In [1]:
# transformer block 1
# overlapping patch merging -> [selfefficient attention -> Mix-FFN]
#The OverlapPatchMerging block can be implemented with a convolution layer with a stride less than the kernel_size, so it overlaps different patches. In SegFormer, the conv layer is followed by a layer norm.

import torch
from torch import nn
from timm.models.layers import DropPath, to_2tuple, trunc_normal_, to_3tuple
import math
import numpy as np
import torch.nn.functional as F
#from mmcv.runner import load_checkpoint
#from mmseg.utils import get_root_logger

class OverlapPatchEmbed(nn.Module):
    """ Image to Patch Embedding
    """

    def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=1, embed_dim=768):
        super().__init__()
        if isinstance(img_size, int) :
            img_size = to_3tuple(img_size)
        if isinstance(patch_size, int) :
            patch_size = to_3tuple(patch_size)
        #img_size = to_3tuple(img_size)
        #patch_size = to_3tuple(patch_size)

        self.img_size = img_size
        self.patch_size = patch_size
        self.D, self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1], img_size[2]//patch_size[2]
        self.num_patches = self.D * self.H * self.W
        self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
                              padding=(patch_size[0] // 2, patch_size[1] // 2, patch_size[2] // 2))
        self.norm = nn.LayerNorm(embed_dim)

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv3d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

    def forward(self, x):
        x = self.proj(x)
        _, _, D, H, W = x.shape
        x = x.flatten(2).transpose(1, 2)
        x = self.norm(x)

        return x, D, H, W


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# !pip install -U openmim
# !mim install mmcv


In [3]:
# tensor = torch.randn(1, 2, 3, 4, 5)
# torch.randn(1, 1, 224, 224, 224).shape
# model = OverlapPatchEmbed()
# print(model(torch.randn(1, 1, 224, 224, 224))[0].shape)
# #embed_tensor = model(tensor)[0]


In [4]:
# print(tensor.permute(0, 2, 3, 4, 1).shape)
# print(tensor.reshape(1, 3, -1).shape)

In [5]:
# print(tensor.reshape(1, -1, 3).shape)
# print(tensor.transpose(-2, -1).shape)

In [6]:
# m = nn.Softmax(dim=-1)
# input = torch.randn(2, 3)
# print(input)
# output = m(input)
# print(output)

In [7]:
# print(np.exp(input[1]))
# exp = np.exp(input[1])
# exp = np.array(exp)
# print(type(exp))
# print(np.sum(exp))
# exp = exp/np.sum(exp)
# print(exp)
# #print(np.sum(np.exp(input[0])))
# #norm = np.exp(input[0])/np.sum(np.exp(input[0]))
# #print(norm)

In [8]:
class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1):
        super().__init__()
        assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."

        self.dim = dim
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        self.q = nn.Linear(dim, dim, bias=qkv_bias)
        self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        self.sr_ratio = sr_ratio
        if sr_ratio > 1:
            self.sr = nn.Conv3d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
            self.norm = nn.LayerNorm(dim)

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv3d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] *  m.kernel_size[2] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

    def forward(self, x, D, H, W ):
        B, N, C = x.shape
        q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)

        if self.sr_ratio > 1:
            x_ = x.permute(0, 2, 1).reshape(B, C, D, H, W)
            x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
            x_ = self.norm(x_)
            kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        else:
            kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        k, v = kv[0], kv[1]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)

        return x
    
        

        #

In [9]:
class DWConv(nn.Module):
    def __init__(self, dim=768):
        super(DWConv, self).__init__()
        self.dwconv = nn.Conv3d(dim, dim, 3, 1, 1, bias=True, groups=dim)

    def forward(self, x, D, H, W):
        B, N, C = x.shape
        x = x.transpose(1, 2).view(B, C, D, H, W)
        x = self.dwconv(x)
        x = x.flatten(2).transpose(1, 2)

        return x

In [10]:
class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.dwconv = DWConv(hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv3d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] *  m.kernel_size[2] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

    def forward(self, x, D, H, W):
        x = self.fc1(x)
        x = self.dwconv(x, D, H, W)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

In [11]:
tensor = torch.randn(1, 27, 768)
model = DWConv(dim=768)
print(model.forward(tensor, 3, 3, 3).size())
#print(model(tensor)) 

B, N, C = tensor.shape
print(B, N, C)
tensor = tensor.transpose(1, 2)
print(tensor.size())

tensor = tensor.transpose(1, 2).view(B, C, 3, 3, 3)
print(tensor.size())

torch.Size([1, 27, 768])
1 27 768
torch.Size([1, 768, 27])
torch.Size([1, 768, 3, 3, 3])


In [12]:
sum(p.numel() for p in model.parameters() if p.requires_grad)

21504

In [13]:
dim = 768 
q = nn.Linear(dim, dim)


In [14]:
t = torch.randn(1, 3, 224, 224)
t.flatten(2).transpose(1, 2).shape
tup =to_2tuple(224)
print(tup)

(224, 224)


In [15]:
t.flatten(2).shape

torch.Size([1, 3, 50176])

In [16]:
scale = None or 0.5
print(scale)

0.5


In [17]:
class Block(nn.Module):

    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(
            dim,
            num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
            attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio)
        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv3d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

    def forward(self, x, D, H, W):
        x = x + self.drop_path(self.attn(self.norm1(x), D, H, W))
        x = x + self.drop_path(self.mlp(self.norm2(x), D, H, W))

        return x


In [18]:
# drop_path_rate = 0
# depths=[3, 4, 6, 3]
# print(sum(depths))
# lst = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
# print(lst)
# print(len(lst))


In [19]:
# for x in torch.linspace(0, 12, sum(depths)):
#     print(x.item())

In [20]:
class MixVisionTransformer(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=1, num_classes=1000, embed_dims=[64, 128, 256, 512],
                 num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0.,
                 attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,
                 depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1]):
        super().__init__()
        assert isinstance(img_size, int) or (isinstance(img_size, tuple) and len(img_size) == 3), f"the image {img_size} is the type of {type(img_size)}. It has to be int or tuple of 3 int"
        self.num_classes = num_classes
        self.depths = depths

        # patch_embed
        if isinstance(img_size, int):
            self.patch_embed1 = OverlapPatchEmbed(img_size=img_size, patch_size=7, stride=4, in_chans=in_chans,
                                              embed_dim=embed_dims[0])
            self.patch_embed2 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0],
                                              embed_dim=embed_dims[1])
            self.patch_embed3 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1],
                                              embed_dim=embed_dims[2])
            self.patch_embed4 = OverlapPatchEmbed(img_size=img_size // 16, patch_size=3, stride=2, in_chans=embed_dims[2],
                                              embed_dim=embed_dims[3])
            
        elif isinstance(img_size, tuple):
            self.patch_embed1 = OverlapPatchEmbed(img_size=(img_size[0], img_size[1], img_size[2]), patch_size=7, stride=4, in_chans=in_chans,
                                              embed_dim=embed_dims[0])
            self.patch_embed2 = OverlapPatchEmbed(img_size=(img_size[0] // 4, img_size[1] //4, img_size[2]//4), patch_size=3, stride=2, in_chans=embed_dims[0],
                                              embed_dim=embed_dims[1])
            self.patch_embed3 = OverlapPatchEmbed(img_size=(img_size[0] // 8, img_size[1]//8, img_size[2]//8), patch_size=3, stride=2, in_chans=embed_dims[1],
                                              embed_dim=embed_dims[2])
            self.patch_embed4 = OverlapPatchEmbed(img_size=(img_size[0] // 16, img_size[1]//16, img_size[2]//16), patch_size=3, stride=2, in_chans=embed_dims[2],
                                              embed_dim=embed_dims[3])

        # transformer encoder
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule
        cur = 0
        self.block1 = nn.ModuleList([Block(
            dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale,
            drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
            sr_ratio=sr_ratios[0])
            for i in range(depths[0])])
        self.norm1 = norm_layer(embed_dims[0])

        cur += depths[0]
        self.block2 = nn.ModuleList([Block(
            dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale,
            drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
            sr_ratio=sr_ratios[1])
            for i in range(depths[1])])
        self.norm2 = norm_layer(embed_dims[1])

        cur += depths[1]
        self.block3 = nn.ModuleList([Block(
            dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale,
            drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
            sr_ratio=sr_ratios[2])
            for i in range(depths[2])])
        self.norm3 = norm_layer(embed_dims[2])

        cur += depths[2]
        self.block4 = nn.ModuleList([Block(
            dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale,
            drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
            sr_ratio=sr_ratios[3])
            for i in range(depths[3])])
        self.norm4 = norm_layer(embed_dims[3])

        # classification head
        # self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity()

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv3d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

    # def init_weights(self, pretrained=None):
    #     if isinstance(pretrained, str):
    #         logger = get_root_logger()
    #         load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger)

    def reset_drop_path(self, drop_path_rate):
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))]
        cur = 0
        for i in range(self.depths[0]):
            self.block1[i].drop_path.drop_prob = dpr[cur + i]

        cur += self.depths[0]
        for i in range(self.depths[1]):
            self.block2[i].drop_path.drop_prob = dpr[cur + i]

        cur += self.depths[1]
        for i in range(self.depths[2]):
            self.block3[i].drop_path.drop_prob = dpr[cur + i]

        cur += self.depths[2]
        for i in range(self.depths[3]):
            self.block4[i].drop_path.drop_prob = dpr[cur + i]

    def freeze_patch_emb(self):
        self.patch_embed1.requires_grad = False

    @torch.jit.ignore
    def no_weight_decay(self):
        return {'pos_embed1', 'pos_embed2', 'pos_embed3', 'pos_embed4', 'cls_token'}  # has pos_embed may be better

    def get_classifier(self):
        return self.head

    def reset_classifier(self, num_classes, global_pool=''):
        self.num_classes = num_classes
        self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()

    def forward_features(self, x):
        B = x.shape[0]
        outs = []

        # stage 1
        x, D, H, W = self.patch_embed1(x)
        for i, blk in enumerate(self.block1):
            x = blk(x, D, H, W)
        x = self.norm1(x)
        x = x.reshape(B, D, H, W, -1).permute(0, 4, 1, 2, 3).contiguous()
        outs.append(x)

        # stage 2
        x, D, H, W = self.patch_embed2(x)
        for i, blk in enumerate(self.block2):
            x = blk(x, D, H, W)
        x = self.norm2(x)
        x = x.reshape(B, D, H, W, -1).permute(0, 4, 1, 2, 3).contiguous()
        outs.append(x)

        # stage 3
        x, D, H, W = self.patch_embed3(x)
        for i, blk in enumerate(self.block3):
            x = blk(x, D, H, W)
        x = self.norm3(x)
        x = x.reshape(B, D, H, W, -1).permute(0, 4, 1, 2, 3).contiguous()
        outs.append(x)

        # stage 4
        x, D, H, W = self.patch_embed4(x)
        for i, blk in enumerate(self.block4):
            x = blk(x, D, H, W)
        x = self.norm4(x)
        x = x.reshape(B, D, H, W, -1).permute(0, 4, 1, 2, 3).contiguous()
        outs.append(x)

        return outs

    def forward(self, x):
        x = self.forward_features(x)
        # x = self.head(x)

        return x

In [21]:
# model = MixVisionTransformer()
# #print(model)
# tensor = torch.randn(1, 1, 224, 224, 224)
# print(tensor.shape)
# print(model(tensor)[0].shape, model(tensor)[1].shape, model(tensor)[2].shape, model(tensor)[3].shape)

In [22]:
print(model)

DWConv(
  (dwconv): Conv3d(768, 768, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), groups=768)
)


In [23]:
#!pip install pyproject-toml
#!pip install wheel
#!pip install functools

In [24]:

class mit_b1(MixVisionTransformer):
    def __init__(self, **kwargs):
        super(mit_b1, self).__init__(
            patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
            qkv_bias=True, norm_layer=nn.LayerNorm, depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1],
            drop_rate=0.0, drop_path_rate=0.1)

In [25]:
# model = mit_b1()
# #print(model)
# tensor = torch.randn(1, 1, 224, 224, 224)
# print(tensor.shape)
# print(model(tensor)[0].shape, model(tensor)[1].shape, model(tensor)[2].shape, model(tensor)[3].shape)
# output_tensor = model(tensor)

In [26]:
print(model)

DWConv(
  (dwconv): Conv3d(768, 768, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), groups=768)
)


In [27]:
class MLP(nn.Module):
    """
    Linear Embedding
    """
    def __init__(self, input_dim=2048, embed_dim=768):
        super().__init__()
        self.proj = nn.Linear(input_dim, embed_dim)

    def forward(self, x):
        x = x.flatten(2).transpose(1, 2)
        x = self.proj(x)
        return x

In [28]:
class SegFormerHead(nn.Module):
    """
    SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers
    """
    def __init__(self, feature_strides, in_channels=[64, 128, 320, 512] ,embedding_dim = 768, num_classes=19, dropout_ratio=0.1):
        super().__init__()
        self.in_channels = in_channels
        assert len(feature_strides) == len(self.in_channels)
        assert min(feature_strides) == feature_strides[0]
        self.feature_strides = feature_strides
        self.num_classes = num_classes
        self.embedding_dim = embedding_dim
        self.dropout_ratio = dropout_ratio
        self.dropout = nn.Dropout(self.dropout_ratio)


        c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = self.in_channels

        self.linear_c4 = MLP(input_dim=c4_in_channels, embed_dim=embedding_dim)
        self.linear_c3 = MLP(input_dim=c3_in_channels, embed_dim=embedding_dim)
        self.linear_c2 = MLP(input_dim=c2_in_channels, embed_dim=embedding_dim)
        self.linear_c1 = MLP(input_dim=c1_in_channels, embed_dim=embedding_dim)

        self.linear_fuse = nn.Sequential(
            nn.Conv3d(embedding_dim*4, embedding_dim, kernel_size=1, bias=False),
            nn.ReLU(),
            nn.BatchNorm3d(embedding_dim)
        )

        self.linear_pred = nn.Conv3d(embedding_dim, self.num_classes, kernel_size=1)

    def resize(self, input, size=None, scale_factor=None, mode='nearest', align_corners=None):
            return F.interpolate(input, size, scale_factor, mode, align_corners)
    

    def forward(self, x):
        c1, c2, c3, c4 = x

        ############## MLP decoder on C1-C4 ###########
        n, _, d, h, w = c4.shape

        _c4 = self.linear_c4(c4).permute(0,2,1).reshape(n, -1, c4.shape[2], c4.shape[3], c4.shape[4])
        _c4 = self.resize(_c4, size=c1.size()[2:],mode='trilinear',align_corners=False)

        _c3 = self.linear_c3(c3).permute(0,2,1).reshape(n, -1, c3.shape[2], c3.shape[3], c3.shape[4])
        _c3 = self.resize(_c3, size=c1.size()[2:],mode='trilinear',align_corners=False)

        _c2 = self.linear_c2(c2).permute(0,2,1).reshape(n, -1, c2.shape[2], c2.shape[3], c2.shape[4])
        _c2 = self.resize(_c2, size=c1.size()[2:],mode='trilinear',align_corners=False)

        _c1 = self.linear_c1(c1).permute(0,2,1).reshape(n, -1, c1.shape[2], c1.shape[3], c1.shape[4])

        _c = self.linear_fuse(torch.cat([_c4, _c3, _c2, _c1], dim=1))

        x = self.dropout(_c)
        x = self.linear_pred(x)

        return x

      

In [29]:
# #print(type(model(tensor)))
# inputs = model(tensor)
# print(len(inputs))
# c1, c2, c3, c4 = inputs
# print(c1.shape, c2.shape, c3.shape, c4.shape)
# n, _, d, h, w = c4.shape

In [30]:
# def resize(input, size=None, scale_factor=None, mode='nearest', align_corners=None):
#     return F.interpolate(input, size, scale_factor, mode, align_corners)

In [31]:
# linear = MLP(input_dim=512, embed_dim=768)
# #print(linear(c4).permute(0,2,1).reshape(n, -1, c4.shape[2], c4.shape[3], c4.shape[4]).shape)
# _c4 = linear(c4).permute(0,2,1).reshape(n, -1, c4.shape[2], c4.shape[3], c4.shape[4]) 
# print(_c4.shape)
# _c4 = resize(_c4, size=c1.size()[2:],mode='trilinear',align_corners=False)
# print(_c4.shape)


In [32]:
# model = SegFormerHead(feature_strides=[4, 8, 16, 32], in_channels=[64, 128, 320, 512] ,embedding_dim = 768, num_classes=19)
# #print(model)
# model(output_tensor).shape


In [40]:
class ThreeDimSegFormer(nn.Module):
    def __init__(self, img_size=(200,224,224), patch_size=4, in_chans=1, num_classes=1000, embed_dims=[64, 128, 256, 512],
                 num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=True, qk_scale=None, drop_rate=0.,
                 attn_drop_rate=0., drop_path_rate=0.1, norm_layer=nn.LayerNorm,
                 depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1]):
        super().__init__()
        self.encoder = MixVisionTransformer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, num_classes=num_classes, embed_dims=embed_dims,
                 num_heads=num_heads, mlp_ratios=mlp_ratios, qkv_bias=qkv_bias, qk_scale=qk_scale, drop_rate=drop_rate,
                 attn_drop_rate=attn_drop_rate, drop_path_rate=drop_path_rate, norm_layer=norm_layer,
                 depths=depths, sr_ratios=sr_ratios)
        self.decoder = SegFormerHead(feature_strides=[4, 8, 16, 32], in_channels=embed_dims ,embedding_dim = embed_dims[-1], num_classes=num_classes)
    def final_resize(self, input, size=None, scale_factor=None, mode='nearest', align_corners=None):
        return F.interpolate(input, size, scale_factor, mode, align_corners)
    
    def forward(self, input):
        x = self.encoder(input)
        x = self.decoder(x)
        out = self.final_resize(x, size=input.shape[2:],mode='trilinear',align_corners=False)

        return out

In [43]:
model = ThreeDimSegFormer(depths=[2, 2, 2, 2])
print(model)
device = torch.device('cuda:0')
model = model.to(device)
tensor = torch.randn(1, 1, 100, 124, 124)
tensor = tensor.to(device)
print(tensor.shape)
print(model(tensor).shape)


ThreeDimSegFormer(
  (encoder): MixVisionTransformer(
    (patch_embed1): OverlapPatchEmbed(
      (proj): Conv3d(1, 64, kernel_size=(7, 7, 7), stride=(4, 4, 4), padding=(3, 3, 3))
      (norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
    )
    (patch_embed2): OverlapPatchEmbed(
      (proj): Conv3d(64, 128, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
      (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    )
    (patch_embed3): OverlapPatchEmbed(
      (proj): Conv3d(128, 256, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
      (norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    )
    (patch_embed4): OverlapPatchEmbed(
      (proj): Conv3d(256, 512, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
      (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    )
    (block1): ModuleList(
      (0): Block(
        (norm1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (attn): Atten