In [1]:
# TODO LIST FROM GANHACKS
# https://github.com/soumith/ganhacks
# ✓ Normalize the inputs
# ✓ A modified loss function
# ✓ Use a spherical Z
# ✓ BatchNorm
# Aviod sparse gradients
# ✓ Use soft and noisy labels
# ✓ DCGAN
# Use stability tricks from RL
# ✓ Use SGD for discriminator ADAM for generator
# ✓ Add noise to inputs
# Batch Discrimination (for diversity)
# ✓ Use dropouts in G in both train and test phase

In [2]:
import network as layer
import torch.nn as nn

nc = 3
nz = 512
nfeature = 736

In [3]:
import torchvision.transforms as transforms
import torch.utils.data
from Folder import ImageFeatureFolder

lr     = 0.0002
beta1  = 0.0   
beta2  = 0.99     
imageSize = 64
batchSize = 64

outf = "./celeba_result/"
des_dir = "./celeba/"
embed_dir = "./celeba_embed/"

dataset = ImageFeatureFolder(root=des_dir,
                            feature_dir=embed_dir,
                            transform=transforms.Compose([
                                transforms.CenterCrop(178),
                                transforms.Resize(imageSize),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                            ]))

dataloader = torch.utils.data.DataLoader(dataset,
                                         batch_size= batchSize,
                                         shuffle=True)

In [4]:
from custom_layers import *

class Generator(nn.Module):
    def make_dense(self, k_in, k_growth, n, options):
        layers = []
        for i in range(n):
            layers.append(Dense(layer.conv(k_in, k_growth, 3, 1, 1, **options)))
            k_in += k_growth
        return nn.Sequential(*layers)
    
    def __init__(self):
        super(Generator, self).__init__()
        options = {'leaky':True, 'bn':True, 'wn':False, 'pixel':True, 'gdrop':True}
        
        self.deconv1 = layer.deconv(nz, 512, 4, 1, 0, **options)
        self.conv_feature = layer.conv(512, 256, 3, 2, 1, leaky=True)
        
        
        layers = []
        # cat(deconv1, conv_feature) will be feeded
        # 4 x 4
        layers.append(layer.deconv(512 + 256, 256, 4, 2, 1, **options))
        # 8 x 8
        layers.append(layer.deconv(256, 128, 4, 2, 1, **options))
        # 16 x 16
        layers.append(layer.deconv(128, 64, 4, 2, 1, **options))
        # 32 x 32
        layers.append(layer.deconv(64, nc, 4, 2, 1, gdrop=options['gdrop'], only=True))
        # 64 x 64
        layers.append(nn.Tanh())
        
        self.deconv2 = nn.Sequential(*layers)
        
    def forward(self, x, feature):
        x = self.deconv1(x)
        feature = self.conv_feature(feature.view(-1, 512, 7, 7))
        x = torch.cat([x, feature], 1)
        x = self.deconv2(x)
        return x
    
netG = Generator()

In [5]:
class Discriminator(nn.Module):
    def make_dense(self, k_in, k_growth, n, options):
        layers = []
        for i in range(n):
            layers.append(Dense(layer.conv(k_in, k_growth, 3, 1, 1, **options)))
            k_in += k_growth
        return nn.Sequential(*layers)
    
    def __init__(self):
        super(Discriminator, self).__init__()
        
        options = {'leaky':True, 'bn':False, 'wn':False, 'pixel':False, 'gdrop':False}
        
        layers = []
        
        # 64 x 64
        layers.append(layer.conv(nc, 64, 4, 2, 1, **options))
        # 32 x 32
        layers.append(layer.conv(64, 128, 4, 2, 1, **options))
        # 16 x 16
        layers.append(layer.conv(128, 256, 4, 2, 1, **options))
        # 8 x 8
        layers.append(layer.conv(256, 512, 4, 2, 1, **options))
        # 4 x 4
        self.conv1 = nn.Sequential(*layers)
        self.conv_feature = layer.conv(512, 256, 3, 2, 1, leaky=True)
        self.conv2 = layer.conv(512 + 256, 1, 4, 1, 0, **options)
        # 1 x 1
    
    def forward(self, x, feature):
        x = self.conv1(x)
        feature = self.conv_feature(feature.view(-1, 512, 7, 7))
        x = torch.cat([x, feature], 1)
        x = self.conv2(x)
        return x.view(-1, 1)
    
netD = Discriminator()

In [6]:
from torch.autograd import Variable
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torchvision.utils as vutils

In [7]:
criterion = nn.MSELoss()

input = torch.FloatTensor(batchSize, 3, imageSize,imageSize)
noise = torch.FloatTensor(batchSize, nz, 1, 1)
feature = torch.FloatTensor(batchSize, 512, 7, 7)
fixed_noise = torch.FloatTensor(batchSize, nz, 1, 1).normal_(0, 1)
fixed_feature = dataset.feature[1]
fixed_feature = fixed_feature.repeat(batchSize, 1, 1, 1)

label_real = torch.FloatTensor(batchSize)
label_real_smooth = torch.FloatTensor(batchSize)
label_fake = torch.FloatTensor(batchSize)

netD.cuda()
netG.cuda()
criterion.cuda()
input, feature, noise = input.cuda(), feature.cuda(), noise.cuda()
label_real, label_real_smooth, label_fake = label_real.cuda(), label_real_smooth.cuda(), label_fake.cuda()
fixed_noise, fixed_feature = fixed_noise.cuda(), fixed_feature.cuda()

label_real.resize_(batchSize, 1).fill_(1)
label_fake.resize_(batchSize, 1).fill_(0)
label_real_smooth.resize_(batchSize, 1).fill_(0.8)
label_real = Variable(label_real)
label_fake = Variable(label_fake)
label_real_smooth = Variable(label_real_smooth)
print()




In [8]:
# netD.load_state_dict(torch.load(outf + 'netD_epoch_042.pth'))
# netG.load_state_dict(torch.load(outf + 'netG_epoch_042.pth'))

In [9]:
fixed_noise = Variable(fixed_noise)
fixed_feature = Variable(fixed_feature)

# setup optimizer
# optimizerD = optim.SGD(netD.parameters(), lr = lr, momentum=0.9)
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, beta2))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, beta2))
schedulerD = optim.lr_scheduler.MultiStepLR(optimizerD, milestones=[4, 7, 11, 17], gamma=0.87)
schedulerG = optim.lr_scheduler.MultiStepLR(optimizerG, milestones=[4, 7, 11, 17], gamma=0.87)
result_dict = {}
loss_D,loss_G,score_D,score_G1,score_G2 = [],[],[],[],[]

In [None]:
import numpy as np
_d_ = None
def add_noise(x, d_fake):
    global _d_
    if _d_ is not None:
        _d_ = _d_ * 0.9 + torch.mean(d_fake).data[0] * 0.1
        strength = 0.2 * max(0, _d_ - 0.5)**2
        z = np.random.randn(*x.size()).astype(np.float32) * strength
        z = Variable(torch.from_numpy(z)).cuda()
        return x + z
    else:
        _d_ = 0.0
        return x


In [None]:
niter = 100
d_fake_save = None
for epoch in range(niter):
    schedulerD.step()
    schedulerG.step()
    
    for i, data in enumerate(dataloader, 0):

        # train D
        netD.zero_grad()
        real_cpu, embed_cpu = data
        batch_size = real_cpu.size(0)

        real = real_cpu.cuda()
        embed = embed_cpu.cuda()
        input.resize_as_(real).copy_(real)
        feature.resize_as_(embed).copy_(embed)
        
        inputv = Variable(input)
        inputv = add_noise(inputv, d_fake_save)
        featurev = Variable(feature)
        
        d_real = netD(inputv, featurev)
        d_real_mean = d_real.data.mean()
        
        noise.resize_(batch_size, nz, 1, 1).normal_(0, 1)
        noisev = Variable(noise)
        fake = netG(noisev, featurev)
        
        d_fake_save = d_fake = netD(fake.detach(), featurev)
        d_fake_mean = d_fake.data.mean()
        
        loss_d = criterion(d_real, label_real_smooth) + criterion(d_fake, label_fake)
        loss_d.backward()
        optimizerD.step()
        
        # train G
        netG.zero_grad()
        d_fake = netD(fake, featurev)
        loss_g = criterion(d_fake, label_real.detach())
        loss_g.backward()
        optimizerG.step()

        if i%250 == 0:
            fake = netG(fixed_noise, fixed_feature)
            vutils.save_image(fake.data,
                    '%s/fake_samples_epoch_%03dstep_%04d.png' % (outf, epoch, i),
                    normalize=True)
            print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f'
                  % (epoch, niter, i, len(dataloader),
                     loss_d.data[0], loss_g.data[0], d_real_mean, d_fake_mean))
            
    torch.save(netG.state_dict(), '%s/netG_epoch_%03d.pth' % (outf, epoch))
    torch.save(netD.state_dict(), '%s/netD_epoch_%03d.pth' % (outf, epoch))



[0/100][0/3165] Loss_D: 0.5498 Loss_G: 0.1301 D(x): 0.0669 D(G(z)): 0.0635


  


[0/100][250/3165] Loss_D: 0.3185 Loss_G: 0.3430 D(x): 0.3739 D(G(z)): 0.3672
[0/100][500/3165] Loss_D: 0.3127 Loss_G: 0.4305 D(x): 0.4227 D(G(z)): 0.4098
[0/100][750/3165] Loss_D: 0.3132 Loss_G: 0.3371 D(x): 0.3543 D(G(z)): 0.3354
[0/100][1000/3165] Loss_D: 0.3102 Loss_G: 0.4232 D(x): 0.3917 D(G(z)): 0.3734
[0/100][1250/3165] Loss_D: 0.3024 Loss_G: 0.4495 D(x): 0.4215 D(G(z)): 0.3930
[0/100][1500/3165] Loss_D: 0.2974 Loss_G: 0.6385 D(x): 0.5053 D(G(z)): 0.4498
[0/100][1750/3165] Loss_D: 0.3989 Loss_G: 0.9594 D(x): 0.5601 D(G(z)): 0.5620
[0/100][2000/3165] Loss_D: 0.2647 Loss_G: 0.4333 D(x): 0.4214 D(G(z)): 0.3334
[0/100][2250/3165] Loss_D: 0.2875 Loss_G: 0.3899 D(x): 0.3842 D(G(z)): 0.3266
[0/100][2500/3165] Loss_D: 0.2444 Loss_G: 0.6588 D(x): 0.5049 D(G(z)): 0.3678
[0/100][2750/3165] Loss_D: 0.2418 Loss_G: 0.4910 D(x): 0.4320 D(G(z)): 0.2991
[0/100][3000/3165] Loss_D: 0.2944 Loss_G: 0.4998 D(x): 0.4692 D(G(z)): 0.4135
[1/100][0/3165] Loss_D: 0.3052 Loss_G: 0.6838 D(x): 0.5933 D(G(z)):

[8/100][500/3165] Loss_D: 0.0881 Loss_G: 0.7723 D(x): 0.6092 D(G(z)): 0.0778
[8/100][750/3165] Loss_D: 0.0718 Loss_G: 0.7811 D(x): 0.6309 D(G(z)): 0.0553
[8/100][1000/3165] Loss_D: 0.0592 Loss_G: 0.9680 D(x): 0.6970 D(G(z)): 0.1168
[8/100][1250/3165] Loss_D: 0.0835 Loss_G: 0.8367 D(x): 0.6430 D(G(z)): 0.1120
[8/100][1500/3165] Loss_D: 0.0850 Loss_G: 1.0326 D(x): 0.7398 D(G(z)): 0.1418
[8/100][1750/3165] Loss_D: 0.0494 Loss_G: 1.0452 D(x): 0.7617 D(G(z)): 0.0687
[8/100][2000/3165] Loss_D: 0.1482 Loss_G: 0.6538 D(x): 0.4860 D(G(z)): 0.0115
[8/100][2250/3165] Loss_D: 0.0767 Loss_G: 0.8839 D(x): 0.6479 D(G(z)): 0.1000
[8/100][2500/3165] Loss_D: 0.0678 Loss_G: 1.0789 D(x): 0.8181 D(G(z)): 0.1301
[8/100][2750/3165] Loss_D: 0.1483 Loss_G: 0.5067 D(x): 0.4845 D(G(z)): 0.0229
[8/100][3000/3165] Loss_D: 0.1026 Loss_G: 1.0574 D(x): 0.7094 D(G(z)): 0.2130
[9/100][0/3165] Loss_D: 0.1093 Loss_G: 0.5446 D(x): 0.5388 D(G(z)): 0.0215
[9/100][250/3165] Loss_D: 0.1274 Loss_G: 1.0513 D(x): 0.7394 D(G(z)):

[16/100][500/3165] Loss_D: 0.0681 Loss_G: 1.1524 D(x): 0.8289 D(G(z)): 0.1409
[16/100][750/3165] Loss_D: 0.0737 Loss_G: 0.8053 D(x): 0.6116 D(G(z)): 0.0085
[16/100][1000/3165] Loss_D: 0.0534 Loss_G: 0.9350 D(x): 0.6638 D(G(z)): 0.0603
[16/100][1250/3165] Loss_D: 0.0713 Loss_G: 1.0642 D(x): 0.7275 D(G(z)): 0.1095
[16/100][1500/3165] Loss_D: 0.0538 Loss_G: 0.8708 D(x): 0.6643 D(G(z)): 0.0224
[16/100][1750/3165] Loss_D: 0.0497 Loss_G: 0.9943 D(x): 0.7095 D(G(z)): 0.0315
[16/100][2000/3165] Loss_D: 0.0447 Loss_G: 1.0390 D(x): 0.7271 D(G(z)): 0.0307
[16/100][2250/3165] Loss_D: 0.0489 Loss_G: 1.0873 D(x): 0.7798 D(G(z)): 0.0929
[16/100][2500/3165] Loss_D: 0.0527 Loss_G: 0.9250 D(x): 0.6865 D(G(z)): 0.0497
[16/100][2750/3165] Loss_D: 0.0490 Loss_G: 1.0642 D(x): 0.7782 D(G(z)): 0.0690
[16/100][3000/3165] Loss_D: 0.0720 Loss_G: 0.6075 D(x): 0.5792 D(G(z)): -0.0101
[17/100][0/3165] Loss_D: 0.0384 Loss_G: 1.0864 D(x): 0.8010 D(G(z)): 0.0671
[17/100][250/3165] Loss_D: 0.0258 Loss_G: 1.0658 D(x): 0

[24/100][500/3165] Loss_D: 0.0261 Loss_G: 1.0676 D(x): 0.7487 D(G(z)): 0.0324
[24/100][750/3165] Loss_D: 0.0336 Loss_G: 1.0008 D(x): 0.7452 D(G(z)): 0.0305
[24/100][1000/3165] Loss_D: 0.0403 Loss_G: 1.0471 D(x): 0.7298 D(G(z)): 0.0516
[24/100][1250/3165] Loss_D: 0.0340 Loss_G: 1.0662 D(x): 0.7570 D(G(z)): 0.0520
[24/100][1500/3165] Loss_D: 0.0327 Loss_G: 1.0803 D(x): 0.7851 D(G(z)): 0.0200
[24/100][1750/3165] Loss_D: 0.0476 Loss_G: 0.9699 D(x): 0.7056 D(G(z)): 0.0350
[24/100][2000/3165] Loss_D: 0.0833 Loss_G: 1.1536 D(x): 0.8585 D(G(z)): 0.1921
[24/100][2250/3165] Loss_D: 0.0427 Loss_G: 0.9410 D(x): 0.6889 D(G(z)): 0.0402
[24/100][2500/3165] Loss_D: 0.0429 Loss_G: 1.0031 D(x): 0.7087 D(G(z)): 0.0392
[24/100][2750/3165] Loss_D: 0.0375 Loss_G: 1.0227 D(x): 0.7389 D(G(z)): 0.0334
[24/100][3000/3165] Loss_D: 0.0490 Loss_G: 1.1718 D(x): 0.8374 D(G(z)): 0.1006
[25/100][0/3165] Loss_D: 0.0634 Loss_G: 1.1635 D(x): 0.8274 D(G(z)): 0.1449
[25/100][250/3165] Loss_D: 0.0416 Loss_G: 0.8545 D(x): 0.