From 632b5dc202cce7ec7e2826018e06d120a4cd33d1 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Sat, 2 Jan 2016 12:34:05 +0100 Subject: [PATCH] Prevent leaking state from internal/external row. --- .../spark/sql/catalyst/expressions/rows.scala | 8 ++--- .../scala/org/apache/spark/sql/RowTest.scala | 30 +++++++++++++++++++ 2 files changed, 34 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index cfc68fc00bea8..814b3c22f8806 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -199,9 +199,9 @@ class GenericRow(protected[sql] val values: Array[Any]) extends Row { override def get(i: Int): Any = values(i) - override def toSeq: Seq[Any] = values.toSeq + override def toSeq: Seq[Any] = values.clone() - override def copy(): Row = this + override def copy(): GenericRow = this } class GenericRowWithSchema(values: Array[Any], override val schema: StructType) @@ -226,11 +226,11 @@ class GenericInternalRow(private[sql] val values: Array[Any]) extends BaseGeneri override protected def genericGet(ordinal: Int) = values(ordinal) - override def toSeq(fieldTypes: Seq[DataType]): Seq[Any] = values + override def toSeq(fieldTypes: Seq[DataType]): Seq[Any] = values.clone() override def numFields: Int = values.length - override def copy(): InternalRow = new GenericInternalRow(values.clone()) + override def copy(): GenericInternalRow = this } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala index 5c22a72192541..61361ac17547d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala @@ -104,4 +104,34 @@ class RowTest extends FunSpec with Matchers { internalRow shouldEqual internalRow2 } } + + describe("row immutability") { + val values = Array(1, 2, "3", "IV", 6L) + val externalRow = Row(values.clone(): _*) + val internalRow = InternalRow(values.clone(): _*) + + def modifyValues(values: Seq[Any]): Seq[Any] = { + val array = values.toArray + array(2) = "42" + array + } + + it("copy should return same ref for external rows") { + externalRow should be theSameInstanceAs externalRow.copy() + } + + it("copy should return same ref for interal rows") { + internalRow should be theSameInstanceAs internalRow.copy() + } + + it("toSeq should not expose internal state for external rows") { + val modifiedValues = modifyValues(externalRow.toSeq) + externalRow.toSeq should not equal modifiedValues + } + + it("toSeq should not expose internal state for internal rows") { + val modifiedValues = modifyValues(internalRow.toSeq(Seq.empty)) + internalRow.toSeq(Seq.empty) should not equal modifiedValues + } + } }