Skip to content

Commit

Permalink
[FLINK-2678]DataSet API does not support multi-dimensional arrays as …
Browse files Browse the repository at this point in the history
…keys
  • Loading branch information
sbcd90 committed Feb 1, 2016
1 parent ef58cf3 commit 30da04d
Show file tree
Hide file tree
Showing 4 changed files with 305 additions and 6 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.api.common.typeutils.base;

import org.apache.flink.api.common.typeutils.TypeComparator;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.core.memory.DataInputView;
import org.apache.flink.core.memory.DataOutputView;
import org.apache.flink.core.memory.MemorySegment;

import java.io.IOException;
import java.lang.reflect.Array;
import java.util.Arrays;

public class GenericArrayComparator<T> extends TypeComparator<T[]> implements java.io.Serializable {

private static final long serialVersionUID = 1L;

private transient T[] reference;

protected final boolean ascendingComparison;

private final TypeSerializer<T[]> serializer;

// For use by getComparators
@SuppressWarnings("rawtypes")
private final TypeComparator[] comparators = new TypeComparator[] {this};

public GenericArrayComparator(boolean ascending, TypeSerializer<T[]> serializer) {
this.ascendingComparison = ascending;
this.serializer = serializer;
}

@Override
public void setReference(T[] reference) {
this.reference = reference;
}

@Override
public boolean equalToReference(T[] candidate) {
return compare(this.reference, candidate) == 0;
}

@Override
public int compareToReference(TypeComparator<T[]> referencedComparator) {
int comp = compare(((GenericArrayComparator<T>) referencedComparator).reference, reference);
return ascendingComparison ? comp : -comp;
}

@Override
public int compareSerialized(DataInputView firstSource, DataInputView secondSource) throws IOException {
T[] firstArray = serializer.deserialize(firstSource);
T[] secondArray = serializer.deserialize(secondSource);

int comp = compare(firstArray, secondArray);
return ascendingComparison ? comp : -comp;
}

@Override
public int extractKeys(Object record, Object[] target, int index) {
target[index] = record;
return 1;
}

@Override
public TypeComparator[] getFlatComparators() {
return comparators;
}

@Override
public boolean supportsNormalizedKey() {
return false;
}

@Override
public boolean supportsSerializationWithKeyNormalization() {
return false;
}

@Override
public int getNormalizeKeyLen() {
return 0;
}

@Override
public boolean isNormalizedKeyPrefixOnly(int keyBytes) {
throw new UnsupportedOperationException();
}

@Override
public void putNormalizedKey(T[] record, MemorySegment target, int offset, int numBytes) {
throw new UnsupportedOperationException();
}

@Override
public void writeWithKeyNormalization(T[] record, DataOutputView target) throws IOException {
throw new UnsupportedOperationException();
}

@Override
public T[] readWithKeyDenormalization(T[] reuse, DataInputView source) throws IOException {
throw new UnsupportedOperationException();
}

@Override
public boolean invertNormalizedKey() {
return !ascendingComparison;
}

@Override
public int hash(T[] record) {
return Arrays.hashCode(record);
}

private int compareValues(Object first, Object second) {
if (first.getClass().equals(Boolean.class) && second.getClass().equals(Boolean.class)) {
return new BooleanComparator(true).compare((Boolean) first, (Boolean) second);
}
else if (first.getClass().equals(Byte.class) && second.getClass().equals(Byte.class)) {
return new ByteComparator(true).compare((Byte) first, (Byte) second);
}
else if (first.getClass().equals(Character.class) && second.getClass().equals(Character.class)) {
return new CharComparator(true).compare((Character) first, (Character) second);
}
else if (first.getClass().equals(Double.class) && second.getClass().equals(Double.class)) {
return new DoubleComparator(true).compare((Double) first, (Double) second);
}
else if (first.getClass().equals(Float.class) && second.getClass().equals(Float.class)) {
return new FloatComparator(true).compare((Float) first, (Float) second);
}
else if (first.getClass().equals(Integer.class) && second.getClass().equals(Integer.class)) {
return new IntComparator(true).compare((Integer) first, (Integer) second);
}
else if (first.getClass().equals(Long.class) && second.getClass().equals(Long.class)) {
return new LongComparator(true).compare((Long) first, (Long) second);
}
else if (first.getClass().equals(Short.class) && second.getClass().equals(Short.class)) {
return new ShortComparator(true).compare((Short) first, (Short) second);
}
else if (first.getClass().equals(String.class) && second.getClass().equals(String.class)) {
return new StringComparator(true).compare((String) first, (String) second);
}
return -1;
}

@SuppressWarnings("unchecked")
private int parseGenericArray(Object firstArray, Object secondArray) {
int compareResult = 0;
if (firstArray.getClass().isArray() && secondArray.getClass().isArray()) {
int min = Array.getLength(firstArray);
int tempResult = 0;

if (min < Array.getLength(secondArray)) {
tempResult = -1;
}
if (min > Array.getLength(secondArray)) {
min = Array.getLength(secondArray);
tempResult = 1;
}

for(int i=0; i < min; i++) {
int val = parseGenericArray(Array.get(firstArray, i), Array.get(secondArray, i));
if (val != 0 && compareResult == 0) {
compareResult = val;
}
}

if (compareResult == 0) {
compareResult = tempResult;
}
}
else {
compareResult = compareValues(firstArray, secondArray);
}
return compareResult;
}

@Override
public int compare(T[] first, T[] second) {
return parseGenericArray(first, second);
}

@Override
public TypeComparator<T[]> duplicate() {
GenericArrayComparator<T> dupe = new GenericArrayComparator<T>(ascendingComparison, serializer);
dupe.setReference(this.reference);
return dupe;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.api.common.typeutils.base;

import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeutils.ComparatorTestBase;
import org.apache.flink.api.common.typeutils.TypeComparator;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.junit.Assert;

public class GenericArrayComparatorTest extends ComparatorTestBase<char[][]> {
private final TypeInformation<char[]> componentInfo;

public GenericArrayComparatorTest() {
componentInfo = PrimitiveArrayTypeInfo.CHAR_PRIMITIVE_ARRAY_TYPE_INFO;
}

@Override
protected TypeSerializer<char[][]> createSerializer() {
return (TypeSerializer<char[][]>) new GenericArraySerializer<char[]>(
componentInfo.getTypeClass(),
componentInfo.createSerializer(null));
}

@Override
protected TypeComparator<char[][]> createComparator(boolean ascending) {
return (TypeComparator<char[][]>) new GenericArrayComparator<char[]>(
ascending,
(GenericArraySerializer<char[]>) createSerializer()
);
}

@Override
protected void deepEquals(String message, char[][] should, char[][] is) {
Assert.assertTrue(should.length==is.length);
for (int i=0; i < should.length;i++) {
Assert.assertTrue(should[i].length==is[i].length);
for (int j=0;j < should[i].length;j++) {
Assert.assertEquals(should[i][j], is[i][j]);
}
}
}

@Override
protected char[][][] getSortedTestData() {
return new char[][][]{
new char[][]{
new char[]{'b', 'e'},
new char[]{'2'},
},
new char[][]{
new char[]{'n', 'o', 't'},
new char[]{'3'}
},
new char[][]{
new char[]{'o', 'r'},
new char[]{'2'}
},
new char[][]{
new char[]{'t', 'o'},
new char[]{'2'}
}
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,14 @@

import com.google.common.base.Preconditions;
import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.typeinfo.AtomicType;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeutils.TypeComparator;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.common.typeutils.base.GenericArrayComparator;
import org.apache.flink.api.common.typeutils.base.GenericArraySerializer;

public class ObjectArrayTypeInfo<T, C> extends TypeInformation<T> {
public class ObjectArrayTypeInfo<T, C> extends TypeInformation<T> implements AtomicType<T> {

private static final long serialVersionUID = 1L;

Expand Down Expand Up @@ -72,15 +75,24 @@ public TypeInformation<C> getComponentInfo() {

@Override
public boolean isKeyType() {
return false;
return true;
}

@SuppressWarnings("unchecked")
@Override
public TypeSerializer<T> createSerializer(ExecutionConfig executionConfig) {
return (TypeSerializer<T>) new GenericArraySerializer<C>(
componentInfo.getTypeClass(),
componentInfo.createSerializer(executionConfig));
componentInfo.getTypeClass(),
componentInfo.createSerializer(executionConfig));
}

@SuppressWarnings("unchecked")
@Override
public TypeComparator<T> createComparator(boolean sortOrderAscending, ExecutionConfig executionConfig) {
return (TypeComparator<T>) new GenericArrayComparator<T>(
sortOrderAscending,
(GenericArraySerializer<T>) createSerializer(executionConfig)
);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ public void testPojoSingleOrderFull() {
.sortLocalOutput("*", Order.ASCENDING);
}

@Test(expected = InvalidProgramException.class)
@Test
public void testArrayOrderFull() {

List<Object[]> arrayData = new ArrayList<>();
Expand All @@ -322,7 +322,7 @@ public void testArrayOrderFull() {
DataSet<Object[]> pojoDs = env
.fromCollection(arrayData);

// must not work
// should work
pojoDs.writeAsText("/tmp/willNotHappen")
.sortLocalOutput("*", Order.ASCENDING);
}
Expand Down

0 comments on commit 30da04d

Please sign in to comment.