In [None]:

import os, random, numpy as np, tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from scipy.ndimage import binary_dilation, binary_erosion
import shap

# -------------------- Reproducibility --------------------
np.random.seed(42); random.seed(42); tf.random.set_seed(42)

# -------------------- Utilities & Loading --------------------
def _exists(name, scope): return (name in scope) and (scope[name] is not None)
def _ensure_ch(x): return x if x.ndim == 5 else x[..., np.newaxis]
def _one_hot(y, n=2): return keras.utils.to_categorical(y, num_classes=n)

def _load_all(globs):
    """Load arrays either from globals (preloaded) or from disk.
    Required arrays (numpy): X_train, y_train, X_val, y_val, X_test, y_test
    Brain atlas masks: BMAPS (list of 3D bool arrays of length 96) and optional TEMPLATE (3D bool)
    """
    if all(_exists(n, globs) for n in ["X_train","y_train","X_val","y_val","X_test","y_test","BMAPS"]):
        X_train, y_train = globs["X_train"], globs["y_train"]
        X_val,   y_val   = globs["X_val"],   globs["y_val"]
        X_test,  y_test  = globs["X_test"],  globs["y_test"]
        BMAPS = globs["BMAPS"]
        TEMPLATE = globs["TEMPLATE"] if _exists("TEMPLATE", globs) else None
        return X_train, y_train, X_val, y_val, X_test, y_test, BMAPS, TEMPLATE

    def req(p, w): 
        if not os.path.exists(p): raise FileNotFoundError(f"Missing {w}: {p}")
        return p

    X_train = np.load(req("X_train.npy","X_train")); y_train = np.load(req("y_train.npy","y_train"))
    X_val   = np.load(req("X_val.npy","X_val"));     y_val   = np.load(req("y_val.npy","y_val"))
    X_test  = np.load(req("X_test.npy","X_test"));   y_test  = np.load(req("y_test.npy","y_test"))

    if os.path.exists("atlas_masks.npz"):
        npz = np.load("atlas_masks.npz", allow_pickle=True)
        BMAPS = list(npz["masks"]); TEMPLATE = npz["template"] if "template" in npz.files else None
    else:
        BMAPS = np.load(req("BMAPS.npy","BMAPS"), allow_pickle=True).tolist()
        TEMPLATE = np.load("TEMPLATE.npy") if os.path.exists("TEMPLATE.npy") else None
    return X_train, y_train, X_val, y_val, X_test, y_test, BMAPS, TEMPLATE

# -------------------- Model --------------------
def build_cnn(input_shape, n_classes=2):
    inp = keras.Input(shape=input_shape)
    x = layers.Conv3D(8, 5, padding="same", activation="relu")(inp); x = layers.MaxPool3D(2)(x)
    x = layers.Conv3D(16, 3, padding="same", activation="relu")(x);  x = layers.MaxPool3D(2)(x)
    x = layers.Conv3D(32, 3, padding="same", activation="relu")(x);  x = layers.MaxPool3D(2)(x)
    x = layers.Flatten()(x); x = layers.Dense(1024, activation="relu")(x); x = layers.Dropout(0.5)(x)
    out = layers.Dense(n_classes, activation="softmax")(x)
    model = keras.Model(inp, out, name="GASHAP_3DCNN")
    opt = keras.optimizers.Adadelta(learning_rate=0.05, rho=0.95)
    model.compile(loss="categorical_crossentropy", optimizer=opt, metrics=["accuracy"])
    return model

def train_model(model, Xtr, ytr, Xva, yva, outdir="outputs", bs=32, epochs=100):
    os.makedirs(outdir, exist_ok=True)
    ck = ModelCheckpoint(os.path.join(outdir,"best_model.h5"), monitor="val_accuracy", save_best_only=True)
    es = EarlyStopping(monitor="val_loss", patience=10, restore_best_weights=True)
    hist = model.fit(Xtr, ytr, validation_data=(Xva,yva), epochs=epochs, batch_size=bs, callbacks=[ck,es], verbose=1)
    return hist

def evaluate_model(model, X, y_oh):
    p = model.predict(X, verbose=0)
    ypred = np.argmax(p,1); ytrue = np.argmax(y_oh,1)
    return (ypred==ytrue).mean(), ypred, p

# -------------------- SHAP --------------------
def deep_shap(model, X_bg, X_samples, nsamples=None):
    expl = shap.DeepExplainer(model, X_bg)
    Xs = X_samples[:nsamples] if (nsamples and nsamples>0) else X_samples
    sv = expl.shap_values(Xs, check_additivity=False)
    sv = sv[1] if isinstance(sv, list) else sv
    if sv.ndim == 5 and sv.shape[-1]==1: sv = sv[...,0]
    return sv

def region_score(heatmap, rmask, max_shap=None):
    v = heatmap[rmask]
    if v.size==0: return 0.0
    pos = v[v>0]
    if pos.size==0: return 0.0
    s = pos.mean()
    if (max_shap is not None) and (max_shap>0): s = s/(max_shap+1e-8)
    return float(s)

# -------------------- Mapping scores -> chromosome --------------------
def scores_to_chrom_percentiles(scores, cuts=(30,50,70,100)):
    scores = np.array(scores, dtype=float)
    if not np.any(scores>0):
        return np.zeros_like(scores, dtype=np.int32)
    # Percentile thresholds
    p30, p50, p70, p100 = np.percentile(scores, [cuts[0], cuts[1], cuts[2], cuts[3]])
    chrom = np.zeros(scores.shape[0], dtype=np.int32)
    for i, s in enumerate(scores):
        if s <= p30:         chrom[i] = 0   # UR
        elif s <= p50:       chrom[i] = 1   # IR
        elif s <= p70:       chrom[i] = 2   # VIR
        else:                chrom[i] = 3   # VVIR
    return chrom

# -------------------- Chromosome -> mask with correct morphological mapping --------------------
# Per manuscript Table 4/5: 1=IR->Erosion, 2=VIR->No-op, 3=VVIR->Dilation
def chrom_to_mask(BMAPS, DNA, template=None, dil=1, ero=1):
    D,H,W = BMAPS[0].shape
    m = np.zeros((D,H,W), dtype=bool)
    for r,g in enumerate(DNA):
        if g==0:  # UR - skip
            continue
        reg = BMAPS[r]
        if   g==1: mm = binary_erosion(reg, iterations=ero)   # IR
        elif g==2: mm = reg                                   # VIR
        elif g==3: mm = binary_dilation(reg, iterations=dil)  # VVIR
        else:      mm = reg
        m |= mm
    if template is not None:
        m &= template
    return m

def masked_acc(model, X, yidx, m3d, bs=32):
    N = X.shape[0]; ok=0
    for i in range(0,N,bs):
        sl = slice(i, min(i+bs,N))
        Xb = (X[sl,...] * m3d[None,...])[...,np.newaxis]
        pr = model.predict(Xb, verbose=0); yp = np.argmax(pr,1)
        ok += np.sum(yp==yidx[sl])
    return ok/N

def fitness(model, Xref, yref_idx, BMAPS, DNA, template=None, alpha=0.025, beta=0.975):
    m3d = chrom_to_mask(BMAPS, DNA, template)
    acc = masked_acc(model, Xref, yref_idx, m3d, 32)
    nreg = int(np.sum(DNA>0))
    comp = 1.0/(nreg if nreg>0 else 1e6)   # compactness
    fit = beta*acc + alpha*comp
    return fit, acc, nreg, m3d

# -------------------- GA --------------------
def roulette_select(pop, fits):
    fits = np.array(fits, dtype=float)
    # Avoid negative/zero by shifting
    m = fits.min()
    probs = (fits - m + 1e-8)
    s = probs.sum()
    if s <= 0:
        # fallback uniform
        i = np.random.randint(len(pop))
        return pop[i].copy()
    r = np.random.rand() * s
    c = 0.0
    for i,p in enumerate(probs):
        c += p
        if c >= r:
            return pop[i].copy()
    return pop[-1].copy()

def xover(p1,p2,p=0.4):
    if np.random.rand()>p or len(p1)<2: 
        return p1.copy(), p2.copy()
    c = np.random.randint(1,len(p1))
    return np.concatenate([p1[:c],p2[c:]]), np.concatenate([p2[:c],p1[c:]])

def mutate(ch, pm=0.6):
    o = ch.copy()
    for i in range(len(o)):
        if np.random.rand() < pm:
            o[i] = np.random.choice([0,1,2,3])
    return o

def run_ga(
    model, Xref, yref_idx, BMAPS, template=None, init_pop=None,
    max_iter=2000, cross_p=0.4, mut_p=0.6, elit=0.05,
    alpha=0.025, beta=0.975, patience=10
):
    R = len(BMAPS)
    ps = len(init_pop) if init_pop is not None else 200
    ne = max(1, int(elit*ps))

    # init
    if init_pop is None:
        init_pop = [np.random.choice([0,1,2,3], size=R).astype(np.int32) for _ in range(ps)]
    pop = [c.copy() for c in init_pop]

    best_fit = -np.inf; best_sol=None; best_mask=None
    hist=[]
    prev_mean=None; noimp=0

    for gen in range(max_iter):
        fits=[]; accs=[]; regs=[]; masks=[]
        for ch in pop:
            f,a,r,mm = fitness(model, Xref, yref_idx, BMAPS, ch, template, alpha, beta)
            fits.append(f); accs.append(a); regs.append(r); masks.append(mm)

        bi = int(np.argmax(fits))
        if fits[bi] > best_fit:
            best_fit=float(fits[bi]); best_sol=pop[bi].copy(); best_mask=masks[bi]

        mfit=float(np.mean(fits)); hist.append((gen, mfit, float(np.mean(accs)), int(np.mean(regs))))
        if prev_mean is not None and (mfit - prev_mean) < 1e-3:
            noimp += 1
        else:
            noimp = 0
        prev_mean = mfit
        if noimp >= patience:
            break

        elite_idx = np.argsort(fits)[-ne:]
        elites=[pop[i].copy() for i in elite_idx]

        # new population via roulette selection + crossover + mutation
        newp=[]
        while len(newp) < (ps-ne):
            p1 = roulette_select(pop, fits); p2 = roulette_select(pop, fits)
            c1,c2 = xover(p1,p2,p=cross_p)
            c1 = mutate(c1, pm=mut_p); newp.append(c1)
            if len(newp) < (ps-ne):
                c2 = mutate(c2, pm=mut_p); newp.append(c2)
        pop = elites + newp

    return best_sol, best_mask, hist

# -------------------- Init population from SHAP (per manuscript: two experiments with percentiles) --------------------
def init_pop_from_shap_percentiles(shap_maps, BMAPS, template=None, pop_size=200, cuts=(30,50,70,100)):
    R=len(BMAPS); N=shap_maps.shape[0]
    mx=float(np.max(np.abs(shap_maps))) if shap_maps.size>0 else 0.0

    # region scores as mean of positive SHAP within region, normalized by global max (as in text)
    rs = np.zeros(R,float)
    for r in range(R):
        acc=[]
        rmask = BMAPS[r]
        for i in range(N):
            acc.append(region_score(shap_maps[i], rmask, max_shap=mx))
        rs[r] = np.mean(acc) if acc else 0.0

    base = scores_to_chrom_percentiles(rs, cuts=cuts)

    # seed population around the base chromosome
    pop=[]
    for _ in range(pop_size):
        ch = base.copy()
        if R>0:
            # small random perturbations to diversify
            flip = np.random.choice(R, size=max(1,R//20), replace=False)
            for j in flip:
                if np.random.rand()<0.3: ch[j] = np.random.choice([0,1,2,3])
        pop.append(ch.astype(np.int32))
    return pop, rs

# -------------------- End-to-end pipeline --------------------
def run_pipeline(outdir="outputs", bg_per_class=20, shap_nsamples=None):
    X_train, y_train, X_val, y_val, X_test, y_test, BMAPS, TEMPLATE = _load_all(globals())
    X_train=_ensure_ch(X_train); X_val=_ensure_ch(X_val); X_test=_ensure_ch(X_test)
    ytr=_one_hot(y_train,2); yva=_one_hot(y_val,2); yte=_one_hot(y_test,2)

    model = build_cnn(X_train.shape[1:], n_classes=2)
    _ = train_model(model, X_train, ytr, X_val, yva, outdir=outdir, bs=32, epochs=100)

    btr,_,_ = evaluate_model(model, X_train, ytr)
    bva,_,_ = evaluate_model(model, X_val,   yva)
    bte,_,_ = evaluate_model(model, X_test,  yte)
    print(f"[Baseline] Train={btr:.3f}  Val={bva:.3f}  Test={bte:.3f}")

    # background for Deep SHAP
    Xb=[]
    for cls in [0,1]:
        idx = np.where(y_train==cls)[0]; take = min(bg_per_class, idx.size)
        if take>0: Xb.append(X_train[np.random.choice(idx, size=take, replace=False)])
    X_bg = np.concatenate(Xb,0)

    shap_maps = deep_shap(model, X_bg, X_val, nsamples=shap_nsamples)

    # ---------------- Experiment 1: cuts (30,50,70,100) ----------------
    init_pop1, scores1 = init_pop_from_shap_percentiles(shap_maps, BMAPS, template=TEMPLATE, pop_size=200, cuts=(30,50,70,100))
    bestDNA1, bestMask1, hist1 = run_ga(
        model=model, Xref=X_val[...,0], yref_idx=y_val.astype(int),
        BMAPS=BMAPS, template=TEMPLATE, init_pop=init_pop1,
        max_iter=2000, cross_p=0.4, mut_p=0.6, elit=0.05, alpha=0.025, beta=0.975, patience=10
    )
    mtest1 = masked_acc(model, X_test[...,0], np.argmax(yte,1), bestMask1, bs=32)
    sel1 = np.where(bestDNA1>0)[0].tolist()
    print(f"[Exp1] Masked Test Acc={mtest1:.3f}  Regions={len(sel1)}  idx={sel1}")

    # ---------------- Experiment 2: cuts (40,60,80,100) ----------------
    init_pop2, scores2 = init_pop_from_shap_percentiles(shap_maps, BMAPS, template=TEMPLATE, pop_size=200, cuts=(40,60,80,100))
    bestDNA2, bestMask2, hist2 = run_ga(
        model=model, Xref=X_val[...,0], yref_idx=y_val.astype(int),
        BMAPS=BMAPS, template=TEMPLATE, init_pop=init_pop2,
        max_iter=2000, cross_p=0.4, mut_p=0.6, elit=0.05, alpha=0.025, beta=0.975, patience=10
    )
    mtest2 = masked_acc(model, X_test[...,0], np.argmax(yte,1), bestMask2, bs=32)
    sel2 = np.where(bestDNA2>0)[0].tolist()
    print(f"[Exp2] Masked Test Acc={mtest2:.3f}  Regions={len(sel2)}  idx={sel2}")

    # ---------------- Intersection (as reported in manuscript) ----------------
    inter_idx = sorted(list(set(sel1).intersection(set(sel2))))
    inter_DNA = np.zeros_like(bestDNA1); inter_DNA[inter_idx] = np.maximum(bestDNA1[inter_idx], bestDNA2[inter_idx])
    inter_mask = chrom_to_mask(BMAPS, inter_DNA, TEMPLATE)
    mtest_inter = masked_acc(model, X_test[...,0], np.argmax(yte,1), inter_mask, bs=32)
    print(f"[Intersection] Masked Test Acc={mtest_inter:.3f}  Regions={len(inter_idx)}  idx={inter_idx}")

    # ---------------- Save all outputs ----------------
    os.makedirs(outdir, exist_ok=True)
    np.save(os.path.join(outdir,"baseline_acc.npy"), np.array([btr,bva,bte]))
    np.save(os.path.join(outdir,"shap_scores_exp1.npy"), np.array(scores1))
    np.save(os.path.join(outdir,"shap_scores_exp2.npy"), np.array(scores2))

    np.save(os.path.join(outdir,"best_DNA_exp1.npy"), bestDNA1)
    np.save(os.path.join(outdir,"best_mask_exp1.npy"), bestMask1.astype(np.uint8))
    np.save(os.path.join(outdir,"ga_history_exp1.npy"), np.array(hist1, dtype=object))

    np.save(os.path.join(outdir,"best_DNA_exp2.npy"), bestDNA2)
    np.save(os.path.join(outdir,"best_mask_exp2.npy"), bestMask2.astype(np.uint8))
    np.save(os.path.join(outdir,"ga_history_exp2.npy"), np.array(hist2, dtype=object))

    np.save(os.path.join(outdir,"intersect_idx.npy"), np.array(inter_idx))
    np.save(os.path.join(outdir,"intersect_DNA.npy"), inter_DNA)
    np.save(os.path.join(outdir,"intersect_mask.npy"), inter_mask.astype(np.uint8))

    return dict(
        model=model,
        baseline=dict(train=btr,val=bva,test=bte),
        exp1=dict(masked_test_acc=mtest1, selected_regions=sel1, best_DNA=bestDNA1, hist=hist1),
        exp2=dict(masked_test_acc=mtest2, selected_regions=sel2, best_DNA=bestDNA2, hist=hist2),
        intersection=dict(masked_test_acc=mtest_inter, selected_regions=inter_idx, DNA=inter_DNA)
    )

if __name__ == "__main__":
    results = run_pipeline(outdir="outputs", bg_per_class=20, shap_nsamples=None)
