Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HIVE-27112 - implement array_except UDF in Hive #4090

Merged
merged 8 commits into from
Jun 23, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,7 @@ public final class FunctionRegistry {
system.registerGenericUDF("array_distinct", GenericUDFArrayDistinct.class);
system.registerGenericUDF("array_join", GenericUDFArrayJoin.class);
system.registerGenericUDF("array_slice", GenericUDFArraySlice.class);
system.registerGenericUDF("array_except", GenericUDFArrayExcept.class);
system.registerGenericUDF("array_intersect", GenericUDFArrayIntersect.class);
system.registerGenericUDF("deserialize", GenericUDFDeserialize.class);
system.registerGenericUDF("sentences", GenericUDFSentences.class);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.hive.ql.udf.generic;

import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;

import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;

/**
* GenericUDFArrayExcept
*/
@Description(name = "array_except", value = "_FUNC_(array1, array2) - Returns an array of the elements in array1 but not in array2.", extended =
"Example:\n" + " > SELECT _FUNC_(array(1, 2, 3,4), array(2,3)) FROM src LIMIT 1;\n"
+ " [1,4]")
public class GenericUDFArrayExcept extends AbstractGenericUDFArrayBase {
static final int ARRAY2_IDX = 1;
private static final String FUNC_NAME = "ARRAY_EXCEPT";
static final String ERROR_NOT_COMPARABLE = "Input arrays are not comparable to use ARRAY_EXCEPT udf";
private transient ListObjectInspector array2OI;
private transient ObjectInspector arrayElementOI;
private transient ObjectInspector array2ElementOI;

public GenericUDFArrayExcept() {
super(FUNC_NAME, 2, 2, ObjectInspector.Category.LIST);
}

@Override public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException {
ObjectInspector defaultOI = super.initialize(arguments);
array2OI = (ListObjectInspector) arguments[ARRAY2_IDX];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NITs:

  • This line should be after the following checkArgCategory because it validates the type and checkArgCategory throws a kind message

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

checkArgCategory(arguments, ARRAY2_IDX, ObjectInspector.Category.LIST, FUNC_NAME,
org.apache.hadoop.hive.serde.serdeConstants.LIST_TYPE_NAME); //Array1 is already getting validated in Parent class
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we check if the type of elements of array1 is equal to array2? It might be OK if we allow type conversions here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added test cases when one array is of int type and one array is of String type at ql/src/test/org/apache/hadoop/hive/ql/udf/generic/TestGenericUDFArrayExcept.java#testPrimitive

Copy link
Contributor

@okumin okumin May 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I think we should carefully think of the expected specification first. What should the following SQL return?

SELECT array_except(array(1, 2, 3), array(2.0, 3.3));

If it should return array(1, 3), meaning type conversion is applied, we should add a test case where elements are removed with type conversions.
If it should return array(1, 2, 3), meaning the second argument is meaningless if types are unmatched, I personally think we should raise a syntax error. That's because it happens only when a user misunderstands the types of the 1st and 2nd arguments.

For example, Spark 3.4 fails in that case. PrestoSQL returns [1.0, 3.0] with the same SQL, meaning PrestoSQL applies the type conversion from int to float. I personally think either is fine, but it should be tested.

spark-sql (default)> select array_except(array(1, 2, 3), array(2.0, 3.3));
[DATATYPE_MISMATCH.BINARY_ARRAY_DIFF_TYPES] Cannot resolve "array_except(array(1, 2, 3), array(2.0, 3.3))" due to data type mismatch: Input to function `array_except` should have been two "ARRAY" with same element type, but it's ["ARRAY<INT>", "ARRAY<DECIMAL(2,1)>"].; line 1 pos 7;
'Project [unresolvedalias(array_except(array(1, 2, 3), array(2.0, 3.3)), None)]
+- OneRowRelation

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As recommended, throwing error for unmatched data types

arrayElementOI = arrayOI.getListElementObjectInspector();
array2ElementOI = array2OI.getListElementObjectInspector();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: array2OI can be a local variable, or we can simply do ObjectInspectorUtils.compareTypes(arrayOI, arguments[ARRAY2_IDX]) because it seems to recursively check types of LIST, MAP, etc.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed all those variables

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please note that elements in the input array can be of type array or list, that's why sending element itself to check compatibility

if (!ObjectInspectorUtils.compareTypes(arrayElementOI, array2ElementOI)) { // check if elements of arrays are comparable
throw new UDFArgumentTypeException(1, ERROR_NOT_COMPARABLE);
}
return defaultOI;
}

@Override public Object evaluate(DeferredObject[] arguments) throws HiveException {
Object array = arguments[ARRAY_IDX].get();
Object array2 = arguments[ARRAY2_IDX].get();
if (array == null) {
return null;
}

if (array2 == null) {
return null;
}

List<?> retArray3 = ((ListObjectInspector) argumentOIs[ARRAY_IDX]).getList(array);
List inputArrayCopy = new ArrayList<>();
inputArrayCopy.addAll(retArray3);
inputArrayCopy.removeAll(((ListObjectInspector) argumentOIs[ARRAY2_IDX]).getList(arguments[ARRAY2_IDX].get()));
return inputArrayCopy.stream().distinct().map(o -> converter.convert(o)).collect(Collectors.toList());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We got robustness. Thanks!

}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,263 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.hadoop.hive.ql.udf.generic;

import org.apache.hadoop.hive.common.type.Date;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.serde2.io.DateWritableV2;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.io.FloatWritable;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Text;
import org.junit.Assert;
import org.junit.Test;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import static java.util.Arrays.asList;

public class TestGenericUDFArrayExcept {
private final GenericUDFArrayExcept udf = new GenericUDFArrayExcept();

@Test
public void testPrimitive() throws HiveException {
ObjectInspector intObjectInspector = ObjectInspectorFactory.getStandardListObjectInspector(
PrimitiveObjectInspectorFactory.writableIntObjectInspector);
ObjectInspector floatObjectInspector = ObjectInspectorFactory.getStandardListObjectInspector(
PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
ObjectInspector doubleObjectInspector = ObjectInspectorFactory.getStandardListObjectInspector(
PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
ObjectInspector longObjectInspector = ObjectInspectorFactory.getStandardListObjectInspector(
PrimitiveObjectInspectorFactory.writableLongObjectInspector);
ObjectInspector stringObjectInspector = ObjectInspectorFactory.getStandardListObjectInspector(
PrimitiveObjectInspectorFactory.writableStringObjectInspector);

Object i1 = new IntWritable(1);
Object i2 = new IntWritable(2);
Object i3 = new IntWritable(4);
Object i4 = new IntWritable(5);
Object i5 = new IntWritable(1);
Object i6 = new IntWritable(3);
Object i7 = new IntWritable(2);
Object i8 = new IntWritable(9);
List<Object> inputList = new ArrayList<>();
inputList.add(i1);
inputList.add(i2);
inputList.add(i3);
inputList.add(i4);

udf.initialize(new ObjectInspector[] { intObjectInspector, intObjectInspector });
runAndVerify(inputList, asList(i5, i6, i7, i8), asList(i3,i4));

i1 = new FloatWritable(3.3f);
i2 = new FloatWritable(1.1f);
i3 = new FloatWritable(4.3f);
i4 = new FloatWritable(2.22f);
i5 = new FloatWritable(3.3f);
i6 = new FloatWritable(1.1f);
i7 = new FloatWritable(2.28f);
i8 = new FloatWritable(2.20f);
List<Object> inputFloatList = new ArrayList<>();
inputFloatList.add(i1);
inputFloatList.add(i2);
inputFloatList.add(i3);
inputFloatList.add(i4);

udf.initialize(new ObjectInspector[] { floatObjectInspector, floatObjectInspector });
runAndVerify(new ArrayList<>(inputFloatList), asList(i5, i6, i7, i8), asList(i3, i4));

Object s1 = new Text("1");
Object s2 = new Text("2");
Object s3 = new Text("4");
Object s4 = new Text("5");
List<Object> inputStringList = new ArrayList<>();
inputStringList.add(s1);
inputStringList.add(s2);
inputStringList.add(s3);
inputStringList.add(s4);

udf.initialize(new ObjectInspector[] { stringObjectInspector, stringObjectInspector });
runAndVerify(inputStringList,asList(s1,s3),asList(s2,s4));
// Empty array output
runAndVerify(inputStringList,inputStringList,asList());
runAndVerify(inputStringList,asList(),inputStringList);
// Empty input arrays
runAndVerify(asList(),asList(),asList());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great test cases!

// Int & float arrays
UDFArgumentTypeException exception = Assert.assertThrows(UDFArgumentTypeException.class, () -> udf.initialize(new ObjectInspector[] { floatObjectInspector, intObjectInspector }));
Assert.assertEquals(GenericUDFArrayExcept.ERROR_NOT_COMPARABLE,exception.getMessage());
// float and string arrays
exception = Assert.assertThrows(UDFArgumentTypeException.class, () -> udf.initialize(new ObjectInspector[] { floatObjectInspector, stringObjectInspector }));
Assert.assertEquals(GenericUDFArrayExcept.ERROR_NOT_COMPARABLE,exception.getMessage());
// long and double arrays
exception = Assert.assertThrows(UDFArgumentTypeException.class, () -> udf.initialize(new ObjectInspector[] { longObjectInspector, doubleObjectInspector }));
Assert.assertEquals(GenericUDFArrayExcept.ERROR_NOT_COMPARABLE,exception.getMessage());
}

@Test
public void testList() throws HiveException {
ObjectInspector[] inputOIs = {
ObjectInspectorFactory.getStandardListObjectInspector(
ObjectInspectorFactory.getStandardListObjectInspector(
PrimitiveObjectInspectorFactory.writableStringObjectInspector
)
),
ObjectInspectorFactory.getStandardListObjectInspector(
ObjectInspectorFactory.getStandardListObjectInspector(
PrimitiveObjectInspectorFactory.writableStringObjectInspector
)
)
};
udf.initialize(inputOIs);

Object i1 = asList(new Text("aa1"), new Text("dd"), new Text("cc"), new Text("bb"));
Object i2 = asList(new Text("aa2"), new Text("cc"), new Text("ba"), new Text("dd"));
Object i3 = asList(new Text("aa3"), new Text("cc"), new Text("dd"), new Text("ee"), new Text("bb"));
Object i4 = asList(new Text("aa4"), new Text("cc"), new Text("ddd"), new Text("bb"));
List<Object> inputList = new ArrayList<>();
inputList.add(i1);
inputList.add(i2);
inputList.add(i3);
inputList.add(i4);
runAndVerify(inputList, asList(i1, i2, i2), asList(i3, i4));
}

@Test
public void testStruct() throws HiveException {
ObjectInspector[] inputOIs = {
ObjectInspectorFactory.getStandardListObjectInspector(
ObjectInspectorFactory.getStandardStructObjectInspector(
asList("f1", "f2", "f3", "f4"),
asList(
PrimitiveObjectInspectorFactory.writableStringObjectInspector,
PrimitiveObjectInspectorFactory.writableDoubleObjectInspector,
PrimitiveObjectInspectorFactory.writableDateObjectInspector,
ObjectInspectorFactory.getStandardListObjectInspector(
PrimitiveObjectInspectorFactory.writableIntObjectInspector
)
)
)
),
ObjectInspectorFactory.getStandardListObjectInspector(
ObjectInspectorFactory.getStandardStructObjectInspector(
asList("f1", "f2", "f3", "f4"),
asList(
PrimitiveObjectInspectorFactory.writableStringObjectInspector,
PrimitiveObjectInspectorFactory.writableDoubleObjectInspector,
PrimitiveObjectInspectorFactory.writableDateObjectInspector,
ObjectInspectorFactory.getStandardListObjectInspector(
PrimitiveObjectInspectorFactory.writableIntObjectInspector
)
)
)
)
};
udf.initialize(inputOIs);

Object i1 = asList(new Text("a"), new DoubleWritable(3.1415),
new DateWritableV2(Date.of(2015, 5, 26)),
asList(new IntWritable(1), new IntWritable(3),
new IntWritable(2), new IntWritable(4)));

Object i2 = asList(new Text("b"), new DoubleWritable(3.14),
new DateWritableV2(Date.of(2015, 5, 26)),
asList(new IntWritable(1), new IntWritable(3),
new IntWritable(2), new IntWritable(4)));

Object i3 = asList(new Text("a"), new DoubleWritable(3.1415),
new DateWritableV2(Date.of(2015, 5, 25)),
asList(new IntWritable(1), new IntWritable(3),
new IntWritable(2), new IntWritable(5)));

Object i4 = asList(new Text("a"), new DoubleWritable(3.1415),
new DateWritableV2(Date.of(2015, 5, 25)),
asList(new IntWritable(1), new IntWritable(3),
new IntWritable(2), new IntWritable(4)));

List<Object> inputList = new ArrayList<>();
inputList.add(i1);
inputList.add(i2);
inputList.add(i3);
inputList.add(i4);
runAndVerify(inputList, asList(i1, i3), asList(i2, i4));
}

@Test
public void testMap() throws HiveException {
ObjectInspector[] inputOIs = {
ObjectInspectorFactory.getStandardListObjectInspector(
ObjectInspectorFactory.getStandardMapObjectInspector(
PrimitiveObjectInspectorFactory.writableStringObjectInspector,
PrimitiveObjectInspectorFactory.writableIntObjectInspector
)
),
ObjectInspectorFactory.getStandardListObjectInspector(
ObjectInspectorFactory.getStandardMapObjectInspector(
PrimitiveObjectInspectorFactory.writableStringObjectInspector,
PrimitiveObjectInspectorFactory.writableIntObjectInspector
)
)
};
udf.initialize(inputOIs);

Map<Text, IntWritable> m1 = new HashMap<>();
m1.put(new Text("a"), new IntWritable(4));
m1.put(new Text("b"), new IntWritable(3));
m1.put(new Text("c"), new IntWritable(1));
m1.put(new Text("d"), new IntWritable(2));

Map<Text, IntWritable> m2 = new HashMap<>();
m2.put(new Text("d"), new IntWritable(4));
m2.put(new Text("b"), new IntWritable(3));
m2.put(new Text("a"), new IntWritable(1));
m2.put(new Text("c"), new IntWritable(2));

Map<Text, IntWritable> m3 = new HashMap<>();
m3.put(new Text("d"), new IntWritable(4));
m3.put(new Text("b"), new IntWritable(3));
m3.put(new Text("a"), new IntWritable(1));

Map<Text, IntWritable> m4 = new HashMap<>();
m3.put(new Text("e"), new IntWritable(4));
m3.put(new Text("b"), new IntWritable(3));
m3.put(new Text("a"), new IntWritable(1));

List<Object> inputList = new ArrayList<>();
inputList.add(m1);
inputList.add(m3);
inputList.add(m2);
inputList.add(m4);
inputList.add(m1);
runAndVerify(inputList, asList(m1,m3), asList(m2,m4));
}

private void runAndVerify(List<Object> actual, List<Object> actual2, List<Object> expected)
throws HiveException {
GenericUDF.DeferredJavaObject[] args = {new GenericUDF.DeferredJavaObject(actual), new GenericUDF.DeferredJavaObject(actual2)};
List<?> result = (List<?>) udf.evaluate(args);
Assert.assertArrayEquals("Check content", expected.toArray(), result.toArray());
}
}
42 changes: 42 additions & 0 deletions ql/src/test/queries/clientpositive/udf_array_except.q
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
--! qt:dataset:src

-- SORT_QUERY_RESULTS

set hive.fetch.task.conversion=more;

DESCRIBE FUNCTION array_except;
DESCRIBE FUNCTION EXTENDED array_except;

-- evalutes function for array of primitives
SELECT array_except(array(1, 2, 3, null,3,4),array(1, 3, null));

SELECT array_except(array(),array());

SELECT array_except(array(null),array(null));

SELECT array_except(array(1.12, 2.23, 3.34, null,1.11,1.12,2.9),array(1.12,3.34,1.11,1.12));

SELECT array(1,2,3),array_except(array(1, 2, 3),array(1,3,4));

SELECT array_except(array(1.1234567890, 2.234567890, 3.34567890, null, 3.3456789, 2.234567,1.1234567890),array(1.1234567890, 3.34567890, null,2.234567));

SELECT array_except(array(11234567890, 2234567890, 334567890, null, 11234567890, 2234567890, 334567890, null),array(11234567890, 2234567890, 334567890));

SELECT array_except(array(array("a","b","c","d"),array("a","b","c","d"),array("a","b","c","d","e"),null,array("e","a","b","c","d")),array(array("a","b","c","d"),array("a","b","c","d"),array("a","b","c","d","e"),null));

# handle null array cases

dfs ${system:test.dfs.mkdir} ${system:test.tmp.dir}/test_null_array;

dfs -copyFromLocal ../../data/files/test_null_array.csv ${system:test.tmp.dir}/test_null_array/;

create external table test_null_array (id int, value Array<String>) ROW FORMAT DELIMITED
FIELDS TERMINATED BY ':' collection items terminated by ',' location '${system:test.tmp.dir}/test_null_array';

select value from test_null_array;

select array_except(value,value) from test_null_array;

select value, array_except(value,value) from test_null_array;

dfs -rm -r ${system:test.tmp.dir}/test_null_array;