In [None]:
#Original GAN and face crop

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.datasets import ImageFolder
import os
import matplotlib.pyplot as plt
from tqdm import tqdm
import shutil
import cv2
from pathlib import Path
from PIL import Image

# Helper function for weight initialization
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

class Discriminator(nn.Module):
    def __init__(self, emb_size=32, num_classes=4):
        super(Discriminator, self).__init__()
        self.emb_size = emb_size
        self.label_embeddings = nn.Embedding(num_classes, self.emb_size)
        self.model = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1, bias=False),  # Output: (64, 32, 32)
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),  # Output: (128, 16, 16)
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),  # Output: (256, 8, 8)
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            
            nn.Conv2d(256, 512, 4, 2, 1, bias=False),  # Output: (512, 4, 4)
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            
            nn.Flatten()
        )

        # Calculate flattened size dynamically
        with torch.no_grad():
            dummy_input = torch.zeros(1, 3, 64, 64)  # Match input dimensions
            dummy_output = self.model(dummy_input)
            self.flattened_size = dummy_output.size(1)  # Should be 512*4*4=8192
            print(f"Discriminator Flattened Size: {self.flattened_size}")  # Debug statement

        # Adjust input size of Linear layer
        self.model2 = nn.Sequential(
            nn.Linear(self.flattened_size + self.emb_size, 100),  # Should be 8192 + 32 = 8224
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(100, 1),
            nn.Sigmoid()
        )
        self.apply(weights_init)

    def forward(self, input, labels):
        x = self.model(input)  # Convolutional and flattening layers
        y = self.label_embeddings(labels)  # Label embeddings
        combined = torch.cat([x, y], dim=1)  # Concatenate
        print(f"Combined Shape: {combined.shape}")  # Debug statement
        return self.model2(combined)

class Generator(nn.Module):
    def __init__(self, emb_size=32, num_classes=4):
        super(Generator, self).__init__()
        self.emb_size = emb_size
        self.label_embeddings = nn.Embedding(num_classes, self.emb_size)
        self.model = nn.Sequential(
            # 1. From latent vector + label embedding
            nn.ConvTranspose2d(100 + self.emb_size, 1024, 4, 1, 0, bias=False),  # (132, 1, 1) -> (1024, 4, 4)
            nn.BatchNorm2d(1024),
            nn.ReLU(True),
            nn.Dropout(0.3),
            
            # 2. Upsample to 8x8
            nn.ConvTranspose2d(1024, 512, 4, 2, 1, bias=False),  # (512, 8, 8)
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.Dropout(0.3),
            
            # 3. Upsample to 16x16
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),  # (256, 16, 16)
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.Dropout(0.3),
            
            # 4. Upsample to 32x32
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),  # (128, 32, 32)
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.Dropout(0.3),
            
            # 5. Upsample to 64x64
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),  # (64, 64, 64)
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.Dropout(0.3),
            
            # 6. Convert to 3 channels without upsampling
            nn.Conv2d(64, 3, 3, 1, 1, bias=False),  # (3, 64, 64)
            nn.Tanh()
        )
        self.apply(weights_init)

    def forward(self, input_noise, labels):
        label_embeddings = self.label_embeddings(labels).view(labels.size(0), self.emb_size, 1, 1)
        input = torch.cat([input_noise, label_embeddings], dim=1)
        return self.model(input)

# Function to create a new experiment directory with an incremental index
def create_experiment_dir(base_dir):
    if not os.path.exists(base_dir):
        os.makedirs(base_dir)
        return os.path.join(base_dir, "experiment_1")
    existing = [d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d)) and d.startswith("experiment_")]
    if not existing:
        return os.path.join(base_dir, "experiment_1")
    indices = [int(d.split("_")[1]) for d in existing if d.split("_")[1].isdigit()]
    next_index = max(indices) + 1 if indices else 1
    return os.path.join(base_dir, f"experiment_{next_index}")

# Function to crop faces from images
def crop_faces(input_dir, output_dir, face_cascade_path='haarcascade_frontalface_default.xml'):
    """
    Detects and crops faces from images in the input directory and saves them to the output directory,
    maintaining class subdirectories.

    Args:
        input_dir (str): Path to the directory containing input images organized in class subdirectories.
        output_dir (str): Path to the directory where cropped face images will be saved, maintaining class subdirectories.
        face_cascade_path (str): Path to the Haar Cascade XML file for face detection.
    """
    # Initialize face cascade
    face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + face_cascade_path)
    if face_cascade.empty():
        raise IOError("Cannot load Haar cascade xml file for face detection.")

    # Get list of class directories
    class_dirs = [d for d in Path(input_dir).iterdir() if d.is_dir()]
    
    for class_dir in class_dirs:
        class_name = class_dir.name
        output_class_dir = Path(output_dir) / class_name
        output_class_dir.mkdir(parents=True, exist_ok=True)
        
        # Get list of image paths in the class directory
        image_paths = sorted(class_dir.glob("*.jpg"))  # Adjust extension if needed
        
        for i, img_path in enumerate(tqdm(image_paths, desc=f"Cropping Faces in {class_name}")):
            img = cv2.imread(str(img_path))
            if img is None:
                print(f"Warning: Unable to read image {img_path}. Skipping.")
                continue

            gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
            faces = face_cascade.detectMultiScale(gray, scaleFactor=1.3, minNeighbors=5)

            if len(faces) == 0:
                print(f"No face detected in image {img_path}. Skipping.")
                continue

            # Save only the first detected face
            (x, y, w, h) = faces[0]
            face = img[y:y+h, x:x+w]

            # Convert BGR to RGB before saving
            face_rgb = cv2.cvtColor(face, cv2.COLOR_BGR2RGB)
            face_pil = Image.fromarray(face_rgb)

            # Resize the face to a consistent size before saving
            face_pil = face_pil.resize((64, 64), Image.Resampling.LANCZOS)


            # Save the cropped face
            save_path = output_class_dir / f"{class_name}_{i}.jpg"
            face_pil.save(save_path)

# Preprocess images by cropping faces
input_dir = "C:/Users/ryan9/Documents/CV/Team Project 3/structured_images"
cropped_dir = "C:/Users/ryan9/Documents/CV/Team Project 3/cropped_faces"

print("Starting face cropping...")
crop_faces(input_dir, cropped_dir)
print("Face cropping completed.")

# Define dataset and transforms without CenterCrop
transform = transforms.Compose([
    transforms.Resize((64, 64)),  # Images are already cropped; ensure consistent size
    transforms.RandomHorizontalFlip(),  # Data augmentation
    transforms.RandomRotation(10),      # Data augmentation
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])  # Corrected for 3 channels
])

# Update dataset path to use cropped faces
dataset_path = "C:/Users/ryan9/Documents/CV/Team Project 3/cropped_faces"
dataset = ImageFolder(root=dataset_path, transform=transform)

# Split dataset into training and validation sets (e.g., 80-20 split)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# DataLoaders for training and validation
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False, num_workers=4, pin_memory=True)

# Initialize models, optimizers, and loss function
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator = Generator(emb_size=32, num_classes=4).to(device)
discriminator = Discriminator(emb_size=32, num_classes=4).to(device)
criterion = nn.BCELoss()

# Unified lower learning rate for both Generator and Discriminator
learning_rate = 0.0001
optimizer_G = optim.Adam(generator.parameters(), lr=learning_rate, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=learning_rate, betas=(0.5, 0.999))

# Paths for saving models and examples with indexed subfolders
base_model_save_dir = "C:/Users/ryan9/Documents/CV/Team Project 3/saved_models"
base_example_save_dir = "C:/Users/ryan9/Documents/CV/Team Project 3/examples"

# Create new experiment directories
model_save_dir = create_experiment_dir(base_model_save_dir)
example_save_dir = create_experiment_dir(base_example_save_dir)
os.makedirs(model_save_dir, exist_ok=True)
os.makedirs(example_save_dir, exist_ok=True)

print(f"Models will be saved to: {model_save_dir}")
print(f"Examples will be saved to: {example_save_dir}")

# Dummy Forward Pass Test
print("Performing dummy forward pass to verify model dimensions...")
dummy_noise = torch.randn(128, 100, 1, 1, device=device)
dummy_labels = torch.randint(0, 4, (128,), device=device)
fake_imgs = generator(dummy_noise, dummy_labels)
output = discriminator(fake_imgs.detach(), dummy_labels)
print(f"Discriminator Output Shape: {output.shape}")  # Should be (128, 1)

# Training loop with increased epochs and validation
epochs = 100  # Increased number of epochs
latent_dim = 100
noise_std = 0.1  # Standard deviation for added Gaussian noise

for epoch in range(epochs):
    generator.train()
    discriminator.train()
    g_loss_epoch = 0
    d_loss_epoch = 0

    for i, (imgs, labels) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")):
        imgs, labels = imgs.to(device), labels.to(device)
        batch_size = imgs.size(0)

        # Apply label smoothing
        valid = torch.ones((batch_size, 1), device=device) * 0.9  # Smooth labels for real images
        fake = torch.zeros((batch_size, 1), device=device)        # Keep fake labels as 0

        # Train Discriminator
        optimizer_D.zero_grad()

        # Add Gaussian noise to real images
        real_imgs_noisy = imgs + noise_std * torch.randn_like(imgs)
        real_imgs_noisy = torch.clamp(real_imgs_noisy, -1, 1)  # Ensure images are still in [-1,1]

        noise = torch.randn(batch_size, latent_dim, 1, 1, device=device)
        gen_labels = torch.randint(0, 4, (batch_size,), device=device)  # Update to 4 classes
        fake_imgs = generator(noise, gen_labels)

        real_loss = criterion(discriminator(real_imgs_noisy, labels), valid)
        fake_loss = criterion(discriminator(fake_imgs.detach(), gen_labels), fake)
        d_loss = (real_loss + fake_loss) / 2
        d_loss.backward()
        optimizer_D.step()

        # Train Generator
        optimizer_G.zero_grad()

        # Add Gaussian noise to fake images
        fake_imgs_noisy = fake_imgs + noise_std * torch.randn_like(fake_imgs)
        fake_imgs_noisy = torch.clamp(fake_imgs_noisy, -1, 1)  # Ensure images are still in [-1,1]

        g_loss = criterion(discriminator(fake_imgs_noisy, gen_labels), valid)  # Target generator outputs as real
        g_loss.backward()
        optimizer_G.step()

        g_loss_epoch += g_loss.item()
        d_loss_epoch += d_loss.item()

    # Log and print losses for the epoch
    g_loss_epoch /= len(train_loader)
    d_loss_epoch /= len(train_loader)
    print(f"Epoch [{epoch+1}/{epochs}] - Generator Loss: {g_loss_epoch:.4f}, Discriminator Loss: {d_loss_epoch:.4f}")

    # Save models after each epoch
    torch.save(generator.state_dict(), os.path.join(model_save_dir, f"generator_epoch_{epoch+1}.pth"))
    torch.save(discriminator.state_dict(), os.path.join(model_save_dir, f"discriminator_epoch_{epoch+1}.pth"))

    # Validation step
    generator.eval()
    discriminator.eval()
    val_d_loss = 0
    val_g_loss = 0
    with torch.no_grad():
        for imgs, labels in val_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            batch_size = imgs.size(0)

            # Real images
            valid = torch.ones((batch_size, 1), device=device) * 0.9
            # Fake images
            fake = torch.zeros((batch_size, 1), device=device)

            # Add noise to real images
            real_imgs_noisy = imgs + noise_std * torch.randn_like(imgs)
            real_imgs_noisy = torch.clamp(real_imgs_noisy, -1, 1)

            # Discriminator loss on real images
            real_loss = criterion(discriminator(real_imgs_noisy, labels), valid)

            # Generate fake images
            noise = torch.randn(batch_size, latent_dim, 1, 1, device=device)
            gen_labels = torch.randint(0, 4, (batch_size,), device=device)  # Update to 4 classes
            fake_imgs = generator(noise, gen_labels)

            # Add noise to fake images
            fake_imgs_noisy = fake_imgs + noise_std * torch.randn_like(fake_imgs)
            fake_imgs_noisy = torch.clamp(fake_imgs_noisy, -1, 1)

            # Discriminator loss on fake images
            fake_loss = criterion(discriminator(fake_imgs_noisy, gen_labels), fake)
            d_loss = (real_loss + fake_loss) / 2
            val_d_loss += d_loss.item()

            # Generator loss
            g_loss = criterion(discriminator(fake_imgs_noisy, gen_labels), valid)
            val_g_loss += g_loss.item()

    val_d_loss /= len(val_loader)
    val_g_loss /= len(val_loader)
    print(f"Validation - Generator Loss: {val_g_loss:.4f}, Discriminator Loss: {val_d_loss:.4f}")

    # Save generated examples after each epoch
    with torch.no_grad():
        fixed_noise = torch.randn(4, latent_dim, 1, 1, device=device)  # One per class
        fixed_labels = torch.arange(0, 4, device=device)  # 0,1,2,3
        generated_imgs = generator(fixed_noise, fixed_labels)
        generated_imgs = (generated_imgs + 1) / 2  # Rescale to [0, 1]

        grid = torchvision.utils.make_grid(generated_imgs.cpu(), nrow=4)
        plt.figure(figsize=(16, 4))
        plt.imshow(grid.permute(1, 2, 0))
        plt.axis('off')
        plt.savefig(os.path.join(example_save_dir, f"epoch_{epoch+1}.png"))
        plt.close()


In [None]:
#cGAN

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.datasets import ImageFolder
import os
import matplotlib.pyplot as plt
from tqdm import tqdm
from pathlib import Path
from PIL import Image
import torchvision.models as models
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.inception import InceptionScore

# ------------------------------
# Helper Classes and Functions
# ------------------------------

class ResidualBlock(nn.Module):
    def __init__(self, in_channels):
        super(ResidualBlock, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(in_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(in_channels)
        )

    def forward(self, x):
        return x + self.block(x)

def preprocess_for_fid(images):
    images = (images + 1) * 127.5  # Scale from [-1, 1] to [0, 255]
    images = images.clamp(0, 255).byte()  # Clamp and convert to uint8
    return images
    
# Helper function for weight initialization
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1 or classname.find('ConvTranspose') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

# Function to create a new experiment directory with an incremental index
def create_experiment_dir(base_dir):
    if not os.path.exists(base_dir):
        os.makedirs(base_dir)
        return os.path.join(base_dir, "experiment_1")
    existing = [d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d)) and d.startswith("experiment_")]
    if not existing:
        return os.path.join(base_dir, "experiment_1")
    indices = [int(d.split("_")[1]) for d in existing if d.split("_")[1].isdigit()]
    next_index = max(indices) + 1 if indices else 1
    return os.path.join(base_dir, f"experiment_{next_index}")

# ------------------------------
# Updated Generator
# ------------------------------

class Generator(nn.Module):
    def __init__(self, emb_size=32, num_classes=4):
        super(Generator, self).__init__()
        self.emb_size = emb_size
        self.num_classes = num_classes
        self.label_embeddings = nn.Embedding(num_classes, emb_size)
        
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1),  # Output: (64, 32, 32)
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(64, 128, 4, 2, 1),  # Output: (128, 16, 16)
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(128, 256, 4, 2, 1),  # Output: (256, 8, 8)
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(256, 512, 4, 2, 1),  # Output: (512, 4, 4)
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
        )
        
        # Residual Blocks
        self.residual_blocks = nn.Sequential(
            ResidualBlock(512),
            ResidualBlock(512),
            ResidualBlock(512),
        )
        
        # Decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(512 + emb_size, 256, 4, 2, 1),  # Output: (256, 8, 8)
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(256, 128, 4, 2, 1),  # Output: (128, 16, 16)
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(128, 64, 4, 2, 1),  # Output: (64, 32, 32)
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(64, 3, 4, 2, 1),    # Output: (3, 64, 64)
            nn.Tanh()
        )
        
        self.apply(weights_init)

    def forward(self, input_images, target_labels):
        x = self.encoder(input_images)  # (batch_size, 512, 4, 4)
        x = self.residual_blocks(x)     # (batch_size, 512, 4, 4)
        
        # Embed labels and expand to match spatial dimensions
        label_embeddings = self.label_embeddings(target_labels).unsqueeze(2).unsqueeze(3)  # (batch_size, emb_size, 1, 1)
        label_embeddings = label_embeddings.repeat(1, 1, x.size(2), x.size(3))             # (batch_size, emb_size, 4, 4)
        
        # Concatenate along the channel dimension
        x = torch.cat([x, label_embeddings], dim=1)  # (batch_size, 512 + emb_size, 4, 4)
        output_images = self.decoder(x)             # (batch_size, 3, 64, 64)
        return output_images

# ------------------------------
# Updated Discriminator
# ------------------------------

class Discriminator(nn.Module):
    def __init__(self, emb_size=32, num_classes=4):
        super(Discriminator, self).__init__()
        self.emb_size = emb_size
        self.num_classes = num_classes
        self.label_embeddings = nn.Embedding(num_classes, emb_size)
        
        self.model = nn.Sequential(
            nn.Conv2d(3 + emb_size, 64, 4, 2, 1, bias=False),  # (64, 32, 32)
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),           # (128, 16, 16)
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),          # (256, 8, 8)
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(256, 512, 4, 2, 1, bias=False),          # (512, 4, 4)
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Flatten()
        )
        
        self.adv_layer = nn.Sequential(
            nn.Linear(512 * 4 * 4 + emb_size, 1),
            nn.Sigmoid()
        )
        
        self.aux_layer = nn.Sequential(
            nn.Linear(512 * 4 * 4 + emb_size, num_classes),
            nn.Softmax(dim=1)
        )
        
        self.apply(weights_init)

    def forward(self, input_images, labels):
        # Embed labels and expand to match spatial dimensions
        label_embeddings = self.label_embeddings(labels).unsqueeze(2).unsqueeze(3)  # (batch_size, emb_size, 1, 1)
        label_embeddings = label_embeddings.repeat(1, 1, input_images.size(2), input_images.size(3))  # (batch_size, emb_size, H, W)
        
        # Concatenate image and label embeddings
        x = torch.cat([input_images, label_embeddings], dim=1)  # (batch_size, 3 + emb_size, 64, 64)
        x = self.model(x)  # (batch_size, 512 * 4 * 4)
        
        # Concatenate with label embeddings
        label_embeddings_flat = self.label_embeddings(labels)  # (batch_size, emb_size)
        x = torch.cat([x, label_embeddings_flat], dim=1)      # (batch_size, 512*4*4 + emb_size)
        
        validity = self.adv_layer(x)  # (batch_size, 1)
        label_pred = self.aux_layer(x)  # (batch_size, num_classes)
        return validity, label_pred

# ------------------------------
# Perceptual Loss
# ------------------------------

class PerceptualLoss(nn.Module):
    def __init__(self, device):
        super(PerceptualLoss, self).__init__()
        vgg = models.vgg19(pretrained=True).features
        self.vgg_layers = nn.Sequential(*[vgg[i] for i in range(35)]).eval()
        for param in self.vgg_layers.parameters():
            param.requires_grad = False
        self.criterion = nn.MSELoss()
        self.vgg_layers.to(device)  # Move VGG layers to the specified device

    def forward(self, x, y):
        # Normalize input for VGG
        mean = torch.tensor([0.485, 0.456, 0.406]).to(x.device).view(1, 3, 1, 1)
        std = torch.tensor([0.229, 0.224, 0.225]).to(x.device).view(1, 3, 1, 1)
        x_norm = (x - mean) / std
        y_norm = (y - mean) / std
        x_features = self.vgg_layers(x_norm)
        y_features = self.vgg_layers(y_norm)
        return self.criterion(x_features, y_features)

# ------------------------------
# Minibatch Discrimination (Removed)
# ------------------------------
# The MinibatchDiscrimination layer has been removed as it's not commonly used
# and can complicate training. The updated Discriminator uses Auxiliary Classifier (ACGAN) approach.

# ------------------------------
# Data Preparation
# ------------------------------

# Define dataset and transforms without CenterCrop
transform = transforms.Compose([
    transforms.Resize((64, 64)),  # Ensure consistent size
    transforms.RandomHorizontalFlip(),  # Data augmentation
    transforms.RandomRotation(10),      # Data augmentation
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])  # Normalize to [-1, 1]
])

# Update dataset path to use cropped faces
dataset_path = "C:/Users/ryan9/Documents/CV/Team Project 3/cropped_faces"
dataset = ImageFolder(root=dataset_path, transform=transform)

# Split dataset into training and validation sets (e.g., 80-20 split)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# DataLoaders for training and validation
# To reduce verbosity, set num_workers to 0 if encountering issues with multiple workers
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=2, pin_memory=True)

# ------------------------------
# Initialize Models, Optimizers, and Loss Functions
# ------------------------------

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

if device.type == 'cuda':
    # Limit the GPU memory usage to 80%
    torch.cuda.set_per_process_memory_fraction(0.8, device.index)

# Initialize Generator and Discriminator
generator = Generator(emb_size=32, num_classes=4).to(device)
discriminator = Discriminator(emb_size=32, num_classes=4).to(device)

# Loss functions
criterion_GAN = nn.BCELoss()
criterion_aux = nn.CrossEntropyLoss()
criterion_perceptual = PerceptualLoss(device).to(device)
criterion_L1 = nn.L1Loss()

# Optimizers
learning_rate = 0.0001
optimizer_G = optim.Adam(generator.parameters(), lr=learning_rate, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=learning_rate, betas=(0.5, 0.999))

# ------------------------------
# Evaluation Metrics
# ------------------------------

fid = FrechetInceptionDistance(feature=2048).to(device)
inception_score = InceptionScore().to(device)

# ------------------------------
# Paths for Saving Models and Examples
# ------------------------------

base_model_save_dir = "C:/Users/ryan9/Documents/CV/Team Project 3/saved_models"
base_example_save_dir = "C:/Users/ryan9/Documents/CV/Team Project 3/examples"
base_fid_save_dir = "C:/Users/ryan9/Documents/CV/Team Project 3/fid_scores"

# Create new experiment directories
model_save_dir = create_experiment_dir(base_model_save_dir)
example_save_dir = create_experiment_dir(base_example_save_dir)
fid_save_dir = create_experiment_dir(base_fid_save_dir)

os.makedirs(model_save_dir, exist_ok=True)
os.makedirs(example_save_dir, exist_ok=True)
os.makedirs(fid_save_dir, exist_ok=True)

print(f"Models will be saved to: {model_save_dir}")
print(f"Examples will be saved to: {example_save_dir}")
print(f"FID scores will be saved to: {fid_save_dir}")

# ------------------------------
# Dummy Forward Pass Test
# ------------------------------
print("Performing dummy forward pass to verify model dimensions...")
dummy_noise = torch.randn(128, 100, 1, 1, device=device)  # Not used in updated Generator
dummy_labels = torch.randint(0, 4, (128,), device=device)
# Since Generator now requires input images, perform accordingly
dummy_input = torch.randn(128, 3, 64, 64, device=device)
fake_imgs = generator(dummy_input, dummy_labels)
validity, label_pred = discriminator(fake_imgs.detach(), dummy_labels)
print(f"Discriminator Output Shape (Validity): {validity.shape}")    # Should be (128, 1)
print(f"Discriminator Output Shape (Label): {label_pred.shape}")    # Should be (128, num_classes)

# ------------------------------
# Training Loop
# ------------------------------

epochs = 100  # Number of epochs
latent_dim = 100  # Not used in updated Generator
noise_std = 0.1  # Standard deviation for Gaussian noise
alpha = 0.5  # Weight for perceptual loss
beta = 0.5   # Weight for L1 loss (if used)

# Fixed noise and labels for generating consistent images
fixed_input = torch.randn(4, 3, 64, 64, device=device)  # Random input images
fixed_labels = torch.arange(0, 4, device=device)        # One label per class

# Initialize tqdm for epoch progress
epoch_pbar = tqdm(range(epochs), desc="Training Epochs", dynamic_ncols=True)

for epoch in epoch_pbar:
    generator.train()
    discriminator.train()
    g_loss_epoch = 0
    d_loss_epoch = 0

    for imgs, labels in train_loader:
        imgs, labels = imgs.to(device), labels.to(device)
        batch_size = imgs.size(0)

        # ---------------------
        #  Train Discriminator
        # ---------------------
        optimizer_D.zero_grad()

        # Real images
        real_valid = torch.ones((batch_size, 1), device=device) * 0.9  # Label smoothing
        real_imgs_noisy = imgs + noise_std * torch.randn_like(imgs)
        real_imgs_noisy = torch.clamp(real_imgs_noisy, -1, 1)  # Keep within [-1,1]

        # Discriminator forward pass on real images
        real_pred, real_aux = discriminator(real_imgs_noisy, labels)
        d_real_loss = criterion_GAN(real_pred, real_valid)
        d_real_aux_loss = criterion_aux(real_aux, labels)

        # Generate fake images
        target_labels = torch.randint(0, 4, (batch_size,), device=device)
        fake_imgs = generator(imgs, target_labels)
        fake_imgs_noisy = fake_imgs + noise_std * torch.randn_like(fake_imgs)
        fake_imgs_noisy = torch.clamp(fake_imgs_noisy, -1, 1)

        # Discriminator forward pass on fake images
        fake_valid = torch.zeros((batch_size, 1), device=device)
        fake_pred, fake_aux = discriminator(fake_imgs_noisy.detach(), target_labels)
        d_fake_loss = criterion_GAN(fake_pred, fake_valid)
        d_fake_aux_loss = criterion_aux(fake_aux, target_labels)

        # Total Discriminator loss
        d_loss = (d_real_loss + d_fake_loss) + (d_real_aux_loss + d_fake_aux_loss)
        d_loss.backward()
        optimizer_D.step()

        # -----------------
        #  Train Generator
        # -----------------
        optimizer_G.zero_grad()

        # Generate fake images
        fake_imgs = generator(imgs, target_labels)
        fake_imgs_noisy = fake_imgs + noise_std * torch.randn_like(fake_imgs)
        fake_imgs_noisy = torch.clamp(fake_imgs_noisy, -1, 1)

        # Discriminator evaluates the fake images
        pred_fake, aux_fake = discriminator(fake_imgs_noisy, target_labels)

        # Generator adversarial loss
        g_adv = criterion_GAN(pred_fake, real_valid)  # Wants discriminator to believe it's real

        # Auxiliary loss
        g_aux = criterion_aux(aux_fake, target_labels)

        # Perceptual loss
        g_perc = criterion_perceptual(fake_imgs, imgs)

        # L1 loss (optional, can help with image quality)
        g_L1 = criterion_L1(fake_imgs, imgs)

        # Total Generator loss
        g_loss = g_adv + alpha * g_perc + beta * g_L1 + g_aux
        g_loss.backward()
        optimizer_G.step()

        # Accumulate epoch losses
        g_loss_epoch += g_loss.item()
        d_loss_epoch += d_loss.item()

    # Calculate average losses for the epoch
    g_loss_epoch /= len(train_loader)
    d_loss_epoch /= len(train_loader)

    # ---------------------
    #  Validation and Evaluation
    # ---------------------
    generator.eval()
    discriminator.eval()
    val_d_loss = 0
    val_g_loss = 0
    fid.reset()
    inception_score.reset()

with torch.no_grad():
    for imgs, labels in val_loader:
        imgs, labels = imgs.to(device), labels.to(device)
        batch_size = imgs.size(0)
    
        # Preprocess real images for FID
        real_imgs_uint8 = preprocess_for_fid(imgs)
    
        # Generate fake images
        target_labels = torch.randint(0, 4, (batch_size,), device=device)
        fake_imgs = generator(imgs, target_labels)
    
        # Preprocess fake images for FID
        fake_imgs_uint8 = preprocess_for_fid(fake_imgs)
    
        # Update FID
        fid.update(fake_imgs_uint8, real=False)
        fid.update(real_imgs_uint8, real=True)
    
        # **Preprocess fake images for Inception Score (ensure uint8)**
        fake_imgs_for_is = preprocess_for_fid(fake_imgs)
        inception_score.update(fake_imgs_for_is)
    
        # Discriminator loss
        real_valid = torch.ones((batch_size, 1), device=device) * 0.9
        real_imgs_noisy = imgs + noise_std * torch.randn_like(imgs)
        real_imgs_noisy = torch.clamp(real_imgs_noisy, -1, 1)
        real_pred, real_aux = discriminator(real_imgs_noisy, labels)
        d_real_loss = criterion_GAN(real_pred, real_valid)
        d_real_aux_loss = criterion_aux(real_aux, labels)
    
        fake_valid = torch.zeros((batch_size, 1), device=device)
        fake_imgs_noisy = fake_imgs + noise_std * torch.randn_like(fake_imgs)
        fake_imgs_noisy = torch.clamp(fake_imgs_noisy, -1, 1)
        fake_pred, fake_aux = discriminator(fake_imgs_noisy, target_labels)
        d_fake_loss = criterion_GAN(fake_pred, fake_valid)
        d_fake_aux_loss = criterion_aux(fake_aux, target_labels)
    
        # Generator loss
        pred_fake, aux_fake = discriminator(fake_imgs_noisy, target_labels)
        g_adv = criterion_GAN(pred_fake, real_valid)
        g_aux = criterion_aux(aux_fake, target_labels)
        g_perc = criterion_perceptual(fake_imgs, imgs)
        g_L1 = criterion_L1(fake_imgs, imgs)
        g_loss = g_adv + alpha * g_perc + beta * g_L1 + g_aux
    
        # Update validation losses
        val_d_loss += (d_real_loss + d_fake_loss + d_real_aux_loss + d_fake_aux_loss).item()
        val_g_loss += g_loss.item()

# After the loop
val_d_loss /= len(val_loader)
val_g_loss /= len(val_loader)

# Compute FID and Inception Score
current_fid = fid.compute().item()
current_is, current_is_std = inception_score.compute()

# Reset metrics
fid.reset()
inception_score.reset()

# Update progress bar with epoch-level metrics
epoch_pbar.set_postfix({
    'G_Loss': f"{g_loss_epoch:.4f}",
    'D_Loss': f"{d_loss_epoch:.4f}",
    'Val_G_Loss': f"{val_g_loss:.4f}",
    'Val_D_Loss': f"{val_d_loss:.4f}",
    'FID': f"{current_fid:.2f}",
    'IS': f"{current_is:.2f}±{current_is_std:.2f}"
})

# ... rest of your code remains unchanged ...

# ---------------------
#  Save Models
# ---------------------
torch.save(generator.state_dict(), os.path.join(model_save_dir, f"generator_epoch_{epoch+1}.pth"))
torch.save(discriminator.state_dict(), os.path.join(model_save_dir, f"discriminator_epoch_{epoch+1}.pth"))

# ---------------------
#  Save Generated Examples
# ---------------------
with torch.no_grad():
    generated_imgs = generator(fixed_input, fixed_labels)
    generated_imgs = (generated_imgs + 1) / 2  # Rescale to [0, 1]

    grid = torchvision.utils.make_grid(generated_imgs.cpu(), nrow=4)
    plt.figure(figsize=(16, 4))
    plt.imshow(grid.permute(1, 2, 0))
    plt.axis('off')
    plt.savefig(os.path.join(example_save_dir, f"epoch_{epoch+1}.png"))
    plt.close()

# ---------------------
#  Save FID Score
# ---------------------
with open(os.path.join(fid_save_dir, "fid_scores.txt"), "a") as f:
    f.write(f"Epoch {epoch+1}: FID = {current_fid:.2f}, IS = {current_is:.2f}±{current_is_std:.2f}\n")

print("Training completed!")

# ------------------------------
# Re-aging Function for All Categories
# ------------------------------

def re_age_faces_all_categories(generator, input_images, original_ages, save_dir, device):
    """
    Re-age faces in the input images to all age categories and save with original/target age in filenames.

    Args:
        generator (nn.Module): Trained generator model.
        input_images (torch.Tensor): Batch of input images.
        original_ages (torch.Tensor): Original age labels for the input images.
        save_dir (str): Directory to save re-aged images.
        device (torch.device): Device to run computations on.
    """
    generator.eval()
    os.makedirs(save_dir, exist_ok=True)

    with torch.no_grad():
        # Iterate through all examples
        for i, (img, original_age) in enumerate(zip(input_images, original_ages)):
            img = img.unsqueeze(0).to(device)  # Add batch dimension
            original_age = original_age.item()

            # Generate re-aged images for all four categories
            for target_age in range(4):
                target_age_tensor = torch.tensor([target_age], device=device)

                # Generate re-aged image
                re_aged_img = generator(img, target_age_tensor)
                re_aged_img = (re_aged_img + 1) / 2  # Rescale to [0, 1]

                # Convert to image and save with detailed filename
                img_np = re_aged_img.squeeze(0).cpu().permute(1, 2, 0).numpy()  # Remove batch dimension
                img_pil = Image.fromarray((img_np * 255).astype("uint8"))
                filename = f"image_{i}_original_age_{original_age}_target_age_{target_age}.png"
                img_pil.save(os.path.join(save_dir, filename))

# ------------------------------
# Test Re-aging on Dataset
# ------------------------------

# Take 20 examples from the dataset
examples_to_process = 20
sample_images, sample_labels = next(iter(val_loader))

# Limit to 20 examples
sample_images = sample_images[:examples_to_process]
sample_labels = sample_labels[:examples_to_process]

# Save re-aged images
re_age_faces_all_categories(
    generator,
    sample_images,
    original_ages=sample_labels,
    save_dir="C:/Users/ryan9/Documents/CV/Team Project 3/re_aged_faces_all_categories",
    device=device
)
