# DCGAN - SageMaker Training

This notebook is a wrapper around `gaudi_dcgan.py` and is provided for those who might prefer
to train on Sagemaker using a GPU instance (or those who have the determination to run a NB server 
on-top of a DL1 instance). 

This notebook does **NOT** take full advantage of the Gaudi accelerators and I would direct you to 
`run_gaudi_dcgan.py` for the fully-migrated training experience.

**NOTE:** On Sagemaker either the `conda_amazonei_pytorch_latest_p37` (on `notebook-al2-v1`) OR `conda_pytorch_p36` (on `notebook-al1-v1`) kernels will be satisfactory for this notebook.

-----------------

In [None]:
# General Deps
import random
import os
import numpy as np
import matplotlib.pyplot as plt

# Torch Deps
import torch
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils

# DCGAN
import gaudi_dcgan as dcgan

## Data Loading & Transformations 
--------------

In [None]:
# Model Inputs
SEED = 215
DATAROOT = "/efs/images/"

In [None]:
# Seed PyTorch
random.seed(SEED)
torch.manual_seed(SEED)

# Init Model and Training Configs w. Default Values - See gaudi_dcgan.py for descriptions. For clarity,
# objects below are initialized with their default values.
model_cfg = dcgan.ModelCheckpointConfig(
    batch_size=128,
    img_size=64,
    nc=3,
    nz=100,
    ngf=64,
    ndf=64,
    lr=0.0002,
    beta1=0.5,
    beta2=0.999,
)

train_cfg = dcgan.TrainingConfig(
    model_name="msls_dcgan_ml_p3_8xlarge_001", # Custom Model Name To Identify Gaudi vs GPU Trained!
    model_dir="/efs/trained_model",
    save_frequency=1,
    log_frequency=50,
    gen_progress_frequency=250,
)

In [None]:
%% time

# In general, the ImageFolder/Dataloader reads from the directory of images and applys a transformation
# at runtime to generate our training images. See `Data and Transformations` section for details.
dataset = dset.ImageFolder(
    root=DATAROOT,
    transform=transforms.Compose(
        [
            transforms.RandomAffine(degrees=0, translate=(0.3, 0.0)),
            transforms.CenterCrop(train_cfg.img_size * 4),
            transforms.Resize(train_cfg.img_size),
            transforms.ToTensor(),
            transforms.Normalize(
                (
                    0.5,
                    0.5,
                    0.5,
                ),
                (
                    0.5,
                    0.5,
                    0.5,
                ),
            ),
        ]
    ),
)

# The Habana analog to Pytorch's DataLoader can be more efficient on Gaudi-Accelerated instances under
# specific conditions. Here, we create the dataloader with *similar* params to those that would cause
# the ht.DataLoader() to use acceleration. 

# NOTE: This step can be slow as the images are processed (esp. on a new EFS); anecdotally, around 8-10 
# min to load 1MM images (~30GB total)
dataloader = torch.utils.data.DataLoader(
    dataset,
    shuffle=False,
    num_workers=8,
    pin_memory=True,
    timeout=0,
    batch_size=train_cfg.batch_size,
)

In [None]:
# Check that the dset.ImageFolder && data.DataLoader are correct and the training data look OK
real_batch = next(iter(dataloader))
plt.figure(figsize=(16, 16))
plt.axis("off")
plt.title("Training Images")

# Plot and Save Sample Training Images
plt.imshow(
    np.transpose(
        vutils.make_grid(
            real_batch[0].to(train_cfg.dev)[:16], padding=2, normalize=True
        ).cpu(),
        (1, 2, 0),
    )
)

# Create Figures Directory if Not Yet Exists
if not os.path.exists(f"{model_cfg.model_dir}/{model_cfg.model_name}/figures"):
    os.makedirs(f"{model_cfg.model_dir}/{model_cfg.model_name}/figures")

plt.savefig(f"{model_cfg.model_dir}/{model_cfg.model_name}/figures/train_samples.png")

## Model Training 
----------------------------

In [None]:
# Train the Model - Refer to documentation on `dcgan.start_or_resume_training_run` for details
result = dcgan.start_or_resume_training_run(
    dataloader, train_cfg, model_cfg, n_epochs=64, st_epoch=0
)