In [1]:
# plot saliency maps of top neurons onto fish images.. only 11-13

In [None]:
#!/usr/bin/env python
# overlay_saliency_spatial.py
# ─────────────────────────────────────────────────────────────────────────────
"""
Creates a saliency overlay for each fish:
 * loads all importance.npy files from experiment_6
 * averages across runs / seq-lengths
 * picks the top-K neurons
 * draws them onto the anatomical plane_0.jpg
 * saves fish{X}_saliency_overlay.png in the experiment_6 results folder
"""
import os, re, ast, numpy as np, pandas as pd, matplotlib.pyplot as plt
from PIL import Image
# ─────────────────────────────────────────────────────────────────────────────
# CONFIG
# ─────────────────────────────────────────────────────────────────────────────
EXPT_DIR   = "/hpc/group/naumannlab/jjm132/nlp4neuro/results/experiment_6"
IMG_ROOT   = os.path.join(os.path.dirname(EXPT_DIR), "fish{fish}_images")
DATA_ROOT  = "/hpc/group/naumannlab/jjm132/data"                   # base raw-data dir
TOP_K      = 20
DOT_COLOR  = "red"
DOT_EDGE   = "white"
DOT_ALPHA  = .8
DOT_MIN, DOT_MAX = 50, 300        # scatter dot size range (pixels^2)
# mapping fish → raw-data folder (same as you used to build neural_data_matched)
RAW_FOLDER = {
     9: "20241104_elavl3rsChrm_h2bg6s_OMR2stim_fish9_omr_stack-002/20241104_elavl3rsChrm_h2bg6s_OMR2stim_fish9_omr_stack-002",
    10: "20241119_elavl3rsChrm_H2bG6s_7dpf_OMR2Stim_fish10_OMR_stack-002/20241119_elavl3rsChrm_H2bG6s_7dpf_OMR2Stim_fish10_OMR_stack-002",
    11: "20241209_elavl3rsChrm_H2bG6s_OMR2Stim_fish11_omr_stack-007/20241209_elavl3rsChrm_H2bG6s_OMR2Stim_fish11_omr_stack-007",
    12: "20241209_elavl3rsChrm_H2bG6s_OMR2Stim_fish12_omr_stack-009/20241209_elavl3rsChrm_H2bG6s_OMR2Stim_fish12_omr_stack-009",
    13: "20241216_elavl3rsChrm_H2bG6s_OMR2Stim_fish13_omr_stack/20241216_elavl3rsChrm_H2bG6s_OMR2Stim_fish13_omr_stack",
}
# ─────────────────────────────────────────────────────────────────────────────
def load_all_importances(fish_id):
    """stack all importance.npy for a fish -> (n_runs, n_neurons)"""
    fish_dir = os.path.join(EXPT_DIR, f"fish{fish_id}")
    imp_list = []
    for root, _, files in os.walk(fish_dir):
        for f in files:
            if f == "importance.npy":
                imp_list.append(np.load(os.path.join(root, f)))
    if not imp_list:
        raise FileNotFoundError(f"No importance.npy found for fish {fish_id}")
    return np.vstack(imp_list)          # shape (n_files, n_neurons)

def neuron_xy_array(fish_id):
    """return (n_neurons, 2) array of x,y"""
    raw_dir = os.path.join(DATA_ROOT, RAW_FOLDER[fish_id])
    h5_path = os.path.join(raw_dir, "functional_types_df.h5")
    if not os.path.exists(h5_path):
        raise FileNotFoundError(f"{h5_path} not found")
    df = pd.read_hdf(h5_path)           # expects 'neur_coords' column
    coords = df["neur_coords"].apply(ast.literal_eval).to_list()   # list of [x,y]
    return np.array(coords)             # (n_neurons, 2)

def main():
    for fish in [9,10,11,12,13]:
        print(f"Fish {fish}: loading data …")
        imps   = load_all_importances(fish)        # (n_files, n_neurons)
        imp_mu = imps.mean(axis=0)                 # mean saliency per neuron

        coords = neuron_xy_array(fish)             # (n_neurons, 2)
        assert len(coords) == len(imp_mu), "coords and saliency length mismatch"

        # pick top-K
        idx_topk = np.argsort(imp_mu)[-TOP_K:][::-1]
        xy_topk  = coords[idx_topk]
        imp_topk = imp_mu[idx_topk]

        # scale dot size
        s = ( (imp_topk - imp_topk.min()) /
              (imp_topk.ptp() + 1e-6) )*(DOT_MAX - DOT_MIN) + DOT_MIN

        # anatomical image
        img_path = os.path.join(IMG_ROOT.format(fish=fish), "plane_0.jpg")
        if not os.path.exists(img_path):
            raise FileNotFoundError(img_path)
        img = Image.open(img_path)

        # plot
        fig, ax = plt.subplots(figsize=(6,6))
        ax.imshow(img)
        ax.scatter(xy_topk[:,0], xy_topk[:,1], s=s, c=DOT_COLOR, edgecolors=DOT_EDGE,
                   alpha=DOT_ALPHA)
        for i,(x,y) in enumerate(xy_topk):
            ax.text(x, y, str(i+1), color="white", fontsize=8,
                    ha="center", va="center", weight="bold")
        ax.set_title(f"Fish {fish} – Top-{TOP_K} neurons (plane 0)")
        ax.axis("off")

        save_to = os.path.join(EXPT_DIR, f"fish{fish}",
                               f"fish{fish}_saliency_overlay.png")
        plt.tight_layout(); plt.savefig(save_to, dpi=300); plt.close()
        print(f"  → saved overlay to {save_to}")

if __name__ == "__main__":
    main()
