In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import os
import time
import datetime
from torch.autograd import grad
from torch.autograd import Variable
from torchvision.utils import save_image
from torchvision import transforms
from PIL import Image
import argparse
from torch.backends import cudnn
from torchvision.datasets import ImageFolder
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

In [3]:
class ResidualBlock(nn.Module):
    def __init__(self, dim_in, dim_out):
        super(ResidualBlock, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(dim_in, dim_out, kernel_size=3, stride=1, padding=1, bias=False),
            nn.InstanceNorm2d(dim_out, affine=True),
            nn.LeakyReLU(0.01, inplace=True),
            nn.Dropout(),
            nn.Conv2d(dim_out, dim_out, kernel_size=3, stride=1, padding=1, bias=False),
            nn.InstanceNorm2d(dim_out, affine=True))

    def forward(self, x):
        return x + self.main(x)

In [4]:
class Generator(nn.Module):
    def __init__(self, conv_dim=64, c_dim=5, repeat_num=6):
        super(Generator, self).__init__()

        self.conv1 = nn.Sequential(nn.Conv2d(3+c_dim, conv_dim, kernel_size=7, stride=1, padding=3, bias=False), nn.InstanceNorm2d(conv_dim, affine=True), nn.LeakyReLU(0.01, inplace=True))

        curr_dim = conv_dim
        self.conv2 = nn.Sequential(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1, bias=False), nn.InstanceNorm2d(curr_dim*2, affine=True), nn.LeakyReLU(0.01, inplace=True))
        curr_dim = curr_dim * 2
        self.conv3 = nn.Sequential(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1, bias=False), nn.InstanceNorm2d(curr_dim*2, affine=True), nn.LeakyReLU(0.01, inplace=True))
        curr_dim = curr_dim * 2

        layers = []
        for i in range(repeat_num):
            layers.append(ResidualBlock(dim_in=curr_dim, dim_out=curr_dim))
        self.main_res = nn.Sequential(*layers)

        self.dconv3 = nn.Sequential(nn.ConvTranspose2d(curr_dim, curr_dim//2, kernel_size=4, stride=2, padding=1, bias=False), nn.InstanceNorm2d(curr_dim//2, affine=True), nn.LeakyReLU(0.01, inplace=True))
        curr_dim = curr_dim // 2
        self.dconv2 = nn.Sequential(nn.ConvTranspose2d(curr_dim, curr_dim//2, kernel_size=4, stride=2, padding=1, bias=False), nn.InstanceNorm2d(curr_dim//2, affine=True), nn.LeakyReLU(0.01, inplace=True))
        curr_dim = curr_dim // 2

        self.dconv1 = nn.Sequential(nn.Conv2d(curr_dim, 3, kernel_size=7, stride=1, padding=3, bias=False), nn.Tanh())

    def forward(self, x, c):
        c = c.unsqueeze(2).unsqueeze(3)
        c = c.expand(c.size(0), c.size(1), x.size(2), x.size(3))
        xc = torch.cat([x, c], dim=1)

        h1 = self.conv1(xc)
        h2 = self.conv2(h1)
        h3 = self.conv3(h2)
        h4 = self.main_res(h3) + h3
        h5 = self.dconv3(h4) + h2
        h6 = self.dconv2(h5) + h1

        return self.dconv1(h6) + x

In [5]:
class Discriminator(nn.Module):
    def __init__(self, image_size=128, conv_dim=64, c_dim=5, repeat_num=6):
        super(Discriminator, self).__init__()

        self.repeat_num = repeat_num

        layers = []
        layers.append(nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1))
        layers.append(nn.LeakyReLU(0.01, inplace=True))

        curr_dim = conv_dim
        for i in range(1, repeat_num):
            layers.append(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1))
            layers.append(nn.LeakyReLU(0.01, inplace=True))
            curr_dim = curr_dim * 2

        k_size = int(image_size / np.power(2, repeat_num))
        self.main = nn.Sequential(*layers)
        self.conv1 = nn.Conv2d(curr_dim, 1, kernel_size=3, stride=1, padding=1, bias=False)
        self.conv2 = nn.Conv2d(curr_dim, c_dim, kernel_size=k_size, bias=False)

    def forward(self, x):
        h = x
        out_feats = []
        for i in range(0, self.repeat_num):
            h = nn.Sequential(*list(self.main.children())[i*2:(i+1)*2])(h)
            if i < 4:
                out_feats.append(h.squeeze())
        out_real = self.conv1(h)
        out_aux = self.conv2(h)
        return out_real.squeeze(), out_aux.squeeze(), out_feats

In [6]:
class Solver(object):

    def __init__(self, fer2013_loader, config):
        self.fer2013_loader = fer2013_loader

        self.c_dim = config.c_dim
        self.c2_dim = config.c2_dim
        self.image_size = config.image_size
        self.g_conv_dim = config.g_conv_dim
        self.d_conv_dim = config.d_conv_dim
        self.g_repeat_num = config.g_repeat_num
        self.d_repeat_num = config.d_repeat_num
        self.d_train_repeat = config.d_train_repeat

        self.lambda_cls = config.lambda_cls
        self.lambda_rec = config.lambda_rec
        self.lambda_gp = config.lambda_gp
        self.lambda_feat_rec = config.lambda_feat_rec
        self.g_lr = config.g_lr
        self.d_lr = config.d_lr
        self.beta1 = config.beta1
        self.beta2 = config.beta2

        self.dataset = config.dataset
        self.num_epochs = config.num_epochs
        self.num_epochs_decay = config.num_epochs_decay
        self.num_iters = config.num_iters
        self.num_iters_decay = config.num_iters_decay
        self.batch_size = config.batch_size
        self.use_tensorboard = config.use_tensorboard
        self.pretrained_model = config.pretrained_model

        self.test_model = config.test_model

        self.sample_path = config.sample_path
        self.model_save_path = config.model_save_path
        self.result_path = config.result_path

        self.sample_step = config.sample_step
        self.model_save_step = config.model_save_step

        self.build_model()

        if self.pretrained_model:
            self.load_pretrained_model()

    def build_model(self):
        self.G = Generator(self.g_conv_dim, self.c_dim, self.g_repeat_num)
        self.D = Discriminator(self.image_size, self.d_conv_dim, self.c_dim, self.d_repeat_num) 

        self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2])
        self.d_optimizer = torch.optim.Adam(self.D.parameters(), self.d_lr, [self.beta1, self.beta2])

        self.print_network(self.G, 'G')
        self.print_network(self.D, 'D')

        if torch.cuda.is_available():
            self.G.cuda()
            self.D.cuda()

    def print_network(self, model, name):
        num_params = 0
        for p in model.parameters():
            num_params += p.numel()
        print(name)
        print(model)
        print("The number of parameters: {}".format(num_params))

    def load_pretrained_model(self):
        self.G.load_state_dict(torch.load(os.path.join(
            self.model_save_path, '{}_G.pth'.format(self.pretrained_model))))
        self.D.load_state_dict(torch.load(os.path.join(
            self.model_save_path, '{}_D.pth'.format(self.pretrained_model))))
        print('loaded trained models')

    def update_lr(self, g_lr, d_lr):
        for param_group in self.g_optimizer.param_groups:
            param_group['lr'] = g_lr
        for param_group in self.d_optimizer.param_groups:
            param_group['lr'] = d_lr

    def reset_grad(self):
        self.g_optimizer.zero_grad()
        self.d_optimizer.zero_grad()

    def to_var(self, x, volatile=False):
        if torch.cuda.is_available():
            x = x.cuda()
        return Variable(x, volatile=volatile)

    def denorm(self, x):
        out = (x + 1) / 2
        return out.clamp_(0, 1)

    def threshold(self, x):
        x = x.clone()
        x[x >= 0.5] = 1
        x[x < 0.5] = 0
        return x

    def compute_accuracy(self, x, y, dataset):
        _, predicted = torch.max(x, dim=1)
        correct = (predicted == y).float()
        accuracy = torch.mean(correct) * 100.0
        return accuracy

    def one_hot(self, labels, dim):
        batch_size = labels.size(0)
        out = torch.zeros(batch_size, dim)
        out[np.arange(batch_size), labels.long()] = 1
        return out

    def train(self):

        self.data_loader = self.fer2013_loader
        iters_per_epoch = len(self.data_loader)

        fixed_x = []
        real_c = []
        for i, (images, labels) in enumerate(self.data_loader):
            fixed_x.append(images)
            real_c.append(labels)
            if i == 3:
                break

        fixed_x = torch.cat(fixed_x, dim=0)
        fixed_x = self.to_var(fixed_x, volatile=True)
        real_c = torch.cat(real_c, dim=0)

        fixed_c_list = []
        for i in range(self.c_dim):
            fixed_c = self.one_hot(torch.ones(fixed_x.size(0)) * i, self.c_dim)
            fixed_c_list.append(self.to_var(fixed_c, volatile=True))

        g_lr = self.g_lr
        d_lr = self.d_lr
        if self.pretrained_model:
            start = int(self.pretrained_model.split('_')[0])
        else:
            start = 0

        start_time = time.time()
        for e in range(start, self.num_epochs):
            print("LEN: " + str(len(self.data_loader)))
            for i, (real_x, real_label) in enumerate(self.data_loader):
                
                rand_idx = torch.randperm(real_label.size(0))
                fake_label = real_label[rand_idx]

                real_c = self.one_hot(real_label, self.c_dim)
                fake_c = self.one_hot(fake_label, self.c_dim)

                real_x = self.to_var(real_x)
                real_c = self.to_var(real_c) 
                fake_c = self.to_var(fake_c)
                real_label = self.to_var(real_label)  
                fake_label = self.to_var(fake_label)
                
                #train discriminator

                out_src, out_cls, out_feats_real = self.D(real_x)
                d_loss_real = - torch.mean(out_src)

                d_loss_cls = F.cross_entropy(out_cls, real_label)

                fake_x = self.G(real_x, fake_c)
                fake_x = Variable(fake_x.data)
                out_src, out_cls, out_feats_fake = self.D(fake_x)
                d_loss_fake = torch.mean(out_src)

                d_loss = d_loss_real + d_loss_fake + self.lambda_cls * d_loss_cls
                self.reset_grad()
                d_loss.backward(retain_graph=True)
                self.d_optimizer.step()

                alpha = torch.rand(real_x.size(0), 1, 1, 1).cuda().expand_as(real_x)
                interpolated = Variable(alpha * real_x.data + (1 - alpha) * fake_x.data, requires_grad=True)
                out, out_cls, out_feats = self.D(interpolated)

                grad = torch.autograd.grad(outputs=out,
                                           inputs=interpolated,
                                           grad_outputs=torch.ones(out.size()).cuda(),
                                           retain_graph=True,
                                           create_graph=True,
                                           only_inputs=True)[0]

                grad = grad.view(grad.size(0), -1)
                grad_l2norm = torch.sqrt(torch.sum(grad ** 2, dim=1))
                d_loss_gp = torch.mean((grad_l2norm - 1)**2)

                d_loss = self.lambda_gp * d_loss_gp
                self.reset_grad()
                d_loss.backward(retain_graph=True)
                self.d_optimizer.step()

                # train generator
                if (i+1) % self.d_train_repeat == 0:

                    fake_x = self.G(real_x, fake_c)
                    rec_x = self.G(fake_x, real_c)

                    out_src, out_cls, out_feats_fake = self.D(fake_x)
                    g_loss_fake = - torch.mean(out_src)

                    g_loss_cls = F.cross_entropy(out_cls, fake_label)

                    out_src, out_cls, out_feats_rec = self.D(rec_x)
                    g_loss_rec = torch.mean(torch.abs(real_x - rec_x))

                    g_loss = g_loss_fake + self.lambda_cls * g_loss_cls + self.lambda_rec * g_loss_rec
                    self.reset_grad()
                    g_loss.backward(retain_graph=True)
                    self.g_optimizer.step()

                if (i+1) % self.sample_step == 0:
                    fake_image_list = [fixed_x]
                    for fixed_c in fixed_c_list:
                        fake_image_list.append(self.G(fixed_x, fixed_c))
                    fake_images = torch.cat(fake_image_list, dim=3)
                    save_image(self.denorm(fake_images.data),
                        os.path.join(self.sample_path, '{}_{}_fake.png'.format(e+1, i+1)),nrow=1, padding=0)
                    print('Images saved into {}..!'.format(self.sample_path))

                if (i+1) % self.model_save_step == 0:
                    torch.save(self.G.state_dict(),
                        os.path.join(self.model_save_path, '{}_{}_G.pth'.format(e+1, i+1)))
                    torch.save(self.D.state_dict(),
                        os.path.join(self.model_save_path, '{}_{}_D.pth'.format(e+1, i+1)))

            if (e+1) > (self.num_epochs - self.num_epochs_decay):
                g_lr -= (self.g_lr / float(self.num_epochs_decay))
                d_lr -= (self.d_lr / float(self.num_epochs_decay))
                self.update_lr(g_lr, d_lr)
                print ('Decay learning rate to g_lr: {}, d_lr: {}.'.format(g_lr, d_lr))


    def test(self):
        G_path = os.path.join(self.model_save_path, '{}_G.pth'.format(self.test_model))
        self.G.load_state_dict(torch.load(G_path))
        self.G.eval()

        data_loader = self.fer2013_loader
        

        label_name_list = ['Angry', 'Happy', 'Sad', 'Surprise']
        for label_name in label_name_list:
            label_path = os.path.join(self.result_path, 'fer2013', 'test', label_name)
            if not os.path.exists(label_path):
                os.makedirs(label_path)

        for i, (real_x, org_c) in enumerate(data_loader):
            real_x = self.to_var(real_x, volatile=True)

            target_c_list = []
            for j in range(self.c_dim):
                target_c = self.one_hot(torch.ones(real_x.size(0)) * j, self.c_dim)
                target_c_list.append(self.to_var(target_c, volatile=True))

            fake_image_list = [real_x]
            for target_c in target_c_list:
                fake_x = self.G(real_x, target_c)
                fake_image_list.append(fake_x)

                tc = np.where(target_c.data.cpu().numpy() == 1)[1]
                for k in range(real_x.size(0)):
                    save_path = os.path.join(self.result_path, 'fer2013', 'test', label_name_list[tc[k]], '{}_fake.png'.format(i*real_x.size(0)+k))
                    save_image(self.denorm(fake_x.data[k,:,:,:]), save_path, nrow=1, padding=0)
            fake_images = torch.cat(fake_image_list, dim=3)
            save_path = os.path.join(self.result_path, '{}_fake.png'.format(i+1))
            save_image(self.denorm(fake_images.data), save_path, nrow=1, padding=0)
            print('Images saved into "{}"..!'.format(save_path))


In [7]:
def get_loader(image_path, metadata_path, crop_size, image_size, batch_size, dataset='fer2013', mode='train'):
    if mode == 'train':
        transform = transforms.Compose([
            transforms.Resize(image_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    else:
        transform = transforms.Compose([
            transforms.Resize(image_size),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    dataset = ImageFolder(image_path, transform)

    shuffle = False
    if mode == 'train':
        shuffle = True

    data_loader = DataLoader(dataset=dataset,
                             batch_size=batch_size,
                             shuffle=shuffle)
    return data_loader

In [8]:
def str2bool(v):
    return v.lower() in ('true')

In [9]:
def main(config):
    cudnn.benchmark = True

    if not os.path.exists(config.log_path):
        os.makedirs(config.log_path)
    if not os.path.exists(config.model_save_path):
        os.makedirs(config.model_save_path)
    if not os.path.exists(config.sample_path):
        os.makedirs(config.sample_path)
    if not os.path.exists(config.result_path):
        os.makedirs(config.result_path)

    fer2013_loader = None
    
    fer2013_loader = get_loader(config.fer2013_image_path, None, config.fer2013_crop_size,
                                 config.image_size, config.batch_size, 'fer2013', config.mode)
    solver = Solver(fer2013_loader, config)

    if config.mode == 'train':
        solver.train()
    elif config.mode == 'test':
        solver.test()

In [10]:
class Config(object):
    def __init__(self, batch_size, beta1, beta2, c2_dim, c_dim, d_conv_dim, d_lr, d_repeat_num, d_train_repeat, dataset, fer2013_crop_size, 
                 fer2013_image_path, g_conv_dim, g_lr, g_repeat_num, image_size, lambda_cls, lambda_feat_rec, lambda_gp, lambda_rec, log_path, 
                 log_step, mode, model_save_path, model_save_step, num_epochs, num_epochs_decay, num_iters, num_iters_decay, num_workers, 
                 pretrained_model, result_path, sample_path, sample_step, test_model):
        self.batch_size = batch_size
        self.beta1 = beta1
        self.beta2 = beta2
        self.c2_dim = c2_dim
        self.c_dim = c_dim
        self.d_conv_dim = d_conv_dim
        self.d_lr = d_lr
        self.d_repeat_num = d_repeat_num
        self.d_train_repeat = d_train_repeat
        self.dataset = dataset
        self.fer2013_crop_size = fer2013_crop_size
        self.fer2013_image_path = fer2013_image_path
        self.g_conv_dim = g_conv_dim
        self.g_lr = g_lr
        self.g_repeat_num = g_repeat_num
        self.image_size = image_size
        self.lambda_cls = lambda_cls
        self.lambda_feat_rec = lambda_feat_rec
        self.lambda_gp = lambda_gp
        self.lambda_rec = lambda_rec
        self.log_path = log_path
        self.log_step = log_step
        self.mode = mode
        self.model_save_path = model_save_path
        self.model_save_step = model_save_step
        self.num_epochs = num_epochs
        self.num_epochs_decay = num_epochs_decay
        self.num_iters = num_iters
        self.num_iters_decay = num_iters_decay
        self.num_workers = num_workers
        self.pretrained_model = pretrained_model
        self.result_path = result_path
        self.sample_path = sample_path
        self.sample_step = sample_step
        self.test_model = test_model

In [11]:
#TRAIN
config_train = Config(batch_size=8, beta1=0.5, beta2=0.999, c2_dim=8, c_dim=2, d_conv_dim=64, d_lr=0.0001, d_repeat_num=6, d_train_repeat=5, 
               dataset='fer2013', fer2013_crop_size=48, fer2013_image_path='./fer2013/train', g_conv_dim=64, g_lr=0.0001, g_repeat_num=6,
               image_size=64, lambda_cls=1, lambda_feat_rec=10, lambda_gp=10, lambda_rec=10, log_path='./stargan/logs',
               log_step=50, mode='train', model_save_path='./stargan/models', model_save_step=1000, num_epochs=400, 
               num_epochs_decay=100, num_iters=2000, num_iters_decay=100000, num_workers=8, pretrained_model='300_1000', 
               result_path='./stargan/results', sample_path='./stargan/samples', 
               sample_step=1000, test_model='20_1000')
#main(config_train)

In [12]:
#TEST
config_test = Config(batch_size=1, beta1=0.5, beta2=0.999, c2_dim=8, c_dim=2,
                     d_conv_dim=64, d_lr=0.0001, d_repeat_num=6, d_train_repeat=5, dataset='fer2013', fer2013_crop_size=48,
                     fer2013_image_path='./fer2013/test', 
                     g_conv_dim=64, g_lr=0.0001, g_repeat_num=6, image_size=64, lambda_cls=1, lambda_feat_rec=10, lambda_gp=10, 
                     lambda_rec=10, log_path='./stargan/logs', log_step=10, 
                     mode='test', model_save_path='./stargan/models', model_save_step=4, 
                     num_epochs=20, num_epochs_decay=10, num_iters=200000, num_iters_decay=100000, num_workers=8, pretrained_model=None, 
                     result_path='./stargan/results', 
                     sample_path='./stargan/samples', sample_step=500, test_model='400_1000')
#main(config_test)