Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add type hints #16

Merged
merged 5 commits into from
Sep 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .github/workflows/linter.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,14 @@ jobs:
uses: actions/setup-python@v2

- name: Install Python dependencies
run: pip install black
run: |
pip install black \
pip install mypy

- name: Run linters
uses: wearerequired/lint-action@v2
with:
github_token: ${{ secrets.GITHUB_TOKEN }}
black: true
mypy: true
auto_fix: true
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.7', '3.8', '3.9', '3.10', '3.11']
python-version: ['3.8', '3.9', '3.10', '3.11']

steps:
- uses: actions/checkout@v2
Expand Down
2 changes: 2 additions & 0 deletions bloomfilter/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from bloomfilter.bloomfilter import BloomFilter

__all__ = ["BloomFilter"]
32 changes: 21 additions & 11 deletions bloomfilter/bloomfilter.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
import base64
import math
import typing

from bitarray import bitarray

Check failure on line 5 in bloomfilter/bloomfilter.py

View workflow job for this annotation

GitHub Actions / Mypy

bloomfilter/bloomfilter.py#L5

Cannot find implementation or library stub for module named "bitarray" [import]
from bloomfilter.bloomfilter_strategy import MURMUR128_MITZ_32, MURMUR128_MITZ_64
from bloomfilter.bloomfilter_strategy import (
Strategy,
MURMUR128_MITZ_32,
MURMUR128_MITZ_64,
)


STRATEGIES = [MURMUR128_MITZ_32, MURMUR128_MITZ_64]
STRATEGIES: typing.List[typing.Type[Strategy]] = [MURMUR128_MITZ_32, MURMUR128_MITZ_64]


class BloomFilter(object):
class BloomFilter:
"""
Bloomfilter class.

Expand All @@ -21,7 +26,10 @@
"""

def __init__(
self, expected_insertions, err_rate, strategy=MURMUR128_MITZ_64, *args, **kwargs
self,
expected_insertions: int,
err_rate: float,
strategy: typing.Type[Strategy] = MURMUR128_MITZ_64,
):
if err_rate <= 0:
raise ValueError("Error rate must be > 0.0")
Expand All @@ -37,7 +45,9 @@
data = bitarray("0") * math.ceil(num_bits / 64) * 64
self.setup(num_hash_functions, data, strategy)

def setup(self, num_hash_functions, data, strategy):
def setup(
self, num_hash_functions: int, data: bitarray, strategy: typing.Type[Strategy]
) -> None:
self.num_hash_functions = num_hash_functions
self.data = data
self.strategy = strategy
Expand Down Expand Up @@ -95,7 +105,7 @@
result += self.num_hash_functions.to_bytes(1, byteorder="little")
result += math.ceil(len(self.data) / 64).to_bytes(4, byteorder="big")
for i in range(0, len(self.data), 64):
result += self.data[i : i + 64][::-1]
result += self.data[i : i + 64][::-1].tobytes()
return result

def dumps_to_hex(self) -> str:
Expand All @@ -111,7 +121,7 @@
return base64.b64encode(self.dumps())

@classmethod
def num_of_bits(cls, expected_insertions, err_rate):
def num_of_bits(cls, expected_insertions: int, err_rate: float) -> int:
"""
Compute the number of bits required for the Bloomfilter given expected insertions and error rate.

Expand All @@ -129,7 +139,7 @@
)

@classmethod
def num_of_hash_functions(cls, expected_insertions, num_bits):
def num_of_hash_functions(cls, expected_insertions: int, num_bits: int) -> int:
"""
Compute the number of hash functions required per each element insertion.

Expand All @@ -142,17 +152,17 @@
"""
return max(1, round(num_bits / expected_insertions * math.log(2)))

def put(self, key):
def put(self, key: typing.Union[int, str]) -> bool:
"""
Put an element into the Bloomfilter.
"""
return self.strategy.put(key, self.num_hash_functions, self.data)

def might_contain(self, key):
def might_contain(self, key: typing.Union[int, str]) -> bool:
"""
Return ``True`` if given element exists in Bloomfilter. Otherwise return ``False``.
"""
return self.strategy.might_contain(key, self.num_hash_functions, self.data)

def __contains__(self, key):
def __contains__(self, key: typing.Union[int, str]) -> bool:
return self.might_contain(key)
87 changes: 54 additions & 33 deletions bloomfilter/bloomfilter_strategy.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from abc import ABC, abstractmethod
from bitarray import bitarray

Check failure on line 2 in bloomfilter/bloomfilter_strategy.py

View workflow job for this annotation

GitHub Actions / Mypy

bloomfilter/bloomfilter_strategy.py#L2

Cannot find implementation or library stub for module named "bitarray" [import]

import mmh3

Check failure on line 4 in bloomfilter/bloomfilter_strategy.py

View workflow job for this annotation

GitHub Actions / Mypy

bloomfilter/bloomfilter_strategy.py#L4

Cannot find implementation or library stub for module named "mmh3" [import]
import typing


class Strategy(ABC):
Expand All @@ -9,110 +11,129 @@
INT_MAX = 0x7FFFFFFF
INT_MIN = -0x7FFFFFFF

@classmethod
@abstractmethod
def put(self, key, num_hash_functions, array):
def put(
cls, key: typing.Union[int, str], num_hash_functions: int, array: bitarray
) -> bool:
pass

@classmethod
@abstractmethod
def might_contain(self, key, num_hash_functions, array):
def might_contain(
cls, key: typing.Union[int, str], num_hash_functions: int, array: bitarray
) -> bool:
pass

@classmethod
@abstractmethod
def ordinal(self):
def ordinal(cls) -> int:
pass


class MURMUR128_MITZ_32(Strategy):
@classmethod
def put(self, key, num_hash_functions, array):
def put(
cls, key: typing.Union[int, str], num_hash_functions: int, array: bitarray
) -> bool:
bit_size = len(array)
if isinstance(key, int) and self.INT_MIN <= key <= self.INT_MAX:
hash_value, _ = mmh3.hash64(key.to_bytes(4, byteorder="little"))
elif isinstance(key, int) and self.LONG_MIN <= key <= self.LONG_MAX:
hash_value, _ = mmh3.hash64(key.to_bytes(8, byteorder="little"))
if isinstance(key, int):
if cls.INT_MIN <= key <= cls.INT_MAX:
hash_value, _ = mmh3.hash64(key.to_bytes(4, byteorder="little"))
elif cls.LONG_MIN <= key <= cls.LONG_MAX:
hash_value, _ = mmh3.hash64(key.to_bytes(8, byteorder="little"))
else:
hash_value, _ = mmh3.hash64(key)
hash1 = hash_value & self.INT_MAX
hash1 = hash_value & cls.INT_MAX
hash2 = (hash_value >> 32) & 0xFFFFFFFF

bits_changed = False
for i in range(1, num_hash_functions + 1):
combined_hash = hash1 + (i * hash2)
combined_hash &= 0xFFFFFFFF
if combined_hash > self.INT_MAX or combined_hash < 0:
combined_hash = (~combined_hash) & self.INT_MAX
if combined_hash > cls.INT_MAX or combined_hash < 0:
combined_hash = (~combined_hash) & cls.INT_MAX
index = combined_hash % bit_size
if array[index] == 0:
bits_changed = True
array[index] = 1
return bits_changed

@classmethod
def might_contain(self, key, num_hash_functions, array):
def might_contain(
cls, key: typing.Union[int, str], num_hash_functions: int, array: bitarray
) -> bool:
bit_size = len(array)
if isinstance(key, int) and self.INT_MIN <= key <= self.INT_MAX:
hash_value, _ = mmh3.hash64(key.to_bytes(4, byteorder="little"))
elif isinstance(key, int) and self.LONG_MIN <= key <= self.LONG_MAX:
hash_value, _ = mmh3.hash64(key.to_bytes(8, byteorder="little"))
if isinstance(key, int):
if cls.INT_MIN <= key <= cls.INT_MAX:
hash_value, _ = mmh3.hash64(key.to_bytes(4, byteorder="little"))
elif cls.LONG_MIN <= key <= cls.LONG_MAX:
hash_value, _ = mmh3.hash64(key.to_bytes(8, byteorder="little"))

Check warning on line 71 in bloomfilter/bloomfilter_strategy.py

View check run for this annotation

Codecov / codecov/patch

bloomfilter/bloomfilter_strategy.py#L67-L71

Added lines #L67 - L71 were not covered by tests
else:
hash_value, _ = mmh3.hash64(key)
hash1 = hash_value & self.INT_MAX
hash1 = hash_value & cls.INT_MAX

Check warning on line 74 in bloomfilter/bloomfilter_strategy.py

View check run for this annotation

Codecov / codecov/patch

bloomfilter/bloomfilter_strategy.py#L74

Added line #L74 was not covered by tests
hash2 = (hash_value >> 32) & 0xFFFFFFFF

for i in range(1, num_hash_functions + 1):
combined_hash = hash1 + (i * hash2)
combined_hash &= 0xFFFFFFFF
if combined_hash > self.INT_MAX or combined_hash < 0:
combined_hash = (~combined_hash) & self.INT_MAX
if combined_hash > cls.INT_MAX or combined_hash < 0:
combined_hash = (~combined_hash) & cls.INT_MAX

Check warning on line 81 in bloomfilter/bloomfilter_strategy.py

View check run for this annotation

Codecov / codecov/patch

bloomfilter/bloomfilter_strategy.py#L80-L81

Added lines #L80 - L81 were not covered by tests
index = combined_hash % bit_size
if array[index] == 0:
return False
return True

@classmethod
def ordinal(self):
def ordinal(cls) -> int:
return 0


class MURMUR128_MITZ_64(Strategy):
@classmethod
def put(self, key, num_hash_functions, array):
def put(
cls, key: typing.Union[int, str], num_hash_functions: int, array: bitarray
) -> bool:
bit_size = len(array)
if isinstance(key, int) and self.INT_MIN <= key <= self.INT_MAX:
hash1, hash2 = mmh3.hash64(key.to_bytes(4, byteorder="little"))
elif isinstance(key, int) and self.LONG_MIN <= key <= self.LONG_MAX:
hash1, hash2 = mmh3.hash64(key.to_bytes(8, byteorder="little"))
if isinstance(key, int):
if cls.INT_MIN <= key <= cls.INT_MAX:
hash1, hash2 = mmh3.hash64(key.to_bytes(4, byteorder="little"))
elif cls.LONG_MIN <= key <= cls.LONG_MAX:
hash1, hash2 = mmh3.hash64(key.to_bytes(8, byteorder="little"))

Check warning on line 102 in bloomfilter/bloomfilter_strategy.py

View check run for this annotation

Codecov / codecov/patch

bloomfilter/bloomfilter_strategy.py#L101-L102

Added lines #L101 - L102 were not covered by tests
else:
hash1, hash2 = mmh3.hash64(key)

bits_changed = False
combined_hash = hash1
for _ in range(num_hash_functions):
index = (combined_hash & self.LONG_MAX) % bit_size
index = (combined_hash & cls.LONG_MAX) % bit_size
if array[index] == 0:
bits_changed = True
array[index] = 1
combined_hash += hash2
return bits_changed

@classmethod
def might_contain(self, key, num_hash_functions, array):
def might_contain(
cls, key: typing.Union[int, str], num_hash_functions: int, array: bitarray
) -> bool:
bit_size = len(array)
if isinstance(key, int) and self.INT_MIN <= key <= self.INT_MAX:
hash1, hash2 = mmh3.hash64(key.to_bytes(4, byteorder="little"))
elif isinstance(key, int) and self.LONG_MIN <= key <= self.LONG_MAX:
hash1, hash2 = mmh3.hash64(key.to_bytes(8, byteorder="little"))
if isinstance(key, int):
if cls.INT_MIN <= key <= cls.INT_MAX:
hash1, hash2 = mmh3.hash64(key.to_bytes(4, byteorder="little"))
elif cls.LONG_MIN <= key <= cls.LONG_MAX:
hash1, hash2 = mmh3.hash64(key.to_bytes(8, byteorder="little"))

Check warning on line 125 in bloomfilter/bloomfilter_strategy.py

View check run for this annotation

Codecov / codecov/patch

bloomfilter/bloomfilter_strategy.py#L124-L125

Added lines #L124 - L125 were not covered by tests
else:
hash1, hash2 = mmh3.hash64(key)

combined_hash = hash1
for _ in range(num_hash_functions):
index = (combined_hash & self.LONG_MAX) % bit_size
index = (combined_hash & cls.LONG_MAX) % bit_size
if not array[index]:
return False
combined_hash += hash2
return True

@classmethod
def ordinal(self):
def ordinal(cls) -> int:
return 1
7 changes: 3 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ classifiers = [
"License :: OSI Approved :: MIT License",
"Programming Language :: Python",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
Expand All @@ -20,15 +19,15 @@ classifiers = [
"Topic :: Software Development :: Libraries :: Python Modules",
]
dependencies = [
"bitarray==2.7.6",
"mmh3==3.1.0",
"bitarray==2.8.1",
"mmh3==4.0.1",
]
description = "Yet another bloomfilter implementation in Python"
keywords = ["bloomfilter"]
license = {file = "LICENSE"}
name = "bloomfilter-py"
readme = {file = "README.md", content-type = "text/markdown"}
requires-python = ">=3.7"
requires-python = ">=3.8"
version = "1.dev0"

[project.urls]
Expand Down
4 changes: 2 additions & 2 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import os

guava_file_dir = os.path.join(
guava_file_dir: str = os.path.join(
os.path.abspath(os.path.dirname(__file__)), "guava_dump_files"
)


def read_data(filename):
def read_data(filename: str) -> bytes:
"""
Read Bloomfilter serialized data from Guava's dump file.

Expand Down
14 changes: 7 additions & 7 deletions tests/test_bloomfilter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@


class BloomFilterTest(unittest.TestCase):
def test_num_of_bits(self):
def test_num_of_bits(self) -> None:
test_cases = [(500, 0.01, 4792), (500, 0.0, 774727), (10, 0.01, 95)]
for case in test_cases:
num_bits = BloomFilter.num_of_bits(case[0], case[1])
self.assertEqual(
num_bits, case[2], f"Expected {case[2]} bits, but got {num_bits}"
)

def test_num_of_hash_functions(self):
def test_num_of_hash_functions(self) -> None:
test_cases = [(500, 4792, 7), (500, 774727, 1074)]
for case in test_cases:
num_hash_functions = BloomFilter.num_of_hash_functions(case[0], case[1])
Expand All @@ -25,7 +25,7 @@ def test_num_of_hash_functions(self):
f"Expected {case[2]} hash functions, but got {num_hash_functions}",
)

def test_basic_functionality(self):
def test_basic_functionality(self) -> None:
bloom_filter = BloomFilter(10000000, 0.001)
for i in range(200):
bloom_filter.put(i)
Expand Down Expand Up @@ -54,7 +54,7 @@ def test_basic_functionality(self):
"Word 'not_exist' is expected to be in bloomfilter",
)

def test_dumps(self):
def test_dumps(self) -> None:
bloom_filter = BloomFilter(300, 0.0001, MURMUR128_MITZ_32)
for i in range(100):
bloom_filter.put(i)
Expand Down Expand Up @@ -82,7 +82,7 @@ def test_dumps(self):
"New filter's dump is expected to be the same as old filter's",
)

def test_guava_compatibility(self):
def test_guava_compatibility(self) -> None:
bloom_filter = BloomFilter.loads(read_data("500_0_01_0_to_99_test.out"))
num_bits = BloomFilter.num_of_bits(500, 0.01)
num_hash_functions = BloomFilter.num_of_hash_functions(500, num_bits)
Expand Down Expand Up @@ -111,7 +111,7 @@ def test_guava_compatibility(self):
f"Number {i} is expected to be in bloomfilter",
)

def test_dumps_to_hex(self):
def test_dumps_to_hex(self) -> None:
bloom_filter = BloomFilter(500, 0.0001, MURMUR128_MITZ_32)
for _ in range(100):
bloom_filter.put(random.randint(100000000, 10000000000))
Expand Down Expand Up @@ -139,7 +139,7 @@ def test_dumps_to_hex(self):
"New filter's dump is expected to be the same as old filter's",
)

def test_dumps_to_base64(self):
def test_dumps_to_base64(self) -> None:
bloom_filter = BloomFilter(500, 0.0001, MURMUR128_MITZ_32)
for _ in range(100):
bloom_filter.put(random.randint(100000000, 10000000000))
Expand Down
Loading