In [1]:
import csv
from itertools import combinations, chain
from collections import defaultdict

def load_transactions(file_path):
    """
    Load the groceries dataset, expecting one “basket” per line (items separated by commas),
    or possibly a “basket” representation. Adjust as per actual file format.
    Returns a list of transactions (each transaction is a set of items).
    """
    transactions = []
    with open(file_path, newline='', encoding='utf-8') as f:
        reader = csv.reader(f)
        for row in reader:
            # If there are empty entries, filter them out
            items = [item.strip() for item in row if item and item.strip() != '']
            if len(items) > 0:
                transactions.append(set(items))
    return transactions

def get_frequent_itemsets(transactions, min_support):
    """
    Run Apriori to find all frequent itemsets with support >= min_support.
    Returns dict: frequent_itemsets[k] = dict mapping itemset (frozenset) → support (float)
    Also returns a dict all_frequent mapping every frequent itemset to its support.
    """
    num_trans = len(transactions)
    # 1. Generate L1 (frequent 1-itemsets)
    item_counts = defaultdict(int)
    for trans in transactions:
        for item in trans:
            item_counts[frozenset([item])] += 1
    L1 = {}
    for itemset, count in item_counts.items():
        sup = count / num_trans
        if sup >= min_support:
            L1[itemset] = sup

    frequent_itemsets = dict()
    frequent_itemsets[1] = L1
    all_frequent = dict(L1)  # include 1-itemsets

    k = 2
    while True:
        prev_L = frequent_itemsets.get(k-1, {})
        if not prev_L:
            break
        # Generate candidates Ck from (k-1)-itemsets
        candidates = set()
        prev_itemsets = list(prev_L.keys())
        len_prev = len(prev_itemsets)
        for i in range(len_prev):
            for j in range(i+1, len_prev):
                s1 = prev_itemsets[i]
                s2 = prev_itemsets[j]
                # join step: if they share k-2 items, union them
                if len(s1.union(s2)) == k:
                    candidates.add(s1.union(s2))
        # Prune candidates: all (k-1)-subsets of candidate must be frequent
        candidates_pruned = set()
        for cand in candidates:
            all_sub_ok = True
            for subset in combinations(cand, k-1):
                if frozenset(subset) not in prev_L:
                    all_sub_ok = False
                    break
            if all_sub_ok:
                candidates_pruned.add(cand)

        # Count support of candidates by scanning transactions
        cand_count = defaultdict(int)
        for trans in transactions:
            for cand in candidates_pruned:
                if cand.issubset(trans):
                    cand_count[frozenset(cand)] += 1
        Lk = {}
        for cand, count in cand_count.items():
            sup = count / num_trans
            if sup >= min_support:
                Lk[cand] = sup
        if not Lk:
            break
        frequent_itemsets[k] = Lk
        all_frequent.update(Lk)
        k += 1

    return frequent_itemsets, all_frequent

def generate_association_rules(all_frequent, min_confidence, min_lift, min_length, transactions):
    """
    From all frequent itemsets, generate possible association rules X -> Y
    (for all splits: X ∪ Y = itemset, X ∩ Y = ∅), and compute support, confidence, lift.
    Returns a list of rules: each rule is a tuple (X, Y, support, confidence, lift).
    """
    num_trans = len(transactions)
    rules = []
    # For each frequent itemset of size ≥ 2
    for itemset, supp_itemset in all_frequent.items():
        if len(itemset) < 2:
            continue
        # for every possible non-empty proper subset X
        # Y = itemset \ X
        # We only consider rules where |X| ≥ 1 and |Y| ≥ 1, and |X| + |Y| = len(itemset)
        subsets = chain.from_iterable(combinations(itemset, r) for r in range(1, len(itemset)))
        for subset in subsets:
            X = frozenset(subset)
            Y = itemset - X
            if not Y:
                continue
            # support of X
            sup_X = all_frequent.get(X, 0)
            # support of Y
            sup_Y = all_frequent.get(Y, 0)
            if sup_X <= 0:
                continue
            # confidence = support(X ∪ Y) / support(X)
            conf = supp_itemset / sup_X
            # lift = confidence / support(Y)
            if sup_Y <= 0:
                continue
            lift = conf / sup_Y
            # filter by thresholds
            if conf >= min_confidence and lift >= min_lift and (len(X) + len(Y)) >= min_length:
                rules.append((X, Y, supp_itemset, conf, lift))
    return rules

def print_rules(rules):
    """
    Print rules with their support, confidence, lift in a readable form.
    """
    # Sort rules by e.g. descending lift or confidence
    rules_sorted = sorted(rules, key=lambda x: (x[4], x[3]), reverse=True)
    print(f"{'Rule':40s}  {'Support':>8s}  {'Confidence':>10s}  {'Lift':>8s}")
    print("-" * 70)
    for (X, Y, supp, conf, lift) in rules_sorted:
        rule_str = f"{set(X)} => {set(Y)}"
        print(f"{rule_str:40s}  {supp:8.4f}  {conf:10.4f}  {lift:8.4f}")

def main():
    # Parameters
    min_support = 0.0040
    min_confidence = 0.2
    min_lift = 3.0
    min_length = 2

    # Path to the groceries dataset (adjust as needed)
    file_path = "groceries.csv"
    print("Loading transactions …")
    transactions = load_transactions(file_path)
    print(f"Total transactions loaded: {len(transactions)}")

    print("Running Apriori to find frequent itemsets …")
    frequent_itemsets, all_frequent = get_frequent_itemsets(transactions, min_support)
    total_freq = sum(len(v) for v in frequent_itemsets.values())
    print(f"Found {total_freq} frequent itemsets")

    print("Generating association rules …")
    rules = generate_association_rules(all_frequent, min_confidence, min_lift, min_length, transactions)
    print(f"Number of rules found: {len(rules)}")

    print("\n== Rules meeting thresholds ==")
    print_rules(rules)

if __name__ == "__main__":
    main()


Loading transactions …
Total transactions loaded: 9835
Running Apriori to find frequent itemsets …
Found 1398 frequent itemsets
Generating association rules …
Number of rules found: 153

== Rules meeting thresholds ==
Rule                                       Support  Confidence      Lift
----------------------------------------------------------------------
{'flour'} => {'sugar'}                      0.0050      0.2865    8.4631
{'tropical fruit', 'root vegetables'} => {'citrus fruit', 'other vegetables'}    0.0045      0.2126    7.3610
{'root vegetables', 'citrus fruit'} => {'tropical fruit', 'other vegetables'}    0.0045      0.2529    7.0454
{'processed cheese'} => {'white bread'}     0.0042      0.2515    5.9754
{'tropical fruit', 'whipped/sour cream'} => {'yogurt', 'whole milk'}    0.0044      0.3162    5.6435
{'tropical fruit', 'root vegetables'} => {'other vegetables', 'yogurt'}    0.0050      0.2367    5.4522
{'liquor'} => {'bottled beer'}              0.0047      0.4220    5