Inspired by https://github.com/kevinlu1211/pytorch-unet-resnet-50-encoder/blob/master/u_net_resnet_50_encoder.py

In [None]:
import torch
import torch.nn as nn
import torchvision
import torch.nn.functional as F

In [None]:
# Conv -> BN -> ReLU
class ConvBlock(nn.Module):

    def __init__(self, channels_in, channels_out, padding=1, kernel_size=3, stride=1, act_fct=nn.ReLU()):
        super().__init__()
        self.conv = nn.Conv2d(channels_in, channels_out, padding=padding, kernel_size=kernel_size, stride=stride)
        self.bn = nn.BatchNorm2d(channels_out)
        self.act_fct = act_fct
        self.dropout = nn.Dropout(0.3)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.act_fct(x)
        x = self.dropout(x)

        return x

In [None]:
# Middle layer of U-net
class Bridge(nn.Module):
    def __init__(self, channels_in, channels_out):
        super().__init__()
        self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.conv1 = ConvBlock(channels_in, channels_out)
        self.conv2 = ConvBlock(channels_out, channels_out)

    def forward(self, x):
        x = self.max_pool(x)
        x = self.conv1(x)
        x = self.conv2(x)
        return x

In [None]:
# Middle layer of U-net (no max pooling)
class Bridge_no_pool(nn.Module):
    def __init__(self, channels_in, channels_out):
        super().__init__()
        self.conv1 = ConvBlock(channels_in, channels_out)
        self.conv2 = ConvBlock(channels_out, channels_out)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        return x

In [None]:
# Upsample -> Concatenate -> ConvBlock -> ConvBlock
class UpBlock(nn.Module):

    def __init__(self, channels_in, channels_out, up_channels_in=None, up_channels_out=None):
        super().__init__()

        if up_channels_in is None:
            up_channels_in = channels_in
        if up_channels_out is None:
            up_channels_out = channels_out

        # Double the resolution
        self.upsample = nn.ConvTranspose2d(up_channels_in, up_channels_out, kernel_size=2, stride=2)

        self.conv_block_1 = ConvBlock(channels_in, channels_out)
        self.conv_block_2 = ConvBlock(channels_out, channels_out)

    def forward(self, x_up, x_down):
        x = self.upsample(x_up)
        x = torch.cat([x, x_down], 1)
        x = self.conv_block_1(x)
        x = self.conv_block_2(x)
        return x

In [None]:
class UNet50(nn.Module):
    DEPTH = 6

    def __init__(self, n_classes,rgb=True):
        super().__init__()
        self.rgb = rgb
        resnet = torchvision.models.resnet.resnet50(pretrained=True)
        blocks_up = []
        blocks_down = []
        self.conv0 = torch.nn.Conv2d(1,3,1,1)
        self.input_block = nn.Sequential(*list(resnet.children()))[:3]
        self.input_pool = list(resnet.children())[3]
        
        for bottleneck in list(resnet.children()):
            if isinstance(bottleneck, nn.Sequential):
                blocks_down.append(bottleneck)
        self.blocks_down = nn.ModuleList(blocks_down)

        self.bridge = Bridge(2048, 4096)

        blocks_up.append(UpBlock(4096, 2048))
        blocks_up.append(UpBlock(2048, 1024))
        blocks_up.append(UpBlock(1024, 512))
        blocks_up.append(UpBlock(512, 256))
        blocks_up.append(UpBlock(channels_in=128 + 64, channels_out=128, up_channels_in=256, up_channels_out=128))
        blocks_up.append(UpBlock(channels_in=64 + 3, channels_out=64, up_channels_in=128, up_channels_out=64))
        self.blocks_up = nn.ModuleList(blocks_up)

        self.out = nn.Conv2d(64, n_classes, kernel_size=1, stride=1)

        self.softmax = nn.Softmax2d()

    def forward(self, x):
        if not self.rgb:
            x = self.conv0(x)
        pre_pools = dict()
        pre_pools[f"layer_0"] = x
        x = self.input_block(x)
        pre_pools[f"layer_1"] = x
        x = self.input_pool(x)

        for i, block in enumerate(self.blocks_down, 2):
            x = block(x)
            key = f"layer_{i}"
            pre_pools[key] = x

        x = self.bridge(x)

        for i, block in enumerate(self.blocks_up, 1):
            key = f"layer_{UNet50.DEPTH - i}"
            x = block(x, pre_pools[key])

        x = self.out(x)
        x = self.softmax(x)
        del pre_pools
        return x

In [None]:
class UNet18(nn.Module):
  DEPTH = 6

  def __init__(self, n_classes, rgb = True):
    super().__init__()
    
    blocks_up = []
    blocks_down = []
    resnet = torchvision.models.resnet.resnet18(pretrained=True)

    self.rgb = rgb
    self.conv0 = torch.nn.Conv2d(1,3,1,1)
    self.input_block = nn.Sequential(*list(resnet.children()))[:3]
    self.input_pool = list(resnet.children())[3]

    # Down blocks
    for bottleneck in list(resnet.children()):
      if isinstance(bottleneck, nn.Sequential):
        blocks_down.append(bottleneck)
    self.blocks_down = nn.ModuleList(blocks_down)

    # Bridge
    self.bridge = Bridge_no_pool(512, 512)

    # Up blocks
    blocks_up.append(UpBlock(512, 256))
    blocks_up.append(UpBlock(256, 128))
    blocks_up.append(UpBlock(128, 64))
    blocks_up.append(UpBlock(channels_in=128, channels_out=64, up_channels_in=64, up_channels_out=64))
    blocks_up.append(UpBlock(channels_in=64 + 3, channels_out=64, up_channels_in=64, up_channels_out=64))
    self.blocks_up = nn.ModuleList(blocks_up)

    self.out = nn.Conv2d(64, n_classes, kernel_size=1, stride=1)
    self.softmax = nn.Softmax2d()

  def forward(self, x):
    if not self.rgb:
      x = self.conv0(x)
    pre_pools = dict()
    pre_pools[f"layer_0"] = x
    x = self.input_block(x)
    pre_pools[f"layer_1"] = x
    x = self.input_pool(x)

    for i, block in enumerate(self.blocks_down, 2):
      x = block(x)
      key = f"layer_{i}"
      pre_pools[key] = x

    x = self.bridge(x)

    for i, block in enumerate(self.blocks_up, 1):
      key = f"layer_{UNet50.DEPTH - 1 - i}"
      x = block(x, pre_pools[key])

    x = self.out(x)
    x = self.softmax(x)
    del pre_pools
    return x
    

In [None]:
class DiceLoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(DiceLoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):
        # comment out if your model contains a sigmoid or equivalent activation layer
        inputs = torch.sigmoid(inputs)

        # flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)

        intersection = (inputs * targets).sum()
        dice = (2. * intersection + smooth) / (inputs.sum() + targets.sum() + smooth)

        return 1 - dice

In [None]:
ALPHA = 0.5  # < 0.5 penalises FP more, > 0.5 penalises FN more
CE_RATIO = 0.5  # weighted contribution of modified CE loss compared to Dice loss

class ComboLoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(ComboLoss, self).__init__()

    def forward(self, inputs, targets, smooth=1, ALPHA=ALPHA, CE_RATIO=CE_RATIO, eps=1e-9):
        # flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)

        # True Positives, False Positives & False Negatives
        intersection = (inputs * targets).sum()
        dice = (2. * intersection + smooth) / (inputs.sum() + targets.sum() + smooth)

        inputs = torch.clamp(inputs, eps, 1.0 - eps)
        out = - (ALPHA * ((targets * torch.log(inputs)) + ((1 - ALPHA) * (1.0 - targets) * torch.log(1.0 - inputs))))
        weighted_ce = out.mean(-1)
        combo = (CE_RATIO * weighted_ce) - ((1 - CE_RATIO) * dice)

        return combo