# Pix2Pix GAN

In this notebook we create and train the Pix2Pix GAN architecture, which is a type of conditional GAN, or cGAN.

## Dataset

We will use the face image and coressponding face edges from our data folder. 

## Goals

For this notebook I will:
<!-- 1. Exploratory Data Analysis 
2. Train the Gaussian NB model called gaussianNB_clf.
3. Analyze Model Performance 
4. Evaluate the model on Test Data
5. Written Report -->

## Import Libraries and Load data

In [13]:
import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision import datasets, transforms
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import torch.nn.functional as F
%matplotlib inline
import matplotlib.pyplot as plt
torch.manual_seed(0)
import helper

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

data_dir = '../data/'

# We transform the data to tensors and resize them to desired size (O(1) operation)
def load_data_sets(data_dir,batch_size):
    transform = transforms.Compose([transforms.Resize(256),
                                    transforms.ToTensor()])
    dataset = datasets.ImageFolder(data_dir, transform=transform)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

#     images, labels = next(iter(dataloader))
    # helper.imshow(images[0], normalize=False)
    return dataloader

def show_tensor_images(image_tensor, num_images=25, size=(1, 256, 256)):
    image_shifted = image_tensor
    image_unflat = image_shifted.detach().cpu().view(-1, *size)
    image_grid = make_grid(image_unflat[:num_images], nrow=5)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    plt.show()

## U-Net 

## PatchGAN Discriminator

## Training Preparation
Parameters needed to be set in order to train correctly:
  *   **real_dim**: the number of channels of the real image and the number expected in the output image
  *   **adv_criterion**: an adversarial loss function to keep track of how well the GAN is fooling the discriminator and how well the discriminator is catching the GAN
  *   **recon_criterion**: a loss function that rewards similar images to the ground truth, which "reconstruct" the image
  *   **lambda_recon**: a parameter for how heavily the reconstruction loss should be weighed
  *   **n_epochs**: the number of times you iterate through the entire dataset when training
  *   **input_dim**: the number of channels of the input image
  *   **display_step**: how often to display/visualize the images
  *   **batch_size**: the number of images per forward/backward pass
  *   **lr**: the learning rate
  *   **target_shape**: the size of the output image (in pixels)
  *   **device**: the device type

In [12]:
real_dim = 3
adv_criterion = nn.BCEWithLogitsLoss() 
recon_criterion = nn.L1Loss() 
lambda_recon = 200
n_epochs = 20
input_dim = 3
display_step = 200
batch_size = 4
lr = 0.0002
target_shape = 256
device = 'cpu' # use 'cuda' if this is running on a gpu 