In [None]:
from google.colab import drive
drive.mount('/content/drive')


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import zipfile
import os

zip_path = '/content/drive/MyDrive/Data/UTKFace.zip'
extract_to = '/content/images'

with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall(extract_to)


In [None]:
import os
import glob
import random
import math
from PIL import Image
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torchvision.utils as vutils

In [None]:
# Configuration and Hyperparameters

img_size = 128
batch_size = 64
num_epochs = 100
lr = 0.0002
beta1 = 0.5
lambda_cls = 1
lambda_rec = 10

# Paths of the datasets and output directories
data_root = "/content/images/UTKFace"
output_dir = "/content/drive/MyDrive/UTKFace_Outputs"
ckpt_dir   = "/content/drive/MyDrive/UTKFace_Outputs"

os.makedirs(output_dir, exist_ok=True)
os.makedirs(ckpt_dir, exist_ok=True)

In [None]:
# Dataset Definition for UTKFace

class UTKFaceDataset(Dataset):

    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform

        self.image_paths = [os.path.join(root_dir, fname)
                             for fname in os.listdir(root_dir) if fname.endswith(('.jpg', '.png'))]

        self.labels = []
        for path in self.image_paths:
            fname = os.path.basename(path)
            age = int(fname.split('_')[0])
            # Map age to group index 0-9
            if age < 1:
                age = 1
            group = min((age - 1) // 10, 9)
            self.labels.append(group)
        assert len(self.image_paths) == len(self.labels), "Mismatch in images and labels count"

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        age_group = self.labels[idx]

        img = Image.open(img_path).convert('RGB')
        if self.transform:
            img = self.transform(img)
        # Create one-hot encoding for the age group label
        label_tensor = torch.tensor(age_group, dtype=torch.long)
        one_hot = F.one_hot(label_tensor, num_classes=10).float()
        return img, one_hot, age_group


transform_ops = transforms.Compose([
    transforms.CenterCrop(min(img_size, 200)),
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])


full_dataset = UTKFaceDataset(data_root, transform=transform_ops)

indices = list(range(len(full_dataset)))
random.seed(42)
random.shuffle(indices)
split_idx = int(0.9 * len(indices))
train_indices = indices[:split_idx]
val_indices = indices[split_idx:]
train_subset = torch.utils.data.Subset(full_dataset, train_indices)
val_subset = torch.utils.data.Subset(full_dataset, val_indices)

train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True, num_workers=2, drop_last=True)
val_loader   = DataLoader(val_subset, batch_size=8, shuffle=False, num_workers=1)

In [None]:
# For monitoring, taking a fixed set of sample images from val

sample_batch = next(iter(val_loader))
sample_images, sample_labels, sample_age_groups = sample_batch

if sample_images.shape[0] == 0:
    sample_batch = next(iter(train_loader))
    sample_images, sample_labels, sample_age_groups = sample_batch

sample_images = sample_images[:5]
sample_labels = sample_labels[:5]
sample_age_groups = sample_age_groups[:5]

In [None]:
# Model Definitions

# Generator Architecture
class Generator(nn.Module):
    def __init__(self, img_channels=3, label_dim=10, feature_dim=64):
        super(Generator, self).__init__()
        self.label_dim = label_dim
        in_channels = img_channels + label_dim
        nf = feature_dim
        self.down1 = nn.Conv2d(in_channels, nf, kernel_size=7, stride=1, padding=3)
        self.down2 = nn.Conv2d(nf, nf*2, kernel_size=4, stride=2, padding=1)
        self.down3 = nn.Conv2d(nf*2, nf*4, kernel_size=4, stride=2, padding=1)

        res_blocks = []
        res_channels = nf * 4
        for _ in range(6):
            res_blocks.append(nn.Sequential(
                nn.Conv2d(res_channels, res_channels, kernel_size=3, stride=1, padding=1),
                nn.InstanceNorm2d(res_channels, affine=False),
                nn.ReLU(inplace=True),
                nn.Conv2d(res_channels, res_channels, kernel_size=3, stride=1, padding=1),
                nn.InstanceNorm2d(res_channels, affine=False)
            ))
        self.res_blocks = nn.ModuleList(res_blocks)
        self.up1 = nn.ConvTranspose2d(res_channels, nf*2, kernel_size=4, stride=2, padding=1)
        self.up2 = nn.ConvTranspose2d(nf*2, nf, kernel_size=4, stride=2, padding=1)
        self.out_conv = nn.Conv2d(nf, img_channels, kernel_size=7, stride=1, padding=3)
        self.actvn = nn.ReLU(inplace=True)
        self.tanh = nn.Tanh()

        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
                nn.init.normal_(m.weight, 0.0, 0.02)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.InstanceNorm2d) or isinstance(m, nn.BatchNorm2d):
                if hasattr(m, 'weight') and m.weight is not None:
                    nn.init.normal_(m.weight, 1.0, 0.02)
                    nn.init.constant_(m.bias, 0)

    def forward(self, img, target_label):
        label_maps = target_label[:, :, None, None]
        label_maps = label_maps.expand(-1, -1, img.size(2), img.size(3))
        x = torch.cat([img, label_maps], dim=1)

        x = self.actvn(self.down1(x))
        x = self.actvn(self.down2(x))
        x = self.actvn(self.down3(x))

        for res_block in self.res_blocks:
            residual = x
            out = res_block(x)
            x = self.actvn(out + residual)

        x = self.actvn(self.up1(x))
        x = self.actvn(self.up2(x))
        x = self.tanh(self.out_conv(x))
        return x

# Discriminator Architecture

class Discriminator(nn.Module):
    def __init__(self, img_channels=3, base_features=64, num_classes=10):
        super(Discriminator, self).__init__()
        nf = base_features

        self.conv1 = nn.Conv2d(img_channels, nf, kernel_size=4, stride=2, padding=1)
        self.conv2 = nn.Conv2d(nf, nf*2, kernel_size=4, stride=2, padding=1)
        self.conv3 = nn.Conv2d(nf*2, nf*4, kernel_size=4, stride=2, padding=1)
        self.conv4 = nn.Conv2d(nf*4, nf*8, kernel_size=4, stride=2, padding=1)
        self.conv5 = nn.Conv2d(nf*8, nf*16, kernel_size=4, stride=2, padding=1)

        self.gap = nn.AdaptiveAvgPool2d(1)

        self.adv_out = nn.Linear(nf*16, 1)
        self.cls_out = nn.Linear(nf*16, num_classes)

        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight, 0.0, 0.02)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0.0, 0.02)
                nn.init.constant_(m.bias, 0)

    def forward(self, img):
        x = F.leaky_relu(self.conv1(img), 0.2)
        x = F.leaky_relu(self.conv2(x), 0.2)
        x = F.leaky_relu(self.conv3(x), 0.2)
        x = F.leaky_relu(self.conv4(x), 0.2)
        x = F.leaky_relu(self.conv5(x), 0.2)

        x = self.gap(x)
        x = x.view(x.size(0), -1)

        adv_logits = self.adv_out(x)
        cls_logits = self.cls_out(x)
        return adv_logits, cls_logits

In [None]:
# Instantiate models and optimizers
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
G = Generator(img_channels=3, label_dim=10, feature_dim=64).to(device)
D = Discriminator(img_channels=3, base_features=64, num_classes=10).to(device)

# Optimizers
optimizer_G = torch.optim.Adam(G.parameters(), lr=lr, betas=(beta1, 0.999))
optimizer_D = torch.optim.Adam(D.parameters(), lr=lr, betas=(beta1, 0.999))

# Loss functions
adv_loss_fn = nn.BCEWithLogitsLoss()
cls_loss_fn = nn.CrossEntropyLoss()

# Labels for adversarial loss
real_label = 1.0
fake_label = 0.0

In [None]:
print(device)

cpu


In [None]:
# Training Loop

print("Starting training...")
for epoch in range(1, num_epochs+1):
    G.train()
    D.train()
    for i, (real_imgs, real_labels, real_age_groups) in enumerate(train_loader):
        real_imgs = real_imgs.to(device)
        real_labels = real_labels.to(device)         # one-hot labels
        real_age_groups = real_age_groups.to(device) # numeric labels

        #  Train Discriminator

        optimizer_D.zero_grad()

        adv_logits_real, cls_logits_real = D(real_imgs)

        adv_loss_real = adv_loss_fn(adv_logits_real, torch.ones_like(adv_logits_real)*real_label)
        cls_loss_real = cls_loss_fn(cls_logits_real, real_age_groups)

        target_age_groups = []
        for ag in real_age_groups:
            tgt = random.randrange(0, 10)
            if tgt == ag.item():
                tgt = (tgt + random.randrange(1, 10)) % 10
            target_age_groups.append(tgt)
        target_age_groups = torch.tensor(target_age_groups, dtype=torch.long, device=device)
        target_labels = F.one_hot(target_age_groups, num_classes=10).float()
        # Generate fake images with target age condition
        with torch.no_grad():
            fake_imgs = G(real_imgs, target_labels)
        # Forward pass fake images through D
        adv_logits_fake, _ = D(fake_imgs)
        adv_loss_fake = adv_loss_fn(adv_logits_fake, torch.ones_like(adv_logits_fake)*fake_label)
        # Total discriminator loss
        d_loss = adv_loss_real + adv_loss_fake + lambda_cls * cls_loss_real
        # Backprop and optimize D
        d_loss.backward()
        optimizer_D.step()

        #  Train Generator

        optimizer_G.zero_grad()

        fake_imgs = G(real_imgs, target_labels)

        for p in D.parameters():
            p.requires_grad = False
        adv_logits_fake, cls_logits_fake = D(fake_imgs)

        for p in D.parameters():
            p.requires_grad = True
        # Generator adversarial loss
        g_adv_loss = adv_loss_fn(adv_logits_fake, torch.ones_like(adv_logits_fake)*real_label)
        # Generator classification loss
        g_cls_loss = cls_loss_fn(cls_logits_fake, target_age_groups)
        # Reconstruction cycle loss: reconstruct original image from fake
        rec_imgs = G(fake_imgs, real_labels)
        rec_loss = F.l1_loss(rec_imgs, real_imgs)
        # Total generator loss
        g_loss = g_adv_loss + lambda_cls * g_cls_loss + lambda_rec * rec_loss
        # Backprop and optimize G
        g_loss.backward()
        optimizer_G.step()

        if (i+1) % 50 == 0:
            print(f"Epoch [{epoch}/{num_epochs}] Batch {i+1}/{len(train_loader)} "
                  f"D_loss: {d_loss.item():.3f}  G_loss: {g_loss.item():.3f}  "
                  f"(Adv: {g_adv_loss.item():.3f}, Cls: {g_cls_loss.item():.3f}, Rec: {rec_loss.item():.3f})")


    # Checking the transformation and saving the checkpoints

    G.eval()

    if epoch % 5 == 0:
        torch.save(G.state_dict(), os.path.join(ckpt_dir, f"generator_epoch_{epoch}.pth"))
        torch.save(D.state_dict(), os.path.join(ckpt_dir, f"discriminator_epoch_{epoch}.pth"))
        print(f"Saved checkpoints at epoch {epoch}")

    # Generate and save sample age transformations on fixed sample images
    with torch.no_grad():
        sample_targets = []
        for ag in sample_age_groups:
            if ag < 5:
                sample_targets.append(random.randint(5, 9))
            else:
                sample_targets.append(random.randint(0, 4))
        sample_targets = torch.tensor(sample_targets, dtype=torch.long, device=device)
        sample_target_labels = F.one_hot(sample_targets, num_classes=10).float()
        sample_inputs = sample_images.to(device)
        gen_samples = G(sample_inputs, sample_target_labels)

        all_imgs = []
        for idx in range(gen_samples.size(0)):
            orig = sample_inputs[idx]
            fake = gen_samples[idx]
            all_imgs.append(orig.cpu())
            all_imgs.append(fake.cpu())

        grid = vutils.make_grid(all_imgs, nrow=2, normalize=True, value_range=(-1, 1))
        vutils.save_image(grid, os.path.join(output_dir, f"epoch_{epoch}_comparison.png"))
        if epoch % 5 == 0:
            print(f"Saved sample comparison image at epoch {epoch}")

print("Training complete.")


Starting training...
Epoch [1/100] Batch 50/333 D_loss: 3.236  G_loss: 5.198  (Adv: 0.856, Cls: 2.380, Rec: 0.196)
Epoch [1/100] Batch 100/333 D_loss: 3.372  G_loss: 5.039  (Adv: 0.842, Cls: 2.554, Rec: 0.164)
Epoch [1/100] Batch 150/333 D_loss: 3.057  G_loss: 4.633  (Adv: 0.679, Cls: 2.580, Rec: 0.137)
Epoch [1/100] Batch 200/333 D_loss: 3.090  G_loss: 4.755  (Adv: 1.008, Cls: 2.460, Rec: 0.129)
Epoch [1/100] Batch 250/333 D_loss: 2.614  G_loss: 4.709  (Adv: 1.188, Cls: 2.297, Rec: 0.122)
Epoch [1/100] Batch 300/333 D_loss: 3.047  G_loss: 3.943  (Adv: 0.636, Cls: 2.092, Rec: 0.122)
Epoch [2/100] Batch 50/333 D_loss: 2.656  G_loss: 4.005  (Adv: 1.365, Cls: 1.494, Rec: 0.115)
Epoch [2/100] Batch 100/333 D_loss: 2.847  G_loss: 4.473  (Adv: 1.586, Cls: 1.548, Rec: 0.134)
Epoch [2/100] Batch 150/333 D_loss: 2.723  G_loss: 4.232  (Adv: 1.535, Cls: 1.544, Rec: 0.115)
Epoch [2/100] Batch 200/333 D_loss: 2.587  G_loss: 3.839  (Adv: 1.241, Cls: 1.485, Rec: 0.111)
Epoch [2/100] Batch 250/333 D_l

In [None]:
# Testing configuration
test_image_path = "/test_10_year.jpg"
target_age_group = 6
pretrained_model_path = "/content/drive/MyDrive/UTKFace_Outputs/generator_epoch_65.pth"

In [None]:
# Testing on a Single Image

if test_image_path is not None:
    G_model = Generator(img_channels=3, label_dim=10, feature_dim=64).to(device)
    if pretrained_model_path is not None:
        G_model.load_state_dict(torch.load(pretrained_model_path, map_location=device))
    else:
        G_model.load_state_dict(G.state_dict())
    G_model.eval()

    test_img = Image.open(test_image_path).convert('RGB')

    test_img = test_img.resize((128, 128), Image.BICUBIC)

    test_img_tensor = transform_ops(test_img).unsqueeze(0).to(device)

    tgt_group = max(0, min(9, target_age_group))
    tgt_label = F.one_hot(torch.tensor([tgt_group], device=device), num_classes=10).float()

    with torch.no_grad():
        output_face = G_model(test_img_tensor, tgt_label)

    out_path = "transformed_face1.png"
    vutils.save_image(output_face, out_path, normalize=True, value_range=(-1, 1))
    print(f"Single image transformed. Saved result to {out_path}")


Single image transformed. Saved result to transformed_face1.png
