This code is based on https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/gan/gan.py

In [None]:
import torch
import torch.nn as nn

from torchvision import datasets
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader

import os
import numpy as np
from tqdm import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 1. Dataset

In [None]:
data_dir = './data'
os.makedirs(data_dir, exist_ok=True)

In [None]:
train_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5],[0.5]),
    ])
train_dataset = datasets.MNIST(data_dir, train=True, transform=train_transform, download=True)
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)

# 2. Model

In [None]:
params = {
    'num_classes': 10,
    'latent_space': 100,
    'input_size': (1,28,28),
    'lr': 2e-4,
    'b1': 0.5,
    'b2': 0.999,
    'epochs': 100,
}

In [None]:
class Generator(nn.Module):
    def __init__(self, params):
        super().__init__()
        self.num_classes = params['num_classes']
        self.latent_dim = params['latent_space']
        self.input_size = params['input_size']

        self.model = nn.Sequential(
            *self.block(self.latent_dim, 128, normalize=False),
            *self.block(128,256),
            *self.block(256,512),
            *self.block(512,1024),
            nn.Linear(1024,int(np.prod(self.input_size))),
            nn.Tanh()
        )
        
    def block(self, in_channels, out_channels, normalize=True):
        layers = []
        layers.append(nn.Linear(in_channels, out_channels)) # fc layer
        if normalize:
            layers.append(nn.BatchNorm1d(out_channels, 0.8)) # Batch Normalization
        layers.append(nn.LeakyReLU(0.2)) # LeakyReLU
        return layers


    def forward(self, noise):
        img = self.model(noise)
        img = img.view(img.size(0), *self.input_size)
        return img

In [None]:
class Discriminator(nn.Module):
    def __init__(self,params):
        super().__init__()
        self.input_size = params['input_size']

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(self.input_size)), 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512,256),
            nn.LeakyReLU(0.2),
            nn.Linear(256,1),
            nn.Sigmoid()
        )

    def forward(self, img):
        img = img.view(img.size(0),-1)
        validity = self.model(img)
        return validity

# 3. Train

In [None]:
# Loss function
adversarial_loss = torch.nn.BCELoss()

# Initialize generator and discriminator
generator = Generator(params)
discriminator = Discriminator(params)

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=params['lr'], betas=(params['b1'], params['b2']))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=params['lr'], betas=(params['b1'], params['b2']))

In [None]:
batch_count = 0

for epoch in tqdm(range(params['epochs'])):
    for i, (imgs, _) in enumerate(train_dataloader):
        
        # Adversarial ground truths
        valid = torch.Tensor(imgs.size(0), 1).fill_(1.0).to(device)
        fake = torch.Tensor(imgs.size(0), 1).fill_(0.0).to(device)

        real_imgs = imgs.to(device)

        #-----------------
        # Train Generator 
        #-----------------
        optimizer_G.zero_grad()

        # Sample noise as Generator input
        z = torch.randn(imgs.size(0), params['latent_space'], device=device) 
        
        # Generate a batch of images
        gen_imgs = generator(z) 

        # Loss measure Generator's ability to fool the discriminator
        g_loss = adversarial_loss(discriminator(gen_imgs), valid) 
        
        g_loss.backward()
        optimizer_G.step()

        #---------------------
        # Train Discriminator 
        #---------------------

        optimizer_D.zero_grad()

        # Measure discriminator's ability to classify real from generated samples
        real_loss = adversarial_loss(discriminator(real_imgs), valid)
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2

        d_loss.backward()
        optimizer_D.step()


        batches_done = epoch * len(train_dataloader) + i
        if batches_done % 1000 == 0:      
            print(
                "[Epoch %d/%d] [D loss: %f] [G loss: %f]"
                % (epoch, params['epochs'], d_loss.item(), g_loss.item())
            )
            save_image(gen_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True)