In [1]:
# Import kamping library before starting the tutorial
import kamping

%load_ext autoreload
%autoreload 2

In [2]:
gene_graphs = kamping.create_graphs('../data/kgml_hsa', type='mixed', verbose=True, ignore_file=['hsa01100.xml'])


            Visit https://www.kegg.jp/kegg-bin/show_pathway?hsa00190 for pathway details.

            There are likely no edges in which to parse...
INFO:KeggGraph:Now parsing: path:hsa00220...
INFO:KeggGraph:Graph path:hsa00220 parsed successfully!
INFO:KeggGraph:Now parsing: path:hsa00230...
INFO:KeggGraph:Graph path:hsa00230 parsed successfully!
INFO:KeggGraph:Now parsing: path:hsa00232...
INFO:KeggGraph:Graph path:hsa00232 parsed successfully!
INFO:KeggGraph:Now parsing: path:hsa00240...
INFO:KeggGraph:Graph path:hsa00240 parsed successfully!
INFO:KeggGraph:Now parsing: path:hsa00250...
INFO:KeggGraph:Graph path:hsa00250 parsed successfully!
INFO:KeggGraph:Now parsing: path:hsa00260...
INFO:KeggGraph:Graph path:hsa00260 parsed successfully!
INFO:KeggGraph:Now parsing: path:hsa00270...
INFO:KeggGraph:Graph path:hsa00270 parsed successfully!
INFO:KeggGraph:Now parsing: path:hsa00280...
INFO:KeggGraph:Graph path:hsa00280 parsed successfully!
INFO:KeggGraph:Now parsing: path:hsa00290

In [3]:
gene_graph_00010 = [graph for graph in gene_graphs if graph.name == 'path:hsa00010'][0]
gene_graph_00010

KEGG Pathway: 
            [Title]: Glycolysis / Gluconeogenesis
            [Name]: path:hsa00010
            [Org]: hsa
            [Link]: https://www.kegg.jp/kegg-bin/show_pathway?hsa00010
            [Image]: https://www.kegg.jp/kegg/pathway/hsa/hsa00010.png
            [Link]: https://www.kegg.jp/kegg-bin/show_pathway?hsa00010
            Graph type: mixed 
            Number of Genes: 67
            Number of Compounds: 26
            Gene ID type : kegg
            Compound ID type : kegg
            Number of Nodes: 93
            Number of Edges: 279

In [4]:
converter = kamping.Converter('hsa', gene_target='uniprot', verbose=True)

In [5]:
for graph in gene_graphs:
    converter.convert(graph)

INFO:kamping.parser.convert:Conversion of path:hsa00010 complete!
INFO:kamping.parser.convert:Conversion of path:hsa00020 complete!
INFO:kamping.parser.convert:Conversion of path:hsa00030 complete!
INFO:kamping.parser.convert:Conversion of path:hsa00040 complete!
INFO:kamping.parser.convert:Conversion of path:hsa00051 complete!
INFO:kamping.parser.convert:Conversion of path:hsa00052 complete!
INFO:kamping.parser.convert:Conversion of path:hsa00053 complete!
INFO:kamping.parser.convert:Conversion of path:hsa00061 complete!
INFO:kamping.parser.convert:Conversion of path:hsa00062 complete!
INFO:kamping.parser.convert:Conversion of path:hsa00071 complete!
INFO:kamping.parser.convert:Conversion of path:hsa00100 complete!
INFO:kamping.parser.convert:Conversion of path:hsa00120 complete!
INFO:kamping.parser.convert:Conversion of path:hsa00130 complete!
INFO:kamping.parser.convert:Conversion of path:hsa00140 complete!
INFO:kamping.parser.convert:Conversion of path:hsa00220 complete!
INFO:kampi

In [26]:
import pandas as pd

# uncommented code below if run the first time
# save the mols to a file
# mols.to_pickle('data/mols.pkl')
# retrieve mol from file
mols = pd.read_pickle('data/mols.pkl')
mol_embeddings = kamping.get_mol_embeddings_from_dataframe(mols, transformer='morgan')

'
                    total 231 Invalid rows with "None" in the ROMol column


In [27]:
mol_embeddings

{'cpd:C00038': array([0., 0., 0., ..., 0., 0., 0.], dtype=float32),
 'cpd:C01180': array([0., 0., 0., ..., 0., 0., 0.], dtype=float32),
 'cpd:C20683': array([0., 0., 0., ..., 0., 0., 0.], dtype=float32),
 'cpd:C02593': array([0., 1., 0., ..., 0., 0., 0.], dtype=float32),
 'cpd:C00286': array([0., 0., 0., ..., 0., 0., 0.], dtype=float32),
 'cpd:C03564': array([0., 0., 0., ..., 0., 0., 0.], dtype=float32),
 'cpd:C05452': array([0., 1., 0., ..., 0., 0., 0.], dtype=float32),
 'cpd:C00603': array([0., 1., 0., ..., 0., 0., 0.], dtype=float32),
 'cpd:C05443': array([0., 1., 0., ..., 0., 0., 0.], dtype=float32),
 'cpd:C06157': array([0., 1., 0., ..., 0., 0., 0.], dtype=float32),
 'cpd:C00055': array([0., 0., 0., ..., 0., 0., 0.], dtype=float32),
 'cpd:C04487': array([0., 1., 0., ..., 0., 0., 0.], dtype=float32),
 'cpd:C05294': array([0., 0., 0., ..., 0., 0., 0.], dtype=float32),
 'cpd:C01674': array([0., 0., 0., ..., 0., 0., 0.], dtype=float32),
 'cpd:C16549': array([0., 1., 0., ..., 0., 0., 0

In [28]:
protein_embeddings = kamping.get_uniprot_protein_embeddings(gene_graphs, '../data/embedding/protein_embedding.h5') 
protein_embeddings

{'up:Q09472': array([ 0.07514679,  0.05754357,  0.02730095, ...,  0.01029478,
        -0.00694583,  0.00510776], dtype=float32),
 'up:Q05940': array([ 0.03543108,  0.11163083,  0.06609377, ...,  0.00057277,
        -0.00168417, -0.00779132], dtype=float32),
 'up:Q8NFF5': array([ 0.03544585,  0.05745461,  0.06277175, ..., -0.02915345,
         0.00568385,  0.05075936], dtype=float32),
 'up:B3KXW6': array([ 0.00720328,  0.04428549,  0.025407  , ..., -0.02186162,
        -0.05797382,  0.02092825], dtype=float32),
 'up:P08697': array([-0.002179  ,  0.04799882,  0.05717658, ...,  0.00177752,
        -0.01672771,  0.06153679], dtype=float32),
 'up:P56750': array([ 0.00189529,  0.11344011,  0.02590067, ..., -0.00397826,
         0.01495809, -0.03836733], dtype=float32),
 'up:A8K8E4': array([-0.01210068,  0.06828921,  0.00862978, ...,  0.00124401,
         0.00249301,  0.00844896], dtype=float32),
 'up:Q6FI00': array([ 0.01111019,  0.0441544 , -0.01572975, ...,  0.00072912,
        -0.06173321

In [29]:
# combine protein embeddings and metabolite embeddings into one dictionary
embeddings = {**protein_embeddings, **mol_embeddings}
len(embeddings)

8837

In [96]:
pyg_graph = kamping.convert_to_single_pyg(gene_graphs, embeddings=embeddings)
data= pyg_graph
data

  hetero_data_dict[group][key] = torch.tensor(value)


HeteroData(
  name='combined',
  type='mixed',
  compound={ x=[1432, 1024] },
  gene={ x=[7405, 1024] },
  (compound, to, compound)={ edge_index=[2, 362] },
  (compound, to, gene)={ edge_index=[2, 8789] },
  (gene, to, compound)={ edge_index=[2, 6955] },
  (gene, to, gene)={ edge_index=[2, 77375] }
)

In [97]:
# # del data['compound']
# # del data[('gene', 'to', 'gene')]
# # del data[('gene', 'to', 'compound')]
# # del data[('compound', 'to', 'gene')]
# # del data[('compound', 'to', 'compound')]
# data

In [98]:
# to undirected graph
from torch_geometric.transforms.to_undirected import ToUndirected
# transform = ToUndirected()
# data = transform(data)

In [99]:
import torch_geometric.transforms as T
transform = T.RandomLinkSplit(
    num_val=0.1,
    num_test=0.1,
    is_undirected=True,
    # disjoint_train_ratio=0.3, # TODO
    neg_sampling_ratio=1.0, # TODO
    add_negative_train_samples=False,
    edge_types=("gene", "to", "gene")
)
train_data, val_data, test_data = transform(data)

In [100]:
import torch

In [101]:
from torch.utils.data import random_split
import torch.nn as nn
from torch_geometric.nn import GCNConv, GAE, GATConv, Linear, to_hetero, SAGEConv
from torch_geometric.loader import DataLoader
import torch
from tqdm import tqdm
import torch_geometric.utils as utils
import torch.nn.functional as F

class GCNEncoder(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = SAGEConv(in_channels, 2 * out_channels)
        self.conv2 = SAGEConv(2 * out_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        return self.conv2(x, edge_index)


In [102]:
train_data['compound'].x

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 1., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])

In [103]:
from sklearn.metrics import roc_auc_score

In [104]:
def train(model, optimizer, data):

    z_dict = model(data.x_dict, data.edge_index_dict)

    pos_edge_label_index = data[('gene', 'to', 'gene')].edge_label_index
    pos_edge_label = torch.ones(pos_edge_label_index.size(1))

    neg_edge_label_index = utils.negative_sampling(edge_index=pos_edge_label_index, #positive edges
                                                    num_nodes=data['gene'].x.size(0), # number of nodes
                                                    num_neg_samples=pos_edge_label_index.size(1))
    neg_edge_label = torch.zeros(neg_edge_label_index.size(1))

    edge_label_index = torch.cat([pos_edge_label_index, neg_edge_label_index], dim=1)
    edge_label = torch.cat([pos_edge_label, neg_edge_label], dim=0)

    z_src = z_dict['gene'][edge_label_index[0]]
    z_dst = z_dict['gene'][edge_label_index[1]]

    recon = (z_src * z_dst).sum(dim=-1)
    loss = F.binary_cross_entropy_with_logits(recon, edge_label)
    # calculate AUC
    auc = roc_auc_score(edge_label.cpu().detach().numpy(), recon.cpu().detach().numpy())
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    return loss.item(), auc


In [105]:
def main():
    num_epochs = 1000
    batch_size = 10
    in_channels, out_channels = 1024, 10
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    model = GCNEncoder(in_channels, out_channels)
    model = to_hetero(model, data.metadata(), aggr='sum')
    model = model.to(device)

    optimizer = torch.optim.Adam(params=model.parameters(), lr=0.01)

    for _ in tqdm(range(num_epochs)):
        loss, auc = train(model, optimizer, train_data)
        print("Loss is : ", loss)
        print("AUC is : ", auc)


In [107]:
data

HeteroData(
  name='combined',
  type='mixed',
  compound={ x=[1432, 1024] },
  gene={ x=[7405, 1024] },
  (compound, to, compound)={ edge_index=[2, 362] },
  (compound, to, gene)={ edge_index=[2, 8789] },
  (gene, to, compound)={ edge_index=[2, 6955] },
  (gene, to, gene)={ edge_index=[2, 77375] }
)

In [106]:
main()

  0%|          | 2/1000 [00:00<01:51,  8.95it/s]

Loss is :  0.7060911655426025
AUC is :  0.5770540960328022
Loss is :  0.8108980655670166
AUC is :  0.7511234942053767


  0%|          | 4/1000 [00:00<01:51,  8.92it/s]

Loss is :  0.6988070607185364
AUC is :  0.8725923880934381
Loss is :  0.6101466417312622
AUC is :  0.8689185786405437


  1%|          | 6/1000 [00:00<01:50,  8.96it/s]

Loss is :  0.5839148759841919
AUC is :  0.8840822951889831
Loss is :  0.5937895178794861
AUC is :  0.8781830781398758


  1%|          | 8/1000 [00:00<01:47,  9.20it/s]

Loss is :  0.5696653723716736
AUC is :  0.8833380028537744
Loss is :  0.5591142773628235
AUC is :  0.8853442321534339


  1%|          | 10/1000 [00:01<01:48,  9.15it/s]

Loss is :  0.5454058647155762
AUC is :  0.8868895920234667
Loss is :  0.5405158400535583
AUC is :  0.8834632372418629


  1%|▏         | 13/1000 [00:01<01:42,  9.66it/s]

Loss is :  0.5324928760528564
AUC is :  0.8808450532755107
Loss is :  0.5231932401657104
AUC is :  0.8825571158949255
Loss is :  0.512742817401886
AUC is :  0.8900623668777694


  2%|▏         | 15/1000 [00:01<01:40,  9.78it/s]

Loss is :  0.5065770745277405
AUC is :  0.8994640179170846
Loss is :  0.5021194219589233
AUC is :  0.908784216285038
Loss is :  0.4964260756969452
AUC is :  0.9166534757303947


  2%|▏         | 19/1000 [00:01<01:36, 10.19it/s]

Loss is :  0.491542786359787
AUC is :  0.9213433466834
Loss is :  0.487273633480072
AUC is :  0.9266961184267786
Loss is :  0.4851556718349457
AUC is :  0.9294341828537437


  2%|▏         | 21/1000 [00:02<01:35, 10.24it/s]

Loss is :  0.4828897714614868
AUC is :  0.9324222106602484
Loss is :  0.481884628534317
AUC is :  0.9350535305279937
Loss is :  0.47576630115509033
AUC is :  0.9391204352172355


  2%|▎         | 25/1000 [00:02<01:34, 10.35it/s]

Loss is :  0.471223920583725
AUC is :  0.941451248260629
Loss is :  0.47000038623809814
AUC is :  0.9427903262707603
Loss is :  0.4676051735877991
AUC is :  0.9441071138945126


  3%|▎         | 27/1000 [00:02<01:35, 10.17it/s]

Loss is :  0.4686444103717804
AUC is :  0.9432127263198913
Loss is :  0.4659672975540161
AUC is :  0.9431435707931579


  3%|▎         | 29/1000 [00:02<01:35, 10.16it/s]

Loss is :  0.46610772609710693
AUC is :  0.9422878637061068
Loss is :  0.46234041452407837
AUC is :  0.9443701696936891
Loss is :  0.4601728320121765
AUC is :  0.946619656099795


  3%|▎         | 33/1000 [00:03<01:33, 10.30it/s]

Loss is :  0.4598025679588318
AUC is :  0.9494973101959108
Loss is :  0.4571104943752289
AUC is :  0.9519973699864754
Loss is :  0.4541993737220764
AUC is :  0.9542966581268666


  4%|▎         | 35/1000 [00:03<01:33, 10.29it/s]

Loss is :  0.4562593400478363
AUC is :  0.9537251181766059
Loss is :  0.4533967077732086
AUC is :  0.9552584071331736
Loss is :  0.4506532847881317
AUC is :  0.9576329495891531


  4%|▍         | 39/1000 [00:03<01:32, 10.39it/s]

Loss is :  0.4497893154621124
AUC is :  0.9601104707062127
Loss is :  0.45207515358924866
AUC is :  0.9602737094090628
Loss is :  0.45037150382995605
AUC is :  0.9594460133950224


  4%|▍         | 41/1000 [00:04<01:32, 10.36it/s]

Loss is :  0.4476616084575653
AUC is :  0.9597937766955293
Loss is :  0.4457118511199951
AUC is :  0.9610665355757693
Loss is :  0.44631290435791016
AUC is :  0.9625279378472138


  4%|▍         | 45/1000 [00:04<01:32, 10.37it/s]

Loss is :  0.4457659125328064
AUC is :  0.9637393565736672
Loss is :  0.4427776634693146
AUC is :  0.9652121012816275
Loss is :  0.43895474076271057
AUC is :  0.9660040433193029


  5%|▍         | 47/1000 [00:04<01:31, 10.41it/s]

Loss is :  0.4418790340423584
AUC is :  0.9652167147589168
Loss is :  0.43941229581832886
AUC is :  0.9662950591476902
Loss is :  0.4393623471260071
AUC is :  0.9675018057910096


  5%|▌         | 51/1000 [00:05<01:30, 10.46it/s]

Loss is :  0.435165137052536
AUC is :  0.9701500740908949
Loss is :  0.43479475378990173
AUC is :  0.970846698448696
Loss is :  0.43349018692970276
AUC is :  0.9716090300409578


  5%|▌         | 53/1000 [00:05<01:30, 10.45it/s]

Loss is :  0.4343411922454834
AUC is :  0.9716191430150788
Loss is :  0.43287378549575806
AUC is :  0.9717433445539285


  6%|▌         | 55/1000 [00:05<01:36,  9.76it/s]

Loss is :  0.4335426092147827
AUC is :  0.9728745999376001
Loss is :  0.4297180473804474
AUC is :  0.9744469575666864


  6%|▌         | 58/1000 [00:05<01:35,  9.91it/s]

Loss is :  0.429848849773407
AUC is :  0.9744882293148239
Loss is :  0.4263664782047272
AUC is :  0.9752094983755115
Loss is :  0.42545974254608154
AUC is :  0.9754002547279716


  6%|▌         | 60/1000 [00:05<01:33, 10.07it/s]

Loss is :  0.4239937663078308
AUC is :  0.976012187968872
Loss is :  0.42693278193473816
AUC is :  0.9750185025583104
Loss is :  0.4244385063648224
AUC is :  0.9759080447529455


  6%|▋         | 64/1000 [00:06<01:30, 10.31it/s]

Loss is :  0.424399197101593
AUC is :  0.9765180975654573
Loss is :  0.4243509769439697
AUC is :  0.9768105823209807
Loss is :  0.422747939825058
AUC is :  0.9775244430987814


  7%|▋         | 66/1000 [00:06<01:31, 10.15it/s]

Loss is :  0.4230428636074066
AUC is :  0.9773700078760206
Loss is :  0.42158666253089905
AUC is :  0.9782186551644005
Loss is :  0.4220104217529297
AUC is :  0.9782251156710486


  7%|▋         | 68/1000 [00:06<01:31, 10.13it/s]

Loss is :  0.4235648214817047
AUC is :  0.9778904945736382
Loss is :  0.4186919629573822
AUC is :  0.9795329082429188


  7%|▋         | 72/1000 [00:07<01:32, 10.07it/s]

Loss is :  0.4172747731208801
AUC is :  0.9796999078025458
Loss is :  0.41800209879875183
AUC is :  0.9794071911443256
Loss is :  0.4160032272338867
AUC is :  0.9800777699305041


  7%|▋         | 74/1000 [00:07<01:31, 10.17it/s]

Loss is :  0.4190983772277832
AUC is :  0.9800130438713119
Loss is :  0.4183611571788788
AUC is :  0.9802621729052795
Loss is :  0.41739609837532043
AUC is :  0.9806349779160071


  8%|▊         | 78/1000 [00:07<01:31, 10.13it/s]

Loss is :  0.415302038192749
AUC is :  0.9812420727087953
Loss is :  0.41336485743522644
AUC is :  0.9818824581419665
Loss is :  0.41527435183525085
AUC is :  0.9818755319394139


  8%|▊         | 80/1000 [00:07<01:30, 10.17it/s]

Loss is :  0.41723522543907166
AUC is :  0.9814672842564849
Loss is :  0.41459447145462036
AUC is :  0.9821027638133143
Loss is :  0.4133888781070709
AUC is :  0.9823924373788108


  8%|▊         | 84/1000 [00:08<01:29, 10.21it/s]

Loss is :  0.4134126901626587
AUC is :  0.9824627114579533
Loss is :  0.4098944664001465
AUC is :  0.9837766040790192
Loss is :  0.4126131534576416
AUC is :  0.9833203808955412


  9%|▊         | 86/1000 [00:08<01:29, 10.24it/s]

Loss is :  0.409087598323822
AUC is :  0.9839109551418717
Loss is :  0.4140607714653015
AUC is :  0.9828286072807283
Loss is :  0.410148948431015
AUC is :  0.9837839474543576


  9%|▉         | 90/1000 [00:08<01:28, 10.33it/s]

Loss is :  0.41017213463783264
AUC is :  0.984208459834574
Loss is :  0.40812456607818604
AUC is :  0.9851545472166393
Loss is :  0.4092620313167572
AUC is :  0.9848373364662514


  9%|▉         | 92/1000 [00:09<01:29, 10.20it/s]

Loss is :  0.4085206687450409
AUC is :  0.98455527157782
Loss is :  0.40869802236557007
AUC is :  0.9849881085086432
Loss is :  0.41139549016952515
AUC is :  0.9844619212898451


  9%|▉         | 93/1000 [00:09<01:30, 10.03it/s]


KeyboardInterrupt: 