In [1]:
"""
Load and use the trained autoencoder model.
"""
import sys
from pathlib import Path
project_root = Path().resolve().parent
sys.path.insert(0, str(project_root))

import torch
from models.autoencoder import ToolInvocationAutoencoder
from training.config import AutoencoderConfig

print(f"Project root: {project_root}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

  from .autonotebook import tqdm as notebook_tqdm


Project root: /scratch4/home/akrik/NTILC
PyTorch version: 2.9.1+cu128
CUDA available: True
CUDA device: NVIDIA H100 80GB HBM3


In [2]:
# Configuration: Set the path to your trained model checkpoint
# Default is ./checkpoints/best_model.pt, but you can change this
CHECKPOINT_PATH = project_root / "checkpoints" / "best_model.pt"

# If the default path doesn't exist, try to find it
if not CHECKPOINT_PATH.exists():
    print(f"Warning: {CHECKPOINT_PATH} not found. Please update CHECKPOINT_PATH.")
    print("Looking for checkpoints in common locations...")
    
    # Try alternative locations
    alt_paths = [
        project_root / "checkpoints" / "best_model.pt",
        Path("./checkpoints/best_model.pt"),
        Path("../checkpoints/best_model.pt"),
    ]
    
    found = False
    for alt_path in alt_paths:
        if alt_path.exists():
            CHECKPOINT_PATH = alt_path
            print(f"Found checkpoint at: {CHECKPOINT_PATH}")
            found = True
            break
    
    if not found:
        print("No checkpoint found. Please specify CHECKPOINT_PATH manually.")
else:
    print(f"Found checkpoint at: {CHECKPOINT_PATH}")

Found checkpoint at: /scratch4/home/akrik/NTILC/checkpoints/best_model.pt


In [3]:
# Load the checkpoint
print(f"Loading checkpoint from {CHECKPOINT_PATH}...")
checkpoint = torch.load(CHECKPOINT_PATH, map_location='cpu')

# Extract saved configuration
saved_config_dict = checkpoint.get('config', {})
print("\nSaved model configuration:")
for key, value in saved_config_dict.items():
    print(f"  {key}: {value}")

# Create config object from saved config
config = AutoencoderConfig(**saved_config_dict)

# Print validation metrics if available
if 'val_metrics' in checkpoint:
    print("\nValidation metrics from training:")
    for key, value in checkpoint['val_metrics'].items():
        if isinstance(value, (int, float)):
            print(f"  {key}: {value:.4f}")
        else:
            print(f"  {key}: {value}")

print(f"\nModel was trained for {checkpoint.get('epoch', 'unknown')} epochs")

Loading checkpoint from /scratch4/home/akrik/NTILC/checkpoints/best_model.pt...

Saved model configuration:
  embedding_dim: 256
  encoder_model: google/flan-t5-base
  decoder_model: google/flan-t5-base
  pooling_strategy: attention
  max_length: 128
  dropout: 0.1
  freeze_encoder: False
  freeze_decoder: False
  batch_size: 64
  learning_rate: 1e-05
  weight_decay: 0.01
  num_epochs: 50
  warmup_steps: 1000
  gradient_clip: 0.5
  use_lr_scheduler: True
  use_gradient_checkpointing: True
  torch_dtype: bfloat16
  num_train_samples: 100000
  num_val_samples: 10000
  num_test_samples: 10000
  output_dir: ./checkpoints
  log_dir: ./logs
  data_dir: ./data
  log_interval: 100
  eval_interval: 1000
  save_interval: 5000
  use_wandb: True
  wandb_project: ntilc
  wandb_entity: andykr1k
  wandb_run_name: None
  early_stopping_patience: 7
  early_stopping_min_delta: 0.001

Validation metrics from training:
  exact_match_accuracy: 0.0476
  tool_accuracy: 0.9966
  param_str_accuracy: 0.1305
  p

In [4]:
# Initialize model with saved configuration
print("Initializing model...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

model = ToolInvocationAutoencoder(
    embedding_dim=config.embedding_dim,
    encoder_model=config.encoder_model,
    decoder_model=config.decoder_model,
    pooling_strategy=config.pooling_strategy,
    max_length=config.max_length,
    dropout=config.dropout,
    freeze_encoder=config.freeze_encoder,
    freeze_decoder=config.freeze_decoder,
    torch_dtype=config.torch_dtype,
    use_gradient_checkpointing=config.use_gradient_checkpointing
)

# Load model weights
print("Loading model weights...")
model.load_state_dict(checkpoint['model_state_dict'])
model = model.to(device)
model.eval()  # Set to evaluation mode

print("Model loaded successfully!")
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

Initializing model...
Using device: cuda


`torch_dtype` is deprecated! Use `dtype` instead!


Loading model weights...
Model loaded successfully!
Model parameters: 360,949,633
Trainable parameters: 360,949,633


In [10]:
# Example tool invocations to test
test_tool_calls = [
    "search(query='machine learning', max_results=10)",
    "calculate(expression='2 + 2 * 3')",
    "database_query(table='users', limit=50)",
    "send_email(to='user@example.com', subject='Test', body='Hello world')",
    "web_fetch(url='https://example.com', method='GET')",
    "file_read(path='/path/to/file.txt')",
]

print("Testing model with example tool invocations:\n")
for i, tool_call in enumerate(test_tool_calls, 1):
    print(f"{i}. {tool_call}")

Testing model with example tool invocations:

1. search(query='machine learning', max_results=10)
2. calculate(expression='2 + 2 * 3')
3. database_query(table='users', limit=50)
4. send_email(to='user@example.com', subject='Test', body='Hello world')
5. web_fetch(url='https://example.com', method='GET')
6. file_read(path='/path/to/file.txt')


In [11]:
# Test encoding: Convert tool calls to embeddings
print("Encoding tool calls to embeddings...")
with torch.no_grad():
    embeddings = model.encode(test_tool_calls)

print(f"Embeddings shape: {embeddings.shape}")  # Should be (batch_size, embedding_dim)
print(f"Embedding dimension: {embeddings.shape[1]}")
print(f"\nFirst embedding stats:")
print(f"  Min: {embeddings[0].min().item():.4f}")
print(f"  Max: {embeddings[0].max().item():.4f}")
print(f"  Mean: {embeddings[0].mean().item():.4f}")
print(f"  Std: {embeddings[0].std().item():.4f}")

Encoding tool calls to embeddings...
Embeddings shape: torch.Size([6, 256])
Embedding dimension: 256

First embedding stats:
  Min: -2.8125
  Max: 2.6094
  Mean: -0.0000
  Std: 0.9844


In [8]:
# Test reconstruction: Encode then decode
print("Reconstructing tool calls (encode -> decode)...\n")
reconstructed = model.reconstruct(test_tool_calls)

for i, (original, recon) in enumerate(zip(test_tool_calls, reconstructed), 1):
    match = "✓" if original == recon else "✗"
    print(f"{i}. {match}")
    print(f"   Original:     {original}")
    print(f"   Reconstructed: {recon}")
    print()

Reconstructing tool calls (encode -> decode)...

1. ✗
   Original:     search(query='machine learning', max_results=10)
   Reconstructed: search(query='digital computing', max_results=7)

2. ✗
   Original:     calculate(expression='2 + 2 * 3')
   Reconstructed: calculate(expression='2 + 2')

3. ✗
   Original:     database_query(table='users', filter={'age': {'>': 25}}, limit=50)
   Reconstructed: '', '', '', ''

4. ✗
   Original:     send_email(to='user@example.com', subject='Test', body='Hello world')
   Reconstructed: send me the following information: name, email, password, subject, and subject.

5. ✗
   Original:     web_fetch(url='https://example.com', method='GET')
   Reconstructed: web_fetch(url='https://api.example.com/get', method='GET')

6. ✗
   Original:     file_read(path='/path/to/file.txt')
   Reconstructed: file_read(path='/home/files/tmp/files/tmp/files/files/path/path/files/path/path/path/path/path/path/path/path/path/path/path/path/path/path/path/path/path/path/path/p

In [12]:
# Test decoding: Decode embeddings back to tool calls
print("Decoding embeddings back to tool calls...\n")
decoded = model.decode(embeddings)

for i, (original, decoded_call) in enumerate(zip(test_tool_calls, decoded), 1):
    match = "✓" if original == decoded_call else "✗"
    print(f"{i}. {match}")
    print(f"   Original: {original}")
    print(f"   Decoded:  {decoded_call}")
    print()

Decoding embeddings back to tool calls...

1. ✗
   Original: search(query='machine learning', max_results=10)
   Decoded:  search(query='digital computing', max_results=7)

2. ✗
   Original: calculate(expression='2 + 2 * 3')
   Decoded:  calculate(expression='2 + 2')

3. ✗
   Original: database_query(table='users', limit=50)
   Decoded:  database_query(query='database', search='database', data='database', time='5')

4. ✗
   Original: send_email(to='user@example.com', subject='Test', body='Hello world')
   Decoded:  send me the following information: name, email, password, subject, and subject.

5. ✗
   Original: web_fetch(url='https://example.com', method='GET')
   Decoded:  web_fetch(url='https://api.example.com/get', method='GET')

6. ✗
   Original: file_read(path='/path/to/file.txt')
   Decoded:  file_read(path='/home/files/tmp/files/tmp/files/files/path/path/files/path/path/path/path/path/path/path/path/path/path/path/path/path/path/path/path/path/path/path/path/path/path/path/path

In [13]:
# Compute reconstruction accuracy
exact_matches = sum(1 for orig, recon in zip(test_tool_calls, reconstructed) if orig == recon)
accuracy = exact_matches / len(test_tool_calls)

print(f"Reconstruction Accuracy: {accuracy:.2%} ({exact_matches}/{len(test_tool_calls)})")

Reconstruction Accuracy: 0.00% (0/6)


In [14]:
# Example: Compute similarity between embeddings
import torch.nn.functional as F

print("Computing pairwise cosine similarity between embeddings...\n")
with torch.no_grad():
    # Normalize embeddings for cosine similarity
    normalized_embeddings = F.normalize(embeddings, p=2, dim=1)
    similarity_matrix = torch.mm(normalized_embeddings, normalized_embeddings.t())

print("Cosine similarity matrix:")
print("      ", end="")
for i in range(len(test_tool_calls)):
    print(f"{i+1:>6}", end="")
print()
for i in range(len(test_tool_calls)):
    print(f"  {i+1:2d}  ", end="")
    for j in range(len(test_tool_calls)):
        sim = similarity_matrix[i, j].item()
        print(f"{sim:6.3f}", end="")
    print()

Computing pairwise cosine similarity between embeddings...

Cosine similarity matrix:
           1     2     3     4     5     6
   1   0.996 0.820 0.906 0.855 0.852 0.832
   2   0.820 0.996 0.809 0.777 0.801 0.785
   3   0.906 0.809 1.000 0.875 0.871 0.859
   4   0.855 0.777 0.875 1.000 0.914 0.891
   5   0.852 0.801 0.871 0.914 1.000 0.898
   6   0.832 0.785 0.859 0.891 0.898 1.000


In [15]:
# Example: Interpolate between two tool calls
print("Interpolating between two tool calls...\n")

# Select two tool calls
idx1, idx2 = 0, 2
tool1 = test_tool_calls[idx1]
tool2 = test_tool_calls[idx2]

print(f"Tool 1: {tool1}")
print(f"Tool 2: {tool2}\n")

# Get embeddings
emb1 = embeddings[idx1:idx1+1]
emb2 = embeddings[idx2:idx2+1]

# Interpolate
num_steps = 5
interpolated_tool_calls = []
for alpha in torch.linspace(0, 1, num_steps):
    interp_emb = (1 - alpha) * emb1 + alpha * emb2
    decoded = model.decode(interp_emb)
    interpolated_tool_calls.append(decoded[0])
    print(f"α={alpha:.2f}: {decoded[0]}")

Interpolating between two tool calls...

Tool 1: search(query='machine learning', max_results=10)
Tool 2: database_query(table='users', limit=50)

α=0.00: search(query='digital computing', max_results=7)
α=0.25: search(query='digital computing', max_results=10)
α=0.50: search(query='digital computing', max_results=10)
α=0.75: search(query='shortest search', max_result=10)
α=1.00: database_query(query='database', search='database', data='database', time='5')


In [16]:
# Optional: Load and evaluate on test dataset
from training.data_generator import ToolInvocationGenerator, DataGeneratorConfig
from transformers import AutoTokenizer
from torch.utils.data import Dataset, DataLoader
from evaluation.metrics import compute_metrics

# Load test data if available
test_data_path = project_root / "data" / "test_data.txt"
if test_data_path.exists():
    print(f"Loading test data from {test_data_path}...")
    data_config = DataGeneratorConfig()
    generator = ToolInvocationGenerator(data_config)
    test_tool_calls = generator.load_dataset(str(test_data_path))
    
    print(f"Loaded {len(test_tool_calls)} test samples")
    
    # Evaluate on a subset (first 100 for quick testing)
    num_samples = min(100, len(test_tool_calls))
    test_subset = test_tool_calls[:num_samples]
    
    print(f"\nEvaluating on {num_samples} samples...")
    reconstructed = model.reconstruct(test_subset)
    embeddings = model.encode(test_subset)
    
    # Compute metrics
    metrics = compute_metrics(test_subset, reconstructed, embeddings)
    
    print("\nTest Metrics:")
    for key, value in metrics.items():
        if isinstance(value, (int, float)):
            print(f"  {key}: {value:.4f}")
        else:
            print(f"  {key}: {value}")
else:
    print(f"Test data not found at {test_data_path}")
    print("Skipping test evaluation. Run training script to generate test data.")

Loading test data from /scratch4/home/akrik/NTILC/data/test_data.txt...
Loaded 10000 test samples

Evaluating on 100 samples...

Test Metrics:
  exact_match_accuracy: 0.0600
  tool_accuracy: 0.9800
  param_str_accuracy: 0.1751
  param_int_accuracy: 0.0333
  embedding_mean_norm: 15.6875
  embedding_std_norm: 0.0302
  embedding_min_norm: 15.6875
  embedding_max_norm: 15.7500
  embedding_mean_variance: 0.1387
  embedding_std_variance: 0.0776
  embedding_mean_per_dim: -0.0000
  embedding_std_per_dim: 0.3594
