In [1]:
import time, pickle, argparse, network, utils, itertools
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from torchvision import transforms
from torch.autograd import Variable
import test
import train

In [2]:
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', required=False, default='apple2orange',  help='the name of train set')
parser.add_argument('--train_subfolder', required=False, default='train',  help='the subfolder name of train set')
parser.add_argument('--test_subfolder', required=False, default='test',  help='the subfolder name of test set')
parser.add_argument('--input_ngc', type=int, default=3, help='number of input channel for generator')
parser.add_argument('--output_ngc', type=int, default=3, help='number of output channel for generator')
parser.add_argument('--input_ndc', type=int, default=3, help='number of input channel for discriminator')
parser.add_argument('--output_ndc', type=int, default=1, help='number of output channel for discriminator')
parser.add_argument('--batch_size', type=int, default=1, help='batch size')
parser.add_argument('--ngf', type=int, default=32)
parser.add_argument('--ndf', type=int, default=64)
parser.add_argument('--nb', type=int, default=9, help='the number of resnet block layer for generator')
parser.add_argument('--input_size', type=int, default=256, help='input size')
parser.add_argument('--resize_scale', type=int, default=286, help='resize scale (0 is false)')
parser.add_argument('--crop', type=bool, default=True, help='random crop True or False')
parser.add_argument('--fliplr', type=bool, default=True, help='random fliplr True or False')
parser.add_argument('--train_epoch', type=int, default=200, help='train epochs num')
parser.add_argument('--decay_epoch', type=int, default=100, help='learning rate decay start epoch num')
parser.add_argument('--lrD', type=float, default=0.0002, help='learning rate for discriminator')
parser.add_argument('--lrG', type=float, default=0.0002, help='learning rate for generator')
parser.add_argument('--lambdaA', type=float, default=10, help='lambdaA for cycle loss')
parser.add_argument('--lambdaB', type=float, default=10, help='lambdaB for cycle loss')
parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for Adam optimizer')
parser.add_argument('--beta2', type=float, default=0.999, help='beta2 for Adam optimizer')
parser.add_argument('--save_root', required=False, default='results', help='results save path')
parser.add_argument('--cuda', type=bool, default=True, help='use GPU computation')
opt = parser.parse_args(args=['--dataset', 'facades'])

In [3]:
print('------------ Arguments -------------')
for k, v in sorted(vars(opt).items()):
    print('%s = %s' % (str(k), str(v)))
print('-------------- End ----------------')

------------ Arguments -------------
batch_size = 1
beta1 = 0.5
beta2 = 0.999
crop = True
cuda = True
dataset = facades
decay_epoch = 100
fliplr = True
input_ndc = 3
input_ngc = 3
input_size = 256
lambdaA = 10
lambdaB = 10
lrD = 0.0002
lrG = 0.0002
nb = 9
ndf = 64
ngf = 32
output_ndc = 1
output_ngc = 3
resize_scale = 286
save_root = results
test_subfolder = test
train_epoch = 200
train_subfolder = train
-------------- End ----------------


### results save path

In [4]:
root, model = utils.filepath_check_and_initialize(opt.dataset, opt.save_root)

### data_loader

In [5]:
transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])
train_loader_A = utils.data_load('data/' + opt.dataset, opt.train_subfolder + 'A', transform, opt.batch_size, shuffle=True)
train_loader_B = utils.data_load('data/' + opt.dataset, opt.train_subfolder + 'B', transform, opt.batch_size, shuffle=True)
test_loader_A = utils.data_load('data/' + opt.dataset, opt.test_subfolder + 'A', transform, opt.batch_size, shuffle=False)
test_loader_B = utils.data_load('data/' + opt.dataset, opt.test_subfolder + 'B', transform, opt.batch_size, shuffle=False)

### initialize generators and discriminators

In [6]:
G_A, G_B = network.initialize_generators(opt.input_ngc, opt.output_ngc, opt.ngf, opt.nb, opt.cuda)
D_A, D_B = network.initialize_discriminators(opt.input_ndc, opt.output_ndc, opt.ndf, opt.cuda)

### Initialized Networks

In [7]:
utils.print_network(G_A)
utils.print_network(G_B)
utils.print_network(D_A)
utils.print_network(D_B)

generator(
  (conv1): Conv2d(3, 32, kernel_size=(7, 7), stride=(1, 1))
  (conv1_norm): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (conv2_norm): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
  (conv3): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (conv3_norm): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
  (resnet_blocks): Sequential(
    (0): resnet_block(
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
      (conv1_norm): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
      (conv2_norm): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    )
    (1): resnet_block(
      (conv1): Conv2d(128, 128, kernel_si

### Initialize loss

In [8]:
BCE_loss = nn.BCELoss().cuda()
MSE_loss = nn.MSELoss().cuda()
L1_loss = nn.L1Loss().cuda()

### Adam optimizer

In [9]:
G_optimizer = optim.Adam(itertools.chain(G_A.parameters(), G_B.parameters()), lr=opt.lrG, betas=(opt.beta1, opt.beta2))
D_A_optimizer = optim.Adam(D_A.parameters(), lr=opt.lrD, betas=(opt.beta1, opt.beta2))
D_B_optimizer = optim.Adam(D_B.parameters(), lr=opt.lrD, betas=(opt.beta1, opt.beta2))

### Image store

In [10]:
fakeA_store = utils.ImagePool(50)
fakeB_store = utils.ImagePool(50)

train_hist = utils.train_histogram_initialize()

### Training

In [11]:
print('**************************start training!**************************')
start_time = time.time()
for epoch in range(opt.train_epoch):
    D_A_losses = []
    D_B_losses = []
    G_A_losses = []
    G_B_losses = []
    A_cycle_losses = []
    B_cycle_losses = []
    epoch_start_time = time.time()
    num_iter = 0
    if (epoch+1) > opt.decay_epoch:
        D_A_optimizer.param_groups[0]['lr'] -= opt.lrD / (opt.train_epoch - opt.decay_epoch)
        D_B_optimizer.param_groups[0]['lr'] -= opt.lrD / (opt.train_epoch - opt.decay_epoch)
        G_optimizer.param_groups[0]['lr'] -= opt.lrG / (opt.train_epoch - opt.decay_epoch)

    for (realA, _), (realB, _) in zip(train_loader_A, train_loader_B):
        if opt.resize_scale:
            realA = utils.imgs_resize(realA, opt.resize_scale)
            realB = utils.imgs_resize(realB, opt.resize_scale)

        if opt.crop:
            realA = utils.random_crop(realA, opt.input_size)
            realB = utils.random_crop(realB, opt.input_size)

        if opt.fliplr:
            realA = utils.random_fliplr(realA)
            realB = utils.random_fliplr(realB)

        realA, realB = Variable(realA.cuda()), Variable(realB.cuda())

        # train generator G
        G_optimizer.zero_grad()

        # generate real A to fake B; D_A(G_A(A))
        fakeB = G_A(realA)
        D_A_result = D_A(fakeB)
        G_A_loss = MSE_loss(D_A_result, Variable(torch.ones(D_A_result.size()).cuda()))

        # reconstruct fake B to rec A; G_B(G_A(A))
        recA = G_B(fakeB)
        A_cycle_loss = L1_loss(recA, realA) * opt.lambdaA

        # generate real B to fake A; D_A(G_B(B))
        fakeA = G_B(realB)
        D_B_result = D_B(fakeA)
        G_B_loss = MSE_loss(D_B_result, Variable(torch.ones(D_B_result.size()).cuda()))

        # reconstruct fake A to rec B G_A(G_B(B))
        recB = G_A(fakeA)
        B_cycle_loss = L1_loss(recB, realB) * opt.lambdaB

        G_loss = G_A_loss + G_B_loss + A_cycle_loss + B_cycle_loss
        G_loss.backward()
        G_optimizer.step()

        train_hist['G_A_losses'].append(G_A_loss.data)
        train_hist['G_B_losses'].append(G_B_loss.data)
        train_hist['A_cycle_losses'].append(A_cycle_loss.data)
        train_hist['B_cycle_losses'].append(B_cycle_loss.data)

        G_A_losses.append(G_A_loss.data)
        G_B_losses.append(G_B_loss.data)
        A_cycle_losses.append(A_cycle_loss.data)
        B_cycle_losses.append(B_cycle_loss.data)

        # train discriminator D_A
        D_A_optimizer.zero_grad()

        D_A_real = D_A(realB)
        D_A_real_loss = MSE_loss(D_A_real, Variable(torch.ones(D_A_real.size()).cuda()))

        # fakeB = fakeB_store.query(fakeB.data)
        fakeB = fakeB_store.query(fakeB)
        D_A_fake = D_A(fakeB)
        D_A_fake_loss = MSE_loss(D_A_fake, Variable(torch.zeros(D_A_fake.size()).cuda()))

        D_A_loss = (D_A_real_loss + D_A_fake_loss) * 0.5
        D_A_loss.backward()
        D_A_optimizer.step()

        train_hist['D_A_losses'].append(D_A_loss.data)
        D_A_losses.append(D_A_loss.data)

        # train discriminator D_B
        D_B_optimizer.zero_grad()

        D_B_real = D_B(realA)
        D_B_real_loss = MSE_loss(D_B_real, Variable(torch.ones(D_B_real.size()).cuda()))

        # fakeA = fakeA_store.query(fakeA.data)
        fakeA = fakeA_store.query(fakeA)
        D_B_fake = D_B(fakeA)
        D_B_fake_loss = MSE_loss(D_B_fake, Variable(torch.zeros(D_B_fake.size()).cuda()))

        D_B_loss = (D_B_real_loss + D_B_fake_loss) * 0.5
        D_B_loss.backward()
        D_B_optimizer.step()

        train_hist['D_B_losses'].append(D_B_loss.data)
        D_B_losses.append(D_B_loss.data)

        num_iter += 1

    epoch_end_time = time.time()
    per_epoch_ptime = epoch_end_time - epoch_start_time
    train_hist['per_epoch_ptimes'].append(per_epoch_ptime)
    print(
    '[%d/%d] - ptime: %.2f, loss_D_A: %.3f, loss_D_B: %.3f, loss_G_A: %.3f, loss_G_B: %.3f, loss_A_cycle: %.3f, loss_B_cycle: %.3f' % (
        (epoch + 1), opt.train_epoch, per_epoch_ptime, torch.mean(torch.FloatTensor(D_A_losses)),
        torch.mean(torch.FloatTensor(D_B_losses)), torch.mean(torch.FloatTensor(G_A_losses)),
        torch.mean(torch.FloatTensor(G_B_losses)), torch.mean(torch.FloatTensor(A_cycle_losses)),
        torch.mean(torch.FloatTensor(B_cycle_losses))))


    if (epoch+1) % 10 == 0:
        test.test_results_network(test_loader_A, test_loader_B, G_A, G_B, opt.dataset)
    else:
        train.train_results_network(train_loader_A, train_loader_B, G_A, G_B, opt.dataset)

end_time = time.time()
total_time = end_time - start_time
train_hist['total_time'].append(total_time)

**************************start training!**************************




[1/200] - ptime: 160.26, loss_D_A: 0.428, loss_D_B: 0.420, loss_G_A: 0.566, loss_G_B: 0.536, loss_A_cycle: 2.496, loss_B_cycle: 2.217


  realA = Variable(realA.cuda(), volatile=True)
  realB = Variable(realB.cuda(), volatile=True)


[2/200] - ptime: 159.58, loss_D_A: 0.250, loss_D_B: 0.248, loss_G_A: 0.462, loss_G_B: 0.464, loss_A_cycle: 2.191, loss_B_cycle: 1.554
[3/200] - ptime: 159.49, loss_D_A: 0.221, loss_D_B: 0.228, loss_G_A: 0.465, loss_G_B: 0.464, loss_A_cycle: 2.140, loss_B_cycle: 1.454
[4/200] - ptime: 159.79, loss_D_A: 0.219, loss_D_B: 0.220, loss_G_A: 0.484, loss_G_B: 0.452, loss_A_cycle: 1.996, loss_B_cycle: 1.215
[5/200] - ptime: 159.46, loss_D_A: 0.209, loss_D_B: 0.197, loss_G_A: 0.517, loss_G_B: 0.512, loss_A_cycle: 2.057, loss_B_cycle: 1.336
[6/200] - ptime: 160.16, loss_D_A: 0.196, loss_D_B: 0.173, loss_G_A: 0.529, loss_G_B: 0.512, loss_A_cycle: 1.966, loss_B_cycle: 1.171
[7/200] - ptime: 166.18, loss_D_A: 0.206, loss_D_B: 0.176, loss_G_A: 0.526, loss_G_B: 0.494, loss_A_cycle: 1.921, loss_B_cycle: 1.110
[8/200] - ptime: 159.77, loss_D_A: 0.179, loss_D_B: 0.170, loss_G_A: 0.548, loss_G_B: 0.484, loss_A_cycle: 1.931, loss_B_cycle: 1.119
[9/200] - ptime: 159.86, loss_D_A: 0.178, loss_D_B: 0.182, los

  realA = Variable(realA.cuda(), volatile=True)
  realB = Variable(realB.cuda(), volatile=True)


[11/200] - ptime: 159.56, loss_D_A: 0.153, loss_D_B: 0.166, loss_G_A: 0.605, loss_G_B: 0.510, loss_A_cycle: 1.860, loss_B_cycle: 1.018
[12/200] - ptime: 159.58, loss_D_A: 0.152, loss_D_B: 0.159, loss_G_A: 0.588, loss_G_B: 0.499, loss_A_cycle: 1.815, loss_B_cycle: 0.971
[13/200] - ptime: 159.58, loss_D_A: 0.157, loss_D_B: 0.161, loss_G_A: 0.612, loss_G_B: 0.494, loss_A_cycle: 1.809, loss_B_cycle: 0.941
[14/200] - ptime: 159.75, loss_D_A: 0.156, loss_D_B: 0.157, loss_G_A: 0.586, loss_G_B: 0.478, loss_A_cycle: 1.823, loss_B_cycle: 0.911
[15/200] - ptime: 159.57, loss_D_A: 0.134, loss_D_B: 0.145, loss_G_A: 0.612, loss_G_B: 0.501, loss_A_cycle: 1.771, loss_B_cycle: 0.899
[16/200] - ptime: 159.86, loss_D_A: 0.132, loss_D_B: 0.134, loss_G_A: 0.616, loss_G_B: 0.524, loss_A_cycle: 1.836, loss_B_cycle: 0.999
[17/200] - ptime: 159.54, loss_D_A: 0.140, loss_D_B: 0.147, loss_G_A: 0.606, loss_G_B: 0.484, loss_A_cycle: 1.712, loss_B_cycle: 0.839
[18/200] - ptime: 158.87, loss_D_A: 0.133, loss_D_B: 0.

[72/200] - ptime: 160.38, loss_D_A: 0.077, loss_D_B: 0.122, loss_G_A: 0.757, loss_G_B: 0.516, loss_A_cycle: 1.284, loss_B_cycle: 0.594
[73/200] - ptime: 160.58, loss_D_A: 0.074, loss_D_B: 0.121, loss_G_A: 0.755, loss_G_B: 0.514, loss_A_cycle: 1.266, loss_B_cycle: 0.631
[74/200] - ptime: 160.39, loss_D_A: 0.074, loss_D_B: 0.127, loss_G_A: 0.743, loss_G_B: 0.523, loss_A_cycle: 1.267, loss_B_cycle: 0.622
[75/200] - ptime: 160.58, loss_D_A: 0.064, loss_D_B: 0.119, loss_G_A: 0.774, loss_G_B: 0.543, loss_A_cycle: 1.282, loss_B_cycle: 0.634
[76/200] - ptime: 160.26, loss_D_A: 0.069, loss_D_B: 0.117, loss_G_A: 0.742, loss_G_B: 0.522, loss_A_cycle: 1.261, loss_B_cycle: 0.622
[77/200] - ptime: 160.32, loss_D_A: 0.072, loss_D_B: 0.126, loss_G_A: 0.764, loss_G_B: 0.511, loss_A_cycle: 1.254, loss_B_cycle: 0.611
[78/200] - ptime: 159.97, loss_D_A: 0.066, loss_D_B: 0.119, loss_G_A: 0.780, loss_G_B: 0.540, loss_A_cycle: 1.276, loss_B_cycle: 0.636
[79/200] - ptime: 161.17, loss_D_A: 0.060, loss_D_B: 0.

[133/200] - ptime: 160.16, loss_D_A: 0.047, loss_D_B: 0.103, loss_G_A: 0.810, loss_G_B: 0.553, loss_A_cycle: 1.034, loss_B_cycle: 0.511
[134/200] - ptime: 161.35, loss_D_A: 0.046, loss_D_B: 0.105, loss_G_A: 0.836, loss_G_B: 0.553, loss_A_cycle: 1.023, loss_B_cycle: 0.513
[135/200] - ptime: 161.76, loss_D_A: 0.043, loss_D_B: 0.100, loss_G_A: 0.827, loss_G_B: 0.552, loss_A_cycle: 0.989, loss_B_cycle: 0.517
[136/200] - ptime: 160.48, loss_D_A: 0.042, loss_D_B: 0.099, loss_G_A: 0.826, loss_G_B: 0.554, loss_A_cycle: 1.028, loss_B_cycle: 0.534
[137/200] - ptime: 160.25, loss_D_A: 0.041, loss_D_B: 0.100, loss_G_A: 0.837, loss_G_B: 0.558, loss_A_cycle: 1.005, loss_B_cycle: 0.523
[138/200] - ptime: 160.58, loss_D_A: 0.042, loss_D_B: 0.103, loss_G_A: 0.819, loss_G_B: 0.552, loss_A_cycle: 0.981, loss_B_cycle: 0.514
[139/200] - ptime: 160.82, loss_D_A: 0.042, loss_D_B: 0.104, loss_G_A: 0.826, loss_G_B: 0.537, loss_A_cycle: 1.003, loss_B_cycle: 0.517
[140/200] - ptime: 160.58, loss_D_A: 0.041, loss

[194/200] - ptime: 159.98, loss_D_A: 0.022, loss_D_B: 0.090, loss_G_A: 0.924, loss_G_B: 0.616, loss_A_cycle: 0.771, loss_B_cycle: 0.425
[195/200] - ptime: 160.08, loss_D_A: 0.022, loss_D_B: 0.088, loss_G_A: 0.924, loss_G_B: 0.634, loss_A_cycle: 0.772, loss_B_cycle: 0.417
[196/200] - ptime: 160.10, loss_D_A: 0.022, loss_D_B: 0.091, loss_G_A: 0.910, loss_G_B: 0.637, loss_A_cycle: 0.763, loss_B_cycle: 0.425
[197/200] - ptime: 160.09, loss_D_A: 0.022, loss_D_B: 0.092, loss_G_A: 0.922, loss_G_B: 0.631, loss_A_cycle: 0.766, loss_B_cycle: 0.421
[198/200] - ptime: 160.31, loss_D_A: 0.022, loss_D_B: 0.090, loss_G_A: 0.927, loss_G_B: 0.648, loss_A_cycle: 0.771, loss_B_cycle: 0.416
[199/200] - ptime: 160.17, loss_D_A: 0.021, loss_D_B: 0.093, loss_G_A: 0.920, loss_G_B: 0.638, loss_A_cycle: 0.760, loss_B_cycle: 0.411
[200/200] - ptime: 160.26, loss_D_A: 0.019, loss_D_B: 0.090, loss_G_A: 0.919, loss_G_B: 0.628, loss_A_cycle: 0.764, loss_B_cycle: 0.415


In [12]:
print("Avg one epoch passing time: %.2f, total %d epochs passing time: %.2f" % (torch.mean(torch.FloatTensor(train_hist['per_epoch_ptimes'])), opt.train_epoch, total_time))
print("Training finish!... save training results")
torch.save(G_A.state_dict(), root + model + 'generatorA_param.pkl')
torch.save(G_B.state_dict(), root + model + 'generatorB_param.pkl')
torch.save(D_A.state_dict(), root + model + 'discriminatorA_param.pkl')
torch.save(D_B.state_dict(), root + model + 'discriminatorB_param.pkl')
with open(root + model + 'train_hist.pkl', 'wb') as f:
    pickle.dump(train_hist, f)

utils.show_train_hist(train_hist, save=True, path=root + model + 'train_hist.png')

Avg one epoch passing time: 160.20, total 200 epochs passing time: 39036.69
Training finish!... save training results
