In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from torch.autograd import Variable, grad

# Create folder for generated images
os.makedirs("generated_images", exist_ok=True)

# Transform to normalize the data between -1 and 1
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Load MNIST dataset
trainset = torchvision.datasets.MNIST(
    root='./data',
    train=True,
    download=True,
    transform=transform
)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=64, shuffle=True, num_workers=2
)

# Set device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Generator with convolutional layers
class Generator(nn.Module):
    def __init__(self, latent_dim=100):
        super(Generator, self).__init__()
        
        self.init_size = 7  # Initial size before upsampling
        self.l1 = nn.Sequential(
            nn.Linear(latent_dim, 128 * self.init_size ** 2)
        )

        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 1, 3, stride=1, padding=1),
            nn.Tanh()
        )

    def forward(self, z):
        out = self.l1(z)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img

# Discriminator (WGAN Critic) with convolutional layers
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        def discriminator_block(in_filters, out_filters, bn=True):
            block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), 
                    nn.LeakyReLU(0.2, inplace=True),
                    nn.Dropout2d(0.25)]
            if bn:
                block.append(nn.BatchNorm2d(out_filters, 0.8))
            return block

        self.model = nn.Sequential(
            *discriminator_block(1, 16, bn=False),
            *discriminator_block(16, 32),
            *discriminator_block(32, 64),
            *discriminator_block(64, 128),
        )

        # The height and width of downsampled image
        ds_size = 2
        self.adv_layer = nn.Linear(128 * ds_size ** 2, 1)

    def forward(self, img):
        out = self.model(img)
        out = out.view(out.shape[0], -1)
        validity = self.adv_layer(out)
        return validity

# Compute gradient penalty
def compute_gradient_penalty(D, real_samples, fake_samples):
    # Random weight term for interpolation between real and fake samples
    alpha = torch.rand(real_samples.size(0), 1, 1, 1, device=device)
    # Get random interpolation between real and fake samples
    interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
    d_interpolates = D(interpolates)
    fake = torch.ones(real_samples.size(0), 1, device=device)
    
    # Get gradient w.r.t. interpolates
    gradients = grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=fake,
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]
    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty

# Initialize models
latent_dim = 100
generator = Generator(latent_dim).to(device)
discriminator = Discriminator().to(device)

# Initialize weights
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)
    elif classname.find("Linear") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)

generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)

# Optimizers
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

# WGAN-GP hyperparameters
lambda_gp = 10  # Gradient penalty lambda
n_critic = 5   # Number of critic iterations per generator iteration
num_epochs = 100
fixed_noise = torch.randn(64, latent_dim, device=device)
G_losses = []
D_losses = []
Wasserstein_distances = []
img_list = []

print("Starting Training Loop...")
for epoch in range(num_epochs):
    for i, (imgs, _) in enumerate(trainloader):
        real_imgs = imgs.to(device)
        batch_size = real_imgs.size(0)

        # ---------------------
        # Train Discriminator
        # ---------------------
        optimizer_D.zero_grad()

        # Generate a batch of images
        z = torch.randn(batch_size, latent_dim, device=device)
        fake_imgs = generator(z)
        
        # Real images
        real_validity = discriminator(real_imgs)
        # Fake images
        fake_validity = discriminator(fake_imgs.detach())
        
        # Gradient penalty
        gradient_penalty = compute_gradient_penalty(discriminator, real_imgs.data, fake_imgs.data)
        
        # Wasserstein distance with gradient penalty
        wasserstein_dist = -torch.mean(real_validity) + torch.mean(fake_validity)
        d_loss = wasserstein_dist + lambda_gp * gradient_penalty
        
        d_loss.backward()
        optimizer_D.step()

        # Train the generator every n_critic steps
        if i % n_critic == 0:
            # -----------------
            # Train Generator
            # -----------------
            optimizer_G.zero_grad()
            
            # Generate a batch of images
            z = torch.randn(batch_size, latent_dim, device=device)
            gen_imgs = generator(z)
            
            # Loss measures generator's ability to fool the discriminator
            gen_validity = discriminator(gen_imgs)
            g_loss = -torch.mean(gen_validity)
            
            g_loss.backward()
            optimizer_G.step()
            
            # Save losses
            G_losses.append(g_loss.item())
            D_losses.append(d_loss.item())
            Wasserstein_distances.append(wasserstein_dist.item())

        if i % 100 == 0:
            print(f"[Epoch {epoch}/{num_epochs}] [Batch {i}/{len(trainloader)}] "
                  f"[D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}] "
                  f"[Wasserstein: {wasserstein_dist.item():.4f}]")

    # Save generated images for the current epoch
    if epoch % 5 == 0 or epoch == num_epochs - 1:
        with torch.no_grad():
            fake = generator(fixed_noise).detach().cpu()
        grid = torchvision.utils.make_grid(fake, padding=2, normalize=True)
        img_list.append(grid)

        plt.figure(figsize=(8, 8))
        plt.axis("off")
        plt.title(f"Generated Images - Epoch {epoch}")
        plt.imshow(np.transpose(grid, (1, 2, 0)))
        plt.savefig(f"generated_images/epoch_{epoch:03d}.png")
        plt.close()

# Plot Wasserstein distance
plt.figure(figsize=(10, 5))
plt.plot(Wasserstein_distances)
plt.title("Wasserstein Distance During Training")
plt.xlabel("Iterations")
plt.ylabel("Wasserstein Distance")
plt.savefig("wasserstein_distance_plot.png")
plt.close()

# Plot loss curves
plt.figure(figsize=(10, 5))
plt.plot(G_losses, label='Generator Loss')
plt.plot(D_losses, label='Discriminator Loss')
plt.title("WGAN-GP Training Losses")
plt.xlabel("Iterations")
plt.ylabel("Loss")
plt.legend()
plt.savefig("wgan_gp_loss_plot.png")
plt.close()

# Create animation of training progress
if len(img_list) > 1:
    try:
        import imageio
        
        # Create a GIF of training progress
        frames = []
        for epoch_idx in range(0, num_epochs, 5):
            if epoch_idx >= len(img_list) * 5:
                break
            idx = epoch_idx // 5
            img_array = np.transpose(img_list[idx].numpy(), (1, 2, 0))
            # Convert to uint8
            img_array = ((img_array + 1) / 2 * 255).astype(np.uint8)
            frames.append(img_array)
            
        imageio.mimsave('training_progress.gif', frames, fps=3)
        print("Training progress GIF saved.")
    except ImportError:
        print("Warning: imageio not installed. Skipping GIF creation.")

# Final comparison
plt.figure(figsize=(12, 6))

plt.subplot(1, 2, 1)
plt.axis("off")
plt.title("Real Images")
real_batch = next(iter(trainloader))[0].to(device)
real_grid = torchvision.utils.make_grid(real_batch[:64], padding=5, normalize=True)
plt.imshow(np.transpose(real_grid.cpu(), (1, 2, 0)))

plt.subplot(1, 2, 2)
plt.axis("off")
plt.title("Generated Images")
plt.imshow(np.transpose(img_list[-1], (1, 2, 0)))
plt.savefig("wgan_gp_final_comparison.png")
plt.close()

print("Training complete!")

Using device: cpu
Starting Training Loop...
[Epoch 0/100] [Batch 0/938] [D loss: 9.9894] [G loss: -0.0000] [Wasserstein: -0.0002]
[Epoch 0/100] [Batch 100/938] [D loss: -2.4196] [G loss: -4.6022] [Wasserstein: -3.5350]
[Epoch 0/100] [Batch 200/938] [D loss: -12.3110] [G loss: 1.2107] [Wasserstein: -13.5060]
[Epoch 0/100] [Batch 300/938] [D loss: -2.9057] [G loss: -2.1352] [Wasserstein: -3.6861]
[Epoch 0/100] [Batch 400/938] [D loss: -2.1115] [G loss: -3.6115] [Wasserstein: -2.5324]
[Epoch 0/100] [Batch 500/938] [D loss: -0.7044] [G loss: -1.2658] [Wasserstein: -1.0982]
[Epoch 0/100] [Batch 600/938] [D loss: -2.0738] [G loss: 0.2977] [Wasserstein: -2.4001]
[Epoch 0/100] [Batch 700/938] [D loss: -1.9834] [G loss: 1.0243] [Wasserstein: -2.3062]
[Epoch 0/100] [Batch 800/938] [D loss: -0.9204] [G loss: -0.2387] [Wasserstein: -1.2349]
[Epoch 0/100] [Batch 900/938] [D loss: -1.0425] [G loss: 1.0889] [Wasserstein: -1.2430]
[Epoch 1/100] [Batch 0/938] [D loss: -1.7588] [G loss: 0.6406] [Wassers