# Graph Analysis EDA
This notebook performs exploratory data analysis on the drug-disease knowledge graph

In [None]:
import networkx as nx
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import json
from collections import Counter
import plotly.express as px
import plotly.graph_objects as go
from IPython.display import display

# Set style
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette("viridis")

In [None]:
# Load the knowledge graph
graph_path = Path('../data/processed/graph/knowledge_graph.graphml')

print(f"Loading graph from {graph_path}...")
G = nx.read_graphml(graph_path)
print(f"Graph loaded with {G.number_of_nodes()} nodes and {G.number_of_edges()} edges")

In [None]:
# Basic graph statistics
def print_graph_stats(G):
    print("Graph Statistics:")
    print(f"  Nodes: {G.number_of_nodes()}")
    print(f"  Edges: {G.number_of_edges()}")
    print(f"  Is Directed: {nx.is_directed(G)}")
    print(f"  Connected Components: {nx.number_connected_components(G.to_undirected())}")
    
    # Largest connected component
    Gcc = max(nx.connected_components(G.to_undirected()), key=len)
    print(f"  Largest Component Size: {len(Gcc)} nodes ({len(Gcc)/G.number_of_nodes()*100:.2f}%)")
    
    # Node degree statistics
    degrees = [d for _, d in G.degree()]
    print(f"  Average Degree: {np.mean(degrees):.2f}")
    print(f"  Min Degree: {min(degrees)}")
    print(f"  Max Degree: {max(degrees)}")
    print(f"  Median Degree: {np.median(degrees):.2f}")
    
    # Density
    print(f"  Graph Density: {nx.density(G):.6f}")

print_graph_stats(G)

In [None]:
# Analyze node types
def analyze_node_types(G):
    node_types = {}
    for node, attrs in G.nodes(data=True):
        node_type = attrs.get('type', 'unknown')
        node_types[node_type] = node_types.get(node_type, 0) + 1
    
    return pd.DataFrame.from_dict(node_types, orient='index', columns=['count']).sort_values('count', ascending=False)

node_types_df = analyze_node_types(G)
print("\nNode Types:")
display(node_types_df)


In [None]:
# Plot node type distribution
plt.figure(figsize=(10, 6))
node_types_df.plot(kind='bar', color='skyblue')
plt.title('Distribution of Node Types')
plt.xlabel('Node Type')
plt.ylabel('Count')
plt.grid(axis='y', linestyle='--', alpha=0.7)
plt.tight_layout()
plt.savefig('../docs/figures/node_type_distribution.png', dpi=300, bbox_inches='tight')
plt.show()



In [None]:
# Analyze edge types
def analyze_edge_types(G):
    edge_types = {}
    for _, _, attrs in G.edges(data=True):
        edge_type = attrs.get('type', 'unknown')
        edge_types[edge_type] = edge_types.get(edge_type, 0) + 1
    
    return pd.DataFrame.from_dict(edge_types, orient='index', columns=['count']).sort_values('count', ascending=False)

edge_types_df = analyze_edge_types(G)
print("\nEdge Types:")
display(edge_types_df)



In [None]:
# Plot edge type distribution
plt.figure(figsize=(10, 6))
edge_types_df.plot(kind='bar', color='lightgreen')
plt.title('Distribution of Edge Types')
plt.xlabel('Edge Type')
plt.ylabel('Count')
plt.grid(axis='y', linestyle='--', alpha=0.7)
plt.tight_layout()
plt.savefig('../docs/figures/edge_type_distribution.png', dpi=300, bbox_inches='tight')
plt.show()



In [None]:
# Analyze node degree distribution by type
def analyze_node_degree_by_type(G):
    node_degree_by_type = {}
    
    for node, attrs in G.nodes(data=True):
        node_type = attrs.get('type', 'unknown')
        degree = G.degree(node)
        
        if node_type not in node_degree_by_type:
            node_degree_by_type[node_type] = []
        
        node_degree_by_type[node_type].append(degree)
    
    # Calculate statistics
    stats = {}
    for node_type, degrees in node_degree_by_type.items():
        stats[node_type] = {
            'count': len(degrees),
            'mean': np.mean(degrees),
            'median': np.median(degrees),
            'min': min(degrees),
            'max': max(degrees),
            'std': np.std(degrees)
        }
    
    return pd.DataFrame.from_dict(stats, orient='index')

node_degree_stats = analyze_node_degree_by_type(G)
print("\nNode Degree Statistics by Type:")
display(node_degree_stats)



In [None]:
# Plot node degree statistics
plt.figure(figsize=(12, 6))
sns.barplot(data=node_degree_stats.reset_index(), x='index', y='mean')
plt.title('Average Node Degree by Type')
plt.xlabel('Node Type')
plt.ylabel('Average Degree')
plt.grid(axis='y', linestyle='--', alpha=0.7)
plt.tight_layout()
plt.savefig('../docs/figures/avg_degree_by_type.png', dpi=300, bbox_inches='tight')
plt.show()



In [None]:
# Analyze drug-disease paths
def find_paths_between_types(G, source_type, target_type, max_length=3):
    """Find paths between nodes of specified types up to a maximum length"""
    paths = []
    path_lengths = []
    path_patterns = []
    
    # Get nodes of each type
    source_nodes = [n for n, attrs in G.nodes(data=True) if attrs.get('type') == source_type]
    target_nodes = [n for n, attrs in G.nodes(data=True) if attrs.get('type') == target_type]
    
    # Sample to limit computational load
    max_samples = min(100, len(source_nodes))
    sampled_sources = np.random.choice(source_nodes, max_samples, replace=False)
    
    for source in sampled_sources:
        for target in target_nodes[:10]:  # Limit targets to first 10 for each source
            try:
                # Find shortest path
                path = nx.shortest_path(G, source=source, target=target)
                if len(path) <= max_length + 1:  # +1 because path includes nodes
                    paths.append(path)
                    path_lengths.append(len(path) - 1)  # Convert to edge count
                    
                    # Create path pattern (sequence of node types)
                    pattern = []
                    for node in path:
                        node_type = G.nodes[node].get('type', 'unknown')
                        pattern.append(node_type)
                    
                    path_patterns.append('->'.join(pattern))
            except nx.NetworkXNoPath:
                continue
    
    return paths, path_lengths, path_patterns



In [None]:
# Analyze drug-disease paths
if 'drug' in node_types_df.index and 'disease' in node_types_df.index:
    print("\nAnalyzing paths between drugs and diseases...")
    paths, path_lengths, path_patterns = find_paths_between_types(G, 'drug', 'disease', max_length=4)
    
    if paths:
        print(f"Found {len(paths)} paths between drugs and diseases (max length 4)")
        
        # Analyze path lengths
        path_length_counts = Counter(path_lengths)
        path_length_df = pd.DataFrame.from_dict(path_length_counts, orient='index', columns=['count']).sort_index()
        
        plt.figure(figsize=(8, 5))
        path_length_df.plot(kind='bar', color='coral')
        plt.title('Distribution of Path Lengths (Drug to Disease)')
        plt.xlabel('Path Length')
        plt.ylabel('Count')
        plt.grid(axis='y', linestyle='--', alpha=0.7)
        plt.tight_layout()
        plt.savefig('../docs/figures/drug_disease_path_length.png', dpi=300, bbox_inches='tight')
        plt.show()
        
        # Analyze path patterns
        path_pattern_counts = Counter(path_patterns)
        common_patterns = pd.DataFrame.from_dict(path_pattern_counts, orient='index', columns=['count']).sort_values('count', ascending=False).head(10)
        
        print("\nMost Common Path Patterns (Drug to Disease):")
        display(common_patterns)
else:
    print("Either drug or disease nodes are missing from the graph")



In [None]:
# Network visualization of a subgraph
def visualize_subgraph(G, start_node=None, n_nodes=50, node_size_factor=100, edge_width=1.0):
    """Visualize a subgraph starting from a specific node or a random node"""
    if start_node is None:
        start_node = np.random.choice(list(G.nodes()))
    
    # Extract a subgraph using BFS
    nodes = set([start_node])
    frontier = [start_node]
    
    while len(nodes) < n_nodes and frontier:
        current = frontier.pop(0)
        neighbors = list(G.neighbors(current))
        
        for neighbor in neighbors:
            if neighbor not in nodes:
                nodes.add(neighbor)
                frontier.append(neighbor)
                
                if len(nodes) >= n_nodes:
                    break
    
    # Create the subgraph
    H = G.subgraph(nodes)
    
    # Prepare node colors and sizes based on type
    color_map = {
        'drug': 'skyblue',
        'protein': 'lightgreen',
        'disease': 'salmon',
        'pathway': 'gold',
        'category': 'violet',
        'unknown': 'gray'
    }
    
    node_colors = []
    node_sizes = []
    node_labels = {}
    
    for node in H.nodes():
        node_type = H.nodes[node].get('type', 'unknown')
        node_colors.append(color_map.get(node_type, 'gray'))
        
        # Node size based on degree
        node_sizes.append(H.degree(node) * node_size_factor)
        
        # Node labels
        node_labels[node] = H.nodes[node].get('name', node)
    
    # Plot
    plt.figure(figsize=(12, 12))
    pos = nx.spring_layout(H, seed=42)
    
    nx.draw_networkx_nodes(H, pos, node_color=node_colors, node_size=node_sizes, alpha=0.8)
    nx.draw_networkx_edges(H, pos, width=edge_width, alpha=0.5, arrows=True)
    nx.draw_networkx_labels(H, pos, labels=node_labels, font_size=8, font_weight='bold')
    
    # Create legend for node types
    legend_elements = [plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=color, 
                                   markersize=10, label=node_type)
                       for node_type, color in color_map.items() if node_type in [H.nodes[n].get('type', 'unknown') for n in H.nodes()]]
    
    plt.legend(handles=legend_elements, loc='upper right')
    plt.axis('off')
    plt.title(f'Subgraph Starting from {node_labels[start_node]}')
    plt.tight_layout()
    plt.savefig('../docs/figures/subgraph_visualization.png', dpi=300, bbox_inches='tight')
    plt.show()



In [None]:
# Try to find a drug node to start visualization
drug_nodes = [n for n, attrs in G.nodes(data=True) if attrs.get('type') == 'drug']
if drug_nodes:
    visualize_subgraph(G, start_node=drug_nodes[0], n_nodes=30)
else:
    visualize_subgraph(G, n_nodes=30)



In [None]:
# Centrality analysis
def analyze_centrality(G, top_n=20):
    """Analyze centrality measures for the graph"""
    print("Calculating centrality measures...")
    
    # Limit to largest connected component for centrality calculations
    largest_cc = max(nx.connected_components(G.to_undirected()), key=len)
    H = G.subgraph(largest_cc).copy()
    
    # Calculate centrality measures
    degree_centrality = nx.degree_centrality(H)
    betweenness_centrality = nx.betweenness_centrality(H, k=100)  # Use k to limit computation
    eigenvector_centrality = nx.eigenvector_centrality_numpy(H)
    
    # Combine results
    centrality_df = pd.DataFrame({
        'degree': pd.Series(degree_centrality),
        'betweenness': pd.Series(betweenness_centrality),
        'eigenvector': pd.Series(eigenvector_centrality)
    })
    
    # Add node attributes
    node_types = []
    node_names = []
    
    for node in centrality_df.index:
        node_types.append(H.nodes[node].get('type', 'unknown'))
        node_names.append(H.nodes[node].get('name', node))
    
    centrality_df['type'] = node_types
    centrality_df['name'] = node_names
    
    # Get top nodes by each measure
    top_degree = centrality_df.sort_values('degree', ascending=False).head(top_n)
    top_betweenness = centrality_df.sort_values('betweenness', ascending=False).head(top_n)
    top_eigenvector = centrality_df.sort_values('eigenvector', ascending=False).head(top_n)
    
    return centrality_df, top_degree, top_betweenness, top_eigenvector

centrality_df, top_degree, top_betweenness, top_eigenvector = analyze_centrality(G)

print("\nTop Nodes by Degree Centrality:")
display(top_degree[['name', 'type', 'degree']])

print("\nTop Nodes by Betweenness Centrality:")
display(top_betweenness[['name', 'type', 'betweenness']])

print("\nTop Nodes by Eigenvector Centrality:")
display(top_eigenvector[['name', 'type', 'eigenvector']])



In [None]:
# Plot centrality distributions by node type
def plot_centrality_by_type(centrality_df, measure='degree'):
    plt.figure(figsize=(10, 6))
    
    # Group by type and calculate statistics
    grouped = centrality_df.groupby('type')[measure].agg(['mean', 'median', 'std'])
    sorted_grouped = grouped.sort_values('mean', ascending=False)
    
    # Plot
    ax = sns.barplot(x=sorted_grouped.index, y=sorted_grouped['mean'])
    
    # Add error bars
    ax.errorbar(
        x=range(len(sorted_grouped)), 
        y=sorted_grouped['mean'], 
        yerr=sorted_grouped['std'],
        fmt='none', 
        color='black', 
        capsize=5
    )
    
    plt.title(f'{measure.capitalize()} Centrality by Node Type')
    plt.xlabel('Node Type')
    plt.ylabel(f'Mean {measure.capitalize()} Centrality')
    plt.xticks(rotation=45)
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    plt.tight_layout()
    plt.savefig(f'../docs/figures/{measure}_centrality_by_type.png', dpi=300, bbox_inches='tight')
    plt.show()

# Plot centrality measures by node type
plot_centrality_by_type(centrality_df, measure='degree')
plot_centrality_by_type(centrality_df, measure='betweenness')
plot_centrality_by_type(centrality_df, measure='eigenvector')



In [None]:
# Explore potential drug-disease pairs for prediction
def identify_prediction_candidates(G, max_path_length=3):
    """Identify potential drug-disease pairs for prediction based on network structure"""
    drug_nodes = [n for n, attrs in G.nodes(data=True) if attrs.get('type') == 'drug']
    disease_nodes = [n for n, attrs in G.nodes(data=True) if attrs.get('type') == 'disease']
    
    # Limit to a sample for computational feasibility
    sample_drugs = drug_nodes[:50] if len(drug_nodes) > 50 else drug_nodes
    sample_diseases = disease_nodes[:50] if len(disease_nodes) > 50 else disease_nodes
    
    # Check for existing direct connections
    direct_connections = set()
    for drug in sample_drugs:
        for disease in sample_diseases:
            if G.has_edge(drug, disease) or G.has_edge(disease, drug):
                direct_connections.add((drug, disease))
    
    # Find potential pairs based on shared neighbors
    potential_pairs = []
    
    for drug in sample_drugs:
        drug_name = G.nodes[drug].get('name', drug)
        drug_neighbors = set(G.neighbors(drug))
        
        for disease in sample_diseases:
            # Skip if direct connection exists
            if (drug, disease) in direct_connections:
                continue
                
            disease_name = G.nodes[disease].get('name', disease)
            disease_neighbors = set(G.neighbors(disease))
            
            # Calculate shared neighbors
            shared = drug_neighbors.intersection(disease_neighbors)
            
            if shared:
                # Calculate Jaccard similarity
                jaccard = len(shared) / len(drug_neighbors.union(disease_neighbors))
                
                # Get types of shared neighbors
                shared_types = [G.nodes[n].get('type', 'unknown') for n in shared]
                type_counts = Counter(shared_types)
                
                potential_pairs.append({
                    'drug_id': drug,
                    'drug_name': drug_name,
                    'disease_id': disease,
                    'disease_name': disease_name,
                    'shared_count': len(shared),
                    'jaccard': jaccard,
                    'shared_types': dict(type_counts)
                })
    
    # Convert to DataFrame and sort
    if potential_pairs:
        pairs_df = pd.DataFrame(potential_pairs)
        pairs_df = pairs_df.sort_values('jaccard', ascending=False)
        return pairs_df
    else:
        return pd.DataFrame()



In [None]:
# Check if we have both drugs and diseases in the graph
if 'drug' in node_types_df.index and 'disease' in node_types_df.index:
    print("\nIdentifying potential drug-disease pairs for prediction...")
    prediction_candidates = identify_prediction_candidates(G)
    
    if not prediction_candidates.empty:
        print(f"Found {len(prediction_candidates)} potential drug-disease pairs for prediction")
        print("\nTop candidates by Jaccard similarity:")
        display(prediction_candidates.head(10))
        
        # Plot distribution of Jaccard similarities
        plt.figure(figsize=(10, 6))
        sns.histplot(prediction_candidates['jaccard'], bins=20, kde=True)
        plt.title('Distribution of Jaccard Similarities for Potential Drug-Disease Pairs')
        plt.xlabel('Jaccard Similarity')
        plt.ylabel('Count')
        plt.grid(linestyle='--', alpha=0.7)
        plt.tight_layout()
        plt.savefig('../docs/figures/jaccard_distribution.png', dpi=300, bbox_inches='tight')
        plt.show()
        
        # Plot shared neighbor types
        shared_types = prediction_candidates['shared_types'].apply(pd.Series).fillna(0).sum()
        shared_types = shared_types.sort_values(ascending=False)
        
        plt.figure(figsize=(10, 6))
        shared_types.plot(kind='bar', color='teal')
        plt.title('Types of Shared Neighbors Between Drug-Disease Pairs')
        plt.xlabel('Node Type')
        plt.ylabel('Total Count')
        plt.grid(axis='y', linestyle='--', alpha=0.7)
        plt.tight_layout()
        plt.savefig('../docs/figures/shared_neighbor_types.png', dpi=300, bbox_inches='tight')
        plt.show()
    else:
        print("No potential drug-disease pairs found. This could be due to a small sample or disconnected graph.")
else:
    print("Either drug or disease nodes are missing from the graph")



In [None]:
# Generate summary of findings
print("\n" + "="*50)
print("KNOWLEDGE GRAPH ANALYSIS SUMMARY")
print("="*50)

print(f"\nThe knowledge graph contains {G.number_of_nodes()} nodes and {G.number_of_edges()} edges.")
print(f"Node types: {', '.join(node_types_df.index)}")
print(f"Edge types: {', '.join(edge_types_df.index)}")

# Summarize node degree statistics
print("\nNode degree statistics:")
for node_type, row in node_degree_stats.iterrows():
    print(f"  {node_type}: mean={row['mean']:.2f}, median={row['median']:.1f}, max={row['max']}")

# Summarize centrality findings
if not centrality_df.empty:
    print("\nMost central entities:")
    
    top_degree_node = top_degree.iloc[0]
    print(f"  Most connected: {top_degree_node['name']} ({top_degree_node['type']}) with {top_degree_node['degree']:.4f} degree centrality")
    
    top_betweenness_node = top_betweenness.iloc[0]
    print(f"  Most bridging: {top_betweenness_node['name']} ({top_betweenness_node['type']}) with {top_betweenness_node['betweenness']:.4f} betweenness centrality")
    
    top_eigenvector_node = top_eigenvector.iloc[0]
    print(f"  Most influential: {top_eigenvector_node['name']} ({top_eigenvector_node['type']}) with {top_eigenvector_node['eigenvector']:.4f} eigenvector centrality")

# Summarize drug-disease relationships
if 'drug' in node_types_df.index and 'disease' in node_types_df.index:
    drug_count = node_types_df.loc['drug', 'count']
    disease_count = node_types_df.loc['disease', 'count']
    print(f"\nThe graph contains {drug_count} drugs and {disease_count} diseases.")
    
    if not prediction_candidates.empty:
        print(f"Identified {len(prediction_candidates)} potential drug-disease pairs for prediction.")
        print(f"Top candidate: {prediction_candidates.iloc[0]['drug_name']} - {prediction_candidates.iloc[0]['disease_name']} (Jaccard: {prediction_candidates.iloc[0]['jaccard']:.4f})")

print("\nRecommendations for model development:")
print("  1. Use network centrality to prioritize important entities in the graph")
print("  2. Consider path-based features for drug-disease interaction prediction")
print("  3. Incorporate node type information in the graph neural network design")
print("  4. Pay attention to multi-hop relationships between drugs and diseases")
print("  5. Leverage shared protein targets as important prediction signals")

print("\n" + "="*50)