##Setting up the environment

In [1]:
!pip uninstall torch torch-geometric torch-scatter torch-sparse torch-cluster torch-spline-conv pyg-lib -y
!pip install torch==2.5.0+cu124 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
!pip install torch-geometric --no-cache-dir
!pip install torch-scatter torch-sparse torch-cluster torch-spline-conv pyg-lib --no-cache-dir --find-links https://data.pyg.org/whl/torch-2.5.0+cu124.html
!pip install --upgrade transformers huggingface_hub accelerate
!pip install -U bitsandbytes
!pip install pyvis
!pip install -q gradio

Found existing installation: torch 2.6.0+cu124
Uninstalling torch-2.6.0+cu124:
  Successfully uninstalled torch-2.6.0+cu124
[0mLooking in indexes: https://download.pytorch.org/whl/cu124
Collecting torch==2.5.0+cu124
  Downloading https://download.pytorch.org/whl/cu124/torch-2.5.0%2Bcu124-cp311-cp311-linux_x86_64.whl (908.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m908.3/908.3 MB[0m [31m1.6 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch==2.5.0+cu124)
  Downloading https://download.pytorch.org/whl/cu124/nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (24.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m88.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting nvidia-cuda-runtime-cu12==12.4.127 (from torch==2.5.0+cu124)
  Downloading https://download.pytorch.org/whl/cu124/nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (883 kB)
[2K     [90m━━━━━━

##Importing all the libraries

In [1]:
import os
import torch
import torch.nn.functional as F
import pickle
import gradio as gr
import pandas as pd
import networkx as nx
from pyvis.network import Network
from torch_geometric.data import HeteroData
from torch_geometric.nn import RGCNConv
from torch_geometric.explain import Explainer, GNNExplainer
from transformers import AutoTokenizer, AutoModelForCausalLM
from IPython.display import IFrame

 ## Loading LLaMA 3.1 8B Instruct Model (Meta)
- Load the Meta LLaMA 3.1 8B Instruct model using Hugging Face’s transformers library.

- Requires a valid Hugging Face access token (as the model is gated).

- Loaded in float16 precision with automatic device mapping for GPU acceleration.

- Used for generating natural language explanations from biomedical graph predictions.

In [None]:
#  Hugging Face token (for Colab runtime only)
HF_TOKEN = input("Enter your Hugging Face token: ").strip()
assert HF_TOKEN, " Please provide a valid Hugging Face token."

In [3]:
#  Load LLaMA 3.1 8B Instruct model
model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name, token=HF_TOKEN)
llm = AutoModelForCausalLM.from_pretrained(model_name, token=HF_TOKEN, torch_dtype=torch.float16, device_map="auto")

tokenizer_config.json:   0%|          | 0.00/55.4k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/855 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/1.17G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/184 [00:00<?, ?B/s]

##Loading the Data

- Loads the heterogeneous graph (HeteroData) from disk.

- Loads the node maps (used for converting node names ↔ indices).

- Constructs a relation-to-ID mapping for all edge types.

- Concatenates node features from all node types into a unified feature matrix x.

- Flattens all edge indices and types into edge_index and edge_type, so the model can treat it as a single homogeneous graph with edge type annotations.



In [4]:
import tempfile
import base64
import gradio as gr

# Load trained model components
hetero_dict = torch.load("/content/drive/MyDrive/hetero_data_dict.pt", map_location="cpu")
hetero_data = HeteroData.from_dict(hetero_dict)
with open("/content/drive/MyDrive/node_maps.pkl", "rb") as f:
    node_maps = pickle.load(f)

relation_to_id = {rel: i for i, rel in enumerate(hetero_data.edge_types)}
x = torch.cat([hetero_data[n].x for n in hetero_data.node_types], dim=0)
edge_index_all, edge_type_all = [], []
for etype, eidx in hetero_data.edge_index_dict.items():
    rel_id = relation_to_id[etype]
    edge_index_all.append(eidx)
    edge_type_all.append(torch.full((eidx.size(1),), rel_id, dtype=torch.long))
edge_index = torch.cat(edge_index_all, dim=1)
edge_type = torch.cat(edge_type_all)



  hetero_dict = torch.load("/content/drive/MyDrive/hetero_data_dict.pt", map_location="cpu")


 ## Model Loading: Multi-Relation R-GCN with DistMult Decoder

In this section, the trained Relational Graph Convolutional Network (R-GCN) model equipped with a DistMult decoder is loaded for link prediction.

- Encoder: A two-layer R-GCN that captures relational dependencies between heterogeneous biomedical entities (e.g., drug, disease, phenotype).

- Decoder: A DistMult scoring function using a learned relation embedding per edge type to compute edge plausibility.

- Checkpoint: Loading pre-trained model weights saved after training on multi-relational edge types like drug ➝ drug_effect ➝ phenotype and disease ➝ disease_phenotype_positive ➝ phenotype.

This model will now be used for:

- Making link predictions between biomedical entities

- Explaining them using GNNExplainer, and

- Validating predictions with natural language using an LLM (e.g., LLaMA 3.1 8B Instruct).

In [5]:
# Model definitions
class RGCNEncoder(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_relations):
        super().__init__()
        self.conv1 = RGCNConv(in_channels, hidden_channels, num_relations)
        self.conv2 = RGCNConv(hidden_channels, out_channels, num_relations)

    def forward(self, x, edge_index, edge_type):
        x = self.conv1(x, edge_index, edge_type)
        x = F.relu(x)
        x = self.conv2(x, edge_index, edge_type)
        return x

class DistMultPredictor(torch.nn.Module):
    def __init__(self, encoder, embedding_dim, num_relations):
        super().__init__()
        self.encoder = encoder
        self.rel_embeddings = torch.nn.Embedding(num_relations, embedding_dim)

    def forward(self, x, edge_index, edge_type, edge_label_index, edge_type_ids):
        z = self.encoder(x, edge_index, edge_type)
        src = z[edge_label_index[0]]
        dst = z[edge_label_index[1]]
        rel = self.rel_embeddings(edge_type_ids)
        return (src * rel * dst).sum(dim=-1)

# Load trained model state
embedding_dim = 128
encoder = RGCNEncoder(128, 256, embedding_dim, len(relation_to_id))
ckpt = torch.load("/content/drive/MyDrive/rgcn_distmult_multirel_metrics.pt", map_location="cpu")
encoder.load_state_dict(ckpt['encoder_state_dict'])
predictor = DistMultPredictor(encoder, embedding_dim, len(relation_to_id))
predictor.rel_embeddings.load_state_dict({'weight': ckpt['decoder_state_dict']['relation_embeddings.weight']})
predictor.eval()


  ckpt = torch.load("/content/drive/MyDrive/rgcn_distmult_multirel_metrics.pt", map_location="cpu")


DistMultPredictor(
  (encoder): RGCNEncoder(
    (conv1): RGCNConv(128, 256, num_relations=10)
    (conv2): RGCNConv(256, 128, num_relations=10)
  )
  (rel_embeddings): Embedding(10, 128)
)

##LLM-Enhanced Biomedical Link Prediction & Explanation
We designed an interactive Gradio app that enables users to input a disease and a phenotype, and the system performs the following:

# Step 1: Predict the Relationship
- Uses a trained R-GCN encoder and a DistMult decoder to predict the likelihood
  of a link between the disease and phenotype.

- Outputs a normalized confidence score.

# Step 2: Generate an Explanation with GNNExplainer
- Extracts the top-k most influential edges using GNNExplainer.

- Computes a separate confidence score using only the explanatory subgraph to
  show explainability fidelity.

# Step 3: Visualize the Subgraph
- Displays an interactive PyVis network of important nodes and edges that
  contributed to the prediction.


# Step 4: Summarize Using a Language Model
- Feeds the edge explanation and scores into LLaMA 3.1 8B Instruct to generate
  a human-readable biological rationale.

The generated explanation helps users understand why the model predicted the link.



In [6]:
import gradio as gr
import tempfile
import base64
import torch
from torch_geometric.explain import Explainer, GNNExplainer
import networkx as nx
from pyvis.network import Network
from transformers import AutoTokenizer, AutoModelForCausalLM

# Assuming these are already loaded:
# predictor, x, edge_index, edge_type, node_maps, relation_to_id, tokenizer, llm

def explain_query(disease, phenotype):
    try:
        src_idx = node_maps["disease"].get(disease)
        dst_idx = node_maps["phenotype"].get(phenotype)
        rel_tuple = ("disease", "disease_phenotype_positive", "phenotype")
        rel_id = torch.tensor([relation_to_id[rel_tuple]])

        if src_idx is None or dst_idx is None:
            return "Invalid disease or phenotype name.", "<p style='color:red;'> Disease or phenotype not found.</p>"

        edge_label_index = torch.tensor([[src_idx], [dst_idx]])

        # Prediction score from full graph
        with torch.no_grad():
            full_score = predictor(x, edge_index, edge_type, edge_label_index, rel_id).item()
            normalized_full_score = torch.sigmoid(torch.tensor(full_score)).item()

        # GNNExplainer setup
        explainer = Explainer(
            model=predictor,
            algorithm=GNNExplainer(epochs=75),
            explanation_type="model",
            edge_mask_type="object",
            model_config=dict(
                mode="binary_classification",
                task_level="edge",
                return_type="raw",
            ),
        )

        explanation = explainer(
            x=x,
            edge_index=edge_index,
            edge_type=edge_type,
            edge_label_index=edge_label_index,
            edge_type_ids=rel_id
        )

        edge_mask = explanation.edge_mask
        top_edges = edge_mask.topk(15).indices
        important_edges = edge_index[:, top_edges]

        if important_edges.size(1) == 0:
            return "No important edges found for this prediction.", "<p>No influential subgraph detected.</p>"

        # Subgraph confidence
        masked_edge_index = explanation.edge_index[:, explanation.edge_mask.bool()]
        masked_edge_type = edge_type[explanation.edge_mask.bool()]
        with torch.no_grad():
            subgraph_score = predictor(x, masked_edge_index, masked_edge_type, edge_label_index, rel_id).item()
            normalized_subgraph_score = torch.sigmoid(torch.tensor(subgraph_score)).item()

        # Create graph
        index_to_name = {v: k for t in node_maps for k, v in node_maps[t].items()}
        G = nx.DiGraph()
        for src, dst in important_edges.t().tolist():
            G.add_edge(index_to_name.get(src, str(src)), index_to_name.get(dst, str(dst)))

        net = Network(height="700px", width="100%", notebook=False, cdn_resources="in_line", directed=True)
        for node in G.nodes():
            color = (
                "#A3C4F3" if node == disease else
                "#FFB3C6" if node == phenotype else
                "#D3F8E2"
            )
            net.add_node(node, label=node, color=color, font={'size': 28, 'color': '#eeeeee'})

        for src, dst in G.edges():
            net.add_edge(src, dst)

        net.set_options('''
        {
          "nodes": {
            "shape": "dot",
            "size": 25,
            "font": { "size": 28, "face": "arial", "color": "#eeeeee" }
          },
          "edges": {
            "width": 1.5,
            "color": { "color": "#cccccc" },
            "smooth": false
          },
          "physics": {
            "enabled": true,
            "barnesHut": {
              "gravitationalConstant": -25000,
              "centralGravity": 0.1,
              "springLength": 300,
              "springConstant": 0.04,
              "damping": 0.15,
              "avoidOverlap": 1
            }
          },
          "layout": {
            "improvedLayout": true
          }
        }
        ''')

        # Save and encode graph
        with tempfile.NamedTemporaryFile("w+", suffix=".html", delete=False) as tmp_file:
            net.save_graph(tmp_file.name)
            tmp_file.seek(0)
            html_content = tmp_file.read()

        dark_css = """
        <style>
          body { background-color: #111111; margin: 0; padding: 0; }
          #mynetwork { background-color: #111111 !important; }
        </style>
        """
        html_content = html_content.replace("</head>", f"{dark_css}</head>")
        encoded_html = base64.b64encode(html_content.encode("utf-8")).decode("utf-8")
        iframe_html = f'''
        <iframe src="data:text/html;base64,{encoded_html}"
                width="100%" height="550px" frameborder="0"
                sandbox="allow-scripts allow-same-origin"></iframe>
        '''

        # LLM explanation
        summary = "\n".join([f"- {index_to_name[src]} ⟶ {index_to_name[dst]}" for src, dst in important_edges.t().tolist()])
        prompt = f"""
A biomedical GNN model predicted a relationship:
- Disease: {disease}
- Phenotype: {phenotype}

Confidence Scores:
- Full graph prediction (normalized): {normalized_full_score:.4f}
- GNNExplainer subgraph prediction (normalized): {normalized_subgraph_score:.4f}

Important edges influencing this prediction:
{summary}

Explain why this might make sense biologically, in a small paragraph. And don't write anything else.
"""

        inputs = tokenizer(prompt, return_tensors="pt").to(llm.device)
        outputs = llm.generate(**inputs, max_new_tokens=250)
        explanation = tokenizer.decode(outputs[0], skip_special_tokens=True)

        return explanation.strip(), iframe_html

    except Exception as e:
        import traceback
        print("Exception:", traceback.format_exc())
        return f"Error: {str(e)}", f"<p style='color:red;'> {str(e)}</p>"



##XplainMD: A Multi-Relation Biomedical Link Prediction & Explanation Tool
We built an interactive Gradio-based chatbot that allows users to input any disease–phenotype pair and receive a real-time explanation of the predicted link using our trained multi-relation DistMult GNN model.

- User Input: Disease and phenotype names (e.g., diabetes, retinopathy)

- Prediction: The model computes the confidence score of the link

-  Explainability: GNNExplainer identifies key edges influencing the prediction

- LLM Reasoning: LLaMA 3.1 generates a human-readable biological explanation

- Subgraph Visualization: Interactive PyVis graph highlights important nodes and relationships in the knowledge graph

This interface bridges the gap between complex model internals and domain interpretability, enabling both researchers and clinicians to interactively explore and validate AI-driven biomedical hypotheses.

In [None]:
# Gradio app
with gr.Blocks(title="XplainMD: Multi-Relation DistMult") as demo:
    gr.Markdown("## XplainMD: A Multi-Relation Biomedical Link Prediction & Explanation Tool ")
    gr.Markdown("This tool predicts and explains disease-phenotype links using multi-relation GCN.")
    with gr.Row():
        with gr.Column(scale=1):
            disease_input = gr.Textbox(label="Enter Disease")
            phenotype_input = gr.Textbox(label="Enter Phenotype")
            run_button = gr.Button("Run Explanation")
        with gr.Column(scale=2):
            explanation_output = gr.Textbox(label="Prediction & Confidence", lines=8, interactive=False)
            graph_output = gr.HTML(label="Subgraph Visualization")

    run_button.click(fn=explain_query, inputs=[disease_input, phenotype_input], outputs=[explanation_output, graph_output])

demo.launch(share=True)