# Reproducing original implementation with changes
1. Explicit weights initialization
2. Queue of generated images
3. Activation to res net block
4. Add parameters to ADAILN layer
5. Label smoothing

Providing the monitoring stage

In [0]:
import torch
from torch.nn import init
from torch import nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.autograd import Variable
from torch import Tensor
from torch.utils.tensorboard import SummaryWriter
from torch.nn.parameter import Parameter
from torch.utils.tensorboard import SummaryWriter 

from skimage.filters import threshold_otsu
from PIL import Image
import cv2
import functools
import itertools
import numpy as np
import random

import os
import sys
import glob
from tqdm import tqdm_notebook
import re

import matplotlib.pyplot as plt
import pickle
import time
from scipy import misc

### Networks

In [0]:
class ResnetGenerator(nn.Module):
    def __init__(self, input_nc, output_nc, ngf=64, n_blocks=6, img_size=256, light=False):
        assert(n_blocks >= 0)
        super(ResnetGenerator, self).__init__()
        self.input_nc = input_nc
        self.output_nc = output_nc
        self.ngf = ngf
        self.n_blocks = n_blocks
        self.img_size = img_size
        self.light = light

        DownBlock = []
        DownBlock += [nn.ReflectionPad2d(3),
                      nn.Conv2d(input_nc, ngf, kernel_size=7, stride=1, padding=0, bias=False),
                      nn.InstanceNorm2d(ngf),
                      nn.ReLU(True)]

        # Down-Sampling
        n_downsampling = 2
        for i in range(n_downsampling):
            mult = 2**i
            DownBlock += [nn.ReflectionPad2d(1),
                          nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=0, bias=False),
                          nn.InstanceNorm2d(ngf * mult * 2),
                          nn.ReLU(True)]

        # Down-Sampling Bottleneck
        mult = 2**n_downsampling
        for i in range(n_blocks):
            DownBlock += [ResnetBlock(ngf * mult, use_bias=False)]

        # Class Activation Map
        self.gap_fc = nn.Linear(ngf * mult, 1, bias=False)
        self.gmp_fc = nn.Linear(ngf * mult, 1, bias=False)
        self.conv1x1 = nn.Conv2d(ngf * mult * 2, ngf * mult, kernel_size=1, stride=1, bias=True)
        self.relu = nn.ReLU(True)

        # Gamma, Beta block
        if self.light:
            # added more parameters from ngf * mult to:
            FC = [nn.Linear(ngf * mult, ngf * ngf, bias=False),
                  nn.ReLU(True),
                  nn.Linear(ngf * ngf, ngf * mult, bias=False),
                  nn.ReLU(True)]
        else:
            FC = [nn.Linear(img_size // mult * img_size // mult * ngf * mult, ngf * mult, bias=False),
                  nn.ReLU(True),
                  nn.Linear(ngf * mult, ngf * mult, bias=False),
                  nn.ReLU(True)]
        self.gamma = nn.Linear(ngf * mult, ngf * mult, bias=False)
        self.beta = nn.Linear(ngf * mult, ngf * mult, bias=False)

        # Up-Sampling Bottleneck
        for i in range(n_blocks):
            setattr(self, 'UpBlock1_' + str(i+1), ResnetAdaILNBlock(ngf * mult, use_bias=False))

        # Up-Sampling
        UpBlock2 = []
        for i in range(n_downsampling):
            mult = 2**(n_downsampling - i)
            UpBlock2 += [nn.Upsample(scale_factor=2, mode='nearest'),
                         nn.ReflectionPad2d(1),
                         nn.Conv2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=1, padding=0, bias=False),
                         ILN(int(ngf * mult / 2)),
                         nn.ReLU(True)]

        UpBlock2 += [nn.ReflectionPad2d(3),
                     nn.Conv2d(ngf, output_nc, kernel_size=7, stride=1, padding=0, bias=False),
                     nn.Tanh()]

        self.DownBlock = nn.Sequential(*DownBlock)
        self.FC = nn.Sequential(*FC)
        self.UpBlock2 = nn.Sequential(*UpBlock2)

    def forward(self, input):
        x = self.DownBlock(input)

        gap = torch.nn.functional.adaptive_avg_pool2d(x, 1)
        gap_logit = self.gap_fc(gap.view(x.shape[0], -1))
        gap_weight = list(self.gap_fc.parameters())[0]
        gap = x * gap_weight.unsqueeze(2).unsqueeze(3)

        gmp = torch.nn.functional.adaptive_max_pool2d(x, 1)
        gmp_logit = self.gmp_fc(gmp.view(x.shape[0], -1))
        gmp_weight = list(self.gmp_fc.parameters())[0]
        gmp = x * gmp_weight.unsqueeze(2).unsqueeze(3)

        cam_logit = torch.cat([gap_logit, gmp_logit], 1)
        x = torch.cat([gap, gmp], 1)
        x = self.relu(self.conv1x1(x))

        heatmap = torch.sum(x, dim=1, keepdim=True)

        if self.light:
            x_ = torch.nn.functional.adaptive_avg_pool2d(x, 1)
            x_ = self.FC(x_.view(x_.shape[0], -1))
        else:
            x_ = self.FC(x.view(x.shape[0], -1))
        gamma, beta = self.gamma(x_), self.beta(x_)


        for i in range(self.n_blocks):
            x = getattr(self, 'UpBlock1_' + str(i+1))(x, gamma, beta)
        out = self.UpBlock2(x)

        return out, cam_logit, heatmap


class ResnetBlock(nn.Module):
    def __init__(self, dim, use_bias):
        super(ResnetBlock, self).__init__()
        conv_block = []
        conv_block += [nn.ReflectionPad2d(1),
                       nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=0, bias=use_bias),
                       nn.InstanceNorm2d(dim),
                       nn.ReLU(True)]

        conv_block += [nn.ReflectionPad2d(1),
                       nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=0, bias=use_bias),
                       nn.InstanceNorm2d(dim),
                       nn.ReLU(True)]

        self.conv_block = nn.Sequential(*conv_block)

    def forward(self, x):
        out = x + self.conv_block(x)
        return out


class ResnetAdaILNBlock(nn.Module):
    def __init__(self, dim, use_bias):
        super(ResnetAdaILNBlock, self).__init__()
        self.pad1 = nn.ReflectionPad2d(1)
        self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=0, bias=use_bias)
        self.norm1 = adaILN(dim)
        self.relu1 = nn.ReLU(True)

        self.pad2 = nn.ReflectionPad2d(1)
        self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=0, bias=use_bias)
        self.norm2 = adaILN(dim)
        self.relu2 = nn.ReLU(True)

    def forward(self, x, gamma, beta):
        out = self.pad1(x)
        out = self.conv1(out)
        out = self.norm1(out, gamma, beta)
        out = self.relu1(out)
        out = self.pad2(out)
        out = self.conv2(out)
        out = self.norm2(out, gamma, beta)
        out = self.relu2(out)

        return out + x


class adaILN(nn.Module):
    def __init__(self, num_features, eps=1e-5):
        super(adaILN, self).__init__()
        self.eps = eps
        self.rho = Parameter(torch.Tensor(1, num_features, 1, 1))
        self.rho.data.fill_(0.9)

    def forward(self, input, gamma, beta):
        in_mean, in_var = torch.mean(input, dim=[2, 3], keepdim=True), torch.var(input, dim=[2, 3], keepdim=True)
        out_in = (input - in_mean) / torch.sqrt(in_var + self.eps)
        ln_mean, ln_var = torch.mean(input, dim=[1, 2, 3], keepdim=True), torch.var(input, dim=[1, 2, 3], keepdim=True)
        out_ln = (input - ln_mean) / torch.sqrt(ln_var + self.eps)
        out = self.rho.expand(input.shape[0], -1, -1, -1) * out_in + (1-self.rho.expand(input.shape[0], -1, -1, -1)) * out_ln
        out = out * gamma.unsqueeze(2).unsqueeze(3) + beta.unsqueeze(2).unsqueeze(3)

        return out


class ILN(nn.Module):
    def __init__(self, num_features, eps=1e-5):
        super(ILN, self).__init__()
        self.eps = eps
        self.rho = Parameter(torch.Tensor(1, num_features, 1, 1))
        self.gamma = Parameter(torch.Tensor(1, num_features, 1, 1))
        self.beta = Parameter(torch.Tensor(1, num_features, 1, 1))
        self.rho.data.fill_(0.0)
        self.gamma.data.fill_(1.0)
        self.beta.data.fill_(0.0)

    def forward(self, input):
        in_mean, in_var = torch.mean(input, dim=[2, 3], keepdim=True), torch.var(input, dim=[2, 3], keepdim=True)
        out_in = (input - in_mean) / torch.sqrt(in_var + self.eps)
        ln_mean, ln_var = torch.mean(input, dim=[1, 2, 3], keepdim=True), torch.var(input, dim=[1, 2, 3], keepdim=True)
        out_ln = (input - ln_mean) / torch.sqrt(ln_var + self.eps)
        out = self.rho.expand(input.shape[0], -1, -1, -1) * out_in + (1-self.rho.expand(input.shape[0], -1, -1, -1)) * out_ln
        out = out * self.gamma.expand(input.shape[0], -1, -1, -1) + self.beta.expand(input.shape[0], -1, -1, -1)

        return out


class Discriminator(nn.Module):
    def __init__(self, input_nc, ndf=64, n_layers=5):
        super(Discriminator, self).__init__()
        model = [nn.ReflectionPad2d(1),
                 nn.utils.spectral_norm(
                 nn.Conv2d(input_nc, ndf, kernel_size=4, stride=2, padding=0, bias=True)),
                 nn.LeakyReLU(0.2, True)]

        for i in range(1, n_layers - 2):
            mult = 2 ** (i - 1)
            model += [nn.ReflectionPad2d(1),
                      nn.utils.spectral_norm(
                      nn.Conv2d(ndf * mult, ndf * mult * 2, kernel_size=4, stride=2, padding=0, bias=True)),
                      nn.LeakyReLU(0.2, True)]

        mult = 2 ** (n_layers - 2 - 1)
        model += [nn.ReflectionPad2d(1),
                  nn.utils.spectral_norm(
                  nn.Conv2d(ndf * mult, ndf * mult * 2, kernel_size=4, stride=1, padding=0, bias=True)),
                  nn.LeakyReLU(0.2, True)]

        # Class Activation Map
        mult = 2 ** (n_layers - 2)
        self.gap_fc = nn.utils.spectral_norm(nn.Linear(ndf * mult, 1, bias=False))
        self.gmp_fc = nn.utils.spectral_norm(nn.Linear(ndf * mult, 1, bias=False))
        self.conv1x1 = nn.Conv2d(ndf * mult * 2, ndf * mult, kernel_size=1, stride=1, bias=True)
        self.leaky_relu = nn.LeakyReLU(0.2, True)

        self.pad = nn.ReflectionPad2d(1)
        self.conv = nn.utils.spectral_norm(
            nn.Conv2d(ndf * mult, 1, kernel_size=4, stride=1, padding=0, bias=False))

        self.model = nn.Sequential(*model)

    def forward(self, input):
        x = self.model(input)

        gap = torch.nn.functional.adaptive_avg_pool2d(x, 1)
        gap_logit = self.gap_fc(gap.view(x.shape[0], -1))
        gap_weight = list(self.gap_fc.parameters())[0]
        gap = x * gap_weight.unsqueeze(2).unsqueeze(3)

        gmp = torch.nn.functional.adaptive_max_pool2d(x, 1)
        gmp_logit = self.gmp_fc(gmp.view(x.shape[0], -1))
        gmp_weight = list(self.gmp_fc.parameters())[0]
        gmp = x * gmp_weight.unsqueeze(2).unsqueeze(3)

        cam_logit = torch.cat([gap_logit, gmp_logit], 1)
        x = torch.cat([gap, gmp], 1)
        x = self.leaky_relu(self.conv1x1(x))

        heatmap = torch.sum(x, dim=1, keepdim=True)

        x = self.pad(x)
        out = self.conv(x)

        return out, cam_logit, heatmap


class RhoClipper(object):

    def __init__(self, min, max):
        self.clip_min = min
        self.clip_max = max
        assert min < max

    def __call__(self, module):

        if hasattr(module, 'rho'):
            w = module.rho.data
            w = w.clamp(self.clip_min, self.clip_max)
            module.rho.data = w

In [5]:
# summary of the models
def num_params(model):
  model_parameters = filter(lambda p: p.requires_grad, model.parameters())
  params = sum([np.prod(p.size()) for p in model_parameters])
  return params

dL = Discriminator(3, n_layers=5)
dG = Discriminator(3, n_layers=7)
g_light = ResnetGenerator(3,3, light=True)
g_hard = ResnetGenerator(3,3, light=False)

dL_num = round(num_params(dL) * 2 / 1e6, 3)
dG_num = round(num_params(dG) * 2 / 1e6, 3)
g_light_num = round(num_params(g_light) * 2 / 1e6, 3)
g_hard_num = round(num_params(g_hard) * 2 / 1e6, 3)
overall_light = round(dL_num + dG_num + g_light_num, 3)
overall_hard = round(dL_num + dG_num + g_hard_num, 3)

print('---- Summary models ----')
print("Number of parameters (in millions):")
print("{:10}{:10}{:10}{:10}{:20}{:10}".format("D_L", 'D_G', "G_LIGHT", "G_HARD","Overall_LIGHT", "Overall_HARD"))
print("{:10}{:10}{:10}{:10}{:20}{:10}".format(str(dL_num), str(dG_num), str(g_light_num), 
                                    str(g_hard_num), str(overall_light), str(overall_hard)))

---- Summary models ----
Number of parameters (in millions):
D_L       D_G       G_LIGHT   G_HARD    Overall_LIGHT       Overall_HARD
6.581     106.26    34.551    567.359   147.392             680.2     


### Utils

In [0]:
def load_test_data(image_path, size=256):
    img = misc.imread(image_path, mode='RGB')
    img = misc.imresize(img, [size, size])
    img = np.expand_dims(img, axis=0)
    img = preprocessing(img)

    return img

def preprocessing(x):
    x = x/127.5 - 1 # -1 ~ 1
    return x

def save_images(images, size, image_path):
    return imsave(inverse_transform(images), size, image_path)

def inverse_transform(images):
    return (images+1.) / 2

def imsave(images, size, path):
    return misc.imsave(path, merge(images, size))

def merge(images, size):
    h, w = images.shape[1], images.shape[2]
    img = np.zeros((h * size[0], w * size[1], 3))
    for idx, image in enumerate(images):
        i = idx % size[1]
        j = idx // size[1]
        img[h*j:h*(j+1), w*i:w*(i+1), :] = image

    return img

def check_folder(log_dir):
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    return log_dir

def str2bool(x):
    return x.lower() in ('true')

def cam(x, size = 256):
    x = x - np.min(x)
    cam_img = x / np.max(x)
    cam_img = np.uint8(255 * cam_img)
    cam_img = cv2.resize(cam_img, (size, size))
    cam_img = cv2.applyColorMap(cam_img, cv2.COLORMAP_JET)
    return cam_img / 255.0

def imagenet_norm(x):
    mean = [0.485, 0.456, 0.406]
    std = [0.299, 0.224, 0.225]
    mean = torch.FloatTensor(mean).unsqueeze(0).unsqueeze(2).unsqueeze(3).to(x.device)
    std = torch.FloatTensor(std).unsqueeze(0).unsqueeze(2).unsqueeze(3).to(x.device)
    return (x - mean) / std

def denorm(x):
    return x * 0.5 + 0.5

def tensor2numpy(x):
    return x.detach().cpu().numpy().transpose(1,2,0)

def RGB2BGR(x):
    return cv2.cvtColor(x, cv2.COLOR_RGB2BGR)


# weights initialization
def init_weights(net, stddev=.02):

  def weights_initializer(m):
    classname = m.__class__.__name__
    if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
      init.normal_(m.weight.data, mean=0.0, std=stddev)

      if hasattr(m, 'bias') and m.bias is not None:
        init.constant_(m.bias.data, 0.0)

  net.apply(weights_initializer)
  return net


def init_net(net, device, stddev=0.02):
  net.to(device)
  init_weights(net, stddev=stddev)
  return net


def ground_truth_tensor(tensor, label):
  return torch.ones_like(tensor) * label

### Dataset and Images Queue

In [0]:
def read_pil_image(path):
  with open(path, 'rb') as f:
    img = Image.open(f)
    return img.convert("RGB")


class DatasetFolder(torch.utils.data.Dataset):
  def __init__(self, root, extension, transform=None):
    self.img_paths = glob.glob(os.path.join(root, '*.{}'.format(extension)))
    self.img_len = len(self.img_paths)
    assert self.img_len > 0, 'Found 0 images in {}'.format(root)
    self.transform = transform


  def __getitem__(self, index):
    img = read_pil_image(self.img_paths[index % self.img_len])
    img = self.transform(img)
    return img


  def __len__(self):
    return self.img_len


class ImageFolder(DatasetFolder):
  def __init__(self, root, extension, transform):
    super(ImageFolder, self).__init__(root=root, 
                                      extension=extension, 
                                      transform=transform)

In [0]:
class ImagePool():
  def __init__(self, pool_size):
    self.pool_size=pool_size
    self.images = []

  def sample(self, image_data):
    if len(self.images) < self.pool_size:
      self.images.append(image_data)
      return image_data
    
    p = random.random()
    if p > 0.5:
      idx = random.randrange(0, self.pool_size)
      tmp_data = self.images[idx].clone()
      self.images[idx] = image_data.clone()
      return tmp_data

    return image_data

### Model

In [0]:
checkpoint_dir = 'ugatit/output/checkpoints/checkpoints_v0.2.1/'
images_dir = 'ugatit/output/images/images_v0.2.1/'
summary_dir = 'ugatit/output/summary/summary_v0.2.1/'
os.makedirs(checkpoint_dir, exist_ok=True)
os.makedirs(images_dir, exist_ok=True)
os.makedirs(summary_dir, exist_ok=True)


A_dir = 'data/shadow_USR/shadow_train/'
B_dir = 'data/shadow_USR/shadow_free/'
A_test_dir = 'data/shadow_USR/shadow_test/'



load_model = False
batch_size=1
image_size=256
ngf=64
ndf=64
light = True

weight_adv = 1
weight_cyc = 10
weight_idt = 10
# it helps only generator to get attended, while discriminator does not use distinct cam loss
weight_cam = 1000  


learning_rate=2e-4
beta1=.5
pool_size=50
mask_queue_size=50
n_blocks=4
slope=0.2
stddev=0.02
weight_decay=0.0001
pool_size=50

input_nc=3
output_nc=3
device=torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
real_label=0.9

n_steps = int(3e5)
decay_start = n_steps // 2
decay_flag=True
print_freq = 10
save_freq = 5000
test_freq = 100

In [0]:
def get_step(filename):
  match = re.findall('(\d+).pt', filename)[0]
  return int(match)
  
def latest_checkpoint_files(check_dir, f):
  return max(map(f, os.listdir(check_dir)))

In [0]:
class UGATIT(object):
  def __init__(self):
    self.img_size=image_size
    self.batch_size=batch_size
    self.device=device
    self.n_res=n_blocks
    self.ch = ngf
    self.light=light
    self.lr = learning_rate
    self.weight_decay = weight_decay

    self.adv_weight = weight_adv
    self.cycle_weight = weight_cyc
    self.identity_weight = weight_idt
    self.cam_weight = weight_cam


  def build_model(self):
    """ DataLoader """
    train_transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.Resize((self.img_size + 30, self.img_size+30)),
        transforms.RandomCrop(self.img_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    ])
    test_transform = transforms.Compose([
        transforms.Resize((self.img_size, self.img_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    ])

    self.trainA = ImageFolder(A_dir, 'jpg', train_transform)
    self.trainB = ImageFolder(B_dir, 'jpg', train_transform)
    self.trainA_loader = DataLoader(self.trainA, batch_size=self.batch_size, shuffle=True)
    self.trainB_loader = DataLoader(self.trainB, batch_size=self.batch_size, shuffle=True)

    """ Define Generator, Discriminator """
    self.genA2B = ResnetGenerator(input_nc=3, output_nc=3, ngf=self.ch, 
                                  n_blocks=self.n_res, img_size=self.img_size, 
                                  light=self.light)
    self.genB2A = ResnetGenerator(input_nc=3, output_nc=3, 
                                  ngf=self.ch, n_blocks=self.n_res, 
                                  img_size=self.img_size, light=self.light)
    self.disGA = Discriminator(input_nc=3, ndf=self.ch, n_layers=7)
    self.disGB = Discriminator(input_nc=3, ndf=self.ch, n_layers=7)
    self.disLA = Discriminator(input_nc=3, ndf=self.ch, n_layers=5)
    self.disLB = Discriminator(input_nc=3, ndf=self.ch, n_layers=5)

    # Adding the initialization step
    init_net(self.genA2B, self.device, stddev=stddev)
    init_net(self.genB2A, self.device, stddev=stddev)
    init_net(self.disGA, self.device, stddev=stddev)
    init_net(self.disGB, self.device, stddev=stddev)
    init_net(self.disLA, self.device, stddev=stddev)
    init_net(self.disLB, self.device, stddev=stddev)

    """ Define Loss """
    self.L1_loss = nn.L1Loss().to(self.device)
    self.MSE_loss = nn.MSELoss().to(self.device)
    self.BCE_loss = nn.BCEWithLogitsLoss().to(self.device)

    """ Trainer """
    self.G_optim = torch.optim.Adam(itertools.chain(self.genA2B.parameters(), self.genB2A.parameters()), 
                                    lr=self.lr, betas=(0.5, 0.999), weight_decay=self.weight_decay)
    self.D_optim = torch.optim.Adam(itertools.chain(self.disGA.parameters(), self.disGB.parameters(), 
                                                    self.disLA.parameters(), self.disLB.parameters()), 
                                    lr=self.lr, betas=(0.5, 0.999), weight_decay=self.weight_decay)

    """ Define Rho clipper to constraint the value of rho in AdaILN and ILN"""
    self.Rho_clipper = RhoClipper(0, 1)

    """Images buffer for fake updates of discriminators so they can obtain the history"""
    self.fake_A_buffer = ImagePool(pool_size=pool_size)
    self.fake_B_buffer = ImagePool(pool_size=pool_size)

    """Summary writing to tensorboard"""
    self.writer = SummaryWriter(log_dir=summary_dir)


  def train(self):
    self.genA2B.train(), self.genB2A.train(), self.disGA.train(), self.disGB.train(), self.disLA.train(), self.disLB.train()

    start_iter = 1
    if load_model:
      print("---- Loading the model ----".upper())
      model_list = glob.glob(os.path.join(checkpoint_dir, '*.pt'))
      if len(model_list) > 0:
        start_iter = latest_checkpoint_files(checkpoint_dir, get_step)
        print('Load the model from {}'.format(start_iter))
        self.load(checkpoint_dir, start_iter)
        if decay_flag and start_iter > decay_start:
          self.G_optim.param_groups[0]['lr'] -= (lr / decay_start) * (start_iter - decay_start)
          self.D_optim.param_groups[0]['lr'] -= (lr / decay_start) * (start_iter - decay_start)
        start_iter += 1  # for not repeating the start_iter step

    # training loop
    try:
      print('training start !')
      start_time = time.time()
      counter=start_iter
      for step in tqdm_notebook(range(start_iter, n_steps + 1), total=n_steps-start_iter):
        if decay_flag and step > (n_steps // 2):
            self.G_optim.param_groups[0]['lr'] -= (self.lr / (n_steps // 2))
            self.D_optim.param_groups[0]['lr'] -= (self.lr / (n_steps // 2))

        try:
            real_A = trainA_iter.next()
        except:
            trainA_iter = iter(self.trainA_loader)
            real_A = trainA_iter.next()

        try:
            real_B = trainB_iter.next()
        except:
            trainB_iter = iter(self.trainB_loader)
            real_B = trainB_iter.next()

        real_A, real_B = real_A.to(self.device), real_B.to(self.device)

        # Update G
        self.G_optim.zero_grad()

        fake_A2B, fake_A2B_cam_logit, _ = self.genA2B(real_A)
        fake_B2A, fake_B2A_cam_logit, _ = self.genB2A(real_B)

        fake_A2B2A, _, _ = self.genB2A(fake_A2B)
        fake_B2A2B, _, _ = self.genA2B(fake_B2A)

        fake_A2A, fake_A2A_cam_logit, _ = self.genB2A(real_A)
        fake_B2B, fake_B2B_cam_logit, _ = self.genA2B(real_B)

        fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A)
        fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A)
        fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B)
        fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B)

        G_ad_loss_GA = self.MSE_loss(fake_GA_logit, ground_truth_tensor(fake_GA_logit, real_label))
        G_ad_cam_loss_GA = self.MSE_loss(fake_GA_cam_logit, ground_truth_tensor(fake_GA_cam_logit, real_label))
        G_ad_loss_LA = self.MSE_loss(fake_LA_logit, ground_truth_tensor(fake_LA_logit, real_label))
        G_ad_cam_loss_LA = self.MSE_loss(fake_LA_cam_logit, ground_truth_tensor(fake_LA_cam_logit, real_label))
        G_ad_loss_GB = self.MSE_loss(fake_GB_logit, ground_truth_tensor(fake_GB_logit, real_label))
        G_ad_cam_loss_GB = self.MSE_loss(fake_GB_cam_logit, ground_truth_tensor(fake_GB_cam_logit, real_label))
        G_ad_loss_LB = self.MSE_loss(fake_LB_logit, ground_truth_tensor(fake_LB_logit, real_label))
        G_ad_cam_loss_LB = self.MSE_loss(fake_LB_cam_logit, ground_truth_tensor(fake_LB_cam_logit, real_label))

        # Cycle loss
        G_recon_loss_A = self.L1_loss(fake_A2B2A, real_A)
        G_recon_loss_B = self.L1_loss(fake_B2A2B, real_B)

        # Identity loss
        G_identity_loss_A = self.L1_loss(fake_A2A, real_A)
        G_identity_loss_B = self.L1_loss(fake_B2B, real_B)

        # CAM loss
        G_cam_loss_A = self.BCE_loss(fake_B2A_cam_logit, torch.ones_like(fake_B2A_cam_logit).to(self.device)) + self.BCE_loss(fake_A2A_cam_logit, torch.zeros_like(fake_A2A_cam_logit).to(self.device))
        G_cam_loss_B = self.BCE_loss(fake_A2B_cam_logit, torch.ones_like(fake_A2B_cam_logit).to(self.device)) + self.BCE_loss(fake_B2B_cam_logit, torch.zeros_like(fake_B2B_cam_logit).to(self.device))

        # Combine all losses together
        G_loss_A =  self.adv_weight * (G_ad_loss_GA + G_ad_cam_loss_GA + G_ad_loss_LA + G_ad_cam_loss_LA) + self.cycle_weight * G_recon_loss_A + self.identity_weight * G_identity_loss_A + self.cam_weight * G_cam_loss_A
        G_loss_B = self.adv_weight * (G_ad_loss_GB + G_ad_cam_loss_GB + G_ad_loss_LB + G_ad_cam_loss_LB) + self.cycle_weight * G_recon_loss_B + self.identity_weight * G_identity_loss_B + self.cam_weight * G_cam_loss_B

        Generator_loss = G_loss_A + G_loss_B
        Generator_loss.backward()
        self.G_optim.step()

        # Update D
        self.D_optim.zero_grad()

        # update discriminators on the history of the generated samples
        fake_A2B = self.fake_B_buffer.sample(fake_A2B)
        fake_B2A = self.fake_A_buffer.sample(fake_B2A)

        real_GA_logit, real_GA_cam_logit, _ = self.disGA(real_A)
        real_LA_logit, real_LA_cam_logit, _ = self.disLA(real_A)
        real_GB_logit, real_GB_cam_logit, _ = self.disGB(real_B)
        real_LB_logit, real_LB_cam_logit, _ = self.disLB(real_B)

        fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A.detach())
        fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A.detach())
        fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B.detach())
        fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B.detach())

        D_ad_loss_GA = self.MSE_loss(real_GA_logit, ground_truth_tensor(real_GA_logit, real_label).to(self.device)) + \
          self.MSE_loss(fake_GA_logit, torch.zeros_like(fake_GA_logit).to(self.device))

        D_ad_cam_loss_GA = self.MSE_loss(real_GA_cam_logit, ground_truth_tensor(real_GA_cam_logit, real_label).to(self.device)) + \
          self.MSE_loss(fake_GA_cam_logit, torch.zeros_like(fake_GA_cam_logit).to(self.device))

        D_ad_loss_LA = self.MSE_loss(real_LA_logit, ground_truth_tensor(real_LA_logit, real_label)) + \
          self.MSE_loss(fake_LA_logit, torch.zeros_like(fake_LA_logit).to(self.device))

        D_ad_cam_loss_LA = self.MSE_loss(real_LA_cam_logit, ground_truth_tensor(real_LA_cam_logit, real_label)) + \
          self.MSE_loss(fake_LA_cam_logit, torch.zeros_like(fake_LA_cam_logit).to(self.device))

        D_ad_loss_GB = self.MSE_loss(real_GB_logit, ground_truth_tensor(real_GB_logit, real_label)) + \
          self.MSE_loss(fake_GB_logit, torch.zeros_like(fake_GB_logit).to(self.device))

        D_ad_cam_loss_GB = self.MSE_loss(real_GB_cam_logit, ground_truth_tensor(real_GB_cam_logit, real_label)) + \
          self.MSE_loss(fake_GB_cam_logit, torch.zeros_like(fake_GB_cam_logit).to(self.device))

        D_ad_loss_LB = self.MSE_loss(real_LB_logit, ground_truth_tensor(real_LB_logit, real_label)) + \
          self.MSE_loss(fake_LB_logit, torch.zeros_like(fake_LB_logit).to(self.device))

        D_ad_cam_loss_LB = self.MSE_loss(real_LB_cam_logit, ground_truth_tensor(real_LB_cam_logit, real_label)) + \
          self.MSE_loss(fake_LB_cam_logit, torch.zeros_like(fake_LB_cam_logit).to(self.device))

        D_loss_A = self.adv_weight * (D_ad_loss_GA + D_ad_cam_loss_GA + D_ad_loss_LA + D_ad_cam_loss_LA)
        D_loss_B = self.adv_weight * (D_ad_loss_GB + D_ad_cam_loss_GB + D_ad_loss_LB + D_ad_cam_loss_LB)

        Discriminator_loss = D_loss_A + D_loss_B
        Discriminator_loss.backward()
        self.D_optim.step()


        # clip parameter of AdaILN and ILN, applied after optimizer step
        self.genA2B.apply(self.Rho_clipper)
        self.genB2A.apply(self.Rho_clipper)

        if step % print_freq == 0:
          time_spent = time.time() - start_time
          h = time_spent // 3600
          m = (time_spent - h * 3600) // 60
          s = int(time_spent - h * 3600 - m * 60)
          print("[{}/{}]\nD_loss : {:.4f}\nG_loss: {:.4f}".format(step, n_steps, Discriminator_loss, Generator_loss))
          print("Time: {}:{}:{}".format(h, m, s))

          # Summary writing
          G_adv_loss_A = (G_ad_loss_GA + G_ad_cam_loss_GA + G_ad_loss_LA + G_ad_cam_loss_LA)
          G_adv_loss_B = (G_ad_loss_GB + G_ad_cam_loss_GB + G_ad_loss_LB + G_ad_cam_loss_LB)
          print(("loss_G_AB: {:.3f}\nloss_G_AB_GAN: {:.3f}\nloss_G_AB_identity: {:.3f}\n" + 
                "loss_G_AB_cycle: {:.3f}\nloss_G_AB_cam: {:.3f}\nloss_G_BA: {:.3f}\nloss_G_BA_GAN: {:.3f}\n" + 
                "loss_G_BA_identity: {:.3f}\nloss_G_BA_cycle: {:.3f}\nloss_G_BA_cam: {:.3f}\n").format(
                    G_loss_B, G_adv_loss_B, G_identity_loss_B, G_recon_loss_B, G_cam_loss_B, 
                    G_loss_A, G_adv_loss_A, G_identity_loss_A, G_recon_loss_A, G_cam_loss_A
                ))
          self.writer.add_scalar('G/G_AB', G_loss_B, global_step=step)
          self.writer.add_scalar('G/G_AB_GAN', G_adv_loss_B, global_step=step)
          self.writer.add_scalar('G/G_AB_identity', G_identity_loss_B, global_step=step)
          self.writer.add_scalar('G/G_AB_cycle', G_recon_loss_B, global_step=step)
          self.writer.add_scalar('G/G_AB_cam', G_cam_loss_B, global_step=step)

          self.writer.add_scalar('G/G_BA', G_loss_A, global_step=step)
          self.writer.add_scalar('G/G_BA_GAN', G_adv_loss_A, global_step=step)
          self.writer.add_scalar('G/G_BA_identity', G_identity_loss_A, global_step=step)
          self.writer.add_scalar('G/G_BA_cycle', G_recon_loss_A, global_step=step)
          self.writer.add_scalar('G/G_BA_cam', G_cam_loss_A, global_step=step)

          # Discriminator monitoring
          D_LA = D_ad_loss_LA + D_ad_cam_loss_LA
          D_GA = D_ad_loss_GA + D_ad_cam_loss_GA
          
          D_LB = D_ad_loss_LB + D_ad_cam_loss_LB
          D_GB = D_ad_loss_GB + D_ad_cam_loss_GB
          print(("loss_D_LA: {:.3f}\nloss_D_LA_GAN: {:.3f}\nloss_D_LA_cam: {:.3f}\n" + 
                "loss_D_GA: {:.3f}\nloss_D_GA_GAN: {:.3f}\nloss_D_GA_cam: {:.3f}\n" + 
                "loss_D_LB: {:.3f}\nloss_D_LB_GAN: {:.3f}\nloss_D_LB_cam: {:.3f}\n" +
                "loss_D_GB: {:.3f}\nloss_D_GB_GAN: {:.3f}\nloss_D_GB_cam: {:.3f}\n").format(
                    D_LA, D_ad_loss_LA, D_ad_cam_loss_LA, D_GA, D_ad_loss_GA, D_ad_cam_loss_GA,
                    D_LB, D_ad_loss_LB, D_ad_cam_loss_LB, D_GB, D_ad_loss_GB, D_ad_cam_loss_GB
                ))
          self.writer.add_scalar("D/D_LA", D_LA, global_step=step)
          self.writer.add_scalar("D/D_LA_GAN", D_ad_loss_LA, global_step=step)
          self.writer.add_scalar("D/D_LA_cam", D_ad_cam_loss_LA, global_step=step)
          self.writer.add_scalar("D/D_GA", D_GA, global_step=step)
          self.writer.add_scalar("D/D_GA_GAN", D_ad_loss_GA, global_step=step)
          self.writer.add_scalar("D/D_GA_cam", D_ad_cam_loss_GA, global_step=step)
          self.writer.add_scalar("D/D_LB", D_LB, global_step=step)
          self.writer.add_scalar("D/D_LB_GAN", D_ad_loss_LB, global_step=step)
          self.writer.add_scalar("D/D_LB_cam", D_ad_cam_loss_LB, global_step=step)
          self.writer.add_scalar("D/D_GB", D_GB, global_step=step)
          self.writer.add_scalar("D/D_GB_GAN", D_ad_loss_GB, global_step=step)
          self.writer.add_scalar("D/D_GB_cam", D_ad_cam_loss_GB, global_step=step)
          
        if step % save_freq == 0:
          print("Save the checkpoint : ", step)
          self.save(checkpoint_dir, step)

        if step % test_freq == 0:
        # copied from original implementation
          train_sample_num = 5

          A2B = np.zeros((image_size * 4, 0, 3))
          B2A = np.zeros((image_size * 4, 0, 3))

          self.genA2B.eval(), self.genB2A.eval(), self.disGA.eval(), self.disGB.eval(), self.disLA.eval(), self.disLB.eval()
          
          for _ in range(train_sample_num):
            try:
                real_A = trainA_iter.next()
            except:
                trainA_iter = iter(self.trainA_loader)
                real_A = trainA_iter.next()

            try:
                real_B = trainB_iter.next()
            except:
                trainB_iter = iter(self.trainB_loader)
                real_B = trainB_iter.next()
            real_A, real_B = real_A.to(device), real_B.to(device)

            fake_A2B, _, fake_A2B_heatmap = self.genA2B(real_A)
            fake_B2A, _, fake_B2A_heatmap = self.genB2A(real_B)

            fake_A2B2A, _, fake_A2B2A_heatmap = self.genB2A(fake_A2B)
            fake_B2A2B, _, fake_B2A2B_heatmap = self.genA2B(fake_B2A)

            fake_A2A, _, fake_A2A_heatmap = self.genB2A(real_A)
            fake_B2B, _, fake_B2B_heatmap = self.genA2B(real_B)

            A2B = np.concatenate((A2B, np.concatenate((RGB2BGR(tensor2numpy(denorm(real_A[0]))),
                                                        cam(tensor2numpy(fake_A2B_heatmap[0]), image_size),
                                                        RGB2BGR(tensor2numpy(denorm(fake_A2B[0]))),
                                                        RGB2BGR(tensor2numpy(denorm(fake_A2B2A[0])))), 0)), 1)

            B2A = np.concatenate((B2A, np.concatenate((RGB2BGR(tensor2numpy(denorm(real_B[0]))),
                                                        cam(tensor2numpy(fake_B2A_heatmap[0]), image_size),
                                                        RGB2BGR(tensor2numpy(denorm(fake_B2A[0]))),
                                                        RGB2BGR(tensor2numpy(denorm(fake_B2A2B[0])))), 0)), 1)
          cv2.imwrite(os.path.join(images_dir,'A2B_%07d.png' % step), A2B * 255.0)
          cv2.imwrite(os.path.join(images_dir,'B2A_%07d.png' % step), B2A * 255.0)
          self.genA2B.train(), self.genB2A.train(), self.disLA.train(), self.disLB.train(), self.disGA.train(), self.disGB.train()
        counter += 1
    except Exception as e:
      raise e
    finally:
      print("---- Save the checkpoint ----")
      self.save(checkpoint_dir, counter)
      print("---- Finished ----".upper())



  def load(self, path, step):
    params = torch.load(os.path.join(path, 'params_{}.pt'.format(step)))
    self.genA2B.load_state_dict(params['genA2B'])
    self.genB2A.load_state_dict(params['genB2A'])
    self.disGA.load_state_dict(params['disGA'])
    self.disGB.load_state_dict(params['disGB'])
    self.disLA.load_state_dict(params['disLA'])
    self.disLB.load_state_dict(params['disLB'])

  def save(self, path, step):
    params = {}
    params['genA2B'] = self.genA2B.state_dict()
    params['genB2A'] = self.genB2A.state_dict()
    params['disGA'] = self.disGA.state_dict()
    params['disGB'] = self.disGB.state_dict()
    params['disLA'] = self.disLA.state_dict()
    params['disLB'] = self.disLB.state_dict()
    torch.save(params, os.path.join(path, 'params_{}.pt'.format(step)))

In [12]:
model = UGATIT()
model.build_model()

In [0]:
model.train()

training start !


HBox(children=(IntProgress(value=0, max=299999), HTML(value='')))

[10/300000]
D_loss : 4.7980
G_loss: 2725.9902
Time: 0.0:0.0:16
loss_G_AB: 1578.548
loss_G_AB_GAN: 2.423
loss_G_AB_identity: 0.599
loss_G_AB_cycle: 0.588
loss_G_AB_cam: 1.564
loss_G_BA: 1147.442
loss_G_BA_GAN: 2.719
loss_G_BA_identity: 0.402
loss_G_BA_cycle: 0.437
loss_G_BA_cam: 1.136

loss_D_LA: 0.968
loss_D_LA_GAN: 0.695
loss_D_LA_cam: 0.273
loss_D_GA: 1.112
loss_D_GA_GAN: 0.578
loss_D_GA_cam: 0.533
loss_D_LB: 1.325
loss_D_LB_GAN: 0.713
loss_D_LB_cam: 0.612
loss_D_GB: 1.393
loss_D_GB_GAN: 0.582
loss_D_GB_cam: 0.811

[20/300000]
D_loss : 3.6381
G_loss: 1448.8079
Time: 0.0:0.0:33
loss_G_AB: 729.986
loss_G_AB_GAN: 1.858
loss_G_AB_identity: 0.172
loss_G_AB_cycle: 0.180
loss_G_AB_cam: 0.725
loss_G_BA: 718.822
loss_G_BA_GAN: 1.393
loss_G_BA_identity: 0.373
loss_G_BA_cycle: 0.408
loss_G_BA_cam: 0.710

loss_D_LA: 0.654
loss_D_LA_GAN: 0.323
loss_D_LA_cam: 0.332
loss_D_GA: 0.880
loss_D_GA_GAN: 0.401
loss_D_GA_cam: 0.479
loss_D_LB: 0.935
loss_D_LB_GAN: 0.479
loss_D_LB_cam: 0.456
loss_D_GB: 1.169

In [0]:
"""
function preventFromOff(){
  console.log("Click button");
  document.querySelector("colab-toolbar-button#connect").click() 
}
var timeout = 8 * 60 * 60 * 1000;
var delay = 5 * 60 * 1000;
var refreshId = setInterval(preventFromOff, delay);
setTimeout(() => {clearInterval(refreshId); console.log("Stopped script");}, timeout);
"""

In [0]:
%load_ext tensorboard

In [0]:
%tensorboard --logdir=ugatit/output/summary/summary_v0.2.1/