In [7]:
import math
from torch import nn, einsum
import torch
from timm.models.layers import drop_path, to_2tuple, trunc_normal_
import torch.utils.checkpoint as checkpoint
from einops import rearrange, repeat
from torch_geometric.utils import to_dense_batch
from torch_geometric.nn.pool import knn_graph
from graphnet.models.gnn.gnn import GNN
from graphnet.models.utils import calculate_xyzt_homophily
from torch_geometric.typing import Adj
from torch_scatter import scatter_max, scatter_mean, scatter_min, scatter_sum
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union
from torch import Tensor, LongTensor
from torch_geometric.nn import MessagePassing, global_add_pool, global_mean_pool
from torch_geometric.nn import EdgeConv
import torch.nn.functional as F
from torch.nn import Linear, ReLU, SiLU, Sequential
from graphnet.utilities.config import save_model_config
from torch_geometric.utils import to_dense_adj

GLOBAL_POOLINGS = {
    "min": scatter_min,
    "max": scatter_max,
    "sum": scatter_sum,
    "mean": scatter_mean,
}

class DynEdgeConv(EdgeConv):
    """Dynamical edge convolution layer."""

    def __init__(
        self,
        nn: Callable,
        aggr: str = "max",
        nb_neighbors: int = 8,
        features_subset: Optional[Union[Sequence[int], slice]] = None,
        **kwargs: Any,
    ):
        """Construct `DynEdgeConv`.
        Args:
            nn: The MLP/torch.Module to be used within the `EdgeConv`.
            aggr: Aggregation method to be used with `EdgeConv`.
            nb_neighbors: Number of neighbours to be clustered after the
                `EdgeConv` operation.
            features_subset: Subset of features in `Data.x` that should be used
                when dynamically performing the new graph clustering after the
                `EdgeConv` operation. Defaults to all features.
            **kwargs: Additional features to be passed to `EdgeConv`.
        """
        # Check(s)
        if features_subset is None:
            features_subset = slice(None)  # Use all features
        assert isinstance(features_subset, (list, slice))

        # Base class constructor
        super().__init__(nn=nn, aggr=aggr, **kwargs)

        # Additional member variables
        self.nb_neighbors = nb_neighbors
        self.features_subset = features_subset

    def forward(
        self, x: Tensor, edge_index: Adj, batch: Optional[Tensor] = None
    ) -> Tensor:

        """Forward pass."""
        # Standard EdgeConv forward pass
        x = super().forward(x, edge_index)
        dev = x.device

        # Recompute adjacency
        edge_index = knn_graph(
            x=x[:, self.features_subset],
            k=self.nb_neighbors,
            batch=batch,
        ).to(dev)

        return x, edge_index
    
    
class DynEdgeFEXTRACTRO(GNN):
    """DynEdge (dynamical edge convolutional) model."""

    @save_model_config
    def __init__(
        self,
        nb_inputs: int,
        *,
        nb_neighbours: int = 8,
        features_subset: Optional[Union[List[int], slice]] = None,
        dynedge_layer_sizes: Optional[List[Tuple[int, ...]]] = None,
        post_processing_layer_sizes: Optional[List[int]] = None,
        readout_layer_sizes: Optional[List[int]] = None,
        global_pooling_schemes: Optional[Union[str, List[str]]] = None,
        add_global_variables_after_pooling: bool = False,
    ):
        """Construct `DynEdge`.
        Args:
            nb_inputs: Number of input features on each node.
            nb_neighbours: Number of neighbours to used in the k-nearest
                neighbour clustering which is performed after each (dynamical)
                edge convolution.
            features_subset: The subset of latent features on each node that
                are used as metric dimensions when performing the k-nearest
                neighbours clustering. Defaults to [0,1,2].
            dynedge_layer_sizes: The layer sizes, or latent feature dimenions,
                used in the `DynEdgeConv` layer. Each entry in
                `dynedge_layer_sizes` corresponds to a single `DynEdgeConv`
                layer; the integers in the corresponding tuple corresponds to
                the layer sizes in the multi-layer perceptron (MLP) that is
                applied within each `DynEdgeConv` layer. That is, a list of
                size-two tuples means that all `DynEdgeConv` layers contain a
                two-layer MLP.
                Defaults to [(128, 256), (336, 256), (336, 256), (336, 256)].
            post_processing_layer_sizes: Hidden layer sizes in the MLP
                following the skip-concatenation of the outputs of each
                `DynEdgeConv` layer. Defaults to [336, 256].
            readout_layer_sizes: Hidden layer sizes in the MLP following the
                post-processing _and_ optional global pooling. As this is the
                last layer(s) in the model, the last layer in the read-out
                yields the output of the `DynEdge` model. Defaults to [128,].
            global_pooling_schemes: The list global pooling schemes to use.
                Options are: "min", "max", "mean", and "sum".
            add_global_variables_after_pooling: Whether to add global variables
                after global pooling. The alternative is to  added (distribute)
                them to the individual nodes before any convolutional
                operations.
        """
        # Latent feature subset for computing nearest neighbours in DynEdge.
        if features_subset is None:
            features_subset = slice(0, 3)

        # DynEdge layer sizes
        if dynedge_layer_sizes is None:
            dynedge_layer_sizes = [
                (
                    128,
                    256,
                ),
                (
                    336,
                    256,
                ),
                (
                    336,
                    256,
                ),
                (
                    336,
                    256,
                ),
            ]

        assert isinstance(dynedge_layer_sizes, list)
        assert len(dynedge_layer_sizes)
        assert all(isinstance(sizes, tuple) for sizes in dynedge_layer_sizes)
        assert all(len(sizes) > 0 for sizes in dynedge_layer_sizes)
        assert all(all(size > 0 for size in sizes) for sizes in dynedge_layer_sizes)

        self._dynedge_layer_sizes = dynedge_layer_sizes

        # Post-processing layer sizes
        if post_processing_layer_sizes is None:
            post_processing_layer_sizes = [
                336,
                256,
            ]

        assert isinstance(post_processing_layer_sizes, list)
        assert len(post_processing_layer_sizes)
        assert all(size > 0 for size in post_processing_layer_sizes)

        self._post_processing_layer_sizes = post_processing_layer_sizes

        # Read-out layer sizes
        if readout_layer_sizes is None:
            readout_layer_sizes = [
                128,
            ]

        assert isinstance(readout_layer_sizes, list)
        assert len(readout_layer_sizes)
        assert all(size > 0 for size in readout_layer_sizes)

        self._readout_layer_sizes = readout_layer_sizes

        # Global pooling scheme(s)
        if isinstance(global_pooling_schemes, str):
            global_pooling_schemes = [global_pooling_schemes]

        if isinstance(global_pooling_schemes, list):
            for pooling_scheme in global_pooling_schemes:
                assert (
                    pooling_scheme in GLOBAL_POOLINGS
                ), f"Global pooling scheme {pooling_scheme} not supported."
        else:
            assert global_pooling_schemes is None

        self._global_pooling_schemes = global_pooling_schemes

        if add_global_variables_after_pooling:
            assert self._global_pooling_schemes, (
                "No global pooling schemes were request, so cannot add global"
                " variables after pooling."
            )
        self._add_global_variables_after_pooling = add_global_variables_after_pooling

        # Base class constructor
        super().__init__(nb_inputs, self._readout_layer_sizes[-1])

        # Remaining member variables()
        self._activation = torch.nn.GELU()
        self._nb_inputs = nb_inputs
        self._nb_global_variables = 5 + nb_inputs
        self._nb_neighbours = nb_neighbours
        self._features_subset = features_subset

        self._construct_layers()

    def _construct_layers(self) -> None:
        """Construct layers (torch.nn.Modules)."""
        # Convolutional operations
        nb_input_features = self._nb_inputs
        if not self._add_global_variables_after_pooling:
            nb_input_features += self._nb_global_variables

        self._conv_layers = torch.nn.ModuleList()
        nb_latent_features = nb_input_features
        for sizes in self._dynedge_layer_sizes:
            layers = []
            layer_sizes = [nb_latent_features] + list(sizes)
            for ix, (nb_in, nb_out) in enumerate(
                zip(layer_sizes[:-1], layer_sizes[1:])
            ):
                if ix == 0:
                    nb_in *= 2
                layers.append(torch.nn.Linear(nb_in, nb_out))
                layers.append(nn.LayerNorm(nb_out))
                layers.append(self._activation)

            conv_layer = DynEdgeConv(
                torch.nn.Sequential(*layers),
                aggr="add",
                nb_neighbors=self._nb_neighbours,
                features_subset=self._features_subset,
            )
            self._conv_layers.append(conv_layer)

            nb_latent_features = nb_out

        # Post-processing operations
        nb_latent_features = (
            sum(sizes[-1] for sizes in self._dynedge_layer_sizes) + nb_input_features
        )

        post_processing_layers = []
        layer_sizes = [nb_latent_features] + list(self._post_processing_layer_sizes)
        for nb_in, nb_out in zip(layer_sizes[:-1], layer_sizes[1:]):
            post_processing_layers.append(torch.nn.Linear(nb_in, nb_out))
            post_processing_layers.append(nn.LayerNorm(nb_out))
            post_processing_layers.append(self._activation)

        self._post_processing = torch.nn.Sequential(*post_processing_layers)

        # Read-out operations
        nb_poolings = (
            len(self._global_pooling_schemes) if self._global_pooling_schemes else 1
        )
        nb_latent_features = nb_out * nb_poolings
        if self._add_global_variables_after_pooling:
            nb_latent_features += self._nb_global_variables

        readout_layers = []
        layer_sizes = [nb_latent_features] + list(self._readout_layer_sizes)
        for nb_in, nb_out in zip(layer_sizes[:-1], layer_sizes[1:]):
            readout_layers.append(torch.nn.Linear(nb_in, nb_out))
            readout_layers.append(nn.LayerNorm(nb_out))
            readout_layers.append(self._activation)

        self._readout = torch.nn.Sequential(*readout_layers)

    def _global_pooling(self, x: Tensor, batch: LongTensor) -> Tensor:
        """Perform global pooling."""
        assert self._global_pooling_schemes
        pooled = []
        for pooling_scheme in self._global_pooling_schemes:
            pooling_fn = GLOBAL_POOLINGS[pooling_scheme]
            pooled_x = pooling_fn(x, index=batch, dim=0)
            if isinstance(pooled_x, tuple) and len(pooled_x) == 2:
                # `scatter_{min,max}`, which return also an argument, vs.
                # `scatter_{mean,sum}`
                pooled_x, _ = pooled_x
            pooled.append(pooled_x)

        return torch.cat(pooled, dim=1)

    def _calculate_global_variables(
        self,
        x: Tensor,
        edge_index: LongTensor,
        batch: LongTensor,
        *additional_attributes: Tensor,
    ) -> Tensor:
        """Calculate global variables."""
        # Calculate homophily (scalar variables)
        h_x, h_y, h_z, h_t = calculate_xyzt_homophily(x, edge_index, batch)

        # Calculate mean features
        global_means = scatter_mean(x, batch, dim=0)

        # Add global variables
        global_variables = torch.cat(
            [
                global_means,
                h_x,
                h_y,
                h_z,
                h_t,
            ]
            + [attr.unsqueeze(dim=1) for attr in additional_attributes],
            dim=1,
        )

        return global_variables

    def forward(self, x, edge_index, batch, n_pulses) -> Tensor:
        """Apply learnable forward pass."""
        # Convenience variables
        global_variables = self._calculate_global_variables(
            x,
            edge_index,
            batch,
            torch.log10(n_pulses),
        )

        # Distribute global variables out to each node
        if not self._add_global_variables_after_pooling:
            distribute = (
                batch.unsqueeze(dim=1) == torch.unique(batch).unsqueeze(dim=0)
            ).type(torch.float)

            global_variables_distributed = torch.sum(
                distribute.unsqueeze(dim=2) * global_variables.unsqueeze(dim=0),
                dim=1,
            )

            x = torch.cat((x, global_variables_distributed), dim=1)

        # DynEdge-convolutions
        skip_connections = [x]
        for conv_layer in self._conv_layers:
            x, edge_index = conv_layer(x, edge_index, batch)
            skip_connections.append(x)

        # Skip-cat
        x = torch.cat(skip_connections, dim=1)
        x = self._post_processing(x)
        return x, edge_index, batch

def exists(val):
    return val is not None

def batched_index_select(values, indices):
    last_dim = values.shape[-1]
    return values.gather(1, indices[:, :, None].expand(-1, -1, last_dim))

class FeedForward(nn.Module):
    def __init__(self, dim, mult = 4, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, dim * mult),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(dim * mult, dim)
        )

    def forward(self, x, **kwargs):
        return self.net(x)


class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, **kwargs):
        return x + self.fn(x, **kwargs)


class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn

    def forward(self, x, **kwargs):
        x = self.norm(x)
        return self.fn(x, **kwargs)

class LAttentionV2(nn.Module):
    def __init__(
        self,
        dim,
        heads = 8,
        dim_head = 64,
        and_self_attend = False
    ):
        super().__init__()
        inner_dim = heads * dim_head
        self.heads = heads
        self.scale = dim_head ** -0.5

        self.and_self_attend = and_self_attend

        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
        self.to_out = nn.Linear(inner_dim, dim, bias = False)

    def forward(
        self,
        x,
        context,
        mask = None
    ):
        h, scale = self.heads, self.scale
        q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
        dots = einsum('b h i d, b h j d -> b h i j', q, k) * scale

        if exists(mask):
            mask_value = -torch.finfo(dots.dtype).max
            mask = rearrange(mask, 'b n -> b 1 1 n')
            dots.masked_fill_(~mask, mask_value)

        attn = dots.softmax(dim = -1)
        out = einsum('b h i j, b h j d -> b h i d', attn, v)

        out = rearrange(out, 'b h n d -> b n (h d)', h = h)
        return self.to_out(out)

class LocalLatentsAttent(nn.Module):
    def __init__(
        self,
        dim,
        heads = 8,
        num_latents = 64,
        latent_self_attend = False
    ):
        super().__init__()
        self.latents = nn.Parameter(torch.randn(num_latents, dim))
        self.attn1 = LAttentionV2(dim, heads, and_self_attend = latent_self_attend)
        self.attn2 = LAttentionV2(dim, heads)

    def forward(self, x, latents = None, mask = None):
        b, *_ = x.shape

        latents = self.latents

        if latents.ndim == 2:
            latents = repeat(latents, 'n d -> b n d', b = b)

        latents = self.attn1(latents, x, mask = mask)
        out     = self.attn2(x, latents)

        return out, latents
    
class LocalAttenV2(nn.Module):
    def __init__(
        self,
        dim,
        depth,
        heads = 8,
        num_latents = 64,
        ff_dropout = 0.
    ):
        super().__init__()
        self.layers = nn.ModuleList([])

        for _ in range(depth):
            global_attn = PreNorm(dim, LocalLatentsAttent(
                dim = dim,
                heads = heads,
                num_latents = num_latents
            )) 

            self.layers.append(nn.ModuleList([
                global_attn,
                Residual(PreNorm(dim, FeedForward(
                    dim = dim,
                    dropout = ff_dropout
                )))
            ]))

    def forward(self, x, mask = None):
        for attn, ff in self.layers:
            out, _ = attn(x, mask = mask)
            x = x + out
            x = ff(x)
        return x

class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim=16, M=10000):
        super().__init__()
        self.dim = dim
        self.M = M

    def forward(self, x):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(self.M) / half_dim
        emb = torch.exp(torch.arange(half_dim, device=device) * (-emb))
        emb = x[...,None] * emb[None,...]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb
    
    
class LocalGlobalAttention(nn.Module):
    def __init__(
        self,
        dim,
        depth,
        dim_head = 64,
        heads = 8,
        num_neighbors_cutoff = None,
        attn_dropout = 0.1,
        ff_dropout=0.,
    ):
        super().__init__()
        self.num_neighbors_cutoff = num_neighbors_cutoff
        self.layers = nn.ModuleList([])

        for _ in range(depth):
            global_attn = PreNorm(dim, LocalLatentsAttent(
                dim = dim,
                heads = heads,
                num_latents = 64
            )) 
        
            self.layers.append(nn.ModuleList([
                Residual(PreNorm(dim, NMatrixAttention(
                    dim = dim,
                    dim_head = dim_head,
                    heads = heads,
                    dropout = attn_dropout
                ))),
                global_attn,
                Residual(PreNorm(dim, FeedForward(
                    dim = dim,
                    dropout = ff_dropout
                )))
            ]))

    def forward(self, x, adjacency_mat, mask = None):
        device, n = x.device, x.shape[1]

        diag = torch.eye(adjacency_mat.shape[-1], device = device).bool()
        adjacency_mat |= diag
        if exists(mask):
            adjacency_mat &= (mask[:, :, None] * mask[:, None, :])

        adj_mat = adjacency_mat.float()
        max_neighbors = int(adj_mat.sum(dim = -1).max())

        if exists(self.num_neighbors_cutoff) and max_neighbors > self.num_neighbors_cutoff:
            noise = torch.empty((n, n), device = device).uniform_(-0.01, 0.01)
            adj_mat = adj_mat + noise
            adj_mask, adj_kv_indices = adj_mat.topk(dim = -1, k = self.num_neighbors_cutoff)
            adj_mask = (adj_mask > 0.5).float()
        else:
            adj_mask, adj_kv_indices = adj_mat.topk(dim = -1, k = max_neighbors)
        
        for attn, locla_attn, ff in self.layers:
            x = attn(
                x,
                adj_kv_indices = adj_kv_indices,
                mask = adj_mask
            )
            out, _ = locla_attn(x, mask = mask)
            x = x + out
            x = ff(x)
            
        return x


    
class ExtractorV0(nn.Module):
    def __init__(self, dim_base=128, dim=384, proj = True):
        super().__init__()
        self.emb = SinusoidalPosEmb(dim=dim_base)
        self.emb2 = SinusoidalPosEmb(dim=dim_base//2)
        self.aux_emb = nn.Embedding(2,dim_base//2)
        self.qe_emb = nn.Embedding(2,dim_base//2)
        self.proj = nn.Linear(dim_base*7,dim) if proj else nn.Identity()
        
    def forward(self, x, Lmax=None):
        pos = x['pos'] if Lmax is None else x['pos'][:,:Lmax]
        charge = x['charge'] if Lmax is None else x['charge'][:,:Lmax]
        time = x['time'] if Lmax is None else x['time'][:,:Lmax]
        auxiliary = x['aux'] if Lmax is None else x['auxiliary'][:,:Lmax]
        qe = x['qe'] if Lmax is None else x['qe'][:,:Lmax]
        ice_properties = x['ice_properties'] if Lmax is None else x['ice_properties'][:,:Lmax]
        
        x = torch.cat([self.emb(100*pos).flatten(-2), self.emb(40*charge),
                       self.emb(100*time),self.aux_emb(auxiliary),self.qe_emb(qe),
                       self.emb2(50*ice_properties).flatten(-2)],-1)
        x = self.proj(x)
        return x


class BeDropPath(nn.Module):
    def __init__(self, drop_prob=None):
        super(BeDropPath, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training)
    
    def extra_repr(self) -> str:
        return 'p={}'.format(self.drop_prob)
    
class BeMLP(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        # x = self.drop(x)
        # commit this for the orignal BERT implement 
        x = self.fc2(x)
        x = self.drop(x)
        return x

#BEiTv2 Beblock
class BeBlock(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
                 window_size=None, attn_head_dim=None, **kwargs):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = nn.MultiheadAttention(dim, num_heads, dropout=drop, batch_first=True)
        self.drop_path = BeDropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = BeMLP(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

        if init_values is not None:
            self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
            self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
        else:
            self.gamma_1, self.gamma_2 = None, None

    def forward(self, x, attn_mask=None, key_padding_mask=None):
        if self.gamma_1 is None:
            xn = self.norm1(x)
            x = x + self.drop_path(self.attn(xn,xn,xn,
                            attn_mask=attn_mask,
                            key_padding_mask=key_padding_mask,
                            need_weights=False)[0])
            x = x + self.drop_path(self.mlp(self.norm2(x)))
        else:
            xn = self.norm1(x)
            x = x + self.drop_path(self.gamma_1 * self.drop_path(self.attn(xn,xn,xn,
                            attn_mask=attn_mask,
                            key_padding_mask=key_padding_mask,
                            need_weights=False)[0]))
            x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
        return x


class BeDeepIceModel(nn.Module):
    def __init__(self, dim=384, depth=12, out_class = 3, use_checkpoint=False, drop_b= 0., div_factor=64, attn_drop_b = 0., drop_path = 0.,  **kwargs):
        super().__init__()
        self.Beblocks = nn.ModuleList([ 
            BeBlock(
                dim=dim, num_heads=dim//div_factor, mlp_ratio=4, drop_path=drop_path, init_values=1, attn_drop=attn_drop_b, drop=drop_b)
            for i in range(depth)])
        #self.Beblocks = nn.ModuleList([ 
        #    nn.TransformerEncoderLayer(dim,dim//64,dim*4,dropout=0,
        #        activation=nn.GELU(), batch_first=True, norm_first=True)
        #    for i in range(depth)])
        self.out_class = out_class
        self.proj_out = nn.Linear(dim,out_class) if out_class == 3 else nn.Identity()
        self.use_checkpoint = use_checkpoint
        self.apply(self._init_weights)

    def fix_init_weight(self):
        def rescale(param, layer_id):
            param.div_(math.sqrt(2.0 * layer_id))

        for layer_id, layer in enumerate(self.Beblocks):
            rescale(layer.attn.proj.weight.data, layer_id + 1)
            rescale(layer.mlp.fc2.weight.data, layer_id + 1)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def init_weights(self, pretrained=None):
        def _init_weights(m):
            if isinstance(m, nn.Linear):
                trunc_normal_(m.weight, std=.02)
                if isinstance(m, nn.Linear) and m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.LayerNorm):
                nn.init.constant_(m.bias, 0)
                nn.init.constant_(m.weight, 1.0)
        self.apply(_init_weights)
    
    def forward(self, x, mask):
        attn_mask = torch.zeros(mask.shape, device=mask.device)
        attn_mask[~mask] = -torch.inf
        for blk in self.Beblocks:
            if self.use_checkpoint:
                x = checkpoint.checkpoint(blk, x, None, attn_mask)
            else: x = blk(x, None, attn_mask)
        if self.out_class == 3:
            x = self.proj_out(x[:,0]) #cls token
        return x
    

class EncoderWithDirectionReconstructionV18(nn.Module):
    def __init__(self, dim_out=256, drop_path=0.):
        super().__init__()
        self.fe = ExtractorV0(dim=dim_out//2, dim_base=32)
        self.encoder = BeDeepIceModel(dim_out , drop_path=drop_path)
        self.cls_token = nn.Linear(dim_out,1,bias=False)
        self.local_root= DynEdgeFEXTRACTRO(9, 
                                           post_processing_layer_sizes = [336, dim_out//2], 
                                           dynedge_layer_sizes = [(128, 256), (336, 256), (336, 256), (336, 256)])
        self.global_root =  LocalAttenV2(dim = dim_out//2, depth =4)
        trunc_normal_(self.cls_token.weight, std=.02)

    def forward(self, batch):
        mask = batch["mask"] #bs, seq_len
        graph_featutre = torch.concat([batch["pos"][mask] , 
                             batch['time'][mask].view(-1, 1),
                             batch['auxiliary'][mask].view(-1, 1),
                             batch['qe'][mask].view(-1, 1),
                             batch['charge'][mask].view(-1, 1),
                             batch["ice_properties"][mask], 
                              ], dim=1)
        bs = mask.shape[0] # int
        mask = mask[:,:mask.sum(-1).max()] 
        batch_index = mask.nonzero()[:,0] 
        edge_index = knn_graph(x = graph_featutre[:,:3], k=8, batch=batch_index).to(mask.device)
        x = self.fe(batch, mask.sum(-1).max())
        
        graph_featutre, _, _ = self.local_root(graph_featutre, edge_index, batch_index, mask.sum(-1))
        graph_featutre, mask = to_dense_batch(graph_featutre, batch_index)
        global_featutre = self.global_root(x, mask)
        x = torch.cat([global_featutre, graph_featutre],2)
        cls_token = self.cls_token.weight.unsqueeze(0).expand(bs,-1,-1)
        x = torch.cat([cls_token,x],1)
        mask = torch.cat([torch.ones(bs, 1, dtype=torch.bool, device=x.device), mask], dim=1)
        x = self.encoder(x, mask=mask)
        return x
    

class NMatrixAttention(nn.Module):
    def __init__(
        self,
        *,
        dim,
        dim_head = 64,
        heads = 4,
        dropout = 0.1
    ):
        super().__init__()
        inner_dim = dim_head * heads
        self.scale = dim_head ** -0.5
        self.heads = heads

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
        self.to_out = nn.Linear(inner_dim, dim)
        #null_k and null_v parameters serve as learnable "null" key and value vectors.
        #provids a default key and value for each attention head 
        #when there is no connection between two nodes or when adjacency information is missing.
        #By including these null keys and values, the attention mechanism can learn to assign a
        #ppropriate importance to the null entries in the adjacency matrix, effectively allowing the model to learn 
        #how to handle situations where neighborhood information is incomplete or scarce.
        self.null_k = nn.Parameter(torch.randn(heads, dim_head))
        self.null_v = nn.Parameter(torch.randn(heads, dim_head))

        self.dropout = nn.Dropout(dropout)

    def forward(
        self,
        x,
        adj_kv_indices,
        mask
    ):
        b, n, d, h = *x.shape, self.heads
        flat_indices = repeat(adj_kv_indices, 'b n a -> (b h) (n a)', h = h)
        #splits the input tensor into query q, key k, and value v tensors using the to_qkv linear laye
        q, k, v = self.to_qkv(x).chunk(3, dim = -1)
        #rearranges q, k, and v tensors to have separate head dimensions.
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))

        #batched_index_select to select the corresponding k and v tensors based on the adjacency indices
        k, v = map(lambda t: rearrange(t, 'b h n d -> (b h) n d'), (k, v))
        k = batched_index_select(k, flat_indices)
        v = batched_index_select(v, flat_indices)
        k, v = map(lambda t: rearrange(t, '(b h) (n a) d -> b h n a d', h = h, n = n), (k, v))

        nk, nv = map(lambda t: rearrange(t, 'h d -> () h () () d').expand(b, -1, n, 1, -1), (self.null_k, self.null_v))
        k = torch.cat((nk, k), dim = -2)
        v = torch.cat((nv, v), dim = -2)
        mask = F.pad(mask, (1, 0), value = 1)
        #calculate the similarity scores between queries and keys, scales them, and applies the mask.
        sim = einsum('b h n d, b h n a d -> b h n a', q, k) * self.scale

        mask_value = -torch.finfo(sim.dtype).max
        mask = rearrange(mask.bool(), 'b n a -> b () n a')
        sim.masked_fill_(~mask.bool(), mask_value)

        attn = sim.softmax(dim = -1)
        attn = self.dropout(attn)

        out = einsum('b h n a, b h n a d -> b h n d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')

        return self.to_out(out)
    
class LocalAttenNetwok(nn.Module):
    def __init__(
        self,
        dim,
        depth,
        dim_head = 64,
        heads = 8,
        num_neighbors_cutoff = None,
        attn_dropout = 0.1,
    ):
        super().__init__()
        self.num_neighbors_cutoff = num_neighbors_cutoff
        self.layers = nn.ModuleList([])

        for _ in range(depth):
            global_attn = None
            self.layers.append(nn.ModuleList([
                Residual(PreNorm(dim, NMatrixAttention(
                    dim = dim,
                    dim_head = dim_head,
                    heads = heads,
                    dropout = attn_dropout
                ))),
                global_attn,
            ]))

    def forward(self, x, adjacency_mat, mask = None):
        device, n = x.device, x.shape[1]

        diag = torch.eye(adjacency_mat.shape[-1], device = device).bool()
        adjacency_mat |= diag
        if exists(mask):
            adjacency_mat &= (mask[:, :, None] * mask[:, None, :])

        adj_mat = adjacency_mat.float()
        max_neighbors = int(adj_mat.sum(dim = -1).max())

        if exists(self.num_neighbors_cutoff) and max_neighbors > self.num_neighbors_cutoff:
            noise = torch.empty((n, n), device = device).uniform_(-0.01, 0.01)
            adj_mat = adj_mat + noise
            adj_mask, adj_kv_indices = adj_mat.topk(dim = -1, k = self.num_neighbors_cutoff)
            adj_mask = (adj_mask > 0.5).float()
        else:
            adj_mask, adj_kv_indices = adj_mat.topk(dim = -1, k = max_neighbors)
        for attn, _ in self.layers:
            x = attn(
                x,
                adj_kv_indices = adj_kv_indices,
                mask = adj_mask
            )


        return x
    
    
    
class EncoderWithDirectionReconstructionV11_V2_LOCAL_GLOBAL(nn.Module):
    def __init__(self, dim_out=256 + 64, drop_path=0.):
        super().__init__()
        self.fe = ExtractorV0(dim=dim_out, dim_base=96)
        self.encoder = BeDeepIceModel(dim_out, drop_path=drop_path)
        self.cls_token = nn.Linear(dim_out,1,bias=False)
        self.loacl_attn = LocalGlobalAttention(dim = dim_out, depth = 3, num_neighbors_cutoff = 24)
        trunc_normal_(self.cls_token.weight, std=.02)

    def forward(self, batch):
        mask = batch["mask"] #bs, seq_len
        bs = mask.shape[0] # int
        pos = batch["pos"][mask] 
        mask = mask[:,:mask.sum(-1).max()] 
        batch_index = mask.nonzero()[:,0] 
        edge_index = knn_graph(x = pos, k=8, batch=batch_index).to(mask.device)
        adj_matrix = to_dense_adj(edge_index, batch_index).int()
        x = self.fe(batch, mask.sum(-1).max())
        x = self.loacl_attn(x, adj_matrix, mask)
        cls_token = self.cls_token.weight.unsqueeze(0).expand(bs,-1,-1)
        x = torch.cat([cls_token,x],1)
        mask = torch.cat([torch.ones(bs, 1, dtype=torch.bool, device=x.device), mask], dim=1)
        x = self.encoder(x, mask=mask)
        return x
    
class EncoderWithDirectionReconstructionV11(nn.Module):
    def __init__(self, dim_out=256 + 64, drop_path=0.):
        super().__init__()
        self.fe = ExtractorV0(dim=dim_out, dim_base=96)
        self.encoder = BeDeepIceModel(dim_out, drop_path=drop_path)
        self.cls_token = nn.Linear(dim_out,1,bias=False)
        self.loacl_attn = LocalAttenNetwok(dim = dim_out, depth = 3, num_neighbors_cutoff = 24)
        trunc_normal_(self.cls_token.weight, std=.02)

    def forward(self, batch):
        mask = batch["mask"] #bs, seq_len
        bs = mask.shape[0] # int
        pos = batch["pos"][mask] 
        mask = mask[:,:mask.sum(-1).max()] 
        batch_index = mask.nonzero()[:,0] 
        edge_index = knn_graph(x = pos, k=8, batch=batch_index).to(mask.device)
        adj_matrix = to_dense_adj(edge_index, batch_index).int()
        x = self.fe(batch, mask.sum(-1).max())
        x = self.loacl_attn(x, adj_matrix, mask)
        cls_token = self.cls_token.weight.unsqueeze(0).expand(bs,-1,-1)
        x = torch.cat([cls_token,x],1)
        mask = torch.cat([torch.ones(bs, 1, dtype=torch.bool, device=x.device), mask], dim=1)
        x = self.encoder(x, mask=mask)
        return x

In [7]:
md = EncoderWithDirectionReconstructionV18()
md.load_state_dict(torch.load('/opt/slh/icecube/hb_training_loop/V18FT/models/model_6.pth'))

<All keys matched successfully>

In [8]:
EncoderWithDirectionReconstructionV11_V2_LOCAL_GLOBAL()
EncoderWithDirectionReconstructionV11()

EncoderWithDirectionReconstructionV11(
  (fe): ExtractorV0(
    (emb): SinusoidalPosEmb()
    (emb2): SinusoidalPosEmb()
    (aux_emb): Embedding(2, 48)
    (qe_emb): Embedding(2, 48)
    (proj): Linear(in_features=672, out_features=320, bias=True)
  )
  (encoder): BeDeepIceModel(
    (Beblocks): ModuleList(
      (0): BeBlock(
        (norm1): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
        (attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=320, out_features=320, bias=True)
        )
        (drop_path): Identity()
        (norm2): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
        (mlp): BeMLP(
          (fc1): Linear(in_features=320, out_features=1280, bias=True)
          (act): GELU()
          (fc2): Linear(in_features=1280, out_features=320, bias=True)
          (drop): Dropout(p=0.0, inplace=False)
        )
      )
      (1): BeBlock(
        (norm1): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
     

In [1]:
import polars

In [2]:
polars.__version__

'0.16.8'

In [3]:
import pyarrow

In [4]:
pyarrow.__version__

'11.0.0'