# Task 3 – Physics-Informed Model for Squared Amplitude

This notebook trains and evaluates a **physics-informed** encoder–decoder model that:

1. **Graph-based encoding** of the Feynman diagram (vertices, external legs, propagators) via a GNN (TransformerConv). Node embeddings are concatenated with the text encoder output so the decoder cross-attends to both amplitude tokens and diagram topology.

2. **Physics-type token embeddings** – each vocabulary token has a physics type (coupling, mass, Mandelstam, number, regulator, operator, imaginary, other). A learned type embedding is added to the token embedding in the decoder.

**Usage:** Set `MODEL` and `DATA_DIR` in the configuration cell, then run all cells. Supports **QED** and **QCD** 2-to-2.

---

**Encoding and decoding ideas**

| | Idea |
|---|------|
| **Encoding** | **Dual encoding:** (1) **Text encoder** — standard Transformer encoder on the amplitude token sequence (with sinusoidal positional encoding). (2) **Graph encoder** — the Feynman diagram is parsed into a PyG graph (vertices, external legs, propagators); a GNN (TransformerConv layers with edge features) produces node embeddings. The two are combined into a single memory (e.g. concatenated or fused via cross-attention) so the decoder can attend to both the amplitude tokens and the diagram topology. |
| **Decoding** | **Physics-informed decoder:** Autoregressive Transformer decoder that cross-attends to the fused memory. In addition to token embeddings, each token gets a **physics-type embedding** (coupling, mass, Mandelstam, number, regulator, operator, etc.), so the model can treat different symbol types differently. Next-token prediction with cross-entropy loss. |

---

## 1. Imports and path setup

In [1]:
from __future__ import annotations

import math
import os
import re
import time
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torch.optim import AdamW
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import DataLoader, Dataset

try:
    from torch_geometric.nn import TransformerConv
    from torch_geometric.data import Batch as PygBatch, Data as PygData
except ImportError:
    raise ImportError("torch_geometric is required: pip install torch-geometric")


import sys
if os.getcwd() not in sys.path:
    sys.path.insert(0, os.getcwd())
from preprocess import (
    NUM_TOKEN_TYPES,
    Vocab,
    load_raw_data,
    normalize_indices,
    tokenize_expr,
)

## 2. Diagram parsing to PyG graph

Parse SYMBA diagram text into a graph: **nodes** = vertices ∪ external legs; **edges** = propagators (between vertices) and attachments (vertex–external).

In [2]:
PARTICLE_LIST = ["e", "mu", "tau", "u", "d", "s", "c", "b", "t", "A", "G"]
_PARTICLE_TO_IDX = {p: i for i, p in enumerate(PARTICLE_LIST)}
NUM_PARTICLE_TYPES = len(PARTICLE_LIST) + 1

_VERTEX_SECTION_RE = re.compile(r"Vertex\s+V_(\d+):(.*?)(?=Vertex\s+V_|\Z)", re.S)
_ENTRY_RE = re.compile(
    r"(?P<wrappers>(?:(?:AntiPart|OffShell)\s+)*)"
    r"(?P<particle>[A-Za-z]+)"
    r"\((?P<loc>[^)]+)\)"
)


# Parse diagram text into vertices, externals, offshell propagators.
def _parse_diagram(text: str):
    vertices = {}
    externals = {}
    offshell_by_type = {}
    for m in _VERTEX_SECTION_RE.finditer(text):
        vid = int(m.group(1))
        body = m.group(2)
        entries = []
        for em in _ENTRY_RE.finditer(body):
            wrappers = em.group("wrappers").split()
            particle = em.group("particle")
            loc = em.group("loc").strip()
            is_anti = "AntiPart" in wrappers
            is_offshell = "OffShell" in wrappers
            entry = dict(particle=particle, loc=loc, is_anti=is_anti, vid=vid)
            entries.append(entry)
            if loc.startswith("X_"):
                externals[loc] = dict(particle=particle, is_anti=is_anti)
            if is_offshell:
                offshell_by_type.setdefault(particle, []).append(vid)
        vertices[vid] = entries
    return vertices, externals, offshell_by_type


# One-hot particle type vector.
def _particle_onehot(name: str):
    idx = _PARTICLE_TO_IDX.get(name, len(PARTICLE_LIST))
    vec = [0.0] * NUM_PARTICLE_TYPES
    vec[idx] = 1.0
    return vec


def diagram_to_graph(text: str) -> PygData:
    vertices, externals, offshell_by_type = _parse_diagram(text)
    node_ids = {}
    node_feats = []
    for vid in sorted(vertices):
        node_ids[f"V_{vid}"] = len(node_ids)
        node_feats.append([1.0, 0.0] + [0.0] * NUM_PARTICLE_TYPES)
    for xkey in sorted(externals):
        ext = externals[xkey]
        node_ids[xkey] = len(node_ids)
        feat = [0.0, 1.0 if ext["is_anti"] else 0.0] + _particle_onehot(ext["particle"])
        node_feats.append(feat)
    edge_src, edge_dst, edge_attr = [], [], []
    for ptype, vids in offshell_by_type.items():
        if len(vids) >= 2:
            a, b = node_ids[f"V_{vids[0]}"], node_ids[f"V_{vids[1]}"]
            prop_feat = [1.0] + _particle_onehot(ptype)
            edge_src += [a, b]; edge_dst += [b, a]; edge_attr += [prop_feat, prop_feat]
    for vid, entries in vertices.items():
        v_node = node_ids[f"V_{vid}"]
        for ent in entries:
            if ent["loc"].startswith("X_") and ent["loc"] in node_ids:
                x_node = node_ids[ent["loc"]]
                att_feat = [0.0] + _particle_onehot(ent["particle"])
                edge_src += [v_node, x_node]; edge_dst += [x_node, v_node]; edge_attr += [att_feat, att_feat]
    if not edge_src:
        edge_index = torch.zeros(2, 0, dtype=torch.long)
        edge_attr_t = torch.zeros(0, 1 + NUM_PARTICLE_TYPES)
    else:
        edge_index = torch.tensor([edge_src, edge_dst], dtype=torch.long)
        edge_attr_t = torch.tensor(edge_attr, dtype=torch.float)
    x = torch.tensor(node_feats, dtype=torch.float)
    return PygData(x=x, edge_index=edge_index, edge_attr=edge_attr_t)

## 3. Dataset and collate

`GraphSeq2SeqDataset` yields `(graph, src_ids, tgt_ids)` per sample. `graph_collate` batches graphs with PyG and pads sequences.

In [3]:
# Yields (graph, src_ids, tgt_ids) per sample.
class GraphSeq2SeqDataset(Dataset):
    def __init__(self, records, vocab: Vocab, max_len: Optional[int] = None):
        self.samples = []
        for rec in records:
            graph = diagram_to_graph(rec["diagram"])
            src_toks = tokenize_expr(normalize_indices(rec["amplitude"]))
            tgt_toks = tokenize_expr(normalize_indices(rec["squared_amplitude"]))
            src_ids = [vocab.bos_id] + vocab.encode(src_toks) + [vocab.eos_id]
            tgt_ids = [vocab.bos_id] + vocab.encode(tgt_toks) + [vocab.eos_id]
            if max_len is not None:
                src_ids = src_ids[:max_len]
                tgt_ids = tgt_ids[:max_len]
                if src_ids[-1] != vocab.eos_id:
                    src_ids[-1] = vocab.eos_id
                if tgt_ids[-1] != vocab.eos_id:
                    tgt_ids[-1] = vocab.eos_id
            self.samples.append((
                graph,
                torch.tensor(src_ids, dtype=torch.long),
                torch.tensor(tgt_ids, dtype=torch.long),
            ))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]


def graph_collate(batch, pad_id=0):
    graphs, srcs, tgts = zip(*batch)
    graph_batch = PygBatch.from_data_list(list(graphs))
    src_pad = pad_sequence(srcs, batch_first=True, padding_value=pad_id)
    tgt_pad = pad_sequence(tgts, batch_first=True, padding_value=pad_id)
    return graph_batch, src_pad, tgt_pad

## 4. GNN encoder for diagrams

Three layers of `TransformerConv` with **residual connections** and a skip projection to per-node embeddings of size `d_model`.

In [4]:
# GNN: 3× TransformerConv with residual connections.
class DiagramGNN(nn.Module):
    def __init__(self, node_in: int, edge_in: int, hidden: int, out: int):
        super().__init__()
        self.input_proj = nn.Linear(node_in, hidden)
        self.conv1 = TransformerConv(hidden, hidden, edge_dim=edge_in)
        self.conv2 = TransformerConv(hidden, hidden, edge_dim=edge_in)
        self.conv3 = TransformerConv(hidden, out, edge_dim=edge_in)
        self.norm0 = nn.LayerNorm(hidden)
        self.norm1 = nn.LayerNorm(hidden)
        self.norm2 = nn.LayerNorm(hidden)
        self.norm3 = nn.LayerNorm(out)
        self.skip_proj = nn.Linear(hidden, out) if hidden != out else nn.Identity()

    def forward(self, x, edge_index, edge_attr):
        h = F.gelu(self.norm0(self.input_proj(x)))
        h = h + F.gelu(self.norm1(self.conv1(h, edge_index, edge_attr)))
        h = h + F.gelu(self.norm2(self.conv2(h, edge_index, edge_attr)))
        out = F.gelu(self.norm3(self.conv3(h, edge_index, edge_attr)))
        return out + self.skip_proj(h)

## 5. Physics-informed Transformer

**PositionalEncoding** (sinusoidal) + **PhysicsTransformer**: text encoder + graph encoder to **cross-attention fused** memory (text attends to graph with gated residual); decoder with token + physics-type embeddings on **both** encoder and decoder sides.

In [5]:
# Sinusoidal positional encoding.
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 4096, dropout: float = 0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(max_len).unsqueeze(1).float()
        div = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)
        self.register_buffer("pe", pe.unsqueeze(0))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.dropout(x + self.pe[:, : x.size(1)])


class PhysicsTransformer(nn.Module):
    """
    Cross-attention fused graph+text encoder to decoder with physics-type embeddings.
    Key: text memory attends to graph via cross-attention + gated residual.
    """
    def __init__(
        self,
        vocab_size: int,
        d_model: int = 256,
        nhead: int = 8,
        num_enc_layers: int = 3,
        num_dec_layers: int = 3,
        dim_ff: int = 1024,
        dropout: float = 0.1,
        pad_id: int = 0,
        node_feat_dim: int = 14,
        edge_feat_dim: int = 13,
        gnn_hidden: int = 128,
        num_token_types: int = NUM_TOKEN_TYPES,
        type_ids: Optional[torch.Tensor] = None,
    ):
        super().__init__()
        self.d_model = d_model
        self.pad_id = pad_id
        self.embed = nn.Embedding(vocab_size, d_model, padding_idx=pad_id)
        self.type_embed = nn.Embedding(num_token_types, d_model)
        self.pos_enc = PositionalEncoding(d_model, dropout=dropout)
        self.register_buffer("type_ids", type_ids if type_ids is not None else torch.zeros(vocab_size, dtype=torch.long))
        enc_layer = nn.TransformerEncoderLayer(d_model, nhead, dim_ff, dropout, batch_first=True)
        self.text_encoder = nn.TransformerEncoder(enc_layer, num_enc_layers)
        self.gnn = DiagramGNN(node_feat_dim, edge_feat_dim, gnn_hidden, d_model)
        self.gnn_proj = nn.Linear(d_model, d_model)
        # Cross-attention fusion: text attends to graph
        self.graph_cross_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
        self.cross_attn_norm = nn.LayerNorm(d_model)
        self.cross_attn_ff = nn.Sequential(
            nn.Linear(d_model, dim_ff), nn.GELU(), nn.Dropout(dropout),
            nn.Linear(dim_ff, d_model), nn.Dropout(dropout),
        )
        self.ff_norm = nn.LayerNorm(d_model)
        self.fusion_gate = nn.Sequential(nn.Linear(d_model * 2, d_model), nn.Sigmoid())
        dec_layer = nn.TransformerDecoderLayer(d_model, nhead, dim_ff, dropout, batch_first=True)
        self.decoder = nn.TransformerDecoder(dec_layer, num_dec_layers)
        self.out_proj = nn.Linear(d_model, vocab_size)

    def _embed_src(self, ids: torch.Tensor) -> torch.Tensor:
        tok_emb = self.embed(ids) * math.sqrt(self.d_model)
        typ_emb = self.type_embed(self.type_ids[ids])
        return self.pos_enc(tok_emb + typ_emb)

    def _embed_tgt(self, ids: torch.Tensor) -> torch.Tensor:
        tok_emb = self.embed(ids) * math.sqrt(self.d_model)
        typ_emb = self.type_embed(self.type_ids[ids])
        return self.pos_enc(tok_emb + typ_emb)

    @staticmethod
    def _causal_mask(sz: int, device: torch.device) -> torch.Tensor:
        return torch.triu(torch.full((sz, sz), float("-inf"), device=device), diagonal=1)

    def _encode_graphs(self, graph_batch, batch_size: int, device: torch.device):
        node_h = self.gnn(graph_batch.x.to(device), graph_batch.edge_index.to(device), graph_batch.edge_attr.to(device))
        node_h = self.gnn_proj(node_h)
        batch_vec = graph_batch.batch.to(device)
        counts = torch.zeros(batch_size, dtype=torch.long, device=device)
        counts.scatter_add_(0, batch_vec, torch.ones_like(batch_vec))
        max_nodes = int(counts.max().item())
        padded = torch.zeros(batch_size, max_nodes, self.d_model, device=device)
        mask = torch.ones(batch_size, max_nodes, dtype=torch.bool, device=device)
        for b in range(batch_size):
            sel = (batch_vec == b)
            n = int(sel.sum().item())
            padded[b, :n] = node_h[sel]
            mask[b, :n] = False
        return padded, mask

    def _fuse_memory(self, text_mem, graph_mem, graph_pad_mask):
        """Cross-attention fusion with gated residual."""
        attn_out, _ = self.graph_cross_attn(text_mem, graph_mem, graph_mem, key_padding_mask=graph_pad_mask)
        fused = self.cross_attn_norm(text_mem + attn_out)
        fused = self.ff_norm(fused + self.cross_attn_ff(fused))
        gate = self.fusion_gate(torch.cat([text_mem, fused], dim=-1))
        return text_mem + gate * (fused - text_mem)

    def forward(self, graph_batch, src, tgt_in):
        B, device = src.size(0), src.device
        src_pad = src == self.pad_id
        tgt_pad = tgt_in == self.pad_id
        tgt_mask = self._causal_mask(tgt_in.size(1), device)
        text_mem = self.text_encoder(self._embed_src(src), src_key_padding_mask=src_pad)
        graph_mem, graph_pad_mask = self._encode_graphs(graph_batch, B, device)
        memory = self._fuse_memory(text_mem, graph_mem, graph_pad_mask)
        dec_out = self.decoder(self._embed_tgt(tgt_in), memory, tgt_mask=tgt_mask, tgt_key_padding_mask=tgt_pad, memory_key_padding_mask=src_pad)
        return self.out_proj(dec_out)

    @torch.no_grad()
    def generate(self, graph_batch, src, bos_id, eos_id, max_len=400):
        self.eval()
        B, device = src.size(0), src.device
        src_pad = src == self.pad_id
        text_mem = self.text_encoder(self._embed_src(src), src_key_padding_mask=src_pad)
        graph_mem, graph_pad_mask = self._encode_graphs(graph_batch, B, device)
        memory = self._fuse_memory(text_mem, graph_mem, graph_pad_mask)
        ys = torch.full((B, 1), bos_id, dtype=torch.long, device=device)
        finished = torch.zeros(B, dtype=torch.bool, device=device)
        for _ in range(max_len):
            tgt_mask = self._causal_mask(ys.size(1), device)
            tgt_pad = ys == self.pad_id
            out = self.decoder(self._embed_tgt(ys), memory, tgt_mask=tgt_mask, tgt_key_padding_mask=tgt_pad, memory_key_padding_mask=src_pad)
            nxt = self.out_proj(out[:, -1, :]).argmax(dim=-1)
            ys = torch.cat([ys, nxt.unsqueeze(1)], dim=1)
            finished |= nxt == eos_id
            if finished.all():
                break
        return ys

## 6. Training and evaluation helpers

In [6]:
# Strip BOS/EOS/PAD from id list.
def _strip(ids, bos, eos, pad):
    out = []
    for i in ids:
        if i in (bos, pad):
            continue
        if i == eos:
            break
        out.append(i)
    return out


def train_one_epoch(model, loader, optimiser, scheduler, device, pad_id, max_len=None, label_smoothing=0.0, accum_steps=1):
    model.train()
    total_loss, n = 0.0, 0
    optimiser.zero_grad()
    for i, (graph_batch, src, tgt) in enumerate(loader):
        src, tgt = src.to(device), tgt.to(device)
        if max_len is not None:
            src, tgt = src[:, :max_len], tgt[:, :max_len]
        tgt_in, tgt_out = tgt[:, :-1], tgt[:, 1:]
        logits = model(graph_batch, src, tgt_in)
        loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), tgt_out.reshape(-1), ignore_index=pad_id, label_smoothing=label_smoothing) / accum_steps
        loss.backward()
        if (i + 1) % accum_steps == 0 or (i + 1) == len(loader):
            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimiser.step()
            if scheduler is not None:
                scheduler.step()
            optimiser.zero_grad()
        total_loss += loss.item() * accum_steps
        n += 1
    return total_loss / max(n, 1)


# Eval: sequence and token accuracy; returns metrics and sample predictions.
@torch.no_grad()
def evaluate(model, loader, device, vocab, max_gen=400, max_src_len=None):
    model.eval()
    seq_correct = tok_correct = tok_total = total = 0
    examples = []
    for graph_batch, src, tgt in loader:
        src = src.to(device)
        if max_src_len is not None:
            src = src[:, :max_src_len]
        preds = model.generate(graph_batch, src, vocab.bos_id, vocab.eos_id, max_gen)
        for i in range(src.size(0)):
            pred_ids = _strip(preds[i].tolist(), vocab.bos_id, vocab.eos_id, vocab.pad_id)
            true_ids = _strip(tgt[i].tolist(), vocab.bos_id, vocab.eos_id, vocab.pad_id)
            if pred_ids == true_ids:
                seq_correct += 1
            mn = min(len(pred_ids), len(true_ids))
            tok_correct += sum(p == t for p, t in zip(pred_ids[:mn], true_ids[:mn]))
            tok_total += max(len(pred_ids), len(true_ids))
            total += 1
            if len(examples) < 5:
                examples.append({"pred": " ".join(vocab.decode(pred_ids)), "true": " ".join(vocab.decode(true_ids))})
    return {"seq_acc": seq_correct / max(total, 1), "tok_acc": tok_correct / max(tok_total, 1), "n": total, "examples": examples}


# 80-10-10 split, vocab, graph+seq datasets and loaders.
def build_data_task3(data_dir, model_prefix, seed=42, batch_size=16, max_len=None):
    import random
    records = load_raw_data(data_dir, model_prefix)
    if not records:
        raise RuntimeError(f"No data for {model_prefix}")
    random.seed(seed)
    random.shuffle(records)
    n = len(records)
    n_train, n_val = int(0.8 * n), int(0.1 * n)
    train_recs = records[:n_train]
    val_recs = records[n_train : n_train + n_val]
    test_recs = records[n_train + n_val :]
    all_toks = []
    for rec in train_recs:
        all_toks.append(tokenize_expr(normalize_indices(rec["amplitude"])))
        all_toks.append(tokenize_expr(normalize_indices(rec["squared_amplitude"])))
    vocab = Vocab(all_toks)
    train_ds = GraphSeq2SeqDataset(train_recs, vocab, max_len)
    val_ds = GraphSeq2SeqDataset(val_recs, vocab, max_len)
    test_ds = GraphSeq2SeqDataset(test_recs, vocab, max_len)
    collate = lambda b: graph_collate(b, vocab.pad_id)
    kw = dict(num_workers=0, pin_memory=True, collate_fn=collate)
    train_ld = DataLoader(train_ds, batch_size=batch_size, shuffle=True, **kw)
    val_ld = DataLoader(val_ds, batch_size=batch_size, shuffle=False, **kw)
    test_ld = DataLoader(test_ds, batch_size=batch_size, shuffle=False, **kw)
    print(f"[task3] {model_prefix}:  {len(train_recs)} train / {len(val_recs)} val / {len(test_recs)} test  | vocab {len(vocab)}")
    return train_ld, val_ld, test_ld, vocab, test_recs

## 7. Configuration

Set `MODEL` to `"QED"` or `"QCD"`, and `DATA_DIR` to the folder containing the SYMBA test `.txt` files. Run from the `gsoc_tasks` directory so `preprocess` can be imported.

In [7]:
MODEL = "QED"  
DATA_DIR = "SYMBA - Test Data"
EPOCHS = 200
BATCH_SIZE = None  
LR = 3e-4
D_MODEL = 256
NHEAD = 8
NUM_LAYERS = 3
DIM_FF = 1024
DROPOUT = 0.1
LABEL_SMOOTHING = 0.0
ACCUM_STEPS = 1
MAX_SEQ_LEN = None  
SEED = 42
OUT_DIR = None

# Evaluate/save best checkpoint more frequently so we don't miss the peak.
EVAL_EVERY = 5
EVAL_FROM_EPOCH = 10

if BATCH_SIZE is None:
    BATCH_SIZE = 16 if MODEL == "QED" else 2
if MAX_SEQ_LEN is None:
    MAX_SEQ_LEN = 300 if MODEL == "QED" else 1500
if OUT_DIR is None:
    OUT_DIR = f"results_task3_{MODEL}"
os.makedirs(OUT_DIR, exist_ok=True)

# Reproducibility (helps when chasing 100% seq accuracy)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

Device: cuda


## 8. Build data and create model

In [8]:
train_ld, val_ld, test_ld, vocab, _ = build_data_task3(DATA_DIR, MODEL, SEED, BATCH_SIZE, MAX_SEQ_LEN)

node_feat_dim = 2 + NUM_PARTICLE_TYPES
edge_feat_dim = 1 + NUM_PARTICLE_TYPES

model = PhysicsTransformer(
    vocab_size=len(vocab),
    d_model=D_MODEL,
    nhead=NHEAD,
    num_enc_layers=NUM_LAYERS,
    num_dec_layers=NUM_LAYERS,
    dim_ff=DIM_FF,
    dropout=DROPOUT,
    pad_id=vocab.pad_id,
    node_feat_dim=node_feat_dim,
    edge_feat_dim=edge_feat_dim,
    gnn_hidden=128,
    type_ids=vocab.type_ids,
).to(device)

n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Model parameters: {n_params:,}")

optimiser = AdamW(model.parameters(), lr=LR, weight_decay=1e-2)
steps_per_epoch = math.ceil(len(train_ld) / ACCUM_STEPS)
total_steps = EPOCHS * steps_per_epoch
scheduler = OneCycleLR(optimiser, max_lr=LR, total_steps=total_steps, pct_start=0.05)

best_val_seq = -1.0
best_val_tok = -1.0
best_path = os.path.join(OUT_DIR, "best.pt")

[task3] QED:  288 train / 36 val / 36 test  | vocab 181
Model parameters: 6,918,709


## 9. Training loop

In [9]:
print(f"\nTraining {MODEL} (physics-informed) for {EPOCHS} epochs …\n")
t0 = time.time()

for epoch in range(1, EPOCHS + 1):
    loss = train_one_epoch(
        model, train_ld, optimiser, scheduler, device,
        vocab.pad_id, MAX_SEQ_LEN,
        label_smoothing=LABEL_SMOOTHING,
        accum_steps=ACCUM_STEPS,
    )

    do_eval = (epoch >= EVAL_FROM_EPOCH) and (epoch % EVAL_EVERY == 0 or epoch == EPOCHS)
    if do_eval:
        metrics = evaluate(model, val_ld, device, vocab, max_gen=MAX_SEQ_LEN, max_src_len=MAX_SEQ_LEN)
        elapsed = time.time() - t0
        print(
            f"Epoch {epoch:4d} | loss {loss:.4f} | "
            f"val seq_acc {metrics['seq_acc']:.4f} | val tok_acc {metrics['tok_acc']:.4f} | {elapsed:.0f}s"
        )

        # Save best checkpoint (tie-break on token accuracy).
        if (metrics["seq_acc"] > best_val_seq) or (
            metrics["seq_acc"] == best_val_seq and metrics["tok_acc"] > best_val_tok
        ):
            best_val_seq = metrics["seq_acc"]
            best_val_tok = metrics["tok_acc"]
            torch.save(model.state_dict(), best_path)
    elif epoch % 5 == 0 or epoch == 1:
        print(f"Epoch {epoch:4d} | loss {loss:.4f} | {time.time() - t0:.0f}s")


Training QED (physics-informed) for 200 epochs …



  key_padding_mask = F._canonical_mask(


Epoch    1 | loss 4.7985 | 2s
Epoch    5 | loss 1.4683 | 9s


  output = torch._nested_tensor_from_mask(


Epoch   10 | loss 0.4976 | val seq_acc 0.0000 | val tok_acc 0.1573 | 25s
Epoch   15 | loss 0.1944 | val seq_acc 0.0556 | val tok_acc 0.6748 | 35s
Epoch   20 | loss 0.0893 | val seq_acc 0.3056 | val tok_acc 0.8500 | 44s
Epoch   25 | loss 0.0536 | val seq_acc 0.3056 | val tok_acc 0.8022 | 53s
Epoch   30 | loss 0.0345 | val seq_acc 0.6389 | val tok_acc 0.8785 | 63s
Epoch   35 | loss 0.0277 | val seq_acc 0.4167 | val tok_acc 0.9182 | 72s
Epoch   40 | loss 0.0180 | val seq_acc 0.8056 | val tok_acc 0.9802 | 82s
Epoch   45 | loss 0.0141 | val seq_acc 0.7222 | val tok_acc 0.9652 | 92s
Epoch   50 | loss 0.0127 | val seq_acc 0.8333 | val tok_acc 0.9849 | 102s
Epoch   55 | loss 0.0098 | val seq_acc 0.7778 | val tok_acc 0.9629 | 113s
Epoch   60 | loss 0.0077 | val seq_acc 0.8611 | val tok_acc 0.9842 | 123s
Epoch   65 | loss 0.0090 | val seq_acc 0.4722 | val tok_acc 0.9314 | 134s
Epoch   70 | loss 0.0066 | val seq_acc 0.4722 | val tok_acc 0.9424 | 144s
Epoch   75 | loss 0.0043 | val seq_acc 0.8611 

## 10. Test evaluation

In [10]:
print("Loading best checkpoint and evaluating on test set …")
model.load_state_dict(torch.load(best_path, weights_only=True))

test_metrics = evaluate(model, test_ld, device, vocab, max_gen=MAX_SEQ_LEN, max_src_len=MAX_SEQ_LEN)
print(f"\n  Test Sequence Accuracy : {test_metrics['seq_acc']:.4f}")
print(f"  Test Token Accuracy    : {test_metrics['tok_acc']:.4f}")
print(f"  Test examples          : {test_metrics['n']}")
print("\nSample predictions:")
for i, ex in enumerate(test_metrics["examples"][:3]):
    print(f"\n  [{i}] TRUE : {ex['true'][:120]}…")
    print(f"      PRED : {ex['pred'][:120]}…")
print(f"\nResults saved to {OUT_DIR}")

Loading best checkpoint and evaluating on test set …

  Test Sequence Accuracy : 1.0000
  Test Token Accuracy    : 1.0000
  Test examples          : 36

Sample predictions:

  [0] TRUE : 1/36 * e ^ 4 * ( 16 * m_s ^ 2 * m_mu ^ 2 + (-8) * m_s ^ 2 * s_13 + 8 * s_14 * s_23 + (-8) * m_mu ^ 2 * s_24 + 8 * s_12 *…
      PRED : 1/36 * e ^ 4 * ( 16 * m_s ^ 2 * m_mu ^ 2 + (-8) * m_s ^ 2 * s_13 + 8 * s_14 * s_23 + (-8) * m_mu ^ 2 * s_24 + 8 * s_12 *…

  [1] TRUE : 2/81 * e ^ 4 * s_14 * s_34 * ( s_12 + 1/2 * reg_prop ) ^ (-2) + -4/81 * i * e ^ 2 * ( i * e ^ 2 * m_b ^ 2 * ( m_b ^ 2 + …
      PRED : 2/81 * e ^ 4 * s_14 * s_34 * ( s_12 + 1/2 * reg_prop ) ^ (-2) + -4/81 * i * e ^ 2 * ( i * e ^ 2 * m_b ^ 2 * ( m_b ^ 2 + …

  [2] TRUE : 4/81 * e ^ 4 * ( 16 * m_c ^ 2 * m_u ^ 2 + (-8) * m_c ^ 2 * s_13 + 8 * s_14 * s_23 + (-8) * m_u ^ 2 * s_24 + 8 * s_12 * s…
      PRED : 4/81 * e ^ 4 * ( 16 * m_c ^ 2 * m_u ^ 2 + (-8) * m_c ^ 2 * s_13 + 8 * s_14 * s_23 + (-8) * m_u ^ 2 * s_24 + 8 * s_12 * s…

Results saved