## Bloom Filters

A bloom filter is a set data structure with constant $\Omicron(1)$ lookup performance that has no false negatives (always detects memberships if actually a member) but can have false positives (can detect membership but actually not a member), along with $\Omicron(1)$ insert complexity.

It specializes in having a constant memory cost, of which the size directly corrosponds with the false positive rate assuming a perfect hashing function.


In [50]:
import hashlib
import string
import random
import math
import scipy.optimize as opt
import sys

from typing import Iterator


def random_string(min_length: int = 9, max_length: int = 18):
    length = random.randint(min_length, max_length)
    letters = string.ascii_letters + string.digits
    return "".join(random.choices(letters, k=length))


class BloomFilter:
    # filter memory

    memory: bytearray
    # number of hash functions

    k: int
    # number of items in the filter
    items: int

    # bytes of each hash function output
    BLOCK_SIZE: int = 4

    def __init__(self, size: int = 1000, k: int = 16):
        """Creates the filter with *size* bytes of memory and *k* hash functions"""

        self.memory = bytearray(size)
        self.k = k
        self.items = 0

    def hash(self, data: bytes) -> bytes:
        """Returns the hash of length k*BLOCK for the data *data*"""

        h = hashlib.shake_256()
        h.update(data)
        return h.digest(self.BLOCK_SIZE * self.k)

    def indices(self, data: bytes) -> Iterator[int]:
        """Returns a iterator of set indices for the data *data*"""
        digest = self.hash(data)

        for i in range(self.k):
            pos = int.from_bytes(
                digest[self.BLOCK_SIZE * i : self.BLOCK_SIZE * (i + 1)], "little"
            ) % (8 * len(self.memory))
            yield pos

    def add(self, data: bytes):
        """Add *data* to the filter"""
        for pos in self.indices(data):
            self.memory[pos // 8] |= 1 << (pos % 8)

        self.items += 1

    def __contains__(self, item: bytes):
        """Checks if *data* is likely in the filter"""
        for pos in self.indices(item):
            if self.memory[pos // 8] & (1 << (pos % 8)) == 0:
                return False

        return True

    def fp_rate(self) -> float:
        """Returns the probability that an element not in the filter is detected to be in the filter"""

        n = 8 * len(self.memory)
        m = self.items
        k = self.k

        return (1 - (1 - 1 / n) ** (k * m)) ** k

    @staticmethod
    def least_memory(max_fp_rate: float, max_items: int, k: int) -> int:
        """Returns the least memory in bytes needed to store a max of *max_item* items with *max_fp_rate*"""
        occ_rate = -math.log(1 - max_fp_rate ** (1 / k)) / k
        mem = max_items / occ_rate
        return int(mem // 8 + 1)

    @staticmethod
    def best_k(max_fp_rate: float) -> int:
        res = opt.minimize_scalar(
            lambda k: -k / math.log(1 - max_fp_rate ** (1 / k)), bounds=(0.5, 64)
        )
        if not res.success:
            k = 16
        else:
            k = math.ceil(res.x)

        return k

    @staticmethod
    def create_filter(max_fp_rate: float, max_items: int, k: int = None):
        if k is None:
            k = BloomFilter.best_k(max_fp_rate)

        mem = BloomFilter.least_memory(max_fp_rate, max_items, k)
        return BloomFilter(mem, k)

    def __repr__(self) -> str:
        return f"BloomFilter(n={len(self.memory)} bytes, #={self.items}, k={self.k})"


# fil = BloomFilter.create_filter(0.001, 1000, 8)

# all_strings = set()


# print(fil.fp_rate())
# print(repr(fil))
fA = BloomFilter.create_filter(0.00001, 1000)
fB = BloomFilter.create_filter(0.00001, 1000, 16)
fset = set()

for _ in range(1000):
    word = random_string().encode()
    fset.add(word)
    fA.add(word)
    fB.add(word)

print(repr(fA), f'fp={fA.fp_rate()}')
print(repr(fB), f'fp={fB.fp_rate()}')

print(sys.getsizeof(fset))
# print(BloomFilter.least_memory(0.11, 1000))
# mistakes = 0
# total = 1000000
# for _ in range(total):
#     while True:
#         word = random_string().encode()
#         if word not in all_strings:
#             break

#     if word in fil:
#         mistakes += 1

# print(mistakes / total)

BloomFilter(n=2996 bytes, #=1000, k=17) fp=9.995548181145193e-06
BloomFilter(n=2997 bytes, #=1000, k=16) fp=9.98850733116421e-06
32984


In [47]:
filt = BloomFilter.create_filter(0.0001, 10_000)

print(repr(filt))

filt.add(b"www.youtube.com")
filt.add(b"www.bbc.com")

print(b"www.youtube.com" in filt)
print(b"www.google.com" in filt)

BloomFilter(n=23983B, #=0, k=14)
True
False
