In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import torch
import numpy as np
from minimodel import data, metrics, model_builder, model_trainer

device = torch.device('cuda')

data_dict = {}

data_path = '../../data'
weight_path = '../checkpoints/fullmodel'

In [None]:
mouse_id = 5
seed = 1
nlayers = 2

In [None]:
# load images
img = data.load_images(data_path, mouse_id, file=os.path.join(data_path, data.img_file_name[mouse_id]))
nimg, Ly, Lx = img.shape
print('img: ', img.shape, img.min(), img.max(), img.dtype)

# load neurons
fname = '%s_nat60k_%s.npz'%(data.db[mouse_id]['mname'], data.db[mouse_id]['datexp'])
spks, istim_train, istim_test, xpos, ypos, spks_rep_all = data.load_neurons(file_path = os.path.join(data_path, fname), mouse_id = mouse_id)
n_stim, n_max_neurons = spks.shape
print('spks: ', spks.shape, spks.min(), spks.max())
print('spks_rep_all: ', len(spks_rep_all), spks_rep_all[0].shape)
print('istim_train: ', istim_train.shape, istim_train.min(), istim_train.max())
print('istim_test: ', istim_test.shape, istim_test.min(), istim_test.max())

# split train and validation set
itrain, ival = data.split_train_val(istim_train, train_frac=0.9)
print('itrain: ', itrain.shape, itrain.min(), itrain.max())
print('ival: ', ival.shape, ival.min(), ival.max())

# normalize spks
spks, spks_rep_all = data.normalize_spks(spks, spks_rep_all, itrain)

img_val = torch.from_numpy(img[istim_train][ival]).to(device).unsqueeze(1)
img_test = torch.from_numpy(img[istim_test]).to(device).unsqueeze(1)
img_train = torch.from_numpy(img[istim_train][itrain]).to(device).unsqueeze(1) # change :130 to 25:100 

print('img_train: ', img_train.shape, img_train.min(), img_train.max())
print('img_val: ', img_val.shape, img_val.min(), img_val.max())
print('img_test: ', img_test.shape, img_test.min(), img_test.max())

input_Ly, input_Lx = img_train.shape[-2:]

ineur = np.arange(spks.shape[-1])
spks_train = torch.from_numpy(spks[itrain][:,ineur])
spks_val = torch.from_numpy(spks[ival][:,ineur]) 
spks_rep = [spks_rep_all[i][:,ineur] for i in range(len(spks_rep_all))]

# load fullmodel

In [None]:
nconv1 = 16
nconv2 = 320
model, in_channels = model_builder.build_model(NN=data.NNs[mouse_id], n_layers=nlayers, n_conv=nconv1, n_conv_mid=nconv2)
model_name = model_builder.create_model_name(data.mouse_names[mouse_id], data.exp_date[mouse_id], n_layers=nlayers, in_channels=in_channels)
weight_path = os.path.join(weight_path, 'fullmodel', data.mouse_names[mouse_id])
model_path = os.path.join(weight_path, model_name)

model.load_state_dict(torch.load(model_path))
print('loaded model', model_path)
model = model.to(device)

# test model
test_pred = model_trainer.test_epoch(model, img_test)
test_fev, test_feve = metrics.feve(spks_rep_all, test_pred)
print('FEVE (test): ', np.mean(test_feve[test_fev>0.15]))

valid_idxes = np.where(test_fev > 0.15)[0]

# load all minimodels

In [None]:
# calculate FEV
nconv1 = 16
nconv2 = 64

n_max_neurons = data.NNs_valid[mouse_id]
ineurons = valid_idxes

feve_all = np.zeros(len(ineurons))
fev_all = np.zeros(len(ineurons))
wc_all = np.zeros((len(ineurons), nconv2))
wx_all = []
wy_all = []
test_pred_all = np.zeros((len(ineurons), len(istim_test)))

for j in np.arange(len(valid_idxes)):
    ineuron = ineurons[j]

    ineur = [ineuron]
    spks_train = torch.from_numpy(spks[itrain][:,ineur])
    spks_val = torch.from_numpy(spks[ival][:,ineur]) 
    spks_rep = [spks_rep_all[i][:,ineur] for i in range(len(spks_rep_all))]

    print('spks_train: ', spks_train.shape, spks_train.min(), spks_train.max())
    print('spks_val: ', spks_val.shape, spks_val.min(), spks_val.max())

    # build model
    wc_coef = 0.2
    hs_readout = 0.03

    model, in_channels = model_builder.build_model(NN=1, n_layers=nlayers, n_conv=nconv1, n_conv_mid=nconv2, Wc_coef=wc_coef)
    model_name = model_builder.create_model_name(data.mouse_names[mouse_id], data.exp_date[mouse_id], ineuron=ineur[0], n_layers=nlayers, in_channels=in_channels, seed=seed,hs_readout=hs_readout)

    weight_path = os.path.join(weight_path, 'minimodel', data.mouse_names[mouse_id])
    model_path = os.path.join(weight_path, model_name)
    print('model path: ', model_path)
    model.load_state_dict(torch.load(model_path))
    print('loaded model', model_path)
    model = model.to(device)

    # test model
    test_pred = model_trainer.test_epoch(model, img_test)
    test_pred_all[j] = test_pred.squeeze()

    test_fev, test_feve = metrics.feve(spks_rep, test_pred)
    print('FEVE (test): ', test_feve)

    feve_all[j] = test_feve
    fev_all[j] = test_fev
    wc_all[j] = model.readout.Wc.detach().cpu().numpy().squeeze()
    wx_all.append(model.readout.Wx.detach().cpu().numpy().squeeze())
    wy_all.append(model.readout.Wy.detach().cpu().numpy().squeeze())

In [None]:
data_dict['feve_all'] = feve_all
data_dict['fev_all'] = fev_all
data_dict['wc_all'] = wc_all
data_dict['wx_all'] = np.stack(wx_all)
data_dict['wy_all'] = np.stack(wy_all)
data_dict['test_pred_all'] = test_pred_all

print(np.mean(feve_all))
print(np.stack(wx_all).shape)

# save

In [None]:
np.savez(f"outputs/minimodel_{data.mouse_names[mouse_id]}_result.npz", **data_dict)
print(f'saved minimodel result to minimodel_{data.mouse_names[mouse_id]}_result.npz')