In [10]:
#|default_exp modelsgraph

In [11]:
#| export
import sys
sys.path.append('/opt/slh/archive/software/graphnet/src')

from x_transformers import ContinuousTransformerWrapper, Encoder, Decoder
from graphnet.models.task.reconstruction import DirectionReconstructionWithKappa
from graphnet.training.loss_functions import VonMisesFisher3DLoss
from graphnet.training.labels import Direction

from typing import Callable, Optional, Union
import torch
from torch import nn
from torch.nn import functional as F
import torch_geometric
from torch_geometric.nn import SchNet, global_add_pool, global_mean_pool
import torch_scatter
from torch_scatter import scatter
from torch.nn import Linear, ReLU, SiLU, Sequential
from torch_geometric.nn import MessagePassing, global_add_pool, global_mean_pool
from torch_scatter import scatter
from torch_geometric.loader import DataLoader as gDataLoader


In [12]:
sys.path.append('/opt/slh/icecube')
from icecube.graphdataset import GraphDasetV0
from datasets import  load_from_disk

In [13]:
#| export

class EGNNLayer(MessagePassing):
    def __init__(self, emb_dim, activation="relu", norm="layer", aggr="add"):
        """E(n) Equivariant GNN Layer
        Paper: E(n) Equivariant Graph Neural Networks, Satorras et al.
        
        Args:
            emb_dim: (int) - hidden dimension `d`
            activation: (str) - non-linearity within MLPs (swish/relu)
            norm: (str) - normalisation layer (layer/batch)
            aggr: (str) - aggregation function `\oplus` (sum/mean/max)
        """
        # Set the aggregation function
        super().__init__(aggr=aggr)

        self.emb_dim = emb_dim
        self.activation = {"swish": SiLU(), "relu": ReLU()}[activation]
        self.norm = {"layer": torch.nn.LayerNorm, "batch": torch.nn.BatchNorm1d}[norm]

        # MLP `\psi_h` for computing messages `m_ij`
        self.mlp_msg = Sequential(
            Linear(2 * emb_dim + 1, emb_dim),
            self.norm(emb_dim),
            self.activation,
            Linear(emb_dim, emb_dim),
            self.norm(emb_dim),
            self.activation,
        )
        # MLP `\psi_x` for computing messages `\overrightarrow{m}_ij`
        self.mlp_pos = Sequential(
            Linear(emb_dim, emb_dim), self.norm(emb_dim), self.activation, Linear(emb_dim, 1)
        )
        # MLP `\phi` for computing updated node features `h_i^{l+1}`
        self.mlp_upd = Sequential(
            Linear(2 * emb_dim, emb_dim),
            self.norm(emb_dim),
            self.activation,
            Linear(emb_dim, emb_dim),
            self.norm(emb_dim),
            self.activation,
        )

    def forward(self, h, pos, edge_index):
        """
        Args:
            h: (n, d) - initial node features
            pos: (n, 3) - initial node coordinates
            edge_index: (e, 2) - pairs of edges (i, j)
        Returns:
            out: [(n, d),(n,3)] - updated node features
        """
        out = self.propagate(edge_index, h=h, pos=pos)
        return out

    def message(self, h_i, h_j, pos_i, pos_j):
        # Compute messages
        pos_diff = pos_i - pos_j
        dists = torch.norm(pos_diff, dim=-1).unsqueeze(1)
        msg = torch.cat([h_i, h_j, dists], dim=-1)
        msg = self.mlp_msg(msg)
        # Scale magnitude of displacement vector
        pos_diff = pos_diff * self.mlp_pos(msg)  # torch.clamp(updates, min=-100, max=100)
        return msg, pos_diff

    def aggregate(self, inputs, index):
        msgs, pos_diffs = inputs
        # Aggregate messages
        msg_aggr = scatter(msgs, index, dim=self.node_dim, reduce=self.aggr)
        # Aggregate displacement vectors
        pos_aggr = scatter(pos_diffs, index, dim=self.node_dim, reduce="mean")
        return msg_aggr, pos_aggr

    def update(self, aggr_out, h, pos):
        msg_aggr, pos_aggr = aggr_out
        upd_out = self.mlp_upd(torch.cat([h, msg_aggr], dim=-1))
        upd_pos = pos + pos_aggr
        return upd_out, upd_pos

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(emb_dim={self.emb_dim}, aggr={self.aggr})"
    
    
class EGNNModel(torch.nn.Module):
    def __init__(
        self,
        num_layers=5,
        emb_dim=128,
        in_dim=9,
        out_dim=4,
        activation="relu",
        norm="layer",
        aggr="sum",
        pool="sum",
        residual=True
    ):
        """E(n) Equivariant GNN model 
        
        Args:
            num_layers: (int) - number of message passing layers
            emb_dim: (int) - hidden dimension
            in_dim: (int) - initial node feature dimension
            out_dim: (int) - output number of classes
            activation: (str) - non-linearity within MLPs (swish/relu)
            norm: (str) - normalisation layer (layer/batch)
            aggr: (str) - aggregation function `\oplus` (sum/mean/max)
            pool: (str) - global pooling function (sum/mean)
            residual: (bool) - whether to use residual connections
        """
        super().__init__()

        # Embedding lookup for initial node features
        self.emb_in = nn.Linear(in_dim, emb_dim)

        # Stack of GNN layers
        self.convs = torch.nn.ModuleList()
        for layer in range(num_layers):
            self.convs.append(EGNNLayer(emb_dim, activation, norm, aggr))

        # Global pooling/readout function
        self.pool = {"mean": global_mean_pool, "sum": global_add_pool}[pool]

        # Predictor MLP
        self.postpool = torch.nn.Sequential(
            torch.nn.Linear(emb_dim, emb_dim),
            torch.nn.ReLU()
        )
        
        self.out = DirectionReconstructionWithKappa(
            hidden_size=emb_dim,
            target_labels='direction',
            loss_function=VonMisesFisher3DLoss(),
        )

        self.residual = residual

    def forward(self, batch):
        
        h = self.emb_in(batch.x)  # (n,) -> (n, d)
        pos = batch.pos  # (n, 3)

        for conv in self.convs:
            # Message passing layer
            h_update, pos_update = conv(h, pos, batch.edge_index)

            # Update node features (n, d) -> (n, d)
            h = h + h_update if self.residual else h_update 

            # Update node coordinates (no residual) (n, 3) -> (n, 3)
            pos = pos_update

        out = self.pool(h, batch.batch)  # (n, d) -> (batch_size, d)
        out = self.postpool(out) 
        out = self.out(out)  # (batch_size, d) -> (batch_size, 1)
        return out


In [14]:
ds = GraphDasetV0(load_from_disk('/opt/slh/icecube/data/hf_cashe/batch_1.parquet'))
dl = gDataLoader(ds, batch_size=64, shuffle=True)
x = next(iter(dl))



In [15]:
md = EGNNModel().eval()
with torch.no_grad():
    out = md(x)

In [16]:
out.shape

torch.Size([64, 4])

In [17]:
x.y.reshape(-1, 3)

tensor([[-0.3024, -0.2626, -0.9163],
        [ 0.5685, -0.7906, -0.2276],
        [-0.5923,  0.4075,  0.6950],
        [ 0.9434, -0.0632,  0.3256],
        [-0.3884,  0.4468,  0.8059],
        [ 0.9331,  0.3209, -0.1622],
        [ 0.2757,  0.1921,  0.9419],
        [-0.2261,  0.7435, -0.6293],
        [-0.1910, -0.5903,  0.7843],
        [ 0.4528, -0.7029,  0.5486],
        [ 0.3039, -0.9135,  0.2705],
        [ 0.5533, -0.2467,  0.7956],
        [ 0.0754, -0.3414,  0.9369],
        [-0.4229,  0.8572,  0.2937],
        [ 0.8806,  0.1395, -0.4529],
        [-0.1885,  0.7806,  0.5959],
        [-0.1094, -0.1582,  0.9813],
        [ 0.9775, -0.0162,  0.2105],
        [ 0.9373, -0.3399,  0.0770],
        [-0.2167, -0.2301, -0.9487],
        [-0.9903, -0.1197,  0.0710],
        [-0.5924,  0.6438,  0.4843],
        [-0.0430,  0.8164, -0.5759],
        [-0.6117, -0.2344,  0.7556],
        [-0.3599, -0.5484, -0.7548],
        [-0.9938, -0.1021,  0.0447],
        [-0.6852, -0.7262, -0.0564],
 

In [18]:
#|hide
#|eval: false
from nbdev.doclinks import nbdev_export
nbdev_export()