In this week, you are required to implement a toy GATConv and SAGEConv based on document. Also, you need to implement both in PyG and DGL. In this work, you will get a further understanding of tensor-centric in PyG and graph-centric in DGL.

In [None]:
!pip install  dgl -f https://data.dgl.ai/wheels/repo.html
!pip install torch_geometric
!pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.0.1+cpu.html

In [4]:
import torch
import numpy as np
import torch.nn as nn
import dgl
import dgl.function as fn
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import add_self_loops,degree
from torch_geometric.utils import softmax

class PyG_GATConv(MessagePassing):
  def __init__(self, in_channel, out_channel):
    super(PyG_GATConv, self).__init__(aggr='mean')
    self.in_channel = in_channel
    self.out_channel = out_channel
    self.linear = nn.Linear(in_channel, out_channel)
    self.attn_linear = nn.Linear(2*out_channel, 1)
    self.leaky_relu = nn.LeakyReLU()
    self.softmax = nn.Softmax(dim=1)

  def forward(self,x, edge_index):
    wh = self.linear(x)
    edge_attr = torch.cat([wh[edge_index[0]], wh[edge_index[1]]], dim=1)
    eij = self.leaky_relu(self.attn_linear(edge_attr)).squeeze()
    # alpha = self.softmax(eij)
    alpha = softmax(eij, edge_index[1])
    return self.propagate(edge_index, x=wh, alpha=alpha)
  def message(self,x_j, alpha):
    return alpha.view(-1,1) * x_j

class PyG_SAGEConv(MessagePassing):
  def __init__(self, in_channel, out_channel):
    super(PyG_SAGEConv, self).__init__(aggr='mean')
    self.in_channel = in_channel
    self.out_channel = out_channel
    self.linear = nn.Linear(in_channel, out_channel)
    self.relu = nn.ReLU()
  def forward(self,x, edge_index):
    print("x,edge_index")
    print(x,edge_index)
    return self.propagate(edge_index, x=x)

    # print("row " + str(row),"col "+str(col))
  def message(self,x_i,x_j):
    return torch.cat([x_i,x_j])
  
  def update(self, aggr_out):
    return self.relu(self.linear(aggr_out))


In [5]:
class DGL_GATConv(nn.Module):
  def __init__(self, in_channel, out_channel):
    super(DGL_GATConv, self).__init__()
    self.linear = nn.Linear(in_channel, out_channel, bias=False)
    self.attn_linear = nn.Linear(2 * out_channel, 1, bias=False)
    
    self.leaky_relu = nn.LeakyReLU()
    self.softmax = nn.Softmax(dim=1)


  def forward(self, g, h):
    # h:[5,8]
    g.ndata['h'] = self.linear(h)
    # data_h:[5,4]

    g.update_all(self.message_leakyrelu, self.reduce_softmax)
    return g.ndata.pop('h')

  def reduce_softmax(self,nodes):
      # print(nodes.mailbox)
      alpha = self.softmax(nodes.mailbox['e'])
      h = torch.sum(alpha * nodes.mailbox['h'], dim=1)
      return {'h': h}

  def message_leakyrelu(self, edges):
      h = torch.cat([edges.src['h'] , edges.dst['h']],dim=1)
      Wh = self.attn_linear(h)
      return {'e': self.leaky_relu(Wh), 'h' : edges.src['h']}
    
class DGL_SAGEConv(nn.Module):
  def __init__(self, in_channel, out_channel):
    super(DGL_SAGEConv, self).__init__()
    self.linear = nn.Linear(in_channel * 2, out_channel, bias=False)
    self.relu = nn.ReLU()
  

  def forward(self, g, h):
    g.ndata['h'] = h
    g.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'hn'))
    hn = g.ndata.pop('hn')
    h = torch.cat((h, hn), dim=1)
    h = self.linear(h)
    h = self.relu(h)
    return h

If you want to check your answer, you can run the following code.

In [6]:
edge_index = torch.tensor([[0,1,1,2,2,4],[2,0,2,3,4,3]])
x = torch.ones((5, 8))
conv = PyG_GATConv(8, 4)
output = conv(x, edge_index)
print(output)
conv = PyG_SAGEConv(8, 4)
output = conv(x, edge_index)
print(output)

src = torch.tensor([0, 1, 1, 2, 2, 4])
dst = torch.tensor([2, 0, 2, 3, 4, 3])
h = torch.ones((5, 8))
g = dgl.graph((src, dst))
conv = DGL_GATConv(8, 4)
output = conv(g, h)
print(output)
conv = DGL_SAGEConv(8, 4)
output = conv(g, h)
print(output)

tensor([[-1.0417, -0.2492,  0.4913,  0.0191],
        [ 0.0000,  0.0000,  0.0000,  0.0000],
        [-0.5209, -0.1246,  0.2456,  0.0095],
        [-0.5209, -0.1246,  0.2456,  0.0095],
        [-1.0417, -0.2492,  0.4913,  0.0191]], grad_fn=<DivBackward0>)
x,edge_index
tensor([[1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1.]]) tensor([[0, 1, 1, 2, 2, 4],
        [2, 0, 2, 3, 4, 3]])


RuntimeError: The expanded size of the tensor (12) must match the existing size (6) at non-singleton dimension 0.  Target sizes: [12, 8].  Tensor sizes: [6, 1]

In [24]:
import torch
import numpy as np
import torch.nn as nn
import dgl
import dgl.function as fn
class DGL_GATConv(nn.Module):
  def __init__(self, in_channel, out_channel):
    super(DGL_GATConv, self).__init__()
    self.linear = nn.Linear(in_channel, out_channel, bias=False)
    self.attn_linear = nn.Linear(2 * out_channel, 1, bias=False)
    
    self.leaky_relu = nn.LeakyReLU(0.2)
    self.softmax = nn.Softmax(dim=1)


  def forward(self, g, h):
    # h:[5,8]
    g.ndata['h'] = self.linear(h)
    # data_h:[5,4]

    g.update_all(self.message_leakyrelu, self.reduce_softmax)
    return g.ndata.pop('h')

  def reduce_softmax(self,nodes):
      # print(nodes.mailbox)
      alpha = self.softmax(nodes.mailbox['e'])
      h = torch.sum(alpha * nodes.mailbox['h'], dim=1)
      return {'h': h}

  def message_leakyrelu(self, edges):
      h = torch.cat([edges.src['h'] , edges.dst['h']],dim=1)
      Wh = self.attn_linear(h)
      return {'e': self.leaky_relu(Wh), 'h' : edges.src['h']}
  
#test
src = torch.tensor([0, 1, 1, 2, 2, 4])
dst = torch.tensor([2, 0, 2, 3, 4, 3])
h = torch.ones((5, 8))
g = dgl.graph((src, dst))
conv = DGL_GATConv(8, 4)
output = conv(g, h)
print(output)

tensor([[-0.1365, -0.9197,  0.6146,  0.4226],
        [ 0.0000,  0.0000,  0.0000,  0.0000],
        [-0.1365, -0.9197,  0.6146,  0.4226],
        [-0.1365, -0.9197,  0.6146,  0.4226],
        [-0.1365, -0.9197,  0.6146,  0.4226]], grad_fn=<IndexCopyBackward0>)
