## Adversarial Autoencoder (Basic Implementation)

In [1]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch import nn
from adversarial import AdversarialAutoencoder
from likelihood import cross_validate_sigma, estimate_log_likelihood

## Dataset Configuration

In [2]:
def configure_mnist(batch_size=100):
    # Transform: Just ToTensor (auto 0-1) + flatten
    transform = transforms.Compose([
        transforms.ToTensor(),  # Automatically scales pixels to [0, 1]
        transforms.Lambda(lambda x: x.view(-1))  # Flatten
    ])

    # Load datasets (applies transform automatically)
    train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
    test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)

    # Get the actual transformed data (0-1 scaled, flattened)
    X_train = torch.stack([x for x, _ in train_dataset])  # Exactly what DataLoader will see
    X_test = torch.stack([x for x, _ in test_dataset])

    Y_train = train_dataset.targets.clone()
    Y_test = test_dataset.targets.clone()

    # DataLoader (will serve same transformed data)
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True
    )

    return X_train, X_test, Y_train, Y_test, train_loader






## Paper Configuration

In [3]:
INPUT_DIM = 784
BATCH_SIZE = 100
AE_HIDDEN = 1000
DC_HIDDEN = 1000
LATENT_DIM = 8
PRIOR_STD = 5.0
recon_loss = nn.MSELoss()
init_recon_lr = 0.01
init_gen_lr = init_disc_lr = 0.1
use_decoder_sigmoid = True


## Training

In [4]:
aae = AdversarialAutoencoder(
    input_dim=INPUT_DIM,
    ae_hidden=AE_HIDDEN,
    dc_hidden=DC_HIDDEN,
    latent_dim=LATENT_DIM,
    recon_loss_fn=recon_loss,
    init_recon_lr=init_recon_lr,
    init_gen_lr=init_gen_lr,
    init_disc_lr=init_disc_lr,
    use_decoder_sigmoid=use_decoder_sigmoid,
    device = "cuda" if torch.cuda.is_available() else "cpu"
)

X_train, X_test, Y_train, Y_test, train_loader = configure_mnist(batch_size=BATCH_SIZE)



In [5]:
aae.train_mbgd(
    data_loader=train_loader,
    epochs=2000,
    prior_std=PRIOR_STD,
)

Epoch (1/2000)	)Recon Loss: 0.2296	)Disc Loss: 0.3463	)Gen Loss: 3.4776	)
Epoch (2/2000)	)Recon Loss: 0.2252	)Disc Loss: 0.6248	)Gen Loss: 3.0876	)
Epoch (3/2000)	)Recon Loss: 0.2083	)Disc Loss: 1.1080	)Gen Loss: 1.5984	)
Epoch (4/2000)	)Recon Loss: 0.1182	)Disc Loss: 1.2357	)Gen Loss: 1.0939	)
Epoch (5/2000)	)Recon Loss: 0.0735	)Disc Loss: 1.2830	)Gen Loss: 0.9630	)
Epoch (6/2000)	)Recon Loss: 0.0703	)Disc Loss: 1.3251	)Gen Loss: 0.8720	)
Epoch (7/2000)	)Recon Loss: 0.0690	)Disc Loss: 1.3416	)Gen Loss: 0.8288	)
Epoch (8/2000)	)Recon Loss: 0.0686	)Disc Loss: 1.3617	)Gen Loss: 0.7791	)
Epoch (9/2000)	)Recon Loss: 0.0684	)Disc Loss: 1.3635	)Gen Loss: 0.7697	)
Epoch (10/2000)	)Recon Loss: 0.0676	)Disc Loss: 1.3710	)Gen Loss: 0.7484	)
Epoch (11/2000)	)Recon Loss: 0.0674	)Disc Loss: 1.3748	)Gen Loss: 0.7378	)
Epoch (12/2000)	)Recon Loss: 0.0667	)Disc Loss: 1.3784	)Gen Loss: 0.7249	)
Epoch (13/2000)	)Recon Loss: 0.0663	)Disc Loss: 1.3761	)Gen Loss: 0.7295	)
Epoch (14/2000)	)Recon Loss: 0.065

## Evaluation

In [6]:
aae.load_weights(path_prefix="2000_mnist_aae_weights")

Weights saved to 2000_mnist_aae_weights_*.pth


In [14]:
samples = aae.generate_samples(n=10000, prior_std=PRIOR_STD)

In [None]:

# check that 1 class is not clumped in last 10k
cross_validate_sigma(
    samples=samples,
    validation_dataset=X_train[50000:60000],
    sigma_range=np.exp(np.linspace(np.log(0.1), np.log(1.0), 10)),
    batch_size=100,
)





In [16]:
# 0.1668100537200059
estimate_log_likelihood(
    samples=samples,
    test_data=X_test,
    sigma=0.16859878364193906
)

(np.float64(340.1380381282135), np.float64(2.939785348750419))