In [None]:
import numpy as np
import pandas as pd
import lightning as L
from lightning.pytorch import loggers as pl_loggers
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from helpers_convolutional_filters import compute_num_conv_features, init_conv_layers
from Convolutional_Mormyromast import ConvMormyromast
from lfp_response_dataset import LfpResponseDataset, create_train_and_validation_datasets

data_fname = "../data/lfp-abby/processed/single_trials.pkl"
data_fname_averages = "../data/lfp-abby/processed/trial_averages.pkl"
data = pd.read_pickle(data_fname)
data_averages = pd.read_pickle(data_fname_averages)

print(f"Loaded data shape: single trials - {data.shape}, trial averages - {data_averages.shape}")

In [None]:
train_dataset, valid_dataset = create_train_and_validation_datasets(
    data, fish_id="fish_01", zone="dlz", percent_train=0.8
)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=32, shuffle=False)

model = ConvMormyromast(
    input_length=next(iter(train_loader))[0].shape[2],
    input_channels=1,
    conv_layer_fraction_widths=[1],
    conv_output_channels=1,
    conv_stride=25,
    N_receptors=1,
)

In [None]:
learning_rate = 0.001
input_noise_std = 0.25

# define the LightningModule
class LitModel(L.LightningModule):
    def __init__(self, model, input_noise_std=input_noise_std):
        super().__init__()
        self.model = model
        self.input_noise_std = input_noise_std

    def training_step(self, batch, batch_idx):
        x, y = batch
        x += torch.randn(*x.shape).to(x.device) * self.input_noise_std
        y_hat = self.model(x)
        loss = nn.functional.mse_loss(y_hat, y)
        self.log("train_loss", loss, prog_bar=False)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.model(x)
        val_loss = nn.functional.mse_loss(y_hat, y)
        self.log("val_loss", val_loss, prog_bar=False)
        pass

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate)
        return optimizer

lit_model = LitModel(model)

In [None]:
logger = pl_loggers.TensorBoardLogger(save_dir="lightning_logs", name="mormyromast")
# trainer = L.Trainer(logger=logger)
trainer = L.Trainer(max_epochs=100, logger=logger)
trainer.fit(model=lit_model, train_dataloaders=train_loader, val_dataloaders=valid_loader)

In [None]:
plt.plot(lit_model.model.conv_list[0].weight.detach().squeeze().numpy())
plt.show()

In [None]:
checkpoint = "./lightning_logs/mormyromast/version_1/checkpoints/epoch=99-step=7800.ckpt"
lit_model = LitModel.load_from_checkpoint(checkpoint, model=model)

In [None]:
plt.plot(lit_model.model.conv_list[0].weight.detach().cpu().squeeze().numpy())
plt.show()