In [1]:
import torch
from torch.nn import Linear, Parameter
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree

In [37]:
import torch
from torch.nn import Linear, Parameter
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree

class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super().__init__(aggr='add')  # "Add" aggregation (Step 5).
        self.lin = Linear(in_channels, out_channels, bias=False)
        self.bias = Parameter(torch.empty(out_channels))

        self.reset_parameters()

    def reset_parameters(self):
        self.lin.reset_parameters()
        self.bias.data.zero_()

    def forward(self, x, edge_index):
        # x has shape [N, in_channels]
        # N is the number of nodes in the graph, and 
        # in_channels is the dimensionality of the input features for each node
        # edge_index has shape [2, E]
        # E is the number of edges in the graph
        # first row contains the indicies of the source nodes of each graph
        # second row contains the indices of the target nodes of each edge, which do the aggregation

        # Step 1: Add self-loops to the adjacency matrix.
        # self loop is added to ensure that each node can send and receive messages
        # Specify the number of nodes in the graph
        num_nodes = x.size(0)
        print("num nodes: " + str(num_nodes))
        edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes)

        # Step 2: Linearly transform node feature matrix.
        # output is transformed feature matrix
        x = self.lin(x)

        # Step 3: Compute normalization. of the GCN Layer
        row, col = edge_index
        # calculates the degree of each node
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        # converts any infinite values of dividing by zero to 0
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        # shape is [numb edges, ]
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        # Step 4-5: Start propagating messages.
        # aggregates messages from its neighboring nodes 
        # combines thhem using a specific method (add, mean, etc)
        out = self.propagate(edge_index, x=x, norm=norm)

        # Step 6: Apply a final bias vector.
        out += self.bias

        return out

    def message(self, x_j, norm):
        # x_j has shape [E, out_channels]
        # Lifting a node feature - create a copy of the node, but modify to capture information 
        # related to the edges connected to that node
        # Step 4: Normalize node features.
        return norm.view(-1, 1) * x_j

In [38]:
# edge index needs to be 2,e
edge_index = torch.tensor([[0, 1],
                           [1, 0],
                           [1, 2],
                           [2, 1]], dtype=torch.long).t().contiguous()
x = torch.tensor([[-1], [0], [1]], dtype=torch.float)

In [39]:
# first number is the X size
conv = GCNConv(1, 32)
x = conv(x, edge_index)

num nodes: 3
