In [1]:
import torch
import torch.nn as nn
from torchvision import datasets,transforms
from accelerate import Accelerator
import matplotlib.pyplot as plt
import os

In [2]:
def load_mnist_dataset(batch_size=64):
    transform_for_gan=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5),(0.5))
    ])
    train_dataset=datasets.MNIST('data',train=True,download=True,transform=transform_for_gan)
    train_dataloader=torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,shuffle=True)
    
    return train_dataloader

In [17]:
class SimpleGenerator(nn.Module):
    def __init__(self,noise_dimension=100):
        super().__init__()
        self.noise_dimension=noise_dimension
        
        self.noise_to_image=nn.Sequential(
           nn.Linear(noise_dimension, 256),
            nn.BatchNorm1d(256),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Linear(512, 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Linear(1024, 28*28),
            nn.Tanh()
        )
        print(f"Created Generator: {noise_dimension} -> 256 -> 512 -> {28*28}")
        
    def forward(self,random_noise):
        flat_images=self.noise_to_image(random_noise)
        batch_size=flat_images.size(0)
        generated_images=flat_images.view(batch_size,1,28,28)
        
        return generated_images

In [15]:
class SimpleDiscriminator(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.image_to_decision=nn.Sequential(
            nn.Linear(28*28, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
        
        print(f"Created Discriminator :{28*28} -> 512 -> 256 -> 1")
        
    def forward(self, images):
        """Classify images as real (1) or fake (0)"""
        batch_size = images.size(0)
        flat_images = images.view(batch_size, -1)  # Flatten to (batch_size, 784)
        probability_real = self.image_to_decision(flat_images)
        return probability_real

In [5]:
def train_generator(generator,discriminator, real_image_batch, optimizer_generator, accelerator):
    batch_size=real_image_batch.size(0)
    
    #Step 1 - Generate fake images using random noise
    random_noise=torch.randn(batch_size,generator.noise_dimension,device=accelerator.device)
    generated_fake_images=generator(random_noise)
    
    #Step 2 - Take Discriminator opinion on the generated images (we want to classify the fake images as real)
    discriminator_opinion_on_fakes=discriminator(generated_fake_images)
    target_labels_real=torch.ones_like(discriminator_opinion_on_fakes)
    
    #Step 3 - Calculate loss and update generator
    loss_function=nn.BCELoss()
    generator_loss=loss_function(discriminator_opinion_on_fakes,target_labels_real)
    
    optimizer_generator.zero_grad()
    accelerator.backward(generator_loss)
    optimizer_generator.step()
    
    return generator_loss.item()
    
    

In [6]:
def train_discriminator(generator,discriminator,real_images_batch,optimizer_discriminator,accelerator):
    batch_size=real_images_batch.size(0)
    loss_function=nn.BCELoss()
    
    #Step 1 - Train on REAL images
    discriminator_opinion_on_real_images=discriminator(real_images_batch)
    # target_labels_real=torch.zeros_like(discriminator_opinion_on_real_images)
    target_labels_real=torch.torch.ones_like(discriminator_opinion_on_real_images)
    loss_on_real_images = loss_function(discriminator_opinion_on_real_images, target_labels_real)
    
    #Step 2 - Train on FAKE images
    random_noise=torch.randn(batch_size,generator.noise_dimension,device=accelerator.device)
    generated_fake_images = generator(random_noise).detach()  # Don't update generator
    discriminator_opinion_on_fakes = discriminator(generated_fake_images)
    target_labels_fake = torch.zeros_like(discriminator_opinion_on_fakes)
    loss_on_fake_images = loss_function(discriminator_opinion_on_fakes, target_labels_fake)
    
    # Step 3: Combined discriminator loss
    total_discriminator_loss = (loss_on_real_images + loss_on_fake_images) / 2
    
    optimizer_discriminator.zero_grad()
    accelerator.backward(total_discriminator_loss)
    optimizer_discriminator.step()
    
    return total_discriminator_loss.item()

In [7]:
def save_generated_samples(generator, epoch, accelerator, num_samples=16):
    """Save generated image samples to visualize progress"""
    generator.eval()
    
    with torch.no_grad():
        # Generate samples
        sample_noise = torch.randn(num_samples, generator.noise_dimension, device=accelerator.device)
        generated_samples = generator(sample_noise)
        
        # Convert to numpy for plotting
        samples_cpu = generated_samples.cpu()
        
        # Create grid plot
        fig, axes = plt.subplots(4, 4, figsize=(8, 8))
        for i, ax in enumerate(axes.flat):
            if i < num_samples:
                # Denormalize from [-1,1] to [0,1] for display
                img = (samples_cpu[i].squeeze() + 1) / 2
                ax.imshow(img, cmap='gray')
            ax.axis('off')
        
        plt.suptitle(f'Generated Images - Epoch {epoch}', fontsize=16)
        plt.tight_layout()
        
        # Save image
        os.makedirs('generated_images', exist_ok=True)
        if isinstance(epoch, str):
            plt.savefig(f'generated_images/{epoch}.png')
        else:
            plt.savefig(f'generated_images/epoch_{epoch:03d}.png')
        plt.close()
    
    generator.train()


In [8]:
def overfit_single_batch_gan(generator, discriminator, train_dataloader, accelerator, iterations=1000):
    """
    TESTING FUNCTION: Overfit GAN on a single batch to verify learning capability
    If the networks can't overfit one batch, they won't work on full dataset
    """
    accelerator.print("=== OVERFITTING TEST: Training GAN on single batch ===")
    
    # Get one single batch and keep using it
    single_batch_images, _ = next(iter(train_dataloader))
    visualize_single_batch(single_batch_images)

    #creating an even smaller batch of 10 images instead of 128
    # small_batch = single_batch_images[:1]
    # print("visualizing small batch")
    # visualize_single_batch(single_batch_images)
    accelerator.print(f"Using single batch with {single_batch_images.size(0)} images")
    
    # Create optimizers
    optimizer_generator = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    optimizer_discriminator = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    
    # Train on the same batch repeatedly
    for iteration in range(iterations):
        # Train discriminator on the same batch
        discriminator_loss = train_discriminator(
            generator, discriminator, single_batch_images, 
            optimizer_discriminator, accelerator
        )
        
        # Train generator on the same batch
        generator_loss = train_generator(
            generator, discriminator, single_batch_images,
            optimizer_generator, accelerator
        )
        
        # Print progress every 50 iterations
        if (iteration + 1) % 50 == 0:
            accelerator.print(f'Iteration {iteration+1:3d}: Gen Loss = {generator_loss:.4f}, Disc Loss = {discriminator_loss:.4f}')
        
        # Save samples every 100 iterations
        if (iteration + 1) % 100 == 0:
            save_generated_samples(generator, f'overfit_iter_{iteration+1}', accelerator)
    
    accelerator.print("Overfitting test completed! Check if losses decreased and images improved.")
    accelerator.print("If GAN can overfit one batch, it should work on full dataset.")


In [9]:
def train_gan(generator, discriminator, train_dataloader, accelerator, epochs=50):
    """Main training loop for the GAN"""
    # Create optimizers
    optimizer_generator = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    optimizer_discriminator = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    
    # Training loop
    for epoch in range(epochs):
        total_generator_loss = 0
        total_discriminator_loss = 0
        num_batches = 0
        
        for real_images_batch, _ in train_dataloader:  # We don't need labels for GAN
            
            # Train discriminator first
            discriminator_loss = train_discriminator(
                generator, discriminator, real_images_batch, 
                optimizer_discriminator, accelerator
            )
            
            # Train generator second
            generator_loss = train_generator(
                generator, discriminator, real_images_batch,
                optimizer_generator, accelerator
            )
            
            total_generator_loss += generator_loss
            total_discriminator_loss += discriminator_loss
            num_batches += 1
        
        # Print progress
        avg_gen_loss = total_generator_loss / num_batches
        avg_disc_loss = total_discriminator_loss / num_batches
        accelerator.print(f'Epoch {epoch+1:3d}: Gen Loss = {avg_gen_loss:.4f}, Disc Loss = {avg_disc_loss:.4f}')
        
        # Save samples every 10 epochs
        if (epoch + 1) % 10 == 0:
            save_generated_samples(generator, epoch + 1, accelerator)


In [10]:
def main():
    """Main function to run the educational GAN"""
    # Initialize Accelerator for device management
    accelerator = Accelerator()
    accelerator.print("Starting Simple Educational GAN Training")
    
    # Step 1: Load MNIST data
    train_dataloader = load_mnist_dataset(batch_size=128)
    accelerator.print(f"Loaded MNIST dataset with {len(train_dataloader)} batches")
    
    # Step 2: Create models
    noise_dimension = 100
    generator = SimpleGenerator(noise_dimension)
    discriminator = SimpleDiscriminator()
    
    # Step 3: Prepare everything with Accelerate
    generator, discriminator, train_dataloader = accelerator.prepare(
        generator, discriminator, train_dataloader
    )
    
    accelerator.print(f'Generator parameters: {sum(p.numel() for p in generator.parameters()):,}')
    accelerator.print(f'Discriminator parameters: {sum(p.numel() for p in discriminator.parameters()):,}')
    
    # OPTION 1: Test if networks can learn by overfitting single batch
    accelerator.print("\nChoose training mode:")
    accelerator.print("1. Overfit single batch (testing mode)")
    accelerator.print("2. Full dataset training")
    
    # For educational purposes, let's run overfitting test first
    # overfit_single_batch_gan(generator, discriminator, train_dataloader, accelerator, iterations=1000)
    
    # Uncomment the line below to run full training instead:
    train_gan(generator, discriminator, train_dataloader, accelerator, epochs=100)
    
    accelerator.print("Training completed! Check 'generated_images' folder for results.")


In [13]:
def visualize_single_batch(single_batch_images, n=8):
    plt.figure(figsize=(8, 8))
    for i in range(n * n):
        plt.subplot(n, n, i + 1)
        plt.imshow(single_batch_images[i].squeeze().cpu(), cmap='gray')
        plt.axis('off')
    plt.show() 


In [18]:
if __name__ == "__main__":
    main()

Starting Simple Educational GAN Training
Loaded MNIST dataset with 469 batches
Created Generator: 100 -> 256 -> 512 -> 784
Created Discriminator :784 -> 512 -> 256 -> 1
Generator parameters: 1,489,936
Discriminator parameters: 1,460,225

Choose training mode:
1. Overfit single batch (testing mode)
2. Full dataset training
Epoch   1: Gen Loss = 0.9334, Disc Loss = 0.6339
Epoch   2: Gen Loss = 1.1016, Disc Loss = 0.5968
Epoch   3: Gen Loss = 1.1846, Disc Loss = 0.5726
Epoch   4: Gen Loss = 1.2016, Disc Loss = 0.5529
Epoch   5: Gen Loss = 1.1736, Disc Loss = 0.5592
Epoch   6: Gen Loss = 1.1174, Disc Loss = 0.5726
Epoch   7: Gen Loss = 1.0612, Disc Loss = 0.5856
Epoch   8: Gen Loss = 1.0470, Disc Loss = 0.5906
Epoch   9: Gen Loss = 1.0261, Disc Loss = 0.5966
Epoch  10: Gen Loss = 1.0095, Disc Loss = 0.6036
Epoch  11: Gen Loss = 1.0011, Disc Loss = 0.6046
Epoch  12: Gen Loss = 0.9838, Disc Loss = 0.6101
Epoch  13: Gen Loss = 0.9761, Disc Loss = 0.6140
Epoch  14: Gen Loss = 0.9645, Disc Loss