## Adversarial Autoencoder (Basic Implementation)

In [1]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch import nn
from adversarial import AdversarialAutoencoder
from likelihood import cross_validate_sigma, estimate_log_likelihood

## Dataset Configuration

In [3]:
def configure_mnist(batch_size=100):
    # Define the transform to convert to tensor and flatten the image
    transform = transforms.Compose([
        transforms.ToTensor(),  # [0, 1 normalization]
        transforms.Lambda(lambda x: x.view(-1))  # Flatten
    ])

    # Load the training and test datasets
    train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
    test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)

    # Extract the training and test data and labels
    X_train = train_dataset.data.float().div(255).view(-1, 28 * 28)  # Normalize and flatten
    X_test = test_dataset.data.float().div(255).view(-1, 28 * 28)

    Y_train = train_dataset.targets  # Get corresponding labels
    Y_test = test_dataset.targets  # Get corresponding labels

    # Create the DataLoader for training
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True
    )

    # Return data and dataloader as a tuple
    return X_train, X_test, Y_train, Y_test, train_loader






## Paper Configuration

In [4]:
INPUT_DIM = 784
BATCH_SIZE = 100
AE_HIDDEN = 1000
DC_HIDDEN = 1000
LATENT_DIM = 8
PRIOR_STD = 5.0
recon_loss = nn.MSELoss()
init_recon_lr = 0.01
init_gen_lr = init_disc_lr = 0.1
use_decoder_sigmoid = True


## Training

In [5]:
aae = AdversarialAutoencoder(
    input_dim=INPUT_DIM,
    ae_hidden=AE_HIDDEN,
    dc_hidden=DC_HIDDEN,
    latent_dim=LATENT_DIM,
    recon_loss_fn=recon_loss,
    init_recon_lr=init_recon_lr,
    init_gen_lr=init_gen_lr,
    init_disc_lr=init_disc_lr,
    use_decoder_sigmoid=use_decoder_sigmoid,
    device = "cuda" if torch.cuda.is_available() else "cpu"
)

X_train, X_test, Y_train, Y_test, train_loader = configure_mnist(batch_size=BATCH_SIZE)



In [None]:
aae.train_mbgd(
    data_loader=train_loader,
    epochs=50,
    prior_std=PRIOR_STD,
)

## Evaluation

In [6]:
aae.load_weights(path_prefix="aae_weights")

Weights loaded from 50_aae_weights_*.pth


In [7]:
samples = aae.generate_samples(n=10000, prior_std=PRIOR_STD)

In [13]:

cross_validate_sigma(
    samples=samples,
    validation_dataset=X_train[50000:60000],
    sigma_range=np.exp(np.linspace(np.log(0.167), np.log(0.18), 10)),
    batch_size=100,
)





Evaluating sigma = 0.15999999999999998
Sigma: 0.16000, Log-Likelihood: 285.59569
Evaluating sigma = 0.16210768217462557
Sigma: 0.16211, Log-Likelihood: 283.94969
Evaluating sigma = 0.16424312887518386
Sigma: 0.16424, Log-Likelihood: 284.04009
Evaluating sigma = 0.16640670584415235
Sigma: 0.16641, Log-Likelihood: 282.12530
Evaluating sigma = 0.16859878364193906
Sigma: 0.16860, Log-Likelihood: 289.70884
Evaluating sigma = 0.17081973771034953
Sigma: 0.17082, Log-Likelihood: 286.45709
Evaluating sigma = 0.173069948436889
Sigma: 0.17307, Log-Likelihood: 291.83917
Evaluating sigma = 0.17534980121991275
Sigma: 0.17535, Log-Likelihood: 285.40476
Evaluating sigma = 0.17765968653463365
Sigma: 0.17766, Log-Likelihood: 288.41038
Evaluating sigma = 0.18000000000000002
Sigma: 0.18000, Log-Likelihood: 288.95250
Best sigma: 0.173069948436889


np.float64(0.173069948436889)

In [15]:
estimate_log_likelihood(
    samples=samples,
    test_data=X_test,
    sigma=0.16859878364193906
)

(np.float64(290.69415833171877), np.float64(3.2503453024047224))