# Single Image DR-GAN デモ

1. CFP 顔画像データセットの前処理，入力データの作成
2. Single Image DR-GANモデルの定義
3. 学習の定義
4. パラメータ指定，学習の実行
5. 学習結果の読み込み， 画像の生成

## 1. CFP 顔画像データセットの前処理，入力データの作成

In [9]:
import sys
sys.path.append('..')

In [10]:
import os
import glob
import numpy as np
from skimage import io, transform
from matplotlib import pylab as plt
%matplotlib inline
from tqdm import tqdm
import pdb

In [11]:
# CFP の画像を 長辺を指定した長さに， 短辺は 変換後に リサイズするクラス

class Resize(object):
    #  assume image  as H x W x C numpy array
    def __init__(self, output_size):
        assert isinstance(output_size, int)
        self.output_size = output_size
        
    def __call__(self, image):
        h, w = image.shape[:2]
        if h > w:
            new_h, new_w = self.output_size, int(self.output_size * w / h)
        else:
            new_h, new_w = int(self.output_size * h / w), self.output_size

        resized_image = transform.resize(image, (new_h, new_w))
        
        if h>w:
            diff = self.output_size - new_w
            if diff%2 == 0:
                pad_l = int(diff/2)
                pad_s = int(diff/2)
            else:
                pad_l = int(diff/2)+1
                pad_s = int(diff/2)

            padded_image = np.lib.pad(resized_image, ((0,0), (pad_l,pad_s), (0,0)), 'edge')

        else:
            diff = self.output_size - new_h
            if diff%2==0:
                pad_l = int(diff/2)
                pad_s = int(diff/2)
            else:
                pad_l = int(diff/2)+1
                pad_s = int(diff/2)

            padded_image = np.lib.pad(resized_image, ((pad_l,pad_s), (0,0),  (0,0)), 'edge')

        return padded_image

In [12]:
# 画像をロードし，長辺 110pix 短編 110pix になるようにエッジの画素値で padding

image_dir = "../cfp-dataset/Data/Images/"
rsz = Resize(110)

Indv_dir = []
for x in os.listdir(image_dir):
    if os.path.isdir(os.path.join(image_dir, x)):
        Indv_dir.append(x)
        
Indv_dir=np.sort(Indv_dir)

images = np.zeros((7000, 110, 110, 3))
id_labels = np.zeros(7000)
pose_labels = np.zeros(7000)
count = 0
gray_count = 0

for i in tqdm(range(len(Indv_dir))):
    Frontal_dir = os.path.join(image_dir, Indv_dir[i], 'frontal')
    Profile_dir = os.path.join(image_dir, Indv_dir[i], 'profile')
    
    front_img_files = os.listdir(Frontal_dir)
    prof_img_files = os.listdir(Profile_dir)
    
    for img_file in front_img_files:
        img = io.imread(os.path.join(Frontal_dir, img_file))
        if len(img.shape)==2:
            gray_count = gray_count+1
            continue
        img_rsz = rsz(img)
        images[count] = img_rsz
        id_labels[count] = i
        pose_labels[count] = 0
        count = count + 1
    
    for img_file in prof_img_files:
        img = io.imread(os.path.join(Profile_dir, img_file))
        if len(img.shape)==2:
            gray_count = gray_count+1
            continue
        img_rsz = rsz(img)
        images[count] = img_rsz
        id_labels[count] = i
        pose_labels[count] = 1
        count = count + 1
    
id_labels = id_labels.astype('int64')
pose_labels = pose_labels.astype('int64')

#[0,255] -> [-1,1]
images = images *2 - 1
# RGB -> BGR
images = images[:,:,:,[2,1,0]]
# B x H x W x C-> B x C x H x W
images = images.transpose(0, 3, 1, 2)

# 白黒画像データを取り除く
images = images[:gray_count*-1]
id_labels = id_labels[:gray_count*-1]
pose_labels = pose_labels[:gray_count*-1]
Np = int(pose_labels.max() + 1)
Nd = int(id_labels.max() + 1)
Nz = 50
channel_num = 3

100%|██████████| 500/500 [00:54<00:00,  9.13it/s]


## 2. Single-Image DRGANモデルの定義 (model/single_DR_GAN_model.py)


In [13]:
#!/usr/bin/env python
# encoding: utf-8

import torch
from torch import nn, optim
from torch.autograd import Variable
import pdb


class Discriminator(nn.Module):
    """
    multi-task CNN for identity and pose classification

    ### init
    Nd : Number of identitiy to classify
    Np : Number of pose to classify

    """

    def __init__(self, Nd, Np, channel_num):
        super(Discriminator, self).__init__()
        convLayers = [
            nn.Conv2d(channel_num, 32, 3, 1, 1, bias=False), # Bxchx96x96 -> Bx32x96x96
            nn.BatchNorm2d(32),
            nn.ELU(),
            nn.Conv2d(32, 64, 3, 1, 1, bias=False), # Bx32x96x96 -> Bx64x96x96
            nn.BatchNorm2d(64),
            nn.ELU(),
            nn.ZeroPad2d((0, 1, 0, 1)),                      # Bx64x96x96 -> Bx64x97x97
            nn.Conv2d(64, 64, 3, 2, 0, bias=False), # Bx64x97x97 -> Bx64x48x48
            nn.BatchNorm2d(64),
            nn.ELU(),
            nn.Conv2d(64, 64, 3, 1, 1, bias=False), # Bx64x48x48 -> Bx64x48x48
            nn.BatchNorm2d(64),
            nn.ELU(),
            nn.Conv2d(64, 128, 3, 1, 1, bias=False), # Bx64x48x48 -> Bx128x48x48
            nn.BatchNorm2d(128),
            nn.ELU(),
            nn.ZeroPad2d((0, 1, 0, 1)),                      # Bx128x48x48 -> Bx128x49x49
            nn.Conv2d(128, 128, 3, 2, 0, bias=False), #  Bx128x49x49 -> Bx128x24x24
            nn.BatchNorm2d(128),
            nn.ELU(),
            nn.Conv2d(128, 96, 3, 1, 1, bias=False), #  Bx128x24x24 -> Bx96x24x24
            nn.BatchNorm2d(96),
            nn.ELU(),
            nn.Conv2d(96, 192, 3, 1, 1, bias=False), #  Bx96x24x24 -> Bx192x24x24
            nn.BatchNorm2d(192),
            nn.ELU(),
            nn.ZeroPad2d((0, 1, 0, 1)),                      # Bx192x24x24 -> Bx192x25x25
            nn.Conv2d(192, 192, 3, 2, 0, bias=False), # Bx192x25x25 -> Bx192x12x12
            nn.BatchNorm2d(192),
            nn.ELU(),
            nn.Conv2d(192, 128, 3, 1, 1, bias=False), # Bx192x12x12 -> Bx128x12x12
            nn.BatchNorm2d(128),
            nn.ELU(),
            nn.Conv2d(128, 256, 3, 1, 1, bias=False), # Bx128x12x12 -> Bx256x12x12
            nn.BatchNorm2d(256),
            nn.ELU(),
            nn.ZeroPad2d((0, 1, 0, 1)),                      # Bx256x12x12 -> Bx256x13x13
            nn.Conv2d(256, 256, 3, 2, 0, bias=False),  # Bx256x13x13 -> Bx256x6x6
            nn.BatchNorm2d(256),
            nn.ELU(),
            nn.Conv2d(256, 160, 3, 1, 1, bias=False), # Bx256x6x6 -> Bx160x6x6
            nn.BatchNorm2d(160),
            nn.ELU(),
            nn.Conv2d(160, 320, 3, 1, 1, bias=False), # Bx160x6x6 -> Bx320x6x6
            nn.BatchNorm2d(320),
            nn.ELU(),
            nn.AvgPool2d(6, stride=1), #  Bx320x6x6 -> Bx320x1x1
        ]

        self.convLayers = nn.Sequential(*convLayers)
        self.fc = nn.Linear(320, Nd+1+Np)

        # 重みは全て N(0, 0.02) で初期化
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                m.weight.data.normal_(0, 0.02)

            elif isinstance(m, nn.Linear):
                m.weight.data.normal_(0, 0.02)

    def forward(self, input):
        # 畳み込み -> 平均プーリングの結果 B x 320 x 1 x 1の出力を得る
        x = self.convLayers(input)

        # バッチ数次元を消さないように１次元の次元を削除　
        x = x.squeeze(2)
        x = x.squeeze(2)

        # 全結合
        x = self.fc(x) # Bx320 -> B x (Nd+1+Np)

        return x

class Crop(nn.Module):
    """
    Generator でのアップサンプリング時に， ダウンサンプル時のZeroPad2d と逆の事をするための関数
    論文著者が Tensorflow で padding='SAME' オプションで自動的にパディングしているのを
    ダウンサンプル時にはZeroPad2dで，アップサンプリング時には Crop で実現

    ### init
    crop_list : データの上下左右をそれぞれどれくらい削るか指定
    """

    def __init__(self, crop_list):
        super(Crop, self).__init__()

        # crop_lsit = [crop_top, crop_bottom, crop_left, crop_right]
        self.crop_list = crop_list

    def forward(self, x):
        B,C,H,W = x.size()
        x = x[:,:, self.crop_list[0] : H - self.crop_list[1] , self.crop_list[2] : W - self.crop_list[3]]

        return x


class Generator(nn.Module):
    """
    Encoder/Decoder conditional GAN conditioned with pose vector and noise vector

    ### init
    Np : Dimension of pose vector (Corresponds to number of dicrete pose classes of the data)
    Nz : Dimension of noise vector

    """

    def __init__(self, Np, Nz, channel_num):
        super(Generator, self).__init__()
        self.features = []

        G_enc_convLayers = [
            nn.Conv2d(channel_num, 32, 3, 1, 1, bias=False), # Bx3x96x96 -> Bx32x96x96
            nn.BatchNorm2d(32),
            nn.ELU(),
            nn.Conv2d(32, 64, 3, 1, 1, bias=False), # Bx32x96x96 -> Bx64x96x96
            nn.BatchNorm2d(64),
            nn.ELU(),
            nn.ZeroPad2d((0, 1, 0, 1)),                      # Bx64x96x96 -> Bx64x97x97
            nn.Conv2d(64, 64, 3, 2, 0, bias=False), # Bx64x97x97 -> Bx64x48x48
            nn.BatchNorm2d(64),
            nn.ELU(),
            nn.Conv2d(64, 64, 3, 1, 1, bias=False), # Bx64x48x48 -> Bx64x48x48
            nn.BatchNorm2d(64),
            nn.ELU(),
            nn.Conv2d(64, 128, 3, 1, 1, bias=False), # Bx64x48x48 -> Bx128x48x48
            nn.BatchNorm2d(128),
            nn.ELU(),
            nn.ZeroPad2d((0, 1, 0, 1)),                      # Bx128x48x48 -> Bx128x49x49
            nn.Conv2d(128, 128, 3, 2, 0, bias=False), #  Bx128x49x49 -> Bx128x24x24
            nn.BatchNorm2d(128),
            nn.ELU(),
            nn.Conv2d(128, 96, 3, 1, 1, bias=False), #  Bx128x24x24 -> Bx96x24x24
            nn.BatchNorm2d(96),
            nn.ELU(),
            nn.Conv2d(96, 192, 3, 1, 1, bias=False), #  Bx96x24x24 -> Bx192x24x24
            nn.BatchNorm2d(192),
            nn.ELU(),
            nn.ZeroPad2d((0, 1, 0, 1)),                      # Bx192x24x24 -> Bx192x25x25
            nn.Conv2d(192, 192, 3, 2, 0, bias=False), # Bx192x25x25 -> Bx192x12x12
            nn.BatchNorm2d(192),
            nn.ELU(),
            nn.Conv2d(192, 128, 3, 1, 1, bias=False), # Bx192x12x12 -> Bx128x12x12
            nn.BatchNorm2d(128),
            nn.ELU(),
            nn.Conv2d(128, 256, 3, 1, 1, bias=False), # Bx128x12x12 -> Bx256x12x12
            nn.BatchNorm2d(256),
            nn.ELU(),
            nn.ZeroPad2d((0, 1, 0, 1)),                      # Bx256x12x12 -> Bx256x13x13
            nn.Conv2d(256, 256, 3, 2, 0, bias=False),  # Bx256x13x13 -> Bx256x6x6
            nn.BatchNorm2d(256),
            nn.ELU(),
            nn.Conv2d(256, 160, 3, 1, 1, bias=False), # Bx256x6x6 -> Bx160x6x6
            nn.BatchNorm2d(160),
            nn.ELU(),
            nn.Conv2d(160, 320, 3, 1, 1, bias=False), # Bx160x6x6 -> Bx320x6x6
            nn.BatchNorm2d(320),
            nn.ELU(),
            nn.AvgPool2d(6, stride=1), #  Bx320x6x6 -> Bx320x1x1

        ]
        self.G_enc_convLayers = nn.Sequential(*G_enc_convLayers)

        G_dec_convLayers = [
            nn.ConvTranspose2d(320,160, 3,1,1, bias=False), # Bx320x6x6 -> Bx160x6x6
            nn.BatchNorm2d(160),
            nn.ELU(),
            nn.ConvTranspose2d(160, 256, 3,1,1, bias=False), # Bx160x6x6 -> Bx256x6x6
            nn.BatchNorm2d(256),
            nn.ELU(),
            nn.ConvTranspose2d(256, 256, 3,2,0, bias=False), # Bx256x6x6 -> Bx256x13x13
            nn.BatchNorm2d(256),
            nn.ELU(),
            Crop([0, 1, 0, 1]),
            nn.ConvTranspose2d(256, 128, 3,1,1, bias=False), # Bx256x12x12 -> Bx128x12x12
            nn.BatchNorm2d(128),
            nn.ELU(),
            nn.ConvTranspose2d(128, 192,  3,1,1, bias=False), # Bx128x12x12 -> Bx192x12x12
            nn.BatchNorm2d(192),
            nn.ELU(),
            nn.ConvTranspose2d(192, 192,  3,2,0, bias=False), # Bx128x12x12 -> Bx192x25x25
            nn.BatchNorm2d(192),
            nn.ELU(),
            Crop([0, 1, 0, 1]),
            nn.ConvTranspose2d(192, 96,  3,1,1, bias=False), # Bx192x24x24 -> Bx96x24x24
            nn.BatchNorm2d(96),
            nn.ELU(),
            nn.ConvTranspose2d(96, 128,  3,1,1, bias=False), # Bx96x24x24 -> Bx128x24x24
            nn.BatchNorm2d(128),
            nn.ELU(),
            nn.ConvTranspose2d(128, 128,  3,2,0, bias=False), # Bx128x24x24 -> Bx128x49x49
            nn.BatchNorm2d(128),
            nn.ELU(),
            Crop([0, 1, 0, 1]),
            nn.ConvTranspose2d(128, 64,  3,1,1, bias=False), # Bx128x48x48 -> Bx64x48x48
            nn.BatchNorm2d(64),
            nn.ELU(),
            nn.ConvTranspose2d(64, 64,  3,1,1, bias=False), # Bx64x48x48 -> Bx64x48x48
            nn.BatchNorm2d(64),
            nn.ELU(),
            nn.ConvTranspose2d(64, 64,  3,2,0, bias=False), # Bx64x48x48 -> Bx64x97x97
            nn.BatchNorm2d(64),
            nn.ELU(),
            Crop([0, 1, 0, 1]),
            nn.ConvTranspose2d(64, 32,  3,1,1, bias=False), # Bx64x96x96 -> Bx32x96x96
            nn.BatchNorm2d(32),
            nn.ELU(),
            nn.ConvTranspose2d(32, channel_num,  3,1,1, bias=False), # Bx32x96x96 -> Bxchx96x96
            nn.Tanh(),
        ]

        self.G_dec_convLayers = nn.Sequential(*G_dec_convLayers)

        self.G_dec_fc = nn.Linear(320+Np+Nz, 320*6*6)

        # 重みは全て N(0, 0.02) で初期化
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                m.weight.data.normal_(0, 0.02)

            elif isinstance(m, nn.ConvTranspose2d):
                m.weight.data.normal_(0, 0.02)

            elif isinstance(m, nn.Linear):
                m.weight.data.normal_(0, 0.02)



    def forward(self, input, pose, noise):
        
        x = self.G_enc_convLayers(input) # Bxchx96x96 -> Bx320x1x1

        x = x.squeeze(2)
        x = x.squeeze(2)

        self.features = x

        x = torch.cat([x, pose, noise], 1)  # Bx320 -> B x (320+Np+Nz)

        x = self.G_dec_fc(x) # B x (320+Np+Nz) -> B x (320x6x6)

        x = x.view(-1, 320, 6, 6) # B x (320x6x6) -> B x 320 x 6 x 6

        x = self.G_dec_convLayers(x) #  B x 320 x 6 x 6 -> Bxchx96x96

        return x


## 3. 学習の定義 (train_single_DRGAN.py)

In [14]:
#!/usr/bin/env python
# encoding: utf-8

import os
import numpy as np
from scipy import misc
import pdb
import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt
import torch
from torch import nn, optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from util.one_hot import one_hot
from util.Is_D_strong import Is_D_strong
from util.log_learning import log_learning
from util.convert_image import convert_image
from util.DataAugmentation import FaceIdPoseDataset, Resize, RandomCrop



def train_single_DRGAN(images, id_labels, pose_labels, Nd, Np, Nz, D_model, G_model, args):
    if args.cuda:
        D_model.cuda()
        G_model.cuda()

    D_model.train()
    G_model.train()

    lr_Adam    = args.lr
    beta1_Adam = args.beta1
    beta2_Adam = args.beta2

    image_size = images.shape[0]
    epoch_time = np.ceil(image_size / args.batch_size).astype(int)

    optimizer_D = optim.Adam(D_model.parameters(), lr = lr_Adam, betas=(beta1_Adam, beta2_Adam))
    optimizer_G = optim.Adam(G_model.parameters(), lr = lr_Adam, betas=(beta1_Adam, beta2_Adam))
    loss_criterion = nn.CrossEntropyLoss()
    loss_criterion_gan = nn.BCEWithLogitsLoss()

    loss_log = []
    steps = 0

    flag_D_strong  = False
    for epoch in range(1,args.epochs+1):

        # Load augmented data
        transformed_dataset = FaceIdPoseDataset(images, id_labels, pose_labels,
                                        transform = transforms.Compose([Resize((110,110)), RandomCrop((96,96))]))
        dataloader = DataLoader(transformed_dataset, batch_size = args.batch_size, shuffle=True)

        for i, batch_data in enumerate(dataloader):
            D_model.zero_grad()
            G_model.zero_grad()

            batch_image = torch.FloatTensor(batch_data[0].float())
            batch_id_label = batch_data[1]
            batch_pose_label = batch_data[2]
            minibatch_size = len(batch_image)

            batch_ones_label = torch.ones(minibatch_size)   # 真偽判別用のラベル
            batch_zeros_label = torch.zeros(minibatch_size)


            # ノイズと姿勢コードを生成
            fixed_noise = torch.FloatTensor(np.random.uniform(-1,1, (minibatch_size, Nz)))
            tmp  = torch.LongTensor(np.random.randint(Np, size=minibatch_size))
            pose_code = one_hot(tmp, Np) # Condition 付に使用
            pose_code_label = torch.LongTensor(tmp) # CrossEntropy 誤差に使用


            if args.cuda:
                batch_image, batch_id_label, batch_pose_label, batch_ones_label, batch_zeros_label = \
                    batch_image.cuda(), batch_id_label.cuda(), batch_pose_label.cuda(), batch_ones_label.cuda(), batch_zeros_label.cuda()

                fixed_noise, pose_code, pose_code_label = \
                    fixed_noise.cuda(), pose_code.cuda(), pose_code_label.cuda()

            batch_image, batch_id_label, batch_pose_label, batch_ones_label, batch_zeros_label = \
                Variable(batch_image), Variable(batch_id_label), Variable(batch_pose_label), Variable(batch_ones_label), Variable(batch_zeros_label)

            fixed_noise, pose_code, pose_code_label = \
                Variable(fixed_noise), Variable(pose_code), Variable(pose_code_label)

            # Generatorでイメージ生成
            generated = G_model(batch_image, pose_code, fixed_noise)

            steps += 1

            # バッチ毎に交互に D と G の学習，　Dが90%以上の精度の場合は 1:4の比率で学習
            if flag_D_strong:

                if i%5 == 0:
                    # Discriminator の学習
                    flag_D_strong = Learn_D(D_model, loss_criterion, loss_criterion_gan, optimizer_D, batch_image, generated, \
                                            batch_id_label, batch_pose_label, batch_ones_label, batch_zeros_label, epoch, steps, Nd, args)

                else:
                    # Generatorの学習
                    Learn_G(D_model, loss_criterion, loss_criterion_gan, optimizer_G ,generated,\
                            batch_id_label, batch_ones_label, pose_code_label, epoch, steps, Nd, args)
            else:

                if i%2==0:
                    # Discriminator の学習
                    flag_D_strong = Learn_D(D_model, loss_criterion, loss_criterion_gan, optimizer_D, batch_image, generated, \
                                            batch_id_label, batch_pose_label, batch_ones_label, batch_zeros_label, epoch, steps, Nd, args)

                else:
                    # Generatorの学習
                    Learn_G(D_model, loss_criterion, loss_criterion_gan, optimizer_G ,generated, \
                            batch_id_label, batch_ones_label, pose_code_label, epoch, steps, Nd, args)


        if epoch%args.save_freq == 0:
            # 各エポックで学習したモデルを保存
            if not os.path.isdir(args.save_dir): os.makedirs(args.save_dir)
            save_path_D = os.path.join(args.save_dir,'epoch{}_D.pt'.format(epoch))
            torch.save(D_model, save_path_D)
            save_path_G = os.path.join(args.save_dir,'epoch{}_G.pt'.format(epoch))
            torch.save(G_model, save_path_G)
            # 最後のエポックの学習前に生成した画像を１枚保存（学習の確認用）
            save_generated_image = convert_image(generated[0].cpu().data.numpy())
            save_path_image = os.path.join(args.save_dir, 'epoch{}_generatedimage.jpg'.format(epoch))
            misc.imsave(save_path_image, save_generated_image.astype(np.uint8))



def Learn_D(D_model, loss_criterion, loss_criterion_gan, optimizer_D, batch_image, generated, \
            batch_id_label, batch_pose_label, batch_ones_label, batch_zeros_label, epoch, steps, Nd, args):

    real_output = D_model(batch_image)
    syn_output = D_model(generated.detach()) # .detach() をすることで Generatorまでの逆伝播計算省略

    # id,真偽, pose それぞれのロスを計算
    L_id    = loss_criterion(real_output[:, :Nd], batch_id_label)
    L_gan   = loss_criterion_gan(real_output[:, Nd], batch_ones_label) + loss_criterion_gan(syn_output[:, Nd], batch_zeros_label)
    L_pose  = loss_criterion(real_output[:, Nd+1:], batch_pose_label)

    d_loss = L_gan + L_id + L_pose

    d_loss.backward()
    optimizer_D.step()
    log_learning(epoch, steps, 'D', d_loss.data[0], args)

    # Discriminator の強さを判別
    flag_D_strong = Is_D_strong(real_output, syn_output, batch_id_label, batch_pose_label, Nd)

    return flag_D_strong



def Learn_G(D_model, loss_criterion, loss_criterion_gan, optimizer_G ,generated, \
            batch_id_label, batch_ones_label, pose_code_label, epoch, steps, Nd, args):

    syn_output=D_model(generated)

    # id についての出力と元画像のラベル, 真偽, poseについての出力と生成時に与えたposeコード の ロスを計算
    L_id    = loss_criterion(syn_output[:, :Nd], batch_id_label)
    L_gan   = loss_criterion_gan(syn_output[:, Nd], batch_ones_label)
    L_pose  = loss_criterion(syn_output[:, Nd+1:], pose_code_label)

    g_loss = L_gan + L_id + L_pose

    g_loss.backward()
    optimizer_G.step()
    log_learning(epoch, steps, 'G', g_loss.data[0], args)


## 4. パラメータ指定，学習の実行 (main.py)

In [None]:
#!/usr/bin/env python
# encoding: utf-8

import os
import argparse
import datetime
import numpy as np
import torch
from torch import nn, optim
from torch.autograd import Variable
from model import single_DR_GAN_model as single_model
from model import multiple_DR_GAN_model as multi_model
from util.create_randomdata import create_randomdata
from train_single_DRGAN import train_single_DRGAN
from train_multiple_DRGAN import train_multiple_DRGAN
from Generate_Image import Generate_Image
import pdb

import easydict
args = easydict.EasyDict({
    "lr": 0.0002,
    "beta1": 0.5,
    "beta2": 0.999,
    "epochs": 1000,
    "batch_size": 64,
    "save_dir": 'snapshot',
    "save_freq": 1,
    "cuda": False,
})

# update args and print
args.save_dir = os.path.join(args.save_dir, 'Single',datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S'))

os.makedirs(args.save_dir)

print("Parameters:")
for attr, value in sorted(args.__dict__.items()):
    text ="\t{}={}\n".format(attr.upper(), value)
    print(text)
    with open('{}/Parameters.txt'.format(args.save_dir),'a') as f:
        f.write(text)

# Define model
D = single_model.Discriminator(Nd, Np, channel_num)
G = single_model.Generator(Np, Nz, channel_num)

# Train model
train_single_DRGAN(images, id_labels, pose_labels, Nd, Np, Nz, D, G, args)


Parameters:
	BATCH_SIZE=64

	BETA1=0.5

	BETA2=0.999

	CUDA=False

	EPOCHS=1000

	LR=0.0002

	SAVE_DIR=snapshot/Single/2019-11-15_03-44-06

	SAVE_FREQ=1

EPOCH : 1, step : 1, D : 8.307655334472656
EPOCH : 1, step : 2, G : 7.594618797302246
EPOCH : 1, step : 3, D : 8.253284454345703
EPOCH : 1, step : 4, G : 7.582085609436035
EPOCH : 1, step : 5, D : 8.263090133666992
EPOCH : 1, step : 6, G : 7.533958435058594
EPOCH : 1, step : 7, D : 8.260887145996094
EPOCH : 1, step : 8, G : 7.553605556488037
EPOCH : 1, step : 9, D : 8.144478797912598
EPOCH : 1, step : 10, G : 7.513525009155273
EPOCH : 1, step : 11, D : 8.19669246673584
EPOCH : 1, step : 12, G : 7.496384620666504
EPOCH : 1, step : 13, D : 8.145478248596191
EPOCH : 1, step : 14, G : 7.540350437164307
EPOCH : 1, step : 15, D : 8.320524215698242
EPOCH : 1, step : 16, G : 7.629148960113525
EPOCH : 1, step : 17, D : 8.062918663024902
EPOCH : 1, step : 18, G : 7.447247505187988
EPOCH : 1, step : 19, D : 8.099799156188965
EPOCH : 1, step : 20

  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "


EPOCH : 2, step : 111, D : 7.59882926940918
EPOCH : 2, step : 112, G : 7.023697853088379
EPOCH : 2, step : 113, D : 7.624639511108398
EPOCH : 2, step : 114, G : 6.989665985107422
EPOCH : 2, step : 115, D : 7.670737266540527
EPOCH : 2, step : 116, G : 7.01533842086792
EPOCH : 2, step : 117, D : 7.754487991333008
EPOCH : 2, step : 118, G : 6.966924667358398
EPOCH : 2, step : 119, D : 7.677892208099365
EPOCH : 2, step : 120, G : 7.015960216522217
EPOCH : 2, step : 121, D : 7.6473164558410645
EPOCH : 2, step : 122, G : 6.941712856292725
EPOCH : 2, step : 123, D : 7.7122483253479
EPOCH : 2, step : 124, G : 7.088972091674805
EPOCH : 2, step : 125, D : 7.565810203552246
EPOCH : 2, step : 126, G : 6.940740585327148
EPOCH : 2, step : 127, D : 7.620291709899902
EPOCH : 2, step : 128, G : 6.900816440582275
EPOCH : 2, step : 129, D : 7.576470851898193
EPOCH : 2, step : 130, G : 6.973528861999512
EPOCH : 2, step : 131, D : 7.575857639312744
EPOCH : 2, step : 132, G : 6.942293167114258
EPOCH : 2, st

EPOCH : 3, step : 294, G : 7.1057448387146
EPOCH : 3, step : 295, D : 6.9707746505737305
EPOCH : 3, step : 296, G : 7.13057279586792
EPOCH : 3, step : 297, D : 7.112574577331543
EPOCH : 3, step : 298, G : 7.185234546661377
EPOCH : 3, step : 299, D : 7.038104057312012
EPOCH : 3, step : 300, G : 7.12141752243042
EPOCH : 3, step : 301, D : 6.946699142456055
EPOCH : 3, step : 302, G : 7.1052045822143555
EPOCH : 3, step : 303, D : 7.052834987640381
EPOCH : 3, step : 304, G : 7.172644138336182
EPOCH : 3, step : 305, D : 7.073223114013672
EPOCH : 3, step : 306, G : 7.399919033050537
EPOCH : 3, step : 307, D : 6.97652530670166
EPOCH : 3, step : 308, G : 7.168344020843506
EPOCH : 3, step : 309, D : 6.971804618835449
EPOCH : 3, step : 310, G : 7.189900875091553
EPOCH : 3, step : 311, D : 6.961858749389648
EPOCH : 3, step : 312, G : 7.19865608215332
EPOCH : 3, step : 313, D : 7.0014119148254395
EPOCH : 3, step : 314, G : 7.19551944732666
EPOCH : 3, step : 315, D : 6.955749034881592
EPOCH : 3, ste

EPOCH : 5, step : 478, G : 7.211894512176514
EPOCH : 5, step : 479, D : 6.682022571563721
EPOCH : 5, step : 480, G : 7.23668909072876
EPOCH : 5, step : 481, D : 6.778932094573975
EPOCH : 5, step : 482, G : 7.169198989868164
EPOCH : 5, step : 483, D : 6.781635284423828
EPOCH : 5, step : 484, G : 7.182280540466309
EPOCH : 5, step : 485, D : 6.661975383758545
EPOCH : 5, step : 486, G : 7.153035640716553
EPOCH : 5, step : 487, D : 6.615057468414307
EPOCH : 5, step : 488, G : 7.30815315246582
EPOCH : 5, step : 489, D : 6.507323265075684
EPOCH : 5, step : 490, G : 7.206066131591797
EPOCH : 5, step : 491, D : 6.710513114929199
EPOCH : 5, step : 492, G : 7.195493698120117
EPOCH : 5, step : 493, D : 6.5861382484436035
EPOCH : 5, step : 494, G : 7.179818153381348
EPOCH : 5, step : 495, D : 6.610709190368652
EPOCH : 5, step : 496, G : 7.12639045715332
EPOCH : 5, step : 497, D : 6.634424686431885
EPOCH : 5, step : 498, G : 7.158304214477539
EPOCH : 5, step : 499, D : 6.714442253112793
EPOCH : 5, s

EPOCH : 7, step : 661, D : 6.79098653793335
EPOCH : 7, step : 662, G : 6.992258071899414
EPOCH : 7, step : 663, D : 6.583299160003662
EPOCH : 7, step : 664, G : 6.991307258605957
EPOCH : 7, step : 665, D : 6.671349048614502
EPOCH : 7, step : 666, G : 6.880144119262695
EPOCH : 7, step : 667, D : 6.72780704498291
EPOCH : 7, step : 668, G : 6.901692867279053
EPOCH : 7, step : 669, D : 6.670871257781982
EPOCH : 7, step : 670, G : 6.972987174987793
EPOCH : 7, step : 671, D : 6.487018585205078
EPOCH : 7, step : 672, G : 7.0492730140686035
EPOCH : 7, step : 673, D : 6.6367011070251465
EPOCH : 7, step : 674, G : 7.142626762390137
EPOCH : 7, step : 675, D : 6.4650959968566895
EPOCH : 7, step : 676, G : 7.014721870422363
EPOCH : 7, step : 677, D : 6.475342750549316
EPOCH : 7, step : 678, G : 7.09406042098999
EPOCH : 7, step : 679, D : 6.575969696044922
EPOCH : 7, step : 680, G : 6.941701412200928
EPOCH : 7, step : 681, D : 6.604556083679199
EPOCH : 7, step : 682, G : 7.110828876495361
EPOCH : 7,

EPOCH : 8, step : 844, G : 6.840269565582275
EPOCH : 8, step : 845, D : 6.407989025115967
EPOCH : 8, step : 846, G : 6.873669147491455
EPOCH : 8, step : 847, D : 6.6183180809021
EPOCH : 8, step : 848, G : 7.121105194091797
EPOCH : 8, step : 849, D : 6.41027307510376
EPOCH : 8, step : 850, G : 6.903271675109863
EPOCH : 8, step : 851, D : 6.490921974182129
EPOCH : 8, step : 852, G : 6.938608646392822
EPOCH : 8, step : 853, D : 6.432162284851074
EPOCH : 8, step : 854, G : 7.072232246398926
EPOCH : 8, step : 855, D : 6.397902965545654
EPOCH : 8, step : 856, G : 7.071691513061523
EPOCH : 8, step : 857, D : 6.385356426239014
EPOCH : 8, step : 858, G : 7.109584331512451
EPOCH : 8, step : 859, D : 6.304523944854736
EPOCH : 8, step : 860, G : 7.104043483734131
EPOCH : 8, step : 861, D : 6.405430793762207
EPOCH : 8, step : 862, G : 6.693262100219727
EPOCH : 8, step : 863, D : 6.713984966278076
EPOCH : 8, step : 864, G : 6.513049602508545
EPOCH : 8, step : 865, D : 6.656358242034912
EPOCH : 8, st

EPOCH : 10, step : 1025, D : 6.152266502380371
EPOCH : 10, step : 1026, G : 7.216885566711426
EPOCH : 10, step : 1027, D : 6.1125006675720215
EPOCH : 10, step : 1028, G : 7.027321815490723
EPOCH : 10, step : 1029, D : 6.064948558807373
EPOCH : 10, step : 1030, G : 7.001389503479004
EPOCH : 10, step : 1031, D : 5.868066310882568
EPOCH : 10, step : 1032, G : 7.198128700256348
EPOCH : 10, step : 1033, D : 6.150226593017578
EPOCH : 10, step : 1034, G : 7.49027156829834
EPOCH : 10, step : 1035, D : 6.661647319793701
EPOCH : 10, step : 1036, G : 6.733563423156738
EPOCH : 10, step : 1037, D : 6.45571231842041
EPOCH : 10, step : 1038, G : 6.7764129638671875
EPOCH : 10, step : 1039, D : 6.177024841308594
EPOCH : 10, step : 1040, G : 6.920872211456299
EPOCH : 10, step : 1041, D : 6.241705894470215
EPOCH : 10, step : 1042, G : 7.134250164031982
EPOCH : 10, step : 1043, D : 6.1908650398254395
EPOCH : 10, step : 1044, G : 6.944128513336182
EPOCH : 10, step : 1045, D : 6.12034797668457
EPOCH : 10, s

EPOCH : 11, step : 1200, G : 6.571808815002441
EPOCH : 11, step : 1201, D : 6.6071457862854
EPOCH : 11, step : 1202, G : 7.169093608856201
EPOCH : 11, step : 1203, D : 6.567954063415527
EPOCH : 11, step : 1204, G : 6.556551933288574
EPOCH : 11, step : 1205, D : 6.401123046875
EPOCH : 11, step : 1206, G : 6.726420879364014
EPOCH : 11, step : 1207, D : 6.201216697692871
EPOCH : 11, step : 1208, G : 6.8428144454956055
EPOCH : 11, step : 1209, D : 6.041626453399658
EPOCH : 11, step : 1210, G : 6.564837455749512
EPOCH : 12, step : 1211, D : 6.212036609649658
EPOCH : 12, step : 1212, G : 7.115494251251221
EPOCH : 12, step : 1213, D : 6.083163261413574
EPOCH : 12, step : 1214, G : 7.176898002624512
EPOCH : 12, step : 1215, D : 5.940834999084473
EPOCH : 12, step : 1216, G : 6.817367076873779
EPOCH : 12, step : 1217, D : 6.130165100097656
EPOCH : 12, step : 1218, G : 7.156674385070801
EPOCH : 12, step : 1219, D : 5.963457107543945
EPOCH : 12, step : 1220, G : 7.013247966766357
EPOCH : 12, step 

EPOCH : 13, step : 1375, D : 5.87406063079834
EPOCH : 13, step : 1376, G : 7.163100242614746
EPOCH : 13, step : 1377, D : 5.8632659912109375
EPOCH : 13, step : 1378, G : 6.59572696685791
EPOCH : 13, step : 1379, D : 6.078794002532959
EPOCH : 13, step : 1380, G : 7.13106632232666
EPOCH : 13, step : 1381, D : 5.974859714508057
EPOCH : 13, step : 1382, G : 6.669900417327881
EPOCH : 13, step : 1383, D : 5.763886451721191
EPOCH : 13, step : 1384, G : 6.8882317543029785
EPOCH : 13, step : 1385, D : 5.918089389801025
EPOCH : 13, step : 1386, G : 6.926290988922119
EPOCH : 13, step : 1387, D : 5.883264064788818
EPOCH : 13, step : 1388, G : 6.954230785369873
EPOCH : 13, step : 1389, D : 5.874432563781738
EPOCH : 13, step : 1390, G : 7.001430034637451
EPOCH : 13, step : 1391, D : 5.852739334106445
EPOCH : 13, step : 1392, G : 7.098095417022705
EPOCH : 13, step : 1393, D : 5.946601867675781
EPOCH : 13, step : 1394, G : 7.021731853485107
EPOCH : 13, step : 1395, D : 5.792816638946533
EPOCH : 13, st

EPOCH : 15, step : 1550, G : 5.931161403656006
EPOCH : 15, step : 1551, D : 6.355254173278809
EPOCH : 15, step : 1552, G : 6.323249340057373
EPOCH : 15, step : 1553, D : 5.907946586608887
EPOCH : 15, step : 1554, G : 6.855928421020508
EPOCH : 15, step : 1555, D : 5.879683971405029
EPOCH : 15, step : 1556, G : 6.626267910003662
EPOCH : 15, step : 1557, D : 5.710831165313721
EPOCH : 15, step : 1558, G : 6.514534950256348
EPOCH : 15, step : 1559, D : 5.831459045410156
EPOCH : 15, step : 1560, G : 6.986136436462402
EPOCH : 15, step : 1561, D : 5.6182684898376465
EPOCH : 15, step : 1562, G : 6.940507411956787
EPOCH : 15, step : 1563, D : 5.692627906799316
EPOCH : 15, step : 1564, G : 6.598626613616943
EPOCH : 15, step : 1565, D : 5.752634048461914
EPOCH : 15, step : 1566, G : 7.064855575561523
EPOCH : 15, step : 1567, D : 5.615880012512207
EPOCH : 15, step : 1568, G : 6.8637590408325195
EPOCH : 15, step : 1569, D : 5.761837482452393
EPOCH : 15, step : 1570, G : 6.8672051429748535
EPOCH : 15

EPOCH : 16, step : 1724, G : 6.683197498321533
EPOCH : 16, step : 1725, D : 5.623957633972168
EPOCH : 16, step : 1726, G : 6.5950751304626465
EPOCH : 16, step : 1727, D : 5.707012176513672
EPOCH : 16, step : 1728, G : 7.0923566818237305
EPOCH : 16, step : 1729, D : 5.916419506072998
EPOCH : 16, step : 1730, G : 6.466928482055664
EPOCH : 16, step : 1731, D : 5.799925804138184
EPOCH : 16, step : 1732, G : 6.790423393249512
EPOCH : 16, step : 1733, D : 5.503425598144531
EPOCH : 16, step : 1734, G : 6.9385600090026855
EPOCH : 16, step : 1735, D : 5.53016471862793
EPOCH : 16, step : 1736, G : 6.460559844970703
EPOCH : 16, step : 1737, D : 5.843725681304932
EPOCH : 16, step : 1738, G : 7.305351257324219
EPOCH : 16, step : 1739, D : 5.825404167175293
EPOCH : 16, step : 1740, G : 6.639400005340576
EPOCH : 16, step : 1741, D : 5.806029319763184
EPOCH : 16, step : 1742, G : 7.14340353012085
EPOCH : 16, step : 1743, D : 5.580953598022461
EPOCH : 16, step : 1744, G : 6.86839485168457
EPOCH : 16, s

EPOCH : 18, step : 1899, D : 5.493350982666016
EPOCH : 18, step : 1900, G : 6.513060092926025
EPOCH : 18, step : 1901, D : 5.811444282531738
EPOCH : 18, step : 1902, G : 6.968311786651611
EPOCH : 18, step : 1903, D : 5.721766948699951
EPOCH : 18, step : 1904, G : 6.639192581176758
EPOCH : 18, step : 1905, D : 5.5512495040893555
EPOCH : 18, step : 1906, G : 6.510802268981934
EPOCH : 18, step : 1907, D : 5.560451984405518
EPOCH : 18, step : 1908, G : 6.562515735626221
EPOCH : 18, step : 1909, D : 5.573428630828857
EPOCH : 18, step : 1910, G : 6.315622806549072
EPOCH : 18, step : 1911, D : 5.861034393310547
EPOCH : 18, step : 1912, G : 7.420148849487305
EPOCH : 18, step : 1913, D : 6.302074909210205
EPOCH : 18, step : 1914, G : 7.230831146240234
EPOCH : 18, step : 1915, D : 5.684798240661621
EPOCH : 18, step : 1916, G : 5.868161678314209
EPOCH : 18, step : 1917, D : 5.971778869628906
EPOCH : 18, step : 1918, G : 6.66243314743042
EPOCH : 18, step : 1919, D : 5.381469249725342
EPOCH : 18, s

In [23]:
print(images.shape)
print(id_labels.shape)

(6987, 3, 110, 110)
(6987,)


## 5. 学習結果の読み込み， 画像の生成

In [24]:
import os
import numpy as np
import pdb
import matplotlib.pyplot as plt
import torch
from torch import nn, optim
from torch.autograd import Variable

#  DR-GAN の Generator を用いる画像生成関数
def Generate_Image(images, pose_number, Np,Nz, G_model, args):
    """
    Generate_Image with learned Generator

    ### input
    images      : source images
    pose_number : integer which specify pose to generate image from source image
    Nz          : size of noise vecotr
    G_model     : learned Generator
    args        : options


    """
    if args.cuda:
        G_model.cuda()
        
    G_model.eval()

    features = []
    
    batch_size = images.shape[0]
    pose_code = np.zeros([batch_size, Np])
    pose_code[:, pose_number] = 1
    batch_image = torch.FloatTensor(images)
    batch_pose_code = torch.FloatTensor(pose_code) # Condition 付に使用
    fixed_noise = torch.FloatTensor(np.random.uniform(-1,1, (batch_size, Nz)))

    batch_image, fixed_noise, batch_pose_code = \
        batch_image.cuda(), fixed_noise.cuda(), batch_pose_code.cuda()

    batch_image, fixed_noise, batch_pose_code = \
         Variable(batch_image), Variable(fixed_noise), Variable(batch_pose_code)

    # Generatorでイメージ生成
    generated = G_model(batch_image, batch_pose_code, fixed_noise)
    features.append(G_model.features)

    features = torch.cat(features)
    
    return convert_image(generated.data.cpu().numpy())

def convert_image(data):

    img = data.transpose(0, 2, 3, 1)+1
    img = img / 2.0
    img = img * 255.
    img = img[:,:,:,[2,1,0]]
    
    return img.astype(np.uint8)

In [25]:
G = torch.load('./epoch780_G.pt')
jpg_image = convert_image(images)



In [28]:
def recursion_change_bn(module):
    if isinstance(module, torch.nn.BatchNorm2d):
        module.track_running_stats = 1
    else:
        for i, (name, module1) in enumerate(module._modules.items()):
            module1 = recursion_change_bn(module1)
    return module

# check_point = torch.load(check_point_file_path)
model = G
for i, (name, module) in enumerate(model._modules.items()):
    module = recursion_change_bn(model)
# model.eval()

In [30]:
model.dump_patches = True

In [31]:
# 同一ポーズで複数の顔写真を生成テスト

import easydict

Np = 2
n = 4
pose = 0
image_list = np.random.randint(0,6900, (1,n))[0]
args = easydict.EasyDict({
    "cuda": False,
})
generated_image = Generate_Image(images[image_list], pose, Np, 50, G, args)

plt.rcParams['figure.figsize'] = (15.0, 15.0)
for i in range(n):
    plt.subplot(2, n, i+1)
    plt.title('No.:{}, id:{}, pose:{}'.format(image_list[i], id_labels[image_list[i]], pose_labels[image_list[i]]))
    plt.imshow(jpg_image[image_list[i]])
    plt.subplot(2, n, n+i+1)
    plt.imshow(generated_image[i])

axes = plt.gcf().get_axes()
for ax in axes:
    ax.tick_params(labelbottom="off",bottom="off") # x軸の削除
    ax.tick_params(labelleft="off",left="off") # y軸の削除
    ax.set_xticklabels([]) 
    
plt.subplots_adjust(left=None, bottom=None, right=None, top=None, wspace=0, hspace=0)
plt.show()


AttributeError: 'ZeroPad2d' object has no attribute 'value'

In [32]:
# 同一写真で複数のポーズを生成テスト

Np = 2
n = 1
image_list = np.random.randint(0,6900, (1,n))[0]


generated_image = []

for i in range(Np):
        pose = i
        generated_image_pose = Generate_Image(images[image_list], pose, Np,  50, G , args)
        generated_image.append(generated_image_pose)

plt.rcParams['figure.figsize'] = (15.0, 15.0)
for i in range(Np):
    plt.subplot(2, Np, i+1)
    plt.title('No.:{}, pose:{}'.format(image_list[0], pose_labels[image_list[0]]))
    plt.imshow(jpg_image[image_list[0]])
    
    plt.subplot(2, Np, i+1+Np)
    plt.title('pose:{}'.format(i))
    plt.imshow(generated_image[i].squeeze())

axes = plt.gcf().get_axes()
for ax in axes:
    ax.tick_params(labelbottom="off",bottom="off") # x軸の削除
    ax.tick_params(labelleft="off",left="off") # y軸の削除
    ax.set_xticklabels([]) 

plt.subplots_adjust(left=None, bottom=None, right=None, top=None, wspace=0, hspace=0)
fig = plt.gcf()
fig.tight_layout()
plt.show()


AttributeError: 'ZeroPad2d' object has no attribute 'value'