In [7]:
import pandas as pd
import re

# Load the CSV data
csv_path = "../QED_data/processed_2.csv"
df = pd.read_csv(csv_path)

# Print column names and first few rows for verification
print("Columns in CSV:", df.columns.tolist())
print("First few rows of CSV:\n", df.head())

# Filter out rows with missing data
df = df.dropna(subset=['text', 'label'])
print(f"Number of rows after filtering NaN: {len(df)}")

# Convert all entries to strings
df['text'] = df['text'].astype(str)
df['label'] = df['label'].astype(str)

# Enhanced replace_state_ids function
def replace_state_ids(text):
    # Step 1: Remove tokenizer artifacts
    text = re.sub(r'%% ]}', '', text)
    text = re.sub(r'%%+', '', text)
    text = re.sub(r'Ġ', '', text)
    text = re.sub(r'čĊ', '', text)
    text = re.sub(r'igma', 'sigma', text)

    # Step 2: Fix LaTeX issues
    text = re.sub(r'\\ssigma', 'sigma', text)
    text = re.sub(r'\\s', '', text)
    text = re.sub(r'\+%sigma', '+sigma', text)
    text = re.sub(r'%sigma', 'sigma', text)

    # Step 3: Simplify repeated [STATE_ID][STATE_ID]_[STATE_ID] patterns
    text = re.sub(r'\[STATE_ID\](?:\[STATE_ID\]_?)+', '[STATE_ID]', text)

    # Step 4: Replace state IDs in formats like _120386_ (e.g., del_120386_)
    text = re.sub(r'(_\d+_)', ' [STATE_ID]', text)

    # Step 5: Replace state IDs in particle names
    prefixes = r'(sigma|gam|del|eta|nu|mu|eps|alpha|beta|tau|rho|lambda|t_eps|t_alpha|t_eta|t_gam|e_eps|e_eta|e_beta|s_eps|s_alpha|s_eta|s_gam|s_del|s_beta|c_eps|c_eta|c_beta|c_gam|c_del|e_gam|e_del)'
    text = re.sub(rf'{prefixes}_(\d+)(?:_\[STATE_ID\](?:_\[STATE_ID\])?)?', r'\1 [STATE_ID]', text)

    # Step 6: Replace state IDs in gamma_{...}
    text = re.sub(rf'\\({prefixes})_(\d+)(?:_\[STATE_ID\](?:_\[STATE_ID\])?)?', r'\\\1 [STATE_ID]', text)

    # Step 7: Replace indices with state IDs (e.g., i_151807, k_151795, j_151787)
    text = re.sub(r'(j|i|k|l)_(\d+)', r'\1 [STATE_ID]', text)

    # Step 8: Simplify any remaining [STATE_ID]_[STATE_ID] patterns
    text = re.sub(r'\[STATE_ID\]_\[STATE_ID\]', '[STATE_ID]', text)

    # Step 9: Remove extra spaces and normalize
    text = re.sub(r'\s+', ' ', text).strip()

    return text

# Apply the replacement to input and target columns
df['text'] = df['text'].apply(replace_state_ids)
df['label'] = df['label'].apply(replace_state_ids)

# Save the preprocessed CSV
preprocessed_csv_path = "preprocessed_data.csv"
df.to_csv(preprocessed_csv_path, index=False)
print(f"Preprocessed CSV saved to {preprocessed_csv_path}")

# Define the vocabulary
# Comprehensive QED-specific vocabulary
vocab = {
    # Special tokens
    "[CLS]": 0,
    "[SEP]": 1,
    "[PAD]": 2,
    "[MASK]": 3,
    "[STATE_ID]": 4,
    "[UNK]": 5,

    # Mathematical operators
    "+": 6,
    "-": 7,
    "*": 8,
    "/": 9,
    "^": 10,
    "(": 11,
    ")": 12,
    "[": 13,
    "]": 14,
    "{": 15,
    "}": 16,

    # Numbers (for coefficients and indices)
    "0": 17,
    "1": 18,
    "2": 19,
    "3": 20,
    "4": 21,
    "5": 22,
    "6": 23,
    "7": 24,
    "8": 25,
    "9": 26,

    # Common variables and constants
    "e": 27,        # Electric charge or electron
    "alpha": 28,    # Fine-structure constant
    "hbar": 29,     # Reduced Planck constant
    "c": 30,        # Speed of light
    "G": 31,        # Gravitational constant (if needed)
    "pi": 32,       # Pi constant
    "i": 33,        # Imaginary unit

    # Particle masses
    "m_e": 34,      # Electron mass
    "m_mu": 35,     # Muon mass
    "m_tau": 36,    # Tau mass
    "m_u": 37,      # Up quark mass
    "m_d": 38,      # Down quark mass
    "m_c": 39,      # Charm quark mass
    "m_s": 40,      # Strange quark mass
    "m_t": 41,      # Top quark mass
    "m_b": 42,      # Bottom quark mass

    # Mandelstam variables
    "s": 43,
    "t": 44,
    "u": 45,
    "s_11": 46,
    "s_12": 47,
    "s_13": 48,
    "s_14": 49,
    "s_21": 50,
    "s_22": 51,
    "s_23": 52,
    "s_24": 53,
    "s_31": 54,
    "s_32": 55,
    "s_33": 56,
    "s_34": 57,
    "s_41": 58,
    "s_42": 59,
    "s_43": 60,
    "s_44": 61,

    # Momenta and coordinates
    "p_1": 62,
    "p_2": 63,
    "p_3": 64,
    "p_4": 65,
    "k_1": 66,
    "k_2": 67,
    "q": 68,
    "X": 69,        # Generic variable (e.g., in e(X))

    # QED-specific terms
    "reg_prop": 70, # Regularized propagator
    "Delta": 71,    # Delta function or difference
    "A": 72,        # Photon field (e.g., A(X))
    "V_0": 73,      # Potential term
    "V_1": 74,      # Potential term

    # Greek letters (common in QED)
    "gamma": 75,    # Dirac matrices
    "sigma": 76,    # Pauli matrices or other
    "epsilon": 77,  # Levi-Civita tensor or small parameter
    "delta": 78,    # Kronecker delta
    "theta": 79,    # Angle or parameter
    "phi": 80,      # Scalar field or angle
    "omega": 81,    # Frequency or field
    "eta": 82,      # Pseudoscalar meson or parameter
    "mu": 83,       # Muon or Lorentz index
    "nu": 84,       # Neutrino or Lorentz index
    "rho": 85,      # Rho meson or density
    "tau": 86,      # Tau lepton or parameter
    "lambda": 87,   # Parameter or field
    "beta": 88,     # Parameter or velocity
    "alpha": 89,    # Fine-structure constant (already included, but repeated for clarity)
    "gam": 90,      # Short form for gamma in subscripts

    # Particles and fields
    "e_gam": 91,    # Electron-gamma interaction
    "e_del": 92,    # Electron-delta interaction
    "e_eps": 93,    # Electron-epsilon interaction
    "e_eta": 94,    # Electron-eta interaction
    "e_mu": 95,     # Electron-muon interaction
    "e_nu": 96,     # Electron-neutrino interaction
    "photon": 97,   # Photon
    "Z": 98,        # Z boson
    "W": 99,        # W boson

    # Subscripts and superscripts
    "_u": 100,      # Up-type subscript
    "_d": 101,      # Down-type subscript
    "_mu": 102,     # Muon subscript
    "_nu": 103,     # Neutrino subscript
    "_e": 104,      # Electron subscript
    "_gam": 105,    # Gamma subscript
    "_del": 106,    # Delta subscript
    "_eps": 107,    # Epsilon subscript
    "_eta": 108,    # Eta subscript

    # Mathematical functions
    "sin": 109,
    "cos": 110,
    "tan": 111,
    "exp": 112,
    "log": 113,

    # Indices (for Dirac matrices, tensors, etc.)
    "i": 114,       # Index (e.g., in e_{i})
    "j": 115,
    "k": 116,
    "l": 117,
}


# Improved pre-tokenizer
def hep_pre_tokenize(text, vocab):
    tokens = []
    i = 0
    while i < len(text):
        matched = False
        # Skip the "to" keyword
        if text[i:i+2] == "to":
            i += 2
            continue
        # Try to match the longest possible symbol from the vocabulary
        for j in range(len(text), i, -1):
            substring = text[i:j]
            if substring in vocab:
                tokens.append(substring)
                i = j
                matched = True
                break
        if not matched:
            # Try to match LaTeX-like structures (e.g., gamma_{...}, e_{...}_u^(*))
            if text[i:i+6].startswith('gamma_'):
                start = i
                i += 6
                if i < len(text) and text[i] == '{':
                    i += 1
                    brace_count = 1
                    content = []
                    while i < len(text) and brace_count > 0:
                        if text[i] == '{':
                            brace_count += 1
                        elif text[i] == '}':
                            brace_count -= 1
                        if brace_count > 0:
                            content.append(text[i])
                        i += 1
                    tokens.append('gamma')
                    tokens.append('{')
                    content_parts = ''.join(content).split()
                    for part in content_parts:
                        if part in vocab:
                            tokens.append(part)
                        elif part.startswith('+'):
                            tokens.append('+')
                            sub_part = part[1:]
                            if sub_part in vocab:
                                tokens.append(sub_part)
                    tokens.append('}')
                    matched = True
            elif text[i:i+2].startswith('e_'):
                start = i
                i += 2
                if i < len(text) and text[i] == '{':
                    i += 1
                    brace_count = 1
                    content = []
                    while i < len(text) and brace_count > 0:
                        if text[i] == '{':
                            brace_count += 1
                        elif text[i] == '}':
                            brace_count -= 1
                        if brace_count > 0:
                            content.append(text[i])
                        i += 1
                    tokens.append('e')
                    tokens.append('{')
                    content_parts = ''.join(content).split()
                    for part in content_parts:
                        if part in vocab:
                            tokens.append(part)
                    tokens.append('}')
                    if i + 2 <= len(text) and text[i:i+2] == '_u':
                        tokens.append('_u')
                        i += 2
                    if i + 4 <= len(text) and text[i:i+4] == '^(*)':
                        tokens.append('^')
                        tokens.append('(')
                        tokens.append('*')
                        tokens.append(')')
                        i += 4
                    matched = True
            if not matched:
                operators = r'(\+|-|\*|/|\^|\(|\)|\[|\]|\{|\})'
                match = re.match(operators, text[i:])
                if match:
                    token = match.group(0)
                    tokens.append(token)
                    i += len(token)
                    matched = True
                else:
                    if text[i].isdigit():
                        num_start = i
                        while i < len(text) and text[i].isdigit():
                            i += 1
                        num = text[num_start:i]
                        tokens.extend(list(num))
                        matched = True
                    else:
                        i += 1
    return tokens

# Updated tokenizer class
class QEDTokenizer:
    def __init__(self, vocab):
        self.vocab = vocab
        self.vocab_reverse = {v: k for k, v in vocab.items()}
        self.pad_token = "[PAD]"
        self.pad_token_id = vocab["[PAD]"]

    def pre_tokenize(self, text):
        return hep_pre_tokenize(text, self.vocab)

    def tokenize(self, text):
        tokens = self.pre_tokenize(text)
        final_tokens = []
        for token in tokens:
            if token in self.vocab:
                final_tokens.append(token)
            elif token.isdigit():
                final_tokens.extend(list(token))
            else:
                continue
        return final_tokens

    def encode(self, text):
        tokens = self.tokenize(text)
        token_ids = [self.vocab.get(token, self.vocab["[PAD]"]) for token in tokens]
        return token_ids

    def decode(self, token_ids, skip_special_tokens=True):
        tokens = []
        for tid in token_ids:
            token = self.vocab_reverse.get(tid, "")
            if skip_special_tokens and token in ["[CLS]", "[SEP]", "[PAD]", "[MASK]", "[STATE_ID]"]:
                continue
            tokens.append(token)
        return "".join(tokens)

# Initialize the tokenizer
tokenizer = QEDTokenizer(vocab)

# Count unique tokens in the preprocessed data
input_texts = df['text'].tolist()
target_texts = df['label'].tolist()
all_texts = input_texts + target_texts

all_tokens = set()
for text in all_texts:
    if text:
        tokens = tokenizer.tokenize(text)
        all_tokens.update(tokens)

print(f"Number of unique tokens in the preprocessed dataset: {len(all_tokens)}")
print(f"Sample tokens: {list(all_tokens)[:20]}")

# Print the preprocessed version of the first non-empty row for verification
first_non_empty_row = df.iloc[0]
print("\nPreprocessed first non-empty row:")
print("Input:", first_non_empty_row['text'])
print("Target:", first_non_empty_row['label'])

# Test tokenization on the first row
input_tokens = tokenizer.tokenize(first_non_empty_row['text'])
target_tokens = tokenizer.tokenize(first_non_empty_row['label'])
print("\nTokenized Input:", input_tokens)
print("Tokenized Target:", target_tokens)

Columns in CSV: ['text', 'label']
First few rows of CSV:
                                                 text  \
0  e_gam_[STATE_ID](X)^(*) e_del_[STATE_ID](X)^(*...   
1  e_gam_[STATE_ID](X)^(*) e_del_[STATE_ID](X)^(*...   
2  e_gam_[STATE_ID](X)^(*) e_del_[STATE_ID](X)^(*...   
3  e_gam_[STATE_ID](X)^(*) e_del_[STATE_ID](X)^(*...   
4  e_gam_[STATE_ID](X)^(*) e_del_[STATE_ID](X)^(*...   

                                               label  
0  2*e^4*(m_e^4 + -1/2*m_e^2*s_13 + 1/2*s_14*s_23...  
1  2*e^4*(m_e^4 + -1/2*m_e^2*s_14 + -1/2*m_e^2*s_...  
2  2*e^4*(m_e^4 + -1/2*m_e^2*s_13 + 1/2*s_14*s_23...  
3  2*e^4*(m_e^4 + -1/2*m_e^2*s_14 + -1/2*m_e^2*s_...  
4  8*e^4*(m_e^4 + -1/2*m_e^2*s_13 + 1/2*s_14*s_23...  
Number of rows after filtering NaN: 15552
Preprocessed CSV saved to preprocessed_data.csv
Number of unique tokens in the preprocessed dataset: 70
Sample tokens: ['u', '+', 'gam', 'i', 'eta', '/', 'alpha', ')', 'e_del', '_eta', 'X', '5', '6', 'm_u', 's_33', '(', 'm_t', '8', '

In [8]:
# Create the vocabulary
vocab = {}

# Add special tokens
special_tokens = ["[CLS]", "[SEP]", "[PAD]", "[MASK]", "[STATE_ID]", "[UNK]"]
for i, token in enumerate(special_tokens):
    vocab[token] = i

# Add digits (0-9)
for digit in range(10):
    token = str(digit)
    vocab[token] = len(vocab)

# Add all unique tokens from the dataset
for token in sorted(all_tokens):
    if token not in vocab:  # Avoid duplicates (e.g., digits already added)
        vocab[token] = len(vocab)

# Print the vocabulary size and some entries
print(f"Vocabulary size: {len(vocab)}")
print("Sample vocabulary entries:")
for token, idx in list(vocab.items())[:10]:
    print(f"{token}: {idx}")
print("...")
for token, idx in list(vocab.items())[-10:]:
    print(f"{token}: {idx}")

Vocabulary size: 77
Sample vocabulary entries:
[CLS]: 0
[SEP]: 1
[PAD]: 2
[MASK]: 3
[STATE_ID]: 4
[UNK]: 5
0: 6
1: 7
2: 8
3: 9
...
s_24: 67
s_33: 68
s_34: 69
s_44: 70
sigma: 71
t: 72
tau: 73
u: 74
{: 75
}: 76
