In [None]:
import sys
import argparse
import os.path
import random
import time
import datetime

import numpy as np

import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torchvision.utils as vutils
from PIL import Image

# 教師データの準備

In [None]:
!wget http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/facades.tar.gz
!mkdir -p ./datasets/facades
!tar -zxvf facades.tar.gz -C ./datasets/
!rm facades.tar.gz

## 画像の下準備

In [None]:
class generateDataset:
    pic_extention = [".png", ".jpg"]
    def __init__(self, params):
        self.params = params
        dir = os.path.join(params.data_dir, params.phase)
        self.AB_paths = sorted(self.make_dataset(dir))
        
    @classmethod
    def is_image_file(self, fname):
        return any(fname.endswith(ext) for ext in self.pic_extention)
    
    @classmethod
    def make_dataset(self, dir):
        images = []
        assert os.path.isdir(dir), '%s is not a valid directory' % dir

        for root, _, fnames in sorted(os.walk(dir)):
            for fname in fnames:
                if self.is_image_file(fname):
                    path = os.path.join(root, fname)
                    images.append(path)
        return images
    
    def __transform(self, param):
        list = []
        load_num = self.params.load_num
        list.append(transforms.Resize([load_num, load_num], Image.BICUBIC))

        (x, y) = param['crop_pos']
        crop_size = self.params.crop_size
        list.append(transforms.Lambda(lambda img: img.crop((x, y, x + crop_size, y + crop_size))))

        if param['flip']:
            list.append(transforms.Lambda(lambda img: img.transpose(Image.FLIP_LEFT_RIGHT)))

        list += [transforms.ToTensor(),
                 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]

        return transforms.Compose(list)
    
    def __transform_param(self):
        x_max = self.params.load_num - self.params.crop_size
        x = random.randint(0, np.maximum(0, x_max))
        y = random.randint(0, np.maximum(0, x_max))

        flip = random.random() > 0.5

        return {'crop_pos': (x, y), 'flip': flip}

    def __getitem__(self, index):
        AB_path = self.AB_paths[index]
        AB = Image.open(AB_path).convert('RGB')

        param = self.__transform_param()
        w, h = AB.size
        w2 = int(w / 2)
        
        transform = self.__transform(param)
        A = transform(AB.crop((0, 0, w2, h)))
        B = transform(AB.crop((w2, 0, w, h)))

        return {'A': B, 'B': A, 'A_paths': AB_path, 'B_paths': AB_path}

    def __len__(self):
        return len(self.AB_paths)

## 識別器の定義

### 実効的な受容野を７０とする様に設計。

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        # 二つの画像をstackするので、チャンネルは倍
        self.model = nn.Sequential(
        nn.Conv2d(6, 64, kernel_size = 4, stride = 2, padding = 1),
        nn.LeakyReLU(0.2, True),
        self.__nest_layer(64, 128),
        self.__nest_layer(128, 256),
        self.__nest_layer(256, 512, stride = 1),
        nn.Conv2d(512, 1, kernel_size = 4, stride = 1, padding = 1)
        )
        
    def __nest_layer(self, in_chn, out_chn, stride = 2):
        #１個目の畳み込み層でサイズは半分になる。
        layers = nn.Sequential(
        nn.Conv2d(in_chn, out_chn, kernel_size = 4, stride = stride, padding = 1),
        nn.BatchNorm2d(out_chn),
        nn.LeakyReLU(0.2, True))
        
        return layers
    
    def forward(self, x):
        return self.model(x)

## 生成器の定義

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.down0 = nn.Conv2d(3, 64, kernel_size = 4, stride = 2, padding = 1)
        self.down1 = self.__encode_layer(64, 128)
        self.down2 = self.__encode_layer(128, 256)
        self.down3 = self.__encode_layer(256, 512)
        self.down4 = self.__encode_layer(512, 512)
        self.down5 = self.__encode_layer(512, 512)
        self.down6 = self.__encode_layer(512, 512)
        self.down7 = self.__encode_layer(512, 512, flag = False)
        
        self.up7 = self.__decode_layer(512, 512)
        self.up6 = self.__decode_layer(1024, 512, flag_dropout = True)
        self.up5 = self.__decode_layer(1024, 512, flag_dropout = True)
        self.up4 = self.__decode_layer(1024, 512, flag_dropout = True)
        self.up3 = self.__decode_layer(1024, 256)
        self.up2 = self.__decode_layer(512, 128)
        self.up1 = self.__decode_layer(256, 64)
        self.up0 = nn.Sequential(
        self.__decode_layer(128, 3, flag = False),
        nn.Tanh())
        
    #エンコーダー部分
    def __encode_layer(self, in_chn, out_chn, flag = True):
        #畳み込みそうでサイズが半分に落ちる。
        layers = [
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(in_chn, out_chn, kernel_size = 4, stride = 2, padding = 1)
        ]
        
        if flag:
            layers.append(nn.BatchNorm2d(out_chn))
            
        return nn.Sequential(*layers)
        
    #デコーダー部分
    def __decode_layer(self, in_chn, out_chn, flag = True, flag_dropout = False):
        layers = [
            nn.ReLU(True),
            nn.ConvTranspose2d(in_chn, out_chn, kernel_size = 4, stride = 2, padding = 1)
        ]
        
        if flag:
            layers.append(nn.BatchNorm2d(out_chn))
        if flag_dropout:
            layers.append(nn.Dropout(0.5))
        return nn.Sequential(*layers)
        
        
    def forward(self, x):
        x0 = self.down0(x)
        x1 = self.down1(x0)
        x2 = self.down2(x1)
        x3 = self.down3(x2)
        x4 = self.down4(x3)
        x5 = self.down5(x4)
        x6 = self.down6(x5)
        x7 = self.down7(x6)
        y7 = self.up7(x7)
        
        y6 = self.up6(self.cat(x6, y7))
        y5 = self.up5(self.cat(x5, y6))
        y4 = self.up4(self.cat(x4, y5))
        y3 = self.up3(self.cat(x3, y4))
        y2 = self.up2(self.cat(x2, y3))
        y1 = self.up1(self.cat(x1, y2))
        y0 = self.up0(self.cat(x0, y1))
        
        return y0
    
    
    def cat(self, x,y):
        return torch.cat([x,y], dim = 1)

## 損失関数の導入

In [None]:
class pix2pix_Loss(nn.Module):
    def __init__(self):
        super().__init__()
        self.loss = nn.BCEWithLogitsLoss()
        self.register_buffer("real", torch.tensor(1.0))
        self.register_buffer("fake", torch.tensor(0.0))
        
    def __call__(self,prediction, is_real):
        if is_real :
            target_tensor = self.real
        else:
            target_tensor = self.fake
            
        return self.loss(prediction, target_tensor.expand_as(prediction))

# モデルの構築

In [None]:
class Pix2pix:
    def __init__(self, params):
        self.params = params
        
        #生成器
        self.Gnet = Generator().to(self.params.device)
        self.Gnet.apply(self.weights_init)
        if self.params.path2generator != None:
            self.Gnet.load_state_dict(torch.load(self.params.path2generator, 
                                                 map_location=self.params.device_name), strict=False)
        #識別器
        self.Dnet = Discriminator().to(self.params.device)
        self.Dnet.apply(self.weights_init)
        if self.params.path2discriminator != None:
            self.Dnet.load_state_dict(torch.load(self.params.path2discriminator, 
                                                 map_location=self.params.device_name), strict=False)
        #オプティマイザー
        self.optimizerG = optim.Adam(self.Gnet.parameters(), 
                                                         lr = 0.00002, betas = (0.5, 0.999))
        self.optimizerD = optim.Adam(self.Dnet.parameters(), 
                                                         lr = 0.00002, betas = (0.5, 0.999))
        
        #損失関数
        self.loss_pixel = pix2pix_Loss().to(self.params.device)
        self.loss_L1 = nn.L1Loss()
        
        #スケジューラー
        self.schedulerG = optim.lr_scheduler.LambdaLR(self.optimizerG, 
                                                                                  self.__change_lr)
        self.schedulerD = optim.lr_scheduler.LambdaLR(self.optimizerD, 
                                                                                  self.__change_lr)
        
    def update_lr(self):
        self.schedulerG.step()
        self.schedulerD.step()
        
    def __change_lr(self, epoch):
        if self.params.epochs_lr_decay_start < 0:
            return 1.
        delta = max(0, epoch - self.params.epochs_lr_decay_start) / float(self.params.epochs_lr_decay)
        return max(0.0, 1.0 - delta)
    
    
    def weights_init(self, m):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            m.weight.data.normal_(0.0, 0.02)
        elif classname.find('BatchNorm') != -1:
            m.weight.data.normal_(1.0, 0.02)
            m.bias.data.fill_(0)
            
    def train(self, data, batches_done):
        self.realA = data["A"].to(self.params.device)
        self.realB = data["B"].to(self.params.device)
        
        #realA --> fakeB
        fakeB = self.Gnet(self.realA)
        self.fakeB = fakeB
        
        #本物と偽物のペアの損失
        real_fake = torch.cat( (self.realA, fakeB) , dim = 1)
        prediction_fake = self.Dnet(real_fake.detach())
        loss_fake = self.loss_pixel(prediction_fake, False)
        
        #本物画像のペアの損失
        real_real = torch.cat( (self.realA, self.realB) , dim = 1)
        prediction_real = self.Dnet(real_real)
        loss_real = self.loss_pixel(prediction_real, True)
        
        #合計損失
        lossD = 0.5 * (loss_real + loss_fake)
        
        
        #識別器の更新
        self.optimizerD.zero_grad()
        lossD.backward()
        self.optimizerD.step()
        
        #生成器の損失計算
        with torch.no_grad():
            prediction_fake = self.Dnet(real_fake)
        lossG_pix = self.loss_pixel(prediction_fake, True)
        lossG_L1 = self.loss_L1(fakeB, self.realB)
        lossG = lossG_pix + lossG_L1 * self.params.lambda_L1
        
        #生成器の更新
        self.optimizerG.zero_grad()
        lossG.backward()
        self.optimizerG.step()
        
        return lossD, lossG
        
    def save_model(self, epoch):
        # モデルの保存
        output_dir = self.params.output_dir
        torch.save(self.Gnet.state_dict(), '{}/pix2pix_G_epoch_{}'.format(output_dir, epoch))
        torch.save(self.Dnet.state_dict(), '{}/pix2pix_D_epoch_{}'.format(output_dir, epoch))

    def save_image(self, epoch):
        output_image = torch.cat([self.realA, self.fakeB, self.realB], dim=3)
        vutils.save_image(output_image,
                '{}/pix2pix_epoch_{}.png'.format(self.params.output_dir, epoch),
                normalize=True)



## 細かい準備

In [None]:
!mkdir -p 'output'

In [None]:
import json
def save_json(file, save_path, mode):
    with open(param_save_path, mode) as outfile:
        json.dump(file, outfile, indent=4)

## パラメータの設定

In [None]:
class Parameter_set:
    def __init__(self):
        self.epochs = 100
        
        self.save_data_interval = 10
        self.save_image_interval = 1
        self.sample_interval = 10
        
        self.batch_num = 64
        self.load_num = 286
        self.crop_size = 256
        
        self.cpu = True 
        self.data_dir = 'datasets/facades'
        self.output_dir = 'output'
        self.phase = 'train'
        
        self.lambda_L1 = 100.
        self.epochs_lr_decay = 0 
        self.epochs_lr_decay_start = -1 
        self.path2generator = None
        self.path2discriminator = None
        self.device_name = 'cuda:0' if torch.cuda.is_available() else 'cpu'
        self.device = torch.device(self.device_name)
        
        
    
    def dict(self):
        params = {
            'epochs': self.epochs,
            'save_data_interval': self.save_data_interval,
            'save_image_interval': self.save_image_interval,
            'sample_interval': self.sample_interval,
            'batch_num': self.batch_num,
            'load_num': self.load_num,
            'crop_size': self.crop_size,
            'cpu': self.cpu,
            'data_dir': self.data_dir,
            'output_dir': self.output_dir,
            'phase': self.phase,
            'lambda_L1': self.lambda_L1,
            'epochs_lr_decay': self.epochs_lr_decay,
            'epochs_lr_decay_start': self.epochs_lr_decay_start,
            'path2generator': self.path2generator,
            'path2discriminator': self.path2discriminator,
            'device_name': self.device_name
        }
        return params

In [None]:
params_set = Parameter_set()
param_save_path = os.path.join('output', 'param.json')
save_json(params_set.dict(), param_save_path, 'w')

## モデルの準備

In [None]:
model = Pix2pix(params_set)

In [None]:
dataset = generateDataset(params_set)
dataloader = DataLoader(dataset, batch_size = params_set.batch_num, shuffle=True)

# 学習

In [None]:
for epoch in range(1, params_set.epochs + 1):
    for batch_num, data in enumerate(dataloader):
        batches_done = (epoch - 1) * len(dataloader) + batch_num
        loss1, loss2 = model.train(data, batches_done)
        
        if batch_num % 20 == 0:
            print("Current Epoch {} :  Loss_D: {:.4f} Loss_G: {:.4f}".format(epoch, loss1, loss2))

    if epoch % params_set.save_data_interval == 0:
        model.save_model(epoch)

    if epoch % params_set.save_image_interval == 0:
        model.save_image(epoch)

    model.update_lr()