In [1]:
%load_ext autoreload
%autoreload 2

In [26]:
import gc
from pathlib import Path
import numpy as np
import yaml

import torch
import torch.nn as nn
import torch.nn.functional as F

from torchsummary import summary
from pytorch_lightning.trainer import seed_everything

from turboflow.models.phyrff import plDivFreeRFFNet
from turboflow.dataloaders import TurboFlowDataModule
from turboflow.utils import phy_utils as phy
from turboflow.utils import torch_utils as tch

import matplotlib.pyplot as plt

In [38]:
class SpatialGatingUnit(nn.Module):
    def __init__(self, d_ffn, seq_len):
        super().__init__()
        self.norm = nn.LayerNorm(d_ffn)
        self.spatial_proj = nn.Conv1d(seq_len, seq_len, kernel_size=1)
        nn.init.constant_(self.spatial_proj.bias, 1.0)

    def forward(self, x):
        u, v = x.chunk(2, dim=-1)
        v = self.norm(v)
        v = self.spatial_proj(v).squeeze()
        out = u * v
        return out


class gMLPBlock(nn.Module):
    def __init__(self, d_model, d_ffn, seq_len):
        super().__init__()
        self.norm = nn.LayerNorm(d_model, 1)
        self.channel_proj1 = nn.Linear(d_model, d_ffn * 2)
        self.channel_proj2 = nn.Linear(d_ffn, d_model)
        self.sgu = SpatialGatingUnit(d_ffn, seq_len)

    def forward(self, x):
        residual = x
        x = self.norm(x) # norm axis = channel
        x = F.gelu(self.channel_proj1(x))
        x = self.sgu(x)
        x = self.channel_proj2(x)
        out = x + residual
        return out

class gMLP(nn.Module):
    def __init__(self, d_model=256, d_ffn=512, seq_len=256, num_layers=6):
        super().__init__()
        self.model = nn.Sequential(
            *[gMLPBlock(d_model, d_ffn, seq_len) for _ in range(num_layers)]
        )

    def forward(self, x):
        return self.model(x)
    
model = gMLP(d_model=256, d_ffn=512, seq_len=128, num_layers=1)
print(model)

x = torch.rand(32,128,256)
print(model(x).shape)


gMLP(
  (model): Sequential(
    (0): gMLPBlock(
      (norm): LayerNorm((256,), eps=1, elementwise_affine=True)
      (channel_proj1): Linear(in_features=256, out_features=1024, bias=True)
      (channel_proj2): Linear(in_features=512, out_features=256, bias=True)
      (sgu): SpatialGatingUnit(
        (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (spatial_proj): Conv1d(128, 128, kernel_size=(1,), stride=(1,))
      )
    )
  )
)
tensor(0.4969)
tensor(0.4927)
tensor(-0.0046, grad_fn=<MeanBackward0>)
tensor(-8.8476e-09, grad_fn=<MeanBackward0>)
torch.Size([32, 128, 256])
