In [1]:
# TODO LIST FROM GANHACKS
# https://github.com/soumith/ganhacks
# ✓ Normalize the inputs
# ✓ A modified loss function
# ✓ Use a spherical Z
# ✓ BatchNorm
# Aviod sparse gradients
# ✓ Use soft and noisy labels
# ✓ DCGAN
# Use stability tricks from RL
# ✓ Use SGD for discriminator ADAM for generator
# ✓ Add noise to inputs
# Batch Discrimination (for diversity)
# ✓ Use dropouts in G in both train and test phase

In [2]:
import torch.nn as nn
import torch.utils.data
import torchvision.datasets as dset
from ImageFeatureFolder import ImageFeatureFolder

nc = 3
nz = 100
lr     = 0.00017
beta1  = 0.0   
beta2  = 0.99     
imageSize = 64
batchSize = 64

outf = "./celeba_result/"
des_dir = "./celeba_normal/"
landmark_file = './list_landmarks_align_celeba.txt'

dataset = ImageFeatureFolder(root=des_dir, landmark_file=landmark_file, imageSize=imageSize)

dataloader = torch.utils.data.DataLoader(dataset,
                                         batch_size= batchSize,
                                         shuffle=True)

In [3]:
from torch.autograd import Variable
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torchvision.utils as vutils
from Models import Generator, Discriminator
import net_sphere

In [4]:
netG = Generator()
netD = Discriminator()
feature_net = net_sphere.sphere20a(feature=True)

criterion = nn.MSELoss()

input = torch.FloatTensor(batchSize, 3, imageSize,imageSize)
noise = torch.FloatTensor(batchSize, nz, 1, 1)
fixed_noise = torch.FloatTensor(batchSize, nz, 1, 1).normal_(0, 1)

label_real = torch.FloatTensor(batchSize)
label_real_smooth = torch.FloatTensor(batchSize)
label_fake = torch.FloatTensor(batchSize)

netD.cuda()
netG.cuda()
criterion.cuda()
feature_net.cuda()
feature_net.load_state_dict(torch.load('./models/sphere20a_20171020.pth'))
feature_net.eval()
input, noise = input.cuda(), noise.cuda()
label_real, label_real_smooth, label_fake = label_real.cuda(), label_real_smooth.cuda(), label_fake.cuda()
fixed_noise = fixed_noise.cuda()

label_real.resize_(batchSize, 1).fill_(1)
label_fake.resize_(batchSize, 1).fill_(0)
label_real_smooth.resize_(batchSize, 1).fill_(0.9)
label_real = Variable(label_real)
label_fake = Variable(label_fake)
label_real_smooth = Variable(label_real_smooth)

optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, beta2))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, beta2))
fixed_noise = Variable(fixed_noise)

In [None]:
# netD.load_state_dict(torch.load(outf + 'netD_epoch_003.pth'))
# netG.load_state_dict(torch.load(outf + 'netG_epoch_003.pth'))
import numpy as np
_d_ = 0.0
def add_noise(x, d_fake):
    global _d_
    if d_fake is not None:
        _d_ = _d_ * 0.9 + float(torch.mean(d_fake).data[0]) * 0.1
        strength = 0.2 * max(0, _d_ - 0.5)**2
        z = np.random.randn(*x.size()).astype(np.float32) * strength
        z = Variable(torch.from_numpy(z)).cuda()
        return x + z
    else:
        return x

In [None]:
niter = 100
d_fake_save = None
for epoch in range(niter):
#     if epoch < 4:
#         continue
    for i, data in enumerate(dataloader, 0):
        # train D
        netD.zero_grad()
        real_cpu, feature_input_cpu = data
        batch_size = real_cpu.size(0)

        real = real_cpu.cuda()        
        inputv = Variable(real)
        
        feature_input = feature_input_cpu.cuda()
        feature_inputv = Variable(feature_input)
        feature = feature_net(feature_inputv)
        feature_noise = (1 - torch.FloatTensor(batch_size, 512).normal_(0, 0.05)).cuda()
        feature *= feature_noise
        
        d_real = netD(add_noise(inputv, d_fake_save), feature.detach())
        d_real_mean = d_real.data.mean()
        
        noise.resize_(batch_size, nz, 1, 1).normal_(0, 1)
        noisev = Variable(noise)
        fake = netG(noisev, feature.detach())
        
        d_fake_save = d_fake = netD(add_noise(fake.detach(), d_fake_save), feature.detach())
        d_fake_mean = d_fake.data.mean()
        
        loss_d = criterion(d_real, label_real_smooth) + criterion(d_fake, label_fake)
        loss_d.backward()
        optimizerD.step()
        
        # train G
        netG.zero_grad()
        d_fake = netD(fake, feature.detach())
        loss_g = criterion(d_fake, label_real.detach())
        loss_g.backward()
        optimizerG.step()

        if i%500 == 0:
            fake = netG(fixed_noise, feature.detach())
            vutils.save_image(fake.data,
                    '%s/fake_samples_epoch_%03dstep_%04d.png' % (outf, epoch, i),
                    normalize=True)
#             vutils.save_image(inputv.data,
#                     '%s/fake_samples_epoch_%03dstep_%04d_input.png' % (outf, epoch, i),
#                     normalize=True)
            vutils.save_image(feature_inputv.data,
                    '%s/fake_samples_epoch_%03dstep_%04d_feature.png' % (outf, epoch, i),
                    normalize=True)
            print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f'
                  % (epoch, niter, i, len(dataloader),
                     loss_d.data[0], loss_g.data[0], d_real_mean, d_fake_mean))
            
    torch.save(netG.state_dict(), '%s/netG_epoch_%03d.pth' % (outf, epoch))
    torch.save(netD.state_dict(), '%s/netD_epoch_%03d.pth' % (outf, epoch))



[0/100][0/3165] Loss_D: 0.7651 Loss_G: 1.2701 D(x): 0.1561 D(G(z)): 0.1269


  


[0/100][500/3165] Loss_D: 0.4454 Loss_G: 0.6369 D(x): 0.4681 D(G(z)): 0.4671
[0/100][1000/3165] Loss_D: 0.4481 Loss_G: 0.3634 D(x): 0.3486 D(G(z)): 0.2678
[0/100][1500/3165] Loss_D: 0.4525 Loss_G: 1.6847 D(x): 0.7345 D(G(z)): 0.5720
[0/100][2000/3165] Loss_D: 0.3940 Loss_G: 1.4144 D(x): 0.6945 D(G(z)): 0.4372
[0/100][2500/3165] Loss_D: 0.2334 Loss_G: 0.9225 D(x): 0.6341 D(G(z)): 0.3313
[0/100][3000/3165] Loss_D: 0.2328 Loss_G: 1.6349 D(x): 0.8236 D(G(z)): 0.3878
[1/100][0/3165] Loss_D: 0.1576 Loss_G: 0.9187 D(x): 0.6692 D(G(z)): 0.2204
[1/100][500/3165] Loss_D: 0.1732 Loss_G: 1.4006 D(x): 0.8099 D(G(z)): 0.3604
[1/100][1000/3165] Loss_D: 0.1826 Loss_G: 0.9280 D(x): 0.6132 D(G(z)): 0.2579
[1/100][1500/3165] Loss_D: 0.1275 Loss_G: 1.4249 D(x): 0.7750 D(G(z)): 0.2325
[1/100][2000/3165] Loss_D: 0.1598 Loss_G: 0.3571 D(x): 0.5881 D(G(z)): -0.1352
[1/100][2500/3165] Loss_D: 0.2773 Loss_G: 1.9509 D(x): 0.7722 D(G(z)): 0.4591
[1/100][3000/3165] Loss_D: 0.1852 Loss_G: 1.5033 D(x): 0.7575 D(G(z)

[15/100][500/3165] Loss_D: 0.0308 Loss_G: 0.7741 D(x): 0.8793 D(G(z)): -0.1491
[15/100][1000/3165] Loss_D: 0.0809 Loss_G: 0.5403 D(x): 0.7407 D(G(z)): -0.2121
[15/100][1500/3165] Loss_D: 0.0259 Loss_G: 0.9680 D(x): 0.8036 D(G(z)): 0.0584
[15/100][2000/3165] Loss_D: 0.0252 Loss_G: 0.7422 D(x): 0.7693 D(G(z)): -0.0374
[15/100][2500/3165] Loss_D: 0.0169 Loss_G: 0.9965 D(x): 0.8333 D(G(z)): 0.0418
[15/100][3000/3165] Loss_D: 0.0186 Loss_G: 0.8734 D(x): 0.7977 D(G(z)): -0.0243
[16/100][0/3165] Loss_D: 0.0160 Loss_G: 1.1331 D(x): 0.9638 D(G(z)): 0.0455
[16/100][500/3165] Loss_D: 0.1159 Loss_G: 1.6873 D(x): 0.7911 D(G(z)): 0.2982
[16/100][1000/3165] Loss_D: 0.1681 Loss_G: 0.5625 D(x): 0.5199 D(G(z)): -0.1125
[16/100][1500/3165] Loss_D: 0.0751 Loss_G: 1.6319 D(x): 1.0005 D(G(z)): 0.2308
[16/100][2000/3165] Loss_D: 0.2516 Loss_G: 0.4993 D(x): 0.4733 D(G(z)): -0.2421
[16/100][2500/3165] Loss_D: 0.0560 Loss_G: 1.3756 D(x): 0.8859 D(G(z)): 0.2147
[16/100][3000/3165] Loss_D: 0.0159 Loss_G: 1.0182 D

[29/100][3000/3165] Loss_D: 0.0047 Loss_G: 1.0188 D(x): 0.9347 D(G(z)): -0.0003
[30/100][0/3165] Loss_D: 0.0176 Loss_G: 0.8912 D(x): 0.8493 D(G(z)): -0.0985
[30/100][500/3165] Loss_D: 0.0146 Loss_G: 0.9182 D(x): 0.9669 D(G(z)): -0.0683
[30/100][1000/3165] Loss_D: 0.1679 Loss_G: 1.8707 D(x): 0.9852 D(G(z)): 0.3909
[30/100][1500/3165] Loss_D: 0.0209 Loss_G: 0.9580 D(x): 0.7918 D(G(z)): 0.0559
[30/100][2000/3165] Loss_D: 0.0058 Loss_G: 1.0201 D(x): 0.9482 D(G(z)): 0.0124
[30/100][2500/3165] Loss_D: 0.0093 Loss_G: 1.1446 D(x): 0.9706 D(G(z)): 0.0395
[30/100][3000/3165] Loss_D: 0.1453 Loss_G: 1.5438 D(x): 0.9339 D(G(z)): 0.3702
[31/100][0/3165] Loss_D: 0.0409 Loss_G: 1.4273 D(x): 0.9104 D(G(z)): 0.1760
[31/100][500/3165] Loss_D: 0.0666 Loss_G: 1.7334 D(x): 0.9794 D(G(z)): 0.2261
[31/100][1000/3165] Loss_D: 0.0165 Loss_G: 1.1877 D(x): 0.8978 D(G(z)): 0.1111
[31/100][1500/3165] Loss_D: 0.0245 Loss_G: 1.2428 D(x): 0.9993 D(G(z)): 0.1059
[31/100][2000/3165] Loss_D: 0.3097 Loss_G: 1.5803 D(x): 0

[44/100][2000/3165] Loss_D: 0.0230 Loss_G: 1.0895 D(x): 0.8362 D(G(z)): 0.1122
[44/100][2500/3165] Loss_D: 0.0193 Loss_G: 0.9585 D(x): 0.8815 D(G(z)): -0.1262
[44/100][3000/3165] Loss_D: 0.0230 Loss_G: 0.6128 D(x): 0.8058 D(G(z)): -0.0857
[45/100][0/3165] Loss_D: 0.0054 Loss_G: 0.9970 D(x): 0.8904 D(G(z)): -0.0115
[45/100][500/3165] Loss_D: 0.0143 Loss_G: 1.3226 D(x): 0.8938 D(G(z)): 0.1008
[45/100][1000/3165] Loss_D: 0.0044 Loss_G: 1.0008 D(x): 0.8995 D(G(z)): -0.0274
[45/100][1500/3165] Loss_D: 0.0234 Loss_G: 1.2306 D(x): 0.9510 D(G(z)): 0.1203
[45/100][2000/3165] Loss_D: 0.0097 Loss_G: 0.9804 D(x): 0.8190 D(G(z)): -0.0032
[45/100][2500/3165] Loss_D: 0.1407 Loss_G: 0.5459 D(x): 0.5544 D(G(z)): -0.1080
[45/100][3000/3165] Loss_D: 0.0035 Loss_G: 1.0552 D(x): 0.9381 D(G(z)): 0.0231
[46/100][0/3165] Loss_D: 0.0042 Loss_G: 1.0829 D(x): 0.8848 D(G(z)): 0.0446
[46/100][500/3165] Loss_D: 0.0186 Loss_G: 0.7963 D(x): 0.8416 D(G(z)): -0.1099
[46/100][1000/3165] Loss_D: 0.0296 Loss_G: 0.8111 D(x