In [1]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

from gan import GAN
from torch import nn
from src.aae import AdversarialAutoencoder
from likelihood import cross_validate_sigma, estimate_log_likelihood

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Dataset Configuration

In [3]:
def configure_mnist(batch_size=100):
    # Transform: Just ToTensor (auto 0-1) + flatten
    transform = transforms.Compose([
        transforms.ToTensor(),  # Automatically scales pixels to [0, 1]
        transforms.Lambda(lambda x: x.view(-1))  # Flatten
    ])

    # Load datasets (applies transform automatically)
    train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
    test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)

    # Get the actual transformed data (0-1 scaled, flattened)
    X_train = torch.stack([x for x, _ in train_dataset])  # Exactly what DataLoader will see
    X_test = torch.stack([x for x, _ in test_dataset])

    Y_train = train_dataset.targets.clone()
    Y_test = test_dataset.targets.clone()

    # DataLoader (will serve same transformed data)
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True
    )

    return X_train, X_test, Y_train, Y_test, train_loader



In [4]:
X_train, X_test, Y_train, Y_test, train_loader = configure_mnist(batch_size=100)

## Paper Configuration

In [5]:
LATENT_DIM = 100
NUM_EPOCHS = 300
UNIFORM_RANGE = 3 ** 0.5
GENERATOR_HIDDEN_DIM=1200
DISCRIMINATOR_HIDDEN_DIM=240
INPUT_DIM = 784
LR = 1e-1
MIN_LR = 1e-6
DECAY_FACTOR = (1/(1+4e-6))
MOMENTUM = 0.5
FINAL_MOMENTUM = 0.7
MOMENTUM_SATURATE = 250
BATCH_SIZE = 100




## Training

In [6]:
gan = GAN(
    latent_dim=LATENT_DIM,
    input_size=INPUT_DIM,
    generator_hidden_dim=GENERATOR_HIDDEN_DIM,
    discriminator_hidden_dim=DISCRIMINATOR_HIDDEN_DIM,
    use_sigmoid_gen=True,
    device=device
)


In [7]:
gan.train_mbgd(
    data_loader=train_loader,
    learning_rate=LR,
    uniform_range=UNIFORM_RANGE,
    min_lr=MIN_LR,
    decay_factor=DECAY_FACTOR,
    epochs=NUM_EPOCHS,
    momentum=MOMENTUM,
    final_momentum=FINAL_MOMENTUM,
    momentum_saturate=MOMENTUM_SATURATE,

    log_dir='./tmp_runs'
)

[Epoch 1/300] Batch 0 | G Loss: 0.6932 | D Loss: 1.3863 | LR: 0.100000 | Momentum: 0.5000
[Epoch 1/300] Batch 100 | G Loss: 5.0786 | D Loss: 0.6481 | LR: 0.099960 | Momentum: 0.5000
[Epoch 1/300] Batch 200 | G Loss: 12.1758 | D Loss: 0.2999 | LR: 0.099920 | Momentum: 0.5000
[Epoch 1/300] Batch 300 | G Loss: 9.9472 | D Loss: 0.0991 | LR: 0.099880 | Momentum: 0.5000
[Epoch 1/300] Batch 400 | G Loss: 7.3296 | D Loss: 0.0414 | LR: 0.099840 | Momentum: 0.5000
[Epoch 1/300] Batch 500 | G Loss: 8.8285 | D Loss: 0.0123 | LR: 0.099800 | Momentum: 0.5000
[Epoch 2/300] Batch 0 | G Loss: 10.2546 | D Loss: 0.0480 | LR: 0.099760 | Momentum: 0.5000
[Epoch 2/300] Batch 100 | G Loss: 7.1078 | D Loss: 0.1172 | LR: 0.099720 | Momentum: 0.5000
[Epoch 2/300] Batch 200 | G Loss: 35.8733 | D Loss: 1.6037 | LR: 0.099681 | Momentum: 0.5000
[Epoch 2/300] Batch 300 | G Loss: 29.8826 | D Loss: 0.5787 | LR: 0.099641 | Momentum: 0.5000
[Epoch 2/300] Batch 400 | G Loss: 10.5191 | D Loss: 0.0284 | LR: 0.099601 | Mome

In [12]:
print(samples.shape, X_train.shape)

torch.Size([10000, 100]) torch.Size([60000, 784])


In [8]:
gan.save_weights(path_prefix="300_mnist_gan_weights")

Weights saved to 300_mnist_gan_weights_*.pth


In [14]:
samples = gan.sample_z(batch_size=10000, uniform_range=UNIFORM_RANGE)
samples = gan.generator(samples).detach()

In [17]:
# check that 1 class is not clumped in last 10k
cross_validate_sigma(
    samples=samples,
    validation_dataset=X_train[50000:60000],
    sigma_range=np.exp(np.linspace(np.log(0.1), np.log(1.0), 10)),
    batch_size=100,
)

Evaluating sigma = 0.10000000000000002
Sigma: 0.10000, Log-Likelihood: 291.09677
Evaluating sigma = 0.12915496650148844
Sigma: 0.12915, Log-Likelihood: 386.49493
Evaluating sigma = 0.1668100537200059


KeyboardInterrupt: 

In [16]:
# 0.13861643358759645
estimate_log_likelihood(
    samples=samples,
    test_data=X_test,
    sigma=0.13861643358759645
)

(np.float64(397.7332554443204), np.float64(4.662149497740721))