<a href="https://colab.research.google.com/github/MicheleCattaneo/3D_Reconstruction_CV/blob/main/message_passing_gnn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [65]:
import torch
# !pip install torch_geometric
from torch_geometric.data import Data
from torch_geometric.utils import scatter
import torch_geometric.transforms as T
from torch.nn import Linear, Parameter
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree

##### GCN Layer

$$
x_i^{(k)} = \sum_{j\in \mathcal{N}(i)\cup \{i\}}\frac{1}{\sqrt{deg(i)}\cdot\sqrt{deg(j)}}\cdot \left( W^T\cdot x_j^{(k-1)}\right)+b
$$

In [109]:
from torch_geometric.nn.models.dimenet import sqrt
class GCNConv(MessagePassing):
  def __init__(self, in_channels, out_channels):
    super().__init__(aggr='add')
    self.linear = Linear(in_channels, out_channels, bias=False)
    self.b = Parameter(torch.empty(out_channels))


  def forward(self, x, edge_index):
    # x.shape -> [N, in_channels]
    # edge_index.shape -> [2, E]

    # add self loops
    edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

    x = self.linear(x)

    src, tgt = edge_index
    deg = degree(tgt, x.size(0), dtype=x.dtype)
    sqrt_deg_inv = deg.pow(-0.5)
    sqrt_deg_inv[sqrt_deg_inv == float('inf')] = 0
    sqrt_deg_i_inv, sqrt_deg_j_inv = sqrt_deg_inv[src], sqrt_deg_inv[tgt]

    # norm.shape = [E,]
    norm = sqrt_deg_i_inv * sqrt_deg_j_inv

    # pass norm to propagate which will be used in the message
    out = self.propagate(edge_index=edge_index, x=x, norm=norm.view(-1,1))

    out += self.bias

    return out

  def message(self, x_j, norm):
    return norm * x_j
