In [17]:
import numpy as np
import pandas as pd 
import torch
import torch.nn as nn
import torch.optim as optim

import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision.utils as vutils

import os 

In [18]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = datasets.MNIST(root='./mnist_data/', train=True, transform=transform, download=True)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)

In [19]:
class Generator(nn.Module):
    def __init__(self, noise_dim=100):
        super(Generator, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(noise_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, 28*28),
            nn.Tanh()
        )
        
    def forward(self, z):
        return self.net(z).view(-1, 1, 28, 28)
    
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(28*28, 1024),
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        x = x.view(-1, 28*28) 
        return self.net(x)
        

In [20]:
generator = Generator().to(device)
discriminator = Discriminator().to(device)

criterion = nn.BCELoss()
optimizer_g = optim.Adam(generator.parameters(), lr=0.0002)
optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0002)

def get_noise(batch_size, noise_dim):
    return torch.randn(batch_size, noise_dim, device=device)

In [None]:
num_epochs = 50
noise_dim = 100
os.makedirs('mnist_results', exist_ok=True)

for epoch in range(num_epochs):
    for i, (images, _) in enumerate(train_loader):
        batch_size = images.size(0)
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)
        
        discriminator.zero_grad()
        outputs = discriminator(images.to(device)).view(-1, 1)
        loss_d_real = criterion(outputs, real_labels)
        loss_d_real.backward()
        
        z = get_noise(batch_size, noise_dim)
        fake_images = generator(z)
        outputs = discriminator(fake_images.detach()).view(-1, 1)
        loss_d_fake = criterion(outputs, fake_labels)
        loss_d_fake.backward()
        
        loss_d = loss_d_fake + loss_d_real
        optimizer_d.step()
        
        generator.zero_grad()
        outputs = discriminator(fake_images).view(-1, 1)
        loss_g = criterion(outputs, real_labels)
        loss_g.backward()
        optimizer_g.step()
        
        if (i+1) % 100 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss D: {loss_d.item():.4f}, Loss G: {loss_g.item():.4f}')
    
    fake_images = fake_images.reshape(batch_size, 1, 28, 28)
    if (epoch + 1) % 5 == 5:
        vutils.save_image(fake_images, f'mnist_results/fake_img_epoch_{epoch+1}.png', normalize=True)
        

Epoch [1/50], Step [100/938], Loss D: 0.4048, Loss G: 1.3342
Epoch [1/50], Step [200/938], Loss D: 0.0100, Loss G: 4.7167
Epoch [1/50], Step [300/938], Loss D: 0.0829, Loss G: 10.9219
Epoch [1/50], Step [400/938], Loss D: 0.0135, Loss G: 7.7133
Epoch [1/50], Step [500/938], Loss D: 0.0907, Loss G: 9.8181
Epoch [1/50], Step [600/938], Loss D: 0.0000, Loss G: 25.0819
Epoch [1/50], Step [700/938], Loss D: 0.0285, Loss G: 6.6445
Epoch [1/50], Step [800/938], Loss D: 0.1209, Loss G: 9.8888
Epoch [1/50], Step [900/938], Loss D: 0.0559, Loss G: 13.1616
Epoch [2/50], Step [100/938], Loss D: 0.1413, Loss G: 6.7659
Epoch [2/50], Step [200/938], Loss D: 0.0445, Loss G: 5.6852
Epoch [2/50], Step [300/938], Loss D: 0.4565, Loss G: 7.0318
Epoch [2/50], Step [400/938], Loss D: 0.3708, Loss G: 6.3103
Epoch [2/50], Step [500/938], Loss D: 0.4115, Loss G: 3.3201
Epoch [2/50], Step [600/938], Loss D: 0.1499, Loss G: 2.8872
Epoch [2/50], Step [700/938], Loss D: 0.0553, Loss G: 4.8515
Epoch [2/50], Step [8