In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import glob
import dill
import numpy as np
import pandas as pd
import h5py
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import lightning as L
from collections import OrderedDict

import sys
sys.path.append("../end-to-end")
# sys.path.append("../../../electric_fish/ActiveZone/electrodynamic/objects")
# sys.path.append("../../../electric_fish/ActiveZone/electrodynamic/helper_functions")
# sys.path.append("../../../electric_fish/ActiveZone/electrodynamic/uniform_points_generation")
sys.path.append("../../efish-physics-model/objects")
sys.path.append("../../efish-physics-model/helper_functions")
sys.path.append("../../efish-physics-model/uniform_points_generation")

from electric_images_dataset import ElectricImagesDataset
from EndToEndConvNN import EndToEndConvNN
from EndToEndConvNN_PL import EndToEndConvNN_PL
from EndToEndConvNNWithFeedback import EndToEndConvNNWithFeedback
from EndToEndConvNNWithFeedback_PL import EndToEndConvNNWithFeedback_PL

In [None]:
# python train_end_to_end_models_with_feedback.py --input_noise_std 0.01 --save_dir small-with-values --use_estimates_as_feedback false --gpu 1
# python train_end_to_end_models_with_feedback.py --input_noise_std 0.01 --save_dir small-with-estimates --use_estimates_as_feedback true --gpu 3

In [11]:
state_dict = torch.load("./with-values/lightning_logs/version_2/checkpoints/epoch=84-step=481950.ckpt")
# state_dict = torch.load("./with-estimates/lightning_logs/version_2/checkpoints/epoch=85-step=487620.ckpt")

In [12]:
state_dict["hyper_parameters"]["loss_lambda"]

tensor([ 1.,  1.,  1., 10., 20., 20.])

In [8]:
state_dict["hyper_parameters"]

{'layers_properties': OrderedDict([('conv1',
               {'in_channels': 1,
                'out_channels': 8,
                'kernel_size': 7,
                'stride': 1,
                'max_pool': {'kernel_size': 3, 'stride': 1}}),
              ('conv2',
               {'in_channels': 8,
                'out_channels': 16,
                'kernel_size': 5,
                'stride': 1}),
              ('conv3',
               {'in_channels': 16,
                'out_channels': 32,
                'kernel_size': 5,
                'stride': 1}),
              ('conv4',
               {'in_channels': 32,
                'out_channels': 16,
                'kernel_size': 5,
                'stride': 1,
                'max_pool': {'kernel_size': 3, 'stride': 1}}),
              ('fc1', {'out_features': 5120}),
              ('fc2', {'in_features': 5120, 'out_features': 2560}),
              ('fc3', {'in_features': 2560, 'out_features': 1280}),
              ('fc4', {'in_features':

# Check model size

In [9]:
data_dir_name = "../../efish-physics-model/data/processed/data-2024_06_18-characterization_dataset"

dset = ElectricImagesDataset(data_dir_name=data_dir_name, fish_t=20, fish_u=30)

In [12]:
layers_properties = OrderedDict(
    [
        (
            "conv1",
            dict(
                in_channels=1, out_channels=8, kernel_size=7, stride=1, max_pool=dict(kernel_size=3, stride=1)
            ),
        ),
        (
            "conv2",
            dict(in_channels=8, out_channels=8, kernel_size=5, stride=1),
        ),
        # # (
        # #     "conv2-2",
        #     dict(in_channels=16, out_channels=32, kernel_size=5, stride=1),
        # ),
        (
            "conv3",
            dict(in_channels=8, out_channels=8, kernel_size=5, stride=1, max_pool=dict(kernel_size=3, stride=2)),
        ),
        # the fully connected layers can have dropout or flatten layers - some can miss the activation
        ("fc1", dict(dropout=0.5, flatten=True, in_features=None, out_features=960)),
        ("fc2", dict(dropout=0.5, in_features=960, out_features=240)),
        # ("fc2-2", dict(dropout=0.5, in_features=2560, out_features=1280)),
        ("fc3", dict(in_features=240, out_features=4, activation=False)),
    ]
)

model_PL = EndToEndConvNNWithFeedback_PL(
    # spatial model properties
    layers_properties=layers_properties,
    activation_spatial="relu",
    model_type="two_paths",
    # feedback model properties (for extracting electric properties)
    kernel_size=7,
    in_channels=2,
    poly_degree_distance=4,
    poly_degree_radius=3,
    activation_feedback="relu",
    # miscellaneous properties
    use_estimates_as_feedback=True,
    input_noise_std=0.01,
    input_noise_type="additive",
)

# dummy forward pass to initialize the model
dloader = DataLoader(dset, batch_size=4, shuffle=True)
batch = next(iter(dloader))
_ = model_PL.model(
    batch[0],
    distances=torch.zeros(batch[0].shape[0]).to(batch[0].device),
    radii=torch.zeros(batch[0].shape[0]).to(batch[0].device),
)



In [13]:
out_model = model_PL.model.spatial_model.forward_print_dims(batch[0])

torch.Size([4, 1, 20, 30])
Conv2d(1, 8, kernel_size=(7, 7), stride=(1, 1), padding=(3, 0), padding_mode=circular)
torch.Size([4, 8, 20, 24])
--------------------------------------------------------------------------------------------------
BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
torch.Size([4, 8, 20, 24])
--------------------------------------------------------------------------------------------------
ReLU()
torch.Size([4, 8, 20, 24])
--------------------------------------------------------------------------------------------------
CircularPad2d((0, 0, 1, 1))
torch.Size([4, 8, 22, 24])
--------------------------------------------------------------------------------------------------
MaxPool2d(kernel_size=3, stride=1, padding=0, dilation=1, ceil_mode=False)
torch.Size([4, 8, 20, 22])
--------------------------------------------------------------------------------------------------
Conv2d(8, 8, kernel_size=(5, 5), stride=(1, 1), padding=(2, 0), pad