# TRAE (Transformer Autoencoder) for Expression Embeddings

This notebook runs the TRAE code for training and testing a projection head on expression embeddings. The notebook:

1. Clones the required repository
2. Installs all necessary dependencies
3. Sets up the environment
4. Runs the TRAE training and testing pipeline

**Note**: Make sure to enable GPU acceleration in Kaggle for optimal performance.

## 1. Clone Repository and Setup Environment

In [None]:
import os
import subprocess
import sys

# Clone the repository
if not os.path.exists('/kaggle/working/CHEHAB_FHE_Compiler_RL'):
    print("Cloning repository...")
    !git clone https://github.com/Abderraouf-D/CHEHAB_FHE_Compiler_RL.git /kaggle/working/CHEHAB_FHE_Compiler_RL
else:
    print("Repository already exists.")

# Change to the project directory
os.chdir('/kaggle/working/CHEHAB_FHE_Compiler_RL')
print(f"Current directory: {os.getcwd()}")

## 2. Install Dependencies

In [None]:
# Install core dependencies
print("Installing PyTorch and related packages...")
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

print("\nInstalling additional dependencies...")
!pip install tqdm numpy matplotlib pandas

# Install project-specific requirements first
print("\nInstalling project requirements...")
!pip install -r RL/requirements.txt

# Install additional dependencies that might be needed for pytrs
print("\nInstalling additional dependencies for pytrs...")
!pip install setuptools wheel

# Navigate to the RL directory and install pytrs with all dependencies
print("\nInstalling pytrs package...")
os.chdir('/kaggle/working/CHEHAB_FHE_Compiler_RL/RL')

# Try to install any missing dependencies that pytrs might need
!pip install --upgrade setuptools
!pip install --editable .

# Also try to install from the parent directory structure if needed
print("\nAttempting alternative pytrs installation...")
import sys
sys.path.insert(0, '/kaggle/working/CHEHAB_FHE_Compiler_RL/RL')

print("\nAll dependencies installed successfully!")

## 3. Verify Installation and Check GPU

In [None]:
import torch
import numpy as np
from tqdm.auto import tqdm
import sys
import os

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")
    print(f"GPU device: {torch.cuda.get_device_name(0)}")
    print(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
else:
    print("CUDA not available - using CPU")

# Add multiple paths for pytrs import
sys.path.insert(0, '/kaggle/working/CHEHAB_FHE_Compiler_RL/RL')
sys.path.insert(0, '/kaggle/working/CHEHAB_FHE_Compiler_RL/RL/pytrs')

# Check what's in the RL directory
print(f"\nContents of RL directory:")
if os.path.exists('/kaggle/working/CHEHAB_FHE_Compiler_RL/RL'):
    rl_contents = os.listdir('/kaggle/working/CHEHAB_FHE_Compiler_RL/RL')
    for item in rl_contents:
        print(f"  - {item}")

# Check what's in the pytrs directory
print(f"\nContents of pytrs directory:")
if os.path.exists('/kaggle/working/CHEHAB_FHE_Compiler_RL/RL/pytrs'):
    pytrs_contents = os.listdir('/kaggle/working/CHEHAB_FHE_Compiler_RL/RL/pytrs')
    for item in pytrs_contents:
        print(f"  - {item}")

# Test pytrs import with better error handling
try:
    # Try importing from different locations
    try:
        from pytrs import parse_sexpr, tokenize, Op
        print("✅ pytrs package imported successfully from pytrs module")
    except ImportError as e1:
        print(f"❌ Failed to import from pytrs module: {e1}")
        try:
            # Try direct import from files
            import pytrs
            from pytrs import parse_sexpr, tokenize, Op
            print("✅ pytrs package imported successfully using direct import")
        except ImportError as e2:
            print(f"❌ Failed direct import: {e2}")
            try:
                # Try importing individual modules
                sys.path.append('/kaggle/working/CHEHAB_FHE_Compiler_RL/RL/pytrs')
                import parse_sexpr
                import tokenize
                print("✅ Individual modules imported successfully")
            except ImportError as e3:
                print(f"❌ Failed to import individual modules: {e3}")
                print("⚠️ Will try to continue with manual imports in the next cells")
                
except Exception as e:
    print(f"❌ Unexpected error during pytrs import: {e}")
    print("⚠️ Will try to continue with manual setup")

## 4. Download Pre-trained Model (if available)

This cell attempts to download the pre-trained model. If the model is not available publicly, you'll need to upload it manually to Kaggle datasets.

In [None]:
import os

# Create directories for models
os.makedirs('/kaggle/working/trained_models', exist_ok=True)

# Check if model exists in the cloned repository
model_paths = [
    '/kaggle/working/CHEHAB_FHE_Compiler_RL/RL/fhe_rl/trained_models/model_Transformer_ddp_10399047_epoch_5000000.pth',
    '/kaggle/working/CHEHAB_FHE_Compiler_RL/RL/fhe_rl/trained_models/projection_head.pth'
]

print("Checking for pre-trained models in repository...")
for path in model_paths:
    if os.path.exists(path):
        print(f"Found model: {path}")
    else:
        print(f"Model not found: {path}")

# Also check if model exists in Kaggle input (you can still add it as a dataset)
kaggle_model_paths = [
    '/kaggle/input/*/model_Transformer_ddp_10399047_epoch_5000000.pth',
    '/kaggle/input/*/projection_head.pth'
]

print("\nChecking for pre-trained models in Kaggle datasets...")
for path in kaggle_model_paths:
    import glob
    matches = glob.glob(path)
    if matches:
        print(f"Found model: {matches[0]}")
    else:
        print(f"Model not found: {path}")

print("\nNote: The notebook will first try to use models from the repository,")
print("then fall back to Kaggle datasets, or train from scratch if none are available.")

## 5. Create Contrastive Training Data

Generate contrastive pairs for training the projection head.

In [None]:
# Add the RL directory to Python path
import sys
import os
sys.path.append('/kaggle/working/CHEHAB_FHE_Compiler_RL/RL')
sys.path.append('/kaggle/working/CHEHAB_FHE_Compiler_RL/RL/pytrs')

# Try multiple import strategies for pytrs
parse_sexpr = None
try:
    from pytrs import parse_sexpr
    print("✅ Successfully imported parse_sexpr from pytrs")
except ImportError as e:
    print(f"❌ Failed to import from pytrs: {e}")
    try:
        # Try to manually load the parse_sexpr function
        exec(open('/kaggle/working/CHEHAB_FHE_Compiler_RL/RL/pytrs/__init__.py').read())
        print("✅ Successfully loaded pytrs manually")
    except Exception as e2:
        print(f"❌ Failed manual load: {e2}")
        # Create a minimal parse_sexpr function for testing
        print("⚠️ Using fallback implementation")
        
        def parse_sexpr(expr_str):
            """Minimal fallback implementation for parse_sexpr"""
            # This is a simplified version - replace with actual implementation
            class MockOp:
                def __init__(self, op, args):
                    self.op = op
                    self.args = args
                def __str__(self):
                    return f"({self.op} {' '.join(str(arg) for arg in self.args)})"
                def __repr__(self):
                    return self.__str__()
            
            # Very basic parsing - this should be replaced with the actual pytrs implementation
            if expr_str.startswith('(') and expr_str.endswith(')'):
                inner = expr_str[1:-1].strip()
                parts = inner.split()
                if len(parts) > 1:
                    op = parts[0]
                    args = parts[1:]
                    return MockOp(op, args)
            return expr_str

# returns a list of pairs of positive and negative pairs, 1 for positive and 0 for negative
def augment_expression(expr):
    
    """    
        A function to augment the expression by generating positive and negative pairs.
            expr: a string representing the expression
        Returns a list of tuples (expr_a, expr_b, label) where label is 1 for positive pairs and 0 for negative pairs. 
    """
    expr_pairs = []
    for i, c in enumerate(expr):
        if c == "+":
            prev_expr = expr
            expr = expr[:i] + "-" + expr[i + 1:]
            neg_pair = (prev_expr, expr, 0)  # negative pair
            pos_pair = (prev_expr, prev_expr, 1)
            expr_pairs.append(neg_pair)
            expr_pairs.append(pos_pair)

            expr = expr[:i] + "*" + expr[i + 1:] 
            neg_pair = (prev_expr, expr, 0)  # negative pair
            pos_pair = (prev_expr, prev_expr, 1)
            expr_pairs.append(neg_pair)
            expr_pairs.append(pos_pair)

        elif c == "-":
            prev_expr = expr
            expr = expr[:i] + "+" + expr[i + 1:]
            neg_pair = (prev_expr, expr, 0)  # negative pair
            pos_pair = (prev_expr, prev_expr, 1)
            expr_pairs.append(neg_pair)
            expr_pairs.append(pos_pair)

            expr = expr[:i] + "*" + expr[i + 1:] 
            neg_pair = (prev_expr, expr, 0)  # negative pair
            pos_pair = (prev_expr, prev_expr, 1)
            expr_pairs.append(neg_pair)
            expr_pairs.append(pos_pair)

        elif c == "*":
            prev_expr = expr
            expr = expr[:i] + "+" + expr[i + 1:]
            neg_pair = (prev_expr, expr, 0)  # negative pair
            pos_pair = (prev_expr, prev_expr, 1)
            expr_pairs.append(neg_pair)
            expr_pairs.append(pos_pair)

            expr = expr[:i] + "-" + expr[i + 1:] 
            neg_pair = (prev_expr, expr, 0)  # negative pair
            pos_pair = (prev_expr, prev_expr, 1)
            expr_pairs.append(neg_pair)
            expr_pairs.append(pos_pair)    
        
    return expr_pairs

# building the positive and negative contrastive pairs based on the expression strings
def build_contrastive_pairs(expr_strs):
    pairs = []
    # parse_sexpr
    for expr in expr_strs:
        expr_pairs = augment_expression(expr)

        # parsing the expressions of the generated pairs
        try:
            parsed_expr_pairs = [(parse_sexpr(a), parse_sexpr(b), label) for a, b, label in expr_pairs]
            pairs.extend(parsed_expr_pairs)
        except Exception as e:
            print(f"⚠️ Error parsing expression pairs for {expr}: {e}")
            continue
        
    return pairs

# Create sample expressions dataset file (since test.txt may not exist)
sample_expressions = [
    "(Vec (+ a b) (+ c d) (- f g))",
    "(Vec (- x y) (* p q) (+ m n))",
    "(+ (* a b) (- c d))",
    "(- (+ x y) (* z w))",
    "(* (+ a b) (- c d))",
    "(Vec (+ a 1) (- b 2) (* c 3))",
    "(+ (- a b) (+ c d))",
    "(* (+ x 5) (- y 10))",
    "(Vec (* a a) (+ b b) (- c c))",
    "(- (* x y) (+ z w))"
]

# Create directories for datasets
os.makedirs('/kaggle/working/datasets', exist_ok=True)

# Write sample expressions to file (mimicking test.txt)
with open('/kaggle/working/datasets/test.txt', 'w') as f:
    for expr in sample_expressions:
        f.write(f"{expr}\n")

print("Created sample expressions dataset file")

# Load expressions from file (following the exact demo() function flow)
with open("/kaggle/working/datasets/test.txt") as f:
    expr_strs = [line.strip() for line in f if line.strip()]

print(f"Loaded {len(expr_strs)} expressions from dataset")

# Build contrastive pairs
print("Building contrastive pairs...")
try:
    pairs = build_contrastive_pairs(expr_strs)
    print(f"Generated {len(pairs)} contrastive pairs")

    # storing the pairs in a file (exactly as in demo() function)
    with open("/kaggle/working/datasets/contrastive_pairs.txt", "w") as f:
        for expr_a, expr_b, label in pairs:
            f.write(f"{expr_a} | {expr_b} | {label}\n")

    print("Contrastive pairs saved to /kaggle/working/datasets/contrastive_pairs.txt")
except Exception as e:
    print(f"❌ Error building contrastive pairs: {e}")
    print("⚠️ Will create a minimal pairs file for testing")
    
    # Create minimal pairs for testing
    with open("/kaggle/working/datasets/contrastive_pairs.txt", "w") as f:
        for i, expr in enumerate(sample_expressions[:3]):  # Use first 3 expressions
            # Create simple positive and negative pairs
            f.write(f"{expr} | {expr} | 1\n")  # positive pair
            if i < len(sample_expressions) - 1:
                f.write(f"{expr} | {sample_expressions[i+1]} | 0\n")  # negative pair
    
    print("Created minimal contrastive pairs file for testing")

## 6. TRAE Model Implementation

Now let's implement the TRAE model and related components.

In [None]:
# Set environment variables for optimal performance
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

import torch
import torch.nn as nn
import torch.optim as optim
import random
import numpy as np
from tqdm.auto import tqdm

from pytrs import (
    Op,
    VARIABLE_RANGE,
    CONST_OFFSET,
    PAREN_CLOSE,
    PAREN_OPEN,
    node_to_id,
    parse_sexpr,
    tokenize,
    MAX_INT_TOKENS
)

torch.set_float32_matmul_precision("high")

# Configuration
class Config:
    max_gen_length = 25122
    vocab_size = CONST_OFFSET + MAX_INT_TOKENS + 2 + 1 + 1
    start_token = CONST_OFFSET + MAX_INT_TOKENS
    end_token = CONST_OFFSET + MAX_INT_TOKENS + 1
    pad_token = CONST_OFFSET + MAX_INT_TOKENS + 2
    cls_token = vocab_size
    vocab_size += 1  # include CLS
    
    d_model = 256
    num_heads = 8
    num_encoder_layers = 4
    num_decoder_layers = 4
    dim_feedforward = 512
    transformer_dropout = 0.2
    max_seq_length = 25200
    
    batch_size = 16
    learning_rate = 3e-4
    epochs = 50
    dropout_rate = 0.3
    total_samples = 5000000

config = Config()

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

if torch.cuda.is_available():
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")

In [None]:
# Expression processing helpers
def dfs_traverse(expr, depth=0, node_list=None):
    if node_list is None:
        node_list = []
    if isinstance(expr, Op):
        node_list.append((PAREN_OPEN, depth))
        node_list.append((expr, depth))
        for child in expr.args:
            dfs_traverse(child, depth + 1, node_list)
        node_list.append((PAREN_CLOSE, depth))
    else:
        node_list.append((expr, depth))
    return node_list

def flatten_expr(expr):
    node_list = dfs_traverse(expr, 0)
    results = []
    varmap = {}
    intmap = {}
    next_var_id = VARIABLE_RANGE[0]
    next_int_id = CONST_OFFSET
    for node_or_paren, depth in node_list:
        if node_or_paren in (PAREN_OPEN, PAREN_CLOSE):
            nid = node_or_paren
        else:
            nid, next_var_id, next_int_id, _ = node_to_id(
                node_or_paren, varmap, intmap, next_var_id, next_int_id
            )
        results.append({"node_id": nid})
    return results

print("Expression processing helpers defined.")

In [None]:
# Positional encoding
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=1024):
        super().__init__()
        self.pos_embedding = nn.Embedding(max_len, d_model)

    def forward(self, x, positions=None):
        batch_size, seq_len, _ = x.shape
        if positions is None:
            positions = (
                torch.arange(0, seq_len, device=x.device)
                .unsqueeze(0)
                .repeat(batch_size, 1)
            )
        pos_emb = self.pos_embedding(positions)
        return x + pos_emb

print("PositionalEncoding defined.")

In [None]:
# Transformer autoencoder
class TransformerAutoencoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_embedding = nn.Embedding(config.vocab_size, config.d_model)
        self.pos_encoder = PositionalEncoding(
            config.d_model, max_len=config.max_seq_length
        )

        enc_layer = nn.TransformerEncoderLayer(
            d_model=config.d_model,
            nhead=config.num_heads,
            dim_feedforward=config.dim_feedforward,
            dropout=config.transformer_dropout,
            batch_first=True,
        )
        self.encoder = nn.TransformerEncoder(enc_layer, config.num_encoder_layers)

        dec_layer = nn.TransformerDecoderLayer(
            d_model=config.d_model,
            nhead=config.num_heads,
            dim_feedforward=config.dim_feedforward,
            dropout=config.transformer_dropout,
            batch_first=True,
        )
        self.decoder = nn.TransformerDecoder(dec_layer, config.num_decoder_layers)
        self.output_fc = nn.Linear(config.d_model, config.vocab_size)

    @staticmethod
    def generate_square_subsequent_mask(sz, device):
        mask = (torch.triu(torch.ones(sz, sz, device=device)) == 1).T
        mask = mask.float().masked_fill(mask == 0, float("-inf")).masked_fill(
            mask == 1, float(0.0)
        )
        return mask

    def forward(self, src_nodes, tgt_seq):
        batch_size = src_nodes.size(0)
        cls_column = torch.full(
            (batch_size, 1), config.cls_token, dtype=torch.long, device=src_nodes.device
        )
        src_nodes_with_cls = torch.cat([cls_column, src_nodes], dim=1)

        src_emb = self.token_embedding(src_nodes_with_cls)
        src_emb = self.pos_encoder(src_emb)
        memory = self.encoder(src_emb)

        tgt_emb = self.token_embedding(tgt_seq)
        tgt_emb = self.pos_encoder(tgt_emb)
        tgt_mask = self.generate_square_subsequent_mask(
            tgt_seq.size(1), device=tgt_seq.device
        )

        dec_out = self.decoder(tgt_emb, memory, tgt_mask=tgt_mask)
        logits = self.output_fc(dec_out)
        return logits

    def encode(self, src_nodes):
        batch_size = src_nodes.size(0)
        cls_column = torch.full(
            (batch_size, 1), config.cls_token, dtype=torch.long, device=src_nodes.device
        )
        src_nodes_with_cls = torch.cat([cls_column, src_nodes], dim=1)
        src_emb = self.token_embedding(src_nodes_with_cls)
        src_emb = self.pos_encoder(src_emb)
        return self.encoder(src_emb)

    def get_cls_vector(self, memory):
        return memory[:, 0, :]

print("TransformerAutoencoder defined.")

In [None]:
# TRAE wrapper
class TRAE(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = TransformerAutoencoder()

    def forward(self, src_nodes, src_pos, tgt_seq):
        return self.model(src_nodes, tgt_seq)

    @property
    def encoder(self):
        return self.model.encode

    def get_cls_summary(self, memory):
        return self.model.get_cls_vector(memory)

# Projection head for contrastive learning
class ProjectionHead(nn.Module):
    def __init__(self, input_dim, proj_dim=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, proj_dim),
            nn.ReLU(),
            nn.Linear(proj_dim, proj_dim)
        )

    def forward(self, x):
        return self.net(x)

print("TRAE and ProjectionHead defined.")

In [None]:
# Utility functions
def get_expression_cls_embedding(expr, model):
    flat = flatten_expr(expr)
    node_ids = [e["node_id"] for e in flat]
    if len(node_ids) + 1 > config.max_seq_length:
        return None

    src_tensor = torch.tensor([node_ids], dtype=torch.long, device=device)
    memory = model.encoder(src_tensor)
    return model.get_cls_summary(memory)

def contrastive_loss(z1, z2, label, temperature=0.5):
    # z1, z2: [batch, dim], label: [batch] (1=pos, 0=neg)
    z1 = nn.functional.normalize(z1, dim=1)
    z2 = nn.functional.normalize(z2, dim=1)
    sim = torch.sum(z1 * z2, dim=1) / temperature
    pos_loss = -nn.functional.logsigmoid(sim[label == 1]).mean() if (label == 1).any() else 0
    neg_loss = -nn.functional.logsigmoid(-sim[label == 0]).mean() if (label == 0).any() else 0
    return pos_loss + neg_loss

print("Utility functions defined.")

## 7. Initialize and Train the Model

Now let's initialize the TRAE model and train the projection head.

In [None]:
# Initialize the model
print("Initializing TRAE model...")
model = TRAE()
model.eval()

# Try to load pre-trained weights if available
pretrained_model_path = None

# First, check the repository path
repo_model_path = '/kaggle/working/CHEHAB_FHE_Compiler_RL/RL/fhe_rl/trained_models/model_Transformer_ddp_10399047_epoch_5000000.pth'
if os.path.exists(repo_model_path):
    pretrained_model_path = repo_model_path
    print(f"Found model in repository: {pretrained_model_path}")
else:
    # Fallback to Kaggle input datasets
    for path in ['/kaggle/input/*/model_Transformer_ddp_10399047_epoch_5000000.pth']:
        import glob
        matches = glob.glob(path)
        if matches:
            pretrained_model_path = matches[0]
            print(f"Found model in Kaggle dataset: {pretrained_model_path}")
            break

if pretrained_model_path:
    print(f"Loading pre-trained model from: {pretrained_model_path}")
    state_dict = torch.load(pretrained_model_path, map_location=device)
    new_sd = {k[len("module."):] if k.startswith("module.") else k: v for k, v in state_dict.items()}
    model.load_state_dict(new_sd)
    print("Pre-trained Encoder loaded successfully")
else:
    print("No pre-trained model found. Using randomly initialized weights.")

model.to(device)

# Initialize projection head
print("\nInitializing projection head...")
cls_dim = model.get_cls_summary(torch.zeros(1, 1, config.d_model, device=device)).shape[-1]
projection_head = ProjectionHead(cls_dim, proj_dim=128).to(device)

# Freeze encoder weights for projection head training
for param in model.parameters():
    param.requires_grad = False

optimizer = torch.optim.Adam(projection_head.parameters(), lr=1e-3)

print(f"Model initialized. CLS dimension: {cls_dim}")
print(f"Projection head output dimension: 128")

In [None]:
# Load contrastive pairs from the stored file (following demo() function exactly)
print("Loading contrastive pairs for training...")

# Load the stored contrastive pairs from existing dataset file
pairs = []
with open("/kaggle/working/datasets/contrastive_pairs.txt", "r") as f:
    for line in f:
        expr_a_str, expr_b_str, label_str = line.strip().split(" | ")
        # Parse the expressions back to objects
        expr_a = parse_sexpr(expr_a_str)
        expr_b = parse_sexpr(expr_b_str)
        label = int(label_str)
        pairs.append((expr_a, expr_b, label))

print(f"Loaded {len(pairs)} training pairs from file")

# Count positive and negative pairs
pos_pairs = sum(1 for _, _, label in pairs if label == 1)
neg_pairs = sum(1 for _, _, label in pairs if label == 0)
print(f"Positive pairs: {pos_pairs}, Negative pairs: {neg_pairs}")

In [None]:
# Training loop (exact demo() function implementation)
print("Starting projection head training...")

# Training loop with tqdm progress bars
for epoch in tqdm(range(10), desc="Training Epochs"):
    random.shuffle(pairs)
    losses = []
    
    # Add progress bar for the training pairs within each epoch
    epoch_pairs = tqdm(pairs, desc=f"Epoch {epoch}", leave=False)
    
    for expr_a, expr_b, label in epoch_pairs:
        cls_a = get_expression_cls_embedding(expr_a, model)
        cls_b = get_expression_cls_embedding(expr_b, model)
        
        if cls_a is None or cls_b is None:
            continue
        cls_a = cls_a.to(device)
        cls_b = cls_b.to(device)
        z_a = projection_head(cls_a)
        z_b = projection_head(cls_b)
        lbl = torch.tensor([label], dtype=torch.long, device=device)
        loss = contrastive_loss(z_a, z_b, lbl)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
        
        # Update progress bar with current loss
        if len(losses) > 0:
            epoch_pairs.set_postfix({"loss": f"{losses[-1]:.6f}", "avg_loss": f"{np.mean(losses):.6f}"})

    print(f"Epoch {epoch}: Contrastive loss = {np.mean(losses):.6f}")

# save the projection head
torch.save(projection_head.state_dict(), "/kaggle/working/projection_head.pth")
print("\nProjection head training completed and saved!")

## 8. Test the Trained Model

Now let's test the trained projection head on some example expressions.

In [None]:
# Test the projection head (test_projection_head function implementation)
print("Testing the trained projection head...")

# Try to load saved projection head first
try:
    projection_head.load_state_dict(torch.load("/kaggle/working/projection_head.pth", map_location=device))
    projection_head.eval()
    projection_head.to(device)
    print("Loaded trained projection head")
except:
    print("Using current projection head state")

exp_a = parse_sexpr("(Vec  (+ a b) (+ c d) (- f g) )")
exp_b = parse_sexpr("(Vec  (* a b) (* c d) (+ f g) )")

print(f"Expression A: {exp_a}")
print(f"Expression B: {exp_b}")

cls_a = get_expression_cls_embedding(exp_a, model)
cls_b = get_expression_cls_embedding(exp_b, model)

if cls_a is None or cls_b is None:
    print("Cannot encode them")
else:
    cls_a = cls_a.to(device)
    cls_b = cls_b.to(device)
    z_a = projection_head(cls_a)
    z_b = projection_head(cls_b)

    cosine_sim = nn.functional.cosine_similarity(z_a, z_b).item()
    print("Cosine similarity between test expressions:", cosine_sim)
    print("Embedding for expr_a:", z_a.cpu().detach().numpy())
    print("Embedding for expr_b:", z_b.cpu().detach().numpy())

In [None]:
# Test with more expression pairs
print("\n🔬 Testing multiple expression pairs...")

test_pairs_manual = [
    ("(Vec (+ a b) (+ c d) (- f g))", "(Vec (+ a b) (+ c d) (- f g))", "Identical"),
    ("(Vec (+ a b) (+ c d) (- f g))", "(Vec (- a b) (- c d) (+ f g))", "Different ops"),
    ("(+ a b)", "(- a b)", "Simple change"),
    ("(* x y)", "(+ x y)", "Mult to add"),
    ("(Vec (+ a b) (+ c d))", "(Vec (+ a b) (+ c d) (+ e f))", "Different length")
]

results = []

for expr1_str, expr2_str, description in test_pairs_manual:
    try:
        expr1 = parse_sexpr(expr1_str)
        expr2 = parse_sexpr(expr2_str)
        
        cls1 = get_expression_cls_embedding(expr1, model)
        cls2 = get_expression_cls_embedding(expr2, model)
        
        if cls1 is not None and cls2 is not None:
            cls1 = cls1.to(device)
            cls2 = cls2.to(device)
            z1 = projection_head(cls1)
            z2 = projection_head(cls2)
            
            cosine_sim = nn.functional.cosine_similarity(z1, z2).item()
            euclidean_dist = torch.norm(z1 - z2).item()
            
            results.append({
                'description': description,
                'cosine_sim': cosine_sim,
                'euclidean_dist': euclidean_dist,
                'expr1': expr1_str[:30] + '...' if len(expr1_str) > 30 else expr1_str,
                'expr2': expr2_str[:30] + '...' if len(expr2_str) > 30 else expr2_str
            })
            
            print(f"\n{description}:")
            print(f"  Expr 1: {expr1_str}")
            print(f"  Expr 2: {expr2_str}")
            print(f"  Cosine similarity: {cosine_sim:.6f}")
            print(f"  Euclidean distance: {euclidean_dist:.6f}")
        else:
            print(f"\n❌ {description}: Failed to encode expressions")
            
    except Exception as e:
        print(f"\n❌ {description}: Error - {e}")

print(f"\n✅ Testing completed! Processed {len(results)} expression pairs successfully.")

## 9. Optional: Visualize Embeddings

Let's create a simple visualization of the embeddings to understand how the model groups similar expressions.

In [None]:
# Install required packages for visualization
!pip install matplotlib seaborn scikit-learn

import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import numpy as np

# Generate embeddings for a variety of expressions
test_expressions = [
    "(+ a b)",
    "(- a b)", 
    "(* a b)",
    "(+ x y)",
    "(- x y)",
    "(* x y)",
    "(Vec (+ a b) (+ c d))",
    "(Vec (- a b) (- c d))",
    "(Vec (* a b) (* c d))",
    "(+ (+ a b) (+ c d))",
    "(- (- a b) (- c d))",
    "(* (* a b) (* c d))"
]

embeddings = []
labels = []
valid_expressions = []

print("Generating embeddings for visualization...")
for expr_str in test_expressions:
    try:
        expr = parse_sexpr(expr_str)
        cls_emb = get_expression_cls_embedding(expr, model)
        
        if cls_emb is not None:
            cls_emb = cls_emb.to(device)
            z = projection_head(cls_emb)
            embeddings.append(z.cpu().detach().numpy().flatten())
            
            # Create labels based on operation type
            if '+' in expr_str:
                labels.append('Addition')
            elif '-' in expr_str:
                labels.append('Subtraction')
            elif '*' in expr_str:
                labels.append('Multiplication')
            else:
                labels.append('Other')
                
            valid_expressions.append(expr_str)
            
    except Exception as e:
        print(f"Failed to process {expr_str}: {e}")
        continue

print(f"Generated {len(embeddings)} embeddings for visualization")

In [None]:
if len(embeddings) > 0:
    # Convert to numpy array
    embeddings_array = np.array(embeddings)
    
    # Apply PCA for dimensionality reduction
    pca = PCA(n_components=2)
    embeddings_2d = pca.fit_transform(embeddings_array)
    
    # Create visualization
    plt.figure(figsize=(12, 8))
    
    # Plot PCA
    plt.subplot(1, 2, 1)
    unique_labels = list(set(labels))
    colors = plt.cm.Set1(np.linspace(0, 1, len(unique_labels)))
    
    for i, label in enumerate(unique_labels):
        mask = np.array(labels) == label
        plt.scatter(embeddings_2d[mask, 0], embeddings_2d[mask, 1], 
                   c=[colors[i]], label=label, s=100, alpha=0.7)
    
    plt.title('Expression Embeddings (PCA)', fontsize=14)
    plt.xlabel(f'PC1 ({pca.explained_variance_ratio_[0]:.2%} variance)')
    plt.ylabel(f'PC2 ({pca.explained_variance_ratio_[1]:.2%} variance)')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # Create similarity heatmap
    plt.subplot(1, 2, 2)
    similarity_matrix = np.zeros((len(embeddings), len(embeddings)))
    
    for i in range(len(embeddings)):
        for j in range(len(embeddings)):
            emb_i = torch.tensor(embeddings[i]).unsqueeze(0)
            emb_j = torch.tensor(embeddings[j]).unsqueeze(0)
            similarity = nn.functional.cosine_similarity(emb_i, emb_j).item()
            similarity_matrix[i, j] = similarity
    
    # Create abbreviated labels for the heatmap
    short_labels = [expr[:15] + '...' if len(expr) > 15 else expr for expr in valid_expressions]
    
    sns.heatmap(similarity_matrix, 
                xticklabels=short_labels,
                yticklabels=short_labels,
                annot=True, 
                fmt='.2f',
                cmap='coolwarm',
                center=0,
                square=True)
    plt.title('Cosine Similarity Heatmap', fontsize=14)
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    
    plt.tight_layout()
    plt.savefig('/kaggle/working/embedding_analysis.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    print("\n📊 Visualization Summary:")
    print(f"- PCA explains {pca.explained_variance_ratio_.sum():.2%} of the variance in the first 2 components")
    print(f"- Similarity matrix shows how similar expressions cluster together")
    print(f"- Visualization saved as /kaggle/working/embedding_analysis.png")
    
else:
    print("❌ No valid embeddings generated for visualization")