In [1]:
import os
import pickle

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import sklearn.utils

from model import FiringRateModel, PolynomialActivation, train_model
from data import load_data, preprocess_data, get_train_test_data
from evaluate import explained_variance_ratio

In [2]:
#device = torch.device("mps")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
def plot_predictions(model, Is, fs, evr=None):    
    pred_fs = model.predict(Is)
    ts = np.arange(len(Is)) * bin_size / 1000
    fig, axs = plt.subplots(2)
    
    if evr is not None:
        fig.suptitle(f"cell_id={cell_id}, bin_size={bin_size}, k={k}, l={l}, evr={evr[0]:.3f}/{evr[1]:.3f}")
    else:
        fig.suptitle(f"cell_id={cell_id}, bin_size={bin_size}, k={k}, l={l}")
        
    axs[0].plot(ts, fs)
    axs[0].plot(ts, pred_fs)
    axs[1].plot(ts, Is)
    axs[0].set_ylabel("firing rate")
    axs[1].set_ylabel("current (pA)")
    axs[1].set_xlabel("time (s)")
    
def train(cell_id, bin_size, k, l, loss_fn, save=True):
    Is_tr, fs_tr, Is_te, fs_te = get_train_test_data(data, cell_id, bin_size, device=device)
    Is_tr, fs_tr = sklearn.utils.shuffle(Is_tr, fs_tr)
    
    actv = PolynomialActivation()
    actv.init_from_file(f"model/activation/{loss_fn}/bin_size_{bin_size}/{cell_id}_1e-05.pickle")

    model = FiringRateModel(actv, k=k, l=l).to(device)
    if loss_fn == "poisson":
        criterion = torch.nn.PoissonNLLLoss(log_input=False)
    elif loss_fn == "huber":
        criterion = torch.nn.HuberLoss()
        #criterion = torch.nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
    
    epochs = 1 if k == 0 and l == 0 else 50
    print_every = 1 if k == 0 and l == 0 else 10
    train_model(
        model, 
        criterion, 
        optimizer,
        Is_tr,
        fs_tr,
        epochs = epochs,
        print_every = print_every,
        loss_fn = loss_fn,
        bin_size = bin_size,
        up_factor = 5
    )
    
    for i in range(len(Is_tr)):
        if not torch.all(fs_tr[i] <= 0.01):
            plot_predictions(model, Is_tr[i], fs_tr[i], evr=None)
            if save:
                plt.savefig(f"figures/model/bin_size_{bin_size}/{cell_id}_{k}_{l}_{i}.png")
                plt.close()
    
    r = explained_variance_ratio(model, Is_te, fs_te, bin_size)
    rq = explained_variance_ratio(model, Is_te, fs_te, bin_size, quantize=True)
    plot_predictions(model, Is_te[0], fs_te[0], evr=(r, rq))
    if save:
        plt.savefig(f"figures/model/bin_size_{bin_size}/{cell_id}_{k}_{l}_noise2.png")
        plt.close()
    return model

In [4]:
data = load_data(with_zero=True)
data.keys()

dict_keys([583836069, 565871768, 605889373])

In [5]:
params = {}
for cell_id in [583836069, 565871768, 605889373]:
    for bin_size in [20, 50, 100]:
        for k in [0, 1, 2]:
            for l in [0, 1]:
                print(f"cell_id={cell_id}, bin_size={bin_size}, k={k}, l={l}")
                model = train(cell_id, bin_size, k, l, "poisson", save=True)
                params[(cell_id, bin_size, k, l)] = {"a": model.a.tolist(), "b": model.b.tolist()}
                print(model.a.tolist(), model.b.tolist())

cell_id=583836069, bin_size=20, k=0, l=0
Epoch 1 / Loss: 11918.103953152895
[] []
cell_id=583836069, bin_size=20, k=0, l=1
Epoch 10 / Loss: 11866.666827082634
Epoch 20 / Loss: 11813.96374103427
Epoch 30 / Loss: 11762.358575612307
Epoch 40 / Loss: 11711.832786798477
Epoch 50 / Loss: 11662.294202625751
[] [82.23571014404297]
cell_id=583836069, bin_size=20, k=1, l=0
Epoch 10 / Loss: 2739.8150255978107
Epoch 20 / Loss: 2739.7977062165737
Epoch 30 / Loss: 2739.7719454169273
Epoch 40 / Loss: 2739.736669033766
Epoch 50 / Loss: 2739.690047621727
[-1.444038987159729] []
cell_id=583836069, bin_size=20, k=1, l=1
Epoch 10 / Loss: 2741.338117226027
Epoch 20 / Loss: 2741.219035744667
Epoch 30 / Loss: 2741.074725329876
Epoch 40 / Loss: 2740.8755066394806
Epoch 50 / Loss: 2740.6097138524055
[-1.6595782041549683] [0.9941470623016357]
cell_id=583836069, bin_size=20, k=2, l=0
Epoch 10 / Loss: 1998.8511716127396
Epoch 20 / Loss: 1604.054558366537
Epoch 30 / Loss: 2469.0817211270332
Epoch 40 / Loss: 1857.8

KeyboardInterrupt: 