# Part 1: Preprocessing

In [1]:
# Import necessary libraries
import os
import re
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from collections import Counter


In [2]:
# Function to load and parse data from all files
def load_data(file_paths):
    data = []
    for file_path in file_paths:
        with open(file_path, 'r') as f:
            for line in f:
                parts = line.strip().split(' : ')
                if len(parts) == 4:
                    event_type, diagram, amplitude, squared_amplitude = parts
                    data.append({
                        'event_type': event_type,
                        'diagram': diagram,
                        'amplitude': amplitude.strip(),
                        'squared_amplitude': squared_amplitude.strip()
                    })
    return pd.DataFrame(data)


In [3]:
def normalize_indices(expr):
    # Find all patterns like %something_number
    pattern = r'(%[^_]+)_(\d+)'
    matches = re.findall(pattern, expr)
    
    # Get unique numeric indices for each variable type
    var_indices = {}
    for var_type, num_idx in matches:
        if var_type not in var_indices:
            var_indices[var_type] = set()
        var_indices[var_type].add(num_idx)
    
    # Create mapping from original indices to normalized ones for each variable type
    index_maps = {}
    for var_type, indices in var_indices.items():
        sorted_indices = sorted(indices, key=int)
        index_maps[var_type] = {orig_idx: str(i+1) for i, orig_idx in enumerate(sorted_indices)}
    
    # Replace indices according to the mapping
    def replace_match(m):
        var_type, num_idx = m.groups()
        return f"{var_type}_{index_maps[var_type][num_idx]}"
    
    normalized_expr = re.sub(pattern, replace_match, expr)
    
    return normalized_expr





In [25]:
# Tokenization function for mathematical expressions
def tokenize_expression(expr):
    # Define patterns for different token types
    patterns = [
        # Constants and numbers
        r'(\d+/\d+|\d+\.\d+|\d+)',
        # Variables with indices and special notations
        r'([a-zA-Z]+(?:_[a-zA-Z0-9]+)?(?:\^\([*]\))?)',
        # Mathematical operators and symbols
        r'([\+\-\*/\^\(\)\[\]\{\}])',
        # Special sym|bols and groupings
        r'(%[a-zA-Z]+_\d+|_{[^}]+})'
    ]
    
    # Combine patterns
    combined_pattern = '|'.join(patterns)
    tokens = re.findall(combined_pattern, expr)
    
    # Flatten and filter empty strings
    tokens = [t for sublist in tokens for t in sublist if t]
    return tokens


In [26]:
# Example of how the tokenization works
example_expr = "-1/2*i*e^2*gamma_{+%\sigma_165,%gam_145,%gam_146}*gamma_{%\sigma_165,%gam_147,%del_137}*e_{i_3,%gam_146}(p_1)_u*e_{k_3,%del_137}(p_2)_u*e_{l_3,%gam_145}(p_3)_u^(*)*e_{i_5,%gam_147}(p_4)_u^(*)/(m_e^2 + -s_13 + 1/2*reg_prop)"
tokens = tokenize_expression(example_expr)
print("Original expression:", example_expr)
print("Tokenized expression:", tokens)


Original expression: -1/2*i*e^2*gamma_{+%\sigma_165,%gam_145,%gam_146}*gamma_{%\sigma_165,%gam_147,%del_137}*e_{i_3,%gam_146}(p_1)_u*e_{k_3,%del_137}(p_2)_u*e_{l_3,%gam_145}(p_3)_u^(*)*e_{i_5,%gam_147}(p_4)_u^(*)/(m_e^2 + -s_13 + 1/2*reg_prop)
Tokenized expression: ['-', '1/2', '*', 'i', '*', 'e', '^', '2', '*', 'gamma', '_{+%\\sigma_165,%gam_145,%gam_146}', '*', 'gamma', '_{%\\sigma_165,%gam_147,%del_137}', '*', 'e', '_{i_3,%gam_146}', '(', 'p_1', ')', 'u', '*', 'e', '_{k_3,%del_137}', '(', 'p_2', ')', 'u', '*', 'e', '_{l_3,%gam_145}', '(', 'p_3', ')', 'u^(*)', '*', 'e', '_{i_5,%gam_147}', '(', 'p_4', ')', 'u^(*)', '/', '(', 'm_e', '^', '2', '+', '-', 's_13', '+', '1/2', '*', 'reg_prop', ')']


In [27]:
# tokennize using tokeninzer
from transformers import T5Tokenizer
tokenizer = T5Tokenizer.from_pretrained('t5-small')
#special_tokens = ['*', '/', '+', '-', '^', '(', ')', '{', '}', '_', 'gamma', 'sigma', 'e^2']
#tokenizer.add_tokens(special_tokens)
#tokenizer.add_tokens(['[START]'])

In [28]:
# Tokenization function for mathematical expressions
def tokenize_expression(expr):
    # Tokenize the expression
    tokens = tokenizer.tokenize(expr)
    return tokens

# Example of how the tokenization works
tokens = tokenize_expression(example_expr)
print("Original expression:", example_expr)
print("Tokenized expression:", tokens)


Original expression: -1/2*i*e^2*gamma_{+%\sigma_165,%gam_145,%gam_146}*gamma_{%\sigma_165,%gam_147,%del_137}*e_{i_3,%gam_146}(p_1)_u*e_{k_3,%del_137}(p_2)_u*e_{l_3,%gam_145}(p_3)_u^(*)*e_{i_5,%gam_147}(p_4)_u^(*)/(m_e^2 + -s_13 + 1/2*reg_prop)
Tokenized expression: ['▁', '-', '1/2', '*', 'i', '*', 'e', '^', '2', '*', 'gam', 'm', 'a', '_', '{', '+', '%', '\\', 's', 'igma', '_', '165', ',', '%', 'gam', '_', '145', ',', '%', 'gam', '_', '146', '}', '*', 'gam', 'm', 'a', '_', '{', '%', '\\', 's', 'igma', '_', '165', ',', '%', 'gam', '_', '147', ',', '%', 'de', 'l', '_', '137', '}', '*', 'e', '_', '{', 'i', '_', '3,', '%', 'gam', '_', '146', '}', '(', 'p', '_', '1)', '_', 'u', '*', 'e', '_', '{', 'k', '_', '3,', '%', 'de', 'l', '_', '137', '}', '(', 'p', '_', '2)', '_', 'u', '*', 'e', '_', '{', 'l', '_', '3,', '%', 'gam', '_', '145', '}', '(', 'p', '_', '3)', '_', 'u', '^', '(', '*', ')', '*', 'e', '_', '{', 'i', '_', '5,', '%', 'gam', '_', '147', '}', '(', 'p', '_', '4)', '_', 'u', '^', '(

In [None]:
class Tokenizer:
    """
    Tokenizer for processing symbolic mathematical expressions.
    """
    def __init__(self, df, index_token_pool_size, momentum_token_pool_size, special_symbols, UNK_IDX, to_replace):
        self.amps = df.amplitude.tolist()
        self.sqamps = df.squared_amplitude.tolist()

        # # Issue warnings if token pool sizes are too small
        # if index_token_pool_size < 100:
        #     warnings.warn(f"Index token pool size ({index_token_pool_size}) is small. Consider increasing it.", UserWarning)
        # if momentum_token_pool_size < 100:
        #     warnings.warn(f"Momentum token pool size ({momentum_token_pool_size}) is small. Consider increasing it.", UserWarning)
        
        # Generate token pools
        self.tokens_pool = [f"INDEX_{i}" for i in range(index_token_pool_size)]
        self.momentum_pool = [f"MOMENTUM_{i}" for i in range(momentum_token_pool_size)]
        
        # Regular expression patterns for token replacement
        self.pattern_momentum = re.compile(r'\b[ijkl]_\d{1,}\b')
        self.pattern_num_123 = re.compile(r'\b(?![ps]_)\w+_\d{1,}\b')
        self.pattern_special = re.compile(r'\b\w+_+\w+\b\\')
        self.pattern_underscore_curly = re.compile(r'\b\w+_{')
        self.pattern_prop = re.compile(r'Prop')
        self.pattern_int = re.compile(r'int\{')
        self.pattern_operators = {
            '+': re.compile(r'\+'), '-': re.compile(r'-'), '*': re.compile(r'\*'),
            ',': re.compile(r','), '^': re.compile(r'\^'), '%': re.compile(r'%'),
            '}': re.compile(r'\}'), '(': re.compile(r'\('), ')': re.compile(r'\)')
        }
        self.pattern_mass = re.compile(r'\b\w+_\w\b')
        self.pattern_s = re.compile(r'\b\w+_\d{2,}\b')
        self.pattern_reg_prop = re.compile(r'\b\w+_\d{1}\b')
        self.pattern_antipart = re.compile(r'(\w)_\w+_\d+\(X\)\^\(\*\)')
        self.pattern_part = re.compile(r'(\w)_\w+_\d+\(X\)')
        self.pattern_index = re.compile(r'\b\w+_\w+_\d{2,}\b')
        
        self.special_symbols = special_symbols
        self.UNK_IDX = UNK_IDX
        self.to_replace = to_replace

    @staticmethod
    def remove_whitespace(expression):
        """Remove all forms of whitespace from the expression."""
        return re.sub(r'\s+', '', expression)

    @staticmethod
    def split_expression(expression):
        """Split the expression by space delimiter."""
        return re.split(r' ', expression)

    def build_tgt_vocab(self):
        """Build vocabulary for target sequences."""
        counter = Counter()
        for eqn in tqdm(self.sqamps, desc='Processing target vocab'):
            counter.update(self.tgt_tokenize(eqn))
        voc = vocab(OrderedDict(counter), specials=self.special_symbols[:], special_first=True)
        voc.set_default_index(self.UNK_IDX)
        return voc

    def build_src_vocab(self, seed):
        """Build vocabulary for source sequences."""
        counter = Counter()
        for diag in tqdm(self.amps, desc='Processing source vocab'):
            counter.update(self.src_tokenize(diag, seed))
        voc = vocab(OrderedDict(counter), specials=self.special_symbols[:], special_first=True)
        voc.set_default_index(self.UNK_IDX)
        return voc
    
    def src_replace(self, ampl, seed):
        """Replace indexed and momentum variables with tokenized equivalents."""
        ampl = self.remove_whitespace(ampl)
        
        random.seed(seed)
        token_cycle = cycle(random.sample(self.tokens_pool, len(self.tokens_pool)))
        momentum_cycle = cycle(random.sample(self.momentum_pool, len(self.momentum_pool)))
        
        # Replace momentum tokens
        temp_ampl = ampl
        momentum_mapping = {match: next(momentum_cycle) for match in set(self.pattern_momentum.findall(ampl))}
        for key, value in momentum_mapping.items():
            temp_ampl = temp_ampl.replace(key, value)
        
        # Replace index tokens
        num_123_mapping = {match: next(token_cycle) for match in set(self.pattern_num_123.findall(ampl))}
        for key, value in num_123_mapping.items():
            temp_ampl = temp_ampl.replace(key, value)

        # Replace pattern index tokens
        pattern_index_mapping = {match: f"{'_'.join(match.split('_')[:-1])} {next(token_cycle)}"
                for match in set(self.pattern_index.findall(ampl))
            }
        for key, value in pattern_index_mapping.items():
            temp_ampl = temp_ampl.replace(key, value)
            
        return temp_ampl
    
    def src_tokenize(self, ampl, seed):
        """Tokenize source expression, optionally applying replacements."""
        temp_ampl = self.src_replace(ampl, seed) if self.to_replace else ampl
        temp_ampl = temp_ampl.replace('\\\\', '\\').replace('\\', ' \\ ').replace('%', '')

        temp_ampl = self.pattern_underscore_curly.sub(lambda match: f' {match.group(0)} ', temp_ampl)

        
        for symbol, pattern in self.pattern_operators.items():
            temp_ampl = pattern.sub(f' {symbol} ', temp_ampl)
        
        temp_ampl = re.sub(r' {2,}', ' ', temp_ampl)
        return [token for token in self.split_expression(temp_ampl) if token]

    def tgt_tokenize(self, sqampl):
        """Tokenize target expression."""
        sqampl = self.remove_whitespace(sqampl)
        temp_sqampl = sqampl
        
        for symbol, pattern in self.pattern_operators.items():
            temp_sqampl = pattern.sub(f' {symbol} ', temp_sqampl)
        
        for pattern in [self.pattern_reg_prop, self.pattern_mass, self.pattern_s]:
            temp_sqampl = pattern.sub(lambda match: f' {match.group(0)} ', temp_sqampl)
        
        temp_sqampl = re.sub(r' {2,}', ' ', temp_sqampl)
        return [token for token in self.split_expression(temp_sqampl) if token]

In [50]:
# Special token indices
BOS_IDX = 0  # Beginning of Sequence
PAD_IDX = 1  # Padding
EOS_IDX = 2  # End of Sequence
UNK_IDX = 3  # Unknown Token
SEP_IDX = 4  # Separator Token

# Special token symbols
SPECIAL_SYMBOLS = ['<START>', '<PAD>', '<END>', '<UNK>', '<SEP>']
import re
import random
from itertools import cycle
tokenizer = Tokenizer(df, 100, 100, SPECIAL_SYMBOLS, UNK_IDX, to_replace=True)

In [51]:
def tokenize_expression(expr):
    # Tokenize the expression
    tokens = tokenizer.tgt_tokenize(expr)
    return tokens

# Example of how the tokenization works
tokens = tokenize_expression(example_expr)
print("Original expression:", example_expr)
print("Tokenized expression:", tokens)

Original expression: -i*e^2*gamma_{+%\sigma_157721,%gam_166722,%eps_44575}*gamma_{%\sigma_157721,%gam_166723,%del_106099}*e_{i_36289,%del_106099}(p_3)_v*e_{k_36277,%gam_166723}(p_1)_v^(*)*mu_{l_36277,%gam_166722}(p_2)_v^(*)*mu_{j_36269,%eps_44575}(p_4)_v/(m_e^2 + (-2)*s_13 + s_33 + reg_prop)
Tokenized expression: ['-', 'i', '*', 'e', '^', '2', '*', 'gamma_{', '+', '%', '\\', 'sigma_157721', ',', '%', 'gam_166722', ',', '%', 'eps_44575', '}', '*', 'gamma_{', '%', '\\', 'sigma_157721', ',', '%', 'gam_166723', ',', '%', 'del_106099', '}', '*', 'e_{', 'i_36289', ',', '%', 'del_106099', '}', '(', 'p_3', ')', '_v', '*', 'e_{', 'k_36277', ',', '%', 'gam_166723', '}', '(', 'p_1', ')', '_v', '^', '(', '*', ')', '*', 'mu_{', 'l_36277', ',', '%', 'gam_166722', '}', '(', 'p_2', ')', '_v', '^', '(', '*', ')', '*', 'mu_{', 'j_36269', ',', '%', 'eps_44575', '}', '(', 'p_4', ')', '_v/', '(', 'm_e', '^', '2', '+', '(', '-', '2', ')', '*', 's_13', '+', 's_33', '+', 'reg_prop', ')']


In [52]:
def normalize_indices(tokenizer, expressions, index_token_pool_size=50, momentum_token_pool_size=50):
    # Function to replace indices with a new set of tokens for each expression
    def replace_indices(token_list, index_map):
        new_index = (f"INDEX_{i}" for i in range(index_token_pool_size))  # Local generator for new indices
        new_tokens = []
        for token in token_list:
            if "INDEX_" in token:
                if token not in index_map:
                    try:
                        index_map[token] = next(new_index)
                    except StopIteration:
                        # Handle the case where no more indices are available
                        raise ValueError("Ran out of unique indices, increase token_pool_size")
                new_tokens.append(index_map[token])
            else:
                new_tokens.append(token)
        return new_tokens

    def replace_momenta(token_list, index_map):
        new_index = (f"MOMENTUM_{i}" for i in range(momentum_token_pool_size))  # Local generator for new indices
        new_tokens = []
        for token in token_list:
            if "MOMENTUM_" in token:
                if token not in index_map:
                    try:
                        index_map[token] = next(new_index)
                    except StopIteration:
                        # Handle the case where no more indices are available
                        raise ValueError("Ran out of unique indices, increase momentum_token_pool_size")
                new_tokens.append(index_map[token])
            else:
                new_tokens.append(token)
        return new_tokens

    normalized_expressions = []
    # Replace indices in each expression randomly
    for expr in expressions:
        toks = tokenizer.src_tokenize(expr,42)
        print(toks)
        normalized_expressions.append(replace_momenta(replace_indices(toks, {}), {}))

    return normalized_expressions

In [None]:
notmalized_tokens = normalize_indices(tokenizer, [example_expr])
print("Original expression:", example_expr)
print("Normalized expression:", notmalized_tokens[0])

['-', 'i', '*', 'e', '^', '2', '*', 'gamma_{', '+', '\\', 'INDEX_31', ',', 'INDEX_35', ',', 'INDEX_28', '}', '*', 'gamma_{', '\\', 'INDEX_31', ',', 'INDEX_3', ',', 'INDEX_94', '}', '*', 'e_{', 'MOMENTUM_27', ',', 'INDEX_94', '}', '(', 'p_3', ')', '_v', '*', 'e_{', 'MOMENTUM_40', ',', 'INDEX_3', '}', '(', 'p_1', ')', '_v', '^', '(', '*', ')', '*', 'mu_{', 'MOMENTUM_72', ',', 'INDEX_35', '}', '(', 'p_2', ')', '_v', '^', '(', '*', ')', '*', 'mu_{', 'MOMENTUM_91', ',', 'INDEX_28', '}', '(', 'p_4', ')', '_v/', '(', 'm_e', '^', '2', '+', '(', '-', '2', ')', '*', 's_13', '+', 's_33', '+', 'reg_prop', ')']
Original expression: -i*e^2*gamma_{+%\sigma_157721,%gam_166722,%eps_44575}*gamma_{%\sigma_157721,%gam_166723,%del_106099}*e_{i_36289,%del_106099}(p_3)_v*e_{k_36277,%gam_166723}(p_1)_v^(*)*mu_{l_36277,%gam_166722}(p_2)_v^(*)*mu_{j_36269,%eps_44575}(p_4)_v/(m_e^2 + (-2)*s_13 + s_33 + reg_prop)
Normalized expression: ['-', 'i', '*', 'e', '^', '2', '*', 'gamma_{', '+', '\\', 'INDEX_0', ',', 'IND

: 

In [46]:
example_expr = "-i*e^2*gamma_{+%\sigma_157721,%gam_166722,%eps_44575}*gamma_{%\sigma_157721,%gam_166723,%del_106099}*e_{i_36289,%del_106099}(p_3)_v*e_{k_36277,%gam_166723}(p_1)_v^(*)*mu_{l_36277,%gam_166722}(p_2)_v^(*)*mu_{j_36269,%eps_44575}(p_4)_v/(m_e^2 + (-2)*s_13 + s_33 + reg_prop)"
normalized = normalize_indices(tokenizer, [example_expr])[0]
print("Original expression:", example_expr)
print("Normalized expression:", normalized)

Original expression: -i*e^2*gamma_{+%\sigma_157721,%gam_166722,%eps_44575}*gamma_{%\sigma_157721,%gam_166723,%del_106099}*e_{i_36289,%del_106099}(p_3)_v*e_{k_36277,%gam_166723}(p_1)_v^(*)*mu_{l_36277,%gam_166722}(p_2)_v^(*)*mu_{j_36269,%eps_44575}(p_4)_v/(m_e^2 + (-2)*s_13 + s_33 + reg_prop)
Normalized expression: ['-', 'i', '*', 'e', '^', '2', '*', 'gamma_{', '+', '\\', 'INDEX_0', ',', 'INDEX_1', ',', 'INDEX_2', '}', '*', 'gamma_{', '\\', 'INDEX_0', ',', 'INDEX_3', ',', 'INDEX_4', '}', '*', 'e_{', 'MOMENTUM_0', ',', 'INDEX_4', '}', '(', 'p_3', ')', '_v', '*', 'e_{', 'MOMENTUM_1', ',', 'INDEX_3', '}', '(', 'p_1', ')', '_v', '^', '(', '*', ')', '*', 'mu_{', 'MOMENTUM_2', ',', 'INDEX_1', '}', '(', 'p_2', ')', '_v', '^', '(', '*', ')', '*', 'mu_{', 'MOMENTUM_3', ',', 'INDEX_2', '}', '(', 'p_4', ')', '_v/', '(', 'm_e', '^', '2', '+', '(', '-', '2', ')', '*', 's_13', '+', 's_33', '+', 'reg_prop', ')']


In [13]:
file_paths = [f"SYMBA - Test Data\QED-2-to-2-diag-TreeLevel-{i}.txt" for i in range(0, 10)]

# Load the sample data
df = load_data(file_paths)

# Display the first few rows
print("Dataset shape:", df.shape)
df.head()

Dataset shape: (15552, 4)


Unnamed: 0,event_type,diagram,amplitude,squared_amplitude
0,Interaction: e_gam_239(X)^(*) e_del_219(X)^(...,"Vertex V_1:e(X_2), e(X_4), OffShell A(V_1), V...","-1/2*i*e^2*gamma_{+%\sigma_165,%gam_145,%gam_1...",2*e^4*(m_e^4 + -1/2*m_e^2*s_13 + 1/2*s_14*s_23...
1,Interaction: e_gam_239(X)^(*) e_del_219(X)^(...,"Vertex V_0:e(X_2), e(X_3), OffShell A(V_0), V...","1/2*i*e^2*gamma_{+%\sigma_172,%gam_162,%del_14...",2*e^4*(m_e^4 + -1/2*m_e^2*s_14 + -1/2*m_e^2*s_...
2,Interaction: e_gam_239(X)^(*) e_del_219(X)^(...,"Vertex V_1:e(X_2), OffShell e(X_4), OffShell...","-1/2*i*e^2*gamma_{+%\sigma_293,%gam_358,%gam_3...",2*e^4*(m_e^4 + -1/2*m_e^2*s_13 + 1/2*s_14*s_23...
3,Interaction: e_gam_239(X)^(*) e_del_219(X)^(...,"Vertex V_0:e(X_2), e(X_3), OffShell A(V_0), V...","1/2*i*e^2*gamma_{+%\sigma_301,%gam_377,%del_27...",2*e^4*(m_e^4 + -1/2*m_e^2*s_14 + -1/2*m_e^2*s_...
4,Interaction: e_gam_239(X)^(*) e_del_219(X)^(...,"Vertex V_1:e(X_2), e(X_4), OffShell A(V_1), V...","-i*e^2*gamma_{+%\sigma_435,%gam_574,%gam_575}*...",8*e^4*(m_e^4 + -1/2*m_e^2*s_13 + 1/2*s_14*s_23...


In [9]:
# Normalize indices in amplitudes and squared amplitudes
df['normalized_amplitude'] = df['amplitude'].apply(normalize_indices)
df['normalized_squared_amplitude'] = df['squared_amplitude'].apply(normalize_indices)

# Display an example of normalization
print("Original amplitude:")
print(df['amplitude'].iloc[0])
print("\nNormalized amplitude:")
print(df['normalized_amplitude'].iloc[0])

Original amplitude:
-1/2*i*e^2*gamma_{+%\sigma_165,%gam_145,%gam_146}*gamma_{%\sigma_165,%gam_147,%del_137}*e_{i_3,%gam_146}(p_1)_u*e_{k_3,%del_137}(p_2)_u*e_{l_3,%gam_145}(p_3)_u^(*)*e_{i_5,%gam_147}(p_4)_u^(*)/(m_e^2 + -s_13 + 1/2*reg_prop)

Normalized amplitude:
-1/2*i*e^2*gamma_{+%\sigma_1,%gam_1,%gam_2}*gamma_{%\sigma_1,%gam_3,%del_1}*e_{i_3,%gam_2}(p_1)_u*e_{k_3,%del_1}(p_2)_u*e_{l_3,%gam_1}(p_3)_u^(*)*e_{i_5,%gam_3}(p_4)_u^(*)/(m_e^2 + -s_13 + 1/2*reg_prop)


In [10]:
# Tokenize normalized expressions
df['tokenized_amplitude'] = df['normalized_amplitude'].apply(tokenize_expression)
df['tokenized_squared_amplitude'] = df['normalized_squared_amplitude'].apply(tokenize_expression)

# Display an example of tokenization
print("Normalized amplitude:")
print(df['normalized_amplitude'].iloc[0])
print("\nTokenized amplitude (first 20 tokens):")
print(df['tokenized_amplitude'].iloc[0][:20])

Normalized amplitude:
-1/2*i*e^2*gamma_{+%\sigma_1,%gam_1,%gam_2}*gamma_{%\sigma_1,%gam_3,%del_1}*e_{i_3,%gam_2}(p_1)_u*e_{k_3,%del_1}(p_2)_u*e_{l_3,%gam_1}(p_3)_u^(*)*e_{i_5,%gam_3}(p_4)_u^(*)/(m_e^2 + -s_13 + 1/2*reg_prop)

Tokenized amplitude (first 20 tokens):
['-', '1/2', '*', 'i', '*', 'e', '^', '2', '*', 'gamma', '_{+%\\sigma_1,%gam_1,%gam_2}', '*', 'gamma', '_{%\\sigma_1,%gam_3,%del_1}', '*', 'e', '_{i_3,%gam_2}', '(', 'p_1', ')']


In [11]:
# Split into train, validation, and test sets (80-10-10)
train_df, temp_df = train_test_split(df, test_size=0.2, random_state=42)
val_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=42)

print(f"Train set size: {len(train_df)}")
print(f"Validation set size: {len(val_df)}")
print(f"Test set size: {len(test_df)}")

Train set size: 12441
Validation set size: 1555
Test set size: 1556


In [12]:
# Analyze token distribution
all_tokens = []
for tokens in df['tokenized_amplitude'] + df['tokenized_squared_amplitude']:
    all_tokens.extend(tokens)

token_counts = Counter(all_tokens)
print(f"Total unique tokens: {len(token_counts)}")
print(f"Most common tokens: {token_counts.most_common(10)}")


Total unique tokens: 44416
Most common tokens: [('*', 575442), ('(', 285768), (')', 285768), ('+', 272448), ('2', 215021), ('^', 198555), ('-', 168857), ('e', 70464), ('reg_prop', 62784), ('gamma', 58752)]


In [13]:
# Dump into a pickle file
outout_dir = "data"
os.makedirs(outout_dir, exist_ok=True)
train_df.to_pickle(os.path.join(outout_dir, "train.pkl"))
val_df.to_pickle(os.path.join(outout_dir, "val.pkl"))
test_df.to_pickle(os.path.join(outout_dir, "test.pkl"))
