In [20]:
#|default_exp modelsgraph

In [21]:
#| export
import sys
sys.path.append('/opt/slh/archive/software/graphnet/src')


from x_transformers import ContinuousTransformerWrapper, Encoder, Decoder
from graphnet.models.task.reconstruction import DirectionReconstructionWithKappa
from graphnet.training.loss_functions import VonMisesFisher3DLoss
from graphnet.training.labels import Direction

from typing import Callable, Optional, Union
import torch
from torch import nn
from torch.nn import functional as F
import torch_geometric
from torch_geometric.nn import SchNet, global_add_pool, global_mean_pool
import torch_scatter
from torch_scatter import scatter
from torch.nn import Linear, ReLU, SiLU, Sequential
from torch_geometric.nn import MessagePassing, global_add_pool, global_mean_pool
from torch_scatter import scatter
from torch_geometric.loader import DataLoader as gDataLoader


In [22]:
sys.path.append('/opt/slh/icecube')
from icecube.graphdataset import GraphDasetV0
from datasets import  load_from_disk

In [23]:
#| export

class EGNNLayer(MessagePassing):
    def __init__(self, emb_dim, activation="relu", norm="layer", aggr="add"):
        """E(n) Equivariant GNN Layer
        Paper: E(n) Equivariant Graph Neural Networks, Satorras et al.
        
        Args:
            emb_dim: (int) - hidden dimension `d`
            activation: (str) - non-linearity within MLPs (swish/relu)
            norm: (str) - normalisation layer (layer/batch)
            aggr: (str) - aggregation function `\oplus` (sum/mean/max)
        """
        # Set the aggregation function
        super().__init__(aggr=aggr)

        self.emb_dim = emb_dim
        self.activation = {"swish": SiLU(), "relu": ReLU()}[activation]
        self.norm = {"layer": torch.nn.LayerNorm, "batch": torch.nn.BatchNorm1d}[norm]

        # MLP `\psi_h` for computing messages `m_ij`
        self.mlp_msg = Sequential(
            Linear(2 * emb_dim + 1, emb_dim),
            self.norm(emb_dim),
            self.activation,
            Linear(emb_dim, emb_dim),
            self.norm(emb_dim),
            self.activation,
        )
        # MLP `\psi_x` for computing messages `\overrightarrow{m}_ij`
        self.mlp_pos = Sequential(
            Linear(emb_dim, emb_dim), self.norm(emb_dim), self.activation, Linear(emb_dim, 1)
        )
        # MLP `\phi` for computing updated node features `h_i^{l+1}`
        self.mlp_upd = Sequential(
            Linear(2 * emb_dim, emb_dim),
            self.norm(emb_dim),
            self.activation,
            Linear(emb_dim, emb_dim),
            self.norm(emb_dim),
            self.activation,
        )

    def forward(self, h, pos, edge_index):
        """
        Args:
            h: (n, d) - initial node features
            pos: (n, 3) - initial node coordinates
            edge_index: (e, 2) - pairs of edges (i, j)
        Returns:
            out: [(n, d),(n,3)] - updated node features
        """
        out = self.propagate(edge_index, h=h, pos=pos)
        return out

    def message(self, h_i, h_j, pos_i, pos_j):
        # Compute messages
        pos_diff = pos_i - pos_j
        dists = torch.norm(pos_diff, dim=-1).unsqueeze(1)
        msg = torch.cat([h_i, h_j, dists], dim=-1)
        msg = self.mlp_msg(msg)
        # Scale magnitude of displacement vector
        pos_diff = pos_diff * self.mlp_pos(msg)  # torch.clamp(updates, min=-100, max=100)
        return msg, pos_diff

    def aggregate(self, inputs, index):
        msgs, pos_diffs = inputs
        # Aggregate messages
        msg_aggr = scatter(msgs, index, dim=self.node_dim, reduce=self.aggr)
        # Aggregate displacement vectors
        pos_aggr = scatter(pos_diffs, index, dim=self.node_dim, reduce="mean")
        return msg_aggr, pos_aggr

    def update(self, aggr_out, h, pos):
        msg_aggr, pos_aggr = aggr_out
        upd_out = self.mlp_upd(torch.cat([h, msg_aggr], dim=-1))
        upd_pos = pos + pos_aggr
        return upd_out, upd_pos

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(emb_dim={self.emb_dim}, aggr={self.aggr})"
    
    
class EGNNModel(torch.nn.Module):
    def __init__(
        self,
        num_layers=5,
        emb_dim=128,
        in_dim=9,
        out_dim=4,
        activation="relu",
        norm="layer",
        aggr="sum",
        pool="sum",
        residual=True
    ):
        """E(n) Equivariant GNN model 
        
        Args:
            num_layers: (int) - number of message passing layers
            emb_dim: (int) - hidden dimension
            in_dim: (int) - initial node feature dimension
            out_dim: (int) - output number of classes
            activation: (str) - non-linearity within MLPs (swish/relu)
            norm: (str) - normalisation layer (layer/batch)
            aggr: (str) - aggregation function `\oplus` (sum/mean/max)
            pool: (str) - global pooling function (sum/mean)
            residual: (bool) - whether to use residual connections
        """
        super().__init__()

        # Embedding lookup for initial node features
        self.emb_in = nn.Linear(in_dim, emb_dim)

        # Stack of GNN layers
        self.convs = torch.nn.ModuleList()
        for layer in range(num_layers):
            self.convs.append(EGNNLayer(emb_dim, activation, norm, aggr))

        # Global pooling/readout function
        self.pool = {"mean": global_mean_pool, "sum": global_add_pool}[pool]

        # Predictor MLP
        self.postpool = torch.nn.Sequential(
            torch.nn.Linear(emb_dim, emb_dim),
            torch.nn.ReLU()
        )
        
        self.out = DirectionReconstructionWithKappa(
            hidden_size=emb_dim,
            target_labels='direction',
            loss_function=VonMisesFisher3DLoss(),
        )

        self.residual = residual

    def forward(self, batch):
        
        h = self.emb_in(batch.x)  # (n,) -> (n, d)
        pos = batch.pos  # (n, 3)

        for conv in self.convs:
            # Message passing layer
            h_update, pos_update = conv(h, pos, batch.edge_index)

            # Update node features (n, d) -> (n, d)
            h = h + h_update if self.residual else h_update 

            # Update node coordinates (no residual) (n, 3) -> (n, 3)
            pos = pos_update

        out = self.pool(h, batch.batch)  # (n, d) -> (batch_size, d)
        out = self.postpool(out) 
        out = self.out(out)  # (batch_size, d) -> (batch_size, 1)
        return out


In [24]:
ds = GraphDasetV0(load_from_disk('/opt/slh/icecube/data/hf_cashe/batch_1.parquet'))
dl = gDataLoader(ds, batch_size=64, shuffle=True)
x = next(iter(dl))



In [25]:
md = EGNNModel().eval()
with torch.no_grad():
    out = md(x)

In [26]:
out.shape

torch.Size([64, 4])

In [27]:
x.y.reshape(-1, 3)

tensor([[ 4.3360e-01,  7.8692e-01, -4.3903e-01],
        [ 3.5603e-01, -6.5630e-01,  6.6522e-01],
        [ 5.6932e-01,  5.3317e-01,  6.2578e-01],
        [-8.7482e-01, -4.5451e-01,  1.6765e-01],
        [-8.0436e-01, -4.1511e-01, -4.2506e-01],
        [-4.7833e-01,  1.7376e-01,  8.6082e-01],
        [ 2.5790e-01,  9.4348e-01, -2.0815e-01],
        [-7.0665e-02,  9.0034e-01,  4.2940e-01],
        [ 4.8997e-01,  2.4937e-02, -8.7138e-01],
        [ 7.0611e-01,  1.3525e-01, -6.9506e-01],
        [ 4.0584e-01, -9.1394e-01, -7.8477e-04],
        [-8.5595e-01,  3.1931e-01, -4.0668e-01],
        [-6.3817e-02, -8.8359e-01,  4.6389e-01],
        [-7.1645e-01,  6.9421e-01, -6.9115e-02],
        [-1.7384e-01, -9.2958e-01,  3.2505e-01],
        [ 2.4919e-01, -9.6507e-01, -8.0927e-02],
        [ 6.3849e-01, -1.8381e-01, -7.4736e-01],
        [-4.6091e-01, -2.2264e-01, -8.5907e-01],
        [ 1.3469e-01,  8.9402e-01,  4.2731e-01],
        [ 3.2838e-01,  1.9967e-01, -9.2320e-01],
        [-1.0364e-01

In [28]:
#|hide
#|eval: false
from nbdev.doclinks import nbdev_export
nbdev_export()