# __Project 3a: Advanced GAN Crystal Ball__

In [None]:
# Basic Libraries
import scipy.io
import pandas as pd
import numpy as np
import matplotlib as plt
import datetime
from PIL import Image
import os

# Torch Libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import torchgan.models as models
from torchvision.utils import save_image

print(dir(models))
if torch.cuda.is_available():
    print(torch.cuda.get_device_name())
else:
    print('cpu')

### __Download__ and __Extract__ the CelebA Dataset

In [None]:
#!curl -L https://data.vision.ee.ethz.ch/cvl/rrothe/imdb-wiki/static/wiki_crop.tar -o wiki_crop.tar

In [None]:
#!tar -xvzf wiki_crop.tar -C .

### __Load the Celeb-WIKI Dataset__

In [None]:
mat_file_path = './wiki_crop/wiki.mat'
mat = scipy.io.loadmat(mat_file_path)

mat.keys()

In [None]:
mat.items()

In [None]:
# Extract from .mat file
wiki = mat['wiki']

full_path = wiki['full_path'][0][0][0]
gender = wiki['gender'][0][0][0]
dob = wiki['dob'][0][0][0]
photo_taken = wiki['photo_taken'][0][0][0]
face_location = wiki['face_location'][0][0][0]
name = wiki['name'][0][0][0]
face_score = wiki['face_score'][0][0][0]
second_face_score = wiki['second_face_score'][0][0][0]

df = pd.DataFrame({
    'full_path': full_path,
    'gender': gender.flatten(),
    'dob': dob.flatten(),
    'photo_taken': photo_taken.flatten(),
    'face_location': face_location.tolist(),
    'name': name.flatten(),
    'face_score': face_score.flatten(),
    'second_face_score': second_face_score.flatten()
})

### __Data Cleaning__

In [None]:
# -inf for face score means that the confidence of a face being detected in the image is virtually NONEXISTENT!
num_neg_inf = (df['face_score'] == -np.inf).sum()
print(num_neg_inf)
df_filtered = df[df['face_score'] != -np.inf]
df_filtered

In [None]:
num_nans = df_filtered['second_face_score'].isna().sum()
df_filtered = df_filtered[df_filtered['second_face_score'].isna()]
df_filtered

In [None]:
num_nans = df_filtered['gender'].isna().sum()
df_filtered = df_filtered[df_filtered['gender'].notna()]
num_nans = df_filtered['gender'].isna().sum()
print(num_nans)
df_filtered

In [None]:
df_filtered = df_filtered.drop(columns=['second_face_score'])
df_filtered

In [None]:
def matlab_serial_to_year(serial_date):
    # MATLAB's serial dates start from 0000-01-01, Python starts from 0001-01-01
    origin = datetime.datetime(1, 1, 1)  # Using year 1
    delta = datetime.timedelta(days=int(serial_date) - 366)  # Subtract 366 to adjust MATLAB's start year (0)
    return (origin + delta).year

# Assuming your cleaned DataFrame is named 'df'
def get_age_bucket(age):
    if age <= 18:
        return 0
    elif 19 <= age <= 29:
        return 1
    elif 30 <= age <= 39:
        return 2
    elif 40 <= age <= 49:
        return 3
    elif 50 <= age <= 59:
        return 4
    else:
        return 5

final_df = df_filtered.copy()
final_df['dob'] = df_filtered['dob'].apply(matlab_serial_to_year)

# Add another feature
final_df['age'] = final_df['photo_taken'] - final_df['dob']
final_df = final_df.drop(columns=['dob', 'photo_taken'])

# Assign age bucket to each row
final_df['age_bucket'] = final_df['age'].apply(get_age_bucket)

# Add './wiki_crop/' prefix to the full_path column to get the correct paths
final_df['full_path'] = final_df['full_path'].apply(lambda x: f"./wiki_crop/{x[0]}")

# Convert gender to int
final_df['gender'] = final_df['gender'].astype(int)

# Check the updated DataFrame
final_df

In [None]:
final_df['face_location'] = final_df['face_location'].apply(lambda x: x[0].tolist() if isinstance(x, np.ndarray) and x.ndim == 2 else x)
final_df['name'] = final_df['name'].apply(lambda x: x[0] if isinstance(x, np.ndarray) and x.ndim == 1 else x)

print(final_df.dtypes)

final_df

### __Data Preprocessing__

In [None]:
# Define transformations
transform = transforms.Compose([
    transforms.Resize((64, 64)),  # Resize cropped faces to 64x64
    transforms.ToTensor(),  # Convert PIL Image to Tensor
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])  # Normalize to [-1, 1]
])

In [None]:
class FaceAgingDataset(Dataset):
    def __init__(self, dataframe, transform=None):
        """
        Args:
            dataframe (pd.DataFrame): DataFrame containing image paths, gender, and labels.
            transform (callable, optional): A function/transform to apply to the images.
        """
        self.dataframe = dataframe
        self.transform = transform

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        # Get the row corresponding to the index
        row = self.dataframe.iloc[idx]

        # Load the image
        img = Image.open(row["full_path"]).convert("RGB")

        # Apply transformations
        if self.transform:
            img = self.transform(img)

        # Get the age bucket (convert to zero-indexed for PyTorch)
        age_bucket = row["age_bucket"] - 1

        # Get the gender
        gender = row["gender"]

        return img, age_bucket, gender

In [None]:
# Create dataset and dataloader
dataset = FaceAgingDataset(final_df, transform=transform)
dataloader = DataLoader(
    FaceAgingDataset(final_df, transform=transform), 
    batch_size=64, 
    shuffle=True, 
    num_workers=4
)

### __Modeling__

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1, bias=False),  # Output: [batch_size, 64, 32, 32]
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(64, 64 * 2, 4, 2, 1, bias=False),  # Output: [batch_size, 128, 16, 16]
            nn.BatchNorm2d(64 * 2),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(64 * 2, 64 * 4, 4, 2, 1, bias=False),  # Output: [batch_size, 256, 8, 8]
            nn.BatchNorm2d(64 * 4),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(64 * 4, 64 * 8, 4, 2, 1, bias=False),  # Output: [batch_size, 512, 4, 4]
            nn.BatchNorm2d(64 * 8),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(64 * 8, 1, 4, 1, 0, bias=False),  # Output: [batch_size, 1, 1, 1]
            nn.Sigmoid()  # Scalar probability
        )

    def forward(self, input):
        return self.model(input).view(-1)  # Flatten to [batch_size]

In [None]:
class Generator(nn.Module):
    def __init__(self, noise_dim, condition_dim, additional_features=1, output_channels=3):
        super(Generator, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(noise_dim + condition_dim + additional_features, 256),
            nn.ReLU(True),
            nn.Linear(256, 512),
            nn.ReLU(True),
            nn.Linear(512, 1024),
            nn.ReLU(True),
            nn.Linear(1024, 3 * 64 * 64),  # Generate flattened 64x64 image
            nn.Tanh()
        )

    def forward(self, noise, condition, additional_features):
        # Concatenate noise, condition, and additional features
        x = torch.cat((noise, condition, additional_features), dim=1)
        img = self.fc(x)
        return img.view(-1, 3, 64, 64)  # Reshape to batch_size x 3 x 64 x 64

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define dimensions for noise and condition
noise_dim = 100  # Latent noise dimension
condition_dim = 6  # Number of age buckets

# Initialize models
discriminator = Discriminator().to(device)
generator = Generator(noise_dim=noise_dim, condition_dim=condition_dim).to(device)

# Initialize optimizers
d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
g_optimizer = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))

# Loss function
loss_fn = nn.BCELoss()

In [None]:
def discriminator_train_step(real_data, fake_data):
    d_optimizer.zero_grad()

    # Train on real data
    prediction_real = discriminator(real_data)
    real_labels = torch.ones(real_data.size(0)).to(device)
    error_real = loss_fn(prediction_real, real_labels)
    error_real.backward()

    # Train on fake data
    prediction_fake = discriminator(fake_data.detach())  # Detach to avoid updating generator
    fake_labels = torch.zeros(fake_data.size(0)).to(device)
    error_fake = loss_fn(prediction_fake, fake_labels)
    error_fake.backward()

    d_optimizer.step()

    return error_real + error_fake

In [None]:
def generator_train_step(fake_data):
    g_optimizer.zero_grad()

    # Generate predictions
    prediction = discriminator(fake_data)
    real_labels = torch.ones(fake_data.size(0)).to(device)
    error = loss_fn(prediction, real_labels)  # Want discriminator to think fake is real
    error.backward()

    g_optimizer.step()

    return error

In [None]:
epochs = 50
batch_size = 64
noise_dim = 100
condition_dim = 6  # Number of age buckets

# Dataset and DataLoader
dataloader = DataLoader(FaceAgingDataset(final_df, transform=transform), batch_size=batch_size, shuffle=True)

# Training loop
for epoch in range(epochs):
    for real_images, age_buckets, gender in dataloader:
        real_images, age_buckets, gender = real_images.to(device), age_buckets.to(device), gender.to(device)

        # One-hot encode the age buckets for conditional input
        age_conditions = torch.eye(condition_dim).to(device)[age_buckets]

        # Combine age conditions with gender
        gender = gender.unsqueeze(1)  # Expand gender dimensions to match
        combined_conditions = torch.cat((age_conditions, gender), dim=1)

        # Generate fake images
        noise = torch.randn(real_images.size(0), noise_dim).to(device)
        fake_images = generator(noise, age_conditions, gender)  # Pass age_conditions and gender separately

        # Train Discriminator
        d_loss = discriminator_train_step(real_images, fake_images)

        # Train Generator
        noise = torch.randn(real_images.size(0), noise_dim).to(device)
        fake_images = generator(noise, age_conditions, gender)  # Pass age_conditions and gender separately
        g_loss = generator_train_step(fake_images)

    print(f"Epoch [{epoch + 1}/{epochs}] | D Loss: {d_loss.item():.4f} | G Loss: {g_loss.item():.4f}")

    # Ensure the output directory exists
    output_dir = "output"
    os.makedirs(output_dir, exist_ok=True)

    # Save sample images
    if (epoch + 1) % 10 == 0:
        save_image(fake_images.data[:25], f"{output_dir}/epoch_{epoch + 1}.png", nrow=5, normalize=True)

In [None]:
def visualize_age_buckets(generator, dataset, num_age_buckets=6):
    generator.eval()  # Set generator to evaluation mode

    # Select a random image from the dataset
    idx = torch.randint(0, len(dataset), (1,)).item()
    original_img, age_bucket, gender = dataset[idx]  # Extract all values from dataset
    original_img = original_img.to(device).unsqueeze(0)  # Add batch dimension

    # Generate noise for the generator
    noise_dim = 100
    fixed_noise = torch.randn(1, noise_dim).to(device)  # Fixed noise for consistency

    # Create conditional inputs for all age buckets
    generated_images = []
    for age_bucket in range(num_age_buckets):
        condition = torch.zeros(1, num_age_buckets).to(device)
        condition[0, age_bucket] = 1  # One-hot encode the age bucket

        # Add gender to condition
        gender_condition = torch.tensor([[gender]], device=device)  # Convert to tensor
        combined_condition = torch.cat((condition, gender_condition), dim=1)

        # Pass noise and condition to the generator
        fake_img = generator(fixed_noise, condition, gender_condition).detach().cpu()
        generated_images.append(fake_img.squeeze())

    # Plot the original image and the generated variations
    fig, axes = plt.subplots(1, num_age_buckets + 1, figsize=(15, 5))
    axes[0].imshow(original_img.squeeze().cpu().permute(1, 2, 0) * 0.5 + 0.5)  # De-normalize
    axes[0].set_title("Original Image")
    axes[0].axis("off")

    for i, fake_img in enumerate(generated_images):
        axes[i + 1].imshow(fake_img.permute(1, 2, 0) * 0.5 + 0.5)  # De-normalize
        axes[i + 1].set_title(f"Age Bucket {i + 1}")
        axes[i + 1].axis("off")

    plt.tight_layout()
    plt.show()

In [None]:
visualize_age_buckets(generator, FaceAgingDataset(final_df, transform=transform))