In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset

transform = transforms.Compose([
    transforms.ToTensor()
])

dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
subset = Subset(dataset, range(1000))
dataloader = DataLoader(subset, batch_size=10, shuffle=True)

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            nn.Linear(100, 256),
            nn.ReLU(),
            nn.Linear(256, 28*28),
            nn.Tanh()
        )

    def forward(self, x):
        return self.gen(x)

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            nn.Linear(28*28, 256),
            nn.ReLU(),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.disc(x)

generator = Generator()
discriminator = Discriminator()
criterion = nn.BCELoss()
optim_gen = optim.Adam(generator.parameters(), lr=2e-4)
optim_disc = optim.Adam(discriminator.parameters(), lr=2e-4)

def train(num_epochs):
    for epoch in range(num_epochs):
        generator.train()
        discriminator.train()
        for real, _ in dataloader:
            real = real.view(-1, 28*28)
            batch_size = real.size(0)
            
            # Train Discriminator
            noise = torch.randn(batch_size, 100)
            fake = generator(noise)
            disc_real = discriminator(real)
            loss_disc_real = criterion(disc_real, torch.ones_like(disc_real))
            disc_fake = discriminator(fake)
            loss_disc_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
            
            loss_disc = (loss_disc_real + loss_disc_fake) / 2
            
            # Backprop
            optim_disc.zero_grad()
            loss_disc.backward()
            optim_disc.step()
            
            # Train Generator
            noise = torch.randn(batch_size, 100)
            fake = generator(noise)
            disc_fake = discriminator(fake)
            loss_gen = criterion(disc_fake, torch.ones_like(disc_fake))
            
            # Backprop
            optim_gen.zero_grad()
            loss_gen.backward()
            optim_gen.step()
            
        print(f'Epoch {epoch+1}, Loss D: {loss_disc.item():.4f}, Loss G: {loss_gen.item():.4f}')

train(15)


Epoch 1, Loss D: 0.6074, Loss G: 1.1406
Epoch 2, Loss D: 0.2637, Loss G: 1.9539
Epoch 3, Loss D: 0.4815, Loss G: 1.5118
Epoch 4, Loss D: 0.5446, Loss G: 1.4493
Epoch 5, Loss D: 0.3713, Loss G: 1.6276
Epoch 6, Loss D: 0.3152, Loss G: 1.7755
Epoch 7, Loss D: 0.5351, Loss G: 1.1940
Epoch 8, Loss D: 0.3224, Loss G: 1.8870
Epoch 9, Loss D: 0.3258, Loss G: 1.4795
Epoch 10, Loss D: 0.3115, Loss G: 1.1164
Epoch 11, Loss D: 0.4486, Loss G: 0.9738
Epoch 12, Loss D: 0.5317, Loss G: 0.6488
Epoch 13, Loss D: 0.5472, Loss G: 0.6966
Epoch 14, Loss D: 0.5009, Loss G: 0.6672
Epoch 15, Loss D: 0.4654, Loss G: 0.7201
