Skip to content

Commit

Permalink
Specialized integral and string types for Count-min Sketch
Browse files Browse the repository at this point in the history
  • Loading branch information
liancheng committed Jan 28, 2016
1 parent 4a09123 commit 4ad74c0
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,18 @@ int getVersionNumber() {
*/
public abstract void add(Object item, long count);

public abstract void addLong(long item);

public abstract void addLong(long item, long count);

public abstract void addString(String item);

public abstract void addString(String item, long count);

public abstract void addBinary(byte[] item);

public abstract void addBinary(byte[] item, long count);

/**
* Returns the estimated frequency of {@code item}.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.io.Serializable;
import java.io.UnsupportedEncodingException;
import java.util.Arrays;
import java.util.Random;

Expand Down Expand Up @@ -146,27 +145,49 @@ public void add(Object item, long count) {
}
}

private void addString(String item, long count) {
@Override
public void addString(String item) {
addString(item, 1);
}

@Override
public void addString(String item, long count) {
addBinary(Utils.getBytesFromUTF8String(item), count);
}

@Override
public void addLong(long item) {
addLong(item, 1);
}

@Override
public void addLong(long item, long count) {
if (count < 0) {
throw new IllegalArgumentException("Negative increments not implemented");
}

int[] buckets = getHashBuckets(item, depth, width);

for (int i = 0; i < depth; ++i) {
table[i][buckets[i]] += count;
table[i][hash(item, i)] += count;
}

totalCount += count;
}

private void addLong(long item, long count) {
@Override
public void addBinary(byte[] item) {
addBinary(item, 1);
}

@Override
public void addBinary(byte[] item, long count) {
if (count < 0) {
throw new IllegalArgumentException("Negative increments not implemented");
}

int[] buckets = getHashBuckets(item, depth, width);

for (int i = 0; i < depth; ++i) {
table[i][hash(item, i)] += count;
table[i][buckets[i]] += count;
}

totalCount += count;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import scala.collection.JavaConverters._
import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.stat._
import org.apache.spark.sql.types.{IntegralType, StringType}
import org.apache.spark.sql.types._
import org.apache.spark.util.sketch.{BloomFilter, CountMinSketch}

/**
Expand All @@ -38,6 +38,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {

/**
* Calculate the sample covariance of two numerical columns of a DataFrame.
*
* @param col1 the name of the first column
* @param col2 the name of the second column
* @return the covariance of the two columns.
Expand All @@ -48,7 +49,6 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
* df.stat.cov("rand1", "rand2")
* res1: Double = 0.065...
* }}}
*
* @since 1.4.0
*/
def cov(col1: String, col2: String): Double = {
Expand All @@ -70,7 +70,6 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
* df.stat.corr("rand1", "rand2")
* res1: Double = 0.613...
* }}}
*
* @since 1.4.0
*/
def corr(col1: String, col2: String, method: String): Double = {
Expand All @@ -92,7 +91,6 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
* df.stat.corr("rand1", "rand2", "pearson")
* res1: Double = 0.613...
* }}}
*
* @since 1.4.0
*/
def corr(col1: String, col2: String): Double = {
Expand All @@ -109,7 +107,6 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
* Null elements will be replaced by "null", and back ticks will be dropped from elements if they
* exist.
*
*
* @param col1 The name of the first column. Distinct items will make the first item of
* each row.
* @param col2 The name of the second column. Distinct items will make the column names
Expand All @@ -129,7 +126,6 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
* | 3| 0| 1| 1|
* +---------+---+---+---+
* }}}
*
* @since 1.4.0
*/
def crosstab(col1: String, col2: String): DataFrame = {
Expand Down Expand Up @@ -175,7 +171,6 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
* | ... |
* +----------+
* }}}
*
* @since 1.4.0
*/
def freqItems(cols: Array[String], support: Double): DataFrame = {
Expand All @@ -193,7 +188,6 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
*
* @param cols the names of the columns to search frequent items in.
* @return A Local DataFrame with the Array of frequent items for each column.
*
* @since 1.4.0
*/
def freqItems(cols: Array[String]): DataFrame = {
Expand Down Expand Up @@ -236,7 +230,6 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
* | ... |
* +----------+
* }}}
*
* @since 1.4.0
*/
def freqItems(cols: Seq[String], support: Double): DataFrame = {
Expand All @@ -254,7 +247,6 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
*
* @param cols the names of the columns to search frequent items in.
* @return A Local DataFrame with the Array of frequent items for each column.
*
* @since 1.4.0
*/
def freqItems(cols: Seq[String]): DataFrame = {
Expand All @@ -263,6 +255,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {

/**
* Returns a stratified sample without replacement based on the fraction given on each stratum.
*
* @param col column that defines strata
* @param fractions sampling fraction for each stratum. If a stratum is not specified, we treat
* its fraction as zero.
Expand All @@ -283,7 +276,6 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
* | 3| 2|
* +---+-----+
* }}}
*
* @since 1.5.0
*/
def sampleBy[T](col: String, fractions: Map[T, Double], seed: Long): DataFrame = {
Expand All @@ -300,13 +292,13 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {

/**
* Returns a stratified sample without replacement based on the fraction given on each stratum.
*
* @param col column that defines strata
* @param fractions sampling fraction for each stratum. If a stratum is not specified, we treat
* its fraction as zero.
* @param seed random seed
* @tparam T stratum type
* @return a new [[DataFrame]] that represents the stratified sample
*
* @since 1.5.0
*/
def sampleBy[T](col: String, fractions: ju.Map[T, jl.Double], seed: Long): DataFrame = {
Expand Down Expand Up @@ -374,21 +366,27 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
val singleCol = df.select(col)
val colType = singleCol.schema.head.dataType

require(
colType == StringType || colType.isInstanceOf[IntegralType],
s"Count-min Sketch only supports string type and integral types, " +
s"and does not support type $colType."
)
val updater: (CountMinSketch, InternalRow) => Unit = colType match {
// For string type, we can get bytes of our `UTF8String` directly, and call the `addBinary`
// instead of `addString` to avoid unnecessary conversion.
case StringType => (sketch, row) => sketch.addBinary(row.getUTF8String(0).getBytes)
case ByteType => (sketch, row) => sketch.addLong(row.getByte(0))
case ShortType => (sketch, row) => sketch.addLong(row.getShort(0))
case IntegerType => (sketch, row) => sketch.addLong(row.getInt(0))
case LongType => (sketch, row) => sketch.addLong(row.getLong(0))
case _ =>
throw new IllegalArgumentException(
s"Count-min Sketch only supports string type and integral types, " +
s"and does not support type $colType."
)
}

singleCol.rdd.aggregate(zero)(
(sketch: CountMinSketch, row: Row) => {
sketch.add(row.get(0))
singleCol.queryExecution.toRdd.aggregate(zero)(
(sketch: CountMinSketch, row: InternalRow) => {
updater(sketch, row)
sketch
},

(sketch1: CountMinSketch, sketch2: CountMinSketch) => {
sketch1.mergeInPlace(sketch2)
}
_ mergeInPlace _
)
}

Expand Down Expand Up @@ -447,19 +445,27 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
require(colType == StringType || colType.isInstanceOf[IntegralType],
s"Bloom filter only supports string type and integral types, but got $colType.")

val seqOp: (BloomFilter, InternalRow) => BloomFilter = if (colType == StringType) {
(filter, row) =>
// For string type, we can get bytes of our `UTF8String` directly, and call the `putBinary`
// instead of `putString` to avoid unnecessary conversion.
filter.putBinary(row.getUTF8String(0).getBytes)
filter
} else {
(filter, row) =>
// TODO: specialize it.
filter.putLong(row.get(0, colType).asInstanceOf[Number].longValue())
filter
val updater: (BloomFilter, InternalRow) => Unit = colType match {
// For string type, we can get bytes of our `UTF8String` directly, and call the `putBinary`
// instead of `putString` to avoid unnecessary conversion.
case StringType => (filter, row) => filter.putBinary(row.getUTF8String(0).getBytes)
case ByteType => (filter, row) => filter.putLong(row.getByte(0))
case ShortType => (filter, row) => filter.putLong(row.getShort(0))
case IntegerType => (filter, row) => filter.putLong(row.getInt(0))
case LongType => (filter, row) => filter.putLong(row.getLong(0))
case _ =>
throw new IllegalArgumentException(
s"Bloom filter only supports string type and integral types, " +
s"and does not support type $colType."
)
}

singleCol.queryExecution.toRdd.aggregate(zero)(seqOp, _ mergeInPlace _)
singleCol.queryExecution.toRdd.aggregate(zero)(
(filter: BloomFilter, row: InternalRow) => {
updater(filter, row)
filter
},
_ mergeInPlace _
)
}
}

0 comments on commit 4ad74c0

Please sign in to comment.