Skip to content

Commit

Permalink
[FLINK-2678]DataSet API does not support multi-dimensional arrays as …
Browse files Browse the repository at this point in the history
…keys
  • Loading branch information
sbcd90 committed Feb 4, 2016
1 parent aabb268 commit 6ae8399
Show file tree
Hide file tree
Showing 5 changed files with 453 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,16 @@

import com.google.common.base.Preconditions;
import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.typeinfo.AtomicType;
import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeutils.CompositeType;
import org.apache.flink.api.common.typeutils.TypeComparator;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.java.typeutils.runtime.ObjectArrayComparator;
import org.apache.flink.api.common.typeutils.base.GenericArraySerializer;

public class ObjectArrayTypeInfo<T, C> extends TypeInformation<T> {
public class ObjectArrayTypeInfo<T, C> extends TypeInformation<T> implements AtomicType<T> {

private static final long serialVersionUID = 1L;

Expand Down Expand Up @@ -72,15 +77,59 @@ public TypeInformation<C> getComponentInfo() {

@Override
public boolean isKeyType() {
return false;
return true;
}

@SuppressWarnings("unchecked")
@Override
public TypeSerializer<T> createSerializer(ExecutionConfig executionConfig) {
return (TypeSerializer<T>) new GenericArraySerializer<C>(
componentInfo.getTypeClass(),
componentInfo.createSerializer(executionConfig));
componentInfo.getTypeClass(),
componentInfo.createSerializer(executionConfig));
}

@SuppressWarnings("unchecked")
private TypeComparator<? super Object> getBaseComparatorInfo(TypeInformation<? extends Object> componentInfo, boolean sortOrderAscending, ExecutionConfig executionConfig) {
/**
* method tries to find out the Comparator to be used to compare each element (of primitive type or composite type) of the provided Object arrays.
*/
if (componentInfo instanceof ObjectArrayTypeInfo) {
return getBaseComparatorInfo(((ObjectArrayTypeInfo) componentInfo).getComponentInfo(), sortOrderAscending, executionConfig);
}
else if (componentInfo instanceof PrimitiveArrayTypeInfo) {
return getBaseComparatorInfo(((PrimitiveArrayTypeInfo<? extends Object>) componentInfo).getComponentType(), sortOrderAscending, executionConfig);
}
else {
if (componentInfo instanceof AtomicType) {
return ((AtomicType<? super Object>) componentInfo).createComparator(sortOrderAscending, executionConfig);
}
else if (componentInfo instanceof CompositeType) {
int componentArity = ((CompositeType<? extends Object>) componentInfo).getArity();
int [] logicalKeyFields = new int[componentArity];
boolean[] orders = new boolean[componentArity];

for (int i=0;i < componentArity;i++) {
logicalKeyFields[i] = i;
orders[i] = sortOrderAscending;
}

return ((CompositeType<? super Object>) componentInfo).createComparator(logicalKeyFields, orders, 0, executionConfig);
}
else {
throw new IllegalArgumentException("Could not add a comparator for the component type " + componentInfo.getClass().getName());
}
}
}

@SuppressWarnings("unchecked")
@Override
public TypeComparator<T> createComparator(boolean sortOrderAscending, ExecutionConfig executionConfig) {

return (TypeComparator<T>) new ObjectArrayComparator<T,C>(
sortOrderAscending,
(GenericArraySerializer<T>) createSerializer(executionConfig),
getBaseComparatorInfo(componentInfo, sortOrderAscending, executionConfig)
);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.api.java.typeutils.runtime;

import org.apache.flink.api.common.typeutils.TypeComparator;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.core.memory.DataInputView;
import org.apache.flink.core.memory.DataOutputView;
import org.apache.flink.core.memory.MemorySegment;

import java.io.IOException;
import java.lang.reflect.Array;
import java.util.Arrays;


public class ObjectArrayComparator<T,C> extends TypeComparator<T[]> implements java.io.Serializable {

private static final long serialVersionUID = 1L;

private transient T[] reference;

protected final boolean ascendingComparison;

private final TypeSerializer<T[]> serializer;

private TypeComparator<? super Object> comparatorInfo;

// For use by getComparators
@SuppressWarnings("rawtypes")
private final TypeComparator[] comparators = new TypeComparator[] {this};

public ObjectArrayComparator(boolean ascending, TypeSerializer<T[]> serializer, TypeComparator<? super Object> comparatorInfo) {
this.ascendingComparison = ascending;
this.serializer = serializer;
this.comparatorInfo = comparatorInfo;
}

@Override
public void setReference(T[] reference) {
this.reference = reference;
}

@Override
public boolean equalToReference(T[] candidate) {
return compare(this.reference, candidate) == 0;
}

@Override
public int compareToReference(TypeComparator<T[]> referencedComparator) {
int comp = compare(((ObjectArrayComparator<T,C>) referencedComparator).reference, reference);
return comp;
}

@Override
public int compareSerialized(DataInputView firstSource, DataInputView secondSource) throws IOException {
T[] firstArray = serializer.deserialize(firstSource);
T[] secondArray = serializer.deserialize(secondSource);

int comp = compare(firstArray, secondArray);
return comp;
}

@Override
public int extractKeys(Object record, Object[] target, int index) {
target[index] = record;
return 1;
}

@Override
public TypeComparator[] getFlatComparators() {
return comparators;
}

@Override
public boolean supportsNormalizedKey() {
return false;
}

@Override
public boolean supportsSerializationWithKeyNormalization() {
return false;
}

@Override
public int getNormalizeKeyLen() {
return 0;
}

@Override
public boolean isNormalizedKeyPrefixOnly(int keyBytes) {
throw new UnsupportedOperationException();
}

@Override
public void putNormalizedKey(T[] record, MemorySegment target, int offset, int numBytes) {
throw new UnsupportedOperationException();
}

@Override
public void writeWithKeyNormalization(T[] record, DataOutputView target) throws IOException {
throw new UnsupportedOperationException();
}

@Override
public T[] readWithKeyDenormalization(T[] reuse, DataInputView source) throws IOException {
throw new UnsupportedOperationException();
}

@Override
public boolean invertNormalizedKey() {
return !ascendingComparison;
}

@Override
public int hash(T[] record) {
return Arrays.hashCode(record);
}

private int compareValues(Object first, Object second) {
/**
* uses the chosen comparator ( of primitive or composite type ) & compares the provided objects as input
*/
return comparatorInfo.compare(first, second);
}

@SuppressWarnings("unchecked")
private int parseGenericArray(Object firstArray, Object secondArray) {
int compareResult = 0;

/**
* logic to determine comparison result due to length difference.
* the length difference cannot fully determine the result of the comparison. Hence, result added to tempResult.
*/
int min = Array.getLength(firstArray);
int tempResult = 0;
if (min < Array.getLength(secondArray)) {
tempResult = ascendingComparison ? -1: 1;
}
if (min > Array.getLength(secondArray)) {
tempResult = ascendingComparison? 1: -1;
min = Array.getLength(secondArray);
}

/**
* comparing the actual content of two arrays.
*/
for (int i=0;i < min;i++) {
int val;

if (!Array.get(firstArray, i).getClass().isArray() && !Array.get(secondArray, i).getClass().isArray()) {
val = compareValues(Array.get(firstArray, i), Array.get(secondArray, i));
}
else {
val = parseGenericArray(Array.get(firstArray, i), Array.get(secondArray, i));
}

if (val != 0) {
compareResult = val;
break;
}
}

/**
* if the actual comparison cannot distinguish between two arrays, then length differences take preference.
*/
if (compareResult == 0) {
compareResult = tempResult;
}
return compareResult;
}

@Override
public int compare(T[] first, T[] second) {
return parseGenericArray(first, second);
}

@Override
public TypeComparator<T[]> duplicate() {
ObjectArrayComparator<T,C> dupe = new ObjectArrayComparator<T,C>(ascendingComparison, serializer, comparatorInfo);
dupe.setReference(this.reference);
return dupe;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.api.java.typeutils.runtime;

import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeutils.ComparatorTestBase;
import org.apache.flink.api.common.typeutils.CompositeType;
import org.apache.flink.api.common.typeutils.TypeComparator;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.common.typeutils.base.GenericArraySerializer;
import org.apache.flink.api.java.tuple.Tuple;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
import org.apache.flink.api.java.typeutils.TupleTypeInfo;
import org.junit.Assert;

import java.lang.reflect.Array;

public class ObjectArrayComparatorCompositeTypeTest extends ComparatorTestBase<Tuple2<String, Integer>[][]> {
private final TypeInformation<Tuple2<String, Integer>[]> componentInfo;

public ObjectArrayComparatorCompositeTypeTest() {
this.componentInfo = ObjectArrayTypeInfo.getInfoFor(new TupleTypeInfo<Tuple>(BasicTypeInfo.STRING_TYPE_INFO, BasicTypeInfo.INT_TYPE_INFO));
}

@SuppressWarnings("unchecked")
@Override
protected TypeSerializer<Tuple2<String, Integer>[][]> createSerializer() {
return (TypeSerializer<Tuple2<String, Integer>[][]>) new GenericArraySerializer<Tuple2<String, Integer>[]>(
componentInfo.getTypeClass(),
componentInfo.createSerializer(null));
}

@SuppressWarnings("unchecked")
@Override
protected TypeComparator<Tuple2<String, Integer>[][]> createComparator(boolean ascending) {
CompositeType<? extends Object> baseComponentInfo = new TupleTypeInfo<Tuple>(BasicTypeInfo.STRING_TYPE_INFO, BasicTypeInfo.INT_TYPE_INFO);
int componentArity = baseComponentInfo.getArity();
int [] logicalKeyFields = new int[componentArity];
boolean[] orders = new boolean[componentArity];

for (int i=0;i < componentArity;i++) {
logicalKeyFields[i] = i;
orders[i] = ascending;
}

return (TypeComparator<Tuple2<String, Integer>[][]>) new ObjectArrayComparator<Tuple2<String, Integer>[], Character>(ascending,
(GenericArraySerializer<Tuple2<String, Integer>[]>) createSerializer(),
((CompositeType<? super Object>) baseComponentInfo).createComparator(logicalKeyFields, orders, 0, null)
);
}

@Override
protected void deepEquals(String message, Tuple2<String, Integer>[][] should, Tuple2<String, Integer>[][] is) {
Assert.assertTrue(should.length==is.length);
for (int i=0;i < should.length;i++) {
Assert.assertTrue(should[i].length==is[i].length);
for (int j=0;j < should[i].length;j++) {
Assert.assertEquals(should[i][j].f0,is[i][j].f0);
Assert.assertEquals(should[i][j].f1,is[i][j].f1);
}
}
}

@SuppressWarnings("unchecked")
@Override
protected Tuple2<String, Integer>[][][] getSortedTestData() {
Object result = Array.newInstance(Tuple2.class, new int[]{2, 2, 1});

((Tuple2<String, Integer>[][][]) result)[0][0][0] = new Tuple2<String, Integer>();
((Tuple2<String, Integer>[][][]) result)[0][0][0].f0 = "be";
((Tuple2<String, Integer>[][][]) result)[0][0][0].f1 = 2;

((Tuple2<String, Integer>[][][]) result)[0][1][0] = new Tuple2<String, Integer>();
((Tuple2<String, Integer>[][][]) result)[0][1][0].f0 = "not";
((Tuple2<String, Integer>[][][]) result)[0][1][0].f1 = 3;


((Tuple2<String, Integer>[][][]) result)[1][0][0] = new Tuple2<String, Integer>();
((Tuple2<String, Integer>[][][]) result)[1][0][0].f0 = "or";
((Tuple2<String, Integer>[][][]) result)[1][0][0].f1 = 2;

((Tuple2<String, Integer>[][][]) result)[1][1][0] = new Tuple2<String, Integer>();
((Tuple2<String, Integer>[][][]) result)[1][1][0].f0 = "to";
((Tuple2<String, Integer>[][][]) result)[1][1][0].f1 = 2;

return (Tuple2<String, Integer>[][][]) result;
}
}
Loading

0 comments on commit 6ae8399

Please sign in to comment.