In [1]:
import easydict
import os
from torch.backends import cudnn
import easydict
import sys
sys.path.append(r'C:\Users\YongTaek\Desktop')
from Solver import Solver
from data_loader import get_loader

def str2bool(v):
    return v.lower() in ('true')

def main(args):
    # For fast training.
    cudnn.benchmark = True

    # Create directories if not exist.
    if not os.path.exists(args['log_dir']):
        os.makedirs(args['log_dir'])
    if not os.path.exists(args['model_save_dir']):
        os.makedirs(args['model_save_dir'])
    if not os.path.exists(args['sample_dir']):
        os.makedirs(args['sample_dir'])
    if not os.path.exists(args['result_dir']):
        os.makedirs(args['result_dir'])

    # Data loader.
    celeba_loader = None
    rafd_loader = None

    celeba_loader = get_loader(args['celeba_image_dir'], args['attr_path'], args['selected_attrs'],
                              args['celeba_crop_size'], args['image_size'], args['batch_size'],
                              'CelebA', args['mode'], args['num_workers'])
    rafd_loader = get_loader(args['rafd_image_dir'], None, None,
                                     args['rafd_crop_size'], args['image_size'], args['batch_size'],
                                     'RaFD', args['mode'], args['num_workers'])
    

    # Solver for training and testing StarGAN.
    solver = Solver(celeba_loader, rafd_loader, args)

    if args['mode'] == 'train':
        if args['dataset'] in ['CelebA', 'RaFD']:
            solver.train()
        elif args['dataset'] in ['Both']:
            solver.train_multi()
    elif args['mode'] == 'test':
        if args['dataset'] in ['CelebA', 'RaFD']:
            solver.test()
        elif args['dataset'] in ['Both']:
            solver.test_multi()
if __name__ == '__main__':
    args = easydict.EasyDict({
        'c_dim' : 5, # dimension of domain labels (1st dataset)
        'c2_dim' : 8, # dimension of domain labels (2nd dataset)
        'celeba_crop_size' : 178, # crop size for the CelebA dataset
        'rafd_crop_size' : 256, # crop size for the RaFD dataset
        'image_size' : 256, # image resolution
        'g_conv_dim' : 64, # number of conv filters in the first layer of G
        'd_conv_dim' : 64, # number of conv filters in the first layer of D
        'g_repeat_num' : 6, # number of conv filters in the first layer of D
        'd_repeat_num' : 6, # number of strided conv layers in D
        'lambda_cls' : 1, # weight for domain classification loss
        'lambda_rec' : 10, # weight for reconstruction loss
        'lambda_gp' : 10, # weight for gradient penalty
        
        # Training configuration.
        'dataset' : 'Both', #choices=['CelebA', ' RaFD', 'Both']
        'batch_size' : 16, # mini-batch size
        'num_iters' : 200000, # number of total iterations for training D
        'num_iters_decay' : 100000, #number of iterations for decaying lr
        'g_lr' : 0.0001, #learning rate for G
        'd_lr' : 0.0001, #learning rate for D
        'n_critic' : 5, # number of D updates per each G update
        'beta1' : 0.5, # beta1 for Adam optimizer
        'beta2' : 0.999, # beta2 for Adam optimizer
        'resume_iters' : None, #resume training from this step
        'selected_attrs' : ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Male', 'Young'], #selected attributes for the CelebA dataset
        
        # Test configuration.
        'test_iters' : 200000,
        
        # Miscellaneous.
        'num_workers' : 2,
        'mode' : 'train', #choices = ['train', 'test']
        'use_tensorboard' : True,
        
        # Directories.
        'celeba_image_dir' : r'C:\Users\YongTaek\Desktop\img_align_celeba\img_align_celeba',
        'attr_path' : r'C:\Users\YongTaek\Desktop\GAN\stargan-master\list_attr_celeba.txt',
        'rafd_image_dir' : r'C:\Users\YongTaek\Desktop\finalimage\train',
        'log_dir' : r'C:\Users\YongTaek\Desktop\stargan_both/logs',
        'model_save_dir' : r'C:\Users\YongTaek\Desktop\stargan_both/models',
        'sample_dir' : r'C:\Users\YongTaek\Desktop\stargan_both/samples',
        'result_dir' : r'C:\Users\YongTaek\Desktop\stargan_both/results',
        
        # Step size.
        'log_step' : 10,
        'sample_step' : 1000,
        'model_save_step' : 10000,
        'lr_update_step' : 1000,
        
        # ngpu사용
        'ngpu' : 2
        })
main(args)

Finished preprocessing the CelebA dataset...
DataParallel(
  (module): Generator(
    (main): Sequential(
      (0): Conv2d(9, 15, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), bias=False)
      (1): InstanceNorm2d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(15, 30, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (4): InstanceNorm2d(30, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
      (6): Conv2d(30, 60, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (7): InstanceNorm2d(60, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (8): ReLU(inplace=True)
      (9): ResidualBlock(
        (main): Sequential(
          (0): Conv2d(60, 60, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): InstanceNorm2d(60, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


Start training...




RuntimeError: Caught RuntimeError in replica 0 on device 0.
Original Traceback (most recent call last):
  File "C:\Users\YongTaek\Anaconda3\envs\Caps\lib\site-packages\torch\nn\parallel\parallel_apply.py", line 60, in _worker
    output = module(*input, **kwargs)
  File "C:\Users\YongTaek\Anaconda3\envs\Caps\lib\site-packages\torch\nn\modules\module.py", line 547, in __call__
    result = self.forward(*input, **kwargs)
  File "C:\Users\YongTaek\Desktop\model1.py", line 70, in forward
    return self.main(x)
  File "C:\Users\YongTaek\Anaconda3\envs\Caps\lib\site-packages\torch\nn\modules\module.py", line 547, in __call__
    result = self.forward(*input, **kwargs)
  File "C:\Users\YongTaek\Anaconda3\envs\Caps\lib\site-packages\torch\nn\modules\container.py", line 92, in forward
    input = module(input)
  File "C:\Users\YongTaek\Anaconda3\envs\Caps\lib\site-packages\torch\nn\modules\module.py", line 547, in __call__
    result = self.forward(*input, **kwargs)
  File "C:\Users\YongTaek\Anaconda3\envs\Caps\lib\site-packages\torch\nn\modules\conv.py", line 343, in forward
    return self.conv2d_forward(input, self.weight)
  File "C:\Users\YongTaek\Anaconda3\envs\Caps\lib\site-packages\torch\nn\modules\conv.py", line 340, in conv2d_forward
    self.padding, self.dilation, self.groups)
RuntimeError: Given groups=1, weight of size 15 9 7 7, expected input[8, 18, 256, 256] to have 9 channels, but got 18 channels instead


### main 함수  첫 단계 DATA LOAD 함수

In [None]:
from torch.utils import data
from torchvision import transforms as T
from torchvision.datasets import ImageFolder
from PIL import Image
import torch
import os
import random


class CelebA(data.Dataset):
    """Dataset class for the CelebA dataset."""

    def __init__(self, image_dir, attr_path, selected_attrs, transform, mode):
        """Initialize and preprocess the CelebA dataset."""
        self.image_dir = image_dir
        self.attr_path = attr_path
        self.selected_attrs = selected_attrs
        self.transform = transform
        self.mode = mode
        self.train_dataset = []
        self.test_dataset = []
        self.attr2idx = {}
        self.idx2attr = {}
        self.preprocess()

        if mode == 'train':
            self.num_images = len(self.train_dataset)
        else:
            self.num_images = len(self.test_dataset)

    def preprocess(self):
        """Preprocess the CelebA attribute file."""
        lines = [line.rstrip() for line in open(self.attr_path, 'r')]
        all_attr_names = lines[1].split()
        for i, attr_name in enumerate(all_attr_names):
            self.attr2idx[attr_name] = i
            self.idx2attr[i] = attr_name

        lines = lines[2:]
        random.seed(1234)
        random.shuffle(lines)
        for i, line in enumerate(lines):
            split = line.split()
            filename = split[0]
            values = split[1:]

            label = []
            for attr_name in self.selected_attrs:
                idx = self.attr2idx[attr_name]
                label.append(values[idx] == '1')

            if (i+1) < 2000:
                self.test_dataset.append([filename, label])
            else:
                self.train_dataset.append([filename, label])

        print('Finished preprocessing the CelebA dataset...')

    def __getitem__(self, index):
        """Return one image and its corresponding attribute label."""
        dataset = self.train_dataset if self.mode == 'train' else self.test_dataset
        filename, label = dataset[index]
        image = Image.open(os.path.join(self.image_dir, filename))
        return self.transform(image), torch.FloatTensor(label)

    def __len__(self):
        """Return the number of images."""
        return self.num_images


def get_loader(image_dir, attr_path, selected_attrs, crop_size=178, image_size=128, 
               batch_size=16, dataset='CelebA', mode='train', num_workers=1):
    """Build and return a data loader."""
    transform = []
    if mode == 'train':
        transform.append(T.RandomHorizontalFlip())
    transform.append(T.CenterCrop(crop_size))
    transform.append(T.Resize(image_size))
    transform.append(T.ToTensor())
    transform.append(T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))
    transform = T.Compose(transform)

    if dataset == 'CelebA':
        dataset = CelebA(image_dir, attr_path, selected_attrs, transform, mode)
    elif dataset == 'RaFD':
        dataset = ImageFolder(image_dir, transform)

    data_loader = data.DataLoader(dataset=dataset,
                                  batch_size=batch_size,
                                  shuffle=(mode=='train'),
                                  num_workers=num_workers)
    return data_loader

### DATA LOADING

In [None]:
celeba_loader = get_loader(args['celeba_image_dir'], args['attr_path'], args['selected_attrs'],
                          args['celeba_crop_size'], args['image_size'], args['batch_size'],
                          'CelebA', args['mode'], args['num_workers'])
rafd_loader = get_loader(args['rafd_image_dir'], None, None,
                                 args['rafd_crop_size'], args['image_size'], args['batch_size'],
                                 'RaFD', args['mode'], args['num_workers'])

### Model 정의 ngpu코드 추가

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


class ResidualBlock(nn.Module):
    """Residual Block with instance normalization."""
    def __init__(self, dim_in, dim_out):#, ngpu):  ngpu 받는 코드 추가
        super(ResidualBlock, self).__init__()
        #self.ngpu = ngpu  ngpu 추가
        self.main = nn.Sequential(
            nn.Conv2d(dim_in, dim_out, kernel_size=3, stride=1, padding=1, bias=False),
            nn.InstanceNorm2d(dim_out, affine=True, track_running_stats=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(dim_out, dim_out, kernel_size=3, stride=1, padding=1, bias=False),
            nn.InstanceNorm2d(dim_out, affine=True, track_running_stats=True))

    def forward(self, x):
        return x + self.main(x)


class Generator(nn.Module):
    """Generator network."""
    def __init__(self, ngpu, conv_dim=64, c_dim=5, repeat_num=6 ): # ngpu 받는 코드 추가
        super(Generator, self).__init__()
        self.ngpu = ngpu # ngpu 추가

        layers = []
        layers.append(nn.Conv2d(3+c_dim, conv_dim, kernel_size=7, stride=1, padding=3, bias=False))
        layers.append(nn.InstanceNorm2d(conv_dim, affine=True, track_running_stats=True))
        layers.append(nn.ReLU(inplace=True))

        # Down-sampling layers.
        curr_dim = conv_dim
        for i in range(2):
            layers.append(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1, bias=False))
            layers.append(nn.InstanceNorm2d(curr_dim*2, affine=True, track_running_stats=True))
            layers.append(nn.ReLU(inplace=True))
            curr_dim = curr_dim * 2

        # Bottleneck layers.
        for i in range(repeat_num):
            layers.append(ResidualBlock(dim_in=curr_dim, dim_out=curr_dim))

        # Up-sampling layers.
        for i in range(2):
            layers.append(nn.ConvTranspose2d(curr_dim, curr_dim//2, kernel_size=4, stride=2, padding=1, bias=False))
            layers.append(nn.InstanceNorm2d(curr_dim//2, affine=True, track_running_stats=True))
            layers.append(nn.ReLU(inplace=True))
            curr_dim = curr_dim // 2

        layers.append(nn.Conv2d(curr_dim, 3, kernel_size=7, stride=1, padding=3, bias=False))
        layers.append(nn.Tanh())
        self.main = nn.Sequential(*layers)

    def forward(self, x, c):
        # Replicate spatially and concatenate domain information.
        # Note that this type of label conditioning does not work at all if we use reflection padding in Conv2d.
        # This is because instance normalization ignores the shifting (or bias) effect.
        c = c.view(c.size(0), c.size(1), 1, 1)
        c = c.repeat(1, 1, x.size(2), x.size(3))
        x = torch.cat([x, c], dim=1)
        return self.main(x)


class Discriminator(nn.Module):
    """Discriminator network with PatchGAN."""
    def __init__(self, ngpu, image_size=128, conv_dim=64, c_dim=5, repeat_num=6): # ngpu 받는 코드 추가
        super(Discriminator, self).__init__()
        self.ngpu = ngpu # ngpu 추가
        layers = []
        layers.append(nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1))
        layers.append(nn.LeakyReLU(0.01))

        curr_dim = conv_dim
        for i in range(1, repeat_num):
            layers.append(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1))
            layers.append(nn.LeakyReLU(0.01))
            curr_dim = curr_dim * 2

        kernel_size = int(image_size / np.power(2, repeat_num))
        self.main = nn.Sequential(*layers)
        self.conv1 = nn.Conv2d(curr_dim, 1, kernel_size=3, stride=1, padding=1, bias=False)
        self.conv2 = nn.Conv2d(curr_dim, c_dim, kernel_size=kernel_size, bias=False)
        
    def forward(self, x):
        h = self.main(x)
        out_src = self.conv1(h)
        out_cls = self.conv2(h)
        return out_src, out_cls.view(out_cls.size(0), out_cls.size(1))

## SOLVER 정의

In [None]:
from torch.autograd import Variable
from torchvision.utils import save_image
import torch
import torch.nn.functional as F
import numpy as np
import os
import time
import datetime


class Solver(object):
    """Solver for training and testing StarGAN."""

    def __init__(self, celeba_loader, rafd_loader, args):
        """Initialize configurations."""

        # Data loader.
        self.celeba_loader = celeba_loader
        self.rafd_loader = rafd_loader

        # Model configurations.
        self.c_dim = args['c_dim']
        self.c2_dim = args['c2_dim']
        self.image_size = args['image_size']
        self.g_conv_dim = args['g_conv_dim']
        self.d_conv_dim = args['d_conv_dim']
        self.g_repeat_num = args['g_repeat_num']
        self.d_repeat_num = args['d_repeat_num']
        self.lambda_cls = args['lambda_cls']
        self.lambda_rec = args['lambda_rec']
        self.lambda_gp = args['lambda_gp']

        # Training configurations.
        self.dataset = args['dataset']
        self.batch_size = args['batch_size']
        self.num_iters = args['num_iters']
        self.num_iters_decay = args['num_iters_decay']
        self.g_lr = args['g_lr']
        self.d_lr = args['d_lr']
        self.n_critic = args['n_critic']
        self.beta1 = args['beta1']
        self.beta2 = args['beta2']
        self.resume_iters = args['resume_iters']
        self.selected_attrs = args['selected_attrs']

        # Test configurations.
        self.test_iters = args['test_iters']

        # Miscellaneous.
        self.use_tensorboard = args['use_tensorboard']
        
        # GPU 설정.
        self.ngpu = args['ngpu']
        self.device = torch.device("cuda:0" if (torch.cuda.is_available() and self.ngpu > 0) else "cpu")

        # Directories.
        self.log_dir = args['log_dir']
        self.sample_dir = args['sample_dir']
        self.model_save_dir = args['model_save_dir']
        self.result_dir = args['result_dir']

        # Step size.
        self.log_step = args['log_step']
        self.sample_step = args['sample_step']
        self.model_save_step = args['model_save_step']
        self.lr_update_step = args['lr_update_step']

        # Build the model and tensorboard.
        self.build_model()
        if self.use_tensorboard:
            self.build_tensorboard()

    def build_model(self):
        """Create a generator and a discriminator."""
        if self.dataset in ['CelebA', 'RaFD']:
            self.G = Generator(self.g_conv_dim, self.c_dim, self.g_repeat_num, self.ngpu).to(self.device)
            self.D = Discriminator(self.image_size, self.d_conv_dim, self.c_dim, self.d_repeat_num, self.ngpu).to(self.device)
            if (self.device.type == 'cuda') and (self.ngpu > 1):
                self.G = nn.DataParallel(self.G, list(range(self.ngpu)))
            if (self.device.type == 'cuda') and (self.ngpu > 1):
                self.D = nn.DataParallel(self.D, list(range(self.ngpu)))
                
        elif self.dataset in ['Both']:
            self.G = Generator(self.g_conv_dim, self.c_dim+self.c2_dim+2, self.g_repeat_num, self.ngpu).to(self.device)   # 2 for mask vector.
            self.D = Discriminator(self.image_size, self.d_conv_dim, self.c_dim+self.c2_dim, self.d_repeat_num, self.ngpu).to(self.device)
            if (self.device.type == 'cuda') and (self.ngpu > 1):
                self.G = nn.DataParallel(self.G, list(range(self.ngpu)))
            if (self.device.type == 'cuda') and (self.ngpu > 1):
                self.G = nn.DataParallel(self.G, list(range(self.ngpu)))
            
            
        self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2])
        self.d_optimizer = torch.optim.Adam(self.D.parameters(), self.d_lr, [self.beta1, self.beta2])
        self.print_network(self.G, 'G')
        self.print_network(self.D, 'D')
            
        #self.G.to(self.device)
        #self.D.to(self.device)

    def print_network(self, model, name):
        """Print out the network information."""
        num_params = 0
        for p in model.parameters():
            num_params += p.numel()
        print(model)
        print(name)
        print("The number of parameters: {}".format(num_params))

    def restore_model(self, resume_iters):
        """Restore the trained generator and discriminator."""
        print('Loading the trained models from step {}...'.format(resume_iters))
        G_path = os.path.join(self.model_save_dir, '{}-G.ckpt'.format(resume_iters))
        D_path = os.path.join(self.model_save_dir, '{}-D.ckpt'.format(resume_iters))
        self.G.load_state_dict(torch.load(G_path, map_location=lambda storage, loc: storage))
        self.D.load_state_dict(torch.load(D_path, map_location=lambda storage, loc: storage))

    def build_tensorboard(self):
        """Build a tensorboard logger."""
        from logger import Loggerclass
        self.logger = Loggerclass(self.log_dir)

    def update_lr(self, g_lr, d_lr):
        """Decay learning rates of the generator and discriminator."""
        for param_group in self.g_optimizer.param_groups:
            param_group['lr'] = g_lr
        for param_group in self.d_optimizer.param_groups:
            param_group['lr'] = d_lr

    def reset_grad(self):
        """Reset the gradient buffers."""
        self.g_optimizer.zero_grad()
        self.d_optimizer.zero_grad()

    def denorm(self, x):
        """Convert the range from [-1, 1] to [0, 1]."""
        out = (x + 1) / 2
        return out.clamp_(0, 1)

    def gradient_penalty(self, y, x):
        """Compute gradient penalty: (L2_norm(dy/dx) - 1)**2."""
        weight = torch.ones(y.size()).to(self.device)
        dydx = torch.autograd.grad(outputs=y,
                                   inputs=x,
                                   grad_outputs=weight,
                                   retain_graph=True,
                                   create_graph=True,
                                   only_inputs=True)[0]

        dydx = dydx.view(dydx.size(0), -1)
        dydx_l2norm = torch.sqrt(torch.sum(dydx**2, dim=1))
        return torch.mean((dydx_l2norm-1)**2)

    def label2onehot(self, labels, dim):
        """Convert label indices to one-hot vectors."""
        batch_size = labels.size(0)
        out = torch.zeros(batch_size, dim)
        out[np.arange(batch_size), labels.long()] = 1
        return out

    def create_labels(self, c_org, c_dim=5, dataset='CelebA', selected_attrs=None):
        """Generate target domain labels for debugging and testing."""
        # Get hair color indices.
        if dataset == 'CelebA':
            hair_color_indices = []
            for i, attr_name in enumerate(selected_attrs):
                if attr_name in ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Gray_Hair']:
                    hair_color_indices.append(i)

        c_trg_list = []
        for i in range(c_dim):
            if dataset == 'CelebA':
                c_trg = c_org.clone()
                if i in hair_color_indices:  # Set one hair color to 1 and the rest to 0.
                    c_trg[:, i] = 1
                    for j in hair_color_indices:
                        if j != i:
                            c_trg[:, j] = 0
                else:
                    c_trg[:, i] = (c_trg[:, i] == 0)  # Reverse attribute value.
            elif dataset == 'RaFD':
                c_trg = self.label2onehot(torch.ones(c_org.size(0))*i, c_dim)

            c_trg_list.append(c_trg.to(self.device))
        return c_trg_list

    def classification_loss(self, logit, target, dataset='CelebA'):
        """Compute binary or softmax cross entropy loss."""
        if dataset == 'CelebA':
            return F.binary_cross_entropy_with_logits(logit, target, size_average=False) / logit.size(0)
        elif dataset == 'RaFD':
            return F.cross_entropy(logit, target)

    

    def train_multi(self):
        """Train StarGAN with multiple datasets."""        
        # Data iterators.
        celeba_iter = iter(self.celeba_loader)
        rafd_iter = iter(self.rafd_loader)

        # Fetch fixed inputs for debugging.
        x_fixed, c_org = next(celeba_iter)
        x_fixed = x_fixed.to(self.device)
        c_celeba_list = self.create_labels(c_org, self.c_dim, 'CelebA', self.selected_attrs)
        c_rafd_list = self.create_labels(c_org, self.c2_dim, 'RaFD')
        
        ### 여기도 병행처리로 넘겨봄.
        zero_celeba = torch.zeros(x_fixed.size(0), self.c_dim).to("cuda:0")           # Zero vector for CelebA.
        zero_rafd = torch.zeros(x_fixed.size(0), self.c2_dim).to("cuda:0")             # Zero vector for RaFD.
        mask_celeba = self.label2onehot(torch.zeros(x_fixed.size(0)), 2).to("cuda:1")  # Mask vector: [1, 0].
        mask_rafd = self.label2onehot(torch.ones(x_fixed.size(0)), 2).to("cuda:1")     # Mask vector: [0, 1].

        # Learning rate cache for decaying.
        g_lr = self.g_lr
        d_lr = self.d_lr

        # Start training from scratch or resume training.
        start_iters = 0
        if self.resume_iters:
            start_iters = self.resume_iters
            self.restore_model(self.resume_iters)

        # Start training.
        print('Start training...')
        start_time = time.time()
        for i in range(start_iters, self.num_iters):
            for dataset in ['CelebA', 'RaFD']:

                # =================================================================================== #
                #                             1. Preprocess input data                                #
                # =================================================================================== #
                
                # Fetch real images and labels.
                data_iter = celeba_iter if dataset == 'CelebA' else rafd_iter
                
                try:
                    x_real, label_org = next(data_iter)
                except:
                    if dataset == 'CelebA':
                        celeba_iter = iter(self.celeba_loader)
                        x_real, label_org = next(celeba_iter)
                    elif dataset == 'RaFD':
                        rafd_iter = iter(self.rafd_loader)
                        x_real, label_org = next(rafd_iter)

                # Generate target domain labels randomly.
                rand_idx = torch.randperm(label_org.size(0))
                label_trg = label_org[rand_idx]

                if dataset == 'CelebA':
                    c_org = label_org.clone()
                    c_trg = label_trg.clone()
                    zero = torch.zeros(x_real.size(0), self.c2_dim)
                    mask = self.label2onehot(torch.zeros(x_real.size(0)), 2)
                    c_org = torch.cat([c_org, zero, mask], dim=1)
                    c_trg = torch.cat([c_trg, zero, mask], dim=1)
                elif dataset == 'RaFD':
                    c_org = self.label2onehot(label_org, self.c2_dim)
                    c_trg = self.label2onehot(label_trg, self.c2_dim)
                    zero = torch.zeros(x_real.size(0), self.c_dim)
                    mask = self.label2onehot(torch.ones(x_real.size(0)), 2)
                    c_org = torch.cat([zero, c_org, mask], dim=1)
                    c_trg = torch.cat([zero, c_trg, mask], dim=1)

                x_real = x_real.to(self.device)             # Input images.
                c_org = c_org.to(self.device)               # Original domain labels.
                c_trg = c_trg.to(self.device)               # Target domain labels.
                label_org = label_org.to(self.device)       # Labels for computing classification loss.
                label_trg = label_trg.to(self.device)       # Labels for computing classification loss.

                # =================================================================================== #
                #                             2. Train the discriminator                              #
                # =================================================================================== #

                # Compute loss with real images.
                out_src, out_cls = self.D(x_real)
                out_cls = out_cls[:, :self.c_dim] if dataset == 'CelebA' else out_cls[:, self.c_dim:]
                d_loss_real = - torch.mean(out_src)
                d_loss_cls = self.classification_loss(out_cls, label_org, dataset)

                # Compute loss with fake images.
                x_fake = self.G(x_real, c_trg)
                out_src, _ = self.D(x_fake.detach())
                d_loss_fake = torch.mean(out_src)

                # Compute loss for gradient penalty.
                alpha = torch.rand(x_real.size(0), 1, 1, 1).to(self.device)
                x_hat = (alpha * x_real.data + (1 - alpha) * x_fake.data).requires_grad_(True)
                out_src, _ = self.D(x_hat)
                d_loss_gp = self.gradient_penalty(out_src, x_hat)

                # Backward and optimize.
                d_loss = d_loss_real + d_loss_fake + self.lambda_cls * d_loss_cls + self.lambda_gp * d_loss_gp
                self.reset_grad()
                d_loss.backward()
                self.d_optimizer.step()

                # Logging.
                loss = {}
                loss['D/loss_real'] = d_loss_real.item()
                loss['D/loss_fake'] = d_loss_fake.item()
                loss['D/loss_cls'] = d_loss_cls.item()
                loss['D/loss_gp'] = d_loss_gp.item()
            
                # =================================================================================== #
                #                               3. Train the generator                                #
                # =================================================================================== #

                if (i+1) % self.n_critic == 0:
                    # Original-to-target domain.
                    x_fake = self.G(x_real, c_trg)
                    out_src, out_cls = self.D(x_fake)
                    out_cls = out_cls[:, :self.c_dim] if dataset == 'CelebA' else out_cls[:, self.c_dim:]
                    g_loss_fake = - torch.mean(out_src)
                    g_loss_cls = self.classification_loss(out_cls, label_trg, dataset)

                    # Target-to-original domain.
                    x_reconst = self.G(x_fake, c_org)
                    g_loss_rec = torch.mean(torch.abs(x_real - x_reconst))

                    # Backward and optimize.
                    g_loss = g_loss_fake + self.lambda_rec * g_loss_rec + self.lambda_cls * g_loss_cls
                    self.reset_grad()
                    g_loss.backward()
                    self.g_optimizer.step()

                    # Logging.
                    loss['G/loss_fake'] = g_loss_fake.item()
                    loss['G/loss_rec'] = g_loss_rec.item()
                    loss['G/loss_cls'] = g_loss_cls.item()

                # =================================================================================== #
                #                                 4. Miscellaneous                                    #
                # =================================================================================== #

                # Print out training info.
                if (i+1) % self.log_step == 0:
                    et = time.time() - start_time
                    et = str(datetime.timedelta(seconds=et))[:-7]
                    log = "Elapsed [{}], Iteration [{}/{}], Dataset [{}]".format(et, i+1, self.num_iters, dataset)
                    for tag, value in loss.items():
                        log += ", {}: {:.4f}".format(tag, value)
                    print(log)

                    if self.use_tensorboard:
                        for tag, value in loss.items():
                            self.logger.scalar_summary(tag, value, i+1)

            # Translate fixed images for debugging.
            if (i+1) % self.sample_step == 0:
                with torch.no_grad():
                    x_fake_list = [x_fixed]
                    for c_fixed in c_celeba_list:
                        c_trg = torch.cat([c_fixed, zero_rafd, mask_celeba], dim=1)
                        x_fake_list.append(self.G(x_fixed, c_trg))
                    for c_fixed in c_rafd_list:
                        c_trg = torch.cat([zero_celeba, c_fixed, mask_rafd], dim=1)
                        x_fake_list.append(self.G(x_fixed, c_trg))
                    x_concat = torch.cat(x_fake_list, dim=3)
                    sample_path = os.path.join(self.sample_dir, '{}-images.jpg'.format(i+1))
                    save_image(self.denorm(x_concat.data.cpu()), sample_path, nrow=1, padding=0)
                    print('Saved real and fake images into {}...'.format(sample_path))

            # Save model checkpoints.
            if (i+1) % self.model_save_step == 0:
                G_path = os.path.join(self.model_save_dir, '{}-G.ckpt'.format(i+1))
                D_path = os.path.join(self.model_save_dir, '{}-D.ckpt'.format(i+1))
                torch.save(self.G.state_dict(), G_path)
                torch.save(self.D.state_dict(), D_path)
                print('Saved model checkpoints into {}...'.format(self.model_save_dir))

            # Decay learning rates.
            if (i+1) % self.lr_update_step == 0 and (i+1) > (self.num_iters - self.num_iters_decay):
                g_lr -= (self.g_lr / float(self.num_iters_decay))
                d_lr -= (self.d_lr / float(self.num_iters_decay))
                self.update_lr(g_lr, d_lr)
                print ('Decayed learning rates, g_lr: {}, d_lr: {}.'.format(g_lr, d_lr))