### **Imports**

In [1]:
import os 
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Literal
from ip_adapter.ip_adapter import Resampler

import argparse
import logging
import os
import torch.utils.data as data
import torchvision
import json
import accelerate
import numpy as np
import torch
from PIL import Image, ImageDraw
import torch.nn.functional as F
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from packaging import version
from torchvision import transforms
import diffusers
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, StableDiffusionXLControlNetInpaintPipeline
from transformers import AutoTokenizer, PretrainedConfig,CLIPImageProcessor, CLIPVisionModelWithProjection,CLIPTextModelWithProjection, CLIPTextModel, CLIPTokenizer
import cv2
from diffusers.utils.import_utils import is_xformers_available
from numpy.linalg import lstsq
import yaml
from src.unet_hacked_tryon import UNet2DConditionModel
from src.unet_hacked_garmnet import UNet2DConditionModel as UNet2DConditionModel_ref
from src.tryon_pipeline import StableDiffusionXLInpaintPipeline as TryonPipeline

logger = get_logger(__name__, log_level="INFO")
os.chdir('/home/bala/Desktop/sri_krishna/IDM-VTON')

  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  torch.utils._pytree._register_pytree_node(
  torch.utils._pytree._register_pytree_node(


In [2]:
label_map={
    "background": 0,
    "hat": 1,
    "hair": 2,
    "sunglasses": 3,
    "upper_clothes": 4,
    "skirt": 5,
    "pants": 6,
    "dress": 7,
    "belt": 8,
    "left_shoe": 9,
    "right_shoe": 10,
    "head": 11,
    "left_leg": 12,
    "right_leg": 13,
    "left_arm": 14,
    "right_arm": 15,
    "bag": 16,
    "scarf": 17,
}

# def parse_args():
#     parser = argparse.ArgumentParser(description="Simple example of a training script.")
#     parser.add_argument("--pretrained_model_name_or_path",type=str,default= "yisol/IDM-VTON",required=False,)
#     parser.add_argument("--width",type=int,default=768,)
#     parser.add_argument("--height",type=int,default=1024,)
#     parser.add_argument("--num_inference_steps",type=int,default=30,)
#     parser.add_argument("--output_dir",type=str,default="result",)
#     parser.add_argument("--category",type=str,default="upper_body",choices=["upper_body", "lower_body", "dresses"])
#     parser.add_argument("--unpaired",action="store_true",)
#     parser.add_argument("--data_dir",type=str,default="archive")
#     parser.add_argument("--seed", type=int, default=42,)
#     parser.add_argument("--train_batch_size", type=int, default=8,)
#     parser.add_argument("--train_epochs", type=int, default=10,)
#     parser.add_argument("--test_batch_size", type=int, default=2,)
#     parser.add_argument("--guidance_scale",type=float,default=2.0,)
#     parser.add_argument("--mixed_precision",type=str,default=None,choices=["no", "fp16", "bf16"],)
#     parser.add_argument("--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers.")
#     args = parser.parse_args()


#     return args

# args = parse_args()

with open('viton_train_config.yaml', 'r') as f:
    config = yaml.safe_load(f)


args = config
if config["seed"] is not None:
        set_seed(args["seed"])

def pil_to_tensor(images):
    images = np.array(images).astype(np.float32) / 255.0
    images = torch.from_numpy(images.transpose(2, 0, 1))
    return images


In [3]:
class VitonHDDataset(data.Dataset):
    def __init__(
        self,
        dataroot_path: str,
        transformations ,
        phase: Literal["train", "test"],
        order: Literal["paired", "unpaired"] = "paired",
        size: Tuple[int, int] = (512, 384),
        
    ):
        super(VitonHDDataset, self).__init__()
        self.dataroot = dataroot_path
        self.phase = phase
        self.height = size[0]
        self.width = size[1]
        self.size = size
        self.transform = transformations
        self.toTensor = transforms.ToTensor()

        self.annotation_pair = {}
        try : 
            with open(
                os.path.join(dataroot_path, phase, "vitonhd_" + phase + "_tagged.json"), "r"
            ) as file1:
                data1 = json.load(file1)

            annotation_list = [
                "sleeveLength",
                "neckLine",
                "item",
            ]

            
            for k, v in data1.items():
                for elem in v:
                    annotation_str = ""
                    for template in annotation_list:
                        for tag in elem["tag_info"]:
                            if (
                                tag["tag_name"] == template
                                and tag["tag_category"] is not None
                            ):
                                annotation_str += tag["tag_category"]
                                annotation_str += " "
                    self.annotation_pair[elem["file_name"]] = annotation_str
        except:
            print(f"No annotation file found for {self.phase} phase in {self.dataroot}")                                  
                                  
        

        self.order = order
        self.toTensor = transforms.ToTensor()

        im_names = []
        c_names = []
        dataroot_names = []


        if phase == "train":
            filename = os.path.join(dataroot_path, f"{phase}_pairs.txt")
        else:
            filename = os.path.join(dataroot_path, f"{phase}_pairs.txt")

        with open(filename, "r") as f:
            for line in f.readlines():
                if phase == "train":
                    im_name, _ = line.strip().split()
                    c_name = im_name
                else:
                    if order == "paired":
                        im_name, _ = line.strip().split()
                        c_name = im_name
                    else:
                        im_name, c_name = line.strip().split()

                im_names.append(im_name)
                c_names.append(c_name)
                dataroot_names.append(dataroot_path)

        self.im_names = im_names
        self.c_names = c_names
        self.dataroot_names = dataroot_names
        self.clip_processor = CLIPImageProcessor()
        
    def __getitem__(self, index):
        c_name = self.c_names[index]
        im_name = self.im_names[index]
        if c_name in self.annotation_pair:
            cloth_annotation = self.annotation_pair[c_name]
        else:
            cloth_annotation = "shirts"
        cloth = Image.open(os.path.join(self.dataroot, self.phase, "cloth", c_name))

        im_pil_big = Image.open(
            os.path.join(self.dataroot, self.phase, "image", im_name)
        ).resize((self.width,self.height))
        image = self.transform(im_pil_big)

        mask = Image.open(os.path.join(self.dataroot, self.phase, "agnostic-mask", im_name)).resize((self.width,self.height))
        mask = self.toTensor(mask)
        mask = mask[:1]
        mask = 1-mask
        im_mask = image * mask
 
        pose_img = Image.open(
            os.path.join(self.dataroot, self.phase, "image-densepose", im_name)
        )
        pose_img = self.transform(pose_img)  # [-1,1]
 
        result = {}
        result["c_name"] = c_name
        result["im_name"] = im_name
        result["image"] = image
        result["cloth_pure"] = self.transform(cloth).half()

        result["cloth"] = self.clip_processor(images=cloth, return_tensors="pt").pixel_values.half()

        result["inpaint_mask"] =1-mask
        result["im_mask"] = im_mask
        result["caption_cloth"] = "a photo of " + cloth_annotation
        result["caption"] = "model is wearing a " + cloth_annotation
        result["pose_img"] = pose_img

        return result

    def __len__(self):
        # model images + cloth image
        return len(self.im_names)




## Reducing Train Images to 4000 to train faster

- In below code I tried to reduce img size to be able to afford higher batch_size....
- But the model only accepts img of size (768X1024)

- But I reduced the training set from ~11k to 4k images

In [4]:
import torch
from torchvision import transforms
from PIL import Image
import numpy as np

# Define the transformations

def reduce_pixels(img_pth):
    transformations = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Resize((400, 280)),
            # transforms.Normalize([0.5], [0.5]),
        ]
    )
    image = Image.open(img_pth).convert('RGB')  # Convert to grayscale

    # Apply the transformations
    transformed_image = transformations(image)
    transformed_image = transforms.ToPILImage()(transformed_image)
    return transformed_image


if not os.path.exists('archive/train/image'):
    os.makedirs('archive/train/image')
    os.makedirs('archive/train/cloths')
    
    for f_name in os.listdir('archive/train/image_original'):
        img = reduce_pixels(os.path.join('archive/train/image_original', f_name))
        img.save(f'archive/train/image/{f_name}')
        
    for f_name in os.listdir('archive/train/cloth_original'):
        img = reduce_pixels(os.path.join('archive/train/cloth_original', f_name))
        img.save(f'archive/train/cloths/{f_name}')

    if not os.path.exists('archive/train_pairs.txt'):
        NUM = 4000

        with open('archive/train_pairs_original.txt', 'r') as f:
            train_pairs = f.readlines()
            
        with open('archive/train_pairs.txt', 'w') as f:
            #select 4000 random pairs
            indices = np.random.choice(len(train_pairs), NUM, replace=False)
            for i in indices:
                f.write(train_pairs[i])


#### Creating DataLoaders

In [5]:
transformations = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5]),
    ]
)   


train_dataset = VitonHDDataset(
        dataroot_path=args["data_dir"],
        transformations=transformations,
        phase="train",
        order="unpaired" if args["unpaired"] else "paired",
        size=(args["height"], args["width"]),
    )

train_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    shuffle=True,
    batch_size=args["train_batch_size"],
    num_workers=4,
)

test_dataset = VitonHDDataset(
        dataroot_path=args["data_dir"],
        transformations=transformations,
        phase="test",
        order="unpaired" if args["unpaired"] else "paired",
        size=(args["height"], args["width"]),
    )

test_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    shuffle=False,
    batch_size=args["test_batch_size"],
    num_workers=4,
)

No annotation file found for train phase in archive


#### Get Models

In [6]:
def get_models():
    # if accelerator.mixed_precision == "fp16":
    #     weight_dtype = torch.float16
    #     args.mixed_precision = accelerator.mixed_precision
    # elif accelerator.mixed_precision == "bf16":
    #     weight_dtype = torch.bfloat16
    #     args.mixed_precision = accelerator.mixed_precision

    # Load scheduler, tokenizer and models.
    
    accelerator_project_config = ProjectConfiguration(project_dir=args["output_dir"])
    accelerator = Accelerator(
        mixed_precision=args["mixed_precision"],
        project_config=accelerator_project_config,
    )
    if accelerator.is_local_main_process:
        transformers.utils.logging.set_verbosity_warning()
        diffusers.utils.logging.set_verbosity_info()
    else:
        transformers.utils.logging.set_verbosity_error()
        diffusers.utils.logging.set_verbosity_error()
    # If passed along, set the training seed now.
    

    # Handle the repository creation
    if accelerator.is_main_process:
        if args["output_dir"] is not None:
            os.makedirs(args["output_dir"], exist_ok=True)

    weight_dtype = torch.float16
    
    noise_scheduler = DDPMScheduler.from_pretrained(args["pretrained_model_name_or_path"], subfolder="scheduler")
    vae = AutoencoderKL.from_pretrained(                           # VAE (frozen pre-trained VAE) for latent space creation from images
        args["pretrained_model_name_or_path"],
        subfolder="vae",
        torch_dtype=torch.float16,
    )
    unet = UNet2DConditionModel.from_pretrained(                     # TryonNet (trainable Unet)
        args["pretrained_model_name_or_path"],
        subfolder="unet",
        torch_dtype=torch.float16,
    )
    image_encoder = CLIPVisionModelWithProjection.from_pretrained(     # IP adapter (trainable)
        args["pretrained_model_name_or_path"],
        subfolder="image_encoder",
        torch_dtype=torch.float16,
    )
    UNet_Encoder = UNet2DConditionModel_ref.from_pretrained(        # GarmentNet (frozen pre-trained UNet encoder)
        args["pretrained_model_name_or_path"],
        subfolder="unet_encoder",
        torch_dtype=torch.float16,
    )
    text_encoder_one = CLIPTextModel.from_pretrained(
        args["pretrained_model_name_or_path"],
        subfolder="text_encoder",
        torch_dtype=torch.float16,
    )
    text_encoder_two = CLIPTextModelWithProjection.from_pretrained(
        args["pretrained_model_name_or_path"],
        subfolder="text_encoder_2",
        torch_dtype=torch.float16,
    )
    
    # tokenizers for text encoders
    tokenizer_one = AutoTokenizer.from_pretrained(
        args["pretrained_model_name_or_path"],
        subfolder="tokenizer",
        revision=None,
        use_fast=False,
    )
    tokenizer_two = AutoTokenizer.from_pretrained(
        args["pretrained_model_name_or_path"],
        subfolder="tokenizer_2",
        revision=None,
        use_fast=False,
    )
                
    unet.requires_grad_(True)
    vae.requires_grad_(False)
    image_encoder.requires_grad_(True)
    UNet_Encoder.requires_grad_(False)
    text_encoder_one.requires_grad_(False)
    text_encoder_two.requires_grad_(False)
    UNet_Encoder.to(accelerator.device, weight_dtype)
    # unet.eval()
    UNet_Encoder.eval()
    
    
    return {
        "accelerator": accelerator,
        "noise_scheduler": noise_scheduler,
        "vae": vae,
        "unet": unet,
        "image_encoder": image_encoder,
        "UNet_Encoder": UNet_Encoder,
        "text_encoder_one": text_encoder_one,
        "text_encoder_two": text_encoder_two,
        "tokenizer_one": tokenizer_one,
        "tokenizer_two": tokenizer_two,
    }

In [7]:
def main():
    
    models = get_models()

    if args["enable_xformers_memory_efficient_attention"]:
        if is_xformers_available():
            import xformers

            xformers_version = version.parse(xformers.__version__)
            if xformers_version == version.parse("0.0.16"):
                logger.warn(
                    "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
                )
            models["unet"].enable_xformers_memory_efficient_attention()
        else:
            raise ValueError("xformers is not available. Make sure it is installed correctly")

    

    pipe = TryonPipeline.from_pretrained(
            args["pretrained_model_name_or_path"],
            unet=models["unet"],
            vae=models["vae"],
            feature_extractor= CLIPImageProcessor(),
            text_encoder = models["text_encoder_one"],
            text_encoder_2 = models["text_encoder_two"],
            tokenizer = models["tokenizer_one"],
            tokenizer_2 = models["tokenizer_two"],
            scheduler = models["noise_scheduler"],
            image_encoder=models["image_encoder"],
            torch_dtype=torch.float16,
    ).to(models["accelerator"].device)
    pipe.unet_encoder = models["UNet_Encoder"]
    
    return  pipe, models

    # pipe.enable_sequential_cpu_offload()
    # pipe.enable_model_cpu_offload()
    # pipe.enable_vae_slicing()

## Loss Functions : 

There are various loss functions available for training diffusion models like : 
#### L1 loss
#### LPIPS
A lower LPIPS score indicates that the generated image is more similar 
to the real image, and thus the diffusion model is performing better.
It uses a deep neural network (specifically, a version of the AlexNet or VGG network) to extract features 
from image patches, and computes distances in this feature space to measure image similarity. 

#### SSIM
A higher SSIM score indicates that the generated image is more similar to the real image, 
and thus the diffusion model is performing better.

 SSIM considers changes in structural information, luminance, and contrast of the images.  

##### **I chosen to move with L1 loss.... But any one can be placed in training loop in place of L1 loss**      

In [1]:

# import lpips
# import torch

# # Initialize the LPIPS metric
# loss_fn_vgg = lpips.LPIPS(net='vgg')

# # Define your input and target tensors
# input = torch.randn(1, 3, 64, 64)
# target = torch.randn(1, 3, 64, 64)

# # Compute the LPIPS distance
# distance = loss_fn_vgg(input, target)

# _________________________________________________________________________________

# from pytorch_msssim import ssim
# import torch

# # Define your input and target tensors
# input = torch.randn(1, 1, 256, 256)
# target = torch.randn(1, 1, 256, 256)

# # Compute the SSIM
# ssim_value = ssim(input, target, data_range=1.0)

# print("SSIM: ", ssim_value.item())

#____________________________________________________________________________________


### Setting up Model passes and evaluation

**Training Scenario**
* Since images have high resolution, I could only keep a batch size of 2 (OutOfMemory error beyond that)
* To increase effective batch size, I set the **accumulate_grad_batches = 50** ... Making gradient updates after 50 steps
* As per the paper, they only trained **TryonNet** and **IPAdapter**, kept rest of the models frozen .... I followed the approach
* I have also employed learning rate scheduler under optimizer function

In [9]:
import torch
import pytorch_lightning as pl

class TrainEval(pl.LightningModule):
    def __init__(self, config):
        super().__init__()
        self.criterion = torch.nn.L1Loss()
        self.config = config  
        self.pipe, self.models = main()
        self.log_every_n_steps = self.config['accumulate_grad_batches']

    def get_output(self, batch):
        img_emb_list = []
        for i in range(batch['cloth'].shape[0]):
            img_emb_list.append(batch['cloth'][i])
        
        prompt = batch["caption"]

        num_prompts = batch['cloth'].shape[0]                                        
        negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"

        if not isinstance(prompt, List):
            prompt = [prompt] * num_prompts
        if not isinstance(negative_prompt, List):
            negative_prompt = [negative_prompt] * num_prompts

        image_embeds = torch.cat(img_emb_list,dim=0)
        # with torch.inference_mode():
        (
            prompt_embeds,
            negative_prompt_embeds,
            pooled_prompt_embeds,
            negative_pooled_prompt_embeds,
        ) = self.pipe.encode_prompt(
            prompt,
            num_images_per_prompt=1,
            do_classifier_free_guidance=True,
            negative_prompt=negative_prompt,
        )
    
        prompt = batch["caption_cloth"]
        negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"

        if not isinstance(prompt, List):
            prompt = [prompt] * num_prompts
        if not isinstance(negative_prompt, List):
            negative_prompt = [negative_prompt] * num_prompts

        # with torch.inference_mode():
        (
            prompt_embeds_c,
            _,
            _,
            _,
        ) = self.pipe.encode_prompt(
            prompt,
            num_images_per_prompt=1,
            do_classifier_free_guidance=False,
            negative_prompt=negative_prompt,
        )


        generator = torch.Generator(self.pipe.device).manual_seed(args["seed"]) if args["seed"] is not None else None
        images = self.pipe(
            prompt_embeds=prompt_embeds,
            negative_prompt_embeds=negative_prompt_embeds,
            pooled_prompt_embeds=pooled_prompt_embeds,
            negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
            num_inference_steps=args["num_inference_steps"],
            generator=generator,
            strength = 1.0,
            pose_img = batch['pose_img'],
            text_embeds_cloth=prompt_embeds_c,
            cloth = batch["cloth_pure"].to(self.models["accelerator"].device),
            mask_image=batch['inpaint_mask'],
            image=(batch['image']+1.0)/2.0, 
            height=args["height"],
            width=args["width"],
            guidance_scale=args["guidance_scale"],
            ip_adapter_image = image_embeds,
        )[0]
        
        to_tensor = transforms.ToTensor()

        # Convert list of PIL images to list of tensors
        tensor_images = [to_tensor(image) for image in images]

        # Stack the list of tensors into a single tensor
        stacked_tensor_images = torch.stack(tensor_images)
        # print(60)
        # print(type(images))
        # print(type(images[0]))
        return stacked_tensor_images.to(self.models["accelerator"].device)

  
    def training_step(self, batch, batch_idx):
        generated_images = self.get_output(batch)
        generated_images.requires_grad = True
        loss = self.criterion(generated_images, batch['image'].to(self.models["accelerator"].device))
        # if (self.global_step + 1) % self.log_every_n_steps == 0:
        self.log("train_loss", loss,on_step = True ,on_epoch=True, prog_bar=True, logger=True)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        generated_images = self.get_output(batch)
        loss = self.criterion(generated_images, batch['image'].to(self.models["accelerator"].device))
        self.log("val_loss", loss, on_epoch=True, prog_bar=True, logger=True)
        return loss
        

    def test_step(self, batch, batch_idx):
        generated_images = self.get_output(batch)
        loss = self.criterion(generated_images, batch['image'].to(self.models["accelerator"].device))
        self.log("test_loss", loss, on_epoch=True, prog_bar=True, logger=True)
        
        
        for i in range(len(generated_images)):
            x_sample = pil_to_tensor(generated_images[i])
            torchvision.utils.save_image(x_sample,os.path.join(args["output_dir"],batch['im_name'][i]))
        return loss



    def configure_optimizers(self):
        parameters = []
        for param in self.models["unet"].parameters():
            parameters.append(param)
            assert param.requires_grad, "All parameters should require gradients."
            
        for param in self.models["image_encoder"].parameters():
            parameters.append(param)
            assert param.requires_grad, "All parameters should require gradients."
            
        optim =  torch.optim.Adam(params = parameters, lr = self.config['lr'], weight_decay = self.config['weight_decay'])   # https://pytorch.org/docs/stable/optim.html
        lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optim, patience=3, factor=0.7, 
                                                                  threshold=0.005, cooldown =2,verbose=True)
        # lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optim,gamma = 0.995 ,last_epoch=-1,   verbose=True)

        return [optim], [{'scheduler': lr_scheduler, 'interval': 'epoch', 'monitor': 'train_loss', 'name': 'lr_scheduler'}]

#### Callbacks

In [10]:
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import ModelCheckpoint, RichProgressBar, RichModelSummary
from torchvision import transforms

from  pytorch_lightning.callbacks.progress.rich_progress import RichProgressBarTheme

early_stop_callback = EarlyStopping(
   monitor='val_loss',
   min_delta=0.00001,
   patience=20,
   verbose=True,
   mode='min'
)

theme = RichProgressBarTheme(metrics='green', time='yellow', progress_bar_finished='#8c53e0' ,progress_bar='#c99e38')
rich_progress_bar = RichProgressBar(theme=theme)

# rich_model_summary = RichModelSummary(max_depth=5)

checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    save_top_k=3,
    verbose=True,
 )

#### Trainer

Here model and data-loaders meet to train and evaluate of dataset....

* I set it for training 10 epochs... But one epoch took more than 1 day to complete.... I need to terminate it 
* Logging is done to wandb... while running below cell it would ask for wandb API key just for intial run.
* checkpoints are also saved

In [None]:
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
import yaml 


torch.set_float32_matmul_precision('high')

#___________________________________________________________________________________________________________________
model = TrainEval(config)

NAME = config['model_name']
checkpoint_callback.dirpath = os.path.join(config['dir'], 'ckpts')
checkpoint_callback.filename = NAME+'__' + config['ckpt_file_name']

run_name = f"lr_{config['lr']} *** bs{config['train_batch_size']} *** decay_{config['weight_decay']}"
wandb_logger = WandbLogger(project= NAME, name = run_name)

trainer = Trainer(callbacks=[early_stop_callback, checkpoint_callback, rich_progress_bar], 
                  accelerator = 'gpu' ,max_epochs=args["train_epochs"], logger=[wandb_logger], 
                  accumulate_grad_batches = config['accumulate_grad_batches'])  
 
trainer.fit(model, train_dataloader, test_dataloader)
trainer.test(model, test_dataloader)

In [None]:
import torch
print(torch.__version__)
print(torch.version.cuda)
print(torch.cuda.is_available())
print(torch.cuda.current_device())
print(torch.cuda.get_device_name(torch.cuda.current_device()))


2.3.0+cu121
12.1
True
0
NVIDIA RTX A5000
