# Common Task 1.2 Dataset preprocessing
The initial task focused on data processing, particularly tokenization, with custom strategies for both source features and target equations. **For source sequences**, amplitude preprocessing involved replacing indexed and momentum variables with predefined token pools for consistency and vocabulary limitation. Operators and symbols were then isolated and tokenized sequentially, with tokens reordered numerically to enhance compatibility with vocabulary-based modeling while preserving index positioning within equations. **For target sequences**, squared amplitude tokenization was more straightforward, involving the identification and isolation of symbols, operators, regular propagators, and masses while maintaining mathematical relationships and representational consistency.

## Source (Amplitude) and Target (Squared Amplitude) Sequences Example  
#### Before and After Processing and Tokenization  

### **Raw Source Sequence**  
$$
-\frac{1}{18} i e^2 \gamma_{\rho_{52180}, \eta_{21888}, \epsilon_{14253}} 
\gamma_{+\rho_{52180}, \gamma_{17939}, \delta_{15647}} 
d_{j_{18804}, \gamma_{17939}}(p_4)_u^{(*)} 
d_{i_{18804}, \delta_{15647}}(p_3)_v 
s_{k_{18802}, \epsilon_{14253}}(p_1)_u 
s_{l_{18802}, \eta_{21888}}(p_2)_v^{(*)} 
\Big/ \left(m_s^2 + s_{12} + \frac{1}{2} \text{reg\_prop} \right)
$$


### **Replaced Indices**  
$$
-\frac{1}{18} i e^2 \gamma_{\text{INDEX}_0, \text{INDEX}_1, \text{INDEX}_2} 
\gamma_{+\text{INDEX}_0, \text{INDEX}_3, \text{INDEX}_4} 
d_{\text{MOMENTUM}_0, \text{INDEX}_3}(p_4)_u^{(*)} 
d_{\text{MOMENTUM}_1, \text{INDEX}_4}(p_3)_v 
s_{\text{MOMENTUM}_2, \text{INDEX}_2}(p_1)_u 
s_{\text{MOMENTUM}_3, \text{INDEX}_1}(p_2)_v^{(*)} 
\Big/ \left(m_s^2 + s_{12} + \frac{1}{2} \text{reg\_prop} \right)
$$


### **Source Tokens**  
```text
['-', '1/18', '*', 'i', '*', 'e', '^', '2', '*', 'gamma_{', '\\', 'INDEX_0', ',', 
'INDEX_1', ',', 'INDEX_2', '}', '*', 'gamma_{', '+', '\\', 'INDEX_0', ',', 'INDEX_3', 
',', 'INDEX_4', '}', '*', 'd_{', 'MOMENTUM_0', ',', 'INDEX_3', '}', '(', 'p_4', ')', 
'_u', '^', '(', '*', ')', '*', 'd_{', 'MOMENTUM_1', ',', 'INDEX_4', '}', '(', 'p_3', 
')', '_v', '*', 's_{', 'MOMENTUM_2', ',', 'INDEX_2', '}', '(', 'p_1', ')', '_u', '*', 
's_{', 'MOMENTUM_3', ',', 'INDEX_1', '}', '(', 'p_2', ')', '_v', '^', '(', '*', ')', 
'/', '(', 'm_s', '^', '2', '+', 's_12', '+', '1/2', '*', 'reg_prop', ')']
```

---

### **Raw Target Sequence**  
$$
\frac{1}{324} e^4 \left( 16 m_d^2 m_s^2 + 8 m_d^2 s_{12} + 8 s_{14} s_{23} + 8 s_{13} s_{24} + 
8 m_s^2 s_{34} \right) \left( m_s^2 + s_{12} + \frac{1}{2} \text{reg\_prop} \right)^{-2}
$$


### **Target Tokens**  
```text
['1/324', '*', 'e', '^', '4', '*', '(', '16', '*', 'm_d', '^', '2', '*', 'm_s', '^', 
'2', '+', '8', '*', 'm_d', '^', '2', '*', 's_12', '+', '8', '*', 's_14', '*', 's_23', 
'+', '8', '*', 's_13', '*', 's_24', '+', '8', '*', 'm_s', '^', '2', '*', 's_34', ')', 
'*', '(', 'm_s', '^', '2', '+', 's_12', '+', '1/2', '*', 'reg_prop', ')', '^', '(', 
'-', '2', ')']
```

---

In [None]:
import os
import pandas as pd
import numpy as np
from tqdm import tqdm
import json

from sklearn.model_selection import train_test_split
from collections import Counter, OrderedDict
from itertools import cycle
import re
import random
from torchtext.vocab import vocab
from tqdm import tqdm
import warnings

In [None]:
class Tokenizer:
    """
    Tokenizer for processing symbolic mathematical expressions.
    """
    def __init__(self, df, index_token_pool_size, momentum_token_pool_size, special_symbols, UNK_IDX, to_replace):
        self.amps = df.amp.tolist()
        self.sqamps = df.sqamp.tolist()

        # Issue warnings if token pool sizes are too small
        if index_token_pool_size < 100:
            warnings.warn(f"Index token pool size ({index_token_pool_size}) is small. Consider increasing it.", UserWarning)
        if momentum_token_pool_size < 100:
            warnings.warn(f"Momentum token pool size ({momentum_token_pool_size}) is small. Consider increasing it.", UserWarning)
        
        # Generate token pools
        self.tokens_pool = [f"INDEX_{i}" for i in range(index_token_pool_size)]
        self.momentum_pool = [f"MOMENTUM_{i}" for i in range(momentum_token_pool_size)]
        
        # Regular expression patterns for token replacement
        self.pattern_momentum = re.compile(r'\b[ijkl]_\d{1,}\b')
        self.pattern_num_123 = re.compile(r'\b(?![ps]_)\w+_\d{1,}\b')
        self.pattern_special = re.compile(r'\b\w+_+\w+\b\\')
        self.pattern_underscore_curly = re.compile(r'\b\w+_{')
        self.pattern_prop = re.compile(r'Prop')
        self.pattern_int = re.compile(r'int\{')
        self.pattern_operators = {
            '+': re.compile(r'\+'), '-': re.compile(r'-'), '*': re.compile(r'\*'),
            ',': re.compile(r','), '^': re.compile(r'\^'), '%': re.compile(r'%'),
            '}': re.compile(r'\}'), '(': re.compile(r'\('), ')': re.compile(r'\)')
        }
        self.pattern_mass = re.compile(r'\b\w+_\w\b')
        self.pattern_s = re.compile(r'\b\w+_\d{2,}\b')
        self.pattern_reg_prop = re.compile(r'\b\w+_\d{1}\b')
        self.pattern_antipart = re.compile(r'(\w)_\w+_\d+\(X\)\^\(\*\)')
        self.pattern_part = re.compile(r'(\w)_\w+_\d+\(X\)')
        self.pattern_index = re.compile(r'\b\w+_\w+_\d{2,}\b')
        
        self.special_symbols = special_symbols
        self.UNK_IDX = UNK_IDX
        self.to_replace = to_replace

    @staticmethod
    def remove_whitespace(expression):
        """Remove all forms of whitespace from the expression."""
        return re.sub(r'\s+', '', expression)

    @staticmethod
    def split_expression(expression):
        """Split the expression by space delimiter."""
        return re.split(r' ', expression)

    def build_tgt_vocab(self):
        """Build vocabulary for target sequences."""
        counter = Counter()
        for eqn in tqdm(self.sqamps, desc='Processing target vocab'):
            counter.update(self.tgt_tokenize(eqn))
        voc = vocab(OrderedDict(counter), specials=self.special_symbols[:], special_first=True)
        voc.set_default_index(self.UNK_IDX)
        return voc

    def build_src_vocab(self, seed):
        """Build vocabulary for source sequences."""
        counter = Counter()
        for diag in tqdm(self.amps, desc='Processing source vocab'):
            counter.update(self.src_tokenize(diag, seed))
        voc = vocab(OrderedDict(counter), specials=self.special_symbols[:], special_first=True)
        voc.set_default_index(self.UNK_IDX)
        return voc
    
    def src_replace(self, ampl, seed):
        """Replace indexed and momentum variables with tokenized equivalents."""
        ampl = self.remove_whitespace(ampl)
        
        random.seed(seed)
        token_cycle = cycle(random.sample(self.tokens_pool, len(self.tokens_pool)))
        momentum_cycle = cycle(random.sample(self.momentum_pool, len(self.momentum_pool)))
        
        # Replace momentum tokens
        temp_ampl = ampl
        momentum_mapping = {match: next(momentum_cycle) for match in set(self.pattern_momentum.findall(ampl))}
        for key, value in momentum_mapping.items():
            temp_ampl = temp_ampl.replace(key, value)
        
        # Replace index tokens
        num_123_mapping = {match: next(token_cycle) for match in set(self.pattern_num_123.findall(ampl))}
        for key, value in num_123_mapping.items():
            temp_ampl = temp_ampl.replace(key, value)

        # Replace pattern index tokens
        pattern_index_mapping = {match: f"{'_'.join(match.split('_')[:-1])} {next(token_cycle)}"
                for match in set(self.pattern_index.findall(ampl))
            }
        for key, value in pattern_index_mapping.items():
            temp_ampl = temp_ampl.replace(key, value)
            
        return temp_ampl
    
    def src_tokenize(self, ampl, seed):
        """Tokenize source expression, optionally applying replacements."""
        temp_ampl = self.src_replace(ampl, seed) if self.to_replace else ampl
        temp_ampl = temp_ampl.replace('\\\\', '\\').replace('\\', ' \\ ').replace('%', '')

        temp_ampl = self.pattern_underscore_curly.sub(lambda match: f' {match.group(0)} ', temp_ampl)

        
        for symbol, pattern in self.pattern_operators.items():
            temp_ampl = pattern.sub(f' {symbol} ', temp_ampl)
        
        temp_ampl = re.sub(r' {2,}', ' ', temp_ampl)
        return [token for token in self.split_expression(temp_ampl) if token]

    def tgt_tokenize(self, sqampl):
        """Tokenize target expression."""
        sqampl = self.remove_whitespace(sqampl)
        temp_sqampl = sqampl
        
        for symbol, pattern in self.pattern_operators.items():
            temp_sqampl = pattern.sub(f' {symbol} ', temp_sqampl)
        
        for pattern in [self.pattern_reg_prop, self.pattern_mass, self.pattern_s]:
            temp_sqampl = pattern.sub(lambda match: f' {match.group(0)} ', temp_sqampl)
        
        temp_sqampl = re.sub(r' {2,}', ' ', temp_sqampl)
        return [token for token in self.split_expression(temp_sqampl) if token]

In [None]:
def read_file_to_list(file_path):
    lines_list = []
    with open(file_path, 'r') as file:
        for line in file:
            lines_list.append(line.strip())  # .strip() removes newline characters
    return lines_list

In [None]:
def normalize_indices(tokenizer, expressions, index_token_pool_size=50, momentum_token_pool_size=50):
    """
    Normalize index and momentum tokens in a list of mathematical expressions.

    This function replaces indexed terms (e.g., "INDEX_x", "MOMENTUM_y") with a 
    standardized sequence of tokens. It ensures that each unique token in an 
    expression is replaced consistently within that expression while avoiding collisions.

    Args:
        tokenizer: A tokenizer object with a `src_tokenize` method to tokenize expressions.
        expressions (list of str): List of mathematical expressions to be normalized.
        index_token_pool_size (int, optional): Maximum number of unique index tokens available. Defaults to 50.
        momentum_token_pool_size (int, optional): Maximum number of unique momentum tokens available. Defaults to 50.

    Returns:
        list of str: A list of normalized expressions with replaced index and momentum tokens.
    """
    
    def replace_indices(token_list, index_map):
        """
        Replace index tokens in a tokenized expression.

        Args:
            token_list (list of str): List of tokens from an expression.
            index_map (dict): Dictionary mapping original index tokens to new standardized ones.

        Returns:
            list of str: Token list with replaced index tokens.
        """
        new_index = (f"INDEX_{i}" for i in range(index_token_pool_size))  # Generator for unique index tokens
        new_tokens = []
        
        for token in token_list:
            if "INDEX_" in token:
                if token not in index_map:
                    try:
                        index_map[token] = next(new_index)
                    except StopIteration:
                        raise ValueError("Ran out of unique indices, increase index_token_pool_size")
                new_tokens.append(index_map[token])
            else:
                new_tokens.append(token)
        
        return new_tokens

    def replace_momenta(token_list, index_map):
        """
        Replace momentum tokens in a tokenized expression.

        Args:
            token_list (list of str): List of tokens from an expression.
            index_map (dict): Dictionary mapping original momentum tokens to new standardized ones.

        Returns:
            list of str: Token list with replaced momentum tokens.
        """
        new_index = (f"MOMENTUM_{i}" for i in range(momentum_token_pool_size))  # Generator for unique momentum tokens
        new_tokens = []
        
        for token in token_list:
            if "MOMENTUM_" in token:
                if token not in index_map:
                    try:
                        index_map[token] = next(new_index)
                    except StopIteration:
                        raise ValueError("Ran out of unique momenta, increase momentum_token_pool_size")
                new_tokens.append(index_map[token])
            else:
                new_tokens.append(token)
        
        return new_tokens

    normalized_expressions = []
    
    # Normalize each expression by replacing indices and momenta
    for expr in tqdm(expressions, desc="Normalizing..."):
        tokens = tokenizer.src_tokenize(expr, 42)
        normalized_expr = replace_momenta(replace_indices(tokens, {}), {})
        normalized_expressions.append(normalized_expr)

    return normalized_expressions

In [None]:
def aug_data(df):
    """
    Perform data augmentation on amplitude expressions and normalize indices.

    This function augments the `amp` (amplitude) column by applying token replacement 
    with different random seeds. It then normalizes indices and momentum terms 
    to ensure consistent tokenization. The corresponding `sqamp` (squared amplitude) 
    values are duplicated accordingly.

    Args:
        df (pd.DataFrame): Input DataFrame containing 'amp' and 'sqamp' columns.

    Returns:
        pd.DataFrame: Augmented DataFrame with new 'amp' and corresponding 'sqamp' values.
    """

    # Extract amplitude and squared amplitude columns
    amps = df["amp"]
    sqamps = df["sqamp"]

    # Number of augmented samples per original entry
    n_samples = 1
    aug_amps = []

    # Generate augmented amplitude expressions
    for amp in tqdm(amps, desc="Processing amplitudes"):
        random_seeds = [random.randint(1, 1000) for _ in range(n_samples)]
        for seed in random_seeds:
            aug_amps.append(tokenizer.src_replace(amp, seed))

    # Duplicate squared amplitude values to match augmented amplitude samples
    aug_sqamps = [sqamp for sqamp in sqamps for _ in range(n_samples)]

    # Normalize indices and momentum tokens in the augmented expressions
    normalized_amps = normalize_indices(tokenizer, aug_amps, INDEX_POOL_SIZE, MOMENTUM_POOL_SIZE)

    # Convert tokenized expressions back to string format
    aug_amps = ["".join(amp) for amp in normalized_amps]

    # Create a new DataFrame with augmented amplitudes and squared amplitudes
    df_aug = pd.DataFrame({"amp": aug_amps, "sqamp": aug_sqamps})

    return df_aug

In [None]:
curr_dir = "/pscratch/sd/r/ritesh11/raw_data/SYMBA_test_data"
line_list = [read_file_to_list(os.path.join(curr_dir, file)) for file in os.listdir(curr_dir)]

In [None]:
import random


data_train = {'sqamp': [], 'process': [], 'amp': []}
data_valid = {'sqamp': [], 'process': [], 'amp': []}
data_test = {'sqamp': [], 'process': [], 'amp': []}

data_list = []  

# Extract relevant data from each line
for lines in line_list:
    for c in lines:
        res = c.split(" : ")  # Split line into components
        try:
            data_list.append((res[1], res[2], res[3]))  # Store as (process, amp, sqamp)
        except IndexError:
            pass  
        except Exception:
            if len(res) > 2 and "error" not in res[2]:
                print(res)  
                break

# Shuffle data to ensure randomness in train/valid/test splits
random.shuffle(data_list)

# Compute split sizes
total = len(data_list)
train_size = int(0.80 * total)  # 80% for training
valid_size = int(0.10 * total)  # 10% for validation

# Assign data to splits
data_train['process'], data_train['amp'], data_train['sqamp'] = zip(*data_list[:train_size])
data_valid['process'], data_valid['amp'], data_valid['sqamp'] = zip(*data_list[train_size:train_size + valid_size])
data_test['process'], data_test['amp'], data_test['sqamp'] = zip(*data_list[train_size + valid_size:])


In [None]:
# Special token indices
BOS_IDX = 0  # Beginning of Sequence
PAD_IDX = 1  # Padding
EOS_IDX = 2  # End of Sequence
UNK_IDX = 3  # Unknown Token
SEP_IDX = 4  # Separator Token

# Special token symbols
SPECIAL_SYMBOLS = ['<S>', '<PAD>', '</S>', '<UNK>', '<SEP>']

INDEX_POOL_SIZE = 200
MOMENTUM_POOL_SIZE = 200

In [None]:
df_train = pd.DataFrame(data_train)
df_test = pd.DataFrame(data_valid)
df_valid = pd.DataFrame(data_test)

In [None]:
df_train

In [None]:
tokenizer = Tokenizer(_, 500, 500, _, _, False)

In [None]:
df_train = aug_data(df_train)
df_valid = aug_data(df_valid)
df_test = aug_data(df_test)

In [None]:
df_train.shape, df_valid.shape, df_test.shape

In [None]:
df_train.to_csv("SYMBA_testtrain.csv",index=False)
df_test.to_csv("SYMBA_testtest.csv",index=False)
df_valid.to_csv("SYMBA_testvalid.csv",index=False)