In [None]:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable
import pandas as pd
from scipy.io import loadmat
from scipy.stats import binned_statistic

import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader

import model, utils, inference

from importlib import reload
reload(model)

In [None]:
def load_data():
    temp = loadmat(f'SpTimesRGC.mat', squeeze_me=False, struct_as_record=False)['SpTimes'][0]
    n_time_bins = 20 * 60 * 120 # 20 min * 119.9820 Hz
    time_bins = np.linspace(1, n_time_bins, n_time_bins)
    n_neurons = 27
    spikes = np.zeros((n_time_bins, n_neurons))
    for i in range(n_neurons):
        spikes[:, i] = binned_statistic(temp[i][:, 0], None, bins=np.hstack(([0], time_bins)), statistic='count')[0].T
    return spikes

spikes = load_data()

## hyper-parameters
decay = 5
# dt = 1/120
dt = 0.05
window_size = 1
n_vis_neurons = spikes.shape[1]
n_neurons = n_vis_neurons
basis = utils.exp_basis(decay, window_size, dt*window_size)


vis_spikes_list_train, vis_spikes_list_test = torch.from_numpy(spikes[:96000].reshape(960, 100, -1)).to(torch.float32), torch.from_numpy(spikes[96000:].reshape(480, 100, -1)).to(torch.float32)
convolved_vis_spikes_list_train = utils.convolve_spikes_with_basis(vis_spikes_list_train, basis)
convolved_vis_spikes_list_test = utils.convolve_spikes_with_basis(vis_spikes_list_test, basis)
train_dataset = TensorDataset(vis_spikes_list_train, convolved_vis_spikes_list_train)
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=False)

inf_model = model.POGLM(n_neurons, n_vis_neurons, basis)
with torch.no_grad():
    inf_model.linear.weight.data = torch.zeros((n_neurons, n_neurons))
    inf_model.linear.bias.data = torch.zeros((n_neurons, ))
    
inf_optimizer = torch.optim.Adam(inf_model.parameters(), lr=0.01)


n_epochs = 10
print_freq = 1

epoch_loss_list = torch.zeros(n_epochs)

for epoch in range(n_epochs):
    for spikes_list, convolved_spikes_list in train_dataloader:
        batch_size = spikes_list.shape[0]
        loss = 0
        for sample in range(batch_size):
            spikes = spikes_list[sample]
            convolved_spikes = convolved_spikes_list[sample]
            
            hid_spikes_list = spikes[None, :, n_vis_neurons:]
            convolved_hid_spikes_list = convolved_spikes[None, :, n_vis_neurons:]
            vis_spikes = spikes[:, :n_vis_neurons]
            convolved_vis_spikes = convolved_spikes[:, :n_vis_neurons]
            loss -= inf_model.complete_log_likelihood(hid_spikes_list, convolved_hid_spikes_list, vis_spikes, convolved_vis_spikes)[0]
        
        loss /= batch_size
        loss.backward()
        inf_optimizer.step()
        inf_optimizer.zero_grad()
        
        epoch_loss_list[epoch] += loss.item()
    epoch_loss_list[epoch] /= len(train_dataloader)
    
    if epoch % print_freq == 0:
        with torch.no_grad():
            print(epoch, epoch_loss_list[epoch], flush=True)

            
def evaluate_rgc_0(inf_model, spikes_list, convolved_spikes_list, seed: int = 0):
    n_samples = spikes_list.shape[0]
    df = pd.DataFrame(index=np.arange(n_samples), columns=['marginal log-likelihood', 'ELBO'])
    
    torch.manual_seed(seed)
    
    with torch.no_grad():
        for sample in range(n_samples):
            spikes = spikes_list[sample]
            convolved_spikes = convolved_spikes_list[sample]
            
            hid_spikes_list = spikes[None, :, n_vis_neurons:]
            convolved_hid_spikes_list = convolved_spikes[None, :, n_vis_neurons:]
            vis_spikes = spikes[:, :n_vis_neurons]
            convolved_vis_spikes = convolved_spikes[:, :n_vis_neurons]
            df.at[sample, 'marginal log-likelihood'] = inf_model.complete_log_likelihood(hid_spikes_list, convolved_hid_spikes_list, vis_spikes, convolved_vis_spikes)[0]
            df.at[sample, 'ELBO'] = np.nan
            
    return df


df = evaluate_rgc_0(inf_model, vis_spikes_list_test, convolved_vis_spikes_list_test).mean().to_frame().T
df['time'] = np.nan
df.to_csv(f'csv/rgc_0.csv', index=False)