In [21]:
# === Configuration ===
ckpt_dir = "/home/joseph_tennyson/182/in-context-learning-GLM/src/models/poisson-0.32"
n_tasks = 2500
n_train = 40
lr = 0.05
max_steps = 100000
tol = 1e-12


In [2]:
# === Imports ===
import os
import yaml
import torch
from types import SimpleNamespace
from tqdm import tqdm
from torch.nn import PoissonNLLLoss
from models import build_model  # Ensure this is in your working directory

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [3]:
# === Data Sampling ===
def sample_data(n_tasks, n_points, d, scale=0.32):
    xs = torch.randn(n_tasks, n_points, d)
    ws = scale * torch.randn(n_tasks, d, 1)
    logits = xs @ ws
    ys = torch.poisson(torch.exp(logits.clamp(max=4))).squeeze(-1)
    return xs, ys, ws


In [24]:
# === Load Config===
config_path = os.path.join(ckpt_dir, "config.yaml")
cfg = yaml.load(open(config_path), Loader=yaml.FullLoader)

model_conf = SimpleNamespace(**cfg["model"])

# === Sample Data ===
d = cfg["model"]["n_dims"]
scale = cfg["training"]["task_kwargs"].get("scaling", 1.0)
xs_all, ys_all, _ = sample_data(n_tasks, n_train + 1, d, scale)
xs_train, xs_test = xs_all[:, :-1].to(device), xs_all[:, -1:].to(device)
ys_train, ys_test = ys_all[:, :-1].to(device), ys_all[:, -1:].to(device)


In [28]:
# === Transformer Evaluation ===
def evaluate_transformer(model, ckpt_file, xs_all, ys_all, xs_test, ys_test):
    # model.eval()

    # with torch.no_grad():
    #     out = model(xs_all, ys_all)
    #     last_loglam = out[:, -1]
    #     loss_fn = PoissonNLLLoss(log_input=True, full=True, reduction="mean")
    #     loss = loss_fn(last_loglam.unsqueeze(-1), ys_test)
    # print(f"{os.path.basename(ckpt_file)} | Transformer Poisson NLL: {loss.item():.4f}")
    
    model.eval()
    with torch.no_grad():
        out = model(xs_all, ys_all)  # shape: [n_tasks, seq_len]
        loss_fn = PoissonNLLLoss(log_input=True, full=True, reduction="mean")
        loss = loss_fn(out, ys_all)
    print(f"{os.path.basename(ckpt_file)} | Transformer Poisson NLL (all points): {loss.item():.4f}")


In [8]:
# === Oracle Gradient Descent Evaluation ===
def evaluate_oracle(xs_train, ys_train, xs_test, ys_test, lr, max_steps, tol):
    n_tasks, n_train, d = xs_train.shape
    w_hat = torch.randn(n_tasks, d, 1, device=device, requires_grad=True)
    opt = torch.optim.Adam([w_hat], lr=lr)
    loss_fn = PoissonNLLLoss(log_input=False, full=True, reduction="mean")
    prev = float("inf")

    for step in tqdm(range(1, max_steps + 1), desc="Oracle GD"):
        logits = (xs_train @ w_hat).squeeze(-1).clamp(max=3)
        pred = torch.exp(logits)
        loss = loss_fn(pred, ys_train)
        if abs(prev - loss.item()) < tol:
            print(f"Oracle converged at step {step} (Δloss={abs(prev-loss.item()):.2e})")
            break
        prev = loss.item()
        opt.zero_grad()
        loss.backward()
        opt.step()

    with torch.no_grad():
        logits_test = (xs_test @ w_hat).squeeze(-1).clamp(max=3)
        pred_test = torch.exp(logits_test)
        final_loss = loss_fn(pred_test, ys_test)
    print(f"Oracle baseline Poisson NLL: {final_loss.item():.4f}")


In [9]:
# === Naive Baseline Evaluation ===
def evaluate_naive(ys_train, ys_test):
    naive_mean = ys_train.mean(dim=1, keepdim=True)
    loss_fn = PoissonNLLLoss(log_input=False, full=True, reduction="mean")
    loss = loss_fn(naive_mean, ys_test)
    print(f"Naive Mean Baseline Poisson NLL: {loss.item():.4f}")


In [26]:
pt_files = sorted(f for f in os.listdir(ckpt_dir) if f.endswith(".pt"))
models = []

for fname in pt_files:
    ckpt_path = os.path.join(ckpt_dir, fname)
    state = torch.load(ckpt_path, map_location=device)
    
    model = build_model(model_conf).to(device)
    
    print(f"Loading checkpoint {fname}...")
    model.load_state_dict(state['model_state_dict'] if 'model_state_dict' in state else state)
    
    models.append([model, fname])

Loading checkpoint model_1000.pt...
Loading checkpoint model_2000.pt...
Loading checkpoint model_3000.pt...
Loading checkpoint model_4000.pt...
Loading checkpoint model_5000.pt...
Loading checkpoint model_6000.pt...
Loading checkpoint model_7000.pt...
Loading checkpoint state.pt...


In [29]:
# === Run Transformer Models ===
pt_files = sorted(f for f in os.listdir(ckpt_dir) if f.endswith(".pt"))
for model, fname in models:
    print(f"Evaluating checkpoint {fname}...")
    evaluate_transformer(model, fname, xs_train, ys_train, xs_test, ys_test)

# === Run Oracle & Naive ===
evaluate_oracle(xs_train, ys_train, xs_test, ys_test, lr, max_steps, tol)
evaluate_naive(ys_train, ys_test)


Evaluating checkpoint model_1000.pt...
model_1000.pt | Transformer Poisson NLL (all points): 2.2746
Evaluating checkpoint model_2000.pt...
model_2000.pt | Transformer Poisson NLL (all points): 1.9898
Evaluating checkpoint model_3000.pt...
model_3000.pt | Transformer Poisson NLL (all points): 1.8184
Evaluating checkpoint model_4000.pt...
model_4000.pt | Transformer Poisson NLL (all points): 1.7905
Evaluating checkpoint model_5000.pt...
model_5000.pt | Transformer Poisson NLL (all points): 1.7830
Evaluating checkpoint model_6000.pt...
model_6000.pt | Transformer Poisson NLL (all points): 1.7787
Evaluating checkpoint model_7000.pt...
model_7000.pt | Transformer Poisson NLL (all points): 1.7766
Evaluating checkpoint state.pt...
state.pt | Transformer Poisson NLL (all points): 1.7766


Oracle GD:   0%|          | 385/100000 [00:00<01:55, 864.48it/s]

Oracle converged at step 386 (Δloss=0.00e+00)
Oracle baseline Poisson NLL: 1.5592
Naive Mean Baseline Poisson NLL: 2.3288



