In [1]:
###############################################################################
# CELL 1: IMPORTS AND LIBRARIES
###############################################################################
"""
Face GAN - Deep Convolutional GAN for Face Generation
----------------------------------------------------
Import required libraries and modules for dataset processing,
neural network creation, training, and visualization.
"""
import os
import re
import sys
import time
import json
import random
import datetime
import numpy as np
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.utils as vutils
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
from tqdm.notebook import tqdm  # Use tqdm.notebook for Jupyter progress bars

# Try importing optional libraries
try:
    from torch.utils.tensorboard import SummaryWriter
    TENSORBOARD_AVAILABLE = True
except ImportError:
    TENSORBOARD_AVAILABLE = False
    print("TensorBoard not available. Install with: pip install tensorboard")

# For resource monitoring
try:
    import psutil
    import GPUtil
    RESOURCE_MONITORING = True
except ImportError:
    RESOURCE_MONITORING = False
    print("Resource monitoring unavailable. Install with: pip install psutil gputil")

# Set display settings for the notebook
%matplotlib inline
plt.rcParams['figure.figsize'] = (10.0, 8.0)


In [2]:
###############################################################################
# CELL 2: CONFIGURATION SETTINGS
###############################################################################
"""
Configuration Settings
---------------------
Define paths, hyperparameters, and training settings.
Modify these values to adapt the model to your specific requirements.
"""
# Paths
BASE_DIR = os.path.dirname(os.path.abspath("__file__"))
CELEBDF_PATH = os.path.join(BASE_DIR, "faces\\Real\\Celeb_V2\\Train\\real")
FF_PATH = os.path.join(BASE_DIR, "faces\\Real\\FaceForensics++\\original_sequences\\youtube\\c23\\frames")
PROCESSED_PATH = os.path.join(BASE_DIR, "processed_faces")
OUTPUT_PATH = os.path.join(BASE_DIR, "output")
CHECKPOINT_DIR = os.path.join(OUTPUT_PATH, "checkpoints")
LOG_DIR = os.path.join(OUTPUT_PATH, "logs")

# Dataset processing settings
PROCESS_DATASETS = False  # Set to False to skip dataset processing if already processed
CELEBDF_MAX_PER_FACE = None  # Max images per face for CelebDF
FF_MAX_PER_FACE = None  # Max images per face for FaceForensics++
FF_MAX_FACES = None  # Maximum number of different faces from FaceForensics++
TARGET_SIZE = 128  # Size to resize images to

# Model hyperparameters
CUDA = True  # Use CUDA (will be auto-detected later)
BATCH_SIZE = 64
IMAGE_CHANNEL = 3  # RGB images
Z_DIM = 100  # Latent vector dimension
G_HIDDEN = 64  # Double from 64 to 128
D_HIDDEN = 64  # Double from 64 to 128
X_DIM = 128  # Target image size
EPOCH_NUM = 10  # Number of training epochs
REAL_LABEL = 1
FAKE_LABEL = 0
lr = 2e-4  # Learning rate
seed = 1  # Random seed for reproducibility

# Training control settings
CHECKPOINT_FREQ = 5  # Save checkpoints every N epochs
CHECKPOINT_SAMPLES = 1000  # Generate image samples every N iterations
RESUME_TRAINING = True  # Try to resume from checkpoint if available
EARLY_STOPPING_PATIENCE = 5  # Early stopping after N epochs without improvement
EARLY_STOPPING_THRESHOLD = 0.01  # Minimum improvement to reset patience counter
RESOURCE_CHECK_FREQ = 50  # Check system resources every N batches
MAX_TRAINING_TIME = None  # Max training time in hours (None for no limit)
ENABLE_TENSORBOARD = TENSORBOARD_AVAILABLE  # Enable TensorBoard logging
EXPERIMENT_NAME = f"face_gan_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}"  # Unique name for this run

# Print some key configuration values
print(f"Dataset paths:\n- CelebDF: {CELEBDF_PATH}\n- FaceForensics++: {FF_PATH}")
print(f"Processed data will be saved to: {PROCESSED_PATH}")
print(f"Output will be saved to: {OUTPUT_PATH}")
print(f"Target image size: {TARGET_SIZE}x{TARGET_SIZE}")
print(f"Batch size: {BATCH_SIZE}")
print(f"Training epochs: {EPOCH_NUM}")


Dataset paths:
- CelebDF: c:\Users\vinay\Documents\mnist\faces\Real\Celeb_V2\Train\real
- FaceForensics++: c:\Users\vinay\Documents\mnist\faces\Real\FaceForensics++\original_sequences\youtube\c23\frames
Processed data will be saved to: c:\Users\vinay\Documents\mnist\processed_faces
Output will be saved to: c:\Users\vinay\Documents\mnist\output
Target image size: 128x128
Batch size: 64
Training epochs: 10


In [3]:
###############################################################################
# CELL 3: HELPER FUNCTIONS
###############################################################################
"""
Helper Functions
--------------
Create output directories, verify dataset paths,
save configuration, and monitor system resources.
"""
def create_output_directories():
    """Create necessary output directories"""
    os.makedirs(PROCESSED_PATH, exist_ok=True)
    os.makedirs(OUTPUT_PATH, exist_ok=True)
    os.makedirs(CHECKPOINT_DIR, exist_ok=True)
    os.makedirs(LOG_DIR, exist_ok=True)
    
    # Create separate directories for each dataset
    celebdf_dir = os.path.join(PROCESSED_PATH, "celebdf")
    ff_dir = os.path.join(PROCESSED_PATH, "faceforensics")
    combined_dir = os.path.join(PROCESSED_PATH, "combined")
    
    os.makedirs(celebdf_dir, exist_ok=True)
    os.makedirs(ff_dir, exist_ok=True)
    os.makedirs(combined_dir, exist_ok=True)
    
    print(f"Created output directories:")
    print(f" - CelebDF: {celebdf_dir}")
    print(f" - FaceForensics++: {ff_dir}")
    print(f" - Combined: {combined_dir}")
    print(f" - Results: {OUTPUT_PATH}")
    print(f" - Checkpoints: {CHECKPOINT_DIR}")
    print(f" - Logs: {LOG_DIR}")
    
    return celebdf_dir, ff_dir, combined_dir

def verify_dataset_paths():
    """Verify that dataset paths exist"""
    celebdf_exists = os.path.exists(CELEBDF_PATH)
    ff_exists = os.path.exists(FF_PATH)
    
    print(f"CelebDF dataset: {'Found' if celebdf_exists else 'Not found'} at {CELEBDF_PATH}")
    print(f"FaceForensics++ dataset: {'Found' if ff_exists else 'Not found'} at {FF_PATH}")
    
    return celebdf_exists, ff_exists

def save_config():
    """Save current configuration as a JSON file"""
    config = {
        'CELEBDF_PATH': CELEBDF_PATH,
        'FF_PATH': FF_PATH,
        'PROCESSED_PATH': PROCESSED_PATH,
        'OUTPUT_PATH': OUTPUT_PATH,
        'CHECKPOINT_DIR': CHECKPOINT_DIR,
        'PROCESS_DATASETS': PROCESS_DATASETS,
        'CELEBDF_MAX_PER_FACE': CELEBDF_MAX_PER_FACE,
        'FF_MAX_PER_FACE': FF_MAX_PER_FACE,
        'FF_MAX_FACES': FF_MAX_FACES,
        'TARGET_SIZE': TARGET_SIZE,
        'BATCH_SIZE': BATCH_SIZE,
        'IMAGE_CHANNEL': IMAGE_CHANNEL,
        'Z_DIM': Z_DIM,
        'G_HIDDEN': G_HIDDEN,
        'D_HIDDEN': D_HIDDEN,
        'X_DIM': X_DIM,
        'EPOCH_NUM': EPOCH_NUM,
        'lr': lr,
        'seed': seed,
        'EXPERIMENT_NAME': EXPERIMENT_NAME,
        'timestamp': datetime.datetime.now().isoformat()
    }
    
    config_path = os.path.join(OUTPUT_PATH, f"{EXPERIMENT_NAME}_config.json")
    with open(config_path, 'w') as f:
        json.dump(config, f, indent=4)
    
    print(f"Configuration saved to {config_path}")
    return config_path

def monitor_resources():
    """Monitor system resources and return a report"""
    if not RESOURCE_MONITORING:
        return "Resource monitoring unavailable. Install psutil and gputil."
    
    # CPU info
    cpu_percent = psutil.cpu_percent(interval=1)
    memory = psutil.virtual_memory()
    memory_percent = memory.percent
    
    # GPU info
    gpu_info = "No GPU available"
    if torch.cuda.is_available():
        try:
            gpus = GPUtil.getGPUs()
            if gpus:
                gpu = gpus[0]  # Get the first GPU
                gpu_name = gpu.name
                gpu_load = f"{gpu.load * 100:.1f}%"
                gpu_mem_used = f"{gpu.memoryUsed:.0f}MB"
                gpu_mem_total = f"{gpu.memoryTotal:.0f}MB"
                gpu_mem_percent = f"{(gpu.memoryUsed / gpu.memoryTotal) * 100:.1f}%"
                gpu_temp = f"{gpu.temperature}°C"
                gpu_info = f"{gpu_name}: {gpu_load} load, {gpu_mem_used}/{gpu_mem_total} ({gpu_mem_percent}), {gpu_temp}"
        except Exception as e:
            gpu_info = f"Error getting GPU info: {e}"
    
    return {
        "cpu_percent": cpu_percent,
        "memory_percent": memory_percent,
        "gpu_info": gpu_info,
        "timestamp": datetime.datetime.now().isoformat()
    }

# Create directories and verify paths
celebdf_dir, ff_dir, combined_dir = create_output_directories()
celebdf_exists, ff_exists = verify_dataset_paths()
save_config()


Created output directories:
 - CelebDF: c:\Users\vinay\Documents\mnist\processed_faces\celebdf
 - FaceForensics++: c:\Users\vinay\Documents\mnist\processed_faces\faceforensics
 - Combined: c:\Users\vinay\Documents\mnist\processed_faces\combined
 - Results: c:\Users\vinay\Documents\mnist\output
 - Checkpoints: c:\Users\vinay\Documents\mnist\output\checkpoints
 - Logs: c:\Users\vinay\Documents\mnist\output\logs
CelebDF dataset: Found at c:\Users\vinay\Documents\mnist\faces\Real\Celeb_V2\Train\real
FaceForensics++ dataset: Found at c:\Users\vinay\Documents\mnist\faces\Real\FaceForensics++\original_sequences\youtube\c23\frames
Configuration saved to c:\Users\vinay\Documents\mnist\output\face_gan_20250306_113635_config.json


'c:\\Users\\vinay\\Documents\\mnist\\output\\face_gan_20250306_113635_config.json'

In [4]:

###############################################################################
# CELL 4: CELEBDF DATASET PROCESSING
###############################################################################
"""
CelebDF Dataset Processing
------------------------
Process the CelebDF dataset by extracting faces, resizing them,
and saving them to a standardized format.
"""
def process_celebdf_dataset(source_dir, target_dir, target_size=(128, 128), max_images=None):
    """Process CelebDF dataset where all images are in a single folder"""
    source_path = Path(source_dir)
    target_path = Path(target_dir)
    
    if not source_path.exists():
        print(f"Source directory {source_path} does not exist.")
        return 0
    
    # Get all image files
    image_extensions = ['.jpg', '.jpeg', '.png']
    image_files = []
    for ext in image_extensions:
        image_files.extend(source_path.glob(f"*{ext}"))
        image_files.extend(source_path.glob(f"*{ext.upper()}"))
    
    # If maximum image limit is set, randomly sample
    if max_images and len(image_files) > max_images:
        random.seed(1)  # For reproducible results
        image_files = random.sample(image_files, max_images)
    
    print(f"Found {len(image_files)} total images in {source_path}")
    
    # Process all images directly with progress bar
    count = 0
    with tqdm(total=len(image_files), desc="Processing CelebDF images", unit="img") as pbar:
        for img_path in image_files:
            try:
                # Open and resize image
                img = Image.open(img_path)
                img = img.resize(target_size, Image.LANCZOS)
                
                # Save to target directory
                target_file = target_path / f"celebdf_{count:06d}{img_path.suffix}"
                img.save(target_file)
                count += 1
                pbar.update(1)
                    
            except Exception as e:
                print(f"\nError processing {img_path.name}: {e}")
    
    print(f"Successfully processed {count} images from CelebDF to {target_path}")
    return count

# Only run if PROCESS_DATASETS is True and CelebDF dataset exists
if PROCESS_DATASETS and celebdf_exists:
    count = process_celebdf_dataset(
        CELEBDF_PATH, 
        celebdf_dir, 
        target_size=(TARGET_SIZE, TARGET_SIZE),
        max_images=None  # Set to None to process all images, or a number to limit
    )
    print(f"Processed {count} CelebDF images")
else:
    print("Skipping CelebDF processing. Set PROCESS_DATASETS=True to process.")


Skipping CelebDF processing. Set PROCESS_DATASETS=True to process.


In [5]:
###############################################################################
# CELL 5: FACEFORENSICS++ DATASET PROCESSING
###############################################################################
"""
FaceForensics++ Dataset Processing
-------------------------------
Process the FaceForensics++ dataset by extracting face folders,
resizing images, and saving them to a standardized format.
"""
def process_faceforensics_dataset(source_dir, target_dir, target_size=(128, 128), max_folders=None, max_images_per_folder=None):
    """Process FaceForensics++ dataset where images are organized in numbered folders"""
    source_path = Path(source_dir)
    target_path = Path(target_dir)
    
    if not source_path.exists():
        print(f"Source directory {source_path} does not exist.")
        return 0
    
    # Get all folders in the source directory
    face_folders = [f for f in source_path.iterdir() if f.is_dir()]
    
    # Limit number of folders if specified
    if max_folders and len(face_folders) > max_folders:
        face_folders = random.sample(face_folders, max_folders)
    
    print(f"Found {len(face_folders)} face folders in {source_path}")
    
    # Count total images to process for progress bar
    total_images = 0
    folder_image_counts = []
    for folder in face_folders:
        image_files = []
        for ext in ['.jpg', '.jpeg', '.png']:
            image_files.extend(folder.glob(f"*{ext}"))
            image_files.extend(folder.glob(f"*{ext.upper()}"))
        
        count = len(image_files)
        if max_images_per_folder and count > max_images_per_folder:
            count = max_images_per_folder
        
        folder_image_counts.append(count)
        total_images += count
    
    # Process images in each folder with progress bar
    count = 0
    with tqdm(total=total_images, desc="Processing FaceForensics++ images", unit="img") as pbar:
        for folder_idx, folder in enumerate(face_folders):
            # Get image files in this folder
            image_files = []
            for ext in ['.jpg', '.jpeg', '.png']:
                image_files.extend(folder.glob(f"*{ext}"))
                image_files.extend(folder.glob(f"*{ext.upper()}"))
            
            # Limit number of images per folder if specified
            if max_images_per_folder and len(image_files) > max_images_per_folder:
                image_files = random.sample(image_files, max_images_per_folder)
            
                try:
                    # Open and resize image
                    img = Image.open(img_path)
                    img = img.resize(target_size, Image.LANCZOS)
                    
                    # Save to target directory with folder index as face ID
                    target_file = target_path / f"ff_{folder_idx:04d}_{img_idx:04d}{img_path.suffix}"
                    img.save(target_file)
                    count += 1
                    pbar.update(1)
                        
                except Exception as e:
                    print(f"\nError processing {img_path.name}: {e}")
    
    print(f"Successfully processed {count} images from FaceForensics++ to {target_path}")
    return count

In [6]:
# Only run if PROCESS_DATASETS is True and FaceForensics++ dataset exists
if PROCESS_DATASETS and ff_exists:
    count = process_faceforensics_dataset(
        FF_PATH, 
        ff_dir, 
        target_size=(TARGET_SIZE, TARGET_SIZE),
        max_folders=FF_MAX_FACES,
        max_images_per_folder=FF_MAX_PER_FACE
    )
    print(f"Processed {count} FaceForensics++ images")
else:
    print("Skipping FaceForensics++ processing. Set PROCESS_DATASETS=True to process.")


Skipping FaceForensics++ processing. Set PROCESS_DATASETS=True to process.


In [7]:
###############################################################################
# CELL 6: COMBINE DATASETS
###############################################################################
"""
Dataset Combination
----------------
Combine the processed CelebDF and FaceForensics++ datasets 
into a single dataset for training. This ensures we have a diverse
set of facial images.
"""
def combine_datasets(celebdf_dir, ff_dir, combined_dir):
    """Combine processed datasets into one directory"""
    # Copy all images from CelebDF directory
    celebdf_files = list(Path(celebdf_dir).glob("*.jpg")) + list(Path(celebdf_dir).glob("*.png"))
    ff_files = list(Path(ff_dir).glob("*.jpg")) + list(Path(ff_dir).glob("*.png"))
    
    total_files = len(celebdf_files) + len(ff_files)
    print(f"Combining {len(celebdf_files)} CelebDF images and {len(ff_files)} FaceForensics++ images...")
    
    with tqdm(total=total_files, desc="Combining datasets", unit="img") as pbar:
        # Copy CelebDF images
        for img_path in celebdf_files:
            try:
                target_path = Path(combined_dir) / img_path.name
                img = Image.open(img_path)
                img.save(target_path)
                pbar.update(1)
            except Exception as e:
                print(f"\nError copying {img_path.name}: {e}")
        
        # Copy FaceForensics++ images
        for img_path in ff_files:
            try:
                target_path = Path(combined_dir) / img_path.name
                img = Image.open(img_path)
                img.save(target_path)
                pbar.update(1)
            except Exception as e:
                print(f"\nError copying {img_path.name}: {e}")
    
    print(f"Combined dataset created with {total_files} total images at {combined_dir}")
    return total_files

# Only run if both datasets have been processed
if PROCESS_DATASETS and celebdf_exists and ff_exists:
    total_images = combine_datasets(celebdf_dir, ff_dir, combined_dir)
    dataset_path = combined_dir
    print(f"Combined dataset created with {total_images} images")
elif celebdf_exists and os.path.exists(celebdf_dir) and len(os.listdir(celebdf_dir)) > 0:
    dataset_path = celebdf_dir
    print(f"Using CelebDF dataset at {celebdf_dir}")
elif ff_exists and os.path.exists(ff_dir) and len(os.listdir(ff_dir)) > 0:
    dataset_path = ff_dir
    print(f"Using FaceForensics++ dataset at {ff_dir}")
elif os.path.exists(combined_dir) and len(os.listdir(combined_dir)) > 0:
    dataset_path = combined_dir
    print(f"Using existing combined dataset at {combined_dir}")
else:
    print("Error: No datasets were processed or found")
    dataset_path = None

print(f"Dataset path for training: {dataset_path}")


Using CelebDF dataset at c:\Users\vinay\Documents\mnist\processed_faces\celebdf
Dataset path for training: c:\Users\vinay\Documents\mnist\processed_faces\celebdf


In [8]:

###############################################################################
# CELL 7: DATASET CLASS AND CUDA SETUP
###############################################################################
"""
Dataset Class and CUDA Setup
-------------------------
Define the custom dataset class for loading face images
and set up CUDA for GPU acceleration if available.
"""
class FaceDataset(Dataset):
    """Custom Dataset for loading face images"""
    def __init__(self, root_dir, transform=None, file_extensions=('.png', '.jpg', '.jpeg')):
        self.root_dir = root_dir
        self.transform = transform
        self.file_extensions = file_extensions
        
        # Count all files first for debugging
        all_files = os.listdir(root_dir)
        print(f"Total files in directory: {len(all_files)}")
        
        # Get valid image files
        self.image_files = [
            f for f in all_files 
            if os.path.isfile(os.path.join(root_dir, f)) and 
               any(f.lower().endswith(ext) for ext in self.file_extensions)
        ]
        
        # Log file extension stats
        extensions = {}
        for f in self.image_files:
            ext = os.path.splitext(f)[1].lower()
            extensions[ext] = extensions.get(ext, 0) + 1
        
        print(f"Found {len(self.image_files)} valid images out of {len(all_files)} total files")
        print(f"Image extension breakdown: {extensions}")
        
        # Check for potentially problematic files
        non_image_files = len(all_files) - len(self.image_files)
        if non_image_files > 0:
            print(f"Warning: {non_image_files} files were skipped (not recognized as images)")
            # Sample some skipped files for debugging
            skipped = [f for f in all_files if f not in self.image_files][:5]
            print(f"Sample skipped files: {skipped}")
        
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        img_path = os.path.join(self.root_dir, self.image_files[idx])
        try:
            image = Image.open(img_path).convert('RGB')
            
            if self.transform:
                image = self.transform(image)
                
            return image, 0  # Return 0 as dummy label
        except Exception as e:
            print(f"Error loading {img_path}: {e}")
            # Return a placeholder black image if loading fails
            if self.transform:
                return torch.zeros(3, 128, 128), 0
            else:
                return Image.new('RGB', (128, 128), color='black'), 0

def setup_cuda():
    """Setup CUDA if available"""
    global CUDA, device
    # Set random seed for reproducibility
    random.seed(seed)
    torch.manual_seed(seed)
    
    if not torch.cuda.is_available():
        CUDA = False
        print("CUDA is not available. Running on CPU.")
        device = torch.device("cpu")
    else:
        device = torch.device("cuda:0")
        # Enable deterministic behavior for reproducibility
        torch.cuda.manual_seed(seed)
        cudnn.benchmark = True
    
    print(f"PyTorch version: {torch.__version__}")
    if CUDA:
        print(f"CUDA version: {torch.version.cuda}")
        print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Using device: {device}")
    return device

# Setup CUDA
device = setup_cuda()


PyTorch version: 2.6.0+cu126
CUDA version: 12.6
GPU: NVIDIA GeForce RTX 2070 Super with Max-Q Design
Using device: cuda:0


In [9]:
###############################################################################
# CELL 8: GAN MODEL ARCHITECTURE
###############################################################################
"""
GAN Model Architecture
-------------------
Define the Generator and Discriminator architecture
for the Deep Convolutional GAN (DCGAN).
"""
def weights_init(m):
    """Initialize network weights"""
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)


# Replace your current Generator class in Cell 8 with this code:
class Generator(nn.Module):
    """Generator Network with LeakyReLU for better gradient flow"""
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            # Input: Z_DIM x 1 x 1
            nn.ConvTranspose2d(Z_DIM, G_HIDDEN * 16, kernel_size=4, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(G_HIDDEN * 16),
            nn.LeakyReLU(0.2, inplace=True),  # Changed from ReLU to LeakyReLU
            # 4x4
            
            nn.ConvTranspose2d(G_HIDDEN * 16, G_HIDDEN * 8, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(G_HIDDEN * 8),
            nn.LeakyReLU(0.2, inplace=True),  # Changed from ReLU to LeakyReLU
            # 8x8
            
            nn.ConvTranspose2d(G_HIDDEN * 8, G_HIDDEN * 4, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(G_HIDDEN * 4),
            nn.LeakyReLU(0.2, inplace=True),  # Changed from ReLU to LeakyReLU
            # 16x16
            
            nn.ConvTranspose2d(G_HIDDEN * 4, G_HIDDEN * 2, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(G_HIDDEN * 2),
            nn.LeakyReLU(0.2, inplace=True),  # Changed from ReLU to LeakyReLU
            # 32x32
            
            nn.ConvTranspose2d(G_HIDDEN * 2, G_HIDDEN, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(G_HIDDEN),
            nn.LeakyReLU(0.2, inplace=True),  # Changed from ReLU to LeakyReLU
            # 64x64
            
            nn.ConvTranspose2d(G_HIDDEN, IMAGE_CHANNEL, kernel_size=4, stride=2, padding=1, bias=False),
            nn.Tanh()
            # 128x128
        )

    def forward(self, input):
        return self.main(input)


class Discriminator(nn.Module):
    """Discriminator Network"""
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            # Input: IMAGE_CHANNEL x 128 x 128
            nn.Conv2d(IMAGE_CHANNEL, D_HIDDEN, kernel_size=4, stride=2, padding=1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # 64x64
            
            nn.Conv2d(D_HIDDEN, D_HIDDEN * 2, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(D_HIDDEN * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # 32x32
            
            nn.Conv2d(D_HIDDEN * 2, D_HIDDEN * 4, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(D_HIDDEN * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # 16x16
            
            nn.Conv2d(D_HIDDEN * 4, D_HIDDEN * 8, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(D_HIDDEN * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # 8x8
            
            nn.Conv2d(D_HIDDEN * 8, D_HIDDEN * 16, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(D_HIDDEN * 16),
            nn.LeakyReLU(0.2, inplace=True),
            # 4x4
            
            nn.Conv2d(D_HIDDEN * 16, 1, kernel_size=4, stride=1, padding=0, bias=False),
            nn.Sigmoid()
            # 1x1
        )

    def forward(self, input):
        return self.main(input).view(-1, 1).squeeze(1)

# Initialize networks
netG = Generator().to(device)
netG.apply(weights_init)
print("Generator initialized")

netD = Discriminator().to(device)
netD.apply(weights_init)
print("Discriminator initialized")

# Display network architecture summaries
print("\nGenerator Architecture:")
print(netG)
print("\nDiscriminator Architecture:")
print(netD)


Generator initialized
Discriminator initialized

Generator Architecture:
Generator(
  (main): Sequential(
    (0): ConvTranspose2d(100, 1024, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.2, inplace=True)
    (3): ConvTranspose2d(1024, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): LeakyReLU(negative_slope=0.2, inplace=True)
    (6): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): LeakyReLU(negative_slope=0.2, inplace=True)
    (9): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=T

In [10]:

###############################################################################
# CELL 9: DATA LOADING WITH ENHANCED DIAGNOSTICS (VISUALIZATION REMOVED)
###############################################################################
"""
Data Loading with Enhanced Diagnostics
---------------------------
Load the processed face dataset with better diagnostics and file handling.
Includes checks for the combined directory and detailed file reporting.
Visualization has been removed to improve performance.
"""
def load_dataset(data_path):
    """Load face dataset and create dataloader with detailed diagnostics"""
    # Image transformations
    transform = transforms.Compose([
        transforms.Resize(X_DIM),
        transforms.CenterCrop(X_DIM),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize for RGB
    ])
    
    # Create dataset and dataloader
    try:
        print(f"Looking for images in: {data_path}")
        
        # Check if the directory exists and is readable
        if not os.path.exists(data_path):
            print(f"Error: Directory {data_path} does not exist")
            return None, None
        elif not os.access(data_path, os.R_OK):
            print(f"Error: No read permissions for {data_path}")
            return None, None
            
        # Check if path is combined dir and has more images
        combined_dir = os.path.join(PROCESSED_PATH, "combined")
        if os.path.exists(combined_dir) and data_path != combined_dir:
            combined_count = len([f for f in os.listdir(combined_dir) 
                               if os.path.isfile(os.path.join(combined_dir, f)) and 
                               f.lower().endswith(('.png', '.jpg', '.jpeg'))])
            current_count = len([f for f in os.listdir(data_path) 
                               if os.path.isfile(os.path.join(data_path, f)) and 
                               f.lower().endswith(('.png', '.jpg', '.jpeg'))])
            
            if combined_count > current_count:
                print(f"Note: The combined directory has {combined_count} images, but you're using {data_path} with {current_count} images")
                print(f"Consider using the combined dataset at {combined_dir} instead")
                
                # Optionally, switch to the combined directory automatically
                # Uncomment the line below to use the combined directory with more images
                # data_path = combined_dir
                
        # Instantiate the dataset with additional file extensions
        dataset = FaceDataset(
            root_dir=data_path, 
            transform=transform,
            file_extensions=('.png', '.jpg', '.jpeg', '.bmp', '.tiff', '.webp')  # Add more image types
        )
        
        if len(dataset) == 0:
            print(f"Error: No valid images found in {data_path}")
            return None, None
            
        dataloader = DataLoader(
            dataset, 
            batch_size=BATCH_SIZE, 
            shuffle=True,
            num_workers=0, 
            pin_memory=True if CUDA else False
        )
        print(f"Dataset loaded with {len(dataset)} images")
        print(f"Number of batches: {len(dataloader)}")
        return dataset, dataloader
    except Exception as e:
        print(f"Error loading dataset: {str(e)}")
        import traceback
        traceback.print_exc()
        return None, None

# Load dataset if path exists
if dataset_path and os.path.exists(dataset_path):
    # Check if combined dataset exists and has more images than current dataset
    combined_dir = os.path.join(PROCESSED_PATH, "combined")
    if os.path.exists(combined_dir):
        # Quick count of valid images in both directories
        combined_files = [f for f in os.listdir(combined_dir) 
                         if os.path.isfile(os.path.join(combined_dir, f)) and 
                         f.lower().endswith(('.png', '.jpg', '.jpeg'))]
        current_files = [f for f in os.listdir(dataset_path) 
                        if os.path.isfile(os.path.join(dataset_path, f)) and 
                        f.lower().endswith(('.png', '.jpg', '.jpeg'))]
        
        print(f"Dataset comparison: Current path has {len(current_files)} images, Combined directory has {len(combined_files)} images")
        
        if len(combined_files) > len(current_files):
            print(f"Switching to combined dataset with more images: {combined_dir}")
            dataset_path = combined_dir
    
    dataset, dataloader = load_dataset(dataset_path)
    if not dataset or not dataloader:
        print("Failed to load dataset or create dataloader.")
else:
    print("Dataset path not available. Please check your configurations.")
    dataloader = None


Dataset comparison: Current path has 80576 images, Combined directory has 144473 images
Switching to combined dataset with more images: c:\Users\vinay\Documents\mnist\processed_faces\combined
Looking for images in: c:\Users\vinay\Documents\mnist\processed_faces\combined
Total files in directory: 144473
Found 144473 valid images out of 144473 total files
Image extension breakdown: {'.jpg': 80576, '.png': 63897}
Dataset loaded with 144473 images
Number of batches: 2258


In [11]:
###############################################################################
# CELL 10: CHECKPOINT FUNCTIONS
###############################################################################
"""
Checkpoint Functions
----------------
Functions for saving and loading model checkpoints
to support resumable training and best model selection.
"""
def save_checkpoint(netG, netD, optimG, optimD, epoch, iteration, losses, img_list, filename=None):
    """Save a checkpoint of the current training state"""
    if not filename:
        filename = os.path.join(CHECKPOINT_DIR, f"checkpoint_e{epoch}_i{iteration}.pt")
    
    checkpoint = {
        'epoch': epoch,
        'iteration': iteration,
        'netG_state_dict': netG.state_dict(),
        'netD_state_dict': netD.state_dict(),
        'optimG_state_dict': optimG.state_dict(),
        'optimD_state_dict': optimD.state_dict(),
        'G_losses': losses['G'],
        'D_losses': losses['D'],
        'img_list': img_list,
        'timestamp': datetime.datetime.now().isoformat()
    }
    
    torch.save(checkpoint, filename)
    print(f"Checkpoint saved: {filename}")
    
    # Save the latest checkpoint also (for resuming)
    latest_path = os.path.join(CHECKPOINT_DIR, "latest_checkpoint.pt")
    torch.save(checkpoint, latest_path)
    
    return filename

def load_checkpoint(netG, netD, optimG, optimD, filename=None):
    """Load a checkpoint to resume training"""
    if not filename:
        # Try to load the latest checkpoint
        filename = os.path.join(CHECKPOINT_DIR, "latest_checkpoint.pt")
    
    if not os.path.exists(filename):
        print(f"No checkpoint found at {filename}")
        return None, 0, 0, {'G': [], 'D': []}, []
    
    print(f"Loading checkpoint: {filename}")
    checkpoint = torch.load(filename, map_location=device)
    
    netG.load_state_dict(checkpoint['netG_state_dict'])
    netD.load_state_dict(checkpoint['netD_state_dict'])
    optimG.load_state_dict(checkpoint['optimG_state_dict'])
    optimD.load_state_dict(checkpoint['optimD_state_dict'])
    
    # For optimizers loaded from checkpoint, move to correct device
    if CUDA:
        for state in optimG.state.values():
            for k, v in state.items():
                if torch.is_tensor(v):
                    state[k] = v.to(device)
                    
        for state in optimD.state.values():
            for k, v in state.items():
                if torch.is_tensor(v):
                    state[k] = v.to(device)
    
    epoch = checkpoint['epoch']
    iteration = checkpoint['iteration']
    G_losses = checkpoint['G_losses']
    D_losses = checkpoint['D_losses']
    img_list = checkpoint.get('img_list', [])
    
    print(f"Resuming from epoch {epoch+1}, iteration {iteration}")
    
    return {'netG': netG, 'netD': netD, 'optimG': optimG, 'optimD': optimD}, epoch, iteration, {'G': G_losses, 'D': D_losses}, img_list


In [12]:
###############################################################################
# CELL 11: TRAINING INITIALIZATION
###############################################################################
"""
Training Initialization
--------------------
Initialize GAN training by setting up optimizers, 
creating fixed noise for visualization,
and loading a checkpoint if resuming training.
"""
# Setup loss function and optimizers
criterion = nn.BCELoss()
# Fixed noise for visualization
fixed_noise = torch.randn(16, Z_DIM, 1, 1, device=device)

# Initialize optimizers
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(0.5, 0.999))

# Setup TensorBoard
if ENABLE_TENSORBOARD:
    tb_writer = SummaryWriter(log_dir=os.path.join(LOG_DIR, EXPERIMENT_NAME))
    # Add model graph to TensorBoard
    try:
        sample_input = torch.randn(1, Z_DIM, 1, 1, device=device)
        tb_writer.add_graph(netG, sample_input)
    except Exception as e:
        print(f"Could not add model graph to TensorBoard: {e}")
else:
    tb_writer = None

# Try to load checkpoint if resuming training
start_epoch = 0
start_iter = 0
G_losses = []
D_losses = []
img_list = []

if RESUME_TRAINING:
    models, start_epoch, start_iter, losses, prev_img_list = load_checkpoint(netG, netD, optimizerG, optimizerD)
    if models:  # If checkpoint was loaded successfully
        netG, netD = models['netG'], models['netD']
        optimizerG, optimizerD = models['optimG'], models['optimD']
        G_losses = losses['G']
        D_losses = losses['D']
        img_list = prev_img_list
        print(f"Resuming from epoch {start_epoch+1}, iteration {start_iter}")

print(f"\n{'='*50}\nREADY FOR TRAINING\n{'='*50}")


No checkpoint found at c:\Users\vinay\Documents\mnist\output\checkpoints\latest_checkpoint.pt

READY FOR TRAINING


In [None]:

###############################################################################
# CELL 12: ENHANCED TRAINING LOOP WITH STABILIZATION TECHNIQUES
###############################################################################
"""
Enhanced Training Loop for Face GAN
----------------------------------
Features:
- Label smoothing (real_label=0.9)
- Adaptive learning rates
- Noisy label training
- Conditional discriminator updates
- Two-timescale update rule (TTUR)
- Multiple generator updates for difficult batches
"""
import time
import os
import csv
import torch
import torch.nn as nn
import torch.utils.data
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import datetime

# Create output directory structure
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
run_name = f"face_gan_run_{timestamp}"
run_dir = os.path.join(OUTPUT_PATH, run_name)

# Create subdirectories for all outputs
images_dir = os.path.join(run_dir, "images")
#checkpoints_dir = os.path.join(run_dir, "checkpoints") this was a bug.
checkpoints_dir = CHECKPOINT_DIR 
logs_dir = os.path.join(run_dir, "logs")
timelapse_dir = os.path.join(run_dir, "timelapse")

# Create timelapse subdirectories
angles = ["grid", "row", "individual", "interpolation"]
timelapse_subdirs = {}
for angle in angles:
    dir_path = os.path.join(timelapse_dir, angle)
    os.makedirs(dir_path, exist_ok=True)
    timelapse_subdirs[angle] = dir_path

# Create all other directories
for dir_path in [run_dir, images_dir, checkpoints_dir, logs_dir]:
    os.makedirs(dir_path, exist_ok=True)

print(f"Created output directory structure at {run_dir}")

# GPU Optimization Steps
if torch.cuda.is_available():
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.enabled = True
    torch.set_float32_matmul_precision('high')

# Create dataloader
dataloader = torch.utils.data.DataLoader(
    dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=True,
    num_workers=0,
    pin_memory=True if CUDA else False,
    drop_last=True
)

# Calculate total iterations
total_iterations = len(dataloader) * EPOCH_NUM
print(f"Total iterations: {total_iterations}")

# Setup CSV logging file
csv_log_path = os.path.join(logs_dir, "training_metrics.csv")
with open(csv_log_path, 'w', newline='') as csvfile:
    writer = csv.writer(csvfile)
    writer.writerow(['Epoch', 'Iteration', 'G_loss', 'D_loss', 'D_x', 'D_G_z1', 'D_G_z2', 
                     'Time_elapsed', 'ETA'])

# Create fixed noise vectors for visualization
fixed_noise = torch.randn(25, Z_DIM, 1, 1, device=device)
fixed_noise_row = torch.randn(8, Z_DIM, 1, 1, device=device)
interpolation_points = 10
fixed_noise_start = torch.randn(1, Z_DIM, 1, 1, device=device)
fixed_noise_end = torch.randn(1, Z_DIM, 1, 1, device=device)
interpolation_noises = []
for i in range(interpolation_points):
    t = i / (interpolation_points - 1)
    interp_noise = (1-t) * fixed_noise_start + t * fixed_noise_end
    interpolation_noises.append(interp_noise)


# Modified save_timelapse_images function to use fixed seeds for individual images
def save_timelapse_images(netG, epoch, iteration, total_iterations):
    filename_base = f"{epoch:03d}_{iteration:05d}"
    
    with torch.no_grad():
        # Grid of images - use the existing fixed_noise
        fake_grid = netG(fixed_noise).detach().cpu()
        grid_img = vutils.make_grid(fake_grid, padding=2, normalize=True, nrow=5)
        utils_filename = os.path.join(timelapse_subdirs["grid"], f"grid_{filename_base}.png") 
        vutils.save_image(fake_grid, utils_filename, nrow=5, padding=2, normalize=True)
        
        # Row of images - use the existing fixed_noise_row
        fake_row = netG(fixed_noise_row).detach().cpu()
        row_filename = os.path.join(timelapse_subdirs["row"], f"row_{filename_base}.png")
        vutils.save_image(fake_row, row_filename, nrow=8, padding=2, normalize=True)
        
        # Individual large images - use individual_fixed_noises
        # We'll use the first 4 samples from our fixed_noise to maintain consistency
        indiv_samples = netG(fixed_noise[:4]).detach().cpu()
        for i, sample in enumerate(indiv_samples):
            indiv_filename = os.path.join(timelapse_subdirs["individual"], f"img{i}_{filename_base}.png")
            vutils.save_image(sample, indiv_filename, normalize=True)
        
        # Latent space interpolation - keep the existing approach
        interp_images = []
        for noise in interpolation_noises:
            interp_images.append(netG(noise).detach().cpu())
        interp_tensor = torch.cat(interp_images, 0)
        interp_filename = os.path.join(timelapse_subdirs["interpolation"], f"interp_{filename_base}.png")
        vutils.save_image(interp_tensor, interp_filename, nrow=interpolation_points, normalize=True)
    
    return grid_img


# Function to add noise to labels
def noisy_labels(size, value, device):
    """Add noise to labels to prevent discriminator overconfidence"""
    if value > 0.5:  # real label
        # Real labels: value-0.2 to value range instead of exactly value (0.9)
        return value - 0.2 * torch.rand(size, device=device)
    else:  # fake label
        # Fake labels: 0 to 0.3 range instead of exactly 0
        return value + 0.3 * torch.rand(size, device=device)

# Lists to track progress
G_losses = []
D_losses = []
D_x_history = []  # Real accuracy
D_G_z_history = []  # Fake accuracy
img_list = []

# Training frequency settings
record_freq = 100
save_image_freq = 200
checkpoint_freq = 1

# Initialize optimizers with different learning rates (TTUR)
lr_g = lr * 0.5  # Lower learning rate for generator (1e-4)
lr_d = lr * 1.5  # Higher learning rate for discriminator (3e-4)

optimizerD = optim.Adam(netD.parameters(), lr=lr_d, betas=(0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr_g, betas=(0.5, 0.999))

# Set up labels for training
real_label = 0.9  # Use label smoothing
fake_label = 0.0

# Start timing
start_time = time.time()
global_step = 0

# Create progress bars
main_progress = tqdm(total=total_iterations, desc="Total progress")

try:
    # For each epoch
    for epoch in range(EPOCH_NUM):
        epoch_start_time = time.time()
        
        # TQDM bar for this epoch
        pbar = tqdm(total=len(dataloader), desc=f"Epoch {epoch+1}/{EPOCH_NUM}")
        
        # For each batch in the dataloader
        for i, data in enumerate(dataloader):
            # Calculate ETA and elapsed time
            elapsed_time = time.time() - start_time
            if global_step > 0:
                iterations_left = total_iterations - global_step
                eta_seconds = elapsed_time * (iterations_left / global_step)
                eta = str(datetime.timedelta(seconds=int(eta_seconds)))
            else:
                eta = "N/A"
                
            # Move batch data to device
            real_images = data[0].to(device)
            current_batch_size = real_images.size(0)
            
            # --------------------
            # Train Discriminator
            # --------------------
            netD.zero_grad()
            
            # Real images with noisy labels
            real_label_values = noisy_labels(current_batch_size, real_label, device)
            output = netD(real_images)
            errD_real = criterion(output, real_label_values)
            errD_real.backward()
            D_x = output.mean().item()
            
            # Fake images with noisy labels
            noise = torch.randn(current_batch_size, Z_DIM, 1, 1, device=device)
            fake = netG(noise)
            fake_label_values = noisy_labels(current_batch_size, fake_label, device)
            output = netD(fake.detach())
            errD_fake = criterion(output, fake_label_values)
            errD_fake.backward()
            D_G_z1 = output.mean().item()
            
            errD = errD_real + errD_fake
            
            # Only update discriminator if it's not too powerful
            if D_x < 0.8 or D_G_z1 > 0.1:
                optimizerD.step()

            # --------------------
            # Train Generator
            # --------------------
            # Train generator multiple times if it's struggling
            generator_updates = 2 if errD.item() < 1.0 and D_G_z1 < 0.3 else 1
            
            for _ in range(generator_updates):
                netG.zero_grad()
                label = torch.full((current_batch_size,), real_label, dtype=torch.float, device=device)
                
                # If we're doing multiple updates, generate fresh noise for second update
                if _ > 0:
                    noise = torch.randn(current_batch_size, Z_DIM, 1, 1, device=device)
                    fake = netG(noise)
                    
                output = netD(fake)
                errG = criterion(output, label)
                errG.backward()
                optimizerG.step()
                
            # Get final D(G(z)) after G update
            D_G_z2 = output.mean().item()
            
            # Save losses for plotting later
            G_losses.append(errG.item())
            D_losses.append(errD.item())
            D_x_history.append(D_x)
            D_G_z_history.append(D_G_z2)
            
            # Update TQDM with critical statistics
            pbar.set_postfix({
                'G_loss': f"{errG.item():.3f}",
                'D_loss': f"{errD.item():.3f}", 
                'D(x)': f"{D_x:.3f}",
                'D(G(z))': f"{D_G_z2:.3f}"
            })
            pbar.update(1)
            main_progress.update(1)
            
            # Record metrics to CSV log every record_freq batches
            if global_step % record_freq == 0:
                with open(csv_log_path, 'a', newline='') as csvfile:
                    writer = csv.writer(csvfile)
                    writer.writerow([
                        epoch+1, i, errG.item(), errD.item(), 
                        D_x, D_G_z1, D_G_z2, 
                        elapsed_time, eta
                    ])
            
            # Generate and save images for timelapse
            if global_step % save_image_freq == 0 or (epoch == EPOCH_NUM-1 and i == len(dataloader)-1):
                img_grid = save_timelapse_images(netG, epoch+1, global_step, total_iterations)
                # Store grid for final evaluation (but don't display it)
                img_list.append(img_grid)
                
                # Save loss curves without displaying
                plt.figure(figsize=(10, 8))
                plt.subplot(2, 1, 1)
                plt.title("Generator and Discriminator Loss")
                plt.plot(G_losses, label="G")
                plt.plot(D_losses, label="D")
                plt.legend()
                plt.subplot(2, 1, 2)
                plt.title("D(x) and D(G(z))")
                plt.plot(D_x_history, label="D(x)")
                plt.plot(D_G_z_history, label="D(G(z))")
                plt.legend()
                plt.savefig(os.path.join(logs_dir, f"loss_curves_{epoch+1}_{global_step}.png"))
                plt.close()
                
            # Increment global step counter
            global_step += 1
                
        # Close progress bar for this epoch
        pbar.close()
        
        # End of epoch stats
        epoch_time = time.time() - epoch_start_time
        print(f"Epoch {epoch+1}/{EPOCH_NUM} completed in {epoch_time:.2f}s")
        
        # Save checkpoint at end of epoch
        if (epoch+1) % checkpoint_freq == 0:
            latest_path = os.path.join(CHECKPOINT_DIR, "latest_checkpoint.pt")
            torch.save({
                'epoch': epoch,
                'global_step': global_step,
                'netG_state_dict': netG.state_dict(),
                'netD_state_dict': netD.state_dict(),
                'optimizerG_state_dict': optimizerG.state_dict(),
                'optimizerD_state_dict': optimizerD.state_dict(),
                'G_losses': G_losses,
                'D_losses': D_losses,
                'D_x_history': D_x_history,
                'D_G_z_history': D_G_z_history
            }, latest_path)
            print(f"Checkpoint saved: {latest_path}")
            '''
            checkpoint_path = os.path.join(checkpoints_dir, f"checkpoint_epoch{epoch+1}.pt")
            torch.save({
                'epoch': epoch,
                'global_step': global_step,
                'netG_state_dict': netG.state_dict(),
                'netD_state_dict': netD.state_dict(),
                'optimizerG_state_dict': optimizerG.state_dict(),
                'optimizerD_state_dict': optimizerD.state_dict(),
                'G_losses': G_losses,
                'D_losses': D_losses,
                'D_x_history': D_x_history,
                'D_G_z_history': D_G_z_history
            }, checkpoint_path)
            print(f"Checkpoint saved: {checkpoint_path}")'''
    
    # Close main progress bar
    main_progress.close()
    
    # Training complete
    total_time = time.time() - start_time
    print(f"Training completed in {total_time/60:.2f} minutes")
    
    # Save final summary statistics
    plt.figure(figsize=(16, 12))
    # Loss plot
    plt.subplot(2, 2, 1)
    plt.plot(G_losses, label="Generator")
    plt.plot(D_losses, label="Discriminator")
    plt.title("Loss over Training")
    plt.legend()
    
    # D(x)/D(G(z)) plot
    plt.subplot(2, 2, 2)
    plt.plot(D_x_history, label="D(x) - Real")
    plt.plot(D_G_z_history, label="D(G(z)) - Fake")
    plt.title("Discriminator Accuracy")
    plt.legend()
    
    # Final generated images grid
    if len(img_list) > 0:
        plt.subplot(2, 1, 2)
        plt.axis("off")
        plt.title("Final Generated Images")
        plt.imshow(np.transpose(img_list[-1], (1, 2, 0)))
    
    plt.savefig(os.path.join(run_dir, "final_summary.png"))
    plt.close()
    
    # Save final model
    torch.save(netG.state_dict(), os.path.join(run_dir, "generator_final.pt"))
    torch.save(netD.state_dict(), os.path.join(run_dir, "discriminator_final.pt"))
    
    print(f"Training data saved to {run_dir}")
    print("Training completed successfully!")
    
except Exception as e:
    main_progress.close()
    print(f"Error during training: {e}")
    import traceback
    traceback.print_exc()
    
    # Try to save emergency checkpoint
    try:
        error_checkpoint_path = os.path.join(checkpoints_dir, "error_checkpoint.pt")
        torch.save({
            'netG_state_dict': netG.state_dict(),
            'netD_state_dict': netD.state_dict(),
            'optimizerG_state_dict': optimizerG.state_dict(),
            'optimizerD_state_dict': optimizerD.state_dict(),
            'G_losses': G_losses,
            'D_losses': D_losses,
            'epoch': epoch,
            'global_step': global_step
        }, error_checkpoint_path)
        print(f"Emergency checkpoint saved to {error_checkpoint_path}")
    except:
        print("Failed to save emergency checkpoint")


Created output directory structure at c:\Users\vinay\Documents\mnist\output\face_gan_run_20250306_113703
Total iterations: 22570


Total progress:   0%|          | 0/22570 [00:00<?, ?it/s]

Epoch 1/10:   0%|          | 0/2257 [00:00<?, ?it/s]