In [1]:
"""
Load and use the trained autoencoder model.
"""
import sys
from pathlib import Path
project_root = Path().resolve().parent
sys.path.insert(0, str(project_root))

import torch
from models.autoencoder import ToolInvocationAutoencoder
from training.config import AutoencoderConfig

print(f"Project root: {project_root}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

  from .autonotebook import tqdm as notebook_tqdm


Project root: /scratch4/home/akrik/NTILC
PyTorch version: 2.9.1+cu128
CUDA available: True
CUDA device: NVIDIA H100 80GB HBM3


In [2]:
# Configuration: Set the path to your trained model checkpoint
# Default is ./checkpoints/best_model.pt, but you can change this
CHECKPOINT_PATH = project_root / "checkpoints" / "best_model.pt"

# If the default path doesn't exist, try to find it
if not CHECKPOINT_PATH.exists():
    print(f"Warning: {CHECKPOINT_PATH} not found. Please update CHECKPOINT_PATH.")
    print("Looking for checkpoints in common locations...")
    
    # Try alternative locations
    alt_paths = [
        project_root / "checkpoints" / "best_model.pt",
        Path("./checkpoints/best_model.pt"),
        Path("../checkpoints/best_model.pt"),
    ]
    
    found = False
    for alt_path in alt_paths:
        if alt_path.exists():
            CHECKPOINT_PATH = alt_path
            print(f"Found checkpoint at: {CHECKPOINT_PATH}")
            found = True
            break
    
    if not found:
        print("No checkpoint found. Please specify CHECKPOINT_PATH manually.")
else:
    print(f"Found checkpoint at: {CHECKPOINT_PATH}")

Found checkpoint at: /scratch4/home/akrik/NTILC/checkpoints/best_model.pt


In [3]:
# Load the checkpoint
print(f"Loading checkpoint from {CHECKPOINT_PATH}...")
checkpoint = torch.load(CHECKPOINT_PATH, map_location='cpu')

# Extract saved configuration
saved_config_dict = checkpoint.get('config', {})
print("\nSaved model configuration:")
for key, value in saved_config_dict.items():
    print(f"  {key}: {value}")

# Create config object from saved config
config = AutoencoderConfig(**saved_config_dict)

# Print validation metrics if available
if 'val_metrics' in checkpoint:
    print("\nValidation metrics from training:")
    for key, value in checkpoint['val_metrics'].items():
        if isinstance(value, (int, float)):
            print(f"  {key}: {value:.4f}")
        else:
            print(f"  {key}: {value}")

print(f"\nModel was trained for {checkpoint.get('epoch', 'unknown')} epochs")

Loading checkpoint from /scratch4/home/akrik/NTILC/checkpoints/best_model.pt...

Saved model configuration:
  embedding_dim: 256
  encoder_model: google/flan-t5-base
  decoder_model: google/flan-t5-base
  pooling_strategy: attention
  max_length: 256
  dropout: 0.1
  freeze_encoder: False
  freeze_decoder: False
  freeze_encoder_layers: 4
  freeze_decoder_layers: 2
  batch_size: 32
  learning_rate: 5e-05
  weight_decay: 0.01
  num_epochs: 30
  warmup_steps: 1000
  warmup_ratio: 0.1
  gradient_clip: 1.0
  use_lr_scheduler: True
  label_smoothing: 0.1
  use_gradient_checkpointing: True
  torch_dtype: bfloat16
  gradient_accumulation_steps: 2
  num_train_samples: 250000
  num_val_samples: 2500
  num_test_samples: 2500
  output_format: python
  regenerate_data: True
  use_validity_loss: False
  validity_loss_weight: 0.1
  use_contrastive_loss: True
  contrastive_loss_weight: 0.03
  contrastive_margin: 0.5
  contrastive_temperature: 0.07
  embedding_l2_weight: 0.001
  embedding_variance_wei

In [4]:
# Initialize model with saved configuration
print("Initializing model...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

model = ToolInvocationAutoencoder(
    embedding_dim=config.embedding_dim,
    encoder_model=config.encoder_model,
    decoder_model=config.decoder_model,
    pooling_strategy=config.pooling_strategy,
    max_length=config.max_length,
    dropout=config.dropout,
    freeze_encoder=config.freeze_encoder,
    freeze_decoder=config.freeze_decoder,
    torch_dtype=config.torch_dtype,
    use_gradient_checkpointing=config.use_gradient_checkpointing
)

# Load model weights
print("Loading model weights...")
model.load_state_dict(checkpoint['model_state_dict'])
model = model.to(device)
model.eval()  # Set to evaluation mode

print("Model loaded successfully!")
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

Initializing model...
Using device: cuda


`torch_dtype` is deprecated! Use `dtype` instead!


Loading model weights...
Model loaded successfully!
Model parameters: 359,378,305
Trainable parameters: 359,378,305


In [5]:
# Example tool invocations to test
test_tool_calls = [
    "pi",
    "3.14159",
    "3.14",
]

print("Testing model with example tool invocations:\n")
for i, tool_call in enumerate(test_tool_calls, 1):
    print(f"{i}. {tool_call}")

Testing model with example tool invocations:

1. pi
2. 3.14159
3. 3.14


In [6]:
# Test encoding: Convert tool calls to embeddings
print("Encoding tool calls to embeddings...")
with torch.no_grad():
    embeddings = model.encode(test_tool_calls)

print(f"Embeddings shape: {embeddings.shape}")  # Should be (batch_size, embedding_dim)
print(f"Embedding dimension: {embeddings.shape[1]}")
print(f"\nFirst embedding stats:")
print(f"  Min: {embeddings[0].min().item():.4f}")
print(f"  Max: {embeddings[0].max().item():.4f}")
print(f"  Mean: {embeddings[0].mean().item():.4f}")
print(f"  Std: {embeddings[0].std().item():.4f}")

Encoding tool calls to embeddings...
Embeddings shape: torch.Size([3, 256])
Embedding dimension: 256

First embedding stats:
  Min: -0.1914
  Max: 0.2090
  Mean: 0.0000
  Std: 0.0625


In [7]:
# Test reconstruction: Encode then decode
print("Reconstructing tool calls (encode -> decode)...\n")
reconstructed = model.reconstruct(test_tool_calls)

for i, (original, recon) in enumerate(zip(test_tool_calls, reconstructed), 1):
    match = "✓" if original == recon else "✗"
    print(f"{i}. {match}")
    print(f"   Original:     {original}")
    print(f"   Reconstructed: {recon}")
    print()

Reconstructing tool calls (encode -> decode)...

1. ✗
   Original:     pi
   Reconstructed: file_pi()

2. ✗
   Original:     3.14159
   Reconstructed: calculate(expression='round(expression='round(expression='round(expression='round(expression='round(expression='round(expression='round(expression='round(expression='round(expression='round(expression='round(expression='round(expression='round(expression='round(expression='round(expression='round(expression='round(expression='round(expression='round(expression='round(expression='round(expression='round(expression='round(expression='round(expression='round(expression='round(expression='round(expression='round(expression='round(expression='round(expression='round(expression='round(expression='round(expression='round(expression='round(expression='round(expression='round(expression='round(expression='round(expression='round(expression='round(expression='round(expression='round(expression='round(expression='round(expression='round(expression=

In [8]:
# Test decoding: Decode embeddings back to tool calls
print("Decoding embeddings back to tool calls...\n")
decoded = model.decode(embeddings)

for i, (original, decoded_call) in enumerate(zip(test_tool_calls, decoded), 1):
    match = "✓" if original == decoded_call else "✗"
    print(f"{i}. {match}")
    print(f"   Original: {original}")
    print(f"   Decoded:  {decoded_call}")
    print()

Decoding embeddings back to tool calls...

1. ✗
   Original: pi
   Decoded:  file_pi()

2. ✗
   Original: 3.14159
   Decoded:  calculate(expression='round(expression='round(expression='round(expression='round(expression='round(expression='round(expression='round(expression='round(expression='round(expression='round(expression='round(expression='round(expression='round(expression='round(expression='round(expression='round(expression='round(expression='round(expression='round(expression='round(expression='round(expression='round(expression='round(expression='round(expression='round(expression='round(expression='round(expression='round(expression='round(expression='round(expression='round(expression='round(expression='round(expression='round(expression='round(expression='round(expression='round(expression='round(expression='round(expression='round(expression='round(expression='round(expression='round(expression='round(expression='round(expression='round(expression='round(expression='round

In [9]:
# Compute reconstruction accuracy
exact_matches = sum(1 for orig, recon in zip(test_tool_calls, reconstructed) if orig == recon)
accuracy = exact_matches / len(test_tool_calls)

print(f"Reconstruction Accuracy: {accuracy:.2%} ({exact_matches}/{len(test_tool_calls)})")

Reconstruction Accuracy: 0.00% (0/3)


In [10]:
# Example: Compute similarity between embeddings
import torch.nn.functional as F

print("Computing pairwise cosine similarity between embeddings...\n")
with torch.no_grad():
    # Normalize embeddings for cosine similarity
    normalized_embeddings = F.normalize(embeddings, p=2, dim=1)
    similarity_matrix = torch.mm(normalized_embeddings, normalized_embeddings.t())

print("Cosine similarity matrix:")
print("      ", end="")
for i in range(len(test_tool_calls)):
    print(f"{i+1:>6}", end="")
print()
for i in range(len(test_tool_calls)):
    print(f"  {i+1:2d}  ", end="")
    for j in range(len(test_tool_calls)):
        sim = similarity_matrix[i, j].item()
        print(f"{sim:6.3f}", end="")
    print()

Computing pairwise cosine similarity between embeddings...

Cosine similarity matrix:
           1     2     3
   1   1.000 0.486 0.516
   2   0.486 1.000 0.902
   3   0.516 0.902 1.000
