In [0]:
import torch
from torch.nn import init
from torch import nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.autograd import Variable

from skimage.filters import threshold_otsu
from PIL import Image
import functools
import itertools
import numpy as np
import random

import os
import sys
import glob
from tqdm import tqdm_notebook
import re

import matplotlib.pyplot as plt
import pickle
import time

### Layers

In [0]:
def init_weights(net, stddev=.02):

  def weights_initializer(m):
    classname = m.__class__.__name__
    if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
      init.normal_(m.weight.data, mean=0.0, std=stddev)

      if hasattr(m, 'bias') and m.bias is not None:
        init.constant_(m.bias.data, 0.0)

  net.apply(weights_initializer)
  return net


def init_net(net, device, stddev=0.02):
  net.to(device)
  init_weights(net, stddev=stddev)
  return net


def build_generator(input_nc, device, n_blocks=9, ngf=64, stddev=0.02):
  model = ResNetGenerator(input_nc, n_blocks=n_blocks, ngf=ngf)
  return init_net(model, device=device, stddev=stddev)

def build_discriminator(input_nc, device, ndf=64, stddev=0.02, slope=.2):
  model = Disriminator(input_nc, ndf=ndf, slope=slope)
  return init_net(model, device=device, stddev=stddev)


class Generator_S2F(nn.Module):
  def __init__(self, input_nc, n_blocks=9, ngf=64):
    super(Generator_S2F, self).__init__()

    # c7s1_k
    model = [nn.ReflectionPad2d(3),
             nn.Conv2d(input_nc, ngf, 7, stride=1, padding=0),
             nn.InstanceNorm2d(ngf),
             nn.ReLU(inplace=True)]
    # 2x downsample
    model += [nn.Conv2d(ngf, ngf * 2, 3, stride=2, padding=1),
              nn.InstanceNorm2d(ngf * 2),
              nn.ReLU(True)]

    model += [nn.Conv2d(ngf * 2, ngf * 4, 3, stride=2, padding=1),
              nn.InstanceNorm2d(ngf * 4),
              nn.ReLU(True)]

    # resnet blocks
    for n in range(n_blocks):
      model += [ResNetBlock(ngf * 4, ngf * 4)]

    # 2x upsampling
    model += [nn.ConvTranspose2d(ngf * 4, ngf * 2, 3, stride=2, padding=1, 
                                 output_padding=1),
              nn.InstanceNorm2d(ngf * 2),
              nn.ReLU(True)]

    model += [nn.ConvTranspose2d(ngf * 2, ngf, 3, stride=2, padding=1, 
                                 output_padding=1),
              nn.InstanceNorm2d(ngf),
              nn.ReLU(True)]

    
    # c7s1_3
    model += [nn.ReflectionPad2d(3),
              nn.Conv2d(ngf, 3, 7, padding=0, stride=1)]

    self.model = nn.Sequential(*model)

  def forward(self, x):
    return (self.model(x) + x).tanh()

class Generator_F2S(nn.Module):
  def __init__(self, input_nc, output_nc, ngf=64, n_blocks=9):
    super(Generator_F2S, self).__init__()
    
    # c7s1_k
    model = [nn.ReflectionPad2d(3),
             nn.Conv2d(input_nc, ngf, 7, stride=1, padding=0),
             nn.InstanceNorm2d(ngf),
             nn.ReLU(inplace=True)]
    # 2x downsample
    model += [nn.Conv2d(ngf, ngf * 2, 3, stride=2, padding=1),
              nn.InstanceNorm2d(ngf * 2),
              nn.ReLU(True)]

    model += [nn.Conv2d(ngf * 2, ngf * 4, 3, stride=2, padding=1),
              nn.InstanceNorm2d(ngf * 4),
              nn.ReLU(True)]

    # resnet blocks
    for n in range(n_blocks):
      model += [ResNetBlock(ngf * 4, ngf * 4)]

    # 2x upsampling
    model += [nn.ConvTranspose2d(ngf * 4, ngf * 2, 3, stride=2, padding=1, 
                                 output_padding=1),
              nn.InstanceNorm2d(ngf * 2),
              nn.ReLU(True)]

    model += [nn.ConvTranspose2d(ngf * 2, ngf, 3, stride=2, padding=1, 
                                 output_padding=1),
              nn.InstanceNorm2d(ngf),
              nn.ReLU(True)]

    
    # c7s1_3
    model += [nn.ReflectionPad2d(3),
              nn.Conv2d(ngf, 3, 7, padding=0, stride=1)]

    self.model = nn.Sequential(*model)

  def forward(self, x, mask):
    gen = self.model(torch.cat((x, mask), 1))
    return (gen + x).tanh()


class ResNetBlock(nn.Module):
  def __init__(self, input_nc, output_nc, use_bias=True):
    super(ResNetBlock, self).__init__()

    model  = [nn.ReflectionPad2d(1),
              nn.Conv2d(input_nc, output_nc, 3, bias=use_bias),
              nn.InstanceNorm2d(output_nc),
              nn.ReLU(True)]

    
    model += [nn.ReflectionPad2d(1),
              nn.Conv2d(input_nc, output_nc, 3, bias=use_bias),
              nn.InstanceNorm2d(output_nc),
              nn.ReLU(True)]
    
    self.conv_block = nn.Sequential(*model)
  
  def forward(self, input):
    output = input + self.conv_block(input)
    return output


class Disriminator(nn.Module):
  # WARNING: Implemented 94x94 Patch Discriminator.
  def __init__(self, input_nc, ndf=64, slope=.2):
    super(Disriminator, self).__init__()
    model = [
             nn.Conv2d(input_nc, ndf, 4, stride=2, padding=1, bias=True),
             nn.LeakyReLU(slope, True)
    ]

    model += [
             nn.Conv2d(ndf, ndf * 2, 4, stride=2, padding=1),
             nn.InstanceNorm2d(ndf * 4),
             nn.LeakyReLU(slope, inplace=True),

             nn.Conv2d(ndf *2, ndf * 4, 4, stride=2, padding=1),
             nn.InstanceNorm2d(ndf * 4),
             nn.LeakyReLU(slope, inplace=True),

             # Use of ReflectionPadding and bias=False from orig. impl.
             # nn.Conv2d(ndf * 4, ndf * 8, stride=1, padding=1) -- must stride=1
             nn.Conv2d(ndf * 4, ndf * 8, 4, stride=2, padding=1),
             nn.InstanceNorm2d(ndf * 8),
             nn.LeakyReLU(slope, inplace=True)
    ]

    model += [
              nn.Conv2d(ndf * 8, 1, 4, stride=1, padding=1)
    ]

    self.model = nn.Sequential(*model)

  def forward(self, x):
    x = self.model(x)
    return F.avg_pool2d(x, x.size()[2:]).view(x.size()[0], -1)

In [0]:
class GANLoss(nn.Module):
  def __init__(self, device, real_target_label=1.0, fake_target_label=0.0):
    super(GANLoss, self).__init__()
    self.loss = nn.MSELoss()
    self.register_buffer('real_label', torch.tensor(real_target_label))
    self.register_buffer('fake_label', torch.tensor(fake_target_label))
    self.device = device


  def get_target_tensor(self, prediction, target_is_real):
    if target_is_real:
      target_tensor = self.real_label
    else:
      target_tensor = self.fake_label

    return target_tensor.expand_as(prediction)

  def __call__(self, prediction, target_is_real):
    target_tensor = self.get_target_tensor(prediction, target_is_real)
    loss = self.loss(prediction, target_tensor.to(self.device))
    return loss

In [5]:
# summary of the models
def num_params(model):
  model_parameters = filter(lambda p: p.requires_grad, model.parameters())
  params = sum([np.prod(p.size()) for p in model_parameters])
  return params

dL = Disriminator(3)
g = Generator_S2F(3)

d_num = round(num_params(dL) * 2 / 1e6, 3) * 2
g_num = round(num_params(g) * 2 / 1e6, 3) * 2
overall = round(d_num + g_num, 3)

print('---- Summary models ----')
print("Number of parameters (in millions):")
print("{:10}{:10}{:20}".format("D", 'G', "Overall"))
print("{:10}{:10}{:20}".format(str(d_num), str(g_num), str(overall)))

---- Summary models ----
Number of parameters (in millions):
D         G         Overall             
11.058    45.512    56.57               


### DataLoader

In [0]:
class UnalignedDataset(torch.utils.data.Dataset):
  def __init__(self, A_dir, B_dir, image_size):
    self.A_paths = sorted(glob.glob(os.path.join(A_dir, '*.jpg')))
    self.B_paths = sorted(glob.glob(os.path.join(B_dir, '*.jpg')))

    self.A_size = len(self.A_paths)
    self.B_size = len(self.B_paths)

    self.image_size=image_size
    self.transform_A = self.get_transform()
    self.transform_B = self.get_transform()

  def __getitem__(self, index_A):
    index_B = random.randint(0, self.B_size-1)
    A_path = self.A_paths[index_A % self.A_size]
    B_path = self.B_paths[index_B]

    A_img = Image.open(A_path).convert('RGB')
    B_img = Image.open(B_path).convert('RGB')

    A = self.transform_A(A_img)
    B = self.transform_B(B_img)


    return {'A':A, 'B':B, 'A_path':A_path, 'B_path':B_path}

    
  def __len__(self):
    return max(self.A_size, self.B_size)


  def get_transform(self):
    transform_list = [
                      transforms.Resize(int(self.image_size * 1.12), Image.BICUBIC),
                      transforms.RandomCrop(self.image_size),
                      transforms.RandomHorizontalFlip(),
                      transforms.ToTensor(),
                      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ]
    return transforms.Compose(transform_list)

### Pool,queue,sheduler


In [0]:
class ImagePool():
  def __init__(self, pool_size):
    self.pool_size=pool_size
    self.images = []

  def sample(self, image_data):
    if len(self.images) < self.pool_size:
      self.images.append(image_data)
      return image_data
    
    p = random.random()
    if p > 0.5:
      idx = random.randrange(0, self.pool_size)
      tmp_data = self.images[idx].clone()
      self.images[idx] = image_data.clone()
      return tmp_data

    return image_data

In [0]:
class QueueMask():
  def __init__(self, queue_size):
    self.queue_size=queue_size
    self.queue = []

  def insert(self, mask):
    if len(self.queue) >= self.queue_size:
      self.queue.pop(0)

    self.queue.append(mask)
  
  def rand_item(self):
    assert len(self.queue) > 0, 'Error! Empty queue.'
    return self.queue[random.randint(0, len(self.queue) - 1)]

  def last_item(self):
    assert len(self.queue) > 0, 'Error! Empty queue.'
    return self.queue[len(self.queue) - 1]

In [0]:
class LambdaLR():
    def __init__(self, n_epochs, offset, decay_start_epoch):
        assert ((n_epochs - decay_start_epoch) > 0), "Decay must start before the training session ends!"
        self.n_epochs = n_epochs
        self.offset = offset
        self.decay_start_epoch = decay_start_epoch

    def step(self, epoch):
        return 1.0 - max(0, epoch + self.offset - self.decay_start_epoch)/(self.n_epochs - self.decay_start_epoch)

### Mask generator

In [0]:
to_pil = transforms.ToPILImage()
to_gray = transforms.Grayscale(num_output_channels=1)

def mask_generator(shadow, shadow_free):
  shadow_gray = to_gray(to_pil( ((shadow.data.squeeze(0) + 1) * .5).cpu()) )
  shadow_free_gray = to_gray(to_pil(((shadow_free.data.squeeze(0) + 1) * .5).cpu()))
  diff = np.asarray(shadow_free_gray, dtype='float32') - np.asarray(shadow_gray, dtype='float32')

  T = threshold_otsu(diff)
  mask = torch.tensor((np.float32(diff >= T) - .5) / .5).unsqueeze(0).unsqueeze(0).cuda()
  mask.requires_grad = False

  return mask

### Model

In [0]:

checkpoint_dir = 'mask_shadow_gan/output/checkpoints_v1/'
images_dir = 'mask_shadow_gan/output/images_v1/'
summary_dir = 'mask_shadow_gan/output/summary_v1/'

A_dir = 'data/shadow_USR/shadow_train/'
B_dir = 'data/shadow_USR/shadow_free/'


load_model = True
batch_size=1
image_size=256
ngf=64
ndf=64

lambda1=10
lambda2=10
identity_lambda = 0.5
learning_rate=2e-4
beta1=.5
pool_size=50
mask_queue_size=50
n_blocks=9 if image_size==256 else 6
slope=0.2
stddev=0.02

input_nc=3
output_nc=3
device=torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
real_label=1.

n_epochs=200
decay_start=100
offset=0

In [0]:
def get_step(filename):
  match = re.findall(r'(\d+).pth', filename)[0]
  return int(match)

def latest_checkpoint_files(check_dir, f):
  return max(map(f, os.listdir(check_dir)))

In [0]:
def train():

  ### Definition of variables ###
  print("---- Define the networks ----".upper())
  # Networks
  netG_A2B = Generator_S2F(input_nc, n_blocks=n_blocks, ngf=ngf)
  netG_B2A = Generator_F2S(input_nc+1, output_nc, n_blocks=n_blocks, ngf=ngf)
  netD_B = Disriminator(input_nc, ndf=ndf, slope=slope)
  netD_A = Disriminator(input_nc, ndf=ndf, slope=slope)

  # initializing the nets and transferring to gpu
  print("---- Initializing the networks ----".upper())
  init_net(netG_A2B, device, stddev=stddev)
  init_net(netG_B2A, device, stddev=stddev)
  init_net(netD_B, device, stddev=stddev)
  init_net(netD_A, device, stddev=stddev)

  # Losses
  print("---- Define the losses ----".upper())
  criterion_GAN = torch.nn.MSELoss()
  criterion_cycle = torch.nn.L1Loss()
  criterion_identity = torch.nn.L1Loss()

  # Shedulers and Optimizers
  print("---- Define the schedulers and optimizers ----".upper())
  optimizer_G = torch.optim.Adam(itertools.chain(netG_A2B.parameters(), netG_B2A.parameters()), 
                                 betas=(beta1, 0.999), lr=learning_rate)
  optimizer_D_A = torch.optim.Adam(netD_A.parameters(), 
                                 betas=(beta1, 0.999), lr=learning_rate)
  
  optimizer_D_B = torch.optim.Adam(netD_B.parameters(), 
                                 betas=(beta1, 0.999), lr=learning_rate)

  lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(optimizer_G,
                                                    lr_lambda=LambdaLR(n_epochs, offset, decay_start).step)
  lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(optimizer_D_A,
                                                    lr_lambda=LambdaLR(n_epochs, offset, decay_start).step)
  lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(optimizer_D_B,
                                                    lr_lambda=LambdaLR(n_epochs, offset, decay_start).step)
  
  # Resume training - loading the state dicts
  latest_step = 0
  if load_model:
    print("---- Resume training from the latest checkpoint ----".upper())
    latest_step = latest_checkpoint_files(checkpoint_dir, get_step)
    netG_A2B.load_state_dict(torch.load(os.path.join(checkpoint_dir, 'netG_A2B_{}.pth'.format(latest_step))))
    netG_B2A.load_state_dict(torch.load(os.path.join(checkpoint_dir, 'netG_B2A_{}.pth'.format(latest_step))))
    netD_A.load_state_dict(torch.load(os.path.join(checkpoint_dir, 'netD_A_{}.pth'.format(latest_step))))
    netD_B.load_state_dict(torch.load(os.path.join(checkpoint_dir, 'netD_B_{}.pth'.format(latest_step))))

    optimizer_G.load_state_dict(torch.load(os.path.join(checkpoint_dir, 'optimizer_G_{}.pth'.format(latest_step))))
    optimizer_D_A.load_state_dict(torch.load(os.path.join(checkpoint_dir, 'optimizer_D_A_{}.pth'.format(latest_step))))
    optimizer_D_B.load_state_dict(torch.load(os.path.join(checkpoint_dir, 'optimizer_D_B_{}.pth'.format(latest_step))))

    lr_scheduler_G.load_state_dict(torch.load(os.path.join(checkpoint_dir, 'lr_scheduler_G_{}.pth'.format(latest_step))))
    lr_scheduler_D_A.load_state_dict(torch.load(os.path.join(checkpoint_dir, 'lr_scheduler_D_A_{}.pth'.format(latest_step))))
    lr_scheduler_D_B.load_state_dict(torch.load(os.path.join(checkpoint_dir, 'lr_scheduler_D_B_{}.pth'.format(latest_step))))
  
  ### Inputs and targets allocations
  print("---- Allocating the networks ----".upper())
  Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.Tensor
  input_A = Tensor(batch_size, 3, image_size, image_size)
  input_B = Tensor(batch_size, 3, image_size, image_size)
  
  target_real = Variable(Tensor(batch_size).fill_(real_label), requires_grad=False)
  target_fake = Variable(Tensor(batch_size).fill_(0.), requires_grad=False)

  mask_non_shadow = Variable(Tensor(batch_size, 1, image_size, image_size).fill_(-1.0), 
                             requires_grad=False)  # Ml
  fake_A_buffer = ImagePool(pool_size)
  fake_B_buffer = ImagePool(pool_size)

  # Data loader
  dataloader = DataLoader(UnalignedDataset(A_dir, B_dir, image_size), 
                          batch_size=batch_size, shuffle=True, num_workers=2)
  mask_queue = QueueMask(queue_size=dataloader.__len__() / 4)

  print('---- Start training ----'.upper())
  iter_num = latest_step * dataloader.__len__()
  start_time = time.time()
  try:
    for epoch in range(latest_step, n_epochs):
      for i, batch in enumerate(dataloader):

        real_A = Variable(input_A.copy_(batch['A']))
        real_B = Variable(input_B.copy_(batch['B']))
        
        # Generators optimizing
        optimizer_G.zero_grad()

        ## Identity loss
        same_B = netG_A2B(real_B)
        identity_G_A2B_loss = criterion_identity(same_B, real_B)
        identity_G_B2A_loss = criterion_identity(netG_B2A(real_A, mask_non_shadow), real_A)

        ## GAN Loss
        fake_B = netG_A2B(real_A)
        gan_G_A2B_loss = criterion_GAN(netD_B(fake_B), target_real)

        mask_queue.insert(mask_generator(real_A, fake_B))

        mask_use = mask_queue.rand_item()
        fake_A = netG_B2A(real_B, mask_use)
        gan_G_B2A_loss = criterion_GAN(netD_A(fake_A), target_real)

        ## Cycle loss
        recovered_A = netG_B2A(fake_B, mask_queue.last_item())
        cycle_G_ABA_loss = criterion_cycle(recovered_A, real_A)

        recovered_B = netG_A2B(fake_A)
        cycle_G_BAB_loss = criterion_cycle(recovered_B, real_B)

        ## Total loss and optimizing
        loss_G = identity_lambda * (identity_G_A2B_loss + identity_G_B2A_loss) + \
          (lambda1 * cycle_G_ABA_loss + lambda2 * cycle_G_BAB_loss) + (gan_G_A2B_loss + gan_G_B2A_loss)
        loss_G.backward()

        optimizer_G.step()

        # Discriminator A optimizing 
        optimizer_D_A.zero_grad()
        
        # real loss
        real_D_A = netD_A(real_A)
        real_D_A_loss = criterion_GAN(real_D_A, target_real)
        
        # fake loss
        fake_A = fake_A_buffer.sample(fake_A)
        fake_D_A = netD_A(fake_A.detach())
        fake_D_A_loss = criterion_GAN(fake_D_A, target_fake)

        # Total loss
        loss_D_A = 0.5 * (real_D_A_loss + fake_D_A_loss)
        loss_D_A.backward()
        optimizer_D_A.step()
        
        
        # Discriminator B optimizing
        optimizer_D_B.zero_grad()

        # real loss
        real_D_B = netD_B(real_B)
        real_D_B_loss = criterion_GAN(real_D_B, target_real)
        
        # fake loss
        fake_B = fake_B_buffer.sample(fake_B)
        fake_D_B = netD_B(fake_B.detach())
        fake_D_B_loss = criterion_GAN(fake_D_B, target_fake)

        # Total loss
        loss_D_B = 0.5 * (real_D_B_loss + fake_D_B_loss)
        loss_D_B.backward()

        optimizer_D_B.step()
        iter_num += 1

        if i % 1000 == 0:
          log = '[iter %d], [loss_G %.5f], [loss_G_identity %.5f], [loss_G_GAN %.5f],' \
              '[loss_G_cycle %.5f], [loss_D %.5f]' % \
              (iter_num, loss_G, (identity_G_A2B_loss + identity_G_B2A_loss), (gan_G_A2B_loss + gan_G_B2A_loss),
              (cycle_G_ABA_loss + cycle_G_BAB_loss), (loss_D_A + loss_D_B))
          print(log)

          img_fake_A = 0.5 * (fake_A.detach().data + 1.0)
          img_fake_A = (to_pil(img_fake_A.data.squeeze(0).cpu()))
          img_fake_A.save(os.path.join(images_dir, 'fake_A_{}.png'.format(iter_num)))

          img_fake_B = 0.5 * (fake_B.detach().data + 1.0)
          img_fake_B = (to_pil(img_fake_B.data.squeeze(0).cpu()))
          img_fake_B.save(os.path.join(images_dir, 'fake_B_{}.png'.format(iter_num)))
          duration = time.time() - start_time
          if duration > 6 * 3600:
            print("---- 6 hours limit reached ----".upper())
            break
          print("Time from start : ", time.time() - start_time)


      # schedulers
      lr_scheduler_G.step()
      lr_scheduler_D_A.step()
      lr_scheduler_D_B.step()

      # checkpoints
      if (epoch + 1) % 5 == 0:
        print("Saving the checkpoint - {}".format(epoch + 1))
        torch.save(netG_A2B.state_dict(), os.path.join(checkpoint_dir, 'netG_A2B_{}.pth'.format(epoch+1)))
        torch.save(netG_B2A.state_dict(), os.path.join(checkpoint_dir, 'netG_B2A_{}.pth'.format(epoch+1)))
        torch.save(netD_A.state_dict(), os.path.join(checkpoint_dir, 'netD_A_{}.pth'.format(epoch+1)))
        torch.save(netD_B.state_dict(), os.path.join(checkpoint_dir, 'netD_B_{}.pth'.format(epoch+1)))

        torch.save(optimizer_G.state_dict(), os.path.join(checkpoint_dir, 'optimizer_G_{}.pth'.format(epoch+1)))
        torch.save(optimizer_D_A.state_dict(), os.path.join(checkpoint_dir, 'optimizer_D_A_{}.pth'.format(epoch+1)))
        torch.save(optimizer_D_B.state_dict(), os.path.join(checkpoint_dir, 'optimizer_D_B_{}.pth'.format(epoch+1)))

        torch.save(lr_scheduler_G.state_dict(), os.path.join(checkpoint_dir, 'lr_scheduler_G_{}.pth'.format(epoch+1)))
        torch.save(lr_scheduler_D_A.state_dict(), os.path.join(checkpoint_dir, 'lr_scheduler_D_A_{}.pth'.format(epoch+1)))
        torch.save(lr_scheduler_D_B.state_dict(), os.path.join(checkpoint_dir, 'lr_scheduler_D_B_{}.pth'.format(epoch+1)))
      
      if epoch > 200:
        print("---- Reached the 200 epochs ----".upper())
        break

  except Exception as e:
    print("---- EXCEPTION ----")
    print(e)
    raise e
    print("@ line {}".format(sys.exc_info()[-1].tb_lineno))
  finally:
    # save the checkpoint
    
    print("Saving the checkpoint - {}".format(epoch + 1))
    torch.save(netG_A2B.state_dict(), os.path.join(checkpoint_dir, 'netG_A2B_{}.pth'.format(epoch+1)))
    torch.save(netG_B2A.state_dict(), os.path.join(checkpoint_dir, 'netG_B2A_{}.pth'.format(epoch+1)))
    torch.save(netD_A.state_dict(), os.path.join(checkpoint_dir, 'netD_A_{}.pth'.format(epoch+1)))
    torch.save(netD_B.state_dict(), os.path.join(checkpoint_dir, 'netD_B_{}.pth'.format(epoch+1)))

    torch.save(optimizer_G.state_dict(), os.path.join(checkpoint_dir, 'optimizer_G_{}.pth'.format(epoch+1)))
    torch.save(optimizer_D_A.state_dict(), os.path.join(checkpoint_dir, 'optimizer_D_A_{}.pth'.format(epoch+1)))
    torch.save(optimizer_D_B.state_dict(), os.path.join(checkpoint_dir, 'optimizer_D_B_{}.pth'.format(epoch+1)))

    torch.save(lr_scheduler_G.state_dict(), os.path.join(checkpoint_dir, 'lr_scheduler_G_{}.pth'.format(epoch+1)))
    torch.save(lr_scheduler_D_A.state_dict(), os.path.join(checkpoint_dir, 'lr_scheduler_D_A_{}.pth'.format(epoch+1)))
    torch.save(lr_scheduler_D_B.state_dict(), os.path.join(checkpoint_dir, 'lr_scheduler_D_B_{}.pth'.format(epoch+1)))

    print("---- Finish ----".upper())

In [0]:
train()

---- DEFINE THE NETWORKS ----
---- INITIALIZING THE NETWORKS ----
---- DEFINE THE LOSSES ----
---- DEFINE THE SCHEDULERS AND OPTIMIZERS ----
---- RESUME TRAINING FROM THE LATEST CHECKPOINT ----
---- ALLOCATING THE NETWORKS ----
---- START TRAINING ----


  return F.mse_loss(input, target, reduction=self.reduction)


[iter 342301], [loss_G 2.39170], [loss_G_identity 0.07929], [loss_G_GAN 1.67406],[loss_G_cycle 0.06780], [loss_D 0.01205]
Time from start :  2.43892502784729
[iter 343301], [loss_G 2.40191], [loss_G_identity 0.07386], [loss_G_GAN 1.61789],[loss_G_cycle 0.07471], [loss_D 0.09941]
Time from start :  707.6419966220856
[iter 344257], [loss_G 2.15431], [loss_G_identity 0.20001], [loss_G_GAN 1.06660],[loss_G_cycle 0.09877], [loss_D 0.08525]
Time from start :  1426.5825719833374
[iter 345257], [loss_G 2.27064], [loss_G_identity 0.10003], [loss_G_GAN 1.42336],[loss_G_cycle 0.07973], [loss_D 0.03857]
Time from start :  2134.9565091133118
[iter 346213], [loss_G 1.95389], [loss_G_identity 0.09544], [loss_G_GAN 1.00708],[loss_G_cycle 0.08991], [loss_D 0.06158]
Time from start :  2813.0236287117004
[iter 347213], [loss_G 1.89872], [loss_G_identity 0.10401], [loss_G_GAN 1.21937],[loss_G_cycle 0.06273], [loss_D 0.01771]
Time from start :  3522.4053750038147
[iter 348169], [loss_G 2.34890], [loss_G_id

In [0]:
images_paths = glob.glob(os.path.join(images_dir, '*.png'))

In [0]:
def get_step_(filename):
  match = re.findall(r'(\d+).png', filename)[0]
  return int(match)

In [0]:
ind = list(map(get_step, images_paths))

In [0]:
k = 30 * 2
im_ind = sorted(zip(images_paths, ind), reverse=True, key=lambda x: x[1])[k:]

In [0]:
to_del_path = [x[0] for x in im_ind]
_ = [os.remove(path) for path in to_del_path]

### Inference

In [0]:
def read_paths(path):
  with open(path, 'rb') as f:
    paths = pickle.load(f)
  return paths

def mkdir(path):
  try:
    os.mkdir(path)
  except FileExistsError as e:
    pass

def save_test(A_path, B_path, save_path):
  start_time = time.time()
  A_paths = read_paths(A_path)
  B_paths = read_paths(B_path)
  
  assert len(A_paths) == len(B_paths)
  # read and preprocess the test images
  mkdir(save_path)
  mkdir(os.path.join(save_path, 'A'))
  mkdir(os.path.join(save_path, 'B'))

  mkdir(os.path.join(save_path, 'A_B'))
  mkdir(os.path.join(save_path, 'B_A'))

  mkdir(os.path.join(save_path, 'masks'))

  # save domain images
  # print('---- Saving the images for domains ----'.upper())
  # save_images(A_paths, os.path.join(save_path, 'A'), 'A')
  # save_images(B_paths, os.path.join(save_path, 'B'), 'B')

  # load the model
  print("---- Loading the models ----".upper())
  netG_A2B = Generator_S2F(input_nc).to(device)
  netG_B2A = Generator_F2S(input_nc+1, output_nc).to(device)


  # load latest checkpoint
  latest_step = latest_checkpoint_files(checkpoint_dir, get_step)
  netG_A2B.load_state_dict(torch.load(os.path.join(checkpoint_dir, 'netG_A2B_{}.pth'.format(latest_step))))
  netG_B2A.load_state_dict(torch.load(os.path.join(checkpoint_dir, 'netG_B2A_{}.pth'.format(latest_step))))

  # turn the validation mode
  netG_A2B.eval()
  netG_B2A.eval()

  # input tensors
  Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.Tensor
  input_A = Tensor(batch_size, input_nc, image_size, image_size, 3)
  input_B = Tensor(batch_size, output_nc, image_size, image_size, 3)

  # input transformations
  img_transforms = transforms.Compose([
                                       transforms.Resize((image_size, image_size), interpolation=Image.BICUBIC),
                                       transforms.ToTensor(),
                                       transforms.Normalize((.5,.5,.5),(.5,.5,.5))
  ])
  to_pil = transforms.ToPILImage()

  # start inference
  print("--- Start inference ----".upper())
  image_queue = QueueMask(queue_size=mask_queue_size)
  for i,(path_A, path_B) in enumerate(zip(A_paths, B_paths)):
    image_A = Image.open(path_A).convert("RGB")
    image_B = Image.open(path_A).convert("RGB")

    im_A = (img_transforms(image_A).unsqueeze(0)).to(device)
    im_B = (img_transforms(image_B).unsqueeze(0)).to(device)

    Image.fromarray(np.array(transforms.Resize((image_size, image_size))(image_A))).save(os.path.join(save_path, 'A', 'A_{}.jpg'.format(i)))
    Image.fromarray(np.array(transforms.Resize((image_size, image_size))(image_B))).save(os.path.join(save_path, 'B', 'B_{}.jpg'.format(i)))

    # generate A -> B
    A_B = netG_A2B(im_A)
    w,h = image_A.size

    current_mask = mask_generator(A_B, im_A)
    image_queue.insert(current_mask)
    A_B = .5 * (A_B + 1)
    A_B = np.array((to_pil(A_B.data.squeeze(0).cpu())))
    Image.fromarray(A_B).save(os.path.join(save_path, 'A_B', 'A_B_{}.jpg'.format(i)))

    # generate B -> A
    mask = image_queue.rand_item()
    B_A = netG_B2A(im_B, mask)
    w,h = image_B.size

    B_A = .5 * (B_A + 1)
    B_A = np.array((to_pil(B_A.data.squeeze(0).cpu())))
    Image.fromarray(B_A).save(os.path.join(save_path, 'B_A', 'B_A_{}.jpg'.format(i)))

    mask_cpu = .5 * (current_mask + 1)
    mask_cpu = np.array((to_pil(mask_cpu.data.squeeze(0).cpu())))
    Image.fromarray(mask_cpu).save(os.path.join(save_path, 'masks', 'mask_{}.jpg'.format(i)))

  print("---- Inference time : {} ----".format(time.time() - start_time).upper())
  


def save_images(image_paths, save_path, domain):
  try:
    os.mkdir(save_path)
  except:
    pass

  for i in range(len(image_paths)):
    img = plt.imread(image_paths[i])
    plt.imsave(os.path.join(save_path,'{}_{}.jpg'.format(domain, i)), img)

In [0]:
A_dir_inf = 'mask_shadow_gan/results/test_paths/shadow_path.pickle'
B_dir_inf = 'mask_shadow_gan/results/test_paths/free_path.pickle'
results_dir = 'mask_shadow_gan/results/msg_200/'

In [16]:
save_test(A_dir_inf, B_dir_inf, results_dir)

---- LOADING THE MODELS ----
--- START INFERENCE ----
---- INFERENCE TIME : 26.246599912643433 ----
