In [1]:
import os

from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Module
from torch.nn import init
import torchvision.models

In [2]:
def weights_init_kaiming(m):
    """Initialize weights according to method describe here:
    https://www.cv-foundation.org/openaccess/content_iccv_2015/papers/He_Delving_Deep_into_ICCV_2015_paper.pdf
    """

    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
    elif classname.find('Linear') != -1:
        init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
    elif classname.find('BatchNorm') != -1:
        init.normal_(m.weight.data, 1.0, 0.02)
        init.constant_(m.bias.data, 0.0)

In [3]:
class DownConv(Module):

    def __init__(self, in_feat, out_feat, drop_rate=0.4, bn_momentum=0.1, is_2d=True):
        super(DownConv, self).__init__()
        if is_2d:
            conv = nn.Conv2d
            bn = nn.BatchNorm2d
            dropout = nn.Dropout2d
        else:
            conv = nn.Conv3d
            bn = nn.InstanceNorm3d
            dropout = nn.Dropout3d

        self.conv1 = conv(in_feat, out_feat, kernel_size=3, padding=1)
        self.conv1_bn = bn(out_feat, momentum=bn_momentum)
        self.conv1_drop = dropout(drop_rate)

        self.conv2 = conv(out_feat, out_feat, kernel_size=3, padding=1)
        self.conv2_bn = bn(out_feat, momentum=bn_momentum)
        self.conv2_drop = dropout(drop_rate)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.conv1_bn(x)
        x = self.conv1_drop(x)

        x = F.relu(self.conv2(x))
        x = self.conv2_bn(x)
        x = self.conv2_drop(x)
        return x

In [4]:
class UpConv(Module):

    def __init__(self, in_feat, out_feat, drop_rate=0.4, bn_momentum=0.1, is_2d=True):
        super(UpConv, self).__init__()
        self.is_2d = is_2d
        self.downconv = DownConv(in_feat, out_feat, drop_rate, bn_momentum, is_2d) # chm

    def forward(self, x, y):
        # For retrocompatibility purposes
        if not hasattr(self, "is_2d"):
            self.is_2d = True
        mode = 'bilinear' if self.is_2d else 'trilinear'
        dims = -2 if self.is_2d else -3
        x = F.interpolate(x, size=y.size()[dims:], mode=mode, align_corners=True)
        x = torch.cat([x, y], dim=1)
        x = self.downconv(x)
        return x

In [5]:
# put it in the code 
class _GridAttentionBlockND(nn.Module):
    def __init__(self, in_channels, gating_channels, inter_channels=None, dimension=3,
                 sub_sample_factor=2):
        super(_GridAttentionBlockND, self).__init__()

        assert dimension in [2, 3] # for debugging

        # Downsampling rate for the input featuremap
        if isinstance(sub_sample_factor, tuple): self.sub_sample_factor = sub_sample_factor
        elif isinstance(sub_sample_factor, list): self.sub_sample_factor = tuple(sub_sample_factor)
        else: self.sub_sample_factor = tuple([sub_sample_factor]) * dimension

        # Default parameter set
        self.dimension = dimension
        self.sub_sample_kernel_size = self.sub_sample_factor

        # Number of channels (pixel dimensions)
        self.in_channels = in_channels 
        self.gating_channels = gating_channels
        self.inter_channels = inter_channels

        if self.inter_channels is None:
            self.inter_channels = in_channels // 2
            if self.inter_channels == 0:
                self.inter_channels = 1

        if dimension == 3:
            conv_nd = nn.Conv3d
            bn = nn.BatchNorm3d
            self.upsample_mode = 'trilinear'
        elif dimension == 2:
            conv_nd = nn.Conv2d
            bn = nn.BatchNorm2d
            self.upsample_mode = 'bilinear'
        else:
            raise NotImplemented

        # Output transform
        self.W = nn.Sequential(
            conv_nd(in_channels=self.in_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0),
            bn(self.in_channels),
        )

        # Theta^T * x_ij + Phi^T * gating_signal + bias
        self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
                             kernel_size=self.sub_sample_kernel_size, stride=self.sub_sample_factor, padding=0, bias=False)
        self.phi = conv_nd(in_channels=self.gating_channels, out_channels=self.inter_channels,
                           kernel_size=1, stride=1, padding=0, bias=True)
        self.psi = conv_nd(in_channels=self.inter_channels, out_channels=1, kernel_size=1, stride=1, padding=0, bias=True)

        # Initialise weights
        for m in self.children():
            weights_init_kaiming(m)
         
        # Define the operation
        self.operation_function = self._concatenation

    def forward(self, x, g):
        '''
        :param x: (b, c, t, h, w)
        :param g: (b, g_d)
        :return:
        '''

        output = self.operation_function(x, g)
        return output

    def _concatenation(self, x, g):
        input_size = x.size()
        batch_size = input_size[0]
        assert batch_size == g.size(0)

        # theta => (b, c, t, h, w) -> (b, i_c, t, h, w) -> (b, i_c, thw)
        # phi   => (b, g_d) -> (b, i_c)
        theta_x = self.theta(x)
        theta_x_size = theta_x.size()

        # g (b, c, t', h', w') -> phi_g (b, i_c, t', h', w')
        #  Relu(theta_x + phi_g + bias) -> f = (b, i_c, thw) -> (b, i_c, t/s1, h/s2, w/s3)
        phi_g = F.upsample(self.phi(g), size=theta_x_size[2:], mode=self.upsample_mode)
        f = F.relu(theta_x + phi_g, inplace=True)


        #  psi^T * f -> (b, psi_i_c, t/s1, h/s2, w/s3)
        sigm_psi_f = F.sigmoid(self.psi(f))

        # upsample the attentions and multiply
        sigm_psi_f = F.upsample(sigm_psi_f, size=input_size[2:], mode=self.upsample_mode)

        y = sigm_psi_f.expand_as(x) * x

        W_y = self.W(y)

        return W_y, sigm_psi_f


class GridAttentionBlock2D(_GridAttentionBlockND):
    def __init__(self, in_channels, gating_channels, inter_channels=None,
                 sub_sample_factor=2):
        super(GridAttentionBlock2D, self).__init__(in_channels,
                                                   inter_channels=inter_channels,
                                                   gating_channels=gating_channels,
                                                   dimension=2,
                                                   sub_sample_factor=sub_sample_factor,
                                                   )


class GridAttentionBlock3D(_GridAttentionBlockND):
    def __init__(self, in_channels, gating_channels, inter_channels=None,
                 sub_sample_factor=(2,2,2)):
        super(GridAttentionBlock3D, self).__init__(in_channels,
                                                   inter_channels=inter_channels,
                                                   gating_channels=gating_channels,
                                                   dimension=3,
                                                   sub_sample_factor=sub_sample_factor,
                                                   )

In [6]:

class UnetGridGatingSignal2(nn.Module):
    """Operation to extract important features for a specific task using 1x1 convolution (Gating) which is used in the 2D
    attention blocks.

    Args:
        in_size (int): Number of channels in the input image.
        out_size (int): Number of channels in the output image.
        kernel_size (tuple): Convolution kernel size.
        is_batchnorm (bool): Boolean indicating whether to apply batch normalization or not.

    Attributes:
        conv1 (Sequential): 2D convolution, batch normalization and ReLU activation.
    """

    def __init__(self, in_size, out_size, kernel_size=(1, 1), is_batchnorm=True):
        super(UnetGridGatingSignal2, self).__init__()

        if is_batchnorm:
            
            self.conv1 = nn.Sequential(nn.Conv2d(in_size, out_size, kernel_size, (1, 1), (0, 0)),
                                       nn.BatchNorm2d(out_size),
                                       nn.ReLU(inplace=True)
                                       )
        else:
                                       
            self.conv1 = nn.Sequential(nn.Conv2d(in_size, out_size, kernel_size, (1, 1), (0, 0)),
                                       nn.ReLU(inplace=True)
                                       )

        # initialise the blocks
        for m in self.children():
            weights_init_kaiming(m)

    def forward(self, inputs):
        outputs = self.conv1(inputs)
        return outputs

In [7]:
class Encoder(Module):

    def __init__(self, in_channel=1, depth=3, drop_rate=0.4, bn_momentum=0.1, n_metadata=None, film_layers=None,
                 is_2d=True, n_filters=64):
        super(Encoder, self).__init__()
        self.depth = depth
        self.down_path = nn.ModuleList()
        # first block
        self.down_path.append(DownConv(in_channel, n_filters, drop_rate, bn_momentum, is_2d))
        self.down_path.append(FiLMlayer(n_metadata, n_filters) if film_layers and film_layers[0] else None)
        max_pool = nn.MaxPool2d if is_2d else nn.MaxPool3d
        self.down_path.append(max_pool(2))

        for i in range(depth - 1):
            self.down_path.append(DownConv(n_filters*2**i, n_filters*2**(i+1), drop_rate, bn_momentum, is_2d))
            self.down_path.append(FiLMlayer(n_metadata, n_filters * 2**(i+1)) if film_layers and film_layers[i + 1] else None)
            self.down_path.append(max_pool(2))

        # Bottom
        self.conv_bottom = DownConv(n_filters*2**(depth-1), n_filters*2**depth, drop_rate, bn_momentum, is_2d)# change by me
        self.film_bottom = FiLMlayer(n_metadata, n_filters*2**depth) if film_layers and film_layers[self.depth] else None #changeby me

    def forward(self, x, context=None):
        features = []

        # First block
        x = self.down_path[0](x)
        if self.down_path[1]:
            x, w_film = self.down_path[1](x, context, None)
        features.append(x)
        x = self.down_path[2](x)

        # Down-sampling path (other blocks)
        for i in range(1, self.depth):
            x = self.down_path[i * 3](x)
            if self.down_path[i * 3 + 1]:
                x, w_film = self.down_path[i * 3 + 1](x, context, None if 'w_film' not in locals() else w_film)
            features.append(x)
            x = self.down_path[i * 3 + 2](x)

        # Bottom level
        x = self.conv_bottom(x)
        if self.film_bottom:
            x, w_film = self.film_bottom(x, context, None if 'w_film' not in locals() else w_film)
        features.append(x)
        return features, None if 'w_film' not in locals() else w_film


class Decoder(Module):


    def __init__(self, out_channel=1, depth=3, drop_rate=0.4, bn_momentum=0.1,
                 n_metadata=None, film_layers=None, hemis=False, final_activation="sigmoid", is_2d=True,
                 n_filters=64, attention= True):
        super(Decoder, self).__init__()
        self.depth = depth
        self.out_channel = out_channel
        self.base_n_filter = n_filters
        self.attention = attention
        self.final_activation = final_activation
        # Up-Sampling path
        self.up_path = nn.ModuleList()
        self.att_path = nn.ModuleList()
        if hemis:
            in_channel = n_filters * 2 ** self.depth
            self.up_path.append(UpConv(in_channel * 2, n_filters * 2 ** (self.depth - 1), drop_rate, bn_momentum,
                                       is_2d))
            if film_layers and film_layers[self.depth + 1]:
                self.up_path.append(FiLMlayer(n_metadata, n_filters * 2 ** (self.depth - 1)))
            else:
                self.up_path.append(None)
            # self.depth += 1
        else:
            in_channel = n_filters * 2 ** self.depth

            self.up_path.append(UpConv(in_channel+n_filters * 2 ** (self.depth - 1)
                                       , n_filters * 2 ** (self.depth - 1), drop_rate, bn_momentum, is_2d))#chm
            if film_layers and film_layers[self.depth + 1]:
                self.up_path.append(FiLMlayer(n_metadata, n_filters * 2 ** (self.depth - 1)))
            else:
                self.up_path.append(None)

        for i in range(1, depth):
            in_channel //= 2

            self.up_path.append(UpConv(in_channel+ n_filters * 2 ** (self.depth - i - 1 + int(hemis)),
                                       n_filters * 2 ** (self.depth - i - 1),
                       drop_rate, bn_momentum, is_2d))
            if film_layers and film_layers[self.depth + i + 1]:
                self.up_path.append(FiLMlayer(n_metadata, n_filters * 2 ** (self.depth - i - 1)))
            else:
                self.up_path.append(None)

        # Last Convolution
        conv = nn.Conv2d if is_2d else nn.Conv3d
        self.last_conv = conv(in_channel // 2, out_channel, kernel_size=3, padding=1)
        self.last_film = FiLMlayer(n_metadata, self.out_channel) if film_layers and film_layers[-1] else None
        self.softmax = nn.Softmax(dim=1)
        
        ### ATTENTION MODULE ###
        if self.attention:
            
            self.gating = UnetGridGatingSignal2(self.base_n_filter * 2**(self.depth),
                                                self.base_n_filter * 2**(self.depth-1), kernel_size=(1, 1),
                                                is_batchnorm=True)
            for k in range(1,self.depth+1):
                
                self.att_path.append(GridAttentionBlock2D(in_channels=self.base_n_filter * 2**(self.depth-k),
                                                        gating_channels=self.base_n_filter * 2**(self.depth-1),
                                                        inter_channels=self.base_n_filter * 2**(self.depth-k),
                                                        sub_sample_factor=2))


    def forward(self, features, context=None, w_film=None):
        x = features[-1]
        gating = self.gating(x)

        for i in range(self.depth):
            if self.attention:
                y, att = self.att_path[i](features[(self.depth-1)-i],gating)
                x = self.up_path[2*i](x, y)
                
            else:
                x = self.up_path[2*i](x, features[(self.depth-1)-i])
            if self.up_path[2*i+1]:
                x, w_film = self.up_path[2*i+ 1](x, context, w_film)
                


        # Last convolution
        x = self.last_conv(x)
        if self.last_film:
            x, w_film = self.last_film(x, context, w_film)

        if hasattr(self, "final_activation") and self.final_activation == "softmax":
            preds = self.softmax(x)
        elif hasattr(self, "final_activation") and self.final_activation == "relu":
            preds = nn.ReLU()(x) / nn.ReLU()(x).max() if bool(nn.ReLU()(x).max()) else nn.ReLU()(x)
            # If model multiclass
            if preds.shape[1] > 1:
                class_sum = preds.sum(dim=1).unsqueeze(1)
                # Avoid division by zero
                class_sum[class_sum == 0] = 1
                preds /= class_sum
        else:
            preds = torch.sigmoid(x)

        if self.out_channel > 1:
            # Remove background class
            preds = preds[:, 1:, ]

        return preds


#if __name__=="__main__":
 #   features = [torch.rand((2, 64, 322, 322)), torch.rand((2, 128, 161, 161)),
 #              torch.rand((2, 256, 80, 80)), torch.rand((2, 512, 40, 40))]
 #   model = Decoder()
 #   print(model(features))


In [9]:

class Unet(Module):
    """A reference U-Net model.

    .. seealso::
        Ronneberger, O., et al (2015). U-Net: Convolutional
        Networks for Biomedical Image Segmentation
        ArXiv link: https://arxiv.org/abs/1505.04597

    Args:
        in_channel (int): Number of channels in the input image.
        out_channel (int): Number of channels in the output image.
        depth (int): Number of down convolutions minus bottom down convolution.
        drop_rate (float): Probability of dropout.
        bn_momentum (float): Batch normalization momentum.
        final_activation (str): Choice of final activation between "sigmoid", "relu" and "softmax".
        is_2d (bool): Indicates dimensionality of model: True for 2D convolutions, False for 3D convolutions.
        n_filters (int):  Number of base filters in the U-Net.
        **kwargs:

    Attributes:
        encoder (Encoder): U-Net encoder.
        decoder (Decoder): U-net decoder.
    """

    def __init__(self, in_channel=1, out_channel=1, depth=3, drop_rate=0.4, bn_momentum=0.1, final_activation='sigmoid',
                 is_2d=True, n_filters=64, **kwargs):
        super(Unet, self).__init__()

        # Encoder path
        self.encoder = Encoder(in_channel=1, depth=3, drop_rate=0.4, bn_momentum=0.1, n_metadata=None, film_layers=None,
                 is_2d=True, n_filters=64)

        # Decoder path
        self.decoder = Decoder(out_channel=1, depth=3, drop_rate=0.4, bn_momentum=0.1,
                 n_metadata=None, film_layers=None, hemis=False, final_activation="sigmoid", is_2d=True,
                 n_filters=64, attention=True)


    def forward(self, x):
        features, _ = self.encoder(x)
        preds = self.decoder(features)
        
        print(f"the final size is equal to {preds.size()}")

        return preds
    
    
if __name__=="__main__":
    image = torch.rand((2,1,86,86)) # batch, Channle, W, H as input
    model = Unet()
    print(model(image))
        

torch.Size([2, 256, 10, 10])
torch.Size([2, 1, 21, 21])
torch.Size([2, 256, 21, 21])
torch.Size([2, 256, 21, 21])
torch.Size([2, 1, 43, 43])
torch.Size([2, 128, 43, 43])
torch.Size([2, 128, 43, 43])
torch.Size([2, 1, 86, 86])
torch.Size([2, 64, 86, 86])
torch.Size([2, 64, 86, 86])
the final size is equal to torch.Size([2, 1, 86, 86])
tensor([[[[0.5289, 0.5782, 0.6105,  ..., 0.4764, 0.4914, 0.5742],
          [0.5252, 0.5174, 0.6103,  ..., 0.4310, 0.4038, 0.5962],
          [0.4992, 0.6070, 0.5364,  ..., 0.3027, 0.4503, 0.6104],
          ...,
          [0.4476, 0.3083, 0.2959,  ..., 0.4956, 0.2178, 0.3265],
          [0.4757, 0.4349, 0.4090,  ..., 0.4034, 0.3720, 0.3553],
          [0.5666, 0.5279, 0.5734,  ..., 0.5136, 0.4664, 0.4260]]],


        [[[0.5342, 0.4628, 0.3959,  ..., 0.4934, 0.4997, 0.5142],
          [0.6124, 0.5599, 0.5561,  ..., 0.5135, 0.4738, 0.5327],
          [0.5291, 0.5371, 0.4802,  ..., 0.5075, 0.5027, 0.5677],
          ...,
          [0.5433, 0.5759, 0.4579,  