# GAN Demo

## Load Libraries

In [None]:
import torch
import torch.optim as optim
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder

import matplotlib.pyplot as plt

%matplotlib inline
plt.rcParams['figure.figsize'] = (10.0, 8.0) # set default size of plots
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'

%load_ext autoreload
%autoreload 2

## Load Utils

In [None]:
from models.gan_train import train_gan
from models.utils import show_images

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
NOISE_DIM = 100

## Load Losses

In [None]:
from models.gan_loss import discriminator_loss, generator_loss, ls_discriminator_loss, ls_generator_loss

## LNN-based GAN

### Load Model

In [None]:
from models.gan import LNN_Discriminator, LNN_Generator

### Load Data

In [None]:
batch_size = 128

mnist = datasets.MNIST('./data', train=True, download=True, transform=transforms.ToTensor())
train_loader = DataLoader(mnist, batch_size=batch_size, drop_last=True)

try:
  imgs = next(iter(train_loader))[0].view(batch_size, 784).numpy().squeeze()
except:
  imgs = train_loader.__iter__().next()[0].view(batch_size, 784).numpy().squeeze()

show_images(imgs)

### Train Model with GAN Loss

In [None]:
LNN_D = LNN_Discriminator(input_channels=1, image_size=28).to(device)
LNN_G = LNN_Generator(NOISE_DIM).to(device)

LNN_D_optimizer = optim.Adam(LNN_D.parameters(), lr=1e-3, betas=(0.5, 0.999))
LNN_G_optimizer = optim.Adam(LNN_G.parameters(), lr=1e-3, betas=(0.5, 0.999))

train_gan(LNN_D, LNN_G, LNN_D_optimizer, LNN_G_optimizer, discriminator_loss, generator_loss, show_every=500, train_loader=train_loader, num_epochs=10, device=device)

### Train Model with LSGAN Loss

In [None]:
LNN_D_LS = LNN_Discriminator(input_channels=1, image_size=28).to(device)
LNN_G_LS = LNN_Generator(NOISE_DIM).to(device)

LNN_D_LS_optimizer = optim.Adam(LNN_D_LS.parameters(), lr=1e-3, betas=(0.5, 0.999))
LNN_G_LS_optimizer = optim.Adam(LNN_G_LS.parameters(), lr=1e-3, betas=(0.5, 0.999))

train_gan(LNN_D_LS, LNN_G_LS, LNN_D_LS_optimizer, LNN_G_LS_optimizer, ls_discriminator_loss, ls_generator_loss, show_every=500, train_loader=train_loader, num_epochs=10, device=device)

## CNN-based GAN

### Load Model

In [None]:
from models.gan import CNN_Discriminator, CNN_Generator

### Load Data

In [None]:
batch_size = 64
image_size = 64
cat_root = './data/cats'

cat_train = ImageFolder(root=cat_root, transform=transforms.Compose([
  transforms.ToTensor(),
  transforms.Resize(int(1.15 * image_size), antialias=True),
  transforms.RandomCrop(image_size),
  transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
]))

cat_train_loader = DataLoader(cat_train, batch_size=batch_size, drop_last=True)

In [None]:
try:
  imgs = next(iter(cat_train_loader))[0].numpy().squeeze()
except:
  imgs = cat_train_loader.__iter__().next()[0].numpy().squeeze()

show_images(imgs, color=True)

### Train Model with GAN Loss

In [None]:
CNN_D = CNN_Discriminator(input_channels=3).to(device)
CNN_G = CNN_Generator(NOISE_DIM).to(device)

CNN_D_optimizer = optim.Adam(CNN_D.parameters(), lr=1e-3, betas=(0.5, 0.999))
CNN_G_optimizer = optim.Adam(CNN_G.parameters(), lr=1e-3, betas=(0.5, 0.999))

train_gan(CNN_D, CNN_G, CNN_D_optimizer, CNN_G_optimizer, discriminator_loss, generator_loss, show_every=250, train_loader=cat_train_loader, num_epochs=50, device=device)

### Train Model with LSGAN Loss

In [None]:
CNN_D_LS = CNN_Discriminator(input_channels=3).to(device)
CNN_G_LS = CNN_Generator(NOISE_DIM).to(device)

CNN_D_LS_optimizer = optim.Adam(CNN_D_LS.parameters(), lr=1e-3, betas=(0.5, 0.999))
CNN_G_LS_optimizer = optim.Adam(CNN_G_LS.parameters(), lr=1e-3, betas=(0.5, 0.999))

train_gan(CNN_D_LS, CNN_G_LS, CNN_D_LS_optimizer, CNN_G_LS_optimizer, discriminator_loss, generator_loss, show_every=250, train_loader=cat_train_loader, num_epochs=50, device=device)