In [None]:
import os
import pickle

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np

from model import FiringRateModel, PolynomialActivation, train_model
from data import load_data, preprocess_data

In [None]:
#device = torch.device("mps")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
def test_model(model, criterion, Is, fs, k: int):
    with torch.no_grad():
        f = fs[0] # initialize firing rate to t=0
        loss = 0
        n = 0
        for i in range(k+1, len(Is)):
            currs = Is[i-k:i+1]
            f = model(currs, f)
            loss += criterion(f, fs[i]).item()
            n += 1
    return loss / n

def predict(model, Is, fs, k: int):
    with torch.no_grad():
        f = fs[0] # initialize firing rate to t=0
        pred_fs = []
        for i in range(k+1, len(Is)):
            currs = Is[i-k:i+1]
            f = model(currs, f)
            pred_fs.append(f)
    return pred_fs

def plot_predictions(model, Is, fs, k: int):
    pred_fs = predict(model, Is, fs, k)
    ts = list(range(1, len(fs)))
    plt.plot(ts, fs[1:], label="Actual")
    plt.plot(ts, pred_fs, label="Predicted")
    plt.legend()
    plt.ylabel("Firing rate")
    plt.xlabel("t")

In [None]:
data = load_data()
data.keys()

In [None]:
def get_train_test_data(data, cell_id, bin_size):
    Is_tr, fs_tr, Is_te, fs_te = tuple([] for _ in range(4))
    
    for sweep in data[cell_id][:-1]:
        stim_name = sweep["stimulus_name"]
        Is = torch.tensor(sweep["current"][bin_size], device=device)
        fs = torch.tensor(sweep["firing_rate"][bin_size], device=device)
        if stim_name == "Noise 2":
            Is_te.append(Is)
            fs_te.append(fs)
        elif stim_name != "Test":
            Is_tr.append(Is)
            fs_tr.append(fs)
    return Is_tr, fs_tr, Is_te, fs_te

def data_lens(Is):
    lens = [len(a) for a in Is]
    counts = {}
    for l in lens:
        if l not in counts:
            counts[l] = 0
        counts[l] += 1
    print(counts)

In [None]:
bin_size = 100
cell_id = 605889373
k = 1
loss_fn = "poisson"
Is_tr, fs_tr, Is_te, fs_te = get_train_test_data(data, cell_id, bin_size)
data_lens(Is_tr)
data_lens(Is_te)

In [None]:
actv = PolynomialActivation()
actv.init_from_file(f"model/activation/{loss_fn}/bin_size_{bin_size}/{cell_id}_1e-05.pickle")

model = FiringRateModel(actv, k=k).to(device)
if loss_fn == "poisson":
    criterion = torch.nn.PoissonNLLLoss(log_input=False)
elif loss_fn == "huber":
    #criterion = torch.nn.HuberLoss()
    criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)

train_model(
    model, 
    criterion, 
    optimizer,
    Is_tr,
    fs_tr,
    k,
    epochs = 10,
    print_every = 1,
    loss_fn = loss_fn,
    bin_size = bin_size,
    up_factor = 1000
)
'''
train_loss = test_model(
    model, 
    criterion, 
    Is_tr,
    fs_tr,
    k
)

test_loss = test_model(
    model, 
    criterion, 
    Is_te,
    fs_te,
    k
)

print(f"\nTrain Loss: {train_loss}")
print(f"Test Loss: {test_loss}")''';

In [None]:
for i in range(len(Is_tr)):
    if not torch.all(fs_tr[i] <= 0.01):
        plt.figure()
        pred_fs = predict(model, Is_tr[i], fs_tr[i], k)
        ts = list(range(len(Is_tr[i])-k-1))
        plt.plot(ts, fs_tr[i][k+1:])
        plt.plot(ts, pred_fs)

In [None]:
for i in range(len(Is_tr)):
    pred_fs = predict(model, Is_tr[i], fs_tr[i], k)
    ts = list(range(len(Is_tr[i])-k-1))
    #plt.plot(ts, fs_tr[i][k+1:])
    plt.plot(ts, pred_fs)

In [None]:
for i in range(len(Is_tr)):
    #plt.figure()
    pred_fs = predict(model, Is_tr[i], fs_tr[i], k)
    ts = list(range(len(Is_tr[i])-k-1))
    plt.plot(ts, fs_tr[i][k+1:])
    #plt.plot(ts, pred_fs)

In [None]:
print(model.a, model.b)
print(model.g.poly_coeff)