# **BigGAN Implmentation for Oxford 102 Flowers**

### **Imports**

In [1]:
!pip install pytorch-fid

Collecting pytorch-fid
  Downloading pytorch_fid-0.3.0-py3-none-any.whl.metadata (5.3 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=1.0.1->pytorch-fid)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=1.0.1->pytorch-fid)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=1.0.1->pytorch-fid)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=1.0.1->pytorch-fid)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch>=1.0.1->pytorch-fid)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torc

In [2]:
import urllib.request
import tarfile
import os
import time
import datetime
import glob
import random
import numpy as np
import argparse
import types
import torch
import torch.nn as nn
from torch.nn import init
from torch.nn.parameter import Parameter
from torch.nn import functional as F
from torch.backends import cudnn
from torchvision.utils import save_image
import torchvision.datasets as dsets
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter
from PIL import Image
import imageio.v2 as imageio
import re
import matplotlib.animation as animation
import matplotlib.pyplot as plt
from torchvision.transforms import ToPILImage
import shutil
from tqdm.auto import tqdm
import pytorch_fid

### **Download Data**

In [3]:
os.makedirs("/content/flowers_data", exist_ok=True)
url = "https://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz"
local_path = "/content/flowers_data/102flowers.tgz"

urllib.request.urlretrieve(url, local_path)

with tarfile.open(local_path) as tar:
    tar.extractall(path="/content/flowers")


### **Helper functions**

In [4]:
def create_output_directory(base_path: str, sub_version: str = "") -> str:
    full_path = os.path.join(base_path, sub_version) if sub_version else base_path
    os.makedirs(full_path, exist_ok=True)
    return full_path

def move_tensor_to_device(data: torch.Tensor, device: torch.device) -> torch.Tensor:
    return data.to(device)

def denormalize_image(x: torch.Tensor) -> torch.Tensor:
    return ((x + 1) / 2).clamp_(0, 1)

def initialize_network_weights(module):
    classname = module.__class__.__name__
    if classname.find('Conv') != -1:
        init.xavier_normal_(module.weight.data)
        if module.bias is not None:
            init.constant_(module.bias.data, 0.0)
    elif classname.find('Linear') != -1:
        init.xavier_normal_(module.weight.data)
        if module.bias is not None:
            init.constant_(module.bias.data, 0.0)

def l2_normalize_vector(v, epsilon=1e-12):
    return v / (v.norm() + epsilon)


### **Normalization**

In [5]:
class SpectralNormalization(nn.Module):
    def __init__(self, module, name='weight', power_iterations=1):
        super(SpectralNormalization, self).__init__()
        self.module = module
        self.name = name
        self.power_iterations = power_iterations
        if not hasattr(self.module, self.name + "_u"):
            self._add_spectral_params()

    def _update_uv_vectors(self):
        u = getattr(self.module, self.name + "_u")
        v = getattr(self.module, self.name + "_v")
        w_bar = getattr(self.module, self.name + "_bar")

        height = w_bar.data.shape[0]
        w_matrix = w_bar.view(height, -1).data

        for _ in range(self.power_iterations):
            v.data = l2_normalize_vector(torch.mv(w_matrix.T, u.data))
            u.data = l2_normalize_vector(torch.mv(w_matrix, v.data))

        sigma = u.dot(w_matrix.mv(v))

        setattr(self.module, self.name, w_bar / sigma.expand_as(w_bar))

    def _add_spectral_params(self):
        original_weight = getattr(self.module, self.name)

        height = original_weight.data.shape[0]
        width = original_weight.view(height, -1).data.shape[1]

        u = Parameter(original_weight.data.new(height).normal_(0, 1), requires_grad=False)
        v = Parameter(original_weight.data.new(width).normal_(0, 1), requires_grad=False)
        u.data = l2_normalize_vector(u.data)
        v.data = l2_normalize_vector(v.data)

        w_bar = Parameter(original_weight.data)
        del self.module._parameters[self.name]

        self.module.register_parameter(self.name + "_u", u)
        self.module.register_parameter(self.name + "_v", v)
        self.module.register_parameter(self.name + "_bar", w_bar)

    def forward(self, *args):
        self._update_uv_vectors()
        return self.module.forward(*args)

class AdaptiveNormalization(nn.Module):
    def __init__(self, channels, condition_dim=148):
        super().__init__()
        self.batch_norm = nn.BatchNorm2d(channels, affine=False)
        self.embedding_layer = nn.Linear(condition_dim, channels * 2)

        nn.init.constant_(self.embedding_layer.weight.data[:, :channels], 1.0)
        nn.init.constant_(self.embedding_layer.weight.data[:, channels:], 0.0)
        if self.embedding_layer.bias is not None:
            nn.init.constant_(self.embedding_layer.bias.data[:channels], 1.0) # Gamma bias
            nn.init.constant_(self.embedding_layer.bias.data[channels:], 0.0) # Beta bias

    def forward(self, input_tensor, condition_vector):
        normalized_output = self.batch_norm(input_tensor)

        gamma_beta = self.embedding_layer(condition_vector)
        gamma, beta = gamma_beta.chunk(2, 1)

        # Reshape gamma and beta for broadcasting
        gamma = gamma.unsqueeze(2).unsqueeze(3)
        beta = beta.unsqueeze(2).unsqueeze(3)

        output = gamma * normalized_output + beta
        return output



### **Attention Module**

In [6]:
class SpatialAttentionBlock(nn.Module):
    def __init__(self, in_channels):
        super(SpatialAttentionBlock,self).__init__()
        self.in_channels = in_channels

        self.query_conv = SpectralNormalization(nn.Conv2d(in_channels, in_channels // 8, kernel_size=1))
        self.key_conv = SpectralNormalization(nn.Conv2d(in_channels, in_channels // 8, kernel_size=1))
        self.value_conv = SpectralNormalization(nn.Conv2d(in_channels, in_channels, kernel_size=1))

        # Learnable gamma parameter for weighted sum
        self.gamma_param = nn.Parameter(torch.zeros(1))

        self.softmax_fn = nn.Softmax(dim=-1)
        self.post_attention_conv = SpectralNormalization(nn.Conv2d(in_channels, in_channels, kernel_size=1))


    def forward(self, x):
        batch_size, C, H, W = x.size()

        # Reshape for matrix multiplication
        proj_query = self.query_conv(x).view(batch_size, -1, H * W).permute(0, 2, 1) # B x HW x C'
        proj_key = self.key_conv(x).view(batch_size, -1, H * W)                     # B x C' x HW
        proj_value = self.value_conv(x).view(batch_size, -1, H * W)                 # B x C x HW

        # Calculate attention map
        energy = torch.bmm(proj_query, proj_key) # B x HW x HW
        attention = self.softmax_fn(energy)      # B x HW x HW

        # Apply attention to value
        output_attention = torch.bmm(proj_value, attention.permute(0, 2, 1)) # B x C x HW
        output_attention = output_attention.view(batch_size, C, H, W)        # B x C x H x W

        # Pass through an additional convolution
        output_attention = self.post_attention_conv(output_attention)

        # Residual connection with learnable gamma
        output = self.gamma_param * output_attention + x
        return output



### **Residual Blocks**

In [7]:
class GenResBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, padding=1,
                 condition_dim=148, use_upsample=True):
        super().__init__()
        self.use_upsample = use_upsample

        # Convolutional layers with Spectral Normalization
        self.conv1 = SpectralNormalization(nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding))
        self.conv2 = SpectralNormalization(nn.Conv2d(out_channels, out_channels, kernel_size, padding=padding))
        self.conv3 = SpectralNormalization(nn.Conv2d(out_channels, out_channels, kernel_size, padding=padding))

        # Adaptive Normalization layers
        self.norm1 = AdaptiveNormalization(out_channels, condition_dim)
        self.norm2 = AdaptiveNormalization(out_channels, condition_dim)

        self.skip_projection = False
        # 1x1 convolution for skip connection if channels change or upsampling is used
        if in_channels != out_channels or use_upsample:
            self.conv_skip = SpectralNormalization(nn.Conv2d(in_channels, out_channels, 1, padding=0))
            self.skip_projection = True

        self.activation = nn.ReLU()

    def forward(self, input_tensor, condition_vector):
        out = self.conv1(input_tensor)
        out = self.activation(out)
        out = self.norm1(out, condition_vector)

        if self.use_upsample:
            out = F.interpolate(out, scale_factor=2, mode='nearest') # Upsample

        out = self.conv2(out)
        out = self.activation(out)
        out = self.norm2(out, condition_vector)
        out = self.conv3(out) # Added conv3

        skip = input_tensor
        if self.skip_projection:
            skip = self.activation(skip) # Added activation to skip
            if self.use_upsample:
                skip = F.interpolate(skip, scale_factor=2, mode='nearest') # Upsample skip
            skip = self.conv_skip(skip)

        return out + skip # Add skip connection output


class DiscResBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, padding=1,
                 use_downsample=True):
        super().__init__()
        self.use_downsample = use_downsample

        # Convolutional layers with Spectral Normalization
        self.conv1 = SpectralNormalization(nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding))
        self.conv2 = SpectralNormalization(nn.Conv2d(out_channels, out_channels, kernel_size, padding=padding))
        self.conv3 = SpectralNormalization(nn.Conv2d(out_channels, out_channels, kernel_size, padding=padding)) # Added layer

        self.skip_projection = False
        # 1x1 convolution for skip connection if channels change or downsampling is used
        if in_channels != out_channels or use_downsample:
            self.conv_skip = SpectralNormalization(nn.Conv2d(in_channels, out_channels, 1, padding=0))
            self.skip_projection = True

        self.activation = nn.LeakyReLU(0.2)

    def forward(self, input_tensor):
        out = self.conv1(input_tensor)
        out = self.activation(out)
        if self.use_downsample:
            out = F.avg_pool2d(out, 2) # Downsample

        out = self.conv2(out)
        out = self.activation(out)
        out = self.conv3(out) # Added conv3

        # Skip connection: input_tensor -> LeakyReLU -> conv_skip -> (downsample if needed)
        skip = input_tensor
        if self.skip_projection:
            skip = self.activation(skip) # Added activation to skip
            skip = self.conv_skip(skip)
            if self.use_downsample:
                skip = F.avg_pool2d(skip, 2) # Downsample skip

        return out + skip # Add skip connection output



### **Generator and Discriminator Model**

In [8]:
class ImageGenerator(nn.Module):
    def __init__(self, latent_dim=120, num_classes=1000, base_channels=96):
        super().__init__()
        self.latent_dim = latent_dim
        self.num_classes = num_classes
        self.base_channels = base_channels

        # Class embedding layer for conditioning
        self.class_embedding_layer = SpectralNormalization(nn.Linear(num_classes, 128, bias=False))

        self.initial_dense_layer = SpectralNormalization(nn.Linear(20, 8 * 8 * (16 * base_channels)))
        self.initial_feature_map_channels = 16 * base_channels

        # Generator blocks
        self.gen_blocks = nn.ModuleList([
            # GenResBlock without upsample for the first block if starting at 8x8
            GenResBlock(16 * base_channels, 16 * base_channels, condition_dim=128 + 20, use_upsample=False), # 8x8 -> 8x8
            SpatialAttentionBlock(16 * base_channels), # Attention at 8x8 (Moved earlier)
            GenResBlock(16 * base_channels, 8 * base_channels, condition_dim=128 + 20),  # 8x8 -> 16x16
            GenResBlock(8 * base_channels, 4 * base_channels, condition_dim=128 + 20),   # 16x16 -> 32x32
            GenResBlock(4 * base_channels, 2 * base_channels, condition_dim=128 + 20),   # 32x32 -> 64x64
            GenResBlock(2 * base_channels, 1 * base_channels, condition_dim=128 + 20)    # 64x64 -> 128x128
        ])

        # Final layers
        self.final_batch_norm = nn.BatchNorm2d(1 * base_channels)
        self.output_conv_pre_tanh = SpectralNormalization(nn.Conv2d(1 * base_channels, 3, kernel_size=3, padding=1))

    def forward(self, latent_code, class_one_hot):
        # Split latent code into parts for different blocks
        latent_code_parts = torch.split(latent_code, 20, 1) # 120 / 20 = 6 parts

        # Get class embedding
        class_embedding = self.class_embedding_layer(class_one_hot)

        # Initial projection from latent to feature map (now 8x8)
        out = self.initial_dense_layer(latent_code_parts[0])
        out = out.view(-1, self.initial_feature_map_channels, 8, 8) # Reshape to 8x8 feature map

        latent_part_idx = 1
        # Pass through generator blocks
        for block in self.gen_blocks:
            if isinstance(block, GenResBlock):
                # Concatenate current latent part and class embedding for conditioning
                condition_vector = torch.cat([latent_code_parts[latent_part_idx], class_embedding], 1)
                out = block(out, condition_vector)
                latent_part_idx += 1
            else: # SpatialAttentionBlock
                out = block(out)

        # Final layers
        out = self.final_batch_norm(out)
        out = F.relu(out)
        out = self.output_conv_pre_tanh(out) # Pass through new final conv layer

        return torch.tanh(out) # Output image in [-1, 1] range


class ImageDiscriminator(nn.Module):
    def __init__(self, num_classes=1000, base_channels=96):
        super().__init__()
        self.num_classes = num_classes
        self.base_channels = base_channels

        self.initial_block = nn.Sequential(
            SpectralNormalization(nn.Conv2d(3, 1 * base_channels, kernel_size=3, padding=1)),
            nn.LeakyReLU(0.2),
            SpectralNormalization(nn.Conv2d(1 * base_channels, 1 * base_channels, kernel_size=3, padding=1)),
            nn.AvgPool2d(2)
        )
        # Skip connection for the initial block
        self.initial_skip_conv = SpectralNormalization(nn.Conv2d(3, 1 * base_channels, kernel_size=1, padding=0))

        self.disc_blocks = nn.Sequential(
            DiscResBlock(1 * base_channels, 1 * base_channels, use_downsample=True), # 64x64 -> 32x32
            DiscResBlock(1 * base_channels, 2 * base_channels, use_downsample=True), # 32x32 -> 16x16
            SpatialAttentionBlock(2 * base_channels), # Attention at 16x16 (Moved later)
            DiscResBlock(2 * base_channels, 4 * base_channels, use_downsample=True), # 16x16 -> 8x8
            DiscResBlock(4 * base_channels, 8 * base_channels, use_downsample=True), # 8x8 -> 4x4
            DiscResBlock(8 * base_channels, 16 * base_channels, use_downsample=True), # 4x4 -> 2x2
            DiscResBlock(16 * base_channels, 16 * base_channels, use_downsample=False) # 2x2 (no downsample)
        )

        # Output linear layer for real/fake score
        self.output_linear = SpectralNormalization(nn.Linear(16 * base_channels, 1))

        # Class embedding layer for projection discriminator
        self.class_embedding_layer = SpectralNormalization(nn.Embedding(num_classes, 16 * base_channels))
        # Initialize embedding weights
        self.class_embedding_layer.module.weight_bar.data.uniform_(-0.1, 0.1)

        self.activation = nn.LeakyReLU(0.2)


    def forward(self, image, class_id_int):
        # Initial block processing
        out = self.initial_block(image)
        # Add skip connection for initial block
        out = out + self.initial_skip_conv(F.avg_pool2d(image, 2))

        # Pass through discriminator blocks
        out = self.disc_blocks(out)
        out = F.relu(out) # Final ReLU before pooling

        # Adaptive average pooling before linear layer
        out = F.adaptive_avg_pool2d(out, 1).view(out.size(0), -1) # Global average pooling

        # Get real/fake score
        real_fake_score = self.output_linear(out).squeeze(1)

        # Get class conditional score (projection discriminator)
        class_embedding = self.class_embedding_layer(class_id_int)
        # Apply LeakyReLU to output features before dot product for conditional score
        class_conditional_score = (self.activation(out) * class_embedding).sum(1)

        return real_fake_score + class_conditional_score # Combined score


### **Model Architecture**

In [9]:
G = ImageGenerator(latent_dim=120, num_classes = 1, base_channels=64).to('cuda')
D = ImageDiscriminator(num_classes=1, base_channels=64).to('cuda')


print("\n--- Generator Architecture ---")
print(G)
print("\n--- Discriminator Architecture ---")
print(D)

del G
del D
torch.cuda.empty_cache()


--- Generator Architecture ---
ImageGenerator(
  (class_embedding_layer): SpectralNormalization(
    (module): Linear(in_features=1, out_features=128, bias=False)
  )
  (initial_dense_layer): SpectralNormalization(
    (module): Linear(in_features=20, out_features=65536, bias=True)
  )
  (gen_blocks): ModuleList(
    (0): GenResBlock(
      (conv1): SpectralNormalization(
        (module): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (conv2): SpectralNormalization(
        (module): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (conv3): SpectralNormalization(
        (module): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (norm1): AdaptiveNormalization(
        (batch_norm): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)
        (embedding_layer): Linear(in_features=148, out_features=2048, bias=True)
      )
      (norm2): AdaptiveNorma

### **Training the model**

In [12]:
imsize = 128
z_dim = 120
chn = 64 # Base channels for Generator and Discriminator
lambda_gp = 10.0 # Gradient penalty weight
version = 'Gan_flower102' # Version string for output directories
total_step = 200000 # Total training steps
d_iters = 5 # Discriminator iterations per generator iteration
batch_size = 32
num_workers = 2
g_lr = 0.0001 # Generator learning rate
d_lr = 0.0004 # Discriminator learning rate
beta1 = 0.0 # Adam beta1
beta2 = 0.9 # Adam beta2
seed = 42 # Random seed
image_path = '/content/flowers' # Path to extracted flower images
log_path = '/content/logs' # Path for TensorBoard logs
model_save_path = '/content/models' # Path for saving model
sample_path = '/content/samples' # Path for saving generated samples
attn_path = '/content/attn' # Path for attention maps (not used in this version)

log_step = 100 # Frequency for logging training progress
sample_step = 5000 # Frequency for saving generated samples

pretrained_gen_path = '100000_G.pth'
pretrained_disc_path = '100000_D.pth'

#pretrained_gen_path = None
#pretrained_disc_path = None

# --- Global Setup and Directory Creation ---
cudnn.benchmark = True # Optimize cuDNN for faster training
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True # Ensure reproducibility

# Create output directories
model_output_dir = create_output_directory(model_save_path, version)
sample_output_dir = create_output_directory(sample_path, version)
log_output_dir = create_output_directory(log_path, version)

# Determine device for training
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Training on device: {device}")


# --- Data Loader Initialization ---
# Define image transformations
transform_list = [
    transforms.CenterCrop(160), # Crop to a square
    transforms.Resize((imsize, imsize)), # Resize to target size
    transforms.ToTensor(), # Convert to tensor
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) # Normalize to [-1, 1]
]
image_transforms = transforms.Compose(transform_list)

# Load dataset from image folder structure
dataset = dsets.ImageFolder(image_path, transform=image_transforms)
num_classes = len(dataset.classes) # Get number of classes

# Create data loader
data_loader = torch.utils.data.DataLoader(
    dataset=dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    drop_last=True # Drop the last incomplete batch
)


# --- Model and Optimizer Initialization---
generator = ImageGenerator(
    latent_dim=z_dim,
    num_classes=num_classes,
    base_channels=chn
).to(device)

discriminator = ImageDiscriminator(
    num_classes=num_classes,
    base_channels=chn
).to(device)

gen_optimizer = torch.optim.Adam(
    filter(lambda p: p.requires_grad, generator.parameters()),
    lr=g_lr, betas=[beta1, beta2]
)
disc_optimizer = torch.optim.Adam(
    filter(lambda p: p.requires_grad, discriminator.parameters()),
    lr=d_lr, betas=[beta1, beta2]
)

# --- Load Pre-trained Models ---
start_step = 0
if pretrained_gen_path:
    gen_load_path = os.path.join(model_output_dir, pretrained_gen_path)
    if os.path.exists(gen_load_path):
        generator.load_state_dict(torch.load(gen_load_path, map_location=device))
        match = re.search(r'(\d+)_G\.pth', pretrained_gen_path)
        start_step = int(match.group(1))

if pretrained_disc_path:
    disc_load_path = os.path.join(model_output_dir, pretrained_disc_path)
    if os.path.exists(disc_load_path):
        discriminator.load_state_dict(torch.load(disc_load_path, map_location=device))

# --- TensorBoard Setup ---
tf_log_path = os.path.join(log_output_dir, 'tensorboard_logs')
summary_writer = SummaryWriter(log_dir=tf_log_path)

# --- Helper Functions for Training Loop ---
def zero_grad_optimizers():
    disc_optimizer.zero_grad()
    gen_optimizer.zero_grad()

def generate_random_labels(batch_size_val, num_classes_val, device_val):
    labels_int = torch.randint(0, num_classes_val, (batch_size_val,)).to(device_val)
    labels_one_hot = F.one_hot(labels_int, num_classes=num_classes_val).float().to(device_val)
    return labels_int, labels_one_hot

def save_real_image_sample_func(data_loader_val, sample_output_dir_val):
    real_images, _ = next(iter(data_loader_val))
    save_image(denormalize_image(real_images), os.path.join(sample_output_dir_val, 'real_images_sample.png'))

# Save a sample of real images
save_real_image_sample_func(data_loader, sample_output_dir)


# --- Training Loop ---
data_iterator = iter(data_loader)

# Fixed latent code and labels for consistent sample generation
fixed_latent_z = move_tensor_to_device(torch.randn(batch_size, z_dim), device)
fixed_labels_int, fixed_labels_one_hot = generate_random_labels(batch_size, num_classes, device)

start_time = time.time()

# History lists for plotting losses
d_total_loss_history = []
g_total_loss_history = []
steps_history = []

print('Starting GAN training...')
for step in range(start_step, total_step):
    generator.train()
    discriminator.train()

    # --- Train Discriminator ---
    # Fetch real images and labels
    try:
        real_images, real_labels_int = next(data_iterator)
    except StopIteration:
        # Reset iterator if end of dataset is reached
        data_iterator = iter(data_loader)
        real_images, real_labels_int = next(data_iterator)

    real_images = move_tensor_to_device(real_images, device)
    real_labels_int = move_tensor_to_device(real_labels_int, device)

    # Discriminator output for real images
    d_out_real = discriminator(real_images, real_labels_int)
    d_loss_real = - torch.mean(d_out_real) # Maximize D(real)

    # Generate fake images
    z_latent = move_tensor_to_device(torch.randn(batch_size, z_dim), device)
    fake_labels_int, fake_labels_one_hot = generate_random_labels(batch_size, num_classes, device)
    fake_images = generator(z_latent, fake_labels_one_hot)

    # Discriminator output for fake images
    d_out_fake = discriminator(fake_images.detach(), fake_labels_int) # Detach fake_images to prevent G from being updated
    d_loss_fake = d_out_fake.mean() # Minimize D(fake)

    # Calculate Gradient Penalty
    # Interpolate between real and fake images
    alpha = torch.rand(real_images.size(0), 1, 1, 1, device=device)
    alpha = alpha.expand_as(real_images)
    interpolated_images = (alpha * real_images.data + (1 - alpha) * fake_images.data).requires_grad_(True)

    # Discriminator output for interpolated images
    d_out_interpolated = discriminator(interpolated_images, real_labels_int)

    # Compute gradients of D_out_interpolated with respect to interpolated_images
    gradients = torch.autograd.grad(
        outputs=d_out_interpolated,
        inputs=interpolated_images,
        grad_outputs=torch.ones_like(d_out_interpolated, device=device),
        retain_graph=True,
        create_graph=True,
        only_inputs=True
    )[0]

    # Calculate gradient norm and penalty
    gradients = gradients.view(gradients.size(0), -1)
    grad_norm = gradients.norm(2, dim=1)
    d_loss_gp = torch.mean((grad_norm - 1) ** 2)

    # Total Discriminator Loss
    total_d_loss = d_loss_real + d_loss_fake + lambda_gp * d_loss_gp

    # Backward pass and optimize Discriminator
    zero_grad_optimizers()
    total_d_loss.backward()
    disc_optimizer.step()

    # --- Train Generator (every d_iters steps) ---
    if (step + 1) % d_iters == 0:
        # Generate new fake images
        z_latent = move_tensor_to_device(torch.randn(batch_size, z_dim), device)
        fake_labels_int, fake_labels_one_hot = generate_random_labels(batch_size, num_classes, device)

        fake_images = generator(z_latent, fake_labels_one_hot)
        g_out_fake = discriminator(fake_images, fake_labels_int) # D(G(z))

        gen_loss = - g_out_fake.mean() # Maximize D(G(z))

        # Backward pass and optimize Generator
        zero_grad_optimizers()
        gen_loss.backward()
        gen_optimizer.step()

        # --- Logging and Monitoring ---
        if (step + 1) % log_step == 0:
            elapsed_time = time.time() - start_time
            elapsed_time_str = str(datetime.timedelta(seconds=elapsed_time))
            log_message = (
                f"Step [{step + 1}/{total_step}], "
                f"D_loss_real: {d_loss_real.item():.4f}, D_loss_fake: {d_loss_fake.item():.4f}, "
            )
            log_message += f"D_loss_gp: {d_loss_gp.item():.4f}, "
            log_message += f"G_loss: {gen_loss.item():.4f}"
            print(log_message)

            steps_history.append(step)
            d_total_loss_history.append(total_d_loss.item())
            g_total_loss_history.append(gen_loss.item())

            # Log to TensorBoard
            summary_writer.add_scalar('Loss/D_real', d_loss_real.item(), (step + 1))
            summary_writer.add_scalar('Loss/D_fake', d_loss_fake.item(), (step + 1))
            summary_writer.add_scalar('Loss/D_total', total_d_loss.item(), (step + 1))
            summary_writer.add_scalar('Loss/D_gp', d_loss_gp.item(), (step + 1))
            summary_writer.add_scalar('Loss/G_total', gen_loss.item(), (step + 1))

        # --- Save Sample Images ---
        if (step + 1) % sample_step == 0:
            generator.eval() # Set generator to evaluation mode
            with torch.no_grad(): # Disable gradient calculation
                generated_samples = generator(fixed_latent_z, fixed_labels_one_hot)
            save_image(denormalize_image(generated_samples.data),
                       os.path.join(sample_output_dir, f'{step + 1}_generated.png'))
            generator.train() # Set generator back to training mode

# Save final models after training completes
torch.save(generator.state_dict(), os.path.join(model_output_dir, f'{total_step}_G.pth'))
torch.save(discriminator.state_dict(), os.path.join(model_output_dir, f'{total_step}_D.pth'))


# Close TensorBoard writer
summary_writer.close()




Training on device: cuda
Starting GAN training...
Step [100100/200000], D_loss_real: -3502.6516, D_loss_fake: 3502.4717, D_loss_gp: 0.0048, G_loss: -3475.3096
Step [100200/200000], D_loss_real: -3225.1311, D_loss_fake: 3222.5388, D_loss_gp: 0.0195, G_loss: -3300.8660
Step [100300/200000], D_loss_real: -2534.8735, D_loss_fake: 2531.3931, D_loss_gp: 0.0382, G_loss: -2533.1965
Step [100400/200000], D_loss_real: -1940.4595, D_loss_fake: 1944.1321, D_loss_gp: 0.0533, G_loss: -1974.2791
Step [100500/200000], D_loss_real: -1117.1329, D_loss_fake: 1116.0376, D_loss_gp: 0.0248, G_loss: -1168.6519
Step [100600/200000], D_loss_real: -2898.5308, D_loss_fake: 2896.0227, D_loss_gp: 0.0167, G_loss: -2999.8657
Step [100700/200000], D_loss_real: -2588.0010, D_loss_fake: 2585.8208, D_loss_gp: 0.0111, G_loss: -2580.2551
Step [100800/200000], D_loss_real: -2103.0059, D_loss_fake: 2103.6592, D_loss_gp: 0.0257, G_loss: -2214.3633
Step [100900/200000], D_loss_real: -3892.9695, D_loss_fake: 3895.2832, D_loss_

KeyboardInterrupt: 

### **Plot Generator and Discriminator losses**

In [None]:
SAMPLE_OUTPUT_DIR = os.path.join('/content/samples', 'Gan_flower102')

# --- Plot 1: Discriminator Total Loss ---
plot_path_d = os.path.join(SAMPLE_OUTPUT_DIR, 'discriminator_total_loss.png')

plt.figure(figsize=(10, 5))
plt.plot(steps_history, d_total_loss_history, label='Discriminator Total Loss', color='blue', alpha=0.8)
plt.xlabel('Training Step')
plt.ylabel('Loss')
plt.title('Discriminator Total Loss Over Time')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig(plot_path_d)
plt.show()
print(f"Discriminator loss plot saved to: {plot_path_d}")

# --- Plot 2: Generator Loss ---
plot_path_g = os.path.join(SAMPLE_OUTPUT_DIR, 'generator_loss.png')

plt.figure(figsize=(10, 5))
plt.plot(steps_history, g_total_loss_history, label='Generator Loss', color='red', alpha=0.8)
plt.xlabel('Training Step')
plt.ylabel('Loss')
plt.title('Generator Loss Over Time')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig(plot_path_g)
plt.show()
print(f"Generator loss plot saved to: {plot_path_g}")



### **Animation of Training Progress**

In [None]:
def create_animation(sample_dir, output_gif_name="training_progress.gif", fps=10):
    image_files = []
    sorted_files = sorted(
        [f for f in os.listdir(sample_dir) if f.endswith('_generated.png')],
        key=lambda x: int(re.findall(r'(\d+)_generated\.png', x)[0]) if re.findall(r'(\d+)_generated\.png', x) else 0
    )

    for f in sorted_files:
        image_files.append(os.path.join(sample_dir, f))

    images_for_gif = []
    for image_file in image_files:
      img = Image.open(image_file).convert('RGB')
      images_for_gif.append(np.array(img))

    fps = 10

    output_path = os.path.join(sample_dir, output_gif_name)

    imageio.mimsave(output_path, images_for_gif, fps=fps)

    fig, ax = plt.subplots()
    im = ax.imshow(images_for_gif[0])
    ax.axis('off')
    plt.title('GAN Training Progress Animation')

    def update(frame):
        im.set_array(images_for_gif[frame])
        return [im]

    ani = animation.FuncAnimation(
        fig, update, frames=len(images_for_gif),
        interval=1000 / fps, blit=True, repeat=False
    )

    plt.show(block=False)

create_animation(SAMPLE_OUTPUT_DIR)


### **Plot Real Images vs Fake Images**

In [None]:
def display_real_vs_fake(real_image, generated_image):
    real_img = Image.open(real_image)
    final_fake_img = Image.open(generated_image)

    fig, axes = plt.subplots(1, 2, figsize=(12, 6))

    axes[0].imshow(real_img)
    axes[0].set_title('Sample Real Images')
    axes[0].axis('off')

    axes[1].imshow(final_fake_img)
    axes[1].set_title('Final Generated Images')
    axes[1].axis('off')

    plt.suptitle('Real vs. Generated Images After Training', fontsize=16)
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    output_comparison_path = os.path.join(SAMPLE_OUTPUT_DIR, 'real_vs_fake_comparison.png')
    plt.savefig(output_comparison_path)
    plt.show()

final_step_str = str(100000)
generated_image_sample = os.path.join(SAMPLE_OUTPUT_DIR, f'{final_step_str}_generated.png')
real_image_sample = os.path.join(SAMPLE_OUTPUT_DIR, 'real_images_sample.png')

display_real_vs_fake(real_image=real_image_sample, generated_image=generated_image_sample)
