In [2]:
"""
GAN Image Generation System for MNIST/Fashion-MNIST
Step-by-step implementation with debugging
"""

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import save_image, make_grid
import numpy as np
import matplotlib.pyplot as plt
import os
from datetime import datetime

# ============================================================================
# STEP 1: Configuration and Input Parameters
# ============================================================================

class Config:
    """Configuration class to store all hyperparameters"""
    def __init__(self):
        self.dataset_choice = None
        self.epochs = None
        self.batch_size = None
        self.noise_dim = None
        self.learning_rate = None
        self.save_interval = None
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.img_size = 28
        self.channels = 1

    def get_user_input(self):
        """Get configuration from user input"""
        print("=" * 60)
        print("GAN IMAGE GENERATION SYSTEM")
        print("=" * 60)

        # Dataset choice
        while True:
            choice = input("\nChoose dataset (mnist/fashion): ").lower().strip()
            if choice in ['mnist', 'fashion']:
                self.dataset_choice = choice
                break
            print("Invalid choice. Please enter 'mnist' or 'fashion'")

        # Epochs
        while True:
            try:
                epochs = int(input("Enter number of epochs (recommended 30-100): "))
                if epochs > 0:
                    self.epochs = epochs
                    break
                print("Epochs must be positive")
            except ValueError:
                print("Please enter a valid integer")

        # Batch size
        while True:
            try:
                batch = int(input("Enter batch size (recommended 64 or 128): "))
                if batch > 0:
                    self.batch_size = batch
                    break
                print("Batch size must be positive")
            except ValueError:
                print("Please enter a valid integer")

        # Noise dimension
        while True:
            try:
                noise = int(input("Enter noise dimension (recommended 50 or 100): "))
                if noise > 0:
                    self.noise_dim = noise
                    break
                print("Noise dimension must be positive")
            except ValueError:
                print("Please enter a valid integer")

        # Learning rate
        while True:
            try:
                lr = float(input("Enter learning rate (recommended 0.0002): "))
                if lr > 0:
                    self.learning_rate = lr
                    break
                print("Learning rate must be positive")
            except ValueError:
                print("Please enter a valid number")

        # Save interval
        while True:
            try:
                interval = int(input("Enter save interval (e.g., 5 for every 5 epochs): "))
                if interval > 0:
                    self.save_interval = interval
                    break
                print("Save interval must be positive")
            except ValueError:
                print("Please enter a valid integer")

        print(f"\n✓ Configuration complete. Using device: {self.device}")
        self.print_config()

    def print_config(self):
        """Print current configuration"""
        print("\n" + "=" * 60)
        print("CONFIGURATION SUMMARY")
        print("=" * 60)
        print(f"Dataset: {self.dataset_choice.upper()}")
        print(f"Epochs: {self.epochs}")
        print(f"Batch Size: {self.batch_size}")
        print(f"Noise Dimension: {self.noise_dim}")
        print(f"Learning Rate: {self.learning_rate}")
        print(f"Save Interval: {self.save_interval}")
        print(f"Device: {self.device}")
        print("=" * 60 + "\n")


# ============================================================================
# STEP 2: Generator Network
# ============================================================================

class Generator(nn.Module):
    """
    Generator Network: Converts random noise to images
    Input: noise vector of size (noise_dim,)
    Output: image of size (1, 28, 28)
    """
    def __init__(self, noise_dim, img_shape):
        super(Generator, self).__init__()
        self.img_shape = img_shape

        # Calculate the size needed for reshaping
        self.init_size = img_shape[1] // 4  # 28 // 4 = 7

        # Linear layer to expand noise
        self.fc = nn.Sequential(
            nn.Linear(noise_dim, 128 * self.init_size ** 2)
        )

        # Convolutional layers to upsample
        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),  # 7 -> 14
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),  # 14 -> 28
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, img_shape[0], 3, stride=1, padding=1),
            nn.Tanh()  # Output range [-1, 1]
        )

    def forward(self, z):
        """
        Forward pass
        z: noise vector (batch_size, noise_dim)
        returns: generated images (batch_size, 1, 28, 28)
        """
        out = self.fc(z)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img


# ============================================================================
# STEP 3: Discriminator Network
# ============================================================================

class Discriminator(nn.Module):
    """
    Discriminator Network: Classifies images as real or fake
    Input: image of size (1, 28, 28)
    Output: probability of being real (single value)
    """
    def __init__(self, img_shape):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Conv2d(img_shape[0], 32, 3, 2, 1),  # 28 -> 14
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout2d(0.25),
            nn.Conv2d(32, 64, 3, 2, 1),  # 14 -> 7
            nn.ZeroPad2d((0, 1, 0, 1)),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout2d(0.25),
            nn.Conv2d(64, 128, 3, 2, 1),  # 7 -> 4
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout2d(0.25),
            nn.Conv2d(128, 256, 3, 1, 1),
            nn.BatchNorm2d(256, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout2d(0.25),
        )

        # Calculate output size after conv layers
        ds_size = 4
        self.adv_layer = nn.Sequential(
            nn.Linear(256 * ds_size ** 2, 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        """
        Forward pass
        img: images (batch_size, 1, 28, 28)
        returns: probability of being real (batch_size, 1)
        """
        out = self.model(img)
        out = out.view(out.shape[0], -1)
        validity = self.adv_layer(out)
        return validity


# ============================================================================
# STEP 4: Pre-trained Classifier for Label Prediction
# ============================================================================

class Classifier(nn.Module):
    """Simple CNN classifier for MNIST/Fashion-MNIST"""
    def __init__(self, num_classes=10):
        super(Classifier, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 32, 3, 1, 1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(32, 64, 3, 1, 1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
        )
        self.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(64 * 7 * 7, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x


# ============================================================================
# STEP 5: Data Loading
# ============================================================================

def load_dataset(config):
    """Load and prepare the dataset"""
    print("\n" + "=" * 60)
    print("LOADING DATASET")
    print("=" * 60)

    # Normalize to [-1, 1] to match Generator output
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])
    ])

    if config.dataset_choice == 'mnist':
        dataset = datasets.MNIST(
            root='./data',
            train=True,
            download=True,
            transform=transform
        )
        print("✓ MNIST dataset loaded")
    else:
        dataset = datasets.FashionMNIST(
            root='./data',
            train=True,
            download=True,
            transform=transform
        )
        print("✓ Fashion-MNIST dataset loaded")

    dataloader = DataLoader(
        dataset,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=2,
        drop_last=True
    )

    print(f"✓ Total samples: {len(dataset)}")
    print(f"✓ Batches per epoch: {len(dataloader)}")
    print("=" * 60 + "\n")

    return dataloader, dataset


# ============================================================================
# STEP 6: Train Classifier (for label prediction)
# ============================================================================

def train_classifier(dataloader, config):
    """Train a classifier for label prediction"""
    print("\n" + "=" * 60)
    print("TRAINING CLASSIFIER FOR LABEL PREDICTION")
    print("=" * 60)

    classifier = Classifier().to(config.device)
    optimizer = optim.Adam(classifier.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()

    classifier.train()
    num_epochs = 5

    for epoch in range(num_epochs):
        total_loss = 0
        correct = 0
        total = 0

        for i, (imgs, labels) in enumerate(dataloader):
            imgs, labels = imgs.to(config.device), labels.to(config.device)

            optimizer.zero_grad()
            outputs = classifier(imgs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

        acc = 100. * correct / total
        print(f"Classifier Epoch {epoch+1}/{num_epochs} | Loss: {total_loss/len(dataloader):.4f} | Acc: {acc:.2f}%")

    print("✓ Classifier training complete")
    print("=" * 60 + "\n")

    return classifier


# ============================================================================
# STEP 7: GAN Training
# ============================================================================

def train_gan(config, dataloader, classifier):
    """Main GAN training function"""
    print("\n" + "=" * 60)
    print("STARTING GAN TRAINING")
    print("=" * 60 + "\n")

    # Create directories
    os.makedirs('generated_samples', exist_ok=True)
    os.makedirs('final_generated_images', exist_ok=True)

    # Initialize networks
    img_shape = (config.channels, config.img_size, config.img_size)
    generator = Generator(config.noise_dim, img_shape).to(config.device)
    discriminator = Discriminator(img_shape).to(config.device)

    print(f"✓ Generator parameters: {sum(p.numel() for p in generator.parameters()):,}")
    print(f"✓ Discriminator parameters: {sum(p.numel() for p in discriminator.parameters()):,}\n")

    # Optimizers
    optimizer_G = optim.Adam(generator.parameters(), lr=config.learning_rate, betas=(0.5, 0.999))
    optimizer_D = optim.Adam(discriminator.parameters(), lr=config.learning_rate, betas=(0.5, 0.999))

    # Loss function
    adversarial_loss = nn.BCELoss()

    # Fixed noise for consistent visualization
    fixed_noise = torch.randn(25, config.noise_dim).to(config.device)

    # Training loop
    print("Training Progress:")
    print("-" * 60)

    for epoch in range(config.epochs):
        d_losses = []
        g_losses = []
        d_accuracies = []

        for i, (real_imgs, _) in enumerate(dataloader):
            batch_size = real_imgs.size(0)
            real_imgs = real_imgs.to(config.device)

            # Labels for real and fake
            real_labels = torch.ones(batch_size, 1).to(config.device)
            fake_labels = torch.zeros(batch_size, 1).to(config.device)

            # ---------------------
            # Train Discriminator
            # ---------------------
            optimizer_D.zero_grad()

            # Real images
            real_output = discriminator(real_imgs)
            d_loss_real = adversarial_loss(real_output, real_labels)

            # Fake images
            noise = torch.randn(batch_size, config.noise_dim).to(config.device)
            fake_imgs = generator(noise)
            fake_output = discriminator(fake_imgs.detach())
            d_loss_fake = adversarial_loss(fake_output, fake_labels)

            # Total discriminator loss
            d_loss = (d_loss_real + d_loss_fake) / 2
            d_loss.backward()
            optimizer_D.step()

            # Calculate discriminator accuracy
            real_acc = (real_output > 0.5).float().mean()
            fake_acc = (fake_output < 0.5).float().mean()
            d_acc = (real_acc + fake_acc) / 2

            # ---------------------
            # Train Generator
            # ---------------------
            optimizer_G.zero_grad()

            # Generate fake images
            noise = torch.randn(batch_size, config.noise_dim).to(config.device)
            gen_imgs = generator(noise)

            # Generator tries to fool discriminator
            g_output = discriminator(gen_imgs)
            g_loss = adversarial_loss(g_output, real_labels)

            g_loss.backward()
            optimizer_G.step()

            # Store losses
            d_losses.append(d_loss.item())
            g_losses.append(g_loss.item())
            d_accuracies.append(d_acc.item() * 100)

        # Calculate epoch averages
        avg_d_loss = np.mean(d_losses)
        avg_g_loss = np.mean(g_losses)
        avg_d_acc = np.mean(d_accuracies)

        # Print progress
        print(f"Epoch {epoch+1}/{config.epochs} | D_loss: {avg_d_loss:.2f} | D_acc: {avg_d_acc:.2f}% | G_loss: {avg_g_loss:.2f}")

        # Save generated samples at intervals
        if (epoch + 1) % config.save_interval == 0:
            generator.eval()
            with torch.no_grad():
                sample_imgs = generator(fixed_noise)
                save_image(sample_imgs.data,
                          f'generated_samples/epoch_{epoch+1:03d}.png',
                          nrow=5, normalize=True)
            generator.train()
            print(f"  → Saved samples to generated_samples/epoch_{epoch+1:03d}.png")

    print("\n" + "=" * 60)
    print("✓ GAN TRAINING COMPLETE")
    print("=" * 60 + "\n")

    return generator


# ============================================================================
# STEP 8: Generate Final Images and Predict Labels
# ============================================================================

def generate_final_images(generator, classifier, config):
    """Generate 100 final images and predict their labels"""
    print("\n" + "=" * 60)
    print("GENERATING FINAL IMAGES")
    print("=" * 60)

    generator.eval()
    classifier.eval()

    num_images = 100
    all_images = []
    all_labels = []

    with torch.no_grad():
        # Generate in batches
        for i in range(0, num_images, config.batch_size):
            batch_size = min(config.batch_size, num_images - i)
            noise = torch.randn(batch_size, config.noise_dim).to(config.device)
            imgs = generator(noise)
            all_images.append(imgs)

            # Predict labels
            outputs = classifier(imgs)
            predicted = outputs.argmax(dim=1)
            all_labels.extend(predicted.cpu().numpy())

        # Concatenate all images
        all_images = torch.cat(all_images, dim=0)

        # Save individual images
        for i in range(num_images):
            save_image(all_images[i],
                      f'final_generated_images/image_{i+1:03d}.png',
                      normalize=True)

        # Save grid
        save_image(all_images.data[:100],
                  'final_generated_images/grid_100.png',
                  nrow=10, normalize=True)

    print(f"✓ Generated {num_images} images saved to final_generated_images/")
    print(f"✓ Grid visualization saved to final_generated_images/grid_100.png")

    # Print label distribution
    print("\n" + "=" * 60)
    print("LABEL DISTRIBUTION OF GENERATED IMAGES")
    print("=" * 60)

    label_names = {
        'mnist': ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'],
        'fashion': ['T-shirt', 'Trouser', 'Pullover', 'Dress', 'Coat',
                   'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
    }

    labels_count = np.bincount(all_labels, minlength=10)
    names = label_names[config.dataset_choice]

    for i, (name, count) in enumerate(zip(names, labels_count)):
        percentage = (count / num_images) * 100
        bar = '█' * int(percentage / 2)
        print(f"{i} ({name:12s}): {count:3d} ({percentage:5.1f}%) {bar}")

    print("=" * 60 + "\n")


# ============================================================================
# MAIN EXECUTION
# ============================================================================

def main():
    """Main execution function"""
    # Step 1: Get configuration
    config = Config()
    config.get_user_input()

    # Step 2: Load dataset
    dataloader, dataset = load_dataset(config)

    # Step 3: Train classifier for label prediction
    classifier = train_classifier(dataloader, config)

    # Step 4: Train GAN
    generator = train_gan(config, dataloader, classifier)

    # Step 5: Generate final images and predict labels
    generate_final_images(generator, classifier, config)

    print("\n" + "=" * 60)
    print("ALL TASKS COMPLETED SUCCESSFULLY!")
    print("=" * 60)
    print("\nOutput Summary:")
    print("1. Training logs printed above")
    print("2. Generated samples saved in: generated_samples/")
    print("3. Final 100 images saved in: final_generated_images/")
    print("4. Label distribution printed above")
    print("=" * 60 + "\n")


if __name__ == "__main__":
    main()

GAN IMAGE GENERATION SYSTEM

Choose dataset (mnist/fashion): mnist
Enter number of epochs (recommended 30-100): 50
Enter batch size (recommended 64 or 128): 80
Enter noise dimension (recommended 50 or 100): 50
Enter learning rate (recommended 0.0002): 0.0002
Enter save interval (e.g., 5 for every 5 epochs): 5

✓ Configuration complete. Using device: cuda

CONFIGURATION SUMMARY
Dataset: MNIST
Epochs: 50
Batch Size: 80
Noise Dimension: 50
Learning Rate: 0.0002
Save Interval: 5
Device: cuda


LOADING DATASET


100%|██████████| 9.91M/9.91M [00:00<00:00, 17.9MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 487kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 4.40MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 3.85MB/s]


✓ MNIST dataset loaded
✓ Total samples: 60000
✓ Batches per epoch: 750


TRAINING CLASSIFIER FOR LABEL PREDICTION
Classifier Epoch 1/5 | Loss: 0.2980 | Acc: 90.77%
Classifier Epoch 2/5 | Loss: 0.1171 | Acc: 96.55%
Classifier Epoch 3/5 | Loss: 0.0945 | Acc: 97.25%
Classifier Epoch 4/5 | Loss: 0.0834 | Acc: 97.51%
Classifier Epoch 5/5 | Loss: 0.0730 | Acc: 97.86%
✓ Classifier training complete


STARTING GAN TRAINING

✓ Generator parameters: 542,465
✓ Discriminator parameters: 392,833

Training Progress:
------------------------------------------------------------
Epoch 1/50 | D_loss: 0.64 | D_acc: 63.47% | G_loss: 0.83
Epoch 2/50 | D_loss: 0.62 | D_acc: 66.24% | G_loss: 0.90
Epoch 3/50 | D_loss: 0.59 | D_acc: 69.11% | G_loss: 0.99
Epoch 4/50 | D_loss: 0.56 | D_acc: 71.45% | G_loss: 1.05
Epoch 5/50 | D_loss: 0.55 | D_acc: 72.22% | G_loss: 1.09
  → Saved samples to generated_samples/epoch_005.png
Epoch 6/50 | D_loss: 0.53 | D_acc: 73.80% | G_loss: 1.14
Epoch 7/50 | D_loss: 0.53 | D_acc: 7