In [30]:
import torch
from torchsummary import summary
import math
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

In [42]:
class UNet(nn.Module):

    def __init__(self, in_channels=3,n_classes=2,feature_scale=4, is_deconv=True, is_batchnorm=True):
        super(UNet, self).__init__()
        self.is_deconv = is_deconv
        self.in_channels = in_channels
        self.is_batchnorm = is_batchnorm
        self.feature_scale = feature_scale

        filters = [64, 128, 256, 512, 1024]
        filters = [int(x / self.feature_scale) for x in filters]

        # downsampling
        self.conv1 = unetConv2(self.in_channels, filters[0], self.is_batchnorm)
        self.maxpool1 = nn.MaxPool2d(kernel_size=2)

        self.conv2 = unetConv2(filters[0], filters[1], self.is_batchnorm)
        self.maxpool2 = nn.MaxPool2d(kernel_size=2)

        self.conv3 = unetConv2(filters[1], filters[2], self.is_batchnorm)
        self.maxpool3 = nn.MaxPool2d(kernel_size=2)

        self.conv4 = unetConv2(filters[2], filters[3], self.is_batchnorm)
        self.maxpool4 = nn.MaxPool2d(kernel_size=2)

        self.center = unetConv2(filters[3], filters[4], self.is_batchnorm)

        # upsampling
        self.up_concat4 = unetUp(filters[4], filters[3], self.is_deconv)
        self.up_concat3 = unetUp(filters[3], filters[2], self.is_deconv)
        self.up_concat2 = unetUp(filters[2], filters[1], self.is_deconv)
        self.up_concat1 = unetUp(filters[1], filters[0], self.is_deconv)

        # final conv (without any concat)
        self.final = nn.Conv2d(filters[0], n_classes, 1)

        # initialise weights
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init_weights(m, init_type='kaiming')
            elif isinstance(m, nn.BatchNorm2d):
                init_weights(m, init_type='kaiming')


    def forward(self, inputs):
        conv1 = self.conv1(inputs)
        maxpool1 = self.maxpool1(conv1)

        conv2 = self.conv2(maxpool1)
        maxpool2 = self.maxpool2(conv2)

        conv3 = self.conv3(maxpool2)
        maxpool3 = self.maxpool3(conv3)

        conv4 = self.conv4(maxpool3)
        maxpool4 = self.maxpool4(conv4)

        center = self.center(maxpool4)
        up4 = self.up_concat4(conv4, center)
        up3 = self.up_concat3(conv3, up4)
        up2 = self.up_concat2(conv2, up3)
        up1 = self.up_concat1(conv1, up2)

        final = self.final(up1)

        return F.log_softmax(final),up1

    



In [43]:
model = UNet()


In [44]:
input = torch.rand(4,3,512,512)
input = Variable(input)

In [46]:
y = model(input)



In [47]:
y[0].shape

torch.Size([4, 2, 512, 512])

In [48]:
y[1].shape

torch.Size([4, 16, 512, 512])

In [49]:
len(y)

2

In [54]:
model = RNN_GRU_UNet2d()

In [55]:
input = torch.rand(24,3,512,512)
input = Variable(input)

In [58]:
input.shape

torch.Size([24, 3, 512, 512])

In [56]:
y = model(input)

ValueError: Expected 4D tensor as input, got 3D tensor instead.

In [53]:
class RNN_GRU_UNet2d(nn.Module):
    def __init__(self, in_channel=3, n_classes=2,bn = True, input_chl = 16):

        self.in_channel = in_channel
        self.n_classes = n_classes
        self.bn = bn
        self.input_chl = input_chl

        super(RNN_GRU_UNet2d, self).__init__()

        self.conv = self.encoder(self.input_chl, 4, bias=True, batchnorm=self.bn)
        self.izt = self.encoder(4, 4, bias=True, batchnorm=self.bn)
        self.hzt = self.encoder(4, 4, bias=True, batchnorm=self.bn)
        self.irt = self.encoder(4, 4, bias=True, batchnorm=self.bn)
        self.hrt = self.encoder(4, 4, bias=True, batchnorm=self.bn)
        self.int = self.encoder(4, 4, bias=True, batchnorm=self.bn)
        self.hnt = self.encoder(4, 4, bias=True, batchnorm=self.bn)
        self.ht = self.encoder(4, in_channel, bias=True, batchnorm=self.bn)

        self.Unet2d = UNet(in_channel, n_classes)
        self.tanh_function = nn.Tanh()
        self.sigmoid = nn.Sigmoid()

    def encoder(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1,
                bias=True, batchnorm=False):
        if batchnorm:
            layer = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias),
                nn.BatchNorm2d(out_channels),
                nn.ReLU())
        else:
            layer = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias),
                nn.ReLU())
        return layer

    def decoder(self, in_channels, out_channels, kernel_size, stride=1, padding=0,
                output_padding=0, bias=True, batchnorm=True):
        if batchnorm:
            layer = nn.Sequential(
                nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=stride,
                                   padding=padding, output_padding=output_padding, bias=bias),
                nn.BatchNorm2d(out_channels),
                nn.ReLU())
        else:
            layer = nn.Sequential(
                nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=stride,
                                   padding=padding, output_padding=output_padding, bias=bias),
                nn.ReLU())
        return layer

    def forward(self, x):

        #ret = []
        att = []
        output = torch.zeros(x.shape[0],self.n_classes,x.shape[2],x.shape[3])
        for i in range(x.shape[0]):
            if i >= 1:
                res__, conv_feature = self.Unet2d(x[i].unsqueeze(dim=0) * temp_mul)
                conv2 = self.conv(conv_feature)
                rt = self.sigmoid(self.irt(conv2) + self.hrt(ht))
                zt = self.sigmoid(self.izt(conv2) + self.hzt(ht))
                nt = self.tanh_function(self.int(conv2) + rt * self.hnt(ht))
                ht = (1 - zt) * ht + zt * nt
                temp = self.ht(ht)
                temp_mul = self.sigmoid(temp) * 2.0 - 1.0
                #x = x * temp_mul
                att.append(temp_mul)
            else:
                res__, conv_feature = self.Unet2d(x[i].unsqueeze(dim=0))
                conv2 = self.conv(conv_feature)
                ht = conv2
                rt = self.sigmoid(self.irt(conv2) + self.hrt(ht))
                zt = self.sigmoid(self.izt(conv2) + self.hzt(ht))
                nt = self.tanh_function(self.int(conv2) + rt * self.hnt(ht))
                ht = (1 - zt) * ht + zt * nt
                temp = self.ht(ht)
                temp_mul = self.sigmoid(temp) * 2.0 - 1.0
                #x = x * temp_mul
                att.append(temp_mul)

            # res__, conv_feature = self.Unet2d(x)
            res = res__
            # res = self.trans_decoder(conv_feature)
            #ret.append(res)
            output[i,:,:,:] = res
            
        return output












In [50]:
import torch
import torch.nn as nn
from torch.nn import init

def weights_init_normal(m):
    classname = m.__class__.__name__
    #print(classname)
    if classname.find('Conv') != -1:
        init.normal(m.weight.data, 0.0, 0.02)
    elif classname.find('Linear') != -1:
        init.normal(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        init.normal(m.weight.data, 1.0, 0.02)
        init.constant(m.bias.data, 0.0)


def weights_init_xavier(m):
    classname = m.__class__.__name__
    #print(classname)
    if classname.find('Conv') != -1:
        init.xavier_normal(m.weight.data, gain=1)
    elif classname.find('Linear') != -1:
        init.xavier_normal(m.weight.data, gain=1)
    elif classname.find('BatchNorm') != -1:
        init.normal(m.weight.data, 1.0, 0.02)
        init.constant(m.bias.data, 0.0)


def weights_init_kaiming(m):
    classname = m.__class__.__name__
    #print(classname)
    if classname.find('Conv') != -1:
        init.kaiming_normal(m.weight.data, a=0, mode='fan_in')
    elif classname.find('Linear') != -1:
        init.kaiming_normal(m.weight.data, a=0, mode='fan_in')
    elif classname.find('BatchNorm') != -1:
        init.normal(m.weight.data, 1.0, 0.02)
        init.constant(m.bias.data, 0.0)


def weights_init_orthogonal(m):
    classname = m.__class__.__name__
    #print(classname)
    if classname.find('Conv') != -1:
        init.orthogonal(m.weight.data, gain=1)
    elif classname.find('Linear') != -1:
        init.orthogonal(m.weight.data, gain=1)
    elif classname.find('BatchNorm') != -1:
        init.normal(m.weight.data, 1.0, 0.02)
        init.constant(m.bias.data, 0.0)


def init_weights(net, init_type='normal'):
    #print('initialization method [%s]' % init_type)
    if init_type == 'normal':
        net.apply(weights_init_normal)
    elif init_type == 'xavier':
        net.apply(weights_init_xavier)
    elif init_type == 'kaiming':
        net.apply(weights_init_kaiming)
    elif init_type == 'orthogonal':
        net.apply(weights_init_orthogonal)
    else:
        raise NotImplementedError('initialization method [%s] is not implemented' % init_type)




In [51]:
class unetConv2(nn.Module):
    def __init__(self, in_size, out_size, is_batchnorm, n=2, ks=3, stride=1, padding=1):
        super(unetConv2, self).__init__()
        self.n = n
        self.ks = ks
        self.stride = stride
        self.padding = padding
        s = stride
        p = padding
        if is_batchnorm:
            for i in range(1, n+1):
                conv = nn.Sequential(nn.Conv2d(in_size, out_size, ks, s, p),
                                     nn.BatchNorm2d(out_size),
                                     nn.ReLU(inplace=True),)
                setattr(self, 'conv%d'%i, conv)
                in_size = out_size

        else:
            for i in range(1, n+1):
                conv = nn.Sequential(nn.Conv2d(in_size, out_size, ks, s, p),
                                     nn.ReLU(inplace=True),)
                setattr(self, 'conv%d'%i, conv)
                in_size = out_size

        # initialise the blocks
        for m in self.children():
            init_weights(m, init_type='kaiming')

    def forward(self, inputs):
        x = inputs
        for i in range(1, self.n+1):
            conv = getattr(self, 'conv%d'%i)
            x = conv(x)

        return x


class unetUp(nn.Module):
    def __init__(self, in_size, out_size, is_deconv):
        super(unetUp, self).__init__()
        self.conv = unetConv2(in_size, out_size, False)
        if is_deconv:
            self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=4, stride=2, padding=1)
        else:
            self.up = nn.UpsamplingBilinear2d(scale_factor=2)

        # initialise the blocks
        for m in self.children():
            if m.__class__.__name__.find('unetConv2') != -1: continue
            init_weights(m, init_type='kaiming')

    def forward(self, inputs1, inputs2):
        outputs2 = self.up(inputs2)
        offset = outputs2.size()[2] - inputs1.size()[2]
        padding = 2 * [offset // 2, offset // 2]
        outputs1 = F.pad(inputs1, padding)
        return self.conv(torch.cat([outputs1, outputs2], 1))


