In this week, you are required to implement a toy GATConv and SAGEConv based on document. Also, you need to implement both in PyG and DGL. In this work, you will get a further understanding of tensor-centric in PyG and graph-centric in DGL.

In [7]:
# !pip install  dgl -f https://data.dgl.ai/wheels/repo.html
# !pip install torch_geometric
# !pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.0.1+cpu.html

Defaulting to user installation because normal site-packages is not writeable
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Looking in links: https://data.dgl.ai/wheels/repo.html
Could not fetch URL https://data.dgl.ai/wheels/repo.html: There was a problem confirming the ssl certificate: HTTPSConnectionPool(host='data.dgl.ai', port=443): Max retries exceeded with url: /wheels/repo.html (Caused by SSLError(SSLZeroReturnError(6, 'TLS/SSL connection has been closed (EOF) (_ssl.c:1131)'))) - skipping
Could not fetch URL https://pypi.tuna.tsinghua.edu.cn/simple/pip/: There was a problem confirming the ssl certificate: HTTPSConnectionPool(host='pypi.tuna.tsinghua.edu.cn', port=443): Max retries exceeded with url: /simple/pip/ (Caused by SSLError(SSLZeroReturnError(6, 'TLS/SSL connection has been closed (EOF) (_ssl.c:1131)'))) - skipping


DEPRECATION: matlabengineforpython R2021b has a non-standard version number. pip 24.0 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of matlabengineforpython or contact the author to suggest that they release a version with a conforming version number. Discussion can be found at https://github.com/pypa/pip/issues/12063


^C
^C


In [11]:
import torch
import numpy as np
import torch.nn as nn
import dgl
import dgl.function as fn
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import add_self_loops, degree
from torch_geometric.utils import softmax

class PyG_GATConv(MessagePassing):
    def __init__(self, in_channels: int, out_channels: int):
        super().__init__(aggr='add')  # "Addition" aggregation.
        self.in_channels = in_channels
        self.out_channels = out_channels

        self.lin = torch.nn.Linear(in_channels, out_channels, bias=False)
        self.att = torch.nn.Linear(2 * out_channels, 1, bias=False)
        self.act = torch.nn.LeakyReLU() # not a real layer, just for activation

    def forward(self, x, edge_index):
        x = self.lin(x)
        # compute attention coefficients based on edge features e_ij
        edge_attr = torch.cat([x[edge_index[0]], x[edge_index[1]]], dim=-1)
        edge_attr = self.act(self.att(edge_attr))
        
        # alpha_ij is the normalized attention scores
        alpha = softmax(edge_attr, edge_index[1])
        
        # calc message passing with attention scores
        out = self.propagate(edge_index, x=x, alpha=alpha)
        
        return out

    def message(self, x_j, alpha):
        # x_j is the input node features, alpha is the attention scores as weights
        return alpha * x_j



class PyG_SAGEConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(PyG_SAGEConv, self).__init__(aggr='mean')  # "mean"
        self.lin = torch.nn.Linear(in_channels, out_channels)
        self.act = torch.nn.ReLU()

    def forward(self, x, edge_index):
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        return self.propagate(edge_index, x=x, norm=norm)

    def message(self, x_j, norm):
        return norm.view(-1, 1) * x_j

    def update(self, aggr_out):
        return self.act(self.lin(aggr_out))

In [12]:
class DGL_GATConv(nn.Module):
    def __init__(self, in_feats, out_feats):
        super(DGL_GATConv, self).__init__()
        self.fc = nn.Linear(in_feats, out_feats, bias=False)
        self.attn_fc = nn.Linear(2 * out_feats, 1, bias=False)
        self.reset_parameters()

    def reset_parameters(self):
        gain = nn.init.calculate_gain('relu')
        nn.init.xavier_uniform_(self.fc.weight, gain=gain)
        nn.init.xavier_uniform_(self.attn_fc.weight, gain=gain)

    def edge_attention(self, edges):
        z2 = torch.cat([edges.src['z'], edges.dst['z']], dim=1)
        a = self.attn_fc(z2)
        return {'e': a}

    def message_func(self, edges):
        return {'z': edges.src['z'], 'e': edges.data['e']}

    def reduce_func(self, nodes):
        alpha = torch.softmax(nodes.mailbox['e'], dim=1)
        h = torch.sum(alpha * nodes.mailbox['z'], dim=1)
        return {'h': h}

    def forward(self, g, x):
        z = self.fc(x)
        g.ndata['z'] = z
        g.apply_edges(self.edge_attention)
        g.update_all(self.message_func, self.reduce_func)
        return g.ndata.pop('h')


class DGL_SAGEConv(nn.Module):
    def __init__(self, in_channels: int, out_channels: int):
        super().__init__()
        self.fc = nn.Linear(in_channels * 2, out_channels)
        self.act = nn.ReLU()

    def forward(self, g: dgl.DGLGraph, h: torch.Tensor) -> torch.Tensor:
        with g.local_scope():
            g.ndata['h'] = h
            g.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'neigh'))
            neigh = g.ndata['neigh']
            return self.act(self.fc(torch.cat([h, neigh], dim=1)))

If you want to check your answer, you can run the following code.

In [13]:
edge_index = torch.tensor([[0,1,1,2,2,4],[2,0,2,3,4,3]])
x = torch.ones((5, 8))
conv = PyG_GATConv(8, 4)
output = conv(x, edge_index)
print(output)
conv = PyG_SAGEConv(8, 4)
output = conv(x, edge_index)
print(output)

src = torch.tensor([0, 1, 1, 2, 2, 4])
dst = torch.tensor([2, 0, 2, 3, 4, 3])
h = torch.ones((5, 8))
g = dgl.graph((src, dst))
conv = DGL_GATConv(8, 4)
output = conv(g, h)
print(output)
conv = DGL_SAGEConv(8, 4)
output = conv(g, h)
print(output)

tensor([[0.4891, 0.2247, 0.0578, 0.2553],
        [0.0000, 0.0000, 0.0000, 0.0000],
        [0.4891, 0.2247, 0.0578, 0.2553],
        [0.4891, 0.2247, 0.0578, 0.2553],
        [0.4891, 0.2247, 0.0578, 0.2553]], grad_fn=<ScatterAddBackward0>)
tensor([[0.0000, 0.0000, 0.0000, 0.4698],
        [0.0535, 0.0000, 0.0000, 0.5607],
        [0.0000, 0.0000, 0.0000, 0.4322],
        [0.0000, 0.0110, 0.0000, 0.4136],
        [0.0000, 0.0000, 0.0000, 0.4356]], grad_fn=<ReluBackward0>)
tensor([[-2.5431, -1.1511,  0.5339,  3.7166],
        [ 0.0000,  0.0000,  0.0000,  0.0000],
        [-2.5431, -1.1511,  0.5339,  3.7166],
        [-2.5431, -1.1511,  0.5339,  3.7166],
        [-2.5431, -1.1511,  0.5339,  3.7166]], grad_fn=<IndexCopyBackward0>)
tensor([[0.2385, 0.5614, 0.0000, 0.0000],
        [0.0000, 0.2318, 0.0000, 0.0000],
        [0.2385, 0.5614, 0.0000, 0.0000],
        [0.2385, 0.5614, 0.0000, 0.0000],
        [0.2385, 0.5614, 0.0000, 0.0000]], grad_fn=<ReluBackward0>)
