## Version 0.1.1
- Architectural changes
  - Dilated convolutions (+)
  - SENet (+)

- Stable training
  - Spectral normalization (+)
  - Label smoothing (+)
  - TTUR (+)

In [1]:
from google.colab import drive
import os
drive.mount('/content/drive/')
path_basic = 'drive/My Drive/gan_experiments'
os.chdir(path_basic)

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /content/drive/


In [0]:
!ln -sf /opt/bin/nvidia-smi /usr/bin/nvidia-smi
!pip install gputil

!pip install psutil
!pip install humanize
import psutil
import humanize
import os
import GPUtil as GPU
GPUs = GPU.getGPUs()
# XXX: only one GPU on Colab and isn’t guaranteed
gpu = GPUs[0]
def printm():
 process = psutil.Process(os.getpid())
 print(f"GPU name: {gpu.name}")
 print("Gen RAM Free: " + humanize.naturalsize( psutil.virtual_memory().available ), " | Proc size: " + humanize.naturalsize( process.memory_info().rss))
 print("GPU RAM Free: {0:.0f}MB | Used: {1:.0f}MB | Util {2:3.0f}% | Total {3:.0f}MB".format(gpu.memoryFree, gpu.memoryUsed, gpu.memoryUtil*100, gpu.memoryTotal))
printm() 

Collecting gputil
  Downloading https://files.pythonhosted.org/packages/ed/0e/5c61eedde9f6c87713e89d794f01e378cfd9565847d4576fa627d758c554/GPUtil-1.4.0.tar.gz
Building wheels for collected packages: gputil
  Building wheel for gputil (setup.py) ... [?25l[?25hdone
  Created wheel for gputil: filename=GPUtil-1.4.0-cp36-none-any.whl size=7413 sha256=b37b26575b7f87f8c6c60c1e72ff3a76958a39997fe2fb9abe351038238508ae
  Stored in directory: /root/.cache/pip/wheels/3d/77/07/80562de4bb0786e5ea186911a2c831fdd0018bda69beab71fd
Successfully built gputil
Installing collected packages: gputil
Successfully installed gputil-1.4.0
GPU name: Tesla P100-PCIE-16GB
Gen RAM Free: 12.8 GB  | Proc size: 157.9 MB
GPU RAM Free: 16280MB | Used: 0MB | Util   0% | Total 16280MB


In [0]:
import torch
from torch.nn import init
from torch import nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader,Dataset
from torch.autograd import Variable
from torch.utils.tensorboard import SummaryWriter 

from skimage.filters import threshold_otsu
import cv2
from PIL import Image
import functools
import itertools
import numpy as np
import random

import os
import sys
import glob
from tqdm import tqdm_notebook
import re

import matplotlib.pyplot as plt
import pickle
import time

### Additional modules implementation

#### Dilated convolutions

In [0]:
# test block(not used further)
class DilatedBlock(nn.Module):
  def __init__(self, dilations):
    super(DilatedBlock ,self).__init__()
    self.dilations = dilations
    k=3
    for d in dilations:
      p = int(((k - 1) * (d - 1) + k - 1) / 2)
      setattr(self, f"conv_{d}", nn.Conv2d(3,3,3,dilation=d,padding=p))

  def forward(self, x):
    print(f"Input shape: {x.shape}")
    h = self.conv_1(x)

    for d in self.dilations:
      x = getattr(self, f"conv_{d}")(x)
      print(f"Dilation-{d} shape: {x.shape}")
    return h

In [4]:
inp = torch.rand(1,3,64,64,)
dilations = [1,1,2,4,8,16,64,1]
model = DilatedBlock(dilations)
_ = model(inp)

Input shape: torch.Size([1, 3, 64, 64])
Dilation-1 shape: torch.Size([1, 3, 64, 64])
Dilation-1 shape: torch.Size([1, 3, 64, 64])
Dilation-2 shape: torch.Size([1, 3, 64, 64])
Dilation-4 shape: torch.Size([1, 3, 64, 64])
Dilation-8 shape: torch.Size([1, 3, 64, 64])
Dilation-16 shape: torch.Size([1, 3, 64, 64])
Dilation-64 shape: torch.Size([1, 3, 64, 64])
Dilation-1 shape: torch.Size([1, 3, 64, 64])


In [0]:
class SENet(nn.Module):
  def __init__(self, input_nc, r=16):
    super(SENet, self).__init__()
    self.input_nc = input_nc
    self.fc1 = nn.utils.spectral_norm(nn.Linear(input_nc, int(input_nc / r)))
    self.relu = nn.ReLU()
    self.fc2 = nn.utils.spectral_norm(nn.Linear(int(input_nc / r), input_nc))
    self.sigmoid = nn.Sigmoid()

  def forward(self, x):
    input_dim = x.shape[2]
    gap = F.avg_pool2d(x, input_dim)
    gap = gap.view(gap.shape[0], -1)

    h = self.fc1(gap)
    h = self.relu(h)
    h = self.fc2(h)
    attn = self.sigmoid(h)

    output = x * attn.unsqueeze(2).unsqueeze(3)
    return output

In [0]:
in_data = torch.rand(1,32,32,32)
model = SENet(32)
out = model(in_data)

In [0]:
# timer utils
import time
import functools
def timer(f):
  @functools.wraps(f)
  def wrapper(*args, **kwargs):
    start_time = time.time()
    r = f(*args, **kwargs)
    duration = time.time() - start_time
    result = {
        'result':r,
        'time':duration
    }
    return result
  return wrapper

### Blocks

In [0]:
class ResidualBlock(nn.Module):
    """
      SEDResidualBlock
      Parameters:
        dilation_factor -- dilated convolution hyperparam
        r -- reduction factor for SE bottleneck
    """
    def __init__(self, in_features, dilation_factor, r=4):
        super(ResidualBlock, self).__init__()
        k = 3
        d = dilation_factor
        pad = int(((k - 1) * (d - 1) + k - 1) / 2)

        conv_block = [  nn.ReflectionPad2d(pad),
                        nn.utils.spectral_norm(nn.Conv2d(in_features, in_features, 3, dilation=d)),
                        nn.InstanceNorm2d(in_features),
                        nn.ReLU(inplace=True),
                      
                        nn.ReflectionPad2d(pad),
                        nn.utils.spectral_norm(nn.Conv2d(in_features, in_features, 3, dilation=d)),
                        nn.InstanceNorm2d(in_features),
                        SENet(in_features, r=r)
                      ]

        self.conv_block = nn.Sequential(*conv_block)

    def forward(self, x):
        return x + self.conv_block(x)


class Generator_S2F(nn.Module):
    def __init__(self, input_nc, output_nc, n_residual_blocks=9,
                 dilation_factors=[1,1,2,4,8,16,2,1,1],
                 reduction=4):
        super(Generator_S2F, self).__init__()
        assert len(dilation_factors) == n_residual_blocks

        # Initial convolution block
        model = [   nn.ReflectionPad2d(3),
                    nn.utils.spectral_norm(nn.Conv2d(input_nc, 64, 7)),
                    nn.InstanceNorm2d(64),
                    nn.ReLU(inplace=True) ]

        # Downsampling
        in_features = 64
        out_features = in_features*2
        for _ in range(2):
            model += [  nn.utils.spectral_norm(nn.Conv2d(in_features, out_features, 3, stride=2, padding=1)),
                        nn.InstanceNorm2d(out_features),
                        nn.ReLU(inplace=True) ]
            in_features = out_features
            out_features = in_features*2

        # Residual blocks
        for i in range(n_residual_blocks):
          d = dilation_factors[i]
          model += [ResidualBlock(in_features,dilation_factor=d,r=reduction)]

        # Upsampling
        out_features = in_features//2
        for _ in range(2):
            model += [  nn.utils.spectral_norm(nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1)),
                        nn.InstanceNorm2d(out_features),
                        nn.ReLU(inplace=True) ]
            in_features = out_features
            out_features = in_features//2

        # Output layer
        model += [  nn.ReflectionPad2d(3),
                    nn.utils.spectral_norm(nn.Conv2d(64, output_nc, 7)) ]
                    #nn.Tanh() ]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return (self.model(x) + x).tanh() #(min=-1, max=1) #just learn a residual


class Generator_F2S(nn.Module):
    def __init__(self, input_nc, output_nc, n_residual_blocks=9,
                 dilation_factors=[1,1,2,4,8,16,2,1,1],
                 reduction=4):
        super(Generator_F2S, self).__init__()
        assert len(dilation_factors) == n_residual_blocks

        # Initial convolution block
        model = [   nn.ReflectionPad2d(3),
                    nn.utils.spectral_norm(nn.Conv2d(input_nc+1, 64, 7)), # + mask
                    nn.InstanceNorm2d(64),
                    nn.ReLU(inplace=True) ]

        # Downsampling
        in_features = 64
        out_features = in_features*2
        for _ in range(2):
            model += [  nn.utils.spectral_norm(nn.Conv2d(in_features, out_features, 3, stride=2, padding=1)),
                        nn.InstanceNorm2d(out_features),
                        nn.ReLU(inplace=True) ]
            in_features = out_features
            out_features = in_features*2

        # Residual blocks
        for i in range(n_residual_blocks):
          d = dilation_factors[i]
          model += [ResidualBlock(in_features, dilation_factor=d, r=reduction)]

        # Upsampling
        out_features = in_features//2
        for _ in range(2):
            model += [  nn.utils.spectral_norm(nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1)),
                        nn.InstanceNorm2d(out_features),
                        nn.ReLU(inplace=True) ]
            in_features = out_features
            out_features = in_features//2

        # Output layer
        model += [  nn.ReflectionPad2d(3),
                    nn.utils.spectral_norm(nn.Conv2d(64, output_nc, 7)) ]
                    #nn.Tanh() ]

        self.model = nn.Sequential(*model)

    def forward(self, x, mask):
        return (self.model(torch.cat((x, mask), 1)) + x).tanh() #(min=-1, max=1) #just learn a residual


class Discriminator(nn.Module):
    def __init__(self, input_nc):
        super(Discriminator, self).__init__()

        # A bunch of convolutions one after another
        
        model = [nn.utils.spectral_norm(nn.Conv2d(input_nc, 64, 4, stride=2, padding=1)),
                 nn.LeakyReLU(0.2, inplace=True) ]

        model += [nn.utils.spectral_norm(nn.Conv2d(64, 128, 4, stride=2, padding=1)),
                  nn.InstanceNorm2d(128),
                  nn.LeakyReLU(0.2, inplace=True) ]

        model += [nn.utils.spectral_norm(nn.Conv2d(128, 256, 4, stride=2, padding=1)),
                  nn.InstanceNorm2d(256),
                  nn.LeakyReLU(0.2, inplace=True) ]

        model += [nn.utils.spectral_norm(nn.Conv2d(256, 512, 4, padding=1)),
                  nn.InstanceNorm2d(512),
                  nn.LeakyReLU(0.2, inplace=True) ]

        # FCN classification layer
        model += [nn.utils.spectral_norm(nn.Conv2d(512, 1, 4, padding=1))]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        x =  self.model(x)
        # Average pooling and flatten
        return F.avg_pool2d(x, x.size()[2:]).view(x.size()[0], -1) #global avg pool

In [0]:
# V0.1.0 num params
V010_G = 45.512
V010_D = 11.058
V010_S = 56.57

In [10]:
# summary of the models
def num_params(model):
  model_parameters = filter(lambda p: p.requires_grad, model.parameters())
  params = sum([np.prod(p.size()) for p in model_parameters])
  return params

dL = Discriminator(3)
g = Generator_S2F(3,3, reduction=4)

V011_D = round(num_params(dL) * 2 / 1e6, 3) * 2
V011_G = round(num_params(g) * 2 / 1e6, 3) * 2
V011_S = round(V011_D + V011_G, 3)

print('---- Summary models ----')
print("Number of parameters (in millions):")
print("{:10}{:10}{:20}".format("D", 'G', "Overall"))
print("{:10}{:10}{:20}".format(str(V011_D), str(V011_G), str(V011_S)))

print(f"GAIN PARAMS: {round((V011_S / V010_S - 1) * 100, 3)}%")

---- Summary models ----
Number of parameters (in millions):
D         G         Overall             
11.058    46.704    57.762              
GAIN PARAMS: 2.107%


### Buffers

In [0]:
class ReplayBuffer():
    def __init__(self, max_size=50):
        assert (max_size > 0), 'Empty buffer or trying to create a black hole. Be careful.'
        self.max_size = max_size
        self.data = []

    def push_and_pop(self, data):
        to_return = []
        for element in data.data:
            element = torch.unsqueeze(element, 0)
            if len(self.data) < self.max_size:
                self.data.append(element)
                to_return.append(element)
            else:
                if random.uniform(0,1) > 0.5:
                    i = random.randint(0, self.max_size-1)
                    to_return.append(self.data[i].clone())
                    self.data[i] = element
                else:
                    to_return.append(element)
        return Variable(torch.cat(to_return))

# class ImageDataset(Dataset):
#     def __init__(self, root, transforms_=None, unaligned=False, mode='train'):
#         self.transform = transforms.Compose(transforms_)
#         self.unaligned = unaligned

#         self.files_A = sorted(glob.glob(os.path.join(root, 'shadow_train') + '/*.*'))
#         self.files_B = sorted(glob.glob(os.path.join(root, 'shadow_free') + '/*.*'))

#     def __getitem__(self, index):
#         item_A = self.transform(Image.open(self.files_A[index % len(self.files_A)]))

#         if self.unaligned:
#             item_B = self.transform(Image.open(self.files_B[random.randint(0, len(self.files_B) - 1)]))
#         else:
#             item_B = self.transform(Image.open(self.files_B[index % len(self.files_B)]))

#         return {'A': item_A, 'B': item_B}

#     def __len__(self):
#         return max(len(self.files_A), len(self.files_B))



class ImageDataset(Dataset):
    def __init__(self, root, transforms_=None, mode='train'):
        self.transform = transforms.Compose(transforms_)
        self.files_A = sorted(glob.glob(os.path.join(root, '%s/train_A' % mode) + '/*.*'))
        self.files_B = sorted(glob.glob(os.path.join(root, '%s/train_C' % mode) + '/*.*'))

    @timer
    def __getitem__(self, index):
        item_A = self.transform(Image.open(self.files_A[index % len(self.files_A)]))
        item_B = self.transform(Image.open(self.files_B[random.randint(0, len(self.files_B) - 1)]))
        return {'A': item_A, 'B': item_B}

    def __len__(self):
        return max(len(self.files_A), len(self.files_B))

### Utils

In [0]:
to_pil = transforms.ToPILImage()
to_gray = transforms.Grayscale(num_output_channels=1)

class QueueMask():
    def __init__(self, length):
        self.max_length = length
        self.queue = []

    def insert(self, mask):
        if self.queue.__len__() >= self.max_length:
            self.queue.pop(0)

        self.queue.append(mask)

    def rand_item(self):
        assert self.queue.__len__() > 0, 'Error! Empty queue!'
        return self.queue[np.random.randint(0, self.queue.__len__())]

    def last_item(self):
        assert self.queue.__len__() > 0, 'Error! Empty queue!'
        return self.queue[self.queue.__len__()-1]


def mask_generator(shadow, shadow_free):
	im_f = to_gray(to_pil(((shadow_free.data.squeeze(0) + 1.0) * 0.5).cpu()))
	im_s = to_gray(to_pil(((shadow.data.squeeze(0) + 1.0) * 0.5).cpu()))

	diff = (np.asarray(im_f, dtype='float32')- np.asarray(im_s, dtype='float32')) # difference between shadow image and shadow_free image
	L = threshold_otsu(diff)
	mask = torch.tensor((np.float32(diff >= L)-0.5)/0.5).unsqueeze(0).unsqueeze(0).cuda() #-1.0:non-shadow, 1.0:shadow
	mask.requires_grad = False
	return mask



def tensor2image(tensor):
    image = 127.5*(tensor[0].cpu().float().numpy() + 1.0)
    if image.shape[0] == 1:
        image = np.tile(image, (3,1,1))
    return image.astype(np.uint8)


class LambdaLR():
    def __init__(self, n_epochs, offset, decay_start_epoch):
        assert ((n_epochs - decay_start_epoch) > 0), "Decay must start before the training session ends!"
        self.n_epochs = n_epochs
        self.offset = offset
        self.decay_start_epoch = decay_start_epoch

    def step(self, epoch):
        return 1.0 - max(0, epoch + self.offset - self.decay_start_epoch)/(self.n_epochs - self.decay_start_epoch)

def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm2d') != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)

In [0]:
def read_paths(path):
  with open(path, 'rb') as f:
    paths = pickle.load(f)
  return paths

def mkdir(path):
  try:
    os.mkdir(path)
  except FileExistsError as e:
    pass

def save_test(A_path, B_path, save_path):
  start_time = time.time()
  A_paths = read_paths(A_path)
  B_paths = read_paths(B_path)
  
  assert len(A_paths) == len(B_paths)
  # read and preprocess the test images
  os.makedirs(save_path, exist_ok=True)

  mkdir(os.path.join(save_path, 'A_B'))
  mkdir(os.path.join(save_path, 'B_A'))

  mkdir(os.path.join(save_path, 'masks'))

  # load the model
  netG_A2B = Generator_S2F(input_nc, output_nc).to(device)
  netG_B2A = Generator_F2S(input_nc, output_nc).to(device)


  # load latest checkpoint
  netG_A2B.load_state_dict(torch.load(os.path.join(checkpoint_dir, 'netG_A2B.pth')))
  netG_B2A.load_state_dict(torch.load(os.path.join(checkpoint_dir, 'netG_B2A.pth')))

  # turn the validation mode
  netG_A2B.eval()
  netG_B2A.eval()

  # input tensors
  Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.Tensor
  input_A = Tensor(batch_size, input_nc, image_size, image_size, 3)
  input_B = Tensor(batch_size, output_nc, image_size, image_size, 3)

  # input transformations
  img_transforms = transforms.Compose([
                                       transforms.Resize((image_size, image_size), interpolation=Image.BICUBIC),
                                       transforms.ToTensor(),
                                       transforms.Normalize((.5,.5,.5),(.5,.5,.5))
  ])
  to_pil = transforms.ToPILImage()

  image_queue = QueueMask(length=mask_queue_size)
  for i,(path_A, path_B) in enumerate(zip(A_paths, B_paths)):
    image_A = Image.open(path_A).convert("RGB")
    image_B = Image.open(path_A).convert("RGB")

    im_A = (img_transforms(image_A).unsqueeze(0)).to(device)
    im_B = (img_transforms(image_B).unsqueeze(0)).to(device)

    
    # Image.fromarray(np.array(transforms.Resize((image_size, image_size))(image_A))).save(os.path.join(save_path, 'A', 'A_{}.jpg'.format(i)))
    # Image.fromarray(np.array(transforms.Resize((image_size, image_size))(image_B))).save(os.path.join(save_path, 'B', 'B_{}.jpg'.format(i)))

    # generate A -> B
    A_B = netG_A2B(im_A)
    w,h = image_A.size

    current_mask = mask_generator(A_B, im_A)
    image_queue.insert(current_mask)
    A_B = .5 * (A_B + 1)
    A_B = np.array((to_pil(A_B.data.squeeze(0).cpu())))
    Image.fromarray(A_B).save(os.path.join(save_path, 'A_B', 'A_B_{}.jpg'.format(i)))

    # generate B -> A
    mask = image_queue.rand_item()
    B_A = netG_B2A(im_B, mask)
    w,h = image_B.size

    B_A = .5 * (B_A + 1)
    B_A = np.array((to_pil(B_A.data.squeeze(0).cpu())))
    Image.fromarray(B_A).save(os.path.join(save_path, 'B_A', 'B_A_{}.jpg'.format(i)))

    mask_cpu = np.array((to_pil(.5 * (current_mask.data + 1).squeeze(0).cpu())))
    Image.fromarray(mask_cpu).save(os.path.join(save_path, 'masks', 'mask_{}.jpg'.format(i)))

  print("---- Inference finished. Time : {} ----".format(time.time() - start_time).upper())
  return image_queue


def save_images(image_paths, save_path, domain):
  try:
    os.mkdir(save_path)
  except:
    pass

  for i in range(len(image_paths)):
    img = plt.imread(image_paths[i])
    plt.imsave(os.path.join(save_path,'{}_{}.jpg'.format(domain, i)), img)

### Train

In [0]:
checkpoint_dir = 'mask_shadow_gan/output/checkpoints/checkpoints_v0.1.1/1/'
images_dir = 'mask_shadow_gan/output/images/images_v0.1.1/1/'
summary_dir = 'mask_shadow_gan/output/summary/summary_v0.1.1/1/'
root_dir = 'data/ISTD_Dataset/'

# validation pathes
A_dir_inf = 'mask_shadow_gan/output/results/test_set_meta/test_paths/ISTD/shadow_path.pickle'
B_dir_inf = 'mask_shadow_gan/output/results/test_set_meta/test_paths/ISTD/free_path.pickle'
results_dir = 'mask_shadow_gan/output/results/v.0.1.1/1'



load_model = True
batch_size=1
image_size=256
ngf=64
ndf=64

lambda1=10
lambda2=10
identity_lambda = 0.5
learning_rate=2e-4
lr_D = 4e-4  # TTUR
lr_G = 1e-4  # TTUR

beta1=.5
mask_queue_size=50
slope=0.2
stddev=0.02

input_nc=3
output_nc=3

n_res_blocks=9
dilation_factors = [1,1,1,2,4,8,16,1,1]
REAL_LABEL=0.9  # label smoothing
reduction_factor=4

device=torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

n_epochs=202
decay_epoch=100
epoch=0
img_snapshot=500
model_snapshot=5
log_snapshot=10

coef_identity = 5
coef_cycle = 10
coef_adv = 1

In [0]:
class MaskShadowGAN(object):
  def __init__(self):
    pass

  def build(self):
    ###### Definition of variables ######
    # Networks
    print(f"------------------- Definition of variables -------------------")
    
    self.netG_A2B = Generator_S2F(input_nc, output_nc, 
                                  n_residual_blocks=n_res_blocks,
                                  dilation_factors=dilation_factors,
                                  reduction=reduction_factor)  # shadow to shadow_free
    self.netG_B2A = Generator_F2S(output_nc, input_nc,
                                  n_residual_blocks=n_res_blocks,
                                  dilation_factors=dilation_factors,
                                  reduction=reduction_factor)  # shadow_free to shadow
    self.netD_A = Discriminator(input_nc)
    self.netD_B = Discriminator(output_nc)


    self.netG_A2B.cuda()
    self.netG_B2A.cuda()
    self.netD_A.cuda()
    self.netD_B.cuda()

    self.netG_A2B.apply(weights_init_normal)
    self.netG_B2A.apply(weights_init_normal)
    self.netD_A.apply(weights_init_normal)
    self.netD_B.apply(weights_init_normal)

    # Lossess
    self.criterion_GAN = torch.nn.MSELoss()  # lsgan
    # criterion_GAN = torch.nn.BCEWithLogitsLoss() #vanilla
    self.criterion_cycle = torch.nn.L1Loss()
    self.criterion_identity = torch.nn.L1Loss()

    # Optimizers & LR schedulers
    self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A2B.parameters(), self.netG_B2A.parameters()),
                    lr=lr_G, betas=(0.5, 0.999))
    self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(), lr=lr_D, betas=(0.5, 0.999))
    self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(), lr=lr_D, betas=(0.5, 0.999))

    self.lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(self.optimizer_G,
                              lr_lambda=LambdaLR(n_epochs, epoch, decay_epoch).step)
    self.lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(self.optimizer_D_A,
                              lr_lambda=LambdaLR(n_epochs, epoch, decay_epoch).step)
    self.lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(self.optimizer_D_B,
                              lr_lambda=LambdaLR(n_epochs, epoch, decay_epoch).step)
    
    Tensor = torch.cuda.FloatTensor
    self.input_A = Tensor(batch_size, input_nc, image_size, image_size)
    self.input_B = Tensor(batch_size, output_nc, image_size, image_size)
    self.target_real = Variable(Tensor(batch_size).fill_(REAL_LABEL), requires_grad=False)
    self.target_fake = Variable(Tensor(batch_size).fill_(0.0), requires_grad=False)
    self.mask_non_shadow = Variable(Tensor(batch_size, 1, image_size, image_size).fill_(-1.0), requires_grad=False) #-1.0 non-shadow

    self.fake_A_buffer = ReplayBuffer()
    self.fake_B_buffer = ReplayBuffer()

    # Dataset loader
    self.transforms_ = [
            transforms.Resize(int(image_size * 1.12), Image.BICUBIC),
            transforms.RandomCrop(image_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]

    self.dataloader = DataLoader(ImageDataset(root_dir, transforms_=self.transforms_),
                                              batch_size=batch_size, shuffle=True, num_workers=1)

    """Summary writing to tensorboard"""
    self.writer = SummaryWriter(log_dir=summary_dir)
    return self
    

  def train(self):
    step = 0
    epoch = 0
    if load_model:
      print("Resume training")
      epoch, step = self.load_model()
      # epoch+=1

    to_pil = transforms.ToPILImage()
    mask_queue =  QueueMask(self.dataloader.__len__()/4)

    ###### Training ######
    print(f"------------------- Start Training -------------------")
    epoch_cp = epoch
    try:
      for ep in tqdm_notebook(range(epoch, n_epochs), total=n_epochs-epoch):
        for i, _batch in enumerate(self.dataloader):
          # Set model input
          batch = _batch['result']
          time_batch = _batch['time']
          real_A = Variable(self.input_A.copy_(batch['A']))
          real_B = Variable(self.input_B.copy_(batch['B']))

          ###### Generators A2B and B2A ######
          start_time = time.time()
          self.optimizer_G.zero_grad()

          # Identity loss
          # G_A2B(B) should equal B if real B is fed
          same_B = self.netG_A2B(real_B)
          loss_identity_B = self.criterion_identity(same_B, real_B)  # ||Gb(b)-b||1
          # G_B2A(A) should equal A if real A is fed, so the mask should be all zeros
          same_A = self.netG_B2A(real_A, self.mask_non_shadow)
          loss_identity_A = self.criterion_identity(same_A, real_A)  # ||Ga(a)-a||1

          # GAN loss
          fake_B = self.netG_A2B(real_A)
          pred_fake = self.netD_B(fake_B)
          loss_GAN_A2B = self.criterion_GAN(pred_fake, self.target_real)  # log(Db(Gb(a)))

          mask_queue.insert(mask_generator(real_A, fake_B))

          fake_A = self.netG_B2A(real_B, mask_queue.rand_item())
          pred_fake = self.netD_A(fake_A)
          loss_GAN_B2A = self.criterion_GAN(pred_fake, self.target_real)  # log(Da(Ga(b)))

          # Cycle loss
          recovered_A = self.netG_B2A(fake_B, mask_queue.last_item()) # real shadow, false shadow free
          loss_cycle_ABA = self.criterion_cycle(recovered_A, real_A) * 10.0  # ||Ga(Gb(a))-a||1

          recovered_B = self.netG_A2B(fake_A)
          loss_cycle_BAB = self.criterion_cycle(recovered_B, real_B) * 10.0  # ||Gb(Ga(b))-b||1

          # Total loss
          loss_G = coef_identity * (loss_identity_A + loss_identity_B) + \
            coef_adv * (loss_GAN_A2B + loss_GAN_B2A) + coef_cycle * (loss_cycle_ABA + loss_cycle_BAB)

          time_G_forward = time.time() - start_time

          start_time = time.time()
          loss_G.backward()
          self.optimizer_G.step()
          time_G_backward = time.time() - start_time

          ###################################

          ###### Discriminator A ######
          start_time = time.time()
          self.optimizer_D_A.zero_grad()

          # Real loss
          pred_real = self.netD_A(real_A)
          loss_D_real = self.criterion_GAN(pred_real, self.target_real)  # log(Da(a))

          # Fake loss
          fake_A = self.fake_A_buffer.push_and_pop(fake_A)
          pred_fake = self.netD_A(fake_A.detach())
          loss_D_fake = self.criterion_GAN(pred_fake, self.target_fake)  # log(1-Da(G(b)))

          # Total loss
          loss_D_A = (loss_D_real + loss_D_fake) * 0.5 * coef_adv
          time_DA_forward = time.time() - start_time
          
          start_time = time.time()
          loss_D_A.backward()

          self.optimizer_D_A.step()
          time_DA_backward = time.time() - start_time

          ###################################

          ###### Discriminator B ######
          start_time = time.time()
          self.optimizer_D_B.zero_grad()

          # Real loss
          pred_real = self.netD_B(real_B)
          loss_D_real = self.criterion_GAN(pred_real, self.target_real)  # log(Db(b))

          # Fake loss
          fake_B = self.fake_B_buffer.push_and_pop(fake_B)
          pred_fake = self.netD_B(fake_B.detach())
          loss_D_fake = self.criterion_GAN(pred_fake, self.target_fake)  # log(1-Db(G(a)))

          # Total loss
          loss_D_B = (loss_D_real + loss_D_fake) * 0.5 * coef_adv
          time_DB_forward = time.time() - start_time

          start_time = time.time()

          loss_D_B.backward()
          self.optimizer_D_B.step()

          time_DB_backward = time.time() - start_time
          ###################################

          step += 1
          if step % img_snapshot == 0:
            img_fake_A = 0.5 * (fake_A.detach().data + 1.0)
            img_fake_A = (to_pil(img_fake_A.data.squeeze(0).cpu()))
            img_fake_A.save(os.path.join(images_dir, f"fake_A_{step}.png"))

            img_fake_B = 0.5 * (fake_B.detach().data + 1.0)
            img_fake_B = (to_pil(img_fake_B.data.squeeze(0).cpu()))
            img_fake_B.save(os.path.join(images_dir, f"fake_B_{step}.png"))

          # logging
          if step % log_snapshot == 0:
            self.writer.add_scalar('G/GAN_A2B', loss_GAN_A2B, global_step=step)
            self.writer.add_scalar('G/GAN_B2A', loss_GAN_B2A, global_step=step)
            self.writer.add_scalar('G/cycle_ABA', loss_cycle_ABA, global_step=step)
            self.writer.add_scalar('G/cycle_BAB', loss_cycle_BAB, global_step=step)
            self.writer.add_scalar('G/idt_A', loss_identity_A, global_step=step)
            self.writer.add_scalar('G/idt_B', loss_identity_B, global_step=step)

            self.writer.add_scalar('D/D_A', loss_D_A, global_step=step)
            self.writer.add_scalar('D/D_B', loss_D_B, global_step=step)

            self.writer.add_scalar('T/batch', time_batch, global_step=step)
            self.writer.add_scalar('T/G_forward', time_G_forward, global_step=step)
            self.writer.add_scalar('T/G_backward', time_G_backward, global_step=step)

            self.writer.add_scalar('T/D_A_forward', time_DA_forward, global_step=step)
            self.writer.add_scalar('T/D_A_backward', time_DA_backward, global_step=step)

            self.writer.add_scalar('T/D_B_forward', time_DB_forward, global_step=step)
            self.writer.add_scalar('T/D_B_backward', time_DB_backward, global_step=step)


          

        # Update learning rates
        self.lr_scheduler_G.step()
        self.lr_scheduler_D_A.step()
        self.lr_scheduler_D_B.step()


        # Save models checkpoints
        if ep % model_snapshot == 0:
          print(f"Save model -- epoch: {ep}, step: {step}")
          self.save_model(ep, step)

          print('Run validation...')
          save_test(A_dir_inf, B_dir_inf, os.path.join(results_dir, f'msg_{ep}'))

        epoch_cp = ep


    except Exception as e:
      raise e
    finally:
      print(f"Save model -- epoch: {epoch_cp}, step: {step}")
      self.save_model(epoch_cp, step)



  def save_model(self, epoch_num, step):
    np.savetxt(os.path.join(checkpoint_dir, 'epoch_num.txt'), [epoch_num])
    np.savetxt(os.path.join(checkpoint_dir, 'step.txt'), [step])

    torch.save(self.netG_A2B.state_dict(), os.path.join(checkpoint_dir, 'netG_A2B.pth'))
    torch.save(self.netG_B2A.state_dict(), os.path.join(checkpoint_dir, 'netG_B2A.pth'))
    torch.save(self.netD_A.state_dict(), os.path.join(checkpoint_dir, 'netD_A.pth'))
    torch.save(self.netD_B.state_dict(), os.path.join(checkpoint_dir, 'netD_B.pth'))

    torch.save(self.optimizer_G.state_dict(), os.path.join(checkpoint_dir, 'optimizer_G.pth'))
    torch.save(self.optimizer_D_A.state_dict(), os.path.join(checkpoint_dir, 'optimizer_D_A.pth'))
    torch.save(self.optimizer_D_B.state_dict(), os.path.join(checkpoint_dir, 'optimizer_D_B.pth'))

    torch.save(self.lr_scheduler_G.state_dict(), os.path.join(checkpoint_dir, 'lr_scheduler_G.pth'))
    torch.save(self.lr_scheduler_D_A.state_dict(), os.path.join(checkpoint_dir, 'lr_scheduler_D_A.pth'))
    torch.save(self.lr_scheduler_D_B.state_dict(), os.path.join(checkpoint_dir, 'lr_scheduler_D_B.pth'))


  def load_model(self):
    epoch = int(np.loadtxt(os.path.join(checkpoint_dir, 'epoch_num.txt')))
    step  = int(np.loadtxt(os.path.join(checkpoint_dir, 'step.txt')))
    
    self.netG_A2B.load_state_dict(torch.load(os.path.join(checkpoint_dir, 'netG_A2B.pth')))
    self.netG_B2A.load_state_dict(torch.load(os.path.join(checkpoint_dir, 'netG_B2A.pth')))
    self.netD_A.load_state_dict(torch.load(os.path.join(checkpoint_dir, 'netD_A.pth')))
    self.netD_B.load_state_dict(torch.load((os.path.join(checkpoint_dir, 'netD_B.pth'))))

    self.optimizer_G.load_state_dict(torch.load((os.path.join(checkpoint_dir, 'optimizer_G.pth'))))
    self.optimizer_D_A.load_state_dict(torch.load((os.path.join(checkpoint_dir, 'optimizer_D_A.pth'))))
    self.optimizer_D_B.load_state_dict(torch.load(os.path.join(checkpoint_dir, 'optimizer_D_B.pth')))

    self.lr_scheduler_G.load_state_dict(torch.load(os.path.join(checkpoint_dir, 'lr_scheduler_G.pth')))
    self.lr_scheduler_D_A.load_state_dict(torch.load(os.path.join(checkpoint_dir, 'lr_scheduler_D_A.pth')))
    self.lr_scheduler_D_B.load_state_dict(torch.load(os.path.join(checkpoint_dir, 'lr_scheduler_D_B.pth')))
    print(f"Model loaded -- {epoch}:{step}")
    return epoch, step


In [19]:
model = MaskShadowGAN().build()

------------------- Definition of variables -------------------


In [0]:
# turn on save
model.train()

Resume training
Model loaded -- 199:283402
------------------- Start Training -------------------


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

  return F.mse_loss(input, target, reduction=self.reduction)


In [0]:
"""function preventFromOff(){
  console.log("Click button");
  document.querySelector("colab-connect-button").shadowRoot.getElementById("connect").click()
}
var timeout = 8 * 60 * 60 * 1000;
var delay = 3 * 60 * 1000;
var refreshId = setInterval(preventFromOff, delay);
setTimeout(() => {clearInterval(refreshId); console.log("Stopped script");}, timeout);"""