In [3]:
# Gradient-based top genes per class (try shap.GradientExplainer, fallback to manual gradient-saliency)
# Paste & run in Kaggle. Adjust paths if needed.

import os, json, joblib, time, math
import numpy as np, pandas as pd
import torch, torch.nn as nn
import torch.nn.functional as F
from sklearn.preprocessing import LabelEncoder

# ---------- CONFIG ----------
MULTI_GA_CSV = "/kaggle/input/genetic-algorithms/Multi GA.csv"
SCALER_MULTI_PATH = "/kaggle/input/saint-multi/scaler_saint_multi.pkl"
SAINT_PTH = "/kaggle/input/saint-multi/saint_multiclass_best.pth"
LABEL_ENCODER_PATH = "/kaggle/input/saint-multi/label_encoder_saint_multi.pkl"
OUT_JSON = "/kaggle/working/top_genes_gradient.json"

TOPK = 20
BATCH = 128            # batch for gradient computation
SAMPLE_FRACTION = 1.0  # reduce <1.0 if you want fewer samples (faster)
USE_SHAP_GRAD = True   # try shap.GradientExplainer first

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# ---------- LOAD DATA ----------
df = pd.read_csv(MULTI_GA_CSV, low_memory=False)
print("Loaded Multi GA CSV:", df.shape)

# detect label column
label_candidates = [c for c in df.columns if "cancer" in c.lower() or "type" in c.lower() or "label" in c.lower()]
if not label_candidates:
    raise SystemExit("No label column detected in Multi GA CSV.")
label_col = label_candidates[0]
print("Using label column:", label_col)

# feature dataframe (drop label)
Xdf = df.drop(columns=[label_col]).copy()

# ---------- ALIGN TO SAVED SCALER ----------
if not os.path.exists(SCALER_MULTI_PATH):
    raise SystemExit(f"Scaler not found at {SCALER_MULTI_PATH}")

scaler = joblib.load(SCALER_MULTI_PATH)
if not hasattr(scaler, "feature_names_in_"):
    raise SystemExit("Saved multi scaler must contain feature_names_in_ (was fitted on named columns).")

scaler_cols = [str(x).strip() for x in scaler.feature_names_in_]
# check and add missing columns as zero if needed (safe fallback)
missing = [c for c in scaler_cols if c not in Xdf.columns]
if missing:
    print(f"Warning: {len(missing)} scaler columns missing in CSV. Filling them with zeros (first 10 shown): {missing[:10]}")
    for c in missing:
        Xdf[c] = 0.0

# reorder to match scaler
Xdf = Xdf.loc[:, scaler_cols]
# optional sampling
if SAMPLE_FRACTION < 1.0:
    Xdf = Xdf.sample(frac=SAMPLE_FRACTION, random_state=42).reset_index(drop=True)
else:
    Xdf = Xdf.reset_index(drop=True)

# fill NaNs
if Xdf.isna().sum().sum() > 0:
    Xdf = Xdf.fillna(Xdf.median())

# scale
try:
    X = scaler.transform(Xdf.values.astype(float))
except Exception:
    X = scaler.transform(Xdf)

print("Prepared X shape:", X.shape)

# ---------- label encoder ----------
if os.path.exists(LABEL_ENCODER_PATH):
    label_enc = joblib.load(LABEL_ENCODER_PATH)
else:
    label_enc = LabelEncoder(); label_enc.fit(df[label_col].astype(str).values)
classes = list(label_enc.classes_)
n_classes = len(classes)
print("Classes:", n_classes)

# ---------- SAINT model skeleton ----------
class SAINT(nn.Module):
    def __init__(self, input_dim, num_classes, embed_dim=128, num_heads=4, num_layers=4, dropout=0.1):
        super().__init__()
        self.embedding = nn.Linear(input_dim, embed_dim)
        self.layers = nn.ModuleList([ nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dropout=dropout) for _ in range(num_layers) ])
        self.fc = nn.Linear(embed_dim, num_classes)
        self.dropout = nn.Dropout(dropout)
    def forward(self, x):
        x = self.embedding(x).unsqueeze(0)
        for layer in self.layers:
            x = layer(x)
        x = x.mean(dim=0)
        return self.fc(self.dropout(x))

# load checkpoint
saint_blob = torch.load(SAINT_PTH, map_location="cpu")
state = saint_blob.get("state_dict", saint_blob) if isinstance(saint_blob, dict) else None
model = SAINT(input_dim=X.shape[1], num_classes=n_classes).to(device)
if state is not None:
    sd = {k.replace("module.",""): v for k,v in state.items()}
    model.load_state_dict(sd, strict=False)
model.eval()
print("SAINT model ready on", device)

# ---------- Try shap.GradientExplainer (fast) ----------
use_shap = False
if USE_SHAP_GRAD:
    try:
        import shap
        print("shap version:", shap.__version__)
        # sample small background for explainer
        bg_size = min(200, max(20, int(0.02 * X.shape[0])))
        rng = np.random.RandomState(42)
        bg_idx = rng.choice(X.shape[0], size=bg_size, replace=False)
        background = X[bg_idx]
        # shap GradientExplainer wants model function that returns logits or outputs
        def model_fn(x_array):
            with torch.no_grad():
                t = torch.tensor(x_array, dtype=torch.float32).to(device)
                out = model(t)
                if isinstance(out, torch.Tensor):
                    return out.cpu().numpy()
                return out
        explainer = shap.GradientExplainer(model_fn, background)
        use_shap = True
        print("Created shap.GradientExplainer with background size", background.shape[0])
    except Exception as e:
        print("shap.GradientExplainer not available or failed:", str(e))
        use_shap = False

results = {}
feat_names = scaler_cols

if use_shap:
    print("Running GradientExplainer for all classes (this is faster than KernelExplainer).")
    # We'll compute mean absolute attributions per class
    # For memory/speed, iterate over classes and compute shap values for a subset of rows (or all rows if possible)
    # We'll use batches to avoid OOM
    all_indices = np.arange(X.shape[0])
    # optionally limit sample count per class to speed up; here we use all rows
    start_time = time.time()
    # shap returns per-class attribution for multi-output models; we request shap values for whole dataset in batches
    batch = 256
    sum_abs_by_class = np.zeros((n_classes, X.shape[1]), dtype=float)
    count_by_class = np.zeros(n_classes, dtype=int)
    for i in range(0, X.shape[0], batch):
        xb = X[i:i+batch]
        # shap values shape: list of arrays (one per output) or array (batch, features, outputs)? handle both
        sv = explainer.shap_values(xb)  # try: returns list len n_outputs each (batch, features)
        if isinstance(sv, list):
            # sv[j] = (batch, features) for class j
            for j in range(n_classes):
                arr = np.array(sv[j])
                sum_abs_by_class[j] += np.sum(np.abs(arr), axis=0)
            # increment counts
            for idx in range(i, min(i+batch, X.shape[0])):
                # determine true label if present (we don't require it; just count samples)
                count_by_class += 0  # we will normalize by number of samples later (we used all rows)
        else:
            # sv shape could be (batch, features, n_outputs)
            arr3 = np.array(sv)  # shape (batch, features, n_outputs)
            if arr3.ndim == 3:
                # move outputs to first axis
                arr3 = np.transpose(arr3, (2,0,1))  # (n_outputs, batch, features)
                for j in range(n_classes):
                    sum_abs_by_class[j] += np.sum(np.abs(arr3[j]), axis=0)
            else:
                raise RuntimeError("Unexpected shap output shape: " + str(arr3.shape))
    elapsed = time.time() - start_time
    print("GradientExplainer completed batches in {:.1f}s".format(elapsed))
    # normalize by total number of samples (we summed abs across all samples)
    total_samples = X.shape[0]
    mean_abs_by_class = sum_abs_by_class / float(total_samples)
    # prepare topk per class
    for c in range(n_classes):
        vec = mean_abs_by_class[c]
        order = np.argsort(vec)[::-1][:TOPK]
        total = vec.sum() + 1e-12
        top_list = [{"gene": feat_names[int(i)], "mean_abs": float(vec[int(i)]), "pct_total": float(100.0 * vec[int(i)] / total)} for i in order]
        results[c] = {"class_name": str(label_enc.inverse_transform([int(c)])[0]), "top_genes": top_list}
else:
    # ---------- FALLBACK: direct mean-absolute-gradient saliency (fast, reliable) ----------
    print("Falling back to manual gradient-saliency (mean abs gradient per class). This is fast and recommended.")
    model.eval()
    n = X.shape[0]
    counts = np.zeros(n_classes, dtype=int)
    imp_sum = np.zeros((n_classes, X.shape[1]), dtype=float)
    X_t = torch.tensor(X, dtype=torch.float32).to(device)
    # process in batches; for each batch compute logits then gradients for each class
    for i in range(0, n, BATCH):
        xb = X_t[i:i+BATCH].clone().detach().requires_grad_(True)
        logits = model(xb)            # (batch, n_classes)
        bs = xb.shape[0]
        # for each class, compute gradient of sum(logit[:,class]) wrt xb
        for cls in range(n_classes):
            model.zero_grad()
            # sum of logits for that class over the batch
            s = logits[:, cls].sum()
            s.backward(retain_graph=True)
            grads = xb.grad.detach().cpu().numpy()  # (batch, features)
            absmean = np.mean(np.abs(grads), axis=0)  # mean abs gradient per feature for this batch
            imp_sum[cls] += absmean * bs
            xb.grad.zero_()
        # accumulate counts by actual class labels if label column present
        # if you want class-sampled normalization, compute counts separately; here we normalize by dataset size
    # normalize by total number of samples used
    total_samples = float(n)
    mean_abs_by_class = imp_sum / total_samples
    for c in range(n_classes):
        vec = mean_abs_by_class[c]
        order = np.argsort(vec)[::-1][:TOPK]
        total = vec.sum() + 1e-12
        top_list = [{"gene": feat_names[int(i)], "score": float(vec[int(i)]), "pct_total": float(100.0 * vec[int(i)] / total)} for i in order]
        results[c] = {"class_name": str(label_enc.inverse_transform([int(c)])[0]), "top_genes": top_list}

# ---------- SAVE RESULTS ----------
with open(OUT_JSON, "w") as f:
    json.dump({
        "method": "shap_gradient_explainer" if use_shap else "gradient_saliency",
        "n_rows_used": int(X.shape[0]),
        "n_features": int(X.shape[1]),
        "topk": TOPK,
        "results": results
    }, f, indent=2)

Device: cuda
Loaded Multi GA CSV: (10459, 8972)
Using label column: cancer_type


https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


Prepared X shape: (10459, 8971)
Classes: 33
SAINT model ready on cuda


https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


shap version: 0.44.1


2025-10-30 04:28:18.798372: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1761798498.962723      37 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1761798499.018409      37 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


shap.GradientExplainer not available or failed: <class 'function'> is not currently a supported model type!
Falling back to manual gradient-saliency (mean abs gradient per class). This is fast and recommended.


In [4]:
import json, os
OUT_JSON = "/kaggle/working/top_genes_gradient.json"

if not os.path.exists(OUT_JSON):
    raise SystemExit(f"{OUT_JSON} not found. Make sure the gradient cell finished and saved the file.")

data = json.load(open(OUT_JSON))
print("Method:", data.get("method", "gradient_saliency"))
print("Rows used:", data.get("n_rows_used"))
print("Features:", data.get("n_features"))
print("Top-k:", data.get("topk"))
print()

results = data.get("results", data) if "results" in data else data
for k in list(results.keys()):
    cls = results[k]["class_name"]
    genes = [g["gene"] for g in results[k]["top_genes"][:10]]
    print(f"Class {k} - {cls} -> top 10 genes: {genes}")

Method: gradient_saliency
Rows used: 10459
Features: 8971
Top-k: 20

Class 0 - TCGA-ACC -> top 10 genes: ['PYY', 'NAMA', 'NXPE1', 'REN', 'NOX1_AMDP1_TRPC6P_TCEB1P24_DDX11L16_WASIR1', 'HOXD_AS2', 'NXPE4', 'TMEM255B', 'TMIGD1', 'RETNLB']
Class 1 - TCGA-BLCA -> top 10 genes: ['IGBP1P1', 'AC006026_10', 'RPL13AP6', 'PSMB11', 'PIN1P1', 'AC105339_1', 'UBE2L3', 'RPL13A', 'PPIAP30', 'IL37']
Class 2 - TCGA-BRCA -> top 10 genes: ['AFP', 'KDM5D', 'NAPSA', 'UTY', 'ITLN2', 'APOD', 'RP11_417E7_2', 'SFTPC', 'SFTPA2', 'RP11_344E13_3']
Class 3 - TCGA-CESC -> top 10 genes: ['PENK', 'RPL19P12', 'NBPF8', 'MIR7_3HG', 'SLC18A1', 'TCF21', 'PCSK2', 'NBPF20', 'HBBP1', 'NBPF10']
Class 4 - TCGA-CHOL -> top 10 genes: ['NXPE1', 'GALNT14', 'PAX1', 'NOX1_AMDP1_TRPC6P_TCEB1P24_DDX11L16_WASIR1', 'SLC26A3', 'RP11_163N6_2', 'RAB34', 'NXPE4', 'ABCC11_RP11_3M1_1', 'NINL']
Class 5 - TCGA-COAD -> top 10 genes: ['TMA7', 'MIR101_2', 'STAU2_AS1', 'USP32P1', 'SULT2A1', 'MGARP', 'INS', 'LYZL6', 'AC006552_1', 'RPL19P12']
Class 6 -

In [5]:
# by TCGA name
tcga_name = "TCGA-BRCA"   # change
# by index
cls_idx = None            # or set to integer, e.g. 2

data = json.load(open("/kaggle/working/top_genes_gradient.json"))
results = data.get("results", data) if "results" in data else data

if cls_idx is None:
    # find by name
    found = None
    for k,v in results.items():
        if v["class_name"] == tcga_name:
            found = (k,v); break
    if found is None:
        raise SystemExit(f"{tcga_name} not found in results. Available classes: {[v['class_name'] for v in results.values()]}")
    k,v = found
else:
    k = str(cls_idx) if isinstance(list(results.keys())[0], str) else cls_idx
    v = results[k]

print("Class:", v["class_name"])
for i,g in enumerate(v["top_genes"]):
    print(i+1, g["gene"], f"score={g.get('score', g.get('mean_abs', None)):.4e}", f"pct={g.get('pct_total',0):.2f}%")

Class: TCGA-BRCA
1 AFP score=5.8962e-03 pct=0.07%
2 KDM5D score=5.5705e-03 pct=0.07%
3 NAPSA score=5.0143e-03 pct=0.06%
4 UTY score=4.9115e-03 pct=0.06%
5 ITLN2 score=4.8846e-03 pct=0.06%
6 APOD score=4.8188e-03 pct=0.06%
7 RP11_417E7_2 score=4.7069e-03 pct=0.06%
8 SFTPC score=4.5862e-03 pct=0.05%
9 SFTPA2 score=4.5236e-03 pct=0.05%
10 RP11_344E13_3 score=4.2508e-03 pct=0.05%
11 SFTPA1 score=4.2006e-03 pct=0.05%
12 PPP1R14A score=4.1656e-03 pct=0.05%
13 ADIPOQ score=4.1556e-03 pct=0.05%
14 SNORA14A score=4.1205e-03 pct=0.05%
15 RP11_432J24_5 score=4.1190e-03 pct=0.05%
16 NKX3_2 score=4.0435e-03 pct=0.05%
17 HBA2 score=4.0203e-03 pct=0.05%
18 TMSB4Y score=4.0139e-03 pct=0.05%
19 NLGN4Y score=3.9703e-03 pct=0.05%
20 AC016683_6 score=3.9435e-03 pct=0.05%


In [6]:
import csv, json
data = json.load(open("/kaggle/working/top_genes_gradient.json"))
results = data.get("results", data) if "results" in data else data
out_csv = "/kaggle/working/top_genes_all_classes.csv"
rows = []
for k,v in results.items():
    class_name = v["class_name"]
    for rank, g in enumerate(v["top_genes"], start=1):
        score = g.get("score", g.get("mean_abs", g.get("mean_abs_shap", None)))
        pct = g.get("pct_total", None)
        rows.append([k, class_name, g["gene"], rank, score, pct])
with open(out_csv, "w", newline="") as f:
    writer = csv.writer(f)
    writer.writerow(["class_index","class_name","gene","rank","score","pct_total"])
    writer.writerows(rows)
print("Saved CSV to", out_csv)

Saved CSV to /kaggle/working/top_genes_all_classes.csv
