In this week, you are required to implement a toy GATConv and SAGEConv based on document. Also, you need to implement both in PyG and DGL. In this work, you will get a further understanding of tensor-centric in PyG and graph-centric in DGL.

In [26]:
import torch
import numpy as np
import torch.nn as nn
import dgl
import dgl.function as fn
import torch.nn.functional as F
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import add_self_loops, degree

class PyG_GATConv(MessagePassing):
  def __init__(self, in_channel, out_channel):
    super().__init__(aggr='add')

    self.W = nn.Linear(in_channel, out_channel, bias=False)
    self.a = nn.Linear(2*out_channel, 1, bias=False)
    self.dropout = nn.Dropout(0.1)
    
    nn.init.xavier_uniform_(self.W.weight, gain=1.414)
    nn.init.xavier_uniform_(self.a.weight, gain=1.414)

  def forward(self, x, edge_index):
    x = self.W(x)
    edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
    out = self.propagate(edge_index, x=x)
    return out

  def message(self, x_i, x_j, edge_index):
    x_cat = torch.cat([x_i, x_j], dim=-1)
    e = F.leaky_relu(self.a(x_cat), negative_slope=0.2)

    node_i = edge_index[1]
    e = self.dropout(e)
    e = torch.exp(e)
    input = torch.zeros_like(e)
    e_sum = torch.scatter_add(input=input, dim=0, index=node_i.unsqueeze(1), src=e)
    alpha = e / (e_sum[node_i] + 1e-16)

    return alpha * x_j

class PyG_SAGEConv(MessagePassing):
  def __init__(self, in_channel, out_channel):
    super().__init__(aggr='mean')

    self.lin = nn.Linear(2*in_channel, out_channel)

    nn.init.xavier_uniform_(self.lin.weight, gain=1.414)
    if self.lin.bias is not None:
      nn.init.zeros_(self.lin.bias)

  def forward(self, x, edge_index):
    edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
    agg_neighbors = self.propagate(edge_index, x=x)
    x_cat = torch.cat([x, agg_neighbors], dim=-1)

    out = self.lin(x_cat)
    return out

  def message(self, x_j):
    return x_j

In [34]:
class DGL_GATConv(nn.Module):
  def __init__(self, in_channel, out_channel):
    super().__init__()
    self.W = nn.Linear(in_channel, out_channel, bias=False)
    self.a = nn.Linear(2*out_channel, 1, bias=False)
    self.dropout = nn.Dropout(0.1)
    
    nn.init.xavier_uniform_(self.W.weight, gain=1.414)
    nn.init.xavier_uniform_(self.a.weight, gain=1.414)

  def forward(self, g, h):
    h = self.W(h)
    g = dgl.add_self_loop(g)
    g.ndata['h'] = h

    def msg_func(edges):
      x_cat = torch.cat([edges.dst['h'], edges.src['h']], dim=-1)
      e = F.leaky_relu(self.a(x_cat), negative_slope=0.2)
      return {'e': self.dropout(torch.exp(e))}

    g.apply_edges(msg_func)

    g.update_all(
      fn.copy_e('e', 'm'),
      fn.sum('m', 'e_sum')
    )

    v = g.edges()[1]
    g.edata['alpha'] = g.edata['e'] / (g.ndata['e_sum'][v] + 1e-16)

    g.update_all(
      fn.u_mul_e('h', 'alpha', 'm'),
      fn.sum('m', 'h_out')
    )

    out = g.ndata.pop('h_out')
    return out

class DGL_SAGEConv(nn.Module):
  def __init__(self, in_channel, out_channel):
    super().__init__()
    self.lin = nn.Linear(2*in_channel, out_channel)

    nn.init.xavier_uniform_(self.lin.weight, gain=1.414)
    if self.lin.bias is not None:
      nn.init.zeros_(self.lin.bias)

  def forward(self, g, h):
    g = dgl.add_self_loop(g)
    g.ndata['h'] = h

    def msg_func(edges):
      return {'m': edges.src['h']}

    g.update_all(
      msg_func,
      fn.mean('m', 'agg_neighbors')
    )

    x_concat = torch.cat([g.ndata['h'], g.ndata['agg_neighbors']], dim=-1)

    out = self.lin(x_concat)
    return out

If you want to check your answer, you can run the following code.

In [35]:
edge_index = torch.tensor([[0,1,1,2,2,4],[2,0,2,3,4,3]])
x = torch.ones((5, 8))
conv = PyG_GATConv(8, 4)
output = conv(x, edge_index)
print(output)
conv = PyG_SAGEConv(8, 4)
output = conv(x, edge_index)
print(output)

src = torch.tensor([0, 1, 1, 2, 2, 4])
dst = torch.tensor([2, 0, 2, 3, 4, 3])
h = torch.ones((5, 8))
g = dgl.graph((src, dst))
conv = DGL_GATConv(8, 4)
output = conv(g, h)
print(output)
conv = DGL_SAGEConv(8, 4)
output = conv(g, h)
print(output)

tensor([[-0.9622,  0.8724,  1.7199, -0.5196],
        [-0.9622,  0.8724,  1.7199, -0.5196],
        [-0.9622,  0.8724,  1.7199, -0.5196],
        [-0.9622,  0.8724,  1.7199, -0.5196],
        [-0.9622,  0.8724,  1.7199, -0.5196]], grad_fn=<ScatterAddBackward0>)
tensor([[ 1.8194, -2.2228,  3.6659,  1.1293],
        [ 1.8194, -2.2228,  3.6659,  1.1293],
        [ 1.8194, -2.2228,  3.6659,  1.1293],
        [ 1.8194, -2.2228,  3.6659,  1.1293],
        [ 1.8194, -2.2228,  3.6659,  1.1293]], grad_fn=<AddmmBackward0>)
tensor([[ 0.7014, -1.0442,  0.0533,  2.7487],
        [ 0.7014, -1.0442,  0.0533,  2.7487],
        [ 0.7014, -1.0442,  0.0533,  2.7487],
        [ 0.7014, -1.0442,  0.0533,  2.7487],
        [ 0.7014, -1.0442,  0.0533,  2.7487]], grad_fn=<GSpMMBackward>)
tensor([[-1.6457, -3.3782, -1.0880, -0.9941],
        [-1.6457, -3.3782, -1.0880, -0.9941],
        [-1.6457, -3.3782, -1.0880, -0.9941],
        [-1.6457, -3.3782, -1.0880, -0.9941],
        [-1.6457, -3.3782, -1.0880, -0.99