Skip to content

Commit

Permalink
Add more type hints according to mypy --strict
Browse files Browse the repository at this point in the history
  • Loading branch information
OldPanda committed Sep 10, 2023
1 parent b0c3b3f commit 0b38785
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 20 deletions.
2 changes: 2 additions & 0 deletions bloomfilter/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from bloomfilter.bloomfilter import BloomFilter

__all__ = ["BloomFilter"]
10 changes: 4 additions & 6 deletions bloomfilter/bloomfilter.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@ def __init__(
expected_insertions: int,
err_rate: float,
strategy: typing.Type[Strategy] = MURMUR128_MITZ_64,
*args,
**kwargs,
):
if err_rate <= 0:
raise ValueError("Error rate must be > 0.0")
Expand All @@ -49,7 +47,7 @@ def __init__(

def setup(
self, num_hash_functions: int, data: bitarray, strategy: typing.Type[Strategy]
):
) -> None:
self.num_hash_functions = num_hash_functions
self.data = data
self.strategy = strategy
Expand Down Expand Up @@ -154,17 +152,17 @@ def num_of_hash_functions(cls, expected_insertions: int, num_bits: int) -> int:
"""
return max(1, round(num_bits / expected_insertions * math.log(2)))

def put(self, key):
def put(self, key: typing.Union[int, str]) -> bool:
"""
Put an element into the Bloomfilter.
"""
return self.strategy.put(key, self.num_hash_functions, self.data)

def might_contain(self, key):
def might_contain(self, key: typing.Union[int, str]) -> bool:
"""
Return ``True`` if given element exists in Bloomfilter. Otherwise return ``False``.
"""
return self.strategy.might_contain(key, self.num_hash_functions, self.data)

def __contains__(self, key):
def __contains__(self, key: typing.Union[int, str]) -> bool:
return self.might_contain(key)
16 changes: 11 additions & 5 deletions bloomfilter/bloomfilter_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ class Strategy(ABC):

@classmethod
@abstractmethod
def put(cls, key: typing.Union[int, str], num_hash_functions: int, array: bitarray):
def put(
cls, key: typing.Union[int, str], num_hash_functions: int, array: bitarray
) -> bool:
pass

@classmethod
Expand All @@ -31,7 +33,9 @@ def ordinal(cls) -> int:

class MURMUR128_MITZ_32(Strategy):
@classmethod
def put(cls, key: typing.Union[int, str], num_hash_functions: int, array: bitarray):
def put(
cls, key: typing.Union[int, str], num_hash_functions: int, array: bitarray
) -> bool:
bit_size = len(array)
if isinstance(key, int):
if cls.INT_MIN <= key <= cls.INT_MAX:
Expand All @@ -58,7 +62,7 @@ def put(cls, key: typing.Union[int, str], num_hash_functions: int, array: bitarr
@classmethod
def might_contain(
cls, key: typing.Union[int, str], num_hash_functions: int, array: bitarray
):
) -> bool:
bit_size = len(array)
if isinstance(key, int):
if cls.INT_MIN <= key <= cls.INT_MAX:
Expand Down Expand Up @@ -87,7 +91,9 @@ def ordinal(cls) -> int:

class MURMUR128_MITZ_64(Strategy):
@classmethod
def put(cls, key: typing.Union[int, str], num_hash_functions: int, array: bitarray):
def put(
cls, key: typing.Union[int, str], num_hash_functions: int, array: bitarray
) -> bool:
bit_size = len(array)
if isinstance(key, int):
if cls.INT_MIN <= key <= cls.INT_MAX:
Expand All @@ -110,7 +116,7 @@ def put(cls, key: typing.Union[int, str], num_hash_functions: int, array: bitarr
@classmethod
def might_contain(
cls, key: typing.Union[int, str], num_hash_functions: int, array: bitarray
):
) -> bool:
bit_size = len(array)
if isinstance(key, int):
if cls.INT_MIN <= key <= cls.INT_MAX:
Expand Down
4 changes: 2 additions & 2 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import os

guava_file_dir = os.path.join(
guava_file_dir: str = os.path.join(
os.path.abspath(os.path.dirname(__file__)), "guava_dump_files"
)


def read_data(filename):
def read_data(filename: str) -> bytes:
"""
Read Bloomfilter serialized data from Guava's dump file.
Expand Down
14 changes: 7 additions & 7 deletions tests/test_bloomfilter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@


class BloomFilterTest(unittest.TestCase):
def test_num_of_bits(self):
def test_num_of_bits(self) -> None:
test_cases = [(500, 0.01, 4792), (500, 0.0, 774727), (10, 0.01, 95)]
for case in test_cases:
num_bits = BloomFilter.num_of_bits(case[0], case[1])
self.assertEqual(
num_bits, case[2], f"Expected {case[2]} bits, but got {num_bits}"
)

def test_num_of_hash_functions(self):
def test_num_of_hash_functions(self) -> None:
test_cases = [(500, 4792, 7), (500, 774727, 1074)]
for case in test_cases:
num_hash_functions = BloomFilter.num_of_hash_functions(case[0], case[1])
Expand All @@ -25,7 +25,7 @@ def test_num_of_hash_functions(self):
f"Expected {case[2]} hash functions, but got {num_hash_functions}",
)

def test_basic_functionality(self):
def test_basic_functionality(self) -> None:
bloom_filter = BloomFilter(10000000, 0.001)
for i in range(200):
bloom_filter.put(i)
Expand Down Expand Up @@ -54,7 +54,7 @@ def test_basic_functionality(self):
"Word 'not_exist' is expected to be in bloomfilter",
)

def test_dumps(self):
def test_dumps(self) -> None:
bloom_filter = BloomFilter(300, 0.0001, MURMUR128_MITZ_32)
for i in range(100):
bloom_filter.put(i)
Expand Down Expand Up @@ -82,7 +82,7 @@ def test_dumps(self):
"New filter's dump is expected to be the same as old filter's",
)

def test_guava_compatibility(self):
def test_guava_compatibility(self) -> None:
bloom_filter = BloomFilter.loads(read_data("500_0_01_0_to_99_test.out"))
num_bits = BloomFilter.num_of_bits(500, 0.01)
num_hash_functions = BloomFilter.num_of_hash_functions(500, num_bits)
Expand Down Expand Up @@ -111,7 +111,7 @@ def test_guava_compatibility(self):
f"Number {i} is expected to be in bloomfilter",
)

def test_dumps_to_hex(self):
def test_dumps_to_hex(self) -> None:
bloom_filter = BloomFilter(500, 0.0001, MURMUR128_MITZ_32)
for _ in range(100):
bloom_filter.put(random.randint(100000000, 10000000000))
Expand Down Expand Up @@ -139,7 +139,7 @@ def test_dumps_to_hex(self):
"New filter's dump is expected to be the same as old filter's",
)

def test_dumps_to_base64(self):
def test_dumps_to_base64(self) -> None:
bloom_filter = BloomFilter(500, 0.0001, MURMUR128_MITZ_32)
for _ in range(100):
bloom_filter.put(random.randint(100000000, 10000000000))
Expand Down

0 comments on commit 0b38785

Please sign in to comment.