In [1]:
from nlb_tools.nwb_interface import NWBDataset
from nlb_tools.make_tensors import make_train_input_tensors, make_eval_input_tensors

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch


# import math
from datetime import datetime
import os
import torch.nn as nn
import torch.utils
import torch.utils.data
# from torchvision import datasets, transforms
# from torch.autograd import Variable
import matplotlib.pyplot as plt 

In [2]:

"""implementation of the Variational Recurrent
Neural Network (VRNN) from https://arxiv.org/abs/1506.02216
using unimodal isotropic gaussian distributions for 
inference, prior, and generating models."""

# changing device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
EPS = torch.finfo(torch.float).eps # numerical logs

class VRNN(nn.Module):
    def __init__(self, x_dim, h_dim, z_dim, out_dim, n_layers, dropout=0.3, bias=False):
        super(VRNN, self).__init__()

        self.x_dim = x_dim
        self.h_dim = h_dim
        self.z_dim = z_dim
        self.out_dim = out_dim
        self.n_layers = n_layers

        self.dropout1 = torch.nn.Dropout(p=dropout)
        self.dropout2 = torch.nn.Dropout(p=dropout)
        #feature-extracting transformations
        self.phi_x = nn.Sequential(
            nn.Linear(x_dim, h_dim),
            nn.ReLU(),
            nn.Linear(h_dim, h_dim),
            nn.ReLU())
        self.phi_z = nn.Sequential(
            nn.Linear(z_dim, h_dim),
            nn.ReLU())

        #encoder
        self.enc = nn.Sequential(
            nn.Linear(h_dim + h_dim, h_dim),
            nn.ReLU(),
            nn.Linear(h_dim, h_dim),
            nn.ReLU())
        self.enc_mean = nn.Linear(h_dim, z_dim)
        self.enc_std = nn.Sequential(
            nn.Linear(h_dim, z_dim),
            nn.Softplus())

        #prior
        self.prior = nn.Sequential(
            nn.Linear(h_dim, h_dim),
            nn.ReLU())
        self.prior_mean = nn.Linear(h_dim, z_dim)
        self.prior_std = nn.Sequential(
            nn.Linear(h_dim, z_dim),
            nn.Softplus())

        #decoder
        self.dec = nn.Sequential(
            nn.Linear(h_dim + h_dim, h_dim),
            nn.ReLU(),
            nn.Linear(h_dim, h_dim),
            nn.ReLU())
        self.dec_std = nn.Sequential(
            nn.Linear(h_dim, x_dim),
            nn.Softplus())
        #self.dec_mean = nn.Linear(h_dim, x_dim)
        self.dec_mean = nn.Sequential(
            nn.Linear(h_dim, x_dim),
            nn.Sigmoid())

        #recurrence
        self.rnn = nn.GRU(h_dim + h_dim, h_dim, n_layers, bias)
        self.transform = torch.nn.Linear(h_dim, out_dim)


    def forward(self, x):

        all_enc_mean, all_enc_std = [], []
        all_dec_mean, all_dec_std = [], []
        kld_loss = 0
        nll_loss = 0
        output_list = []

        h = torch.zeros(self.n_layers, x.size(1), self.h_dim, device=device)
        for t in range(x.size(0)):

            phi_x_t = self.phi_x(self.dropout1(x[t]))

            #encoder
            enc_t = self.enc(torch.cat([phi_x_t, h[-1]], 1)) #h
            enc_mean_t = self.enc_mean(enc_t) #z
            enc_std_t = self.enc_std(enc_t) #z

            #prior
            prior_t = self.prior(h[-1])
            prior_mean_t = self.prior_mean(prior_t)
            prior_std_t = self.prior_std(prior_t)

            #sampling and reparameterization
            z_t = self._reparameterized_sample(enc_mean_t, enc_std_t)
            phi_z_t = self.phi_z(z_t)

            #decoder
            dec_t = self.dec(torch.cat([phi_z_t, h[-1]], 1))
            dec_mean_t = self.dec_mean(dec_t)
            dec_std_t = self.dec_std(dec_t)

            #recurrence
            output, h = self.rnn(torch.cat([phi_x_t, phi_z_t], 1).unsqueeze(0), h)
            output = self.transform(self.dropout2(output))

            #computing losses
            # print("enc_mean, std, prior mean, std:", enc_mean_t.shape, enc_std_t.shape, prior_mean_t.shape, prior_std_t.shape)
            kld_loss += self._kld_gauss(enc_mean_t, enc_std_t, prior_mean_t, prior_std_t)
            # print("kld:", kld_loss.shape)
            #nll_loss += self._nll_gauss(dec_mean_t, dec_std_t, x[t])
            # print("nll_ip:", dec_mean_t.shape, x[t].shape)
            nll_loss += self._nll_bernoulli(dec_mean_t, x[t])
            # print("nll_loss", nll_loss.shape)

            all_enc_std.append(enc_std_t)
            all_enc_mean.append(enc_mean_t)
            all_dec_mean.append(dec_mean_t)
            all_dec_std.append(dec_std_t)
            output_list.append(torch.exp(output))
        
        # print("out_for", kld_loss, nll_loss)
        # print("x: ",x.size(0))
        return torch.cat(output_list), kld_loss/x.size(0), nll_loss/x.size(0) #, \
            # (all_enc_mean, all_enc_std), \
            # (all_dec_mean, all_dec_std)


    def sample(self, seq_len):

        sample = torch.zeros(seq_len, self.x_dim, device=device)

        h = torch.zeros(self.n_layers, 1, self.h_dim, device=device)
        for t in range(seq_len):

            #prior
            prior_t = self.prior(h[-1])
            prior_mean_t = self.prior_mean(prior_t)
            prior_std_t = self.prior_std(prior_t)

            #sampling and reparameterization
            z_t = self._reparameterized_sample(prior_mean_t, prior_std_t)
            phi_z_t = self.phi_z(z_t)

            #decoder
            dec_t = self.dec(torch.cat([phi_z_t, h[-1]], 1))
            dec_mean_t = self.dec_mean(dec_t)
            #dec_std_t = self.dec_std(dec_t)

            phi_x_t = self.phi_x(dec_mean_t)

            #recurrence
            _, h = self.rnn(torch.cat([phi_x_t, phi_z_t], 1).unsqueeze(0), h)

            sample[t] = dec_mean_t.data

        return sample


    def reset_parameters(self, stdv=1e-1):
        for weight in self.parameters():
            weight.data.normal_(0, stdv)


    def _init_weights(self, stdv):
        pass


    def _reparameterized_sample(self, mean, std):
        """using std to sample"""
        eps = torch.empty(size=std.size(), device=device, dtype=torch.float).normal_()
        return eps.mul(std).add_(mean)


    def _kld_gauss(self, mean_1, std_1, mean_2, std_2):
        """Using std to compute KLD"""

        kld_element =  (2 * torch.log(std_2 + EPS) - 2 * torch.log(std_1 + EPS) + 
            (std_1.pow(2) + (mean_1 - mean_2).pow(2)) /
            std_2.pow(2) - 1)
        return	0.5 * torch.mean(kld_element)


    def _nll_bernoulli(self, theta, x):
        return - torch.mean(x*torch.log(theta + EPS) + (1-x)*torch.log(1-theta-EPS))


    def _nll_gauss(self, mean, std, x):
        return torch.mean(torch.log(std + EPS) + torch.log(2*torch.pi)/2 + (x - mean).pow(2)/(2*std.pow(2)))

In [3]:
class NLBRunner:
    """Class that handles training NLBRNN"""
    def __init__(self, model_init, model_cfg, data, train_cfg, use_gpu=False, num_gpus=1):
        self.model = model_init(**model_cfg)
        self.data = data
        if use_gpu and torch.cuda.is_available():
            device = torch.device('cuda:0')
            self.model.to(device)
            self.data = tuple([d.to(device) for d in self.data])
        self.cd_ratio = train_cfg.get('cd_ratio', 0.2)
        self.optimizer = torch.optim.Adam(self.model.parameters(), 
                                          lr=train_cfg.get('lr', 1e-3), 
                                          weight_decay=train_cfg.get('alpha', 0.0))
    
    def make_cd_mask(self, train_input, train_output):
        """Creates boolean mask for coordinated dropout.

        In coordinated dropout, a random set of inputs is zeroed out,
        and only the corresponding outputs (i.e. same trial, timestep, and neuron)
        are used to compute loss and update model weights. This prevents
        exact spike times from being directly passed through the model.
        """
        cd_ratio = self.cd_ratio
        input_mask = torch.zeros((train_input.shape[0] * train_input.shape[1] * train_input.shape[2]), dtype=torch.bool)
        idxs = torch.randperm(input_mask.shape[0])[:int(round(cd_ratio * input_mask.shape[0]))]
        input_mask[idxs] = True
        input_mask = input_mask.view((train_input.shape[0], train_input.shape[1], train_input.shape[2]))
        output_mask = torch.ones(train_output.shape, dtype=torch.bool)
        output_mask[:, :, :input_mask.shape[2]] = input_mask
        return input_mask, output_mask
    
    def train_epoch(self):
        """Trains model for one epoch. 
        This simple script does not support splitting training samples into batches.
        """
        self.model.train()
        self.optimizer.zero_grad()
        # create mask for coordinated dropout
        train_input, train_output, val_input, val_output, *_ = self.data
        input_mask, output_mask = self.make_cd_mask(train_input, train_output)
        # mask inputs
        masked_train_input = train_input.clone()
        masked_train_input[input_mask] = 0.0
        # masked_train_input = masked_train_input.transpose(0, 1)
        # train_predictions, kld_loss, nll_loss, a , b = self.model(masked_train_input)
        train_predictions, kld_loss, nll_loss = self.model(masked_train_input)
        
        # print("model",kld_loss.shape, nll_loss.shape)
        # learn only from masked inputs
        # train_predictions = train_predictions.transpose(0,1)
        # print("train_pred_Shape:", train_predictions.shape, train_output.shape, output_mask.shape)
        p_loss = torch.nn.functional.poisson_nll_loss(train_predictions[output_mask], train_output[output_mask], log_input=False)
        # print("loss",kld_loss, nll_loss, p_loss)
        loss = p_loss + kld_loss + nll_loss
        loss.backward()
        self.optimizer.step()
        # get validation score
        train_res, train_output = self.score(train_input, train_output, prefix='train')
        val_res, val_output = self.score(val_input, val_output, prefix='val')
        res = train_res.copy()
        res.update(val_res)
        return res, (train_output, val_output)
    
    def score(self, input, output, prefix='val'):
        """Evaluates model performance on given data"""
        self.model.eval()
        predictions, pred_kld, pred_nll = self.model(input)
        # print(predictions.shape, output.shape)
        self.model.train()
        loss = torch.nn.functional.poisson_nll_loss(predictions, output, log_input=False) + pred_kld + pred_nll
        num_heldout = output.shape[2] - input.shape[2]
        cosmooth_loss = torch.nn.functional.poisson_nll_loss(
            predictions[:, :, -num_heldout:], output[:, :, -num_heldout:], log_input=False)
        return {f'{prefix}_nll': loss.item(), f'{prefix}_cosmooth_nll': cosmooth_loss.item()}, predictions

    def train(self, n_iter=1000, patience=200, save_path=None, verbose=False, log_frequency=50):
        """Trains model for given number of iterations with early stopping"""
        train_log = []
        best_score = 1e8
        last_improv = -1
        for i in range(n_iter):
            res, output = self.train_epoch()
            res['iter'] = i
            train_log.append(res)
            if verbose:
                if (i % log_frequency) == 0:
                    print(res)
            if res['val_nll'] < best_score:
                best_score = res['val_nll']
                last_improv = i
                data = res.copy()
                if save_path is not None:
                    self.save_checkpoint(save_path, data)
            if (i - last_improv) > patience:
                break
        return train_log
    
    def save_checkpoint(self, file_path, data):
        default_ckpt = {
            "state_dict": self.model.state_dict(),
            "optim_state": self.optimizer.state_dict(),
        }
        assert "state_dict" not in data
        assert "optim_state" not in data
        default_ckpt.update(data)
        torch.save(default_ckpt, file_path)

In [4]:
def get_data(dataset_name, phase='test', bin_size=5):
    """Function that extracts and formats data for training model"""
    data_path = '/scratch/gilbreth/akamsali/Research/Makin/000138/sub-Jenkins'

    dataset = NWBDataset(data_path, 
        skip_fields=['cursor_pos', 'eye_pos', 'cursor_vel', 'eye_vel', 'hand_pos'])
    dataset.resample(5)
    train_split = ['train', 'val'] if phase == 'test' else 'train'
    eval_split = phase
    train_dict = make_train_input_tensors(dataset, dataset_name, train_split, save_file=False, include_forward_pred=True)
    eval_dict = make_eval_input_tensors(dataset, dataset_name, eval_split, save_file=False)
    training_input = np.concatenate([
        train_dict['train_spikes_heldin'],
        np.zeros(train_dict['train_spikes_heldin_forward'].shape),
    ], axis=1)
    training_output = np.concatenate([
        np.concatenate([
            train_dict['train_spikes_heldin'],
            train_dict['train_spikes_heldin_forward'],
        ], axis=1),
        np.concatenate([
            train_dict['train_spikes_heldout'],
            train_dict['train_spikes_heldout_forward'],
        ], axis=1),
    ], axis=2)
    eval_input = np.concatenate([
        eval_dict['eval_spikes_heldin'],
        np.zeros((
            eval_dict['eval_spikes_heldin'].shape[0],
            train_dict['train_spikes_heldin_forward'].shape[1],
            eval_dict['eval_spikes_heldin'].shape[2]
        )),
    ], axis=1)
    # del dataset
    return dataset, train_dict, eval_dict, training_input, training_output, eval_input

In [5]:
# Run parameters
dataset_name = 'mc_maze_large'
phase = 'val'
bin_size = 5
# Extract data
dataset, train_dict, eval_dict, training_input, training_output, eval_input = get_data(dataset_name, phase, bin_size)

# Train/val split and convert to Torch tensors
num_train = int(round(training_input.shape[0] * 0.75))
train_input = torch.Tensor(training_input[:num_train])
train_output = torch.Tensor(training_output[:num_train])
val_input = torch.Tensor(training_input[num_train:])
val_output = torch.Tensor(training_output[num_train:])
eval_input = torch.Tensor(eval_input)

In [6]:
DROPOUT = 0
L2_WEIGHT = 5e-7
LR_INIT = 1.5e-2
CD_RATIO = 0.27
HIDDEN_DIM = 40
USE_GPU = True
MAX_GPUS = 1

RUN_NAME = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + '_model'
RUN_DIR = './runs/'
if not os.path.isdir(RUN_DIR):
    os.mkdir(RUN_DIR)

# Train model
runner = NLBRunner(
    model_init=VRNN,
    model_cfg={'x_dim': train_input.shape[2], 'h_dim': HIDDEN_DIM, 'z_dim': HIDDEN_DIM, 'out_dim': train_output.shape[2], 'n_layers': 1, 'dropout': DROPOUT},
    data=(train_input, train_output, val_input, val_output, eval_input),
    train_cfg={'lr': LR_INIT, 'alpha': L2_WEIGHT, 'cd_ratio': CD_RATIO},
    use_gpu=USE_GPU,
    num_gpus=MAX_GPUS,
)

In [7]:
model_dir = os.path.join(RUN_DIR, RUN_NAME)
os.mkdir(os.path.join(RUN_DIR, RUN_NAME))
train_log = runner.train(n_iter=2500, patience=1000, save_path=os.path.join(model_dir, 'model.ckpt'), verbose=True, log_frequency=1)

# Save results
import pandas as pd
train_log = pd.DataFrame(train_log)
train_log.to_csv(os.path.join(model_dir, 'train_log.csv'))

{'train_nll': 1.6090447902679443, 'train_cosmooth_nll': 0.918481707572937, 'val_nll': 1.6093521118164062, 'val_cosmooth_nll': 0.9192109704017639, 'iter': 0}
{'train_nll': 1.3504079580307007, 'train_cosmooth_nll': 0.6960930228233337, 'val_nll': 1.3534586429595947, 'val_cosmooth_nll': 0.6997567415237427, 'iter': 1}
{'train_nll': 0.9257137179374695, 'train_cosmooth_nll': 0.4490068256855011, 'val_nll': 0.9336225986480713, 'val_cosmooth_nll': 0.45491668581962585, 'iter': 2}
{'train_nll': 0.5137958526611328, 'train_cosmooth_nll': 0.281981498003006, 'val_nll': 0.5222513675689697, 'val_cosmooth_nll': 0.28724053502082825, 'iter': 3}
{'train_nll': 0.3554554581642151, 'train_cosmooth_nll': 0.19213685393333435, 'val_nll': 0.35942143201828003, 'val_cosmooth_nll': 0.19594672322273254, 'iter': 4}
{'train_nll': 0.3436306118965149, 'train_cosmooth_nll': 0.14863835275173187, 'val_nll': 0.34422528743743896, 'val_cosmooth_nll': 0.15109747648239136, 'iter': 5}
{'train_nll': 0.3571227192878723, 'train_cosmo

In [8]:
conf = {'x_dim': train_input.shape[2], 'h_dim': HIDDEN_DIM, 'z_dim': HIDDEN_DIM, 'out_dim': train_output.shape[2], 'n_layers': 1, 'dropout': DROPOUT}
saved_model = VRNN(**conf).to(device)
ckpt = torch.load(f'{model_dir}/model.ckpt')
saved_model.load_state_dict(ckpt['state_dict'])
ckpt['state_dict'].keys()

odict_keys(['phi_x.0.weight', 'phi_x.0.bias', 'phi_x.2.weight', 'phi_x.2.bias', 'phi_z.0.weight', 'phi_z.0.bias', 'enc.0.weight', 'enc.0.bias', 'enc.2.weight', 'enc.2.bias', 'enc_mean.weight', 'enc_mean.bias', 'enc_std.0.weight', 'enc_std.0.bias', 'prior.0.weight', 'prior.0.bias', 'prior_mean.weight', 'prior_mean.bias', 'prior_std.0.weight', 'prior_std.0.bias', 'dec.0.weight', 'dec.0.bias', 'dec.2.weight', 'dec.2.bias', 'dec_std.0.weight', 'dec_std.0.bias', 'dec_mean.0.weight', 'dec_mean.0.bias', 'rnn.weight_ih_l0', 'rnn.weight_hh_l0', 'transform.weight', 'transform.bias'])

In [9]:
saved_model.eval()
training_input = torch.Tensor(training_input).to(device)
eval_input = eval_input.to(device)
training_predictions, _, _ = saved_model(training_input)
eval_predictions, _, _ = saved_model(eval_input)

training_predictions = training_predictions.cpu().detach().numpy()
eval_predictions = eval_predictions.cpu().detach().numpy()

tlen = train_dict['train_spikes_heldin'].shape[1]
num_heldin = train_dict['train_spikes_heldin'].shape[2]

submission = {
    'mc_maze_large': {
        'train_rates_heldin': training_predictions[:, :tlen, :num_heldin],
        'train_rates_heldout': training_predictions[:, :tlen, num_heldin:],
        'eval_rates_heldin': eval_predictions[:, :tlen, :num_heldin],
        'eval_rates_heldout': eval_predictions[:, :tlen, num_heldin:],
        'eval_rates_heldin_forward': eval_predictions[:, tlen:, :num_heldin],
        'eval_rates_heldout_forward': eval_predictions[:, tlen:, num_heldin:]
    }
}

In [10]:
from nlb_tools.make_tensors import make_eval_target_tensors

target_dict = make_eval_target_tensors(dataset=dataset, 
                                       dataset_name='mc_maze_large',
                                       train_trial_split='train',
                                       eval_trial_split='val',
                                       include_psth=True,
                                       save_file=False)
from nlb_tools.evaluation import evaluate

evaluate(target_dict, submission)

[{'mc_maze_scaling_split': {'[500] co-bps': 0.0751514829820409,
   '[500] vel R2': 0.025792186360115377,
   '[500] psth R2': 0.15769809450841424,
   '[500] fp-bps': -0.030471098079586333}}]