In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import clear_output

In [2]:
# Set random seed for reproducibility
torch.manual_seed(42)

# Hyperparameters
latent_dim = 100
hidden_dim = 256
image_dim = 784
batch_size = 64
num_epochs = 100  # Increased epochs
lr = 0.0002

In [3]:
# Generator Network
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, image_dim),
            nn.Tanh()
        )
    
    def forward(self, z):
        return self.model(z)

# Discriminator Network
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(image_dim, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        return self.model(x)

In [4]:
# Function to generate and save sample images
def save_samples(generator, fixed_noise, epoch, fig_size=(10, 10)):
    with torch.no_grad():
        fake_images = generator(fixed_noise)
        fake_images = fake_images.view(-1, 28, 28)
        
        plt.figure(figsize=fig_size)
        for i in range(16):
            plt.subplot(4, 4, i+1)
            plt.imshow(fake_images[i].cpu().numpy(), cmap='gray')
            plt.axis('off')
        plt.suptitle(f'Generated Images at Epoch {epoch}', fontsize=20)
        plt.tight_layout()
        return plt.gcf()

# Load MNIST Dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

mnist = torchvision.datasets.MNIST(root='./data', train=True, 
                                 transform=transform, download=True)
dataloader = DataLoader(mnist, batch_size=batch_size, shuffle=True)

In [5]:
# Initialize networks and optimizers
device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")
generator = Generator().to(device)
discriminator = Discriminator().to(device)
g_optimizer = optim.Adam(generator.parameters(), lr=lr)
d_optimizer = optim.Adam(discriminator.parameters(), lr=lr)
criterion = nn.BCELoss()

# Fixed noise for visualization
fixed_noise = torch.randn(16, latent_dim,device=device)

# Training loop with visualization
d_losses = []
g_losses = []
evolution_images = []
checkpoints = [1, 10, 25, 50, 75, 100]  # Epochs to visualize

In [6]:
def train_and_visualize():
    for epoch in range(num_epochs):
        d_loss_epoch = 0
        g_loss_epoch = 0
        batch_count = 0
        
        for batch_idx, (real_images, _) in enumerate(dataloader):
            batch_size = real_images.size(0)
            real_images = real_images.view(batch_size, -1).to(device)
            
            # Train Discriminator
            discriminator.zero_grad()
            real_labels = torch.ones(batch_size, 1,device=device)
            fake_labels = torch.zeros(batch_size, 1,device=device)
            
            outputs = discriminator(real_images)
            d_loss_real = criterion(outputs, real_labels)
            
            z = torch.randn(batch_size, latent_dim,device=device)
            fake_images = generator(z)
            outputs = discriminator(fake_images.detach())
            d_loss_fake = criterion(outputs, fake_labels)
            
            d_loss = d_loss_real + d_loss_fake
            d_loss.backward()
            d_optimizer.step()
            
            # Train Generator
            generator.zero_grad()
            outputs = discriminator(fake_images)
            g_loss = criterion(outputs, real_labels)
            g_loss.backward()
            g_optimizer.step()
            
            d_loss_epoch += d_loss.item()
            g_loss_epoch += g_loss.item()
            batch_count += 1
        
        # Record losses
        d_losses.append(d_loss_epoch / batch_count)
        g_losses.append(g_loss_epoch / batch_count)
        
        # Save images at checkpoints
        if (epoch + 1) in checkpoints:
            evolution_images.append(save_samples(generator, fixed_noise, epoch + 1))
            print(f'Epoch [{epoch+1}/{num_epochs}], d_loss: {d_losses[-1]:.4f}, g_loss: {g_losses[-1]:.4f}')

In [None]:
# Train the model
train_and_visualize()

# Plot loss curves
plt.figure(figsize=(10, 5))
plt.plot(d_losses, label='Discriminator Loss')
plt.plot(g_losses, label='Generator Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Losses')
plt.legend()
plt.grid(True)
plt.show()

# Display evolution grid
plt.figure(figsize=(20, 12))
for i, img in enumerate(evolution_images):
    plt.subplot(2, 3, i+1)
    plt.imshow(img.canvas.renderer.buffer_rgba())
    plt.axis('off')
plt.suptitle('Evolution of Generated Images During Training', fontsize=16)
plt.tight_layout()
plt.show()