In [1]:
import torch
from torch.nn import Linear, Parameter
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree

In [23]:
class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super().__init__(aggr='add')  # "Add" aggregation (Step 5).
        self.lin = Linear(in_channels, out_channels, bias=False)
        self.bias = Parameter(torch.Tensor(out_channels))

        self.reset_parameters()

    def reset_parameters(self):
        self.lin.reset_parameters()
        self.bias.data.zero_()

    def forward(self, x, edge_index):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]

        # Step 1: Add self-loops to the adjacency matrix.
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        # Step 2: Linearly transform node feature matrix.
        x = self.lin(x)

        # Step 3: Compute normalization.
        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
        print("a", norm.size())

        # Step 4-5: Start propagating messages.
        out = self.propagate(edge_index, x=x, norm=norm)

        # Step 6: Apply a final bias vector.
        out += self.bias

        return out

    def message(self, x_j, norm):
        # x_j has shape [E, out_channels]
        print(x_j.size())
        print(norm.size())

        # Step 4: Normalize node features.
        return norm.view(-1, 1) * x_j

In [3]:
from torch_geometric.datasets import TUDataset
dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES')

Downloading https://www.chrsmrrs.com/graphkerneldatasets/ENZYMES.zip
Extracting /tmp/ENZYMES/ENZYMES/ENZYMES.zip
Processing...
Done!


In [20]:
data = dataset[0]
data

Data(edge_index=[2, 168], x=[37, 3], y=[1])

In [24]:
conv = GCNConv(3, 16)
x, edge_index = data.x, data.edge_index
x = conv(x, edge_index)
x

tensor([0.5000, 0.4082, 0.4082, 0.3780, 0.4472, 0.4082, 0.4472, 0.4082, 0.5000,
        0.4082, 0.4472, 0.4472, 0.4082, 0.4472, 0.4472, 0.4472, 0.4472, 0.4472,
        0.5000, 0.5000, 0.3780, 0.4082, 0.4472, 0.4472, 0.4472, 0.4082, 0.4082,
        0.4082, 0.3780, 0.3536, 0.4082, 0.5000, 0.4472, 0.3780, 0.3780, 0.4082,
        0.4472])
a torch.Size([205])
torch.Size([205, 16])
torch.Size([205])


tensor([[-0.3964,  0.2050, -0.4111, -0.2535,  0.1883,  0.0435,  0.2698, -0.1096,
         -0.1150,  0.4174,  0.1600, -0.1886, -0.1801, -0.4785, -0.3219, -0.1456],
        [-0.4371,  0.1925, -0.4458, -0.3130,  0.3232,  0.2261,  0.2209,  0.0347,
         -0.0507,  0.5021,  0.1395, -0.0516, -0.3287, -0.2221, -0.4494, -0.2300],
        [-0.4279,  0.1904, -0.4369, -0.3044,  0.3095,  0.2107,  0.2208,  0.0247,
         -0.0542,  0.4890,  0.1388, -0.0599, -0.3140, -0.2358, -0.4343, -0.2210],
        [-0.4973,  0.2434, -0.5127, -0.3317,  0.2835,  0.1274,  0.3072, -0.0739,
         -0.1131,  0.5408,  0.1857, -0.1727, -0.2792, -0.4754, -0.4425, -0.2111],
        [-0.4030,  0.1932, -0.4146, -0.2729,  0.2438,  0.1249,  0.2396, -0.0410,
         -0.0825,  0.4433,  0.1460, -0.1210, -0.2420, -0.3482, -0.3701, -0.1795],
        [-0.4459,  0.2167, -0.4593, -0.2989,  0.2596,  0.1225,  0.2718, -0.0590,
         -0.0979,  0.4868,  0.1647, -0.1476, -0.2564, -0.4120, -0.4011, -0.1925],
        [-0.4627,  0.2