# NL-MTP HoF Experiment

Homoiconic transformer with fast-weights (LoR) for MTP policy evaluation on HoF.

## Setup

- **Objective**: Train a transformer to perform policy evaluation (MTP) under δ=+14 Da shift
- **Architecture**: 12-layer transformer with LoR at layers {3,7,11}, rank=8
- **Losses**: DR/AIPW + MDN propensity + REx invariance + LoR locality
- **Data**: BOOM HoF splits with RDKit features

In [1]:
# Imports and setup
import os
import sys
import torch
import torch.optim as optim
from pathlib import Path

# Add repo root to path
repo_root = Path.cwd().parent.parent
sys.path.insert(0, str(repo_root))

# Import experiment modules
try:
    from dataset import make_dataloaders
    from model import NL_MTP_Model
    from trainer import train_epoch, evaluate
    from eval import save_metrics_json, make_all_plots
except ImportError:
    sys.path.insert(0, str(Path.cwd()))
    from dataset import make_dataloaders
    from model import NL_MTP_Model
    from trainer import train_epoch, evaluate
    from eval import save_metrics_json, make_all_plots

# Device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

Using device: cpu


In [2]:
# Configuration
config = {
    'batch_size': 64,
    'epochs': 30,
    'delta': 14.0,
    'lr': 2e-4,
    'warmup_epochs': 5,
    'out_dir': 'results',
    
    # Model
    'emb_dim': 512,
    'num_layers': 12,
    'num_heads': 8,
    'dim_ff': 2048,
    'lor_layers': (3, 7, 11),
    'lor_rank': 8,
    'mdn_components': 8,
}

os.makedirs(config['out_dir'], exist_ok=True)
print("Configuration:")
for k, v in config.items():
    print(f"  {k}: {v}")

Configuration:
  batch_size: 64
  epochs: 30
  delta: 14.0
  lr: 0.0002
  warmup_epochs: 5
  out_dir: results
  emb_dim: 512
  num_layers: 12
  num_heads: 8
  dim_ff: 2048
  lor_layers: (3, 7, 11)
  lor_rank: 8
  mdn_components: 8


In [3]:
# Load data
print("Loading data...")
loaders = make_dataloaders(batch_size=config['batch_size'])
train_dl, id_dl, ood_dl = loaders

print(f"Train batches: {len(train_dl)}")
print(f"ID batches: {len(id_dl)}")
print(f"OOD batches: {len(ood_dl)}")

# Check a sample
sample = next(iter(train_dl))
print(f"\nSample batch shapes:")
for k, v in sample.items():
    print(f"  {k}: {v.shape}")

Loading data...
File downloaded successfully and saved as c:\Users\prapa\Documents\GitHub\BOOM-TTE\experiments\gnn-nlmtp\10k_dft_density_data.csv
File downloaded successfully and saved as c:\Users\prapa\Documents\GitHub\BOOM-TTE\experiments\gnn-nlmtp\10k_dft_hof_data.csv
There are 1000 OOD density samples. There are 1000 OOD hof samples. There are 440 IID density samples. There are 423 IID hof samples.
Train batches: 138
ID batches: 7
OOD batches: 16

Sample batch shapes:
  x_ctx: torch.Size([64, 2223])
  mw: torch.Size([64])
  y: torch.Size([64])
  env_idx: torch.Size([64])


In [4]:
# Initialize model
print("Initializing model...")
model = NL_MTP_Model(
    emb_dim=config['emb_dim'],
    num_layers=config['num_layers'],
    num_heads=config['num_heads'],
    dim_ff=config['dim_ff'],
    lor_layers=config['lor_layers'],
    lor_rank=config['lor_rank'],
    mdn_components=config['mdn_components'],
).to(device)

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total params: {total_params:,}")
print(f"Trainable params: {trainable_params:,}")

Initializing model...
Total params: 40,939,044
Trainable params: 40,939,044


In [5]:
# Optimizer and scheduler
opt = optim.AdamW(model.parameters(), lr=config['lr'], weight_decay=1e-2)

total_steps = config['epochs'] * len(train_dl)
warmup_steps = min(2000, total_steps // 10)

sched = optim.lr_scheduler.OneCycleLR(
    opt,
    max_lr=config['lr'],
    total_steps=total_steps,
    pct_start=warmup_steps / total_steps,
    anneal_strategy='cos',
)

print(f"Total steps: {total_steps}")
print(f"Warmup steps: {warmup_steps}")

Total steps: 4140
Warmup steps: 414


In [None]:
# Training loop
print(f"\nTraining for {config['epochs']} epochs...\n")

best_ood_rmse = float('inf')
history = []

for epoch in range(1, config['epochs'] + 1):
    # Train
    train_metrics = train_epoch(
        model,
        loaders,
        opt,
        sched,
        delta=config['delta'],
        epoch=epoch,
        device=device,
        warmup_epochs=config['warmup_epochs'],
    )
    
    # Evaluate
    val_metrics = evaluate(model, loaders, delta=config['delta'], device=device)
    
    # Log
    print(f"Epoch {epoch}/{config['epochs']}")
    print(f"  Train Loss: {train_metrics['train_loss']:.4f} "
          f"(obs={train_metrics['L_obs']:.4f}, mdn={train_metrics['L_mdn']:.4f}, "
          f"dr_func={train_metrics['L_dr_func']:.4f}, rex={train_metrics['L_rex']:.4f})")
    print(f"  ID:  RMSE={val_metrics['id_rmse']:.4f}, MAE={val_metrics['id_mae']:.4f}, "
          f"Contrast={val_metrics['id_policy_contrast']:.4f}")
    print(f"  OOD: RMSE={val_metrics['ood_rmse']:.4f}, MAE={val_metrics['ood_mae']:.4f}, "
          f"Contrast={val_metrics['ood_policy_contrast']:.4f}")
    
    # Save best
    if val_metrics['ood_rmse'] < best_ood_rmse:
        best_ood_rmse = val_metrics['ood_rmse']
        torch.save(model.state_dict(), os.path.join(config['out_dir'], 'best_model.pth'))
        print(f"  *** New best OOD RMSE: {best_ood_rmse:.4f} ***")
    
    # Store history
    history.append({
        'epoch': epoch,
        **train_metrics,
        **{k: v for k, v in val_metrics.items() if not isinstance(v, torch.Tensor)}
    })

print("\nTraining complete!")


Training for 30 epochs...



In [None]:
# Final evaluation
print("="*60)
print("Final Evaluation")
print("="*60)

model.load_state_dict(torch.load(os.path.join(config['out_dir'], 'best_model.pth')))
final_metrics = evaluate(model, loaders, delta=config['delta'], device=device)

print(f"\nID:  RMSE={final_metrics['id_rmse']:.4f}, MAE={final_metrics['id_mae']:.4f}")
print(f"OOD: RMSE={final_metrics['ood_rmse']:.4f}, MAE={final_metrics['ood_mae']:.4f}")
print(f"\nPolicy contrast (Δpred):")
print(f"  ID:  {final_metrics['id_policy_contrast']:.4f}")
print(f"  OOD: {final_metrics['ood_policy_contrast']:.4f}")
print(f"\nAlpha (LoR gate):")
print(f"  ID:  {final_metrics['id_alpha']:.4f}")
print(f"  OOD: {final_metrics['ood_alpha']:.4f}")

In [None]:
# Save results
save_metrics_json(final_metrics, os.path.join(config['out_dir'], 'metrics.json'))
make_all_plots(final_metrics, config['out_dir'])

print(f"\nResults saved to {config['out_dir']}")
print(f"  - metrics.json")
print(f"  - NL_MTP_HoF_ID_parity.png")
print(f"  - NL_MTP_HoF_OOD_parity.png")
print(f"  - best_model.pth")

In [None]:
# Plot training history
import matplotlib.pyplot as plt
import pandas as pd

df = pd.DataFrame(history)

fig, axes = plt.subplots(2, 2, figsize=(12, 8))

# Loss
axes[0, 0].plot(df['epoch'], df['train_loss'], label='Train Loss')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].set_title('Training Loss')
axes[0, 0].legend()
axes[0, 0].grid(alpha=0.3)

# RMSE
axes[0, 1].plot(df['epoch'], df['id_rmse'], label='ID')
axes[0, 1].plot(df['epoch'], df['ood_rmse'], label='OOD')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('RMSE')
axes[0, 1].set_title('RMSE over Epochs')
axes[0, 1].legend()
axes[0, 1].grid(alpha=0.3)

# Loss components
axes[1, 0].plot(df['epoch'], df['L_obs'], label='L_obs')
axes[1, 0].plot(df['epoch'], df['L_mdn'], label='L_mdn')
axes[1, 0].plot(df['epoch'], df['L_dr_func'], label='L_DR-func')
axes[1, 0].plot(df['epoch'], df['L_rex'], label='L_rex')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('Loss')
axes[1, 0].set_title('Loss Components')
axes[1, 0].legend()
axes[1, 0].grid(alpha=0.3)

# Policy contrast
axes[1, 1].plot(df['epoch'], df['id_policy_contrast'], label='ID')
axes[1, 1].plot(df['epoch'], df['ood_policy_contrast'], label='OOD')
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('Δpred (Policy Contrast)')
axes[1, 1].set_title('Policy Contrast over Epochs')
axes[1, 1].legend()
axes[1, 1].grid(alpha=0.3)

plt.tight_layout()
plt.savefig(os.path.join(config['out_dir'], 'training_history.png'), dpi=150)
plt.show()

print("Training history plot saved.")

In [None]:
# Display parity plots
from IPython.display import Image, display

print("ID Parity Plot:")
display(Image(filename=os.path.join(config['out_dir'], 'NL_MTP_HoF_ID_parity.png')))

print("\nOOD Parity Plot:")
display(Image(filename=os.path.join(config['out_dir'], 'NL_MTP_HoF_OOD_parity.png')))