# Create a generative adversarial network, building and training a GAN that can generate hand-written images of digist(0-9)

### Learning Objectives
1. Build the generator and discriminator components of a GAN from scratch.
2. Create generator and discriminator loss functions.
3. Train your GAN and visualize the generated images.

### Getting Started
Import useful packages and the dataset you will use to build and train your GAN. A visualizer function is there to investigate the images the GAN will create.

In [1]:
import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.datasets import MNIST  # Training dataset
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
torch.manual_seed(0) # Set for training purposes, do not change

def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28)):
    '''
    Function for visualizing images: Given a tensor of images, number of images, and size per images, plots and 
    prints the images in a uniform grid.
    '''
    image_unflat = image_tensor.detach().cpu().view(-1, *size)
    image_grid = make_grid(image_unflat[:num_images], nrow=5)
    plt.imshow(image_grid.permute(1,2,0).squeeze())
    plt.show()


### Generator
Create a function to make a single layer/block for the generator's neural network.Each block should include
a linear transformation to map to another shape, a batch normalization for stabilization 
and finally a non-linear activation function (ReLU) so the output can be transformed in complex ways.

In [3]:
# Define the generator block function
def get_generator_block(input_dim, output_dim):
    '''
    Function for returning a block of the generator's neural network given input and output dimensions.
    Parameters:
        input_dim: the dimension of the input vector, a scalar
        output_dim: the dimension of the output vector, a scalar
    Returns:
        a generator neural network layer, with a linear transformation
        followed by a batch normalization and then a relu activation
    '''
    return nn.Sequential(
        # Hint: Replace all of the "None" with the appropriate dimensions.
        # The documentation may be useful if you're less familiar with PyTorch:
        # https://pytorch.org/docs/stable/nn.html.
        nn.Linear(None, None),
        nn.BatchNorm1d(None),
        nn.ReLU(inplace=True)
    )