In [6]:
import os, glob, gc, csv
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image, ImageDraw, ImageFont
import matplotlib.cm as cm
from Bio.PDB import PDBParser
from collections import Counter

# ================================ HELPERS ==============================
def normalize01(a: np.ndarray) -> np.ndarray:
    a = a.astype(np.float32, copy=False)
    mn, mx = float(a.min()), float(a.max())
    if mx - mn < 1e-8:
        return np.zeros_like(a, dtype=np.float32)
    return (a - mn) / (mx - mn)

def save_png_gray(arr: np.ndarray, path: str):
    Image.fromarray((normalize01(arr)*255).astype(np.uint8)).save(path)

def save_overlay(inp: np.ndarray, cam: np.ndarray, path: str):
    inp01 = normalize01(inp)
    cam01 = normalize01(cam)
    inp_rgb  = (np.stack([inp01]*3, -1)*255).astype(np.uint8)
    heatmap  = (cm.jet(cam01)[..., :3]*255).astype(np.uint8)
    overlay  = (0.6*inp_rgb + 0.4*heatmap).clip(0,255).astype(np.uint8)
    Image.fromarray(overlay).save(path)

def save_line_plot_pil(values, out_path, title="", xlabel="", ylabel="", size=(1000, 360), margin=60):
    values = np.asarray(values, dtype=np.float32)
    W,H = size
    img = Image.new("RGB", (W,H), "white")
    draw = ImageDraw.Draw(img)
    left, top = margin, margin
    right, bottom = W - margin, H - margin
    # axes
    draw.line([(left,bottom),(right,bottom)], fill="black", width=2)
    draw.line([(left,top),(left,bottom)], fill="black", width=2)
    if values.size > 0:
        vmin, vmax = float(values.min()), float(values.max())
        norm = (values - vmin) / max(vmax - vmin, 1e-8)
        xs = np.linspace(left, right, len(values))
        ys = bottom - norm * (bottom - top)
        for i in range(len(xs)-1):
            draw.line([(xs[i], ys[i]), (xs[i+1], ys[i+1])], fill=(30,144,255), width=2)
    try:
        font = ImageFont.load_default()
        if title:  draw.text((left, top-25), title, fill="black", font=font)
        if xlabel: draw.text((left + (right-left)//2 - 40, bottom + 8), xlabel, fill="black", font=font)
        if ylabel: draw.text((8, top - 10), ylabel, fill="black", font=font)
    except: pass
    img.save(out_path, format="PNG")

def save_bar_plot_pil(values, labels, out_path, title="", size=(1200, 500), margin=80, color=(30,144,255)):
    values = np.asarray(values, dtype=np.float32)
    n = len(values); W,H = size
    img = Image.new("RGB", (W,H), "white")
    draw = ImageDraw.Draw(img)
    left, top = margin, margin
    right, bottom = W - margin, H - margin
    plot_w, plot_h = right-left, bottom-top
    try:
        font = ImageFont.load_default()
        if title: draw.text((left, top-25), title, fill="black", font=font)
    except: pass
    if n == 0:
        img.save(out_path); return
    vmax = float(values.max()); scale = plot_w / max(vmax, 1e-8)
    bar_h = plot_h / max(n,1)
    for i, v in enumerate(values):
        y0 = top + i*bar_h + 4; y1 = top + (i+1)*bar_h - 4; x1 = left + v*scale
        draw.rectangle([left, y0, x1, y1], fill=color)
        lbl = str(labels[i])[:22] + ("…" if len(str(labels[i]))>22 else "")
        try:
            draw.text((5, (y0+y1)/2 - 6), lbl, fill="black", font=font)
            draw.text((x1 + 5, (y0+y1)/2 - 6), f"{v:.3f}", fill="black", font=font)
        except: pass
    img.save(out_path, format="PNG")

def load_sample(path, exp_index):
    """
    Returns:
      M_in: np.ndarray [H,37]  -> matrix you feed the model (your multiplied matrix)
      encoded: np.ndarray [H,37] -> original one-hot atom subtype matrix
    """
    basename = os.path.basename(path)
    protein_id = basename[:4]
    M_in = np.load(path)

    one_hot_path = f"/home/alexhernandez/transmembranebindingAI/Data/SplitData/Cholesterol/IvanTestSet/ivan-separate-graphs-5A/positive/{protein_id}-filtered_graphs.npy"

    obj = np.load(one_hot_path, allow_pickle=True).item()
    encoded = obj["encoded_matrix"]
    return M_in, encoded


def atom_type_from_onehot(onehot_row, type_names):
    return type_names[int(np.argmax(onehot_row))] if np.any(onehot_row) else "UNKNOWN"

def align_encoded_to_padded(M_in: np.ndarray, ENC: np.ndarray):
    """
    Align a shorter, unpadded one-hot matrix (ENC: [N,37]) to a padded model input
    (M_in: [H,37]) by placing ENC rows into the non-padded rows of M_in.

    Returns:
      ENC_aligned: [H,37] one-hot aligned/padded with zeros on padded rows
      nonpad_mask: [H] boolean mask where True = real atom row
    """
    assert M_in.ndim == 2 and M_in.shape[1] == 37, f"M_in must be [H,37], got {M_in.shape}"
    assert ENC.ndim == 2 and ENC.shape[1] == 37,   f"ENC must be [N,37], got {ENC.shape}"

    H = M_in.shape[0]
    # Detect padded rows in M_in: all ~zero across 37 cols
    nonpad_mask = ~np.isclose(M_in, 0.0, atol=1e-8).all(axis=1)
    real_idx = np.flatnonzero(nonpad_mask)

    ENC_aligned = np.zeros_like(M_in, dtype=np.float32)

    if len(real_idx) == ENC.shape[0]:
        # Best case: one-to-one—put ENC rows into the detected real rows
        ENC_aligned[real_idx, :] = ENC
    else:
        # Fallback: assume ENC rows correspond to the first N rows
        n = min(H, ENC.shape[0])
        ENC_aligned[:n, :] = ENC[:n, :]
        # Optional: warn if strong mismatch
        if abs(len(real_idx) - ENC.shape[0]) > 0:
            print(f"[WARN] nonpad rows in M_in ({len(real_idx)}) != ENC rows ({ENC.shape[0]}). "
                  f"Used first-{n} fallback alignment.")

    return ENC_aligned, nonpad_mask

def save_bar_with_errorbars(values, errors, labels, out_path,
                            title="",
                            ylabel="Percent of Top-K atoms"):
    import matplotlib.pyplot as plt
    import numpy as np

    x = np.arange(len(values))
    fig, ax = plt.subplots(figsize=(max(9, 0.35*len(values)), 5))
    bars = ax.bar(x, values)
    ax.errorbar(x, values, yerr=errors, fmt='none', capsize=4, linewidth=1)

    ax.set_xticks(x)
    ax.set_xticklabels(labels, rotation=75, ha='right')
    ax.set_ylabel(ylabel)
    ax.set_title(title)
    fig.tight_layout()
    fig.savefig(out_path, dpi=200)
    plt.close(fig)

def get_atom_residues(pdb_file):
    parser = PDBParser(QUIET=True)
    structure = parser.get_structure("mol", pdb_file)
    atoms = []

    for model in structure:
        for chain in model:
            for residue in chain:
                for atom in residue:
                    if atom.is_disordered():
                        # Add ALL conformers (A, B, etc.)
                        for alt_atom in atom.disordered_get_list():
                            atoms.append(alt_atom)
                    else:
                        atoms.append(atom)

    residues = [
        (
            atom.get_parent().get_resname(),              # residue name (e.g., LEU)
            atom.get_parent().get_id()[1],                # residue number
            atom.get_parent().get_parent().get_id()       # chain ID
        )
        for atom in atoms
    ]
    return residues

RES_ORDER_CANON = [
    "ALA","ARG","ASN","ASP","CYS","GLN","GLU","GLY","HIS","ILE",
    "LEU","LYS","MET","PHE","PRO","SER","THR","TRP","TYR","VAL"
]

def counts_to_percent(counts: Counter, total: int) -> dict:
    if total <= 0:
        return {k: 0.0 for k in counts}
    return {k: (v/total)*100.0 for k, v in counts.items()}

def order_residue_labels(all_labels):
    # Start with canonical residues, then any others (alphabetical)
    others = sorted([r for r in all_labels if r not in RES_ORDER_CANON])
    ordered = [r for r in RES_ORDER_CANON if r in all_labels] + others
    return ordered

def write_residue_freq_csv(path, labels, counts_row, pct_row, total_k, k_per_sample):
    # counts_row/pct_row are dicts keyed by residue label
    import csv
    with open(path, "w", newline="") as f:
        w = csv.writer(f)
        w.writerow(["residue","count_in_topK","percent_of_topK","topk_atoms_total","K_per_sample"])
        for r in labels:
            w.writerow([r, int(counts_row.get(r,0)), float(pct_row.get(r,0.0)), int(total_k), int(k_per_sample)])

In [7]:
OUT_ROOT_BASE = r"gnn_out_external"
os.makedirs(OUT_ROOT_BASE, exist_ok=True)

# Collect per-experiment subtype % vectors (length W=37) for cross-experiment stats
cross_exp_pct_list = []
cross_exp_expnames = []

for exp_index in range(5):
    # if exp_index == 0:
    #     M_GLOB = r"/home/alexhernandez/transmembranebindingAI/Data/SplitData/Cholesterol/cholesterol-graph-5A/Test/Positive/*.npy"
    # else:
    #     M_GLOB = f"/home/alexhernandez/transmembranebindingAI/Data/SplitData/Cholesterol/cholesterol-graph-5A_exp{exp_index + 1}/Test/Positive/*.npy"
    M_GLOB = f"/home/alexhernandez/transmembranebindingAI/Data/SplitData/Cholesterol/IvanTestSet/ivan-graph-5A/positive/*.npy"
    MODEL_GLOB   = f"/home/alexhernandez/transmembranebindingAI/Models/Cholesterol/GNN/GNN-5A_Exp{exp_index + 1}/Models/*.pth"
    # ============================== CONFIG ===============================
    OUT_ROOT = os.path.join(OUT_ROOT_BASE, f"exp{exp_index+1}")

    print(exp_index, "is exp index", M_GLOB, " is M_GLOB")

    # Your 37 atom-type names in the exact column order of your input matrices:
    ATOM_TYPE_NAMES = [
        'C', 'CA', 'CB', 'CD', 'CD1', 'CD2', 'CE', 'CE1', 'CE2', 'CE3', 'CG', 'CG1', 'CG2', 'CH2', 'CZ', 'CZ2', 'CZ3',
        'O', 'OH', 'OD1', 'OD2', 'OE1', 'OE2', 'OG', 'OG1',
        'N', 'NE', 'NE1', 'NE2', 'ND1', 'ND2', 'NZ', 'NH1', 'NH2',
        'SD', 'SG'
    ]
    ATOM_TYPE_NAMES.append('UNKNOWN')  # makes 37 total

    # Grad-CAM target: "positive" shows evidence FOR the positive class.
    TARGET_CLASS = "positive"

    # Safety: CPU-only is most stable. Set to False to try GPU with aggressive cleanup.
    FORCE_CPU = True

    if FORCE_CPU:
        os.environ["CUDA_VISIBLE_DEVICES"] = ""
    DEVICE = torch.device("cuda" if (not FORCE_CPU and torch.cuda.is_available()) else "cpu")
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    os.makedirs(OUT_ROOT, exist_ok=True)

    # ============================ MODEL & GRAD-CAM =========================
    class CNN2D(nn.Module):
        def __init__(self, in_ch=1, out_grid=(4,18)):
            super().__init__()
            self.conv1 = nn.Sequential(nn.Conv2d(in_ch,32,3,padding=1), nn.BatchNorm2d(32), nn.ReLU())
            self.pool1 = nn.MaxPool2d(2,2)
            self.conv2 = nn.Sequential(nn.Conv2d(32,64,3,padding=1), nn.BatchNorm2d(64), nn.ReLU())
            self.pool2 = nn.MaxPool2d(2,2)
            self.conv3 = nn.Sequential(nn.Conv2d(64,128,3,padding=1), nn.BatchNorm2d(128), nn.ReLU())
            self.pool3 = nn.MaxPool2d(2,2)
            self.adapt  = nn.AdaptiveAvgPool2d(out_grid)
            self.flat   = nn.Flatten()
            self.fc1    = nn.Linear(128*out_grid[0]*out_grid[1], 128)
            self.drop   = nn.Dropout(0.5)
            self.fc2    = nn.Linear(128, 1)
        def forward(self, x):
            x = self.pool1(self.conv1(x))
            x = self.pool2(self.conv2(x))
            x = self.pool3(self.conv3(x))
            x = self.adapt(x)
            x = self.flat(x)
            x = torch.relu(self.fc1(x))
            x = self.drop(x)
            return self.fc2(x)

    class GradCAM:
        def __init__(self, model, target_layer):
            self.model, self.target_layer = model.eval(), target_layer
            self.activations, self.gradients = None, None
            self._fwd = target_layer.register_forward_hook(self._fwd_hook)
            self._bwd = target_layer.register_full_backward_hook(self._bwd_hook)
        def _fwd_hook(self, m,i,o): self.activations = o.detach()
        def _bwd_hook(self, m,gi,go): self.gradients = go[0].detach()
        @staticmethod
        @torch.no_grad()
        def _norm(cam):
            B = cam.size(0)
            flat = cam.view(B,-1)
            flat = flat - flat.min(dim=1, keepdim=True).values
            flat = flat / flat.max(dim=1, keepdim=True).values.clamp_min(1e-8)
            return flat.view_as(cam)
        def generate(self, x, target="positive"):
            B,C,H,W = x.shape
            self.model.zero_grad(set_to_none=True)
            logits = self.model(x).squeeze(-1)           # [B]
            score  = logits if target=="positive" else (-logits)
            self.model.zero_grad(set_to_none=True)
            score.sum().backward()
            A, dA = self.activations, self.gradients
            if A is None or dA is None:
                raise RuntimeError("Grad-CAM hooks failed")
            w = dA.mean(dim=(2,3), keepdim=True)         # [B,C,1,1]
            cam = (w*A).sum(dim=1, keepdim=True)         # [B,1,h,w]
            cam = F.relu(cam)
            if cam.shape[-2:] != (H,W):
                cam = F.interpolate(cam, size=(H,W), mode='bilinear', align_corners=False)
            return self._norm(cam), logits.detach()
        def remove(self):
            self._fwd.remove(); self._bwd.remove()

    def last_conv2d(model):
        lc = None
        for m in model.modules():
            if isinstance(m, nn.Conv2d): lc = m
        return lc

    def load_model(path, in_ch=1):
        # weights_only=True removes the security warning and future-proofs your code
        sd = torch.load(path, map_location=DEVICE, weights_only=True)
        if isinstance(sd, dict) and "state_dict" in sd:
            sd = sd["state_dict"]
        sd = { (k[7:] if k.startswith("module.") else k): v for k,v in sd.items() }
        m = CNN2D(in_ch=in_ch).to(DEVICE)
        m.load_state_dict(sd, strict=True)
        m.eval()
        return m

    # ========================= DISCOVER DATA/MODELS ========================
    npy_files = sorted(glob.glob(M_GLOB))
    print(npy_files[0], "is first npy file")
    assert npy_files, f"No .npy files match: {M_GLOB}"

    model_paths = sorted(glob.glob(MODEL_GLOB))

    # Peek first sample using safe loader to get H,W
    M0, ENC0 = load_sample(npy_files[0], (exp_index + 1))     # M0 -> model input, ENC0 -> original one-hot
    assert M0.ndim == 2 and M0.shape[1] == 37, f"Expected [H,37], got {M0.shape}"
    assert ENC0.ndim == 2 and ENC0.shape[1] == 37, f"Expected ENC0 [N,37], got {ENC0.shape}"

    # Align the one-hot to the padded input
    ENC0_aligned, nonpad0 = align_encoded_to_padded(M0, ENC0)

    H, W = M0.shape
    assert W == 37, f"Expected 37 type columns, got {W}"
    assert len(ATOM_TYPE_NAMES) == 37, f"ATOM_TYPE_NAMES must have 37 entries; got {len(ATOM_TYPE_NAMES)}"

    # ========================= PRELOAD MODELS & GRADCAM ====================
    def preload_models(paths, in_ch=1):
        """Load all models once and attach persistent Grad-CAM hooks."""
        bank = []
        for mp in paths:
            try:
                m = load_model(mp, in_ch=in_ch)     # already on DEVICE, eval(), strict load
                conv = last_conv2d(m)
                if conv is None:
                    print(f"[SKIP no conv2d] {os.path.basename(mp)}")
                    # release to be safe
                    del m
                    continue
                gc_obj = GradCAM(m, conv)           # registers hooks
                bank.append({"path": mp, "model": m, "gc": gc_obj})

            except Exception as e:
                print(f"[SKIP preload {os.path.basename(mp)}] {e}")
                try:
                    del m, gc_obj
                except:
                    pass
                gc.collect()
                if DEVICE.type == 'cuda':
                    torch.cuda.empty_cache()
        if not bank:
            raise RuntimeError("No usable models after preload.")
        print(f"[INFO] Preloaded {len(bank)} model(s).")
        return bank

    models_bank = preload_models(model_paths, in_ch=1)

    # ========================= AGGREGATE ACCUMULATORS ======================
    sum_atom_importance = np.zeros((H,), dtype=np.float64)
    sum_type_importance = np.zeros((W,), dtype=np.float64)
    sum_cam             = np.zeros((H,W), dtype=np.float64)
    count_samples       = 0
    # Subtype frequency among Top-K atoms (aggregated over all samples)
    topk_subtype_counts = np.zeros((W,), dtype=np.int64)
    topk_total_atoms    = 0
    topk_residue_counts = Counter()
    topk_total_atoms_res = 0

    # CSV for per-sample ensemble scores
    scores_path = os.path.join(OUT_ROOT, "per_sample_scores.csv")
    with open(scores_path, "w", newline="") as fcsv:
        csv.writer(fcsv).writerow(["sample_name","ensemble_prob_mean","num_models_used"])

    # ========================= PROCESS EACH POSITIVE =======================
    for sidx, fpath in enumerate(npy_files):
        try:
            # ---- load model input + original one-hot ----
            M_in, ENC = load_sample(fpath, (exp_index + 1))      # both [H,37]
            ENC_aligned, nonpad_mask = align_encoded_to_padded(M_in, ENC)

            # ---- shape to model input [1,1,H,37] ----
            x_np = M_in[None, None, :, :].astype(np.float32)  # [1,1,H,37]
            x_t  = torch.from_numpy(x_np).to(DEVICE)

            run_sum = np.zeros((H,37), dtype=np.float64)
            logits_list = []
            used = 0

            # ---- look up residues from pdb ----
            basename = os.path.basename(fpath)
            protein_id = basename[:4]
            residues = get_atom_residues(f"/home/alexhernandez/transmembranebindingAI/Notebooks/Cholesterol/GNN/ivan-pdbs-distinct-5A/positive/{protein_id}-filtered.pdb")

            # --------- REUSE PRELOADED MODELS & HOOKS ---------
            for entry in models_bank:
                gc_obj = entry["gc"]
                try:
                    cam_b1hw, logits = gc_obj.generate(x_t, target=TARGET_CLASS)
                    cam = cam_b1hw[0,0].detach().cpu().numpy()  # [H,37]
                    run_sum += cam
                    logits_list.append(float(logits[0].item()))
                    used += 1
                except Exception as e:
                    print(f"[SKIP {os.path.basename(entry['path'])} on {os.path.basename(fpath)}] {e}")
                finally:
                    # Ensure no graph is held across iterations
                    del cam_b1hw, logits
                    gc.collect()
                    if DEVICE.type == 'cuda':
                        torch.cuda.empty_cache()

            if used == 0:
                print(f"[WARN] {os.path.basename(fpath)} produced no CAMs; skipped.")
                continue

            # Average across models, then (optionally) renormalize to [0,1] for visuals
            cam_mean = (run_sum / used).astype(np.float32)     # [H,37]
            cam_mean_vis = normalize01(cam_mean)               # for images only

            mean_logit = float(np.mean(logits_list))
            prob = float(1/(1+np.exp(-mean_logit)))

            # ===================== IMPORTANCE LOGIC (unchanged) =====================
            atom_importance = cam_mean.sum(axis=1).astype(np.float32)   # [H]
            type_importance = cam_mean.sum(axis=0).astype(np.float32)   # [37]

            # ===================== OUTPUTS (unchanged) =====================
            sname = os.path.splitext(os.path.basename(fpath))[0]
            sdir  = os.path.join(OUT_ROOT, f"sample_{sidx:05d}_{sname}")
            os.makedirs(sdir, exist_ok=True)

            np.save(os.path.join(sdir, "input_M.npy"), M_in)
            np.save(os.path.join(sdir, "encoded_matrix.npy"), ENC)
            np.save(os.path.join(sdir, "ensemble_cam.npy"), cam_mean)
            np.save(os.path.join(sdir, "atom_importance.npy"), atom_importance)
            np.save(os.path.join(sdir, "type_importance.npy"), type_importance)

            # atoms CSV
            atoms_csv = os.path.join(sdir, "atoms_importance.csv")
            with open(atoms_csv, "w", newline="") as f:
                w = csv.writer(f)
                w.writerow(["atom_index","is_real_atom","atom_type_from_onehot","importance"])
                for a in range(H):
                    is_real = int(nonpad_mask[a])
                    if not is_real:
                        continue
                    atom_type = ATOM_TYPE_NAMES[int(np.argmax(ENC_aligned[a]))] if np.any(ENC_aligned[a]) else "UNKNOWN"
                    w.writerow([a, is_real, atom_type, float(atom_importance[a])])

            # per-type CSV
            types_csv = os.path.join(sdir, "type_importance.csv")
            with open(types_csv, "w", newline="") as f:
                w = csv.writer(f)
                w.writerow(["atom_type","importance_sum"])
                for j in range(37):
                    w.writerow([ATOM_TYPE_NAMES[j], float(type_importance[j])])

            # Visuals (unchanged)
            save_png_gray(M_in,                    os.path.join(sdir, "input_M.png"))
            save_png_gray(cam_mean_vis,            os.path.join(sdir, "ensemble_cam.png"))
            save_overlay(M_in, cam_mean_vis,       os.path.join(sdir, "ensemble_cam_overlay.png"))

            save_line_plot_pil(atom_importance,    os.path.join(sdir, "atom_importance_sum_plot.png"),
                            title="Atom importance (SUM over type-columns)",
                            xlabel="Atom index", ylabel="Importance (sum)")

            # Top-K atoms (respecting nonpad)
            valid_rows_mask = nonpad_mask & (ENC_aligned.sum(axis=1) > 0)
            valid_idx = np.flatnonzero(valid_rows_mask)
            if valid_idx.size > 0:
                order_local = np.argsort(atom_importance[valid_idx])[::-1]
                k_local = min(5, order_local.size)
                top_rows = valid_idx[order_local[:k_local]]
                top_vals   = atom_importance[top_rows]
                top_labels = [f"{idx} ({ATOM_TYPE_NAMES[int(np.argmax(ENC_aligned[idx]))]})" for idx in top_rows]
                save_bar_plot_pil(top_vals, top_labels, os.path.join(sdir, "top_atoms_bar.png"),
                                title=f"Top {k_local} atoms by importance (SUM across types)")

                # ---- Global Top-K subtype frequency updates ----
                subtypes_top = np.argmax(ENC_aligned[top_rows], axis=1)
                np.add.at(topk_subtype_counts, subtypes_top, 1)
                topk_total_atoms += k_local

            # append sample score
            with open(scores_path, "a", newline="") as fcsv:
                csv.writer(fcsv).writerow([sname, prob, used])

            # aggregates
            sum_atom_importance += atom_importance
            sum_type_importance += type_importance
            sum_cam             += cam_mean
            count_samples       += 1
            # ---- Residue frequency updates for this experiment (Top-K atoms) ----
            # SAFETY: ensure residues array is aligned to your real atom rows.
            # A good check: assert len(residues) == int(nonpad_mask.sum()).
            try:
                assert len(residues) == int(nonpad_mask.sum()), \
                    f"Residue list ({len(residues)}) != real atoms ({int(nonpad_mask.sum())}). Check alignment."
            except AssertionError as _e:
                # You can print a warning and skip residue counting for this sample if misaligned.
                print("[WARN residues alignment]", _e)
            else:
                # Map padded row index -> compact real-atom index
                # This relies on your padding scheme being "left aligned" with nonpad rows in order.
                # If your aligner reorders, replace the map with a true index map from align_encoded_to_padded.
                real_row_indices = np.flatnonzero(nonpad_mask)  # positions of real atoms in padded matrix
                padded_to_real = {p_idx: r_idx for r_idx, p_idx in enumerate(real_row_indices)}

                # Collect residue names for Top-K rows
                for pr in top_rows:
                    if pr in padded_to_real:
                        rr = padded_to_real[pr]          # 0..(num_real_atoms-1)
                        resname, resnum, chain = residues[rr]
                        topk_residue_counts[resname] += 1
                        topk_total_atoms_res += 1
                    else:
                        # Shouldn't happen if top_rows ⊆ real rows; keep for safety
                        pass


        except Exception as e:
            print(f"[FAIL sample {fpath}] {e}")

    # ============================== CLEANUP =================================
    # Remove hooks and free models after all samples
    for entry in models_bank:
        try:
            entry["gc"].remove()
        except Exception:
            pass
        # keep explicit deletes to encourage early free
        del entry["gc"]
        del entry["model"]
    gc.collect()
    if DEVICE.type == 'cuda':
        torch.cuda.empty_cache()

    # ============================= AGGREGATES ==============================
    if count_samples > 0:
        agg_dir = os.path.join(OUT_ROOT, "aggregate")
        os.makedirs(agg_dir, exist_ok=True)

        mean_atom_importance = (sum_atom_importance / count_samples).astype(np.float32)  # [H]
        mean_type_importance = (sum_type_importance / count_samples).astype(np.float32)  # [37]
        mean_cam             = (sum_cam / count_samples).astype(np.float32)             # [H,37]

        # arrays
        np.save(os.path.join(agg_dir, "mean_atom_importance.npy"), mean_atom_importance)
        np.save(os.path.join(agg_dir, "mean_type_importance.npy"), mean_type_importance)
        np.save(os.path.join(agg_dir, "mean_cam.npy"),             mean_cam)

        # CSVs
        with open(os.path.join(agg_dir, "mean_atom_importance.csv"), "w", newline="") as f:
            w = csv.writer(f); w.writerow(["atom_index","mean_importance"])
            for a in range(H): w.writerow([a, float(mean_atom_importance[a])])

        with open(os.path.join(agg_dir, "mean_type_importance.csv"), "w", newline="") as f:
            w = csv.writer(f); w.writerow(["atom_type","mean_importance"])
            for j in range(37): w.writerow([ATOM_TYPE_NAMES[j], float(mean_type_importance[j])])

        # Top-K subtype frequency CSV/plot (unchanged)
        if topk_total_atoms > 0:
            freq = topk_subtype_counts.astype(np.int64)
            pct  = (freq / topk_total_atoms) * 100.0
            with open(os.path.join(agg_dir, "subtype_topk_frequency.csv"), "w", newline="") as f:
                w = csv.writer(f)
                w.writerow(["atom_type", "count_in_topK", "percent_of_topK", "topk_atoms_total", "K_per_sample"])
                for j in range(W):
                    w.writerow([ATOM_TYPE_NAMES[j], int(freq[j]), float(pct[j]), int(topk_total_atoms), int(5)])

            order = np.argsort(freq)[::-1]
            order = [j for j in order if freq[j] > 0]
            if order:
                save_bar_plot_pil(
                    freq[order],
                    [ATOM_TYPE_NAMES[j] for j in order],
                    os.path.join(agg_dir, "top_subtypes_topk_frequency_bar.png"),
                    title=f"Subtype frequency among Top-K atoms (total counted: {topk_total_atoms})"
                )
            # ---- Save this experiment's percent vector for cross-experiment stats ----
            cross_exp_pct_list.append(pct.copy())              # pct is length W
            cross_exp_expnames.append(f"exp{exp_index+1}")

        # Keep this experiment's raw counts (length W) for cross-experiment counts table
        if topk_total_atoms > 0:
            if 'cross_exp_freq_list' not in globals():
                cross_exp_freq_list = []
                cross_exp_expnames_counts = []
            cross_exp_freq_list.append(freq.copy())                          # raw counts per subtype (length W)
            cross_exp_expnames_counts.append(f"exp{exp_index+1}")
        
        # -------- Per-experiment RESIDUE frequency outputs --------
        if topk_total_atoms_res > 0:
            # Normalize to percent
            res_pct = counts_to_percent(topk_residue_counts, topk_total_atoms_res)

            # Choose a stable label order across experiments: union with canonical preference
            exp_res_labels = order_residue_labels(list(topk_residue_counts.keys()))

            # Save per-experiment residues CSV
            residues_csv = os.path.join(agg_dir, "residues_topk_frequency.csv")
            write_residue_freq_csv(
                residues_csv,
                exp_res_labels,
                topk_residue_counts,
                res_pct,
                topk_total_atoms_res,
                5  # your K per-sample (adjust if you changed it)
            )

            # Save per-experiment bar (counts). You can also plot % if you prefer.
            # Reuse your save_bar_plot_pil helper; if it expects numpy arrays, cast as needed.
            from numpy import array as nparr
            save_bar_plot_pil(
                nparr([topk_residue_counts[r] for r in exp_res_labels]),
                exp_res_labels,
                os.path.join(agg_dir, "residues_topk_frequency_bar.png"),
                title=f"Residue frequency among Top-K atoms (total counted: {topk_total_atoms_res})"
            )

            # ---- Collect for cross-experiment mean±SD ----
            # We collect as a dense vector aligned to a global label order later.
            if 'cross_exp_residue_pct_list' not in globals():
                cross_exp_residue_pct_list = []
                cross_exp_residue_counts_list = []
                cross_exp_residue_labels_union = set()
                cross_exp_residue_expnames = []

            cross_exp_residue_pct_list.append(res_pct)           # dict: residue -> %
            cross_exp_residue_counts_list.append(topk_residue_counts.copy())
            cross_exp_residue_labels_union.update(res_pct.keys())
            cross_exp_residue_expnames.append(f"exp{exp_index+1}")



        # visuals
        save_line_plot_pil(mean_atom_importance, os.path.join(agg_dir, "mean_atom_importance_plot.png"),
                        title="Mean atom importance (positives)", xlabel="Atom index", ylabel="Mean importance")
        order = np.argsort(mean_type_importance)[::-1]
        save_bar_plot_pil(mean_type_importance[order], [ATOM_TYPE_NAMES[i] for i in order],
                        os.path.join(agg_dir, "mean_type_importance_bar.png"),
                        title="Mean atom-type importance (positives)")
        save_png_gray(mean_cam, os.path.join(agg_dir, "mean_cam.png"))

    print(f"[DONE] Processed {count_samples} positive sample(s). Outputs -> {OUT_ROOT}")

# =================== CROSS-EXPERIMENT: COUNTS TABLE (Top-K) ===================
# Creates a CSV like the example:
# __experiment__,C,N,O,SD,CE,CZ3,CH2,CD1,CG2,CB,OG1,CA,CZ,OH,NH1,NE,NH2,SG,OG,CE2,CE3,CD2,CG,OD1,ND2,NE1,CE1,CG1,NE2,OD2,CD,OE1,CZ2,ND1
if 'cross_exp_freq_list' in globals() and cross_exp_freq_list:
    exp_freqs = np.stack(cross_exp_freq_list, axis=0)  # [E, W]
    E, W_local = exp_freqs.shape
    assert W_local == len(ATOM_TYPE_NAMES), "Mismatch in subtype dimension."

    # Map atom type -> column index in working vectors
    type_to_idx = {name: idx for idx, name in enumerate(ATOM_TYPE_NAMES)}

    cross_dir = os.path.join(OUT_ROOT_BASE, "cross_experiment")
    os.makedirs(cross_dir, exist_ok=True)

    out_counts_csv = os.path.join(cross_dir, "top_subtypes_topk_frequency_counts_by_experiment.csv")
    with open(out_counts_csv, "w", newline="") as f:
        w = csv.writer(f)
        w.writerow(["__experiment__"] + ATOM_TYPE_NAMES)

        for e, exp_name in enumerate(cross_exp_expnames_counts):
            row = [exp_name]
            for t in ATOM_TYPE_NAMES:
                idx = type_to_idx.get(t, None)
                val = int(exp_freqs[e, idx]) if idx is not None else 0
                row.append(val)
            w.writerow(row)

# ======================= CROSS-EXPERIMENT RESIDUE AGG =======================
if 'cross_exp_residue_pct_list' in globals() and cross_exp_residue_pct_list:
    cross_dir = os.path.join(OUT_ROOT_BASE, "cross_experiment")
    os.makedirs(cross_dir, exist_ok=True)

    # Create a stable global order for residues (canonical-first, then others)
    all_residues = order_residue_labels(list(cross_exp_residue_labels_union))

    # Build dense matrix [E, R] of percents in this order
    import numpy as np, csv
    E = len(cross_exp_residue_pct_list)
    R = len(all_residues)
    pct_mat = np.zeros((E, R), dtype=float)

    for e, pct_dict in enumerate(cross_exp_residue_pct_list):
        for j, res in enumerate(all_residues):
            pct_mat[e, j] = float(pct_dict.get(res, 0.0))

    mean_pct = pct_mat.mean(axis=0)
    std_pct  = pct_mat.std(axis=0, ddof=1) if E > 1 else np.zeros_like(mean_pct)

    # CSV: mean/std plus per-experiment columns
    csv_path = os.path.join(cross_dir, "top_residues_topk_frequency_mean_std.csv")
    with open(csv_path, "w", newline="") as f:
        w = csv.writer(f)
        header = ["residue", "mean_percent", "std_percent"] + [f"{name}_percent" for name in cross_exp_residue_expnames]
        w.writerow(header)
        for j, res in enumerate(all_residues):
            row = [res, float(mean_pct[j]), float(std_pct[j])] + [float(pct_mat[e, j]) for e in range(E)]
            w.writerow(row)

    # Plot mean ± SD (reuse your errorbar helper)
    order = np.argsort(mean_pct)[::-1]
    order = [idx for idx in order if (mean_pct[idx] > 0) or (std_pct[idx] > 0)]
    if order:
        png_path = os.path.join(cross_dir, "top_residues_topk_frequency_mean_std.png")
        save_bar_with_errorbars(
            values=mean_pct[order],
            errors=std_pct[order],
            labels=[all_residues[j] for j in order],
            out_path=png_path,
            title=f"Residue frequency among Top-K atoms (mean ± SD across {E} experiments)",
            ylabel="Percent of Top-K atoms"
        )

    # (Optional) also save cross-experiment COUNTS table aligned by residues
    # Build counts matrix [E, R]
    cnt_mat = np.zeros((E, R), dtype=int)
    for e, cnts in enumerate(cross_exp_residue_counts_list):
        for j, res in enumerate(all_residues):
            cnt_mat[e, j] = int(cnts.get(res, 0))

    out_counts_csv = os.path.join(cross_dir, "top_residues_topk_frequency_counts_by_experiment.csv")
    with open(out_counts_csv, "w", newline="") as f:
        w = csv.writer(f)
        w.writerow(["__experiment__"] + all_residues)
        for e, name in enumerate(cross_exp_residue_expnames):
            w.writerow([name] + [int(cnt_mat[e, j]) for j in range(R)])

# ======================= CROSS-EXPERIMENT AGGREGATION =======================
if cross_exp_pct_list:
    exp_pcts = np.stack(cross_exp_pct_list, axis=0)  # [E, W]
    E, W = exp_pcts.shape

    mean_pct = exp_pcts.mean(axis=0)                       # [W]
    std_pct  = exp_pcts.std(axis=0, ddof=1) if E > 1 else np.zeros_like(mean_pct)

    # Order by mean descending, keep only those with any signal
    order = np.argsort(mean_pct)[::-1]
    order = [j for j in order if mean_pct[j] > 0 or std_pct[j] > 0]

    cross_dir = os.path.join(OUT_ROOT_BASE, "cross_experiment")
    os.makedirs(cross_dir, exist_ok=True)

    # CSV: mean, std, and per-experiment columns
    csv_path = os.path.join(cross_dir, "top_subtypes_topk_frequency_mean_std.csv")
    with open(csv_path, "w", newline="") as f:
        w = csv.writer(f)
        header = ["atom_type", "mean_percent", "std_percent"] + [f"{name}_percent" for name in cross_exp_expnames]
        w.writerow(header)
        for j in range(W):
            row = [ATOM_TYPE_NAMES[j], float(mean_pct[j]), float(std_pct[j])] + [float(exp_pcts[e, j]) for e in range(E)]
            w.writerow(row)

    # Plot (mean ± std) as error bars
    if order:
        png_path = os.path.join(cross_dir, "top_subtypes_topk_frequency_mean_std.png")
        save_bar_with_errorbars(
            values=mean_pct[order],
            errors=std_pct[order],
            labels=[ATOM_TYPE_NAMES[j] for j in order],
            out_path=png_path,
            title=f"Subtype frequency among Top-K atoms (mean ± SD across {E} experiments)",
            ylabel="Percent of Top-K atoms"
        )

0 is exp index /home/alexhernandez/transmembranebindingAI/Data/SplitData/Cholesterol/IvanTestSet/ivan-graph-5A/positive/*.npy  is M_GLOB
/home/alexhernandez/transmembranebindingAI/Data/SplitData/Cholesterol/IvanTestSet/ivan-graph-5A/positive/4HQJ-filtered_combined_matrix.npy is first npy file
[INFO] Preloaded 50 model(s).
[DONE] Processed 57 positive sample(s). Outputs -> gnn_out_external/exp1
1 is exp index /home/alexhernandez/transmembranebindingAI/Data/SplitData/Cholesterol/IvanTestSet/ivan-graph-5A/positive/*.npy  is M_GLOB
/home/alexhernandez/transmembranebindingAI/Data/SplitData/Cholesterol/IvanTestSet/ivan-graph-5A/positive/4HQJ-filtered_combined_matrix.npy is first npy file
[INFO] Preloaded 50 model(s).
[DONE] Processed 57 positive sample(s). Outputs -> gnn_out_external/exp2
2 is exp index /home/alexhernandez/transmembranebindingAI/Data/SplitData/Cholesterol/IvanTestSet/ivan-graph-5A/positive/*.npy  is M_GLOB
/home/alexhernandez/transmembranebindingAI/Data/SplitData/Cholesterol

In [8]:
# def min_max_normalization(matrix):
#     """
#     Perform Min-Max normalization on a given matrix.

#     Parameters:
#     matrix (np.ndarray): The input matrix to be normalized.

#     Returns:
#     np.ndarray: The normalized matrix with values scaled to the range [0, 1].
#     """
#     # Compute the minimum and maximum values for the matrix
#     min_val = np.min(matrix)
#     max_val = np.max(matrix)

#     # Apply Min-Max normalization formula
#     normalized_matrix = (matrix - min_val) / (max_val - min_val)

#     return normalized_matrix

# data = np.load("/home/alexhernandez/transmembranebindingAI/Data/SplitData/Cholesterol/cholesterol-separate-graphs-5A_exp1/Positive/1LRI-filtered_graphs.npy", allow_pickle=True).item()
# inverse_distance = data['inverse_distance']
# encoded_matrix = data['encoded_matrix']

# print(encoded_matrix[:10])
# print(inverse_distance[:10])

# max_atoms = 150

# combined_matrix = inverse_distance @ encoded_matrix # for gnn
# combined_matrix = min_max_normalization(combined_matrix)

# num_atoms = inverse_distance.shape[0]

# combined_matrix = np.pad(combined_matrix, ((0, max_atoms - num_atoms), (0, 0)), mode='constant') # padding for gnn
# print(combined_matrix[:10])