In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Expand(nn.Module):
    def __init__(self):
        super(Expand, self).__init__()

    def forward(self, x):
        return torch.unsqueeze(x, dim=0)

class Squeeze(nn.Module):
    def __init__(self):
        super(Squeeze, self).__init__()

    def forward(self, x):
        return torch.squeeze(x, dim=1)

class SE_block(nn.Module):
    """[summary]
    
    Squeeze Excite block

    """
    def __init__(self, ratio=16):
        super(SE_block, self).__init__()
        self.ratio = ratio

    def dense_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Linear(in_channels, in_channels // self.ratio, bias=False),
            nn.ReLU(),
            nn.Linear(in_channels // self.ratio, out_channels, bias=False),
            nn.Sigmoid(),
        )
        
    def forward(self, x):
        filters = x.size(1) # 64
        reshape_size = (x.size(0), 1, 1, filters)
        se = F.adaptive_avg_pool2d(x, (1, 1))
        se = torch.reshape(se, reshape_size)
        se = self.dense_block(in_channels=filters, out_channels=filters)(se)
        se = se.permute(0, 3, 1, 2)
        return x * se

class BN_block2d(nn.Module):
    """
        2-d batch-norm block
    """
    def __init__(self, in_channels, out_channels):
#TODO: Supposed to be a padding 'same' not padding 1 => calculate the padding and jot it down
        super(BN_block2d, self).__init__()
        self.bn_block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )

    def forward(self, x):
        return self.bn_block(x)

class BN_block3d(nn.Module):
    """
        3-d batch-norm block
    """
    def __init__(self, in_channels, out_channels):
        super(BN_block3d, self).__init__()
        self.bn_block = nn.Sequential(
            nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(),
            nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm3d(out_channels),
            nn.ReLU()
        )

    def forward(self, x):
        return self.bn_block(x)

class D_SE_Add(nn.Module):
    """
        D_SE_Add block
    """
    def __init__(self, in_channels, out_channels):
        super(D_SE_Add, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels

        self.squeeze_block_3d = nn.Sequential(
            nn.Conv3d(in_channels, 1, kernel_size=1, padding=0),
            Squeeze()
        )

    def squeeze_block_2d(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(),
            SE_block()
        )
        
    def forward(self, in_3d, in_2d):
        in_3d = self.squeeze_block_3d(in_3d)
        in_3d = self.squeeze_block_2d(in_3d.size(1), self.out_channels)(in_3d)
        in_2d = SE_block()(in_2d)
        return in_3d + in_2d

def up_block(in_channels, out_channels):
    return nn.Sequential(
        nn.Upsample(scale_factor=2, mode='nearest'),
        nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
        nn.ReLU()
    )

class DUnet(nn.Module):
    def __init__(self, in_channels):
        super(DUnet, self).__init__()

        self.in_channels = in_channels
        self.BN_block3d = BN_block3d
        self.BN_block2d = BN_block2d
        self.D_SE_Add = D_SE_Add
        self.MaxPool3d = nn.MaxPool3d
        self.MaxPool2d = nn.MaxPool2d
        self.Dropout = nn.Dropout
        self.Expand = Expand
        self.Conv2d = nn.Conv2d
        self.Sigmoid = nn.Sigmoid
        self.up_block = up_block

    def forward(self, x):
        input3d = self.Expand()(x) # 1, batch_size, 4, 192, 192
        input3d = input3d.permute(1, 0, 2, 3, 4) # batch, 1, 4, 192, 192

        in_channels = input3d.size(1)

        # 3d Stream
        conv3d1 = self.BN_block3d(in_channels, in_channels * 32)(input3d)
        pool3d1 = self.MaxPool3d(kernel_size=2)(conv3d1)

        conv3d2 = self.BN_block3d(in_channels * 32, in_channels * 64)(pool3d1)
        pool3d2 = self.MaxPool3d(kernel_size=2)(conv3d2)

        conv3d3 = self.BN_block3d(in_channels * 64, in_channels * 128)(pool3d2)
        
        # 2d Encoding
        in_channels = self.in_channels

        conv1 = self.BN_block2d(in_channels, in_channels * 8)(x)
        pool1 = self.MaxPool2d(kernel_size=2)(conv1)

        conv2 = self.BN_block2d(in_channels * 8, in_channels * 16)(pool1)
        conv2 = self.D_SE_Add(in_channels * 16, in_channels * 16)(conv3d2, conv2)
        pool2 = self.MaxPool2d(kernel_size=2)(conv2)

        conv3 = self.BN_block2d(in_channels * 16, in_channels * 32)(pool2)
        conv3 = self.D_SE_Add(in_channels * 32, in_channels * 32)(conv3d3, conv3)
        pool3 = self.MaxPool2d(kernel_size=2)(conv3)

        conv4 = self.BN_block2d(in_channels * 32, in_channels * 64)(pool3)
        conv4 = self.Dropout(0.3)(conv4)
        pool4 = self.MaxPool2d(kernel_size=2)(conv4)

        conv5 = self.BN_block2d(in_channels * 64, in_channels * 128)(pool4)
        conv5 = self.Dropout(0.3)(conv5)

        # Decoding

        up6 = self.up_block(in_channels * 128, in_channels * 64)(conv5)
        merge6 = conv4 + up6
        conv6 = self.BN_block2d(in_channels * 64, in_channels * 64)(merge6)

        up7 = self.up_block(in_channels * 64, in_channels * 32)(conv6)
        merge7 = conv3 + up7
        conv7 = self.BN_block2d(in_channels * 32, in_channels * 32)(merge7)

        up8 = self.up_block(in_channels * 32, in_channels * 16)(conv7)
        merge8 = conv2 + up8
        conv8 = self.BN_block2d(in_channels * 16, in_channels * 16)(merge8)

        up9 = self.up_block(in_channels * 16, in_channels * 8)(conv8)
        merge9 = conv1 + up9
        conv9 = self.BN_block2d(in_channels * 8, in_channels * 8)(merge9)

        conv10 = self.Conv2d(in_channels * 8, 1, kernel_size=1, padding=0)(conv9)
        conv10 = self.Sigmoid()(conv10)

        return conv10

In [8]:
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

batch_size = 4

x = torch.Tensor(batch_size, 4, 192, 192).to(device)
model = DUnet(4).to(device)

output = model(x)

print(output.size()) # (batch_size, 1, 192, 192)

torch.Size([4, 1, 4, 192, 192])
torch.Size([4, 1, 192, 192])
