In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from minimodel import data, model_builder, model_trainer, metrics

plt.rcParams.update({'font.size': 12})
device = torch.device('cuda')

data_dict = {}

data_path = '../data'
weight_path = './checkpoints/fullmodel'
result_path = './save_results/outputs'

# save results

In [None]:
fev_all = []
for mouse_id in range(6):
    dat = np.load(os.path.join(result_path, f'fullmodel_{data.mouse_names[mouse_id]}_results.npz'), allow_pickle=True)
    fev_all.append(dat['fev'])

fev_all = np.hstack(fev_all)
print(f'{np.sum(fev_all>0.15)}/{len(fev_all)} neurons have FEV > 0.15')

# figure 1c

In [None]:
# load images
mouse_id = 1

# load images
img = data.load_images(data_path, mouse_id, file=os.path.join(data_path, data.img_file_name[mouse_id]))

# load neurons
fname = '%s_nat60k_%s.npz'%(data.db[mouse_id]['mname'], data.db[mouse_id]['datexp'])
spks, istim_train, istim_test, xpos, ypos, spks_rep_all = data.load_neurons(file_path = os.path.join(data_path, fname), mouse_id = mouse_id)
n_stim, n_max_neurons = spks.shape

# split train and validation set
itrain, ival = data.split_train_val(istim_train, train_frac=0.9)

# normalize spks
spks, spks_rep_all = data.normalize_spks(spks, spks_rep_all, itrain)

ineur = np.arange(0, n_max_neurons) #np.arange(0, n_neurons, 5)

spks_val = torch.from_numpy(spks[ival][:,ineur]) 
spks_rep_all = [spks_rep_all[i][:,ineur] for i in range(len(spks_rep_all))]

img_val = torch.from_numpy(img[istim_train][ival]).to(device).unsqueeze(1)
img_test = torch.from_numpy(img[istim_test]).to(device).unsqueeze(1)

input_Ly, input_Lx = img_test.shape[-2:]

In [None]:
spks_rep_all = np.stack(spks_rep_all)
spks_rep_mean = np.mean(spks_rep_all, axis=1)
print(spks_rep_mean.shape, spks_rep_mean.min(), spks_rep_mean.max())
data_dict['spks_test_mean_example'] = spks_rep_mean

In [None]:
nlayers = 4
nconv1 = 192
nconv2 = 192
suffix = ''
model, in_channels = model_builder.build_model(NN=len(ineur), n_layers=nlayers, n_conv=nconv1, n_conv_mid=nconv2)
model_name = model_builder.create_model_name(data.mouse_names[mouse_id], data.exp_date[mouse_id], n_layers=nlayers, in_channels=in_channels)

weight_path = os.path.join(weight_path, 'fullmodel', data.mouse_names[mouse_id])
if not os.path.exists(weight_path):
    os.makedirs(weight_path)
model_path = os.path.join(weight_path, model_name)
print('model path: ', model_path)
model = model.to(device)
model.load_state_dict(torch.load(model_path))
print('loaded model', model_path)

# test model
test_pred = model_trainer.test_epoch(model, img_test)
print('test_pred: ', test_pred.shape, test_pred.min(), test_pred.max())

test_fev, test_feve = metrics.feve(spks_rep_all, test_pred)
print('FEVE (test, all): ', np.mean(test_feve))

threshold = 0.15
print(f'filtering neurons with FEV > {threshold}')
valid_idxes = np.where(test_fev > threshold)[0]
print(f'valid neurons: {len(valid_idxes)} / {len(test_fev)}')
print(f'FEVE (test, FEV>0.15): {np.mean(test_feve[test_fev > threshold])}')


data_dict['fev_example'] = test_fev
data_dict['feve_example'] = test_feve
data_dict['spks_pred_test_example'] = test_pred

# figure 1d FEV versus FEVE

In [None]:
fev_all = []
feve_all = []
for mouse_id in range(6):
    dat = np.load(os.path.join(result_path, f'fullmodel_{data.mouse_names[mouse_id]}_results.npz'), allow_pickle=True)
    fev_all.append(dat['fev'])
    feve_all.append(dat['feve_depth'][3])

fev_all = np.hstack(fev_all)
feve_all = np.hstack(feve_all)

data_dict['feve_all'] = feve_all
data_dict['fev_all'] = fev_all

# figure 1e (5k vs 30k train images)

In [None]:
# build model
seed = 1
nlayers = 4
nconv1 = 192
nconv2 = 192
nmouse = 6
nstim_list = [5000, 30000]
# n_max_neurons = len(valid_idxes)
feve_nstims = np.zeros((nmouse, len(nstim_list)))
for mouse_id in range(nmouse):
    # load images
    img = data.load_images(data_path, mouse_id, file=os.path.join(data_path, data.img_file_name[mouse_id]))
    nimg, Ly, Lx = img.shape
    print('img: ', img.shape, img.min(), img.max(), img.dtype)

    # load neurons
    fname = '%s_nat60k_%s.npz'%(data.db[mouse_id]['mname'], data.db[mouse_id]['datexp'])
    spks, istim_train, istim_test, xpos, ypos, spks_rep_all = data.load_neurons(file_path = os.path.join(data_path, fname), mouse_id = mouse_id)
    n_stim, n_max_neurons = spks.shape
    itrain, ival = data.split_train_val(istim_train, train_frac=0.9)
    # normalize spks
    spks, spks_rep_all = data.normalize_spks(spks, spks_rep_all, itrain)
    ineur = np.arange(0, n_max_neurons) #np.arange(0, n_neurons, 5)
    spks_rep_all = [spks_rep_all[i][:,ineur] for i in range(len(spks_rep_all))]
    img_test = torch.from_numpy(img[istim_test]).to(device).unsqueeze(1)

    input_Ly, input_Lx = img_test.shape[-2:]
    
    for i, n_stim in enumerate(nstim_list):
        if n_stim  > len(itrain): n_stim = len(itrain)
        
        suffix = f'nstims_{n_stim}'
        model, in_channels = model_builder.build_model(NN=len(ineur), n_layers=nlayers, n_conv=nconv1, n_conv_mid=nconv2)
        model_name = model_builder.create_model_name(data.mouse_names[mouse_id], data.exp_date[mouse_id], n_layers=nlayers, in_channels=in_channels, seed=seed, suffix=suffix)

        weight_path = os.path.join(weight_path, 'fullmodel', data.mouse_names[mouse_id])
        if not os.path.exists(weight_path):
            os.makedirs(weight_path)
        model_path = os.path.join(weight_path, model_name)
        print('model path: ', model_path)
        model = model.to(device)
        model.load_state_dict(torch.load(model_path))
        print('loaded model', model_path)

        # test model
        test_pred = model_trainer.test_epoch(model, img_test)
        # print('test_pred: ', test_pred.shape, test_pred.min(), test_pred.max())

        test_fev, test_feve = metrics.feve(spks_rep_all, test_pred)
        threshold = 0.15
        valid_idxes = np.where(test_fev > threshold)[0]
        print('FEVE (test): ', np.mean(test_feve[valid_idxes]))

        feve_nstims[mouse_id, i] = np.mean(test_feve[valid_idxes])

In [None]:
data_dict['5k_30k_feve'] = feve_nstims

# figure 1e (FEVE distribution)

In [None]:
fev_all = []
feve_all = []
for mouse_id in range(6):
    dat = np.load(os.path.join(result_path, f'fullmodel_{data.mouse_names[mouse_id]}_results.npz'), allow_pickle=True)
    fev_all.append(dat['fev'])
    feve_all.append(dat['feve_depth'][3])
fev_all = np.hstack(fev_all)
feve_all = np.hstack(feve_all)
data_dict['feve_all_mice'] = feve_all[fev_all>0.15]

# figure 1f (performance change with model depth)

In [None]:
nmouse = 6
feve_depth_all = np.zeros((nmouse, 4))
feve_LN_all = np.zeros(nmouse)
for mouse_id in range(6):
    dat = np.load(os.path.join(result_path, f'fullmodel_{data.mouse_names[mouse_id]}_results.npz'), allow_pickle=True)
    fev = dat['fev']
    feve_depth = dat['feve_depth']
    feve_depth_all[mouse_id] = feve_depth[:, fev>0.15].mean(axis=1)
    feve_LN_all[mouse_id] = dat['LNmodel_feve_all'][fev>0.15].mean()
data_dict['feve_our_model'] = feve_depth_all
data_dict['feve_LN_model'] = feve_LN_all

In [None]:
# run Lurz_model_train_test.ipynb first
lurz_feve_all = np.load(os.path.join(result_path, 'lurz_feve_all.npy'))
data_dict['feve_lurz_model'] = lurz_feve_all
print(data_dict['feve_lurz_model'].mean(axis=0))

# figure 1g (visualize readout)

In [None]:
# load images
mouse_id = 1

# load neurons
ineur = np.arange(0, data.NNs[mouse_id]) #np.arange(0, n_neurons, 5)
input_Ly, input_Lx = 66, 130

nlayers = 2
nconv1 = 192
nconv2 = 192
model, in_channels = model_builder.build_model(NN=len(ineur), n_layers=nlayers, n_conv=nconv1, n_conv_mid=nconv2)
model_name = model_builder.create_model_name(data.mouse_names[mouse_id], data.exp_date[mouse_id], n_layers=nlayers, in_channels=in_channels)

weight_path = os.path.join(weight_path, 'fullmodel', data.mouse_names[mouse_id])
if not os.path.exists(weight_path):
    os.makedirs(weight_path)
model_path = os.path.join(weight_path, model_name)
print('model path: ', model_path)
model = model.to(device)
model.load_state_dict(torch.load(model_path))
print('loaded model', model_path)

In [None]:
Wc = model.readout.Wc.detach().cpu().numpy().squeeze()
# # change model Wx and Wy
Wx = model.readout.Wx.detach().cpu().numpy()
Wy = model.readout.Wy.detach().cpu().numpy()
# outer product of Wx and Wy
Wxy = np.einsum('icj,ick->ijk', Wy, Wx)

# rfsize from the Wxy
from minimodel.utils import weight_bandwidth
NN = Wxy.shape[0]
bandwidth_Wx = np.zeros(NN)
bandwidth_Wy = np.zeros(NN)
for i in range(NN):
    bandwidth_Wx[i] = weight_bandwidth(Wx[i, 0, :])
    bandwidth_Wy[i] = weight_bandwidth(Wy[i, 0, :])
rf_size = bandwidth_Wx * bandwidth_Wy
print(f'average rf size: {np.mean(rf_size):.2f}')

data_dict['Wxy_example'] = Wxy
data_dict['Wx_example'] = Wx
data_dict['Wy_example'] = Wy

# figure 1h (distribution of pooling area)

In [None]:
nmouse = 6
mouse_list = [0, 1, 2, 3, 4, 5]
nconv1 = 192
nconv2 = 192
nl = 2
rfsize_all = []
for n, mouse_id in enumerate(mouse_list):
    # load neurons
    ineur = np.arange(0, data.NNs[mouse_id]) 
    input_Ly, input_Lx = 66, 130

    nlayers = 2
    nconv1 = 192
    nconv2 = 192
    suffix = ''
    if mouse_id == 5: suffix += f'xrange_{xrange_max}'
    model, in_channels = model_builder.build_model(NN=len(ineur), n_layers=nlayers, n_conv=nconv1, n_conv_mid=nconv2)
    model_name = model_builder.create_model_name(data.mouse_names[mouse_id], data.exp_date[mouse_id], n_layers=nlayers, in_channels=in_channels, suffix=suffix)

    weight_path = os.path.join(weight_path, 'fullmodel', data.mouse_names[mouse_id])
    if not os.path.exists(weight_path):
        os.makedirs(weight_path)
    model_path = os.path.join(weight_path, model_name)
    print('model path: ', model_path)
    model = model.to(device)
    model.load_state_dict(torch.load(model_path))
    print('loaded model', model_path)

    Wc = model.readout.Wc.detach().cpu().numpy().squeeze()
    # # change model Wx and Wy
    Wx = model.readout.Wx.detach().cpu().numpy()
    Wy = model.readout.Wy.detach().cpu().numpy()
    # outer product of Wx and Wy
    Wxy = np.einsum('icj,ick->ijk', Wy, Wx)
    print(Wxy.shape, Wc.shape)
    print(Wc.shape, Wc.min(), Wc.max())

    # rfsize from the Wxy
    from minimodel.utils import weight_bandwidth
    NN = Wxy.shape[0]
    bandwidth_Wx = np.zeros(NN)
    bandwidth_Wy = np.zeros(NN)
    for i in range(NN):
        bandwidth_Wx[i] = weight_bandwidth(Wx[i, 0, :])
        bandwidth_Wy[i] = weight_bandwidth(Wy[i, 0, :])
    rf_size = bandwidth_Wx * bandwidth_Wy
    print(f'average rf size: {np.mean(rf_size):.2f}')
    rfsize_all.append(rf_size)
rfsize_all = np.hstack(rfsize_all)
print(rfsize_all.shape)
data_dict['rfsize_all'] = rfsize_all

# monkey

## figure 1i (images)

In [None]:
# load data
dat = np.load(os.path.join(data_path, 'monkeyv1_cadena_2019.npz'))
images = dat['images']
responses = dat['responses']
real_responses = dat['real_responses']
test_images = dat['test_images']
test_responses = dat['test_responses']
test_real_responses = dat['test_real_responses']
train_idx = dat['train_idx']
val_idx = dat['val_idx']
repetitions = dat['repetitions']
monkey_ids = dat['subject_id']
image_ids = dat['image_ids']

print(images.shape)

In [None]:
n_img, _, Ly, Lx = images.shape
np.random.seed(42)
iselect = np.random.choice(n_img, 10, replace=False)
fig, ax = plt.subplots(2, 5, figsize=(10, 4))
for i in range(10):
    ax[i//5, i%5].imshow(images[iselect[i]].mean(0), cmap='gray')
    ax[i//5, i%5].axis('off')
plt.savefig('monkeyv1_images.pdf', dpi=300)

## figure 1j-k

In [None]:
# load data
dat = np.load(os.path.join(data_path, 'monkeyv1_cadena_2019.npz'))
images = dat['images']
responses = dat['responses']
real_responses = dat['real_responses']
test_images = dat['test_images']
test_responses = dat['test_responses']
test_real_responses = dat['test_real_responses']
train_idx = dat['train_idx']
val_idx = dat['val_idx']
repetitions = dat['repetitions']
monkey_ids = dat['subject_id']
image_ids = dat['image_ids']

# normalize responses
responses_nan = np.where(real_responses, responses, np.nan)
resp_std = np.nanstd(responses_nan, axis=0)
responses = responses / resp_std
test_responses = test_responses / resp_std

test_responses_nan = np.where(test_real_responses, test_responses, np.nan)
print('test_responses_nan: ', test_responses_nan.shape)

In [None]:
nlayers = 2
nconv1 = 192
nconv2 = 192
Lx, Ly = 80, 80
model, in_channels = model_builder.build_model(NN=166, n_layers=nlayers, n_conv=nconv1, n_conv_mid=nconv2, input_Lx=Lx, input_Ly=Ly)
suffix = ''
model_name = model_builder.create_model_name('monkeyV1', '2019', n_layers=nlayers, in_channels=in_channels, suffix=suffix)
weight_path = os.path.join(weight_path, 'fullmodel', 'monkeyV1')
model_path = os.path.join(weight_path, model_name)
print('model path: ', model_path)
model.load_state_dict(torch.load(model_path))
print('loaded model', model_path)
model = model.to(device)
test_images = torch.from_numpy(test_images).to(device)
spks_pred_test = model_trainer.test_epoch(model, test_images)
print('spks_pred_test: ', spks_pred_test.shape)

In [None]:
spks_test_mean = np.nanmean(test_responses_nan, axis=0)
dat = np.load(os.path.join(result_path, f'fullmodel_monkey_results.npz'), allow_pickle=True)

data_dict['monkey_fev'] = dat['fev_all']
data_dict['monkey_feve'] = dat['feve_depth'][1]
data_dict['monkey_spks_test_mean'] = spks_test_mean
data_dict['monkey_spks_pred_test'] = spks_pred_test
data_dict['monkey_LNmodel_feve'] = dat['LNmodel_feve_all'].mean()

## figure 1l (FEVE change with depth)

In [None]:
dat = np.load(f'outputs/fullmodel_monkey_results.npz', allow_pickle=True)
data_dict['monkey_id'] = dat['monkey_ids']
monkey_eve = dat['feve_depth']
imonkey1 = np.where(data_dict['monkey_id']==4)[0]
imonkey2 = np.where(data_dict['monkey_id']==34)[0]
separate_eve = np.zeros((2, 4))
separate_eve[0] = np.mean(monkey_eve[:, imonkey1], axis=1)
separate_eve[1] = np.mean(monkey_eve[:, imonkey2], axis=1)
data_dict['monkey_depth_eve'] = monkey_eve.T
data_dict['monkey_feve'] = dat['feve_depth'][1]
data_dict['monkey_fev'] = dat['fev_all']
print(monkey_eve.mean(1))

In [None]:
# load VGG feve
vgg_eve = np.load(os.path.join(result_path, 'Cadena_vgg_feve.npy'))
print(vgg_eve.mean())

data_dict['vgg_eve'] = vgg_eve
# data_dict['fullmodel_eve'] = fullmodel_eve
vgg_eve_monkey1 = vgg_eve[imonkey1].mean()
vgg_eve_monkey2 = vgg_eve[imonkey2].mean()
print(vgg_eve_monkey1, vgg_eve_monkey2, vgg_eve.mean())

## figure 1m-n

In [None]:
nlayers = 2
nconv1 = 192
nconv2 = 192
Lx, Ly = 80, 80
model, in_channels = model_builder.build_model(NN=166, n_layers=nlayers, n_conv=nconv1, n_conv_mid=nconv2, input_Lx=Lx, input_Ly=Ly)
suffix = ''
model_name = model_builder.create_model_name('monkeyV1', '2019', n_layers=nlayers, in_channels=in_channels, suffix=suffix)
weight_path = os.path.join(weight_path, 'fullmodel', 'monkeyV1')
model_path = os.path.join(weight_path, model_name)
print('model path: ', model_path)
model.load_state_dict(torch.load(model_path))
print('loaded model', model_path)

Wc = model.readout.Wc.detach().cpu().numpy().squeeze()
# # change model Wx and Wy
Wx = model.readout.Wx.detach().cpu().numpy()
Wy = model.readout.Wy.detach().cpu().numpy()
# outer product of Wx and Wy
Wxy = np.einsum('icj,ick->ijk', Wy, Wx)

# rfsize from the Wxy
from minimodel.utils import weight_bandwidth
NN = Wxy.shape[0]
bandwidth_Wx = np.zeros(NN)
bandwidth_Wy = np.zeros(NN)
for i in range(NN):
    bandwidth_Wx[i] = weight_bandwidth(Wx[i, 0, :])
    bandwidth_Wy[i] = weight_bandwidth(Wy[i, 0, :])
rf_size = bandwidth_Wx * bandwidth_Wy
print(f'average rf size: {np.mean(rf_size):.2f}')
data_dict['monkey_rfsize'] = rf_size

In [None]:
data_dict['monkey_Wxy'] = Wxy
data_dict['monkey_Wx'] = Wx
data_dict['monkey_Wy'] = Wy

# save

In [None]:
# save data_dict
np.savez(f'figure1_results.npz', **data_dict)

# plot

In [None]:
import figure1
dat = np.load('figure1_results.npz', allow_pickle=True)
save_path = './outputs'
figure1.figure1(dat, save_path)