In [None]:
# ─────────────────────────────────────────────────────────────────────────────
# 05_benchmarks.ipynb
# ─────────────────────────────────────────────────────────────────────────────

# 0) Ensure src/ is on PYTHONPATH, and import necessary pieces
%run setup.py

import time
import numpy as np
import pandas as pd
import torch
from pathlib import Path

from sklearn.metrics import (
    log_loss,
    accuracy_score,
    f1_score,
    precision_score,
    recall_score,
    roc_auc_score,
)
from src.pretrained_models import default_experts, MoEClassifier
from src.logs import log_event, LogKind

# ─────────────────────────────────────────────────────────────────────────────
# 1) Find all “moe_<key>_idxs.npy” under models/gates/ and pair with
#    the matching “gate_<key>_retrained.pt”
# ─────────────────────────────────────────────────────────────────────────────
gate_dir = Path("models/gates")
idx_files = sorted(gate_dir.glob("moe_*_idxs.npy"))

if not idx_files:
    raise RuntimeError(f"No ‘moe_<key>_idxs.npy’ files found in {gate_dir}.")

# Build a list of (key, idx_path, gate_path) tuples
top_runs = []
for idx_path in idx_files:
    stem = idx_path.stem  # e.g. "moe_LRFeatureExpert+XGBFeatureExpert_idxs"
    if not stem.startswith("moe_") or not stem.endswith("_idxs"):
        continue
    key = stem[len("moe_") : -len("_idxs")]  # e.g. "LRFeatureExpert+XGBFeatureExpert"
    gate_path = gate_dir / f"gate_{key}_retrained.pt"
    if not gate_path.exists():
        continue  # skip if no matching .pt
    top_runs.append((key, idx_path, gate_path))

if not top_runs:
    raise RuntimeError(f"No matching ‘gate_<key>_retrained.pt’ found under {gate_dir}.")

print(f"Found {len(top_runs)} gate checkpoints:")
for key, idx_path, gate_path in top_runs:
    print(f"  * key = {key!r}, idxs = {idx_path.name}, gate = {gate_path.name}")

# ─────────────────────────────────────────────────────────────────────────────
# 2) Prepare the TEST split once
# ─────────────────────────────────────────────────────────────────────────────
test_df = pd.read_csv("../data/splits/test.csv").dropna(subset=["question1", "question2"])
pairs_test = list(zip(test_df.question1, test_df.question2))
y_test = test_df.is_duplicate.values.astype(int)

# ─────────────────────────────────────────────────────────────────────────────
# 3) Evaluate each “top” run on TEST, collect & log metrics
# ─────────────────────────────────────────────────────────────────────────────
# We need a 768-dim embedding for QuoraDistilExpert
emb_path_768 = "../data/processed/question_embeddings_768.npy"
lr_path      = "../models/pretrained/quoradistil_lr.pkl"

for key, idx_path, gate_path in top_runs:
    # Load the selected expert-indices
    idxs = np.load(idx_path).tolist()  # e.g. [0,3,5,...]

    # Reconstruct the exact expert list (same order) and pick only those indices
    all_experts = default_experts(
        emb_path=emb_path_768,
        lr_path=lr_path
    )
    experts = [all_experts[i] for i in idxs]

    # Instantiate MoEClassifier (gate) and load saved weights
    moe = MoEClassifier(experts, lr=1.0, epochs=0)
    moe.gate.load_state_dict(torch.load(gate_path, map_location="cpu"))
    moe.gate.eval()

    # Run inference on TEST and time it
    t0 = time.time()
    probs = moe.predict_prob(pairs_test)
    inference_time = time.time() - t0

    # Compute thresholded predictions
    preds = (probs > 0.5).astype(int)

    # Compute a variety of metrics
    ll   = log_loss(y_test, probs)
    acc  = accuracy_score(y_test, preds)
    f1   = f1_score(y_test, preds)
    prec = precision_score(y_test, preds)
    rec  = recall_score(y_test, preds)

    # ROC-AUC requires probabilities and both classes present
    try:
        auc = roc_auc_score(y_test, probs)
    except ValueError:
        auc = float("nan")

    # Print to notebook
    print("\n" + "="*80)
    print(f"Gate key: {key}")
    print(f"  * TEST log-loss  = {ll:.4f}")
    print(f"  * TEST accuracy  = {acc:.4f}")
    print(f"  * TEST  F1 score = {f1:.4f}")
    print(f"  * TEST Precision = {prec:.4f}")
    print(f"  * TEST   Recall  = {rec:.4f}")
    print(f"  * TEST    AUC    = {auc:.4f}")
    print(f"  * Inference time = {inference_time:.2f}s")
    print("="*80)

    # Log to metric_logs/benchmarks.csv
    log_event(
        LogKind.TEST,
        model=key,
        phase="eval",
        seconds=round(inference_time, 2),
        test_LL=round(ll, 6),
        test_ACC=round(acc, 4),
        test_F1=round(f1, 4),
        test_PREC=round(prec, 4),
        test_REC=round(rec, 4),
        test_AUC=round(auc, 4),
    )

In [None]:
import pandas as pd
import matplotlib.pyplot as plt

# 1) Load the benchmark CSV
df = pd.read_csv("metric_logs/benchmarks.csv")

# 2) Show the full DataFrame for inspection
import ace_tools as tools
tools.display_dataframe_to_user(name="Benchmark Results", dataframe=df)

# 3) Display summary statistics
summary = df.describe().round(4)
tools.display_dataframe_to_user(name="Summary Statistics", dataframe=summary)

# 4) Bar plot: Test Log-Loss by Model
plt.figure(figsize=(8, 5))
plt.bar(df["model"], df["test_LL"])
plt.xticks(rotation=45, ha="right")
plt.ylabel("Test Log-Loss")
plt.title("Test Log-Loss by Model")
plt.tight_layout()
plt.show()

# 5) Bar plot: Test Accuracy by Model
plt.figure(figsize=(8, 5))
plt.bar(df["model"], df["test_ACC"])
plt.xticks(rotation=45, ha="right")
plt.ylabel("Test Accuracy")
plt.title("Test Accuracy by Model")
plt.tight_layout()
plt.show()

# 6) Bar plot: Test F1 Score by Model
plt.figure(figsize=(8, 5))
plt.bar(df["model"], df["test_F1"])
plt.xticks(rotation=45, ha="right")
plt.ylabel("Test F1 Score")
plt.title("Test F1 Score by Model")
plt.tight_layout()
plt.show()

# 7) Scatter: Inference Time vs Test Log-Loss
plt.figure(figsize=(6, 6))
plt.scatter(df["seconds"], df["test_LL"])
plt.xlabel("Inference Time (s)")
plt.ylabel("Test Log-Loss")
plt.title("Inference Time vs Test Log-Loss")
plt.tight_layout()
plt.show()

# 8) Correlation matrix heatmap of metrics
corr = df[
    ["test_LL", "test_ACC", "test_F1", "test_PREC", "test_REC", "test_AUC", "seconds"]
].corr()

plt.figure(figsize=(6, 6))
plt.imshow(corr, cmap="viridis", interpolation="none")
plt.colorbar()
plt.xticks(range(len(corr)), corr.columns, rotation=45, ha="right")
plt.yticks(range(len(corr)), corr.columns)
plt.title("Correlation Matrix of Metrics")
plt.tight_layout()
plt.show()