## Adversarial Autoencoder (Basic Implementation)

In [23]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch import nn
from adversarial import AdversarialAutoencoder
from likelihood import

## Dataset Configuration

In [24]:
def configure_mnist(batch_size=100):
    # Define the transform to convert to tensor and flatten the image
    transform = transforms.Compose([
        transforms.ToTensor(),  # [0, 1 normalization]
        transforms.Lambda(lambda x: x.view(-1))  # Flatten
    ])

    # Load the training and test datasets
    train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
    test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)

    # Extract the training and test data and labels
    X_train = train_dataset.data.float().div(255).view(-1, 28 * 28)  # Normalize and flatten
    X_test = test_dataset.data.float().div(255).view(-1, 28 * 28)

    Y_train = train_dataset.targets  # Get corresponding labels
    Y_test = test_dataset.targets  # Get corresponding labels

    # Create the DataLoader for training
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True
    )

    # Return data and dataloader as a tuple
    return X_train, X_test, Y_train, Y_test, train_loader






## Paper Configuration

In [28]:
INPUT_DIM = 784
BATCH_SIZE = 100
AE_HIDDEN = 1000
DC_HIDDEN = 1000
LATENT_DIM = 8
PRIOR_STD = 5.0
recon_loss = nn.MSELoss()
init_recon_lr = 0.01
init_gen_lr = init_disc_lr = 0.1
use_decoder_sigmoid = True


## Training

In [26]:
aae = AdversarialAutoencoder(
    input_dim=INPUT_DIM,
    ae_hidden=AE_HIDDEN,
    dc_hidden=DC_HIDDEN,
    latent_dim=LATENT_DIM,
    recon_loss_fn=recon_loss,
    init_recon_lr=init_recon_lr,
    init_gen_lr=init_gen_lr,
    init_disc_lr=init_disc_lr,
    use_decoder_sigmoid=use_decoder_sigmoid,
    device = "cuda" if torch.cuda.is_available() else "cpu"
)

X_train, X_test, Y_train, Y_test, train_loader = configure_mnist(batch_size=BATCH_SIZE)



In [27]:
aae.train_mbgd(
    data_loader=train_loader,
    epochs=50,
    prior_std=PRIOR_STD,
)

Epoch (1/50)	)Recon Loss: 0.2297	)Disc Loss: 0.3583	)Gen Loss: 3.3949	)


KeyboardInterrupt: 

## Evaluation

In [5]:
aae.load_weights()

Weights loaded from aae_weights_*.pth


In [None]:
# Include likelihood experiments down here after.