# Task 1.2 – Dataset Preprocessing

Loads SYMBA test data, **normalises** growing numeric indices, **tokenises** amplitude and squared-amplitude expressions, builds a shared **vocabulary**, and produces an **80-10-10** train / val / test split.

---

## Tokenisation

Regex-based tokeniser: one compiled pattern with alternatives matched in a fixed order (longer patterns first) so meaningful units stay **one token**.

- **Kept whole:** conjugation `^(*)`, fractions `1/2`, parenthesised negatives `(-2)`, LaTeX-with-index `\\sigma_0`, compound names `m_e` / `s_12` / `gam_1`, standalone letters/symbols.
- **Single-token:** operators/delimiters `+ - * / ^ ( ) { } , % _`, integers, then `\\S` fallback.

After `findall`, matches are filtered with `if t.strip()`. This keeps sequences short and avoids splitting physics objects across tokens.

---

## Index normalisation

Raw SYMBA expressions contain **numeric subscripts** `_NUMBER` (e.g. `\\sigma_249`, `gam_165`). These are **dummy indices** (summation/Lorentz) that grow with diagram size; we renumber them per example to `0, 1, 2, ...` to avoid vocabulary explosion.

- **Physical indices** (in `_PHYSICAL_NUMS`: `0`, `1`–`4`, `12`, `13`, `14`, `23`, `24`, `34`) are **unchanged** — momenta `p_1..p_4` and Mandelstam pairs like `s_12`, `t_23`.
- **Dummy indices** get a new id by first occurrence in that string; same original number maps to the same normalised id within the example. Mapping is built anew per string (amplitude and squared-amplitude separately).


## 1. Imports and constants

In [None]:
from __future__ import annotations

import os
import re
import random
from collections import Counter
from typing import Dict, List, Tuple

import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset

PAD, BOS, EOS, UNK = "<pad>", "<bos>", "<eos>", "<unk>"
SPECIALS = [PAD, BOS, EOS, UNK]
PAD_ID = 0

_PHYSICAL_NUMS = frozenset(
    {"0", "1", "2", "3", "4", "12", "13", "14", "23", "24", "34"}
)

## 2. Data loading

Each line: `interaction : diagram : amplitude : squared_amplitude`. Load all `.txt` files whose name starts with the model prefix (e.g. QED, QCD).

In [2]:
# Load SYMBA .txt files; each line: interaction, diagram, amplitude, squared_amplitude.
def load_raw_data(data_dir: str, model_prefix: str) -> List[Dict[str, str]]:
    records = []
    for fn in sorted(os.listdir(data_dir)):
        if not fn.endswith(".txt") or not fn.startswith(model_prefix):
            continue
        with open(os.path.join(data_dir, fn)) as fh:
            for line in fh:
                line = line.strip()
                if not line:
                    continue
                parts = line.split(" : ")
                if len(parts) < 4:
                    continue
                records.append({
                    "interaction": parts[0].strip(),
                    "diagram": parts[1].strip(),
                    "amplitude": parts[2].strip(),
                    "squared_amplitude": " : ".join(parts[3:]).strip(),
                })
    return records

## 3. Index normalisation

Replace every `_NUMBER` with a per-example sequential id, except numbers in `_PHYSICAL_NUMS` (momenta, Mandelstam indices, etc.).

In [3]:
# Renumber dummy indices per example; keep physical indices (s_12, p_1) unchanged.
def normalize_indices(text: str) -> str:
    mapping = {}
    counter = [0]

    def _repl(m: re.Match) -> str:
        num = m.group(1)
        if num in _PHYSICAL_NUMS:
            return "_" + num
        if num not in mapping:
            mapping[num] = str(counter[0])
            counter[0] += 1
        return "_" + mapping[num]

    return re.sub(r"_(\d+)", _repl, text)

## 4. Tokenisation

Regex that matches (in order): conjugation `^(*)`, parenthesised negatives `(-2)`, fractions, LaTeX indices `\\sigma_0`, compound identifiers, single delimiters, and fallback.

In [4]:
_TOKEN_RE = re.compile(
    r"""
    \^\(\*\)                   |  # ^(*) conjugation
    \(-\d+\)                   |  # (-2) parenthesised negative
    [-+]?\d+/\d+               |  # fractions  1/2  -1/6
    \\[a-zA-Z]+_\d+            |  # \sigma_0
    [a-zA-Z]+_[a-zA-Z]+_\d+   |  # compound_word_num (rare)
    [a-zA-Z]+_[a-zA-Z]+        |  # m_e  reg_prop
    [a-zA-Z]+_\d+              |  # gam_1  s_12  k_3
    \\[a-zA-Z]+                |  # \sigma  (standalone)
    [a-zA-Z]+                  |  # gamma  e  i
    \d+                        |  # 16  4
    [{}()^*+\-/,%_]            |  # single delimiters
    \S                            # fallback
    """,
    re.VERBOSE,
)


# Split expression into tokens via regex (fractions, LaTeX, identifiers, operators, etc.).
def tokenize_expr(text: str) -> List[str]:
    return [t for t in _TOKEN_RE.findall(text.strip()) if t.strip()]

## 5. Physics-type classification (for Task 3)

Assign each token to one of 9 types: special, coupling, mass, Mandelstam, number, regulator, operator, imaginary, other.

In [5]:
NUM_TOKEN_TYPES = 9

_MASS_RE = re.compile(r"^m_[a-zA-Z]+$")
_MANDEL_RE = re.compile(r"^[stu]_\d+$")
_NUM_RE = re.compile(r"^[-+]?\d+(/\d+)?$|^\(-?\d+\)$")
_OPERATORS = frozenset("+-*/^(){},%_")


# Map token to one of 9 physics types (special, coupling, mass, Mandelstam, etc.).
def token_physics_type(tok: str) -> int:
    if tok in SPECIALS:
        return 0
    if tok in ("e", "g"):
        return 1
    if _MASS_RE.match(tok):
        return 2
    if _MANDEL_RE.match(tok):
        return 3
    if _NUM_RE.match(tok):
        return 4
    if tok in ("reg_prop", "reg"):
        return 5
    if tok in _OPERATORS:
        return 6
    if tok == "i":
        return 7
    return 8

## 6. Vocabulary

Special tokens 0–3; then all tokens from training data (by frequency). Encode/decode and optional `type_ids` for Task 3.

In [6]:
# Shared vocabulary from token counts; encode/decode and optional type_ids for Task 3.
class Vocab:
    def __init__(self, token_lists: List[List[str]], min_freq: int = 1):
        counts = Counter()
        for toks in token_lists:
            counts.update(toks)
        self.itos = list(SPECIALS)
        for tok, cnt in counts.most_common():
            if cnt >= min_freq and tok not in self.itos:
                self.itos.append(tok)
        self.stoi = {s: i for i, s in enumerate(self.itos)}
        self.type_ids = torch.tensor([token_physics_type(t) for t in self.itos], dtype=torch.long)

    @property
    def pad_id(self) -> int:
        return 0

    @property
    def bos_id(self) -> int:
        return 1

    @property
    def eos_id(self) -> int:
        return 2

    @property
    def unk_id(self) -> int:
        return 3

    def __len__(self) -> int:
        return len(self.itos)

    def encode(self, tokens: List[str]) -> List[int]:
        return [self.stoi.get(t, self.unk_id) for t in tokens]

    def decode(self, ids):
        return [self.itos[i] if i < len(self.itos) else UNK for i in ids]

## 7. Dataset and collate

In [7]:
# Dataset: amplitude to squared_amplitude (src_ids, tgt_ids).
class Seq2SeqDataset(Dataset):
    def __init__(self, records: List[Dict[str, str]], vocab: Vocab):
        self.samples = []
        for rec in records:
            src_toks = tokenize_expr(normalize_indices(rec["amplitude"]))
            tgt_toks = tokenize_expr(normalize_indices(rec["squared_amplitude"]))
            src_ids = [vocab.bos_id] + vocab.encode(src_toks) + [vocab.eos_id]
            tgt_ids = [vocab.bos_id] + vocab.encode(tgt_toks) + [vocab.eos_id]
            self.samples.append((
                torch.tensor(src_ids, dtype=torch.long),
                torch.tensor(tgt_ids, dtype=torch.long),
            ))

    def __len__(self) -> int:
        return len(self.samples)

    def __getitem__(self, idx: int):
        return self.samples[idx]


# Pad src and tgt sequences in batch to same length.
def seq2seq_collate(batch, pad_id: int = 0):
    srcs, tgts = zip(*batch)
    return (
        pad_sequence(srcs, batch_first=True, padding_value=pad_id),
        pad_sequence(tgts, batch_first=True, padding_value=pad_id),
    )

## 8. Build data (80-10-10 split, loaders, vocab)

In [8]:
# Load data, 80-10-10 split; build vocab and train/val/test loaders.
def build_data(
    data_dir: str,
    model_prefix: str,
    seed: int = 42,
    batch_size: int = 16,
) -> Tuple[DataLoader, DataLoader, DataLoader, Vocab, List[Dict[str, str]]]:
    records = load_raw_data(data_dir, model_prefix)
    if not records:
        raise RuntimeError(f"No data found for prefix '{model_prefix}' in {data_dir}")
    random.seed(seed)
    random.shuffle(records)
    n = len(records)
    n_train = int(0.8 * n)
    n_val = int(0.1 * n)
    train_recs = records[:n_train]
    val_recs = records[n_train : n_train + n_val]
    test_recs = records[n_train + n_val :]
    all_toks = []
    for rec in train_recs:
        all_toks.append(tokenize_expr(normalize_indices(rec["amplitude"])))
        all_toks.append(tokenize_expr(normalize_indices(rec["squared_amplitude"])))
    vocab = Vocab(all_toks)
    train_ds = Seq2SeqDataset(train_recs, vocab)
    val_ds = Seq2SeqDataset(val_recs, vocab)
    test_ds = Seq2SeqDataset(test_recs, vocab)
    collate = lambda batch: seq2seq_collate(batch, vocab.pad_id)
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, collate_fn=collate, num_workers=0, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, collate_fn=collate, num_workers=0, pin_memory=True)
    test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, collate_fn=collate, num_workers=0, pin_memory=True)
    print(f"[preprocess] {model_prefix}:  {len(train_recs)} train / {len(val_recs)} val / {len(test_recs)} test  | vocab {len(vocab)}")
    src_lens = [s.size(0) for s, _ in train_ds.samples]
    tgt_lens = [t.size(0) for _, t in train_ds.samples]
    print(f"[preprocess] src lengths:  min={min(src_lens)} max={max(src_lens)} avg={sum(src_lens)/len(src_lens):.0f}")
    print(f"[preprocess] tgt lengths:  min={min(tgt_lens)} max={max(tgt_lens)} avg={sum(tgt_lens)/len(tgt_lens):.0f}")
    return train_loader, val_loader, test_loader, vocab, test_recs

## 9. Quick demo

Set `DATA_DIR` and run for QED and/or QCD to inspect loaded records and tokenisation.

In [9]:
DATA_DIR = "SYMBA - Test Data"

for prefix in ("QED", "QCD"):
    print(f"\n{'='*60}\n  {prefix}\n{'='*60}")
    recs = load_raw_data(DATA_DIR, prefix)
    print(f"  Loaded {len(recs)} records")
    if recs:
        ex = recs[0]
        amp_toks = tokenize_expr(normalize_indices(ex["amplitude"]))
        sq_toks = tokenize_expr(normalize_indices(ex["squared_amplitude"]))
        print(f"  Example amplitude  ({len(amp_toks)} tokens): {amp_toks[:20]} ...")
        print(f"  Example sq. ampl.  ({len(sq_toks)} tokens):  {sq_toks[:20]} ...")
        build_data(DATA_DIR, prefix)


  QED
  Loaded 360 records
  Example amplitude  (104 tokens): ['-1/2', '*', 'i', '*', 'e', '^', '2', '*', 'gamma', '_', '{', '+', '%', '\\sigma_0', ',', '%', 'gam_1', ',', '%', 'del_1'] ...
  Example sq. ampl.  (59 tokens):  ['1/4', '*', 'e', '^', '4', '*', '(', '16', '*', 'm_e', '^', '2', '*', 'm_mu', '^', '2', '+', '8', '*', 'm_mu'] ...
[preprocess] QED:  288 train / 36 val / 36 test  | vocab 181
[preprocess] src lengths:  min=106 max=198 avg=141
[preprocess] tgt lengths:  min=61 max=97 avg=75

  QCD
  Loaded 234 records
  Example amplitude  (190 tokens): ['-1/4', '*', 'i', '*', 'g', '^', '2', '*', 'gamma', '_', '{', '+', '%', '\\sigma_0', ',', '%', 'gam_1', ',', '%', 'del_1'] ...
  Example sq. ampl.  (197 tokens):  ['-1/144', '*', 'g', '^', '4', '*', '(', '(-16)', '*', 'm_d', '^', '2', '*', 'm_u', '^', '2', '+', '(-8)', '*', 'm_d'] ...
[preprocess] QCD:  187 train / 23 val / 24 test  | vocab 916
[preprocess] src lengths:  min=192 max=2126 avg=532
[preprocess] tgt lengths:  min=95 m