In [1]:
import networkx as nx
from pyvis.network import Network
import json

def create_enhanced_kg_visualization(kg_data, output_file="enhanced_kg_visualization.html"):
    """
    Create an interactive visualization of the enhanced knowledge graph
    """
    # Create a new network
    net = Network(height="750px", width="100%", bgcolor="#ffffff", font_color="black")
    net.force_atlas_2based()
    
    # Track added nodes to avoid duplicates
    added_nodes = set()
    
    def analyze_entity_types(kg_data):
        """
        Analyze all entity types present in the knowledge graph and create a color scheme
        """
        entity_types = {}
        type_counts = {}
        
        # Extract all instance_of (P31) and subclass_of (P279) relations
        for triple in kg_data['subgraph']['triples']:
            if triple['predicate']['id'] in ['P31', 'P279']:
                target_id = triple['target']['id']
                source_id = triple['source']['id']
                
                # Get the type label from metadata
                target_metadata = kg_data['subgraph']['entities'].get(target_id, {}).get('metadata', {})
                type_label = target_metadata.get('label', target_id)
                
                if target_id not in entity_types:
                    entity_types[target_id] = type_label
                    type_counts[target_id] = 0
                type_counts[target_id] += 1

        # Sort types by frequency
        sorted_types = sorted(type_counts.items(), key=lambda x: x[1], reverse=True)
        
        # Define a broader color palette
        colors = [
            '#ff6b6b',  # Red
            '#4ecdc4',  # Turquoise
            '#45b7d1',  # Blue
            '#96ceb4',  # Soft Green
            '#88d8b0',  # Light Green
            '#ffeead',  # Light Yellow
            '#ff9f80',  # Coral
            '#d4a5a5',  # Dusty Rose
            '#82b74b',  # Green
            '#9b59b6',  # Purple
            '#3498db',  # Blue
            '#e74c3c',  # Red
            '#2ecc71',  # Green
            '#f1c40f',  # Yellow
            '#1abc9c',  # Turquoise
            '#e67e22',  # Orange
            '#95a5a6',  # Gray
            '#16a085',  # Dark Turquoise
            '#d35400',  # Dark Orange
            '#c0392b',  # Dark Red
            '#a3a3a3',  # Light Gray
            '#7f8c8d'   # Dark Gray
        ]
        
        # Create color mapping
        type_colors = {'default': '#7f7f7f'}
        print("\nEntity Types in Knowledge Graph:")
        for i, (type_id, count) in enumerate(sorted_types):
            color_idx = min(i, len(colors) - 1)
            type_colors[type_id] = colors[color_idx]
            print(f"{entity_types[type_id]} ({type_id}): {count} instances -> {colors[color_idx]}")
        
        return type_colors

    # Generate color scheme based on actual data
    type_colors = analyze_entity_types(kg_data)

    def get_entity_type(entity_data):
        """Extract entity type from enhanced entity data"""
        if entity_data and 'data' in entity_data and 'claims' in entity_data['data']:
            instance_claims = entity_data['data'].get('claims', {}).get('P31', [])
            if instance_claims:
                return instance_claims[0].get('mainsnak', {}).get('datavalue', {}).get('value', {}).get('id')
        return 'default'

    def create_tooltip(entity_data):
        """Create tooltip with enhanced entity information"""
        if not entity_data:
            return "No information available"
            
        tooltip = "<div style='max-width: 300px;'>"
        
        # Get metadata from enhanced structure
        metadata = entity_data.get('metadata', {})
        label = metadata.get('label', 'Unknown')
        description = metadata.get('description', '')
        
        tooltip += f"<strong>{label}</strong>"
        if description:
            tooltip += f"<br><em>{description}</em>"
            
        # Add entity type information
        if 'data' in entity_data and 'claims' in entity_data['data']:
            type_claims = entity_data['data']['claims'].get('P31', [])
            subclass_claims = entity_data['data']['claims'].get('P279', [])
            
            if type_claims or subclass_claims:
                tooltip += "<br><br><strong>Classifications:</strong>"
                
            # Add instance of (P31) relations
            for claim in type_claims:
                if 'mainsnak' in claim and 'datavalue' in claim['mainsnak']:
                    type_id = claim['mainsnak']['datavalue']['value']['id']
                    type_entity = kg_data['subgraph']['entities'].get(type_id, {})
                    type_label = type_entity.get('metadata', {}).get('label', type_id)
                    tooltip += f"<br>• Type: {type_label}"
                    
            # Add subclass of (P279) relations
            for claim in subclass_claims:
                if 'mainsnak' in claim and 'datavalue' in claim['mainsnak']:
                    superclass_id = claim['mainsnak']['datavalue']['value']['id']
                    superclass_entity = kg_data['subgraph']['entities'].get(superclass_id, {})
                    superclass_label = superclass_entity.get('metadata', {}).get('label', superclass_id)
                    tooltip += f"<br>• Subclass of: {superclass_label}"
            
        # Add qualifiers if available
        if 'data' in entity_data and 'claims' in entity_data['data']:
            properties = []
            for prop_id, claims in entity_data['data']['claims'].items():
                if prop_id in kg_data['metadata'].get('medical_properties', {}):
                    prop_name = kg_data['metadata']['medical_properties'][prop_id]
                    properties.append(f"<br>• {prop_name}")
            if properties:
                tooltip += "<br><br><strong>Properties:</strong>"
                tooltip += ''.join(properties)
                
        tooltip += "</div>"
        return tooltip

    # Filter to get 3-hop neighborhood of asthma
    asthma_id = 'Q199804'
    nodes_in_scope = {asthma_id}
    current_nodes = {asthma_id}
    
    # Get triples in correct format from enhanced structure
    triples = [(t['source']['id'], t['predicate']['id'], t['target']['id']) 
              for t in kg_data['subgraph']['triples']]
    
    # Perform 3-hop expansion
    for _ in range(3):
        next_nodes = set()
        for triple in triples:
            subject, _, object_id = triple
            if subject in current_nodes:
                next_nodes.add(object_id)
            if object_id in current_nodes:
                next_nodes.add(subject)
        current_nodes = next_nodes
        nodes_in_scope.update(next_nodes)

    # Filter triples to only those involving nodes in scope
    filtered_triples = [
        triple for triple in kg_data['subgraph']['triples']
        if triple['source']['id'] in nodes_in_scope and 
        triple['target']['id'] in nodes_in_scope
    ]

    # Process nodes
    entities = kg_data['subgraph']['entities']
    for triple in filtered_triples:
        for entity_type in ['source', 'target']:
            entity_id = triple[entity_type]['id']
            if entity_id not in added_nodes:
                entity_data = entities.get(entity_id, {})
                entity_type_id = get_entity_type(entity_data)
                
                # Special styling for asthma node
                if entity_id == asthma_id:
                    net.add_node(
                        entity_id,
                        label="ASTHMA",
                        title=create_tooltip(entity_data),
                        color='#ff0000',
                        size=100,
                        borderWidth=4,
                        borderWidthSelected=8,
                        font={'size': 20, 'color': 'black', 'face': 'arial', 'bold': True}
                    )
                else:
                    # Regular node
                    label = entity_data.get('metadata', {}).get('label', entity_id)
                    net.add_node(
                        entity_id,
                        label=label,
                        title=create_tooltip(entity_data),
                        color=type_colors.get(entity_type_id, type_colors['default']),
                        size=30
                    )
                added_nodes.add(entity_id)

        # Add edge with enhanced information
        predicate = triple['predicate']
        property_labels = {
            'P780': 'has symptom',
            'P828': 'has cause',
            'P927': 'located in',
            'P2176': 'treated by',
            'P1050': 'medical condition',
            'P2175': 'treats',
            'P31': 'is a',
            'P279': 'subclass of'
        }
        
        edge_label = property_labels.get(predicate['id'], predicate['label'])
        edge_color = '#7c7c7c'
        
        if predicate['id'] in ['P780', 'P828']:  # symptoms and causes
            edge_color = '#ff9999'
        elif predicate['id'] in ['P2176', 'P2175']:  # treatments
            edge_color = '#99ccff'
        elif predicate['id'] in ['P31', 'P279']:  # classification
            edge_color = '#c2c2c2'
            
        # Create edge tooltip with qualifiers
        edge_tooltip = f"Relationship: {edge_label}"
        if triple.get('qualifiers'):
            edge_tooltip += "\nQualifiers:"
            for qualifier in triple['qualifiers']:
                edge_tooltip += f"\n• {qualifier['property']}: {qualifier['value']}"
        
        net.add_edge(
            triple['source']['id'],
            triple['target']['id'],
            label=edge_label,
            title=edge_tooltip,
            color=edge_color,
            arrows={'to': {'enabled': True}}
        )

    # Configure physics settings
    net.set_options("""
    var options = {
        "physics": {
            "forceAtlas2Based": {
                "gravitationalConstant": -50,
                "centralGravity": 0.01,
                "springLength": 200,
                "springConstant": 0.08
            },
            "maxVelocity": 50,
            "solver": "forceAtlas2Based",
            "timestep": 0.35,
            "stabilization": {"iterations": 150}
        },
        "edges": {
            "color": {"inherit": false},
            "smooth": {"type": "continuous"},
            "length": 200,
            "font": {
                "size": 10
            }
        },
        "nodes": {
            "font": {
                "size": 12,
                "face": "arial"
            },
            "borderWidth": 2,
            "borderWidthSelected": 4
        }
    }
    """)
    
    # Save the visualization
    net.save_graph(output_file)
    print(f"Visualization saved to {output_file}")
    print(f"Number of nodes in visualization: {len(added_nodes)}")
    print(f"Number of edges in visualization: {len(filtered_triples)}")

def main():
    # Load the enhanced knowledge graph
    with open('enhanced_copd_kg_4hops.json', 'r') as f:
        kg_data = json.load(f)
        
    # Create visualization
    create_enhanced_kg_visualization(
        kg_data=kg_data,
        output_file="enhanced_asthma_knowledge_graph.html"
    )

if __name__ == "__main__":
    main()


Entity Types in Knowledge Graph:
gene (Q7187): 697 instances -> #ff6b6b
protein-coding gene (Q20747295): 686 instances -> #4ecdc4
type of chemical entity (Q113145171): 274 instances -> #45b7d1
chemical compound (Q11173): 140 instances -> #96ceb4
class of disease (Q112193867): 118 instances -> #88d8b0
disease (Q12136): 81 instances -> #ffeead
symptom or sign (Q112965645): 59 instances -> #ff9f80
medication (Q12140): 44 instances -> #d4a5a5
group of stereoisomers (Q59199015): 38 instances -> #82b74b
clinical sign (Q1441305): 31 instances -> #9b59b6
symptom (Q169872): 26 instances -> #3498db
mixture (Q169336): 24 instances -> #e74c3c
Q193430 (Q193430): 23 instances -> #2ecc71
Q79529 (Q79529): 21 instances -> #f1c40f
medical specialty (Q930752): 17 instances -> #1abc9c
Q427087 (Q427087): 15 instances -> #e67e22
academic discipline (Q11862829): 14 instances -> #95a5a6
metaclass (Q19478619): 14 instances -> #16a085
rare disease (Q929833): 13 instances -> #d35400
physiological condition (Q71