In [169]:
from __future__ import annotations
import torch
from torch import nn
from mamba_ssm import Mamba
from timm.models.layers import DropPath, trunc_normal_
import math
from transformers import SegformerForSemanticSegmentation
from torchinfo import summary

# Layer Norm

In [3]:
class LayerNorm(nn.Module):
    r"""
        channelは最初の次元と最後の次元の二つの形式をサポート。
        channels_last : B, H, W, C (default)
        channels_first : B, C, H, W
    """

    def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
        super(LayerNorm, self).__init__()
        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.bias = nn.Parameter(torch.zeros(normalized_shape))
        self.eps = eps
        self.data_format = data_format
        if self.data_format not in ["channels_last", "channels_first"]:
            raise NotImplementedError
        self.normalized_shape = (normalized_shape, )
    
    def forward(self, x):
        if self.data_format == "channels_last":
            return nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
        elif self.data_format == "channels_first":
            u = x.mean(1, keepdim=True)
            s = (x - u).pow(2).mean(1, keepdim=True)
            x = (x - u) / torch.sqrt(s + self.eps)
            x = self.weight[: ,None, None] * x + self.bias[:, None, None]
            return x

In [4]:
B = 1
C = 3 
H, W = 28, 28
x = torch.rand(size=(B, H, W, C))

In [5]:

lnorm = LayerNorm(normalized_shape=C)
x_norm  = lnorm(x)
for i in range(C):
    print(f"C={i}", "mean:", x_norm[..., i].mean().item(), "var:", x_norm[..., i].std().item()**2)

C=0 mean: -0.004667768720537424 var: 1.000094654415662
C=1 mean: 0.006295911036431789 var: 1.0178018509495956
C=2 mean: -0.0016281544230878353 var: 0.9855169191951063


# Downsample conv

In [9]:
class DWConv(nn.Module):
    def __init__(self, dim=768):
        super(DWConv, self).__init__()
        self.dwconv = nn.Conv3d(dim, dim, 3, 1, 1, bias=True, groups=dim)
    
    def forward(self, x, nf, H, W):
        B, N, C = x.shape
        x = x.transpose(1,2).view(B, C, nf, H, W)
        x = self.dwconv(x)
        x = x.flatten(2).transpose(1,2)
        return x

In [69]:
C = 3
nf = 5
H, W = 224,224
B = 1
x = torch.randn(B, C, nf, H, W)
x = x.flatten(2).transpose(1,2)

dwconv = DWConv(dim=C)
out = dwconv(x, nf, H, W)
out.shape # B, N, C

y = out.transpose(1,2).view(B, C, nf, H, W)
y.shape # B, C, nf, H, W

torch.Size([1, 3, 5, 224, 224])

# MLP

In [83]:
class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None,
                 act_layer=nn.GELU, drop=0.):
        super(Mlp, self).__init__()
        hidden_features = hidden_features or in_features
        out_features = out_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.dwconv = DWConv(hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)
        self.apply(self._init_weights)
    
    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=0.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()
    
    def forward(self, x, nf, H, W):
        x = self.fc1(x)
        x = self.dwconv(x, nf, H, W)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

In [90]:
C = 3
nf = 5
H, W = 24,24
B = 1
N = nf*H*W
x = torch.randn(B, C, nf, H, W)
x = x.flatten(2).transpose(1,2) #B, N , C
mlp = Mlp(in_features=C, out_features=10)
mlp(x, nf, H, W).shape

torch.Size([1, 2880, 10])

# Mamba Layer


In [125]:

class MambaLayer(nn.Module):
    def __init__(
        self,
        dim,
        d_state=16, d_conv=4, expand=2, mlp_ratio=4, drop=0., drop_path=0., act_layer=nn.GELU
    ):
        super().__init__()
        self.dim = dim
        self.norm1 = nn.LayerNorm(dim)
        self.mamba = Mamba(
            d_model=dim, # モデルの次元数
            d_state=d_state, # SSM state expansion factor
            expand=expand, # Block expantion factor
        )
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
        self.norm2 = nn.LayerNorm(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
    
    def forward(self, x):
        B, C, nf, H, W = x.shape
        assert C == self.dim
        n_tokens = x.shape[2:].numel()
        img_dims = x.shape[2:]
        x_flat = x.reshape(B, C, n_tokens).transpose(-1, -2) # B, C, N -> B, N, C
        x_mamba = x_flat + self.drop_path(self.mamba(self.norm1(x_flat)))
        x_mamba = x_mamba + self.drop_path(self.mlp(self.norm2(x_mamba), nf, H, W))
        out = x_mamba.transpose(-1, -2).reshape(B, C, *img_dims)
        return out
       

In [126]:
 
x = torch.randn(1, 3, 5, 28,28)
dim = 3
mal = MambaLayer(dim=dim).cuda()
mal(x.cuda()).shape

torch.Size([1, 3, 5, 28, 28])

# Mamba Block


In [228]:
class mamba_block(nn.Module):

    def __init__(
        self,
        backbone,
        in_chans=1,
        depths=[2, 2, 2, 2],
        dims=[64, 128, 320, 512],
        drop_path_rate=0.0,
        layer_scale_init_value=1e-6,
        out_indices=[0, 1, 2, 3],
    ):
        super().__init__()

        self.downsample_layers = backbone.segformer.encoder

        self.stages = nn.ModuleList()
        dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
        cur = 0
        for i in range(len(dims)):
            stage = nn.Sequential(
                *[
                    nn.Sequential(MambaLayer(dim=dims[i], drop_path=dp_rates[i]))
                    for j in range(depths[i])
                ]
            )
            self.stages.append(stage)
            cur += depths[i]

        self.out_indices = out_indices

    def forward_features(self, x):
        outs = []
        B, nf, C, H, W = x.shape
        x = x.reshape(B * nf, x.shape[-3], x.shape[-2], x.shape[-1]) # B*nf, C, H, W

        layers = [
            self.downsample_layers.patch_embeddings,
            self.downsample_layers.block,
            self.stages,
        ]

        for idx, layer in enumerate(zip(*layers)):
            embedding_layer, block_layer, mam_stage = layer
            # first, obtain patch embeddings
            x, height, width = embedding_layer(x)

            # second, send embeddings through blocks
            for i, blk in enumerate(block_layer):
                layer_outputs = blk(x, height, width, False)
                x = layer_outputs[0]

            # third, optionally reshape back to (batch_size, num_channels, height, width)
            x = x.reshape(B * nf, height, width, -1).permute(0, 3, 1, 2).contiguous()
            x = x.reshape(B, nf, x.shape[-3], x.shape[-2], x.shape[-1]).transpose(1, 2)
            x = mam_stage(x)
            x = x.transpose(1, 2)
            x = x.reshape(B * nf, x.shape[-3], x.shape[-2], x.shape[-1])

            outs.append(x)

        return tuple(outs)

    def forward(self, x):
        x = self.forward_features(x)
        return x


# ViVim

In [229]:
class Vivim(nn.Module):
    def __init__(
        self,
        in_channels=3,
        out_channels=1,
        depths=[2, 2, 2, 2],
        feature_size=[64, 128, 320, 512],
        drop_path_rate=0,
        layer_scale_init_value=1e-6,
        hidden_size: int = 768,
        spatial_dims=2,
        with_edge=False,
    ) -> None:
        super().__init__()

        self.hidden_size = hidden_size
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.depths = depths
        self.drop_path_rate = drop_path_rate
        self.feature_size = feature_size
        self.layer_scale_init_value = layer_scale_init_value

        self.spatial_dims = spatial_dims

        backbone = SegformerForSemanticSegmentation.from_pretrained(
            "nvidia/segformer-b3-finetuned-ade-512-512"
        )
        self.encoder = mamba_block(
            backbone,
            in_channels,
            dims=feature_size,
        )
        self.decoder = backbone.decode_head
        # self.decoder.classifier = nn.Sequential()

        self.out = nn.Conv2d(768, out_channels, kernel_size=1)
        self.with_edge = with_edge
        if with_edge:
            self.edgeocr_cls_head = nn.Conv2d(64, 1, kernel_size=1, stride=1, padding=0, bias=True)


    def proj_feat(self, x):
        new_view = [x.size(0)] + self.proj_view_shape
        x = x.view(new_view)
        x = x.permute(self.proj_axes).contiguous()
    
    def decode(self, encoder_hidden_states):
        batch_size = encoder_hidden_states[-1].shape[0]

        all_hidden_states = ()
        for encoder_hidden_state, mlp in zip(encoder_hidden_states, self.decoder.linear_c):
            if self.decoder.config.reshape_last_stage is False and encoder_hidden_state.ndim == 3:
                height = width = int(math.sqrt(encoder_hidden_state.shape[-1]))
                encoder_hidden_state = (
                    encoder_hidden_state.reshape(batch_size, height, width, -1)
                    .permute(0, 3, 1, 2)
                    .contiguous()
                )

            # unify channel dimension
            height, width = encoder_hidden_state.shape[2], encoder_hidden_state.shape[3]
            encoder_hidden_state = mlp(encoder_hidden_state)
            encoder_hidden_state = encoder_hidden_state.permute(0, 2, 1)
            encoder_hidden_state = encoder_hidden_state.reshape(batch_size, -1, height, width)
            # upsample
            encoder_hidden_state = nn.functional.interpolate(
                encoder_hidden_state,
                size=encoder_hidden_states[0].size()[2:],
                mode="bilinear",
                align_corners=False,
            )
            all_hidden_states += (encoder_hidden_state,)
        concat_hidden_states = torch.cat(all_hidden_states[::-1], dim=1)
        hidden_states = self.decoder.linear_fuse(concat_hidden_states)
        hidden_states = self.decoder.batch_norm(hidden_states)
        hidden_states = self.decoder.activation(hidden_states)
        hidden_states = self.decoder.dropout(hidden_states)

        logits = self.out(hidden_states)

        return logits

    def forward(self, x):
        B, nf, C, H, W = x.shape
        outs = self.encoder(x)
        logits = self.decode(outs)
        upsampled_logits = nn.functional.interpolate(
            logits, size=(H, W), mode="bilinear", align_corners=False
        )

        if self.with_edge:
            edge = self.edgeocr_cls_head(outs[0])
            edge = nn.functional.interpolate(edge, size=(H, W), mode="bilinear", align_corners=False)
            return upsampled_logits, edge
        else:
            return upsampled_logits

In [240]:
x = torch.randn(1, 5, 3, 224, 224).cuda()
model = Vivim(in_channels=3, out_channels=1).cuda()

In [241]:
model(x).shape


torch.Size([5, 1, 224, 224])

In [244]:
from vivim import Vivim
model = Vivim(in_chans=1, out_chans=1).cuda()
x = torch.randn(1, 5, 1, 224, 224).cuda()
model(x)

TypeError: Mamba.__init__() got an unexpected keyword argument 'bimamba_type'