In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
import matplotlib.pyplot as plt
import numpy as np
import os
from torchvision.transforms import ToPILImage
from torchvision import models
from torchvision import transforms
from PIL import Image
from google.colab import drive
drive.mount('/content/drive')

In [None]:
def make_grid(N, iW, iH, device):
    grid_x = torch.linspace(-1.0, 1.0, iW).view(1, 1, iW, 1).expand(N, iH, -1, -1).to(device)
    grid_y = torch.linspace(-1.0, 1.0, iH).view(1, iH, 1, 1).expand(N, -1, iW, -1).to(device)
    grid = torch.cat([grid_x, grid_y], 3)
    return grid

def save_tensor_as_image(tensor, save_path):
    """
    Saves a [C, H, W] or [1, C, H, W] tensor as an image file.
    """
    if tensor.dim() == 4:
        tensor = tensor.squeeze(0)  # Remove batch dimension

    to_pil = ToPILImage()
    image = to_pil(tensor.cpu().clamp(0, 1))  # Clamp values to [0,1] if needed
    image.save(save_path)

def flow_loss(flow_list):

    loss_tv = 0

    for flow in flow_list:
      y_tv = torch.abs(flow[:, 1:, :, :] - flow[:, :-1, :, :]).mean()
      x_tv = torch.abs(flow[:, :, 1:, :] - flow[:, :, :-1, :]).mean()
      loss_tv = loss_tv + y_tv + x_tv

    return loss_tv

In [None]:
class Vgg19(nn.Module):
    def __init__(self, requires_grad=False):
        super(Vgg19, self).__init__()
        vgg_pretrained_features = models.vgg19(pretrained=True).features
        self.slice1 = torch.nn.Sequential()
        self.slice2 = torch.nn.Sequential()
        self.slice3 = torch.nn.Sequential()
        self.slice4 = torch.nn.Sequential()
        self.slice5 = torch.nn.Sequential()
        for x in range(2):
            self.slice1.add_module(str(x), vgg_pretrained_features[x])
        for x in range(2, 7):
            self.slice2.add_module(str(x), vgg_pretrained_features[x])
        for x in range(7, 12):
            self.slice3.add_module(str(x), vgg_pretrained_features[x])
        for x in range(12, 21):
            self.slice4.add_module(str(x), vgg_pretrained_features[x])
        for x in range(21, 30):
            self.slice5.add_module(str(x), vgg_pretrained_features[x])
        if not requires_grad:
            for param in self.parameters():
                param.requires_grad = False

    def forward(self, X):
        h_relu1 = self.slice1(X)
        h_relu2 = self.slice2(h_relu1)
        h_relu3 = self.slice3(h_relu2)
        h_relu4 = self.slice4(h_relu3)
        h_relu5 = self.slice5(h_relu4)
        out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
        return out

class VGGLoss(nn.Module):
    def __init__(self,layids = None):
        super(VGGLoss, self).__init__()
        self.vgg = Vgg19()
        self.criterion = nn.L1Loss()
        self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0]
        self.layids = layids

    def forward(self, x, y):
        x_vgg, y_vgg = self.vgg(x), self.vgg(y)
        loss = 0
        if self.layids is None:
            self.layids = list(range(len(x_vgg)))
        for i in self.layids:
            loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
        return loss

In [None]:
class ResNetEncoderBlock(nn.Module):

  def __init__(self, input_channels, output_channels, use_dropout=False, use_bn=True, down=True, up=False):
      super(ResNetEncoderBlock, self).__init__()

      if down:
          self.scale = nn.Conv2d(input_channels, output_channels, kernel_size=3, stride=2, padding=1)
      elif up:
          self.scale = nn.Sequential(
              nn.Upsample(scale_factor=2, mode='bilinear'),
              nn.Conv2d(input_channels, output_channels, kernel_size=3, padding=1)
          )
      else:
          self.scale = nn.Conv2d(input_channels, output_channels, kernel_size=1)

      self.activation = nn.ReLU()

      if use_bn:
          self.batchnorm = nn.InstanceNorm2d(output_channels)
      self.use_bn = use_bn

      if use_dropout:
          self.dropout = nn.Dropout()
      self.use_dropout = use_dropout

      self.conv_1 = nn.Conv2d(output_channels, output_channels, kernel_size=3, padding=1)

      self.conv_2 = nn.Conv2d(output_channels, output_channels, kernel_size=3, padding=1)

  def forward(self, x):

      residual = self.scale(x)
      conv1_x = self.conv_1(residual)
      if self.use_bn:
          conv1_x = self.batchnorm(conv1_x)
      if self.use_dropout:
          conv1_x = self.dropout(conv1_x)
      conv1_x = self.activation(conv1_x)
      conv2_x = self.conv_2(conv1_x)
      if self.use_bn:
          conv2_x = self.batchnorm(conv2_x)
      if self.use_dropout:
          conv2_x = self.dropout(conv2_x)

      return self.activation(conv2_x + residual)

class ClothingEncoder(nn.Module):

  def __init__(self, input_channels, output_channels):
      super(ClothingEncoder, self).__init__()
      self.resnet1 = ResNetEncoderBlock(input_channels, output_channels)
      self.resnet2 = ResNetEncoderBlock(output_channels, output_channels * 2)
      self.resnet3 = ResNetEncoderBlock(output_channels * 2, output_channels * 4)
      self.resnet4 = ResNetEncoderBlock(output_channels * 4, output_channels * 4)
      self.resnet5 = ResNetEncoderBlock(output_channels * 4, output_channels * 4)

  def forward(self, x):

      x0 = self.resnet1(x)
      x1 = self.resnet2(x0)
      x2 = self.resnet3(x1)
      x3 = self.resnet4(x2)
      x4 = self.resnet5(x3)

      return x0, x1, x2, x3, x4

class SegmentEncoder(nn.Module):

  def __init__(self, input_channels, output_channels):
      super(SegmentEncoder, self).__init__()
      self.resnet1 = ResNetEncoderBlock(input_channels, output_channels)
      self.resnet2 = ResNetEncoderBlock(output_channels, output_channels * 2)
      self.resnet3 = ResNetEncoderBlock(output_channels * 2, output_channels * 4)
      self.resnet4 = ResNetEncoderBlock(output_channels * 4, output_channels * 4)
      self.resnet5 = ResNetEncoderBlock(output_channels * 4, output_channels * 4)
      self.resnet6 = ResNetEncoderBlock(output_channels * 4, output_channels * 8, down=False)
      self.resnet7 = ResNetEncoderBlock(output_channels * 8, output_channels * 4, down=False, up=True)

  def forward(self, x):

      x0 = self.resnet1(x)
      x1 = self.resnet2(x0)
      x2 = self.resnet3(x1)
      x3 = self.resnet4(x2)
      x4 = self.resnet5(x3)
      x5 = self.resnet6(x4)
      x6 = self.resnet7(x5)

      return x0, x1, x2, x3, x4, x5, x6

In [None]:
size = (256, 256)

transform = transforms.Compose([
        transforms.Resize(size),        # Resize to fixed size
        transforms.ToTensor(),          # Converts to [C, H, W], values in [0, 1]
        # transforms.Normalize([0.5]*3, [0.5]*3)  # Optional if using Tanh output
    ])

def load_image(image_path, size=(576, 576), transform=transform):
    """
    Loads a single image and converts it to a tensor of shape [1, 3, H, W]
    """



    image = Image.open(image_path).convert('RGB')  # Ensure 3 channels
    image = transform(image)                      # [3, H, W]

    tensor = image.unsqueeze(0)

    return tensor

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

list_names = os.listdir("/content/drive/MyDrive/AIClothes/Inputs_VITON/cloth")

images = []
cloths = []
cloth_masks = []
denseposes = []
parse_agnostics = []
warped_masks = []
warped_cloths = []
agnostics = []

inputs_1 = []
inputs_2 = []


for i in list_names:

  cloth = load_image("/content/drive/MyDrive/AIClothes/Inputs_VITON/cloth/" + i).to(device)

  cloth_mask = load_image("/content/drive/MyDrive/AIClothes/Inputs_VITON/cloth_mask/" + i).to(device)

  densepose = load_image("/content/drive/MyDrive/AIClothes/Inputs_VITON/densepose/" + i).to(device)

  parse_agnostic = load_image("/content/drive/MyDrive/AIClothes/Inputs_VITON/parse_agnostic/" + i).to(device)

  warped_mask = load_image("/content/drive/MyDrive/AIClothes/Inputs_VITON/warped_mask/" + i).to(device)

  image = load_image("/content/drive/MyDrive/AIClothes/Inputs_VITON/images/" + i).to(device)

  warped_cloth = load_image("/content/drive/MyDrive/AIClothes/Inputs_VITON/warped_cloth/" + i).to(device)

  agnostic = load_image("/content/drive/MyDrive/AIClothes/Inputs_VITON/agnostic/" + i).to(device)

  cloths.append(cloth)
  cloth_masks.append(cloth_mask[:, 0:1, ...])
  denseposes.append(densepose)
  parse_agnostics.append(parse_agnostic)
  warped_masks.append(warped_mask[:, 0:1, ...])
  images.append(image)
  warped_cloths.append(warped_cloth)
  agnostics.append(agnostic)

  #save_tensor_as_image(warped_cloth, "/content/drive/MyDrive/AIClothes/Inputs_VITON/warped_cloth/" + i)

  input_1 = torch.cat([cloth, cloth_mask[:, 0:1, ...]], dim=1)
  input_2 = torch.cat([densepose, parse_agnostic], dim=1)

  inputs_1.append(input_1)
  inputs_2.append(input_2)

In [None]:
class WarpingProcess(nn.Module):

  def __init__(self, oc):
      super(WarpingProcess, self).__init__()

      self.conv_after_concat_1 = nn.Conv2d(oc*8, 2, kernel_size=3, padding=1)

      self.convs_clothing = nn.ModuleList([
          nn.Conv2d(oc*4, oc*4, kernel_size=1),
          nn.Conv2d(oc*4, oc*4, kernel_size=1),
          nn.Conv2d(oc*2, oc*4, kernel_size=1),
          nn.Conv2d(oc, oc*4, kernel_size=1)
      ])

      self.convs_segment = nn.ModuleList([
          nn.Conv2d(oc*4, oc*4, kernel_size=3, padding=1),
          nn.Conv2d(oc*4, oc*4, kernel_size=3, padding=1),
          nn.Conv2d(oc*4, oc*4, kernel_size=3, padding=1),
          nn.Conv2d(oc*4, oc*4, kernel_size=3, padding=1)
      ])

      self.conv_after_concat_2 = nn.ModuleList([
          nn.Conv2d(oc*8, 2, kernel_size=3, padding=1),
          nn.Conv2d(oc*8, 2, kernel_size=3, padding=1),
          nn.Conv2d(oc*8, 2, kernel_size=3, padding=1),
          nn.Conv2d(oc*8, 2, kernel_size=3, padding=1)
      ])

      self.convs_segment_after_concat = nn.ModuleList([
          ResNetEncoderBlock(oc*12, oc*4, down=False, up=True),
          ResNetEncoderBlock(oc*12, oc*4, down=False, up=True),
          ResNetEncoderBlock(oc*10, oc*4, down=False, up=True),
          ResNetEncoderBlock(oc*9, oc*4, down=False, up=True)
      ])

      self.clothingencoder = ClothingEncoder(4, oc)
      self.posencoder = SegmentEncoder(6, oc)

  def forward(self, input_1, input_2, device):

    clothing_before_warp = []
    pose_after_concat = []
    clothingpose_before_warp = []

    flow_list = []

    ce_0, ce_1, ce_2, ce_3, ce_4 = self.clothingencoder(input_1)
    clothing_features = [ce_0, ce_1, ce_2, ce_3, ce_4]

    pe_0, pe_1, pe_2, pe_3, pe_4, pe_5, pe_6 = self.posencoder(input_2)
    pose_features = [pe_0, pe_1, pe_2, pe_3, pe_4, pe_5, pe_6]

    clothing_pose_last_feature = torch.cat([ce_4, pe_4], dim=1)

    conv_before_warp = self.conv_after_concat_1(clothing_pose_last_feature)

    for i in range(4):

      if i == 0:

        up = F.interpolate(clothing_features[4], scale_factor=2, mode='bilinear')

        conv_up = self.convs_clothing[i](clothing_features[3 - i]) + up

        clothing_before_warp.append(conv_up)

        up_flow = F.interpolate(conv_before_warp, scale_factor=2, mode='bilinear')

        grid = make_grid(1, up_flow.shape[2], up_flow.shape[3], device)

        flow_norm = torch.cat([up_flow[:, 0:1, :, :] / ((up_flow.shape[3] - 1.0) / 2.0), up_flow[:, 1:2, :, :] / ((up_flow.shape[2] - 1.0) / 2.0)], 1).permute(0, 2, 3, 1)
        warped_T1 = F.grid_sample(conv_up, grid + flow_norm, padding_mode='border')

        flow_list.append(flow_norm)

        pe_concat = torch.cat([pose_features[6], warped_T1, pose_features[3 - i]], dim=1)

        conv_pe_6_out = self.convs_segment[i](pose_features[6])

        con_pe_6_warp = torch.concat([conv_pe_6_out, warped_T1], dim=1)

        conv_pe6_warp_out = self.conv_after_concat_2[i](con_pe_6_warp)

        concat_up_conv_pe6_warp = up_flow + conv_pe6_warp_out

        clothingpose_before_warp.append(concat_up_conv_pe6_warp)

        pe_last_resblock = self.convs_segment_after_concat[i](pe_concat)

        pose_after_concat.append(pe_last_resblock)

      else:

        up = F.interpolate(clothing_before_warp[i - 1], scale_factor=2, mode='bilinear')

        conv_up = self.convs_clothing[i](clothing_features[3 - i]) + up

        clothing_before_warp.append(conv_up)

        up_flow = F.interpolate(clothingpose_before_warp[i - 1], scale_factor=2, mode='bilinear')

        grid = make_grid(1, up_flow.shape[2], up_flow.shape[3], device)

        flow_norm = torch.cat([up_flow[:, 0:1, :, :] / ((up_flow.shape[3] - 1.0) / 2.0), up_flow[:, 1:2, :, :] / ((up_flow.shape[2] - 1.0) / 2.0)], 1).permute(0, 2, 3, 1)
        warped_T1 = F.grid_sample(conv_up, grid + flow_norm, padding_mode='border')

        flow_list.append(flow_norm)

        pe_concat = torch.cat([pose_after_concat[i - 1], warped_T1, pose_features[3 - i]], dim=1)

        conv_pe_6_out = self.convs_segment[i](pose_after_concat[i - 1])

        con_pe_6_warp = torch.concat([conv_pe_6_out, warped_T1], dim=1)

        conv_pe6_warp_out = self.conv_after_concat_2[i](con_pe_6_warp)

        concat_up_conv_pe6_warp = up_flow + conv_pe6_warp_out

        clothingpose_before_warp.append(concat_up_conv_pe6_warp)

        pe_last_resblock = self.convs_segment_after_concat[i](pe_concat)

        pose_after_concat.append(pe_last_resblock)

    last_up = F.interpolate(clothingpose_before_warp[-1], scale_factor=2, mode='bilinear')

    grid = make_grid(1, last_up.shape[2], last_up.shape[3], device)

    flow_norm = torch.cat([last_up[:, 0:1, :, :] / ((last_up.shape[3] - 1.0) / 2.0), last_up[:, 1:2, :, :] / ((last_up.shape[2] - 1.0) / 2.0)], 1).permute(0, 2, 3, 1)
    warped_T1 = F.grid_sample(input_1, grid + flow_norm, padding_mode='border')

    flow_list.append(flow_norm)

    warped_c = warped_T1[:, :-1, :, :]
    warped_cm = warped_T1[:, -1:, :, :]

    return warped_c, warped_cm, flow_list


In [None]:
lr = 0.00005
l1_lambda = 1
vgg_lambda = 0.1
tvlambda = 1
epochs = 200
criterionL1 = nn.L1Loss().to(device)
criterionVGG = VGGLoss().to(device)

In [None]:
generator = WarpingProcess(96).to(device)
gen_opt = torch.optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))

In [None]:
def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.xavier_normal_(m.weight)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.xavier_normal_(m.weight)
        torch.nn.init.constant_(m.bias, 0)

generator = generator.apply(weights_init)

In [None]:
count = 0

losses = []

for i in range(epochs):

  for input_1, input_2, warped_cloth, warped_mask in zip(inputs_1, inputs_2, warped_cloths, warped_masks):

    #input_1, input_2, warped_cloth, warped_mask =  inputs_1[0:11], inputs_2[0:11], warped_cloths[0:11], warped_masks[0:11]

    warped_c, warped_cm, flow_list = generator(input_1, input_2, device)

    loss_l1_cloth = criterionL1(warped_cm, warped_mask)

    loss_vgg = criterionVGG(warped_cm*warped_c, warped_cloth)

    loss_tv = flow_loss(flow_list)

    for j in range(len(flow_list)-1):
      flow = flow_list[j]
      N, fH, fW, _ = flow.size()
      grid = make_grid(N, input_1.shape[2], input_1.shape[3], device)
      flow = F.interpolate(flow.permute(0, 3, 1, 2), size = input_1.shape[2:], mode='bilinear').permute(0, 2, 3, 1)
      flow_norm = torch.cat([flow[:, :, :, 0:1] / ((flow.shape[1] - 1.0) / 2.0), flow[:, :, :, 1:2] / ((flow.shape[2] - 1.0) / 2.0)], 3)
      warped = F.grid_sample(input_1, grid + flow_norm, padding_mode='border')
      warped_c_flow = warped[:, :-1, :, :]
      warped_cm_flow = warped[:, -1:, :, :]

      loss_l1_cloth += criterionL1(warped_cm_flow, warped_mask) / (2 ** (4-j))
      loss_vgg += criterionVGG(warped_cm_flow*warped_c_flow, warped_cloth) / (2 ** (4-j))


    loss_G = ((l1_lambda * loss_l1_cloth) + (vgg_lambda * loss_vgg) + (tvlambda * loss_tv))

    losses.append(loss_G.item())

    gen_opt.zero_grad()
    loss_G.backward()
    gen_opt.step()

    if i % 10 == 0:

      print(f'Epoch: {i}, Mean Loss: {np.mean(losses)}')

      fig, axs = plt.subplots(1, 7, figsize=(10, 5))
      axs[0].imshow(warped_c[0].permute(1, 2, 0).detach().cpu().numpy())
      axs[0].axis('off')

      axs[1].imshow(warped_cloth[0].permute(1, 2, 0).detach().cpu().numpy())
      axs[1].axis('off')

      axs[2].imshow(warped_cm[0].permute(1, 2, 0).detach().cpu().numpy())
      axs[2].axis('off')

      axs[3].imshow(warped_mask[0].permute(1, 2, 0).detach().cpu().numpy())
      axs[3].axis('off')

      axs[4].imshow(input_1[0, 0:3, ...].permute(1, 2, 0).detach().cpu().numpy())
      axs[4].axis('off')

      axs[5].imshow(input_2[0, 0:3, ...].permute(1, 2, 0).detach().cpu().numpy())
      axs[5].axis('off')

      axs[6].imshow(input_2[0, 3:6, ...].permute(1, 2, 0).detach().cpu().numpy())
      axs[6].axis('off')

      plt.tight_layout()
      plt.show()

In [None]:
checkpoint = {
        'epoch': i,
        'model_state_dict': generator.state_dict(),
        'optimizer_state_dict': gen_opt.state_dict(),
        'loss': losses,
        # You can add more training parameters here if needed, e.g., scheduler state, random seeds
    }

torch.save(checkpoint, f'/content/drive/MyDrive/AIClothes/Models/warping_lr{lr}_vgg{vgg_lambda}_tvlmabda_{tvlambda}epoch_{i}.pth')

In [None]:
# CHECKPOINT_PATH ='/content/drive/MyDrive/AIClothes/Models/Warping/warping_lr1e-05_vgg0.001_epoch_2009.pth'

# checkpoint = torch.load(CHECKPOINT_PATH)

# viton.load_state_dict(checkpoint['model_state_dict'])
# viton_opt.load_state_dict(checkpoint['optimizer_state_dict'])

# losses = checkpoint['loss']

In [None]:
def warped_cloth_into_agnostic(agnostic_image, warped_cloth_image, warped_cloth_mask):

  agnostic_image_warped_cloth = agnostic_image.clone()

  warped_cloth_mask[warped_cloth_mask > 0.5] = 1

  warped_cloth_mask = warped_cloth_mask.repeat(1, 3, 1, 1)

  agnostic_image_warped_cloth[warped_cloth_mask == 1] = warped_cloth_image[warped_cloth_mask == 1]

  return agnostic_image_warped_cloth

In [None]:
for img in range(len(inputs_1)):

  agnostic_image = agnostics[img]

  warped_c, warped_cm, flow_list = generator(inputs_1[img], inputs_2[img], device)

  new_im = warped_cloth_into_agnostic(agnostic_image, warped_c, warped_cm)

  plt.imshow(new_im[0].permute(1, 2, 0).detach().cpu().numpy())
  plt.show()

  save_tensor_as_image(new_im[0], "/content/drive/MyDrive/AIClothes/Inputs_VITON/inputs_difussion_model/" + list_names[img])