In [2]:
import argparse
import os
import numpy as np
import math

import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch

os.makedirs("images", exist_ok=True)

parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
parser.add_argument("--n_classes", type=int, default=10, help="number of classes for dataset")
parser.add_argument("--img_size", type=int, default=32, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=400, help="interval between image sampling")

config_list = ['--n_epochs', '50', '--batch_size', '64']
opt = parser.parse_args(config_list)
print(opt)

cuda = True if torch.cuda.is_available() else False


def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)


class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.label_emb = nn.Embedding(opt.n_classes, opt.latent_dim)

        self.init_size = opt.img_size // 4  # Initial size before upsampling
        self.l1 = nn.Sequential(nn.Linear(opt.latent_dim, 128 * self.init_size ** 2))

        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, opt.channels, 3, stride=1, padding=1),
            nn.Tanh(),
        )

    def forward(self, noise, labels):
        gen_input = torch.mul(self.label_emb(labels), noise)
        out = self.l1(gen_input)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        def discriminator_block(in_filters, out_filters, bn=True):
            """Returns layers of each discriminator block"""
            block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)]
            if bn:
                block.append(nn.BatchNorm2d(out_filters, 0.8))
            return block

        self.conv_blocks = nn.Sequential(
            *discriminator_block(opt.channels, 16, bn=False),
            *discriminator_block(16, 32),
            *discriminator_block(32, 64),
            *discriminator_block(64, 128),
        )

        # The height and width of downsampled image
        ds_size = opt.img_size // 2 ** 4

        # Output layers
        self.adv_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, 1), nn.Sigmoid())
        self.aux_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, opt.n_classes), nn.Softmax())

    def forward(self, img):
        out = self.conv_blocks(img)
        out = out.view(out.shape[0], -1)
        validity = self.adv_layer(out)
        label = self.aux_layer(out)

        return validity, label


# Loss functions
adversarial_loss = torch.nn.BCELoss()
auxiliary_loss = torch.nn.CrossEntropyLoss()

# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()

if cuda:
    generator.cuda()
    discriminator.cuda()
    adversarial_loss.cuda()
    auxiliary_loss.cuda()

# Initialize weights
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)

# Configure data loader
os.makedirs("../data", exist_ok=True)
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "../data",
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
        ),
    ),
    batch_size=opt.batch_size,
    shuffle=True,
)

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensor


def sample_image(n_row, batches_done):
    """Saves a grid of generated digits ranging from 0 to n_classes"""
    # Sample noise
    z = Variable(FloatTensor(np.random.normal(0, 1, (n_row ** 2, opt.latent_dim))))
    # Get labels ranging from 0 to n_classes for n rows
    labels = np.array([num for _ in range(n_row) for num in range(n_row)])
    labels = Variable(LongTensor(labels))
    gen_imgs = generator(z, labels)
    save_image(gen_imgs.data, "images/%d.png" % batches_done, nrow=n_row, normalize=True)


# ----------
#  Training
# ----------

for epoch in range(opt.n_epochs):
    for i, (imgs, labels) in enumerate(dataloader):

        batch_size = imgs.shape[0]

        # Adversarial ground truths
        valid = Variable(FloatTensor(batch_size, 1).fill_(1.0), requires_grad=False)
        fake = Variable(FloatTensor(batch_size, 1).fill_(0.0), requires_grad=False)

        # Configure input
        real_imgs = Variable(imgs.type(FloatTensor))
        labels = Variable(labels.type(LongTensor))

        # -----------------
        #  Train Generator
        # -----------------

        optimizer_G.zero_grad()

        # Sample noise and labels as generator input
        z = Variable(FloatTensor(np.random.normal(0, 1, (batch_size, opt.latent_dim))))
        gen_labels = Variable(LongTensor(np.random.randint(0, opt.n_classes, batch_size)))

        # Generate a batch of images
        gen_imgs = generator(z, gen_labels)

        # Loss measures generator's ability to fool the discriminator
        validity, pred_label = discriminator(gen_imgs)
        g_loss = 0.5 * (adversarial_loss(validity, valid) + auxiliary_loss(pred_label, gen_labels))

        g_loss.backward()
        optimizer_G.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizer_D.zero_grad()

        # Loss for real images
        real_pred, real_aux = discriminator(real_imgs)
        d_real_loss = (adversarial_loss(real_pred, valid) + auxiliary_loss(real_aux, labels)) / 2

        # Loss for fake images
        fake_pred, fake_aux = discriminator(gen_imgs.detach())
        d_fake_loss = (adversarial_loss(fake_pred, fake) + auxiliary_loss(fake_aux, gen_labels)) / 2

        # Total discriminator loss
        d_loss = (d_real_loss + d_fake_loss) / 2

        # Calculate discriminator accuracy
        pred = np.concatenate([real_aux.data.cpu().numpy(), fake_aux.data.cpu().numpy()], axis=0)
        gt = np.concatenate([labels.data.cpu().numpy(), gen_labels.data.cpu().numpy()], axis=0)
        d_acc = np.mean(np.argmax(pred, axis=1) == gt)

        d_loss.backward()
        optimizer_D.step()

        print(
            "[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %d%%] [G loss: %f]"
            % (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), 100 * d_acc, g_loss.item())
        )
        batches_done = epoch * len(dataloader) + i
        if batches_done % opt.sample_interval == 0:
            sample_image(n_row=10, batches_done=batches_done)

Namespace(b1=0.5, b2=0.999, batch_size=64, channels=1, img_size=32, latent_dim=100, lr=0.0002, n_classes=10, n_cpu=8, n_epochs=50, sample_interval=400)


  input = module(input)


[Epoch 0/50] [Batch 0/938] [D loss: 1.497996, acc: 11%] [G loss: 1.490640]
[Epoch 0/50] [Batch 1/938] [D loss: 1.497958, acc: 9%] [G loss: 1.490550]
[Epoch 0/50] [Batch 2/938] [D loss: 1.497775, acc: 10%] [G loss: 1.490907]
[Epoch 0/50] [Batch 3/938] [D loss: 1.498048, acc: 10%] [G loss: 1.491552]
[Epoch 0/50] [Batch 4/938] [D loss: 1.497958, acc: 6%] [G loss: 1.491710]
[Epoch 0/50] [Batch 5/938] [D loss: 1.497511, acc: 10%] [G loss: 1.491505]
[Epoch 0/50] [Batch 6/938] [D loss: 1.498003, acc: 7%] [G loss: 1.492139]
[Epoch 0/50] [Batch 7/938] [D loss: 1.497864, acc: 13%] [G loss: 1.492322]
[Epoch 0/50] [Batch 8/938] [D loss: 1.497711, acc: 10%] [G loss: 1.492143]
[Epoch 0/50] [Batch 9/938] [D loss: 1.497556, acc: 10%] [G loss: 1.492333]
[Epoch 0/50] [Batch 10/938] [D loss: 1.497551, acc: 12%] [G loss: 1.492511]
[Epoch 0/50] [Batch 11/938] [D loss: 1.497652, acc: 9%] [G loss: 1.493167]
[Epoch 0/50] [Batch 12/938] [D loss: 1.497694, acc: 10%] [G loss: 1.493444]
[Epoch 0/50] [Batch 13/938

[Epoch 0/50] [Batch 110/938] [D loss: 1.499318, acc: 11%] [G loss: 1.466677]
[Epoch 0/50] [Batch 111/938] [D loss: 1.502099, acc: 10%] [G loss: 1.468029]
[Epoch 0/50] [Batch 112/938] [D loss: 1.502769, acc: 22%] [G loss: 1.472178]
[Epoch 0/50] [Batch 113/938] [D loss: 1.503663, acc: 10%] [G loss: 1.477081]
[Epoch 0/50] [Batch 114/938] [D loss: 1.499975, acc: 13%] [G loss: 1.486194]
[Epoch 0/50] [Batch 115/938] [D loss: 1.499349, acc: 11%] [G loss: 1.493463]
[Epoch 0/50] [Batch 116/938] [D loss: 1.501559, acc: 14%] [G loss: 1.495735]
[Epoch 0/50] [Batch 117/938] [D loss: 1.499971, acc: 11%] [G loss: 1.503789]
[Epoch 0/50] [Batch 118/938] [D loss: 1.497524, acc: 10%] [G loss: 1.509943]
[Epoch 0/50] [Batch 119/938] [D loss: 1.500595, acc: 18%] [G loss: 1.512823]
[Epoch 0/50] [Batch 120/938] [D loss: 1.496360, acc: 10%] [G loss: 1.515543]
[Epoch 0/50] [Batch 121/938] [D loss: 1.493623, acc: 20%] [G loss: 1.520279]
[Epoch 0/50] [Batch 122/938] [D loss: 1.496781, acc: 16%] [G loss: 1.518486]

[Epoch 0/50] [Batch 222/938] [D loss: 1.500472, acc: 26%] [G loss: 1.497333]
[Epoch 0/50] [Batch 223/938] [D loss: 1.501295, acc: 24%] [G loss: 1.501004]
[Epoch 0/50] [Batch 224/938] [D loss: 1.500174, acc: 26%] [G loss: 1.497547]
[Epoch 0/50] [Batch 225/938] [D loss: 1.494352, acc: 26%] [G loss: 1.500774]
[Epoch 0/50] [Batch 226/938] [D loss: 1.494701, acc: 30%] [G loss: 1.500471]
[Epoch 0/50] [Batch 227/938] [D loss: 1.492400, acc: 31%] [G loss: 1.495899]
[Epoch 0/50] [Batch 228/938] [D loss: 1.494085, acc: 28%] [G loss: 1.496784]
[Epoch 0/50] [Batch 229/938] [D loss: 1.492884, acc: 22%] [G loss: 1.498771]
[Epoch 0/50] [Batch 230/938] [D loss: 1.485852, acc: 31%] [G loss: 1.497828]
[Epoch 0/50] [Batch 231/938] [D loss: 1.486344, acc: 34%] [G loss: 1.487892]
[Epoch 0/50] [Batch 232/938] [D loss: 1.488444, acc: 24%] [G loss: 1.493075]
[Epoch 0/50] [Batch 233/938] [D loss: 1.483617, acc: 32%] [G loss: 1.490319]
[Epoch 0/50] [Batch 234/938] [D loss: 1.484059, acc: 32%] [G loss: 1.481113]

[Epoch 0/50] [Batch 334/938] [D loss: 1.379000, acc: 47%] [G loss: 1.418265]
[Epoch 0/50] [Batch 335/938] [D loss: 1.396970, acc: 38%] [G loss: 1.449878]
[Epoch 0/50] [Batch 336/938] [D loss: 1.382226, acc: 45%] [G loss: 1.421604]
[Epoch 0/50] [Batch 337/938] [D loss: 1.393972, acc: 38%] [G loss: 1.491690]
[Epoch 0/50] [Batch 338/938] [D loss: 1.388107, acc: 39%] [G loss: 1.458247]
[Epoch 0/50] [Batch 339/938] [D loss: 1.372398, acc: 47%] [G loss: 1.430389]
[Epoch 0/50] [Batch 340/938] [D loss: 1.373102, acc: 42%] [G loss: 1.432564]
[Epoch 0/50] [Batch 341/938] [D loss: 1.375321, acc: 44%] [G loss: 1.462382]
[Epoch 0/50] [Batch 342/938] [D loss: 1.363744, acc: 45%] [G loss: 1.492967]
[Epoch 0/50] [Batch 343/938] [D loss: 1.353146, acc: 46%] [G loss: 1.452021]
[Epoch 0/50] [Batch 344/938] [D loss: 1.370804, acc: 45%] [G loss: 1.446389]
[Epoch 0/50] [Batch 345/938] [D loss: 1.341062, acc: 50%] [G loss: 1.466136]
[Epoch 0/50] [Batch 346/938] [D loss: 1.377402, acc: 43%] [G loss: 1.494237]

[Epoch 0/50] [Batch 445/938] [D loss: 1.318557, acc: 53%] [G loss: 1.422038]
[Epoch 0/50] [Batch 446/938] [D loss: 1.289460, acc: 57%] [G loss: 1.414374]
[Epoch 0/50] [Batch 447/938] [D loss: 1.299673, acc: 58%] [G loss: 1.374526]
[Epoch 0/50] [Batch 448/938] [D loss: 1.310941, acc: 55%] [G loss: 1.405298]
[Epoch 0/50] [Batch 449/938] [D loss: 1.293378, acc: 59%] [G loss: 1.412966]
[Epoch 0/50] [Batch 450/938] [D loss: 1.271735, acc: 61%] [G loss: 1.420752]
[Epoch 0/50] [Batch 451/938] [D loss: 1.272011, acc: 57%] [G loss: 1.389796]
[Epoch 0/50] [Batch 452/938] [D loss: 1.260887, acc: 65%] [G loss: 1.336004]
[Epoch 0/50] [Batch 453/938] [D loss: 1.307809, acc: 55%] [G loss: 1.413485]
[Epoch 0/50] [Batch 454/938] [D loss: 1.305077, acc: 57%] [G loss: 1.391756]
[Epoch 0/50] [Batch 455/938] [D loss: 1.264672, acc: 61%] [G loss: 1.339488]
[Epoch 0/50] [Batch 456/938] [D loss: 1.279811, acc: 60%] [G loss: 1.418109]
[Epoch 0/50] [Batch 457/938] [D loss: 1.293202, acc: 59%] [G loss: 1.376208]

[Epoch 0/50] [Batch 557/938] [D loss: 1.249811, acc: 63%] [G loss: 1.368356]
[Epoch 0/50] [Batch 558/938] [D loss: 1.210265, acc: 72%] [G loss: 1.318721]
[Epoch 0/50] [Batch 559/938] [D loss: 1.247462, acc: 67%] [G loss: 1.286583]
[Epoch 0/50] [Batch 560/938] [D loss: 1.235542, acc: 67%] [G loss: 1.347811]
[Epoch 0/50] [Batch 561/938] [D loss: 1.253295, acc: 64%] [G loss: 1.341233]
[Epoch 0/50] [Batch 562/938] [D loss: 1.215436, acc: 71%] [G loss: 1.291972]
[Epoch 0/50] [Batch 563/938] [D loss: 1.212409, acc: 69%] [G loss: 1.357070]
[Epoch 0/50] [Batch 564/938] [D loss: 1.205850, acc: 75%] [G loss: 1.281514]
[Epoch 0/50] [Batch 565/938] [D loss: 1.205143, acc: 73%] [G loss: 1.322948]
[Epoch 0/50] [Batch 566/938] [D loss: 1.201415, acc: 71%] [G loss: 1.314462]
[Epoch 0/50] [Batch 567/938] [D loss: 1.206472, acc: 71%] [G loss: 1.339463]
[Epoch 0/50] [Batch 568/938] [D loss: 1.231106, acc: 66%] [G loss: 1.333269]
[Epoch 0/50] [Batch 569/938] [D loss: 1.208922, acc: 70%] [G loss: 1.309179]

[Epoch 0/50] [Batch 669/938] [D loss: 1.161547, acc: 76%] [G loss: 1.290624]
[Epoch 0/50] [Batch 670/938] [D loss: 1.155250, acc: 81%] [G loss: 1.243828]
[Epoch 0/50] [Batch 671/938] [D loss: 1.165529, acc: 79%] [G loss: 1.222974]
[Epoch 0/50] [Batch 672/938] [D loss: 1.179815, acc: 71%] [G loss: 1.321985]
[Epoch 0/50] [Batch 673/938] [D loss: 1.133043, acc: 79%] [G loss: 1.267390]
[Epoch 0/50] [Batch 674/938] [D loss: 1.175322, acc: 74%] [G loss: 1.303156]
[Epoch 0/50] [Batch 675/938] [D loss: 1.140501, acc: 85%] [G loss: 1.269597]
[Epoch 0/50] [Batch 676/938] [D loss: 1.152860, acc: 76%] [G loss: 1.290836]
[Epoch 0/50] [Batch 677/938] [D loss: 1.161430, acc: 77%] [G loss: 1.309044]
[Epoch 0/50] [Batch 678/938] [D loss: 1.156308, acc: 82%] [G loss: 1.254215]
[Epoch 0/50] [Batch 679/938] [D loss: 1.151240, acc: 81%] [G loss: 1.237357]
[Epoch 0/50] [Batch 680/938] [D loss: 1.131635, acc: 83%] [G loss: 1.271891]
[Epoch 0/50] [Batch 681/938] [D loss: 1.156160, acc: 78%] [G loss: 1.304822]

[Epoch 0/50] [Batch 781/938] [D loss: 1.139185, acc: 84%] [G loss: 1.183530]
[Epoch 0/50] [Batch 782/938] [D loss: 1.110156, acc: 86%] [G loss: 1.238013]
[Epoch 0/50] [Batch 783/938] [D loss: 1.124671, acc: 84%] [G loss: 1.271017]
[Epoch 0/50] [Batch 784/938] [D loss: 1.102775, acc: 92%] [G loss: 1.245632]
[Epoch 0/50] [Batch 785/938] [D loss: 1.097831, acc: 86%] [G loss: 1.232757]
[Epoch 0/50] [Batch 786/938] [D loss: 1.153043, acc: 84%] [G loss: 1.204104]
[Epoch 0/50] [Batch 787/938] [D loss: 1.145089, acc: 85%] [G loss: 1.174116]
[Epoch 0/50] [Batch 788/938] [D loss: 1.113702, acc: 85%] [G loss: 1.227180]
[Epoch 0/50] [Batch 789/938] [D loss: 1.121340, acc: 85%] [G loss: 1.215502]
[Epoch 0/50] [Batch 790/938] [D loss: 1.171730, acc: 74%] [G loss: 1.263789]
[Epoch 0/50] [Batch 791/938] [D loss: 1.117337, acc: 83%] [G loss: 1.267567]
[Epoch 0/50] [Batch 792/938] [D loss: 1.114852, acc: 85%] [G loss: 1.208439]
[Epoch 0/50] [Batch 793/938] [D loss: 1.098983, acc: 91%] [G loss: 1.191018]

[Epoch 0/50] [Batch 892/938] [D loss: 1.114382, acc: 90%] [G loss: 1.210890]
[Epoch 0/50] [Batch 893/938] [D loss: 1.093527, acc: 90%] [G loss: 1.239601]
[Epoch 0/50] [Batch 894/938] [D loss: 1.111109, acc: 90%] [G loss: 1.184512]
[Epoch 0/50] [Batch 895/938] [D loss: 1.104240, acc: 90%] [G loss: 1.169980]
[Epoch 0/50] [Batch 896/938] [D loss: 1.083358, acc: 88%] [G loss: 1.211347]
[Epoch 0/50] [Batch 897/938] [D loss: 1.119747, acc: 87%] [G loss: 1.196724]
[Epoch 0/50] [Batch 898/938] [D loss: 1.112629, acc: 89%] [G loss: 1.215811]
[Epoch 0/50] [Batch 899/938] [D loss: 1.121751, acc: 84%] [G loss: 1.246258]
[Epoch 0/50] [Batch 900/938] [D loss: 1.119949, acc: 88%] [G loss: 1.173769]
[Epoch 0/50] [Batch 901/938] [D loss: 1.125484, acc: 85%] [G loss: 1.259220]
[Epoch 0/50] [Batch 902/938] [D loss: 1.106081, acc: 87%] [G loss: 1.223268]
[Epoch 0/50] [Batch 903/938] [D loss: 1.111878, acc: 90%] [G loss: 1.251742]
[Epoch 0/50] [Batch 904/938] [D loss: 1.124807, acc: 87%] [G loss: 1.242568]

[Epoch 1/50] [Batch 66/938] [D loss: 1.114855, acc: 88%] [G loss: 1.194634]
[Epoch 1/50] [Batch 67/938] [D loss: 1.103023, acc: 91%] [G loss: 1.149803]
[Epoch 1/50] [Batch 68/938] [D loss: 1.122215, acc: 88%] [G loss: 1.186125]
[Epoch 1/50] [Batch 69/938] [D loss: 1.103862, acc: 85%] [G loss: 1.205298]
[Epoch 1/50] [Batch 70/938] [D loss: 1.122338, acc: 93%] [G loss: 1.190117]
[Epoch 1/50] [Batch 71/938] [D loss: 1.131681, acc: 82%] [G loss: 1.302590]
[Epoch 1/50] [Batch 72/938] [D loss: 1.094315, acc: 87%] [G loss: 1.179724]
[Epoch 1/50] [Batch 73/938] [D loss: 1.103163, acc: 89%] [G loss: 1.216027]
[Epoch 1/50] [Batch 74/938] [D loss: 1.109790, acc: 91%] [G loss: 1.169244]
[Epoch 1/50] [Batch 75/938] [D loss: 1.101091, acc: 88%] [G loss: 1.221270]
[Epoch 1/50] [Batch 76/938] [D loss: 1.100112, acc: 90%] [G loss: 1.223912]
[Epoch 1/50] [Batch 77/938] [D loss: 1.099638, acc: 90%] [G loss: 1.207265]
[Epoch 1/50] [Batch 78/938] [D loss: 1.117592, acc: 86%] [G loss: 1.171158]
[Epoch 1/50]

KeyboardInterrupt: 