# Preparing Data for Figure 2

This notebook saves the processed data needed to plot **Figure 2** into a `.npz` file and then generates the corresponding plots.

**Important:**  
Before running this notebook, please make sure to execute the following notebooks in the `./save_results` directory. 
- [fullmodel_mouse_saveall.ipynb](https://github.com/MouseLand/minimodel/blob/main/figures/save_results/fullmodel_mouse_saveall.ipynb)
- [fullmodel_monkey_saveall.ipynb](https://github.com/MouseLand/minimodel/blob/main/figures/save_results/fullmodel_monkey_saveall.ipynb)

These notebooks:
- Load the raw neural and stimulus data,
- Run models for each animal (mouse and monkey),
- Save the model outputs needed for plotting.

Each notebook in `./save_results` corresponds to a specific condition or model variant. Skipping any of them may result in missing or incomplete data when running this notebook.


In [None]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from minimodel import data

device = torch.device('cuda')

data_dict = {}

data_path = '../data'
weight_path = './checkpoints/fullmodel'
result_path = './save_results/outputs'

# figure 2b-d (performance change with width)

In [None]:
nconv1_list = [8, 16, 32, 64, 128, 192, 256, 320, 384, 448]
nconv2_list = [8, 16, 32, 64, 128, 192, 256, 320, 384, 448]
feve_width = []
for mouse_id in range(6):
    dat = np.load(os.path.join(result_path, f'fullmodel_{data.mouse_names[mouse_id]}_results.npz'), allow_pickle=True)
    feve_width.append(dat['feve_width'].mean(axis=2))

feve_width = np.stack(feve_width)
data_dict['feve_our_model_vary_width'] = feve_width
data_dict['nconv1'] = nconv1_list
data_dict['nconv2'] = nconv2_list

# figure 1e (conv1 kernels)

In [None]:
# load images
mouse_id = 0
dat = np.load(os.path.join(result_path, f'fullmodel_{data.mouse_names[mouse_id]}_results.npz'), allow_pickle=True)
data_dict['conv1_W'] = dat['fullmodel_conv1_W']

#  monkey

## figure 1f-g (performance change with width)

In [None]:
dat = np.load(os.path.join(result_path, f'fullmodel_monkey_results.npz'), allow_pickle=True)
data_dict['monkey_all_width_eve'] = dat['feve_width'].mean(axis=2)

## figure 2i (kernels)

In [None]:
# load images
mouse_id = 0
dat = np.load(os.path.join(result_path, f'fullmodel_monkey_results.npz'), allow_pickle=True)
data_dict['monkey_conv1_W'] = dat['conv1_W']

In [None]:
conv1_W = data_dict['monkey_conv1_W']
isort = [0,1,11,9,5,8,13,7,4,6,12,14,15,10,2,3]
fig, ax = plt.subplots(4,4, figsize=(8, 8))
for i in range(16):
    ax[i//4, i%4].imshow(conv1_W[isort[i]], cmap='RdBu_r', vmin=-0.15, vmax=0.15)
    # ax[i//4, i%4].set_title(f'{conv1_W_ratio[isort[i]]:.2f}')
    ax[i//4, i%4].axis('off')

# figure 2k-l

In [None]:
texturenet_acc = np.load(os.path.join(result_path, 'texturenet_accuracy.npy'), allow_pickle=True)[()]['accuracy']
imagenet_accuracy = np.load(os.path.join(result_path, 'top1_top5_summary.npy'), allow_pickle=True)[()]['top1']
data_dict['texturenet_accuracy'] = texturenet_acc
data_dict['imagenet_accuracy'] = imagenet_accuracy

# save

In [None]:
np.savez(f'figure2_results.npz', **data_dict)

# plot

In [None]:
import figure2
dat = np.load('figure2_results.npz', allow_pickle=True)
save_path = './outputs'
figure2.figure2(dat, save_path)