# SAGAN (Self-Attention GAN) Implementation
This notebook implements a Self-Attention GAN (SAGAN) for image generation using the CIFAR-10 dataset. Every section is thoroughly commented for clarity and serves as a reference for future GAN-based projects.

In [1]:
# Import PyTorch and its submodules for building models and data loading
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, utils  # For dataset and image utilities
import matplotlib.pyplot as plt  # For visualizing generated images
import os  # For file and directory handling


In [2]:
# Define image preprocessing steps: resize, normalize, and convert to tensor
transform = transforms.Compose([
    transforms.Resize(32),  # Resize images to 32x32
    transforms.ToTensor(),  # Convert images to PyTorch tensors
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize to the range [-1, 1]
])

# Load CIFAR-10 training dataset with the defined transformations
dataset = datasets.CIFAR10(root="./data", train=True, transform=transform, download=True)

# Use a DataLoader to fetch the data in batches for training
dataloader = torch.utils.data.DataLoader(dataset, batch_size=128, shuffle=True)

print("Dataset loaded successfully!")  # Confirmation that the dataset is ready


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:04<00:00, 35.2MB/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Dataset loaded successfully!


In [3]:
# Step 3: Define the Self-Attention Layer
class SelfAttention(nn.Module):
    def __init__(self, in_channels):
        super(SelfAttention, self).__init__()
        # Reduce dimensionality for attention computation
        self.query = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.key = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.value = nn.Conv2d(in_channels, in_channels, kernel_size=1)

        # Scaling factor to control attention contribution
        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        batch, channels, height, width = x.size()  # Get the input dimensions

        # Compute query, key, and value feature maps
        query = self.query(x).view(batch, -1, height * width).permute(0, 2, 1)  # (B, HW, C//8)
        key = self.key(x).view(batch, -1, height * width)  # (B, C//8, HW)
        attention = torch.softmax(torch.bmm(query, key), dim=-1)  # Attention map: (B, HW, HW)

        value = self.value(x).view(batch, -1, height * width)  # (B, C, HW)
        out = torch.bmm(value, attention.permute(0, 2, 1))  # Weighted value: (B, C, HW)
        out = out.view(batch, channels, height, width)  # Reshape to match input

        return self.gamma * out + x  # Combine input and attention-enhanced output


In [4]:
# Step 4: Define the Generator
class Generator(nn.Module):
    def __init__(self, z_dim, img_channels, feature_maps):
        super(Generator, self).__init__()

        # Define the starting size of the image (before upsampling)
        self.init_size = 4

        # Fully connected layer to map noise to initial features
        self.fc = nn.Sequential(
            nn.Linear(z_dim, feature_maps * 8 * self.init_size * self.init_size)
        )

        # Define the generator network
        self.gen = nn.Sequential(
            nn.BatchNorm2d(feature_maps * 8),  # Normalize initial features

            nn.Upsample(scale_factor=2),  # Upsample from 4x4 -> 8x8
            nn.Conv2d(feature_maps * 8, feature_maps * 4, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(feature_maps * 4),
            nn.ReLU(inplace=True),  # Non-linear activation

            # Self-Attention layer for global coherence
            SelfAttention(feature_maps * 4),

            nn.Upsample(scale_factor=2),  # Upsample from 8x8 -> 16x16
            nn.Conv2d(feature_maps * 4, feature_maps * 2, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(feature_maps * 2),
            nn.ReLU(inplace=True),

            nn.Upsample(scale_factor=2),  # Upsample from 16x16 -> 32x32
            nn.Conv2d(feature_maps * 2, img_channels, kernel_size=3, stride=1, padding=1),
            nn.Tanh()  # Output values in the range [-1, 1]
        )

    def forward(self, z):
        # Map noise to feature space and reshape to initial image size
        out = self.fc(z).view(z.size(0), -1, self.init_size, self.init_size)
        return self.gen(out)  # Generate final image


In [5]:
# Step 5: Define the Discriminator
class Discriminator(nn.Module):
    def __init__(self, img_channels, feature_maps):
        super(Discriminator, self).__init__()

        # Define the discriminator network
        self.disc = nn.Sequential(
            nn.Conv2d(img_channels, feature_maps, kernel_size=4, stride=2, padding=1),  # 32x32 -> 16x16
            nn.LeakyReLU(0.2, inplace=True),  # Leaky ReLU prevents "dead neurons"

            nn.Conv2d(feature_maps, feature_maps * 2, kernel_size=4, stride=2, padding=1),  # 16x16 -> 8x8
            nn.BatchNorm2d(feature_maps * 2),
            nn.LeakyReLU(0.2, inplace=True),

            # Self-Attention layer to capture global dependencies
            SelfAttention(feature_maps * 2),

            nn.Conv2d(feature_maps * 2, feature_maps * 4, kernel_size=4, stride=2, padding=1),  # 8x8 -> 4x4
            nn.BatchNorm2d(feature_maps * 4),
            nn.LeakyReLU(0.2, inplace=True),
        )

        # Fully connected layer to classify real vs fake
        self.fc = nn.Sequential(
            nn.Linear(feature_maps * 4 * 4 * 4, 1),  # Flatten features to a single value
            nn.Sigmoid()  # Output in the range [0, 1]
        )

    def forward(self, img):
        # Extract features using convolutional layers
        features = self.disc(img).view(img.size(0), -1)
        return self.fc(features)  # Return classification result


In [None]:
# Step 6: Training Loop

# Initialize models
z_dim = 100  # Dimensionality of noise vector
gen = Generator(z_dim, img_channels=3, feature_maps=64)  # Generator model
disc = Discriminator(img_channels=3, feature_maps=64)  # Discriminator model

# Define optimizers for both networks
lr = 0.0002  # Learning rate for Adam optimizer
optim_gen = optim.Adam(gen.parameters(), lr=lr, betas=(0.5, 0.999))  # Optimizer for generator
optim_disc = optim.Adam(disc.parameters(), lr=lr, betas=(0.5, 0.999))  # Optimizer for discriminator

# Binary cross-entropy loss function
criterion = nn.BCELoss()

# Training loop parameters
epochs = 50  # Number of epochs
z_dim = 100  # Latent space dimensionality

# Training loop
for epoch in range(epochs):
    for real, _ in dataloader:
        # Use CPU-friendly operations (remove `.cuda()` calls)
        batch_size = real.size(0)  # Batch size of the current batch

        # Generate random noise and fake images
        z = torch.randn(batch_size, z_dim)  # Random noise vector
        fake = gen(z).detach()  # Detach from generator to avoid gradient computation

        # Labels for real and fake images
        real_labels = torch.ones(batch_size, 1)  # Labels for real images
        fake_labels = torch.zeros(batch_size, 1)  # Labels for fake images

        # Train discriminator
        disc_loss_real = criterion(disc(real), real_labels)  # Loss for real images
        disc_loss_fake = criterion(disc(fake), fake_labels)  # Loss for fake images
        disc_loss = disc_loss_real + disc_loss_fake  # Total discriminator loss

        optim_disc.zero_grad()  # Zero out gradients for discriminator
        disc_loss.backward()  # Backpropagate discriminator loss
        optim_disc.step()  # Update discriminator weights

        # Train generator
        z = torch.randn(batch_size, z_dim)  # Generate new noise vector
        fake = gen(z)  # Generate fake images
        gen_loss = criterion(disc(fake), real_labels)  # Generator loss (fool the discriminator)

        optim_gen.zero_grad()  # Zero out gradients for generator
        gen_loss.backward()  # Backpropagate generator loss
        optim_gen.step()  # Update generator weights

    # Print progress for each epoch
    print(f"Epoch [{epoch+1}/{epochs}], Disc Loss: {disc_loss.item():.4f}, Gen Loss: {gen_loss.item():.4f}")


Epoch [1/50], Disc Loss: 1.1468, Gen Loss: 1.4797
Epoch [2/50], Disc Loss: 0.9381, Gen Loss: 1.2716
Epoch [3/50], Disc Loss: 0.9619, Gen Loss: 1.9820
Epoch [4/50], Disc Loss: 0.8605, Gen Loss: 1.8880
Epoch [5/50], Disc Loss: 0.7269, Gen Loss: 1.3613


In [None]:
# Step 7 : Generate and visualize images
# Generate random samples
z = torch.randn(16, z_dim).cuda()
samples = gen(z).detach().cpu()  # Detach from computation graph for visualization
samples = (samples + 1) / 2  # Rescale images to [0, 1]

# Create a grid of images
grid = utils.make_grid(samples, nrow=4)
plt.imshow(grid.permute(1, 2, 0))  # Convert to (H, W, C) for visualization
plt.axis("off")
plt.show()
