# GANs Homework

This assignment consists of 2 parts:
- SNGAN on CIFAR-10 (15 points)
- CycleGAN on ColoredMNIST (10 points)

Each task is accompanied with some amount of code to allow you to concentrate on the most interesting parts of the assignment. You are free to modify any code in the solution section if you think it will make it more convenient. However you should not modify data acquisition and result sections.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import IPython
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim
from scipy.stats import norm
from tqdm import trange, tqdm_notebook
import os.path as osp
import warnings
warnings.filterwarnings('ignore')

# GANs on CIFAR-10 (15 points)

In this exercise, you will train GANs on CIFAR-10. Execute the cell below to visualize the dataset. 

In [None]:
import torchvision
from torchvision.utils import make_grid
from matplotlib.pyplot import savefig

def show_samples(samples, fname=None, nrow=10, title='Samples'):
    samples = (torch.FloatTensor(samples) / 255).permute(0, 3, 1, 2)
    grid_img = make_grid(samples, nrow=nrow)
    plt.figure()
    plt.title(title)
    plt.imshow(grid_img.permute(1, 2, 0))
    plt.axis('off')

    if fname is not None:
        savefig(fname)
    else:
        plt.show()

def load_cifar():
    train_data = torchvision.datasets.CIFAR10("./data", transform=torchvision.transforms.ToTensor(),
                                              download=True, train=True)
    return train_data

def visualize_cifar():
    train_data = load_ci()
    imgs = train_data.data[:100]
    show_samples(imgs, title=f'CIFAR-10 Samples')

visualize_cifar()

 We'll use the CIFAR-10 architecture from the [SN-GAN paper](https://arxiv.org/pdf/1802.05957.pdf), with $z \in \mathbb R ^{128}$, with $z \sim \mathcal N (0, I_{128})$. Instead of upsampling via transposed convolutions and downsampling via pooling or striding, we'll use these DepthToSpace and SpaceToDepth methods for changing the spatial configuration of our hidden states. 

```
# Spatial Upsampling with Nearest Neighbors
Upsample_Conv2d(in_dim, out_dim, kernel_size=(3, 3), stride=1, padding=1):
    x = torch.cat([x, x, x, x], dim=1)
    DepthToSpace(block_size=2)
    Conv2d(in_dim, out_dim, kernel_size, stride=stride, padding=padding)


# Spatial Downsampling with Spatial Mean Pooling
Downsample_Conv2d(in_dim, out_dim, kernel_size=(3, 3), stride=1, padding=1):
        SpaceToDepth(2)
        torch.sum(x.chunk(4, dim=1)) / 4.0
        nn.Conv2d(in_dim, out_dim, kernel_size,
                              stride=stride, padding=padding, bias=bias)
```

Here's pseudocode for how we'll implement a ResBlockUp, used in the generator:

```
ResnetBlockUp(x, in_dim, kernel_size=(3, 3), n_filters=256):
    _x = x
    _x = nn.BatchNorm2d(in_dim)(_x)
    _x = nn.ReLU()(_x)
    _x = nn.Conv2d(in_dim, n_filters, kernel_size, padding=1)(_x)
    _x = nn.BatchNorm2d(n_filters)(_x)
    _x = nn.ReLU()(_x)
    residual = Upsample_Conv2d(n_filters, n_filters, kernel_size, padding=1)(_x)
    shortcut = Upsample_Conv2d(in_dim, n_filters, kernel_size=(1, 1), padding=0)(x)
    return residual + shortcut
```
The ResBlockDown module is similar, except it uses Downsample_Conv2d, omits the BatchNorm and has spectral_norm. You may try to use spectral_norm in generator too, it works quite well as a regularization

Finally, here's the architecture for the generator:
```
def Generator(*, n_samples=1024, n_filters=128):
    z = Normal(0, 1)([n_samples, 128])
    nn.Linear(128, 4*4*256)
    reshape output of linear layer
    ResnetBlockUp(in_dim=256, n_filters=n_filters),
    ResnetBlockUp(in_dim=n_filters, n_filters=n_filters),
    ResnetBlockUp(in_dim=n_filters, n_filters=n_filters),
    nn.BatchNorm2d(n_filters),
    nn.ReLU(),
    nn.Conv2d(n_filters, 3, kernel_size=(3, 3), padding=1),
    nn.Tanh()
```
Again, the discriminator has the same architecture, except with ResnetBlockDown and **no BatchNorm**

**Hyperparameters**

We'll implement SNGAN, which uses a spectral_norm to force Lipschitz continuity. To add spectral_norm parametrization, use [spectral_norm](https://pytorch.org/docs/1.13/generated/torch.nn.utils.parametrizations.spectral_norm.html) from pytorch

Use the Adam optimizer with $\alpha = 2e-4$, $\beta_1 = 0.5$, $\beta_2 = 0.999$. Use the GAN training params: $n_{critic} = 3$. Use a batch size of 256 and n_filters=128 within the ResBlocks. 
Train for at least 25000 steps (**Warning: 25000 steps will take ~12 hours on colab, so consider starting from 2500 steps and if your generator converges to something reasonable you can proceed**) gradient steps, with the learning rate linearly annealed to 0 over training. 

**Objective**

We'll use hinge-loss objective for discriminator:

$$argmax_D \lgroup \mathbb{E}_{x \sim q_{data}(x)} min(0, -1 + D(x)) + \mathbb{E}_{z \sim p(z)} min(0, -1 - D(G(z))) \rgroup$$

And for generator:

$$argmin_G \mathbb{E}_{z \sim p(z)} -D(G(z))$$

**You will provide the following deliverables**
1. Inception score (CIFAR-10 version) of the final model. We provide a utility that will automatically do this for you
3. Generator objectives across training
4. 100 samples

## Solution

In [None]:
from torch.nn.utils.parametrizations import spectral_norm


class DepthToSpace(nn.Module):
    def __init__(self, block_size):
        super(DepthToSpace, self).__init__()
        self.block_size = block_size
        self.block_size_sq = block_size * block_size

    def forward(self, input):
        output = input.permute(0, 2, 3, 1)
        (batch_size, d_height, d_width, d_depth) = output.size()
        s_depth = int(d_depth / self.block_size_sq)
        s_width = int(d_width * self.block_size)
        s_height = int(d_height * self.block_size)
        t_1 = output.reshape(batch_size, d_height, d_width, self.block_size_sq, s_depth)
        spl = t_1.split(self.block_size, 3)
        stack = [t_t.reshape(batch_size, d_height, s_width, s_depth) for t_t in spl]
        output = torch.stack(stack, 0).transpose(0, 1).permute(0, 2, 1, 3, 4).reshape(batch_size, 
                                                                                      s_height, 
                                                                                      s_width,
                                                                                      s_depth)
        output = output.permute(0, 3, 1, 2)
        return output.contiguous()


class SpaceToDepth(nn.Module):
    def __init__(self, block_size):
        super(SpaceToDepth, self).__init__()
        self.block_size = block_size
        self.block_size_sq = block_size * block_size

    def forward(self, input):
        output = input.permute(0, 2, 3, 1)
        (batch_size, s_height, s_width, s_depth) = output.size()
        d_depth = s_depth * self.block_size_sq
        d_width = int(s_width / self.block_size)
        d_height = int(s_height / self.block_size)
        t_1 = output.split(self.block_size, 2)
        stack = [t_t.reshape(batch_size, d_height, d_depth) for t_t in t_1]
        output = torch.stack(stack, 1)
        output = output.permute(0, 2, 1, 3)
        output = output.permute(0, 3, 1, 2)
        return output.contiguous()


class Upsample_Conv2d(nn.Module):
    def __init__(self, in_dim, out_dim, kernel_size=(3, 3), stride=1, padding=1, bias=True):
        super(Upsample_Conv2d, self).__init__()
        # YOUR_CODE

    def forward(self, x):
        # YOUR_CODE
        return x


class Downsample_Conv2d(nn.Module):
    def __init__(self, in_dim, out_dim, kernel_size=(3, 3), stride=1, padding=1, bias=True):
        super(Downsample_Conv2d, self).__init__()
        # YOUR_CODE

    def forward(self, x):
        # YOUR_CODE
        return x


class ResnetBlockUp(nn.Module):
    def __init__(self, in_dim, kernel_size=(3, 3), n_filters=256):
        super(ResnetBlockUp, self).__init__()
        self.residual = # YOUR_CODE
        self.shortcut = # YOUR_CODE

    def forward(self, x):
        _x = x
        residual = self.residual(_x)
        shortcut = self.shortcut(x)
        return residual + shortcut


class ResnetBlockDown(nn.Module):
    def __init__(self, in_dim, kernel_size=(3, 3), stride=1, n_filters=256):
        super(ResnetBlockDown, self).__init__()
        self.residual = # YOUR_CODE
        self.shortcut = # YOUR_CODE

    def forward(self, x):
        _x = x
        residual = self.residual(_x)
        shortcut = self.shortcut(x)
        return residual + shortcut

In [None]:
class Generator(nn.Module):
    def __init__(self, n_filters=256):
        super(Generator, self).__init__()
        self.fc = # YOUR_CODE
        network = [
            # YOUR_CODE
        ]
        self.net = nn.Sequential(*network)
        self.noise = torch.distributions.Normal(torch.tensor(0.), torch.tensor(1.))

    def forward(self, z):
        z = self.fc(z).reshape(-1, 256, 4, 4)
        return self.net(z)

    def sample(self, n_samples):
        z = self.noise.sample([n_samples, 128]).to(device)
        return self.forward(z)

    
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        network = [
            # YOUR_CODE
        ]
        self.net = nn.Sequential(*network)
        self.fc = spectral_norm(nn.Linear(128, 1))

    def forward(self, z):
        z = self.net(z)
        z = torch.mean(z, dim=(2, 3))
        return self.fc(z)

In [None]:
def plot_gan_training(losses, title, fname):
    plt.figure()
    n_itr = len(losses)
    xs = np.arange(n_itr)

    plt.plot(xs, losses, label='loss')
    plt.legend()
    plt.title(title)
    plt.xlabel('Training Iteration')
    plt.ylabel('Loss')

In [None]:
from tqdm import tqdm
from IPython.display import clear_output

device = 'cuda'

class Solver(object):
    def __init__(self, train_data, n_iterations=5000, batch_size=256, n_filters=128):
        self.n_critic = 3
        self.log_interval = 100
        self.batch_size = batch_size
        self.n_filters = n_filters
        self.train_loader = self.create_loaders(train_data)
        self.n_batches_in_epoch = len(self.train_loader)
        self.n_epochs = n_iterations // self.n_batches_in_epoch
        self.curr_itr = 0

    def build(self, part_name):
        self.g = Generator(n_filters=self.n_filters).to(device)
        self.d = Discriminator().to(device)
        self.g_optimizer = # TODO (do not forget to use recommended parameters)
        self.g_scheduler = # TODO
        self.d_optimizer = # TODO
        self.d_scheduler = # TODO
        self.part_name = part_name

    def create_loaders(self, train_data):
        train_loader = data.DataLoader(train_data, batch_size=self.batch_size, shuffle=True)
        return train_loader
    
    def log_progress(self, losses):
        clear_output()
        print(f'current g_lr: {self.g_scheduler.get_last_lr()}')
        print(f'current d_lr: {self.d_scheduler.get_last_lr()}')
        print(f'iter: {self.curr_itr}')
        plot_gan_training(losses, 'Q2 Losses', 'results/q2_losses.png')
        
        self.g.eval()
        with torch.no_grad():
            samples = self.g.sample(100)
            samples = samples.permute(0, 2, 3, 1).detach().cpu().numpy() * 0.5 + 0.5
        show_samples(samples[:100] * 255.0, fname=None, title=f'CIFAR-10 generated samples')

    def train(self):
        train_losses = []
        for epoch_i in range(self.n_epochs):
            epoch_i += 1

            for batch_i, x in enumerate(self.train_loader):
                self.d.train()
                self.g.train()
                batch_i += 1
                self.curr_itr += 1
                x = torch.tensor(x).float().to(device)
                x = 2 * (x - 0.5)

                # do a critic update
                # TODO
                
                # generator update
                if self.curr_itr % self.n_critic == 0:
                    self.d.eval()
                    # TODO

                    train_losses.append(g_loss.cpu().item())
                    self.d.train()
                
                if self.curr_itr % self.log_interval == 0:
                    self.log_progress(train_losses)
            
            # step the learning rate
            self.g_scheduler.step()
            self.d_scheduler.step()
            np.save("q2_train_losses.npy", np.array(train_losses))

        train_losses = np.array(train_losses)
        self.save_model(f"{self.part_name}.pt")
        return train_losses

    def save_model(self, filename):
        split_path = list(osp.split(filename))
        g_path = osp.join(*split_path[:-1], 'g_' + split_path[-1])
        d_path = osp.join(*split_path[:-1], 'd_' + split_path[-1])
        torch.save(self.g.state_dict(), g_path)
        torch.save(self.d.state_dict(), d_path)

    def load_model(self, filename):
        split_path = list(osp.split(filename))
        g_path = osp.join(*split_path[:-1], 'g_' + split_path[-1])
        d_path = osp.join(*split_path[:-1], 'd_' + split_path[-1])
        self.d.load_state_dict(torch.load(d_path))
        self.g.load_state_dict(torch.load(g_path))

In [None]:
def q2(train_data, model_path=None, losses_path="./q2_train_losses.npy", n_iterations=5000):
    """
    train_data: An (n_train, 3, 32, 32) numpy array of CIFAR-10 images with values in [0, 1]

    Returns
    - a (# of training iterations,) numpy array of WGAN critic train losses evaluated every minibatch
    - a (1000, 32, 32, 3) numpy array of samples from your model in [0, 1]. 
        The first 100 will be displayed, and the rest will be used to calculate the Inception score. 
    """
    solver = Solver(train_data, n_iterations=n_iterations)
    solver.build("sngan")
    if model_path is not None and losses_path is not None:
        solver.load_model(model_path)
        losses = np.load(losses_path)
    else:
        losses = solver.train()
    solver.g.eval()
    solver.d.eval()
    with torch.no_grad():
        samples = solver.g.sample(100)
        samples = samples.permute(0, 2, 3, 1).detach().cpu().numpy() * 0.5 + 0.5

    return losses, samples

## Results

In [None]:
import torch.nn.functional as F
import math
import sys

def calculate_is(samples):
    assert (type(samples[0]) == np.ndarray)
    assert (len(samples[0].shape) == 3)

    model = torchvision.models.googlenet(pretrained=True).to(device)

    bs = 100
    model.eval()
    with torch.no_grad():
        preds = []
        n_batches = int(math.ceil(float(len(samples)) / float(bs)))
        for i in range(n_batches):
            sys.stdout.write(".")
            sys.stdout.flush()
            inp = torch.FloatTensor(samples[(i * bs):min((i + 1) * bs, len(samples))]).to(device)
            pred = F.softmax(model(inp), dim=1).detach().cpu().numpy()
            preds.append(pred)
    preds = np.concatenate(preds, 0)
    kl = preds * (np.log(preds) - np.log(np.expand_dims(np.mean(preds, 0), 0)))
    kl = np.mean(np.sum(kl, 1))
    return np.exp(kl)

    
def sngan_save_results(fn):
    train_data = load_cifar()
    train_data = train_data.data.transpose((0, 3, 1, 2)) / 255.0
    train_losses, samples = fn(train_data)

    print("Inception score:", calculate_is(samples.transpose([0, 3, 1, 2])))
    plot_gan_training(train_losses, 'Q2 Losses', 'results/q2_losses.png')
    show_samples(samples[:100] * 255.0, fname=None, title=f'CIFAR-10 generated samples')

sngan_save_results(q2)

# CycleGAN (10 points)
In this question, you'll train a CycleGAN model to learn to translate between two different image domains, without any paired data. Execute the following cell to visualize our two datasets: MNIST and Colored MNIST. 

In [None]:
!wget https://upload.wikimedia.org/wikipedia/en/7/7d/Lenna_%28test_image%29.png -O lena.jpg

In [None]:
from torchvision import transforms

def load_mnist_data():
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])
    train_data = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform)
    test_data = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform)
    return train_data, test_data

In [None]:
import PIL
import cv2
import scipy

def get_colored_mnist(data):
    # from https://www.wouterbulten.nl/blog/tech/getting-started-with-gans-2-colorful-mnist/
    # Read Lena image
    lena = PIL.Image.open('./lena.jpg')

    # Resize
    batch_resized = np.asarray([scipy.ndimage.zoom(image, (2.3, 2.3, 1), order=1) for image in data])

    # Extend to RGB
    batch_rgb = np.concatenate([batch_resized, batch_resized, batch_resized], axis=3)

    # Make binary
    batch_binary = (batch_rgb > 0.5)

    batch = np.zeros((data.shape[0], 28, 28, 3))

    for i in range(data.shape[0]):
        # Take a random crop of the Lena image (background)
        x_c = np.random.randint(0, lena.size[0] - 64)
        y_c = np.random.randint(0, lena.size[1] - 64)
        image = lena.crop((x_c, y_c, x_c + 64, y_c + 64))
        image = np.asarray(image) / 255.0

        # Invert the colors at the location of the number
        image[batch_binary[i]] = 1 - image[batch_binary[i]]

        batch[i] = cv2.resize(image, (0, 0), fx=28 / 64, fy=28 / 64, interpolation=cv2.INTER_AREA)
    return batch.transpose(0, 3, 1, 2)

def load_q4_data():
    train, _ = load_q3_data()
    mnist = np.array(train.data.reshape(-1, 28, 28, 1) / 255.0)
    colored_mnist = get_colored_mnist(mnist)
    return mnist.transpose(0, 3, 1, 2), colored_mnist

def visualize_cyclegan_datasets():
    mnist, colored_mnist = load_q4_data()
    mnist, colored_mnist = mnist[:100], colored_mnist[:100]
    show_samples(mnist.reshape([100, 28, 28, 1]) * 255.0, title=f'MNIST samples')
    show_samples(colored_mnist.transpose([0, 2, 3, 1]) * 255.0, title=f'Colored MNIST samples')

visualize_cyclegan_datasets()

In [CycleGAN](https://arxiv.org/pdf/1703.10593.pdf), the goal is to learn functions $F$ and $G$ that can transform images from $X \rightarrow Y$ and vice-versa. This is an unconstrained problem, so we additionally enforce the *cycle-consistency* property, where we want 
$$x \approx G(F(x))$$
and  
$$y \approx F(G(x))$$
This loss function encourages $F$ and $G$ to approximately invert each other. In addition to this cycle-consistency loss, we also have a standard GAN loss such that $F(x)$ and $G(y)$ look like real images from the other domain. 

Since this is a bonus question, we won't do much hand-holding. We recommend reading through the original paper to get a sense of what architectures and hyperparameters are useful. Note that our datasets are fairly simple, so you won't need excessively large models. 

**You will report the following deliverables**
1. A set of images showing real MNIST digits, transformations of those images into Colored MNIST digits, and reconstructions back into the greyscale domain. 
2. A set of images showing real Colored MNIST digits, transformations of those images, and reconstructions. 

## Solution

In [None]:
def q4(mnist_data, cmnist_data):
    """
    mnist_data: An (60000, 1, 28, 28) numpy array of black and white images with values in [0, 1]
    cmnist_data: An (60000, 3, 28, 28) numpy array of colored images with values in [0, 1]

    Returns
    - a (20, 28, 28, 1) numpy array of real MNIST digits, in [0, 1]
    - a (20, 28, 28, 3) numpy array of translated Colored MNIST digits, in [0, 1]
    - a (20, 28, 28, 1) numpy array of reconstructed MNIST digits, in [0, 1]

    - a (20, 28, 28, 3) numpy array of real Colored MNIST digits, in [0, 1]
    - a (20, 28, 28, 1) numpy array of translated MNIST digits, in [0, 1]
    - a (20, 28, 28, 3) numpy array of reconstructed Colored MNIST digits, in [0, 1]
    """
    """ YOUR CODE HERE """

## Results

In [None]:
def q4_save_results(fn):
    mnist, cmnist = load_q4_data()

    m1, c1, m2, c2, m3, c3 = fn(mnist, cmnist)
    m1, m2, m3 = m1.repeat(3, axis=3), m2.repeat(3, axis=3), m3.repeat(3, axis=3)
    mnist_reconstructions = np.concatenate([m1, c1, m2], axis=0)
    colored_mnist_reconstructions = np.concatenate([c2, m3, c3], axis=0)

    show_samples(mnist_reconstructions * 255.0, nrow=20,
                 fname=None,
                 title=f'Source domain: MNIST')
    show_samples(colored_mnist_reconstructions * 255.0, nrow=20,
                 fname=None,
                 title=f'Source domain: Colored MNIST')
    pass

q4_save_results(q4)