Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FLINK-2678]DataSet API does not support multi-dimensional arrays as keys #1566

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,16 @@

import com.google.common.base.Preconditions;
import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.typeinfo.AtomicType;
import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeutils.CompositeType;
import org.apache.flink.api.common.typeutils.TypeComparator;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.java.typeutils.runtime.ObjectArrayComparator;
import org.apache.flink.api.common.typeutils.base.GenericArraySerializer;

public class ObjectArrayTypeInfo<T, C> extends TypeInformation<T> {
public class ObjectArrayTypeInfo<T, C> extends TypeInformation<T> implements AtomicType<T> {

private static final long serialVersionUID = 1L;

Expand Down Expand Up @@ -72,15 +77,59 @@ public TypeInformation<C> getComponentInfo() {

@Override
public boolean isKeyType() {
return false;
return true;
}

@SuppressWarnings("unchecked")
@Override
public TypeSerializer<T> createSerializer(ExecutionConfig executionConfig) {
return (TypeSerializer<T>) new GenericArraySerializer<C>(
componentInfo.getTypeClass(),
componentInfo.createSerializer(executionConfig));
componentInfo.getTypeClass(),
componentInfo.createSerializer(executionConfig));
}

@SuppressWarnings("unchecked")
private TypeComparator<? super Object> getBaseComparatorInfo(TypeInformation<? extends Object> componentInfo, boolean sortOrderAscending, ExecutionConfig executionConfig) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are you extracting for certain types the element comparator (for example the ObjectArrayTypeInfo) and for other you create the TypeComparator (for example the CompositeTypeInfo)? I don't get it. Why do you need the getBaseComparatorInfo method at all? Simply check in createComparator the different subtypes and then create the TypeComparator.

/**
* method tries to find out the Comparator to be used to compare each element (of primitive type or composite type) of the provided Object arrays.
*/
if (componentInfo instanceof ObjectArrayTypeInfo) {
return getBaseComparatorInfo(((ObjectArrayTypeInfo) componentInfo).getComponentInfo(), sortOrderAscending, executionConfig);
}
else if (componentInfo instanceof PrimitiveArrayTypeInfo) {
return getBaseComparatorInfo(((PrimitiveArrayTypeInfo<? extends Object>) componentInfo).getComponentType(), sortOrderAscending, executionConfig);
}
else {
if (componentInfo instanceof AtomicType) {
return ((AtomicType<? super Object>) componentInfo).createComparator(sortOrderAscending, executionConfig);
}
else if (componentInfo instanceof CompositeType) {
int componentArity = ((CompositeType<? extends Object>) componentInfo).getArity();
int [] logicalKeyFields = new int[componentArity];
boolean[] orders = new boolean[componentArity];

for (int i=0;i < componentArity;i++) {
logicalKeyFields[i] = i;
orders[i] = sortOrderAscending;
}

return ((CompositeType<? super Object>) componentInfo).createComparator(logicalKeyFields, orders, 0, executionConfig);
}
else {
throw new IllegalArgumentException("Could not add a comparator for the component type " + componentInfo.getClass().getName());
}
}
}

@SuppressWarnings("unchecked")
@Override
public TypeComparator<T> createComparator(boolean sortOrderAscending, ExecutionConfig executionConfig) {

return (TypeComparator<T>) new ObjectArrayComparator<T,C>(
sortOrderAscending,
(GenericArraySerializer<T>) createSerializer(executionConfig),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this cast here?

getBaseComparatorInfo(componentInfo, sortOrderAscending, executionConfig)
);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.api.java.typeutils.runtime;

import org.apache.flink.api.common.typeutils.TypeComparator;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.core.memory.DataInputView;
import org.apache.flink.core.memory.DataOutputView;
import org.apache.flink.core.memory.MemorySegment;

import java.io.IOException;
import java.lang.reflect.Array;
import java.util.Arrays;


public class ObjectArrayComparator<T,C> extends TypeComparator<T[]> implements java.io.Serializable {

private static final long serialVersionUID = 1L;

private transient T[] reference;

protected final boolean ascendingComparison;

private final TypeSerializer<T[]> serializer;

private TypeComparator<? super Object> comparatorInfo;

// For use by getComparators
@SuppressWarnings("rawtypes")
private final TypeComparator[] comparators = new TypeComparator[] {this};

public ObjectArrayComparator(boolean ascending, TypeSerializer<T[]> serializer, TypeComparator<? super Object> comparatorInfo) {
this.ascendingComparison = ascending;
this.serializer = serializer;
this.comparatorInfo = comparatorInfo;
}

@Override
public void setReference(T[] reference) {
this.reference = reference;
}

@Override
public boolean equalToReference(T[] candidate) {
return compare(this.reference, candidate) == 0;
}

@Override
public int compareToReference(TypeComparator<T[]> referencedComparator) {
int comp = compare(((ObjectArrayComparator<T,C>) referencedComparator).reference, reference);
return comp;
}

@Override
public int compareSerialized(DataInputView firstSource, DataInputView secondSource) throws IOException {
T[] firstArray = serializer.deserialize(firstSource);
T[] secondArray = serializer.deserialize(secondSource);

int comp = compare(firstArray, secondArray);
return comp;
}

@Override
public int extractKeys(Object record, Object[] target, int index) {
target[index] = record;
return 1;
}

@Override
public TypeComparator[] getFlatComparators() {
return comparators;
}

@Override
public boolean supportsNormalizedKey() {
return false;
}

@Override
public boolean supportsSerializationWithKeyNormalization() {
return false;
}

@Override
public int getNormalizeKeyLen() {
return 0;
}

@Override
public boolean isNormalizedKeyPrefixOnly(int keyBytes) {
throw new UnsupportedOperationException();
}

@Override
public void putNormalizedKey(T[] record, MemorySegment target, int offset, int numBytes) {
throw new UnsupportedOperationException();
}

@Override
public void writeWithKeyNormalization(T[] record, DataOutputView target) throws IOException {
throw new UnsupportedOperationException();
}

@Override
public T[] readWithKeyDenormalization(T[] reuse, DataInputView source) throws IOException {
throw new UnsupportedOperationException();
}

@Override
public boolean invertNormalizedKey() {
return !ascendingComparison;
}

@Override
public int hash(T[] record) {
return Arrays.hashCode(record);
}

private int compareValues(Object first, Object second) {
/**
* uses the chosen comparator ( of primitive or composite type ) & compares the provided objects as input
*/
return comparatorInfo.compare(first, second);
}

@SuppressWarnings("unchecked")
private int parseGenericArray(Object firstArray, Object secondArray) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why exactly are you doing all the array operation on Objects?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is not the right way to go. Simply check then length of the arrays in the compare method and then call for all the elements the type comparator of the component type of the array, which you have extracted from the component type info.

int compareResult = 0;

/**
* logic to determine comparison result due to length difference.
* the length difference cannot fully determine the result of the comparison. Hence, result added to tempResult.
*/
int min = Array.getLength(firstArray);
int tempResult = 0;
if (min < Array.getLength(secondArray)) {
tempResult = ascendingComparison ? -1: 1;
}
if (min > Array.getLength(secondArray)) {
tempResult = ascendingComparison? 1: -1;
min = Array.getLength(secondArray);
}

/**
* comparing the actual content of two arrays.
*/
for (int i=0;i < min;i++) {
int val;

if (!Array.get(firstArray, i).getClass().isArray() && !Array.get(secondArray, i).getClass().isArray()) {
val = compareValues(Array.get(firstArray, i), Array.get(secondArray, i));
}
else {
val = parseGenericArray(Array.get(firstArray, i), Array.get(secondArray, i));
}

if (val != 0) {
compareResult = val;
break;
}
}

/**
* if the actual comparison cannot distinguish between two arrays, then length differences take preference.
*/
if (compareResult == 0) {
compareResult = tempResult;
}
return compareResult;
}

@Override
public int compare(T[] first, T[] second) {
return parseGenericArray(first, second);
}

@Override
public TypeComparator<T[]> duplicate() {
ObjectArrayComparator<T,C> dupe = new ObjectArrayComparator<T,C>(ascendingComparison, serializer, comparatorInfo);
dupe.setReference(this.reference);
return dupe;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.api.java.typeutils.runtime;

import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeutils.ComparatorTestBase;
import org.apache.flink.api.common.typeutils.CompositeType;
import org.apache.flink.api.common.typeutils.TypeComparator;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.common.typeutils.base.GenericArraySerializer;
import org.apache.flink.api.java.tuple.Tuple;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
import org.apache.flink.api.java.typeutils.TupleTypeInfo;
import org.junit.Assert;

import java.lang.reflect.Array;

public class ObjectArrayComparatorCompositeTypeTest extends ComparatorTestBase<Tuple2<String, Integer>[][]> {
private final TypeInformation<Tuple2<String, Integer>[]> componentInfo;

public ObjectArrayComparatorCompositeTypeTest() {
this.componentInfo = ObjectArrayTypeInfo.getInfoFor(new TupleTypeInfo<Tuple>(BasicTypeInfo.STRING_TYPE_INFO, BasicTypeInfo.INT_TYPE_INFO));
}

@SuppressWarnings("unchecked")
@Override
protected TypeSerializer<Tuple2<String, Integer>[][]> createSerializer() {
return (TypeSerializer<Tuple2<String, Integer>[][]>) new GenericArraySerializer<Tuple2<String, Integer>[]>(
componentInfo.getTypeClass(),
componentInfo.createSerializer(null));
}

@SuppressWarnings("unchecked")
@Override
protected TypeComparator<Tuple2<String, Integer>[][]> createComparator(boolean ascending) {
CompositeType<? extends Object> baseComponentInfo = new TupleTypeInfo<Tuple>(BasicTypeInfo.STRING_TYPE_INFO, BasicTypeInfo.INT_TYPE_INFO);
int componentArity = baseComponentInfo.getArity();
int [] logicalKeyFields = new int[componentArity];
boolean[] orders = new boolean[componentArity];

for (int i=0;i < componentArity;i++) {
logicalKeyFields[i] = i;
orders[i] = ascending;
}

return (TypeComparator<Tuple2<String, Integer>[][]>) new ObjectArrayComparator<Tuple2<String, Integer>[], Character>(ascending,
(GenericArraySerializer<Tuple2<String, Integer>[]>) createSerializer(),
((CompositeType<? super Object>) baseComponentInfo).createComparator(logicalKeyFields, orders, 0, null)
);
}

@Override
protected void deepEquals(String message, Tuple2<String, Integer>[][] should, Tuple2<String, Integer>[][] is) {
Assert.assertTrue(should.length==is.length);
for (int i=0;i < should.length;i++) {
Assert.assertTrue(should[i].length==is[i].length);
for (int j=0;j < should[i].length;j++) {
Assert.assertEquals(should[i][j].f0,is[i][j].f0);
Assert.assertEquals(should[i][j].f1,is[i][j].f1);
}
}
}

@SuppressWarnings("unchecked")
@Override
protected Tuple2<String, Integer>[][][] getSortedTestData() {
Object result = Array.newInstance(Tuple2.class, new int[]{2, 2, 1});

((Tuple2<String, Integer>[][][]) result)[0][0][0] = new Tuple2<String, Integer>();
((Tuple2<String, Integer>[][][]) result)[0][0][0].f0 = "be";
((Tuple2<String, Integer>[][][]) result)[0][0][0].f1 = 2;

((Tuple2<String, Integer>[][][]) result)[0][1][0] = new Tuple2<String, Integer>();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can write something like Tuple2.of("not", 3). Makes your life easier.

((Tuple2<String, Integer>[][][]) result)[0][1][0].f0 = "not";
((Tuple2<String, Integer>[][][]) result)[0][1][0].f1 = 3;


((Tuple2<String, Integer>[][][]) result)[1][0][0] = new Tuple2<String, Integer>();
((Tuple2<String, Integer>[][][]) result)[1][0][0].f0 = "or";
((Tuple2<String, Integer>[][][]) result)[1][0][0].f1 = 2;

((Tuple2<String, Integer>[][][]) result)[1][1][0] = new Tuple2<String, Integer>();
((Tuple2<String, Integer>[][][]) result)[1][1][0].f0 = "to";
((Tuple2<String, Integer>[][][]) result)[1][1][0].f1 = 2;

return (Tuple2<String, Integer>[][][]) result;
}
}
Loading