In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import glob
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from torch.utils.data import DataLoader

import sys
sys.path.append("../electric-properties-only")
sys.path.append("../end-to-end")
sys.path.append("../end-to-end-with-feedback")
sys.path.append("../../efish-physics-model/objects")
sys.path.append("../../efish-physics-model/helper_functions")
sys.path.append("../../efish-physics-model/uniform_points_generation")

# from helpers_conv_nn_models import make_true_vs_predicted_figure
from electric_images_dataset import ElectricImagesDataset
from EndToEndConvNN_PL import EndToEndConvNN_PL
from EndToEndConvNNWithFeedback_PL import EndToEndConvNNWithFeedback_PL

## Load trained Models

In [3]:
models = pd.DataFrame()

#### Prepare dataset for dummy run to initialize LazyLayers

In [4]:
batch_size = 100
data_dir_name = "../../efish-physics-model/data/processed/data-2024_12_04-discrimination_dataset"
# data_dir_name = "../../efish-physics-model/data/processed/data-2024_06_18-characterization_dataset"
raw_dataset = pd.read_pickle(f"{data_dir_name}/dataset.pkl")
dset = ElectricImagesDataset(data_dir_name=data_dir_name, fish_t=20, fish_u=30)
dloader = DataLoader(dset, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=12)

#### Load full end-to-end models

In [5]:
for folder in list(np.sort(glob.glob("../figures/stats-panel/full-model*"))):
    checkpoint_path = f"{folder}/lightning_logs/version_0/checkpoints/epoch=4-step=25015.ckpt"
    model = EndToEndConvNN_PL.load_from_checkpoint(checkpoint_path)
    model.eval()
    model.freeze()
    model.cpu()
    rand_seed = int(folder.split("_")[1].split("-")[0])
    lambda_RC = int(folder.split("_")[-1])
    models = pd.concat([models, pd.DataFrame({"rand_seed": [rand_seed], "lambda_RC": [lambda_RC], "model_type":"full", "model": [model]})]).reset_index(drop=True)



In [6]:
for folder in list(np.sort(glob.glob("../figures/stats-panel/feedback-with-values*"))):
    checkpoint_path = f"{folder}/lightning_logs/version_0/checkpoints/epoch=4-step=25015.ckpt"
    model = EndToEndConvNNWithFeedback_PL.load_from_checkpoint(checkpoint_path)
    model.eval()
    model.freeze()
    model.cpu()
    rand_seed = int(folder.split("_")[1].split("-")[0])
    lambda_RC = int(folder.split("_")[-1])
    models = pd.concat([models, pd.DataFrame({"rand_seed": [rand_seed], "lambda_RC": [lambda_RC], "model_type":"feedback_vals", "model": [model]})]).reset_index(drop=True)

for folder in list(np.sort(glob.glob("../figures/stats-panel/feedback-with-estimates*"))):
    checkpoint_path = f"{folder}/lightning_logs/version_0/checkpoints/epoch=4-step=25015.ckpt"
    model = EndToEndConvNNWithFeedback_PL.load_from_checkpoint(checkpoint_path)
    model.eval()
    model.freeze()
    model.cpu()
    rand_seed = int(folder.split("_")[1].split("-")[0])
    lambda_RC = int(folder.split("_")[-1])
    models = pd.concat([models, pd.DataFrame({"rand_seed": [rand_seed], "lambda_RC": [lambda_RC], "model_type":"feedback_esti", "model": [model]})]).reset_index(drop=True)

## Inspect discrimination data

In [None]:
eis = next(iter(dloader))[0]
max_mod = eis[:,:,5,6]
print(eis.shape, max_mod.shape)

torch.Size([36, 2, 20, 30]) torch.Size([36, 2])


## Compute models predictions

In [9]:
models["prediction"] = models.apply(lambda x: np.nan if x["model_type"] == "feedback_vals" else x["model"].model(eis).detach().cpu().numpy(), axis=1)
predictions = models["prediction"].dropna()

In [11]:
predictions

0     [[1.0879288, -1.7104414, 0.19586463, -1.165871...
1     [[1.2015797, -1.8188689, 0.060691796, -0.87724...
2     [[1.2062205, -1.7061975, -0.08710666, -0.64469...
3     [[1.1098105, -1.9065433, -0.056732252, -1.2388...
4     [[1.0690147, -1.8444798, 0.06205017, -0.782846...
5     [[1.2554508, -1.633863, -0.011240186, -0.93791...
6     [[0.9909905, -1.623841, -0.14795007, -0.999237...
7     [[1.1543545, -1.7656376, 0.0122752115, -0.8426...
8     [[1.0690492, -1.6558446, -0.068323046, -1.1859...
9     [[1.143685, -1.8807266, -0.15735708, -0.865185...
10    [[1.1822836, -1.74764, 0.045839936, -1.1373272...
11    [[1.2732633, -1.8734568, 0.21749571, -1.203840...
24    [[1.1611611, -1.7896407, -0.6544116, -0.985888...
25    [[1.0937381, -1.8148575, -0.05253224, -1.27198...
26    [[1.2175602, -0.71030885, -0.46170676, -0.2227...
27    [[1.1597536, -1.3854177, -0.11776447, -1.19350...
28    [[1.2561973, -1.269551, 0.31300837, -1.1037918...
29    [[1.4800192, 0.4341957, 0.005852923, 0.182