In [9]:
import pandas as pd
import re
from collections import Counter

In [10]:
csv_path = "../QED_data/processed_dataset.csv"
df = pd.read_csv(csv_path)

In [11]:
print("Columns in CSV:", df.columns.tolist())
print("First few rows of CSV:\n", df.head())

df = df.dropna(subset=['text', 'label'])
print(f"Number of rows after filtering NaN: {len(df)}")

df['text'] = df['text'].astype(str)
df['label'] = df['label'].astype(str)

Columns in CSV: ['text', 'label']
First few rows of CSV:
                                                 text  \
0  e_gam_[STATE_ID](X)^(*) e_del_[STATE_ID](X)^(*...   
1  e_gam_[STATE_ID](X)^(*) e_del_[STATE_ID](X)^(*...   
2  e_gam_[STATE_ID](X)^(*) e_del_[STATE_ID](X)^(*...   
3  e_gam_[STATE_ID](X)^(*) e_del_[STATE_ID](X)^(*...   
4  e_gam_[STATE_ID](X)^(*) e_del_[STATE_ID](X)^(*...   

                                               label  
0  2*e^4*(m_e^4 + -1/2*m_e^2*s_13 + 1/2*s_14*s_23...  
1  2*e^4*(m_e^4 + -1/2*m_e^2*s_14 + -1/2*m_e^2*s_...  
2  2*e^4*(m_e^4 + -1/2*m_e^2*s_13 + 1/2*s_14*s_23...  
3  2*e^4*(m_e^4 + -1/2*m_e^2*s_14 + -1/2*m_e^2*s_...  
4  8*e^4*(m_e^4 + -1/2*m_e^2*s_13 + 1/2*s_14*s_23...  
Number of rows after filtering NaN: 15552


In [12]:

def replace_state_ids(text):
    text = re.sub(r'%% ]}', '', text)
    text = re.sub(r'%%+', '', text)
    text = re.sub(r'Ġ', '', text)
    text = re.sub(r'čĊ', '', text)
    text = re.sub(r'igma', 'sigma', text)

    text = re.sub(r'\\ssigma', 'sigma', text)
    text = re.sub(r'\\s', '', text)
    text = re.sub(r'\+%sigma', '+sigma', text)
    text = re.sub(r'%sigma', 'sigma', text)

    text = re.sub(r'\bdel\b', 'delta', text)

    while '[STATE_ID][STATE_ID]' in text:
        text = re.sub(r'\[STATE_ID\]\[STATE_ID\]', '[STATE_ID]', text)
    text = re.sub(r'\[STATE_ID\](?:\[STATE_ID\]_?)+', '[STATE_ID]', text)

    text = re.sub(r'(_\d+_)', ' [STATE_ID]', text)

    prefixes = r'(sigma|gam|delta|eta|nu|mu|eps|alpha|beta|tau|rho|lambda|t_eps|t_alpha|t_eta|t_gam|e_eps|e_eta|e_beta|s_eps|s_alpha|s_eta|s_gam|s_delta|s_beta|c_eps|c_eta|c_beta|c_gam|c_delta|e_gam|e_delta)'
    text = re.sub(rf'{prefixes}_(\d+)(?:_\[STATE_ID\](?:_\[STATE_ID\])?)?', r'\1 [STATE_ID]', text)

    text = re.sub(rf'\\({prefixes})_(\d+)(?:_\[STATE_ID\](?:_\[STATE_ID\])?)?', r'\\\1 [STATE_ID]', text)

    text = re.sub(r'(j|i|k|l)_(\d+)', r'\1 [STATE_ID]', text)

    text = re.sub(r'\[STATE_ID\]_\[STATE_ID\]', '[STATE_ID]', text)

    text = re.sub(r'\s+', ' ', text).strip()

    return text


In [13]:
df['text'] = df['text'].apply(replace_state_ids)
df['label'] = df['label'].apply(replace_state_ids)

preprocessed_csv_path = "preprocessed_data.csv"
df.to_csv(preprocessed_csv_path, index=False)
print(f"Preprocessed CSV saved to {preprocessed_csv_path}")

Preprocessed CSV saved to preprocessed_data.csv


In [14]:

def hep_pre_tokenize(text):
    tokens = []
    i = 0
    while i < len(text):
        matched = False
        # Skip the "to" keyword
        if text[i:i+2].lower() == "to":
            i += 2
            continue
        if text[i] == '[':
            end = text.find(']', i)
            if end != -1:
                token = text[i:end+1]
                if token == '[STATE_ID]':
                    tokens.append(token)
                    i = end + 1
                    matched = True
        if not matched and text[i:i+6].startswith('gamma_'):
            start = i
            i += 6
            if i < len(text) and text[i] == '{':
                i += 1
                brace_count = 1
                content = []
                while i < len(text) and brace_count > 0:
                    if text[i] == '{':
                        brace_count += 1
                    elif text[i] == '}':
                        brace_count -= 1
                    if brace_count > 0:
                        content.append(text[i])
                    i += 1
                tokens.append('gamma')
                tokens.append('{')
                content_str = ''.join(content)
                content_parts = []
                j = 0
                while j < len(content_str):
                    if content_str[j:j+3] == '%\\':
                        end = j + 3
                        while end < len(content_str) and content_str[end].isalpha():
                            end += 1
                        latex_index = content_str[j:end]
                        content_parts.append(latex_index)
                        j = end
                    elif content_str[j] == '[':
                        end = content_str.find(']', j)
                        if end != -1:
                            content_parts.append(content_str[j:end+1])
                            j = end + 1
                        else:
                            j += 1
                    elif content_str[j] == '+':
                        content_parts.append('+')
                        j += 1
                    else:
                        j += 1
                for part in content_parts:
                    tokens.append(part)
                tokens.append('}')
                matched = True

        if not matched and text[i:i+2].startswith('A_'):
            start = i
            i += 2
            if i < len(text) and text[i] == '\\':
                i += 1
                symbol_match = re.match(r'[a-zA-Z]+', text[i:])
                if symbol_match:
                    token = symbol_match.group(0)
                    tokens.append(f'A_{token}')
                    i += len(token)
                    matched = True

        if not matched and text[i:i+2].startswith('e_'):
            start = i
            i += 2
            if i < len(text) and text[i] == '{':
                i += 1
                brace_count = 1
                content = []
                while i < len(text) and brace_count > 0:
                    if text[i] == '{':
                        brace_count += 1
                    elif text[i] == '}':
                        brace_count -= 1
                    if brace_count > 0:
                        content.append(text[i])
                    i += 1
                tokens.append('e')
                tokens.append('{')
                content_str = ''.join(content)
                content_parts = content_str.split()
                for part in content_parts:
                    tokens.append(part)
                tokens.append('}')
                if i + 2 <= len(text) and text[i:i+2] == '_u':
                    tokens.append('_u')
                    i += 2
                elif i + 2 <= len(text) and text[i:i+2] == '_v':
                    tokens.append('_v')
                    i += 2
                if i + 4 <= len(text) and text[i:i+4] == '^(*)':
                    tokens.append('^')
                    tokens.append('(')
                    tokens.append('*')
                    tokens.append(')')
                    i += 4
                matched = True

        if not matched:
            operators = r'(\+|-|\*|/|\^|\(|\)|\[|\]|\{|\})'
            match = re.match(operators, text[i:])
            if match:
                token = match.group(0)
                tokens.append(token)
                i += len(token)
                matched = True

        if not matched:
            symbol_match = re.match(r'[a-zA-Z_][a-zA-Z0-9_]*(?:\[[a-zA-Z0-9_]+\])?', text[i:])
            if symbol_match:
                token = symbol_match.group(0)
                if token.endswith('_') and i + len(token) < len(text):
                    next_part = re.match(r'[a-zA-Z0-9_]+', text[i + len(token):])
                    if next_part:
                        token += next_part.group(0)
                tokens.append(token)
                i += len(token)
                matched = True

        if not matched and text[i].isdigit():
            num_start = i
            while i < len(text) and text[i].isdigit():
                i += 1
            num = text[num_start:i]
            tokens.extend(list(num))
            matched = True

        if not matched and text[i:i+8].lower() == 'antipart':
            tokens.append('AntiPart')
            i += 8
            matched = True

        if not matched:
            i += 1

    return tokens

In [16]:
all_texts = df['text'].tolist() + df['label'].tolist()
all_tokens = set()
token_counter = Counter()

for text in all_texts:
    if text:
        tokens = hep_pre_tokenize(text)
        all_tokens.update(tokens)
        token_counter.update(tokens)

print(f"\nNumber of unique tokens in the dataset: {len(all_tokens)}")
print(f"Sample tokens: {list(all_tokens)[:20]}")
print("\nMost common tokens (top 20):")
for token, count in token_counter.most_common(20):
    print(f"{token}: {count}")

with open("../QED_data/unique_tokens.txt", "w") as f:
    for token in sorted(all_tokens):
        f.write(f"{token}\n")
print("Unique tokens saved to unique_tokens.txt")


Number of unique tokens in the dataset: 139
Sample tokens: ['1', 'tt_', 'u', 'c_gam_[STATE_ID]', '-', '(', 'b_beta', 's', 'mu_alpha_[STATE_ID]', ')', '_u', 'gam', 'u_', 'c_eta_[STATE_ID]', 'eta_[STATE_ID]', 'reg_prop', 'b', 'rho', 'tt_alpha', 't_beta']

Most common tokens (top 20):
*: 600380
[STATE_ID]: 527366
): 500040
(: 489096
+: 281402
X: 276216
^: 247371
2: 231819
-: 152129
{: 141002
}: 141002
/: 112763
1: 86884
e: 77376
i: 70812
AntiPart: 62208
8: 55607
reg_prop: 51264
gamma: 44234
s_23: 43632
Unique tokens saved to unique_tokens.txt
