**V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation** in progress   
*Fausto Milletari, Nassir Navab, Seyed-Ahmad Ahmadi*   
[[paper](https://arxiv.org/abs/1606.04797)]   
3DV 2016 

In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [11]:
# Encoder; 'compression path'
# ... The left part of the network consists of a compress, ...
class CompressionPath(nn.Module):
    def __init__(self) -> None:
        super(CompressionPath, self).__init__()

        self.conv1 = nn.Sequential(
                nn.Conv3d(1, 16, kernel_size=5, stride=1, padding=2),
                nn.PReLU()
            )
        self.down1 = nn.Sequential(
                nn.Conv3d(16, 32, kernel_size=2, stride=2, padding=0),
                nn.PReLU
            )

        self.conv2 = nn.Sequential(
                nn.Conv3d(32, 32, kernel_size=5, stride=1, padding=2),
                nn.PReLU(),
                nn.Conv3d(32, 32, kernel_size=5, stride=1, padding=2),
                nn.PReLU()
            )
        self.down2 = nn.Sequential(
                nn.Conv3d(32, 64, kernel_size=2, stride=2, padding=0),
                nn.PReLU
            )
        
        self.conv3 = nn.Sequential(
                nn.Conv3d(64, 64, kernel_size=5, stride=1, padding=2),
                nn.PReLU(),
                nn.Conv3d(64, 64, kernel_size=5, stride=1, padding=2),
                nn.PReLU(),
                nn.Conv3d(64, 64, kernel_size=5, stride=1, padding=2),
                nn.PReLU()
            )
        self.down3 = nn.Sequential(
                nn.Conv3d(64, 128, kernel_size=2, stride=2, padding=0),
                nn.PReLU
            )

        self.conv4 = nn.Sequential(
                nn.Conv3d(128, 128, kernel_size=5, stride=1, padding=2),
                nn.PReLU(),
                nn.Conv3d(128, 128, kernel_size=5, stride=1, padding=2),
                nn.PReLU(),
                nn.Conv3d(128, 128, kernel_size=5, stride=1, padding=2),
                nn.PReLU()
            )
        self.down4 = nn.Sequential(
                nn.Conv3d(128, 256, kernel_size=2, stride=2, padding=0),
                nn.PReLU
            )

        self.conv5 = nn.Sequential(
                nn.Conv3d(256, 256, kernel_size=5, stride=1, padding=2),
                nn.PReLU(),
                nn.Conv3d(256, 256, kernel_size=5, stride=1, padding=2),
                nn.PReLU(),
                nn.Conv3d(256, 256, kernel_size=5, stride=1, padding=2),
                nn.PReLU()
            )

        

    def forward(self, x) :
        
        id1 = x
        h1  = self.conv1(x)
        h1  = h1 + id1
        d1  = self.down1(h1)

        id2 = d1
        h2  = self.conv2(d1)
        h2  = h2 + id2
        d2  = self.down2(h2)

        id3 = d2 
        h3  = self.conv3(d2)
        h3  = h3 + id3 
        d3  = self.down3(h3)

        id4 = d3 
        h4  = self.conv4(d3)
        h4  = h4 + id4
        d4  = self.down4(h4)

        id5 = d4
        h5  = self.conv5(d4)
        h5  = h5 + id5

        stage_outputs = [h1, h2, h3, h4] # forward the features extracted from early stages of the left part of the CNN to the right part

        return h5, stage_outputs


In [12]:
# decoder; 'decompression path'
# while the right part decompresses the signal until its original size is reached.
class DecompressionPath(nn.Module):
    def __init__(self) -> None:
        super(DecompressionPath, self).__init__()

        self.up1 = nn.Sequential(
                nn.ConvTranspose3d(256, 128, kernel_size=2, stride=2, padding=0),
                nn.PReLU()
            )
        self.conv1 = nn.Sequential(
                nn.Conv3d(256, 256, kernel_size=5, stride=1, padding=2),
                nn.PReLU(),
                nn.Conv3d(256, 256, kernel_size=5, stride=1, padding=2),
                nn.PReLU(),
                nn.Conv3d(256, 256, kernel_size=5, stride=1, padding=2),
                nn.PReLU()
            )

        self.up2 = nn.Sequential(
                nn.ConvTranspose3d(128, 64, kernel_size=2, stride=2, padding=0),
                nn.PReLU()
            )
        self.conv2 = nn.Sequential(
                nn.Conv3d(128, 128, kernel_size=5, stride=1, padding=2),
                nn.PReLU(),
                nn.Conv3d(128, 128, kernel_size=5, stride=1, padding=2),
                nn.PReLU(),
                nn.Conv3d(128, 128, kernel_size=5, stride=1, padding=2),
                nn.PReLU()
            )
        
        self.up3 = nn.Sequential(
                nn.ConvTranspose3d(64, 32, kernel_size=2, stride=2, padding=0),
                nn.PReLU()
            )
        self.conv3 = nn.Sequential(
                nn.Conv3d(64, 64, kernel_size=5, stride=1, padding=2),
                nn.PReLU(),
                nn.Conv3d(64, 64, kernel_size=5, stride=1, padding=2),
                nn.PReLU(),
                nn.Conv3d(64, 64, kernel_size=5, stride=1, padding=2),
                nn.PReLU()
            )

        self.up4 = nn.Sequential(
                nn.ConvTranspose3d(32, 16, kernel_size=2, stride=2, padding=0),
                nn.PReLU()
            )
        self.conv4 = nn.Sequential(
                nn.Conv3d(32, 32, kernel_size=5, stride=1, padding=2),
                nn.PReLU(),
                nn.Conv3d(32, 32, kernel_size=5, stride=1, padding=2),
                nn.PReLU(),
                nn.Conv3d(32, 32, kernel_size=5, stride=1, padding=2),
                nn.PReLU()
            )

        self.conv5 = nn.Sequential(
                nn.Conv3d(32, 32, kernel_size=1, stride=1, padding=0),
                nn.PReLU()
            )

    def forward(self, enc_out, stage_outputs):
        
            u1  = self.up1(enc_out)
            id1 = u1
            h1  = torch.cat([u1, stage_outputs[-1]], dim=1)
            h1  = self.conv1(h1)
            h1  = h1 + id1

            u2  = self.up2(h1)
            id2 = u2
            h2  = torch.cat([u2, stage_outputs[-2]], dim=1)
            h2  = self.conv2(h2)
            h2  = h2 + id2

            u3  = self.up3(h2)
            id3 = u3
            h3  = torch.cat([u3, stage_outputs[-3]], dim=1)
            h3  = self.conv3(h3)
            h3  = h3 + id3

            u4  = self.up4(h3)
            id4 = u4
            h4  = torch.cat([u4, stage_outputs[-4]], dim=1)
            h4  = self.conv4(h4)
            h4  = h4 + id4

            out = self.conv5(h4)

            return out

In [14]:
class VNet(nn.Module):
    def __init__(self) -> None:
        super(VNet, self).__init__()

        self.encoder = CompressionPath()
        self.decoder = DecompressionPath()

    
    def forward(self, x, activate=False):

        enc_out, stage_outputs = self.encoder(x)
        dec_out = self.decoder(enc_out, stage_outputs)


        if not activate:
            return dec_out
        else:
            output = F.softmax(dec_out, dim=1)
        return output
        

In [16]:
# dice loss
def DiceLoss(inputs, targets, smooth=1):

    inputs  = F.sigmoid(inputs)
    
    inputs  = inputs.view(-1)
    targets = targets.view(-1)

    dice_coef = ((2.0 * inputs * targets).sum() + smooth )/ ((inputs**2).sum() + (targets**2).sum() + smooth)

    return 1 - dice_coef