In [4]:
from collections import Counter, OrderedDict
from itertools import cycle
import re
import random
from torchtext.vocab import vocab
from tqdm import tqdm
import warnings
import pandas as pd 


In [9]:
class Tokenizer:
    """
    Tokenizer for processing symbolic mathematical expressions.
    """
    def __init__(self, df, index_token_pool_size, momentum_token_pool_size, special_symbols, UNK_IDX, to_replace):
        self.amps = df.amplitude.tolist()
        self.sqamps = df.squared_amplitude.tolist()

        # Issue warnings if token pool sizes are too small
        if index_token_pool_size < 100:
            warnings.warn(f"Index token pool size ({index_token_pool_size}) is small. Consider increasing it.", UserWarning)
        if momentum_token_pool_size < 100:
            warnings.warn(f"Momentum token pool size ({momentum_token_pool_size}) is small. Consider increasing it.", UserWarning)
        
        # Generate token pools
        self.tokens_pool = [f"INDEX_{i}" for i in range(index_token_pool_size)]
        self.momentum_pool = [f"MOMENTUM_{i}" for i in range(momentum_token_pool_size)]
        
        # Regular expression patterns for token replacement
        self.pattern_momentum = re.compile(r'\b[ijkl]_\d{1,}\b')
        self.pattern_num_123 = re.compile(r'\b(?![ps]_)\w+_\d{1,}\b')
        self.pattern_special = re.compile(r'\b\w+_+\w+\b\\')
        self.pattern_underscore_curly = re.compile(r'\b\w+_{')
        self.pattern_prop = re.compile(r'Prop')
        self.pattern_int = re.compile(r'int\{')
        self.pattern_operators = {
            '+': re.compile(r'\+'), '-': re.compile(r'-'), '*': re.compile(r'\*'),
            ',': re.compile(r','), '^': re.compile(r'\^'), '%': re.compile(r'%'),
            '}': re.compile(r'\}'), '(': re.compile(r'\('), ')': re.compile(r'\)')
        }
        self.pattern_mass = re.compile(r'\b\w+_\w\b')
        self.pattern_s = re.compile(r'\b\w+_\d{2,}\b')
        self.pattern_reg_prop = re.compile(r'\b\w+_\d{1}\b')
        self.pattern_antipart = re.compile(r'(\w)_\w+_\d+\(X\)\^\(\*\)')
        self.pattern_part = re.compile(r'(\w)_\w+_\d+\(X\)')
        self.pattern_index = re.compile(r'\b\w+_\w+_\d{2,}\b')
        
        self.special_symbols = special_symbols
        self.UNK_IDX = UNK_IDX
        self.to_replace = to_replace

    @staticmethod
    def remove_whitespace(expression):
        """Remove all forms of whitespace from the expression."""
        return re.sub(r'\s+', '', expression)

    @staticmethod
    def split_expression(expression):
        """Split the expression by space delimiter."""
        return re.split(r' ', expression)

    def build_tgt_vocab(self):
        """Build vocabulary for target sequences."""
        counter = Counter()
        for eqn in tqdm(self.sqamps, desc='Processing target vocab'):
            counter.update(self.tgt_tokenize(eqn))
        voc = vocab(OrderedDict(counter), specials=self.special_symbols[:], special_first=True)
        voc.set_default_index(self.UNK_IDX)
        return voc

    def build_src_vocab(self, seed):
        """Build vocabulary for source sequences."""
        counter = Counter()
        for diag in tqdm(self.amps, desc='Processing source vocab'):
            counter.update(self.src_tokenize(diag, seed))
        voc = vocab(OrderedDict(counter), specials=self.special_symbols[:], special_first=True)
        voc.set_default_index(self.UNK_IDX)
        return voc
    
    def src_replace(self, ampl, seed):
        """Replace indexed and momentum variables with tokenized equivalents."""
        ampl = self.remove_whitespace(ampl)
        
        random.seed(seed)
        token_cycle = cycle(random.sample(self.tokens_pool, len(self.tokens_pool)))
        momentum_cycle = cycle(random.sample(self.momentum_pool, len(self.momentum_pool)))
        
        # Replace momentum tokens
        temp_ampl = ampl
        momentum_mapping = {match: next(momentum_cycle) for match in set(self.pattern_momentum.findall(ampl))}
        for key, value in momentum_mapping.items():
            temp_ampl = temp_ampl.replace(key, value)
        
        # Replace index tokens
        num_123_mapping = {match: next(token_cycle) for match in set(self.pattern_num_123.findall(ampl))}
        for key, value in num_123_mapping.items():
            temp_ampl = temp_ampl.replace(key, value)

        # Replace pattern index tokens
        pattern_index_mapping = {match: f"{'_'.join(match.split('_')[:-1])} {next(token_cycle)}"
                for match in set(self.pattern_index.findall(ampl))
            }
        for key, value in pattern_index_mapping.items():
            temp_ampl = temp_ampl.replace(key, value)
            
        return temp_ampl
    
    def src_tokenize(self, ampl, seed):
        """Tokenize source expression, optionally applying replacements."""
        temp_ampl = self.src_replace(ampl, seed) if self.to_replace else ampl
        temp_ampl = temp_ampl.replace('\\\\', '\\').replace('\\', ' \\ ').replace('%', '')

        temp_ampl = self.pattern_underscore_curly.sub(lambda match: f' {match.group(0)} ', temp_ampl)

        
        for symbol, pattern in self.pattern_operators.items():
            temp_ampl = pattern.sub(f' {symbol} ', temp_ampl)
        
        temp_ampl = re.sub(r' {2,}', ' ', temp_ampl)
        return [token for token in self.split_expression(temp_ampl) if token]

    def tgt_tokenize(self, sqampl):
        """Tokenize target expression."""
        sqampl = self.remove_whitespace(sqampl)
        temp_sqampl = sqampl
        
        for symbol, pattern in self.pattern_operators.items():
            temp_sqampl = pattern.sub(f' {symbol} ', temp_sqampl)
        
        for pattern in [self.pattern_reg_prop, self.pattern_mass, self.pattern_s]:
            temp_sqampl = pattern.sub(lambda match: f' {match.group(0)} ', temp_sqampl)
        
        temp_sqampl = re.sub(r' {2,}', ' ', temp_sqampl)
        return [token for token in self.split_expression(temp_sqampl) if token]

In [None]:
df = pd.read_csv(r"../QED_data/preprocessed_data.csv")  

index_token_pool_size = 200 
momentum_token_pool_size = 200  
special_symbols = ["<unk>", "<pad>", "<bos>", "<eos>"]
UNK_IDX = 0  
to_replace = True 

tokenizer = Tokenizer(df, index_token_pool_size, momentum_token_pool_size, special_symbols, UNK_IDX, to_replace)

seed = 42 
tokenized_amps = [tokenizer.src_tokenize(amp, seed) for amp in df["amplitude"]]
tokenized_sqamps = [tokenizer.tgt_tokenize(sqamp) for sqamp in df["squared_amplitude"]]

src_vocab = tokenizer.build_src_vocab(seed)
tgt_vocab = tokenizer.build_tgt_vocab()

Processing source vocab: 100%|██████████| 15552/15552 [00:03<00:00, 3892.58it/s]
Processing target vocab: 100%|██████████| 15552/15552 [00:01<00:00, 10796.21it/s]


In [12]:
import torch
from torch.utils.data import Dataset, DataLoader

class MathExprDataset(Dataset):
    """
    Custom Dataset for handling tokenized mathematical expressions.
    """
    def __init__(self, tokenized_src, tokenized_tgt, src_vocab, tgt_vocab):
        self.tokenized_src = tokenized_src
        self.tokenized_tgt = tokenized_tgt
        self.src_vocab = src_vocab
        self.tgt_vocab = tgt_vocab

    def __len__(self):
        return len(self.tokenized_src)

    def __getitem__(self, idx):
        src_tokens = self.tokenized_src[idx]
        tgt_tokens = self.tokenized_tgt[idx]
        
        # Converting tokens to indices using vocab
        src_indices = [self.src_vocab[token] for token in src_tokens]
        tgt_indices = [self.tgt_vocab[token] for token in tgt_tokens]

        return torch.tensor(src_indices), torch.tensor(tgt_indices)


In [21]:
from torch.nn.utils.rnn import pad_sequence

def collate_fn(batch):
    """
    Custom collate function that pads sequences to the maximum length in the batch.
    """
    # Unzip the batch into source (src) and target (tgt) sequences
    src_batch, tgt_batch = zip(*batch)

    # Pad sequences (ensure padding_value is a float)
    src_batch_padded = pad_sequence(src_batch, batch_first=True, padding_value=float(src_vocab["<pad>"]))
    tgt_batch_padded = pad_sequence(tgt_batch, batch_first=True, padding_value=float(tgt_vocab["<pad>"]))

    return src_batch_padded, tgt_batch_padded


In [22]:
from sklearn.model_selection import train_test_split

train_src, temp_src, train_tgt, temp_tgt = train_test_split(
    tokenized_amps, tokenized_sqamps, test_size=0.3, random_state=42
)

val_src, test_src, val_tgt, test_tgt = train_test_split(
    temp_src, temp_tgt, test_size=0.5, random_state=42
)


In [23]:
train_dataset = MathExprDataset(tokenized_src=train_src,
                                 tokenized_tgt=train_tgt,
                                 src_vocab=src_vocab,
                                 tgt_vocab=tgt_vocab)

val_dataset = MathExprDataset(tokenized_src=val_src,
                               tokenized_tgt=val_tgt,
                               src_vocab=src_vocab,
                               tgt_vocab=tgt_vocab)

test_dataset = MathExprDataset(tokenized_src=test_src,
                                tokenized_tgt=test_tgt,
                                src_vocab=src_vocab,
                                tgt_vocab=tgt_vocab)

train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)


In [None]:
for batch_idx, (src_batch, tgt_batch) in enumerate(train_dataloader):
    print(f"Batch {batch_idx+1}")
    
    print(f"Source Batch (src_batch): {src_batch}")
    
    print(f"Target Batch (tgt_batch): {tgt_batch}")
    
    break


Batch 1
Source Batch (src_batch): tensor([[79,  6,  7,  ...,  1,  1,  1],
        [ 4, 78,  6,  ...,  1,  1,  1],
        [ 4, 82,  6,  ...,  1,  1,  1],
        ...,
        [79,  6,  7,  ...,  1,  1,  1],
        [ 4,  7,  6,  ...,  1,  1,  1],
        [ 7,  6,  8,  ...,  1,  1,  1]])
Target Batch (tgt_batch): tensor([[51,  5,  6,  ...,  1,  1,  1],
        [50,  5,  6,  ...,  1,  1,  1],
        [53,  5,  6,  ...,  1,  1,  1],
        ...,
        [51,  5,  6,  ...,  1,  1,  1],
        [ 6,  7,  8,  ...,  1,  1,  1],
        [12, 43,  5,  ..., 20,  1,  1]])


In [26]:
import pickle 

In [27]:
with open(r'../src/Dataloaders/train_loader.pkl', 'wb') as fp:
    pickle.dump(train_dataloader , fp) 

with open(r'../src/Dataloaders/test_loader.pkl', 'wb') as fp:
    pickle.dump(train_dataloader , fp) 

with open(r'../src/Dataloaders/val_loader.pkl', 'wb') as fp:
    pickle.dump(train_dataloader , fp) 
