## Adversarial Autoencoder (Basic Implementation)

In [12]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch import nn
from components import AdversarialAutoencoder

## Dataset Configuration

In [9]:
def configure_mnist(batch_size=100):
    transform = transforms.Compose([
        transforms.ToTensor(), # [0,1 normalization]
        transforms.Lambda(lambda x : x.view(-1)) # [(n, d, d, 1) -> (n,d^2)]
    ])

    return DataLoader(
        datasets.MNIST(root='./data', train=True, transform=transform, download=True),
        batch_size=batch_size,
        shuffle=True
    )


## Paper Configuration

In [14]:
INPUT_DIM = 784
BATCH_SIZE = 100
AE_HIDDEN = 1000
DC_HIDDEN = 1000
LATENT_DIM = 8
PRIOR_STD = 5.0
recon_loss = nn.MSELoss()
learning_rate = 1e-3 # not specified in appendix A.1, for successive experiments it is
use_decoder_sigmoid = True


## Training

In [16]:
aae = AdversarialAutoencoder(
    input_dim=INPUT_DIM,
    ae_hidden=AE_HIDDEN,
    dc_hidden=DC_HIDDEN,
    latent_dim=LATENT_DIM,
    recon_loss_fn=recon_loss,
    lr=learning_rate,
    use_decoder_sigmoid=use_decoder_sigmoid,
    device = "cuda" if torch.cuda.is_available() else "cpu"
)

loader = configure_mnist(batch_size=BATCH_SIZE)

aae.train(
    data_loader=loader,
    epochs=50,
    prior_std=PRIOR_STD,
)

Epoch (1/50)	)Recon Loss: 0.0787	)Disc Loss: 1.4262	)Gen Loss: 1.6439	)
Epoch (2/50)	)Recon Loss: 0.0704	)Disc Loss: 1.8373	)Gen Loss: 2.6186	)
Epoch (3/50)	)Recon Loss: 0.0673	)Disc Loss: 1.1156	)Gen Loss: 2.3606	)
Epoch (4/50)	)Recon Loss: 0.0671	)Disc Loss: 0.6345	)Gen Loss: 4.8821	)
Epoch (5/50)	)Recon Loss: 0.0672	)Disc Loss: 0.6537	)Gen Loss: 4.4879	)
Epoch (6/50)	)Recon Loss: 0.0673	)Disc Loss: 0.2102	)Gen Loss: 4.5482	)
Epoch (7/50)	)Recon Loss: 0.0662	)Disc Loss: 0.1323	)Gen Loss: 4.8886	)
Epoch (8/50)	)Recon Loss: 0.0655	)Disc Loss: 0.1692	)Gen Loss: 5.1607	)


KeyboardInterrupt: 