In [2]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import torch_geometric as tgn
from torch.nn import Sequential as Seq, Linear, ReLU
from torch_geometric.utils import add_self_loops, degree
from torch_geometric.nn import MessagePassing
from torch_geometric.data import Data
import torch.nn.functional as F

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super().__init__(aggr='add')  # "Add" aggregation (Step 5).
        self.lin = torch.nn.Linear(in_channels, out_channels)

    def forward(self, x, edge_index):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]

        # Step 1: Add self-loops to the adjacency matrix.
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        # Step 2: Linearly transform node feature matrix.
        x = self.lin(x)

        # Step 3: Compute normalization.
        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        # Step 4-5: Start propagating messages.
        return self.propagate(edge_index, x=x, norm=norm)

    def message(self, x_j, norm):
        # x_j has shape [E, out_channels]

        # Step 4: Normalize node features.
        return norm.view(-1, 1) * x_j

In [4]:
edge_index = torch.tensor([[0, 1],
                           [1, 0],
                           [1, 2],
                           [2, 1]], dtype=torch.long)

x = torch.tensor([[-1,-1], [0,-1], [1,3]], dtype=torch.float)
y = torch.tensor([0,1,1]).float()
data = Data(x=x, edge_index=edge_index.t().contiguous(), y=y)

In [5]:
model = GCNConv(2,1)
optimizer = torch.optim.Adam(model.parameters())
L = torch.nn.MSELoss()

In [6]:
model.train()
epochs = 100
for epoch in range(100):
            optimizer.zero_grad()
            out = model(data.x, data.edge_index)
            loss = L(out, data.y)
            loss.backward()
            optimizer.step()

  return F.mse_loss(input, target, reduction=self.reduction)
