# Code to generate main manuscript figures

#### Note: Running this notebook requires downloading the output dict of all analyses, which can be found in the OSF repository(save into `save_dir`), or can be generated it via `analysis.py` (see README).

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from all_tnn.analysis.util import * 
import pickle
# Import custom modules
from all_tnn.analysis.config import *
from all_tnn.analysis.visualization.acc_maps_visualization import plot_bar_plot_from_df
from all_tnn.analysis.visualization.acc_spatial_loss import generate_analysis_df
from all_tnn.analysis.visualization.energy_efficiency import plot_energy_consumption_across_epochs_lineplot, plot_stacked_energy_map_energy_vs_eccentricity
from all_tnn.analysis.visualization.layer_visualization import visualize_layer
from all_tnn.analysis.visualization.smoothness_entropy_visualization import *
from all_tnn.analysis.glm_analysis import run_full_GLM_analysis

# Set plotting style
import scienceplots  
plt.style.use(['science', 'nature', "ieee", 'no-latex'])
from all_tnn.analysis.visualization.colors import DECREASING_6COLORS, COLOR_THEME_WITH_ALPHA_SWEEP
color_palette = COLOR_THEME_WITH_ALPHA_SWEEP[1:] 

In [None]:
# Setting Paths
base_src_dir = os.path.join('/share/klab/datasets/TNN_paper_save_dir/All-TNN_public', '_analyses_data')
# List of additional directories to create
directories = [
    os.path.join(base_src_dir, 'figure1', 'src'),   
    os.path.join(base_src_dir, 'figure2'),        
    os.path.join(base_src_dir, 'figure3'),        
    os.path.join(base_src_dir, 'figure4'), 
    os.path.join(base_src_dir, 'figure5'), 
]

# Iterate over the list and create each directory
for directory in directories:
    os.makedirs(directory, exist_ok=True)
    # Convert to absolute path for better readability
    absolute_path = os.path.abspath(directory)
    print(f"Created directory: {absolute_path}")

plot_path_fig1, plot_path_fig2, plot_path_fig3, plot_path_fig4, plot_path_fig5 = directories

## Figure 1 
### Categorization performance 
Classification accuracy of all All-TNNs and control models on the ecoset test set. 
Running these cells will plot the figures and save them to `save_dir`

In [None]:
MODEL_NAMES = [
    "CNN_lr_0.05",
    "LCN_lr_0.05",
    "TNN_alpha_1_lr_0.05",
    "TNN_alpha_10_lr_0.05",
    "TNN_alpha_100_lr_0.05",]
    
df = generate_analysis_df(
    base_src_dir_path=plot_path_fig1,
    MODEL_NAMES=MODEL_NAMES,
    seeds_range=SEEDS_RANGE,
    models_epochs_dict=MODELS_EPOCHS_DICT,
    MODEL_NAMES_TO_PLOT=MODEL_NAMES_TO_PLOT,
    results_file_name='acc_smoothness_loss.pickle', # or default is multi_models_neural_dict.pickle
)
plot_bar_plot_from_df(df, plot_path_fig1+'accuracy_compare_across_alphas.pdf', 
                               x="Model", y="Accuracy", 
                               title = "Categorisation performance", 
                               color3_start_id = 1,
                               show_plot = True,
                               figsize=(3.54, 2))

### Spatial smoothness
Spatial smoothness is calculated as as 1/average cosine similarity between the weights of neighbouring units for all models.

In [None]:
plot_bar_plot_from_df(df, plot_path_fig1+'mean_cosdist_compare_across_alphas.pdf',
                                    x="Model", y="Spatial Smoothness", 
                                    title = "Spatial smoothness", 
                                    color3_start_id = 1,
                                    show_plot = True,
                                    log_scale = True, #! log scale
                                    figsize=(3.54, 2))

## Figure 2
### Orientation selectivity maps, entropy of first layer orientation selectivity, and category selectivity maps for XXX

In [None]:
# test smoothness_main 
seed = 1
datapath = f'/share/klab/datasets/TNN_paper_save_dir/All-TNN_public/neural_level_analysis/300/seed{seed}/'
with open(datapath + 'all_multi_models_neural_dict.pickle', 'rb') as handle:
        output_dict = pickle.load(handle)
CATEGORY_STATS = 0
vis = visualize_layer(output_dict["TNN_alpha_10_lr_0.05"], 300, layer_i=0, analysis_dir=plot_path_fig2, model_name="TNN_alpha_10_lr_0.05", layer=None, save=True, show=True)

### Radial entropy profile

In [None]:
datapath = '/share/klab/datasets/TNN_paper_save_dir/All-TNN_public/neural_level_analysis/300/'

all_data = []
for seed in [1]: #seeds
    with open(os.path.join(datapath, f'seed{seed}/all_multi_models_neural_dict.pickle'), 'rb') as handle: 
        all_data.append(pickle.load(handle))
ent_dict = calculate_radial_entropy(all_data, MODEL_NAMES)
plot_radial_entropy(ent_dict, color_palette, MODEL_NAMES, plot_path_fig2, save=True, show=True)

### Smoothness of orientation selectivity and category selectivity maps



In [None]:
plot_cluster_size(all_data, color_palette, MODEL_NAMES, plot_path_fig2, stats=False, save=True, show=True)
cluster_size_vs_eccentricity(all_data, color_palette, MODEL_NAMES, plot_path_fig2, save=True, show=True)

## Figure 3
### Energy consumption
- For Figures 3A & 3B, simply run the following code. Figure 3C is already generated and saved in the same directory as each seed’s models after running analysis.py.

In [None]:
epochs_to_plot = [35] + list(range(50,601,50))
plot_energy_consumption_across_epochs_lineplot(
        model_name_path_dict=MODEL_NAME_PATH_DICT,
        alphas=ALPHAS,
        seed_range=SEEDS_RANGE,
        fixed_epochs=epochs_to_plot,
        save_fig_path=plot_path_fig3,
        pre_or_postrelu='postrelu',
        NORM_PREV_LAYER=True,
        NORM_LAYER_OUT=True,
    )

In [None]:
plot_stacked_energy_map_energy_vs_eccentricity(
        model_name_path_dict=MODEL_NAME_PATH_DICT,
        alphas=ALPHAS,
        save_fig_path=plot_path_fig3,
        pre_or_postrelu='postrelu',
        prefix_list=[ 'ali_'],
        energy_consumption_types= ['total'],# energy_consumption_types,
        seed_range=SEEDS_RANGE,
        models_epochs_dict=MODELS_EPOCHS_DICT,
        NORM_PREV_LAYER=True,
        NORM_LAYER_OUT=False,
    )
    

## Figure 4
### Human object-specific biases are predicted by animacy and real-world size. 
 Non-negative least squares GLM analysis on the averaged human ADM shows significant unique variance explained by animacy and real-world size. 

In [None]:
glm_results = run_full_GLM_analysis(
    base_path='./save_dir/_analyses_data',
    data_filename='vicky_adm_dict_600_1_pearsonr_spearmanr.pkl',
    plot_path = plot_path_fig4,
    model_type='average_human_adm',
    predictor_names=['animate', 'size', 'spiky'],  # or a custom list of predictors (see make_predictors() function)
    num_permutations=1000,
    verbose=True,)

### Spatial biases in human and model behaviour

In [None]:
df_behavior_agreements = pd.read_csv(plot_path_fig4 + 'df_behavior_agreements.csv')
plot_bar_plot_from_df(df_behavior_agreements,
                        plot_path_fig4+'behaviour_agreements.pdf',
                        x="Model", y='pearsonr', title="Behaviour Agreement Analysis",
                        show_plot=False, color3_start_id=1, hline=None,
                        figsize=(3.54, 2),)

### Agreement with object-specific biases in human behaviour

In [None]:
df_adm_agreement = pd.read_csv(plot_path_fig4 + 'df_adm_agreement.csv')
plot_bar_plot_from_df(df_adm_agreement,
                        plot_path_fig4+'adm_agreements.pdf',
                        x="Model", y='spearmanr', title="ADM Agreement Analysis",
                        show_plot=False, color3_start_id=1, hline=None,
                        figsize=(3.54, 2),)

## Figure 5

- Figure 5 is produced by re-running the code for Figures 1–4 for the self-supervised (SimCLR) All-TNN ($\alpha = 10$) and the supervised All-TNN ($\alpha = 10$).

In [None]:
MODEL_NAMES=["TNN_alpha_10_lr_0.05", "TNN_simclr_finetune",] # Set this in analysis/config.py