In [168]:
#|default_exp models

In [169]:
#| export
import sys
sys.path.append('/opt/slh/archive/software/graphnet/src')
sys.path.append('/opt/slh/icecube/')
import torch
from x_transformers import ContinuousTransformerWrapper, Encoder, Decoder
from torch import nn, einsum
from einops import rearrange
import torch.nn.functional as F
from datasets import load_from_disk
from abc import abstractmethod
from torch import Tensor
from typing import Optional, Any
import scipy
import numpy as np
from graphnet.models.task.reconstruction import DirectionReconstructionWithKappa, AzimuthReconstructionWithKappa, ZenithReconstruction
from graphnet.training.loss_functions import VonMisesFisher3DLoss,  VonMisesFisher2DLoss, EuclideanDistanceLoss
from graphnet.training.labels import Direction
from icecube.modelsgraph import EGNNModeLFEAT, DynEdgeFEXTRACTRO
from torch_geometric.nn.pool import knn_graph
import torch.utils.checkpoint as checkpoint
from einops import repeat
from torch_geometric.utils import to_dense_adj


In [170]:
# | export

import math
class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim=16, M=10000):
        super().__init__()
        self.dim = dim
        self.M = M

    def forward(self, x):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(self.M) / half_dim
        emb = torch.exp(torch.arange(half_dim, device=device) * (-emb))
        emb = x[...,None] * emb[None,...]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb
    

class EuclideanDistanceLossG(torch.nn.Module):
    def __init__(self, eps=1e-6, reduction='mean'):
        super().__init__()
        self.eps = eps
        self.reduction = reduction
        
    def forward(self, prediction, target):
        diff = prediction - target
        loss = torch.norm(diff, dim=1) + self.eps
        if self.reduction == 'mean':
            loss = torch.mean(loss)
        elif self.reduction == 'sum':
            loss = torch.sum(loss)
        return loss
    
    
class VonMisesFisher3DLossCosineSimularityLoss(nn.Module):
    def __init__(self, eps=1e-8):
        super().__init__()
        self.vonmis = VonMisesFisher3DLoss()
        self.cosine = nn.CosineSimilarity(dim=1, eps=eps)
        
    def forward(self, y_pred, y_true):
        return (self.vonmis(y_pred, y_true) + (1-self.cosine(y_pred[:, :3], y_true).mean()))/2
    
    
class VonMisesFisher3DLossEcludeLoss(nn.Module):
    def __init__(self, eps=1e-8):
        super().__init__()
        self.vonmis = VonMisesFisher3DLoss()
        self.cosine = EuclideanDistanceLossG()
        
    def forward(self, y_pred, y_true):
        return (self.vonmis(y_pred, y_true) + self.cosine(y_pred[:, :3], y_true))/2
    
class VonMisesFisher3DLossEcludeLossCosine(nn.Module):
    def __init__(self, eps=1e-8):
        super().__init__()
        self.vonmis = VonMisesFisher3DLoss()
        self.cosine =  nn.CosineSimilarity(dim=1, eps=eps)
        self.euclud = EuclideanDistanceLossG()
        self.eps = eps
        
    def forward(self, y_pred, y_true):
        return (self.vonmis(y_pred, y_true) + 
                (self.euclud(y_pred[:, :3], y_true)) +
                (1-self.cosine(y_pred[:, :3], y_true).mean()))/3
    


class VonMisesFisher2DLossL1Loss(nn.Module):
    def __init__(self, eps=1e-8):
        super().__init__()
        self.vonmis = VonMisesFisher2DLoss()
        self.l1 = nn.L1Loss()
        
    def forward(self, y_pred, y_true):
        vm = self.vonmis(y_pred[:, :2], y_true)
        l1 = self.l1(y_pred[:, 2], y_true[:, -1])
        return (vm + l1)/2
        
    
    
class LogCoshLoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, y_t, y_prime_t):
        ey_t = y_t - y_prime_t
        return torch.mean(torch.log(torch.cosh(ey_t + 1e-12)))


class SigmoidRange(nn.Module):
    def __init__(self, low, high):
        super().__init__()
        self.low = low
        self.high = high

    def forward(self, x):
        return torch.sigmoid(x) * (self.high - self.low) + self.low


class Adjustoutput(nn.Module):
    def __init__(self):
        super().__init__()
        self.az = SigmoidRange(6.436839548775502e-08, 6.2891)
        self.zn = SigmoidRange(8.631674577710722e-05, 3.1417)

    def forward(self, x):
        x[:, 0] = self.az(x[:, 0])
        x[:, 1] = self.zn(x[:, 1])
        return x


class PoolingWithMask(nn.Module):
    def __init__(self, pool_type):
        super().__init__()
        self.pool_type = pool_type

    def forward(self, x, mask):
        # Multiply the mask with the input tensor to zero out the padded values
        x = x * mask.unsqueeze(-1)

        if self.pool_type == "mean":
            # Sum the values along the sequence dimension
            x = torch.sum(x, dim=1)

            # Divide the sum by the number of non-padded values (i.e. the sum of the mask)
            x = x / torch.sum(mask, dim=1, keepdim=True)
        elif self.pool_type == "max":
            # Find the maximum value along the sequence dimension
            x, _ = torch.max(x, dim=1)
        elif self.pool_type == "min":
            # Find the minimum value along the sequence dimension
            x, _ = torch.min(x, dim=1)
        else:
            raise ValueError("Invalid pool_type. Choose from ['mean', 'max', 'min']")

        return x


class MeanPoolingWithMask(nn.Module):
    def __init__(self):
        super(MeanPoolingWithMask, self).__init__()

    def forward(self, x, mask):
        # Multiply the mask with the input tensor to zero out the padded values
        x = x * mask.unsqueeze(-1)

        # Sum the values along the sequence dimension
        x = torch.sum(x, dim=1)

        # Divide the sum by the number of non-padded values (i.e. the sum of the mask)
        x = x / torch.sum(mask, dim=1, keepdim=True)

        return x


class FeedForward(nn.Module):
    def __init__(self, dim, dim_out=None, mult=4):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, dim * mult), nn.GELU(), nn.Linear(dim * mult, dim_out)
        )

    def forward(self, x):
        return self.net(x)


class IceCubeModelEncoderV0(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = ContinuousTransformerWrapper(
            dim_in=6,
            dim_out=128,
            max_seq_len=150,
            attn_layers=Encoder(dim=128, depth=6, heads=8),
        )

        # self.pool = MeanPoolingWithMask()
        self.head = FeedForward(128, 2)

    def forward(self, batch):
        x, mask = batch["event"], batch["mask"]
        x = self.encoder(x, mask=mask)
        x = x.mean(dim=1)
        x = self.head(x)
        return x


class IceCubeModelEncoderV1(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = ContinuousTransformerWrapper(
            dim_in=8,
            dim_out=128,
            max_seq_len=150,
            attn_layers=Encoder(dim=128, depth=6, heads=8),
        )

        self.pool = MeanPoolingWithMask()
        self.head = FeedForward(128, 2)

    def forward(self, batch):
        x, mask = batch["event"], batch["mask"]
        x = self.encoder(x, mask=mask)
        x = self.pool(x, mask)
        x = self.head(x)
        return x


class IceCubeModelEncoderV1CombinePool(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = ContinuousTransformerWrapper(
            dim_in=8,
            dim_out=128,
            max_seq_len=150,
            attn_layers=Encoder(dim=128, depth=6, heads=8),
        )

        self.pool_mean = PoolingWithMask('mean')
        self.pool_max = PoolingWithMask('max')
        self.head = FeedForward(128 * 2, 2)

    def forward(self, batch):
        x, mask = batch["event"], batch["mask"]
        x = self.encoder(x, mask=mask)
        x = torch.concat([self.pool_mean(x, mask), self.pool_max(x, mask)], dim=1)
        x = self.head(x)
        return x


class always:
    def __init__(self, val):
        self.val = val

    def __call__(self, *args, **kwargs):
        return self.val


def l2norm(t, groups=1):
    t = rearrange(t, "... (g d) -> ... g d", g=groups)
    t = F.normalize(t, p=2, dim=-1)
    return rearrange(t, "... g d -> ... (g d)")


class TokenEmbedding(nn.Module):
    def __init__(self, dim, num_tokens, l2norm_embed=False):
        super().__init__()
        self.l2norm_embed = l2norm_embed
        self.emb = nn.Embedding(num_tokens, dim, padding_idx=0)

    def forward(self, x):
        token_emb = self.emb(x)
        return l2norm(token_emb) if self.l2norm_embed else token_emb

    def init_(self):
        nn.init.kaiming_normal_(self.emb.weight)


class IceCubeModelEncoderSensorEmbeddinng(nn.Module):
    def __init__(self, dim=128, in_features=14):
        super().__init__()
        self.token_emb = TokenEmbedding(dim, num_tokens=5161)
        self.post_norma = nn.LayerNorm(dim)
        self.token_emb.init_()
        self.encoder = ContinuousTransformerWrapper(
            dim_in=in_features + dim,
            dim_out=256,
            max_seq_len=150,
            attn_layers=Encoder(dim=256, depth=6, heads=8),
        )

        self.pool = MeanPoolingWithMask()
        self.head = FeedForward(256, 2)

    def forward(self, batch):
        x, mask, sensor_id = batch["event"], batch["mask"], batch["sensor_id"]
        embed = self.token_emb(sensor_id)
        embed = self.post_norma(embed)
        x = torch.cat([x, embed], dim=-1)
        x = self.encoder(x, mask=mask)
        x = self.pool(x, mask)
        x = self.head(x)
        return x


class IceCubeModelEncoderSensorEmbeddinngV1(nn.Module):
    def __init__(self, dim=128, in_features=6):
        super().__init__()
        self.token_emb = TokenEmbedding(dim, num_tokens=5161)
        self.post_norma = nn.LayerNorm(dim)
        self.token_emb.init_()
        self.encoder = ContinuousTransformerWrapper(
            dim_in=in_features + dim,
            dim_out=256,
            max_seq_len=150,
            attn_layers=Encoder(dim=256, depth=6, heads=8),
        )

        self.pool = MeanPoolingWithMask()
        self.head = FeedForward(256, 2)

    def forward(self, batch):
        x, mask, sensor_id = batch["event"], batch["mask"], batch["sensor_id"]
        embed = self.token_emb(sensor_id)
        embed = self.post_norma(embed)
        x = torch.cat([x, embed], dim=-1)
        x = self.encoder(x, mask=mask)
        x = self.pool(x, mask)
        x = self.head(x)
        return x


class TokenEmbeddingV2(nn.Module):
    def __init__(self, dim, num_tokens, l2norm_embed=False):
        super().__init__()
        self.l2norm_embed = l2norm_embed
        self.emb = nn.Embedding(num_tokens, dim)

    def forward(self, x):
        token_emb = self.emb(x)
        return l2norm(token_emb) if self.l2norm_embed else token_emb

    def init_(self):
        nn.init.kaiming_normal_(self.emb.weight)


class IceCubeModelEncoderSensorEmbeddinngV2(nn.Module):
    def __init__(self, dim=128, in_features=6):
        super().__init__()
        self.token_emb = TokenEmbeddingV2(dim, num_tokens=5161)
        self.post_norma = nn.LayerNorm(dim)
        self.token_emb.init_()
        self.encoder = ContinuousTransformerWrapper(
            dim_in=in_features + dim,
            dim_out=256,
            max_seq_len=196,
            attn_layers=Encoder(dim=256, depth=6, heads=8),
        )

        self.pool = MeanPoolingWithMask()
        self.head = FeedForward(256, 2)

    def forward(self, batch):
        x, mask, sensor_id = batch["event"], batch["mask"], batch["sensor_id"]
        embed = self.token_emb(sensor_id)
        x = torch.cat([x, embed], dim=-1)
        x = self.encoder(x, mask=mask)
        x = self.pool(x, mask)
        x = self.head(x)
        return x


class IceCubeModelEncoderSensorEmbeddinngV3(nn.Module):
    def __init__(self, dim=128, in_features=6):
        super().__init__()
        self.token_emb = TokenEmbeddingV2(dim, num_tokens=5161)
        self.post_norma = nn.LayerNorm(dim)
        self.token_emb.init_()
        self.encoder = ContinuousTransformerWrapper(
            dim_in=in_features + dim,
            dim_out=256,
            max_seq_len=150,
            attn_layers=Encoder(dim=256, depth=6, heads=8),
        )

        self.pool = MeanPoolingWithMask()
        self.head = FeedForward(256, 2)
        self.sigmout = Adjustoutput()

    def forward(self, batch):
        x, mask, sensor_id = batch["event"], batch["mask"], batch["sensor_id"]
        embed = self.token_emb(sensor_id)
        x = torch.cat([x, embed], dim=-1)
        x = self.encoder(x, mask=mask)
        x = self.pool(x, mask)
        x = self.head(x)
        s = self.sigmout(x)
        return x
    
class IceCubeModelEncoderSensorEmbeddinngV4(nn.Module):
    def __init__(self, dim=128, in_features=9):
        super().__init__()
        self.token_emb = TokenEmbeddingV2(dim, num_tokens=5168)
        self.post_norma = nn.LayerNorm(dim)
        self.token_emb.init_()
        self.encoder = ContinuousTransformerWrapper(
            dim_in=in_features + dim,
            dim_out=256,
            max_seq_len=480,
            post_emb_norm = True,
            attn_layers=Encoder(dim=256, 
                                depth=8, 
                                heads=8, 
                                ff_glu = True,
                                rotary_pos_emb = True),
        )

        self.pool_mean = PoolingWithMask("mean")
        self.pool_max = PoolingWithMask("max")
        
        self.out = DirectionReconstructionWithKappa(
            hidden_size=256 * 2,
            target_labels='direction',
            loss_function=VonMisesFisher3DLoss(),
        )
        
        self.apply(self._init_weights)
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, batch):
        x, mask, sensor_id = batch["event"], batch["mask"], batch["sensor_id"]
        embed = self.token_emb(sensor_id)
        x = torch.cat([x, embed], dim=-1)
        x = self.encoder(x, mask=mask)
        x = torch.concat([self.pool_mean(x, mask), self.pool_max(x, mask)], dim=1)
        x = self.out(x)
        return x
    
    



In [171]:
model = IceCubeModelEncoderSensorEmbeddinngV4().eval()
event = torch.rand(4, 100, 9)
mask = torch.ones(4, 100, dtype=torch.bool)
sensor_id = torch.randint(0, 5161, (4, 100))
input = dict(event=event, mask=mask, sensor_id=sensor_id)
with torch.no_grad():
    y = model(input)

y.shape

torch.Size([4, 4])

In [172]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [173]:
count_parameters(model)/1000000

11.277827

In [174]:
# | export
class IceCubeModelEncoderV2(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = ContinuousTransformerWrapper(
            dim_in=8,
            dim_out=128,
            max_seq_len=150,
            attn_layers=Encoder(dim=128, depth=6, heads=8),
        )

        self.pool = MeanPoolingWithMask()
        self.out = DirectionReconstructionWithKappa(128)

    def forward(self, batch):
        x, mask = batch["event"], batch["mask"]
        x = self.encoder(x, mask=mask)
        x = self.pool(x, mask)
        return self.out(x)


class EncoderWithDirectionReconstruction(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = ContinuousTransformerWrapper(
            dim_in=8,
            dim_out=128,
            max_seq_len=150,
            attn_layers=Encoder(dim=128, depth=6, heads=8),
        )

        self.pool_mean = PoolingWithMask("mean")
        self.pool_max = PoolingWithMask("max")
        self.out = DirectionReconstructionWithKappa(
            hidden_size=256,
            target_labels='direction',
            loss_function=VonMisesFisher3DLoss(),
        )

    def forward(self, batch):
        x, mask = batch["event"], batch["mask"]
        x = self.encoder(x, mask=mask)
        x = torch.concat([self.pool_mean(x, mask), self.pool_max(x, mask)], dim=1)
        return self.out(x)


class EncoderWithDirectionReconstructionV1(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = ContinuousTransformerWrapper(
            dim_in=9,
            dim_out=128,
            max_seq_len=150,
            attn_layers=Encoder(dim=128, depth=6, heads=8),
        )

        self.pool_mean = PoolingWithMask("mean")
        self.pool_max = PoolingWithMask("max")
        self.pool_min = PoolingWithMask("min")
        self.ae = FeedForward(384, 384)
        self.out = DirectionReconstructionWithKappa(
            hidden_size=384,
            target_labels='direction',
            loss_function=VonMisesFisher3DLoss(),
        )

    def forward(self, batch):
        x, mask = batch["event"], batch["mask"]
        x = self.encoder(x, mask=mask)
        x = torch.concat([self.pool_mean(x, mask), 
                          self.pool_max(x, mask), 
                          self.pool_min(x, mask)], dim=1)
        x = self.ae(x)
        return self.out(x)
    
class EncoderWithDirectionReconstructionV2(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = ContinuousTransformerWrapper(
            dim_in=9,
            dim_out=128,
            max_seq_len=150,
            attn_layers=Encoder(dim=128, depth=6, heads=8),
        )

        self.pool_mean = PoolingWithMask("mean")
        self.pool_max = PoolingWithMask("max")
        self.out = DirectionReconstructionWithKappa(
            hidden_size=256,
            target_labels='direction',
            loss_function=VonMisesFisher3DLoss(),
        )

    def forward(self, batch):
        x, mask = batch["event"], batch["mask"]
        x = self.encoder(x, mask=mask)
        x = torch.concat([self.pool_mean(x, mask), self.pool_max(x, mask)], dim=1)
        return self.out(x)
    
class EncoderWithDirectionReconstructionV3(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = ContinuousTransformerWrapper(
            dim_in=8,
            dim_out=128,
            max_seq_len=150,
            attn_layers=Encoder(dim=128, depth=6, heads=8),
        )

        self.pool_mean = PoolingWithMask("mean")
        self.pool_max = PoolingWithMask("max")

        self.azimuth_task = AzimuthReconstructionWithKappa(
            hidden_size=256,
            loss_function=VonMisesFisher2DLoss(),
            target_labels=["azimuth", "kappa"],
        )

        self.zenith_task = ZenithReconstruction(
            hidden_size=256,
            loss_function=nn.L1Loss(),
            target_labels=["zenith"],
        )

    def forward(self, batch):
        x, mask = batch["event"], batch["mask"]
        x = self.encoder(x, mask=mask)
        x = torch.concat([self.pool_mean(x, mask), self.pool_max(x, mask)], dim=1)
        az = self.azimuth_task(x)
        zn = self.zenith_task(x)
        return torch.concat([az, zn], dim=1)


class EncoderWithDirectionReconstructionV3(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = ContinuousTransformerWrapper(
            dim_in=8,
            dim_out=128,
            max_seq_len=150,
            attn_layers=Encoder(dim=128, depth=6, heads=8),
        )

        self.pool_mean = PoolingWithMask("mean")
        self.pool_max = PoolingWithMask("max")

        self.azimuth_task = AzimuthReconstructionWithKappa(
            hidden_size=256,
            loss_function=VonMisesFisher2DLoss(),
            target_labels=["azimuth", "kappa"],
        )

        self.zenith_task = ZenithReconstruction(
            hidden_size=256,
            loss_function=nn.L1Loss(),
            target_labels=["zenith"],
        )

    def forward(self, batch):
        x, mask = batch["event"], batch["mask"]
        x = self.encoder(x, mask=mask)
        x = torch.concat([self.pool_mean(x, mask), self.pool_max(x, mask)], dim=1)
        az = self.azimuth_task(x)
        zn = self.zenith_task(x)
        return torch.concat([az, zn], dim=1)


class EncoderWithDirectionReconstructionV4(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = ContinuousTransformerWrapper(
            dim_in=9,
            dim_out=128,
            max_seq_len=150,
            post_emb_norm = True,
            attn_layers=Encoder(dim=128,
                                depth=8,
                                heads=8,
                                ff_glu = True,
                                rotary_pos_emb = True),
        )

        self.pool_mean = PoolingWithMask("mean")
        self.pool_max = PoolingWithMask("max")
        self.out = DirectionReconstructionWithKappa(
            hidden_size=256,
            target_labels='direction',
            loss_function=VonMisesFisher3DLoss(),
        )
        
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            
    def forward(self, batch):
        x, mask = batch["event"], batch["mask"]
        x = self.encoder(x, mask=mask)
        x = torch.concat([self.pool_mean(x, mask), self.pool_max(x, mask)], dim=1)
        return self.out(x)
    

class EncoderWithDirectionReconstructionV5(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = ContinuousTransformerWrapper(
            dim_in=9,
            dim_out=128,
            max_seq_len=150,
            post_emb_norm = True,
            attn_layers=Encoder(dim=128,
                                depth=8,
                                heads=8,
                                ff_glu = True,
                                rotary_pos_emb = True),
        )

        self.cls_token = nn.Parameter(torch.rand(1, 1, 128))
        self.out = DirectionReconstructionWithKappa(
            hidden_size=128,
            target_labels='direction',
            loss_function=VonMisesFisher3DLoss(),
        )
        
        self.apply(self._init_weights)
        torch.nn.init.trunc_normal_(self.cls_token, std=0.02)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            
    def forward(self, batch):
        x, mask = batch["event"], batch["mask"]
        bs = x.shape[0]
        cls_tokens  = self.cls_token.expand(bs, -1, -1)
        mask = torch.cat([torch.ones(bs, 1, dtype=torch.bool, device=x.device), mask], dim=1)
        x = self.encoder(x, mask=mask, prepend_embeds=cls_tokens)
        #pool on cls token
        x = x[:, 0]
        return self.out(x)
    
    

    
class ExtractorV0(nn.Module):
    def __init__(self, dim_base=128, dim=384, proj = True):
        super().__init__()
        self.emb = SinusoidalPosEmb(dim=dim_base)
        self.emb2 = SinusoidalPosEmb(dim=dim_base//2)
        self.aux_emb = nn.Embedding(2,dim_base//2)
        self.qe_emb = nn.Embedding(2,dim_base//2)
        self.proj = nn.Linear(dim_base*7,dim) if proj else nn.Identity()
        
    def forward(self, x, Lmax=None):
        pos = x['pos'] if Lmax is None else x['pos'][:,:Lmax]
        charge = x['charge'] if Lmax is None else x['charge'][:,:Lmax]
        time = x['time'] if Lmax is None else x['time'][:,:Lmax]
        auxiliary = x['aux'] if Lmax is None else x['auxiliary'][:,:Lmax]
        qe = x['qe'] if Lmax is None else x['qe'][:,:Lmax]
        ice_properties = x['ice_properties'] if Lmax is None else x['ice_properties'][:,:Lmax]
        
        x = torch.cat([self.emb(100*pos).flatten(-2), self.emb(40*charge),
                       self.emb(100*time),self.aux_emb(auxiliary),self.qe_emb(qe),
                       self.emb2(50*ice_properties).flatten(-2)],-1)
        x = self.proj(x)
        return x

    
    
class ExtractorV1(nn.Module):
    def __init__(self, dim_base=128):
        super().__init__()
        self.emb = SinusoidalPosEmb(dim=dim_base)
        self.emb2 = SinusoidalPosEmb(dim=dim_base//2)
        self.aux_emb = TokenEmbeddingV2(dim_base//4, 2, True)
        self.qe_emb = TokenEmbeddingV2(dim_base//4, 2, True)
        self.rank = TokenEmbeddingV2(dim_base//4, 4, True)
        
    def forward(self, x):
        ice_properties = torch.stack([x['scattering'], x['absorption']], dim=2)
        
        x = torch.cat([self.emb(100*x['pos']).flatten(-2), 
                       self.emb(40*x['charge']),
                       self.emb(100*x['time']),
                       self.aux_emb(x["aux"]),
                       self.qe_emb(x["qe"]),
                       self.rank(x["rank"]),
                       self.emb2(50*ice_properties).flatten(-2)],-1)
        return x
    

class ExtractorV2(nn.Module):
    def __init__(self, dim_base=128, out_dim=196):
        super().__init__()
        self.emb = SinusoidalPosEmb(dim=dim_base)
        self.emb2 = SinusoidalPosEmb(dim=dim_base//2)
        self.aux_emb = TokenEmbeddingV2(dim_base//4, 2, True)
        self.qe_emb = TokenEmbeddingV2(dim_base//4, 2, True)
        self.rank = TokenEmbeddingV2(dim_base//4, 4, True)
        self.out = nn.Linear(864, out_dim)
        
    def forward(self, x):
        ice_properties = torch.stack([x['scattering'], x['absorption']], dim=2)
        
        x = torch.cat([self.emb(100*x['pos']).flatten(-2), 
                       self.emb(40*x['charge']),
                       self.emb(100*x['time']),
                       self.aux_emb(x["aux"]),
                       self.qe_emb(x["qe"]),
                       self.rank(x["rank"]),
                       self.emb2(50*ice_properties).flatten(-2)],-1)
        return x
    
    
class ExtractorV2(nn.Module):
    def __init__(self, dim_base=128, out_dim=196):
        super().__init__()
        self.emb = SinusoidalPosEmb(dim=dim_base)
        self.emb2 = SinusoidalPosEmb(dim=dim_base//2)
        self.aux_emb = TokenEmbeddingV2(dim_base//4, 2, True)
        self.qe_emb = TokenEmbeddingV2(dim_base//4, 2, True)
        self.rank = TokenEmbeddingV2(dim_base//4, 4, True)
        self.out = nn.Linear(864, out_dim)
        
    def forward(self, x):
        ice_properties = torch.stack([x['scattering'], x['absorption']], dim=2)
        
        x = torch.cat([self.emb(100*x['pos']).flatten(-2), 
                       self.emb(40*x['charge']),
                       self.emb(100*x['time']),
                       self.aux_emb(x["aux"]),
                       self.qe_emb(x["qe"]),
                       self.rank(x["rank"]),
                       self.emb2(50*ice_properties).flatten(-2)],-1)
        return x
    
    
class EncoderWithDirectionReconstructionV6(nn.Module):
    def __init__(self, dim_in = 864, dim_out=256, attn_depth = 8, heads = 8):
        super().__init__()
        self.encoder = ContinuousTransformerWrapper(
            dim_in=dim_in,
            dim_out=dim_out,
            max_seq_len=440,
            post_emb_norm = True,
            attn_layers=Encoder(dim=dim_out,
                                depth=attn_depth,
                                heads=heads,
                                ff_glu = True,
                                rotary_pos_emb = True),
        )

        self.cls_token = nn.Parameter(torch.rand(1, 1, dim_in))
        self.out = DirectionReconstructionWithKappa(
            hidden_size=dim_out,
            target_labels='direction',
            loss_function=VonMisesFisher3DLoss(),
        )
        self.fe = ExtractorV1()
        
        self.apply(self._init_weights)
        torch.nn.init.trunc_normal_(self.cls_token, std=0.02)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            
    def forward(self, batch):
        mask = batch["mask"]
        x = self.fe(batch)
        bs = x.shape[0]
        cls_tokens  = self.cls_token.expand(bs, -1, -1)
        x = torch.cat((cls_tokens, x), dim = -2)
        mask = torch.cat([torch.ones(bs, 1, dtype=torch.bool, device=x.device), mask], dim=1)
        x = self.encoder(x, mask=mask)
        x = x[:, 0]
        return self.out(x)
    
    
class EncoderWithDirectionReconstructionV7(nn.Module):
    def __init__(self, dim_in = 864, dim_out=256, attn_depth = 12, heads = 12):
        super().__init__()
        self.encoder = ContinuousTransformerWrapper(
            dim_in=dim_in,
            dim_out=dim_out,
            max_seq_len=440,
            post_emb_norm = True,
            use_abs_pos_emb = False, 
            emb_dropout = 0.1, 
            attn_layers=Encoder(dim=dim_out,
                                depth=attn_depth,
                                heads=heads,
                                ff_glu = True,
                                rotary_pos_emb = True, 
                                use_rmsnorm = True,
                                layer_dropout = 0.1, 
                                attn_dropout = 0.1,    
                                ff_dropout = 0.1)   
        )

        self.pool_mean = PoolingWithMask("mean")
        self.pool_max = PoolingWithMask("max")
        self.cls_token = nn.Parameter(torch.rand(1, 1, dim_in))
        self.out = DirectionReconstructionWithKappa(
            hidden_size=dim_out * 3,
            target_labels='direction',
            loss_function=VonMisesFisher3DLoss(),
        )
        self.fe = ExtractorV1()
        
        self.apply(self._init_weights)
        torch.nn.init.trunc_normal_(self.cls_token, std=0.02)


    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            
    def forward(self, batch):
        mask = batch["mask"]
        x = self.fe(batch)
        bs = x.shape[0]
        cls_tokens  = self.cls_token.expand(bs, -1, -1)
        x = torch.cat((cls_tokens, x), dim = -2)
        mask = torch.cat([torch.ones(bs, 1, dtype=torch.bool, device=x.device), mask], dim=1)
        x = self.encoder(x, mask=mask)
        x = torch.concat([self.pool_mean(x, mask), self.pool_max(x, mask), x[:, 0]], dim=1)
        return self.out(x)
    
from timm.models.layers import drop_path, to_2tuple, trunc_normal_
class EncoderWithDirectionReconstructionV8(nn.Module):
    def __init__(self, dim_in = 864, dim_out=256, attn_depth = 8, heads = 12):
        super().__init__()
        self.encoder = ContinuousTransformerWrapper(
            dim_in=dim_out,
            dim_out=dim_out,
            max_seq_len=300,
            post_emb_norm = True,
            use_abs_pos_emb = False, 
            emb_dropout = 0.1, 
            attn_layers=Encoder(dim=dim_out,
                                depth=attn_depth,
                                heads=heads,
                                ff_glu = True,
                                rel_pos_bias = True, 
                                layer_dropout = 0.01, 
                                attn_dropout = 0.01,    
                                ff_dropout = 0.01)   
        )
        
        self.cls_token = nn.Linear(dim_out,1,bias=False)
        self.pool_mean = PoolingWithMask("mean")
        self.pool_max = PoolingWithMask("max")
        
        self.out = nn.Linear(dim_out * 3, 3)
        self.fe = ExtractorV0(dim=dim_out, dim_base=96)
        
        self.apply(self._init_weights)
        trunc_normal_(self.cls_token.weight, std=.02)
        #torch.nn.init.trunc_normal_(self.cls_token, std=0.02)


    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            
    def forward(self, batch):
        mask = batch["mask"]
        x = self.fe(batch, mask.sum(-1).max())
        bs = x.shape[0]
        cls_token = self.cls_token.weight.unsqueeze(0).expand(bs,-1,-1)
        x = torch.cat([cls_token,x],1)
        mask = torch.cat([torch.ones(bs, 1, dtype=torch.bool, device=x.device), mask], dim=1)
        x = x[:,:mask.sum(-1).max()]
        mask = mask[:,:mask.sum(-1).max()]
        x = self.encoder(x, mask=mask)
        x = torch.concat([self.pool_mean(x, mask), self.pool_max(x, mask), x[:, 0]], dim=1)
        return self.out(x)


class EncoderWithDirectionReconstructionV9(nn.Module):
    def __init__(self, dim_out=256, attn_depth = 10, heads = 12):
        super().__init__()
        self.encoder = ContinuousTransformerWrapper(
            dim_out=dim_out,
            max_seq_len=300,
            post_emb_norm = True,
            use_abs_pos_emb = False, 
            emb_dropout = 0.1, 
            attn_layers=Encoder(dim=dim_out,
                                depth=attn_depth,
                                heads=heads,
                                use_rmsnorm = True,
                                ff_glu = True,
                                alibi_pos_bias = True, 
                                alibi_num_heads = 4 ,  
                                layer_dropout = 0.01, 
                                attn_dropout = 0.01,    
                                ff_dropout = 0.01)   
        )
        
        self.cls_token = nn.Linear(dim_out,1,bias=False)  
        self.out = nn.Linear(dim_out, 3)
        self.fe = ExtractorV0(dim=dim_out, dim_base=96)
        self.apply(self._init_weights)
        trunc_normal_(self.cls_token.weight, std=.02)
        #torch.nn.init.trunc_normal_(self.cls_token, std=0.02)


    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            
    def forward(self, batch):
        mask = batch["mask"]
        x = self.fe(batch, mask.sum(-1).max())
        bs = x.shape[0]
        cls_token = self.cls_token.weight.unsqueeze(0).expand(bs,-1,-1)
        x = torch.cat([cls_token,x],1)
        mask = torch.cat([torch.ones(bs, 1, dtype=torch.bool, device=x.device), mask], dim=1)
        x = x[:,:mask.sum(-1).max()]
        mask = mask[:,:mask.sum(-1).max()]
        x = self.encoder(x, mask=mask)
        x = x[:, 0]
        return self.out(x)
    
from torch_geometric.utils import to_dense_batch
class EncoderWithDirectionReconstructionV10(nn.Module):
    def __init__(self, dim_in = 864, dim_out=256, attn_depth = 8, heads = 12):
        super().__init__()
        self.encoder = ContinuousTransformerWrapper(
            dim_out=dim_out,
            max_seq_len=300,
            post_emb_norm = True,
            use_abs_pos_emb = False, 
            emb_dropout = 0.1, 
            attn_layers=Encoder(dim=dim_out,
                                depth=attn_depth,
                                heads=heads,
                                ff_glu = True,
                                rel_pos_bias = True, 
                                layer_dropout = 0.01, 
                                attn_dropout = 0.01,    
                                ff_dropout = 0.01)   
        )
    
        
        self.cls_token = nn.Linear(dim_out,1,bias=False)
        self.out = nn.Linear(dim_out, 3)
        self.fe = ExtractorV0(dim=dim_out, dim_base=96)
        self.graph_feat = EGNNModeLFEAT( emb_dim=dim_out, num_layers=2)
        trunc_normal_(self.cls_token.weight, std=.02)
        #torch.nn.init.trunc_normal_(self.cls_token, std=0.02)

    def forward(self, batch):
        mask = batch["mask"] #bs, seq_len
        bs = mask.shape[0] # int
        pos = batch["pos"][mask] 
        mask = mask[:,:mask.sum(-1).max()] 
        batch_index = mask.nonzero()[:,0] 
        edge_index = knn_graph(x = pos, k=8, batch=batch_index).to(mask.device)
        x = self.fe(batch, mask.sum(-1).max())
        x = x[mask]
        x = self.graph_feat(x, pos, edge_index)
        x, mask = to_dense_batch(x, batch_index)
        cls_token = self.cls_token.weight.unsqueeze(0).expand(bs,-1,-1)
        x = torch.cat([cls_token,x],1)
        mask = torch.cat([torch.ones(bs, 1, dtype=torch.bool, device=x.device), mask], dim=1)
        x = self.encoder(x, mask=mask)
        x = x[:, 0]
        return self.out(x)


In [175]:
from x_transformers.x_transformers import ScaledSinusoidalEmbedding
from icecube.utils import collate_fn_v2
from icecube.dataset import HuggingFaceDatasetV14
from torch.utils.data import DataLoader

In [176]:
ds = HuggingFaceDatasetV14(load_from_disk('/opt/slh/icecube/data/hf_cashe/batch_1.parquet'))
dl = DataLoader(ds, batch_size=12, collate_fn=collate_fn_v2, num_workers=1, drop_last=True)
batch=next(iter(dl))

In [177]:
md = EncoderWithDirectionReconstructionV10().eval()
#with torch.no_grad():
#    out = md(batch)


In [178]:
count_parameters(md)/1000000

13.776453

In [179]:
model = EncoderWithDirectionReconstructionV5().eval()
event = torch.rand(4, 100, 9)
mask = torch.ones(4, 100, dtype=torch.bool)
sensor_id = torch.randint(0, 5161, (10, 100))
label = torch.rand(4, 3)
input = dict(event=event, mask=mask, sensor_id=sensor_id, label=label)
with torch.no_grad():
    y = model(input)

y.shape

torch.Size([4, 4])

### Graph Transformers

In [180]:
# | export




DIST_KERNELS = {
    "exp": {
        "fn": lambda t: torch.exp(-t),
        "mask_value_fn": lambda t: torch.finfo(t.dtype).max,
    },
    "softmax": {
        "fn": lambda t: torch.softmax(t, dim=-1),
        "mask_value_fn": lambda t: -torch.finfo(t.dtype).max,
    },
}

# helpers


def exists(val):
    return val is not None


def default(val, d):
    return d if not exists(val) else val


# helper classes


class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, **kwargs):
        return x + self.fn(x, **kwargs)


class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn

    def forward(self, x, **kwargs):
        x = self.norm(x)
        return self.fn(x, **kwargs)


class FeedForwardV1(nn.Module):
    def __init__(self, dim, dim_out=None, mult=4):
        super().__init__()
        dim_out = default(dim_out, dim)
        self.net = nn.Sequential(
            nn.Linear(dim, dim * mult), nn.GELU(), nn.Linear(dim * mult, dim_out)
        )

    def forward(self, x):
        return self.net(x)


class Attention(nn.Module):
    def __init__(
        self, dim, heads=8, dim_head=64, Lg=0.5, Ld=0.5, La=1, dist_kernel_fn="exp"
    ):
        super().__init__()
        inner_dim = dim_head * heads
        self.heads = heads
        self.scale = dim_head**-0.5
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
        self.to_out = nn.Linear(inner_dim, dim)

        # hyperparameters controlling the weighted linear combination from
        # self-attention (La)
        # adjacency graph (Lg)
        # pair-wise distance matrix (Ld)

        self.La = La
        self.Ld = Ld
        self.Lg = Lg

        self.dist_kernel_fn = dist_kernel_fn

    def forward(self, x, mask=None, adjacency_mat=None, distance_mat=None):
        h, La, Ld, Lg, dist_kernel_fn = (
            self.heads,
            self.La,
            self.Ld,
            self.Lg,
            self.dist_kernel_fn,
        )

        qkv = self.to_qkv(x)
        q, k, v = rearrange(qkv, "b n (h qkv d) -> b h n qkv d", h=h, qkv=3).unbind(
            dim=-2
        )
        dots = einsum("b h i d, b h j d -> b h i j", q, k) * self.scale

        assert (
            dist_kernel_fn in DIST_KERNELS
        ), f"distance kernel function needs to be one of {DIST_KERNELS.keys()}"
        dist_kernel_config = DIST_KERNELS[dist_kernel_fn]

        if exists(distance_mat):
            distance_mat = rearrange(distance_mat, "b i j -> b () i j")

        if exists(adjacency_mat):
            adjacency_mat = rearrange(adjacency_mat, "b i j -> b () i j")

        if exists(mask):
            mask_value = torch.finfo(dots.dtype).max
            mask = mask[:, None, :, None] * mask[:, None, None, :]

            # mask attention
            dots.masked_fill_(~mask, -mask_value)

            if exists(distance_mat):
                # mask distance to infinity
                # todo - make sure for softmax distance kernel, use -infinity
                dist_mask_value = dist_kernel_config["mask_value_fn"](dots)
                distance_mat.masked_fill_(~mask, dist_mask_value)

            if exists(adjacency_mat):
                adjacency_mat.masked_fill_(~mask, 0.0)

        attn = dots.softmax(dim=-1)

        # sum contributions from adjacency and distance tensors
        attn = attn * La

        if exists(adjacency_mat):
            attn = attn + Lg * adjacency_mat

        if exists(distance_mat):
            distance_mat = dist_kernel_config["fn"](distance_mat)
            attn = attn + Ld * distance_mat

        out = einsum("b h i j, b h j d -> b h i d", attn, v)
        out = rearrange(out, "b h n d -> b n (h d)")
        return self.to_out(out)


# main class


class MAT(nn.Module):
    def __init__(
        self,
        *,
        dim_in,
        model_dim,
        dim_out,
        depth,
        heads=8,
        Lg=0.5,
        Ld=0.5,
        La=1,
        dist_kernel_fn="exp",
    ):
        super().__init__()

        self.embed_to_model = nn.Linear(dim_in, model_dim)
        self.layers = nn.ModuleList([])

        for _ in range(depth):
            layer = nn.ModuleList(
                [
                    Residual(
                        PreNorm(
                            model_dim,
                            Attention(
                                model_dim,
                                heads=heads,
                                Lg=Lg,
                                Ld=Ld,
                                La=La,
                                dist_kernel_fn=dist_kernel_fn,
                            ),
                        )
                    ),
                    Residual(PreNorm(model_dim, FeedForwardV1(model_dim))),
                ]
            )
            self.layers.append(layer)

        self.norm_out = nn.LayerNorm(model_dim)
        self.ff_out = FeedForward(model_dim, dim_out)


    def forward(self, batch):

        x = batch["event"]
        mask = batch["mask"]
        adjacency_mat = batch["adjecent_matrix"]
        distance_mat = batch["distance_matrix"]

        x = self.embed_to_model(x)

        for (attn, ff) in self.layers:
            x = attn(
                x, mask=mask, adjacency_mat=adjacency_mat, distance_mat=distance_mat
            )
            x = ff(x)

        x = self.norm_out(x)
        x = x.mean(dim=-2)
        x = self.ff_out(x)
        return x


class MATMaskedPool(nn.Module):
    def __init__(
        self,
        *,
        dim_in,
        model_dim,
        dim_out,
        depth,
        heads=8,
        Lg=1.0,
        Ld=1.0,
        La=1,
        dist_kernel_fn="exp",
    ):
        super().__init__()

        self.embed_to_model = nn.Linear(dim_in, model_dim)
        self.layers = nn.ModuleList([])

        for _ in range(depth):
            layer = nn.ModuleList(
                [
                    Residual(
                        PreNorm(
                            model_dim,
                            Attention(
                                model_dim,
                                heads=heads,
                                Lg=Lg,
                                Ld=Ld,
                                La=La,
                                dist_kernel_fn=dist_kernel_fn,
                            ),
                        )
                    ),
                    Residual(PreNorm(model_dim, FeedForwardV1(model_dim))),
                ]
            )
            self.layers.append(layer)

        self.norm_out = nn.LayerNorm(model_dim)
        self.pool = MeanPoolingWithMask()
        self.ff_out = FeedForward(model_dim, dim_out)


    def forward(self, batch):

        x = batch["event"]
        mask = batch["mask"]
        adjacency_mat = batch["adjecent_matrix"]
        distance_mat = batch["distance_matrix"]

        x = self.embed_to_model(x)

        for (attn, ff) in self.layers:
            x = attn(
                x, mask=mask, adjacency_mat=adjacency_mat, distance_mat=distance_mat
            )
            x = ff(x)

        x = self.norm_out(x)
        x = self.pool(x, mask=mask)
        x = self.ff_out(x)
        return x
    
    

class MATAdjusted(nn.Module):
    def __init__(
        self,
        model_dim,
        depth,
        heads=12,
        Lg=0.75,
        Ld=0.75,
        La=1.,
        dist_kernel_fn="softmax",
    ):
        super().__init__()
        self.layers = nn.ModuleList([])

        for _ in range(depth):
            layer = nn.ModuleList(
                [
                    Residual(
                        PreNorm(
                            model_dim,
                            Attention(
                                model_dim,
                                heads=heads,
                                Lg=Lg,
                                Ld=Ld,
                                La=La,
                                dist_kernel_fn=dist_kernel_fn,
                            ),
                        )
                    ),
                    Residual(PreNorm(model_dim, FeedForwardV1(model_dim))),
                ]
            )
            self.layers.append(layer)



    def forward(self, x, adjacency_mat, mask):


        for (attn, ff) in self.layers:
            x = attn(
                x, mask=mask, adjacency_mat=adjacency_mat
            )
            x = ff(x)

        return x



class IceCubeModelEncoderMAT(nn.Module):
    def __init__(self, dim=128, in_features=6):
        super().__init__()
        self.md = MAT(
            dim_in=6,
            model_dim=128,
            dim_out=2,
            depth=6,
            Lg=0.5,  # lambda (g)raph - weight for adjacency matrix
            Ld=0.5,  # lambda (d)istance - weight for distance matrix
            La=1,  # lambda (a)ttention - weight for usual self-attention
            dist_kernel_fn="exp",  # distance kernel fn - either 'exp' or 'softmax'
        )

    def forward(self, batch):
        return self.md(batch)


class IceCubeModelEncoderMATMasked(nn.Module):
    def __init__(self, dim=128, in_features=6):
        super().__init__()
        self.md = MATMaskedPool(
            dim_in=6,
            model_dim=128,
            dim_out=2,
            depth=6,
            Lg=0.5,  # lambda (g)raph - weight for adjacency matrix
            Ld=0.5,  # lambda (d)istance - weight for distance matrix
            La=1,  # lambda (a)ttention - weight for usual self-attention
            dist_kernel_fn="exp",  # distance kernel fn - either 'exp' or 'softmax'
        )

    def forward(self, batch):
        return self.md(batch)


# helpers

def exists(val):
    return val is not None

def batched_index_select(values, indices):
    last_dim = values.shape[-1]
    return values.gather(1, indices[:, :, None].expand(-1, -1, last_dim))


class FeedForward(nn.Module):
    def __init__(self, dim, mult = 4, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, dim * mult),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(dim * mult, dim)
        )

    def forward(self, x, **kwargs):
        return self.net(x)


class NMatrixAttention(nn.Module):
    def __init__(
        self,
        *,
        dim,
        dim_head = 64,
        heads = 4,
        dropout = 0.1
    ):
        super().__init__()
        inner_dim = dim_head * heads
        self.scale = dim_head ** -0.5
        self.heads = heads

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
        self.to_out = nn.Linear(inner_dim, dim)
        #null_k and null_v parameters serve as learnable "null" key and value vectors.
        #provids a default key and value for each attention head 
        #when there is no connection between two nodes or when adjacency information is missing.
        #By including these null keys and values, the attention mechanism can learn to assign a
        #ppropriate importance to the null entries in the adjacency matrix, effectively allowing the model to learn 
        #how to handle situations where neighborhood information is incomplete or scarce.
        self.null_k = nn.Parameter(torch.randn(heads, dim_head))
        self.null_v = nn.Parameter(torch.randn(heads, dim_head))

        self.dropout = nn.Dropout(dropout)

    def forward(
        self,
        x,
        adj_kv_indices,
        mask
    ):
        b, n, d, h = *x.shape, self.heads
        flat_indices = repeat(adj_kv_indices, 'b n a -> (b h) (n a)', h = h)
        #splits the input tensor into query q, key k, and value v tensors using the to_qkv linear laye
        q, k, v = self.to_qkv(x).chunk(3, dim = -1)
        #rearranges q, k, and v tensors to have separate head dimensions.
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))

        #batched_index_select to select the corresponding k and v tensors based on the adjacency indices
        k, v = map(lambda t: rearrange(t, 'b h n d -> (b h) n d'), (k, v))
        k = batched_index_select(k, flat_indices)
        v = batched_index_select(v, flat_indices)
        k, v = map(lambda t: rearrange(t, '(b h) (n a) d -> b h n a d', h = h, n = n), (k, v))

        nk, nv = map(lambda t: rearrange(t, 'h d -> () h () () d').expand(b, -1, n, 1, -1), (self.null_k, self.null_v))
        k = torch.cat((nk, k), dim = -2)
        v = torch.cat((nv, v), dim = -2)
        mask = F.pad(mask, (1, 0), value = 1)
        #calculate the similarity scores between queries and keys, scales them, and applies the mask.
        sim = einsum('b h n d, b h n a d -> b h n a', q, k) * self.scale

        mask_value = -torch.finfo(sim.dtype).max
        mask = rearrange(mask.bool(), 'b n a -> b () n a')
        sim.masked_fill_(~mask.bool(), mask_value)

        attn = sim.softmax(dim = -1)
        attn = self.dropout(attn)

        out = einsum('b h n a, b h n a d -> b h n d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')

        return self.to_out(out)


class LocalAttenNetwok(nn.Module):
    def __init__(
        self,
        dim,
        depth,
        dim_head = 64,
        heads = 8,
        num_neighbors_cutoff = None,
        attn_dropout = 0.1,
    ):
        super().__init__()
        self.num_neighbors_cutoff = num_neighbors_cutoff
        self.layers = nn.ModuleList([])

        for _ in range(depth):
            global_attn = None
            self.layers.append(nn.ModuleList([
                Residual(PreNorm(dim, NMatrixAttention(
                    dim = dim,
                    dim_head = dim_head,
                    heads = heads,
                    dropout = attn_dropout
                ))),
                global_attn,
            ]))

    def forward(self, x, adjacency_mat, mask = None):
        device, n = x.device, x.shape[1]

        diag = torch.eye(adjacency_mat.shape[-1], device = device).bool()
        adjacency_mat |= diag
        if exists(mask):
            adjacency_mat &= (mask[:, :, None] * mask[:, None, :])

        adj_mat = adjacency_mat.float()
        max_neighbors = int(adj_mat.sum(dim = -1).max())

        if exists(self.num_neighbors_cutoff) and max_neighbors > self.num_neighbors_cutoff:
            noise = torch.empty((n, n), device = device).uniform_(-0.01, 0.01)
            adj_mat = adj_mat + noise
            adj_mask, adj_kv_indices = adj_mat.topk(dim = -1, k = self.num_neighbors_cutoff)
            adj_mask = (adj_mask > 0.5).float()
        else:
            adj_mask, adj_kv_indices = adj_mat.topk(dim = -1, k = max_neighbors)
        for attn, _ in self.layers:
            x = attn(
                x,
                adj_kv_indices = adj_kv_indices,
                mask = adj_mask
            )


        return x
    

class LAttentionV2(nn.Module):
    def __init__(
        self,
        dim,
        heads = 8,
        dim_head = 64,
        and_self_attend = False
    ):
        super().__init__()
        inner_dim = heads * dim_head
        self.heads = heads
        self.scale = dim_head ** -0.5

        self.and_self_attend = and_self_attend

        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
        self.to_out = nn.Linear(inner_dim, dim, bias = False)

    def forward(
        self,
        x,
        context,
        mask = None
    ):
        h, scale = self.heads, self.scale
        q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
        dots = einsum('b h i d, b h j d -> b h i j', q, k) * scale

        if exists(mask):
            mask_value = -torch.finfo(dots.dtype).max
            mask = rearrange(mask, 'b n -> b 1 1 n')
            dots.masked_fill_(~mask, mask_value)

        attn = dots.softmax(dim = -1)
        out = einsum('b h i j, b h j d -> b h i d', attn, v)

        out = rearrange(out, 'b h n d -> b n (h d)', h = h)
        return self.to_out(out)

class LocalLatentsAttent(nn.Module):
    def __init__(
        self,
        dim,
        heads = 8,
        num_latents = 64,
        latent_self_attend = False
    ):
        super().__init__()
        self.latents = nn.Parameter(torch.randn(num_latents, dim))
        self.attn1 = LAttentionV2(dim, heads, and_self_attend = latent_self_attend)
        self.attn2 = LAttentionV2(dim, heads)

    def forward(self, x, latents = None, mask = None):
        b, *_ = x.shape

        latents = self.latents

        if latents.ndim == 2:
            latents = repeat(latents, 'n d -> b n d', b = b)

        latents = self.attn1(latents, x, mask = mask)
        out     = self.attn2(x, latents)

        return out, latents
    
class LocalAttenV2(nn.Module):
    def __init__(
        self,
        dim,
        depth,
        heads = 8,
        num_latents = 64,
        ff_dropout = 0.
    ):
        super().__init__()
        self.layers = nn.ModuleList([])

        for _ in range(depth):
            global_attn = PreNorm(dim, LocalLatentsAttent(
                dim = dim,
                heads = heads,
                num_latents = num_latents
            )) 

            self.layers.append(nn.ModuleList([
                global_attn,
                Residual(PreNorm(dim, FeedForward(
                    dim = dim,
                    dropout = ff_dropout
                )))
            ]))

    def forward(self, x, mask = None):
        for attn, ff in self.layers:
            out, _ = attn(x, mask = mask)
            x = x + out
            x = ff(x)
        return x
    
class GlobalLocalAttention(nn.Module):
    def __init__(
        self,
        dim,
        depth,
        dim_head = 64,
        heads = 8,
        num_neighbors_cutoff = None,
        attn_dropout = 0.1,
        ff_dropout=0.,
    ):
        super().__init__()
        self.num_neighbors_cutoff = num_neighbors_cutoff
        self.layers = nn.ModuleList([])

        for _ in range(depth):
            global_attn = PreNorm(dim, LocalLatentsAttent(
                dim = dim,
                heads = heads,
                num_latents = 64
            )) 
        
            self.layers.append(nn.ModuleList([
                global_attn,
                Residual(PreNorm(dim, NMatrixAttention(
                    dim = dim,
                    dim_head = dim_head,
                    heads = heads,
                    dropout = attn_dropout
                ))),
                Residual(PreNorm(dim, FeedForward(
                    dim = dim,
                    dropout = ff_dropout
                )))
            ]))

    def forward(self, x, adjacency_mat, mask = None):
        device, n = x.device, x.shape[1]

        diag = torch.eye(adjacency_mat.shape[-1], device = device).bool()
        adjacency_mat |= diag
        if exists(mask):
            adjacency_mat &= (mask[:, :, None] * mask[:, None, :])

        adj_mat = adjacency_mat.float()
        max_neighbors = int(adj_mat.sum(dim = -1).max())

        if exists(self.num_neighbors_cutoff) and max_neighbors > self.num_neighbors_cutoff:
            noise = torch.empty((n, n), device = device).uniform_(-0.01, 0.01)
            adj_mat = adj_mat + noise
            adj_mask, adj_kv_indices = adj_mat.topk(dim = -1, k = self.num_neighbors_cutoff)
            adj_mask = (adj_mask > 0.5).float()
        else:
            adj_mask, adj_kv_indices = adj_mat.topk(dim = -1, k = max_neighbors)
        
        for attn, locla_attn, ff in self.layers:
            x, _ = attn(x, mask = mask)
            out = locla_attn(
                x,
                adj_kv_indices = adj_kv_indices,
                mask = adj_mask
            )
            x = x + out
            x = ff(x)
            
        return x
    

class LocalGlobalAttention(nn.Module):
    def __init__(
        self,
        dim,
        depth,
        dim_head = 64,
        heads = 8,
        num_neighbors_cutoff = None,
        attn_dropout = 0.1,
        ff_dropout=0.,
    ):
        super().__init__()
        self.num_neighbors_cutoff = num_neighbors_cutoff
        self.layers = nn.ModuleList([])

        for _ in range(depth):
            global_attn = PreNorm(dim, LocalLatentsAttent(
                dim = dim,
                heads = heads,
                num_latents = 64
            )) 
        
            self.layers.append(nn.ModuleList([
                Residual(PreNorm(dim, NMatrixAttention(
                    dim = dim,
                    dim_head = dim_head,
                    heads = heads,
                    dropout = attn_dropout
                ))),
                global_attn,
                Residual(PreNorm(dim, FeedForward(
                    dim = dim,
                    dropout = ff_dropout
                )))
            ]))

    def forward(self, x, adjacency_mat, mask = None):
        device, n = x.device, x.shape[1]

        diag = torch.eye(adjacency_mat.shape[-1], device = device).bool()
        adjacency_mat |= diag
        if exists(mask):
            adjacency_mat &= (mask[:, :, None] * mask[:, None, :])

        adj_mat = adjacency_mat.float()
        max_neighbors = int(adj_mat.sum(dim = -1).max())

        if exists(self.num_neighbors_cutoff) and max_neighbors > self.num_neighbors_cutoff:
            noise = torch.empty((n, n), device = device).uniform_(-0.01, 0.01)
            adj_mat = adj_mat + noise
            adj_mask, adj_kv_indices = adj_mat.topk(dim = -1, k = self.num_neighbors_cutoff)
            adj_mask = (adj_mask > 0.5).float()
        else:
            adj_mask, adj_kv_indices = adj_mat.topk(dim = -1, k = max_neighbors)
        
        for attn, locla_attn, ff in self.layers:
            x = attn(
                x,
                adj_kv_indices = adj_kv_indices,
                mask = adj_mask
            )
            out, _ = locla_attn(x, mask = mask)
            x = x + out
            x = ff(x)
            
        return x
    
class RMSNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.scale = dim ** 0.5
        self.gamma = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        normed = F.normalize(x, dim = -1)
        return normed * self.scale * self.gamma
    
class gAttention(nn.Module):
    def __init__(
        self,
        dim,
        dim_head = 64,
        heads = 8
    ):
        super().__init__()
        self.heads = heads
        self.scale = dim_head ** -0.5
        dim_hidden = dim_head * heads

        self.norm = RMSNorm(dim)
        self.to_q = nn.Linear(dim, dim_hidden, bias = False)
        self.to_kv = nn.Linear(dim, dim_hidden * 2, bias = False)
        self.to_out = nn.Linear(dim_hidden, dim, bias = False)

    def forward(
        self,
        x,
        context = None,
        mask = None,
    ):
        h = self.heads
        x = self.norm(x)
        if exists(context):
            context = self.norm(context)
        context = default(context, x)
        q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v))
        q = q * self.scale
        sim = einsum('b h i d, b h j d -> b h i j', q, k)

        if exists(mask):
            mask = rearrange(mask, 'b j -> b 1 1 j')
            sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)

        attn = sim.softmax(dim = -1)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)
    
class GlobalAttentionV5(nn.Module):
    def __init__(
        self,
        dim,
        depth,
        heads = 8,
        dim_heads = 64,
        ff_dropout = 0.
    ):
        super().__init__()
        self.layers = nn.ModuleList([])

        for _ in range(depth):
            global_attn = gAttention(
                dim = dim,
                heads = heads,
                dim_head = dim_heads,
            )

            self.layers.append(nn.ModuleList([
                global_attn,
                Residual(PreNorm(dim, FeedForward(
                    dim = dim,
                    dropout = ff_dropout
                )))
            ]))

    def forward(self, x, mask = None, context = None):
        for attn, ff in self.layers:
            out = attn(x, mask = mask, context = context)
            x = x + out
            x = ff(x)
        return x

class BeDropPath(nn.Module):
    def __init__(self, drop_prob=None):
        super(BeDropPath, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training)
    
    def extra_repr(self) -> str:
        return 'p={}'.format(self.drop_prob)
    
class BeMLP(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        # x = self.drop(x)
        # commit this for the orignal BERT implement 
        x = self.fc2(x)
        x = self.drop(x)
        return x

#BEiTv2 Beblock
class BeBlock(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
                 window_size=None, attn_head_dim=None, **kwargs):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = nn.MultiheadAttention(dim, num_heads, dropout=drop, batch_first=True)
        self.drop_path = BeDropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = BeMLP(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

        if init_values is not None:
            self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
            self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
        else:
            self.gamma_1, self.gamma_2 = None, None

    def forward(self, x, attn_mask=None, key_padding_mask=None):
        if self.gamma_1 is None:
            xn = self.norm1(x)
            x = x + self.drop_path(self.attn(xn,xn,xn,
                            attn_mask=attn_mask,
                            key_padding_mask=key_padding_mask,
                            need_weights=False)[0])
            x = x + self.drop_path(self.mlp(self.norm2(x)))
        else:
            xn = self.norm1(x)
            x = x + self.drop_path(self.gamma_1 * self.drop_path(self.attn(xn,xn,xn,
                            attn_mask=attn_mask,
                            key_padding_mask=key_padding_mask,
                            need_weights=False)[0]))
            x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
        return x


class BeDeepIceModel(nn.Module):
    def __init__(self, dim=384, depth=12, out_class = 3, use_checkpoint=False, drop_b= 0., div_factor=64, attn_drop_b = 0., drop_path = 0.,  **kwargs):
        super().__init__()
        self.Beblocks = nn.ModuleList([ 
            BeBlock(
                dim=dim, num_heads=dim//div_factor, mlp_ratio=4, drop_path=drop_path, init_values=1, attn_drop=attn_drop_b, drop=drop_b)
            for i in range(depth)])
        #self.Beblocks = nn.ModuleList([ 
        #    nn.TransformerEncoderLayer(dim,dim//64,dim*4,dropout=0,
        #        activation=nn.GELU(), batch_first=True, norm_first=True)
        #    for i in range(depth)])
        self.out_class = out_class
        self.proj_out = nn.Linear(dim,out_class) if out_class == 3 else nn.Identity()
        self.use_checkpoint = use_checkpoint
        self.apply(self._init_weights)

    def fix_init_weight(self):
        def rescale(param, layer_id):
            param.div_(math.sqrt(2.0 * layer_id))

        for layer_id, layer in enumerate(self.Beblocks):
            rescale(layer.attn.proj.weight.data, layer_id + 1)
            rescale(layer.mlp.fc2.weight.data, layer_id + 1)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def init_weights(self, pretrained=None):
        def _init_weights(m):
            if isinstance(m, nn.Linear):
                trunc_normal_(m.weight, std=.02)
                if isinstance(m, nn.Linear) and m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.LayerNorm):
                nn.init.constant_(m.bias, 0)
                nn.init.constant_(m.weight, 1.0)
        self.apply(_init_weights)
    
    def forward(self, x, mask):
        attn_mask = torch.zeros(mask.shape, device=mask.device)
        attn_mask[~mask] = -torch.inf
        for blk in self.Beblocks:
            if self.use_checkpoint:
                x = checkpoint.checkpoint(blk, x, None, attn_mask)
            else: x = blk(x, None, attn_mask)
        if self.out_class == 3:
            x = self.proj_out(x[:,0]) #cls token
        return x
    
    
class EncoderWithDirectionReconstructionV11(nn.Module):
    def __init__(self, dim_out=256 + 64, drop_path=0.):
        super().__init__()
        self.fe = ExtractorV0(dim=dim_out, dim_base=96)
        self.encoder = BeDeepIceModel(dim_out, drop_path=drop_path)
        self.cls_token = nn.Linear(dim_out,1,bias=False)
        self.loacl_attn = LocalAttenNetwok(dim = dim_out, depth = 3, num_neighbors_cutoff = 24)
        trunc_normal_(self.cls_token.weight, std=.02)

    def forward(self, batch):
        mask = batch["mask"] #bs, seq_len
        bs = mask.shape[0] # int
        pos = batch["pos"][mask] 
        mask = mask[:,:mask.sum(-1).max()] 
        batch_index = mask.nonzero()[:,0] 
        edge_index = knn_graph(x = pos, k=8, batch=batch_index).to(mask.device)
        adj_matrix = to_dense_adj(edge_index, batch_index).int()
        x = self.fe(batch, mask.sum(-1).max())
        x = self.loacl_attn(x, adj_matrix, mask)
        cls_token = self.cls_token.weight.unsqueeze(0).expand(bs,-1,-1)
        x = torch.cat([cls_token,x],1)
        mask = torch.cat([torch.ones(bs, 1, dtype=torch.bool, device=x.device), mask], dim=1)
        x = self.encoder(x, mask=mask)
        return x
    
    
class EncoderWithDirectionReconstructionV11_V2_GLOBAL_LOCAL(nn.Module):
    def __init__(self, dim_out=256 + 64, drop_path=0.):
        super().__init__()
        self.fe = ExtractorV0(dim=dim_out, dim_base=96)
        self.encoder = BeDeepIceModel(dim_out, drop_path=drop_path)
        self.cls_token = nn.Linear(dim_out,1,bias=False)
        self.loacl_attn = GlobalLocalAttention(dim = dim_out, depth = 3, num_neighbors_cutoff = 24)
        trunc_normal_(self.cls_token.weight, std=.02)

    def forward(self, batch):
        mask = batch["mask"] #bs, seq_len
        bs = mask.shape[0] # int
        pos = batch["pos"][mask] 
        mask = mask[:,:mask.sum(-1).max()] 
        batch_index = mask.nonzero()[:,0] 
        edge_index = knn_graph(x = pos, k=8, batch=batch_index).to(mask.device)
        adj_matrix = to_dense_adj(edge_index, batch_index).int()
        x = self.fe(batch, mask.sum(-1).max())
        x = self.loacl_attn(x, adj_matrix, mask)
        cls_token = self.cls_token.weight.unsqueeze(0).expand(bs,-1,-1)
        x = torch.cat([cls_token,x],1)
        mask = torch.cat([torch.ones(bs, 1, dtype=torch.bool, device=x.device), mask], dim=1)
        x = self.encoder(x, mask=mask)
        return x
    
    
class EncoderWithDirectionReconstructionV11_V2_LOCAL_GLOBAL(nn.Module):
    def __init__(self, dim_out=256 + 64, drop_path=0.):
        super().__init__()
        self.fe = ExtractorV0(dim=dim_out, dim_base=96)
        self.encoder = BeDeepIceModel(dim_out, drop_path=drop_path)
        self.cls_token = nn.Linear(dim_out,1,bias=False)
        self.loacl_attn = LocalGlobalAttention(dim = dim_out, depth = 3, num_neighbors_cutoff = 24)
        trunc_normal_(self.cls_token.weight, std=.02)

    def forward(self, batch):
        mask = batch["mask"] #bs, seq_len
        bs = mask.shape[0] # int
        pos = batch["pos"][mask] 
        mask = mask[:,:mask.sum(-1).max()] 
        batch_index = mask.nonzero()[:,0] 
        edge_index = knn_graph(x = pos, k=8, batch=batch_index).to(mask.device)
        adj_matrix = to_dense_adj(edge_index, batch_index).int()
        x = self.fe(batch, mask.sum(-1).max())
        x = self.loacl_attn(x, adj_matrix, mask)
        cls_token = self.cls_token.weight.unsqueeze(0).expand(bs,-1,-1)
        x = torch.cat([cls_token,x],1)
        mask = torch.cat([torch.ones(bs, 1, dtype=torch.bool, device=x.device), mask], dim=1)
        x = self.encoder(x, mask=mask)
        return x
    
    
class EncoderWithDirectionReconstructionV12(nn.Module):
    def __init__(self, dim_out=256 + 64):
        super().__init__()
        self.fe = ExtractorV0(dim=dim_out, dim_base=96)
        self.encoder = BeDeepIceModel(dim_out)
        self.cls_token = nn.Linear(dim_out,1,bias=False)
        self.loacl_attn = MATAdjusted(model_dim = dim_out, depth = 3)
        trunc_normal_(self.cls_token.weight, std=.02)

    def forward(self, batch):
        mask = batch["mask"] #bs, seq_len
        bs = mask.shape[0] # int
        xyz = batch["pos"][mask]
        mask = mask[:,:mask.sum(-1).max()] 
        batch_index = mask.nonzero()[:,0] 
        edge_index = knn_graph(x = xyz, k=8, batch=batch_index).to(mask.device)
        adj_matrix = to_dense_adj(edge_index, batch_index).int()
        x = self.fe(batch, mask.sum(-1).max())
        x = self.loacl_attn(x, adj_matrix, mask)
        cls_token = self.cls_token.weight.unsqueeze(0).expand(bs,-1,-1)
        x = torch.cat([cls_token,x],1)
        mask = torch.cat([torch.ones(bs, 1, dtype=torch.bool, device=x.device), mask], dim=1)
        x = self.encoder(x, mask=mask)
        return x
    
class EncoderWithDirectionReconstructionV12_V2(nn.Module):
    def __init__(self, dim_out=256 + 64):
        super().__init__()
        self.fe = ExtractorV0(dim=dim_out, dim_base=96)
        self.encoder = ContinuousTransformerWrapper(
            dim_out=dim_out,
            max_seq_len=256,
            post_emb_norm = True,
            use_abs_pos_emb = False, 
            attn_layers=Encoder(dim=dim_out,
                                depth=8,
                                heads=8,
                                use_rmsnorm = True,
                                ff_glu = True,
                                rel_pos_bias = True,
                                deepnorm=True)   
        )
    
        self.cls_token = nn.Linear(dim_out,1,bias=False)
        self.loacl_attn = MATAdjusted(model_dim = dim_out, depth = 3)
        self.out = nn.Linear(dim_out, 3)
        trunc_normal_(self.cls_token.weight, std=.02)
        

    def forward(self, batch):
        mask = batch["mask"] #bs, seq_len
        bs = mask.shape[0] # int
        xyz = batch["pos"][mask]
        mask = mask[:,:mask.sum(-1).max()] 
        batch_index = mask.nonzero()[:,0] 
        edge_index = knn_graph(x = xyz, k=8 ,batch=batch_index).to(mask.device)
        adj_matrix = to_dense_adj(edge_index, batch_index).int()
        x = self.fe(batch, mask.sum(-1).max())
        x = self.loacl_attn(x, adj_matrix, mask)
        cls_token = self.cls_token.weight.unsqueeze(0).expand(bs,-1,-1)
        x = torch.cat([cls_token,x],1)
        mask = torch.cat([torch.ones(bs, 1, dtype=torch.bool, device=x.device), mask], dim=1)
        x = self.encoder(x, mask=mask)
        x = x[:, 0]
        return self.out(x)
    
    
class EncoderWithDirectionReconstructionV13(nn.Module):
    def __init__(self, dim_out=256):
        super().__init__()
        self.encoder = BeDeepIceModel(dim_out)
        self.cls_token = nn.Linear(dim_out,1,bias=False)
        self.graphnet = DynEdgeFEXTRACTRO(dim_out + 4)
        self.fe = ExtractorV0(dim=dim_out, dim_base=96)
        trunc_normal_(self.cls_token.weight, std=.02)

    def forward(self, batch):
        mask = batch["mask"] #bs, seq_len
        bs = mask.shape[0] # int
        xyzt = torch.concat([batch["pos"][mask] , batch['time'][mask].view(-1, 1)], dim=1)
        mask = mask[:,:mask.sum(-1).max()] 
        batch_index = mask.nonzero()[:,0] 
        edge_index = knn_graph(x = xyzt, k=12, batch=batch_index).to(mask.device)
        x = self.fe(batch, mask.sum(-1).max())
        x = x[mask]
        x = torch.cat([x, xyzt], dim=1)
        x, _, _ = self.graphnet(x, edge_index, batch_index, mask.sum(-1))
        x, mask = to_dense_batch(x, batch_index)
        cls_token = self.cls_token.weight.unsqueeze(0).expand(bs,-1,-1)
        x = torch.cat([cls_token,x],1)
        mask = torch.cat([torch.ones(bs, 1, dtype=torch.bool, device=x.device), mask], dim=1)
        x = self.encoder(x, mask=mask)
        
        return x
    
    
class EncoderWithDirectionReconstructionV14(nn.Module):
    def __init__(self, dim_out=256 + 64, drop_path=0.):
        super().__init__()
        self.fe = ExtractorV0(dim=dim_out, dim_base=96)
        self.encoder = BeDeepIceModel(dim_out, drop_path=drop_path)
        self.cls_token = nn.Linear(dim_out,1,bias=False)
        self.loacl_attn = LocalAttenV2(dim = dim_out, depth = 4)
        trunc_normal_(self.cls_token.weight, std=.02)

    def forward(self, batch):
        mask = batch["mask"] #bs, seq_len
        bs = mask.shape[0] # int
        mask = mask[:,:mask.sum(-1).max()] 
        x = self.fe(batch, mask.sum(-1).max())
        x = self.loacl_attn(x, mask)
        cls_token = self.cls_token.weight.unsqueeze(0).expand(bs,-1,-1)
        x = torch.cat([cls_token,x],1)
        mask = torch.cat([torch.ones(bs, 1, dtype=torch.bool, device=x.device), mask], dim=1)
        x = self.encoder(x, mask=mask)
        return x
    
    
class EncoderWithDirectionReconstructionV15(nn.Module):
    def __init__(self, dim_out=256 + 64, drop_path=0.):
        super().__init__()
        self.fe = ExtractorV0(dim=dim_out, dim_base=96)
        self.encoder = BeDeepIceModel(dim_out, drop_path=drop_path)
        self.cls_token = nn.Linear(dim_out,1,bias=False)
        self.loacl_attn = LocalAttenV2(dim = dim_out, depth = 5)
        trunc_normal_(self.cls_token.weight, std=.02)

    def forward(self, batch):
        mask = batch["mask"] #bs, seq_len
        bs = mask.shape[0] # int
        mask = mask[:,:mask.sum(-1).max()] 
        x = self.fe(batch, mask.sum(-1).max())
        cls_token = self.cls_token.weight.unsqueeze(0).expand(bs,-1,-1)
        x = torch.cat([cls_token,x],1)
        mask = torch.cat([torch.ones(bs, 1, dtype=torch.bool, device=x.device), mask], dim=1)
        x = self.loacl_attn(x, mask)
        x = self.encoder(x, mask=mask)
        return x
    
    
def get_ds_matrix(x, batch_index, mask, Lmax=None):
    pos = x['pos'] if Lmax is None else x['pos'][:,:Lmax]
    time = x['time'] if Lmax is None else x['time'][:,:Lmax]
    ds2 = (pos[:,:,None] - pos[:,None,:]).pow(2).sum(-1) - \
                ((time[:,:,None] - time[:,None,:])*(3e4/500*3e-1)).pow(2)
    d = torch.sign(ds2)*torch.sqrt(torch.abs(ds2))
    edge_index = knn_graph(x = d[mask], k=8, batch=batch_index).to(mask.device)
    return edge_index
    
    
class EncoderWithDirectionReconstructionV16(nn.Module):
    def __init__(self, dim_out=256, drop_path=0.):
        super().__init__()
        self.fe = ExtractorV0(dim=dim_out, dim_base=32)
        self.encoder = BeDeepIceModel(dim_out , drop_path=drop_path)
        self.cls_token = nn.Linear(dim_out,1,bias=False)
        self.local_root= EGNNModeLFEAT( emb_dim=dim_out, num_layers=2)
        self.global_root =  LocalAttenV2(dim = dim_out, depth =2)
        self.gl_lc = GlobalAttentionV5(dim = dim_out, depth = 1)
        self.lc_gl = GlobalAttentionV5(dim = dim_out, depth = 1)
        trunc_normal_(self.cls_token.weight, std=.02)

    def forward(self, batch):
        mask = batch["mask"] #bs, seq_len
        bs = mask.shape[0] # int
        pos = batch["pos"][mask] 
        mask = mask[:,:mask.sum(-1).max()] 
        batch_index = mask.nonzero()[:,0] 
        edge_index = get_ds_matrix(batch, batch_index, mask, Lmax=mask.sum(-1).max())
        x = self.fe(batch, mask.sum(-1).max())
        
        graph_featutre = self.local_root(x[mask], pos, edge_index)
        graph_featutre, mask = to_dense_batch(graph_featutre, batch_index)
        global_featutre = self.global_root(x, mask)
        x = self.gl_lc(graph_featutre, mask, context = global_featutre) + self.lc_gl(global_featutre, mask, context =graph_featutre )
        
        cls_token = self.cls_token.weight.unsqueeze(0).expand(bs,-1,-1)
        x = torch.cat([cls_token,x],1)
        mask = torch.cat([torch.ones(bs, 1, dtype=torch.bool, device=x.device), mask], dim=1)
        x = self.encoder(x, mask=mask)
        return x



class EncoderWithDirectionReconstructionV17(nn.Module):
    def __init__(self, dim_out=256, drop_path=0.):
        super().__init__()
        self.fe = ExtractorV0(dim=dim_out//2, dim_base=32)
        self.encoder = BeDeepIceModel(dim_out , drop_path=drop_path)
        self.cls_token = nn.Linear(dim_out,1,bias=False)
        self.local_root= EGNNModeLFEAT( emb_dim=dim_out//2, num_layers=3)
        self.global_root =  LocalAttenV2(dim = dim_out//2, depth =4)
        trunc_normal_(self.cls_token.weight, std=.02)

    def forward(self, batch):
        mask = batch["mask"] #bs, seq_len
        bs = mask.shape[0] # int
        pos = batch["pos"][mask] 
        mask = mask[:,:mask.sum(-1).max()] 
        batch_index = mask.nonzero()[:,0] 
        edge_index = get_ds_matrix(batch, batch_index, mask, Lmax=mask.sum(-1).max())
        x = self.fe(batch, mask.sum(-1).max())
        
        graph_featutre = self.local_root(x[mask], pos, edge_index)
        graph_featutre, mask = to_dense_batch(graph_featutre, batch_index)
        global_featutre = self.global_root(x, mask)
        x = torch.cat([global_featutre, graph_featutre],2)
        cls_token = self.cls_token.weight.unsqueeze(0).expand(bs,-1,-1)
        x = torch.cat([cls_token,x],1)
        mask = torch.cat([torch.ones(bs, 1, dtype=torch.bool, device=x.device), mask], dim=1)
        x = self.encoder(x, mask=mask)
        return x
    
    
class EncoderWithDirectionReconstructionV18(nn.Module):
    def __init__(self, dim_out=256, drop_path=0.):
        super().__init__()
        self.fe = ExtractorV0(dim=dim_out//2, dim_base=32)
        self.encoder = BeDeepIceModel(dim_out , drop_path=drop_path)
        self.cls_token = nn.Linear(dim_out,1,bias=False)
        self.local_root= DynEdgeFEXTRACTRO(9, 
                                           post_processing_layer_sizes = [336, dim_out//2], 
                                           dynedge_layer_sizes = [(128, 256), (336, 256), (336, 256), (336, 256)])
        self.global_root =  LocalAttenV2(dim = dim_out//2, depth =4)
        trunc_normal_(self.cls_token.weight, std=.02)

    def forward(self, batch):
        mask = batch["mask"] #bs, seq_len
        graph_featutre = torch.concat([batch["pos"][mask] , 
                             batch['time'][mask].view(-1, 1),
                             batch['auxiliary'][mask].view(-1, 1),
                             batch['qe'][mask].view(-1, 1),
                             batch['charge'][mask].view(-1, 1),
                             batch["ice_properties"][mask], 
                              ], dim=1)
        bs = mask.shape[0] # int
        mask = mask[:,:mask.sum(-1).max()] 
        batch_index = mask.nonzero()[:,0] 
        edge_index = knn_graph(x = graph_featutre[:,:3], k=8, batch=batch_index).to(mask.device)
        x = self.fe(batch, mask.sum(-1).max())
        
        graph_featutre, _, _ = self.local_root(graph_featutre, edge_index, batch_index, mask.sum(-1))
        graph_featutre, mask = to_dense_batch(graph_featutre, batch_index)
        global_featutre = self.global_root(x, mask)
        x = torch.cat([global_featutre, graph_featutre],2)
        cls_token = self.cls_token.weight.unsqueeze(0).expand(bs,-1,-1)
        x = torch.cat([cls_token,x],1)
        mask = torch.cat([torch.ones(bs, 1, dtype=torch.bool, device=x.device), mask], dim=1)
        x = self.encoder(x, mask=mask)
        return x
    
    
    
class EncoderWithDirectionReconstructionV19(nn.Module):
    def __init__(self, dim_out=256, drop_path=0.):
        super().__init__()
        self.fe = ExtractorV0(dim=dim_out, dim_base=32)
        self.encoder = BeDeepIceModel(dim_out, drop_path=drop_path)
        self.cls_token = nn.Linear(dim_out,1,bias=False)
        self.gl_attn = GlobalAttentionV5(dim = dim_out, depth = 12)
        trunc_normal_(self.cls_token.weight, std=.02)

    def forward(self, batch):
        mask = batch["mask"] #bs, seq_len
        bs = mask.shape[0] # int
        mask = mask[:,:mask.sum(-1).max()] 
        x = self.fe(batch, mask.sum(-1).max())
        x = self.gl_attn(x, mask)
        cls_token = self.cls_token.weight.unsqueeze(0).expand(bs,-1,-1)
        x = torch.cat([cls_token,x],1)
        mask = torch.cat([torch.ones(bs, 1, dtype=torch.bool, device=x.device), mask], dim=1)
        x = self.encoder(x, mask=mask)
        return x
    

In [181]:
#| export
class DropPath(nn.Module):
    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training)
    
    def extra_repr(self) -> str:
        return 'p={}'.format(self.drop_prob)
    
class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        # x = self.drop(x)
        # commit this for the orignal BERT implement 
        x = self.fc2(x)
        x = self.drop(x)
        return x

#BEiTv2 block
class Block(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
                 window_size=None, attn_head_dim=None, **kwargs):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = nn.MultiheadAttention(dim, num_heads, dropout=drop, batch_first=True)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

        if init_values is not None:
            self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
            self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
        else:
            self.gamma_1, self.gamma_2 = None, None

    def forward(self, x, attn_mask=None, key_padding_mask=None):
        if self.gamma_1 is None:
            xn = self.norm1(x)
            x = x + self.drop_path(self.attn(xn,xn,xn,
                            attn_mask=attn_mask,
                            key_padding_mask=key_padding_mask,
                            need_weights=False)[0])
            x = x + self.drop_path(self.mlp(self.norm2(x)))
        else:
            xn = self.norm1(x)
            x = x + self.drop_path(self.gamma_1 * self.attn(xn,xn,xn,
                            attn_mask=attn_mask,
                            key_padding_mask=key_padding_mask,
                            need_weights=False)[0])
            x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
        return x

class Attention_rel(nn.Module):
    def __init__(
            self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
            proj_drop=0.0, attn_head_dim=None):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        if attn_head_dim is not None:
            head_dim = attn_head_dim
        all_head_dim = head_dim * self.num_heads
        self.scale = qk_scale or head_dim ** -0.5

        self.proj_q = nn.Linear(dim, all_head_dim, bias=False)
        self.proj_k = nn.Linear(dim, all_head_dim, bias=False)
        self.proj_v = nn.Linear(dim, all_head_dim, bias=False)
        if qkv_bias:
            self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
            self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
        else:
            self.q_bias = None
            self.v_bias = None

        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(all_head_dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, q, k, v, rel_pos_bias=None, key_padding_mask=None):
        #rel_pos_bias: B L L C/h
        #key_padding_mask - float with -inf
        B, N, C = q.shape
        #qkv_bias = None
        #if self.q_bias is not None:
        #    qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
        # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        #qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
        #qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
        #q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)
        
        q = F.linear(input=q, weight=self.proj_q.weight, bias=self.q_bias)
        q = q.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
        k = F.linear(input=k, weight=self.proj_k.weight, bias=None)
        k = k.reshape(B, k.shape[1], self.num_heads, -1).permute(0, 2, 1, 3)
        v = F.linear(input=v, weight=self.proj_v.weight, bias=self.v_bias)
        v = v.reshape(B, v.shape[1], self.num_heads, -1).permute(0, 2, 1, 3)

        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))
        if rel_pos_bias is not None:
            bias = torch.einsum('bhic,bijc->bhij', q, rel_pos_bias)
            attn = attn + bias
        if key_padding_mask is not None:
            assert key_padding_mask.dtype == torch.float32 or key_padding_mask.dtype == torch.float16, \
                'incorrect mask dtype'
            bias = torch.min(key_padding_mask[:,None,:], key_padding_mask[:,:,None])
            bias[torch.max(key_padding_mask[:,None,:], key_padding_mask[:,:,None]) < 0] = 0
            #print(bias.shape,bias.min(),bias.max())
            attn = attn + bias.unsqueeze(1)
        
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2)
        if rel_pos_bias is not None:
            x = x + torch.einsum('bhij,bijc->bihc', attn, rel_pos_bias)
        x = x.reshape(B, N, -1)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x
    
#BEiTv2 block
class Block_rel(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
                 window_size=None, attn_head_dim=None, **kwargs):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention_rel(dim, num_heads, attn_drop=attn_drop, qkv_bias=qkv_bias)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

        if init_values is not None:
            self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
            self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
        else:
            self.gamma_1, self.gamma_2 = None, None

    def forward(self, x, key_padding_mask=None, rel_pos_bias=None, kv=None):
        if self.gamma_1 is None:
            xn = self.norm1(x)
            kv = xn if kv is None else self.norm1(kv)
            x = x + self.drop_path(self.attn(xn, kv, kv,
                            rel_pos_bias=rel_pos_bias,
                            key_padding_mask=key_padding_mask))
            x = x + self.drop_path(self.mlp(self.norm2(x)))
        else:
            xn = self.norm1(x)
            kv = xn if kv is None else self.norm1(kv)
            x = x + self.drop_path(self.gamma_1 * self.drop_path(self.attn(xn, kv, kv,
                            rel_pos_bias=rel_pos_bias,
                            key_padding_mask=key_padding_mask)))
            x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
        return x

    
class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim=16, M=10000):
        super().__init__()
        self.dim = dim
        self.M = M

    def forward(self, x):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(self.M) / half_dim
        emb = torch.exp(torch.arange(half_dim, device=device) * (-emb))
        emb = x[...,None] * emb[None,...]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb
    
class ExtractorV11(nn.Module):
    def __init__(self, dim_base=128, dim=384):
        super().__init__()
        self.emb = SinusoidalPosEmb(dim=dim_base)
        self.aux_emb = nn.Embedding(2,dim_base//2)
        self.proj = nn.Linear(11*dim_base//2,dim)
        
    def forward(self, x, Lmax=None):
        pos = x['pos'] if Lmax is None else x['pos'][:,:Lmax]
        charge = x['charge'] if Lmax is None else x['charge'][:,:Lmax]
        time = x['time'] if Lmax is None else x['time'][:,:Lmax]
        auxiliary = x['auxiliary'] if Lmax is None else x['auxiliary'][:,:Lmax]
        qe = x['qe'] if Lmax is None else x['qe'][:,:Lmax]
        ice_properties = x['ice_properties'] if Lmax is None else x['ice_properties'][:,:Lmax]
        
        x = torch.cat([self.emb(4096*pos).flatten(-2), self.emb(1024*charge),
                       self.emb(4096*time),self.aux_emb(auxiliary)
                      ],-1)
        x = self.proj(x)
        return x
    
class ScaledSinusoidalEmbedding(nn.Module):
    def __init__(self, dim=32, M = 10000):
        super().__init__()
        assert (dim % 2) == 0
        self.scale = nn.Parameter(torch.ones(1) * dim ** -0.5)
        self.dim = dim
        self.M = M

    def forward(self, x):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(self.M) / half_dim
        emb = torch.exp(torch.arange(half_dim, device=device) * (-emb))
        emb = x[...,None] * emb[None,...]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb * self.scale

class ExtractorV11Scaled(nn.Module):
    def __init__(self, dim_base=128, dim=384):
        super().__init__()
        self.pos = ScaledSinusoidalEmbedding(dim=dim_base)
        self.emb_charge = ScaledSinusoidalEmbedding(dim=dim_base)
        self.time = ScaledSinusoidalEmbedding(dim=dim_base)
        self.aux_emb = nn.Embedding(2,dim_base//2)
        self.emb2 = ScaledSinusoidalEmbedding(dim=dim_base//2)
        self.proj = nn.Sequential(nn.Linear(6*dim_base,6*dim_base),nn.LayerNorm(6*dim_base),
                                  nn.GELU(),nn.Linear(6*dim_base,dim))
        
    def forward(self, x, Lmax=None):
        pos = x['pos'] if Lmax is None else x['pos'][:,:Lmax]
        charge = x['charge'] if Lmax is None else x['charge'][:,:Lmax]
        time = x['time'] if Lmax is None else x['time'][:,:Lmax]
        auxiliary = x['auxiliary'] if Lmax is None else x['auxiliary'][:,:Lmax]
        length = torch.log10(x['L0'].to(dtype=pos.dtype))

        
        x = torch.cat([self.pos(4096*pos).flatten(-2), self.emb_charge(1024*charge),
                       self.time(4096*time),self.aux_emb(auxiliary), 
                       self.emb2(length).unsqueeze(1).expand(-1,pos.shape[1],-1)
                      ],-1)
        x = self.proj(x)
        return x

    
class Rel_ds(nn.Module):
    def __init__(self, dim=32):
        super().__init__()
        self.emb = SinusoidalPosEmb(dim=dim)
        self.proj = nn.Linear(dim,dim)
        
    def forward(self, x, Lmax=None):
        pos = x['pos'] if Lmax is None else x['pos'][:,:Lmax]
        time = x['time'] if Lmax is None else x['time'][:,:Lmax]
        ds2 = (pos[:,:,None] - pos[:,None,:]).pow(2).sum(-1) - \
                ((time[:,:,None] - time[:,None,:])*(3e4/500*3e-1)).pow(2)
        d = torch.sign(ds2)*torch.sqrt(torch.abs(ds2))
        emb = self.emb(1024*d.clip(-4,4))
        rel_attn = self.proj(emb)
        return rel_attn,emb

def get_nbs(x, Lmax=None, K=8):
    pos = x['pos'] if Lmax is None else x['pos'][:,:Lmax]
    mask = x['mask'][:,:Lmax]
    B = pos.shape[0]
    
    d = -torch.cdist(pos, pos, p=2) 
    d -= 100*(~torch.min(mask[:,None,:],mask[:,:,None]))
    d -= 200*torch.eye(Lmax, dtype=pos.dtype, device=pos.device).unsqueeze(0)
    nbs = d.topk(K-1,dim=-1)[1]
    nbs = torch.cat([
            torch.arange(Lmax, dtype=nbs.dtype, device=nbs.device).unsqueeze(0).unsqueeze(-1).expand(B,-1,-1),
            nbs],-1)
    return nbs
    
class LocalBlock(nn.Module):
    def __init__(self, dim=192, num_heads=192//64, mlp_ratio=4, drop_path=0, init_values=1, **kwargs):
        super().__init__()
        self.proj_rel_bias = nn.Linear(dim//num_heads,dim//num_heads)
        self.block = Block_rel(dim=dim, num_heads=num_heads, mlp_ratio=mlp_ratio, 
                               drop_path=drop_path, init_values=init_values)
        
    def forward(self, x, nbs, key_padding_mask=None, rel_pos_bias=None):
        B,Lmax,C = x.shape
        mask = key_padding_mask if not (key_padding_mask is None) \
            else torch.ones(B, Lmax, dtype=torch.bool, device=x.deice)
        
        m = torch.gather(mask.unsqueeze(1).expand(-1,Lmax,-1), 2, nbs)
        attn_mask = torch.zeros(m.shape, device=m.device)
        attn_mask[~mask] = -torch.inf
        attn_mask = attn_mask[mask]
        
        if rel_pos_bias is not None:
            rel_pos_bias = torch.gather(rel_pos_bias, 2, 
                                        nbs.unsqueeze(-1).expand(-1,-1,-1,rel_pos_bias.shape[-1]))
            rel_pos_bias = rel_pos_bias[mask]
            rel_pos_bias = self.proj_rel_bias(rel_pos_bias).unsqueeze(1)
            
        xl = torch.gather(x.unsqueeze(1).expand(-1,Lmax,-1,-1), 2, nbs.unsqueeze(-1).expand(-1,-1,-1,C))
        xl = xl[mask]
        # modify only the node (0th element)
        #print(xl[:,:1].shape,rel_pos_bias.shape,attn_mask[:,:1].shape,xl.shape)
        xl = self.block(xl[:,:1], rel_pos_bias=rel_pos_bias, key_padding_mask=attn_mask[:,:1], kv=xl)
        x = torch.zeros(x.shape, device=x.device, dtype=xl.dtype)
        x[mask] = xl.squeeze(1)
        return x


    
    
class EncoderWithDirectionReconstructionV20(nn.Module):
    def __init__(self, dim=384, dim_base=128, depth=8, use_checkpoint=False, head_size=64, **kwargs):
        super().__init__()
        self.extractor = ExtractorV11(dim_base,dim)
        self.rel_pos = Rel_ds(head_size)
        self.sandwich = nn.ModuleList([ 
            Block_rel(dim=dim, num_heads=dim//head_size),
            Block_rel(dim=dim, num_heads=dim//head_size),
            Block_rel(dim=dim, num_heads=dim//head_size),
            Block_rel(dim=dim, num_heads=dim//head_size),
        ])
        self.cls_token = nn.Linear(dim * 2,1,bias=False)
        self.blocks = nn.ModuleList([ 
            Block(
                dim=(dim * 2), num_heads=(dim * 2) //head_size, mlp_ratio=4, drop_path=0.0*(i/(depth-1)), init_values=1,)
            for i in range(depth)])
        self.proj_out = nn.Linear(dim * 2,3)
        self.local_root= DynEdgeFEXTRACTRO(9, 
                                           post_processing_layer_sizes = [336, dim], 
                                           dynedge_layer_sizes = [(128, 256), (336, 256), (336, 256), (336, 256)])   
        self.use_checkpoint = use_checkpoint
        self.apply(self._init_weights)
        trunc_normal_(self.cls_token.weight, std=.02)

    def fix_init_weight(self):
        def rescale(param, layer_id):
            param.div_(math.sqrt(2.0 * layer_id))

        for layer_id, layer in enumerate(self.blocks):
            rescale(layer.attn.proj.weight.data, layer_id + 1)
            rescale(layer.mlp.fc2.weight.data, layer_id + 1)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def init_weights(self, pretrained=None):
        def _init_weights(m):
            if isinstance(m, nn.Linear):
                trunc_normal_(m.weight, std=.02)
                if isinstance(m, nn.Linear) and m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.LayerNorm):
                nn.init.constant_(m.bias, 0)
                nn.init.constant_(m.weight, 1.0)
        self.apply(_init_weights)
        
    @torch.jit.ignore
    def no_weight_decay(self):
        return {'cls_token'}
    
    def forward(self, x0):
        mask = x0['mask']
        graph_featutre = torch.concat([x0["pos"][mask] , 
                             x0['time'][mask].view(-1, 1),
                             x0['auxiliary'][mask].view(-1, 1),
                             x0['qe'][mask].view(-1, 1),
                             x0['charge'][mask].view(-1, 1),
                             x0["ice_properties"][mask], 
                              ], dim=1)
        Lmax = mask.sum(-1).max()
        x = self.extractor(x0, Lmax)
        rel_pos_bias, rel_enc = self.rel_pos(x0, Lmax)
        #nbs = get_nbs(x0, Lmax)
        mask = mask[:,:Lmax]
        batch_index = mask.nonzero()[:,0] 
        edge_index = knn_graph(x = graph_featutre[:,:3], k=8, batch=batch_index).to(mask.device)
        graph_featutre, _, _ = self.local_root(graph_featutre, edge_index, batch_index, x0['L0'])
        graph_featutre, _ = to_dense_batch(graph_featutre, batch_index)
        
        B,_ = mask.shape
        attn_mask = torch.zeros(mask.shape, device=mask.device)
        attn_mask[~mask] = -torch.inf
        
        for blk in self.sandwich:
            if isinstance(blk,LocalBlock):
                x = blk(x,nbs,mask,rel_enc)
            else:
                x = blk(x,attn_mask,rel_pos_bias)
                rel_pos_bias = None
        x = torch.cat([x,graph_featutre],2)
        mask = torch.cat([torch.ones(B,1,dtype=mask.dtype, device=mask.device),mask],1)
        attn_mask = torch.zeros(mask.shape, device=mask.device)
        attn_mask[~mask] = -torch.inf
        cls_token = self.cls_token.weight.unsqueeze(0).expand(B,-1,-1)
        x = torch.cat([cls_token,x],1)
        
        for blk in self.blocks:
            if self.use_checkpoint:
                x = checkpoint.checkpoint(blk, x, None, attn_mask)
            else: x = blk(x, None, attn_mask)
                
        x = self.proj_out(x[:,0]) #cls token
        return x
    
    
class EncoderWithDirectionReconstructionV22(nn.Module):
    def __init__(self, dim=384, dim_base=128, depth=8, use_checkpoint=False, head_size=64, **kwargs):
        super().__init__()
        self.extractor = ExtractorV11Scaled(dim_base,dim//2)
        self.rel_pos = Rel_ds(head_size)
        self.sandwich = nn.ModuleList([ 
            Block_rel(dim=dim, num_heads=dim//head_size),
            Block_rel(dim=dim, num_heads=dim//head_size),
            Block_rel(dim=dim, num_heads=dim//head_size),
            Block_rel(dim=dim, num_heads=dim//head_size),
        ])
        self.cls_token = nn.Linear(dim,1,bias=False)
        self.blocks = nn.ModuleList([ 
            Block(
                dim=dim, num_heads=dim//head_size, mlp_ratio=4, drop_path=0.0*(i/(depth-1)), init_values=1,)
            for i in range(depth)])
        self.proj_out = nn.Linear(dim,3)
        self.use_checkpoint = use_checkpoint
        self.local_root= DynEdgeFEXTRACTRO(9, 
                                           post_processing_layer_sizes = [336, dim//2], 
                                           dynedge_layer_sizes = [(128, 256), (336, 256), (336, 256), (336, 256)])   
        self.apply(self._init_weights)
        trunc_normal_(self.cls_token.weight, std=.02)

    def fix_init_weight(self):
        def rescale(param, layer_id):
            param.div_(math.sqrt(2.0 * layer_id))

        for layer_id, layer in enumerate(self.blocks):
            rescale(layer.attn.proj.weight.data, layer_id + 1)
            rescale(layer.mlp.fc2.weight.data, layer_id + 1)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def init_weights(self, pretrained=None):
        def _init_weights(m):
            if isinstance(m, nn.Linear):
                trunc_normal_(m.weight, std=.02)
                if isinstance(m, nn.Linear) and m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.LayerNorm):
                nn.init.constant_(m.bias, 0)
                nn.init.constant_(m.weight, 1.0)
        self.apply(_init_weights)
        
    @torch.jit.ignore
    def no_weight_decay(self):
        return {'cls_token'}
    
    def forward(self, x0):
        mask = x0['mask']
        graph_featutre = torch.concat([x0["pos"][mask] , 
                             x0['time'][mask].view(-1, 1),
                             x0['auxiliary'][mask].view(-1, 1),
                             x0['qe'][mask].view(-1, 1),
                             x0['charge'][mask].view(-1, 1),
                             x0["ice_properties"][mask], 
                              ], dim=1)
        Lmax = mask.sum(-1).max()
        x = self.extractor(x0, Lmax)
        rel_pos_bias, rel_enc = self.rel_pos(x0, Lmax)
        #nbs = get_nbs(x0, Lmax)
        mask = mask[:,:Lmax]
        batch_index = mask.nonzero()[:,0] 
        edge_index = knn_graph(x = graph_featutre[:,:3], k=8, batch=batch_index).to(mask.device)
        graph_featutre, _, _ = self.local_root(graph_featutre, edge_index, batch_index, x0['L0'])
        graph_featutre, _ = to_dense_batch(graph_featutre, batch_index)
        
        B,_ = mask.shape
        attn_mask = torch.zeros(mask.shape, device=mask.device)
        attn_mask[~mask] = -torch.inf
        x = torch.cat([x,graph_featutre],2)
        
        for blk in self.sandwich:
            if isinstance(blk,LocalBlock):
                x = blk(x,nbs,mask,rel_enc)
            else:
                x = blk(x,attn_mask,rel_pos_bias)
                rel_pos_bias = None
        mask = torch.cat([torch.ones(B,1,dtype=mask.dtype, device=mask.device),mask],1)
        attn_mask = torch.zeros(mask.shape, device=mask.device)
        attn_mask[~mask] = -torch.inf
        cls_token = self.cls_token.weight.unsqueeze(0).expand(B,-1,-1)
        x = torch.cat([cls_token,x],1)
        
        for blk in self.blocks:
            if self.use_checkpoint:
                x = checkpoint.checkpoint(blk, x, None, attn_mask)
            else: x = blk(x, None, attn_mask)
                
        x = self.proj_out(x[:,0]) #cls token
        return x
    
    
    
class EncoderWithDirectionReconstructionV23(nn.Module):
    def __init__(self, dim=384, dim_base=128, depth=8, use_checkpoint=False, head_size=64, **kwargs):
        super().__init__()
        self.extractor = ExtractorV11Scaled(dim_base,dim//2)
        self.rel_pos = Rel_ds(head_size)
        self.sandwich = nn.ModuleList([ 
            Block_rel(dim=dim, num_heads=dim//head_size),
            Block_rel(dim=dim, num_heads=dim//head_size),
            Block_rel(dim=dim, num_heads=dim//head_size),
            Block_rel(dim=dim, num_heads=dim//head_size),
        ])
        self.cls_token = nn.Linear(dim,1,bias=False)
        self.blocks = nn.ModuleList([ 
            Block(
                dim=dim, num_heads=dim//head_size, mlp_ratio=4, drop_path=0.0*(i/(depth-1)), init_values=1,)
            for i in range(depth)])
        self.proj_out = nn.Linear(dim,3)
        self.use_checkpoint = use_checkpoint
        self.local_root= DynEdgeFEXTRACTRO(9, 
                                           post_processing_layer_sizes = [336, dim//2], 
                                           dynedge_layer_sizes = [(128, 256), (336, 256), (336, 256), (336, 256)])   
        self.apply(self._init_weights)
        trunc_normal_(self.cls_token.weight, std=.02)

    def fix_init_weight(self):
        def rescale(param, layer_id):
            param.div_(math.sqrt(2.0 * layer_id))

        for layer_id, layer in enumerate(self.blocks):
            rescale(layer.attn.proj.weight.data, layer_id + 1)
            rescale(layer.mlp.fc2.weight.data, layer_id + 1)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def init_weights(self, pretrained=None):
        def _init_weights(m):
            if isinstance(m, nn.Linear):
                trunc_normal_(m.weight, std=.02)
                if isinstance(m, nn.Linear) and m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.LayerNorm):
                nn.init.constant_(m.bias, 0)
                nn.init.constant_(m.weight, 1.0)
        self.apply(_init_weights)
        
    @torch.jit.ignore
    def no_weight_decay(self):
        return {'cls_token'}
    
    def forward(self, x0):
        mask = x0['mask']
        graph_featutre = torch.concat([x0["pos"][mask] , 
                             x0['time'][mask].view(-1, 1),
                             x0['auxiliary'][mask].view(-1, 1),
                             x0['qe'][mask].view(-1, 1),
                             x0['charge'][mask].view(-1, 1),
                             x0["ice_properties"][mask], 
                              ], dim=1)
        Lmax = mask.sum(-1).max()
        x = self.extractor(x0, Lmax)
        rel_pos_bias, rel_enc = self.rel_pos(x0, Lmax)
        #nbs = get_nbs(x0, Lmax)
        mask = mask[:,:Lmax]
        batch_index = mask.nonzero()[:,0] 
        edge_index = knn_graph(x = graph_featutre[:,:4], k=8, batch=batch_index).to(mask.device)
        graph_featutre, _, _ = self.local_root(graph_featutre, edge_index, batch_index, x0['L0'])
        graph_featutre, _ = to_dense_batch(graph_featutre, batch_index)
        
        B,_ = mask.shape
        attn_mask = torch.zeros(mask.shape, device=mask.device)
        attn_mask[~mask] = -torch.inf
        x = torch.cat([x,graph_featutre],2)
        
        for blk in self.sandwich:
            if isinstance(blk,LocalBlock):
                x = blk(x,nbs,mask,rel_enc)
            else:
                x = blk(x,attn_mask,rel_pos_bias)
                #rel_pos_bias = None
        mask = torch.cat([torch.ones(B,1,dtype=mask.dtype, device=mask.device),mask],1)
        attn_mask = torch.zeros(mask.shape, device=mask.device)
        attn_mask[~mask] = -torch.inf
        cls_token = self.cls_token.weight.unsqueeze(0).expand(B,-1,-1)
        x = torch.cat([cls_token,x],1)
        
        for blk in self.blocks:
            if self.use_checkpoint:
                x = checkpoint.checkpoint(blk, x, None, attn_mask)
            else: x = blk(x, None, attn_mask)
                
        x = self.proj_out(x[:,0]) #cls token
        return x

In [182]:
EncoderWithDirectionReconstructionV23(dim=384, dim_base=128, depth=8, head_size=32)

EncoderWithDirectionReconstructionV23(
  (extractor): ExtractorV11Scaled(
    (pos): ScaledSinusoidalEmbedding()
    (emb_charge): ScaledSinusoidalEmbedding()
    (time): ScaledSinusoidalEmbedding()
    (aux_emb): Embedding(2, 64)
    (emb2): ScaledSinusoidalEmbedding()
    (proj): Sequential(
      (0): Linear(in_features=768, out_features=768, bias=True)
      (1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (2): GELU()
      (3): Linear(in_features=768, out_features=192, bias=True)
    )
  )
  (rel_pos): Rel_ds(
    (emb): SinusoidalPosEmb()
    (proj): Linear(in_features=32, out_features=32, bias=True)
  )
  (sandwich): ModuleList(
    (0): Block_rel(
      (norm1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
      (attn): Attention_rel(
        (proj_q): Linear(in_features=384, out_features=384, bias=False)
        (proj_k): Linear(in_features=384, out_features=384, bias=False)
        (proj_v): Linear(in_features=384, out_features=384, bias=False)
      

In [183]:
model = IceCubeModelEncoderMATMasked()
event = torch.randn(2, 100, 6)
mask = torch.ones(2, 100).bool()
adjecent_matrix = torch.empty(2, 100, 100).random_(2).float()
distance_matrix = torch.randn(2, 100, 100)
batch = dict(
    event=event,
    mask=mask,
    adjecent_matrix=adjecent_matrix,
    distance_matrix=distance_matrix,
)
out = model(batch)
out.shape


torch.Size([2, 128])

In [184]:
import torch
torch.__version__

'1.11.0+cu115'

In [185]:
#|hide
#|eval: false
from nbdev.doclinks import nbdev_export
nbdev_export()

In [186]:
(1 + 2)/2

1.5

In [187]:
import torch

In [188]:
from typing import Optional, Tuple, Type
from dataclasses import dataclass
import math

import torch
from torch import nn
import torch.nn.functional as F


@dataclass
class ModelArgs:
    dim: int = 512
    n_layers: int = 8
    n_heads: int = 8
    vocab_size: int = -1  # defined later by tokenizer
    multiple_of: int = 256  # make SwiGLU hidden layer size multiple of large power of 2
    norm_eps: float = 1e-5

    max_batch_size: int = 32
    max_seq_len: int = 1024


class RMSNorm(torch.nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        output = self._norm(x.float()).type_as(x)
        return output * self.weight


def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device)  # type: ignore
    freqs = torch.outer(t, freqs).float()  # type: ignore
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
    return freqs_cis


def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
    ndim = x.ndim
    assert 0 <= 1 < ndim
    assert freqs_cis.shape == (x.shape[1], x.shape[-1])
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(*shape)


def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)


class Attention(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()

        self.n_local_heads = args.n_heads
        self.head_dim = args.dim // args.n_heads

        self.wq = nn.Linear(
            args.dim,
            args.n_heads * self.head_dim,
            bias=False,
        )
        self.wk = nn.Linear(
            args.dim,
            args.n_heads * self.head_dim,
            bias=False,
        )
        self.wv = nn.Linear(
            args.dim,
            args.n_heads * self.head_dim,
            bias=False,
        )
        self.wo = nn.Linear(
            args.dim,
            args.n_heads * self.head_dim,
            bias=False,
        )

        self.cache_k = torch.zeros(
            (args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)
        ).cuda()
        self.cache_v = torch.zeros(
            (args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)
        ).cuda()

    def forward(
        self,
        x: torch.Tensor,
        start_pos: int,
        freqs_cis: torch.Tensor,
        mask: Optional[torch.Tensor],
    ):
        bsz, seqlen, _ = x.shape
        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)

        xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
        xk = xk.view(bsz, seqlen, self.n_local_heads, self.head_dim)
        xv = xv.view(bsz, seqlen, self.n_local_heads, self.head_dim)

        xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)

        self.cache_k = self.cache_k.to(xq)
        self.cache_v = self.cache_v.to(xq)

        self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
        self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv

        keys = self.cache_k[:bsz, : start_pos + seqlen]
        values = self.cache_v[:bsz, : start_pos + seqlen]

        xq = xq.transpose(1, 2)
        keys = keys.transpose(1, 2)
        values = values.transpose(1, 2)
        scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
        if mask is not None:
            scores = scores + mask  # (bs, n_local_heads, slen, cache_len + slen)
        scores = F.softmax(scores.float(), dim=-1).type_as(xq)
        output = torch.matmul(scores, values)  # (bs, n_local_heads, slen, head_dim)
        output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)

        return self.wo(output)


class FeedForward(nn.Module):
    def __init__(
        self,
        dim: int,
        hidden_dim: int,
        multiple_of: int,
    ):
        super().__init__()
        hidden_dim = int(2 * hidden_dim / 3)
        hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)

        self.w1 = nn.Linear(dim, hidden_dim, bias=False)
        self.w2 = nn.Linear(
            hidden_dim,
            dim,
            bias=False,
        )
        self.w3 = nn.Linear(
            dim,
            hidden_dim,
            bias=False,
        )

    def forward(self, x):
        return self.w2(F.silu(self.w1(x)) * self.w3(x))


class TransformerBlock(nn.Module):
    def __init__(self, layer_id: int, args: ModelArgs):
        super().__init__()
        self.n_heads = args.n_heads
        self.dim = args.dim
        self.head_dim = args.dim // args.n_heads
        self.attention = Attention(args)
        self.feed_forward = FeedForward(
            dim=args.dim, hidden_dim=4 * args.dim, multiple_of=args.multiple_of
        )
        self.layer_id = layer_id
        self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
        self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)

    def forward(
        self,
        x: torch.Tensor,
        start_pos: int,
        freqs_cis: torch.Tensor,
        mask: Optional[torch.Tensor],
    ):
        h = x + self.attention.forward(
            self.attention_norm(x), start_pos, freqs_cis, mask
        )
        out = h + self.feed_forward.forward(self.ffn_norm(h))
        return out


def convert_linear_to_bnb(float_linear):
    new_layer = InferenceQuantizedLinear(
        float_linear.in_features,
        float_linear.out_features,
        bias=float_linear.bias is not None,
    )
    new_layer._parameters["weight"] = bnb.nn.Int8Params(
        float_linear.weight.data.cpu(),
        requires_grad=False,
        has_fp16_weights=False,
    )
    if float_linear.bias is not None:
        new_layer._parameters["bias"] = float_linear.bias
    return new_layer


class Transformer(nn.Module):
    def __init__(self, params: ModelArgs):
        super().__init__()
        self.params = params
        self.vocab_size = params.vocab_size
        self.n_layers = params.n_layers

        self.tok_embeddings = torch.nn.Embedding(params.vocab_size, params.dim)

        self.layers = torch.nn.ModuleList()
        for layer_id in range(params.n_layers):
            self.layers.append(TransformerBlock(layer_id, params))

        self.norm = RMSNorm(params.dim, eps=params.norm_eps)

        self.output = nn.Linear(params.dim, params.vocab_size, bias=False)

        self.freqs_cis = precompute_freqs_cis(
            self.params.dim // self.params.n_heads, self.params.max_seq_len * 2
        )

    @torch.inference_mode()
    def forward(self, tokens: torch.Tensor, start_pos: int):
        _bsz, seqlen = tokens.shape
        h = self.tok_embeddings(tokens)
        self.freqs_cis = self.freqs_cis.to(h.device)
        freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]

        mask = None
        if seqlen > 1:
            mask = torch.full(
                (1, 1, seqlen, seqlen), float("-inf"), device=tokens.device
            )
            mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)

        for layer in self.layers:
            h = h.to(layer.parameters().__next__().device)
            h = layer(h, start_pos, freqs_cis, mask)
        h = h.to(self.norm.parameters().__next__().device)
        h = self.norm(h)

        hl = h[:, -1, :]
        hl = hl.to(self.output.parameters().__next__().device)
        output = self.output(hl)
        return output.float()

In [189]:
ModelArgs().dim


512

In [190]:
freqs_cis = precompute_freqs_cis(64, 1024)

In [191]:
freqs_cis[10:40].shape

torch.Size([30, 32])

In [192]:
ScaledSinusoidalEmbedding(32)

ScaledSinusoidalEmbedding()

In [193]:
x = torch.rand(3, 5, 1)

In [194]:
SinusoidalPosEmb()(x)

tensor([[[[3.8143e-01, 1.2344e-01, 3.9125e-02, 1.2375e-02, 3.9135e-03,
           1.2375e-03, 3.9135e-04, 1.2376e-04, 9.2440e-01, 9.9235e-01,
           9.9923e-01, 9.9992e-01, 9.9999e-01, 1.0000e+00, 1.0000e+00,
           1.0000e+00]],

         [[8.3218e-01, 3.0588e-01, 9.8144e-02, 3.1081e-02, 9.8301e-03,
           3.1086e-03, 9.8303e-04, 3.1086e-04, 5.5451e-01, 9.5207e-01,
           9.9517e-01, 9.9952e-01, 9.9995e-01, 1.0000e+00, 1.0000e+00,
           1.0000e+00]],

         [[5.3265e-02, 1.6851e-02, 5.3290e-03, 1.6852e-03, 5.3290e-04,
           1.6852e-04, 5.3290e-05, 1.6852e-05, 9.9858e-01, 9.9986e-01,
           9.9999e-01, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00,
           1.0000e+00]],

         [[5.4645e-02, 1.7288e-02, 5.4672e-03, 1.7289e-03, 5.4672e-04,
           1.7289e-04, 5.4672e-05, 1.7289e-05, 9.9851e-01, 9.9985e-01,
           9.9999e-01, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00,
           1.0000e+00]],

         [[2.9026e-01, 9.2995e-02, 2.9446e-0

In [195]:
dice = torch.tensor([0.2, 0.4, 0.5, 0.6])

In [196]:

F.softmax(dice)

  """Entry point for launching an IPython kernel.


tensor([0.1975, 0.2412, 0.2666, 0.2946])

In [197]:
def loss_fn(x):
    #x dice score (bs, 4)
    y = F.one_hot(x.argmax(dim=-1), num_classes=x.shape[1]).float()
    x = F.normalize(x, dim=-1, p=2)
    return 2 - 2 * (x * y).sum(dim=-1)

In [198]:
loss_fn(torch.rand(3, 4))

tensor([0.0612, 0.5434, 0.5291])

In [199]:
x = torch.rand(3, 4)

In [200]:
F.one_hot(x.argmax(dim=-1), ).float()

tensor([[0., 0., 1.],
        [1., 0., 0.],
        [1., 0., 0.]])

In [201]:
torch.tensor([0.0250, 0.0250, 0.0250, 0.0250, 0.1000])  * 0.70

tensor([0.0175, 0.0175, 0.0175, 0.0175, 0.0700])

In [202]:
torch.tensor([0.0175, 0.0175, 0.0175, 0.0175, 0.0700]).sum()

tensor(0.1400)

In [203]:
1 - 0.14

0.86

In [204]:
0.86/4

0.215

In [205]:
 torch.tensor([0.26666, 0.26666, 0.26666,0.0250, 0.0250, 0.0250, 0.0250, 0.1000]) * 0.5

tensor([0.1333, 0.1333, 0.1333, 0.0125, 0.0125, 0.0125, 0.0125, 0.0500])

In [206]:
torch.tensor([0.215, 0.215, 0.215, 0.215, 0.0175, 0.0175, 0.0175, 0.0175, 0.0700]).sum()

tensor(1.)

In [207]:
1e-5/10.

1.0000000000000002e-06

In [208]:
1e-6 == 0.000001

True