In [None]:
#|default_exp models.gMLP

# gMLP

This is an unofficial PyTorch implementation based on:

* Liu, H., Dai, Z., So, D. R., & Le, Q. V. (2021). <span style="color:dodgerblue">**Pay Attention to MLPs**</span>. arXiv preprint arXiv:2105.08050.

* Cholakov, R., & Kolev, T. (2022). <span style="color:dodgerblue">**The GatedTabTransformer. An enhanced deep
learning architecture for tabular modeling**</span>. arXiv preprint arXiv:2201.00199.

In [None]:
#|export
from tsai.imports import *
from tsai.models.layers import *

In [None]:
#|export
class _SpatialGatingUnit(nn.Module):
    def __init__(self, d_ffn, seq_len):
        super().__init__()
        self.norm = nn.LayerNorm(d_ffn)
        self.spatial_proj = nn.Conv1d(seq_len, seq_len, kernel_size=1)
        nn.init.constant_(self.spatial_proj.bias, 1.0)
        nn.init.normal_(self.spatial_proj.weight, std=1e-6)

    def forward(self, x):
        u, v = x.chunk(2, dim=-1)
        v = self.norm(v)
        v = self.spatial_proj(v)
        out = u * v
        return out


class _gMLPBlock(nn.Module):
    def __init__(self, d_model, d_ffn, seq_len):
        super().__init__()
        self.norm = nn.LayerNorm(d_model)
        self.channel_proj1 = nn.Linear(d_model, d_ffn * 2)
        self.channel_proj2 = nn.Linear(d_ffn, d_model)
        self.sgu = _SpatialGatingUnit(d_ffn, seq_len)

    def forward(self, x):
        residual = x
        x = self.norm(x)
        x = F.gelu(self.channel_proj1(x))
        x = self.sgu(x)
        x = self.channel_proj2(x)
        out = x + residual
        return out


class _gMLPBackbone(nn.Module):
    def __init__(self, d_model=256, d_ffn=512, seq_len=256, depth=6):
        super().__init__()
        self.model = nn.Sequential(
            *[_gMLPBlock(d_model, d_ffn, seq_len) for _ in range(depth)]
        )

    def forward(self, x):
        return self.model(x)


class gMLP(_gMLPBackbone):
    def __init__(
        self,
        c_in,
        c_out,
        seq_len,
        patch_size=1,
        d_model=256,
        d_ffn=512,
        depth=6,
    ):
        assert seq_len % patch_size == 0, "`seq_len` must be divisibe by `patch_size`"
        super().__init__(d_model, d_ffn, seq_len // patch_size, depth)
        self.patcher = nn.Conv1d(
            c_in, d_model, kernel_size=patch_size, stride=patch_size
        )
        self.head = nn.Linear(d_model, c_out)

    def forward(self, x):
        patches = self.patcher(x)
        batch_size, num_channels, _ = patches.shape
        patches = patches.permute(0, 2, 1)
        patches = patches.view(batch_size, -1, num_channels)
        embedding = self.model(patches)
        embedding = embedding.mean(dim=1)
        out = self.head(embedding)
        return out

In [None]:
bs = 16
c_in = 3
c_out = 2
seq_len = 64
patch_size = 4
xb = torch.rand(bs, c_in, seq_len)
model = gMLP(c_in, c_out, seq_len, patch_size=patch_size)
test_eq(model(xb).shape, (bs, c_out))

In [None]:
#|eval: false
#|hide
from tsai.export import get_nb_name; nb_name = get_nb_name(locals())
from tsai.imports import create_scripts; create_scripts(nb_name)

<IPython.core.display.Javascript object>

/Users/nacho/notebooks/tsai/nbs/103d_models.gMLP.ipynb saved at 2022-11-09 13:01:47
Correct notebook to script conversion! 😃
Wednesday 09/11/22 13:01:49 CET
