# **DCGAN Training**

Requires:
- Dataset.

## **Imports**

In [None]:
import os
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils

import matplotlib.pyplot as plt

import time

from scipy.linalg import sqrtm
from sklearn.manifold import TSNE

import networks as nws
import utils

## **Data**

In [None]:
# Spatial size of training images. All images will be resized to this
#   size using a transformer.
# image_size = 28 # Mnist

# Batch size during training
batch_size = 64
workers = 2

# Root directory for dataset
dataroot = "MNIST_full"
# dataroot = "MNIST_no1s"

if dataroot == "MNIST_full":
    # We can use an image folder dataset the way we have it setup.
    # Create the dataset
    dataset = dset.MNIST(
        root=dataroot,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ]),
        download=True
    )

    # Create the dataloader
    dataloader = torch.utils.data.DataLoader(
        dataset, batch_size=batch_size, shuffle=True
    )
    model_suffix = dataroot
else:
    dataset = dset.ImageFolder(
        root=dataroot,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,)),
            transforms.Resize((28, 28)), 
            transforms.Grayscale()
        ])
    )

    # Create the dataloader
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                            shuffle=True)
    model_suffix = dataroot

print(len(dataloader))

## **Models**

In [None]:
# Thiết bị sử dụng
device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
print(device)

#-- Tạo mạng G
version = 3
nz = 128

netG = nws.Generator(version=version, nz=nz).to(device)
netG.apply(nws.weights_init)

#-- Tạo mạng D
netD = nws.Discriminator(version=version).to(device)
netD.apply(nws.weights_init)

In [None]:
n_paramsG = sum(p.numel() for p in netG.parameters() if p.requires_grad)
print(f"GEN version {version}, # params {n_paramsG}")
# v1  312256
# v2  316640

# NOTE: DIS version match with GEN
#   based on GEN's complexity
n_paramsD = sum(p.numel() for p in netD.parameters() if p.requires_grad)
print(f"DIS version {version}, # params {n_paramsD}")

## **Training setup**

In [None]:
criterion = nn.BCELoss()

# Create a batch (64) of latent vectors that we will use to visualize
#  the progression of the generator
fixed_noise = torch.randn(64, nz, 1, 1, device=device)

# Establish convention for real and fake labels during training
real_label = 1.0
fake_label = 0.0

# Learning rate for optimizers
lr = 0.0002

# Beta1 hyperparameter for Adam optimizers
beta1 = 0.5

# Setup Adam optimizers for both G and D
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))
num_epochs = 300

save_dir = "checkpoints_original"
os.makedirs(save_dir, exist_ok=True)

plts_dir = f"plots/GAN_train_{model_suffix}"
os.makedirs(plts_dir, exist_ok=True)

c_optimizer = optimizerG.__class__.__name__ # Adam | SGD
c_lr = str(optimizerD.param_groups[0]['lr']).replace(".", "") # 00002 | 00004 ...

print("-- Epochs: ", num_epochs)
print("-- Current z_dim: ", nz)
print("-- Current optimizer & learning rate: ", c_optimizer, c_lr)
print("-- Is training netG: ", netG.training)
print("-- Is training netD: ", netD.training)
print("-- Device: ", device)

## **Training function**

In [None]:
# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []
iters = 0

ttrain_D, ttrain_G, ttrain_GAN = 0, 0, 0
print("Starting Training Loop...")
# For each epoch
for epoch in range(num_epochs):
    # For each batch in the dataloader
    for i, (data, _) in enumerate(dataloader):
        # print(i)
        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        tD_start = time.time()
        ## Train with all-real batch
        netD.zero_grad()
        # Format batch
        real_cpu = data.to(device)
        b_size = real_cpu.size(0)
        label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
        # Forward pass real batch through D
        output = netD(real_cpu).view(-1)
        # Calculate loss on all-real batch
        errD_real = criterion(output, label)
        # Calculate gradients for D in backward pass
        errD_real.backward()
        D_x = output.mean().item()

        ## Train with all-fake batch
        # Generate batch of latent vectors
        noise = torch.randn(b_size, nz, 1, 1, device=device)
        # Generate fake image batch with G
        fake = netG(noise)
        label.fill_(fake_label)
        # Classify all fake batch with D
        output = netD(fake.detach()).view(-1)
        # Calculate D's loss on the all-fake batch
        errD_fake = criterion(output, label)
        # Calculate the gradients for this batch, accumulated (summed) with previous gradients
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        # Compute error of D as sum over the fake and the real batches
        errD = errD_real + errD_fake
        # Update D
        optimizerD.step()

        tD_end = time.time()
        ttrain_D += tD_end - tD_start

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        tG_start = time.time()
        netG.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        # Since we just updated D, perform another forward pass of all-fake batch through D
        output = netD(fake).view(-1)
        # Calculate G's loss based on this output
        errG = criterion(output, label)
        # Calculate gradients for G
        errG.backward()
        D_G_z2 = output.mean().item()
        # Update G
        optimizerG.step()

        tG_end = time.time()
        ttrain_G += tG_end - tG_start

        ttrain_GAN = ttrain_D + ttrain_G

        # Output training stats
        if i % 50 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch + 1, num_epochs, i, len(dataloader),
                     errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

        # Save Losses for plotting later
        G_losses.append(errG.item())
        D_losses.append(errD.item())

        # Check how the generator is doing by saving G's output on fixed_noise
        if (iters % 50 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
            with torch.no_grad():
                fake = netG(fixed_noise).detach().cpu()
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))

        iters += 1
        torch.cuda.empty_cache()

    # Lưu trọng số G tại epoch 10 và 50
    # if epoch == 10 or epoch == 50:
    #     torch.save(netG.state_dict(), f"{save_dir}/generator_epoch_{epoch}.pth")

print(f">> GAN training time: {ttrain_GAN}")
print(f"  -- Discriminator training time: {ttrain_D}")
print(f"  -- Generator training time: {ttrain_G}")

torch.save(netG.state_dict(), f"{save_dir}/gen_{model_suffix}_v{version}_nz={nz}_epochs={num_epochs}.pth")
torch.save(netD.state_dict(), f"{save_dir}/dis_{model_suffix}_v{version}_nz={nz}_epochs={num_epochs}.pth")

In [None]:
plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses,label="G")
plt.plot(D_losses,label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()

figname = f"{plts_dir}/GAN_loss_{model_suffix}_v{version}_nz={nz}_epochs={num_epochs}.jpg"
plt.savefig(figname, bbox_inches="tight")
plt.show()

In [None]:
real_batch = next(iter(dataloader))

# Plot the real images
plt.figure(figsize=(15,15))
plt.subplot(1,2,1)
plt.axis("off")
plt.title("Real Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(),(1,2,0)))

# Plot the fake images from the last epoch
plt.subplot(1,2,2)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(img_list[-1],(1,2,0)))

figname = f"{plts_dir}/Real_vs_Generated_{model_suffix}_v{version}_nz={nz}_epochs={num_epochs}.jpg"
plt.savefig(figname, bbox_inches="tight")
plt.show()