Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.execution.columnar

import org.apache.spark.TaskContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
Expand Down Expand Up @@ -167,11 +168,23 @@ case class InMemoryTableScanExec(
if (enableAccumulatorsForTest && iter.hasNext) {
readPartitions.add(1)
}
val inputMetrics = TaskContext.get().taskMetrics().inputMetrics
var rowCount = 0L
// RDD.getOrCompute increments inputMetrics.recordsRead by 1 per CachedBatch on
// cache hit, counting batches instead of rows. Register a completion listener that
// corrects the final count to the actual number of rows once all batches in this
// partition have been consumed. The math works for both cache-hit and cache-miss:
// - cache hit: getOrCompute added numBatches; we add (rowCount - numBatches)
// - cache miss: source scan already added rowCount; we add (rowCount - rowCount) = 0
TaskContext.get().addTaskCompletionListener[Unit] { _ =>
inputMetrics.incRecordsRead(rowCount - inputMetrics.recordsRead)
}
iter.map { batch =>
if (enableAccumulatorsForTest) {
readBatches.add(1)
}
numOutputRows += batch.numRows
rowCount += batch.numRows
batch
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import java.sql.{Date, Timestamp}
import java.util.concurrent.atomic.AtomicInteger

import org.apache.spark.rdd.RDD
import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd}
import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSet, In}
Expand Down Expand Up @@ -620,4 +621,29 @@ class InMemoryColumnarQuerySuite extends QueryTest

assert(exceptionCnt.get == 0)
}

test("SPARK-55523: inputMetrics.recordsRead should count rows not CachedBatches") {
val numRows = 1000
val df = spark.range(numRows).toDF()
df.cache()
df.count() // materialize the cache

var totalRecordsRead = 0L
val listener = new SparkListener {
override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
totalRecordsRead += taskEnd.taskMetrics.inputMetrics.recordsRead
}
}
sparkContext.addSparkListener(listener)
try {
df.count() // read from cache
sparkContext.listenerBus.waitUntilEmpty(5000)
} finally {
sparkContext.removeSparkListener(listener)
df.unpersist()
}

assert(totalRecordsRead == numRows,
s"Expected inputMetrics.recordsRead=$numRows (actual rows), but got $totalRecordsRead")
}
}