In [15]:
# ============================== CONFIG ===============================
M_GLOB       = r"/home/alexhernandez/transmembranebindingAI/Data/SplitData/Cholesterol/cholesterol-graph-5A/Test/Positive/1LRI-filtered_combined_matrix.npy"
MODEL_GLOB   = r"/home/alexhernandez/transmembranebindingAI/Models/Cholesterol/GNN/GNN-5A_Exp1/Models/*.pth"
OUT_ROOT     = r"gnn_out_1LRI"

# OPTIONAL: coordinates per sample (e.g., [N, H, 3] or list of .npy matching order). Keep None if not available.
COORDS_PATH  = None   # e.g., r"C:/.../coords.npy"

# Your 37 atom-type names in the exact column order of your input matrices:
ATOM_TYPE_NAMES = [
    'C', 'CA', 'CB', 'CD', 'CD1', 'CD2', 'CE', 'CE1', 'CE2', 'CE3', 'CG', 'CG1', 'CG2', 'CH2', 'CZ', 'CZ2', 'CZ3',
    'O', 'OH', 'OD1', 'OD2', 'OE1', 'OE2', 'OG', 'OG1',
    'N', 'NE', 'NE1', 'NE2', 'ND1', 'ND2', 'NZ', 'NH1', 'NH2',
    'SD', 'SG'
]
ATOM_TYPE_NAMES.append('UNKNOWN')  # makes 37 total

# Grad-CAM target: "positive" shows evidence FOR the positive class.
TARGET_CLASS = "positive"

# Safety: CPU-only is most stable. Set to False to try GPU with aggressive cleanup.
FORCE_CPU = True

# ============================ IMPORTS/SETUP ============================
import os, glob, gc, csv
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image, ImageDraw, ImageFont
import matplotlib.cm as cm

if FORCE_CPU:
    os.environ["CUDA_VISIBLE_DEVICES"] = ""
DEVICE = torch.device("cuda" if (not FORCE_CPU and torch.cuda.is_available()) else "cpu")
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
os.makedirs(OUT_ROOT, exist_ok=True)

# ================================ HELPERS ==============================
def normalize01(a: np.ndarray) -> np.ndarray:
    a = a.astype(np.float32, copy=False)
    mn, mx = float(a.min()), float(a.max())
    if mx - mn < 1e-8:
        return np.zeros_like(a, dtype=np.float32)
    return (a - mn) / (mx - mn)

def save_png_gray(arr: np.ndarray, path: str):
    Image.fromarray((normalize01(arr)*255).astype(np.uint8)).save(path)

def save_overlay(inp: np.ndarray, cam: np.ndarray, path: str):
    inp01 = normalize01(inp)
    cam01 = normalize01(cam)
    inp_rgb  = (np.stack([inp01]*3, -1)*255).astype(np.uint8)
    heatmap  = (cm.jet(cam01)[..., :3]*255).astype(np.uint8)
    overlay  = (0.6*inp_rgb + 0.4*heatmap).clip(0,255).astype(np.uint8)
    Image.fromarray(overlay).save(path)

def save_line_plot_pil(values, out_path, title="", xlabel="", ylabel="", size=(1000, 360), margin=60):
    values = np.asarray(values, dtype=np.float32)
    W,H = size
    img = Image.new("RGB", (W,H), "white")
    draw = ImageDraw.Draw(img)
    left, top = margin, margin
    right, bottom = W - margin, H - margin
    # axes
    draw.line([(left,bottom),(right,bottom)], fill="black", width=2)
    draw.line([(left,top),(left,bottom)], fill="black", width=2)
    if values.size > 0:
        vmin, vmax = float(values.min()), float(values.max())
        norm = (values - vmin) / max(vmax - vmin, 1e-8)
        xs = np.linspace(left, right, len(values))
        ys = bottom - norm * (bottom - top)
        for i in range(len(xs)-1):
            draw.line([(xs[i], ys[i]), (xs[i+1], ys[i+1])], fill=(30,144,255), width=2)
    try:
        font = ImageFont.load_default()
        if title:  draw.text((left, top-25), title, fill="black", font=font)
        if xlabel: draw.text((left + (right-left)//2 - 40, bottom + 8), xlabel, fill="black", font=font)
        if ylabel: draw.text((8, top - 10), ylabel, fill="black", font=font)
    except: pass
    img.save(out_path, format="PNG")

def save_bar_plot_pil(values, labels, out_path, title="", size=(1200, 500), margin=80, color=(30,144,255)):
    values = np.asarray(values, dtype=np.float32)
    n = len(values); W,H = size
    img = Image.new("RGB", (W,H), "white")
    draw = ImageDraw.Draw(img)
    left, top = margin, margin
    right, bottom = W - margin, H - margin
    plot_w, plot_h = right-left, bottom-top
    try:
        font = ImageFont.load_default()
        if title: draw.text((left, top-25), title, fill="black", font=font)
    except: pass
    if n == 0:
        img.save(out_path); return
    vmax = float(values.max()); scale = plot_w / max(vmax, 1e-8)
    bar_h = plot_h / max(n,1)
    for i, v in enumerate(values):
        y0 = top + i*bar_h + 4; y1 = top + (i+1)*bar_h - 4; x1 = left + v*scale
        draw.rectangle([left, y0, x1, y1], fill=color)
        lbl = str(labels[i])[:22] + ("â€¦" if len(str(labels[i]))>22 else "")
        try:
            draw.text((5, (y0+y1)/2 - 6), lbl, fill="black", font=font)
            draw.text((x1 + 5, (y0+y1)/2 - 6), f"{v:.3f}", fill="black", font=font)
        except: pass
    img.save(out_path, format="PNG")

# ============================ MODEL & GRAD-CAM =========================
class CNN2D(nn.Module):
    def __init__(self, in_ch=1, out_grid=(4,18)):
        super().__init__()
        self.conv1 = nn.Sequential(nn.Conv2d(in_ch,32,3,padding=1), nn.BatchNorm2d(32), nn.ReLU())
        self.pool1 = nn.MaxPool2d(2,2)
        self.conv2 = nn.Sequential(nn.Conv2d(32,64,3,padding=1), nn.BatchNorm2d(64), nn.ReLU())
        self.pool2 = nn.MaxPool2d(2,2)
        self.conv3 = nn.Sequential(nn.Conv2d(64,128,3,padding=1), nn.BatchNorm2d(128), nn.ReLU())
        self.pool3 = nn.MaxPool2d(2,2)
        self.adapt  = nn.AdaptiveAvgPool2d(out_grid)
        self.flat   = nn.Flatten()
        self.fc1    = nn.Linear(128*out_grid[0]*out_grid[1], 128)
        self.drop   = nn.Dropout(0.5)
        self.fc2    = nn.Linear(128, 1)
    def forward(self, x):
        x = self.pool1(self.conv1(x))
        x = self.pool2(self.conv2(x))
        x = self.pool3(self.conv3(x))
        x = self.adapt(x)
        x = self.flat(x)
        x = torch.relu(self.fc1(x))
        x = self.drop(x)
        return self.fc2(x)

class GradCAM:
    def __init__(self, model, target_layer):
        self.model, self.target_layer = model.eval(), target_layer
        self.activations, self.gradients = None, None
        self._fwd = target_layer.register_forward_hook(self._fwd_hook)
        self._bwd = target_layer.register_full_backward_hook(self._bwd_hook)
    def _fwd_hook(self, m,i,o): self.activations = o.detach()
    def _bwd_hook(self, m,gi,go): self.gradients = go[0].detach()
    @staticmethod
    @torch.no_grad()
    def _norm(cam):
        B = cam.size(0)
        flat = cam.view(B,-1)
        flat = flat - flat.min(dim=1, keepdim=True).values
        flat = flat / flat.max(dim=1, keepdim=True).values.clamp_min(1e-8)
        return flat.view_as(cam)
    def generate(self, x, target="positive"):
        B,C,H,W = x.shape
        self.model.zero_grad(set_to_none=True)
        logits = self.model(x).squeeze(-1)           # [B]
        score  = logits if target=="positive" else (-logits)
        self.model.zero_grad(set_to_none=True)
        score.sum().backward()
        A, dA = self.activations, self.gradients
        if A is None or dA is None:
            raise RuntimeError("Grad-CAM hooks failed")
        w = dA.mean(dim=(2,3), keepdim=True)         # [B,C,1,1]
        cam = (w*A).sum(dim=1, keepdim=True)         # [B,1,h,w]
        cam = F.relu(cam)
        if cam.shape[-2:] != (H,W):
            cam = F.interpolate(cam, size=(H,W), mode='bilinear', align_corners=False)
        return self._norm(cam), logits.detach()
    def remove(self):
        self._fwd.remove(); self._bwd.remove()

def last_conv2d(model):
    lc = None
    for m in model.modules():
        if isinstance(m, nn.Conv2d): lc = m
    return lc

def load_model(path, in_ch=1):
    # weights_only=True removes the security warning and future-proofs your code
    sd = torch.load(path, map_location=DEVICE, weights_only=True)
    if isinstance(sd, dict) and "state_dict" in sd:
        sd = sd["state_dict"]
    sd = { (k[7:] if k.startswith("module.") else k): v for k,v in sd.items() }
    m = CNN2D(in_ch=in_ch).to(DEVICE)
    m.load_state_dict(sd, strict=True)
    m.eval()
    return m

# ========================= DISCOVER DATA/MODELS ========================
npy_files = sorted(glob.glob(M_GLOB))
assert npy_files, f"No .npy files match: {M_GLOB}"
model_paths = sorted(glob.glob(MODEL_GLOB))
assert model_paths, f"No models match: {MODEL_GLOB}"

# Peek first sample to get H,W and validate W=37
first = np.load(npy_files[0])
if first.ndim == 2:     # [H, W]
    H, W = first.shape
elif first.ndim == 3:   # [1, H, W] or [C,H,W]
    H, W = first.shape[-2], first.shape[-1]
else:
    raise ValueError(f"Unsupported sample shape: {first.shape}")
assert W == 37, f"Expected 37 type columns, got {W}"
assert len(ATOM_TYPE_NAMES) == 37, f"ATOM_TYPE_NAMES must have 37 entries; got {len(ATOM_TYPE_NAMES)}"

# Optional coordinates load
coords = None
if COORDS_PATH and os.path.exists(COORDS_PATH):
    coords = np.load(COORDS_PATH)  # expect [N, H, 2 or 3]
    assert coords.shape[0] == len(npy_files), "coords N must match number of .npy files"
    assert coords.shape[1] == H, "coords H must match input row count"

# ========================= AGGREGATE ACCUMULATORS ======================
sum_atom_importance = np.zeros((H,), dtype=np.float64)
sum_type_importance = np.zeros((W,), dtype=np.float64)
sum_cam             = np.zeros((H,W), dtype=np.float64)
count_samples       = 0

# CSV for per-sample ensemble scores
scores_path = os.path.join(OUT_ROOT, "per_sample_scores.csv")
with open(scores_path, "w", newline="") as fcsv:
    csv.writer(fcsv).writerow(["sample_name","ensemble_prob_mean","num_models_used"])

# ========================= PROCESS EACH POSITIVE =======================
for sidx, fpath in enumerate(npy_files):
    try:
        M = np.load(fpath)  # shape: [H,37] or [1,H,37]
        if M.ndim == 2:
            assert M.shape[1] == 37
            x_np = M[None, None, :, :]        # [1,1,H,37]
        elif M.ndim == 3:
            if M.shape[0] == 1:
                x_np = M[:, None, :, :]       # [1,1,H,37] if first dim is samples=1
            else:
                # if it's [C,H,W], force C=1:
                assert M.shape[0] in (1,), f"Unexpected 3D shape; expected [1,H,37], got {M.shape}"
                x_np = M[None, :, :, :]       # [1,1,H,37]
        else:
            raise ValueError(f"Unsupported sample shape: {M.shape}")
        assert x_np.shape[-2:] == (H,37), f"Sample {os.path.basename(fpath)} has wrong HxW"

        x_t = torch.from_numpy(x_np.astype(np.float32)).to(DEVICE)  # [1,1,H,37]

        run_sum = np.zeros((H,37), dtype=np.float64)
        logits_list = []
        used = 0

        for mp in model_paths:
            try:
                model = load_model(mp, in_ch=1)
                conv  = last_conv2d(model)
                if conv is None:
                    del model
                    continue
                gc_obj = GradCAM(model, conv)
                cam_b1hw, logits = gc_obj.generate(x_t, target=TARGET_CLASS)
                gc_obj.remove()

                cam = cam_b1hw[0,0].detach().cpu().numpy()  # [H,37] in [0,1]
                run_sum += cam
                logits_list.append(float(logits[0].item()))
                used += 1

            except Exception as e:
                print(f"[SKIP {os.path.basename(mp)}] {e}")

            # cleanup
            try:
                del model, gc_obj, cam_b1hw, logits
            except: pass
            gc.collect()
            if DEVICE.type == 'cuda':
                torch.cuda.empty_cache()

        if used == 0:
            print(f"[WARN] {os.path.basename(fpath)} produced no CAMs; skipped.")
            continue

        cam_mean = (run_sum / used).astype(np.float32)   # [H,37]
        cam_mean = normalize01(cam_mean)
        print(cam_mean.shape, "is shape of cam mean")
        with np.printoptions(threshold=np.inf, linewidth=10_000, suppress=True):
            print(cam_mean)        

        mean_logit = float(np.mean(logits_list))
        prob = float(1/(1+np.exp(-mean_logit)))

        # -------- per-atom & per-type importance --------
        atom_importance = cam_mean.mean(axis=1).astype(np.float32)  # [H]
        type_importance = cam_mean.mean(axis=0).astype(np.float32)  # [37]

        # -------- save per-sample outputs --------
        sname = os.path.splitext(os.path.basename(fpath))[0]
        sdir  = os.path.join(OUT_ROOT, f"sample_{sidx:05d}_{sname}")
        os.makedirs(sdir, exist_ok=True)

        # arrays
        np.save(os.path.join(sdir, "input_M.npy"), M if M.ndim==2 else M[0])          # [H,37]
        np.save(os.path.join(sdir, "ensemble_cam.npy"), cam_mean)                     # [H,37]
        np.save(os.path.join(sdir, "atom_importance.npy"), atom_importance)           # [H]
        np.save(os.path.join(sdir, "type_importance.npy"), type_importance)           # [37]

        # CSV: atoms (+ coords if available)
        atoms_csv = os.path.join(sdir, "atoms_importance.csv")
        with open(atoms_csv, "w", newline="") as f:
            w = csv.writer(f)
            if coords is not None:
                dim = coords.shape[2]
                w.writerow(["atom_index", "atom_type", "importance"] + [f"coord_{k}" for k in range(dim)])
                for a in range(H):
                    atom_type = ATOM_TYPE_NAMES[np.argmax(M[a])] if np.any(M[a]) else "UNKNOWN"
                    row = [a, atom_type, float(atom_importance[a])] + [float(c) for c in coords[sidx, a]]
                    w.writerow(row)
            else:
                w.writerow(["atom_index", "atom_type", "importance"])
                for a in range(H):
                    atom_type = ATOM_TYPE_NAMES[np.argmax(M[a])] if np.any(M[a]) else "UNKNOWN"
                    w.writerow([a, atom_type, float(atom_importance[a])])

        # CSV: per-type with names
        types_csv = os.path.join(sdir, "type_importance.csv")
        with open(types_csv, "w", newline="") as f:
            w = csv.writer(f)
            w.writerow(["atom_type","importance"])
            for j in range(37):
                w.writerow([ATOM_TYPE_NAMES[j], float(type_importance[j])])

        # Visuals (PIL only)
        inp2d = M if M.ndim==2 else M[0]
        save_png_gray(inp2d,                  os.path.join(sdir, "input_M.png"))
        save_png_gray(cam_mean,               os.path.join(sdir, "ensemble_cam.png"))
        save_overlay(inp2d, cam_mean,         os.path.join(sdir, "ensemble_cam_overlay.png"))
        save_line_plot_pil(atom_importance,   os.path.join(sdir, "atom_importance_plot.png"),
                           title="Atom importance (mean over types)",
                           xlabel="Atom index", ylabel="Importance")
        # Sorted per-type bar
        order = np.argsort(type_importance)[::-1]
        save_bar_plot_pil(type_importance[order], [ATOM_TYPE_NAMES[i] for i in order],
                          os.path.join(sdir, "type_importance_bar.png"),
                          title="Atom-type importance (mean over atoms)")

        # append sample score
        with open(scores_path, "a", newline="") as fcsv:
            csv.writer(fcsv).writerow([sname, prob, used])

        # aggregates
        sum_atom_importance += atom_importance
        sum_type_importance += type_importance
        sum_cam             += cam_mean
        count_samples       += 1

    except Exception as e:
        print(f"[FAIL sample {fpath}] {e}")

# ============================= AGGREGATES ==============================
if count_samples > 0:
    agg_dir = os.path.join(OUT_ROOT, "aggregate")
    os.makedirs(agg_dir, exist_ok=True)

    mean_atom_importance = (sum_atom_importance / count_samples).astype(np.float32)  # [H]
    mean_type_importance = (sum_type_importance / count_samples).astype(np.float32)  # [37]
    mean_cam             = (sum_cam / count_samples).astype(np.float32)             # [H,37]

    # arrays
    np.save(os.path.join(agg_dir, "mean_atom_importance.npy"), mean_atom_importance)
    np.save(os.path.join(agg_dir, "mean_type_importance.npy"), mean_type_importance)
    np.save(os.path.join(agg_dir, "mean_cam.npy"),             mean_cam)

    # CSVs
    with open(os.path.join(agg_dir, "mean_atom_importance.csv"), "w", newline="") as f:
        w = csv.writer(f); w.writerow(["atom_index","mean_importance"])
        for a in range(H): w.writerow([a, float(mean_atom_importance[a])])

    with open(os.path.join(agg_dir, "mean_type_importance.csv"), "w", newline="") as f:
        w = csv.writer(f); w.writerow(["atom_type","mean_importance"])
        for j in range(37): w.writerow([ATOM_TYPE_NAMES[j], float(mean_type_importance[j])])

    # visuals
    save_line_plot_pil(mean_atom_importance, os.path.join(agg_dir, "mean_atom_importance_plot.png"),
                       title="Mean atom importance (positives)", xlabel="Atom index", ylabel="Mean importance")
    order = np.argsort(mean_type_importance)[::-1]
    save_bar_plot_pil(mean_type_importance[order], [ATOM_TYPE_NAMES[i] for i in order],
                      os.path.join(agg_dir, "mean_type_importance_bar.png"),
                      title="Mean atom-type importance (positives)")
    save_png_gray(mean_cam, os.path.join(agg_dir, "mean_cam.png"))

print(f"[DONE] Processed {count_samples} positive sample(s). Outputs -> {OUT_ROOT}")


(150, 37) is shape of cam mean
[[0.4381806  0.4381806  0.43794274 0.4374076  0.43687245 0.4363373  0.4349929  0.43203005 0.42906716 0.42610434 0.43894482 0.507097   0.57524925 0.64340144 0.70320034 0.69617295 0.6891455  0.68211806 0.6750907  0.67622465 0.6773585  0.6784925  0.67962635 0.66931325 0.6575692  0.6458253  0.6340813  0.5768286  0.50657356 0.4363184  0.36606333 0.32421017 0.2965579  0.26890564 0.24125338 0.22896348 0.22896348]
 [0.4381806  0.4381806  0.43794274 0.4374076  0.43687245 0.4363373  0.4349929  0.43203005 0.42906716 0.42610434 0.43894482 0.507097   0.57524925 0.64340144 0.70320034 0.69617295 0.6891455  0.68211806 0.6750907  0.67622465 0.6773585  0.6784925  0.67962635 0.66931325 0.6575692  0.6458253  0.6340813  0.5768286  0.50657356 0.4363184  0.36606333 0.32421017 0.2965579  0.26890564 0.24125338 0.22896348 0.22896348]
 [0.3947082  0.3947082  0.39782342 0.40483257 0.41184172 0.41885087 0.42288122 0.420954   0.41902676 0.4170995  0.43075496 0.49894992 0.56714493 0.63

In [12]:
def min_max_normalization(matrix):
    """
    Perform Min-Max normalization on a given matrix.

    Parameters:
    matrix (np.ndarray): The input matrix to be normalized.

    Returns:
    np.ndarray: The normalized matrix with values scaled to the range [0, 1].
    """
    # Compute the minimum and maximum values for the matrix
    min_val = np.min(matrix)
    max_val = np.max(matrix)

    # Apply Min-Max normalization formula
    normalized_matrix = (matrix - min_val) / (max_val - min_val)

    return normalized_matrix

data = np.load("/home/alexhernandez/transmembranebindingAI/Data/SplitData/Cholesterol/cholesterol-separate-graphs-5A_exp1/Positive/1LRI-filtered_graphs.npy", allow_pickle=True).item()
inverse_distance = data['inverse_distance']
encoded_matrix = data['encoded_matrix']

print(encoded_matrix[:10])
print(inverse_distance[:10])

max_atoms = 150

combined_matrix = inverse_distance @ encoded_matrix # for gnn
combined_matrix = min_max_normalization(combined_matrix)

num_atoms = inverse_distance.shape[0]

combined_matrix = np.pad(combined_matrix, ((0, max_atoms - num_atoms), (0, 0)), mode='constant') # padding for gnn
print(combined_matrix[:10])

[[0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
  0]
 [1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
  0]
 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
  0]
 [0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
  0]
 [0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
  0]
 [0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
  0]
 [1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
  0]
 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
  0]
 [0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
  0]
 [0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
  0]]
[[1.         0.656567   0.4184693  0.6601029  0.332777   0.1901095
  0.17879032 0.15053831 0.2045873  0.18143432 0.18421562 0.1424453
  0.1951022  0.16694595 0.17375705 0.14421593 0.10617352 0.1192816
  0.0784