In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np

# Create folder for generated images
os.makedirs("generated_images", exist_ok=True)

# Transform to normalize the data between -1 and 1
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Load MNIST dataset
trainset = torchvision.datasets.MNIST(
    root='./data',
    train=True,
    download=True,
    transform=transform
)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=64, shuffle=True, num_workers=2
)

# Set device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Generator
class Generator(nn.Module):
    def __init__(self, latent_dim=100):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.ReLU(True),
            nn.Linear(256, 512),
            nn.ReLU(True),
            nn.Linear(512, 1024),
            nn.ReLU(True),
            nn.Linear(1024, 28*28),
            nn.Tanh()
        )

    def forward(self, input):
        output = self.main(input)
        return output.view(-1, 1, 28, 28)

# Discriminator (WGAN Critic)
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(28*28, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1)
        )

    def forward(self, input):
        input = input.view(-1, 28*28)
        return self.main(input)

# Weight initialization
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Linear') != -1:
        nn.init.kaiming_uniform_(m.weight.data, a=0.2)
        if m.bias is not None:
            m.bias.data.fill_(0)

# Initialize models
latent_dim = 100
generator = Generator(latent_dim).to(device)
discriminator = Discriminator().to(device)
generator.apply(weights_init)
discriminator.apply(weights_init)

# Optimizers
optimizer_G = optim.RMSprop(generator.parameters(), lr=0.00005)
optimizer_D = optim.RMSprop(discriminator.parameters(), lr=0.00005)

# WGAN parameters
n_critic = 5
clip_value = 0.01
num_epochs = 50
fixed_noise = torch.randn(64, latent_dim, device=device)
G_losses = []
D_losses = []
img_list = []

print("Starting Training Loop...")
for epoch in range(num_epochs):
    for i, data in enumerate(trainloader):
        real_images = data[0].to(device)
        batch_size = real_images.size(0)

        # ---------------------
        # Train Discriminator
        # ---------------------
        for _ in range(n_critic):
            optimizer_D.zero_grad()

            real_output = discriminator(real_images).view(-1)
            noise = torch.randn(batch_size, latent_dim, device=device)
            fake_images = generator(noise)
            fake_output = discriminator(fake_images.detach()).view(-1)

            d_loss = -(torch.mean(real_output) - torch.mean(fake_output))
            d_loss.backward()
            optimizer_D.step()

            # Clip weights
            for p in discriminator.parameters():
                p.data.clamp_(-clip_value, clip_value)

        # -----------------
        # Train Generator
        # -----------------
        optimizer_G.zero_grad()
        noise = torch.randn(batch_size, latent_dim, device=device)
        fake_images = generator(noise)
        fake_output = discriminator(fake_images).view(-1)
        g_loss = -torch.mean(fake_output)
        g_loss.backward()
        optimizer_G.step()

        # Save losses
        G_losses.append(g_loss.item())
        D_losses.append(d_loss.item())

        if i % 100 == 0:
            print(f"[Epoch {epoch}/{num_epochs}] Step {i}/{len(trainloader)} "
                  f"Loss_D: {d_loss.item():.4f} Loss_G: {g_loss.item():.4f}")

    # Save generated images for the current epoch
    with torch.no_grad():
        fake = generator(fixed_noise).detach().cpu()
    grid = torchvision.utils.make_grid(fake, padding=2, normalize=True)
    img_list.append(grid)

    plt.figure(figsize=(6, 6))
    plt.axis("off")
    plt.title(f"Epoch {epoch}")
    plt.imshow(np.transpose(grid, (1, 2, 0)))
    plt.savefig(f"generated_images/epoch_{epoch:03d}.png")
    plt.close()

# Plot loss curve
plt.figure(figsize=(10, 5))
plt.plot(G_losses, label='Generator Loss')
plt.plot(D_losses, label='Discriminator Loss')
plt.title("WGAN Training Losses")
plt.xlabel("Iterations")
plt.ylabel("Loss")
plt.legend()
plt.savefig("wgan_loss_plot.png")
plt.close()

# Final comparison
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.axis("off")
plt.title("Real Images")
real_batch = next(iter(trainloader))[0].to(device)
real_grid = torchvision.utils.make_grid(real_batch[:64], padding=5, normalize=True)
plt.imshow(np.transpose(real_grid.cpu(), (1, 2, 0)))

plt.subplot(1, 2, 2)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(img_list[-1], (1, 2, 0)))
plt.savefig("wgan_final_comparison.png")
plt.close()


Using device: cpu
Starting Training Loop...
[Epoch 0/50] Step 0/938 Loss_D: -0.0072 Loss_G: 0.0027
[Epoch 0/50] Step 100/938 Loss_D: -1.8784 Loss_G: -1.6879
[Epoch 0/50] Step 200/938 Loss_D: -3.6362 Loss_G: -3.8740
[Epoch 0/50] Step 300/938 Loss_D: -3.3774 Loss_G: -4.9600
[Epoch 0/50] Step 400/938 Loss_D: -2.7655 Loss_G: -4.6600
[Epoch 0/50] Step 500/938 Loss_D: -2.7063 Loss_G: -3.6568
[Epoch 0/50] Step 600/938 Loss_D: -2.2661 Loss_G: -1.0638
[Epoch 0/50] Step 700/938 Loss_D: -2.2940 Loss_G: -3.7315
[Epoch 0/50] Step 800/938 Loss_D: -2.0941 Loss_G: -3.1257
[Epoch 0/50] Step 900/938 Loss_D: -2.2867 Loss_G: -4.2933
[Epoch 1/50] Step 0/938 Loss_D: -1.6409 Loss_G: -4.4634
[Epoch 1/50] Step 100/938 Loss_D: -2.0911 Loss_G: -4.8055
[Epoch 1/50] Step 200/938 Loss_D: -1.8259 Loss_G: -5.6522
[Epoch 1/50] Step 300/938 Loss_D: -1.7344 Loss_G: -4.4042
[Epoch 1/50] Step 400/938 Loss_D: -1.5641 Loss_G: -4.7958
[Epoch 1/50] Step 500/938 Loss_D: -1.9372 Loss_G: -7.2175
[Epoch 1/50] Step 600/938 Loss_D: