In [1]:
import networkx as nx
import torch
from torch_geometric.utils.convert import from_networkx

In [2]:
#Part 1

#Task 1.1
def gen_cycle_pairs(ns):
    graphs = []
    
    #generate pairs between 3 and n-3 such that their sum equals n
    def gen_pairs(n):
        res = []
        for i in range(3, n-2):
            res.append((i, n-i))
        return res
    
    #iterate over all ns 
    for n in ns:
        
        #generate pairs for given n and iterate over all pairs
        cycles = gen_pairs(n)
        for cycle in cycles:
            #for each pair generate the corresponing graphs
            c1 = nx.cycle_graph(cycle[0])
            c2 = nx.cycle_graph(cycle[1])
            graphs.append(nx.disjoint_union(c1, c2))
            
            #also append a single cycle graph for balance
            graphs.append(nx.cycle_graph(n))  
    
    return graphs
    
nx_graphs = gen_cycle_pairs(range(6, 16))

#Sanity check
for i in range(0, 110,2):  
    assert(nx.algorithms.graph_hashing.weisfeiler_lehman_graph_hash(nx_graphs[i]) 
        == nx.algorithms.graph_hashing.weisfeiler_lehman_graph_hash(nx_graphs[i+1]))

    
#Task 1.2
py_graphs = []
for i, graph in enumerate(nx_graphs):

    py_graph = from_networkx(graph)
    
    #x features
    py_graph.x = torch.zeros(py_graph.num_nodes,50)
    
    #for all even indices the graph is not a simple cycle
    if (i%2 == 0):
        py_graph.y = torch.tensor([0])
        
    #for all odd indices the graph is a simple cycle
    else:
        py_graph.y = torch.tensor([1])
        
    py_graphs.append(py_graph)
    


In [3]:
#Part 2
import torch.nn as nn
import torch_geometric.nn as tg_nn
import torch.nn.functional as F
from torch_geometric.data import DataLoader

#Task 2.1

input_dim = 50
output_dim = 1

class Network(nn.Module):
    
    def __init__(self):
        super().__init__()
        
        #ModuleList of 16 Message-passing layers
        self.convs = nn.ModuleList()
        self.convs.append(tg_nn.GCNConv(input_dim, 50))
        
        self.num_MP_layers = 16
        
        for i in range(15):
            self.convs.append(tg_nn.GCNConv(50, 50))
        
        #MLPs for post processing
        self.MLPs = nn.Sequential(
            nn.Linear(50, 50), nn.Dropout(0.25),
            nn.Linear(50,50), nn.Dropout(0.1),
            nn.Linear(50, output_dim)
            )
        
    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        
        #forwards throug MPNNs
        for i in range(self.num_MP_layers):
            x = self.convs[i](x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, 0.25)
        
        #global mean pool
        x = tg_nn.global_mean_pool(x, batch)
        
        #MLPs
        x = self.MLPs(x)     
        return x
    
    #reset parameters of every parameterised layer
    def reset_parameters(self):
        for layers in self.children():
            for layer in layers:
                if hasattr(layer, 'reset_parameters'):
                    layer.reset_parameters()
    
    #one output feature -> BCE loss
    def loss(self, pred, label):       
        return F.binary_cross_entropy(F.sigmoid(pred), label.float().unsqueeze(1))
    
#Task 2.2

def train(net, training_loader, epochs):
    
    #Adam with 1.e-4 learning rate
    opt = torch.optim.Adam(net.parameters(), lr = 1.e-4)

    for epoch in range(1,epochs+1):
        total_loss = 0
        net.train()
        for batch in training_loader:
            opt.zero_grad()
            
            prediction = net(batch)          
            label = batch.y
            
            loss = net.loss(prediction, label)
            loss.backward()
            opt.step()
            total_loss += loss.item() * batch.num_graphs
            
        total_loss /= len(training_loader.dataset)
        
        print(f'Epoch {epoch+0:03}: | Loss: {total_loss:.5f}')

def test(net,loader):
    
    net.eval()

    correct = 0

    for data in loader:
        with torch.no_grad():
            output = net(data)
            
            #sigmoid(output) > 0.5 -> true, else false
            classification = torch.round(F.sigmoid(output))
            label = data.y
            #print('classification', classification, 'label', label)
        
        correct+= (classification == label).sum()

    return correct/ len(loader.dataset)

def cross_validation(net, dataset, epochs):
    data_size = len(dataset)
    fold_index = int(0.2*len(dataset))
    
    mean_accuracy = 0
    for k in range(5):
        
        #reset parameters after starting new fold
        net.reset_parameters()
        
        #use 6 indices for left and right bound of the three data chunks
        #[train_index_left_left, train_index_left_right],[test_index_left, test_index_right],
        #[train_index_right_left, train_index_right_right]
        
        train_index_left_left = 0
        train_index_left_right = k * fold_index

        test_index_left = train_index_left_right
        test_index_right = test_index_left + fold_index

        train_index_right_left = test_index_right
        train_index_right_right = data_size

        #training set is union of two training chunks
        train_indices = list(range(train_index_left_left, train_index_left_right))+ list(range(train_index_right_left, train_index_right_right))
        test_indices = list(range(test_index_left, test_index_right))

        train_set = torch.utils.data.dataset.Subset(dataset,train_indices)
        test_set = torch.utils.data.dataset.Subset(dataset, test_indices)

        train_loader = DataLoader(train_set, shuffle = True)
        test_loader = DataLoader(test_set, shuffle = True)

        train(net, train_loader, epochs)
        
        accuracy = test(net, test_loader)
        mean_accuracy+= accuracy
        print(f'\nFold {k}: | Accuracy: {accuracy:.5f}\n')
        
    print('Total Mean Accuracy', mean_accuracy/5)


net = Network()
net = cross_validation(net, py_graphs, 200)





Epoch 001: | Loss: 0.69380
Epoch 002: | Loss: 0.69354
Epoch 003: | Loss: 0.69482
Epoch 004: | Loss: 0.69631
Epoch 005: | Loss: 0.69285
Epoch 006: | Loss: 0.69359
Epoch 007: | Loss: 0.69486
Epoch 008: | Loss: 0.69558
Epoch 009: | Loss: 0.69373
Epoch 010: | Loss: 0.69332
Epoch 011: | Loss: 0.69449
Epoch 012: | Loss: 0.69541
Epoch 013: | Loss: 0.69394
Epoch 014: | Loss: 0.69426
Epoch 015: | Loss: 0.69394
Epoch 016: | Loss: 0.69344
Epoch 017: | Loss: 0.69275
Epoch 018: | Loss: 0.69223
Epoch 019: | Loss: 0.69603
Epoch 020: | Loss: 0.69357
Epoch 021: | Loss: 0.69587
Epoch 022: | Loss: 0.69303
Epoch 023: | Loss: 0.69240
Epoch 024: | Loss: 0.69588
Epoch 025: | Loss: 0.69094
Epoch 026: | Loss: 0.69583
Epoch 027: | Loss: 0.69511
Epoch 028: | Loss: 0.69700
Epoch 029: | Loss: 0.69570
Epoch 030: | Loss: 0.69378
Epoch 031: | Loss: 0.69419
Epoch 032: | Loss: 0.69232
Epoch 033: | Loss: 0.69555
Epoch 034: | Loss: 0.69371
Epoch 035: | Loss: 0.69185
Epoch 036: | Loss: 0.69083
Epoch 037: | Loss: 0.69470
E

Epoch 104: | Loss: 0.69450
Epoch 105: | Loss: 0.69463
Epoch 106: | Loss: 0.69530
Epoch 107: | Loss: 0.69292
Epoch 108: | Loss: 0.69320
Epoch 109: | Loss: 0.69348
Epoch 110: | Loss: 0.69290
Epoch 111: | Loss: 0.69394
Epoch 112: | Loss: 0.69272
Epoch 113: | Loss: 0.69285
Epoch 114: | Loss: 0.69360
Epoch 115: | Loss: 0.69360
Epoch 116: | Loss: 0.69503
Epoch 117: | Loss: 0.69178
Epoch 118: | Loss: 0.69038
Epoch 119: | Loss: 0.69209
Epoch 120: | Loss: 0.69412
Epoch 121: | Loss: 0.69300
Epoch 122: | Loss: 0.69291
Epoch 123: | Loss: 0.69311
Epoch 124: | Loss: 0.69386
Epoch 125: | Loss: 0.69262
Epoch 126: | Loss: 0.69418
Epoch 127: | Loss: 0.69377
Epoch 128: | Loss: 0.69598
Epoch 129: | Loss: 0.69458
Epoch 130: | Loss: 0.69258
Epoch 131: | Loss: 0.69474
Epoch 132: | Loss: 0.69449
Epoch 133: | Loss: 0.69489
Epoch 134: | Loss: 0.69291
Epoch 135: | Loss: 0.69460
Epoch 136: | Loss: 0.69314
Epoch 137: | Loss: 0.69323
Epoch 138: | Loss: 0.69561
Epoch 139: | Loss: 0.69280
Epoch 140: | Loss: 0.69300
E

Epoch 006: | Loss: 0.69623
Epoch 007: | Loss: 0.69523
Epoch 008: | Loss: 0.69345
Epoch 009: | Loss: 0.69531
Epoch 010: | Loss: 0.69370
Epoch 011: | Loss: 0.69437
Epoch 012: | Loss: 0.69449
Epoch 013: | Loss: 0.69330
Epoch 014: | Loss: 0.69555
Epoch 015: | Loss: 0.69089
Epoch 016: | Loss: 0.69392
Epoch 017: | Loss: 0.69407
Epoch 018: | Loss: 0.69308
Epoch 019: | Loss: 0.69063
Epoch 020: | Loss: 0.69251
Epoch 021: | Loss: 0.69424
Epoch 022: | Loss: 0.69261
Epoch 023: | Loss: 0.69448
Epoch 024: | Loss: 0.69336
Epoch 025: | Loss: 0.69332
Epoch 026: | Loss: 0.69416
Epoch 027: | Loss: 0.69515
Epoch 028: | Loss: 0.69181
Epoch 029: | Loss: 0.69510
Epoch 030: | Loss: 0.69295
Epoch 031: | Loss: 0.69420
Epoch 032: | Loss: 0.69263
Epoch 033: | Loss: 0.69332
Epoch 034: | Loss: 0.69546
Epoch 035: | Loss: 0.69073
Epoch 036: | Loss: 0.69305
Epoch 037: | Loss: 0.69501
Epoch 038: | Loss: 0.69134
Epoch 039: | Loss: 0.69272
Epoch 040: | Loss: 0.69456
Epoch 041: | Loss: 0.69282
Epoch 042: | Loss: 0.69405
E

Epoch 109: | Loss: 0.69387
Epoch 110: | Loss: 0.69246
Epoch 111: | Loss: 0.69393
Epoch 112: | Loss: 0.69390
Epoch 113: | Loss: 0.69483
Epoch 114: | Loss: 0.69284
Epoch 115: | Loss: 0.69224
Epoch 116: | Loss: 0.69348
Epoch 117: | Loss: 0.69314
Epoch 118: | Loss: 0.69267
Epoch 119: | Loss: 0.69451
Epoch 120: | Loss: 0.69323
Epoch 121: | Loss: 0.69474
Epoch 122: | Loss: 0.69196
Epoch 123: | Loss: 0.69519
Epoch 124: | Loss: 0.69371
Epoch 125: | Loss: 0.69472
Epoch 126: | Loss: 0.69292
Epoch 127: | Loss: 0.69363
Epoch 128: | Loss: 0.69429
Epoch 129: | Loss: 0.69526
Epoch 130: | Loss: 0.69438
Epoch 131: | Loss: 0.69178
Epoch 132: | Loss: 0.69327
Epoch 133: | Loss: 0.69372
Epoch 134: | Loss: 0.69295
Epoch 135: | Loss: 0.69483
Epoch 136: | Loss: 0.69408
Epoch 137: | Loss: 0.69467
Epoch 138: | Loss: 0.69423
Epoch 139: | Loss: 0.69287
Epoch 140: | Loss: 0.69106
Epoch 141: | Loss: 0.69566
Epoch 142: | Loss: 0.69254
Epoch 143: | Loss: 0.69240
Epoch 144: | Loss: 0.69492
Epoch 145: | Loss: 0.69141
E

In [4]:
#Part 3

#Tast 3.1
class Network_RNI(nn.Module):
    
    def __init__(self):
        super().__init__()
        
        self.convs = nn.ModuleList()
        self.convs.append(tg_nn.GCNConv(input_dim, 50)
        )
        
        self.num_MP_layers = 16
        
        for i in range(15):
            self.convs.append(tg_nn.GCNConv(50, 50))
        
        self.MLPs = nn.Sequential(
            nn.Linear(50, 50), nn.Dropout(0.25),
            nn.Linear(50,50), nn.Dropout(0.1),
            nn.Linear(50,50), nn.Dropout(0.1),
            nn.Linear(50, output_dim)
            )
        
    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        
        #sample 25 random features for ever data point
        rvec = torch.randn((len(batch),25))
        
        x_pre = x.narrow(1, 0,25)
        
        #concatenate 25 zeros with 25 N(0,1) values for consecutive processing
        x = torch.cat((x_pre, rvec), dim = 1)
        #print(x)
        
        for i in range(0,self.num_MP_layers):
            x = self.convs[i](x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, 0.25)
            
        x = tg_nn.global_max_pool(x, batch)
        
        x = self.MLPs(x)
        
        
        return x
    
    def reset_parameters(self):
        for layers in self.children():
            for layer in layers:
                if hasattr(layer, 'reset_parameters'):
                    layer.reset_parameters()
    
    def loss(self, pred, label):
        return F.binary_cross_entropy(F.sigmoid(pred), label.float().unsqueeze(1))

net_rni = Network_RNI()
print(net_rni)

cross_validation(net_rni, py_graphs, 400)

Network_RNI(
  (convs): ModuleList(
    (0): GCNConv(50, 50)
    (1): GCNConv(50, 50)
    (2): GCNConv(50, 50)
    (3): GCNConv(50, 50)
    (4): GCNConv(50, 50)
    (5): GCNConv(50, 50)
    (6): GCNConv(50, 50)
    (7): GCNConv(50, 50)
    (8): GCNConv(50, 50)
    (9): GCNConv(50, 50)
    (10): GCNConv(50, 50)
    (11): GCNConv(50, 50)
    (12): GCNConv(50, 50)
    (13): GCNConv(50, 50)
    (14): GCNConv(50, 50)
    (15): GCNConv(50, 50)
  )
  (MLPs): Sequential(
    (0): Linear(in_features=50, out_features=50, bias=True)
    (1): Dropout(p=0.25, inplace=False)
    (2): Linear(in_features=50, out_features=50, bias=True)
    (3): Dropout(p=0.1, inplace=False)
    (4): Linear(in_features=50, out_features=50, bias=True)
    (5): Dropout(p=0.1, inplace=False)
    (6): Linear(in_features=50, out_features=1, bias=True)
  )
)
Epoch 001: | Loss: 0.69485
Epoch 002: | Loss: 0.69523
Epoch 003: | Loss: 0.69388
Epoch 004: | Loss: 0.69336
Epoch 005: | Loss: 0.69099
Epoch 006: | Loss: 0.69426
Epoch 0

Epoch 274: | Loss: 0.62139
Epoch 275: | Loss: 0.56743
Epoch 276: | Loss: 0.53347
Epoch 277: | Loss: 0.53469
Epoch 278: | Loss: 0.59009
Epoch 279: | Loss: 0.54193
Epoch 280: | Loss: 0.56852
Epoch 281: | Loss: 0.48989
Epoch 282: | Loss: 0.49382
Epoch 283: | Loss: 0.61931
Epoch 284: | Loss: 0.50238
Epoch 285: | Loss: 0.51866
Epoch 286: | Loss: 0.52915
Epoch 287: | Loss: 0.60739
Epoch 288: | Loss: 0.58385
Epoch 289: | Loss: 0.52773
Epoch 290: | Loss: 0.50676
Epoch 291: | Loss: 0.64587
Epoch 292: | Loss: 0.54830
Epoch 293: | Loss: 0.47613
Epoch 294: | Loss: 0.48640
Epoch 295: | Loss: 0.48659
Epoch 296: | Loss: 0.60829
Epoch 297: | Loss: 0.54477
Epoch 298: | Loss: 0.60325
Epoch 299: | Loss: 0.49433
Epoch 300: | Loss: 0.51041
Epoch 301: | Loss: 0.59832
Epoch 302: | Loss: 0.52306
Epoch 303: | Loss: 0.53185
Epoch 304: | Loss: 0.52351
Epoch 305: | Loss: 0.50214
Epoch 306: | Loss: 0.57828
Epoch 307: | Loss: 0.55049
Epoch 308: | Loss: 0.50807
Epoch 309: | Loss: 0.58569
Epoch 310: | Loss: 0.59241
E

Epoch 177: | Loss: 0.50718
Epoch 178: | Loss: 0.50976
Epoch 179: | Loss: 0.60893
Epoch 180: | Loss: 0.58386
Epoch 181: | Loss: 0.57302
Epoch 182: | Loss: 0.56169
Epoch 183: | Loss: 0.50371
Epoch 184: | Loss: 0.54303
Epoch 185: | Loss: 0.58698
Epoch 186: | Loss: 0.56950
Epoch 187: | Loss: 0.63756
Epoch 188: | Loss: 0.53854
Epoch 189: | Loss: 0.53637
Epoch 190: | Loss: 0.53566
Epoch 191: | Loss: 0.57124
Epoch 192: | Loss: 0.54465
Epoch 193: | Loss: 0.51461
Epoch 194: | Loss: 0.49652
Epoch 195: | Loss: 0.51557
Epoch 196: | Loss: 0.59903
Epoch 197: | Loss: 0.61005
Epoch 198: | Loss: 0.55207
Epoch 199: | Loss: 0.51032
Epoch 200: | Loss: 0.52644
Epoch 201: | Loss: 0.49680
Epoch 202: | Loss: 0.43190
Epoch 203: | Loss: 0.60024
Epoch 204: | Loss: 0.56281
Epoch 205: | Loss: 0.43295
Epoch 206: | Loss: 0.64556
Epoch 207: | Loss: 0.54288
Epoch 208: | Loss: 0.50908
Epoch 209: | Loss: 0.44174
Epoch 210: | Loss: 0.44744
Epoch 211: | Loss: 0.48877
Epoch 212: | Loss: 0.51094
Epoch 213: | Loss: 0.52571
E

Epoch 080: | Loss: 0.65034
Epoch 081: | Loss: 0.62923
Epoch 082: | Loss: 0.63373
Epoch 083: | Loss: 0.63052
Epoch 084: | Loss: 0.61439
Epoch 085: | Loss: 0.64457
Epoch 086: | Loss: 0.62419
Epoch 087: | Loss: 0.62295
Epoch 088: | Loss: 0.67889
Epoch 089: | Loss: 0.64164
Epoch 090: | Loss: 0.60212
Epoch 091: | Loss: 0.65076
Epoch 092: | Loss: 0.65471
Epoch 093: | Loss: 0.61841
Epoch 094: | Loss: 0.67166
Epoch 095: | Loss: 0.62341
Epoch 096: | Loss: 0.66337
Epoch 097: | Loss: 0.64785
Epoch 098: | Loss: 0.64809
Epoch 099: | Loss: 0.59761
Epoch 100: | Loss: 0.56936
Epoch 101: | Loss: 0.65146
Epoch 102: | Loss: 0.57250
Epoch 103: | Loss: 0.59356
Epoch 104: | Loss: 0.64494
Epoch 105: | Loss: 0.63057
Epoch 106: | Loss: 0.61283
Epoch 107: | Loss: 0.61326
Epoch 108: | Loss: 0.63381
Epoch 109: | Loss: 0.64476
Epoch 110: | Loss: 0.61273
Epoch 111: | Loss: 0.60659
Epoch 112: | Loss: 0.60330
Epoch 113: | Loss: 0.66216
Epoch 114: | Loss: 0.57294
Epoch 115: | Loss: 0.57531
Epoch 116: | Loss: 0.57910
E

Epoch 384: | Loss: 0.56843
Epoch 385: | Loss: 0.47429
Epoch 386: | Loss: 0.58066
Epoch 387: | Loss: 0.49081
Epoch 388: | Loss: 0.45315
Epoch 389: | Loss: 0.51406
Epoch 390: | Loss: 0.50385
Epoch 391: | Loss: 0.57530
Epoch 392: | Loss: 0.51612
Epoch 393: | Loss: 0.51989
Epoch 394: | Loss: 0.48290
Epoch 395: | Loss: 0.51570
Epoch 396: | Loss: 0.54476
Epoch 397: | Loss: 0.57008
Epoch 398: | Loss: 0.46958
Epoch 399: | Loss: 0.50351
Epoch 400: | Loss: 0.53081

Fold 2: | Accuracy: 0.81818

Epoch 001: | Loss: 0.69586
Epoch 002: | Loss: 0.69383
Epoch 003: | Loss: 0.69389
Epoch 004: | Loss: 0.69522
Epoch 005: | Loss: 0.69516
Epoch 006: | Loss: 0.69419
Epoch 007: | Loss: 0.69373
Epoch 008: | Loss: 0.69441
Epoch 009: | Loss: 0.69266
Epoch 010: | Loss: 0.69342
Epoch 011: | Loss: 0.69288
Epoch 012: | Loss: 0.69126
Epoch 013: | Loss: 0.69389
Epoch 014: | Loss: 0.69583
Epoch 015: | Loss: 0.69473
Epoch 016: | Loss: 0.69244
Epoch 017: | Loss: 0.69283
Epoch 018: | Loss: 0.69432
Epoch 019: | Loss: 0.6924

Epoch 287: | Loss: 0.54812
Epoch 288: | Loss: 0.55949
Epoch 289: | Loss: 0.57992
Epoch 290: | Loss: 0.58201
Epoch 291: | Loss: 0.59184
Epoch 292: | Loss: 0.59614
Epoch 293: | Loss: 0.57409
Epoch 294: | Loss: 0.55298
Epoch 295: | Loss: 0.63223
Epoch 296: | Loss: 0.58465
Epoch 297: | Loss: 0.58063
Epoch 298: | Loss: 0.53644
Epoch 299: | Loss: 0.52613
Epoch 300: | Loss: 0.67024
Epoch 301: | Loss: 0.59664
Epoch 302: | Loss: 0.48974
Epoch 303: | Loss: 0.55177
Epoch 304: | Loss: 0.62680
Epoch 305: | Loss: 0.55426
Epoch 306: | Loss: 0.54774
Epoch 307: | Loss: 0.61153
Epoch 308: | Loss: 0.52029
Epoch 309: | Loss: 0.59205
Epoch 310: | Loss: 0.57342
Epoch 311: | Loss: 0.61847
Epoch 312: | Loss: 0.48278
Epoch 313: | Loss: 0.56381
Epoch 314: | Loss: 0.54326
Epoch 315: | Loss: 0.52455
Epoch 316: | Loss: 0.60107
Epoch 317: | Loss: 0.57991
Epoch 318: | Loss: 0.63972
Epoch 319: | Loss: 0.51508
Epoch 320: | Loss: 0.59116
Epoch 321: | Loss: 0.57239
Epoch 322: | Loss: 0.52449
Epoch 323: | Loss: 0.58487
E

Epoch 190: | Loss: 0.62509
Epoch 191: | Loss: 0.59839
Epoch 192: | Loss: 0.59642
Epoch 193: | Loss: 0.57828
Epoch 194: | Loss: 0.62873
Epoch 195: | Loss: 0.57277
Epoch 196: | Loss: 0.49003
Epoch 197: | Loss: 0.58436
Epoch 198: | Loss: 0.60668
Epoch 199: | Loss: 0.57231
Epoch 200: | Loss: 0.49643
Epoch 201: | Loss: 0.54906
Epoch 202: | Loss: 0.63058
Epoch 203: | Loss: 0.60492
Epoch 204: | Loss: 0.54457
Epoch 205: | Loss: 0.49203
Epoch 206: | Loss: 0.70776
Epoch 207: | Loss: 0.55179
Epoch 208: | Loss: 0.53009
Epoch 209: | Loss: 0.63005
Epoch 210: | Loss: 0.57249
Epoch 211: | Loss: 0.45012
Epoch 212: | Loss: 0.59637
Epoch 213: | Loss: 0.71380
Epoch 214: | Loss: 0.50292
Epoch 215: | Loss: 0.58595
Epoch 216: | Loss: 0.57553
Epoch 217: | Loss: 0.55753
Epoch 218: | Loss: 0.55760
Epoch 219: | Loss: 0.53317
Epoch 220: | Loss: 0.48721
Epoch 221: | Loss: 0.60940
Epoch 222: | Loss: 0.57568
Epoch 223: | Loss: 0.59111
Epoch 224: | Loss: 0.52320
Epoch 225: | Loss: 0.60899
Epoch 226: | Loss: 0.63899
E

In [5]:
class Network_RNI_adapted(nn.Module):
    
    def __init__(self):
        super().__init__()
        
        self.convs = nn.ModuleList()
        self.convs.append(tg_nn.GCNConv(input_dim, 50)
        )
        #4 MP layers
        self.num_MP_layers = 4
        
        for i in range(3):
            self.convs.append(tg_nn.GCNConv(50, 50))
        
        #8 MLP layers
        self.MLPs = nn.Sequential(
            nn.Linear(50, 50), nn.Dropout(0.25),
            nn.Linear(50,50), nn.Dropout(0.1),
            nn.Linear(50,50), nn.Dropout(0.1),
            nn.Linear(50,50), nn.Dropout(0.1),
            nn.Linear(50,50), nn.Dropout(0.1),
            nn.Linear(50,50), nn.Dropout(0.1),
            nn.Linear(50,50), nn.Dropout(0.1),
            nn.Linear(50, output_dim)
            )
        
    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        
        rvec = torch.randn((len(batch),25))
        
        x_pre = x.narrow(1, 0,25)
        
        x = torch.cat((x_pre, rvec), dim = 1)
        #print(x)
        
        for i in range(0,self.num_MP_layers):
            #print(x.shape)
            #print(x)
            x = self.convs[i](x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, 0.25)
            
        x = tg_nn.global_max_pool(x, batch)
        
        x = self.MLPs(x)
        
        
        return x
    
    def reset_parameters(self):
        for layers in self.children():
            for layer in layers:
                if hasattr(layer, 'reset_parameters'):
                    layer.reset_parameters()
    
    def loss(self, pred, label):
        #print('lossfuntion', F.sigmoid(pred[0]))
        return F.binary_cross_entropy(F.sigmoid(pred), label.float().unsqueeze(1))
    
    
    
net_rni_adapted = Network_RNI_adapted()

cross_validation(net_rni_adapted, py_graphs, 400)

#Accuracy decreases drastically with too few MPLs, because for graphs with up to 15 nodes,
#we need 14 MPLs to represent the information coming from all other nodes

Epoch 001: | Loss: 0.69060
Epoch 002: | Loss: 0.69460
Epoch 003: | Loss: 0.69483
Epoch 004: | Loss: 0.69587
Epoch 005: | Loss: 0.69657
Epoch 006: | Loss: 0.69243
Epoch 007: | Loss: 0.69497
Epoch 008: | Loss: 0.69555
Epoch 009: | Loss: 0.69394
Epoch 010: | Loss: 0.69505
Epoch 011: | Loss: 0.69247
Epoch 012: | Loss: 0.69306
Epoch 013: | Loss: 0.69377
Epoch 014: | Loss: 0.69416
Epoch 015: | Loss: 0.69348
Epoch 016: | Loss: 0.69450
Epoch 017: | Loss: 0.69315
Epoch 018: | Loss: 0.69444
Epoch 019: | Loss: 0.69418
Epoch 020: | Loss: 0.69475
Epoch 021: | Loss: 0.69470
Epoch 022: | Loss: 0.69397
Epoch 023: | Loss: 0.69334
Epoch 024: | Loss: 0.69533
Epoch 025: | Loss: 0.69511
Epoch 026: | Loss: 0.69581
Epoch 027: | Loss: 0.69416
Epoch 028: | Loss: 0.69363
Epoch 029: | Loss: 0.69308
Epoch 030: | Loss: 0.69364
Epoch 031: | Loss: 0.69232
Epoch 032: | Loss: 0.69351
Epoch 033: | Loss: 0.69454
Epoch 034: | Loss: 0.69525
Epoch 035: | Loss: 0.69555
Epoch 036: | Loss: 0.69348
Epoch 037: | Loss: 0.69450
E

Epoch 305: | Loss: 0.68099
Epoch 306: | Loss: 0.69542
Epoch 307: | Loss: 0.69384
Epoch 308: | Loss: 0.69302
Epoch 309: | Loss: 0.70066
Epoch 310: | Loss: 0.68610
Epoch 311: | Loss: 0.67910
Epoch 312: | Loss: 0.68105
Epoch 313: | Loss: 0.69168
Epoch 314: | Loss: 0.67443
Epoch 315: | Loss: 0.68691
Epoch 316: | Loss: 0.69623
Epoch 317: | Loss: 0.68063
Epoch 318: | Loss: 0.67962
Epoch 319: | Loss: 0.69336
Epoch 320: | Loss: 0.68344
Epoch 321: | Loss: 0.69245
Epoch 322: | Loss: 0.69083
Epoch 323: | Loss: 0.70471
Epoch 324: | Loss: 0.70029
Epoch 325: | Loss: 0.67135
Epoch 326: | Loss: 0.68697
Epoch 327: | Loss: 0.67833
Epoch 328: | Loss: 0.70461
Epoch 329: | Loss: 0.68512
Epoch 330: | Loss: 0.68231
Epoch 331: | Loss: 0.69052
Epoch 332: | Loss: 0.69049
Epoch 333: | Loss: 0.71383
Epoch 334: | Loss: 0.67913
Epoch 335: | Loss: 0.69405
Epoch 336: | Loss: 0.68282
Epoch 337: | Loss: 0.68454
Epoch 338: | Loss: 0.69615
Epoch 339: | Loss: 0.68472
Epoch 340: | Loss: 0.69980
Epoch 341: | Loss: 0.67881
E

Epoch 208: | Loss: 0.68277
Epoch 209: | Loss: 0.66685
Epoch 210: | Loss: 0.68601
Epoch 211: | Loss: 0.68193
Epoch 212: | Loss: 0.70008
Epoch 213: | Loss: 0.69645
Epoch 214: | Loss: 0.68522
Epoch 215: | Loss: 0.69266
Epoch 216: | Loss: 0.69885
Epoch 217: | Loss: 0.67375
Epoch 218: | Loss: 0.69726
Epoch 219: | Loss: 0.67571
Epoch 220: | Loss: 0.66938
Epoch 221: | Loss: 0.70440
Epoch 222: | Loss: 0.68579
Epoch 223: | Loss: 0.68059
Epoch 224: | Loss: 0.69972
Epoch 225: | Loss: 0.68927
Epoch 226: | Loss: 0.67417
Epoch 227: | Loss: 0.69407
Epoch 228: | Loss: 0.67636
Epoch 229: | Loss: 0.69760
Epoch 230: | Loss: 0.68769
Epoch 231: | Loss: 0.67457
Epoch 232: | Loss: 0.67815
Epoch 233: | Loss: 0.68191
Epoch 234: | Loss: 0.67496
Epoch 235: | Loss: 0.67787
Epoch 236: | Loss: 0.67458
Epoch 237: | Loss: 0.69101
Epoch 238: | Loss: 0.68810
Epoch 239: | Loss: 0.68171
Epoch 240: | Loss: 0.66023
Epoch 241: | Loss: 0.69978
Epoch 242: | Loss: 0.67644
Epoch 243: | Loss: 0.71507
Epoch 244: | Loss: 0.68058
E

Epoch 111: | Loss: 0.69440
Epoch 112: | Loss: 0.68309
Epoch 113: | Loss: 0.69580
Epoch 114: | Loss: 0.68701
Epoch 115: | Loss: 0.69823
Epoch 116: | Loss: 0.69765
Epoch 117: | Loss: 0.68344
Epoch 118: | Loss: 0.68825
Epoch 119: | Loss: 0.68934
Epoch 120: | Loss: 0.69841
Epoch 121: | Loss: 0.68995
Epoch 122: | Loss: 0.68565
Epoch 123: | Loss: 0.70208
Epoch 124: | Loss: 0.68965
Epoch 125: | Loss: 0.68654
Epoch 126: | Loss: 0.69220
Epoch 127: | Loss: 0.68951
Epoch 128: | Loss: 0.69207
Epoch 129: | Loss: 0.68836
Epoch 130: | Loss: 0.68686
Epoch 131: | Loss: 0.69821
Epoch 132: | Loss: 0.68080
Epoch 133: | Loss: 0.69406
Epoch 134: | Loss: 0.67684
Epoch 135: | Loss: 0.68439
Epoch 136: | Loss: 0.66124
Epoch 137: | Loss: 0.70128
Epoch 138: | Loss: 0.69075
Epoch 139: | Loss: 0.69908
Epoch 140: | Loss: 0.68805
Epoch 141: | Loss: 0.70049
Epoch 142: | Loss: 0.69303
Epoch 143: | Loss: 0.68440
Epoch 144: | Loss: 0.68078
Epoch 145: | Loss: 0.68294
Epoch 146: | Loss: 0.69239
Epoch 147: | Loss: 0.68062
E

Epoch 014: | Loss: 0.69428
Epoch 015: | Loss: 0.69435
Epoch 016: | Loss: 0.69552
Epoch 017: | Loss: 0.69447
Epoch 018: | Loss: 0.69330
Epoch 019: | Loss: 0.69421
Epoch 020: | Loss: 0.69449
Epoch 021: | Loss: 0.69488
Epoch 022: | Loss: 0.69367
Epoch 023: | Loss: 0.69362
Epoch 024: | Loss: 0.69365
Epoch 025: | Loss: 0.69162
Epoch 026: | Loss: 0.69353
Epoch 027: | Loss: 0.69212
Epoch 028: | Loss: 0.69336
Epoch 029: | Loss: 0.69549
Epoch 030: | Loss: 0.69822
Epoch 031: | Loss: 0.69323
Epoch 032: | Loss: 0.69575
Epoch 033: | Loss: 0.69218
Epoch 034: | Loss: 0.69633
Epoch 035: | Loss: 0.69452
Epoch 036: | Loss: 0.69393
Epoch 037: | Loss: 0.69504
Epoch 038: | Loss: 0.69298
Epoch 039: | Loss: 0.69405
Epoch 040: | Loss: 0.69600
Epoch 041: | Loss: 0.69506
Epoch 042: | Loss: 0.69320
Epoch 043: | Loss: 0.69527
Epoch 044: | Loss: 0.69540
Epoch 045: | Loss: 0.69226
Epoch 046: | Loss: 0.69326
Epoch 047: | Loss: 0.69397
Epoch 048: | Loss: 0.69432
Epoch 049: | Loss: 0.69425
Epoch 050: | Loss: 0.69485
E

Epoch 318: | Loss: 0.64414
Epoch 319: | Loss: 0.67466
Epoch 320: | Loss: 0.68428
Epoch 321: | Loss: 0.70027
Epoch 322: | Loss: 0.69136
Epoch 323: | Loss: 0.68964
Epoch 324: | Loss: 0.68950
Epoch 325: | Loss: 0.69389
Epoch 326: | Loss: 0.65222
Epoch 327: | Loss: 0.67576
Epoch 328: | Loss: 0.69004
Epoch 329: | Loss: 0.69227
Epoch 330: | Loss: 0.68048
Epoch 331: | Loss: 0.65831
Epoch 332: | Loss: 0.67840
Epoch 333: | Loss: 0.66201
Epoch 334: | Loss: 0.65242
Epoch 335: | Loss: 0.64021
Epoch 336: | Loss: 0.67030
Epoch 337: | Loss: 0.69010
Epoch 338: | Loss: 0.66910
Epoch 339: | Loss: 0.67267
Epoch 340: | Loss: 0.71107
Epoch 341: | Loss: 0.66258
Epoch 342: | Loss: 0.66660
Epoch 343: | Loss: 0.67757
Epoch 344: | Loss: 0.69519
Epoch 345: | Loss: 0.68014
Epoch 346: | Loss: 0.67199
Epoch 347: | Loss: 0.68388
Epoch 348: | Loss: 0.67913
Epoch 349: | Loss: 0.68097
Epoch 350: | Loss: 0.64562
Epoch 351: | Loss: 0.69392
Epoch 352: | Loss: 0.64811
Epoch 353: | Loss: 0.67052
Epoch 354: | Loss: 0.64829
E

Epoch 221: | Loss: 0.68672
Epoch 222: | Loss: 0.66746
Epoch 223: | Loss: 0.71035
Epoch 224: | Loss: 0.65107
Epoch 225: | Loss: 0.68844
Epoch 226: | Loss: 0.67810
Epoch 227: | Loss: 0.67150
Epoch 228: | Loss: 0.70405
Epoch 229: | Loss: 0.68118
Epoch 230: | Loss: 0.70040
Epoch 231: | Loss: 0.69856
Epoch 232: | Loss: 0.68287
Epoch 233: | Loss: 0.68656
Epoch 234: | Loss: 0.69550
Epoch 235: | Loss: 0.64677
Epoch 236: | Loss: 0.71736
Epoch 237: | Loss: 0.68078
Epoch 238: | Loss: 0.66114
Epoch 239: | Loss: 0.70351
Epoch 240: | Loss: 0.66828
Epoch 241: | Loss: 0.67513
Epoch 242: | Loss: 0.66151
Epoch 243: | Loss: 0.65197
Epoch 244: | Loss: 0.70583
Epoch 245: | Loss: 0.71630
Epoch 246: | Loss: 0.67933
Epoch 247: | Loss: 0.69791
Epoch 248: | Loss: 0.68520
Epoch 249: | Loss: 0.67208
Epoch 250: | Loss: 0.68030
Epoch 251: | Loss: 0.70707
Epoch 252: | Loss: 0.68739
Epoch 253: | Loss: 0.66702
Epoch 254: | Loss: 0.68232
Epoch 255: | Loss: 0.67420
Epoch 256: | Loss: 0.69508
Epoch 257: | Loss: 0.68316
E