# Variational Autoencoder (VAE) - Gesichtsgenerierung mit CelebA Dataset

In diesem Notebook wird ein Variational Autoencoder (VAE) auf dem CelebA-Datensatz trainiert, um realistische Gesichtsbilder zu generieren.


In [76]:
import os
import math
import random
import logging
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import numpy as np
import kagglehub

random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

<torch._C.Generator at 0x1057633f0>

In [77]:
# Set up logging to file
os.makedirs("data/logs", exist_ok=True)
log_file = os.path.join("data/logs", "training.log")
logging.basicConfig(
    filename=log_file,
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger()

# Check for GPU support
if torch.cuda.is_available():
    device = torch.device("cuda")
    logger.info("Using CUDA (NVIDIA GPU)!")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
    logger.info("Using MPS (Apple GPU)!")
else:
    device = torch.device("cpu")
    logger.info("Using CPU!")

In [78]:
# Download CelebA dataset via Kagglehub
data_path = kagglehub.dataset_download("jessicali9530/celeba-dataset")
logger.info(f"Dataset path: {data_path}")

In [79]:
# Define image transformations
image_size = 64
transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5),
                         (0.5, 0.5, 0.5))
])

# Load dataset
dataset = datasets.ImageFolder(root=data_path, transform=transform)

In [80]:
# Define VAE model class
class VAE(nn.Module):
    def __init__(self, latent_dim=100):
        super(VAE, self).__init__()
        self.latent_dim = latent_dim
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Flatten()
        )
        self.fc_mu = nn.Linear(256 * 8 * 8, latent_dim)
        self.fc_logvar = nn.Linear(256 * 8 * 8, latent_dim)
        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 256 * 8 * 8),
            nn.ReLU(),
            nn.Unflatten(1, (256, 8, 8)),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        x = self.encoder(x)
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        z = self.reparameterize(mu, logvar)
        return self.decoder(z), mu, logvar

In [81]:
# Define loss function
def loss_function(recon_x, x, mu, logvar):
    # Reconstruction loss (MSE)
    reconstruction_loss = nn.functional.mse_loss(recon_x, x, reduction='sum')
    # KL divergence
    kl_divergence = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return reconstruction_loss + kl_divergence

# Random search configuration
search_space = {
    'latent_dim': (64, 512),
    'learning_rate': (1e-4, 3e-3),
    'batch_size': (32, 256)
}
num_trials = 5
results = []

In [None]:
for trial in range(1, num_trials + 1):
    # --- Sample hyperparameters ---
    # latent_dim: uniform int in [64,512]
    min_ld, max_ld = search_space['latent_dim']
    latent_dim = random.randint(min_ld, max_ld)

    # learning_rate: log-uniform in [1e-4,3e-3]
    lr_min, lr_max = search_space['learning_rate']
    log_low, log_high = math.log10(lr_min), math.log10(lr_max)
    lr = 10 ** random.uniform(log_low, log_high)

    # batch_size: uniform int in [32,256]
    min_bs, max_bs = search_space['batch_size']
    batch_size = random.randint(min_bs, max_bs)

    logger.info(
        f"Trial {trial}/{num_trials} - "
        f"latent_dim={latent_dim}, lr={lr:.2e}, batch_size={batch_size}"
    )

    # --- Prepare DataLoader & Model ---
    loader    = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    vae       = VAE(latent_dim=latent_dim).to(device)
    optimizer = optim.Adam(vae.parameters(), lr=lr)

    # --- Training with Early Stopping ---
    best_loss = float('inf')
    patience, patience_counter = 3, 0
    eval_interval = 100
    step, total_loss = 0, 0

    vae.train()
    while patience_counter < patience:
        for data, _ in loader:
            data = data.to(device)
            optimizer.zero_grad()
            recon_batch, mu, logvar = vae(data)
            loss = loss_function(recon_batch, data, mu, logvar)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            step += 1

            if step % eval_interval == 0:
                avg_loss = total_loss / eval_interval
                logger.info(f"Trial {trial} - Step {step} - Avg Loss: {avg_loss:.4f}")
                if avg_loss < best_loss:
                    best_loss = avg_loss
                    patience_counter = 0
                    os.makedirs("data/model_checkpoints", exist_ok=True)
                    ckpt = f"data/model_checkpoints/vae_trial{trial}.pt"
                    torch.save(vae.state_dict(), ckpt)
                    logger.info(f"New best model saved: {ckpt}")
                else:
                    patience_counter += 1
                    logger.warning(f"No improvement. Patience {patience_counter}/{patience}")
                total_loss = 0

            if patience_counter >= patience:
                break
        if patience_counter >= patience:
            break

    logger.info(f"Trial {trial} completed. Best avg loss: {best_loss:.4f}")
    results.append((latent_dim, lr, batch_size, best_loss))

# Nach allen Trials: bestes Ergebnis
best = min(results, key=lambda x: x[3])
logger.info(f"Best config: latent_dim={best[0]}, lr={best[1]:.2e}, batch_size={best[2]}, loss={best[3]:.4f}")

In [None]:
# results is a list of tuples: (latent_dim, lr, batch_size, loss)
best_trial = min(results, key=lambda x: x[3])                   # pick the tuple with smallest loss
latent_dim, lr, batch_size, best_loss = best_trial             # unpack it
logger.info(
    f"Random search completed. Best config: "
    f"latent_dim={latent_dim}, lr={lr}, batch_size={batch_size}, loss={best_loss:.4f}"
)

In [None]:
# Load the best model file corresponding to the tuple index
idx = results.index(best_trial) + 1  # trial number
print(f"Loading best model for latent_dim={latent_dim}, lr={lr}, batch_size={batch_size}")
vae = VAE(latent_dim=latent_dim).to(device)
vae.load_state_dict(torch.load(f"data/model_checkpoints/vae_trial{idx}.pt", map_location=device))
vae.eval()

# Visualization loader
dataloader_vis = DataLoader(dataset, batch_size=512, shuffle=True)
latent_vectors, image_samples = [], []
for i, (data, _) in enumerate(dataloader_vis):
    data = data.to(device)
    with torch.no_grad():
        _, mu, _ = vae(data)
        latent_vectors.append(mu.cpu().numpy())
        image_samples.append(data.cpu().numpy())
    if len(latent_vectors) * 512 >= 5000:
        break
latent_vectors = np.concatenate(latent_vectors, axis=0)
image_samples = np.concatenate(image_samples, axis=0)
print(f"Collected latent vectors shape: {latent_vectors.shape}")

In [None]:
# Dimensionality reduction to 2D with t-SNE
print("🔄 Reduziere die Dimensionen für Visualisierung...")
tsne = TSNE(n_components=2, random_state=42)      # initialize t-SNE
latent_2d = tsne.fit_transform(latent_vectors)     # fit & transform

# Plot the 2D latent space
plt.figure(figsize=(8, 6))
plt.scatter(latent_2d[:, 0], latent_2d[:, 1], s=5, alpha=0.6)  # color will default
plt.title("Visualisierung des latenten Raumes mit t-SNE")
plt.xlabel("Dimension 1")
plt.ylabel("Dimension 2")
plt.grid(True)
plt.show()

In [None]:
z = torch.randn(16, latent_dim).to(device)
generated_images = vae.decoder(z)

# Rücktransformation: von [-1,1] zurück zu [0,1]
gen_imgs = (generated_images * 0.5) + 0.5  

# Erstelle ein 4×4-Grid
fig, axes = plt.subplots(4, 4, figsize=(8, 8))
for idx, ax in enumerate(axes.flat):
    img = gen_imgs[idx]
    ax.imshow(np.transpose(img, (1, 2, 0)))  # (C,H,W) → (H,W,C)
    ax.axis('off')

plt.tight_layout()
plt.show()