In [None]:
import os
import torch
import numpy as np
from minimodel import data, metrics, model_builder, model_trainer

device = torch.device('cuda')

data_dict = {}

data_path = '../../data'
weight_path = '../checkpoints/fullmodel'

# n_layers result

In [None]:
# load data
dat = np.load(os.path.join(data_path, 'monkeyv1_cadena_2019.npz'))
images = dat['images']
responses = dat['responses']
real_responses = dat['real_responses']
test_images = dat['test_images']
test_responses = dat['test_responses']
test_real_responses = dat['test_real_responses']
train_idx = dat['train_idx']
val_idx = dat['val_idx']
repetitions = dat['repetitions']
monkey_ids = dat['subject_id']
image_ids = dat['image_ids']

# normalize responses
responses_nan = np.where(real_responses, responses, np.nan)
resp_std = np.nanstd(responses_nan, axis=0)
responses = responses / resp_std
test_responses = test_responses / resp_std

train_images = images[train_idx]
val_images = images[val_idx]
train_responses = responses[train_idx]
val_responses = responses[val_idx]
train_real_responses = real_responses[train_idx]
val_real_responses = real_responses[val_idx]

print('train:', train_images.shape, train_responses.shape, train_real_responses.shape)
print('val:', val_images.shape, val_responses.shape, val_real_responses.shape)
print('test:', test_images.shape, test_responses.shape, test_real_responses.shape)

print('resp:', responses.min(), responses.max())
print('test resp:', test_responses.min(), test_responses.max())

test_responses = np.where(test_real_responses, test_responses, np.nan)

NN = train_responses.shape[1]
Lx, Ly = train_images.shape[2], train_images.shape[3]

train_images = torch.from_numpy(train_images)
val_images = torch.from_numpy(val_images)
train_responses = torch.from_numpy(train_responses)
val_responses = torch.from_numpy(val_responses)
train_real_responses = torch.from_numpy(train_real_responses)
val_real_responses = torch.from_numpy(val_real_responses)
test_images = torch.from_numpy(test_images).to(device)

In [None]:
# build model
seed = 2
nlayers = 2
nconv1 = 192
nconv2 = 192
n_max_neurons = 166
feve_nlayers = np.zeros((4, n_max_neurons))
weight_decay_core = 0.1
seed = 1
for nlayers in range(1, 5):
    if nlayers == 3: weight_decay_core = 0.2
    elif nlayers == 4: weight_decay_core = 0.3
    else: weight_decay_core = 0.1

    model, in_channels = model_builder.build_model(NN=166, n_layers=nlayers, n_conv=nconv1, n_conv_mid=nconv2, input_Lx=Lx, input_Ly=Ly)
    if weight_decay_core != 0.1: suffix = f'wdcore_{weight_decay_core}'
    else: suffix = ''
    model_name = model_builder.create_model_name('monkeyV1', '2019', n_layers=nlayers, in_channels=in_channels, seed=seed, suffix=suffix)
    weight_path = os.path.join(weight_path, 'fullmodel', 'monkeyV1')
    model_path = os.path.join(weight_path, model_name)
    print('model path: ', model_path)
    model.load_state_dict(torch.load(model_path))
    print('loaded model', model_path)
    model = model.to(device)

    spks_pred_test = model_trainer.test_epoch(model, test_images)
    test_fev, test_feve = metrics.monkey_feve(test_responses, spks_pred_test, repetitions)
    print('FEVE (test): ', np.mean(test_feve))
    feve_nlayers[nlayers-1] = test_feve

    imonkey1 = np.where(monkey_ids == 4)[0] 
    imonkey2 = np.where(monkey_ids == 34)[0]

    print('FEVE (test) monkey 1: ', np.mean(test_feve[imonkey1]))
    print('FEVE (test) monkey 2: ', np.mean(test_feve[imonkey2]))
    print('FEVE (test) mean: ', np.mean([np.mean(test_feve[imonkey1]), np.mean(test_feve[imonkey2])]))

In [None]:
data_dict['feve_depth'] = feve_nlayers

# LN result

In [None]:
# build model
seed = 1
nlayers = 2
nconv1 = 16
nconv2 = 320
n_max_neurons = 166
seed = 1

model, in_channels = model_builder.build_model(NN=166, n_layers=nlayers, n_conv=nconv1, n_conv_mid=nconv2, input_Lx=Lx, input_Ly=Ly, activation=None, avgpool=True)
model_name = model_builder.create_model_name('monkeyV1', '2019', n_layers=nlayers, in_channels=in_channels, seed=seed, suffix='LN')
weight_path = os.path.join(weight_path, 'LNmodel', 'monkeyV1')
model_path = os.path.join(weight_path, model_name)
print('model path: ', model_path)
model.load_state_dict(torch.load(model_path))
print('loaded model', model_path)
model = model.to(device)

model.eval()
spks_pred_test = model_trainer.test_epoch(model, test_images)
test_fev, test_feve = metrics.monkey_feve(test_responses, spks_pred_test, repetitions)
print('FEVE (test): ', np.mean(test_feve))

imonkey1 = np.where(monkey_ids == 4)[0] 
imonkey2 = np.where(monkey_ids == 34)[0]

print('FEVE (test) monkey 1: ', np.mean(test_feve[imonkey1]))
print('FEVE (test) monkey 2: ', np.mean(test_feve[imonkey2]))
print('FEVE (test) mean: ', np.mean([np.mean(test_feve[imonkey1]), np.mean(test_feve[imonkey2])]))
data_dict['LNmodel_feve_all'] = test_feve

# change #conv1 #conv2 result

In [None]:
# build model
nlayers = 2
nconv1 = 192
nconv2 = 192
n_max_neurons = 166
nconv1_list = [8,16,32,64, 128, 192, 256, 320, 384, 448]
nconv2_list = [8,16,32,64, 128, 192, 256, 320, 384, 448]
seed = 2
feve_width = np.zeros((len(nconv1_list), len(nconv2_list), n_max_neurons))

for i, nconv1 in enumerate(nconv1_list):
    for j, nconv2 in enumerate(nconv2_list):
        if (nconv1==16) and (nconv2==320) and(seed==1): continue
        model, in_channels = model_builder.build_model(NN=166, n_layers=nlayers, n_conv=nconv1, n_conv_mid=nconv2, input_Lx=Lx, input_Ly=Ly)
        model_name = model_builder.create_model_name('monkeyV1', '2019', n_layers=nlayers, in_channels=in_channels, seed=seed)
        weight_path = os.path.join(weight_path, 'fullmodel', 'monkeyV1')
        model_path = os.path.join(weight_path, model_name)
        print('model path: ', model_path)
        model.load_state_dict(torch.load(model_path))
        print('loaded model', model_path)
        model = model.to(device)

        spks_pred_test = model_trainer.test_epoch(model, test_images)
        test_fev, test_feve = metrics.monkey_feve(test_responses, spks_pred_test, repetitions)
        print('FEVE (test): ', np.mean(test_feve))
        feve_width[i, j] = test_feve

        imonkey1 = np.where(monkey_ids == 4)[0] 
        imonkey2 = np.where(monkey_ids == 34)[0]

        print('FEVE (test) monkey 1: ', np.mean(test_feve[imonkey1]))
        print('FEVE (test) monkey 2: ', np.mean(test_feve[imonkey2]))
        print('FEVE (test) mean: ', np.mean([np.mean(test_feve[imonkey1]), np.mean(test_feve[imonkey2])]))

In [None]:
data_dict['feve_width'] = feve_width

# change #stims

In [None]:
n_max_stims = 4640
stim_numbers = np.geomspace(500, n_max_stims, num=10, dtype=int)
stim_numbers = np.unique(stim_numbers)  # Remove duplicates that might occur due to rounding
print(stim_numbers)

# build model
seed = 1
nlayers = 2
nconv1 = 16
nconv2 = 320
n_max_neurons = 166
feve_nstims = np.zeros((len(stim_numbers), n_max_neurons))

for i, n_stim in enumerate(stim_numbers):
    model, in_channels = model_builder.build_model(NN=n_max_neurons, n_layers=nlayers, n_conv=nconv1, n_conv_mid=nconv2, input_Lx=Lx, input_Ly=Ly)
    suffix = f'nstims_{n_stim}'
    model_name = model_builder.create_model_name('monkeyV1', '2019', n_layers=nlayers, in_channels=in_channels, seed=seed, suffix=suffix)   
    weight_path = os.path.join(weight_path, 'fullmodel', 'monkeyV1')
    model_path = os.path.join(weight_path, model_name)
    print('model path: ', model_path)
    model.load_state_dict(torch.load(model_path))
    print('loaded model', model_path)
    model = model.to(device)

    spks_pred_test = model_trainer.test_epoch(model, test_images)
    test_fev, test_feve = metrics.monkey_feve(test_responses, spks_pred_test, repetitions)
    print('FEVE (test): ', np.mean(test_feve))
    feve_nstims[i] = test_feve


In [None]:
data_dict['nstims'] = stim_numbers
data_dict['feve_nstims'] = feve_nstims

# change #neurons

In [None]:
# build model
seed = 1
nlayers = 2
nconv1 = 16
nconv2 = 320
n_max_neurons = 166
feve_nneurons = []
nneuron_monkey_ids = []

#  Generate lists of neuron numbers and seed numbers using logarithmic spacing
neuron_numbers = np.geomspace(1, n_max_neurons, num=10, dtype=int)
neuron_numbers = np.unique(np.concatenate(([1], neuron_numbers)))  # Ensure 1 is included and remove duplicates
seed_numbers = np.linspace(20, 1, num=len(neuron_numbers), dtype=int)
for i, n_neuron in enumerate(neuron_numbers):
    feve_nneurons.append([])
    nneuron_monkey_ids.append([])
    for seed in range(1, seed_numbers[i]+1):
        ineurons = np.arange(166)
        if n_neuron != n_max_neurons: 
            np.random.seed(n_neuron*seed)
            ineurons = np.random.choice(np.arange(166), size=n_neuron, replace=False)

        dat = np.load(os.path.join(data_path, 'monkeyv1_cadena_2019.npz'))
        images = dat['images']
        responses = dat['responses'][:, ineurons]
        real_responses = dat['real_responses'][:, ineurons]
        test_images = dat['test_images']
        test_responses = dat['test_responses'][:, :, ineurons]
        test_real_responses = dat['test_real_responses'][:, :, ineurons]
        repetitions = dat['repetitions'][ineurons]
        monkey_id = dat['subject_id'][ineurons]
        image_ids = dat['image_ids']

        # normalize responses
        responses_nan = np.where(real_responses, responses, np.nan)
        resp_std = np.nanstd(responses_nan, axis=0)
        responses = responses / resp_std
        test_responses = test_responses / resp_std

        test_responses = np.where(test_real_responses, test_responses, np.nan)

        model, in_channels = model_builder.build_model(NN=n_neuron, n_layers=nlayers, n_conv=nconv1, n_conv_mid=nconv2, input_Lx=Lx, input_Ly=Ly)
        if n_neuron != n_max_neurons: suffix = f'nneurons_{n_neuron}'
        else: suffix = ''
        model_name = model_builder.create_model_name('monkeyV1', '2019', n_layers=nlayers, in_channels=in_channels, seed=seed, suffix=suffix)   
        weight_path = os.path.join(weight_path, 'fullmodel', 'monkeyV1')
        model_path = os.path.join(weight_path, model_name)
        print('model path: ', model_path)
        model.load_state_dict(torch.load(model_path))
        print('loaded model', model_path)
        model = model.to(device)

        test_images = torch.from_numpy(test_images).to(device)
        spks_pred_test = model_trainer.test_epoch(model, test_images)
        test_fev, test_feve = metrics.monkey_feve(test_responses, spks_pred_test, repetitions)
        print('FEVE (test): ', np.mean(test_feve))
        feve_nneurons[i].append(test_feve)

        nneuron_monkey_ids[i].append(monkey_id)

In [None]:
data_dict['nneurons'] = neuron_numbers
data_dict['feve_nneurons'] = feve_nneurons
data_dict['nneuron_monkey_ids'] = nneuron_monkey_ids

feves = np.zeros((2, len(neuron_numbers)))
for i, nn in enumerate(neuron_numbers):
    feve_allseed = np.array(feve_nneurons[i])
    nseed = len(feve_allseed)
    # print(feve_allseed.shape)
    feve_tmp = np.zeros((2, nseed))
    for iseed in range(nseed):
        monkey_id = np.array(nneuron_monkey_ids[i][iseed])
        imonkey1 = np.where(monkey_id == 4)[0]
        imonkey2 = np.where(monkey_id == 34)[0]
        feve = feve_allseed[iseed]
        if len(imonkey1) == 0:
            feve_tmp[0, iseed] = np.nan
        else:
            feve_tmp[0, iseed] = feve[imonkey1].mean()
        if len(imonkey2) == 0:
            feve_tmp[1, iseed] = np.nan
        else:
            feve_tmp[1, iseed] = feve[imonkey2].mean()

    feves[0, i] = np.nanmean(feve_tmp[0])
    feves[1, i] = np.nanmean(feve_tmp[1])

# visualize conv1

In [None]:
# build model
nlayers = 2
nconv1 = 16
nconv2 = 320
n_max_neurons = 166
feve_nlayers = np.zeros((4, n_max_neurons))
weight_decay_core = 0.1
seed = 2

model, in_channels = model_builder.build_model(NN=166, n_layers=nlayers, n_conv=nconv1, n_conv_mid=nconv2, input_Lx=Lx, input_Ly=Ly)
if weight_decay_core != 0.1: suffix = f'wdcore_{weight_decay_core}'
else: suffix = ''
model_name = model_builder.create_model_name('monkeyV1', '2019', n_layers=nlayers, in_channels=in_channels, seed=seed, suffix=suffix)
weight_path = os.path.join(weight_path, 'fullmodel', 'monkeyV1')
model_path = os.path.join(weight_path, model_name)
print('model path: ', model_path)
model.load_state_dict(torch.load(model_path))
print('loaded model', model_path)
model = model.to(device)

spks_pred_test = model_trainer.test_epoch(model, test_images)
test_fev, test_feve = metrics.monkey_feve(test_responses, spks_pred_test, repetitions)
print('FEVE (test): ', np.mean(test_feve))
feve_nlayers[nlayers-1] = test_feve

imonkey1 = np.where(monkey_ids == 4)[0] 
imonkey2 = np.where(monkey_ids == 34)[0]

print('FEVE (test) monkey 1: ', np.mean(test_feve[imonkey1]))
print('FEVE (test) monkey 2: ', np.mean(test_feve[imonkey2]))
print('FEVE (test) mean: ', np.mean([np.mean(test_feve[imonkey1]), np.mean(test_feve[imonkey2])]))

In [None]:
data_dict['fev_all'] = test_fev
data_dict['fullmodel_Wx'] = model.readout.Wx.cpu().detach().numpy().squeeze()
data_dict['fullmodel_Wy'] = model.readout.Wy.cpu().detach().numpy().squeeze()
data_dict['fullmodel_feve_all'] = test_feve
data_dict['monkey_ids'] = monkey_ids

# save

In [None]:
np.savez('outputs/fullmodel_monkey_results.npz', **data_dict)