From 4a8c4ac088b6f84a10569ee89db3a938b48922b4 Mon Sep 17 00:00:00 2001 From: ptrblck Date: Wed, 7 Aug 2019 19:34:12 +0200 Subject: [PATCH] Add DCGAN example (#413) * initial commit * add default O1 mode, enable other modes, add README * add carilli's review suggestions to README --- examples/dcgan/README.md | 42 +++++- examples/dcgan/main_amp.py | 274 +++++++++++++++++++++++++++++++++++++ 2 files changed, 315 insertions(+), 1 deletion(-) create mode 100644 examples/dcgan/main_amp.py diff --git a/examples/dcgan/README.md b/examples/dcgan/README.md index 9e86fd8fc..21baa5371 100644 --- a/examples/dcgan/README.md +++ b/examples/dcgan/README.md @@ -1 +1,41 @@ -Under construction... +# Mixed Precision DCGAN Training in PyTorch + +`main_amp.py` is based on [https://github.com/pytorch/examples/tree/master/dcgan](https://github.com/pytorch/examples/tree/master/dcgan). +It implements Automatic Mixed Precision (Amp) training of the DCGAN example for different datasets. Command-line flags forwarded to `amp.initialize` are used to easily manipulate and switch between various pure and mixed precision "optimization levels" or `opt_level`s. For a detailed explanation of `opt_level`s, see the [updated API guide](https://nvidia.github.io/apex/amp.html). + +We introduce these changes to the PyTorch DCGAN example as described in the [Multiple models/optimizers/losses](https://nvidia.github.io/apex/advanced.html#multiple-models-optimizers-losses) section of the documentation:: +``` +# Added after models and optimizers construction +[netD, netG], [optimizerD, optimizerG] = amp.initialize( + [netD, netG], [optimizerD, optimizerG], opt_level=opt.opt_level, num_losses=3) +... +# loss.backward() changed to: +with amp.scale_loss(errD_real, optimizerD, loss_id=0) as errD_real_scaled: + errD_real_scaled.backward() +... +with amp.scale_loss(errD_fake, optimizerD, loss_id=1) as errD_fake_scaled: + errD_fake_scaled.backward() +... +with amp.scale_loss(errG, optimizerG, loss_id=2) as errG_scaled: + errG_scaled.backward() +``` + +Note that we use different `loss_scalers` for each computed loss. +Using a separate loss scaler per loss is [optional, not required](https://nvidia.github.io/apex/advanced.html#optionally-have-amp-use-a-different-loss-scaler-per-loss). + +To improve the numerical stability, we swapped `nn.Sigmoid() + nn.BCELoss()` to `nn.BCEWithLogitsLoss()`. + +With the new Amp API **you never need to explicitly convert your model, or the input data, to half().** + +"Pure FP32" training: +``` +$ python main_amp.py --opt-level O0 +``` +Recommended mixed precision training: +``` +$ python main_amp.py --opt-level O1 +``` + +Have a look at the original [DCGAN example](https://github.com/pytorch/examples/tree/master/dcgan) for more information about the used arguments. + +To enable mixed precision training, we introduce the `--opt-level` argument. diff --git a/examples/dcgan/main_amp.py b/examples/dcgan/main_amp.py new file mode 100644 index 000000000..be1a2894f --- /dev/null +++ b/examples/dcgan/main_amp.py @@ -0,0 +1,274 @@ +from __future__ import print_function +import argparse +import os +import random +import torch +import torch.nn as nn +import torch.nn.parallel +import torch.backends.cudnn as cudnn +import torch.optim as optim +import torch.utils.data +import torchvision.datasets as dset +import torchvision.transforms as transforms +import torchvision.utils as vutils + +try: + from apex import amp +except ImportError: + raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.") + + +parser = argparse.ArgumentParser() +parser.add_argument('--dataset', default='cifar10', help='cifar10 | lsun | mnist |imagenet | folder | lfw | fake') +parser.add_argument('--dataroot', default='./', help='path to dataset') +parser.add_argument('--workers', type=int, help='number of data loading workers', default=2) +parser.add_argument('--batchSize', type=int, default=64, help='input batch size') +parser.add_argument('--imageSize', type=int, default=64, help='the height / width of the input image to network') +parser.add_argument('--nz', type=int, default=100, help='size of the latent z vector') +parser.add_argument('--ngf', type=int, default=64) +parser.add_argument('--ndf', type=int, default=64) +parser.add_argument('--niter', type=int, default=25, help='number of epochs to train for') +parser.add_argument('--lr', type=float, default=0.0002, help='learning rate, default=0.0002') +parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5') +parser.add_argument('--ngpu', type=int, default=1, help='number of GPUs to use') +parser.add_argument('--netG', default='', help="path to netG (to continue training)") +parser.add_argument('--netD', default='', help="path to netD (to continue training)") +parser.add_argument('--outf', default='.', help='folder to output images and model checkpoints') +parser.add_argument('--manualSeed', type=int, help='manual seed') +parser.add_argument('--classes', default='bedroom', help='comma separated list of classes for the lsun data set') +parser.add_argument('--opt_level', default='O1', help='amp opt_level, default="O1"') + +opt = parser.parse_args() +print(opt) + + +try: + os.makedirs(opt.outf) +except OSError: + pass + +if opt.manualSeed is None: + opt.manualSeed = 2809 +print("Random Seed: ", opt.manualSeed) +random.seed(opt.manualSeed) +torch.manual_seed(opt.manualSeed) + +cudnn.benchmark = True + + +if opt.dataset in ['imagenet', 'folder', 'lfw']: + # folder dataset + dataset = dset.ImageFolder(root=opt.dataroot, + transform=transforms.Compose([ + transforms.Resize(opt.imageSize), + transforms.CenterCrop(opt.imageSize), + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + ])) + nc=3 +elif opt.dataset == 'lsun': + classes = [ c + '_train' for c in opt.classes.split(',')] + dataset = dset.LSUN(root=opt.dataroot, classes=classes, + transform=transforms.Compose([ + transforms.Resize(opt.imageSize), + transforms.CenterCrop(opt.imageSize), + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + ])) + nc=3 +elif opt.dataset == 'cifar10': + dataset = dset.CIFAR10(root=opt.dataroot, download=True, + transform=transforms.Compose([ + transforms.Resize(opt.imageSize), + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + ])) + nc=3 + +elif opt.dataset == 'mnist': + dataset = dset.MNIST(root=opt.dataroot, download=True, + transform=transforms.Compose([ + transforms.Resize(opt.imageSize), + transforms.ToTensor(), + transforms.Normalize((0.5,), (0.5,)), + ])) + nc=1 + +elif opt.dataset == 'fake': + dataset = dset.FakeData(image_size=(3, opt.imageSize, opt.imageSize), + transform=transforms.ToTensor()) + nc=3 + +assert dataset +dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize, + shuffle=True, num_workers=int(opt.workers)) + +device = torch.device("cuda:0") +ngpu = int(opt.ngpu) +nz = int(opt.nz) +ngf = int(opt.ngf) +ndf = int(opt.ndf) + + +# custom weights initialization called on netG and netD +def weights_init(m): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + m.weight.data.normal_(0.0, 0.02) + elif classname.find('BatchNorm') != -1: + m.weight.data.normal_(1.0, 0.02) + m.bias.data.fill_(0) + + +class Generator(nn.Module): + def __init__(self, ngpu): + super(Generator, self).__init__() + self.ngpu = ngpu + self.main = nn.Sequential( + # input is Z, going into a convolution + nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False), + nn.BatchNorm2d(ngf * 8), + nn.ReLU(True), + # state size. (ngf*8) x 4 x 4 + nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False), + nn.BatchNorm2d(ngf * 4), + nn.ReLU(True), + # state size. (ngf*4) x 8 x 8 + nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False), + nn.BatchNorm2d(ngf * 2), + nn.ReLU(True), + # state size. (ngf*2) x 16 x 16 + nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False), + nn.BatchNorm2d(ngf), + nn.ReLU(True), + # state size. (ngf) x 32 x 32 + nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False), + nn.Tanh() + # state size. (nc) x 64 x 64 + ) + + def forward(self, input): + if input.is_cuda and self.ngpu > 1: + output = nn.parallel.data_parallel(self.main, input, range(self.ngpu)) + else: + output = self.main(input) + return output + + +netG = Generator(ngpu).to(device) +netG.apply(weights_init) +if opt.netG != '': + netG.load_state_dict(torch.load(opt.netG)) +print(netG) + + +class Discriminator(nn.Module): + def __init__(self, ngpu): + super(Discriminator, self).__init__() + self.ngpu = ngpu + self.main = nn.Sequential( + # input is (nc) x 64 x 64 + nn.Conv2d(nc, ndf, 4, 2, 1, bias=False), + nn.LeakyReLU(0.2, inplace=True), + # state size. (ndf) x 32 x 32 + nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False), + nn.BatchNorm2d(ndf * 2), + nn.LeakyReLU(0.2, inplace=True), + # state size. (ndf*2) x 16 x 16 + nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False), + nn.BatchNorm2d(ndf * 4), + nn.LeakyReLU(0.2, inplace=True), + # state size. (ndf*4) x 8 x 8 + nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False), + nn.BatchNorm2d(ndf * 8), + nn.LeakyReLU(0.2, inplace=True), + # state size. (ndf*8) x 4 x 4 + nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False), + ) + + def forward(self, input): + if input.is_cuda and self.ngpu > 1: + output = nn.parallel.data_parallel(self.main, input, range(self.ngpu)) + else: + output = self.main(input) + + return output.view(-1, 1).squeeze(1) + + +netD = Discriminator(ngpu).to(device) +netD.apply(weights_init) +if opt.netD != '': + netD.load_state_dict(torch.load(opt.netD)) +print(netD) + +criterion = nn.BCEWithLogitsLoss() + +fixed_noise = torch.randn(opt.batchSize, nz, 1, 1, device=device) +real_label = 1 +fake_label = 0 + +# setup optimizer +optimizerD = optim.Adam(netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) +optimizerG = optim.Adam(netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + +[netD, netG], [optimizerD, optimizerG] = amp.initialize( + [netD, netG], [optimizerD, optimizerG], opt_level=opt.opt_level, num_losses=3) + +for epoch in range(opt.niter): + for i, data in enumerate(dataloader, 0): + ############################ + # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z))) + ########################### + # train with real + netD.zero_grad() + real_cpu = data[0].to(device) + batch_size = real_cpu.size(0) + label = torch.full((batch_size,), real_label, device=device) + + output = netD(real_cpu) + errD_real = criterion(output, label) + with amp.scale_loss(errD_real, optimizerD, loss_id=0) as errD_real_scaled: + errD_real_scaled.backward() + D_x = output.mean().item() + + # train with fake + noise = torch.randn(batch_size, nz, 1, 1, device=device) + fake = netG(noise) + label.fill_(fake_label) + output = netD(fake.detach()) + errD_fake = criterion(output, label) + with amp.scale_loss(errD_fake, optimizerD, loss_id=1) as errD_fake_scaled: + errD_fake_scaled.backward() + D_G_z1 = output.mean().item() + errD = errD_real + errD_fake + optimizerD.step() + + ############################ + # (2) Update G network: maximize log(D(G(z))) + ########################### + netG.zero_grad() + label.fill_(real_label) # fake labels are real for generator cost + output = netD(fake) + errG = criterion(output, label) + with amp.scale_loss(errG, optimizerG, loss_id=2) as errG_scaled: + errG_scaled.backward() + D_G_z2 = output.mean().item() + optimizerG.step() + + print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f' + % (epoch, opt.niter, i, len(dataloader), + errD.item(), errG.item(), D_x, D_G_z1, D_G_z2)) + if i % 100 == 0: + vutils.save_image(real_cpu, + '%s/real_samples.png' % opt.outf, + normalize=True) + fake = netG(fixed_noise) + vutils.save_image(fake.detach(), + '%s/amp_fake_samples_epoch_%03d.png' % (opt.outf, epoch), + normalize=True) + + # do checkpointing + torch.save(netG.state_dict(), '%s/netG_epoch_%d.pth' % (opt.outf, epoch)) + torch.save(netD.state_dict(), '%s/netD_epoch_%d.pth' % (opt.outf, epoch)) + +