# Import necessary libraries and creating a function for visualization

In [None]:
import torch
from torch import nn
from tqdm.auto import tqdm
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
from torchvision.datasets import MNIST
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
from torchinfo import summary
import matplotlib.pyplot as plt
import glob
from PIL import Image


def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28)):
    '''
    Function for visualizing images: Given a tensor of images, number of images, and
    size per image, plots and prints the images in an uniform grid.
    '''
    image_tensor = (image_tensor + 1) / 2
    image_unflat = image_tensor.detach().cpu()
    image_grid = make_grid(image_unflat[:num_images], nrow=5)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    plt.show()

# Loading the images into a list

In [None]:
images_list = []
images_list.append(glob.glob('/kaggle/input/siim-isic-2019-organized/dataset organized/malignant/*.jpg'))
images = []
for filepaths in images_list:
    images.extend(filepaths)

<h2> What are WGANS? </h2>
<p style="font-size: 15px;"> A Wasserstein Generative Adversarial Network (WGAN) is an advanced type of Generative Adversarial Network (GAN) designed to improve the stability and effectiveness of training GANs. Traditional GANs often suffer from issues like mode collapse, vanishing gradients, and unstable training dynamics, which make them challenging to train. WGAN addresses these issues by using the Earth Mover's (EM) distance, also known as the Wasserstein distance, as the loss function instead of the Jensen-Shannon divergence used in standard GANs. </p>

<h2> The Generator </h2>

<p style="font-size: 15px;"> The generator in a WGAN is designed to create synthetic data samples that are indistinguishable from real data. It operates by taking in random noise vectors sampled from a latent space, which are usually drawn from a simple distribution such as a Gaussian or uniform distribution. These noise vectors are transformed through a series of neural network layers in the generator and the result of this transformation is a synthetic data sample that should closely resemble the real data the model is trying to replicate. During the training process, the generator aims to minimize the Wasserstein distance between the distribution of the generated data and the real data distribution. </p>

In [None]:
class Generator(nn.Module):
    def __init__(self, z_dim=100, im_chan=3, hidden_dim=128):
        super(Generator, self).__init__()
        self.z_dim = z_dim
        # Build the neural network
        self.gen = nn.Sequential(
            # Z latent vector 100
            self.make_gen_block(z_dim, hidden_dim * 8, kernel_size=4, stride=1, padding = 0), 
            # state = 1024 x 4 x 4
            self.make_gen_block(hidden_dim * 8, hidden_dim * 4, kernel_size=4, stride=2, padding = 1),
            # state = 512 x 8 x 8
            self.make_gen_block(hidden_dim * 4, hidden_dim * 2, kernel_size=4, stride=2, padding = 1),
            # state = 256 x 16 x 16
            self.make_gen_block(hidden_dim * 2, hidden_dim, kernel_size=4, stride=2, padding = 1),
            # state = 128 x 32 x 32
            self.make_gen_block(hidden_dim, hidden_dim // 2, kernel_size=4, stride=2, padding = 1),
            # state = 64 x 64 x 64
            self.make_gen_block(hidden_dim // 2, im_chan, kernel_size=4, stride=2, padding = 1, final_layer=True),
            # Final state = 3 x 128 x 128
        )

    def make_gen_block(self, input_channels, output_channels, kernel_size=4, stride=1, padding = 0, final_layer=False,):
        if not final_layer:
            return nn.Sequential(
                nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride, padding = padding),
                nn.BatchNorm2d(output_channels),
                nn.ReLU(inplace=True),
            )
        else:
            return nn.Sequential(
                nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride, padding = 1),
                nn.Tanh(),
            )

    def forward(self, noise):
        x = noise.view(len(noise), self.z_dim, 1, 1)
        return self.gen(x)

def get_noise(n_samples, z_dim, device='cpu'):
    return torch.randn(n_samples, z_dim, device=device)

<h2> The Critic </h2>
<p style="font-size: 15px;"> The critic in a WGAN replaces the discriminator used in traditional GANs. Unlike the discriminator, which classifies inputs as real or fake, the critic's job is to score the inputs in a way that approximates the Wasserstein distance between the real and generated data distributions. This scoring mechanism allows the critic to provide more informative feedback for improving the generator. During the training process, the critic's objective is to maximize the Wasserstein distance, making the scores for real and fake data as distinct as possible.</p>

In [None]:
class Critic(nn.Module):
    def __init__(self, im_chan=3, hidden_dim=128):
        super(Critic, self).__init__()
        self.crit = nn.Sequential(
            # Image (3x128x128)
            self.make_crit_block(im_chan, hidden_dim),
            # State = 128 x 64 x 64
            self.make_crit_block(hidden_dim, hidden_dim * 2),
            # State = 256 x 32 x 32
            self.make_crit_block(hidden_dim * 2, hidden_dim * 4),
            # State = 512 x 16 x 16
            self.make_crit_block(hidden_dim * 4, hidden_dim * 8),
            # State = 1024 x 8 x 8
            self.make_crit_block(hidden_dim * 8, hidden_dim * 16),
            # State = 2048 x 4 x 4
            self.make_crit_block(hidden_dim * 16, 1, stride = 1, padding = 0, final_layer=True),
            # Final state = 1 x 1 x 1
        )

    def make_crit_block(self, input_channels, output_channels, kernel_size=4, stride=2, padding = 1, final_layer=False):
        if not final_layer:
            return nn.Sequential(
                nn.Conv2d(input_channels, output_channels, kernel_size, stride, padding = padding),
                nn.BatchNorm2d(output_channels),
                nn.LeakyReLU(0.2, inplace=True),
            )
        else:
            return nn.Sequential(
                nn.Conv2d(input_channels, output_channels, kernel_size, stride, padding = padding),
            )

    def forward(self, image):
        crit_pred = self.crit(image)
        return crit_pred.view(len(crit_pred), -1)

<h2> Data Preprocessing </h2>
<p style="font-size: 15px;"> The images are loaded and resized into a shape of 128 x 128 x 3. The transformations applied include a center crop, a random horizontal flip with a probability of 0.5, color jitter, and random rotation with a probability of 0.2. Normalization was also done by subtracting the mean and dividing by the standard deviation.  </p>

In [None]:
# Data preprocessing
class CustomDataset(Dataset):
    def __init__(self, image_paths, transform=None):
        self.image_paths = image_paths
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]

        # Read and process image
        img = Image.open(image_path).convert("RGB")
        if self.transform:
            img = self.transform(img)
        # Ensures channels last format
#         img = img.permute(1, 2, 0)


        return img

# Define transformations
random_transforms = [transforms.ColorJitter(), transforms.RandomRotation(degrees=20)]
transform = transforms.Compose([transforms.Resize(128),
                                transforms.CenterCrop(128),
                                transforms.RandomHorizontalFlip(p=0.5),
                                transforms.RandomApply(random_transforms, p=0.2),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

# Create dataset and dataloader
custom_dataset = CustomDataset(images, transform=transform)
dataloader = DataLoader(custom_dataset, batch_size=64, shuffle=True)

<h2> Sample images </h2>

In [None]:
item = next(iter(dataloader))
print(item.shape)
show_tensor_images(item)

<h2> Setting hyperparameters </h2>

In [None]:
n_epochs = 100
z_dim = 100
display_step = 10
batch_size = 64
lr = 0.0002
beta_1 = 0.5
beta_2 = 0.999
betas = (beta_1, beta_2)
c_lambda = 10
crit_repeats = 5
if torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu'
device

<h2> Initialization </h2>
<p style="font-size: 15px;"> The generator and the critic are initialized. Weights of the model are initialized using a normal distribution </p>

In [None]:
gen = Generator(z_dim).to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr = lr, betas = betas)
critic = Critic().to(device)
crit_opt = torch.optim.Adam(critic.parameters(), lr = lr, betas = betas)

def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)
gen = gen.apply(weights_init)
critic = critic.apply(weights_init)


<h2> Summary of the generator and the critic </h2>

In [None]:
summary(gen, input_size = (batch_size, z_dim))

In [None]:
summary(critic, input_size = (batch_size, 3, 128, 128))

<h2> Gradient penalty </h2>
<p style="font-size: 15px;"> A crucial aspect of the critic's training is ensuring that it remains 1-Lipschitz continuous, which is typically enforced through a technique called gradient penalty. A function is Lipschitz continuous if there exists a constant 
𝐾 such that for any two points a and b, the absolute difference between the function's values at these points is bounded by 𝐾 times the distance between a and b. In mathematical terms, for a function 𝑓, </p>

<p style="font-size: 15px; text-align: center;"> ∣𝑓(b) − 𝑓(a)∣ ≤ 𝐾 ∥b−a∥.

<p style="font-size: 15px;"> In the context of WGAN, we require the critic function 
𝐷 to be 1-Lipschitz continuous, meaning 𝐾=1. This constraint is necessary for the theoretical properties of the Wasserstein distance to hold, ensuring stable and meaningful gradients. Mathematically, the above equation is equivalent to ensuring that the norm of the function's gradient must not be greater than 1. </p>

<p style="font-size: 15px;"> In order to enforce this constraint, the gradient penalty method was used. Introduced by I. Gulrajani (https://arxiv.org/abs/1704.00028), this method involves adding a regularization term to the loss function which penalizes the critic when the gradient norm is higher than one (to satisfy the 1-Lipschitz constraint). Since measuring the gradient of the critic at every possible point of the feature space is impractical, the authors come up with a clever way to assess when the gradient is higher than one. Instead of sampling all the points, a new image interpolated between the real and fake images is generated and is passed through the discriminator. The gradient obtained from this output is penalized if higher than one. Through this method, the 1-Lipschitz constraint is not enforced at all points but is rather encouraged. This soft constraint has been observed to work well in practice.

In [None]:
# Getting gradient for gradient penalty
def get_gradient(critic, real, fake, epsilon):
    interpolated_img = real * epsilon + fake * (1 - epsilon)
    y_hat = critic(interpolated_img)
    gradient = torch.autograd.grad(
        inputs = interpolated_img,
        outputs = y_hat,
        grad_outputs = torch.ones_like(y_hat),
        create_graph = True,
        retain_graph = True,
    )[0]
    return gradient


In [None]:
# Computing gradient penalty
def gradient_penalty(gradient):
    gradient = gradient.view(len(gradient), -1)
    norm = gradient.norm(2, dim = 1)
    return torch.mean((norm - 1)**2)

<h2> Critic Loss </h2>
<p style="font-size: 15px;"> The critic maximizes the difference between the original distribution and the generated distribution. Or in mathematical terms, the critic maximizes </p>
\begin{equation}
\mathbb{E}_{\mathbf{x} \sim p_\text{data}} [D(\mathbf{x})] - \mathbb{E}_{\mathbf{z} \sim p_\mathbf{z}} [D(G(\mathbf{z}))]
\end{equation}
<br>
<p style="font-size: 15px;"> Maximizing an equation is the same as minimizing the negative. So, the critic tries to minimize:

\begin{equation}
L_D = \mathbb{E}_{\mathbf{z} \sim p_\mathbf{z}} [D(G(\mathbf{z}))] - \mathbb{E}_{\mathbf{x} \sim p_\text{data}} [D(\mathbf{x})] + \lambda \mathbb{E}_{\hat{\mathbf{x}} \sim p_{\hat{\mathbf{x}}}} \left[ (\|\nabla_{\hat{\mathbf{x}}} D(\hat{\mathbf{x}})\|_2 - 1)^2 \right]
\end{equation}
<br>
<p style="font-size: 15px;">The first term is the regular Earth Mover's loss function, the second term is the Gradient Penalty regularization term </p>


In [None]:
def get_crit_loss(crit_fake_pred, crit_real_pred, gp, c_lambda):
    return torch.mean(crit_fake_pred) - torch.mean(crit_real_pred) + c_lambda * gp

<h2> Generator Loss </h2>
\begin{equation}
L_G = -\mathbb{E}_{\mathbf{z} \sim p_\mathbf{z}} [D(G(\mathbf{z}))]
\end{equation}
<br>

In [None]:
def get_gen_loss(crit_fake_pred):
    return -torch.mean(crit_fake_pred)

In [None]:
def visualize_code(generator_losses, critic_losses, curr_step, real, fake, display_step = 10):
    gen_mean = sum(generator_losses[-display_step:]) / display_step
    crit_mean = sum(critic_losses[-display_step:]) / display_step
    print(f"Step {curr_step}: Generator loss: {gen_mean}, critic loss: {crit_mean}")
    show_tensor_images(fake)
    show_tensor_images(real)
    step_bins = 20
    num_examples = (len(generator_losses) // step_bins) * step_bins
    plt.plot(
        range(num_examples // step_bins), 
        torch.Tensor(generator_losses[:num_examples]).view(-1, step_bins).mean(1),
        label="Generator Loss"
    )
    plt.plot(
        range(num_examples // step_bins), 
        torch.Tensor(critic_losses[:num_examples]).view(-1, step_bins).mean(1),
        label="Critic Loss"
    )
    plt.legend()
    plt.show()

<h2> Training loop </h2>

<p style="font-size: 15px;"> The model is trained for 500 epochs. The critic is updated 5 times for every 1 update of the generator to provide meaningful information to the generator. </p>

In [None]:
curr_step = 0
generator_losses = []
critic_losses = []

for epoch in range(0): # Change this to n_epochs to start training
    for real in tqdm(dataloader):
        real = real.to(device)
        # Update critic
        num_of_samples = real.shape[0]
        # Critic is updated more times than the generator
        sum_crit_loss = 0
        for _ in range(crit_repeats): 
            crit_opt.zero_grad()
            noise = get_noise(num_of_samples, z_dim, device)
            fake = gen(noise)
            crit_fake_pred = critic(fake.detach())
            crit_real_pred = critic(real)
            
            epsilon = torch.rand(len(real), 1, 1, 1, device=device, requires_grad=True)

            gradient = get_gradient(critic, real, fake.detach(), epsilon)
            gp = gradient_penalty(gradient)
            crit_loss = get_crit_loss(crit_fake_pred, crit_real_pred, gp, c_lambda)
            sum_crit_loss += crit_loss.item()
            
            crit_loss.backward(retain_graph = True)
            crit_opt.step()
            
        mean_crit_loss = sum_crit_loss/crit_repeats
        critic_losses.append(mean_crit_loss)
        
        # Update generator
        gen_opt.zero_grad()
        noise = get_noise(num_of_samples, z_dim, device)
        fake = gen(noise)
        crit_fake_pred = critic(fake)
        gen_loss = get_gen_loss(crit_fake_pred)
        gen_loss.backward()
        gen_opt.step()
        generator_losses.append(gen_loss.item())
        
        # Visualization step
        if curr_step % display_step == 0 and curr_step > 0:
            visualize_code(generator_losses, critic_losses, curr_step, real, fake, display_step)
        curr_step += 1

In [None]:
# Saving model
# torch.save(gen.state_dict(), 'gen128_500.pth')
# torch.save(critic.state_dict(), 'critic128_500.pth')

<h2> Sampling from the model </h2>

In [None]:
gen = Generator(z_dim).to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr = lr, betas = betas)
critic = Critic().to(device)
crit_opt = torch.optim.Adam(critic.parameters(), lr = lr, betas = betas)

# Load the saved state dictionaries
# This doesnt work when trying to save the notebook
# gen.load_state_dict(torch.load('gen128_500.pth'))
# critic.load_state_dict(torch.load('critic128_500.pth'))

# Displaying generator outcome
# noise = get_noise(64, z_dim, device)
# fake = gen(noise)
# show_tensor_images(fake)

<h2> Displaying output from generator

In [None]:
from IPython.display import Image, display

# Display the saved image
display(Image(filename='/kaggle/input/gen-image/image-gen.png', width=400, height=300))

In [None]:
# from torchvision.utils import save_image
# import os
# directory = "/kaggle/working/output_imgs/"
# if not os.path.exists(directory):
#     os.mkdir(directory)
# def save_imgs(image_tensor, offset):
#     offset = 600 * offset
#     image_tensor = (image_tensor + 1) / 2
#     image_unflat = image_tensor.detach().cpu()
#     for i in range(image_unflat.shape[0]):
#         img = image_unflat[i]
#         save_image(img, directory + f"img{i + offset}_test.png")

In [None]:
# # Loading images to a working folder 
# for i in range(10):
#     noise = get_noise(600, z_dim, device)
#     fake_imgs = gen(noise)
#     save_imgs(fake_imgs, i)

In [None]:
# # Converting output to zip
# !zip -r "/kaggle/working/output_imgs_final.zip" "/kaggle/working/output_imgs/"