# The model :


In [None]:
class DDCONV_Block(nn.Module):
    # The following class is taken from https://github.com/samleoqh/DDCM-Semantic-Segmentation-PyTorch/

    def __init__(self, in_dim, out_dim, rates, kernel=3, bias=False, extend_dim=False):
        super(DDCONV_Block, self).__init__()
        self.features = []
        self.num = len(rates)
        self.in_dim = in_dim
        self.out_dim = out_dim

        if self.num > 0:
            if extend_dim:
                self.out_dim = out_dim * self.num
            for idx, rate in enumerate(rates):
                self.features.append(nn.Sequential(
                    nn.Conv2d(self.in_dim + idx * out_dim,
                              out_dim,
                              kernel_size=kernel, dilation=rate,
                              padding=rate * (kernel - 1) // 2, bias=bias),
                    nn.PReLU(),
                    nn.BatchNorm2d(out_dim))
                )

            self.features = nn.ModuleList(self.features)

        self.conv1x1_out = nn.Sequential(
            nn.Conv2d(self.in_dim + out_dim * self.num,
                      self.out_dim, kernel_size=1, bias=bias),
            nn.PReLU(),
            nn.BatchNorm2d(self.out_dim),
        )

    def forward(self, x):
        for f in self.features:
            x = torch.cat([f(x), x], 1)
        x = self.conv1x1_out(x)
        return x

def crop(e,d):
  diff = e.size(2) - d.size(2)
  return e[:,:,diff//2:d.size(2)+diff//2 ,diff//2:d.size(2)+diff//2]

class UNetSCDC(nn.Module):
  def __init__(self):
    super(UNetSCDC, self).__init__()

    self.maxpool = nn.MaxPool2d(kernel_size=2 , stride = 2)
    self.ddconv = DDCONV_Block( in_dim = 512, out_dim = 512, rates = [2,4,8,16])
    self.encoder_conv_1 = nn.Sequential(nn.Conv2d(in_channels= 3, out_channels= 64 , kernel_size= 3),
                                 nn.ReLU(inplace= True),
                                 nn.Conv2d(in_channels= 64, out_channels= 64 , kernel_size= 3),
                                 nn.ReLU(inplace= True),
                                 )
    self.encoder_conv_2 = nn.Sequential(nn.Conv2d(in_channels= 64, out_channels= 128 , kernel_size= 3),
                                 nn.ReLU(inplace= True),
                                 nn.Conv2d(in_channels= 128, out_channels= 128 , kernel_size= 3),
                                 nn.ReLU(inplace= True),
                                 )
    self.encoder_conv_3 = nn.Sequential(nn.Conv2d(in_channels= 128, out_channels= 256 , kernel_size= 3),
                                 nn.ReLU(inplace= True),
                                 nn.Conv2d(in_channels= 256, out_channels= 256 , kernel_size= 3),
                                 nn.ReLU(inplace= True),
                                 )
    self.encoder_conv_4 = nn.Sequential(nn.Conv2d(in_channels= 256, out_channels= 512 , kernel_size= 3),
                                nn.ReLU(inplace= True),
                                 nn.Conv2d(in_channels= 512, out_channels= 512 , kernel_size= 3),
                                 nn.ReLU(inplace= True),
                                 )
    # self.encoder_conv_5 = nn.Sequential(nn.Conv2d(in_channels= 512, out_channels= 1024 , kernel_size= 3),
    #                              nn.ReLU(inplace= True),
    #                              nn.Conv2d(in_channels= 1024, out_channels= 1024 , kernel_size= 3),
    #                              nn.ReLU(inplace= True))    
    
    #self.convt_1= nn.ConvTranspose2d(in_channels = 1024, out_channels = 512 , kernel_size=2 ,stride=2)
    self.convt_2= nn.ConvTranspose2d(in_channels = 512, out_channels = 256 , kernel_size=2,stride=2)
    self.convt_3= nn.ConvTranspose2d(in_channels = 256, out_channels = 128 , kernel_size=2,stride=2)
    self.convt_4= nn.ConvTranspose2d(in_channels = 128, out_channels = 64 , kernel_size=2,stride=2)

    # self.decoder_conv_1 = nn.Sequential(nn.Conv2d(in_channels= 1024, out_channels= 512 , kernel_size= 3),
    #                                     nn.ReLU(inplace= True),
    #                                     nn.Conv2d(in_channels= 512, out_channels= 512 , kernel_size= 3),
    #                                     nn.ReLU(inplace= True),
    #                                     )
    self.decoder_conv_2 = nn.Sequential(nn.Conv2d(in_channels= 512, out_channels= 256 , kernel_size= 3),
                                        nn.ReLU(inplace= True),
                                        nn.Conv2d(in_channels= 256, out_channels= 256 , kernel_size= 3),
                                        nn.ReLU(inplace= True),
                                        )
    self.decoder_conv_3 = nn.Sequential(nn.Conv2d(in_channels= 256, out_channels= 128 , kernel_size= 3),
                                        nn.ReLU(inplace= True),
                                        nn.Conv2d(in_channels= 128, out_channels= 128 , kernel_size= 3),
                                        nn.ReLU(inplace= True),
                                        )
    self.decoder_conv_4 = nn.Sequential(nn.Conv2d(in_channels= 128, out_channels= 64 , kernel_size= 3),
                                        nn.ReLU(inplace= True),
                                        nn.Conv2d(in_channels= 64, out_channels= 64 , kernel_size= 3),
                                        nn.ReLU(inplace= True),
                                        )
    self.decoder_conv_5 = nn.Conv2d(in_channels= 64, out_channels= 1 , kernel_size= 1)
                                      
  def forward(self,x):
      e1 = self.encoder_conv_1(x)
      e2 = self.encoder_conv_2(self.maxpool(e1))
      e3 = self.encoder_conv_3(self.maxpool(e2))
      e4 = self.encoder_conv_4(self.maxpool(e3))
      #e5 = self.encoder_conv_5(self.maxpool(e4))
      #e5 = e5.view(e5.size(0),1024,e5.size(2),e5.size(3))
      d = self.ddconv(e4)
      # t = self.convt_1(d)
      # d = self.decoder_conv_1(torch.cat((t,crop(e4,t)),1))
      t = self.convt_2(d)
      d = self.decoder_conv_2(torch.cat((t,crop(e3,t)),1))
      t = self.convt_3(d)
      d = self.decoder_conv_3(torch.cat((t,crop(e2,t)),1))
      t = self.convt_4(d)
      d = self.decoder_conv_4(torch.cat((t,crop(e1,t)),1))
      d = self.decoder_conv_5(d)
      d = torch.sigmoid(F.interpolate(d, size=x.shape[-2:], mode='bilinear', align_corners=False))
      return d

In [None]:
unet = UNetSCDC()
ndf = 4
descriminator_1 =   nn.Sequential(
                    nn.Conv2d(4, ndf, 4, 2, 1, bias=False),
                    nn.LeakyReLU(0.2, inplace=True),
                    nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
                    nn.BatchNorm2d(ndf * 2),
                    nn.LeakyReLU(0.2, inplace=True),
                    nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
                    nn.BatchNorm2d(ndf * 4),
                    nn.LeakyReLU(0.2, inplace=True),
                    nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
                    nn.BatchNorm2d(ndf * 8),
                    nn.LeakyReLU(0.2, inplace=True),
                    nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
                    nn.Flatten(),
                    nn.Linear(25,1),
                    nn.Sigmoid())

descriminator_2 =   nn.Sequential(
                    nn.Conv2d(1, ndf, 4, 2, 1, bias=False),
                    nn.LeakyReLU(0.2, inplace=True),
                    nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
                    nn.BatchNorm2d(ndf * 2),
                    nn.LeakyReLU(0.2, inplace=True),
                    nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
                    nn.BatchNorm2d(ndf * 4),
                    nn.LeakyReLU(0.2, inplace=True),
                    nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
                    nn.BatchNorm2d(ndf * 8),
                    nn.LeakyReLU(0.2, inplace=True),
                    nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
                    nn.Flatten(),
                    nn.Linear(25,1),
                    nn.Sigmoid())