In [8]:
import torch
from torch.nn import Sequential, Linear, Parameter, ReLU
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops
from typing import Any, Dict, List, Optional, Union
from torch import Tensor
from torch_geometric.nn.aggr import Aggregation


In [5]:
class GNNSolver(MessagePassing):
    def __init__(self, in_channels: int, out_channels: int, alpha: float = 0.01) -> None:
        """
        Args:
            in_channels: число фичей в вершинах графа
            alpha: параметр для учета величины обновления X
        """
        super().__init__(aggr="add")
        self.alpha = alpha

        self.linear = Linear(in_channels, out_channels, bias=False)
        # self.bias = Parameter(torch.empty(out_channels))

        self.mlp = Sequential(
            Linear(2 * in_channels, out_channels),
            ReLU(),
            Linear(out_channels, out_channels)
        )
        self.reset_parameters()

    def reset_parameters(self):
        self.linear.reset_parameters()
        # self.bias.zero_()

    def forward(self, x, edge_index: Tensor) -> Tensor:
        """
        Args:
            x: Данные размерностью [num_nodes, num_features]  (num_features := in_channels) 
            edge_index has shape [2, num_edges]
        """
        res = x
        self.flow = 'source_to_target'

        # Формирование сообщения j -> i
        msg1 = self.propagate(edge_index, x=x)

        self.flow = 'target_to_source'

        # Добавление петель
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        # Propagate
        out = self.propagate(edge_index, x=x)

        # ResNet fashion
        out = res + self.alpha * out

        return out

    def message(self, x_i: Tensor, x_j: Tensor) -> Tensor:
        """
        Args:
            x_i: Target node features [num_edges, num_features]
            x_j: Source node features [num_edges, num_features]
        """
        tmp = torch.cat([x_i, x_j], dim=1) # shape [E, 2 * in_channels]
        return self.mlp(tmp)
    
    # def __repr__(self):
    #     return ('{}({}, {}, num_layers={}, batch_norm={}, cat={}, lin={}, '
    #             'dropout={})').format(self.__class__.__name__,
    #                                   self.in_channels, self.out_channels,
    #                                   self.num_layers, self.batch_norm,
    #                                   self.cat, self.lin, self.dropout)    