In [1]:
#|default_exp models

In [9]:
#| export
import torch
from x_transformers import ContinuousTransformerWrapper, Encoder, Decoder
from torch import nn

In [10]:
#| export
class LogCoshLoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, y_t, y_prime_t):
        ey_t = y_t - y_prime_t
        return torch.mean(torch.log(torch.cosh(ey_t + 1e-12)))



class MeanPoolingWithMask(nn.Module):
    def __init__(self):
        super(MeanPoolingWithMask, self).__init__()

    def forward(self, x, mask):
        # Multiply the mask with the input tensor to zero out the padded values
        x = x * mask.unsqueeze(-1)

        # Sum the values along the sequence dimension
        x = torch.sum(x, dim=1)

        # Divide the sum by the number of non-padded values (i.e. the sum of the mask)
        x = x / torch.sum(mask, dim=1, keepdim=True)

        return x

class FeedForward(nn.Module):
    def __init__(self, dim, dim_out = None, mult = 4):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, dim * mult),
            nn.GELU(),
            nn.Linear(dim * mult, dim_out)
        )

    def forward(self, x):
        return self.net(x)


class IceCubeModelEncoderV0(nn.Module):
    def __init__(slf):
        super().__init__()
        self.encoder = ContinuousTransformerWrapper(
            dim_in=6,
            dim_out=128,
            max_seq_len=150,
            attn_layers=Encoder(dim=128,
                        depth=3, 
                        heads=8),
        )

        #self.pool = MeanPoolingWithMask()
        self.head = FeedForward(128, 2)

    def forward(self, x, mask):
        x = self.encoder(x, mask = mask)
        x = x.mean(dim=1)
        x = self.head(x)
        return x


class IceCubeModelEncoderV1(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = ContinuousTransformerWrapper(
            dim_in=6,
            dim_out=128,
            max_seq_len=150,
            attn_layers=Encoder(dim=128,
                        depth=3, 
                        heads=8),
        )

        self.pool = MeanPoolingWithMask()
        self.head = FeedForward(128, 2)

    def forward(self, x, mask):
        x = self.encoder(x, mask = mask)
        x = self.pool(x, mask)
        x = self.head(x)
        return x

In [11]:
IceCubeModelEncoderV1()

IceCubeModelEncoderV1(
  (encoder): ContinuousTransformerWrapper(
    (pos_emb): AbsolutePositionalEmbedding(
      (emb): Embedding(150, 128)
    )
    (post_emb_norm): Identity()
    (emb_dropout): Dropout(p=0.0, inplace=False)
    (project_in): Linear(in_features=6, out_features=128, bias=True)
    (attn_layers): Encoder(
      (layers): ModuleList(
        (0): ModuleList(
          (0): ModuleList(
            (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
            (1): None
            (2): None
          )
          (1): Attention(
            (to_q): Linear(in_features=128, out_features=512, bias=False)
            (to_k): Linear(in_features=128, out_features=512, bias=False)
            (to_v): Linear(in_features=128, out_features=512, bias=False)
            (dropout): Dropout(p=0.0, inplace=False)
            (to_out): Linear(in_features=512, out_features=128, bias=False)
          )
          (2): Residual()
        )
        (1): ModuleList(
          (0): M

In [4]:
# x = torch.rand(1, 1024, 6)
# mask = torch.ones(1, 1024, dtype=torch.bool)
# y = model(x, mask = mask)

In [5]:
#|hide
#|eval: false
from nbdev.doclinks import nbdev_export
nbdev_export()