In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import torch
import numpy as np
from minimodel import data, model_builder, metrics, model_trainer

device = torch.device('cuda')

data_dict = {}

data_path = '../data'
weight_path = './checkpoints/fullmodel'
result_path = './save_results/outputs'

# figure 1: retinotopy

In [None]:
# load model
mouse_id = 0
xpos_all = []
ypos_all = []
xpos_visual_all = []
ypos_visual_all = []
x_pixel_ratio = 0.75
y_pixel_ratio = 0.5
fev_all = []
for mouse_id in range(6):
    fname = '%s_nat60k_%s.npz'%(data.db[mouse_id]['mname'], data.db[mouse_id]['datexp'])
    spks, istim_train, istim_test, xpos, ypos, spks_rep_all = data.load_neurons(file_path = os.path.join(data_path, fname), mouse_id = mouse_id)
    xpos_all.append(xpos)
    ypos_all.append(ypos)
    dat = np.load(os.path.join(result_path, f'fullmodel_{data.mouse_names[mouse_id]}_results.npz'), allow_pickle=True)
    from minimodel.utils import weight_bandwidth
    Wx = dat['fullmodel_Wx']
    Wy = dat['fullmodel_Wy']
    NN = Wx.shape[0]
    bandwidth_Wx = np.zeros(NN)
    bandwidth_Wy = np.zeros(NN)
    centerpos_Wx = np.zeros(NN)
    centerpos_Wy = np.zeros(NN)
    for i in range(NN):
        bandwidth_Wx[i], centerpos_Wx[i] = weight_bandwidth(Wx[i, :], return_peak=True)
        bandwidth_Wy[i], centerpos_Wy[i] = weight_bandwidth(Wy[i, :], return_peak=True)
    xpos_model = np.argmax(Wx.squeeze(), axis=1) 
    ypos_model = np.argmax(Wy.squeeze(), axis=1) 
    xpos_visual = centerpos_Wx*2*(270/264) - 135 # 0 is at the center, so it should be 135 pixels
    ypos_visual = centerpos_Wy*2*(65/66) - 32.5 # vertical visual range is 65, so it sould be (66/65) pixels per degree
    if mouse_id == 5: xpos_visual += (46*270/264) # xrange of mouse 6 is 46-176
    xpos = xpos / x_pixel_ratio
    ypos = ypos / y_pixel_ratio
    if mouse_id == 1:
        idx_up = np.where(xpos>325/ x_pixel_ratio)[0]
        idx_down = np.where(xpos<=325/ x_pixel_ratio)[0]
        ymax = ypos[idx_up].max()
        xmax, xmin = xpos[idx_up].max(), xpos[idx_up].min()
        ypos[idx_up] = ymax - ypos[idx_up] +300 # + ymax
    xpos_visual_all.append(xpos_visual)
    ypos_visual_all.append(ypos_visual)
    fev_all.append(dat['fev'])
data_dict['xpos_all'] = np.array(xpos_all)
data_dict['ypos_all'] = np.array(ypos_all)
data_dict['xpos_visual_all'] = np.array(xpos_visual_all)
data_dict['ypos_visual_all'] = np.array(ypos_visual_all)
data_dict['fev_all'] = np.array(fev_all)

In [None]:
from pathlib import Path

db = []

db.append({'mname': 'L1_A5', 'datexp': '2023_01_27', 'blk':'3', 'stim':'short3'})
db.append({'mname': 'L1_A1', 'datexp': '2023_03_27', 'blk':'1', 'stim':'short3'})
db.append({'mname': 'FX9', 'datexp': '2023_05_02', 'blk':'2', 'stim':'short3'})
db.append({'mname': 'FX10', 'datexp': '2023_05_02', 'blk':'1', 'stim':'short3'})
db.append({'mname': 'FX8', 'datexp': '2023_05_02', 'blk':'1', 'stim':'short3'})
db.append({'mname': 'FX20', 'datexp': '2023_09_08', 'blk':'2','stim':'nat15k'})

iregion_all = []
iarea_all = []
xy_all = []
out_all = []
jxy_all = []

for i, mouse in enumerate(db[:6]):
    aligned_path = Path(os.path.join(result_path, 'Ret_maps'))
    aligned_path = aligned_path.joinpath(f"{db[i]['mname']}_{db[i]['datexp']}_{db[i]['blk']}_{db[i]['stim']}.npz")

    m = np.load(aligned_path, allow_pickle=True)
    iregion = m['iregion']
    iarea = m['iarea']
    xy = m['xy']
    out = m['out']
    jxy = m['jxy']
    iregion_all.append(iregion)
    iarea_all.append(iarea)
    xy_all.append(xy)
    out_all.append(out)
    jxy_all.append(jxy)
data_dict['iregion'] = np.array(iregion_all)
data_dict['iarea'] = np.array(iarea_all)
data_dict['xy'] = np.array(xy_all)  
data_dict['out'] = np.array(out_all)
data_dict['jxy'] = np.array(jxy_all)


In [None]:
import sup_figure
from fig_utils import *
save_path = './outputs'
sup_figure.figure1(data_dict, save_path)

# figure 2: signal variance

In [None]:
# Example neuron Trial 1 vs trial 2
mouse_id = 1
fname = '%s_nat60k_%s.npz'%(data.db[mouse_id]['mname'], data.db[mouse_id]['datexp'])
spks, istim_train, istim_test, xpos, ypos, spks_rep_all = data.load_neurons(file_path = os.path.join(data_path, fname), mouse_id = mouse_id)
# spks, spks_rep, itrain, ival, fev, istim_test = load_activity(file_path=mouse_file_paths[mouse_id], mouse_id=mouse_id)
# split train and validation set
itrain, ival = data.split_train_val(istim_train, train_frac=0.9)
# normalize spks
spks, spks_rep_all = data.normalize_spks(spks, spks_rep_all, itrain)
fev_all = metrics.fev(spks_rep_all)

ineurons = [2550, 5020, 264]
spks_rep = [spks_rep_all[i][:,ineurons] for i in range(len(spks_rep_all))]
data_dict['example_repeats'] = np.stack(spks_rep)
data_dict['example_fev'] = fev_all[ineurons]

In [None]:
# Signal variance distribution mouse
fev_all = []
for mouse_id in range(6):
    dat = np.load(os.path.join(result_path, f'fullmodel_{data.mouse_names[mouse_id]}_results.npz'), allow_pickle=True)
    fev_all.append(dat['fev'])
fev_all = np.hstack(fev_all)
data_dict['fev_all'] = fev_all

## monkey

In [None]:
# load data
dat = np.load(os.path.join(data_path, 'monkeyv1_cadena_2019.npz'))
images = dat['images']
responses = dat['responses']
real_responses = dat['real_responses']
test_images = dat['test_images']
test_responses = dat['test_responses']
test_real_responses = dat['test_real_responses']
train_idx = dat['train_idx']
val_idx = dat['val_idx']
repetitions = dat['repetitions']
monkey_ids = dat['subject_id']
image_ids = dat['image_ids']

# normalize responses
responses_nan = np.where(real_responses, responses, np.nan)
resp_std = np.nanstd(responses_nan, axis=0)
responses = responses / resp_std
test_responses = test_responses / resp_std

test_responses_nan = np.where(test_real_responses, test_responses, np.nan)
print('test_responses_nan: ', test_responses_nan.shape)

# spks_test_mean = np.nanmean(test_responses_nan, axis=0)
# print(spks_test_mean.shape)
dat = np.load(os.path.join(result_path, f'fullmodel_monkey_results.npz'), allow_pickle=True)
data_dict['monkey_fev_all'] = dat['fev_all']
ineurons = [51, 105, 86]
data_dict['monkey_example_repeats'] = test_responses_nan[:, :, ineurons]
data_dict['monkey_example_fev'] = dat['fev_all'][ineurons]
data_dict['monkey_example_reps'] = repetitions[ineurons]

In [None]:
ineurons = [51, 105, 86]
data_dict['monkey_example_repeats'] = test_responses_nan[:, :, ineurons]
data_dict['monkey_example_fev'] = dat['fev_all'][ineurons]
data_dict['monkey_example_reps'] = repetitions[ineurons]

In [None]:
import sup_figure
from fig_utils import *
save_path = './outputs'
sup_figure.figure2(data_dict, save_path)

# figure 3: Wx and Wy

In [None]:
# mouse
# Example neuron Trial 1 vs trial 2
mouse_id = 0
Wx_all = []
Wy_all = []
fev_all = []
for mouse_id in range(6):
    # spks, spks_rep, itrain, ival, fev, istim_test = load_activity(file_path=data.mouse_file_paths[mouse_id], mouse_id=mouse_id)
    n_neurons = data.NNs[mouse_id]
    ineur = np.arange(0, n_neurons) #np.arange(0, n_neurons, 5)
    input_Ly, input_Lx = 66, 130
    nlayers = 2
    nconv1 = 192
    nconv2 = 192

    suffix = ''
    if mouse_id == 5: suffix += f'xrange_176'
    model, in_channels = model_builder.build_model(NN=len(ineur), n_layers=nlayers, n_conv=nconv1, n_conv_mid=nconv2, pool=pool, depth_separable=depth_separable)
    model_name = model_builder.create_model_name(data.mouse_names[mouse_id], data.exp_date[mouse_id], n_layers=nlayers, in_channels=in_channels, clamp=clamp, suffix=suffix)

    weight_path = os.path.join(weight_path, 'fullmodel', data.mouse_names[mouse_id])
    if not os.path.exists(weight_path):
        os.makedirs(weight_path)
    model_path = os.path.join(weight_path, model_name)
    print('model path: ', model_path)
    model = model.to(device)
    model.load_state_dict(torch.load(model_path))
    print('loaded model', model_path)

    Wc = model.readout.Wc.detach().cpu().numpy().squeeze()
    # # change model Wx and Wy
    Wx = model.readout.Wx.detach().cpu().numpy()
    Wy = model.readout.Wy.detach().cpu().numpy()
    # outer product of Wx and Wy
    Wxy = np.einsum('icj,ick->ijk', Wy, Wx)

    dat = np.load(os.path.join(result_path, f'fullmodel_{data.mouse_names[mouse_id]}_results.npz'), allow_pickle=True)
    fev = dat['fev']
    valid_idx = np.where(fev > 0.15)[0]
    Wx_all.append(Wx[valid_idx])
    Wy_all.append(Wy[valid_idx])
    fev_all.append(fev[valid_idx])

data_dict['fev_all'] = fev_all
data_dict['Wx_all'] = Wx_all
data_dict['Wy_all'] = Wy_all

In [None]:
from minimodel.utils import weight_bandwidth
from scipy.interpolate import interp1d
fev_all = data_dict['fev_all']
Wx_all = data_dict['Wx_all']
Wy_all = data_dict['Wy_all']

fev = np.hstack(fev_all)
Wx = np.vstack(Wx_all)
Wy = np.vstack(Wy_all)

In [None]:
NN = Wx.shape[0]
bandwidth_Wx = np.zeros(NN)
bandwidth_Wy = np.zeros(NN)
centerpos_Wx = np.zeros(NN)
centerpos_Wy = np.zeros(NN)
for i in range(NN):
    bandwidth_Wx[i], centerpos_Wx[i] = weight_bandwidth(Wx[i, 0, :], return_peak=True)
    bandwidth_Wy[i], centerpos_Wy[i] = weight_bandwidth(Wy[i, 0, :], return_peak=True)

Lx, Ly = Wx.shape[-1], Wy.shape[-1]
xmid, ymid = int(Lx/2), int(Ly/2)
x_xrange, y_xrange = np.arange(Lx), np.arange(Ly)

ineurons = np.arange(0, NN)
# Define a common x range for interpolation, e.g., based on the min and max of x_xrange and center positions
common_x = np.linspace(np.min(x_xrange - np.max(centerpos_Wx)), 
                    np.max(x_xrange - np.min(centerpos_Wx)), 
                    num=len(x_xrange))

# Container for interpolated Wx values
interp_Wx_values = []

for i in ineurons:
    # Original x values for this neuron, shifted by its center position
    original_x = x_xrange - centerpos_Wx[i]
    # Interpolation function for the current neuron's Wx values
    interp_func = interp1d(original_x, Wx[i, 0, :], kind='linear', bounds_error=False, fill_value=np.NaN)
    # Interpolate onto the common x range and store the result
    interp_Wx_values.append(interp_func(common_x))
    # plt.plot(original_x, Wx[i, 0, :], color='whitesmoke', alpha=1/255)


data_dict['interp_Wx_values'] = interp_Wx_values
data_dict['common_x'] = common_x

In [None]:
from scipy.interpolate import interp1d

# Assumptions and initial setup (Ensure these variables are defined correctly in your context)
ineurons = np.arange(0, NN)
# Define a common x range for interpolation, e.g., based on the min and max of x_xrange and center positions
common_x = np.linspace(np.min(y_xrange - np.max(centerpos_Wy)), 
                       np.max(y_xrange - np.min(centerpos_Wy)), 
                       num=len(y_xrange))

# Container for interpolated Wx values
interp_Wy_values = []

for i in ineurons:
    # Original x values for this neuron, shifted by its center position
    original_x = y_xrange - centerpos_Wy[i]
    # Interpolation function for the current neuron's Wx values
    interp_func = interp1d(original_x, Wy[i, 0, :], kind='linear', bounds_error=False, fill_value=np.NaN)
    # Interpolate onto the common x range and store the result
    interp_Wy_values.append(interp_func(common_x))
    # plt.plot(original_x, Wy[i, 0, :], color='gray', alpha=0.1)

# Calculate the mean of the interpolated Wx values
data_dict['interp_Wy_values'] = interp_Wy_values
data_dict['common_y'] = common_x

## monkey

In [None]:
# monkey
Lx, Ly = 80, 80
nconv1 = 192
nconv2 = 192
use_sensorium_normalization = True
model, in_channels = model_builder.build_model(NN=166, n_layers=nlayers, n_conv=nconv1, n_conv_mid=nconv2, input_Lx=Lx, input_Ly=Ly)
suffix = ''
model_name = model_builder.create_model_name('monkeyV1', '2019', n_layers=nlayers, in_channels=in_channels, suffix=suffix, seed=1, use_sensorium_normalization=use_sensorium_normalization)
weight_path = os.path.join(weight_path, 'fullmodel', 'monkeyV1')
model_path = os.path.join(weight_path, model_name)
print('model path: ', model_path)
model.load_state_dict(torch.load(model_path))
print('loaded model', model_path)
model = model.to(device)

Wc = model.readout.Wc.detach().cpu().numpy().squeeze()
# # change model Wx and Wy
Wx = model.readout.Wx.detach().cpu().numpy()
Wy = model.readout.Wy.detach().cpu().numpy()
# outer product of Wx and Wy
Wxy = np.einsum('icj,ick->ijk', Wy, Wx)
print(Wxy.shape, Wc.shape)
print(Wc.shape, Wc.min(), Wc.max())

In [None]:
NN = Wx.shape[0]
bandwidth_Wx = np.zeros(NN)
bandwidth_Wy = np.zeros(NN)
centerpos_Wx = np.zeros(NN)
centerpos_Wy = np.zeros(NN)
for i in range(NN):
    bandwidth_Wx[i], centerpos_Wx[i] = weight_bandwidth(Wx[i, 0, :], return_peak=True)
    bandwidth_Wy[i], centerpos_Wy[i] = weight_bandwidth(Wy[i, 0, :], return_peak=True)

Lx, Ly = Wx.shape[-1], Wy.shape[-1]
xmid, ymid = int(Lx/2), int(Ly/2)
x_xrange, y_xrange = np.arange(Lx), np.arange(Ly)

ineurons = np.arange(0, NN)
# Define a common x range for interpolation, e.g., based on the min and max of x_xrange and center positions
common_x = np.linspace(np.min(x_xrange - np.max(centerpos_Wx)), 
                    np.max(x_xrange - np.min(centerpos_Wx)), 
                    num=len(x_xrange))

# Container for interpolated Wx values
interp_Wx_values = []

for i in ineurons:
    # Original x values for this neuron, shifted by its center position
    original_x = x_xrange - centerpos_Wx[i]
    # Interpolation function for the current neuron's Wx values
    interp_func = interp1d(original_x, Wx[i, 0, :], kind='linear', bounds_error=False, fill_value=np.NaN)
    # Interpolate onto the common x range and store the result
    interp_Wx_values.append(interp_func(common_x))
    # plt.plot(original_x, Wx[i, 0, :], color='whitesmoke', alpha=1/255)


data_dict['monkey_interp_Wx_values'] = interp_Wx_values
data_dict['monkey_common_x'] = common_x

In [None]:
from scipy.interpolate import interp1d

# Assumptions and initial setup (Ensure these variables are defined correctly in your context)
ineurons = np.arange(0, NN)
# Define a common x range for interpolation, e.g., based on the min and max of x_xrange and center positions
common_x = np.linspace(np.min(y_xrange - np.max(centerpos_Wy)), 
                       np.max(y_xrange - np.min(centerpos_Wy)), 
                       num=len(y_xrange))

# Container for interpolated Wx values
interp_Wy_values = []

for i in ineurons:
    # Original x values for this neuron, shifted by its center position
    original_x = y_xrange - centerpos_Wy[i]
    # Interpolation function for the current neuron's Wx values
    interp_func = interp1d(original_x, Wy[i, 0, :], kind='linear', bounds_error=False, fill_value=np.NaN)
    # Interpolate onto the common x range and store the result
    interp_Wy_values.append(interp_func(common_x))
    # plt.plot(original_x, Wy[i, 0, :], color='gray', alpha=0.1)

# Calculate the mean of the interpolated Wx values
data_dict['monkey_interp_Wy_values'] = interp_Wy_values
data_dict['monkey_common_y'] = common_x

In [None]:
# plot
import sup_figure
from fig_utils import *
save_path = './outputs'
sup_figure.figure3(data_dict, save_path)

# figure 4: sparsity penalty on minimodel

## mouse validation

In [None]:
mouse_id = 0
hs_list = [0, 0.01, 0.02, 0.03, 0.04, 0.05, 0.1, 0.2, 0.5]
nhs = len(hs_list)
feve_all = []
nconv2_all = []

dat = np.load(os.path.join(result_path, f'{data.mouse_names[mouse_id]}_{data.exp_date[mouse_id]}_minimodel_16_64_choose_param_val_result.npz'))
feve_all = dat['feve_val']
nconv2_all = dat['nconv2']
print(feve_all.shape, nconv2_all.shape)

data_dict['mouse_feve_val'] = feve_all
data_dict['mouse_nconv2_val'] = nconv2_all

dat = np.load(os.path.join(result_path, f'{data.mouse_names[mouse_id]}_{data.exp_date[mouse_id]}_minimodel_16_64_choose_param_5k_val_result.npz'))
feve_all = dat['feve_val']
nconv2_all = dat['nconv2']
print(feve_all.shape, nconv2_all.shape)

data_dict['mouse_feve_val_5k'] = feve_all
data_dict['mouse_nconv2_val_5k'] = nconv2_all

# monkey
dat = np.load("outputs/minimodel_monkey_result.npz")
feve_all = dat['param_search_feve_val']
nconv2_all = dat['param_search_nconv2_val']
print(feve_all.shape, nconv2_all.shape)

data_dict['monkey_feve_val'] = feve_all
data_dict['monkey_nconv2_val'] = nconv2_all

## mouse test

In [None]:
nmouse = 6
hs_list = [0, 0.01, 0.02, 0.03, 0.04, 0.05, 0.1, 0.2, 0.5]
nhs = len(hs_list)
feve_all = []
nconv2_all = []
for mouse_id in range(nmouse):
    dat = np.load(os.path.join(result_path, f'{data.mouse_names[mouse_id]}_{data.exp_date[mouse_id]}_minimodel_16_64_choose_param_result.npz'))
    feve = dat['feve_all']
    wc = dat['wc_all']
    nconv2 = np.sum(np.abs(wc)>0.01, axis=2)
    feve_all.append(feve)
    nconv2_all.append(nconv2)
feve_all = np.vstack(feve_all)
nconv2_all = np.vstack(nconv2_all)
print(feve_all.shape, nconv2_all.shape)

In [None]:
data_dict['sparsity_feve_all'] = feve_all
data_dict['sparsity_nconv2_all'] = nconv2_all
data_dict['sparsity_hs'] = hs_list

### mouse 5k

In [None]:
nmouse = 6
hs_list = [0, 0.01, 0.02, 0.03, 0.04, 0.05, 0.1, 0.2, 0.5]
nhs = len(hs_list)
feve_all = []
nconv2_all = []
for mouse_id in range(nmouse):
    dat = np.load(os.path.join(result_path, f'{data.mouse_names[mouse_id]}_{data.exp_date[mouse_id]}_minimodel_16_64_choose_param_5k_result.npz'))
    feve = dat['feve_all']
    wc = dat['wc_all']
    nconv2 = np.sum(np.abs(wc)>0.01, axis=2)
    feve_all.append(feve)
    nconv2_all.append(nconv2)
feve_all = np.vstack(feve_all)
nconv2_all = np.vstack(nconv2_all)
print(feve_all.shape, nconv2_all.shape)

data_dict['sparsity_feve_5k_all'] = feve_all
data_dict['sparsity_nconv2_5k_all'] = nconv2_all

## monkey

In [None]:
dat = np.load(os.path.join(result_path, 'minimodel_monkey_result.npz'))
feve_hs_all = dat['feve_hs_all']
wc_hs_all = dat['wc_hs_all']
nconv2_all = np.sum(np.abs(wc_hs_all)>0.01, axis=2)
hs_list = dat['hs_list']
data_dict['sparsity_monkey_feve_all'] = feve_hs_all 
data_dict['sparsity_monkey_nconv2_all'] = nconv2_all
data_dict['sparsity_monkey_hs'] = hs_list

## plot

In [None]:
# plot
import sup_figure
from fig_utils import *
save_path = './outputs'
sup_figure.figure4(data_dict, save_path)

# figure 5: model structure

In [None]:
import sup_figure
from fig_utils import *
save_path = './outputs'
sup_figure.figure5(data_dict, save_path)

# figure 6: reuse conv1

In [None]:
nmouse = 6
NN = 100
feve_matrix = np.zeros((nmouse, nmouse, NN))

nconv1 = 16
nconv2 = 64
wc_coef = 0.2
hs_readout = 0.03
nlayers = 2

for mouse_id in range(nmouse):
    img = data.load_images(data_path, mouse_id, file=os.path.join(data_path, data.img_file_name[mouse_id]))
    nimg, Ly, Lx = img.shape
    print('img: ', img.shape, img.min(), img.max(), img.dtype)
    fname = '%s_nat60k_%s.npz'%(data.db[mouse_id]['mname'], data.db[mouse_id]['datexp'])

    spks, istim_train, istim_test, xpos, ypos, spks_rep_all = data.load_neurons(file_path = os.path.join(data_path, fname), mouse_id = mouse_id)
    n_stim, n_max_neurons = spks.shape

    # normalize spks
    itrain, ival = data.split_train_val(istim_train, train_frac=0.9)
    spks, spks_rep_all = data.normalize_spks(spks, spks_rep_all, itrain)
    img_test = torch.from_numpy(img[istim_test]).to(device).unsqueeze(1)
    input_Ly, input_Lx = img_test.shape[-2:]

    ineurons = np.arange(data.NNs_valid[mouse_id])
    np.random.seed(42)
    ineurons = np.random.choice(ineurons, 100, replace=False)

    fev_test = metrics.fev(spks_rep_all)
    isort_neurons = np.argsort(fev_test)[::-1]
    for mouse_id_base in range(nmouse):
        for i, ineuron in enumerate(ineurons):
            ineur = [isort_neurons[ineuron]]
            spks_rep = [spks_rep_all[i][:,ineur] for i in range(len(spks_rep_all))]

            if mouse_id_base != mouse_id:
                suffix = f'pretrainconv1_{data.mouse_names[mouse_id_base]}_{data.exp_date[mouse_id_base]}'
                if mouse_id_base == 5: 
                    suffix = f'pretrainconv1_{data.mouse_names[mouse_id_base]}_{data.exp_date[mouse_id_base]}_xrange_176'
            else: 
                suffix = ''
            model, in_channels = model_builder.build_model(NN=1, n_layers=nlayers, n_conv=nconv1, n_conv_mid=nconv2, Wc_coef=wc_coef)
            model_name = model_builder.create_model_name(data.mouse_names[mouse_id], data.exp_date[mouse_id], ineuron=ineur[0], n_layers=nlayers, in_channels=in_channels, hs_readout=hs_readout, suffix=suffix)

            weight_path = os.path.join(weight_path, 'minimodel', data.mouse_names[mouse_id])
            model_path = os.path.join(weight_path, model_name)
            print('model path: ', model_path)
            model.load_state_dict(torch.load(model_path))
            print('loaded model', model_path)
            model = model.to(device)

            # test model
            test_pred = model_trainer.test_epoch(model, img_test)
            test_fev, test_feve = metrics.feve(spks_rep, test_pred)
            print('FEVE (test): ', test_feve)

            feve_matrix[mouse_id, mouse_id_base, i] = np.mean(test_feve)

In [None]:
data_dict['reuse_conv1_feve'] = feve_matrix.transpose(1,0,2)

In [None]:
# conv1 kernels
mouse_id = 0
conv1_all = []
for mouse_id in range(6):
    dat = np.load(os.path.join(result_path, f'fullmodel_{data.mouse_names[mouse_id]}_results.npz'), allow_pickle=True)
    conv1 = dat['fullmodel_conv1_W']
    conv1_all.append(conv1)
data_dict['conv1_W'] = np.stack(conv1_all)
print(data_dict['conv1_W'].shape)

In [None]:
import sup_figure
from fig_utils import *
save_path = './outputs'
sup_figure.figure6(data_dict, save_path)

# figure 7: example neurons visualization

In [None]:
import sup_figure
from fig_utils import *
save_path = './outputs'
sup_figure.figure9(data_dict, save_path)

# save all

In [None]:
np.savez(f'supfig_results.npz', **data_dict)