Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-12938][SQL] DataFrame API for Bloom filter #10937

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ public enum Version {
* {@code BloomFilter} binary format version 1 (all values written in big-endian order):
* <ul>
* <li>Version number, always 1 (32 bit)</li>
* <li>Number of hash functions (32 bit)</li>
* <li>Total number of words of the underlying bit array (32 bit)</li>
* <li>The words/longs (numWords * 64 bit)</li>
* <li>Number of hash functions (32 bit)</li>
* </ul>
*/
V1(1);
Expand Down Expand Up @@ -97,6 +97,21 @@ int getVersionNumber() {
*/
public abstract boolean put(Object item);

/**
* A specialized variant of {@link #put(Object)}, that can only be used to put utf-8 string.
*/
public abstract boolean putString(String str);

/**
* A specialized variant of {@link #put(Object)}, that can only be used to put long.
*/
public abstract boolean putLong(long l);

/**
* A specialized variant of {@link #put(Object)}, that can only be used to put byte array.
*/
public abstract boolean putBinary(byte[] bytes);

/**
* Determines whether a given bloom filter is compatible with this bloom filter. For two
* bloom filters to be compatible, they must have the same bit size.
Expand All @@ -121,6 +136,23 @@ int getVersionNumber() {
*/
public abstract boolean mightContain(Object item);

/**
* A specialized variant of {@link #mightContain(Object)}, that can only be used to test utf-8
* string.
*/
public abstract boolean mightContainString(String str);

/**
* A specialized variant of {@link #mightContain(Object)}, that can only be used to test long.
*/
public abstract boolean mightContainLong(long l);

/**
* A specialized variant of {@link #mightContain(Object)}, that can only be used to test byte
* array.
*/
public abstract boolean mightContainBinary(byte[] bytes);

/**
* Writes out this {@link BloomFilter} to an output stream in binary format.
* It is the caller's responsibility to close the stream.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@

import java.io.*;

public class BloomFilterImpl extends BloomFilter {
public class BloomFilterImpl extends BloomFilter implements Serializable {

private final int numHashFunctions;
private final BitArray bits;
private int numHashFunctions;
private BitArray bits;

BloomFilterImpl(int numHashFunctions, long numBits) {
this(new BitArray(numBits), numHashFunctions);
Expand All @@ -33,6 +33,8 @@ private BloomFilterImpl(BitArray bits, int numHashFunctions) {
this.numHashFunctions = numHashFunctions;
}

private BloomFilterImpl() {}

@Override
public boolean equals(Object other) {
if (other == this) {
Expand Down Expand Up @@ -63,55 +65,75 @@ public long bitSize() {
return bits.bitSize();
}

private static long hashObjectToLong(Object item) {
@Override
public boolean put(Object item) {
if (item instanceof String) {
try {
byte[] bytes = ((String) item).getBytes("utf-8");
return hashBytesToLong(bytes);
} catch (UnsupportedEncodingException e) {
throw new RuntimeException("Only support utf-8 string", e);
}
return putString((String) item);
} else if (item instanceof byte[]) {
return putBinary((byte[]) item);
} else {
long longValue;

if (item instanceof Long) {
longValue = (Long) item;
} else if (item instanceof Integer) {
longValue = ((Integer) item).longValue();
} else if (item instanceof Short) {
longValue = ((Short) item).longValue();
} else if (item instanceof Byte) {
longValue = ((Byte) item).longValue();
} else {
throw new IllegalArgumentException(
"Support for " + item.getClass().getName() + " not implemented"
);
}

int h1 = Murmur3_x86_32.hashLong(longValue, 0);
int h2 = Murmur3_x86_32.hashLong(longValue, h1);
return (((long) h1) << 32) | (h2 & 0xFFFFFFFFL);
return putLong(Utils.integralToLong(item));
}
}

private static long hashBytesToLong(byte[] bytes) {
@Override
public boolean putString(String str) {
return putBinary(Utils.getBytesFromUTF8String(str));
}

@Override
public boolean putBinary(byte[] bytes) {
int h1 = Murmur3_x86_32.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, bytes.length, 0);
int h2 = Murmur3_x86_32.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, bytes.length, h1);
return (((long) h1) << 32) | (h2 & 0xFFFFFFFFL);

long bitSize = bits.bitSize();
boolean bitsChanged = false;
for (int i = 1; i <= numHashFunctions; i++) {
int combinedHash = h1 + (i * h2);
// Flip all the bits if it's negative (guaranteed positive number)
if (combinedHash < 0) {
combinedHash = ~combinedHash;
}
bitsChanged |= bits.set(combinedHash % bitSize);
}
return bitsChanged;
}

@Override
public boolean put(Object item) {
public boolean mightContainString(String str) {
return mightContainBinary(Utils.getBytesFromUTF8String(str));
}

@Override
public boolean mightContainBinary(byte[] bytes) {
int h1 = Murmur3_x86_32.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, bytes.length, 0);
int h2 = Murmur3_x86_32.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, bytes.length, h1);

long bitSize = bits.bitSize();
for (int i = 1; i <= numHashFunctions; i++) {
int combinedHash = h1 + (i * h2);
// Flip all the bits if it's negative (guaranteed positive number)
if (combinedHash < 0) {
combinedHash = ~combinedHash;
}
if (!bits.get(combinedHash % bitSize)) {
return false;
}
}
return true;
}

// Here we first hash the input element into 2 int hash values, h1 and h2, then produce n hash
// values by `h1 + i * h2` with 1 <= i <= numHashFunctions.
// Note that `CountMinSketch` use a different strategy for long type, it hash the input long
// element with every i to produce n hash values.
long hash64 = hashObjectToLong(item);
int h1 = (int) (hash64 >> 32);
int h2 = (int) hash64;
@Override
public boolean putLong(long l) {
// Here we first hash the input long element into 2 int hash values, h1 and h2, then produce n
// hash values by `h1 + i * h2` with 1 <= i <= numHashFunctions.
// Note that `CountMinSketch` use a different strategy, it hash the input long element with
// every i to produce n hash values.
// TODO: the strategy of `CountMinSketch` looks more advanced, should we follow it here?
int h1 = Murmur3_x86_32.hashLong(l, 0);
int h2 = Murmur3_x86_32.hashLong(l, h1);

long bitSize = bits.bitSize();
boolean bitsChanged = false;
for (int i = 1; i <= numHashFunctions; i++) {
int combinedHash = h1 + (i * h2);
Expand All @@ -125,12 +147,11 @@ public boolean put(Object item) {
}

@Override
public boolean mightContain(Object item) {
long bitSize = bits.bitSize();
long hash64 = hashObjectToLong(item);
int h1 = (int) (hash64 >> 32);
int h2 = (int) hash64;
public boolean mightContainLong(long l) {
int h1 = Murmur3_x86_32.hashLong(l, 0);
int h2 = Murmur3_x86_32.hashLong(l, h1);

long bitSize = bits.bitSize();
for (int i = 1; i <= numHashFunctions; i++) {
int combinedHash = h1 + (i * h2);
// Flip all the bits if it's negative (guaranteed positive number)
Expand All @@ -144,6 +165,17 @@ public boolean mightContain(Object item) {
return true;
}

@Override
public boolean mightContain(Object item) {
if (item instanceof String) {
return mightContainString((String) item);
} else if (item instanceof byte[]) {
return mightContainBinary((byte[]) item);
} else {
return mightContainLong(Utils.integralToLong(item));
}
}

@Override
public boolean isCompatible(BloomFilter other) {
if (other == null) {
Expand Down Expand Up @@ -191,18 +223,33 @@ public void writeTo(OutputStream out) throws IOException {
DataOutputStream dos = new DataOutputStream(out);

dos.writeInt(Version.V1.getVersionNumber());
bits.writeTo(dos);
dos.writeInt(numHashFunctions);
bits.writeTo(dos);
}

public static BloomFilterImpl readFrom(InputStream in) throws IOException {
private void readFrom0(InputStream in) throws IOException {
DataInputStream dis = new DataInputStream(in);

int version = dis.readInt();
if (version != Version.V1.getVersionNumber()) {
throw new IOException("Unexpected Bloom filter version number (" + version + ")");
}

return new BloomFilterImpl(BitArray.readFrom(dis), dis.readInt());
this.numHashFunctions = dis.readInt();
this.bits = BitArray.readFrom(dis);
}

public static BloomFilterImpl readFrom(InputStream in) throws IOException {
BloomFilterImpl filter = new BloomFilterImpl();
filter.readFrom0(in);
return filter;
}

private void writeObject(ObjectOutputStream out) throws IOException {
writeTo(out);
}

private void readObject(ObjectInputStream in) throws IOException {
readFrom0(in);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,7 @@ class CountMinSketchImpl extends CountMinSketch implements Serializable {
private double eps;
private double confidence;

private CountMinSketchImpl() {
}
private CountMinSketchImpl() {}

CountMinSketchImpl(int depth, int width, int seed) {
this.depth = depth;
Expand Down Expand Up @@ -143,23 +142,7 @@ public void add(Object item, long count) {
if (item instanceof String) {
addString((String) item, count);
} else {
long longValue;

if (item instanceof Long) {
longValue = (Long) item;
} else if (item instanceof Integer) {
longValue = ((Integer) item).longValue();
} else if (item instanceof Short) {
longValue = ((Short) item).longValue();
} else if (item instanceof Byte) {
longValue = ((Byte) item).longValue();
} else {
throw new IllegalArgumentException(
"Support for " + item.getClass().getName() + " not implemented"
);
}

addLong(longValue, count);
addLong(Utils.integralToLong(item), count);
}
}

Expand Down Expand Up @@ -201,13 +184,7 @@ private int hash(long item, int count) {
}

private static int[] getHashBuckets(String key, int hashCount, int max) {
byte[] b;
try {
b = key.getBytes("UTF-8");
} catch (UnsupportedEncodingException e) {
throw new RuntimeException(e);
}
return getHashBuckets(b, hashCount, max);
return getHashBuckets(Utils.getBytesFromUTF8String(key), hashCount, max);
}

private static int[] getHashBuckets(byte[] b, int hashCount, int max) {
Expand All @@ -225,23 +202,7 @@ public long estimateCount(Object item) {
if (item instanceof String) {
return estimateCountForStringItem((String) item);
} else {
long longValue;

if (item instanceof Long) {
longValue = (Long) item;
} else if (item instanceof Integer) {
longValue = ((Integer) item).longValue();
} else if (item instanceof Short) {
longValue = ((Short) item).longValue();
} else if (item instanceof Byte) {
longValue = ((Byte) item).longValue();
} else {
throw new IllegalArgumentException(
"Support for " + item.getClass().getName() + " not implemented"
);
}

return estimateCountForLongItem(longValue);
return estimateCountForLongItem(Utils.integralToLong(item));
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.util.sketch;

import java.io.UnsupportedEncodingException;

public class Utils {
public static byte[] getBytesFromUTF8String(String str) {
try {
return str.getBytes("utf-8");
} catch (UnsupportedEncodingException e) {
throw new IllegalArgumentException("Only support utf-8 string", e);
}
}

public static long integralToLong(Object i) {
long longValue;

if (i instanceof Long) {
longValue = (Long) i;
} else if (i instanceof Integer) {
longValue = ((Integer) i).longValue();
} else if (i instanceof Short) {
longValue = ((Short) i).longValue();
} else if (i instanceof Byte) {
longValue = ((Byte) i).longValue();
} else {
throw new IllegalArgumentException("Unsupported data type " + i.getClass().getName());
}

return longValue;
}
}
Loading