In [1]:
import os
import pickle

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import sklearn.utils

from tqdm import tqdm

from model import FiringRateModel, PolynomialActivation
from train import train_model
from data import get_data, get_train_test_data
from evaluate import explained_variance_ratio
from utils import plot_predictions, plot_kernel
from config import config

In [2]:
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
    
print(device)

cpu


In [3]:
def train(Is, fs, g, cell_id, bin_size, ds=None, device=None, repeats=1, plot=False):
    best_model, best_losses = None, [0, 1e10]
    
    for i in range(repeats):
        model = FiringRateModel(
            g.to(device), ds, bin_size=bin_size, device=device
        ).to(device)

        criterion = torch.nn.PoissonNLLLoss(log_input=False, reduction="none")
        optimizer = torch.optim.RMSprop(model.parameters(), lr=0.03, centered=True)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, gamma=0.85, step_size=5)

        losses = train_model(
            model, 
            criterion, 
            optimizer,
            Is,
            fs,
            epochs = 150,
            print_every = 151,
            bin_size = bin_size,
            up_factor = 1,
            scheduler = scheduler
        )
        
        if plot:
            plt.plot(list(range(len(losses))), losses)
        
        if best_losses[-1] > losses[-1]:
            best_losses = losses
            best_model = model
    
    return best_model, best_losses

In [None]:
d = {}

bin_size = 20
ds = np.linspace(0.05, 1.0, 20)
actv_bin_size = 100

for cell_id in [504615116, 513593674, 565871768, 583836069, 605889373]:
    print(cell_id)
    data = get_data(cell_id, aligned=False)
    Is_tr, fs_tr, Is_te, fs_te, stims = get_train_test_data(data, bin_size, device=device)
    Is_tr, fs_tr, stims = sklearn.utils.shuffle(Is_tr, fs_tr, stims)

    actv = PolynomialActivation()
    actv.init_from_file(f"model/activation/poisson/bin_size_{actv_bin_size}/{cell_id}_0.pickle")
    model, losses = train(Is_tr, fs_tr, actv, cell_id, bin_size, ds=ds, device=device, repeats=1, plot=True)
    
    d[cell_id] = (model, losses)
    
    save = True
    if save:
        plt.savefig(config["fig_save_path"] + f"{cell_id}/bin_size_{bin_size}/Loss_{actv_bin_size}.png")
        plt.xlabel("epoch")
        plt.ylabel("loss")
        plt.close()

        for Is, fs, s in zip(Is_tr, fs_tr, stims):
            for i in range(Is.shape[0]):
                plot_predictions(
                    model, 
                    Is[i, :], 
                    fs[i, :], 
                    cell_id, 
                    bin_size, 
                    evr = None,
                    save = save,
                    fname = f"{s} {i}_{actv_bin_size}"
                )

        r = explained_variance_ratio(model, Is_te[0], fs_te[0], bin_size)
        rq = explained_variance_ratio(model, Is_te[0], fs_te[0], bin_size, quantize=True)
        plot_predictions(
            model, 
            Is_te[0][0, :], 
            fs_te[0][0, :], 
            cell_id, 
            bin_size, 
            evr = (r, rq),
            save = save,
            fname = f"Noise 2_{actv_bin_size}"
        )

        plot_kernel(
            model,
            cell_id,
            bin_size,
            save = save,
            fname = f"Kernel_{actv_bin_size}"
        )

504615116


Train model:  45%|████████████▉                | 67/150 [00:46<00:54,  1.53it/s]