In [2]:
import os
import torch
import torch.nn.functional as F
import pickle
import gradio as gr
import pandas as pd
import networkx as nx
from pyvis.network import Network
from torch_geometric.data import HeteroData
from torch_geometric.nn import RGCNConv
from torch_geometric.explain import Explainer, GNNExplainer
from transformers import AutoTokenizer, AutoModelForCausalLM
import tempfile
import base64
from IPython.display import IFrame
from huggingface_hub import snapshot_download

data_fp = '../data/PROCESSED/'
model_fp = '../models'

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
HF_TOKEN = input("Enter Hugging Face token: ").strip() # Enter hugging face token
assert HF_TOKEN, " Please provide valid Hugging Face token."

In [4]:
snapshot_download(repo_id="meta-llama/Meta-Llama-3.1-8B-Instruct", repo_type="model", token=HF_TOKEN)

Fetching 17 files: 100%|██████████| 17/17 [00:00<00:00, 352985.98it/s]


'/home/agrima/.cache/huggingface/hub/models--meta-llama--Meta-Llama-3.1-8B-Instruct/snapshots/0e9e39f249a16976918f6564b8830bc894c89659'

In [5]:
#  Load LLaMA 3.1 8B Instruct model
model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name, token=HF_TOKEN)
llm = AutoModelForCausalLM.from_pretrained(model_name, token=HF_TOKEN, torch_dtype=torch.float16, device_map="auto")

Loading checkpoint shards: 100%|██████████| 4/4 [00:00<00:00,  4.21it/s]


In [6]:
# Load trained model components
hetero_dict = torch.load(f"{model_fp}/hetero_data_dict_version_final.pt", map_location="cpu")
hetero_data = HeteroData.from_dict(hetero_dict)
with open(f"{data_fp}/node_maps_version_final.pkl", "rb") as f:
    node_maps = pickle.load(f)

relation_to_id = {rel: i for i, rel in enumerate(hetero_data.edge_types)}
x = torch.cat([hetero_data[n].x for n in hetero_data.node_types], dim=0)
edge_index_all, edge_type_all = [], []
for etype, eidx in hetero_data.edge_index_dict.items():
    rel_id = relation_to_id[etype]
    edge_index_all.append(eidx)
    edge_type_all.append(torch.full((eidx.size(1),), rel_id, dtype=torch.long))
edge_index = torch.cat(edge_index_all, dim=1)
edge_type = torch.cat(edge_type_all)



  hetero_dict = torch.load(f"{model_fp}/hetero_data_dict_version_final.pt", map_location="cpu")


 ## LOADING MODEL AND PERFORMING EXPLAINABILITY


In [7]:
# Model definitions
class RGCNEncoder(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_relations):
        super().__init__()
        self.conv1 = RGCNConv(in_channels, hidden_channels, num_relations)
        self.conv2 = RGCNConv(hidden_channels, out_channels, num_relations)

    def forward(self, x, edge_index, edge_type):
        x = self.conv1(x, edge_index, edge_type)
        x = F.relu(x)
        x = self.conv2(x, edge_index, edge_type)
        return x

class DistMultPredictor(torch.nn.Module):
    def __init__(self, encoder, embedding_dim, num_relations):
        super().__init__()
        self.encoder = encoder
        self.rel_embeddings = torch.nn.Embedding(num_relations, embedding_dim)

    def forward(self, x, edge_index, edge_type, edge_label_index, edge_type_ids):
        z = self.encoder(x, edge_index, edge_type)
        src = z[edge_label_index[0]]
        dst = z[edge_label_index[1]]
        rel = self.rel_embeddings(edge_type_ids)
        return (src * rel * dst).sum(dim=-1)

# Load trained model state
embedding_dim = 128
encoder = RGCNEncoder(128, 256, embedding_dim, len(relation_to_id))
ckpt = torch.load(f"{model_fp}/rgcn_distmult_multirel_metrics.pt", map_location="cpu")
encoder.load_state_dict(ckpt['encoder_state_dict'])
predictor = DistMultPredictor(encoder, embedding_dim, len(relation_to_id))
predictor.rel_embeddings.load_state_dict({'weight': ckpt['decoder_state_dict']['relation_embeddings.weight']})
predictor.eval()


  ckpt = torch.load(f"{model_fp}/rgcn_distmult_multirel_metrics.pt", map_location="cpu")


DistMultPredictor(
  (encoder): RGCNEncoder(
    (conv1): RGCNConv(128, 256, num_relations=11)
    (conv2): RGCNConv(256, 128, num_relations=11)
  )
  (rel_embeddings): Embedding(11, 128)
)

In [None]:
def explain_query(disease, phenotype):
    try:
        src_idx = node_maps["disease"].get(disease)
        dst_idx = node_maps["phenotype"].get(phenotype)
        rel_tuple = ("disease", "disease_phenotype_positive", "phenotype")
        rel_id = torch.tensor([relation_to_id[rel_tuple]])

        print("FOUND DISEASE:", disease, "PHENOTYPE:", phenotype)


        if src_idx is None or dst_idx is None:
            return "Invalid disease or phenotype name.", "<p style='color:red;'> Disease or phenotype not found.</p>"

        edge_label_index = torch.tensor([[src_idx], [dst_idx]])

        # Prediction score from full graph
        with torch.no_grad():
            full_score = predictor(x, edge_index, edge_type, edge_label_index, rel_id).item()
            normalized_full_score = torch.sigmoid(torch.tensor(full_score)).item()

        # GNNExplainer setup
        explainer = Explainer(
            model=predictor,
            algorithm=GNNExplainer(epochs=75),
            explanation_type="model",
            edge_mask_type="object",
            model_config=dict(
                mode="binary_classification",
                task_level="edge",
                return_type="raw",
            ),
        )

        print("Completed explainer setup.")

        explanation = explainer(
            x=x,
            edge_index=edge_index,
            edge_type=edge_type,
            edge_label_index=edge_label_index,
            edge_type_ids=rel_id
        )

        edge_mask = explanation.edge_mask
        top_edges = edge_mask.topk(15).indices
        important_edges = edge_index[:, top_edges]

        if important_edges.size(1) == 0:
            return "No important edges found for this prediction.", "<p>No influential subgraph detected.</p>"

        # Subgraph confidence
        masked_edge_index = explanation.edge_index[:, explanation.edge_mask.bool()]
        masked_edge_type = edge_type[explanation.edge_mask.bool()]
        with torch.no_grad():
            subgraph_score = predictor(x, masked_edge_index, masked_edge_type, edge_label_index, rel_id).item()
            normalized_subgraph_score = torch.sigmoid(torch.tensor(subgraph_score)).item()

        # Create graph
        index_to_name = {v: k for t in node_maps for k, v in node_maps[t].items()}
        G = nx.DiGraph()
        for src, dst in important_edges.t().tolist():
            G.add_edge(index_to_name.get(src, str(src)), index_to_name.get(dst, str(dst)))

        print("Completed explainer setup.")

        net = Network(height="700px", width="100%", notebook=False, cdn_resources="in_line", directed=True)
        for node in G.nodes():
            color = (
                "#A3C4F3" if node == disease else
                "#FFB3C6" if node == phenotype else
                "#D3F8E2"
            )
            net.add_node(node, label=node, color=color, font={'size': 28, 'color': '#eeeeee'})

        for src, dst in G.edges():
            net.add_edge(src, dst)

        net.set_options('''
        {
          "nodes": {
            "shape": "dot",
            "size": 25,
            "font": { "size": 28, "face": "arial", "color": "#eeeeee" }
          },
          "edges": {
            "width": 1.5,
            "color": { "color": "#cccccc" },
            "smooth": false
          },
          "physics": {
            "enabled": true,
            "barnesHut": {
              "gravitationalConstant": -25000,
              "centralGravity": 0.1,
              "springLength": 300,
              "springConstant": 0.04,
              "damping": 0.15,
              "avoidOverlap": 1
            }
          },
          "layout": {
            "improvedLayout": true
          }
        }
        ''')

        # Save and encode graph
        with tempfile.NamedTemporaryFile("w+", suffix=".html", delete=False) as tmp_file:
            net.save_graph(tmp_file.name)
            tmp_file.seek(0)
            html_content = tmp_file.read()

        dark_css = """
        <style>
          body { background-color: #111111; margin: 0; padding: 0; }
          #mynetwork { background-color: #111111 !important; }
        </style>
        """
        html_content = html_content.replace("</head>", f"{dark_css}</head>")
        encoded_html = base64.b64encode(html_content.encode("utf-8")).decode("utf-8")
        iframe_html = f'''
        <iframe src="data:text/html;base64,{encoded_html}"
                width="100%" height="550px" frameborder="0"
                sandbox="allow-scripts allow-same-origin"></iframe>
        '''

        # LLM explanation
        summary = "\n".join([f"- {index_to_name[src]} ⟶ {index_to_name[dst]}" for src, dst in important_edges.t().tolist()])
        prompt = f"""
A biomedical GNN model predicted a relationship:
- Disease: {disease}
- Phenotype: {phenotype}

Confidence Scores:
- Full graph prediction (normalized): {normalized_full_score:.4f}
- GNNExplainer subgraph prediction (normalized): {normalized_subgraph_score:.4f}

Important edges influencing this prediction:
{summary}

Explain why this might make sense biologically, in a small paragraph. And don't write anything else.
"""

        inputs = tokenizer(prompt, return_tensors="pt").to(llm.device)
        outputs = llm.generate(**inputs, max_new_tokens=250)
        explanation = tokenizer.decode(outputs[0], skip_special_tokens=True)

        return explanation.strip(), iframe_html

    except Exception as e:
        import traceback
        print("Exception:", traceback.format_exc())
        return f"Error: {str(e)}", f"<p style='color:red;'> {str(e)}</p>" 



In [9]:
with gr.Blocks(title="PhenoMap") as demo:
    gr.Markdown("Enter disease and phenotype to get explanation of how they're linked.")
    with gr.Row():
        with gr.Column(scale=1):
            disease_input = gr.Textbox(label="Enter Disease")
            phenotype_input = gr.Textbox(label="Enter Phenotype")
            run_button = gr.Button("Run Explanation")
        with gr.Column(scale=2):
            explanation_output = gr.Textbox(label="Prediction & Confidence", lines=8, interactive=False)
            graph_output = gr.HTML(label="Subgraph Visualization")

    run_button.click(fn=explain_query, inputs=[disease_input, phenotype_input], outputs=[explanation_output, graph_output])

demo.launch(share=True)


* Running on local URL:  http://127.0.0.1:7860
* Running on public URL: https://35d40ad6319453daa4.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)




FOUND DISEASE: psoriasis PHENOTYPE: abnormality of the skin
Completed explainer setup.


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Completed explainer setup.
