# https://colab.research.google.com/drive/1r_FWLSFf9iL0OWeHeD31d_Opt031P1Nq?usp=sharing#scrollTo=Ogh615ka9I2c
# https://pytorch-geometric.readthedocs.io/en/latest/notes/heterogeneous.html

In [1]:
import torch
from torch import Tensor
print(torch.__version__)

2.6.0


In [2]:
# Install required packages.
import os
os.environ['TORCH'] = torch.__version__

#!pip install torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
#!pip install torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
#!pip install pyg-lib -f https://data.pyg.org/whl/nightly/torch-${TORCH}.html
#!pip install git+https://github.com/pyg-team/pytorch_geometric.git

## Heterogeneous Graph Creation

### Connect to the Graph Database

In [3]:
#Connect to the Graph Database
import pandas as pd
import numpy as np
from gqlalchemy import Memgraph
from gqlalchemy import match
from gqlalchemy.query_builders.memgraph_query_builder import Operator

# Make a connection to the database
memgraph = Memgraph(host='alzkb.ai', port=7687)

### Node features

In [4]:
def explode_cypher_node(df_nodes):
    nodes_data = []
    for i, record in df_nodes.iterrows():
        node = record["n"]
        node_data = {
            "labels": node._labels,
            "id": node._id,
            **node._properties
        }
        nodes_data.append(node_data)

    nodes_df = pd.DataFrame(nodes_data)
    return nodes_df

In [5]:
from sklearn.preprocessing import LabelEncoder
import torch

def get_node_feature(type):
    query = memgraph.execute_and_fetch(f"MATCH (n:{type}) RETURN n")
    df_nodes = pd.DataFrame(query)
    df_nodes = explode_cypher_node(df_nodes)

    #Create a mapping from unique node indices to range [0, num_user_nodes):
    unique_id = df_nodes['id'].unique()
    unique_id = pd.DataFrame(data={
        f'{type}Id': unique_id,
        'mappedID': pd.RangeIndex(len(unique_id)),
    })
    print(unique_id.head())
    
    #label encoding
    label_encoder = LabelEncoder()

    df_nodes_encoded = df_nodes.set_index('id')
    df_nodes_encoded = df_nodes_encoded.drop(columns={'labels', 'nodeID'})
    df_nodes_encoded = df_nodes_encoded.apply(label_encoder.fit_transform)
    print(df_nodes_encoded.head())

    #save feature representation of node to torch format
    feat = torch.from_numpy(df_nodes_encoded.values).to(torch.float)
    print(feat.shape)
    print('--------------------------------')

    return feat, unique_id

In [7]:
#drug_feat, unique_drug_id = get_node_feature('Drug')
#disease_feat, unique_disease_id = get_node_feature('Disease')

labels = ['BiologicalProcess', 'BodyPart', 'CellularComponent', 'Disease', 'Drug'
          , 'DrugClass', 'Gene', 'MolecularFunction', 'Pathway', 'Symptom', 'TranscriptionFactor']  

for label in labels:
    variable_name = f"{label.lower()}_feat"
    variable_name2 = f"unique_{label.lower()}_id"
    feature, unique_id = get_node_feature(label)
    globals()[variable_name] = feature
    globals()[variable_name2] = unique_id



   BiologicalProcessId  mappedID
0               209860         0
1               209861         1
2               209862         2
3               209863         3
4               209864         4
        xrefGeneOntology  commonName  sourceDatabase
id                                                  
209860              5749           0               0
209861              5751           1               0
209862              5301           2               0
209863              1114           3               0
209864              4667           4               0
torch.Size([12322, 3])
--------------------------------




   BodyPartId  mappedID
0      232358         0
1      232359         1
2      232360         2
3      232361         3
4      232362         4
        xrefUberon  commonName  sourceDatabase
id                                            
232358         395           0               0
232359          92           1               0
232360         137           2               0
232361          33           3               0
232362         215           4               0
torch.Size([652, 3])
--------------------------------




   CellularComponentId  mappedID
0               230158         0
1               230159         1
2               230160         2
3               230161         3
4               230162         4
        commonName  xrefGeneOntology
id                                  
230158           0               336
230159           1               125
230160           2              1631
230161           3               316
230162           4              1519
torch.Size([1695, 2])
--------------------------------




   DiseaseId  mappedID
0     233484         0
1     233485         1
2     233486         2
3     233487         3
4     233488         4
        xrefUmlsCUI  xrefDiseaseOntology  commonName  sourceDatabase
id                                                                  
233484           18                    9           0               0
233485           25                   13           1               0
233486           26                   14           2               0
233487           16                   17           3               0
233488           11                    6           4               0
torch.Size([34, 4])
--------------------------------
   DrugId  mappedID
0       0         0
1       1         1
2       2         2
3       3         3
4       4         4
    xrefDrugbank  xrefCasRN  commonName  sourceDatabase
id                                                     
0           3143       2892           0               0
1           1941        357           



   DrugClassId  mappedID
0       233010         0
1       233011         1
2       233012         2
3       233013         3
4       233014         4
        xrefNciThesaurus  commonName  sourceDatabase
id                                                  
233010               339           0               0
233011               128           1               0
233012               239           2               0
233013               147           3               0
233014               327           4               0
torch.Size([474, 3])
--------------------------------
   GeneId  mappedID
0   16581         0
1   16582         1
2   16583         2
3   16584         3
4   16585         4
       xrefNcbiGene  chromosome  commonName  geneSymbol  xrefEnsembl  \
id                                                                     
16581         55031           4           0           0            0   
16582         36003          10           0          13        33031   
16583         285



   MolecularFunctionId  mappedID
0               226698         0
1               226699         1
2               226700         2
3               226701         3
4               226702         4
        commonName  xrefGeneOntology
id                                  
226698           0               140
226699           1              1113
226700           2              3139
226701           3              1679
226702           4              3366
torch.Size([3460, 2])
--------------------------------




   PathwayId  mappedID
0     222182         0
1     222183         1
2     222184         2
3     222185         3
4     222186         4
        pathwayName  pathwayId  sourceDatabase
id                                            
222182         4313          1              23
222183           35          2               9
222184         4461          3               0
222185         4240          4              23
222186         3964          5              23
torch.Size([4516, 3])
--------------------------------
   SymptomId  mappedID
0     231853         0
1     231854         1
2     231855         2
3     231856         3
4     231857         4
        xrefMeSH  commonName  sourceDatabase
id                                          
231853         0           0               0
231854       328           1               0
231855       490           2               0
231856        50           3               0
231857        51           4               0
torch.Size([505, 3])
----



In [12]:
print(f'shape of drug class feature = {drugclass_feat.shape}')
print(f'number of drug class = {len(unique_drugclass_id)}')

shape of drug class feature = torch.Size([474, 3])
number of drug class = 474


### Edge Index

In [16]:
edges = memgraph.execute_and_fetch("MATCH (n)-[r]->(m) RETURN id(n) as source, id(m) as target, type(r) as type")
df_edges = pd.DataFrame(edges)
df_edges

Unnamed: 0,source,target,type
0,233591,17276,TRANSCRIPTIONFACTORINTERACTSWITHGENE
1,232420,17887,BODYPARTOVEREXPRESSESGENE
2,232932,17887,BODYPARTOVEREXPRESSESGENE
3,232371,18695,BODYPARTOVEREXPRESSESGENE
4,232411,18695,BODYPARTUNDEREXPRESSESGENE
...,...,...,...
1668482,208092,233511,GENEASSOCIATESWITHDISEASE
1668483,208260,233511,GENEASSOCIATESWITHDISEASE
1668484,208709,233511,GENEASSOCIATESWITHDISEASE
1668485,208771,233511,GENEASSOCIATESWITHDISEASE


In [17]:
df_edges['type'].unique()

array(['TRANSCRIPTIONFACTORINTERACTSWITHGENE',
       'BODYPARTOVEREXPRESSESGENE', 'BODYPARTUNDEREXPRESSESGENE',
       'CHEMICALBINDSGENE', 'CHEMICALINCREASESEXPRESSION',
       'CHEMICALDECREASESEXPRESSION', 'GENEREGULATESGENE',
       'GENEINTERACTSWITHGENE', 'GENECOVARIESWITHGENE',
       'GENEPARTICIPATESINBIOLOGICALPROCESS', 'GENEINPATHWAY',
       'GENEHASMOLECULARFUNCTION', 'GENEASSOCIATEDWITHCELLULARCOMPONENT',
       'DISEASELOCALIZESTOANATOMY', 'DRUGINCLASS',
       'GENEASSOCIATESWITHDISEASE', 'DRUGTREATSDISEASE',
       'DRUGCAUSESEFFECT', 'SYMPTOMMANIFESTATIONOFDISEASE'], dtype=object)

In [58]:
# Create a mapping from unique node indices to range [0, num_user_nodes):
"""
edge_index =df_edges[df_edges['type']=='DRUGTREATSDISEASE']
unique_drug_id = edge_index['source'].unique()
unique_drug_id = pd.DataFrame(data={
    'drugId': unique_drug_id,
    'mappedID': pd.RangeIndex(len(unique_drug_id)),
})
print("Mapping of drug IDs to consecutive values:")
print("==========================================")
print(unique_drug_id.head())
print()
"""



In [18]:
def create_edge_index(source, rel_type, target):
    df_type = df_edges[df_edges['type']==rel_type]

    # Perform merge to obtain the edges:
    source_id = pd.merge(df_type['source'], globals()[f"unique_{source.lower()}_id"],
                                left_on='source', right_on=f'{source}Id', how='left')
    #print(source_id.head())
    source_id = torch.from_numpy(source_id['mappedID'].values)

    target_id = pd.merge(df_type['target'], globals()[f"unique_{target.lower()}_id"],
                                left_on='target', right_on=f'{target}Id', how='left')
    #print(target_id.head())
    target_id = torch.from_numpy(target_id['mappedID'].values)

    #construct edge_index in COO format
    edge_index = torch.stack([source_id, target_id], dim=0)
    print(edge_index)
    print(edge_index.size())
    print('---------------------------')

    return edge_index


#globals()[f"edge_index_{rel_type}"]
edge_index_DRUGTREATSDISEASE = create_edge_index('Drug', 'DRUGTREATSDISEASE', 'Disease')
edge_index_GENEASSOCIATESWITHDISEASE = create_edge_index('Gene', 'GENEASSOCIATESWITHDISEASE', 'Disease')
edge_index_CHEMICALINCREASESEXPRESSION = create_edge_index('Drug', 'CHEMICALINCREASESEXPRESSION', 'Gene')
edge_index_CHEMICALDECREASESEXPRESSION = create_edge_index('Drug', 'CHEMICALDECREASESEXPRESSION', 'Gene')
edge_index_CHEMICALBINDSGENE = create_edge_index('Drug', 'CHEMICALBINDSGENE', 'Gene')

tensor([[ 4255,  6564,  7842,  8196, 10310, 13241, 13670, 13719, 14129],
        [   24,    24,    24,    24,    24,    24,    24,    24,    24]])
torch.Size([2, 9])
---------------------------
tensor([[149728, 166552, 166552,  ..., 192128, 192190, 183707],
        [     1,      3,      4,  ...,     27,     27,     29]])
torch.Size([2, 508])
---------------------------
tensor([[  5435,   6027,   6094,  ...,  15950,  15954,  15957],
        [149507, 149508, 149508,  ..., 193277, 193277, 193277]])
torch.Size([2, 18713])
---------------------------
tensor([[ 16022,  11878,  10312,  ...,   7661,  12376,  12797],
        [149507, 149512, 149517,  ..., 193268, 193268, 193268]])
torch.Size([2, 21051])
---------------------------
tensor([[  3905,   4551,   6020,  ...,  11580,  15568,   3749],
        [  3448,   3448,   3448,  ..., 192384, 192385, 193277]])
torch.Size([2, 25726])
---------------------------


### Initialize `HeteroData` object and pass the necessary information to it.
Note that we also pass in a `node_id` vector to each node type in order to reconstruct the original node indices from sampled subgraphs.

We also take care of adding reverse edges to the `HeteroData` object.
This allows our GNN model to use both directions of the edge for message passing:

In [21]:
#for var_name in globals():
#    print(var_name)

In [22]:
from torch_geometric.data import HeteroData
import torch_geometric.transforms as T
from torch_geometric.transforms import AddSelfLoops
from torch_geometric.transforms import ToUndirected

data = HeteroData()

# Save node indices and node features:
nodes = ['Disease', 'Drug', 'Gene']  
for node in nodes:
    data[node].node_id = torch.arange(len(globals()[f"unique_{node.lower()}_id"]))
    data[node].x = globals()[f"{node.lower()}_feat"]

# Add edge indices:
data['Drug', 'DRUGTREATSDISEASE', 'Disease'].edge_index = edge_index_DRUGTREATSDISEASE
data['Gene', 'GENEASSOCIATESWITHDISEASE', 'Disease'].edge_index = edge_index_GENEASSOCIATESWITHDISEASE
data['Drug', 'CHEMICALINCREASESEXPRESSION', 'Gene'].edge_index = edge_index_CHEMICALINCREASESEXPRESSION 
data['Drug', 'CHEMICALDECREASESEXPRESSION', 'Gene'].edge_index = edge_index_CHEMICALDECREASESEXPRESSION 
data['Drug', 'CHEMICALBINDSGENE', 'Gene'].edge_index = edge_index_CHEMICALBINDSGENE

# We may also need to make sure to add the reverse edges in order to let a GNN be able to pass messages in both directions.
# We can leverage the `T.ToUndirected()` transform for this from PyG:
transform = ToUndirected()
data = transform(data)

print(data)

assert data.node_types == ['Disease', 'Drug', 'Gene']  
#assert data.edge_types == [('Drug', 'DRUGTREATSDISEASE', 'Disease'),
#                           ('Gene', 'GENEASSOCIATESWITHDISEASE', 'Disease'),
#                           ('Drug', 'CHEMICALINCREASESEXPRESSION', 'Gene'),
#                           ('Drug', 'CHEMICALDECREASESEXPRESSION', 'Gene'),
#                           ('Drug', 'CHEMICALBINDSGENE', 'Gene')]
assert data["Disease"].num_nodes == 34
assert data["Drug"].num_features == 4
assert data["Gene"].num_nodes == 193279
assert data['Drug', 'DRUGTREATSDISEASE', 'Disease'].num_edges == 9
assert data['Gene', 'GENEASSOCIATESWITHDISEASE', 'Disease'].num_edges == 508

  Referenced from: <99947679-2853-3CFE-9677-9AD439037D88> /opt/anaconda3/lib/python3.9/site-packages/torch_scatter/_version_cpu.so
  Expected in:     <36F46DB8-DB62-3926-8653-E332C34252FB> /opt/anaconda3/lib/python3.9/site-packages/torch/lib/libtorch_cpu.dylib
  Referenced from: <99F10DE3-301A-311B-8B7F-D998A38D1857> /opt/anaconda3/lib/python3.9/site-packages/torch_sparse/_version_cpu.so
  Expected in:     <36F46DB8-DB62-3926-8653-E332C34252FB> /opt/anaconda3/lib/python3.9/site-packages/torch/lib/libtorch_cpu.dylib


HeteroData(
  Disease={
    node_id=[34],
    x=[34, 4],
  },
  Drug={
    node_id=[16581],
    x=[16581, 4],
  },
  Gene={
    node_id=[193279],
    x=[193279, 9],
  },
  (Drug, DRUGTREATSDISEASE, Disease)={ edge_index=[2, 9] },
  (Gene, GENEASSOCIATESWITHDISEASE, Disease)={ edge_index=[2, 508] },
  (Drug, CHEMICALINCREASESEXPRESSION, Gene)={ edge_index=[2, 18713] },
  (Drug, CHEMICALDECREASESEXPRESSION, Gene)={ edge_index=[2, 21051] },
  (Drug, CHEMICALBINDSGENE, Gene)={ edge_index=[2, 25726] },
  (Disease, rev_DRUGTREATSDISEASE, Drug)={ edge_index=[2, 9] },
  (Disease, rev_GENEASSOCIATESWITHDISEASE, Gene)={ edge_index=[2, 508] },
  (Gene, rev_CHEMICALINCREASESEXPRESSION, Drug)={ edge_index=[2, 18713] },
  (Gene, rev_CHEMICALDECREASESEXPRESSION, Drug)={ edge_index=[2, 21051] },
  (Gene, rev_CHEMICALBINDSGENE, Drug)={ edge_index=[2, 25726] }
)


### Save HeteroData

In [None]:
import os

path =''
torch.save(data, os.path.join(path,'hetero_graph.pt'))

## Heterogeneous Link-level GNN

We are now ready to create our heterogeneous GNN.
The GNN is responsible for learning enriched node representations from the surrounding subgraphs, which can be then used to derive edge-level predictions.
For defining our heterogenous GNN, we make use of [`nn.SAGEConv`](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.conv.SAGEConv) and the [`nn.to_hetero()`](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.to_hetero_transformer.to_hetero) function, which transforms a GNN defined on homogeneous graphs to be applied on heterogeneous ones.

In addition, we define a final link-level classifier, which simply takes both node embeddings of the link we are trying to predict, and applies a dot-product on them.

As users do not have any node-level information, we choose to learn their features jointly via a `torch.nn.Embedding` layer. In order to improve the expressiveness of movie features, we do the same for movie nodes, and simply add their shallow embeddings to the pre-defined genre features.

### Load HeteroData

In [None]:
import os
import torch

#path =''

data = torch.load(os.path.join(path,'hetero_graph.pt'))
data

HeteroData(
  Disease={
    node_id=[34],
    x=[34, 4],
  },
  Drug={
    node_id=[16581],
    x=[16581, 4],
  },
  Gene={
    node_id=[193279],
    x=[193279, 9],
  },
  (Drug, DRUGTREATSDISEASE, Disease)={ edge_index=[2, 9] },
  (Gene, GENEASSOCIATESWITHDISEASE, Disease)={ edge_index=[2, 508] },
  (Drug, CHEMICALINCREASESEXPRESSION, Gene)={ edge_index=[2, 18713] },
  (Drug, CHEMICALDECREASESEXPRESSION, Gene)={ edge_index=[2, 21051] },
  (Drug, CHEMICALBINDSGENE, Gene)={ edge_index=[2, 25726] },
  (Disease, rev_DRUGTREATSDISEASE, Drug)={ edge_index=[2, 9] },
  (Disease, rev_GENEASSOCIATESWITHDISEASE, Gene)={ edge_index=[2, 508] },
  (Gene, rev_CHEMICALINCREASESEXPRESSION, Drug)={ edge_index=[2, 18713] },
  (Gene, rev_CHEMICALDECREASESEXPRESSION, Drug)={ edge_index=[2, 21051] },
  (Gene, rev_CHEMICALBINDSGENE, Drug)={ edge_index=[2, 25726] }
)

In [19]:
data.metadata()

(['Disease', 'Drug', 'Gene'],
 [('Drug', 'DRUGTREATSDISEASE', 'Disease'),
  ('Gene', 'GENEASSOCIATESWITHDISEASE', 'Disease'),
  ('Drug', 'CHEMICALINCREASESEXPRESSION', 'Gene'),
  ('Drug', 'CHEMICALDECREASESEXPRESSION', 'Gene'),
  ('Drug', 'CHEMICALBINDSGENE', 'Gene'),
  ('Disease', 'rev_DRUGTREATSDISEASE', 'Drug'),
  ('Disease', 'rev_GENEASSOCIATESWITHDISEASE', 'Gene'),
  ('Gene', 'rev_CHEMICALINCREASESEXPRESSION', 'Drug'),
  ('Gene', 'rev_CHEMICALDECREASESEXPRESSION', 'Drug'),
  ('Gene', 'rev_CHEMICALBINDSGENE', 'Drug')])

### Edge-level Training Splits

Since our data is now ready-to-be-used, we can split the ratings of users into training, validation, and test splits.
This is needed in order to ensure that we leak no information about edges used during evaluation into the training phase.

For this, we make use of the [`transforms.RandomLinkSplit`](https://pytorch-geometric.readthedocs.io/en/latest/modules/transforms.html#torch_geometric.transforms.RandomLinkSplit) transformation from PyG.
This transforms randomly divides the edges in the `("drug", "DRUGTREATSDISEASE", "Disease")` into training, validation and test edges.
The `disjoint_train_ratio` parameter further separates edges in the training split into edges used for message passing (`edge_index`) and edges used for supervision (`edge_label_index`).


Note that we may also need to specify the reverse edge type `("movie", "rev_rates", "user")`.
This allows the `RandomLinkSplit` transform to drop reverse edges accordingly to not leak any information into the training phase.

In [23]:
transform = T.RandomLinkSplit(
    num_val=0.2,  
    num_test=0.2,  
    #disjoint_train_ratio=...,  
    neg_sampling_ratio=20,  
    #add_negative_train_samples=False,   #default true
    edge_types=('Drug', 'DRUGTREATSDISEASE', 'Disease'),
    rev_edge_types=("Disease", "rev_DRUGTREATSDISEASE", "Drug"),
)

train_data, val_data, test_data = transform(data)
print("Training data:")
print("==============")
print(train_data)
print()
print("Validation data:")
print("================")
print(val_data)
print("Testing data:")
print("==============")
print(test_data)
print()

Training data:
HeteroData(
  Disease={
    node_id=[34],
    x=[34, 4],
  },
  Drug={
    node_id=[16581],
    x=[16581, 4],
  },
  Gene={
    node_id=[193279],
    x=[193279, 9],
  },
  (Drug, DRUGTREATSDISEASE, Disease)={
    edge_index=[2, 7],
    edge_label=[147],
    edge_label_index=[2, 147],
  },
  (Gene, GENEASSOCIATESWITHDISEASE, Disease)={ edge_index=[2, 508] },
  (Drug, CHEMICALINCREASESEXPRESSION, Gene)={ edge_index=[2, 18713] },
  (Drug, CHEMICALDECREASESEXPRESSION, Gene)={ edge_index=[2, 21051] },
  (Drug, CHEMICALBINDSGENE, Gene)={ edge_index=[2, 25726] },
  (Disease, rev_DRUGTREATSDISEASE, Drug)={ edge_index=[2, 7] },
  (Disease, rev_GENEASSOCIATESWITHDISEASE, Gene)={ edge_index=[2, 508] },
  (Gene, rev_CHEMICALINCREASESEXPRESSION, Drug)={ edge_index=[2, 18713] },
  (Gene, rev_CHEMICALDECREASESEXPRESSION, Drug)={ edge_index=[2, 21051] },
  (Gene, rev_CHEMICALBINDSGENE, Drug)={ edge_index=[2, 25726] }
)

Validation data:
HeteroData(
  Disease={
    node_id=[34],
    x=[3

In [24]:
data['Drug', 'DRUGTREATSDISEASE', 'Disease']

{'edge_index': tensor([[ 4255,  6564,  7842,  8196, 10310, 13241, 13670, 13719, 14129],
        [   24,    24,    24,    24,    24,    24,    24,    24,    24]])}

In [25]:
train_data['Drug', 'DRUGTREATSDISEASE', 'Disease']

{'edge_index': tensor([[ 7842, 13719,  6564, 10310, 13670,  4255, 14129],
        [   24,    24,    24,    24,    24,    24,    24]]), 'edge_label': tensor([1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0.]), 'edge_label_index': tensor([[ 7842, 13719,  6564, 10310, 13670,  4255, 14129,  4564, 15252, 12108,
          3444,  5076, 13653,  5741,  2293,  7732, 15381, 10370, 15044,  7485,
          2

In [26]:
val_data['Drug', 'DRUGTREATSDISEASE', 'Disease']

{'edge_index': tensor([[ 7842, 13719,  6564, 10310, 13670,  4255, 14129],
        [   24,    24,    24,    24,    24,    24,    24]]), 'edge_label': tensor([1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0.]), 'edge_label_index': tensor([[13241,  7337,  1702,   163, 13215,   979,   575,  2100,  5847,  5857,
          8867, 13779,  3181,  8636,  3039, 11022,  7244,  7261,  9122,  6203,
         15284],
        [   24,     4,     6,    22,    29,    21,    12,    26,    27,    29,
             5,    16,    19,    11,    27,    25,    26,     7,    12,    14,
            21]])}

In [27]:
test_data['Drug', 'DRUGTREATSDISEASE', 'Disease']

{'edge_index': tensor([[ 7842, 13719,  6564, 10310, 13670,  4255, 14129, 13241],
        [   24,    24,    24,    24,    24,    24,    24,    24]]), 'edge_label': tensor([1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0.]), 'edge_label_index': tensor([[ 8196, 12104, 12413,  5424,  5999, 13711,  2079, 12786, 12513,  1845,
         14511, 15384,  4010,  4177, 10710, 12616, 14991,  9194, 10519, 12914,
          2537],
        [   24,    14,     4,    11,    15,     5,     1,    33,    21,    19,
             9,    10,    14,    14,     8,    33,     9,    21,    22,     9,
             6]])}

In [31]:
"""
assert train_data['Drug', 'DRUGTREATSDISEASE', 'Disease'].num_edges == 8
assert train_data['Drug', 'DRUGTREATSDISEASE', 'Disease'].edge_label_index.size(1) == 8

# No negative edges added:
assert train_data['Drug', 'DRUGTREATSDISEASE', 'Disease'].edge_label.min() == 1
assert train_data['Drug', 'DRUGTREATSDISEASE', 'Disease'].edge_label.max() == 1

assert val_data['Drug', 'DRUGTREATSDISEASE', 'Disease'].num_edges == 8
assert val_data['Drug', 'DRUGTREATSDISEASE', 'Disease'].edge_label_index.size(1) == 0

# Negative edges with ratio 2:1:
#assert val_data['Drug', 'DRUGTREATSDISEASE', 'Disease'].edge_label.long().bincount().tolist() == [20166, 10083]
"""

In [62]:
train_data

HeteroData(
  Disease={
    node_id=[34],
    x=[34, 4],
  },
  Drug={
    node_id=[16581],
    x=[16581, 4],
  },
  Gene={
    node_id=[193279],
    x=[193279, 9],
  },
  (Drug, DRUGTREATSDISEASE, Disease)={
    edge_index=[2, 7],
    edge_label=[147],
    edge_label_index=[2, 147],
  },
  (Gene, GENEASSOCIATESWITHDISEASE, Disease)={ edge_index=[2, 508] },
  (Drug, CHEMICALINCREASESEXPRESSION, Gene)={ edge_index=[2, 18713] },
  (Drug, CHEMICALDECREASESEXPRESSION, Gene)={ edge_index=[2, 21051] },
  (Drug, CHEMICALBINDSGENE, Gene)={ edge_index=[2, 25726] },
  (Disease, rev_DRUGTREATSDISEASE, Drug)={ edge_index=[2, 7] },
  (Disease, rev_GENEASSOCIATESWITHDISEASE, Gene)={ edge_index=[2, 508] },
  (Gene, rev_CHEMICALINCREASESEXPRESSION, Drug)={ edge_index=[2, 18713] },
  (Gene, rev_CHEMICALDECREASESEXPRESSION, Drug)={ edge_index=[2, 21051] },
  (Gene, rev_CHEMICALBINDSGENE, Drug)={ edge_index=[2, 25726] }
)

In [28]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import HeteroConv, SAGEConv, Linear
from torch_geometric.transforms import AddSelfLoops

class HeteroGNN(torch.nn.Module):
    def __init__(self, metadata, hidden_channels):
        super().__init__()
        # Define convolutional layers for each relationship type in the HeteroData
        self.convs = HeteroConv({
            ('Drug', 'DRUGTREATSDISEASE', 'Disease'): SAGEConv((-1, -1), hidden_channels),
            ('Gene', 'GENEASSOCIATESWITHDISEASE', 'Disease'): SAGEConv((-1, -1), hidden_channels),
            ('Drug', 'CHEMICALINCREASESEXPRESSION', 'Gene'): SAGEConv((-1, -1), hidden_channels),
            ('Drug', 'CHEMICALDECREASESEXPRESSION', 'Gene'): SAGEConv((-1, -1), hidden_channels),
            ('Drug', 'CHEMICALBINDSGENE', 'Gene'): SAGEConv((-1, -1), hidden_channels),
            ('Disease', 'rev_DRUGTREATSDISEASE', 'Drug'): SAGEConv((-1, -1), hidden_channels),
            ('Disease', 'rev_GENEASSOCIATESWITHDISEASE', 'Gene'): SAGEConv((-1, -1), hidden_channels),
            ('Gene', 'rev_CHEMICALINCREASESEXPRESSION', 'Drug'): SAGEConv((-1, -1), hidden_channels),
            ('Gene', 'rev_CHEMICALDECREASESEXPRESSION', 'Drug'): SAGEConv((-1, -1), hidden_channels),
            ('Gene', 'rev_CHEMICALBINDSGENE', 'Drug'): SAGEConv((-1, -1), hidden_channels),
        }, aggr='sum')  # Aggregating results with sum
        self.conv = SAGEConv((-1, -1), hidden_channels)
        
        # Linear layer for link prediction
        self.linear = Linear(hidden_channels * 2, 1)

    def forward(self, data):
        # Apply heterogeneous convolution
        x_dict = self.convs(data.x_dict, data.edge_index_dict)
        return x_dict

    def decode(self, z_dict, edge_label_index):
        # Decode the node embeddings to predict the edge
        src, dst = z_dict['Drug'][edge_label_index[0]], z_dict['Disease'][edge_label_index[1]]
        edge_feat = torch.cat([src, dst], dim=-1)
        return self.linear(edge_feat).sigmoid()

# Instantiate the model
model = HeteroGNN(data.metadata(), hidden_channels=64)

In [32]:
# Define optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

best_val_loss = float('inf')
patience, patience_counter = 10, 0  # For early stopping

for epoch in range(100):
    model.train()
    optimizer.zero_grad()

    # Forward pass on training data
    z_dict = model(train_data)
    
    # Predict edges for training data
    train_pred = model.decode(z_dict, train_data['Drug', 'DRUGTREATSDISEASE', 'Disease'].edge_label_index)
    
    # Compute the loss on training data
    train_loss = F.binary_cross_entropy(train_pred.squeeze(), train_data['Drug', 'DRUGTREATSDISEASE', 'Disease'].edge_label)
    
    # Backpropagation
    train_loss.backward()
    optimizer.step()
    
    # Validation step
    model.eval()
    with torch.no_grad():
        z_dict = model(val_data)
        val_pred = model.decode(z_dict, val_data['Drug', 'DRUGTREATSDISEASE', 'Disease'].edge_label_index)
        val_loss = F.binary_cross_entropy(val_pred.squeeze(), val_data['Drug', 'DRUGTREATSDISEASE', 'Disease'].edge_label)
    
    print(f'Epoch {epoch+1}, Train Loss: {train_loss.item()}, Val Loss: {val_loss.item()}')

    # Early stopping check
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        patience_counter = 0
        #torch.save(model.state_dict(), os.path.join(path,'HeteroGNN_best_model.pth'))  # Save the best model
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print("Early stopping triggered!")
            break

Epoch 1, Train Loss: 6.122448921203613, Val Loss: 4.761904716491699
Epoch 2, Train Loss: 6.122448921203613, Val Loss: 4.761904716491699
Epoch 3, Train Loss: 6.122448921203613, Val Loss: 4.761904716491699
Epoch 4, Train Loss: 6.122448921203613, Val Loss: 4.761904716491699
Epoch 5, Train Loss: 6.122448921203613, Val Loss: 4.761904716491699
Epoch 6, Train Loss: 6.122448921203613, Val Loss: 4.761904716491699
Epoch 7, Train Loss: 6.122448921203613, Val Loss: 4.761904716491699
Epoch 8, Train Loss: 6.122448921203613, Val Loss: 4.761904716491699
Epoch 9, Train Loss: 6.122448921203613, Val Loss: 4.761904716491699
Epoch 10, Train Loss: 6.122448921203613, Val Loss: 4.761904716491699
Epoch 11, Train Loss: 6.122448921203613, Val Loss: 4.761904716491699
Early stopping triggered!


In [33]:
train_data['Drug', 'DRUGTREATSDISEASE', 'Disease'].edge_label

tensor([1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0.])

In [34]:
train_pred.squeeze()

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0.], grad_fn=<SqueezeBackward0>)

In [35]:
val_data['Drug', 'DRUGTREATSDISEASE', 'Disease'].edge_label

tensor([1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0.])

In [36]:
val_data['Drug', 'DRUGTREATSDISEASE', 'Disease'].edge_index

tensor([[ 7842, 13719,  6564, 10310, 13670,  4255, 14129],
        [   24,    24,    24,    24,    24,    24,    24]])

In [37]:
print(train_data['Drug', 'DRUGTREATSDISEASE', 'Disease'].edge_index) 

tensor([[ 7842, 13719,  6564, 10310, 13670,  4255, 14129],
        [   24,    24,    24,    24,    24,    24,    24]])


In [38]:
train_data.edge_index_dict

{('Drug',
  'DRUGTREATSDISEASE',
  'Disease'): tensor([[ 7842, 13719,  6564, 10310, 13670,  4255, 14129],
         [   24,    24,    24,    24,    24,    24,    24]]),
 ('Gene',
  'GENEASSOCIATESWITHDISEASE',
  'Disease'): tensor([[149728, 166552, 166552,  ..., 192128, 192190, 183707],
         [     1,      3,      4,  ...,     27,     27,     29]]),
 ('Drug',
  'CHEMICALINCREASESEXPRESSION',
  'Gene'): tensor([[  5435,   6027,   6094,  ...,  15950,  15954,  15957],
         [149507, 149508, 149508,  ..., 193277, 193277, 193277]]),
 ('Drug',
  'CHEMICALDECREASESEXPRESSION',
  'Gene'): tensor([[ 16022,  11878,  10312,  ...,   7661,  12376,  12797],
         [149507, 149512, 149517,  ..., 193268, 193268, 193268]]),
 ('Drug',
  'CHEMICALBINDSGENE',
  'Gene'): tensor([[  3905,   4551,   6020,  ...,  11580,  15568,   3749],
         [  3448,   3448,   3448,  ..., 192384, 192385, 193277]]),
 ('Disease',
  'rev_DRUGTREATSDISEASE',
  'Drug'): tensor([[   24,    24,    24,    24,    24,    24,

In [39]:
print(train_data.x_dict['Drug'].shape)

torch.Size([16581, 4])


In [41]:
from sklearn.metrics import roc_auc_score

# Load the best model saved during training
#model.load_state_dict(torch.load(os.path.join(path,'HeteroGNN_best_model.pth')))

# Evaluate on the test set
model.eval()
with torch.no_grad():
    z_dict = model(test_data)
    test_pred = model.decode(z_dict, test_data['Drug', 'DRUGTREATSDISEASE', 'Disease'].edge_label_index)
    
    # Calculate AUC and other metrics
    auc = roc_auc_score(test_data['Drug', 'DRUGTREATSDISEASE', 'Disease'].edge_label.cpu(), test_pred.cpu())
    print(f'Test AUC: {auc}')


Test AUC: 0.5


In [42]:
test_data['Drug', 'DRUGTREATSDISEASE', 'Disease'].edge_label

tensor([1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0.])

In [43]:
test_pred.squeeze()

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

## Graph Attention Networks (GAT)

### Load HeteroData

In [None]:
import os
import torch

#path =''

data = torch.load(os.path.join(path,'hetero_graph.pt'))
data

HeteroData(
  Disease={
    node_id=[34],
    x=[34, 4],
  },
  Drug={
    node_id=[16581],
    x=[16581, 4],
  },
  Gene={
    node_id=[193279],
    x=[193279, 9],
  },
  (Drug, DRUGTREATSDISEASE, Disease)={ edge_index=[2, 9] },
  (Gene, GENEASSOCIATESWITHDISEASE, Disease)={ edge_index=[2, 508] },
  (Drug, CHEMICALINCREASESEXPRESSION, Gene)={ edge_index=[2, 18713] },
  (Drug, CHEMICALDECREASESEXPRESSION, Gene)={ edge_index=[2, 21051] },
  (Drug, CHEMICALBINDSGENE, Gene)={ edge_index=[2, 25726] },
  (Disease, rev_DRUGTREATSDISEASE, Drug)={ edge_index=[2, 9] },
  (Disease, rev_GENEASSOCIATESWITHDISEASE, Gene)={ edge_index=[2, 508] },
  (Gene, rev_CHEMICALINCREASESEXPRESSION, Drug)={ edge_index=[2, 18713] },
  (Gene, rev_CHEMICALDECREASESEXPRESSION, Drug)={ edge_index=[2, 21051] },
  (Gene, rev_CHEMICALBINDSGENE, Drug)={ edge_index=[2, 25726] }
)

In [19]:
data.metadata()

(['Disease', 'Drug', 'Gene'],
 [('Drug', 'DRUGTREATSDISEASE', 'Disease'),
  ('Gene', 'GENEASSOCIATESWITHDISEASE', 'Disease'),
  ('Drug', 'CHEMICALINCREASESEXPRESSION', 'Gene'),
  ('Drug', 'CHEMICALDECREASESEXPRESSION', 'Gene'),
  ('Drug', 'CHEMICALBINDSGENE', 'Gene'),
  ('Disease', 'rev_DRUGTREATSDISEASE', 'Drug'),
  ('Disease', 'rev_GENEASSOCIATESWITHDISEASE', 'Gene'),
  ('Gene', 'rev_CHEMICALINCREASESEXPRESSION', 'Drug'),
  ('Gene', 'rev_CHEMICALDECREASESEXPRESSION', 'Drug'),
  ('Gene', 'rev_CHEMICALBINDSGENE', 'Drug')])

### Edge-level Training Splits

In [14]:
transform = T.RandomLinkSplit(
    num_val=0.2,  
    num_test=0.2,  
    #disjoint_train_ratio=...,  
    neg_sampling_ratio=20,  
    #add_negative_train_samples=False,   #default true
    edge_types=('Drug', 'DRUGTREATSDISEASE', 'Disease'),
    rev_edge_types=("Disease", "rev_DRUGTREATSDISEASE", "Drug"),
)

train_data, val_data, test_data = transform(data)
print("Training data:")
print("==============")
print(train_data)
print()
print("Validation data:")
print("================")
print(val_data)
print("Testing data:")
print("==============")
print(test_data)
print()

Training data:
HeteroData(
  Disease={
    node_id=[34],
    x=[34, 4],
  },
  Drug={
    node_id=[16581],
    x=[16581, 4],
  },
  Gene={
    node_id=[193279],
    x=[193279, 9],
  },
  (Drug, DRUGTREATSDISEASE, Disease)={
    edge_index=[2, 7],
    edge_label=[147],
    edge_label_index=[2, 147],
  },
  (Gene, GENEASSOCIATESWITHDISEASE, Disease)={ edge_index=[2, 508] },
  (Drug, CHEMICALINCREASESEXPRESSION, Gene)={ edge_index=[2, 18713] },
  (Drug, CHEMICALDECREASESEXPRESSION, Gene)={ edge_index=[2, 21051] },
  (Drug, CHEMICALBINDSGENE, Gene)={ edge_index=[2, 25726] },
  (Disease, rev_DRUGTREATSDISEASE, Drug)={ edge_index=[2, 7] },
  (Disease, rev_GENEASSOCIATESWITHDISEASE, Gene)={ edge_index=[2, 508] },
  (Gene, rev_CHEMICALINCREASESEXPRESSION, Drug)={ edge_index=[2, 18713] },
  (Gene, rev_CHEMICALDECREASESEXPRESSION, Drug)={ edge_index=[2, 21051] },
  (Gene, rev_CHEMICALBINDSGENE, Drug)={ edge_index=[2, 25726] }
)

Validation data:
HeteroData(
  Disease={
    node_id=[34],
    x=[3

### GAT

In [44]:
import torch
from torch.nn import Linear
from torch_geometric.nn import GATConv, HeteroConv

class GATLinkPredictor(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GATLinkPredictor, self).__init__()
        self.conv1 = HeteroConv({
            ('Drug', 'DRUGTREATSDISEASE', 'Disease'): GATConv(in_channels, hidden_channels, heads=8, concat=False),
            ('Gene', 'GENEASSOCIATESWITHDISEASE', 'Disease'): GATConv(in_channels, hidden_channels, heads=8, concat=False),
            ('Drug', 'CHEMICALINCREASESEXPRESSION', 'Gene'): GATConv(in_channels, hidden_channels, heads=8, concat=False),
            ('Drug', 'CHEMICALDECREASESEXPRESSION', 'Gene'): GATConv(in_channels, hidden_channels, heads=8, concat=False),
            ('Drug', 'CHEMICALBINDSGENE', 'Gene'): GATConv(in_channels, hidden_channels, heads=8, concat=False),
            ('Disease', 'rev_DRUGTREATSDISEASE', 'Drug'): GATConv(in_channels, hidden_channels, heads=8, concat=False),
            ('Disease', 'rev_GENEASSOCIATESWITHDISEASE', 'Gene'): GATConv(in_channels, hidden_channels, heads=8, concat=False),
            ('Gene', 'rev_CHEMICALINCREASESEXPRESSION', 'Drug'): GATConv(in_channels, hidden_channels, heads=8, concat=False),
            ('Gene', 'rev_CHEMICALDECREASESEXPRESSION', 'Drug'): GATConv(in_channels, hidden_channels, heads=8, concat=False),
            ('Gene', 'rev_CHEMICALBINDSGENE', 'Drug'): GATConv(in_channels, hidden_channels, heads=8, concat=False),
        }, aggr='mean')
        self.conv2 = HeteroConv({
            ('Drug', 'DRUGTREATSDISEASE', 'Disease'): GATConv(hidden_channels, hidden_channels, heads=8, concat=False),
            ('Gene', 'GENEASSOCIATESWITHDISEASE', 'Disease'): GATConv(in_channels, hidden_channels, heads=8, concat=False),
            ('Drug', 'CHEMICALINCREASESEXPRESSION', 'Gene'): GATConv(in_channels, hidden_channels, heads=8, concat=False),
            ('Drug', 'CHEMICALDECREASESEXPRESSION', 'Gene'): GATConv(in_channels, hidden_channels, heads=8, concat=False),
            ('Drug', 'CHEMICALBINDSGENE', 'Gene'): GATConv(in_channels, hidden_channels, heads=8, concat=False),
            ('Disease', 'rev_DRUGTREATSDISEASE', 'Drug'): GATConv(in_channels, hidden_channels, heads=8, concat=False),
            ('Disease', 'rev_GENEASSOCIATESWITHDISEASE', 'Gene'): GATConv(in_channels, hidden_channels, heads=8, concat=False),
            ('Gene', 'rev_CHEMICALINCREASESEXPRESSION', 'Drug'): GATConv(in_channels, hidden_channels, heads=8, concat=False),
            ('Gene', 'rev_CHEMICALDECREASESEXPRESSION', 'Drug'): GATConv(in_channels, hidden_channels, heads=8, concat=False),
            ('Gene', 'rev_CHEMICALBINDSGENE', 'Drug'): GATConv(in_channels, hidden_channels, heads=8, concat=False),            
        }, aggr='mean')
        self.fc = Linear(hidden_channels * 2, out_channels)

    def forward(self, x_dict, edge_index_dict):
        x_dict = self.conv1(x_dict, edge_index_dict)
        x_dict = {key: torch.relu(x) for key, x in x_dict.items()}
        x_dict = self.conv2(x_dict, edge_index_dict)
        return x_dict

    def decode(self, z_dict, edge_label_index):
        src_z, dst_z = z_dict['Drug'][edge_label_index[0]], z_dict['Disease'][edge_label_index[1]]
        return torch.sigmoid(self.fc(torch.cat([src_z, dst_z], dim=1)))

In [45]:
import torch.nn.functional as F
from torch_geometric.nn import GATConv, HeteroConv, Linear

class HeteroGAT(torch.nn.Module):
    def __init__(self, metadata, hidden_channels, out_channels):
        super(HeteroGAT, self).__init__()
        self.convs = HeteroConv({
            ('Drug', 'DRUGTREATSDISEASE', 'Disease'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
            ('Gene', 'GENEASSOCIATESWITHDISEASE', 'Disease'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
            ('Drug', 'CHEMICALINCREASESEXPRESSION', 'Gene'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
            ('Drug', 'CHEMICALDECREASESEXPRESSION', 'Gene'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
            ('Drug', 'CHEMICALBINDSGENE', 'Gene'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
            ('Disease', 'rev_DRUGTREATSDISEASE', 'Drug'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
            ('Disease', 'rev_GENEASSOCIATESWITHDISEASE', 'Gene'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
            ('Gene', 'rev_CHEMICALINCREASESEXPRESSION', 'Drug'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
            ('Gene', 'rev_CHEMICALDECREASESEXPRESSION', 'Drug'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
            ('Gene', 'rev_CHEMICALBINDSGENE', 'Drug'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
        }, aggr='sum')
        
        self.lin = Linear(hidden_channels * 2, out_channels)

    def forward(self, x_dict, edge_index_dict):
        # Apply HeteroConv layers
        x_dict = self.convs(x_dict, edge_index_dict)
        return x_dict

    def decode(self, z_dict, edge_label_index):
        src_z, dst_z = z_dict['Drug'][edge_label_index[0]], z_dict['Disease'][edge_label_index[1]]
        return self.lin(torch.cat([src_z, dst_z], dim=-1)).squeeze(-1)


In [47]:
import torch
import torch.nn.functional as F
from torch.optim import Adam
from torch_geometric.data import Data

# Define hyperparameters
drug_feature_size = 4
in_channels = drug_feature_size  # Assuming feature sizes match
hidden_channels = 64
out_channels = 1
learning_rate = 0.01
num_epochs = 100
best_val_loss = float('inf')

# Initialize model, optimizer, and loss function
#model = GATLinkPredictor(in_channels, hidden_channels, out_channels)
model = HeteroGAT(metadata=data.metadata(), hidden_channels=64, out_channels=1)
optimizer = Adam(model.parameters(), lr=learning_rate)
criterion = torch.nn.BCEWithLogitsLoss()

# Training loop
for epoch in range(num_epochs):
    model.train()
    optimizer.zero_grad()

    z_dict = model(train_data.x_dict, train_data.edge_index_dict) #(train_data)
    
    # Forward pass and loss calculation
    train_pred = model.decode(z_dict, train_data['Drug', 'DRUGTREATSDISEASE', 'Disease'].edge_label_index)
    train_loss = criterion(train_pred.squeeze(), train_data['Drug', 'DRUGTREATSDISEASE', 'Disease'].edge_label)
    
    # Backward pass and optimization
    train_loss.backward()
    optimizer.step()

    # Validation step
    model.eval()
    with torch.no_grad():
        z_dict = model(val_data.x_dict, val_data.edge_index_dict)
        val_pred = model.decode(z_dict, val_data['Drug', 'DRUGTREATSDISEASE', 'Disease'].edge_label_index)
        val_loss = F.binary_cross_entropy_with_logits(val_pred, val_data['Drug', 'DRUGTREATSDISEASE', 'Disease'].edge_label)
    
    print(f'Epoch {epoch+1}, Train Loss: {train_loss.item()}, Val Loss: {val_loss.item()}')

    # Early stopping check
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        patience_counter = 0
        #torch.save(model.state_dict(), os.path.join(path,'GAT_best_model.pth'))  # Save the best model
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print("Early stopping triggered!")
            break

Epoch 1, Train Loss: 807.485107421875, Val Loss: 1687.5162353515625
Epoch 2, Train Loss: 650.6926879882812, Val Loss: 1414.9110107421875
Epoch 3, Train Loss: 456.26153564453125, Val Loss: 486.6878356933594
Epoch 4, Train Loss: 193.6825714111328, Val Loss: 251.7686004638672
Epoch 5, Train Loss: 205.13084411621094, Val Loss: 432.7135009765625
Epoch 6, Train Loss: 302.5349426269531, Val Loss: 296.48931884765625
Epoch 7, Train Loss: 215.1543426513672, Val Loss: 0.15046530961990356
Epoch 8, Train Loss: 52.565513610839844, Val Loss: 203.27024841308594
Epoch 9, Train Loss: 128.68841552734375, Val Loss: 429.249267578125
Epoch 10, Train Loss: 193.04415893554688, Val Loss: 411.8004150390625
Epoch 11, Train Loss: 179.84066772460938, Val Loss: 233.8612060546875
Epoch 12, Train Loss: 104.82941436767578, Val Loss: 49.87337112426758
Epoch 13, Train Loss: 26.638402938842773, Val Loss: 34.20005416870117
Epoch 14, Train Loss: 78.2496337890625, Val Loss: 141.28173828125
Epoch 15, Train Loss: 127.42868804

In [49]:
from sklearn.metrics import roc_auc_score

# Load the best model saved during training
#model.load_state_dict(torch.load(os.path.join(path,'GAT_best_model.pth')))

# Evaluate on the test set
model.eval()
with torch.no_grad():
    z_dict = model(test_data.x_dict, test_data.edge_index_dict)
    test_pred = model.decode(z_dict, test_data['Drug', 'DRUGTREATSDISEASE', 'Disease'].edge_label_index)
    
    # Calculate AUC and other metrics
    auc = roc_auc_score(test_data['Drug', 'DRUGTREATSDISEASE', 'Disease'].edge_label.cpu(), test_pred.cpu())
    print(f'Test AUC: {auc}')


Test AUC: 1.0


## Relational Graph Convolutional Network (R-GCN)

### Load HeteroData

In [None]:
import os
import torch

#path =''

data = torch.load(os.path.join(path,'hetero_graph.pt'))
data

In [2]:
data.metadata()

(['Disease', 'Drug', 'Gene'],
 [('Drug', 'DRUGTREATSDISEASE', 'Disease'),
  ('Gene', 'GENEASSOCIATESWITHDISEASE', 'Disease'),
  ('Drug', 'CHEMICALINCREASESEXPRESSION', 'Gene'),
  ('Drug', 'CHEMICALDECREASESEXPRESSION', 'Gene'),
  ('Drug', 'CHEMICALBINDSGENE', 'Gene'),
  ('Disease', 'rev_DRUGTREATSDISEASE', 'Drug'),
  ('Disease', 'rev_GENEASSOCIATESWITHDISEASE', 'Gene'),
  ('Gene', 'rev_CHEMICALINCREASESEXPRESSION', 'Drug'),
  ('Gene', 'rev_CHEMICALDECREASESEXPRESSION', 'Drug'),
  ('Gene', 'rev_CHEMICALBINDSGENE', 'Drug')])

### Edge-level Training Splits

In [56]:
import torch_geometric.transforms as T

transform = T.RandomLinkSplit(
    num_val=0.2,  
    num_test=0.2,  
    #disjoint_train_ratio=...,  
    neg_sampling_ratio=20,  
    #add_negative_train_samples=False,   #default true
    edge_types=('Drug', 'DRUGTREATSDISEASE', 'Disease'),
    rev_edge_types=("Disease", "rev_DRUGTREATSDISEASE", "Drug"),
)

train_data, val_data, test_data = transform(data)
print("Training data:")
print("==============")
print(train_data)
print()
print("Validation data:")
print("================")
print(val_data)
print("Testing data:")
print("==============")
print(test_data)
print()

Training data:
HeteroData(
  Disease={
    node_id=[34],
    x=[34, 4],
  },
  Drug={
    node_id=[16581],
    x=[16581, 4],
  },
  Gene={
    node_id=[193279],
    x=[193279, 9],
  },
  (Drug, DRUGTREATSDISEASE, Disease)={
    edge_index=[2, 7],
    edge_label=[147],
    edge_label_index=[2, 147],
  },
  (Gene, GENEASSOCIATESWITHDISEASE, Disease)={ edge_index=[2, 508] },
  (Drug, CHEMICALINCREASESEXPRESSION, Gene)={ edge_index=[2, 18713] },
  (Drug, CHEMICALDECREASESEXPRESSION, Gene)={ edge_index=[2, 21051] },
  (Drug, CHEMICALBINDSGENE, Gene)={ edge_index=[2, 25726] },
  (Disease, rev_DRUGTREATSDISEASE, Drug)={ edge_index=[2, 7] },
  (Disease, rev_GENEASSOCIATESWITHDISEASE, Gene)={ edge_index=[2, 508] },
  (Gene, rev_CHEMICALINCREASESEXPRESSION, Drug)={ edge_index=[2, 18713] },
  (Gene, rev_CHEMICALDECREASESEXPRESSION, Drug)={ edge_index=[2, 21051] },
  (Gene, rev_CHEMICALBINDSGENE, Drug)={ edge_index=[2, 25726] }
)

Validation data:
HeteroData(
  Disease={
    node_id=[34],
    x=[3

In [None]:
"""
        edge_indices = []
        edge_type_indices = []
        
        for edge_type, edge_index in edge_index_dict.items():
            edge_indices.append(edge_index)
            edge_type_idx = edge_type_dict[edge_type]
            edge_type_indices.append(torch.full((edge_index.size(1),), edge_type_idx, dtype=torch.long))
        
        # Concatenate all edge indices and type indices
        edge_index = torch.cat(edge_indices, dim=1)
        edge_type_indices = torch.cat(edge_type_indices, dim=0)

        # Pass through RGCN layers
        x_dict = self.conv1(x_dict, edge_index, edge_type_indices)
        x_dict = F.relu(x_dict)
        x_dict = self.conv2(x_dict, edge_index, edge_type_indices)
        """

In [57]:
import torch.nn.functional as F
from torch_geometric.nn import RGCNConv, HeteroConv, Linear

class RGCN(torch.nn.Module):
    def __init__(self, num_nodes_dict, num_relations, hidden_channels, out_channels):
        super(RGCN, self).__init__()
        self.convs = HeteroConv({
            ('Drug', 'DRUGTREATSDISEASE', 'Disease'): RGCNConv((1, 1), hidden_channels, num_relations=num_relations),
            ('Gene', 'GENEASSOCIATESWITHDISEASE', 'Disease'): RGCNConv((1, 1), hidden_channels, num_relations=num_relations),
            ('Drug', 'CHEMICALINCREASESEXPRESSION', 'Gene'): RGCNConv((1, 1), hidden_channels, num_relations=num_relations),
            ('Drug', 'CHEMICALDECREASESEXPRESSION', 'Gene'): RGCNConv((1, 1), hidden_channels, num_relations=num_relations),
            ('Drug', 'CHEMICALBINDSGENE', 'Gene'): RGCNConv((1, 1), hidden_channels, num_relations=num_relations),
            ('Disease', 'rev_DRUGTREATSDISEASE', 'Drug'): RGCNConv((1, 1), hidden_channels, num_relations=num_relations),
            ('Disease', 'rev_GENEASSOCIATESWITHDISEASE', 'Gene'): RGCNConv((1, 1), hidden_channels, num_relations=num_relations),
            ('Gene', 'rev_CHEMICALINCREASESEXPRESSION', 'Drug'): RGCNConv((1, 1), hidden_channels, num_relations=num_relations),
            ('Gene', 'rev_CHEMICALDECREASESEXPRESSION', 'Drug'): RGCNConv((1, 1), hidden_channels, num_relations=num_relations),
            ('Gene', 'rev_CHEMICALBINDSGENE', 'Drug'): RGCNConv((1, 1), hidden_channels, num_relations=num_relations),
        }, aggr='sum')
        #self.conv1 = RGCNConv((1, 1), hidden_channels, num_relations=num_relations)
        #self.conv2 = RGCNConv(hidden_channels, hidden_channels, num_relations=num_relations)
        self.lin = Linear(hidden_channels * 2, out_channels)
        
    def forward(self, x_dict, edge_index_dict):
        #x_dict = self.conv1(x_dict, edge_index_dict)
        #x_dict = {k: F.relu(v) for k, v in x_dict.items()}
        #x_dict = self.conv2(x_dict, edge_index_dict)
        x_dict = self.convs(x_dict, edge_index_dict)
        return x_dict

    def decode(self, z_dict, edge_label_index):
        src_z, dst_z = z_dict['Drug'][edge_label_index[0]], z_dict['Disease'][edge_label_index[1]]
        return self.lin(torch.cat([src_z, dst_z], dim=-1)).squeeze(-1)


In [58]:
data.edge_types

[('Drug', 'DRUGTREATSDISEASE', 'Disease'),
 ('Gene', 'GENEASSOCIATESWITHDISEASE', 'Disease'),
 ('Drug', 'CHEMICALINCREASESEXPRESSION', 'Gene'),
 ('Drug', 'CHEMICALDECREASESEXPRESSION', 'Gene'),
 ('Drug', 'CHEMICALBINDSGENE', 'Gene'),
 ('Disease', 'rev_DRUGTREATSDISEASE', 'Drug'),
 ('Disease', 'rev_GENEASSOCIATESWITHDISEASE', 'Gene'),
 ('Gene', 'rev_CHEMICALINCREASESEXPRESSION', 'Drug'),
 ('Gene', 'rev_CHEMICALDECREASESEXPRESSION', 'Drug'),
 ('Gene', 'rev_CHEMICALBINDSGENE', 'Drug')]

In [59]:
edge_type_dict = {
    ('Drug', 'DRUGTREATSDISEASE', 'Disease'): edge_index_DRUGTREATSDISEASE,
    ('Gene', 'GENEASSOCIATESWITHDISEASE', 'Disease'): edge_index_GENEASSOCIATESWITHDISEASE,
    ('Drug', 'CHEMICALINCREASESEXPRESSION', 'Gene'): edge_index_CHEMICALINCREASESEXPRESSION,
    ('Drug', 'CHEMICALDECREASESEXPRESSION', 'Gene'): edge_index_CHEMICALDECREASESEXPRESSION,
    ('Drug', 'CHEMICALBINDSGENE', 'Gene'): edge_index_CHEMICALBINDSGENE,
    ('Disease', 'rev_DRUGTREATSDISEASE', 'Drug'): edge_index_DRUGTREATSDISEASE,
    ('Disease', 'rev_GENEASSOCIATESWITHDISEASE', 'Gene'): edge_index_GENEASSOCIATESWITHDISEASE,
    ('Gene', 'rev_CHEMICALINCREASESEXPRESSION', 'Drug'): edge_index_CHEMICALINCREASESEXPRESSION,
    ('Gene', 'rev_CHEMICALDECREASESEXPRESSION', 'Drug'): edge_index_CHEMICALDECREASESEXPRESSION,
    ('Gene', 'rev_CHEMICALBINDSGENE', 'Drug'): edge_index_CHEMICALBINDSGENE
}
edge_type_dict

{('Drug',
  'DRUGTREATSDISEASE',
  'Disease'): tensor([[ 4255,  6564,  7842,  8196, 10310, 13241, 13670, 13719, 14129],
         [   24,    24,    24,    24,    24,    24,    24,    24,    24]]),
 ('Gene',
  'GENEASSOCIATESWITHDISEASE',
  'Disease'): tensor([[149728, 166552, 166552,  ..., 192128, 192190, 183707],
         [     1,      3,      4,  ...,     27,     27,     29]]),
 ('Drug',
  'CHEMICALINCREASESEXPRESSION',
  'Gene'): tensor([[  5435,   6027,   6094,  ...,  15950,  15954,  15957],
         [149507, 149508, 149508,  ..., 193277, 193277, 193277]]),
 ('Drug',
  'CHEMICALDECREASESEXPRESSION',
  'Gene'): tensor([[ 16022,  11878,  10312,  ...,   7661,  12376,  12797],
         [149507, 149512, 149517,  ..., 193268, 193268, 193268]]),
 ('Drug',
  'CHEMICALBINDSGENE',
  'Gene'): tensor([[  3905,   4551,   6020,  ...,  11580,  15568,   3749],
         [  3448,   3448,   3448,  ..., 192384, 192385, 193277]]),
 ('Disease',
  'rev_DRUGTREATSDISEASE',
  'Drug'): tensor([[ 4255,  6564,

In [60]:
from torch.optim import Adam

model = RGCN(num_nodes_dict=data.num_nodes_dict, 
             num_relations=len(data.edge_types), 
             hidden_channels=64, 
             out_channels=1)

optimizer = Adam(model.parameters(), lr=0.01)

for epoch in range(100):
    model.train()
    optimizer.zero_grad()
    
    # Forward pass
    z_dict = model(train_data.x_dict, train_data.edge_index_dict) #(train_data.edge_index_dict, train_data.edge_types)
    train_pred = model.decode(z_dict, train_data['Drug', 'DRUGTREATSDISEASE', 'Disease'].edge_label_index)
    train_loss = F.binary_cross_entropy_with_logits(train_pred, train_data['Drug', 'DRUGTREATSDISEASE', 'Disease'].edge_label.float())
    
    train_loss.backward()
    optimizer.step()
    
    # Validation step
    model.eval()
    with torch.no_grad():
        z_dict = model(val_data.x_dict, val_data.edge_index_dict)
        val_pred = model.decode(z_dict, val_data['Drug', 'DRUGTREATSDISEASE', 'Disease'].edge_label_index)
        val_loss = F.binary_cross_entropy_with_logits(val_pred, val_data['Drug', 'DRUGTREATSDISEASE', 'Disease'].edge_label)
    
    print(f'Epoch {epoch+1}, Train Loss: {train_loss.item()}, Val Loss: {val_loss.item()}')

    # Early stopping check
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        patience_counter = 0
        #torch.save(model.state_dict(), os.path.join(path,'GAT_best_model.pth'))  # Save the best model
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print("Early stopping triggered!")
            break

AssertionError: 