In [1]:
import os
import pickle

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import sklearn.utils

from model import FiringRateModel, PolynomialActivation, train_model
from data import load_data, preprocess_data, get_train_test_data
from evaluate import explained_variance_ratio

In [2]:
#device = torch.device("mps")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
def plot_predictions(model, Is, fs, evr=None):    
    pred_fs = model.predict(Is)
    ts = np.arange(len(Is)) * bin_size / 1000
    fig, axs = plt.subplots(2)
    
    if evr is not None:
        fig.suptitle(f"cell_id={cell_id}, bin_size={bin_size}, k={k}, l={l}, evr={evr[0]:.3f}/{evr[1]:.3f}")
    else:
        fig.suptitle(f"cell_id={cell_id}, bin_size={bin_size}, k={k}, l={l}")
        
    axs[0].plot(ts, fs)
    axs[0].plot(ts, pred_fs)
    axs[1].plot(ts, Is)
    axs[0].set_ylabel("firing rate")
    axs[1].set_ylabel("current (pA)")
    axs[1].set_xlabel("time (s)")
    
def train(cell_id, bin_size, k, l, loss_fn, save=True, a=None, b=None):
    Is_tr, fs_tr, Is_te, fs_te = get_train_test_data(data, cell_id, bin_size, device=device)
    Is_tr, fs_tr = sklearn.utils.shuffle(Is_tr, fs_tr)
    
    actv = PolynomialActivation()
    actv.init_from_file(f"model/activation/{loss_fn}/bin_size_{bin_size}/{cell_id}_1e-05.pickle")

    model = FiringRateModel(actv, k=k, l=l, a=a, b=b).to(device)
    if loss_fn == "poisson":
        criterion = torch.nn.PoissonNLLLoss(log_input=False)
    elif loss_fn == "huber":
        criterion = torch.nn.HuberLoss()
        #criterion = torch.nn.MSELoss()
    #optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
    optimizer = torch.optim.RMSprop(model.parameters(), lr=0.1, centered=True)
    
    epochs = 1 if k == 0 and l == 0 else 30
    print_every = 1 if k == 0 and l == 0 else 10
    losses = train_model(
        model, 
        criterion, 
        optimizer,
        Is_tr,
        fs_tr,
        epochs = epochs,
        print_every = print_every,
        loss_fn = loss_fn,
        bin_size = bin_size,
        up_factor = 5
    )
    
    for i in range(len(Is_tr)):
        if not torch.all(fs_tr[i] <= 0.01):
            plot_predictions(model, Is_tr[i], fs_tr[i], evr=None)
            if save:
                plt.savefig(f"figures/model/{cell_id}/bin_size_{bin_size}/{k}_{l}_{i}.png")
                plt.close()
                
    if save:
        plt.plot(list(range(len(losses))), losses)
        plt.xlabel("epoch")
        plt.ylabel("loss")
        plt.savefig(f"figures/model/{cell_id}/bin_size_{bin_size}/loss/{k}_{l}.png")
        plt.close()
    
    r = explained_variance_ratio(model, Is_te, fs_te, bin_size)
    rq = explained_variance_ratio(model, Is_te, fs_te, bin_size, quantize=True)
    plot_predictions(model, Is_te[0], fs_te[0], evr=(r, rq))
    if save:
        plt.savefig(f"figures/model/{cell_id}/bin_size_{bin_size}/{k}_{l}_noise2.png")
        plt.close()
    return model

In [6]:
data = load_data(with_zero=True)
data.keys()

dict_keys([583836069, 565871768, 605889373])

In [10]:
for cell_id in data:
    print(cell_id)
    counts = {}
    for sweep in data[cell_id][:-1]:
        stim_name = sweep["stimulus_name"]
        if stim_name not in counts:
            counts[stim_name] = 0
        counts[stim_name] += 1
    print(counts)

583836069
{'Short Square': 9, 'Long Square': 22, 'Ramp': 2, 'Noise 1': 3, 'Noise 2': 4, 'Square - 0.5ms Subthreshold': 10, 'Test': 1}
565871768
{'Short Square': 13, 'Long Square': 36, 'Ramp': 1, 'Noise 2': 4, 'Noise 1': 2, 'Square - 0.5ms Subthreshold': 10, 'Square - 2s Suprathreshold': 12, 'Test': 1}
605889373
{'Short Square': 17, 'Long Square': 37, 'Ramp': 3, 'Noise 1': 3, 'Noise 2': 3, 'Square - 0.5ms Subthreshold': 10, 'Test': 1}


In [9]:
def exists(params, cell_id, bin_size, k, l):
    return cell_id in params and bin_size in params[cell_id] and (k, l) in params[cell_id][bin_size]

try:
    with open("model/params.pickle", "rb") as file:
        params = pickle.load(file)
except:
    print("Error")
    params = {}
    
for cell_id in [565871768]:
    for bin_size in [20]:
        for k in range(1, 10):
            for l in [1, 2, 3]:
                print(f"cell_id={cell_id}, bin_size={bin_size}, k={k}, l={l}")
                if exists(params, cell_id, bin_size, k, l):
                    print("Skipped")
                else:
                    a, b = None, None
                    if exists(params, cell_id, bin_size, k-1, l):
                        a = torch.cat((torch.tensor([0.0]), params[cell_id][bin_size][(k-1, l)]["a"].clone()))
                        b = params[cell_id][bin_size][(k-1, l)]["b"].clone()
                    elif exists(params, cell_id, bin_size, k, l-1):
                        a = params[cell_id][bin_size][(k, l-1)]["a"].clone()
                        b = torch.cat((torch.tensor([0.0]), params[cell_id][bin_size][(k, l-1)]["b"].clone()))
                    model = train(cell_id, bin_size, k, l, "poisson", save=True, a=a, b=b)

                    if cell_id not in params:
                        params[cell_id] = {}
                    if bin_size not in params[cell_id]:
                        params[cell_id][bin_size] = {}
                    params[cell_id][bin_size][(k, l)] = model.get_params()
                    print(model.a.tolist(), model.b.tolist())

                    with open('model/params.pickle', 'wb') as handle:
                        pickle.dump(params, handle, protocol=pickle.HIGHEST_PROTOCOL)

cell_id=565871768, bin_size=20, k=1, l=1
Skipped
cell_id=565871768, bin_size=20, k=1, l=2
Epoch 10 / Loss: 7421.435245513916
Epoch 20 / Loss: 7310.283526420593
Epoch 30 / Loss: 16006.1729221344
[0.5738366842269897] [60.39689254760742, 109.53589630126953]
cell_id=565871768, bin_size=20, k=1, l=3
Epoch 10 / Loss: 11733.613585472107
Epoch 20 / Loss: 7184.987804412842
Epoch 30 / Loss: 13905.457458496094
[0.7356340885162354] [52.881839752197266, 125.88520812988281, 169.90554809570312]
cell_id=565871768, bin_size=20, k=2, l=1
Skipped
cell_id=565871768, bin_size=20, k=2, l=2
Epoch 10 / Loss: 14861.689311027527
Epoch 20 / Loss: 10847.328112830874
Epoch 30 / Loss: 18103.55552005768
[0.19210930168628693, 1.0360426902770996] [100.95767974853516, 157.2069091796875]
cell_id=565871768, bin_size=20, k=2, l=3
Epoch 10 / Loss: 39161.742510318756
Epoch 20 / Loss: 39161.283450603485
Epoch 30 / Loss: 39173.323622226715
[-1.7513484954833984, 0.08677950501441956] [60.332733154296875, 133.3321533203125, 174.