From 11f1df3470d812f867c68e68ec2094cc08577f7d Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 14 Apr 2015 11:34:13 -0700 Subject: [PATCH 1/6] improve accessor for nested types --- python/pyspark/sql/dataframe.py | 48 ++++++++++++++++--- python/pyspark/sql/tests.py | 18 +++++++ .../scala/org/apache/spark/sql/Column.scala | 7 +-- 3 files changed, 64 insertions(+), 9 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index ef91a9c4f522..fd5109c3cf47 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -545,16 +545,23 @@ def __getitem__(self, item): [Row(name=u'Alice', age=2), Row(name=u'Bob', age=5)] >>> df[ df.age > 3 ].collect() [Row(age=5, name=u'Bob')] + >>> df[df[0] > 3].collect() + [Row(age=5, name=u'Bob')] """ if isinstance(item, basestring): + if item not in self.columns: + raise IndexError("no such column: %s" % item) jc = self._jdf.apply(item) return Column(jc) elif isinstance(item, Column): return self.filter(item) - elif isinstance(item, list): + elif isinstance(item, (list, tuple)): return self.select(*item) + elif isinstance(item, int): + jc = self._jdf.apply(self.columns[item]) + return Column(jc) else: - raise IndexError("unexpected index: %s" % item) + raise TypeError("unexpected type: %s" % type(item)) def __getattr__(self, name): """Returns the :class:`Column` denoted by ``name``. @@ -562,10 +569,11 @@ def __getattr__(self, name): >>> df.select(df.age).collect() [Row(age=2), Row(age=5)] """ - if name.startswith("__"): + try: + jc = self._jdf.apply(name) + return Column(jc) + except Exception: raise AttributeError(name) - jc = self._jdf.apply(name) - return Column(jc) def select(self, *cols): """Projects a set of expressions and returns a new :class:`DataFrame`. @@ -1067,7 +1075,35 @@ def __init__(self, jc): # container operators __contains__ = _bin_op("contains") __getitem__ = _bin_op("getItem") - getField = _bin_op("getField", "An expression that gets a field by name in a StructField.") + __getattr__ = _bin_op("getField") + + def getItem(self, key): + """An expression that gets an item at position `ordinal` out of a list, + or gets an item by key out of a dict. + + >>> df = sc.parallelize([([1, 2], {"key": "value"})]).toDF(["l", "d"]) + >>> df.select(df.l.getItem(0), df.d.getItem("key")).show() + l[0] d[key] + 1 value + >>> df.select(df.l[0], df.d["key"]).show() + l[0] d[key] + 1 value + """ + return self[key] + + def getField(self, name): + """An expression that gets a field by name in a StructField. + + >>> from pyspark.sql import Row + >>> df = sc.parallelize([Row(r=Row(a=1, b="b"))]).toDF() + >>> df.select(df.r.getField("b")).show() + r.b + b + >>> df.select(df.r.a).show() + r.a + 1 + """ + return getattr(self, name) # string methods rlike = _bin_op("rlike") diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index b3a6a2c6a922..b740baed298c 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -426,6 +426,24 @@ def test_help_command(self): pydoc.render_doc(df.foo) pydoc.render_doc(df.take(1)) + def test_access_column(self): + df = self.df + self.assertTrue(isinstance(df.key, Column)) + self.assertTrue(isinstance(df['key'], Column)) + self.assertTrue(isinstance(df[0], Column)) + self.assertRaises(IndexError, lambda: df[2]) + self.assertRaises(IndexError, lambda: df["bad_key"]) + self.assertRaises(TypeError, lambda: df[{}]) + + def test_access_nested_types(self): + df = self.sc.parallelize([Row(l=[1], r=Row(a=1, b="b"), d={"k": "v"})]).toDF() + self.assertEqual(1, df.select(df.l[0]).first()[0]) + self.assertEqual(1, df.select(df.l.getItem(0)).first()[0]) + self.assertEqual(1, df.select(df.r.a).first()[0]) + self.assertEqual("b", df.select(df.r.getField("b")).first()[0]) + self.assertEqual("v", df.select(df.d["k"]).first()[0]) + self.assertEqual("v", df.select(df.d.getItem("k")).first()[0]) + def test_infer_long_type(self): longrow = [Row(f1='a', f2=100000000000000)] df = self.sc.parallelize(longrow).toDF() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 3cd7adf8cab5..edb229c059e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -515,14 +515,15 @@ class Column(protected[sql] val expr: Expression) extends Logging { def rlike(literal: String): Column = RLike(expr, lit(literal).expr) /** - * An expression that gets an item at position `ordinal` out of an array. + * An expression that gets an item at position `ordinal` out of an array, + * or gets a value by key `key` in a [[MapType]]. * * @group expr_ops */ - def getItem(ordinal: Int): Column = GetItem(expr, Literal(ordinal)) + def getItem(key: Any): Column = GetItem(expr, Literal(key)) /** - * An expression that gets a field by name in a [[StructField]]. + * An expression that gets a field by name in a [[StructType]]. * * @group expr_ops */ From 6c32e79b8b54aae5a8a8627eed6bdd949442f97a Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 15 Apr 2015 00:10:44 -0700 Subject: [PATCH 2/6] add scala tests --- .../test/scala/org/apache/spark/sql/DataFrameSuite.scala | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index f5df8c6a59f1..e9cc0acbe53d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -86,6 +86,12 @@ class DataFrameSuite extends QueryTest { TestSQLContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, oldSetting.toString) } + test("access complex data") { + assert(complexData.filter(complexData("a").getItem(0) === 2).count() == 1) + assert(complexData.filter(complexData("m").getItem(1) === "1").count() == 1) + assert(complexData.filter(complexData("s").getField("key") === 1).count() == 1) + } + test("table scan") { checkAnswer( testData, From 6b62540c1aaac73783fcc27103c70c8bd4f85e55 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 15 Apr 2015 10:49:31 -0700 Subject: [PATCH 3/6] fix test --- python/pyspark/sql/dataframe.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index fd5109c3cf47..b7638ff2a6e3 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1075,7 +1075,6 @@ def __init__(self, jc): # container operators __contains__ = _bin_op("contains") __getitem__ = _bin_op("getItem") - __getattr__ = _bin_op("getField") def getItem(self, key): """An expression that gets an item at position `ordinal` out of a list, @@ -1103,7 +1102,12 @@ def getField(self, name): r.a 1 """ - return getattr(self, name) + return Column(self._jc.getField(name)) + + def __getattr__(self, item): + if item.startswith("__"): + raise AttributeError(item) + return self.getField(item) # string methods rlike = _bin_op("rlike") From d125ac46d353054ff430cf8fb6726def4e3a9b81 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 15 Apr 2015 15:34:03 -0700 Subject: [PATCH 4/6] check column name, improve scala tests --- python/pyspark/sql/dataframe.py | 9 ++++----- .../test/scala/org/apache/spark/sql/DataFrameSuite.scala | 2 +- .../src/test/scala/org/apache/spark/sql/TestData.scala | 9 ++++----- 3 files changed, 9 insertions(+), 11 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index b7638ff2a6e3..a9f036e93f0f 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -569,11 +569,10 @@ def __getattr__(self, name): >>> df.select(df.age).collect() [Row(age=2), Row(age=5)] """ - try: - jc = self._jdf.apply(name) - return Column(jc) - except Exception: - raise AttributeError(name) + if name not in self.columns: + raise AttributeError("No such column: %s" % name) + jc = self._jdf.apply(name) + return Column(jc) def select(self, *cols): """Projects a set of expressions and returns a new :class:`DataFrame`. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 12297be57566..34b2cb054a3e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -88,7 +88,7 @@ class DataFrameSuite extends QueryTest { test("access complex data") { assert(complexData.filter(complexData("a").getItem(0) === 2).count() == 1) - assert(complexData.filter(complexData("m").getItem(1) === "1").count() == 1) + assert(complexData.filter(complexData("m").getItem("1") === 1).count() == 1) assert(complexData.filter(complexData("s").getField("key") === 1).count() == 1) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index 637f59b2e68c..225b51bd73d6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -20,9 +20,8 @@ package org.apache.spark.sql import java.sql.Timestamp import org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test._ import org.apache.spark.sql.test.TestSQLContext.implicits._ +import org.apache.spark.sql.test._ case class TestData(key: Int, value: String) @@ -199,11 +198,11 @@ object TestData { Salary(1, 1000.0) :: Nil).toDF() salary.registerTempTable("salary") - case class ComplexData(m: Map[Int, String], s: TestData, a: Seq[Int], b: Boolean) + case class ComplexData(m: Map[String, Int], s: TestData, a: Seq[Int], b: Boolean) val complexData = TestSQLContext.sparkContext.parallelize( - ComplexData(Map(1 -> "1"), TestData(1, "1"), Seq(1), true) - :: ComplexData(Map(2 -> "2"), TestData(2, "2"), Seq(2), false) + ComplexData(Map("1" -> 1), TestData(1, "1"), Seq(1), true) + :: ComplexData(Map("2" -> 2), TestData(2, "2"), Seq(2), false) :: Nil).toDF() complexData.registerTempTable("complexData") } From 7ada9ebe2405963eba610c8cfcf227204ff74fdb Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 16 Apr 2015 09:29:01 -0700 Subject: [PATCH 5/6] update timeout --- dev/run-tests-jenkins | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dev/run-tests-jenkins b/dev/run-tests-jenkins index 3c1c91a11135..06e7cd3acc23 100755 --- a/dev/run-tests-jenkins +++ b/dev/run-tests-jenkins @@ -47,7 +47,7 @@ COMMIT_URL="https://github.com/apache/spark/commit/${ghprbActualCommit}" # GitHub doesn't auto-link short hashes when submitted via the API, unfortunately. :( SHORT_COMMIT_HASH="${ghprbActualCommit:0:7}" -TESTS_TIMEOUT="120m" # format: http://linux.die.net/man/1/timeout +TESTS_TIMEOUT="180m" # format: http://linux.die.net/man/1/timeout # Array to capture all tests to run on the pull request. These tests are held under the #+ dev/tests/ directory. From e04d5a03db1659ef4d0bd58b9fcc18c93df79998 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 16 Apr 2015 10:21:53 -0700 Subject: [PATCH 6/6] Update run-tests-jenkins --- dev/run-tests-jenkins | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dev/run-tests-jenkins b/dev/run-tests-jenkins index 06e7cd3acc23..030f2cdddb35 100755 --- a/dev/run-tests-jenkins +++ b/dev/run-tests-jenkins @@ -47,7 +47,7 @@ COMMIT_URL="https://github.com/apache/spark/commit/${ghprbActualCommit}" # GitHub doesn't auto-link short hashes when submitted via the API, unfortunately. :( SHORT_COMMIT_HASH="${ghprbActualCommit:0:7}" -TESTS_TIMEOUT="180m" # format: http://linux.die.net/man/1/timeout +TESTS_TIMEOUT="150m" # format: http://linux.die.net/man/1/timeout # Array to capture all tests to run on the pull request. These tests are held under the #+ dev/tests/ directory.