In [None]:
import os
import re
import json
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# -----------------------
# Configuration
# -----------------------
DATA_DIR = "datasets"
TRAIN_FOLDS = [f"fold_{i}.csv" for i in range(5)]
TEST_FILE = "test.csv"

TRAIN_PATHS = [os.path.join(DATA_DIR, f) for f in TRAIN_FOLDS]
TEST_PATH = os.path.join(DATA_DIR, TEST_FILE)

# Output directory
CLEAN_DIR = "clean"
os.makedirs(CLEAN_DIR, exist_ok=True)

# Valid amino acids (standard 20)
AA_VOCAB = list("ACDEFGHIKLMNPQRSTVWY")
AA_SET = set(AA_VOCAB)

# -----------------------
# Allele name correction utilities
# -----------------------
def fix_allele_format(allele: str) -> str:
    """Fix allele name to follow standard HLA format: HLA-A*02:101"""
    allele = str(allele).strip().upper()
    valid_pattern = re.compile(r"^HLA-[A-Z]\*\d{2}:\d{2,3}$")

    if valid_pattern.match(allele):
        return allele  # already valid

    # Remove unwanted characters
    allele = allele.replace(" ", "").replace("_", "").replace("--", "-")

    # Ensure starts with HLA-
    if not allele.startswith("HLA"):
        allele = "HLA-" + allele
    elif not allele.startswith("HLA-"):
        allele = allele.replace("HLA", "HLA-", 1)

    # Insert "*" if missing
    if "*" not in allele:
        allele = re.sub(r"^(HLA-[A-Z])(\d+)", r"\1*\2", allele)

    # Replace wrong separators (e.g., HLA-A*0201 -> HLA-A*02:01)
    allele = re.sub(r"(\*\d{2})(\d{2,3})$", r"\1:\2", allele)

    return allele


def clean_allele_column(df: pd.DataFrame, name: str) -> pd.DataFrame:
    """Validate and fix allele names in a DataFrame"""
    print(f"[{name}] Checking and fixing allele formats...")

    valid_pattern = re.compile(r"^HLA-[A-Z]\*\d{2}:\d{2,3}$")

    corrections_made = False
    corrected_alleles = []

    for allele in df["allele"]:
        corrected = fix_allele_format(allele)
        if corrected != allele:
            corrections_made = True
        corrected_alleles.append(corrected)

    if corrections_made:
        print(f"[DEBUG] Correcting allele naming...")

    df["allele"] = corrected_alleles

    invalid_mask = ~df["allele"].str.match(valid_pattern)
    invalid_count = invalid_mask.sum()

    if invalid_count > 0:
        print(f"[{name}] WARNING: {invalid_count} allele entries still invalid after correction.")
    elif corrections_made:
        print(f"[{name}] All allele names corrected successfully.")
    else:
        print(f"[{name}] No mismatch in allele found.")

    return df


# -----------------------
# Utils
# -----------------------
def load_csv_safe(path: str) -> pd.DataFrame:
    if not os.path.isfile(path):
        raise FileNotFoundError(f"File not found: {path}")
    df = pd.read_csv(path)
    expected_cols = {"peptide", "allele", "hit"}
    missing = expected_cols - set(df.columns)
    if missing:
        raise ValueError(f"{path} missing columns: {missing}")
    return df

def dataset_stats(df: pd.DataFrame, name: str):
    print(f"\n=== Dataset Stats: {name} ===")
    print(f"Rows: {len(df):,}")
    print("Columns:", list(df.columns))
    print("Unique alleles:", df["allele"].nunique())
    if "hit" in df.columns:
        print("Label distribution (hit):")
        print(df["hit"].value_counts(dropna=False))
    lens = df["peptide"].astype(str).str.len()
    print("Peptide length stats:")
    print(lens.describe())

def remove_exact_duplicates(df: pd.DataFrame, name: str) -> pd.DataFrame:
    before = len(df)
    df2 = df.drop_duplicates()
    print(f"[{name}] Removed exact duplicate rows: {before - len(df2)}")
    return df2

def remove_conflicting_duplicates(df: pd.DataFrame, name: str) -> pd.DataFrame:
    grp = df.groupby(["peptide", "allele"])["hit"].nunique()
    conflicts = grp[grp > 1]
    if conflicts.empty:
        print(f"[{name}] Conflicting duplicates: none.")
        return df
    before = len(df)
    bad_keys = set(conflicts.index)
    mask = df.set_index(["peptide", "allele"]).index.isin(bad_keys)
    df2 = df[~mask].copy()
    print(f"[{name}] Conflicting (peptide, allele) pairs: {len(conflicts)} | Rows removed: {before - len(df2)}")
    return df2

def handle_missing(df: pd.DataFrame, name: str) -> pd.DataFrame:
    miss = df.isna().sum()
    if miss.sum() == 0:
        print(f"[{name}] Missing values: none.")
        return df
    print(f"[{name}] Missing values per column:\n{miss}")
    before = len(df)
    df2 = df.dropna(subset=["peptide", "allele", "hit"]).copy()
    print(f"[{name}] Rows dropped due to missing peptide/allele/hit: {before - len(df2)}")
    return df2

def is_valid_peptide(seq: str) -> bool:
    s = str(seq).strip().upper()
    if len(s) == 0:
        return False
    return set(s).issubset(AA_SET)

def clean_invalid_peptides(df: pd.DataFrame, name: str) -> pd.DataFrame:
    lens_before = df["peptide"].astype(str).str.len().describe()
    valid_mask = df["peptide"].astype(str).str.upper().apply(is_valid_peptide)
    invalid = (~valid_mask).sum()
    df2 = df[valid_mask].copy()
    print(f"[{name}] Non-standard/invalid peptide rows removed: {invalid}")
    print(f"[{name}] Sequence length stats BEFORE removal:\n{lens_before}")
    print(f"[{name}] Sequence length stats AFTER removal:\n{df2['peptide'].astype(str).str.len().describe()}")
    return df2

def filter_test_alleles_in_train(train: pd.DataFrame, test: pd.DataFrame) -> pd.DataFrame:
    train_alleles = set(train["allele"].unique())
    test_alleles  = set(test["allele"].unique())
    unknown = sorted(list(test_alleles - train_alleles))
    if unknown:
        before = len(test)
        test2 = test[test["allele"].isin(train_alleles)].copy()
        print(f"[TEST] Alleles not present in TRAIN: {unknown}")
        print(f"[TEST] Removed rows with unknown alleles: {before - len(test2)}")
        return test2
    print("[TEST] All test alleles exist in training.")
    return test

def compute_max_len(train_df: pd.DataFrame) -> int:
    max_len = int(train_df["peptide"].astype(str).str.len().max())
    print(f"[FE] Max sequence length (from CLEANED TRAIN): {max_len}")
    return max_len

# -----------------------
# Visualization helpers
# -----------------------
def plot_class_distribution(df, out_prefix, top_n_alleles=15):
    plt.figure(figsize=(4,4))
    sns.countplot(x="hit", data=df)
    plt.title("General Class Distribution (hit)")
    plt.xlabel("Class (hit)")
    plt.ylabel("Count")
    plt.tight_layout()
    plt.savefig(os.path.join(CLEAN_DIR, f"{out_prefix}_class_distribution.png"))
    plt.close()

    plt.figure(figsize=(12,6))
    allele_counts = (
        df.groupby(["allele","hit"])
        .size()
        .reset_index(name="count")
    )
    top_alleles = df["allele"].value_counts().head(top_n_alleles).index
    sns.barplot(
        x="allele", y="count", hue="hit",
        data=allele_counts[allele_counts["allele"].isin(top_alleles)]
    )
    plt.title(f"Class Distribution per Allele (top {top_n_alleles})")
    plt.xticks(rotation=45, ha="right")
    plt.tight_layout()
    plt.savefig(os.path.join(CLEAN_DIR, f"{out_prefix}_class_distribution_per_allele.png"))
    plt.close()

def plot_length_distribution(df, out_prefix, top_n_alleles=15):
    df["length"] = df["peptide"].astype(str).str.len()

    plt.figure(figsize=(5,4))
    sns.histplot(df["length"], bins=20, kde=True, color="steelblue")
    plt.title("General Peptide Length Distribution")
    plt.xlabel("Peptide length")
    plt.ylabel("Count")
    plt.tight_layout()
    plt.savefig(os.path.join(CLEAN_DIR, f"{out_prefix}_length_distribution.png"))
    plt.close()

    top_alleles = df["allele"].value_counts().head(top_n_alleles).index
    plt.figure(figsize=(12,6))
    sns.boxplot(
        x="allele", y="length",
        data=df[df["allele"].isin(top_alleles)],
        showfliers=False
    )
    plt.title(f"Peptide Length Distribution per Allele (top {top_n_alleles})")
    plt.xticks(rotation=45, ha="right")
    plt.tight_layout()
    plt.savefig(os.path.join(CLEAN_DIR, f"{out_prefix}_length_distribution_per_allele.png"))
    plt.close()

# -----------------------
# Main flow
# -----------------------
train_parts = []
for p in TRAIN_PATHS:
    df = load_csv_safe(p)
    df["peptide"] = df["peptide"].astype(str).str.upper().str.strip()
    df["allele"]  = df["allele"].astype(str).str.strip()
    df = clean_allele_column(df, f"TRAIN file {os.path.basename(p)}")
    train_parts.append(df)
train_raw = pd.concat(train_parts, ignore_index=True)

test_raw = load_csv_safe(TEST_PATH)
test_raw["peptide"] = test_raw["peptide"].astype(str).str.upper().str.strip()
test_raw["allele"]  = test_raw["allele"].astype(str).str.strip()
test_raw = clean_allele_column(test_raw, "TEST")

# Stats
dataset_stats(train_raw, "TRAIN (raw)")
dataset_stats(test_raw,  "TEST (raw)")

# Clean
train = remove_exact_duplicates(train_raw, "TRAIN")
test  = remove_exact_duplicates(test_raw,  "TEST")

train = remove_conflicting_duplicates(train, "TRAIN")
if "hit" in test.columns:
    test = remove_conflicting_duplicates(test, "TEST")

train = handle_missing(train, "TRAIN")
test  = handle_missing(test,  "TEST")

train = clean_invalid_peptides(train, "TRAIN")
test  = clean_invalid_peptides(test,  "TEST")

test  = filter_test_alleles_in_train(train, test)

# Stats after cleaning
dataset_stats(train, "TRAIN (cleaned)")
dataset_stats(test,  "TEST (cleaned)")

# Visualization
print("\n[INFO] Generating distribution plots...")
plot_class_distribution(train, out_prefix="train")
plot_length_distribution(train, out_prefix="train")
print(f"[INFO] Figures saved to {CLEAN_DIR}")

# Compute max length
max_seq_len = compute_max_len(train)

# Save cleaned data & metadata
train_out = os.path.join(CLEAN_DIR, "train_clean.csv")
test_out  = os.path.join(CLEAN_DIR, "test_clean.csv")
meta_out  = os.path.join(CLEAN_DIR, "metadata.json")

train.to_csv(train_out, index=False)
test.to_csv(test_out, index=False)

metadata = {
    "aa_vocab": AA_VOCAB,
    "max_seq_len": max_seq_len,
    "train_alleles": sorted(train["allele"].unique())
}
with open(meta_out, "w") as f:
    json.dump(metadata, f, indent=2)

print(f"\nSaved cleaned TRAIN to: {train_out}")
print(f"Saved cleaned TEST  to: {test_out}")
print(f"Saved metadata to: {meta_out}")
print(f"Saved figures to: {CLEAN_DIR}")


In [None]:
# Notebook 2 — Training (Full Dataset with Batch Processing)

import os
import json
import numpy as np
import pandas as pd
from sklearn.linear_model import SGDClassifier
from sklearn.dummy import DummyClassifier
from sklearn.utils.class_weight import compute_class_weight
import joblib
import gc

# -----------------------
# Configuration
# -----------------------
DATA_DIR  = "."
CLEAN_DIR = os.path.join(DATA_DIR, "clean")
MODEL_DIR = os.path.join(DATA_DIR, "models")
os.makedirs(MODEL_DIR, exist_ok=True)

TRAIN_CLEAN = os.path.join(CLEAN_DIR, "train_clean.csv")
META_PATH   = os.path.join(CLEAN_DIR, "metadata.json")

RANDOM_STATE = 53
CHUNK_SIZE   = 50000       # number of rows processed per batch per allele
EPOCHS       = 10           # 🔁 number of passes over the full data
MAX_ITER     = 5            # internal iteration for partial_fit

# -----------------------
# Load data & metadata
# -----------------------
print("[INFO] Loading cleaned training data...")
train = pd.read_csv(TRAIN_CLEAN)
with open(META_PATH, "r") as f:
    meta = json.load(f)

AA_VOCAB  = meta["aa_vocab"]
AA_TO_IDX = {aa: i for i, aa in enumerate(AA_VOCAB)}

print(f"[INFO] Loaded {len(train):,} rows | {train['allele'].nunique()} alleles")

# -----------------------
# Encoder
# -----------------------
def one_hot_encode_padded(seqs, max_len, aa_to_idx):
    """One-hot encode peptides up to given max_len."""
    n = len(seqs)
    width = len(aa_to_idx)
    X = np.zeros((n, max_len * width), dtype=np.float32)
    for i, s in enumerate(seqs):
        s = str(s).strip().upper()
        L = min(len(s), max_len)
        for j in range(L):
            idx = aa_to_idx.get(s[j])
            if idx is not None:
                X[i, j*width + idx] = 1.0
    return X

# Compute per-allele max sequence length
allele_maxlen = (
    train.assign(length=train["peptide"].astype(str).str.len())
    .groupby("allele")["length"]
    .max()
    .astype(int)
    .to_dict()
)
print(f"[INFO] Computed max sequence lengths for {len(allele_maxlen)} alleles.")

# -----------------------
# Training per allele (chunked incremental mode)
# -----------------------
manifest = {
    "aa_vocab": AA_VOCAB,
    "allele_max_len": allele_maxlen,
    "models": []
}

allele_list = sorted(train["allele"].unique())

for idx, allele in enumerate(allele_list, 1):
    tr_full = train[train["allele"] == allele]
    if len(tr_full) == 0:
        continue

    n_samples = len(tr_full)
    max_len = allele_maxlen[allele]
    print(f"\n[{idx}/{len(allele_list)}] Training allele={allele} | n={n_samples:,} | max_len={max_len}")

    y_full = tr_full["hit"].astype(int).values

    # --- Handle single-class allele ---
    if len(np.unique(y_full)) < 2:
        clf = DummyClassifier(strategy="most_frequent", random_state=RANDOM_STATE)
        clf.fit(np.zeros((1, max_len * len(AA_TO_IDX))), [y_full[0]])
        print(f"[INFO] Only one class present → DummyClassifier.")
    else:
        # --- Compute manual class weights ---
        classes = np.array([0, 1])
        weights = compute_class_weight("balanced", classes=classes, y=y_full)
        class_weight_dict = {cls: w for cls, w in zip(classes, weights)}
        print(f"[INFO] Computed class weights: {class_weight_dict}")

        # --- Initialize incremental model ---
        clf = SGDClassifier(
            loss="log_loss",
            penalty="l2",
            random_state=RANDOM_STATE,
            max_iter=MAX_ITER,
            validation_fraction=0.1,
            learning_rate="optimal",
            tol=1e-3,
            n_jobs=-1,          # ✅ use all CPUs
            early_stopping=False
        )

        # Shuffle once before training
        tr_full = tr_full.sample(frac=1.0, random_state=RANDOM_STATE).reset_index(drop=True)

        # 🔁 Multiple epochs
        for epoch in range(EPOCHS):
            print(f"  [Epoch {epoch+1}/{EPOCHS}] ----------------------")

            for start in range(0, n_samples, CHUNK_SIZE):
                end = min(start + CHUNK_SIZE, n_samples)
                batch = tr_full.iloc[start:end]
                X_batch = one_hot_encode_padded(batch["peptide"].tolist(), max_len, AA_TO_IDX)
                y_batch = batch["hit"].astype(int).values

                sample_weight = np.array([class_weight_dict[y] for y in y_batch])

                # partial_fit trains incrementally
                if epoch == 0 and start == 0:
                    clf.fit(X_batch, y_batch, sample_weight=sample_weight)
                else:
                    clf.fit(X_batch, y_batch, sample_weight=sample_weight)

                del X_batch, y_batch, batch, sample_weight
                gc.collect()
                print(f"    [Batch {start//CHUNK_SIZE+1}] trained on rows {start:,}-{end:,}")

            # Optional: reshuffle each epoch for better convergence
            tr_full = tr_full.sample(frac=1.0, random_state=RANDOM_STATE + epoch).reset_index(drop=True)

    # --- Save model ---
    safe_allele = allele.replace("*", "").replace(":", "_").replace("/", "-")
    model_path = os.path.join(MODEL_DIR, f"model_{safe_allele}.joblib")
    joblib.dump(clf, model_path)

    manifest["models"].append({
        "allele": allele,
        "path": model_path,
        "n_train": int(n_samples),
        "max_len": int(max_len),
        "model_type": type(clf).__name__,
        "chunk_size": CHUNK_SIZE,
        "epochs": EPOCHS
    })

    print(f"[DONE] Saved model for {allele} -> {model_path}")
    gc.collect()

# -----------------------
# Save manifest
# -----------------------
manifest_path = os.path.join(MODEL_DIR, "manifest.json")
with open(manifest_path, "w") as f:
    json.dump(manifest, f, indent=2)

print(f"\n✅ Saved manifest to {manifest_path}")
print("[INFO] All models trained successfully.")


In [None]:
# Notebook 3 — Evaluation & Inference

import os
import json
import numpy as np
import pandas as pd
import joblib
from sklearn.metrics import accuracy_score, roc_auc_score, f1_score

# -----------------------
# Configuration
# -----------------------
DATA_DIR   = "."
CLEAN_DIR  = os.path.join(DATA_DIR, "clean")
MODEL_DIR  = os.path.join(DATA_DIR, "models")
OUT_DIR    = os.path.join(DATA_DIR, "outputs")
os.makedirs(OUT_DIR, exist_ok=True)

TEST_CLEAN = os.path.join(CLEAN_DIR, "test_clean.csv")
META_PATH  = os.path.join(CLEAN_DIR, "metadata.json")
MANIFEST   = os.path.join(MODEL_DIR, "manifest.json")

# -----------------------
# Load data
# -----------------------
test = pd.read_csv(TEST_CLEAN)
with open(META_PATH, "r") as f:
    meta = json.load(f)
with open(MANIFEST, "r") as f:
    manifest = json.load(f)

AA_VOCAB  = meta["aa_vocab"]
AA_TO_IDX = {aa: i for i, aa in enumerate(AA_VOCAB)}

# Per-allele max lengths stored in manifest
allele_maxlen = manifest.get("allele_max_len", {})
allele_to_model = {m["allele"]: m["path"] for m in manifest["models"]}

print(f"Loaded test rows: {len(test):,}")
print(f"Alleles in test: {test['allele'].nunique()}")
print(f"Models available for: {len(allele_to_model)} alleles")

# -----------------------
# Encoder
# -----------------------
def one_hot_encode_padded(seqs, max_len, aa_to_idx):
    """One-hot encode peptides with padding up to max_len."""
    n = len(seqs)
    width = len(aa_to_idx)
    X = np.zeros((n, max_len * width), dtype=np.float32)
    for i, s in enumerate(seqs):
        s = str(s).strip().upper()
        L = min(len(s), max_len)
        for j in range(L):
            aa = s[j]
            idx = aa_to_idx.get(aa)
            if idx is not None:
                X[i, j * width + idx] = 1.0
    return X

# -----------------------
# Evaluate per allele
# -----------------------
pred_rows = []
metric_rows = []

have_labels = "hit" in test.columns and not test["hit"].isna().any()
y_true_all, y_pred_all, y_prob_all = [], [], []

alleles_in_test = sorted(test["allele"].unique())

# Create subdirectory for per-allele outputs
ALLELE_OUT_DIR = os.path.join(OUT_DIR, "predictions")
os.makedirs(ALLELE_OUT_DIR, exist_ok=True)

for allele in alleles_in_test:
    if allele not in allele_to_model:
        print(f"Skipping allele={allele}: no trained model found.")
        continue

    te = test[test["allele"] == allele]
    if len(te) == 0:
        continue

    # Determine correct max_len for this allele
    max_len = allele_maxlen.get(allele, meta.get("max_seq_len", 15))

    # Encode test peptides
    X_te = one_hot_encode_padded(te["peptide"].tolist(), max_len, AA_TO_IDX)
    model_path = allele_to_model[allele]
    clf = joblib.load(model_path)

    # Verify feature size
    expected_features = getattr(clf, "n_features_in_", X_te.shape[1])
    if X_te.shape[1] != expected_features:
        print(f"[WARN] Skipping allele={allele}: feature size mismatch "
              f"(model={expected_features}, test={X_te.shape[1]})")
        continue

    # Predict probabilities
    if hasattr(clf, "predict_proba"):
        prob = clf.predict_proba(X_te)[:, 1]
    elif hasattr(clf, "decision_function"):
        df = clf.decision_function(X_te)
        df_min, df_max = df.min(), df.max()
        prob = (df - df_min) / (df_max - df_min + 1e-12)
    else:
        prob = clf.predict(X_te).astype(float)

    pred = (prob >= 0.5).astype(int)

    te_out = te[["peptide", "allele"]].copy()
    te_out["y_prob"] = prob
    te_out["y_pred"] = pred
    if have_labels:
        te_out["hit"] = te["hit"].values
    pred_rows.append(te_out)

    # === NEW: Save per-allele CSV ===
    safe_allele = allele.replace('*', '').replace(':', '_').replace('/', '-')
    allele_csv = os.path.join(ALLELE_OUT_DIR, f"predictions_{safe_allele}.csv")
    te_out.to_csv(allele_csv, index=False)
    print(f"Saved per-allele predictions to: {allele_csv}")

    # === Compute per-allele metrics ===
    if have_labels:
        y_te = te["hit"].astype(int).values
        if len(np.unique(y_te)) > 1:
            try:
                auc = roc_auc_score(y_te, prob)
            except ValueError:
                auc = np.nan
            acc = accuracy_score(y_te, pred)
            f1  = f1_score(y_te, pred, zero_division=0)
            metric_rows.append({
                "allele": allele, "n_test": len(te),
                "acc": acc, "auc": auc, "f1": f1
            })
            y_true_all.extend(y_te.tolist())
            y_pred_all.extend(pred.tolist())
            y_prob_all.extend(prob.tolist())
        else:
            metric_rows.append({
                "allele": allele, "n_test": len(te),
                "acc": None, "auc": None, "f1": None
            })

    print(f"Evaluated allele={allele:20s} | n_test={len(te):5d} | max_len={max_len}")

# -----------------------
# Combine & save
# -----------------------
if pred_rows:
    predictions = pd.concat(pred_rows, ignore_index=True)
else:
    predictions = pd.DataFrame(columns=["peptide","allele","y_prob","y_pred"] + (["hit"] if have_labels else []))

# Overall metrics
overall = {}
if have_labels and len(y_true_all) > 0 and len(set(y_true_all)) > 1:
    try:
        overall["auc"] = roc_auc_score(y_true_all, y_prob_all)
    except ValueError:
        overall["auc"] = np.nan
    overall["acc"] = accuracy_score(y_true_all, y_pred_all)
    overall["f1"]  = f1_score(y_true_all, y_pred_all, zero_division=0)
else:
    overall = {"auc": None, "acc": None, "f1": None}

# Save
pred_path = os.path.join(OUT_DIR, "predictions_per_allele.csv")
metrics_path = os.path.join(OUT_DIR, "metrics_per_allele.csv")
overall_path = os.path.join(OUT_DIR, "metrics_overall.json")

predictions.to_csv(pred_path, index=False)
pd.DataFrame(metric_rows).to_csv(metrics_path, index=False)
with open(overall_path, "w") as f:
    json.dump(overall, f, indent=2)

print(f"\nSaved predictions to: {pred_path}")
print(f"Saved per-allele metrics to: {metrics_path}")
print(f"Saved overall metrics to: {overall_path}")
print("\nOverall metrics:", overall)
