# Train
This notebook is used for training of the network as well as setting of the hyperparameters and visualization of the generator ouputs.


In [None]:
import numpy as np

import torch
from torch.nn import BCELoss
from torch.optim import Adam

# Networks and utility functions
from network import Generator, Discriminator, weights_init
from utils import get_dataloader_image

# Visualization
import matplotlib.pyplot as plt
import torchvision.utils as vutils

# Parameters
The `generator_layers` and `discriminator_layers` variables decide the total number of layers and the number of channels in each of them.

The `image_size` is set to match given architecture of the generator

The `visualization_noise` is set here, so that results across epochs can be directly compared

In [None]:
image_folder = "images"
batch_size = 128
latent_size = 200
num_epochs = 100
learning_rate = 1e-4
generator_layers = [latent_size, 512, 256, 128, 64]
discriminator_layers = [1, 64, 128, 256, 512]
image_size = 2**len(generator_layers)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
beta = 0.5 # Adam hyperparameter

visualize = True
print_freq = 25
visualization_noise = torch.randn(25, latent_size, 1, 1).to(device)

# Train modules
Modules setup, no change should be made in this cell.

In [None]:
D = Discriminator(discriminator_layers).to(device)
G = Generator(generator_layers).to(device)
D = D.apply(weights_init)
G = G.apply(weights_init)

loss_fcn = BCELoss()

optimizer_D = Adam(D.parameters(), lr=learning_rate, betas=[beta, 0.999])
optimizer_G = Adam(G.parameters(), lr=learning_rate, betas=[beta, 0.999])

dataloader = get_dataloader_image(image_folder, batch_size, image_size)

## Train
Main train loop. On GTX 1060ti (low-end GPU), this cell ran for ~ 2 hours

In [None]:
plots = []
for n in range(num_epochs):
    for it, images_real in enumerate(dataloader):
        # Discriminator update #
        D.zero_grad()

        images_real = images_real.to(device)
        batch_size = images_real.shape[0]
        
        # Predict real images
        label = torch.ones(batch_size, dtype=torch.float).to(device)
        predictions_real = D(images_real)

        loss_real = loss_fcn(predictions_real, label)
        loss_real.backward()

        # Generate fake images
        input_noise = torch.randn(batch_size, latent_size, 1, 1,).to(device)
        images_fake = G(input_noise)

        # Predict fake images
        label = torch.zeros(batch_size, dtype=torch.float).to(device)
        predictions_fake = D(images_fake.detach())

        loss_fake = loss_fcn(predictions_fake, label)
        loss_fake.backward()

        # Update dicriminator weights
        loss_discriminator = loss_real + loss_fake
        optimizer_D.step()

        # Generator update #
        G.zero_grad()

        # Predict fake image
        label = torch.ones(batch_size, dtype=torch.float).to(device)
        predictions_DG = D(images_fake)
        loss_DG = loss_fcn(predictions_DG, label)
        
        # Update generator weights
        loss_DG.backward()
        optimizer_G.step()

        if visualize and (it + 1) % print_freq == 0:
            with torch.no_grad():
                G.eval()
                images_showcase = G(visualization_noise)
                plt.figure(figsize=(8,8))
                plt.axis("off")
                plt.title(f"Showcase images - Epoch: {n + 1} | Iter: {it + 1}")
                plt.imshow(np.transpose(vutils.make_grid(images_showcase, padding=2, normalize=True, nrow=5).cpu(),(1,2,0)))
                plt.show()
                plots.append(np.transpose(vutils.make_grid(images_showcase, padding=2, normalize=True, nrow=5).cpu(),(1,2,0)))
                G.train()

    torch.save(G.state_dict(), f"modelG_{n+1}.pth")
    torch.save(D.state_dict(), f"modelD_{n+1}.pth")
    torch.save(plots, "plots.pth")

# Visualization
Visualization of final network ouputs

In [None]:

noise = torch.randn(36, latent_size, 1, 1).to(device)

G = G.eval().cpu()
G.load_state_dict(torch.load("modelG_100.pth"))
G = G.cuda()

image_showcase = G(noise)

plt.figure(figsize=(40,40))
plt.axis("off")
plt.imshow(np.transpose(vutils.make_grid(image_showcase, padding=2, normalize=True, nrow=9).cpu(),(1,2,0)))
plt.box(False)
plt.savefig("final.png", bbox_inches='tight')
plt.show()

