In [None]:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable
import pandas as pd
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F

import model, utils, inference

from importlib import reload
reload(model)

In [None]:
from torch.utils.data import TensorDataset, DataLoader

decay = 5
dt = 0.05
window_size = 5
n_neurons = 5
n_vis_neurons = 3
basis = utils.exp_basis(decay, window_size, dt*window_size)
T = 5

In [None]:
trial = 0

df = pd.read_pickle('data.pkl')

spikes_list_train = df.at[trial, 'spikes_list_train']
convolved_spikes_list_train = df.at[trial, 'convolved_spikes_list_train']

train_dataset = TensorDataset(spikes_list_train, convolved_spikes_list_train)
train_dataloader = DataLoader(train_dataset, batch_size=10, shuffle=False)

inf_model = model.POGLM(n_neurons, n_vis_neurons, basis)
with torch.no_grad():
    inf_model.linear.weight.data = torch.zeros((n_neurons, n_neurons))
    inf_model.linear.bias.data = torch.zeros((n_neurons, ))
    
inf_optimizer = torch.optim.Adam(inf_model.parameters())

In [None]:
n_epochs = 1000
print_freq = 100

epoch_loss_list = torch.zeros(n_epochs)

for epoch in range(n_epochs):
    for spikes_list, convolved_spikes_list in train_dataloader:
        batch_size = spikes_list.shape[0]
        loss = 0
        for sample in range(batch_size):
            spikes = spikes_list[sample]
            convolved_spikes = convolved_spikes_list[sample]
            
            hid_spikes_list = spikes[None, :, n_vis_neurons:]
            convolved_hid_spikes_list = convolved_spikes[None, :, n_vis_neurons:]
            vis_spikes = spikes[:, :n_vis_neurons]
            convolved_vis_spikes = convolved_spikes[:, :n_vis_neurons]
            loss -= inf_model.complete_log_likelihood(hid_spikes_list, convolved_hid_spikes_list, vis_spikes, convolved_vis_spikes)[0]
        
        loss /= batch_size
        loss.backward()
        inf_optimizer.step()
        inf_optimizer.zero_grad()
        
        epoch_loss_list[epoch] += loss.item()
    
    if epoch % print_freq == 0:
        with torch.no_grad():
            print(epoch, epoch_loss_list[epoch],
                  (data.at[trial, 'gen_model']['linear.weight'] - inf_model.linear.weight.data).abs().mean(),
                  (data.at[trial, 'gen_model']['linear.bias'] - inf_model.linear.bias.data).abs().mean(),
                  flush=True)

In [None]:
plt.matshow(data.at[trial, 'gen_model']['linear.weight'])
plt.colorbar()

In [None]:
plt.matshow(inf_model.linear.weight.data)