In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from pathlib import Path

import sys
import import_ipynb
dir = Path('notebooks')
sys.path.insert(0, str(dir.resolve()))
import globals


importing Jupyter notebook from globals.ipynb


In [2]:
CRITIC_STEP = 1
Z_DIM = 512
LEARNING_RATE_CRITIC = 0.0001
LEARNING_RATE_GENERATOR = 0.0001
ADAM_BETA_1 = 0.5
ADAM_BETA_2 = 0.999
GP_WEIGHT = 25.0

In [3]:
#Critic

class Critic(nn.Module):
    def __init__(self):
        super(Critic, self).__init__()
        self.conv1 = nn.Conv2d(in_channels = globals.CHANNELS, out_channels = 64, kernel_size = 4, stride = 2, padding = 1)
        self.conv2 = nn.Conv2d(in_channels = 64, out_channels = 128, kernel_size = 4, stride = 2, padding = 1)
        self.conv3 = nn.Conv2d(in_channels = 128, out_channels = 256, kernel_size = 4, stride = 2, padding = 1)
        self.conv4 = nn.Conv2d(in_channels = 256, out_channels = 512, kernel_size = 4, stride = 2, padding = 1)
        self.conv5 = nn.Conv2d(in_channels = 512, out_channels = 1, kernel_size = 4, stride = 2, padding = 0)


        self._initialize_weights()


        self.dropout = nn.Dropout(0.1)     
        self.leakyReLU = nn.LeakyReLU(negative_slope = .1)


    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight, 0.0, 0.05)

    def forward(self, imgs):
        x = self.conv1(imgs)
        x = self.leakyReLU(x)
        x = self.dropout(x)

        x = self.conv2(x)
        x = self.leakyReLU(x)
        x = self.dropout(x)

        x = self.conv3(x)
        x = self.leakyReLU(x)
        x = self.dropout(x)

        x = self.conv4(x)
        x = self.leakyReLU(x)
        x = self.dropout(x)

        x = self.conv5(x)

        x = torch.flatten(x, start_dim = 1, end_dim = 3)

        return x
       


In [4]:
# Generator

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.conv2dtranspose1 = nn.ConvTranspose2d(in_channels = Z_DIM, out_channels = 512, kernel_size = 4, stride = 2, padding = 0, bias = False)
        self.conv2dtranspose2 = nn.ConvTranspose2d(in_channels = 512, out_channels = 256, kernel_size = 4, stride = 2, padding = 1, bias = False)
        self.conv2dtranspose3 = nn.ConvTranspose2d(in_channels = 256, out_channels = 128, kernel_size = 4, stride = 2, padding = 1, bias = False)
        self.conv2dtranspose4 = nn.ConvTranspose2d(in_channels = 128, out_channels = 64, kernel_size = 4, stride = 2, padding = 1, bias = False)
        self.conv2dtranspose5 = nn.ConvTranspose2d(in_channels = 64, out_channels = globals.CHANNELS, kernel_size = 4, stride = 2, padding = 1, bias = False)
        self._initialize_weights()

        self.leakyReLU = nn.LeakyReLU(negative_slope = .1)

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.ConvTranspose2d):
                nn.init.normal_(m.weight, 0.0, .1)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.normal_(m.weight, 1.0, 0.02)
                nn.init.constant_(m.bias, 0)
    
    def forward(self, latent_space):
        batch_size = latent_space.shape[0]
        x = torch.reshape(input = latent_space, shape = (batch_size, Z_DIM, 1, 1))

        x = self.conv2dtranspose1(x)
        x = nn.BatchNorm2d(x.shape[1])(x)
        x = self.leakyReLU(x)

        x = self.conv2dtranspose2(x)
        x = nn.BatchNorm2d(x.shape[1])(x)
        x = self.leakyReLU(x)

        x = self.conv2dtranspose3(x)
        x = nn.BatchNorm2d(x.shape[1])(x)
        x = self.leakyReLU(x)

        x = self.conv2dtranspose4(x)
        x = nn.BatchNorm2d(x.shape[1])(x)
        x = self.leakyReLU(x)

        x = self.conv2dtranspose5(x)

        x = nn.Tanh()(x)
        return x
    



In [5]:
#Class WGAN

class wgan(nn.Module):
    def __init__(self, generator, critic, critic_optimizer, generator_optimizer):
        super(wgan, self).__init__()
        self.generator = generator
        self.critic = critic
        self.latent_dim = Z_DIM
        self.gp_weight = GP_WEIGHT
        self.critic_opt = critic_optimizer
        self.generator_opt = generator_optimizer
        self.scheduler_critic = torch.optim.lr_scheduler.StepLR(self.critic_opt , step_size = globals.EPOCHS, gamma = 0.999)

    
    def gradient_penalty(self, batch_size, real_images, fake_images):
        alpha = torch.normal(size = (batch_size, 1, 1, 1), mean = 0.0, std = 1.0)

        diff = fake_images - real_images
        interpolated = (real_images + alpha * diff).requires_grad_(True)


        pred = torch.mean(self.critic(interpolated))

        gradients = torch.autograd.grad(
            outputs = pred,
            inputs = interpolated,
            grad_outputs = None,
            create_graph = True,
            retain_graph = True,
            only_inputs = True
        )[0]



        gradients = gradients.view(gradients.size(0), -1)
        gradients_norm = gradients.norm(2, dim = 1)

        gp = ((gradients_norm - 1)**2).mean()

        return gp
    

    def forward(self, real_images):
        batch_size = real_images.shape[0]


        for i in range(CRITIC_STEP):
            self.critic_opt.zero_grad()

            random_latent_vectors = torch.randn(size = (batch_size, self.latent_dim))

            fake_images = self.generator(random_latent_vectors)

            fake_pred = self.critic(fake_images.detach())
            real_pred = self.critic(real_images)


            c_wass_loss = torch.mean(fake_pred) - torch.mean(real_pred)

            c_gp = self.gradient_penalty(batch_size, real_images, fake_images.detach())

            c_loss = c_wass_loss + self.gp_weight * c_gp
            
            c_loss.backward()
            self.critic_opt.step()
   
        self.generator_opt.zero_grad()

        random_latent_vector = torch.randn(size = (batch_size, self.latent_dim))
        fake_images = self.generator(random_latent_vector)
        fake_predictions = self.critic(fake_images)
        g_loss = -1.0 * torch.mean(fake_predictions)

        g_loss.backward()
        self.generator_opt.step()
        self.scheduler_critic.step()

        print('c_loss is  {}'.format(c_loss))
        print('g_loss is  {}'.format(g_loss))
        print('c_gp is  {}'.format(c_gp))
        print('c_wass_loss is  {}'.format(c_wass_loss))



        return (c_loss, g_loss, c_gp, c_wass_loss)
    
