In [None]:
import os
import torch
import numpy as np
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

monkey_data_path = '../data'
weight_path = './checkpoints/fullmodel'

In [None]:
# load data
dat = np.load(os.path.join(monkey_data_path, 'monkeyv1_cadena_2019.npz'))
images = dat['images']
responses = dat['responses']
real_responses = dat['real_responses']
test_images = dat['test_images']
test_responses = dat['test_responses']
test_real_responses = dat['test_real_responses']
train_idx = dat['train_idx']
val_idx = dat['val_idx']
repetitions = dat['repetitions']
monkey_id = dat['subject_id']
image_ids = dat['image_ids']

# normalize responses
responses_nan = np.where(real_responses, responses, np.nan)
resp_std = np.nanstd(responses_nan, axis=0) 
responses = responses / resp_std
test_responses = test_responses / resp_std

train_images = images[train_idx]
val_images = images[val_idx]
train_responses = responses[train_idx]
val_responses = responses[val_idx]
train_real_responses = real_responses[train_idx]
val_real_responses = real_responses[val_idx]

print('train:', train_images.shape, train_responses.shape, train_real_responses.shape)
print('val:', val_images.shape, val_responses.shape, val_real_responses.shape)
print('test:', test_images.shape, test_responses.shape, test_real_responses.shape)

print('resp:', responses.min(), responses.max())
print('test resp:', test_responses.min(), test_responses.max())

In [None]:
test_responses = np.where(test_real_responses, test_responses, np.nan)

NN = train_responses.shape[1]
Lx, Ly = train_images.shape[2], train_images.shape[3]

In [None]:
train_images = torch.from_numpy(train_images)
val_images = torch.from_numpy(val_images)
train_responses = torch.from_numpy(train_responses)
val_responses = torch.from_numpy(val_responses)
train_real_responses = torch.from_numpy(train_real_responses)
val_real_responses = torch.from_numpy(val_real_responses)

In [None]:
# build model
from minimodel import model_builder
seed = 1
nlayers = 2
nconv1 = 16
nconv2 = 320
model, in_channels = model_builder.build_model(NN=166, n_layers=nlayers, n_conv=nconv1, n_conv_mid=nconv2, input_Lx=Lx, input_Ly=Ly)
model_name = model_builder.create_model_name('monkeyV1', '2019', n_layers=nlayers, in_channels=in_channels)
model_path = os.path.join(weight_path, model_name)
print('model path: ', model_path)
model = model.to(device)

In [None]:
if not os.path.exists(model_path):
    from minimodel import model_trainer
    best_state_dict = model_trainer.monkey_train(model, train_responses, train_real_responses, val_responses, val_real_responses, train_images, val_images, device=device)
    torch.save(best_state_dict, model_path)
    print('model saved', model_path)

In [None]:
model.load_state_dict(torch.load(model_path))
print('loaded model', model_path)
model = model.to(device)

In [None]:
from minimodel import model_trainer
test_images = torch.from_numpy(test_images).to(device)
spks_pred_test = model_trainer.test_epoch(model, test_images)
print('predctions:', spks_pred_test.shape, spks_pred_test.min(), spks_pred_test.max())

In [None]:
from minimodel import metrics
test_fev, test_feve = metrics.monkey_feve(test_responses, spks_pred_test, repetitions)
print('FEVE (test): ', np.mean(test_feve))

In [None]:
num_reps = 4
sz = val_responses.shape[0]
val_responses = val_responses.reshape([num_reps, int(sz / num_reps), NN])
val_images = val_images.reshape([num_reps, int(sz / num_reps), 1, 80, 80])[0]

val_images = val_images.to(device)
spks_pred_val = model_trainer.test_epoch(model, val_images)
val_fev, val_feve = metrics.monkey_feve(val_responses, spks_pred_val, repetitions)
print('FEVE (val):', np.mean(val_feve))