In [None]:
import os
import pickle

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

import funcs
import hubs
import clusters as clust

import correlations_figure as cf
import cluster_graphs_figure as cgf
import metrics_figures as mf
import hubs_multiple_trials_figure as hmtf
import figures as f

paper_folder = 'du_Plessis_et_al_2022/'


#### Import data:
#### The functions here require access to a set of raw data.
#### This can be loaded using 'funcs.load_data'.
#### The format for 'data' object is a pandas DataFrame, with rows labelled as 'Region', and Columns use MultiIndex with 2 levels: 'group', and animal name (any string)
#### Results of this are used to create most figures, and are already stored in threshold_results_dict.dat file in the paper folder

In [None]:
# With access to the data file, load it here:
# data = funcs.load_data()
# groups = data.columns.get_level_values('group').unique()

# Or just set up group names:
groups = ['naive_male', 'naive_female', 'trained_male', 'trained_female']

In [None]:
thresholds = np.linspace(0.005, 0.05, 10)
threshold_type = 'p'
method = 'pearson'

thresholds_results_dict = {}
thresholds = np.linspace(0.005, 0.05, 10)

In [None]:
# This initial analysis takes ~5 minutes to compute, and
# requires access to the original data.
# Alternatively can load from dat files in cell below

for threshold in thresholds:
    adj_mat_dict = funcs.get_adj_mat_dict_for_groups(data, groups,
                                                threshold_type=threshold_type,
                                                threshold=threshold,
                                                method=method)
    cluster_ids_dict = clust.get_cluster_ids_for_groups(data, groups,
                                                threshold_type=threshold_type,
                                                threshold=threshold,
                                                method=method)
    gr_dict = {}
    hub_dict = {}
    for group in groups:
        gr_dict[group] = cgf.create_graph(adj_mat_dict[group])
        hub_dict[group] = hubs.centrality_measures_with_hub_regions(data, 
                                    group=group,
                                    threshold_type=threshold_type,
                                    threshold=threshold,
                                    method=method)
    thresholds_results_dict[threshold] = (adj_mat_dict, cluster_ids_dict, hub_dict, gr_dict)

with open(paper_folder+'thresholds_results_dict.dat', 'wb') as f:
    pickle.dump(thresholds_results_dict, f)


In [None]:
thresholds_results_dict = pickle.load(open(paper_folder+'thresholds_results_dict.dat', 'rb'))

In [None]:
hubs_all_thresholds = {}
for group in groups:
    hub_dict_list = []
    for threshold in thresholds_results_dict.keys():
        hub_dict_list.append(thresholds_results_dict[threshold][2][group])
    hubs_all_thresholds[group] = hubs.hub_counts_for_multiple_trials(hub_dict_list)

In [None]:
metrics_naive_male_list = pickle.load(open(paper_folder+'metrics_naive_male_list.dat', 'rb'))
metrics_trained_male_list = pickle.load(open(paper_folder+'metrics_trained_male_list.dat', 'rb'))
metrics_naive_female_list = pickle.load(open(paper_folder+'metrics_naive_female_list.dat', 'rb'))
metrics_trained_female_list = pickle.load(open(paper_folder+'metrics_trained_female_list.dat', 'rb'))
all_groups_metrics = metrics_naive_male_list + \
                     metrics_naive_female_list + \
                     metrics_trained_male_list + \
                     metrics_trained_female_list
metrics_all_groups_df = pd.DataFrame(all_groups_metrics)

## Figure 3 -- Correlation matrices for each group

In [None]:
fig = cf.make_correlation_matrix_group_clusters_figure(
    cluster_group='regionsB'
)

## Figure 4 -- Network graphs for each group, along with centrality metrics indicating hubs

In [None]:
primary_threshold = 0.05
adj_mat_dict = thresholds_results_dict[primary_threshold][0]
cluster_ids_dict = thresholds_results_dict[primary_threshold][1]
hub_data_dict = thresholds_results_dict[primary_threshold][2]

In [None]:
tab20 = plt.get_cmap('tab20').colors
color_select = [1,3,5,9,11,13,19,17,4,12,16,18,2,0]
colors = [tab20[color_num] for color_num in color_select]
figure_out = cgf.make_cluster_graph_4groups_figure(hub_data_dict=hub_dict, 
                                                   adj_mat_dict=adj_mat_dict, 
                                                   cluster_ids_dict=cluster_ids_dict,
                                                   node_color_map=colors,
                                                   node_layout='set_position',
                                                   node_position_file=paper_folder+'node_pos_p05_size5_2.json')

In [None]:
# Positions of nodes can be edited in the figure if running in an
# interactive mode, and then saved to a file like this:
netgraphs_list = [fo[1] for fo in figure_out.values()]
cgf.save_node_positions('node_pos_p05_size5_2.json', netgraphs_list, 
    groups_list=['naive_male', 'trained_male', 'naive_female', 'trained_female'])

## Figure S1 -- metrics for primary threshold, for all groups, with randomizations

In [None]:
mf.make_network_metrics_primary_threshold_figure(
    metrics_df=metrics_all_groups_df)

## Figure S2 -- metrics for all thresholds, for all groups, with randomizations

In [None]:
mf.make_network_metrics_all_thresholds_figure(
    metrics_df=metrics_all_groups_df.drop(metrics_all_groups_df[metrics_all_groups_df.small_worldness>1e6].index))

## Figure S3 -- hub counts for multiple thresholds, for each group

In [None]:
hmtf.make_hub_counts_4groups_figure(thresholds_results_dict)