In this week, you are required to implement a toy GATConv and SAGEConv based on document. Also, you need to implement both in PyG and DGL. In this work, you will get a further understanding of tensor-centric in PyG and graph-centric in DGL.

In [None]:
!pip install  dgl -f https://data.dgl.ai/wheels/repo.html
!pip install torch_geometric
!pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.0.1+cpu.html

In [5]:
import torch
import numpy as np
import torch.nn as nn
import dgl
import dgl.function as fn
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import add_self_loops, softmax
class PyG_GATConv(MessagePassing):
    def __init__(self, in_channel, out_channel):
        super(PyG_GATConv, self).__init__(aggr='add')
        self.lin = nn.Linear(in_channel, out_channel, bias=False)
        self.att = nn.Parameter(torch.Tensor(1, out_channel))
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.lin.weight)
        nn.init.xavier_uniform_(self.att)

    def forward(self, x, edge_index):
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
        x = self.lin(x)
        alpha = (x @ self.att.t()).squeeze(-1)
        ptr = torch.arange(x.size(0) + 1, dtype=torch.long, device=edge_index.device)
        return self.propagate(edge_index, x=x, alpha=alpha, ptr=ptr)


    def message(self, x_j, alpha_j, index, ptr, size_i):
        alpha = softmax(alpha_j, index)
        return x_j * alpha.view(-1, 1)


    def update(self, aggr_out):
        return aggr_out


class PyG_SAGEConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(PyG_SAGEConv, self).__init__(aggr='mean')
        self.lin = nn.Linear(in_channels, out_channels)  # Set the output dimension of the linear layer
        self.update_lin = nn.Linear(in_channels , out_channels)

    def forward(self, x, edge_index):
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
        x = self.lin(x)
        return self.propagate(edge_index, x=x)

    def message(self, x_j):
        return x_j

    def update(self, aggr_out, x):
        combined = torch.cat([x, aggr_out], dim=-1)
        return self.update_lin(combined)


In [6]:
class DGL_GATConv(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(DGL_GATConv, self).__init__()
        self.in_channel = in_channel
        self.out_channel = out_channel
        self.W = nn.Parameter(torch.Tensor(in_channel, out_channel))
        self.att = nn.Parameter(torch.Tensor(1, out_channel))
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.W)
        nn.init.xavier_uniform_(self.att)

    def forward(self, g, h):
        with g.local_scope():
            h = torch.matmul(h, self.W)
            g.ndata['h'] = h
            g.ndata['a'] = torch.matmul(h, self.att.t())
            g.apply_edges(fn.u_add_v('a', 'a', 'e'))
            e = torch.exp(g.edata.pop('e'))
            g.edata['a'] = dgl.ops.edge_softmax(g, e)
            g.update_all(fn.u_mul_e('h', 'a', 'm'), fn.sum('m', 'h'))
            return g.ndata.pop('h')


class DGL_SAGEConv(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(DGL_SAGEConv, self).__init__()
        self.in_channel = in_channel
        self.out_channel = out_channel
        self.W = nn.Linear(in_channel, out_channel)
        self.update_lin = nn.Linear(in_channel + out_channel, out_channel)

    def forward(self, g, h):
        with g.local_scope():
            h_in = h
            h = self.W(h)
            g.ndata['h'] = h
            g.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'h'))
            h = g.ndata.pop('h')
            return self.update_lin(torch.cat([h_in, h], dim=1))

If you want to check your answer, you can run the following code.

In [7]:
edge_index = torch.tensor([[0,1,1,2,2,4],[2,0,2,3,4,3]])
x = torch.ones((5, 8))
conv = PyG_GATConv(8, 4)
output = conv(x, edge_index)
print(output)
conv = PyG_SAGEConv(8, 4)
output = conv(x, edge_index)
print(output)

src = torch.tensor([0, 1, 1, 2, 2, 4])
dst = torch.tensor([2, 0, 2, 3, 4, 3])
h = torch.ones((5, 8))
g = dgl.graph((src, dst))
conv = DGL_GATConv(8, 4)
output = conv(g, h)
print(output)
conv = DGL_SAGEConv(8, 4)
output = conv(g, h)
print(output)

tensor([[ 0.2320,  0.6226, -0.2256, -0.0852],
        [ 0.3480,  0.9339, -0.3384, -0.1278],
        [ 0.2320,  0.6226, -0.2256, -0.0852],
        [ 0.2320,  0.6226, -0.2256, -0.0852],
        [ 0.3480,  0.9339, -0.3384, -0.1278]],
       grad_fn=<CppNode<class SegmentSumCSR>>)
tensor([[-0.1045, -0.2056, -0.2056, -0.2831],
        [-0.1045, -0.2056, -0.2056, -0.2831],
        [-0.1045, -0.2056, -0.2056, -0.2831],
        [-0.1045, -0.2056, -0.2056, -0.2831],
        [-0.1045, -0.2056, -0.2056, -0.2831]], grad_fn=<AddmmBackward>)
tensor([[-0.4209, -0.2771,  0.1855,  1.8558],
        [ 0.0000,  0.0000,  0.0000,  0.0000],
        [-0.4209, -0.2771,  0.1855,  1.8558],
        [-0.4209, -0.2771,  0.1855,  1.8558],
        [-0.4209, -0.2771,  0.1855,  1.8558]], grad_fn=<GSpMMBackward>)
tensor([[ 0.7488, -0.4662, -0.2543,  0.8355],
        [ 0.6620, -0.6672,  0.0468,  0.6709],
        [ 0.7488, -0.4662, -0.2543,  0.8355],
        [ 0.7488, -0.4662, -0.2543,  0.8355],
        [ 0.7488, -0.4662,