In [1]:
import numpy as np
import matplotlib.pyplot as plt
from glob import glob

In [2]:
def plot_lambda_raster(im_o, im_t):
    fig, axes = plt.subplots(nrows=2, ncols=1, figsize=(90, 80))
    time = im_o.shape[0] / 10000
    extent = [0, time, 0, 100]
    im_o = axes[0].imshow(np.transpose(im_o), cmap='binary', extent=extent)
    im_t = axes[1].imshow(np.transpose(im_t), cmap='binary', extent=extent)
    ylabel = 'Lambda'
    axes[0].set_title('Prediction\n', fontsize = 30)
    axes[0].set_xlabel('Time(s)', fontsize = 30)
    axes[0].set_ylabel('Node index', fontsize = 30)
    axes[0].set_aspect(0.1)
    axes[0].tick_params(axis='x', labelsize=20)
    axes[0].tick_params(axis='y', labelsize=20)
    axes[1].set_title('Target\n', fontsize = 30)
    axes[1].set_xlabel('Time(s)', fontsize = 30)
    axes[1].set_ylabel('Node index', fontsize = 30)
    axes[1].set_aspect(0.1)
    axes[1].tick_params(axis='x', labelsize=20)
    axes[1].tick_params(axis='y', labelsize=20)
    cax_o = fig.add_axes([axes[0].get_position().x1+0.01,axes[0].get_position().y0,0.02,axes[0].get_position().height])
    cax_t = fig.add_axes([axes[1].get_position().x1+0.01,axes[1].get_position().y0,0.02,axes[1].get_position().height])
    cbar1 = plt.colorbar(im_o, cax=cax_o)
    cbar2 = plt.colorbar(im_t, cax=cax_t)
    cbar1.ax.set_ylabel(ylabel, fontsize=30)
    cbar1.ax.tick_params(axis='y', labelsize=20)
    cbar2.ax.set_ylabel(ylabel, fontsize=30)
    cbar2.ax.tick_params(axis='y', labelsize=20)
#     plt.savefig(‘./fig/lamall_RNN_{}.png’.format(suffix), bbox_inches=‘tight’)


In [3]:
import pickle

import torch

In [4]:
spike = pickle.load(open('./data/LNP_spk_all.pickle', 'rb'))

In [5]:
spike.shape

(100, 4800000)

In [6]:
lam = pickle.load(open('./data/LNP_lam_all.pickle', 'rb'))

In [7]:
lam.shape

(100, 4800000)

In [None]:
data = spike[:1000].transpose((1,0))
lam = lam[:1000].transpose((1,0))

In [None]:
num_neurons = data.shape[0]
total_time = data.shape[-1]

In [None]:
time_steps = 200 #previous time steps = 20ms
pred_steps = 20 #steps to predict
window_size = time_steps + pred_steps - 1 # for training only
batch_size = int(np.floor(total_time / (window_size + 1)) - 1)

In [None]:
window_size

In [None]:
batch_size

In [None]:
fully_connected = np.ones((num_neurons, num_neurons)) - np.eye(num_neurons)

In [None]:
fully_connected.shape

In [None]:
encoder_edge = np.where(fully_connected)
encoder_edge = np.array([encoder_edge[0], encoder_edge[1]], dtype=np.int64)

In [None]:
encoder_edge

In [None]:
encoder_edge.shape

In [None]:
data = torch.FloatTensor(data)
lam = torch.FloatTensor(lam)
encoder_edge = torch.LongTensor(encoder_edge)

In [None]:
from tqdm import tqdm
from torch_geometric.data import Data

In [None]:
data_list = []

for i in tqdm(range(batch_size)):
    step = i * (window_size+1)
    data_sample = data[:, step:step+window_size]
    lam_tar = lam[:, step+time_steps:step+time_steps+pred_steps]
    spk_tar = data[:, step+time_steps:step+time_steps+pred_steps]
    lam_spk_tar = torch.stack([lam_tar, spk_tar], dim=-1)
    data_item = Data(x=data_sample, edge_index=encoder_edge, y=lam_spk_tar)
    data_list.append(data_item)

In [None]:
step

In [None]:
data_sample.shape

In [None]:
lam_tar.shape

In [None]:
spk_tar.shape

In [None]:
lam_spk_tar.shape

In [None]:
data_item

In [None]:
data_item.x.shape

In [None]:
data_item.edge_index.shape

In [None]:
data_item.y.shape

In [None]:
plt.plot(data[0])

In [None]:
plt.plot(lam[0])