In [None]:
# General Deps
import random
import os

import numpy as np
import matplotlib.pyplot as plt

import matplotlib.animation as animation
from IPython.display import HTML

# Torch Deps
import torch
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils

# DCGAN
import gaudi_dcgan as dcgan

In [None]:
# Seed Model
random.seed(151)
torch.manual_seed(151)

# Init Model Config w. Default DCGAN Values
model_cfg = dcgan.ModelCheckpointConfig()
train_cfg = dcgan.TrainingConfig()

In [None]:
# Root directory for dataset
dataroot = "/efs/images/sample"

# We can use an image folder dataset the way we have it setup.
dataset = dset.ImageFolder(
    root=dataroot,
    transform=transforms.Compose(
        [
            transforms.CenterCrop(train_cfg.img_size * 2),
            transforms.Resize(train_cfg.img_size),
            transforms.ToTensor(),
            transforms.Normalize(
                (
                    0.5,
                    0.5,
                    0.5,
                ),
                (
                    0.5,
                    0.5,
                    0.5,
                ),
            ),
        ]
    ),
)

# Create the dataloader
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=train_cfg.batch_size,
    shuffle=True,
    num_workers=(os.cpu_count() - 1),
)

# Plot some training images
real_batch = next(iter(dataloader))
plt.figure(figsize=(16, 16))
plt.axis("off")
plt.title("Training Images")
plt.imshow(
    np.transpose(
        vutils.make_grid(
            real_batch[0].to(train_cfg.dev)[:16], padding=2, normalize=True
        ).cpu(),
        (1, 2, 0),
    )
)

In [None]:
result = dcgan.start_or_resume_training_run(
    dataloader, train_cfg, model_cfg, num_epochs=16, start_epoch=0
)

In [None]:
plt.figure(figsize=(10, 5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(result["losses"]["_G"], label="G")
plt.plot(result["losses"]["_D"], label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

In [None]:
fig = plt.figure(figsize=(8, 8))
plt.axis("off")
ims = [
    [plt.imshow(np.transpose(i, (1, 2, 0)), animated=True)] for i in result["img_list"]
]

ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)

content = HTML(ani.to_jshtml())

In [None]:
# Grab a batch of real images from the dataloader
real_batch = next(iter(dataloader))

# Plot the real images
plt.figure(figsize=(15, 15))
plt.subplot(1, 2, 1)
plt.axis("off")
plt.title("Real Images")
plt.imshow(
    np.transpose(
        vutils.make_grid(
            real_batch[0].to(train_cfg.dev)[:64], padding=5, normalize=True
        ).cpu(),
        (1, 2, 0),
    )
)

# Plot the fake images from the last epoch
plt.subplot(1, 2, 2)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(result['img_list'][-1], (1, 2, 0)))
plt.show()