## 실습에 필요한 파이썬 라이브러리를 불러옵니다.

!pip install easydict

In [1]:
# pyTorch 관련 된 라이브러리.
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim # optimization에 관한 모듈.
import torchvision # 이미지 관련 전처리, pretrained된 모델, 데이터 로딩에 관한 패키지입니다.
from torch.autograd import Variable # 미분자동화가 된 변수 설정을 위한 모듈
from torchvision.utils import save_image # 이미지 저장을 위한 torchvision의 모듈
import torchvision.datasets as vision_dsets
import torchvision.transforms as T # 이미지 전처리 모듈입니다.
from torchvision.datasets import ImageFolder
from torch.utils import data

# 기타 필요한 라이브러리.
from PIL import Image
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import os
import sys
import time
import datetime
import random
import easydict # ipynb에서는 다루기 귀찮은 argparse를 대체해주는 라이브러리

## 버전을 한 번 확인해 줍니다.

In [2]:
print('python version : ', sys.version)
print('numpy version : ', np.version.version)
print('matplotlib version :', mpl.__version__)
print('pytorch version : ', torch.__version__)
print('torchvision version : ', torchvision.__version__)
print('Cuda : ', torch.cuda.is_available()) # GPU 세팅이 제대로 됬는지 확인해 줍니다.

python version :  3.7.1 (default, Dec 14 2018, 19:28:38) 
[GCC 7.3.0]
numpy version :  1.15.4
matplotlib version : 3.0.2
pytorch version :  1.0.1.post2
torchvision version :  0.2.2
Cuda :  True


## StarGAN

#### StarGAN은 multi-domain의 image translation이 가능하다는 것에 가장 큰 contribution을 갖는다.

![구조](https://github.com/yunjey/stargan/blob/master/jpg/model.jpg?raw=true)

---

#### 위 구조를 실제 dataset 상황에서 보자면 아래와 같다.



![구조2](https://github.com/yunjey/stargan/blob/master/jpg/model2.jpg?raw=true)

## CelebA Dataset Download 

!bash download.sh celeba

## Data Loader

CelebA의 경우에는 pytorch에서 기본으로 제공하는 data loader가 없기에 조금 복잡한 data loader를 구현해줘야 합니다. dataloader의 경우는 각각의 dataset들마다 코딩해줘야하는 방향이 천차만별입니다.

In [3]:
class CelebA(data.Dataset):
    """Dataset class for the CelebA dataset."""

    def __init__(self, image_dir, attr_path, selected_attrs, transform, mode):
        """Initialize and preprocess the CelebA dataset."""
        self.image_dir = image_dir
        self.attr_path = attr_path
        self.selected_attrs = selected_attrs
        self.transform = transform
        self.mode = mode
        self.train_dataset = []
        self.test_dataset = []
        self.attr2idx = {}
        self.idx2attr = {}
        self.preprocess()

        if mode == 'train':
            self.num_images = len(self.train_dataset)
        else:
            self.num_images = len(self.test_dataset)

    def preprocess(self):
        """Preprocess the CelebA attribute file."""
        lines = [line.rstrip() for line in open(self.attr_path, 'r')]
        all_attr_names = lines[1].split()
        for i, attr_name in enumerate(all_attr_names):
            self.attr2idx[attr_name] = i
            self.idx2attr[i] = attr_name

        lines = lines[2:]
        random.seed(1234)
        random.shuffle(lines)
        for i, line in enumerate(lines):
            split = line.split()
            filename = split[0]
            values = split[1:]

            label = []
            for attr_name in self.selected_attrs:
                idx = self.attr2idx[attr_name]
                label.append(values[idx] == '1')

            if (i+1) < 2000:
                self.test_dataset.append([filename, label])
            else:
                self.train_dataset.append([filename, label])

        print('Finished preprocessing the CelebA dataset...')

    def __getitem__(self, index):
        """Return one image and its corresponding attribute label."""
        dataset = self.train_dataset if self.mode == 'train' else self.test_dataset
        filename, label = dataset[index]
        image = Image.open(os.path.join(self.image_dir, filename))
        return self.transform(image), torch.FloatTensor(label)

    def __len__(self):
        """Return the number of images."""
        return self.num_images
    
def get_loader(image_dir, attr_path, selected_attrs, crop_size=178, image_size=128, 
               batch_size=16, dataset='CelebA', mode='train', num_workers=1):
    """Build and return a data loader."""
    transform = []
    if mode == 'train':
        transform.append(T.RandomHorizontalFlip())
    transform.append(T.CenterCrop(crop_size))
    transform.append(T.Resize(image_size))
    transform.append(T.ToTensor())
    transform.append(T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))
    transform = T.Compose(transform)

    dataset = CelebA(image_dir, attr_path, selected_attrs, transform, mode)

    data_loader = data.DataLoader(dataset=dataset,
                                  batch_size=batch_size,
                                  shuffle=(mode=='train'),
                                  num_workers=num_workers)
    return data_loader

## Model Implementation

### ResidualBlock

Residual Block 이라는 개념은 2015년 Imagenet Challenge에서 ResNet이라는 모델이 우승을 차지하며 그 효과를 입증받은 모델이다.

![resnet](https://neurohive.io/wp-content/uploads/2019/01/resnet-e1548261477164.png)

### Instance Normalization

Batch Normalization의 경우에는 모델이 batch크기의 영향을 크게 받고 특히 batch 크기가 1처럼 작은 경우 variance가 너무 작아져 noisy한 결과를 갖게 될 수 있습니다. 이러한 batch size에 영향을 받는 batch norm의 한계를 넘기 위해 많은 normalization 방법론들이 나왔고, 그 중 하나가 Instance Normalization으로서 image to image translation에서 자주 사용되는 방법입니다.

![IN](https://bloglunit.files.wordpress.com/2018/04/ec8aa4ed81aceba6b0ec83b7-2018-04-11-ec98a4ed9b84-3-07-41.png?w=1400)

#### Batch Normalization

![BN equation](https://i.stack.imgur.com/VDqKY.jpg)

#### Instance Normalization

![IN equation](https://i.stack.imgur.com/X5z48.jpg)

In [4]:
class ResidualBlock(nn.Module):
    """Residual Block with instance normalization."""
    def __init__(self, dim_in, dim_out):
        super(ResidualBlock, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(dim_in, dim_out, kernel_size=3, stride=?, padding=?, bias=False),
            nn.InstanceNorm2d(dim_out, affine=True, track_running_stats=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(dim_out, dim_out, kernel_size=3, stride=?, padding=?, bias=False),
            nn.InstanceNorm2d(dim_out, affine=True, track_running_stats=True))

    def forward(self, x):
        return ?????


class Generator(nn.Module):
    """Generator network."""
    def __init__(self, conv_dim=64, c_dim=5, repeat_num=6):
        super(Generator, self).__init__()

        layers = []
        layers.append(nn.Conv2d(3+?????, conv_dim, kernel_size=7, stride=1, padding=3, bias=False))
        layers.append(nn.InstanceNorm2d(conv_dim, affine=True, track_running_stats=True))
        layers.append(nn.ReLU(inplace=True))

        # Down-sampling layers.
        curr_dim = conv_dim
        for i in range(2):
            layers.append(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1, bias=False))
            layers.append(nn.InstanceNorm2d(curr_dim*2, affine=True, track_running_stats=True))
            layers.append(nn.ReLU(inplace=True))
            curr_dim = curr_dim * 2
        # 이 for 문이 끝난 뒤의 curr dim = ???????

        # Bottleneck layers.
        for i in range(repeat_num):
            layers.append(?????????????????????????????????????????)

        # Up-sampling layers.
        for i in range(2):
            layers.append(nn.??????????????(curr_dim, ?????, kernel_size=?, stride=?, padding=?, bias=False))
            layers.append(nn.InstanceNorm2d(???????, affine=True, track_running_stats=True))
            layers.append(nn.ReLU(inplace=True))
            curr_dim = ???????

        layers.append(nn.Conv2d(curr_dim, 3, kernel_size=7, stride=1, padding=3, bias=False))
        layers.append(nn.Tanh())
        self.main = nn.Sequential(*layers)

    def forward(self, x, c):
        # Replicate spatially and concatenate domain information.
        # Note that this type of label conditioning does not work at all if we use reflection padding in Conv2d.
        # This is because instance normalization ignores the shifting (or bias) effect.
        c = c.view(c.size(0), c.size(1), 1, 1)
        c = c.repeat(1, 1, x.size(2), x.size(3))
        x = torch.???([x, c], dim=1)
        return self.main(x)


class Discriminator(nn.Module):
    """Discriminator network with PatchGAN."""
    def __init__(self, image_size=128, conv_dim=64, c_dim=5, repeat_num=6):
        super(Discriminator, self).__init__()
        layers = []
        layers.append(nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1))
        layers.append(nn.LeakyReLU(0.01))

        curr_dim = conv_dim
        for i in range(1, repeat_num):
            layers.append(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1))
            layers.append(nn.LeakyReLU(0.01))
            curr_dim = curr_dim * 2

        kernel_size = ??????????????????????????????????????
        self.main = nn.Sequential(*layers)
        self.conv1 = nn.Conv2d(curr_dim, ?, kernel_size=3, stride=1, padding=1, bias=False)
        self.conv2 = nn.Conv2d(curr_dim, ???, kernel_size=kernel_size, bias=False)
        
    def forward(self, x):
        h = self.main(x)
        out_src = self.conv1(h)
        out_cls = self.conv2(h)
        return out_src, out_cls.view(out_cls.size(0), out_cls.size(1))

## Configuration, hyper parameter

필요한 여러가지 설정들과 hyper parameter들을 정의해준다. 이렇게 설정들을 따로 정해두는 것은 다른 데이터셋의 학습과 같은 상황에서 유연하게 대처할 수 있도록 코드를 모듈화해두는 작업이다.

In [5]:
config = easydict.EasyDict({
    
    # Model configuration
    "c_dim": 5, # dimension of domain labels
    "celeba_crop_size": 178, # crop size for the CelebA dataset
    "image_size": 128, # image resolution
    "g_conv_dim": 64, # number of conv filters in the first layer of G
    "d_conv_dim": 64, # number of conv filters in the first layer of D
    "g_repeat_num": 6, # number of residual blocks in G
    "d_repeat_num": 6, # number of strided conv layers in D
    "lambda_cls": 1.0, # weight for domain classification loss
    "lambda_rec": 10.0, # weight for reconstruction loss
    "lambda_gp": 10.0, # weight for gradient penalty
    
    # Training configuration
#     "dataset": 'CelebA'
    "batch_size": 16, # mini-batch size
    "num_iters": 200000, # number of total iterations for training D
    "num_iters_decay": 100000, # number of iterations for decaying lr
    "g_lr": 0.0001, # learning rate for G
    "d_lr": 0.0001, # learning rate for D
    "n_critic": 5, # number of D updates per each G update
    "beta1": 0.5, # beta1 for Adam optimizer
    "beta2": 0.999, # beta2 for Adam optimizer
    "resume_iters": None, # resume training from this step
    "selected_attrs": ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Male', 'Young'], # selected attributes for the CelebA dataset
    
    # Test configuration
    "test_iters": 200000, # test model from this step
    
    # Miscellaneous
    "num_workers": 1, #
    "mode": 'train', # train or test mode
    
    # Directories
    "celeba_image_dir": 'data/celeba/images', #
    "attr_path": 'data/celeba/list_attr_celeba.txt', #
    
#     "log_dir": 'stargan_celeba/logs', #
    "model_save_dir": 'stargan_celeba/models', #
    "sample_dir": 'stargan_celeba/samples', #
    "result_dir": 'stargan_celeba/results', #
    
    # Step size
    "log_step": 10, #
    "sample_step": 1000, #
    "model_save_step": 10000, #
    "lr_update_step": 1000 #
})

In [6]:
if not os.path.exists(config.model_save_dir):
    os.makedirs(config.model_save_dir)
if not os.path.exists(config.sample_dir):
    os.makedirs(config.sample_dir)
if not os.path.exists(config.result_dir):
    os.makedirs(config.result_dir)

## Loss

#### Adversarial Loss

![loss](https://drive.google.com/uc?export=view&id=1eOBGMlNiw2Obn6nHiVYYRFYi3zrTFD4X)

#### Domain Classification Loss

흔히 사용되는 classification loss (binary cross entropy)

#### Reconstruction Loss

![rec loss](https://drive.google.com/uc?export=view&id=1hntL27V8kJSdmBkOIN3JJ5bJu8k6pRUO)

#### Full Objective
![loss2](https://drive.google.com/uc?export=view&id=1LYq5233lnCxISttXtINlLkgV0voHvNiC)


In [7]:
class Solver(object):
    """Solver for training and testing StarGAN."""

    def __init__(self, celeba_loader, config):
        """Initialize configurations."""

        # Data loader.
        self.celeba_loader = celeba_loader

        # Model configurations.
        self.c_dim = config.c_dim
        self.image_size = config.image_size
        self.g_conv_dim = config.g_conv_dim
        self.d_conv_dim = config.d_conv_dim
        self.g_repeat_num = config.g_repeat_num
        self.d_repeat_num = config.d_repeat_num
        self.lambda_cls = config.lambda_cls
        self.lambda_rec = config.lambda_rec
        self.lambda_gp = config.lambda_gp

        # Training configurations.
        self.batch_size = config.batch_size
        self.num_iters = config.num_iters
        self.num_iters_decay = config.num_iters_decay
        self.g_lr = config.g_lr
        self.d_lr = config.d_lr
        self.n_critic = config.n_critic
        self.beta1 = config.beta1
        self.beta2 = config.beta2
        self.resume_iters = config.resume_iters
        self.selected_attrs = config.selected_attrs

        # Test configurations.
        self.test_iters = config.test_iters

        # Miscellaneous.
        self.device = torch.device(????? if torch.cuda.is_available() else ?????)

        # Directories.
#         self.log_dir = config.log_dir
        self.sample_dir = config.sample_dir
        self.model_save_dir = config.model_save_dir
        self.result_dir = config.result_dir

        # Step size.
        self.log_step = config.log_step
        self.sample_step = config.sample_step
        self.model_save_step = config.model_save_step
        self.lr_update_step = config.lr_update_step

        # Build the model and tensorboard.
        self.build_model()
        
    def build_model(self):
        """Create a generator and a discriminator."""
        self.G = Generator(???????????????????????????????????)
        self.D = Discriminator(??????????????????????????????????????????????????) 

        self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2])
        self.d_optimizer = torch.optim.Adam(self.D.parameters(), self.d_lr, [self.beta1, self.beta2])
        self.print_network(self.G, 'G')
        self.print_network(self.D, 'D')
            
        self.G.to(self.device)
        self.D.to(self.device)
        
    def print_network(self, model, name):
        """Print out the network information."""
        num_params = 0
        for p in model.parameters():
            num_params += p.numel()
        print(model)
        print(name)
        print("The number of parameters: {}".format(num_params))
        
    def restore_model(self, resume_iters):
        """Restore the trained generator and discriminator."""
        print('Loading the trained models from step {}...'.format(resume_iters))
        G_path = os.path.join(self.model_save_dir, '{}-G.ckpt'.format(resume_iters))
        D_path = os.path.join(self.model_save_dir, '{}-D.ckpt'.format(resume_iters))
        self.G.load_state_dict(torch.load(G_path, map_location=lambda storage, loc: storage))
        self.D.load_state_dict(torch.load(D_path, map_location=lambda storage, loc: storage))
        
    def update_lr(self, g_lr, d_lr):
        """Decay learning rates of the generator and discriminator."""
        for param_group in self.g_optimizer.param_groups:
            param_group['lr'] = g_lr
        for param_group in self.d_optimizer.param_groups:
            ???????????????????????
            
    def reset_grad(self):
        """Reset the gradient buffers."""
        self.g_optimizer.zero_grad()
        ???????????????????????????
            
    def denorm(self, x):
        """Convert the range from [-1, 1] to [0, 1]."""
        out = (x + 1) / 2
        return out.clamp_(0, 1)
    
    def gradient_penalty(self, y, x):
        """Compute gradient penalty: (L2_norm(dy/dx) - 1)**2."""
        weight = torch.ones(y.size()).to(self.device)
        dydx = torch.autograd.grad(outputs=y,
                                   inputs=x,
                                   grad_outputs=weight,
                                   retain_graph=True,
                                   create_graph=True,
                                   only_inputs=True)[0]

        dydx = dydx.view(dydx.size(0), -1)
        dydx_l2norm = torch.sqrt(torch.sum(dydx**2, dim=1))
        return torch.mean((dydx_l2norm-1)**2)
    
    def label2onehot(self, labels, dim):
        """Convert label indices to one-hot vectors."""
        batch_size = labels.size(0)
        out = torch.zeros(batch_size, dim)
        out[np.arange(batch_size), labels.long()] = 1
        return out
    
    def create_labels(self, c_org, c_dim=5, selected_attrs=None):
        """Generate target domain labels for debugging and testing."""
        # Get hair color indices.
        hair_color_indices = []
        for i, attr_name in enumerate(selected_attrs):
            if attr_name in ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Gray_Hair']:
                hair_color_indices.append(i)

        c_trg_list = []
        for i in range(c_dim):
            c_trg = c_org.clone()
            if i in hair_color_indices:  # Set one hair color to 1 and the rest to 0.
                c_trg[:, i] = 1
                for j in hair_color_indices:
                    if j != i:
                        c_trg[:, j] = 0
            else:
                c_trg[:, i] = (c_trg[:, i] == 0)  # Reverse attribute value.

            c_trg_list.append(c_trg.to(self.device))
        return c_trg_list
    
    def classification_loss(self, logit, target):
        """Compute binary or softmax cross entropy loss."""
        return F.binary_cross_entropy_with_logits(logit, target, size_average=False) / logit.size(0)
    
    def train(self):
        """Train StarGAN within a CelebA dataset."""
        # Set data loader.
        data_loader = self.celeba_loader

        # Fetch fixed inputs for debugging.
        data_iter = iter(data_loader)
        x_fixed, c_org = next(data_iter)
        x_fixed = x_fixed.to(self.device)
        c_fixed_list = self.create_labels(c_org, self.c_dim, self.selected_attrs)

        # Learning rate cache for decaying.
        g_lr = self.g_lr
        d_lr = self.d_lr

        # Start training from scratch or resume training.
        start_iters = 0
        if self.resume_iters:
            start_iters = self.resume_iters
            self.restore_model(self.resume_iters)

        # Start training.
        print('Start training...')
        start_time = time.time()
        for i in range(start_iters, self.num_iters):

            # =================================================================================== #
            #                             1. Preprocess input data                                #
            # =================================================================================== #

            # Fetch real images and labels.
            try:
                x_real, label_org = next(data_iter)
            except:
                ?????????????????????????????
                ??????????????????????????????????

            # Generate target domain labels randomly.
            rand_idx = torch.randperm(label_org.size(0))
            label_trg = label_org[rand_idx]

            
            c_org = label_org.clone()
            c_trg = label_trg.clone()

            x_real = x_real.to(???????????)           # Input images.
            c_org = c_org.to(???????????)             # Original domain labels.
            c_trg = c_trg.to(???????????)             # Target domain labels.
            label_org = label_org.to(???????????)     # Labels for computing classification loss.
            label_trg = label_trg.to(???????????)     # Labels for computing classification loss.

            # =================================================================================== #
            #                             2. Train the discriminator                              #
            # =================================================================================== #

            # Compute loss with real images.
            out_src, out_cls = ????????????
            d_loss_real = - torch.mean(out_src)
            d_loss_cls = ??????????????????????????????????

            # Compute loss with fake images.
            x_fake = ?????????????????
            out_src, out_cls = self.D(???????.detach())
            d_loss_fake = ??????????????????

            # Compute loss for gradient penalty.
            alpha = torch.rand(x_real.size(0), 1, 1, 1).to(self.device)
            x_hat = (alpha * x_real.data + (1 - alpha) * x_fake.data).requires_grad_(True)
            out_src, _ = self.D(x_hat)
            d_loss_gp = self.gradient_penalty(out_src, x_hat)

            # Backward and optimize.
            d_loss = ???????????????????????????????????????????????? + self.lambda_gp * d_loss_gp
            self.reset_grad()
            d_loss.backward()
            self.d_optimizer.step()

            # Logging.
            loss = {}
            loss['D/loss_real'] = d_loss_real.item()
            loss['D/loss_fake'] = d_loss_fake.item()
            loss['D/loss_cls'] = d_loss_cls.item()
            loss['D/loss_gp'] = d_loss_gp.item()
            
            # =================================================================================== #
            #                               3. Train the generator                                #
            # =================================================================================== #
            
            if (i+1) % self.n_critic == 0:
                # Original-to-target domain.
                x_fake = ?????????????????
                out_src, out_cls = self.D(x_fake)
                g_loss_fake = ???????????????
                g_loss_cls = ????????????????????????????

                # Target-to-original domain.
                x_reconst = ??????????????????
                g_loss_rec = ???????????????????????????????????

                # Backward and optimize.
                g_loss = ???????????????????????????????????????????????????????????
                ?????????????????
                ?????????????????
                ?????????????????????????

                # Logging.
                loss['G/loss_fake'] = g_loss_fake.item()
                loss['G/loss_rec'] = g_loss_rec.item()
                loss['G/loss_cls'] = g_loss_cls.item()

            # =================================================================================== #
            #                                 4. Miscellaneous                                    #
            # =================================================================================== #

            # Print out training information.
            if (i+1) % self.log_step == 0:
                et = time.time() - start_time
                et = str(datetime.timedelta(seconds=et))[:-7]
                log = "Elapsed [{}], Iteration [{}/{}]".format(et, i+1, self.num_iters)
                for tag, value in loss.items():
                    log += ", {}: {:.4f}".format(tag, value)
                print(log)

                

            # Translate fixed images for debugging.
            if (i+1) % self.sample_step == 0:
                with torch.no_grad():
                    x_fake_list = [x_fixed]
                    for c_fixed in c_fixed_list:
                        x_fake_list.append(self.G(x_fixed, c_fixed))
                    x_concat = torch.cat(x_fake_list, dim=3)
                    sample_path = os.path.join(self.sample_dir, '{}-images.jpg'.format(i+1))
                    save_image(self.denorm(x_concat.data.cpu()), sample_path, nrow=1, padding=0)
                    print('Saved real and fake images into {}...'.format(sample_path))

            # Save model checkpoints.
            if (i+1) % self.model_save_step == 0:
                G_path = os.path.join(self.model_save_dir, '{}-G.ckpt'.format(i+1))
                D_path = os.path.join(self.model_save_dir, '{}-D.ckpt'.format(i+1))
                torch.save(self.G.state_dict(), G_path)
                torch.save(self.D.state_dict(), D_path)
                print('Saved model checkpoints into {}...'.format(self.model_save_dir))

            # Decay learning rates.
            if (i+1) % self.lr_update_step == 0 and (i+1) > (self.num_iters - self.num_iters_decay):
                g_lr -= (self.g_lr / float(self.num_iters_decay))
                d_lr -= (self.d_lr / float(self.num_iters_decay))
                self.update_lr(g_lr, d_lr)
                print ('Decayed learning rates, g_lr: {}, d_lr: {}.'.format(g_lr, d_lr))
                
    def test(self):
        """Translate images using StarGAN trained on a single dataset."""
        # Load the trained generator.
        self.restore_model(self.test_iters)
        
        # Set data loader.
        data_loader = self.celeba_loader
        
        with torch.no_grad():
            for i, (x_real, c_org) in enumerate(data_loader):

                # Prepare input images and target domain labels.
                x_real = ????????????????????
                c_trg_list = ????????????????????????????????????????????????

                # Translate images.
                x_fake_list = [x_real]
                for c_trg in c_trg_list:
                    x_fake_list.append(self.G(x_real, c_trg))

                # Save the translated images.
                x_concat = torch.cat(x_fake_list, dim=3)
                result_path = os.path.join(self.result_dir, '{}-images.jpg'.format(i+1))
                save_image(self.denorm(x_concat.data.cpu()), result_path, nrow=1, padding=0)
                print('Saved real and fake images into {}...'.format(result_path))

## Main Codes

In [8]:
from torch.backends import cudnn

cudnn.benchmark = True

celeba_loader = get_loader(config.celeba_image_dir, config.attr_path, config.selected_attrs,
                           config.celeba_crop_size, config.image_size, config.batch_size,
                           'CelebA', config.mode, config.num_workers)

solver = Solver(celeba_loader, config)

if config.mode == 'train':
    ?????????????
elif config.mode == 'test':
    ?????????????

Finished preprocessing the CelebA dataset...
Generator(
  (main): Sequential(
    (0): Conv2d(8, 64, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), bias=False)
    (1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace)
    (3): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace)
    (6): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace)
    (9): ResidualBlock(
      (main): Sequential(
        (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace)
        (3): Conv2d(256, 256, kernel_size=(3,



Elapsed [0:00:04], Iteration [10/200000], D/loss_real: -24.0860, D/loss_fake: 5.0567, D/loss_cls: 4.2179, D/loss_gp: 0.0656, G/loss_fake: -3.8394, G/loss_rec: 0.5249, G/loss_cls: 3.0116
Elapsed [0:00:05], Iteration [20/200000], D/loss_real: -38.6063, D/loss_fake: 6.1324, D/loss_cls: 4.6678, D/loss_gp: 0.7411, G/loss_fake: -4.1855, G/loss_rec: 0.5080, G/loss_cls: 2.8512
Elapsed [0:00:07], Iteration [30/200000], D/loss_real: -30.2640, D/loss_fake: 8.8760, D/loss_cls: 3.9003, D/loss_gp: 0.2229, G/loss_fake: -9.5772, G/loss_rec: 0.4262, G/loss_cls: 2.8190
Elapsed [0:00:09], Iteration [40/200000], D/loss_real: -17.4891, D/loss_fake: 3.3436, D/loss_cls: 3.6712, D/loss_gp: 0.1782, G/loss_fake: 1.3596, G/loss_rec: 0.3969, G/loss_cls: 3.1906
Elapsed [0:00:11], Iteration [50/200000], D/loss_real: -27.2563, D/loss_fake: 15.7733, D/loss_cls: 3.4486, D/loss_gp: 0.3871, G/loss_fake: -7.6021, G/loss_rec: 0.3785, G/loss_cls: 3.2814


KeyboardInterrupt: 

## Pretrained Weight Download

StarGAN의 경우 multi domain의 상황에서 그 contribution이 잘 드러나는 모델이므로, 이 전 실습들과 같이 mnist와 같은 작은 데이터에서만 진행할 수가 없다. 모델 자체의 iteration 수도 크기 때문에 실습 시간이라는 제한 내에서 결과를 보기 어렵다. 따라서 이번 실습에서는 위의 train code의 진행 가능 여부를 확인했으니, prtrained weight를 통해 진행하도록 한다.

Pretrained Weight는 보통 pt, pth, ckpt 등의 확장자를 가지고 있는 파일로, 이전에 트레이닝 해둔 모델의 weight 값들을 모델의 텐서 형태에 맞추어 저장하고 있다.

이미 트레이닝된 모델을 활용하는 것은 다 방면에서 활용 가능하다. (pretrained model을 통한 transfer learning으로의 활용이 많으며, 이는 딥러닝의 대부분의 분야에서 활용된다.)

!bash download.sh pretrained-celeba-128x128

In [9]:
config = easydict.EasyDict({
    
    # Model configuration
    "c_dim": 5, # dimension of domain labels
    "celeba_crop_size": 178, # crop size for the CelebA dataset
    "image_size": 128, # image resolution
    "g_conv_dim": 64, # number of conv filters in the first layer of G
    "d_conv_dim": 64, # number of conv filters in the first layer of D
    "g_repeat_num": 6, # number of residual blocks in G
    "d_repeat_num": 6, # number of strided conv layers in D
    "lambda_cls": 1.0, # weight for domain classification loss
    "lambda_rec": 10.0, # weight for reconstruction loss
    "lambda_gp": 10.0, # weight for gradient penalty
    
    # Training configuration
#     "dataset": 'CelebA'
    "batch_size": 16, # mini-batch size
    "num_iters": 200000, # number of total iterations for training D
    "num_iters_decay": 100000, # number of iterations for decaying lr
    "g_lr": 0.0001, # learning rate for G
    "d_lr": 0.0001, # learning rate for D
    "n_critic": 5, # number of D updates per each G update
    "beta1": 0.5, # beta1 for Adam optimizer
    "beta2": 0.999, # beta2 for Adam optimizer
    "resume_iters": None, # resume training from this step
    "selected_attrs": ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Male', 'Young'], # selected attributes for the CelebA dataset
    
    # Test configuration
    "test_iters": 200000, # test model from this step
    
    # Miscellaneous
    "num_workers": 1, #
    "mode": 'test', # train or test mode
    
    # Directories
    "celeba_image_dir": 'data/celeba/images', #
    "attr_path": 'data/celeba/list_attr_celeba.txt', #
    
#     "log_dir": 'stargan_celeba/logs', #
    "model_save_dir": 'stargan_celeba/models', #
    "sample_dir": 'stargan_celeba/samples', #
    "result_dir": 'stargan_celeba/results', #
    
    # Step size
    "log_step": 10, #
    "sample_step": 1000, #
    "model_save_step": 10000, #
    "lr_update_step": 1000 #
})

In [10]:
from torch.backends import cudnn

cudnn.benchmark = True

celeba_loader = get_loader(config.celeba_image_dir, config.attr_path, config.selected_attrs,
                           config.celeba_crop_size, config.image_size, config.batch_size,
                           'CelebA', config.mode, config.num_workers)

solver = Solver(celeba_loader, config)

if config.mode == 'train':
    solver.train()
elif config.mode == 'test':
    solver.test()

Finished preprocessing the CelebA dataset...
Generator(
  (main): Sequential(
    (0): Conv2d(8, 64, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), bias=False)
    (1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace)
    (3): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace)
    (6): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace)
    (9): ResidualBlock(
      (main): Sequential(
        (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace)
        (3): Conv2d(256, 256, kernel_size=(3,

Saved real and fake images into stargan_celeba/results/45-images.jpg...
Saved real and fake images into stargan_celeba/results/46-images.jpg...
Saved real and fake images into stargan_celeba/results/47-images.jpg...
Saved real and fake images into stargan_celeba/results/48-images.jpg...
Saved real and fake images into stargan_celeba/results/49-images.jpg...
Saved real and fake images into stargan_celeba/results/50-images.jpg...
Saved real and fake images into stargan_celeba/results/51-images.jpg...
Saved real and fake images into stargan_celeba/results/52-images.jpg...
Saved real and fake images into stargan_celeba/results/53-images.jpg...
Saved real and fake images into stargan_celeba/results/54-images.jpg...
Saved real and fake images into stargan_celeba/results/55-images.jpg...
Saved real and fake images into stargan_celeba/results/56-images.jpg...
Saved real and fake images into stargan_celeba/results/57-images.jpg...
Saved real and fake images into stargan_celeba/results/58-images