In [1]:
import os
import cv2
import json
import torch
import numpy as np
from PIL import Image
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from torchvision import models
from torchvision.transforms import transforms
import diffusers
from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepScheduler
from accelerate import Accelerator
import torch.nn.functional as F
from huggingface_hub import login
import google.generativeai as genai
from transformers import AutoModel, AutoImageProcessor, CLIPTextModel, CLIPTokenizer

2025-07-20 11:53:38.267625: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1753012418.291241     292 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1753012418.298309     292 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [None]:
os.environ["HF_Token"] = ""
login(token=os.environ["HF_Token"])


os.environ["GEMINI_API_KEY"] = ""
genai.configure(api_key=os.environ["GEMINI_API_KEY"])

gemini = genai.GenerativeModel("gemini-2.0-flash")

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


In [4]:
model_id = "stabilityai/stable-diffusion-2-base"
torch_dtype = torch.float16

In [5]:
torch.manual_seed(1)

<torch._C.Generator at 0x7e6df4d96ff0>

# Dataset Pipeline

In [6]:
tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")

In [7]:
class CustomDataset(Dataset):

    def __init__(self, folder_path):
        
        self.image = os.path.join(folder_path, "image")
        self.cloth = os.path.join(folder_path, "cloth")
        self.cloth_mask = os.path.join(folder_path, "cloth-mask")
        self.segmentation = os.path.join(folder_path, "image-parse-v3")
        self.agnostic = os.path.join(folder_path, "agnostic-v3.2")

        self.image_names = sorted([f for f in os.listdir(self.image)])
        self.image_names.sort(key=lambda f: int(f.split("_")[0]))

        self.cloth_names = sorted([f for f in os.listdir(self.cloth)])
        self.cloth_names.sort(key=lambda f: int(f.split("_")[0]))

        self.cloth_mask_names = sorted([f for f in os.listdir(self.cloth_mask)])
        self.cloth_mask_names.sort(key=lambda f: int(f.split("_")[0]))

        self.segmentation_names = sorted([f for f in os.listdir(self.segmentation)])
        self.segmentation_names.sort(key=lambda f: int(f.split("_")[0]))

        self.agnostic_names = sorted([f for f in os.listdir(self.agnostic)])
        self.agnostic_names.sort(key=lambda f: int(f.split("_")[0]))
    
        self.transform = transforms.Compose([
            transforms.Resize((512, 512), interpolation=transforms.InterpolationMode.BICUBIC),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])

        self.mask_transform = transforms.Compose([
            transforms.Resize((64, 64), interpolation=transforms.InterpolationMode.NEAREST_EXACT),
            transforms.ToTensor()
        ])

        with open("/kaggle/input/prompts/generated_prompts.json") as file:
            self.generated_prompts = json.load(file)

    
    def __len__(self):
        return 6506

    
    def __getitem__(self, index):

        image = Image.open(os.path.join(self.image, self.image_names[index]))
        cloth = Image.open(os.path.join(self.cloth, self.cloth_names[index]))
        cloth_mask = Image.open(os.path.join(self.cloth_mask, self.cloth_mask_names[index]))
        segmentation = Image.open(os.path.join(self.segmentation, self.segmentation_names[index]))
        agnostic = Image.open(os.path.join(self.agnostic, self.agnostic_names[index]))

        prompt = self.generated_prompts[self.cloth_names[index]]
        
        text_input_ids = tokenizer(prompt, padding="max_length", truncation=True, return_tensors="pt").input_ids.squeeze(0)
        
        image = self.transform(image).to(torch_dtype).to(device)
        cloth = self.transform(cloth).to(torch_dtype).to(device)
        cloth_mask = self.mask_transform(cloth_mask).to(torch_dtype).to(device)
        segmentation = self.mask_transform(segmentation).to(torch_dtype).to(device)
        agnostic = self.transform(agnostic).to(torch_dtype).to(device)

        return text_input_ids.to(device), cloth, cloth_mask, segmentation, agnostic, image

In [8]:
train_dataset = CustomDataset("/kaggle/input/vton-hd/train")
test_dataset = CustomDataset("/kaggle/input/vton-hd/test")

In [9]:
batch_size = 8
train_loader = DataLoader(train_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

# Model

In [10]:
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", torch_dtype=torch_dtype).to(device)
vae.requires_grad_(False)

unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet", torch_dtype=torch_dtype).to(device)

text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch_dtype).to(device)
text_encoder.requires_grad_(False)

scheduler = DPMSolverMultistepScheduler.from_pretrained(model_id, subfolder="scheduler")

In [11]:
class CustomUNet(nn.Module):

    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(12, 4, kernel_size=(1, 1), stride=1)
        self.unet = unet

    def forward(self, latent_model_input, timesteps, text_embeddings):

        latent_model_input = self.conv(latent_model_input)
        model_pred = self.unet(latent_model_input, timesteps, encoder_hidden_states=text_embeddings)

        return model_pred

In [12]:
custom_diff_model = CustomUNet()
custom_diff_model = custom_diff_model.to(torch_dtype).to(device)

# Training Loop

In [13]:
lr = 1e-5
epochs = 2
GRADIENT_ACCUMULATION_STEPS = 1


loss_fxn = nn.MSELoss(reduction="mean")
optimizer = torch.optim.AdamW(custom_diff_model.parameters(), lr=lr)

accelerator = Accelerator(mixed_precision="no", gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS)
custom_diff_model, optimizer, vae, text_encoder, train_loader = accelerator.prepare(custom_diff_model, optimizer, vae, text_encoder, train_loader)

In [None]:
custom_diff_model.train()

losses = []
try:
    for epoch in range(epochs):
        total_loss = 0
    
        for step, (text_input_ids, cloth, cloth_mask, segmentation, agnostic, target_img) in enumerate(train_loader):
            with accelerator.accumulate(custom_diff_model):
                
                text_embeddings = text_encoder(text_input_ids).last_hidden_state
                # print(f"text_embeddings: {text_embeddings}")
        
                target_latents = vae.encode(target_img).latent_dist.sample()
                target_latents = target_latents * vae.config.scaling_factor
                # print(f"target_latents: {target_latents}")
        
                cloth_latents = vae.encode(cloth).latent_dist.sample()
                cloth_latents = cloth_latents * vae.config.scaling_factor
                # print(f"cloth_latents: {cloth_latents}")
        
                agnostic_latents = vae.encode(agnostic).latent_dist.sample()
                agnostic_latents = agnostic_latents * vae.config.scaling_factor
                # print(f"agnostic: {agnostic_latents}")
        
                noise = torch.randn_like(target_latents, device=device)
                timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (target_latents.shape[0],), device=device)
                # print(f"noise: {noise}, timesteps: {timesteps}")
                noisy_latents = scheduler.add_noise(target_latents, noise, timesteps)
                # print(f"noisy_latents: {noisy_latents.shape} : {noisy_latents}")
                
                latent_model_input = torch.cat([noisy_latents, cloth_latents, agnostic_latents], dim=1)
                # print(f"latent_model_input: {latent_model_input.shape} : {latent_model_input}")
    
                model_pred = custom_diff_model(latent_model_input, timesteps, text_embeddings).sample
                # print(f"model_pred: {model_pred.shape} : {model_pred}")
                
                loss = loss_fxn(model_pred.float(), noise.float())
                # loss = loss.to(torch.float32)
                
                accelerator.backward(loss)

                accelerator.clip_grad_norm_(custom_diff_model.parameters(), 1.0)
                
                optimizer.zero_grad()
                optimizer.step()
    
                total_loss += loss.item()
        
                if accelerator.is_main_process:
                    if step % 50 == 0:
                        accelerator.print(f"Epoch {epoch+1}/{epochs}, Step {step}, Loss: {(total_loss / (step+0.00001)):.4f}")
                        losses.append(total_loss / (step+0.00001))
                        
    
        avg_loss = total_loss / len(train_loader)
        accelerator.print(f"Epoch {epoch+1} finished. Average Loss: {avg_loss:.4f}")
    
        # --- NEW: Save Checkpoint ---
        if accelerator.is_main_process:
            accelerator.wait_for_everyone()
            unwrapped_model = accelerator.unwrap_model(custom_diff_model)
            torch.save(unwrapped_model.state_dict(), f"model_weights_epoch_{epoch+1}.pth")
            print(f"Saved UNet checkpoint for epoch {epoch+1}")
    
    accelerator.end_training()
    print("\nFine-tuning completed.")

except KeyboardInterrupt:
    print("Exited")

    accelerator.wait_for_everyone()
    custom_diff_model = accelerator.unwrap_model(custom_diff_model)
    accelerator.end_training()

Epoch 1/2, Step 0, Loss: 130573.0343
Epoch 1/2, Step 50, Loss: 1.2985
Epoch 1/2, Step 100, Loss: 1.2966
Epoch 1/2, Step 150, Loss: 1.2956
Epoch 1/2, Step 200, Loss: 1.2874
Epoch 1/2, Step 250, Loss: 1.2830
Epoch 1/2, Step 300, Loss: 1.2781
Epoch 1/2, Step 350, Loss: 1.2800
Epoch 1/2, Step 400, Loss: 1.2768
Epoch 1/2, Step 450, Loss: 1.2766
Epoch 1/2, Step 500, Loss: 1.2771
Epoch 1/2, Step 550, Loss: 1.2762
Epoch 1/2, Step 600, Loss: 1.2758
Epoch 1/2, Step 650, Loss: 1.2776
Epoch 1/2, Step 700, Loss: 1.2775
Epoch 1/2, Step 750, Loss: 1.2775
Epoch 1/2, Step 800, Loss: 1.2778
Epoch 1 finished. Average Loss: 1.2759
Saved UNet checkpoint for epoch 1
Epoch 2/2, Step 0, Loss: 122379.4937
Epoch 2/2, Step 50, Loss: 1.2966
Epoch 2/2, Step 100, Loss: 1.2730
Epoch 2/2, Step 150, Loss: 1.2749


# Inference Loop

In [None]:
# torch_dtype = torch.float16

custom_diff_model = CustomUNet()
custom_diff_model = custom_diff_model.to(torch_dtype).to(device)
custom_diff_model.load_state_dict(torch.load('/kaggle/working/model_weights_epoch1.pth'))

In [None]:
transform = transforms.Compose([
    transforms.Resize((512, 512), interpolation=transforms.InterpolationMode.BICUBIC),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

mask_transform = transforms.Compose([
    transforms.Resize((64, 64), interpolation=transforms.InterpolationMode.NEAREST_EXACT),
    transforms.ToTensor()
])


cloth = Image.open("/kaggle/input/vton-hd/test/cloth/00006_00.jpg")
cloth_mask = Image.open("/kaggle/input/vton-hd/test/cloth-mask/00006_00.jpg")
segmentation = Image.open("/kaggle/input/vton-hd/test/image-parse-v3/00006_00.png")
agnostic = Image.open("/kaggle/input/vton-hd/test/agnostic-v3.2/00006_00.jpg")

cloth_tensor = transform(cloth).to(torch_dtype).to(device).unsqueeze(0)
cloth_mask_tensor = mask_transform(cloth_mask).to(torch_dtype).to(device).unsqueeze(0)
segmentation_tensor = mask_transform(segmentation).to(torch_dtype).to(device).unsqueeze(0)
agnostic_tensor = transform(agnostic).to(torch_dtype).to(device).unsqueeze(0)

vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", torch_dtype=torch_dtype).to(device)
vae.requires_grad_(False)

with torch.no_grad():
    
    cloth_latents = vae.encode(cloth_tensor).latent_dist.sample()
    cloth_latents = cloth_latents * vae.config.scaling_factor
    
    agnostic_latents = vae.encode(agnostic_tensor).latent_dist.sample()
    agnostic_latents = agnostic_latents * vae.config.scaling_factor

In [None]:
gemini_prompt = "Describe the image in short. Include only necessary details about the cloth that the person is wearing such that a person is able to understand the overall appearance of the cloth. Start you sentence with: A person wearing .... against a plane white background. I will use this output for text conditioning of a diffusion model so make the output appropriate."

result = gemini.generate_content([cloth, gemini_prompt])
prompt = result.text
negative_prompt = ""

print(prompt)

text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch_dtype).to(device)
text_encoder.requires_grad_(False)

with torch.no_grad():
    text_input_ids = tokenizer(prompt, padding="max_length", truncation=True, return_tensors="pt").input_ids.to(device)
    conditional_embeddings = text_encoder(text_input_ids)[0]
    
    uncond_input_ids = tokenizer(negative_prompt, padding="max_length", truncation=True, return_tensors="pt").input_ids.to(device)
    unconditional_embeddings = text_encoder(uncond_input_ids)[0]

text_embeddings = torch.cat([unconditional_embeddings, conditional_embeddings])
print(text_embeddings)

In [None]:
guidance_scale = 8.0
num_inference_steps = 50
seed = 42
generator = torch.Generator(device=device).manual_seed(seed)

In [None]:
custom_diff_model.eval()

latents = torch.randn(1, 4, 64, 64, generator=generator, device=device, dtype=torch_dtype)

scheduler.set_timesteps(num_inference_steps, device=device)
latents = latents * scheduler.init_noise_sigma


for i, t in enumerate(scheduler.timesteps):
    with torch.no_grad():
        
        scaled_latents = scheduler.scale_model_input(latents, t)
    
        latent_model_input = torch.cat([scaled_latents, cloth_latents, cloth_mask_tensor, segmentation_tensor, agnostic_latents], dim=1)
    
        latent_model_input = torch.cat([latent_model_input] * 2)
    
        noise_pred = custom_diff_model(latent_model_input, t, text_embeddings).sample
        
        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
    
        latents = scheduler.step(noise_pred, t, latents).prev_sample
        print(f"  Step {i+1} (Timestep {t.item()}): Latents Min: {latents.min().item():.4f}, Max: {latents.max().item():.4f}, Mean: {latents.mean().item():.4f}, Std: {latents.std().item():.4f}")


In [None]:
latents = 1 / vae.config.scaling_factor * latents
print(f"Latents (before VAE decode) min: {latents.min().item():.4f}, max: {latents.max().item():.4f}, mean: {latents.mean().item():.4f}, std: {latents.std().item():.4f}")


image = vae.decode(latents).sample
print("Image decoded.")
print(f"Decoded image min: {image.min().item():.4f}, max: {image.max().item():.4f}, mean: {image.mean().item():.4f}")


image = torch.clamp(image, -1, 1)
image = (image / 2 + 0.5)
image = torch.clamp(image, 0, 1)
img = image.squeeze(0)
img = (img * 255).byte()
img = img.permute(1, 2, 0).detach().cpu().numpy()
img = Image.fromarray(img)
print("Image prepared for display.")

In [None]:
img