# Notebook to calculate group differences based on paper drakesmith et al

This is the  multi-threshold permutation correction approach and consists:

1) Apply thresholds to networks and compute network metrics for all networks across all thresholds
2) Compute test statistics for each network
3) Permute across groups to get the null distrubition
4) Take the maximum test statistic across all thresholds for each permutation resulting in one summarised null statistic for each permutation. 
 

In [1]:
import os
import pandas as pd
import numpy as np
import networkx as nx
import scona as scn
import matplotlib.pyplot as plt
import functions.plotting_functions as Pfun
import functions.statistical_functions as Sfun
import seaborn as sns
sns.set_style('dark')
from decouple import config

data = config('data')

Import data.

In [57]:
lh_volume = pd.read_csv(f'{data}/lh_volume.dat',sep='\t').drop(['BrainSegVolNotVent', 'eTIV'],axis=1).rename(columns={'lh.aparc.volume':'G-Number'})
rh_volume =  pd.read_csv(f'{data}/rh_volume.dat',sep='\t').drop(['BrainSegVolNotVent', 'eTIV','rh.aparc.volume'],axis=1)

group = pd.read_csv(f'{data}/cortical_measures.csv').iloc[0:,2]

volume = pd.concat([lh_volume, rh_volume, group],axis=1)

names = list(volume.columns.drop(['G-Number','age_adjusted_group']))

centroids = pd.read_csv(f'{data}/atlas.csv')
centroids = centroids[['x.mni', 'y.mni', 'z.mni']].to_numpy()


group = volume.groupby('age_adjusted_group')
aan = group.get_group('AAN').reset_index(drop=True)
hc = group.get_group('HC').reset_index(drop=True)
wr = group.get_group('WR').reset_index(drop=True)

float64


## 1) Apply thresholds to networks and compute network metrics for all networks across all thresholds

First the threshold range needs to be decided. To do this the methodology used is taken from Bassett et al (https://doi.org/10.1073/pnas.0606005103) where the upper limit must not exceed 2 * natural log(nodes). 
Starting point of 4 is used otherwise minimum spanning tree is too large and throws up error message.

Threshold values are stored in a list.

In [3]:
threshold_value=[]
for threshold in range(4 , 100):
    aan_graphs = Sfun.create_graphs(aan.iloc[:,1:69], names, centroids, threshold=threshold)
    aan_graphs['graph_threshold'].calculate_nodal_measures()
    cal = aan_graphs['graph_threshold'].report_nodal_measures()
    
    if cal['degree'].mean() > 2 * np.log(len(aan_graphs['graph_threshold'].nodes())):
        break
    
    threshold_value.append(threshold)

Now create graphs at the pre-defined thresholds.

In [4]:
result = {
    'aan_graphs':[],
    'hc_graphs':[],
    'wr_graphs':[]
}

In [5]:
for threshold in threshold_value:
    aan_graphs = Sfun.create_graphs(aan.iloc[:,1:69], names, centroids, threshold=threshold)
    aan_graphs['graph_threshold'].__name__ = f'aan_graph_threshold_value_{threshold}'
    result['aan_graphs'].append(aan_graphs['graph_threshold'])

    hc_graphs = Sfun.create_graphs(hc.iloc[:,1:69], names, centroids, threshold=threshold)
    hc_graphs['graph_threshold'].__name__ = f'hc_graph_threshold_value_{threshold}'
    result['hc_graphs'].append(hc_graphs['graph_threshold'])

    wr_graphs = Sfun.create_graphs(wr.iloc[:,1:69], names, centroids, threshold=threshold)
    wr_graphs['graph_threshold'].__name__ = f'wr_graph_threshold_value_{threshold}'
    result['wr_graphs'].append(wr_graphs['graph_threshold'])

Finally calculate global measures for each graph at the threshold ranges.

In [6]:
measures = {}

In [7]:
for keys, values in result.items():
    for graph_object in values:
        global_measures = graph_object.calculate_global_measures()
        measures[f'{graph_object.__name__}'] = []
        measures[f'{graph_object.__name__}'].append(global_measures)

## 2) Compute test statistics for each network

Calculate the test statistic for each condition. This is going to be the difference between each global measure for two groups. 

In [43]:
test_statistics = {
    'aan_hc':{},
    'wr_hc':{},
    'wr_aan':{}
}

In [46]:
for threshold in threshold_value:
    for keys, value in measures['aan_graph_threshold_value_4'][0].items():
        statistic_aan_hc =  measures[f'aan_graph_threshold_value_{threshold}'][0][keys] - measures[f'hc_graph_threshold_value_{threshold}'][0][keys]
        statistic_wr_hc =  measures[f'wr_graph_threshold_value_{threshold}'][0][keys] - measures[f'hc_graph_threshold_value_{threshold}'][0][keys]
        statistic_wr_aan =  measures[f'wr_graph_threshold_value_{threshold}'][0][keys] - measures[f'aan_graph_threshold_value_{threshold}'][0][keys]
        
        key = f'{keys}_at_threshold_value_{threshold}'
      
        test_statistics['aan_hc'][key] = statistic_aan_hc 
        test_statistics['wr_hc'][key] = statistic_wr_hc
        test_statistics['wr_aan'][key] = statistic_wr_aan


## 3) Permute across groups to get the null distrubition

Now we permutate to create a null distrubtion


In [69]:
def permuations(permutations:int, length_of_group_1:int, length_of_group_2:int, participants:pd.DataFrame, names:list, centroids:np.float64, threshold:int) -> dict:
    measures = ['average_clustering', 'average_shortest_path_length', 'assortativity', 'modularity', 'efficiency']
    null_distribution ={
        
        'average_clustering':[],
        'average_shortest_path_length':[], 
        'assortativity':[], 
        'modularity':[], 
        'efficiency':[]
    
    }
    
    for perm in range(permutations):
        group_1_participants = participants.sample(n=length_of_group_1)
        group_2_participants = participants.sample(n=length_of_group_2)
        
        group_1_graphs = Sfun.create_graphs(group_1_participants.iloc[:,1:69], names, centroids, threshold=threshold)
        group_2_graphs = Sfun.create_graphs(group_2_participants.iloc[:,1:69], names, centroids)
        
        group_1_values = group_1_graphs['graph_threshold'].calculate_global_measures()
        group_2_graphs = group_2_graphs['graph_threshold'].calculate_global_measures()
    
        for meas in measures:
            crit_val = group_1_values[meas] -  group_2_graphs[meas]
            null_distribution[meas].append(crit_val)
            
    return null_distribution

In [61]:
aan_hc_df = pd.concat([aan, hc], ignore_index=True)
aan_wr_df = pd.concat([aan, wr], ignore_index=True)
wr_hc_df = pd.concat([wr, hc], ignore_index=True)

In [79]:
null_distribution = {
    'aan_hc':{},
    'wr_hc':{},
    'wr_aan':{}
}

In [80]:
for threshold in threshold_value:
    aan_hc_perm = permuations(10, len(aan['G-Number']), len(hc['G-Number']), aan_hc_df, names, centroids, threshold)
    aan_wr_perm = permuations(10, len(aan['G-Number']), len(wr['G-Number']), aan_wr_df, names, centroids, threshold)
    wr_hc_perm = permuations(10, len(wr['G-Number']), len(hc['G-Number']), wr_hc_df, names, centroids, threshold)
    
    key = f'thresholded_value_{threshold}'
    null_distribution['aan_hc'][key] = aan_hc_perm 
    null_distribution['wr_hc'][key] = aan_wr_perm
    null_distribution['wr_aan'][key] = wr_hc_perm

In [94]:
print(max(null_distribution['aan_hc']['thresholded_value_4']['average_shortest_path_length']))
print(test_statistics['aan_hc']['average_shortest_path_length_at_threshold_value_4'])

3.2734855136084287
-0.08077260755048243


## 4) Take the maximum test statistic across all thresholds for each permutation resulting in one summarised null statistic for each permutation. 



In [110]:
def find_max_null_stat(null_distribution:dict, group_key:str, measure_key:str, list_number:int) -> int:
    
    '''
    Function to loop through each element in a list for each thresold value.

    Parameters
    ----------
    null_distribution:dict dictionary of null_distribution
    group_key:str dictionary key for group 
    measure_key:str dictionary key for the graph theory measure.
    list_number:int list index of permutation

    Returns
    -------
    max_null_statistic:int max null statistic for that permutation.

    '''
    
    values = []
    for threshold in threshold_value:
        values.append(null_distribution[group_key][f'thresholded_value_{threshold}'][measure_key][list_number])
    
    max_null_statistic = max(values, key=abs)
    return max_null_statistic

In [None]:
null_stat_summarized = {
    
    'aan_hc':{},
    'wr_hc':{},
    'wr_aan':{}
}

In [115]:
#TODO need to get max_null_stat for each permutation into dictionary
#for perm in range(0, 10):
#    find_max_null_stat