Skip to content

Commit

Permalink
[FLINK-2565] Support primitive arrays as keys
Browse files Browse the repository at this point in the history
This closes #1043
  • Loading branch information
supermegaciaccount authored and StephanEwen committed Aug 27, 2015
1 parent 1e38d6f commit 0807eec
Show file tree
Hide file tree
Showing 22 changed files with 1,057 additions and 15 deletions.
Expand Up @@ -24,13 +24,22 @@
import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.functions.InvalidTypesException;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.common.typeutils.base.array.BooleanPrimitiveArrayComparator;
import org.apache.flink.api.common.typeutils.base.array.BooleanPrimitiveArraySerializer;
import org.apache.flink.api.common.typeutils.base.array.BytePrimitiveArrayComparator;
import org.apache.flink.api.common.typeutils.base.array.BytePrimitiveArraySerializer;
import org.apache.flink.api.common.typeutils.base.array.CharPrimitiveArrayComparator;
import org.apache.flink.api.common.typeutils.base.array.CharPrimitiveArraySerializer;
import org.apache.flink.api.common.typeutils.base.array.DoublePrimitiveArrayComparator;
import org.apache.flink.api.common.typeutils.base.array.DoublePrimitiveArraySerializer;
import org.apache.flink.api.common.typeutils.base.array.FloatPrimitiveArrayComparator;
import org.apache.flink.api.common.typeutils.base.array.FloatPrimitiveArraySerializer;
import org.apache.flink.api.common.typeutils.base.array.IntPrimitiveArrayComparator;
import org.apache.flink.api.common.typeutils.base.array.IntPrimitiveArraySerializer;
import org.apache.flink.api.common.typeutils.base.array.LongPrimitiveArrayComparator;
import org.apache.flink.api.common.typeutils.base.array.LongPrimitiveArraySerializer;
import org.apache.flink.api.common.typeutils.base.array.PrimitiveArrayComparator;
import org.apache.flink.api.common.typeutils.base.array.ShortPrimitiveArrayComparator;
import org.apache.flink.api.common.typeutils.base.array.ShortPrimitiveArraySerializer;

/**
Expand All @@ -39,19 +48,18 @@
*
* @param <T> The type represented by this type information, e.g., int[], double[], long[]
*/
public class PrimitiveArrayTypeInfo<T> extends TypeInformation<T> {
public class PrimitiveArrayTypeInfo<T> extends TypeInformation<T> implements AtomicType<T> {

private static final long serialVersionUID = 1L;

public static final PrimitiveArrayTypeInfo<boolean[]> BOOLEAN_PRIMITIVE_ARRAY_TYPE_INFO = new PrimitiveArrayTypeInfo<boolean[]>(boolean[].class, BooleanPrimitiveArraySerializer.INSTANCE);
public static final PrimitiveArrayTypeInfo<byte[]> BYTE_PRIMITIVE_ARRAY_TYPE_INFO = new PrimitiveArrayTypeInfo<byte[]>(byte[].class, BytePrimitiveArraySerializer.INSTANCE);
public static final PrimitiveArrayTypeInfo<short[]> SHORT_PRIMITIVE_ARRAY_TYPE_INFO = new PrimitiveArrayTypeInfo<short[]>(short[].class, ShortPrimitiveArraySerializer.INSTANCE);
public static final PrimitiveArrayTypeInfo<int[]> INT_PRIMITIVE_ARRAY_TYPE_INFO = new PrimitiveArrayTypeInfo<int[]>(int[].class, IntPrimitiveArraySerializer.INSTANCE);
public static final PrimitiveArrayTypeInfo<long[]> LONG_PRIMITIVE_ARRAY_TYPE_INFO = new PrimitiveArrayTypeInfo<long[]>(long[].class, LongPrimitiveArraySerializer.INSTANCE);
public static final PrimitiveArrayTypeInfo<float[]> FLOAT_PRIMITIVE_ARRAY_TYPE_INFO = new PrimitiveArrayTypeInfo<float[]>(float[].class, FloatPrimitiveArraySerializer.INSTANCE);
public static final PrimitiveArrayTypeInfo<double[]> DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO = new PrimitiveArrayTypeInfo<double[]>(double[].class, DoublePrimitiveArraySerializer.INSTANCE);
public static final PrimitiveArrayTypeInfo<char[]> CHAR_PRIMITIVE_ARRAY_TYPE_INFO = new PrimitiveArrayTypeInfo<char[]>(char[].class, CharPrimitiveArraySerializer.INSTANCE);

public static final PrimitiveArrayTypeInfo<boolean[]> BOOLEAN_PRIMITIVE_ARRAY_TYPE_INFO = new PrimitiveArrayTypeInfo<boolean[]>(boolean[].class, BooleanPrimitiveArraySerializer.INSTANCE, BooleanPrimitiveArrayComparator.class);
public static final PrimitiveArrayTypeInfo<byte[]> BYTE_PRIMITIVE_ARRAY_TYPE_INFO = new PrimitiveArrayTypeInfo<byte[]>(byte[].class, BytePrimitiveArraySerializer.INSTANCE, BytePrimitiveArrayComparator.class);
public static final PrimitiveArrayTypeInfo<short[]> SHORT_PRIMITIVE_ARRAY_TYPE_INFO = new PrimitiveArrayTypeInfo<short[]>(short[].class, ShortPrimitiveArraySerializer.INSTANCE, ShortPrimitiveArrayComparator.class);
public static final PrimitiveArrayTypeInfo<int[]> INT_PRIMITIVE_ARRAY_TYPE_INFO = new PrimitiveArrayTypeInfo<int[]>(int[].class, IntPrimitiveArraySerializer.INSTANCE, IntPrimitiveArrayComparator.class);
public static final PrimitiveArrayTypeInfo<long[]> LONG_PRIMITIVE_ARRAY_TYPE_INFO = new PrimitiveArrayTypeInfo<long[]>(long[].class, LongPrimitiveArraySerializer.INSTANCE, LongPrimitiveArrayComparator.class);
public static final PrimitiveArrayTypeInfo<float[]> FLOAT_PRIMITIVE_ARRAY_TYPE_INFO = new PrimitiveArrayTypeInfo<float[]>(float[].class, FloatPrimitiveArraySerializer.INSTANCE, FloatPrimitiveArrayComparator.class);
public static final PrimitiveArrayTypeInfo<double[]> DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO = new PrimitiveArrayTypeInfo<double[]>(double[].class, DoublePrimitiveArraySerializer.INSTANCE, DoublePrimitiveArrayComparator.class);
public static final PrimitiveArrayTypeInfo<char[]> CHAR_PRIMITIVE_ARRAY_TYPE_INFO = new PrimitiveArrayTypeInfo<char[]>(char[].class, CharPrimitiveArraySerializer.INSTANCE, CharPrimitiveArrayComparator.class);
// --------------------------------------------------------------------------------------------

/** The class of the array (such as int[].class) */
Expand All @@ -60,12 +68,15 @@ public class PrimitiveArrayTypeInfo<T> extends TypeInformation<T> {
/** The serializer for the array */
private final TypeSerializer<T> serializer;

/** The class of the comparator for the array */
private Class<? extends PrimitiveArrayComparator<T, ?>> comparatorClass;

/**
* Creates a new type info for a
* @param arrayClass The class of the array (such as int[].class)
* @param serializer The serializer for the array.
*/
private PrimitiveArrayTypeInfo(Class<T> arrayClass, TypeSerializer<T> serializer) {
private PrimitiveArrayTypeInfo(Class<T> arrayClass, TypeSerializer<T> serializer, Class<? extends PrimitiveArrayComparator<T, ?>> comparatorClass) {
if (arrayClass == null || serializer == null) {
throw new NullPointerException();
}
Expand All @@ -74,6 +85,7 @@ private PrimitiveArrayTypeInfo(Class<T> arrayClass, TypeSerializer<T> serializer
}
this.arrayClass = arrayClass;
this.serializer = serializer;
this.comparatorClass = comparatorClass;
}

// --------------------------------------------------------------------------------------------
Expand Down Expand Up @@ -105,7 +117,7 @@ public Class<T> getTypeClass() {

@Override
public boolean isKeyType() {
return false;
return true;
}

@Override
Expand Down Expand Up @@ -161,4 +173,13 @@ public static <X> PrimitiveArrayTypeInfo<X> getInfoFor(Class<X> type) {
TYPES.put(double[].class, DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO);
TYPES.put(char[].class, CHAR_PRIMITIVE_ARRAY_TYPE_INFO);
}

@Override
public PrimitiveArrayComparator<T, ?> createComparator(boolean sortOrderAscending, ExecutionConfig executionConfig) {
try {
return comparatorClass.getConstructor(boolean.class).newInstance(sortOrderAscending);
} catch (Exception e) {
throw new RuntimeException("Could not initialize primitive " + comparatorClass.getName() + " array comparator.", e);
}
}
}
@@ -0,0 +1,56 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.api.common.typeutils.base.array;

import static java.lang.Math.min;
import org.apache.flink.api.common.typeutils.TypeComparator;
import org.apache.flink.api.common.typeutils.base.BooleanComparator;

public class BooleanPrimitiveArrayComparator extends PrimitiveArrayComparator<boolean[], BooleanComparator> {
public BooleanPrimitiveArrayComparator(boolean ascending) {
super(ascending, new BooleanComparator(ascending));
}

@Override
public int hash(boolean[] record) {
int result = 0;
for (boolean field : record) {
result += field ? 1231 : 1237;
}
return result;
}

@Override
public int compare(boolean[] first, boolean[] second) {
for (int x = 0; x < min(first.length, second.length); x++) {
int cmp = (second[x] == first[x] ? 0 : (first[x] ? 1 : -1));
if (cmp != 0) {
return ascending ? cmp : -cmp;
}
}
int cmp = first.length - second.length;
return ascending ? cmp : -cmp;
}

@Override
public TypeComparator<boolean[]> duplicate() {
BooleanPrimitiveArrayComparator dupe = new BooleanPrimitiveArrayComparator(this.ascending);
dupe.setReference(this.reference);
return dupe;
}
}
@@ -0,0 +1,56 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.api.common.typeutils.base.array;

import static java.lang.Math.min;
import org.apache.flink.api.common.typeutils.TypeComparator;
import org.apache.flink.api.common.typeutils.base.ByteComparator;

public class BytePrimitiveArrayComparator extends PrimitiveArrayComparator<byte[], ByteComparator> {
public BytePrimitiveArrayComparator(boolean ascending) {
super(ascending, new ByteComparator(ascending));
}

@Override
public int hash(byte[] record) {
int result = 0;
for (byte field : record) {
result += (int) field;
}
return result;
}

@Override
public int compare(byte[] first, byte[] second) {
for (int x = 0; x < min(first.length, second.length); x++) {
int cmp = first[x] - second[x];
if (cmp != 0) {
return ascending ? cmp : -cmp;
}
}
int cmp = first.length - second.length;
return ascending ? cmp : -cmp;
}

@Override
public TypeComparator<byte[]> duplicate() {
BytePrimitiveArrayComparator dupe = new BytePrimitiveArrayComparator(this.ascending);
dupe.setReference(this.reference);
return dupe;
}
}
@@ -0,0 +1,56 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.api.common.typeutils.base.array;

import static java.lang.Math.min;
import org.apache.flink.api.common.typeutils.TypeComparator;
import org.apache.flink.api.common.typeutils.base.CharComparator;

public class CharPrimitiveArrayComparator extends PrimitiveArrayComparator<char[], CharComparator> {
public CharPrimitiveArrayComparator(boolean ascending) {
super(ascending, new CharComparator(ascending));
}

@Override
public int hash(char[] record) {
int result = 0;
for (char field : record) {
result += (int) field;
}
return result;
}

@Override
public int compare(char[] first, char[] second) {
for (int x = 0; x < min(first.length, second.length); x++) {
int cmp = first[x] - second[x];
if (cmp != 0) {
return ascending ? cmp : -cmp;
}
}
int cmp = first.length - second.length;
return ascending ? cmp : -cmp;
}

@Override
public TypeComparator<char[]> duplicate() {
CharPrimitiveArrayComparator dupe = new CharPrimitiveArrayComparator(this.ascending);
dupe.setReference(this.reference);
return dupe;
}
}
@@ -0,0 +1,57 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.api.common.typeutils.base.array;

import static java.lang.Math.min;
import org.apache.flink.api.common.typeutils.TypeComparator;
import org.apache.flink.api.common.typeutils.base.DoubleComparator;

public class DoublePrimitiveArrayComparator extends PrimitiveArrayComparator<double[], DoubleComparator> {
public DoublePrimitiveArrayComparator(boolean ascending) {
super(ascending, new DoubleComparator(ascending));
}

@Override
public int hash(double[] record) {
int result = 0;
for (double field : record) {
long bits = Double.doubleToLongBits(field);
result += (int) (bits ^ (bits >>> 32));
}
return result;
}

@Override
public int compare(double[] first, double[] second) {
for (int x = 0; x < min(first.length, second.length); x++) {
int cmp = Double.compare(first[x], second[x]);
if (cmp != 0) {
return ascending ? cmp : -cmp;
}
}
int cmp = first.length - second.length;
return ascending ? cmp : -cmp;
}

@Override
public TypeComparator<double[]> duplicate() {
DoublePrimitiveArrayComparator dupe = new DoublePrimitiveArrayComparator(this.ascending);
dupe.setReference(this.reference);
return dupe;
}
}
@@ -0,0 +1,56 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.api.common.typeutils.base.array;

import static java.lang.Math.min;
import org.apache.flink.api.common.typeutils.TypeComparator;
import org.apache.flink.api.common.typeutils.base.FloatComparator;

public class FloatPrimitiveArrayComparator extends PrimitiveArrayComparator<float[], FloatComparator> {
public FloatPrimitiveArrayComparator(boolean ascending) {
super(ascending, new FloatComparator(ascending));
}

@Override
public int hash(float[] record) {
int result = 0;
for (float field : record) {
result += Float.floatToIntBits(field);
}
return result;
}

@Override
public int compare(float[] first, float[] second) {
for (int x = 0; x < min(first.length, second.length); x++) {
int cmp = Float.compare(first[x], second[x]);
if (cmp != 0) {
return ascending ? cmp : -cmp;
}
}
int cmp = first.length - second.length;
return ascending ? cmp : -cmp;
}

@Override
public TypeComparator<float[]> duplicate() {
FloatPrimitiveArrayComparator dupe = new FloatPrimitiveArrayComparator(this.ascending);
dupe.setReference(this.reference);
return dupe;
}
}

0 comments on commit 0807eec

Please sign in to comment.