In [1]:
import os
import torch
print("Using torch", torch.__version__)

Using torch 2.1.0+cu118


In [2]:
!pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-2.1.0+cu118.html
!pip install pyg-library

Looking in links: https://data.pyg.org/whl/torch-2.1.0+cu118.html


In [3]:
!pip install ogb



In [4]:
from torch_geometric.data import Data
from ogb.linkproppred import PygLinkPropPredDataset
from torch_geometric import nn
import torch_geometric.transforms as T

In [5]:
dataset = PygLinkPropPredDataset(name="ogbl-biokg", root='dataset/')

split_edge = dataset.get_edge_split()
train_edge, valid_edge, test_edge = split_edge["train"], split_edge["valid"], split_edge["test"]
biokg_raw = dataset[0]

print(biokg_raw)

Data(
  num_nodes_dict={
    disease=10687,
    drug=10533,
    function=45085,
    protein=17499,
    sideeffect=9969,
  },
  edge_index_dict={
    (disease, disease-protein, protein)=[2, 73547],
    (drug, drug-disease, disease)=[2, 5147],
    (drug, drug-drug_acquired_metabolic_disease, drug)=[2, 63430],
    (drug, drug-drug_bacterial_infectious_disease, drug)=[2, 18554],
    (drug, drug-drug_benign_neoplasm, drug)=[2, 30348],
    (drug, drug-drug_cancer, drug)=[2, 48514],
    (drug, drug-drug_cardiovascular_system_disease, drug)=[2, 94842],
    (drug, drug-drug_chromosomal_disease, drug)=[2, 316],
    (drug, drug-drug_cognitive_disorder, drug)=[2, 34660],
    (drug, drug-drug_cryptorchidism, drug)=[2, 128],
    (drug, drug-drug_developmental_disorder_of_mental_health, drug)=[2, 14314],
    (drug, drug-drug_endocrine_system_disease, drug)=[2, 55994],
    (drug, drug-drug_fungal_infectious_disease, drug)=[2, 36114],
    (drug, drug-drug_gastrointestinal_system_disease, drug)=[2, 8321

In [6]:
from torch_geometric.data import HeteroData
biokg = HeteroData()
for node_type, num_nodes in biokg_raw.num_nodes_dict.items():
    biokg[node_type].num_nodes = num_nodes
for edge_type, edge_index in biokg_raw.edge_index_dict.items():
    biokg[edge_type].edge_index = edge_index

In [7]:
biokg

HeteroData(
  disease={ num_nodes=10687 },
  drug={ num_nodes=10533 },
  function={ num_nodes=45085 },
  protein={ num_nodes=17499 },
  sideeffect={ num_nodes=9969 },
  (disease, disease-protein, protein)={ edge_index=[2, 73547] },
  (drug, drug-disease, disease)={ edge_index=[2, 5147] },
  (drug, drug-drug_acquired_metabolic_disease, drug)={ edge_index=[2, 63430] },
  (drug, drug-drug_bacterial_infectious_disease, drug)={ edge_index=[2, 18554] },
  (drug, drug-drug_benign_neoplasm, drug)={ edge_index=[2, 30348] },
  (drug, drug-drug_cancer, drug)={ edge_index=[2, 48514] },
  (drug, drug-drug_cardiovascular_system_disease, drug)={ edge_index=[2, 94842] },
  (drug, drug-drug_chromosomal_disease, drug)={ edge_index=[2, 316] },
  (drug, drug-drug_cognitive_disorder, drug)={ edge_index=[2, 34660] },
  (drug, drug-drug_cryptorchidism, drug)={ edge_index=[2, 128] },
  (drug, drug-drug_developmental_disorder_of_mental_health, drug)={ edge_index=[2, 14314] },
  (drug, drug-drug_endocrine_syste

In [8]:
print(biokg.num_nodes)
print(biokg.num_edges)
# print(biokg.edge_index_dict)

93773
4762678


In [9]:
node_types, edge_types = biokg.metadata()
print(node_types)
print(edge_types)

['disease', 'drug', 'function', 'protein', 'sideeffect']
[('disease', 'disease-protein', 'protein'), ('drug', 'drug-disease', 'disease'), ('drug', 'drug-drug_acquired_metabolic_disease', 'drug'), ('drug', 'drug-drug_bacterial_infectious_disease', 'drug'), ('drug', 'drug-drug_benign_neoplasm', 'drug'), ('drug', 'drug-drug_cancer', 'drug'), ('drug', 'drug-drug_cardiovascular_system_disease', 'drug'), ('drug', 'drug-drug_chromosomal_disease', 'drug'), ('drug', 'drug-drug_cognitive_disorder', 'drug'), ('drug', 'drug-drug_cryptorchidism', 'drug'), ('drug', 'drug-drug_developmental_disorder_of_mental_health', 'drug'), ('drug', 'drug-drug_endocrine_system_disease', 'drug'), ('drug', 'drug-drug_fungal_infectious_disease', 'drug'), ('drug', 'drug-drug_gastrointestinal_system_disease', 'drug'), ('drug', 'drug-drug_hematopoietic_system_disease', 'drug'), ('drug', 'drug-drug_hematopoietic_system_diseases', 'drug'), ('drug', 'drug-drug_hypospadias', 'drug'), ('drug', 'drug-drug_immune_system_diseas

In [10]:
print(biokg["disease"].num_nodes)
print(biokg["drug"].num_nodes)
print(biokg["function"].num_nodes)
print(biokg["protein"].num_nodes)
print(biokg["sideeffect"].num_nodes)

10687
10533
45085
17499
9969


In [11]:
for t in edge_types:
    print(f"{t}: {biokg[t].num_edges}")

('disease', 'disease-protein', 'protein'): 73547
('drug', 'drug-disease', 'disease'): 5147
('drug', 'drug-drug_acquired_metabolic_disease', 'drug'): 63430
('drug', 'drug-drug_bacterial_infectious_disease', 'drug'): 18554
('drug', 'drug-drug_benign_neoplasm', 'drug'): 30348
('drug', 'drug-drug_cancer', 'drug'): 48514
('drug', 'drug-drug_cardiovascular_system_disease', 'drug'): 94842
('drug', 'drug-drug_chromosomal_disease', 'drug'): 316
('drug', 'drug-drug_cognitive_disorder', 'drug'): 34660
('drug', 'drug-drug_cryptorchidism', 'drug'): 128
('drug', 'drug-drug_developmental_disorder_of_mental_health', 'drug'): 14314
('drug', 'drug-drug_endocrine_system_disease', 'drug'): 55994
('drug', 'drug-drug_fungal_infectious_disease', 'drug'): 36114
('drug', 'drug-drug_gastrointestinal_system_disease', 'drug'): 83210
('drug', 'drug-drug_hematopoietic_system_disease', 'drug'): 79202
('drug', 'drug-drug_hematopoietic_system_diseases', 'drug'): 3006
('drug', 'drug-drug_hypospadias', 'drug'): 292
('dr

In [12]:
nums_edge = [biokg[t].num_edges for t in edge_types]

In [13]:
print(edge_types[-8])
relation_focus = edge_types[-8]
relation_edges = biokg[relation_focus]

('protein', 'protein-function', 'function')


In [14]:
transform = T.Compose([
    T.RandomLinkSplit(
        num_val=0.05,
        num_test=0.1,
        disjoint_train_ratio=0.2,   # supervision
        add_negative_train_samples=False,
        neg_sampling_ratio=1.0,
        edge_types=edge_types
    )
])

In [15]:
train_data, val_data, test_data = transform(biokg)

In [16]:
print(f"Train Data: {train_data}")
# print(f"Val Data: {val_data}")
# print(f"Test Data: {test_data}")

Train Data: HeteroData(
  disease={ num_nodes=10687 },
  drug={ num_nodes=10533 },
  function={ num_nodes=45085 },
  protein={ num_nodes=17499 },
  sideeffect={ num_nodes=9969 },
  (disease, disease-protein, protein)={
    edge_index=[2, 50013],
    edge_label=[12503],
    edge_label_index=[2, 12503],
  },
  (drug, drug-disease, disease)={
    edge_index=[2, 3501],
    edge_label=[875],
    edge_label_index=[2, 875],
  },
  (drug, drug-drug_acquired_metabolic_disease, drug)={
    edge_index=[2, 43133],
    edge_label=[10783],
    edge_label_index=[2, 10783],
  },
  (drug, drug-drug_bacterial_infectious_disease, drug)={
    edge_index=[2, 12618],
    edge_label=[3154],
    edge_label_index=[2, 3154],
  },
  (drug, drug-drug_benign_neoplasm, drug)={
    edge_index=[2, 20638],
    edge_label=[5159],
    edge_label_index=[2, 5159],
  },
  (drug, drug-drug_cancer, drug)={
    edge_index=[2, 32991],
    edge_label=[8247],
    edge_label_index=[2, 8247],
  },
  (drug, drug-drug_cardiovascular

In [17]:
import sys
import torch_geometric
from torch_geometric.nn import GCNConv, SAGEConv, GATConv, GINConv
from torch_geometric.nn import BatchNorm, LayerNorm, HeteroBatchNorm, HeteroLayerNorm
from torch_geometric.nn import to_hetero
from torch.nn import Embedding

class SAGE(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()

        self.relu = torch.nn.ReLU()
        self.conv1 = SAGEConv(in_channels, hidden_channels)
        self.conv2 = SAGEConv(hidden_channels, hidden_channels)
        self.conv3 = SAGEConv(hidden_channels, out_channels)

    def forward(self, node_feature, edge_index):

        output1 = self.relu(self.conv1(node_feature, edge_index))
        output2 = self.relu(self.conv2(output1, edge_index))
        output = self.conv3(output2, edge_index)
        return output

class SAGE_RES(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()

        self.relu = torch.nn.ReLU()
        self.conv1 = SAGEConv(in_channels, hidden_channels)
        self.conv2 = SAGEConv(hidden_channels, hidden_channels)
        self.conv3 = SAGEConv(hidden_channels, out_channels)
        self.res = torch.nn.Linear(in_channels, out_channels)
        # self.norm1 = BatchNorm(hidden_channels, 2)
        # self.norm2 = BatchNorm(hidden_channels, 2)

    def forward(self, node_feature, edge_index):

        output1 = self.relu(self.conv1(node_feature, edge_index))
        # output1 = self.norm1(output1)
        output2 = self.relu(self.conv2(output1, edge_index))
        # output2 = self.norm2(output2)
        output3 = self.conv3(output2, edge_index)
        output_res = self.res(node_feature)
        return (output3 + output_res) * 0.5

class GAT(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()

        self.relu = torch.nn.ReLU()
        self.conv1 = GATConv(in_channels, hidden_channels, add_self_loops=False)
        self.conv2 = GATConv(hidden_channels, hidden_channels, add_self_loops=False)
        self.conv3 = GATConv(hidden_channels, out_channels, add_self_loops=False)

    def forward(self, node_feature, edge_index):

        output1 = self.relu(self.conv1(node_feature, edge_index))
        output2 = self.relu(self.conv2(output1, edge_index))
        output = self.conv3(output2, edge_index)
        return output

class GAT_RES(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()

        self.relu = torch.nn.ReLU()
        self.conv1 = GATConv(in_channels, hidden_channels, add_self_loops=False)
        self.conv2 = GATConv(hidden_channels, hidden_channels, add_self_loops=False)
        self.conv3 = GATConv(hidden_channels, out_channels, add_self_loops=False)
        self.res = torch.nn.Linear(in_channels, out_channels)

    def forward(self, node_feature, edge_index):

        output1 = self.relu(self.conv1(node_feature, edge_index))
        output2 = self.relu(self.conv2(output1, edge_index))
        output3 = self.conv3(output2, edge_index)
        output_res = self.res(node_feature)
        return (output3 + output_res) * 0.5

class Embedder(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        # Embedding dim (in_channels): 20
        proteins = biokg["protein"]
        functions = biokg["function"]
        self.sage = SAGE(in_channels, hidden_channels, out_channels)
        self.sage = to_hetero(self.sage, metadata=biokg.metadata())
        self.sage_res = SAGE_RES(in_channels, hidden_channels, out_channels)
        self.sage_res = to_hetero(self.sage_res, metadata=biokg.metadata())
        self.gat = GAT(in_channels, hidden_channels, out_channels)
        self.gat = to_hetero(self.gat, metadata=biokg.metadata())
        self.gat_res = GAT_RES(in_channels, hidden_channels, out_channels)
        self.gat_res = to_hetero(self.gat_res, metadata=biokg.metadata())

        self.emb_tr_protein = Embedding(proteins.num_nodes, in_channels)
        self.emb_tr_function = Embedding(functions.num_nodes, in_channels)

        self.gnn = self.gat_res

    def forward(self, hetero_data):
        features = {
            "protein": self.emb_tr_function(torch.arange(biokg["protein"].num_nodes)),
            "function": self.emb_tr_function(torch.arange(biokg["function"].num_nodes))
        }
        embeddings = self.gnn(features, hetero_data.edge_index_dict)
        return embeddings

In [18]:
def calc_emb_similarity(node_embs, edge_index, method="cosine"):
    if method == "cosine":
        return torch.sum(node_embs["protein"][edge_index[0]] * node_embs["function"][edge_index[1]], 1)

In [19]:
from torch_geometric.loader import LinkNeighborLoader
train_loader = LinkNeighborLoader(
    data=train_data,
    num_neighbors=[40, 20, 10],
    neg_sampling="binary",
    neg_sampling_ratio=1.0,
    edge_label_index=(relation_focus, train_data[relation_focus].edge_label_index),
    edge_label=train_data[relation_focus].edge_label,
    batch_size=256,
    shuffle=True
)



In [20]:
from tqdm import tqdm
from torch_geometric.utils import negative_sampling

def train(model, dataloader, optimizer, loss_fn):
    correct_count = 0
    all_count = 0
    loss = 0
    model.train()
    for batch in tqdm(dataloader):
        optimizer.zero_grad()

        node_embeddings = model(batch)

        """
        neg_edge_index = negative_sampling(
            edge_index=batch[edge_types[0]].edge_index,
            num_nodes=batch.num_nodes,
            num_neg_samples=batch[edge_types[0]].edge_label.shape
        )

        edges_all = torch.cat((batch[edge_types[0]].edge_label_index, neg_edge_index), dim=1)
        labels_all = torch.cat((batch[edge_types[0]].edge_label, torch.zeros(batch[edge_types[0]].edge_label.shape)), dim=0)
        """

        labels = batch[relation_focus].edge_label
        similarities = calc_emb_similarity(node_embeddings, batch[relation_focus].edge_label_index)
        predictions = similarities.sigmoid() > 0.5
        correct_count += torch.sum(predictions == labels)
        all_count += len(labels)

        loss = loss_fn(similarities, labels)
        loss.backward()
        optimizer.step()
    return model, (float(correct_count) / float(all_count))


In [21]:
from sklearn.metrics import roc_auc_score

@torch.no_grad()
def test(model, hetero_data):
    model.eval()
    node_embs = model(hetero_data)
    node_embs = calc_emb_similarity(node_embs, hetero_data[relation_focus].edge_label_index).view(-1).sigmoid()
    return roc_auc_score(hetero_data[relation_focus].edge_label.cpu().numpy(), node_embs.cpu().numpy())

In [24]:
model = Embedder(in_channels=20, hidden_channels=128, out_channels=64)
optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-2)
loss_fn = torch.nn.BCEWithLogitsLoss()
# print(model)

In [25]:
epochs = 10
for epoch in range(1, epochs + 1):
    model, acc = train(model, train_loader, optimizer, loss_fn)
    val_auc = test(model, val_data)
    test_auc = test(model, test_data)
    print(f'Epoch: {epoch:03d}, Training Accuracy: {acc:.4f}, Val AUC: {val_auc:.4f}, Test AUC: {test_auc:.4f}')

  0%|          | 0/517 [00:01<?, ?it/s]


AttributeError: ignored