From 27a5a81de859b41b39174fcdd9aad8df80703b5b Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Fri, 1 May 2015 12:01:07 -0700 Subject: [PATCH 1/8] implemented crosstab --- .../spark/sql/DataFrameStatFunctions.scala | 16 +++++++- .../sql/execution/stat/ContingencyTable.scala | 38 +++++++++++++++++++ .../apache/spark/sql/JavaDataFrameSuite.java | 18 +++++++++ .../apache/spark/sql/DataFrameStatSuite.scala | 17 +++++++++ 4 files changed, 88 insertions(+), 1 deletion(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/stat/ContingencyTable.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index 42e5cbc05e1e0..ba90888226194 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.execution.stat.FrequentItems +import org.apache.spark.sql.execution.stat.{ContingencyTable, FrequentItems} /** * :: Experimental :: @@ -27,6 +27,20 @@ import org.apache.spark.sql.execution.stat.FrequentItems @Experimental final class DataFrameStatFunctions private[sql](df: DataFrame) { + /** + * Computes a pair-wise frequency table of the given columns. Also known as a contingency table. + * The number of distinct values for each column should be less than Int.MaxValue. The first + * column of each row will be the distinct values of `col1` and the column names will be the + * distinct values of `col2` sorted in lexicographical order. Counts will be returned as `Long`s. + * + * @param col1 The name of the first column. + * @param col2 The name of the second column. + * @return A Local DataFrame containing the table + */ + def crosstab(col1: String, col2: String): DataFrame = { + ContingencyTable.crossTabulate(df, col1, col2) + } + /** * Finding frequent items for columns, possibly with false positives. Using the * frequent element count algorithm described in diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/ContingencyTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/ContingencyTable.scala new file mode 100644 index 0000000000000..916df88fec957 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/ContingencyTable.scala @@ -0,0 +1,38 @@ +package org.apache.spark.sql.execution.stat + +import org.apache.spark.sql.{Row, DataFrame} +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation +import org.apache.spark.sql.types._ +import org.apache.spark.sql.functions._ + + +private[sql] object ContingencyTable { + + /** Generate a table of frequencies for the elements of two columns. */ + private[sql] def crossTabulate(df: DataFrame, col1: String, col2: String): DataFrame = { + val tableName = s"${col1}_$col2" + val distinctVals = df.select(countDistinct(col1), countDistinct(col2)).collect().head + val distinctCol1 = distinctVals.getLong(0) + val distinctCol2 = distinctVals.getLong(1) + + require(distinctCol1 < Int.MaxValue, s"The number of distinct values for $col1, can't " + + s"exceed Int.MaxValue. Currently $distinctCol1") + require(distinctCol2 < Int.MaxValue, s"The number of distinct values for $col2, can't " + + s"exceed Int.MaxValue. Currently $distinctCol2") + // Aggregate the counts for the two columns + val allCounts = + df.groupBy(col1, col2).agg(col(col1), col(col2), count("*")).orderBy(col1, col2).collect() + // Pivot the table + val pivotedTable = allCounts.grouped(distinctCol2.toInt).toArray + // Get the column names (distinct values of col2) + val headerNames = pivotedTable.head.map(r => StructField(r.get(1).toString, LongType)) + val schema = StructType(StructField(tableName, StringType) +: headerNames) + val table = pivotedTable.map { rows => + // the value of col1 is the first value, the rest are the counts + val rowValues = rows.head.get(0).toString +: rows.map(_.getLong(2)) + Row(rowValues:_*) + } + new DataFrame(df.sqlContext, LocalRelation(schema.toAttributes, table)) + } + +} diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index ebe96e649d940..3896ef96ffb40 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -178,6 +178,24 @@ public void testCreateDataFrameFromJavaBeans() { Assert.assertEquals(bean.getD().get(i), d.apply(i)); } } + + @Test + public void testCrosstab() { + DataFrame df = context.table("testData2"); + DataFrame crosstab = df.stat().crosstab("a", "b"); + String[] columnNames = crosstab.schema().fieldNames(); + Assert.assertEquals(columnNames[0], "a_b"); + Assert.assertEquals(columnNames[1], "1"); + Assert.assertEquals(columnNames[2], "2"); + Row[] rows = crosstab.collect(); + Integer count = 1; + for (Row row : rows) { + Assert.assertEquals(row.get(0).toString(), count.toString()); + Assert.assertEquals(row.getLong(1), 1L); + Assert.assertEquals(row.getLong(2), 1L); + count++; + } + } @Test public void testFrequentItems() { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index bb1d29c71d23b..c0afef2b7fe11 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -24,9 +24,26 @@ import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext.implicits._ class DataFrameStatSuite extends FunSuite { + import TestData._ val sqlCtx = TestSQLContext + test("crosstab") { + val crosstab = testData2.stat.crosstab("a", "b") + val columnNames = crosstab.schema.fieldNames + assert(columnNames(0) === "a_b") + assert(columnNames(1) === "1") + assert(columnNames(2) === "2") + val rows: Array[Row] = crosstab.collect() + var count: Integer = 1 + rows.foreach { row => + assert(row.get(0).toString === count.toString) + assert(row.getLong(1) === 1L) + assert(row.getLong(2) === 1L) + count += 1 + } + } + test("Frequent Items") { def toLetter(i: Int): String = (i + 96).toChar.toString val rows = Array.tabulate(1000) { i => From 7f098bc807a2b97e5ee6b9723acde42422522b9a Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Fri, 1 May 2015 14:51:54 -0700 Subject: [PATCH 2/8] add crosstab pyTest --- python/pyspark/sql/dataframe.py | 7 +------ python/pyspark/sql/tests.py | 8 ++++++++ 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index cca5dfccdd82c..3877de7e2aaa1 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -899,17 +899,12 @@ def crosstab(self, col1, col2): :param col1: The name of the first column :param col2: The name of the second column - - >>> df3.crosstab("age", "height").show() - age_height 80 85 - 2 1 1 - 5 1 1 """ if not isinstance(col1, str): raise ValueError("col1 should be a string.") if not isinstance(col2, str): raise ValueError("col2 should be a string.") - return self._jdf.stat().crosstab(col1, col2) + return DataFrame(self._jdf.stat().crosstab(col1, col2), self.sql_ctx) @ignore_unicode_prefix def withColumn(self, colName, col): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 44c8b6a1aac13..4ad56577cc551 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -392,6 +392,14 @@ def test_cov(self): cov = df.stat.cov("a", "b") self.assertTrue(abs(cov - 55.0 / 3) < 1e-6) + def test_crosstab(self): + df = self.sc.parallelize([Row(a=i % 3, b=i % 2) for i in range(1, 7)]).toDF() + ct = df.stat.crosstab("a", "b") + for i, row in enumerate(ct.collect()): + self.assertEqual(row[0], str(i)) + self.assertTrue(row[1], 1) + self.assertTrue(row[2], 1) + def test_math_functions(self): df = self.sc.parallelize([Row(a=i, b=2 * i) for i in range(10)]).toDF() from pyspark.sql import mathfunctions as functions From 939b7c460698a5ec480048cbd41d259823edc1bc Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Fri, 1 May 2015 14:52:34 -0700 Subject: [PATCH 3/8] lint python --- python/pyspark/sql/dataframe.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 3877de7e2aaa1..e07a4d2004481 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1360,12 +1360,13 @@ def cov(self, col1, col2): return self.df.cov(col1, col2) cov.__doc__ = DataFrame.cov.__doc__ - + def crosstab(self, col1, col2): return self.df.crosstab(col1, col2) crosstab.__doc__ = DataFrame.crosstab.__doc__ + def _test(): import doctest from pyspark.context import SparkContext From 6805df8e34cfcc7520cbda30d6660f93045c948c Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Fri, 1 May 2015 19:42:40 -0700 Subject: [PATCH 4/8] addressed comments and fixed test --- .../org/apache/spark/sql/execution/stat/StatFunctions.scala | 4 ++-- .../scala/org/apache/spark/sql/DataFrameStatSuite.scala | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index cc08479f43163..394d1bce8025b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -82,9 +82,9 @@ private[sql] object StatFunctions { /** Generate a table of frequencies for the elements of two columns. */ private[sql] def crossTabulate(df: DataFrame, col1: String, col2: String): DataFrame = { val tableName = s"${col1}_$col2" - val distinctCol2 = df.select(col2).distinct.orderBy(col2).collect() + val distinctCol2 = df.select(col2).distinct.collect().sortBy(_.get(0).toString) val columnSize = distinctCol2.size - require(columnSize < 1e5, s"The number of distinct values for $col2, can't " + + require(columnSize < 1e4, s"The number of distinct values for $col2, can't " + s"exceed 1e5. Currently $columnSize") var i = 0 val col2Map = distinctCol2.map { r => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index d9d12c5df24bd..aa8c503d341e8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -33,10 +33,10 @@ class DataFrameStatSuite extends FunSuite { val crosstab = df.stat.crosstab("a", "b") val columnNames = crosstab.schema.fieldNames assert(columnNames(0) === "a_b") - assert(columnNames(1) === "1") - assert(columnNames(2) === "2") + assert(columnNames(1) === "0") + assert(columnNames(2) === "1") val rows: Array[Row] = crosstab.collect() - var count: Integer = 1 + var count: Integer = 0 rows.foreach { row => assert(row.get(0).toString === count.toString) assert(row.getLong(1) === 1L) From a63ad0036be42c84ba55ea1f1632390c6a0484c2 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Sun, 3 May 2015 10:52:39 -0700 Subject: [PATCH 5/8] addressed comments v3.0 --- .../apache/spark/sql/DataFrameStatSuite.scala | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index aa8c503d341e8..570127aec5111 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -29,20 +29,22 @@ class DataFrameStatSuite extends FunSuite { def toLetter(i: Int): String = (i + 97).toChar.toString test("crosstab") { - val df = sqlCtx.sparkContext.parallelize((1 to 6).map(i => (i % 3, i % 2))).toDF("a", "b") + val df = Seq.tabulate(8)(i => (i % 3, i % 2)).toDF("a", "b") val crosstab = df.stat.crosstab("a", "b") val columnNames = crosstab.schema.fieldNames assert(columnNames(0) === "a_b") assert(columnNames(1) === "0") assert(columnNames(2) === "1") val rows: Array[Row] = crosstab.collect() - var count: Integer = 0 - rows.foreach { row => - assert(row.get(0).toString === count.toString) - assert(row.getLong(1) === 1L) - assert(row.getLong(2) === 1L) - count += 1 - } + assert(rows(0).get(0).toString === "0") + assert(rows(0).getLong(1) === 2L) + assert(rows(0).getLong(2) === 1L) + assert(rows(1).get(0).toString === "1") + assert(rows(1).getLong(1) === 1L) + assert(rows(1).getLong(2) === 2L) + assert(rows(2).get(0).toString === "2") + assert(rows(2).getLong(1) === 1L) + assert(rows(2).getLong(2) === 1L) } test("Frequent Items") { From bced829ea4cb1c939ecff636aad4c7933a61130e Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Sun, 3 May 2015 11:54:56 -0700 Subject: [PATCH 6/8] fix merge conflicts --- python/pyspark/sql/dataframe.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index fa7b9679225e1..4d3587e69e43e 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1381,17 +1381,15 @@ def cov(self, col1, col2): cov.__doc__ = DataFrame.cov.__doc__ -<<<<<<< HEAD def crosstab(self, col1, col2): return self.df.crosstab(col1, col2) crosstab.__doc__ = DataFrame.crosstab.__doc__ -======= + def freqItems(self, cols, support=None): return self.df.freqItems(cols, support) freqItems.__doc__ = DataFrame.freqItems.__doc__ ->>>>>>> 49549d5a1a867c3ba25f5e4aec351d4102444bc0 def _test(): From ae9e01d48b0451baa219943449d5532ab2226708 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Sun, 3 May 2015 23:43:45 -0700 Subject: [PATCH 7/8] fix test --- python/pyspark/sql/dataframe.py | 2 +- python/pyspark/sql/tests.py | 2 +- .../scala/org/apache/spark/sql/DataFrameStatFunctions.scala | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 6fa322cfd3a58..832c758b592a7 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -915,7 +915,7 @@ def crosstab(self, col1, col2): Computes a pair-wise frequency table of the given columns. Also known as a contingency table. The number of distinct values for each column should be less than 1e4. The first column of each row will be the distinct values of `col1` and the column names will be the - distinct values of `col2`. Pairs that have no occurrences will have `null` as their values. + distinct values of `col2`. Pairs that have no occurrences will have `null` as their counts. :func:`DataFrame.crosstab` and :func:`DataFrameStatFunctions.crosstab` are aliases. :param col1: The name of the first column. Distinct items will make the first item of diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index d6cbd0a046d6b..7ea6656d31c4e 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -408,7 +408,7 @@ def test_cov(self): def test_crosstab(self): df = self.sc.parallelize([Row(a=i % 3, b=i % 2) for i in range(1, 7)]).toDF() ct = df.stat.crosstab("a", "b").collect() - ct = sorted(ct, lambda r: r[0]) + ct = sorted(ct, key=lambda x: x[0]) for i, row in enumerate(ct): self.assertEqual(row[0], str(i)) self.assertTrue(row[1], 1) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index 67327ad5da8c9..6b4e68dfe60ad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -68,7 +68,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * The number of distinct values for each column should be less than 1e4. The first * column of each row will be the distinct values of `col1` and the column names will be the * distinct values of `col2`. Counts will be returned as `Long`s. Pairs that have no occurrences - * will have `null` as their values. + * will have `null` as their counts. * * @param col1 The name of the first column. Distinct items will make the first item of * each row. From a07c01e07f17935f3729185f6be507971b4a4561 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Mon, 4 May 2015 10:14:08 -0700 Subject: [PATCH 8/8] addressed comments v4.1 --- python/pyspark/sql/dataframe.py | 3 ++- .../apache/spark/sql/DataFrameStatFunctions.scala | 4 ++-- .../spark/sql/execution/stat/StatFunctions.scala | 13 +++++++++---- 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 832c758b592a7..b4b5ca51e7991 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -915,7 +915,8 @@ def crosstab(self, col1, col2): Computes a pair-wise frequency table of the given columns. Also known as a contingency table. The number of distinct values for each column should be less than 1e4. The first column of each row will be the distinct values of `col1` and the column names will be the - distinct values of `col2`. Pairs that have no occurrences will have `null` as their counts. + distinct values of `col2`. The name of the first column will be `$col1_$col2`. Pairs that + have no occurrences will have `null` as their counts. :func:`DataFrame.crosstab` and :func:`DataFrameStatFunctions.crosstab` are aliases. :param col1: The name of the first column. Distinct items will make the first item of diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index 6b4e68dfe60ad..fcf21ca741a7c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -67,8 +67,8 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * Computes a pair-wise frequency table of the given columns. Also known as a contingency table. * The number of distinct values for each column should be less than 1e4. The first * column of each row will be the distinct values of `col1` and the column names will be the - * distinct values of `col2`. Counts will be returned as `Long`s. Pairs that have no occurrences - * will have `null` as their counts. + * distinct values of `col2`. The name of the first column will be `$col1_$col2`. Counts will be + * returned as `Long`s. Pairs that have no occurrences will have `null` as their counts. * * @param col1 The name of the first column. Distinct items will make the first item of * each row. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index 8345f17dcd941..b50f606d9cbe3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -17,13 +17,14 @@ package org.apache.spark.sql.execution.stat -import org.apache.spark.sql.{Column, DataFrame, Row} +import org.apache.spark.Logging +import org.apache.spark.sql.{Column, DataFrame} import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, Cast} import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ -private[sql] object StatFunctions { +private[sql] object StatFunctions extends Logging { /** Calculate the Pearson Correlation Coefficient for the given columns */ private[sql] def pearsonCorrelation(df: DataFrame, cols: Seq[String]): Double = { @@ -102,18 +103,22 @@ private[sql] object StatFunctions { private[sql] def crossTabulate(df: DataFrame, col1: String, col2: String): DataFrame = { val tableName = s"${col1}_$col2" val counts = df.groupBy(col1, col2).agg(col(col1), col(col2), count("*")).take(1e8.toInt) + if (counts.length == 1e8.toInt) { + logWarning("The maximum limit of 1e8 pairs have been collected, which may not be all of " + + "the pairs. Please try reducing the amount of distinct items in your columns.") + } // get the distinct values of column 2, so that we can make them the column names val distinctCol2 = counts.map(_.get(1)).distinct.zipWithIndex.toMap val columnSize = distinctCol2.size require(columnSize < 1e4, s"The number of distinct values for $col2, can't " + s"exceed 1e4. Currently $columnSize") - val table = counts.groupBy(_.get(0)).map { case (col1Items, rows) => + val table = counts.groupBy(_.get(0)).map { case (col1Item, rows) => val countsRow = new GenericMutableRow(columnSize + 1) rows.foreach { row => countsRow.setLong(distinctCol2.get(row.get(1)).get + 1, row.getLong(2)) } // the value of col1 is the first value, the rest are the counts - countsRow.setString(0, col1Items.toString) + countsRow.setString(0, col1Item.toString) countsRow }.toSeq val headerNames = distinctCol2.map(r => StructField(r._1.toString, LongType)).toSeq