New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-7240][SQL] Single pass covariance calculation for dataframes #5825
Changes from 3 commits
408cb77
7dc6dbc
a7115f1
e3b0b85
4e97a50
aa2ad29
8456eca
0c6a759
51e39b8
f2e862b
cb18046
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -34,7 +34,8 @@ | |
from pyspark.sql.types import _create_cls, _parse_datatype_json_string | ||
|
||
|
||
__all__ = ["DataFrame", "GroupedData", "Column", "SchemaRDD", "DataFrameNaFunctions"] | ||
__all__ = ["DataFrame", "GroupedData", "Column", "SchemaRDD", "DataFrameNaFunctions", | ||
"DataFrameStatFunctions"] | ||
|
||
|
||
class DataFrame(object): | ||
|
@@ -93,6 +94,12 @@ def na(self): | |
""" | ||
return DataFrameNaFunctions(self) | ||
|
||
@property | ||
def stat(self): | ||
"""Returns a :class:`DataFrameStatFunctions` for statistic functions. | ||
""" | ||
return DataFrameStatFunctions(self) | ||
|
||
@ignore_unicode_prefix | ||
def toJSON(self, use_unicode=True): | ||
"""Converts a :class:`DataFrame` into a :class:`RDD` of string. | ||
|
@@ -868,6 +875,17 @@ def fillna(self, value, subset=None): | |
|
||
return DataFrame(self._jdf.na().fill(value, self._jseq(subset)), self.sql_ctx) | ||
|
||
def cov(self, col1, col2): | ||
""" | ||
Calculate the covariance for the given columns, specified by their names. | ||
alias for ``stat.cov()``. | ||
|
||
:param col1: The name of the first column | ||
:param col2: The name of the second column | ||
:return: the covariance of the columns | ||
""" | ||
return self.stat.cov(col1, col2) | ||
|
||
@ignore_unicode_prefix | ||
def withColumn(self, colName, col): | ||
"""Returns a new :class:`DataFrame` by adding a column. | ||
|
@@ -1311,6 +1329,28 @@ def fill(self, value, subset=None): | |
fill.__doc__ = DataFrame.fillna.__doc__ | ||
|
||
|
||
class DataFrameStatFunctions(object): | ||
"""Functionality for statistic functions with :class:`DataFrame`. | ||
""" | ||
|
||
def __init__(self, df): | ||
self.df = df | ||
|
||
def cov(self, col1, col2): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. take a look at how we do it for DataFrameNaFunctions so we don't need to duplicate the doc. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I noticed how we do it, but that doc contains There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Slightly weird, but probably better than the two falling out of sync. You can also say
|
||
""" | ||
Calculate the covariance for the given columns, specified by their names. | ||
|
||
:param col1: The name of the first column | ||
:param col2: The name of the second column | ||
:return: the covariance of the columns | ||
""" | ||
if not isinstance(col1, str): | ||
raise ValueError("col1 should be a string.") | ||
if not isinstance(col2, str): | ||
raise ValueError("col2 should be a string.") | ||
return self.df._jdf.stat().cov(col1, col2) | ||
|
||
|
||
def _test(): | ||
import doctest | ||
from pyspark.context import SparkContext | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.sql.execution.stat | ||
|
||
import org.apache.spark.sql.{Column, DataFrame} | ||
import org.apache.spark.sql.types.NumericType | ||
|
||
private[sql] object StatFunctions { | ||
|
||
/** Helper class to simplify tracking and merging counts. */ | ||
private class CovarianceCounter extends Serializable { | ||
var xAvg = 0.0 | ||
var yAvg = 0.0 | ||
var Ck = 0.0 | ||
var count = 0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should be |
||
// add an example to the calculation | ||
def add(x: Number, y: Number): this.type = { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why Number type? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So that we can handle all types ranging from Int, Long, Double, ..., BigDecimal, etc... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you can add a cast to double when you do the select |
||
val oldX = xAvg | ||
val otherX = x.doubleValue() | ||
val otherY = y.doubleValue() | ||
count += 1 | ||
xAvg += (otherX - xAvg) / count | ||
yAvg += (otherY - yAvg) / count | ||
println(oldX) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. println |
||
Ck += (otherY - yAvg) * (otherX - oldX) | ||
this | ||
} | ||
// merge counters from other partitions | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. provide a link to the wikipedia page for reference |
||
def merge(other: CovarianceCounter): this.type = { | ||
val totalCount = count + other.count | ||
Ck += other.Ck + | ||
(xAvg - other.xAvg) * (yAvg - other.yAvg) * (count * other.count) / totalCount | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
xAvg = (xAvg * count + other.xAvg * other.count) / totalCount | ||
yAvg = (yAvg * count + other.yAvg * other.count) / totalCount | ||
count = totalCount | ||
this | ||
} | ||
// return the covariance for the observed examples | ||
def cov: Double = Ck / count | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just remember that this should be sample covariance (dividing by (n-1) and 0 if n is 1). Try R and make sure we output the same, and then update the doc. See: |
||
} | ||
|
||
/** | ||
* Calculate the covariance of two numerical columns of a DataFrame. | ||
* @param df The DataFrame | ||
* @param cols the column names | ||
* @return the covariance of the two columns. | ||
*/ | ||
private[sql] def calculateCov(df: DataFrame, cols: Seq[String]): Double = { | ||
require(cols.length == 2, "Currently cov supports calculating the covariance " + | ||
"between two columns.") | ||
cols.map(name => (name, df.schema.fields.find(_.name == name))).foreach { case (name, data) => | ||
require(data.nonEmpty, s"Couldn't find column with name $name") | ||
require(data.get.dataType.isInstanceOf[NumericType], "Covariance calculation for columns " + | ||
s"with dataType ${data.get.dataType} not supported.") | ||
} | ||
val counts = df.select(cols.map(Column(_)):_*).rdd.aggregate(new CovarianceCounter)( | ||
seqOp = (counter, row) => { | ||
counter.add(row.getAs[Number](0), row.getAs[Number](1)) | ||
}, | ||
combOp = (baseCounter, other) => { | ||
baseCounter.merge(other) | ||
}) | ||
counts.cov | ||
} | ||
|
||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,10 +25,11 @@ import org.apache.spark.sql.test.TestSQLContext.implicits._ | |
|
||
class DataFrameStatSuite extends FunSuite { | ||
|
||
import TestData._ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i think we prefer having data defined inline now since it is easier to read. those were there historically because it was hard to create schemardds. |
||
val sqlCtx = TestSQLContext | ||
|
||
def toLetter(i: Int): String = (i + 97).toChar.toString | ||
|
||
test("Frequent Items") { | ||
def toLetter(i: Int): String = (i + 96).toChar.toString | ||
val rows = Array.tabulate(1000) { i => | ||
if (i % 3 == 0) (1, toLetter(1), -1.0) else (i, toLetter(i), i * -1.0) | ||
} | ||
|
@@ -44,4 +45,19 @@ class DataFrameStatSuite extends FunSuite { | |
items2.getSeq[Double](0) should contain (-1.0) | ||
|
||
} | ||
|
||
test("covariance") { | ||
val rows = Array.tabulate(10)(i => (i, 2.0 * i, toLetter(i))) | ||
val df = sqlCtx.sparkContext.parallelize(rows).toDF("singles", "doubles", "letters") | ||
df.show() | ||
|
||
val results = df.stat.cov("singles", "doubles") | ||
println(results) | ||
assert(math.abs(results - 16.5) < 1e-6) | ||
intercept[IllegalArgumentException] { | ||
df.stat.cov("singles", "letters") // doesn't accept non-numerical dataTypes | ||
} | ||
val decimalRes = decimalData.stat.cov("a", "b") | ||
assert(math.abs(decimalRes) < 1e-6) | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does
:return:
work in python doc? Can you try compiling it?Also say something about the return type, maybe "covariance of the two columns as a double value"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess not. IntelliJ generated them automatically... Thought I would give it a shot.