In [1]:
import os
import dgl
import torch
import numpy as np
import pandas as pd
import networkx as nx
import torch.nn as nn
from tqdm.notebook import tqdm
import torch.nn.functional as F
from dgl.data import DGLDataset

Using backend: pytorch


### Setting up the Dataset

In [2]:
class PolypharmacyDataset(DGLDataset):
    def __init__(self):
        super().__init__(name='polypharmacy')

    def process(self):
        edges = pd.read_csv('../data/GNN_edges-toy.csv')
        properties = pd.read_csv('../data/GNN_properties-toy.csv')
        features = pd.read_csv('../data/GNN-GSE_full_pkd_norm.csv', index_col = 'ProteinID', sep=',')
        drug_comb = pd.read_csv('../data/GNN-TWOSIDE-train-PSE-964.csv', sep=',')
        self.graphs = []
        self.labels = []
        
        num_features = 964 # no. of PSEs
        self.dim_nfeats = num_features
        self.gclasses = num_features

        # Create a graph for each graph ID from the edges table.
        # First process the properties table into two dictionaries with graph IDs as keys.
        # The label and number of nodes are values.
        label_dict = {}
        num_nodes_dict = {}
        
        for _, row in properties.iterrows():
            label_dict[row['graph_id']] = row['label']
            num_nodes_dict[row['graph_id']] = row['num_nodes']

        # For the edges, first group the table by graph IDs.
        edges_group = edges.groupby('graph_id')
        
        #Node features or PSEs dictionary
        feature_dic = {i+1:torch.tensor(features.loc[i+1,]) for i in range(len(features))}
        
        # For each graph ID...
        for graph_id in edges_group.groups:
            # Find the edges as well as the number of nodes and its label.
            edges_of_id = edges_group.get_group(graph_id)
            src = edges_of_id['src'].to_numpy()
            dst = edges_of_id['dst'].to_numpy()
            num_nodes = num_nodes_dict[graph_id]
            label = label_dict[graph_id]
            
            # Create a graph and add it to the list of graphs and labels.
            g = dgl.graph((src, dst), num_nodes=num_nodes)
            
            # Need to convert proteinsIDs for feature assigning
            prot_ids = edges_of_id['src_prot'].unique().tolist()
            convert_prot = {prot_ids.index(prot):prot for prot in prot_ids}
            
            #Adding features of each node
            g.ndata['PSE'] = torch.zeros(g.num_nodes(), num_features)
            for node in g.nodes().tolist():
                g.ndata['PSE'][node] = feature_dic[convert_prot[node]]
                
            self.graphs.append(g)
            self.labels.append(label)
            

        # Convert the label list to tensor for saving.
        self.labels = torch.LongTensor(self.labels)

    def __getitem__(self, i):
        return self.graphs[i], self.labels[i]

    def __len__(self):
        return len(self.graphs)
    

dataset = PolypharmacyDataset()
graph, label = dataset[0]
print(graph, label)


Graph(num_nodes=722, num_edges=38344,
      ndata_schemes={'PSE': Scheme(shape=(964,), dtype=torch.float32)}
      edata_schemes={}) tensor(85)


In [3]:
from dgl.dataloading import GraphDataLoader
from torch.utils.data.sampler import SubsetRandomSampler

num_examples = len(dataset)
num_train = int(num_examples * 0.8)

train_sampler = SubsetRandomSampler(torch.arange(num_train))
test_sampler = SubsetRandomSampler(torch.arange(num_train, num_examples))

train_dataloader = GraphDataLoader(
    dataset, sampler=train_sampler, batch_size=5, drop_last=False)
test_dataloader = GraphDataLoader(
    dataset, sampler=test_sampler, batch_size=5, drop_last=False)

In [4]:
it = iter(train_dataloader)
batch = next(it)
print(batch)

[Graph(num_nodes=21272, num_edges=4610760,
      ndata_schemes={'PSE': Scheme(shape=(964,), dtype=torch.float32)}
      edata_schemes={}), tensor([158, 119, 160, 143, 146])]


In [5]:
batched_graph, labels = batch
print('Number of nodes for each graph element in the batch:', batched_graph.batch_num_nodes())
print('Number of edges for each graph element in the batch:', batched_graph.batch_num_edges())

# Recover the original graph elements from the minibatch
graphs = dgl.unbatch(batched_graph)
print('The original graphs in the minibatch:')
print(graphs)

Number of nodes for each graph element in the batch: tensor([2600,  719, 6921, 5656, 5376])
Number of edges for each graph element in the batch: tensor([ 268040,   38194, 1876090, 1269352, 1159084])
The original graphs in the minibatch:
[Graph(num_nodes=2600, num_edges=268040,
      ndata_schemes={'PSE': Scheme(shape=(964,), dtype=torch.float32)}
      edata_schemes={}), Graph(num_nodes=719, num_edges=38194,
      ndata_schemes={'PSE': Scheme(shape=(964,), dtype=torch.float32)}
      edata_schemes={}), Graph(num_nodes=6921, num_edges=1876090,
      ndata_schemes={'PSE': Scheme(shape=(964,), dtype=torch.float32)}
      edata_schemes={}), Graph(num_nodes=5656, num_edges=1269352,
      ndata_schemes={'PSE': Scheme(shape=(964,), dtype=torch.float32)}
      edata_schemes={}), Graph(num_nodes=5376, num_edges=1159084,
      ndata_schemes={'PSE': Scheme(shape=(964,), dtype=torch.float32)}
      edata_schemes={})]


### GNN Model: GCN

In [6]:
from dgl.nn import GraphConv

class GCN(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(GCN, self).__init__()
        self.conv1 = GraphConv(in_feats, h_feats)
        self.conv2 = GraphConv(h_feats, num_classes)

    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat)
        h = F.relu(h)
        h = self.conv2(g, h)
        g.ndata['h'] = h
        return dgl.mean_nodes(g, 'h')

In [None]:
# Create the model with given dimensions
model = GCN(dataset.dim_nfeats, 16, dataset.gclasses)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

for epoch in range(20):
    for batched_graph, labels in train_dataloader:
        pred = model(batched_graph, batched_graph.ndata['PSE'].float())
        loss = F.cross_entropy(pred, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

num_correct = 0
num_tests = 0
for batched_graph, labels in test_dataloader:
    pred = model(batched_graph, batched_graph.ndata['PSE'].float())
    num_correct += (pred.argmax(1) == labels).sum().item()
    num_tests += len(labels)

print('Test accuracy:', num_correct / num_tests)

### Testing

In [123]:
edges = pd.read_csv('../data/GNN_edges-toy.csv')
properties = pd.read_csv('../data/GNN_properties-toy.csv')
features = pd.read_csv('../data/GNN-GSE_full_pkd_norm.csv', index_col = 'ProteinID', sep=',')

In [130]:
feature_dic = {i+1:torch.tensor(features.loc[i+1,]) for i in range(len(features))}
len(feature_dic)

19555

In [111]:
properties

Unnamed: 0,graph_id,label,num_nodes
0,1,85,722
1,2,119,719
2,3,137,935
3,4,143,5656
4,5,146,5376
5,6,158,2600
6,7,159,10469
7,8,160,6921
8,9,175,1158
9,10,187,6830


In [124]:
graph = edges.loc[edges['graph_id']==3]
src = graph['src'].to_numpy()
dst = graph['dst'].to_numpy()
graph

Unnamed: 0,graph_id,src,dst,src_prot,dst_prot
76538,3,0,172,32,4621
76539,3,0,50,32,1027
76540,3,0,54,32,1129
76541,3,0,828,32,18028
76542,3,0,717,32,16363
...,...,...,...,...,...
126803,3,933,128,19514,3408
126804,3,933,679,19514,15709
126805,3,933,664,19514,15442
126806,3,934,420,19530,10008


In [125]:
g = dgl.graph((src, dst), num_nodes=935)
g

Graph(num_nodes=935, num_edges=50270,
      ndata_schemes={}
      edata_schemes={})

In [137]:
edges_group = edges.groupby('graph_id')
edges_of_id = edges_group.get_group(3)
edges_of_id

Unnamed: 0,graph_id,src,dst,src_prot,dst_prot
76538,3,0,172,32,4621
76539,3,0,50,32,1027
76540,3,0,54,32,1129
76541,3,0,828,32,18028
76542,3,0,717,32,16363
...,...,...,...,...,...
126803,3,933,128,19514,3408
126804,3,933,679,19514,15709
126805,3,933,664,19514,15442
126806,3,934,420,19530,10008


In [142]:
prot_ids = edges_of_id['src_prot'].unique().tolist()
convert_prot = {prot_ids.index(prot):prot for prot in prot_ids}
g.ndata['PSE'] = torch.zeros(g.num_nodes(), 964)
for node in g.nodes().tolist():
    g.ndata['PSE'][node] = feature_dic[convert_prot[node]]
    
g

Graph(num_nodes=935, num_edges=50270,
      ndata_schemes={'PSE': Scheme(shape=(964,), dtype=torch.float32)}
      edata_schemes={})

In [150]:
features.loc[3291]

Arthralgia               0.781995
Diarrhoea                0.782882
Headache                 0.784467
Vomiting                 0.783447
Dyspepsia                0.780225
                           ...   
Hypertensive crisis      0.773061
Pneumonia bacterial      0.712454
Hepatocellular injury    0.782869
Shock haemorrhagic       0.737688
Haemorrhagic stroke      0.785618
Name: 3291, Length: 964, dtype: float64

In [149]:
g.ndata['PSE'][127]

tensor([0.7820, 0.7829, 0.7845, 0.7834, 0.7802, 0.7793, 0.7811, 0.7759, 0.7768,
        0.7833, 0.8071, 0.7828, 0.7835, 0.7612, 0.7716, 0.7817, 0.7594, 0.7814,
        0.7882, 0.7843, 0.7842, 0.8261, 0.7802, 0.7832, 0.7789, 0.8044, 0.7800,
        0.7852, 0.7725, 0.7356, 0.7820, 0.7800, 0.7807, 0.7702, 0.7781, 0.7819,
        0.7551, 0.7790, 0.7845, 0.7783, 0.7888, 0.7822, 0.8066, 0.7891, 0.7784,
        0.7758, 0.7890, 0.7904, 0.7863, 0.7741, 0.7836, 0.7685, 0.7760, 0.7046,
        0.7756, 0.7830, 0.7710, 0.7940, 0.7789, 0.7524, 0.7860, 0.7832, 0.7861,
        0.8036, 0.7798, 0.7782, 0.6946, 0.7210, 0.7794, 0.7820, 0.7886, 0.7735,
        0.8290, 0.7740, 0.7896, 0.7720, 0.7496, 0.7786, 0.7837, 0.7852, 0.7835,
        0.7766, 0.6761, 0.7887, 0.7695, 0.7522, 0.7793, 0.7869, 0.7750, 0.7919,
        0.7936, 0.7718, 0.7461, 0.7875, 0.7758, 0.7751, 0.7904, 0.7810, 0.7727,
        0.7553, 0.8057, 0.7663, 0.7780, 0.7712, 0.7804, 0.7836, 0.7762, 0.7887,
        0.7848, 0.7870, 0.7761, 0.7808, 

In [148]:
convert_prot

{0: 32,
 1: 90,
 2: 95,
 3: 100,
 4: 119,
 5: 153,
 6: 172,
 7: 182,
 8: 207,
 9: 210,
 10: 221,
 11: 225,
 12: 268,
 13: 302,
 14: 310,
 15: 315,
 16: 322,
 17: 332,
 18: 336,
 19: 355,
 20: 368,
 21: 405,
 22: 421,
 23: 525,
 24: 550,
 25: 615,
 26: 627,
 27: 638,
 28: 650,
 29: 663,
 30: 668,
 31: 673,
 32: 744,
 33: 766,
 34: 769,
 35: 803,
 36: 819,
 37: 834,
 38: 848,
 39: 850,
 40: 857,
 41: 887,
 42: 901,
 43: 906,
 44: 916,
 45: 929,
 46: 931,
 47: 947,
 48: 954,
 49: 1026,
 50: 1027,
 51: 1052,
 52: 1054,
 53: 1096,
 54: 1129,
 55: 1136,
 56: 1180,
 57: 1217,
 58: 1256,
 59: 1258,
 60: 1262,
 61: 1265,
 62: 1300,
 63: 1330,
 64: 1336,
 65: 1340,
 66: 1341,
 67: 1350,
 68: 1355,
 69: 1356,
 70: 1359,
 71: 1364,
 72: 1368,
 73: 1390,
 74: 1414,
 75: 1417,
 76: 1427,
 77: 1454,
 78: 1477,
 79: 1478,
 80: 1513,
 81: 1536,
 82: 1545,
 83: 1571,
 84: 1605,
 85: 1629,
 86: 1685,
 87: 1695,
 88: 1708,
 89: 1738,
 90: 1756,
 91: 1842,
 92: 1868,
 93: 1936,
 94: 1969,
 95: 1995,
 96: 2

In [None]:
g = dgl.graph(([0, 0, 1, 5], [1, 2, 2, 0])) # 6 nodes, 4 edges
g
g.ndata['x'] = th.ones(g.num_nodes(), 3)               # node feature of length 3
g.edata['x'] = th.ones(g.num_edges(), dtype=th.int32)  # scalar integer feature
g
# different names can have different shapes
g.ndata['y'] = th.randn(g.num_nodes(), 5)
g.ndata['x'][1]                  # get node 1's feature
g.edata['x'][th.tensor([0, 3])]  # get features of edge 0 and 3
g.ndata['x'][0] = th.zeros(1, 3)