# DCGAN - SageMaker Training

This notebook is a wrapper around `gaudi_dcgan.py` and is provided for those who might prefer
to train on Sagemaker using a GPU instance (or those who have the determination to run a NB server 
on-top of a DL1 instance). 

This notebook does **NOT** take full advantage of the Gaudi accelerators and I would direct you to 
`run_gaudi_dcgan.py` for the fully-migrated training experience.

**NOTE:** On Sagemaker either the `conda_amazonei_pytorch_latest_p37` (on `notebook-al2-v1`) OR `conda_amazonei_pytorch_latest_p36` (on `notebook-al1-v1`) kernels will be satisfactory for this notebook.

-----------------

In [1]:
%%capture
! pip3 install \
    tensorboard \
    theano

In [None]:
# General Deps
import os
import numpy as np
import matplotlib.pyplot as plt
import torch.multiprocessing as mp

# Torch Deps
import torch
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils

# DCGAN
import gaudi_dcgan as dcgan

## Data Loading & Transformations 
--------------

In [None]:
# These objects (ALL of them) are usually created + configured via CLI args passed to `run_dcgan_gaudi.py`.
# only here for demonstration

# Init Model and Training Configs w. Default Values - See gaudi_dcgan.py for descriptions. For clarity,
# objects below are initialized with their default values.
model_cfg = dcgan.ModelCheckpointConfig(
    model_name="p3_2xl_sagemaker",  # Custom Model Name To Identify Gaudi vs GPU Trained!
    model_dir="/efs/trained_model",
    save_frequency=1,
    log_frequency=50,
    gen_progress_frequency=250,
)

train_cfg = dcgan.TrainingConfig(
    dev = torch.device("cuda"), ## For illustrative purposes. Again, please train on EC2.
    data_root="/efs/imgs/test",
    batch_size=128,
    img_size=64,
    nc=3,
    nz=100,
    ngf=64,
    ndf=64,
    lr=0.0002,
    beta1=0.5,
    beta2=0.999
)

# In general, the ImageFolder/Dataloader reads from the directory of images and applys a transformation
# at runtime to generate our training images. See `Data and Transformations` section for details.
dataset = dset.ImageFolder(
    root="/efs/imgs/test",
    transform=transforms.Compose(
        [
            transforms.RandomAffine(degrees=0, translate=(0.3, 0.0)),
            transforms.CenterCrop(64 * 4),
            transforms.Resize(64),
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        ]
    ),
)

# The Habana analog to Pytorch's DataLoader can be more efficient under specific conditions. Here, we 
# create the dataloader with *similar* params to those that would cause the ht.DataLoader() to use 
# acceleration. NOTE: This step can be slow if images are processed from EFS (see note on using EBS);
dataloader = torch.utils.data.DataLoader(
    dataset,
    shuffle=False,
    num_workers=4,
    pin_memory=True,
    timeout=0,
    batch_size=train_cfg.batch_size,
)

In [None]:
# Check that the dset.ImageFolder && data.DataLoader are correct and the training data look OK
real_batch = next(iter(dataloader))

plt.figure(figsize=(16, 16))
plt.axis("off")
plt.title("Training Images")

# Plot and Save Sample Training Images
grid = vutils.make_grid(
        real_batch[0].to(train_cfg.dev)[:16], 
        padding=2,
        normalize=True
        ).cpu()

plt.imshow(
    np.transpose(grid, (1, 2, 0))
)

# Create Figures && Events Directory if Not Yet Exists
if not os.path.exists(f"{model_cfg.model_dir}/{model_cfg.model_name}/figures"):
    os.makedirs(f"{model_cfg.model_dir}/{model_cfg.model_name}/figures")
    
if not os.path.exists(f"{model_cfg.model_dir}/{model_cfg.model_name}/events"):
    os.makedirs(f"{model_cfg.model_dir}/{model_cfg.model_name}/events")

# Save to Disk    
plt.savefig(f"{model_cfg.model_dir}/{model_cfg.model_name}/figures/train_samples.png")

## Model Training 
----------------------------

In [None]:
#$ Set some params...
TOTAL_NUM_EPOCHS = 16
START_EPOCH = 0
ENABLE_PROFILING = False
ENABLE_LOGGING = True

# Train the Model - Refer to documentation on `dcgan.start_or_resume_training_run` for details
mp.spawn(
        dcgan.start_or_resume_training_run,
        nprocs=torch.cuda.device_count(),
        args=(
            train_cfg,
            model_cfg,
            TOTAL_NUM_EPOCHS,
            START_EPOCH,
            ENABLE_PROFILING,
            ENABLE_LOGGING,
        ),
        join=True
)