In [3]:
# Importing required libraries
import pandas as pd
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
from collections import Counter
import numpy as np
from tqdm import tqdm
import pickle

In [4]:
# Importing the data
!wget -O kg.csv https://dataverse.harvard.edu/api/access/datafile/6180620

--2025-06-07 06:05:00--  https://dataverse.harvard.edu/api/access/datafile/6180620
Resolving dataverse.harvard.edu (dataverse.harvard.edu)... 3.212.84.5, 54.225.229.227, 52.86.99.82
Connecting to dataverse.harvard.edu (dataverse.harvard.edu)|3.212.84.5|:443... connected.
HTTP request sent, awaiting response... 303 See Other
Location: https://dvn-cloud.s3.amazonaws.com/10.7910/DVN/IXA7BM/1805e679c4c-72137dbedbf1?response-content-disposition=attachment%3B%20filename%2A%3DUTF-8%27%27kg.csv&response-content-type=text%2Fcsv&X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Date=20250607T060500Z&X-Amz-SignedHeaders=host&X-Amz-Expires=3600&X-Amz-Credential=AKIAIEJ3NV7UYCSRJC7A%2F20250607%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Signature=22bd37770c6fcfd8c63523e2e154150c49525db4572dbebfedfd43c9765c6349 [following]
--2025-06-07 06:05:00--  https://dvn-cloud.s3.amazonaws.com/10.7910/DVN/IXA7BM/1805e679c4c-72137dbedbf1?response-content-disposition=attachment%3B%20filename%2A%3DUTF-8%27%27kg.csv&response-conten

In [5]:
df = pd.read_csv('kg.csv', low_memory=False)
df.head()

Unnamed: 0,relation,display_relation,x_index,x_id,x_type,x_name,x_source,y_index,y_id,y_type,y_name,y_source
0,protein_protein,ppi,0,9796,gene/protein,PHYHIP,NCBI,8889,56992,gene/protein,KIF15,NCBI
1,protein_protein,ppi,1,7918,gene/protein,GPANK1,NCBI,2798,9240,gene/protein,PNMA1,NCBI
2,protein_protein,ppi,2,8233,gene/protein,ZRSR2,NCBI,5646,23548,gene/protein,TTC33,NCBI
3,protein_protein,ppi,3,4899,gene/protein,NRF1,NCBI,11592,11253,gene/protein,MAN1B1,NCBI
4,protein_protein,ppi,4,5297,gene/protein,PI4KA,NCBI,2122,8601,gene/protein,RGS20,NCBI


In [6]:
print(f"Original dataset shape: {df.shape}")
print(f"Columns: {df.columns.tolist()}")

Original dataset shape: (8100498, 12)
Columns: ['relation', 'display_relation', 'x_index', 'x_id', 'x_type', 'x_name', 'x_source', 'y_index', 'y_id', 'y_type', 'y_name', 'y_source']


In [8]:
# Create nodes dataframe
# Extract X nodes (source nodes)
x_nodes = df[['x_id', 'x_name', 'x_type', 'x_source']].copy()
x_nodes.columns = ['id', 'name', 'type', 'source']

# Extract Y nodes (target nodes)
y_nodes = df[['y_id', 'y_name', 'y_type', 'y_source']].copy()
y_nodes.columns = ['id', 'name', 'type', 'source']

# Combine all nodes and remove duplicates
all_nodes = pd.concat([x_nodes, y_nodes], ignore_index=True)
nodes = all_nodes.drop_duplicates(subset=['id']).reset_index(drop=True)

# Clean up any potential NaN values and strip whitespace from string columns
for col in ['name', 'type', 'source']:
    if col in nodes.columns:
        nodes[col] = nodes[col].astype(str).str.strip()

print(f"Total unique nodes: {len(nodes)}")
print(f"Node types distribution:")
print(nodes['type'].value_counts())



Total unique nodes: 135010
Node types distribution:
type
gene/protein          44553
biological_process    35362
effect/phenotype      12393
disease                9846
anatomy                9407
molecular_function     8814
drug                   7957
cellular_component     3344
pathway                2516
exposure                818
Name: count, dtype: int64


In [9]:
# Create edges dataframe
edges = df[['x_id', 'y_id', 'relation', 'display_relation']].copy()

# Remove any duplicate edges (same source-target-relation combination)
edges = edges.drop_duplicates().reset_index(drop=True)

# Clean up relation columns
edges['relation'] = edges['relation'].astype(str).str.strip()
edges['display_relation'] = edges['display_relation'].astype(str).str.strip()

print(f"Total unique edges: {len(edges)}")
print(f"Relation types distribution:")
print(edges['relation'].value_counts())

# Verify data integrity
print("\n=== Data Integrity Checks ===")

# Check if all edge nodes exist in nodes
x_ids_in_edges = set(edges['x_id'].unique())
y_ids_in_edges = set(edges['y_id'].unique())
all_edge_node_ids = x_ids_in_edges.union(y_ids_in_edges)
node_ids = set(nodes['id'].unique())

missing_nodes = all_edge_node_ids - node_ids
if missing_nodes:
    print(f"Warning: {len(missing_nodes)} node IDs in edges are missing from nodes")
    print(f"First few missing IDs: {list(missing_nodes)[:10]}")
else:
    print("✓ All edge node IDs are present in nodes")

# Check for any missing values
print(f"Nodes with missing values: {nodes.isnull().sum().sum()}")
print(f"Edges with missing values: {edges.isnull().sum().sum()}")

# Save to CSV files
nodes.to_csv('nodes.csv', index=False)
edges.to_csv('edges.csv', index=False)


Total unique edges: 8097947
Relation types distribution:
relation
anatomy_protein_present       3033894
drug_drug                     2672628
protein_protein                642150
disease_phenotype_positive     300634
bioprocess_protein             289603
cellcomp_protein               166782
disease_protein                160819
molfunc_protein                139053
drug_effect                    129568
bioprocess_bioprocess          105772
pathway_protein                 85292
disease_disease                 64388
contraindication                61350
drug_protein                    51306
anatomy_protein_absent          39774
phenotype_phenotype             37472
anatomy_anatomy                 28064
molfunc_molfunc                 27148
indication                      18776
cellcomp_cellcomp                9690
phenotype_protein                6660
off-label use                    5136
pathway_pathway                  5070
exposure_disease                 4608
exposure_exposure     

In [10]:
print(f"\nFiles Created")
print(f"nodes.csv: {len(nodes)} rows, {len(nodes.columns)} columns")
print(f"edges.csv: {len(edges)} rows, {len(edges.columns)} columns")

# Display sample data
print(f"\n Sample Nodes")
print(nodes.head())

print(f"\nSample Edges")
print(edges.head())

# Summary statistics
print(f"\n Summary Statistics")
print(f"Original relations: {len(df)}")
print(f"Unique edges after deduplication: {len(edges)}")
print(f"Unique nodes: {len(nodes)}")
print(f"Average degree per node: {2 * len(edges) / len(nodes):.2f}")

# Show node type breakdown
print(f"\nNode Type Breakdown")
for node_type in nodes['type'].unique():
    count = len(nodes[nodes['type'] == node_type])
    print(f"{node_type}: {count} nodes")

# Show relation type breakdown
print(f"\nRelation Type Breakdown")
for relation in edges['relation'].unique():
    count = len(edges[edges['relation'] == relation])
    print(f"{relation}: {count} edges")


=== Files Created ===
nodes.csv: 135010 rows, 4 columns
edges.csv: 8097947 rows, 4 columns

=== Sample Nodes ===
     id    name          type source
0  9796  PHYHIP  gene/protein   NCBI
1  7918  GPANK1  gene/protein   NCBI
2  8233   ZRSR2  gene/protein   NCBI
3  4899    NRF1  gene/protein   NCBI
4  5297   PI4KA  gene/protein   NCBI

=== Sample Edges ===
   x_id   y_id         relation display_relation
0  9796  56992  protein_protein              ppi
1  7918   9240  protein_protein              ppi
2  8233  23548  protein_protein              ppi
3  4899  11253  protein_protein              ppi
4  5297   8601  protein_protein              ppi

=== Summary Statistics ===
Original relations: 8100498
Unique edges after deduplication: 8097947
Unique nodes: 135010
Average degree per node: 119.96

=== Node Type Breakdown ===
gene/protein: 44553 nodes
drug: 7957 nodes
effect/phenotype: 12393 nodes
disease: 9846 nodes
biological_process: 35362 nodes
molecular_function: 8814 nodes
cellular_com

In [11]:
# Verifying data integrity
print("\nData Integrity Checks:\n")

# Checking if all edge nodes exist in nodes
x_ids_in_edges = set(edges['x_id'].unique())
y_ids_in_edges = set(edges['y_id'].unique())
all_edge_node_ids = x_ids_in_edges.union(y_ids_in_edges)
node_ids = set(nodes['id'].unique())

missing_nodes = all_edge_node_ids - node_ids
if missing_nodes:
    print(f"Warning: {len(missing_nodes)} node IDs in edges are missing from nodes")
    print(f"First few missing IDs: {list(missing_nodes)[:10]}")
else:
    print("All edge node IDs are present in nodes")

# Checking for any missing values
print(f"Nodes with missing values: {nodes.isnull().sum().sum()}")
print(f"Edges with missing values: {edges.isnull().sum().sum()}")


Data Integrity Checks:

All edge node IDs are present in nodes
Nodes with missing values: 0
Edges with missing values: 0


In [12]:
# Saving to CSV files
nodes.to_csv('nodes.csv', index=False)
edges.to_csv('edges.csv', index=False)

print(f"Files Created:")
print(f"nodes.csv: {len(nodes)} rows, {len(nodes.columns)} columns")
print(f"edges.csv: {len(edges)} rows, {len(edges.columns)} columns")

Files Created:
nodes.csv: 135010 rows, 4 columns
edges.csv: 8097947 rows, 4 columns


In [13]:
# Displaying sample data
print(f"\nSample Nodes:")
print(nodes.head())

print(f"\nSample Edges:")
print(edges.head())


Sample Nodes:
     id    name          type source
0  9796  PHYHIP  gene/protein   NCBI
1  7918  GPANK1  gene/protein   NCBI
2  8233   ZRSR2  gene/protein   NCBI
3  4899    NRF1  gene/protein   NCBI
4  5297   PI4KA  gene/protein   NCBI

Sample Edges:
   x_id   y_id         relation display_relation
0  9796  56992  protein_protein              ppi
1  7918   9240  protein_protein              ppi
2  8233  23548  protein_protein              ppi
3  4899  11253  protein_protein              ppi
4  5297   8601  protein_protein              ppi


In [14]:
# Summary statistics
print(f"\nSummary Statistics:")
print(f"Original relations: {len(df)}")
print(f"Unique edges after deduplication: {len(edges)}")
print(f"Unique nodes: {len(nodes)}")
print(f"Average degree per node: {2 * len(edges) / len(nodes):.2f}")


Summary Statistics:
Original relations: 8100498
Unique edges after deduplication: 8097947
Unique nodes: 135010
Average degree per node: 119.96


In [15]:
# Node type breakdown
print(f"\nNode Type Breakdown:")
for node_type in nodes['type'].unique():
    count = len(nodes[nodes['type'] == node_type])
    print(f"{node_type}: {count} nodes")

# Relation type breakdown
print(f"\n=== Relation Type Breakdown ===")
for relation in edges['relation'].unique():
    count = len(edges[edges['relation'] == relation])
    print(f"{relation}: {count} edges")


Node Type Breakdown:
gene/protein: 44553 nodes
drug: 7957 nodes
effect/phenotype: 12393 nodes
disease: 9846 nodes
biological_process: 35362 nodes
molecular_function: 8814 nodes
cellular_component: 3344 nodes
exposure: 818 nodes
pathway: 2516 nodes
anatomy: 9407 nodes

=== Relation Type Breakdown ===
protein_protein: 642150 edges
drug_protein: 51306 edges
contraindication: 61350 edges
indication: 18776 edges
off-label use: 5136 edges
drug_drug: 2672628 edges
phenotype_protein: 6660 edges
phenotype_phenotype: 37472 edges
disease_phenotype_negative: 2386 edges
disease_phenotype_positive: 300634 edges
disease_protein: 160819 edges
disease_disease: 64388 edges
drug_effect: 129568 edges
bioprocess_bioprocess: 105772 edges
molfunc_molfunc: 27148 edges
cellcomp_cellcomp: 9690 edges
molfunc_protein: 139053 edges
cellcomp_protein: 166782 edges
bioprocess_protein: 289603 edges
exposure_protein: 2424 edges
exposure_disease: 4608 edges
exposure_exposure: 4140 edges
exposure_bioprocess: 3250 edges


In [16]:
print(f"Number of rows in nodes: {nodes.shape[0]}")
print(f"Number of rows in edges: {edges.shape[0]}")

Number of rows in nodes: 135010
Number of rows in edges: 8097947


In [17]:
display(nodes.head())
display(edges.head())

Unnamed: 0,id,name,type,source
0,9796,PHYHIP,gene/protein,NCBI
1,7918,GPANK1,gene/protein,NCBI
2,8233,ZRSR2,gene/protein,NCBI
3,4899,NRF1,gene/protein,NCBI
4,5297,PI4KA,gene/protein,NCBI


Unnamed: 0,x_id,y_id,relation,display_relation
0,9796,56992,protein_protein,ppi
1,7918,9240,protein_protein,ppi
2,8233,23548,protein_protein,ppi
3,4899,11253,protein_protein,ppi
4,5297,8601,protein_protein,ppi


In [18]:
dupes = nodes[nodes.duplicated(subset='id', keep=False)]
display(dupes)

Unnamed: 0,id,name,type,source


In [19]:
# Building the knowledge graph
def build_knowledge_graph(nodes_df, edges_df):

    # Building a NetworkX graph from nodes and edges DataFrames
    G = nx.Graph()

    # Adding nodes with attributes
    for _, node in tqdm(nodes_df.iterrows(), total=len(nodes_df), desc="Nodes"):
        G.add_node(
            node['id'],
            name=node['name'],
            type=node['type'],
            source=node['source']
        )

    # Adding edges with attributes
    for _, edge in tqdm(edges_df.iterrows(), total=len(edges_df), desc="Edges"):
        if edge['x_id'] in G.nodes and edge['y_id'] in G.nodes:
            G.add_edge(
                edge['x_id'],
                edge['y_id'],
                relation=edge['relation'],
                display_relation=edge['display_relation']
            )

    return G

kg_graph = build_knowledge_graph(nodes, edges)

Nodes: 100%|██████████| 135010/135010 [00:06<00:00, 21799.24it/s]
Edges: 100%|██████████| 8097947/8097947 [08:08<00:00, 16580.39it/s]


In [20]:
# Saving the graph for later use
with open('primekg_graph.pkl', 'wb') as f:
    pickle.dump(kg_graph, f)
print("Graph saved as 'primekg_graph.pkl'")

Graph saved as 'primekg_graph.pkl'


In [21]:
# Graph statistics
print("\nKnowledge Graph Statistics:")
print(f"Number of nodes: {kg_graph.number_of_nodes():,}")
print(f"Number of edges: {kg_graph.number_of_edges():,}")
print(f"Is connected: {nx.is_connected(kg_graph)}")
print(f"Number of connected components: {nx.number_connected_components(kg_graph)}")


Knowledge Graph Statistics:
Number of nodes: 135,010
Number of edges: 4,395,511
Is connected: False
Number of connected components: 42


In [22]:
# Finding all connected components from your existing graph
components = list(nx.connected_components(kg_graph))

# Sorting components by size (largest first)
components_sorted = sorted(components, key=len, reverse=True)

# Getting component sizes
component_sizes = [len(comp) for comp in components_sorted]

print(f"Total connected components: {len(components)}")
print(f"Total nodes in graph: {kg_graph.number_of_nodes()}")

Total connected components: 42
Total nodes in graph: 135010


In [23]:
# Top 10 component sizes
print("Top 10 component sizes:")
for i, size in enumerate(component_sizes[:10]):
    percentage = (size / kg_graph.number_of_nodes()) * 100
    print(f"Component {i+1}: {size:,} nodes ({percentage:.2f}%)")

print()

Top 10 component sizes:
Component 1: 134,877 nodes (99.90%)
Component 2: 8 nodes (0.01%)
Component 3: 7 nodes (0.01%)
Component 4: 6 nodes (0.00%)
Component 5: 5 nodes (0.00%)
Component 6: 5 nodes (0.00%)
Component 7: 5 nodes (0.00%)
Component 8: 5 nodes (0.00%)
Component 9: 4 nodes (0.00%)
Component 10: 4 nodes (0.00%)



In [24]:
# Getting the largest component
largest_component = components_sorted[0]
print(f"Largest component has {len(largest_component):,} nodes")
print(f"This represents {(len(largest_component)/kg_graph.number_of_nodes())*100:.2f}% of all nodes")

Largest component has 134,877 nodes
This represents 99.90% of all nodes


In [25]:
# Showing some example nodes from the largest component with their attributes
print(f"\nSample nodes from largest component:")
sample_nodes = list(largest_component)[:10]
for node in sample_nodes:
    node_data = kg_graph.nodes[node]
    print(f"  - {node}: {node_data['name']} ({node_data['type']})")


Sample nodes from largest component:
  - 1: A1BG (gene/protein)
  - 2: A2M (gene/protein)
  - 3: naris (anatomy)
  - 84695: LOXL3 (gene/protein)
  - 4: nose (anatomy)
  - 6: islet of Langerhans (anatomy)
  - 7: pituitary gland (anatomy)
  - 8: Abnormal morphology of female internal genitalia (effect/phenotype)
  - 9: NAT1 (gene/protein)
  - 10: NAT2 (gene/protein)


In [26]:
# Creating subgraph of just the largest component
largest_subgraph = kg_graph.subgraph(largest_component)

print(f"\nLargest component statistics:")
print(f"Nodes: {largest_subgraph.number_of_nodes():,}")
print(f"Edges: {largest_subgraph.number_of_edges():,}")
print(f"Density: {nx.density(largest_subgraph):.6f}")


Largest component statistics:
Nodes: 134,877
Edges: 4,395,416
Density: 0.000483


In [27]:
# Showing distribution of smaller components
if len(components) > 1:
    small_components = component_sizes[1:]
    print(f"\nSmaller components:")
    print(f"Number of small components: {len(small_components)}")
    print(f"Largest small component: {max(small_components)} nodes")
    print(f"Smallest components: {min(small_components)} nodes")
    print(f"Average size of small components: {sum(small_components)/len(small_components):.1f} nodes")

    # Showing details about the top 5 smaller components
    print(f"\nTop 5 smaller components details:")
    for i in range(1, min(6, len(components_sorted))):
        comp = components_sorted[i]
        print(f"Component {i+1} ({len(comp)} nodes):")
        sample_nodes = list(comp)[:5]
        for node in sample_nodes:
            node_data = kg_graph.nodes[node]
            print(f"    - {node_data['name']} ({node_data['type']})")
        if len(comp) > 5:
            print(f"    ... and {len(comp) - 5} more nodes")
        print()


Smaller components:
Number of small components: 41
Largest small component: 8 nodes
Smallest components: 2 nodes
Average size of small components: 3.2 nodes

Top 5 smaller components details:
Component 2 (8 nodes):
    - suppression by virus of host cell cycle arrest (biological_process)
    - suppression by virus of host exit from mitosis (biological_process)
    - suppression by virus of G2/M transition of host mitotic cell cycle (biological_process)
    - modulation by virus of host cell cycle (biological_process)
    - modification by virus of host cell cycle regulation (biological_process)
    ... and 3 more nodes

Component 3 (7 nodes):
    - rolling circle viral DNA replication (biological_process)
    - viral DNA genome replication (biological_process)
    - rolling hairpin viral DNA replication (biological_process)
    - bidirectional double-stranded viral DNA replication (biological_process)
    - viral DNA strand displacement replication (biological_process)
    ... and 2 m

In [28]:
# Analyzing node types
node_types = [kg_graph.nodes[node]['type'] for node in kg_graph.nodes()]
type_counts = Counter(node_types)
print(f"\nNode type distribution:")
for node_type, count in type_counts.most_common():
    print(f"  {node_type}: {count:,}")


Node type distribution:
  gene/protein: 44,553
  biological_process: 35,362
  effect/phenotype: 12,393
  disease: 9,846
  anatomy: 9,407
  molecular_function: 8,814
  drug: 7,957
  cellular_component: 3,344
  pathway: 2,516
  exposure: 818


In [29]:
# Analyzing edge relations
edge_relations = [kg_graph.edges[edge]['relation'] for edge in kg_graph.edges()]
relation_counts = Counter(edge_relations)
print(f"\nTop 10 relation types:")
for relation, count in relation_counts.most_common(10):
    print(f"  {relation}: {count:,}")


Top 10 relation types:
  anatomy_protein_present: 1,534,770
  drug_drug: 1,336,314
  protein_protein: 374,527
  disease_phenotype_positive: 210,751
  bioprocess_protein: 182,933
  disease_protein: 145,895
  drug_effect: 91,605
  cellcomp_protein: 83,069
  pathway_protein: 76,236
  bioprocess_bioprocess: 72,296


In [30]:
# Calculating basic network metrics
print(f"Average degree: {np.mean([d for n, d in kg_graph.degree()]):.2f}")
print(f"Network density: {nx.density(kg_graph):.6f}")

# Finding highly connected nodes (hubs)
degree_centrality = nx.degree_centrality(kg_graph)
top_hubs = sorted(degree_centrality.items(), key=lambda x: x[1], reverse=True)[:10]
print(f"\nTop 10 most connected nodes:")
for node_id, centrality in top_hubs:
    node_name = kg_graph.nodes[node_id]['name']
    node_type = kg_graph.nodes[node_id]['type']
    degree = kg_graph.degree[node_id]
    print(f"  {node_name} ({node_type}): {degree} connections")

Average degree: 65.11
Network density: 0.000482

Top 10 most connected nodes:
  ATF4 (gene/protein): 17460 connections
  ETFA (gene/protein): 16944 connections
  RERE (gene/protein): 16839 connections
  AP2A1 (gene/protein): 16833 connections
  KRT83 (gene/protein): 16786 connections
  prostate gland (anatomy): 16649 connections
  spleen (anatomy): 16601 connections
  CD33 (gene/protein): 16544 connections
  CD52 (gene/protein): 16539 connections
  dorsolateral prefrontal cortex (anatomy): 16515 connections


In [31]:
# Function to query the graph
def query_node(graph, node_id):
    """Query information about a specific node"""
    if node_id not in graph.nodes:
        print(f"Node {node_id} not found in graph")
        return None

    node_data = graph.nodes[node_id]
    neighbors = list(graph.neighbors(node_id))

    print(f"\nNode Information:")
    print(f"  ID: {node_id}")
    print(f"  Name: {node_data['name']}")
    print(f"  Type: {node_data['type']}")
    print(f"  Source: {node_data['source']}")
    print(f"  Degree: {graph.degree[node_id]}")
    print(f"  Number of neighbors: {len(neighbors)}")

    return {
        'id': node_id,
        'data': node_data,
        'neighbors': neighbors,
        'degree': graph.degree[node_id]
    }

# Function to find shortest path between nodes
def find_path(graph, source, target, max_length=5):
    """Find shortest path between two nodes"""
    try:
        path = nx.shortest_path(graph, source, target)
        if len(path) <= max_length:
            print(f"\nShortest path from {source} to {target}:")
            for i, node in enumerate(path):
                node_name = graph.nodes[node]['name']
                node_type = graph.nodes[node]['type']
                print(f"  {i+1}. {node_name} ({node_type})")
                if i < len(path) - 1:
                    edge_data = graph.edges[node, path[i+1]]
                    print(f"     --[{edge_data['display_relation']}]-->")
            return path
        else:
            print(f"Path too long ({len(path)} steps)")
            return None
    except nx.NetworkXNoPath:
        print(f"No path found between {source} and {target}")
        return None

# Function to get subgraph around a node
def get_subgraph(graph, center_node, radius=1):
    """Extract subgraph around a center node within given radius"""
    if center_node not in graph.nodes:
        print(f"\nNode {center_node} not found")
        return None

    # Getting all nodes within radius
    nodes_in_radius = set([center_node])
    current_nodes = set([center_node])

    for _ in range(radius):
        next_nodes = set()
        for node in current_nodes:
            next_nodes.update(graph.neighbors(node))
        nodes_in_radius.update(next_nodes)
        current_nodes = next_nodes

    # Creating subgraph
    subgraph = graph.subgraph(nodes_in_radius).copy()
    print(f"\nSubgraph around {center_node} (radius {radius}):")
    print(f"  Nodes: {subgraph.number_of_nodes()}")
    print(f"  Edges: {subgraph.number_of_edges()}")

    return subgraph

# Function to search nodes by name or type
def search_nodes(graph, query, search_type='name', limit=10):
    """Search for nodes by name or type"""
    results = []

    for node_id in graph.nodes():
        node_data = graph.nodes[node_id]
        if search_type == 'name':
            if query.lower() in node_data['name'].lower():
                results.append((node_id, node_data))
        elif search_type == 'type':
            if query.lower() in node_data['type'].lower():
                results.append((node_id, node_data))

    print(f"\nSearch results for '{query}' in {search_type} (showing first {limit}):")
    for i, (node_id, node_data) in enumerate(results[:limit]):
        print(f"  {i+1}. {node_data['name']} ({node_data['type']}) - ID: {node_id}")

    return results[:limit]

# Example usage functions
# Searching for diabetes-related nodes
diabetes_nodes = search_nodes(kg_graph, 'diabetes', 'name', 5)

# If diabetes nodes is found, explore one
if diabetes_nodes:
    example_node = diabetes_nodes[0][0]  # Get first result's ID
    query_node(kg_graph, example_node)

    # Getting subgraph around this node
    subgraph = get_subgraph(kg_graph, example_node, radius=2)


Search results for 'diabetes' in name (showing first 5):
  1. maturity-onset diabetes of the young (disease) - ID: 14589_18911_10894_11668_12818_7452_13240_12348_12513_13242_7453_14674_11667
  2. permanent neonatal diabetes mellitus (disease) - ID: 100164_16391
  3. neonatal diabetes mellitus (disease) - ID: 20525_12522_30089_30088_11073_12480_30087_100165
  4. X-linked intellectual disability-limb spasticity-retinal dystrophy-diabetes insipidus syndrome (disease) - ID: 18495
  5. type 2 diabetes mellitus (disease) - ID: 5148

Node Information:
  ID: 14589_18911_10894_11668_12818_7452_13240_12348_12513_13242_7453_14674_11667
  Name: maturity-onset diabetes of the young
  Type: disease
  Source: MONDO_grouped
  Degree: 69
  Number of neighbors: 69

Subgraph around 14589_18911_10894_11668_12818_7452_13240_12348_12513_13242_7453_14674_11667 (radius 2):
  Nodes: 14006
  Edges: 1337988


In [36]:
from pyvis.network import Network

# Create a PyVis network
net = Network(height='600px', width='100%', bgcolor='#222222', font_color='white', notebook=True)

# Set static layout
net.set_options("""
var options = {
  "physics": {
    "enabled": false
  },
  "layout": {
    "improvedLayout": true
  },
  "nodes": {
    "shape": "dot",
    "size": 10,
    "font": {
      "size": 12,
      "color": "white"
    }
  },
  "edges": {
    "color": {
      "inherit": true
    },
    "smooth": false
  }
}
""")

# Select 200 nodes
sample_nodes = list(kg_graph.nodes)[:200]
subgraph = kg_graph.subgraph(sample_nodes)

# Add nodes with name as label
for node, attrs in subgraph.nodes(data=True):
    net.add_node(
        node,
        label=attrs.get('name', str(node)),
        title=f"Type: {attrs.get('type', 'N/A')}, Source: {attrs.get('source', 'N/A')}"
    )

# Add edges
for source, target, attrs in subgraph.edges(data=True):
    net.add_edge(source, target, title=attrs.get('relation', ''))

# Generate the graph
net.show("kg_graph_named.html")


kg_graph_named.html
