# Task 2 â€“ Vanilla Transformer for Squared Amplitude

Train a standard **encoder-decoder Transformer** (next-token prediction) to map **amplitude** token sequences to **squared-amplitude** token sequences.

1. Build 80-10-10 split and shared vocabulary (via `preprocess.build_data`).
2. Train with teacher forcing and cross-entropy loss (optional label smoothing and gradient accumulation).
3. Evaluate with **sequence accuracy** (exact match) and **token accuracy**.

Set `MODEL` and `DATA_DIR` in the config cell; run all cells. Supports **QED** and **QCD** 2-to-2.

## 1. Imports and path

In [13]:
from __future__ import annotations

import math
import os
import time
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import OneCycleLR

import sys
if os.getcwd() not in sys.path:
    sys.path.insert(0, os.getcwd())
from preprocess import Vocab, build_data

## 2. Positional encoding and Seq2Seq Transformer

In [14]:
# Sinusoidal positional encoding.
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 4096, dropout: float = 0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(max_len).unsqueeze(1).float()
        div = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)
        self.register_buffer("pe", pe.unsqueeze(0))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.dropout(x + self.pe[:, : x.size(1)])


# Encoder-decoder transformer (amplitude to squared-amplitude).
class Seq2SeqTransformer(nn.Module):
    def __init__(self, vocab_size: int, d_model: int = 256, nhead: int = 8, num_enc_layers: int = 3,
                 num_dec_layers: int = 3, dim_ff: int = 1024, dropout: float = 0.1, pad_id: int = 0):
        super().__init__()
        self.d_model = d_model
        self.pad_id = pad_id
        self.embed = nn.Embedding(vocab_size, d_model, padding_idx=pad_id)
        self.pos_enc = PositionalEncoding(d_model, dropout=dropout)
        self.transformer = nn.Transformer(d_model=d_model, nhead=nhead, num_encoder_layers=num_enc_layers,
            num_decoder_layers=num_dec_layers, dim_feedforward=dim_ff, dropout=dropout, batch_first=True)
        self.out_proj = nn.Linear(d_model, vocab_size)

    def _embed(self, ids: torch.Tensor) -> torch.Tensor:
        return self.pos_enc(self.embed(ids) * math.sqrt(self.d_model))

    @staticmethod
    def _causal_mask(sz: int, device: torch.device) -> torch.Tensor:
        return torch.triu(torch.full((sz, sz), float("-inf"), device=device), diagonal=1)

    def forward(self, src: torch.Tensor, tgt_in: torch.Tensor) -> torch.Tensor:
        src_pad = src == self.pad_id
        tgt_pad = tgt_in == self.pad_id
        tgt_mask = self._causal_mask(tgt_in.size(1), tgt_in.device)
        out = self.transformer(self._embed(src), self._embed(tgt_in), tgt_mask=tgt_mask,
            src_key_padding_mask=src_pad, tgt_key_padding_mask=tgt_pad, memory_key_padding_mask=src_pad)
        return self.out_proj(out)

    @torch.no_grad()
    def generate(self, src: torch.Tensor, bos_id: int, eos_id: int, max_len: int = 400) -> torch.Tensor:
        self.eval()
        B, device = src.size(0), src.device
        src_pad = src == self.pad_id
        memory = self.transformer.encoder(self._embed(src), src_key_padding_mask=src_pad)
        ys = torch.full((B, 1), bos_id, dtype=torch.long, device=device)
        finished = torch.zeros(B, dtype=torch.bool, device=device)
        for _ in range(max_len):
            tgt_mask = self._causal_mask(ys.size(1), device)
            tgt_pad = ys == self.pad_id
            out = self.transformer.decoder(self._embed(ys), memory, tgt_mask=tgt_mask,
                tgt_key_padding_mask=tgt_pad, memory_key_padding_mask=src_pad)
            nxt = self.out_proj(out[:, -1, :]).argmax(dim=-1)
            ys = torch.cat([ys, nxt.unsqueeze(1)], dim=1)
            finished |= nxt == eos_id
            if finished.all():
                break
        return ys

## 3. Training and evaluation helpers

In [15]:
# Strip BOS/EOS/PAD from id list.
def _strip(ids, bos, eos, pad):
    out = []
    for i in ids:
        if i in (bos, pad):
            continue
        if i == eos:
            break
        out.append(i)
    return out


# One epoch: teacher forcing, CE loss; optional label smoothing and grad accumulation.
def train_one_epoch(model, loader, optimiser, scheduler, device, pad_id, max_len=None, label_smoothing=0.0, accum_steps=1):
    model.train()
    total_loss, n_batches = 0.0, 0
    optimiser.zero_grad()
    for i, (src, tgt) in enumerate(loader):
        src, tgt = src.to(device), tgt.to(device)
        if max_len is not None:
            src, tgt = src[:, :max_len], tgt[:, :max_len]
        tgt_in, tgt_out = tgt[:, :-1], tgt[:, 1:]
        logits = model(src, tgt_in)
        loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), tgt_out.reshape(-1), ignore_index=pad_id, label_smoothing=label_smoothing) / accum_steps
        loss.backward()
        if (i + 1) % accum_steps == 0 or (i + 1) == len(loader):
            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimiser.step()
            if scheduler is not None:
                scheduler.step()
            optimiser.zero_grad()
        total_loss += loss.item() * accum_steps
        n_batches += 1
    return total_loss / max(n_batches, 1)


# Eval: sequence and token accuracy; returns metrics and sample predictions.
@torch.no_grad()
def evaluate(model, loader, device, vocab, max_gen=400, max_src_len=None):
    model.eval()
    seq_correct = tok_correct = tok_total = total = 0
    examples = []
    for src, tgt in loader:
        src = src.to(device)
        if max_src_len is not None:
            src = src[:, :max_src_len]
        preds = model.generate(src, vocab.bos_id, vocab.eos_id, max_gen)
        for i in range(src.size(0)):
            pred_ids = _strip(preds[i].tolist(), vocab.bos_id, vocab.eos_id, vocab.pad_id)
            true_ids = _strip(tgt[i].tolist(), vocab.bos_id, vocab.eos_id, vocab.pad_id)
            if pred_ids == true_ids:
                seq_correct += 1
            min_len = min(len(pred_ids), len(true_ids))
            tok_correct += sum(p == t for p, t in zip(pred_ids[:min_len], true_ids[:min_len]))
            tok_total += max(len(pred_ids), len(true_ids))
            total += 1
            if len(examples) < 5:
                examples.append({"pred": " ".join(vocab.decode(pred_ids)), "true": " ".join(vocab.decode(true_ids))})
    return {"seq_acc": seq_correct / max(total, 1), "tok_acc": tok_correct / max(tok_total, 1), "n": total, "examples": examples}

## 4. Configuration

Set `MODEL` to `"QED"` or `"QCD"`. Run notebook from `gsoc_tasks` so `preprocess` is importable.

In [16]:
MODEL = "QCD"
DATA_DIR = "SYMBA - Test Data"
EPOCHS = 400
BATCH_SIZE = None
LR = 2e-4
D_MODEL = 256
NHEAD = 8
NUM_LAYERS = 4
DIM_FF = 1024
DROPOUT = 0.1
LABEL_SMOOTHING = 0.1
ACCUM_STEPS = 2
MAX_SEQ_LEN = None
SEED = 42
OUT_DIR = None

if BATCH_SIZE is None:
    BATCH_SIZE = 16 if MODEL == "QED" else 2
if MAX_SEQ_LEN is None:
    MAX_SEQ_LEN = 300 if MODEL == "QED" else 1500
if OUT_DIR is None:
    OUT_DIR = f"results_task2_{MODEL}"
os.makedirs(OUT_DIR, exist_ok=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

Device: cuda


## 5. Build data and model

In [17]:
train_ld, val_ld, test_ld, vocab, _ = build_data(DATA_DIR, MODEL, SEED, BATCH_SIZE)

model = Seq2SeqTransformer(vocab_size=len(vocab), d_model=D_MODEL, nhead=NHEAD, num_enc_layers=NUM_LAYERS,
    num_dec_layers=NUM_LAYERS, dim_ff=DIM_FF, dropout=DROPOUT, pad_id=vocab.pad_id).to(device)
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Model parameters: {n_params:,}")

optimiser = AdamW(model.parameters(), lr=LR, weight_decay=1e-2)
steps_per_epoch = math.ceil(len(train_ld) / ACCUM_STEPS)
total_steps = EPOCHS * steps_per_epoch
scheduler = OneCycleLR(optimiser, max_lr=LR, total_steps=total_steps, pct_start=0.05)
best_val_acc = -1.0
best_path = os.path.join(OUT_DIR, "best.pt")

[preprocess] QCD:  187 train / 23 val / 24 test  | vocab 916
[preprocess] src lengths:  min=192 max=2126 avg=532
[preprocess] tgt lengths:  min=95 max=1349 avg=357
Model parameters: 7,843,732


## 6. Training loop

In [18]:
print(f"\nTraining {MODEL} for {EPOCHS} epochs ...\n")
t0 = time.time()

for epoch in range(1, EPOCHS + 1):
    loss = train_one_epoch(model, train_ld, optimiser, scheduler, device, vocab.pad_id, MAX_SEQ_LEN,
        label_smoothing=LABEL_SMOOTHING, accum_steps=ACCUM_STEPS)
    if epoch % 20 == 0 or epoch == EPOCHS:
        metrics = evaluate(model, val_ld, device, vocab, max_gen=MAX_SEQ_LEN, max_src_len=MAX_SEQ_LEN)
        elapsed = time.time() - t0
        print(f"Epoch {epoch:4d} | loss {loss:.4f} | val seq_acc {metrics['seq_acc']:.4f} | val tok_acc {metrics['tok_acc']:.4f} | {elapsed:.0f}s")
        if metrics["seq_acc"] > best_val_acc:
            best_val_acc = metrics["seq_acc"]
            torch.save(model.state_dict(), best_path)
    elif epoch % 5 == 0:
        print(f"Epoch {epoch:4d} | loss {loss:.4f} | {time.time() - t0:.0f}s")


Training QCD for 400 epochs ...

Epoch    5 | loss 2.9578 | 65s
Epoch   10 | loss 1.9349 | 129s
Epoch   15 | loss 1.6859 | 193s
Epoch   20 | loss 1.4880 | val seq_acc 0.0000 | val tok_acc 0.2309 | 376s
Epoch   25 | loss 1.3979 | 439s
Epoch   30 | loss 1.3289 | 503s
Epoch   35 | loss 1.2851 | 568s
Epoch   40 | loss 1.2732 | val seq_acc 0.6087 | val tok_acc 0.3843 | 747s
Epoch   45 | loss 1.2182 | 812s
Epoch   50 | loss 1.2010 | 876s
Epoch   55 | loss 1.1902 | 941s
Epoch   60 | loss 1.1546 | val seq_acc 0.6957 | val tok_acc 0.4171 | 1120s
Epoch   65 | loss 1.1325 | 1183s
Epoch   70 | loss 1.1174 | 1245s
Epoch   75 | loss 1.1105 | 1309s
Epoch   80 | loss 1.0955 | val seq_acc 0.7391 | val tok_acc 0.4277 | 1490s
Epoch   85 | loss 1.0846 | 1556s
Epoch   90 | loss 1.0725 | 1622s
Epoch   95 | loss 1.0617 | 1686s
Epoch  100 | loss 1.0685 | val seq_acc 0.7826 | val tok_acc 0.5094 | 1872s
Epoch  105 | loss 1.0502 | 1940s
Epoch  110 | loss 1.0547 | 2008s
Epoch  115 | loss 1.0491 | 2068s
Epoch  12

KeyboardInterrupt: 

## 7. Test evaluation

In [19]:
print("Loading best checkpoint and evaluating on test set ...")
model.load_state_dict(torch.load(best_path, weights_only=True))

test_metrics = evaluate(model, test_ld, device, vocab, max_gen=MAX_SEQ_LEN, max_src_len=MAX_SEQ_LEN)
print(f"\n  Test Sequence Accuracy : {test_metrics['seq_acc']:.4f}")
print(f"  Test Token Accuracy    : {test_metrics['tok_acc']:.4f}")
print(f"  Test examples          : {test_metrics['n']}")
print("\nSample predictions:")
for i, ex in enumerate(test_metrics["examples"][:3]):
    print(f"\n  [{i}] TRUE : {ex['true'][:120]}...")
    print(f"      PRED : {ex['pred'][:120]}...")
print(f"\nResults saved to {OUT_DIR}")

Loading best checkpoint and evaluating on test set ...

  Test Sequence Accuracy : 0.8750
  Test Token Accuracy    : 0.8764
  Test examples          : 24

Sample predictions:

  [0] TRUE : -1/6 * g ^ 4 * ( m_b ^ 2 * ( m_b ^ 2 + 1/2 * s_34 ) + s_34 * ( m_b ^ 2 + 1/2 * s_34 ) ) * ( s_12 + 1/2 * reg_prop ) ^ (-...
      PRED : -1/2304 * g ^ 4 * ( s_13 * ( 128 * m_b ^ 2 + (-64) * s_13 ) + (-1024) * m_b ^ 2 * ( m_b ^ 2 + -1/2 * s_13 ) + (-224) * m...

  [1] TRUE : -1/2304 * g ^ 4 * ( s_13 * ( 128 * m_t ^ 2 + (-64) * s_13 ) + (-1024) * m_t ^ 2 * ( m_t ^ 2 + -1/2 * s_13 ) + (-224) * m...
      PRED : -1/2304 * g ^ 4 * ( s_13 * ( 128 * m_t ^ 2 + (-64) * s_13 ) + (-1024) * m_t ^ 2 * ( m_t ^ 2 + -1/2 * s_13 ) + (-224) * m...

  [2] TRUE : -1/16 * g ^ 4 * ( (-16) * m_c ^ 2 * m_t ^ 2 + (-8) * m_c ^ 2 * s_12 + (-8) * s_14 * s_23 + (-8) * s_13 * s_24 + (-8) * m...
      PRED : -1/16 * g ^ 4 * ( (-16) * m_c ^ 2 * m_t ^ 2 + (-8) * m_c ^ 2 * s_12 + (-8) * s_14 * s_23 + (-8) * s_13 * s_24 + (-8) * m...
