In [1]:
import torch
import torch.nn as nn
import parser
import numpy as np
from torch.autograd import Variable
from model_utils import *
from image_utils import *
from networks import *
from dataloader import *

In [2]:
import os, time, pickle
import torch
import torch.utils.data as data
import torchvision.transforms as transforms
from image_utils import is_image, load_img
from torch.utils.data import DataLoader

In [3]:
# def argument():
#     parser = argparse.ArgumentParser(description='pix2pix-PyTorch-implementation')
#     parser.add_argument('--dataset', required=True, default='facades')
#     parser.add_argument('--trainBatchSize', type=int, default=1, help='training batch size')
#     parser.add_argument('--testBatchSize', type=int, default=1, help='testing batch size')
#     parser.add_argument('--maxepochs', type=int, default=200, help='number of epochs')
#     parser.add_argument('--ch_in', type=int, default=3, help='input image channels')
#     parser.add_argument('--ch_out', type=int, default=3, help='output image channels')
#     parser.add_argument('--ngf', type=int, default=64, help='generator filters')
#     parser.add_argument('--ndf', type=int, default=64, help='discriminator filters')
#     parser.add_argument('--lrD', type=float, default=0.0002, help='Learning Rate D')
#     parser.add_argument('--lrG', type=float, default=0.0002, help='Learning Rate G')
#     parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5')
#     parser.add_argument('--cuda', action='store_true', default=True, help='use cuda?')
#     parser.add_argument('--threads', type=int, default=4, help='number of threads for data loader to use')
#     parser.add_argument('--seed', type=int, default=123, help='random seed to use. Default=123')
#     parser.add_argument('--lamb', type=int, default=10, help='weight on L1 term in objective')
#     parser.add_argument('--save_dir', type=str, default='result')
#     parser.add_argument('--model', type=str, default='pix2pix')
#     parser.add_argument('--model_type_D', type=str, default='n_layer')
#     parser.add_argument('--model_type_G', type=str, default='resnet_9blocks')
#     parser.add_argument('--norm_type', type=str, default='instance')
#     parser.add_argument('--use_dropout', type=bool, default='False')
#     parser.add_argument('--init_type', type=str, default='normal')
    
#     opt = parser.parse_args()
#     return opt

In [4]:
class arguments():
    def __init__(self):
        self.dataset = 'facades'
        self.trainBatchSize=1
        self.testBatchSize=1
        self.maxepoch=200
        self.ch_in=3
        self.ch_out=3
        self.ngf=64
        self.ndf=64
        self.lrD=1e-6
        self.lrG=1e-5
        self.beta1=0.5
        self.cuda=True
        self.threads=4
        self.seed=123
        self.lamb=10
        self.save_dir='result'
        self.model='pix2pix'
        self.model_type_D='n_layer'
        self.model_type_G='resnet_9blocks'
        self.norm_type='instance'
        self.use_dropout=False
        self.init_type='normal'
        self.use_sigmoid=True
args = arguments()

In [5]:
class pix2pix(object):
    def __init__(self, args):
        self.args = args
        self.D = define_D(args.ch_in, args.ndf, args.model_type_D, 3, args.norm_type, args.use_sigmoid, args.init_type, [0])
        self.G = define_G(args.ch_in, args.ch_out, args.ngf, args.model_type_G, args.norm_type, args.use_dropout, args.init_type, [0])
        self.optim_D = torch.optim.Adam(params=self.D.parameters(), lr = self.args.lrD)
        self.optim_G = torch.optim.Adam(params=self.G.parameters(), lr = self.args.lrG)
        trn_dset, tst_dset = get_train_set('/data/jehyuk/imgdata/datasets/facades/'), get_test_set('/data/jehyuk/imgdata/datasets/facades/')
        self.sampleA, self.sampleB = tst_dset[0][0].unsqueeze(dim=0).cuda(), tst_dset[0][1].unsqueeze(dim=0).cuda()
        self.trn_loader = DataLoader(dataset=trn_dset, num_workers=2, batch_size=1, shuffle=True)
        self.tst_loader = DataLoader(dataset=tst_dset, num_workers=2, batch_size=1, shuffle=False)

        self.bce_loss, self.mse_loss, self.L1_loss = nn.BCELoss().cuda(), nn.MSELoss().cuda(), nn.L1Loss().cuda()
        self.bce_true_gan_loss, self.bce_fake_gan_loss = GANLoss(False, True).cuda(), GANLoss(False, False).cuda()
        self.ls_true_gan_loss, self.ls_fake_gan_loss = GANLoss(True, True).cuda(), GANLoss(True, False).cuda()
        
    def train(self):
        self.loss_dict = dict()
        self.loss_dict['G_loss'], self.loss_dict['D_fake_loss'], self.loss_dict['D_real_loss'] = list(), list(), list()
        print('------------------Start training------------------')
        for epoch in range(self.args.maxepoch):
            self.D.train()
            print(">>>>Epoch: {}".format(epoch+1))
            start_time = time.time()
            for iter_num, batch in enumerate(self.trn_loader):
                real_a, real_b = batch[0].cuda(), batch[1].cuda()
                fake_b = self.G.forward(real_a)
                
                ###### Train D ######
                self.optim_D.zero_grad()
                # Train with fake pair
                fake_ab = torch.cat((real_a, fake_b), 2)
                D_fake_ab = self.D.forward(fake_ab.detach()) # Make compute gradient not to be calculated
                D_fake_loss = self.bce_fake_gan_loss(D_fake_ab)
                #Train with real pair
                real_ab = torch.cat((real_a, real_b), 2)
                D_real_ab = self.D.forward(real_ab.detach())
                D_real_loss = self.ls_true_gan_loss(D_real_ab)
                
                D_loss = 0.5*(D_fake_loss + D_real_loss)
                D_loss.backward()
                self.optim_D.step()
                
                ###### Train G ######
                self.D.eval()
                self.optim_G.zero_grad()
                # Train with fake pair. G must fake the Discriminator
                fake_ab = torch.cat((real_a, fake_b), 2)
                D_fake_ab = self.D.forward(fake_ab)
                G_fake_loss = self.ls_fake_gan_loss(D_fake_ab)
                # G(A) = B
                L1_loss = self.L1_loss(fake_b, real_b) * args.lamb
                
                G_loss = G_fake_loss + L1_loss
                G_loss.backward()
                self.optim_G.step()
                self.loss_dict['G_loss'] = G_loss
                self.loss_dict['D_fake_loss'] = D_fake_loss
                self.loss_dict['D_real_loss'] = D_real_loss
                
            print("In epoch: {}, D_fake_loss: {:.4f}, D_real_loss: {:.4f}, G_loss: {:.4f}".format(epoch+1, D_fake_loss, D_real_loss, G_loss))
            self.save_results(epoch, self.sampleA, self.sampleB)
        self.save_model()
            
    def save_results(self, epoch, realA, realB):
        #save result img file
        result_dir = self.args.save_dir + '/logs/' + self.args.model + '/' + self.args.dataset
        exp_config = "ngf_{}_ndf_{}_lambda_{}_norm_{}".format(self.args.ngf, self.args.ndf, self.args.lamb, self.args.norm_type)
        result_dir = os.path.join(result_dir, exp_config)
        
        self.G.eval()
        if not os.path.exists(result_dir):
            os.makedirs(result_dir)
        fake_filename = result_dir + '/epoch%03d' %epoch + '.png'
        fakeB = self.G.forward(realA)
        results = torch.cat((realA, fakeB, realB), 2)
        results_img = tensor2img(results)
        save_image(results_img, fake_filename)

    def save_model(self):
        #save trained models
        model_dir = self.args.save_dir + '/' + self.args.model + '/' + self.args.dataset
        exp_config = "ngf_{}_ndf_{}_lambda_{}_norm_{}".format(self.args.ngf, self.args.ndf, self.args.lamb, self.args.norm_type)
        model_dir = os.path.join(model_dir, exp_config)
        
        if not os.path.exists(model_dir):
            os.makedirs(model_dir)
        torch.save(self.G.state_dict(), os.path.join(model_dir, 'G.pkl'))
        torch.save(self.D.state_dict(), os.path.join(model_dir, 'D.pkl'))
        with open(os.path.join(model_dir, 'loss_dict'), 'wb') as f:
            pickle.dump(self.loss_dict, f)
    
    def load_model(self):
        model_dir = self.args.save_dir + '/' + self.args.model + '/' + self.args.dataset
        exp_config = "ngf_{}_ndf_{}_lambda_{}_norm_{}".format(self.args.ngf, self.args.ndf, self.args.lamb, self.args.norm_type)
        model_dir = os.path.join(model_dir, exp_config)
        
        self.G.load_state_dict(torch.load(os.path.join(model_dir, 'G.pkl')))
        self.D.load_state_dict(torch.load(os.path.join(model_dir, 'D.pkl')))

        

In [None]:
model = pix2pix(args)

In [None]:
model.train()

------------------Start training------------------
>>>>Epoch: 1
