In [63]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Expand(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return torch.unsqueeze(x, dim=0)

class Squeeze(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return torch.squeeze(x, dim=1)

class SE_block(nn.Module):
    """[summary]
    
    Squeeze Excite block

    """
    def __init__(self, in_channels, ratio=16):
        super().__init__()
        self.avg_2d = F.adaptive_avg_pool2d
        
        self.filters = x.size(1)
        self.dense_block = nn.Sequential(
            nn.Linear(in_channels, in_channels // ratio, bias=False),
            nn.ReLU(),
            nn.Linear(in_channels // ratio, in_channels, bias=False),
            nn.Sigmoid(),
        )

    def forward(self, x):
        filters = x.size(1)
        reshape_size = (x.size(0), 1, 1, filters)
        se = self.avg_2d(x, (1, 1))
        se = torch.reshape(se, reshape_size)
        se = self.dense_block(se)
        se = se.permute(0, 3, 1, 2)
        return x * se

class BN_block2d(nn.Module):
    """
        2-d batch-norm block
    """
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.bn_block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )

    def forward(self, x):
        return self.bn_block(x)

class BN_block3d(nn.Module):
    """
        3-d batch-norm block
    """
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.bn_block = nn.Sequential(
            nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(),
            nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm3d(out_channels),
            nn.ReLU()
        )

    def forward(self, x):
        return self.bn_block(x)

class D_SE_Add(nn.Module):
    """
        D_SE_Add block
    """
    def __init__(self, in_channels, out_channels, mid_channels):
        super().__init__()
        
        self.conv3d_ = nn.Conv3d(in_channels, 1, kernel_size=1, padding=0)
        self.Squeeze = Squeeze()
        self.conv2d_ = nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1)
        self.ReLU = nn.ReLU()
        
        self.SE_block = SE_block(in_channels)
        
        self.squeeze_block_3d = nn.Sequential(
            nn.Conv3d(in_channels, 1, kernel_size=1, padding=0),
            Squeeze(),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(),
            SE_block(in_channels)
        )
        
    def forward(self, in_3d, in_2d):
        in_2d = self.SE_block(in_2d)
        in_3d = self.squeeze_block_3d(in_3d)

        return in_3d + in_2d

def up_block(in_channels, out_channels):
    return nn.Sequential(
        nn.Upsample(scale_factor=2, mode='nearest'),
        nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
        nn.ReLU()
    )

class DUnet(nn.Module):
    def __init__(self, in_channels):
        super().__init__()

        self.in_channels = in_channels
        in_channels_3d = 1
        
        self.Expand = Expand
        self.MaxPool3d = nn.MaxPool3d(kernel_size=2)
        self.MaxPool2d = nn.MaxPool2d(kernel_size=2)
        self.Dropout = nn.Dropout(0.3)
        
        # 3d down
        self.bn_3d_1 = BN_block3d(in_channels_3d, in_channels_3d * 32)
        self.bn_3d_2 = BN_block3d(in_channels_3d * 32, in_channels_3d * 64)
        self.bn_3d_3 = BN_block3d(in_channels_3d * 64, in_channels_3d * 128)
        
        # 2d down
        
        self.bn_2d_1 = BN_block2d(in_channels, in_channels * 8)

        self.bn_2d_2 = BN_block2d(in_channels * 8, in_channels * 16)
        self.se_add_2 = D_SE_Add(in_channels * 16, in_channels * 16, 2)
        
        self.bn_2d_3 = BN_block2d(in_channels * 16, in_channels * 32)
        self.se_add_3 = D_SE_Add(in_channels * 32, in_channels * 32, 1)
        
        self.bn_2d_4 = BN_block2d(in_channels * 32, in_channels * 64)
        self.bn_2d_5 = BN_block2d(in_channels * 64, in_channels * 128)
        
        # up
        
        self.up_block_1 = up_block(in_channels * 128, in_channels * 64)
        self.bn_2d_6 = BN_block2d(in_channels * 64, in_channels * 64)
        
        self.up_block_2 = up_block(in_channels * 64, in_channels * 32)
        self.bn_2d_7 = BN_block2d(in_channels * 32, in_channels * 32)
        
        self.up_block_3 = up_block(in_channels * 32, in_channels * 16)
        self.bn_2d_8 = BN_block2d(in_channels * 16, in_channels * 16)
        
        self.up_block_4 = up_block(in_channels * 16, in_channels * 8)
        self.bn_2d_9 = BN_block2d(in_channels * 8, in_channels * 8)
        
        self.conv_10 = nn.Sequential(
            nn.Conv2d(in_channels * 8, 1, kernel_size=1, padding=0),
            nn.Sigmoid()
        )

    def forward(self, x):
        input3d = self.Expand()(x) # 1, batch_size, 4, 192, 192
        input3d = input3d.permute(1, 0, 2, 3, 4) # batch_size, 1, 4, 192, 192

        # 3d Stream
        conv3d1 = self.bn_3d_1(input3d)
        pool3d1 = self.MaxPool3d(conv3d1)

        conv3d2 = self.bn_3d_2(pool3d1)
        pool3d2 = self.MaxPool3d(conv3d2)

        conv3d3 = self.bn_3d_3(pool3d2)
        
        # 2d Encoding
        in_channels = self.in_channels

        conv1 = self.bn_2d_1(x)
        pool1 = self.MaxPool2d(conv1)

        conv2 = self.bn_2d_2(pool1)
        conv2 = self.se_add_2(conv3d2, conv2)
        pool2 = self.MaxPool2d(conv2)

        conv3 = self.bn_2d_3(pool2)
        conv3 = self.se_add_3(conv3d3, conv3)
        pool3 = self.MaxPool2d(conv3)

        conv4 = self.bn_2d_4(pool3)
        conv4 = self.Dropout(conv4)
        pool4 = self.MaxPool2d(conv4)

        conv5 = self.bn_2d_5(pool4)
        conv5 = self.Dropout(conv5)

        # Decoding

        up6 = self.up_block_1(conv5)
        merge6 = conv4 + up6
        conv6 = self.bn_2d_6(merge6)

        up7 = self.up_block_2(conv6)
        merge7 = conv3 + up7
        conv7 = self.bn_2d_7(merge7)

        up8 = self.up_block_3(conv7)
        merge8 = conv2 + up8
        conv8 = self.bn_2d_8(merge8)

        up9 = self.up_block_4(conv8)
        merge9 = conv1 + up9
        conv9 = self.bn_2d_9(merge9)
        
        conv10 = self.conv_10(conv9)

        return conv10

In [66]:
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

batch_size = 8

x = torch.Tensor(batch_size, 4, 192, 192).to(device)
model = DUnet(4).to(device)

output = model(x)

print(output.size()) # (batch_size, 1, 192, 192)

torch.Size([8, 64, 2, 96, 96]) torch.Size([8, 64, 96, 96])
torch.Size([8, 1, 192, 192])


In [4]:
x.type()

'torch.cuda.FloatTensor'