In [None]:
%load_ext autoreload
%autoreload 2

from IPython.display import clear_output
clear_output(wait=True)


In [None]:
import sys
sys.path.append('../../../electric_fish/ActiveZone/electrodynamic/helper_functions')
sys.path.append('../../../electric_fish/ActiveZone/electrodynamic/objects')
sys.path.append('../../../electric_fish/ActiveZone/electrodynamic/uniform_points_generation')


import time
import numpy as np
from collections import OrderedDict
import matplotlib.pyplot as plt
import h5py

import torch

import torch.nn as nn

from torch.utils.data import DataLoader

from load_data import load_data_full
from NaiveConvNet import NaiveConvNet, TwoPathsNaiveConvNet
from train_naive_convNets_TorchDataset import ElectricImagesDataset

parse_device = lambda device: torch.device(f'cuda:{device[-1]}' if ('gpu' in device.lower()) and (torch.cuda.is_available()) else 'cpu')

In [None]:
dataset = ElectricImagesDataset(N_data_samples_that_fit_in_RAM=40_000)
len(dataset)

In [None]:
dataset_loader = DataLoader(
    dataset=dataset,
    batch_size=5_000,
    pin_memory=True,
    shuffle=False,
)
len(dataset_loader)

In [None]:
in_data, out_data = next(iter(dataset_loader))
in_data.shape, out_data.shape

In [None]:
start_time = time.time()
for i, (source, targets) in enumerate(dataset_loader):
    end_time = time.time()
    print(i, source.shape, targets.shape, f'{end_time-start_time:.2f}s')
    start_time = end_time
    if i == 7:
        break

In [None]:
# example_ei = (in_data / self.base_stim)[31,0,:,:24].numpy()
example_ei = in_data[31,0,:,:24].numpy()
vval = np.max(np.abs(example_ei))
plt.imshow(example_ei, cmap='seismic', vmin=-vval, vmax=vval)
plt.colorbar()
plt.show()

In [None]:
model = TwoPathsNaiveConvNet(
    layers_properties=OrderedDict(
        [
            (
                "conv1",
                dict(
                    in_channels=1, out_channels=4, kernel_size=7, stride=1, max_pool=dict(kernel_size=3, stride=1)
                ),
            ),
            (
                "conv2",
                dict(in_channels=4, out_channels=16, kernel_size=5, stride=1),
            ),
            (
                "conv3",
                dict(
                    in_channels=16, out_channels=8, kernel_size=5, stride=1, max_pool=dict(kernel_size=3, stride=2)
                ),
            ),
            ("fc1", dict(dropout=0.5, flatten=True, in_features=480, out_features=240)),
            ("fc2", dict(dropout=0.5, in_features=240, out_features=120)),
            ("fc3", dict(in_features=120, out_features=6, activation=False)),
        ]
    ),
    activation="relu",
)

In [None]:
model.forward_print_dims(in_data[:10])