In [None]:
import os
from PIL import Image
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

from tqdm import tqdm

In [None]:
class CelebADataset(Dataset):
    def __init__(self, image_dir, attr_path, transform=None, selected_attrs=None):
        self.image_dir = image_dir
        self.attr_path = attr_path
        self.transform = transform

        df = pd.read_csv(attr_path, delim_whitespace=True, skiprows=1)
        df[df == -1] = 0  # convert -1/1 to 0/1
        self.attr_names = df.columns.tolist()
        self.image_names = df.index.tolist()

        if selected_attrs is not None:
            self.attr_indices = [self.attr_names.index(attr) for attr in selected_attrs]
        else:
            self.attr_indices = list(range(len(self.attr_names)))

        self.attrs = df.iloc[:, self.attr_indices].values.astype("float32")

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        img_name = self.image_names[idx]
        img_path = os.path.join(self.image_dir, img_name)
        image = Image.open(img_path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        label = torch.tensor(self.attrs[idx])
        return image, label

In [None]:
transform = transforms.Compose(
    [
        transforms.CenterCrop(178),
        transforms.Resize((64, 64)),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x / 255.0),  # Normalize to [0, 1]
    ]
)

dataset = CelebADataset(
    image_dir="../../data/celeba/celeba/img_align_celeba",
    attr_path="../../data/celeba/celeba/list_attr_celeba.txt",
    transform=transform,
    selected_attrs=None,
)

dataloader = DataLoader(
    dataset,
    batch_size=64,
    shuffle=True,
    num_workers=0
)

In [None]:
for images, labels in dataloader:
    print(images.shape)   # [64, 3, 64, 64]
    print(labels.shape)   # [64, 40]
    break

In [None]:
class VAE(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        encoder_dims = [32, 64, 128, 256, 512]
        self.final_dim = encoder_dims[-1]
        in_channels = 3

        modules = []
        for hdim in encoder_dims:
            modules.append(
                nn.Sequential(
                    nn.Conv2d(in_channels, hdim, kernel_size=3, stride=2, padding=1),
                    nn.BatchNorm2d(hdim),
                    nn.LeakyReLU(),
                )
            )
            in_channels = hdim

        self.encoder = nn.Sequential(*modules)
        self.fc_mu = nn.Linear(self.final_dim * 2 * 2, 128)
        self.fc_logvar = nn.Linear(self.final_dim * 2 * 2, 128)

        self.decoder_input = nn.Linear(128, self.final_dim * 2 * 2)
        decoder_dims = encoder_dims[::-1]
        modules = []
        for i in range(len(decoder_dims) - 1):
            modules.append(
                nn.Sequential(
                    nn.ConvTranspose2d(
                        decoder_dims[i],
                        decoder_dims[i + 1],
                        kernel_size=3,
                        stride=2,
                        padding=1,
                        output_padding=1
                    ),
                    nn.BatchNorm2d(decoder_dims[i + 1]),
                    nn.LeakyReLU(),
                )
            )

        self.decoder = nn.Sequential(*modules)

        self.final_layer = nn.Sequential(
            nn.ConvTranspose2d(
                decoder_dims[-1],
                decoder_dims[-1],
                kernel_size=3,
                stride=2,
                padding=1,
                output_padding=1
            ),
            nn.BatchNorm2d(decoder_dims[-1]),
            nn.LeakyReLU(),
            nn.Conv2d(decoder_dims[-1], 3, kernel_size=3, stride=1, padding=1),
            nn.Sigmoid()  # Output in [0, 1] range
        )

    def encode(self, x):
        result = self.encoder(x)
        result = torch.flatten(result, start_dim=1)
        mu = self.fc_mu(result)
        logvar = self.fc_logvar(result)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        result = self.decoder_input(z)
        result = result.view(-1, self.final_dim, 2, 2)
        result = self.decoder(result)
        return self.final_layer(result)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar


In [None]:
def loss_function(recon_x, x, mu, logvar):
    MSE = F.mse_loss(recon_x, x)
    KLD = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
    kld_weight = 0.00025
    loss = MSE + kld_weight * KLD
    return loss

In [None]:
def train_vae(model, dataloader, optimizer, device, num_epochs):
    model.train()
    losses = []

    for epoch in range(1, num_epochs + 1):
        total_loss = 0
        batch_pbar = tqdm(dataloader, desc=f"Epoch {epoch}", leave=False)

        for x, _ in batch_pbar:
            x = x.to(device)

            optimizer.zero_grad()
            recon_x, mu, logvar = model(x)
            loss = loss_function(recon_x, x, mu, logvar)
            loss.backward()
            optimizer.step()

            loss_value = loss.item()
            losses.append(loss_value)
            total_loss += loss_value

            batch_pbar.set_postfix(loss=loss_value)

        tqdm.write(f"[Epoch {epoch}] Total Loss: {total_loss:.4f}")

In [None]:
def test_vae(model, dataloader, device):
    model.eval()
    test_loss = 0

    with torch.no_grad():
        for x, _ in tqdm(dataloader, desc='Testing'):
            x = x.to(device)
            recon_x, mu, logvar = model(x)
            loss = loss_function(recon_x, x, mu, logvar)
            test_loss += loss.item()

    avg_loss = test_loss / len(dataloader.dataset)
    tqdm.write(f"[Test] Average Loss: {avg_loss:.2f}")

In [None]:
model = VAE().to('cpu')
optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [None]:
losses = train_vae(model, dataloader, optimizer, device='cpu', num_epochs=1)