In [1]:
import torch
from torch import Tensor
from torch.nn import Parameter
from torch_scatter import scatter_add
from torch_sparse import SparseTensor, fill_diag, matmul, mul
from torch_sparse import sum as sparsesum
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.nn.inits import zeros
from torch_geometric.typing import Adj, OptTensor, PairTensor
from torch_geometric.utils import add_remaining_self_loops
from torch_geometric.utils.num_nodes import maybe_num_nodes
from torch_geometric.datasets import Planetoid
import torch.nn.functional as F
from typing import Optional, Tuple

def gcn_norm(edge_index, edge_weight=None, num_nodes=None, improved=False, add_self_loops=True, flow="source_to_target", dtype=None):
    fill_value = 2. if improved else 1.
    num_nodes = maybe_num_nodes(edge_index, num_nodes)

    if edge_weight is None:
        edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype,
                                 device=edge_index.device)

    if add_self_loops:
        edge_index, tmp_edge_weight = add_remaining_self_loops(edge_index, edge_weight, 
                                                               fill_value, num_nodes)

        edge_weight = tmp_edge_weight

    row, col = edge_index[0], edge_index[1]
    idx = col if flow == "source_to_target" else row
    deg = scatter_add(edge_weight, idx, dim=0, dim_size=num_nodes)
    deg_inv_sqrt = deg.pow_(-0.5)
    deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0)
    return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]


class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels,**kwargs):
        kwargs.setdefault('aggr', 'add')
        super().__init__(**kwargs)
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.improved = False
        self.add_self_loops = True
        self.normalize = True
        
        self.lin = Linear(in_channels, out_channels, bias=False, weight_initializer='glorot')
        self.bias = Parameter(torch.Tensor(out_channels))
        
        self.reset_parameters()

    def reset_parameters(self):
        self.lin.reset_parameters()
        zeros(self.bias)


    def forward(self, x, edge_index, edge_weight = None):
        if self.normalize:
            edge_index, edge_weight = gcn_norm(edge_index, edge_weight, 
                                               x.size(self.node_dim), self.improved, 
                                               self.add_self_loops, self.flow)

        x = self.lin(x)
        
        # propagate_type: (x: Tensor, edge_weight: OptTensor)
        out = self.propagate(edge_index, x=x, edge_weight=edge_weight, size=None)

        if self.bias is not None:
            out += self.bias

        return out


    def message(self, x_j, edge_weight):
        return x_j if edge_weight is None else edge_weight.view(-1, 1)*x_j

    def message_and_aggregate(self, adj_t, x):
        return matmul(adj_t, x, reduce=self.aggr)


In [2]:
dataset = Planetoid(root='/tmp/Cora', name='Cora')

In [3]:
class GCN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = GCNConv(dataset.num_node_features, 16)
        self.conv2 = GCNConv(16, dataset.num_classes)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)

        return F.log_softmax(x, dim=1)

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GCN().to(device)
data = dataset[0].to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

model.train()
for epoch in range(500):
    optimizer.zero_grad()
    out = model(data)
    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
    print("NLL Loss: {}".format(loss.item()))
    loss.backward()
    optimizer.step()

NLL Loss: 1.9540339708328247
NLL Loss: 1.8549575805664062
NLL Loss: 1.722895860671997
NLL Loss: 1.6031818389892578
NLL Loss: 1.4577544927597046
NLL Loss: 1.3632465600967407
NLL Loss: 1.214414358139038
NLL Loss: 1.1096059083938599
NLL Loss: 1.0113153457641602
NLL Loss: 0.9122030735015869
NLL Loss: 0.8009372353553772
NLL Loss: 0.738771378993988
NLL Loss: 0.6319516897201538
NLL Loss: 0.5524605512619019
NLL Loss: 0.48144006729125977
NLL Loss: 0.45087242126464844
NLL Loss: 0.42074069380760193
NLL Loss: 0.37666186690330505
NLL Loss: 0.30514439940452576
NLL Loss: 0.2766708433628082
NLL Loss: 0.21188490092754364
NLL Loss: 0.20924663543701172
NLL Loss: 0.18866832554340363
NLL Loss: 0.16344475746154785
NLL Loss: 0.1463128626346588
NLL Loss: 0.16747958958148956
NLL Loss: 0.09954558312892914
NLL Loss: 0.11196306347846985
NLL Loss: 0.1226775050163269
NLL Loss: 0.0986974909901619
NLL Loss: 0.11168782413005829
NLL Loss: 0.11779920011758804
NLL Loss: 0.10296595096588135
NLL Loss: 0.12333191186189651
N

NLL Loss: 0.030006807297468185
NLL Loss: 0.03098428249359131
NLL Loss: 0.023853838443756104
NLL Loss: 0.017495010048151016
NLL Loss: 0.02384697087109089
NLL Loss: 0.02303258329629898
NLL Loss: 0.012917105108499527
NLL Loss: 0.020390067249536514
NLL Loss: 0.023172516375780106
NLL Loss: 0.024942034855484962
NLL Loss: 0.02567211724817753
NLL Loss: 0.03368431702256203
NLL Loss: 0.02081015706062317
NLL Loss: 0.02715565823018551
NLL Loss: 0.02164369821548462
NLL Loss: 0.027976712211966515
NLL Loss: 0.03575754165649414
NLL Loss: 0.04442604258656502
NLL Loss: 0.018933819606900215
NLL Loss: 0.029846396297216415
NLL Loss: 0.02199404314160347
NLL Loss: 0.02686881832778454
NLL Loss: 0.0423690490424633
NLL Loss: 0.01739085279405117
NLL Loss: 0.02086290717124939
NLL Loss: 0.020008353516459465
NLL Loss: 0.021676957607269287
NLL Loss: 0.02381288632750511
NLL Loss: 0.021141819655895233
NLL Loss: 0.029934167861938477
NLL Loss: 0.015182336792349815
NLL Loss: 0.0244023147970438
NLL Loss: 0.022857712581753

In [5]:
model.eval()
pred = model(data).argmax(dim=1)
correct = (pred[data.test_mask] == data.y[data.test_mask]).sum()
acc = int(correct) / int(data.test_mask.sum())
print(f'Accuracy: {acc:.4f}')

Accuracy: 0.7990
