<a href="https://colab.research.google.com/github/MrCelestial/cGAN_MNIST/blob/main/cGAN_MNIST.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [12]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
import ipywidgets as widgets
from IPython.display import display, clear_output, Image
import imageio
import os
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas

# Set random seed for reproducibility
torch.manual_seed(42)

<torch._C.Generator at 0x7f2c13f4d9d0>

##Generator and Discriminator network

In [2]:
# Define the Generator network
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.label_emb = nn.Embedding(10, 10)  # Embedding for digit labels (0-9)

        self.model = nn.Sequential(
            nn.Linear(110, 256),  # 100 (noise) + 10 (label embedding)
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(1024, 784),  # 28x28 = 784 pixels
            nn.Tanh()
        )

    def forward(self, z, labels):
        labels_emb = self.label_emb(labels)
        z = torch.cat([z, labels_emb], 1)
        output = self.model(z)
        return output.view(-1, 1, 28, 28)  # Reshape to image dimensions

# Define the Discriminator network
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.label_emb = nn.Embedding(10, 10)

        self.model = nn.Sequential(
            nn.Linear(794, 1024),  # 784 (flattened image) + 10 (label embedding)
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, x, labels):
        x = x.view(-1, 784)  # Flatten the image
        labels_emb = self.label_emb(labels)
        x = torch.cat([x, labels_emb], 1)
        return self.model(x)

# Create directory for temporary images
os.makedirs('progress_images', exist_ok=True)

##Training

In [3]:
# Function to train the GAN and create a progress GIF
def train_gan(generator, discriminator, dataloader, num_epochs=10):
    # Loss function and optimizers
    criterion = nn.BCELoss()
    optimizer_g = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    generator.to(device)
    discriminator.to(device)

    # Create fixed noise for visualization
    fixed_noise = torch.randn(10, 100, device=device)
    fixed_labels = torch.arange(10, device=device)

    # List to store images for GIF
    progress_images = []

    # Training loop
    for epoch in range(num_epochs):
        for i, (real_images, labels) in enumerate(dataloader):
            batch_size = real_images.size(0)

            # Create labels for real and fake images
            real_labels = torch.ones(batch_size, 1, device=device)
            fake_labels = torch.zeros(batch_size, 1, device=device)

            # Move data to device
            real_images = real_images.to(device)
            labels = labels.to(device)

            # -----------------
            # Train Discriminator
            # -----------------
            optimizer_d.zero_grad()

            # Loss on real images
            outputs = discriminator(real_images, labels)
            d_loss_real = criterion(outputs, real_labels)

            # Generate fake images
            noise = torch.randn(batch_size, 100, device=device)
            fake_images = generator(noise, labels)

            # Loss on fake images
            outputs = discriminator(fake_images.detach(), labels)
            d_loss_fake = criterion(outputs, fake_labels)

            # Total discriminator loss
            d_loss = d_loss_real + d_loss_fake
            d_loss.backward()
            optimizer_d.step()

            # -----------------
            # Train Generator
            # -----------------
            optimizer_g.zero_grad()

            # Generate fake images again
            outputs = discriminator(fake_images, labels)
            g_loss = criterion(outputs, real_labels)

            # Generator loss
            g_loss.backward()
            optimizer_g.step()

        # Print progress and save progress images
        if (epoch + 1) % 1 == 0:
            clear_output(wait=True)
            print(f"Epoch [{epoch+1}/{num_epochs}] | D Loss: {d_loss.item():.4f} | G Loss: {g_loss.item():.4f}")

            # Generate and save images for each digit
            with torch.no_grad():
                generated = generator(fixed_noise, fixed_labels).detach().cpu()

            fig, axs = plt.subplots(1, 10, figsize=(12, 1.2))
            for j in range(10):
                axs[j].imshow(generated[j].squeeze().numpy(), cmap='gray')
                axs[j].set_title(f"Digit: {j}")
                axs[j].axis('off')
            plt.tight_layout()

            # Save the figure to a file
            canvas = FigureCanvas(fig)
            canvas.draw()
            img_path = f'progress_images/epoch_{epoch+1:03d}.png'
            fig.savefig(img_path)
            plt.show()

            # Add image to list for GIF
            progress_images.append(imageio.imread(img_path))

    # Create GIF from saved images
    gif_path = 'digit_progress.gif'
    imageio.mimsave(gif_path, progress_images, fps=2)

    # Display the GIF
    display(Image(gif_path))

    # Save the model
    torch.save(generator.state_dict(), 'mnist_generator.pth')
    print("Training complete and model saved!")
    return generator


## Progress gif

In [4]:
# Function to create a progress GIF for a specific digit
def create_digit_progress_gif(digit, num_epochs=15):
    # Create a generator and discriminator
    generator = Generator()
    discriminator = Discriminator()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    generator.to(device)
    discriminator.to(device)

    # Loss function and optimizers
    criterion = nn.BCELoss()
    optimizer_g = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

    # Create fixed noise and label for visualization
    fixed_noise = torch.randn(1, 100, device=device)
    fixed_label = torch.tensor([digit], device=device)

    # Load MNIST data
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])
    dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

    # Filter to only include the specified digit
    digit_indices = [i for i, (_, label) in enumerate(dataset) if label == digit]
    digit_subset = torch.utils.data.Subset(dataset, digit_indices)
    dataloader = DataLoader(digit_subset, batch_size=32, shuffle=True)

    # List to store images for GIF
    progress_images = []

    # Create directory for temporary images
    os.makedirs(f'progress_images_digit_{digit}', exist_ok=True)

    # Training loop
    for epoch in range(num_epochs):
        for i, (real_images, _) in enumerate(dataloader):
            batch_size = real_images.size(0)

            # Create labels for real and fake images
            real_labels = torch.ones(batch_size, 1, device=device)
            fake_labels = torch.zeros(batch_size, 1, device=device)

            # Move data to device
            real_images = real_images.to(device)
            labels = torch.full((batch_size,), digit, dtype=torch.long, device=device)

            # -----------------
            # Train Discriminator
            # -----------------
            optimizer_d.zero_grad()

            # Loss on real images
            outputs = discriminator(real_images, labels)
            d_loss_real = criterion(outputs, real_labels)

            # Generate fake images
            noise = torch.randn(batch_size, 100, device=device)
            fake_images = generator(noise, labels)

            # Loss on fake images
            outputs = discriminator(fake_images.detach(), labels)
            d_loss_fake = criterion(outputs, fake_labels)

            # Total discriminator loss
            d_loss = d_loss_real + d_loss_fake
            d_loss.backward()
            optimizer_d.step()

            # -----------------
            # Train Generator
            # -----------------
            optimizer_g.zero_grad()

            # Generate fake images again
            outputs = discriminator(fake_images, labels)
            g_loss = criterion(outputs, real_labels)

            # Generator loss
            g_loss.backward()
            optimizer_g.step()

        # Generate image with current model state
        with torch.no_grad():
            generated = generator(fixed_noise, fixed_label).detach().cpu().squeeze().numpy()

        # Save the progress image
        plt.figure(figsize=(3, 3))
        plt.imshow(generated, cmap='gray')
        plt.title(f"Digit {digit} - Epoch {epoch+1}")
        plt.axis('off')

        img_path = f'progress_images_digit_{digit}/epoch_{epoch+1:03d}.png'
        plt.savefig(img_path)
        plt.close()

        # Add image to list for GIF
        progress_images.append(imageio.imread(img_path))

        # Print progress
        if (epoch + 1) % 3 == 0 or epoch == 0:
            clear_output(wait=True)
            print(f"Training for digit {digit}: Epoch {epoch+1}/{num_epochs}")
            plt.figure(figsize=(3, 3))
            plt.imshow(generated, cmap='gray')
            plt.title(f"Digit {digit} - Epoch {epoch+1}")
            plt.axis('off')
            plt.show()

    # Create GIF from saved images
    gif_path = f'digit_{digit}_progress.gif'
    imageio.mimsave(gif_path, progress_images, fps=2)

    # Display the GIF
    display(Image(gif_path))

    print(f"Training complete for digit {digit}!")
    return generator


## Load the data and generate the interactive images

In [14]:
# Function to load and preprocess MNIST data
def load_mnist_data():
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])

    # Download and load MNIST dataset
    dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

    return dataloader

# Function to generate images based on user input
def generate_digit_image(generator, digit):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    generator.to(device)
    generator.eval()

    with torch.no_grad():
        # Create random noise
        noise = torch.randn(1, 100, device=device)

        # Convert digit to tensor
        label = torch.tensor([digit], device=device)

        # Generate the image
        generated_img = generator(noise, label)

        # Convert to numpy for display
        img = generated_img.detach().cpu().squeeze().numpy()

        # Display the image
        plt.figure(figsize=(3, 3))
        plt.imshow(img, cmap='gray')
        plt.title(f"Generated Digit: {digit}")
        plt.axis('off')
        plt.show()



In [15]:
# Main function to handle user input and generation
def interactive_digit_generation():
    # Load the trained generator
    generator = Generator()
    try:
        generator.load_state_dict(torch.load('mnist_generator.pth'))
        print("Pre-trained model loaded successfully!")
    except:
        print("No pre-trained model found. Training a new model for all digits...")
        dataloader = load_mnist_data()
        generator = train_gan(generator, Discriminator(), dataloader, num_epochs=15)

    # Function to handle generating images and training progress
    def generate_callback(digit_str):
        try:
            digit = int(digit_str)
            if 0 <= digit <= 9:
                print(f"Generating image for digit {digit}...")
                generate_digit_image(generator, digit)

                # Create buttons for yes/no instead of input()
                yes_button = widgets.Button(description="Yes")
                no_button = widgets.Button(description="No")

                print("Would you like to see a GIF of training progress for this digit?")

                def on_yes_click(b):
                    clear_output(wait=True)
                    print(f"Creating training progress GIF for digit {digit}...")
                    create_digit_progress_gif(digit, num_epochs=15)

                def on_no_click(b):
                    clear_output(wait=True)
                    print(f"You can enter another digit to generate.")

                yes_button.on_click(on_yes_click)
                no_button.on_click(on_no_click)

                display(widgets.HBox([yes_button, no_button]))
            else:
                print("Please enter a digit between 0 and 9.")
        except ValueError:
            print("Please enter a valid integer.")

    # Create a text input widget using ipywidgets instead of google.colab.widgets
    text_input = widgets.Text(description="Enter a digit (0-9):", placeholder="Enter a digit")
    submit_button = widgets.Button(description="Generate")
    output = widgets.Output()

    def on_submit_click(b):
        with output:
            clear_output()
            generate_callback(text_input.value)

    submit_button.on_click(on_submit_click)

    # Display the widgets
    display(widgets.VBox([widgets.HBox([text_input, submit_button]), output]))

# Run the interactive application
interactive_digit_generation()

Pre-trained model loaded successfully!


VBox(children=(HBox(children=(Text(value='', description='Enter a digit (0-9):', placeholder='Enter a digit'),…