# Preparing Data for Figure 3

This notebook saves the processed data needed to plot **Figure 3** into a `.npz` file and then generates the corresponding plots.

**Important:**  
Before running this notebook, please make sure to execute the following notebooks in the `./save_results/` directory. 
- [fullmodel_mouse_saveall.ipynb](https://github.com/MouseLand/minimodel/blob/main/figures/save_results/fullmodel_mouse_saveall.ipynb)
- [fullmodel_monkey_saveall.ipynb](https://github.com/MouseLand/minimodel/blob/main/figures/save_results/fullmodel_monkey_saveall.ipynb)
- [minimodel_mouse_saveall.ipynb](https://github.com/MouseLand/minimodel/blob/main/figures/save_results/minimodel_mouse_saveall.ipynb)
- [minimodel_monkey_saveall.ipynb](https://github.com/MouseLand/minimodel/blob/main/figures/save_results/minimodel_monkey_saveall.ipynb)

These notebooks:
- Load the raw neural and stimulus data,
- Run models for each animal (mouse and monkey),
- Save the model outputs needed for plotting.

Each notebook in `./save_results/outputs/` corresponds to a specific condition or model variant. Skipping any of them may result in missing or incomplete data when running this notebook.


In [None]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from minimodel import data

device = torch.device('cuda')

data_dict = {}

data_path = '../data'
weight_path = './checkpoints/fullmodel'
result_path = './save_results/outputs'

# figure 3a (FEVE change with nimages)

## 1. mouse

In [None]:
nmouse = 6
all_feve = []
for mouse_id in range(6):
    dat = np.load(os.path.join(result_path, f'fullmodel_{data.mouse_names[mouse_id]}_results.npz'), allow_pickle=True)
    all_feve.append(dat['feve_nstims'].mean(1))

data_dict['feve_all_nstim'] = np.vstack(all_feve)
data_dict['nstim'] = dat['nstims']

## 2. monkey

In [None]:
dat = np.load(os.path.join(result_path, f'fullmodel_monkey_results.npz'), allow_pickle=True)
data_dict['monkey_feve_all_nstim'] = dat['feve_nstims']
data_dict['monkey_nstim'] = dat['nstims']
data_dict['monkey_id'] = dat['monkey_ids']

# figure 3b (vary nn performance)

## 1. mouse

In [None]:
nmouse = 6
all_feve = []
for mouse_id in range(6):
    dat = np.load(os.path.join(result_path, f'fullmodel_{data.mouse_names[mouse_id]}_results.npz'), allow_pickle=True)
    all_feve.append(dat['feve_nneurons'])

data_dict['NNs'] = dat['nneurons']
data_dict['feve_all_nn'] = all_feve

## 2. monkey

In [None]:
dat = np.load(os.path.join(result_path, f'fullmodel_monkey_results.npz'), allow_pickle=True)
neuron_numbers = dat['nneurons']
feve_nneurons = dat['feve_nneurons']
monkey_ids = dat['nneuron_monkey_ids']

feves = []
for i, nn in enumerate(neuron_numbers):
    feve_allseed = np.array(feve_nneurons[i])
    nseed = len(feve_allseed)
    feve_tmp = []
    for iseed in range(nseed):
        feve = feve_allseed[iseed]
        feve_tmp.append(feve)
    feve_tmp = np.hstack(feve_tmp)
    feves.append(feve_tmp)
data_dict['monkey_NNs'] = dat['nneurons']
data_dict['monkey_feve_all_nn'] = feves

# figure 3d-f


In [None]:
nmouse = 6
fev_all = []
fullmodel_feve_all = []
for mouse_id in range(6):
    dat = np.load(os.path.join(result_path, f'fullmodel_{data.mouse_names[mouse_id]}_results.npz'), allow_pickle=True)
    fev_all.append(dat['fev'])
    fullmodel_feve_all.append(dat['fullmodel_feve_all'])
data_dict['fev_all'] = fev_all
data_dict['fullmodel_feve_all'] = fullmodel_feve_all

In [None]:
feve_all = []
wc_all = []
for mouse_id in range(6):
    dat = np.load(os.path.join(result_path, f'minimodel_{data.mouse_names[mouse_id]}_result.npz'), allow_pickle=True)
    feve_all.append(dat['feve_all'])
    wc_all.append(dat['wc_all'])
data_dict['minimodel_feve_all'] = feve_all
data_dict['minimodel_wc_all'] = wc_all

# Monkey

In [None]:
dat = np.load(os.path.join(result_path, f'minimodel_monkey_result.npz'), allow_pickle=True)
data_dict['monkey_wc_all'] = dat['wc_all']
data_dict['monkey_minimodel_feve_all'] = dat['feve_all']

In [None]:
dat = np.load(os.path.join(result_path, f'fullmodel_monkey_results.npz'), allow_pickle=True)
data_dict['monkey_feve_all'] = dat['fullmodel_feve_all']
data_dict['monkey_id'] = dat['monkey_ids']
data_dict['monkey_fev_all'] = dat['fev_all']

# save

In [None]:
# save data_dict
np.savez(f'figure3_results.npz', **data_dict)

# plot

In [None]:
import figure3
dat = np.load('figure3_results.npz', allow_pickle=True)
save_path = './outputs'
figure3.figure3(dat, save_path)