In [155]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [241]:
def resize(input,
           size=None,
           scale_factor=None,
           mode='nearest',
           align_corners=None):
    if isinstance(size, torch.Size):
        size = tuple(int(x) for x in size)
    return F.interpolate(input, size, scale_factor, mode, align_corners)

# Encoder

In [242]:
class OverlapPatchEmbeddings(nn.Module):
    def __init__(
        self,
        image_size: int,
        patch_size: int,
        stride: int,
        in_channels: int = 4,
        embedding_dim: int = 768
    ):
        super().__init__()
        self.image_size = image_size
        self.embeddings_projection = nn.Conv2d(in_channels=in_channels,
                                               out_channels=embedding_dim,
                                               kernel_size=patch_size,
                                               stride=stride,
                                               padding=(patch_size // 2, patch_size // 2))
        
        self.layer_norm = nn.LayerNorm(normalized_shape=embedding_dim)
    
    def forward(self, x: torch.Tensor) -> tuple:
        """
        Args:
            x: (B, C, H, W)
        Returns:
            x' : (B, H' * W', embedding_dim)
            num_patches_H: int
            num_patches_W: int
        """
        x = self.embeddings_projection(x) # (B, embedding_dim, H', W')
        
        num_patches_H = x.shape[2]
        num_patches_W = x.shape[3]
        
        x = x.flatten(start_dim=2, end_dim=-1) # (B, embedding_dim, H' * W')
        
        x = self.reshape_for_layer_norm(x) # (B, H' * W', embedding_dim)
        
        x = self.layer_norm(x) # (B, H' * W', embedding_dim)
        
        return x, num_patches_H, num_patches_W

    def reshape_for_layer_norm(self, x: torch.Tensor) -> torch.Tensor:
        return x.transpose(1, 2)


In [243]:
patch_embed1 = OverlapPatchEmbeddings(image_size=512,
                                      patch_size=7,
                                      stride=4,
                                      in_channels=4,
                                      embedding_dim=64)

In [244]:
img_batch = torch.randn(2, 4, 512, 512)

In [245]:
patch_embed1(img_batch)[0].shape

torch.Size([2, 16384, 64])

In [246]:
class EfficientSelfAttention(nn.Module):
    def __init__(
        self,
        d_model: int,
        num_heads: int = 8,
        qkv_bias: bool = False,
        attention_dropout: float = 0.,
        projection_dropout: float = 0.,
        reduction_ratio: int = 1
    ):
        super().__init__()
        # Based on original paper (Attention is all you need), d_model must be divisible by num_heads
        assert d_model % num_heads == 0, f"d_model ({d_model}) must be divisible by num_heads ({num_heads}) !"
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.scale = (d_model // num_heads) ** -0.5 # scale factor from original paper (Attention is all you need)
        
        self.q = nn.Linear(in_features=d_model,
                           out_features=d_model,
                           bias=qkv_bias)
        self.k = nn.Linear(in_features=d_model,
                           out_features=d_model,
                           bias=qkv_bias)
        self.v = nn.Linear(in_features=d_model,
                           out_features=d_model,
                           bias=qkv_bias)
        
        self.attention_dropout = nn.Dropout(p=attention_dropout)
        
        self.projection = nn.Linear(in_features=d_model,
                                    out_features=d_model)
        self.projection_dropout = nn.Dropout(p=projection_dropout)
        
        self.reduction_ratio = reduction_ratio
        if reduction_ratio > 1:
            self.reduction = nn.Conv2d(in_channels=d_model,
                                       out_channels=d_model,
                                       kernel_size=reduction_ratio,
                                       stride=reduction_ratio)
            self.reduction_layer_norm = nn.LayerNorm(normalized_shape=d_model)
    
    def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
        """
        Args:
            x: (B, N, C)
            H: int
            W: int
        Returns:
            x: (B, N, C)
        """
        B, N, C = x.shape
        
        queries = self.q(x) # (B, N, C)
        queries = queries.reshape(B, N, self.num_heads, C // self.num_heads) # (B, N, num_heads, C // num_heads)
        queries = queries.permute(0, 2, 1, 3) # (B, num_heads, N, C // num_heads)
        
        if self.reduction_ratio > 1:
            x_reshaped = x.permute(0, 2, 1).reshape(B, C, H, W) # (B, C, H, W)
            x_reduced = self.reduction(x_reshaped).reshape(B, C, -1).permute(0, 2, 1) # (B, N_reduced, C)
            x_reduced = self.reduction_layer_norm(x_reduced) # (B, N_reduced, C)
            
            keys = self.k(x_reduced) # (B, N_reduced, C)
            keys = keys.reshape(B, -1, self.num_heads, C // self.num_heads) # (B, N_reduced, num_heads, C // num_heads)
            keys = keys.permute(0, 2, 1, 3) # (B, num_heads, N_reduced, C // num_heads)
            
            values = self.v(x_reduced) # (B, N_reduced, C)
            values = values.reshape(B, -1, self.num_heads, C // self.num_heads) # (B, N_reduced, num_heads, C // num_heads)
            values = values.permute(0, 2, 1, 3) # (B, num_heads, N_reduced, C // num_heads)
        else:
            keys = self.k(x) # (B, N, C)
            keys = keys.reshape(B, -1, self.num_heads, C // self.num_heads) # (B, N, num_heads, C // num_heads)
            keys = keys.permute(0, 2, 1, 3) # (B, num_heads, N, C // num_heads)
            
            values = self.v(x) # (B, N, C)
            values = values.reshape(B, -1, self.num_heads, C // self.num_heads) # (B, N, num_heads, C // num_heads)
            values = values.permute(0, 2, 1, 3) # (B, num_heads, N, C // num_heads)
        
        keys = keys.transpose(-2, -1) # (B, num_heads, C // num_heads, N)
        
        attention = queries.matmul(keys) * self.scale # (B, num_heads, N, N)
        attention = attention.softmax(dim=-1)
        attention = self.attention_dropout(attention)
        
        x = attention.matmul(values).transpose(1, 2).reshape(B, N, C) # (B, N, C)
        x = self.projection(x)
        x = self.projection_dropout(x)
        
        return x
        
        

In [247]:
efficient_self_attention = EfficientSelfAttention(d_model=64,
                                                  num_heads=8,
                                                  reduction_ratio=8)

In [248]:
efficient_self_attention(torch.randn(10, (56*56), 64), 56, 56).shape

torch.Size([10, 3136, 64])

In [249]:
class DWConv(nn.Module):
    """
    Taken from: https://github.com/NVlabs/SegFormer/tree/master
    """
    def __init__(self, dim=768):
        super(DWConv, self).__init__()
        self.dwconv = nn.Conv2d(in_channels=dim,
                                out_channels=dim, 
                                kernel_size=3,
                                stride=1,
                                padding=1,
                                bias=True,
                                groups=dim)

    def forward(self, x, H, W):
        B, N, C = x.shape
        x = x.transpose(1, 2).view(B, C, H, W)
        x = self.dwconv(x)
        x = x.flatten(2).transpose(1, 2)

        return x

In [250]:
class MixFFN(nn.Module):
    def __init__(
        self,
        in_features: int,
        hidden_features=None,
        out_features=None,
        dropout: float = 0.,
    ):
        super().__init__()
        out_features = out_features or in_features
        
        hidden_features = hidden_features or in_features
        
        self.fc1 = nn.Linear(in_features=in_features,
                             out_features=hidden_features)
        
        self.conv = DWConv(dim=hidden_features)

        self.fc2 = nn.Linear(in_features=hidden_features,
                             out_features=out_features)
        
        self.dropout = nn.Dropout(p=dropout)
    
    def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
        """
        Args:
            x: (B, N, C)
            H: int
            W: int
        Returns:
            x: (B, N, C)
        """
        x = self.fc1(x)
        x = self.conv(x, H, W)
        x = F.gelu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        
        return x

In [251]:
mlp = MixFFN(in_features=64)

In [252]:
mlp(torch.randn(10, (56*56), 64), 56, 56).shape

torch.Size([10, 3136, 64])

In [253]:
class TransformerBlock(nn.Module):
    def __init__(
        self,
        d_model: int,
        num_heads: int,
        mlp_ratio: int = 4,
        qkv_bias: bool = False,
        attention_dropout: float = 0.,
        dropout: float = 0.,
        reduction_ratio: int = 1
    ):
        super().__init__()
        self.norm1 = nn.LayerNorm(normalized_shape=d_model)
        
        self.attention = EfficientSelfAttention(d_model=d_model,
                                                num_heads=num_heads,
                                                qkv_bias=qkv_bias,
                                                attention_dropout=attention_dropout,
                                                projection_dropout=dropout,
                                                reduction_ratio=reduction_ratio)
        
        self.norm2 = nn.LayerNorm(normalized_shape=d_model)
        
        mlp_hidden_dim = int(d_model * mlp_ratio)
        self.mix_ffn = MixFFN(in_features=d_model,
                              hidden_features=mlp_hidden_dim,
                              dropout=dropout)
    
    def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
        """
        Args:
            x: (B, N, C)
            H: int
            W: int
        Returns:
            x: (B, N, C)
        """
        x = x + self.attention(self.norm1(x), H, W)
        x = x + self.mix_ffn(self.norm2(x), H, W)
        return x
        

In [254]:
transformer_block = TransformerBlock(d_model=64,
                                     num_heads=8,
                                     mlp_ratio=4,
                                     reduction_ratio=8)

In [255]:
transformer_block(torch.randn(10, (56*56), 64), 56, 56).shape

torch.Size([10, 3136, 64])

In [256]:
class MixVisionTransformer(nn.Module):
    def __init__(
        self,
        image_size: int,
        in_channels: int = 4,
        num_classes: int = 1,
        embedding_dims: list = [64, 128, 256, 512],
        num_heads: list = [1, 2, 4, 8],
        mlp_ratios: list = [4, 4, 4, 4],
        qkv_bias: bool = False,
        attention_dropout: float = 0.,
        dropout: float = 0.,
        reduction_ratios: list = [8, 4, 2, 1],
        depths: list = [3, 4, 6, 3]
    ):
        super().__init__()
        self.num_classes = num_classes
        self.depths = depths
        
        self.patch_embed1 = OverlapPatchEmbeddings(image_size=image_size,
                                                   patch_size=7,
                                                   stride=4,
                                                   in_channels=in_channels,
                                                   embedding_dim=embedding_dims[0])
        
        self.patch_embed2 = OverlapPatchEmbeddings(image_size=image_size // 4,
                                                   patch_size=3,
                                                   stride=2,
                                                   in_channels=embedding_dims[0],
                                                   embedding_dim=embedding_dims[1])
        
        self.patch_embed3 = OverlapPatchEmbeddings(image_size=image_size // 8,
                                                   patch_size=3,
                                                   stride=2,
                                                   in_channels=embedding_dims[1],
                                                   embedding_dim=embedding_dims[2])
        
        self.patch_embed4 = OverlapPatchEmbeddings(image_size=image_size // 16,
                                                   patch_size=3,
                                                   stride=2,
                                                   in_channels=embedding_dims[2],
                                                   embedding_dim=embedding_dims[3])
        
        self.block1 = nn.ModuleList([
            TransformerBlock(d_model=embedding_dims[0],
                             num_heads=num_heads[0],
                             mlp_ratio=mlp_ratios[0],
                             qkv_bias=qkv_bias,
                             attention_dropout=attention_dropout,
                             dropout=dropout,
                             reduction_ratio=reduction_ratios[0]) for _ in range(depths[0])
        ])
        self.norm1 = nn.LayerNorm(normalized_shape=embedding_dims[0])
        
        self.block2 = nn.ModuleList([
            TransformerBlock(d_model=embedding_dims[1],
                             num_heads=num_heads[1],
                             mlp_ratio=mlp_ratios[1],
                             qkv_bias=qkv_bias,
                             attention_dropout=attention_dropout,
                             dropout=dropout,
                             reduction_ratio=reduction_ratios[1]) for _ in range(depths[1])
        ])
        self.norm2 = nn.LayerNorm(normalized_shape=embedding_dims[1])
        
        self.block3 = nn.ModuleList([
            TransformerBlock(d_model=embedding_dims[2],
                             num_heads=num_heads[2],
                             mlp_ratio=mlp_ratios[2],
                             qkv_bias=qkv_bias,
                             attention_dropout=attention_dropout,
                             dropout=dropout,
                             reduction_ratio=reduction_ratios[2]) for _ in range(depths[2])
        ])
        self.norm3 = nn.LayerNorm(normalized_shape=embedding_dims[2])
        
        self.block4 = nn.ModuleList([
            TransformerBlock(d_model=embedding_dims[3],
                             num_heads=num_heads[3],
                             mlp_ratio=mlp_ratios[3],
                             qkv_bias=qkv_bias,
                             attention_dropout=attention_dropout,
                             dropout=dropout,
                             reduction_ratio=reduction_ratios[3]) for _ in range(depths[3])
        ])
        self.norm4 = nn.LayerNorm(normalized_shape=embedding_dims[3])
    
    def forward(self, x: torch.Tensor):
        B = x.shape[0]
        outputs = []
        
        x = self.compute_stage_1(x, B)
        outputs.append(x)
        
        x = self.compute_stage_2(x, B)
        outputs.append(x)
        
        x = self.compute_stage_3(x, B)
        outputs.append(x)
        
        x = self.compute_stage_4(x, B)
        outputs.append(x)
        
        return outputs
    
    def compute_stage_1(self, x: torch.Tensor, B: int) -> torch.Tensor:
        x, H, W = self.patch_embed1(x)
        
        for transformer_block in self.block1:
            x = transformer_block(x, H, W)
        
        x = self.norm1(x)
        
        x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
        
        return x
    
    def compute_stage_2(self, x: torch.Tensor, B: int) -> torch.Tensor:
        x, H, W = self.patch_embed2(x)
        
        for transformer_block in self.block2:
            x = transformer_block(x, H, W)
        
        x = self.norm2(x)
        
        x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
        
        return x
    
    def compute_stage_3(self, x: torch.Tensor, B: int) -> torch.Tensor:
        x, H, W = self.patch_embed3(x)
        
        for transformer_block in self.block3:
            x = transformer_block(x, H, W)
        
        x = self.norm3(x)
        
        x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
        
        return x
    
    def compute_stage_4(self, x: torch.Tensor, B: int) -> torch.Tensor:
        x, H, W = self.patch_embed4(x)
        
        for transformer_block in self.block4:
            x = transformer_block(x, H, W)
        
        x = self.norm4(x)
        
        x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
        
        return x

In [257]:
mit = MixVisionTransformer(image_size=512)

In [258]:
mit_outputs = mit(torch.randn(10, 4, 512, 512))

In [259]:
for output in mit_outputs:
    print(output.shape)

torch.Size([10, 64, 128, 128])
torch.Size([10, 128, 64, 64])
torch.Size([10, 256, 32, 32])
torch.Size([10, 512, 16, 16])


# Decoder

In [260]:
class MLP(nn.Module):
    def __init__(
        self,
        input_dim: int,
        embedding_dim: int
    ):
        super().__init__()
        self.fc = nn.Linear(in_features=input_dim,
                            out_features=embedding_dim)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: (B, input_dim, H, W)
        Returns:
            x: (B, H * W, embedding_dim)
        """
        x = x.flatten(start_dim=2, end_dim=-1).transpose(1, 2) # (B, H * W, C)
        x = self.fc(x)
        return x

In [261]:
mlp = MLP(input_dim=512, embedding_dim=768)

In [262]:
mlp(torch.randn(10, 512, 16, 16)).shape

torch.Size([10, 256, 768])

In [263]:
class SegFormerSemanticSegmentationHead(nn.Module):
    def __init__(
        self,
        c1_in_channels: int,
        c2_in_channels: int,
        c3_in_channels: int,
        c4_in_channels: int,
        num_classes: int,
        embedding_dim: int = 768,
        dropout: float = 0.
    ):
        super().__init__()
        self.linear_c4 = MLP(input_dim=c4_in_channels, embedding_dim=embedding_dim)
        self.linear_c3 = MLP(input_dim=c3_in_channels, embedding_dim=embedding_dim)
        self.linear_c2 = MLP(input_dim=c2_in_channels, embedding_dim=embedding_dim)
        self.linear_c1 = MLP(input_dim=c1_in_channels, embedding_dim=embedding_dim)
        
        number_of_blocks = 4
        self.linear_fuse_conv = nn.Conv2d(in_channels=embedding_dim * number_of_blocks,
                                          out_channels=embedding_dim,
                                          kernel_size=1)
        self.linear_fuse_norm = nn.SyncBatchNorm(num_features=embedding_dim)
        
        self.dropout = nn.Dropout(p=dropout)
        
        self.out = nn.Conv2d(in_channels=embedding_dim, out_channels=num_classes, kernel_size=1)
    
    def apply_linear_fuse(self, c1: torch.Tensor, c2: torch.Tensor, c3: torch.Tensor, c4: torch.Tensor) -> torch.Tensor:
        output = self.linear_fuse_conv(torch.cat([c4, c3, c2, c1], dim=1))
        output = self.linear_fuse_norm(output)
        output = F.relu(output)
        return output
    
    def forward(self, inputs: list) -> torch.Tensor:
        """
        Args:
            inputs: list of tensors from the output of the MixVisionTransformer
        Returns:
            x: (B, H * W, num_classes)
        """
        c1, c2, c3, c4 = inputs
        
        n, _, h, w = c4.shape

        _c4 = self.linear_c4(c4).permute(0,2,1).reshape(n, -1, c4.shape[2], c4.shape[3])
        _c4 = resize(_c4, size=c1.size()[2:],mode='bilinear',align_corners=False)

        _c3 = self.linear_c3(c3).permute(0,2,1).reshape(n, -1, c3.shape[2], c3.shape[3])
        _c3 = resize(_c3, size=c1.size()[2:],mode='bilinear',align_corners=False)

        _c2 = self.linear_c2(c2).permute(0,2,1).reshape(n, -1, c2.shape[2], c2.shape[3])
        _c2 = resize(_c2, size=c1.size()[2:],mode='bilinear',align_corners=False)

        _c1 = self.linear_c1(c1).permute(0,2,1).reshape(n, -1, c1.shape[2], c1.shape[3])

        _c = self.apply_linear_fuse(c1=_c1, c2=_c2, c3=_c3, c4=_c4)

        x = self.dropout(_c)
        x = self.out(x)

        return x

In [264]:
head = SegFormerSemanticSegmentationHead(c1_in_channels=64,
                                         c2_in_channels=128,
                                         c3_in_channels=256,
                                         c4_in_channels=512,
                                         num_classes=1,
                                         embedding_dim=768)

In [265]:
head(mit_outputs).shape

torch.Size([10, 1, 128, 128])

# SegFormer Encoder-Decoder

In [268]:
class SegFormer(nn.Module):
    def __init__(
        self,
        image_size: int = 512,
        num_classes: int = 1,
        in_channels: int = 4,
        encoder_embedding_dims: list = [64, 128, 256, 512],
        decoder_embedding_dim: int = 768,
    ):
        super().__init__()
        self.image_size = image_size
        self.num_classes = num_classes
        self.in_channels = in_channels
        self.encoder_embedding_dims = encoder_embedding_dims
        self.decoder_embedding_dim = decoder_embedding_dim
        
        self.encoder = MixVisionTransformer(image_size=image_size,
                                             in_channels=in_channels,
                                             num_classes=num_classes,
                                             embedding_dims=encoder_embedding_dims,
                                             num_heads=[1, 2, 4, 8],
                                             mlp_ratios=[4, 4, 4, 4],
                                             qkv_bias=False,
                                             attention_dropout=0.5,
                                             dropout=0.5,
                                             reduction_ratios=[8, 4, 2, 1],
                                             depths=[3, 4, 6, 3])
        
        self.decoder = SegFormerSemanticSegmentationHead(c1_in_channels=encoder_embedding_dims[0],
                                                         c2_in_channels=encoder_embedding_dims[1],
                                                         c3_in_channels=encoder_embedding_dims[2],
                                                         c4_in_channels=encoder_embedding_dims[3],
                                                         num_classes=num_classes,
                                                         embedding_dim=decoder_embedding_dim,
                                                         dropout=0.5)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.encoder(x)
        x = self.decoder(x)
        return resize(input=x, size=(self.image_size, self.image_size), mode='bilinear', align_corners=False)

In [269]:
segformer = SegFormer(image_size=512)

In [270]:
segformer(torch.randn(10, 4, 512, 512)).shape

torch.Size([10, 1, 512, 512])