diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index aac5b8c4c5770..b4b5ca51e7991 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -910,6 +910,26 @@ def cov(self, col1, col2): raise ValueError("col2 should be a string.") return self._jdf.stat().cov(col1, col2) + def crosstab(self, col1, col2): + """ + Computes a pair-wise frequency table of the given columns. Also known as a contingency + table. The number of distinct values for each column should be less than 1e4. The first + column of each row will be the distinct values of `col1` and the column names will be the + distinct values of `col2`. The name of the first column will be `$col1_$col2`. Pairs that + have no occurrences will have `null` as their counts. + :func:`DataFrame.crosstab` and :func:`DataFrameStatFunctions.crosstab` are aliases. + + :param col1: The name of the first column. Distinct items will make the first item of + each row. + :param col2: The name of the second column. Distinct items will make the column names + of the DataFrame. + """ + if not isinstance(col1, str): + raise ValueError("col1 should be a string.") + if not isinstance(col2, str): + raise ValueError("col2 should be a string.") + return DataFrame(self._jdf.stat().crosstab(col1, col2), self.sql_ctx) + def freqItems(self, cols, support=None): """ Finding frequent items for columns, possibly with false positives. Using the @@ -1390,6 +1410,11 @@ def cov(self, col1, col2): cov.__doc__ = DataFrame.cov.__doc__ + def crosstab(self, col1, col2): + return self.df.crosstab(col1, col2) + + crosstab.__doc__ = DataFrame.crosstab.__doc__ + def freqItems(self, cols, support=None): return self.df.freqItems(cols, support) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index d652c302a54ba..7ea6656d31c4e 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -405,6 +405,15 @@ def test_cov(self): cov = df.stat.cov("a", "b") self.assertTrue(abs(cov - 55.0 / 3) < 1e-6) + def test_crosstab(self): + df = self.sc.parallelize([Row(a=i % 3, b=i % 2) for i in range(1, 7)]).toDF() + ct = df.stat.crosstab("a", "b").collect() + ct = sorted(ct, key=lambda x: x[0]) + for i, row in enumerate(ct): + self.assertEqual(row[0], str(i)) + self.assertTrue(row[1], 1) + self.assertTrue(row[2], 1) + def test_math_functions(self): df = self.sc.parallelize([Row(a=i, b=2 * i) for i in range(10)]).toDF() from pyspark.sql import mathfunctions as functions diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index 903532105284e..fcf21ca741a7c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -28,6 +28,16 @@ import org.apache.spark.sql.execution.stat._ final class DataFrameStatFunctions private[sql](df: DataFrame) { /** + * Calculate the sample covariance of two numerical columns of a DataFrame. + * @param col1 the name of the first column + * @param col2 the name of the second column + * @return the covariance of the two columns. + */ + def cov(col1: String, col2: String): Double = { + StatFunctions.calculateCov(df, Seq(col1, col2)) + } + + /* * Calculates the correlation of two columns of a DataFrame. Currently only supports the Pearson * Correlation Coefficient. For Spearman Correlation, consider using RDD methods found in * MLlib's Statistics. @@ -53,6 +63,23 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { corr(col1, col2, "pearson") } + /** + * Computes a pair-wise frequency table of the given columns. Also known as a contingency table. + * The number of distinct values for each column should be less than 1e4. The first + * column of each row will be the distinct values of `col1` and the column names will be the + * distinct values of `col2`. The name of the first column will be `$col1_$col2`. Counts will be + * returned as `Long`s. Pairs that have no occurrences will have `null` as their counts. + * + * @param col1 The name of the first column. Distinct items will make the first item of + * each row. + * @param col2 The name of the second column. Distinct items will make the column names + * of the DataFrame. + * @return A Local DataFrame containing the table + */ + def crosstab(col1: String, col2: String): DataFrame = { + StatFunctions.crossTabulate(df, col1, col2) + } + /** * Finding frequent items for columns, possibly with false positives. Using the * frequent element count algorithm described in @@ -94,14 +121,4 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { def freqItems(cols: Seq[String]): DataFrame = { FrequentItems.singlePassFreqItems(df, cols, 0.01) } - - /** - * Calculate the sample covariance of two numerical columns of a DataFrame. - * @param col1 the name of the first column - * @param col2 the name of the second column - * @return the covariance of the two columns. - */ - def cov(col1: String, col2: String): Double = { - StatFunctions.calculateCov(df, Seq(col1, col2)) - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index 67b48e58b17ab..b50f606d9cbe3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -17,11 +17,14 @@ package org.apache.spark.sql.execution.stat -import org.apache.spark.sql.catalyst.expressions.Cast +import org.apache.spark.Logging import org.apache.spark.sql.{Column, DataFrame} -import org.apache.spark.sql.types.{DoubleType, NumericType} +import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, Cast} +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ -private[sql] object StatFunctions { +private[sql] object StatFunctions extends Logging { /** Calculate the Pearson Correlation Coefficient for the given columns */ private[sql] def pearsonCorrelation(df: DataFrame, cols: Seq[String]): Double = { @@ -95,4 +98,32 @@ private[sql] object StatFunctions { val counts = collectStatisticalData(df, cols) counts.cov } + + /** Generate a table of frequencies for the elements of two columns. */ + private[sql] def crossTabulate(df: DataFrame, col1: String, col2: String): DataFrame = { + val tableName = s"${col1}_$col2" + val counts = df.groupBy(col1, col2).agg(col(col1), col(col2), count("*")).take(1e8.toInt) + if (counts.length == 1e8.toInt) { + logWarning("The maximum limit of 1e8 pairs have been collected, which may not be all of " + + "the pairs. Please try reducing the amount of distinct items in your columns.") + } + // get the distinct values of column 2, so that we can make them the column names + val distinctCol2 = counts.map(_.get(1)).distinct.zipWithIndex.toMap + val columnSize = distinctCol2.size + require(columnSize < 1e4, s"The number of distinct values for $col2, can't " + + s"exceed 1e4. Currently $columnSize") + val table = counts.groupBy(_.get(0)).map { case (col1Item, rows) => + val countsRow = new GenericMutableRow(columnSize + 1) + rows.foreach { row => + countsRow.setLong(distinctCol2.get(row.get(1)).get + 1, row.getLong(2)) + } + // the value of col1 is the first value, the rest are the counts + countsRow.setString(0, col1Item.toString) + countsRow + }.toSeq + val headerNames = distinctCol2.map(r => StructField(r._1.toString, LongType)).toSeq + val schema = StructType(StructField(tableName, StringType) +: headerNames) + + new DataFrame(df.sqlContext, LocalRelation(schema.toAttributes, table)) + } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 78e847239f405..58cc8e5be6075 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -34,6 +34,7 @@ import java.io.Serializable; import java.util.Arrays; +import java.util.Comparator; import java.util.List; import java.util.Map; @@ -178,6 +179,33 @@ public void testCreateDataFrameFromJavaBeans() { Assert.assertEquals(bean.getD().get(i), d.apply(i)); } } + + private static Comparator CrosstabRowComparator = new Comparator() { + public int compare(Row row1, Row row2) { + String item1 = row1.getString(0); + String item2 = row2.getString(0); + return item1.compareTo(item2); + } + }; + + @Test + public void testCrosstab() { + DataFrame df = context.table("testData2"); + DataFrame crosstab = df.stat().crosstab("a", "b"); + String[] columnNames = crosstab.schema().fieldNames(); + Assert.assertEquals(columnNames[0], "a_b"); + Assert.assertEquals(columnNames[1], "1"); + Assert.assertEquals(columnNames[2], "2"); + Row[] rows = crosstab.collect(); + Arrays.sort(rows, CrosstabRowComparator); + Integer count = 1; + for (Row row : rows) { + Assert.assertEquals(row.get(0).toString(), count.toString()); + Assert.assertEquals(row.getLong(1), 1L); + Assert.assertEquals(row.getLong(2), 1L); + count++; + } + } @Test public void testFrequentItems() { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 06764d2a122f1..46b1845a9180c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -24,26 +24,9 @@ import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext.implicits._ class DataFrameStatSuite extends FunSuite { - - import TestData._ + val sqlCtx = TestSQLContext def toLetter(i: Int): String = (i + 97).toChar.toString - - test("Frequent Items") { - val rows = Seq.tabulate(1000) { i => - if (i % 3 == 0) (1, toLetter(1), -1.0) else (i, toLetter(i), i * -1.0) - } - val df = rows.toDF("numbers", "letters", "negDoubles") - - val results = df.stat.freqItems(Array("numbers", "letters"), 0.1) - val items = results.collect().head - items.getSeq[Int](0) should contain (1) - items.getSeq[String](1) should contain (toLetter(1)) - - val singleColResults = df.stat.freqItems(Array("negDoubles"), 0.1) - val items2 = singleColResults.collect().head - items2.getSeq[Double](0) should contain (-1.0) - } test("pearson correlation") { val df = Seq.tabulate(10)(i => (i, 2 * i, i * -1.0)).toDF("a", "b", "c") @@ -76,7 +59,43 @@ class DataFrameStatSuite extends FunSuite { intercept[IllegalArgumentException] { df.stat.cov("singles", "letters") // doesn't accept non-numerical dataTypes } + val decimalData = Seq.tabulate(6)(i => (BigDecimal(i % 3), BigDecimal(i % 2))).toDF("a", "b") val decimalRes = decimalData.stat.cov("a", "b") assert(math.abs(decimalRes) < 1e-12) } + + test("crosstab") { + val df = Seq((0, 0), (2, 1), (1, 0), (2, 0), (0, 0), (2, 0)).toDF("a", "b") + val crosstab = df.stat.crosstab("a", "b") + val columnNames = crosstab.schema.fieldNames + assert(columnNames(0) === "a_b") + assert(columnNames(1) === "0") + assert(columnNames(2) === "1") + val rows: Array[Row] = crosstab.collect().sortBy(_.getString(0)) + assert(rows(0).get(0).toString === "0") + assert(rows(0).getLong(1) === 2L) + assert(rows(0).get(2) === null) + assert(rows(1).get(0).toString === "1") + assert(rows(1).getLong(1) === 1L) + assert(rows(1).get(2) === null) + assert(rows(2).get(0).toString === "2") + assert(rows(2).getLong(1) === 2L) + assert(rows(2).getLong(2) === 1L) + } + + test("Frequent Items") { + val rows = Seq.tabulate(1000) { i => + if (i % 3 == 0) (1, toLetter(1), -1.0) else (i, toLetter(i), i * -1.0) + } + val df = rows.toDF("numbers", "letters", "negDoubles") + + val results = df.stat.freqItems(Array("numbers", "letters"), 0.1) + val items = results.collect().head + items.getSeq[Int](0) should contain (1) + items.getSeq[String](1) should contain (toLetter(1)) + + val singleColResults = df.stat.freqItems(Array("negDoubles"), 0.1) + val items2 = singleColResults.collect().head + items2.getSeq[Double](0) should contain (-1.0) + } }