Referecnce : https://github.com/eriklindernoren/PyTorch-GAN

In [1]:
import argparse
import os
import numpy as np
import math

import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch

In [2]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        
        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            
            return layers
        
        self.model = nn.Sequential(
                    *block(100,256,normalize=True),
                    *block(256,512,normalize=True),
                    *block(512,1024,normalize=True),
                    nn.Dropout(),
                    nn.Linear(1024, 784),
                    nn.Tanh()
                    )


    def forward(self, z):
        
        img = self.model(z)
        img = img.view(img.size(0), 1, 28, 28)
        
        return img


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        
        self.model = nn.Sequential(
                    nn.Dropout(),
                    nn.Linear(784, 512),
                    nn.LeakyReLU(0.2),
                    nn.Linear(512, 256),
                    nn.LeakyReLU(0.2),
                    nn.Linear(256, 1),
                    nn.Sigmoid()
                    )

    def forward(self, img):
        
        img_flat = img.view(img.size(0), -1)
        out = self.model(img_flat)

        return out

In [3]:
# Loss function
adversarial_loss = torch.nn.BCELoss()

# Initialize generator and discriminator
generator = Generator().cuda()
discriminator = Discriminator().cuda()

In [4]:
dataset = datasets.MNIST("./mnist", train=True, download=True, 
                         transform=transforms.Compose([transforms.Resize(28), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]))
dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True, drop_last=True)


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./mnist/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./mnist/MNIST/raw/train-images-idx3-ubyte.gz to ./mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./mnist/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./mnist/MNIST/raw/train-labels-idx1-ubyte.gz to ./mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./mnist/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./mnist/MNIST/raw/t10k-images-idx3-ubyte.gz to ./mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./mnist/MNIST/raw



In [5]:
optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.9999))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.9999))

In [6]:
os.makedirs("./mlpgan", exist_ok=True)
for epoch in range(200):
    for i, (real_imgs, _) in enumerate(dataloader):

        # -----------------
        #  Train Generator
        # -----------------
        real_imgs = real_imgs.cuda()

        optimizer_G.zero_grad()

        # Sample noise as generator input
        z = torch.Tensor(np.random.normal(0, 1, (real_imgs.shape[0], 100))).cuda()

        # Generate a batch of images
        gen_imgs = generator(z)
        
        # Loss measures generator's ability to fool the discriminator
        g_loss = adversarial_loss(discriminator(gen_imgs), torch.ones((gen_imgs.size(0),1)).cuda())

        g_loss.backward()
        optimizer_G.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizer_D.zero_grad()

        # Measure discriminator's ability to classify real from generated samples
        real_loss = adversarial_loss(discriminator(real_imgs), torch.ones((gen_imgs.size(0),1)).cuda())
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), torch.zeros((gen_imgs.size(0),1)).cuda())
        d_loss = (real_loss + fake_loss) / 2

        d_loss.backward()
        optimizer_D.step()

        batches_done = epoch * len(dataloader) + i
        
        if batches_done % 100 == 0:
            print(
                "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
                % (epoch, 64, i, len(dataloader), d_loss.item(), g_loss.item())
            )

        batches_done = epoch * len(dataloader) + i
        if batches_done % 2000 == 0:
            save_image(gen_imgs.data[:25], "mlpgan/%d.png" % batches_done, nrow=5, normalize=True)

[Epoch 0/64] [Batch 0/937] [D loss: 0.695439] [G loss: 0.675042]
[Epoch 0/64] [Batch 100/937] [D loss: 0.321986] [G loss: 0.846146]
[Epoch 0/64] [Batch 200/937] [D loss: 0.473571] [G loss: 0.766429]
[Epoch 0/64] [Batch 300/937] [D loss: 0.386632] [G loss: 0.967008]
[Epoch 0/64] [Batch 400/937] [D loss: 0.585910] [G loss: 0.692223]
[Epoch 0/64] [Batch 500/937] [D loss: 0.434597] [G loss: 1.185722]
[Epoch 0/64] [Batch 600/937] [D loss: 0.385554] [G loss: 1.080754]
[Epoch 0/64] [Batch 700/937] [D loss: 0.390274] [G loss: 1.138490]
[Epoch 0/64] [Batch 800/937] [D loss: 0.366427] [G loss: 2.059252]
[Epoch 0/64] [Batch 900/937] [D loss: 0.415237] [G loss: 1.506824]
[Epoch 1/64] [Batch 63/937] [D loss: 0.372891] [G loss: 1.246332]
[Epoch 1/64] [Batch 163/937] [D loss: 0.498919] [G loss: 0.651340]
[Epoch 1/64] [Batch 263/937] [D loss: 0.426200] [G loss: 1.418907]
[Epoch 1/64] [Batch 363/937] [D loss: 0.628441] [G loss: 0.586570]
[Epoch 1/64] [Batch 463/937] [D loss: 0.619569] [G loss: 0.533368