In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.utils import save_image
from PIL import Image
import os
from tqdm import tqdm
import numpy as np
import json

In [2]:
class AgeDataset(Dataset):

    def __init__(self, data_path, transform=None):

        self.transform = transform

        self.images = []
        self.ages = []
        
        print(f"Loading dataset from: {data_path}")

        

        # Load from JSON file
        with open(data_path, 'r') as f:
            data = json.load(f)
        
        # Check if nested structure (has "root" key)
        if "root" in data:
            data = data["root"]
        
        # Flatten the nested dictionary into lists
        for age_str, content in data.items():
            age_group = int(age_str)  # Age group is already the key

            for img_id, img_path in content.items():
                self.images.append(img_path)
                self.ages.append(age_group)

        
        print(f"Loaded {len(self.images)} image paths from JSON")
            
       
        # Print age distribution
        # age_counts = [0] * (26)
        # #print(self.ages)
        # for age in self.ages:
        #     # print(age)
        #     # print(age//10)
        #     age_counts[age] += 1
        
        # print("\nAge distribution:")
        # for i, count in enumerate(age_counts):
        #     print(f"  {i*4}-{i*4+4} years: {count} images")
        print(max(self.ages))
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):

        # Load from path stored in JSON
        image_path = self.images[idx]
        image = np.load(image_path)
        
        
        age = self.ages[idx]
        
        # Convert numpy array to PIL Image
        # Handle different array formats
        if image.dtype == np.float32 or image.dtype == np.float64:
            # If normalized to [0, 1], scale to [0, 255]
            if image.max() <= 1.0:
                image = (image * 255).astype(np.uint8)
            else:
                image = image.astype(np.uint8)
        elif image.dtype != np.uint8:
            image = image.astype(np.uint8)
        
        # Handle channel order: (H, W, C) or (C, H, W)
        if image.ndim == 3:
            if image.shape[0] == 3 or image.shape[0] == 1:
                # (C, H, W) -> (H, W, C)
                image = np.transpose(image, (1, 2, 0))
        
        # Convert to PIL Image
        if image.shape[-1] == 1:
            image = Image.fromarray(image.squeeze(), mode='L').convert('RGB')
        else:
            image = Image.fromarray(image, mode='RGB')
        resized_pil_image = image.resize((128, 128), Image.LANCZOS)
        
        # Apply transforms
        if self.transform:
            image = self.transform(image)
        
        return image, age



In [3]:
class Generator(nn.Module):
    """
    Takes: input face image + target age
    Returns: face aged to target age
    """
    def __init__(self, num_age_classes=111):
        super(Generator, self).__init__()
        
        # ENCODER: Compress image to features
        self.encoder = nn.Sequential(
            # 128x128 -> 64x64
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            
            # 64x64 -> 32x32
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            
            # 32x32 -> 16x16
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            
            # 16x16 -> 8x8
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
        )
        
        # AGE CONDITIONING: Add age information
        self.age_embedding = nn.Embedding(num_age_classes, 512 * 8 * 8)
        
        # DECODER: Reconstruct image with new age
        self.decoder = nn.Sequential(
            # 8x8 -> 16x16
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            
            # 16x16 -> 32x32
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            
            # 32x32 -> 64x64
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            
            # 64x64 -> 128x128
            nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1),
            nn.Tanh()  # Output in range [-1, 1]
        )
    
    def forward(self, img, target_age):
        # Encode the face
        features = self.encoder(img)
        
        # Add age information
        batch_size = img.size(0)
        age_vec = self.age_embedding(target_age)
        age_vec = age_vec.view(batch_size, 512, 8, 8)
        #print(features.shape)
        #print(age_vec.shape)
        combined = features + age_vec
        
        # Decode with new age
        output = self.decoder(combined)
        return output

In [4]:
class Discriminator(nn.Module):
    """
    Takes: an image
    Returns: 
        - Is it real or fake? (validity)
        - What age is this person? (age classification)
    """
    def __init__(self, num_age_classes=10):
        super(Discriminator, self).__init__()
        
        # Feature extraction
        self.features = nn.Sequential(
            # 128x128 -> 64x64
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            
            # 64x64 -> 32x32
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            
            # 32x32 -> 16x16
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            
            # 16x16 -> 8x8
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
        )
        
        # Real/Fake head
        self.validity = nn.Sequential(
            nn.Conv2d(512, 1, kernel_size=8, stride=1, padding=0),
            nn.Sigmoid()
        )
        
        # Age classification head
        self.age_classifier = nn.Sequential(
            nn.Conv2d(512, num_age_classes, kernel_size=8, stride=1, padding=0),
        )
    
    def forward(self, img):
        features = self.features(img)
        #print(features.shape)
        validity = self.validity(features).view(-1, 1)
        age_logits = self.age_classifier(features).view(-1, 111)
        return validity, age_logits


In [5]:
dataset = AgeDataset("/projects/standard/csci5561/shared/G8/data/qtk.json")

Loading dataset from: /projects/standard/csci5561/shared/G8/data/qtk.json
Loaded 66982 image paths from JSON
116


In [6]:
dataset = AgeDataset("/projects/standard/csci5561/shared/G8/data/face_age.json")

Loading dataset from: /projects/standard/csci5561/shared/G8/data/face_age.json
Loaded 19556 image paths from JSON
110


In [7]:
def train_age_progression_gan(
    data_path,
    num_epochs=100,
    batch_size=4,
    lr_g=0.0002,
    lr_d=0.0002,
    image_size=128,
    save_interval=10,
    sample_interval=5,
    output_dir="training_output_CNN",
    device='cuda'
):
    """
    Complete training pipeline
    
    Args:
        data_path: Path to JSON file (mode='json'), folder (mode='individual'), 
                   or data files (mode='single'/'dict')

    """
    # Create output directories
    os.makedirs(f"{output_dir}/checkpoints", exist_ok=True)
    os.makedirs(f"{output_dir}/samples", exist_ok=True)
    
    print("="*60)
    print("TRAINING AGE PROGRESSION GAN")
    print("="*60)
    print(f"Device: {device}")
    print(f"Data path: {data_path}")
    print(f"Epochs: {num_epochs}")
    print(f"Batch size: {batch_size}")
    print(f"Image size: {image_size}x{image_size}")
    print("="*60 + "\n")
    
    # Data preprocessing
    transform = transforms.Compose([
        transforms.Resize(image_size),
        transforms.CenterCrop(image_size),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
    ])
    
    # Create dataset and dataloader
    dataset = AgeDataset(data_path, transform=transform)
    dataloader = DataLoader(
        dataset, 
        batch_size=batch_size, 
        shuffle=True, 
        num_workers=4,
        pin_memory=True
    )
    
    print(f"\nTotal batches per epoch: {len(dataloader)}\n")
    
    # Initialize models
    generator = Generator(num_age_classes=111).to(device)
    discriminator = Discriminator(num_age_classes=111).to(device)
    
    print(f"Generator parameters: {sum(p.numel() for p in generator.parameters()):,}")
    print(f"Discriminator parameters: {sum(p.numel() for p in discriminator.parameters()):,}\n")
    
    # Loss functions
    adversarial_loss = nn.BCELoss()
    age_loss = nn.CrossEntropyLoss()
    reconstruction_loss = nn.L1Loss()
    
    # Optimizers
    optimizer_g = optim.Adam(generator.parameters(), lr=lr_g, betas=(0.5, 0.999))
    optimizer_d = optim.Adam(discriminator.parameters(), lr=lr_d, betas=(0.5, 0.999))
    
    # Training loop
    for epoch in range(num_epochs):
        generator.train()
        discriminator.train()
        
        epoch_d_loss = 0
        epoch_g_loss = 0
        
        progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")
        
        for i, (real_imgs, real_ages) in enumerate(progress_bar):
            batch_size_actual = real_imgs.size(0)
            real_imgs = real_imgs.to(device)
            real_ages = real_ages.to(device)
            
            # Sample random target ages for aging
            target_ages = torch.randint(0, 111, (batch_size_actual,), device=device)
            
            # Labels
            real_labels = torch.ones(batch_size_actual, 1, device=device)
            fake_labels = torch.zeros(batch_size_actual, 1, device=device)
            
            # =============================
            # TRAIN DISCRIMINATOR
            # =============================
            optimizer_d.zero_grad()
            
            # Real images
            
            real_validity, real_age_pred = discriminator(real_imgs)
            d_real_adv_loss = adversarial_loss(real_validity, real_labels)
            #print(real_age_pred.shape)
            d_real_age_loss = age_loss(real_age_pred, real_ages)
            
            # Fake images
            
            fake_imgs = generator(real_imgs, target_ages)
            fake_validity, fake_age_pred = discriminator(fake_imgs.detach())
            d_fake_adv_loss = adversarial_loss(fake_validity, fake_labels)
            #print("got to here")
            d_fake_age_loss = age_loss(fake_age_pred, target_ages)
            #print("maybe error here")
            # Total discriminator loss
            d_loss = (d_real_adv_loss + d_fake_adv_loss) + 0.5 * (d_real_age_loss + d_fake_age_loss)
            d_loss.backward()
            optimizer_d.step()
            
            # =============================
            # TRAIN GENERATOR
            # =============================
            optimizer_g.zero_grad()
            
            # Generate fake images
            fake_imgs = generator(real_imgs, target_ages)
            fake_validity, fake_age_pred = discriminator(fake_imgs)
            
            # Adversarial loss: fool discriminator
            
            g_adv_loss = adversarial_loss(fake_validity, real_labels)
            
            # Age loss: correct age
            
            g_age_loss = age_loss(fake_age_pred, target_ages)
            
            # Identity preservation: when age doesn't change, preserve identity
            same_age_mask = (target_ages == real_ages)
            if same_age_mask.any():
                g_recon_loss = reconstruction_loss(
                    fake_imgs[same_age_mask],
                    real_imgs[same_age_mask]
                )
            else:
                g_recon_loss = torch.tensor(0.0, device=device)
            
            # Total generator loss
            g_loss = g_adv_loss + g_age_loss + 10.0 * g_recon_loss
            g_loss.backward()
            optimizer_g.step()
            
            # Update progress bar
            epoch_d_loss += d_loss.item()
            epoch_g_loss += g_loss.item()
            progress_bar.set_postfix({
                'D_loss': f'{d_loss.item():.4f}',
                'G_loss': f'{g_loss.item():.4f}'
            })
        
        # Print epoch summary
        avg_d_loss = epoch_d_loss / len(dataloader)
        avg_g_loss = epoch_g_loss / len(dataloader)
        print(f"\nEpoch {epoch+1} Summary:")
        print(f"  Average D Loss: {avg_d_loss:.4f}")
        print(f"  Average G Loss: {avg_g_loss:.4f}\n")
        
        # Save sample images
        if (epoch + 1) % sample_interval == 0:
            save_samples(generator, real_imgs[:8], epoch+1, output_dir, device)
        
        # Save checkpoint
        if (epoch + 1) % save_interval == 0:
            checkpoint = {
                'epoch': epoch + 1,
                'generator': generator.state_dict(),
                'discriminator': discriminator.state_dict(),
                'optimizer_g': optimizer_g.state_dict(),
                'optimizer_d': optimizer_d.state_dict(),
            }
            checkpoint_path = f"{output_dir}/checkpoints/checkpoint_epoch_{epoch+1}.pth"
            torch.save(checkpoint, checkpoint_path)
            print(f"✓ Saved checkpoint: {checkpoint_path}\n")
    
    print("="*60)
    print("TRAINING COMPLETE!")
    print("="*60)
    
    return generator, discriminator

In [8]:
def save_samples(generator, real_imgs, epoch, output_dir, device):
    """Generate and save sample aged faces"""
    generator.eval()
    with torch.no_grad():
        # Show aging progression: 20s, 40s, 60s, 80s
        age_groups = [20, 40, 60, 80]  # 20s, 40s, 60s, 80s
        
        samples = [real_imgs]
        for age in age_groups:
            target_ages = torch.full((real_imgs.size(0),), age, device=device)
            aged_imgs = generator(real_imgs, target_ages)
            samples.append(aged_imgs)
        
        # Concatenate all samples
        samples = torch.cat(samples, dim=0)
        
        # Save grid
        save_image(
            samples,
            f"{output_dir}/samples/epoch_{epoch}.png",
            nrow=real_imgs.size(0),
            normalize=True,
            value_range=(-1, 1)
        )
    generator.train()
    print(f"✓ Saved samples: {output_dir}/samples/epoch_{epoch}.png")

In [None]:
DATA_PATH = "/projects/standard/csci5561/shared/G8/data/face_age.json"

NUM_EPOCHS = 100                    # Number of training epochs
BATCH_SIZE = 4                     # Batch size (reduce if out of memory)
IMAGE_SIZE = 128                    # Image resolution
LEARNING_RATE_G = 0.0002           # Generator learning rate
LEARNING_RATE_D = 0.0002           # Discriminator learning rate
DEVICE = 'cuda:1' if torch.cuda.is_available() else 'cpu'
#DEVICE = "cpu"
# =============================
# START TRAINING!
# =============================

generator, discriminator = train_age_progression_gan(
    data_path=DATA_PATH,
    num_epochs=NUM_EPOCHS,
    batch_size=BATCH_SIZE,
    lr_g=LEARNING_RATE_G,
    lr_d=LEARNING_RATE_D,
    image_size=IMAGE_SIZE,
    save_interval=10,      # Save checkpoint every 10 epochs
    sample_interval=5,     # Generate samples every 5 epochs
    output_dir="training_output_CNN",
    device=DEVICE
)


TRAINING AGE PROGRESSION GAN
Device: cuda:1
Data path: /projects/standard/csci5561/shared/G8/data/face_age.json
Epochs: 100
Batch size: 4
Image size: 128x128

Loading dataset from: /projects/standard/csci5561/shared/G8/data/face_age.json
Loaded 19556 image paths from JSON
110

Total batches per epoch: 4889

Generator parameters: 9,152,515
Discriminator parameters: 6,428,464



  image = Image.fromarray(image, mode='RGB')
  image = Image.fromarray(image, mode='RGB')
  image = Image.fromarray(image, mode='RGB')
  image = Image.fromarray(image, mode='RGB')
Epoch 1/100: 100%|█████████████████████████████████████████████████████████████████████████████████████| 4889/4889 [01:00<00:00, 81.12it/s, D_loss=2.2087, G_loss=3.7569]



Epoch 1 Summary:
  Average D Loss: 3.2899
  Average G Loss: 4.5316



  image = Image.fromarray(image, mode='RGB')
  image = Image.fromarray(image, mode='RGB')
  image = Image.fromarray(image, mode='RGB')
  image = Image.fromarray(image, mode='RGB')
Epoch 2/100: 100%|█████████████████████████████████████████████████████████████████████████████████████| 4889/4889 [00:59<00:00, 81.97it/s, D_loss=0.2034, G_loss=8.2064]



Epoch 2 Summary:
  Average D Loss: 1.3804
  Average G Loss: 5.2435



  image = Image.fromarray(image, mode='RGB')
  image = Image.fromarray(image, mode='RGB')
  image = Image.fromarray(image, mode='RGB')
  image = Image.fromarray(image, mode='RGB')
Epoch 3/100: 100%|█████████████████████████████████████████████████████████████████████████████████████| 4889/4889 [01:00<00:00, 80.80it/s, D_loss=1.8574, G_loss=4.2490]



Epoch 3 Summary:
  Average D Loss: 0.8945
  Average G Loss: 5.8511



  image = Image.fromarray(image, mode='RGB')
  image = Image.fromarray(image, mode='RGB')
  image = Image.fromarray(image, mode='RGB')
  image = Image.fromarray(image, mode='RGB')
Epoch 4/100: 100%|█████████████████████████████████████████████████████████████████████████████████████| 4889/4889 [01:00<00:00, 80.99it/s, D_loss=1.0354, G_loss=8.5311]



Epoch 4 Summary:
  Average D Loss: 0.8382
  Average G Loss: 5.4840



  image = Image.fromarray(image, mode='RGB')
  image = Image.fromarray(image, mode='RGB')
  image = Image.fromarray(image, mode='RGB')
  image = Image.fromarray(image, mode='RGB')
Epoch 5/100: 100%|█████████████████████████████████████████████████████████████████████████████████████| 4889/4889 [01:00<00:00, 80.60it/s, D_loss=0.2009, G_loss=4.6354]



Epoch 5 Summary:
  Average D Loss: 0.8079
  Average G Loss: 5.1712

✓ Saved samples: training_output_CNN/samples/epoch_5.png


  image = Image.fromarray(image, mode='RGB')
  image = Image.fromarray(image, mode='RGB')
  image = Image.fromarray(image, mode='RGB')
  image = Image.fromarray(image, mode='RGB')
Epoch 6/100: 100%|█████████████████████████████████████████████████████████████████████████████████████| 4889/4889 [00:59<00:00, 82.33it/s, D_loss=0.1298, G_loss=3.5367]



Epoch 6 Summary:
  Average D Loss: 0.8453
  Average G Loss: 4.6377



  image = Image.fromarray(image, mode='RGB')
  image = Image.fromarray(image, mode='RGB')
  image = Image.fromarray(image, mode='RGB')
  image = Image.fromarray(image, mode='RGB')
Epoch 7/100:  26%|█████████████████████▊                                                               | 1252/4889 [00:15<00:43, 83.28it/s, D_loss=2.0145, G_loss=5.9109]

In [None]:
import numpy as np
image = np.load("/projects/standard/csci5561/shared/G8/data/face_age_Numpy/9773_100.npy")
if image.dtype == np.float32 or image.dtype == np.float64:
    # If normalized to [0, 1], scale to [0, 255]
    if image.max() <= 1.0:
        image = (image * 255).astype(np.uint8)
    else:
        image = image.astype(np.uint8)
elif image.dtype != np.uint8:
    image = image.astype(np.uint8)

# Handle channel order: (H, W, C) or (C, H, W)
if image.ndim == 3:
    if image.shape[0] == 3 or image.shape[0] == 1:
        # (C, H, W) -> (H, W, C)
        image = np.transpose(image, (1, 2, 0))
print(image.shape)
pil_image = Image.fromarray(image)
resized_pil_image = pil_image.resize((128, 128), Image.LANCZOS) # LANCZOS is a high-quality filter

# Convert the resized PIL Image back to a NumPy array
x = np.array(resized_pil_image)

print(x.shape)