In [1]:
import os.path as osp
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.transforms as T
import torch_geometric
from torch_geometric.datasets import Planetoid, TUDataset
from torch_geometric.data import DataLoader
from torch_geometric.nn.inits import uniform
from torch.nn import Parameter as Param
from torch import Tensor

torch.manual_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

from torch_geometric.nn.conv import MessagePassing

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
dataset = 'Cora'
transform = T.Compose([
    T.RandomNodeSplit('train_rest', num_val=500, num_test=500),
    T.TargetIndegree(),
])

path = osp.join('data', dataset)
dataset = Planetoid(path, dataset, transform=transform)
data = dataset[0]

Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.x
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.tx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.allx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.y
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ty
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ally
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.graph
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.test.index
Processing...
Done!


In [5]:
dataset = 'Cora'
path = osp.join('data', dataset)
dataset = Planetoid(path, dataset, transform=T.NormalizeFeatures())
data = dataset[0]
data = data.to(device)

In [6]:
class MLP(nn.Module):
    def __init__(self, input_dim, hid_dims, out_dims):
        super(MLP, self).__init__()
        self.mlp = nn.Sequential()
        dims = [input_dim] + hid_dims + [out_dims]
        for i in range(len(dims) - 1):
            self.mlp.add_module('lay_{}'.format(i), nn.Linear(in_features=dims[i], out_features=dims[i+1]))
            if i+2 < len(dims):
                self.mlp.add_module('act {}'.format(i), nn.Tanh())
    
    def reset_parameters(self):
        for i, l in enumerate(self.mlp):
            if type(l) == nn.Linear:
                nn.init.xavier_normal_(l.weight)

    def forward(self, x):
        return self.mlp(x)

In [7]:
from torch import Tensor


class GNNM(MessagePassing):
    def __init__(self, n_nodes, out_channels, features_dim, hid_dims, num_layers=50, eps=1e-3, aggr='add', bias=True, **kwargs):
        super(GNNM, self).__init__(aggr=aggr, **kwargs)

        self.node_states = Param(torch.rand((n_nodes, features_dim)), requires_grad=True)
        self.out_channels = out_channels
        self.eps = eps
        self.num_layers = num_layers

        self.transition = MLP(features_dim, hid_dims, features_dim)
        self.read_out = MLP(features_dim, hid_dims, out_channels)

        self.reset_parameters()

    def reset_parameters(self):
        self.transition.reset_parameters()
        self.read_out.reset_parameters()

    def forward(self):
        edge_index = data.edge_index
        edge_weight = data.edge_attr
        node_states = self.node_states
        for i in range(self.num_layers):
            m = self.propagate(edge_index, x=node_states, edge_weight=edge_weight, size=None)
            new_states = self.transition(m)
            with torch.no_grad():
                distance = torch.norm(new_states - node_states, dim=1)
                convergence = distance < self.eps
            node_states = new_states
            if convergence.all():
                break
        
        out = self.read_out(node_states)
        return F.log_softmax(out, dim=-1)
    
    def message(self, x_j, edge_weight):
        return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j
    
    def message_and_aggregate(self, adj_t, x):
        return torch.matmul(adj_t, x, reduce=self.aggr)
    
    def __repr__(self):
        return '{}({}, num_layers={})'.format(self.__class__.__name__, self.out_channels, self.num_layers)

In [9]:
model = GNNM(data.num_nodes, dataset.num_classes, 32, [64,64,64,64,64], eps=0.01).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = nn.CrossEntropyLoss()

test_dataset = dataset[:len(dataset)//10]
train_dataset = dataset[len(dataset) // 10:]
test_loader = DataLoader(test_dataset)
train_loader = DataLoader(train_dataset)



In [10]:
def train():
    model.train()
    optimizer.zero_grad()
    loss_fn(model()[data.train_mask], data.y[data.train_mask]).backward()
    optimizer.step()

def test():
    model.eval()
    logits, accs = model(), []
    for _, mask in data('train_mask', 'val_mask', 'test_mask'):
        pred = logits[mask].max(1)[1]
        acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()
        accs.append(acc)
    return accs

for epoch in range(1, 11):
    train()
    accs = test()
    train_acc = accs[0]
    val_acc = accs[1]
    test_acc = accs[2]
    print('Epoch: {:03d}, Train Acc: {:.5f}, Val Acc: {:.5f}, Test Acc: {:.5f}'.format(epoch, train_acc, val_acc, test_acc))

Epoch: 001, Train Acc: 0.13571, Val Acc: 0.08400, Test Acc: 0.09700
Epoch: 002, Train Acc: 0.09286, Val Acc: 0.05000, Test Acc: 0.04200
Epoch: 003, Train Acc: 0.12143, Val Acc: 0.09000, Test Acc: 0.09400
Epoch: 004, Train Acc: 0.20714, Val Acc: 0.14000, Test Acc: 0.15600
Epoch: 005, Train Acc: 0.14286, Val Acc: 0.26200, Test Acc: 0.26000
Epoch: 006, Train Acc: 0.18571, Val Acc: 0.26600, Test Acc: 0.26100
Epoch: 007, Train Acc: 0.23571, Val Acc: 0.20200, Test Acc: 0.20000
Epoch: 008, Train Acc: 0.13571, Val Acc: 0.06000, Test Acc: 0.06400
Epoch: 009, Train Acc: 0.13571, Val Acc: 0.05800, Test Acc: 0.06100
Epoch: 010, Train Acc: 0.12857, Val Acc: 0.05400, Test Acc: 0.06000
