In [None]:
import pandas as pd
import numpy as np
import seaborn as sns
import networkx as nx
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler
from sklearn.preprocessing import StandardScaler
#sns.set_style('darkgrid')
sns.set_style('whitegrid')
sns.axes_style("whitegrid")
from networkx.algorithms import community
import itertools
from networkx.algorithms.community.quality import modularity
from sklearn.metrics import jaccard_score
import matplotlib.colors as mcolors


## Functions:

In [None]:
def scale_adjMatrix(adjM, sc):
    
    if sc == 'standard':
        scaler = StandardScaler()
    elif sc == 'minmax':
        scaler = MinMaxScaler()
        
    mask = np.triu(np.ones_like(adjM, dtype=bool), k=1) # mask the lower triangle and diagonal for scaling
    adjM_sc = adjM.copy()
    adjM_sc[mask] = scaler.fit_transform(adjM[mask].reshape(-1, 1)).flatten()
    
    return adjM_sc

In [None]:
def create_PSN(adjM, nodelist, phi, s, title, fused, savefig):
    
    G = nx.from_numpy_array(adjM)
    PSN = nx.Graph()
    PSN.add_nodes_from(G.nodes)
    
    similarities = []
    n = len(G.nodes)
    
    for i in range(n):
        for j in range(i+1,n):
            sim = adjM[i,j]
            similarities.append(sim)
            if sim > phi:
                PSN.add_edge(i, j, weight = sim)
                
    plt.figure(figsize=(12,12))
    plt.title(title + ', cutoff: ' + str(phi) + ', n='+ str(n))
    
    pos = nx.spring_layout(PSN, seed = s, k=1/2) 
    
    if nodelist:
        node_cols = {'E2E2':'#177245', 'E2E3':'#6aa84f','E3E3':'#0b5394', 'E2E4': '#f2a85a', 'E4E4':'#cc0000', 'E3E4':'#800080'}
       
        node_mapping = [node_cols[gtype] for gtype in nodelist]
        rgb_colors = [mcolors.hex2color(color) for color in node_mapping]
        
        nx.draw_networkx_nodes(PSN, pos=pos, node_size = 80, node_color=rgb_colors, alpha=0.7)
    
    else:
        if fused:
            color = '#38761d'
        else:
            color = '#2d6b87'#257ca3
    
    
    nx.draw_networkx_nodes(PSN, pos=pos, node_size=80, alpha=0.7, node_color=color) 
    nx.draw_networkx_edges(PSN, pos=pos, width=2, alpha= 1/10)
    sns.set_style('white')
    plt.grid(False)
    plt.gca().set_facecolor('white')
    
    if savefig:
        plt.savefig(f'{title}.svg')
    
    return PSN, similarities

In [None]:
def plot_similarities(sim, title, fused, savefig):
    
    plt.figure(figsize=(8,8))
    plt.title(title)
    plt.xlabel('Similarities')
    sns.set(font_scale=1.2)
    sns.axes_style("whitegrid")
    sns.set_style('whitegrid')
    plt.grid(False)
    plt.gca().set_facecolor('white')
    
    if fused: # set green color for fused network and blue for unimodal
        col='#668a56'
    else:
        col='#668d9f'  
    
    sns.histplot(sim, bins=100, color=col)#668d9f #38761d #79ba5c

    #plt.axvline(np.mean(sim), color='k', linestyle='dashed', linewidth=1)
    
    sorted_sim = sorted(sim, reverse=True)
    top_10_ind = int(0.1 * len(sorted_sim))
    top_10_cutoff = sorted_sim[top_10_ind]
    
    plt.axvline(top_10_cutoff, color='k', linestyle='dashed', linewidth=2)
    plt.text(top_10_cutoff+0.01, 500, str(round(top_10_cutoff,3)), va='bottom', ha='left')
    
    if savefig:
        plt.savefig(title+'.svg')
        
    plt.show()
    
    
    
    print('mean: ' + str(np.nanmean(sim)))
    print('Top 10 percent similarities: ' + str(top_10_cutoff))
    

In [None]:
def fuse_PSNs(adjMatrices):
    
    dimentions= []
    
    for adjM in adjMatrices:
        dimentions.append(adjM.shape)
    
    if len(set(dimentions)) == 1: #checks that the dimentions are the same for all adjacency matrices to be fused
        
        fused_matrix = np.mean(adjMatrices, axis = 0) #computes the average for every entry in the matrices
    
    return fused_matrix

In [None]:
s = 100

## PSN implementation:

In [None]:
adjM_APOE = np.loadtxt('Files_from_tsd/adjMs/adjM_APOE_Jacc.csv', delimiter=',')
adjM_APOE_sc = scale_adjMatrix(adjM_APOE, 'standard')

In [None]:
adjM_eucl = np.loadtxt('Files_from_tsd/adjMs/adjM_V1_Eucl.csv', delimiter=',')
adjM_eucl_sc = scale_adjMatrix(adjM_eucl, 'standard')

In [None]:
adjM_cos = np.loadtxt('Files_from_tsd/adjMs/adjM_V1_Cos.csv', delimiter=',')
adjM_cos_sc = scale_adjMatrix(adjM_cos, 'standard')

In [None]:
adjM_pears = np.loadtxt('Files_from_tsd/adjMs/adjM_V1_Pears.csv', delimiter=',')
adjM_pears_sc = scale_adjMatrix(adjM_pears, 'standard')

In [None]:
adjM_clin = np.loadtxt('Files_from_tsd/adjMs/adjM_clinical_Cos.csv', delimiter=',')
adjM_clin_sc = scale_adjMatrix(adjM_clin, 'standard')

In [None]:
adjM_ptau = np.loadtxt('Files_from_tsd/adjMs/adjM_pTau_Eucl.csv', delimiter=',')
adjM_ptau_sc = scale_adjMatrix(adjM_ptau, 'standard')

In [None]:
adjM_ptau.shape

### Fuse all:

In [None]:
fusedM_APOE_cantab_cos_clinical_ptau = fuse_PSNs([adjM_APOE_sc, adjM_cos_sc, adjM_clin_sc, adjM_ptau_sc])

In [None]:
fusedM_APOE_cantab_cos_clinical_ptau = scale_adjMatrix(fusedM_APOE_cantab_cos_clinical_ptau, 'minmax')

In [None]:
fusedM_APOE_cantab_cos_clinical_ptau.shape

In [None]:
fusedPSN_all, sims_all = create_PSN(fusedM_APOE_cantab_cos_clinical_ptau, False, 0.632, s, 'Fused all, cutoff 0.632', True, False)

In [None]:
plot_similarities(sims_all, 'similarities', True, False)

## Community detection

### Louvain communities:

In [None]:
def over_k_louvain(G, k): #finds comnmunities over k in size
    
    all_nodes = list(G.nodes)
    
    louv_comm = nx.community.louvain_communities(G)
    
    comm_overk = [comm for comm in louv_comm if len(comm) >= k]
    
    nodes_in_comm = [node for comm in comm_overk for node in comm]
    
    nodes_outside_comm = [node for node in all_nodes if node not in nodes_in_comm]
    
    return comm_overk, nodes_outside_comm

#### Fused all:

In [None]:
over5_fused_all, rest_fused_all = over_k_louvain(fusedPSN_all, 5)

In [None]:
for i, comm in enumerate(over5_fused_all):
    if len(comm) < 20:
        over5_fused_all.remove(comm)
    print(len(comm), i)

In [None]:
len(over5_fused_all)

### Plot and mark communities:

In [None]:
palette = sns.color_palette('muted', 8)

In [None]:
palette[1]

In [None]:
def create_PSN_with_communities(G, size, seed, title, savefig):
    
    communities, rest_nodes = over_k_louvain(G, size)
    
    #col_list = ['#38761d', '#0b5394', '#351c75', '#741b47', '#783f04', '#51eaca', '#990000', '#a64d79']
    col_list = sns.color_palette('muted', 8) #sns.plot colors
    
    
    plt.figure(figsize=(10,10))
    plt.title(f'PSN with communities detected for {title}, minimum comm size={size}')
    pos = nx.spring_layout(G, seed=seed, k=1/15)
        
    for i, comm in enumerate(communities):
        nx.draw_networkx_nodes(G, pos, nodelist=list(comm), node_size=80, alpha=0.7, node_color=col_list[i])
    
    nx.draw_networkx_nodes(G, pos, nodelist=rest_nodes, node_size=80, alpha=0.7, node_color='#5b5b5b')
    nx.draw_networkx_edges(G, pos, width=2, alpha = 0.1)
    
    plt.grid(False)
    plt.gca().set_facecolor('white')
    
    if savefig:
        plt.savefig(f'Communitites_{title}.svg')
    
    return sorted(communities)
    

#### Fused all data:

In [None]:
comms_all = create_PSN_with_communities(fusedPSN_all, 10, s, 'Fused all, cutoff 0.632', False) 

In [None]:
modularity_score_all = modularity(fusedPSN_all, comms_all)
print(modularity_score_all)

In [None]:
for i, comm in enumerate(comms_all):
    if len(comm) < 20:
        comms_all.remove(comm)
    print(len(comm), i)

In [None]:
community_dist = {'A':126, 'B':110, 'C': 106}

In [None]:
comm_df = pd.DataFrame(list(community_dist.items()), columns=['Community', 'Size'])
sns.barplot(x='Community', y='Size', data=comm_df, palette='muted')
plt.title('Community distribution')
plt.savefig('Community_dist_bar.svg')
plt.show()

## Cluster alignment:

In [None]:
def cluster_alignment(adjM1, adjM2):
    '''
    Finds cluster alignment for two networks from their adjacency matrices
    Finds the 10 nearest neighbors for a node and compares to the 10 nearest neighbors
    for the same node in the other network using jaccard similarity.
    
    input: adjacency matrix 1 and 2
    
    returns list of alignment distribution for the two graphs
    '''
    alignment_scores = []
    
    G1 = nx.from_numpy_array(adjM1)
    G1_nodes = list(G1.nodes)
    
    G2 = nx.from_numpy_array(adjM2)
    G2_nodes = list(G2.nodes)
    
    n = len(G1_nodes)
    
    for i in range(n):
        
        G1_node = G1_nodes[i]
        G2_node = G2_nodes[i]
        
        if G1_node == G2_node: # check if the nodes are the same 
        
            nn_1 = find_10_nn(G1, G1_node) 
            nn_2 = find_10_nn(G2, G2_node)
            
            align_score = jaccard_sim(set(nn_1), set(nn_2))
            alignment_scores.append(align_score)
            
        else:
            
            print('Not same nodes in the networks!')
            break
    
    return alignment_scores
    
    

In [None]:
def find_10_nn(G, node):
    '''
    Finding 10 nearest neighbors for node in graph G (based on greatest weight)
    '''
    
    neighbors = G[node]
    top_10 = sorted(neighbors.items(), key=lambda x: x[1]['weight'], reverse=True)[:10]
    
    return [neigh for neigh, weight in top_10]
        

In [None]:
cluster_align_eucl_cos = cluster_alignment(adjM_eucl, adjM_cos)

In [None]:
cluster_align_eucl_pears = cluster_alignment(adjM_eucl, adjM_pears)

In [None]:
cluster_align_cos_pears = cluster_alignment(adjM_cos, adjM_pears)

In [None]:
np.average(cluster_align_cos_pears)

In [None]:
cluster_algn_df = pd.DataFrame(columns=['eucl_cos', 'eucl_pears', 'cos_pears'])
cluster_algn_df['eucl_cos'] = cluster_align_eucl_cos
cluster_algn_df['eucl_pears'] = cluster_align_eucl_pears
cluster_algn_df['cos_pears'] = cluster_align_cos_pears

cluster_algn_df

In [None]:
plt.Figure(figsize=(5,5))
plt.title('Cluster alignment distribution between PSNs created with the respective similarity metrics')
sns.violinplot(cluster_algn_df)
plt.yticks([0, 0.5, 1.0])
plt.ylim(0,1.1)
plt.savefig('cluster alignment distribution.svg')

In [None]:
plt.Figure(figsize=(5,5))
plt.title('Cluster alignment distribution cantab, eucl and cos')
sns.violinplot(cluster_align_eucl_cos)
plt.yticks([0, 0.5, 1.0])
plt.ylim(0,1.1)
plt.savefig('violin_CA_eucl_cos.svg')

In [None]:

plt.Figure(figsize=(5,5))
plt.title('Cluster alignment distribution cantab, eucl and pears')
ax1 = sns.violinplot(cluster_align_eucl_pears, inner_kws=dict(alpha=0.5))
ax1.set_alpha(0.5)
#sns.stripplot(cluster_align_eucl_pears)
plt.yticks([0, 0.5, 1.0])
plt.ylim(0,1.1)
plt.savefig('violin_CA_eucl_pears.svg')

#sns.violinplot(data=df, x="age", inner_kws=dict(box_width=15, whis_width=2, color=".8"))

In [None]:
plt.Figure(figsize=(2,2))
plt.title('Cluster alignment distribution cantab, cos and pears')
sns.violinplot(cluster_align_cos_pears)
plt.yticks([0, 0.5, 1.0])
plt.ylim(0,1.1)
plt.savefig('violin_CA_cos_pears.svg')
print(np.mean(cluster_align_cos_pears))

## Finding sim-cutoff vs modularity

In [None]:
degrees = dict(x.degree())
avg_deg = sum(degrees.values()) / len(x)
avg_deg

In [None]:
def plot_cutoff_vs_modularity(adjM, file_name):

    cutoffs = np.linspace(0, 1, num = 20) 
    print(cutoffs)
    
    number_of_communities = []
    modularity_scores = []
    similarities = []
    degrees = []
    clustering = []
    mod_vs_comm = []
    
    for cutoff in cutoffs: # iterate over cutoffs and create a PSN-network for each
        
        G = nx.from_numpy_array(adjM)
        PSN = nx.Graph()
        PSN.add_nodes_from(G.nodes)
        n = len(G.nodes)
        
        for i in range(n):
            for j in range(i+1,n):
                #print(f'{i}, {j}')
                sim = adjM[i,j]
                similarities.append(sim)
                if sim > cutoff:
                    PSN.add_edge(i, j, weight = sim)
        
        if len(PSN.edges()) == 0: 
            
            number_of_communities.append(np.nan)
            modularity_score = np.nan
            mod_vs_comm_score = np.nan
            
        else:
            
            communities = community.louvain_communities(PSN)
            number_of_communities.append(len(communities))
            modularity_score = modularity(PSN, communities)
            mod_vs_comm_score = modularity_score/len(communities)
            
            #print(mod_vs_comm_score)
        
        modularity_scores.append(modularity_score)
        mod_vs_comm.append(mod_vs_comm_score)
        
        deg = dict(PSN.degree())
        avg_degree = sum(deg.values()) / len(PSN)
        degrees.append(avg_degree)
        clustering.append(nx.average_clustering(PSN))
        
        
    mod_array = np.array(modularity_scores)
    masked_mod_array = np.ma.masked_invalid(mod_array) # masking the empty values 
    
    comm_array = np.array(number_of_communities)
    masked_comm_array = np.ma.masked_invalid(comm_array) # masking the empty values 
    
    #print(mod_vs_comm)
    mod_vs_comm_array = np.array(mod_vs_comm)
    masked_modcom_array = np.ma.masked_invalid(mod_vs_comm_array) # masking the empty values 
    
    sorted_sim = sorted(similarities, reverse=True)
    top_10_cutoff = sorted_sim[int(0.1*len(sorted_sim))]
    
    cutoff = cutoffs[12]
    
    
    # plotting cuttoffs against the variables:
    
    sns.set_style('whitegrid')
    
    fig, ((ax1, ax2), (ax3, ax4), (ax5, ax6)) = plt.subplots(3,2, figsize=(8,10))
    fig.suptitle('Dynamics of PSN connectivity and modularity as cutoff changes')
    
    ax1.plot(cutoffs, masked_comm_array, marker='o')
    #ax1.axvline(top_10_cutoff, color='k', linestyle = '--', linewidth=2)
    ax1.axvline(cutoff, color='k', linestyle = '--', linewidth=2)
    ax1.set_xlabel('Similarity cutoff')
    ax1.set_ylabel('Number of communities')
    ax1.set_xlim(-0.05,1.05)
    ax1.grid(False)
    ax1.set_facecolor('white')
    
    ax2.plot(cutoffs, masked_mod_array, marker='o', color='r')
    #ax2.axvline(top_10_cutoff, color='k', linestyle = '--', linewidth=2)
    ax2.axvline(cutoff, color='k', linestyle = '--', linewidth=2)
    ax2.set_xlabel('Similarity cutoff')
    ax2.set_ylabel('Modularity score')
    ax2.set_xlim(-0.05,1.05)
    ax2.set_yticks([i/10 for i in range(11)])
    ax2.grid(False)
    ax2.set_facecolor('white')
    
    ax3.plot(cutoffs, degrees, marker='o', color='g')
    #ax3.axvline(top_10_cutoff, color='k', linestyle = '--', linewidth=2)
    ax3.axvline(cutoff, color='k', linestyle = '--', linewidth=2)
    ax3.set_xlabel('Similarity cutoff')
    ax3.set_ylabel('Average node degree')
    ax3.set_xlim(-0.05,1.05)
    ax3.set_yticks([0, 50, 100, 150, 200, 250, 300, 350])
    ax3.grid(False)
    ax3.set_facecolor('white')
    
    ax4.plot(cutoffs, clustering, marker='o', color='#6a329f')
    #ax4.axvline(top_10_cutoff, color='k', linestyle = '--', linewidth=2)
    ax4.axvline(cutoff, color='k', linestyle = '--', linewidth=2)
    ax4.set_xlabel('Similarity cutoff')
    ax4.set_ylabel('Average clustering')
    ax4.set_xlim(-0.05,1.05)
    ax4.set_yticks([i/10 for i in range(11)])
    ax4.grid(False)
    ax4.set_facecolor('white')
    
    ax5.plot(cutoffs, masked_modcom_array, marker='o', color='#f39b3e')
    #ax5.axvline(top_10_cutoff, color='k', linestyle = '--', linewidth=2)
    ax5.axvline(cutoff, color='k', linestyle = '--', linewidth=2)
    ax5.set_xlabel('Similarity cutoff')
    ax5.set_ylabel('modularity/#communities')
    ax5.set_xlim(-0.05,1.05)
    ax5.set_ylim(0,0.1)
    ax5.grid(False)
    ax5.set_facecolor('white')
    
    
    ax6.plot(cutoffs, masked_mod_array, marker='o', color='r', label='Modularity')
    ax6.plot(cutoffs, clustering, marker='o', color= '#6a329f', label='Clustering')
    #ax6.axvline(top_10_cutoff, color='k', linestyle = '--', linewidth=2)
    ax6.axvline(cutoff, color='k', linestyle = '--', linewidth=2)
    ax6.set_xlabel('Similarity cutoff')
    ax6.set_ylabel('Score')
    ax6.set_xlim(-0.05, 1.05)
    ax6.set_yticks([i/10 for i in range(11)])
    ax6.tick_params(axis='both', which='both', direction='out', length=6, width=1)
    ax6.grid(False)
    ax6.legend()
    
    
    
    #sns.axes_style("whitegrid")
    plt.tight_layout()
    sns.set(font_scale=1.2)
    
    if file_name:
        plt.savefig(file_name)
        
    plt.show()
    print(mod_vs_comm_array)
       
    
    for comms in number_of_communities:
        print(comms)
        
    for i, modcom in enumerate(masked_modcom_array):
        print(f'{modcom}, {i}')
     


### Fused all:

In [None]:
plot_cutoff_vs_modularity(fusedM_APOE_cantab_cos_clinical_ptau, 'cutoff_mod_all_data.svg')# 