# KAMPING Tutorial 2. Homogenous Graph neural network modeling

Date created: 2024-10-25

In [1]:
# Import kamping library before starting the tutorial
import kamping

%load_ext autoreload
%autoreload 2

In the previous tutorial we have shown how to use parse information from a single KGML file into KeggGraph object for storing information in a easily-access way. In this tutorial, we will show you how to dataset can be used in one of the most popular graph-machine learning package "pytorch-geometric" through provided utility function with ease.

Machine-learning graph model also use data contains more than one graphs, you can use `kamping.create_graphs` function to create a list of KeggGraph objects from a directory with KGML files.

In this tutorial we target the homogenous graph, which is defined only have one type of nodes in a graph. In our case, a homogenous graph is a "gene-only" graph or "metbaolite-only" graph. Training on Homogenous graph is easy to understand, which is the reason why we start from here. Later, we will show you KAMPING can also convert heterogenous graph in a similar way with just a littble bit extra effort. 

In [45]:
metabolite_graphs = kamping.create_graphs('../data/kgml_hsa', type='gene', verbose=True, ignore_file=['hsa01100.xml'])

INFO:KeggGraph:Now parsing: path:hsa00010...
INFO:KeggGraph:Graph path:hsa00010 parsed successfully!
INFO:KeggGraph:Now parsing: path:hsa00020...
INFO:KeggGraph:Graph path:hsa00020 parsed successfully!
INFO:KeggGraph:Now parsing: path:hsa00030...
INFO:KeggGraph:Graph path:hsa00030 parsed successfully!
INFO:KeggGraph:Now parsing: path:hsa00040...
INFO:KeggGraph:Graph path:hsa00040 parsed successfully!
INFO:KeggGraph:Now parsing: path:hsa00051...
INFO:KeggGraph:Graph path:hsa00051 parsed successfully!
INFO:KeggGraph:Now parsing: path:hsa00052...
INFO:KeggGraph:Graph path:hsa00052 parsed successfully!
INFO:KeggGraph:Now parsing: path:hsa00053...
INFO:KeggGraph:Graph path:hsa00053 parsed successfully!
INFO:KeggGraph:Now parsing: path:hsa00061...
INFO:KeggGraph:Graph path:hsa00061 parsed successfully!
INFO:KeggGraph:Now parsing: path:hsa00062...
INFO:KeggGraph:Graph path:hsa00062 parsed successfully!
INFO:KeggGraph:Now parsing: path:hsa00071...
INFO:KeggGraph:Graph path:hsa00071 parsed succ

The batch processing of KGML could also be useful in regular task. To access the result of specific KGML file you can use code below.  

In [46]:
metabolite_graph_00010 = [graph for graph in metabolite_graphs if graph.name == 'path:hsa00010'][0]

In [47]:
metabolite_graph_00010

KEGG Pathway: 
            [Title]: Glycolysis / Gluconeogenesis
            [Name]: path:hsa00010
            [Org]: hsa
            [Link]: https://www.kegg.jp/kegg-bin/show_pathway?hsa00010
            [Image]: https://www.kegg.jp/kegg/pathway/hsa/hsa00010.png
            [Link]: https://www.kegg.jp/kegg-bin/show_pathway?hsa00010
            Graph type: gene 
            Number of Genes: 67
            Number of Compounds: 0
            Gene ID type : kegg
            Compound ID type : kegg
            Number of Nodes: 67
            Number of Edges: 559

In [48]:
metabolite_graph_00010.edges

Unnamed: 0,entry1,entry2,type,subtype_name,subtype_value,entry1_type,entry2_type
0,hsa:10327,hsa:124,PPrel,compound-propagation,custom,gene,gene
1,hsa:10327,hsa:125,PPrel,compound-propagation,custom,gene,gene
2,hsa:10327,hsa:126,PPrel,compound-propagation,custom,gene,gene
3,hsa:10327,hsa:127,PPrel,compound-propagation,custom,gene,gene
4,hsa:10327,hsa:128,PPrel,compound-propagation,custom,gene,gene
...,...,...,...,...,...,...,...
554,hsa:9562,hsa:387712,PPrel,compound-propagation,custom,gene,gene
555,hsa:9562,hsa:441531,PPrel,compound-propagation,custom,gene,gene
556,hsa:9562,hsa:5223,PPrel,compound-propagation,custom,gene,gene
557,hsa:9562,hsa:5224,PPrel,compound-propagation,custom,gene,gene


In this tutorial, we will use pre-processed protein embedding information directly from uniprot, so we need to convert the KEGG gene ID into UniProt ID. We don't need to convert the KEGG compound id so we keep it untouched. If you didn't specify the "compound_target" when initalizing the converter, it will be default as "kegg". The same if you only want to convert gene ID. 

In [49]:
converter = kamping.Converter('hsa', gene_target='uniprot', verbose=True)

In [50]:
# todo: repeat convert again will generate NA check why 
for graph in metabolite_graphs:
    converter.convert(graph)

INFO:kamping.parser.convert:Conversion of path:hsa00010 complete!
INFO:kamping.parser.convert:Conversion of path:hsa00020 complete!
INFO:kamping.parser.convert:Conversion of path:hsa00030 complete!
INFO:kamping.parser.convert:Conversion of path:hsa00040 complete!
INFO:kamping.parser.convert:Conversion of path:hsa00051 complete!
INFO:kamping.parser.convert:Conversion of path:hsa00052 complete!
INFO:kamping.parser.convert:Conversion of path:hsa00053 complete!
INFO:kamping.parser.convert:Conversion of path:hsa00061 complete!
INFO:kamping.parser.convert:Conversion of path:hsa00062 complete!
INFO:kamping.parser.convert:Conversion of path:hsa00071 complete!
INFO:kamping.parser.convert:Conversion of path:hsa00100 complete!
INFO:kamping.parser.convert:Conversion of path:hsa00120 complete!
INFO:kamping.parser.convert:Conversion of path:hsa00130 complete!
INFO:kamping.parser.convert:Conversion of path:hsa00140 complete!
INFO:kamping.parser.convert:Conversion of path:hsa00220 complete!
INFO:kampi

If you didn't convert Compound ID into other ID. You can use `kamping.get_kegg_mol` function to retrieve the molfile from KEGG for each compound in all graphs and create a MOL object using RDKit (https://www.rdkit.org/). It will return a pd.dataframe with first column the compound ID and second column as the MOL object.

In [51]:
# uncomment to run for the first time
# mols = kamping.get_kegg_mol(graphs)

The process might take a while due to the large number of compounds from so many graphs. It could a good idea to save the created pd.DataFrame for repeated use when testing different approach of embedding metabolite.

In [52]:
import pandas as pd

# uncommented code below if run the first time
# save the mols to a file
# mols.to_pickle('data/mols.pkl')
# retrieve mol from file
mols = pd.read_pickle('data/mols.pkl')

Not all compound has a molFile from KEGG. Most compounds without molFile are glycan which is doesn't have a fixed atom composition.  Right now, we can just ignore them.

In [53]:
mols

Unnamed: 0,id,ROMol
0,cpd:C00038,<rdkit.Chem.rdchem.Mol object at 0x2d0822ed0>
1,cpd:C01180,<rdkit.Chem.rdchem.Mol object at 0x2d0822e80>
2,gl:G00083,
3,cpd:C20683,<rdkit.Chem.rdchem.Mol object at 0x2d0822f20>
4,cpd:C02593,<rdkit.Chem.rdchem.Mol object at 0x2d0822fc0>
...,...,...
1658,gl:G10599,
1659,cpd:C03090,<rdkit.Chem.rdchem.Mol object at 0x2c64a0a40>
1660,cpd:C00097,<rdkit.Chem.rdchem.Mol object at 0x2db798b30>
1661,cpd:C11134,<rdkit.Chem.rdchem.Mol object at 0x2db748ae0>


After we get the MOL object of each compound, we can use RDkit to embedding them into vectors that can be understanded by machine.

In [54]:
# todo: Might be a good idea to depend on scikit-mol  

In [55]:
mol_embeddings = kamping.get_mol_embeddings_from_dataframe(mols, transformer='morgan')

'
                    total 231 Invalid rows with "None" in the ROMol column


In [59]:
protein_embeddings = kamping.get_uniprot_protein_embeddings(metabolite_graphs, '../data/embedding/protein_embedding.h5')

In [60]:
pyg_one_graph = kamping.get_pyg_one_graph(metabolite_graphs, embeddings=protein_embeddings)

In [61]:
data = pyg_one_graph

In [62]:
data

Data(edge_index=[2, 108747], node_type=[7277], node_name=[7277], edge_type=[108747], edge_subtype_name=[108747], entry1_type=[108747], entry2_type=[108747], name='combined', type='gene', x=[7277, 1024])

In [63]:
import os.path as osp

import torch
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score

from torch_geometric.utils import negative_sampling
from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T
from torch_geometric.nn import GCNConv
from torch_geometric.utils import train_test_split_edges

In [64]:
from torch_geometric import transforms

data.train_mask = data.val_mask = data.test_mask = data.y = None
data = train_test_split_edges(data)



In [65]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = "cpu"

In [66]:
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = GCNConv(1024, 128)
        self.conv2 = GCNConv(128, 64)

    def encode(self):
        x = self.conv1(data.x, data.train_pos_edge_index) # convolution 1
        x = x.relu()
        return self.conv2(x, data.train_pos_edge_index) # convolution 2

    def decode(self, z, pos_edge_index, neg_edge_index): # only pos and neg edges
        edge_index = torch.cat([pos_edge_index, neg_edge_index], dim=-1) # concatenate pos and neg edges
        logits = (z[edge_index[0]] * z[edge_index[1]]).sum(dim=-1)  # dot product 
        return logits

    def decode_all(self, z):
        prob_adj = z @ z.t() # get adj NxN
        return (prob_adj > 0).nonzero(as_tuple=False).t() # get predicted edge_list 

In [67]:
model, data = Net().to(device), data.to(device)
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.01)

In [68]:
def get_link_labels(pos_edge_index, neg_edge_index):
    # returns a tensor:
    # [1,1,1,1,...,0,0,0,0,0,..] with the number of ones is equel to the lenght of pos_edge_index
    # and the number of zeros is equal to the length of neg_edge_index
    E = pos_edge_index.size(1) + neg_edge_index.size(1)
    link_labels = torch.zeros(E, dtype=torch.float, device=device)
    link_labels[:pos_edge_index.size(1)] = 1.
    return link_labels


def train():
    model.train()

    neg_edge_index = negative_sampling(
        edge_index=data.train_pos_edge_index, #positive edges
        num_nodes=data.num_nodes, # number of nodes
        num_neg_samples=data.train_pos_edge_index.size(1)) # number of neg_sample equal to number of pos_edges

    optimizer.zero_grad()

    z = model.encode() #encode
    link_logits = model.decode(z, data.train_pos_edge_index, neg_edge_index) # decode

    link_labels = get_link_labels(data.train_pos_edge_index, neg_edge_index)
    loss = F.binary_cross_entropy_with_logits(link_logits, link_labels)
    loss.backward()
    optimizer.step()

    return loss


@torch.no_grad()
def test():
    model.eval()
    perfs = []
    for prefix in ["val", "test"]:
        pos_edge_index = data[f'{prefix}_pos_edge_index']
        neg_edge_index = data[f'{prefix}_neg_edge_index']

        z = model.encode() # encode train
        link_logits = model.decode(z, pos_edge_index, neg_edge_index) # decode test or val
        link_probs = link_logits.sigmoid() # apply sigmoid

        link_labels = get_link_labels(pos_edge_index, neg_edge_index) # get link

        perfs.append(roc_auc_score(link_labels.cpu(), link_probs.cpu())) #compute roc_auc score
    return perfs

In [69]:
best_val_perf = test_perf = 0
for epoch in range(1, 101):
    train_loss = train()
    val_perf, tmp_test_perf = test()
    if val_perf > best_val_perf:
        best_val_perf = val_perf
        test_perf = tmp_test_perf
    log = 'Epoch: {:03d}, Loss: {:.4f}, AUC Val: {:.4f}, AUC Test: {:.4f}'
    print(log.format(epoch, train_loss, best_val_perf, test_perf))

Epoch: 001, Loss: 0.6767, AUC Val: 0.8186, AUC Test: 0.8343
Epoch: 002, Loss: 2.6535, AUC Val: 0.8661, AUC Test: 0.8850
Epoch: 003, Loss: 0.6298, AUC Val: 0.8661, AUC Test: 0.8850
Epoch: 004, Loss: 0.8200, AUC Val: 0.8661, AUC Test: 0.8850
Epoch: 005, Loss: 0.7325, AUC Val: 0.8661, AUC Test: 0.8850
Epoch: 006, Loss: 0.6600, AUC Val: 0.8661, AUC Test: 0.8850
Epoch: 007, Loss: 0.6440, AUC Val: 0.8661, AUC Test: 0.8850
Epoch: 008, Loss: 0.6413, AUC Val: 0.8661, AUC Test: 0.8850
Epoch: 009, Loss: 0.6396, AUC Val: 0.8661, AUC Test: 0.8850
Epoch: 010, Loss: 0.6359, AUC Val: 0.8661, AUC Test: 0.8850
Epoch: 011, Loss: 0.6323, AUC Val: 0.8661, AUC Test: 0.8850
Epoch: 012, Loss: 0.6254, AUC Val: 0.8686, AUC Test: 0.8814
Epoch: 013, Loss: 0.6187, AUC Val: 0.8759, AUC Test: 0.8899
Epoch: 014, Loss: 0.6123, AUC Val: 0.8810, AUC Test: 0.8952
Epoch: 015, Loss: 0.6049, AUC Val: 0.8835, AUC Test: 0.8978
Epoch: 016, Loss: 0.5974, AUC Val: 0.8847, AUC Test: 0.8986
Epoch: 017, Loss: 0.5885, AUC Val: 0.885