## 1. Environment setup

In [1]:
import os
import dgl
import torch
import numpy as np
import pandas as pd
import networkx as nx
import torch.nn as nn
from tqdm.notebook import tqdm
import torch.nn.functional as F
from dgl.data import DGLDataset
torch.cuda.set_device(0)  

Using backend: pytorch


## 2. Setting up the Dataset

In [3]:
class PolypharmacyDataset(DGLDataset):
    def __init__(self):
        super().__init__(name='polypharmacy')

    def process(self):
        edges = pd.read_csv('../data/GNN_edges-toy.csv')
        properties = pd.read_csv('../data/GNN_properties-toy.csv')
        drug_comb = pd.read_csv('../data/GNN-TWOSIDE-train-PSE-964-toy.csv', sep=',') # or 3347
        features = pd.read_csv('../data/GNN-GSE_full_pkd_norm.csv', index_col = 'ProteinID', sep=',')
        
        self.graphs = []
        self.labels = []
        self.comb_graphs = []
        self.comb_labels = []
        
        num_features = len(features.columns) # no. of PSEs
        self.dim_nfeats = num_features
        self.gclasses = num_features

        # Create a graph for each graph ID from the edges table.
        # First process the properties table into two dictionaries with graph IDs as keys.
        # The label and number of nodes are values.
        label_dict = {}
        num_nodes_dict = {}
        
        for _, row in properties.iterrows():
            label_dict[row['graph_id']] = row['label']
            num_nodes_dict[row['graph_id']] = row['num_nodes']

        # For the edges, first group the table by graph IDs.
        edges_group = edges.groupby('graph_id')
        
        #Node features or PSEs dictionary
        feature_dic = {i+1:torch.tensor(features.loc[i+1,]) for i in range(len(features))}
        
        # For each graph ID...
        for graph_id in edges_group.groups:
            # Find the edges as well as the number of nodes and its label.
            edges_of_id = edges_group.get_group(graph_id)
            src = edges_of_id['src'].to_numpy()
            dst = edges_of_id['dst'].to_numpy()
            num_nodes = num_nodes_dict[graph_id]
            label = label_dict[graph_id]
            
            # Create a graph and add it to the list of graphs and labels.
            g = dgl.graph((src, dst), num_nodes=num_nodes)
            
            # Need to convert proteinsIDs for feature assigning
            prot_ids = edges_of_id['src_prot'].unique().tolist()
            for prot in edges_of_id['dst_prot'].unique().tolist():
                if prot not in prot_ids:
                    prot_ids.append(prot)
            convert_prot = {prot_ids.index(prot):prot for prot in prot_ids}
            
            #Adding features of each node
            g.ndata['PSE'] = torch.zeros(g.num_nodes(), num_features)
            for node in g.nodes().tolist():
                g.ndata['PSE'][node] = feature_dic[convert_prot[node]]
                
            self.graphs.append(g)
            self.labels.append(label)
        
        # conver drugid to their respective graph id
        #drug2graph = {properties['label'][i]:i for i in range(len(properties))} 
        drug2graph = {self.labels[i]:i for i in range(len(self.labels))} 

        for i in range(len(drug_comb)):
            row = drug_comb.loc[i]
            g1 = self.graphs[drug2graph[row[0]]] # Drug1 graph
            g2 = self.graphs[drug2graph[row[1]]] # Drug2 graph  
            self.comb_graphs.append([g1,g2])
            self.comb_labels.append(torch.tensor(row[2:])) # PSE values

            
        # Convert the label list to tensor for saving.
        #self.comb_labels = torch.LongTensor(self.comb_labels)

    def __getitem__(self, i):
       # return self.comb_graphs[i], self.comb_labels[i]
        return self.comb_graphs[i], self.comb_labels[i]

    def __len__(self):
        return len(self.comb_graphs)
    

dataset = PolypharmacyDataset()
graph, label = dataset[0]
print(graph)


AttributeError: 'PolypharmacyDataset' object has no attribute 'to'

In [4]:
graph, label = dataset[0]
print(graph)

[Graph(num_nodes=722, num_edges=38344,
      ndata_schemes={'PSE': Scheme(shape=(964,), dtype=torch.float32)}
      edata_schemes={}), Graph(num_nodes=4515, num_edges=799034,
      ndata_schemes={'PSE': Scheme(shape=(964,), dtype=torch.float32)}
      edata_schemes={})]


## 3. Data loading and batch

In [101]:
from dgl.dataloading import GraphDataLoader
from torch.utils.data.sampler import SubsetRandomSampler

num_examples = len(dataset)
num_train = int(num_examples * 0.8)

train_sampler = SubsetRandomSampler(torch.arange(num_train))
test_sampler = SubsetRandomSampler(torch.arange(num_train, num_examples))


train_dataloader = GraphDataLoader(
    dataset, sampler=train_sampler, batch_size=5, drop_last=False)
test_dataloader = GraphDataLoader(
    dataset, sampler=test_sampler, batch_size=5, drop_last=False)

In [6]:
len(dataset)

18

In [76]:
it = iter(train_dataloader)
batch = next(it)
print(batch[0][0])

Graph(num_nodes=15641, num_edges=3666126,
      ndata_schemes={'PSE': Scheme(shape=(964,), dtype=torch.float32)}
      edata_schemes={})


In [8]:
batched_graph, labels = batch
print('Number of nodes for each graph1 element in the batch:', batched_graph[0].batch_num_nodes())
print('Number of nodes for each graph2 element in the batch:', batched_graph[1].batch_num_nodes())
print('Number of edges for each graph1 element in the batch:', batched_graph[0].batch_num_edges())
print('Number of edges for each graph2 element in the batch:', batched_graph[1].batch_num_edges())

# Recover the original graph elements from the minibatch
graphs = dgl.unbatch(batched_graph[0])
print('The original graphs1 in the minibatch:')
print(graphs)

Number of nodes for each graph1 element in the batch: tensor([7585, 5656,  853,  722,  694])
Number of nodes for each graph2 element in the batch: tensor([ 722,  846, 2756, 4515,  846])
Number of edges for each graph1 element in the batch: tensor([2270116, 1269352,   44876,   38344,   36906])
Number of edges for each graph2 element in the batch: tensor([ 38344,  44316, 292614, 799034,  44316])
The original graphs1 in the minibatch:
[Graph(num_nodes=7585, num_edges=2270116,
      ndata_schemes={'PSE': Scheme(shape=(964,), dtype=torch.float32)}
      edata_schemes={}), Graph(num_nodes=5656, num_edges=1269352,
      ndata_schemes={'PSE': Scheme(shape=(964,), dtype=torch.float32)}
      edata_schemes={}), Graph(num_nodes=853, num_edges=44876,
      ndata_schemes={'PSE': Scheme(shape=(964,), dtype=torch.float32)}
      edata_schemes={}), Graph(num_nodes=722, num_edges=38344,
      ndata_schemes={'PSE': Scheme(shape=(964,), dtype=torch.float32)}
      edata_schemes={}), Graph(num_nodes=694, 

## 4. GNN Model: Siamese GCN

In [9]:
from dgl.nn import GraphConv

class GCN(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(GCN, self).__init__()
        self.conv1 = GraphConv(in_feats, h_feats)
        self.conv2 = GraphConv(h_feats, num_classes)
        
    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat)
        h = F.relu(h)
        h = self.conv2(g, h)
        g.ndata['h'] = h
        return dgl.mean_nodes(g, 'h')


## 5. Training

In [104]:
%%time
# Create the model with given dimensions
model = GCN(dataset.dim_nfeats, 450, dataset.gclasses)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

for epoch in range(5):
    print('epoch',epoch+1)
    for batched_graph, labels in train_dataloader:
        g1 = batched_graph[0]
        g2 = batched_graph[1]
        pred1 = model(g1, g1.ndata['PSE'].float())
        pred2 = model(g2, g2.ndata['PSE'].float())
        pred = F.relu((pred1+pred2)/2)
        loss = F.mse_loss(pred.float(), labels.float())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    
num_correct = 0
num_tests = 0
for batched_graph, labels in test_dataloader:
    pred1 = model(batched_graph[0], batched_graph[0].ndata['PSE'].float())
    pred2 = model(batched_graph[1], batched_graph[1].ndata['PSE'].float())
    pred = F.relu((pred1+pred2)/2)
    diff = labels - pred
    #acc = diff.mean()
    #num_correct += (pred.argmax(1) == labels).sum().item()
    #num_tests += len(labels)
    diff = diff.tolist()
    y = 0
    for x in diff:
        y += x.count(0)
    acc = (y*100)/(4*964)
    print('Test accuracy:', acc)

epoch 0
epoch 1
epoch 2
epoch 3
epoch 4
Test accuracy: 72.8734439834025
Wall time: 1min 10s


In [68]:
1 epoch -> 53, 64, 60, 68, 57 ~ 60
2 epoch ->  73 ~ 
5 epoch ->  86 ~ 
10 epoch ->  91 ~ 

593

In [105]:
pos = []
for i, label in enumerate(labels[0].tolist()):
    if label != 0:
        pos.append(i)
        print(i,':',label)

1 : 1
2 : 1
3 : 1
5 : 1
8 : 1
17 : 1
23 : 1
26 : 1
45 : 1
48 : 1
49 : 1
54 : 1
58 : 1
61 : 1
71 : 1
79 : 1
89 : 1
103 : 1
104 : 1
106 : 1
113 : 1
116 : 1
118 : 1
126 : 1
130 : 1
131 : 1
132 : 1
142 : 1
148 : 1
154 : 1
158 : 1
177 : 1
183 : 1
212 : 1
215 : 1
246 : 1
258 : 1
265 : 1
295 : 1
321 : 1
340 : 1
343 : 1
345 : 1
366 : 1
372 : 1
391 : 1
435 : 1
461 : 1
489 : 1
508 : 1
516 : 1
518 : 1
535 : 1
536 : 1
537 : 1
611 : 1
613 : 1
642 : 1
660 : 1
763 : 1
764 : 1


In [110]:
pos2 = []
for i, predic in enumerate(pred[0].tolist()):
    if predic > 0.5:
        pos2.append(i)
        print(i,':',predic)

1 : 1.1869322061538696
3 : 0.5132281184196472
5 : 0.8081663846969604
8 : 0.6965665221214294
14 : 0.5136588215827942
15 : 0.6814299821853638
19 : 0.6229294538497925
20 : 0.889102041721344
21 : 0.8402390480041504
24 : 0.774488091468811
31 : 0.5501083135604858
35 : 0.678189754486084
40 : 0.5777214169502258
54 : 0.8008038997650146
57 : 0.5739651322364807
58 : 0.6112098693847656
70 : 0.5463200807571411
90 : 0.654937744140625
99 : 0.6231285333633423
103 : 1.0847125053405762
106 : 0.6038809418678284
110 : 0.6283327341079712
116 : 0.5954090356826782
118 : 0.6291906237602234
128 : 0.7105851173400879
129 : 0.5435681343078613
131 : 0.5831461548805237
132 : 0.7943530082702637
136 : 0.6944834589958191
144 : 0.5415425896644592
154 : 0.5906703472137451
161 : 0.5148847103118896
174 : 0.5993022322654724
178 : 0.5105489492416382
183 : 0.9503158926963806
186 : 0.5618095397949219
194 : 0.5665977001190186
206 : 0.8024017810821533
212 : 0.5391796827316284
213 : 0.5940852761268616
223 : 0.6121821403503418
23

In [111]:
print(len(pos), len(pos2))

61 84


### ==================== Testing ====================

In [123]:
edges = pd.read_csv('../data/GNN_edges-toy.csv')
properties = pd.read_csv('../data/GNN_properties-toy.csv')
features = pd.read_csv('../data/GNN-GSE_full_pkd_norm.csv', index_col = 'ProteinID', sep=',')

In [130]:
feature_dic = {i+1:torch.tensor(features.loc[i+1,]) for i in range(len(features))}
len(feature_dic)

19555

In [111]:
properties

Unnamed: 0,graph_id,label,num_nodes
0,1,85,722
1,2,119,719
2,3,137,935
3,4,143,5656
4,5,146,5376
5,6,158,2600
6,7,159,10469
7,8,160,6921
8,9,175,1158
9,10,187,6830


In [124]:
graph = edges.loc[edges['graph_id']==3]
src = graph['src'].to_numpy()
dst = graph['dst'].to_numpy()
graph

Unnamed: 0,graph_id,src,dst,src_prot,dst_prot
76538,3,0,172,32,4621
76539,3,0,50,32,1027
76540,3,0,54,32,1129
76541,3,0,828,32,18028
76542,3,0,717,32,16363
...,...,...,...,...,...
126803,3,933,128,19514,3408
126804,3,933,679,19514,15709
126805,3,933,664,19514,15442
126806,3,934,420,19530,10008


In [125]:
g = dgl.graph((src, dst), num_nodes=935)
g

Graph(num_nodes=935, num_edges=50270,
      ndata_schemes={}
      edata_schemes={})

In [137]:
edges_group = edges.groupby('graph_id')
edges_of_id = edges_group.get_group(3)
edges_of_id

Unnamed: 0,graph_id,src,dst,src_prot,dst_prot
76538,3,0,172,32,4621
76539,3,0,50,32,1027
76540,3,0,54,32,1129
76541,3,0,828,32,18028
76542,3,0,717,32,16363
...,...,...,...,...,...
126803,3,933,128,19514,3408
126804,3,933,679,19514,15709
126805,3,933,664,19514,15442
126806,3,934,420,19530,10008


In [142]:
prot_ids = edges_of_id['src_prot'].unique().tolist()
convert_prot = {prot_ids.index(prot):prot for prot in prot_ids}
g.ndata['PSE'] = torch.zeros(g.num_nodes(), 964)
for node in g.nodes().tolist():
    g.ndata['PSE'][node] = feature_dic[convert_prot[node]]
    
g

Graph(num_nodes=935, num_edges=50270,
      ndata_schemes={'PSE': Scheme(shape=(964,), dtype=torch.float32)}
      edata_schemes={})

In [150]:
features.loc[3291]

Arthralgia               0.781995
Diarrhoea                0.782882
Headache                 0.784467
Vomiting                 0.783447
Dyspepsia                0.780225
                           ...   
Hypertensive crisis      0.773061
Pneumonia bacterial      0.712454
Hepatocellular injury    0.782869
Shock haemorrhagic       0.737688
Haemorrhagic stroke      0.785618
Name: 3291, Length: 964, dtype: float64

In [2]:
g.ndata['PSE'][127]

In [3]:
convert_prot

In [None]:
from dgl.nn import GraphConv

class GCN(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(GCN, self).__init__()
        self.conv1 = GraphConv(in_feats, h_feats)
        self.conv2 = GraphConv(h_feats, num_classes)
        self.in_feats = in_feats
        self.num_classes=num_classes
        
    def forward(self, g):
        h = self.conv1(g, self.in_feats)
        h = F.relu(h)
        h = self.conv2(g, h)
        g.ndata['h'] = h
        return dgl.mean_nodes(g, 'h')
    '''
    def forward(self, g):
        out1 = self.forward_one(g[0])
        out2 = self.forward_one(g[1])
        out = F.relu((out1+out2)/2)
        #out = self.out(mean)
        return out

# for test
#if __name__ == '__main__':
 #   net = GCN()
 #   print(net)
 #   print(list(net.parameters()))
    def forward_one(self, g):
        h = self.conv1(g, self.in_feat)
        h = F.relu(h)
        h = self.conv2(h, self.n_classes)
        h = F.relu(h)
        g.ndata['h'] = h
        return dgl.mean_nodes(g, 'h')
        '''

In [None]:
g = dgl.graph(([0, 0, 1, 5], [1, 2, 2, 0])) # 6 nodes, 4 edges
g
g.ndata['x'] = th.ones(g.num_nodes(), 3)               # node feature of length 3
g.edata['x'] = th.ones(g.num_edges(), dtype=th.int32)  # scalar integer feature
g
# different names can have different shapes
g.ndata['y'] = th.randn(g.num_nodes(), 5)
g.ndata['x'][1]                  # get node 1's feature
g.edata['x'][th.tensor([0, 3])]  # get features of edge 0 and 3
g.ndata['x'][0] = th.zeros(1, 3)