In [7]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import torch.nn.functional as F


In [5]:

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=0.5, std=0.5)  # (x - 0.5) / 0.5 = 2x - 1
])

mnist_train = torchvision.datasets.MNIST(
    root='./mnist_data', 
    train=True, 
    download=True, 
    transform=transform
)

mnist_test = torchvision.datasets.MNIST(
    root='./mnist_data', 
    train=False, 
    download=True, 
    transform=transform
)

print(f"Training samples: {len(mnist_train)}")
print(f"Test samples: {len(mnist_test)}")

100%|██████████| 9.91M/9.91M [00:02<00:00, 4.48MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 113kB/s]
100%|██████████| 1.65M/1.65M [00:01<00:00, 909kB/s] 
100%|██████████| 4.54k/4.54k [00:00<00:00, 1.13MB/s]


Training samples: 60000
Test samples: 10000


In [6]:
# Set up DataLoaders
batch_size = 32  # Adjust based on your CPU capacity
train_loader = DataLoader(
    mnist_train, 
    batch_size=batch_size, 
    shuffle=True, 
    num_workers=2  # Adjust based on your CPU
)

test_loader = DataLoader(
    mnist_test, 
    batch_size=batch_size, 
    shuffle=False, 
    num_workers=2
)

print(f"Number of training batches: {len(train_loader)}")
print(f"Number of test batches: {len(test_loader)}")

Number of training batches: 1875
Number of test batches: 313


In [None]:
class Generator(nn.Module):
    """
    Generator network for MNIST GAN
    Maps noise vector z (dim 100) to fake images (28x28 = 784)
    Architecture: 100 -> 256 -> 512 -> 784
    """
    def __init__(self, noise_dim=100, hidden_dim1=256, hidden_dim2=512, output_dim=784):
        super(Generator, self).__init__()
        
        self.noise_dim = noise_dim
        self.output_dim = output_dim
        
        # Define layers
        self.fc1 = nn.Linear(noise_dim, hidden_dim1)
        self.bn1 = nn.BatchNorm1d(hidden_dim1)  # Optional batch normalization
        
        self.fc2 = nn.Linear(hidden_dim1, hidden_dim2)
        self.bn2 = nn.BatchNorm1d(hidden_dim2)
        
        self.fc3 = nn.Linear(hidden_dim2, output_dim)
        
        # Initialize weights
        self._initialize_weights()
    
    def _initialize_weights(self):
        """Initialize weights following DCGAN paper recommendations"""
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0.0, 0.02)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm1d):
                nn.init.normal_(m.weight, 1.0, 0.02)
                nn.init.constant_(m.bias, 0)
    
    def forward(self, z):
        """
        Forward pass through generator
        Args:
            z: Noise vector of shape [batch_size, noise_dim]
        Returns:
            Generated images of shape [batch_size, 1, 28, 28]
        """
        # First hidden layer
        x = self.fc1(z)
        x = self.bn1(x)
        x = F.relu(x)
        
        # Second hidden layer
        x = self.fc2(x)
        x = self.bn2(x)
        x = F.relu(x)
        
        # Output layer
        x = self.fc3(x)
        x = torch.tanh(x)  # Output in [-1, 1] to match normalized MNIST
        
        # Reshape to image format [batch_size, 1, 28, 28]
        x = x.view(-1, 1, 28, 28)
        
        return x

class Discriminator(nn.Module):
    """
    Discriminator network for MNIST GAN
    Maps flattened images (784) to probability of being real
    Architecture: 784 -> 256 -> 256 -> 1
    """
    def __init__(self, input_dim=784, hidden_dim1=256, hidden_dim2=256, dropout_prob=0.3):
        super(Discriminator, self).__init__()
        
        self.input_dim = input_dim
        
        # Define layers
        self.fc1 = nn.Linear(input_dim, hidden_dim1)
        self.dropout1 = nn.Dropout(dropout_prob)
        
        self.fc2 = nn.Linear(hidden_dim1, hidden_dim2)
        self.dropout2 = nn.Dropout(dropout_prob)
        
        self.fc3 = nn.Linear(hidden_dim2, 1)
        
        # Initialize weights
        self._initialize_weights()
    
    def _initialize_weights(self):
        """Initialize weights following DCGAN paper recommendations"""
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0.0, 0.02)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
    
    def forward(self, x):
        """
        Forward pass through discriminator
        Args:
            x: Images of shape [batch_size, 1, 28, 28] or [batch_size, 784]
        Returns:
            Probability of being real [batch_size, 1]
        """
        # Flatten if needed
        if len(x.shape) == 4:  # [batch_size, 1, 28, 28]
            x = x.view(x.size(0), -1)  # [batch_size, 784]
        
        # First hidden layer
        x = self.fc1(x)
        x = F.leaky_relu(x, negative_slope=0.2)
        x = self.dropout1(x)
        
        # Second hidden layer
        x = self.fc2(x)
        x = F.leaky_relu(x, negative_slope=0.2)
        x = self.dropout2(x)
        
        # Output layer
        x = self.fc3(x)
        x = torch.sigmoid(x)  # Output probability [0, 1]
        
        return x


In [None]:
class GANTrainer:
    def __init__(self, generator, discriminator, device, lr=2e-4, beta1=0.5, beta2=0.999, 
                 label_smoothing=True, smooth_real_labels=0.9):
        """
        GAN Trainer with optimized hyperparameters
        
        Args:
            generator: Generator model
            discriminator: Discriminator model  
            device: Training device (cuda/cpu)
            lr: Learning rate for both optimizers
            beta1: Beta1 parameter for Adam (reduced for GAN stability)
            beta2: Beta2 parameter for Adam
            label_smoothing: Whether to apply one-sided label smoothing
            smooth_real_labels: Value for smoothed real labels (default 0.9)
        """
        self.generator = generator.to(device)
        self.discriminator = discriminator.to(device)
        self.device = device
        
        # Optimizers with GAN-specific hyperparameters
        self.optimizer_G = optim.Adam(
            self.generator.parameters(), 
            lr=lr, 
            betas=(beta1, beta2)
        )
        
        self.optimizer_D = optim.Adam(
            self.discriminator.parameters(), 
            lr=lr, 
            betas=(beta1, beta2)
        )
        
        # Loss function
        self.criterion = nn.BCELoss()
        
        # Label smoothing settings
        self.label_smoothing = label_smoothing
        self.smooth_real_labels = smooth_real_labels
        
        # Training statistics
        self.losses_G = []
        self.losses_D = []
        self.losses_D_real = []
        self.losses_D_fake = []
        
        print(f"GAN Trainer initialized:")
        print(f"  Learning rate: {lr}")
        print(f"  Adam betas: ({beta1}, {beta2})")
        print(f"  Label smoothing: {label_smoothing}")
        if label_smoothing:
            print(f"  Smooth real labels: {smooth_real_labels}")
        print(f"  Device: {device}")
    
    def get_labels(self, batch_size, real=True):
        """
        Generate labels for real/fake data with optional smoothing
        
        Args:
            batch_size: Size of the batch
            real: True for real data labels, False for fake data labels
            
        Returns:
            Labels tensor of appropriate shape
        """
        if real:
            if self.label_smoothing:
                # One-sided label smoothing: real labels = 0.9 instead of 1.0
                labels = torch.full((batch_size, 1), self.smooth_real_labels, 
                                  dtype=torch.float32, device=self.device)
            else:
                labels = torch.ones(batch_size, 1, dtype=torch.float32, device=self.device)
        else:
            # Always use 0.0 for fake labels (no smoothing)
            labels = torch.zeros(batch_size, 1, dtype=torch.float32, device=self.device)
        
        return labels
    
    def train_discriminator(self, real_images, noise):
        """
        Train discriminator on both real and fake images
        
        Args:
            real_images: Batch of real images
            noise: Random noise for generating fake images
            
        Returns:
            Dictionary with discriminator losses
        """
        batch_size = real_images.size(0)
        
        # Clear discriminator gradients
        self.optimizer_D.zero_grad()
        
        # === Train on real images ===
        real_labels = self.get_labels(batch_size, real=True)
        real_predictions = self.discriminator(real_images)
        loss_D_real = self.criterion(real_predictions, real_labels)
        
        # === Train on fake images ===
        fake_images = self.generator(noise)
        fake_labels = self.get_labels(batch_size, real=False)
        fake_predictions = self.discriminator(fake_images.detach())  # Detach to avoid training G
        loss_D_fake = self.criterion(fake_predictions, fake_labels)
        
        # Total discriminator loss
        loss_D = loss_D_real + loss_D_fake
        loss_D.backward()
        self.optimizer_D.step()
        
        return {
            'loss_D_total': loss_D.item(),
            'loss_D_real': loss_D_real.item(),
            'loss_D_fake': loss_D_fake.item(),
            'real_predictions': real_predictions.mean().item(),
            'fake_predictions': fake_predictions.mean().item()
        }
    
    def train_generator(self, noise):
        """
        Train generator using non-saturating loss
        
        Args:
            noise: Random noise for generating fake images
            
        Returns:
            Dictionary with generator losses
        """
        batch_size = noise.size(0)
        
        # Clear generator gradients
        self.optimizer_G.zero_grad()
        
        # Generate fake images
        fake_images = self.generator(noise)
        
        # Non-saturating loss: use real labels (1) for fake images
        # This maximizes log(D(G(z))) instead of minimizing log(1-D(G(z)))
        real_labels = self.get_labels(batch_size, real=True)
        fake_predictions = self.discriminator(fake_images)
        loss_G = self.criterion(fake_predictions, real_labels)
        
        loss_G.backward()
        self.optimizer_G.step()
        
        return {
            'loss_G': loss_G.item(),
            'fake_predictions_for_G': fake_predictions.mean().item()
        }
    
    def train_epoch(self, dataloader, noise_dim=100):
        """
        Train both networks for one epoch
        
        Args:
            dataloader: DataLoader for training data
            noise_dim: Dimension of noise vector
            
        Returns:
            Dictionary with epoch statistics
        """
        self.generator.train()
        self.discriminator.train()
        
        epoch_losses_G = []
        epoch_losses_D = []
        epoch_losses_D_real = []
        epoch_losses_D_fake = []
        
        progress_bar = tqdm(dataloader, desc="Training")
        
        for batch_idx, (real_images, _) in enumerate(progress_bar):
            batch_size = real_images.size(0)
            real_images = real_images.to(self.device)
            
            # Generate random noise
            noise = torch.randn(batch_size, noise_dim, device=self.device)
            
            # Train Discriminator
            d_stats = self.train_discriminator(real_images, noise)
            
            # Train Generator
            g_stats = self.train_generator(noise)
            
            # Store losses
            epoch_losses_G.append(g_stats['loss_G'])
            epoch_losses_D.append(d_stats['loss_D_total'])
            epoch_losses_D_real.append(d_stats['loss_D_real'])
            epoch_losses_D_fake.append(d_stats['loss_D_fake'])
            
            # Update progress bar
            progress_bar.set_postfix({
                'D_loss': f"{d_stats['loss_D_total']:.4f}",
                'G_loss': f"{g_stats['loss_G']:.4f}",
                'D(x)': f"{d_stats['real_predictions']:.3f}",
                'D(G(z))': f"{d_stats['fake_predictions']:.3f}"
            })
        
        # Calculate epoch averages
        avg_loss_G = np.mean(epoch_losses_G)
        avg_loss_D = np.mean(epoch_losses_D)
        avg_loss_D_real = np.mean(epoch_losses_D_real)
        avg_loss_D_fake = np.mean(epoch_losses_D_fake)
        
        # Store for plotting
        self.losses_G.append(avg_loss_G)
        self.losses_D.append(avg_loss_D)
        self.losses_D_real.append(avg_loss_D_real)
        self.losses_D_fake.append(avg_loss_D_fake)
        
        return {
            'avg_loss_G': avg_loss_G,
            'avg_loss_D': avg_loss_D,
            'avg_loss_D_real': avg_loss_D_real,
            'avg_loss_D_fake': avg_loss_D_fake
        }
    
    def generate_samples(self, num_samples=64, noise_dim=100):
        """Generate samples for visualization"""
        self.generator.eval()
        
        with torch.no_grad():
            noise = torch.randn(num_samples, noise_dim, device=self.device)
            fake_images = self.generator(noise)
        
        return fake_images
    
    def plot_losses(self):
        """Plot training losses"""
        fig, axes = plt.subplots(1, 2, figsize=(15, 5))
        
        epochs = range(1, len(self.losses_G) + 1)
        
        # Plot Generator and Discriminator losses
        axes[0].plot(epochs, self.losses_G, label='Generator Loss', color='blue')
        axes[0].plot(epochs, self.losses_D, label='Discriminator Loss', color='red')
        axes[0].set_xlabel('Epoch')
        axes[0].set_ylabel('Loss')
        axes[0].set_title('Generator vs Discriminator Loss')
        axes[0].legend()
        axes[0].grid(True, alpha=0.3)
        
        # Plot Discriminator breakdown
        axes[1].plot(epochs, self.losses_D_real, label='D Loss (Real)', color='green')
        axes[1].plot(epochs, self.losses_D_fake, label='D Loss (Fake)', color='orange')
        axes[1].set_xlabel('Epoch')
        axes[1].set_ylabel('Loss')
        axes[1].set_title('Discriminator Loss Breakdown')
        axes[1].legend()
        axes[1].grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()

def visualize_samples(fake_images, epoch=None, num_samples=16):
    """Visualize generated samples"""
    # Convert from [-1,1] to [0,1] for visualization
    fake_images = (fake_images + 1) / 2
    fake_images = torch.clamp(fake_images, 0, 1)
    
    fig, axes = plt.subplots(4, 4, figsize=(8, 8))
    
    for i in range(num_samples):
        row, col = i // 4, i % 4
        img = fake_images[i].cpu().squeeze()
        axes[row, col].imshow(img, cmap='gray')
        axes[row, col].axis('off')
    
    title = f'Generated Samples (Epoch {epoch})' if epoch else 'Generated Samples'
    plt.suptitle(title, fontsize=14)
    plt.tight_layout()
    plt.show()

def full_training_loop(generator, discriminator, train_loader, epochs=50, noise_dim=100, 
                      lr=2e-4, label_smoothing=True, device=None):
    """
    Complete training loop with visualization
    """
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Initialize trainer
    trainer = GANTrainer(
        generator, discriminator, device, 
        lr=lr, label_smoothing=label_smoothing
    )
    
    print(f"\nStarting GAN training for {epochs} epochs...")
    print("="*60)
    
    start_time = time.time()
    
    for epoch in range(1, epochs + 1):
        print(f"\nEpoch {epoch}/{epochs}")
        
        # Train for one epoch
        epoch_stats = trainer.train_epoch(train_loader, noise_dim)
        
        # Print epoch summary
        print(f"Epoch {epoch} Summary:")
        print(f"  Generator Loss: {epoch_stats['avg_loss_G']:.4f}")
        print(f"  Discriminator Loss: {epoch_stats['avg_loss_D']:.4f}")
        print(f"    - Real Loss: {epoch_stats['avg_loss_D_real']:.4f}")
        print(f"    - Fake Loss: {epoch_stats['avg_loss_D_fake']:.4f}")
        
        # Generate and show samples every 5 epochs
        if epoch % 5 == 0 or epoch == 1:
            fake_images = trainer.generate_samples(16, noise_dim)
            visualize_samples(fake_images, epoch, num_samples=16)
    
    end_time = time.time()
    training_time = end_time - start_time
    
    print(f"\nTraining completed in {training_time:.2f} seconds")
    print(f"Average time per epoch: {training_time/epochs:.2f} seconds")
    
    # Plot final losses
    trainer.plot_losses()
    
    # Generate final samples
    print("\nFinal generated samples:")
    final_samples = trainer.generate_samples(16, noise_dim)
    visualize_samples(final_samples, epoch=epochs, num_samples=16)
    
    return trainer


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialize models (using previously defined architectures)
generator = Generator().to(device)
discriminator = Discriminator().to(device)

print("Testing training setup...")

# Test with dummy data
batch_size = 32
noise_dim = 100

# Create trainer
trainer = GANTrainer(generator, discriminator, device, label_smoothing=True)

# Test one training step
dummy_images = torch.randn(batch_size, 1, 28, 28).to(device)
dummy_noise = torch.randn(batch_size, noise_dim).to(device)

print("\nTesting discriminator training...")
d_stats = trainer.train_discriminator(dummy_images, dummy_noise)
print(f"D stats: {d_stats}")

print("\nTesting generator training...")
g_stats = trainer.train_generator(dummy_noise)
print(f"G stats: {g_stats}")

print("\nTesting sample generation...")
samples = trainer.generate_samples(16, noise_dim)
print(f"Generated samples shape: {samples.shape}")
print(f"Samples range: [{samples.min():.3f}, {samples.max():.3f}]")

print("\n✓ Training setup verified and ready!")

# Uncomment to run full training
trainer = full_training_loop(generator, discriminator, train_loader, 
                            epochs=50, device=device)