In [1]:
import torch
import torch.nn as nn

## Equivariant Graph Neural Network
We have a graph $G = (V, E)$ with nodes $v_i \in V$ with $N$ number of nodes and edges $e_{ij} \in E$, each node has a feature vector $h_i \in \mathbb{R}^k$ and an n-dimensional coordinate vector $x_i \in \mathbb{R}^n$.

We would like to preserve equivariance to rotations, translations and permutations.


### Equivariant Graph Convolutional Layer (EGCL)
- **Input**: node embeddings **$ h^l = \{ h^l_0, ..., h^l_{N - 1}\} $**, edge information: **$E$**, coordinate embeddings **$x^l = \{ x^l_0, ..., x^l_{N - 1} \}$**

$$
    \{ H^l, E, X^l \}
$$

- **Output**: a transformation on the node embeddings and the coordinate embeddings

$$
    h^{l+1}, x^{l+1} = EGCL[h^l, x^l, E]
$$

The equations that define this layer are the following:
$$
    m_{ij} = \phi_e(h^l_i, h^l_j, || x^l_i - x^l_j ||^2, a_{ij})
$$
$$
    m_i = \sum_{j \in \mathit{N}(i)} m_{ij}
$$
$$
    x^{l+1}_i = x^{l}_i + C \sum_{j \neq i} (x^l_i - x^l_j) \phi_x(m_{ij})
$$
$$
    h^{l+i}_i = \phi_h(h^l_i, m_i)
$$

Following the paper the edge attributes are just the edge values $a_{ij} = e_{ij}$ but other attribute could also be included. $C = \frac{1}{N-1}$ ( minus one because we have of course $j \neq i $) 
**It's left to be defined $\phi_e, \phi_x, \phi_h$**. Following the Appendices:
- Edge Function **$\phi_e$**:  $\phi_e$ is a two layers MLP with two Swish non-linearities: Input −→ {LinearLayer() −→ Swish()
−→ LinearLayer() −→ Swish() } −→ Output.
- Coordinate Function **$\phi_x$**:  $\phi_x$ ( consists of a two layers MLP with one non-linearity: $m_{ij}$ −→ {LinearLayer() −→
Swish() −→ LinearLayer() } −→ Output
- Node Function **$\phi_h$**: $\phi_h$ consists of a two layers MLP with one non-linearity and a residual connection:
$[h^l_i, m_i]$ −→ {LinearLayer() −→ Swish() −→ LinearLayer() −→ Addition($h^l_i$) } −→ $h^{l+1}$

Notice that the paper **doesn't define how the inputs are combined together for the edge function**. Considering the nature of inputs it's reasonable to define the function in this way:
$$
    \hat{h}_{ij} = a_{ij} || x_i - x_j ||^2 * [h_i, h_j] \\
    m_{ij} = \phi_e (\hat{h}_{ij})
$$ 
Where $a_{ij}$ and $|| x_i - x_j ||^2$ multiplied together act as weight for the concatenation of the two embeddings. Of course others formulations are possible.
Let's first define the Swish activation function:

In [None]:
class Swish(nn.Module):
    def __init__(self):
        super(Swish, self).__init__()
    def forward(self, x):
        return x * torch.sigmoid(x)

### Edge function
The architecture of the model is the one defined in the Appendix C

In [None]:
class Phi_e(nn.Module):
    def __init__(self, n_input_features, n_hidden_features, n_output_features):
        super(Phi_e, self).__init__()
        n_input_features = n_input_features*2
        self.model = nn.Sequential(nn.Linear(n_input_features, n_hidden_features),
                                   Swish(),
                                   nn.Linear(n_hidden_features, n_output_features),
                                   Swish())
    '''
        We cannot assume that the concatenation of [h_i, h_j] must produce the same results than [h_j, h_i]
    '''
    def forward(self, Edges, Coordinates, Embeddings):
        ## Define the input [ a_{ij} || - ||^, h_i, h_j]
        ## Compute the distance matrix
        distance_matrix = torch.cdist(Coordinates, Coordinates, p=2)
        ## Combine the values of the weights
        w = Edges * distance_matrix
        ## Produce all the combinations of [h_i, h_j] also when j = i (it's possible to speed up this procedure)
        def concat_embed(embedding_a, embedding_b):
            return torch.cat((embedding_a, embedding_b))
        H = torch.vmap(func=torch.vmap(concat_embed, in_dims=(None, 0)), in_dims=(0, None))(Embeddings, Embeddings)
        H_hat = w[:, :, None] * H
        n_nodes = w.shape[0]
        H_hat = torch.flatten(H_hat, start_dim=0, end_dim=1)
        M = self.model(H_hat)
        M_unflattened = torch.unflatten(M, 0, (n_nodes, n_nodes, -1)).squeeze()
        return M_unflattened

### Coordinate Function
The architecture of the model is the one defined in the Appendix C

In [None]:
class Phi_x(nn.Module):
    def __init__(self, n_input_features, n_hidden_features, n_output_features):
        super(Phi_x, self).__init__()
        self.model = nn.Sequential(nn.Linear(n_input_features, n_hidden_features),
                                   Swish(),
                                   nn.Linear(n_hidden_features, n_output_features))

    def forward(self, M):
        return self.model(M)

### Node Function
The architecture of the model is the one defined in the Appendix C

In [None]:
class Phi_h(nn.Module):
    def __init__(self, n_input_features, n_hidden_features, n_output_features):
        super(Phi_h, self).__init__()
        self.model = nn.Sequential(nn.Linear(n_input_features, n_hidden_features),
                                   Swish(),
                                   nn.Linear(n_hidden_features, n_output_features))

    def forward(self, H, M):
        input = torch.concatenate([H, M], dim=1)
        out = self.model(input)
        out = out + H
        return out

## Equivariant Graph Convolutional Layer
Here we also implemented the possibility to have velocities for the nodes. As reported
$$
    v^{l+1}_i = \phi_v(h^l_i)v^{l}_i + C \sum_{j \neq i} (x^l_i - x^l_j) \phi_x(m_{ij}) \\
    x^{l+1}_i = x^l_i + v^{l+1}_i
$$

In [None]:
class EGCL(nn.Module):
    def __init__(self, n_input_features, n_hidden_features, n_output_features, with_velocity):
        super(EGCL, self).__init__()

        self.with_velocity = with_velocity
        if with_velocity:
            self.phi_v = nn.Sequential(nn.Linear(n_input_features, n_hidden_features),
                                       Swish(),
                                       nn.Linear(n_hidden_features, 1))
        self.phi_e = Phi_e(n_input_features=n_hidden_features,
                           n_hidden_features=n_hidden_features,
                           n_output_features=n_hidden_features)

        self.phi_x = Phi_x(n_input_features=n_hidden_features,
                           n_hidden_features=n_hidden_features,
                           n_output_features=1)

        self.phi_h = Phi_h(n_input_features=n_hidden_features*2,
                           n_hidden_features = n_hidden_features,
                           n_output_features=n_output_features)

    def forward(self, x):
        if self.with_velocity:
            edges, coordinates, embeddings, velocities = x
        else:
            edges, coordinates, embeddings = x
        M_edge = self.phi_e(edges, coordinates, embeddings)
        n_nodes = M_edge.shape[0]
        difference_position_fun = lambda x_i, x_j: x_i - x_j
        differences_of_positions = torch.vmap(func=torch.vmap(difference_position_fun, in_dims=(None, 0)), in_dims=(0, None))(coordinates, coordinates)

        M_edge_flatten = torch.flatten(M_edge, start_dim=0, end_dim=1)
        M_x = self.phi_x(M_edge_flatten)
        M_x_unflattened = torch.unflatten(M_x, 0, (n_nodes, n_nodes, -1)).squeeze()

        partials = (M_x_unflattened * (torch.ones((n_nodes, n_nodes), device=M_edge.device) - torch.eye(n=n_nodes, device=M_edge.device))).unsqueeze(-1)
        del M_x_unflattened
        if self.with_velocity:
            new_velocities = self.phi_v(embeddings)*velocities + (1/(n_nodes - 1))* torch.sum(differences_of_positions * partials, dim=0)
            new_coordinates = coordinates + new_velocities
        else:
            new_coordinates = coordinates + (1/(n_nodes - 1))* torch.sum(differences_of_positions * partials, dim=0)

        del differences_of_positions
        M_aggregated = torch.sum(M_edge * edges.unsqueeze(-1), dim=1)
        new_embeddings = self.phi_h(embeddings, M_aggregated)
        if self.with_velocity:
            return edges, new_coordinates, new_embeddings, new_velocities
        else:
            return edges, new_coordinates, new_embeddings

## E(n) Equivariant Graph Neural Network
Here we report the architectural structure of the model for the EGNN in experiments for the QM9 dataset:
- _"Our EGNN consists of 7 layers.[...]"_ 
- _"Finally, the output of our EGNN hL is forwarded through a two layers MLP that acts  ode-wise, a sum pooling operation and another two layers MLP that maps the averaged embedding to the predicted property value, more formally: hL −→ {Linear() −→ Swish() −→ Linear() −→ Sum-Pooling() −→ Linear() −→ Swish() −→ Linear} −→ Property. The number of hidden features for all model hidden layers is 128"_

In [None]:
class EnGNN(nn.Module):
    def __init__(self, n_input_features, n_hidden_features, n_output_features, with_velocity=False):
        super(EnGNN, self).__init__()
        self.with_velocity = with_velocity
        self.embed = nn.Linear(n_input_features, n_hidden_features)
        self.model = nn.Sequential(EGCL(n_hidden_features, n_hidden_features, n_hidden_features, with_velocity),
                                   EGCL(n_hidden_features, n_hidden_features, n_hidden_features, with_velocity),
                                   EGCL(n_hidden_features, n_hidden_features, n_hidden_features, with_velocity),
                                   EGCL(n_hidden_features, n_hidden_features, n_hidden_features, with_velocity),
                                   EGCL(n_hidden_features, n_hidden_features, n_hidden_features, with_velocity),
                                   EGCL(n_hidden_features, n_hidden_features, n_hidden_features, with_velocity),
                                   EGCL(n_hidden_features, n_hidden_features, n_hidden_features, with_velocity))
        self.head_model_1 = nn.Sequential(nn.Linear(n_hidden_features, n_hidden_features),
                                        Swish(),
                                        nn.Linear(n_hidden_features, n_hidden_features))
        self.head_model_2 = nn.Sequential(nn.Linear(n_hidden_features, n_hidden_features),
                                        Swish(),
                                        nn.Linear(n_hidden_features, n_output_features))

        self.edge_inferring_model = nn.Sequential(nn.Linear(n_hidden_features, 1),
                                                  nn.Sigmoid())

    def forward(self, Edges, Coordinates, Embeddings, batch_pointer, velocities=None):
        Embeddings = self.embed(Embeddings)
        if self.with_velocity:
            edges, coordinates, embeddings, velocities = self.model((Edges, Coordinates, Embeddings, velocities))
        else:
            edges, coordinates, embeddings = self.model((Edges, Coordinates, Embeddings))
        embeddings = self.head_model_1(embeddings)
        ### Aggregation with pointer not implemented yet from torch geometric
        ranges = batch_pointer.unfold(0, 2, 1)
        embeddings_list = []
        for indeces in ranges:
            embeddings_list.append(torch.sum(embeddings[indeces[0]:indeces[1]], dim=0 ))
        ####################################################################à
        embeddings = torch.stack(embeddings_list)
        out = self.head_model_2(embeddings)
        return out

## Notes 
The article unfortunately leaves some implementation details not totally clear that we have to infer:
- The definition of how the input of $\phi_e$ are combined is not defined:
$$
\phi_e(h^l_i, h^l_j, || x^l_i - x^l_j ||^2, a_{ij})
$$
- The article uses $\phi_{\text{inf}}$ to infer the edges of the graph starting from a fully connected one. The article also report that $\phi_{\text{inf}}(m_{ij})$ so depending on the output of $\phi_e$ that is part of a EGCL. This leaves ambiguity if $\phi_{\text{inf}}$ infer the edges at each layer starting from a situation of a fully connected; or if the edges are inferred only for the first layer. This is not clear.