In [None]:
import wandb

In [None]:
wandb.login()

  | |_| | '_ \/ _` / _` |  _/ -_)


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mayush89718[0m ([33mayush89718-alliance-university[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [1]:
import torch.nn as nn
import torch
import torch.nn.functional as F
from einops import rearrange
import math

In [2]:
class SineKAN1D(nn.Module):
    """
    Input:  (..., input_dim)  e.g., (B, L)
    Output: (..., output_dim) e.g., (B, O)
    """
    def __init__(self, input_dim, output_dim, device='cuda', grid_size=8, is_first=False, add_bias=True, norm_freq=True):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.grid_size = grid_size

        device = torch.device(device)
        dtype = torch.get_default_dtype()

        # Frequencies as a 1D learnable vector
        freq = torch.arange(1, grid_size + 1, dtype=dtype, device=device)
        if norm_freq:
            freq = freq / ((grid_size + 1) ** (0 if is_first else 1))
        self.freq = nn.Parameter(freq)  # (G,)

        # Phase matrix (K, G) precomputed
        input_phase = torch.linspace(0, math.pi, input_dim, dtype=dtype, device=device) * input_dim  # (K,)
        grid_phase = torch.arange(1, grid_size + 1, dtype=dtype, device=device) / (grid_size + 1) * grid_size  # (G,)
        phase = input_phase[:, None] + grid_phase[None, :]  # (K, G)
        self.register_buffer('phase', phase)

        # Amplitudes as (O, K, G). Final matmul done with nn.functional.linear on flattened KG.
        if is_first:
            amp = torch.empty(output_dim, input_dim, grid_size, dtype=dtype, device=device).normal_(0, 0.4)
        else:
            amp = torch.empty(output_dim, input_dim, grid_size, dtype=dtype, device=device).uniform_(-1, 1)
        grid_norm = torch.arange(1, grid_size + 1, dtype=dtype, device=device)  # (G,)
        amp = amp / output_dim / grid_norm[None, None, :]
        self.amplitudes = nn.Parameter(amp)  # (O, K, G)

        if add_bias:
            self.bias = nn.Parameter(torch.ones(output_dim, dtype=dtype, device=device) / output_dim)
        else:
            self.register_parameter('bias', None)

    @property
    def _W(self):
        # (O, K*G)
        return self.amplitudes.reshape(self.output_dim, self.input_dim * self.grid_size)

    def forward(self, x):
        # Support arbitrary leading dims ending with input_dim
        out_shape = x.shape[:-1] + (self.output_dim,)
        x2 = x.reshape(-1, self.input_dim)  # (N, K)
        # Compute sin with minimal broadcasting: (N, K, G)
        s = torch.sin(x2[..., :, None] * self.freq[None, None, :] + self.phase[None, :, :])
        # Dense linear over flattened (K*G)
        y = nn.functional.linear(s.reshape(-1, self.input_dim * self.grid_size), self._W, self.bias)  # (N, O)
        return y.reshape(out_shape)


class SineKANSeqFeat(nn.Module):
    """
    Input:  (B, L, F) where F == input_dim
    Output: (B, L, O) broadcast along L
    """
    def __init__(self, input_dim, output_dim, device='cuda', grid_size=8, is_first=False, add_bias=True, norm_freq=True):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.grid_size = grid_size

        device = torch.device(device)
        dtype = torch.get_default_dtype()

        freq = torch.arange(1, grid_size + 1, dtype=dtype, device=device)
        if norm_freq:
            freq = freq / ((grid_size + 1) ** (0 if is_first else 1))
        self.freq = nn.Parameter(freq)  # (G,)

        input_phase = torch.linspace(0, math.pi, input_dim, dtype=dtype, device=device) * input_dim  # (F,)
        grid_phase = torch.arange(1, grid_size + 1, dtype=dtype, device=device) / (grid_size + 1) * grid_size  # (G,)
        phase = input_phase[:, None] + grid_phase[None, :]  # (F, G)
        self.register_buffer('phase', phase)

        if is_first:
            amp = torch.empty(output_dim, input_dim, grid_size, dtype=dtype, device=device).normal_(0, 0.4)
        else:
            amp = torch.empty(output_dim, input_dim, grid_size, dtype=dtype, device=device).uniform_(-1, 1)
        grid_norm = torch.arange(1, grid_size + 1, dtype=dtype, device=device)
        amp = amp / output_dim / grid_norm[None, None, :]
        self.amplitudes = nn.Parameter(amp)  # (O, F, G)

        if add_bias:
            self.bias = nn.Parameter(torch.ones(output_dim, dtype=dtype, device=device) / output_dim)
        else:
            self.register_parameter('bias', None)

    @property
    def _W(self):
        return self.amplitudes.reshape(self.output_dim, self.input_dim * self.grid_size)

    def forward(self, x):
        B, L, F = x.shape
        assert F == self.input_dim
        # (B, L, F, G)
        s = torch.sin(x.unsqueeze(-1) * self.freq.view(1, 1, 1, -1) + self.phase.view(1, 1, F, -1))
        # Dense linear per time step
        y = nn.functional.linear(s.reshape(B, L, F * self.grid_size), self._W, self.bias).reshape(B, L, self.output_dim)
        return y

In [3]:
class MoeLayer(nn.Module):
  def __init__(self , d_model , n_experts , k):
    super().__init__()
    self.n_experts = n_experts
    self.experts = nn.ModuleList([SineKAN1D(d_model , d_model , grid_size=8) for i in range(self.n_experts)])
    self.gate = nn.Linear(d_model , self.n_experts)
    self.k = k

  def forward(self , x):
    gate_logits = self.gate(x)
    weights , selected_experts = torch.topk(gate_logits , k = self.k)
    weights = F.softmax(weights , dim=-1)
    out = torch.zeros_like(x)
    for i , current_expert in enumerate(self.experts):
      batch_idx , seq_idx , k_idx = torch.where(selected_experts == i)
      token_x = x[batch_idx, seq_idx]
      token_w = weights[batch_idx, seq_idx, k_idx].unsqueeze(-1)
      expert_out = current_expert(token_x)
      out[batch_idx, seq_idx] += token_w * expert_out
    return out


In [4]:
class MultiHeadAttention(nn.Module):
  def __init__(self, d_model , n_heads, max_seq_len , bias = False) -> None:
    super().__init__()
    assert d_model % n_heads == 0 , "d_model not divisible by n_heads"
    self.d_model = d_model
    self.n_heads = n_heads
    self.bias = bias
    self.qw = nn.Linear(self.d_model , self.d_model , bias = self.bias)
    self.kw = nn.Linear(self.d_model , self.d_model , bias = self.bias)
    self.vw = nn.Linear(self.d_model , self.d_model , bias = self.bias)
    self.project = nn.Linear(self.d_model , self.d_model)
    self.rope = RotaryPositionalEncoding(d_model=d_model, max_seq_len=max_seq_len)

  # def cross_forward()

  def forward(self, q , k , v ,  mask = None):
    ## X dimension ==> (batch_size, seq_len, dim)

    ## (batch_dim, seq_len, d_out)
    q = self.qw(q)
    k = self.kw(k)
    v = self.vw(v)

    k = self.rope(k)
    v = self.rope(v)

    # (batch seq_len d_out) -> (batch n_heads seq_len d_out)
    q = rearrange(q , "b s (h d) -> b h s d", h = self.n_heads)
    k = rearrange(k , "b s (h d) -> b h s d", h = self.n_heads)
    v = rearrange(v , "b s (h d) -> b h s d", h = self.n_heads)

    attention_scores = (q @ k.transpose(-2 , -1)) / (k.shape[-1]**0.5)

    # if mask:
    #   masks = torch.triu(torch.ones(seq_len , seq_len) , diagonal=1)
    #   attention_scores = attention_scores.masked_fill(masks == 1 , float("-inf"))
    #   attention_weights = F.softmax(attention_scores/(k.shape[-1]**0.5) , dim=-1)
    # else:
    #   attention_weights = F.softmax(attention_scores/(k.shape[-1]**0.5) , dim=-1)


    if mask is not None:
      attention_scores = attention_scores.masked_fill(mask == 0 , -1e9)
    attention_weights = F.softmax(attention_scores , dim = -1)

    context_vector = attention_weights @ v

    context_vector = rearrange(context_vector , "b h s d -> b s (h d)")
    return self.project(context_vector)


In [5]:
def forward_step(i_n, grid_size, A, K, C):
    ratio = A * grid_size**(-K) + C
    i_n1 = ratio * i_n
    return i_n1

class SineKANLayer(torch.nn.Module):
    def __init__(self, input_dim, output_dim, device='cuda', grid_size=5, is_first=False, add_bias=True, norm_freq=True):
        super(SineKANLayer,self).__init__()
        self.grid_size = grid_size
        self.device = device
        self.is_first = is_first
        self.add_bias = add_bias
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.A, self.K, self.C = 0.9724108095811765, 0.9884401790754128, 0.999449553483052

        self.grid_norm_factor = (torch.arange(grid_size) + 1)
        self.grid_norm_factor = self.grid_norm_factor.reshape(1, 1, grid_size)

        if is_first:
            self.amplitudes = torch.nn.Parameter(torch.empty(output_dim, input_dim, 1).normal_(0, .4) / output_dim  / self.grid_norm_factor)
        else:
            self.amplitudes = torch.nn.Parameter(torch.empty(output_dim, input_dim, 1).uniform_(-1, 1) / output_dim  / self.grid_norm_factor)

        grid_phase = torch.arange(1, grid_size + 1).reshape(1, 1, 1, grid_size) / (grid_size + 1)
        self.input_phase = torch.linspace(0, math.pi, input_dim).reshape(1, 1, input_dim, 1).to(device)
        phase = grid_phase.to(device) + self.input_phase

        if norm_freq:
            self.freq = torch.nn.Parameter(torch.arange(1, grid_size + 1).float().reshape(1, 1, 1, grid_size) / (grid_size + 1)**(1 - is_first))
        else:
            self.freq = torch.nn.Parameter(torch.arange(1, grid_size + 1).float().reshape(1, 1, 1, grid_size))

        for i in range(1, self.grid_size):
            phase = forward_step(phase, i, self.A, self.K, self.C)
        # self.phase = torch.nn.Parameter(phase)
        self.register_buffer('phase', phase)

        if self.add_bias:
            self.bias  = torch.nn.Parameter(torch.ones(1, output_dim) / output_dim)

    def forward(self, x):
        x_shape = x.shape
        output_shape = x_shape[0:-1] + (self.output_dim,)
        x = torch.reshape(x, (-1, self.input_dim))
        x_reshaped = torch.reshape(x, (x.shape[0], 1, x.shape[1], 1))
        s = torch.sin(x_reshaped * self.freq + self.phase)
        y = torch.einsum('ijkl,jkl->ij', s, self.amplitudes)
        if self.add_bias:
            y += self.bias
        y = torch.reshape(y, output_shape)
        return y

In [6]:
## Expansion and Contraction Module

class FeedForward(nn.Module):
  def __init__(self, d_in) -> None:
    super().__init__()
    self.ff = nn.Sequential(
        nn.Linear(d_in , 4 * d_in),
        nn.GELU(),
        nn.Linear(4 * d_in , d_in)
    )

  def forward(self, x):
    return self.ff(x)

In [7]:
class Encoder(nn.Module):
  def __init__(self, d_model , n_heads , max_seq_len , dropout_ratio = 0.2 , bias = False) -> None:
    super().__init__()
    self.attention = MultiHeadAttention(d_model=d_model , n_heads=n_heads , max_seq_len=max_seq_len , bias = bias)
    self.ff = FeedForward(d_in=d_model)
    self.norm1 = nn.LayerNorm(d_model)
    self.norm2 = nn.LayerNorm(d_model)
    self.dropout = nn.Dropout(dropout_ratio)

  def forward(self , x , src_mask):
    attention_out = self.attention(x , x , x , src_mask)
    x = x + self.norm1(self.dropout(attention_out))
    ff_out = self.ff(x)
    x = x + self.norm2(self.dropout(ff_out))
    return x

In [8]:
class Decoder(nn.Module):
  def __init__(self, d_model , max_seq_len , n_heads , dropout_ratio = 0.2 , bias = False) -> None:
    super().__init__()
    self.attention = MultiHeadAttention(d_model=d_model , n_heads=n_heads , max_seq_len=max_seq_len , bias = bias)
    self.ff = FeedForward(d_in=d_model)
    self.norm1 = nn.LayerNorm(d_model)
    self.norm2 = nn.LayerNorm(d_model)
    self.norm3 = nn.LayerNorm(d_model)
    self.dropout = nn.Dropout(dropout_ratio)

  def forward(self, x , enc_output , src_mask, tgt_mask):
    attention_outputs = self.attention(x , x, x, tgt_mask)
    x = x + self.norm1(self.dropout(attention_outputs))
    cross_attention_outputs = self.attention(x , enc_output , enc_output ,  src_mask)
    x = x + self.norm2(self.dropout(cross_attention_outputs))
    ff_output = self.ff(x)
    x = x + self.norm3(self.dropout(ff_output))
    return x




In [9]:
class PositionalEncoding(nn.Module):
  def __init__(self , d_model , max_seq_len):
    super().__init__()
    self.pos = torch.arange(0 , max_seq_len)
    self.theta = 1 / ((10000 ** (torch.arange(0 , d_model , 2))) / d_model)

    pe = torch.zeros(max_seq_len , d_model)

    pe[...,0::2] = torch.sin(self.pos[: , None] / self.theta)
    pe[...,1::2] = torch.cos(self.pos[: , None] / self.theta)

    self.register_buffer("pe" , pe)

  def forward(self, x):
    b , s , d = x.shape
    return x + self.pe[:s , :]

In [10]:
def get_mask(src , tgt):
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  src_mask = (src != 0).to(device)
  src_mask = src_mask[: , None , None , :]
  tgt_mask = (tgt != 0)[: , None , : , None].to(device)
  seq_len = tgt_mask.shape[-2]
  causal_mask = torch.tril(torch.ones(1 , seq_len , seq_len)).bool().to(device)
  final_tgt_mask = tgt_mask & causal_mask
  return src_mask , final_tgt_mask.to(device)


class Transformer(nn.Module):
  def __init__(self , src_vocab_size , d_model , tgt_vocab_size , max_seq_len , n_heads , dropout_ratio , bias , n_encoders , n_decoders, n_experts , k) -> None:
    super().__init__()
    self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    self.encod_embedding = nn.Embedding(src_vocab_size, d_model)
    self.decod_embedding = nn.Embedding(tgt_vocab_size, d_model)
    self.encoding = PositionalEncoding(d_model=d_model, max_seq_len=max_seq_len)
    self.encoder = nn.ModuleList([Encoder(d_model=d_model, n_heads=n_heads, max_seq_len=max_seq_len , dropout_ratio=dropout_ratio , bias = bias) for i in range(n_encoders)])
    self.decoder = nn.ModuleList([Decoder(d_model=d_model, n_heads=n_heads , max_seq_len=max_seq_len , dropout_ratio=dropout_ratio , bias = bias) for i in range(n_encoders)])
    self.linear = nn.Linear(d_model , tgt_vocab_size)
    self.dropout = nn.Dropout(dropout_ratio)
    self.ff = MoeLayer(d_model=d_model , n_experts=n_experts , k=k)
    self.ff2 = FeedForward(d_model)

  def forward(self , src, tgt):
    src_mask , tgt_mask = get_mask(src , tgt)
    src_mask = src_mask.to(self.device)
    tgt_mask = tgt_mask.to(self.device)
    encod_embed = self.dropout(self.encod_embedding(src))
    decod_embed = self.dropout(self.decod_embedding(tgt))

    enc_output = encod_embed
    for encoder in self.encoder:
      enc_output = encoder(enc_output , src_mask)

    dec_output = decod_embed
    for decoder in self.decoder:
      dec_output = decoder(dec_output , enc_output , src_mask , tgt_mask )
    # print("Decoder output " , dec_output)
    out = self.ff(dec_output)

    # out = self.ff(dec_output)
    return self.linear(out)


In [11]:
class RotaryPositionalEncoding(nn.Module):
  def __init__(self, d_model , max_seq_len) -> None:
    super().__init__()
    self.max_seq_len = max_seq_len
    self.d_model = d_model
    self.half_d  = d_model // 2
    self.pos = torch.arange(0 , max_seq_len)
    self.theta = 1 / (10000 ** ((2 * torch.arange(0 , self.half_d)) / self.d_model))
    self.angles = self.pos[: , None] * self.theta[None , :]
    sin = torch.sin(self.angles)
    cos = torch.cos(self.angles)
    self.register_buffer("sin" , sin)
    self.register_buffer("cos" , cos)

  def forward(self , x):
    assert x.shape[-2] <= self.max_seq_len , "seq len should be less than max_seq_len"
    seq_len = x.shape[-2]
    x1 = x[..., 0::2]
    x2 = x[..., 1::2]
    x_rot_1 = x1 * self.cos[:seq_len , :]  - x2 * self.sin[:seq_len , :]
    x_rot_2 = x1 * self.sin[:seq_len , :]  + x2 * self.cos[:seq_len , :]

    out = torch.zeros_like(x)
    out[..., 0::2] = x_rot_1
    out[..., 1::2] = x_rot_2
    return out


In [12]:
from typing import List, Union, Dict , OrderedDict
import warnings
import re
import time
from tqdm import tqdm
from torch.utils.data import Dataset , DataLoader
import pandas as pd

In [13]:
class SymbolicQEDTokenizer:
    def __init__(self, df=None, index_token_pool_size=100, special_symbols=None, unk_idx=1, to_replace=True):
        self.amps = df.amp.tolist() if df is not None else None
        self.sqamps = df.sqamp.tolist() if df is not None else None
        if index_token_pool_size < 50:
            warnings.warn(f"Index token pool size ({index_token_pool_size}) may be insufficient. Consider using at least 50-100 tokens for symbolic tasks.", UserWarning)
        self.index_pool = [f"INDEX_{i}" for i in range(index_token_pool_size)]
        self.particle_index_pool = [f"PINDEX_{i}" for i in range(index_token_pool_size)]
        self.special_symbols = special_symbols or ["<PAD>", "<UNK>", "<BOS>", "<EOS>", "<SEP>"]
        self.unk_idx = unk_idx
        self.to_replace = to_replace
        self.pattern_underscore_curly = re.compile(r'\b[\w]+(?:_[\w]+)*_{')
        self.pattern_mass = re.compile(r'\bm_([a-z]+)\b')
        self.pattern_mandelstam = re.compile(r'\bs_(\d{2,})\b')
        self.pattern_momentum = re.compile(r'\bp_(\d+)\b')
        self.pattern_single_s = re.compile(r'\bs_(\d+)\b(?!\d)')
        self.pattern_exponent = re.compile(r'\^(\w+|\([^)]+\))')
        self.pattern_special = re.compile(r'_([uv])|\\(\w+_\d+|\w+\b)')
        self.pattern_num_123 = re.compile(r'\b(?![psijkl]_)(?!MOMENTUM_)(?!MASS_)(?!P_)(?!S_)(?!MANDELSTAM_)\w+_\d+\b')
        self.pattern_particle = re.compile(r'(?P<prefix>\b(?:\w+_)?)?(?P<target>[ijkl]_\d+\b)')

    def preprocess_expression(self, expr):
        expr = expr.replace(' * ', '*').replace(' / ', '/').replace(' ^ ', '^')
        expr = expr.replace(' + ', '+').replace(' - ', '-')
        expr = expr.replace("+-", "-")
        expr = expr.replace("-+", "-")
        expr = ' '.join(expr.split())
        expr = expr.replace('me', 'm_e')
        return expr

    @staticmethod
    def remove_whitespace(expression: str):
        return re.sub(r'\s+', '', expression)

    def protect_structures(self, ampl: str):
        protected = []
        return ampl, protected

    def physics_aware_replace(self, ampl: str, is_source: bool = True):
        ampl = self.remove_whitespace(ampl)
        ampl = re.sub(r'\bi\b(?!\w)', 'I_UNIT', ampl)
        ampl = re.sub(r'\be\b(?=\^|[+\-*/()| ])', 'E_CHARGE', ampl)
        ampl = ampl.replace('reg_prop', 'REG_PROP')
        ampl = self.pattern_mandelstam.sub(r'MANDELSTAM_\1', ampl)
        ampl = self.pattern_momentum.sub(r'P_\1', ampl)
        ampl = self.pattern_single_s.sub(r'S_\1', ampl)
        ampl = ampl.replace('(*)', 'CONJ')
        return ampl

    def replace_indices(self, ampl: str, is_source: bool = True):
        if not self.to_replace:
            return ampl
        index_pool = iter(self.index_pool)
        particle_index_pool = iter(self.particle_index_pool)
        index_pool_set = set(self.index_pool) if is_source else set()
        ampl = self.pattern_mandelstam.sub(lambda m: f'MANDELSTAM_{m.group(1)}', ampl)

        def get_unique_matches(pattern):
            matches = list(OrderedDict.fromkeys(pattern.findall(ampl)))
            return [m for m in matches if m not in index_pool_set]

        def replace_particle_tokens():
            nonlocal ampl
            matches = list(OrderedDict.fromkeys(m.group('target') for m in sorted(self.pattern_particle.finditer(ampl), key=lambda m: m.start())))
            try:
                mapping = {m: next(particle_index_pool) for m in matches}
            except StopIteration:
                raise RuntimeError("particle_index_pool exhausted. Increase the size of the particle_index_pool.")
            for key in sorted(mapping.keys(), key=len, reverse=True):
                ampl = ampl.replace(key, mapping[key])

        matches = get_unique_matches(self.pattern_num_123)
        try:
            for match in matches:
                ampl = ampl.replace(match, next(index_pool))
        except StopIteration:
            raise RuntimeError("index_pool exhausted. Increase pool size.")
        replace_particle_tokens()
        return ampl

    def tokenize_expression(self, ampl: str, protected: List[str], is_source: bool = True):
        ampl = ampl.replace('\\\\', '\\')
        def replace_special(match):
            if match.group(1):
                return f' _ {match.group(1)} '
            elif match.group(2):
                return f' \\ {match.group(2)} '
        ampl = self.pattern_special.sub(replace_special, ampl)
        if is_source:
            ampl = self.pattern_underscore_curly.sub(lambda match: f' {match.group(0)} ', ampl)
            for symbol in ['{', '}', ',']:
                ampl = ampl.replace(symbol, f' {symbol} ')
        for symbol in ['/', '+', '-', '*', '(', ')', '^']:
            ampl = ampl.replace(symbol, f' {symbol} ')
        ampl = self.pattern_exponent.sub(r' ^ \1 ', ampl)
        ampl = ampl.replace('_PINDEX', '_ PINDEX').replace('_INDEX', '_ INDEX')
        ampl = ampl.replace('REG_PROP', ' reg_prop ')
        ampl = re.sub(r' +', ' ', ampl).strip()
        tokens = [token for token in ampl.split(' ') if token]
        final_tokens = []
        for token in tokens:
            if token.startswith('PROTECTED_'):
                try:
                    idx = int(token.split('_')[1])
                    final_tokens.append(protected[idx])
                except (IndexError, ValueError):
                    final_tokens.append(token)
            else:
                final_tokens.append(token)
        return final_tokens

    def src_tokenize(self, ampl: str):
        try:
            ampl = self.preprocess_expression(ampl)
            ampl, protected = self.protect_structures(ampl)
            ampl = self.physics_aware_replace(ampl, is_source=True)
            ampl = self.replace_indices(ampl, is_source=True)
            return self.tokenize_expression(ampl, protected, is_source=True)
        except Exception as e:
            warnings.warn(f"Source tokenization failed for '{ampl}': {e}")
            return [self.special_symbols[self.unk_idx]]

    def tgt_tokenize(self, sqampl: str):
        try:
            sqampl = self.preprocess_expression(sqampl)
            sqampl, protected = self.protect_structures(sqampl)
            sqampl = self.physics_aware_replace(sqampl, is_source=False)
            sqampl = self.replace_indices(sqampl, is_source=False)
            return self.tokenize_expression(sqampl, protected, is_source=False)
        except Exception as e:
            warnings.warn(f"Target tokenization failed for '{sqampl}': {e}")
            return [self.special_symbols[self.unk_idx]]

    def build_src_vocab(self):
        if self.amps is None:
            return set()
        vocab_set = set()
        start_time = time.time()
        for expr in tqdm(self.amps, desc="Processing source vocab"):
            vocab_set.update(self.src_tokenize(expr))
        end_time = time.time()
        print(f"Source vocab built in {end_time - start_time:.2f} seconds, size: {len(vocab_set)}")
        return vocab_set

    def build_tgt_vocab(self):
        if self.sqamps is None:
            return set()
        vocab_set = set()
        start_time = time.time()
        for expr in tqdm(self.sqamps, desc="Processing target vocab"):
            vocab_set.update(self.tgt_tokenize(expr))
        end_time = time.time()
        print(f"Target vocab built in {end_time - start_time:.2f} seconds, size: {len(vocab_set)}")
        return vocab_set

class SymbolicVocab:
    def __init__(self, tokens: set, special_symbols: list, bos_idx: int, pad_idx: int, eos_idx: int, unk_idx: int, sep_idx: int):
        self.token_list = special_symbols + sorted(list(tokens))
        self.token_to_idx = {token: idx for idx, token in enumerate(self.token_list)}
        self.idx_to_token = {idx: token for token, idx in self.token_to_idx.items()}
        self.unk_idx = unk_idx
        self.pad_idx = pad_idx
        self.bos_idx = bos_idx
        self.eos_idx = eos_idx
        self.sep_idx = sep_idx
        self.unk_tok = special_symbols[unk_idx]
        self.pad_tok = special_symbols[pad_idx]
        self.bos_tok = special_symbols[bos_idx]
        self.eos_tok = special_symbols[eos_idx]
        self.sep_tok = special_symbols[sep_idx]

    def encode(self, tokens: list):
        return [self.token_to_idx.get(token, self.unk_idx) for token in tokens]

    def decode(self, indices: list, include_special_tokens: bool = True):
        if include_special_tokens:
            return [self.idx_to_token.get(idx, self.unk_tok) for idx in indices]
        return [self.idx_to_token.get(idx, self.unk_tok) for idx in indices if idx not in {self.pad_idx, self.bos_idx, self.eos_idx, self.sep_idx}]

    def __len__(self):
        return len(self.token_list)

    def __getitem__(self, item):
        if isinstance(item, int):
            return self.idx_to_token.get(item, self.unk_tok)
        return self.token_to_idx.get(item, self.unk_idx)

class QEDDataset(Dataset):
    def __init__(self, data, tokenizer, max_length):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length
        start_time = time.time()
        self.src_vocab = SymbolicVocab(tokens=tokenizer.build_src_vocab(), special_symbols=tokenizer.special_symbols, bos_idx=2, pad_idx=0, eos_idx=3, unk_idx=1, sep_idx=4)
        self.tgt_vocab = SymbolicVocab(tokens=tokenizer.build_tgt_vocab(), special_symbols=tokenizer.special_symbols, bos_idx=2, pad_idx=0, eos_idx=3, unk_idx=1, sep_idx=4)
        end_time = time.time()
        print(f"Dataset initialized in {end_time - start_time:.2f} seconds, src_vocab_size: {len(self.src_vocab)}, tgt_vocab_size: {len(self.tgt_vocab)}")
        if len(self.src_vocab) == 5 or len(self.tgt_vocab) == 5:
            warnings.warn("Vocabulary size is minimal (only special tokens). Check dataset or tokenization.")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        src = str(self.data.iloc[idx]["amp"])
        trg = str(self.data.iloc[idx]["sqamp"])
        src_tokens = self.tokenizer.src_tokenize(src)
        trg_tokens = self.tokenizer.tgt_tokenize(trg)
        src_ids = self.src_vocab.encode(src_tokens)
        trg_ids = self.tgt_vocab.encode(trg_tokens)
        src_ids = src_ids[:self.max_length] + [self.src_vocab.pad_idx] * (self.max_length - len(src_ids))
        trg_ids = trg_ids[:self.max_length] + [self.tgt_vocab.pad_idx] * (self.max_length - len(trg_ids))
        return {"input_ids": torch.tensor(src_ids, dtype=torch.long), "labels": torch.tensor(trg_ids, dtype=torch.long)}


In [14]:
data_df = pd.read_csv(r'/content/train_data.csv')

start_time = time.time()
tokenizer = SymbolicQEDTokenizer(df=data_df, index_token_pool_size=100, special_symbols=["<PAD>", "<UNK>", "<BOS>", "<EOS>", "<SEP>"], to_replace=True)
src_vocab_size = len(tokenizer.build_src_vocab()) + 5
tgt_vocab_size = len(tokenizer.build_tgt_vocab()) + 5

start_time = time.time()
dataset = QEDDataset(data_df, tokenizer, 300)


Processing source vocab: 100%|██████████| 9952/9952 [00:02<00:00, 3450.59it/s]


Source vocab built in 2.89 seconds, size: 78


Processing target vocab: 100%|██████████| 9952/9952 [00:01<00:00, 5357.93it/s]


Target vocab built in 1.86 seconds, size: 45


Processing source vocab: 100%|██████████| 9952/9952 [00:02<00:00, 3512.68it/s]


Source vocab built in 2.84 seconds, size: 78


Processing target vocab: 100%|██████████| 9952/9952 [00:01<00:00, 5397.09it/s]

Target vocab built in 1.85 seconds, size: 45
Dataset initialized in 4.68 seconds, src_vocab_size: 83, tgt_vocab_size: 50





In [15]:
train_loader = DataLoader(dataset , batch_size = 64)

In [16]:
for batch in train_loader:
  print(batch['input_ids'].shape , batch['labels'].shape)
  print(batch['input_ids'][0] , batch['labels'][0])
  break

torch.Size([64, 300]) torch.Size([64, 300])
tensor([12, 11, 19,  7, 37,  7, 22, 53, 14,  7, 59, 81,  8, 52, 23,  9, 24,  9,
        29, 82,  7, 59, 81, 52, 23,  9, 30,  9, 31, 82,  7, 55, 81, 44,  9, 31,
        82,  5, 50,  6, 54, 80,  7, 55, 81, 45,  9, 30, 82,  5, 51,  6, 54, 78,
        53, 21,  7, 56, 81, 46,  9, 29, 82,  5, 48,  6, 54, 78,  7, 56, 81, 47,
         9, 24, 82,  5, 49,  6, 54, 80, 53, 21, 11,  5, 62, 53, 14,  8, 39,  8,
        12, 11, 14,  7, 74,  6,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0, 

In [17]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [18]:
import torch
import torch.nn as nn
import torch.optim as optim

model = Transformer(
    src_vocab_size=src_vocab_size,
    d_model=512,
    tgt_vocab_size=tgt_vocab_size,
    max_seq_len=300,
    n_heads=4,
    dropout_ratio=0.1,
    bias=False,
    n_encoders=4,
    n_decoders=4,
    n_experts=8,
    k=3
)

model = model.to(device)

epochs = 50
optimizer = optim.Adam(model.parameters() , lr=3e-6)
criterion = nn.CrossEntropyLoss()

In [19]:
torch.cuda.empty_cache()

In [20]:
!nvidia-smi

Thu Sep 25 15:31:07 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   66C    P0             31W /   70W |     298MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [21]:
import torch
from einops import rearrange

def token_accuracy(output, target):
    """
    output: (batch, seq_len, vocab_size)
    target: (batch, seq_len)
    """
    preds = torch.argmax(output, dim=-1)       # (batch, seq_len)
    acc = (preds == target).float().mean().item()
    return acc

def sequence_accuracy(output, target):
    """
    output: (batch, seq_len, vocab_size)
    target: (batch, seq_len)
    """
    preds = torch.argmax(output, dim=-1)       # (batch, seq_len)
    correct_sequences = (preds == target).all(dim=1)  # True if all tokens match
    seq_acc = correct_sequences.float().mean().item()
    return seq_acc

# # Device
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# model = model.to(device)

# Training

epochs = 10
start_time = time.time()
for epoch in range(epochs):
    model.train()
    total_loss = 0.0
    total_token_acc = 0.0
    total_seq_acc = 0.0

    for batch in train_loader:
        src = batch['input_ids'].to(device)
        target = batch['labels'].to(device)

        optimizer.zero_grad()
        output = model(src, target)  # (batch, seq_len, vocab_size)

        # Flatten for loss
        output_flat = rearrange(output, 'b s c -> (b s) c')
        target_flat = rearrange(target, 'b s -> (b s)')

        # Compute loss
        loss = criterion(output_flat, target_flat)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        # Compute accuracies
        total_token_acc += token_accuracy(output, target)
        total_seq_acc += sequence_accuracy(output, target)

    avg_loss = total_loss / len(train_loader)
    avg_token_acc = total_token_acc / len(train_loader)
    avg_seq_acc = total_seq_acc / len(train_loader)

    print(f"Epoch {epoch+1}/{epochs} | Loss: {avg_loss:.4f} | "
          f"Token Acc: {avg_token_acc:.4f} | Seq Acc: {avg_seq_acc:.4f}")
end_time = time.time()

print(end_time - start_time)


Epoch 1/10 | Loss: 1.5058 | Token Acc: 0.7788 | Seq Acc: 0.0000
Epoch 2/10 | Loss: 0.2677 | Token Acc: 0.9486 | Seq Acc: 0.0069
Epoch 3/10 | Loss: 0.1068 | Token Acc: 0.9855 | Seq Acc: 0.1518
Epoch 4/10 | Loss: 0.0488 | Token Acc: 0.9961 | Seq Acc: 0.4174
Epoch 5/10 | Loss: 0.0265 | Token Acc: 0.9979 | Seq Acc: 0.6086
Epoch 6/10 | Loss: 0.0172 | Token Acc: 0.9982 | Seq Acc: 0.6783
Epoch 7/10 | Loss: 0.0122 | Token Acc: 0.9989 | Seq Acc: 0.7335
Epoch 8/10 | Loss: 0.0091 | Token Acc: 0.9991 | Seq Acc: 0.7960
Epoch 9/10 | Loss: 0.0069 | Token Acc: 0.9994 | Seq Acc: 0.8736
Epoch 10/10 | Loss: 0.0053 | Token Acc: 0.9997 | Seq Acc: 0.9247
2611.1186583042145
