In [33]:
import os
import numpy as np
import math

import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch

In [34]:
cuda = True if torch.cuda.is_available() else False

In [35]:
img_shape = [1,28,28]

In [36]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(100, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), *img_shape)
        return img

In [37]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)

        return validity

In [38]:
# Loss function
adversarial_loss = torch.nn.BCELoss()

# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()

if cuda:
    generator.cuda()
    discriminator.cuda()
    adversarial_loss.cuda()

In [39]:
# Configure data loader
os.makedirs("./data/mnist", exist_ok=True)
os.makedirs("./images", exist_ok = True)
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "./data/mnist",
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.Resize(28), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
        ),
    ),
    batch_size= 128,
    shuffle=True,
    num_workers = 4
)

In [40]:
# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

In [41]:
for epoch in range(200):
    for i, (imgs, _) in enumerate(dataloader):

        # Adversarial ground truths
        valid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False)
        fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False)

        # Configure input
        real_imgs = Variable(imgs.type(Tensor))

        # -----------------
        #  Train Generator
        # -----------------

        optimizer_G.zero_grad()

        # Sample noise as generator input
        z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], 100))))

        # Generate a batch of images
        gen_imgs = generator(z)

        # Loss measures generator's ability to fool the discriminator
        g_loss = adversarial_loss(discriminator(gen_imgs), valid)

        g_loss.backward()
        optimizer_G.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizer_D.zero_grad()

        # Measure discriminator's ability to classify real from generated samples
        real_loss = adversarial_loss(discriminator(real_imgs), valid)
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2

        d_loss.backward()
        optimizer_D.step()

        print(
            "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
            % (epoch, 200, i, len(dataloader), d_loss.item(), g_loss.item())
        )

        batches_done = epoch * len(dataloader) + i
        if batches_done % 10 == 0:
            save_image(gen_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True)

[Epoch 0/200] [Batch 0/469] [D loss: 0.724849] [G loss: 0.696622]
[Epoch 0/200] [Batch 1/469] [D loss: 0.630939] [G loss: 0.693587]
[Epoch 0/200] [Batch 2/469] [D loss: 0.561310] [G loss: 0.690880]
[Epoch 0/200] [Batch 3/469] [D loss: 0.505745] [G loss: 0.688213]
[Epoch 0/200] [Batch 4/469] [D loss: 0.461377] [G loss: 0.684848]
[Epoch 0/200] [Batch 5/469] [D loss: 0.429889] [G loss: 0.680114]
[Epoch 0/200] [Batch 6/469] [D loss: 0.408132] [G loss: 0.674331]
[Epoch 0/200] [Batch 7/469] [D loss: 0.395487] [G loss: 0.667583]
[Epoch 0/200] [Batch 8/469] [D loss: 0.389711] [G loss: 0.658517]
[Epoch 0/200] [Batch 9/469] [D loss: 0.388282] [G loss: 0.648758]
[Epoch 0/200] [Batch 10/469] [D loss: 0.392770] [G loss: 0.637124]
[Epoch 0/200] [Batch 11/469] [D loss: 0.398350] [G loss: 0.623930]
[Epoch 0/200] [Batch 12/469] [D loss: 0.403310] [G loss: 0.610474]
[Epoch 0/200] [Batch 13/469] [D loss: 0.409586] [G loss: 0.599781]
[Epoch 0/200] [Batch 14/469] [D loss: 0.416483] [G loss: 0.588464]
[Epoc

[Epoch 0/200] [Batch 131/469] [D loss: 0.560447] [G loss: 0.705321]
[Epoch 0/200] [Batch 132/469] [D loss: 0.581612] [G loss: 1.127679]
[Epoch 0/200] [Batch 133/469] [D loss: 0.669071] [G loss: 0.455952]
[Epoch 0/200] [Batch 134/469] [D loss: 0.610622] [G loss: 1.317782]
[Epoch 0/200] [Batch 135/469] [D loss: 0.601936] [G loss: 0.566389]
[Epoch 0/200] [Batch 136/469] [D loss: 0.500685] [G loss: 1.145540]
[Epoch 0/200] [Batch 137/469] [D loss: 0.483496] [G loss: 1.041122]
[Epoch 0/200] [Batch 138/469] [D loss: 0.470790] [G loss: 0.925278]
[Epoch 0/200] [Batch 139/469] [D loss: 0.494365] [G loss: 1.030829]
[Epoch 0/200] [Batch 140/469] [D loss: 0.528993] [G loss: 0.785312]
[Epoch 0/200] [Batch 141/469] [D loss: 0.554385] [G loss: 1.053295]
[Epoch 0/200] [Batch 142/469] [D loss: 0.627359] [G loss: 0.539959]
[Epoch 0/200] [Batch 143/469] [D loss: 0.632106] [G loss: 1.094083]
[Epoch 0/200] [Batch 144/469] [D loss: 0.743012] [G loss: 0.349513]
[Epoch 0/200] [Batch 145/469] [D loss: 0.680841]

[Epoch 0/200] [Batch 261/469] [D loss: 0.542313] [G loss: 0.822623]
[Epoch 0/200] [Batch 262/469] [D loss: 0.551716] [G loss: 0.602413]
[Epoch 0/200] [Batch 263/469] [D loss: 0.565246] [G loss: 0.957908]
[Epoch 0/200] [Batch 264/469] [D loss: 0.601810] [G loss: 0.499061]
[Epoch 0/200] [Batch 265/469] [D loss: 0.603712] [G loss: 1.004744]
[Epoch 0/200] [Batch 266/469] [D loss: 0.635820] [G loss: 0.451632]
[Epoch 0/200] [Batch 267/469] [D loss: 0.581545] [G loss: 1.003213]
[Epoch 0/200] [Batch 268/469] [D loss: 0.562607] [G loss: 0.559312]
[Epoch 0/200] [Batch 269/469] [D loss: 0.527508] [G loss: 1.001637]
[Epoch 0/200] [Batch 270/469] [D loss: 0.556225] [G loss: 0.635649]
[Epoch 0/200] [Batch 271/469] [D loss: 0.564077] [G loss: 0.877012]
[Epoch 0/200] [Batch 272/469] [D loss: 0.557413] [G loss: 0.655387]
[Epoch 0/200] [Batch 273/469] [D loss: 0.562284] [G loss: 0.928673]
[Epoch 0/200] [Batch 274/469] [D loss: 0.582246] [G loss: 0.607902]
[Epoch 0/200] [Batch 275/469] [D loss: 0.574157]

[Epoch 0/200] [Batch 383/469] [D loss: 0.495261] [G loss: 0.976531]
[Epoch 0/200] [Batch 384/469] [D loss: 0.534115] [G loss: 0.606192]
[Epoch 0/200] [Batch 385/469] [D loss: 0.593867] [G loss: 1.317514]
[Epoch 0/200] [Batch 386/469] [D loss: 0.789377] [G loss: 0.302476]
[Epoch 0/200] [Batch 387/469] [D loss: 0.577560] [G loss: 1.104918]
[Epoch 0/200] [Batch 388/469] [D loss: 0.576022] [G loss: 0.754804]
[Epoch 0/200] [Batch 389/469] [D loss: 0.581772] [G loss: 0.690908]
[Epoch 0/200] [Batch 390/469] [D loss: 0.596950] [G loss: 0.857434]
[Epoch 0/200] [Batch 391/469] [D loss: 0.614691] [G loss: 0.661198]
[Epoch 0/200] [Batch 392/469] [D loss: 0.604495] [G loss: 0.777232]
[Epoch 0/200] [Batch 393/469] [D loss: 0.588466] [G loss: 0.736637]
[Epoch 0/200] [Batch 394/469] [D loss: 0.567743] [G loss: 0.792482]
[Epoch 0/200] [Batch 395/469] [D loss: 0.553639] [G loss: 0.730305]
[Epoch 0/200] [Batch 396/469] [D loss: 0.524035] [G loss: 0.892844]
[Epoch 0/200] [Batch 397/469] [D loss: 0.525790]

[Epoch 1/200] [Batch 40/469] [D loss: 0.635230] [G loss: 0.381984]
[Epoch 1/200] [Batch 41/469] [D loss: 0.367096] [G loss: 1.507630]
[Epoch 1/200] [Batch 42/469] [D loss: 0.341622] [G loss: 1.158306]
[Epoch 1/200] [Batch 43/469] [D loss: 0.366492] [G loss: 0.929082]
[Epoch 1/200] [Batch 44/469] [D loss: 0.356781] [G loss: 1.262566]
[Epoch 1/200] [Batch 45/469] [D loss: 0.361517] [G loss: 0.965257]
[Epoch 1/200] [Batch 46/469] [D loss: 0.374594] [G loss: 1.121864]
[Epoch 1/200] [Batch 47/469] [D loss: 0.388268] [G loss: 0.964411]
[Epoch 1/200] [Batch 48/469] [D loss: 0.381801] [G loss: 1.077340]
[Epoch 1/200] [Batch 49/469] [D loss: 0.405330] [G loss: 0.863467]
[Epoch 1/200] [Batch 50/469] [D loss: 0.407263] [G loss: 1.186506]
[Epoch 1/200] [Batch 51/469] [D loss: 0.427349] [G loss: 0.704260]
[Epoch 1/200] [Batch 52/469] [D loss: 0.408363] [G loss: 1.431227]
[Epoch 1/200] [Batch 53/469] [D loss: 0.440880] [G loss: 0.664591]
[Epoch 1/200] [Batch 54/469] [D loss: 0.400384] [G loss: 1.419

[Epoch 1/200] [Batch 165/469] [D loss: 0.470695] [G loss: 1.769358]
[Epoch 1/200] [Batch 166/469] [D loss: 0.606013] [G loss: 0.456546]
[Epoch 1/200] [Batch 167/469] [D loss: 0.460651] [G loss: 1.777956]
[Epoch 1/200] [Batch 168/469] [D loss: 0.450969] [G loss: 0.733152]
[Epoch 1/200] [Batch 169/469] [D loss: 0.346074] [G loss: 1.386763]
[Epoch 1/200] [Batch 170/469] [D loss: 0.381660] [G loss: 1.146583]
[Epoch 1/200] [Batch 171/469] [D loss: 0.439842] [G loss: 0.863319]
[Epoch 1/200] [Batch 172/469] [D loss: 0.425606] [G loss: 1.343529]
[Epoch 1/200] [Batch 173/469] [D loss: 0.496758] [G loss: 0.700015]
[Epoch 1/200] [Batch 174/469] [D loss: 0.490177] [G loss: 1.613183]
[Epoch 1/200] [Batch 175/469] [D loss: 0.600258] [G loss: 0.483243]
[Epoch 1/200] [Batch 176/469] [D loss: 0.462394] [G loss: 1.636812]
[Epoch 1/200] [Batch 177/469] [D loss: 0.466867] [G loss: 0.761545]
[Epoch 1/200] [Batch 178/469] [D loss: 0.392553] [G loss: 1.387416]
[Epoch 1/200] [Batch 179/469] [D loss: 0.378076]

[Epoch 1/200] [Batch 288/469] [D loss: 0.532779] [G loss: 0.635873]
[Epoch 1/200] [Batch 289/469] [D loss: 0.536750] [G loss: 1.561669]
[Epoch 1/200] [Batch 290/469] [D loss: 0.685221] [G loss: 0.383034]
[Epoch 1/200] [Batch 291/469] [D loss: 0.590060] [G loss: 1.933318]
[Epoch 1/200] [Batch 292/469] [D loss: 0.622755] [G loss: 0.445991]
[Epoch 1/200] [Batch 293/469] [D loss: 0.522949] [G loss: 1.507712]
[Epoch 1/200] [Batch 294/469] [D loss: 0.501058] [G loss: 0.704666]
[Epoch 1/200] [Batch 295/469] [D loss: 0.469594] [G loss: 1.149466]
[Epoch 1/200] [Batch 296/469] [D loss: 0.459767] [G loss: 0.902356]
[Epoch 1/200] [Batch 297/469] [D loss: 0.460158] [G loss: 1.047071]
[Epoch 1/200] [Batch 298/469] [D loss: 0.467828] [G loss: 0.930004]
[Epoch 1/200] [Batch 299/469] [D loss: 0.441110] [G loss: 1.012365]
[Epoch 1/200] [Batch 300/469] [D loss: 0.453304] [G loss: 1.010952]
[Epoch 1/200] [Batch 301/469] [D loss: 0.458615] [G loss: 0.902117]
[Epoch 1/200] [Batch 302/469] [D loss: 0.459042]

[Epoch 1/200] [Batch 411/469] [D loss: 0.354339] [G loss: 1.245337]
[Epoch 1/200] [Batch 412/469] [D loss: 0.377988] [G loss: 1.142257]
[Epoch 1/200] [Batch 413/469] [D loss: 0.359048] [G loss: 1.168160]
[Epoch 1/200] [Batch 414/469] [D loss: 0.344671] [G loss: 1.251360]
[Epoch 1/200] [Batch 415/469] [D loss: 0.349794] [G loss: 1.154816]
[Epoch 1/200] [Batch 416/469] [D loss: 0.338367] [G loss: 1.281850]
[Epoch 1/200] [Batch 417/469] [D loss: 0.348547] [G loss: 1.206413]
[Epoch 1/200] [Batch 418/469] [D loss: 0.358071] [G loss: 1.189598]
[Epoch 1/200] [Batch 419/469] [D loss: 0.340645] [G loss: 1.174511]
[Epoch 1/200] [Batch 420/469] [D loss: 0.305204] [G loss: 1.338374]
[Epoch 1/200] [Batch 421/469] [D loss: 0.306329] [G loss: 1.228981]
[Epoch 1/200] [Batch 422/469] [D loss: 0.306355] [G loss: 1.321090]
[Epoch 1/200] [Batch 423/469] [D loss: 0.318381] [G loss: 1.181433]
[Epoch 1/200] [Batch 424/469] [D loss: 0.315354] [G loss: 1.311347]
[Epoch 1/200] [Batch 425/469] [D loss: 0.305253]

[Epoch 2/200] [Batch 70/469] [D loss: 0.657250] [G loss: 0.418053]
[Epoch 2/200] [Batch 71/469] [D loss: 0.840408] [G loss: 3.383815]
[Epoch 2/200] [Batch 72/469] [D loss: 1.057527] [G loss: 0.183389]
[Epoch 2/200] [Batch 73/469] [D loss: 0.350616] [G loss: 2.185173]
[Epoch 2/200] [Batch 74/469] [D loss: 0.384687] [G loss: 2.326027]
[Epoch 2/200] [Batch 75/469] [D loss: 0.492040] [G loss: 0.707381]
[Epoch 2/200] [Batch 76/469] [D loss: 0.361821] [G loss: 1.874710]
[Epoch 2/200] [Batch 77/469] [D loss: 0.339871] [G loss: 1.176298]
[Epoch 2/200] [Batch 78/469] [D loss: 0.382077] [G loss: 1.230598]
[Epoch 2/200] [Batch 79/469] [D loss: 0.425638] [G loss: 1.219650]
[Epoch 2/200] [Batch 80/469] [D loss: 0.489615] [G loss: 0.953088]
[Epoch 2/200] [Batch 81/469] [D loss: 0.473998] [G loss: 1.185665]
[Epoch 2/200] [Batch 82/469] [D loss: 0.499102] [G loss: 0.850977]
[Epoch 2/200] [Batch 83/469] [D loss: 0.499016] [G loss: 1.571013]
[Epoch 2/200] [Batch 84/469] [D loss: 0.643658] [G loss: 0.453

[Epoch 2/200] [Batch 200/469] [D loss: 0.372542] [G loss: 1.033501]
[Epoch 2/200] [Batch 201/469] [D loss: 0.372064] [G loss: 1.129232]
[Epoch 2/200] [Batch 202/469] [D loss: 0.378667] [G loss: 1.362943]
[Epoch 2/200] [Batch 203/469] [D loss: 0.352146] [G loss: 1.011712]
[Epoch 2/200] [Batch 204/469] [D loss: 0.367898] [G loss: 1.629767]
[Epoch 2/200] [Batch 205/469] [D loss: 0.388695] [G loss: 0.881772]
[Epoch 2/200] [Batch 206/469] [D loss: 0.389937] [G loss: 1.797219]
[Epoch 2/200] [Batch 207/469] [D loss: 0.427859] [G loss: 0.774659]
[Epoch 2/200] [Batch 208/469] [D loss: 0.380402] [G loss: 1.985776]
[Epoch 2/200] [Batch 209/469] [D loss: 0.411522] [G loss: 0.761360]
[Epoch 2/200] [Batch 210/469] [D loss: 0.400330] [G loss: 1.795785]
[Epoch 2/200] [Batch 211/469] [D loss: 0.474648] [G loss: 0.738016]
[Epoch 2/200] [Batch 212/469] [D loss: 0.425259] [G loss: 1.888831]
[Epoch 2/200] [Batch 213/469] [D loss: 0.524491] [G loss: 0.583899]
[Epoch 2/200] [Batch 214/469] [D loss: 0.476331]

[Epoch 2/200] [Batch 322/469] [D loss: 0.435783] [G loss: 1.243188]
[Epoch 2/200] [Batch 323/469] [D loss: 0.490595] [G loss: 0.794984]
[Epoch 2/200] [Batch 324/469] [D loss: 0.470320] [G loss: 1.476798]
[Epoch 2/200] [Batch 325/469] [D loss: 0.572501] [G loss: 0.516588]
[Epoch 2/200] [Batch 326/469] [D loss: 0.648591] [G loss: 2.379569]
[Epoch 2/200] [Batch 327/469] [D loss: 0.805370] [G loss: 0.244768]
[Epoch 2/200] [Batch 328/469] [D loss: 0.539676] [G loss: 2.731040]
[Epoch 2/200] [Batch 329/469] [D loss: 0.353869] [G loss: 1.101723]
[Epoch 2/200] [Batch 330/469] [D loss: 0.364429] [G loss: 0.919139]
[Epoch 2/200] [Batch 331/469] [D loss: 0.417162] [G loss: 1.896034]
[Epoch 2/200] [Batch 332/469] [D loss: 0.437015] [G loss: 0.705978]
[Epoch 2/200] [Batch 333/469] [D loss: 0.412651] [G loss: 1.753827]
[Epoch 2/200] [Batch 334/469] [D loss: 0.438967] [G loss: 0.751614]
[Epoch 2/200] [Batch 335/469] [D loss: 0.459768] [G loss: 1.712366]
[Epoch 2/200] [Batch 336/469] [D loss: 0.607777]

KeyboardInterrupt: 