## 1. Environment setup

In [1]:
import os
import dgl
import torch
import numpy as np
import pandas as pd
import networkx as nx
import torch.nn as nn
from tqdm.notebook import tqdm
import torch.nn.functional as F
from dgl.data import DGLDataset
#torch.cuda.set_device(0)  

Using backend: pytorch


## 2. Setting up the Dataset

In [2]:
class PolypharmacyDataset(DGLDataset):
    def __init__(self):
        super().__init__(name='polypharmacy')

    def process(self):
        edges = pd.read_csv('../data/GNN_edges-toy.csv')
        properties = pd.read_csv('../data/GNN_properties-toy.csv')
        drug_comb = pd.read_csv('../data/GNN-TWOSIDE-train-PSE-964-toy.csv', sep=',') # or 3347
        features = pd.read_csv('../data/GNN-GSE_full_pkd_norm.csv', index_col = 'ProteinID', sep=',')
        
        self.graphs = []
        self.labels = []
        self.comb_graphs = []
        self.comb_labels = []
        
        num_features = len(features.columns) # no. of PSEs
        self.dim_nfeats = num_features
        self.gclasses = num_features

        # Create a graph for each graph ID from the edges table.
        # First process the properties table into two dictionaries with graph IDs as keys.
        # The label and number of nodes are values.
        label_dict = {}
        num_nodes_dict = {}
        
        for _, row in properties.iterrows():
            label_dict[row['graph_id']] = row['label']
            num_nodes_dict[row['graph_id']] = row['num_nodes']

        # For the edges, first group the table by graph IDs.
        edges_group = edges.groupby('graph_id')
        
        #Node features or PSEs dictionary
        feature_dic = {i+1:torch.tensor(features.loc[i+1,]) for i in range(len(features))}
        
        # For each graph ID...
        for graph_id in edges_group.groups:
            # Find the edges as well as the number of nodes and its label.
            edges_of_id = edges_group.get_group(graph_id)
            src = edges_of_id['src'].to_numpy()
            dst = edges_of_id['dst'].to_numpy()
            num_nodes = num_nodes_dict[graph_id]
            label = label_dict[graph_id]
            
            # Create a graph and add it to the list of graphs and labels.
            g = dgl.graph((src, dst), num_nodes=num_nodes)
            
            # Need to convert proteinsIDs for feature assigning
            prot_ids = edges_of_id['src_prot'].unique().tolist()
            for prot in edges_of_id['dst_prot'].unique().tolist():
                if prot not in prot_ids:
                    prot_ids.append(prot)
            convert_prot = {prot_ids.index(prot):prot for prot in prot_ids}
            
            #Adding features of each node
            g.ndata['PSE'] = torch.zeros(g.num_nodes(), num_features)
            for node in g.nodes().tolist():
                g.ndata['PSE'][node] = feature_dic[convert_prot[node]]
                
            self.graphs.append(g)
            self.labels.append(label)
        
        # conver drugid to their respective graph id
        #drug2graph = {properties['label'][i]:i for i in range(len(properties))} 
        drug2graph = {self.labels[i]:i for i in range(len(self.labels))} 

        for i in range(len(drug_comb)):
            row = drug_comb.loc[i]
            g1 = self.graphs[drug2graph[row[0]]] # Drug1 graph
            g2 = self.graphs[drug2graph[row[1]]] # Drug2 graph  
            self.comb_graphs.append([g1,g2])
            self.comb_labels.append(torch.tensor(row[2:])) # PSE values

            
        # Convert the label list to tensor for saving.
        #self.comb_labels = torch.LongTensor(self.comb_labels)

    def __getitem__(self, i):
       # return self.comb_graphs[i], self.comb_labels[i]
        return self.comb_graphs[i], self.comb_labels[i]

    def __len__(self):
        return len(self.comb_graphs)
    

dataset = PolypharmacyDataset()
graph, label = dataset[0]
print(graph)


[Graph(num_nodes=722, num_edges=38344,
      ndata_schemes={'PSE': Scheme(shape=(964,), dtype=torch.float32)}
      edata_schemes={}), Graph(num_nodes=4515, num_edges=799034,
      ndata_schemes={'PSE': Scheme(shape=(964,), dtype=torch.float32)}
      edata_schemes={})]


## 3. Data loading and batch

In [123]:
from dgl.dataloading import GraphDataLoader
from torch.utils.data.sampler import SubsetRandomSampler

num_examples = len(dataset)
num_train = int(num_examples * 0.8)

train_sampler = SubsetRandomSampler(torch.arange(num_train))
test_sampler = SubsetRandomSampler(torch.arange(num_train, num_examples))


train_dataloader = GraphDataLoader(
    dataset, sampler=train_sampler, batch_size=3, drop_last=False)
test_dataloader = GraphDataLoader(
    dataset, sampler=test_sampler, batch_size=3, drop_last=False)

In [124]:
len(dataset)

18

In [125]:
it = iter(train_dataloader)
batch = next(it)
print(batch[0][0])

Graph(num_nodes=8973, num_edges=2343928,
      ndata_schemes={'PSE': Scheme(shape=(964,), dtype=torch.float32)}
      edata_schemes={})


In [126]:
batched_graph, labels = batch
print('Number of nodes for each graph1 element in the batch:', batched_graph[0].batch_num_nodes())
print('Number of nodes for each graph2 element in the batch:', batched_graph[1].batch_num_nodes())
print('Number of edges for each graph1 element in the batch:', batched_graph[0].batch_num_edges())
print('Number of edges for each graph2 element in the batch:', batched_graph[1].batch_num_edges())

# Recover the original graph elements from the minibatch
graphs = dgl.unbatch(batched_graph[0])
print('The original graphs1 in the minibatch:')
print(graphs)

Number of nodes for each graph1 element in the batch: tensor([ 694,  694, 7585])
Number of nodes for each graph2 element in the batch: tensor([2756, 5656,  853])
Number of edges for each graph1 element in the batch: tensor([  36906,   36906, 2270116])
Number of edges for each graph2 element in the batch: tensor([ 292614, 1269352,   44876])
The original graphs1 in the minibatch:
[Graph(num_nodes=694, num_edges=36906,
      ndata_schemes={'PSE': Scheme(shape=(964,), dtype=torch.float32)}
      edata_schemes={}), Graph(num_nodes=694, num_edges=36906,
      ndata_schemes={'PSE': Scheme(shape=(964,), dtype=torch.float32)}
      edata_schemes={}), Graph(num_nodes=7585, num_edges=2270116,
      ndata_schemes={'PSE': Scheme(shape=(964,), dtype=torch.float32)}
      edata_schemes={})]


## 4. GNN Model: Siamese GCN

In [130]:
from dgl.nn import GraphConv

class GCN(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(GCN, self).__init__()
        self.conv1 = GraphConv(in_feats, h_feats)
        self.conv2 = GraphConv(h_feats,  num_classes)
        
    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat)
        h = F.relu(h)
        h = self.conv2(g, h)
        g.ndata['h'] = h
        out = F.relu(dgl.mean_nodes(g, 'h'))
        return out


## 5. Training

In [129]:
%%time
# Create the model with given dimensions
model = GCN(dataset.dim_nfeats, 50, dataset.gclasses)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

for epoch in range(5):
    for batched_graph, labels in train_dataloader:
        g1 = batched_graph[0]
        g2 = batched_graph[1]
        pred1 = model(g1, g1.ndata['PSE'].float())
        pred2 = model(g2, g2.ndata['PSE'].float())
        pred = F.relu((pred1+pred2)/2)
        loss = F.binary_cross_entropy(F.sigmoid(pred).float(),labels.float())
        #loss = 1-F.cosine_similarity(pred,labels).mean()
        #loss = F.triplet_margin_loss(labels,pred1,pred2)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print ('epoch %s | loss = %s' % (epoch,loss.tolist()))
        
    
num_correct = 0
num_correct_0 = 0
num_incorrect = 0
num_tests = 0
for batched_graph, labels in test_dataloader:
    pred1 = model(batched_graph[0], batched_graph[0].ndata['PSE'].float())
    pred2 = model(batched_graph[1], batched_graph[1].ndata['PSE'].float())
    pred = F.relu((pred1+pred2)/2)
    for i in range(len(labels)):
        for j in range(len(labels[i])):
            num_tests += 1
            if labels[i][j] == 1 and pred[i][j] != 0:
                num_correct += 1
            elif labels[i][j] == 1 and pred[i][j] == 0:
                num_incorrect += 1
            elif labels[i][j] == 0 and pred[i][j] != 0:
                num_incorrect += 1
            elif labels[i][j] == 0 and pred[i][j] == 0:
                num_correct_0 += 1
                pass

    acc = ((num_correct+num_correct_0)*100)/(num_tests)
    prec = ((num_correct)*100)/(num_incorrect+num_correct)
    sim = ((F.cosine_similarity(pred.float(),labels.float())).mean().tolist())*100
    print('Accuracy: %s | Precision: %s | Similarity: %s' %(round(acc,5),round(prec,5),round(sim,5)))

epoch 0 | loss = 0.7125621438026428
epoch 0 | loss = 0.7406646609306335
epoch 0 | loss = 0.7002500891685486
epoch 0 | loss = 0.6930876970291138
epoch 0 | loss = 0.693204402923584
epoch 1 | loss = 0.693008542060852
epoch 1 | loss = 0.6929177641868591
epoch 1 | loss = 0.6932021379470825
epoch 1 | loss = 0.6930321455001831
epoch 1 | loss = 0.6931098699569702
epoch 2 | loss = 0.6930577158927917
epoch 2 | loss = 0.692969024181366
epoch 2 | loss = 0.6927382349967957
epoch 2 | loss = 0.6930434703826904
epoch 2 | loss = 0.6930602788925171
epoch 3 | loss = 0.6926549673080444
epoch 3 | loss = 0.6930493116378784
epoch 3 | loss = 0.6928863525390625
epoch 3 | loss = 0.6928797960281372
epoch 3 | loss = 0.69314044713974
epoch 4 | loss = 0.6926441192626953
epoch 4 | loss = 0.6928839087486267
epoch 4 | loss = 0.6930856108665466
epoch 4 | loss = 0.6927308440208435
epoch 4 | loss = 0.692812979221344
Accuracy: 94.22545 | Precision: 8.24176 | Similarity: 13.54651
Accuracy: 94.16494 | Precision: 10.35857 | 

In [98]:
loss.tolist()

0.693041980266571

In [None]:
loss

In [47]:
x1 = pred1[0]
x2 = pred2[0]
y1 = F.relu((x1*x2)**0.5)
y2 = F.relu((x1+x2)*0.5)
print(y1.mean(), y2.mean())

tensor(0.1765, grad_fn=<MeanBackward0>) tensor(0.1765, grad_fn=<MeanBackward0>)


In [49]:
F.triplet_margin_loss(labels,pred1,pred2)

tensor(0.9331, grad_fn=<MeanBackward0>)

In [9]:
pos = []
for i, label in enumerate(labels[0].tolist()):
    if label != 0:
        pos.append(i)
        print(i,':',label)

0 : 1
1 : 1
2 : 1
4 : 1
6 : 1
8 : 1
9 : 1
10 : 1
11 : 1
13 : 1
14 : 1
15 : 1
17 : 1
18 : 1
19 : 1
23 : 1
24 : 1
26 : 1
27 : 1
29 : 1
31 : 1
34 : 1
35 : 1
37 : 1
38 : 1
40 : 1
41 : 1
43 : 1
44 : 1
45 : 1
48 : 1
49 : 1
53 : 1
54 : 1
55 : 1
56 : 1
57 : 1
63 : 1
71 : 1
73 : 1
77 : 1
79 : 1
80 : 1
85 : 1
87 : 1
90 : 1
102 : 1
106 : 1
110 : 1
114 : 1
116 : 1
125 : 1
135 : 1
138 : 1
140 : 1
149 : 1
154 : 1
166 : 1
167 : 1
169 : 1
174 : 1
184 : 1
186 : 1
190 : 1
197 : 1
200 : 1
201 : 1
206 : 1
212 : 1
215 : 1
225 : 1
233 : 1
240 : 1
243 : 1
247 : 1
250 : 1
262 : 1
276 : 1
278 : 1
280 : 1
295 : 1
305 : 1
314 : 1
321 : 1
340 : 1
343 : 1
345 : 1
354 : 1
357 : 1
438 : 1
456 : 1
468 : 1
475 : 1
496 : 1
499 : 1
514 : 1
516 : 1
518 : 1
547 : 1
558 : 1
576 : 1
671 : 1
731 : 1
737 : 1
765 : 1
766 : 1
838 : 1


In [10]:
pos2 = []
for i, predic in enumerate(pred[0].tolist()):
    if predic != 0:
        pos2.append(i)
        print(i,':',predic)

6 : 0.38044387102127075
8 : 0.587792158126831
16 : 0.12098941206932068
35 : 0.0996902585029602
37 : 0.1469675749540329
40 : 0.25196316838264465
43 : 0.3600175380706787
45 : 0.22222143411636353
46 : 0.3237467110157013
48 : 0.37982177734375
49 : 0.48251789808273315
50 : 0.4170515537261963
52 : 0.4106405973434448
61 : 0.39962413907051086
62 : 0.13347773253917694
71 : 0.14868420362472534
73 : 0.4536293148994446
75 : 0.13180279731750488
78 : 0.3722996413707733
85 : 0.13340209424495697
88 : 0.06968531757593155
90 : 0.1370120346546173
94 : 0.11294928938150406
97 : 0.13625746965408325
100 : 0.3915035128593445
109 : 0.5002908706665039
113 : 0.27593111991882324
115 : 0.11313841491937637
117 : 0.28581711649894714
118 : 0.6448791027069092
119 : 0.5943384170532227
126 : 0.1319393366575241
128 : 0.5394355058670044
130 : 0.3075450360774994
150 : 0.02749987319111824
157 : 0.42085522413253784
159 : 0.47558271884918213
162 : 0.22198130190372467
164 : 0.4906436800956726
171 : 0.004296622704714537
178 : 0

In [11]:
print(len(pos), len(pos2))

107 119


In [None]:
num_correct = 0
num_tests = 0
for batched_graph, labels in test_dataloader:
    pred1 = model(batched_graph[0], batched_graph[0].ndata['PSE'].float())
    pred2 = model(batched_graph[1], batched_graph[1].ndata['PSE'].float())
    pred = F.relu((pred1+pred2)/2)
    diff = labels - pred
    #acc = diff.mean()
    #num_correct += (pred.argmax(1) == labels).sum().item()
    #num_tests += len(labels)
    diff = diff.tolist()
    y = 0
    for x in diff:
        y += x.count(0)
    acc = (y*100)/(4*964)
    print('Test accuracy:', acc)

### ==================== Testing ====================

In [123]:
edges = pd.read_csv('../data/GNN_edges-toy.csv')
properties = pd.read_csv('../data/GNN_properties-toy.csv')
features = pd.read_csv('../data/GNN-GSE_full_pkd_norm.csv', index_col = 'ProteinID', sep=',')

In [130]:
feature_dic = {i+1:torch.tensor(features.loc[i+1,]) for i in range(len(features))}
len(feature_dic)

19555

In [111]:
properties

Unnamed: 0,graph_id,label,num_nodes
0,1,85,722
1,2,119,719
2,3,137,935
3,4,143,5656
4,5,146,5376
5,6,158,2600
6,7,159,10469
7,8,160,6921
8,9,175,1158
9,10,187,6830


In [124]:
graph = edges.loc[edges['graph_id']==3]
src = graph['src'].to_numpy()
dst = graph['dst'].to_numpy()
graph

Unnamed: 0,graph_id,src,dst,src_prot,dst_prot
76538,3,0,172,32,4621
76539,3,0,50,32,1027
76540,3,0,54,32,1129
76541,3,0,828,32,18028
76542,3,0,717,32,16363
...,...,...,...,...,...
126803,3,933,128,19514,3408
126804,3,933,679,19514,15709
126805,3,933,664,19514,15442
126806,3,934,420,19530,10008


In [125]:
g = dgl.graph((src, dst), num_nodes=935)
g

Graph(num_nodes=935, num_edges=50270,
      ndata_schemes={}
      edata_schemes={})

In [137]:
edges_group = edges.groupby('graph_id')
edges_of_id = edges_group.get_group(3)
edges_of_id

Unnamed: 0,graph_id,src,dst,src_prot,dst_prot
76538,3,0,172,32,4621
76539,3,0,50,32,1027
76540,3,0,54,32,1129
76541,3,0,828,32,18028
76542,3,0,717,32,16363
...,...,...,...,...,...
126803,3,933,128,19514,3408
126804,3,933,679,19514,15709
126805,3,933,664,19514,15442
126806,3,934,420,19530,10008


In [142]:
prot_ids = edges_of_id['src_prot'].unique().tolist()
convert_prot = {prot_ids.index(prot):prot for prot in prot_ids}
g.ndata['PSE'] = torch.zeros(g.num_nodes(), 964)
for node in g.nodes().tolist():
    g.ndata['PSE'][node] = feature_dic[convert_prot[node]]
    
g

Graph(num_nodes=935, num_edges=50270,
      ndata_schemes={'PSE': Scheme(shape=(964,), dtype=torch.float32)}
      edata_schemes={})

In [150]:
features.loc[3291]

Arthralgia               0.781995
Diarrhoea                0.782882
Headache                 0.784467
Vomiting                 0.783447
Dyspepsia                0.780225
                           ...   
Hypertensive crisis      0.773061
Pneumonia bacterial      0.712454
Hepatocellular injury    0.782869
Shock haemorrhagic       0.737688
Haemorrhagic stroke      0.785618
Name: 3291, Length: 964, dtype: float64

In [2]:
g.ndata['PSE'][127]

In [3]:
convert_prot

In [None]:
from dgl.nn import GraphConv

class GCN(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(GCN, self).__init__()
        self.conv1 = GraphConv(in_feats, h_feats)
        self.conv2 = GraphConv(h_feats, num_classes)
        self.in_feats = in_feats
        self.num_classes=num_classes
        
    def forward(self, g):
        h = self.conv1(g, self.in_feats)
        h = F.relu(h)
        h = self.conv2(g, h)
        g.ndata['h'] = h
        return dgl.mean_nodes(g, 'h')
    '''
    def forward(self, g):
        out1 = self.forward_one(g[0])
        out2 = self.forward_one(g[1])
        out = F.relu((out1+out2)/2)
        #out = self.out(mean)
        return out

# for test
#if __name__ == '__main__':
 #   net = GCN()
 #   print(net)
 #   print(list(net.parameters()))
    def forward_one(self, g):
        h = self.conv1(g, self.in_feat)
        h = F.relu(h)
        h = self.conv2(h, self.n_classes)
        h = F.relu(h)
        g.ndata['h'] = h
        return dgl.mean_nodes(g, 'h')
        '''

In [None]:
g = dgl.graph(([0, 0, 1, 5], [1, 2, 2, 0])) # 6 nodes, 4 edges
g
g.ndata['x'] = th.ones(g.num_nodes(), 3)               # node feature of length 3
g.edata['x'] = th.ones(g.num_edges(), dtype=th.int32)  # scalar integer feature
g
# different names can have different shapes
g.ndata['y'] = th.randn(g.num_nodes(), 5)
g.ndata['x'][1]                  # get node 1's feature
g.edata['x'][th.tensor([0, 3])]  # get features of edge 0 and 3
g.ndata['x'][0] = th.zeros(1, 3)