In [None]:
DATA

# Figure 1c

In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import sem

fish = 11
seq = 10
use_components = [0, 1, 2, 3, 4, 5]
num_runs = 10
models_to_use = []
panel_plot = True

width = 3.0
height_per_panel = 0.8

frame_range = None

RESULTS_ROOT = os.path.abspath(
    os.path.join(
        os.getcwd(), os.pardir, os.pardir, "results", "experiment_1"
    )
)

full_model_map = {
    "GPT-2 (base)":       ("gpt2_pretrained",    "#1f77b4"),
    "LSTM":               ("lstm",               "#ff7f0e"),
    "DeepSeek-coder-7b":  ("deepseek_moe",       "#000000"),
    "BERT-small":         ("bert",               "#7f7f7f"),
}
model_map = {m: full_model_map[m] for m in models_to_use}

gt_path = os.path.join(
    RESULTS_ROOT, f"fish{fish}",
    f"fish{fish}_final_predictions_groundtruth_test.npy"
)
y_true = np.load(gt_path)
time_vec = np.arange(y_true.shape[0]) / 1.1

if frame_range is not None:
    start, end = frame_range
    y_true = y_true[start:end]
    time_vec = time_vec[start:end]

plt.rcParams["pdf.fonttype"] = 42

if panel_plot:
    baselines = {
        comp: np.mean(y_true[:50, comp])
        for comp in use_components
    }
    n = len(use_components)
    fig, axes = plt.subplots(
        n, 1, sharex=True,
        figsize=(width, height_per_panel * n),
        dpi=300
    )

    for idx, comp in enumerate(use_components):
        ax = axes[idx]
        gt_label = "Ground truth" if idx == 0 else None
        y_gt = y_true[:, comp] - baselines[comp]
        ax.plot(time_vec, y_gt, color="red", lw=1.5, label=gt_label)

        for model_name, (key, color) in model_map.items():
            runs = []
            for run in range(1, num_runs + 1):
                fpath = os.path.join(
                    RESULTS_ROOT,
                    f"fish{fish}", f"run_{run}",
                    f"seq_{seq}",
                    f"fish{fish}_final_predictions_{key}_test_run{run}.npy"
                )
                if not os.path.isfile(fpath):
                    continue
                pred = np.load(fpath)
                if frame_range is not None:
                    pred = pred[start:end]
                runs.append(pred[:, comp] - baselines[comp])

            if not runs:
                continue

            data = np.vstack(runs)
            mean_ts = data.mean(axis=0)
            sem_ts = sem(data, axis=0)
            mlabel = model_name if idx == 0 else None

            ax.plot(time_vec, mean_ts, color=color, lw=1.2, label=mlabel)
            ax.fill_between(
                time_vec,
                mean_ts - sem_ts,
                mean_ts + sem_ts,
                color=color, alpha=0.2
            )

        ax.set_ylabel(rf"$\Delta\theta_{{{comp}}}$")
        ax.spines["top"].set_visible(False)
        ax.spines["right"].set_visible(False)

    handles, labels = axes[0].get_legend_handles_labels()
    fig.legend(
        handles, labels,
        ncol=len(handles),
        frameon=False,
        loc="lower center",
        bbox_to_anchor=(0.5, 1.0),
        bbox_transform=fig.transFigure
    )
    fig.subplots_adjust(top=0.85)
    axes[-1].set_xlabel("Time (s)")
    fig.tight_layout()
    out_name = f"tail_panels_fish{fish}_s{seq}_baseline.pdf"

else:
    y_sum_full = y_true[:, use_components].sum(axis=1)
    baseline_sum = np.mean(y_sum_full[:50])
    y_sum_shifted = y_sum_full - baseline_sum

    fig, ax = plt.subplots(figsize=(width, height_per_panel), dpi=300)

    ax.plot(time_vec, y_sum_shifted,
            color="red", lw=1.5, label="Ground truth")

    for model_name, (key, color) in model_map.items():
        runs = []
        for run in range(1, num_runs + 1):
            fpath = os.path.join(
                RESULTS_ROOT,
                f"fish{fish}", f"run_{run}",
                f"seq_{seq}",
                f"fish{fish}_final_predictions_{key}_test_run{run}.npy"
            )
            if not os.path.isfile(fpath):
                continue
            pred = np.load(fpath)
            if frame_range is not None:
                pred = pred[start:end]
            runs.append(pred[:, use_components].sum(axis=1) - baseline_sum)

        if not runs:
            continue

        data = np.vstack(runs)
        mean_ts = data.mean(axis=0)
        sem_ts = sem(data, axis=0)

        ax.plot(time_vec, mean_ts, color=color, lw=1.2, label=model_name)
        ax.fill_between(
            time_vec,
            mean_ts - sem_ts,
            mean_ts + sem_ts,
            color=color, alpha=0.2
        )

    ax.set_xlabel("Time (s)")
    ax.set_ylabel(r"$\Delta \sum_{i}\theta_i$")
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)
    ax.legend(ncol=1, frameon=False, loc="upper right")
    fig.tight_layout()
    out_name = f"tail_sum_fish{fish}_s{seq}_baseline.pdf"

out_file = os.path.join(RESULTS_ROOT, out_name)
fig.savefig(out_file, dpi=300, bbox_inches="tight")
print("Saved to", out_file)


Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x000002816DA0D160>>
Traceback (most recent call last):
  File "C:\Users\jacob\AppData\Roaming\jupyterlab-desktop\envs\fconn-holo\Lib\site-packages\ipykernel\ipkernel.py", line 790, in _clean_thread_parent_frames
    active_threads = {thread.ident for thread in threading.enumerate()}
  File "C:\Users\jacob\AppData\Roaming\jupyterlab-desktop\envs\fconn-holo\Lib\threading.py", line 1477, in enumerate
    def enumerate():
KeyboardInterrupt: 

KeyboardInterrupt



## Figure 2a

## Figure 2b

## Figure 2c

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import sem, t
from matplotlib.lines import Line2D

fish = 11
seq = 10
use_components = [0, 1, 2, 3, 4, 5]
num_runs = 10
models_to_use = [
    "GPT-2 (base)",
    "LSTM",
    "DeepSeek-coder-7b",
    "BERT-small",
    "Reservoir computer"
]
panel_plot = True

reservoir_in_legend = True
single_run = None

width = 3
height_per_panel = 1
frame_range = None
sampling_rate_hz = 1.1

RESULTS_ROOT = os.path.abspath(
    os.path.join(
        os.getcwd(), os.pardir, os.pardir, "results", "experiment_1"
    )
)

full_model_map = {
    "GPT-2 (base)":      ("gpt2_pretrained", "#1f77b4"),
    "LSTM":              ("lstm",            "#ff7f0e"),
    "DeepSeek-coder-7b": ("deepseek_moe",    "#000000"),
    "BERT-small":        ("bert",            "#7f7f7f"),
    "Reservoir computer":("reservoir",       "purple"),
}
model_map = {m: full_model_map[m] for m in models_to_use}

gt_path = os.path.join(
    RESULTS_ROOT, f"fish{fish}",
    f"fish{fish}_final_predictions_groundtruth_test.npy"
)
y_true = np.load(gt_path)
time_vec = np.arange(y_true.shape[0]) / sampling_rate_hz

if frame_range is not None:
    start, end = frame_range
    y_true = y_true[start:end]
    time_vec = time_vec[start:end]

plt.rcParams["pdf.fonttype"] = 42

def get_run_list():
    return [single_run] if single_run is not None else range(1, num_runs + 1)

if panel_plot:
    baselines = {c: np.mean(y_true[:50, c]) for c in use_components}
    n = len(use_components)
    fig, axes = plt.subplots(
        n, 1, sharex=True,
        figsize=(width, height_per_panel * n),
        dpi=300
    )

    for idx, comp in enumerate(use_components):
        ax = axes[idx]
        gt_label = "Ground truth" if idx == 0 else None
        y_gt = y_true[:, comp] - baselines[comp]
        ax.plot(time_vec, y_gt, color="red", lw=0.5, label=gt_label)

        for model_name, (key, color) in model_map.items():
            if model_name == "Reservoir computer" and idx != 0:
                continue

            runs = []
            for run in get_run_list():
                fpath = os.path.join(
                    RESULTS_ROOT,
                    f"fish{fish}", f"run_{run}",
                    f"seq_{seq}",
                    f"fish{fish}_final_predictions_{key}_test_run{run}.npy"
                )
                if not os.path.isfile(fpath):
                    continue
                pred = np.load(fpath)
                if frame_range is not None:
                    pred = pred[start:end]
                runs.append(pred[:, comp] - baselines[comp])

            if not runs:
                continue

            data = np.vstack(runs)
            mean_ts = data.mean(axis=0)
            sem_ts = sem(data, axis=0)
            df = data.shape[0] - 1
            ci_mult = t.ppf(0.975, df=df) if df > 0 else 0
            ci = ci_mult * sem_ts

            mlabel = model_name if idx == 0 else None
            ax.plot(time_vec, mean_ts, color=color, lw=0.5, label=mlabel)
            ax.fill_between(
                time_vec,
                mean_ts - ci,
                mean_ts + ci,
                color=color, alpha=0.2
            )

        ax.set_ylabel(rf"$\Delta\theta_{{{comp}}}$")
        ax.spines["top"].set_visible(False)
        ax.spines["right"].set_visible(False)

    handles, labels = axes[0].get_legend_handles_labels()
    hdict = dict(zip(labels, handles))

    if reservoir_in_legend:
        ordered = [
            "LSTM",
            "BERT-small",
            "GPT-2 (base)",
            "DeepSeek-coder-7b",
            "Reservoir computer",
            "Ground truth"
        ]
        leg_h = [hdict[l] for l in ordered]
        fig.legend(
            leg_h, ordered,
            ncol=2,
            frameon=False,
            loc="lower center",
            bbox_to_anchor=(0.5, 1.00),
            bbox_transform=fig.transFigure
        )
        fig.subplots_adjust(top=0.85)
    else:
        model_order = ["LSTM", "BERT-small", "GPT-2 (base)", "DeepSeek-coder-7b"]
        model_handles = [hdict[m] for m in model_order]
        fig.legend(
            model_handles, model_order,
            ncol=2,
            frameon=False,
            loc="lower center",
            bbox_to_anchor=(0.5, 1.00),
            bbox_transform=fig.transFigure
        )

        gt_hdl = hdict["Ground truth"]
        fig.legend(
            [gt_hdl], ["Ground truth"],
            ncol=1,
            frameon=False,
            loc="lower center",
            bbox_to_anchor=(0.5, 0.95),
            bbox_transform=fig.transFigure
        )
        fig.subplots_adjust(top=0.85)

    axes[-1].set_xlabel("Time (s)")
    fig.tight_layout()
    out_name = f"tail_panels_fish{fish}_s{seq}_baseline.pdf"

else:
    y_sum_full = y_true[:, use_components].sum(axis=1)
    baseline_sum = np.mean(y_sum_full[:50])
    y_sum_shifted = y_sum_full - baseline_sum

    fig, ax = plt.subplots(figsize=(width, height_per_panel), dpi=300)
    ax.plot(
        time_vec, y_sum_shifted,
        color="red", lw=1.0, label="Ground truth"
    )

    for model_name, (key, color) in model_map.items():
        runs = []
        for run in get_run_list():
            fpath = os.path.join(
                RESULTS_ROOT,
                f"fish{fish}", f"run_{run}",
                f"seq_{seq}",
                f"fish{fish}_final_predictions_{key}_test_run{run}.npy"
            )
            if not os.path.isfile(fpath):
                continue
            pred = np.load(fpath)
            if frame_range is not None:
                pred = pred[start:end]
            runs.append(pred[:, use_components].sum(axis=1) - baseline_sum)

        if not runs:
            continue

        data = np.vstack(runs)
        mean_ts = data.mean(axis=0)
        sem_ts = sem(data, axis=0)
        df = data.shape[0] - 1
        ci_mult = t.ppf(0.975, df=df) if df > 0 else 0
        ci = ci_mult * sem_ts

        ax.plot(
            time_vec, mean_ts,
            color=color, lw=0.8, label=model_name
        )
        ax.fill_between(
            time_vec,
            mean_ts - ci,
            mean_ts + ci,
            color=color, alpha=0.2
        )

    ax.set_xlabel("Time (s)")
    ax.set_ylabel(r"$\Delta \sum_{i}\theta_i$")
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)
    ax.legend(ncol=2, frameon=False, loc="upper right")

    fig.tight_layout()
    out_name = f"tail_sum_fish{fish}_s{seq}_baseline.pdf"

out_file = os.path.join(RESULTS_ROOT, out_name)
fig.savefig(out_file, dpi=300, bbox_inches="tight")
print("Saved to", out_file)



# add event times...
width = 6
height_per_panel = 1
frame_range = None
sampling_rate_hz = 1.1

EVENTS = [
    (18.3, 24.7), (45.0, 52.5), (67.2, 75.0),
    (115, 120), (135, 140), (162, 167),
    (169, 174), (195, 205), (210, 220),
    (240, 248), (255, 265), (300, 310),
    (327, 342), (368, 375), (408, 423),
    (449, 465), (490, 500),
]
EVENT_COLOR = "gray"
EVENT_ALPHA = 0.12

RESULTS_ROOT = os.path.abspath(
    os.path.join(
        os.getcwd(), os.pardir, os.pardir, "results", "experiment_1"
    )
)

full_model_map = {
    "GPT-2":        ("gpt2_pretrained", "#1f77b4"),
    "LSTM":         ("lstm",            "#ff7f0e"),
    "DeepSeek-c7b": ("deepseek_moe",    "#000000"),
    "BERT-bu":      ("bert",            "#7f7f7f"),
    "RC":           ("reservoir",       "purple"),
}
model_map = {m: full_model_map[m] for m in models_to_use}

gt_path = os.path.join(
    RESULTS_ROOT,
    f"fish{fish}",
    f"fish{fish}_final_predictions_groundtruth_test.npy"
)
y_true = np.load(gt_path)
time_vec = np.arange(y_true.shape[0]) / sampling_rate_hz

if frame_range is not None:
    start, end = frame_range
    y_true = y_true[start:end]
    time_vec = time_vec[start:end]

plt.rcParams["pdf.fonttype"] = 42

def get_run_list():
    return [single_run] if single_run is not None else range(1, num_runs + 1)

if panel_plot:
    baselines = {c: np.mean(y_true[:50, c]) for c in use_components}
    n_panels = len(use_components)
    fig, axes = plt.subplots(
        n_panels, 1,
        sharex=True,
        figsize=(width, height_per_panel * n_panels),
        dpi=300
    )

    for idx, comp in enumerate(use_components):
        ax = axes[idx]

        for start_s, end_s in EVENTS:
            ax.axvspan(start_s, end_s, color=EVENT_COLOR, alpha=EVENT_ALPHA, zorder=0)

        gt_label = "Ground truth" if idx == 0 else None
        y_gt = y_true[:, comp] - baselines[comp]
        ax.plot(time_vec, y_gt, color="red", lw=0.5, label=gt_label)

        for model_name, (key, color) in model_map.items():
            if model_name == "RC" and idx != 0:
                continue

            runs = []
            for run in get_run_list():
                fpath = os.path.join(
                    RESULTS_ROOT,
                    f"fish{fish}", f"run_{run}",
                    f"seq_{seq}",
                    f"fish{fish}_final_predictions_{key}_test_run{run}.npy"
                )
                if not os.path.isfile(fpath):
                    continue
                pred = np.load(fpath)
                if frame_range is not None:
                    pred = pred[start:end]
                runs.append(pred[:, comp] - baselines[comp])

            if not runs:
                continue

            data = np.vstack(runs)
            mean_ts = data.mean(axis=0)
            sem_ts = sem(data, axis=0)
            df = data.shape[0] - 1
            ci_mult = t.ppf(0.975, df=df) if df > 0 else 0
            ci = ci_mult * sem_ts

            mlabel = model_name if idx == 0 else None
            ax.plot(time_vec, mean_ts, color=color, lw=0.5, label=mlabel)
            ax.fill_between(
                time_vec,
                mean_ts - ci,
                mean_ts + ci,
                color=color, alpha=0.2
            )

        ax.set_ylabel(rf"$\Delta\theta_{{{comp}}}$")
        ax.spines["top"].set_visible(False)
        ax.spines["right"].set_visible(False)

    handles, labels = axes[0].get_legend_handles_labels()
    hdict = dict(zip(labels, handles))

    if reservoir_in_legend:
        ordered = [
            "LSTM",
            "BERT-bu",
            "GPT-2",
            "DeepSeek-c7b",
            "RC",
            "Ground truth"
        ]
        leg_h = [hdict[l] for l in ordered]
        fig.legend(
            leg_h, ordered,
            ncol=3, frameon=False,
            loc="lower center",
            bbox_to_anchor=(0.5, 1.00),
            bbox_transform=fig.transFigure
        )
        fig.subplots_adjust(top=0.85)
    else:
        model_order = ["LSTM", "BERT-bu", "GPT-2", "DeepSeek-c7b", "RC"]
        model_handles = [hdict[m] for m in model_order]
        fig.legend(
            model_handles, model_order,
            ncol=3, frameon=False,
            loc="lower center",
            bbox_to_anchor=(0.5, 1.00),
            bbox_transform=fig.transFigure
        )
        gt_hdl = hdict["Ground truth"]
        fig.legend(
            [gt_hdl], ["Ground truth"],
            ncol=1, frameon=False,
            loc="lower center",
            bbox_to_anchor=(0.5, 0.95),
            bbox_transform=fig.transFigure
        )
        fig.subplots_adjust(top=0.85)

    axes[-1].set_xlabel("Time (s)")
    fig.tight_layout()
    out_name = f"tail_panels_fish{fish}_s{seq}_baseline.pdf"

else:
    y_sum_full = y_true[:, use_components].sum(axis=1)
    baseline_sum = np.mean(y_sum_full[:50])
    y_sum_shifted = y_sum_full - baseline_sum

    fig, ax = plt.subplots(figsize=(width, height_per_panel), dpi=300)

    for start_s, end_s in EVENTS:
        ax.axvspan(start_s, end_s, color=EVENT_COLOR, alpha=EVENT_ALPHA, zorder=0)

    ax.plot(time_vec, y_sum_shifted, color="red", lw=1.0, label="Ground truth")

    for model_name, (key, color) in model_map.items():
        runs = []
        for run in get_run_list():
            fpath = os.path.join(
                RESULTS_ROOT,
                f"fish{fish}", f"run_{run}",
                f"seq_{seq}",
                f"fish{fish}_final_predictions_{key}_test_run{run}.npy"
            )
            if not os.path.isfile(fpath):
                continue
            pred = np.load(fpath)
            if frame_range is not None:
                pred = pred[start:end]
            runs.append(pred[:, use_components].sum(axis=1) - baseline_sum)

        if not runs:
            continue

        data = np.vstack(runs)
        mean_ts = data.mean(axis=0)
        sem_ts = sem(data, axis=0)
        df = data.shape[0] - 1
        ci_mult = t.ppf(0.975, df=df) if df > 0 else 0
        ci = ci_mult *sem_ts

        ax.plot(time_vec, mean_ts, color=color, lw=0.8, label=model_name)
        ax.fill_between(
            time_vec,
            mean_ts - ci,
            mean_ts + ci,
            color=color, alpha=0.2
        )

    ax.set_xlabel("Time (s)")
    ax.set_ylabel(r"$\Delta \sum_{i}\theta_i$")
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)
    ax.legend(ncol=3, frameon=False, loc="upper right")

    fig.tight_layout()
    out_name = f"tail_sum_fish{fish}_s{seq}_baseline2.pdf"

out_file = os.path.join(RESULTS_ROOT, out_name)
fig.savefig(out_file, dpi=300, bbox_inches="tight")
print("Saved to", out_file)

In [None]:
fish = 11

# once the results are saved in the results folder, ensure that BASE_DIR is pointing to where
# the saved results are. Also check that the images are in the indicated directory img_dir.
BASE_DIR = os.path.abspath(
    os.path.join(
        os.getcwd(), os.pardir, os.pardir, "results", "experiment_4b"
    )
)
img_dir = os.path.join(
        os.getcwd(), os.pardir, os.pardir, "exp1-4_data", "data_prepped_for_models", f"fish{fish}_images"
    )

In [None]:
import os
import glob
import tifffile
import numpy as np
import matplotlib.pyplot as plt

plane0_file = glob.glob(os.path.join(img_dir, "plane_0.*"))[0]

im = tifffile.imread(plane0_file).astype(float)
im = (im - im.min()) / (im.max() - im.min())
im *= 1.5
im = np.clip(im, 0, 1)

fig, ax = plt.subplots(figsize=(6, 6), dpi=300)
ax.imshow(im, cmap="viridis")

bar_len_px = SCALE_BAR_UM / UM_PER_PIXEL
x0, y0 = 0.02, 0.95

# convert axes fraction to data coordinates
xdata0, ydata0 = ax.transAxes.transform((x0, 1 - y0))
xdata0, ydata0 = ax.transData.inverted().transform((xdata0, ydata0))

# draw scale bar
ax.hlines(
    y=ydata0,
    xmin=xdata0,
    xmax=xdata0 + bar_len_px,
    colors="white",
    linewidth=3,
    transform=ax.transData,
    clip_on=False
)
ax.text(
    xdata0 + bar_len_px / 2,
    ydata0 - bar_len_px * 0.3,
    f"{SCALE_BAR_UM} µm",
    color="white",
    ha="center",
    va="top",
    fontsize=12
)

# label
ax.text(
    0.01,
    0.99,
    "GCaMP6s",
    transform=ax.transAxes,
    va="top",
    ha="left",
    color="white",
    weight="bold",
    fontsize=16
)

ax.axis("off")

out = os.path.join(BASE_DIR, f"fish{fish}_calcium.pdf")
fig.savefig(out, bbox_inches="tight", transparent=True)
plt.close(fig)
print("saved", out)


In [None]:
import os ,glob ,ast ,numpy as np ,pandas as pd ,matplotlib .pyplot as plt ,tifffile 

fish =11 
TOP_K =30 
DOT_MIN ,DOT_MAX =50 ,300 

def bright_plane0 (f ,boost =1.5 ):
    plane0 =glob .glob (os .path .join (img_dir ,"plane_0.*"))[0 ]
    im =tifffile .imread (plane0 ).astype (float )
    im =(im -im .min ())/(im .max ()-im .min ())
    return np .clip (im *boost ,0 ,1 )

def load_saliency (f ):

    root =os .path .join (BASE_DIR ,f"fish{f}")
    print (root )
    vecs =[
    np .load (os .path .join (r ,"importance.npy"))
    for r ,_ ,fs in os .walk (root )if "importance.npy"in fs 
    ]
    if not vecs :
        raise FileNotFoundError (
        f"No importance.npy files under {root}. "
        "Check the folder structure or BASE_DIR."
        )
    return np .vstack (vecs ).mean (0 )

def plane0_coords (f ):
    h5 =os .path .join (os .path .dirname (BASE_DIR ),os .pardir ,
    f"fish{f}_images","functional_types_df.h5")
    h5 =os .path .normpath (h5 )
    df =pd .read_hdf (h5 )
    df0 =df [df .plane =="plane_0"]

    coords =np .vstack (
    df0 .neur_coords .apply (
    lambda v :ast .literal_eval (v )if isinstance (v ,str )else v 
    )
    )

    return coords ,df0 .index .values .astype (int )

sal =load_saliency (fish )
coords ,idx =plane0_coords (fish )
sal0 =sal [idx ]

top =np .argsort (sal0 )[-TOP_K :][::-1 ]
arr =sal0 [top ]
sizes =(arr -arr .min ())/(np .ptp (arr )+1e-9 )
sizes =sizes *(DOT_MAX -DOT_MIN )+DOT_MIN 

bg =bright_plane0 (fish )
fig ,ax =plt .subplots (figsize =(6 ,6 ))
ax .imshow (bg ,cmap ="gray",vmin =0 ,vmax =1 )
ax .scatter (coords [top ,0 ],coords [top ,1 ],
s =sizes ,c ="gray",edgecolors ="white",alpha =.8 )

ax .text (0.01 ,0.99 ,f"Fish {fish}",transform =ax .transAxes ,
va ="top",ha ="left",color ="white",weight ="bold",fontsize =16 )

SCALE_BAR_UM =50 
UM_PER_PIXEL =0.6 
bar_len_px =SCALE_BAR_UM /UM_PER_PIXEL 
x0 ,y0 =0.02 ,0.95 

xdata0 ,ydata0 =ax .transAxes .transform ((x0 ,1 -y0 ))
xdata0 ,ydata0 =ax .transData .inverted ().transform ((xdata0 ,ydata0 ))
ax .hlines (y =ydata0 ,xmin =xdata0 ,xmax =xdata0 +bar_len_px ,
colors ="white",linewidth =3 ,transform =ax .transData ,clip_on =False )
ax .text (xdata0 +bar_len_px /2 ,ydata0 -bar_len_px *0.3 ,
f"{SCALE_BAR_UM} µm",color ="white",ha ="center",va ="top",fontsize =12 )

ax .axis ("off")
fig .patch .set_alpha (0 )
ax .patch .set_alpha (0 )

out =os .path .join (BASE_DIR ,f"fish{fish}_overlay.pdf")
fig .savefig (out ,bbox_inches ="tight",transparent =True )
plt .close (fig )
print ("saved ",out )


In [None]:
import os ,glob ,ast ,numpy as np ,pandas as pd ,matplotlib .pyplot as plt ,tifffile 

fish =11 
COLOR_MAP ={"Pt":"orange","Hb":"green","Other":"purple"}

def bright_plane0 (f ,boost =1.5 ):
    plane0 =glob .glob (os .path .join (img_dir ,"plane_0.*"))[0 ]
    im =tifffile .imread (plane0 ).astype (float )
    im =(im -im .min ())/(im .max ()-im .min ())
    return np .clip (im *boost ,0 ,1 )

def plane0_table (f ):
    h5 =os .path .normpath (os .path .join (os .path .dirname (BASE_DIR ),os .pardir ,
    f"fish{f}_images","functional_types_df.h5"))
    df =pd .read_hdf (h5 )
    return df [df .plane =="plane_0"]

df0 =plane0_table (fish ).copy ()
df0 ["coords"]=df0 .neur_coords .apply (
lambda v :ast .literal_eval (v )if isinstance (v ,str )else v 
)
df0 ["grp"]=df0 .region .fillna ("unknown").apply (
lambda r :r if r in ("Pt","Hb")else "Other"
)

bg =bright_plane0 (fish )
fig ,ax =plt .subplots (figsize =(6 ,6 ),dpi =300 )
ax .imshow (bg ,cmap ="gray",vmin =0 ,vmax =1 )

for group ,color in COLOR_MAP .items ():
    coords_list =df0 .loc [df0 .grp ==group ,"coords"].tolist ()
    if not coords_list :
        continue 
    pts =np .vstack (coords_list )
    ax .scatter (pts [:,0 ],pts [:,1 ],
    s =40 ,c =color ,edgecolors ="white",alpha =.9 ,
    label =group )

ax .legend (loc ="lower right")
ax .axis ("off")

fig .patch .set_alpha (0 );ax .patch .set_alpha (0 )

out =os .path .join (BASE_DIR ,f"fish{fish}_clusters_overlay.pdf")
fig .savefig (out ,bbox_inches ="tight",transparent =True )
plt .close (fig )
print ("saved ",out )


In [None]:
import os ,numpy as np ,pandas as pd ,matplotlib .pyplot as plt 
from scipy .stats import sem ,ttest_rel 

fish_list =[11 ,12 ,13 ]
GROUPS =["Pt","Hb","Other"]
COLORS ={"Pt":"orange","Hb":"green","Other":"purple"}

def load_saliency (f ):
    root =os .path .join (BASE_DIR ,f"fish{f}")
    print (root )
    vecs =[
    np .load (os .path .join (r ,"importance.npy"))
    for r ,_ ,fs in os .walk (root )if "importance.npy"in fs 
    ]
    if not vecs :
        raise FileNotFoundError (
        f"No importance.npy files under {root}. "
        "Check the folder structure or BASE_DIR."
        )
    return np .vstack (vecs ).mean (0 )

def plane0_table (f ):
    h5 =os .path .normpath (os .path .join (os .path .dirname (img_dir),os .pardir ,
    f"fish{f}_images","functional_types_df.h5"))
    df =pd .read_hdf (h5 )
    return df [df .plane =="plane_0"]

group_sals =[]
for f in fish_list :
    sal =load_saliency (f )
    df =plane0_table (f ).copy ()
    df ["grp"]=df .region .fillna ("unknown").apply (
    lambda r :r if r in ("Pt","Hb")else "Other"
    )
    vals =sal [df .index .astype (int )]
    s =pd .Series (vals ,index =df .grp ).groupby (level =0 ).mean ()
    group_sals .append (s )

df_groups =pd .concat (group_sals ,axis =1 ).reindex (GROUPS )
means =df_groups .mean (axis =1 )
errs =df_groups .apply (sem ,axis =1 )

x =np .arange (len (GROUPS ))
fig ,ax =plt .subplots (figsize =(5 ,2 ),dpi =300 )
ax .bar (x ,means ,yerr =errs ,capsize =5 ,
color =[COLORS [g ]for g in GROUPS ])
ax .set_xticks (x )
ax .set_ylim (0.00045 ,0.00095 )
ax .set_xticklabels (GROUPS )
ax .set_ylabel ("Mean saliency")

def draw_bracket (ax ,x1 ,x2 ,y ,h ,text ):
    ax .plot ([x1 ,x1 ,x2 ,x2 ],[y ,y +h ,y +h ,y ],
    lw =1.5 ,c ="black")
    ax .text ((x1 +x2 )/2 ,y +h +0.00001 ,text ,
    ha ="center",va ="bottom",fontsize =12 )

print ("Paired t-tests:")
offset =errs .max ()*0.1 
h =errs .max ()*0.1 
pairs =[(0 ,1 ),(0 ,2 ),(1 ,2 )]
for i ,j in pairs :
    grp_i ,grp_j =GROUPS [i ],GROUPS [j ]
    vals_i ,vals_j =df_groups .loc [grp_i ],df_groups .loc [grp_j ]
    t_stat ,p_val =ttest_rel (vals_i ,vals_j ,nan_policy ="omit")
    print (f"{grp_i} vs {grp_j}: p = {p_val:.4f}")
    if p_val <0.2 :
        star ="***"if p_val <0.001 else "**"if p_val <0.01 else f"p={round(p_val,1)}"
        y =max (means [i ]+errs [i ],means [j ]+errs [j ])+offset 
        draw_bracket (ax ,i ,j ,y ,h ,star )
        offset +=h *2.8 

fig .tight_layout ()
out =os .path .join (BASE_DIR ,"cluster_importance_barplot.pdf")
fig .savefig (out ,bbox_inches ="tight",transparent =True )
plt .close (fig )
print ("saved ",out )
