In [7]:
"""
Simple Knowledge Graph Query Tools
Two functions: Get Gene ID + Find Diseases
"""

import requests
import json
from typing import Dict, List


class KnowledgeGraphTools:

    def __init__(self):
        self.base_url = "https://api.bte.ncats.io/v1"
        self.ncbi_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi"

    # for finding GENE -ID
    def get_gene_id(self, gene_symbol: str) -> str:
        """Look up Entrez Gene ID from gene symbol"""
        params = {
            'db': 'gene',
            'term': f"{gene_symbol}[Gene Name] AND human[Organism]",
            'retmode': 'json'
        }

        try:
            response = requests.get(self.ncbi_url, params=params, timeout=10)
            data = response.json()
            ids = data.get('esearchresult', {}).get('idlist', [])

            if ids:
                return ids[0]
            return None

        except Exception as e:
            print(f"Error looking up {gene_symbol}: {e}")
            return None

    # FUNCTION 2: Find Diseases
    def find_diseases(self, gene_symbol: str, gene_id: str) -> Dict:
        """Find diseases associated with a gene"""

        query = {
            "message": {
                "query_graph": {
                    "nodes": {
                        "gene": {
                            "ids": [f"NCBIGene:{gene_id}"],
                            "categories": ["biolink:Gene"]
                        },
                        "disease": {
                            "categories": ["biolink:Disease"]
                        }
                    },
                    "edges": {
                        "e": {
                            "subject": "gene",
                            "object": "disease",
                            "predicates": ["biolink:related_to"]
                        }
                    }
                }
            }
        }

        try:
            response = requests.post(f"{self.base_url}/query", json=query, timeout=120)

            if response.status_code == 200:
                data = response.json()
                diseases = self._parse_results(data)

                return {
                    "gene_symbol": gene_symbol,
                    "gene_id": gene_id,
                    "diseases": diseases,
                    "count": len(diseases),
                    "status": "success"
                }

            else:
                return {
                    "gene_symbol": gene_symbol,
                    "error": f"API error: {response.status_code}",
                    "status": "error"
                }

        except Exception as e:
            return {
                "gene_symbol": gene_symbol,
                "error": str(e),
                "status": "error"
            }

    def _parse_results(self, data: Dict) -> List[Dict]:
        """Extract the disease ID + name from TRAPI response"""

        diseases = []
        results = data.get('message', {}).get('results', [])

        for result in results[:10]:
            disease_nodes = result.get('node_bindings', {}).get('disease', [])

            for node in disease_nodes:
                curie = node.get("id", "Unknown")

                # Try to extract label (name)
                name = curie
                kg_nodes = data.get("message", {}).get("knowledge_graph", {}).get("nodes", {})
                if curie in kg_nodes:
                    name = kg_nodes[curie].get("name", curie)

                diseases.append({
                    "id": curie,
                    "name": name
                })

        return diseases


"""
Test with single gene - Validation below
"""

# Initialize
kg = KnowledgeGraphTools()

# Choose a gene to test
test_gene = "APOE"

print(f"Testing with gene: {test_gene}")
print("=" * 60)

# Step 1: Get ID
print("\nStep 1: Looking up gene ID...")

gene_id = kg.get_gene_id(test_gene)

if gene_id:
    print(f"✓ Found ID: {gene_id}")

    # Step 2: Query Knowledge Graph
    print("\nStep 2: Querying Knowledge Graph...")
    result = kg.find_diseases(test_gene, gene_id)

    print("\nResults:")
    print(json.dumps(result, indent=2))

    if result['status'] == 'success':
        print(f"\n✓ Found {result['count']} disease associations")
        print("\nDiseases:")
        for i, disease in enumerate(result['diseases'], 1):
            print(f"  {i}. {disease['name']} ({disease['id']})")

else:
    print(f"✗ Could not find gene ID for {test_gene}")


Testing with gene: APOE

Step 1: Looking up gene ID...
✓ Found ID: 348

Step 2: Querying Knowledge Graph...

Results:
{
  "gene_symbol": "APOE",
  "gene_id": "348",
  "diseases": [
    {
      "id": "MONDO:0004975",
      "name": "Alzheimer disease"
    },
    {
      "id": "MONDO:0010017",
      "name": "sea-blue histiocyte syndrome"
    },
    {
      "id": "MONDO:0018473",
      "name": "hyperlipoproteinemia type 3"
    },
    {
      "id": "MONDO:0012725",
      "name": "lipoprotein glomerulopathy"
    },
    {
      "id": "MONDO:0007089",
      "name": "Alzheimer disease 2"
    },
    {
      "id": "MONDO:0005044",
      "name": "hypertension"
    },
    {
      "id": "MONDO:0005150",
      "name": "age-related macular degeneration"
    },
    {
      "id": "MONDO:0005311",
      "name": "atherosclerosis"
    },
    {
      "id": "MONDO:0011743",
      "name": "Alzheimer disease 4"
    },
    {
      "id": "MONDO:0005620",
      "name": "cerebral amyloid angiopathy"
    }
  ],
  "