In [1]:
import sys, os
sys.path.append(os.path.dirname(os.path.abspath(os.path.dirname('backbone')))) # to import file from under same-level directory
sys.path.append(os.path.dirname(os.path.abspath(os.path.dirname('utils'))))

from backbone.convnext_se.convnext_se import ConvNextV1, Stage, Stem
from utils.conv_2d import adjust_padding_for_strided_output, DepthWiseSepConv
from utils.stochastic_depth_drop import create_linear_p, create_uniform_p

import os
import glob
import PIL
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import regex as re
from tqdm import tqdm

import torch
import torch.nn as nn
import torchvision
from torchvision.transforms import v2
import torchinfo
import albumentations

import sklearn
import sklearn.metrics as metrics

  from .autonotebook import tqdm as notebook_tqdm
  check_for_updates()


In [2]:
# to make structure that encoder and decoder are connected stagewise, we need to re-implement modules.

dp_list, dp_mode = create_linear_p([3,3,9,3], 'batch', 0.5) # creates linearly decaying stochastic depth drop probability
# dp_list, dp_mode = create_uniform_p([3,3,9,3], 'batch', 0.001) # create constant stochastic depth drop probability
print("linearly decaying survival probability: ", dp_list)
    
convnext_t = ConvNextV1(num_blocks=[3, 3, 9, 3], input_channels=3, stem_kersz=(4,4), stem_stride=(4,4), img_hw=[(56,56),(28,28),(14,14),(7,7)], main_channels=[96, 192, 384, 768], expansion_dim=[96*4, 192*4, 384*4, 768*4],
                               kernel_sz=[(7,7), (7,7), (7,7), (7,7)], stride=[(1,1),(1,1),(1,1),(1,1)], padding=['same', 'same', 'same', 'same'], dilation=[1,1,1,1], groups=[1,1,1,1], droprate=dp_list, drop_mode=dp_mode,
                               use_se=[True, True, True, True], squeeze_ratio=16, transition_kersz=[-1, (1,1),(1,1),(1,1)], transition_stride=[-1, (2,2), (2,2), (2,2)], 
                               norm_mode='layer_norm', device='cuda')

torchinfo.summary(convnext_t, (1, 3, 224, 224)) # convnext tiny (33M vs 29M: more weight saving on batch normalization)

linearly decaying survival probability:  [[0.0, 0.027777777777777776, 0.05555555555555555], [0.08333333333333333, 0.1111111111111111, 0.1388888888888889], [0.16666666666666666, 0.19444444444444445, 0.2222222222222222, 0.25, 0.2777777777777778, 0.3055555555555556, 0.3333333333333333, 0.3611111111111111, 0.3888888888888889], [0.4166666666666667, 0.4444444444444444, 0.4722222222222222]]


Layer (type:depth-idx)                             Output Shape              Param #
ConvNextV1                                         [1, 768, 7, 7]            --
├─Stem: 1-1                                        [1, 96, 56, 56]           --
│    └─Conv2d: 2-1                                 [1, 96, 56, 56]           4,704
├─ModuleList: 1-2                                  --                        --
│    └─Stage: 2-2                                  [1, 96, 56, 56]           --
│    │    └─ModuleList: 3-1                        --                        2,075,058
│    └─Stage: 2-3                                  [1, 192, 28, 28]          --
│    │    └─ModuleList: 3-2                        --                        1,940,580
│    └─Stage: 2-4                                  [1, 384, 14, 14]          --
│    │    └─ModuleList: 3-3                        --                        13,652,568
│    └─Stage: 2-5                                  [1, 768, 7, 7]            --
│    │    

In [3]:
ex_img = Image.open("/work/dataset/CULane/driver_193_90frame/06051132_0638.MP4/00000.jpg")
print("original image size: ", ex_img.size)
ex_img_resize = ex_img.resize([672, 224]) # mask and image will be resized to this size - format (W, H)
ex_img = torch.tensor(np.array(ex_img_resize)).unsqueeze(0).permute(0, 3, 1, 2).float().to('cuda')
print(ex_img.shape)

original image size:  (1640, 590)
torch.Size([1, 3, 224, 672])


In [4]:
# convnext v1 + sequeeze and excitation module + Unet Encoder (returns each stage output)
class Encoder(torch.nn.Module):
    def __init__(self, num_blocks:list, input_channels:int, stem_kersz:tuple, stem_stride:tuple, 
                 img_hw:list, main_channels:list, expansion_dim:list, kernel_sz:list, stride:list, padding:list, dilation:list, groups:list, droprate:list, drop_mode:list, use_se:list, squeeze_ratio:int,
                 transition_kersz:list, transition_stride:list, norm_mode:str, device='cuda'):
        
        super().__init__()
        
        self.num_blocks = num_blocks # number of blocks for each stage
        self.input_channels = input_channels # mostly RGB 3 channel images
        self.stem_kersz = stem_kersz # kernel size for stem layer
        self.stem_stride = stem_stride # stride for stem layer
        self.img_hw = img_hw # representative image height and width for each stage
        self.main_channels = main_channels # main channels for each stage
        self.expansion_dim = expansion_dim # expansion dimension for each stage
        self.kernel_sz = kernel_sz # kernel size for each stage
        self.stride=stride # stride for each stage
        self.padding=padding # padding for each stage
        self.dilation=dilation # dilation for each stage
        self.groups=groups # number of groups for each stage
        self.droprate=droprate # constant droprate for all stage
        self.drop_mode=drop_mode # drop_mode is same for all stage
        self.use_se=use_se # flag for using se operation in each stage
        self.squeeze_ratio=squeeze_ratio # squeeze_ratio is same for all stage
        self.transition_kersz=transition_kersz # transition kernel size for each stage
        self.transition_stride=transition_stride # transition stride for each stage
        
        if norm_mode not in ['layer_norm', 'batch_norm']:
            raise Exception(f"Unsupported normalization method: {norm_mode}. Must be either 'layer_norm' or 'batch_norm'")
        self.norm_mode = norm_mode
        
        self.device=device
        self.num_stages = len(num_blocks)
        
        self.stem = Stem(stem_in_channels=self.input_channels, stem_out_channels=main_channels[0], stem_kernel_sz=self.stem_kersz, stem_stride=self.stem_stride,
                         stem_padding=adjust_padding_for_strided_output(self.stem_kersz, self.stem_stride), stem_dilation=1, stem_groups=1, device=self.device)
        self.stages = []
        for i in range(self.num_stages):
            if i == 0:
                self.stages.append(Stage(transition_flag=False, num_blocks=self.num_blocks[i], img_hw=self.img_hw[i], in_channels=self.main_channels[i], out_channels=self.expansion_dim[i],
                                    kernel_sz=self.kernel_sz[i], stride=self.stride[i], padding=self.padding[i], dilation=self.dilation[i], groups=self.groups[i], droprate=self.droprate[i],
                                    drop_mode=self.drop_mode[i], use_se=self.use_se[i], squeeze_ratio=self.squeeze_ratio, norm_mode=self.norm_mode, device='cuda'))
            else:
                self.stages.append(Stage(transition_flag=True, num_blocks=self.num_blocks[i], img_hw=self.img_hw[i], in_channels=self.main_channels[i], out_channels=self.expansion_dim[i],
                                        kernel_sz=self.kernel_sz[i], stride=self.stride[i], padding=self.padding[i], dilation=self.dilation[i], groups=self.groups[i], droprate=self.droprate[i],
                                        drop_mode=self.drop_mode[i], use_se=self.use_se[i], squeeze_ratio=self.squeeze_ratio, transition_channels=self.main_channels[i-1], transition_kersz=self.transition_kersz[i],
                                        transition_stride=self.transition_stride[i], norm_mode=self.norm_mode, device='cuda'))
        self.stages = torch.nn.ModuleList(self.stages)
        
    def forward(self, x):
        stage_output = []
        x = self.stem(x)
        for i in range(self.num_stages):
            x = self.stages[i](x)
            stage_output.append(x)
        return x, stage_output
        

In [6]:
dp_list, dp_mode = create_linear_p([3,3,9,3], 'batch', 0.5) # creates linearly decaying stochastic depth drop probability
# dp_list, dp_mode = create_uniform_p([3,3,9,3], 'batch', 0.001) # create constant stochastic depth drop probability

unet_encoder = Encoder(num_blocks=[3, 3, 9, 3], input_channels=3, stem_kersz=(4,4), stem_stride=(4,4), img_hw=[(56, 168), (28, 84), (14, 42), (7, 21)], main_channels=[96, 192, 384, 768], expansion_dim=[96*4, 192*4, 384*4, 768*4],
                               kernel_sz=[(7,7), (7,7), (7,7), (7,7)], stride=[(1,1),(1,1),(1,1),(1,1)], padding=['same', 'same', 'same', 'same'], dilation=[1,1,1,1], groups=[1,1,1,1], droprate=dp_list, drop_mode=dp_mode,
                               use_se=[True, True, True, True], squeeze_ratio=16, transition_kersz=[-1, (1,1),(1,1),(1,1)], transition_stride=[-1, (2,2), (2,2), (2,2)], norm_mode='layer_norm', device='cuda')
encoder_output, stage_output = unet_encoder(ex_img)

print("encoder_output shape: ", encoder_output.shape)
print("encoder output: is gradient alive?: ", encoder_output.requires_grad)
print()

print("linearly decaying survival probability: ", dp_list)
print("encoder output shape: ", encoder_output.shape)
for s_index in range(len(stage_output)):
    print(f"shape of {s_index}th stage output: ", stage_output[s_index].shape)
print()

torchinfo.summary(unet_encoder, (1, 3, 224, 672))

encoder_output shape:  torch.Size([1, 768, 7, 21])
encoder output: is gradient alive?:  True

linearly decaying survival probability:  [[0.0, 0.027777777777777776, 0.05555555555555555], [0.08333333333333333, 0.1111111111111111, 0.1388888888888889], [0.16666666666666666, 0.19444444444444445, 0.2222222222222222, 0.25, 0.2777777777777778, 0.3055555555555556, 0.3333333333333333, 0.3611111111111111, 0.3888888888888889], [0.4166666666666667, 0.4444444444444444, 0.4722222222222222]]
encoder output shape:  torch.Size([1, 768, 7, 21])
shape of 0th stage output:  torch.Size([1, 96, 56, 168])
shape of 1th stage output:  torch.Size([1, 192, 28, 84])
shape of 2th stage output:  torch.Size([1, 384, 14, 42])
shape of 3th stage output:  torch.Size([1, 768, 7, 21])



Layer (type:depth-idx)                             Output Shape              Param #
Encoder                                            [1, 768, 7, 21]           --
├─Stem: 1-1                                        [1, 96, 56, 168]          --
│    └─Conv2d: 2-1                                 [1, 96, 56, 168]          4,704
├─ModuleList: 1-2                                  --                        --
│    └─Stage: 2-2                                  [1, 96, 56, 168]          --
│    │    └─ModuleList: 3-1                        --                        5,687,730
│    └─Stage: 2-3                                  [1, 192, 28, 84]          --
│    │    └─ModuleList: 3-2                        --                        3,746,916
│    └─Stage: 2-4                                  [1, 384, 14, 42]          --
│    │    └─ModuleList: 3-3                        --                        16,362,072
│    └─Stage: 2-5                                  [1, 768, 7, 21]           --
│    │    

In [46]:
print("before upsampling shape: ", encoder_output.shape)

up_sample = torch.nn.ConvTranspose2d(768, 384, kernel_size=(7,7), stride=2, padding=3, output_padding=(1,1), groups=1, bias=True, dilation=1, padding_mode='zeros', device='cuda')
up_sample_output = up_sample(encoder_output)
print("after 1st upsampling shape: ", up_sample_output.shape)

up_sample = torch.nn.ConvTranspose2d(384, 192, kernel_size=(7,7), stride=2, padding=3, output_padding=(1,1), groups=1, bias=True, dilation=1, padding_mode='zeros', device='cuda')
up_sample_output = up_sample(up_sample_output)
print("after 2nd upsampling shape: ", up_sample_output.shape)

up_sample = torch.nn.ConvTranspose2d(192, 96, kernel_size=(7,7), stride=2, padding=3, output_padding=(1,1), groups=1, bias=True, dilation=1, padding_mode='zeros', device='cuda')
decoder_up_sample = up_sample(up_sample_output)
print("after 3rd upsampling shape: ", decoder_up_sample.shape)

up_sample = torch.nn.ConvTranspose2d(96, 48, kernel_size=(7,7), stride=2, padding=3, output_padding=(1,1), groups=1, bias=True, dilation=1, padding_mode='zeros', device='cuda')
head_up_sample = up_sample(decoder_up_sample)
print("after 4th upsampling shape: ", head_up_sample.shape)

up_sample = torch.nn.ConvTranspose2d(48, 24, kernel_size=(7,7), stride=2, padding=3, output_padding=(1,1), groups=1, bias=True, dilation=1, padding_mode='zeros', device='cuda')
head_output = up_sample(head_up_sample)
print("after 5th upsampling shape: ", head_output.shape)

print()
print("original image shape: ", ex_img.shape)
print()

# 4x increase in image width and image height
up_sample = torch.nn.ConvTranspose2d(96, 48, kernel_size=(4,4), stride=(4,4), padding=(0,0), output_padding=(0,0), groups=1, bias=True, dilation=1, padding_mode='zeros', device='cuda')
up_sample_output = up_sample(decoder_up_sample)
print("upsampling as the way corresponding to stem: ", up_sample_output.shape)

before upsampling shape:  torch.Size([1, 768, 7, 21])
after 1st upsampling shape:  torch.Size([1, 384, 14, 42])
after 2nd upsampling shape:  torch.Size([1, 192, 28, 84])
after 3rd upsampling shape:  torch.Size([1, 96, 56, 168])
after 4th upsampling shape:  torch.Size([1, 48, 112, 336])
after 5th upsampling shape:  torch.Size([1, 24, 224, 672])

original image shape:  torch.Size([1, 3, 224, 672])

upsampling as the way corresponding to stem:  torch.Size([1, 48, 224, 672])


In [7]:
# should carefully consider the transition layer, because when up-convolution (transposed convolution) is used, output size is doubled.
class DBlock(torch.nn.Module):
    def __init__(self, img_hw, in_channels, out_channels, kernel_sz, stride, padding, groups, dilation, droprate, drop_mode,
                 use_se, squeeze_ratio, is_fuse=False, fused_channels=-1, transition=False, transition_channels=-1, transition_kersz=-1,
                 transition_stride=-1, transition_padding=-1, transition_out_padding=-1, norm_mode='layer_norm', device='cuda'):
        super().__init__()
        self.img_h = img_hw[0]
        self.img_w = img_hw[1]
        
        # residual path parameters
        self.block_channels = in_channels
        self.block_out_channels = out_channels
        self.kernel_sz = kernel_sz
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.groups = groups
        
        self.droprate = droprate
        if drop_mode not in ['row', 'batch']:
            raise Exception("drop_mode must be either 'row' or 'batch'")
        self.drop_mode = drop_mode
        
        self.use_se = use_se
        self.squeeze_r = squeeze_ratio
        
        if transition and is_fuse:
            raise Exception("Upsampling and encoder stage output concat cannot happen same time!")
        
        self.is_fuse = is_fuse # whether to decide encoder stage output is concatenated to decoder path
        self.fused_channels = fused_channels # number of channels from encoder output
        
        if norm_mode not in ['layer_norm', 'batch_norm']:
            raise Exception(f"UnSupported normalization method: {norm_mode}. Use either 'layer_norm' or 'batch_norm'")
        self.norm_mode = norm_mode
        self.device = device
        
        # deconvolution path parameters
        self.transition = transition
        self.transition_channels = transition_channels
        self.transition_kersz = transition_kersz
        self.transition_stride = transition_stride
        self.transition_padding = transition_padding
        self.transition_out_padding = transition_out_padding
        
        if self.transition:
            # unlike encoder, kernel_size for self.conv_1 is not same as self.kernel_sz because of multiple possible choice of padding sizes
            self.conv_1 = torch.nn.ConvTranspose2d(in_channels=self.transition_channels, out_channels=self.block_channels, kernel_size=self.transition_kersz, 
                                                stride=self.transition_stride, padding=self.transition_padding, output_padding=self.transition_out_padding, groups=1, bias=True, dilation=1, padding_mode='zeros', device=self.device, dtype=None)
            self.transition_conv = torch.nn.ConvTranspose2d(in_channels=self.transition_channels, out_channels=self.block_channels, kernel_size=self.transition_kersz, 
                                                stride=self.transition_stride, padding=self.transition_padding, output_padding=self.transition_out_padding, groups=1, bias=True, dilation=1, padding_mode='zeros', device=self.device, dtype=None)
        else:
            if self.is_fuse:
                self.conv_1 = DepthWiseSepConv(in_channels=(self.block_channels+self.fused_channels), out_channels=self.block_channels, kernel_sz=self.kernel_sz,
                                                        stride=self.stride, padding=self.padding, dilation=self.dilation, device=self.device)
                # transition_conv is needed to adjust the number of output filters (pointwise convolution is used to fit the size)
                self.transition_conv = torch.nn.Conv2d(in_channels=self.block_channels+self.fused_channels, out_channels=self.block_channels, kernel_size=(1,1),
                                                       stride=(1,1), padding=(0,0), dilation=1, groups=1, bias=True, padding_mode='zeros', device=self.device)
            else:
                self.conv_1 = DepthWiseSepConv(in_channels=self.block_channels, out_channels=self.block_channels, kernel_sz=self.kernel_sz,
                                                        stride=self.stride, padding=self.padding, dilation=self.dilation, device=self.device)
            
        self.pointwise_1 = torch.nn.Conv2d(in_channels=self.block_channels, out_channels=self.block_out_channels, kernel_size=(1,1), stride=(1,1), padding=(0,0), dilation=1, groups=self.groups, bias=True, padding_mode='zeros', device=self.device)
        self.pointwise_2 = torch.nn.Conv2d(in_channels=self.block_out_channels, out_channels=self.block_channels, kernel_size=(1,1), stride=(1,1), padding=(0,0), dilation=1, groups=self.groups, bias=True, padding_mode='zeros', device=self.device)
        
        if self.use_se:
            self.fc1 = torch.nn.Conv2d(in_channels=self.block_channels, out_channels=int(self.block_channels/self.squeeze_r), kernel_size=(1,1), stride=(1,1), padding=(0,0), dilation=1, groups=1, bias=True, device=self.device)
            self.relu = torch.nn.ReLU()
            self.fc2 = torch.nn.Conv2d(in_channels=int(self.block_channels/self.squeeze_r), out_channels=self.block_channels, kernel_size=(1,1), stride=(1,1), padding=(0,0), dilation=1, groups=1, bias=True, device=self.device)
            self.sigmoid = torch.nn.Sigmoid()
        
        # unlike batch normalization, layer normalization consumes more gpu memory as input spatial size increases.
        # while decoder head restore image size to original, it necessarily process large spatial size input tensor. it consumes much more gpu memory
        # than using batch normalization as normalization method. (4M vs 256M) In this case, use batch normalization instead of layer normalization.
        if self.norm_mode == 'layer_norm':
            self.normalization = torch.nn.LayerNorm([self.block_channels, self.img_h, self.img_w], device=self.device)
        elif self.norm_mode == 'batch_norm':
            self.normalization = torch.nn.BatchNorm2d(num_features=self.block_channels, device=self.device)
        self.gelu = torch.nn.GELU()
    
    # encoder stage output concatenation is not happend in the block-level, do that on the stage-level
    def forward(self, x):
        highway = x
    
        x = self.conv_1(x)
        x = self.normalization(x)
        x = self.pointwise_1(x)
        x = self.gelu(x)
        x = self.pointwise_2(x)
        
        if self.use_se:
            squeeze = torch.nn.functional.avg_pool2d(x, x.size()[2:])
            excitation = self.fc1(squeeze)
            excitation = self.relu(excitation)
            excitation = self.fc2(excitation)
            attention = self.sigmoid(excitation)
            x = x * attention
        
        if self.transition or self.is_fuse:
            highway = self.transition_conv(highway)
        
        highway = highway + torchvision.ops.stochastic_depth(x, p=self.droprate, mode=self.drop_mode, training=self.training)
        return highway
    

In [8]:
decoder_block = DBlock(img_hw=(14, 42), in_channels=384, out_channels=384*4, kernel_sz=(7,7), stride=(1,1), padding=(3,3), groups=1, dilation=1, 
                       droprate=0.5, drop_mode='batch', use_se=True, squeeze_ratio=16, is_fuse=False, fused_channels=-1, transition=True, transition_channels=768, transition_kersz=(7,7),
                       transition_stride=(2,2), transition_padding=(3,3), transition_out_padding=(1,1), device='cuda')
decoder_block_output = decoder_block(encoder_output)

print("decoder 1st block output: ", decoder_block_output.shape)
print("decoder 1st block output: is gradient alive?: ", decoder_block_output.requires_grad)

decoder 1st block output:  torch.Size([1, 384, 14, 42])
decoder 1st block output: is gradient alive?:  True


In [9]:
torchinfo.summary(decoder_block, (1, 768, 7, 21)) #6,888,894 params

Layer (type:depth-idx)                   Output Shape              Param #
DBlock                                   [1, 384, 14, 42]          --
├─ConvTranspose2d: 1-1                   [1, 384, 14, 42]          14,451,072
├─LayerNorm: 1-2                         [1, 384, 14, 42]          451,584
├─Conv2d: 1-3                            [1, 1536, 14, 42]         591,360
├─GELU: 1-4                              [1, 1536, 14, 42]         --
├─Conv2d: 1-5                            [1, 384, 14, 42]          590,208
├─Conv2d: 1-6                            [1, 24, 1, 1]             9,240
├─ReLU: 1-7                              [1, 24, 1, 1]             --
├─Conv2d: 1-8                            [1, 384, 1, 1]            9,600
├─Sigmoid: 1-9                           [1, 384, 1, 1]            --
├─ConvTranspose2d: 1-10                  [1, 384, 14, 42]          14,451,072
Total params: 30,554,136
Trainable params: 30,554,136
Non-trainable params: 0
Total mult-adds (G): 17.69
Input size (M

In [10]:
# # should carefully consider the transition layer, because when up-convolution (transposed convolution) is used, output size is doubled.
# class DBlock(torch.nn.Module):
#     def __init__(self, img_hw, in_channels, out_channels, kernel_sz, stride, padding, groups, dilation, droprate, drop_mode,
#                  use_se, squeeze_ratio, is_fuse=False, fused_channels=-1, transition=False, transition_channels=-1, transition_kersz=-1,
#                  transition_stride=-1, transition_padding=-1, transition_out_padding=-1, device='cuda'):

class DStage(torch.nn.Module):
    def __init__(self, transition_flag, num_blocks:int, img_hw:list, in_channels, out_channels, kernel_sz, stride, padding, dilation, groups,
                 droprate, drop_mode:list, use_se:bool, squeeze_ratio:int, encoder_channels:int, transition_channels=-1, transition_kersz=-1, transition_stride=-1, 
                 transition_padding=-1, transition_out_padding=1, norm_mode='layer_norm', device='cuda'):
        
        super().__init__()
        self.transition_flag = transition_flag
        self.num_blocks = num_blocks
        self.img_h = img_hw[0]
        self.img_w = img_hw[1]
        self.stage_channels = in_channels
        self.stage_out_channels = out_channels
        self.kernel_sz = kernel_sz
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.groups = groups
        
        self.droprate=droprate
        for i in range(len(drop_mode)):
            if drop_mode[i] not in ['row', 'batch']:
                raise Exception("drop_mode must be 'row' or 'batch'")
        self.drop_mode = drop_mode
        self.use_se = use_se
        self.squeeze_r = squeeze_ratio
        self.encoder_channels = encoder_channels
        self.device = device
        
        self.transition_channels = transition_channels
        self.transition_kersz = transition_kersz
        self.transition_stride = transition_stride
        self.transition_padding = transition_padding
        self.transition_out_padding = transition_out_padding
        
        if norm_mode not in ['layer_norm', 'batch_norm']:
            raise Exception(f"Unsupported normalization method: {norm_mode}. Must be either 'layer_norm' 'batch_norm'")
        self.norm_mode = norm_mode
        
        self.blocks = []
        for i in range(num_blocks):
            if i == 0 and self.transition_flag: # upsampling and encoder stage output concatenation cannot happen same time!
                self.blocks.append(DBlock([self.img_h, self.img_w], in_channels=self.stage_channels, out_channels=self.stage_out_channels,
                                          kernel_sz=self.kernel_sz, stride=self.stride, padding=self.padding, groups=self.groups, dilation=self.dilation,
                                          droprate=self.droprate[i], drop_mode=self.drop_mode[i], use_se=self.use_se, squeeze_ratio=self.squeeze_r,
                                          transition=True, transition_channels=self.transition_channels, transition_kersz=self.transition_kersz, 
                                          transition_stride=self.transition_stride, transition_padding=self.transition_padding, 
                                          transition_out_padding=self.transition_out_padding, norm_mode=self.norm_mode, device=self.device))
            else:
                if i==1: # after upsampling is dones, concatenation is applied to the decoder output
                    is_fused = True # self.encoder_channels are only valid when is_fused==True
                else:
                    is_fused = False
                self.blocks.append(DBlock([self.img_h, self.img_w], in_channels=self.stage_channels, out_channels=self.stage_out_channels,
                                          kernel_sz=self.kernel_sz, stride=self.stride, padding=self.padding, groups=self.groups, dilation=self.dilation,
                                          droprate=self.droprate[i], drop_mode=self.drop_mode[i], use_se=self.use_se, squeeze_ratio=self.squeeze_r,
                                          is_fuse=is_fused, fused_channels=self.encoder_channels, norm_mode=self.norm_mode, device=self.device))
        self.blocks = torch.nn.ModuleList(self.blocks)
                
    def forward(self, x, enc_stage_o):
        for i in range(self.num_blocks):
            if i == 1: # after upconvolution are applied, concatenation with encoder stage output happens
                x = torch.concat([x, enc_stage_o], dim=1) # concatenate between decoder first block output and corresponding encoder output
            x = self.blocks[i](x)
        return x

In [12]:
print("encoder stage output shape: ", encoder_output.shape)
print("matched encoder stage output shape: ", stage_output[-2].shape)
print()
decoder_stage1 = DStage(transition_flag=True, num_blocks=3, img_hw=[14, 42], in_channels=384, out_channels=384*4, kernel_sz=(7,7), stride=1,
                       padding='same', dilation=1, groups=1, droprate=[0.5, 0.5, 0.5], drop_mode=['batch', 'batch', 'batch'], use_se=True,
                       squeeze_ratio=16, encoder_channels=384, transition_channels=768, transition_kersz=(7,7), transition_stride=(2,2), transition_padding=(3,3),
                       transition_out_padding=(1,1), norm_mode='layer_norm', device='cuda')

decoder_stage_output = decoder_stage1(encoder_output, stage_output[-2])
print("decoder stage_1 output shape: ", decoder_stage_output.shape)
print("matched encoder stage output shape: ", stage_output[-3].shape)
print("decoder stage_1: is gradient alive?: ", decoder_stage_output.requires_grad)
print()

decoder_stage2 = DStage(transition_flag=True, num_blocks=3, img_hw=[28, 84], in_channels=192, out_channels=192*4, kernel_sz=(7,7), stride=1,
                       padding='same', dilation=1, groups=1, droprate=[0.5, 0.5, 0.5], drop_mode=['batch', 'batch', 'batch'], use_se=True,
                       squeeze_ratio=16, encoder_channels=192, transition_channels=384, transition_kersz=(7,7), transition_stride=(2,2), transition_padding=(3,3),
                       transition_out_padding=(1,1), norm_mode='layer_norm', device='cuda')
decoder_stage_output = decoder_stage2(decoder_stage_output, stage_output[-3])

print("decoder stage_2 output shape: ", decoder_stage_output.shape)
print("decoder stage_2: is gradient alive?: ", decoder_stage_output.requires_grad)

encoder stage output shape:  torch.Size([1, 768, 7, 21])
matched encoder stage output shape:  torch.Size([1, 384, 14, 42])

decoder stage_1 output shape:  torch.Size([1, 384, 14, 42])
matched encoder stage output shape:  torch.Size([1, 192, 28, 84])
decoder stage_1: is gradient alive?:  True

decoder stage_2 output shape:  torch.Size([1, 192, 28, 84])
decoder stage_2: is gradient alive?:  True


In [13]:
torchinfo.summary(decoder_stage1, [(1, 768, 7, 21),(1, 384, 14, 42)]) # 34,654,152 params
# torchinfo.summary(decoder_stage2, [(1, 384, 14, 42), (1, 192, 28, 84)]) # 11,050,980 params

Layer (type:depth-idx)                   Output Shape              Param #
DStage                                   [1, 384, 14, 42]          --
├─ModuleList: 1-1                        --                        --
│    └─DBlock: 2-1                       [1, 384, 14, 42]          --
│    │    └─ConvTranspose2d: 3-1         [1, 384, 14, 42]          14,451,072
│    │    └─LayerNorm: 3-2               [1, 384, 14, 42]          451,584
│    │    └─Conv2d: 3-3                  [1, 1536, 14, 42]         591,360
│    │    └─GELU: 3-4                    [1, 1536, 14, 42]         --
│    │    └─Conv2d: 3-5                  [1, 384, 14, 42]          590,208
│    │    └─Conv2d: 3-6                  [1, 24, 1, 1]             9,240
│    │    └─ReLU: 3-7                    [1, 24, 1, 1]             --
│    │    └─Conv2d: 3-8                  [1, 384, 1, 1]            9,600
│    │    └─Sigmoid: 3-9                 [1, 384, 1, 1]            --
│    │    └─ConvTranspose2d: 3-10        [1, 384, 14, 42

In [14]:
# apply one transpose convolution layer (increase spatial dimension n times)
class DStemNx(torch.nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, groups=1,
                 dilation=1, device='cuda'):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.out_padding = output_padding
        self.groups = groups
        self.dilation = dilation
        self.device = device
        
        self.stem_conv = torch.nn.ConvTranspose2d(in_channels=self.in_channels, out_channels=self.out_channels, kernel_size=self.kernel_size,
                                                  stride=self.stride, padding=self.padding, output_padding=self.out_padding, groups=self.groups,
                                                  bias=True, dilation=self.dilation, padding_mode='zeros', device=self.device)
    
    def forward(self, x):
        x = self.stem_conv(x)
        return x

# apply stacked transpose convolution layer (increase spatial dimension n * num_stacked times)
class DStemStacked(torch.nn.Module):
    def __init__(self, in_channels:list, out_channels:list, kernel_size:list, stride:list, padding:list, output_padding:list, groups:list,
                 dilation:list, device):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.out_padding = output_padding
        self.groups = groups
        self.dilation = dilation
        self.device = device
        
        self.num_stacked = len(self.in_channels)
        
        self.stacks = []        
        for i in range(self.num_stacked):
            self.stacks.append(torch.nn.ConvTranspose2d(in_channels=self.in_channels[i], out_channels=self.out_channels[i], kernel_size=self.kernel_size[i],
                                                    stride=self.stride[i], padding=self.padding[i], output_padding=self.out_padding[i], groups=self.groups[i],
                                                    bias=True, dilation=self.dilation[i], padding_mode='zeros', device=self.device))
        self.stacks = torch.nn.ModuleList(self.stacks)
        
        self.out_conv = torch.nn.Conv2d(in_channels=self.out_channels[-1], out_channels=self.out_channels[-1], kernel_size=(1,1), stride=1,
                                        padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', device=self.device)
    
    def forward(self, x):
        for i in range(len(self.stacks)):
            x = self.stacks[i](x)
        x = self.out_conv(x)
        return x

# class DBlock(torch.nn.Module):
#     def __init__(self, img_hw, in_channels, out_channels, kernel_sz, stride, padding, groups, dilation, droprate, drop_mode,
#                  use_se, squeeze_ratio, is_fuse=False, fused_channels=-1, transition=False, transition_channels=-1, transition_kersz=-1,
#                  transition_stride=-1, transition_padding=-1, transition_out_padding=-1, device='cuda'):

# continues convnext style decoder block
class DStemStaged(torch.nn.Module):
    def __init__(self, num_blocks:list, img_hw:list, input_channels, main_channels, expansion_channels, kernel_sz, stride, padding, dilation, groups, droprate, drop_mode,
                 use_se, squeeze_ratio, transition_kersz, transition_stride, transition_padding, transition_out_padding, norm_mode, device='cuda'):
        super().__init__()
        self.num_blocks = num_blocks
        self.img_hw = img_hw
        
        self.input_channels = input_channels # number of last stage channels
        self.main_channels = main_channels
        self.expansion_channels = expansion_channels
        
        self.kernel_sz = kernel_sz
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.groups = groups
        
        self.droprate = droprate
        self.drop_mode = drop_mode
        
        self.use_se = use_se
        self.squeeze_ratio = squeeze_ratio
        self.transition_kersz = transition_kersz
        self.transition_stride = transition_stride
        self.transition_padding = transition_padding
        self.transition_out_padding = transition_out_padding
        
        if norm_mode not in ['batch_norm', 'layer_norm']:
            raise Exception(f"Unsupported normalization method: {norm_mode}. Must be either 'layer_norm', 'batch_norm'")
        self.norm_mode = norm_mode
        
        self.device = device
        
        self.stages = []
        self.num_stages = len(self.num_blocks)
        for i in range(self.num_stages):
            stage = []
            for j in range(self.num_blocks[i]):
                if j == 0: # for the first block
                    if i==0:
                        transition_channels = self.input_channels
                    elif i!=0:
                        transition_channels = self.main_channels[i-1]
                    stage.append(DBlock(img_hw=self.img_hw[i], in_channels=self.main_channels[i], out_channels=self.expansion_channels[i], kernel_sz=self.kernel_sz[i], 
                                        stride=self.stride[i], padding=self.padding[i], groups=self.groups[i], dilation=self.dilation[i], droprate=self.droprate[i][j],
                                        drop_mode=self.drop_mode[i][j], use_se=self.use_se[i], squeeze_ratio=self.squeeze_ratio, transition=True,
                                        transition_channels=transition_channels, transition_kersz=self.transition_kersz[i], transition_stride=self.transition_stride[i],
                                        transition_padding=self.transition_padding[i], transition_out_padding=self.transition_out_padding[i], norm_mode=self.norm_mode, device=self.device))
                else: # transition (up-sampling) does not happen if it is not first block
                    stage.append(DBlock(img_hw=self.img_hw[i], in_channels=self.main_channels[i], out_channels=self.expansion_channels[i], kernel_sz=self.kernel_sz[i], 
                                        stride=self.stride[i], padding=self.padding[i], groups=self.groups[i], dilation=self.dilation[i], norm_mode=self.norm_mode, droprate=self.droprate[i][j],
                                        drop_mode=self.drop_mode[i][j], use_se=self.use_se[i], squeeze_ratio=self.squeeze_ratio))
        
            self.stages.append(torch.nn.ModuleList(stage))
        self.stages = torch.nn.ModuleList(self.stages)
    

    def forward(self, x):
        for i in range(len(self.stages)):
            for j in range(len(self.stages[i])):
                x = self.stages[i][j](x)
        return x

In [15]:
# class DBlock(torch.nn.Module):
#     def __init__(self, img_hw, in_channels, out_channels, kernel_sz, stride, padding, groups, dilation, droprate, drop_mode,
#                  use_se, squeeze_ratio, is_fuse=False, fused_channels=-1, transition=False, transition_channels=-1, transition_kersz=-1,
#                  transition_stride=-1, transition_padding=-1, transition_out_padding=-1, device='cuda'):

# class DStage(torch.nn.Module):
#     def __init__(self, transition_flag, num_blocks:int, img_hw:list, in_channels, out_channels, kernel_sz, stride, padding, dilation, groups,
#                  droprate, drop_mode:list, use_se:bool, squeeze_ratio:int, encoder_channels:int, transition_channels=-1, transition_kersz=-1, transition_stride=-1, 
#                  transition_padding=-1, transition_out_padding=1, device='cuda'):

# decoder head (this does not need to be symmetrical with encoder)
class Decoder(torch.nn.Module):
    def __init__(self, num_blocks:list, img_hw:list, main_channels:list, expansion_dim:list, kernel_sz:list, 
                 stride:list, padding:list, dilation:list, groups:list, droprate:list, drop_mode:list, 
                 use_se:list, squeeze_ratio:int, encoder_channels:list,  transition_kersz:list, transition_stride:list, 
                 transition_padding:list, transition_out_padding:list, norm_mode:str, head:torch.nn.ModuleList, device='cuda'):
        super().__init__()
        
        self.num_blocks = num_blocks # list that contains number of blocks for each stage
        self.img_hw = img_hw # list that contains representative (img_height, img_width) for each stage
        self.main_channels = main_channels
        self.expansion_dim = expansion_dim
        self.kernel_sz = kernel_sz
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.groups = groups
        
        self.droprate = droprate # stochastic depth drop probability ex) [[], [], []] -> outer: number of stages, inner: number of blocks
        self.drop_mode = drop_mode # stochastic depth drop mode
        
        self.use_se = use_se  # list that contains whether to use sequeeze and excitation module in the stage
        self.squeeze_ratio = squeeze_ratio
        
        self.encoder_channels = encoder_channels # list that contains number of channels for each encoder stage output (this excludes final encoder stage output)
        
        self.transition_kersz = transition_kersz
        self.transition_stride = transition_stride
        self.transition_padding = transition_padding
        self.transition_out_padding = transition_out_padding
        
        if norm_mode not in ['layer_norm', 'batch_norm']:
            raise Exception(f"Unsupported normalization method: {norm_mode}. Must be either 'layer_norm', 'batch_norm'")
        self.norm_mode = norm_mode
        
        # head is intentionally created outside of Decoder class because many variations and img size can be required in various situation.
        self.head = head # decoder head for further convolution

        self.device = device
        
        self.num_stages = len(self.num_blocks) # length of list indicates the number of stages
        self.stages = []        
        
        for i in range(self.num_stages):
            if i == 0: # start of decoder first inputs the number of last encoder stage output channels
                transition_channels = self.encoder_channels[0]
            else:
                transition_channels = self.main_channels[i-1]
            
            # unlike encoder, decoder always upsampling the input (transition_flag should be always True)
            self.stages.append(DStage(transition_flag=True, num_blocks=self.num_blocks[i], img_hw=self.img_hw[i], in_channels=self.main_channels[i],
                                        out_channels=self.expansion_dim[i], kernel_sz = self.kernel_sz[i], stride=self.stride[i], padding=self.padding[i],
                                        dilation = self.dilation[i], groups=self.groups[i], droprate=self.droprate[i], drop_mode=self.drop_mode[i],
                                        use_se=self.use_se[i], squeeze_ratio=self.squeeze_ratio, encoder_channels=self.encoder_channels[i+1], transition_channels=transition_channels, 
                                        transition_kersz=self.transition_kersz[i], transition_stride=self.transition_stride[i], transition_padding=self.transition_padding[i],
                                        transition_out_padding=self.transition_out_padding[i], norm_mode=self.norm_mode, device=self.device))
        self.stages = torch.nn.ModuleList(self.stages)
        
    def forward(self, x, enc_stage_out):
        encoder_idx = len(enc_stage_out)-2 # -1 from zero indexing, -1 from skipping last stage output
        for i in range(self.num_stages):
            if encoder_idx >= 0:
                x = self.stages[i](x, enc_stage_out[encoder_idx])
            encoder_idx -= 1

        x = self.head(x)
    
        return x

In [19]:
# class DStem4x(torch.nn.Module):
#     def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, groups=1,
#                  dilation=1, device='cuda'):

# class DStemStacked(torch.nn.Module):
#     def __init__(self, in_channels:list, out_channels:list, kernel_size:list, stride:list, padding:list, output_padding:list, groups:list,
#                  dilation:list, device):

# class DStemStaged(torch.nn.ModuleList):
#     def __init__(self, num_blocks:list, img_hw:list, input_channels, main_channels, expansion_channels, kernel_sz, stride, padding, dilation, groups, droprate, drop_mode,
#                  use_se, squeeze_ratio, transition_kersz, transition_stride, transition_padding, transition_out_padding, device='cuda'):

head_4x = DStemNx(in_channels=96, out_channels=128, kernel_size=(4,4), stride=(4,4), padding=(0,0), output_padding=(0,0), groups=1, dilation=1, device='cuda')
# head_stack = DStemStacked(in_channels=[96, 128], out_channels=[128, 256], kernel_size=[(4,4), (12,12)], stride=[(2,4), (4,4)], padding=[(1,0), (1,0)], output_padding=[(0,0), (0,0)], 
#                           groups=[1,1], dilation=[1,1], device='cuda')

# droprate, drop_mode = create_linear_p([3,9,3,2,2], dp_mode='batch', last_p=0.5)
# head_stage = DStemStaged(num_blocks=[2,2], img_hw=[(112, 336), (224, 672)], input_channels=96, main_channels=[96, 128], expansion_channels=[96*4, 128*4], 
#                          kernel_sz=[(7,7), (7,7)], stride=[(1,1), (1,1)], padding=['same', 'same'], dilation=[1, 1], groups=[1,1], 
#                          droprate=droprate[3:], drop_mode=drop_mode[3:],
#                          use_se=[True, True], squeeze_ratio=16, transition_kersz=[(7,7), (7,7)], transition_stride=[(2,2), (2,2)], transition_padding=[(3,3), (3,3)],
#                          transition_out_padding=[(1,1), (1,1)], norm_mode='batch_norm', device='cuda')


torchinfo.summary(head_4x, (1, 96, 56, 168)) # 196,736 params
# torchinfo.summary(head_stack, (1, 96, 73, 102)) # 4,981,376 params
# torchinfo.summary(head_stage, (1, 96, 73, 102)) # 2,564,476 params, much heavier than head_4x but entire number of params are not that huge

Layer (type:depth-idx)                   Output Shape              Param #
DStemNx                                  [1, 128, 224, 672]        --
├─ConvTranspose2d: 1-1                   [1, 128, 224, 672]        196,736
Total params: 196,736
Trainable params: 196,736
Non-trainable params: 0
Total mult-adds (G): 29.61
Input size (MB): 3.61
Forward/backward pass size (MB): 154.14
Params size (MB): 0.79
Estimated Total Size (MB): 158.54

In [22]:
# without considering stem layers
# unet_decoder = Decoder(num_blocks=[3,9,3], img_hw=[(14, 42), (28, 84), (56, 168)], main_channels=[384, 192, 96], expansion_dim=[384*4, 192*4, 96*4],
#                   kernel_sz=[(7,7)]*3, stride=[(1,1)]*3, padding=['same']*3, dilation=[1]*3, groups=[1]*3, droprate=droprate[:3], drop_mode=drop_mode[:3],
#                   use_se=[True]*3, squeeze_ratio=16, encoder_channels=[768, 384, 192, 96], transition_kersz=[(7,7)]*3, transition_stride=[(2,2)]*3, 
#                   transition_padding=[(3,3)]*3, transition_out_padding=[(1,1)]*3, norm_mode='layer_norm', head=head_stage, device='cuda')


droprate, drop_mode = create_linear_p([3,9,3], dp_mode='batch', last_p=0.5)
unet_decoder = Decoder(num_blocks=[3,9,3], img_hw=[(14, 42), (28, 84), (56, 168)], main_channels=[384, 192, 96], expansion_dim=[384*4, 192*4, 96*4],
                  kernel_sz=[(7,7)]*3, stride=[(1,1)]*3, padding=['same']*3, dilation=[1]*3, groups=[1]*3, droprate=droprate, drop_mode=drop_mode,
                  use_se=[True]*3, squeeze_ratio=16, encoder_channels=[768, 384, 192, 96], transition_kersz=[(7,7)]*3, transition_stride=[(2,2)]*3, 
                  transition_padding=[(3,3)]*3, transition_out_padding=[(1,1)]*3, norm_mode='layer_norm', head=head_4x, device='cuda')


decoder_output = unet_decoder(encoder_output, stage_output)
print("decoder output shape: ", decoder_output.shape) # restore to original image size
print("decoder output: is gradient alive?: ", decoder_output.requires_grad)
# torchinfo.summary(decoder, [(1, 768, 10, 13), (1, 96, 73, 102), (1, 192, 37, 51), (1, 384, 19, 26), (1, 768, 10, 13)])

decoder output shape:  torch.Size([1, 128, 224, 672])
decoder output: is gradient alive?:  True


In [23]:
# container encoder and decoder
class UNet(torch.nn.Module):
    def __init__(self, encoder: Encoder, decoder: Decoder, num_cls:int, output_mode:str):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.num_cls = num_cls # number of segmentation class
        
        if output_mode not in ['logits', 'probs']:
            raise Exception(f"Unsupported output mode: {output_mode}. Must be either 'logits' or 'probs'")
        self.output_mode = output_mode
        
        if self.output_mode == 'probs':
            self.softmax = torch.nn.Softmax2d()
        
        if isinstance(self.decoder.head, DStemNx):
            in_channels = self.decoder.head.out_channels
        elif isinstance(self.decoder.head, DStemStacked):
            in_channels = self.decoder.head.out_channels[-1]
        elif isinstance(self.decoder.head, DStemStaged):
            in_channels = self.decoder.head.main_channels[-1]
        else:
            raise Exception("Not Implemented: Currently Head supports only DStemNx, DStemStacked, DStemStaged")
        
        self.cls_head = torch.nn.Conv2d(in_channels=in_channels, out_channels=self.num_cls, kernel_size=(1,1),
                                        stride=(1,1), padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', device=encoder.device)
    
    def forward(self, x):
        x, stage_o = self.encoder(x)
        x = self.decoder(x, stage_o)
        x = self.cls_head(x)  # (B, num_cls, ori_h, ori_w)
        
        if self.output_mode=='probs':
            x = self.softmax(x) # return probabilities
        
        del stage_o
        return x

In [24]:
convnext_unet = UNet(encoder=unet_encoder, decoder=unet_decoder, num_cls=5, output_mode='probs')
unet_o = convnext_unet(ex_img)
print("unet output shape: ", unet_o.shape)
print("unet output: is gradient alive?: ", unet_o.requires_grad)

torchinfo.summary(convnext_unet, (1, 3, 224, 672)) # 106,024,261 params total

unet output shape:  torch.Size([1, 5, 224, 672])
unet output: is gradient alive?:  True


Layer (type:depth-idx)                                  Output Shape              Param #
UNet                                                    [1, 5, 224, 672]          --
├─Encoder: 1-1                                          [1, 768, 7, 21]           --
│    └─Stem: 2-1                                        [1, 96, 56, 168]          --
│    │    └─Conv2d: 3-1                                 [1, 96, 56, 168]          4,704
│    └─ModuleList: 2-2                                  --                        --
│    │    └─Stage: 3-2                                  [1, 96, 56, 168]          5,687,730
│    │    └─Stage: 3-3                                  [1, 192, 28, 84]          3,746,916
│    │    └─Stage: 3-4                                  [1, 384, 14, 42]          16,362,072
│    │    └─Stage: 3-5                                  [1, 768, 7, 21]           16,936,848
├─Decoder: 1-2                                          [1, 128, 224, 672]        --
│    └─ModuleList: 2-3     