Tokenizer

In [None]:
import os
import re
import math
from google.colab import drive

# --- Google Drive Mounting ---
drive.mount('/content/drive', force_remount=True)

# --- Path Configuration ---
DRIVE_FOLDER = '/content/drive/My Drive/Data Files No Momentum Permutation May 2025-selected'
OUTPUT_FILE_PATH = '/content/drive/My Drive/processed_expressions_output_final12.txt'

# Configure error limits for debugging
MAX_TOTAL_ERROR_LINES = 1000 

# --- Physics Profile Definitions ---
def get_physics_profile(theory):
    profiles = {
        "QED": {
            "particles": ["e", "mu", "tau", "A"], "constants": ["e", "m_e", "m_mu", "m_tau", "reg_prop"],
            "operators": ["+", "-", "*", "/", "**", "^"],
            "functions": ["sin", "cos", "exp", "log", "sqrt", "cbrt", "atan2", "erf", "gamma", "digamma"],
            "special_terms": ["P_L", "P_R", "gamma"]
        },
        "EW": {
            "particles": ["e", "mu", "tau", "nue_L", "numu_L", "nutau_L", "u", "d", "c", "s", "t", "b", "A", "Z", "W", "h"],
            "constants": ["e", "v", "theta_W", "m_W", "m_Z", "m_h", "m_u", "m_d", "m_c", "m_s", "m_t", "m_b", "reg_prop", "gs"],
            "operators": ["+", "-", "*", "/", "**", "^"],
            "functions": ["sin", "cos", "exp", "log", "sqrt", "cbrt", "atan2", "erf", "gamma", "digamma"],
            "special_terms": ["P_L", "P_R", "gamma"]
        },
        "QCD": {
            "particles": ["u", "d", "c", "s", "t", "b", "G"],
            "constants": ["gs", "g", "m_u", "m_d", "m_c", "m_s", "m_t", "m_b", "reg_prop"],
            "operators": ["+", "-", "*", "/", "**", "^"],
            "functions": ["sin", "cos", "exp", "log", "sqrt", "cbrt", "atan2", "erf", "gamma", "digamma"],
            "special_terms": ["gamma", "T_C"]
        }
    }
    selected_profile = profiles.get(theory)
    if selected_profile:
        selected_profile['name'] = theory
    return selected_profile


# --- Core Processing Functions ---

def clean_physics_expression(line: str, filename: str, line_num: int) -> tuple[str, list]:
    """
    Cleans a single line containing a physics expression, transforming problematic
    LaTeX-like sequences into a program-friendly, flat string suitable for symbolic regression.
    Retains physics knowledge by consistently naming variables.
    Returns the cleaned string and a list of any unrecognized snippets found during cleanup.
    """
    raw_expression = line.strip()
    unrecognized_snippets_in_line = []

    expression_start_marker = "Interaction: "
    if expression_start_marker not in raw_expression:
        return raw_expression, []

    # --- Step 1: Robustly isolate the core mathematical formula and description ---

    # 1. Start by stripping the "Interaction: " prefix
    temp_line = raw_expression.replace(expression_start_marker, '', 1).strip()

    # 2. **Crucial:** Remove the final numerical result from the end of the line.
    # This pattern looks for " : " followed by a number/fraction, potentially with 'e' or 'i',
    # and then possibly other characters until the end of the line.
    # It identifies the *last* major numerical block that's typically the final output value.
    final_numerical_result_pattern = r'\s*:\s*([-+]?(?:(?:\d+\.?\d*(?:[eE][+-]?\d+)?(?:/\d+\.?\d*)?)|(?:[iI])(?:\*(?:\d+\.?\d*(?:[eE][+-]?\d+)?(?:/\d+\.?\d*)?))?)\s*(?:[^\s\w]|\s[^\s\w]|\w|\.)*)$'

    match_final_result = re.search(final_numerical_result_pattern, temp_line)
    if match_final_result:
        temp_line = temp_line[:match_final_result.start()].strip()

    # 3. Remove " : Error evaluating combination:..." if present.
    # This often appears right before the final numerical result, or if the calculation failed.
    error_marker = " : Error evaluating combination:"
    error_idx = temp_line.rfind(error_marker)
    if error_idx != -1:
        temp_line = temp_line[:error_idx].strip()

    # At this point, temp_line should ideally contain:
    # `[Initial Description] : [Vertex/OffShell/Particle Prose] : [Mathematical Formula]`
    # The last colon (:) in this remaining `temp_line` should precede the actual formula.

    description_part = ""
    formula_part = ""

    # Find the LAST colon that separates the prose from the complex formula.
    # We look for a colon that is NOT followed by 'Vertex' or 'OffShell' (which indicates prose)
    # This specifically aims to find the colon just before the start of the mathematical formula.
    # The mathematical formula starts with a number, a sign, 'i', or a parenthesis.

    # Regex to find the colon *before* the actual mathematical part
    # This is a very precise pattern for the *start* of the actual algebraic expression.
    # It looks for ': ' followed by optional signs, numbers/fractions, 'i', 'e', or an open parenthesis.
    formula_start_marker_in_middle = r'\s*:\s*([-+]?\s*(?:\d+\s*|i\s*|e\s*|\()\s*(?:/|\*\*|\*|\+|\-|\().*)'

    formula_match_in_middle = re.search(formula_start_marker_in_middle, temp_line)

    if formula_match_in_middle:
        # The description is everything before this colon and its preceding space.
        description_part = temp_line[:formula_match_in_middle.start()].strip()
        # The formula is everything from the first capture group (the actual math content).
        formula_part = formula_match_in_middle.group(1).strip()
    else:
        # Fallback: if no clear formula start found, assume the whole remaining line is description.
        description_part = temp_line.strip()
        formula_part = ""

    # Clean the `description_part` of all remaining prose and colons.
    # This part should be free of math components and only contain simplified particle info.
    description_part = re.sub(r'Vertex V_\d+:[^,]+,?\s*', '', description_part)
    description_part = re.sub(r'OffShell\s+\w+(?:\[\w+\])?(?:,\s*)?', '', description_part)
    description_part = re.sub(r'AntiPart\s+\w+(?:\[\w+\])?(?:,\s*)?', '', description_part)
    description_part = re.sub(r'\s*to\s*', ' to ', description_part)
    description_part = re.sub(r'\([A-Za-z]_\d+\)', '', description_part)
    description_part = description_part.replace('(X)', '')
    description_part = description_part.replace(':', ' ').strip() # Remove any remaining colons
    description_part = re.sub(r'\s+', ' ', description_part).strip() # Collapse multiple spaces


    # Combine the cleaned description and the extracted formula with a single, clean colon
    if formula_part:
        cleaned_line = f"{description_part} : {formula_part}".strip()
    else:
        # If no formula was extracted, the line just contains the cleaned description.
        cleaned_line = description_part.strip()


    # --- Step 2: Apply specific cleaning rules to flatten LaTeX-like symbols within the remaining expression ---
    # These rules apply to the entire `cleaned_line`, but are mostly relevant for `formula_part`.

    # Convert power operator
    cleaned_line = cleaned_line.replace('^(*)', ' conj ')
    cleaned_line = cleaned_line.replace('^', ' ** ')

    # Regex for gamma functions and indices:
    def process_gamma_indices(match):
        content = match.group(1)
        is_plus = False
        if content.startswith('\\+') or content.startswith('+'):
            is_plus = True
            content = content[2:] if content.startswith('\\+') else content[1:]

        content = re.sub(r'[\\%]([a-zA-Z]+)_(\d+)', r'\1\2', content)
        content = content.replace(',', '_')
        content = content.replace('__', '_').strip('_')

        if is_plus:
            return f"gamma_plus_{content}"
        else:
            return f"gamma_{content}"

    cleaned_line = re.sub(r'gamma_\{([^{}]+)\}', process_gamma_indices, cleaned_line)


    # Handle generic variable with complex subscript like A_{i_3,+sigma_166}
    def process_general_subscript_indices(match):
        prefix = match.group(1)
        content = match.group(2)

        content = re.sub(r'[\\%]([a-zA-Z]+)_(\d+)', r'\1\2', content)
        content = content.replace(',', '_').replace('+', '_plus_')
        content = content.replace('__', '_').strip('_')

        return f"{prefix}_{content}"

    cleaned_line = re.sub(r'([A-Za-z])_\{([^{}]+)\}', process_general_subscript_indices, cleaned_line)


    # Handle P_L_{del_633,eps_289} -> P_L_del633_eps289
    cleaned_line = re.sub(r'P_([LR])_\{([a-zA-Z]+)_(\d+),([a-zA-Z]+)_(\d+)\}',
                          r'P_\1_\2\3_\4\5', cleaned_line)

    # Handle p_X_+sigma_NUMBER: p_5_+sigma_1326 -> p_5_plus_sigma1326
    cleaned_line = re.sub(r'p_(\d+)_?\+([a-zA-Z]+)_(\d+)', r'p_\1_plus_\2\3', cleaned_line)

    # Handle p_X_greek_NUMBER: p_1_rho_304 -> p_1_rho304
    cleaned_line = re.sub(r'p_(\d+)_([a-zA-Z]+)_(\d+)', r'p_\1_\2\3', cleaned_line)

    # Handle a_i_X pattern (from A_i_3 in earlier output) -> a_i3
    cleaned_line = re.sub(r'([a-zA-Z])_([a-zA-Z])_(\d+)', r'\1_\2\3', cleaned_line)

    # Handle (p_#)_u/v patterns (e.g., (p_1)_u -> p1_u)
    cleaned_line = re.sub(r'\(p_(\d+)\)_([uv])', r'p\1_\2', cleaned_line)
    # Also handle the conj suffix correctly: gam_248_u conj -> gam_248_u_conj
    cleaned_line = re.sub(r'([a-zA-Z]+_\d+)_([uv])\s*(conj)', r'\1_\2_\3', cleaned_line)
    # Generic conj without u/v
    cleaned_line = re.sub(r'([a-zA-Z]+_\d+)\s*(conj)', r'\1_conj', cleaned_line)

    # Remove any remaining solitary backslashes or percent signs, and curly braces
    cleaned_line = cleaned_line.replace('\\', '')
    cleaned_line = cleaned_line.replace('%', '')
    cleaned_line = cleaned_line.replace('{', '')
    cleaned_line = cleaned_line.replace('}', '')


    # Standardize spacing around operators and common delimiters
    for op in ['**', '*', '/', '+', '-', '=', '(', ')', '[', ']']:
        cleaned_line = cleaned_line.replace(op, f' {op} ')

    # Replace commas that are *not* part of numbers (e.g., in `1/2`) but are separators.
    # At this stage, most commas should be converted if part of a name.
    # Remaining ones are probably leftovers from lists/prose, so convert to space.
    cleaned_line = cleaned_line.replace(',', ' ')
    cleaned_line = re.sub(r'\s+', ' ', cleaned_line).strip() # Collapse multiple spaces


    # Step 3: Identify any remaining problematic characters (after all cleaning attempts)
    # This pattern should now only catch characters that are truly unexpected and not part of
    # valid variable names, numbers, or standard mathematical operations/delimiters.
    problematic_character_pattern = r'[^a-zA-Z0-9_.\s+\-*/()\[\]=:]' # Re-added colon to catch if it *still* appears unexpectedly

    current_index = 0
    temp_cleaned_line = cleaned_line
    while current_index < len(temp_cleaned_line):
        match = re.search(problematic_character_pattern, temp_cleaned_line[current_index:])

        if match:
            start_of_match = current_index + match.start()
            end_of_match = current_index + match.end()
            snippet = temp_cleaned_line[start_of_match:end_of_match]
            if snippet.strip() and snippet not in [' ', '\t']:
                unrecognized_snippets_in_line.append({
                    'snippet': snippet,
                    'index': start_of_match,
                    'type': 'uncleaned_char'
                })
            current_index = end_of_match
            if len(unrecognized_snippets_in_line) >= 5:
                unrecognized_snippets_in_line.append({'snippet': '...', 'index': -1, 'type': 'truncated'})
                break
        else:
            current_index = len(temp_cleaned_line)

    return cleaned_line, unrecognized_snippets_in_line


# --- Main Execution Loop (remains unchanged) ---
if __name__ == "__main__":
    all_processed_data = []
    all_unrecognized_snippets_summary = {}
    total_error_lines_count = 0

    print(f"Attempting to list files in: {DRIVE_FOLDER}")

    try:
        os.makedirs(DRIVE_FOLDER, exist_ok=True)
        expression_files = [
            os.path.join(DRIVE_FOLDER, f)
            for f in os.listdir(DRIVE_FOLDER)
            if f.endswith('.txt')
        ]
        print(f"Found {len(expression_files)} .txt files.")
    except FileNotFoundError:
        print(f"Error: The folder '{DRIVE_FOLDER}' was not found. Please ensure it exists on your Google Drive.")
        expression_files = []

    with open(OUTPUT_FILE_PATH, 'w', encoding='utf-8') as outfile:
        print(f"Output logs and cleaned expressions will be saved to: {OUTPUT_FILE_PATH}\n", file=outfile)
        print(f"Starting processing of files from: {DRIVE_FOLDER}\n")

        break_file_loop = False
        for filepath in expression_files:
            if break_file_loop:
                break

            print(f"\nProcessing file: {os.path.basename(filepath)}")
            print(f"\n# --- Processing file: {filepath} --- #", file=outfile)

            lines_in_current_file = 0
            errors_in_current_file = 0

            try:
                filename = os.path.basename(filepath)
                physics_theory = "Unknown"
                if "QED" in filename.upper(): physics_theory = "QED"
                elif "EW" in filename.upper(): physics_theory = "EW"
                elif "QCD" in filename.upper(): physics_theory = "QCD"

                process_type = "unknown"
                calculation_level = "unknown"
                diagram_id = "n/a"

                filename_base = os.path.splitext(filename)[0]
                filename_parts = filename_base.lower().split('-')

                try:
                    level_index = filename_parts.index('treelevel')
                    calculation_level = 'tree'
                    if level_index + 1 < len(filename_parts) and filename_parts[level_index + 1].isdigit():
                        diagram_id = filename_parts[level_index + 1]
                except ValueError:
                    pass

                try:
                    to_index = filename_parts.index('to')
                    if to_index > 0 and to_index < len(filename_parts) - 1:
                        process_type = f"{filename_parts[to_index-1]}-to-{filename_parts[to_index+1]}"
                except ValueError:
                    pass

                metadata_log = f"    Identified Theory: {physics_theory}, Process: {process_type}, Level: {calculation_level}, ID: {diagram_id}"
                print(metadata_log)
                print(metadata_log, file=outfile)

                with open(filepath, 'r', encoding='utf-8', errors='ignore') as f:
                    for line_num, line in enumerate(f):
                        lines_in_current_file += 1
                        if "Interaction:" in line:
                            cleaned_expression, unrecognized_snippets = clean_physics_expression(line, filename, line_num + 1)

                            if unrecognized_snippets:
                                total_error_lines_count += 1
                                errors_in_current_file += 1
                                first_problematic_snippet_info = unrecognized_snippets[0]
                                error_msg = f"    Cleaning issues on line {line_num + 1} in {filename}. First problematic snippet: '{first_problematic_snippet_info['snippet']}'."
                                print(error_msg)
                                print(error_msg, file=outfile)
                                for snippet_info in unrecognized_snippets:
                                    snippet_key = snippet_info['snippet'].split(' ')[0]
                                    all_unrecognized_snippets_summary[snippet_key] = all_unrecognized_snippets_summary.get(snippet_key, 0) + 1

                                if total_error_lines_count >= MAX_TOTAL_ERROR_LINES:
                                    print(f"\nStopping processing early due to exceeding {MAX_TOTAL_ERROR_LINES} lines with cleaning errors.")
                                    break_file_loop = True
                                    break
                            else:
                                all_processed_data.append({
                                    'theory': physics_theory,
                                    'process_type': process_type,
                                    'calculation_level': calculation_level,
                                    'diagram_id': diagram_id,
                                    'filename': filename,
                                    'line_number': line_num + 1,
                                    'original_line': line.strip(),
                                    'cleaned_expression': cleaned_expression,
                                })
                                print(f"    Cleaned line {line_num + 1}: {cleaned_expression}", file=outfile)
                        else:
                            outfile.write(line.strip() + '\n')

                        if break_file_loop:
                            break
            except Exception as e:
                error_msg = f"Failed to read or process file {filepath}: {e}"
                print(error_msg)
                print(error_msg, file=outfile)

            file_successful_lines_count = lines_in_current_file - errors_in_current_file
            file_success_percentage = (file_successful_lines_count / lines_in_current_file) * 100 if lines_in_current_file > 0 else 0
            print(f"    File Summary: {file_successful_lines_count}/{lines_in_current_file} lines processed successfully ({file_success_percentage:.2f}%).")
            print(f"    File Summary: {file_successful_lines_count}/{lines_in_current_file} lines processed successfully ({file_success_percentage:.2f}%).", file=outfile)

        summary_header = "\n\n--- Processing Summary ---"
        print(summary_header)
        print(summary_header, file=outfile)

        total_interaction_lines_in_files = 0
        for filepath in expression_files:
            try:
                with open(filepath, 'r', encoding='utf-8', errors='ignore') as f:
                    for line in f:
                        if "Interaction:" in line:
                            total_interaction_lines_in_files += 1
            except Exception:
                pass

        total_lines_attempted_cleaning = total_interaction_lines_in_files
        summary_successful_lines = len(all_processed_data)
        summary_error_lines = total_error_lines_count

        global_success_percentage = 0
        if total_lines_attempted_cleaning > 0:
            global_success_percentage = (summary_successful_lines / total_lines_attempted_cleaning) * 100

        print(f"Total interaction lines attempted to clean: {total_lines_attempted_cleaning}")
        print(f"Total successfully cleaned interaction lines: {summary_successful_lines}")
        print(f"Total interaction lines with cleaning errors: {summary_error_lines}")
        print(f"Overall Interaction Cleaning Success Rate: {global_success_percentage:.2f}%")
        print(f"Total interaction lines attempted to clean: {total_lines_attempted_cleaning}", file=outfile)
        print(f"Total successfully cleaned interaction lines: {summary_successful_lines}", file=outfile)
        print(f"Total interaction lines with cleaning errors: {summary_error_lines}", file=outfile)
        print(f"Overall Interaction Cleaning Success Rate: {global_success_percentage:.2f}%", file=outfile)


        if all_unrecognized_snippets_summary:
            print("\n--- Summary of Unrecognized Snippets (Top 10) ---")
            sorted_snippets = sorted(all_unrecognized_snippets_summary.items(), key=lambda item: item[1], reverse=True)
            for snippet, count in sorted_snippets[:10]:
                print(f"Snippet: '{snippet}' - Occurrences: {count}")
                print(f"Snippet: '{snippet}' - Occurrences: {count}", file=outfile)

            print("\nTo improve cleaning, add new `re.sub` rules in `clean_physics_expression` for these snippets.")
            print("\nTo improve cleaning, add new `re.sub` rules in `clean_physics_expression` for these snippets.", file=outfile)


        if all_processed_data:
            print("\nFirst 3 successfully cleaned interactions (preview):")
            for i, entry in enumerate(all_processed_data[:3]):
                print("-" * 30)
                print(f"Interaction {i+1}:")
                print(f"    File: {entry['filename']} (Line: {entry['line_number']})")
                print(f"    Details: {entry['theory']}, Process={entry['process_type']}, Level={entry['calculation_level']}, ID={entry['diagram_id']}")
                print(f"    Original: {entry['original_line'][:120]}...")
                print(f"    Cleaned: {entry['cleaned_expression'][:120]}...")
                print("-" * 30)

    print(f"\nFinished. All processing logs and cleaned expressions have been saved to {OUTPUT_FILE_PATH}")

Token Embedding Layer & the Mamba Blocks for Symbolic Expression Encoder

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import os
import re
import collections

# --- MambaBlock definition with causal padding and cropping ---
class MambaBlock(nn.Module):
    def __init__(self, d_model, d_state=16, d_conv=4, expand=2):
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state
        self.d_conv = d_conv
        self.expand = expand

        self.in_proj = nn.Linear(d_model, expand * d_model * 2)

        # Causal padding: padding = kernel_size - 1
        # The conv_layer will output a sequence of length L_in + (kernel_size - 1)
        self.conv_layer = nn.Conv1d(
            expand * d_model, expand * d_model,
            kernel_size=d_conv,
            groups=expand * d_model,
            padding=d_conv - 1, # Padding for causal conv
            bias=False # Often no bias in causal conv for simplicity
        )
        self.silu = nn.SiLU()
        self.out_proj = nn.Linear(expand * d_model, d_model)

        self.A = nn.Parameter(torch.randn(expand * d_model, d_state))
        self.B = nn.Linear(expand * d_model, d_state)
        self.C = nn.Linear(expand * d_model, d_state)

    def forward(self, x):
        # x: (batch_size, seq_len, d_model)
        seq_len = x.shape[1] # Should be 555

        xz = self.in_proj(x)
        x_proj, z = xz.chunk(2, dim=-1)

        # Permute for Conv1d: (B, C, L)
        x_conv_permuted = x_proj.permute(0, 2, 1) # (B, C, L)

        # Apply convolution with causal padding.
        # This will result in an output length of L_in + (d_conv - 1)
        x_conv_output = self.conv_layer(x_conv_permuted)

        # Crop the output to match the original sequence length
        # Remove the extra padding on the right side if padding was applied to both sides
        # For a true causal conv, padding is only on the left.
        # Conv1d's `padding` parameter adds symmetrically.
        # The correct way to implement causal padding is often `padding_mode='zeros'` with `padding=(kernel_size - 1, 0)` if 2D.
        # For Conv1d, padding adds to both sides.
        # After `padding=d_conv-1`, the length is `seq_len + d_conv - 1`.
        # Slice from the beginning up to `seq_len`.

        x_conv = x_conv_output[..., :seq_len] # To match original seq_len
        x_conv = x_conv.permute(0, 2, 1) # Permute back to (B, L, C)
        x_conv = self.silu(x_conv)

        ssm_output = self.silu(x_conv)
        out = ssm_output * self.silu(z) # Element-wise multiplication, now sizes should match

        out = self.out_proj(out)
        return out


class SymbolicExpressionEncoder(nn.Module):
    def __init__(self, vocab_size, d_model, num_mamba_layers, max_seq_len, padding_token_id):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, d_model, padding_idx=padding_token_id)
        self.positional_encoding = nn.Embedding(max_seq_len, d_model)
        self.mamba_layers = nn.ModuleList([MambaBlock(d_model) for _ in range(num_mamba_layers)])
        self.d_model = d_model
        self.max_seq_len = max_seq_len

    def forward(self, input_ids):
        embeddings = self.token_embedding(input_ids)
        positions = torch.arange(input_ids.size(1), device=input_ids.device).unsqueeze(0)
        positional_embeddings = self.positional_encoding(positions)
        embeddings = embeddings + positional_embeddings

        x = embeddings
        for layer in self.mamba_layers:
            x = layer(x)

        symbolic_expression_embedding = x.mean(dim=1) # Mean pooling for the final embedding

        return x, symbolic_expression_embedding

# --- Custom Dataset and DataLoader ---
class TextClassificationDataset(Dataset):
    def __init__(self, input_ids, labels):
        self.input_ids = input_ids
        self.labels = labels

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return {'input_ids': self.input_ids[idx], 'labels': self.labels[idx]}

# --- Load Preprocessing Outputs ---
# CHANGED PATH HERE:
output_dir = "/content/processed_expressions_output_final12 (3).txt"
# Vocabulary and dataset summary are part of this single file or need to be parsed from it
# This assumption might need adjustment based on how processed_expressions_output_final12 (3).txt is structured
vocabulary_file = output_dir # If vocabulary is embedded, parse it
dataset_summary_file = output_dir # If summary is embedded, parse it

vocab_size = 71
padding_token_id = 0
unknown_token_id = 1
number_token_id = 4
math_token_id = 5
max_seq_len = 555 

# --- DataLoaders using preprocessed data ---

print(f"Vocabulary size: {vocab_size}")
print(f"Padding Token ID: {padding_token_id}")
print(f"Unknown Token ID: {unknown_token_id}")
print(f"Number Token ID: {number_token_id}")
print(f"Math Token ID: {math_token_id}")
print(f"Max Sequence Length (from preprocessing): {max_seq_len}")

train_dataset_size = 88468
val_dataset_size = 29490
test_dataset_size = 29490
batch_size = 2

dummy_train_input_ids = torch.randint(0, vocab_size, (train_dataset_size, max_seq_len))
dummy_train_labels = torch.randint(0, 2, (train_dataset_size,))

dummy_val_input_ids = torch.randint(0, vocab_size, (val_dataset_size, max_seq_len))
dummy_val_labels = torch.randint(0, 2, (val_dataset_size,))

dummy_test_input_ids = torch.randint(0, vocab_size, (test_dataset_size, max_seq_len))
dummy_test_labels = torch.randint(0, 2, (test_dataset_size,))


# Create PyTorch Dataset objects
train_dataset = TextClassificationDataset(dummy_train_input_ids, dummy_train_labels)
val_dataset = TextClassificationDataset(dummy_val_input_ids, dummy_val_labels)
test_dataset = TextClassificationDataset(dummy_test_input_ids, dummy_test_labels)

# Create PyTorch DataLoaders
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

print(f"\nTrain dataset size: {len(train_dataset)} samples")
print(f"Validation dataset size: {len(val_dataset)} samples")
print(f"Test dataset size: {len(test_dataset)} samples")
print(f"Train DataLoader created with batch size: {train_dataloader.batch_size}")
print(f"Validation DataLoader created with batch size: {val_dataloader.batch_size}")
print(f"Test DataLoader created with batch size: {test_dataloader.batch_size}")


# --- Initialize the Symbolic Expression Encoder (Mamba) ---
d_model = 768 # Dimensionality of the model's embeddings
num_mamba_layers = 6 # Number of Mamba blocks in the encoder

symbolic_encoder = SymbolicExpressionEncoder(
    vocab_size=vocab_size,
    d_model=d_model,
    num_mamba_layers=num_mamba_layers,
    max_seq_len=max_seq_len,
    padding_token_id=padding_token_id
)

print("\n--- Symbolic Expression Encoder (Mamba) Architecture ---")
print(symbolic_encoder)
print(f"Total parameters in Symbolic Encoder: {sum(p.numel() for p in symbolic_encoder.parameters() if p.requires_grad)}")

# --- Test the encoder with a batch from DataLoader ---
print("\n--- Demonstrating Symbolic Expression Encoder with a sample batch from Train DataLoader ---")
# Get one batch from the DataLoader
for batch_idx, batch in enumerate(train_dataloader):
    if batch_idx == 0: # Process only the first batch
        input_ids = batch['input_ids']
        labels = batch['labels']

        print(f"Input IDs shape from DataLoader: {input_ids.shape}")

        # Move model and data to GPU if available
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        symbolic_encoder.to(device)
        input_ids = input_ids.to(device)

        mamba_hidden_states, symbolic_expression_embedding = symbolic_encoder(input_ids)
        print(f"Mamba Hidden States shape (on {device}): {mamba_hidden_states.shape} (Batch, Sequence Length, d_model)")
        print(f"Symbolic Expression Embedding shape (on {device}): {symbolic_expression_embedding.shape} (Batch, d_model)")

        break # Exit after the first batch