**DynUNet: Optimized U-Net for Brain Tumor Segmentation**    
*Michał Futrega, Alexandre Milesi, Michal Marcinkiewicz, Pablo Ribalta*   
[[paper](https://arxiv.org/abs/2110.03352)]   
MICCAI-BraTS 2021   

In [2]:
import torch
import torch.nn as nn

from typing import Union, List, Tuple

  from .autonotebook import tqdm as notebook_tqdm


In [50]:
class ConvBlock(nn.Module):
    def __init__(self, in_dim, out_dim, stride=1) -> None:
        super(ConvBlock, self).__init__()
        self.use_residual = in_dim == out_dim and stride == 1

        self.conv = nn.Conv3d(in_channels=in_dim, out_channels=out_dim, stride=stride)
        self.norm = nn.InstanceNorm3d(num_features=out_dim)
        self.act  = nn.SiLU()

    def forward(self, x):

        h = self.conv(x)
        h = self.norm(h)
        h = self.act(h)

        # if self.use_residual:
        #     h = h + x 

        return h

In [3]:


class DynUNet_encoder(nn.Module):
    def __init__(self, init_dim, hidden_dim:Union[List, Tuple]) -> None:
        super().__init__()

        self.layer1 = nn.Sequential(
            ConvBlock(init_dim, hidden_dim[0]),
            ConvBlock(hidden_dim[0], hidden_dim[0])
        )

        self.layer2 = nn.Sequential(
            ConvBlock(hidden_dim[0], hidden_dim[1], stride=2),
            ConvBlock(hidden_dim[1], hidden_dim[1])
        )

        self.layer3 = nn.Sequential(
            ConvBlock(hidden_dim[1], hidden_dim[2], stride=2),
            ConvBlock(hidden_dim[2], hidden_dim[2])
        )

        self.layer4 = nn.Sequential(
            ConvBlock(hidden_dim[2], hidden_dim[3], stride=2),
            ConvBlock(hidden_dim[3], hidden_dim[3])
        )

        self.layer5 = nn.Sequential(
            ConvBlock(hidden_dim[3], hidden_dim[4], stride=2),
            ConvBlock(hidden_dim[4], hidden_dim[4])
        )

        self.layer6 = nn.Sequential(
            ConvBlock(hidden_dim[4], hidden_dim[5], stride=2),
            ConvBlock(hidden_dim[5], hidden_dim[5])
        )

        self.layer7 = nn.Sequential(
            ConvBlock(hidden_dim[5], hidden_dim[6], stride=2),
            ConvBlock(hidden_dim[6], hidden_dim[6])
        )

    def forward(self, x):

        h1 = self.layer1(x)
        h2 = self.layer2(h1)
        h3 = self.layer3(h2)
        h4 = self.layer4(h3)
        h5 = self.layer5(h4)
        h6 = self.layer6(h5)
        h7 = self.layer7(h6)

        stage_outputs ={
            'h1':h1,
            'h2':h2,
            'h3':h3,
            'h4':h4,
            'h5':h5,
            'h6':h6
        }

        return h7, stage_outputs

In [53]:
class DynUNet_decoder(nn.Module):
    def __init__(self, out_dim, hidden_dim:Union[List, Tuple]) -> None:
        super(DynUNet_decoder, self).__init__()

        self.up1 = nn.ConvTranspose3d(in_channels=hidden_dim[6], out_channels=hidden_dim[5], kernel_size=3, stride=2, padding=1)
        self.conv1 = nn.Sequential(
            ConvBlock(in_dim=hidden_dim[6], out_dim=hidden_dim[5]),
            ConvBlock(in_dim=hidden_dim[5], out_dim=hidden_dim[5])
        )

        self.up2 = nn.ConvTranspose3d(in_channels=hidden_dim[5], out_channels=hidden_dim[4], kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Sequential(
            ConvBlock(in_dim=hidden_dim[5], out_dim=hidden_dim[4]),
            ConvBlock(in_dim=hidden_dim[4], out_dim=hidden_dim[4])
        )

        self.up3 = nn.ConvTranspose3d(in_channels=hidden_dim[4], out_channels=hidden_dim[3], kernel_size=3, stride=2, padding=1)
        self.conv3 = nn.Sequential(
            ConvBlock(in_dim=hidden_dim[4], out_dim=hidden_dim[3]),
            ConvBlock(in_dim=hidden_dim[3], out_dim=hidden_dim[3])
        )

        self.up4 = nn.ConvTranspose3d(in_channels=hidden_dim[3], out_channels=hidden_dim[2], kernel_size=3, stride=2, padding=1)
        self.conv4 = nn.Sequential(
            ConvBlock(in_dim=hidden_dim[3], out_dim=hidden_dim[2]),
            ConvBlock(in_dim=hidden_dim[2], out_dim=hidden_dim[2])
        )

        self.up5 = nn.ConvTranspose3d(in_channels=hidden_dim[2], out_channels=hidden_dim[1], kernel_size=3, stride=2, padding=1)
        self.conv5 = nn.Sequential(
            ConvBlock(in_dim=hidden_dim[2], out_dim=hidden_dim[1]),
            ConvBlock(in_dim=hidden_dim[1], out_dim=hidden_dim[1])
        )

        self.up6 = nn.ConvTranspose3d(in_channels=hidden_dim[1], out_channels=hidden_dim[0], kernel_size=3, stride=2, padding=1)
        self.conv6 = nn.Sequential(
            ConvBlock(in_dim=hidden_dim[1], out_dim=hidden_dim[0]),
            ConvBlock(in_dim=hidden_dim[0], out_dim=hidden_dim[0])
        )

        self.conv7 = nn.Conv3d(in_channels=hidden_dim[0], out_channels=out_dim, kernel_size=1, stride=1, padding=0)

    
    def forward(self, enc_out, stage_outputs):

        h = self.up1(enc_out)
        h = torch.concat([h, stage_outputs["h6"]], dim=1)
        h = self.conv1(h)

        h = self.up2(h)
        h = torch.concat([h, stage_outputs["h5"]], dim=1)
        h = self.conv2(h)

        h = self.up3(h)
        h = torch.concat([h, stage_outputs["h4"]], dim=1)
        h = self.conv3(h)

        h = self.up4(h)
        h = torch.concat([h, stage_outputs["h3"]], dim=1)
        h = self.conv4(h)

        h = self.up5(h)
        h = torch.concat([h, stage_outputs["h2"]], dim=1)
        h = self.conv5(h)

        h = self.up6(h)
        h = torch.concat([h, stage_outputs["h1"]], dim=1)
        h = self.conv6(h)

        h = self.conv7(h)

        return h

In [None]:
class DynUNet(nn.Module):
    def __init__(self, init_dim, out_dim, hidden_dim:Union[Tuple, List]) -> None:
        super(DynUNet, self).__init__()

        self.encoder = DynUNet_encoder(init_dim=init_dim, hidden_dim=hidden_dim)
        self.decoder = DynUNet_decoder(out_dim=out_dim, hidden_dim=hidden_dim)

    def forward(self, x):

        enc_out, stage_outputs = self.encoder(x)
        out = self.decoder(enc_out, stage_outputs)

        return out