# Efficient Unet

My attempt at building a more efficient unet using depthwise separable deconv blocks

In [231]:
import numpy as np
import matplotlib.pyplot as plt
import scipy.io

import torch
from torch import nn, Tensor
import torch.optim as optim
import torchvision

from tqdm.notebook import tqdm
from prettytable import PrettyTable
from collections import OrderedDict
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union

from misc import Conv2dNormActivation, InvertedResidual


def count_parameters(model, showTable=False):
    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad:
            continue
        params = parameter.numel()
        table.add_row([name, params])
        total_params += params
    if showTable:
        print(table)
    return total_params


def calculate_storage(model, show_buffer=True):
    param_size = 0
    for param in model.parameters():
        param_size += param.nelement() * param.element_size()
    buffer_size = 0
    for buffer in model.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()
    if show_buffer:
        print(f"Buffer size: {buffer_size/1024**2:.3f} MB")

    size_all_mb = (param_size + buffer_size) / 1024**2
    return size_all_mb

In [258]:
import torch
from torch import nn, Tensor

from typing import Any, Callable, List, Optional, Tuple, Union


class Conv2dNormActivation(nn.Sequential):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: Union[int, Tuple[int, ...]] = 3,
        stride: Union[int, Tuple[int, ...]] = 1,
        padding: Optional[Union[int, Tuple[int, ...], str]] = None,
        groups: int = 1,
        norm_layer: Optional[Callable[..., nn.Module]] = None,
        activation_layer: Optional[Callable[..., nn.Module]] = None,
        dropout_p: float = 0.0,
        dilation: Union[int, Tuple[int, ...]] = 1,
        bias: Optional[bool] = None,
        inplace: Optional[bool] = True,
    ) -> None:
        if padding is None:
            if dilation > 1:
                padding = dilation * (kernel_size - 1) // 2
            else:
                padding = kernel_size // 2

        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if activation_layer is None:
            activation_layer = nn.ReLU6

        dropout_layer = nn.Dropout2d

        layers = [
            nn.Conv2d(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                dilation,
                groups,
                bias,
            ),
            norm_layer(out_channels),
            dropout_layer(p=dropout_p),
            activation_layer(inplace=inplace),
        ]

        super().__init__(*layers)


class Deconv2dNormActivation(nn.Sequential):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: Union[int, Tuple[int, ...]] = 3,
        stride: Union[int, Tuple[int, ...]] = 1,
        padding: Optional[Union[int, Tuple[int, ...], str]] = None,
        output_padding: Optional[Union[int, Tuple[int, ...], str]] = None,
        groups: int = 1,
        norm_layer: Optional[Callable[..., nn.Module]] = None,
        activation_layer: Optional[Callable[..., nn.Module]] = None,
        dropout_p: float = 0.0,
        dilation: Union[int, Tuple[int, ...]] = 1,
        bias: Optional[bool] = None,
        inplace: Optional[bool] = True,
    ) -> None:
        if padding is None:
            if dilation > 1:
                padding = dilation * (kernel_size - 1) // 2
            else:
                padding = kernel_size // 2

        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if activation_layer is None:
            activation_layer = nn.ReLU6

        dropout_layer = nn.Dropout2d

        layers = [
            nn.ConvTranspose2d(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                output_padding,
                groups,
                bias,
                dilation,
            ),
            norm_layer(out_channels),
            dropout_layer(p=dropout_p),
            activation_layer(inplace=inplace),
        ]

        super().__init__(*layers)


class InvertedResidual(nn.Module):
    def __init__(self, inp: int, oup: int, stride: int, expand_ratio: int) -> None:
        super().__init__()

        hidden_dim = int(round(inp * expand_ratio))
        self.use_residual = inp == oup

        layers: List[nn.Module] = []

        if expand_ratio != 1:
            # pointwise expansion
            layers.append(
                Conv2dNormActivation(inp, hidden_dim, kernel_size=1, bias=False)
            )

        def conv_dw(inp, oup, stride):
            # inp, oup = int(inp * self.alpha), int(oup * self.alpha)
            return nn.Sequential(
                # depth wise
                Conv2dNormActivation(
                    inp,
                    inp,
                    kernel_size=3,
                    stride=stride,
                    padding=1,
                    groups=inp,
                    bias=False,
                    norm_layer=nn.BatchNorm2d,
                    activation_layer=nn.ReLU6,
                ),
                # pointwise
                nn.Conv2d(inp, oup, 1, 1, padding=0, bias=False),
                nn.BatchNorm2d(oup),
            )

        # depth-wise convolution:
        layers.append(conv_dw(hidden_dim, oup, stride))
        self.conv = nn.Sequential(*layers)

    def forward(self, x: Tensor) -> Tensor:
        out = self.conv(x)
        if self.use_residual:
            out = x + out
        print(out.shape)
        return out


def deconv_dw(inp: int, oup: int, stride: int) -> nn.Sequential:
    # inp, oup = int(inp * self.alpha), int(oup * self.alpha)
    out_pad = 1 if stride == 2 else 0
    return nn.Sequential(
        # depth wise
        Deconv2dNormActivation(
            inp,
            inp,
            3,
            stride=stride,
            padding=1,
            output_padding=out_pad,
            groups=inp,
            bias=False,
            norm_layer=nn.BatchNorm2d,
            activation_layer=nn.ReLU6,
        ),
        # pointwise
        nn.Conv2d(inp, oup, 1, 1, padding=0, bias=False),
        nn.BatchNorm2d(oup),
    )


class DecomposedDeconv(nn.Module):
    # allow arbitrary input and output channel numbers
    def __init__(self, inp: int, oup: int, stride: int, *args) -> None:
        super().__init__()

        self.deconv = nn.Sequential(deconv_dw(inp, oup, stride))

    def forward(self, x: Tensor) -> Tensor:
        out = self.deconv(x)
        print(out.shape)
        return out


class InvertedResidualDeconv(nn.Module):
    def __init__(self, inp: int, oup: int, stride: int, expand_ratio: int) -> None:
        super().__init__()

        hidden_dim = int(round(inp * expand_ratio))
        self.use_residual = inp == oup

        layers: List[nn.Module] = []

        if expand_ratio != 1:
            # pointwise expansion
            layers.append(
                Conv2dNormActivation(inp, hidden_dim, kernel_size=1, bias=False)
            )

        layers.append(deconv_dw(hidden_dim, oup, stride))
        self.deconv = nn.Sequential(*layers)

    def forward(self, x: Tensor) -> Tensor:
        out = self.deconv(x)
        if self.use_residual:
            out = x + out
        print(out.shape)
        return out

In [253]:
drop_cfg = [
    [0, 0, 0, 0, 0, 0, 0.5, 0.5, 0.5],
    [0.5, 0.5, 0.5, 0, 0, 0, 0, 0, 0],
]

In [259]:
class MobileUNet(nn.Module):
    def __init__(self, alpha: float = 1.0) -> None:
        super().__init__()

        # ================ parameters ================
        self.hidden_dim: int = 1280
        self.num_classes: int = 3
        # self.config: np.ndarray = np.array(
        #     [
        #         # t, c, n, s
        #         [1, 16, 1, 1],
        #         [6, 24, 2, 2],  # skip
        #         [6, 32, 3, 2],  # skip
        #         [6, 64, 4, 2],  # skip
        #         [6, 96, 3, 1],
        #         [6, 160, 3, 2],  # skip
        #         [6, 320, 1, 1],
        #     ]
        # )
        self.config: np.ndarray = np.array(
            [
                # t, c, n, s, op
                [1, 16, 1, 1, 0],
                [6, 24, 2, 2, 1],  # skip
                [6, 32, 3, 2, 1],  # skip
                [6, 64, 4, 2, 0],  # skip
                [6, 96, 3, 1, 0],
                [6, 160, 3, 2, 0],  # skip
                [6, 320, 1, 1, 0],
            ]
        )
        encoder_block = InvertedResidual
        decoder_block = InvertedResidualDeconv
        dropout_block = nn.Dropout2d
        # decoder_block = DecomposedDeconv

        self.alpha = alpha
        self.config[:, 1] = np.ceil(self.config[:, 1] * self.alpha).astype(int)

        # ================== encoder ==================
        encoder = OrderedDict()
        inp = 32  # first conv2d channels
        encoder["conv0"] = Conv2dNormActivation(
            3, inp, 3, 2, 1, bias=False, dropout_p=drop_cfg[0][0]
        )

        # add bottleneck layers
        for i, (t, c, n, s) in enumerate(self.config):
            blocks: List[nn.Module] = []
            for j in range(n):
                sj = s if j == 0 else 1
                blocks.extend(
                    [encoder_block(inp, c, sj, t), dropout_block(p=drop_cfg[0][i + 1])]
                )
                inp = c
            encoder[f"bottleneck{i+1}"] = nn.Sequential(*blocks)

        # append last conv2d
        encoder["convLast"] = Conv2dNormActivation(
            inp, self.hidden_dim, 1, 1, bias=False, dropout_p=drop_cfg[0][-1]
        )

        # ================== decoder ==================
        def get_deconv_config(config: np.ndarray) -> np.ndarray:
            deconv_config = config.copy()
            # shift channel up by 1
            deconv_config[:, 1] = np.concatenate(([3], deconv_config[:, 1][:-1]))

            return deconv_config[::-1]

        self.deconv_config = get_deconv_config(self.config)
        print(self.deconv_config)

        decoder = OrderedDict()
        decoder["deconv0"] = Conv2dNormActivation(
            self.hidden_dim, inp, 1, 1, bias=False, dropout_p=drop_cfg[1][0]
        )
        skip = 0
        for i, (t, c, n, s) in enumerate(self.deconv_config):
            blocks: List[nn.Module] = []
            for j in range(n):
                sj = s if j == n - 1 else 1
                cj = c if j == n - 1 else inp
                blocks.extend(
                    [
                        decoder_block(inp + skip, cj, sj, t),
                        dropout_block(p=drop_cfg[1][i + 1]),
                    ]
                )
                inp = cj
                skip = c if sj == 2 else 0
            decoder[f"decomposed_deconv{i+1}"] = nn.Sequential(*blocks)

        # append last deconv2d
        decoder["deconvLast"] = Deconv2dNormActivation(
            inp, self.num_classes, 3, 2, 1, 1, bias=False, dropout_p=drop_cfg[1][-1]
        )

        # ========= weight initialization ========
        self.encoder = nn.Sequential(encoder)
        self.decoder = nn.Sequential(decoder)
        self.softmax = nn.Softmax(dim=1)

        self._init_weights()

    def _init_weights(self) -> None:
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out")
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.zeros_(m.bias)
            elif isinstance(m, nn.ConvTranspose2d):
                torch.nn.init.xavier_normal_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

    def forward(self, x: Tensor) -> Tensor:
        # store features for skip connections
        features = []

        # ENCODER
        x = self.encoder[0](x)
        for i in range(len(self.config)):
            print(f"Encoder {i+1}")
            x = self.encoder[i + 1](x)
            if i < len(self.config) - 1 and self.config[i + 1][3] == 2:
                features.append(x)
                # print(features[-1].shape)
        # final hidden state
        x = self.encoder[-1](x)
        print("hidden state", x.shape)
        # DECODER
        # first deconv
        x = self.decoder[0](x)
        # symmetrical decoder
        for i in range(len(self.config)):
            print(f"Decoder {i+1}")
            x = self.decoder[i + 1](x)
            if self.deconv_config[i][3] == 2:
                # print(f"{x.shape}, {features[-1].shape}")
                print("skip connection here")
                # skip connection
                x = torch.cat((x, features.pop()), dim=1)
        # final output
        x = self.decoder[-1](x)
        x = self.softmax(x)
        return x


# initialize network
unet = MobileUNet(alpha=1)

[[  6 160   1   1]
 [  6  96   3   2]
 [  6  64   3   1]
 [  6  32   4   2]
 [  6  24   3   2]
 [  6  16   2   2]
 [  1   3   1   1]]


In [265]:
# check if the model produces the expected output shape:
x_batch = torch.zeros(1, 3, 480, 640)
print(f'output dimension: {unet(x_batch).shape}')

Encoder 1
torch.Size([1, 16, 240, 320])
Encoder 2
torch.Size([1, 24, 120, 160])
torch.Size([1, 24, 120, 160])
Encoder 3
torch.Size([1, 32, 60, 80])
torch.Size([1, 32, 60, 80])
torch.Size([1, 32, 60, 80])
Encoder 4
torch.Size([1, 64, 30, 40])
torch.Size([1, 64, 30, 40])
torch.Size([1, 64, 30, 40])
torch.Size([1, 64, 30, 40])
Encoder 5
torch.Size([1, 96, 30, 40])
torch.Size([1, 96, 30, 40])
torch.Size([1, 96, 30, 40])
Encoder 6
torch.Size([1, 160, 15, 20])
torch.Size([1, 160, 15, 20])
torch.Size([1, 160, 15, 20])
Encoder 7
torch.Size([1, 320, 15, 20])
hidden state torch.Size([1, 1280, 15, 20])
Decoder 1
torch.Size([1, 160, 15, 20])
Decoder 2
torch.Size([1, 160, 15, 20])
torch.Size([1, 160, 15, 20])
torch.Size([1, 96, 30, 40])
skip connection here
Decoder 3
torch.Size([1, 96, 30, 40])
torch.Size([1, 96, 30, 40])
torch.Size([1, 64, 30, 40])
Decoder 4
torch.Size([1, 64, 30, 40])
torch.Size([1, 64, 30, 40])
torch.Size([1, 64, 30, 40])
torch.Size([1, 32, 60, 80])
skip connection here
Decoder 

In [186]:
# ENCODER:
# bottleneck1: torch.Size([1, 16, 112, 112])
# bottleneck2: torch.Size([1, 24, 56, 56])
# bottleneck3: torch.Size([1, 32, 28, 28])
# bottleneck4: torch.Size([1, 64, 14, 14])
# bottleneck5: torch.Size([1, 96, 14, 14])
# bottleneck6: torch.Size([1, 160, 7, 7])
# bottleneck7: torch.Size([1, 320, 7, 7])
# final conv:  torch.Size([1, 1280, 7, 7])
# DECODER:
# torch.Size([1, 320, 7, 7])
# decoder1: torch.Size([1, 320, 7, 7])
# decoder2: torch.Size([1, 160, 14, 14])
# decoder3: torch.Size([1, 96, 14, 14])
# decoder4: torch.Size([1, 64, 28, 28])
# decoder5: torch.Size([1, 32, 56, 56])
# decoder6: torch.Size([1, 24, 112, 112])
# decoder7: torch.Size([1, 16, 112, 112])
# final deconv:  torch.Size([1, 3, 224, 224])
# output dimension: torch.Size([1, 3, 224, 224])

In [None]:
unet.decoder

In [238]:
print("ENCODER")
for i, block in enumerate(unet.encoder):
    print(f"----Block {i}----")
    size_all_mb = calculate_storage(block, show_buffer=False)
    print("model size: {:.3f}MB".format(size_all_mb))

    total_params = count_parameters(block)
    print(f"Total Trainable Params: {total_params}")


print("DECODER")
for i, block in enumerate(unet.decoder):
    print(f"----Block {i}----")
    size_all_mb = calculate_storage(block, show_buffer=False)
    print("model size: {:.3f}MB".format(size_all_mb))

    total_params = count_parameters(block)
    print(f"Total Trainable Params: {total_params}")

print("----UNET----")
size_all_mb = calculate_storage(unet)
print("model size: {:.3f}MB".format(size_all_mb))
total_params = count_parameters(unet)
print(f"Total Trainable Params: {total_params}")


# Normal:
# Buffer size: 0.145 MB
# model size: 31.730MB
# Total Trainable Params: 8279984

# Decomposed:
# Buffer size: 0.155 MB
# model size: 23.437MB
# Total Trainable Params: 6103125

# Inverted Residual:
# Buffer size: 0.293 MB
# model size: 33.209MB
# Total Trainable Params: 8628749

ENCODER
----Block 0----
model size: 0.004MB
Total Trainable Params: 928
----Block 1----
model size: 0.004MB
Total Trainable Params: 896
----Block 2----
model size: 0.057MB
Total Trainable Params: 13968
----Block 3----
model size: 0.160MB
Total Trainable Params: 39696
----Block 4----
model size: 0.724MB
Total Trainable Params: 183872
----Block 5----
model size: 1.182MB
Total Trainable Params: 303168
----Block 6----
model size: 3.076MB
Total Trainable Params: 795264
----Block 7----
model size: 1.825MB
Total Trainable Params: 473920
----Block 8----
model size: 1.582MB
Total Trainable Params: 412160
DECODER
----Block 0----
model size: 1.567MB
Total Trainable Params: 410240
----Block 1----
model size: 3.643MB
Total Trainable Params: 946880
----Block 2----
model size: 3.474MB
Total Trainable Params: 898432
----Block 3----
model size: 2.192MB
Total Trainable Params: 564992
----Block 4----
model size: 0.806MB
Total Trainable Params: 204736
----Block 5----
model size: 0.280MB
Total Trainable Pa

In [237]:
print("----ENCODER----")
size_all_mb = calculate_storage(unet.encoder)
print("model size: {:.3f}MB".format(size_all_mb))
total_params = count_parameters(unet.encoder)
print(f"Total Trainable Params: {total_params}")

print("----decoder----")
size_all_mb = calculate_storage(unet.decoder)
print("model size: {:.3f}MB".format(size_all_mb))
total_params = count_parameters(unet.decoder)
print(f"Total Trainable Params: {total_params}")

----ENCODER----
Buffer size: 0.131 MB
model size: 8.614MB
Total Trainable Params: 2223872
----decoder----
Buffer size: 0.162 MB
model size: 12.095MB
Total Trainable Params: 3128077


In [262]:
unet.encoder

Sequential(
  (conv0): Conv2dNormActivation(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): Dropout2d(p=0, inplace=False)
    (3): ReLU6(inplace=True)
  )
  (bottleneck1): Sequential(
    (0): InvertedResidual(
      (conv): Sequential(
        (0): Sequential(
          (0): Conv2dNormActivation(
            (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
            (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): Dropout2d(p=0.0, inplace=False)
            (3): ReLU6(inplace=True)
          )
          (1): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
    )
    (1): Dropout2d(p=0, inplace=False)
  )
  (bottleneck2): Sequent