Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-12818][SQL] Specializes integral and string types for Count-min Sketch #10968

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,18 @@ int getVersionNumber() {
*/
public abstract void add(Object item, long count);

public abstract void addLong(long item);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to add java doc.

also update the other java doc to say "Increment item's count by one." or "Increment item's count by count"


public abstract void addLong(long item, long count);

public abstract void addString(String item);

public abstract void addString(String item, long count);

public abstract void addBinary(byte[] item);

public abstract void addBinary(byte[] item, long count);

/**
* Returns the estimated frequency of {@code item}.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.io.Serializable;
import java.io.UnsupportedEncodingException;
import java.util.Arrays;
import java.util.Random;

Expand Down Expand Up @@ -146,27 +145,49 @@ public void add(Object item, long count) {
}
}

private void addString(String item, long count) {
@Override
public void addString(String item) {
addString(item, 1);
}

@Override
public void addString(String item, long count) {
addBinary(Utils.getBytesFromUTF8String(item), count);
}

@Override
public void addLong(long item) {
addLong(item, 1);
}

@Override
public void addLong(long item, long count) {
if (count < 0) {
throw new IllegalArgumentException("Negative increments not implemented");
}

int[] buckets = getHashBuckets(item, depth, width);

for (int i = 0; i < depth; ++i) {
table[i][buckets[i]] += count;
table[i][hash(item, i)] += count;
}

totalCount += count;
}

private void addLong(long item, long count) {
@Override
public void addBinary(byte[] item) {
addBinary(item, 1);
}

@Override
public void addBinary(byte[] item, long count) {
if (count < 0) {
throw new IllegalArgumentException("Negative increments not implemented");
}

int[] buckets = getHashBuckets(item, depth, width);

for (int i = 0; i < depth; ++i) {
table[i][hash(item, i)] += count;
table[i][buckets[i]] += count;
}

totalCount += count;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import scala.collection.JavaConverters._
import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.stat._
import org.apache.spark.sql.types.{IntegralType, StringType}
import org.apache.spark.sql.types._
import org.apache.spark.util.sketch.{BloomFilter, CountMinSketch}

/**
Expand All @@ -38,6 +38,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {

/**
* Calculate the sample covariance of two numerical columns of a DataFrame.
*
* @param col1 the name of the first column
* @param col2 the name of the second column
* @return the covariance of the two columns.
Expand All @@ -48,7 +49,6 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
* df.stat.cov("rand1", "rand2")
* res1: Double = 0.065...
* }}}
*
* @since 1.4.0
*/
def cov(col1: String, col2: String): Double = {
Expand All @@ -70,7 +70,6 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
* df.stat.corr("rand1", "rand2")
* res1: Double = 0.613...
* }}}
*
* @since 1.4.0
*/
def corr(col1: String, col2: String, method: String): Double = {
Expand All @@ -92,7 +91,6 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
* df.stat.corr("rand1", "rand2", "pearson")
* res1: Double = 0.613...
* }}}
*
* @since 1.4.0
*/
def corr(col1: String, col2: String): Double = {
Expand All @@ -109,7 +107,6 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
* Null elements will be replaced by "null", and back ticks will be dropped from elements if they
* exist.
*
*
* @param col1 The name of the first column. Distinct items will make the first item of
* each row.
* @param col2 The name of the second column. Distinct items will make the column names
Expand All @@ -129,7 +126,6 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
* | 3| 0| 1| 1|
* +---------+---+---+---+
* }}}
*
* @since 1.4.0
*/
def crosstab(col1: String, col2: String): DataFrame = {
Expand Down Expand Up @@ -175,7 +171,6 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
* | ... |
* +----------+
* }}}
*
* @since 1.4.0
*/
def freqItems(cols: Array[String], support: Double): DataFrame = {
Expand All @@ -193,7 +188,6 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
*
* @param cols the names of the columns to search frequent items in.
* @return A Local DataFrame with the Array of frequent items for each column.
*
* @since 1.4.0
*/
def freqItems(cols: Array[String]): DataFrame = {
Expand Down Expand Up @@ -236,7 +230,6 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
* | ... |
* +----------+
* }}}
*
* @since 1.4.0
*/
def freqItems(cols: Seq[String], support: Double): DataFrame = {
Expand All @@ -254,7 +247,6 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
*
* @param cols the names of the columns to search frequent items in.
* @return A Local DataFrame with the Array of frequent items for each column.
*
* @since 1.4.0
*/
def freqItems(cols: Seq[String]): DataFrame = {
Expand All @@ -263,6 +255,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {

/**
* Returns a stratified sample without replacement based on the fraction given on each stratum.
*
* @param col column that defines strata
* @param fractions sampling fraction for each stratum. If a stratum is not specified, we treat
* its fraction as zero.
Expand All @@ -283,7 +276,6 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
* | 3| 2|
* +---+-----+
* }}}
*
* @since 1.5.0
*/
def sampleBy[T](col: String, fractions: Map[T, Double], seed: Long): DataFrame = {
Expand All @@ -300,13 +292,13 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {

/**
* Returns a stratified sample without replacement based on the fraction given on each stratum.
*
* @param col column that defines strata
* @param fractions sampling fraction for each stratum. If a stratum is not specified, we treat
* its fraction as zero.
* @param seed random seed
* @tparam T stratum type
* @return a new [[DataFrame]] that represents the stratified sample
*
* @since 1.5.0
*/
def sampleBy[T](col: String, fractions: ju.Map[T, jl.Double], seed: Long): DataFrame = {
Expand Down Expand Up @@ -374,21 +366,27 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
val singleCol = df.select(col)
val colType = singleCol.schema.head.dataType

require(
colType == StringType || colType.isInstanceOf[IntegralType],
s"Count-min Sketch only supports string type and integral types, " +
s"and does not support type $colType."
)
val updater: (CountMinSketch, InternalRow) => Unit = colType match {
// For string type, we can get bytes of our `UTF8String` directly, and call the `addBinary`
// instead of `addString` to avoid unnecessary conversion.
case StringType => (sketch, row) => sketch.addBinary(row.getUTF8String(0).getBytes)
case ByteType => (sketch, row) => sketch.addLong(row.getByte(0))
case ShortType => (sketch, row) => sketch.addLong(row.getShort(0))
case IntegerType => (sketch, row) => sketch.addLong(row.getInt(0))
case LongType => (sketch, row) => sketch.addLong(row.getLong(0))
case _ =>
throw new IllegalArgumentException(
s"Count-min Sketch only supports string type and integral types, " +
s"and does not support type $colType."
)
}

singleCol.rdd.aggregate(zero)(
(sketch: CountMinSketch, row: Row) => {
sketch.add(row.get(0))
singleCol.queryExecution.toRdd.aggregate(zero)(
(sketch: CountMinSketch, row: InternalRow) => {
updater(sketch, row)
sketch
},

(sketch1: CountMinSketch, sketch2: CountMinSketch) => {
sketch1.mergeInPlace(sketch2)
}
_ mergeInPlace _
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no infix notation

)
}

Expand Down Expand Up @@ -447,19 +445,27 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
require(colType == StringType || colType.isInstanceOf[IntegralType],
s"Bloom filter only supports string type and integral types, but got $colType.")

val seqOp: (BloomFilter, InternalRow) => BloomFilter = if (colType == StringType) {
(filter, row) =>
// For string type, we can get bytes of our `UTF8String` directly, and call the `putBinary`
// instead of `putString` to avoid unnecessary conversion.
filter.putBinary(row.getUTF8String(0).getBytes)
filter
} else {
(filter, row) =>
// TODO: specialize it.
filter.putLong(row.get(0, colType).asInstanceOf[Number].longValue())
filter
val updater: (BloomFilter, InternalRow) => Unit = colType match {
// For string type, we can get bytes of our `UTF8String` directly, and call the `putBinary`
// instead of `putString` to avoid unnecessary conversion.
case StringType => (filter, row) => filter.putBinary(row.getUTF8String(0).getBytes)
case ByteType => (filter, row) => filter.putLong(row.getByte(0))
case ShortType => (filter, row) => filter.putLong(row.getShort(0))
case IntegerType => (filter, row) => filter.putLong(row.getInt(0))
case LongType => (filter, row) => filter.putLong(row.getLong(0))
case _ =>
throw new IllegalArgumentException(
s"Bloom filter only supports string type and integral types, " +
s"and does not support type $colType."
)
}

singleCol.queryExecution.toRdd.aggregate(zero)(seqOp, _ mergeInPlace _)
singleCol.queryExecution.toRdd.aggregate(zero)(
(filter: BloomFilter, row: InternalRow) => {
updater(filter, row)
filter
},
_ mergeInPlace _
)
}
}