In [1]:
# TODO LIST FROM GANHACKS
# https://github.com/soumith/ganhacks
# ✓ Normalize the inputs
# ✓ A modified loss function
# ✓ Use a spherical Z
# ✓ BatchNorm
# Aviod sparse gradients
# ✓ Use soft and noisy labels
# ✓ DCGAN
# Use stability tricks from RL
# ✓ Use SGD for discriminator ADAM for generator
# ✓ Add noise to inputs
# Batch Discrimination (for diversity)
# ✓ Use dropouts in G in both train and test phase

In [2]:
import torch.nn as nn
import torchvision.transforms as transforms
import torch.utils.data
import torchvision.datasets as dset
from Folder import ImageFeatureFolder

nc = 3
nz = 512
lr     = 0.00023
beta1  = 0.0   
beta2  = 0.99     
imageSize = 128
batchSize = 64

outf = "./celeba_result/"
des_dir = "./celeba/"

dataset = dset.ImageFolder(root=des_dir,
                        transform=transforms.Compose([
                            transforms.CenterCrop(178),
                            transforms.Resize(imageSize),
                            transforms.ToTensor(),
                            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                        ]))

dataloader = torch.utils.data.DataLoader(dataset,
                                         batch_size= batchSize,
                                         shuffle=True)

In [3]:
from torch.autograd import Variable
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torchvision.utils as vutils
from Models import Generator, Discriminator
import net_sphere

In [4]:
netG = Generator()
netD = Discriminator()
feature_net = net_sphere.sphere20a(feature=True)

criterion = nn.MSELoss()

input = torch.FloatTensor(batchSize, 3, imageSize,imageSize)
noise = torch.FloatTensor(batchSize, nz, 1, 1)
fixed_noise = torch.FloatTensor(batchSize, nz, 1, 1).normal_(0, 1)

label_real = torch.FloatTensor(batchSize)
label_real_smooth = torch.FloatTensor(batchSize)
label_fake = torch.FloatTensor(batchSize)

netD.cuda()
netG.cuda()
criterion.cuda()
feature_net.cuda()
feature_net.load_state_dict(torch.load('./models/sphere20a_20171020.pth'))
feature_net.eval()
input, noise = input.cuda(), noise.cuda()
label_real, label_real_smooth, label_fake = label_real.cuda(), label_real_smooth.cuda(), label_fake.cuda()
fixed_noise = fixed_noise.cuda()

label_real.resize_(batchSize, 1).fill_(1)
label_fake.resize_(batchSize, 1).fill_(0)
label_real_smooth.resize_(batchSize, 1).fill_(0.9)
label_real = Variable(label_real)
label_fake = Variable(label_fake)
label_real_smooth = Variable(label_real_smooth)
print()




In [5]:
# netD.load_state_dict(torch.load(outf + 'netD_epoch_016.pth'))
# netG.load_state_dict(torch.load(outf + 'netG_epoch_016.pth'))

In [6]:
fixed_noise = Variable(fixed_noise)

# setup optimizer
# optimizerD = optim.SGD(netD.parameters(), lr = lr, momentum=0.9)
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, beta2))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, beta2))
schedulerD = optim.lr_scheduler.MultiStepLR(optimizerD, milestones=[4, 7, 11, 17], gamma=0.87)
schedulerG = optim.lr_scheduler.MultiStepLR(optimizerG, milestones=[4, 6, 8, 10, 12, 14, 17], gamma=0.87)

In [7]:
import numpy as np
_d_ = None
def add_noise(x, d_fake):
    global _d_
    if _d_ is not None:
        _d_ = _d_ * 0.9 + torch.mean(d_fake).data[0] * 0.1
        strength = 0.2 * max(0.00001, _d_ - 0.5)**2
        z = torch.FloatTensor(*x.size()).normal_(0, 1) * strength
#         z = np.random.randn(*x.size()).astype(np.float32) * strength
        zv = Variable(z).cuda()
        return x + zv
    else:
        _d_ = 0.0
        return x


In [8]:
feature_input_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.CenterCrop(128), 
    transforms.Resize(112), 
    transforms.CenterCrop((112, 96)),
    transforms.ToTensor()
])

def transform_batch(batch, transform=feature_input_transform):
    ret_list = []
    for i in range(batch.shape[0]):
        ret_list.append(transform(batch[i]).view(1, 3, 112, 96))
    return torch.cat(ret_list, 0)

In [9]:
niter = 100
d_fake_save = None
for epoch in range(niter):
#     schedulerD.step()
#     schedulerG.step()
#     if epoch < 17:
#         continue
    for i, data in enumerate(dataloader, 0):
        # train D
        netD.zero_grad()
        real_cpu, _ = data
        batch_size = real_cpu.size(0)

        real = real_cpu.cuda()
        input.resize_as_(real).copy_(real)
        
        inputv = Variable(input)
#         inputv = add_noise(inputv, d_fake_save)
        
        feature_input_real = transform_batch(real_cpu).cuda()
        feature_inputv_real = Variable(feature_input_real)
        feature_real = feature_net(feature_inputv_real)
        d_real = netD(inputv, feature_real.detach(), feature_real.detach())
        d_real_mean = d_real.data.mean()
        
        noise.resize_(batch_size, nz, 1, 1).normal_(0, 1)
        noisev = Variable(noise)
        fake = netG(noisev, feature_real.detach())
        
        feature_input_fake = transform_batch(fake.data.cpu()).cuda()
        feature_inputv_fake = Variable(feature_input_fake)
        feature_fake = feature_net(feature_inputv_fake)
        
        d_fake_save = d_fake = netD(fake.detach(), feature_real.detach(), feature_fake.detach())
        d_fake_mean = d_fake.data.mean()
        
        loss_d = criterion(d_real, label_real_smooth) + criterion(d_fake, label_fake)
        loss_d.backward()
        optimizerD.step()
        
        # train G
        netG.zero_grad()
        d_fake = netD(fake, feature_real.detach(), feature_fake.detach())
        loss_g = criterion(d_fake, label_real.detach())
        loss_g.backward()
        optimizerG.step()

        if i%500 == 0:
            fake = netG(fixed_noise, feature_real.detach())
            vutils.save_image(fake.data,
                    '%s/fake_samples_epoch_%03dstep_%04d.png' % (outf, epoch, i),
                    normalize=True)
            print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f'
                  % (epoch, niter, i, len(dataloader),
                     loss_d.data[0], loss_g.data[0], d_real_mean, d_fake_mean))
            
    torch.save(netG.state_dict(), '%s/netG_epoch_%03d.pth' % (outf, epoch))
    torch.save(netD.state_dict(), '%s/netD_epoch_%03d.pth' % (outf, epoch))

RuntimeError: size mismatch, m1: [64 x 2048], m2: [512 x 1] at c:\programdata\miniconda3\conda-bld\pytorch_1524549877902\work\aten\src\thc\generic/THCTensorMathBlas.cu:249