Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-41226][SQL] Refactor Spark types by introducing physical types #38750

Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.catalyst.expressions;

import org.apache.spark.sql.catalyst.types.*;
import org.apache.spark.sql.types.*;

public final class SpecializedGettersReader {
Expand All @@ -28,70 +29,56 @@ public static Object read(
DataType dataType,
boolean handleNull,
boolean handleUserDefinedType) {
if (handleNull && (obj.isNullAt(ordinal) || dataType instanceof NullType)) {
PhysicalDataType physicalDataType = dataType.physicalDataType();
if (handleNull && (obj.isNullAt(ordinal) || physicalDataType instanceof PhysicalNullType)) {
return null;
}
if (dataType instanceof BooleanType) {
if (physicalDataType instanceof PhysicalBooleanType) {
return obj.getBoolean(ordinal);
}
if (dataType instanceof ByteType) {
if (physicalDataType instanceof PhysicalByteType) {
return obj.getByte(ordinal);
}
if (dataType instanceof ShortType) {
if (physicalDataType instanceof PhysicalShortType) {
return obj.getShort(ordinal);
}
if (dataType instanceof IntegerType) {
if (physicalDataType instanceof PhysicalIntegerType) {
return obj.getInt(ordinal);
}
if (dataType instanceof LongType) {
if (physicalDataType instanceof PhysicalLongType) {
return obj.getLong(ordinal);
}
if (dataType instanceof FloatType) {
if (physicalDataType instanceof PhysicalFloatType) {
return obj.getFloat(ordinal);
}
if (dataType instanceof DoubleType) {
if (physicalDataType instanceof PhysicalDoubleType) {
return obj.getDouble(ordinal);
}
if (dataType instanceof StringType) {
if (physicalDataType instanceof PhysicalStringType) {
return obj.getUTF8String(ordinal);
}
if (dataType instanceof DecimalType) {
DecimalType dt = (DecimalType) dataType;
if (physicalDataType instanceof PhysicalDecimalType) {
PhysicalDecimalType dt = (PhysicalDecimalType) physicalDataType;
return obj.getDecimal(ordinal, dt.precision(), dt.scale());
}
if (dataType instanceof DateType) {
return obj.getInt(ordinal);
}
if (dataType instanceof TimestampType) {
return obj.getLong(ordinal);
}
if (dataType instanceof TimestampNTZType) {
return obj.getLong(ordinal);
}
if (dataType instanceof CalendarIntervalType) {
if (physicalDataType instanceof PhysicalCalendarIntervalType) {
return obj.getInterval(ordinal);
}
if (dataType instanceof BinaryType) {
if (physicalDataType instanceof PhysicalBinaryType) {
return obj.getBinary(ordinal);
}
if (dataType instanceof StructType) {
return obj.getStruct(ordinal, ((StructType) dataType).size());
if (physicalDataType instanceof PhysicalStructType) {
return obj.getStruct(ordinal, ((PhysicalStructType) physicalDataType).fields().length);
}
if (dataType instanceof ArrayType) {
if (physicalDataType instanceof PhysicalArrayType) {
return obj.getArray(ordinal);
}
if (dataType instanceof MapType) {
if (physicalDataType instanceof PhysicalMapType) {
return obj.getMap(ordinal);
}
if (handleUserDefinedType && dataType instanceof UserDefinedType) {
return obj.get(ordinal, ((UserDefinedType)dataType).sqlType());
}
if (dataType instanceof DayTimeIntervalType) {
return obj.getLong(ordinal);
}
if (dataType instanceof YearMonthIntervalType) {
return obj.getInt(ordinal);
}

throw new UnsupportedOperationException("Unsupported data type " + dataType.simpleString());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.apache.spark.annotation.DeveloperApi;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow;
import org.apache.spark.sql.catalyst.types.*;
import org.apache.spark.sql.types.*;
import org.apache.spark.unsafe.types.CalendarInterval;
import org.apache.spark.unsafe.types.UTF8String;
Expand Down Expand Up @@ -48,36 +49,33 @@ public InternalRow copy() {
row.setNullAt(i);
} else {
DataType dt = columns[i].dataType();
if (dt instanceof BooleanType) {
PhysicalDataType pdt = dt.physicalDataType();
if (pdt instanceof PhysicalBooleanType) {
row.setBoolean(i, getBoolean(i));
} else if (dt instanceof ByteType) {
} else if (pdt instanceof PhysicalByteType) {
row.setByte(i, getByte(i));
} else if (dt instanceof ShortType) {
} else if (pdt instanceof PhysicalShortType) {
row.setShort(i, getShort(i));
} else if (dt instanceof IntegerType || dt instanceof YearMonthIntervalType) {
} else if (pdt instanceof PhysicalIntegerType) {
row.setInt(i, getInt(i));
} else if (dt instanceof LongType || dt instanceof DayTimeIntervalType) {
} else if (pdt instanceof PhysicalLongType) {
row.setLong(i, getLong(i));
} else if (dt instanceof FloatType) {
} else if (pdt instanceof PhysicalFloatType) {
row.setFloat(i, getFloat(i));
} else if (dt instanceof DoubleType) {
} else if (pdt instanceof PhysicalDoubleType) {
row.setDouble(i, getDouble(i));
} else if (dt instanceof StringType) {
} else if (pdt instanceof PhysicalStringType) {
row.update(i, getUTF8String(i).copy());
} else if (dt instanceof BinaryType) {
} else if (pdt instanceof PhysicalBinaryType) {
row.update(i, getBinary(i));
} else if (dt instanceof DecimalType) {
DecimalType t = (DecimalType)dt;
} else if (pdt instanceof PhysicalDecimalType) {
PhysicalDecimalType t = (PhysicalDecimalType)pdt;
row.setDecimal(i, getDecimal(i, t.precision(), t.scale()), t.precision());
} else if (dt instanceof DateType) {
row.setInt(i, getInt(i));
} else if (dt instanceof TimestampType) {
row.setLong(i, getLong(i));
} else if (dt instanceof StructType) {
row.update(i, getStruct(i, ((StructType) dt).fields().length).copy());
} else if (dt instanceof ArrayType) {
} else if (pdt instanceof PhysicalStructType) {
row.update(i, getStruct(i, ((PhysicalStructType) pdt).fields().length).copy());
} else if (pdt instanceof PhysicalArrayType) {
row.update(i, getArray(i).copy());
} else if (dt instanceof MapType) {
} else if (pdt instanceof PhysicalMapType) {
row.update(i, getMap(i).copy());
} else {
throw new RuntimeException("Not implemented. " + dt);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow;
import org.apache.spark.sql.catalyst.types.*;
import org.apache.spark.sql.types.*;
import org.apache.spark.unsafe.types.CalendarInterval;
import org.apache.spark.unsafe.types.UTF8String;
Expand Down Expand Up @@ -55,36 +56,33 @@ public InternalRow copy() {
row.setNullAt(i);
} else {
DataType dt = data.getChild(i).dataType();
if (dt instanceof BooleanType) {
PhysicalDataType pdt = dt.physicalDataType();
if (pdt instanceof PhysicalBooleanType) {
row.setBoolean(i, getBoolean(i));
} else if (dt instanceof ByteType) {
} else if (pdt instanceof PhysicalByteType) {
row.setByte(i, getByte(i));
} else if (dt instanceof ShortType) {
} else if (pdt instanceof PhysicalShortType) {
row.setShort(i, getShort(i));
} else if (dt instanceof IntegerType || dt instanceof YearMonthIntervalType) {
} else if (pdt instanceof PhysicalIntegerType) {
row.setInt(i, getInt(i));
} else if (dt instanceof LongType || dt instanceof DayTimeIntervalType) {
} else if (pdt instanceof PhysicalLongType) {
row.setLong(i, getLong(i));
} else if (dt instanceof FloatType) {
} else if (pdt instanceof PhysicalFloatType) {
row.setFloat(i, getFloat(i));
} else if (dt instanceof DoubleType) {
} else if (pdt instanceof PhysicalDoubleType) {
row.setDouble(i, getDouble(i));
} else if (dt instanceof StringType) {
} else if (pdt instanceof PhysicalStringType) {
row.update(i, getUTF8String(i).copy());
} else if (dt instanceof BinaryType) {
} else if (pdt instanceof PhysicalBinaryType) {
row.update(i, getBinary(i));
} else if (dt instanceof DecimalType) {
DecimalType t = (DecimalType)dt;
} else if (pdt instanceof PhysicalDecimalType) {
PhysicalDecimalType t = (PhysicalDecimalType)pdt;
row.setDecimal(i, getDecimal(i, t.precision(), t.scale()), t.precision());
} else if (dt instanceof DateType) {
row.setInt(i, getInt(i));
} else if (dt instanceof TimestampType) {
row.setLong(i, getLong(i));
} else if (dt instanceof StructType) {
row.update(i, getStruct(i, ((StructType) dt).fields().length).copy());
} else if (dt instanceof ArrayType) {
} else if (pdt instanceof PhysicalStructType) {
row.update(i, getStruct(i, ((PhysicalStructType) pdt).fields().length).copy());
} else if (pdt instanceof PhysicalArrayType) {
row.update(i, getArray(i).copy());
} else if (dt instanceof MapType) {
} else if (pdt instanceof PhysicalMapType) {
row.update(i, getMap(i).copy());
} else {
throw new RuntimeException("Not implemented. " + dt);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.types._
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
Expand Down Expand Up @@ -129,24 +130,25 @@ object InternalRow {
*/
def getAccessor(dt: DataType, nullable: Boolean = true): (SpecializedGetters, Int) => Any = {
val getValueNullSafe: (SpecializedGetters, Int) => Any = dt match {
desmondcheongzx marked this conversation as resolved.
Show resolved Hide resolved
case BooleanType => (input, ordinal) => input.getBoolean(ordinal)
case ByteType => (input, ordinal) => input.getByte(ordinal)
case ShortType => (input, ordinal) => input.getShort(ordinal)
case IntegerType | DateType | _: YearMonthIntervalType =>
(input, ordinal) => input.getInt(ordinal)
case LongType | TimestampType | TimestampNTZType | _: DayTimeIntervalType =>
(input, ordinal) => input.getLong(ordinal)
case FloatType => (input, ordinal) => input.getFloat(ordinal)
case DoubleType => (input, ordinal) => input.getDouble(ordinal)
case StringType => (input, ordinal) => input.getUTF8String(ordinal)
case BinaryType => (input, ordinal) => input.getBinary(ordinal)
case CalendarIntervalType => (input, ordinal) => input.getInterval(ordinal)
case t: DecimalType => (input, ordinal) => input.getDecimal(ordinal, t.precision, t.scale)
case t: StructType => (input, ordinal) => input.getStruct(ordinal, t.size)
case _: ArrayType => (input, ordinal) => input.getArray(ordinal)
case _: MapType => (input, ordinal) => input.getMap(ordinal)
case u: UserDefinedType[_] => getAccessor(u.sqlType, nullable)
case _ => (input, ordinal) => input.get(ordinal, dt)
case _ => dt.physicalDataType match {
case _: PhysicalBooleanType => (input, ordinal) => input.getBoolean(ordinal)
case _: PhysicalByteType => (input, ordinal) => input.getByte(ordinal)
case _: PhysicalShortType => (input, ordinal) => input.getShort(ordinal)
case _: PhysicalIntegerType => (input, ordinal) => input.getInt(ordinal)
case _: PhysicalLongType => (input, ordinal) => input.getLong(ordinal)
case _: PhysicalFloatType => (input, ordinal) => input.getFloat(ordinal)
case _: PhysicalDoubleType => (input, ordinal) => input.getDouble(ordinal)
case _: PhysicalStringType => (input, ordinal) => input.getUTF8String(ordinal)
case _: PhysicalBinaryType => (input, ordinal) => input.getBinary(ordinal)
case _: PhysicalCalendarIntervalType => (input, ordinal) => input.getInterval(ordinal)
case t: PhysicalDecimalType => (input, ordinal) =>
input.getDecimal(ordinal, t.precision, t.scale)
case t: PhysicalStructType => (input, ordinal) => input.getStruct(ordinal, t.fields.size)
case _: PhysicalArrayType => (input, ordinal) => input.getArray(ordinal)
case _: PhysicalMapType => (input, ordinal) => input.getMap(ordinal)
case _ => (input, ordinal) => input.get(ordinal, dt)
}
}

if (nullable) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.SerializerBuildHelper._
import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.objects._
import org.apache.spark.sql.catalyst.types._
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -214,6 +215,8 @@ object RowEncoder {
} else {
nonNullOutput
}
// For other data types, return the internal catalyst value as it is.
case _ => inputObject
desmondcheongzx marked this conversation as resolved.
Show resolved Hide resolved
}

/**
Expand Down Expand Up @@ -253,13 +256,17 @@ object RowEncoder {
}
case _: DayTimeIntervalType => ObjectType(classOf[java.time.Duration])
case _: YearMonthIntervalType => ObjectType(classOf[java.time.Period])
case _: DecimalType => ObjectType(classOf[java.math.BigDecimal])
case StringType => ObjectType(classOf[java.lang.String])
case _: ArrayType => ObjectType(classOf[scala.collection.Seq[_]])
case _: MapType => ObjectType(classOf[scala.collection.Map[_, _]])
case _: StructType => ObjectType(classOf[Row])
case p: PythonUserDefinedType => externalDataTypeFor(p.sqlType)
case udt: UserDefinedType[_] => ObjectType(udt.userClass)
case _ => dt.physicalDataType match {
case _: PhysicalArrayType => ObjectType(classOf[scala.collection.Seq[_]])
case _: PhysicalDecimalType => ObjectType(classOf[java.math.BigDecimal])
case _: PhysicalMapType => ObjectType(classOf[scala.collection.Map[_, _]])
case _: PhysicalStringType => ObjectType(classOf[java.lang.String])
case _: PhysicalStructType => ObjectType(classOf[Row])
// For other data types, return the data type as it is.
case _ => dt
desmondcheongzx marked this conversation as resolved.
Show resolved Hide resolved
}
}

private def deserializerFor(input: Expression, schema: StructType): Expression = {
Expand Down Expand Up @@ -358,6 +365,9 @@ object RowEncoder {
If(IsNull(input),
Literal.create(null, externalDataTypeFor(input.dataType)),
CreateExternalRow(convertedFields, schema))

// For other data types, return the internal catalyst value as it is.
case _ => input
desmondcheongzx marked this conversation as resolved.
Show resolved Hide resolved
}

private def expressionForNullableExpr(
Expand Down
Loading