In [50]:
import torch
from torch.nn import Sequential, Linear, Parameter, ReLU
import torch.nn.functional as F

import torch_geometric
from torch_geometric.nn import MessagePassing, GeneralConv
from torch_geometric.data import Data
from torch_geometric.utils import add_self_loops, segregate_self_loops
from typing import Any, Dict, List, Optional, Union
from torch import Tensor
from torch_geometric.nn.aggr import Aggregation
from torch_geometric.typing import Adj, OptTensor

import networkx as nx


In [8]:
class GNNSolver(MessagePassing):
    def __init__(self, in_channels: int, out_channels: int, num_edge_features: int, alpha: float = 0.01) -> None:
        """
        Args:
            in_channels: число фичей в вершинах графа
            alpha: параметр для учета величины обновления X
        """
        super().__init__(aggr="add")
        self.alpha = alpha

        self.linear = Linear(in_channels, out_channels, bias=False)
        self.num_edge_features = num_edge_features

        self.mlp = Sequential(
            Linear(2 * in_channels + num_edge_features, out_channels),
            ReLU(),
            Linear(out_channels, out_channels)
        )
        self.reset_parameters()

    def reset_parameters(self):
        self.linear.reset_parameters()
        # self.bias.zero_()

    def forward(self, x: Tensor, edge_index: Adj,
                edge_attr: Tensor = None) -> Tensor:
        """
        Args:
            x: Данные размерностью [num_nodes, num_features]  (num_features := in_channels) 
            edge_index has shape [2, num_edges]
        """
        x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
        
        res = x
        self.flow = 'source_to_target'

        # Формирование сообщения j -> i
        phi_in = self.propagate(edge_index, x=x, edge_attr=edge_attr, type="in")

        # Формирование сообщения i -> i
        phi_loop = self.propagate(edge_index, x=x, edge_attr=edge_attr, type="loop")

        # Формирование сообщения i -> j
        self.flow = 'target_to_source'
        phi_out = self.propagate(edge_index, x=x, edge_attr=edge_attr, type="out")


        # Добавление петель
        # edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        # Propagate
        out = self.propagate(edge_index, x=x)

        # ResNet fashion
        out = res + self.alpha * out

        return out

    def message(self, x_i: Tensor, x_j: Tensor, edge_attr: OptTensor, edge_index_i: Tensor, edge_index_j: Tensor) -> Tensor:
        """
        Args:
            x_i: Target node features [num_edges, num_features]
            x_j: Source node features [num_edges, num_features]
        """
        tmp = torch.cat([x_i, x_j, edge_attr], dim=-1) # shape [E, 2 * in_channels + num_edge_features]
        return self.mlp(tmp)
    
    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}({self.in_channels}, '
                f'{self.out_channels}, aggr={self.aggr})')


In [71]:
edge_index = torch.tensor([
  [0, 1],
  [1, 2]
], dtype=torch.long)

edge_index, _ = add_self_loops(edge_index)

edge_index, edge_attr, loop_edge_index, loop_edge_attr = segregate_self_loops(edge_index)

# x = torch.tensor([[-1], [0], [1]], dtype=torch.float)
# edge_attr = torch.tensor([
#   [4, 5],
#   [6, 7]
# ])

# data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
# g = torch_geometric.utils.to_networkx(data, to_undirected=True)
# # nx.draw(g, pos=nx.planar_layout(g))

In [72]:
print(edge_index)

tensor([[0, 1],
        [1, 2]])


In [74]:
print(loop_edge_index)

tensor([[0, 1, 2],
        [0, 1, 2]])
