[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/DalasNoin/arena/blob/main/w4/gan.ipynb)

In [3]:
# ! wget https://raw.githubusercontent.com/callummcdougall/arena-v1/main/w4d1/utils.py
# ! wget https://raw.githubusercontent.com/dalasnoin/arena/main/w4/gan_modules.py
# ! gdown --id 1lfEWQ05cZ5FgWkSIxwyi7TLryhvnCDWb
# ! mkdir data
# ! unzip -qq img_align_celeba.zip -d data
# ! pip install einops fancy_einsum tqdm plotly
# ! pip install wandb 

In [4]:


import torch as t
from typing import Union
from torch import nn
import torch.nn.functional as F
import plotly.express as px
import plotly.graph_objects as go
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange
from fancy_einsum import einsum
import os
from tqdm.auto import tqdm
from torchvision import transforms, datasets
from torchvision.datasets import ImageFolder
from torch.utils.data import Dataset, DataLoader, TensorDataset
import wandb
import utils
import gan_modules
from typing import Optional

image_size = img_size = 64 # i misleadingly used both names in the code
batch_size = 3
latent_dim_size = 100
img_channels = 3
generator_num_features = 512
n_layers = 4

device = t.device("cuda:0" if t.cuda.is_available() else "cpu")
wandb_key=""
keyfile = "keystore.yaml"
if not wandb_key and os.path.exists(keyfile):
    import yaml
    keys = yaml.safe_load(open(keyfile,"r"))
    wandb_key = keys["wandb"]
os.environ["WANDB_API_KEY"] = wandb_key


In [5]:
# @t.no_grad() not necessary since nn.init already uses nograd mode
def initialize_weights(model: nn.Module) -> None:
    """ 
    They mention at the end of page 3 that all weights were initialized from a N(0, 0.02)N(0,0.02) distribution. 
    This applies to the convolutional and convolutional transpose layers' weights, 
    but the BatchNorm layers' weights should be initialised from N(1, 0.02)N(1,0.02) (since 1 is their default value). 
    The BatchNorm biases should all be set to zero (which they are by default).
    """
    for name, parameter in model.named_parameters():
        if "batchnorm" in name:
            if "bias" in name:
                nn.init.constant_(parameter.data, 0.0)
            elif "weight" in name:
                nn.init.uniform_(parameter.data, a=0.02, b=1.0)
        else:
            nn.init.uniform_(parameter.data, a=0.0, b=0.02)

In [6]:
class Generator(nn.Module):

    def __init__(
        self,
        latent_dim_size: int,           # size of the random vector we use for generating outputs
        img_size = int,                 # size of the images we're generating
        img_channels = int,             # indicates RGB images
        generator_num_features = int,   # number of channels after first projection and reshaping
        n_layers = int,                 # number of CONV_n layers
    ):
        super().__init__()
        self.latent_dim_size = latent_dim_size
        self.latent_dim_projected = 8192
        self.img_size = img_size
        self.img_channels = img_channels
        self.generator_num_features = generator_num_features
        self.initial_width = 4
        self.n_layers = n_layers
        

        self._build()


    def _build(self):
        self.latent_sequential = nn.Sequential(nn.Linear(self.latent_dim_size, self.latent_dim_projected, bias=False),
            Rearrange("a (b c d) -> a b c d", c=self.initial_width, d=self.initial_width),
            nn.BatchNorm2d((self.latent_dim_projected//self.initial_width**2)),
            nn.ReLU()
        )

        self.layer_structure = [(self.generator_num_features//(2**i),
                                    self.generator_num_features//2**(i+1), 
                                    self.initial_width*2**i) for i in range(0, self.n_layers-1)]
        self.layer_structure.append((self.img_size,self.img_channels,self.img_size))

        block_list = [ConvTransposeBlock(*structure) for structure in self.layer_structure[:-1]]
        block_list.append(ConvTransposeBlock(*self.layer_structure[-1],t.tanh, False))

        self.upsample_sequential = nn.Sequential(*block_list)

    def forward(self, x: t.Tensor):
        x = self.latent_sequential(x)
        return self.upsample_sequential(x)

class ConvTransposeBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, width: int, activation_function: callable=None, uses_batchnorm:bool=True):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.width = width
        self.convtranspose = gan_modules.ConvTranspose2d(in_channels=in_channels,
                                            out_channels=out_channels,
                                            kernel_size=4,
                                            stride=2,
                                            padding=1)
        self.uses_batchnorm = uses_batchnorm
        if uses_batchnorm:
            self.batchnorm = nn.BatchNorm2d((self.out_channels))
        if activation_function is None:
            self.activation_function = nn.ReLU()
        else:
            self.activation_function = activation_function

    def forward(self, x):
        x = self.convtranspose(x)
        if self.uses_batchnorm:
            x = self.batchnorm(x)
        return self.activation_function(x)

generator = Generator(
    latent_dim_size=latent_dim_size,
    img_size=image_size,
    img_channels=img_channels,
    generator_num_features=generator_num_features,
    n_layers=n_layers
)
generator.layer_structure


[(512, 256, 4), (256, 128, 8), (128, 64, 16), (64, 3, 64)]

In [7]:
class Discriminator(nn.Module):
    def __init__(
        self,
        img_size = 64,
        img_channels = 3,
        generator_num_features = 1024,
        n_layers = 4,
    ):
        super().__init__()
        self.img_size = img_size
        self.img_channels = img_channels
        self.generator_num_features = generator_num_features
        self.n_layers = n_layers

        # set as constants for now
        self.conv_params = {
            "kernel_size": 4,
            "stride": 2,
            "padding": 1,
            "bias": False
        }

        

        self.layer_structure = [(img_size*2**i, img_size*(2**(i+1)), img_size//2**(i+2)) for i in range(0, n_layers-1)]

        self.initial_conv = nn.Conv2d(in_channels=self.img_channels, 
                                        out_channels=self.img_size,  # this seems to be the case for this model
                                        **self.conv_params)
        
        self.downsampling_block = nn.Sequential(*[ConvBlock(*in_out_width_tuple,conv_params=self.conv_params) for in_out_width_tuple in self.layer_structure])

        self.rearrange_layer = Rearrange("a b h w -> a (b h w)")

        self.classifier = nn.Linear(in_features=(self.layer_structure[-1][1]*self.layer_structure[-1][2]**2), # output size of the last layer
                                    out_features=1,
                                    bias=False
        )

        self.sigmoid = nn.Sigmoid()
        

    def forward(self, x: t.Tensor):
        x = self.initial_conv(x)
        x = self.downsampling_block(x)
        x = self.rearrange_layer(x)
        return self.sigmoid(self.classifier(x))


class ConvBlock(nn.Module):
    def __init__(self, in_channels:int, out_channels:int, width:int, conv_params:dict, uses_batchnorm:bool=True):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.width = width
        self.uses_batchnorm = uses_batchnorm

        self.conv_params = conv_params

        self.negative_slope=0.02

        self.conv = nn.Conv2d(in_channels=self.in_channels,
            out_channels=self.out_channels,
            **self.conv_params
        )

        if self.uses_batchnorm:
            self.batchnorm = nn.BatchNorm2d((out_channels))

        self.leaky_relu = gan_modules.LeakyReLU(negative_slope=self.negative_slope)

    def forward(self, x):
        x = self.conv(x)
        if self.uses_batchnorm:
            x = self.batchnorm(x)
        return self.leaky_relu(x)

discriminator = Discriminator(
    img_size=img_size,
    img_channels=img_channels,
    generator_num_features=generator_num_features,
    n_layers=n_layers
)


In [8]:
from torchvision import transforms, datasets

from torch.utils.data import DataLoader


transform = transforms.Compose([
    transforms.Resize((image_size,image_size)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

trainset = ImageFolder(
    root="data",
    transform=transform
)

utils.show_images(trainset, rows=3, cols=5)

trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True)

In [9]:
trainset[0][0].shape

torch.Size([3, 64, 64])

In [10]:
def train_generator_discriminator(
    netG: Generator, 
    netD: Discriminator, 
    optG,
    optD,
    trainloader,
    epochs: int,
    max_epoch_duration: Optional[Union[int, float]] = None,           # Each epoch terminates after this many seconds
    print_netG_output_interval: Optional[Union[int, float]] = None,   # Generator output is printed at this frequency
    use_wandb: bool = False
):
    """
Discriminator:
1. Zero the gradients of DD. This is important because if the last thing we did was evaluate D(G(z)) 
    (in order to update the parameters of G), then D will have stored gradients from that backward pass.

Generate random noise z, and compute D(G(z)). Take the average of log(1−D(G(z))), 
and we have the first part of our loss function.

Take the real images x in the current batch, 
and use that to compute log(D(x)) (we use this rather than log(1−D(x)), 
for reasons we'll discuss below). This gives us the second part of our loss function.

We now add the two terms together, and perform gradient ascent (since we're trying to maximise this expression).
You can perform gradient ascent by either flipping the sign of the thing you're doing a backward pass on, 
or passing the keyword argument maximize=True when defining your optimiser (all optimisers have this option).

Tip - when calculating D(G(z)), for the purpose of training the discriminator, 
it's best to first calculate G(z) then call detach on this tensor before passing it to DD. 
This is because you then don't need to worry about gradients accumulating for GG.


"""

    if use_wandb:
        wandb.init()

    step = 0

    netG.train().to(device)
    netD.train().to(device)

    for epoch in range(epochs):
        for i, (real_images, labels) in enumerate(trainloader):
            
            real_images=real_images.to(device)
            labels=labels.to(device)
            # Discriminator training step
            # 1.
            optD.zero_grad()
            # generate image from random
            z = t.empty((batch_size, latent_dim_size)).uniform_(0,1).to(device)
            
            generated_images = netG(z).detach()
            # use discriminator to get both components of the loss
            loss_tensor = t.log(1-netD(generated_images) + 1e-5)
            loss_generated_images = t.mean(loss_tensor)
            
            
            loss_real_images = t.mean(t.log(netD(real_images)))

            loss_discriminator = loss_generated_images + loss_real_images
            print(loss_discriminator, loss_generated_images, loss_real_images)

            loss_discriminator.backward()

            optD.step()


            # Generator training step
            """Generator:
We take the following steps:

Zero the gradients of GG.
Generate random noise zz, and compute D(G(z)).
We don't use log(1−D(G(z))) to calculate our loss function, instead we use log(D(G(z))) (and gradient ascent).
"""
            optG.zero_grad()
            z = t.empty((batch_size, latent_dim_size)).uniform_(0,1).to(device)
            predictions = discriminator(generator(z))
            loss_generator = t.log(t.mean(predictions))
            loss_generator.backward()
            
            optG.step()

            step += batch_size

            if use_wandb:
                wandb.log(dict(epoch=epoch, loss_discriminator=loss_discriminator, loss_generator=loss_generator), step=step)

            
    if use_wandb:         
        wandb.finish()

lr = 5e-5

initialize_weights(generator)
initialize_weights(discriminator)

optimizer_generator = t.optim.Adam(generator.parameters(),lr=lr, maximize=True)
optimizer_discriminator = t.optim.Adam(discriminator.parameters(),lr=lr, maximize=True)

train_generator_discriminator(
    netG=generator,
    netD=discriminator,
    optG=optimizer_generator,
    optD=optimizer_discriminator,
    trainloader=trainloader,
    epochs=1,
    use_wandb=False
)



tensor(-4.5192, grad_fn=<AddBackward0>) tensor(-4.1511, grad_fn=<MeanBackward0>) tensor(-0.3681, grad_fn=<MeanBackward0>)
tensor(-7.9414, grad_fn=<AddBackward0>) tensor(-7.8113, grad_fn=<MeanBackward0>) tensor(-0.1301, grad_fn=<MeanBackward0>)
tensor(-10.5911, grad_fn=<AddBackward0>) tensor(-10.0312, grad_fn=<MeanBackward0>) tensor(-0.5599, grad_fn=<MeanBackward0>)
tensor(-4.4218, grad_fn=<AddBackward0>) tensor(-4.3658, grad_fn=<MeanBackward0>) tensor(-0.0559, grad_fn=<MeanBackward0>)
tensor(-7.4499, grad_fn=<AddBackward0>) tensor(-7.4419, grad_fn=<MeanBackward0>) tensor(-0.0080, grad_fn=<MeanBackward0>)
tensor(-8.1867, grad_fn=<AddBackward0>) tensor(-7.8148, grad_fn=<MeanBackward0>) tensor(-0.3719, grad_fn=<MeanBackward0>)
tensor(-8.5450, grad_fn=<AddBackward0>) tensor(-7.9149, grad_fn=<MeanBackward0>) tensor(-0.6301, grad_fn=<MeanBackward0>)
tensor(-10.3165, grad_fn=<AddBackward0>) tensor(-9.9695, grad_fn=<MeanBackward0>) tensor(-0.3470, grad_fn=<MeanBackward0>)
tensor(-7.6789, grad_

KeyboardInterrupt: 

### If stuck
compare solutions to own version

In [None]:
# ! wget https://raw.githubusercontent.com/callummcdougall/arena-v1/main/w4d1/solutions.py
# ! wget https://raw.githubusercontent.com/callummcdougall/arena-v1/main/w4d1/w0d2_solutions.py
# ! wget https://raw.githubusercontent.com/callummcdougall/arena-v1/main/w4d1/w0d3_solutions.py


from solutions import netD_celeb_mini
utils.print_param_count(discriminator, netD_celeb_mini)

In [None]:
from solutions import netG_celeb_mini
utils.print_param_count(generator, netG_celeb_mini)

In [None]:
512*4*4

In [None]:
[(self.generator_num_features//(2**i),
                                    self.generator_num_features//2**(i+1), 
                                    self.initial_width*2**i) for i in range(1, self.n_layers)]