In [None]:
# General Deps - Use Kernel: conda_amazonei_pytorch_latest_p37
import random
import os

import numpy as np
import matplotlib.pyplot as plt

import matplotlib.animation as animation
from IPython.display import HTML

In [None]:
# Torch Deps
import torch
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils

# DCGAN
import gaudi_dcgan as dcgan

In [None]:
## Sample Usage on Command Line
! python3 run_gaudi_dcgan.py \
    --dataroot "/efs/images/" \
    --seed 215 \
    --name msls_2022_01_24_001 \
    --s_epoch 0 \
    --n_epoch 16

In [None]:
# Seed Model
random.seed(215)
torch.manual_seed(215)

# Init Model Config w. Default DCGAN Values
model_cfg = dcgan.ModelCheckpointConfig()
train_cfg = dcgan.TrainingConfig(
    dev=torch.device("cuda:0")
)

In [None]:
%%time

# Root directory for dataset
dataroot = "/efs/images/"

# We can use an image folder dataset the way we have it setup. Depending on the size 
# of the training directory this can take a little to instatiate; about 5-8 min for 
# 25GB (also depends on EFS burst)
dataset = dset.ImageFolder(
    root=dataroot,
    transform=transforms.Compose(
        [
            transforms.RandomAffine(degrees = 0, translate = (0.2, 0.0)),
            transforms.CenterCrop(train_cfg.img_size * 4),
            transforms.Resize(train_cfg.img_size),
            transforms.ToTensor(),
            transforms.Normalize(
                (
                    0.5,
                    0.5,
                    0.5,
                ),
                (
                    0.5,
                    0.5,
                    0.5,
                ),
            ),
        ]
    ),
)

In [None]:
%%time

# Create the dataloader with Similar Params to Habana
dataloader = torch.utils.data.DataLoader(
    dataset,
    shuffle=False,
    num_workers=8,
    pin_memory=True,
    timeout=0,
    batch_size=train_cfg.batch_size
)

# Plot some training images
real_batch = next(iter(dataloader))
plt.figure(figsize=(16, 16))
plt.axis("off")
plt.title("Training Images")
plt.imshow(
    np.transpose(
        vutils.make_grid(
            real_batch[0].to(train_cfg.dev)[:16], padding=2, normalize=True
        ).cpu(),
        (1, 2, 0),
    )
)

In [None]:
# Run Training
result = dcgan.start_or_resume_training_run(
        dataloader, train_cfg, model_cfg, n_epochs=16, st_epoch=0
)

In [None]:
# Plot the Losses Over Time

# X -> Training Step
# Y -> Loss

plt.figure(figsize=(10, 5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(result["losses"]["_G"], label="G")
plt.plot(result["losses"]["_D"], label="D")
plt.xlabel("Iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

In [None]:
fig = plt.figure(figsize=(8, 8))
plt.axis("off")
ims = [
    [plt.imshow(np.transpose(i, (1, 2, 0)), animated=True)] for i in result["img_list"]
]

ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)

content = HTML(ani.to_jshtml())

In [None]:
# Grab a batch of real images from the dataloader
real_batch = next(iter(dataloader))

# Plot the real images
plt.figure(figsize=(15, 15))
plt.subplot(1, 2, 1)
plt.axis("off")
plt.title("Real Images")
plt.imshow(
    np.transpose(
        vutils.make_grid(
            real_batch[0].to(train_cfg.dev)[:64], padding=5, normalize=True
        ).cpu(),
        (1, 2, 0),
    )
)

# Plot the fake images from the last epoch
plt.subplot(1, 2, 2)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(result['img_list'][-1], (1, 2, 0)))
plt.show()

In [None]:
# Generate a Few Sample Images....
plt.figure(figsize=(15, 15))

imgs = dcgan.generate_fake_samples(
    n_samples=16,
    train_cfg=train_cfg,
    model_cfg=model_cfg,
    as_of_epoch=16
)

plt.imshow(
    np.transpose(
        vutils.make_grid(imgs.to(train_cfg.dev), padding=2, normalize=True).cpu(),
        (1, 2, 0),
    )
)