# GANs with Python

## Setup

### 1. Go to this link: bit.ly/45svhk0
### 2. Log in to your Google Drive account
### 3. Make sure your files are in the right place

## Importing Libraries

Like in the previous workshop, we're going to import some libraries for the code.

In [None]:
import os
import random
from urllib import request  # This is for downloading
import zipfile  # This is for handling zip files

import cv2
import numpy as np
import torch
from torch.utils.data import Dataset
import matplotlib as plt
import torch.nn as nn
from torch.utils.data import DataLoader

import math
from PIL import Image

%matplotlib inline

## A quick intro to booleans

### Data type recap

Earlier we discussed different types of variables you see in your code. The main examples we went through were strings (text data), integers (whole-numbers), and floats (numbers with a decimal place).

A quick recap:
```python
my_string = "some text!"
my_float = 39.59382
my_int = 20
```
### Bools and why we're using them

There is another basic data type that we frequently use in Python called booleans (or bools). These can either have the values `True` or `False`. They are used when we want a part of our code to run only when a certain condition has been met.

The code we are using can take several hours to complete (depending on how many images you give and some other factors), so it's been written in such a way that it is possible to save the progress that the code has made so that it can pick up from where it left off later. THis is also useful if the code crashes for whatever reason, as the progress it has made will be preserved.

However, this has the downside of creating progress files in your Google Drive folder that can build up in size quite quickly!

### Bools for our GAN

Before running the code, we will set values for the following bools:
- `do_preprocess` - Determines if we run or skip the code for resizing our image files. This will need to be done at least once.
- `from_checkpoint` - Determines if the GAN starts from scratch, or picks up where it left off by loading a model file. For the first run this will have to be set to False.

In [None]:
do_preprocess = True
from_checkpoint = False

Here is a basic example of what bools allow you to do:

In [None]:
my_bool = True

if my_bool:
    print("We will see this printed.")

if my_bool:
    print("This will also be printed")
else:
    print("But not this.")

This code contains an `if` statement. This is another keyword in Python used for conditional execution, basically meaning it helps us write code that is only run some of the time. Here we have used it to ensure that some text is printed only when `my_bool` is True.

Now let's see what happens when something is false.

**Exercise**: What will the output be?

In [None]:
my_bool = False

if my_bool:
    print("This will be printed. Or will it...?")

You may remember that the first workshop showed us how to use `type()` to find the type of some data. Now we can use it with a bool to see that it is in fact a bool.

In [None]:
print(type(True), type(False))

### Downloading Some Data

Now that we understand a bit about what booleans do, please set the bool below to True or False depending on whether you've uploaded 500+ image files to Google Drive and have put them in the right folder.

In [None]:
student_uploaded_own_data = False

## Check the GPU

The Tensorflow library allows us to perform calculations with a GPU or CPU. Setting a GPU up on a desktop machine can be tricky, and sometimes a library won't recognise your GPU as being on your system, even when one is present. For this reason Tensorflow and other Python libraries that utilise the GPU often have a command for checking that a GPU has been detected on the system. This lets you know if you can continue coding, or if you need to do some troubleshooting and figure out why things didn't install properly.

Fortunately we don't need to worry about these things when using Colab. The GPU is already setup for us.

In [None]:
# check CUDA availability and set device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Use GPU: {}".format(str(device) != "cpu"))

Google Colab is somtimes able to tell that a Python Notebook requies the GPU, so you should have automatically been given a T4 runtime. You can check this in the top right.

## Prepare the Folder Names (and Data?)

This code will create quite a bit of output, and so we'll need to create some folders to save all these files:

- One folder for the resized images that we will feed into the GAN
- Another folder for the model files that act as snapshots of how far the GAN has come along
- And another for the computer-generated images

Like in the previous workshop, the `os` library will be used for creating our paths. This for safety, as `os` knows what to do regardless of if you are running Windows, Linux, or Mac.

In [None]:
project_dir = "./drive/MyDrive/"
GANS_WORKSHOP_FOLDER = "gans-workshop-files"

try:
    from google.colab import drive

    drive.mount("/content/drive")

except ModuleNotFoundError:
    project_dir = os.getcwd()

### A bit about `try-except`

The purpose of `try-except` is to give the code a kind of Plan B on what to do when a certain block of code creates an error. Sometimes we are _expecting_ an error in a particular situation, and know what we want the code to do should this happen. This is where a `try-except` block comes into play. It allows us to say "If you run across this problem as the code runs, do this instead."

Here I am mounting the Google Drive folder in a `try-except` block because sometimes I run this code on a desktop. In this scenario the `from google.colab import drive` would lead to a `ModuleNotFoundError`. This is an error you get when you try to import a library that hasn't been installed on your system. When I am running this code on a desktop machine I am not importing Colab and instead getting the input files from my hard drive, and so I tell the code to use `os.getcwd()` as the path to work from. This is the _current working directory_.

Here's an example below of _catching_ an error caused by using the wrong index for a list.

In [None]:
# This list contains three items
my_list = [1, 2, 3]

# This is how we print the items in a list line by line
print(my_list[0])
print(my_list[1])
print(my_list[2])

try:
    # I am going to try and print the 4th item in the list
    print(my_list[3])
    print("Will the code reach this point?")
except IndexError:
    print("You have accessed something outside of the list.")

## Downloading the Files

Now with that out of the way we can start sorting out the images to give to the GAN. Pick a type of image you want to download if you didn't already prepare some data.

In [None]:
# @title Choose the type of data you wish to download if you don't already have something { display-mode: "form" }

data_folder_name = "cats"  # @param ["cats", "flowers", "abstract-paintings"]

Now a conditional statement will download some data if it's needed, otherwise we'll go with the data you already have.

In [None]:
GANS_WORKSHOP_FOLDER = "gans-workshop-files"

if student_uploaded_own_data:
    # Change this to the name of your data folder
    data_folder_name = "pokemon"
else:
    data_path = os.path.join(project_dir, GANS_WORKSHOP_FOLDER, data_folder_name)
    if not os.path.exists(data_path):
        os.makedirs(data_path, exist_ok=True)
        print(f"Downloading {data_folder_name} dataset...")
        local_filename, _ = request.urlretrieve(
            f"https://github.com/DolicaAkelloEgwel/gans-datasets/raw/master/{data_folder_name}.zip"
        )
        with zipfile.ZipFile(local_filename, "r") as downloaded_dataset:
            downloaded_dataset.extractall(data_path)

We're going to need quite a few folders for all the files that will be created from the code. These will be
- A folder for our resized data
- A folder for our model/checkpoint files
- A folder for the GANs output images

For the time being, we're just going to set the paths for these folders. That's because these folders may already exist, so there will need to be some conditional logic used to check if these folders even need to be created.

In [None]:
workshop_dir = os.path.join(project_dir, GANS_WORKSHOP_FOLDER)
data_dir = os.path.join(workshop_dir, data_folder_name)
data_resized_dir = os.path.join(workshop_dir, f"{data_folder_name}-resized")
models_folder = os.path.join(workshop_dir, f"{data_folder_name}-models")
image_folder = os.path.join(workshop_dir, f"{data_folder_name}-gans-images")

## Preprocess the Files

This code is designed to work with 128x128 images, so we're going to resize the images and place them in a new folder. Again, `os.mkdir` is used for this.

With the folder for our preprocessed files created, we can now resize the images and save them there. In order to do this we loop through all the images in our data directory using `os.listdir`. This is something that will list all the files it can find in a folder, which is why I am using `image_filename` as the placeholder.

Our `image_filename` is then sent to the `cv2.imread` command which will load the file into an array, and then sent again to the `cv2.resize` command so that it may be resized. Finally the resized images are saved to our `data_resized_dir` using the `cv2.imwrite` command.

In [None]:
def crop_image_in_centre(image: np.ndarray) -> np.ndarray:
    """Crops an image in the centre to make it square.

    Args:
        image (np.ndarray): The image data to reshape.

    Returns:
        np.ndarray: The cropped image.
    """
    height, width = image.shape[:2]
    # Only do the cropping if the image isn't square
    if height != width:
        min_side = min(height, width)
        top, bot = (height - min_side) // 2, height - (height - min_side) // 2
        left, right = (width - min_side) // 2, width - (width - min_side) // 2
        image = image[top:bot, left:right, :]
    return image

### Loading the Image Data

In [None]:
if do_preprocess:
    # Make a folder for the resized images if one doesn't already exist
    if not os.path.isdir(data_resized_dir):
        os.mkdir(data_resized_dir)

    # Go through each of our input images, resize them, and then save them to the new folder
    for image_filename in os.listdir(data_dir):
        try:
            image = cv2.imread(os.path.join(data_dir, image_filename))
            image = crop_image_in_centre(image)
            image = cv2.resize(image, (128, 128))
            cv2.imwrite(os.path.join(data_resized_dir, image_filename), image)
        except Exception as e:
            print(str(e))

Sometimes stray files such as `.DS_Store` make their way into our folders. Our `os.listdir` loop is simply going to find all the files in a folder and isn't smart enough to know that not all of these files will be images. This is why we use a `try-except` again to allow the code to keep going even when `cv2.imread` fails. This is fine, because if it fails to read a file as an image then the file most likely wasn't an actual image, so we can just ignore it. 

## Setup Helper Functions

This will create a function that will plot several images in a grid. We'll use this to monitor the progress of the GAN.

In [None]:
def get_image(image_path, mode):
    """Loads a numpy image.

    Args:
        image_path (str): The path for the image to load.
        mode (str): The mode to give when converting the image.

    Returns:
        np.ndarray: The image in the form of a numpy array.
    """
    return np.array(Image.open(image_path).convert(mode))


def get_batch(image_files, mode):
    """Creates a batch of images.

    Args:
        image_files (list): A list of several image files.
        mode (str): The most to use when converting the images.

    Returns:
        np.ndarray: An array of a batch of images.
    """
    data_batch = np.array(
        [get_image(sample_file, mode) for sample_file in image_files]
    ).astype(np.float32)

    # Make sure the images are in 4 dimensions
    if len(data_batch.shape) < 4:
        data_batch = data_batch.reshape(data_batch.shape + (1,))

    return data_batch

In [None]:
def images_square_grid(images: np.ndarray, mode: str) -> Image:
    """Plots three images in a grid.

    Args:
        images (np.ndarray): A batch of 9 images.
        mode (str): The mode argument given to Image.

    Returns:
        PIL.Image: An Image object containing the 9 pictures.
    """
    # Get maximum size for square grid of images
    save_size = math.floor(np.sqrt(images.shape[0]))

    # Scale to 0-255
    images = (((images - images.min()) * 255) / (images.max() - images.min())).astype(
        np.uint8
    )

    # Put images in a square arrangement
    images_in_square = np.reshape(
        images[: save_size * save_size],
        (save_size, save_size, images.shape[1], images.shape[2], images.shape[3]),
    )
    if mode == "L":
        images_in_square = np.squeeze(images_in_square, 4)

    # Combine images to grid image
    new_im = Image.new(mode, (images.shape[1] * save_size, images.shape[2] * save_size))
    for col_i, col_images in enumerate(images_in_square):
        for image_i, image in enumerate(col_images):
            im = Image.fromarray(image, mode)
            new_im.paste(im, (col_i * images.shape[1], image_i * images.shape[2]))

    return new_im

In [None]:
def _preprocess_image(img: np.ndarray) -> np.ndarray:
    """Pre-process image."""
    # normalize
    img = img / 128.0  # between 0 and 2
    img -= 1.0  # between -1 and 1
    # transpose
    img = img.transpose((2, 0, 1))
    return img

This creates a `Dataset` object for PyTorch to use.

In [None]:
class Dataset(Dataset):
    """Dataset object for 64x64 pixel catface images."""

    def __init__(self, img_paths, mirror=True):
        self.img_paths = img_paths
        self.size = 128
        self.mirror = mirror

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        img_path = self.img_paths[idx]
        img = cv2.imread(img_path)

        # mirror img with a 50% chance
        if self.mirror:
            if random.random() > 0.5:
                img = img[:, ::-1, :]
        # resize
        img = cv2.resize(img, (self.size, self.size))

        # normalize
        img = _preprocess_image(img)

        return torch.tensor(img.astype(np.float32))

This code below will display 9 of our resized images by using the `images_square_grid` function that was defined earlier. This lets us know that the preprocessing worked.

In [None]:
# Create a list of the files in our resized images folder
resized_data_filenames = [
    os.path.join(data_resized_dir, resized_image_filename)
    for resized_image_filename in os.listdir(data_resized_dir)
]
show_n_images = 9
# Get a batch of 9 of the resized images
train_images = get_batch(resized_data_filenames[:show_n_images], "RGB")
# Create a grid from the resized images and then plot them
plt.imshow(images_square_grid(train_images, "RGB"))

## DCGAN

In [None]:
def postprocess_img(img):
    """Transform a network output into displayable img."""
    img = img.transpose((1, 2, 0))
    img += 1.0
    img = (img * 128.0).astype(np.uint8)
    return img

In [None]:
def save_image_grid(img_batch, grid_size, epoch, img_path):
    """Create a grid-like visualization from a batch of images and save it."""
    if grid_size**2 != img_batch.shape[0]:
        print(
            "grid_size**2 and batch size not equal: {} {}. Skipping".format(
                grid_size**2, img_batch.shape[0]
            )
        )
        return None

    # create black canvas
    img_size = img_batch.shape[2]
    canvas = np.zeros(
        (grid_size * (img_size + 2) - 2 + 28, grid_size * (img_size + 2) - 2, 3),
        dtype=np.uint8,
    )

    # add the epoch number to the bottom
    text_size = cv2.getTextSize(
        "Epoch {}".format(epoch), cv2.FONT_HERSHEY_SIMPLEX, 0.6, 1
    )  # get text size
    # calculate text position
    text_left = (canvas.shape[1] - text_size[0][0]) // 2
    text_bottom = canvas.shape[0] - (28 - text_size[0][1]) // 2
    # add text
    cv2.putText(
        canvas,
        "Epoch {}".format(epoch),
        (text_left, text_bottom),
        cv2.FONT_HERSHEY_SIMPLEX,
        0.6,
        (255, 255, 255),
        1,
    )

    for img_idx, img in enumerate(img_batch):
        col = math.floor(img_idx / grid_size)
        row = img_idx - col * grid_size
        canvas[
            col * (img_size + 2) : col * (img_size + 2) + img_size,
            row * (img_size + 2) : row * (img_size + 2) + img_size,
            :,
        ] = postprocess_img(img)

    cv2.imwrite(img_path, canvas)

In [None]:
def weights_init(m):
    """Initialize the weights of a module randomly, using Gaussian distribution."""
    classname = m.__class__.__name__
    if classname.find("Conv2d") != -1 or classname.find("ConvTranspose2d") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)

In [None]:
class ConvBlock(nn.Module):
    def __init__(
        self,
        block_name,
        in_size,
        out_size,
        normalize=True,
        kernel_size=4,
        stride=2,
        padding=1,
        bias=False,
        activation_fn=nn.LeakyReLU(0.2),
    ):
        super(ConvBlock, self).__init__()

        self.model = nn.Sequential()
        self.model.add_module(
            block_name + "_conv2d",
            nn.Conv2d(
                in_size,
                out_size,
                kernel_size=kernel_size,
                stride=stride,
                padding=padding,
                bias=bias,
            ),
        )

        if normalize:
            self.model.add_module(block_name + "_norm", nn.BatchNorm2d(out_size))

        if activation_fn is not None:
            self.model.add_module(block_name + "_activation_fn", activation_fn)

    def forward(self, x):
        x = self.model(x)
        return x

## Creating a Generator

Tensowflow allows us to group variables together in a `variable_scope`. This means that the variables belonging to the generator will all have `generator` in the name, and likewise, the variables belonging to the discriminator will have `discriminator` in the name.

This approach also means that the netowrks can be reused with different inputs.
- Generator: The generator will be trained, but we're also going to be sampling from it (retrieving our fake data) during the training.
- Discriminator: The discriminator will need to share images between the fake and real input images.

In [None]:
class TransConvBlock(nn.Module):
    def __init__(
        self,
        block_name,
        in_size,
        out_size,
        normalize=True,
        kernel_size=4,
        stride=2,
        padding=1,
        bias=False,
        activation_fn=nn.ReLU(),
    ):
        super(TransConvBlock, self).__init__()

        self.model = nn.Sequential()
        self.model.add_module(
            block_name + "_trans_conv",
            nn.ConvTranspose2d(
                in_size,
                out_size,
                kernel_size=kernel_size,
                stride=stride,
                padding=padding,
                bias=bias,
            ),
        )

        if normalize:
            self.model.add_module(block_name + "_norm", nn.BatchNorm2d(out_size))

        if activation_fn is not None:
            self.model.add_module(block_name + "_activation_fn", activation_fn)

    def forward(self, x):
        x = self.model(x)
        return x

In [None]:
class Generator(nn.Module):
    """Generator architecture."""

    def __init__(self):
        super(Generator, self).__init__()
        self.upconv1 = TransConvBlock(
            "gen_start_block",
            100,
            1024,
            normalize=True,
            kernel_size=4,
            stride=1,
            padding=0,
        )
        self.upconv2 = TransConvBlock(
            "gen_mid_block1",
            1024,
            512,
            normalize=True,
            kernel_size=4,
            stride=2,
            padding=1,
        )
        self.upconv3 = TransConvBlock(
            "gen_mid_block2",
            512,
            256,
            normalize=True,
            kernel_size=4,
            stride=2,
            padding=1,
        )
        self.upconv4 = TransConvBlock(
            "gen_mid_block3",
            256,
            128,
            normalize=True,
            kernel_size=4,
            stride=2,
            padding=1,
        )
        self.upconv5 = TransConvBlock(
            "gen_end_block",
            128,
            3,
            normalize=False,
            kernel_size=4,
            stride=2,
            padding=1,
            activation_fn=nn.Tanh(),
        )

    def forward(self, x):
        x = self.upconv1(x)
        x = self.upconv2(x)
        x = self.upconv3(x)
        x = self.upconv4(x)
        x = self.upconv5(x)
        return x

## Creating a Discriminator

In [None]:
class DCGANDiscriminator(nn.Module):
    """DCGAN Discriminator architecture."""

    def __init__(self):
        super(DCGANDiscriminator, self).__init__()

        self.conv1 = ConvBlock(
            "disc_start_block",
            3,
            128,
            normalize=False,
            kernel_size=4,
            stride=2,
            padding=1,
        )
        self.conv2 = ConvBlock(
            "disc_mid_block1",
            128,
            256,
            normalize=True,
            kernel_size=4,
            stride=2,
            padding=1,
        )
        self.conv3 = ConvBlock(
            "disc_mid_block2",
            256,
            512,
            normalize=True,
            kernel_size=4,
            stride=2,
            padding=1,
        )
        self.conv4 = ConvBlock(
            "disc_mid_block3",
            512,
            1024,
            normalize=True,
            kernel_size=4,
            stride=2,
            padding=1,
        )
        self.conv5 = ConvBlock(
            "disc_end_block",
            1024,
            1,
            normalize=False,
            kernel_size=4,
            stride=1,
            padding=0,
            activation_fn=nn.Sigmoid(),
        )

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        x = x.view(-1)
        return x

In [None]:
dataset = Dataset(resized_data_filenames)

In [None]:
def init_training(train_config):
    """Initialize networks, optimizers and the data pipeline."""
    # create dataset
    dataset = Dataset(train_config["data_path"])
    print("num of images:", len(dataset))

    # define data loader
    data_loader = DataLoader(
        dataset,
        batch_size=train_config["batch_size"],
        shuffle=True,
        num_workers=train_config["num_workers"],
    )

    # init networks
    generator = Generator()  # the generator is the same for both the DCGAN and the WGAN
    generator.apply(weights_init)
    generator = generator.to(device)

    discriminator = DCGANDiscriminator()
    discriminator.apply(weights_init)
    discriminator = discriminator.to(device)

    # Optimizers
    optimizers = {
        "gen": torch.optim.Adam(
            generator.parameters(),
            lr=train_config["learning_rate_g"],
            betas=(train_config["b1"], train_config["b2"]),
        ),
        "disc": torch.optim.Adam(
            discriminator.parameters(),
            lr=train_config["learning_rate_d"],
            betas=(train_config["b1"], train_config["b2"]),
        ),
    }

    # make save dir, if needed
    os.makedirs(os.path.join(train_config["checkpoint_path"], "weights"), exist_ok=True)
    os.makedirs(os.path.join(train_config["checkpoint_path"], "samples"), exist_ok=True)

    # load weights if the training is not starting from the beginning
    if train_config["start_epoch"] > 1:
        gen_path = os.path.join(
            train_config["checkpoint_path"],
            "weights",
            "checkpoint_ep{}_gen.pt".format(train_config["start_epoch"] - 1),
        )
        disc_path = os.path.join(
            train_config["checkpoint_path"],
            "weights",
            "checkpoint_ep{}_disc.pt".format(train_config["start_epoch"] - 1),
        )
        generator.load_state_dict(torch.load(gen_path, map_location=device))
        discriminator.load_state_dict(torch.load(disc_path, map_location=device))

    return device, data_loader, train_config, generator, discriminator, optimizers

## Compute the Loss

The loss tells us how well the GAN is doing.

In [None]:
def training_step_dcgan(
    batch, device, generator, discriminator, optimizers, train_config, loss_fn
):
    """Run the DCGAN training steps."""
    imgs = batch.to(device)

    # Sample noise as generator input
    z = torch.randn(imgs.size()[0], train_config["latent_dim"], 1, 1).to(device)

    valid = torch.ones(imgs.size(0)).to(device)
    fake = torch.zeros(imgs.size(0)).to(device)

    # -------------------
    # Train Discriminator
    # -------------------
    optimizers["disc"].zero_grad()
    # Sample real
    real_loss = loss_fn(discriminator(imgs), valid)
    # Sample fake
    gen_imgs = generator(z)  # generate fakes
    fake_loss = loss_fn(discriminator(gen_imgs.detach()), fake)
    # Backprop.
    d_loss = real_loss + fake_loss
    d_loss.backward()
    optimizers["disc"].step()

    # ---------------
    # Train generator
    # ---------------
    optimizers["gen"].zero_grad()
    g_loss = loss_fn(discriminator(gen_imgs), valid)
    # Backprop.
    g_loss.backward()
    optimizers["gen"].step()

    return g_loss, d_loss


def run_training(args):
    """Initialize and run the full training process using the hyper-params in args."""
    (
        device,
        data_loader,
        train_config,
        generator,
        discriminator,
        optimizers,
    ) = init_training(args)

    # generate a sample with fixed seed, and reset the seed to pseudo-random
    torch.manual_seed(42)
    z_sample = torch.randn(
        train_config["batch_size"], train_config["latent_dim"], 1, 1
    ).to(device)
    torch.manual_seed(random.randint(0, 1e10))

    # Loss function for DCGAN
    if args.gan_type == "dcgan":
        loss_fn = torch.nn.BCELoss().to(device)

    # Training
    print("Training:")
    for epoch in range(train_config["start_epoch"], train_config["max_epoch"] + 1):
        for batch_idx, batch in enumerate(data_loader):
            g_loss, d_loss = training_step_dcgan(
                batch,
                device,
                generator,
                discriminator,
                optimizers,
                train_config,
                loss_fn,
            )

        print(
            "\nEpoch {}/{}:\n"
            "  Discriminator loss={:.4f}\n"
            "  Generator loss={:.4f}".format(
                epoch, train_config["max_epoch"], d_loss.item(), g_loss.item()
            )
        )

        if epoch == 1 or epoch % train_config["sample_save_freq"] == 0:
            # Save sample
            gen_sample = generator(z_sample)
            save_image_grid(
                img_batch=gen_sample[: train_config["grid_size"] ** 2]
                .detach()
                .cpu()
                .numpy(),
                grid_size=train_config["grid_size"],
                epoch=epoch,
                img_path=os.path.join(
                    args.checkpoint_path,
                    "samples",
                    "checkpoint_ep{}_sample.png".format(epoch),
                ),
            )
            print("Image sample saved.")

        if epoch == 1 or epoch % train_config["save_freq"] == 0:
            # Save checkpoint
            gen_path = os.path.join(
                args.checkpoint_path, "weights", "checkpoint_ep{}_gen.pt".format(epoch)
            )
            disc_path = os.path.join(
                args.checkpoint_path, "weights", "checkpoint_ep{}_disc.pt".format(epoch)
            )
            torch.save(generator.state_dict(), gen_path)
            torch.save(discriminator.state_dict(), disc_path)
            print("Checkpoint.")

This will show an image of our fake data while the GAN is running.

## Define the Training Function

## Set the Parameters

In [None]:
# Size input image for discriminator
real_size = (128, 128, 3)

# Size of latent vector to generator
z_dim = 100
learning_rate_D = 0.000005  # Thanks to Alexia Jolicoeur Martineau https://ajolicoeur.wordpress.com/cats/
learning_rate_G = 0.00002  # Thanks to Alexia Jolicoeur Martineau https://ajolicoeur.wordpress.com/cats/
batch_size = 32
epochs = 100
alpha = 0.2
beta1 = 0.5

In [None]:
# Load the data and train the network here
dataset = Dataset(resized_data_filenames)

In [None]:
dataset.shape

## Train the Model

In [None]:
train_config = dict()
train_config[""]

## Improving the Generated Images

- Give it a larger dataset (~10K images)
- Run for a larger number of epochs

## Learning More About GANs

- Two books about GANs are on their way to the library. 
- If you find something interesting, maybe send it to me and I can ask for it to be ordered too.

## Interesting ML / GANs Tools

- 