In [1]:
import argparse
import random
import math

from tqdm import tqdm
import numpy as np
from PIL import Image

import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.autograd import Variable, grad
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, utils

from UTKFaceDataset import UTKFaceDataset

from model_face_UTK import StyledGenerator, Discriminator


def requires_grad(model, flag=True):
    for p in model.parameters():
        p.requires_grad = flag


def accumulate(model1, model2, decay=0.999):
    par1 = dict(model1.named_parameters())
    par2 = dict(model2.named_parameters())

    for k in par1.keys():
        par1[k].data.mul_(decay).add_(1 - decay, par2[k].data)


def sample_data(PATH_IMG, batch_size, image_size=32):
    transform = transforms.Compose(
        [
            transforms.Resize(image_size),
            transforms.CenterCrop(image_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ]
    )

    dataset = UTKFaceDataset(PATH_IMG, transform=transform)
    loader = DataLoader(dataset, shuffle=True, batch_size=batch_size, num_workers=16)

    return loader

def adjust_lr(optimizer, lr):
    for group in optimizer.param_groups:
        mult = group.get('mult', 1)
        group['lr'] = lr * mult


### Parameters 

In [3]:
code_size = 512
batch_size = 16
n_critic = 1

class Args:
    n_gpu = 4
    phase = 600_000
    lr = 0.001
    init_size = 8
    max_size = 128
    mixing = False
    loss = 'wgan-gp'
    data = 'folder'
    path = '/home/quang/working/dataset/dataFace/'
    sched = None
    
args = Args()
# args = {'n_gpu': 4, '': 600_000, 'lr': 0.001, 'init_size': 64, 'max_size': 1024, 'mixing': False, 'loss': 'wgan-gp', 'data': 'folder'}

generator = nn.DataParallel(StyledGenerator(code_size)).cuda()
discriminator = nn.DataParallel(Discriminator()).cuda()

class_loss = nn.CrossEntropyLoss()

g_optimizer = optim.Adam(
    generator.module.generator.parameters(), lr=args.lr, betas=(0.0, 0.99)
)    
g_optimizer.add_param_group(
    {
        'params': generator.module.style.parameters(),
        'lr': args.lr * 0.01,
        'mult': 0.01,
    }
)

d_optimizer = optim.Adam(discriminator.parameters(), lr=args.lr, betas=(0.0, 0.99))

if args.sched:
    args.lr = {128: 0.0015, 256: 0.002, 512: 0.003, 1024: 0.003}
    args.batch = {4: 512, 8: 256, 16: 128, 32: 64, 64: 32, 128: 32, 256: 32}

else:
    args.lr = {128: 0.0015, 256: 0.002, 512: 0.003, 1024: 0.003}
    args.batch = {4: 32, 8: 32, 16: 32, 32: 32, 64: 32, 128: 16, 256: 8}

args.gen_sample = {512: (8, 4), 1024: (4, 2)}

args.batch_default = 32

### Load model

In [4]:
# Load trained model

# checkpoint = torch.load('./save_model/train_step-4-159999_128.model')
# generator.module.load_state_dict(checkpoint['generator'])
# discriminator.module.load_state_dict(checkpoint['discriminator'])
# g_optimizer.load_state_dict(checkpoint['g_optimizer'])
# d_optimizer.load_state_dict(checkpoint['d_optimizer'])

### Start training

In [None]:
step = int(math.log2(args.init_size)) - 3
resolution = 8 * 2 ** step
loader = sample_data(
    args.path, args.batch.get(resolution, args.batch_default), resolution
)
data_loader = iter(loader)

adjust_lr(g_optimizer, args.lr.get(resolution, 0.001))
adjust_lr(d_optimizer, args.lr.get(resolution, 0.001))

pbar = tqdm(range(500_000))

requires_grad(generator, False)
requires_grad(discriminator, True)

disc_loss_val = 0
gen_loss_val = 0
grad_loss_val = 0

alpha = 0
used_sample = 0
# used_sample = 9972*16

n_class_age = 6
n_repeat = 2
y_onehot = torch.FloatTensor(args.batch_default, n_class_age)


print ('Resolution: ', resolution, '|Step: ', step, '|Batch_size: ', args.batch.get(resolution, args.batch_default), ' |Generator lr: ', 
      g_optimizer.state_dict()['param_groups'][0]['lr'], ' |Style lr: ', g_optimizer.state_dict()['param_groups'][1]['lr'])
for i in pbar:
    discriminator.zero_grad()

    alpha = min(1, 1 / args.phase * (used_sample + 1))

    if (i+1)%10000==0:
        torch.save(
            {
                'generator': generator.module.state_dict(),
                'discriminator': discriminator.module.state_dict(),
                'g_optimizer': g_optimizer.state_dict(),
                'd_optimizer': d_optimizer.state_dict(),
            },
            f'checkpoint/train_step-{step}-{i}-{alpha}.model',
        )

    if used_sample > args.phase * 2 and step < (int(math.log2(args.max_size)) - 3):
        step += 1

        if step > int(math.log2(args.max_size)) - 3:
            step = int(math.log2(args.max_size)) - 3

        else:
            alpha = 0
            used_sample = 0

        resolution = 8 * 2 ** step

        loader = sample_data(
            args.path, args.batch.get(resolution, args.batch_default), resolution
        )
        data_loader = iter(loader)

        torch.save(
            {
                'generator': generator.module.state_dict(),
                'discriminator': discriminator.module.state_dict(),
                'g_optimizer': g_optimizer.state_dict(),
                'd_optimizer': d_optimizer.state_dict(),
            },
            f'checkpoint/train_step-{step}.model',
        )

        adjust_lr(g_optimizer, args.lr.get(resolution, 0.001))
        adjust_lr(d_optimizer, args.lr.get(resolution, 0.001))
        print ('Resolution: ', resolution, '|Step: ', step, '|Batch_size: ', args.batch.get(resolution, args.batch_default), ' |Generator lr: ', 
              g_optimizer.state_dict()['param_groups'][0]['lr'], ' |Style lr: ', g_optimizer.state_dict()['param_groups'][1]['lr'])

    try:
        real_image, label = next(data_loader)

    except (OSError, StopIteration):
        data_loader = iter(loader)
        real_image, label = next(data_loader)

    used_sample += real_image.shape[0]

    b_size = real_image.size(0)
    real_image = real_image.cuda()
    y_onehot.zero_()

    label = label.unsqueeze(1)

    if y_onehot.shape[0] == b_size:
        label = y_onehot.scatter_(1, label, 1)
    else:
        y_onehot_temp = torch.FloatTensor(b_size, n_class_age)
        y_onehot_temp.zero_()
        label = y_onehot_temp.scatter_(1, label, 1)

    label = label.cuda()
    label_age = label.clone()
    label = label.repeat(1,n_repeat)

    if args.loss == 'wgan-gp':
        real_predict = discriminator(real_image, label_age=label_age, step=step, alpha=alpha)
        real_predict = real_predict.mean() - 0.001 * (real_predict ** 2).mean()
        (-real_predict).backward()

    elif args.loss == 'r1':
        real_image.requires_grad = True
        real_predict = discriminator(real_image, label_age=label_age, step=step, alpha=alpha)
        real_predict = F.softplus(-real_predict).mean()
        real_predict.backward(retain_graph=True)

        grad_real = grad(
            outputs=real_predict.sum(), inputs=real_image, create_graph=True
        )[0]
        grad_penalty = (
            grad_real.view(grad_real.size(0), -1).norm(2, dim=1) ** 2
        ).mean()
        grad_penalty = 10 / 2 * grad_penalty
        grad_penalty.backward()
        grad_loss_val = grad_penalty.item()

    if args.mixing and random.random() < 0.9 and False:
        gen_in11, gen_in12, gen_in21, gen_in22 = torch.randn(
            4, b_size, code_size - n_class_age*n_repeat, device='cuda'
        ).chunk(4, 0)
        gen_in1 = [gen_in11.squeeze(0), gen_in12.squeeze(0)]
        gen_in2 = [gen_in21.squeeze(0), gen_in22.squeeze(0)]

    else:
        gen_in1, gen_in2 = torch.randn(2, b_size, code_size - n_class_age*n_repeat, device='cuda').chunk(2, 0)            
        gen_in1 = gen_in1.squeeze(0)
        gen_in2 = gen_in2.squeeze(0)

    gen_in1 = torch.cat((label, gen_in1), 1)
    gen_in2 = torch.cat((label, gen_in2), 1)

    fake_image = generator(gen_in1, step=step, alpha=alpha)
    fake_predict = discriminator(fake_image, label_age=label_age, step=step, alpha=alpha)

    if args.loss == 'wgan-gp':
        fake_predict = fake_predict.mean()
        fake_predict.backward()

        eps = torch.rand(b_size, 1, 1, 1).cuda()
        x_hat = eps * real_image.data + (1 - eps) * fake_image.data
        x_hat.requires_grad = True
        hat_predict = discriminator(x_hat, label_age=label_age, step=step, alpha=alpha)
        grad_x_hat = grad(
            outputs=hat_predict.sum(), inputs=x_hat, create_graph=True
        )[0]
        grad_penalty = (
            (grad_x_hat.view(grad_x_hat.size(0), -1).norm(2, dim=1) - 1) ** 2
        ).mean()
        grad_penalty = 10 * grad_penalty
        grad_penalty.backward()
        grad_loss_val = grad_penalty.item()
        disc_loss_val = (real_predict - fake_predict).item()

    elif args.loss == 'r1':
        fake_predict = F.softplus(fake_predict).mean()
        fake_predict.backward()
        disc_loss_val = (real_predict + fake_predict).item()

    d_optimizer.step()

    if (i + 1) % n_critic == 0:
        generator.zero_grad()

        requires_grad(generator, True)
        requires_grad(discriminator, False)

        fake_image = generator(gen_in2, step=step, alpha=alpha)

        predict = discriminator(fake_image, label_age=label_age, step=step, alpha=alpha)

        if args.loss == 'wgan-gp':
            loss = -predict.mean()

        elif args.loss == 'r1':
            loss = F.softplus(-predict).mean()

        gen_loss_val = loss.item()

        loss.backward()
        g_optimizer.step()

        requires_grad(generator, False)
        requires_grad(discriminator, True)

    if (i + 1) % 500 == 0:
        images = []

        gen_i, gen_j = args.gen_sample.get(resolution, (n_class_age, 10))
        test_onehot = torch.FloatTensor(gen_j, n_class_age)
        random_z = torch.randn(gen_j, code_size - n_class_age*n_repeat).cuda()
        with torch.no_grad():
            for age_code in range(gen_i):
                temp_age = torch.tensor([[age_code]]*gen_j)

                test_onehot.zero_()
                label_test_age = test_onehot.scatter_(1, temp_age, 1).cuda()
                label_test_age = label_test_age.repeat(1,n_repeat)
                gen_test = torch.cat((label_test_age, random_z), 1)
                images.append(generator(gen_test, step=step, alpha=alpha))
        utils.save_image(
            torch.cat(images, 0),
            f'sample/{str(i + 1).zfill(6)}.png',
            nrow=gen_j,
            normalize=True,
            range=(-1, 1),
        )

    state_msg = (
        f'Size: {8 * 2 ** step}; G: {gen_loss_val:.3f}; D: {disc_loss_val:.3f};'
        f' Grad: {grad_loss_val:.3f}; Alpha: {alpha:.5f}'
    )

    pbar.set_description(state_msg)

  0%|          | 0/500000 [00:00<?, ?it/s]

Resolution:  128 |Step:  4 |Batch_size:  16  |Generator lr:  0.0015  |Style lr:  1.5e-05


Size: 128; G: 0.636; D: 0.653; Grad: 0.046; Alpha: 1.00000:  24%|██▎       | 118064/500000 [30:04:16<99:59:36,  1.06it/s]