In [None]:
from nlb_tools.nwb_interface import NWBDataset
from nlb_tools.make_tensors import (
    make_train_input_tensors,
    make_eval_input_tensors,
    make_eval_target_tensors,
)
from nlb_tools.evaluation import evaluate

import numpy as np
import pandas as pd
import h5py
import logging
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from LRU_pytorch import LRU

# Setup
logging.basicConfig(level=logging.INFO)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load dataset
dataset_name = 'mc_maze_small'
datapath = '~/data/foundational_ssm/motor/raw/000128/sub-Jenkins/'
dataset = NWBDataset(datapath)

# Phase and binning
phase = 'val'  # 'val' or 'test'
bin_width = 5
dataset.resample(bin_width)
suffix = '' if bin_width == 5 else f'_{int(bin_width)}'

# Prepare training data
train_split = 'train' if phase == 'val' else ['train', 'val']
train_dict = make_train_input_tensors(
    dataset, dataset_name=dataset_name, trial_split=train_split, save_file=False
)
train_spikes_heldin = train_dict['train_spikes_heldin']
train_spikes_heldout = train_dict['train_spikes_heldout']
print("Train held-in shape:", train_spikes_heldin.shape)

# Prepare evaluation data
eval_dict = make_eval_input_tensors(
    dataset, dataset_name=dataset_name, trial_split=phase, save_file=False
)
eval_spikes_heldin = eval_dict['eval_spikes_heldin']
print("Eval dict keys:", eval_dict.keys())
print("Eval held-in shape:", eval_spikes_heldin.shape)

# Prepare targets
target_dict = make_eval_target_tensors(
    dataset,
    dataset_name=dataset_name,
    train_trial_split='train',
    eval_trial_split='val',
    include_psth=True,
    save_file=False,
)

INFO:nlb_tools.nwb_interface:Loading /nfs/ghome/live/mlaimon/data/foundational_ssm/motor/raw/000128/sub-Jenkins/sub-Jenkins_ses-full_desc-test_ecephys.nwb


  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
INFO:nlb_tools.nwb_interface:Loading /nfs/ghome/live/mlaimon/data/foundational_ssm/motor/raw/000128/sub-Jenkins/sub-Jenkins_ses-full_desc-train_behavior+ecephys.nwb
INFO:nlb_tools.nwb_interface:Resampling data to 5 ms.


In [None]:
from sklearn.linear_model import PoissonRegressor

def fit_poisson(train_input, eval_input, train_output, alpha=0.0):
    train_pred = []
    eval_pred = []
    # train Poisson GLM for each output column
    for chan in range(train_output.shape[1]):
        pr = PoissonRegressor(alpha=alpha, max_iter=500)
        pr.fit(train_input, train_output[:, chan])
        train_pred.append(pr.predict(train_input))
        eval_pred.append(pr.predict(eval_input))
    train_pred = np.vstack(train_pred).T
    eval_pred = np.vstack(eval_pred).T
    return train_pred, eval_pred

In [None]:
# Assign useful variables
tlength = train_spikes_heldin.shape[1]
num_train = train_spikes_heldin.shape[0]
num_eval = eval_spikes_heldin.shape[0]
num_heldin = train_spikes_heldin.shape[2]
num_heldout = train_spikes_heldout.shape[2]

# Smooth spikes with 40 ms std gaussian
import scipy.signal as signal
kern_sd_ms = 40
kern_sd = int(round(kern_sd_ms / dataset.bin_width))
window = signal.gaussian(kern_sd * 6, kern_sd, sym=True)
window /= np.sum(window)
filt = lambda x: np.convolve(x, window, 'same')

train_spksmth_heldin = np.apply_along_axis(filt, 1, train_spikes_heldin)
eval_spksmth_heldin = np.apply_along_axis(filt, 1, eval_spikes_heldin)

## Generate rate predictions

# Reshape data to 2d for regression
train_spksmth_heldin_s = train_spksmth_heldin.reshape(-1, train_spksmth_heldin.shape[2])
eval_spksmth_heldin_s = eval_spksmth_heldin.reshape(-1, eval_spksmth_heldin.shape[2])
train_spikes_heldout_s = train_spikes_heldout.reshape(-1, train_spikes_heldout.shape[2])

# Train Poisson regressor from log of held-in smoothed spikes to held-out spikes
train_spksmth_heldout_s, eval_spksmth_heldout_s = fit_poisson(
    np.log(train_spksmth_heldin_s + 1e-4), # add constant offset to prevent taking log of 0
    np.log(eval_spksmth_heldin_s + 1e-4),
    train_spikes_heldout_s,
    alpha=0.1,
)

# Reshape data back to the same 3d shape as the input arrays
train_rates_heldin = train_spksmth_heldin_s.reshape((num_train, tlength, num_heldin))
train_rates_heldout = train_spksmth_heldout_s.reshape((num_train, tlength, num_heldout))
eval_rates_heldin = eval_spksmth_heldin_s.reshape((num_eval, tlength, num_heldin))
eval_rates_heldout = eval_spksmth_heldout_s.reshape((num_eval, tlength, num_heldout))

## Prepare submission data

output_dict = {
    dataset_name + suffix: {
        'train_rates_heldin': train_rates_heldin,
        'train_rates_heldout': train_rates_heldout,
        'eval_rates_heldin': eval_rates_heldin,
        'eval_rates_heldout': eval_rates_heldout
    }
}

## Make data to evaluate predictions with

# Reset logging level to hide excessive info messages
logging.getLogger().setLevel(logging.WARNING)

# If 'val' phase, make the target data
if phase == 'val':
    # Note that the RTT task is not well suited to trial averaging, so PSTHs are not made for it
    target_dict = make_eval_target_tensors(dataset, dataset_name=dataset_name, train_trial_split='train', eval_trial_split='val', include_psth=True, save_file=False)

    # Demonstrate target_dict structure
    print(target_dict.keys())
    print(target_dict[dataset_name + suffix].keys())

# Set logging level again
logging.getLogger().setLevel(logging.INFO)

if phase == 'val':
    print(evaluate(target_dict, output_dict))