In [1]:
# # https://github.com/maxw1489/Mask_RCNN (tensorflow 2.9.1)

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np


from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
from collections import OrderedDict

# from mmseg.ops import resize # steht unten
# from ..builder import HEADS
# from .decode_head import BaseDecodeHead
# from mmseg.models.utils import *
# import attr

from IPython import embed




from timm.layers import SqueezeExcite

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
img_size = 256
patch_size=16
frozen_stages = 0
in_chans = 3
embed_dim = [224, 336, 448]
partial_dim = [48, 72, 96] # partial_dim = r*embed_dim with r=1/4.67
qk_dim = [16, 16, 16]
depth = [4, 7, 6]
types = ["i", "s", "s"]
down_ops = [['subsample', 2], ['subsample', 2], ['']]
pretrained = None
distillation = False

In [4]:
class GroupNorm(torch.nn.GroupNorm):
    """
    Group Normalization with 1 group.
    Input: tensor in shape [B, C, H, W]
    """
    def __init__(self, num_channels, **kwargs):
        super().__init__(1, num_channels, **kwargs)

In [5]:
class Conv2d_BN(torch.nn.Sequential):  
    def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1,
                 groups=1, bn_weight_init=1):
        super().__init__()
        self.add_module('c', torch.nn.Conv2d(
            a, b, ks, stride, pad, dilation, groups, bias=False))
        self.add_module('bn', torch.nn.BatchNorm2d(b))
        torch.nn.init.constant_(self.bn.weight, bn_weight_init)
        torch.nn.init.constant_(self.bn.bias, 0)

    @torch.no_grad()
    def fuse(self):
        c, bn = self._modules.values()
        w = bn.weight / (bn.running_var + bn.eps)**0.5
        w = c.weight * w[:, None, None, None]
        b = bn.bias - bn.running_mean * bn.weight / \
            (bn.running_var + bn.eps)**0.5
        m = torch.nn.Conv2d(w.size(1) * self.c.groups, w.size(
            0), w.shape[2:], stride=self.c.stride, padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups)
        m.weight.data.copy_(w)
        m.bias.data.copy_(b)
        return m

In [6]:
class BN_Linear(torch.nn.Sequential):
    def __init__(self, a, b, bias=True, std=0.02):
        super().__init__()
        self.add_module('bn', torch.nn.BatchNorm1d(a))
        self.add_module('l', torch.nn.Linear(a, b, bias=bias))
        trunc_normal_(self.l.weight, std=std)
        if bias:
            torch.nn.init.constant_(self.l.bias, 0)

    @torch.no_grad()
    def fuse(self):
        bn, l = self._modules.values()
        w = bn.weight / (bn.running_var + bn.eps)**0.5
        b = bn.bias - self.bn.running_mean * \
            self.bn.weight / (bn.running_var + bn.eps)**0.5
        w = l.weight * w[None, :]
        if l.bias is None:
            b = b @ self.l.weight.T
        else:
            b = (l.weight @ b[:, None]).view(-1) + self.l.bias
        m = torch.nn.Linear(w.size(1), w.size(0))
        m.weight.data.copy_(w)
        m.bias.data.copy_(b)
        return m

In [7]:
class PatchMerging(torch.nn.Module):
    def __init__(self, dim, out_dim):
        super().__init__()
        hid_dim = int(dim * 4)
        self.conv1 = Conv2d_BN(dim, hid_dim, 1, 1, 0)
        self.act = torch.nn.ReLU()
        self.conv2 = Conv2d_BN(hid_dim, hid_dim, 3, 2, 1, groups=hid_dim)
        self.se = SqueezeExcite(hid_dim, .25)
        self.conv3 = Conv2d_BN(hid_dim, out_dim, 1, 1, 0)

    def forward(self, x):
        print("### PatchMerging ###")
        print(x.shape)
        x = self.conv3(self.se(self.act(self.conv2(self.act(self.conv1(x))))))
        print(x.shape)
        print("==="*10)
        return x

In [8]:
class Residual(torch.nn.Module):
    def __init__(self, m, drop=0.):
        super().__init__()
        self.m = m
        self.drop = drop

    def forward(self, x):
        if self.training and self.drop > 0:
            return x + self.m(x) * torch.rand(x.size(0), 1, 1, 1,
                                              device=x.device).ge_(self.drop).div(1 - self.drop).detach()
        else:
            return x + self.m(x)
    
    @torch.no_grad()
    def fuse(self):
        if isinstance(self.m, Conv2d_BN):
            m = self.m.fuse()
            assert(m.groups == m.in_channels)
            identity = torch.ones(m.weight.shape[0], m.weight.shape[1], 1, 1)
            identity = torch.nn.functional.pad(identity, [1,1,1,1])
            m.weight += identity.to(m.weight.device)
            return m
        else:
            return self

In [9]:
class FFN(torch.nn.Module):
    def __init__(self, ed, h):
        super().__init__()
        self.pw1 = Conv2d_BN(ed, h)
        self.act = torch.nn.ReLU()
        self.pw2 = Conv2d_BN(h, ed, bn_weight_init=0)

    def forward(self, x):
        x = self.pw2(self.act(self.pw1(x)))
        return x

In [10]:
class SHSA(torch.nn.Module):
    """Single-Head Self-Attention"""
    def __init__(self, dim, qk_dim, pdim):
        super().__init__()
        self.scale = qk_dim ** -0.5
        self.qk_dim = qk_dim
        self.dim = dim
        self.pdim = pdim

        self.pre_norm = GroupNorm(pdim)

        self.qkv = Conv2d_BN(pdim, qk_dim * 2 + pdim)
        self.proj = torch.nn.Sequential(torch.nn.ReLU(), Conv2d_BN(
            dim, dim, bn_weight_init = 0))
        

    def forward(self, x):
        B, C, H, W = x.shape
        x1, x2 = torch.split(x, [self.pdim, self.dim - self.pdim], dim = 1)
        x1 = self.pre_norm(x1)
        qkv = self.qkv(x1)
        q, k, v = qkv.split([self.qk_dim, self.qk_dim, self.pdim], dim = 1)
        q, k, v = q.flatten(2), k.flatten(2), v.flatten(2)
        
        attn = (q.transpose(-2, -1) @ k) * self.scale
        attn = attn.softmax(dim = -1)
        x1 = (v @ attn.transpose(-2, -1)).reshape(B, self.pdim, H, W)
        x = self.proj(torch.cat([x1, x2], dim = 1))

        return x

In [11]:
class BasicBlock(torch.nn.Module):
    def __init__(self, dim, qk_dim, pdim, type):
        super().__init__()
        if type == "s":    # for later stages
            self.conv = Residual(Conv2d_BN(dim, dim, 3, 1, 1, groups = dim, bn_weight_init = 0))
            self.mixer = Residual(SHSA(dim, qk_dim, pdim))
            self.ffn = Residual(FFN(dim, int(dim * 2)))
        elif type == "i":   # for early stages
            self.conv = Residual(Conv2d_BN(dim, dim, 3, 1, 1, groups = dim, bn_weight_init = 0))
            self.mixer = torch.nn.Identity()
            self.ffn = Residual(FFN(dim, int(dim * 2)))
    
    def forward(self, x):
        return self.ffn(self.mixer(self.conv(x)))

In [12]:
# input_image = torch.randn(1, 3, 512, 512)
'''
patch_embed = torch.nn.Sequential(Conv2d_BN(in_chans, embed_dim[0] // 8, 3, 2, 1), torch.nn.ReLU(),
                           Conv2d_BN(embed_dim[0] // 8, embed_dim[0] // 4, 3, 2, 1), torch.nn.ReLU(),
                           Conv2d_BN(embed_dim[0] // 4, embed_dim[0] // 2, 3, 2, 1), torch.nn.ReLU(),
                           Conv2d_BN(embed_dim[0] // 2, embed_dim[0], 3, 2, 1)
                           )

x = patch_embed(input_image)
print("patch_embed: ", x.shape)
'''

'\npatch_embed = torch.nn.Sequential(Conv2d_BN(in_chans, embed_dim[0] // 8, 3, 2, 1), torch.nn.ReLU(),\n                           Conv2d_BN(embed_dim[0] // 8, embed_dim[0] // 4, 3, 2, 1), torch.nn.ReLU(),\n                           Conv2d_BN(embed_dim[0] // 4, embed_dim[0] // 2, 3, 2, 1), torch.nn.ReLU(),\n                           Conv2d_BN(embed_dim[0] // 2, embed_dim[0], 3, 2, 1)\n                           )\n\nx = patch_embed(input_image)\nprint("patch_embed: ", x.shape)\n'

In [13]:
'''
blocks1 = []
blocks2 = []
blocks3 = []
outs = []


for i, (ed, kd, pd, dpth, do, t) in enumerate(zip(embed_dim, qk_dim, partial_dim, depth, down_ops, types)):
    print (i, ed, kd, pd, dpth, do, t)
    print("dpth: ", dpth)
    for d in range(dpth):
        eval('blocks' + str(i+1)).append(BasicBlock(ed, kd, pd, t))
    if do[0] == 'subsample':
                # Build SHViT downsample block
                #('Subsample' stride)
                blk = eval('blocks' + str(i+1)) # 2
                blk.append(torch.nn.Sequential(Residual(Conv2d_BN(embed_dim[i], embed_dim[i], 3, 1, 1, groups=embed_dim[i])),
                                    Residual(FFN(embed_dim[i], int(embed_dim[i] * 2))),))
                blk.append(PatchMerging(*embed_dim[i:i + 2]))
                
                blk.append(torch.nn.Sequential(Residual(Conv2d_BN(embed_dim[i + 1], embed_dim[i + 1], 3, 1, 1, groups=embed_dim[i + 1])),
                                    Residual(FFN(embed_dim[i + 1], int(embed_dim[i + 1] * 2))),))

print(eval('blocks1'))

print(len(blocks1),len(blocks2),len(blocks3)) # 4, 7+3, 6+3

blocks1 = torch.nn.Sequential(*blocks1)
blocks2 = torch.nn.Sequential(*blocks2)
blocks3 = torch.nn.Sequential(*blocks3)

print()

print("Stage 1: ")
print("block1 in : ", x.shape)
print("block1 in : ", x.shape)
x = blocks1(x)
outs.append(x)
print("block1 out: ", x.shape)

print()
print("Stage 2: ")
print("block2 in : ", x.shape)
x = blocks2(x)
outs.append(x)
print("block2 out: ", x.shape)

print()
print("Stage 3: ")
print("block3 in : ", x.shape)
x = blocks3(x)
outs.append(x)
print("block3 out: ", x.shape)
'''

'\nblocks1 = []\nblocks2 = []\nblocks3 = []\nouts = []\n\n\nfor i, (ed, kd, pd, dpth, do, t) in enumerate(zip(embed_dim, qk_dim, partial_dim, depth, down_ops, types)):\n    print (i, ed, kd, pd, dpth, do, t)\n    print("dpth: ", dpth)\n    for d in range(dpth):\n        eval(\'blocks\' + str(i+1)).append(BasicBlock(ed, kd, pd, t))\n    if do[0] == \'subsample\':\n                # Build SHViT downsample block\n                #(\'Subsample\' stride)\n                blk = eval(\'blocks\' + str(i+1)) # 2\n                blk.append(torch.nn.Sequential(Residual(Conv2d_BN(embed_dim[i], embed_dim[i], 3, 1, 1, groups=embed_dim[i])),\n                                    Residual(FFN(embed_dim[i], int(embed_dim[i] * 2))),))\n                blk.append(PatchMerging(*embed_dim[i:i + 2]))\n                \n                blk.append(torch.nn.Sequential(Residual(Conv2d_BN(embed_dim[i + 1], embed_dim[i + 1], 3, 1, 1, groups=embed_dim[i + 1])),\n                                    Residual(FFN(emb

In [14]:
"""
   def __init__(self,
                 in_chans=3,
                 num_classes=1000,
                 embed_dim= [224, 336, 448], # [128, 256, 384] SHVIT_
                 partial_dim = [32, 64, 96],
                 qk_dim=[16, 16, 16],
                 depth=[1, 2, 3],
                 types = ["s", "s", "s"],   	
                 down_ops=[['subsample', 2], ['subsample', 2], ['']],
                 distillation=False,):
        super().__init__()

"""

'\n   def __init__(self,\n                 in_chans=3,\n                 num_classes=1000,\n                 embed_dim= [224, 336, 448], # [128, 256, 384] SHVIT_\n                 partial_dim = [32, 64, 96],\n                 qk_dim=[16, 16, 16],\n                 depth=[1, 2, 3],\n                 types = ["s", "s", "s"],   \t\n                 down_ops=[[\'subsample\', 2], [\'subsample\', 2], [\'\']],\n                 distillation=False,):\n        super().__init__()\n\n'

In [15]:
img_size = 256
patch_size=16
frozen_stages = 0
in_chans = 3
embed_dim = [224, 336, 448]
partial_dim = [48, 72, 96] # partial_dim = r*embed_dim with r=1/4.67
qk_dim = [16, 16, 16]
depth = [4, 7, 6]
types = ["i", "s", "s"]
down_ops = [['subsample', 2], ['subsample', 2], ['']]
pretrained = None
distillation = False

In [16]:
class SHViT(torch.nn.Module):
    def __init__(self,
                 in_chans=3,
                 num_classes=1000,
                 embed_dim = [224, 336, 448], # [128, 256, 384] SHVIT_
                 partial_dim = [48, 72, 96],
                 qk_dim=[16, 16, 16],
                 depth = [4, 7, 6],
                 types = ["i", "s", "s"],   	
                 down_ops=[['subsample', 2], ['subsample', 2], ['']],
                 distillation=False,):
        super().__init__()

        # Patch embedding
        self.patch_embed = torch.nn.Sequential(Conv2d_BN(in_chans, embed_dim[0] // 8, 3, 2, 1), torch.nn.ReLU(),
                           Conv2d_BN(embed_dim[0] // 8, embed_dim[0] // 4, 3, 2, 1), torch.nn.ReLU(),
                           Conv2d_BN(embed_dim[0] // 4, embed_dim[0] // 2, 3, 2, 1), torch.nn.ReLU(),
                           Conv2d_BN(embed_dim[0] // 2, embed_dim[0], 3, 2, 1))

        self.blocks1 = []
        self.blocks2 = []
        self.blocks3 = []

        # Build SHViT blocks
        for i, (ed, kd, pd, dpth, do, t) in enumerate(
                zip(embed_dim, qk_dim, partial_dim, depth, down_ops, types)):
            for d in range(dpth):
                eval('self.blocks' + str(i+1)).append(BasicBlock(ed, kd, pd, t))
            if do[0] == 'subsample':
                # Build SHViT downsample block
                #('Subsample' stride)
                blk = eval('self.blocks' + str(i+2)) # mÃ¼ssete 1 sein # war 2
                blk.append(torch.nn.Sequential(Residual(Conv2d_BN(embed_dim[i], embed_dim[i], 3, 1, 1, groups=embed_dim[i])),
                                    Residual(FFN(embed_dim[i], int(embed_dim[i] * 2))),))
                blk.append(PatchMerging(*embed_dim[i:i + 2]))
                
                blk.append(torch.nn.Sequential(Residual(Conv2d_BN(embed_dim[i + 1], embed_dim[i + 1], 3, 1, 1, groups=embed_dim[i + 1])),
                                    Residual(FFN(embed_dim[i + 1], int(embed_dim[i + 1] * 2))),))
        self.blocks1 = torch.nn.Sequential(*self.blocks1)
        
        self.blocks2 = torch.nn.Sequential(*self.blocks2)
        self.blocks3 = torch.nn.Sequential(*self.blocks3)
        
        '''
        # Classification head
        self.head = BN_Linear(embed_dim[-1], num_classes) if num_classes > 0 else torch.nn.Identity()
        self.distillation = distillation
        if distillation:
            self.head_dist = BN_Linear(embed_dim[-1], num_classes) if num_classes > 0 else torch.nn.Identity()
        '''




    def forward(self, x):
        x = self.patch_embed(x)
        print("x: ", x.shape)
        x = self.blocks1(x)
        print("block1 out: ", x.shape)
        x = self.blocks2(x)
        print("block2 out: ", x.shape)
        x = self.blocks3(x)
        print("block3 out: ", x.shape)

        '''
        x = torch.nn.functional.adaptive_avg_pool2d(x, 1).flatten(1)
        if self.distillation:
            x = self.head(x), self.head_dist(x)
            if not self.training:
                x = (x[0] + x[1]) / 2
        else:
            x = self.head(x)
        '''
        return x

In [17]:
'''
input_image = torch.randn(1, 3, 512, 512)
in_channels = SHViT().forward(input_image)  
c3_in_channels = in_channels

print("input_image shape: ", input_image.shape)
in_channels.shape
'''

'\ninput_image = torch.randn(1, 3, 512, 512)\nin_channels = SHViT().forward(input_image)  \nc3_in_channels = in_channels\n\nprint("input_image shape: ", input_image.shape)\nin_channels.shape\n'

In [18]:
# ========================================== SegFormer - Head ==========================================

In [19]:
shvit_bibi = SHViT()
# shvit_bibi

In [20]:
def resize(input,
           size=None,
           scale_factor=None,
           mode='nearest',
           align_corners=None,
           warning=True):
    if warning:
        if size is not None and align_corners:
            input_h, input_w = tuple(int(x) for x in input.shape[2:])
            output_h, output_w = tuple(int(x) for x in size)
            if output_h > input_h or output_w > output_h:
                if ((output_h > 1 and output_w > 1 and input_h > 1
                     and input_w > 1) and (output_h - 1) % (input_h - 1)
                        and (output_w - 1) % (input_w - 1)):
                    warnings.warn(
                        f'When align_corners={align_corners}, '
                        'the output would more aligned if '
                        f'input size {(input_h, input_w)} is `x+1` and '
                        f'out size {(output_h, output_w)} is `nx+1`')
    if isinstance(size, torch.Size):
        size = tuple(int(x) for x in size)
    return F.interpolate(input, size, scale_factor, mode, align_corners)

In [21]:
class MLP(nn.Module):
    """
    Linear Embedding
    """
    def __init__(self, input_dim=2048, embed_dim=768): 
        """
        Args:
            c3: input_dim
            c_out = embed_dim
            F_i = x (computed tensor from SHVIT)
        """


        super().__init__()
        self.proj = nn.Linear(input_dim, embed_dim)

    def forward(self, x):
        x = x.flatten(2).transpose(1, 2)
        x = self.proj(x)
        return x

In [22]:
# init function 

device = 'cuda' if torch.cuda.is_available() else 'cpu'
input_image = torch.randn(1, 3, 512, 512)#.to(device)
embedding_dim = 448 # 768
num_classes = 150

In [23]:
# embedding_dim = 448 # 768
class SegFormerHead(nn.Module):
    def __init__(self, num_classes=150, embedding_dim = 448, feature_strides=None, channels=None,  in_channels=448,  dropout_ratio=0.1, **kwargs):
        super(SegFormerHead, self).__init__(**kwargs)
        self.in_channels = in_channels
        c1_in_channels = self.in_channels
        self.num_classes = num_classes
        self.feature_strides = feature_strides  

        

        self.dropout_ratio = dropout_ratio
        if self.dropout_ratio > 0:
            self.dropout = nn.Dropout2d(dropout_ratio)
        else:
            self.dropout = None


        self.linear_c1 = MLP(input_dim=c1_in_channels, embed_dim=embedding_dim)
        self.linear_fuse = ConvModule(
            in_channels=embedding_dim, # 4C --> C --> embedding_dim
            out_channels=embedding_dim,
            kernel_size=1,
            # norm_cfg=dict(type='SyncBN', requires_grad=True)
        )
        self.linear_pred = nn.Conv2d(embedding_dim, self.num_classes, kernel_size=1)




    def forward(self, inputs):
        _, _, h_shvit, w_shvit = input_image.size()
        print("OKKK", h_shvit/4,w_shvit/4)


        x = shvit_bibi(inputs)#.to(device)
        c1 = x
        n, _, h, w = c1.shape

        print("=======================================================================================")
        print(c1.shape)
        _c1 = self.linear_c1(c1).permute(0,2,1).reshape(n, -1, c1.shape[2], c1.shape[3])
        # h_shvit/4,w_shvit/4
        _c1 = resize(_c1, size=(128, 128),mode='bilinear',align_corners=False)
        print(_c1.shape)

        _c = self.linear_fuse(torch.cat([_c1.to('cpu')], dim=1))

        
        
        x = self.dropout(_c)
        x = self.linear_pred(x)
        print("SSSSSSSSShape: ", x.shape)
        return x



tmp_result = SegFormerHead(num_classes=15, in_channels=448).forward(input_image) #.to(device)
tmp_result.shape

OKKK 128.0 128.0
x:  torch.Size([1, 224, 32, 32])
block1 out:  torch.Size([1, 224, 32, 32])
### PatchMerging ###
torch.Size([1, 224, 32, 32])
torch.Size([1, 336, 16, 16])
block2 out:  torch.Size([1, 336, 16, 16])
### PatchMerging ###
torch.Size([1, 336, 16, 16])
torch.Size([1, 448, 8, 8])
block3 out:  torch.Size([1, 448, 8, 8])
torch.Size([1, 448, 8, 8])
torch.Size([1, 448, 128, 128])
SSSSSSSSShape:  torch.Size([1, 15, 128, 128])


torch.Size([1, 15, 128, 128])

In [24]:
########################################## FERTIG!! ##########################################

In [25]:
dim = input_dim
hid_dim = int(dim * 4)
conv1 = Conv2d_BN(dim, hid_dim, 1, 1, 0)
conv1

NameError: name 'input_dim' is not defined

In [None]:
###########
#  x = self._transform_inputs(inputs) 

c3 = x
_c3 = linear_c3(c3).permute(0,2,1).reshape(n, -1, c3.shape[2], c3.shape[3])
_c3 = resize(_c3, size=c1.size()[2:],mode='bilinear',align_corners=False)

In [30]:
class SegFormerHead(SHViT):
    """
    SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers
    https://github.com/NVlabs/SegFormer/blob/master/local_configs/segformer/B0/segformer.b0.512x512.ade.160k.py#L18 - wegen feature_strides
    """
    def __init__(self, feature_strides, **kwargs):
        super(SegFormerHead, self).__init__(input_transform='multiple_select', **kwargs)
        assert len(feature_strides) == len(self.in_channels)
        assert min(feature_strides) == feature_strides[0]
        self.feature_strides = feature_strides

        c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = self.in_channels

        decoder_params = kwargs['decoder_params']
        embedding_dim = decoder_params['embed_dim']

        # self.linear_c4 = MLP(input_dim=c4_in_channels, embed_dim=embedding_dim)
        self.linear_c3 = MLP(input_dim=c3_in_channels, embed_dim=embedding_dim)
        # self.linear_c2 = MLP(input_dim=c2_in_channels, embed_dim=embedding_dim)
        # self.linear_c1 = MLP(input_dim=c1_in_channels, embed_dim=embedding_dim)

        self.linear_fuse = ConvModule(
            in_channels=embedding_dim*4,
            out_channels=embedding_dim,
            kernel_size=1,
            norm_cfg=dict(type='SyncBN', requires_grad=True)
        )

        self.linear_pred = nn.Conv2d(embedding_dim, self.num_classes, kernel_size=1)



        def forward(self, inputs):
            x = self._transform_inputs(inputs)  # len=4, 1/4,1/8,1/16,1/32
            # c1, c2, c3, c4 = x

            ############## MLP decoder on C1-C4 ###########
            n, _, h, w = c4.shape

            _c4 = self.linear_c4(c4).permute(0,2,1).reshape(n, -1, c4.shape[2], c4.shape[3])
            _c4 = resize(_c4, size=c1.size()[2:],mode='bilinear',align_corners=False)

            # _c3 = self.linear_c3(c3).permute(0,2,1).reshape(n, -1, c3.shape[2], c3.shape[3])
            # _c3 = resize(_c3, size=c1.size()[2:],mode='bilinear',align_corners=False)

            # _c2 = self.linear_c2(c2).permute(0,2,1).reshape(n, -1, c2.shape[2], c2.shape[3])
            # _c2 = resize(_c2, size=c1.size()[2:],mode='bilinear',align_corners=False)

            # _c1 = self.linear_c1(c1).permute(0,2,1).reshape(n, -1, c1.shape[2], c1.shape[3])

            # _c = self.linear_fuse(torch.cat([_c4, _c3, _c2, _c1], dim=1))

            _c = self.linear_fuse(torch.cat([_c4], dim=1))
            x = self.dropout(_c)
            x = self.linear_pred(x)

            return x

In [21]:
class SegFormerHead(SHViT):
    """
    SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers
    """
    def __init__(self, feature_strides, **kwargs):
        super(SegFormerHead, self).__init__(input_transform='multiple_select', **kwargs)
        assert len(feature_strides) == len(self.in_channels)
        assert min(feature_strides) == feature_strides[0]
        self.feature_strides = feature_strides

        c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = self.in_channels

        decoder_params = kwargs['decoder_params']
        embedding_dim = decoder_params['embed_dim']

        self.linear_c4 = MLP(input_dim=c4_in_channels, embed_dim=embedding_dim)
        # self.linear_c3 = MLP(input_dim=c3_in_channels, embed_dim=embedding_dim)
        # self.linear_c2 = MLP(input_dim=c2_in_channels, embed_dim=embedding_dim)
        # self.linear_c1 = MLP(input_dim=c1_in_channels, embed_dim=embedding_dim)

        self.linear_fuse = ConvModule(
            in_channels=embedding_dim*4,
            out_channels=embedding_dim,
            kernel_size=1,
            norm_cfg=dict(type='SyncBN', requires_grad=True)
        )

        self.linear_pred = nn.Conv2d(embedding_dim, self.num_classes, kernel_size=1)



        def forward(self, inputs):
            x = self._transform_inputs(inputs)  # len=4, 1/4,1/8,1/16,1/32
            # c1, c2, c3, c4 = x
            c4 = x

            ############## MLP decoder on C1-C4 ###########
            n, _, h, w = c4.shape

            _c4 = self.linear_c4(c4).permute(0,2,1).reshape(n, -1, c4.shape[2], c4.shape[3])
            _c4 = resize(_c4, size=c1.size()[2:],mode='bilinear',align_corners=False)

            # _c3 = self.linear_c3(c3).permute(0,2,1).reshape(n, -1, c3.shape[2], c3.shape[3])
            # _c3 = resize(_c3, size=c1.size()[2:],mode='bilinear',align_corners=False)

            # _c2 = self.linear_c2(c2).permute(0,2,1).reshape(n, -1, c2.shape[2], c2.shape[3])
            # _c2 = resize(_c2, size=c1.size()[2:],mode='bilinear',align_corners=False)

            # _c1 = self.linear_c1(c1).permute(0,2,1).reshape(n, -1, c1.shape[2], c1.shape[3])

            # _c = self.linear_fuse(torch.cat([_c4, _c3, _c2, _c1], dim=1))

            _c = self.linear_fuse(torch.cat([_c4], dim=1))
            x = self.dropout(_c)
            x = self.linear_pred(x)

            return x

In [None]:
512//16, 512//32, 512//64

In [None]:
shapes = [tensor.shape for tensor in outs]
print(shapes)