In [7]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.autograd import Variable

# Generator network
class Generator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Generator, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_size + 10, hidden_size),  # Extra 10 for the one-hot encoded labels
            nn.ReLU(),
            nn.Linear(hidden_size, output_size),
            nn.Tanh()
        )

    def forward(self, x, labels):
        x = torch.cat([x, labels], 1)  # Concatenate the labels to the noise vector
        return self.fc(x)

# Discriminator network
class Discriminator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Discriminator, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_size + 10, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, output_size),
            nn.Sigmoid()  # Use sigmoid for binary classification
        )

    def forward(self, x, labels):
        x = torch.cat([x, labels], 1)
        return self.fc(x)

# Hyperparameters
batch_size = 64
lr = 0.0002
z_size = 100  # Size of the random noise vector
hidden_size = 128

# Load dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
# if the MNIST data is already downloaded
train_dataset = datasets.MNIST('../data', train=True, download=False, transform=transform)

if not train_dataset:
    # If not, download the data
    train_dataset = datasets.MNIST('../data', train=True, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# Initialize networks and optimizers
generator = Generator(z_size, hidden_size, 28*28)
discriminator = Discriminator(28*28, hidden_size, 1)

# Use proper weight initialization
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Linear') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

generator.apply(weights_init)
discriminator.apply(weights_init)

optimizer_G = optim.Adam(generator.parameters(), lr=lr)
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr)

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    for batch, (real_images, real_labels) in enumerate(train_loader):
        actual_batch_size = real_images.size(0)
        real_labels_one_hot = torch.zeros(actual_batch_size, 10)
        real_labels_one_hot[torch.arange(actual_batch_size), real_labels] = 1

        optimizer_D.zero_grad()

        real_images = real_images.view(-1, 28*28)
        real_labels = real_labels_one_hot

        # Soft labels for real samples (label smoothing)
        real_labels_smooth = 0.9 * torch.ones(actual_batch_size, 1)
        fake_labels_smooth = 0.1 * torch.ones(actual_batch_size, 1)

        # Forward pass real batch through discriminator
        output_real = discriminator(real_images, real_labels_one_hot)
        loss_real = nn.BCELoss()(output_real, real_labels_smooth)

        # Generate fake images
        noise = Variable(torch.randn(actual_batch_size, z_size))
        fake_images = generator(noise, real_labels_one_hot)

        # Forward pass fake batch through discriminator
        output_fake = discriminator(fake_images.detach(), real_labels_one_hot)
        loss_fake = nn.BCELoss()(output_fake, fake_labels_smooth)

        # Backpropagation
        loss_d = loss_real + loss_fake
        loss_d.backward()
        optimizer_D.step()

        # Train Generator
        optimizer_G.zero_grad()

        # Forward pass fake batch through discriminator again
        output_fake = discriminator(fake_images, real_labels_one_hot)
        loss_g = nn.BCELoss()(output_fake, real_labels_smooth)

        # Backpropagation
        loss_g.backward()
        optimizer_G.step()

        # Print loss
        if batch % 900 == 0:
            print(f'Epoch [{epoch}/{num_epochs}], Batch [{batch}/{len(train_loader)}], '
                  f'D Loss: {loss_d.item():.4f}, G Loss: {loss_g.item():.4f}')

Epoch [0/10], Batch [0/938], D Loss: 1.3968, G Loss: 0.6935
Epoch [0/10], Batch [900/938], D Loss: 1.1721, G Loss: 0.7949
Epoch [1/10], Batch [0/938], D Loss: 1.2235, G Loss: 0.7505
Epoch [1/10], Batch [900/938], D Loss: 0.9341, G Loss: 1.0426
Epoch [2/10], Batch [0/938], D Loss: 1.1941, G Loss: 0.7952
Epoch [2/10], Batch [900/938], D Loss: 1.6778, G Loss: 0.5687
Epoch [3/10], Batch [0/938], D Loss: 1.5920, G Loss: 0.6017
Epoch [3/10], Batch [900/938], D Loss: 1.2262, G Loss: 0.8026
Epoch [4/10], Batch [0/938], D Loss: 1.3127, G Loss: 0.7521
Epoch [4/10], Batch [900/938], D Loss: 1.5375, G Loss: 0.6345
Epoch [5/10], Batch [0/938], D Loss: 1.2383, G Loss: 0.8614
Epoch [5/10], Batch [900/938], D Loss: 1.4934, G Loss: 0.6959
Epoch [6/10], Batch [0/938], D Loss: 1.5761, G Loss: 0.6780
Epoch [6/10], Batch [900/938], D Loss: 1.4038, G Loss: 0.7608
Epoch [7/10], Batch [0/938], D Loss: 1.3838, G Loss: 0.7631
Epoch [7/10], Batch [900/938], D Loss: 0.8980, G Loss: 1.2652
Epoch [8/10], Batch [0/9