### Programming for Biomedical Informatics
#### Week 7 - Network Construction Techniques

Constructing networks that are useful representations of the underlying biological data is a complex task. In this notebook we will explore some key concepts that are used to incoporate data into networks and then refine those using a selection of methodologies.
Quantifying the impact of the assumptions and decisions made in the network construction and refinement process is a key part of the experimental analysis of networks. This is often confounded by the lack of ground-truth data upon which to make decisions.

Thanks to Sebestyen Kamp who developed parts of these scripts for a workshop on networks presented at ISMB2024 in Montreal, Canada. We have collected a lot of functions we use in this code into an accompaying graph_functions.py file to make this script easier to read and make function re-use simpler.

In [None]:
'''Biomedical Networks'''
# standard libraries
import os
import pickle
import graph_functions as gf

# scientific and data manipulation libraries
import numpy as np
import pandas as pd
from scipy import stats
from sklearn.feature_selection import mutual_info_regression
# mygene is a library for querying gene information (though you could use eUtils etc.)
import mygene

# graph and network libraries
import networkx as nx

# visualization libraries
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.graph_objects as go
import plotly.io as pio
from IPython.display import Image
from IPython.display import display

# some deprecation warnings that are safe to ignore can be silenced using the warnings library
import warnings
warnings.filterwarnings('ignore')

We're going to be looking at some gene expression data from the cancer genome atlas.

Note that it has multi-modal data - we will look at this in later lectures. We're going to concentrate on gene expression in this notebook

- **Title**: The Cancer Genome Atlas Lung Adenocarcinoma (TCGA-LUAD)
- **Main Focus**: Study of lung adenocarcinoma (a common type of lung cancer)
- **Data Collected**: Genomic, epigenomic, transcriptomic, and proteomic data from lung adenocarcinoma samples
- **Disease Types**:
  - Acinar Cell Neoplasms
  - Adenomas and Adenocarcinomas
  - Cystic, Mucinous, and Serous Neoplasms
- **Number of Cases**: 585 (498 with transcriptomic data)
- **Data Accessibility**: Available on the NIH-GDC Data Portal

- **Link**: [TCGA-LUAD Project Page](https://portal.gdc.cancer.gov/projects/TCGA-LUAD)

In [None]:
# load the gene expression data from a pickle file
'''a pickle file is a serialized python object that can be saved to disk and loaded back into memory
these can be very useful for sharing python objects. In this case we have a dictionary with the gene 
expression data'''

with open("ISMB_TCGA_GE.pkl", 'rb') as file:
    data = pickle.load(file)

# print the keys of the dictionary
print(data.keys())

In order to construct a biological network, we are going to first:
- examine the TCGA metadata 
- come up with useful strategies to tackle the large data size 
- create the basis of a biological network

We're going to first familiarise ourselves with the data by looking at the meta-data that \
comes with the gene expression data

In [None]:
# show the first few rows of the gene expression meta-data
metadata = data["datMeta"]
metadata.head()

In [None]:
# Count the number of unique patient identifiers in the 'patient' column of the dataFrame
print(f'There are',data["datMeta"]["patient"].unique().size,'patient samples')

In [None]:
# Count the occurrences of each unique value in the 'sample_type' column of the 'datMeta' DataFrame
print(data["datMeta"]['sample_type'].value_counts())

In [None]:
# It's usually a good idea to us a dimensionality reduction method to look for outliers and distributions of samples
# Here we will do a basic PCA

from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler

expression_data = data['datExpr']

# scale features, run PCA on 2-dimensions
X = expression_data.values
X_scaled = StandardScaler().fit_transform(X)
pca = PCA(n_components=2)
pcs = pca.fit_transform(X_scaled)
sample_pca = pd.DataFrame(pcs, columns=['PC1', 'PC2'], index=expression_data.index)

# add the sample_type so we can see if cancerous and non-cancerous samples are split
sample_pca = sample_pca.join(metadata['sample_type'])
sample_pca.index.name = 'sample_id'
print(sample_pca)

In [None]:
# plot the pca results
from matplotlib.patches import Patch

# map statuses to distinct colors
statuses = list(sample_pca['sample_type'].unique())
cmap = plt.get_cmap('Set1')
color_map = {s: cmap(i % cmap.N) for i, s in enumerate(statuses)}
colors = sample_pca['sample_type'].map(color_map)

# create figure and axes, scatter and attach colorbar to that axes
fig, ax = plt.subplots(figsize=(7,5))
sc = ax.scatter(sample_pca['PC1'], sample_pca['PC2'], s=50, edgecolor='k',c=list(colors))

# simple legend with colored patches
legend_elements = [Patch(facecolor=col, edgecolor='k', label=label) 
                   for label, col in color_map.items()]
ax.legend(handles=legend_elements, title='sample_type')

# labels and colorbar
ax.set_xlabel(f"PC1 ({pca.explained_variance_ratio_[0]*100:.1f}%)")
ax.set_ylabel(f"PC2 ({pca.explained_variance_ratio_[1]*100:.1f}%)")
ax.set_title('PCA')
plt.tight_layout()
plt.show()

In [None]:
# in order to focus on genes that are dysregulated between the two sample status classes we will perform differential expression analysis on them
# the data from TCGA are already normalised which means they are no loner integers.
# although it's not best practice we will simply round to the nearest integer to allow us to import these data into DESeq2
# in a full analysis we would go back to the raw count data per sample and re-build the data matrix
expression_data = expression_data.round().astype(int)

# look at the expression_data
expression_data.head()

We are going to visualise various metadata attributes such as race, gender, sample type, cigarettes per day, and smoking status by gender.

In [None]:
# explore the distributions of various data labels

# Set up the figure and axes for a 2-column layout
fig, axes = plt.subplots(3, 2, figsize=(18, 18))
fig.suptitle('Metadata Distributions', fontsize=20, y=0)

# Plot 1: Distribution of Race
sns.countplot(ax=axes[0, 0], x='race', data=data['datMeta'], palette='viridis')
axes[0, 0].set_title('Distribution of Race')
axes[0, 0].set_xlabel('Race')
axes[0, 0].set_ylabel('Count')
axes[0, 0].tick_params(axis='x', rotation=45)

# Plot 2: Gender Distribution
sns.countplot(ax=axes[0, 1], x='gender', data=data['datMeta'], palette='magma')
axes[0, 1].set_title('Gender Distribution')
axes[0, 1].set_xlabel('Gender')
axes[0, 1].set_ylabel('Count')

# Plot 3: Sample Type Distribution
sns.countplot(ax=axes[1, 0], x='sample_type', data=data['datMeta'], palette='plasma')
axes[1, 0].set_title('Sample Type Distribution')
axes[1, 0].set_xlabel('Sample Type')
axes[1, 0].set_ylabel('Count')
axes[1, 0].tick_params(axis='x', rotation=45)

# Plot 4: Distribution of Cigarettes Per Day
sns.histplot(ax=axes[1, 1], data=data['datMeta']['cigarettes_per_day'], kde=True, color='blue')
axes[1, 1].set_title('Distribution of Cigarettes Per Day')
axes[1, 1].set_xlabel('Cigarettes Per Day')
axes[1, 1].set_ylabel('Frequency')

# Plot 5: Smoking Status by Gender
sns.countplot(ax=axes[2, 0], x='Smoked', hue='gender', data=data['datMeta'], palette='coolwarm')
axes[2, 0].set_title('Smoking Status by Gender')
axes[2, 0].set_xlabel('Smoking Status')
axes[2, 0].set_ylabel('Count')
axes[2, 0].legend(title='Gender')

axes[2, 1].axis('off')
plt.tight_layout()
plt.show()

This dataset contains gene expression levels for various samples, identified by their TCGA (The Cancer Genome Atlas) codes.  
Each row represents a different sample, while each column represents a different gene, identified by its Ensembl gene ID.  
The values in the table are the expression levels of the genes for each sample.

In [None]:
# let's set up a DESEq2 experiment to compare cancer vs. non-cancer taking into account smoking status
# this is like the bi-factor analysis we did last week

from pydeseq2.dds import DeseqDataSet
from pydeseq2.default_inference import DefaultInference
from pydeseq2.ds import DeseqStats
from pydeseq2.utils import *

inference = DefaultInference(n_cpus=8)
dds = DeseqDataSet(
    counts=expression_data,
    metadata=metadata,
    design="~sample_type+Smoked",
    refit_cooks=True,
    inference=inference,
)

print(dds)

In [None]:
# as we want to focus on the network construction elements we will run a one-shot DESeq2 analysis without exploring intermediate steps
dds.deseq2()

In [None]:
## now we setup the comparison to be made
stat_res = DeseqStats(dds, contrast=['sample_type','Primary Tumor','Solid Tissue Normal'], inference=inference)

In [None]:
# extract the result slot
stat_res.summary()
results = stat_res.results_df

# sort the data by adjusted p-value
sorted_results = results.sort_values(by='padj', ascending=True)
sorted_results.reset_index(inplace=True)
sorted_results.head()

In [None]:
# we are going to restrict the differentially expressed genes retained by setting cut-offs for padj and fold-change
significant = sorted_results[(sorted_results['padj'] <= 0.05) & (abs(sorted_results['log2FoldChange']) >= 0.8)]
print(f'There are',len(significant),'significantly differentially expressed genes')

In [None]:
# get the ids of the genes to be retained
sig_genes = list(significant['_row'])

#select only the gene columns that are significant
significant_genes = expression_data[expression_data.columns.intersection(sig_genes)]
print(significant_genes.shape)

# so for 498 patients we have a fingerprint of 357 genes

In [None]:
# for interest let's conduct a PCA on these
# scale features, run PCA on 2-dimensions
X = significant_genes.values
X_scaled = StandardScaler().fit_transform(X)
pca = PCA(n_components=2)
pcs = pca.fit_transform(X_scaled)
sample_pca = pd.DataFrame(pcs, columns=['PC1', 'PC2'], index=significant_genes.index)

# add the sample_type so we can see if cancerous and non-cancerous samples are split
sample_pca = sample_pca.join(metadata['sample_type'])

#visualise
sample_pca.index.name = 'sample_id'

# map statuses to distinct colors
statuses = list(sample_pca['sample_type'].unique())
cmap = plt.get_cmap('Set1')
color_map = {s: cmap(i % cmap.N) for i, s in enumerate(statuses)}
colors = sample_pca['sample_type'].map(color_map)

# create figure and axes, scatter and attach colorbar to that axes
fig, ax = plt.subplots(figsize=(7,5))
sc = ax.scatter(sample_pca['PC1'], sample_pca['PC2'], s=50, edgecolor='k',c=list(colors))

# simple legend with colored patches
legend_elements = [Patch(facecolor=col, edgecolor='k', label=label) 
                   for label, col in color_map.items()]
ax.legend(handles=legend_elements, title='sample_type')

# labels and colorbar
ax.set_xlabel(f"PC1 ({pca.explained_variance_ratio_[0]*100:.1f}%)")
ax.set_ylabel(f"PC2 ({pca.explained_variance_ratio_[1]*100:.1f}%)")
ax.set_title('PCA')
plt.tight_layout()
plt.show()

# very interesting - are there some mis-labelled samples? other explanation?
# note how much better separated the data are and note also PC1 now accounts for 29.1% of the variance in the data
# note also how much tighter the normal tissue cluster is

You may have noticed that the column header are not gene names so we're going to fix that by mapping (as we have done before in the course). You could use eUtils to do this or even bulk download the meta-data and use table merging, but we're going to use a nice package called "mygene" - (https://docs.mygene.info/projects/mygene-py/en/latest/)

Converting Ensembl gene IDs (ENSG) to HGNC (HUGO Gene Nomenclature Committee) gene symbols is often a good practice as HGCN is an international standard.

In [None]:
# define a function to do the gene mapping
def rename_ensembl_to_gene_names(df, chunk_size=1000):
    """
    Renames Ensembl gene IDs to gene names using mygene.

    NB we chunk the requests to avoid hitting the rate limit.
    
    Parameters:
    df (pd.DataFrame): DataFrame with Ensembl gene IDs as columns.
    chunk_size (int): Number of Ensembl IDs to query at a time.
    
    Returns:
    pd.DataFrame: DataFrame with gene names as columns, excluding genes that couldn't be mapped.
    """
    
    # Make a copy of the DataFrame to avoid modifying the original
    df_copy = df.copy()

    # Remove the `.number` suffix from ENSG IDs
    df_copy.columns = df_copy.columns.str.split('.').str[0]

    # Initialize mygene client
    mg = mygene.MyGeneInfo()

    # Split ENSG IDs into smaller chunks
    def chunks(lst, n):
        """Yield successive n-sized chunks from lst."""
        for i in range(0, len(lst), n):
            yield lst[i:i + n]

    ensg_ids = df_copy.columns.tolist()
    gene_mappings = []

    unmapped_genes = []

    # send requests in chunks
    for chunk in chunks(ensg_ids, chunk_size):
        result = mg.querymany(chunk, scopes='ensembl.gene', fields='symbol', species='human')
        gene_mappings.extend(result)

    # Create a mapping from ENSG to gene symbol, handle missing mappings
    ensg_to_gene = {item['query']: item.get('symbol', None) for item in gene_mappings}
    
    # Log the unmapped genes
    batch_unmapped_genes = [gene for gene in ensg_ids if ensg_to_gene.get(gene) is None]
    if batch_unmapped_genes:
        # Add unmapped genes to the list
        unmapped_genes.extend(batch_unmapped_genes)

    # Filter the DataFrame to only include columns that have been mapped
    df_filtered = df_copy.loc[:, df_copy.columns.isin(ensg_to_gene.keys())]

    # Further filter to ensure we have the same number of columns as mapped gene names
    df_filtered = df_filtered.loc[:, [ensg for ensg in df_filtered.columns if ensg_to_gene[ensg] is not None]]

    # Assign new column names
    df_filtered.columns = [ensg_to_gene[ensg] for ensg in df_filtered.columns]

    # Handle duplicate gene names by aggregating them (e.g., by taking the mean)
    df_final = df_filtered.T.groupby(df_filtered.columns).mean().T

    return df_final, set(unmapped_genes)

In [None]:
# convert the ensembl gene IDs to gene names
final_genes,unmapped_genes = rename_ensembl_to_gene_names(significant_genes)
print(f'{len(unmapped_genes)} were not mapped to gene names.')

In [None]:
# let's inspect the first few rows of the renamed DataFrame
final_genes.head()

We now want to create correlation matrices based on gene profiles to see how groups of genes vary between cancerous and non-cancerous samples

There are a few correlation metrics one could consider:
- [Pearson](https://en.wikipedia.org/wiki/Pearson_correlation_coefficient)  
  - O(n^2) complexity, fast for large datasets
- [Spearman](https://en.wikipedia.org/wiki/Spearman%27s_rank_correlation_coefficient)
  -  O(n^2 log n) complexity, relatively fast but can be slower than Pearson
- [Absolute biweight midcorrelation](https://en.wikipedia.org/wiki/Biweight_midcorrelation)
  - Robust but slower than Pearson and Spearman, suitable for datasets with outliers


We now want to establish whether there are any strong relationships between the expression levels of the genes in the different samples. We can do this by calculating the correlation between their expression values across samples. We can use this correlation matrix as an adjacency matrix to build a network.

- nodes: genes  
- edges: highly correlated genes (above a given threshold)
- edge-weights: correlation values

In [None]:
# we're using some functions that we have defined in the accompanying functions file

# Dictionary to store different correlation matrices
correlation_matrices = {}

# Pearson correlation - O(n^2) complexity, fast for large datasets
correlation_matrices['pearson'] = final_genes.corr(method='pearson')

# Spearman rank correlation -  O(n^2 log n) complexity, relatively fast but can be slower than Pearson
correlation_matrices['spearman'] = final_genes.corr(method='spearman')

# Biweight midcorrelation -  Robust but slower than Pearson and Spearman, suitable for datasets with outliers
correlation_matrices['biweight_midcorrelation'] = gf.calc_abs_bicorr(final_genes)

# Print the keys of the correlation matrices to verify
print("Correlation matrices calculated:")
print(correlation_matrices.keys())

In [None]:
# plot these in a multi-panel plot
gf.plot_correlation_matrices(correlation_matrices)

In [None]:
# let's create an actual graph from the pearson correlation matrix
G = gf.create_graph_from_correlation(correlation_matrices['pearson'], threshold=0.7)

We are going to define a function `create_graph_from_correlation` to make networks from correlation matrices.

The function starts by creating an empty graph G. Then iterates through the columns of the correlation matrix and adds each column name as a node in the graph. This means each gene (or feature) in your dataset becomes a node in the graph.

The function iterates over the upper triangle of the correlation matrix (excluding the diagonal) to avoid redundancy and self-loops. Remembering that this is an undirected graph so is symmetric.

For each pair of nodes (i, j), it checks if the absolute value of the correlation coefficient between them is greater than or equal to the specified threshold.

If the condition is met, an edge is added between the nodes i and j with the correlation coefficient as the weight of the edge. This signifies a strong correlation (positive or negative) between the two nodes.

In [None]:
# Create a graph from the correlation matrix using a specified threshold
def create_graph_from_correlation(correlation_matrix, threshold=0.8):
    """
    Creates a graph from a correlation matrix using a specified threshold.

    Parameters:
    correlation_matrix (pd.DataFrame): DataFrame containing the correlation matrix.
    threshold (float): Threshold for including edges based on correlation value.

    Returns:
    G (nx.Graph): Graph created from the correlation matrix.
    """
    G = nx.Graph()

    # Add nodes
    for node in correlation_matrix.columns:
        G.add_node(node)

    # Add edges with weights above the threshold
    for i in range(correlation_matrix.shape[0]):
        for j in range(i + 1, correlation_matrix.shape[1]):
            if i != j:  # Ignore the diagonal elements
                weight = correlation_matrix.iloc[i, j]
                if abs(weight) >= threshold:
                    G.add_edge(correlation_matrix.index[i], correlation_matrix.columns[j], weight=weight)

    return G

In [None]:
# Create a graph from the Pearson correlation matrix with a threshold of 0.75
pearson_graph = create_graph_from_correlation(correlation_matrices['pearson'], threshold=0.5)

Now let's go through a few useful NetworkX functions and create a `print_graph_info()` function.

In [None]:
# Print basic information about the graph
def print_graph_info(G):
    """
    Print basic information about a NetworkX graph.

    
    Parameters:
    G (nx.Graph): The NetworkX graph.
    """
    print(f"Number of nodes: {G.number_of_nodes()}")
    print(f"Number of edges: {G.number_of_edges()}")
    print("Sample nodes:", list(G.nodes)[:10])  # Print first 10 nodes as a sample
    print("Sample edges:", list(G.edges(data=True))[:10])  # Print first 10 edges as a sample
    
    info_str = "Graph type: "
    is_directed = G.is_directed()
    if is_directed:
        info_str += "directed"
    else:
        info_str += "undirected"
    print(info_str)

    # Check for self-loops
    self_loops = list(nx.selfloop_edges(G))
    if self_loops:
        print(f"Number of self-loops: {len(self_loops)}")
        print("Self-loops:", self_loops)
    else:
        print("No self-loops in the graph.")

    # density of the graph
    density = nx.density(G)
    print(f"Graph density: {density}")

    # Find and print the number of connected components
    num_connected_components = nx.number_connected_components(G)
    print(f"Number of connected components: {num_connected_components}")

    # Calculate and print the clustering coefficient of the graph
    clustering_coeff = nx.average_clustering(G)
    print(f"Average clustering coefficient: {clustering_coeff}")

In [None]:
print_graph_info(pearson_graph)

In [None]:
# # Visualize the graph
gf.visualise_graph(pearson_graph, title='Pearson Correlation Network (Threshold = 0.7)')

We now have the base gene correlation network but we can see that there are a lot of orphans (due to the threshold filterinf and so need to clean the network up. We can use functions from NetworkX for this.

In [None]:
# Function to clean the graph
def clean_graph(G, degree_threshold=1, keep_largest_component=True):
    """
    Cleans the graph by performing several cleaning steps:
    - Removes unconnected nodes (isolates)
    - Removes self-loops
    - Removes nodes with a degree below a specified threshold
    - Keeps only the largest connected component (optional)

    Parameters:
    G (nx.Graph): The NetworkX graph to clean.
    degree_threshold (int): Minimum degree for nodes to keep.
    keep_largest_component (bool): Whether to keep only the largest connected component.

    Returns:
    G (nx.Graph): Cleaned graph.
    """
    G = G.copy()  # Work on a copy of the graph to avoid modifying the original graph

    # Remove self-loops
    G.remove_edges_from(nx.selfloop_edges(G))

    # Remove nodes with no edges (isolates)
    G.remove_nodes_from(list(nx.isolates(G)))

    # Remove nodes with degree below the threshold
    low_degree_nodes = [node for node, degree in dict(G.degree()).items() if degree < degree_threshold]
    G.remove_nodes_from(low_degree_nodes)

    # Keep only the largest connected component
    if keep_largest_component:
        largest_cc = max(nx.connected_components(G), key=len)
        G = G.subgraph(largest_cc).copy()

    return G

In [None]:
# Clean the graph by removing unconnected nodes
pearson_graph_cleaned = clean_graph(pearson_graph,
                                    degree_threshold=1,
                                    keep_largest_component=False)

In [None]:
# view the cleaned graph
# NB this is now tractable quickly as the graph is much smaller
gf.visualise_graph(pearson_graph_cleaned, title='Pearson Correlation Network - Cleaned')

In [None]:
# we can re-use the function to print the graph information
print_graph_info(pearson_graph_cleaned)

In [None]:
# Clean the graph by keeping only the largest connected component
pearson_graph_pruned = clean_graph(pearson_graph,
                                    degree_threshold=1,
                                    keep_largest_component=True)

In [None]:
gf.visualise_graph(pearson_graph_pruned, title='Pearson Correlation Network - Pruned')

In [None]:
# we can re-use the function to print the graph information
print_graph_info(pearson_graph_pruned)

In [None]:
# Visualize the distribution of edge weights
gf.visualise_edge_weight_distribution(pearson_graph_pruned)

With sparsification we aim to reduce the number of edges in a network while preserving important structural properties.

- Edge Sampling: Randomly removes a fraction of edges.
- Thresholding: Removes edges with weights below a certain threshold.
- Degree-based Sparsification

In [None]:
# simply remove the edges below a certain edge-weight threshold
def threshold_sparsification(graph, threshold):
    """
    Sparsifies the graph by removing edges below the specified weight threshold.

    Parameters:
    graph (nx.Graph): The original NetworkX graph.
    threshold (float): The weight threshold.

    Returns:
    nx.Graph: The sparsified graph.
    """
    graph_copy = graph.copy()
    sparsified_graph = nx.Graph()
    sparsified_graph.add_nodes_from(graph_copy.nodes(data=True))
    sparsified_graph.add_edges_from((u, v, d) for u, v, d in graph_copy.edges(data=True) if d.get('weight', 0) >= threshold)
    return sparsified_graph

# keep the specified top quantile of edges by edge-weight
def top_percentage_sparsification(graph, top_percentage):
    """
    Sparsifies the graph by keeping the top percentage of edges by weight.

    Parameters:
    graph (nx.Graph): The original NetworkX graph.
    top_percentage (float): The percentage of top-weight edges to keep.

    Returns:
    nx.Graph: The sparsified graph.
    """
    graph_copy = graph.copy()
    sorted_edges = sorted(graph_copy.edges(data=True), key=lambda x: x[2].get('weight', 0), reverse=True)
    top_edges_count = max(1, int(len(sorted_edges) * (top_percentage / 100)))
    sparsified_graph = nx.Graph()
    sparsified_graph.add_nodes_from(graph_copy.nodes(data=True))
    sparsified_graph.add_edges_from(sorted_edges[:top_edges_count])
    return sparsified_graph


# remove nodes with degree below a certain threshold
def remove_by_degree(graph, min_degree):
    """
    Sparsifies the graph by removing nodes with degree below the specified threshold.

    Parameters:
    graph (nx.Graph): The original NetworkX graph.
    min_degree (int): The minimum degree threshold.

    Returns:
    nx.Graph: The sparsified graph.
    """
    graph_copy = graph.copy()
    nodes_to_remove = [node for node, degree in dict(graph_copy.degree()).items() if degree < min_degree]
    
    graph_copy.remove_nodes_from(nodes_to_remove)
    return graph_copy

# use KNN sparsification to keep up to only the top N edges for a node
def knn_sparsification(graph, k):
    """
    Sparsifies the graph by keeping only the top-k edges with the highest weights for each node.

    Parameters:
    graph (nx.Graph): The original NetworkX graph.
    k (int): The number of nearest neighbors to keep for each node.

    Returns:
    nx.Graph: The sparsified graph.
    """
    graph_copy = graph.copy()
    sparsified_graph = nx.Graph()
    sparsified_graph.add_nodes_from(graph_copy.nodes(data=True))
    
    for node in graph_copy.nodes():
        edges = sorted(graph_copy.edges(node, data=True), key=lambda x: x[2].get('weight', 0), reverse=True)
        sparsified_graph.add_edges_from(edges[:k])
    
    return sparsified_graph

# create a minimum spanning tree
def spanning_tree_sparsification(graph):
    """
    Sparsifies the graph by creating a minimum spanning tree.

    Parameters:
    graph (nx.Graph): The original NetworkX graph.

    Returns:
    nx.Graph: The sparsified graph.
    """
    graph_copy = graph.copy()
    return nx.minimum_spanning_tree(graph_copy, weight='weight')



In [None]:
# Initialise a dictionary to store graphs
graphs = {}
# Store the original graph for comparison
graphs['original'] = pearson_graph_pruned.copy()

# Apply sparsification methods to the original graph
graphs['knn_5'] = knn_sparsification(graphs['original'], k=5)
# graphs['threshold'] = threshold_sparsification(graphs['original'], threshold=0.82)
# graphs['top_10_percent'] = top_percentage_sparsification(graphs['original'], top_percentage=10)
# graphs['degree_below_3'] = remove_by_degree(graphs['original'], min_degree=3)
# graphs['spanning_tree'] = spanning_tree_sparsification(graphs['original'])


# Visualise the graphs after sparsification
gf.visualise_graph(graphs['original'], 'Original Graph')
gf.visualise_graph(graphs['knn_5'], 'K-Nearest Neighbors (k=5)')
# gf.visualise_graph(graphs['threshold'], 'Thresholded Graph (weight > 0.82)')
# gf.visualise_graph(graphs['top_10_percent'], 'Top 10% Edges by Weight')
# gf.visualise_graph(graphs['degree_below_3'], 'Degree Below 3')
# gf.visualise_graph(graphs['spanning_tree'], 'Minimum Spanning Tree')


In [None]:
# Let's inspect the information of the KNN sparsified graph
print_graph_info(graphs['knn_5'])

In [None]:
# function to analyse the effect of different k values on the network properties
def analyse_knn_effect(graph, k_values):
    """
    Analyses the effect of different k values on the network properties.

    Parameters:
    graph (nx.Graph): The original NetworkX graph.
    k_values (list): List of k values to use for sparsification.

    Returns:
    pd.DataFrame: DataFrame containing the analysis results.
    """
    results = {
        'k': [],
        'num_edges': [],
        'avg_degree': [],
        'avg_clustering': [],
        'num_connected_components': [],
    }
    
    for k in k_values:
        sparsified_graph = knn_sparsification(graph, k)
        num_edges = sparsified_graph.number_of_edges()
        avg_degree = sum(dict(sparsified_graph.degree()).values()) / sparsified_graph.number_of_nodes()
        avg_clustering = nx.average_clustering(sparsified_graph)
        num_connected_components = nx.number_connected_components(sparsified_graph)
        
        results['k'].append(k)
        results['num_edges'].append(num_edges)
        results['avg_degree'].append(avg_degree)
        results['avg_clustering'].append(avg_clustering)
        results['num_connected_components'].append(num_connected_components)
    
    return pd.DataFrame(results)

# plot the analysis of the effect of different k values on network properties
def plot_knn_analysis(df):
    """
    Plots the analysis of the effect of different k values on network properties.

    Parameters:
    df (pd.DataFrame): DataFrame containing the analysis results.
    """
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))
    
    axes[0, 0].plot(df['k'], df['num_edges'], marker='o')
    axes[0, 0].set_title('Number of Edges vs k')
    axes[0, 0].set_xlabel('k')
    axes[0, 0].set_ylabel('Number of Edges')
    
    axes[0, 1].plot(df['k'], df['avg_degree'], marker='o')
    axes[0, 1].set_title('Average Degree vs k')
    axes[0, 1].set_xlabel('k')
    axes[0, 1].set_ylabel('Average Degree')
    
    axes[1, 0].plot(df['k'], df['avg_clustering'], marker='o')
    axes[1, 0].set_title('Average Clustering Coefficient vs k')
    axes[1, 0].set_xlabel('k')
    axes[1, 0].set_ylabel('Average Clustering Coefficient')
    
    axes[1, 1].plot(df['k'], df['num_connected_components'], marker='o')
    axes[1, 1].set_title('Number of Connected Components vs k')
    axes[1, 1].set_xlabel('k')
    axes[1, 1].set_ylabel('Number of Connected Components')
    
    plt.tight_layout()
    plt.show()

In [None]:
k_values = list(range(1, 11))  # Different k values to analyze
analysis_results = analyse_knn_effect(graphs['original'], k_values)

# Plot the analysis results
plot_knn_analysis(analysis_results)

In [None]:
# function to look at the top nodes based on degree
def get_highest_degree_nodes(graph, top_n=10):
    """
    Returns the nodes with the highest degree in the graph.

    Parameters:
    graph (nx.Graph): The NetworkX graph.
    top_n (int): The number of top nodes to return.

    Returns:
    List of tuples: Each tuple contains a node and its degree.
    """
    degrees = dict(graph.degree())
    sorted_degrees = sorted(degrees.items(), key=lambda x: x[1], reverse=True)
    return sorted_degrees[:top_n]

# gather some information about nodes using mygene
def fetch_gene_info(gene_list):
    """
    Fetches gene information from MyGene.info.

    Parameters:
    gene_list (list): List of gene symbols or Ensembl IDs.

    Returns:
    list: List of dictionaries containing gene information.
    """
    mg = mygene.MyGeneInfo()
    gene_info = mg.querymany(gene_list, scopes='symbol,ensembl.gene', 
                             fields='name,symbol,entrezgene,summary,disease,pathway', 
                             species='human')
    return gene_info

# combined function to report node information alongside gene metadata
def print_gene_info_with_degree(top_genes_with_degrees, gene_info):
    """
    Prints gene information including the degree.

    Parameters:
    top_genes_with_degrees (list): List of tuples containing gene symbols and their degrees.
    gene_info (list): List of dictionaries containing gene information.
    """
    for gene, degree in top_genes_with_degrees:
        info = next((item for item in gene_info if item['query'] == gene), None)
        if info:
            print(f"Gene Symbol: {info.get('symbol', 'N/A')}")
            print(f"Degree: {degree}")
            print(f"Gene Name: {info.get('name', 'N/A')}")
            print(f"Entrez ID: {info.get('entrezgene', 'N/A')}")
            print(f"Summary: {info.get('summary', 'N/A')}")
            if 'disease' in info:
                diseases = ', '.join([d['term'] for d in info['disease']])
                print(f"Diseases: {diseases}")
            else:
                print("Diseases: N/A")
            if 'pathway' in info:
                pathways = []
                if isinstance(info['pathway'], dict):
                    for key in info['pathway']:
                        pathway_data = info['pathway'][key]
                        if isinstance(pathway_data, list):
                            pathways.extend([p['name'] for p in pathway_data if 'name' in p])
                        elif isinstance(pathway_data, dict) and 'name' in pathway_data:
                            pathways.append(pathway_data['name'])
                        elif isinstance(pathway_data, str):
                            pathways.append(pathway_data)
                print(f"Pathways: {', '.join(pathways) if pathways else 'N/A'}")
            else:
                print("Pathways: N/A")
            print("-" * 40)
        else:
            print(f"Gene not found: {gene}")
            print(f"Degree: {degree}")
            print("-" * 40)



In [None]:
# get the top 10 genes with the highest degree in the pruned graph using get_highest_degree_nodes
top_genes_with_degrees = get_highest_degree_nodes(pearson_graph_pruned, top_n=10)
gene_symbols = [gene for gene, degree in top_genes_with_degrees]

# get gene information with fetch_gene_info
gene_info = fetch_gene_info(gene_symbols)

# print gene information including degree
print_gene_info_with_degree(top_genes_with_degrees, gene_info)
