# 1.a Setup

In [4]:
# General Programing tools
from pathlib import Path

# General ML tools
import numpy as np
import pandas as pd

# DL tools
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

# 1.b Helper

In [5]:
# Seed
def seed_all(seed=None):
    print("[ Using Seed : ", seed, " ]")

    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.cuda.manual_seed(seed)
    numpy.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Training device retriver
def get_training_device():
    if torch.cuda.is_available():
        device = torch.device("cuda:0")
        print("[ GPU: {} (1/{}) ]".format(
            torch.cuda.get_device_name(0),
            torch.cuda.device_count())
        )
    else:
        device = torch.device("cpu")
        print("[ Running on the CPU ]")

    return device

# Image writer



# 1.c Logging

In [6]:
# Tensorboard

# 1.d Configuration

In [7]:
# root config
config = {
    "seed": 1024,
    "train": None,
    "optimiser": None
}

# train
config["train"] = {
    "epochs": 70
}

# optimiser
config["optimiser"] = {
    "params": {
        "lr" : 0.001
    }
}

# path
DATA_ROOT = Path("./kaggle/input/mnist")

# 2.a Data Preparation

# 3.a Generator

In [8]:
def non_linear_layer(in_feat, out_feat, normalize=True):
    layers = [nn.Linear(in_feat, out_feat)]
    if normalize:
        layers.append(nn.BatchNorm1d(out_feat, 0.8))
    layers.append(nn.LeakyReLU(0.2, inplace=True))
    return layers

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.model = nn.Sequential(
            *non_linear_layer(opt.latent_dim, 128, normalize=False),
            *non_linear_layer(128, 256),
            *non_linear_layer(256, 512),
            *non_linear_layer(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), *img_shape)
        return img

# 3.b Discriminator

In [9]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)

        return validity

# 4.a Train Single Batch

In [None]:
def _train_single_batch_G():
    pass

def _train_single_batch_D():
    pass

# 4.b Train

In [None]:
def train(dataloader, model, optimiser, loss_fn, device, config):
    epochs = config["train"]["epochs"]

    for epoch in range(epochs):
        for batch_ndx, (image,_) in enumerate(dataloader):

        # 1. Train Discriminator
        real_data = Variable(images_to_vectors(real_batch))
        if torch.cuda.is_available(): 
            real_data = real_data.cuda()
        
        # Generate fake data
        fake_data = generator(noise(real_data.size(0))).detach()
        # Train D
        d_error, d_pred_real, d_pred_fake = train_discriminator(d_optimizer,
                                                                real_data, fake_data)

        # 2. Train Generator
        # Generate fake data
        fake_data = generator(noise(real_batch.size(0)))
        g_error = train_generator(g_optimizer, fake_data)
        
        # Log error
        logger.log(d_error, g_error, epoch, n_batch, num_batches)

        # Display Progress
        if (n_batch) % 100 == 0:
            display.clear_output(True)
            # Display Images
            test_images = vectors_to_images(generator(test_noise)).data.cpu()
            logger.log_images(test_images, num_test_samples, epoch, n_batch, num_batches);
            # Display status Logs
            logger.display_status(
                epoch, num_epochs, n_batch, num_batches,
                d_error, g_error, d_pred_real, d_pred_fake
            )

        # Model Checkpoints
        logger.save_models(generator, discriminator, epoch)

# 4.c Run Pipeline 

In [13]:
def run(config):
    seed_all(seed=config["seed"])

    # models
    model_G = Generator()
    model_D = Discriminator()

    # CPU/GPU
    device = get_training_device()

    # mounting
    model_G = model_G.to(device)
    model_D = model_D.to(device)

    # optimisers
    lr = config["optimiser"]["params"]["lr"]
    optimiser_G = torch.optim.Adam(generator.parameters(), lr=lr)
    optimiser_D = torch.optim.Adam(discriminator.parameters(), lr=lr)

    # loss function
    loss_fn = torch.nn.BCELoss()

    # train
    train(dataloader, model_G, model_D, optimiser_G, optimiser_D, loss_fn, device, config)

In [None]:
run(config)