diff --git a/machine_learning/apriori_algorithm.py b/machine_learning/apriori_algorithm.py index 5c3e2baba2c2..560b92f2fb71 100644 --- a/machine_learning/apriori_algorithm.py +++ b/machine_learning/apriori_algorithm.py @@ -45,16 +45,16 @@ def prune(itemset: list, candidates: list, length: int) -> list: >>> prune(itemset, candidates, 3) [] """ + # Use a set for O(1) membership and a Counter for multiplicity checks to + # preserve robustness for edge cases and match existing doctest behavior. + itemset_set = {tuple(item) for item in itemset} itemset_counter = Counter(tuple(item) for item in itemset) pruned = [] for candidate in candidates: is_subsequence = True for item in candidate: item_tuple = tuple(item) - if ( - item_tuple not in itemset_counter - or itemset_counter[item_tuple] < length - 1 - ): + if item_tuple not in itemset_set or itemset_counter[item_tuple] < length - 1: is_subsequence = False break if is_subsequence: