# Cell 1 - Imports

In [9]:
import ast
import itertools
import math
import random
from dataclasses import dataclass, field
from typing import Any, Dict, List, Tuple, Optional

import matplotlib.pyplot as plt
import networkx as nx
import pandas as pd

import ipywidgets as widgets
from IPython.display import display, clear_output

In [10]:
# Cell 2 - Parsing utilities + hashing

def _try_literal_eval(text: str):
    text = text.strip()
    if not text:
        return None
    try:
        return ast.literal_eval(text)
    except Exception:
        return None

def parse_itemsets(text: str) -> List[Tuple[Any, ...]]:
    """
    Accepts:
      - Python literal: [[1,2,3],[1,3,4]]
      - One itemset per line: 1 2 3
      - Comma/space separated per line: 1,2,3
      - Semicolon-separated itemsets: 1 2 3; 1 3 4
    Returns sorted tuples.
    """
    lit = _try_literal_eval(text)
    if lit is not None:
        # Could be list of lists/tuples, or a single itemset
        if isinstance(lit, (list, tuple)) and len(lit) > 0 and isinstance(lit[0], (list, tuple, set)):
            out = []
            for it in lit:
                it2 = tuple(sorted(list(it)))
                out.append(it2)
            return out
        if isinstance(lit, (list, tuple, set)):
            return [tuple(sorted(list(lit)))]
        raise ValueError("Could not interpret candidates literal.")

    # Fallback: split by ; or newlines
    chunks = []
    for part in text.replace(";", "\n").splitlines():
        part = part.strip()
        if part:
            chunks.append(part)

    out = []
    for line in chunks:
        line = line.replace(",", " ")
        toks = [t for t in line.split() if t]
        # try to parse numbers, else keep as strings
        items = []
        for t in toks:
            try:
                items.append(int(t))
            except Exception:
                try:
                    items.append(float(t))
                except Exception:
                    items.append(t)
        out.append(tuple(sorted(items)))
    return out

def parse_transactions(text: str) -> List[Tuple[Any, ...]]:
    """
    Same accepted formats as parse_itemsets, but returns list of transactions (tuples).
    """
    lit = _try_literal_eval(text)
    if lit is not None:
        if isinstance(lit, (list, tuple)) and len(lit) > 0 and isinstance(lit[0], (list, tuple, set)):
            return [tuple(sorted(list(t))) for t in lit]
        if isinstance(lit, (list, tuple, set)):
            return [tuple(sorted(list(lit)))]
        raise ValueError("Could not interpret transactions literal.")
    return [t for t in parse_itemsets(text)]

def default_hash_key(item: Any, mod: int) -> int:
    """
    Hash key for branching.
    - ints: item % mod
    - other: stable-ish string hash mapped to [0, mod-1]
    """
    if mod <= 0:
        raise ValueError("mod must be >= 1")
    if isinstance(item, int):
        return item % mod
    s = str(item)
    # deterministic string hash (avoid Python's randomized hash)
    h = 0
    for ch in s:
        h = (h * 131 + ord(ch)) % 2_147_483_647
    return h % mod


In [11]:
# Cell 3 - Hash Tree core

@dataclass
class HashTreeNode:
    depth: int
    is_leaf: bool = True
    bucket: List[Tuple[Any, ...]] = field(default_factory=list)   # candidates at leaf
    children: Dict[int, "HashTreeNode"] = field(default_factory=dict)

    def split(self, mod: int, node_size: int, k: int):
        """
        Turn this leaf into an internal node, re-distribute bucket candidates.
        Split uses the item at position = depth.
        """
        if not self.is_leaf:
            return
        if self.depth >= k:
            return

        old = self.bucket
        self.bucket = []
        self.is_leaf = False

        for cand in old:
            key = default_hash_key(cand[self.depth], mod)
            if key not in self.children:
                self.children[key] = HashTreeNode(depth=self.depth + 1, is_leaf=True)
            self.children[key].insert(cand, mod=mod, node_size=node_size, k=k)

    def insert(self, cand: Tuple[Any, ...], mod: int, node_size: int, k: int):
        if self.is_leaf:
            self.bucket.append(cand)
            if len(self.bucket) > node_size and self.depth < k:
                self.split(mod=mod, node_size=node_size, k=k)
            return

        # internal
        if self.depth >= k:
            # fallback (should not happen for uniform-k candidates)
            self.is_leaf = True
            self.bucket.append(cand)
            return

        key = default_hash_key(cand[self.depth], mod)
        if key not in self.children:
            self.children[key] = HashTreeNode(depth=self.depth + 1, is_leaf=True)
        self.children[key].insert(cand, mod=mod, node_size=node_size, k=k)

@dataclass
class HashTree:
    k: int
    mod: int
    node_size: int
    root: HashTreeNode = field(default_factory=lambda: HashTreeNode(depth=0, is_leaf=True))

    def add_candidates(self, candidates: List[Tuple[Any, ...]]):
        for c in candidates:
            if len(c) != self.k:
                raise ValueError(f"All candidates must have length k={self.k}. Found {c}.")
            self.root.insert(c, mod=self.mod, node_size=self.node_size, k=self.k)

def build_hash_trees(candidates: List[Tuple[Any, ...]], mod: int, node_size: int) -> Dict[int, HashTree]:
    """
    If candidates have mixed lengths, build one tree per k.
    Returns {k: HashTree}.
    """
    by_k: Dict[int, List[Tuple[Any, ...]]] = {}
    for c in candidates:
        by_k.setdefault(len(c), []).append(tuple(sorted(c)))

    trees: Dict[int, HashTree] = {}
    for k, cands in sorted(by_k.items(), key=lambda x: x[0]):
        t = HashTree(k=k, mod=mod, node_size=node_size)
        t.add_candidates(cands)
        trees[k] = t
    return trees


In [12]:
# Cell 4 - Visualization + support counting

def _tree_to_digraph(root: HashTreeNode, show_leaf_buckets: bool = False) -> nx.DiGraph:
    G = nx.DiGraph()
    counter = {"i": 0}

    def new_id():
        counter["i"] += 1
        return f"n{counter['i']}"

    def node_label(node: HashTreeNode) -> str:
        if node.is_leaf:
            if show_leaf_buckets:
                bucket_str = "\n".join([str(c) for c in node.bucket])
                return f"Leaf d={node.depth}\n|C|={len(node.bucket)}\n{bucket_str}"
            return f"Leaf d={node.depth}\n|C|={len(node.bucket)}"
        return f"Internal d={node.depth}\n|ch|={len(node.children)}"

    def walk(node: HashTreeNode) -> str:
        nid = new_id()
        G.add_node(nid, label=node_label(node))
        if not node.is_leaf:
            for key, child in sorted(node.children.items(), key=lambda x: x[0]):
                cid = walk(child)
                G.add_edge(nid, cid, label=str(key))
        return nid

    walk(root)
    return G

def draw_hash_tree(tree: HashTree, show_leaf_buckets: bool = False, figsize=(12, 7)):
    G = _tree_to_digraph(tree.root, show_leaf_buckets=show_leaf_buckets)

    try:
        pos = nx.nx_pydot.graphviz_layout(G, prog="dot")
    except Exception:
        pos = nx.spring_layout(G, seed=7)

    plt.figure(figsize=figsize)
    labels = nx.get_node_attributes(G, "label")
    nx.draw(G, pos, with_labels=False, arrows=True, node_size=2200)
    nx.draw_networkx_labels(G, pos, labels=labels, font_size=8)
    edge_labels = nx.get_edge_attributes(G, "label")
    nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels, font_size=8)
    plt.title(f"Hash Tree (k={tree.k}, mod={tree.mod}, node_size={tree.node_size})")
    plt.axis("off")
    plt.show()

def count_supports(tree: HashTree, transactions: List[Tuple[Any, ...]]) -> Dict[Tuple[Any, ...], int]:
    """
    Classic hash-tree counting:
      - Traverse internal nodes by hashing items at current depth
      - At leaves, verify subset for each bucket candidate and increment
    """
    counts: Dict[Tuple[Any, ...], int] = {}

    def subset_of(cand: Tuple[Any, ...], tset: set) -> bool:
        for x in cand:
            if x not in tset:
                return False
        return True

    def traverse(node: HashTreeNode, items: List[Any], start_idx: int, tset: set):
        if node.is_leaf:
            for cand in node.bucket:
                if subset_of(cand, tset):
                    counts[cand] = counts.get(cand, 0) + 1
            return

        d = node.depth
        # choose next item for this depth from remaining items
        for i in range(start_idx, len(items)):
            key = default_hash_key(items[i], tree.mod)
            child = node.children.get(key)
            if child is not None:
                traverse(child, items, i + 1, tset)

    for t in transactions:
        items = sorted(set(t))
        if len(items) < tree.k:
            continue
        tset = set(items)
        traverse(tree.root, items, 0, tset)

    return counts

def supports_table(candidates: List[Tuple[Any, ...]], counts: Dict[Tuple[Any, ...], int]) -> pd.DataFrame:
    rows = []
    for c in candidates:
        rows.append({"candidate": c, "support": int(counts.get(c, 0))})
    df = pd.DataFrame(rows).sort_values(["support", "candidate"], ascending=[False, True]).reset_index(drop=True)
    return df
