In [1]:
# %%
"""
Ex8_1 (modified): Sequence-Level Softmax Routing with Frozen Experts and Load-Balancing Loss
CPU-friendly PyTorch notebook

Changes:
 - Domains/topics replaced with: math, legal, biokem, storytelling
 - Synthetic generators and test samples updated accordingly
 - Rest of the pipeline kept intact
"""

# %%
# Imports & Config
import os
import random
import math
from collections import defaultdict, Counter
import numpy as np
import matplotlib.pyplot as plt
import csv

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from transformers import GPT2Tokenizer
import seaborn as sns

# Config (tweak for CPU speed)
CONFIG = {
    'hidden_size': 128,
    'num_experts': 4,
    'expert_layers': 1,
    'pretrain_epochs': 2,
    'pretrain_n_per_domain': 300,
    'moe_train_epochs': 4,
    'moe_n_per_domain': 600,
    'batch_size': 8,
    'lr': 1e-4,
    'seed_list': [42],
    'use_pretrained_embeddings': False,
    'freeze_pretrained_for': 2,   # freeze expert weights for first k MOE epochs
    'entropy_reg_seq': 0.5,      # weight to penalize per-sequence entropy (encourage low entropy)
    'entropy_reg_mean': 0.5,     # weight to penalize mean-gate entropy across batch (encourage non-uniform mean)
    'temperature': 0.7,
    'save_dir': 'ex8_1_results'
}

os.makedirs(CONFIG['save_dir'], exist_ok=True)

# Reproducibility helper
def set_seed(s):
    random.seed(s)
    np.random.seed(s)
    torch.manual_seed(s)

# %%
# Synthetic dataset (topics replaced: math, legal, biokem, storytelling)
class SpecializedTextDataset(Dataset):
    def __init__(self, tokenizer, domain=None, n_per_domain=300, max_length=64):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.data = []
        self.labels = []

        def gen_math(n):
            templates = [
                "We prove that {statement} by induction on {index}.",
                "The {object} has eigenvalues given by {expr}, which implies {consequence}.",
                "A simple algorithm for {problem} runs in O({complexity}) time and uses {space} space."
            ]
            statements = ["the sequence converges", "the function is continuous", "the series is divergent"]
            objects = ["matrix A", "operator T", "linear map"]
            exprs = ["λ_i = i^2", "roots r_1 and r_2", "sinusoids at harmonics"]
            problems = ["sorting", "graph traversal", "prime testing"]
            complexities = ["n log n", "n^2", "poly(n)"]
            spaces = ["O(1)", "O(n)"]
            out = []
            for _ in range(n):
                out.append(random.choice(templates).format(
                    statement=random.choice(statements),
                    index=random.choice(["n","k"]),
                    object=random.choice(objects),
                    expr=random.choice(exprs),
                    consequence=random.choice(["stability","boundedness","uniqueness"]),
                    problem=random.choice(problems),
                    complexity=random.choice(complexities),
                    space=random.choice(spaces)
                ))
            return out

        def gen_legal(n):
            templates = [
                "The court found that the {contract_term} breached the statute under {law}.",
                "Under {jurisdiction} law, the party may seek {remedy} for {violation}.",
                "The precedent in {case} clarifies the interpretation of {doctrine}."
            ]
            contract_terms = ["non-compete clause", "warranty provision", "confidentiality clause"]
            laws = ["Section 12", "tort law", "consumer protection statutes"]
            jurisdictions = ["federal", "state", "EU"]
            remedies = ["injunctive relief", "damages", "restitution"]
            violations = ["fraud", "negligence", "breach of contract"]
            cases = ["Smith v. Jones", "R v. Corporation"]
            doctrines = ["reasonable expectation", "strict liability", "duty of care"]
            out = []
            for _ in range(n):
                out.append(random.choice(templates).format(
                    contract_term=random.choice(contract_terms),
                    law=random.choice(laws),
                    jurisdiction=random.choice(jurisdictions),
                    remedy=random.choice(remedies),
                    violation=random.choice(violations),
                    case=random.choice(cases),
                    doctrine=random.choice(doctrines)
                ))
            return out

        def gen_biokem(n):
            # "biokem" used as shorthand for biochemical/biokemistry style text
            templates = [
                "The assay measured {molecule} concentration after {treatment}.",
                "{enzyme} catalyzes the conversion of {substrate} to {product} with Km={km} and Vmax={vmax}.",
                "Mass spectrometry revealed a peak corresponding to {compound} consistent with {interpretation}."
            ]
            molecules = ["ATP", "glucose", "lactate"]
            treatments = ["incubation", "heat shock", "drug exposure"]
            enzymes = ["hexokinase", "polymerase", "lipase"]
            substrates = ["glucose", "DNA", "lipid"]
            products = ["G6P", "cDNA", "fatty acids"]
            kms = ["0.1 mM", "5 µM", "50 µM"]
            vmaxs = ["100 nmol/min", "2 µmol/min", "0.5 µmol/min"]
            compounds = ["peptide A", "metabolite X"]
            interpretations = ["metabolic activation", "post-translational modification"]
            out = []
            for _ in range(n):
                out.append(random.choice(templates).format(
                    molecule=random.choice(molecules),
                    treatment=random.choice(treatments),
                    enzyme=random.choice(enzymes),
                    substrate=random.choice(substrates),
                    product=random.choice(products),
                    km=random.choice(kms),
                    vmax=random.choice(vmaxs),
                    compound=random.choice(compounds),
                    interpretation=random.choice(interpretations)
                ))
            return out

        def gen_storytelling(n):
            templates = [
                "When {character} entered the {place}, they found a {object} that changed everything.",
                "The narrative follows {protagonist} as they face {conflict} and discover {reveal}.",
                "A quiet scene at {setting} escalates into a confrontation about {theme}."
            ]
            characters = ["an old sailor", "a young coder", "a grieving parent"]
            places = ["abandoned pier", "neon-lit café", "derelict apartment"]
            objects = ["tattered map", "burned letter", "silver locket"]
            protagonists = ["Mara", "Jon", "the narrator"]
            conflicts = ["a moral dilemma", "a long-buried secret", "a financial crisis"]
            reveals = ["their true origin", "a hidden motive", "an unexpected ally"]
            settings = ["dawn", "midnight", "a crowded market"]
            themes = ["loss", "ambition", "betrayal"]
            out = []
            for _ in range(n):
                out.append(random.choice(templates).format(
                    character=random.choice(characters),
                    place=random.choice(places),
                    object=random.choice(objects),
                    protagonist=random.choice(protagonists),
                    conflict=random.choice(conflicts),
                    reveal=random.choice(reveals),
                    setting=random.choice(settings),
                    theme=random.choice(themes)
                ))
            return out

        pools = {
            'math': gen_math(n_per_domain),
            'legal': gen_legal(n_per_domain),
            'biokem': gen_biokem(n_per_domain),
            'storytelling': gen_storytelling(n_per_domain)
        }

        if domain is None:
            for d, samples in pools.items():
                for s in samples:
                    self.data.append(s)
                    self.labels.append(d)
        else:
            for s in pools[domain]:
                self.data.append(s)
                self.labels.append(domain)

        combined = list(zip(self.data, self.labels))
        random.shuffle(combined)
        self.data, self.labels = zip(*combined)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        text = self.data[idx]
        label = self.labels[idx]
        enc = tokenizer(text, max_length=self.max_length, padding='max_length', truncation=True, return_tensors='pt')
        return {'input_ids': enc['input_ids'].squeeze(0), 'attention_mask': enc['attention_mask'].squeeze(0), 'domain': label}

# %%
# Model: experts + sequence-level softmax router + dense mixing
class TransformerExpert(nn.Module):
    def __init__(self, hidden_size=CONFIG['hidden_size'], num_layers=CONFIG['expert_layers']):
        super().__init__()
        encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_size, nhead=4, dim_feedforward=hidden_size*2, batch_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

    def forward(self, hidden, src_key_padding_mask=None):
        return self.transformer(hidden, src_key_padding_mask=src_key_padding_mask)

class SeqRouterSoftmax(nn.Module):
    def __init__(self, hidden_size=CONFIG['hidden_size'], num_experts=CONFIG['num_experts'], hidden_mid=None):
        super().__init__()
        hidden_mid = hidden_mid or max(16, hidden_size//4)
        self.net = nn.Sequential(nn.Linear(hidden_size, hidden_mid), nn.ReLU(), nn.Linear(hidden_mid, num_experts))

    def forward(self, pooled, temperature=1.0):
        logits = self.net(pooled) / (temperature + 1e-12)
        probs = F.softmax(logits, dim=-1)
        return logits, probs

class TransformerMoESeqSoftmax(nn.Module):
    def __init__(self, vocab_size, hidden_size=CONFIG['hidden_size'], num_experts=CONFIG['num_experts'], num_layers=CONFIG['expert_layers'], temperature=CONFIG['temperature']):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, hidden_size)
        self.router = SeqRouterSoftmax(hidden_size, num_experts)
        self.experts = nn.ModuleList([TransformerExpert(hidden_size, num_layers) for _ in range(num_experts)])
        self.lm_head = nn.Linear(hidden_size, vocab_size)
        self.num_experts = num_experts
        self.temperature = temperature

    def forward(self, input_ids, attention_mask=None, return_routing=False):
        B, L = input_ids.shape
        hidden = self.emb(input_ids)
        if attention_mask is not None:
            lengths = attention_mask.sum(dim=1, keepdim=True).clamp(min=1)
            pooled = (hidden * attention_mask.unsqueeze(-1)).sum(dim=1) / lengths
        else:
            pooled = hidden.mean(dim=1)

        logits_seq, probs_seq = self.router(pooled, temperature=self.temperature)
        # compute outputs for each expert
        expert_outputs = []
        for e in range(self.num_experts):
            out = self.experts[e](hidden)
            expert_outputs.append(out)
        expert_outputs = torch.stack(expert_outputs, dim=1)  # [B, E, L, H]

        # mix experts densely using probs_seq weights
        probs_seq_exp = probs_seq.unsqueeze(-1).unsqueeze(-1)  # [B, E, 1, 1]
        mixed = (expert_outputs * probs_seq_exp).sum(dim=1)  # [B, L, H]

        logits = self.lm_head(mixed)
        if return_routing:
            return logits, probs_seq
        return logits

    def load_pretrained_expert(self, idx, state_dict):
        self.experts[idx].load_state_dict(state_dict)
        print(f"Loaded expert into slot {idx}")

# %%
# Pretrain experts lightly (Phase 1)
def pretrain_expert_light(domain, tokenizer, vocab_size, device, n_samples=CONFIG['pretrain_n_per_domain'], epochs=CONFIG['pretrain_epochs']):
    ds = SpecializedTextDataset(tokenizer, domain=domain, n_per_domain=n_samples)
    loader = DataLoader(ds, batch_size=CONFIG['batch_size'], shuffle=True)

    expert = TransformerExpert(CONFIG['hidden_size'], CONFIG['expert_layers']).to(device)
    emb = nn.Embedding(vocab_size, CONFIG['hidden_size']).to(device)
    lm_head = nn.Linear(CONFIG['hidden_size'], vocab_size).to(device)

    opt = torch.optim.Adam(list(expert.parameters()) + list(emb.parameters()) + list(lm_head.parameters()), lr=CONFIG['lr'])
    criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)

    for epoch in range(epochs):
        expert.train()
        total, count = 0.0, 0
        for batch in loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            hidden = emb(input_ids)
            out = expert(hidden)
            logits = lm_head(out)
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = input_ids[..., 1:].contiguous()
            loss = criterion(shift_logits.view(-1, vocab_size), shift_labels.view(-1))
            opt.zero_grad(); loss.backward(); opt.step()
            total += loss.item(); count += 1
        print(f"Pretrain {domain}: epoch {epoch+1}/{epochs} avg loss {total/count:.4f}")

    return expert.state_dict(), emb.state_dict()

# %%
# Transfer + train MoE with frozen experts and entropy penalties (Phase 2)
def transfer_and_train_moe_softmax(pre_states, pre_embeds, tokenizer, vocab_size, device):
    moe = TransformerMoESeqSoftmax(vocab_size, CONFIG['hidden_size'], CONFIG['num_experts'], CONFIG['expert_layers'], temperature=CONFIG['temperature']).to(device)

    # load pretrained experts
    for i, st in enumerate(pre_states):
        moe.load_pretrained_expert(i, st)

    # optionally average embeddings
    if CONFIG['use_pretrained_embeddings']:
        avg = None
        for e in pre_embeds:
            w = torch.tensor(e['weight'])
            if avg is None: avg = w.clone()
            else: avg += w
        avg = (avg / len(pre_embeds)).to(moe.emb.weight.device)
        moe.emb.weight.data.copy_(avg)

    # dataset
    mixed = SpecializedTextDataset(tokenizer, domain=None, n_per_domain=CONFIG['moe_n_per_domain'])
    loader = DataLoader(mixed, batch_size=CONFIG['batch_size'], shuffle=True)

    opt = torch.optim.Adam(moe.parameters(), lr=CONFIG['lr'])
    criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)

    # freeze experts for initial epochs
    freeze_k = CONFIG['freeze_pretrained_for']

    for epoch in range(CONFIG['moe_train_epochs']):
        if epoch < freeze_k:
            for p in moe.experts.parameters(): p.requires_grad = False
            print(f"Epoch {epoch+1}: experts frozen")
        elif epoch == freeze_k:
            for p in moe.experts.parameters(): p.requires_grad = True
            print(f"Epoch {epoch+1}: experts unfrozen")

        moe.train()
        total, count = 0.0, 0
        seq_entropies = []
        mean_gate_entropies = []
        for batch in loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            logits, probs = moe(input_ids, attention_mask, return_routing=True)
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = input_ids[..., 1:].contiguous()
            loss = criterion(shift_logits.view(-1, vocab_size), shift_labels.view(-1))

            # per-sequence entropy (encourage low entropy -> confident gating)
            seq_entropy = - (probs * (probs + 1e-12).log()).sum(dim=-1).mean()
            mean_gates = probs.mean(dim=0)
            mean_gate_entropy = - (mean_gates * (mean_gates + 1e-12).log()).sum()

            # add penalties (positive weights penalize high entropy)
            loss = loss + CONFIG['entropy_reg_seq'] * seq_entropy + CONFIG['entropy_reg_mean'] * mean_gate_entropy

            opt.zero_grad(); loss.backward(); opt.step()
            total += loss.item(); count += 1
            seq_entropies.append(seq_entropy.item()); mean_gate_entropies.append(mean_gate_entropy.item())

        print(f"Epoch {epoch+1}/{CONFIG['moe_train_epochs']} avg loss {total/count:.4f} avg_seq_ent {np.mean(seq_entropies):.4f} avg_mean_gate_ent {np.mean(mean_gate_entropies):.4f}")

    return moe

# %%
# Analysis: compute usage heatmap, DDI, entropy per domain

def analyze_moe(moe, tokenizer, device):
    print('\nAnalyzing routing...')
    test_samples = {
        'math': [
            "We show the sequence is Cauchy and hence convergent in the given metric.",
            "An algorithm for prime sieving runs in n log log n time in practice.",
            "The eigenvalues of the symmetric matrix are all real and bounded."
        ],
        'legal': [
            "The contract's non-compete clause was struck down under state precedent.",
            "Under EU law, the plaintiff may claim damages for the breach of statutory duty.",
            "The appellate decision clarified the doctrine of reasonable expectation."
        ],
        'biokem': [
            "The enzyme kinetics showed a Michaelis-Menten curve with Km around 5 µM.",
            "After drug exposure the assay detected increased levels of lactate.",
            "Mass spectrometry confirmed the presence of peptide A consistent with activation."
        ],
        'storytelling': [
            "When Mara opened the letter she discovered a map leading to the abandoned pier.",
            "A quiet morning escalated into a confrontation about betrayal at the café.",
            "The protagonist faces a moral dilemma that forces a difficult choice."
        ],
    }
    domains = ['math', 'legal', 'biokem', 'storytelling']

    moe.eval()
    routing_avgs = {d: [] for d in domains}
    seq_entropies = {d: [] for d in domains}

    with torch.no_grad():
        for d, samples in test_samples.items():
            for s in samples:
                enc = tokenizer(s, max_length=64, padding='max_length', truncation=True, return_tensors='pt')
                input_ids = enc['input_ids'].to(device)
                attention_mask = enc['attention_mask'].to(device)
                logits, probs = moe(input_ids, attention_mask, return_routing=True)
                probs_np = probs[0].cpu().numpy()
                routing_avgs[d].append(probs_np)
                ent = - (probs * (probs + 1e-12).log()).sum(dim=-1).item()
                seq_entropies[d].append(ent)

    usage = np.zeros((len(domains), moe.num_experts))
    for i, d in enumerate(domains):
        arr = np.array(routing_avgs[d]) if routing_avgs[d] else np.zeros((1, moe.num_experts))
        usage[i] = arr.mean(axis=0) * 100.0

    # DDI: mean of max per-row fraction (normalized 0..1)
    row_max = usage.max(axis=1)
    ddi = float(row_max.mean() / 100.0)

    mean_ent = {d: float(np.mean(seq_entropies[d])) for d in domains}

    # save CSV and heatmap
    csv_path = os.path.join(CONFIG['save_dir'], 'ex8_1_usage.csv')
    with open(csv_path, 'w', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(['domain'] + [f'E{e}' for e in range(moe.num_experts)])
        for i, d in enumerate(domains):
            writer.writerow([d] + usage[i].tolist())

    plt.figure(figsize=(6,4))
    sns.heatmap(usage, annot=True, fmt='.1f', cmap='YlOrRd', xticklabels=[f'E{e}' for e in range(moe.num_experts)], yticklabels=[d.capitalize() for d in domains])
    plt.title('Sequence-level Average Expert Usage (%) by Domain')
    plt.tight_layout()
    plt.savefig(os.path.join(CONFIG['save_dir'], 'ex8_1_usage_heatmap.png'), dpi=150, bbox_inches='tight')
    plt.close()

    return usage, ddi, mean_ent

# %%
# Multi-seed runner

def run_single_seed(seed):
    print(f"\n=== Seed {seed} ===")
    set_seed(seed)
    device = torch.device('cpu')
    tok = GPT2Tokenizer.from_pretrained('gpt2')
    tok.pad_token = tok.eos_token
    global tokenizer
    tokenizer = tok
    vocab_size = tokenizer.vocab_size

    # Phase 1: pretrain experts
    domains = ['math', 'legal', 'biokem', 'storytelling']
    pre_states, pre_embeds = [], []
    for d in domains:
        st, eb = pretrain_expert_light(d, tokenizer, vocab_size, device)
        pre_states.append(st); pre_embeds.append(eb)

    # Phase 2: transfer & train
    moe = transfer_and_train_moe_softmax(pre_states, pre_embeds, tokenizer, vocab_size, device)

    # Analyze
    usage, ddi, mean_ent = analyze_moe(moe, tokenizer, device)
    print('DDI:', ddi)
    print('Mean entropy per domain:', mean_ent)

    return {'usage': usage, 'ddi': ddi, 'entropy': mean_ent}


def run_multi_seed():
    results = []
    for s in CONFIG['seed_list']:
        res = run_single_seed(s)
        results.append(res)
    # aggregate DDI
    ddis = [r['ddi'] for r in results]
    print('\nMulti-seed DDI mean:', np.mean(ddis), 'std:', np.std(ddis))
    return results

# %%
# Entry
if __name__ == '__main__':
    res = run_multi_seed()
    print('\nFinished. Results saved to', CONFIG['save_dir'])

# %%
# Small textual summary for LLM consumption
summary = []
summary.append('Ex8_1 run summary (topics: math, legal, biokem, storytelling):')
summary.append(f"Config: {CONFIG}")
print('\n'.join(summary))

# End of notebook



=== Seed 42 ===
Pretrain math: epoch 1/2 avg loss 10.4968
Pretrain math: epoch 2/2 avg loss 9.4091
Pretrain legal: epoch 1/2 avg loss 10.4417
Pretrain legal: epoch 2/2 avg loss 9.1212
Pretrain biokem: epoch 1/2 avg loss 10.4311
Pretrain biokem: epoch 2/2 avg loss 9.3107
Pretrain storytelling: epoch 1/2 avg loss 10.4534
Pretrain storytelling: epoch 2/2 avg loss 9.2630
Loaded expert into slot 0
Loaded expert into slot 1
Loaded expert into slot 2
Loaded expert into slot 3
Epoch 1: experts frozen
Epoch 1/4 avg loss 11.0638 avg_seq_ent 1.2067 avg_mean_gate_ent 1.2172
Epoch 2: experts frozen
Epoch 2/4 avg loss 6.9722 avg_seq_ent 0.2914 avg_mean_gate_ent 0.3070
Epoch 3: experts unfrozen
Epoch 3/4 avg loss 3.1954 avg_seq_ent 0.0516 avg_mean_gate_ent 0.0546
Epoch 4/4 avg loss 1.6950 avg_seq_ent 0.0236 avg_mean_gate_ent 0.0250

Analyzing routing...
DDI: 0.9925249481201172
Mean entropy per domain: {'math': 0.03245699033141136, 'legal': 0.08825654909014702, 'biokem': 0.054280150681734085, 'storyt