In [None]:
# --- CELL 1: imports & utils ---
import os
from pathlib import Path
import math
import random
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms, datasets, utils
import torchvision

import matplotlib.pyplot as plt
%matplotlib inline

# set deterministic seed for reproducibility
seed = 42
random.seed(seed)
torch.manual_seed(seed)

# utility to show image grid
def show_tensor_images(img_tensor, title=None):
    # img_tensor shape: (B, C, H, W) in range [0,1] or normalized
    img_grid = utils.make_grid(
        img_tensor.detach().cpu(), nrow=8, padding=2, normalize=True, scale_each=True
    )
    plt.figure(figsize=(8, 8))
    plt.axis('off')
    if title:
        plt.title(title)
    plt.imshow(img_grid.permute(1, 2, 0))
    plt.show()

In [None]:
# --- CELL 2: configuration ---
# IMPORTANT: change this to your local path that contains the CelebA images
# For example: data_dir = r"/home/user/datasets/celeba/img_align_celeba"
data_dir = r"C:\Users\GANESH\Downloads\archive\img_align_celeba\img_align_celeba" # <-- EDIT: path to your downloaded CelebA images


# training hyperparameters
img_size = 64 # resize images to 64x64 (typical for small VAE)
batch_size = 128
num_epochs = 1 # reduced to 1 epoch for a single output only
learning_rate = 1e-3
latent_dim = 128
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)

In [None]:
# --- CELL 3: dataset & dataloader ---
# We assume images exist directly in data_dir, or in a subfolder. If they are in a CSV or different layout, adapt.
transform = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.CenterCrop(img_size),
    transforms.ToTensor(),
])

# Check for common CelebA folder structure
if os.path.isdir(os.path.join(data_dir, 'img_align_celeba')):
    img_root = os.path.join(data_dir, 'img_align_celeba')
else:
    img_root = data_dir

print('Using images from:', img_root)

# Simple dataset loader
from PIL import Image

class SimpleImageDataset(torch.utils.data.Dataset):
    def __init__(self, root, transform=None):
        self.root = Path(root)
        self.paths = [p for p in self.root.rglob('*') if p.suffix.lower() in {'.jpg', '.jpeg', '.png'}]
        self.transform = transform

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        p = self.paths[idx]
        img = Image.open(p).convert('RGB')
        if self.transform:
            img = self.transform(img)
        return img, 0

# Initialize dataset and dataloader
ds = SimpleImageDataset(img_root, transform=transform)
dataloader = DataLoader(ds, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)

print('Number of images in dataset:', len(dataloader.dataset))

batch, _ = next(iter(dataloader))
show_tensor_images(batch[:32], title='Sample training images')


In [None]:
# --- CELL 4: VAE model ---
)
self.fc_mu = nn.Linear(256*4*4, z_dim)
self.fc_logvar = nn.Linear(256*4*4, z_dim)
self.fc_dec = nn.Linear(z_dim, 256*4*4)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(256, 128, 4, 2, 1),
nn.BatchNorm2d(128),
nn.ReLU(True),
nn.ConvTranspose2d(128, 64, 4, 2, 1),
nn.BatchNorm2d(64),
nn.ReLU(True),
nn.ConvTranspose2d(64, 32, 4, 2, 1),
nn.BatchNorm2d(32),
nn.ReLU(True),
nn.ConvTranspose2d(32, image_channels, 4, 2, 1),
nn.Sigmoid(),
)


def encode(self, x):
h = self.encoder(x)
h = h.view(h.size(0), -1)
mu = self.fc_mu(h)
logvar = self.fc_logvar(h)
return mu, logvar


def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std


def decode(self, z):
h = self.fc_dec(z)
h = h.view(h.size(0), 256, 4, 4)
x = self.decoder(h)
return x


def forward(self, x):
mu, logvar = self.encode(x)
z = self.reparameterize(mu, logvar)
x_recon = self.decode(z)
return x_recon, mu, logvar


model = ConvVAE(image_channels=3, hidden_dim=256, z_dim=latent_dim).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
bce_loss = nn.BCELoss(reduction='sum')

In [None]:
# --- CELL 5: training loop ---
model.train()
save_dir = Path('./vae_outputs')
save_dir.mkdir(exist_ok=True)


for epoch in range(1, num_epochs+1):
train_loss = 0.0
pbar = tqdm(enumerate(dataloader), total=len(dataloader), desc=f'Epoch {epoch}/{num_epochs}')
for i, (imgs, _) in pbar:
imgs = imgs.to(device)
optimizer.zero_grad()
recon, mu, logvar = model(imgs)
recon_loss = bce_loss(recon, imgs)
kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
loss = recon_loss + kl
loss.backward()
optimizer.step()
train_loss += loss.item()
pbar.set_postfix({'avg_loss': train_loss/((i+1)*batch_size)})
avg_loss = train_loss / len(dataloader.dataset)
print(f'Epoch {epoch} Average loss: {avg_loss:.4f}')


model.eval()
with torch.no_grad():
z = torch.randn(64, latent_dim).to(device)
samples = model.decode(z)
show_tensor_images(samples, title=f'Generated samples (epoch {epoch})')
grid = utils.make_grid(samples.detach().cpu(), nrow=8, normalize=True)
torchvision.utils.save_image(grid, save_dir / f'generated_epoch_{epoch}.png')
model.train()


print('Training finished. Saved outputs in:', save_dir)