In [1]:
from __future__ import print_function
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.autograd import Variable
import torch.nn.functional as F

In [2]:
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', required=True, help='cifar10 | lsun | imagenet | folder | lfw ')
parser.add_argument('--dataroot', required=True, help='path to dataset')
parser.add_argument('--workers', type=int, help='number of data loading workers', default=2)
parser.add_argument('--batchSize', type=int, default=16, help='input batch size')
parser.add_argument('--imageSize', type=int, default=64, help='the height / width of the input image to network')
parser.add_argument('--nz', type=int, default=64, help='size of the latent z vector')
parser.add_argument('--ngf', type=int, default=64)
parser.add_argument('--ndf', type=int, default=64)
parser.add_argument('--niter', type=int, default=25, help='number of epochs to train for')
parser.add_argument('--lr', type=float, default=0.0001, help='learning rate, default=0.0002')
parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5')
parser.add_argument('--cuda', action='store_true', help='enables cuda')
parser.add_argument('--netD', default='', help="path to netD (to continue training)")
parser.add_argument('--netG', default='', help="path to netG (to continue training)")
parser.add_argument('--outf', default='.', help='folder to output images and model checkpoints')
parser.add_argument('--manualSeed', type=int, help='manual seed')

_StoreAction(option_strings=['--manualSeed'], dest='manualSeed', nargs=None, const=None, default=None, type=<type 'int'>, choices=None, help='manual seed', metavar=None)

In [3]:
opt = parser.parse_args(['--dataset', 'folder', '--dataroot', '/home/ubuntu/Python/BEGAN_Pytorch/data/work/', '--cuda', '--niter', '25'])
print(opt)

Namespace(batchSize=16, beta1=0.5, cuda=True, dataroot='/home/ubuntu/Python/BEGAN_Pytorch/data/work/', dataset='folder', imageSize=64, lr=0.0001, manualSeed=None, ndf=64, netD='', netG='', ngf=64, niter=25, nz=64, outf='.', workers=2)


In [4]:
try:
    os.makedirs(opt.outf)
except OSError:
    pass

In [5]:
if opt.manualSeed is None:
    opt.manualSeed = random.randint(1, 10000)
print("Random Seed: ", opt.manualSeed)
random.seed(opt.manualSeed)
torch.manual_seed(opt.manualSeed)
if opt.cuda:
    torch.cuda.manual_seed_all(opt.manualSeed)

Random Seed:  4848


In [6]:
cudnn.benchmark = True
if torch.cuda.is_available() and not opt.cuda:
    print("WARNING: You have a CUDA device, so you should probably run with --cuda")

In [7]:
torch.cuda.set_device(0)

In [8]:
dataset = dset.ImageFolder(root=opt.dataroot,
                               transform=transforms.Compose([
                                   transforms.Scale(opt.imageSize),
                                   transforms.CenterCrop(opt.imageSize),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                               ]))
assert dataset

In [9]:
dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize,
                                         shuffle=True, num_workers=int(opt.workers))

In [10]:
ngpu = 1
nz = int(opt.nz)
ngf = int(opt.ngf)
ndf = int(opt.ndf)
nc = 3
gamma = 0.5
k = 0
lr_decay = 0.9

In [11]:
# custom weights initialization called on netG and netD
def G_weights_init(m):
    classname = m.__class__.__name__

    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.002)

    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

def D_weights_init(m):
    classname = m.__class__.__name__

    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)

    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

In [12]:
class G_CNN(nn.Module):
    def __init__(self, embedding_size, kernel_num):
        super(G_CNN, self).__init__()
        self.kernel_num = kernel_num
        FC_start_size = 8 * 8 * self.kernel_num
        self.start = nn.Linear(embedding_size, FC_start_size)
        #all layers
        self.main_arch = nn.Sequential(
            #repeated units
            nn.Conv2d(in_channels=kernel_num, out_channels=kernel_num, kernel_size=3, padding=1),
            nn.ELU(inplace=True),
            nn.Conv2d(in_channels=kernel_num, out_channels=kernel_num, kernel_size=3, padding=1),
            nn.ELU(inplace=True),
            nn.UpsamplingNearest2d(scale_factor=2),
            #repeated units
            nn.Conv2d(in_channels=kernel_num, out_channels=kernel_num, kernel_size=3, padding=1),
            nn.ELU(inplace=True),
            nn.Conv2d(in_channels=kernel_num, out_channels=kernel_num, kernel_size=3, padding=1),
            nn.ELU(inplace=True),
            nn.UpsamplingNearest2d(scale_factor=2),
            #repeated units
            nn.Conv2d(in_channels=kernel_num, out_channels=kernel_num, kernel_size=3, padding=1),
            nn.ELU(inplace=True),
            nn.Conv2d(in_channels=kernel_num, out_channels=kernel_num, kernel_size=3, padding=1),
            nn.ELU(inplace=True),
            nn.UpsamplingNearest2d(scale_factor=2),
            #last layer
            #nn.Conv2d(in_channels=kernel_num, out_channels=3, kernel_size=3, padding=1)
            nn.Conv2d(in_channels=kernel_num, out_channels=kernel_num, kernel_size=3, padding=1),
            nn.ELU(inplace=True),
            nn.Conv2d(in_channels=kernel_num, out_channels=kernel_num, kernel_size=3, padding=1),
            nn.ELU(inplace=True),
            nn.Conv2d(in_channels=kernel_num, out_channels=3, kernel_size=3, padding=1)
        )

    def forward(self, in_data):
        in_data = self.start(in_data.view(-1, 64))
        in_data = in_data.view(-1, self.kernel_num, 8, 8)
        output = self.main_arch(in_data)
        return output

In [13]:
netG = G_CNN(embedding_size=nz, kernel_num=64)
netG.apply(G_weights_init)
if opt.netG != '':
    netG.load_state_dict(torch.load(opt.netG))

In [14]:
#Discriminator
class D_CNN(nn.Module):
    def __init__(self, embedding_size, base_kernel_num):
        super(D_CNN, self).__init__()
        self.base_kernel_num = base_kernel_num
        
        #encoder w/o flatten and FC
        self.encoder_wo_FC = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=base_kernel_num, kernel_size=3, padding=1),
            nn.ELU(inplace=True),
            
            nn.Conv2d(in_channels=base_kernel_num, out_channels=base_kernel_num, kernel_size=3, padding=1),
            nn.ELU(inplace=True),
            nn.Conv2d(in_channels=base_kernel_num, out_channels=2*base_kernel_num, kernel_size=3, padding=1),
            nn.ELU(inplace=True),
            #downsampling
            nn.Conv2d(in_channels=2*base_kernel_num, out_channels=2*base_kernel_num, kernel_size=3, padding=1, stride=2),
            nn.ELU(inplace=True),
            
            nn.Conv2d(in_channels=2*base_kernel_num, out_channels=2*base_kernel_num, kernel_size=3, padding=1),
            nn.ELU(inplace=True),
            nn.Conv2d(in_channels=2*base_kernel_num, out_channels=2*base_kernel_num, kernel_size=3, padding=1),
            nn.ELU(inplace=True),
            #downsampling
            nn.Conv2d(in_channels=2*base_kernel_num, out_channels=3*base_kernel_num, kernel_size=3, padding=1, stride=2),
            nn.ELU(inplace=True),
            
            nn.Conv2d(in_channels=3*base_kernel_num, out_channels=3*base_kernel_num, kernel_size=3, padding=1),
            nn.ELU(inplace=True),
            nn.Conv2d(in_channels=3*base_kernel_num, out_channels=3*base_kernel_num, kernel_size=3, padding=1),
            nn.ELU(inplace=True),
            #downsampling
            nn.Conv2d(in_channels=3*base_kernel_num, out_channels=4*base_kernel_num, kernel_size=3, padding=1, stride=2),
            nn.ELU(inplace=True),
            
            nn.Conv2d(in_channels=4*base_kernel_num, out_channels=4*base_kernel_num, kernel_size=3, padding=1),
            nn.ELU(inplace=True),
            nn.Conv2d(in_channels=4*base_kernel_num, out_channels=4*base_kernel_num, kernel_size=3, padding=1),
            nn.ELU(inplace=True)
        )
        FC_mid1_size = 8 * 8 * 4 * self.base_kernel_num
        FC_mid2_size = 8 * 8 * self.base_kernel_num
        self.FC1 = nn.Linear(FC_mid1_size, embedding_size)
        self.FC2 = nn.Linear(embedding_size, FC_mid2_size)
        self.decoder_wo_FC = nn.Sequential(
            #repeated units
            nn.Conv2d(in_channels=base_kernel_num, out_channels=base_kernel_num, kernel_size=3, padding=1),
            nn.ELU(inplace=True),
            nn.Conv2d(in_channels=base_kernel_num, out_channels=base_kernel_num, kernel_size=3, padding=1),
            nn.ELU(inplace=True),
            nn.UpsamplingNearest2d(scale_factor=2),
            #repeated units
            nn.Conv2d(in_channels=base_kernel_num, out_channels=base_kernel_num, kernel_size=3, padding=1),
            nn.ELU(inplace=True),
            nn.Conv2d(in_channels=base_kernel_num, out_channels=base_kernel_num, kernel_size=3, padding=1),
            nn.ELU(inplace=True),
            nn.UpsamplingNearest2d(scale_factor=2),
            #repeated units
            nn.Conv2d(in_channels=base_kernel_num, out_channels=base_kernel_num, kernel_size=3, padding=1),
            nn.ELU(inplace=True),
            nn.Conv2d(in_channels=base_kernel_num, out_channels=base_kernel_num, kernel_size=3, padding=1),
            nn.ELU(inplace=True),
            nn.UpsamplingNearest2d(scale_factor=2),
            #last layer
            #nn.Conv2d(in_channels=base_kernel_num, out_channels=3, kernel_size=3, padding=1)
            nn.Conv2d(in_channels=base_kernel_num, out_channels=base_kernel_num, kernel_size=3, padding=1),
            nn.ELU(inplace=True),
            nn.Conv2d(in_channels=base_kernel_num, out_channels=base_kernel_num, kernel_size=3, padding=1),
            nn.ELU(inplace=True),
            nn.Conv2d(in_channels=base_kernel_num, out_channels=3, kernel_size=3, padding=1)
        )

    def forward(self, in_data):
        in_data = self.encoder_wo_FC(in_data)
        
        in_data = in_data.view(-1, 4*self.base_kernel_num * 8 * 8);
        in_data = self.FC1(in_data)
        in_data = self.FC2(in_data)
        in_data = in_data.view(-1, self.base_kernel_num, 8, 8)
        
        output = self.decoder_wo_FC(in_data)
        return output

In [15]:
netD = D_CNN(embedding_size=nz, base_kernel_num=64)
netD.apply(D_weights_init)
if opt.netD != '':
    netD.load_state_dict(torch.load(opt.netD))

In [None]:
criterion_L1 = nn.L1Loss()

input = torch.FloatTensor(opt.batchSize, 3, opt.imageSize, opt.imageSize)
noise = torch.FloatTensor(opt.batchSize, nz)
fixed_noise = torch.FloatTensor(opt.batchSize, nz, 1, 1).normal_(0, 1)

if opt.cuda:
    netD.cuda()
    netG.cuda()
    criterion_L1.cuda()
    input = input.cuda()
    noise, fixed_noise = noise.cuda(), fixed_noise.cuda()

input = Variable(input)
noise = Variable(noise)
fixed_noise = Variable(fixed_noise)

# setup optimizer
optimizerD = optim.Adam(netD.parameters(), lr=opt.lr)
optimizerG = optim.Adam(netG.parameters(), lr=opt.lr)

In [None]:
data_loss_D = []
data_loss_G = []
data_loss_k = []
for epoch in range(opt.niter):
    for i, data in enumerate(dataloader, 0):

        ############################
        # (1) Update D network
        ###########################
        netD.zero_grad()

        # prepare real
        real_cpu, _ = data
        batch_size = real_cpu.size(0)
        input.data.resize_(real_cpu.size()).copy_(real_cpu)

        # train with real
        output = netD(input)
        errD_real = criterion_L1(output, input)  # score on real
        errD_real.backward()  # backward on score on real
        L_x = errD_real.data[0]  # score fore supervision
        
        # generate fake
        #noise.data.resize_(batch_size, nz, 1, 1)
        noise.data.normal_(0, 1)
        fake = netG(noise)
        
        # train with fake
        output = netD(fake.detach())
        errD_fake = criterion_L1(output, fake.detach())  # score on fake
        errD_fake_use = - k * errD_fake
        errD_fake_use.backward()  # backward on score on fake
        optimizerD.step()
        
        #D_G_z1 = errD_fake.data.mean()  # score fore supervision <- generated when calc D loss
        errD = errD_real + errD_fake_use  # score fore supervision

        

        ############################
        # (2) Update G network
        ############################

        netG.zero_grad()

        # generate fake
        noise.data.resize_(batch_size, nz, 1, 1)
        fake = netG(noise)
        
        # NOT reuse generated fake samples
        output = netD(fake)
        errG = (output - fake).abs().mean()  # L1
        errG.backward()
        
        L_G = errG.data[0] # score fore supervision <- generated when calc G loss

        optimizerG.step()
        
        # K STEP
        #left -> real
        output = netD(input)
        errD_real_k = criterion_L1(output, input)  # score on real
        #right -> fake
        fake = netG(noise)
        output = netD(fake.detach())
        errD_fake_k = criterion_L1(output, fake.detach())  # score on fake
        
        #Convergence Measure
        cm = errD_real_k.data[0] + abs(gamma * errD_real_k.data[0] - L_G)
        
        k += 0.001 * (gamma * errD_real_k.data[0] - errD_fake_k.data[0])
        k = max(min(k, 1), 0)
        ############################
        # (3) Report & 100 Batch checkpoint
        ############################
        data_loss_G.append(errG.data[0])
        data_loss_D.append(errD.data[0])
        data_loss_k.append(k)
        if i % 2000 == 0:
            optimizerD = optim.Adam(netD.parameters(), lr=opt.lr*lr_decay)
            optimizerG = optim.Adam(netG.parameters(), lr=opt.lr*lr_decay)
            opt.lr = opt.lr * lr_decay
        if i % 10 == 0:
            print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f Convergence: %.4f k= %.6f'
                    % (epoch, opt.niter, i, len(dataloader),
                     errD.data[0], errG.data[0], cm, k))
        
        if i % 100 == 0:
            fake = netG(fixed_noise)
            vutils.save_image(fake.data,
                              '%s/%d_fake_samples_epoch_%03d.png' % (opt.outf, i, epoch))

    # do checkpointing
    torch.save(netG.state_dict(), '%s/netG_epoch_%d.pth' % (opt.outf, epoch))
    torch.save(netD.state_dict(), '%s/netD_epoch_%d.pth' % (opt.outf, epoch))

[0/25][0/12663] Loss_D: 0.5029 Loss_G: 0.0282 Convergence: 0.7249 k= 0.000224
[0/25][10/12663] Loss_D: 0.5050 Loss_G: 0.0686 Convergence: 0.6862 k= 0.002328
[0/25][20/12663] Loss_D: 0.4624 Loss_G: 0.1094 Convergence: 0.5880 k= 0.003515
[0/25][30/12663] Loss_D: 0.5074 Loss_G: 0.1382 Convergence: 0.6202 k= 0.004865
[0/25][40/12663] Loss_D: 0.4338 Loss_G: 0.1859 Convergence: 0.4643 k= 0.005938
[0/25][50/12663] Loss_D: 0.4514 Loss_G: 0.2008 Convergence: 0.4724 k= 0.006462
[0/25][60/12663] Loss_D: 0.3892 Loss_G: 0.2266 Convergence: 0.4215 k= 0.006372
[0/25][70/12663] Loss_D: 0.4043 Loss_G: 0.3506 Convergence: 0.5511 k= 0.005957
[0/25][80/12663] Loss_D: 0.3280 Loss_G: 0.1493 Convergence: 0.3365 k= 0.005277
[0/25][90/12663] Loss_D: 0.3017 Loss_G: 0.2128 Convergence: 0.3591 k= 0.005011
[0/25][100/12663] Loss_D: 0.3108 Loss_G: 0.2065 Convergence: 0.3590 k= 0.004695
[0/25][110/12663] Loss_D: 0.3145 Loss_G: 0.1187 Convergence: 0.3431 k= 0.004744
[0/25][120/12663] Loss_D: 0.2767 Loss_G: 0.1214 Con

[0/25][1030/12663] Loss_D: 0.1697 Loss_G: 0.1103 Convergence: 0.1970 k= 0.003198
[0/25][1040/12663] Loss_D: 0.2085 Loss_G: 0.0915 Convergence: 0.2183 k= 0.003130
[0/25][1050/12663] Loss_D: 0.1926 Loss_G: 0.0827 Convergence: 0.2046 k= 0.003143
[0/25][1060/12663] Loss_D: 0.1788 Loss_G: 0.0979 Convergence: 0.1867 k= 0.003200
[0/25][1070/12663] Loss_D: 0.1799 Loss_G: 0.1001 Convergence: 0.1893 k= 0.003245
[0/25][1080/12663] Loss_D: 0.1687 Loss_G: 0.0983 Convergence: 0.1829 k= 0.003368
[0/25][1090/12663] Loss_D: 0.1755 Loss_G: 0.0701 Convergence: 0.1855 k= 0.003542
[0/25][1100/12663] Loss_D: 0.1773 Loss_G: 0.0942 Convergence: 0.1831 k= 0.003571
[0/25][1110/12663] Loss_D: 0.1768 Loss_G: 0.0846 Convergence: 0.1754 k= 0.003593
[0/25][1120/12663] Loss_D: 0.1882 Loss_G: 0.0778 Convergence: 0.2028 k= 0.003711
[0/25][1130/12663] Loss_D: 0.1692 Loss_G: 0.0762 Convergence: 0.1762 k= 0.003794
[0/25][1140/12663] Loss_D: 0.1631 Loss_G: 0.0780 Convergence: 0.1634 k= 0.003790
[0/25][1150/12663] Loss_D: 0

[0/25][2050/12663] Loss_D: 0.1677 Loss_G: 0.0591 Convergence: 0.1924 k= 0.012937
[0/25][2060/12663] Loss_D: 0.1816 Loss_G: 0.0537 Convergence: 0.2178 k= 0.013205
[0/25][2070/12663] Loss_D: 0.1503 Loss_G: 0.0638 Convergence: 0.1573 k= 0.013454
[0/25][2080/12663] Loss_D: 0.1672 Loss_G: 0.0550 Convergence: 0.1976 k= 0.013689
[0/25][2090/12663] Loss_D: 0.1783 Loss_G: 0.0613 Convergence: 0.2062 k= 0.013905
[0/25][2100/12663] Loss_D: 0.1681 Loss_G: 0.0600 Convergence: 0.1960 k= 0.014123
[0/25][2110/12663] Loss_D: 0.1567 Loss_G: 0.0612 Convergence: 0.1733 k= 0.014311
[0/25][2120/12663] Loss_D: 0.1694 Loss_G: 0.0634 Convergence: 0.1909 k= 0.014524
[0/25][2130/12663] Loss_D: 0.1649 Loss_G: 0.0663 Convergence: 0.1815 k= 0.014708
[0/25][2140/12663] Loss_D: 0.1861 Loss_G: 0.0620 Convergence: 0.2181 k= 0.014895
[0/25][2150/12663] Loss_D: 0.1597 Loss_G: 0.0592 Convergence: 0.1832 k= 0.015073
[0/25][2160/12663] Loss_D: 0.1680 Loss_G: 0.0791 Convergence: 0.1713 k= 0.015223
[0/25][2170/12663] Loss_D: 0

[0/25][3070/12663] Loss_D: 0.1478 Loss_G: 0.0686 Convergence: 0.1559 k= 0.027628
[0/25][3080/12663] Loss_D: 0.1649 Loss_G: 0.0601 Convergence: 0.1861 k= 0.027762
[0/25][3090/12663] Loss_D: 0.1611 Loss_G: 0.0774 Convergence: 0.1710 k= 0.027925
[0/25][3100/12663] Loss_D: 0.1363 Loss_G: 0.1070 Convergence: 0.1783 k= 0.027857
[0/25][3110/12663] Loss_D: 0.1469 Loss_G: 0.0833 Convergence: 0.1578 k= 0.027876
[0/25][3120/12663] Loss_D: 0.1500 Loss_G: 0.0606 Convergence: 0.1633 k= 0.028056
[0/25][3130/12663] Loss_D: 0.1603 Loss_G: 0.0722 Convergence: 0.1687 k= 0.028254
[0/25][3140/12663] Loss_D: 0.1479 Loss_G: 0.0585 Convergence: 0.1669 k= 0.028405
[0/25][3150/12663] Loss_D: 0.1560 Loss_G: 0.0581 Convergence: 0.1776 k= 0.028545
[0/25][3160/12663] Loss_D: 0.1537 Loss_G: 0.0712 Convergence: 0.1610 k= 0.028660
[0/25][3170/12663] Loss_D: 0.1674 Loss_G: 0.0519 Convergence: 0.2010 k= 0.028756
[0/25][3180/12663] Loss_D: 0.1471 Loss_G: 0.0632 Convergence: 0.1576 k= 0.028850
[0/25][3190/12663] Loss_D: 0