In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
import matplotlib.pyplot as plt
import numpy as np
import os
import random
from torchvision.transforms import ToPILImage
from torchvision import models
from torchvision import transforms
from PIL import Image
from warpingnetwork import WarpingProcess
from losses import VGGLoss, flow_loss

In [None]:
def make_grid(N, iW, iH, device):
    grid_x = torch.linspace(-1.0, 1.0, iW).view(1, 1, iW, 1).expand(N, iH, -1, -1).to(device)
    grid_y = torch.linspace(-1.0, 1.0, iH).view(1, iH, 1, 1).expand(N, -1, iW, -1).to(device)
    grid = torch.cat([grid_x, grid_y], 3)
    return grid

def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.xavier_normal_(m.weight)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.xavier_normal_(m.weight)
        torch.nn.init.constant_(m.bias, 0)

def warped_cloth_into_agnostic(agnostic_image, warped_cloth_image, warped_cloth_mask):

  agnostic_image_warped_cloth = agnostic_image.clone()

  warped_cloth_mask[warped_cloth_mask > 0.5] = 1

  warped_cloth_mask = warped_cloth_mask.repeat(1, 3, 1, 1)

  agnostic_image_warped_cloth[warped_cloth_mask == 1] = warped_cloth_image[warped_cloth_mask == 1]

  return agnostic_image_warped_cloth

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
warped_masks = torch.load('data/warping/warped_masks.pth').to(device)
agnostics = torch.load('data/warping/agnostics.pth').to(device)
agnostic_masks = torch.load('data/warping/agnostic_masks.pth').to(device)
inputs_1 = torch.load('data/warping/inputs_1.pth').to(device)
inputs_2 = torch.load('data/warping/inputs_2.pth').to(device)

In [None]:
warped_masks = torch.unbind(warped_masks, dim=0)
agnostics = torch.unbind(agnostics, dim=0)
agnostic_masks = torch.unbind(agnostic_masks, dim=0)
inputs_1 = torch.unbind(inputs_1, dim=0)
inputs_2 = torch.unbind(inputs_2, dim=0)

In [None]:
warped_cloths = [((i + 1 )/ 2)*j for i, j in zip(warped_masks, images)]

In [None]:
agnostics = [i.unsqueeze(0) for i in agnostics]
agnostic_masks = [i.unsqueeze(0) for i in agnostic_masks]
inputs_1 = [i.unsqueeze(0) for i in inputs_1]
inputs_2 = [i.unsqueeze(0) for i in inputs_2]

In [None]:
lr = 0.00005
l1_lambda = 1
vgg_lambda = 0.1
tvlambda = 1
epochs = 30
criterionL1 = nn.L1Loss().to(device)
criterionVGG = VGGLoss().to(device)

In [None]:
generator = WarpingProcess(96).to(device)
gen_opt = torch.optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
generator = generator.apply(weights_init)

In [None]:
count = 0

losses = []

for i in range(epochs+1):

  for input_1, input_2, warped_cloth, warped_mask in zip(inputs_1, inputs_2, warped_cloths, warped_masks):

    warped_c, warped_cm, flow_list = generator(input_1, input_2, device)

    loss_l1_cloth = criterionL1(warped_cm, warped_mask)

    loss_vgg = criterionVGG(warped_cm*warped_c, warped_cloth)

    loss_tv = flow_loss(flow_list)

    # multi-scale flow prediction

    for j in range(len(flow_list)-1):
      flow = flow_list[j]
      N, fH, fW, _ = flow.size()
      grid = make_grid(N, input_1.shape[2], input_1.shape[3], device)
      flow = F.interpolate(flow.permute(0, 3, 1, 2), size = input_1.shape[2:], mode='bilinear').permute(0, 2, 3, 1)
      flow_norm = torch.cat([flow[:, :, :, 0:1] / ((flow.shape[1] - 1.0) / 2.0), flow[:, :, :, 1:2] / ((flow.shape[2] - 1.0) / 2.0)], 3)
      warped = F.grid_sample(input_1, grid + flow_norm, padding_mode='border')
      warped_c_flow = warped[:, :-1, :, :]
      warped_cm_flow = warped[:, -1:, :, :]

      loss_l1_cloth += criterionL1(warped_cm_flow, warped_mask) / (2 ** (4-j))
      loss_vgg += criterionVGG(warped_cm_flow*warped_c_flow, warped_cloth) / (2 ** (4-j))


    loss_G = ((l1_lambda * loss_l1_cloth) + (vgg_lambda * loss_vgg) + (tvlambda * loss_tv))

    losses.append(loss_G.item())

    gen_opt.zero_grad()
    loss_G.backward()
    gen_opt.step()

    # Visualization of the results

    # if epoch % 10 == 0:

    #   fig, axs = plt.subplots(1, 7, figsize=(10, 5))
    #   axs[0].imshow(warped_c[0].permute(1, 2, 0).detach().cpu().numpy())
    #   axs[0].axis('off')

    #   axs[1].imshow(warped_cloth[0].permute(1, 2, 0).detach().cpu().numpy())
    #   axs[1].axis('off')

    #   axs[2].imshow(warped_cm[0].permute(1, 2, 0).detach().cpu().numpy())
    #   axs[2].axis('off')

    #   axs[3].imshow(warped_mask[0].permute(1, 2, 0).detach().cpu().numpy())
    #   axs[3].axis('off')

    #   axs[4].imshow(input_1[0, 0:3, ...].permute(1, 2, 0).detach().cpu().numpy())
    #   axs[4].axis('off')

    #   axs[5].imshow(input_2[0, 0:3, ...].permute(1, 2, 0).detach().cpu().numpy())
    #   axs[5].axis('off')

    #   axs[6].imshow(input_2[0, 3:6, ...].permute(1, 2, 0).detach().cpu().numpy())
    #   axs[6].axis('off')

    #   plt.tight_layout()
    #   plt.show()

  print(f'Epoch: {i}, Mean Loss: {np.mean(losses)}')

In [None]:
# Add the warped cloth output of the model into the agnostic image.

inputs_diffusion_model = []

with torch.no_grad():
  for img in range(len(agnostics)):

    agnostic_image = agnostics[img]

    warped_c, warped_cm, flow_list = generator(inputs_1[img], inputs_2[img], device)

    new_im = warped_cloth_into_agnostic(agnostic_image, warped_c, warped_cm)

    new_im = F.interpolate(new_im, size=(512, 512), mode='bilinear', align_corners=False)

    inputs_diffusion_model.append(new_im)

In [None]:
# Sample of 2k images for the diffusion model

list_sample_inputs_diffusion_model = random.sample([i for i in range(3000)], 2000)

sample_inputs_diffusion_model = []

for i in list_sample_inputs_diffusion_model:
  sample_inputs_diffusion_model.append(inputs_diffusion_model[i])

In [None]:
torch.save(torch.cat(inputs_diffusion_model, dim=0), 'data/input_difussion_model_sample.pth')