In [129]:
#|default_exp models

In [130]:
#| export
import sys
sys.path.append('/opt/slh/archive/software/graphnet/src')
import torch
from x_transformers import ContinuousTransformerWrapper, Encoder, Decoder
from torch import nn, einsum
from einops import rearrange
import torch.nn.functional as F
from datasets import load_from_disk
from abc import abstractmethod
from torch import Tensor
from typing import Optional, Any
import scipy
import numpy as np
from graphnet.models.task.reconstruction import DirectionReconstructionWithKappa
from graphnet.training.loss_functions import VonMisesFisher3DLoss
from graphnet.training.labels import Direction


In [131]:
# | export

class VonMisesFisher3DLossCosineSimularityLoss(nn.Module):
    def __init__(self, eps=1e-8):
        super().__init__()
        self.vonmis = VonMisesFisher3DLoss()
        self.cosine = nn.CosineSimilarity(dim=1, eps=eps)
        
    def forward(self, y_pred, y_true):
        return (self.vonmis(y_pred, y_true) + (1-self.cosine(y_pred[:, :3], y_true).mean()))/2
        
    
    
class LogCoshLoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, y_t, y_prime_t):
        ey_t = y_t - y_prime_t
        return torch.mean(torch.log(torch.cosh(ey_t + 1e-12)))


class SigmoidRange(nn.Module):
    def __init__(self, low, high):
        super().__init__()
        self.low = low
        self.high = high

    def forward(self, x):
        return torch.sigmoid(x) * (self.high - self.low) + self.low


class Adjustoutput(nn.Module):
    def __init__(self):
        super().__init__()
        self.az = SigmoidRange(6.436839548775502e-08, 6.2891)
        self.zn = SigmoidRange(8.631674577710722e-05, 3.1417)

    def forward(self, x):
        x[:, 0] = self.az(x[:, 0])
        x[:, 1] = self.zn(x[:, 1])
        return x


class PoolingWithMask(nn.Module):
    def __init__(self, pool_type):
        super().__init__()
        self.pool_type = pool_type

    def forward(self, x, mask):
        # Multiply the mask with the input tensor to zero out the padded values
        x = x * mask.unsqueeze(-1)

        if self.pool_type == "mean":
            # Sum the values along the sequence dimension
            x = torch.sum(x, dim=1)

            # Divide the sum by the number of non-padded values (i.e. the sum of the mask)
            x = x / torch.sum(mask, dim=1, keepdim=True)
        elif self.pool_type == "max":
            # Find the maximum value along the sequence dimension
            x, _ = torch.max(x, dim=1)
        elif self.pool_type == "min":
            # Find the minimum value along the sequence dimension
            x, _ = torch.min(x, dim=1)
        else:
            raise ValueError("Invalid pool_type. Choose from ['mean', 'max', 'min']")

        return x


class MeanPoolingWithMask(nn.Module):
    def __init__(self):
        super(MeanPoolingWithMask, self).__init__()

    def forward(self, x, mask):
        # Multiply the mask with the input tensor to zero out the padded values
        x = x * mask.unsqueeze(-1)

        # Sum the values along the sequence dimension
        x = torch.sum(x, dim=1)

        # Divide the sum by the number of non-padded values (i.e. the sum of the mask)
        x = x / torch.sum(mask, dim=1, keepdim=True)

        return x


class FeedForward(nn.Module):
    def __init__(self, dim, dim_out=None, mult=4):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, dim * mult), nn.GELU(), nn.Linear(dim * mult, dim_out)
        )

    def forward(self, x):
        return self.net(x)


class IceCubeModelEncoderV0(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = ContinuousTransformerWrapper(
            dim_in=6,
            dim_out=128,
            max_seq_len=150,
            attn_layers=Encoder(dim=128, depth=6, heads=8),
        )

        # self.pool = MeanPoolingWithMask()
        self.head = FeedForward(128, 2)

    def forward(self, batch):
        x, mask = batch["event"], batch["mask"]
        x = self.encoder(x, mask=mask)
        x = x.mean(dim=1)
        x = self.head(x)
        return x


class IceCubeModelEncoderV1(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = ContinuousTransformerWrapper(
            dim_in=8,
            dim_out=128,
            max_seq_len=150,
            attn_layers=Encoder(dim=128, depth=6, heads=8),
        )

        self.pool = MeanPoolingWithMask()
        self.head = FeedForward(128, 2)

    def forward(self, batch):
        x, mask = batch["event"], batch["mask"]
        x = self.encoder(x, mask=mask)
        x = self.pool(x, mask)
        x = self.head(x)
        return x


class IceCubeModelEncoderV1CombinePool(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = ContinuousTransformerWrapper(
            dim_in=8,
            dim_out=128,
            max_seq_len=150,
            attn_layers=Encoder(dim=128, depth=6, heads=8),
        )

        self.pool_mean = PoolingWithMask('mean')
        self.pool_max = PoolingWithMask('max')
        self.head = FeedForward(128 * 2, 2)

    def forward(self, batch):
        x, mask = batch["event"], batch["mask"]
        x = self.encoder(x, mask=mask)
        x = torch.concat([self.pool_mean(x, mask), self.pool_max(x, mask)], dim=1)
        x = self.head(x)
        return x


class always:
    def __init__(self, val):
        self.val = val

    def __call__(self, *args, **kwargs):
        return self.val


def l2norm(t, groups=1):
    t = rearrange(t, "... (g d) -> ... g d", g=groups)
    t = F.normalize(t, p=2, dim=-1)
    return rearrange(t, "... g d -> ... (g d)")


class TokenEmbedding(nn.Module):
    def __init__(self, dim, num_tokens, l2norm_embed=False):
        super().__init__()
        self.l2norm_embed = l2norm_embed
        self.emb = nn.Embedding(num_tokens, dim, padding_idx=0)

    def forward(self, x):
        token_emb = self.emb(x)
        return l2norm(token_emb) if self.l2norm_embed else token_emb

    def init_(self):
        nn.init.kaiming_normal_(self.emb.weight)


class IceCubeModelEncoderSensorEmbeddinng(nn.Module):
    def __init__(self, dim=128, in_features=14):
        super().__init__()
        self.token_emb = TokenEmbedding(dim, num_tokens=5161)
        self.post_norma = nn.LayerNorm(dim)
        self.token_emb.init_()
        self.encoder = ContinuousTransformerWrapper(
            dim_in=in_features + dim,
            dim_out=256,
            max_seq_len=150,
            attn_layers=Encoder(dim=256, depth=6, heads=8),
        )

        self.pool = MeanPoolingWithMask()
        self.head = FeedForward(256, 2)

    def forward(self, batch):
        x, mask, sensor_id = batch["event"], batch["mask"], batch["sensor_id"]
        embed = self.token_emb(sensor_id)
        embed = self.post_norma(embed)
        x = torch.cat([x, embed], dim=-1)
        x = self.encoder(x, mask=mask)
        x = self.pool(x, mask)
        x = self.head(x)
        return x


class IceCubeModelEncoderSensorEmbeddinngV1(nn.Module):
    def __init__(self, dim=128, in_features=6):
        super().__init__()
        self.token_emb = TokenEmbedding(dim, num_tokens=5161)
        self.post_norma = nn.LayerNorm(dim)
        self.token_emb.init_()
        self.encoder = ContinuousTransformerWrapper(
            dim_in=in_features + dim,
            dim_out=256,
            max_seq_len=150,
            attn_layers=Encoder(dim=256, depth=6, heads=8),
        )

        self.pool = MeanPoolingWithMask()
        self.head = FeedForward(256, 2)

    def forward(self, batch):
        x, mask, sensor_id = batch["event"], batch["mask"], batch["sensor_id"]
        embed = self.token_emb(sensor_id)
        embed = self.post_norma(embed)
        x = torch.cat([x, embed], dim=-1)
        x = self.encoder(x, mask=mask)
        x = self.pool(x, mask)
        x = self.head(x)
        return x


class TokenEmbeddingV2(nn.Module):
    def __init__(self, dim, num_tokens, l2norm_embed=False):
        super().__init__()
        self.l2norm_embed = l2norm_embed
        self.emb = nn.Embedding(num_tokens, dim)

    def forward(self, x):
        token_emb = self.emb(x)
        return l2norm(token_emb) if self.l2norm_embed else token_emb

    def init_(self):
        nn.init.kaiming_normal_(self.emb.weight)


class IceCubeModelEncoderSensorEmbeddinngV2(nn.Module):
    def __init__(self, dim=128, in_features=6):
        super().__init__()
        self.token_emb = TokenEmbeddingV2(dim, num_tokens=5161)
        self.post_norma = nn.LayerNorm(dim)
        self.token_emb.init_()
        self.encoder = ContinuousTransformerWrapper(
            dim_in=in_features + dim,
            dim_out=256,
            max_seq_len=196,
            attn_layers=Encoder(dim=256, depth=6, heads=8),
        )

        self.pool = MeanPoolingWithMask()
        self.head = FeedForward(256, 2)

    def forward(self, batch):
        x, mask, sensor_id = batch["event"], batch["mask"], batch["sensor_id"]
        embed = self.token_emb(sensor_id)
        x = torch.cat([x, embed], dim=-1)
        x = self.encoder(x, mask=mask)
        x = self.pool(x, mask)
        x = self.head(x)
        return x


class IceCubeModelEncoderSensorEmbeddinngV3(nn.Module):
    def __init__(self, dim=128, in_features=6):
        super().__init__()
        self.token_emb = TokenEmbeddingV2(dim, num_tokens=5161)
        self.post_norma = nn.LayerNorm(dim)
        self.token_emb.init_()
        self.encoder = ContinuousTransformerWrapper(
            dim_in=in_features + dim,
            dim_out=256,
            max_seq_len=150,
            attn_layers=Encoder(dim=256, depth=6, heads=8),
        )

        self.pool = MeanPoolingWithMask()
        self.head = FeedForward(256, 2)
        self.sigmout = Adjustoutput()

    def forward(self, batch):
        x, mask, sensor_id = batch["event"], batch["mask"], batch["sensor_id"]
        embed = self.token_emb(sensor_id)
        x = torch.cat([x, embed], dim=-1)
        x = self.encoder(x, mask=mask)
        x = self.pool(x, mask)
        x = self.head(x)
        s = self.sigmout(x)
        return x


In [132]:
model = IceCubeModelEncoderV1CombinePool().eval()
event = torch.rand(1, 100, 8)
mask = torch.ones(1, 100, dtype=torch.bool)
sensor_id = torch.randint(0, 5161, (1, 100))
input = dict(event=event, mask=mask, sensor_id=sensor_id)
with torch.no_grad():
    y = model(input)

y.shape

torch.Size([1, 2])

In [133]:
# | export


class IceCubeModelEncoderV2(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = ContinuousTransformerWrapper(
            dim_in=8,
            dim_out=128,
            max_seq_len=150,
            attn_layers=Encoder(dim=128, depth=6, heads=8),
        )

        self.pool = MeanPoolingWithMask()
        self.out = DirectionReconstructionWithKappa(128)

    def forward(self, batch):
        x, mask = batch["event"], batch["mask"]
        x = self.encoder(x, mask=mask)
        x = self.pool(x, mask)
        return self.out(x)


class EncoderWithDirectionReconstruction(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = ContinuousTransformerWrapper(
            dim_in=8,
            dim_out=128,
            max_seq_len=150,
            attn_layers=Encoder(dim=128, depth=6, heads=8),
        )

        self.pool_mean = PoolingWithMask("mean")
        self.pool_max = PoolingWithMask("max")
        self.out = DirectionReconstructionWithKappa(
            hidden_size=256,
            target_labels='direction',
            loss_function=VonMisesFisher3DLoss(),
        )

    def forward(self, batch):
        x, mask = batch["event"], batch["mask"]
        x = self.encoder(x, mask=mask)
        x = torch.concat([self.pool_mean(x, mask), self.pool_max(x, mask)], dim=1)
        return self.out(x)


class EncoderWithDirectionReconstructionV1(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = ContinuousTransformerWrapper(
            dim_in=9,
            dim_out=128,
            max_seq_len=150,
            attn_layers=Encoder(dim=128, depth=6, heads=8),
        )

        self.pool_mean = PoolingWithMask("mean")
        self.pool_max = PoolingWithMask("max")
        self.pool_min = PoolingWithMask("min")
        self.ae = FeedForward(384, 384)
        self.out = DirectionReconstructionWithKappa(
            hidden_size=384,
            target_labels='direction',
            loss_function=VonMisesFisher3DLoss(),
        )

    def forward(self, batch):
        x, mask = batch["event"], batch["mask"]
        x = self.encoder(x, mask=mask)
        x = torch.concat([self.pool_mean(x, mask), 
                          self.pool_max(x, mask), 
                          self.pool_min(x, mask)], dim=1)
        x = self.ae(x)
        return self.out(x)
    

class EncoderWithDirectionReconstructionV2(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = ContinuousTransformerWrapper(
            dim_in=9,
            dim_out=128,
            max_seq_len=150,
            attn_layers=Encoder(dim=128, depth=6, heads=8),
        )

        self.pool_mean = PoolingWithMask("mean")
        self.pool_max = PoolingWithMask("max")
        self.out = DirectionReconstructionWithKappa(
            hidden_size=256,
            target_labels='direction',
            loss_function=VonMisesFisher3DLoss(),
        )

    def forward(self, batch):
        x, mask = batch["event"], batch["mask"]
        x = self.encoder(x, mask=mask)
        x = torch.concat([self.pool_mean(x, mask), self.pool_max(x, mask)], dim=1)
        return self.out(x)



In [134]:
Direction()

<graphnet.training.labels.Direction at 0x7f9312604650>

In [135]:
model = EncoderWithDirectionReconstructionV1().eval()
event = torch.rand(1, 100, 9)
mask = torch.ones(1, 100, dtype=torch.bool)
sensor_id = torch.randint(0, 5161, (1, 100))
label = torch.rand(1, 3)
input = dict(event=event, mask=mask, sensor_id=sensor_id, label=label)
with torch.no_grad():
    y = model(input)

y.shape

torch.Size([1, 4])

In [136]:
nn.CosineSimilarity(dim=1, eps=1e-6)(torch.rand(10, 4), torch.rand(10, 4)).mean()

tensor(0.8334)

In [137]:
VonMisesFisher3DLossCosineSimularityLoss()(y, label)

tensor(1.8100)

### Graph Transformers

In [138]:
# | export
# MOLECULAR TRANFORMER
DIST_KERNELS = {
    "exp": {
        "fn": lambda t: torch.exp(-t),
        "mask_value_fn": lambda t: torch.finfo(t.dtype).max,
    },
    "softmax": {
        "fn": lambda t: torch.softmax(t, dim=-1),
        "mask_value_fn": lambda t: -torch.finfo(t.dtype).max,
    },
}

# helpers


def exists(val):
    return val is not None


def default(val, d):
    return d if not exists(val) else val


# helper classes


class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, **kwargs):
        return x + self.fn(x, **kwargs)


class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn

    def forward(self, x, **kwargs):
        x = self.norm(x)
        return self.fn(x, **kwargs)


class FeedForwardV1(nn.Module):
    def __init__(self, dim, dim_out=None, mult=4):
        super().__init__()
        dim_out = default(dim_out, dim)
        self.net = nn.Sequential(
            nn.Linear(dim, dim * mult), nn.GELU(), nn.Linear(dim * mult, dim_out)
        )

    def forward(self, x):
        return self.net(x)


class Attention(nn.Module):
    def __init__(
        self, dim, heads=8, dim_head=64, Lg=0.5, Ld=0.5, La=1, dist_kernel_fn="exp"
    ):
        super().__init__()
        inner_dim = dim_head * heads
        self.heads = heads
        self.scale = dim_head**-0.5
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
        self.to_out = nn.Linear(inner_dim, dim)

        # hyperparameters controlling the weighted linear combination from
        # self-attention (La)
        # adjacency graph (Lg)
        # pair-wise distance matrix (Ld)

        self.La = La
        self.Ld = Ld
        self.Lg = Lg

        self.dist_kernel_fn = dist_kernel_fn

    def forward(self, x, mask=None, adjacency_mat=None, distance_mat=None):
        h, La, Ld, Lg, dist_kernel_fn = (
            self.heads,
            self.La,
            self.Ld,
            self.Lg,
            self.dist_kernel_fn,
        )

        qkv = self.to_qkv(x)
        q, k, v = rearrange(qkv, "b n (h qkv d) -> b h n qkv d", h=h, qkv=3).unbind(
            dim=-2
        )
        dots = einsum("b h i d, b h j d -> b h i j", q, k) * self.scale

        assert (
            dist_kernel_fn in DIST_KERNELS
        ), f"distance kernel function needs to be one of {DIST_KERNELS.keys()}"
        dist_kernel_config = DIST_KERNELS[dist_kernel_fn]

        if exists(distance_mat):
            distance_mat = rearrange(distance_mat, "b i j -> b () i j")

        if exists(adjacency_mat):
            adjacency_mat = rearrange(adjacency_mat, "b i j -> b () i j")

        if exists(mask):
            mask_value = torch.finfo(dots.dtype).max
            mask = mask[:, None, :, None] * mask[:, None, None, :]

            # mask attention
            dots.masked_fill_(~mask, -mask_value)

            if exists(distance_mat):
                # mask distance to infinity
                # todo - make sure for softmax distance kernel, use -infinity
                dist_mask_value = dist_kernel_config["mask_value_fn"](dots)
                distance_mat.masked_fill_(~mask, dist_mask_value)

            if exists(adjacency_mat):
                adjacency_mat.masked_fill_(~mask, 0.0)

        attn = dots.softmax(dim=-1)

        # sum contributions from adjacency and distance tensors
        attn = attn * La

        if exists(adjacency_mat):
            attn = attn + Lg * adjacency_mat

        if exists(distance_mat):
            distance_mat = dist_kernel_config["fn"](distance_mat)
            attn = attn + Ld * distance_mat

        out = einsum("b h i j, b h j d -> b h i d", attn, v)
        out = rearrange(out, "b h n d -> b n (h d)")
        return self.to_out(out)


# main class


class MAT(nn.Module):
    def __init__(
        self,
        *,
        dim_in,
        model_dim,
        dim_out,
        depth,
        heads=8,
        Lg=0.5,
        Ld=0.5,
        La=1,
        dist_kernel_fn="exp",
    ):
        super().__init__()

        self.embed_to_model = nn.Linear(dim_in, model_dim)
        self.layers = nn.ModuleList([])

        for _ in range(depth):
            layer = nn.ModuleList(
                [
                    Residual(
                        PreNorm(
                            model_dim,
                            Attention(
                                model_dim,
                                heads=heads,
                                Lg=Lg,
                                Ld=Ld,
                                La=La,
                                dist_kernel_fn=dist_kernel_fn,
                            ),
                        )
                    ),
                    Residual(PreNorm(model_dim, FeedForwardV1(model_dim))),
                ]
            )
            self.layers.append(layer)

        self.norm_out = nn.LayerNorm(model_dim)
        self.ff_out = FeedForward(model_dim, dim_out)


    def forward(self, batch):

        x = batch["event"]
        mask = batch["mask"]
        adjacency_mat = batch["adjecent_matrix"]
        distance_mat = batch["distance_matrix"]

        x = self.embed_to_model(x)

        for (attn, ff) in self.layers:
            x = attn(
                x, mask=mask, adjacency_mat=adjacency_mat, distance_mat=distance_mat
            )
            x = ff(x)

        x = self.norm_out(x)
        x = x.mean(dim=-2)
        x = self.ff_out(x)
        return x


class MATMaskedPool(nn.Module):
    def __init__(
        self,
        *,
        dim_in,
        model_dim,
        dim_out,
        depth,
        heads=8,
        Lg=0.5,
        Ld=0.5,
        La=1,
        dist_kernel_fn="exp",
    ):
        super().__init__()

        self.embed_to_model = nn.Linear(dim_in, model_dim)
        self.layers = nn.ModuleList([])

        for _ in range(depth):
            layer = nn.ModuleList(
                [
                    Residual(
                        PreNorm(
                            model_dim,
                            Attention(
                                model_dim,
                                heads=heads,
                                Lg=Lg,
                                Ld=Ld,
                                La=La,
                                dist_kernel_fn=dist_kernel_fn,
                            ),
                        )
                    ),
                    Residual(PreNorm(model_dim, FeedForwardV1(model_dim))),
                ]
            )
            self.layers.append(layer)

        self.norm_out = nn.LayerNorm(model_dim)
        self.pool = MeanPoolingWithMask()
        self.ff_out = FeedForward(model_dim, dim_out)


    def forward(self, batch):

        x = batch["event"]
        mask = batch["mask"]
        adjacency_mat = batch["adjecent_matrix"]
        distance_mat = batch["distance_matrix"]

        x = self.embed_to_model(x)

        for (attn, ff) in self.layers:
            x = attn(
                x, mask=mask, adjacency_mat=adjacency_mat, distance_mat=distance_mat
            )
            x = ff(x)

        x = self.norm_out(x)
        x = self.pool(x, mask=mask)
        x = self.ff_out(x)
        return x


class IceCubeModelEncoderMAT(nn.Module):
    def __init__(self, dim=128, in_features=6):
        super().__init__()
        self.md = MAT(
            dim_in=6,
            model_dim=128,
            dim_out=2,
            depth=6,
            Lg=0.5,  # lambda (g)raph - weight for adjacency matrix
            Ld=0.5,  # lambda (d)istance - weight for distance matrix
            La=1,  # lambda (a)ttention - weight for usual self-attention
            dist_kernel_fn="exp",  # distance kernel fn - either 'exp' or 'softmax'
        )

    def forward(self, batch):
        return self.md(batch)


class IceCubeModelEncoderMATMasked(nn.Module):
    def __init__(self, dim=128, in_features=6):
        super().__init__()
        self.md = MATMaskedPool(
            dim_in=6,
            model_dim=128,
            dim_out=2,
            depth=6,
            Lg=0.5,  # lambda (g)raph - weight for adjacency matrix
            Ld=0.5,  # lambda (d)istance - weight for distance matrix
            La=1,  # lambda (a)ttention - weight for usual self-attention
            dist_kernel_fn="exp",  # distance kernel fn - either 'exp' or 'softmax'
        )

    def forward(self, batch):
        return self.md(batch)



In [139]:
model = IceCubeModelEncoderMATMasked()
event = torch.randn(2, 100, 6)
mask = torch.ones(2, 100).bool()
adjecent_matrix = torch.empty(2, 100, 100).random_(2).float()
distance_matrix = torch.randn(2, 100, 100)
batch = dict(
    event=event,
    mask=mask,
    adjecent_matrix=adjecent_matrix,
    distance_matrix=distance_matrix,
)
out = model(batch)
out.shape


torch.Size([2, 2])

In [140]:
import torch
torch.__version__

'1.11.0+cu115'

In [141]:
#|hide
#|eval: false
from nbdev.doclinks import nbdev_export
nbdev_export()

In [147]:
import torch
from torch.nn import Linear, ReLU, SiLU, Sequential
from torch_geometric.nn import MessagePassing, global_add_pool, global_mean_pool
from torch_scatter import scatter


class EGNNLayer(MessagePassing):
    def __init__(self, emb_dim, activation="relu", norm="layer", aggr="add"):
        """E(n) Equivariant GNN Layer
        Paper: E(n) Equivariant Graph Neural Networks, Satorras et al.
        
        Args:
            emb_dim: (int) - hidden dimension `d`
            activation: (str) - non-linearity within MLPs (swish/relu)
            norm: (str) - normalisation layer (layer/batch)
            aggr: (str) - aggregation function `\oplus` (sum/mean/max)
        """
        # Set the aggregation function
        super().__init__(aggr=aggr)

        self.emb_dim = emb_dim
        self.activation = {"swish": SiLU(), "relu": ReLU()}[activation]
        self.norm = {"layer": torch.nn.LayerNorm, "batch": torch.nn.BatchNorm1d}[norm]

        # MLP `\psi_h` for computing messages `m_ij`
        self.mlp_msg = Sequential(
            Linear(2 * emb_dim + 1, emb_dim),
            self.norm(emb_dim),
            self.activation,
            Linear(emb_dim, emb_dim),
            self.norm(emb_dim),
            self.activation,
        )
        # MLP `\psi_x` for computing messages `\overrightarrow{m}_ij`
        self.mlp_pos = Sequential(
            Linear(emb_dim, emb_dim), self.norm(emb_dim), self.activation, Linear(emb_dim, 1)
        )
        # MLP `\phi` for computing updated node features `h_i^{l+1}`
        self.mlp_upd = Sequential(
            Linear(2 * emb_dim, emb_dim),
            self.norm(emb_dim),
            self.activation,
            Linear(emb_dim, emb_dim),
            self.norm(emb_dim),
            self.activation,
        )

    def forward(self, h, pos, edge_index):
        """
        Args:
            h: (n, d) - initial node features
            pos: (n, 3) - initial node coordinates
            edge_index: (e, 2) - pairs of edges (i, j)
        Returns:
            out: [(n, d),(n,3)] - updated node features
        """
        out = self.propagate(edge_index, h=h, pos=pos)
        return out

    def message(self, h_i, h_j, pos_i, pos_j):
        # Compute messages
        pos_diff = pos_i - pos_j
        dists = torch.norm(pos_diff, dim=-1).unsqueeze(1)
        msg = torch.cat([h_i, h_j, dists], dim=-1)
        msg = self.mlp_msg(msg)
        # Scale magnitude of displacement vector
        pos_diff = pos_diff * self.mlp_pos(msg)  # torch.clamp(updates, min=-100, max=100)
        return msg, pos_diff

    def aggregate(self, inputs, index):
        msgs, pos_diffs = inputs
        # Aggregate messages
        msg_aggr = scatter(msgs, index, dim=self.node_dim, reduce=self.aggr)
        # Aggregate displacement vectors
        pos_aggr = scatter(pos_diffs, index, dim=self.node_dim, reduce="mean")
        return msg_aggr, pos_aggr

    def update(self, aggr_out, h, pos):
        msg_aggr, pos_aggr = aggr_out
        upd_out = self.mlp_upd(torch.cat([h, msg_aggr], dim=-1))
        upd_pos = pos + pos_aggr
        return upd_out, upd_pos

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(emb_dim={self.emb_dim}, aggr={self.aggr})"


class MPNNLayer(MessagePassing):
    def __init__(self, emb_dim, activation="relu", norm="layer", aggr="add"):
        """Vanilla Message Passing GNN layer
        
        Args:
            emb_dim: (int) - hidden dimension `d`
            activation: (str) - non-linearity within MLPs (swish/relu)
            norm: (str) - normalisation layer (layer/batch)
            aggr: (str) - aggregation function `\oplus` (sum/mean/max)
        """
        # Set the aggregation function
        super().__init__(aggr=aggr)

        self.emb_dim = emb_dim
        self.activation = {"swish": SiLU(), "relu": ReLU()}[activation]
        self.norm = {"layer": torch.nn.LayerNorm, "batch": torch.nn.BatchNorm1d}[norm]

        # MLP `\psi_h` for computing messages `m_ij`
        self.mlp_msg = Sequential(
            Linear(2 * emb_dim, emb_dim),
            self.norm(emb_dim),
            self.activation,
            Linear(emb_dim, emb_dim),
            self.norm(emb_dim),
            self.activation,
        )
        # MLP `\phi` for computing updated node features `h_i^{l+1}`
        self.mlp_upd = Sequential(
            Linear(2 * emb_dim, emb_dim),
            self.norm(emb_dim),
            self.activation,
            Linear(emb_dim, emb_dim),
            self.norm(emb_dim),
            self.activation,
        )

    def forward(self, h, edge_index):
        """
        Args:
            h: (n, d) - initial node features
            edge_index: (e, 2) - pairs of edges (i, j)
        Returns:
            out: (n, d) - updated node features
        """
        out = self.propagate(edge_index, h=h)
        return out

    def message(self, h_i, h_j):
        # Compute messages
        msg = torch.cat([h_i, h_j], dim=-1)
        msg = self.mlp_msg(msg)
        return msg

    def aggregate(self, inputs, index):
        # Aggregate messages
        msg_aggr = scatter(inputs, index, dim=self.node_dim, reduce=self.aggr)
        return msg_aggr

    def update(self, aggr_out, h):
        upd_out = self.mlp_upd(torch.cat([h, aggr_out], dim=-1))
        return upd_out

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(emb_dim={self.emb_dim}, aggr={self.aggr})"

In [143]:
class EGNNModel(torch.nn.Module):
    def __init__(
        self,
        num_layers=5,
        emb_dim=128,
        in_dim=1,
        out_dim=1,
        activation="relu",
        norm="layer",
        aggr="sum",
        pool="sum",
        residual=True
    ):
        """E(n) Equivariant GNN model 
        
        Args:
            num_layers: (int) - number of message passing layers
            emb_dim: (int) - hidden dimension
            in_dim: (int) - initial node feature dimension
            out_dim: (int) - output number of classes
            activation: (str) - non-linearity within MLPs (swish/relu)
            norm: (str) - normalisation layer (layer/batch)
            aggr: (str) - aggregation function `\oplus` (sum/mean/max)
            pool: (str) - global pooling function (sum/mean)
            residual: (bool) - whether to use residual connections
        """
        super().__init__()

        # Embedding lookup for initial node features
        self.emb_in = torch.nn.Embedding(in_dim, emb_dim)

        # Stack of GNN layers
        self.convs = torch.nn.ModuleList()
        for layer in range(num_layers):
            self.convs.append(EGNNLayer(emb_dim, activation, norm, aggr))

        # Global pooling/readout function
        self.pool = {"mean": global_mean_pool, "sum": global_add_pool}[pool]

        # Predictor MLP
        self.pred = torch.nn.Sequential(
            torch.nn.Linear(emb_dim, emb_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(emb_dim, out_dim)
        )
        self.residual = residual

    def forward(self, batch):
        
        h = self.emb_in(batch.atoms)  # (n,) -> (n, d)
        pos = batch.pos  # (n, 3)

        for conv in self.convs:
            # Message passing layer
            h_update, pos_update = conv(h, pos, batch.edge_index)

            # Update node features (n, d) -> (n, d)
            h = h + h_update if self.residual else h_update 

            # Update node coordinates (no residual) (n, 3) -> (n, 3)
            pos = pos_update

        out = self.pool(h, batch.batch)  # (n, d) -> (batch_size, d)
        return self.pred(out)  # (batch_size, out_dim)

In [144]:
#feats = torch.randint(0, 21, (1, 200)) # (1, 1024)
feats = torch.rand(1, 200, 8)         # (1, 1024, 3)
coors = torch.randn(1, 200, 4)         # (1, 1024, 3)
mask = torch.ones_like(torch.rand(1, 200)).bool()    # (1, 1024)

feats_out, coors_out = net(feats, coors, mask = mask) # (1, 1024, 32), (1, 1024, 3)

torch.Size([1, 200, 128])


In [145]:
net.token_emb(feats).shape

torch.Size([1, 200, 128])

In [146]:
feats_out

tensor([[[ 0.6397,  0.0150, -0.5736,  ...,  0.4262,  0.0347, -0.2446],
         [-0.6280,  1.0807,  1.9626,  ..., -1.2121,  0.1134,  1.1187],
         [ 0.3124,  0.5278,  2.3593,  ...,  0.5552,  0.8637,  0.1939],
         ...,
         [-1.0462,  0.7790,  0.9071,  ...,  2.0370, -1.1702, -0.5642],
         [-0.0345,  0.5069, -0.1040,  ...,  1.0613, -0.1551, -0.6776],
         [-0.5054, -0.1815,  1.8672,  ..., -0.1932,  0.8704, -0.0397]]],
       grad_fn=<AddBackward0>)