In [5]:
from __init__ import *

In [173]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from segmentation_models_pytorch.encoders import get_encoder
from segmentation_models_pytorch.common.blocks import Conv2dReLU
from segmentation_models_pytorch.base.model import Model
from segmentation_models_pytorch.base import EncoderDecoder
# from segmentation_models_pytorch.unet.decoder import UnetDecoder as OrginalUnetDecoder
from segmentation_models_pytorch.unet.model import Unet

from torchsummary import summary

In [197]:
class scSE_UnetDecoder(Model):

    def __init__(
            self,
            encoder_channels,
            decoder_channels=(256, 128, 64, 32, 16),
            final_channels=1,
            use_batchnorm=True,
            center=False,
    ):
        super().__init__()

        if center:
            channels = encoder_channels[0]
            self.center = CenterBlock(channels, channels, use_batchnorm=use_batchnorm)
        else:
            self.center = None

        in_channels = self.compute_channels(encoder_channels, decoder_channels)
        out_channels = decoder_channels

        self.layer1 = scSE_DecoderBlock(in_channels[0], out_channels[0], use_batchnorm=use_batchnorm)
        self.layer2 = scSE_DecoderBlock(in_channels[1], out_channels[1], use_batchnorm=use_batchnorm)
        self.layer3 = scSE_DecoderBlock(in_channels[2], out_channels[2], use_batchnorm=use_batchnorm)
        self.layer4 = scSE_DecoderBlock(in_channels[3], out_channels[3], use_batchnorm=use_batchnorm)
        self.layer5 = scSE_DecoderBlock(in_channels[4], out_channels[4], use_batchnorm=use_batchnorm)
        self.final_conv = nn.Conv2d(out_channels[4], final_channels, kernel_size=(1, 1))

        self.initialize()

    def compute_channels(self, encoder_channels, decoder_channels):
        channels = [
            encoder_channels[0] + encoder_channels[1],
            encoder_channels[2] + decoder_channels[0],
            encoder_channels[3] + decoder_channels[1],
            encoder_channels[4] + decoder_channels[2],
            0 + decoder_channels[3],
        ]
        return channels

    def forward(self, x):
        encoder_head = x[0]
        skips = x[1:]

        if self.center:
            encoder_head = self.center(encoder_head)

        x = self.layer1([encoder_head, skips[0]])
        x = self.layer2([x, skips[1]])
        x = self.layer3([x, skips[2]])
        x = self.layer4([x, skips[3]])
        x = self.layer5([x, None])
        x = self.final_conv(x)

        return x

In [218]:
class scSE_hyper_UnetDecoder(Model):

    def __init__(
            self,
            encoder_channels,
            decoder_channels=(256, 128, 64, 32, 16),
            final_channels=1,
            use_batchnorm=True,
            center=False,
    ):
        super().__init__()

        if center:
            channels = encoder_channels[0]
            self.center = CenterBlock(channels, channels, use_batchnorm=use_batchnorm)
        else:
            self.center = None

        in_channels = self.compute_channels(encoder_channels, decoder_channels)
        out_channels = decoder_channels
        
        self.layer1 = scSE_DecoderBlock(in_channels[0], out_channels[0], use_batchnorm=use_batchnorm)
        self.layer2 = scSE_DecoderBlock(in_channels[1], out_channels[1], use_batchnorm=use_batchnorm)
        self.layer3 = scSE_DecoderBlock(in_channels[2], out_channels[2], use_batchnorm=use_batchnorm)
        self.layer4 = scSE_DecoderBlock(in_channels[3], out_channels[3], use_batchnorm=use_batchnorm)
        self.layer5 = scSE_DecoderBlock(in_channels[4], out_channels[4], use_batchnorm=use_batchnorm)
        self.final_conv = nn.Conv2d(out_channels[4], final_channels, kernel_size=(1, 1))
        
        self.logit = nn.Sequential(
            nn.Conv2d(384, 64, kernel_size=3, padding=1),
            nn.ELU(inplace=True),
            nn.Conv2d(64, 1, kernel_size=1, padding=0),
        )

        self.initialize()

    def compute_channels(self, encoder_channels, decoder_channels):
        channels = [
            encoder_channels[0] + encoder_channels[1],
            encoder_channels[2] + decoder_channels[0],
            encoder_channels[3] + decoder_channels[1],
            encoder_channels[4] + decoder_channels[2],
            0 + decoder_channels[3],
        ]
        return channels

    def forward(self, x):
        encoder_head = x[0]
        skips = x[1:]

        if self.center:
            encoder_head = self.center(encoder_head)

        d5 = self.layer1([encoder_head, skips[0]])
        d4 = self.layer2([d5, skips[1]])
        d3 = self.layer3([d4, skips[2]])
        d2 = self.layer4([d3, skips[3]])
        d1 = self.layer5([d2, None])
        d1 = self.final_conv(d1)
        
        print(skips[3].size(1), d1.size(1), d2.size(1), d3.size(1), d4.size(1), d5.size(1))
        f = torch.cat((
            F.interpolate(skips[3], scale_factor=2, mode='bilinear', align_corners=False),
            d1,
            F.interpolate(d2, scale_factor=2, mode='bilinear', align_corners=False),
            F.interpolate(d3, scale_factor=4, mode='bilinear', align_corners=False),
            F.interpolate(d4, scale_factor=8, mode='bilinear', align_corners=False),
            F.interpolate(d5, scale_factor=16, mode='bilinear', align_corners=False),
        ), 1)
        
        f = F.dropout2d(f, p=0.50)
        f = self.logit(f)
        return f

In [220]:
class Unet(EncoderDecoder):
    """Unet_ is a fully convolution neural network for image semantic segmentation

    Args:
        encoder_name: name of classification model (without last dense layers) used as feature
            extractor to build segmentation model.
        encoder_weights: one of ``None`` (random initialization), ``imagenet`` (pre-training on ImageNet).
        decoder_channels: list of numbers of ``Conv2D`` layer filters in decoder blocks
        decoder_use_batchnorm: if ``True``, ``BatchNormalisation`` layer between ``Conv2D`` and ``Activation`` layers
            is used.
        classes: a number of classes for output (output shape - ``(batch, classes, h, w)``).
        activation: activation function used in ``.predict(x)`` method for inference.
            One of [``sigmoid``, ``softmax``, callable, None]
        center: if ``True`` add ``Conv2dReLU`` block on encoder head (useful for VGG models)

    Returns:
        ``torch.nn.Module``: **Unet**

    .. _Unet:
        https://arxiv.org/pdf/1505.04597

    """

    def __init__(
            self,
            encoder_name='resnet34',
            encoder_weights='imagenet',
            decoder_use_batchnorm=True,
            decoder_channels=(256, 128, 64, 32, 16),
            scSE=False,
            hyper=False,
            classes=1,
            activation='sigmoid',
            center=False,  # usefull for VGG models
    ):
        encoder = get_encoder(
            encoder_name,
            encoder_weights=encoder_weights
        )
        
        if scSE == True and hyper==True:
            decoder_channels=(64, 64, 64, 64, 64)
            classes=64
            decoder = scSE_hyper_UnetDecoder(
                encoder_channels=encoder.out_shapes,
                decoder_channels=decoder_channels,
                final_channels=classes,
                use_batchnorm=decoder_use_batchnorm,
                center=center,
            )
        elif scSE == True and hyper==False:
            decoder = scSE_UnetDecoder(
                encoder_channels=encoder.out_shapes,
                decoder_channels=decoder_channels,
                final_channels=classes,
                use_batchnorm=decoder_use_batchnorm,
                center=center,
            )
        else:
            decoder = UnetDecoder(
                encoder_channels=encoder.out_shapes,
                decoder_channels=decoder_channels,
                final_channels=classes,
                use_batchnorm=decoder_use_batchnorm,
                center=center,
            )

        super().__init__(encoder, decoder, activation)

        self.name = 'u-{}'.format(encoder_name)


In [207]:
len(get_encoder('resnet34')(torch.randn(1, 3, 256, 256)))

5

In [222]:
Unet(scSE=True, hyper=False)(img).shape

4


torch.Size([1, 1, 256, 256])

In [25]:
img = torch.randn(1, 3, 256, 256)

In [72]:
class ConvBn2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, padding=0,
                 stride=1, use_batchnorm=True, **batchnorm_parmas):
        super().__init__()
        
        layers = [
            nn.Conv2d(in_channels, out_channels, kernel_size, 
                      stride=stride, padding=padding, bias=not (use_batchnorm))
        ]
        
        if use_batchnorm:
            layers.append(nn.BatchNorm2d(out_channels, **batchnorm_parmas))
            
        self.block = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.block(x)
    

class sSE(nn.Module):
    def __init__(self, out_channels):
        super().__init__()
        self.conv = ConvBn2d(in_channels=out_channels, out_channels=1, kernel_size=1, padding=0)
        
    def forward(self, x):
        x = self.conv(x)
        x = torch.sigmoid(x)
        return x
    
    
class cSE(nn.Module):
    def __init__(self, out_channels):
        super().__init__()
        self.linear1 = nn.Linear(in_features=out_channels, out_features=int(out_channels / 2), bias=False)
        self.linear2 = nn.Linear(in_features=int(out_channels / 2), out_features=out_channels, bias=False)
    
    def forward(self, x):
        b, c, _, _ = x.size()
        y = nn.AdaptiveAvgPool2d(1)(x).view(b, c)
        y = self.linear1(y)
        y = torch.relu(y)
        y = self.linear2(y)
        y = torch.sigmoid(y).view(b, c, 1, 1)
        return x * y.expand_as(x)

In [164]:
class scSE_DecoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels, use_batchnorm=True):
        super().__init__()
        self.block = nn.Sequential(
            Conv2dReLU(in_channels, out_channels, kernel_size=3, padding=1, use_batchnorm=use_batchnorm),
            Conv2dReLU(out_channels, out_channels, kernel_size=3, padding=1, use_batchnorm=use_batchnorm),
        )
        self.spatial_gate = sSE(out_channels)
        self.channel_gate = cSE(out_channels)

    def forward(self, x):
        x, skip = x
        x = F.interpolate(x, scale_factor=2, mode='nearest')
        if skip is not None:
            x = torch.cat([x, skip], dim=1)
        x = self.block(x)
        g1 = self.spatial_gate(x)
        g2 = self.channel_gate(x)
        x = g1 * x + g2 * x
        
        return x

In [97]:
class scSE_UnetDecoder(Model):

    def __init__(
            self,
            encoder_channels,
            decoder_channels=(256, 128, 64, 32, 16),
            final_channels=1,
            use_batchnorm=True,
            center=False,
    ):
        super().__init__()

        if center:
            channels = encoder_channels[0]
            self.center = CenterBlock(channels, channels, use_batchnorm=use_batchnorm)
        else:
            self.center = None

        in_channels = self.compute_channels(encoder_channels, decoder_channels)
        out_channels = decoder_channels

        self.layer1 = scSE_DecoderBlock(in_channels[0], out_channels[0], use_batchnorm=use_batchnorm)
        self.layer2 = scSE_DecoderBlock(in_channels[1], out_channels[1], use_batchnorm=use_batchnorm)
        self.layer3 = scSE_DecoderBlock(in_channels[2], out_channels[2], use_batchnorm=use_batchnorm)
        self.layer4 = scSE_DecoderBlock(in_channels[3], out_channels[3], use_batchnorm=use_batchnorm)
        self.layer5 = scSE_DecoderBlock(in_channels[4], out_channels[4], use_batchnorm=use_batchnorm)
        self.final_conv = nn.Conv2d(out_channels[4], final_channels, kernel_size=(1, 1))

        self.initialize()

    def compute_channels(self, encoder_channels, decoder_channels):
        channels = [
            encoder_channels[0] + encoder_channels[1],
            encoder_channels[2] + decoder_channels[0],
            encoder_channels[3] + decoder_channels[1],
            encoder_channels[4] + decoder_channels[2],
            0 + decoder_channels[3],
        ]
        return channels

    def forward(self, x):
        encoder_head = x[0]
        skips = x[1:]

        if self.center:
            encoder_head = self.center(encoder_head)

        x = self.layer1([encoder_head, skips[0]])
        x = self.layer2([x, skips[1]])
        x = self.layer3([x, skips[2]])
        x = self.layer4([x, skips[3]])
        x = self.layer5([x, None])
        x = self.final_conv(x)

        return x

In [35]:
img.shape

torch.Size([1, 3, 256, 256])

In [59]:
b, c, _, _ = img.size()

In [62]:
y = nn.AdaptiveAvgPool2d(1)(img).view(b,c)

In [63]:
y.shape

torch.Size([1, 3])

In [1]:
import torch
from segmentation_models_pytorch.unet.model import Unet

In [14]:
Unet("resnet152", encoder_weights="imagenet", scSE=True, hyper=True)(torch.randn(1, 3, 256, 256))

tensor([[[[  8.3107,  18.5793,  20.5211,  ...,  12.2965,  13.5621,   5.7937],
          [ 15.2626,  25.5527,  25.5018,  ...,  15.2186,  16.8413,  10.5566],
          [ 13.7626,  23.1354,  23.1595,  ...,  17.2275,  17.8399,  11.1414],
          ...,
          [-11.3372, -17.1666, -16.2905,  ...,  11.1589,  11.1022,   4.9724],
          [ -9.5579, -16.2302, -15.7735,  ...,  14.1393,  14.7918,   7.0590],
          [-13.6933, -16.8718, -15.7791,  ...,  15.3918,  14.4391,  12.7882]]]],
       grad_fn=<MkldnnConvolutionBackward>)