In [1]:
import torch
import torch.nn as nn
from segformer import *
from typing import Tuple
from einops import rearrange


from torch.nn import functional as F


print("\n\n --------The importing has been done!------------\n\n")



 --------The importing has been done!------------




The input size is 
data shape--------- torch.Size([1, 1, 224, 224]) torch.Size([1, 1, 224, 224])

image: torch.Size([1, 1, 224, 224])

label: torch.Size([1, 1, 224, 224])

In [2]:
inputs = torch.rand(1, 1, 224, 224)
print("Test the input of size {}".format(inputs.shape))
if inputs.size()[1] == 1:
    inputs = inputs.repeat(1,3,1,1)
print("Test the input of size {}".format(inputs.shape))

Test the input of size torch.Size([1, 1, 224, 224])
Test the input of size torch.Size([1, 3, 224, 224])


# Encoder

In [3]:
class MiT(nn.Module):
    def __init__(self, image_size, dims, layers, token_mlp='mix_skip'):
        super().__init__()
        patch_sizes = [7, 3, 3, 3]
        strides = [4, 2, 2, 2]
        padding_sizes = [3, 1, 1, 1]
        reduction_ratios = [8, 4, 2, 1]
        heads = [1, 2, 5, 8]

        # patch_embed
        # layers = [2, 2, 2, 2] dims = [64, 128, 320, 512]
        self.patch_embed1 = OverlapPatchEmbeddings(image_size, patch_sizes[0], strides[0], padding_sizes[0], 3, dims[0])
        self.patch_embed2 = OverlapPatchEmbeddings(image_size//4, patch_sizes[1], strides[1],  padding_sizes[1],dims[0], dims[1])
        self.patch_embed3 = OverlapPatchEmbeddings(image_size//8, patch_sizes[2], strides[2],  padding_sizes[2],dims[1], dims[2])
        self.patch_embed4 = OverlapPatchEmbeddings(image_size//16, patch_sizes[3], strides[3],  padding_sizes[3],dims[2], dims[3])
        
        # transformer encoder
        self.block1 = nn.ModuleList([
            TransformerBlock(dims[0], heads[0], reduction_ratios[0],token_mlp)
        for _ in range(layers[0])])
        self.norm1 = nn.LayerNorm(dims[0])

        self.block2 = nn.ModuleList([
            TransformerBlock(dims[1], heads[1], reduction_ratios[1],token_mlp)
        for _ in range(layers[1])])
        self.norm2 = nn.LayerNorm(dims[1])

        self.block3 = nn.ModuleList([
            TransformerBlock(dims[2], heads[2], reduction_ratios[2], token_mlp)
        for _ in range(layers[2])])
        self.norm3 = nn.LayerNorm(dims[2])

        self.block4 = nn.ModuleList([
            TransformerBlock(dims[3], heads[3], reduction_ratios[3], token_mlp)
        for _ in range(layers[3])])
        self.norm4 = nn.LayerNorm(dims[3])

        # self.head = nn.Linear(dims[3], num_classes)


    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B = x.shape[0]
        outs = []

        # stage 1
        x, H, W = self.patch_embed1(x)
        for blk in self.block1:
            x = blk(x, H, W)
        x = self.norm1(x)
        x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
        outs.append(x)

        # stage 2
        x, H, W = self.patch_embed2(x)
        for blk in self.block2:
            x = blk(x, H, W)
        x = self.norm2(x)
        x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
        outs.append(x)

        # stage 3
        x, H, W = self.patch_embed3(x)
        for blk in self.block3:
            x = blk(x, H, W)
        x = self.norm3(x)
        x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
        outs.append(x)

        # stage 4
        x, H, W = self.patch_embed4(x)
        for blk in self.block4:
            x = blk(x, H, W)
        x = self.norm4(x)
        x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
        outs.append(x)

        return outs

In [4]:
dims, layers = [[64, 128, 320, 512], [2, 2, 2, 2]]
token_mlp_mode="mix_skip"
encoder = MiT(224, dims, layers,token_mlp_mode)
output_enc = encoder(inputs)
print("The number of stages from encoder: {}".format(len(output_enc)))
for i in range(len(output_enc)):
    print("The size of output from the {} stage: {}".format(i, output_enc[i].shape))
    


The number of stages from encoder: 4
The size of output from the 0 stage: torch.Size([1, 64, 56, 56])
The size of output from the 1 stage: torch.Size([1, 128, 28, 28])
The size of output from the 2 stage: torch.Size([1, 320, 14, 14])
The size of output from the 3 stage: torch.Size([1, 512, 7, 7])


# Bridge

In [5]:
class BridgeLayer_4(nn.Module):
    def __init__(self, dims, head, reduction_ratios):
        super().__init__()

        self.norm1 = nn.LayerNorm(dims)
        self.attn = M_EfficientSelfAtten(dims, head, reduction_ratios)
        self.norm2 = nn.LayerNorm(dims)
        self.mixffn1 = MixFFN_skip(dims,dims*4)
        self.mixffn2 = MixFFN_skip(dims*2,dims*8)
        self.mixffn3 = MixFFN_skip(dims*5,dims*20)
        self.mixffn4 = MixFFN_skip(dims*8,dims*32)
        
        
    def forward(self, inputs):
        B = inputs[0].shape[0]
        C = 64
        if (type(inputs) == list):
            # If input type is list, then the block is the first bridge layer
            # The feature from the four stages should be concatenated together
            print("\n\n-----1-----")
            c1, c2, c3, c4 = inputs
            B, C, _, _= c1.shape
            c1f = c1.permute(0, 2, 3, 1).reshape(B, -1, C)  # 3136*64
            c2f = c2.permute(0, 2, 3, 1).reshape(B, -1, C)  # 1568*64
            c3f = c3.permute(0, 2, 3, 1).reshape(B, -1, C)  # 980*64
            c4f = c4.permute(0, 2, 3, 1).reshape(B, -1, C)  # 392*64
            
            print(c1f.shape, c2f.shape, c3f.shape, c4f.shape)
            inputs = torch.cat([c1f, c2f, c3f, c4f], -2)
            print("The shape of input is {}".format(inputs.shape))
        else:
            print("\n\n-----not 1-----")
            # In this case, the block is not the first bridge layer
            B,_,C = inputs.shape 
            print(inputs.shape)
            print("The shape of input is {}".format(inputs.shape))

        tx1 = inputs + self.attn(self.norm1(inputs))
        tx = self.norm2(tx1)
        
        print("\n--------seq2imgs--------------\n")
        tem1 = tx[:,:3136,:].reshape(B, -1, C) 
        print("The shape of tem1 is {}".format(tem1.shape))
        tem2 = tx[:,3136:4704,:].reshape(B, -1, C*2)
        print("The shape of tem2 is {}".format(tem2.shape))
        tem3 = tx[:,4704:5684,:].reshape(B, -1, C*5)
        print("The shape of tem3 is {}".format(tem3.shape))
        tem4 = tx[:,5684:6076,:].reshape(B, -1, C*8)
        print("The shape of tem4 is {}".format(tem4.shape))

        print("\n--------imgs passing through Emix_FFN and img2seq--------------\n")
        m1f = self.mixffn1(tem1, 56, 56).reshape(B, -1, C)
        print("The shape of m1f is {}".format(m1f.shape))
        m2f = self.mixffn2(tem2, 28, 28).reshape(B, -1, C)
        print("The shape of m2f is {}".format(m2f.shape))
        m3f = self.mixffn3(tem3, 14, 14).reshape(B, -1, C)
        print("The shape of m3f is {}".format(m3f.shape))
        m4f = self.mixffn4(tem4, 7, 7).reshape(B, -1, C)
        print("The shape of m4f is {}".format(m4f.shape))

        print("\n-------concatenate the seqs---------\n")
        t1 = torch.cat([m1f, m2f, m3f, m4f], -2)
        print("The shape of t1 is {}".format(t1.shape))
        tx2 = tx1 + t1
        print("The shape of tx2 = x1 + t1 is {}".format(t1.shape))


        return tx2

In [6]:
class BridegeBlock_4(nn.Module):
    def __init__(self, dims, head, reduction_ratios):
        super().__init__()
        self.bridge_layer1 = BridgeLayer_4(dims, head, reduction_ratios)
        self.bridge_layer2 = BridgeLayer_4(dims, head, reduction_ratios)
        self.bridge_layer3 = BridgeLayer_4(dims, head, reduction_ratios)
        self.bridge_layer4 = BridgeLayer_4(dims, head, reduction_ratios)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        bridge1 = self.bridge_layer1(x)
        bridge2 = self.bridge_layer2(bridge1)
        bridge3 = self.bridge_layer3(bridge2)
        bridge4 = self.bridge_layer4(bridge3)

        B,_,C = bridge4.shape
        outs = []

        sk1 = bridge4[:,:3136,:].reshape(B, 56, 56, C).permute(0,3,1,2) 
        sk2 = bridge4[:,3136:4704,:].reshape(B, 28, 28, C*2).permute(0,3,1,2) 
        sk3 = bridge4[:,4704:5684,:].reshape(B, 14, 14, C*5).permute(0,3,1,2) 
        sk4 = bridge4[:,5684:6076,:].reshape(B, 7, 7, C*8).permute(0,3,1,2) 
        print("\n\nThe shape of sk:\n{} {} {} {}".format(sk1.shape, sk2.shape, sk3.shape, sk4.shape))

        outs.append(sk1)
        outs.append(sk2)
        outs.append(sk3)
        outs.append(sk4)

        return outs

In [7]:
reduction_ratios = [8, 4, 2, 1]
bridge = BridegeBlock_4(64, 1, reduction_ratios)
output_br = bridge(output_enc) 
print("The number of stages from bridge: {}".format(len(output_br)))
for i in range(len(output_br)):
    print("The size of output from the {} stage: {}".format(i, output_br[i].shape))



-----1-----
torch.Size([1, 3136, 64]) torch.Size([1, 1568, 64]) torch.Size([1, 980, 64]) torch.Size([1, 392, 64])
The shape of input is torch.Size([1, 6076, 64])

--------seq2imgs--------------

The shape of tem1 is torch.Size([1, 3136, 64])
The shape of tem2 is torch.Size([1, 784, 128])
The shape of tem3 is torch.Size([1, 196, 320])
The shape of tem4 is torch.Size([1, 49, 512])

--------imgs passing through Emix_FFN and img2seq--------------

The shape of m1f is torch.Size([1, 3136, 64])
The shape of m2f is torch.Size([1, 1568, 64])
The shape of m3f is torch.Size([1, 980, 64])
The shape of m4f is torch.Size([1, 392, 64])

-------concatenate the seqs---------

The shape of t1 is torch.Size([1, 6076, 64])
The shape of tx2 = x1 + t1 is torch.Size([1, 6076, 64])


-----not 1-----
torch.Size([1, 6076, 64])
The shape of input is torch.Size([1, 6076, 64])

--------seq2imgs--------------

The shape of tem1 is torch.Size([1, 3136, 64])
The shape of tem2 is torch.Size([1, 784, 128])
The shape

# Decoder

In [8]:
class PatchExpand(nn.Module):
    def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.LayerNorm):
        super().__init__()
        self.input_resolution = input_resolution
        self.dim = dim
        self.expand = nn.Linear(dim, 2*dim, bias=False) if dim_scale==2 else nn.Identity()
        self.norm = norm_layer(dim // dim_scale)

    def forward(self, x):
        """
        x: B, H*W, C
        """
        # print("x_shape-----",x.shape)
        H, W = self.input_resolution
        x = self.expand(x)
        
        B, L, C = x.shape
        # print(x.shape)
        assert L == H * W, "input feature has wrong size"

        x = x.view(B, H, W, C)
        x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=2, p2=2, c=C//4)
        x = x.view(B,-1,C//4)
        x= self.norm(x.clone())

        return x

In [9]:
class FinalPatchExpand_X4(nn.Module):
    def __init__(self, input_resolution, dim, dim_scale=4, norm_layer=nn.LayerNorm):
        super().__init__()
        self.input_resolution = input_resolution
        self.dim = dim
        self.dim_scale = dim_scale
        self.expand = nn.Linear(dim, 16*dim, bias=False)
        self.output_dim = dim 
        self.norm = norm_layer(self.output_dim)

    def forward(self, x):
        """
        x: B, H*W, C
        """
        H, W = self.input_resolution
        x = self.expand(x)
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"

        x = x.view(B, H, W, C)
        x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale, c=C//(self.dim_scale**2))
        x = x.view(B,-1,self.output_dim)
        x= self.norm(x.clone())

        return x

In [12]:
class MyDecoderLayer(nn.Module):
    def __init__(self, input_size, in_out_chan, heads, reduction_ratios,token_mlp_mode, n_class=9, norm_layer=nn.LayerNorm, is_last=False):
        super().__init__()
        dims = in_out_chan[0]
        out_dim = in_out_chan[1]
        if not is_last:
            self.concat_linear = nn.Linear(dims*2, out_dim)
            # transformer decoder
            self.layer_up = PatchExpand(input_resolution=input_size, dim=out_dim, dim_scale=2, norm_layer=norm_layer)
            self.last_layer = None
        else:
            self.concat_linear = nn.Linear(dims*4, out_dim)
            # transformer decoder
            self.layer_up = FinalPatchExpand_X4(input_resolution=input_size, dim=out_dim, dim_scale=4, norm_layer=norm_layer)
            # self.last_layer = nn.Linear(out_dim, n_class)
            self.last_layer = nn.Conv2d(out_dim, n_class,1)
            # self.last_layer = None

        self.layer_former_1 = TransformerBlock(out_dim, heads, reduction_ratios, token_mlp_mode)
        self.layer_former_2 = TransformerBlock(out_dim, heads, reduction_ratios, token_mlp_mode)
       

        def init_weights(self): 
            for m in self.modules():
                if isinstance(m, nn.Linear):
                    nn.init.xavier_uniform_(m.weight)
                    if m.bias is not None:
                        nn.init.zeros_(m.bias)
                elif isinstance(m, nn.LayerNorm):
                    nn.init.ones_(m.weight)
                    nn.init.zeros_(m.bias)
                elif isinstance(m, nn.Conv2d):
                    nn.init.xavier_uniform_(m.weight)
                    if m.bias is not None:
                        nn.init.zeros_(m.bias)

        init_weights(self)
      
    def forward(self, x1, x2=None):
        if x2 is not None:# skip connection exist
            print("x1 shape:", x1.shape)
            print("x2 shape:", x2.shape)
            b, h, w, c = x2.shape
            x2 = x2.view(b, -1, c)
            print("------",x1.shape, x2.shape)
            cat_x = torch.cat([x1, x2], dim=-1)
            print("-----catx shape", cat_x.shape)
            cat_linear_x = self.concat_linear(cat_x)
            tran_layer_1 = self.layer_former_1(cat_linear_x, h, w)
            tran_layer_2 = self.layer_former_2(tran_layer_1, h, w)
            
            if self.last_layer:
                out = self.last_layer(self.layer_up(tran_layer_2).view(b, 4*h, 4*w, -1).permute(0,3,1,2)) 
            else:
                out = self.layer_up(tran_layer_2)
        else:
            # if len(x1.shape)>3:
            #     x1 = x1.permute(0,2,3,1)
            #     b, h, w, c = x1.shape
            #     x1 = x1.view(b, -1, c)
            print("What is this else condition?")
            print("x1 shape",x1.shape)
            out = self.layer_up(x1)
        return out

In [13]:
b,c,_,_ = output_br[3].shape
print("The shapes of output_br: \n output_br[3]{} \noutput_br[2]{} \noutput_br[1]{} \noutput_br[0]{}\n".format(output_br[3].shape, output_br[2].shape,output_br[1].shape, output_br[0].shape))

reduction_ratios = [8, 4, 2, 1]
heads = [1, 2, 5, 8]
d_base_feat_size = 7 #16 for 512 inputsize   7for 224
in_out_chan = [[32, 64],[144, 128],[288, 320],[512, 512]]

dims, layers = [[64, 128, 320, 512], [2, 2, 2, 2]]
num_classes=9
token_mlp_mode="mix_skip"

reduction_ratios = [1, 2, 4, 8]


decoder_3= MyDecoderLayer((d_base_feat_size,d_base_feat_size), in_out_chan[3], heads[3], reduction_ratios[3],token_mlp_mode, n_class=num_classes)
decoder_2= MyDecoderLayer((d_base_feat_size*2,d_base_feat_size*2),in_out_chan[2], heads[2], reduction_ratios[2], token_mlp_mode, n_class=num_classes)
decoder_1= MyDecoderLayer((d_base_feat_size*4,d_base_feat_size*4), in_out_chan[1], heads[1], reduction_ratios[1], token_mlp_mode, n_class=num_classes)
decoder_0= MyDecoderLayer((d_base_feat_size*8,d_base_feat_size*8), in_out_chan[0], heads[0], reduction_ratios[0], token_mlp_mode, n_class=num_classes, is_last=True)

#---------------Decoder-------------------------     
print("stage3-----")   
tmp_3 = decoder_3(output_br[3].permute(0,2,3,1).view(b,-1,c))
print("stage2-----")   
tmp_2 = decoder_2(tmp_3, output_br[2].permute(0,2,3,1))
print("stage1-----")   
tmp_1 = decoder_1(tmp_2, output_br[1].permute(0,2,3,1))
print("stage0-----")  
tmp_0 = decoder_0(tmp_1, output_br[0].permute(0,2,3,1))

print("The shapes of tmp: \ntmp_3:{} \ntmp_2:{} \ntmp_1:{} \ntmp_0:{}\n".format(tmp_3.shape, tmp_2.shape,tmp_1.shape, tmp_0.shape))

The shapes of output_br: 
 output_br[3]torch.Size([1, 512, 7, 7]) 
output_br[2]torch.Size([1, 320, 14, 14]) 
output_br[1]torch.Size([1, 128, 28, 28]) 
output_br[0]torch.Size([1, 64, 56, 56])

stage3-----
What is this else condition?
x1 shape torch.Size([1, 49, 512])
stage2-----
x1 shape: torch.Size([1, 196, 256])
x2 shape: torch.Size([1, 14, 14, 320])
------ torch.Size([1, 196, 256]) torch.Size([1, 196, 320])
-----catx shape torch.Size([1, 196, 576])
stage1-----
x1 shape: torch.Size([1, 784, 160])
x2 shape: torch.Size([1, 28, 28, 128])
------ torch.Size([1, 784, 160]) torch.Size([1, 784, 128])
-----catx shape torch.Size([1, 784, 288])
stage0-----
x1 shape: torch.Size([1, 3136, 64])
x2 shape: torch.Size([1, 56, 56, 64])
------ torch.Size([1, 3136, 64]) torch.Size([1, 3136, 64])
-----catx shape torch.Size([1, 3136, 128])
The shapes of tmp: 
tmp_3:torch.Size([1, 196, 256]) 
tmp_2:torch.Size([1, 784, 160]) 
tmp_1:torch.Size([1, 3136, 64]) 
tmp_0:torch.Size([1, 9, 224, 224])



# The MISSFormer Architecture

In [None]:
class MISSFormer(nn.Module):
    def __init__(self, num_classes=9, token_mlp_mode="mix_skip", encoder_pretrained=True):
        super().__init__()
    
        reduction_ratios = [8, 4, 2, 1]
        heads = [1, 2, 5, 8]
        d_base_feat_size = 7 #16 for 512 inputsize   7for 224
        in_out_chan = [[32, 64],[144, 128],[288, 320],[512, 512]]

        dims, layers = [[64, 128, 320, 512], [2, 2, 2, 2]]
        self.backbone = MiT(224, dims, layers,token_mlp_mode)

        self.reduction_ratios = [1, 2, 4, 8]
        self.bridge = BridegeBlock_4(64, 1, self.reduction_ratios)

        self.decoder_3= MyDecoderLayer((d_base_feat_size,d_base_feat_size), in_out_chan[3], heads[3], reduction_ratios[3],token_mlp_mode, n_class=num_classes)
        self.decoder_2= MyDecoderLayer((d_base_feat_size*2,d_base_feat_size*2),in_out_chan[2], heads[2], reduction_ratios[2], token_mlp_mode, n_class=num_classes)
        self.decoder_1= MyDecoderLayer((d_base_feat_size*4,d_base_feat_size*4), in_out_chan[1], heads[1], reduction_ratios[1], token_mlp_mode, n_class=num_classes)
        self.decoder_0= MyDecoderLayer((d_base_feat_size*8,d_base_feat_size*8), in_out_chan[0], heads[0], reduction_ratios[0], token_mlp_mode, n_class=num_classes, is_last=True)

        
    def forward(self, x):
        #---------------Encoder-------------------------
        if x.size()[1] == 1:
            x = x.repeat(1,3,1,1)

        encoder = self.backbone(x)
        bridge = self.bridge(encoder) #list

        b,c,_,_ = bridge[3].shape
        # print(bridge[3].shape, bridge[2].shape,bridge[1].shape, bridge[0].shape)
        #---------------Decoder-------------------------     
        # print("stage3-----")   
        tmp_3 = self.decoder_3(bridge[3].permute(0,2,3,1).view(b,-1,c))
        # print("stage2-----")   
        tmp_2 = self.decoder_2(tmp_3, bridge[2].permute(0,2,3,1))
        # print("stage1-----")   
        tmp_1 = self.decoder_1(tmp_2, bridge[1].permute(0,2,3,1))
        # print("stage0-----")  
        tmp_0 = self.decoder_0(tmp_1, bridge[0].permute(0,2,3,1))

        return tmp_0