Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ARROW] Fix Spark session timezone format in arrow-based result format #4326

Closed
wants to merge 10 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 17 additions & 16 deletions docs/deployment/settings.md

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -162,17 +162,12 @@ class ExecuteStatement(
}
}

// TODO:(fchen) make this configurable
val kyuubiBeelineConvertToString = true

def convertComplexType(df: DataFrame): DataFrame = {
if (kyuubiBeelineConvertToString) {
SparkDatasetHelper.convertTopLevelComplexTypeToHiveString(df)
} else {
df
}
SparkDatasetHelper.convertTopLevelComplexTypeToHiveString(df, timestampAsString)
}

override def getResultSetMetadataHints(): Seq[String] =
Seq(s"__kyuubi_operation_result_format__=$resultFormat")
Seq(
s"__kyuubi_operation_result_format__=$resultFormat",
s"__kyuubi_operation_result_arrow_timestampAsString__=$timestampAsString")
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.hive.service.rpc.thrift.{TGetResultSetMetadataResp, TProgressU
import org.apache.spark.kyuubi.{SparkProgressMonitor, SQLOperationListener}
import org.apache.spark.kyuubi.SparkUtilsHelper.redact
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.types.StructType

import org.apache.kyuubi.{KyuubiSQLException, Utils}
Expand Down Expand Up @@ -136,7 +136,7 @@ abstract class SparkOperation(session: Session)
spark.sparkContext.setLocalProperty

protected def withLocalProperties[T](f: => T): T = {
SQLConf.withExistingConf(spark.sessionState.conf) {
SQLExecution.withSQLConfPropagated(spark) {
val originalSession = SparkSession.getActiveSession
try {
SparkSession.setActiveSession(spark)
Expand Down Expand Up @@ -279,6 +279,10 @@ abstract class SparkOperation(session: Session)
spark.conf.get("kyuubi.operation.result.format", "thrift")
}

protected def timestampAsString: Boolean = {
spark.conf.get("kyuubi.operation.result.arrow.timestampAsString", "false").toBoolean
}

protected def setSessionUserSign(): Unit = {
(
session.conf.get(KYUUBI_SESSION_SIGN_PUBLICKEY),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@

package org.apache.spark.sql.kyuubi

import java.time.ZoneId

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
Expand All @@ -31,12 +29,13 @@ object SparkDatasetHelper {
ds.toArrowBatchRdd
}

def convertTopLevelComplexTypeToHiveString(df: DataFrame): DataFrame = {
val timeZone = ZoneId.of(df.sparkSession.sessionState.conf.sessionLocalTimeZone)
def convertTopLevelComplexTypeToHiveString(
df: DataFrame,
timestampAsString: Boolean): DataFrame = {

val quotedCol = (name: String) => col(quoteIfNeeded(name))

// an udf to call `RowSet.toHiveString` on complex types(struct/array/map).
// an udf to call `RowSet.toHiveString` on complex types(struct/array/map) and timestamp type.
val toHiveStringUDF = udf[String, Row, String]((row, schemaDDL) => {
val dt = DataType.fromDDL(schemaDDL)
dt match {
Expand All @@ -46,6 +45,8 @@ object SparkDatasetHelper {
RowSet.toHiveString((row.toSeq.head, at), nested = true)
case StructType(Array(StructField(_, mt: MapType, _, _))) =>
RowSet.toHiveString((row.toSeq.head, mt), nested = true)
case StructType(Array(StructField(_, tt: TimestampType, _, _))) =>
RowSet.toHiveString((row.toSeq.head, tt), nested = true)
case _ =>
throw new UnsupportedOperationException
}
Expand All @@ -56,6 +57,8 @@ object SparkDatasetHelper {
toHiveStringUDF(quotedCol(name), lit(sf.toDDL)).as(name)
case sf @ StructField(name, _: MapType | _: ArrayType, _, _) =>
toHiveStringUDF(struct(quotedCol(name)), lit(sf.toDDL)).as(name)
case sf @ StructField(name, _: TimestampType, _, _) if timestampAsString =>
toHiveStringUDF(struct(quotedCol(name)), lit(sf.toDDL)).as(name)
case StructField(name, _, _, _) => quotedCol(name)
}
df.select(cols: _*)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,13 @@ class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTyp

override def resultFormat: String = "arrow"

override def beforeEach(): Unit = {
super.beforeEach()
withJdbcStatement() { statement =>
checkResultSetFormat(statement, "arrow")
}
}

test("detect resultSet format") {
withJdbcStatement() { statement =>
checkResultSetFormat(statement, "arrow")
Expand All @@ -43,7 +50,42 @@ class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTyp
}
}

def checkResultSetFormat(statement: Statement, expectFormat: String): Unit = {
test("Spark session timezone format") {
withJdbcStatement() { statement =>
def check(expect: String): Unit = {
val query =
"""
|SELECT
| from_utc_timestamp(
| from_unixtime(
| 1670404535000 / 1000, 'yyyy-MM-dd HH:mm:ss'
| ),
| 'GMT+08:00'
| )
|""".stripMargin
val resultSet = statement.executeQuery(query)
assert(resultSet.next())
assert(resultSet.getString(1) == expect)
}

def setTimeZone(timeZone: String): Unit = {
val rs = statement.executeQuery(s"set spark.sql.session.timeZone=$timeZone")
assert(rs.next())
}

Seq("true", "false").foreach { timestampAsString =>
statement.executeQuery(
s"set ${KyuubiConf.ARROW_BASED_ROWSET_TIMESTAMP_AS_STRING.key}=$timestampAsString")
checkArrowBasedRowSetTimestampAsString(statement, timestampAsString)
setTimeZone("UTC")
check("2022-12-07 17:15:35.0")
setTimeZone("GMT+8")
check("2022-12-08 01:15:35.0")
}
}
}

private def checkResultSetFormat(statement: Statement, expectFormat: String): Unit = {
val query =
s"""
|SELECT '$${hivevar:${KyuubiConf.OPERATION_RESULT_FORMAT.key}}' AS col
Expand All @@ -52,4 +94,16 @@ class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTyp
assert(resultSet.next())
assert(resultSet.getString("col") === expectFormat)
}

private def checkArrowBasedRowSetTimestampAsString(
statement: Statement,
expect: String): Unit = {
val query =
s"""
|SELECT '$${hivevar:${KyuubiConf.ARROW_BASED_ROWSET_TIMESTAMP_AS_STRING.key}}' AS col
|""".stripMargin
val resultSet = statement.executeQuery(query)
assert(resultSet.next())
assert(resultSet.getString("col") === expect)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1676,6 +1676,14 @@ object KyuubiConf {
.transform(_.toLowerCase(Locale.ROOT))
.createWithDefault("thrift")

val ARROW_BASED_ROWSET_TIMESTAMP_AS_STRING: ConfigEntry[Boolean] =
buildConf("kyuubi.operation.result.arrow.timestampAsString")
.doc("When true, arrow-based rowsets will convert columns of type timestamp to strings for" +
" transmission.")
.version("1.7.0")
.booleanConf
.createWithDefault(false)

val SERVER_OPERATION_LOG_DIR_ROOT: ConfigEntry[String] =
buildConf("kyuubi.operation.log.dir.root")
.doc("Root directory for query operation log at server-side.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
public class JdbcColumnAttributes {
public int precision = 0;
public int scale = 0;
public String timeZone = "";
public String timeZone = null;

public JdbcColumnAttributes() {}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ public abstract class KyuubiArrowBasedResultSet implements SQLResultSet {
protected Schema arrowSchema;
protected VectorSchemaRoot root;
protected ArrowColumnarBatchRow row;
protected boolean timestampAsString = true;

protected BufferAllocator allocator;

Expand Down Expand Up @@ -312,11 +313,18 @@ private Object getColumnValue(int columnIndex) throws SQLException {
if (wasNull) {
return null;
} else {
return row.get(columnIndex - 1, columnType);
JdbcColumnAttributes attributes = columnAttributes.get(columnIndex - 1);
return row.get(
columnIndex - 1,
columnType,
attributes == null ? null : attributes.timeZone,
timestampAsString);
}
} catch (Exception e) {
e.printStackTrace();
throw new KyuubiSQLException("Unrecognized column type:", e);
throw new KyuubiSQLException(
String.format(
"Error getting row of type %s at column index %d", columnType, columnIndex - 1),
e);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,6 @@ public class KyuubiArrowQueryResultSet extends KyuubiArrowBasedResultSet {
private boolean isScrollable = false;
private boolean fetchFirst = false;

// TODO:(fchen) make this configurable
protected boolean convertComplexTypeToString = true;

private final TProtocolVersion protocol;

public static class Builder {
Expand All @@ -87,6 +84,8 @@ public static class Builder {
private boolean isScrollable = false;
private ReentrantLock transportLock = null;

private boolean timestampAsString = true;

public Builder(Statement statement) throws SQLException {
this.statement = statement;
this.connection = statement.getConnection();
Expand Down Expand Up @@ -153,6 +152,11 @@ public Builder setScrollable(boolean setScrollable) {
return this;
}

public Builder setTimestampAsString(boolean timestampAsString) {
this.timestampAsString = timestampAsString;
return this;
}

public Builder setTransportLock(ReentrantLock transportLock) {
this.transportLock = transportLock;
return this;
Expand Down Expand Up @@ -189,10 +193,10 @@ protected KyuubiArrowQueryResultSet(Builder builder) throws SQLException {
this.maxRows = builder.maxRows;
}
this.isScrollable = builder.isScrollable;
this.timestampAsString = builder.timestampAsString;
this.protocol = builder.getProtocolVersion();
arrowSchema =
ArrowUtils.toArrowSchema(
columnNames, convertComplexTypeToStringType(columnTypes), columnAttributes);
ArrowUtils.toArrowSchema(columnNames, convertToStringType(columnTypes), columnAttributes);
if (allocator == null) {
initArrowSchemaAndAllocator();
}
Expand Down Expand Up @@ -271,8 +275,7 @@ private void retrieveSchema() throws SQLException {
columnAttributes.add(getColumnAttributes(primitiveTypeEntry));
}
arrowSchema =
ArrowUtils.toArrowSchema(
columnNames, convertComplexTypeToStringType(columnTypes), columnAttributes);
ArrowUtils.toArrowSchema(columnNames, convertToStringType(columnTypes), columnAttributes);
} catch (SQLException eS) {
throw eS; // rethrow the SQLException as is
} catch (Exception ex) {
Expand Down Expand Up @@ -480,22 +483,25 @@ public boolean isClosed() {
return isClosed;
}

private List<TTypeId> convertComplexTypeToStringType(List<TTypeId> colTypes) {
if (convertComplexTypeToString) {
return colTypes.stream()
.map(
type -> {
if (type == TTypeId.ARRAY_TYPE
|| type == TTypeId.MAP_TYPE
|| type == TTypeId.STRUCT_TYPE) {
return TTypeId.STRING_TYPE;
} else {
return type;
}
})
.collect(Collectors.toList());
} else {
return colTypes;
}
/**
* 1. the complex types (map/array/struct) are always converted to string type to transport 2. if
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's move 2. *** to new line

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The spotless plugin makes this change, otherwise, we wouldn't pass the style check :(

Copy link
Member

@pan3793 pan3793 Feb 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All right, leave it then

* the user set `timestampAsString = true`, then the timestamp type will be converted to string
* type too.
*/
private List<TTypeId> convertToStringType(List<TTypeId> colTypes) {
return colTypes.stream()
.map(
type -> {
if ((type == TTypeId.ARRAY_TYPE
|| type == TTypeId.MAP_TYPE
|| type == TTypeId.STRUCT_TYPE) // complex type (map/array/struct)
// timestamp type
|| (type == TTypeId.TIMESTAMP_TYPE && timestampAsString)) {
return TTypeId.STRING_TYPE;
} else {
return type;
}
})
.collect(Collectors.toList());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ public class KyuubiStatement implements SQLStatement, KyuubiLoggable {
public static final Logger LOG = LoggerFactory.getLogger(KyuubiStatement.class.getName());
public static final int DEFAULT_FETCH_SIZE = 1000;
public static final String DEFAULT_RESULT_FORMAT = "thrift";
public static final String DEFAULT_ARROW_TIMESTAMP_AS_STRING = "false";
private final KyuubiConnection connection;
private TCLIService.Iface client;
private TOperationHandle stmtHandle = null;
Expand All @@ -45,7 +46,8 @@ public class KyuubiStatement implements SQLStatement, KyuubiLoggable {
private int fetchSize = DEFAULT_FETCH_SIZE;
private boolean isScrollableResultset = false;
private boolean isOperationComplete = false;
private Map<String, String> properties = new HashMap<>();

private Map<String, String> properties = new TreeMap<>(String.CASE_INSENSITIVE_ORDER);
/**
* We need to keep a reference to the result set to support the following: <code>
* statement.execute(String sql);
Expand Down Expand Up @@ -213,6 +215,11 @@ private boolean executeWithConfOverlay(String sql, Map<String, String> confOverl
LOG.info("kyuubi.operation.result.format: " + resultFormat);
switch (resultFormat) {
case "arrow":
boolean timestampAsString =
Boolean.parseBoolean(
properties.getOrDefault(
"__kyuubi_operation_result_arrow_timestampAsString__",
DEFAULT_ARROW_TIMESTAMP_AS_STRING));
resultSet =
new KyuubiArrowQueryResultSet.Builder(this)
.setClient(client)
Expand All @@ -222,6 +229,7 @@ private boolean executeWithConfOverlay(String sql, Map<String, String> confOverl
.setFetchSize(fetchSize)
.setScrollable(isScrollableResultset)
.setSchema(columnNames, columnTypes, columnAttributes)
.setTimestampAsString(timestampAsString)
.build();
break;
default:
Expand Down Expand Up @@ -270,6 +278,11 @@ public boolean executeAsync(String sql) throws SQLException {
LOG.info("kyuubi.operation.result.format: " + resultFormat);
switch (resultFormat) {
case "arrow":
boolean timestampAsString =
Boolean.parseBoolean(
properties.getOrDefault(
"__kyuubi_operation_result_arrow_timestampAsString__",
DEFAULT_ARROW_TIMESTAMP_AS_STRING));
resultSet =
new KyuubiArrowQueryResultSet.Builder(this)
.setClient(client)
Expand All @@ -279,6 +292,7 @@ public boolean executeAsync(String sql) throws SQLException {
.setFetchSize(fetchSize)
.setScrollable(isScrollableResultset)
.setSchema(columnNames, columnTypes, columnAttributes)
.setTimestampAsString(timestampAsString)
.build();
break;
default:
Expand Down
Loading