In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
import torch.optim.lr_scheduler as lr_scheduler
import matplotlib.pyplot as plt
import numpy as np
import os
import math
from torchvision.transforms import ToPILImage
from torchvision import models
from torchvision import transforms
from PIL import Image
from diffusers import AutoencoderKL
from diffusionmodel import UNet, NoiseScheduler
from losses import VGGLoss

In [None]:
# weights initialization

def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.xavier_normal_(m.weight)
        if m.bias is not None:
            torch.nn.init.constant_(m.bias, 0)

    if isinstance(m, nn.GroupNorm):
        if hasattr(m, 'weight') and m.weight is not None:
            torch.nn.init.constant_(m.weight, 1)
        if hasattr(m, 'bias') and m.bias is not None:
            torch.nn.init.constant_(m.bias, 0)

# Use of the VAE to encode and decode images into latent space

def encode_latents(images):
    # images: [B,3,H,W], range [-1,1]
    with torch.no_grad():
        latents = vae.encode(images).latent_dist.sample()
        latents = latents * 0.18215  # SD scaling factor
    return latents

def decode_latents(latents):
    latents = latents / 0.18215
    with torch.no_grad():
        imgs = vae.decode(latents).sample
    return imgs

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Load pretrained Stable Diffusion VAE
vae = AutoencoderKL.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    subfolder="vae"
).to(device)
vae.eval()


In [None]:
original_images = torch.load('data/original_images_batch.pth')
inputs_viton = torch.load('data/inputs_viton_batch.pth')
mask_images = torch.load('data/mask_images_batch.pth')

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

lr = 0.00001
mse_loss = nn.MSELoss().to(device)
vgg_loss = VGGLoss().to(device)
vgg_lambda = 0.001
epochs = 150

viton = UNet(9, 4).to(device)
viton_opt = torch.optim.AdamW(viton.parameters(), lr=lr)

viton = viton.apply(weights_init)

In [None]:
scheduler = NoiseScheduler(timesteps=1000, beta_schedule="cosine", device=device)

In [None]:
# Convert images into latent space

z_original_images = []
z_inputs_viton = []
z_mask_images = []

for original_image, input_viton, mask_image in zip(original_images_batch, inputs_viton_batch, mask_images_batch):
    z_original_image = encode_latents(original_image)
    z_input_viton = encode_latents(input_viton)
    z_mask_image = F.interpolate(mask_image, size=z_original_image.shape[-2:], mode="nearest")

    z_original_images.append(z_original_image)
    z_inputs_viton.append(z_input_viton)
    z_mask_images.append(z_mask_image)

In [None]:
losses = []

for epoch in range(101, epochs+1):

  for original_image, z_original_image, input_viton, z_input_viton, mask_image, z_mask_image in zip(original_images, z_original_images, inputs_viton, z_inputs_viton, mask_images, z_mask_images):

    x_0_original = original_image.to(device)

    t = torch.randint(0, scheduler.timesteps, (1, )).to(device)

    z_t_original, noise = scheduler.get_noisy_image(z_original_image, t)

    x_0_agnostic = input_viton

    z_t_agnostic, _ = scheduler.get_noisy_image(z_input_viton, t, noise)

    input_original = torch.concat([z_t_original, z_input_viton, z_mask_image[:, 0:1, :, :]], dim=1)

    unet_output_original = viton(input_original, torch.tensor([t]).to(device))

    if epoch < 150:

      loss = mse_loss(unet_output_original, noise)

    else:

      input_agnostic = torch.concat([z_t_agnostic, z_input_viton, z_mask_image[:, 0:1, :, :]], dim=1)

      unet_output_agnostic = viton(input_agnostic, torch.tensor([t]).to(device))

      z_denoise_agnostic = scheduler.denoise_image(z_t_agnostic, t, unet_output_agnostic)

      denoise_agnostic = decode_latents(z_denoise_agnostic)

      # Mask images needs to be rounded because the edges of the mask are not
      # exactly one and we rescalar it as well to be [0,1]

      mask_image = ((mask_image+1)/2).round().repeat(1, 3, 1, 1)

      denoise_agnostic_masked_area = x_0_original.clone()

      # For the VGG only the mask area will be part of the loss calculation.

      denoise_agnostic_masked_area[mask_image == 1] = denoise_agnostic[mask_image == 1]

      loss = mse_loss(unet_output_original, noise) + (vgg_lambda * vgg_loss(denoise_agnostic_masked_area, x_0_original))

    losses.append(loss.item())

    viton_opt.zero_grad()
    loss.backward()
    viton_opt.step()

    # Plot the output images to check the visual performance of the model.

    if epoch == 150:

      print(f"Epoch {epoch}, Mean Loss: {np.mean(losses)} at timestep {t} and lr: {lr}")

      fig, axs = plt.subplots(1, 3, figsize=(10, 5))
      axs[0].imshow(((original_image[0]).permute(1, 2, 0).detach().cpu().numpy()+1)/2)
      axs[0].axis('off')

      axs[1].imshow(((denoise_agnostic_masked_area[0]).permute(1, 2, 0).detach().cpu().numpy()+1)/2)
      axs[1].axis('off')

      axs[2].imshow(((denoise_agnostic[0]).permute(1, 2, 0).detach().cpu().numpy()+1)/2)
      axs[2].axis('off')

    plt.show()

  # if epoch % 10 == 0:

  #   # Save the model checkpoint each 10 epochs
  #   checkpoint = {
  #       'epoch': epoch,
  #       'model_state_dict': viton.state_dict(),
  #       'optimizer_state_dict': viton_opt.state_dict(),
  #       'loss': losses,
  #   }

  #   torch.save(checkpoint, f'/content/drive/MyDrive/AIClothes/DDPM/Models/Diffusion/latent512_2000_batch5_SelfAttention_cosinescheduler_resnet_adamW_lr{lr}_vgg{vgg_lambda}_epoch_{epoch}.pth')

  print(f"Epoch {epoch}, Mean Loss: {np.mean(losses)}")