In [None]:
# Imports
!pip install torch
!pip install torchvision
!pip install matplotlib
import numpy as np
from matplotlib import pyplot as plt
import matplotlib.animation as animation
import torch.nn as nn
import torch.nn.functional as F
import torch
import torchvision
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.utils.data import ConcatDataset, DataLoader
from functools import partial
import torch.optim as optim
device = "cuda:0" if torch.cuda.is_available() else "cpu"

In [None]:
# Connect to Google Drive
from google.colab import drive
drive.mount('/content/gdrive')
root_dir = '/content/gdrive/MyDrive/NN for Images/Ex3'

Mounted at /content/gdrive


In [None]:
# Reproducability
import random
manualSeed = 999
random.seed(manualSeed)
torch.manual_seed(manualSeed)

<torch._C.Generator at 0x7f6ef054a1d0>

# **Data Loaders**
The dataset I used is the full MNIST, including both train and test sets

In [None]:
def get_dataloader(batch_size=128):
    transform = transforms.ToTensor()

    dataset_train = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    dataset_test = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
    dataset_merged = ConcatDataset([dataset_train, dataset_test])

    loader = torch.utils.data.DataLoader(dataset_merged, batch_size=batch_size, shuffle=True)

    return loader

# **Global Parameters & Helper Functions**

### **Global Parameters**
Definitions of dimensions labels

In [None]:
image_size = 28    # The 1-dimension of the image in the dataset
im_c = 1           # The number of channels per image in the dataset
z_size = 100       # The size latnent vector Z which is the input for G
batch_size = 128
dataloader = get_dataloader(batch_size)

# labling the ground truth and the faked
real_label = 1.
fake_label = 0.

# functions that create vector of 'real' / 'fake' labels for a given size
real_label_vec = lambda n : torch.full((n,), real_label, dtype=torch.float)
fake_label_vec = lambda n : torch.full((n,), fake_label, dtype=torch.float)

### **Weights Initialization**
The original DCGAN paper mentions that the initialized weights of all modules, should be taken from a Normal Distribution with `mean=0`, and `std=0.2`. This function will apply this criteria to our modules.

In [None]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

### **Animate List of Images**
This function creates an animation gif made of list of images.

In our case the images are grid of 64 small images that G genreated, helping to monitor it's performance.

In [None]:
def animate(img_list, name, dim=(8,8), iter_D_per_G=1):
    fig = plt.figure(figsize=dim)
    plt.axis("off")

    def update(i):
        plt.imshow(np.transpose(img_list[i], (1, 2, 0)), animated=True)

    ani = animation.FuncAnimation(fig, update, frames=len(img_list), interval=1000)
    ani.save("%s/Q1/%s_%d:1.gif" % (root_dir, name, iter_D_per_G))

### **Generator Loader**
Load the proper Generator given a `Config` object

In [None]:
# Generator loader
def load_G(con):
    path = r'{}/Q1/{}_with_loss_{}_{}:1.pth'
    G = Generator(con.Gfeatures)
    G.load_state_dict(torch.load(path.format(root_dir, "G", con.loss, con.iter_D_per_G)))
    G.eval()
    return G

### **Images Plotter**
Helper function to plot Q2's and Q3's results



In [None]:
def show_images(im_list, path, rect_coor):
    slots = len(im_list)
    if slots == 2:
        original, reconstructed = im_list
    else:
        original, damaged, reconstructed = im_list

    fig, axs = plt.subplots(1, slots, figsize=(8, 4))

    # Plot the original  image
    axs[0].imshow(original.squeeze().detach().numpy(), cmap='gray')
    axs[0].axis('off')
    axs[0].set_title('Original Image')

    # Plot the damaged image, in case of 3 images given
    if slots == 3:
        axs[1].imshow(damaged.squeeze().detach().numpy(), cmap='gray')
        axs[1].axis('off')
        axs[1].set_title('Damaged Image')
        if rect_coor[0] is not None:
          top, left = rect_coor
          rect = plt.Rectangle((left, top), 8, 8, edgecolor='red', linewidth=5, fill=False)
          axs[1].add_patch(rect)

    # Plot the reconstructed image
    axs[slots-1].imshow(reconstructed.squeeze().detach().numpy(), cmap='gray')
    axs[slots-1].axis('off')
    axs[slots-1].set_title('Reconstructed Image')

    plt.savefig(path)

### **Configuration object**
To track the hyper parameters easily

In [None]:
# An object for tracking hyper parameters
class Config:
    def __init__(self, loss_fn, iter_D_per_G=3):
        self.Gfeatures = 64
        self.Dfeatures = 32
        self.iter_D_per_G = iter_D_per_G
        self.num_epochs = 10
        self.lr = 0.0002
        self.l2_reg = 0.5
        self.loss = loss_fn

# **Modules Definitions**

### **Generator**
The Generator recieves a vector Z of size `z_size` (=100), which viewd here as a 1 x 1 image with 100 channels. It outputs an images of size `im_size x im_size` (28 x 28) with `im_c` (=1) channels. The channels amount in every Upconv layer is a factor of `Gfeatures` (=64). The exact architecture is as follow:

1.   Transposed 2-Strided Convolution (kernel=3, in=100, out=256)
2.   Batch Normalization
3.ReLU
1.   Transposed 2-Strided Convolution (kernel=3, in=256, out=128)
2.   Batch Normalization
3.ReLU
1.   Transposed 2-Strided Convolution (kernel=3, in=128, out=64)
2.   Batch Normalization
3.ReLU
1.   Transposed 2-Strided Convolution (kernel=3, in=64, out=1)
2.   Sigmoid




In [None]:
class Generator(nn.Module):
    def __init__(self, feat):
        super(Generator, self).__init__()
        self.seq = nn.Sequential(
            # Input size: 100 x 1 x 1

            nn.ConvTranspose2d(in_channels=100, out_channels=4 * feat,
                               kernel_size=3, stride=2, padding=0, bias=False),
            nn.BatchNorm2d(4 * feat),
            nn.ReLU(inplace=True),
            # current size: 4*feat x 4 x 4

            nn.ConvTranspose2d(in_channels=4 * feat, out_channels=2 * feat,
                               kernel_size=3, stride=2, padding=0, bias=False),
            nn.BatchNorm2d(2 * feat),
            nn.ReLU(inplace=True),
            # current size: 2*feat x 7 x 7

            nn.ConvTranspose2d(in_channels=2 * feat, out_channels=feat,
                               kernel_size=3, stride=2, padding=0, bias=False),
            nn.BatchNorm2d(feat),
            nn.ReLU(inplace=True),
            # current size: feat x 16 x 16

            nn.ConvTranspose2d(in_channels=feat, out_channels=1,
                               kernel_size=3, stride=2, padding=2, output_padding=1, bias=False),
            # current size: 1 x 28 x 28

            # For the last layer, the activation is  Sigmoid
            nn.Sigmoid()
        )

    def forward(self, z):
       return self.seq(z)

### **Discriminator**
The Discriminator recieves an images of size `im_size x im_size` (28 x 28) with `im_c` (=1) channels, and outputs a single scalar value that represents the probability for this image to be real. The channels amount in every Upconv layer is a factor of `Dfeatures` (=32). The exact architecture is as follow:

1.   2-Strided Convolution (kernel=4, in=1, out=32)
2.   Batch Normalization
3.LeakyReLU (slope=0.2)
1.   2-Strided Convolution (kernel=4, in=32, out=64)
2.   Batch Normalization
3.LeakyReLU (slope=0.2)
1.   2-Strided Convolution (kernel=4, in=64, out=128)
2.   Batch Normalization
3.LeakyReLU (slope=0.2)
5. Flatten
6. Fully-Connected (in=128, out=1)
2.   Sigmoid

In [None]:
class FlattenBatch(nn.Module):
    # This is a module that only flatten a given batch
    def __init__(self):
        super(FlattenBatch, self).__init__()

    def forward(self, x):
        return x.flatten(1)

class Discriminator(nn.Module):
    def __init__(self, feat):
        super(Discriminator, self).__init__()
        self.seq = nn.Sequential(
            # current size: 1 x 28 x 28
            nn.Conv2d(in_channels=1, out_channels=feat, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(feat),
            nn.LeakyReLU(0.2, inplace=True),
            # current size: feat x 14 x 14

            nn.Conv2d(in_channels=feat, out_channels=2 * feat, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(2 * feat),
            nn.LeakyReLU(0.2, inplace=True),
            # current size: 2*feat x 7 x 7

            nn.Conv2d(in_channels=2 * feat, out_channels=4 * feat, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(4 * feat),
            nn.LeakyReLU(0.2, inplace=True),
            # current size: 4*feat x 4 x 4

            # custom layer that flatten the batch
            FlattenBatch(),
            nn.Linear(4*feat * 3 * 3, 1),

            nn.Sigmoid()
        )

    def forward(self, x):
        return self.seq(x)

### **Loss functions**
Definitions of the 3 different loss functions we investigated:


*   **The Original DCGAN Loss:**
In this case, the loss of D is the `Binary-Cross-Entropy(BCELoss)` loss, G is trying to minimze `log(1-D(G(z)))`. In order to achieve that, G's loss will be defined as `BCELoss` that compare D's predicitions with real-labels only, and negate the results.
*   **The Non-Saturating Loss:**
In this case, G is trying to maximize `log(D(G(z)))`, or equivalently minimize `-log(D(G(z))).` To achieve that, the loss will be defined as `BCELoss` that compare D's predicitions with fake-labels only.



*   **The L2 Loss:** This case is very similar to Non-Saturating case, but here D's loss is `MSELoss` that corresponding to L2.





In [None]:
def get_loss(loss_name):
  if loss_name not in ['Saturated', 'Non-Saturation', 'L2']:
    raise ValueError("loss_name should be one of: 'Saturated', 'Non-Saturation', 'L2'")

  if loss_name == 'Saturated':
    # The original loss; G minimizes probability of 'fake' responses of D
    D_loss = nn.BCELoss()
    G_loss = lambda x : - D_loss(x, fake_label_vec(x.size(0)))

  elif loss_name == 'Non-Saturation':
    # The Non-Saturating loss; G maximizes the probability of 'real' responses of D
    D_loss = nn.BCELoss()
    G_loss = lambda x : D_loss(x, real_label_vec(x.size(0)))

  elif loss_name == 'L2':
    # The L2 loss; G minimizes the MSE of 'real' responses
    D_loss = nn.MSELoss()
    G_loss = lambda x : D_loss(x, real_label_vec(x.size(0)))

  return {"D_loss" : D_loss, "G_loss" : G_loss}

# **Train Loop**

In [None]:
def train_disciminator(D, G, data, criterion, optimizerD):
    D.zero_grad()

    # Train with all-real batch
    output = D(data).view(-1)
    errD_real = criterion(output, real_label_vec(data.size(0)))
    errD_real.backward()

    # Produces fake data
    noise = torch.randn(data.size(0), z_size, 1, 1)
    fake_data = G(noise)

    # Train with all-fake batch
    output = D(fake_data.detach()).view(-1)
    errD_fake = criterion(output, fake_label_vec(data.size(0)))
    errD_fake.backward()

    errD = errD_real + errD_fake
    optimizerD.step()

    return errD.item(), fake_data

In [None]:
def train_generator(D, G, fake_data, criterion, optimizerG):
    G.zero_grad()

    # check out what D thinks this data is
    output = D(fake_data).view(-1)

    # Update G accordingly to the loss function and the predictions
    errG = criterion(output)
    errG.backward()
    optimizerG.step()

    return errG.item()

In [None]:
def trainloop(config=None, img_list=[], D_err=[], G_err=[]):

    # batch of latent vectors (Z), that will help us to monitor the
    # performance of G along the training loop
    fixed_noise, iters = torch.randn(64, z_size, 1, 1), 0

    # Initializes the Generator and the Discriminator and normalize the
    # weights to a Normal Distribution of mean=0, std=0.2
    G, D = Generator(config.Gfeatures), Discriminator(config.Dfeatures)
    G.apply(weights_init)
    D.apply(weights_init)

    # Build the loss functions
    losses = get_loss(config.loss)
    D_loss, G_loss = losses["D_loss"], losses["G_loss"]

    # Setup Adam optimizers for both G and D with L2 regularization
    optimizerD = optim.Adam(D.parameters(), lr=config.lr, betas=(config.l2_reg, 0.999))
    optimizerG = optim.Adam(G.parameters(), lr=config.lr, betas=(config.l2_reg, 0.999))

    print("Starting Training Loop...")

    # For each epoch
    for epoch in range(config.num_epochs):
        for i, data in enumerate(dataloader):

            # Train the Discriminator and produces fake data for the Generator training
            errD, fake_data = train_disciminator(D, G, data[0], D_loss, optimizerD)

            # Train the Generator every <iter_D_per_G> batches, with the fake data
            if i % config.iter_D_per_G == 0:
                errG = train_generator(D, G, fake_data, G_loss, optimizerG)

                # Plot the errors
                D_err.append(errD)
                G_err.append(errG)

            # Print some statistics
            if i % 50 == 0:
                print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f' % (epoch, config.num_epochs, i, len(dataloader), errD, errG))

            # Save G's output on some fixed noise vectors to monitor performance
            if (iters % 500 == 0) or ((epoch == config.num_epochs-1) and (i == len(dataloader)-1)):
                with torch.no_grad():
                    fake = G(fixed_noise).detach()
                    img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
            iters += 1

    # Svae the trained modules
    path = r'{}/Q1/{}_with_loss_{}_{}:1.pth'
    torch.save(G.state_dict(), path.format(root_dir, "G", config.loss, config.iter_D_per_G))
    torch.save(D.state_dict(), path.format(root_dir, "D", config.loss, config.iter_D_per_G))

# **Q1**
Investigating the 3 different losses

In [None]:
def Q1():

  # Try different loss functions
  for loss_fn in ['Saturated', 'Non-Saturation', 'L2']:
      con = Config(loss_fn=loss_fn)
      img_list, D_err, G_err = [], [], []

      try:
        trainloop(con, img_list, D_err, G_err)
      except Exception as e:
        pass

      # plot the loss and animate the progress even when fail
      finally:
        plt.cla()
        plt.title(loss_fn+" loss")
        x = [i for i in range(len(D_err))]
        plt.plot(x, D_err, alpha=0.5, label="D loss")
        plt.plot(x, G_err, alpha=0.5, label="G loss")
        plt.legend()
        plt.savefig('%s/Q1/%s_loss__%d:1.png'%(root_dir, loss_fn, con.iter_D_per_G))
        animate(img_list, loss_fn, dim=(8,8))
        print("Figures saved successfully")

# **Q2**
Find optimal Z for a given image

In [None]:
def find_optimal_latent_vec(im, G, loss="L2", steps=1000, lr=0.01):

    # Create a latent vector, just noise
    z = torch.randn(1, z_size, 1, 1)
    z.requires_grad = True

    # Choose loss function accordingly to 'loss'
    criterion = nn.MSELoss() if loss == "L2" else nn.L1Loss()

    # The optimizer performs only on the latent vector z
    optimizer = optim.Adam([z], lr)

    # Perform optimization
    for step in range(steps):
        optimizer.zero_grad()

        # Generate image from the latent vector
        generated_image = G(z)

        # Compute the loss between the generated image and the target image
        loss = criterion(generated_image, im)

        # Backpropagate the gradients
        loss.backward()
        optimizer.step()

        # # Uncomment here to print the loss every 100 steps
        # if step % 100 == 0:
        #     print("Step [{}/{}], Loss: {:.4f}".format(step, steps, loss.item()))

    # Return the optimized latent vector
    return z

In [None]:
def Q2(loss_fn, iter_D_per_G):

    # Load the proper G in an evaluation mode the prevent updating
    con = Config(loss_fn=loss_fn, iter_D_per_G=iter_D_per_G)
    G = load_G(con)

    # Create dataloader for single images
    single_dataloader = get_dataloader(batch_size=1)

    # Use GAN inversion for 5 different images
    for i in range(5):

        # Grab a random image from the MNIST dataset
        image, _ = next(iter(single_dataloader))

        # Find an optimal latent vector that reconstruct this image
        optimal_Z = find_optimal_latent_vec(image, G, "L2")

        # Generate the reconstructed image from the inverted latent vector
        reconstructed_image = G(optimal_Z)

        # Display the target image and the reconstructed image
        path = '%s/Q2/%s_loss__%d:1_%d.png'%(root_dir, con.loss, con.iter_D_per_G, i)
        show_images([image, reconstructed_image], path)

# **Q3**
Reconstructing corrupted images using GAN inversion

In [None]:
# Adding noise taken from Normal distribution with mean=0 and std=0.1
def add_gaussian_noise(im):
    noise = torch.randn_like(im) * 0.1
    noisy_image = im + noise
    return noisy_image, None, None

In [None]:
def paint_random_window_black(image, window_size=8):

    # Choose coordinate randomly
    top = random.randint(5, image_size - window_size - 5)
    left = random.randint(5, image_size - window_size - 5)

    # paint the window in black
    clone = image.clone()
    clone[0][0][top:top+8, left:left+8] = 0

    return clone, top, left

In [None]:
def Q3(loss_fn, iter_D_per_G):

    # Load the proper G in an evaluation mode the prevent updating
    con = Config(loss_fn=loss_fn, iter_D_per_G=iter_D_per_G)
    G = load_G(con)

    # Create dataloader for single images
    single_dataloader = get_dataloader(batch_size=1)

    # Definitions: For denoising choose MSE, for inpainting choose L1
    problems = ["Denoising", "Inpainting"]
    damaging_func = [add_gaussian_noise, paint_random_window_black]
    loss = ["L2", "L1"]

    # Use GAN inversion on 10 images of each method: Denoising & Inpainting
    for method, foo, l in zip(problems, damaging_func, loss):
        for i in range(30):
            if method == "Denoising":
              break

            # Grab a random image from the MNIST dataset
            orig_image, _ = next(iter(single_dataloader))

            # Corrupt the image. In case of inpainting, the top&left are the window's corner
            corrupt_image, top, left = foo(orig_image)

            # Find an optimal latent vector that reconstruct this image
            optimal_Z = find_optimal_latent_vec(corrupt_image, G, l, steps=1000, lr=0.01)

            # Generate the reconstructed image from the inverted latent vector
            reconstructed_image = G(optimal_Z)

            # Display the target image and the reconstructed image
            path = '%s/Q3/%s/%s_loss__%d:1_%d.png'%(root_dir, method, con.loss, con.iter_D_per_G, i)
            show_images([orig_image, corrupt_image, reconstructed_image], path, rect_coor=[top,left])


# **Prompt Runners**

In [None]:
Q1()

In [None]:
Q2('Non-Saturation', 3)

In [None]:
Q3('Non-Saturation', 3)