In [None]:
!pip install -r FedStar/requirements.txt

Collecting dtaidistance
  Downloading dtaidistance-2.3.10-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.4/2.4 MB[0m [31m32.7 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
Collecting torch-scatter
  Downloading torch_scatter-2.1.1.tar.gz (107 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m107.6/107.6 kB[0m [31m15.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25ldone
[?25hCollecting torch-sparse
  Downloading torch_sparse-0.6.17.tar.gz (209 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m209.2/209.2 kB[0m [31m34.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l-

In [2]:
!wget http://vision.cs.aston.ac.uk/datasets/UCID/data/ucid.v2.tar.gz

--2023-04-25 06:11:24--  http://vision.cs.aston.ac.uk/datasets/UCID/data/ucid.v2.tar.gz
Resolving vision.cs.aston.ac.uk (vision.cs.aston.ac.uk)... 134.151.37.9
Connecting to vision.cs.aston.ac.uk (vision.cs.aston.ac.uk)|134.151.37.9|:80... failed: Connection timed out.
Retrying.

--2023-04-25 06:13:35--  (try: 2)  http://vision.cs.aston.ac.uk/datasets/UCID/data/ucid.v2.tar.gz
Connecting to vision.cs.aston.ac.uk (vision.cs.aston.ac.uk)|134.151.37.9|:80... ^C


In [1]:
#generic
from pathlib import Path
import os, sys
import argparse
import random
import copy
from random import choices
import pickle

#torch
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GINConv, global_add_pool, SAGEConv
from torch_geometric.transforms import OneHotDegree
from torch_geometric.utils import to_networkx, degree, to_dense_adj, to_scipy_sparse_matrix
from sklearn.model_selection import train_test_split
from scipy import sparse as sp
import torch_geometric
from torch_geometric.data import Data, Dataset, Batch
from torch_geometric.utils import to_networkx, subgraph
import torch_geometric.utils as utils
from torch.nn.functional import one_hot


#utility
import networkx as nx
from dtaidistance import dtw
from tensorboardX import SummaryWriter
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import pymetis
from ogb.nodeproppred import PygNodePropPredDataset
import tqdm

num_clients = 3
device = "cuda" if torch.cuda.is_available() else "cpu"
alg = 'fedstar'
num_rounds = 20
local_epoch = 10
lr = 0.01
weight_decay = 5e-4
nlayer = 3 # number of GINConv layers
hidden = 64
dropout = 0.5
batch_size = 128  # not used
seed = 69
datapath = '.Data'
outbase = 'outputs'
data_group = 'arxiv'
n_rw = 16
n_dg = 16
n_ones = 16
type_init = 'rw_dg' #options are rw, dg and rw_dg
print(device)
seed_dataSplit = 123
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

cuda


In [2]:
def get_numGraphLabels(g):
    s = set(g.y.flatten().tolist())
    return len(s)

def init_structure_encoding(  g, type_init = 'rw_dg'):

    if type_init == 'rw':
        # Geometric diffusion features with Random Walk
        A = to_scipy_sparse_matrix(g.edge_index, num_nodes=g.num_nodes)
        D = (degree(g.edge_index[0], num_nodes=g.num_nodes) ** -1.0).numpy()

        Dinv=sp.diags(D)
        RW=A*Dinv
        M=RW

        SE_rw=[torch.from_numpy(M.diagonal()).float()]
        M_power=M
        for _ in range(n_rw-1):
            M_power=M_power*M
            SE_rw.append(torch.from_numpy(M_power.diagonal()).float())
        SE_rw=torch.stack(SE_rw,dim=-1)

        g['stc_enc'] = SE_rw

    elif type_init == 'dg':
        # PE_degree
        g_dg = (degree(g.edge_index[0], num_nodes=g.num_nodes)).numpy().clip(1, n_dg)
        SE_dg = torch.zeros([g.num_QCnodes, n_dg])
        for i in range(len(g_dg)):
            SE_dg[i,int(g_dg[i]-1)] = 1

        g['stc_enc'] = SE_dg

    elif type_init == 'rw_dg':
        # SE_rw
        A = to_scipy_sparse_matrix(g.edge_index, num_nodes=g.num_nodes)
        D = (degree(g.edge_index[0], num_nodes=g.num_nodes) ** -1.0).numpy()

        Dinv=sp.diags(D)
        RW=A*Dinv
        M=RW

        SE=[torch.from_numpy(M.diagonal()).float()]
        M_power=M
        for _ in range(n_rw-1):
            M_power=M_power*M
            SE.append(torch.from_numpy(M_power.diagonal()).float())
        SE_rw=torch.stack(SE,dim=-1)

        # PE_degree
        g_dg = (degree(g.edge_index[0], num_nodes=g.num_nodes)).numpy().clip(1, n_dg)
        SE_dg = torch.zeros([g.num_nodes, n_dg])
        for i in range(len(g_dg)):
            SE_dg[i,int(g_dg[i]-1)] = 1

        g['stc_enc'] = torch.cat([SE_rw, SE_dg], dim=1)

    return g

def get_stats(df, ds, graph_train, graph_val=None, graph_test=None):
    from collections import Counter
    labels_train = graph_train.y.flatten().tolist()
    df.loc[ds, '#Nodes_train'] = graph_train.num_nodes
    df.loc[ds, '#Edges_train'] = graph_train.num_edges
    df.loc[ds, 'Avg_degree_train'] = graph_train.num_edges/graph_train.num_nodes
    df.loc[ds, '#Labels_train'] = len(set(labels_train))
    df.loc[ds, 'Class_dist_train'] = str(dict(Counter(labels_train)))
    
    if graph_test:
        labels_test = graph_test.y.flatten().tolist()
        df.loc[ds, '#Nodes_test'] = graph_test.num_nodes
        df.loc[ds, '#Edges_test'] = graph_test.num_edges
        df.loc[ds, 'Avg_degree_test'] = graph_test.num_edges/graph_test.num_nodes
        df.loc[ds, '#Labels_test'] = len(set(labels_test))
        df.loc[ds, 'Class_dist_test'] = str(dict(Counter(labels_test)))
        
    if graph_val:
        labels_val = graph_val.y.flatten().tolist()
        df.loc[ds, '#Nodes_val'] = graph_val.num_nodes
        df.loc[ds, '#Edges_val'] = graph_val.num_edges
        df.loc[ds, 'Avg_degree_val'] = graph_val.num_edges/graph_val.num_nodes
        df.loc[ds, '#Labels_val'] = len(set(labels_val))
        df.loc[ds, 'Class_dist_val'] = str(dict(Counter(labels_val)))
        
    return df

In [3]:
file = open('graph_struc_pickle', 'rb')
graph = pickle.load(file)
file.close()
graph

Data(num_nodes=169343, edge_index=[2, 1166243], x=[169343, 128], node_year=[169343, 1], y=[169343, 1], stc_enc=[169343, 32])

In [4]:
class GIN_dc(torch.nn.Module):
    def __init__(self, nfeat, n_se, nhid, nclass, nlayer, dropout):
        super(GIN_dc, self).__init__()
        self.num_layers = nlayer
        self.dropout = dropout

        self.pre = torch.nn.Sequential(torch.nn.Linear(nfeat, nhid))

        self.embedding_s = torch.nn.Linear(n_se, nhid)

        self.graph_convs = torch.nn.ModuleList()
        self.nn1 = torch.nn.Sequential(torch.nn.Linear(nhid + nhid, nhid), torch.nn.ReLU(), torch.nn.Linear(nhid, nhid))
        self.graph_convs.append(GINConv(self.nn1))
        self.graph_convs_s_gcn = torch.nn.ModuleList()
        self.graph_convs_s_gcn.append(GCNConv(nhid, nhid))

        for l in range(nlayer - 1):
            self.nnk = torch.nn.Sequential(torch.nn.Linear(nhid + nhid, nhid), torch.nn.ReLU(), torch.nn.Linear(nhid, nhid))
            self.graph_convs.append(GINConv(self.nnk))
            self.graph_convs_s_gcn.append(GCNConv(nhid, nhid))

        self.Whp = torch.nn.Linear(nhid + nhid, nhid)
        self.post = torch.nn.Sequential(torch.nn.Linear(nhid, nhid), torch.nn.ReLU())
        self.readout = torch.nn.Sequential(torch.nn.Linear(nhid, nclass))

    def forward(self, data):
        x, edge_index, s = data.x, data.edge_index, data.stc_enc
        x = self.pre(x)
        s = self.embedding_s(s)
        for i in range(len(self.graph_convs)):
            x = torch.cat((x, s), -1)
            x = self.graph_convs[i](x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, self.dropout, training=self.training)
            s = self.graph_convs_s_gcn[i](s, edge_index)
            s = torch.tanh(s)
        x = self.Whp(torch.cat((x, s), -1))
        x = self.post(x)
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.readout(x)
        # print(x)
        x = F.log_softmax(x, dim=1)
        # print(x)
        return x.float()
    def loss(self, pred, label):
        # print(pred, label)
        return F.cross_entropy(pred, label)
        # return F.nll_loss(pred, label)
data = copy.deepcopy(graph)
num_classes = get_numGraphLabels(data)
n_se = n_rw+n_dg
data.y = one_hot(data.y).squeeze(dim=1).float()
model = GIN_dc(nfeat=data.num_node_features, n_se=n_se, nhid=64, nclass=num_classes, nlayer=3, dropout=0.5).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
data.to(device)

Data(num_nodes=169343, edge_index=[2, 1166243], x=[169343, 128], node_year=[169343, 1], y=[169343, 40], stc_enc=[169343, 32])

In [18]:
per = torch.randperm(data.num_nodes)

In [37]:
# optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
model = GIN_dc(nfeat=data.num_node_features, n_se=n_se, nhid=64, nclass=num_classes, nlayer=3, dropout=0.5).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
batch_size = 1024
data.to(device)
start = 0
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10)
for epoch in range(20):
    acc_sum = 0
    total_loss = 0
    all = torch.randperm(data.num_nodes)
    for batch_num in range(data.num_nodes//batch_size):
    # batch = random.choices(all, k=len(all)//20)
        model.train()
        model.to(data.x.device)
        data.to(device)
        # print(f'epoch={epoch}')
        loss = torch.tensor([0.0], device='cuda:0')
    # loss.to(device)
    # print(loss.device)
        optimizer.zero_grad()
        out = model(data)
    # break
#     print(out.shape, data.y.shape)
#     loss = model.loss(out, data.y)
        out.to(device)
    # for i in range(out.shape[0]): 
        batch = per[batch_size*batch_num:batch_size*(batch_num+1)]
        for i in batch :
        # i_tensor = torch.tensor([i], dtype=torch.long, device=out.device)
            loss += model.loss(out[i], data.y[i])
            total_loss += loss.item()
        # loss += model.loss(out[i], data.y[i])
# Convert one-hot encoded labels to class indices
# Calculate number of correct predictions
        acc_sum = out.max(dim=1)[1].eq(data.y.max(dim=1)[1]).sum().item()
        # acc_sum += acc
        # if start == 0 : print(f'epoch={epoch}', total_loss, acc_sum, acc_sum*100/data.num_nodes)
        # start = 1
        # print(loss.item(), acc_sum, acc_sum*100/data.num_nodes)
    # break
        loss.backward()
        scheduler.step(loss)
        
    print(f'epoch={epoch}', total_loss, acc_sum, acc_sum*100/data.num_nodes)
    # optimizer.step()

epoch=0 326294302.67658997 4273 2.5232811512728603
epoch=1 326372319.4695916 4350 2.5687509964982316
epoch=2 326097088.06268835 4324 2.553397542266288
epoch=3 326572505.798996 4365 2.577608758555122
epoch=4 326675604.1150408 4244 2.506156144629539
epoch=5 326140412.3133693 4409 2.603591527255334
epoch=6 326823382.5893841 4263 2.5173759765682666
epoch=7 326278767.6821778 4327 2.555169094677666
epoch=8 326818438.9265218 4290 2.5333199482706696
epoch=9 326441913.9854653 4252 2.5108802843932136
epoch=10 326505077.26677394 4279 2.5268242560956167
epoch=11 326703572.72160435 4368 2.5793803109665
epoch=12 326655642.3296597 4354 2.571113066380069
epoch=13 326309872.6503229 4226 2.4955268301612703
epoch=14 326368298.9786763 4272 2.522690633802401
epoch=15 326572511.3947983 4354 2.571113066380069
epoch=16 326328794.6280587 4295 2.5362725356229663
epoch=17 326508996.8394308 4361 2.5752466886732845
epoch=18 326868802.4372189 4307 2.543358745268479
epoch=19 326404487.56867194 4289 2.532729430800210

In [8]:
model = GIN_dc(nfeat=data.num_node_features, n_se=n_se, nhid=64, nclass=num_classes, nlayer=3, dropout=0.5).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10)

for epoch in range(50):
    model.train()
    loss = torch.tensor([0.0], device='cuda:0')
    optimizer.zero_grad()
    out = model(data)

    for i in range(out.shape[0]) : loss += model.loss(out[i], data.y[i])
    acc_sum = out.max(dim=1)[1].eq(data.y.max(dim=1)[1]).sum().item()
    print(f'epoch-{epoch}',loss.item(), acc_sum, acc_sum*100/data.num_nodes)    
    loss.backward()
    optimizer.step()
    # scheduler.step(loss)


epoch-0 651575.9375 2490 1.4703885014438152
epoch-1 6497465856.0 20797 12.280991833143384
epoch-2 280105216.0 23410 13.8240139834537
epoch-3 21737978.0 12795 7.555671034527556
epoch-4 6683801088.0 9634 5.6890453104055085
epoch-5 26866046.0 12121 7.157662259437945
epoch-6 11571762.0 9978 5.89218332024353
epoch-7 7295755.0 12111 7.151757084733352
epoch-8 1479485.125 9346 5.518976278913212
epoch-9 896559.8125 5813 3.43267805578028
epoch-10 621848.3125 6240 3.6848290156664283
epoch-11 656478.625 21383 12.627035070832571
epoch-12 597703.9375 21401 12.63766438530084
epoch-13 592652.5625 21406 12.640616972653136
epoch-14 3814739.25 21379 12.624673000950732
epoch-15 761229.125 21407 12.641207490123595
epoch-16 575462.9375 21407 12.641207490123595
epoch-17 570312.8125 21405 12.640026455182676
epoch-18 564959.4375 21405 12.640026455182676
epoch-19 560239.3125 21405 12.640026455182676
epoch-20 555848.9375 21404 12.639435937712218
epoch-21 550875.9375 21408 12.641798007594055
epoch-22 546820.625 2

In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GINConv

class GIN(torch.nn.Module):
    def __init__(self, nfeat, nhid, nclass, nlayer, dropout):
        super(GIN, self).__init__()
        self.num_layers = nlayer
        self.dropout = dropout

        self.pre = torch.nn.Sequential(torch.nn.Linear(nfeat, nhid))

        self.graph_convs = torch.nn.ModuleList()
        self.nn1 = torch.nn.Sequential(torch.nn.Linear(nhid, nhid), torch.nn.ReLU(), torch.nn.Linear(nhid, nhid))
        self.graph_convs.append(GINConv(self.nn1))
        for l in range(nlayer - 1):
            self.nnk = torch.nn.Sequential(torch.nn.Linear(nhid, nhid), torch.nn.ReLU(), torch.nn.Linear(nhid, nhid))
            self.graph_convs.append(GINConv(self.nnk))

        self.post = torch.nn.Sequential(torch.nn.Linear(nhid, nhid), torch.nn.ReLU())
        self.readout = torch.nn.Sequential(torch.nn.Linear(nhid, nclass))

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.pre(x)
        for i in range(len(self.graph_convs)):
            x = self.graph_convs[i](x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, self.dropout, training=self.training)
        # x = global_add_pool(x, batch)
        x = self.post(x)
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.readout(x)
        x = F.log_softmax(x, dim=1)
        return x

    def loss(self, pred, label):
        return F.cross_entropy(pred, label)
    
model = GIN( nfeat = 128, nhid = 64, nclass = 40, nlayer = 3, dropout = 0.5).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10)

for epoch in range(50):
    model.train()
    loss = torch.tensor([0.0], device='cuda:0')
    optimizer.zero_grad()
    out = model(data)

    for i in range(out.shape[0]) : loss += model.loss(out[i], data.y[i])
    acc_sum = out.max(dim=1)[1].eq(data.y.max(dim=1)[1]).sum().item()
    print(f'epoch-{epoch}',loss.item(), acc_sum, acc_sum*100/data.num_nodes)    
    loss.backward()
    optimizer.step()
    # scheduler.step(loss)

epoch-0 734984.875 3710 2.1908198154042386
epoch-1 22825074688.0 27274 16.10577348930868
epoch-2 469514560.0 19319 11.408207011804445
epoch-3 595741.5625 19616 11.583590700530875
epoch-4 1240723.25 18542 10.949374937257518
epoch-5 6899958.0 15475 9.13825785535865
epoch-6 594683.5 18177 10.73383606053985
epoch-7 650151.5 18027 10.645258439970947
epoch-8 566168.0 17766 10.491133380181052
epoch-9 561576.875 17096 10.095486674973278
epoch-10 556037.25 17598 10.391926445143879
epoch-11 550756.4375 18626 10.998978404776105
epoch-12 547767.0625 19549 11.544026030010098
epoch-13 544715.875 20441 12.07076761365985
epoch-14 540745.4375 21482 12.685496300408047
epoch-15 538318.6875 21769 12.854974814429886
epoch-16 536066.0625 21684 12.804780829440839
epoch-17 533345.6875 21761 12.85025067466621
epoch-18 530799.3125 21842 12.898082589773418
epoch-19 528958.5625 21881 12.921112771121333
epoch-20 526862.8125 21817 12.883319653011934
epoch-21 524818.5 21923 12.945914504880626
epoch-22 523051.8125 21

In [None]:
model = GIN_dc(nfeat=data.num_node_features, n_se=n_se, nhid=64, nclass=num_classes, nlayer=2, dropout=0.2).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

for epoch in range(20):
    model.train()
    loss = torch.tensor([0.0], device='cuda:0')
    optimizer.zero_grad()
    out = model(data)

    for i in range(out.shape[0]) : loss += model.loss(out[i], data.y[i])
    acc_sum = out.max(dim=1)[1].eq(data.y.max(dim=1)[1]).sum().item()
    print(f'epoch-{epoch}',loss.item(), acc_sum, acc_sum*100/data.num_nodes)    
    loss.backward()
    optimizer.step()
    # scheduler.step(loss)


In [17]:
!pip install tqdm

[0m

In [19]:
import tqdm
#Define the GCN model
class GCN(torch.nn.Module):
    def __init__(self, num_features, hidden_channels, num_classes):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(num_features, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, num_classes)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)
    def loss(self, pred, label):
        # print(pred, label)
        return F.cross_entropy(pred, label)
    
model = GCN(data.num_features, 16, 40).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

data.to(device)
for epoch in range(20):
    model.train()
    loss = torch.tensor([0.0], device='cuda:0')
    optimizer.zero_grad()
    out = model(data)

    for i in range(out.shape[0]) : 
        loss += model.loss(out[i], data.y[i])
    acc_sum = out.max(dim=1)[1].eq(data.y.max(dim=1)[1]).sum().item()
    print(f'epoch-{epoch+1}',loss.item(), acc_sum, acc_sum*100/data.num_nodes)    
    loss.backward()
    optimizer.step()

epoch-0 638850.375 3843 2.269358638975334
epoch-1 612070.5 13345 7.880455643280206
epoch-2 594977.375 15323 9.048499199848827
epoch-3 580307.25 18690 11.036771522885505
epoch-4 567277.375 21301 12.578612638254903
epoch-5 557120.5625 22858 13.498048339760132
epoch-6 549456.0625 24322 14.36256591651264
epoch-7 541887.3125 25590 15.111342069055112
epoch-8 535173.875 26065 15.39183786752331
epoch-9 528745.25 25779 15.222949870971933
epoch-10 521939.78125 26617 15.717803511216879
epoch-11 516624.9375 28308 16.71636855376366
epoch-12 511698.5625 32522 19.204809174279422
epoch-13 506232.03125 36700 21.671991165858643
epoch-14 502167.65625 38365 22.655202754173484
epoch-15 497568.0 39734 23.46362117123235
epoch-16 494119.90625 41007 24.21534991112712
epoch-17 491217.71875 42019 24.812953591231995
epoch-18 487619.28125 42990 25.386346055048037
epoch-19 484228.0 44104 26.04418251713977


In [21]:
from ogb.nodeproppred import Evaluator

evaluator = Evaluator(name = 'ogbn-arxiv')
print(evaluator.expected_input_format) 
print(evaluator.expected_output_format) 

==== Expected input format of Evaluator for ogbn-arxiv
{'y_true': y_true, 'y_pred': y_pred}
- y_true: numpy ndarray or torch tensor of shape (num_nodes num_tasks)
- y_pred: numpy ndarray or torch tensor of shape (num_nodes num_tasks)
where y_pred stores predicted class label (integer),
num_task is 1, and each row corresponds to one node.

==== Expected output format of Evaluator for ogbn-arxiv
{'acc': acc}
- acc (float): Accuracy score averaged across 1 task(s)



In [26]:
a = out.max(dim=1)[1].unsqueeze(dim=1)
b = data.y.max(dim=1)[1].unsqueeze(dim=1)
ans = {'y_true':b, 'y_pred': a}
evaluator.eval(ans)
# print(a)

{'acc': 0.2604418251713977}

In [34]:
import torch_geometric.transforms as T
dataset = PygNodePropPredDataset(name='ogbn-arxiv',
                                     transform=T.ToSparseTensor())

data = dataset[0]
data.adj_t = data.adj_t.to_symmetric()

class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers,
                 dropout):
        super(GCN, self).__init__()

        self.convs = torch.nn.ModuleList()
        self.convs.append(GCNConv(in_channels, hidden_channels, cached=True))
        self.bns = torch.nn.ModuleList()
        self.bns.append(torch.nn.BatchNorm1d(hidden_channels))
        for _ in range(num_layers - 2):
            self.convs.append(
                GCNConv(hidden_channels, hidden_channels, cached=True))
            self.bns.append(torch.nn.BatchNorm1d(hidden_channels))
        self.convs.append(GCNConv(hidden_channels, out_channels, cached=True))

        self.dropout = dropout

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()
        for bn in self.bns:
            bn.reset_parameters()

    def forward(self, x, adj_t):
        for i, conv in enumerate(self.convs[:-1]):
            x = conv(x, adj_t)
            x = self.bns[i](x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.convs[-1](x, adj_t)
        return x.log_softmax(dim=-1)
data.to(device)
num_layers=3
hidden_channels=256
dropout=0.5
lr=0.01
epochs=500
runs=10
model = GCN(data.num_features, hidden_channels, 40, num_layers, dropout).to(device)
for epoch in range(epochs):
    model.train()

    optimizer.zero_grad()
    out = model(data.x, data.adj_t)
    loss = F.nll_loss(out, data.y.squeeze(1))
    acc_sum = out.max(dim=1)[1].eq(data.y.max(dim=1)[1]).sum().item()
    if epoch %5 == 0 : print(f'epoch-{epoch+1}',loss.item(), acc_sum, acc_sum*100/data.num_nodes)    
    loss.backward()
    optimizer.step()
    
a = out.max(dim=1)[1].unsqueeze(dim=1)
b = data.y.max(dim=1)[1].unsqueeze(dim=1)
ans = {'y_true':b, 'y_pred': a}
evaluator.eval(ans)

epoch-1 4.114811897277832 8855 5.229032200917664
epoch-6 4.114686012268066 9233 5.452247804751304
epoch-11 4.114619731903076 8222 4.855234642116887
epoch-16 4.11920690536499 8475 5.004635562143106
epoch-21 4.119435787200928 9091 5.368394323946074
epoch-26 4.112491130828857 8074 4.767838056488901
epoch-31 4.118819713592529 8904 5.257967556970173
epoch-36 4.113580226898193 8571 5.061325239307205
epoch-41 4.1190900802612305 9010 5.320562408838866
epoch-46 4.118008136749268 8139 4.8062216920687595
epoch-51 4.12195348739624 8571 5.061325239307205
epoch-56 4.126706600189209 9045 5.3412305203049435
epoch-61 4.118515968322754 8033 4.743626840200068
epoch-66 4.1193060874938965 9692 5.7232953236921515
epoch-71 4.115555763244629 8967 5.295170157609113
epoch-76 4.122092247009277 7488 4.4217948187997145
epoch-81 4.113526821136475 7981 4.71291993173618
epoch-86 4.109722137451172 8412 4.967432961504166
epoch-91 4.119235038757324 8991 5.3093425769001374
epoch-96 4.111802101135254 8209 4.84755791500091

{'acc': 0.04508010369486781}

Data(num_nodes=169343, x=[169343, 128], node_year=[169343, 1], y=[169343, 1], adj_t=[169343, 169343, nnz=2315598])