In [10]:
import torch
import numpy as np
import torch.nn as nn
import dgl
import dgl.function as fn
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import add_self_loops,degree
from torch_geometric.utils import softmax

class PyG_GATConv(MessagePassing):
  def __init__(self, in_channel, out_channel):
    super(PyG_GATConv, self).__init__(aggr='mean')
    self.in_channel = in_channel
    self.out_channel = out_channel
    self.linear = nn.Linear(in_channel, out_channel)
    self.attn_linear = nn.Linear(2*out_channel, 1)
    self.leaky_relu = nn.LeakyReLU()
    self.softmax = nn.Softmax(dim=1)

  def forward(self,x, edge_index):
    wh = self.linear(x)
    edge_attr = torch.cat([wh[edge_index[0]], wh[edge_index[1]]], dim=1)
    eij = self.leaky_relu(self.attn_linear(edge_attr)).squeeze()
    # alpha = self.softmax(eij)
    alpha = softmax(eij, edge_index[1])
    return self.propagate(edge_index, x=wh, alpha=alpha)
  def message(self,x_j, alpha):
    return alpha.view(-1,1) * x_j

class PyG_SAGEConv(MessagePassing):
  def __init__(self, in_channel, out_channel):
    super(PyG_SAGEConv, self).__init__(aggr='mean')
    self.in_channel = in_channel
    self.out_channel = out_channel
    self.linear = nn.Linear(in_channel, out_channel)
    self.relu = nn.ReLU()
  def forward(self,x, edge_index):
    print("x,edge_index")
    print(x,edge_index)
    return self.propagate(edge_index, x=x)

    # print("row " + str(row),"col "+str(col))
  def message(self,x_i,x_j):
    return torch.cat([x_i,x_j])
  
  def update(self, aggr_out):
    return self.relu(self.linear(aggr_out))
  
class PyG_ToyConv(MessagePassing):
  def __init__(self,in_channel,out_channel):
    super(PyG_ToyConv,self).__init__(aggr="mean")
    self.in_channel = in_channel
    self.out_channel = out_channel
    self.linear = nn.Linear(in_channel,out_channel)
    self.activation = nn.ReLU()
  
  def forward(self,x,edge_index):
    x = self.linear(x)
    return self.propagate(edge_index,x=x)
  def message(self,x_j):
    return x_j

  def update(self,aggr_out,x_j):
    hcat = torch.cat([x_j,aggr_out])
    hi = self.activation(hcat)
    return hi



In [11]:
edge_index = torch.tensor([[0,1,1,2,2,4],[2,0,2,3,4,3]])
x = torch.ones((5, 8))
# conv = PyG_GATConv(8, 4)
# output = conv(x, edge_index)
# print(output)
# conv = PyG_SAGEConv(8, 4)
# output = conv(x, edge_index)
# print(output)

conv = PyG_ToyConv(8,4)
output = conv(x,edge_index)
print(output)

tensor([[0.0280, 0.0000, 0.5096, 0.4949],
        [0.0280, 0.0000, 0.5096, 0.4949],
        [0.0280, 0.0000, 0.5096, 0.4949],
        [0.0280, 0.0000, 0.5096, 0.4949],
        [0.0280, 0.0000, 0.5096, 0.4949],
        [0.0280, 0.0000, 0.5096, 0.4949],
        [0.0280, 0.0000, 0.5096, 0.4949],
        [0.0000, 0.0000, 0.0000, 0.0000],
        [0.0280, 0.0000, 0.5096, 0.4949],
        [0.0280, 0.0000, 0.5096, 0.4949],
        [0.0280, 0.0000, 0.5096, 0.4949]], grad_fn=<ReluBackward0>)
