In [None]:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable
import pandas as pd
from scipy.io import loadmat
from scipy.stats import binned_statistic

import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader

from poglm import model, utils, inference

from importlib import reload
reload(model)

In [20]:
## data
x = loadmat(f'crcns-pvc5/rawSpikeTime/times_090425blk10_ch92.mat')['cluster_class']
timestamps_list = [x[x[:, 0] == i, 1] / 1000 for i in np.unique(x[:, 0])]
spikes = torch.from_numpy(utils.continuous_to_discrete(timestamps_list, dt=0.02, T=900)).to(torch.float32)

## hyper-parameters
decay = 0.25
window_size = 10
n_vis_neurons = spikes.shape[1]
n_neurons = n_vis_neurons
basis = utils.exp_basis(decay, window_size, window_size)


vis_spikes_list_train, vis_spikes_list_test = spikes[:22500].reshape(225, 100, -1), spikes[22500:].reshape(225, 100, -1)
convolved_vis_spikes_list_train = utils.convolve_spikes_with_basis(vis_spikes_list_train, basis, direction='forward')
convolved_vis_spikes_list_test = utils.convolve_spikes_with_basis(vis_spikes_list_test, basis, direction='forward')
train_dataset = TensorDataset(vis_spikes_list_train, convolved_vis_spikes_list_train)
train_dataloader = DataLoader(train_dataset, batch_size=25, shuffle=False)

torch.manual_seed(0)
inf_model = model.POGLM(n_neurons, n_vis_neurons, basis)
with torch.no_grad():
    inf_model.linear.weight.data = torch.zeros((n_neurons, n_neurons))
    inf_model.linear.bias.data = torch.zeros((n_neurons, ))
    
inf_optimizer = torch.optim.Adam(inf_model.parameters(), lr=0.1)

n_epochs = 20
print_freq = 1

epoch_loss_list = torch.zeros(n_epochs)

for epoch in range(n_epochs):
    for spikes_list, convolved_spikes_list in train_dataloader:
        batch_size = spikes_list.shape[0]
        loss = 0
        for sample in range(batch_size):
            spikes = spikes_list[sample]
            convolved_spikes = convolved_spikes_list[sample]
            
            hid_spikes_list = spikes[None, :, n_vis_neurons:]
            convolved_hid_spikes_list = convolved_spikes[None, :, n_vis_neurons:]
            vis_spikes = spikes[:, :n_vis_neurons]
            convolved_vis_spikes = convolved_spikes[:, :n_vis_neurons]
            loss -= inf_model.complete_log_likelihood(hid_spikes_list, convolved_hid_spikes_list, vis_spikes, convolved_vis_spikes)[0]
        
        loss /= batch_size
        loss.backward()
        inf_optimizer.step()
        inf_optimizer.zero_grad()
        
        epoch_loss_list[epoch] += loss.item()
    epoch_loss_list[epoch] /= len(train_dataloader)
    
    if epoch % print_freq == 0:
        with torch.no_grad():
            print(epoch, epoch_loss_list[epoch], flush=True)
torch.save(inf_model.state_dict(), f'model/GLM.pt')
            
def evaluate_rgc_0(inf_model, spikes_list, convolved_spikes_list, seed: int = 0):
    n_samples = spikes_list.shape[0]
    df = pd.DataFrame(index=np.arange(n_samples), columns=['marginal log-likelihood', 'ELBO'])
    
    torch.manual_seed(seed)
    
    with torch.no_grad():
        for sample in range(n_samples):
            spikes = spikes_list[sample]
            convolved_spikes = convolved_spikes_list[sample]
            
            hid_spikes_list = spikes[None, :, n_vis_neurons:]
            convolved_hid_spikes_list = convolved_spikes[None, :, n_vis_neurons:]
            vis_spikes = spikes[:, :n_vis_neurons]
            convolved_vis_spikes = convolved_spikes[:, :n_vis_neurons]
            df.at[sample, 'marginal log-likelihood'] = inf_model.complete_log_likelihood(hid_spikes_list, convolved_hid_spikes_list, vis_spikes, convolved_vis_spikes)[0]
            df.at[sample, 'ELBO'] = np.nan
            
    return df


df = evaluate_rgc_0(inf_model, vis_spikes_list_test, convolved_vis_spikes_list_test).mean().to_frame().T
df['time'] = np.nan
df.to_csv(f'csv/GLM.csv', index=False)

0 tensor(165.1292)
1 tensor(136.4076)
2 tensor(128.8828)
3 tensor(125.4257)
4 tensor(123.5623)
5 tensor(122.6213)
6 tensor(121.8679)
7 tensor(121.3196)
8 tensor(120.9073)
9 tensor(120.6300)
10 tensor(120.4323)
11 tensor(120.2767)
12 tensor(120.1534)
13 tensor(120.0594)
14 tensor(119.9867)
15 tensor(119.9293)
16 tensor(119.8845)
17 tensor(119.8500)
18 tensor(119.8232)
19 tensor(119.8026)
