# Adversarial Autoencoder: Toronto Face Likelihood Experiments

### Imports

In [13]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from torch.utils.data import Dataset
from utils import save_weights, load_weights
import scipy.io
from torch import nn


from torch.utils.data import TensorDataset, DataLoader

import torch
from torch.utils.data import DataLoader, TensorDataset

from src.aae import  AdversarialAutoencoder

from sklearn.model_selection import train_test_split

from likelihood import cross_validate_sigma, estimate_log_likelihood
from utils import compute_mean_std, normalize_data, rescale_to_unit_interval, save_weights

In [14]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [15]:
data = scipy.io.loadmat('./data/TFD/TFD_48x48.mat')
data.keys()

dict_keys(['__header__', '__version__', '__globals__', 'images', 'labs_ex', 'labs_id', 'folds'])

### Preprocessing Pipeline

In [16]:


def configure_tfd(data, device='cuda', batch_size=100):
    images = data['images']
    labels = data['labs_id']

    # Filter unlabeled samples (label == -1)
    mask = labels.flatten() == -1
    if not isinstance(images, torch.Tensor):
        images = torch.tensor(images, dtype=torch.float32)
    images_filtered = images[mask].to(device)

    print(f"Number of unlabeled samples: {images_filtered.shape[0]}")

    print(images_filtered.shape)
    # Flatten
    X = images_filtered.reshape(images_filtered.shape[0], -1)

    # Compute mean and std on GPU
    mean = X.mean(dim=0)
    std = X.std(dim=0)
    X_norm = (X - mean) / std

    # Shuffle and split indices on GPU
    total_samples = X_norm.shape[0]
    perm = torch.randperm(total_samples, device=device)
    train_size = int(0.9 * total_samples)
    train_idx = perm[:train_size]
    test_idx = perm[train_size:]

    X_train = X_norm[train_idx]
    X_test = X_norm[test_idx]

    # Create train DataLoader
    train_dataset = TensorDataset(X_train, torch.zeros(X_train.size(0), dtype=torch.uint8, device=device))
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    return train_loader, X_train, X_test, mean, std













## Paper Configuration

In [17]:
INPUT_DIM = 48*48
BATCH_SIZE = 100
AE_HIDDEN = 1000
DC_HIDDEN = 1000
LATENT_DIM = 15
PRIOR_STD = 10.0
recon_loss = nn.MSELoss()
init_recon_lr = 0.01
init_gen_lr = init_disc_lr = 0.1
use_decoder_sigmoid = False

## Training

In [22]:
aae = AdversarialAutoencoder(
    input_dim=INPUT_DIM,
    ae_hidden=AE_HIDDEN,
    dc_hidden=DC_HIDDEN,
    latent_dim=LATENT_DIM,
    recon_loss_fn=recon_loss,
    init_recon_lr=init_recon_lr,
    init_gen_lr=init_gen_lr,
    init_disc_lr=init_disc_lr,
    use_decoder_sigmoid=use_decoder_sigmoid,
    device = device
)

train_loader, X_train, X_test, mean, std = configure_tfd(data=data, device=device, batch_size=100)

Number of unlabeled samples: 98362
torch.Size([98362, 48, 48])


In [23]:
aae.train_mbgd(
    data_loader=train_loader,
    epochs=2000,
    prior_std=PRIOR_STD,
)

Epoch (1/2000)	)Recon Loss: 1.0010	)Disc Loss: 0.4317	)Gen Loss: 3.2179	)
Epoch (2/2000)	)Recon Loss: 1.0008	)Disc Loss: 1.0698	)Gen Loss: 1.7996	)
Epoch (3/2000)	)Recon Loss: 0.8745	)Disc Loss: 1.1990	)Gen Loss: 1.2941	)
Epoch (4/2000)	)Recon Loss: 0.5782	)Disc Loss: 1.2264	)Gen Loss: 1.1891	)
Epoch (5/2000)	)Recon Loss: 0.5149	)Disc Loss: 1.2595	)Gen Loss: 1.0607	)
Epoch (6/2000)	)Recon Loss: 0.4326	)Disc Loss: 1.2901	)Gen Loss: 0.9566	)
Epoch (7/2000)	)Recon Loss: 0.3969	)Disc Loss: 1.3115	)Gen Loss: 0.9101	)
Epoch (8/2000)	)Recon Loss: 0.3658	)Disc Loss: 1.3096	)Gen Loss: 0.8985	)
Epoch (9/2000)	)Recon Loss: 0.3533	)Disc Loss: 1.3207	)Gen Loss: 0.8819	)
Epoch (10/2000)	)Recon Loss: 0.3330	)Disc Loss: 1.3238	)Gen Loss: 0.8743	)
Epoch (11/2000)	)Recon Loss: 0.3224	)Disc Loss: 1.3335	)Gen Loss: 0.8512	)
Epoch (12/2000)	)Recon Loss: 0.3135	)Disc Loss: 1.3385	)Gen Loss: 0.8321	)
Epoch (13/2000)	)Recon Loss: 0.3004	)Disc Loss: 1.3421	)Gen Loss: 0.8184	)
Epoch (14/2000)	)Recon Loss: 0.288

In [None]:
load_weights(encoder=aae.encoder, decoder=aae.decoder, discriminator=aae.discriminator, device=device, path_prefix="2000_tfd_aae_weights")

In [None]:
samples = aae.generate_samples(n=10000, prior_std=PRIOR_STD)

In [None]:
cross_validate_sigma(
    samples=rescale_to_unit_interval(samples, mean, std),
    validation_dataset=rescale_to_unit_interval(X_train[50000:60000], mean, std),
    sigma_range=np.exp(np.linspace(np.log(0.09), np.log(0.2),20)),
    batch_size=100,
)

In [None]:

# 0.1061675131218835
estimate_log_likelihood(
    samples=rescale_to_unit_interval(samples, mean, std),
    test_data=rescale_to_unit_interval(X_test, mean, std),
    sigma=0.1061675131218835
)