In [1]:
# train_trie.py
# Build a TTG-guided token-replacement trie for the PubChem-100 K slice.

import time
from utils import iter_smiles
import trie_funcs as tf

SLICE = "data/pubchem_100K.parquet"
OUT_DIR = "ttg_vocab"            

In [2]:
def compute_trie_metrics(smiles_list, tokenizer):
    """
    Compute all Trie metrics (fertility, mean, variance, normalized entropy) in a single pass.
    
    Args:
        smiles_list: Path to parquet file containing SMILES strings
        tokenizer: Trie tokenizer state with compress_and_len function
        
    Returns:
        tuple: (fertility, mean, variance, normalized_entropy)
    """
    from utils import iter_smiles
    from collections import Counter
    import tqdm
    import math
    
    # Import trie_funcs locally to avoid import issues
    try:
        import trie_funcs as tf
    except ImportError:
        raise ImportError("trie_funcs module is required for Trie metrics")
    
    total_tokens = 0
    total_chars = 0
    token_counts = []
    token_freq = Counter()
    
    # Single pass through the data
    for smi in tqdm.tqdm(iter_smiles(smiles_list), desc="Trie metrics"):
        # Tokenize once per SMILES
        base = tf.tokenize(smi)  # list of atomic tokens
        tokens = tf.compress(base, tokenizer.replace_root)
        token_count = len(tokens)
        
        # Collect data for all metrics
        total_tokens += token_count
        total_chars += len(smi)
        token_counts.append(token_count)
        token_freq.update(tokens)  # For entropy calculation
    
    n = len(token_counts)
    
    # Calculate basic metrics
    fertility = total_tokens / total_chars if total_chars > 0 else 0
    trie_avg = total_tokens / n if n > 0 else 0
    trie_var = sum((count - trie_avg) ** 2 for count in token_counts) / n if n > 0 else 0
    
    # Calculate normalized entropy
    if total_tokens == 0:
        normalized_entropy = 0.0
    else:
        vocab_size = len(token_freq)
        if vocab_size <= 1:
            normalized_entropy = 0.0
        else:
            # Shannon entropy
            entropy = -sum((cnt/total_tokens) * math.log2(cnt/total_tokens)
                          for cnt in token_freq.values())
            # Normalize by log₂(observed_vocab_size)
            normalized_entropy = entropy / math.log2(vocab_size)
    
    return fertility, trie_avg, trie_var, normalized_entropy

In [3]:
def _make_out_name(k: int, freq: int, ent: float) -> str:
    """
    Produce a file name that encodes the hyper-parameters, e.g.
    ttg_pubchem100K_K8_F4_H2p0.pkl
    (the dot in entropy is replaced by “p” to keep the name shell-safe).
    """
    ent_str = str(ent).replace(".", "p")
    return f"{OUT_DIR}/ttg_pubchem100K_K{k}_F{freq}_H{ent_str}.pkl"

def run_ttg(k: int = 8, freq_thr: int = 4, entropy_thr: float = 2.0) -> str:
    """
    Build a TTG-guided compressor with the given hyper-parameters.
    Returns the path of the saved pickle.

    Example
    -------
    >>> run_ttg(k=10, freq_thr=3, entropy_thr=1.5)
    'ttg_pubchem100K_K10_F3_H1p5.pkl'
    """
    out_path = _make_out_name(k, freq_thr, entropy_thr)

    print(f"Building TTG-guided trie compressor (K={k}, FREQ_THR={freq_thr}, "
          f"ENTROPY_THR={entropy_thr}) …")
    t0 = time.time()

    state = tf.prepare_compressor_with_ttg(
        iter_smiles(SLICE),
        K=k,
        freq_thr=freq_thr,
        entropy_thr=entropy_thr,
    )

    tf.save_state(state, out_path)
    print(f"✔ Trie saved → {out_path}  ({time.time() - t0:.1f}s)")

    trie_fert, trie_avg, trie_var, trie_ent = compute_trie_metrics(SLICE, state)

    print(f"✔ Trie Metrics Computed → {out_path}  ({time.time() - t0:.1f}s)")

    return out_path, trie_fert, trie_avg, trie_var, trie_ent

In [4]:
run_ttg()

Building TTG-guided trie compressor (K=8, FREQ_THR=4, ENTROPY_THR=2.0) …
✔ Trie saved → ttg_vocab/ttg_pubchem100K_K8_F4_H2p0.pkl  (15.2s)


Trie metrics: 100000it [00:00, 103106.26it/s]

✔ Trie Metrics Computed → ttg_vocab/ttg_pubchem100K_K8_F4_H2p0.pkl  (16.2s)





('ttg_vocab/ttg_pubchem100K_K8_F4_H2p0.pkl',
 0.5747782249481419,
 25.5174,
 203.6626972400091,
 0.33564296222247203)