In [None]:
%load_ext autoreload
%autoreload 2

# Import libraries and load data

In [None]:
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.utils as vutils
from torchvision import datasets, transforms
import matplotlib.pyplot as plt

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=(0.5,), std=(0.5,))])
data = datasets.MNIST(root='./data',download=True, transform=transform)
img_shape = (data[1][0].size(0), data[1][0].size(1), data[1][0].size(2))
print(f'Input size is {img_shape}')

# Prepare dataloader for training
batch_size = 256
dataloader = DataLoader(data, batch_size=batch_size, shuffle=True)

# Plot some training images
real_batch = next(iter(dataloader))
plt.figure(figsize=(6,6))
plt.axis('off')
plt.imshow(np.transpose(vutils.make_grid(real_batch[0][:16], padding=2, normalize=True).cpu(),(1,2,0)))

# Build Vanilla GAN

In [None]:
from GAN_pytorch import Generator, Discriminator, Train

dim_latent = 100
lr = 0.001
cuda = torch.cuda.is_available()
device = torch.device('cuda' if cuda else 'cpu')

# Initialize generator and discriminator
G = Generator(img_shape=img_shape, dim_latent=dim_latent, g_dims=[128,256,512,1024]).to(device)
D = Discriminator(img_shape=img_shape, d_dims=[512, 256]).to(device)

optimizer_G = torch.optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))


# Training

In [None]:
Train(epoch=50, dataloader=dataloader, device=device, G=G, D=D,
      optimizer_G=optimizer_G, optimizer_D=optimizer_D)

# Save trained model if needed
torch.save(G, './vanilla_G.pt')
torch.save(D, './vanilla_D.pt')

# Load trained model if needed
G = torch.load('./vanilla_G.pt')
D = torch.load('./vanilla_D.pt')

# Examples of Generated Images

In [None]:
G.eval()
z = torch.rand(16, dim_latent).to(device)
fake = G(z).detach().cpu()
plt.figure(figsize=(6,6))
plt.axis('off')
plt.imshow(np.transpose(vutils.make_grid(fake, padding=2, normalize=True).cpu(),(1,2,0)))
plt.show()