In [9]:
import os
import pickle

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import sklearn.utils

from model import FiringRateModel, PolynomialActivation, train_model
from data import load_data, preprocess_data, get_train_test_data
from evaluate import explained_variance_ratio

In [10]:
#device = torch.device("mps")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [15]:
def plot_predictions(model, Is, fs, evr=None):    
    pred_fs = model.predict(Is)
    ts = np.arange(len(Is)) * bin_size / 1000
    fig, axs = plt.subplots(2)
    
    if evr is not None:
        fig.suptitle(f"cell_id={cell_id}, bin_size={bin_size}, k={k}, l={l}, evr={evr[0]:.3f}/{evr[1]:.3f}")
    else:
        fig.suptitle(f"cell_id={cell_id}, bin_size={bin_size}, k={k}, l={l}")
        
    axs[0].plot(ts, fs)
    axs[0].plot(ts, pred_fs)
    axs[1].plot(ts, Is)
    axs[0].set_ylabel("firing rate")
    axs[1].set_ylabel("current (pA)")
    axs[1].set_xlabel("time (s)")
    
def train(cell_id, bin_size, k, l, loss_fn, save=True, a=None, b=None):
    Is_tr, fs_tr, Is_te, fs_te, ws = get_train_test_data(data, cell_id, bin_size, device=device)
    Is_tr, fs_tr, ws = sklearn.utils.shuffle(Is_tr, fs_tr, ws)
    
    actv = PolynomialActivation()
    actv.init_from_file(f"model/activation/{loss_fn}/bin_size_{bin_size}/{cell_id}_1e-05.pickle")

    model = FiringRateModel(actv, k=k, l=l, a=a, b=b, static_g=False).to(device)
    if loss_fn == "poisson":
        criterion = torch.nn.PoissonNLLLoss(log_input=False)
    elif loss_fn == "huber":
        criterion = torch.nn.HuberLoss()
        #criterion = torch.nn.MSELoss()
    #optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
    optimizer = torch.optim.RMSprop(model.parameters(), lr=0.1, centered=True)
    
    epochs = 1 if k == 0 and l == 0 else 100
    print_every = 1 if k == 0 and l == 0 else 10
    losses = train_model(
        model, 
        criterion, 
        optimizer,
        Is_tr,
        fs_tr,
        epochs = epochs,
        print_every = print_every,
        loss_fn = loss_fn,
        bin_size = bin_size,
        up_factor = 5,
        ws = ws
    )
    
    for i in range(len(Is_tr)):
        if not torch.all(fs_tr[i] <= 0.01):
            plot_predictions(model, Is_tr[i], fs_tr[i], evr=None)
            if save:
                plt.savefig(f"figures/model/{cell_id}/bin_size_{bin_size}/{k}_{l}_{i}.png")
                plt.close()
                
    if save:
        plt.plot(list(range(len(losses))), losses)
        plt.xlabel("epoch")
        plt.ylabel("loss")
        plt.savefig(f"figures/model/{cell_id}/bin_size_{bin_size}/loss/{k}_{l}.png")
        plt.close()
    
    r = explained_variance_ratio(model, Is_te, fs_te, bin_size)
    rq = explained_variance_ratio(model, Is_te, fs_te, bin_size, quantize=True)
    plot_predictions(model, Is_te[0], fs_te[0], evr=(r, rq))
    if save:
        plt.savefig(f"figures/model/{cell_id}/bin_size_{bin_size}/{k}_{l}_noise2.png")
        plt.close()
    return model

In [16]:
data = load_data(with_zero=True)
data.keys()

dict_keys([583836069, 565871768, 605889373])

In [17]:
for cell_id in data:
    print(cell_id)
    counts = {}
    for sweep in data[cell_id][:-1]:
        stim_name = sweep["stimulus_name"]
        if stim_name not in counts:
            counts[stim_name] = 0
        counts[stim_name] += 1
    print(counts)

583836069
{'Short Square': 9, 'Long Square': 22, 'Ramp': 2, 'Noise 1': 3, 'Noise 2': 4, 'Square - 0.5ms Subthreshold': 10, 'Test': 1}
565871768
{'Short Square': 13, 'Long Square': 36, 'Ramp': 1, 'Noise 2': 4, 'Noise 1': 2, 'Square - 0.5ms Subthreshold': 10, 'Square - 2s Suprathreshold': 12, 'Test': 1}
605889373
{'Short Square': 17, 'Long Square': 37, 'Ramp': 3, 'Noise 1': 3, 'Noise 2': 3, 'Square - 0.5ms Subthreshold': 10, 'Test': 1}


In [18]:
def exists(params, cell_id, bin_size, k, l):
    return cell_id in params and bin_size in params[cell_id] and (k, l) in params[cell_id][bin_size]

try:
    with open("model/params.pickle", "rb") as file:
        params = pickle.load(file)
except:
    print("No pre-existing params to load. New file created.")
    params = {}
    
use_prev = False
for cell_id in [565871768]:
    for bin_size in [20]:
        for k in range(8, 15):
            for l in [0, 1]:
                print(f"cell_id={cell_id}, bin_size={bin_size}, k={k}, l={l}")
                if exists(params, cell_id, bin_size, k, l):
                    print("Skipped")
                else:
                    a, b = None, None
                    
                    if use_prev:
                        if exists(params, cell_id, bin_size, k-1, l):
                            a = torch.cat((torch.tensor([0.0]), params[cell_id][bin_size][(k-1, l)]["a"].clone()))
                            b = params[cell_id][bin_size][(k-1, l)]["b"].clone()
                        elif exists(params, cell_id, bin_size, k, l-1):
                            a = params[cell_id][bin_size][(k, l-1)]["a"].clone()
                            b = torch.cat((torch.tensor([0.0]), params[cell_id][bin_size][(k, l-1)]["b"].clone()))
                    model = train(cell_id, bin_size, k, l, "poisson", save=True, a=a, b=b)

                    if cell_id not in params:
                        params[cell_id] = {}
                    if bin_size not in params[cell_id]:
                        params[cell_id][bin_size] = {}
                    params[cell_id][bin_size][(k, l)] = model.get_params()
                    print(model.a.tolist(), model.b.tolist())

                    with open('model/params.pickle', 'wb') as handle:
                        pickle.dump(params, handle, protocol=pickle.HIGHEST_PROTOCOL)

cell_id=565871768, bin_size=20, k=8, l=0
Epoch 10 / Loss: 3112.0475042249295
Epoch 20 / Loss: 3114.022049478847
Epoch 30 / Loss: 3080.6556921864194
Epoch 40 / Loss: 3079.5195633023336
Epoch 50 / Loss: 3092.8171028244124
Epoch 60 / Loss: 3096.632309150543
Epoch 70 / Loss: 1906.2244903123826
Epoch 80 / Loss: 3077.8277327267997
Epoch 90 / Loss: 2337.8224507484942
Epoch 100 / Loss: 962.6774030310279
[-1.4619816541671753, -0.8456491231918335, -0.6440469026565552, -1.314702033996582, 0.0495331771671772, 1.1349689960479736, 0.09044358134269714, 1.6839436292648315] []
cell_id=565871768, bin_size=20, k=8, l=1
Epoch 10 / Loss: 3095.415077450862
Epoch 20 / Loss: 3092.923970682137
Epoch 30 / Loss: 3046.684546763359
Epoch 40 / Loss: 3019.07139769128
Epoch 50 / Loss: 2979.8659706636627
Epoch 60 / Loss: 2941.774903518063
Epoch 70 / Loss: 2759.2874977112574
Epoch 80 / Loss: 1077.4973709262522
Epoch 90 / Loss: 2636.359562307517
Epoch 100 / Loss: 2959.313719865391
[-4.707728385925293, -0.415303528308868

KeyboardInterrupt: 