# DCGAN - SageMaker Training

This notebook is a wrapper around `run_dcgan.py` and is provided for those who might prefer
to test training on Sagemaker. This notebook does **NOT** take full advantage of the Gaudi accelerators and I would direct you to refer to the docs on for the fully-migrated training experience on `DL1` instances.

**WARNING:** On Sagemaker neither the `conda_amazonei_pytorch_latest_p37` (on `notebook-al2-v1`) OR `conda_amazonei_pytorch_latest_p36` (on `notebook-al1-v1`) kernels will be satisfactory for this notebook. **<mark>Please use the base conda_python3 environment and install the module's dependencies with the cell below!</mark>**

-----------------

In [1]:
%%capture
! pip3 install \
    ./../model \
    ffmpeg

In [None]:
# General Deps
import numpy as np
import matplotlib.pyplot as plt
import torch.multiprocessing as mp

# Torch Deps
import torch
import torchvision
import torchvision.utils as vutils

# DCGAN
import msls.gpu_dcgan as dcgan
import msls.dcgan_utils as utils

In [1]:
## Set Training Params...
TOTAL_NUM_EPOCHS = 16
START_EPOCH = 0
ENABLE_PROFILING = False
ENABLE_LOGGING = True

## Data Loading & Transformations 
--------------

In [None]:
# Init Model and Training Configs w. Default Values - See gaudi_dcgan.py and dcgan_utils.py 
# for descriptions.
model_cfg = utils.ModelCheckpointConfig(
    name="sagemaker_demo_model",  # Custom Model Name To Identify Gaudi vs GPU Trained!
    root="/efs/trained_model",
    save_frequency=1,
    log_frequency=50,
)

train_cfg = utils.TrainingConfig(
    dev=torch.device("cuda"),  ## For illustrative purposes. Again, please train on EC2.
    data_root="/efs/imgs/train_val/zurich/",
)

# Initialize dataloader from the training config...
dataloader = dcgan.get_msls_dataloader(0, train_cfg)

In [None]:
# Now we check that the data.DataLoader is correct and the training data look OK w. the
# default transforms applied...
real_batch = next(iter(dataloader))

plt.figure(figsize=(16, 16))
plt.axis("off")
plt.title("Training Images")

# Plot and Save Sample Training Images
grid = vutils.make_grid(
    real_batch[0].to(train_cfg.dev)[:16], 
    padding=2,
    normalize=True
).cpu()

plt.imshow(
    np.transpose(grid, (1, 2, 0))
)

## Model Training 
----------------------------

In [None]:
# Train the Model - Refer to documentation on `dcgan.start_or_resume_training_run` for details
mp.spawn(
    dcgan.start_or_resume_training_run,
    nprocs=torch.cuda.device_count(),
    args=(
        train_cfg,
        model_cfg,
        TOTAL_NUM_EPOCHS,
        START_EPOCH,
        ENABLE_PROFILING,
        ENABLE_LOGGING,
    ),
    join=True
)