## Deep Convolutional Generative Adversarial Network (DCGAN)

Paper: [Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks](https://arxiv.org/pdf/1511.06434v2)

Helpful Resources:
- [Aladdin Persson's playlist on GANs](https://youtube.com/playlist?list=PLhhyoLH6IjfwIp8bZnzX8QR30TRcHO8Va&si=8ooImkbbXhCUC1xB)
- [GANs specialization on coursera](https://www.coursera.org/specializations/generative-adversarial-networks-gans)
- [Stanford's Deep Generative Models playlist](https://youtube.com/playlist?list=PLoROMvodv4rPOWA-omMM6STXaWW4FvJT8&si=N_TpTe1bPIhte-t8)
- [AssemblyAI's GAN tutorial](https://youtu.be/_pIMdDWK5sc?si=Mtx2oWh1ZO9tqWYg)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.utils import make_grid

import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

torch.manual_seed(0)

print("Imports done!")

Imports done!


In [None]:
def plot_images(img_tensor, num_imgs=25, size=(1,28,28)):
    """
    Given a tensor of images, number of images, and size per image, 
    this function plots and prints the images in a uniform grid.
    """
    img_unflat = img_tensor.detach().cpu().view(-1, *size)
    img_grid = make_grid(img_unflat[:num_imgs], nrow=5)
    plt.imshow(img_grid.permute(1,2,0).squeeze())
    plt.show()


In [None]:
def plot_results(results):
    """
    results is dictionary with keys: "gen_train_loss", "gen_test_loss", 
        "disc_train_loss", "disc_test_loss", "gen_train_acc", "gen_test_acc", 
        "disc_train_acc", "disc_test_acc".
    This function plots the train and test losses and accuracies.

    However, for now, we'll only plot the train losses for the generator and discriminator.
    """
    plt.plot(results["gen_train_loss"], label="Generator train loss")
    plt.plot(results["disc_train_loss"], label="Discriminator train loss")
    plt.legend()
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.show()
    