In [None]:
import os
import torch
import numpy as np
from minimodel import data

device = torch.device('cuda')

In [None]:
mouse_id = 5

data_path = '../data'
weight_path = './checkpoints/fullmodel'
np.random.seed(1)

In [None]:
# load images
img = data.load_images(data_path, mouse_id, file=os.path.join(data_path, data.img_file_name[mouse_id]))

In [None]:
# load neurons
fname = '%s_nat60k_%s.npz'%(data.db[mouse_id]['mname'], data.db[mouse_id]['datexp'])
spks, istim_train, istim_test, xpos, ypos, spks_rep_all = data.load_neurons(file_path = os.path.join(data_path, fname), mouse_id = mouse_id)
n_stim, n_neurons = spks.shape

In [None]:
# split train and validation set
itrain, ival = data.split_train_val(istim_train, train_frac=0.9)

In [None]:
# normalize data
spks, spks_rep_all = data.normalize_spks(spks, spks_rep_all, itrain)

In [None]:
ineur = np.arange(0, n_neurons) #np.arange(0, n_neurons, 5)
spks_train = torch.from_numpy(spks[itrain][:,ineur])
spks_val = torch.from_numpy(spks[ival][:,ineur]) 

print('spks_train: ', spks_train.shape, spks_train.min(), spks_train.max())
print('spks_val: ', spks_val.shape, spks_val.min(), spks_val.max())

img_train = torch.from_numpy(img[istim_train][itrain]).to(device).unsqueeze(1) # change :130 to 25:100 
img_val = torch.from_numpy(img[istim_train][ival]).to(device).unsqueeze(1)
img_test = torch.from_numpy(img[istim_test]).to(device).unsqueeze(1)

print('img_train: ', img_train.shape, img_train.min(), img_train.max())
print('img_val: ', img_val.shape, img_val.min(), img_val.max())
print('img_test: ', img_test.shape, img_test.min(), img_test.max())

input_Ly, input_Lx = img_train.shape[-2:]

In [None]:
# build model
from minimodel import model_builder
nlayers = 2
nconv1 = 16
nconv2 = 320
model, in_channels = model_builder.build_model(NN=len(ineur), n_layers=nlayers, n_conv=nconv1, n_conv_mid=nconv2)
model_name = model_builder.create_model_name(data.mouse_names[mouse_id], data.exp_date[mouse_id], n_layers=nlayers, in_channels=in_channels)

model_path = os.path.join(weight_path, model_name)
model = model.to(device)

In [None]:
# train model
from minimodel import model_trainer
if not os.path.exists(model_path):
    best_state_dict = model_trainer.train(model, spks_train, spks_val, img_train, img_val, device=device)
    torch.save(best_state_dict, model_path)
    print('saved model', model_path)
model.load_state_dict(torch.load(model_path))
print('loaded model', model_path)

In [None]:
# test model
test_pred = model_trainer.test_epoch(model, img_test)
print('test_pred: ', test_pred.shape, test_pred.min(), test_pred.max())

In [None]:
from minimodel import metrics
test_fev, test_feve = metrics.feve(spks_rep_all, test_pred)

threshold = 0.15
print(f'filtering neurons with FEV > {threshold}')
valid_idxes = np.where(test_fev > threshold)[0]
print(f'valid neurons: {len(valid_idxes)} / {len(test_fev)}')
print(f'FEVE (test): {np.mean(test_feve[test_fev > threshold])}')