In [None]:
import os
import torch
import numpy as np
from minimodel import data, metrics, model_builder, model_trainer

device = torch.device('cuda')

data_dict = {}

data_path = '../../data'
weight_path = '../checkpoints/fullmodel'

# param search all

In [None]:
mouse_id = 5

# load images
img = data.load_images(data_path, mouse_id, file=os.path.join(data_path, data.img_file_name[mouse_id]))
nimg, Ly, Lx = img.shape
print('img: ', img.shape, img.min(), img.max(), img.dtype)

# load neurons
fname = '%s_nat60k_%s.npz'%(data.db[mouse_id]['mname'], data.db[mouse_id]['datexp'])
spks, istim_train, istim_test, xpos, ypos, spks_rep_all = data.load_neurons(file_path = os.path.join(data_path, fname), mouse_id = mouse_id)
n_stim, n_max_neurons = spks.shape

# split train and validation set
itrain, ival = data.split_train_val(istim_train, train_frac=0.9)

# normalize spks
spks, spks_rep_all = data.normalize_spks(spks, spks_rep_all, itrain)

ineur = np.arange(0, n_max_neurons) #np.arange(0, n_neurons, 5)

img_val = torch.from_numpy(img[istim_train][ival]).to(device).unsqueeze(1)
img_test = torch.from_numpy(img[istim_test]).to(device).unsqueeze(1)

input_Ly, input_Lx = img_test.shape[-2:]

In [None]:
nmouse = 6
np.random.seed(42)
ind_selected = np.random.choice(np.arange(np.sum(data.NNs_valid)), 1000, replace=False)
ind_all = np.zeros(np.sum(data.NNs_valid), dtype=bool)
ind_all[ind_selected] = True
ineurons_all = []
for i in range(nmouse):
    if i == 0:
        ineurons_all.append(np.where(ind_all[:data.NNs_valid[i]])[0])
    else:
        ineurons_all.append(np.where(ind_all[np.sum(data.NNs_valid[:i]):np.sum(data.NNs_valid[:i+1])])[0])
ineurons = ineurons_all[mouse_id]

In [None]:
hs_list = [0, 0.01, 0.02, 0.03, 0.04, 0.05, 0.1, 0.2, 0.5]
nlayers = 2
nconv1 = 16
nconv2 = 64
nhs = len(hs_list)
nneurons = len(ineurons)
feve_all = np.zeros((nneurons, nhs))
wc_all = np.zeros((nneurons, nhs, nconv2))
fev_test = metrics.fev(spks_rep_all)
isort_neurons = np.argsort(fev_test)[::-1]
# missing_models = []
nstims = 5000
suffix = ''
if nstims != -1:
    suffix = f'nstims_{nstims}'
for i, ineuron in enumerate(ineurons):
    for ihs, hs_readout in enumerate(hs_list):
        ineur = [isort_neurons[ineuron]]
        spks_rep = [spks_rep_all[i][:,ineur] for i in range(len(spks_rep_all))]

        # build model
        wc_coef = 0.2
        model, in_channels = model_builder.build_model(NN=1, n_layers=nlayers, n_conv=nconv1, n_conv_mid=nconv2, pool=pool, depth_separable=depth_separable, Wc_coef=wc_coef)
        model_name = model_builder.create_model_name(data.mouse_names[mouse_id], data.exp_date[mouse_id], ineuron=ineur[0], n_layers=nlayers, in_channels=in_channels, clamp=clamp,hs_readout=hs_readout, suffix=suffix)

        weight_path = os.path.join(weight_path, 'minimodel', data.mouse_names[mouse_id])
        model_path = os.path.join(weight_path, model_name)
        print('ineuron: ', ineuron)
        print('model path: ', model_path)
        model.load_state_dict(torch.load(model_path))
        print('loaded model', model_path)
        model = model.to(device)

        # test model
        test_pred = model_trainer.test_epoch(model, img_test)
        test_fev, test_feve = metrics.feve(spks_rep, test_pred)
        print('FEVE (test): ', test_feve)

        feve_all[i, ihs] = np.mean(test_feve)
        wc_all[i, ihs] = model.readout.Wc.detach().cpu().numpy()

In [None]:
if nstims == -1:
    np.savez(f'outputs/{data.mouse_names[mouse_id]}_{data.exp_date[mouse_id]}_minimodel_16_64_choose_param_result', feve_all=feve_all, wc_all=wc_all)
elif nstims == 5000:
    np.savez(f'outputs/{data.mouse_names[mouse_id]}_{data.exp_date[mouse_id]}_minimodel_16_64_choose_param_5k_result', feve_all=feve_all, wc_all=wc_all)

# choose param based on 10 neurons

In [None]:
mouse_id = 0

# load images
img = data.load_images(data_path, mouse_id, file=os.path.join(data_path, data.img_file_name[mouse_id]))
nimg, Ly, Lx = img.shape
print('img: ', img.shape, img.min(), img.max(), img.dtype)

# load neurons
fname = '%s_nat60k_%s.npz'%(data.db[mouse_id]['mname'], data.db[mouse_id]['datexp'])
spks, istim_train, istim_test, xpos, ypos, spks_rep_all = data.load_neurons(file_path = os.path.join(data_path, fname), mouse_id = mouse_id)
n_stim, n_max_neurons = spks.shape

# split train and validation set
itrain, ival = data.split_train_val(istim_train, train_frac=0.9)

# normalize spks
spks, spks_rep_all = data.normalize_spks(spks, spks_rep_all, itrain)

ineur = np.arange(0, n_max_neurons) #np.arange(0, n_neurons, 5)

# spks_val = torch.from_numpy(spks[ival][:,ineur]) 
# spks_rep_all = [spks_rep_all[i][:,ineur] for i in range(len(spks_rep_all))]

img_val = torch.from_numpy(img[istim_train][ival]).to(device).unsqueeze(1)
img_test = torch.from_numpy(img[istim_test]).to(device).unsqueeze(1)

input_Ly, input_Lx = img_test.shape[-2:]

In [None]:
hs_list = [0, 0.01, 0.02, 0.03, 0.04, 0.05, 0.1, 0.2, 0.5]

NN = 10
nhs = len(hs_list)
nlayers = 2
nconv1 = 16
nconv2 = 64
feve_val_all = np.zeros((NN, nhs))
feve_test_all = np.zeros((NN, nhs))
wc_all = np.zeros((NN, nhs, nconv2))

fev_test = metrics.fev(spks_rep_all)
isort_neurons = np.argsort(fev_test)[::-1]

# load 10 neurons models
ineurons = np.arange(data.NNs_valid[mouse_id])
np.random.seed(0)
ineurons = np.random.choice(ineurons, 10, replace=False)

nstims = -1
suffix = ''
if nstims != -1:
    suffix = f'nstims_{nstims}'
for i, ineuron in enumerate(ineurons):
    for ihs, hs_readout in enumerate(hs_list):
        ineur = [isort_neurons[ineuron]]

        spks_val = spks[ival][:,ineur][:, np.newaxis, :]
        spks_rep = [spks_rep_all[i][:,ineur] for i in range(len(spks_rep_all))]

        # build model
        wc_coef = 0.2
        model, in_channels = model_builder.build_model(NN=1, n_layers=nlayers, n_conv=nconv1, n_conv_mid=nconv2, pool=pool, depth_separable=depth_separable, Wc_coef=wc_coef)
        model_name = model_builder.create_model_name(data.mouse_names[mouse_id], data.exp_date[mouse_id], ineuron=ineur[0], n_layers=nlayers, in_channels=in_channels, clamp=clamp,hs_readout=hs_readout, suffix=suffix)

        weight_path = os.path.join(weight_path, 'minimodel', data.mouse_names[mouse_id])
        model_path = os.path.join(weight_path, model_name)
        # if not os.path.exists(model_path):
        #     missing_models.append(model_path)
        print('model path: ', model_path)
        model.load_state_dict(torch.load(model_path))
        print('loaded model', model_path)
        model = model.to(device)

        # get FEVE val (no noise estimation)
        val_pred = model_trainer.test_epoch(model, img_val)
        val_fev, val_feve = metrics.feve(spks_val, val_pred, multi_repeats=False)
        feve_val_all[i, ihs] = np.mean(val_feve)

        # test model
        test_pred = model_trainer.test_epoch(model, img_test)
        test_fev, test_feve = metrics.feve(spks_rep, test_pred)
        print('FEVE (test): ', test_feve)

        feve_test_all[i, ihs] = np.mean(test_feve)
        wc_all[i, ihs] = model.readout.Wc.detach().cpu().numpy()

In [None]:
feve_val = np.mean(feve_val_all, axis=0)
feve_test = np.mean(feve_test_all, axis=0)
nconv2 = np.sum(np.abs(wc_all)>0.01, axis=2)
nconv2 = np.mean(nconv2, axis=0)

In [None]:
if nstims == 5000:
    fname = f'outputs/{data.mouse_names[mouse_id]}_{data.exp_date[mouse_id]}_minimodel_16_64_choose_param_5k_val_result.npz'
elif nstims == -1:
    fname = f'outputs/{data.mouse_names[mouse_id]}_{data.exp_date[mouse_id]}_minimodel_16_64_choose_param_val_result.npz'
np.savez(fname, feve_val=feve_val, feve_test=feve_test, nconv2=nconv2, hs_list=hs_list)