In [None]:
###############################################################################
# CELL 1: IMPORTS AND LIBRARIES
###############################################################################
"""
Face GAN - Deep Convolutional GAN for Face Generation
----------------------------------------------------
Import required libraries and modules for dataset processing,
neural network creation, training, and visualization.
"""
import os
import re
import sys
import time
import json
import random
import datetime
import numpy as np
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.utils as vutils
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
from tqdm.notebook import tqdm  # Use tqdm.notebook for Jupyter progress bars

# Try importing optional libraries
try:
    from torch.utils.tensorboard import SummaryWriter
    TENSORBOARD_AVAILABLE = True
except ImportError:
    TENSORBOARD_AVAILABLE = False
    print("TensorBoard not available. Install with: pip install tensorboard")

# For resource monitoring
try:
    import psutil
    import GPUtil
    RESOURCE_MONITORING = True
except ImportError:
    RESOURCE_MONITORING = False
    print("Resource monitoring unavailable. Install with: pip install psutil gputil")

# Set display settings for the notebook
%matplotlib inline
plt.rcParams['figure.figsize'] = (10.0, 8.0)


In [None]:
###############################################################################
# CELL 2: CONFIGURATION SETTINGS
###############################################################################
"""
Configuration Settings
---------------------
Define paths, hyperparameters, and training settings.
Modify these values to adapt the model to your specific requirements.
"""
# Paths
BASE_DIR = os.path.dirname(os.path.abspath("__file__"))
CELEBDF_PATH = os.path.join(BASE_DIR, "faces\\Real\\Celeb_V2\\Train\\real")
FF_PATH = os.path.join(BASE_DIR, "faces\\Real\\FaceForensics++\\original_sequences\\youtube\\c23\\frames")
PROCESSED_PATH = os.path.join(BASE_DIR, "processed_faces")
OUTPUT_PATH = os.path.join(BASE_DIR, "output")
CHECKPOINT_DIR = os.path.join(OUTPUT_PATH, "checkpoints")
LOG_DIR = os.path.join(OUTPUT_PATH, "logs")

# Dataset processing settings
PROCESS_DATASETS = True  # Set to False to skip dataset processing if already processed
CELEBDF_MAX_PER_FACE = None  # Max images per face for CelebDF
FF_MAX_PER_FACE = None  # Max images per face for FaceForensics++
FF_MAX_FACES = None  # Maximum number of different faces from FaceForensics++
TARGET_SIZE = 128  # Size to resize images to

# Model hyperparameters
CUDA = True  # Use CUDA (will be auto-detected later)
BATCH_SIZE = 32
IMAGE_CHANNEL = 3  # RGB images
Z_DIM = 100  # Latent vector dimension
G_HIDDEN = 64  # Generator hidden dimension
D_HIDDEN = 64  # Discriminator hidden dimension
X_DIM = 128  # Target image size
EPOCH_NUM = 10  # Number of training epochs
REAL_LABEL = 1
FAKE_LABEL = 0
lr = 2e-4  # Learning rate
seed = 1  # Random seed for reproducibility

# Training control settings
CHECKPOINT_FREQ = 5  # Save checkpoints every N epochs
CHECKPOINT_SAMPLES = 1000  # Generate image samples every N iterations
RESUME_TRAINING = True  # Try to resume from checkpoint if available
EARLY_STOPPING_PATIENCE = 5  # Early stopping after N epochs without improvement
EARLY_STOPPING_THRESHOLD = 0.01  # Minimum improvement to reset patience counter
RESOURCE_CHECK_FREQ = 50  # Check system resources every N batches
MAX_TRAINING_TIME = None  # Max training time in hours (None for no limit)
ENABLE_TENSORBOARD = TENSORBOARD_AVAILABLE  # Enable TensorBoard logging
EXPERIMENT_NAME = f"face_gan_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}"  # Unique name for this run

# Print some key configuration values
print(f"Dataset paths:\n- CelebDF: {CELEBDF_PATH}\n- FaceForensics++: {FF_PATH}")
print(f"Processed data will be saved to: {PROCESSED_PATH}")
print(f"Output will be saved to: {OUTPUT_PATH}")
print(f"Target image size: {TARGET_SIZE}x{TARGET_SIZE}")
print(f"Batch size: {BATCH_SIZE}")
print(f"Training epochs: {EPOCH_NUM}")


Dataset paths:
- CelebDF: c:\Users\vinay\Documents\mnist\faces\Real\Celeb_V2\Train\real
- FaceForensics++: c:\Users\vinay\Documents\mnist\faces\Real\FaceForensics++\original_sequences\youtube\c23\frames
Processed data will be saved to: c:\Users\vinay\Documents\mnist\processed_faces
Output will be saved to: c:\Users\vinay\Documents\mnist\output
Target image size: 128x128
Batch size: 32
Training epochs: 10


In [None]:
###############################################################################
# CELL 3: HELPER FUNCTIONS
###############################################################################
"""
Helper Functions
--------------
Create output directories, verify dataset paths,
save configuration, and monitor system resources.
"""
def create_output_directories():
    """Create necessary output directories"""
    os.makedirs(PROCESSED_PATH, exist_ok=True)
    os.makedirs(OUTPUT_PATH, exist_ok=True)
    os.makedirs(CHECKPOINT_DIR, exist_ok=True)
    os.makedirs(LOG_DIR, exist_ok=True)
    
    # Create separate directories for each dataset
    celebdf_dir = os.path.join(PROCESSED_PATH, "celebdf")
    ff_dir = os.path.join(PROCESSED_PATH, "faceforensics")
    combined_dir = os.path.join(PROCESSED_PATH, "combined")
    
    os.makedirs(celebdf_dir, exist_ok=True)
    os.makedirs(ff_dir, exist_ok=True)
    os.makedirs(combined_dir, exist_ok=True)
    
    print(f"Created output directories:")
    print(f" - CelebDF: {celebdf_dir}")
    print(f" - FaceForensics++: {ff_dir}")
    print(f" - Combined: {combined_dir}")
    print(f" - Results: {OUTPUT_PATH}")
    print(f" - Checkpoints: {CHECKPOINT_DIR}")
    print(f" - Logs: {LOG_DIR}")
    
    return celebdf_dir, ff_dir, combined_dir

def verify_dataset_paths():
    """Verify that dataset paths exist"""
    celebdf_exists = os.path.exists(CELEBDF_PATH)
    ff_exists = os.path.exists(FF_PATH)
    
    print(f"CelebDF dataset: {'Found' if celebdf_exists else 'Not found'} at {CELEBDF_PATH}")
    print(f"FaceForensics++ dataset: {'Found' if ff_exists else 'Not found'} at {FF_PATH}")
    
    return celebdf_exists, ff_exists

def save_config():
    """Save current configuration as a JSON file"""
    config = {
        'CELEBDF_PATH': CELEBDF_PATH,
        'FF_PATH': FF_PATH,
        'PROCESSED_PATH': PROCESSED_PATH,
        'OUTPUT_PATH': OUTPUT_PATH,
        'CHECKPOINT_DIR': CHECKPOINT_DIR,
        'PROCESS_DATASETS': PROCESS_DATASETS,
        'CELEBDF_MAX_PER_FACE': CELEBDF_MAX_PER_FACE,
        'FF_MAX_PER_FACE': FF_MAX_PER_FACE,
        'FF_MAX_FACES': FF_MAX_FACES,
        'TARGET_SIZE': TARGET_SIZE,
        'BATCH_SIZE': BATCH_SIZE,
        'IMAGE_CHANNEL': IMAGE_CHANNEL,
        'Z_DIM': Z_DIM,
        'G_HIDDEN': G_HIDDEN,
        'D_HIDDEN': D_HIDDEN,
        'X_DIM': X_DIM,
        'EPOCH_NUM': EPOCH_NUM,
        'lr': lr,
        'seed': seed,
        'EXPERIMENT_NAME': EXPERIMENT_NAME,
        'timestamp': datetime.datetime.now().isoformat()
    }
    
    config_path = os.path.join(OUTPUT_PATH, f"{EXPERIMENT_NAME}_config.json")
    with open(config_path, 'w') as f:
        json.dump(config, f, indent=4)
    
    print(f"Configuration saved to {config_path}")
    return config_path

def monitor_resources():
    """Monitor system resources and return a report"""
    if not RESOURCE_MONITORING:
        return "Resource monitoring unavailable. Install psutil and gputil."
    
    # CPU info
    cpu_percent = psutil.cpu_percent(interval=1)
    memory = psutil.virtual_memory()
    memory_percent = memory.percent
    
    # GPU info
    gpu_info = "No GPU available"
    if torch.cuda.is_available():
        try:
            gpus = GPUtil.getGPUs()
            if gpus:
                gpu = gpus[0]  # Get the first GPU
                gpu_name = gpu.name
                gpu_load = f"{gpu.load * 100:.1f}%"
                gpu_mem_used = f"{gpu.memoryUsed:.0f}MB"
                gpu_mem_total = f"{gpu.memoryTotal:.0f}MB"
                gpu_mem_percent = f"{(gpu.memoryUsed / gpu.memoryTotal) * 100:.1f}%"
                gpu_temp = f"{gpu.temperature}°C"
                gpu_info = f"{gpu_name}: {gpu_load} load, {gpu_mem_used}/{gpu_mem_total} ({gpu_mem_percent}), {gpu_temp}"
        except Exception as e:
            gpu_info = f"Error getting GPU info: {e}"
    
    return {
        "cpu_percent": cpu_percent,
        "memory_percent": memory_percent,
        "gpu_info": gpu_info,
        "timestamp": datetime.datetime.now().isoformat()
    }

# Create directories and verify paths
celebdf_dir, ff_dir, combined_dir = create_output_directories()
celebdf_exists, ff_exists = verify_dataset_paths()
save_config()


Created output directories:
 - CelebDF: c:\Users\vinay\Documents\mnist\processed_faces\celebdf
 - FaceForensics++: c:\Users\vinay\Documents\mnist\processed_faces\faceforensics
 - Combined: c:\Users\vinay\Documents\mnist\processed_faces\combined
 - Results: c:\Users\vinay\Documents\mnist\output
 - Checkpoints: c:\Users\vinay\Documents\mnist\output\checkpoints
 - Logs: c:\Users\vinay\Documents\mnist\output\logs
CelebDF dataset: Found at c:\Users\vinay\Documents\mnist\faces\Real\Celeb_V2\Train\real
FaceForensics++ dataset: Found at c:\Users\vinay\Documents\mnist\faces\Real\FaceForensics++\original_sequences\youtube\c23\frames
Configuration saved to c:\Users\vinay\Documents\mnist\output\face_gan_20250304_161130_config.json


'c:\\Users\\vinay\\Documents\\mnist\\output\\face_gan_20250304_161130_config.json'

In [None]:

###############################################################################
# CELL 4: CELEBDF DATASET PROCESSING
###############################################################################
"""
CelebDF Dataset Processing
------------------------
Process the CelebDF dataset by extracting faces, resizing them,
and saving them to a standardized format.
"""
def process_celebdf_dataset(source_dir, target_dir, target_size=(128, 128), max_images=None):
    """Process CelebDF dataset where all images are in a single folder"""
    source_path = Path(source_dir)
    target_path = Path(target_dir)
    
    if not source_path.exists():
        print(f"Source directory {source_path} does not exist.")
        return 0
    
    # Get all image files
    image_extensions = ['.jpg', '.jpeg', '.png']
    image_files = []
    for ext in image_extensions:
        image_files.extend(source_path.glob(f"*{ext}"))
        image_files.extend(source_path.glob(f"*{ext.upper()}"))
    
    # If maximum image limit is set, randomly sample
    if max_images and len(image_files) > max_images:
        random.seed(1)  # For reproducible results
        image_files = random.sample(image_files, max_images)
    
    print(f"Found {len(image_files)} total images in {source_path}")
    
    # Process all images directly with progress bar
    count = 0
    with tqdm(total=len(image_files), desc="Processing CelebDF images", unit="img") as pbar:
        for img_path in image_files:
            try:
                # Open and resize image
                img = Image.open(img_path)
                img = img.resize(target_size, Image.LANCZOS)
                
                # Save to target directory
                target_file = target_path / f"celebdf_{count:06d}{img_path.suffix}"
                img.save(target_file)
                count += 1
                pbar.update(1)
                    
            except Exception as e:
                print(f"\nError processing {img_path.name}: {e}")
    
    print(f"Successfully processed {count} images from CelebDF to {target_path}")
    return count

# Only run if PROCESS_DATASETS is True and CelebDF dataset exists
if PROCESS_DATASETS and celebdf_exists:
    count = process_celebdf_dataset(
        CELEBDF_PATH, 
        celebdf_dir, 
        target_size=(TARGET_SIZE, TARGET_SIZE),
        max_images=None  # Set to None to process all images, or a number to limit
    )
    print(f"Processed {count} CelebDF images")
else:
    print("Skipping CelebDF processing. Set PROCESS_DATASETS=True to process.")


Found 80576 total images in c:\Users\vinay\Documents\mnist\faces\Real\Celeb_V2\Train\real


Processing CelebDF images:   0%|          | 0/80576 [00:00<?, ?img/s]

Successfully processed 80576 images from CelebDF to c:\Users\vinay\Documents\mnist\processed_faces\celebdf
Processed 80576 CelebDF images


In [None]:
###############################################################################
# CELL 5: FACEFORENSICS++ DATASET PROCESSING
###############################################################################
"""
FaceForensics++ Dataset Processing
-------------------------------
Process the FaceForensics++ dataset by extracting face folders,
resizing images, and saving them to a standardized format.
"""
def process_faceforensics_dataset(source_dir, target_dir, target_size=(128, 128), max_folders=None, max_images_per_folder=None):
    """Process FaceForensics++ dataset where images are organized in numbered folders"""
    source_path = Path(source_dir)
    target_path = Path(target_dir)
    
    if not source_path.exists():
        print(f"Source directory {source_path} does not exist.")
        return 0
    
    # Get all folders in the source directory
    face_folders = [f for f in source_path.iterdir() if f.is_dir()]
    
    # Limit number of folders if specified
    if max_folders and len(face_folders) > max_folders:
        face_folders = random.sample(face_folders, max_folders)
    
    print(f"Found {len(face_folders)} face folders in {source_path}")
    
    # Count total images to process for progress bar
    total_images = 0
    folder_image_counts = []
    for folder in face_folders:
        image_files = []
        for ext in ['.jpg', '.jpeg', '.png']:
            image_files.extend(folder.glob(f"*{ext}"))
            image_files.extend(folder.glob(f"*{ext.upper()}"))
        
        count = len(image_files)
        if max_images_per_folder and count > max_images_per_folder:
            count = max_images_per_folder
        
        folder_image_counts.append(count)
        total_images += count
    
    # Process images in each folder with progress bar
    count = 0
    with tqdm(total=total_images, desc="Processing FaceForensics++ images", unit="img") as pbar:
        for folder_idx, folder in enumerate(face_folders):
            # Get image files in this folder
            image_files = []
            for ext in ['.jpg', '.jpeg', '.png']:
                image_files.extend(folder.glob(f"*{ext}"))
                image_files.extend(folder.glob(f"*{ext.upper()}"))
            
            # Limit number of images per folder if specified
            if max_images_per_folder and len(image_files) > max_images_per_folder:
                image_files = random.sample(image_files, max_images_per_folder)
            
                try:
                    # Open and resize image
                    img = Image.open(img_path)
                    img = img.resize(target_size, Image.LANCZOS)
                    
                    # Save to target directory with folder index as face ID
                    target_file = target_path / f"ff_{folder_idx:04d}_{img_idx:04d}{img_path.suffix}"
                    img.save(target_file)
                    count += 1
                    pbar.update(1)
                        
                except Exception as e:
                    print(f"\nError processing {img_path.name}: {e}")
    
    print(f"Successfully processed {count} images from FaceForensics++ to {target_path}")
    return count

In [None]:
# Only run if PROCESS_DATASETS is True and FaceForensics++ dataset exists
if PROCESS_DATASETS and ff_exists:
    count = process_faceforensics_dataset(
        FF_PATH, 
        ff_dir, 
        target_size=(TARGET_SIZE, TARGET_SIZE),
        max_folders=FF_MAX_FACES,
        max_images_per_folder=FF_MAX_PER_FACE
    )
    print(f"Processed {count} FaceForensics++ images")
else:
    print("Skipping FaceForensics++ processing. Set PROCESS_DATASETS=True to process.")


Found 999 face folders in c:\Users\vinay\Documents\mnist\faces\Real\FaceForensics++\original_sequences\youtube\c23\frames


Processing FaceForensics++ images:   0%|          | 0/63898 [00:00<?, ?img/s]

Successfully processed 63898 images from FaceForensics++ to c:\Users\vinay\Documents\mnist\processed_faces\faceforensics
Processed 63898 FaceForensics++ images


In [None]:
###############################################################################
# CELL 6: COMBINE DATASETS
###############################################################################
"""
Dataset Combination
----------------
Combine the processed CelebDF and FaceForensics++ datasets 
into a single dataset for training. This ensures we have a diverse
set of facial images.
"""
def combine_datasets(celebdf_dir, ff_dir, combined_dir):
    """Combine processed datasets into one directory"""
    # Copy all images from CelebDF directory
    celebdf_files = list(Path(celebdf_dir).glob("*.jpg")) + list(Path(celebdf_dir).glob("*.png"))
    ff_files = list(Path(ff_dir).glob("*.jpg")) + list(Path(ff_dir).glob("*.png"))
    
    total_files = len(celebdf_files) + len(ff_files)
    print(f"Combining {len(celebdf_files)} CelebDF images and {len(ff_files)} FaceForensics++ images...")
    
    with tqdm(total=total_files, desc="Combining datasets", unit="img") as pbar:
        # Copy CelebDF images
        for img_path in celebdf_files:
            try:
                target_path = Path(combined_dir) / img_path.name
                img = Image.open(img_path)
                img.save(target_path)
                pbar.update(1)
            except Exception as e:
                print(f"\nError copying {img_path.name}: {e}")
        
        # Copy FaceForensics++ images
        for img_path in ff_files:
            try:
                target_path = Path(combined_dir) / img_path.name
                img = Image.open(img_path)
                img.save(target_path)
                pbar.update(1)
            except Exception as e:
                print(f"\nError copying {img_path.name}: {e}")
    
    print(f"Combined dataset created with {total_files} total images at {combined_dir}")
    return total_files

# Only run if both datasets have been processed
if PROCESS_DATASETS and celebdf_exists and ff_exists:
    total_images = combine_datasets(celebdf_dir, ff_dir, combined_dir)
    dataset_path = combined_dir
    print(f"Combined dataset created with {total_images} images")
elif celebdf_exists and os.path.exists(celebdf_dir) and len(os.listdir(celebdf_dir)) > 0:
    dataset_path = celebdf_dir
    print(f"Using CelebDF dataset at {celebdf_dir}")
elif ff_exists and os.path.exists(ff_dir) and len(os.listdir(ff_dir)) > 0:
    dataset_path = ff_dir
    print(f"Using FaceForensics++ dataset at {ff_dir}")
elif os.path.exists(combined_dir) and len(os.listdir(combined_dir)) > 0:
    dataset_path = combined_dir
    print(f"Using existing combined dataset at {combined_dir}")
else:
    print("Error: No datasets were processed or found")
    dataset_path = None

print(f"Dataset path for training: {dataset_path}")


Combining 80576 CelebDF images and 63898 FaceForensics++ images...


Combining datasets:   0%|          | 0/144474 [00:00<?, ?img/s]

Combined dataset created with 144474 total images at c:\Users\vinay\Documents\mnist\processed_faces\combined
Combined dataset created with 144474 images
Dataset path for training: c:\Users\vinay\Documents\mnist\processed_faces\combined


In [None]:
###############################################################################
# CELL 7: DATASET CLASS AND CUDA SETUP
###############################################################################
"""
Dataset Class and CUDA Setup
-------------------------
Define the custom dataset class for loading face images
and set up CUDA for GPU acceleration if available.
"""
class FaceDataset(Dataset):
    """Custom Dataset for loading face images"""
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_files = [f for f in os.listdir(root_dir) 
                           if os.path.isfile(os.path.join(root_dir, f)) and 
                           f.lower().endswith(('.png', '.jpg', '.jpeg'))]
        
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        img_path = os.path.join(self.root_dir, self.image_files[idx])
        image = Image.open(img_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
            
        return image, 0  # Return 0 as dummy label

def setup_cuda():
    """Setup CUDA if available"""
    global CUDA, device
    # Set random seed for reproducibility
    random.seed(seed)
    torch.manual_seed(seed)
    
    if not torch.cuda.is_available():
        CUDA = False
        print("CUDA is not available. Running on CPU.")
        device = torch.device("cpu")
    else:
        device = torch.device("cuda:0")
        # Enable deterministic behavior for reproducibility
        torch.cuda.manual_seed(seed)
        cudnn.benchmark = True
    
    print(f"PyTorch version: {torch.__version__}")
    if CUDA:
        print(f"CUDA version: {torch.version.cuda}")
        print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Using device: {device}")
    return device

# Setup CUDA
device = setup_cuda()

PyTorch version: 2.6.0+cu126
CUDA version: 12.6
GPU: NVIDIA GeForce RTX 2070 Super with Max-Q Design
Using device: cuda:0


In [None]:
###############################################################################
# CELL 8: GAN MODEL ARCHITECTURE
###############################################################################
"""
GAN Model Architecture
-------------------
Define the Generator and Discriminator architecture
for the Deep Convolutional GAN (DCGAN).
"""
def weights_init(m):
    """Initialize network weights"""
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

class Generator(nn.Module):
    """Generator Network"""
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            # Input: Z_DIM x 1 x 1
            nn.ConvTranspose2d(Z_DIM, G_HIDDEN * 16, kernel_size=4, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(G_HIDDEN * 16),
            nn.ReLU(True),
            # 4x4
            
            nn.ConvTranspose2d(G_HIDDEN * 16, G_HIDDEN * 8, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(G_HIDDEN * 8),
            nn.ReLU(True),
            # 8x8
            
            nn.ConvTranspose2d(G_HIDDEN * 8, G_HIDDEN * 4, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(G_HIDDEN * 4),
            nn.ReLU(True),
            # 16x16
            
            nn.ConvTranspose2d(G_HIDDEN * 4, G_HIDDEN * 2, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(G_HIDDEN * 2),
            nn.ReLU(True),
            # 32x32
            
            nn.ConvTranspose2d(G_HIDDEN * 2, G_HIDDEN, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(G_HIDDEN),
            nn.ReLU(True),
            # 64x64
            
            nn.ConvTranspose2d(G_HIDDEN, IMAGE_CHANNEL, kernel_size=4, stride=2, padding=1, bias=False),
            nn.Tanh()
            # 128x128
        )

    def forward(self, input):
        return self.main(input)

class Discriminator(nn.Module):
    """Discriminator Network"""
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            # Input: IMAGE_CHANNEL x 128 x 128
            nn.Conv2d(IMAGE_CHANNEL, D_HIDDEN, kernel_size=4, stride=2, padding=1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # 64x64
            
            nn.Conv2d(D_HIDDEN, D_HIDDEN * 2, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(D_HIDDEN * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # 32x32
            
            nn.Conv2d(D_HIDDEN * 2, D_HIDDEN * 4, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(D_HIDDEN * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # 16x16
            
            nn.Conv2d(D_HIDDEN * 4, D_HIDDEN * 8, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(D_HIDDEN * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # 8x8
            
            nn.Conv2d(D_HIDDEN * 8, D_HIDDEN * 16, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(D_HIDDEN * 16),
            nn.LeakyReLU(0.2, inplace=True),
            # 4x4
            
            nn.Conv2d(D_HIDDEN * 16, 1, kernel_size=4, stride=1, padding=0, bias=False),
            nn.Sigmoid()
            # 1x1
        )

    def forward(self, input):
        return self.main(input).view(-1, 1).squeeze(1)

# Initialize networks
netG = Generator().to(device)
netG.apply(weights_init)
print("Generator initialized")

netD = Discriminator().to(device)
netD.apply(weights_init)
print("Discriminator initialized")

# Display network architecture summaries
print("\nGenerator Architecture:")
print(netG)
print("\nDiscriminator Architecture:")
print(netD)


Generator initialized
Discriminator initialized

Generator Architecture:
Generator(
  (main): Sequential(
    (0): ConvTranspose2d(100, 1024, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(1024, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU(inplace=True)
    (12): ConvTranspose2d(128, 64, kernel

In [None]:
###############################################################################
# CELL 9: DATA LOADING AND VISUALIZATION
###############################################################################
"""
Data Loading and Visualization
---------------------------
Load the processed face dataset, apply transformations,
and visualize sample images.
"""
def load_dataset(data_path):
    """Load face dataset and create dataloader"""
    # Image transformations
    transform = transforms.Compose([
        transforms.Resize(X_DIM),
        transforms.CenterCrop(X_DIM),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize for RGB
    ])
    
    # Create dataset and dataloader
    try:
        dataset = FaceDataset(root_dir=data_path, transform=transform)
        if len(dataset) == 0:
            print(f"Error: No images found in {data_path}")
            return None, None
            
        dataloader = DataLoader(
            dataset, 
            batch_size=BATCH_SIZE, 
            shuffle=True,
            num_workers=2, 
            pin_memory=True if CUDA else False
        )
        print(f"Dataset loaded with {len(dataset)} images")
        print(f"Number of batches: {len(dataloader)}")
        return dataset, dataloader
    except Exception as e:
        print(f"Error loading dataset: {e}")
        return None, None

def visualize_dataset(dataloader, output_path, title="Training Images"):
    """Save sample images from dataset"""
    try:
        real_batch = next(iter(dataloader))[0]
        plt.figure(figsize=(10,10))
        plt.axis("off")
        plt.title(title)
        plt.imshow(np.transpose(vutils.make_grid(
            real_batch[:min(16, len(real_batch))], padding=2, normalize=True).cpu(),(1,2,0)))
        filepath = os.path.join(output_path, "real_samples.png")
        plt.savefig(filepath)
        plt.show()
        print(f"Saved sample of real images to {filepath}")
        return filepath
    except Exception as e:
        print(f"Error visualizing dataset: {e}")
        return None

# Load dataset if path exists
if dataset_path and os.path.exists(dataset_path):
    dataset, dataloader = load_dataset(dataset_path)
    if dataset and dataloader:
        # Visualize some samples
        visualize_dataset(dataloader, OUTPUT_PATH, "Sample Training Images")
else:
    print("Dataset path not available. Please check your configurations.")
    dataloader = None


Dataset loaded with 144473 images
Number of batches: 4515


In [None]:
###############################################################################
# CELL 10: CHECKPOINT FUNCTIONS
###############################################################################
"""
Checkpoint Functions
----------------
Functions for saving and loading model checkpoints
to support resumable training and best model selection.
"""
def save_checkpoint(netG, netD, optimG, optimD, epoch, iteration, losses, img_list, filename=None):
    """Save a checkpoint of the current training state"""
    if not filename:
        filename = os.path.join(CHECKPOINT_DIR, f"checkpoint_e{epoch}_i{iteration}.pt")
    
    checkpoint = {
        'epoch': epoch,
        'iteration': iteration,
        'netG_state_dict': netG.state_dict(),
        'netD_state_dict': netD.state_dict(),
        'optimG_state_dict': optimG.state_dict(),
        'optimD_state_dict': optimD.state_dict(),
        'G_losses': losses['G'],
        'D_losses': losses['D'],
        'img_list': img_list,
        'timestamp': datetime.datetime.now().isoformat()
    }
    
    torch.save(checkpoint, filename)
    print(f"Checkpoint saved: {filename}")
    
    # Save the latest checkpoint also (for resuming)
    latest_path = os.path.join(CHECKPOINT_DIR, "latest_checkpoint.pt")
    torch.save(checkpoint, latest_path)
    
    return filename

def load_checkpoint(netG, netD, optimG, optimD, filename=None):
    """Load a checkpoint to resume training"""
    if not filename:
        # Try to load the latest checkpoint
        filename = os.path.join(CHECKPOINT_DIR, "latest_checkpoint.pt")
    
    if not os.path.exists(filename):
        print(f"No checkpoint found at {filename}")
        return None, 0, 0, {'G': [], 'D': []}, []
    
    print(f"Loading checkpoint: {filename}")
    checkpoint = torch.load(filename, map_location=device)
    
    netG.load_state_dict(checkpoint['netG_state_dict'])
    netD.load_state_dict(checkpoint['netD_state_dict'])
    optimG.load_state_dict(checkpoint['optimG_state_dict'])
    optimD.load_state_dict(checkpoint['optimD_state_dict'])
    
    # For optimizers loaded from checkpoint, move to correct device
    if CUDA:
        for state in optimG.state.values():
            for k, v in state.items():
                if torch.is_tensor(v):
                    state[k] = v.to(device)
                    
        for state in optimD.state.values():
            for k, v in state.items():
                if torch.is_tensor(v):
                    state[k] = v.to(device)
    
    epoch = checkpoint['epoch']
    iteration = checkpoint['iteration']
    G_losses = checkpoint['G_losses']
    D_losses = checkpoint['D_losses']
    img_list = checkpoint.get('img_list', [])
    
    print(f"Resuming from epoch {epoch+1}, iteration {iteration}")
    
    return {'netG': netG, 'netD': netD, 'optimG': optimG, 'optimD': optimD}, epoch, iteration, {'G': G_losses, 'D': D_losses}, img_list


In [None]:
###############################################################################
# CELL 11: TRAINING INITIALIZATION
###############################################################################
"""
Training Initialization
--------------------
Initialize GAN training by setting up optimizers, 
creating fixed noise for visualization,
and loading a checkpoint if resuming training.
"""
# Setup loss function and optimizers
criterion = nn.BCELoss()
# Fixed noise for visualization
fixed_noise = torch.randn(16, Z_DIM, 1, 1, device=device)

# Initialize optimizers
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(0.5, 0.999))

# Setup TensorBoard
if ENABLE_TENSORBOARD:
    tb_writer = SummaryWriter(log_dir=os.path.join(LOG_DIR, EXPERIMENT_NAME))
    # Add model graph to TensorBoard
    try:
        sample_input = torch.randn(1, Z_DIM, 1, 1, device=device)
        tb_writer.add_graph(netG, sample_input)
    except Exception as e:
        print(f"Could not add model graph to TensorBoard: {e}")
else:
    tb_writer = None

# Try to load checkpoint if resuming training
start_epoch = 0
start_iter = 0
G_losses = []
D_losses = []
img_list = []

if RESUME_TRAINING:
    models, start_epoch, start_iter, losses, prev_img_list = load_checkpoint(netG, netD, optimizerG, optimizerD)
    if models:  # If checkpoint was loaded successfully
        netG, netD = models['netG'], models['netD']
        optimizerG, optimizerD = models['optimG'], models['optimD']
        G_losses = losses['G']
        D_losses = losses['D']
        img_list = prev_img_list
        print(f"Resuming from epoch {start_epoch+1}, iteration {start_iter}")

print(f"\n{'='*50}\nREADY FOR TRAINING\n{'='*50}")


In [None]:
###############################################################################
# CELL 12: TRAINING LOOP
###############################################################################
"""
GAN Training Loop
--------------
Train the GAN model with comprehensive tracking, visualization,
early stopping, and checkpoint saving. This cell runs the main 
training process through all epochs.
"""
def train_epoch(epoch, iters):
    """Train the GAN for one epoch"""
    # Variables for this epoch
    epoch_g_loss = 0.0
    epoch_d_loss = 0.0
    batch_count = 0
    
    # Time tracking
    start_time = time.time()
    
    with tqdm(dataloader, unit="batch", desc=f"Epoch {epoch+1}/{EPOCH_NUM}", dynamic_ncols=True) as tepoch:
        for i, data in enumerate(tepoch):
            real_images = data[0].to(device, non_blocking=True)
            b_size = real_images.size(0)
            
            # Train Discriminator with real images
            netD.zero_grad()
            label = torch.full((b_size,), REAL_LABEL, dtype=torch.float, device=device)
            output = netD(real_images)
            errD_real = criterion(output, label)
            errD_real.backward()
            D_x = output.mean().item()
            
            # Train Discriminator with fake images
            noise = torch.randn(b_size, Z_DIM, 1, 1, device=device)
            fake = netG(noise)
            label.fill_(FAKE_LABEL)
            output = netD(fake.detach())
            errD_fake = criterion(output, label)
            errD_fake.backward()
            D_G_z1 = output.mean().item()
            
            errD = errD_real + errD_fake
            optimizerD.step()
            
            # Train Generator
            netG.zero_grad()
            label.fill_(REAL_LABEL)  # Fake labels are real for generator cost
            output = netD(fake)
            errG = criterion(output, label)
            errG.backward()
            D_G_z2 = output.mean().item()
            optimizerG.step()
            
            # Save losses for plotting
            G_losses.append(errG.item())
            D_losses.append(errD.item())
            
            # Accumulate losses for this epoch
            epoch_g_loss += errG.item()
            epoch_d_loss += errD.item()
            batch_count += 1
            
            # Add losses to TensorBoard
            if ENABLE_TENSORBOARD:
                tb_writer.add_scalar('Loss/Generator', errG.item(), iters)
                tb_writer.add_scalar('Loss/Discriminator', errD.item(), iters)
                tb_writer.add_scalar('D_x', D_x, iters)
                tb_writer.add_scalar('D_G_z1', D_G_z1, iters)
                tb_writer.add_scalar('D_G_z2', D_G_z2, iters)
            
            # Resource monitoring
            if RESOURCE_MONITORING and iters % RESOURCE_CHECK_FREQ == 0:
                resources = monitor_resources()
                print(f"\nResource usage: CPU {resources['cpu_percent']}%, Memory {resources['memory_percent']}%")
                print(f"GPU: {resources['gpu_info']}")
                
                if ENABLE_TENSORBOARD:
                    tb_writer.add_scalar('Resources/CPU', resources['cpu_percent'], iters)
                    tb_writer.add_scalar('Resources/Memory', resources['memory_percent'], iters)
            
            # Save generated images periodically
            if (iters % CHECKPOINT_SAMPLES == 0) or ((epoch == EPOCH_NUM-1) and (i == len(dataloader)-1)):
                with torch.no_grad():
                    fake = netG(fixed_noise).detach().cpu()
                img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
                
                # Save current generator output
                plt.figure(figsize=(8,8))
                plt.axis("off")
                plt.title(f"Generated Images (Epoch {epoch+1}, Iter {iters})")
                plt.imshow(np.transpose(img_list[-1], (1,2,0)))
                plt.savefig(os.path.join(OUTPUT_PATH, f"generated_e{epoch+1}_i{iters}.png"))
                plt.close()
                
                # Add images to TensorBoard
                if ENABLE_TENSORBOARD:
                    tb_writer.add_image('Generated Images', np.transpose(img_list[-1], (2, 0, 1)), iters)
            
            iters += 1
    
    # Calculate average losses for this epoch
    if batch_count > 0:
        epoch_g_loss /= batch_count
        epoch_d_loss /= batch_count
        
    return epoch_g_loss, epoch_d_loss, iters

# Variables for early stopping
best_loss = float('inf')
patience_counter = 0
best_model_path = None

# Start training from the last epoch if resuming
iters = start_iter
for epoch in range(start_epoch, EPOCH_NUM):
    # Train for one epoch
    epoch_g_loss, epoch_d_loss, iters = train_epoch(epoch, iters)
    
    print(f"Epoch {epoch+1} average losses - Generator: {epoch_g_loss:.4f}, Discriminator: {epoch_d_loss:.4f}")
    
    # Save checkpoint at the end of each epoch
    if (epoch + 1) % CHECKPOINT_FREQ == 0:
        checkpoint_path = save_checkpoint(
            netG, netD, optimizerG, optimizerD, epoch, 
            iters, {'G': G_losses, 'D': D_losses}, img_list
        )
    
    # Early stopping check based on generator loss
    current_loss = epoch_g_loss
    if current_loss < best_loss - EARLY_STOPPING_THRESHOLD:
        best_loss = current_loss
        patience_counter = 0
        # Save the best model
        best_model_path = save_checkpoint(
            netG, netD, optimizerG, optimizerD, epoch, 
            iters, {'G': G_losses, 'D': D_losses}, img_list,
            os.path.join(CHECKPOINT_DIR, "best_model.pt")
        )
        print(f"New best model saved with G loss: {best_loss:.4f}")
    else:
        patience_counter += 1
        print(f"Early stopping patience: {patience_counter}/{EARLY_STOPPING_PATIENCE}")
        
        if patience_counter >= EARLY_STOPPING_PATIENCE:
            print(f"Early stopping triggered! No improvement for {EARLY_STOPPING_PATIENCE} epochs.")
            break

# Final checkpoint
save_checkpoint(
    netG, netD, optimizerG, optimizerD, 
    min(epoch, EPOCH_NUM-1), iters, 
    {'G': G_losses, 'D': D_losses}, img_list, 
    os.path.join(CHECKPOINT_DIR, "final_model.pt")
)

# Close TensorBoard writer
if ENABLE_TENSORBOARD and tb_writer:
    tb_writer.close()


In [None]:
###############################################################################
# CELL 13: VISUALIZATIONS
###############################################################################
"""
Result Visualizations
------------------
Generate visualizations of training progress, loss curves,
and sample images from the trained generator.
"""
# Plot loss curves
plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss")
plt.plot(G_losses, label="Generator")
plt.plot(D_losses, label="Discriminator")
plt.xlabel("Iterations")
plt.ylabel("Loss")
plt.legend()
plt.savefig(os.path.join(OUTPUT_PATH, "loss_curves.png"))
plt.show()

# Visualize training progress images
def visualize_training_progress(img_list, output_path, title="Training Progress"):
    """Create a grid of generated images showing progression"""
    if not img_list:
        print("No images available for progress visualization")
        return None
        
    # Select evenly spaced samples to show progression
    num_samples = min(9, len(img_list))
    step = max(1, len(img_list) // num_samples)
    samples = [img_list[i] for i in range(0, len(img_list), step)][:num_samples]
    
    # Create a grid
    rows = int(np.sqrt(len(samples)))
    cols = int(np.ceil(len(samples) / rows))
    
    fig, axes = plt.subplots(rows, cols, figsize=(15, 15))
    axes = axes.flatten()
    
    for i, img in enumerate(samples):
        if i < len(axes):
            sample_num = i * step
            if sample_num >= len(img_list):
                sample_num = len(img_list) - 1
            axes[i].imshow(np.transpose(img, (1, 2, 0)))
            axes[i].set_title(f"Sample {sample_num}")
            axes[i].axis('off')
    
    # Hide unused subplots
    for i in range(len(samples), len(axes)):
        axes[i].axis('off')
    
    plt.tight_layout()
    filepath = os.path.join(output_path, "training_progress.png")
    plt.savefig(filepath)
    plt.show()
    print(f"Saved training progress visualization to {filepath}")
    return filepath

# Show progress if we have multiple images
if len(img_list) > 1:
    visualize_training_progress(img_list, OUTPUT_PATH)

# Final visualization: Real vs Generated
real_batch = next(iter(dataloader))[0][:16].to(device)

# Generate a batch of fake images
with torch.no_grad():
    fake_batch = netG(fixed_noise).detach().cpu()

# Plot real images vs fake images
plt.figure(figsize=(15,7))
plt.subplot(1,2,1)
plt.axis("off")
plt.title("Real Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch.cpu(), padding=2, normalize=True),(1,2,0)))

plt.subplot(1,2,2)
plt.axis("off")
plt.title("Generated Images")
plt.imshow(np.transpose(vutils.make_grid(fake_batch, padding=2, normalize=True),(1,2,0)))
plt.savefig(os.path.join(OUTPUT_PATH, "final_comparison.png"))
plt.show()


In [None]:

###############################################################################
# CELL 14: GENERATE SAMPLES WITH TRAINED MODEL
###############################################################################
"""
Generate Samples with Trained Model
---------------------------------
Generate and visualize multiple samples using the trained generator.
This demonstrates how to use the model for inference.
"""
def generate_samples(netG, n_samples=16, grid_rows=4):
    """Generate multiple samples from the trained generator"""
    grid_cols = n_samples // grid_rows
    fig, axes = plt.subplots(grid_rows, grid_cols, figsize=(grid_cols*2, grid_rows*2))
    axes = axes.flatten()
    
    with torch.no_grad():
        for i in range(n_samples):
            # Generate a random latent vector
            z = torch.randn(1, Z_DIM, 1, 1, device=device)
            
            # Generate a fake image
            fake = netG(z).detach().cpu()
            
            # Display the image
            img = np.transpose(vutils.make_grid(fake, padding=0, normalize=True), (1, 2, 0))
            axes[i].imshow(img)
            axes[i].axis('off')
    
    plt.tight_layout()
    plt.savefig(os.path.join(OUTPUT_PATH, "generated_samples.png"))
    plt.show()
    print(f"Generated {n_samples} samples using the trained model")

# Generate samples with the trained model
generate_samples(netG, n_samples=16, grid_rows=4)


In [None]:
###############################################################################
# CELL 15: CREATE LATENT SPACE ANIMATION
###############################################################################
"""
Latent Space Animation
-------------------
Create an animation by interpolating between points in the latent space.
This helps visualize how the latent space is structured.
"""
def latent_space_interpolation(netG, n_steps=30):
    """Create an animation by interpolating between two latent vectors"""
    # Generate two random latent vectors
    z_start = torch.randn(1, Z_DIM, 1, 1, device=device)
    z_end = torch.randn(1, Z_DIM, 1, 1, device=device)
    
    # Generate intermediate vectors by linear interpolation
    z_vectors = [z_start + (z_end - z_start) * (step / (n_steps-1)) for step in range(n_steps)]
    
    # Generate images for each interpolated vector
    fig = plt.figure(figsize=(8, 8))
    plt.axis("off")
    images = []
    
    with torch.no_grad():
        for z in tqdm(z_vectors, desc="Generating interpolated images"):
            fake = netG(z).detach().cpu()
            img = vutils.make_grid(fake, padding=2, normalize=True)
            images.append([plt.imshow(np.transpose(img, (1, 2, 0)), animated=True)])
    
    # Create animation
    ani = animation.ArtistAnimation(fig, images, interval=200, blit=True)
    ani.save(os.path.join(OUTPUT_PATH, "latent_space_animation.gif"), writer='pillow', fps=10)
    
    # Display the first and last images side by side
    plt.figure(figsize=(12, 6))
    plt.subplot(1, 2, 1)
    plt.axis("off")
    plt.title("Start Point")
    with torch.no_grad():
        fake = netG(z_start).detach().cpu()
    plt.imshow(np.transpose(vutils.make_grid(fake, padding=2, normalize=True), (1, 2, 0)))
    
    plt.subplot(1, 2, 2)
    plt.axis("off")
    plt.title("End Point")
    with torch.no_grad():
        fake = netG(z_end).detach().cpu()
    plt.imshow(np.transpose(vutils.make_grid(fake, padding=2, normalize=True), (1, 2, 0)))
    
    plt.tight_layout()
    plt.savefig(os.path.join(OUTPUT_PATH, "latent_space_endpoints.png"))
    
    plt.show()
    
    print(f"Latent space animation saved to {os.path.join(OUTPUT_PATH, 'latent_space_animation.gif')}")
    return ani

# Create a latent space interpolation animation (fewer steps for quicker execution)
latent_space_interpolation(netG, n_steps=10)


In [None]:
###############################################################################
# CELL 16: SAVE FINAL MODELS
###############################################################################
"""
Save Final Models
--------------
Save the trained generator and discriminator models for later use.
This allows you to load and use the models without retraining.
"""
# Save the final models
torch.save(netG.state_dict(), os.path.join(OUTPUT_PATH, "generator_final.pth"))
torch.save(netD.state_dict(), os.path.join(OUTPUT_PATH, "discriminator_final.pth"))

print(f"Final models saved to {OUTPUT_PATH}")
print("- generator_final.pth")
print("- discriminator_final.pth")

# Function to load the trained generator for inference
def load_generator(model_path):
    """Load a trained generator model"""
    model = Generator().to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()
    return model

###############################################################################
# Example of how to load and use the saved generator
"""
print("\nExample of loading the saved generator:")
print("loaded_generator = load_generator(os.path.join(OUTPUT_PATH, 'generator_final.pth'))")
print("# Generate images with loaded model:")
print("with torch.no_grad():")
print("    z = torch.randn(1, Z_DIM, 1, 1, device=device)")
print("    fake_img = loaded_generator(z)")


In [None]:
###############################################################################
# CELL 17: TRAINING SUMMARY AND METRICS
###############################################################################
"""
Training Summary and Metrics
-------------------------
Display a summary of the training process including timing information,
final loss values, and resource usage statistics.
"""
print(f"\n{'='*50}")
print("TRAINING SUMMARY")
print(f"{'='*50}")

# Calculate total images processed
total_images = len(dataset) * EPOCH_NUM if dataset else 0

# Get final losses (average of last 100 iterations)
final_g_loss = np.mean(G_losses[-100:]) if G_losses else float('nan')
print(f"{'='*50}")
