In [1]:
import os
import math
import numpy as np
from torchvision import datasets, transforms
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch

import imageio
from torchvision.utils import save_image, make_grid


import itertools
import datetime
import time


from models import *
from dataset import *
from utils import *

In [2]:
epochs = 2
decay_epoch = 1

batch_size = 6
img_size = 128
n_residual_blocks = 5 #9
input_shape_A = (1, img_size, img_size)
input_shape_B = (3, img_size, img_size)
checkpoint_interval = 1
sample_interval = 100

lr = 0.0002
b1 = 0.5
b2 = 0.999

lambda_cyc = 4.0
lambda_id = 5.0

epoch = 0

In [3]:
cuda = True if torch.cuda.is_available() else False

In [4]:
# Losses
criterion_GAN = torch.nn.MSELoss()
criterion_cycle = torch.nn.L1Loss()
criterion_identity = torch.nn.L1Loss()

In [5]:
G_AB = GeneratorResNet(input_shape_A,input_shape_B, n_residual_blocks)
G_BA = GeneratorResNet(input_shape_B,input_shape_A, n_residual_blocks)
D_A = Discriminator(input_shape_A)
D_B = Discriminator(input_shape_B)

In [6]:
if cuda:
    G_AB = G_AB.cuda()
    G_BA = G_BA.cuda()
    D_A = D_A.cuda()
    D_B = D_B.cuda()
    criterion_GAN.cuda()
    criterion_cycle.cuda()
    criterion_identity.cuda()

In [7]:
if epoch != 0:
    # Load pretrained models
    G_AB.load_state_dict(torch.load("saved_models/%s/G_AB_%d.pth" % ("cycle_gan", epoch)))
    G_BA.load_state_dict(torch.load("saved_models/%s/G_BA_%d.pth" % ("cycle_gan", epoch)))
    D_A.load_state_dict(torch.load("saved_models/%s/D_A_%d.pth" % ("cycle_gan", epoch)))
    D_B.load_state_dict(torch.load("saved_models/%s/D_B_%d.pth" % ("cycle_gan", epoch)))
else:
    # Initialize weights
    G_AB.apply(weights_init_normal)
    G_BA.apply(weights_init_normal)
    D_A.apply(weights_init_normal)
    D_B.apply(weights_init_normal)

In [8]:
optimizer_G = torch.optim.Adam(
    itertools.chain(G_AB.parameters(), G_BA.parameters()), lr=lr, betas=(b1, b2)
)
optimizer_D_A = torch.optim.Adam(D_A.parameters(), lr=lr, betas=(b1, b2))
optimizer_D_B = torch.optim.Adam(D_B.parameters(), lr=lr, betas=(b1, b2))

In [9]:
lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(
    optimizer_G, lr_lambda=LambdaLR(epochs, epoch, decay_epoch).step
)
lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(
    optimizer_D_A, lr_lambda=LambdaLR(epochs, epoch, decay_epoch).step
)
lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(
    optimizer_D_B, lr_lambda=LambdaLR(epochs, epoch, decay_epoch).step
)

In [10]:
Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor

In [11]:
# Buffers of previously generated samples
fake_A_buffer = ReplayBuffer()
fake_B_buffer = ReplayBuffer()


In [12]:
transforms_A = [
    transforms.Resize(img_size, Image.BICUBIC),
    transforms.CenterCrop(img_size),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
]

transforms_B = [
    transforms.Resize(img_size, Image.BICUBIC),
    transforms.CenterCrop(img_size),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]


In [13]:
dataloader = DataLoader(
    ImageDataset("../data/", transforms_A=transforms_A, transforms_B=transforms_B, dir_A= "pencil_sketch",dir_B = "orig"),
    batch_size=batch_size,
)
val_dataloader = DataLoader(
    ImageDataset("../data/val/", transforms_A=transforms_A, transforms_B=transforms_B, dir_A= "pencil_sketch",dir_B = "orig"),
    batch_size=6
)

202599 202599
33 56


In [14]:
def sample_images(batches_donen ):
    """Saves a generated sample from the test set"""
    imgs = next(iter(val_dataloader))
    G_AB.eval()
    G_BA.eval()
    real_A = Variable(imgs["A"].type(Tensor))
    fake_B = G_AB(real_A)
    real_B = Variable(imgs["B"].type(Tensor))
    fake_A = G_BA(real_B)
    # Arange images along x-axis
    real_A = make_grid(real_A[:6], nrow=3, normalize=True)
    real_B = make_grid(real_B[:6], nrow=3, normalize=True)
    fake_A = make_grid(fake_A[:6], nrow=3, normalize=True)
    fake_B = make_grid(fake_B[:6], nrow=3, normalize=True)
    # Arange images along y-axis
    image_grid = torch.cat((real_A, fake_B, real_B, fake_A), 1)
    save_image(image_grid, "images/%s.png" % ( batches_done,), normalize=False)


In [None]:
prev_time = time.time()
for epoch in range(epoch, epochs):
    for i, batch in enumerate(dataloader):
        # Set model input
        real_A = Variable(batch["A"].type(Tensor))
        real_B = Variable(batch["B"].type(Tensor))

        # Adversarial ground truths
        valid = Variable(Tensor(np.ones((real_A.size(0), *D_A.output_shape))), requires_grad=False)
        fake = Variable(Tensor(np.zeros((real_A.size(0), *D_A.output_shape))), requires_grad=False)

        # ------------------
        #  Train Generators
        # ------------------

        G_AB.train()
        G_BA.train()

        optimizer_G.zero_grad()

        # Identity loss
#         loss_id_A = criterion_identity(G_BA(real_A), real_A)
#         loss_id_B = criterion_identity(G_AB(real_B), real_B)
        loss_identity =0
#         loss_identity = (loss_id_A + loss_id_B) / 2

        # GAN loss
        fake_B = G_AB(real_A)
        loss_GAN_AB = criterion_GAN(D_B(fake_B), valid)
        fake_A = G_BA(real_B)
        loss_GAN_BA = criterion_GAN(D_A(fake_A), valid)

        loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2

        # Cycle loss
        recov_A = G_BA(fake_B)
        
        loss_cycle_A = criterion_cycle(recov_A, real_A)
        recov_B = G_AB(fake_A)
        loss_cycle_B = criterion_cycle(recov_B, real_B)

        loss_cycle = (loss_cycle_A + loss_cycle_B) / 2

        # Total loss
        loss_G = loss_GAN + lambda_cyc * loss_cycle  + lambda_id * loss_identity

        loss_G.backward()
        optimizer_G.step()

        # -----------------------
        #  Train Discriminator A
        # -----------------------

        optimizer_D_A.zero_grad()

        # Real loss
        loss_real = criterion_GAN(D_A(real_A), valid)
        # Fake loss (on batch of previously generated samples)
        fake_A_ = fake_A_buffer.push_and_pop(fake_A)
        loss_fake = criterion_GAN(D_A(fake_A_.detach()), fake)
        # Total loss
        loss_D_A = (loss_real + loss_fake) / 2

        loss_D_A.backward()
        optimizer_D_A.step()

        # -----------------------
        #  Train Discriminator B
        # -----------------------

        optimizer_D_B.zero_grad()

        # Real loss
        loss_real = criterion_GAN(D_B(real_B), valid)
        # Fake loss (on batch of previously generated samples)
        fake_B_ = fake_B_buffer.push_and_pop(fake_B)
        loss_fake = criterion_GAN(D_B(fake_B_.detach()), fake)
        # Total loss
        loss_D_B = (loss_real + loss_fake) / 2

        loss_D_B.backward()
        optimizer_D_B.step()

        loss_D = (loss_D_A + loss_D_B) / 2

        # --------------
        #  Log Progress
        # --------------

        # Determine approximate time left
        batches_done = epoch * len(dataloader) + i
        batches_left = epochs * len(dataloader) - batches_done
        time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
        prev_time = time.time()

        # Print log
        sys.stdout.write(
            "\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f, adv: %f, cycle: %f, identity: %f] ETA: %s"
            % (
                epoch,
                epochs,
                i,
                len(dataloader),
                loss_D.item(),
                loss_G.item(),
                loss_GAN.item(),
                loss_cycle.item(),
                loss_identity,
                time_left,
            )
        )

        # If at sample interval save image
        if batches_done % sample_interval == 0:
            sample_images(batches_done)
        
    # Update learning rates
    lr_scheduler_G.step()
    lr_scheduler_D_A.step()
    lr_scheduler_D_B.step()
    
    if checkpoint_interval != -1 and epoch % checkpoint_interval == 0:
        # Save model checkpoints
        torch.save(G_AB.state_dict(), "saved_models/G_AB_%d.pth" % ( epoch,))
        torch.save(G_BA.state_dict(), "saved_models/G_BA_%d.pth" % ( epoch,))
        torch.save(D_A.state_dict(), "saved_models/D_A_%d.pth" % ( epoch,))
        torch.save(D_B.state_dict(), "saved_models/D_B_%d.pth" % ( epoch,))

[Epoch 0/2] [Batch 8080/33767] [D loss: 0.222293] [G loss: 0.708718, adv: 0.293782, cycle: 0.103734, identity: 0.000000] ETA: 14:02:06.581396