In [30]:
import os
import argparse
import numpy as np
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.autograd import Variable
from torchvision import datasets, transforms
from torchvision.utils import save_image

os.makedirs('images', exist_ok=True)

parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=50, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of second order momentum of gradient")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
parser.add_argument("--n_classes", type=int, default=10, help="number of classes for dataset")
parser.add_argument("--img_size", type=int, default=32, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=400, help="interval between image sampling")
parser.add_argument("--cuda", type=bool, default=torch.cuda.is_available(), help="use cuda or not")
config_list = ['--n_epochs', '50', '--batch_size', '64']
opt = parser.parse_args(config_list)
print(opt)

device = 'cuda' if opt.cuda else 'cpu'

# G
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.label_emb = nn.Embedding(opt.n_classes, opt.latent_dim)  # 词向量字典
        self.init_size = opt.img_size // 4  # Initial size before upsampling
        self.L1 = nn.Linear(opt.latent_dim, 128*self.init_size**2)
        
        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, opt.channels, 3, stride=1, padding=1),
            nn.Tanh(),
        )
        
    def forward(self, noise, labels):
        gen_input = torch.mul(self.label_emb(labels), noise)  # key of ACGAN?
        out = self.L1(gen_input)
        out = out.view(out.size(0), 128, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img

# D
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        
        def discriminator_block(in_filters, out_filters, bn=True):
            # returns layers of each discriminator block
            block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), 
                     nn.LeakyReLU(0.2, inplace=True),
                     nn.Dropout2d(0.25)]
            if bn:
                block.append(nn.BatchNorm2d(out_filters, 0.8))
            return block
        
        self.conv_blocks = nn.Sequential(
            *discriminator_block(opt.channels, 16, bn=False),
            *discriminator_block(16, 32),
            *discriminator_block(32, 64),
            *discriminator_block(64, 128),
        )
        
        # The height and width of downsampled image
        ds_size = opt.img_size // 2 ** 4 
        
        # output layers
        self.adv_layer = nn.Sequential(nn.Linear(128*ds_size**2, 1), nn.Sigmoid())
        self.aux_layer = nn.Sequential(nn.Linear(128*ds_size**2, opt.n_classes), nn.Softmax())
        
    def forward(self, img):
        out = self.conv_blocks(img)
        out = out.view(out.size(0), -1)
        validity = self.adv_layer(out)

        label = self.aux_layer(out)
        
        return validity, label     

Namespace(b1=0.5, b2=0.999, batch_size=64, channels=1, cuda=True, img_size=32, latent_dim=100, lr=0.0002, n_classes=10, n_cpu=8, n_epochs=50, sample_interval=400)


In [42]:
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm2d') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0.0)

# Loss functions
adversarial_loss = torch.nn.BCELoss().to(device)
auxiliary_loss = torch.nn.CrossEntropyLoss().to(device)

# Initialize generator and discriminator
generator = Generator().to(device)
discriminator = Discriminator().to(device)

# Initialize weights
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)

# Configure data loader
transform = transforms.Compose([
    transforms.Resize(opt.img_size), 
    transforms.ToTensor(), 
    transforms.Normalize([0.5], [0.5])])
dataset = datasets.MNIST(root='../data', train=True, download=False, transform=transform)
dataloader = DataLoader(dataset, batch_size=opt.batch_size, shuffle=True)

# Optimizers
optimizer_G = optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

FloatTensor = torch.cuda.FloatTensor if opt.cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if opt.cuda else torch.LongTensor

def sample_image(n_row, batches_done):
    """Saves a grid of generated digits ranging from 0 to n_classes"""
    # Sample noise
    z = Variable(FloatTensor(np.random.normal(0, 1, (n_row ** 2, opt.latent_dim))))
    # Get labels ranging from 0 to n_classes for n rows
    labels = np.array([num for _ in range(n_row) for num in range(n_row)])
    labels = Variable(LongTensor(labels))
    gen_imgs = generator(z, labels)
    save_image(gen_imgs.data, "images/%d.png" % batches_done, nrow=n_row, normalize=True)


In [44]:
# ----------
#  Training
# ----------
for epoch in range(opt.n_epochs):
    for i, (imgs, labels) in enumerate(dataloader):
        batch_size = imgs.size(0)
        
        # Adversarial ground truths
        valid = Variable(FloatTensor(batch_size, 1).fill_(1.0), requires_grad=False)
        fake = Variable(FloatTensor(batch_size, 1).fill_(0.0), requires_grad=False)
        
        # configure input
        imgs = Variable(imgs.type(FloatTensor))
        labels = Variable(labels.type(LongTensor))
        
        # -----------------
        #  Train Generator
        # -----------------

        optimizer_G.zero_grad()
        
        # Sample noise and labels as generator input
        z = Variable(FloatTensor(np.random.normal(0, 1, (batch_size, opt.latent_dim))))
        gen_labels = Variable(LongTensor(np.random.randint(0, opt.n_classes, batch_size)))
        
        # Generate a batch of images
        gen_imgs = generator(z, gen_labels)
        
        # Loss measures generator's ability to fool the discriminator
        validity, pred_label = discriminator(gen_imgs)
        g_loss = 0.5 * (adversarial_loss(validity, valid) + auxiliary_loss(pred_label, gen_labels))
        
        g_loss.backward()
        optimizer_G.step()
        
        # ---------------------
        #  Train Discriminator
        # ---------------------
        
        optimizer_D.zero_grad()
        
        # Loss for real images
        real_pred, real_aux = discriminator(real_imgs)
        d_real_loss = (adversarial_loss(real_pred, valid) + auxiliary_loss(real_aux, labels)) / 2

        # Loss for fake images
        fake_pred, fake_aux = discriminator(gen_imgs.detach())
        d_fake_loss = (adversarial_loss(fake_pred, fake) + auxiliary_loss(fake_aux, gen_labels)) / 2

        # Total discriminator loss
        d_loss = (d_real_loss + d_fake_loss) * 0.5
        
        # Calculate discriminator accuracy
        pred = np.concatenate([real_aux.data.cpu().numpy(), fake_aux.data.cpu().numpy()], axis=0)
        gt = np.concatenate([labels.data.cpu().numpy(), gen_labels.data.cpu().numpy()], axis=0)
        d_acc = np.mean(np.argmax(pred, axis=1) == gt)
        
        d_loss.backward()
        optimizer_D.step()
        
        print(
            "[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %d%%] [G loss: %f]"
            % (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), 100 * d_acc, g_loss.item())
        )
        
        batches_done = epoch * len(dataloader) + i
        if batches_done % opt.sample_interval == 0:
            sample_image(n_row=10, batches_done=batches_done)

[Epoch 0/50] [Batch 0/938] [D loss: 1.499155, acc: 12%] [G loss: 1.499084]
[Epoch 0/50] [Batch 1/938] [D loss: 1.497711, acc: 14%] [G loss: 1.498943]
[Epoch 0/50] [Batch 2/938] [D loss: 1.497965, acc: 4%] [G loss: 1.497679]
[Epoch 0/50] [Batch 3/938] [D loss: 1.498230, acc: 10%] [G loss: 1.495442]
[Epoch 0/50] [Batch 4/938] [D loss: 1.497657, acc: 8%] [G loss: 1.492451]
[Epoch 0/50] [Batch 5/938] [D loss: 1.494534, acc: 11%] [G loss: 1.494542]
[Epoch 0/50] [Batch 6/938] [D loss: 1.497229, acc: 10%] [G loss: 1.492394]
[Epoch 0/50] [Batch 7/938] [D loss: 1.495846, acc: 10%] [G loss: 1.491083]
[Epoch 0/50] [Batch 8/938] [D loss: 1.494712, acc: 9%] [G loss: 1.487605]
[Epoch 0/50] [Batch 9/938] [D loss: 1.496131, acc: 14%] [G loss: 1.489227]
[Epoch 0/50] [Batch 10/938] [D loss: 1.497370, acc: 14%] [G loss: 1.480361]
[Epoch 0/50] [Batch 11/938] [D loss: 1.496854, acc: 11%] [G loss: 1.485178]
[Epoch 0/50] [Batch 12/938] [D loss: 1.498870, acc: 8%] [G loss: 1.486218]
[Epoch 0/50] [Batch 13/938

[Epoch 0/50] [Batch 109/938] [D loss: 1.497828, acc: 9%] [G loss: 1.502709]
[Epoch 0/50] [Batch 110/938] [D loss: 1.499715, acc: 7%] [G loss: 1.500378]
[Epoch 0/50] [Batch 111/938] [D loss: 1.497654, acc: 9%] [G loss: 1.499464]
[Epoch 0/50] [Batch 112/938] [D loss: 1.497206, acc: 8%] [G loss: 1.499712]
[Epoch 0/50] [Batch 113/938] [D loss: 1.496492, acc: 13%] [G loss: 1.502630]
[Epoch 0/50] [Batch 114/938] [D loss: 1.497810, acc: 7%] [G loss: 1.500668]
[Epoch 0/50] [Batch 115/938] [D loss: 1.496356, acc: 10%] [G loss: 1.502636]
[Epoch 0/50] [Batch 116/938] [D loss: 1.496833, acc: 10%] [G loss: 1.501657]
[Epoch 0/50] [Batch 117/938] [D loss: 1.496523, acc: 9%] [G loss: 1.501226]
[Epoch 0/50] [Batch 118/938] [D loss: 1.496984, acc: 8%] [G loss: 1.499249]
[Epoch 0/50] [Batch 119/938] [D loss: 1.495960, acc: 14%] [G loss: 1.497038]
[Epoch 0/50] [Batch 120/938] [D loss: 1.498482, acc: 8%] [G loss: 1.497526]
[Epoch 0/50] [Batch 121/938] [D loss: 1.495639, acc: 11%] [G loss: 1.496695]
[Epoch 

[Epoch 0/50] [Batch 221/938] [D loss: 1.496258, acc: 16%] [G loss: 1.512390]
[Epoch 0/50] [Batch 222/938] [D loss: 1.495094, acc: 12%] [G loss: 1.514421]
[Epoch 0/50] [Batch 223/938] [D loss: 1.496073, acc: 9%] [G loss: 1.517240]
[Epoch 0/50] [Batch 224/938] [D loss: 1.494208, acc: 8%] [G loss: 1.512021]
[Epoch 0/50] [Batch 225/938] [D loss: 1.495912, acc: 10%] [G loss: 1.508091]
[Epoch 0/50] [Batch 226/938] [D loss: 1.495001, acc: 12%] [G loss: 1.506085]
[Epoch 0/50] [Batch 227/938] [D loss: 1.497427, acc: 7%] [G loss: 1.500367]
[Epoch 0/50] [Batch 228/938] [D loss: 1.495669, acc: 10%] [G loss: 1.495365]
[Epoch 0/50] [Batch 229/938] [D loss: 1.494844, acc: 10%] [G loss: 1.491439]
[Epoch 0/50] [Batch 230/938] [D loss: 1.496108, acc: 8%] [G loss: 1.488719]
[Epoch 0/50] [Batch 231/938] [D loss: 1.493229, acc: 11%] [G loss: 1.491121]
[Epoch 0/50] [Batch 232/938] [D loss: 1.489741, acc: 11%] [G loss: 1.484467]
[Epoch 0/50] [Batch 233/938] [D loss: 1.491751, acc: 7%] [G loss: 1.480133]
[Epo

[Epoch 0/50] [Batch 333/938] [D loss: 1.497985, acc: 12%] [G loss: 1.491590]
[Epoch 0/50] [Batch 334/938] [D loss: 1.500565, acc: 9%] [G loss: 1.489872]
[Epoch 0/50] [Batch 335/938] [D loss: 1.499517, acc: 16%] [G loss: 1.490872]
[Epoch 0/50] [Batch 336/938] [D loss: 1.499820, acc: 15%] [G loss: 1.490443]
[Epoch 0/50] [Batch 337/938] [D loss: 1.497273, acc: 10%] [G loss: 1.496791]
[Epoch 0/50] [Batch 338/938] [D loss: 1.500413, acc: 7%] [G loss: 1.492920]
[Epoch 0/50] [Batch 339/938] [D loss: 1.497823, acc: 9%] [G loss: 1.496196]
[Epoch 0/50] [Batch 340/938] [D loss: 1.499501, acc: 14%] [G loss: 1.501667]
[Epoch 0/50] [Batch 341/938] [D loss: 1.493570, acc: 9%] [G loss: 1.504094]
[Epoch 0/50] [Batch 342/938] [D loss: 1.492983, acc: 8%] [G loss: 1.506842]
[Epoch 0/50] [Batch 343/938] [D loss: 1.496549, acc: 9%] [G loss: 1.505634]
[Epoch 0/50] [Batch 344/938] [D loss: 1.494685, acc: 10%] [G loss: 1.505931]
[Epoch 0/50] [Batch 345/938] [D loss: 1.494031, acc: 10%] [G loss: 1.500984]
[Epoc

[Epoch 0/50] [Batch 444/938] [D loss: 1.493438, acc: 9%] [G loss: 1.494959]
[Epoch 0/50] [Batch 445/938] [D loss: 1.494071, acc: 13%] [G loss: 1.490752]
[Epoch 0/50] [Batch 446/938] [D loss: 1.491823, acc: 10%] [G loss: 1.490394]
[Epoch 0/50] [Batch 447/938] [D loss: 1.490403, acc: 10%] [G loss: 1.497918]
[Epoch 0/50] [Batch 448/938] [D loss: 1.487457, acc: 14%] [G loss: 1.508523]
[Epoch 0/50] [Batch 449/938] [D loss: 1.491272, acc: 7%] [G loss: 1.508950]
[Epoch 0/50] [Batch 450/938] [D loss: 1.486463, acc: 13%] [G loss: 1.516187]
[Epoch 0/50] [Batch 451/938] [D loss: 1.490133, acc: 14%] [G loss: 1.522218]
[Epoch 0/50] [Batch 452/938] [D loss: 1.491556, acc: 12%] [G loss: 1.509793]
[Epoch 0/50] [Batch 453/938] [D loss: 1.491361, acc: 11%] [G loss: 1.509741]
[Epoch 0/50] [Batch 454/938] [D loss: 1.497013, acc: 10%] [G loss: 1.501536]
[Epoch 0/50] [Batch 455/938] [D loss: 1.490333, acc: 11%] [G loss: 1.489760]
[Epoch 0/50] [Batch 456/938] [D loss: 1.502344, acc: 11%] [G loss: 1.480479]
[

[Epoch 0/50] [Batch 556/938] [D loss: 1.486392, acc: 7%] [G loss: 1.533645]
[Epoch 0/50] [Batch 557/938] [D loss: 1.486788, acc: 7%] [G loss: 1.538254]
[Epoch 0/50] [Batch 558/938] [D loss: 1.494843, acc: 10%] [G loss: 1.542469]
[Epoch 0/50] [Batch 559/938] [D loss: 1.493993, acc: 9%] [G loss: 1.526935]
[Epoch 0/50] [Batch 560/938] [D loss: 1.501953, acc: 10%] [G loss: 1.507260]
[Epoch 0/50] [Batch 561/938] [D loss: 1.499775, acc: 8%] [G loss: 1.491048]
[Epoch 0/50] [Batch 562/938] [D loss: 1.493308, acc: 11%] [G loss: 1.472588]
[Epoch 0/50] [Batch 563/938] [D loss: 1.492430, acc: 12%] [G loss: 1.481126]
[Epoch 0/50] [Batch 564/938] [D loss: 1.503189, acc: 13%] [G loss: 1.488338]
[Epoch 0/50] [Batch 565/938] [D loss: 1.490838, acc: 12%] [G loss: 1.508767]
[Epoch 0/50] [Batch 566/938] [D loss: 1.478609, acc: 12%] [G loss: 1.541586]
[Epoch 0/50] [Batch 567/938] [D loss: 1.481227, acc: 8%] [G loss: 1.547475]
[Epoch 0/50] [Batch 568/938] [D loss: 1.460467, acc: 11%] [G loss: 1.576144]
[Epo

[Epoch 0/50] [Batch 668/938] [D loss: 1.467480, acc: 13%] [G loss: 1.559131]
[Epoch 0/50] [Batch 669/938] [D loss: 1.462562, acc: 8%] [G loss: 1.552436]
[Epoch 0/50] [Batch 670/938] [D loss: 1.438025, acc: 5%] [G loss: 1.533970]
[Epoch 0/50] [Batch 671/938] [D loss: 1.437545, acc: 8%] [G loss: 1.567460]
[Epoch 0/50] [Batch 672/938] [D loss: 1.458219, acc: 13%] [G loss: 1.522333]
[Epoch 0/50] [Batch 673/938] [D loss: 1.448864, acc: 10%] [G loss: 1.516344]
[Epoch 0/50] [Batch 674/938] [D loss: 1.458945, acc: 8%] [G loss: 1.609990]
[Epoch 0/50] [Batch 675/938] [D loss: 1.454081, acc: 10%] [G loss: 1.576171]
[Epoch 0/50] [Batch 676/938] [D loss: 1.482520, acc: 10%] [G loss: 1.576512]
[Epoch 0/50] [Batch 677/938] [D loss: 1.458246, acc: 10%] [G loss: 1.561733]
[Epoch 0/50] [Batch 678/938] [D loss: 1.472590, acc: 11%] [G loss: 1.554508]
[Epoch 0/50] [Batch 679/938] [D loss: 1.458266, acc: 8%] [G loss: 1.607667]
[Epoch 0/50] [Batch 680/938] [D loss: 1.461702, acc: 7%] [G loss: 1.588530]
[Epoc

[Epoch 0/50] [Batch 780/938] [D loss: 1.419242, acc: 14%] [G loss: 1.670048]
[Epoch 0/50] [Batch 781/938] [D loss: 1.419101, acc: 9%] [G loss: 1.604331]
[Epoch 0/50] [Batch 782/938] [D loss: 1.437392, acc: 14%] [G loss: 1.582093]
[Epoch 0/50] [Batch 783/938] [D loss: 1.415623, acc: 9%] [G loss: 1.648271]
[Epoch 0/50] [Batch 784/938] [D loss: 1.385115, acc: 11%] [G loss: 1.622810]
[Epoch 0/50] [Batch 785/938] [D loss: 1.446378, acc: 9%] [G loss: 1.602826]
[Epoch 0/50] [Batch 786/938] [D loss: 1.429156, acc: 9%] [G loss: 1.478711]
[Epoch 0/50] [Batch 787/938] [D loss: 1.407942, acc: 15%] [G loss: 1.564596]
[Epoch 0/50] [Batch 788/938] [D loss: 1.444771, acc: 7%] [G loss: 1.632799]
[Epoch 0/50] [Batch 789/938] [D loss: 1.403374, acc: 7%] [G loss: 1.740734]
[Epoch 0/50] [Batch 790/938] [D loss: 1.441702, acc: 12%] [G loss: 1.611368]
[Epoch 0/50] [Batch 791/938] [D loss: 1.379027, acc: 16%] [G loss: 1.666849]
[Epoch 0/50] [Batch 792/938] [D loss: 1.425197, acc: 14%] [G loss: 1.574586]
[Epoc

[Epoch 0/50] [Batch 891/938] [D loss: 1.423796, acc: 10%] [G loss: 1.781185]
[Epoch 0/50] [Batch 892/938] [D loss: 1.403902, acc: 10%] [G loss: 1.657867]
[Epoch 0/50] [Batch 893/938] [D loss: 1.373930, acc: 13%] [G loss: 1.797174]
[Epoch 0/50] [Batch 894/938] [D loss: 1.422479, acc: 10%] [G loss: 1.788549]
[Epoch 0/50] [Batch 895/938] [D loss: 1.369762, acc: 15%] [G loss: 1.632747]
[Epoch 0/50] [Batch 896/938] [D loss: 1.382539, acc: 14%] [G loss: 1.683119]
[Epoch 0/50] [Batch 897/938] [D loss: 1.392308, acc: 8%] [G loss: 1.933320]
[Epoch 0/50] [Batch 898/938] [D loss: 1.393514, acc: 14%] [G loss: 1.867513]
[Epoch 0/50] [Batch 899/938] [D loss: 1.395696, acc: 10%] [G loss: 1.704756]
[Epoch 0/50] [Batch 900/938] [D loss: 1.367898, acc: 16%] [G loss: 1.595179]
[Epoch 0/50] [Batch 901/938] [D loss: 1.415831, acc: 8%] [G loss: 1.634093]
[Epoch 0/50] [Batch 902/938] [D loss: 1.378266, acc: 11%] [G loss: 1.863014]
[Epoch 0/50] [Batch 903/938] [D loss: 1.390057, acc: 17%] [G loss: 1.576953]
[

  "Please ensure they have the same size.".format(target.size(), input.size()))
  "Please ensure they have the same size.".format(target.size(), input.size()))


ValueError: Target and input must have the same number of elements. target nelement (32) != input nelement (64)