# Deep Mutational Scanning (DMS) Evaluation

This notebook evaluates the multitask CRISPR design model on held-out DMS test data.

## Objectives
1. Load trained model checkpoint
2. Evaluate on held-out test set
3. Compute regression metrics (MAE, RMSE, RÂ², Pearson/Spearman correlation)
4. Visualize predictions vs ground truth
5. Analyze error patterns and outliers

In [None]:
import sys
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from transformers import AutoTokenizer

sys.path.insert(0, str(Path.cwd().parent / "src"))

from crispr_design_agent.evaluation.metrics import (
    compute_regression_metrics,
    stratified_evaluation,
)
from crispr_design_agent.evaluation.visualization import (
    plot_regression_results,
)
from crispr_design_agent.training.module import MultiTaskLightningModule

sns.set_style("whitegrid")
plt.rcParams["figure.dpi"] = 100

## Configuration

In [None]:
CHECKPOINT_PATH = "../models/checkpoints/multitask-epoch=10.ckpt"
DATA_PATH = "../data/processed/dms.parquet"
VAL_SPLIT = 0.1
SEED = 42
MAX_LENGTH = 1024
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Using device: {DEVICE}")

## Load Data

In [None]:
df = pd.read_parquet(DATA_PATH)
print(f"Loaded {len(df)} DMS measurements")
print(f"\nDataset shape: {df.shape}")
print(f"\nColumns: {df.columns.tolist()}")
df.head()

In [None]:
# Split into train/val
df = df.sample(frac=1.0, random_state=SEED).reset_index(drop=True)
cutoff = int(len(df) * (1 - VAL_SPLIT))
test_df = df.iloc[cutoff:].reset_index(drop=True)

print(f"Test set size: {len(test_df)}")
print(f"\nTarget statistics:")
print(test_df["effect"].describe())

## Load Model

In [None]:
# Load checkpoint
model = MultiTaskLightningModule.load_from_checkpoint(CHECKPOINT_PATH, strict=False)
model.eval()
model.to(DEVICE)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model.encoder_name, trust_remote_code=True)

print(f"Model loaded from {CHECKPOINT_PATH}")
print(f"Encoder: {model.encoder_name}")

## Run Inference

In [None]:
predictions = []
batch_size = 8

with torch.inference_mode():
    for i in range(0, len(test_df), batch_size):
        batch_df = test_df.iloc[i : i + batch_size]
        sequences = batch_df["sequence"].tolist()
        
        tokens = tokenizer(
            sequences,
            return_tensors="pt",
            truncation=True,
            max_length=MAX_LENGTH,
            padding="max_length",
        ).to(DEVICE)
        
        pooled = model.forward(tokens["input_ids"], tokens["attention_mask"])
        logits = model.heads["dms"](pooled).squeeze(-1)
        
        predictions.extend(logits.cpu().numpy().tolist())
        
        if (i + batch_size) % 100 == 0:
            print(f"Processed {min(i + batch_size, len(test_df))} / {len(test_df)}")

test_df["prediction"] = predictions
print(f"\nInference complete!")

## Compute Metrics

In [None]:
y_true = test_df["effect"].values
y_pred = test_df["prediction"].values

metrics = compute_regression_metrics(y_true, y_pred)

print("\n=== DMS Regression Metrics ===")
for key, value in metrics.items():
    print(f"{key:20s}: {value:.4f}")

## Visualize Results

In [None]:
fig = plot_regression_results(
    y_true,
    y_pred,
    title="DMS Effect Prediction",
    figsize=(15, 4),
)
plt.show()

## Error Analysis

In [None]:
# Calculate absolute errors
test_df["abs_error"] = np.abs(test_df["effect"] - test_df["prediction"])

# Find worst predictions
worst_predictions = test_df.nlargest(10, "abs_error")

print("\n=== Top 10 Worst Predictions ===")
print(worst_predictions[["sequence", "effect", "prediction", "abs_error"]])

In [None]:
# Error by sequence length
test_df["seq_length"] = test_df["sequence"].str.len()

fig, ax = plt.subplots(figsize=(10, 6))
ax.scatter(test_df["seq_length"], test_df["abs_error"], alpha=0.5, s=10)
ax.set_xlabel("Sequence Length")
ax.set_ylabel("Absolute Error")
ax.set_title("Prediction Error vs Sequence Length")
ax.grid(True, alpha=0.3)
plt.show()

## Stratified Evaluation

In [None]:
# If dataset has protein/gene identifiers, evaluate per protein
if "protein_id" in test_df.columns:
    stratified_metrics = stratified_evaluation(
        y_true,
        y_pred,
        test_df["protein_id"].values,
        problem_type="regression",
    )
    
    print("\n=== Per-Protein Metrics ===")
    for protein_id, metrics in stratified_metrics.items():
        print(f"\n{protein_id}:")
        for key, value in metrics.items():
            print(f"  {key}: {value:.4f}")
else:
    print("No protein_id column found for stratified evaluation")

## Save Results

In [None]:
# Save predictions
output_path = Path("../results/dms_predictions.csv")
output_path.parent.mkdir(parents=True, exist_ok=True)
test_df.to_csv(output_path, index=False)
print(f"Predictions saved to {output_path}")

# Save metrics
metrics_df = pd.DataFrame([metrics])
metrics_path = Path("../results/dms_metrics.csv")
metrics_df.to_csv(metrics_path, index=False)
print(f"Metrics saved to {metrics_path}")