# Preparing Data for Figure 4

This notebook saves the processed data needed to plot **Figure 4** into a `.npz` file and then generates the corresponding plots.

**Important:**  
Before running this notebook, please make sure to execute the following notebooks in the `./save_results` directory. 
- [minimodel_mouse_saveall.ipynb](https://github.com/MouseLand/minimodel/blob/main/figures/save_results/minimodel_mouse_saveall.ipynb)
- [minimodel_monkey_saveall.ipynb](https://github.com/MouseLand/minimodel/blob/main/figures/save_results/minimodel_monkey_saveall.ipynb)
- [mouse_invariance_saveall.ipynb](https://github.com/MouseLand/minimodel/blob/main/figures/save_results/mouse_invariance_saveall.ipynb)
- [monkey_invariance_saveall.ipynb](https://github.com/MouseLand/minimodel/blob/main/figures/save_results/monkey_invariance_saveall.ipynb)

These notebooks:
- Load the raw neural and stimulus data,
- Run models for each animal (mouse and monkey),
- Save the model outputs needed for plotting.

Each notebook in `./save_results` corresponds to a specific condition or model variant. Skipping any of them may result in missing or incomplete data when running this notebook.


In [None]:
import os
import torch
import numpy as np
from minimodel import data

device = torch.device('cuda')

data_dict = {}

data_path = '../data'
weight_path = './checkpoints/fullmodel'
result_path = './save_results/outputs'

# figure 4a visualize texture 16 images

In [None]:
mouse_id = 2
# load images
img = data.load_images(data_path, mouse_id, file=os.path.join(data_path, data.img_file_name[mouse_id]), normalize=False, crop=False)
nimg, Ly, Lx = img.shape
print('img: ', img.shape, img.min(), img.max(), img.dtype)

In [None]:
N, Ly, Lx = img.shape
img_transpose = img[60000:].transpose(1,2,0)
Ly, Lx, N = img_transpose.shape
nimg = 5
ids = [0,5000,6500, 2, 7000]
# get texture classes, 16 classes each has 500 images
ncls = 16
cls_ids = np.arange(ncls)[:, np.newaxis].repeat(500, axis=1).flatten() + 1 # starts with 1
xpad = int(15)
ypad = int(35)
pimg = np.ones((Ly+(nimg-1)*ypad, Lx+(nimg-1)*ypad)) * 255
for i,idd in enumerate(ids):
    pimg[i*ypad:(i*ypad+Ly), i*xpad:(i*xpad+Lx)] = img_transpose[:,:,idd]
    # print(cls_ids[idd])
data_dict['dataset_imgs'] = pimg
data_dict['dataset_img_ids'] = cls_ids[ids]

# figure 4b decoding accuracy change with NN

In [None]:
from approxineuro.neural_utils import texture_accuracy
from sklearn.linear_model import LogisticRegression
n_seeds = 10
n_mouse = 4
n_neurons_list = np.logspace(0, 4, 20).astype(int)[1:-3]
print(n_neurons_list)
all_accs = np.zeros((n_mouse, len(n_neurons_list), n_seeds))
all_neuron_accs = np.zeros(n_mouse)
for m, mouse_id in enumerate([2,3,4,5]):
    fname = 'text16_%s_%s.npz'%(data.db[mouse_id]['mname'], data.db[mouse_id]['datexp'])
    dat = np.load(os.path.join(data_path fname), allow_pickle=True)  
    txt16_spks = dat['sp']
    txt16_istim = dat['istim'].astype(int)
    txt16_labels = dat['labels']
    txt16_test_istim = dat['ss_istim'].astype(int)
    txt16_test_labels = dat['ss_labels']
    txt16_test_spks = np.stack(dat['ss_all']).mean(1)
    print('train spks:', txt16_spks.shape, 'test spks:', txt16_test_spks.shape)
    from scipy.stats import zscore
    txt16_spks_zscore = zscore(txt16_spks, axis=1)
    txt16_img = img[txt16_istim]

    allneuron_acc = texture_accuracy(txt16_spks.T, txt16_labels, txt16_test_spks, txt16_test_labels)
    print(f'accuracy of all neurons: {allneuron_acc:.2f}')
    all_neuron_accs[m] = allneuron_acc

    NN = txt16_spks.shape[0]
    n_classes = len(np.unique(txt16_labels))
    selected_classes = np.random.choice(np.arange(16), n_classes, replace=False)
    selected_idxes_train = np.where(np.isin(txt16_labels, selected_classes))[0]
    selected_idxes_test = np.where(np.isin(txt16_test_labels, selected_classes))[0]
    
    for k, n_neurons in enumerate(n_neurons_list):
        for iseed, seed in enumerate(range(n_seeds)):
            np.random.seed(42*m+seed)
            random_ineurons = np.random.choice(NN, n_neurons, replace=False)
            train_X = txt16_spks[random_ineurons][:, selected_idxes_train].T
            train_y = txt16_labels[selected_idxes_train]
            test_X = txt16_test_spks[selected_idxes_test][:, random_ineurons]
            test_y = txt16_test_labels[selected_idxes_test]
            mean_x = train_X.mean(axis=0)
            std_x = train_X.std(axis=0)
            train_X = (train_X - mean_x) / std_x
            test_X = (test_X - mean_x) / std_x
            clf = LogisticRegression(random_state=0, penalty='l2', C=0.1).fit(train_X, train_y)
            acc = clf.score(test_X, test_y)
            all_accs[m, k, iseed] = acc

In [None]:
data_dict['n_neurons'] = n_neurons_list
data_dict['classification_accs'] = all_accs
data_dict['all_neuron_accs'] = all_neuron_accs
print(all_neuron_accs.mean())

# figure 3c visualize catvar

In [None]:
mouse_id = 3
# load txt16 data
fname = 'text16_%s_%s.npz'%(data.db[mouse_id]['mname'], data.db[mouse_id]['datexp'])
dat = np.load(os.path.join(data_path, fname), allow_pickle=True)
txt16_spks_test = dat['ss_all']
nstim, nrep, nneuron = txt16_spks_test.shape
txt16_istim_test = dat['ss_istim'].astype(int)
txt16_istim_test = np.repeat(txt16_istim_test[:, np.newaxis], nrep, axis=1).flatten()
txt16_spks_test = txt16_spks_test.reshape(-1, nneuron)
txt16_labels_test = dat['ss_labels']
txt16_labels_test = np.repeat(txt16_labels_test[:, np.newaxis], nrep, axis=1).flatten()

print('txt16_spks_test shape:', txt16_spks_test.shape)
print('txt16_labels_test shape:', txt16_labels_test.shape)


txt16_spks_train = dat['sp'].T
txt16_istim_train = dat['istim'].astype(int)
txt16_labels_train = dat['labels']

print('txt16_spks_train shape:', txt16_spks_train.shape)
print('txt16_labels_train shape:', txt16_labels_train.shape)

txt16_spks = np.vstack((txt16_spks_train, txt16_spks_test))
txt16_labels = np.hstack((txt16_labels_train, txt16_labels_test))
txt16_istim = np.hstack((txt16_istim_train, txt16_istim_test))

print('txt16_spks shape:', txt16_spks.shape)
print('txt16_labels shape:', txt16_labels.shape)

In [None]:
from utils import metrics
iclass1 = 6
iclass2 = 15
catvar = metrics.category_variance_pairwise(txt16_spks.T, txt16_labels, ss=[iclass1, iclass2])
print(np.mean(catvar))

In [None]:
isort = np.argsort(catvar)[::-1]
print(catvar.shape)

i=1
ineurons = isort[i:i+2]
print(catvar[ineurons])
neuron_spks = txt16_spks_test[:, ineurons].T
neuron_spks /= neuron_spks.std(1)[:, np.newaxis]

print(neuron_spks.shape)

unique_istims_cls1 = np.unique(txt16_istim_test[txt16_labels_test == iclass1])
unique_istims_cls2 = np.unique(txt16_istim_test[txt16_labels_test == iclass2])

data_dict['example_istim_cls1'] = unique_istims_cls1
data_dict['example_istim_cls2'] = unique_istims_cls2
data_dict['example_neuron_spks'] = neuron_spks
data_dict['example_ineuron'] = ineurons
data_dict['example_classes'] = np.array([iclass1, iclass2])
data_dict['txt16_istim_test'] = txt16_istim_test

# figure 4d-g

In [None]:
valid_idxes_all = []
pos_corr_all = []
noisy_catvar_all = []
neural_catvar_all = []
model_catvar_all = []
model_rfsize_all = []
pred_catvar_all = []

mouse_id = 2
for mouse_id in [2,3,4,5]:
    fpath = os.path.join(result_path, f'catvar_{data.db[mouse_id]["mname"]}_result.npz')
    dat = np.load(fpath, allow_pickle=True)

    valid_idxes = dat['valid_ineurons']
    pos_corr = dat['mean_correlation']
    noisy_catvar = dat['noisy_model_catvar']
    neural_catvar = dat['neural_catvar'][dat['valid_ineurons']]
    model_catvar = dat['model_catvar']
    pred_catvar = dat['pred_catvar']
    data_dict['op_all'] = dat['op_names']

    valid_idxes_all.append(valid_idxes)
    pos_corr_all.append(pos_corr)
    noisy_catvar_all.append(noisy_catvar)
    neural_catvar_all.append(neural_catvar)
    model_catvar_all.append(model_catvar)
    pred_catvar_all.append(pred_catvar)

    # rfsize from the Wxy
    dat = np.load(f'outputs/minimodel_{data.mouse_names[mouse_id]}_result.npz', allow_pickle=True)
    Wx = dat['wx_all']
    Wy = dat['wy_all']
    feve = dat['feve_all']
    high_feve_idxes = np.where(feve > 0.7)[0]
    from minimodel.utils import weight_bandwidth
    Wxy = np.einsum('ij,ik->ijk', Wy, Wx)
    NN = Wxy.shape[0]
    bandwidth_Wx = np.zeros(NN)
    bandwidth_Wy = np.zeros(NN)
    for i in range(NN):
        bandwidth_Wx[i] = weight_bandwidth(Wx[i, :])
        bandwidth_Wy[i] = weight_bandwidth(Wy[i, :])
    rf_size = bandwidth_Wx * bandwidth_Wy
    model_rfsize_all.append(rf_size[high_feve_idxes])

data_dict['valid_idxes'] = valid_idxes_all
data_dict['minimodel_poscorr_all'] = np.hstack(pos_corr_all)
data_dict['minimodel_catvar_noise_all'] = np.hstack(noisy_catvar_all)
data_dict['neural_catvar_all'] = np.hstack(neural_catvar_all)
data_dict['minimodel_catvar_all'] = np.vstack(model_catvar_all)
data_dict['rfsize'] = np.hstack(model_rfsize_all)
data_dict['pred_catvar_all'] = np.hstack(pred_catvar_all)

In [None]:
print('minimodel poscorr all:', data_dict['minimodel_poscorr_all'].shape)
print('minimodel catvar noise all:', data_dict['minimodel_catvar_noise_all'].shape)
print('neural catvar all:', data_dict['neural_catvar_all'].shape)
print('minimodel catvar all:', data_dict['minimodel_catvar_all'].shape)
print('rfsize:', data_dict['rfsize'].shape)
print('pred catvar all:', data_dict['pred_catvar_all'].shape)

# figure 4h-j monkey catvar

In [None]:
fpath = os.path.join(result_path, f'catvar_monkey_result.npz')
dat = np.load(fpath, allow_pickle=True)

data_dict['monkey_op_all'] = dat['op_names']
data_dict['monkey_minimodel_catvar_all'] = dat['model_catvar']
data_dict['monkey_minimodel_poscorr_all'] = dat['mean_correlation']
data_dict['monkey_op_all'] = dat['op_names']
data_dict['monkey_catvar_all'] = dat['model_catvar']     
data_dict['monkey_pred_catvar'] = dat['pred_catvar']

In [None]:
# rfsize from the Wxy
dat = np.load(os.path.join(result_path, f'minimodel_monkey_result.npz'), allow_pickle=True)
Wx = dat['wx_all']
Wy = dat['wy_all']
from minimodel.utils import weight_bandwidth
Wxy = np.einsum('ij,ik->ijk', Wy, Wx)
NN = Wxy.shape[0]
bandwidth_Wx = np.zeros(NN)
bandwidth_Wy = np.zeros(NN)
for i in range(NN):
    bandwidth_Wx[i] = weight_bandwidth(Wx[i, :])
    bandwidth_Wy[i] = weight_bandwidth(Wy[i, :])
rf_size = bandwidth_Wx * bandwidth_Wy
data_dict['monkey_rfsize'] = rf_size
# rfsize = np.pi * rfsize * (1.1/80)**2

feve_all = dat['feve_all']
data_dict['monkey_feve'] = feve_all

In [None]:
op_all = ['conv1', 'conv1(ReLU)', 'conv1(pool)', 'conv2(1x1)', 'conv2(spatial)', 'conv2(ReLU)', 'readout(Wxy)', 'readout(ELU)']
data_dict['op_all'] = op_all

# save

In [None]:
# save data_dict
np.savez(f'figure4_results.npz', **data_dict)

# plot

In [None]:
import figure4
dat = np.load('figure4_results.npz', allow_pickle=True)
save_path = './outputs'
figure4.figure4(dat, save_path)