In [1]:
import os
import numpy as np
import math
import torch

from torchvision.datasets import MNIST
from torchvision import transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torch.autograd import Variable

from torch import nn
from torch import optim
import torch.nn.functional as F

In [2]:
os.makedirs('../out_img/GAN_MNIST', exist_ok=True)  # exist_ok means that create the directory only the directory is not exist
config = {"n_epochs": 30, "batch_size": 64, "lr" : 0.0001, "b1": 0.5, "b2": 0.999, "n_cpu": 8, "latent_dim": 100, "channels": 1, "img_size": 28, "sample_interval": 400}

img_shape = (config.get("channels"), config.get("img_size"), config.get("img_size"))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [3]:
print(type(img_shape))

<class 'tuple'>


### Data

In [6]:
dataloader = DataLoader(
    dataset=MNIST('./data/MNIST', train=True,
        download=False,
        transform=transforms.Compose(
            [transforms.Resize(config.get('img_size')), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
        ),
    ),
    batch_size = config.get('batch_size'),
    shuffle = True
)

### Model Construct

In [5]:
from model import Discriminator, Generator
model_gen = Generator(img_shape, latent_dim=config.get("latent_dim")).to(device)
model_dis = Discriminator(img_shape).to(device)

ad_loss = nn.BCELoss().to(device)

# since adadelta have a good performance in comparison on MNIST classification
# optim.Adadelta(model_dis.parameters(), lr=1.0, rho=0.9, eps=1e-06)
# optimizer_G = optim.Adadelta(model_gen.parameters(), lr=1.0, rho=0.9, eps=1e-06)
# optimizer_D = optim.Adadelta(model_dis.parameters(), lr=1.0, rho=0.9, eps=1e-06)

optimizer_G = optim.Adam(model_gen.parameters(), lr=config.get("lr"), betas=(config.get("b1"), config.get("b2")))
optimizer_D = optim.Adam(model_dis.parameters(), lr=config.get("lr"), betas=(config.get("b1"), config.get("b2")))

In [6]:
Tensor = torch.cuda.FloatTensor if device == 'cuda' else torch.FloatTensor

### Trainning

In [7]:
for epoch in range(config.get('n_epochs')):
    for i, (imgs, _) in enumerate(dataloader):

        # Adversarial ground truths
        valid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False)
        fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False)

        # Configure input
        real_imgs = Variable(imgs.type(Tensor))

        # **Train Generator**
        optimizer_G.zero_grad()
        # Sample noise as generator input
        z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], config.get('latent_dim')))))
        # Generate a batch of images
        # gen_imgs = model_gen.forward(z)
        gen_imgs = model_gen(z)
        # Loss measures generator's ability to fool the discriminator
                # pred = model_dis(gen_imgs)
                # print(gen_imgs)
        gen_loss = ad_loss(model_dis(gen_imgs), valid)

        gen_loss.backward()
        optimizer_G.step()


        # **Train Discriminator**
        optimizer_D.zero_grad()
        # Measure discriminator's ability to classify real from generate samples
        real_loss = ad_loss(model_dis(real_imgs), valid)
        fake_loss = ad_loss(model_dis(gen_imgs.detach()), fake) # ??? why we use detach here
        loss = (real_loss + fake_loss) / 2

        loss.backward()
        optimizer_D.step()


        print(
            "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
            % (epoch, config.get("n_epochs"), i, len(dataloader), loss.item(), gen_loss.item())
        )
        batches_done = epoch * len(dataloader) + i
        if batches_done % config.get("sample_interval") == 0:
            save_image(gen_imgs.data[:25], "../out_img/GAN_MNIST/%d.png" % batches_done, nrow=5, normalize=True)

[Epoch 0/30] [Batch 0/938] [D loss: 0.741871] [G loss: 0.707292]
[Epoch 0/30] [Batch 1/938] [D loss: 0.693320] [G loss: 0.705763]
[Epoch 0/30] [Batch 2/938] [D loss: 0.654969] [G loss: 0.704276]
[Epoch 0/30] [Batch 3/938] [D loss: 0.617747] [G loss: 0.702961]
[Epoch 0/30] [Batch 4/938] [D loss: 0.581673] [G loss: 0.701743]
[Epoch 0/30] [Batch 5/938] [D loss: 0.551511] [G loss: 0.700296]
[Epoch 0/30] [Batch 6/938] [D loss: 0.524657] [G loss: 0.698943]
[Epoch 0/30] [Batch 7/938] [D loss: 0.500298] [G loss: 0.697460]
[Epoch 0/30] [Batch 8/938] [D loss: 0.476328] [G loss: 0.695684]
[Epoch 0/30] [Batch 9/938] [D loss: 0.455547] [G loss: 0.693554]
[Epoch 0/30] [Batch 10/938] [D loss: 0.441888] [G loss: 0.691391]
[Epoch 0/30] [Batch 11/938] [D loss: 0.425480] [G loss: 0.688777]
[Epoch 0/30] [Batch 12/938] [D loss: 0.411928] [G loss: 0.685855]
[Epoch 0/30] [Batch 13/938] [D loss: 0.403105] [G loss: 0.682338]
[Epoch 0/30] [Batch 14/938] [D loss: 0.397918] [G loss: 0.678784]
[Epoch 0/30] [Batch 

KeyboardInterrupt: 