In [7]:
import torch
import torch.nn as nn

class Critic(nn.Module):
    def __init__(self, channels_img, n_features_map_critic):
        super(Critic, self).__init__()
        self.critic = nn.Sequential(
            # Input: N x channels_img x 64 x 64
            nn.Conv2d(channels_img, n_features_map_critic, kernel_size= 4, stride= 2, padding= 1), # 32 x 32
            nn.LeakyReLU(0.2),
            self._block(n_features_map_critic, n_features_map_critic*2, kernel_s = 4,stride= 2, pad = 1), # 16 x 16
            self._block(n_features_map_critic*2, n_features_map_critic*4, kernel_s = 4,stride= 2, pad = 1), # 8 x 8
            self._block(n_features_map_critic*4, n_features_map_critic*8, kernel_s = 4,stride= 2, pad = 1), # 4 x 4
            # After all _block img output is 4x4 (Conv2d below makes into 1x1)
            nn.Conv2d(n_features_map_critic*8, 1, kernel_size= 4, stride= 2, padding= 0), # 1 x 1 output = Real or Fake img
            )
        
    def _block(self, in_chan, out_chan, kernel_s, stride, pad):
        return nn.Sequential(
            nn.Conv2d(
                in_channels = in_chan,
                out_channels = out_chan,
                kernel_size = kernel_s,
                stride = stride,
                padding=pad,
                bias = False, # because we use batchnorm, so we dont need to use the bias
            ),
            nn.BatchNorm2d(num_features= out_chan),
            nn.LeakyReLU(negative_slope = 0.2),
        )
    
    def forward(self, x):
        return self.critic(x)
    
    

class Generator(nn.Module):
    def __init__(self, z_dim, channels_img, n_features_map_gen): # z_dim its the dimension of the noise
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
        # Input: N x z_dim x 1 x 1
        # Note: This hiperparameter are extracted from the Figure of the paper
            self._block(z_dim, n_features_map_gen*16,kernel_s=4, stride=1, pad=0), # N x f_g*16 x 4 x 4
            self._block(n_features_map_gen*16, n_features_map_gen*8,kernel_s=4, stride= 2, pad=1), # 8 x 8
            self._block(n_features_map_gen*8, n_features_map_gen*4,kernel_s=4, stride= 2, pad=1), # 16 x 16
            self._block(n_features_map_gen*4, n_features_map_gen*2,kernel_s=4, stride= 2, pad=1), # 32 x 32
        nn.ConvTranspose2d(
            n_features_map_gen*2, out_channels= channels_img, kernel_size=4, stride=2, padding= 1,
        ),
        nn.Tanh() # normalize inputs are in the range [-1, 1], so we use Tanh to make outputs gen in the range[-1, 1]
        )
        
    def _block(self, in_chan, out_chan, kernel_s, stride, pad):
        return nn.Sequential(
            nn.ConvTranspose2d(     # to upsampling the fake image
                in_channels = in_chan,
                out_channels = out_chan,
                kernel_size = kernel_s,
                stride = stride,
                padding=pad,
                bias = False,  # because we use batchnorm, so we dont need to use the bias
            ),
            nn.BatchNorm2d(num_features= out_chan),
            nn.ReLU(),
        )
    
    def forward(self, x):
        return self.gen(x)  

# we have to initialize the weights like the paper says
def initialize_weights(model):
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)

# For testing the dimension of the model architecture 
def test():
    N, in_channels, H, W = 8, 3, 64, 64
    z_dim = 100
    x = torch.randn((N, in_channels, H, W))
    critic = Critic(channels_img=in_channels, n_features_map_critic= 8)
    initialize_weights(critic)
    assert critic(x).shape == (N, 1, 1, 1) # We want one value per sample, so we check the dimension of critic(x)
    gen = Generator(z_dim, in_channels,8)
    initialize_weights(gen)
    z = torch.randn((N, z_dim, 1, 1))
    assert gen(z).shape == (N, in_channels, H, W)
    print("Funca!!")
    
    
# test()

# Set the Class Critic_WGANGP for the new Critic which use InstanceNorm instead of Batchnorm
class Critic_WGANGP(nn.Module):
    def __init__(self, channels_img, n_features_map_critic):
        super(Critic_WGANGP, self).__init__()
        self.critic_WGANGP = nn.Sequential(
            # Input: N x channels_img x 64 x 64
            nn.Conv2d(channels_img, n_features_map_critic, kernel_size= 4, stride= 2, padding= 1), # 32 x 32
            nn.LeakyReLU(0.2),
            self._block(n_features_map_critic, n_features_map_critic*2, kernel_s = 4,stride= 2, pad = 1), # 16 x 16
            self._block(n_features_map_critic*2, n_features_map_critic*4, kernel_s = 4,stride= 2, pad = 1), # 8 x 8
            self._block(n_features_map_critic*4, n_features_map_critic*8, kernel_s = 4,stride= 2, pad = 1), # 4 x 4
            # After all _block img output is 4x4 (Conv2d below makes into 1x1)
            nn.Conv2d(n_features_map_critic*8, 1, kernel_size= 4, stride= 2, padding= 0), # 1 x 1 output = Real or Fake img
            )
        
    def _block(self, in_chan, out_chan, kernel_s, stride, pad):
        return nn.Sequential(
            nn.Conv2d(
                in_channels = in_chan,
                out_channels = out_chan,
                kernel_size = kernel_s,
                stride = stride,
                padding=pad,
                bias = False, 
            ),
            nn.InstanceNorm2d(num_features = out_chan, affine= True), # Like paper says instead of using Batchnorm, we use InstanceNorm2d
            nn.LeakyReLU(negative_slope = 0.2),
        )
    
    def forward(self, x):
        return self.critic_WGANGP(x)