Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-7243][SQL] Contingency Tables for DataFrames #5842

Closed
wants to merge 11 commits into from
Closed
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -910,6 +910,25 @@ def cov(self, col1, col2):
raise ValueError("col2 should be a string.")
return self._jdf.stat().cov(col1, col2)

def crosstab(self, col1, col2):
"""
Computes a pair-wise frequency table of the given columns. Also known as a contingency
table. The number of distinct values for each column should be less than 1e4. The first
column of each row will be the distinct values of `col1` and the column names will be the
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Document the first column name. 1e5 -> 1e4.

distinct values of `col2`. Pairs that have no occurrences will have `null` as their counts.
:func:`DataFrame.crosstab` and :func:`DataFrameStatFunctions.crosstab` are aliases.

:param col1: The name of the first column. Distinct items will make the first item of
each row.
:param col2: The name of the second column. Distinct items will make the column names
of the DataFrame.
"""
if not isinstance(col1, str):
raise ValueError("col1 should be a string.")
if not isinstance(col2, str):
raise ValueError("col2 should be a string.")
return DataFrame(self._jdf.stat().crosstab(col1, col2), self.sql_ctx)

def freqItems(self, cols, support=None):
"""
Finding frequent items for columns, possibly with false positives. Using the
Expand Down Expand Up @@ -1390,6 +1409,11 @@ def cov(self, col1, col2):

cov.__doc__ = DataFrame.cov.__doc__

def crosstab(self, col1, col2):
return self.df.crosstab(col1, col2)

crosstab.__doc__ = DataFrame.crosstab.__doc__

def freqItems(self, cols, support=None):
return self.df.freqItems(cols, support)

Expand Down
9 changes: 9 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,15 @@ def test_cov(self):
cov = df.stat.cov("a", "b")
self.assertTrue(abs(cov - 55.0 / 3) < 1e-6)

def test_crosstab(self):
df = self.sc.parallelize([Row(a=i % 3, b=i % 2) for i in range(1, 7)]).toDF()
ct = df.stat.crosstab("a", "b").collect()
ct = sorted(ct, key=lambda x: x[0])
for i, row in enumerate(ct):
self.assertEqual(row[0], str(i))
self.assertTrue(row[1], 1)
self.assertTrue(row[2], 1)

def test_math_functions(self):
df = self.sc.parallelize([Row(a=i, b=2 * i) for i in range(10)]).toDF()
from pyspark.sql import mathfunctions as functions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,16 @@ import org.apache.spark.sql.execution.stat._
final class DataFrameStatFunctions private[sql](df: DataFrame) {

/**
* Calculate the sample covariance of two numerical columns of a DataFrame.
* @param col1 the name of the first column
* @param col2 the name of the second column
* @return the covariance of the two columns.
*/
def cov(col1: String, col2: String): Double = {
StatFunctions.calculateCov(df, Seq(col1, col2))
}

/*
* Calculates the correlation of two columns of a DataFrame. Currently only supports the Pearson
* Correlation Coefficient. For Spearman Correlation, consider using RDD methods found in
* MLlib's Statistics.
Expand All @@ -53,6 +63,23 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
corr(col1, col2, "pearson")
}

/**
* Computes a pair-wise frequency table of the given columns. Also known as a contingency table.
* The number of distinct values for each column should be less than 1e4. The first
* column of each row will be the distinct values of `col1` and the column names will be the
* distinct values of `col2`. Counts will be returned as `Long`s. Pairs that have no occurrences
* will have `null` as their counts.
*
* @param col1 The name of the first column. Distinct items will make the first item of
* each row.
* @param col2 The name of the second column. Distinct items will make the column names
* of the DataFrame.
* @return A Local DataFrame containing the table
*/
def crosstab(col1: String, col2: String): DataFrame = {
StatFunctions.crossTabulate(df, col1, col2)
}

/**
* Finding frequent items for columns, possibly with false positives. Using the
* frequent element count algorithm described in
Expand Down Expand Up @@ -94,14 +121,4 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
def freqItems(cols: Seq[String]): DataFrame = {
FrequentItems.singlePassFreqItems(df, cols, 0.01)
}

/**
* Calculate the sample covariance of two numerical columns of a DataFrame.
* @param col1 the name of the first column
* @param col2 the name of the second column
* @return the covariance of the two columns.
*/
def cov(col1: String, col2: String): Double = {
StatFunctions.calculateCov(df, Seq(col1, col2))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@

package org.apache.spark.sql.execution.stat

import org.apache.spark.sql.catalyst.expressions.Cast
import org.apache.spark.sql.{Column, DataFrame}
import org.apache.spark.sql.types.{DoubleType, NumericType}
import org.apache.spark.sql.{Column, DataFrame, Row}
import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, Cast}
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._

private[sql] object StatFunctions {

Expand Down Expand Up @@ -95,4 +97,28 @@ private[sql] object StatFunctions {
val counts = collectStatisticalData(df, cols)
counts.cov
}

/** Generate a table of frequencies for the elements of two columns. */
private[sql] def crossTabulate(df: DataFrame, col1: String, col2: String): DataFrame = {
val tableName = s"${col1}_$col2"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my previous comment, I mean this tableName is not document. Users need to know the name of the first column to operate.

minor: It would be good to check pandas' OR R's naming for this column and follow one of them.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pandas and R have the concept of row names, which we currently don't. We have to have the first column as the "row names".

val counts = df.groupBy(col1, col2).agg(col(col1), col(col2), count("*")).take(1e8.toInt)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check the size of counts. If it is 1e8, throw a warning.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@brkyvz can you submit a follow up pr to reduce 1e8 to 1e6? 1e8 is too large.

// get the distinct values of column 2, so that we can make them the column names
val distinctCol2 = counts.map(_.get(1)).distinct.zipWithIndex.toMap
val columnSize = distinctCol2.size
require(columnSize < 1e4, s"The number of distinct values for $col2, can't " +
s"exceed 1e4. Currently $columnSize")
val table = counts.groupBy(_.get(0)).map { case (col1Items, rows) =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

col1Items -> col1Item? There is only one item from col1 per record.

val countsRow = new GenericMutableRow(columnSize + 1)
rows.foreach { row =>
countsRow.setLong(distinctCol2.get(row.get(1)).get + 1, row.getLong(2))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

distinctCol2.get(row.get(1)).get -> distinctCol2(row.get(1))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

distinctCol2(row.get(1)) is an Option. I need the value, don't I?

}
// the value of col1 is the first value, the rest are the counts
countsRow.setString(0, col1Items.toString)
countsRow
}.toSeq
val headerNames = distinctCol2.map(r => StructField(r._1.toString, LongType)).toSeq
val schema = StructType(StructField(tableName, StringType) +: headerNames)

new DataFrame(df.sqlContext, LocalRelation(schema.toAttributes, table))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

import java.io.Serializable;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.Map;

Expand Down Expand Up @@ -178,6 +179,33 @@ public void testCreateDataFrameFromJavaBeans() {
Assert.assertEquals(bean.getD().get(i), d.apply(i));
}
}

private static Comparator<Row> CrosstabRowComparator = new Comparator<Row>() {
public int compare(Row row1, Row row2) {
String item1 = row1.getString(0);
String item2 = row2.getString(0);
return item1.compareTo(item2);
}
};

@Test
public void testCrosstab() {
DataFrame df = context.table("testData2");
DataFrame crosstab = df.stat().crosstab("a", "b");
String[] columnNames = crosstab.schema().fieldNames();
Assert.assertEquals(columnNames[0], "a_b");
Assert.assertEquals(columnNames[1], "1");
Assert.assertEquals(columnNames[2], "2");
Row[] rows = crosstab.collect();
Arrays.sort(rows, CrosstabRowComparator);
Integer count = 1;
for (Row row : rows) {
Assert.assertEquals(row.get(0).toString(), count.toString());
Assert.assertEquals(row.getLong(1), 1L);
Assert.assertEquals(row.getLong(2), 1L);
count++;
}
}

@Test
public void testFrequentItems() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,26 +24,9 @@ import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext.implicits._

class DataFrameStatSuite extends FunSuite {

import TestData._

val sqlCtx = TestSQLContext
def toLetter(i: Int): String = (i + 97).toChar.toString

test("Frequent Items") {
val rows = Seq.tabulate(1000) { i =>
if (i % 3 == 0) (1, toLetter(1), -1.0) else (i, toLetter(i), i * -1.0)
}
val df = rows.toDF("numbers", "letters", "negDoubles")

val results = df.stat.freqItems(Array("numbers", "letters"), 0.1)
val items = results.collect().head
items.getSeq[Int](0) should contain (1)
items.getSeq[String](1) should contain (toLetter(1))

val singleColResults = df.stat.freqItems(Array("negDoubles"), 0.1)
val items2 = singleColResults.collect().head
items2.getSeq[Double](0) should contain (-1.0)
}

test("pearson correlation") {
val df = Seq.tabulate(10)(i => (i, 2 * i, i * -1.0)).toDF("a", "b", "c")
Expand Down Expand Up @@ -76,7 +59,43 @@ class DataFrameStatSuite extends FunSuite {
intercept[IllegalArgumentException] {
df.stat.cov("singles", "letters") // doesn't accept non-numerical dataTypes
}
val decimalData = Seq.tabulate(6)(i => (BigDecimal(i % 3), BigDecimal(i % 2))).toDF("a", "b")
val decimalRes = decimalData.stat.cov("a", "b")
assert(math.abs(decimalRes) < 1e-12)
}

test("crosstab") {
val df = Seq((0, 0), (2, 1), (1, 0), (2, 0), (0, 0), (2, 0)).toDF("a", "b")
val crosstab = df.stat.crosstab("a", "b")
val columnNames = crosstab.schema.fieldNames
assert(columnNames(0) === "a_b")
assert(columnNames(1) === "0")
assert(columnNames(2) === "1")
val rows: Array[Row] = crosstab.collect().sortBy(_.getString(0))
assert(rows(0).get(0).toString === "0")
assert(rows(0).getLong(1) === 2L)
assert(rows(0).get(2) === null)
assert(rows(1).get(0).toString === "1")
assert(rows(1).getLong(1) === 1L)
assert(rows(1).get(2) === null)
assert(rows(2).get(0).toString === "2")
assert(rows(2).getLong(1) === 2L)
assert(rows(2).getLong(2) === 1L)
}

test("Frequent Items") {
val rows = Seq.tabulate(1000) { i =>
if (i % 3 == 0) (1, toLetter(1), -1.0) else (i, toLetter(i), i * -1.0)
}
val df = rows.toDF("numbers", "letters", "negDoubles")

val results = df.stat.freqItems(Array("numbers", "letters"), 0.1)
val items = results.collect().head
items.getSeq[Int](0) should contain (1)
items.getSeq[String](1) should contain (toLetter(1))

val singleColResults = df.stat.freqItems(Array("negDoubles"), 0.1)
val items2 = singleColResults.collect().head
items2.getSeq[Double](0) should contain (-1.0)
}
}