In [1]:
import os
import numpy as np

import torch
import torch.nn as nn

import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torchvision import datasets

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
img_size = 28 # size of each image dimension
channels = 1  # number of image channels
img_shape = (channels, img_size, img_size)
input_dim  = 784  # 28*28 for MNIST
hidden_dim = 128
lr = 0.0002 
n_cpu = os.cpu_count()//2 # number of cpu threads to use during batch generation
batch_size = 64 
n_epochs = 10 
noise_dim = 100 # dimensionality of the input noise

In [3]:
data_path = '../data'
os.makedirs(data_path, exist_ok=True)

In [4]:
# Configure data loader
transform = transforms.Compose([transforms.Resize(img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])

loader_kwargs = {'num_workers': os.cpu_count()//2, 'pin_memory': True} 

train_data = datasets.MNIST(root=data_path, train=True, download=True, transform=transform)
test_data = datasets.MNIST(root=data_path, train=False, download=True, transform=transform)

train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=False, **loader_kwargs)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, **loader_kwargs)

In [5]:
class Generator(nn.Module):
    def __init__(self, hidden_dim=hidden_dim):
        super(Generator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(noise_dim, hidden_dim),
            nn.LeakyReLU(True),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(True),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(True),
            nn.Linear(hidden_dim, input_dim, img_shape),
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), *img_shape)
        return img

In [6]:
class Discriminator(nn.Module):
    def __init__(self, hidden_dim=hidden_dim):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LeakyReLU(True),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(True),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid(),
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)
        return validity

In [7]:
criterion = torch.nn.BCELoss()

generator = Generator()
discriminator = Discriminator()

optimizer_G = torch.optim.AdamW(generator.parameters(), lr=lr)
optimizer_D = torch.optim.AdamW(discriminator.parameters(), lr=lr)

generator.to(device)
discriminator.to(device)

Discriminator(
  (model): Sequential(
    (0): Linear(in_features=784, out_features=128, bias=True)
    (1): LeakyReLU(negative_slope=True)
    (2): Linear(in_features=128, out_features=128, bias=True)
    (3): LeakyReLU(negative_slope=True)
    (4): Linear(in_features=128, out_features=1, bias=True)
    (5): Sigmoid()
  )
)

# Training

In [8]:
saved_dir = 'gan_images'
os.makedirs(saved_dir, exist_ok = True)

In [9]:
loss_estimate = []
batches_done = 0
for epoch in range(n_epochs):
    for i, (imgs, _) in enumerate(train_loader):

        # ground truths
        real = torch.ones(imgs.size(0), 1, requires_grad=False, device=device)
        fake = torch.zeros(imgs.size(0), 1, requires_grad=False, device=device)

        real_imgs = imgs.to(device)

        # ===== Generator =====
        optimizer_G.zero_grad()

        z = torch.normal(0, 1, (imgs.shape[0], noise_dim), device=device)


        gen_imgs = generator(z)

        loss_G = criterion(discriminator(gen_imgs), real)

        loss_G.backward()
        optimizer_G.step()

        # ===== Discriminator =====
        optimizer_D.zero_grad()

        real_loss = criterion(discriminator(real_imgs), real)
        fake_loss = criterion(discriminator(gen_imgs.detach()), fake)
        loss_D = (real_loss + fake_loss) / 2

        loss_D.backward()
        optimizer_D.step()

    # ===== save images and print logs =====
    save_image(gen_imgs.data[:25], f"{saved_dir}/{epoch+1}.png", nrow=5, normalize=True)
    print(f"epoch: {epoch}/{n_epochs}, D loss: {loss_D.item():.4f}, G loss: {loss_G.item():.4f}")

epoch: 0/10, D loss: 43.8907, G loss: 0.2839
epoch: 1/10, D loss: 50.0000, G loss: 0.0000
epoch: 2/10, D loss: 50.0000, G loss: 0.0000
epoch: 3/10, D loss: 50.0000, G loss: 0.0000
epoch: 4/10, D loss: 50.0000, G loss: 0.0000
epoch: 5/10, D loss: 48.4375, G loss: 1.0840
epoch: 6/10, D loss: 50.0000, G loss: 0.0000
epoch: 7/10, D loss: 50.0000, G loss: 0.0000
epoch: 8/10, D loss: 50.0000, G loss: 0.0000
epoch: 9/10, D loss: 50.0000, G loss: 0.0000
