In [1]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import torchvision
import torchvision.transforms as T
from torch.utils.data import DataLoader

In [2]:
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device('cpu')

In [3]:
def make_generator_network(
    input_size=20,
    num_hidden_layers=1,
    num_hidden_units=100,
    num_output_units=784,  # 28 * 28
):

    model = nn.Sequential()

    for i in range(num_hidden_layers):
        model.add_module(f"g_fc_{i}", nn.Linear(input_size, num_hidden_units))
        model.add_module(f"g_relu_{i}", nn.LeakyReLU())

    model.add_module(
        f"g_fc_{num_hidden_layers}", nn.Linear(num_hidden_units, num_output_units)
    )
    model.add_module("g_tanh", nn.Tanh())
    return model


def make_discriminator_network(
    input_size=784,  # 28 * 28
    num_hidden_layers=1,
    num_hidden_units=100,
    num_output_units=1,
):

    model = nn.Sequential()

    for i in range(num_hidden_layers):
        model.add_module(f"d_fc_{i}", nn.Linear(input_size, num_hidden_units))
        model.add_module(f"d_relu_{i}", nn.LeakyReLU())
        model.add_module("d_dropout", nn.Dropout(p=0.5))
        input_size = num_hidden_units

    model.add_module(
        f"d_fc_{num_hidden_layers}", nn.Linear(num_hidden_units, num_output_units)
    )
    model.add_module("d_sigmoid", nn.Sigmoid())
    return model


In [4]:
transformer = T.Compose([
    T.ToTensor(),
    T.Normalize((0.5,), (0.5,))
])

In [5]:
def create_noise(batch_size, z_size=20, mode='normal'):  # z_size == gen_model.input_size
    if mode == 'normal':
        return torch.randn(batch_size, z_size)
    return (torch.rand(batch_size, z_size) * 2) - 1

In [6]:
def discriminator_train(d_model, g_model, criterion, d_optimizer, x):
    d_model.zero_grad()
    # Train on REAL data
    x = x.view(BATCH_SIZE, -1).to(device)  # Flattening and transfering data to GPU
    d_labels_real = torch.ones(BATCH_SIZE, 1, device=device)
    d_proba_real = d_model(x)
    d_loss_real = criterion(d_proba_real, d_labels_real)
    # Train on NOISE
    noise = create_noise(BATCH_SIZE).to(device)
    g_output = g_model(noise)
    d_proba_fake = d_model(g_output)
    d_labels_fake = torch.zeros(BATCH_SIZE, 1, device=device)
    d_loss_fake = criterion(d_proba_fake, d_labels_fake)
    # Backprop and optimize d_model only
    d_loss = d_loss_real + d_loss_fake
    d_loss.backward()
    d_optimizer.step()
    return d_loss.data.item(), d_proba_real.detach(), d_proba_fake.detach()

In [7]:
def generator_train(g_model, d_model, criterion, g_optimizer):
    g_model.zero_grad()
    noise = create_noise(BATCH_SIZE).to(device)
    g_output = g_model(noise)
    d_proba_fake = d_model(g_output)
    g_labels_real = torch.ones(BATCH_SIZE, 1, device=device)
    g_loss = criterion(d_proba_fake, g_labels_real)
    # Backprop and optimize g_model only
    g_loss.backward()
    g_optimizer.step()
    return g_loss.data.item()

In [8]:
def train(g_model, d_model, criterion, g_optimizer, d_optimizer, data_loader, epochs=1):
    all_d_losses = []
    all_g_losses = []
    all_d_real = []
    all_d_fake = []
    
    for epoch in range(1, epochs + 1):
        d_losses, g_losses = [], []
        d_vals_real, d_vals_fake = [], []
        
        for i, (x, _) in enumerate(data_loader):
            # Train and record d_model
            d_loss, d_proba_real, d_proba_fake = discriminator_train(d_model, g_model, criterion, d_optimizer, x)
            d_losses.append(d_loss)
            d_vals_real.append(d_proba_real.mean().cpu())
            d_vals_fake.append(d_proba_fake.mean().cpu())
            # Train and record g_model
            g_loss = generator_train(g_model, d_model, criterion, g_optimizer)
            g_losses.append(g_loss)
            
        # Record for every epoch
        all_d_losses.append(torch.tensor(d_losses).mean())
        all_g_losses.append(torch.tensor(g_losses).mean())
        all_d_real.append(torch.tensor(d_vals_real).mean())
        all_d_fake.append(torch.tensor(d_vals_fake).mean())
    
        print(f'Epoch {epoch:03d} | Avg Losses >>'
              f' G/D {all_g_losses[-1]:.4f}/{all_d_losses[-1]:.4f}'
              f' [D-Real: {all_d_real[-1]:.4f}'
              f' D-Fake: {all_d_fake[-1]:.4f}]')

In [9]:
BATCH_SIZE = 128

In [10]:
train_data = torchvision.datasets.MNIST('./mnist/', train=True, download=False, transform=transformer)
train_dl = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=8, drop_last=True)

In [11]:
gen_model = make_generator_network().to(device)
disc_model = make_discriminator_network().to(device)
criterion = nn.BCELoss()
g_optimizer = torch.optim.AdamW(gen_model.parameters())
d_optimizer = torch.optim.AdamW(disc_model.parameters())

In [None]:
%time train(gen_model, disc_model, criterion, g_optimizer, d_optimizer, train_dl, 100)