In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from torch.utils.data import DataLoader

# Define the generator network
class Generator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Generator, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, output_size)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = torch.tanh(self.fc3(x))
        return x

# Define the discriminator network
class Discriminator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Discriminator, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, output_size)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.sigmoid(self.fc3(x))
        return x

# Hyperparameters
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
lr = 0.0002
batch_size = 64
num_epochs = 100
input_size = 100
hidden_size = 256
output_size = 784

# Load MNIST data
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5,), std=(0.5,))
])

train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

# Initialize generator and discriminator
generator = Generator(input_size, hidden_size, output_size).to(device)
discriminator = Discriminator(output_size, hidden_size, 1).to(device)

# Loss function and optimizers
criterion = nn.BCELoss()
gen_optimizer = optim.Adam(generator.parameters(), lr=lr)
discriminator_optimizer = optim.Adam(discriminator.parameters(), lr=lr)

# Training loop
for epoch in range(num_epochs):
    for batch_idx, (real_images, _) in enumerate(train_loader):
        real_images = real_images.view(-1, 28 * 28).to(device)
        batch_size = real_images.size(0)

        # Train discriminator
        discriminator_optimizer.zero_grad()
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)

        # Train on real images
        real_output = discriminator(real_images)
        disc_real_loss = criterion(real_output, real_labels)
        disc_real_loss.backward()

        # Train on fake images
        noise = torch.randn(batch_size, input_size).to(device)
        fake_images = generator(noise)
        fake_output = discriminator(fake_images.detach())
        disc_fake_loss = criterion(fake_output, fake_labels)
        disc_fake_loss.backward()

        discriminator_optimizer.step()

        # Train generator
        gen_optimizer.zero_grad()
        output = discriminator(fake_images)
        gen_loss = criterion(output, real_labels)
        gen_loss.backward()
        gen_optimizer.step()

    # Print losses
    print(f"Epoch [{epoch+1}/{num_epochs}], "
          f"Generator Loss: {gen_loss.item():.4f}, "
          f"Discriminator Loss: {(disc_real_loss + disc_fake_loss).item():.4f}")

# Generate sample images
import matplotlib.pyplot as plt

num_samples = 16
noise = torch.randn(num_samples, input_size).to(device)
generated_images = generator(noise).cpu().detach().view(-1, 28, 28)

plt.figure(figsize=(10, 10))
for i in range(num_samples):
    plt.subplot(4, 4, i + 1)
    plt.imshow(generated_images[i], cmap='gray')
    plt.axis('off')
plt.show()


Epoch [1/100], Generator Loss: 2.4479, Discriminator Loss: 0.5286
Epoch [2/100], Generator Loss: 2.4936, Discriminator Loss: 0.9887
Epoch [3/100], Generator Loss: 1.8801, Discriminator Loss: 0.6164
Epoch [4/100], Generator Loss: 3.7401, Discriminator Loss: 0.5654
Epoch [5/100], Generator Loss: 3.5422, Discriminator Loss: 0.3471
