# DCGAN - Parzen Window-based Log-Likelihood Estimates

This notebook is a wrapper around the Parzen log-liklihood estimator described and implemented
in the [original DCGAN paper](https://github.com/goodfeli/adversarial/blob/master/parzen_ll.py). 

> We estimate probability of the test set data under pg by fitting a Gaussian Parzen window to the
samples generated with G and reporting the log-likelihood under this distribution. The σ parameter of the Gaussians was obtained by cross validation on the validation set. This procedure was intro-
duced in Breuleux et al. [8] and used for various generative models for which the exact likelihood
is not tractable

Slight modifications are made in the local file (`parzen_ll.py`) for the following:

- Migrate from Python2 -> Python3 syntax
- Add comments and docstrings for clarity

The goal of this project is not to develop a new  ramework for estimating generative models, consequently, the log-likelihoods calculated here are meant only for internal comparison between models. As you'll notice, I do not use `MNIST`, `TFD`, or `CIFAR-10` as a validation set, but rather a sample of MSLS images held-out from training.

**Note:** Only tested on `conda_amazonei_pytorch_latest_p3X` and `python_latest_p3X`

--------------------

In [None]:
%%capture
# These are NOT on all `conda_amazonei_pytorch_latest_p3X` or `conda_pytorch_p3X` builds
! pip3 install ./../model

In [None]:
# General
import numpy as np
import matplotlib.pyplot as plt
import datetime
import copy

# Torch Deps
import torch
import torchvision
import torchvision.utils as vutils
import torchvision.datasets as dset
import torchvision.transforms as transforms

# DCGAN
import msls.gpu_dcgan as dcgan
import msls.dcgan_utils as utils
import msls.gan as gan
import msls.evaluation as evaluation

In [None]:
# Inputs 
# NOTE: 
#    - The directory (`CV_DATAROOT_00X`) is assumed to be populated with a holdout of images from MSLS. 

# Normally these don't need to be set; this is a special case to allow for reshaping for cross-validation
# w. parzen est.
CV_DATAROOT = "/efs/imgs/test/miami" 
IMG_SIZE = 64
BATCH_SIZE = 64
VALIDATION_SAMPLE_SIZE = 100
EPOCH_FREQUENCY = 1
ESTIMATION_EPOCH = 16
N_SAMPLES = 100

# See `Data and Transformations` section for details.
TORCH_DL_COMPOSED_TRANSFORMS = transforms.Compose([
    transforms.RandomAffine(degrees=0, translate=(0.3, 0.0)),
    transforms.CenterCrop(IMG_SIZE * 4),
    transforms.Resize(IMG_SIZE),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])


In [None]:
# ImageFolder/Dataloader reads from the directory of images and applys a transformation
dataset = torchvision.datasets.ImageFolder(
    root=CV_DATAROOT,
    transform=TORCH_DL_COMPOSED_TRANSFORMS
)

# Use LimitDataset wrapper to ensure we're able to transfer into memory safely
# WARNING: By setting `batch_size=len(dataset)`, we're forcing the loader to read all data in a single
# iteration. This is "safe" ONLY because we've used `utils.LimitDataset` above to fix the amount of
# images/memory that operation will use!
limited_msls_data = utils.LimitDataset(
    dataset, min(VALIDATION_SAMPLE_SIZE, len(dataset))
)

msls_real_dataloader = torch.utils.data.DataLoader(
    limited_msls_data,
    shuffle=False,
    num_workers=0,
    pin_memory=True,
    batch_size=len(limited_msls_data),  # Force Single Batch
)

# The implementation of Parzen LL expects data in a particular shape; convert from
# [N x 3 x 64 x 64] => [N x (3 x 64 x 64)]
msls_real_data = next(iter(msls_real_dataloader))[0].numpy()
print(f"Shape after Fetch From Loader: {msls_real_data.shape}")

msls_real_data = msls_real_data.reshape(
    (msls_real_data.shape[0], np.prod(msls_real_data.shape[1:]))
)

print(f"Shape after Reshape: {msls_real_data.shape}")

## Initialize Model and Training Configs 

Model and training configs are required to specificy the model to load and generate samples from `G` as of a specific epoch

-------

In [None]:
# Initialize Model and Training Configs w. default args. 

model_cfg = dcgan.ModelCheckpointConfig(
    name="msls-dcgan-128",  # Custom Model Name To Identify Gaudi vs GPU Trained!
    root="/efs/trained_model",
    save_frequency=1,
    log_frequency=50,
)


# Train Config: Must Have Same Size Params as Model...
# python3 -m msls.run_dcgan \
#     -c '{"name": "msls-dcgan-128", "root": "/efs/trained_model/", "log_frequency": 50, "save_frequency": 1}' \
#     -t '{"nc": 3, "nz": 256, "ngf": 256, "ndf": 64, "lr": 0.0002, "beta1": 0.5, "beta2": 0.999, "batch_size": 256, "img_size": 64, "weight_decay": 0.05}'\
#     --s_epoch 0 \
#     --n_epoch 16 \
#     --dataroot /data/imgs/train_val/helsinki \
#     --logging True \
#     --profile True  \
#     --s3_bucket 'dmw2151-habana-model-outputs'

train_cfg = dcgan.TrainingConfig(
    dev = torch.device("cpu"),
    data_root = CV_DATAROOT,
    nz = 256,
    nc = 3,
    ngf = 256,
    ndf = 64,  
)

## Calculating Log-Likelihood For a Single Epoch

In the section below we generate samples from `G` as of a specific epoch, `ESTIMATION_EPOCH`. We then use these samples to calculate a `Sigma` from a set of candidate values, and then estimate the Log-Likelihood of the test set.

--------------

In [None]:
## Restore the Generator for Creating New Images
G, opt_G = train_cfg.get_network(gan.Generator64, device_rank=0)

checkpoint = utils.get_checkpoint(
    path=f"{model_cfg.root}/{model_cfg.name}/checkpoint_{ESTIMATION_EPOCH}.pt",
    cpu=True,
)

utils.restore_G_for_inference(checkpoint, G)

# Generate N samples
Z = torch.randn(N_SAMPLES, train_cfg.nz, 1, 1, device=train_cfg.dev)
generated_data = G(Z).detach().numpy()

# The implementation of Parzen LL expects data in a particular shape; convert from
# [N x 3 x 64 x 64] => [N x (3 x 64 x 64)]
generated_data = generated_data.reshape(
    (generated_data.shape[0], np.prod(generated_data.shape[1:]))
)

# Estimate Sigma on G(Z) and MSLS Data...
sigma = evaluation.cross_validate_sigma(
    generated_data,
    msls_real_data,
    np.logspace(-1.0, 0, num=10),  # Default Sigma-space from DCGAN
    BATCH_SIZE,  # Default Batch Size
)

# Fit Parzen Estimator && Calculate LL
parzen = evaluation.theano_parzen(generated_data, sigma)

ll = evaluation.get_nll(msls_real_data, parzen, batch_size=BATCH_SIZE)

se = ll.std() / np.sqrt(msls_real_data.shape[0])

print(f"Log-Likelihood Results:\t[σ : {sigma:.4f}]\t[nll: {ll.mean():.4f}]\t[se: {se:.4f}]")

## Calculate Log-Likelihood For Multiple Epochs

Same procedure as above, calculate the LL over multiple epochs. Loading and generating from `G` across multiple
checkpoints. Uses a fixed `sigma`. This may not be the optimal method to asses the quality of a GAN's output, but it does demonstrate progress over time.

----------

In [None]:
# Fit Parzen Estimator && Calculate LL
log_likelihoods = []
std_errs = []

# Generate N samples
Z = torch.randn(N_SAMPLES, train_cfg.nz, 1, 1, device=train_cfg.dev)

for cur_epoch in range(0, ESTIMATION_EPOCH, EPOCH_FREQUENCY):
    
    # Restore G to a Particular Epoch
    checkpoint = utils.get_checkpoint(
        path=f"{model_cfg.root}/{model_cfg.name}/checkpoint_{cur_epoch}.pt",
        cpu=True,
    )
    
    utils.restore_G_for_inference(checkpoint, G)
    
    # Generate Data as of Epoch
    generated_data = G(Z).detach().numpy()

    generated_data = generated_data.reshape(
        (generated_data.shape[0], np.prod(generated_data.shape[1:]))
    )

    parzen = evaluation.theano_parzen(generated_data, sigma)

    # Estimate Log-Likelihood
    ll = evaluation.get_nll(msls_real_data, parzen, batch_size=BATCH_SIZE)

    se = ll.std() / np.sqrt(msls_real_data.shape[0])
    log_likelihoods.append(ll.mean())
    std_errs.append(se)
    
    print(f"[Epoch: {cur_epoch}]\t[σ : {sigma:.4f}]\t[nll: {ll.mean():.4f}]\t[se: {se:.4f}]")

In [None]:
# Plot LL over Range && Save to TensorBoard

plt.figure(figsize=(12, 6))
plt.title(f"Log Likliehood over Training Epochs - {model_cfg.name}")

plt.plot(range(0, ESTIMATION_EPOCH, EPOCH_FREQUENCY), log_likelihoods)

plt.errorbar(
    range(0, ESTIMATION_EPOCH, EPOCH_FREQUENCY), log_likelihoods, yerr=std_errs, fmt="o"
)

plt.xlabel("Epoch")
plt.ylabel("Log-Likelihood")

plt.show()

plt.savefig(
    f"{model_cfg.root}/{model_cfg.name}/figures/log_likelihood.png"
)