In [5]:
import re
import sympy

class QEDTokenizer:
    def __init__(self, vocab):
        self.vocab = vocab
        self.vocab_reverse = {v:k for k, v in vocab.items()}
        self.pad_token = "[PAD]"
        self.pad_token_id = vocab["[PAD]"]
        self.unk_token = "[UNK]"
        self.unk_token_id = vocab["[UNK]"]
        self.cls_token = "[CLS]"
        self.cls_token_id = vocab["[CLS]"]
        self.sep_token = "[SEP]"
        self.sep_token_id = vocab["[SEP]"]

    def pre_tokenize(self, text):
        tokens = []
        i = 0
        while i < len(text):
            matched = False
            # Skip the "to" keyword
            if text[i:i+2].lower() == "to":
                i += 2
                continue

            # Match special tokens like [STATE_ID]
            if text[i] == '[':
                end = text.find(']', i)
                if end != -1:
                    token = text[i:end+1]
                    if token == '[STATE_ID]':
                        tokens.append(token)
                        i = end + 1
                        matched = True

            # Match LaTeX-like structures (e.g., gamma_{...}, A_{...}, e_{...}_u^(*))
            if not matched and text[i:i+6].startswith('gamma_'):
                start = i
                i += 6
                if i < len(text) and text[i] == '{':
                    i += 1
                    brace_count = 1
                    content = []
                    while i < len(text) and brace_count > 0:
                        if text[i] == '{':
                            brace_count += 1
                        elif text[i] == '}':
                            brace_count -= 1
                        if brace_count > 0:
                            content.append(text[i])
                        i += 1
                    tokens.append('gamma')
                    tokens.append('{')
                    content_str = ''.join(content)
                    content_parts = []
                    j = 0
                    while j < len(content_str):
                        if content_str[j:j+2] == r'\l':
                            end = j + 7  # Length of '\lambda'
                            latex_index = content_str[j:end]
                            content_parts.append(latex_index)
                            j = end
                        elif content_str[j:j+2] == r'\m':
                            end = j + 3  # Length of '\mu'
                            latex_index = content_str[j:end]
                            content_parts.append(latex_index)
                            j = end
                        elif content_str[j] == '[':
                            end = content_str.find(']', j)
                            if end != -1:
                                content_parts.append(content_str[j:end+1])
                                j = end + 1
                            else:
                                j += 1
                        elif content_str[j] == '+':
                            content_parts.append('+')
                            j += 1
                        else:
                            j += 1
                    for part in content_parts:
                        tokens.append(part)
                    tokens.append('}')
                    matched = True

            if not matched and text[i:i+2].startswith('A_'):
                start = i
                i += 2
                if i < len(text) and text[i] == '\\':
                    i += 1
                    symbol_match = re.match(r'[a-zA-Z]+', text[i:])
                    if symbol_match:
                        token = symbol_match.group(0)
                        tokens.append(f'A_{token}')
                        i += len(token)
                        matched = True

            if not matched and text[i:i+2].startswith('e_'):
                start = i
                i += 2
                if i < len(text) and text[i] == '{':
                    i += 1
                    brace_count = 1
                    content = []
                    while i < len(text) and brace_count > 0:
                        if text[i] == '{':
                            brace_count += 1
                        elif text[i] == '}':
                            brace_count -= 1
                        if brace_count > 0:
                            content.append(text[i])
                        i += 1
                    tokens.append('e')
                    tokens.append('{')
                    content_str = ''.join(content)
                    content_parts = content_str.split()
                    for part in content_parts:
                        if part in {'i', 'j', 'k', 'l', 'gamma', '[STATE_ID]'}:
                            tokens.append(part)
                        else:
                            tokens.append('gamma')  # Standardize 'gam' to 'gamma'
                    tokens.append('}')
                    if i + 2 <= len(text) and text[i:i+2] == '_u':
                        tokens.append('_u')
                        i += 2
                    elif i + 2 <= len(text) and text[i:i+2] == '_v':
                        tokens.append('_v')
                        i += 2
                    if i + 4 <= len(text) and text[i:i+4] == '^(*)':
                        tokens.append('^')
                        tokens.append('(')
                        tokens.append('*')
                        tokens.append(')')
                        i += 4
                    matched = True

            # Handle tokens like del_7748_[STATE_ID]
            if not matched:
                state_id_pattern = re.match(r'([a-zA-Z]+)_(\d+)_(\[STATE_ID\])', text[i:])
                if state_id_pattern:
                    prefix, number, state_id = state_id_pattern.groups()
                    tokens.append(prefix)
                    tokens.append(number)
                    tokens.append(state_id)
                    i += len(prefix) + 1 + len(number) + 1 + len(state_id)
                    matched = True

            # Handle particle names with indices (e.g., alpha_i, alpha_j)
            if not matched:
                particle_with_index = re.match(r'([a-zA-Z]+)_([a-zA-Z])(?=\b|[^a-zA-Z0-9_])', text[i:])
                if particle_with_index:
                    particle, index = particle_with_index.groups()
                    tokens.append(particle)
                    tokens.append(index)
                    i += len(particle) + 1 + len(index)
                    matched = True

            # Handle superscripts (e.g., x^2)
            if not matched and text[i] == '^':
                tokens.append('^')
                i += 1
                if i < len(text) and text[i].isdigit():
                    num_match = re.match(r'\d+', text[i:])
                    if num_match:
                        num = num_match.group(0)
                        tokens.append(num)
                        i += len(num)
                    matched = True

            # Handle fractions (e.g., \frac{a}{b})
            if not matched and text[i:i+5] == r'\frac':
                i += 5
                if i < len(text) and text[i] == '{':
                    i += 1
                    brace_count = 1
                    numerator = []
                    while i < len(text) and brace_count > 0:
                        if text[i] == '{':
                            brace_count += 1
                        elif text[i] == '}':
                            brace_count -= 1
                        if brace_count > 0:
                            numerator.append(text[i])
                        i += 1
                    if i < len(text) and text[i] == '{':
                        i += 1
                        brace_count = 1
                        denominator = []
                        while i < len(text) and brace_count > 0:
                            if text[i] == '{':
                                brace_count += 1
                            elif text[i] == '}':
                                brace_count -= 1
                            if brace_count > 0:
                                denominator.append(text[i])
                            i += 1
                        num_str = ''.join(numerator)
                        den_str = ''.join(denominator)
                        if num_str.startswith('-'):
                            tokens.append('-')
                            num_str = num_str[1:]
                        tokens.append('/')
                        tokens.append(num_str)
                        tokens.append(den_str)
                        matched = True

            # Match operators and parentheses
            if not matched:
                operators = r'(\+|-|\*|/|\^|\(|\)|\[|\]|\{|\})'
                match = re.match(operators, text[i:])
                if match:
                    token = match.group(0)
                    tokens.append(token)
                    i += len(token)
                    matched = True

            # Match symbols (e.g., e, mu_eps, s_12, m_e)
            if not matched:
                symbol_match = re.match(r'[a-zA-Z_][a-zA-Z0-9_]*', text[i:])
                if symbol_match:
                    token = symbol_match.group(0)
                    # Split tokens like b_gam into b and gamma
                    if '_' in token and token not in {'m_e', 'm_mu', 'm_nu', 'm_tau', 'm_b', 'm_c', 'm_d', 'm_s', 'm_t', 'm_u', 's_11', 's_12', 's_13', 's_14', 's_22', 's_23', 's_24', 's_33', 's_34', 's_44', '_u', '_v'}:
                        parts = token.split('_')
                        for part in parts:
                            if part:
                                tokens.append(part)
                    else:
                        tokens.append(token)
                    i += len(token)
                    matched = True

            # Match multi-digit numbers as single tokens
            if not matched and text[i].isdigit():
                num_match = re.match(r'\d+', text[i:])
                if num_match:
                    num = num_match.group(0)
                    tokens.append(num)
                    i += len(num)
                    matched = True

            # If no match, move to the next character
            if not matched:
                i += 1

        return tokens

    def tokenize(self, text):
        tokens = self.pre_tokenize(text)
        final_tokens = [self.cls_token] + tokens + [self.sep_token]
        final_tokens = [token if token in self.vocab else self.unk_token for token in final_tokens]
        return final_tokens

    def encode(self, text):
        tokens = self.tokenize(text)
        token_ids = [self.vocab[token] for token in tokens]
        return token_ids

    def decode(self, token_ids, for_sympy=False):
        tokens = [self.vocab_reverse.get(tid, self.unk_token) for tid in token_ids]
        tokens = [token for token in tokens if token not in {self.cls_token, self.sep_token, self.pad_token}]

        if for_sympy:
            sympy_tokens = []
            i = 0
            while i < len(tokens):
                token = tokens[i]
                if token == '^':
                    sympy_tokens.append('**')
                elif token in {'_u', '_v'}:
                    pass
                elif token.startswith('%\\'):
                    greek_letter = token[3:]
                    sympy_tokens.append(greek_letter)
                elif token == '[STATE_ID]':
                    sympy_tokens.append('id1')
                elif token in {'{', '}'}:
                    pass
                elif token == 'gamma':
                    if i + 1 < len(tokens) and tokens[i + 1] == '{':
                        brace_count = 1
                        i += 2
                        while i < len(tokens) and brace_count > 0:
                            if tokens[i] == '{':
                                brace_count += 1
                            elif tokens[i] == '}':
                                brace_count -= 1
                            i += 1
                        i -= 1
                    sympy_tokens.append('gamma')
                elif token.endswith('_[STATE_ID]'):
                    base_token = token.replace('_[STATE_ID]', '')
                    sympy_tokens.append(base_token)
                else:
                    sympy_tokens.append(token)
                i += 1
            return ' '.join(sympy_tokens)
        else:
            return ' '.join(tokens)

In [6]:
import pandas as pd
import pickle

In [7]:
df = pd.read_csv(r'../QED_data/processed_2.csv')
vocab = pickle.load(open(r'../QED_data/vocab.pkl', 'rb'))



In [8]:
tokenizer = QEDTokenizer(vocab)
vocab = {word: i for i, word in enumerate(vocab)}
all_texts = df['text'].tolist() + df['label'].tolist()
all_tokens = set()
for text in all_texts:
    if text:
        tokens = tokenizer.pre_tokenize(text)
        all_tokens.update(tokens)

special_tokens = ["[CLS]", "[SEP]", "[PAD]", "[STATE_ID]", "[UNK]"]

In [9]:
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import torch

class QEDDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length=128):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text_encoded = self.tokenizer.encode(self.texts[idx])
        label_encoded = self.tokenizer.encode(self.labels[idx])

        text_encoded = text_encoded[:self.max_length]
        label_encoded = label_encoded[:self.max_length]

        text_padding = [self.tokenizer.pad_token_id] * (self.max_length - len(text_encoded))
        label_padding = [self.tokenizer.pad_token_id] * (self.max_length - len(label_encoded))

        text_encoded += text_padding
        label_encoded += label_padding

        return {
            'text': torch.tensor(text_encoded, dtype=torch.long),
            'label': torch.tensor(label_encoded, dtype=torch.long),
            'text_attention_mask': torch.tensor([1] * len(text_encoded) + [0] * len(text_padding), dtype=torch.long),
            'label_attention_mask': torch.tensor([1] * len(label_encoded) + [0] * len(label_padding), dtype=torch.long)
        }




In [13]:

df = pd.read_csv(r'../QED_data/processed_2.csv')
train_texts, val_texts, train_labels, val_labels = train_test_split(
    df['text'].tolist(),
    df['label'].tolist(),
    test_size=0.2,
    random_state=42
)

max_length = 512 
train_dataset = QEDDataset(train_texts, train_labels, tokenizer, max_length)
val_dataset = QEDDataset(val_texts, val_labels, tokenizer, max_length)

batch_size = 32
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

print(f"Training dataset size: {len(train_dataset)}")
print(f"Validation dataset size: {len(val_dataset)}")

sample = train_dataset[0]
print("\nSample from dataset:")
print("Text (encoded):", sample['text'])

Training dataset size: 12441
Validation dataset size: 3111

Sample from dataset:
Text (encoded): tensor([  0,  34, 127,  73,   5,  12,  37,  13, 127,   5,   5,  12,  37,  13,
         38,  12,  14,  13,  34, 127,  41,   5,  12,  37,  13,  38,  12,  14,
         13, 127,  50,   5,  12,  37,  13,   5,  19,  34, 127,  12,  37,  13,
        127,  12,  37,  13,  28,  12,  37,  13,   5,  18,  34, 127,  12,  37,
         13, 127,  12,  37,  13,  28,  12,  37,  13,  16,  22,  17,  27,  14,
         76,  14,  68,  38,  20,  14,  75, 141,  15,   4,   4,   4,   4, 142,
         14,  75, 141,   4,   4,   4,   4, 142,  14, 127, 141,  79,   5,   5,
          5,   4,   4, 142,  12,  37,  13,  39,  14, 127, 141,  78,   5,  73,
          5,   4,   4, 142,  12,  37,  13,  40,  38,  12,  14,  13,  14, 127,
        141,  77,   5,  73,   5,   4,   4, 142,  12,  37,  13,  39,  38,  12,
         14,  13,  14, 127, 141,  76,   5,   5,   5,   4,   4, 142,  12,  37,
         13,  40,  17,  12,  37,  13,   1,   

In [14]:
print("Label (encoded):", sample['label'])

Label (encoded): tensor([  0,   5,  17,   5,  14,  68,  38,  22,  14,  12,   5,  14,   5, 127,
         38,  22,  15,  26,  14,   5, 127,  38,  20,  14, 103,  15,  26,  14,
        105,  14, 107,  15,  26,  14, 104,  14, 108,  15,  26,  14,   5, 127,
         38,  20,  14, 110,  13,  14,  12,   5, 127,  38,  20,  15, 102,  15,
         20,  14, 103,  15,   5,   5,  13,  38,  12,  16,  20,  13,   1,   2,
          2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,
          2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,
          2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,
          2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,
          2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,
          2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,
          2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,
          2,   2,   2,   2,   2,   2,   2,   2,

In [15]:
import pickle

In [16]:
with open(r'../src/Dataloaders/train_loader.pkl', 'wb') as f:
    pickle.dump(tokenizer, f)
with open(r'../src/Dataloaders/val_loader.pkl', 'wb') as f:
    pickle.dump(tokenizer, f)
with open(r'../src/Dataloaders/test_loader.pkl', 'wb') as f:
    pickle.dump(tokenizer, f)