In [None]:
!pip install torch
!pip install torch_geometric



In [None]:
import os
import pickle
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import HeteroData
from torch_geometric.nn import HGTConv
from torch.utils.data import DataLoader
import numpy as np

In [None]:
def is_colab():
    try:
        import google.colab
        return True
    except ImportError:
        return False

IN_COLAB = is_colab()

if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive/')
    %cd "/content/drive/My Drive/Colab Notebooks/CS224W/DRKG Project"

# Load the NetworkX graph
if os.path.exists("./ml_graph.pkl"):
    G = pickle.load(open("./ml_graph.pkl", "rb"))
else:
    print("Graph file not found!")
    # raise FileNotFoundError("ml_graph.pkl not found.")

print("\nChecking example edge attributes:\n")
print(f"Gene::6434 -> Gene::27429: {G['Gene::6434']['Gene::27429']}")
print(f"Gene::27429 -> Gene::6434: {G['Gene::27429']['Gene::6434']}")
print(f"Gene::26092 -> Gene::5577: {G['Gene::26092']['Gene::5577']}")

print("Checking example edge attributes:\n")
print(
    f"Edge from DB00004 to MESH:C063419: {G.get_edge_data('Compound::DB00004', 'Disease::MESH:C063419')}"
)
print(
    f"Edge from MESH:C004656 to MESH:C537014: {G.get_edge_data('Compound::MESH:C004656', 'Disease::MESH:C537014')}"
)
print(
    f"Edge from DB00997 to DOID:363: {G.get_edge_data('Compound::DB00997', 'Disease::DOID:363')}\n"
)

Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount("/content/drive/", force_remount=True).
/content/drive/My Drive/Colab Notebooks/CS224W/DRKG Project

Checking example edge attributes:

Gene::6434 -> Gene::27429: {'relation': 'STRING::OTHER::Gene:Gene', 'coexpression': 0.042, 'experimentally_determined_interaction': 0.047, 'automated_textmining': 0.692, 'combined_score': 0.694, 'rmsd_score': 'N/A', 'gene_gene_features': array([-0.47708675, -0.43628007,  1.306391  ,  0.5278838 ,  0.        ],
      dtype=float32), 'gene_gene_mask': array([0., 0., 0., 0., 1.], dtype=float32)}
Gene::27429 -> Gene::6434: {'relation': 'STRING::OTHER::Gene:Gene', 'coexpression': 0.042, 'experimentally_determined_interaction': 0.047, 'automated_textmining': 0.692, 'combined_score': 0.694, 'rmsd_score': 'N/A', 'gene_gene_features': array([-0.47708675, -0.43628007,  1.306391  ,  0.5278838 ,  0.        ],
      dtype=float32), 'gene_gene_mask': array([0., 0., 0., 0., 1.],

In [None]:
###################################################
# Convert the processed NetworkX graph to a PyG HeteroData
###################################################

def get_node_type(node):
    if node.startswith("Compound::"):
        return "compound"
    elif node.startswith("Disease::"):
        return "disease"
    elif node.startswith("Gene::"):
        return "gene"
    else:
        return "other"

# build node_type_map and node_types_count
node_type_map = {}
node_types_count = {}
for n in G.nodes():
    t = get_node_type(n)
    if t not in node_types_count:
        node_types_count[t] = 0
    node_type_map[n] = (t, node_types_count[t])
    node_types_count[t] += 1

treat_type_dim = 0
gene_gene_feat_dim = 0
mask_dim = 0

for u, v, d in G.edges(data=True):
    if "treat_type_onehot" in d:
        treat_type_dim = len(d["treat_type_onehot"])
    if "gene_gene_features" in d:
        gene_gene_feat_dim = len(d["gene_gene_features"])
        mask_dim = len(d["gene_gene_mask"])
    if treat_type_dim > 0 and gene_gene_feat_dim > 0:
        break

hetero_path = "./ml_graph_heterodata.pkl"
if os.path.exists(hetero_path):
    print(f"Loading existing HeteroData from {hetero_path}...")
    with open(hetero_path, "rb") as f:
        data = pickle.load(f)
else:
    print("Building new HeteroData object...")
    data = HeteroData()

    # assign num_nodes and a dummy x to each node type
    for ntype, count in node_types_count.items():
        data[ntype].x = torch.empty((count, 0))  # No initial features, to be learned
        data[ntype].num_nodes = count

    # determine edge type (just preserve the original DRKG relation type)
    def get_edge_type(u, v, edata):
        rel = edata.get("relation", None)
        src_type, _ = node_type_map[u]
        dst_type, _ = node_type_map[v]
        if rel is None:
            rel = "unknown_relation"
        return (src_type, rel, dst_type)

    edge_dict = {}
    edge_treat_type_features = {}
    edge_gene_gene_features = {}
    edge_gene_gene_mask = {}

    # collect edges by exact relation type
    for u, v, d in G.edges(data=True):
        et = get_edge_type(u, v, d)
        if et not in edge_dict:
            edge_dict[et] = [[], []]

        src_idx = node_type_map[u][1]
        dst_idx = node_type_map[v][1]
        edge_dict[et][0].append(src_idx)
        edge_dict[et][1].append(dst_idx)

        # store edge features if any
        if "treat_type_onehot" in d:
            # This edge_type corresponds to compound->disease "DRKG::Treats::Compound:Disease"
            rel_name = et[1]  # original relation name
            if rel_name not in edge_treat_type_features:
                edge_treat_type_features[rel_name] = []
            edge_treat_type_features[rel_name].append(d["treat_type_onehot"])

        if "gene_gene_features" in d:
            rel_name = et[1]
            if rel_name not in edge_gene_gene_features:
                edge_gene_gene_features[rel_name] = []
                edge_gene_gene_mask[rel_name] = []
            edge_gene_gene_features[rel_name].append(d["gene_gene_features"])
            edge_gene_gene_mask[rel_name].append(d["gene_gene_mask"])

    # assign edge_index and edge_attr to data
    for (srct, rel_name, dstt), (srcs, dsts) in edge_dict.items():
        data[(srct, rel_name, dstt)].edge_index = torch.tensor(
            [srcs, dsts], dtype=torch.long
        )

        # treat_type features exist for this relation
        if rel_name in edge_treat_type_features:
            data[(srct, rel_name, dstt)].edge_attr = torch.tensor(
                edge_treat_type_features[rel_name], dtype=torch.float
            )
        # if gene-gene features exist for this relation
        if rel_name in edge_gene_gene_features:
            fmat = torch.tensor(edge_gene_gene_features[rel_name], dtype=torch.float)
            mmask = torch.tensor(edge_gene_gene_mask[rel_name], dtype=torch.float)
            # concatenate features and masks
            if (
                hasattr(data[(srct, rel_name, dstt)], "edge_attr")
                and data[(srct, rel_name, dstt)].edge_attr is not None
            ):
                # if already have treat_type, we need to concatenate
                # this should NOT happen since treat_type is usually for compound-disease edges,
                raise ValueError("edge attributes already exist for this relation")
                # existing_attr = data[(srct, rel_name, dstt)].edge_attr
                # combined_attr = torch.cat([existing_attr, fmat, mmask], dim=1)
                # data[(srct, rel_name, dstt)].edge_attr = combined_attr
            else:
                data[(srct, rel_name, dstt)].edge_attr = torch.cat([fmat, mmask], dim=1)

    print(f"Saving HeteroData to {hetero_path}...")
    with open(hetero_path, "wb") as f:
        pickle.dump(data, f)

Loading existing HeteroData from ./ml_graph_heterodata.pkl...


In [None]:
###################################################
# Define the HGT-based Model
###################################################
# - we use HGTConv as is for node-level message passing.
# - after HGTConv layers, we use node embeddings and for link prediction & incorporate edge features before final scoring.

# future work: subclass MessagePassing to incorporate the edge features there

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# make CUDA happy
for ntype in data.node_types:
    if data[ntype].x is not None:
        data[ntype].x = data[ntype].x.to(device)
for rel in data.edge_types:
    data[rel].edge_index = data[rel].edge_index.to(device)
    if hasattr(data[rel], 'edge_attr') and data[rel].edge_attr is not None:
        data[rel].edge_attr = data[rel].edge_attr.to(device)

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import HGTConv

class HGTModel(nn.Module):
    def __init__(self, metadata, hidden_channels, out_channels, num_layers, num_heads,
                 node_types_count, treat_type_dim, gene_edge_dim, dropout=0.3, jk_mode='concat'):
        super(HGTModel, self).__init__()
        self.node_types = list(node_types_count.keys())
        self.hidden_channels = hidden_channels
        self.dropout = dropout
        self.treat_type_dim = treat_type_dim
        self.gene_edge_dim = gene_edge_dim
        self.num_layers = num_layers
        self.jk_mode = jk_mode

        # embeddings for each node type
        self.node_embs = nn.ModuleDict()
        for ntype, count in node_types_count.items():
            self.node_embs[ntype] = nn.Embedding(count, hidden_channels)

        # HGTConv layers
        self.convs = nn.ModuleList([
            HGTConv(in_channels=hidden_channels, out_channels=hidden_channels,
                    metadata=metadata, heads=num_heads) for _ in range(num_layers)
        ])

        # LayerNorm
        self.norms = nn.ModuleList([nn.LayerNorm(hidden_channels) for _ in range(num_layers)])

        # JK-Net aggregator: combine layer outputs at the end
        if jk_mode == 'concat':
            self.final_lin = nn.Linear(hidden_channels * num_layers, hidden_channels)
        elif jk_mode in ['max', 'mean']:
            # No dimension change needed
            self.final_lin = nn.Linear(hidden_channels, hidden_channels)
        else:
            raise ValueError("jk_mode must be one of ['concat', 'max', 'mean']")

        self.lin_score = nn.Linear(hidden_channels * 2 + treat_type_dim + gene_edge_dim, out_channels)

    def forward(self, x_dict, edge_index_dict, edge_attr_dict):
        # dropout
        for ntype in x_dict:
            emb = self.node_embs[ntype].weight
            emb = F.dropout(emb, p=self.dropout, training=self.training)
            x_dict[ntype] = emb

        # keep track of outputs from each layer for JK-Net
        layer_outputs = []

        # message passing
        for i, conv in enumerate(self.convs):
            prev_x_dict = {nt: x for nt, x in x_dict.items()}
            x_dict = conv(x_dict, edge_index_dict)

            # residual
            for ntype in x_dict:
                x_dict[ntype] = x_dict[ntype] + prev_x_dict[ntype]
                x_dict[ntype] = self.norms[i](x_dict[ntype])
                # dropout
                x_dict[ntype] = F.dropout(x_dict[ntype], p=self.dropout, training=self.training)

            # for JK-Net
            layer_outputs.append({nt: x.clone() for nt, x in x_dict.items()})

        for ntype in x_dict:
            if self.jk_mode == 'concat':
                all_layers = [layer_outputs[i][ntype] for i in range(self.num_layers)]
                x_cat = torch.cat(all_layers, dim=-1)  # [num_nodes, hidden_channels * num_layers]
                x_dict[ntype] = self.final_lin(x_cat)
            elif self.jk_mode == 'max':
                all_layers = torch.stack([layer_outputs[i][ntype] for i in range(self.num_layers)], dim=0)
                x_max, _ = torch.max(all_layers, dim=0)
                x_dict[ntype] = self.final_lin(x_max)
            elif self.jk_mode == 'mean':
                all_layers = torch.stack([layer_outputs[i][ntype] for i in range(self.num_layers)], dim=0)
                x_mean = torch.mean(all_layers, dim=0)
                x_dict[ntype] = self.final_lin(x_mean)

        return x_dict

    def predict_link(self, x_dict, src_nodes, dst_nodes, edge_attr=None):
        src_ntype, src_ids = src_nodes
        dst_ntype, dst_ids = dst_nodes
        x_src = x_dict[src_ntype][src_ids]
        x_dst = x_dict[dst_ntype][dst_ids]

        # concatenate augmented edge features; pad if missing
        needed_dim = self.treat_type_dim + self.gene_edge_dim
        if edge_attr is None:
            edge_attr = torch.zeros(x_src.size(0), needed_dim, device=x_src.device)
        else:
            current_dim = edge_attr.size(1)
            if current_dim < needed_dim:
                pad_dim = needed_dim - current_dim
                padding = torch.zeros(x_src.size(0), pad_dim, device=x_src.device)
                edge_attr = torch.cat([edge_attr, padding], dim=1)

        combined = torch.cat([x_src, x_dst, edge_attr], dim=-1)
        # dropout
        combined = F.dropout(combined, p=self.dropout, training=self.training)
        scores = self.lin_score(combined)
        return scores

# train, val, test treatment triples splits

metadata = data.metadata()

cd_rel = ("compound", "DRKG::Treats::Compound:Disease", "disease")

cd_edge_index = data[cd_rel].edge_index
cd_edge_attr = data[cd_rel].edge_attr if hasattr(data[cd_rel], "edge_attr") else None

num_cd_edges = cd_edge_index.size(1)
perm = torch.randperm(num_cd_edges, device=device)

train_size = int(0.8 * num_cd_edges)
val_size = int(0.1 * num_cd_edges)
test_size = num_cd_edges - train_size - val_size

train_edges = cd_edge_index[:, perm[:train_size]]
val_edges = cd_edge_index[:, perm[train_size : train_size + val_size]]
test_edges = cd_edge_index[:, perm[train_size + val_size :]]


def slice_edge_attr(edge_attr_full, idxs):
    return edge_attr_full[idxs] if edge_attr_full is not None else None


train_edge_attr = slice_edge_attr(cd_edge_attr, perm[:train_size])
val_edge_attr = slice_edge_attr(cd_edge_attr, perm[train_size : train_size + val_size])
test_edge_attr = slice_edge_attr(cd_edge_attr, perm[train_size + val_size :])

compound_count = data["compound"].num_nodes
disease_count = data["disease"].num_nodes


def negative_sampling(num_samples):
    neg_src = torch.randint(0, compound_count, (num_samples,), device=device)
    neg_dst = torch.randint(0, disease_count, (num_samples,), device=device)
    return neg_src, neg_dst


neg_train_src, neg_train_dst = negative_sampling(train_edges.size(1))
neg_val_src, neg_val_dst = negative_sampling(val_edges.size(1))
neg_test_src, neg_test_dst = negative_sampling(test_edges.size(1))

In [None]:
hidden_channels = 128
out_channels = 1
num_layers = 2
num_heads = 8

model = HGTModel(metadata, hidden_channels, out_channels, num_layers, num_heads,
                 node_types_count=node_types_count,
                 treat_type_dim=treat_type_dim,
                 gene_edge_dim=gene_gene_feat_dim+mask_dim if gene_gene_feat_dim>0 else 0).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)

def compute_loss(pos_edges, neg_edges, pos_edge_attr, model, node_embs):
    pos_src, pos_dst = pos_edges
    neg_src, neg_dst = neg_edges
    # pos_src, pos_dst, neg_src, neg_dst are already on device
    idxs = torch.arange(pos_src.size(0), device=device)
    pos_attr = slice_edge_attr(pos_edge_attr, idxs)

    pos_scores = model.predict_link(node_embs, ('compound', pos_src), ('disease', pos_dst), pos_attr)
    neg_scores = model.predict_link(node_embs, ('compound', neg_src), ('disease', neg_dst), None)

    labels = torch.cat([torch.ones(pos_scores.size(0), 1, device=device),
                        torch.zeros(neg_scores.size(0), 1, device=device)], dim=0)
    scores = torch.cat([pos_scores, neg_scores], dim=0)
    loss = F.binary_cross_entropy_with_logits(scores, labels)
    return loss

# Prepare once outside loop
x_dict = {ntype: data[ntype].x for ntype in data.node_types}
edge_index_dict = {rel: data[rel].edge_index for rel in data.edge_types}
edge_attr_dict = {rel: (data[rel].edge_attr if hasattr(data[rel], 'edge_attr') else None)
                  for rel in data.edge_types}

train_losses = []
val_losses = []

# Early stopping parameters
patience = 40  # Number of epochs to wait for improvement
min_delta = 1e-4  # Minimum change in validation loss to qualify as an improvement
best_val_loss = float("inf")
best_model_state = None
patience_counter = 0
best_epoch = 0

epochs = 4000
current_epoch = 0
for epoch in range(epochs):
    model.train()
    optimizer.zero_grad()
    x_dict = {ntype: data[ntype].x for ntype in data.node_types}
    edge_index_dict = {rel: data[rel].edge_index for rel in data.edge_types}
    edge_attr_dict = {
        rel: (data[rel].edge_attr if hasattr(data[rel], "edge_attr") else None)
        for rel in data.edge_types
    }

    node_embs = model(x_dict, edge_index_dict, edge_attr_dict)

    train_loss = compute_loss(
        train_edges, (neg_train_src, neg_train_dst), train_edge_attr, model, node_embs
    )
    train_loss.backward()
    optimizer.step()

    # validation loss
    model.eval()
    with torch.no_grad():
        node_embs = model(x_dict, edge_index_dict, edge_attr_dict)
        val_loss = compute_loss(
            val_edges, (neg_val_src, neg_val_dst), val_edge_attr, model, node_embs
        )

    train_losses.append(train_loss.item())
    val_losses.append(val_loss.item())

    # early stopping check
    current_val_loss = val_loss.item()
    if current_val_loss < best_val_loss - min_delta:
        best_val_loss = current_val_loss
        best_model_state = model.state_dict()
        patience_counter = 0
        best_epoch = epoch
    else:
        patience_counter += 1

    print(
        f"Epoch {epoch+1}/{epochs} - Train Loss: {train_loss.item():.4f}, Val Loss: {val_loss.item():.4f}"
        + (f" *" if patience_counter == 0 else "")
    )

    # early stopping
    if patience_counter >= patience:
        print(
            f"\nEarly stopping triggered! Best epoch was {best_epoch + 1} with validation loss: {best_val_loss:.4f}"
        )
        # Restore best model
        model.load_state_dict(best_model_state)
        break
    current_epoch += 1

if epoch == epochs - 1 and best_model_state is not None and best_epoch != epoch:
    print(
        f"\nTraining completed. Restoring best model from epoch {best_epoch + 1} with validation loss: {best_val_loss:.4f}"
    )
    model.load_state_dict(best_model_state)

Epoch 1/4000 - Train Loss: 0.7125, Val Loss: 0.6897 *
Epoch 2/4000 - Train Loss: 0.6995, Val Loss: 0.6821 *
Epoch 3/4000 - Train Loss: 0.6958, Val Loss: 0.6756 *
Epoch 4/4000 - Train Loss: 0.6849, Val Loss: 0.6700 *
Epoch 5/4000 - Train Loss: 0.6828, Val Loss: 0.6653 *
Epoch 6/4000 - Train Loss: 0.6806, Val Loss: 0.6611 *
Epoch 7/4000 - Train Loss: 0.6781, Val Loss: 0.6574 *
Epoch 8/4000 - Train Loss: 0.6743, Val Loss: 0.6542 *
Epoch 9/4000 - Train Loss: 0.6702, Val Loss: 0.6511 *
Epoch 10/4000 - Train Loss: 0.6685, Val Loss: 0.6482 *
Epoch 11/4000 - Train Loss: 0.6671, Val Loss: 0.6454 *
Epoch 12/4000 - Train Loss: 0.6634, Val Loss: 0.6427 *
Epoch 13/4000 - Train Loss: 0.6671, Val Loss: 0.6399 *
Epoch 14/4000 - Train Loss: 0.6580, Val Loss: 0.6369 *
Epoch 15/4000 - Train Loss: 0.6606, Val Loss: 0.6339 *
Epoch 16/4000 - Train Loss: 0.6528, Val Loss: 0.6306 *
Epoch 17/4000 - Train Loss: 0.6458, Val Loss: 0.6271 *
Epoch 18/4000 - Train Loss: 0.6461, Val Loss: 0.6232 *
Epoch 19/4000 - Tra

In [None]:
# write train and val loss to disk
with open("train_val_losses.csv", "w") as f:
    f.write("epoch,train_loss,val_loss\n")
    for i, (tr_l, vl_l) in enumerate(zip(train_losses, val_losses), start=1):
        f.write(f"{i},{tr_l},{vl_l}\n")

import matplotlib.pyplot as plt
# plot train and val loss
plt.figure()
plt.plot(range(1, current_epoch+2), train_losses, label='Train Loss')
plt.plot(range(1, current_epoch+2), val_losses, label='Val Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title(f'HGT {hidden_channels} dim, {num_layers} layers, {num_heads} heads')
plt.legend()
plt.savefig("loss_plot.png", dpi=300)
plt.close()

# evaluate on test set
model.eval()
with torch.no_grad():
    node_embs = model(x_dict, edge_index_dict, edge_attr_dict)
    test_loss = compute_loss(test_edges, (neg_test_src, neg_test_dst), test_edge_attr, model, node_embs)
print(f"Test Loss: {test_loss.item():.4f}")


Test Loss: 0.1335


In [None]:
# save the trained model
model_path = "trained_hgt_model.pt"
print(f"\nSaving trained model to {model_path}...")
torch.save(
    {
        "epoch": best_epoch,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "loss": train_loss,
        "hidden_channels": hidden_channels,
        "out_channels": out_channels,
        "num_layers": num_layers,
        "num_heads": num_heads,
        "node_types_count": node_types_count,
        "treat_type_dim": treat_type_dim,
        "gene_edge_dim": gene_gene_feat_dim + mask_dim if gene_gene_feat_dim > 0 else 0,
    },
    model_path,
)


def load_trained_model(model_path, device):
    # Load the checkpoint
    checkpoint = torch.load(model_path, map_location=device)

    # Initialize model with same parameters
    model = HGTModel(
        metadata=data.metadata(),
        hidden_channels=checkpoint["hidden_channels"],
        out_channels=checkpoint["out_channels"],
        num_layers=checkpoint["num_layers"],
        num_heads=checkpoint["num_heads"],
        node_types_count=checkpoint["node_types_count"],
        treat_type_dim=checkpoint["treat_type_dim"],
        gene_edge_dim=checkpoint["gene_edge_dim"],
    )

    # Load the state dict
    model.load_state_dict(checkpoint["model_state_dict"])
    model = model.to(device)

    # If you need to continue training, you can also load optimizer state
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

    return model, optimizer, checkpoint["epoch"]


"""
# load the trained model
model_path = get_file_path("trained_hgt_model.pt")
if os.path.exists(model_path):
    model, optimizer, start_epoch = load_trained_model(model_path, device)
    # model.eval()
else:
    print("No trained model found.")
"""

In [None]:
###################################################
# score all compound->disease pairs for diseases in list and rank compounds by score.
###################################################

COV_disease_list = [
    "Disease::SARS-CoV2 E",
    "Disease::SARS-CoV2 M",
    "Disease::SARS-CoV2 N",
    "Disease::SARS-CoV2 Spike",
    "Disease::SARS-CoV2 nsp1",
    "Disease::SARS-CoV2 nsp10",
    "Disease::SARS-CoV2 nsp11",
    "Disease::SARS-CoV2 nsp12",
    "Disease::SARS-CoV2 nsp13",
    "Disease::SARS-CoV2 nsp14",
    "Disease::SARS-CoV2 nsp15",
    "Disease::SARS-CoV2 nsp2",
    "Disease::SARS-CoV2 nsp4",
    "Disease::SARS-CoV2 nsp5",
    "Disease::SARS-CoV2 nsp5_C145A",
    "Disease::SARS-CoV2 nsp6",
    "Disease::SARS-CoV2 nsp7",
    "Disease::SARS-CoV2 nsp8",
    "Disease::SARS-CoV2 nsp9",
    "Disease::SARS-CoV2 orf10",
    "Disease::SARS-CoV2 orf3a",
    "Disease::SARS-CoV2 orf3b",
    "Disease::SARS-CoV2 orf6",
    "Disease::SARS-CoV2 orf7a",
    "Disease::SARS-CoV2 orf8",
    "Disease::SARS-CoV2 orf9b",
    "Disease::SARS-CoV2 orf9c",
    "Disease::MESH:D045169",
    "Disease::MESH:D045473",
    "Disease::MESH:D001351",
    "Disease::MESH:D065207",
    "Disease::MESH:D028941",
    "Disease::MESH:D058957",
    "Disease::MESH:D006517",
]

control_disease_list = ["Disease::MESH:D015451", "Disease::MESH:D053158", "Disease::MESH:D001474"]

model.eval()
with torch.no_grad():
    x_dict = {ntype: data[ntype].x for ntype in data.node_types}
    edge_index_dict = {rel: data[rel].edge_index for rel in data.edge_types}
    edge_attr_dict = {
        rel: (data[rel].edge_attr if hasattr(data[rel], "edge_attr") else None)
        for rel in data.edge_types
    }
    node_embs = model(x_dict, edge_index_dict, edge_attr_dict)

    # Map disease names to their indices
    disease_name_to_id = {}
    for n, (t, idx) in node_type_map.items():
        if t == "disease":
            disease_name_to_id[n] = idx

    # Map compound indices for convenience
    compound_ids = torch.arange(data["compound"].num_nodes, device=device)

    for disease_name in COV_disease_list + control_disease_list:
        if disease_name not in disease_name_to_id:
            # Disease not in graph
            print(f"Disease {disease_name} not found in the graph.")
            continue

        disease_id = disease_name_to_id[disease_name]
        # Score all compound->disease pairs
        # No treat_type edge_attr available for "all pairs", so we just don't use it
        scores = model.predict_link(
            node_embs,
            ("compound", compound_ids),
            ("disease", torch.full_like(compound_ids, disease_id)),
            None,
        )

        probs = torch.sigmoid(scores).squeeze()
        # Rank compounds by score
        topk = 10
        top_scores, top_indices = torch.topk(probs, topk)
        print(f"Top {topk} candidate compounds for {disease_name}:")
        # Need to map compound_idx back to compound name
        # We'll build a reverse map for compounds:
        compound_id_to_name = {}
        for n, (t, idx) in node_type_map.items():
            if t == "compound":
                compound_id_to_name[idx] = n

        for rank, (score_val, c_idx) in enumerate(
            zip(top_scores.tolist(), top_indices.tolist()), start=1
        ):
            compound_name = compound_id_to_name[c_idx]
            print(f"{rank}: {compound_name} (score={score_val:.4f})")

###################################################
# The printed results give a procedure to identify candidate compounds for given COV diseases.
# We trained a model and then, for each disease in COV_disease_list, we predict compound rankings.
#
# This code sets up the entire training pipeline, uses an HGT model, and shows how to get predictions.
###################################################

Top 10 candidate compounds for Disease::SARS-CoV2 E:
1: Compound::MESH:D013256 (score=0.9601)
2: Compound::MESH:D000305 (score=0.9571)
3: Compound::MESH:D010406 (score=0.9529)
4: Compound::CHEBI:35341 (score=0.9445)
5: Compound::DB00207 (score=0.9310)
6: Compound::CHEBI:50858 (score=0.9206)
7: Compound::DB00563 (score=0.9204)
8: Compound::DB00104 (score=0.9181)
9: Compound::DB09140 (score=0.9086)
10: Compound::DB00741 (score=0.8956)
Top 10 candidate compounds for Disease::SARS-CoV2 M:
1: Compound::MESH:D013256 (score=0.9381)
2: Compound::MESH:D000305 (score=0.9336)
3: Compound::MESH:D010406 (score=0.9273)
4: Compound::CHEBI:35341 (score=0.9146)
5: Compound::DB00207 (score=0.8947)
6: Compound::CHEBI:50858 (score=0.8796)
7: Compound::DB00563 (score=0.8793)
8: Compound::DB00104 (score=0.8759)
9: Compound::DB09140 (score=0.8624)
10: Compound::DB00741 (score=0.8438)
Top 10 candidate compounds for Disease::SARS-CoV2 N:
1: Compound::MESH:D013256 (score=0.9209)
2: Compound::MESH:D000305 (score

In [None]:
import pandas as pd

# known treatments and phases

def parse_phase(phase_str):
    """Parse phase string or int into a numeric value."""
    if isinstance(phase_str, (int, float)):
        return float(phase_str)
    if not isinstance(phase_str, str) or phase_str == "Not Available":
        return None

    # Try parsing comma-separated phases
    parts = phase_str.split(",")
    numeric_parts = []
    for p in parts:
        p = p.strip()
        try:
            numeric_parts.append(float(p))
        except ValueError:
            continue

    return sum(numeric_parts) / len(numeric_parts) if numeric_parts else None

known_treatments = {d: {} for d in COV_disease_list}
trial_data = []

# collect treatment and trial data
for c_node, attrs in G.nodes(data=True):
    if c_node.startswith("Compound::"):
        treatment_info = attrs.get("treatment_data", {})

        # store treatment info if available
        if treatment_info:
            for d in COV_disease_list:
                known_treatments[d][c_node] = treatment_info

        # extract and store trial phase data
        ect_data = treatment_info.get("External Clinical Trials", {})
        if ect_data:
            phase = ect_data.get("Phase", "Not Available")
            parsed_phase = parse_phase(phase)
            if parsed_phase is not None:
                trial_data.append(
                    {
                        "compound": c_node,
                        "phase": parsed_phase,
                        "phase_raw": phase,
                        "is_experimental": "Experimental Unapproved" in treatment_info,
                    }
                )

trial_df = pd.DataFrame(trial_data)
median_phase = trial_df["phase"].median() if not trial_df.empty else 0.0

###################################################
# evaluate model predictions
###################################################


def evaluate_predictions(
    COV_disease_list,
    model,
    data,
    node_type_map,
    known_treatments,
    trial_df,
    median_phase,
    device,
    topk=10,
):
    model.eval()
    with torch.no_grad():
        x_dict = {ntype: data[ntype].x for ntype in data.node_types}
        edge_index_dict = {rel: data[rel].edge_index for rel in data.edge_types}
        edge_attr_dict = {
            rel: (data[rel].edge_attr if hasattr(data[rel], "edge_attr") else None)
            for rel in data.edge_types
        }
        node_embs = model(x_dict, edge_index_dict, edge_attr_dict)

        # Maps
        disease_name_to_id = {
            n: idx for n, (t, idx) in node_type_map.items() if t == "disease"
        }
        compound_id_to_name = {
            idx: n for n, (t, idx) in node_type_map.items() if t == "compound"
        }

        compound_ids = torch.arange(data["compound"].num_nodes, device=device)

        total_correct = 0
        total_predictions = 0
        total_weighted_correct = 0.0
        total_weighted = 0.0

        trial_lookup = {}
        if not trial_df.empty:
            for _, row in trial_df.iterrows():
                trial_lookup[row["compound"]] = {
                    "phase": row["phase"],
                    "is_experimental": row["is_experimental"],
                }

        for disease_name in COV_disease_list:
            if disease_name not in disease_name_to_id:
                continue

            disease_id = disease_name_to_id[disease_name]
            scores = model.predict_link(
                node_embs,
                ("compound", compound_ids),
                ("disease", torch.full_like(compound_ids, disease_id)),
                None,
            )

            probs = torch.sigmoid(scores).squeeze()
            top_scores, top_indices = torch.topk(probs, topk)

            top_scores = top_scores.cpu().tolist()
            top_indices = top_indices.cpu().tolist()

            known_for_disease = known_treatments.get(disease_name, {})

            for score_val, c_idx in zip(top_scores, top_indices):
                compound_name = compound_id_to_name[c_idx]
                total_predictions += 1
                correct = compound_name in known_for_disease

                # compute weight starting with base score
                weight = score_val

                # apply trial data modifiers if available
                trial_info = trial_lookup.get(compound_name, {})
                if trial_info:
                    if trial_info["is_experimental"]:
                        weight *= 1.5

                    phase = trial_info.get("phase", median_phase)
                    weight *= 1 + phase / 10.0

                if correct:
                    total_correct += 1
                    total_weighted_correct += weight
                total_weighted += weight

        accuracy = total_correct / total_predictions if total_predictions > 0 else 0.0
        weighted_accuracy = (
            total_weighted_correct / total_weighted if total_weighted > 0 else 0.0
        )

        print(f"Accuracy: {accuracy:.4f}")
        print(f"Weighted Accuracy: {weighted_accuracy:.4f}")

        return accuracy, weighted_accuracy


###################################################
accuracy, weighted_accuracy = evaluate_predictions(
    COV_disease_list,
    model,
    data,
    node_type_map,
    known_treatments,
    trial_df,
    median_phase,
    device,
    topk=20,
)
###################################################

Accuracy: 0.4500
Weighted Accuracy: 0.5277
