**SegResNet_VAE: 3D MRI brain tumor segmentation using autoencoder regularization**    
*Andriy Myronenko*   
[[paper](https://arxiv.org/abs/1810.11654]   
MICCAI 2018   

In [1]:
import torch
import torch.nn as nn

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class ResBlock(nn.Module):
    def __init__(self, in_dim, out_dim) -> None:
        super(ResBlock, ResBlock).__init__()

        self.conv = nn.Sequential(
            nn.GroupNorm(num_channels=in_dim, num_groups=8),
            nn.ReLU(),
            nn.Conv3d(in_channels=in_dim, out_channels=out_dim, kernel_size=3, stride=1, padding=1),
            nn.GroupNorm(num_groups=8, num_channels=out_dim),
            nn.ReLU(),
            nn.Conv3d(in_channels=out_dim, out_channels=out_dim, kernel_size=3, stride=1, padding=1)
        )

    def forward(self, x):

        residual = x

        h = self.conv(x)
        h += residual

        return h

In [3]:
class Encoder(nn.Module):
    def __init__(self, init_dim, hidden_dim) -> None:
        super(Encoder, self).__init__()

        self.init_conv = nn.Sequential(
            nn.GroupNorm(num_groups=8, num_channels=init_dim),
            nn.ReLU(),
            nn.Conv3d(in_channels=init_dim, out_channels=hidden_dim[0], kernel_size=3, stride=1, padding=1)
        )

        self.layer1 = ResBlock(hidden_dim[0], hidden_dim[0])
        
        self.layer2 = nn.Sequential(
            nn.GroupNorm(num_groups=8, num_channels=hidden_dim[0]),
            nn.ReLU(),
            nn.Conv3d(in_channels=hidden_dim[0],out_channels=hidden_dim[1]),
            ResBlock(hidden_dim[1], hidden_dim[1]),
            ResBlock(hidden_dim[1], hidden_dim[1])
        )

        self.layer3 = nn.Sequential(
            nn.GroupNorm(num_groups=8, num_channels=hidden_dim[1]),
            nn.ReLU(),
            nn.Conv3d(in_channels=hidden_dim[1],out_channels=hidden_dim[2]),
            ResBlock(hidden_dim[2], hidden_dim[2]),
            ResBlock(hidden_dim[2], hidden_dim[2])
        )

        self.layer4 = nn.Sequential(
            nn.GroupNorm(num_groups=8, num_channels=hidden_dim[2]),
            nn.ReLU(),
            nn.Conv3d(in_channels=hidden_dim[2],out_channels=hidden_dim[3]),
            ResBlock(hidden_dim[3], hidden_dim[3]),
            ResBlock(hidden_dim[3], hidden_dim[3]),
            ResBlock(hidden_dim[3], hidden_dim[3]),
            ResBlock(hidden_dim[3], hidden_dim[3])
        )

    def forward(self, x):

        h = self.init_conv(x)

        h1 = self.layer1(h)
        h2 = self.layer2(h1)
        h3 = self.layer3(h2)
        h4 = self.layer4(h3)

        stage_outputs = {
            'h1':h1,
            'h2':h2,
            'h3':h3
        }

        return h4, stage_outputs


In [4]:
class Decoder(nn.Module):
    def __init__(self, out_dim, hidden_dim) -> None:
        super(Decoder, self).__init__()

        self.up1 = nn.Sequential(
            # nn.GroupNorm(num_groups=8, num_channels=hidden_dim[3]),
            # nn.ReLU(),
            nn.Conv3d(in_channels=hidden_dim[3], out_channels=hidden_dim[2], kernel_size=1, stride=1, padding=0),
            nn.Upsample(scale_factor=2, mode='trilinear')
        )
        self.conv1 = ResBlock(hidden_dim[2], hidden_dim[2])

        self.up2 = nn.Sequential(
            # nn.GroupNorm(num_groups=8, num_channels=hidden_dim[2]),
            # nn.ReLU(),
            nn.Conv3d(in_channels=hidden_dim[2], out_channels=hidden_dim[1], kernel_size=1, stride=1, padding=0),
            nn.Upsample(scale_factor=2, mode='trilinear')
        )
        self.conv2 = ResBlock(hidden_dim[1], hidden_dim[1])

        self.up3 = nn.Sequential(
            # nn.GroupNorm(num_groups=8, num_channels=hidden_dim[1]),
            # nn.ReLU(),
            nn.Conv3d(in_channels=hidden_dim[1], out_channels=hidden_dim[0], kernel_size=1, stride=1, padding=0),
            nn.Upsample(scale_factor=2, mode='trilinear')
        )
        self.conv3 = ResBlock(hidden_dim[0], hidden_dim[0])

        self.conv4 = nn.Sequential(
            nn.GroupNorm(num_groups=8, num_channels=hidden_dim[0]),
            nn.ReLU(),
            nn.Conv3d(in_channels=hidden_dim[0], out_channels=out_dim, kernel_size=1, stride=1, padding=0)
        )
    
    def forward(self, enc_out, stage_outputs):

        h1 = self.up1(enc_out)
        h1 = h1 + stage_outputs["h3"] # "... addition of encoder output"
        h1 = self.conv1(h1)

        h2 = self.up2(h1)
        h2 = h2 + stage_outputs["h2"]
        h2 = self.conv2(h2)

        h3 = self.up3(h2)
        h3 = h3 + stage_outputs["h1"]
        h3 = self.conv3(h3)

        h4 = self.conv4(h3)

        return h4

In [None]:
class VAE_decoder(nn.Module):
    def __init__(self, out_dim, hidden_dim, enc_output_size = (20,24,16)) -> None:
        super(VAE_decoder, self).__init__()

        self.latent_dim = hidden_dim[-1] // 2

        self.VD = nn.Sequential(
            nn.GroupNorm(num_groups=8, num_channels=hidden_dim),
            nn.ReLU(),
            nn.Conv3d(in_channels=hidden_dim[3], out_channels=hidden_dim[3], kernel_size=enc_output_size, stride=1),
            nn.Conv3d(in_features=hidden_dim[3], out_features=hidden_dim[3], kernel_size=1, stride=1, padding=0)
        )

        # self.VDraw = torch.normal 

        self.VU = nn.Sequential(
            nn.Conv3d(in_features=hidden_dim[3]//2, out_features=hidden_dim[3]//2, kernel_size=1, stride=1, padding=0),
            nn.ReLU(),
            nn.Conv3d(in_channels=hidden_dim[3]//2, out_channels=hidden_dim[3], kernel_size=1, stride=1, padding=0),
            nn.Upsample(size=enc_output_size, mode='trilinear')
        )

        self.VUp2 = nn.Sequential(
            nn.Conv3d(in_channels=hidden_dim[3], out_channels=hidden_dim[2], kernel_size=1, stride=1, padding=0),
            nn.Upsample(scale_factor=2, mode='trilinear')
        )
        self.VBlock2 = ResBlock(in_dim=hidden_dim[2], out_dim=hidden_dim[2])

        self.VUp1 = nn.Sequential(
            nn.Conv3d(in_channels=hidden_dim[2], out_channels=hidden_dim[1], kernel_size=1, stride=1, padding=0),
            nn.Upsample(scale_factor=2, mode='trilinear')
        )
        self.VBlock1 = ResBlock(in_dim=hidden_dim[1], out_dim=hidden_dim[1])

        self.VUp0 = nn.Sequential(
            nn.Conv3d(in_channels=hidden_dim[1], out_channels=hidden_dim[0], kernel_size=1, stride=1, padding=0),
            nn.Upsample(scale_factor=2, mode='trilinear')
        )
        self.VBlock0 = ResBlock(in_dim=hidden_dim[0], out_dim=hidden_dim[0])

        self.Vend = nn.Conv3d(in_channels=hidden_dim[0], out_channels=out_dim, kernel_size=1, stride=1)
    
    def forward(self, enc_out):

        z = self.VD(enc_out)
        z = torch.normal(mean=z[:self.latent_dim], std=z[self.latent_dim:])

        h = self.VU(z)
        h = self.VUp2(h)
        h = self.VBlock2(h)
        h = self.VUp1(h)
        h = self.VBlock1(h)
        h = self.VUp0(h)
        h = self.VBlock0(h)

        h = self.Vend(h)

        return h

In [None]:
class SegResNet_VAE(nn.Module):
    def __init__(self, init_dim, out_dim, VAE_out_dim, enc_out_size, hidden_dim:list) -> None:
        super(SegResNet_VAE, self).__init__()


        self.encoder = Encoder(init_dim=init_dim, hidden_dim=hidden_dim)
        self.decoder = Decoder(out_dim=out_dim, hidden_dim=hidden_dim)
        self.VAE_dec = VAE_decoder(out_dim=VAE_out_dim, hidden_dim=hidden_dim)

    def forward(self, x):

        enc_out, stage_outputs = self.encoder(x)
        out = self.decoder(enc_out, stage_outputs)

        if self.training:
            VAE_out = self.VAE_dec(enc_out)
            return out, VAE_out
        else:
            return out