Referecnce : https://github.com/eriklindernoren/PyTorch-GAN

In [2]:
import argparse
import os
import numpy as np
import math

import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch

In [3]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        ##implement
        def block(in_feature, out_feature):
            ## linear relu batchnorm
            layers = []
            layers.append(nn.Linear(in_feature, out_feature))
            layers.append(nn.ReLU())
            layers.append(nn.BatchNorm1d(out_feature))
            
            return layers
        
        self.model = nn.Sequential(
                *block(100, 256), # list
                *block(256, 512),
                *block(512, 1024),
                nn.Dropout(),
                nn.Linear(1024, 784),
                nn.Tanh() # -> output: -1 ~ 1 (real image range와 맞추기 위해서)
        )
        
    def forward(self, z):
        ##implement
        ## z: [batch, 100]
        img = self.model(z) ## [batch, 784]
        img = img.view(img.shape[0], 1, 28, 28)
        
        return img


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        ##implement
        self.model = nn.Sequential(
            nn.Dropout(),
            nn.Linear(784, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
        
    def forward(self, img):
        ##implement
        ## img: [batch, 1, 28, 28]
        img = img.view(img.shape[0], -1) # [batch, 784]
        out = self.model(img)
        
        return out

In [10]:
# Loss function
adversarial_loss = torch.nn.BCELoss()

# Initialize generator and discriminator
generator = Generator().cuda()
discriminator = Discriminator().cuda()

In [11]:
dataset = datasets.MNIST("./mnist", train=True, download=True, 
                         transform=transforms.Compose([transforms.Resize(28), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]))
dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True, drop_last=True)


In [12]:
optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.9999))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.9999))

In [None]:
os.makedirs("./mlpgan", exist_ok=True)
for epoch in range(200):
    for i, (real_imgs, _) in enumerate(dataloader):
        
        # -----------------
        #  Train Generator
        # -----------------
        real_imgs = real_imgs.cuda()
        optimizer_G.zero_grad()
        # Sample noise as generator input
        ## 100 dim 
        z = torch.Tensor(np.random.normal(0, 1, (real_imgs.shape[0], 100))).cuda()
        
        # Generate a batch of images
        ## generator 
        gen_imgs = generator(z) ## [batch, 100] -> [batch, 1, 28, 28]
        
        # Loss measures generator's ability to fool the discriminator
        ## generator loss -> bceloss 
        ## generator update
        g_loss = adversarial_loss(
            discriminator(gen_imgs), torch.ones((gen_imgs.shape[0], 1)).cuda()
        )
        
        g_loss.backward() # compute gradiet
        optimizer_G.step() # weight update
        
        # ---------------------
        #  Train Discriminator
        # ---------------------

        # Measure discriminator's ability to classify real from generated samples
        ### discriminator
        ### bceloss -> 
        #D(x)
        
        optimizer_D.zero_grad()

        real_loss = adversarial_loss(
            discriminator(real_imgs), torch.ones((gen_imgs.shape[0], 1)).cuda()
        )
        fake_loss = adversarial_loss(
            discriminator(gen_imgs.detach()), torch.zeros((gen_imgs.shape[0], 1)).cuda()
        )
        d_loss = (real_loss + fake_loss)/2
        
        
        d_loss.backward()
        optimizer_D.step()

        batches_done = epoch * len(dataloader) + i
        
        if batches_done % 100 == 0:
            print(
                "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
                % (epoch, 64, i, len(dataloader), d_loss.item(), g_loss.item())
            )

        batches_done = epoch * len(dataloader) + i
        if batches_done % 2000 == 0:
            save_image(gen_imgs.data[:25], "mlpgan/%d.png" % batches_done, nrow=5, normalize=True)

[Epoch 0/64] [Batch 0/937] [D loss: 0.713945] [G loss: 0.627999]
[Epoch 0/64] [Batch 100/937] [D loss: 0.366326] [G loss: 2.716374]
[Epoch 0/64] [Batch 200/937] [D loss: 0.482585] [G loss: 2.162812]
[Epoch 0/64] [Batch 300/937] [D loss: 0.388261] [G loss: 2.229098]
[Epoch 0/64] [Batch 400/937] [D loss: 0.390624] [G loss: 2.104918]
[Epoch 0/64] [Batch 500/937] [D loss: 0.423960] [G loss: 1.506979]
[Epoch 0/64] [Batch 600/937] [D loss: 0.447793] [G loss: 1.867022]
[Epoch 0/64] [Batch 700/937] [D loss: 0.497519] [G loss: 1.089848]
[Epoch 0/64] [Batch 800/937] [D loss: 0.498748] [G loss: 0.638261]
[Epoch 0/64] [Batch 900/937] [D loss: 0.514719] [G loss: 1.108158]
[Epoch 1/64] [Batch 63/937] [D loss: 0.442374] [G loss: 1.247671]
[Epoch 1/64] [Batch 163/937] [D loss: 0.542918] [G loss: 1.148885]
[Epoch 1/64] [Batch 263/937] [D loss: 0.472351] [G loss: 1.340585]
[Epoch 1/64] [Batch 363/937] [D loss: 0.570362] [G loss: 0.804038]
[Epoch 1/64] [Batch 463/937] [D loss: 0.510023] [G loss: 1.345256