diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/TimestampLTZNanos.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/TimestampLTZNanos.java new file mode 100644 index 0000000000000..b8f5bf8c9d37b --- /dev/null +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/TimestampLTZNanos.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.types; + +import org.apache.spark.annotation.Unstable; + +import java.io.Serializable; +import java.util.Objects; + +/** + * Physical representation of {@code TIMESTAMP_LTZ(p)} with nanosecond-capable precision. + * Values are stored as epoch microseconds plus nanoseconds within that microsecond (0-999). + * + * @since 4.2.0 + */ +@Unstable +public final class TimestampLTZNanos implements Serializable { + public static final int SIZE_IN_BYTES = 16; + + public final long epochMicros; + public final short nanosWithinMicro; + + public TimestampLTZNanos(long epochMicros, short nanosWithinMicro) { + TimestampNanosUtils.validateNanosWithinMicro(nanosWithinMicro); + this.epochMicros = epochMicros; + this.nanosWithinMicro = nanosWithinMicro; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + TimestampLTZNanos that = (TimestampLTZNanos) o; + return epochMicros == that.epochMicros && nanosWithinMicro == that.nanosWithinMicro; + } + + @Override + public int hashCode() { + return Objects.hash(epochMicros, nanosWithinMicro); + } + + @Override + public String toString() { + return "TimestampLTZNanos(" + epochMicros + ", " + nanosWithinMicro + ")"; + } +} diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/TimestampNTZNanos.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/TimestampNTZNanos.java new file mode 100644 index 0000000000000..f3a998679496c --- /dev/null +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/TimestampNTZNanos.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.types; + +import org.apache.spark.annotation.Unstable; + +import java.io.Serializable; +import java.util.Objects; + +/** + * Physical representation of {@code TIMESTAMP_NTZ(p)} with nanosecond-capable precision. + * Values are stored as epoch microseconds plus nanoseconds within that microsecond (0-999). + * + * @since 4.2.0 + */ +@Unstable +public final class TimestampNTZNanos implements Serializable { + public static final int SIZE_IN_BYTES = 16; + + public final long epochMicros; + public final short nanosWithinMicro; + + public TimestampNTZNanos(long epochMicros, short nanosWithinMicro) { + TimestampNanosUtils.validateNanosWithinMicro(nanosWithinMicro); + this.epochMicros = epochMicros; + this.nanosWithinMicro = nanosWithinMicro; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + TimestampNTZNanos that = (TimestampNTZNanos) o; + return epochMicros == that.epochMicros && nanosWithinMicro == that.nanosWithinMicro; + } + + @Override + public int hashCode() { + return Objects.hash(epochMicros, nanosWithinMicro); + } + + @Override + public String toString() { + return "TimestampNTZNanos(" + epochMicros + ", " + nanosWithinMicro + ")"; + } +} diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/TimestampNanosUtils.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/TimestampNanosUtils.java new file mode 100644 index 0000000000000..2ee76b4adb322 --- /dev/null +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/TimestampNanosUtils.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.types; + +import java.util.Map; + +import org.apache.spark.SparkIllegalArgumentException; + +final class TimestampNanosUtils { + static final int MAX_NANOS_WITHIN_MICRO = 999; + + private TimestampNanosUtils() {} + + static void validateNanosWithinMicro(short nanosWithinMicro) { + if (nanosWithinMicro < 0 || nanosWithinMicro > MAX_NANOS_WITHIN_MICRO) { + throw new SparkIllegalArgumentException( + "INTERNAL_ERROR", + Map.of( + "message", + "nanosWithinMicro must be in [0, " + MAX_NANOS_WITHIN_MICRO + "], got: " + + nanosWithinMicro)); + } + } +} diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/TimestampNanosSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/TimestampNanosSuite.java new file mode 100644 index 0000000000000..d125bbed7fdd8 --- /dev/null +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/TimestampNanosSuite.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.types; + +import org.apache.spark.SparkIllegalArgumentException; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.*; + +public class TimestampNanosSuite { + + @Test + public void timestampNTZNanosEqualsAndHashCode() { + TimestampNTZNanos t1 = new TimestampNTZNanos(1000L, (short) 100); + TimestampNTZNanos t2 = new TimestampNTZNanos(1001L, (short) 100); + TimestampNTZNanos t3 = new TimestampNTZNanos(1000L, (short) 101); + TimestampNTZNanos t4 = new TimestampNTZNanos(1000L, (short) 100); + + assertNotEquals(t1, t2); + assertNotEquals(t1, t3); + assertEquals(t1, t4); + assertEquals(t1.hashCode(), t4.hashCode()); + } + + @Test + public void timestampLTZNanosEqualsAndHashCode() { + TimestampLTZNanos t1 = new TimestampLTZNanos(2000L, (short) 0); + TimestampLTZNanos t2 = new TimestampLTZNanos(2000L, (short) 1); + TimestampLTZNanos t3 = new TimestampLTZNanos(2000L, (short) 0); + + assertNotEquals(t1, t2); + assertEquals(t1, t3); + } + + @Test + public void invalidNanosWithinMicroNTZ() { + assertThrows(SparkIllegalArgumentException.class, + () -> new TimestampNTZNanos(0L, (short) -1)); + assertThrows(SparkIllegalArgumentException.class, + () -> new TimestampNTZNanos(0L, (short) 1000)); + } + + @Test + public void invalidNanosWithinMicroLTZ() { + assertThrows(SparkIllegalArgumentException.class, + () -> new TimestampLTZNanos(0L, (short) -1)); + assertThrows(SparkIllegalArgumentException.class, + () -> new TimestampLTZNanos(0L, (short) 1000)); + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java index 2a3a6884c3c6e..f0f5261cf0573 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java @@ -23,6 +23,8 @@ import org.apache.spark.sql.types.Decimal; import org.apache.spark.sql.catalyst.util.MapData; import org.apache.spark.unsafe.types.CalendarInterval; +import org.apache.spark.unsafe.types.TimestampLTZNanos; +import org.apache.spark.unsafe.types.TimestampNTZNanos; import org.apache.spark.unsafe.types.UTF8String; import org.apache.spark.unsafe.types.VariantVal; import org.apache.spark.unsafe.types.GeographyVal; @@ -58,6 +60,10 @@ public interface SpecializedGetters { CalendarInterval getInterval(int ordinal); + TimestampNTZNanos getTimestampNTZNanos(int ordinal); + + TimestampLTZNanos getTimestampLTZNanos(int ordinal); + VariantVal getVariant(int ordinal); InternalRow getStruct(int ordinal, int numFields); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGettersReader.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGettersReader.java index 830aa0d0d0fb4..d1d4608a1d4c7 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGettersReader.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGettersReader.java @@ -72,6 +72,12 @@ public static Object read( if (physicalDataType instanceof PhysicalCalendarIntervalType) { return obj.getInterval(ordinal); } + if (physicalDataType instanceof PhysicalTimestampNTZNanosType) { + return obj.getTimestampNTZNanos(ordinal); + } + if (physicalDataType instanceof PhysicalTimestampLTZNanosType) { + return obj.getTimestampLTZNanos(ordinal); + } if (physicalDataType instanceof PhysicalBinaryType) { return obj.getBinary(ordinal); } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/TimestampNanosRowValues.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/TimestampNanosRowValues.java new file mode 100644 index 0000000000000..f32124d72779d --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/TimestampNanosRowValues.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions; + +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.types.TimestampLTZNanos; +import org.apache.spark.unsafe.types.TimestampNTZNanos; + +/** + * Shared read/write helpers for nanosecond timestamp values in UnsafeRow/UnsafeArrayData. + */ +public final class TimestampNanosRowValues { + public static final int SIZE_IN_BYTES = TimestampNTZNanos.SIZE_IN_BYTES; + + private TimestampNanosRowValues() { + } + + public static void writePayload( + Object baseObject, long baseOffset, int cursor, long epochMicros, short nanosWithinMicro) { + Platform.putLong(baseObject, baseOffset + cursor, epochMicros); + // Store nanos in the low 16 bits; upper 48 bits remain zero. + Platform.putLong(baseObject, baseOffset + cursor + 8, ((long) nanosWithinMicro) & 0xFFFFL); + } + + public static void zeroPayload(Object baseObject, long baseOffset, int cursor) { + Platform.putLong(baseObject, baseOffset + cursor, 0L); + Platform.putLong(baseObject, baseOffset + cursor + 8, 0L); + } + + public static long readEpochMicros(Object baseObject, long baseOffset, int offset) { + return Platform.getLong(baseObject, baseOffset + offset); + } + + public static short readNanosWithinMicro(Object baseObject, long baseOffset, int offset) { + return Platform.getShort(baseObject, baseOffset + offset + 8); + } + + public static TimestampNTZNanos readNTZ(Object baseObject, long baseOffset, int offset) { + return new TimestampNTZNanos( + readEpochMicros(baseObject, baseOffset, offset), + readNanosWithinMicro(baseObject, baseOffset, offset)); + } + + public static TimestampLTZNanos readLTZ(Object baseObject, long baseOffset, int offset) { + return new TimestampLTZNanos( + readEpochMicros(baseObject, baseOffset, offset), + readNanosWithinMicro(baseObject, baseOffset, offset)); + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java index 09ac634955fcb..5cf9635b87c00 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java @@ -38,6 +38,8 @@ import org.apache.spark.unsafe.bitset.BitSetMethods; import org.apache.spark.unsafe.hash.Murmur3_x86_32; import org.apache.spark.unsafe.types.CalendarInterval; +import org.apache.spark.unsafe.types.TimestampLTZNanos; +import org.apache.spark.unsafe.types.TimestampNTZNanos; import org.apache.spark.unsafe.types.UTF8String; import org.apache.spark.unsafe.types.VariantVal; import org.apache.spark.unsafe.types.GeographyVal; @@ -248,6 +250,20 @@ public CalendarInterval getInterval(int ordinal) { return new CalendarInterval(months, days, microseconds); } + @Override + public TimestampNTZNanos getTimestampNTZNanos(int ordinal) { + if (isNullAt(ordinal)) return null; + final int offset = (int) (getLong(ordinal) >> 32); + return TimestampNanosRowValues.readNTZ(baseObject, baseOffset, offset); + } + + @Override + public TimestampLTZNanos getTimestampLTZNanos(int ordinal) { + if (isNullAt(ordinal)) return null; + final int offset = (int) (getLong(ordinal) >> 32); + return TimestampNanosRowValues.readLTZ(baseObject, baseOffset, offset); + } + @Override public VariantVal getVariant(int ordinal) { if (isNullAt(ordinal)) return null; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index ff9eeea9bf126..d45eeee4db565 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -38,6 +38,8 @@ import org.apache.spark.unsafe.bitset.BitSetMethods; import org.apache.spark.unsafe.hash.Murmur3_x86_32; import org.apache.spark.unsafe.types.CalendarInterval; +import org.apache.spark.unsafe.types.TimestampLTZNanos; +import org.apache.spark.unsafe.types.TimestampNTZNanos; import org.apache.spark.unsafe.types.UTF8String; import org.apache.spark.unsafe.types.VariantVal; import org.apache.spark.unsafe.types.GeographyVal; @@ -96,7 +98,9 @@ public static boolean isMutable(DataType dt) { } PhysicalDataType pdt = PhysicalDataType.apply(dt); return pdt instanceof PhysicalPrimitiveType || pdt instanceof PhysicalDecimalType || - pdt instanceof PhysicalCalendarIntervalType; + pdt instanceof PhysicalCalendarIntervalType || + pdt instanceof PhysicalTimestampNTZNanosType || + pdt instanceof PhysicalTimestampLTZNanosType; } ////////////////////////////////////////////////////////////////////////////// @@ -322,6 +326,36 @@ public void setInterval(int ordinal, CalendarInterval value) { } } + @Override + public void setTimestampNTZNanos(int ordinal, TimestampNTZNanos value) { + setTimestampNanosPayload(ordinal, value == null, value == null ? 0L : value.epochMicros, + value == null ? 0 : value.nanosWithinMicro); + } + + @Override + public void setTimestampLTZNanos(int ordinal, TimestampLTZNanos value) { + setTimestampNanosPayload(ordinal, value == null, value == null ? 0L : value.epochMicros, + value == null ? 0 : value.nanosWithinMicro); + } + + private void setTimestampNanosPayload( + int ordinal, boolean isNull, long epochMicros, short nanosWithinMicro) { + assertIndexIsValid(ordinal); + long cursor = getLong(ordinal) >>> 32; + assert cursor > 0 : "invalid cursor " + cursor; + if (isNull) { + setNullAt(ordinal); + TimestampNanosRowValues.zeroPayload(baseObject, baseOffset, (int) cursor); + Platform.putLong( + baseObject, getFieldOffset(ordinal), + (cursor << 32) | TimestampNanosRowValues.SIZE_IN_BYTES); + } else { + TimestampNanosRowValues.writePayload( + baseObject, baseOffset, (int) cursor, epochMicros, nanosWithinMicro); + setLong(ordinal, (cursor << 32) | TimestampNanosRowValues.SIZE_IN_BYTES); + } + } + @Override public Object get(int ordinal, DataType dataType) { return SpecializedGettersReader.read(this, ordinal, dataType, true, true); @@ -446,6 +480,26 @@ public CalendarInterval getInterval(int ordinal) { } } + @Override + public TimestampNTZNanos getTimestampNTZNanos(int ordinal) { + if (isNullAt(ordinal)) { + return null; + } else { + final int offset = (int) (getLong(ordinal) >> 32); + return TimestampNanosRowValues.readNTZ(baseObject, baseOffset, offset); + } + } + + @Override + public TimestampLTZNanos getTimestampLTZNanos(int ordinal) { + if (isNullAt(ordinal)) { + return null; + } else { + final int offset = (int) (getLong(ordinal) >> 32); + return TimestampNanosRowValues.readLTZ(baseObject, baseOffset, offset); + } + } + @Override public VariantVal getVariant(int ordinal) { if (isNullAt(ordinal)) return null; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java index 3070fa3e74b1f..d59f35771785e 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java @@ -23,6 +23,8 @@ import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.bitset.BitSetMethods; import org.apache.spark.unsafe.types.CalendarInterval; +import org.apache.spark.unsafe.types.TimestampLTZNanos; +import org.apache.spark.unsafe.types.TimestampNTZNanos; import static org.apache.spark.sql.catalyst.expressions.UnsafeArrayData.calculateHeaderPortionInBytes; @@ -206,4 +208,24 @@ public void write(int ordinal, CalendarInterval input) { super.write(ordinal, input); } } + + @Override + public void write(int ordinal, TimestampNTZNanos input) { + assertIndexIsValid(ordinal); + if (input == null) { + setNull(ordinal); + } else { + super.write(ordinal, input); + } + } + + @Override + public void write(int ordinal, TimestampLTZNanos input) { + assertIndexIsValid(ordinal); + if (input == null) { + setNull(ordinal); + } else { + super.write(ordinal, input); + } + } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java index e2abc108bb1bc..ffbb9e577aa59 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java @@ -24,6 +24,9 @@ import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.bitset.BitSetMethods; import org.apache.spark.unsafe.types.CalendarInterval; +import org.apache.spark.unsafe.types.TimestampLTZNanos; +import org.apache.spark.unsafe.types.TimestampNTZNanos; +import org.apache.spark.sql.catalyst.expressions.TimestampNanosRowValues; import org.apache.spark.unsafe.types.GeographyVal; import org.apache.spark.unsafe.types.GeometryVal; import org.apache.spark.unsafe.types.UTF8String; @@ -161,6 +164,29 @@ public void write(int ordinal, CalendarInterval input) { increaseCursor(16); } + public void write(int ordinal, TimestampNTZNanos input) { + writeTimestampNanos(ordinal, input == null, input == null ? 0L : input.epochMicros, + input == null ? 0 : input.nanosWithinMicro); + } + + public void write(int ordinal, TimestampLTZNanos input) { + writeTimestampNanos(ordinal, input == null, input == null ? 0L : input.epochMicros, + input == null ? 0 : input.nanosWithinMicro); + } + + private void writeTimestampNanos( + int ordinal, boolean isNull, long epochMicros, short nanosWithinMicro) { + grow(TimestampNanosRowValues.SIZE_IN_BYTES); + if (isNull) { + BitSetMethods.set(getBuffer(), startingOffset, ordinal); + } else { + TimestampNanosRowValues.writePayload( + getBuffer(), 0, (int) cursor(), epochMicros, nanosWithinMicro); + } + setOffsetAndSize(ordinal, TimestampNanosRowValues.SIZE_IN_BYTES); + increaseCursor(TimestampNanosRowValues.SIZE_IN_BYTES); + } + public void write(int ordinal, VariantVal input) { // See the class comment of VariantVal for the format of the binary content. byte[] value = input.getValue(); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java index 8e9a5a620b3e4..6d269e3117527 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java @@ -18,11 +18,14 @@ import scala.PartialFunction; +import org.apache.spark.SparkUnsupportedOperationException; import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.Decimal; import org.apache.spark.sql.types.UserDefinedType; import org.apache.spark.unsafe.types.CalendarInterval; +import org.apache.spark.unsafe.types.TimestampLTZNanos; +import org.apache.spark.unsafe.types.TimestampNTZNanos; import org.apache.spark.unsafe.types.UTF8String; import org.apache.spark.unsafe.types.VariantVal; import org.apache.spark.unsafe.types.GeographyVal; @@ -326,6 +329,14 @@ public CalendarInterval getInterval(int rowId) { return new CalendarInterval(months, days, microseconds); } + public TimestampNTZNanos getTimestampNTZNanos(int rowId) { + throw SparkUnsupportedOperationException.apply(); + } + + public TimestampLTZNanos getTimestampLTZNanos(int rowId) { + throw SparkUnsupportedOperationException.apply(); + } + /** * Returns the Variant value for {@code rowId}. Similar to {@link #getInterval(int)}, the * implementation must implement {@link #getChild(int)} and define 2 child vectors of binary type diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java index 861a6a4c50e44..becf4b121be86 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java @@ -24,6 +24,8 @@ import org.apache.spark.sql.catalyst.util.GenericArrayData; import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.types.CalendarInterval; +import org.apache.spark.unsafe.types.TimestampLTZNanos; +import org.apache.spark.unsafe.types.TimestampNTZNanos; import org.apache.spark.unsafe.types.UTF8String; import org.apache.spark.unsafe.types.VariantVal; import org.apache.spark.unsafe.types.GeographyVal; @@ -191,6 +193,16 @@ public CalendarInterval getInterval(int ordinal) { return data.getInterval(offset + ordinal); } + @Override + public TimestampNTZNanos getTimestampNTZNanos(int ordinal) { + return data.getTimestampNTZNanos(offset + ordinal); + } + + @Override + public TimestampLTZNanos getTimestampLTZNanos(int ordinal) { + return data.getTimestampLTZNanos(offset + ordinal); + } + @Override public VariantVal getVariant(int ordinal) { return data.getVariant(offset + ordinal); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatchRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatchRow.java index 42b335dfd2bc1..d99efe81ce3df 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatchRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatchRow.java @@ -25,6 +25,8 @@ import org.apache.spark.sql.catalyst.types.*; import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.types.CalendarInterval; +import org.apache.spark.unsafe.types.TimestampLTZNanos; +import org.apache.spark.unsafe.types.TimestampNTZNanos; import org.apache.spark.unsafe.types.UTF8String; import org.apache.spark.unsafe.types.VariantVal; import org.apache.spark.unsafe.types.GeographyVal; @@ -155,6 +157,16 @@ public CalendarInterval getInterval(int ordinal) { return columns[ordinal].getInterval(rowId); } + @Override + public TimestampNTZNanos getTimestampNTZNanos(int ordinal) { + return columns[ordinal].getTimestampNTZNanos(rowId); + } + + @Override + public TimestampLTZNanos getTimestampLTZNanos(int ordinal) { + return columns[ordinal].getTimestampLTZNanos(rowId); + } + @Override public VariantVal getVariant(int ordinal) { return columns[ordinal].getVariant(rowId); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java index d66baa8fd8fe3..a7433c1b1ca20 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java @@ -23,6 +23,8 @@ import org.apache.spark.sql.catalyst.types.*; import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.types.CalendarInterval; +import org.apache.spark.unsafe.types.TimestampLTZNanos; +import org.apache.spark.unsafe.types.TimestampNTZNanos; import org.apache.spark.unsafe.types.UTF8String; import org.apache.spark.unsafe.types.VariantVal; import org.apache.spark.unsafe.types.GeographyVal; @@ -160,6 +162,16 @@ public CalendarInterval getInterval(int ordinal) { return data.getChild(ordinal).getInterval(rowId); } + @Override + public TimestampNTZNanos getTimestampNTZNanos(int ordinal) { + return data.getChild(ordinal).getTimestampNTZNanos(rowId); + } + + @Override + public TimestampLTZNanos getTimestampLTZNanos(int ordinal) { + return data.getChild(ordinal).getTimestampLTZNanos(rowId); + } + @Override public VariantVal getVariant(int ordinal) { return data.getChild(ordinal).getVariant(rowId); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index b27283cb3f647..22cabed9c0e01 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.catalyst.types.ops.TypeOps import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} +import org.apache.spark.unsafe.types.{CalendarInterval, TimestampLTZNanos, TimestampNTZNanos, UTF8String} import org.apache.spark.util.ArrayImplicits._ /** @@ -63,6 +63,10 @@ abstract class InternalRow extends SpecializedGetters with Serializable { def setInterval(i: Int, value: CalendarInterval): Unit = update(i, value) + def setTimestampNTZNanos(i: Int, value: TimestampNTZNanos): Unit = update(i, value) + + def setTimestampLTZNanos(i: Int, value: TimestampLTZNanos): Unit = update(i, value) + /** * Make a copy of the current [[InternalRow]] object. */ @@ -144,6 +148,10 @@ object InternalRow { case _: PhysicalStringType => (input, ordinal) => input.getUTF8String(ordinal) case PhysicalBinaryType => (input, ordinal) => input.getBinary(ordinal) case PhysicalCalendarIntervalType => (input, ordinal) => input.getInterval(ordinal) + case PhysicalTimestampNTZNanosType => (input, ordinal) => + input.getTimestampNTZNanos(ordinal) + case PhysicalTimestampLTZNanosType => (input, ordinal) => + input.getTimestampLTZNanos(ordinal) case t: PhysicalDecimalType => (input, ordinal) => input.getDecimal(ordinal, t.precision, t.scale) case t: PhysicalStructType => (input, ordinal) => input.getStruct(ordinal, t.fields.length) @@ -185,6 +193,10 @@ object InternalRow { case DoubleType => (input, v) => input.setDouble(ordinal, v.asInstanceOf[Double]) case CalendarIntervalType => (input, v) => input.setInterval(ordinal, v.asInstanceOf[CalendarInterval]) + case _: TimestampNTZNanosType => + (input, v) => input.setTimestampNTZNanos(ordinal, v.asInstanceOf[TimestampNTZNanos]) + case _: TimestampLTZNanosType => + (input, v) => input.setTimestampLTZNanos(ordinal, v.asInstanceOf[TimestampLTZNanos]) case DecimalType.Fixed(precision, _) => (input, v) => input.setDecimal(ordinal, v.asInstanceOf[Decimal], precision) case udt: UserDefinedType[_] => getWriter(ordinal, udt.sqlType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ProjectingInternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ProjectingInternalRow.scala index 0e451db6cfe25..fad88e11778f4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ProjectingInternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ProjectingInternalRow.scala @@ -105,6 +105,14 @@ case class ProjectingInternalRow(schema: StructType, row.getInterval(colOrdinals(ordinal)) } + override def getTimestampNTZNanos(ordinal: Int): TimestampNTZNanos = { + row.getTimestampNTZNanos(colOrdinals(ordinal)) + } + + override def getTimestampLTZNanos(ordinal: Int): TimestampLTZNanos = { + row.getTimestampLTZNanos(colOrdinals(ordinal)) + } + override def getVariant(ordinal: Int): VariantVal = { row.getVariant(colOrdinals(ordinal)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala index 53b3e0598d586..1b7022fa4e1cb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala @@ -159,6 +159,12 @@ object InterpretedUnsafeProjection { case PhysicalCalendarIntervalType => (v, i) => writer.write(i, v.getInterval(i)) + case PhysicalTimestampNTZNanosType => (v, i) => + writer.write(i, v.getTimestampNTZNanos(i)) + + case PhysicalTimestampLTZNanosType => (v, i) => + writer.write(i, v.getTimestampLTZNanos(i)) + case PhysicalBinaryType => (v, i) => writer.write(i, v.getBinary(i)) case _: PhysicalStringType => (v, i) => writer.write(i, v.getUTF8String(i)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala index 4211dd5e4df01..a6bc79e5b59ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala @@ -126,6 +126,14 @@ class JoinedRow extends InternalRow { override def getInterval(i: Int): CalendarInterval = if (i < row1.numFields) row1.getInterval(i) else row2.getInterval(i - row1.numFields) + override def getTimestampNTZNanos(i: Int): TimestampNTZNanos = + if (i < row1.numFields) row1.getTimestampNTZNanos(i) + else row2.getTimestampNTZNanos(i - row1.numFields) + + override def getTimestampLTZNanos(i: Int): TimestampLTZNanos = + if (i < row1.numFields) row1.getTimestampLTZNanos(i) + else row2.getTimestampLTZNanos(i - row1.numFields) + override def getVariant(i: Int): VariantVal = if (i < row1.numFields) row1.getVariant(i) else row2.getVariant(i - row1.numFields) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 139a7f03cfa40..729f5604bd965 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -1531,6 +1531,8 @@ object CodeGenerator extends Logging { classOf[UTF8String].getName, classOf[Decimal].getName, classOf[CalendarInterval].getName, + classOf[org.apache.spark.unsafe.types.TimestampNTZNanos].getName, + classOf[org.apache.spark.unsafe.types.TimestampLTZNanos].getName, classOf[VariantVal].getName, classOf[ArrayData].getName, classOf[UnsafeArrayData].getName, @@ -1695,6 +1697,8 @@ object CodeGenerator extends Logging { case _: PhysicalGeographyType => s"$input.getGeography($ordinal)" case _: PhysicalGeometryType => s"$input.getGeometry($ordinal)" case PhysicalCalendarIntervalType => s"$input.getInterval($ordinal)" + case PhysicalTimestampNTZNanosType => s"$input.getTimestampNTZNanos($ordinal)" + case PhysicalTimestampLTZNanosType => s"$input.getTimestampLTZNanos($ordinal)" case t: PhysicalDecimalType => s"$input.getDecimal($ordinal, ${t.precision}, ${t.scale})" case _: PhysicalMapType => s"$input.getMap($ordinal)" case PhysicalNullType => "null" @@ -1768,6 +1772,8 @@ object CodeGenerator extends Logging { dataType match { case _ if isPrimitiveType(jt) => s"$row.set${primitiveTypeName(jt)}($ordinal, $value)" case CalendarIntervalType => s"$row.setInterval($ordinal, $value)" + case _: TimestampNTZNanosType => s"$row.setTimestampNTZNanos($ordinal, $value)" + case _: TimestampLTZNanosType => s"$row.setTimestampLTZNanos($ordinal, $value)" case t: DecimalType => s"$row.setDecimal($ordinal, $value, ${t.precision})" case udt: UserDefinedType[_] => setColumn(row, udt.sqlType, ordinal, value) // The UTF8String, InternalRow, ArrayData and MapData may came from UnsafeRow, we should copy @@ -1983,6 +1989,8 @@ object CodeGenerator extends Logging { case PhysicalBooleanType => JAVA_BOOLEAN case PhysicalByteType => JAVA_BYTE case PhysicalCalendarIntervalType => "CalendarInterval" + case PhysicalTimestampNTZNanosType => "TimestampNTZNanos" + case PhysicalTimestampLTZNanosType => "TimestampLTZNanos" case PhysicalIntegerType => JAVA_INT case _: PhysicalDecimalType => "Decimal" case PhysicalDoubleType => JAVA_DOUBLE @@ -2015,6 +2023,8 @@ object CodeGenerator extends Logging { case _: GeometryType => classOf[GeometryVal] case _: StringType => classOf[UTF8String] case CalendarIntervalType => classOf[CalendarInterval] + case _: TimestampNTZNanosType => classOf[org.apache.spark.unsafe.types.TimestampNTZNanos] + case _: TimestampLTZNanosType => classOf[org.apache.spark.unsafe.types.TimestampLTZNanos] case _: StructType => classOf[InternalRow] case _: ArrayType => classOf[ArrayData] case _: MapType => classOf[MapData] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 3b222ca05235a..b27536aff4b22 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -149,6 +149,8 @@ object Literal { case _ if clz == classOf[BigInt] => DecimalType.SYSTEM_DEFAULT case _ if clz == classOf[BigDecimal] => DecimalType.SYSTEM_DEFAULT case _ if clz == classOf[CalendarInterval] => CalendarIntervalType + case _ if clz == classOf[TimestampNTZNanos] => TimestampNTZNanosType() + case _ if clz == classOf[TimestampLTZNanos] => TimestampLTZNanosType() case _ if clz == classOf[VariantVal] => VariantType case _ if clz.isArray => ArrayType(componentTypeToDataType(clz.getComponentType)) @@ -242,6 +244,8 @@ object Literal { case PhysicalBooleanType => v.isInstanceOf[Boolean] case PhysicalByteType => v.isInstanceOf[Byte] case PhysicalCalendarIntervalType => v.isInstanceOf[CalendarInterval] + case PhysicalTimestampNTZNanosType => v.isInstanceOf[TimestampNTZNanos] + case PhysicalTimestampLTZNanosType => v.isInstanceOf[TimestampLTZNanos] case PhysicalIntegerType => v.isInstanceOf[Int] case _: PhysicalDecimalType => v.isInstanceOf[Decimal] case PhysicalDoubleType => v.isInstanceOf[Double] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index b8d6054fc6fc5..2a71ab8f022ab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -49,6 +49,8 @@ trait BaseGenericInternalRow extends InternalRow { override def getGeometry(ordinal: Int): GeometryVal = getAs(ordinal) override def getArray(ordinal: Int): ArrayData = getAs(ordinal) override def getInterval(ordinal: Int): CalendarInterval = getAs(ordinal) + override def getTimestampNTZNanos(ordinal: Int): TimestampNTZNanos = getAs(ordinal) + override def getTimestampLTZNanos(ordinal: Int): TimestampLTZNanos = getAs(ordinal) override def getVariant(ordinal: Int): VariantVal = getAs(ordinal) override def getMap(ordinal: Int): MapData = getAs(ordinal) override def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs(ordinal) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/PhysicalDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/PhysicalDataType.scala index 6f49b3998652c..ba7a02c902720 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/PhysicalDataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/PhysicalDataType.scala @@ -24,8 +24,8 @@ import org.apache.spark.sql.catalyst.expressions.{Ascending, BoundReference, Int import org.apache.spark.sql.catalyst.types.ops.TypeOps import org.apache.spark.sql.catalyst.util.{ArrayData, CollationFactory, MapData, SQLOrderingUtil} import org.apache.spark.sql.errors.QueryExecutionErrors -import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteExactNumeric, ByteType, CalendarIntervalType, CharType, DataType, DateType, DayTimeIntervalType, Decimal, DecimalExactNumeric, DecimalType, DoubleExactNumeric, DoubleType, FloatExactNumeric, FloatType, FractionalType, GeographyType, GeometryType, IntegerExactNumeric, IntegerType, IntegralType, LongExactNumeric, LongType, MapType, NullType, NumericType, ShortExactNumeric, ShortType, StringType, StructField, StructType, TimestampNTZType, TimestampType, TimeType, VarcharType, VariantType, YearMonthIntervalType} -import org.apache.spark.unsafe.types.{ByteArray, GeographyVal, GeometryVal, UTF8String, VariantVal} +import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteExactNumeric, ByteType, CalendarIntervalType, CharType, DataType, DateType, DayTimeIntervalType, Decimal, DecimalExactNumeric, DecimalType, DoubleExactNumeric, DoubleType, FloatExactNumeric, FloatType, FractionalType, GeographyType, GeometryType, IntegerExactNumeric, IntegerType, IntegralType, LongExactNumeric, LongType, MapType, NullType, NumericType, ShortExactNumeric, ShortType, StringType, StructField, StructType, TimestampLTZNanosType, TimestampNTZNanosType, TimestampNTZType, TimestampType, TimeType, VarcharType, VariantType, YearMonthIntervalType} +import org.apache.spark.unsafe.types.{ByteArray, GeographyVal, GeometryVal, TimestampLTZNanos, TimestampNTZNanos, UTF8String, VariantVal} import org.apache.spark.util.ArrayImplicits._ sealed abstract class PhysicalDataType { @@ -55,6 +55,8 @@ object PhysicalDataType { case TimestampType => PhysicalLongType case TimestampNTZType => PhysicalLongType case CalendarIntervalType => PhysicalCalendarIntervalType + case _: TimestampNTZNanosType => PhysicalTimestampNTZNanosType + case _: TimestampLTZNanosType => PhysicalTimestampLTZNanosType case DayTimeIntervalType(_, _) => PhysicalLongType case YearMonthIntervalType(_, _) => PhysicalIntegerType case DateType => PhysicalIntegerType @@ -166,6 +168,24 @@ class PhysicalCalendarIntervalType() extends PhysicalDataType { } case object PhysicalCalendarIntervalType extends PhysicalCalendarIntervalType +class PhysicalTimestampNTZNanosType() extends PhysicalDataType { + override private[sql] def ordering = + throw QueryExecutionErrors.orderedOperationUnsupportedByDataTypeError( + "PhysicalTimestampNTZNanosType") + override private[sql] type InternalType = TimestampNTZNanos + @transient private[sql] lazy val tag = typeTag[InternalType] +} +case object PhysicalTimestampNTZNanosType extends PhysicalTimestampNTZNanosType + +class PhysicalTimestampLTZNanosType() extends PhysicalDataType { + override private[sql] def ordering = + throw QueryExecutionErrors.orderedOperationUnsupportedByDataTypeError( + "PhysicalTimestampLTZNanosType") + override private[sql] type InternalType = TimestampLTZNanos + @transient private[sql] lazy val tag = typeTag[InternalType] +} +case object PhysicalTimestampLTZNanosType extends PhysicalTimestampLTZNanosType + case class PhysicalDecimalType(precision: Int, scale: Int) extends PhysicalFractionalType { private[sql] type InternalType = Decimal private[sql] val ordering = Decimal.DecimalIsFractional diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala index 808a3d43bf200..c43344f0a1e0c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala @@ -75,6 +75,8 @@ class GenericArrayData(val array: Array[Any]) extends ArrayData { override def getGeography(ordinal: Int): GeographyVal = getAs(ordinal) override def getGeometry(ordinal: Int): GeometryVal = getAs(ordinal) override def getInterval(ordinal: Int): CalendarInterval = getAs(ordinal) + override def getTimestampNTZNanos(ordinal: Int): TimestampNTZNanos = getAs(ordinal) + override def getTimestampLTZNanos(ordinal: Int): TimestampLTZNanos = getAs(ordinal) override def getVariant(ordinal: Int): VariantVal = getAs(ordinal) override def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs(ordinal) override def getArray(ordinal: Int): ArrayData = getAs(ordinal) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala index f2925314e2e2b..4da9c88dd9a9e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala @@ -173,6 +173,7 @@ object UnsafeRowUtils { def avoidSetNullAt(dt: DataType): Boolean = dt match { case t: DecimalType if t.precision > Decimal.MAX_LONG_DIGITS => true case CalendarIntervalType => true + case _: TimestampNTZNanosType | _: TimestampLTZNanosType => true case _ => false } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimestampNanosRowSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimestampNanosRowSuite.scala new file mode 100644 index 0000000000000..91d6c442171e2 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimestampNanosRowSuite.scala @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.{TimestampLTZNanos, TimestampNTZNanos} + +class TimestampNanosRowSuite extends SparkFunSuite with ExpressionEvalHelper { + + private val ntzValue = new TimestampNTZNanos(1234567890123L, 42.toShort) + private val ltzValue = new TimestampLTZNanos(9876543210987L, 999.toShort) + + test("GenericInternalRow roundtrip for TIMESTAMP_NTZ nanos") { + val row = new GenericInternalRow(Array[Any](ntzValue, null)) + val accessor = InternalRow.getAccessor(TimestampNTZNanosType(9)) + val writer = InternalRow.getWriter(0, TimestampNTZNanosType(9)) + assert(accessor(row, 0) === ntzValue) + assert(accessor(row, 1) === null) + + val row2 = new GenericInternalRow(Array[Any](null, null)) + writer(row2, ntzValue) + assert(accessor(row2, 0) === ntzValue) + } + + test("GenericInternalRow roundtrip for TIMESTAMP_LTZ nanos") { + val row = new GenericInternalRow(Array[Any](ltzValue, null)) + val accessor = InternalRow.getAccessor(TimestampLTZNanosType(8)) + val writer = InternalRow.getWriter(0, TimestampLTZNanosType(8)) + assert(accessor(row, 0) === ltzValue) + assert(accessor(row, 1) === null) + + val row2 = new GenericInternalRow(Array[Any](null, null)) + writer(row2, ltzValue) + assert(accessor(row2, 0) === ltzValue) + } + + testBothCodegenAndInterpreted("UnsafeRow roundtrip for nanos timestamp columns") { + val schema = StructType(Seq( + StructField("ntz", TimestampNTZNanosType(9), nullable = true), + StructField("ltz", TimestampLTZNanosType(7), nullable = true))) + val fieldTypes = schema.map(_.dataType).toArray + val converter = UnsafeProjection.create(fieldTypes) + + val input = new SpecificInternalRow(fieldTypes.toIndexedSeq) + input.update(0, ntzValue) + input.update(1, ltzValue) + + val unsafeRow = converter.apply(input) + assert(unsafeRow.getTimestampNTZNanos(0) === ntzValue) + assert(unsafeRow.getTimestampLTZNanos(1) === ltzValue) + + val updatedNtz = new TimestampNTZNanos(1L, 0.toShort) + unsafeRow.setTimestampNTZNanos(0, updatedNtz) + assert(unsafeRow.getTimestampNTZNanos(0) === updatedNtz) + + val offset = unsafeRow.getLong(0) >>> 32 + unsafeRow.setTimestampNTZNanos(0, null) + assert(unsafeRow.getTimestampNTZNanos(0) === null) + assert(unsafeRow.getLong(0) >>> 32 === offset) + } + + testBothCodegenAndInterpreted("codegen projection reads nanos timestamp column") { + val boundRef = BoundReference(0, TimestampNTZNanosType(9), nullable = false) + val projection = GenerateUnsafeProjection.generate(Seq(boundRef)) + val input = new GenericInternalRow(Array[Any](ntzValue)) + val output = projection.apply(input) + assert(output.getTimestampNTZNanos(0) === ntzValue) + } + + test("literal validation for nanosecond timestamp types") { + Literal.validateLiteralValue(ntzValue, TimestampNTZNanosType(9)) + Literal.validateLiteralValue(ltzValue, TimestampLTZNanosType(7)) + intercept[IllegalArgumentException] { + Literal.validateLiteralValue(ntzValue, TimestampLTZNanosType(7)) + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjectionSuite.scala index 9c0d610f35f6b..eee45a4b8bc49 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjectionSuite.scala @@ -90,6 +90,8 @@ object AlwaysNull extends InternalRow { override def getGeography(ordinal: Int): GeographyVal = notSupported override def getGeometry(ordinal: Int): GeometryVal = notSupported override def getInterval(ordinal: Int): CalendarInterval = notSupported + override def getTimestampNTZNanos(ordinal: Int): TimestampNTZNanos = notSupported + override def getTimestampLTZNanos(ordinal: Int): TimestampLTZNanos = notSupported override def getVariant(ordinal: Int): VariantVal = notSupported override def getStruct(ordinal: Int, numFields: Int): InternalRow = notSupported override def getArray(ordinal: Int): ArrayData = notSupported @@ -122,6 +124,8 @@ object AlwaysNonNull extends InternalRow { override def getGeography(ordinal: Int): GeographyVal = notSupported override def getGeometry(ordinal: Int): GeometryVal = notSupported override def getInterval(ordinal: Int): CalendarInterval = notSupported + override def getTimestampNTZNanos(ordinal: Int): TimestampNTZNanos = notSupported + override def getTimestampLTZNanos(ordinal: Int): TimestampLTZNanos = notSupported override def getVariant(ordinal: Int): VariantVal = notSupported override def getStruct(ordinal: Int, numFields: Int): InternalRow = notSupported override def getArray(ordinal: Int): ArrayData = stringToUTF8Array(Array("1", "2", "3")) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index 1a7524dbc5a73..a3e2bbca817b1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -417,6 +417,15 @@ class DataTypeSuite extends SparkFunSuite { checkDefaultSize(TimestampNTZNanosType(TimestampNTZNanosType.MIN_PRECISION), 10) checkDefaultSize(TimestampNTZNanosType(TimestampNTZNanosType.MAX_PRECISION), 10) + test("PhysicalDataType for nanosecond timestamp types") { + for (p <- TimestampNTZNanosType.MIN_PRECISION to TimestampNTZNanosType.MAX_PRECISION) { + assert(PhysicalDataType(TimestampNTZNanosType(p)) != UninitializedPhysicalType) + } + for (p <- TimestampLTZNanosType.MIN_PRECISION to TimestampLTZNanosType.MAX_PRECISION) { + assert(PhysicalDataType(TimestampLTZNanosType(p)) != UninitializedPhysicalType) + } + } + def checkEqualsIgnoreCompatibleNullability( from: DataType, to: DataType, diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java index a46b5143eef6d..a5a50d448e51d 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java @@ -29,6 +29,8 @@ import org.apache.spark.sql.vectorized.ColumnarMap; import org.apache.spark.sql.vectorized.ColumnarRow; import org.apache.spark.unsafe.types.CalendarInterval; +import org.apache.spark.unsafe.types.TimestampLTZNanos; +import org.apache.spark.unsafe.types.TimestampNTZNanos; import org.apache.spark.unsafe.types.GeographyVal; import org.apache.spark.unsafe.types.GeometryVal; import org.apache.spark.unsafe.types.UTF8String; @@ -165,6 +167,16 @@ public CalendarInterval getInterval(int ordinal) { return columns[ordinal].getInterval(rowId); } + @Override + public TimestampNTZNanos getTimestampNTZNanos(int ordinal) { + return columns[ordinal].getTimestampNTZNanos(rowId); + } + + @Override + public TimestampLTZNanos getTimestampLTZNanos(int ordinal) { + return columns[ordinal].getTimestampLTZNanos(rowId); + } + @Override public VariantVal getVariant(int ordinal) { return columns[ordinal].getVariant(rowId);