In [0]:
#!/usr/bin/env python3

import argparse
import os
import random
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from dcgan import DCGAN 
from sagan import SAGAN 

parser = argparse.ArgumentParser()
parser.add_argument('--dataroot', default="data", help='path to dataset')
parser.add_argument('--workers', type=int, help='number of data loading workers', default=2)
parser.add_argument('--batchSize', type=int, default=64, help='input batch size')
parser.add_argument('--imageSize', type=int, default=64, help='the height / width of the input image to network')
parser.add_argument('--gan', default="sa", help='dc | sa: use either Deep Convolutional GAN or Self Attention GAN')
parser.add_argument('--nz', type=int, default=100, help='size of the latent z vector')
parser.add_argument('--ngf', type=int, default=64)
parser.add_argument('--ndf', type=int, default=64)
parser.add_argument('--niter', type=int, default=1000, help='number of epochs to train for')
parser.add_argument('--lr', type=float, default=0.0002, help='learning rate, default=0.0002')
parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5')
parser.add_argument('--cuda', action='store_true', help='enables cuda')
parser.add_argument('--ngpu', type=int, default=1, help='number of GPUs to use')
parser.add_argument('--outf', default='output', help='folder to output images')
parser.add_argument('--checkpoint', default='checkpoint', help='folder to output model checkpoints')
parser.add_argument('--manualSeed', type=int, help='manual seed')

opt = parser.parse_args("")
print(opt)

opt.cuda = torch.cuda.is_available()
cudnn.benchmark = True
device = torch.device("cuda:0" if opt.cuda else "cpu")
ngpu = int(opt.ngpu)
nz = int(opt.nz)
ngf = int(opt.ngf)
ndf = int(opt.ndf)
num_classes = 10
nc = 3


def load_dataset():
    transform = transform=transforms.Compose([
                                   transforms.Resize(opt.imageSize),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                               ])
    train = dset.CIFAR10(root=opt.dataroot, download=True, train=True, transform=transform)
    return train


def to_dataloader(trainset):
    train_dataloader = torch.utils.data.DataLoader(trainset, 
            batch_size=opt.batchSize, shuffle=True, num_workers=int(opt.workers))
    return train_dataloader

def _get_one_hot_vector(class_indices, num_classes, batch_size):
    y_onehot = torch.FloatTensor(batch_size, num_classes)
    y_onehot.zero_()
    return y_onehot.scatter_(1, class_indices.unsqueeze(1), 1)

def train(gan, train_dataloader):
    fixed_noise = torch.randn(opt.batchSize, nz, 1, 1, device=device)
    class_one_hot = None
    if opt.gan == "sa":
        labels = torch.zeros(opt.batchSize).long().random_(0, num_classes)
        class_one_hot = _get_one_hot_vector(labels, num_classes, opt.batchSize)\
            .unsqueeze(2).unsqueeze(3).to(device, non_blocking=True)
    netG = gan.netG
    netD = gan.netD
    for epoch in range(opt.niter):
        for i, data in enumerate(train_dataloader, 0):
            print('[%d/%d][%d/%d] '
                  % (epoch, opt.niter, i, len(train_dataloader),), end="")

            gan.train_on_batch(data, device)
    
            if i % 100 == 0:
                fake = netG(fixed_noise) if not opt.gan == "sa" else netG(fixed_noise, class_one_hot)
                vutils.save_image(fake.detach(),
                        '%s/fake_samples_epoch_%03d.png' % (opt.outf, epoch),
                        normalize=True)
    
        # save checkpoint
        torch.save(netG.state_dict(), '%s/netG/netG_epoch_%d.pth' % (opt.checkpoint, epoch))
        torch.save(netD.state_dict(), '%s/netD/netD_epoch_%d.pth' % (opt.checkpoint, epoch))

###### MAIN

try:
    os.makedirs(opt.outf)
except OSError:
    pass

if opt.manualSeed is None:
    opt.manualSeed = random.randint(1, 10000)
print("Random Seed: ", opt.manualSeed)
random.seed(opt.manualSeed)
torch.manual_seed(opt.manualSeed)

if torch.cuda.is_available() and not opt.cuda:
    print("WARNING: You have a CUDA device, so you should probably run with --cuda")

# Prepare dataset
train_dataloader = to_dataloader(load_dataset())

if opt.gan == "sa":
    print("Loading Self Attention GAN")
    gan = SAGAN(nc=nc, nz=nz, ngf=ngf, ndf=ndf, ngpu=ngpu)
else:
    print("Loading Deep Convolutional GAN")
    gan = DCGAN(nc=nc, nz=nz, ngf=ngf, ndf=ndf, ngpu=ngpu)

print(gan.netG)
print(gan.netD)

if opt.cuda:
    gan.netD.cuda()
    gan.netG.cuda()

train(gan, train_dataloader)



Namespace(batchSize=64, beta1=0.5, checkpoint='checkpoint', cuda=False, dataroot='data', gan='sa', imageSize=64, lr=0.0002, manualSeed=None, ndf=64, netD='', netG='', ngf=64, ngpu=1, niter=1000, nz=100, outf='output', workers=2)
Random Seed:  5263
Files already downloaded and verified
Loading Self Attention GAN
_Generator(
  (deconv1): ConvTranspose2d(110, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
  (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (deconv2): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
  (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (deconv3): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
  (bn3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (sa3): _SelfAttention(
    (query): Conv2d(128, 16, kernel_size=(1, 1), stride=(1, 1

KeyboardInterrupt: ignored