Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,10 @@ object CachedColumnarBatchKryoSerializer {
val STATS_FRAMED_MAGIC: Array[Byte] =
Array[Byte](0xfe.toByte, 0xca.toByte, 0x53.toByte, 0x02.toByte)

// V3 magic: same as V2 but last byte = 0x03.
val STATS_FRAMED_MAGIC_V3: Array[Byte] =
Array[Byte](0xfe.toByte, 0xca.toByte, 0x53.toByte, 0x03.toByte)

// Per-column statsBlob layout (LE throughout, matches the cpp emitter in
// VeloxColumnarBatchSerializer.cc):
//
Expand Down Expand Up @@ -605,23 +609,36 @@ object CachedColumnarBatchKryoSerializer {
}

/**
* Parse the JNI `serializeWithStats` framed return into (stats InternalRow, bytesBlob).
*
* Framed layout (matches cpp VeloxColumnarBatchSerializer.cc): `[ STATS_FRAMED_MAGIC: 4B ] [
* statsLen: u32 LE ] [ statsBlob ] [ bytesLen: u32 LE ] [ bytesBlob ]`.
* Parse the JNI `serializeWithStats` framed return into (stats InternalRow, bytesBlob). Routes on
* magic byte[3]: 0x02 -> V2, 0x03 -> V3.
*
* Eager guards catch corrupt magic / truncated framing before they propagate.
* V2 layout: `[ magic: 4B ] [ statsLen: u32 LE ] [ statsBlob ] [ bytesLen: u32 LE ] [ bytesBlob
* ]` V3 layout: `[ magic: 4B ] [ statsLen: u32 LE ] [ statsBlob ] [ numRows: u32 LE ] [ numCols:
* u32 LE ] [ per-col ]`
*/
private[execution] def parseFramedBytes(
framed: Array[Byte],
schema: StructType): (InternalRow, Array[Byte]) = {
// V2 minimum = 4+4+4=12B; V3 minimum = 4+4+4+4=16B; use 12 for dispatcher guard.
require(
framed != null && framed.length >= 4 + 4 + 4,
framed != null && framed.length >= 12,
s"framed bytes too short: len=${if (framed == null) -1 else framed.length}")
val magicVersion = framed(3) & 0xff
magicVersion match {
case 0x02 => parseV2Frame(framed, schema)
case 0x03 => parseV3Frame(framed, schema)
case other =>
throw new IllegalArgumentException(
f"framed bytes magic version 0x$other%02X unknown; expected 0x02(V2) or 0x03(V3)")
}
}

/** V2 parse: extract stats + pure Presto bytesBlob. */
private def parseV2Frame(framed: Array[Byte], schema: StructType): (InternalRow, Array[Byte]) = {
require(
framed(0) == STATS_FRAMED_MAGIC(0) && framed(1) == STATS_FRAMED_MAGIC(1) &&
framed(2) == STATS_FRAMED_MAGIC(2) && framed(3) == STATS_FRAMED_MAGIC(3),
f"framed bytes magic mismatch: expected " +
f"V2 framed bytes magic mismatch: expected " +
f"0x${STATS_FRAMED_MAGIC(0) & 0xff}%02X${STATS_FRAMED_MAGIC(1) & 0xff}%02X" +
f"${STATS_FRAMED_MAGIC(2) & 0xff}%02X${STATS_FRAMED_MAGIC(3) & 0xff}%02X, got " +
f"0x${framed(0) & 0xff}%02X${framed(1) & 0xff}%02X" +
Expand All @@ -632,18 +649,37 @@ object CachedColumnarBatchKryoSerializer {
val statsLen = buf.getInt
require(
statsLen >= 0 && statsLen <= buf.remaining() - 4,
s"framed bytes statsLen=$statsLen exceeds remaining buffer ${buf.remaining() - 4}")
s"V2 framed bytes statsLen=$statsLen exceeds remaining buffer ${buf.remaining() - 4}")
val statsBlob = new Array[Byte](statsLen)
buf.get(statsBlob)
val stats = deserializeStats(statsBlob, schema)
val bytesLen = buf.getInt
require(
bytesLen >= 0 && bytesLen == buf.remaining(),
s"framed bytes bytesLen=$bytesLen != remaining ${buf.remaining()} (truncated or trailing)")
s"V2 framed bytes bytesLen=$bytesLen != remaining ${buf.remaining()} (truncated or trailing)")
val bytesBlob = new Array[Byte](bytesLen)
buf.get(bytesBlob)
(stats, bytesBlob)
}

/**
* V3 parse: extract stats; bytes = the full V3 framed array (C++ deserializeV3 starts at magic).
* Invariant: returned bytes[0..3] == V3 magic; C++ deserializeV3 re-validates.
*/
private def parseV3Frame(framed: Array[Byte], schema: StructType): (InternalRow, Array[Byte]) = {
require(framed.length >= 16, s"V3 framed bytes too short (min 16B): len=${framed.length}")
val buf = ByteBuffer.wrap(framed).order(ByteOrder.LITTLE_ENDIAN)
buf.position(4) // skip magic
val statsLen = buf.getInt
require(
statsLen >= 0 && statsLen <= buf.remaining() - 8, // 8 = numRows(4)+numCols(4)
s"V3 framed bytes statsLen=$statsLen invalid")
val statsBlob = new Array[Byte](statsLen)
buf.get(statsBlob)
val stats = deserializeStats(statsBlob, schema)
// Return full framed bytes; C++ deserializeV3 will skip magic+stats and per-col.
(stats, framed)
}
}

/**
Expand Down Expand Up @@ -750,8 +786,11 @@ class ColumnarCachedBatchSerializer extends SimpleMetricsCachedBatchSerializer
val structSchema = StructType(
schema.map(a => StructField(a.name, a.dataType, a.nullable)))
val backendName = BackendsApiManager.getBackendName
// Hoist partition-level configs: GlutenConfig.get allocates a fresh object on each call.
val partitionStatsEnabled =
GlutenConfig.get.getConf(GlutenConfig.COLUMNAR_TABLE_CACHE_PARTITION_STATS_ENABLED)
val lazyEnabled =
GlutenConfig.get.getConf(GlutenConfig.COLUMNAR_TABLE_CACHE_LAZY_DESERIALIZATION_ENABLED)
val jni = ColumnarBatchSerializerJniWrapper.create(
Runtimes.contextInstance(
backendName,
Expand All @@ -778,7 +817,25 @@ class ColumnarCachedBatchSerializer extends SimpleMetricsCachedBatchSerializer
// UnsatisfiedLinkError on the first invocation; we catch it once, cache the
// result, and fall back to the legacy serialize() path emitting stats=null. The
// buildFilter wrapper directs such batches through without pruning.
if (partitionStatsEnabled && ColumnarCachedBatchSerializer.statsExtAvailable) {
if (lazyEnabled && ColumnarCachedBatchSerializer.statsExtV3Available) {
// V3 path: per-column serialization + stats.
ColumnarCachedBatchSerializer.serializeOneBatchWithStatsV3(
jni,
handle,
batch.numRows(),
structSchema,
() =>
if (partitionStatsEnabled && ColumnarCachedBatchSerializer.statsExtAvailable) {
ColumnarCachedBatchSerializer.serializeOneBatchWithStats(
jni,
handle,
batch.numRows(),
structSchema,
() => legacySerializeInline())
} else legacySerializeInline()
)
} else if (partitionStatsEnabled && ColumnarCachedBatchSerializer.statsExtAvailable) {
// V2 stats path.
ColumnarCachedBatchSerializer.serializeOneBatchWithStats(
jni,
handle,
Expand Down Expand Up @@ -812,6 +869,8 @@ class ColumnarCachedBatchSerializer extends SimpleMetricsCachedBatchSerializer
a => cacheAttributes.map(_.exprId).indexOf(a.exprId)
}
val shouldSelectAttributes = cacheAttributes != selectedAttributes
val lazyEnabled =
GlutenConfig.get.getConf(GlutenConfig.COLUMNAR_TABLE_CACHE_LAZY_DESERIALIZATION_ENABLED)
val localSchema = toStructType(cacheAttributes)
val timezoneId = SQLConf.get.sessionLocalTimeZone
input.mapPartitions {
Expand All @@ -835,21 +894,37 @@ class ColumnarCachedBatchSerializer extends SimpleMetricsCachedBatchSerializer

override def next(): ColumnarBatch = {
val cachedBatch = it.next().asInstanceOf[CachedColumnarBatch]
val batchHandle =
jniWrapper
.deserialize(deserializerHandle, cachedBatch.bytes)
val batch = ColumnarBatches.create(batchHandle)
if (shouldSelectAttributes) {
try {
ColumnarBatches.select(
BackendsApiManager.getBackendName,
batch,
requestedColumnIndices.toArray)
} finally {
batch.close()
}
// V3 bytes are ALWAYS routed to deserializeWithProjection.
// V3 framed bytes must NOT go to jni.deserialize() (expects Presto format).
if (isV3Format(cachedBatch.bytes)) {
// Column projection is always M-column, regardless of lazyEnabled.
// lazyEnabled controls WHEN columns are loaded (lazy vs eager), not HOW MANY.
val reqIndices: Array[Int] =
if (cacheAttributes == selectedAttributes) null // all cols: C++ loadAll
else if (requestedColumnIndices.isEmpty) Array.empty[Int] // count(*): 0 cols
else requestedColumnIndices.toArray // projection: M cols
val batchHandle = jniWrapper.deserializeWithProjection(
deserializerHandle,
cachedBatch.bytes,
reqIndices)
ColumnarBatches.create(batchHandle)
// No ColumnarBatches.select(): C++ returns M-column batch.
} else {
batch
// V2 path (original logic).
val batchHandle = jniWrapper.deserialize(deserializerHandle, cachedBatch.bytes)
val batch = ColumnarBatches.create(batchHandle)
if (shouldSelectAttributes) {
try {
ColumnarBatches.select(
BackendsApiManager.getBackendName,
batch,
requestedColumnIndices.toArray)
} finally {
batch.close()
}
} else {
batch
}
}
}
})
Expand Down Expand Up @@ -898,6 +973,12 @@ class ColumnarCachedBatchSerializer extends SimpleMetricsCachedBatchSerializer
}
}

/** True iff bytes starts with V3 magic (0xFE 0xCA 0x53 0x03). */
private def isV3Format(bytes: Array[Byte]): Boolean =
bytes != null && bytes.length >= 4 &&
(bytes(0) & 0xff) == 0xfe && (bytes(1) & 0xff) == 0xca &&
(bytes(2) & 0xff) == 0x53 && (bytes(3) & 0xff) == 0x03

override def buildFilter(
predicates: Seq[Expression],
cachedAttributes: Seq[Attribute])
Expand Down Expand Up @@ -1029,4 +1110,58 @@ object ColumnarCachedBatchSerializer extends Logging {
)
}
}

// Visible for testing: reset the capability flag so a unit test can re-exercise the
// probe-once semantics.
private[execution] def resetStatsExtAvailableForTesting(): Unit = {
statsExtAvailableFlag = true
}

// V3 lazy deserialization support

// Separate capability latch for the V3 JNI symbol (framedSerializeWithStatsV3).
@volatile private var statsExtV3AvailableFlag: Boolean = true

def statsExtV3Available: Boolean = statsExtV3AvailableFlag

def markStatsExtV3Unavailable(cause: Throwable): Unit = {
if (statsExtV3AvailableFlag) {
statsExtV3AvailableFlag = false
logWarning(
"serializeWithStatsV3 JNI returned null (backend not supported or symbol missing); " +
"disabling V3 per-column lazy deserialization for this JVM. " +
"This typically indicates a Gluten jar / native library version mismatch.",
cause
)
}
}

// V3 per-batch serialization: identical two-arm catch structure to serializeOneBatchWithStats.
// null return from JNI = non-Velox backend; treated as one-shot latch, not corrupt frame.
private[execution] def serializeOneBatchWithStatsV3(
jni: ColumnarBatchSerializerJniWrapper,
handle: Long,
numRows: Int,
structSchema: StructType,
fallbackToV2: () => CachedBatch): CachedBatch = {
try {
val framed = jni.serializeWithStatsV3(handle)
if (framed == null) {
// Non-Velox backend returns null; set latch and fall back.
markStatsExtV3Unavailable(
new RuntimeException("framedSerializeWithStatsV3 returned null (backend not supported)"))
return fallbackToV2()
}
val (stats, _) = CachedColumnarBatchKryoSerializer.parseFramedBytes(framed, structSchema)
// bytes = full V3 frame (C++ deserializeV3 parses from byte 0 including magic).
CachedColumnarBatch(numRows, framed.length, framed, stats, structSchema)
} catch {
case e: UnsatisfiedLinkError =>
markStatsExtV3Unavailable(e)
fallbackToV2()
case NonFatal(e) =>
warnCorruptStatsFrame(e) // count against shared corrupt-frame cap
fallbackToV2()
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -509,4 +509,39 @@ class ColumnarCachedBatchE2ESuite
}
}
}

// V3 lazy deserialization smoke tests

test("V3 enabled: cache + equality filter produces correct result") {
withSQLConf(GlutenConfig.COLUMNAR_TABLE_CACHE_LAZY_DESERIALIZATION_ENABLED.key -> "true") {
val cached = cacheRange()
try {
cached.count()
val result = cached.filter(col("k") === pivot).count()
assert(result == 1L, s"V3: expected 1 row matching k=$pivot, got $result")
} finally {
cached.unpersist()
}
}
}

test("V3 enabled: multi-column cache, partial projection, no crash") {
withSQLConf(GlutenConfig.COLUMNAR_TABLE_CACHE_LAZY_DESERIALIZATION_ENABLED.key -> "true") {
val cached = spark
.range(N)
.selectExpr(
"cast(id as bigint) as a",
"cast(id*2 as bigint) as b",
"cast(id+1 as bigint) as c")
.repartitionByRange(P, col("a"))
.cache()
try {
cached.count()
val result = cached.filter(col("a") === pivot).select("a", "c").count()
assert(result == 1L, s"V3 projection: expected 1 row, got $result")
} finally {
cached.unpersist()
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -409,4 +409,88 @@ class ColumnarCachedBatchFramedBytesSuite extends AnyFunSuite {
assert(statsBlob(4) === 0.toByte, "supported byte must be 0 (no-bounds branch)")
}

// V3 framing tests

/** Build a minimal V3 framed byte array with one empty column. */
private def craftV3Framed(
statsBlob: Array[Byte],
numRows: Int,
numCols: Int,
colBytesLists: List[Array[Byte]]): Array[Byte] = {
val out = new java.io.ByteArrayOutputStream()
// V3 magic
out.write(Array[Byte](0xfe.toByte, 0xca.toByte, 0x53.toByte, 0x03.toByte))
writeU32LE(out, statsBlob.length)
out.write(statsBlob)
writeU32LE(out, numRows)
writeU32LE(out, numCols)
colBytesLists.foreach {
cb =>
writeU32LE(out, cb.length)
out.write(cb)
}
out.toByteArray
}

test("V3: parseFramedBytes routes magic 0x03 to parseV3Frame") {
val stats: InternalRow = new GenericInternalRow(Array[Any](1L, 10L, 0, 5, 100L))
val statsBlob = CachedColumnarBatchKryoSerializer.serializeStats(stats, null)
val colBytes = Array[Byte](0xab.toByte, 0xcd.toByte) // dummy column bytes
val framed = craftV3Framed(statsBlob, 5, 1, List(colBytes))

val (parsedStats, returnedBytes) =
CachedColumnarBatchKryoSerializer.parseFramedBytes(framed, null)
// stats should be extracted correctly
assert(parsedStats != null)
// bytes = full V3 frame (C++ will parse from magic)
assert(
java.util.Arrays.equals(returnedBytes, framed),
"V3: returned bytes must equal full frame")
}

test("V3: wrong magic version throws with clear message") {
val badMagic = Array[Byte](0xfe.toByte, 0xca.toByte, 0x53.toByte, 0x05.toByte) // unknown 0x05
// Need at least 12 bytes to pass the length guard.
val padded = badMagic ++ Array.fill(12)(0.toByte)
val ex = intercept[IllegalArgumentException] {
CachedColumnarBatchKryoSerializer.parseFramedBytes(padded, null)
}
assert(
ex.getMessage.contains("0x05") || ex.getMessage.toLowerCase(Locale.ROOT).contains("magic"),
s"expected version/magic info in message, got: ${ex.getMessage}"
)
}

test("V3: too-short frame (< 12 bytes) rejected by dispatcher") {
val shortV3 = Array[Byte](0xfe.toByte, 0xca.toByte, 0x53.toByte, 0x03.toByte, 0, 0)
intercept[IllegalArgumentException] {
CachedColumnarBatchKryoSerializer.parseFramedBytes(shortV3, null)
}
}

test("V3: frame with truncated colLen claim is correctly formed at JVM layer") {
// Build a V3 frame where numCols=1 but colLen says the bytes extend beyond the actual buffer.
// The JVM parseV3Frame does NOT validate per-column bounds (that is C++ deserializeV3's job).
// This test documents the JVM boundary: parseV3Frame only extracts stats; the frame is
// returned intact for C++ to validate. We test that parseV3Frame succeeds for a well-formed
// header regardless of per-column content.
val stats: InternalRow = new GenericInternalRow(Array[Any](1L, 10L, 0, 5, 100L))
val statsBlob = CachedColumnarBatchKryoSerializer.serializeStats(stats, null)
val colBytes = Array[Byte](0xab.toByte, 0xcd.toByte)
val framed = craftV3Framed(statsBlob, 5, 1, List(colBytes))
val (parsedStats, returnedBytes) =
CachedColumnarBatchKryoSerializer.parseFramedBytes(framed, null)
assert(parsedStats != null)
assert(java.util.Arrays.equals(returnedBytes, framed))
}

test("V3 + V2: V2 frames still parsed correctly after V3 magic added") {
val stats: InternalRow = new GenericInternalRow(Array[Any](5L, 50L, 0, 10, 200L))
val payload = Array[Byte](10, 20, 30)
val v2Framed = craftFramed(stats, payload) // V2 magic 0x02
val (parsedStats, bytesBlob) =
CachedColumnarBatchKryoSerializer.parseFramedBytes(v2Framed, null)
assert(parsedStats != null)
assert(java.util.Arrays.equals(bytesBlob, payload), "V2 bytesBlob must be pure Presto bytes")
}
}
Loading
Loading