In [None]:
import numpy as np
import collections
from torch.utils import data

import sys
sys.path.append('.')
sys.path.append('../')

from einops import rearrange

import torch
import torch.nn as nn
from torch.nn import functional as FeatureAlphaDropout
import pandas as pd
import matplotlib.pyplot as plt

from torch.utils.data.dataloader import DataLoader

import math
from torch.utils.data import Dataset

from scipy import io as scipyio
import skimage
import skvideo.io
from utils import print_full

import os
import glob
parent_path = os.path.dirname(os.path.dirname(os.getcwd())) + "/"

In [None]:
# set up logging
import logging
logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
)

In [None]:
from utils import set_seed
set_seed(25)

In [None]:
# R3D: (3 x T x H x W)

from SpikeVidUtils import image_dataset

train_path = parent_path + "code/data/OneCombo3/stimuli/"
video_stack = [skimage.io.imread(vid) for vid in glob.glob(train_path + '/*.tif')][::-1]
print(glob.glob(train_path + '/*.tif')[::-1])
video_stack = np.concatenate(video_stack, axis=0, dtype=np.float32)

# video_stack = skimage.io.imread("/home/antonis/projects/slab/git/slab/transformer_exp/code/data/OneCombo3/stimuli/Combined Stimuli 3-grating.tif")
# video_stack = image_dataset(video_stack)
# video_stack = video_stack[::3]  # convert from 60 to 20 fps
# video_stack = video_stack.view(1, video_stack.shape[0], video_stack.shape[1], video_stack.shape[2], video_stack.shape[3])

video_stack = image_dataset(video_stack)
video_stack = video_stack[::3]  # convert from 60 to 20 fps
video_stack = video_stack.view(3, video_stack.shape[0] // 3, video_stack.shape[1], video_stack.shape[2], video_stack.shape[3])
# video_stack = video_stack.transpose(-1, -2)

# rearrange(video_stack[0, 0:2].transpose(0,1), 'c t (h p1) (w p2) -> (t h w) (p1 p2 c)', p1=16, p2=16).shape

In [None]:
plt.figure()
plt.imshow(video_stack[0, 1].permute(1, 2, 0))
plt.figure()
plt.imshow(video_stack[1, 1].permute(1, 2, 0))
plt.figure()
plt.imshow(video_stack[2, 1].permute(1, 2, 0))

In [None]:
# spike_path = "/home/antonis/projects/slab/git/slab/transformer_exp/code/data/SImNew3D/neural/NatureMoviePart1-A" # "code/data/SImIm/simNeu_3D_WithNorm__Combo3.mat" 
from SpikeVidUtils import trial_df_combo3

spike_data = scipyio.loadmat(parent_path + "code/data/OneCombo3/spiketrain.mat")
spike_data = np.squeeze(spike_data['spiketrain'].T, axis=-1)
spike_data = [trial_df_combo3(spike_data, n_stim) for n_stim in range(3)]
spike_data = pd.concat(spike_data, axis=0)

spike_data['Trial'] = spike_data['Trial'] + 1
spike_data['Time'] = spike_data['Time'] * 0.0751
spike_data = spike_data[(spike_data['Time'] > 0) & (spike_data['Time'] <= 32)]

# vid_duration = [len(vid) * 1/20 for vid in vid_list]

df = spike_data
del spike_data

In [None]:
df = df[df['Trial'] > 20]

In [None]:
# df = pd.read_csv(parent_path + "code/data/OneCombo3/Combo3_all_stim.csv")
window = 0.5
dt = 0.01

from SpikeVidUtils import make_intervals

df['Interval'] = make_intervals(df, window)
df['Interval_dt'] = make_intervals(df, dt)
df['Interval_dt'] = (df['Interval_dt'] - df['Interval'] + window).round(2)
df = df.reset_index(drop=True)

In [None]:
n_dt = sorted((df['Interval_dt'].unique()).round(3)) # add last interval for EOS'

df['Time'] = df['Time'].round(3)

In [None]:
# df.groupby(['Interval', 'Trial']).size().plot.bar()
# df.groupby(['Interval', 'Trial']).agg(['nunique'])
df.groupby(['Interval', 'Trial']).size().nlargest(100)

In [None]:
from SpikeVidUtils import SpikeTimeVidData

## qv-vae feats
# frames = torch.load(parent_path + "code/data/SImNew3D/stimulus/vq-vae_code_feats-24-05-4x4x4.pt").numpy() + 2
# frame_feats = torch.load(parent_path + "code/data/SImNew3D/stimulus/vq-vae_embed_feats-24-05-4x4x4.pt").numpy()
# frame_block_size = frames.shape[-1] - 1

## resnet3d feats
frame_feats = video_stack.transpose(1, 2)

frame_block_size = 560
prev_id_block_size = 30
id_block_size = 30 * 2    # 95
block_size = frame_block_size + id_block_size + prev_id_block_size # frame_block_size * 2  # small window for faster training
frame_memory = 20   # how many frames back does model see
window = window

neurons = sorted(list(set(df['ID'])))
id_stoi = { ch:i for i,ch in enumerate(neurons) }
id_itos = { i:ch for i,ch in enumerate(neurons) }

# translate neural embeddings to separate them from ID embeddings
# frames = frames + [*id_stoi.keys()][-1] 
neurons = [i for i in range(df['ID'].min(), df['ID'].max() + 1)]
# pixels = sorted(np.unique(frames).tolist())
feat_encodings = neurons + ['SOS'] + ['EOS'] + ['PAD']  # + pixels 
stoi = { ch:i for i,ch in enumerate(feat_encodings) }
itos = { i:ch for i,ch in enumerate(feat_encodings) }
stoi_dt = { ch:i for i,ch in enumerate(n_dt) }
itos_dt = { i:ch for i,ch in enumerate(n_dt) }
max(list(itos_dt.values()))

In [None]:
# train_len = round(len(df)*(4/5))
# test_len = round(len(df) - train_len)

# train_data = df[:train_len]
# test_data = df[train_len:train_len + test_len].reset_index().drop(['index'], axis=1)

n = []
for n_stim in range(3):
    n_trial = [3, 15, 5, 18]
    for n_trial in n_trial:
        trial = (n_stim + 1) * 20 - n_trial
        n.append(trial)
train_data = df[~df['Trial'].isin(n)].reset_index(drop=True)
test_data = df[df['Trial'].isin(n)].reset_index(drop=True)
small_data = df[df['Trial'].isin([5])].reset_index(drop=True)

In [None]:
from SpikeVidUtils import SpikeTimeVidData2

# train_dat1aset = spikeTimeData(spikes, block_size, dt, stoi, itos)

train_dataset = SpikeTimeVidData2(train_data, None, block_size, id_block_size, frame_block_size, prev_id_block_size, window, frame_memory, stoi, itos, neurons, stoi_dt, itos_dt, frame_feats, pred=False)
test_dataset = SpikeTimeVidData2(test_data, None, block_size, id_block_size, frame_block_size, prev_id_block_size, window, frame_memory, stoi, itos, neurons, stoi_dt, itos_dt, frame_feats, pred=False)
# dataset = SpikeTimeVidData(df, frames, frame_feats, block_size, frame_block_size, prev_id_block_size, window, frame_memory, stoi, itos)
# single_batch = SpikeTimeVidData(df[df['Trial'].isin([5])], None, block_size, frame_block_size, prev_id_block_size, window, frame_memory, stoi, itos, neurons, stoi_dt, itos_dt, frame_feats)
small_dataset = SpikeTimeVidData2(small_data, None, block_size, id_block_size, frame_block_size, prev_id_block_size, window, frame_memory, stoi, itos, neurons, stoi_dt, itos_dt, frame_feats, pred=False)


print(f'train: {len(train_dataset)}, test: {len(test_dataset)}')

In [None]:
# def get_class_weights(df, population_size):
#     class_freq = df.groupby(['ID']).size().nlargest(2)
#     class_freq_pad = np.array(class_freq.tolist() + [class_freq.max()]*(population_size - len(class_freq)), dtype=np.float32)
#     return torch.tensor(np.reciprocal(class_freq_pad) * class_freq.max(), dtype=torch.float32) / class_freq.max()

def get_class_weights(df, population_size):
    len_data = len(train_data.drop_duplicates(subset=['Interval', 'Trial'])[['Interval', 'Trial']])
    id_freq = [len(df[df['ID'] == id]) for id in range(neurons[-1] + 1)]
    sos_freq = [len_data * 2]
    eos_freq = [len_data * 1]
    pad_freq = [(len_data * (id_block_size + prev_id_block_size)) - len(df)]
    class_freq = np.array(id_freq + sos_freq + eos_freq + pad_freq, dtype=np.float32)
    class_freq = torch.tensor(np.reciprocal(class_freq) * class_freq.max(), dtype=torch.float32) / class_freq.max()
    return torch.nan_to_num(class_freq, 1)

class_weights = get_class_weights(df, train_dataset.id_population_size)

In [None]:
video_stack.shape

In [None]:
# def get_class_weights(df, population_size):
#     class_freq = df.groupby(['ID']).size().nlargest(2)
#     class_freq_pad = np.array(class_freq.tolist() + [class_freq.max()]*(population_size - len(class_freq)), dtype=np.float32)
#     return torch.tensor(np.reciprocal(class_freq_pad) * class_freq.max(), dtype=torch.float32) / class_freq.max()

def get_class_weights(df, population_size):
    len_data = len(train_data.drop_duplicates(subset=['Interval', 'Trial'])[['Interval', 'Trial']])
    id_freq = [len(df[df['ID'] == id]) for id in range(neurons[-1] + 1)]
    sos_freq = [len_data * 2]
    eos_freq = [len_data * 1]
    pad_freq = [(len_data * (id_block_size + prev_id_block_size)) - len(df)]
    class_freq = np.array(id_freq + sos_freq + eos_freq + pad_freq, dtype=np.float32)
    class_freq = torch.tensor(np.reciprocal(class_freq) * class_freq.max(), dtype=torch.float32) / class_freq.max()
    return torch.nan_to_num(class_freq, 1)

class_weights = get_class_weights(df, train_dataset.id_population_size)

In [None]:
# class_weights.max()

In [3]:
def run_epoch(split):
    is_train = split == 'train'
    print(is_train)

run_epoch('test')

False


In [None]:
from model_perceiver import GPT, GPTConfig, neuralGPTConfig, Decoder
# initialize config class and model (holds hyperparameters)
mconf = GPTConfig(train_dataset.population_size, block_size,    # frame_block_size
                        id_vocab_size=train_dataset.id_population_size,
                        frame_block_size=frame_block_size,
                        id_block_size=id_block_size,  # frame_block_size
                        n_dt=len(n_dt),
                        data_size=train_dataset.size,
                        class_weights=class_weights,
                        pretrain=True,
                        n_layer=8, n_head=4, n_embd=256,
                        temp_emb=True, pos_emb=True,
                        id_drop=0.2, im_drop=0.2)
model = GPT(mconf)
# model.load_state_dict(torch.load(parent_path +  "code/transformer_vid3/runs/models/12-08-21-00:35-e:211-b:272-l:5-h:2-ne:512-higher_order.pt"))

In [None]:
from trainer import Trainer, TrainerConfig
# model.load_state_dict(torch.load(parent_path +  "code/transformer_vid3/runs/models/12-01-21-14:18-e:19-b:239-l:4-h:2-ne:512-higher_order.pt"))
# model.load_state_dict(torch.load(parent_path +  "code/transformer_vid3/runs/models/12-14-21-23:44-e:17-b:650-l:8-h:4-ne:256-higher_order.pt"))

max_epochs = 400
batch_size = 16
tconf = TrainerConfig(max_epochs=max_epochs, batch_size=batch_size, learning_rate=3e-5, 
                      num_workers=4, lr_decay=False, warmup_tokens=2e5, 
                      decay_weights=True,
                      final_tokens=len(train_dataset)*(block_size // 8) * (max_epochs),
                      clip_norm=3.0, grad_norm_clip=2.0,
                      dataset='higher_order', mode='predict',
                      block_size=train_dataset.block_size,
                      id_block_size=train_dataset.id_block_size,
                      show_grads=False, plot_raster=False,
                      pretrain_ims=False, pretrain_ids=False)

trainer = Trainer(model, train_dataset, test_dataset, tconf, mconf)
trainer.train()  

In [None]:
# model.load_state_dict(torch.load(parent_path + "code/transformer_vid3/model_cnn_78.pt"))
# torch.save(model.state_dict(), 'epoch_382_model.pt')

In [None]:
""" Predict using TEST dataset """

from utils import predict_raster, predict_raster_resnet, predict_raster_enc_dec, predict_raster_recursive, predict_beam_search, predict_raster_recursive_time, predict_beam_search_time, predict_raster_hungarian
%matplotlib inline

loader = DataLoader(test_dataset, shuffle=False, pin_memory=False,
                                  batch_size=1, num_workers=1)

# device = torch.cuda.current_device()
# model = model.to(device)
# model.load_state_dict(torch.load(parent_path +  "code/transformer_vid3/runs/models/12-14-21-11:49-e:1-b:650-l:4-h:4-ne:256-higher_order.pt"))

""" 

To predict only neurons we pass <frame_end> so we see predictions only for Neurons 
If you want to also see frame_tokens, just pass <frame_end=0>

NOTE: 512 ID is the <end-of-sequence-id>. Right now, makes no difference if I include
it in loss, here it is included in loss and predictions.

"""
# true, predicted, true_timing, predicted_timing = predict_time_raster(model, loader, 
#                                                                     f_block_sz=frame_block_size, id_block_sz=frame_block_size, 
#                                                                     get_dt=True)

# true, predicted, true_timing, predicted_timing = predict_time_raster(model, loader, 
#                                                                     f_block_sz=frame_block_size, id_block_sz=frame_block_size, 
#                                                                     get_dt=True)

# true, predicted = predict_raster(model, loader)

# true, predicted = predict_beam_search(model, loader, stoi, frame_end=frame_block_size)
true, predicted, true_timing = predict_raster_recursive(model, loader, stoi, sample=True, top_k=15, gpu=True, frame_end=frame_block_size)
# true, predicted = predict_raster_hungarian(model, loader)
# true, predicted = predict_raster(model, loader, gpu=True)

true_df = pd.DataFrame(true.numpy())
predicted_df = pd.DataFrame(predicted.numpy())
print(len(true_df[true_df[0] == 512]), len(predicted_df[predicted_df[0] == 512])) 

In [None]:
# model.load_state_dict(torch.load(parent_path + "code/transformer_vid3/runs/models/12-10-21-18:16-e:18-b:635-l:3-h:4-ne:256-higher_order.pt"))
torch.save(model.state_dict(), 'epoch_400_modelGPT.pt')

In [None]:
test_data['Time']

In [None]:
# loader = DataLoader(train_dataset, shuffle=False, pin_memory=False,
#                                   batch_size=1, num_workers=1)

# true_train, predicted_train, true_timing_train = predict_raster_recursive(model, loader, stoi, sample=None, top_k=None)

In [None]:
true_df = pd.DataFrame(true.numpy())
predicted_df = pd.DataFrame(predicted.numpy())
print(len(true_df[true_df[0] >= 512]), len(predicted_df[predicted_df[0] >= 512])) 

In [None]:
true

In [None]:
def plot_this(true_df, predicted_df):
    plt.figure(figsize=(30,20))
    n_min = 165
    freq_true = true_df[true_df[0] < n_min].groupby([0]).size()
    print(freq_true)
    freq_pred = predicted_df[predicted_df[0] < n_min].groupby([0]).size()
    plt.bar(freq_pred.index, freq_pred, label='predicted', alpha=0.5)
    plt.bar(freq_true.index, freq_true, label='true', alpha=0.5)
    plt.title('Neuron Firing Distribution (PSTH Loss)', fontsize=40)
    plt.legend(fontsize=30)
    plt.show()

plot_this(pd.DataFrame(true.numpy()), pd.DataFrame(predicted.numpy()))

In [None]:
df['Trial'][df['Trial'] == 10]

In [None]:
def plot_this(true_df, predicted_df):
    plt.figure(figsize=(30,20))
    n_min = 512
    freq_true = true_df.groupby(['ID']).size()
    freq_pred = predicted_df.groupby(['ID']).size()
    plt.bar(freq_pred.index, freq_pred, label='Trial 5', alpha=0.5)
    plt.bar(freq_true.index, freq_true, label='Trial 10', alpha=0.5)
    plt.title('Neuron Firing Distribution (PSTH Loss)', fontsize=40)
    plt.legend(fontsize=30)
    plt.show()

plot_this(df[df['Trial'] == 5], df[df['Trial'] == 10])

In [None]:
len(true)

In [None]:
len_pred = len(true)
# len_pred = 1000
plt.figure(figsize=(40,40))
plt.title('Pixel / Spike Raster', size=50)
plt.xlabel('Time')
plt.ylabel('Neuron ID')
plt.scatter(np.arange(len_pred), true[:len_pred], alpha=0.6, label='true', marker='o')
plt.scatter(np.arange(len_pred), predicted[:len_pred], alpha=0.6, label='predicted', marker='x')
plt.legend(fontsize=50)

In [1]:
len([0])

1

In [None]:
true_df = pd.DataFrame(true.numpy())
predicted_df = pd.DataFrame(predicted.numpy())
print(len(true_df[true_df[0] == 512]), len(predicted_df[predicted_df[0] == 512])) 

plt.figure(figsize=(30,20))
n_min = 10000
freq_true = df[(df['ID'] < n_min) & (df['Trial'] == 4)].groupby(['ID']).size()
freq_pred = predicted_df[predicted_df[0] < n_min].groupby([0]).size()
plt.bar(freq_true.index, freq_true, label='true', alpha=0.3)
# plt.bar(freq_pred.index, freq_pred, label='predicted', alpha=0.3)
plt.title('Neuron Firing Distribution (PSTH Loss)', fontsize=40)
plt.legend(fontsize=30)
plt.show()

In [None]:
df = pd.DataFrame({'True':true, 'Predicted':predicted, 
                   })

# df_pred = pd.DataFrame({'True':true, 'Predicted':predicted, 'Time':true_timing / 100})

df.to_csv('GPT-one_combo_73-train.csv', index=False)

# df_pred = pd.read_csv(parent_path + "/transformer_vid3/analysis/cs-k25_2-simNeu_3D_WithNorm__Combo3-train.csv")
# df_pred = df_pred.iloc[:, 1:]

In [None]:
df

In [None]:
block_size

In [None]:
146 + 22

In [None]:
df = df.reset_index(drop=True)

In [None]:
train_dataset.id_block_size + train_dataset.id_prev_block_size

In [None]:
frame_block_size

In [None]:
train_dataset.id_prev_block_size

In [None]:
loader = DataLoader(test_dataset, shuffle=True, pin_memory=False,
                                  batch_size=2, num_workers=1)

In [None]:
iterable = iter(loader)

In [None]:
frame_feats.shape

In [None]:
x, y = next(iterable)

In [None]:
x['frames'].shape

In [None]:
x['dt_prev']

In [None]:
# df[(df['Interval'] == x['interval'][0]) & (df['Trial'] == x['interval'][1])])

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
for key, value in x.items():
    x[key] = x[key].to(device)
for key, value in y.items():
    y[key] = y[key].to(device)

In [None]:
# model = model.cuda()
model = model.cpu()
model(x, y)

In [None]:
y['dt'].shape

In [None]:
x['id'].shape

In [None]:
x['dt'].shape

In [None]:
x['id'][:, :x['id'].shape[-1] - x['pad']]

In [None]:
y['id'][:, :x['id'].shape[-1] - x['pad']]

In [None]:
x['interval']

In [None]:
plt.imshow(x['frames'][0, :, -1].permute(1, 2, 0), cmap='gray')

In [None]:
x['interval'][0]

In [None]:
df[(df['Interval'] >= (x['interval'][0] - 1)) & (df['Interval'] <= x['interval'][0]) & (df['Trial'] == x['interval'][1])]

In [None]:
x['id_prev']

In [None]:
x['dt']

In [None]:
x['dt_prev']

In [None]:
x['id']

In [None]:
id_s = id_block_size + prev_id_block_size

In [None]:
y['id'][:, :id_s - x['pad'] + 1]

In [None]:
math.ceil((20 + 21) // 20 - 1)

In [None]:
model.to('cpu')

In [None]:
x.keys()

In [None]:
preds, features, loss = model(x)

In [None]:
preds['logits'][:, frame_block_size: frame_block_size + x["pad_prev"]].shape

In [None]:
x['dt'].shape

In [None]:
x['id'].shape

In [None]:
y['id'].shape

In [None]:
y['dt'].shape

In [None]:
x['id'].shape

In [None]:
stoi_dt

In [None]:
x['dt']

In [None]:
y['dt']

In [None]:
xx = x['dt'].flatten().tolist()

In [None]:
ss = [xx[n + 1] - xx[n] for n in range(len(xx) - 1)]

In [None]:
ss

In [None]:
x['dt'].shape

In [None]:
frame_block_size + id_block_size + prev_id_block_size

In [None]:
model(x, y)

In [None]:
x['interval'] = x['interval'].flatten().tolist()

In [None]:
interval = x['interval'][0]
trial = x['interval'][1]
prev_int = interval - 1
prev_int = prev_int if prev_int > 0 else 0  
prev_id_interval = prev_int, interval
data_prev = df[(df['Interval'] >= prev_id_interval[0]) & 
                        (df['Interval'] < prev_id_interval[1]) &
                        (df['Trial'] == trial)]

In [None]:
df.index[(df['Interval'] == x['interval'][0]) & (df['Trial'] == x['interval'][1])]

In [None]:
idx = df.index[(df['Interval'] == x['interval'][0]) & (df['Trial'] == x['interval'][1])].item()
df[(df['Interval'] == x['interval'][0]) & (df['Trial'] == x['interval'][1])]

In [None]:
print(x['interval'])

In [None]:
x['id'], x['id_prev']

In [None]:
df.iloc[idx - 10:idx + 2]

In [None]:
x['frames'].shape

In [None]:
frame.shape

In [None]:
frame = x['frames'][0, :, 10]
frame = frame.transpose(0, -1)
plt.imshow(frame, cmap='gray')

In [None]:
df[df['Interval'] == 0.2]

In [None]:
x['dt_prev']

In [None]:
x['dt']

In [None]:
from model_perceiver import VideoFeaturesExtractor

vid = VideoFeaturesExtractor()
vid(x['frames']).shape

In [None]:
class_weights

In [None]:
for key, value in x.items():
    x[key] = x[key].cuda()
for key, value in y.items():
    y[key] = y[key].cuda()

In [None]:
x['id'].shape

In [None]:
x['frames'].shape

In [None]:
x['id'].shape

In [None]:
x['id'][:, :t - x['pad']]

In [None]:
y['id'][:, :t - x['pad']]

In [None]:
y

In [None]:
# model = model.to('cpu')
model = model.cuda()
preds, features, loss = model(x, y)

In [None]:
loss

In [None]:
preds['logits']

In [None]:
loss

In [None]:
y['id']

In [None]:
preds['logits'].shape

In [None]:
t = x['id'].shape[-1]

In [None]:
pad = x['pad']

In [None]:
for i in range(t - pad):
    print(i)

In [None]:
x = {'id' : 2}
def yes(x):
    print(x['id'])
    x['id'] -= 1
    return x

In [None]:
yes(x)

In [None]:
x['id']

In [None]:
y['id']

In [None]:
x['id'].shape

In [None]:
x['id'][:, 0]

In [None]:
x['id'][:, :21 - x['pad']]
y['id'][:, :21 - x['pad']]

In [None]:
y['id']

In [None]:
x['pad']

In [None]:
x['id'].shape

In [None]:
x['id'][:, 0].shape

In [None]:
y['id']

In [None]:
t = x['id'].shape[-1]

In [None]:
t

In [None]:
t - x['pad']

In [None]:
tt = torch.tensor([512])

torch.cat((x['id'], tt[None, ...]), dim=-1)

In [None]:
model = model.to('cpu')
preds, features, loss = model(x, y)

In [None]:
preds['logits'][0].shape

In [None]:
from utils import predict_raster, predict_time_raster, predict_raster_enc_dec
%matplotlib inline
from utils import set_plot_params
set_plot_params()
# model.load_state_dict(torch.load(parent_path + "code/transformer_vid3/runs/models/10-20-21-18:40-e:9-b:166-l:4-h:4-ne:512-higher_order.pt"))
loader = DataLoader(test_dataset, shuffle=False, pin_memory=False,
                                  batch_size=1, num_workers=4)
# device = torch.cuda.current_device()
# model = model.to(device)

# true, predicted, true_timing, predicted_timing = predict_time_raster(model, loader, frame_block_size, train_dataset.id_block_size)
true, predicted, timing = predict_raster_enc_dec(model, loader, frame_block_size, get_dt=True)

true_df = pd.DataFrame(true.numpy())
predicted_df = pd.DataFrame(predicted.numpy())
print(len(true_df[true_df[0] < 512]), len(predicted_df[predicted_df[0] < 512])) 

In [None]:
from utils import set_plot_params
set_plot_params()

plt.rcParams['xtick.labelsize'] = 45
plt.rcParams['ytick.labelsize'] = 45
plt.rcParams['axes.labelsize'] = 45
plt.rcParams['figure.titlesize'] = 1000
plt.rcParams['axes.labelpad'] = 17

len_pred = len(true) # len(true)
plt.figure(figsize=(40,40))
plt.title('Pixel / Spike Raster', size=50)
plt.xlabel('Time')
plt.ylabel('Neuron ID')
plt.scatter(np.arange(len_pred), true, alpha=0.6, label='true', marker='o') # true[len_pred:2 * len_pred], alpha=0.6, label='true', marker='o')
plt.scatter(np.arange(len_pred), predicted, alpha=0.6, label='predicted', marker='x') # predicted[len_pred: 2 * len_pred], alpha=0.6, label='predicted', marker='x')
plt.legend()

In [None]:
get_id = lambda data, id_: np.where(data <= 512, data, None)

idn = 174
id_true = get_id(true, idn)
id_predicted = get_id(predicted, idn)
len_pred = len(true)
plt.figure(figsize=(20,20))
plt.rcParams['xtick.labelsize'] = 10
plt.rcParams['ytick.labelsize'] = 10
plt.title(f'Neuron ID {idn}', size=20)
plt.xlabel('Time', size=20)
plt.ylabel('Response', size=20)
plt.scatter(np.arange(len(id_true[:len_pred])), id_true[:len_pred], alpha=0.7, label='true', s=75)
plt.scatter(np.arange(len(id_predicted[:len_pred])), id_predicted[:len_pred], alpha=0.6, label='predicted', marker='x', s=75)

In [None]:
def build_time_seq(time_list):
    times = []
    current_time = 0
    for dt in time_list:
        if dt == 0:
            dt = current_time
        times.append(dt)
    return times

predicted_time = build_time_seq(predicted_timing)
true_time = build_time_seq(true_timing)

In [None]:
from utils import set_plot_params
set_plot_params()

len_pred = 2000 # len(true)
plt.figure(figsize=(40,40))
plt.title('Pixel / Spike Raster', size=50)
plt.xlabel('Time')
plt.ylabel('Neuron ID')
plt.scatter(true_time[:len_pred], true[:len_pred], alpha=0.6, label='true', marker='o')
plt.scatter(predicted_time[:len_pred], predicted[:len_pred], alpha=0.6, label='predicted', marker='x')
plt.legend()

In [None]:
from utils import set_plot_params
set_plot_params()

len_pred = 2000 # len(true)
plt.figure(figsize=(40,40))
plt.title('Pixel / Spike Raster', size=50)
plt.xlabel('Time')
plt.ylabel('Neuron ID')
plt.scatter(np.arange(len_pred), true_time[:len_pred], alpha=0.6, label='true', marker='o')
plt.scatter(np.arange(len_pred), predicted_time[:len_pred], alpha=0.6, label='predicted', marker='x')
plt.legend()

In [None]:
loader = DataLoader(train_dataset, shuffle=False, pin_memory=False,
                                  batch_size=1, num_workers=1)
iterable = iter(loader)

In [None]:
import torch.nn.functional as F 

x, y = next(iterable)
T = train_dataset.id_block_size
frame_end = 0
logits, features, _ = model(x)
PAD = x['pad']
logits = logits[:, frame_end:T - PAD, :]    # get last unpadded token (-x['pad'])
# take logits of final step and apply softmax
probs = F.softmax(logits, dim=-1)
# choose highest topk (1) sample
_, ix = torch.topk(probs, k=1, dim=-1)

In [None]:
decoder.generate_padding_mask(x['pad'])

In [None]:
decoder = Decoder(mconf)
decoder(model.tok_emb(x['id']), x['frames'], x['pad'])

In [None]:
logits

In [None]:
x['pad']

In [None]:
len(ix.flatten())

In [None]:
ix.flatten()

In [None]:
x['id']

In [None]:
x['pad']

In [None]:
y[:, frame_end:T - x['pad']].flatten()

In [None]:
logits.shape

In [None]:
# torch.save(model.state_dict(), 'model_under1.pt')

In [None]:
get_id = lambda data, id_: np.where(data <= 512, data, None)

idn = 174
id_true = get_id(true, idn)
id_predicted = get_id(predicted, idn)
len_pred = len(true)
plt.figure(figsize=(20,20))
plt.rcParams['xtick.labelsize'] = 10
plt.rcParams['ytick.labelsize'] = 10
plt.title(f'Neuron ID {idn}', size=20)
plt.xlabel('Time', size=20)
plt.ylabel('Response', size=20)
plt.scatter(np.arange(len(id_true[:len_pred])), id_true[:len_pred], alpha=0.4, label='true', s=75)
plt.scatter(np.arange(len(id_predicted[:len_pred])), id_predicted[:len_pred], alpha=0.4, label='predicted', marker='x', s=75)

In [None]:
from SpikeVidUtils import SpikeTimeVidData

# train_dat1aset = spikeTim/eData(spikes, block_size, dt, stoi, itos)

train_dataset = SpikeTimeVidData(train_data, frames,  block_size, frame_block_size, prev_id_block_size, window, frame_memory, stoi, itos, frame_feats)
test_dataset = SpikeTimeVidData(test_data, frames, block_size, frame_block_size, prev_id_block_size, window, frame_memory, stoi, itos, frame_feats)
# dataset = SpikeTimeVidData(df, frames, frame_feats, block_size, frame_block_size, prev_id_block_size, window, frame_memory, stoi, itos)
single_batch = SpikeTimeVidData(df[df['Trial'].isin([5])], frames, block_size, frame_block_size, prev_id_block_size, window, frame_memory, stoi, itos, frame_feats)


print(f'train: {len(train_dataset)}, test: {len(test_dataset)}')

In [None]:
loader = DataLoader(train_dataset, shuffle=False, pin_memory=False,
                                  batch_size=2, num_workers=1)

In [None]:
iterable = iter(loader)

In [None]:
x, y = next(iterable)

In [None]:
y.shape

In [None]:
logits, features, _ = model(x)

In [None]:
logits.shape

In [None]:
model(x).shape

In [None]:
x['frames']

In [None]:
x['id'].shape

In [None]:
y.shape

In [None]:
xy = torch.rand(2, 100, 1)

In [None]:
xy.squeeze(-1).shape

In [None]:
model.to('cpu')
logits, features, loss = model(x, y)

In [None]:
loss

In [None]:
logits[1]

In [None]:
x['pad']

In [None]:
x['id'].shape

In [None]:
x['frames'].shape

In [None]:
block_size = 4

In [None]:
mask = torch.tril(torch.ones((block_size, block_size))
                                     ).view(1, 1, block_size, block_size)

In [None]:
mask

In [None]:
mask[:, :, :, 3:] = 1

In [None]:
mask

In [None]:
model = model.to('cpu')

In [None]:
x, y = model(x, y)

In [None]:
# len(df[(df['Interval'] == 238.5) & (df['Trial'] == 0)])

In [None]:
# interval_prev = 238.5 - window*5
# data_prev = df[(df['Interval'] > 3) & 
# (df['Interval'] < 6)]

In [None]:
# data_prev

In [None]:
decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)
memory = torch.rand(10, 32, 512)
tgt = torch.rand(20, 32, 512)
out = transformer_decoder(tgt, memory)
