In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets
import numpy as np
import os
import sys
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import json

# Add parent directory to path to import utils etc
sys.path.append(os.path.dirname(os.path.abspath("")))
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(""))))

from utils import find_raptor_checkpoint
from dataloader import imagenet_transform
from dino_wrapper import DinoModelWrapper
from raptor_wrapper import RaptorWrapper
from overcomplete.metrics import r2_score

In [None]:
from paths import IMAGENET_VAL_DIR
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath("src")))
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 64

# Configurations from run_all.sh
experiments = [
    {"variant": "dino_s", "seed": 4000, "model_seed": None},
    {"variant": "dino_b", "seed": 4001, "model_seed": None},
    {"variant": "raptor2", "model_seed": 1001, "seed": 4002},
    {"variant": "raptor2", "model_seed": 1002, "seed": 4003},
    {"variant": "raptor2", "model_seed": 1003, "seed": 4004},
    {"variant": "raptor3", "model_seed": 1101, "seed": 4005},
    {"variant": "raptor3", "model_seed": 1102, "seed": 4006},
    {"variant": "raptor3", "model_seed": 1103, "seed": 4007},
    {"variant": "raptor4", "model_seed": 1201, "seed": 4008},
    {"variant": "raptor4", "model_seed": 1202, "seed": 4009},
    {"variant": "raptor4", "model_seed": 1203, "seed": 4010}
]

In [None]:
valset = datasets.ImageFolder(root=IMAGENET_VAL_DIR, transform=imagenet_transform())
val_loader = DataLoader(
    valset, batch_size=BATCH_SIZE, shuffle=False,
    num_workers=8, pin_memory=True,
)


In [None]:
def load_model_and_classifier(variant, model_seed, probe_seed, device):
    # Generate classifier filename based on train_probe.py logic
    if variant.startswith("raptor"):
        classifier_filename = f"{variant}_classifier_modelseed_{model_seed}_probeseed_{probe_seed}.pt"
    else:
        classifier_filename = f"{variant}_classifier_probeseed_{probe_seed}.pt"
    
    classifier_path = os.path.abspath(classifier_filename)
    if not os.path.exists(classifier_path):
        print(f"Warning: Classifier not found at {classifier_path}")
        return None, None

    if variant.startswith("raptor"):
        try:
            raptor_model_path = find_raptor_checkpoint(variant, model_seed, BASE_DIR)
            print(f"Loading Raptor backbone from: {raptor_model_path}")
        except Exception as e:
            print(f"Error locating Raptor checkpoint: {e}")
            return None, None
        
        try:
            model_state = torch.load(raptor_model_path, map_location=device)
            dino = torch.hub.load("facebookresearch/dinov2", "dinov2_vitb14_reg").to(device)
            model = RaptorWrapper(model_state, dino) # classifier init handled separately
            
            # Load classifier weights
            model.classifier.load_state_dict(torch.load(classifier_path, map_location=device))
            
            model.dino.eval()
            model.raptor.eval()
            model.classifier.eval()
            model = model.to(device).float()
            
            # Return both model and dino reference for consistency in eval step
            return model, dino
        except Exception as e:
            print(f"Error loading Raptor model: {e}")
            return None, None

    elif variant.startswith("dino"):
        dino_name = "dinov2_vits14_reg" if variant == "dino_s" else "dinov2_vitb14_reg"
        print(f"Loading Dino model: {dino_name}")
        model = DinoModelWrapper(dino_model=dino_name, device=device).to(device)
        
        # Load classifier weights
        model.classifier.load_state_dict(torch.load(classifier_path, map_location=device))
        model.dino.eval()
        model.classifier.eval()
        model = model.to(device).float()
        
        # For Dino wrapper, dino itself is inside execution flow usually, but our step fn might differ
        # Lets check validation_step logic
        return model, None
    else:
        return None, None

In [None]:
def validation_step(model, dino_ref, dataloader, variant):
    model.eval()
    accs = 0.0
    r2s = 0.0
    num_samples = 0
    
    is_raptor = variant.startswith("raptor")
    
    if is_raptor and dino_ref is not None:
        dino_ref.eval()

    with torch.no_grad():
        for i, (x, y) in enumerate(dataloader):
            x = x.to(DEVICE)
            y = y.to(DEVICE)
            batch_size = x.size(0)
            num_samples += batch_size

            if is_raptor:
                # Raptor forward: logits, a_pred
                logits, a_pred = model(x, layer_start=0, layer_end=12)
            else:
                # Dino variant
                logits, _ = model(x, layer_start=0, layer_end=12)
            
            # Compute Accuracy
            pred_label = logits.argmax(dim=1)
            #accs += (pred_label == y).float().sum().item()
            accs += (pred_label == y).sum().item()
            
            # Compute R2 for Raptor
            if is_raptor:
                 # dino_ref is a DinoModelWrapper instance
                 _, a_dino = dino_ref(x, layer_start=0, layer_end=12)
                 
                 # Calculate R2
                 if a_pred is not None and a_dino is not None:
                     if a_pred.shape == a_dino.shape:
                         # Using metric from 101_eval
                         current_r2 = r2_score(a_pred.reshape(-1, a_pred.size(3)), a_dino.reshape(-1, a_dino.size(3)))
                         r2s += current_r2 * batch_size
                                
    avg_acc = accs / num_samples
    if is_raptor:
        avg_r2 = r2s / num_samples
        avg_r2 = avg_r2.item()
    else:
        avg_r2 = 0.0 # R2 not applicable for pure Dino classification probe evaluation context usually
        
    return avg_acc, avg_r2

In [None]:
# Initialize a shared DinoModelWrapper for computing ground-truth R2 targets
dino_ref = DinoModelWrapper(device=DEVICE).to(DEVICE).float()
dino_ref.eval()

In [None]:
results = []
for exp in experiments:
    print(f"Processing {exp}")
    model, _ = load_model_and_classifier(exp["variant"], exp["model_seed"], exp["seed"], DEVICE)
    
    if model is None:
        print(f"Skipping {exp} due to load failure")
        continue
        
    # Run evaluation
    acc, r2 = validation_step(model, dino_ref, val_loader, exp["variant"])
    print(f"Result: Acc={acc:.4f}, R2={r2:.4f}")
    
    results.append({
        "variant": exp["variant"],
        "seed": exp["seed"],
        "model_seed": exp["model_seed"],
        "acc": acc,
        "r2": r2
    })
    
    # Free memory
    del model
    torch.cuda.empty_cache()

In [None]:
results_df = pd.DataFrame(results)
results_df.to_json("eval_results.json", orient="records")
print("Results saved to eval_results.json")
display(results_df)

In [None]:
# Aggregate results
agg = results_df.groupby("variant").agg({
    "acc": ["mean", "std"],
    "r2": ["mean", "std"]
})

# Extract values for plotting
# Variants: raptor2, raptor3, raptor4, dino_s, dino_b
def get_stats(variant_name):
    if variant_name in agg.index:
        mean = agg.loc[variant_name, ("acc", "mean")]
        std = agg.loc[variant_name, ("acc", "std")]
        r2_mean = agg.loc[variant_name, ("r2", "mean")]
        r2_std = agg.loc[variant_name, ("r2", "std")]
        return mean, std, r2_mean, r2_std
    return 0, 0, 0, 0

raptor2_acc, raptor2_err, raptor2_r2, raptor2_r2_err = get_stats("raptor2")
raptor3_acc, raptor3_err, raptor3_r2, raptor3_r2_err = get_stats("raptor3")
raptor4_acc, raptor4_err, raptor4_r2, raptor4_r2_err = get_stats("raptor4")

dino_s_acc, _, _, _ = get_stats("dino_s")
dino_b_acc, _, _, _ = get_stats("dino_b")

def plot_bar_acc_r2_with_error(accs, acc_errs, r2s, r2_errs, dino_s_acc, dino_b_acc, fontsize=16):
    sns.set_theme()
    fig, ax1 = plt.subplots(figsize=(8,8))

    # Data prep
    # Normalized to Dino Base Accumacy
    acc_norm = [(a / dino_b_acc) * 100.0 for a in accs]
    # Error bars also need scaling
    acc_err_norm = [(e / dino_b_acc) * 100.0 for e in acc_errs]
    
    dino_s_norm = (dino_s_acc / dino_b_acc) * 100.0
    
    xs = np.array([2, 3, 4])
    width = 0.35
    x_idx = np.arange(len(xs))

    acc_color = sns.color_palette("muted")[0]   # soft blue
    r2_color = sns.color_palette("muted")[1]    # soft orange

    # Dino Baseline
    ax1.axhline(dino_s_norm, label="DINOv2 ViT-S Acc (%)", linestyle="--", color="gray")

    # ---- Left axis: Accuracy ----
    bars_acc = ax1.bar(x_idx - width/2, acc_norm, width, yerr=acc_err_norm,
                       label="RAPTOR Acc (%)", color=acc_color, capsize=5)
    
    ax1.set_ylabel("Accuracy (% of ViT-B)", color=acc_color, fontsize=fontsize)
    ax1.tick_params(axis="y", labelcolor=acc_color)
    # Adjust ylim to visually fit
    ax1.set_ylim(90, 100)
    
    # ---- Right axis: R2 ----
    ax2 = ax1.twinx()
    bars_r2 = ax2.bar(x_idx + width/2, r2s, width, yerr=r2_errs,
                      label="RAPTOR R²", color=r2_color, capsize=5)
    ax2.set_ylabel("R²", color=r2_color, fontsize=fontsize)
    ax2.tick_params(axis="y", labelcolor=r2_color, labelsize=fontsize)
    ax2.set_ylim(0.5, 1.0)

    # ---- Labels ----
    # Acc labels
    for bar in bars_acc:
        h = bar.get_height()
        # For text, we can show raw acc or scaled? Lets show scaled value
        ax1.text(bar.get_x() + bar.get_width()/2, h + 0.2, f"{h:.1f}",
                 ha="center", va="bottom", fontsize=fontsize-4, color=acc_color)
    
    # R2 labels
    for bar in bars_r2:
        h = bar.get_height()
        ax2.text(bar.get_x() + bar.get_width()/2, h + 0.02, f"{h:.2f}",
                 ha="center", va="bottom", fontsize=fontsize-4, color=r2_color)

    # ---- X axis ----
    ax1.set_xticks(list(x_idx))
    ax1.set_xticklabels(["2", "3", "4"], fontsize=fontsize)
    ax1.set_xlabel("Number of Recurrent Blocks", fontsize=fontsize)

    # ---- Legends ----
    h1, l1 = ax1.get_legend_handles_labels()
    h2, l2 = ax2.get_legend_handles_labels()
    ax1.legend(h1 + h2, l1 + l2, fontsize=12, loc="upper left")

    ax1.tick_params(axis="both", labelsize=fontsize)
    ax2.tick_params(axis="both", labelsize=fontsize)
    ax1.grid(False)
    ax2.grid(False)
    plt.tight_layout()
    plt.savefig("raptor_vs_dino_bar_scaled_error_bars.pdf")
    plt.show()

acc_list = [raptor2_acc, raptor3_acc, raptor4_acc]
acc_err_list = [raptor2_err, raptor3_err, raptor4_err]
r2_list = [raptor2_r2, raptor3_r2, raptor4_r2]
r2_err_list = [raptor2_r2_err, raptor3_r2_err, raptor4_r2_err]

plot_bar_acc_r2_with_error(acc_list, acc_err_list, r2_list, r2_err_list, dino_s_acc, dino_b_acc)