From 7f098bc807a2b97e5ee6b9723acde42422522b9a Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Fri, 1 May 2015 14:51:54 -0700 Subject: [PATCH] add crosstab pyTest --- python/pyspark/sql/dataframe.py | 7 +------ python/pyspark/sql/tests.py | 8 ++++++++ 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index cca5dfccdd82c..3877de7e2aaa1 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -899,17 +899,12 @@ def crosstab(self, col1, col2): :param col1: The name of the first column :param col2: The name of the second column - - >>> df3.crosstab("age", "height").show() - age_height 80 85 - 2 1 1 - 5 1 1 """ if not isinstance(col1, str): raise ValueError("col1 should be a string.") if not isinstance(col2, str): raise ValueError("col2 should be a string.") - return self._jdf.stat().crosstab(col1, col2) + return DataFrame(self._jdf.stat().crosstab(col1, col2), self.sql_ctx) @ignore_unicode_prefix def withColumn(self, colName, col): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 44c8b6a1aac13..4ad56577cc551 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -392,6 +392,14 @@ def test_cov(self): cov = df.stat.cov("a", "b") self.assertTrue(abs(cov - 55.0 / 3) < 1e-6) + def test_crosstab(self): + df = self.sc.parallelize([Row(a=i % 3, b=i % 2) for i in range(1, 7)]).toDF() + ct = df.stat.crosstab("a", "b") + for i, row in enumerate(ct.collect()): + self.assertEqual(row[0], str(i)) + self.assertTrue(row[1], 1) + self.assertTrue(row[2], 1) + def test_math_functions(self): df = self.sc.parallelize([Row(a=i, b=2 * i) for i in range(10)]).toDF() from pyspark.sql import mathfunctions as functions