Skip to content

Commit

Permalink
[SPARK-7243][SQL] Contingency Tables for DataFrames
Browse files Browse the repository at this point in the history
Computes a pair-wise frequency table of the given columns. Also known as cross-tabulation.
cc mengxr rxin

Author: Burak Yavuz <brkyvz@gmail.com>

Closes #5842 from brkyvz/df-cont and squashes the following commits:

a07c01e [Burak Yavuz] addressed comments v4.1
ae9e01d [Burak Yavuz] fix test
9106585 [Burak Yavuz] addressed comments v4.0
bced829 [Burak Yavuz] fix merge conflicts
a63ad00 [Burak Yavuz] addressed comments v3.0
a0cad97 [Burak Yavuz] addressed comments v3.0
6805df8 [Burak Yavuz] addressed comments and fixed test
939b7c4 [Burak Yavuz] lint python
7f098bc [Burak Yavuz] add crosstab pyTest
fd53b00 [Burak Yavuz] added python support for crosstab
27a5a81 [Burak Yavuz] implemented crosstab

(cherry picked from commit 8055411)
Signed-off-by: Reynold Xin <rxin@databricks.com>
  • Loading branch information
brkyvz authored and rxin committed May 5, 2015
1 parent 863ec0c commit ecf0d8a
Show file tree
Hide file tree
Showing 6 changed files with 160 additions and 31 deletions.
25 changes: 25 additions & 0 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -931,6 +931,26 @@ def cov(self, col1, col2):
raise ValueError("col2 should be a string.")
return self._jdf.stat().cov(col1, col2)

def crosstab(self, col1, col2):
"""
Computes a pair-wise frequency table of the given columns. Also known as a contingency
table. The number of distinct values for each column should be less than 1e4. The first
column of each row will be the distinct values of `col1` and the column names will be the
distinct values of `col2`. The name of the first column will be `$col1_$col2`. Pairs that
have no occurrences will have `null` as their counts.
:func:`DataFrame.crosstab` and :func:`DataFrameStatFunctions.crosstab` are aliases.
:param col1: The name of the first column. Distinct items will make the first item of
each row.
:param col2: The name of the second column. Distinct items will make the column names
of the DataFrame.
"""
if not isinstance(col1, str):
raise ValueError("col1 should be a string.")
if not isinstance(col2, str):
raise ValueError("col2 should be a string.")
return DataFrame(self._jdf.stat().crosstab(col1, col2), self.sql_ctx)

def freqItems(self, cols, support=None):
"""
Finding frequent items for columns, possibly with false positives. Using the
Expand Down Expand Up @@ -1423,6 +1443,11 @@ def cov(self, col1, col2):

cov.__doc__ = DataFrame.cov.__doc__

def crosstab(self, col1, col2):
return self.df.crosstab(col1, col2)

crosstab.__doc__ = DataFrame.crosstab.__doc__

def freqItems(self, cols, support=None):
return self.df.freqItems(cols, support)

Expand Down
9 changes: 9 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,15 @@ def test_cov(self):
cov = df.stat.cov("a", "b")
self.assertTrue(abs(cov - 55.0 / 3) < 1e-6)

def test_crosstab(self):
df = self.sc.parallelize([Row(a=i % 3, b=i % 2) for i in range(1, 7)]).toDF()
ct = df.stat.crosstab("a", "b").collect()
ct = sorted(ct, key=lambda x: x[0])
for i, row in enumerate(ct):
self.assertEqual(row[0], str(i))
self.assertTrue(row[1], 1)
self.assertTrue(row[2], 1)

def test_math_functions(self):
df = self.sc.parallelize([Row(a=i, b=2 * i) for i in range(10)]).toDF()
from pyspark.sql import mathfunctions as functions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,16 @@ import org.apache.spark.sql.execution.stat._
final class DataFrameStatFunctions private[sql](df: DataFrame) {

/**
* Calculate the sample covariance of two numerical columns of a DataFrame.
* @param col1 the name of the first column
* @param col2 the name of the second column
* @return the covariance of the two columns.
*/
def cov(col1: String, col2: String): Double = {
StatFunctions.calculateCov(df, Seq(col1, col2))
}

/*
* Calculates the correlation of two columns of a DataFrame. Currently only supports the Pearson
* Correlation Coefficient. For Spearman Correlation, consider using RDD methods found in
* MLlib's Statistics.
Expand All @@ -53,6 +63,23 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
corr(col1, col2, "pearson")
}

/**
* Computes a pair-wise frequency table of the given columns. Also known as a contingency table.
* The number of distinct values for each column should be less than 1e4. The first
* column of each row will be the distinct values of `col1` and the column names will be the
* distinct values of `col2`. The name of the first column will be `$col1_$col2`. Counts will be
* returned as `Long`s. Pairs that have no occurrences will have `null` as their counts.
*
* @param col1 The name of the first column. Distinct items will make the first item of
* each row.
* @param col2 The name of the second column. Distinct items will make the column names
* of the DataFrame.
* @return A Local DataFrame containing the table
*/
def crosstab(col1: String, col2: String): DataFrame = {
StatFunctions.crossTabulate(df, col1, col2)
}

/**
* Finding frequent items for columns, possibly with false positives. Using the
* frequent element count algorithm described in
Expand Down Expand Up @@ -94,14 +121,4 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
def freqItems(cols: Seq[String]): DataFrame = {
FrequentItems.singlePassFreqItems(df, cols, 0.01)
}

/**
* Calculate the sample covariance of two numerical columns of a DataFrame.
* @param col1 the name of the first column
* @param col2 the name of the second column
* @return the covariance of the two columns.
*/
def cov(col1: String, col2: String): Double = {
StatFunctions.calculateCov(df, Seq(col1, col2))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,14 @@

package org.apache.spark.sql.execution.stat

import org.apache.spark.sql.catalyst.expressions.Cast
import org.apache.spark.Logging
import org.apache.spark.sql.{Column, DataFrame}
import org.apache.spark.sql.types.{DoubleType, NumericType}
import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, Cast}
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._

private[sql] object StatFunctions {
private[sql] object StatFunctions extends Logging {

/** Calculate the Pearson Correlation Coefficient for the given columns */
private[sql] def pearsonCorrelation(df: DataFrame, cols: Seq[String]): Double = {
Expand Down Expand Up @@ -95,4 +98,32 @@ private[sql] object StatFunctions {
val counts = collectStatisticalData(df, cols)
counts.cov
}

/** Generate a table of frequencies for the elements of two columns. */
private[sql] def crossTabulate(df: DataFrame, col1: String, col2: String): DataFrame = {
val tableName = s"${col1}_$col2"
val counts = df.groupBy(col1, col2).agg(col(col1), col(col2), count("*")).take(1e8.toInt)
if (counts.length == 1e8.toInt) {
logWarning("The maximum limit of 1e8 pairs have been collected, which may not be all of " +
"the pairs. Please try reducing the amount of distinct items in your columns.")
}
// get the distinct values of column 2, so that we can make them the column names
val distinctCol2 = counts.map(_.get(1)).distinct.zipWithIndex.toMap
val columnSize = distinctCol2.size
require(columnSize < 1e4, s"The number of distinct values for $col2, can't " +
s"exceed 1e4. Currently $columnSize")
val table = counts.groupBy(_.get(0)).map { case (col1Item, rows) =>
val countsRow = new GenericMutableRow(columnSize + 1)
rows.foreach { row =>
countsRow.setLong(distinctCol2.get(row.get(1)).get + 1, row.getLong(2))
}
// the value of col1 is the first value, the rest are the counts
countsRow.setString(0, col1Item.toString)
countsRow
}.toSeq
val headerNames = distinctCol2.map(r => StructField(r._1.toString, LongType)).toSeq
val schema = StructType(StructField(tableName, StringType) +: headerNames)

new DataFrame(df.sqlContext, LocalRelation(schema.toAttributes, table))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

import java.io.Serializable;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.Map;

Expand Down Expand Up @@ -178,6 +179,33 @@ public void testCreateDataFrameFromJavaBeans() {
Assert.assertEquals(bean.getD().get(i), d.apply(i));
}
}

private static Comparator<Row> CrosstabRowComparator = new Comparator<Row>() {
public int compare(Row row1, Row row2) {
String item1 = row1.getString(0);
String item2 = row2.getString(0);
return item1.compareTo(item2);
}
};

@Test
public void testCrosstab() {
DataFrame df = context.table("testData2");
DataFrame crosstab = df.stat().crosstab("a", "b");
String[] columnNames = crosstab.schema().fieldNames();
Assert.assertEquals(columnNames[0], "a_b");
Assert.assertEquals(columnNames[1], "1");
Assert.assertEquals(columnNames[2], "2");
Row[] rows = crosstab.collect();
Arrays.sort(rows, CrosstabRowComparator);
Integer count = 1;
for (Row row : rows) {
Assert.assertEquals(row.get(0).toString(), count.toString());
Assert.assertEquals(row.getLong(1), 1L);
Assert.assertEquals(row.getLong(2), 1L);
count++;
}
}

@Test
public void testFrequentItems() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,26 +24,9 @@ import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext.implicits._

class DataFrameStatSuite extends FunSuite {

import TestData._

val sqlCtx = TestSQLContext
def toLetter(i: Int): String = (i + 97).toChar.toString

test("Frequent Items") {
val rows = Seq.tabulate(1000) { i =>
if (i % 3 == 0) (1, toLetter(1), -1.0) else (i, toLetter(i), i * -1.0)
}
val df = rows.toDF("numbers", "letters", "negDoubles")

val results = df.stat.freqItems(Array("numbers", "letters"), 0.1)
val items = results.collect().head
items.getSeq[Int](0) should contain (1)
items.getSeq[String](1) should contain (toLetter(1))

val singleColResults = df.stat.freqItems(Array("negDoubles"), 0.1)
val items2 = singleColResults.collect().head
items2.getSeq[Double](0) should contain (-1.0)
}

test("pearson correlation") {
val df = Seq.tabulate(10)(i => (i, 2 * i, i * -1.0)).toDF("a", "b", "c")
Expand Down Expand Up @@ -76,7 +59,43 @@ class DataFrameStatSuite extends FunSuite {
intercept[IllegalArgumentException] {
df.stat.cov("singles", "letters") // doesn't accept non-numerical dataTypes
}
val decimalData = Seq.tabulate(6)(i => (BigDecimal(i % 3), BigDecimal(i % 2))).toDF("a", "b")
val decimalRes = decimalData.stat.cov("a", "b")
assert(math.abs(decimalRes) < 1e-12)
}

test("crosstab") {
val df = Seq((0, 0), (2, 1), (1, 0), (2, 0), (0, 0), (2, 0)).toDF("a", "b")
val crosstab = df.stat.crosstab("a", "b")
val columnNames = crosstab.schema.fieldNames
assert(columnNames(0) === "a_b")
assert(columnNames(1) === "0")
assert(columnNames(2) === "1")
val rows: Array[Row] = crosstab.collect().sortBy(_.getString(0))
assert(rows(0).get(0).toString === "0")
assert(rows(0).getLong(1) === 2L)
assert(rows(0).get(2) === null)
assert(rows(1).get(0).toString === "1")
assert(rows(1).getLong(1) === 1L)
assert(rows(1).get(2) === null)
assert(rows(2).get(0).toString === "2")
assert(rows(2).getLong(1) === 2L)
assert(rows(2).getLong(2) === 1L)
}

test("Frequent Items") {
val rows = Seq.tabulate(1000) { i =>
if (i % 3 == 0) (1, toLetter(1), -1.0) else (i, toLetter(i), i * -1.0)
}
val df = rows.toDF("numbers", "letters", "negDoubles")

val results = df.stat.freqItems(Array("numbers", "letters"), 0.1)
val items = results.collect().head
items.getSeq[Int](0) should contain (1)
items.getSeq[String](1) should contain (toLetter(1))

val singleColResults = df.stat.freqItems(Array("negDoubles"), 0.1)
val items2 = singleColResults.collect().head
items2.getSeq[Double](0) should contain (-1.0)
}
}

0 comments on commit ecf0d8a

Please sign in to comment.