In [None]:
import os
import torch
import numpy as np
from minimodel import data, metrics, model_builder, model_trainer

device = torch.device('cuda')

data_dict = {}

data_path = '../../data'
weight_path = '../checkpoints/fullmodel'

In [None]:
mouse_id = 0

# load images
img = data.load_images(data_path, mouse_id, file=os.path.join(data_path, data.img_file_name[mouse_id]))
nimg, Ly, Lx = img.shape
print('img: ', img.shape, img.min(), img.max(), img.dtype)

# load neurons
fname = '%s_nat60k_%s.npz'%(data.db[mouse_id]['mname'], data.db[mouse_id]['datexp'])
spks, istim_train, istim_test, xpos, ypos, spks_rep_all = data.load_neurons(file_path = os.path.join(data_path, fname), mouse_id = mouse_id)
n_stim, n_max_neurons = spks.shape

# split train and validation set
itrain, ival = data.split_train_val(istim_train, train_frac=0.9)

# normalize spks
spks, spks_rep_all = data.normalize_spks(spks, spks_rep_all, itrain)

ineur = np.arange(0, n_max_neurons) #np.arange(0, n_neurons, 5)

spks_val = torch.from_numpy(spks[ival][:,ineur]) 
spks_rep_all = [spks_rep_all[i][:,ineur] for i in range(len(spks_rep_all))]

img_val = torch.from_numpy(img[istim_train][ival]).to(device).unsqueeze(1)
img_test = torch.from_numpy(img[istim_test]).to(device).unsqueeze(1)

input_Ly, input_Lx = img_test.shape[-2:]

In [None]:
ntest_trails = 0
for i in range(len(spks_rep_all)):
    ntest_trails += spks_rep_all[i].shape[0]
print('ntest_trails: ', ntest_trails)
print('total_trails: ', spks.shape[0]+ntest_trails)

# n layers result

In [None]:
seed = 1
nlayers = 2
feve_nlayers = []
for nlayers in range(1, 5):
    nconv1 = 192
    nconv2 = 192
    model, in_channels = model_builder.build_model(NN=len(ineur), n_layers=nlayers, n_conv=nconv1, n_conv_mid=nconv2)
    model_name = model_builder.create_model_name(data.mouse_names[mouse_id], data.exp_date[mouse_id], n_layers=nlayers, in_channels=in_channels, seed=seed)

    weight_path = os.path.join(weight_path, 'fullmodel', data.mouse_names[mouse_id])
    if not os.path.exists(weight_path):
        os.makedirs(weight_path)
    model_path = os.path.join(weight_path, model_name)
    print('model path: ', model_path)
    model = model.to(device)
    model.load_state_dict(torch.load(model_path))
    print('loaded model', model_path)

    # test model
    test_pred = model_trainer.test_epoch(model, img_test)
    test_fev, test_feve = metrics.feve(spks_rep_all, test_pred)
    print('FEVE (test, all): ', np.mean(test_feve))

    threshold = 0.15
    print(f'filtering neurons with FEV > {threshold}')
    valid_idxes = np.where(test_fev > threshold)[0]
    print(f'valid neurons: {len(valid_idxes)} / {len(test_fev)}')
    print(f'FEVE (test, FEV>0.15): {np.mean(test_feve[test_fev > threshold])}')

    feve_nlayers.append(test_feve)

feve_nlayers = np.stack(feve_nlayers)
print(feve_nlayers.shape)

In [None]:
data_dict['feve_depth'] = feve_nlayers
data_dict['valid_idxes'] = valid_idxes
data_dict['fev'] = test_fev

# change #conv1 #conv2 result

In [None]:
# build model
nlayers = 2
nconv1 = 192
nconv2 = 192
nconv1_list = [8,16,32,64, 128, 192, 256, 320, 384, 448]
nconv2_list = [8,16,32,64, 128, 192, 256, 320, 384, 448]
seed = 1
n_valid_neurons = len(valid_idxes)
feve_width = np.zeros((len(nconv1_list), len(nconv2_list), n_valid_neurons))
data
for i, nconv1 in enumerate(nconv1_list):
    for j, nconv2 in enumerate(nconv2_list):
        model, in_channels = model_builder.build_model(NN=len(ineur), n_layers=nlayers, n_conv=nconv1, n_conv_mid=nconv2)
        model_name = model_builder.create_model_name(data.mouse_names[mouse_id], data.exp_date[mouse_id], n_layers=nlayers, in_channels=in_channels, seed=seed)

        weight_path = os.path.join(weight_path, 'fullmodel', data.mouse_names[mouse_id])
        if not os.path.exists(weight_path):
            os.makedirs(weight_path)
        model_path = os.path.join(weight_path, model_name)
        print('model path: ', model_path)
        model = model.to(device)
        model.load_state_dict(torch.load(model_path))
        print('loaded model', model_path)

        # test model
        test_pred = model_trainer.test_epoch(model, img_test)
        test_fev, test_feve = metrics.feve(spks_rep_all, test_pred)
        print('FEVE (test, all): ', np.mean(test_feve))

        threshold = 0.15
        print(f'filtering neurons with FEV > {threshold}')
        valid_idxes = np.where(test_fev > threshold)[0]
        print(f'valid neurons: {len(valid_idxes)} / {len(test_fev)}')
        print(f'FEVE (test, FEV>0.15): {np.mean(test_feve[test_fev > threshold])}')
        feve_width[i,j] = np.mean(test_feve[test_fev > threshold])

In [None]:
data_dict['feve_width'] = feve_width

# change #stims train

In [None]:
stim_numbers = np.geomspace(500, 30000, num=10, dtype=int)
stim_numbers = np.unique(stim_numbers)  # Remove duplicates that might occur due to rounding
print(stim_numbers)

In [None]:
# build model
seed = 1
nlayers = 2
nconv1 = 16
nconv2 = 320
n_max_neurons = len(valid_idxes)
feve_nstims = np.zeros((len(stim_numbers), n_max_neurons))

for i, n_stim in enumerate(stim_numbers):
    if n_stim  > len(itrain): n_stim = len(itrain)
    
    suffix = f'nstims_{n_stim}'
    model, in_channels = model_builder.build_model(NN=len(ineur), n_layers=nlayers, n_conv=nconv1, n_conv_mid=nconv2)
    model_name = model_builder.create_model_name(data.mouse_names[mouse_id], data.exp_date[mouse_id], n_layers=nlayers, in_channels=in_channels, seed=seed, suffix=suffix)

    weight_path = os.path.join(weight_path, 'fullmodel', data.mouse_names[mouse_id])
    if not os.path.exists(weight_path):
        os.makedirs(weight_path)
    model_path = os.path.join(weight_path, model_name)
    print('model path: ', model_path)
    model = model.to(device)
    model.load_state_dict(torch.load(model_path))
    print('loaded model', model_path)

    # test model
    test_pred = model_trainer.test_epoch(model, img_test)
    test_fev, test_feve = metrics.feve(spks_rep_all, test_pred)
    print('FEVE (test): ', np.mean(test_feve[valid_idxes]))

    feve_nstims[i] = np.mean(test_feve[valid_idxes])

In [None]:
data_dict['feve_nstims'] = feve_nstims
data_dict['nstims'] = stim_numbers

# change #neurons

In [None]:
neuron_numbers = np.geomspace(1, 1000, num=10, dtype=int)
neuron_numbers = np.unique(np.concatenate(([1], neuron_numbers)))  # Ensure 1 is included and remove duplicates
seed_numbers = np.linspace(10, 1, num=len(neuron_numbers), dtype=int)
print(neuron_numbers)
print(seed_numbers)

In [None]:
# build model
seed = 1
nlayers = 2
nconv1 = 16
nconv2 = 320

feve_nneurons = []

for i, n_neuron in enumerate(neuron_numbers):
    feve_nneurons.append([])
    for seed in range(1, seed_numbers[i]+1):
        np.random.seed(n_neuron*seed)
        if n_neuron < len(valid_idxes):
            ineur = np.random.choice(valid_idxes, size=n_neuron, replace=False)
        else:
            ineur = valid_idxes
            n_neuron = len(valid_idxes)
            
        suffix = f'nneurons_{n_neuron}'
        spks_train = torch.from_numpy(spks[itrain][:,ineur])
        spks_val = torch.from_numpy(spks[ival][:,ineur]) 
        spks_rep = [spks_rep_all[i][:,ineur] for i in range(len(spks_rep_all))]

        
        model, in_channels = model_builder.build_model(NN=len(ineur), n_layers=nlayers, n_conv=nconv1, n_conv_mid=nconv2)
        model_name = model_builder.create_model_name(data.mouse_names[mouse_id], data.exp_date[mouse_id], n_layers=nlayers, in_channels=in_channels, seed=seed, suffix=suffix)

        weight_path = os.path.join(weight_path, 'fullmodel', data.mouse_names[mouse_id])
        if not os.path.exists(weight_path):
            os.makedirs(weight_path)
        model_path = os.path.join(weight_path, model_name)
        print('model path: ', model_path)
        model = model.to(device)
        model.load_state_dict(torch.load(model_path))
        print('loaded model', model_path)

        # test model
        test_pred = model_trainer.test_epoch(model, img_test)
        test_fev, test_feve = metrics.feve(spks_rep, test_pred)
        print('FEVE (test): ', np.mean(test_feve))

        feve_nneurons[i].append(np.mean(test_feve))

In [None]:
feve_nneurons = [np.mean(x) for x in feve_nneurons]

In [None]:
data_dict['feve_nneurons'] = feve_nneurons
data_dict['nneurons'] = neuron_numbers

# visualize conv1

In [None]:
nlayers = 2
nconv1 = 16
nconv2 = 320
n_stim, n_max_neurons = spks.shape
ineur = np.arange(0, n_max_neurons) #np.arange(0, n_neurons, 5)
model, in_channels = model_builder.build_model(NN=len(ineur), n_layers=nlayers, n_conv=nconv1, n_conv_mid=nconv2)
model_name = model_builder.create_model_name(data.mouse_names[mouse_id], data.exp_date[mouse_id], n_layers=nlayers, in_channels=in_channels, seed=seed, suffix=suffix)

weight_path = os.path.join(weight_path, 'fullmodel', data.mouse_names[mouse_id])
if not os.path.exists(weight_path):
    os.makedirs(weight_path)
model_path = os.path.join(weight_path, model_name)
print('model path: ', model_path)
model = model.to(device)
model.load_state_dict(torch.load(model_path))
print('loaded model', model_path)

# test model
test_pred = model_trainer.test_epoch(model, img_test)
test_fev, test_feve = metrics.feve(spks_rep_all, test_pred)
print('FEVE (test, all): ', np.mean(test_feve))

threshold = 0.15
print(f'filtering neurons with FEV > {threshold}')
valid_idxes = np.where(test_fev > threshold)[0]
print(f'valid neurons: {len(valid_idxes)} / {len(test_fev)}')
print(f'FEVE (test, FEV>0.15): {np.mean(test_feve[test_fev > threshold])}')

data_dict['fullmodel_Wx'] = model.readout.Wx.cpu().detach().numpy().squeeze()
data_dict['fullmodel_Wy'] = model.readout.Wy.cpu().detach().numpy().squeeze()
data_dict['fullmodel_feve_all'] = test_feve

# LN model result

In [None]:
nlayers = 2
nconv1 = 16
nconv2 = 320
seed = 1
suffix = ''
suffix += 'LN'
n_stim, n_max_neurons = spks.shape
ineur = np.arange(0, n_max_neurons) 
model, in_channels = model_builder.build_model(NN=len(ineur), n_layers=nlayers, n_conv=nconv1, n_conv_mid=nconv2, activation=None, avgpool=True)
model_name = model_builder.create_model_name(data.mouse_names[mouse_id], data.exp_date[mouse_id], n_layers=nlayers, in_channels=in_channels, seed=seed, suffix=suffix)

weight_path = os.path.join(weight_path, 'LNmodel', data.mouse_names[mouse_id])
if not os.path.exists(weight_path):
    os.makedirs(weight_path)
model_path = os.path.join(weight_path, model_name)
print('model path: ', model_path)
model = model.to(device)
model.load_state_dict(torch.load(model_path))
print('loaded model', model_path)

# test model
test_pred = model_trainer.test_epoch(model, img_test)
test_fev, test_feve = metrics.feve(spks_rep_all, test_pred)
print('FEVE (test, all): ', np.mean(test_feve))

threshold = 0.15
print(f'filtering neurons with FEV > {threshold}')
valid_idxes = np.where(test_fev > threshold)[0]
print(f'valid neurons: {len(valid_idxes)} / {len(test_fev)}')
print(f'FEVE (test, FEV>0.15): {np.mean(test_feve[test_fev > threshold])}')

data_dict['LNmodel_feve_all'] = test_feve


# save

In [None]:
np.savez(f'outputs/fullmodel_{data.mouse_names[mouse_id]}_results.npz', **data_dict)