In [None]:
import os
import random

# ===== CONFIG =====
DATA_DIR = "data/img_align_celeba"
KEEP_COUNT = 20000
IMAGE_EXTS = (".jpg", ".jpeg", ".png")

# ==================

# List all image files
files = [
    f for f in os.listdir(DATA_DIR)
    if f.lower().endswith(IMAGE_EXTS)
]

total_files = len(files)

if total_files <= KEEP_COUNT:
    print(f"Only {total_files} files found. Nothing to delete.")
    exit()

# Randomly select files to keep
keep_files = set(random.sample(files, KEEP_COUNT))

deleted = 0
for f in files:
    if f not in keep_files:
        os.remove(os.path.join(DATA_DIR, f))
        deleted += 1

print("===================================")
print(f"Initial files : {total_files}")
print(f"Kept files    : {KEEP_COUNT}")
print(f"Deleted files : {deleted}")
print("===================================")


Initial files : 202599
Kept files    : 20000
Deleted files : 182599


In [2]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import save_image, make_grid
import matplotlib.pyplot as plt
import numpy as np

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


Using device: cpu


In [4]:
DATA_DIR = "data/img_align_celeba"
IMAGE_SIZE = 64
BATCH_SIZE = 128

transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
])

dataset = datasets.ImageFolder(root=os.path.dirname(DATA_DIR), transform=transform)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)

In [5]:
LATENT_DIM = 128

class VAE(nn.Module):
    def __init__(self):
        super().__init__()

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, 4, 2, 1),   # 64 → 32
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, 2, 1), # 32 → 16
            nn.ReLU(),
            nn.Conv2d(64, 128, 4, 2, 1),# 16 → 8
            nn.ReLU(),
            nn.Conv2d(128, 256, 4, 2, 1),# 8 → 4
            nn.ReLU()
        )

        self.fc_mu = nn.Linear(256 * 4 * 4, LATENT_DIM)
        self.fc_logvar = nn.Linear(256 * 4 * 4, LATENT_DIM)

        # Decoder
        self.fc_dec = nn.Linear(LATENT_DIM, 256 * 4 * 4)

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 4, 2, 1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 4, 2, 1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 4, 2, 1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 3, 4, 2, 1),
            nn.Tanh()
        )

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        enc = self.encoder(x).view(x.size(0), -1)
        mu, logvar = self.fc_mu(enc), self.fc_logvar(enc)
        z = self.reparameterize(mu, logvar)
        dec = self.fc_dec(z).view(-1, 256, 4, 4)
        return self.decoder(dec), mu, logvar


In [6]:
def vae_loss(recon, x, mu, logvar):
    recon_loss = nn.functional.mse_loss(recon, x, reduction='sum')
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_loss + kl


In [None]:
model = VAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
EPOCHS = 20

for epoch in range(EPOCHS):
    model.train()
    total_loss = 0

    for imgs, _ in dataloader:
        imgs = imgs.to(device)
        optimizer.zero_grad()

        recon, mu, logvar = model(imgs)
        loss = vae_loss(recon, imgs, mu, logvar)

        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    print(f"Epoch [{epoch+1}/{EPOCHS}] Loss: {total_loss/len(dataset):.4f}")


In [None]:
model.eval()
with torch.no_grad():
    imgs, _ = next(iter(dataloader))
    imgs = imgs.to(device)
    recon, _, _ = model(imgs)

comparison = torch.cat([imgs[:8], recon[:8]])
grid = make_grid(comparison.cpu(), nrow=8, normalize=True)

plt.figure(figsize=(12,4))
plt.imshow(np.transpose(grid, (1,2,0)))
plt.axis("off")
plt.title("Top: Original | Bottom: Reconstructed")
plt.show()


In [None]:
def interpolate(z1, z2, steps=10):
    return [(1 - t) * z1 + t * z2 for t in torch.linspace(0, 1, steps)]

model.eval()
with torch.no_grad():
    imgs, _ = next(iter(dataloader))
    imgs = imgs[:2].to(device)

    _, mu, _ = model(imgs)
    z_interp = interpolate(mu[0], mu[1])

    decoded = []
    for z in z_interp:
        out = model.decoder(model.fc_dec(z.unsqueeze(0)).view(1,256,4,4))
        decoded.append(out)

grid = make_grid(torch.cat(decoded), nrow=len(decoded), normalize=True)
plt.figure(figsize=(14,3))
plt.imshow(np.transpose(grid.cpu(), (1,2,0)))
plt.axis("off")
plt.title("Latent Space Interpolation")
plt.show()


In [None]:
def get_latent_vector(model, dataloader, n=100):
    zs = []
    with torch.no_grad():
        for i, (imgs, _) in enumerate(dataloader):
            imgs = imgs.to(device)
            _, mu, _ = model(imgs)
            zs.append(mu)
            if i * imgs.size(0) > n:
                break
    return torch.cat(zs).mean(dim=0)


In [None]:
# Example vectors (placeholders for attributes)
smile_vector = get_latent_vector(model, dataloader)
neutral_vector = get_latent_vector(model, dataloader)

attribute_vector = smile_vector - neutral_vector


In [None]:
# Apply attribute
model.eval()
with torch.no_grad():
    img, _ = next(iter(dataloader))
    img = img[:1].to(device)

    _, mu, _ = model(img)
    modified_z = mu + 0.8 * attribute_vector

    out = model.decoder(model.fc_dec(modified_z).view(1,256,4,4))

comparison = torch.cat([img, out])
grid = make_grid(comparison.cpu(), nrow=2, normalize=True)

plt.figure(figsize=(4,4))
plt.imshow(np.transpose(grid, (1,2,0)))
plt.axis("off")
plt.title("Original vs Attribute Modified")
plt.show()
