# ClinVar Pathogenicity Evaluation

This notebook evaluates the multitask CRISPR design model on held-out ClinVar test data for variant pathogenicity classification.

## Objectives
1. Load trained model checkpoint
2. Evaluate on held-out test set
3. Compute classification metrics (AUROC, AUPRC, F1, accuracy)
4. Plot ROC and PR curves
5. Analyze confusion matrix and misclassifications

In [None]:
import sys
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from transformers import AutoTokenizer

sys.path.insert(0, str(Path.cwd().parent / "src"))

from crispr_design_agent.evaluation.metrics import (
    compute_classification_metrics,
    compute_pr_curve_data,
    compute_roc_curve_data,
)
from crispr_design_agent.evaluation.visualization import (
    plot_calibration_curve,
    plot_confusion_matrix,
    plot_pr_curve,
    plot_roc_curve,
)
from crispr_design_agent.training.module import MultiTaskLightningModule

sns.set_style("whitegrid")
plt.rcParams["figure.dpi"] = 100

## Configuration

In [None]:
CHECKPOINT_PATH = "../models/checkpoints/multitask-epoch=10.ckpt"
DATA_PATH = "../data/processed/clinvar.parquet"
VAL_SPLIT = 0.1
SEED = 42
MAX_LENGTH = 1024
THRESHOLD = 0.5
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Using device: {DEVICE}")

## Load Data

In [None]:
df = pd.read_parquet(DATA_PATH)
print(f"Loaded {len(df)} ClinVar variants")
print(f"\nDataset shape: {df.shape}")
print(f"\nColumns: {df.columns.tolist()}")
df.head()

In [None]:
# Check class distribution
print("\nClass distribution:")
print(df["is_pathogenic"].value_counts())
print(f"\nPathogenic ratio: {df['is_pathogenic'].mean():.2%}")

In [None]:
# Split into train/val
df = df.sample(frac=1.0, random_state=SEED).reset_index(drop=True)
cutoff = int(len(df) * (1 - VAL_SPLIT))
test_df = df.iloc[cutoff:].reset_index(drop=True)

print(f"Test set size: {len(test_df)}")
print(f"Test pathogenic ratio: {test_df['is_pathogenic'].mean():.2%}")

## Load Model

In [None]:
# Load checkpoint
model = MultiTaskLightningModule.load_from_checkpoint(CHECKPOINT_PATH, strict=False)
model.eval()
model.to(DEVICE)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model.encoder_name, trust_remote_code=True)

print(f"Model loaded from {CHECKPOINT_PATH}")
print(f"Encoder: {model.encoder_name}")

## Run Inference

In [None]:
predictions_proba = []
batch_size = 8

with torch.inference_mode():
    for i in range(0, len(test_df), batch_size):
        batch_df = test_df.iloc[i : i + batch_size]
        sequences = batch_df["sequence"].tolist()
        
        tokens = tokenizer(
            sequences,
            return_tensors="pt",
            truncation=True,
            max_length=MAX_LENGTH,
            padding="max_length",
        ).to(DEVICE)
        
        pooled = model.forward(tokens["input_ids"], tokens["attention_mask"])
        logits = model.heads["clinvar"](pooled).squeeze(-1)
        proba = torch.sigmoid(logits)
        
        predictions_proba.extend(proba.cpu().numpy().tolist())
        
        if (i + batch_size) % 100 == 0:
            print(f"Processed {min(i + batch_size, len(test_df))} / {len(test_df)}")

test_df["prediction_proba"] = predictions_proba
test_df["prediction"] = (test_df["prediction_proba"] >= THRESHOLD).astype(int)
print(f"\nInference complete!")

## Compute Metrics

In [None]:
y_true = test_df["is_pathogenic"].values
y_pred_proba = test_df["prediction_proba"].values

metrics = compute_classification_metrics(
    y_true,
    y_pred_proba,
    threshold=THRESHOLD,
)

print("\n=== ClinVar Classification Metrics ===")
for key, value in metrics.items():
    if isinstance(value, float):
        print(f"{key:20s}: {value:.4f}")
    else:
        print(f"{key:20s}: {value}")

## ROC Curve

In [None]:
roc_data = compute_roc_curve_data(y_true, y_pred_proba)

fig = plot_roc_curve(
    roc_data["fpr"],
    roc_data["tpr"],
    roc_data["auroc"],
    title="ClinVar Pathogenicity - ROC Curve",
)
plt.show()

## Precision-Recall Curve

In [None]:
pr_data = compute_pr_curve_data(y_true, y_pred_proba)

fig = plot_pr_curve(
    pr_data["precision"],
    pr_data["recall"],
    pr_data["auprc"],
    title="ClinVar Pathogenicity - PR Curve",
)
plt.show()

## Confusion Matrix

In [None]:
y_pred = test_df["prediction"].values

fig = plot_confusion_matrix(
    y_true,
    y_pred,
    class_names=["Benign", "Pathogenic"],
    title="ClinVar Pathogenicity - Confusion Matrix",
    normalize=False,
)
plt.show()

## Calibration Curve

In [None]:
fig = plot_calibration_curve(
    y_true,
    y_pred_proba,
    n_bins=10,
    title="ClinVar Model Calibration",
)
plt.show()

## Prediction Distribution

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Overall distribution
axes[0].hist(y_pred_proba, bins=50, alpha=0.7, edgecolor="black")
axes[0].axvline(x=THRESHOLD, color="r", linestyle="--", lw=2, label=f"Threshold={THRESHOLD}")
axes[0].set_xlabel("Predicted Probability")
axes[0].set_ylabel("Count")
axes[0].set_title("Prediction Distribution")
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# By true class
benign_proba = test_df[test_df["is_pathogenic"] == 0]["prediction_proba"]
pathogenic_proba = test_df[test_df["is_pathogenic"] == 1]["prediction_proba"]

axes[1].hist(benign_proba, bins=30, alpha=0.6, label="Benign", color="blue")
axes[1].hist(pathogenic_proba, bins=30, alpha=0.6, label="Pathogenic", color="red")
axes[1].axvline(x=THRESHOLD, color="black", linestyle="--", lw=2, label=f"Threshold={THRESHOLD}")
axes[1].set_xlabel("Predicted Probability")
axes[1].set_ylabel("Count")
axes[1].set_title("Prediction Distribution by True Class")
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## Error Analysis

In [None]:
# False positives
false_positives = test_df[(test_df["is_pathogenic"] == 0) & (test_df["prediction"] == 1)]
print(f"\n=== False Positives (n={len(false_positives)}) ===")
print(false_positives.nlargest(5, "prediction_proba")[["sequence", "is_pathogenic", "prediction_proba"]])

# False negatives
false_negatives = test_df[(test_df["is_pathogenic"] == 1) & (test_df["prediction"] == 0)]
print(f"\n=== False Negatives (n={len(false_negatives)}) ===")
print(false_negatives.nsmallest(5, "prediction_proba")[["sequence", "is_pathogenic", "prediction_proba"]])

## Threshold Analysis

In [None]:
# Test different thresholds
thresholds = np.arange(0.1, 0.9, 0.05)
threshold_metrics = []

for thresh in thresholds:
    metrics_t = compute_classification_metrics(y_true, y_pred_proba, threshold=thresh)
    threshold_metrics.append({
        "threshold": thresh,
        "accuracy": metrics_t["accuracy"],
        "f1": metrics_t["f1"],
        "precision": metrics_t["precision"],
        "sensitivity": metrics_t["sensitivity"],
    })

threshold_df = pd.DataFrame(threshold_metrics)

fig, ax = plt.subplots(figsize=(10, 6))
ax.plot(threshold_df["threshold"], threshold_df["accuracy"], marker="o", label="Accuracy")
ax.plot(threshold_df["threshold"], threshold_df["f1"], marker="s", label="F1")
ax.plot(threshold_df["threshold"], threshold_df["precision"], marker="^", label="Precision")
ax.plot(threshold_df["threshold"], threshold_df["sensitivity"], marker="v", label="Sensitivity")
ax.set_xlabel("Threshold")
ax.set_ylabel("Score")
ax.set_title("Metrics vs Classification Threshold")
ax.legend()
ax.grid(True, alpha=0.3)
plt.show()

## Save Results

In [None]:
# Save predictions
output_path = Path("../results/clinvar_predictions.csv")
output_path.parent.mkdir(parents=True, exist_ok=True)
test_df.to_csv(output_path, index=False)
print(f"Predictions saved to {output_path}")

# Save metrics
metrics_df = pd.DataFrame([metrics])
metrics_path = Path("../results/clinvar_metrics.csv")
metrics_df.to_csv(metrics_path, index=False)
print(f"Metrics saved to {metrics_path}")