In [None]:
import torch
from torch_geometric.nn import global_add_pool, global_mean_pool, GlobalAttention, Set2Set
from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder
from torch_geometric.nn import MessagePassing

## 基于图同构网络（GIN）的图表征网络的实现
基于图同构网络（Graph Isomorphism Network, GIN）的图表征网络是当前最经典的图表征学习网络
基于图同构网络的图表征学习主要包含以下两个过程：
1. 首先计算得到节点表征；
2. 其次对图上各个节点的表征做图池化（Graph Pooling），或称为图读出（Graph Readout），得到图的表征（Graph Representation）。

`ogb.graphpropred.mol_encoder` 中 `AtomEncoder` 和 `BondEncoder` 的使用：

### 基于图同构网络的图表征模块（GINGraphRepr Module）

In [None]:
class GINConv(MessagePassing):
    """
    torch_geometric.nn.GINConv 不支持存在边属性的图
    Args:
        MessagePassing (_type_): _description_
    """
    def __init__(self, emb_dim):
        super(GINConv, self).__init__(aggr = "add")

        self.mlp = torch.nn.Sequential(torch.nn.Linear(emb_dim, emb_dim), torch.nn.BatchNorm1d(emb_dim), torch.nn.ReLU(), torch.nn.Linear(emb_dim, emb_dim))
        self.eps = torch.nn.Parameter(torch.Tensor([0]))
        self.bond_encoder = BondEncoder(emb_dim = emb_dim)

    def forward(self, x, edge_index, edge_attr):
        edge_embedding = self.bond_encoder(edge_attr) # 先将类别型边属性转换为边嵌入
        out = self.mlp((1 + self.eps) *x + self.propagate(edge_index, x=x, edge_attr=edge_embedding))
        return out

    def message(self, x_j, edge_attr):
        return torch.nn.functional.relu(x_j + edge_attr)
        
    def update(self, aggr_out):
        return aggr_out

class GINNodeEmbedding(torch.nn.Module):
    def __init__(self, num_layers, emb_dim, drop_ratio=0.5, JK="last", residual=False, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.num_layers = num_layers
        self.drop_ratio = drop_ratio
        self.JK = JK
        self.residual = residual
        
        if self.num_layers < 2:
            raise ValueError("Number of GNN layers mush be greater than 1")
        
        self.atom_encoder = AtomEncoder(emb_dim)

        self.convs = torch.nn.ModuleList()
        self.batch_norms = torch.nn.ModuleList()
        
        for layer in range(num_layers):
            self.convs.append(GINConv(emb_dim))
            self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim))
    
    def forward(self, batched_data):
        x, edge_index, edge_attr = batched_data.x, batched_data.edge_index, batched_data.edge_attr
        
        h_list = [self.atom_encoder(x)]
        for layer in range(self.num_layers):
            h = self.convs[layer](h_list[layer], edge_index, edge_attr)
            h = self.batch_norms[layer](h)
            if layer == self.num_layers - 1:
                h = torch.nn.functional.dropout(h, self.drop_ratio, training=self.training)
            else:
                h = torch.nn.functional.dropout(torch.nn.functional.relu(h), self.drop_ratio, training=self.training)
            if self.residual:
                h += h_list[layer]
            h_list.append(h)
        
        if self.JK == 'last':
            node_representation = h_list[-1]
        elif self.JK == 'sum':
            node_representation = 0
            for layer in range(self.num_layers - 1):
                node_representation += h_list[layer]
        return node_representation