In [1]:
import torch
import torch.nn as nn
from torchsummary import summary

## Dynamic Tanh

In [2]:
class DyT(nn.Module):
    def __init__(self, dims, init_alpha=0.5):
        super().__init__()
        self.alpha = nn.Parameter(torch.ones(1)) * init_alpha
        self.gamma = nn.Parameter(torch.ones(dims))
        self.beta = nn.Parameter(torch.zeros(dims))

    def forward(self, x):
        return self.gamma * torch.tanh(self.alpha * x) + self.beta

In [3]:
class DyT_wrapper(nn.Module):
    def __init__(self, dims, init_alpha=0.5):
        super().__init__()
        self.dyt = DyT(dims, init_alpha)

    def forward(self, x):
        x = torch.movedim(x, 1,-1)
        x = self.dyt(x)
        x = torch.movedim(x, -1, 1)
        return x

## ConvNeXt

In [4]:
class ConvNeXtBlock2D(nn.Module):
  def __init__(self, dim, layer_scale_init_value=1e-6, drop=0.2, **kwargs):
    super().__init__()

    self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim)  # depthwise conv

    # self.norm = nn.LayerNorm(dim)
    self.norm = DyT(dim)

    self.pwconv1 = nn.Linear(dim, 4 * dim)
    self.act = nn.GELU()
    self.pwconv2 = nn.Linear(4 * dim, dim)

    self.dropout = nn.Dropout2d(p=drop)

    self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim,)), requires_grad=True) if layer_scale_init_value > 0 else None


  def forward(self, x):
    residual = x
    x = self.dwconv(x)

    # Transpose for LayerNorm
    x = x.permute(0, 2, 3, 1)
    x = self.norm(x)

    x = self.pwconv1(x)
    x = self.act(x)
    x = self.pwconv2(x)

    x = self.dropout(x)

    if self.gamma is not None:
        x = self.gamma * x

    # Transpose back to (B, C, H, W)
    x = x.permute(0, 3, 1, 2)
    # no drop path y
    return residual + x

In [5]:
block = ConvNeXtBlock2D(96)
summary(block, (96,112,112), device='cpu')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 96, 112, 112]           4,800
               DyT-2         [-1, 112, 112, 96]               0
            Linear-3        [-1, 112, 112, 384]          37,248
              GELU-4        [-1, 112, 112, 384]               0
            Linear-5         [-1, 112, 112, 96]          36,960
         Dropout2d-6         [-1, 112, 112, 96]               0
Total params: 79,008
Trainable params: 79,008
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 4.59
Forward/backward pass size (MB): 110.25
Params size (MB): 0.30
Estimated Total Size (MB): 115.15
----------------------------------------------------------------


In [6]:
class ConvNext(nn.Module):
    def __init__(self, in_chans=1, dims=[32, 64, 128, 256], stages=[1, 1, 3, 1]):
        super().__init__()

        self.in_chans = in_chans
        self.dims = dims
        self.stages = stages

        self.downsample_layers = nn.ModuleList()  # stem and 3 intermediate downsampling conv layers
        stem = nn.Sequential(
            nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
            # myLayerNorm(dims[0], eps=1e-6)
            DyT_wrapper(dims[0])
        )
        self.downsample_layers.append(stem)
        for i in range(3):
            downsample_layer = nn.Sequential(
                # myLayerNorm(dims[i], eps=1e-6),
                DyT_wrapper(dims[i]),
                nn.Conv2d(dims[i], dims[i + 1], kernel_size=2, stride=2),
            )
            self.downsample_layers.append(downsample_layer)

        self.model_layers = nn.ModuleList()
        for i, stage_length in enumerate(stages):
            stage = nn.ModuleList([ConvNeXtBlock2D(dims[i]) for _ in range(stage_length)])
            self.model_layers.append(stage)

        self.pooling = nn.AdaptiveAvgPool2d((1, 1))
        self.flatten = nn.Flatten()
        # self.final_norm = nn.LayerNorm(dims[-1])
        self.final_norm = DyT_wrapper(dims[-1])

    def forward(self, x):
        for i in range(len(self.dims)):
            x = self.downsample_layers[i](x)
            for layer in self.model_layers[i]:
                x = layer(x)

        x = self.pooling(x)
        x = self.flatten(x)
        x = self.final_norm(x)  # Final normalization
        return x

In [7]:
summary(ConvNext(in_chans=3,dims=[96,192,384,768],stages=[3,3,9,3]), (3,224,224), device="cpu")

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 96, 56, 56]           4,704
               DyT-2           [-1, 56, 56, 96]               0
       DyT_wrapper-3           [-1, 96, 56, 56]               0
            Conv2d-4           [-1, 96, 56, 56]           4,800
               DyT-5           [-1, 56, 56, 96]               0
            Linear-6          [-1, 56, 56, 384]          37,248
              GELU-7          [-1, 56, 56, 384]               0
            Linear-8           [-1, 56, 56, 96]          36,960
         Dropout2d-9           [-1, 56, 56, 96]               0
  ConvNeXtBlock2D-10           [-1, 96, 56, 56]               0
           Conv2d-11           [-1, 96, 56, 56]           4,800
              DyT-12           [-1, 56, 56, 96]               0
           Linear-13          [-1, 56, 56, 384]          37,248
             GELU-14          [-1, 56, 

## CSPNet

In [8]:
class CSPStage(nn.Module):
    def __init__(self, in_channels, num_blocks, **kwargs):
        super().__init__()

        ## base layer moved to downsampling layer
        self.csp_channels = in_channels // 2
        self.bottlenecks = nn.Sequential(
            *[ConvNeXtBlock2D(in_channels // 2, **kwargs) for _ in range(num_blocks)]
        )
        self.transition_layer = nn.Conv2d(in_channels, in_channels, kernel_size=1,stride=1)

    def forward(self, x):
        ## base layer is actually just a downsample layer... follow the convnext downsampling technique
        x1, x2 = x[:, self.csp_channels:,:,:], x[:, :self.csp_channels,:,:]
        x2 = self.bottlenecks(x2)
        x = torch.cat([x1,x2], dim=1)
        x = self.transition_layer(x)
        return x        

In [14]:
cspStage = CSPStage(96,1)
summary(cspStage, (96,224,224), device='cpu')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 48, 224, 224]           2,400
               DyT-2         [-1, 224, 224, 48]               0
            Linear-3        [-1, 224, 224, 192]           9,408
              GELU-4        [-1, 224, 224, 192]               0
            Linear-5         [-1, 224, 224, 48]           9,264
         Dropout2d-6         [-1, 224, 224, 48]               0
   ConvNeXtBlock2D-7         [-1, 48, 224, 224]               0
            Conv2d-8         [-1, 96, 224, 224]           9,312
Total params: 30,384
Trainable params: 30,384
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 18.38
Forward/backward pass size (MB): 275.62
Params size (MB): 0.12
Estimated Total Size (MB): 294.12
----------------------------------------------------------------


In [20]:
class CSPConvNext(nn.Module):
    def __init__(self, in_chans=1, dims=[32, 64, 128, 256], stages=[1, 1, 3, 1]):
        super().__init__()

        self.in_chans = in_chans
        self.dims = dims
        self.stages = stages

        self.downsample_layers = nn.ModuleList()  # stem and 3 intermediate downsampling conv layers
        stem = nn.Sequential(
            nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
            # myLayerNorm(dims[0], eps=1e-6)
            DyT_wrapper(dims[0])
        )
        self.downsample_layers.append(stem)
        for i in range(3):
            downsample_layer = nn.Sequential(
                # myLayerNorm(dims[i], eps=1e-6),
                DyT_wrapper(dims[i]),
                nn.Conv2d(dims[i], dims[i + 1], kernel_size=2, stride=2),
            )
            self.downsample_layers.append(downsample_layer)

        self.model_layers = nn.ModuleList()
        for i, stage_length in enumerate(stages):
            # stage = nn.ModuleList([ConvNeXtBlock2D(dims[i]) for _ in range(stage_length)])
            stage = CSPStage(dims[i], stage_length)
            self.model_layers.append(stage)

    def forward(self, x):
        outputs = []
        for i in range(len(self.dims)):
            x = self.downsample_layers[i](x)
            x = self.model_layers[i](x)
            outputs.append(x)

        return outputs

In [21]:
summary(CSPConvNext(in_chans=3,dims=[96,192,384,768],stages=[3,3,9,3]), (3,224,224), device="cpu")

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 96, 56, 56]           4,704
               DyT-2           [-1, 56, 56, 96]               0
       DyT_wrapper-3           [-1, 96, 56, 56]               0
            Conv2d-4           [-1, 48, 56, 56]           2,400
               DyT-5           [-1, 56, 56, 48]               0
            Linear-6          [-1, 56, 56, 192]           9,408
              GELU-7          [-1, 56, 56, 192]               0
            Linear-8           [-1, 56, 56, 48]           9,264
         Dropout2d-9           [-1, 56, 56, 48]               0
  ConvNeXtBlock2D-10           [-1, 48, 56, 56]               0
           Conv2d-11           [-1, 48, 56, 56]           2,400
              DyT-12           [-1, 56, 56, 48]               0
           Linear-13          [-1, 56, 56, 192]           9,408
             GELU-14          [-1, 56, 