<a href="https://colab.research.google.com/github/MistreanuIonutCosmin/A-U-Net-Based-Discriminator-for-Generative-Adversarial-Networks/blob/master/A_U_Net_Based_Discriminator.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import torch
from torch import nn
import torch.nn.functional as F

In [0]:
class Norms(object):
  BATCH_NORM = 'bn'
  BATCH_NORM_AFFINE = 'bnaffine'
  INSTANCE_NORM = 'in'
  INSTANCE_NORM_AFFINE = 'inaffine'
  NO_NORM = 'nonorm'

class NoNorm(nn.Module):
  def __init__(self, *input):
    super(NoNorm, self).__init__()

  def forward(self, x):
    return x

def get_norm(dim, norm=Norms.BATCH_NORM_AFFINE):
  if norm == Norms.BATCH_NORM_AFFINE:
    return lambda dim: nn.BatchNorm2d(dim, affine=True)
  elif norm == Norms.BATCH_NORM:
    return lambda dim: nn.BatchNorm2d(dim, affine=False)
  elif norm == Norms.INSTANCE_NORM_AFFINE:
    return lambda dim: nn.InstanceNorm2d(dim, affine=True)
  elif norm == Norms.INSTANCE_NORM:
    return lambda dim: nn.InstanceNorm2d(dim, affine=False)
  elif norm == Norms.NO_NORM:
    return NoNorm

In [0]:
norm = NoNorm()

In [0]:
class ResBlock2d(nn.Module):
    """
    Res block, preserve spatial resolution.
    """

    def __init__(self, in_features, out_features, kernel_size, padding, norm=Norms.BATCH_NORM_AFFINE):
        super(ResBlock2d, self).__init__()
        print(in_features)
        self.conv1 = nn.Conv2d(in_channels=int(in_features), out_channels=int(out_features), kernel_size=kernel_size,
                               padding=padding)
        self.conv2 = nn.Conv2d(in_channels=int(out_features), out_channels=int(out_features), kernel_size=kernel_size,
                               padding=padding)
        if in_features != out_features:
          self.shortcut = True
          self.convs = nn.Conv2d(in_channels=int(in_features), out_channels=int(out_features), kernel_size=1,
                                 padding=0)
        else:
          self.shortcut = False

        self.norm_class = get_norm(norm)
        self.norm1 = self.norm_class(int(in_features))
        self.norm2 = self.norm_class(int(out_features))

    def _shortcut(self, x):
      if self.shortcut:
        return self.convs(x)
      else:
        return x

    def forward(self, x):
        out = self.norm1(x)
        out = F.relu(out)
        out = self.conv1(out)
        out = self.norm2(out)
        out = F.relu(out)
        out = self.conv2(out)
        out += self._shortcut(x)
        return out
        

In [0]:
class ResDownBlock2d(nn.Module):
    """
    Downsampling block for use in encoder.
    """

    def __init__(self, in_features, out_features, kernel_size=3, padding=1, 
                 norm=Norms.BATCH_NORM_AFFINE):
        super(ResDownBlock2d, self).__init__()
        self.conv = ResBlock2d(in_features=in_features, out_features=out_features, kernel_size=kernel_size,
                              padding=padding, norm=norm)
        self.pool = nn.AvgPool2d(kernel_size=(2, 2))

    def forward(self, x):
        out = self.conv(x)
        out = self.pool(out)
        return out

class ResUpBlock2d(nn.Module):
    """
    Downsampling block for use in encoder.
    """

    def __init__(self, in_features, out_features, kernel_size=3, padding=1, 
                 norm=Norms.BATCH_NORM_AFFINE):
        super(ResUpBlock2d, self).__init__()
        self.conv = ResBlock2d(in_features=in_features, out_features=out_features, kernel_size=kernel_size,
                              padding=padding, norm=norm)

    def forward(self, x):
        out = self.conv(x)
        out = F.interpolate(out, scale_factor=2)
        return out

In [0]:
class Encoder(nn.Module):
    """
    Hourglass Encoder
    """

    def __init__(self, in_features, block_expansion=64, num_blocks=6, max_features=1024, norm=Norms.BATCH_NORM_AFFINE):
        super(Encoder, self).__init__()

        self.initial_conv = nn.Conv2d(in_features, block_expansion, 1)

        down_blocks = []
        for i in range(num_blocks):
            down_blocks.append(ResDownBlock2d(min(max_features, block_expansion * (2 ** i)),
                                           min(max_features, block_expansion * (2 ** (i + 1))),
                                           kernel_size=3, padding=1, norm=norm))
        self.down_blocks = nn.ModuleList(down_blocks)
        self.out_features = min(max_features, block_expansion * (2 ** (i + 1)))

    def forward(self, x):
        x = self.initial_conv(x)
        outs = [x]
        for down_block in self.down_blocks:
            outs.append(down_block(outs[-1]))
        return outs



In [0]:

class Decoder(nn.Module):
    """
    Hourglass Decoder
    """

    def __init__(self, in_features, block_expansion=64, num_blocks=6, max_features=1024, norm=Norms.BATCH_NORM_AFFINE):
        super(Decoder, self).__init__()

        up_blocks = []

        for i in range(num_blocks)[::-1]:
            in_filters = (1 if i == num_blocks - 1 else 2) * min(max_features, block_expansion * (2 ** (i+1)))
            out_filters = min(max_features, max(0, block_expansion * (2 ** i)))
            print(in_filters, out_filters)

            up_blocks.append(ResUpBlock2d(in_filters, out_filters, kernel_size=3, padding=1, norm=Norms.BATCH_NORM_AFFINE))
 
        self.up_blocks = nn.ModuleList(up_blocks)
        self.out_features = 2 * out_filters

    def forward(self, x):
        out = x.pop()
        print("initial", out.size())
        for i, up_block in enumerate(self.up_blocks):
            out = up_block(out)
            print("after up", out.size())
            skip = x.pop()
            print("from enc", skip.size())
            print("idx, out, skip", (i, out.size(), skip.size()))
            out = torch.cat([out, skip], dim=1)
            print("after cat", out.size())
        print("out", out.size())
        return out

In [0]:
class HourglassDiscriminator(nn.Module):
    """
    Hourglass architecture.
    """

    def __init__(self, in_features, block_expansion=64,  num_blocks=5, max_features=1024, 
                 norm=Norms.BATCH_NORM_AFFINE):
        super(HourglassDiscriminator, self).__init__()
        self.encoder = Encoder(in_features, block_expansion, num_blocks, max_features, 
                               norm=norm)
        
        self.relu = nn.ReLU()
        self.linear = nn.Linear(self.encoder.out_features, 1)

        self.decoder = Decoder(in_features, block_expansion, num_blocks, max_features,
                               norm=norm)
        
        self.final_conv = nn.Conv2d(self.decoder.out_features, 1, kernel_size=1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        enc_out = self.encoder(x)

        enc_discrim_out = self.relu(enc_out[-1])
        enc_discrim_out = torch.sum(enc_discrim_out, dim=(2, 3))
        enc_discrim_out = self.linear(enc_discrim_out)

        dec_discrim_out = self.sigmoid(self.final_conv(self.decoder(enc_out)))

        return enc_discrim_out, dec_discrim_out


In [0]:
discrim = HourglassDiscriminator(3)
img = torch.rand((1, 3, 256, 256))
enc_out, dec_out = discrim(img)
print(dec_out.size(), img.size())
assert dec_out.size()[2:] == img.size()[2:]

64
128
256
512
1024
1024 1024
1024
2048 512
2048
1024 256
1024
512 128
512
256 64
256
initial torch.Size([1, 1024, 8, 8])
after up torch.Size([1, 1024, 16, 16])
from enc torch.Size([1, 1024, 16, 16])
idx, out, skip (0, torch.Size([1, 1024, 16, 16]), torch.Size([1, 1024, 16, 16]))
after cat torch.Size([1, 2048, 16, 16])
after up torch.Size([1, 512, 32, 32])
from enc torch.Size([1, 512, 32, 32])
idx, out, skip (1, torch.Size([1, 512, 32, 32]), torch.Size([1, 512, 32, 32]))
after cat torch.Size([1, 1024, 32, 32])
after up torch.Size([1, 256, 64, 64])
from enc torch.Size([1, 256, 64, 64])
idx, out, skip (2, torch.Size([1, 256, 64, 64]), torch.Size([1, 256, 64, 64]))
after cat torch.Size([1, 512, 64, 64])
after up torch.Size([1, 128, 128, 128])
from enc torch.Size([1, 128, 128, 128])
idx, out, skip (3, torch.Size([1, 128, 128, 128]), torch.Size([1, 128, 128, 128]))
after cat torch.Size([1, 256, 128, 128])
after up torch.Size([1, 64, 256, 256])
from enc torch.Size([1, 64, 256, 256])
idx, out