In [1]:
# ==============================================================
# ِAuthor: Ahmed
# ClinVar v2: Gene-disjoint (StratifiedGroupKFold) CV
# REF/ALT-aware windows + Dual-branch (CGR-2D + OneHot-1D)
# Optimized for: RTX 3080 (CUDA) + Threadripper 3990X
#
# ✅ Works in Jupyter on Windows (forces num_workers=0 to avoid spawn/pickle crash)
# ✅ Works as .py script (uses multi-worker DataLoader for CPU preproc speed)
#
# Edit REF_FA and VCF_PATH .
# ==============================================================
import os
import re
import time
import math
import numpy as np
import pandas as pd
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset

from pyfaidx import Fasta
from sklearn.model_selection import StratifiedGroupKFold
from sklearn.metrics import (
    roc_auc_score, average_precision_score,
    accuracy_score, precision_score, recall_score, f1_score,
    precision_recall_curve, roc_curve, confusion_matrix
)

import matplotlib.pyplot as plt

# ------------------------------
# 0) CONFIG (EDIT THESE PATHS)
# ------------------------------
REF_FA   = r"C:\Users\muham\_Projects\DNA\ref\GRCh38.fa"   # <-- set yours
VCF_PATH = r"clinvar.vcf"                                  # <-- set yours

# Data / featurization
WINDOW_WIDTH = 201
CGR_K = 6                     # 2^6=64 grid (fast). Try 7 later if you want.
MAX_VCF_ROWS = 500_000         # how many VCF rows to scan

# Training
BATCH_SIZE = 64
EPOCHS = 20
LR = 1e-3
N_SPLITS = 5

# IO / cache
CACHE_PARQUET = "df_model_500k.parquet"
SAVE_FOLD_CSV = "gene_disjoint_fold_metrics.csv"
SAVE_TYPE_CSV = "variant_type_fold_metrics.csv"

# Plot settings
FIG_DPI = 300   # set 600 if you want higher resolution figures

# Reproducibility
SEED = 42

# ------------------------------
# 1) Utils: environment + device
# ------------------------------
def set_seed(seed: int = 42):
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def in_notebook() -> bool:
    try:
        from IPython import get_ipython
        ip = get_ipython()
        return (ip is not None) and ("IPKernelApp" in ip.config)
    except Exception:
        return False

class DummyScaler:
    """Fallback when AMP is disabled (CPU or user choice)."""
    def scale(self, loss): return loss
    def step(self, optimizer): optimizer.step()
    def update(self): pass

def setup_device():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("Using device:", device)

    use_amp = (device.type == "cuda")
    if use_amp:
        print("GPU:", torch.cuda.get_device_name(0))
        torch.backends.cudnn.benchmark = True
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
        try:
            torch.set_float32_matmul_precision("high")
        except Exception:
            pass
        scaler = torch.amp.GradScaler("cuda", enabled=True)
    else:
        scaler = DummyScaler()

    return device, use_amp, scaler

def best_num_workers():
    # In Jupyter on Windows, multi-worker DataLoader often crashes due to pickling in __main__
    if os.name == "nt" and in_notebook():
        return 0
    # Script mode: Threadripper can help; keep it moderate to avoid RAM thrash.
    cpu = os.cpu_count() or 8
    return max(4, min(16, cpu // 4))

# ------------------------------
# 2) ClinVar parsing helpers
# ------------------------------
def parse_info(info_str: str) -> dict:
    out = {"CLNSIG":"", "CLNVC":"", "MC":"", "GENEINFO":"", "SYMBOL":""}
    if not isinstance(info_str, str):
        return out
    fields = {}
    for item in info_str.split(";"):
        if "=" in item:
            k, v = item.split("=", 1)
            fields[k] = v
    for k in ["CLNSIG", "CLNVC", "MC", "GENEINFO"]:
        if k in fields:
            out[k] = fields[k]
    gi = out["GENEINFO"]
    if gi:
        # format often: GENE:ID|GENE2:ID2...
        out["SYMBOL"] = gi.split("|")[0].split(":")[0]
    return out

def normalize_chrom(chrom: str) -> str:
    chrom = str(chrom)
    if chrom.startswith("chr"):
        chrom = chrom[3:]
    if chrom == "M":
        chrom = "MT"
    return chrom

def label_from_clnsig(clnsig: str):
    """
    Label policy:
      - Positive: (Likely) Pathogenic
      - Negative: (Likely) Benign
      - Exclude: VUS / conflicting / uncertain or mixed benign+pathogenic
    """
    s = str(clnsig).lower()
    s = s.replace(" ", "_").replace("/", "_").replace(",", "_")
    if "conflicting" in s or "uncertain" in s:
        return None

    has_path = ("pathogenic" in s)
    has_ben  = ("benign" in s)

    # Mixed signals -> exclude
    if has_path and has_ben:
        return None

    if has_path:
        return 1
    if has_ben:
        return 0
    return None

def tag_variant_group(mc: str) -> str:
    mc = str(mc).lower()
    if "missense_variant" in mc:
        return "missense"
    if "frameshift_variant" in mc:
        return "frameshift"
    if "stop_gained" in mc or "nonsense" in mc:
        return "nonsense/stop_gained"
    if "synonymous_variant" in mc:
        return "synonymous"
    if "intron_variant" in mc:
        return "intronic"
    return "other"

# ------------------------------
# 3) REF/ALT-aware window extraction
# ------------------------------
def get_ref_window(ref_fa, chrom: str, pos_1based: int, width: int):
    chrom = normalize_chrom(chrom)
    if chrom not in ref_fa:
        return None
    half = width // 2
    center0 = pos_1based - 1
    start0 = center0 - half
    end0   = center0 + half + 1
    if start0 < 0 or end0 > len(ref_fa[chrom]):
        return None
    seq = ref_fa[chrom][start0:end0]
    seq = str(seq).upper()
    if len(seq) != width or "N" in seq:
        return None
    return seq, half, start0, end0

def make_ref_alt_window(ref_fa, chrom: str, pos_1based: int, ref_allele: str, alt_allele: str, width: int):
    w = get_ref_window(ref_fa, chrom, pos_1based, width)
    if w is None:
        return None
    seq_ref, center, start0, end0 = w

    ref_allele = str(ref_allele).upper()
    alt_allele = str(alt_allele).upper()

    # multi-ALT -> take first
    if "," in alt_allele:
        alt_allele = alt_allele.split(",")[0]

    # symbolic ALT, missing, etc.
    if alt_allele in {".", "*", ""}:
        return None

    # Verify REF allele matches reference at center
    if seq_ref[center:center+len(ref_allele)] != ref_allele:
        return None

    # Apply edit (supports SNP / indel)
    seq_alt_raw = seq_ref[:center] + alt_allele + seq_ref[center+len(ref_allele):]

    # Force fixed width: truncate or pad by reading extra reference after end0
    if len(seq_alt_raw) > width:
        seq_alt = seq_alt_raw[:width]
    elif len(seq_alt_raw) < width:
        need = width - len(seq_alt_raw)
        chrom_n = normalize_chrom(chrom)
        extra = str(ref_fa[chrom_n][end0:end0+need]).upper()
        if len(extra) != need or "N" in extra:
            return None
        seq_alt = seq_alt_raw + extra
    else:
        seq_alt = seq_alt_raw

    if len(seq_alt) != width or "N" in seq_alt:
        return None

    return seq_ref, seq_alt

# ------------------------------
# 4) Build dataset dataframe (chunked)
# ------------------------------
def build_variant_df(vcf_path: str, ref_fa, max_rows: int, width: int) -> pd.DataFrame:
    usecols = [0, 1, 3, 4, 7]
    names = ["CHROM", "POS", "REF", "ALT", "INFO"]

    records = []
    rows_seen = 0
    chunksize = 100_000

    t0 = time.time()
    reader = pd.read_csv(
        vcf_path,
        sep="\t",
        comment="#",
        header=None,
        usecols=usecols,
        names=names,
        dtype=str,
        engine="c",
        chunksize=chunksize
    )

    for chunk in reader:
        if rows_seen >= max_rows:
            break

        remain = max_rows - rows_seen
        if len(chunk) > remain:
            chunk = chunk.iloc[:remain]

        rows_seen += len(chunk)

        for row in chunk.itertuples(index=False):
            chrom, pos_s, ref_a, alt_a, info_s = row

            try:
                pos = int(pos_s)
            except Exception:
                continue

            info_f = parse_info(info_s)
            y = label_from_clnsig(info_f["CLNSIG"])
            if y is None:
                continue

            pair = make_ref_alt_window(ref_fa, chrom, pos, ref_a, alt_a, width=width)
            if pair is None:
                continue
            seq_ref, seq_alt = pair

            records.append({
                "CHROM": normalize_chrom(chrom),
                "POS": pos,
                "REF": ref_a,
                "ALT": alt_a,
                "CLNSIG": info_f["CLNSIG"],
                "CLASS": int(y),
                "SYMBOL": info_f["SYMBOL"] if info_f["SYMBOL"] else "UNK",
                "CLNVC": info_f["CLNVC"],
                "MC": info_f["MC"],
                "seq_ref": seq_ref,
                "seq_alt": seq_alt
            })

        print(f"Scanned {rows_seen:,} VCF rows | kept {len(records):,} labeled so far...")

    df = pd.DataFrame(records)
    dt = time.time() - t0
    print(f"\nFinished scan: {rows_seen:,} VCF rows in {dt:.1f}s")
    print(f"Kept {len(df):,} labeled allele-aware variants")
    print("Pathogenic fraction:", df["CLASS"].mean() if len(df) else np.nan)
    return df

def filter_acgt(df: pd.DataFrame) -> pd.DataFrame:
    mask_ref = df["seq_ref"].str.fullmatch(r"[ACGT]+")
    mask_alt = df["seq_alt"].str.fullmatch(r"[ACGT]+")
    bad = (~mask_ref) | (~mask_alt)
    bad_n = int(bad.sum())
    print("Bad rows (non-ACGT):", bad_n)
    if bad_n > 0:
        ex = df.loc[bad, "seq_alt"].head(5).tolist()
        print("Example bad seq_alt:", ex)
    df2 = df.loc[~bad].reset_index(drop=True)
    print("After filtering:", len(df2))
    return df2

# ------------------------------
# 5) Representations: CGR + OneHot (ref+alt stacked)
# ------------------------------
def sequence_to_cgr(seq: str, k: int) -> torch.Tensor:
    """
    Fast CGR via direct k-mer bit mapping.
    A:(0,0), C:(0,1), G:(1,1), T:(1,0)
    Output: (2^k, 2^k) float32 normalized.
    """
    mapping = {'A': (0,0), 'C': (0,1), 'G': (1,1), 'T': (1,0)}
    N = 2 ** k

    seq = str(seq).upper()
    if len(seq) < k:
        seq = seq.ljust(k, 'A')

    xb = np.fromiter((mapping.get(b, (0,0))[0] for b in seq), dtype=np.uint8, count=len(seq))
    yb = np.fromiter((mapping.get(b, (0,0))[1] for b in seq), dtype=np.uint8, count=len(seq))
    weights = (1 << np.arange(k-1, -1, -1, dtype=np.int32))

    try:
        from numpy.lib.stride_tricks import sliding_window_view
        xw = sliding_window_view(xb, k)
        yw = sliding_window_view(yb, k)
        x_idx = (xw * weights).sum(axis=1).astype(np.int32)
        y_idx = (yw * weights).sum(axis=1).astype(np.int32)
        freq = np.zeros((N, N), dtype=np.float32)
        np.add.at(freq, (y_idx, x_idx), 1.0)
    except Exception:
        freq = np.zeros((N, N), dtype=np.float32)
        for i in range(len(seq) - k + 1):
            xi = int((xb[i:i+k] * weights).sum())
            yi = int((yb[i:i+k] * weights).sum())
            freq[yi, xi] += 1.0

    m = float(freq.max())
    if m > 0:
        freq /= m
    return torch.tensor(freq, dtype=torch.float32)

def one_hot(seq: str, width: int) -> torch.Tensor:
    base_to_idx = {'A':0, 'C':1, 'G':2, 'T':3}
    seq = str(seq).upper()
    mat = torch.zeros((4, width), dtype=torch.float32)
    for j, b in enumerate(seq[:width]):
        i = base_to_idx.get(b, None)
        if i is not None:
            mat[i, j] = 1.0
    return mat

class VariantDataset(Dataset):
    def __init__(self, df: pd.DataFrame, k: int, width: int):
        self.df = df.reset_index(drop=True)
        self.k = k
        self.width = width
        self.N = 2 ** k

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        r = self.df.iloc[idx]
        sref = r["seq_ref"]
        salt = r["seq_alt"]

        cgr_ref = sequence_to_cgr(sref, self.k).unsqueeze(0)  # (1,N,N)
        cgr_alt = sequence_to_cgr(salt, self.k).unsqueeze(0)  # (1,N,N)
        img = torch.cat([cgr_ref, cgr_alt], dim=0)            # (2,N,N)

        oh_ref = one_hot(sref, self.width)                    # (4,width)
        oh_alt = one_hot(salt, self.width)                    # (4,width)
        seq = torch.cat([oh_ref, oh_alt], dim=0)              # (8,width)

        y = torch.tensor(float(r["CLASS"]), dtype=torch.float32)
        return img, seq, y

# ------------------------------
# 6) Model: dual branch
# ------------------------------
class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0):
        super().__init__()
        self.alpha = float(alpha)
        self.gamma = float(gamma)

    def forward(self, logits, labels):
        bce = F.binary_cross_entropy_with_logits(logits, labels, reduction="none")
        pt = torch.exp(-bce)
        loss = self.alpha * (1.0 - pt) ** self.gamma * bce
        return loss.mean()

class DualBranchModel(nn.Module):
    def __init__(self, k: int, seq_len: int):
        super().__init__()
        N = 2 ** k

        self.cnn2d = nn.Sequential(
            nn.Conv2d(2, 32, kernel_size=5, stride=2, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Flatten()
        )
        self.cnn1d = nn.Sequential(
            nn.Conv1d(8, 16, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.MaxPool1d(2, 2),
            nn.Conv1d(16, 32, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.MaxPool1d(2, 2),
            nn.Flatten()
        )

        with torch.no_grad():
            d1 = self.cnn2d(torch.zeros(1, 2, N, N)).shape[1]
            d2 = self.cnn1d(torch.zeros(1, 8, seq_len)).shape[1]

        self.mlp = nn.Sequential(
            nn.Linear(d1 + d2, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, 1)
        )

    def forward(self, img, seq):
        f_img = self.cnn2d(img)
        f_seq = self.cnn1d(seq)
        f = torch.cat([f_img, f_seq], dim=1)
        return self.mlp(f).view(-1)

# ------------------------------
# 7) Metrics + threshold
# ------------------------------
def predict_probs(model, loader, device):
    model.eval()
    ys, ps = [], []
    with torch.no_grad():
        for img, seq, y in loader:
            img = img.to(device, non_blocking=True)
            seq = seq.to(device, non_blocking=True)
            logits = model(img, seq)
            prob = torch.sigmoid(logits).detach().cpu().numpy()
            ys.append(y.detach().cpu().numpy())
            ps.append(prob)
    return np.concatenate(ys), np.concatenate(ps)

def metrics_binary(y_true, y_prob, thr=0.5):
    y_pred = (y_prob >= thr).astype(int)
    return {
        "n": len(y_true),
        "roc_auc": roc_auc_score(y_true, y_prob) if len(np.unique(y_true)) > 1 else np.nan,
        "pr_auc": average_precision_score(y_true, y_prob) if len(np.unique(y_true)) > 1 else np.nan,
        "acc": accuracy_score(y_true, y_pred),
        "precision": precision_score(y_true, y_pred, zero_division=0),
        "recall": recall_score(y_true, y_pred, zero_division=0),
        "f1": f1_score(y_true, y_pred, zero_division=0),
    }

def best_threshold_f1(y_true: np.ndarray, y_prob: np.ndarray) -> float:
    p, r, th = precision_recall_curve(y_true, y_prob)
    f1 = (2 * p * r) / (p + r + 1e-9)
    if th is None or len(th) == 0:
        return 0.5
    best_i = int(np.argmax(f1[1:]))  # skip first point
    return float(th[best_i])

# ------------------------------
# 8) Plotting helpers
# ------------------------------
def save_fig(path, dpi=FIG_DPI):
    plt.tight_layout()
    plt.savefig(path, dpi=dpi, bbox_inches="tight")
    plt.close()

def plot_fold_training_curves(history_df, out_png, fold):
    epochs = history_df["epoch"].values

    plt.figure(figsize=(12, 8))

    # Loss (train)
    plt.subplot(2, 2, 1)
    plt.plot(epochs, history_df["loss"].values, label="Train Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title(f"Fold {fold} - Train Loss")
    plt.legend()

    # AUC (on fold test)
    plt.subplot(2, 2, 2)
    plt.plot(epochs, history_df["auc"].values, label="Test ROC-AUC")
    plt.xlabel("Epoch")
    plt.ylabel("AUC")
    plt.title(f"Fold {fold} - ROC-AUC (Test)")
    plt.legend()

    # F1 curves (on fold test)
    plt.subplot(2, 2, 3)
    plt.plot(epochs, history_df["f1_0p5"].values, label="F1 @ 0.5 (Test)")
    plt.plot(epochs, history_df["f1_thr"].values, label="F1 @ train-chosen thr (Test)")
    plt.xlabel("Epoch")
    plt.ylabel("F1")
    plt.title(f"Fold {fold} - F1 (Test)")
    plt.legend()

    # Threshold + predicted positives
    plt.subplot(2, 2, 4)
    plt.plot(epochs, history_df["thr"].values, label="thr (from train)")
    plt.plot(epochs, history_df["pred_pos"].values, label="pred_pos (test)")
    plt.xlabel("Epoch")
    plt.title(f"Fold {fold} - Threshold / Pred Pos")
    plt.legend()

    save_fig(out_png, dpi=FIG_DPI)

def plot_mean_roc(curves, out_png):
    mean_fpr = np.linspace(0, 1, 500)
    tprs = []
    aucs = []

    for c in curves:
        fpr, tpr = c["fpr"], c["tpr"]
        tpr_i = np.interp(mean_fpr, fpr, tpr)
        tpr_i[0] = 0.0
        tprs.append(tpr_i)
        aucs.append(c["auc"])

    tprs = np.vstack(tprs)
    mean_tpr = tprs.mean(axis=0)
    std_tpr = tprs.std(axis=0)

    mean_auc = np.mean(aucs)
    std_auc = np.std(aucs)

    plt.figure(figsize=(7, 6))
    plt.plot(mean_fpr, mean_tpr, label=f"Mean ROC (AUC={mean_auc:.3f}±{std_auc:.3f})")
    plt.fill_between(mean_fpr, np.maximum(mean_tpr - std_tpr, 0), np.minimum(mean_tpr + std_tpr, 1), alpha=0.2)
    plt.plot([0, 1], [0, 1], linestyle="--", linewidth=1)
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title("Gene-disjoint CV - ROC (mean ± std)")
    plt.legend(loc="lower right")
    save_fig(out_png, dpi=FIG_DPI)

def plot_mean_pr(curves, prevalence, out_png):
    mean_recall = np.linspace(0, 1, 500)
    precs = []
    aps = []

    for c in curves:
        prec, rec = c["precision"], c["recall"]
        # ensure monotonic recall for interpolation
        order = np.argsort(rec)
        rec_s = rec[order]
        prec_s = prec[order]
        p_i = np.interp(mean_recall, rec_s, prec_s)
        precs.append(p_i)
        aps.append(c["ap"])

    precs = np.vstack(precs)
    mean_p = precs.mean(axis=0)
    std_p = precs.std(axis=0)

    mean_ap = np.mean(aps)
    std_ap = np.std(aps)

    plt.figure(figsize=(7, 6))
    plt.plot(mean_recall, mean_p, label=f"Mean PR (AP={mean_ap:.3f}±{std_ap:.3f})")
    plt.fill_between(mean_recall, np.maximum(mean_p - std_p, 0), np.minimum(mean_p + std_p, 1), alpha=0.2)
    plt.hlines(prevalence, 0, 1, linestyles="--", linewidth=1, label=f"Baseline (prev={prevalence:.3f})")
    plt.xlabel("Recall")
    plt.ylabel("Precision")
    plt.title("Gene-disjoint CV - Precision–Recall (mean ± std)")
    plt.legend(loc="lower left")
    save_fig(out_png, dpi=FIG_DPI)

def plot_confmat(cm, out_png, normalize=False, title="Confusion Matrix"):
    cm_plot = cm.astype(np.float64) if normalize else cm.copy()
    if normalize:
        cm_plot = cm_plot / (cm_plot.sum(axis=1, keepdims=True) + 1e-9)

    plt.figure(figsize=(6, 5))
    plt.imshow(cm_plot, interpolation="nearest")
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(2)
    plt.xticks(tick_marks, ["Benign", "Pathogenic"])
    plt.yticks(tick_marks, ["Benign", "Pathogenic"])

    fmt = ".3f" if normalize else "d"
    thresh = cm_plot.max() / 2.0 if cm_plot.size else 0
    for i in range(2):
        for j in range(2):
            val = cm_plot[i, j]
            plt.text(j, i, format(val, fmt),
                     ha="center", va="center",
                     color="white" if val > thresh else "black")

    plt.ylabel("True label")
    plt.xlabel("Predicted label")
    save_fig(out_png, dpi=FIG_DPI)

def plot_variant_type_bars(type_df, metric, out_png, title):
    g = type_df.groupby("group")
    means = g[metric].mean()
    stds = g[metric].std().fillna(0.0)
    ns = g["n"].mean()

    order = ns.sort_values(ascending=False).index.tolist()

    plt.figure(figsize=(10, 5))
    x = np.arange(len(order))
    plt.bar(x, means[order].values, yerr=stds[order].values, capsize=3)
    plt.xticks(x, order, rotation=25, ha="right")
    plt.ylabel(metric)
    plt.title(title)
    save_fig(out_png, dpi=FIG_DPI)

# ------------------------------
# 9) Main training + CV + plots
# ------------------------------
def main():
    set_seed(SEED)
    device, use_amp, scaler = setup_device()

    # DataLoader tuning
    NUM_WORKERS = best_num_workers()
    PIN_MEMORY = (device.type == "cuda")
    PERSISTENT = (NUM_WORKERS > 0)
    PREFETCH = 4 if (NUM_WORKERS > 0) else None
    print(f"DataLoader: num_workers={NUM_WORKERS}, pin_memory={PIN_MEMORY}, persistent={PERSISTENT}")

    # Load reference FASTA
    print("Loading reference FASTA...")
    ref_fa = Fasta(REF_FA, as_raw=True, sequence_always_upper=True)

    # Build/load df_model
    if os.path.exists(CACHE_PARQUET):
        print(f"Loading cache: {CACHE_PARQUET}")
        df_model = pd.read_parquet(CACHE_PARQUET)
        print("Loaded cached variants:", len(df_model))
    else:
        print(f"Building dataset from {VCF_PATH} (up to {MAX_VCF_ROWS:,} rows)...")
        df_model = build_variant_df(VCF_PATH, ref_fa, max_rows=MAX_VCF_ROWS, width=WINDOW_WIDTH)
        print(f"Saving cache: {CACHE_PARQUET}")
        df_model.to_parquet(CACHE_PARQUET, index=False)

    # Filter non-ACGT
    df_model = filter_acgt(df_model)

    same = float((df_model["seq_ref"] == df_model["seq_alt"]).mean()) if len(df_model) else 0.0
    print(f"Sanity: fraction where seq_ref == seq_alt: {same:.6f} (should be near 0)")

    prevalence_all = float(df_model["CLASS"].mean()) if len(df_model) else 0.0

    # Prepare CV
    X_idx = np.arange(len(df_model))
    y = df_model["CLASS"].values.astype(int)
    groups = df_model["SYMBOL"].fillna("UNK").values

    cv = StratifiedGroupKFold(n_splits=N_SPLITS, shuffle=True, random_state=SEED)

    fold_rows = []
    per_type_rows = []

    roc_curves = []
    pr_curves = []
    cm_sum = np.zeros((2, 2), dtype=np.int64)

    def make_loader(df, shuffle):
        ds = VariantDataset(df, k=CGR_K, width=WINDOW_WIDTH)
        kwargs = dict(
            batch_size=BATCH_SIZE,
            shuffle=shuffle,
            num_workers=NUM_WORKERS,
            pin_memory=PIN_MEMORY,
            persistent_workers=PERSISTENT,
        )
        if NUM_WORKERS > 0 and PREFETCH is not None:
            kwargs["prefetch_factor"] = PREFETCH
        return DataLoader(ds, **kwargs)

    for fold, (tr, te) in enumerate(cv.split(X_idx, y, groups), 1):
        df_tr = df_model.iloc[tr].reset_index(drop=True)
        df_te = df_model.iloc[te].reset_index(drop=True)

        # gene-disjoint check
        assert set(df_tr["SYMBOL"]).isdisjoint(set(df_te["SYMBOL"]))

        train_loader = make_loader(df_tr, shuffle=True)
        test_loader = make_loader(df_te, shuffle=False)
        train_eval_loader = make_loader(df_tr, shuffle=False)

        model = DualBranchModel(k=CGR_K, seq_len=WINDOW_WIDTH).to(device)
        print(f"\nFold {fold}: model device: {next(model.parameters()).device}")

        pos_frac = float(df_tr["CLASS"].mean())
        alpha = float(1.0 - pos_frac)  # upweight positives
        criterion = FocalLoss(alpha=alpha, gamma=2.0).to(device)
        optimizer = optim.Adam(model.parameters(), lr=LR)

        print(f"Fold {fold}: n_train={len(df_tr):,}, n_test={len(df_te):,}, pos_frac={pos_frac:.4f}, focal_alpha={alpha:.4f}")

        history = []

        for epoch in range(1, EPOCHS + 1):
            model.train()
            total_loss = 0.0

            for img, seq, lab in train_loader:
                img = img.to(device, non_blocking=True)
                seq = seq.to(device, non_blocking=True)
                lab = lab.to(device, non_blocking=True)

                optimizer.zero_grad(set_to_none=True)

                with torch.amp.autocast(device_type="cuda", dtype=torch.float16, enabled=use_amp):
                    logits = model(img, seq)
                    loss = criterion(logits, lab)

                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()

                total_loss += float(loss.item()) * lab.size(0)

            # threshold on TRAIN only
            y_tr, p_tr = predict_probs(model, train_eval_loader, device)
            thr = best_threshold_f1(y_tr, p_tr)

            # evaluate on fold test
            y_true, y_prob = predict_probs(model, test_loader, device)
            m05 = metrics_binary(y_true, y_prob, thr=0.5)
            m = metrics_binary(y_true, y_prob, thr=thr)

            pred_pos = float((y_prob >= thr).mean())
            loss_epoch = total_loss / len(df_tr)

            print(
                f"Fold {fold} | Epoch {epoch:02d} | "
                f"loss={loss_epoch:.4f} | "
                f"AUC={m['roc_auc']:.4f} | "
                f"F1@0.5={m05['f1']:.4f} | F1@thr={m['f1']:.4f} | "
                f"thr={thr:.3f} | pred_pos={pred_pos:.3f}"
            )

            history.append({
                "fold": fold,
                "epoch": epoch,
                "loss": loss_epoch,
                "auc": m["roc_auc"],
                "pr_auc": m["pr_auc"],
                "f1_0p5": m05["f1"],
                "f1_thr": m["f1"],
                "thr": thr,
                "pred_pos": pred_pos
            })

        # Save + plot fold training curves
        hist_df = pd.DataFrame(history)
        hist_csv = f"history_fold{fold}.csv"
        hist_png = f"Fold{fold}_TrainingCurves.png"
        hist_df.to_csv(hist_csv, index=False)
        plot_fold_training_curves(hist_df, hist_png, fold=fold)
        print(f"Saved: {hist_csv}, {hist_png}")

        # Final fold metrics (use train-chosen thr)
        y_tr, p_tr = predict_probs(model, train_eval_loader, device)
        thr = best_threshold_f1(y_tr, p_tr)

        y_true, y_prob = predict_probs(model, test_loader, device)
        m05 = metrics_binary(y_true, y_prob, thr=0.5)
        m = metrics_binary(y_true, y_prob, thr=thr)

        # save predictions per fold (reproducibility)
        np.savez_compressed(
            f"preds_fold{fold}.npz",
            y_true=y_true.astype(np.int8),
            y_prob=y_prob.astype(np.float32),
            thr=np.array([thr], dtype=np.float32),
        )

        # ROC/PR curves per fold
        fpr, tpr, _ = roc_curve(y_true, y_prob)
        roc_curves.append({"fpr": fpr, "tpr": tpr, "auc": m["roc_auc"]})

        prec, rec, _ = precision_recall_curve(y_true, y_prob)
        pr_curves.append({"precision": prec, "recall": rec, "ap": m["pr_auc"]})

        # confusion matrix at fold thr
        y_pred = (y_prob >= thr).astype(int)
        cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
        cm_sum += cm

        m.update({
            "fold": fold,
            "thr": thr,
            "f1_0p5": m05["f1"],
            "recall_0p5": m05["recall"],
            "precision_0p5": m05["precision"],
        })
        fold_rows.append(m)

        # per variant-type stratified
        df_te2 = df_te.copy()
        df_te2["y_prob"] = y_prob
        df_te2["group"] = df_te2["MC"].apply(tag_variant_group)

        for g, sub in df_te2.groupby("group"):
            mm = metrics_binary(
                sub["CLASS"].values.astype(int),
                sub["y_prob"].values.astype(float),
                thr=thr
            )
            mm["fold"] = fold
            mm["group"] = g
            per_type_rows.append(mm)

        # free GPU memory between folds
        if device.type == "cuda":
            torch.cuda.empty_cache()

    fold_df = pd.DataFrame(fold_rows)
    type_df = pd.DataFrame(per_type_rows)

    print("\n=== Gene-disjoint CV summary (mean ± std) ===")
    summary = fold_df.drop(columns=["fold"]).agg(["mean", "std"]).T
    print(summary)

    print("\n=== Variant-type stratified (mean over folds) ===")
    type_mean = (
        type_df.groupby("group")[["roc_auc","pr_auc","acc","precision","recall","f1","n"]]
        .mean()
        .sort_values("n", ascending=False)
    )
    print(type_mean)

    fold_df.to_csv(SAVE_FOLD_CSV, index=False)
    type_df.to_csv(SAVE_TYPE_CSV, index=False)
    print(f"\nSaved: {SAVE_FOLD_CSV}")
    print(f"Saved: {SAVE_TYPE_CSV}")

    # --------------------------
    # Final paper plots (CV-level)
    # --------------------------
    plot_mean_roc(roc_curves, out_png="CV_ROC_mean.png")
    plot_mean_pr(pr_curves, prevalence=prevalence_all, out_png="CV_PR_mean.png")
    print("Saved: CV_ROC_mean.png, CV_PR_mean.png")

    plot_confmat(cm_sum, out_png="CV_ConfusionMatrix_sum.png", normalize=False,
                 title="Gene-disjoint CV - Confusion Matrix (sum over folds)")
    plot_confmat(cm_sum, out_png="CV_ConfusionMatrix_sum_norm.png", normalize=True,
                 title="Gene-disjoint CV - Confusion Matrix (row-normalized)")
    print("Saved: CV_ConfusionMatrix_sum.png, CV_ConfusionMatrix_sum_norm.png")

    plot_variant_type_bars(type_df, metric="f1", out_png="VariantType_F1_bar.png",
                           title="Variant-type Stratified F1 (mean ± std over folds)")
    plot_variant_type_bars(type_df, metric="pr_auc", out_png="VariantType_PRAUC_bar.png",
                           title="Variant-type Stratified PR-AUC (mean ± std over folds)")
    print("Saved: VariantType_F1_bar.png, VariantType_PRAUC_bar.png")

# Entry point (safe for .py script; in notebook you can just call main())
if __name__ == "__main__":
    main()


Using device: cuda:0
GPU: NVIDIA GeForce RTX 3080
DataLoader: num_workers=0, pin_memory=True, persistent=False
Loading reference FASTA...
Loading cache: df_model_500k.parquet
Loaded cached variants: 191279
Bad rows (non-ACGT): 0
After filtering: 191279
Sanity: fraction where seq_ref == seq_alt: 0.000110 (should be near 0)

Fold 1: model device: cuda:0
Fold 1: n_train=160,705, n_test=30,574, pos_frac=0.1869, focal_alpha=0.8131
Fold 1 | Epoch 01 | loss=0.0789 | AUC=0.8118 | F1@0.5=0.4177 | F1@thr=0.5130 | thr=0.418 | pred_pos=0.162
Fold 1 | Epoch 02 | loss=0.0680 | AUC=0.8305 | F1@0.5=0.4644 | F1@thr=0.5300 | thr=0.446 | pred_pos=0.151
Fold 1 | Epoch 03 | loss=0.0633 | AUC=0.8349 | F1@0.5=0.4530 | F1@thr=0.5383 | thr=0.420 | pred_pos=0.149
Fold 1 | Epoch 04 | loss=0.0599 | AUC=0.8406 | F1@0.5=0.5177 | F1@thr=0.5540 | thr=0.461 | pred_pos=0.160
Fold 1 | Epoch 05 | loss=0.0573 | AUC=0.8430 | F1@0.5=0.5283 | F1@thr=0.5597 | thr=0.463 | pred_pos=0.161
Fold 1 | Epoch 06 | loss=0.0546 | AUC=0.