In [192]:
from collections import OrderedDict
import torch
import torch.nn as nn

In [193]:
class DisInputBlock(nn.Module):
  def __init__(self, in_channels = None, out_channels = None, is_batch_norm=False):
    super(DisInputBlock, self).__init__()

    self.in_channels = in_channels
    self.out_channels = out_channels
    self.is_batch_norm = is_batch_norm

    self.kernel_size = 3
    self.stride = 1
    self.padding = 1

    self.model = self.input_block()

  def input_block(self):
    layers = OrderedDict()

    if self.is_batch_norm:
      layers["conv"] = nn.Conv2d(
          self.in_channels, self.out_channels, self.kernel_size, self.stride*2, self.padding)

      layers["batch_norm"] = nn.BatchNorm2d(self.out_channels)
    else:
      layers["conv"] = nn.Conv2d(
          self.in_channels, self.out_channels, self.kernel_size, self.stride, self.padding)

    layers["lReLU"] = nn.LeakyReLU(negative_slope=0.2)

    return nn.Sequential(layers)

  def forward(self, x):
    return self.model(x)

In [194]:
class DisFeatureBlock(nn.Module):
  def __init__(self, in_channels = None, out_channels = None):
    super(DisFeatureBlock, self).__init__()

    self.in_channels = in_channels
    self.out_channels = out_channels

    self.kernel_size = 3
    self.stride = 1
    self.padding = 1

    self.model = self.feature_block()

  def feature_block(self):
    layers = []
    layers.append(nn.Conv2d(
          self.in_channels, self.out_channels, self.kernel_size, self.stride*2, self.padding))

    layers.append(nn.BatchNorm2d(self.out_channels))

    layers.append(nn.LeakyReLU(negative_slope=0.2))

    return nn.Sequential(*layers)


  def forward(self, x):
    return self.model(x)

In [195]:
class DisOutBlock(nn.Module):
  def __init__(self, in_channels=None, out_channels=None):
    super(DisOutBlock, self).__init__()

    self.in_channels = in_channels
    self.out_channels = out_channels
    self.kernel_size = 1

    self.model = self.out_block()

  def out_block(self):
    layers = OrderedDict()

    layers["out_conv"] = nn.Conv2d(self.in_channels, self.out_channels, self.kernel_size)
    layers["lReLU"] = nn.LeakyReLU(negative_slope=0.2)

    return nn.Sequential(layers)

  def forward(self, x):
    return self.model(x)

In [196]:
class Discriminator(nn.Module):
  def __init__(self):
    super(Discriminator, self).__init__()

    self.in_channels = 64
    self.out_channels = 128
    self.total_repetitive = 6

    self.input = nn.Sequential(*[DisInputBlock(
        in_channels=in_channels, out_channels=out_channels, is_batch_norm=batch_norm) for batch_norm, in_channels, out_channels in [[False, 3, 64], [True, 64, 64]]])

    layers = []

    for index in range(self.total_repetitive):
      layers.append(DisFeatureBlock(self.in_channels, self.in_channels*2 if index%2==0 else self.in_channels))

      if index%2==0:
        self.in_channels=self.in_channels*2
      else:
        self.in_channels = self.in_channels

    self.features_block = nn.Sequential(*layers)

    self.avg_pool = nn.AdaptiveAvgPool2d(output_size=1)

    self.out = nn.Sequential(*[DisOutBlock(in_channels, out_channels) for in_channels, out_channels in [[512, 1024], [1024, 1]]])


  def forward(self, x):
    x = self.input(x)

    x = self.features_block(x)

    x = self.avg_pool(x)

    return self.out(x)

In [197]:
netD = Discriminator()

netD(torch.randn(64, 3, 256, 256)).size()

torch.Size([64, 1, 1, 1])

In [198]:
sum(params.numel() for params in netD.parameters())

5215425

In [199]:
netD

Discriminator(
  (input): Sequential(
    (0): DisInputBlock(
      (model): Sequential(
        (conv): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (lReLU): LeakyReLU(negative_slope=0.2)
      )
    )
    (1): DisInputBlock(
      (model): Sequential(
        (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (batch_norm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (lReLU): LeakyReLU(negative_slope=0.2)
      )
    )
  )
  (features_block): Sequential(
    (0): DisFeatureBlock(
      (model): Sequential(
        (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): LeakyReLU(negative_slope=0.2)
      )
    )
    (1): DisFeatureBlock(
      (model): Sequential(
        (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (1): BatchN