Skip to content

Commit

Permalink
[SPARK-12938][SQL] DataFrame API for Bloom filter
Browse files Browse the repository at this point in the history
This PR integrates Bloom filter from spark-sketch into DataFrame. This version resorts to RDD.aggregate for building the filter. A more performant UDAF version can be built in future follow-up PRs.

This PR also add 2 specify `put` version(`putBinary` and `putLong`) into `BloomFilter`, which makes it easier to build a Bloom filter over a `DataFrame`.

Author: Wenchen Fan <wenchen@databricks.com>

Closes #10937 from cloud-fan/bloom-filter.
  • Loading branch information
cloud-fan authored and rxin committed Jan 27, 2016
1 parent 32f7411 commit 680afab
Show file tree
Hide file tree
Showing 7 changed files with 306 additions and 93 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ public enum Version {
* {@code BloomFilter} binary format version 1 (all values written in big-endian order):
* <ul>
* <li>Version number, always 1 (32 bit)</li>
* <li>Number of hash functions (32 bit)</li>
* <li>Total number of words of the underlying bit array (32 bit)</li>
* <li>The words/longs (numWords * 64 bit)</li>
* <li>Number of hash functions (32 bit)</li>
* </ul>
*/
V1(1);
Expand Down Expand Up @@ -97,6 +97,21 @@ int getVersionNumber() {
*/
public abstract boolean put(Object item);

/**
* A specialized variant of {@link #put(Object)}, that can only be used to put utf-8 string.
*/
public abstract boolean putString(String str);

/**
* A specialized variant of {@link #put(Object)}, that can only be used to put long.
*/
public abstract boolean putLong(long l);

/**
* A specialized variant of {@link #put(Object)}, that can only be used to put byte array.
*/
public abstract boolean putBinary(byte[] bytes);

/**
* Determines whether a given bloom filter is compatible with this bloom filter. For two
* bloom filters to be compatible, they must have the same bit size.
Expand All @@ -121,6 +136,23 @@ int getVersionNumber() {
*/
public abstract boolean mightContain(Object item);

/**
* A specialized variant of {@link #mightContain(Object)}, that can only be used to test utf-8
* string.
*/
public abstract boolean mightContainString(String str);

/**
* A specialized variant of {@link #mightContain(Object)}, that can only be used to test long.
*/
public abstract boolean mightContainLong(long l);

/**
* A specialized variant of {@link #mightContain(Object)}, that can only be used to test byte
* array.
*/
public abstract boolean mightContainBinary(byte[] bytes);

/**
* Writes out this {@link BloomFilter} to an output stream in binary format.
* It is the caller's responsibility to close the stream.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@

import java.io.*;

public class BloomFilterImpl extends BloomFilter {
public class BloomFilterImpl extends BloomFilter implements Serializable {

private final int numHashFunctions;
private final BitArray bits;
private int numHashFunctions;
private BitArray bits;

BloomFilterImpl(int numHashFunctions, long numBits) {
this(new BitArray(numBits), numHashFunctions);
Expand All @@ -33,6 +33,8 @@ private BloomFilterImpl(BitArray bits, int numHashFunctions) {
this.numHashFunctions = numHashFunctions;
}

private BloomFilterImpl() {}

@Override
public boolean equals(Object other) {
if (other == this) {
Expand Down Expand Up @@ -63,55 +65,75 @@ public long bitSize() {
return bits.bitSize();
}

private static long hashObjectToLong(Object item) {
@Override
public boolean put(Object item) {
if (item instanceof String) {
try {
byte[] bytes = ((String) item).getBytes("utf-8");
return hashBytesToLong(bytes);
} catch (UnsupportedEncodingException e) {
throw new RuntimeException("Only support utf-8 string", e);
}
return putString((String) item);
} else if (item instanceof byte[]) {
return putBinary((byte[]) item);
} else {
long longValue;

if (item instanceof Long) {
longValue = (Long) item;
} else if (item instanceof Integer) {
longValue = ((Integer) item).longValue();
} else if (item instanceof Short) {
longValue = ((Short) item).longValue();
} else if (item instanceof Byte) {
longValue = ((Byte) item).longValue();
} else {
throw new IllegalArgumentException(
"Support for " + item.getClass().getName() + " not implemented"
);
}

int h1 = Murmur3_x86_32.hashLong(longValue, 0);
int h2 = Murmur3_x86_32.hashLong(longValue, h1);
return (((long) h1) << 32) | (h2 & 0xFFFFFFFFL);
return putLong(Utils.integralToLong(item));
}
}

private static long hashBytesToLong(byte[] bytes) {
@Override
public boolean putString(String str) {
return putBinary(Utils.getBytesFromUTF8String(str));
}

@Override
public boolean putBinary(byte[] bytes) {
int h1 = Murmur3_x86_32.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, bytes.length, 0);
int h2 = Murmur3_x86_32.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, bytes.length, h1);
return (((long) h1) << 32) | (h2 & 0xFFFFFFFFL);

long bitSize = bits.bitSize();
boolean bitsChanged = false;
for (int i = 1; i <= numHashFunctions; i++) {
int combinedHash = h1 + (i * h2);
// Flip all the bits if it's negative (guaranteed positive number)
if (combinedHash < 0) {
combinedHash = ~combinedHash;
}
bitsChanged |= bits.set(combinedHash % bitSize);
}
return bitsChanged;
}

@Override
public boolean put(Object item) {
public boolean mightContainString(String str) {
return mightContainBinary(Utils.getBytesFromUTF8String(str));
}

@Override
public boolean mightContainBinary(byte[] bytes) {
int h1 = Murmur3_x86_32.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, bytes.length, 0);
int h2 = Murmur3_x86_32.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, bytes.length, h1);

long bitSize = bits.bitSize();
for (int i = 1; i <= numHashFunctions; i++) {
int combinedHash = h1 + (i * h2);
// Flip all the bits if it's negative (guaranteed positive number)
if (combinedHash < 0) {
combinedHash = ~combinedHash;
}
if (!bits.get(combinedHash % bitSize)) {
return false;
}
}
return true;
}

// Here we first hash the input element into 2 int hash values, h1 and h2, then produce n hash
// values by `h1 + i * h2` with 1 <= i <= numHashFunctions.
// Note that `CountMinSketch` use a different strategy for long type, it hash the input long
// element with every i to produce n hash values.
long hash64 = hashObjectToLong(item);
int h1 = (int) (hash64 >> 32);
int h2 = (int) hash64;
@Override
public boolean putLong(long l) {
// Here we first hash the input long element into 2 int hash values, h1 and h2, then produce n
// hash values by `h1 + i * h2` with 1 <= i <= numHashFunctions.
// Note that `CountMinSketch` use a different strategy, it hash the input long element with
// every i to produce n hash values.
// TODO: the strategy of `CountMinSketch` looks more advanced, should we follow it here?
int h1 = Murmur3_x86_32.hashLong(l, 0);
int h2 = Murmur3_x86_32.hashLong(l, h1);

long bitSize = bits.bitSize();
boolean bitsChanged = false;
for (int i = 1; i <= numHashFunctions; i++) {
int combinedHash = h1 + (i * h2);
Expand All @@ -125,12 +147,11 @@ public boolean put(Object item) {
}

@Override
public boolean mightContain(Object item) {
long bitSize = bits.bitSize();
long hash64 = hashObjectToLong(item);
int h1 = (int) (hash64 >> 32);
int h2 = (int) hash64;
public boolean mightContainLong(long l) {
int h1 = Murmur3_x86_32.hashLong(l, 0);
int h2 = Murmur3_x86_32.hashLong(l, h1);

long bitSize = bits.bitSize();
for (int i = 1; i <= numHashFunctions; i++) {
int combinedHash = h1 + (i * h2);
// Flip all the bits if it's negative (guaranteed positive number)
Expand All @@ -144,6 +165,17 @@ public boolean mightContain(Object item) {
return true;
}

@Override
public boolean mightContain(Object item) {
if (item instanceof String) {
return mightContainString((String) item);
} else if (item instanceof byte[]) {
return mightContainBinary((byte[]) item);
} else {
return mightContainLong(Utils.integralToLong(item));
}
}

@Override
public boolean isCompatible(BloomFilter other) {
if (other == null) {
Expand Down Expand Up @@ -191,18 +223,33 @@ public void writeTo(OutputStream out) throws IOException {
DataOutputStream dos = new DataOutputStream(out);

dos.writeInt(Version.V1.getVersionNumber());
bits.writeTo(dos);
dos.writeInt(numHashFunctions);
bits.writeTo(dos);
}

public static BloomFilterImpl readFrom(InputStream in) throws IOException {
private void readFrom0(InputStream in) throws IOException {
DataInputStream dis = new DataInputStream(in);

int version = dis.readInt();
if (version != Version.V1.getVersionNumber()) {
throw new IOException("Unexpected Bloom filter version number (" + version + ")");
}

return new BloomFilterImpl(BitArray.readFrom(dis), dis.readInt());
this.numHashFunctions = dis.readInt();
this.bits = BitArray.readFrom(dis);
}

public static BloomFilterImpl readFrom(InputStream in) throws IOException {
BloomFilterImpl filter = new BloomFilterImpl();
filter.readFrom0(in);
return filter;
}

private void writeObject(ObjectOutputStream out) throws IOException {
writeTo(out);
}

private void readObject(ObjectInputStream in) throws IOException {
readFrom0(in);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,7 @@ class CountMinSketchImpl extends CountMinSketch implements Serializable {
private double eps;
private double confidence;

private CountMinSketchImpl() {
}
private CountMinSketchImpl() {}

CountMinSketchImpl(int depth, int width, int seed) {
this.depth = depth;
Expand Down Expand Up @@ -143,23 +142,7 @@ public void add(Object item, long count) {
if (item instanceof String) {
addString((String) item, count);
} else {
long longValue;

if (item instanceof Long) {
longValue = (Long) item;
} else if (item instanceof Integer) {
longValue = ((Integer) item).longValue();
} else if (item instanceof Short) {
longValue = ((Short) item).longValue();
} else if (item instanceof Byte) {
longValue = ((Byte) item).longValue();
} else {
throw new IllegalArgumentException(
"Support for " + item.getClass().getName() + " not implemented"
);
}

addLong(longValue, count);
addLong(Utils.integralToLong(item), count);
}
}

Expand Down Expand Up @@ -201,13 +184,7 @@ private int hash(long item, int count) {
}

private static int[] getHashBuckets(String key, int hashCount, int max) {
byte[] b;
try {
b = key.getBytes("UTF-8");
} catch (UnsupportedEncodingException e) {
throw new RuntimeException(e);
}
return getHashBuckets(b, hashCount, max);
return getHashBuckets(Utils.getBytesFromUTF8String(key), hashCount, max);
}

private static int[] getHashBuckets(byte[] b, int hashCount, int max) {
Expand All @@ -225,23 +202,7 @@ public long estimateCount(Object item) {
if (item instanceof String) {
return estimateCountForStringItem((String) item);
} else {
long longValue;

if (item instanceof Long) {
longValue = (Long) item;
} else if (item instanceof Integer) {
longValue = ((Integer) item).longValue();
} else if (item instanceof Short) {
longValue = ((Short) item).longValue();
} else if (item instanceof Byte) {
longValue = ((Byte) item).longValue();
} else {
throw new IllegalArgumentException(
"Support for " + item.getClass().getName() + " not implemented"
);
}

return estimateCountForLongItem(longValue);
return estimateCountForLongItem(Utils.integralToLong(item));
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.util.sketch;

import java.io.UnsupportedEncodingException;

public class Utils {
public static byte[] getBytesFromUTF8String(String str) {
try {
return str.getBytes("utf-8");
} catch (UnsupportedEncodingException e) {
throw new IllegalArgumentException("Only support utf-8 string", e);
}
}

public static long integralToLong(Object i) {
long longValue;

if (i instanceof Long) {
longValue = (Long) i;
} else if (i instanceof Integer) {
longValue = ((Integer) i).longValue();
} else if (i instanceof Short) {
longValue = ((Short) i).longValue();
} else if (i instanceof Byte) {
longValue = ((Byte) i).longValue();
} else {
throw new IllegalArgumentException("Unsupported data type " + i.getClass().getName());
}

return longValue;
}
}

0 comments on commit 680afab

Please sign in to comment.