In [None]:
import json
import random
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# ------------------------
# Config
# ------------------------
TIERS = {
    "0_1k":    "sequences_0_1k.json",
    "1k_2k":   "sequences_1k_2k.json",
    "2k_7k":   "sequences_2k_7k.json",
    "7k_plus": "sequences_7k_plus.json",
}

SEQ_LEN = 30        # shorter â€” avg sequence lengths are 11-26
EMBED_DIM = 64
HIDDEN_DIM = 128
NUM_LAYERS = 1
LR = 1e-3
EPOCHS = 30
BATCH_SIZE = 32
VAL_SPLIT = 0.1

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {DEVICE}")

In [None]:
def load_sequences(path: str):
    with open(path, "r") as f:
        data = json.load(f)
    return data["sequences"]


def build_vocab(sequences):
    tricks = sorted({t for seq in sequences for t in seq})
    stoi = {t: i for i, t in enumerate(tricks)}
    itos = {i: t for t, i in stoi.items()}
    return tricks, stoi, itos


# Preview all tiers
for tier_name, tier_file in TIERS.items():
    seqs = load_sequences(tier_file)
    vocab, _, _ = build_vocab(seqs)
    lens = [len(s) for s in seqs]
    print(f"{tier_name:>8s}: {len(seqs):5d} seqs, {len(vocab):3d} tricks, avg len {sum(lens)/len(lens):.1f}")

In [None]:
class NextTrickDataset(Dataset):
    def __init__(self, sequences, stoi, seq_len):
        self.samples = []

        for seq in sequences:
            ids = [stoi[t] for t in seq]
            if len(ids) < 2:
                continue

            # Prefix-based samples for short sequences, sliding window for long
            if len(ids) < seq_len:
                for j in range(1, len(ids)):
                    history = ids[:j]
                    target = ids[j]
                    padding = [0] * (seq_len - len(history))
                    self.samples.append((padding + history, target))
            else:
                for i in range(len(ids) - seq_len):
                    self.samples.append((ids[i:i+seq_len], ids[i+seq_len]))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        x, y = self.samples[idx]
        return torch.tensor(x, dtype=torch.long), torch.tensor(y, dtype=torch.long)


class TrickGRU(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim)
        self.gru = nn.GRU(embed_dim, hidden_dim, num_layers=num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x):
        emb = self.embed(x)
        out, _ = self.gru(emb)
        return self.fc(out[:, -1, :])

In [None]:
def top_k_accuracy(logits, targets, k=3):
    topk = torch.topk(logits, k, dim=1).indices
    return (topk == targets.unsqueeze(1)).any(dim=1).float().mean().item()


def train_epoch(model, loader, criterion, optimizer):
    model.train()
    total_loss = total_correct = total_top3 = total_top5 = total = 0

    for x, y in loader:
        x, y = x.to(DEVICE), y.to(DEVICE)
        logits = model(x)
        loss = criterion(logits, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * x.size(0)
        total_correct += (logits.argmax(dim=1) == y).sum().item()
        total_top3 += top_k_accuracy(logits, y, k=3) * x.size(0)
        total_top5 += top_k_accuracy(logits, y, k=5) * x.size(0)
        total += x.size(0)

    return total_loss/total, total_correct/total, total_top3/total, total_top5/total


def eval_epoch(model, loader, criterion):
    if loader is None:
        return None, None, None, None
    model.eval()
    total_loss = total_correct = total_top3 = total_top5 = total = 0

    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            logits = model(x)
            loss = criterion(logits, y)

            total_loss += loss.item() * x.size(0)
            total_correct += (logits.argmax(dim=1) == y).sum().item()
            total_top3 += top_k_accuracy(logits, y, k=3) * x.size(0)
            total_top5 += top_k_accuracy(logits, y, k=5) * x.size(0)
            total += x.size(0)

    return total_loss/total, total_correct/total, total_top3/total, total_top5/total

In [None]:
# ============================================================
# TRAIN ALL 4 TIER MODELS
# ============================================================

trained_models = {}

for tier_name, tier_file in TIERS.items():
    print(f"\n{'='*60}")
    print(f"  TIER: {tier_name}")
    print(f"{'='*60}")

    # Load data + build vocab
    sequences = load_sequences(tier_file)
    vocab, stoi, itos = build_vocab(sequences)
    vocab_size = len(vocab)
    print(f"Sequences: {len(sequences)}, Vocab: {vocab_size}")

    # Build dataset
    dataset = NextTrickDataset(sequences, stoi, SEQ_LEN)
    print(f"Training samples: {len(dataset)}")

    # Train/val split
    indices = list(range(len(dataset)))
    random.shuffle(indices)
    split = int(len(indices) * (1 - VAL_SPLIT))

    train_loader = DataLoader(
        torch.utils.data.Subset(dataset, indices[:split]),
        batch_size=BATCH_SIZE, shuffle=True
    )
    val_loader = DataLoader(
        torch.utils.data.Subset(dataset, indices[split:]),
        batch_size=BATCH_SIZE
    )

    # Create model
    model = TrickGRU(vocab_size, EMBED_DIM, HIDDEN_DIM, NUM_LAYERS).to(DEVICE)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=LR)

    # Track best val loss for early stopping info
    best_val_loss = float("inf")
    best_epoch = 0

    for epoch in range(1, EPOCHS + 1):
        tr_loss, tr_acc, tr_t3, tr_t5 = train_epoch(model, train_loader, criterion, optimizer)
        va_loss, va_acc, va_t3, va_t5 = eval_epoch(model, val_loader, criterion)

        if va_loss < best_val_loss:
            best_val_loss = va_loss
            best_epoch = epoch
            best_state = {k: v.clone() for k, v in model.state_dict().items()}

        if epoch % 5 == 0 or epoch == 1:
            print(
                f"  Epoch {epoch:02d} | "
                f"train loss {tr_loss:.4f} acc {tr_acc:.3f} t3 {tr_t3:.3f} t5 {tr_t5:.3f} | "
                f"val loss {va_loss:.4f} acc {va_acc:.3f} t3 {va_t3:.3f} t5 {va_t5:.3f}"
            )

    # Restore best model
    model.load_state_dict(best_state)
    print(f"  Best val loss {best_val_loss:.4f} at epoch {best_epoch}")

    # Save checkpoint
    pt_path = f"tricks_gru_{tier_name}.pt"
    torch.save({
        "model_state": model.state_dict(),
        "stoi": stoi,
        "itos": itos,
        "vocab": vocab,
        "seq_len": SEQ_LEN,
        "tier": tier_name,
    }, pt_path)
    print(f"  Saved {pt_path}")

    trained_models[tier_name] = {
        "model": model, "vocab": vocab, "stoi": stoi, "itos": itos,
        "best_val_loss": best_val_loss, "best_epoch": best_epoch,
    }

In [None]:
# ============================================================
# EXPORT ALL MODELS: ONNX + JSON
# ============================================================

for tier_name, info in trained_models.items():
    model = info["model"]
    vocab = info["vocab"]
    stoi = info["stoi"]
    itos = info["itos"]
    model.eval()

    suffix = tier_name

    # ONNX export
    dummy = torch.randint(0, len(vocab), (1, SEQ_LEN), dtype=torch.long)
    onnx_path = f"tricks_gru_{suffix}.onnx"
    torch.onnx.export(
        model, dummy, onnx_path,
        export_params=True, opset_version=17,
        input_names=["input_ids"], output_names=["logits"],
        dynamic_axes={"input_ids": {0: "batch_size"}, "logits": {0: "batch_size"}},
    )

    # JSON exports
    with open(f"stoi_{suffix}.json", "w") as f:
        json.dump(stoi, f, indent=2)
    with open(f"itos_{suffix}.json", "w") as f:
        json.dump({str(k): v for k, v in itos.items()}, f, indent=2)
    with open(f"vocab_{suffix}.json", "w") as f:
        json.dump(vocab, f, indent=2)
    with open(f"metadata_{suffix}.json", "w") as f:
        json.dump({
            "vocab_size": len(vocab),
            "seq_len": SEQ_LEN,
            "model_type": "GRU",
            "embed_dim": EMBED_DIM,
            "hidden_dim": HIDDEN_DIM,
            "num_layers": NUM_LAYERS,
            "tier": tier_name,
        }, f, indent=2)

    print(f"Exported {tier_name}: {onnx_path}, stoi/itos/vocab/metadata JSONs")

In [None]:
# ============================================================
# QUICK TEST: predictions per tier
# ============================================================

def predict_topk(model, stoi, itos, trick_history, k=5):
    ids = [stoi[t] for t in trick_history if t in stoi]
    if not ids:
        ids = [0]
    if len(ids) < SEQ_LEN:
        ids = [0] * (SEQ_LEN - len(ids)) + ids
    else:
        ids = ids[-SEQ_LEN:]
    x = torch.tensor(ids, dtype=torch.long).unsqueeze(0).to(DEVICE)
    with torch.no_grad():
        logits = model(x)
        probs = F.softmax(logits, dim=-1).squeeze(0)
    topk = torch.topk(probs, k)
    return [(itos[idx.item()] if idx.item() in itos else itos.get(str(idx.item()), "?"), prob.item())
            for prob, idx in zip(topk.values, topk.indices)]


test_histories = [["S", "RS", "B"], ["B", "F", "WB", "WF"]]

for tier_name, info in trained_models.items():
    print(f"\n--- {tier_name} ---")
    for hist in test_histories:
        preds = predict_topk(info["model"], info["stoi"], info["itos"], hist)
        print(f"  After {hist}: {preds}")