In [1]:
import os
import cv2
import json
import torch
import numpy as np
from PIL import Image
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from torchvision import models
from torchvision.transforms import transforms
import diffusers
from diffusers import AutoencoderKL, UNet2DConditionModel, ControlNetModel, StableDiffusionControlNetPipeline, DPMSolverMultistepScheduler
from accelerate import Accelerator
import torch.nn.functional as F
from huggingface_hub import login
import google.generativeai as genai
from transformers import CLIPTextModel, CLIPTokenizer

2025-07-20 21:00:56.958374: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1753045257.405109      36 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1753045257.537593      36 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [None]:
os.environ["HF_Token"] = ""
login(token=os.environ["HF_Token"])


os.environ["GEMINI_API_KEY"] = ""
genai.configure(api_key=os.environ["GEMINI_API_KEY"])

gemini = genai.GenerativeModel("gemini-2.0-flash")

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


In [6]:
MODEL_ID = "stabilityai/stable-diffusion-2-base"
TORCH_DTYPE = torch.float16
BATCH_SIZE = 4
LEARNING_RATE = 1e-5
NUM_EPOCHS = 1
GRADIENT_ACCUMULATION_STEPS = 2

# torch.manual_seed(42)

# Pretrained Model Imports

In [17]:
tokenizer = CLIPTokenizer.from_pretrained(MODEL_ID, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(MODEL_ID, subfolder="text_encoder", torch_dtype=TORCH_DTYPE).to(device)
text_encoder.requires_grad_(False)

vae = AutoencoderKL.from_pretrained(MODEL_ID, subfolder="vae", torch_dtype=TORCH_DTYPE).to(device)
vae.requires_grad_(False)

unet = UNet2DConditionModel.from_pretrained(MODEL_ID, subfolder="unet", torch_dtype=TORCH_DTYPE).to(device)
unet.requires_grad_(False)

controlnet = ControlNetModel.from_unet(unet).to(TORCH_DTYPE).to(device)
controlnet.controlnet_cond_embedding = nn.Conv2d(
    in_channels=10, # New input channels for concatenated conditions
    out_channels=controlnet.config.block_out_channels[0], # Output channels typically match first block's out_channels
    kernel_size=1,
    stride=1,
    padding=0
).to(TORCH_DTYPE).to(device)

# Dataset Pipeline

In [8]:
class CustomDataset(Dataset):

    def __init__(self, folder_path):
        
        self.image = os.path.join(folder_path, "image")
        self.cloth = os.path.join(folder_path, "cloth")
        self.cloth_mask = os.path.join(folder_path, "cloth-mask")
        self.segmentation = os.path.join(folder_path, "image-parse-v3")
        self.agnostic = os.path.join(folder_path, "agnostic-v3.2")

        self.image_names = sorted([f for f in os.listdir(self.image)])
        self.image_names.sort(key=lambda f: int(f.split("_")[0]))

        self.cloth_names = sorted([f for f in os.listdir(self.cloth)])
        self.cloth_names.sort(key=lambda f: int(f.split("_")[0]))

        self.cloth_mask_names = sorted([f for f in os.listdir(self.cloth_mask)])
        self.cloth_mask_names.sort(key=lambda f: int(f.split("_")[0]))

        self.segmentation_names = sorted([f for f in os.listdir(self.segmentation)])
        self.segmentation_names.sort(key=lambda f: int(f.split("_")[0]))

        self.agnostic_names = sorted([f for f in os.listdir(self.agnostic)])
        self.agnostic_names.sort(key=lambda f: int(f.split("_")[0]))
    
        self.transform = transforms.Compose([
            transforms.Resize((512, 512), interpolation=transforms.InterpolationMode.BICUBIC),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])

        self.mask_transform = transforms.Compose([
            transforms.Resize((64, 64), interpolation=transforms.InterpolationMode.NEAREST_EXACT),
            transforms.ToTensor()
        ])

        with open("/kaggle/input/prompt/generated_prompts.json") as file:
            self.generated_prompts = json.load(file)

    
    def __len__(self):
        return 6506

    
    def __getitem__(self, index):

        image = Image.open(os.path.join(self.image, self.image_names[index]))
        cloth = Image.open(os.path.join(self.cloth, self.cloth_names[index]))
        cloth_mask = Image.open(os.path.join(self.cloth_mask, self.cloth_mask_names[index]))
        segmentation = Image.open(os.path.join(self.segmentation, self.segmentation_names[index]))
        agnostic = Image.open(os.path.join(self.agnostic, self.agnostic_names[index]))

        prompt = self.generated_prompts[self.cloth_names[index]]
        
        text_input_ids = tokenizer(prompt, padding="max_length", truncation=True, return_tensors="pt").input_ids.squeeze(0)
        
        image = self.transform(image).to(TORCH_DTYPE).to(device)
        cloth = self.transform(cloth).to(TORCH_DTYPE).to(device)
        cloth_mask = self.mask_transform(cloth_mask).to(TORCH_DTYPE).to(device)
        segmentation = self.mask_transform(segmentation).to(TORCH_DTYPE).to(device)
        agnostic = self.transform(agnostic).to(TORCH_DTYPE).to(device)

        return text_input_ids.to(device), cloth, cloth_mask, segmentation, agnostic, image
        

In [9]:
train_dataset = CustomDataset("/kaggle/input/vton-hd/train")
test_dataset = CustomDataset("/kaggle/input/vton-hd/test")

In [13]:
import os
os.listdir('/kaggle/input')

['prompt', 'vton-hd']

In [14]:
batch_size = 4
train_loader = DataLoader(train_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

# Training Loop

In [15]:
accelerator = Accelerator(mixed_precision="no", gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS)

# Prepare models and optimizer with Accelerator
optimizer = torch.optim.AdamW(controlnet.parameters(), lr=LEARNING_RATE) # Only optimize ControlNet
unet, controlnet, optimizer, train_loader, text_encoder, vae = accelerator.prepare(
    unet, controlnet, optimizer, train_loader, text_encoder, vae
)

# Diffusion scheduler
scheduler = DPMSolverMultistepScheduler.from_pretrained(MODEL_ID, subfolder="scheduler")

In [None]:
"""
Training loop with gradient anomaly detection and fault tolerance.
Detects and skips steps with gradient issues to continue training robustly.
"""

controlnet.train()
unet.train()

print("\nStarting training...")
for epoch in range(NUM_EPOCHS):
    total_loss = 0
    for step, (text_input_ids, cloth_tensor, cloth_mask_tensor, segmentation_tensor, agnostic_tensor, target_img_tensor) in enumerate(train_loader):
        with accelerator.accumulate(controlnet):
            with torch.no_grad():
                text_embeddings = text_encoder(text_input_ids).last_hidden_state
                target_latents = vae.encode(target_img_tensor).latent_dist.sample() * vae.config.scaling_factor
                cloth_latents = vae.encode(cloth_tensor).latent_dist.sample() * vae.config.scaling_factor
                agnostic_latents = vae.encode(agnostic_tensor).latent_dist.sample() * vae.config.scaling_factor

            noise = torch.randn_like(target_latents)
            timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (target_latents.shape[0],), device=device).long()
            noisy_latents = scheduler.add_noise(target_latents, noise, timesteps)

            additional_controlnet_cond = torch.cat([
                cloth_latents,
                agnostic_latents,
                cloth_mask_tensor,
                segmentation_tensor
            ], dim=1)

            try:
                # with torch.autograd.detect_anomaly():
                down_block_res_samples, mid_block_res_sample = controlnet(
                    sample=noisy_latents,
                    timestep=timesteps,
                    encoder_hidden_states=text_embeddings,
                    controlnet_cond=additional_controlnet_cond,
                    return_dict=False,
                )

                model_pred = unet(
                    noisy_latents,
                    timesteps,
                    encoder_hidden_states=text_embeddings,
                    down_block_additional_residuals=down_block_res_samples,
                    mid_block_additional_residual=mid_block_res_sample,
                    return_dict=False,
                )[0]
                
                # print(model_pred.float())
                loss = F.mse_loss(model_pred.float(), noise.float(), reduction="mean")
                accelerator.backward(loss)
                accelerator.clip_grad_norm_(controlnet.parameters(), 1.0)
                optimizer.step()
                optimizer.zero_grad()

            except (RuntimeError, ValueError, FloatingPointError) as e:
                accelerator.print(f"Gradient anomaly at Epoch {epoch+1}, Step {step}: {e}")
                optimizer.zero_grad(set_to_none=True)
                continue

            total_loss += loss.item()
            if accelerator.is_main_process and step % 10 == 0:
                accelerator.print(f"Epoch {epoch+1}/{NUM_EPOCHS}, Step {step}, Loss: {total_loss / (step+1):.4f}")

    avg_loss = total_loss / len(train_loader)
    accelerator.print(f"Epoch {epoch+1} finished. Average Loss: {avg_loss:.4f}")

    if accelerator.is_main_process:
        accelerator.wait_for_everyone()
        torch.save(accelerator.unwrap_model(controlnet).state_dict(), f"controlnet_epoch{epoch+1}.pth")
        print(f"Saved ControlNet checkpoint for epoch {epoch+1}")

accelerator.end_training()
print("\nTraining completed.")



Starting training...
Epoch 1/1, Step 0, Loss: 0.3439
Epoch 1/1, Step 10, Loss: 0.2739
