<a href="https://colab.research.google.com/github/as9786/ComputerVision/blob/main/ImageSegmentation/code/Segformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Install

In [None]:
!pip install einops

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting einops
  Downloading einops-0.6.1-py3-none-any.whl (42 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/42.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 kB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.6.1


# 함수

In [None]:
import torch
from einops import rearrange
import torch.nn as nn

In [None]:
# 층 정규화

class LayerNormalization2D(nn.LayerNorm):

  def forward(self,x):
    # 형식 바꾸기
    x = rearrange(x, 'b c h w -> b h w c')
    # 층 정규화
    x = super().forward(x)
    # 원래 모양으로 복구
    x = rearrange(x, 'b h w c -> b c h w')
    return x

Pytorch의 nn.LayerNorm 함수는 (batch, ..., channels)의 입력을 받기 때문에 torch 자체에서는 영상을 (channels, h, w)로 가지기 때문에 이를 바꿔주었다

In [None]:
# Overlap patch merging

class  OverlapPatchMerging(nn.Sequential):

  def __init__(self, in_channels, out_channels, patch_size, overlap_size):

    super().__init__(nn.Conv2d(in_channels, out_channels, kernel_size = patch_size, stride = overlap_size,
                               padding = patch_size // 2, bias = False),
                     LayerNormalization2D(out_channels))

In [None]:
# Efficient multi-head attention

class EfficientMultiHeadAttention(nn.Module):

  def __init__(self, channels, reduction_ratio = 1, num_heads = 8):
    super().__init__()
    self.reducer = nn.Sequential(nn.Conv2d(channels, channels, kernel_size = reduction_ratio,stride = reduction_ratio),
                                  LayerNormalization2D(channels))

    self.att = nn.MultiheadAttention(channels, num_heads = num_heads, batch_first = True)

  def forward(self, x):
    _, _, h, w = x.shape
    reduced_x = self.reducer(x)
    reduced_x = rearrange(reduced_x, 'b c h w -> b (h w) c')
    x = rearrange(x, 'b c h w -> b (h w) c')
    out = self.att(x,reduced_x, reduced_x)[0]

    out = rearrange(out, 'b (h w) c -> b c h w',h=h,w=w)
    return out

In [None]:
x = torch.randn((1,8,64,64))
block = EfficientMultiHeadAttention(8,4)
block(x).shape

torch.Size([1, 8, 64, 64])

논문에서는 선형 변환을 통해서 가중치를 줄였지만 convolution filter를 통해서도 해당 문제를 해결 가능

Parameter의 수는 증가하지만 transformer에서의 지역 정보가 부족한 부분을 채워줄 수 있음

In [None]:
# Efficient multi-head attention in paper

class EfficientMultiHeadAttentionInPaper(nn.Module):

  def __init__(self,channels, reduction_ratio):
    super().__init__()
    self.channels = channels
    self.reduction_ratio = reduction_ratio
    self.att = nn.MultiheadAttention(channels,num_heads=8,batch_first=True)
    self.reducer = nn.Linear(channels * reduction_ratio, channels)
  def forward(self,x):
    _, _, h, w = x.shape
    x = rearrange(x, 'b c h w -> b (h w) c')
    reduced_x = rearrange(x, "b (hw r) c -> b hw (c r)", r=4)
    reduced_x = self.reducer(reduced_x)

    out = self.att(x,reduced_x,reduced_x)[0]
    out = rearrange(out,'b (h w) c -> b c h w',h=h)

    return out

In [None]:
x = torch.randn((1,8,64,64))
block = EfficientMultiHeadAttentionInPaper(8,4)
block(x).shape

torch.Size([1, 8, 64, 64])

In [None]:
!pip install torchsummary

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
from torchsummary import summary as summary

In [None]:
conv_method = EfficientMultiHeadAttention(8,4)
summary(conv_method,(8,64,64))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1            [-1, 8, 16, 16]           1,032
LayerNormalization2D-2            [-1, 8, 16, 16]              16
MultiheadAttention-3  [[-1, 4096, 8], [-1, 4096, 256]]               0
Total params: 1,048
Trainable params: 1,048
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.12
Forward/backward pass size (MB): 262143.97
Params size (MB): 0.00
Estimated Total Size (MB): 262144.10
----------------------------------------------------------------


In [None]:
paper_method = EfficientMultiHeadAttentionInPaper(8,4)
summary(paper_method,(8,64,64))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1              [-1, 1024, 8]             264
MultiheadAttention-2  [[-1, 4096, 8], [-1, 4096, 1024]]               0
Total params: 264
Trainable params: 264
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.12
Forward/backward pass size (MB): 1048575.94
Params size (MB): 0.00
Estimated Total Size (MB): 1048576.06
----------------------------------------------------------------


In [None]:
# Mix-MLP

class MixMLP(nn.Sequential):

  def __init__(self, channels, expansion=4):
    super().__init__(nn.Conv2d(channels, channels, kernel_size=1),
                     nn.Conv2d(channels,channels * expansion,kernel_size=3,groups=channels,padding=1),
                     nn.GELU(), nn.Conv2d(channels*expansion, channels, kernel_size = 1))

In [None]:
from torchvision.ops import StochasticDepth

In [None]:
class ResidualAdd(nn.Module):

  def __init__(self, fn):
    super().__init__()
    self.fn = fn

  def forward(self, x, **kwargs):
    out = self.fn(x, **kwargs)
    x = x + out
    return x

In [None]:
class SegFormerEncoderBlock(nn.Sequential):

  def __init__(self, channels, reduction_ratio=1, num_heads = 8, mlp_expansion = 4, drop_path_prob = 0):

    super().__init__(ResidualAdd(nn.Sequential(LayerNormalization2D(channels),
                                               EfficientMultiHeadAttention(channels,reduction_ratio,num_heads))),
                                 ResidualAdd(nn.Sequential(LayerNormalization2D(channels),
                                                           MixMLP(channels,expansion=mlp_expansion),
                                                           StochasticDepth(p=drop_path_prob,mode='batch'))))

In [None]:
x = torch.randn((1,8,64,64))
block = SegFormerEncoderBlock(8,4)
block(x).shape

torch.Size([1, 8, 64, 64])

In [None]:
from typing import Iterable
from typing import List

In [None]:
class SegFormerEncoderStage(nn.Sequential):
    def __init__(
        self,
        in_channels,
        out_channels,
        patch_size,
        overlap_size,
        drop_probs,
        depth,
        reduction_ratio,
        num_heads,
        mlp_expansion,
    ):
        super().__init__()
        self.overlap_patch_merge = OverlapPatchMerging(
            in_channels, out_channels, patch_size, overlap_size,
        )
        self.blocks = nn.Sequential(
            *[
                SegFormerEncoderBlock(
                    out_channels, reduction_ratio, num_heads, mlp_expansion, drop_probs[i]
                )
                for i in range(depth)
            ]
        )
        self.norm = LayerNormalization2D(out_channels)

In [None]:
def chunks(data, sizes):

    curr = 0
    for size in sizes:
        chunk = data[curr: curr + size]
        curr += size
        yield chunk

In [None]:
class SegFormerEncoder(nn.Module):
    def __init__(
        self,
        in_channels,
        widths,
        depths,
        all_num_heads,
        patch_sizes,
        overlap_sizes,
        reduction_ratios,
        mlp_expansions,
        drop_prob
    ):
        super().__init__()

        drop_probs =  [x.item() for x in torch.linspace(0, drop_prob, sum(depths))]
        self.stages = nn.ModuleList(
            [
                SegFormerEncoderStage(*args)
                for args in zip(
                    [in_channels, *widths],
                    widths,
                    patch_sizes,
                    overlap_sizes,
                    chunks(drop_probs, sizes=depths),
                    depths,
                    reduction_ratios,
                    all_num_heads,
                    mlp_expansions
                )
            ]
        )

    def forward(self, x):
        features = []
        for stage in self.stages:
            x = stage(x)
            features.append(x)
        return features

In [None]:
class SegFormerDecoderBlock(nn.Sequential):
    def __init__(self, in_channels: int, out_channels: int, scale_factor: int = 2):
        super().__init__(
            nn.UpsamplingBilinear2d(scale_factor=scale_factor),
            nn.Conv2d(in_channels, out_channels, kernel_size=1),
        )

In [None]:

class SegFormerDecoder(nn.Module):
    def __init__(self, out_channels: int, widths: List[int], scale_factors: List[int]):
        super().__init__()
        self.stages = nn.ModuleList(
            [
                SegFormerDecoderBlock(in_channels, out_channels, scale_factor)
                for in_channels, scale_factor in zip(widths, scale_factors)
            ]
        )

    def forward(self, features):
        new_features = []
        for feature, stage in zip(features,self.stages):
            x = stage(feature)
            new_features.append(x)
        return new_features

In [None]:
class SegFormerSegmentationHead(nn.Module):
    def __init__(self, channels: int, num_classes: int, num_features: int = 4):
        super().__init__()
        self.fuse = nn.Sequential(
            nn.Conv2d(channels * num_features, channels, kernel_size=1, bias=False),
            nn.ReLU(),
            nn.BatchNorm2d(channels)
        )
        self.predict = nn.Conv2d(channels, num_classes, kernel_size=1)

    def forward(self, features):
        x = torch.cat(features, dim=1)
        x = self.fuse(x)
        x = self.predict(x)
        return x

In [None]:
class SegFormer(nn.Module):
    def __init__(
        self,
        in_channels: int,
        widths: List[int],
        depths: List[int],
        all_num_heads: List[int],
        patch_sizes: List[int],
        overlap_sizes: List[int],
        reduction_ratios: List[int],
        mlp_expansions: List[int],
        decoder_channels: int,
        scale_factors: List[int],
        num_classes: int,
        drop_prob: float = 0.0,
    ):

        super().__init__()
        self.encoder = SegFormerEncoder(
            in_channels,
            widths,
            depths,
            all_num_heads,
            patch_sizes,
            overlap_sizes,
            reduction_ratios,
            mlp_expansions,
            drop_prob,
        )
        self.decoder = SegFormerDecoder(decoder_channels, widths[::-1], scale_factors)
        self.head = SegFormerSegmentationHead(
            decoder_channels, num_classes, num_features=len(widths)
        )

    def forward(self, x):
        features = self.encoder(x)
        features = self.decoder(features[::-1])
        segmentation = self.head(features)
        return segmentation

In [None]:

segformer = SegFormer(
    in_channels=3,
    widths=[64, 128, 256, 512],
    depths=[3, 4, 6, 3],
    all_num_heads=[1, 2, 4, 8],
    patch_sizes=[7, 3, 3, 3],
    overlap_sizes=[4, 2, 2, 2],
    reduction_ratios=[8, 4, 2, 1],
    mlp_expansions=[4, 4, 4, 4],
    decoder_channels=256,
    scale_factors=[8, 4, 2, 1],
    num_classes=100,
)

In [None]:
segmentation = segformer(torch.randn((4, 3, 224, 224)))

In [None]:
segmentation.shape

torch.Size([4, 100, 56, 56])