# Library

In [30]:
!pip install pytorch-fid-wrapper
!pip install sewar

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [31]:
import tensorflow as tf
import torch
import torchvision
import numpy as np
import cv2
import torch.nn as nn
import functools
import itertools

import os
import pathlib
import time
import datetime
import json
import time

from matplotlib import pyplot as plt
from glob import glob
from pathlib import Path
from google.colab import files

from torch.autograd import Variable
import torch.optim as optim
from torchvision.utils import save_image
from torchsummary import summary
from tqdm.notebook import tqdm

from os import listdir
import pandas as pd
from numpy import asarray
from keras.preprocessing.image import img_to_array
from keras.preprocessing.image import load_img

from sewar.full_ref import mse, rmse, psnr, uqi, ssim, ergas, scc, rase, sam, msssim, vifp
import  pytorch_fid_wrapper as pfw

from skimage import io, transform
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

from google.colab import drive
from google.colab import files
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


# Dataset

In [32]:
def random_crop(image, dim):
  height, width, _ = dim
  x, y = np.random.uniform(low=0,high=int(height-256)), np.random.uniform(low=0,high=int(width-256))  
  return image[:, int(x):int(x)+256, int(y):int(y)+256]

def random_jittering_mirroring(input_image, target_image, height=286, width=286):
  #resizing to 286x286
  input_image = cv2.resize(input_image, (height, width) ,interpolation=cv2.INTER_NEAREST)
  target_image = cv2.resize(target_image, (height, width),interpolation=cv2.INTER_NEAREST)

  #cropping (random jittering) to 256x256
  stacked_image = np.stack([input_image, target_image], axis=0)
  IMG_HEIGHT,IMG_WIDTH = 256,256
  cropped_image = random_crop(stacked_image, dim=[IMG_HEIGHT, IMG_WIDTH, 3])

  input_image, target_image = cropped_image[0], cropped_image[1]
  #print(input_image.shape)
  if torch.rand(()) > 0.5:
  # random mirroring
    input_image = np.fliplr(input_image)
    target_image = np.fliplr(target_image)
  return input_image, target_image

def normalize(inp):
  return (inp / 127.5) - 1

def preprocess(inp):
    #inp = random_jittering_mirroring(inp)
    inp = normalize(inp)
    inp = torch.from_numpy(np.einsum('ijk->kij', inp.copy()))
    return inp

# load all images in a directory into memory
def load_images(path, size=(512,512)):
  #images_dict = {"body":[],"apparel_head_foot":[],"segmentation_skin":[],"segmentation_total":[],"line":[],"apparel":[],"head_foot":[]}
  images_dict = {"body":[],"apparel_head_foot":[],"line":[]}
  # enumerate filenames in directory, assume all are images
  for folder_path in tqdm(glob(path, recursive = True)):
    for t in images_dict.keys():
    # load and resize the image
      img = load_img(folder_path+f"/{t}.jpg", target_size=None)
      # convert to numpy array
      img = img_to_array(img)
      # preprocess
      img = preprocess(img)
      images_dict[t].append(img)
  for t in images_dict.keys():
    images_dict[t]=asarray(images_dict[t])
  return images_dict

In [33]:
class Custom_DataLoader():
  def __init__(self, dataset_folder_path, input_name_list = ["line","apparel_head_foot"], img_size=(256, 256),training=True):
    """
    Constructs a DataLoader object
    
    :params:
        dataset_folder_path (str): The path to the folder of 
            images to be loaded.  This is the name of the Pix2Pix dataset, 
            like edges2shoes.
        image_shape (tuple): Tuple of the image dimensions like (x, y).
    """
    self.dataset_folder_path = dataset_folder_path
    self.img_size = img_size
    self.input_name_list = input_name_list
    self.target_name_list = ["body"]
    self.train_test_split_percentage = 0.8
    self.image_folder = glob(self.dataset_folder_path, recursive = True)
    if training:
      self.image_folder = self.image_folder[:int(len(self.image_folder)*self.train_test_split_percentage)]
    else:
      self.image_folder = self.image_folder[int(len(self.image_folder)*self.train_test_split_percentage):]
    # set seed 
    np.random.seed(0)

  def _preprocess(self,image_folder_list):
    """
    Preprocess data 
    Input : List
    Output : Tensor, Tensor
    """
    # set image dict 
    images_dict = {i:[] for i in self.input_name_list + self.target_name_list}
    for img_folder in image_folder_list:
      for t in images_dict.keys():
        # load and resize the image
        img = load_img(img_folder+f"/{t}.jpg", target_size=self.img_size)
        # convert to numpy array
        img = img_to_array(img)
        # normalize 
        img = (img / 127.5) - 1
        # change axis position
        img = torch.from_numpy(np.einsum('ijk->kij', img.copy()))
        images_dict[t].append(img)
    for t in images_dict.keys():
      images_dict[t]=asarray(images_dict[t])
    # set input data
    if len(self.input_name_list)>1:
      inputs_list = [images_dict[name] for name in self.input_name_list]
      input_data = asarray([torch.cat(tuple(inputs),dim=0) for inputs in zip(*inputs_list)])
    else :
      input_data = images_dict[self.input_name_list[0]]
    # set target data
    target_data = images_dict[self.target_name_list[0]]
    # transform into a tensor 
    input_data = torch.stack([t for t in input_data])
    target_data = torch.stack([t for t in target_data])
    return input_data, target_data
  
  def load_data(self, batch_size=1):
    """
    Loads data 
    Output : Tensor, Tensor
    """
    # randomly sample batch_size number of images from the given path
    batch_image_folder = np.random.choice(self.image_folder, size=batch_size)
    #preprocess
    input_data, target_data = self._preprocess(batch_image_folder)

    return input_data, target_data

  def load_batch(self, batch_size=20):
      """
      A batch load generator
      Output : Tensor, Tensor
      """
      #get nbr of batches 
      nbr_batches = int(len(self.image_folder) / batch_size)
      # load batch
      for i in range(nbr_batches-1):
        batch_image_folder = self.image_folder[i*batch_size:(i+1)*batch_size]
        input_data, target_data = self._preprocess(batch_image_folder)           
        yield input_data, target_data

  def load_target_data(self):
    """
    Load target data
    Output : np.array
    """
    images_dict = {self.target_name_list[0]:[]}
    for img_folder in tqdm(self.image_folder):
      # load and resize the image
      img = load_img(img_folder+f"/{self.target_name_list[0]}.jpg", target_size=self.img_size)
      # convert to numpy array
      img = img_to_array(img)
      # normalize 
      img = (img / 127.5) - 1
      # change axis position
      img = torch.from_numpy(np.einsum('ijk->kij', img.copy()))
      images_dict[self.target_name_list[0]].append(img)
    images_dict[self.target_name_list[0]]=asarray(images_dict[self.target_name_list[0]])
    # get target data
    target_data = images_dict[self.target_name_list[0]]
    return target_data

In [34]:
# Create a pytorch Dataset

class DeepFashionDataset(Dataset):
    """Deepfashion dataset """

    def __init__(self, dataset_folder_path, input_name_list, img_size, training=True):
      """  
      :params:
          dataset_folder_path (str): The path to the folder of 
              images to be loaded.  This is the name of the Pix2Pix dataset, 
              like edges2shoes.
          image_shape (tuple): Tuple of the image dimensions like (x, y).
      """
      self.dataset_folder_path = dataset_folder_path
      self.img_size = img_size
      self.input_name_list = input_name_list
      self.target_name_list = ["body"]
      self.train_test_split_percentage = 0.8
      self.image_folder = glob(self.dataset_folder_path, recursive = True)
      if training:
        self.image_folder = self.image_folder[:int(len(self.image_folder)*self.train_test_split_percentage)]
      else:
        self.image_folder = self.image_folder[int(len(self.image_folder)*self.train_test_split_percentage):]
      # set seed 
      np.random.seed(0)

    def __len__(self):
        return len(self.image_folder)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
  
        images_dict = {}
        #image folder 
        img_folder_path = self.image_folder[idx]
        
        for t in self.input_name_list + self.target_name_list:
          # load and resize the image
          img = load_img(img_folder_path+f"/{t}.jpg")
          # convert to numpy array
          img = img_to_array(img)
          # normalize 
          img = (img / 127.5) - 1
          # change axis position
          img = torch.from_numpy(np.einsum('ijk->kij', img))
          images_dict[t] = img

        target_data = images_dict[self.target_name_list[0]]
        # if we have input data
        if len(self.input_name_list)>0:
          # concat data
          if len(self.input_name_list)>1:
            inputs_list = [images_dict[name] for name in self.input_name_list]
            input_data = torch.cat(tuple(inputs_list),dim=0)
          else :
            input_data = images_dict[self.input_name_list[0]]
          return input_data, target_data
        else :
          return target_data


# Discriminator and Generator

In [35]:
# custom weights initialization called on generator and discriminator
def init_weights(net, init_type='normal', scaling=0.02):
  def init_func(m):  # define the initialization function
    classname = m.__class__.__name__
    if hasattr(m, 'weight') and (classname.find('Conv')) != -1:
      torch.nn.init.normal_(m.weight.data, 0.0, scaling)
    elif classname.find('BatchNorm2d') != -1:  # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
      torch.nn.init.normal_(m.weight.data, 1.0, scaling)
      torch.nn.init.constant_(m.bias.data, 0.0)

  #print('initialize network with %s' % init_type)
  net.apply(init_func)  # apply the initialization function <init_func>

In [36]:
def get_norm_layer():
    """Return a normalization layer
       For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
    """
    norm_type = 'batch'
    if norm_type == 'batch':
        norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
    return norm_layer

In [37]:
class UnetSkipConnectionBlock(nn.Module):
    """Defines the Unet submodule with skip connection.
        X -------------------identity----------------------
        |-- downsampling -- |submodule| -- upsampling --|
    """

    def __init__(self, outer_nc, inner_nc, input_nc=None,
                 submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
        """Construct a Unet submodule with skip connections.
        Parameters:
            outer_nc (int) -- the number of filters in the outer conv layer
            inner_nc (int) -- the number of filters in the inner conv layer
            input_nc (int) -- the number of channels in input images/features
            submodule (UnetSkipConnectionBlock) -- previously defined submodules
            outermost (bool)    -- if this module is the outermost module
            innermost (bool)    -- if this module is the innermost module
            norm_layer          -- normalization layer
            use_dropout (bool)  -- if use dropout layers.
        """
        super(UnetSkipConnectionBlock, self).__init__()
        self.outermost = outermost
        if type(norm_layer) == functools.partial:
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d
        if input_nc is None:
            input_nc = outer_nc
        downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
                             stride=2, padding=1, bias=use_bias)
        downrelu = nn.LeakyReLU(0.2, True)
        downnorm = norm_layer(inner_nc)
        uprelu = nn.ReLU(True)
        upnorm = norm_layer(outer_nc)

        if outermost:
            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
                                        kernel_size=4, stride=2,
                                        padding=1)
            down = [downconv]
            up = [uprelu, upconv, nn.Tanh()]
            model = down + [submodule] + up
        elif innermost:
            upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
                                        kernel_size=4, stride=2,
                                        padding=1, bias=use_bias)
            down = [downrelu, downconv]
            up = [uprelu, upconv, upnorm]
            model = down + up
        else:
            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
                                        kernel_size=4, stride=2,
                                        padding=1, bias=use_bias)
            down = [downrelu, downconv, downnorm]
            up = [uprelu, upconv, upnorm]

            if use_dropout:
                model = down + [submodule] + up + [nn.Dropout(0.5)]
            else:
                model = down + [submodule] + up

        self.model = nn.Sequential(*model)

    def forward(self, x):
        if self.outermost:
            return self.model(x)
        else:   # add skip connections
            return torch.cat([x, self.model(x)], 1)

In [38]:
class UnetGenerator(nn.Module):
  """Create a Unet-based generator"""

  def __init__(self, input_nc, output_nc, nf=64, norm_layer=nn.BatchNorm2d, use_dropout=False):
    super(UnetGenerator, self).__init__()
    # construct unet structure
    # add the innermost block
    unet_block = UnetSkipConnectionBlock(nf * 8, nf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) 
    #print(unet_block)

    # add intermediate block with nf * 8 filters
    unet_block = UnetSkipConnectionBlock(nf * 8, nf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
    unet_block = UnetSkipConnectionBlock(nf * 8, nf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
    unet_block = UnetSkipConnectionBlock(nf * 8, nf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)

    # gradually reduce the number of filters from nf * 8 to nf. 
    unet_block = UnetSkipConnectionBlock(nf * 4, nf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
    unet_block = UnetSkipConnectionBlock(nf * 2, nf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
    unet_block = UnetSkipConnectionBlock(nf, nf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
    
    # add the outermost block
    self.model = UnetSkipConnectionBlock(output_nc, nf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer)  

  def forward(self, input):
      """Standard forward"""
      return self.model(input)

In [39]:
class Discriminator(nn.Module):
  def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
    super(Discriminator, self).__init__()
    kw = 4
    padw = 1
    sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
    nf_mult = 1
    nf_mult_prev = 1
    for n in range(1, n_layers):  # gradually increase the number of filters
        nf_mult_prev = nf_mult
        nf_mult = min(2 ** n, 8)
        sequence += [
            nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=False),
            norm_layer(ndf * nf_mult),
            nn.LeakyReLU(0.2, True)
        ]

    nf_mult_prev = nf_mult
    nf_mult = min(2 ** n_layers, 8)
    sequence += [
        nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=False),
        norm_layer(ndf * nf_mult),
        nn.LeakyReLU(0.2, True)
    ]

    sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw), nn.Sigmoid()]  # output 1 channel prediction map
    self.model = nn.Sequential(*sequence)

  def forward(self, input):
    """Standard forward."""
    return self.model(input)

In [40]:
def generator_loss(generated_image, target_img, D_fake_gen, real_target, adv_loss, l1_loss, lb_rec=0.999, lb_adv=0.001): 
  gen_loss,l1_l = 0,0
  #BCE 
  if adv_loss:
    adversarial_loss = nn.BCELoss() 
    gen_loss = adversarial_loss(D_fake_gen, real_target)
  # L1
  if l1_loss:
    L1_loss = nn.L1Loss()
    l1_l = L1_loss(generated_image, target_img)
  # total loss
  gen_total_loss = (lb_adv * gen_loss) + (lb_rec * l1_l)
  # print(lb_adv * gen_loss)
  # print(lb_rec * l1_l)
  # print(gen_total_loss)
  return gen_total_loss

In [41]:
def discriminator_loss(output, label):
  adversarial_loss = nn.BCELoss() 
  disc_loss = adversarial_loss(output, label)
  return disc_loss

## Check architecture

In [42]:
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# input_nc = 6
# generator_input_nc = input_nc
# discriminator_input_nc = 3 + input_nc
# generator = UnetGenerator(input_nc=generator_input_nc, output_nc=3, nf=64, norm_layer=get_norm_layer(), use_dropout=False).cuda().float()
# init_weights(generator, 'normal', scaling=0.02)
# discriminator = Discriminator(discriminator_input_nc, ndf=64, n_layers=4, norm_layer=get_norm_layer()).cuda().float()
# init_weights(discriminator, 'normal', scaling=0.02)
# generator = generator.to('cuda')
# print(summary(generator,(generator_input_nc,512,512)))
# discriminator = discriminator.to('cuda')
# summary(discriminator,(discriminator_input_nc,512,512))

# Evaluation metrics

## Similarity

In [43]:
def tensor_to_np(tensor_images,device):
  """
  Transform a tensor of multiple image into a list of numpy array with the channel in the last dimension (H,W,C)
  """
  if device != "cpu":
    images_np = tensor_images.detach().cpu().numpy()
  else :
    images_np = tensor_images.numpy()
  images_np = [ np.einsum('ijk->jki', image) for image in images_np ]
  return images_np

In [44]:
def compute_eval_metrics(list_img_1,list_img_2):

  """compute different evalutation metrics 
  Input : Two images with size (x,y,color)
  """
  # mse : [0,+inf[ with 0 being the best score
  # ssim -> [0,1]  with 1 being the best score
  # MAX represent the maximum value possible for a pixel
  # psnr -> [0,+inf] with 20-40 considered to be a good score 
  score_dict = {"MSE":[],"SSIM":[],"PSNR":[]}
  for img1,img2 in zip(list_img_1,list_img_2):
    score_dict["MSE"].append(mse(img1,img2))
    score_dict["SSIM"].append(ssim(img1,img2,MAX=1)[0])
    score_dict["PSNR"].append(psnr(img1,img2,MAX=1))
  for key in score_dict.keys():
    score_dict[key]=np.mean(score_dict[key])
  return score_dict

In [45]:
#imageA_np = np.einsum('ijk->jki', np.array(target_data[0]))
#imageB_np = np.einsum('ijk->jki', np.array(target_data[1]))
#compute_eval_metrics([imageA_np],[imageB_np])

## FID

In [46]:
# Optional: set pfw's configuration with your parameters once and for all
# dims is the dimension of the activation layer of inceptionv3
#pfw.set_config(batch_size=8, dims=2048, device="cuda:0")

# Optional: compute real_m ( mean ) and real_s ( std ) only once, they will not change during training
# it expects a tensor of size N*C*H*W with N the number of images

# calculate fid for real data 
#tar_images_tensor = torch.stack([t for t in tar_images])
#real_m, real_s = pfw.get_stats(tar_images_tensor)

# Training and testing

In [47]:
def train_gen_discr(input_img,target_img,generator,discriminator,G_optimizer,D_optimizer,add_discriminator_input_data,mask,device,adv_loss,l1_loss):
  """
  Train the generator and the discriminator given the inputs and the target image 
  Input : Tensor, Tensor, Boolean
  Return : Discrimator loss and Generator loss
  """
  # switch to train mode
  generator.train()
  discriminator.train()

  # input & target 
  D_optimizer.zero_grad()
  input_img = input_img.to(device)
  target_img = target_img.to(device)

  # ground truth labels real and fake
  real_target = Variable(torch.ones(target_img.size(0), 1, 30, 30).to(device))
  fake_target = Variable(torch.zeros(target_img.size(0), 1, 30, 30).to(device))
  
  # generator forward pass
  generated_image = generator(input_img)

  # overwrite generated image
  if mask:
    generated_image = modify_generated_image_with_skin_segementation(input_img,generated_image)
    
  # x is the input , y is the target image 
  ############################
  # (1) Update D network: maximise log(D(y,x)) + log(1 - D(G(x),x))
  ###########################
  
  # train discriminator with fake/generated images
  if add_discriminator_input_data:
    disc_inp_fake = torch.cat((input_img, generated_image), 1)
  else:
    disc_inp_fake = generated_image

  D_fake = discriminator(disc_inp_fake.detach())
  D_fake_loss = discriminator_loss(D_fake, fake_target)
  
  # train discriminator with real images
  if add_discriminator_input_data:
    disc_inp_real = torch.cat((input_img, target_img), 1)
  else:
    disc_inp_real = target_img

  D_real = discriminator(disc_inp_real)
  D_real_loss = discriminator_loss(D_real,  real_target)

  # average discriminator loss
  D_loss = (D_real_loss + D_fake_loss) / 2

  # compute gradients and run optimizer step
  D_loss.backward()
  D_optimizer.step()
  
  ############################
  # (2) Update G network: maximise log(D(G(x),x)) + L1 
  ###########################

  # Train generator with real labels
  G_optimizer.zero_grad()
  if add_discriminator_input_data:
    fake_gen = torch.cat((input_img, generated_image), 1)
  else:
    fake_gen = generated_image
  D_fake_gen = discriminator(fake_gen)

  G_loss = generator_loss(generated_image, target_img, D_fake_gen, real_target,adv_loss,l1_loss)                                 
  # compute gradients and run optimizer step
  G_loss.backward()
  G_optimizer.step()

  return D_loss, G_loss

def test_gen_discr(input_img,target_img,generator,discriminator,add_discriminator_input_data,mask,device,adv_loss,l1_loss):
  """
  Test the generator and the discriminator given the inputs and the target image 
  Return : Discrimator loss, Generator loss and generated images
  """

  # switch to eval mode
  generator.eval()
  discriminator.eval()

  with torch.no_grad():
    input_img = input_img.to(device)
    target_img = target_img.to(device)

    # ground truth labels real and fake
    real_target = Variable(torch.ones(input_img.size(0), 1, 30, 30).to(device))
    fake_target = Variable(torch.zeros(input_img.size(0), 1, 30, 30).to(device))

    # generator forward pass
    generated_image = generator(input_img)

    if mask:
      generated_image = modify_generated_image_with_skin_segementation(input_img,generated_image)
    
    # train discriminator with fake/generated images
    if add_discriminator_input_data:
      disc_inp_fake = torch.cat((input_img, generated_image), 1)
    else:
      disc_inp_fake = generated_image
    D_fake = discriminator(disc_inp_fake.detach())
    D_fake_loss = discriminator_loss(D_fake, fake_target)
    
    # train discriminator with real images
    if add_discriminator_input_data:
      disc_inp_real = torch.cat((input_img, target_img), 1)        
    else:
      disc_inp_real = target_img              
    D_real = discriminator(disc_inp_real)
    D_real_loss = discriminator_loss(D_real,  real_target)

    # average discriminator loss
    D_loss = (D_real_loss + D_fake_loss) / 2

    # Train generator with real labels
    if add_discriminator_input_data:
      fake_gen = torch.cat((input_img, generated_image), 1)
    else:
      fake_gen = generated_image
    G = discriminator(fake_gen)
    G_loss = generator_loss(generated_image, target_img, G, real_target,adv_loss,l1_loss)  
                                 
    return D_loss, G_loss, generated_image

def modify_generated_image_with_skin_segementation(input_img,generated_image):
  """
  Modify generated image by retaining only the spots where we have the skin using skin_segmentation image
  Return : Tensor
  """

  # change image to [0,1]
  input_img = (input_img + 1) / 2
  generated_image = (generated_image + 1) / 2

  apparel_head_foot = input_img[:,0:3,:,:]
  skin_segmentation = input_img[:,6:9,:,:]

  new_generated_image = generated_image*skin_segmentation+(1-skin_segmentation)*apparel_head_foot

  # change image back to [-1,1]
  new_generated_image = ( new_generated_image*2 ) - 1

  return new_generated_image

def generate_sample(input_tensor,target_tensor,generator,device,epoch,suffix,mask,path):
  # switch to eval mode
  generator.eval()
  with torch.no_grad():
    inputs = input_tensor.to(device)
    targets = target_tensor.to(device)
    generated_output = generator(inputs)
    apparel_head_foot = input_tensor[:,0:3,:,:]
    if mask :
      generated_output = modify_generated_image_with_skin_segementation(inputs,generated_output)
    save_image(apparel_head_foot, f'{path}/{suffix}/inputs_{suffix}' + '.png', nrow=5, normalize=True)
    save_image(targets, f'{path}/{suffix}/targets_{suffix}' + '.png', nrow=5, normalize=True)
    save_image(generated_output.data, f'{path}/{suffix}/{suffix}_{epoch}'+ '.png', nrow=5, normalize=True)

# Pipeline

In [48]:
def main(config_file):
  # set variables from config file 
  num_epochs = config_file["num_epochs"]
  input_name_list = config_file["input_name_list"]
  dataset_folder_path = config_file["dataset_folder_path"]
  img_size = config_file["img_size"]
  train_batch_size = config_file["train_batch_size"]
  test_batch_size = config_file["test_batch_size"]
  add_discriminator_input_data = config_file["add_discriminator_input_data"]
  adv_loss = config_file["adv_loss"]
  l1_loss = config_file["l1_loss"]
  learning_rate_generator = config_file["learning_rate_generator"]
  learning_rate_discriminator = config_file["learning_rate_discriminator"]
  betas = config_file["betas"]
  mask = config_file["mask"]
  path_results = config_file["path_results"] + f"/{len(input_name_list)}_inputs_{add_discriminator_input_data}_add_D_inputs_{adv_loss}_adv_loss_{l1_loss}_l1_loss_{mask}_mask_{learning_rate_generator}_lrg_{learning_rate_discriminator}_lrd"
  num_workers = config_file["num_workers"]

  device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
  pin_memory = device =='cuda:0'   
  ##### Load train and test dataloader #####

  train_dataset = DeepFashionDataset( dataset_folder_path=dataset_folder_path,
                                      input_name_list=input_name_list,
                                      img_size=img_size,
                                      training=True)
  
  train_dataloader = DataLoader(train_dataset, batch_size=train_batch_size,
                          shuffle=False, num_workers=num_workers, pin_memory= pin_memory)
  
  test_dataset = DeepFashionDataset(dataset_folder_path=dataset_folder_path,
                                      input_name_list=input_name_list,
                                      img_size=img_size,
                                      training=False)
  
  test_dataloader = DataLoader(test_dataset, batch_size=test_batch_size,
                        shuffle=False, num_workers=num_workers, pin_memory = pin_memory, drop_last=True)

  # training sample
  training_input_tensor,training_target_tensor = next(iter(train_dataloader))
  #testing sample 
  testing_input_tensor,testing_target_tensor = next(iter(test_dataloader))

  ###### Calculate FID ######
  print("Calculating FID")
  # FID real data
  # get all training target data 
  target_data_loader = DeepFashionDataset( dataset_folder_path="drive/MyDrive/ProjetS8/data/512/full_body/*",
                                           input_name_list=[],
                                           img_size=(512,512),
                                           training=True)
  train_all_target_data = [target_data_loader[i] for i in range(len(target_data_loader))]
  tar_images_tensor = torch.stack([t for t in train_all_target_data])
  # get real FID
  pfw.set_config(batch_size=8, dims=2048, device=device)
  real_m, real_s = pfw.get_stats(tar_images_tensor)
  print("Finished Calculating FID")

  # set discriminator and generator
  input_nc = training_input_tensor[0].shape[0]
  generator_input_nc = input_nc
  if add_discriminator_input_data:
    discriminator_input_nc = 3 + input_nc
  else :
    discriminator_input_nc = 3
  generator = UnetGenerator(input_nc=generator_input_nc, output_nc=3, nf=64, norm_layer=get_norm_layer(), use_dropout=False).cuda().float()
  init_weights(generator, 'normal', scaling=0.02)
  discriminator = Discriminator(discriminator_input_nc, ndf=64, n_layers=4, norm_layer=get_norm_layer()).cuda().float()
  init_weights(discriminator, 'normal', scaling=0.02)

  # set metrics
  D_loss_plot, G_loss_plot= [], []
  D_loss_test_plot, G_loss_test_plot, FID_test_plot = [], [], []
  score_plot = {"MSE":[],"SSIM":[],"PSNR":[]}
  # set optimizer 
  G_optimizer = optim.Adam(generator.parameters(), lr = learning_rate_generator, betas=betas)
  D_optimizer = optim.Adam(discriminator.parameters(), lr = learning_rate_discriminator, betas=betas)

  # create repo if not exist
  Path(path_results+"/train").mkdir(parents=True, exist_ok=True)
  Path(path_results+"/test").mkdir(parents=True, exist_ok=True)

  # train and test 
  for epoch in tqdm(range(1, num_epochs+1)): 
    ############### TRAIN  #################
    D_loss_list, G_loss_list, FID_list = [], [], []

    for (input_img, target_img) in train_dataloader:
      D_loss, G_loss = train_gen_discr(input_img,target_img,generator,discriminator,G_optimizer,D_optimizer,add_discriminator_input_data,mask,device,adv_loss,l1_loss)
      D_loss_list.append(D_loss)
      G_loss_list.append(G_loss)

    # save D&G loss in a list
    D_loss_plot.append(torch.mean(torch.FloatTensor(D_loss_list)))
    G_loss_plot.append(torch.mean(torch.FloatTensor(G_loss_list)))

    print('Epoch: [%d/%d] Training: D_loss: %.3f, G_loss: %.5f'%(epoch, num_epochs, D_loss_plot[-1],G_loss_plot[-1]))

    # generate sample for training data
    if epoch%20==0:
      generate_sample(training_input_tensor,training_target_tensor,generator,device,epoch,"train",mask,path_results)

    # Do testing every 20 epochs
    if epoch%20==0:
      ############ TESTING ###########
      D_loss_test_list, G_loss_test_list, FID_test_list = [], [], []
      score_list = {"MSE":[],"SSIM":[],"PSNR":[]}

      for (input_img, target_img) in test_dataloader:
        # test 
        D_loss, G_loss, generated_image = test_gen_discr(input_img,target_img,generator,discriminator,add_discriminator_input_data,mask,device,adv_loss,l1_loss)
        D_loss_test_list.append(D_loss)
        G_loss_test_list.append(G_loss)

        # calculate fid score
        try:
          fid = pfw.fid(generated_image, real_m=real_m, real_s=real_s)
          FID_list.append(fid)
        except:
          print("fid calculation pb")

        # calculate evaluation metrics 
        generated_image_np, target_np = tensor_to_np(generated_image,device),tensor_to_np(target_img,device)
        eval_score_dict = compute_eval_metrics(generated_image_np,target_np)
        for key in eval_score_dict:
          score_list[key].append(eval_score_dict[key])

      # store metrics in list 
      D_loss_test_plot.append(torch.mean(torch.FloatTensor(D_loss_test_list)))
      G_loss_test_plot.append(torch.mean(torch.FloatTensor(G_loss_test_list)))
      FID_test_plot.append(torch.mean(torch.FloatTensor(FID_list)))
      for key in score_list:
        score_plot[key].append(np.mean(score_list[key]))
      print("------Testing-------")
      print('Epoch: [%d/%d] Testing: D_loss: %.3f, G_loss: %.5f, FID_score %.3f'%(epoch, num_epochs, D_loss_test_plot[-1],G_loss_test_plot[-1],FID_test_plot[-1]))
      print(f'Testing evaluation metrics: {[(item[0],item[1][-1])for item in score_plot.items()]}')
      print("--------------------")

      #testing sample 
      generate_sample(testing_input_tensor,testing_target_tensor,generator,device,epoch,"test",mask,path_results)

  training_metrics_df = pd.DataFrame({"epoch" : [i for i in range(1,num_epochs+1)],
                        "D_train_loss" : [i.item() for i in D_loss_plot],
                        "G_train_loss" : [i.item() for i in G_loss_plot]})

  testing_metrics_df = pd.DataFrame(dict({ "epoch" : [i for i in range(20,num_epochs+1,20)],
                              "D_test_loss" : [i.item() for i in D_loss_test_plot],
                              "G_test_loss" : [i.item() for i in G_loss_test_plot],
                              "FID_test": [i.item() for i in FID_test_plot]
                            },
                              **score_plot))
  training_metrics_df.to_csv(path_results+"/train_metrics.csv")
  testing_metrics_df.to_csv(path_results+"/test_metrics.csv")

  with open(path_results+'/config.json', 'w', encoding ='utf8') as json_file:
    json.dump(config_file, json_file)

In [79]:
# Input
for input_name_list in [["apparel_head_foot","line"],["apparel_head_foot","line","segmentation_skin"]]:
  config_file = {
        "num_epochs": 120, 
        "dataset_folder_path":"drive/MyDrive/ProjetS8/data/512/full_body/*",
        "img_size":(512,512),
        "train_batch_size":32,
        "test_batch_size":20,
        "learning_rate_generator":2e-4,
        "learning_rate_discriminator":2e-5,
        "betas":(0.5, 0.999),
        "path_results":"drive/MyDrive/ProjetS8/results/512/full_body",
        "num_workers":4,
        "input_name_list":input_name_list,
        "add_discriminator_input_data":True,
        "adv_loss":True,
        "l1_loss":True,
        "mask":False,
  }
  main(config_file)

Calculating FID
Finished Calculating FID


  0%|          | 0/120 [00:00<?, ?it/s]

Epoch: [1/120] Training: D_loss: 0.667, G_loss: 0.46450
Epoch: [2/120] Training: D_loss: 0.588, G_loss: 0.09900
Epoch: [3/120] Training: D_loss: 0.565, G_loss: 0.05926
Epoch: [4/120] Training: D_loss: 0.562, G_loss: 0.04778
Epoch: [5/120] Training: D_loss: 0.598, G_loss: 0.04195
Epoch: [6/120] Training: D_loss: 0.597, G_loss: 0.03901
Epoch: [7/120] Training: D_loss: 0.617, G_loss: 0.03586
Epoch: [8/120] Training: D_loss: 0.620, G_loss: 0.03516
Epoch: [9/120] Training: D_loss: 0.617, G_loss: 0.03326
Epoch: [10/120] Training: D_loss: 0.613, G_loss: 0.03209
Epoch: [11/120] Training: D_loss: 0.610, G_loss: 0.03085
Epoch: [12/120] Training: D_loss: 0.592, G_loss: 0.03008
Epoch: [13/120] Training: D_loss: 0.583, G_loss: 0.02968
Epoch: [14/120] Training: D_loss: 0.571, G_loss: 0.02865
Epoch: [15/120] Training: D_loss: 0.545, G_loss: 0.02817
Epoch: [16/120] Training: D_loss: 0.551, G_loss: 0.02734
Epoch: [17/120] Training: D_loss: 0.554, G_loss: 0.02674
Epoch: [18/120] Training: D_loss: 0.536,

  ret = umr_sum(arr, axis, dtype, out, keepdims, where=where)
  X -= avg[:, None]


fid calculation pb
------Testing-------
Epoch: [120/120] Testing: D_loss: 0.619, G_loss: 0.02660, FID_score 166.202
Testing evaluation metrics: [('MSE', 0.011865744640758713), ('SSIM', 0.9187389718435109), ('PSNR', 20.14871873777711)]
--------------------
Calculating FID
Finished Calculating FID


  0%|          | 0/120 [00:00<?, ?it/s]

Epoch: [1/120] Training: D_loss: 0.642, G_loss: 0.46175
Epoch: [2/120] Training: D_loss: 0.631, G_loss: 0.08629
Epoch: [3/120] Training: D_loss: 0.683, G_loss: 0.04351
Epoch: [4/120] Training: D_loss: 0.699, G_loss: 0.03445
Epoch: [5/120] Training: D_loss: 0.690, G_loss: 0.03109
Epoch: [6/120] Training: D_loss: 0.680, G_loss: 0.02921
Epoch: [7/120] Training: D_loss: 0.671, G_loss: 0.02795
Epoch: [8/120] Training: D_loss: 0.665, G_loss: 0.02686
Epoch: [9/120] Training: D_loss: 0.660, G_loss: 0.02612
Epoch: [10/120] Training: D_loss: 0.647, G_loss: 0.02553
Epoch: [11/120] Training: D_loss: 0.643, G_loss: 0.02476
Epoch: [12/120] Training: D_loss: 0.636, G_loss: 0.02424
Epoch: [13/120] Training: D_loss: 0.620, G_loss: 0.02364
Epoch: [14/120] Training: D_loss: 0.614, G_loss: 0.02319
Epoch: [15/120] Training: D_loss: 0.613, G_loss: 0.02325
Epoch: [16/120] Training: D_loss: 0.599, G_loss: 0.02243
Epoch: [17/120] Training: D_loss: 0.587, G_loss: 0.02208
Epoch: [18/120] Training: D_loss: 0.572,

In [49]:
# add inputs to discirminator 
config_file = {
      "num_epochs": 120, 
      "dataset_folder_path":"drive/MyDrive/ProjetS8/data/512/full_body/*",
      "img_size":(512,512),
      "train_batch_size":32,
      "test_batch_size":20,
      "learning_rate_generator":2e-4,
      "learning_rate_discriminator":2e-5,
      "betas":(0.5, 0.999),
      "path_results":"drive/MyDrive/ProjetS8/results/512/full_body",
      "num_workers":4,
      "input_name_list":["apparel_head_foot","line"],
      "add_discriminator_input_data":False,
      "adv_loss":True,
      "l1_loss":True,
      "mask":False,
}
main(config_file)

Calculating FID


Downloading: "https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth" to /root/.cache/torch/hub/checkpoints/pt_inception-2015-12-05-6726825d.pth


  0%|          | 0.00/91.2M [00:00<?, ?B/s]

Finished Calculating FID


  0%|          | 0/120 [00:00<?, ?it/s]

Epoch: [1/120] Training: D_loss: 0.560, G_loss: 0.52715
Epoch: [2/120] Training: D_loss: 0.481, G_loss: 0.11170
Epoch: [3/120] Training: D_loss: 0.422, G_loss: 0.06049
Epoch: [4/120] Training: D_loss: 0.361, G_loss: 0.04870
Epoch: [5/120] Training: D_loss: 0.316, G_loss: 0.04316
Epoch: [6/120] Training: D_loss: 0.298, G_loss: 0.03985
Epoch: [7/120] Training: D_loss: 0.276, G_loss: 0.03798
Epoch: [8/120] Training: D_loss: 0.297, G_loss: 0.03609
Epoch: [9/120] Training: D_loss: 0.286, G_loss: 0.03471
Epoch: [10/120] Training: D_loss: 0.293, G_loss: 0.03370
Epoch: [11/120] Training: D_loss: 0.307, G_loss: 0.03273
Epoch: [12/120] Training: D_loss: 0.334, G_loss: 0.03178
Epoch: [13/120] Training: D_loss: 0.375, G_loss: 0.03055
Epoch: [14/120] Training: D_loss: 0.357, G_loss: 0.03023
Epoch: [15/120] Training: D_loss: 0.389, G_loss: 0.02936
Epoch: [16/120] Training: D_loss: 0.351, G_loss: 0.02898
Epoch: [17/120] Training: D_loss: 0.402, G_loss: 0.02792
Epoch: [18/120] Training: D_loss: 0.382,

In [50]:
# mask
config_file = {
      "num_epochs": 120, 
      "dataset_folder_path":"drive/MyDrive/ProjetS8/data/512/full_body/*",
      "img_size":(512,512),
      "train_batch_size":32,
      "test_batch_size":20,
      "learning_rate_generator":2e-4,
      "learning_rate_discriminator":2e-5,
      "betas":(0.5, 0.999),
      "path_results":"drive/MyDrive/ProjetS8/results/512/full_body",
      "num_workers":4,
      "input_name_list":["apparel_head_foot","line","segmentation_skin"],
      "add_discriminator_input_data":True,
      "adv_loss":True,
      "l1_loss":True,
      "mask":True,
}
main(config_file)

Calculating FID
Finished Calculating FID


  0%|          | 0/120 [00:00<?, ?it/s]

Epoch: [1/120] Training: D_loss: 0.756, G_loss: 0.00983
Epoch: [2/120] Training: D_loss: 0.722, G_loss: 0.00866
Epoch: [3/120] Training: D_loss: 0.710, G_loss: 0.00824
Epoch: [4/120] Training: D_loss: 0.704, G_loss: 0.00795
Epoch: [5/120] Training: D_loss: 0.701, G_loss: 0.00752
Epoch: [6/120] Training: D_loss: 0.699, G_loss: 0.00725
Epoch: [7/120] Training: D_loss: 0.698, G_loss: 0.00702
Epoch: [8/120] Training: D_loss: 0.697, G_loss: 0.00678
Epoch: [9/120] Training: D_loss: 0.697, G_loss: 0.00652
Epoch: [10/120] Training: D_loss: 0.696, G_loss: 0.00639
Epoch: [11/120] Training: D_loss: 0.696, G_loss: 0.00626
Epoch: [12/120] Training: D_loss: 0.695, G_loss: 0.00613
Epoch: [13/120] Training: D_loss: 0.695, G_loss: 0.00594
Epoch: [14/120] Training: D_loss: 0.695, G_loss: 0.00572
Epoch: [15/120] Training: D_loss: 0.695, G_loss: 0.00564
Epoch: [16/120] Training: D_loss: 0.695, G_loss: 0.00554
Epoch: [17/120] Training: D_loss: 0.695, G_loss: 0.00540
Epoch: [18/120] Training: D_loss: 0.695,

In [20]:
from IPython.display import Image, display

def get_df(config_file):
  num_epochs = config_file["num_epochs"]
  input_name_list = config_file["input_name_list"]
  add_discriminator_input_data = config_file["add_discriminator_input_data"]
  adv_loss = config_file["adv_loss"]
  l1_loss = config_file["l1_loss"]
  learning_rate_generator = config_file["learning_rate_generator"]
  learning_rate_discriminator = config_file["learning_rate_discriminator"]
  mask = config_file["mask"]
  path_results = config_file["path_results"] + f"/{len(input_name_list)}_inputs_{add_discriminator_input_data}_add_D_inputs_{adv_loss}_adv_loss_{l1_loss}_l1_loss_{mask}_mask_{learning_rate_generator}_lrg_{learning_rate_discriminator}_lrd"
  folder_name = f"{len(input_name_list)}_inputs_{add_discriminator_input_data}_add_D_inputs_{adv_loss}_adv_loss_{l1_loss}_l1_loss_{mask}_mask_{learning_rate_generator}_lrg_{learning_rate_discriminator}_lrd"
  with open(path_results+"/test_metrics.csv","r") as f:
    df = pd.read_csv(f)
  return(df,folder_name)

In [21]:
def dispay_img(config_file,epoch):
  num_epochs = config_file["num_epochs"]
  input_name_list = config_file["input_name_list"]
  add_discriminator_input_data = config_file["add_discriminator_input_data"]
  adv_loss = config_file["adv_loss"]
  l1_loss = config_file["l1_loss"]
  learning_rate_generator = config_file["learning_rate_generator"]
  learning_rate_discriminator = config_file["learning_rate_discriminator"]
  mask = config_file["mask"]
  path_results = config_file["path_results"] + f"/{len(input_name_list)}_inputs_{add_discriminator_input_data}_add_D_inputs_{adv_loss}_adv_loss_{l1_loss}_l1_loss_{mask}_mask_{learning_rate_generator}_lrg_{learning_rate_discriminator}_lrd"
  display(Image(path_results+f"/test/test_{epoch}.png"))
  files.download(path_results+f"/test/test_{epoch}.png")

In [23]:
%matplotlib inline
for metric in ["MSE","SSIM","PSNR","FID_test"]:
  for val in [["apparel_head_foot"],["apparel_head_foot","line"],["apparel_head_foot","line","segmentation_skin"]]:
    config_file = {
              "num_epochs": 80,
              "dataset_folder_path":"drive/MyDrive/ProjetS8/data/512/full_body/*",
              "img_size":(512,512),
              "train_batch_size":32,
              "test_batch_size":20,
              "learning_rate_generator":2e-4,
              "learning_rate_discriminator":2e-5,
              "betas":(0.5, 0.999),
              "path_results":"drive/MyDrive/ProjetS8/results/512/full_body",
              "num_workers":0,
              "input_name_list":val,
              "add_discriminator_input_data":True,
              "adv_loss":True,
              "l1_loss":True,
              "mask":False,
    }
    df,folder_name = get_df(config_file)
    if metric == "FID_test":
      epoch_val = df[df["FID_test"]==min(df["FID_test"])]["epoch"].iloc[0]
      print(epoch_val)
      dispay_img(config_file,epoch_val)
    plt.plot(df["epoch"],df[metric],label=f"{val}")
    plt.ylabel(metric)
    plt.xlabel("epoch")
    if metric == "MSE" or metric == "FID_test":
      print(f"{val} : {min(df[metric])}")
    else:
       print(f"{val} : {max(df[metric])}")
  plt.title(folder_name)
  plt.legend(loc="upper left")
  plt.show()


Output hidden; open in https://colab.research.google.com to view.

In [52]:
%matplotlib inline
for metric in ["MSE","SSIM","PSNR","FID_test"]:
  for val in [True,False]:
    config_file = {
              "num_epochs": 80,
              "dataset_folder_path":"drive/MyDrive/ProjetS8/data/512/full_body/*",
              "img_size":(512,512),
              "train_batch_size":32,
              "test_batch_size":20,
              "learning_rate_generator":2e-4,
              "learning_rate_discriminator":2e-5,
              "betas":(0.5, 0.999),
              "path_results":"drive/MyDrive/ProjetS8/results/512/full_body",
              "num_workers":0,
              "input_name_list":["apparel_head_foot","line"],
              "add_discriminator_input_data":val,
              "adv_loss":True,
              "l1_loss":True,
              "mask":False,
    }
    df,folder_name = get_df(config_file)
    if metric == "FID_test":
      epoch_val = df[df["FID_test"]==min(df["FID_test"])]["epoch"].iloc[0]
      dispay_img(config_file,epoch_val)
    plt.plot(df["epoch"],df[metric],label=f"{val}")
    plt.ylabel(metric)
    plt.xlabel("epoch")
    if metric == "MSE" or metric == "FID_test":
      print(f"{val} : {min(df[metric])}")
    else:
       print(f"{val} : {max(df[metric])}")

  plt.title(folder_name)
  plt.legend(loc="upper left")
  plt.show()

Output hidden; open in https://colab.research.google.com to view.