In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [4]:
# --- Step 1.1: Clone a reliable HD-VITON repository ---
print("Cloning the HD-VITON repository...")
!git clone https://github.com/shadow2496/VITON-HD
print("✅ Repository cloned.")



Cloning the HD-VITON repository...
Cloning into 'VITON-HD'...
remote: Enumerating objects: 52, done.[K
remote: Counting objects: 100% (19/19), done.[K
remote: Compressing objects: 100% (13/13), done.[K
remote: Total 52 (delta 12), reused 6 (delta 6), pack-reused 33 (from 3)[K
Receiving objects: 100% (52/52), 5.03 MiB | 24.89 MiB/s, done.
Resolving deltas: 100% (19/19), done.
✅ Repository cloned.


In [12]:
import os

# --- Environment Setup ---

print("🚀 Setting up the environment...")

# 1.1: Navigate into the repository directory
repo_path = "/kaggle/working/VITON-HD"
%cd {repo_path}
print(f"Current directory: {os.getcwd()}")

# 1.2: Link your dataset folder
target_data_dir = os.path.join(repo_path, "data")
source_data_dir = "/kaggle/input/clothe/clothes_tryon_dataset"
if not os.path.exists(target_data_dir):
    print("\nLinking dataset...")
    os.symlink(source_data_dir, target_data_dir)
    print("✅ Dataset linked successfully.")
else:
    print("\nDataset link check: OK.")

print("\n---------------------------------")
print("✅ Environment is ready.")
print("---------------------------------")

🚀 Setting up the environment...
/kaggle/working/VITON-HD
Current directory: /kaggle/working/VITON-HD

Dataset link check: OK.

---------------------------------
✅ Environment is ready.
---------------------------------


In [14]:
# --- Part 1: GMM Inference ---
!pip install torchgeometry
print("\n🚀 Starting Part 1: GMM Inference (Geometric Warping)...")

# Define the absolute path to YOUR uploaded GMM model.
# The single quotes are important to handle the space and parentheses.
gmm_checkpoint_path = "'/kaggle/input/weights/pytorch/default/1/gmm_final (1).pth'"
# Define a name for this experiment.
gmm_experiment_name = "GMM_test_inference_run"

print(f"Loading GMM model from: {gmm_checkpoint_path}")

# Run the test script for the GMM stage
!python test.py \
    --name {gmm_experiment_name} \
    --stage GMM \
    --dataroot ./data \
    --test_pairs ./data/test_pairs.txt \
    --workers 4 \
    --batch_size 4 \
    --checkpoint {gmm_checkpoint_path}

print("\n---------------------------------")
print("✅ GMM inference complete.")
print("Warped clothes for the test set have been generated.")
print("---------------------------------")

Collecting torchgeometry
  Downloading torchgeometry-0.1.2-py2.py3-none-any.whl.metadata (2.9 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=1.0.0->torchgeometry)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=1.0.0->torchgeometry)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=1.0.0->torchgeometry)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=1.0.0->torchgeometry)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch>=1.0.0->torchgeometry)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==1

In [18]:
import os

print("--- Data Integrity Debugger ---")
print("This script will check the first 5 entries of your pairs file against your data.")

# Define the paths to the source data directories
source_test_dir = "/kaggle/input/clothe/clothes_tryon_dataset/test"
json_dir = os.path.join(source_test_dir, "openpose_json")
img_dir = os.path.join(source_test_dir, "openpose_img")
pairs_file_path = "/kaggle/input/clothe/clothes_tryon_dataset/test_pairs.txt"

# Let's check the first 5 entries from the pairs file
num_to_check = 5

print(f"\nSource JSON directory being checked: {json_dir}")
print(f"Source IMG directory being checked: {img_dir}")
print("-" * 50)

# Verify the directories themselves exist
if not os.path.exists(json_dir):
    print(f"FATAL ERROR: JSON directory not found at {json_dir}")
if not os.path.exists(img_dir):
    print(f"FATAL ERROR: IMG directory not found at {img_dir}")
if not os.path.exists(pairs_file_path):
    print(f"FATAL ERROR: Pairs file not found at {pairs_file_path}")

print("\n--- Checking first 5 pairs from test_pairs.txt ---")

try:
    with open(pairs_file_path, "r") as f:
        for i, line in enumerate(f):
            if i >= num_to_check:
                break

            print(f"\n--- Pair #{i+1} ---")
            print(f"Original line from pairs file: '{line.strip()}'")

            person_fn, cloth_fn = line.strip().split()
            base_name, _ = os.path.splitext(person_fn)
            print(f"Extracted base name for checks: '{base_name}'")

            # --- Check 1: The JSON file ---
            # My script assumed the JSON file is named like '01234_00.json'. Let's verify.
            expected_json_path = os.path.join(json_dir, f"{base_name}.json")
            print(f"Checking for JSON file at: {expected_json_path}")
            json_exists = os.path.exists(expected_json_path)
            print(f"Found? -> {json_exists}")

            # --- Check 2: The rendered PNG file ---
            # My script assumed the IMG file is named like '01234_00_rendered.png'. Let's verify.
            expected_img_path = os.path.join(img_dir, f"{base_name}_rendered.png")
            print(f"Checking for IMG file at: {expected_img_path}")
            img_exists = os.path.exists(expected_img_path)
            print(f"Found? -> {img_exists}")

            if not json_exists or not img_exists:
                print(">>> STATUS: This pair is INVALID and would be filtered out. <<<")

except Exception as e:
    print(f"\nAn error occurred while reading the pairs file: {e}")

print("\n" + "="*50)
print("--- For comparison, here are the ACTUAL filenames ---")
print("\nActual JSON filenames (first 5):")
!ls {json_dir} | head -n 5

print("\nActual IMG filenames (first 5):")
!ls {img_dir} | head -n 5
print("="*50)

--- Data Integrity Debugger ---
This script will check the first 5 entries of your pairs file against your data.

Source JSON directory being checked: /kaggle/input/clothe/clothes_tryon_dataset/test/openpose_json
Source IMG directory being checked: /kaggle/input/clothe/clothes_tryon_dataset/test/openpose_img
--------------------------------------------------

--- Checking first 5 pairs from test_pairs.txt ---

--- Pair #1 ---
Original line from pairs file: '05006_00.jpg 11001_00.jpg'
Extracted base name for checks: '05006_00'
Checking for JSON file at: /kaggle/input/clothe/clothes_tryon_dataset/test/openpose_json/05006_00.json
Found? -> False
Checking for IMG file at: /kaggle/input/clothe/clothes_tryon_dataset/test/openpose_img/05006_00_rendered.png
Found? -> True
>>> STATUS: This pair is INVALID and would be filtered out. <<<

--- Pair #2 ---
Original line from pairs file: '02532_00.jpg 14096_00.jpg'
Extracted base name for checks: '02532_00'
Checking for JSON file at: /kaggle/input/c

In [19]:
import os
import shutil
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm

# --- Step 1: Nuke, Pave, and Link ---
# This setup part is now correct and stable.

print("🚀 Preparing complete environment...")

repo_path = "/kaggle/working/VITON-HD"
%cd {repo_path}

# NUKE previous attempts
data_root_to_nuke = os.path.join(repo_path, "data")
if os.path.lexists(data_root_to_nuke):
    shutil.rmtree(data_root_to_nuke)

# Safe Model Links
safe_link_dir = "/kaggle/working/safe_model_links"
os.makedirs(safe_link_dir, exist_ok=True)
source_gmm_path = "/kaggle/input/weights/pytorch/default/1/gmm_final (1).pth"
source_alias_path = "/kaggle/input/weights/pytorch/default/1/alias_final.pth"
source_seg_path = "/kaggle/input/weights/pytorch/default/1/seg_final.pth"
safe_gmm_path = os.path.join(safe_link_dir, "gmm_final.pth")
safe_alias_path = os.path.join(safe_link_dir, "alias_final.pth")
safe_seg_path = os.path.join(safe_link_dir, "seg_final.pth")
if not os.path.exists(safe_gmm_path): os.symlink(source_gmm_path, safe_gmm_path)
if not os.path.exists(safe_alias_path): os.symlink(source_alias_path, safe_alias_path)
if not os.path.exists(safe_seg_path): os.symlink(source_seg_path, safe_seg_path)
print("✅ Safe model links are ready.")

# PAVE & LINK with CORRECT NAMING
data_root = os.path.join(repo_path, "data")
test_dir = os.path.join(data_root, "test")
os.makedirs(test_dir, exist_ok=True)
source_test_dir = "/kaggle/input/clothe/clothes_tryon_dataset/test"
for folder_name in os.listdir(source_test_dir):
    source_path = os.path.join(source_test_dir, folder_name)
    dest_path = os.path.join(test_dir, folder_name if folder_name != "openpose_img" else "openpose-img")
    if os.path.isdir(source_path) and not os.path.lexists(dest_path):
        os.symlink(source_path, dest_path)
print("✅ Data links are ready.")

# --- THE FINAL FIX: CORRECTED DATA CLEANING LOGIC ---
print("\nCleaning test pairs file with CORRECTED filename logic...")
original_pairs_path = "/kaggle/input/clothe/clothes_tryon_dataset/test_pairs.txt"
clean_pairs_path = os.path.join(data_root, "test_pairs_clean.txt")
json_dir = os.path.join(source_test_dir, "openpose_json")
img_dir = os.path.join(source_test_dir, "openpose_img")
valid_pairs = []
total_pairs = 0

with open(original_pairs_path, "r") as f:
    pairs = f.readlines()
    total_pairs = len(pairs)
    for pair in tqdm(pairs, desc="Verifying pairs"):
        person_fn, _ = pair.strip().split()
        base_name, _ = os.path.splitext(person_fn)
        
        # THE FIX IS HERE: We now check for the CORRECT json filename format
        expected_json = os.path.join(json_dir, f"{base_name}_keypoints.json")
        expected_img = os.path.join(img_dir, f"{base_name}_rendered.png")
        
        if os.path.exists(expected_json) and os.path.exists(expected_img):
            valid_pairs.append(pair)

if not valid_pairs:
    raise ValueError(f"CRITICAL ERROR: Data cleaning resulted in 0 valid pairs out of {total_pairs}.\n"
                     "Even with the corrected filename check, no valid pairs were found. Please manually inspect your dataset.")

with open(clean_pairs_path, "w") as f:
    f.writelines(valid_pairs)
print(f"✅ Data cleaning complete. Found {len(valid_pairs)} valid pairs out of {total_pairs}.")


# --- Step 2: Run Inference with the CLEAN file ---
print("\n🚀 Starting the Full End-to-End Inference Pipeline...")
experiment_name = "final_inference_run_final_fix"
final_output_dir = "/kaggle/working/final_tryon_images/"
os.makedirs(final_output_dir, exist_ok=True)

!python test.py \
    --name {experiment_name} \
    --dataset_dir {data_root} \
    --dataset_list {clean_pairs_path} \
    --gmm_checkpoint {safe_gmm_path} \
    --alias_checkpoint {safe_alias_path} \
    --seg_checkpoint {safe_seg_path} \
    --save_dir {final_output_dir} \
    --workers 4 \
    --batch_size 4

print("\n---------------------------------")
print("✅ Full inference pipeline complete.")
print(f"Check for final images in: {final_output_dir}")
print("---------------------------------")


# --- Step 3: Visualize the Results ---
print("\n🖼️ Displaying Final Results...")
num_examples = 3
with open(clean_pairs_path, "r") as f:
    test_pairs = [line.strip().split() for line in f.readlines()]
for i in range(min(num_examples, len(test_pairs))):
    person_fn, cloth_fn = test_pairs[i]
    result_path = os.path.join(final_output_dir, person_fn)
    person_path = os.path.join(test_dir, "image", person_fn)
    cloth_path = os.path.join(test_dir, "cloth", cloth_fn)
    if not os.path.exists(result_path):
        print(f"Result file not found, skipping: {result_path}")
        continue
    person_img = Image.open(person_path).convert("RGB")
    cloth_img = Image.open(cloth_path).convert("RGB")
    result_img = Image.open(result_path).convert("RGB")
    fig, axes = plt.subplots(1, 3, figsize=(15, 6))
    axes[0].imshow(person_img); axes[0].set_title(f"Original Person\n({person_fn})"); axes[0].axis('off')
    axes[1].imshow(cloth_img); axes[1].set_title(f"Garment\n({cloth_fn})"); axes[1].axis('off')
    axes[2].imshow(result_img); axes[2].set_title("Generated Try-On Result"); axes[2].axis('off')
    plt.tight_layout()
    plt.show()

🚀 Preparing complete environment...
/kaggle/working/VITON-HD
✅ Safe model links are ready.
✅ Data links are ready.

Cleaning test pairs file with CORRECTED filename logic...


Verifying pairs: 100%|██████████| 2032/2032 [00:02<00:00, 977.18it/s] 


✅ Data cleaning complete. Found 2032 valid pairs out of 2032.

🚀 Starting the Full End-to-End Inference Pipeline...
Namespace(name='final_inference_run_final_fix', batch_size=4, workers=4, load_height=1024, load_width=768, shuffle=False, dataset_dir='/kaggle/working/VITON-HD/data', dataset_mode='test', dataset_list='/kaggle/working/VITON-HD/data/test_pairs_clean.txt', checkpoint_dir='./checkpoints/', save_dir='/kaggle/working/final_tryon_images/', display_freq=1, seg_checkpoint='/kaggle/working/safe_model_links/seg_final.pth', gmm_checkpoint='/kaggle/working/safe_model_links/gmm_final.pth', alias_checkpoint='/kaggle/working/safe_model_links/alias_final.pth', semantic_nc=13, init_type='xavier', init_variance=0.02, grid_size=5, norm_G='spectralaliasinstance', ngf=64, num_upsampling_layers='most')
Network [SegGenerator] was created. Total number of parameters: 34.5 million. To see the architecture, do print(network).
Network [ALIASGenerator] was created. Total number of parameters: 100.5 

In [8]:
%%writefile dataset.py
# In dataset.py

import torch
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
import os
import json
import numpy as np

class VitonHDDataset(Dataset):
    """
    VitonHD Dataset for training and testing.
    Loads all necessary data inputs for GMM and TOM.
    """
    def __init__(self, opt, is_train=True):
        self.opt = opt
        self.data_root = opt.dataroot
        self.data_mode = 'train' if is_train else 'test'

        # Define paths to data directories
        self.image_dir = os.path.join(self.data_root, self.data_mode, 'image')
        self.cloth_dir = os.path.join(self.data_root, self.data_mode, 'cloth')
        self.cloth_mask_dir = os.path.join(self.data_root, self.data_mode, 'cloth-mask')
        self.image_parse_dir = os.path.join(self.data_root, self.data_mode, 'image-parse-v3')
        self.openpose_img_dir = os.path.join(self.data_root, self.data_mode, 'openpose_img')

        # Load the list of image pairs (person, cloth) from a text file
        self.pair_list_path = os.path.join(self.data_root, f'{self.data_mode}_pairs.txt')
        self.image_pairs = self._load_pairs()

        # Define standard image transformations
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        self.to_tensor = transforms.ToTensor()

    def _load_pairs(self):
        pairs = []
        with open(self.pair_list_path, 'r') as f:
            for line in f.readlines():
                person_name, cloth_name = line.strip().split()
                pairs.append((person_name, cloth_name))
        return pairs

    def __len__(self):
        return len(self.image_pairs)

    def __getitem__(self, idx):
        person_name, cloth_name = self.image_pairs[idx]

        # 1. Load Cloth image and its mask
        cloth_path = os.path.join(self.cloth_dir, cloth_name)
        cloth_image = Image.open(cloth_path).convert('RGB').resize((self.opt.load_width, self.opt.load_height))
        cloth_mask_path = os.path.join(self.cloth_mask_dir, cloth_name)
        cloth_mask = Image.open(cloth_mask_path).convert('L').resize((self.opt.load_width, self.opt.load_height))

        # 2. Load Person Image
        person_path = os.path.join(self.image_dir, person_name)
        person_image = Image.open(person_path).convert('RGB').resize((self.opt.load_width, self.opt.load_height))

        # 3. Load Person Segmentation Map (Parse Map)
        parse_path = os.path.join(self.image_parse_dir, person_name.replace('.jpg', '.png'))
        parse_map = Image.open(parse_path).convert('L').resize((self.opt.load_width, self.opt.load_height), Image.NEAREST)
        parse_array = np.array(parse_map)

        # Create the agnostic person image (person with original clothes masked out)
        # These parse labels correspond to clothing parts
        parse_cloth_labels = [5, 6, 7] 
        parse_cloth_mask = np.isin(parse_array, parse_cloth_labels)
        agnostic_image = Image.fromarray((np.array(person_image) * (1 - np.expand_dims(parse_cloth_mask, -1))).astype(np.uint8))
        # 4. Load Pose Map
        pose_img_path = os.path.join(self.openpose_img_dir, person_name.replace('.jpg', '_rendered.png'))
        pose_map = Image.open(pose_img_path).convert('RGB').resize((self.opt.load_width, self.opt.load_height))

        # Create the person representation for the GMM
        # This combines the person's shape (from parse map) and pose
        # In dataset.py, inside the VitonHDDataset class's __getitem__ method

# ... (all the loading code remains the same) ...

# Apply final transformations to all images
        cloth_tensor = self.transform(cloth_image)
        agnostic_image_tensor = self.transform(agnostic_image)
        pose_map_tensor = self.transform(pose_map)
        person_image_tensor = self.transform(person_image)
        cloth_mask_tensor = self.to_tensor(cloth_mask)
        
        # Create the agnostic person parse map (1 channel)
        agnostic_parse_array = parse_array * (1 - parse_cloth_mask)
        agnostic_parse_map = Image.fromarray(agnostic_parse_array.astype(np.uint8))
        agnostic_parse_tensor = self.to_tensor(agnostic_parse_map) # This is already 1 channel
        
        # Create the NEW 7-channel person representation for the GMM
        # It combines the agnostic image (3ch), pose map (3ch), and agnostic parse map (1ch)
        gmm_person_representation = torch.cat([agnostic_image_tensor, pose_map_tensor, agnostic_parse_tensor], 0)
        
        data = {
            'cloth': cloth_tensor,
            'cloth_mask': cloth_mask_tensor,
            'person_image': person_image_tensor,
            'agnostic_person': agnostic_image_tensor,
            'gmm_person_representation': gmm_person_representation, # This is now 7 channels
            'person_name': person_name,
            'cloth_name': cloth_name,
}

        return data
        


Overwriting dataset.py


In [14]:
%%writefile models.py
# In models.py
# In models.py
import torch.nn as nn  # <-- ADD THIS LINE
import torch.nn.functional as F

import torch
# ... (imports remain the same) ...

# --- Geometric Matching Module (GMM) - FINAL ARCHITECTURE ---

class GMM(nn.Module):
    """
    This GMM architecture EXACTLY matches the layer names and structure
    from the pre-trained weights file, solving both the size mismatch
    and runtime errors.
    """
    def __init__(self, opt):
        super(GMM, self).__init__()
        
        # Person Representation Feature Extraction ('extractionA')
        # Input: 7 channels (as per original error). Output: progresses to 512 channels.
        extractionA_model = [
            nn.Conv2d(opt.person_rep_channels, 64, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False), nn.BatchNorm2d(128), nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1, bias=False), nn.BatchNorm2d(256), nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1, bias=False), nn.BatchNorm2d(512), nn.LeakyReLU(0.2, inplace=True),
        ]
        self.extractionA = nn.Sequential(*extractionA_model)

        # Cloth Feature Extraction ('extractionB')
        # Input: 3 channels. Output: progresses to 512 channels.
        extractionB_model = [
            nn.Conv2d(opt.cloth_channels, 64, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False), nn.BatchNorm2d(128), nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1, bias=False), nn.BatchNorm2d(256), nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1, bias=False), nn.BatchNorm2d(512), nn.LeakyReLU(0.2, inplace=True),
        ]
        self.extractionB = nn.Sequential(*extractionB_model)
        
        # Flow Regression Network ('regression')
        # This part requires more careful reconstruction based on typical architectures.
        # It likely takes the two 512-channel feature maps and processes them.
        # Let's assume a feature-correlation and upsampling path.
        # A simple concatenation and regression is a good starting point.
        regression_conv = [
            # The input will be the concatenated features: 512 (person) + 512 (cloth) = 1024
            nn.Conv2d(1024, 512, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(512), nn.ReLU(inplace=True),
            nn.Conv2d(512, 256, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(256), nn.ReLU(inplace=True),
            # Upsampling layers to restore original dimensions
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=False), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=False), nn.BatchNorm2d(64), nn.ReLU(inplace=True),
            # Final layers to get to the 2-channel flow field
            nn.Conv2d(64, 32, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(32), nn.ReLU(inplace=True),
            nn.Conv2d(32, 2, kernel_size=3, padding=1),
            nn.Tanh()
        ]
        self.regression = nn.Sequential(*regression_conv)
        
    def forward(self, person_rep, cloth):
        # Pass inputs through their respective feature extractors
        featureA = self.extractionA(person_rep)
        featureB = self.extractionB(cloth)
        
        # Concatenate the final feature maps
        x = torch.cat([featureA, featureB], 1)
        
        # Regress the flow field
        flow = self.regression(x)
        
        # We need to upsample the flow to the original image size
        flow = F.interpolate(flow, size=(person_rep.size(2), person_rep.size(3)), mode='bilinear', align_corners=True)
        
        return flow


# --- Try-On Module (TOM) ---
# ... (The TOM code can remain the same) ...
    # ...
# ... (rest of TOM code) ...

# --- Try-On Module (TOM) ---
# ... (The TOM code remains the same as it was not part of the error) ...
class UnetGenerator(nn.Module):
    """A standard U-Net generator for image-to-image translation."""
    def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False):
        super(UnetGenerator, self).__init__()
        unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True)
        for i in range(num_downs - 5):
            unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
        unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, submodule=unet_block, norm_layer=norm_layer)
        unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, submodule=unet_block, norm_layer=norm_layer)
        unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, submodule=unet_block, norm_layer=norm_layer)
        self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer)

    def forward(self, x):
        return self.model(x)

class UnetSkipConnectionBlock(nn.Module):
    def __init__(self, outer_nc, inner_nc, input_nc=None, submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
        super(UnetSkipConnectionBlock, self).__init__()
        self.outermost = outermost
        if input_nc is None:
            input_nc = outer_nc
        downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=False)
        downrelu = nn.LeakyReLU(0.2, True)
        downnorm = norm_layer(inner_nc)
        uprelu = nn.ReLU(True)
        upnorm = norm_layer(outer_nc)
        if outermost:
            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1)
            down = [downconv]
            up = [uprelu, upconv, nn.Tanh()]
            model = down + [submodule] + up
        elif innermost:
            upconv = nn.ConvTranspose2d(inner_nc, outer_nc, kernel_size=4, stride=2, padding=1, bias=False)
            down = [downrelu, downconv]
            up = [uprelu, upconv, upnorm]
            model = down + up
        else:
            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=False)
            down = [downrelu, downconv, downnorm]
            up = [uprelu, upconv, upnorm]
            model = down + [submodule] + up
        self.model = nn.Sequential(*model)

    def forward(self, x):
        if self.outermost:
            return self.model(x)
        else:
            return torch.cat([x, self.model(x)], 1)

class TOM(nn.Module):
    def __init__(self, opt):
        super(TOM, self).__init__()
        # Input to TOM: agnostic person (3ch) + warped cloth (3ch) = 6 channels
        input_nc = opt.agnostic_channels + opt.cloth_channels
        self.generator = UnetGenerator(input_nc, opt.output_channels, num_downs=7, ngf=64) # Increased num_downs for higher res

    def forward(self, agnostic_person, warped_cloth):
        x = torch.cat([agnostic_person, warped_cloth], 1)
        return self.generator(x)

Overwriting models.py


In [21]:
%%writefile train.py
# In train.py

import torch
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import os
import torch.nn as nn  # <-- ADD THIS LINE

def train_gmm(opt, gmm_model, dataset):
    gmm_model.train()
    dataloader = DataLoader(dataset, batch_size=opt.batch_size, shuffle=True, num_workers=4)
    optimizer = optim.Adam(gmm_model.parameters(), lr=opt.lr)
    
    # GMM loss function: L1 loss between the warped cloth mask and the true cloth mask
    criterionL1 = nn.L1Loss()

    for epoch in range(opt.gmm_epochs):
        for i, data in enumerate(dataloader):
            cloth = data['cloth'].cuda()
            cloth_mask = data['cloth_mask'].cuda()
            person_rep = data['gmm_person_representation'].cuda()
            
            optimizer.zero_grad()
            
            # Forward pass
            flow = gmm_model(person_rep, cloth)
            
            # Warp the cloth mask using the predicted flow
            warped_mask = F.grid_sample(cloth_mask, flow.permute(0, 2, 3, 1), mode='bilinear', padding_mode='zeros', align_corners=True)

            # The loss encourages the warped cloth to align with the ground truth clothing area
            loss = criterionL1(warped_mask, cloth_mask)
            
            loss.backward()
            optimizer.step()

            if (i+1) % 100 == 0:
                print(f"GMM - Epoch [{epoch+1}/{opt.gmm_epochs}], Step [{i+1}/{len(dataloader)}], Loss: {loss.item():.4f}")

        # Save model checkpoint after each epoch
        torch.save(gmm_model.state_dict(), os.path.join(opt.checkpoint_dir, opt.name, f'gmm_epoch_{epoch+1}.pth'))

def train_tom(opt, tom_model, gmm_model, dataset):
    tom_model.train()
    gmm_model.eval()  # GMM is frozen and used for inference only
    dataloader = DataLoader(dataset, batch_size=opt.batch_size, shuffle=True, num_workers=4)
    optimizer = optim.Adam(tom_model.parameters(), lr=opt.lr)

    # TOM loss: L1 loss + Perceptual loss (VGG) for realism
    criterionL1 = nn.L1Loss()
    # For a real project, implementing a VGGPerceptualLoss is highly recommended
    # criterionVGG = VGGPerceptualLoss().cuda() 

    for epoch in range(opt.tom_epochs):
        for i, data in enumerate(dataloader):
            cloth = data['cloth'].cuda()
            agnostic_person = data['agnostic_person'].cuda()
            person_image = data['person_image'].cuda()
            person_rep = data['gmm_person_representation'].cuda()

            # Get the warped cloth from the (pre-trained or newly trained) GMM
            with torch.no_grad():
                flow = gmm_model(person_rep, cloth)
                warped_cloth = F.grid_sample(cloth, flow.permute(0, 2, 3, 1), mode='bilinear', padding_mode='border', align_corners=True)

            optimizer.zero_grad()

            # Forward pass through TOM
            generated_image = tom_model(agnostic_person, warped_cloth)

            # Calculate loss against the ground truth person image
            loss = criterionL1(generated_image, person_image)

            loss.backward()
            optimizer.step()

            if (i+1) % 100 == 0:
                print(f"TOM - Epoch [{epoch+1}/{opt.tom_epochs}], Step [{i+1}/{len(dataloader)}], L1 Loss: {loss.item():.4f}")

        # Save model checkpoint
        torch.save(tom_model.state_dict(), os.path.join(opt.checkpoint_dir, opt.name, f'tom_epoch_{epoch+1}.pth'))

Overwriting train.py


In [22]:
%%writefile test.py
# In test.py

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.utils import save_image
import os

def test(opt, gmm_model, tom_model, dataset):
    gmm_model.eval()
    tom_model.eval()
    dataloader = DataLoader(dataset, batch_size=opt.batch_size_test, shuffle=False, num_workers=4)
    
    output_dir = os.path.join(opt.result_dir, opt.name)
    os.makedirs(output_dir, exist_ok=True)

    for i, data in enumerate(dataloader):
        cloth = data['cloth'].cuda()
        agnostic_person = data['agnostic_person'].cuda()
        person_rep = data['gmm_person_representation'].cuda()
        person_name = data['person_name']
        cloth_name = data['cloth_name']
        
        with torch.no_grad():
            # Run GMM to get the warped cloth
            flow = gmm_model(person_rep, cloth)
            warped_cloth = F.grid_sample(cloth, flow.permute(0, 2, 3, 1), mode='bilinear', padding_mode='border', align_corners=True)
            
            # Run TOM to get the final try-on image
            generated_image = tom_model(agnostic_person, warped_cloth)

            # Save each result in the batch
            for j in range(len(person_name)):
                p_name = os.path.splitext(person_name[j])[0]
                c_name = os.path.splitext(cloth_name[j])[0]
                
                # Create a visual comparison grid and save
                visuals = torch.cat([
                    (data['person_image'][j].cpu() + 1) / 2, # Original Person
                    (cloth[j].cpu() + 1) / 2,               # Target Cloth
                    (warped_cloth[j].cpu() + 1) / 2,         # Warped Cloth
                    (generated_image[j].cpu() + 1) / 2       # Final Result
                ], dim=2) # Concatenate horizontally for comparison
                
                save_path = os.path.join(output_dir, f"{p_name}_tries_{c_name}.png")
                save_image(visuals, save_path)
        
        print(f"Processed and saved results for batch {i+1}/{len(dataloader)}")

Overwriting test.py


In [42]:
%%writefile main.py
# In main.py
import argparse
import os
import torch

from dataset import VitonHDDataset
from models import GMM, TOM
from train import train_gmm, train_tom
from test import test
def clean_dataparallel_keys(state_dict):
    new_state_dict = {}
    for k, v in state_dict.items():
        if k.startswith('module.'):
            name = k[7:]  # remove 'module.'
            new_state_dict[name] = v
        else:
            new_state_dict[k] = v
    return new_state_dict
def get_opt():
    parser = argparse.ArgumentParser()
    # --- Experiment and Data Options ---
    parser.add_argument("--name", default="HD-VITON_run")
    parser.add_argument("--gpu_ids", default="0", help="e.g., 0,1,2. use -1 for CPU")
    parser.add_argument("--mode", default="test", help="train | test")
    parser.add_argument("--dataroot", required=True, help="path to the dataset folder")
    parser.add_argument("--checkpoint_dir", default="./checkpoints")
    parser.add_argument("--result_dir", default="./results")

    # --- Data Loading and Model Options ---
    parser.add_argument("--batch_size", type=int, default=4)
    parser.add_argument("--batch_size_test", type=int, default=1)
    parser.add_argument("--load_height", type=int, default=1024)
    parser.add_argument("--load_width", type=int, default=768)
    parser.add_argument("--person_rep_channels", type=int, default=4) # 1ch parse map + 3ch pose map
    parser.add_argument("--cloth_channels", type=int, default=3)
    parser.add_argument("--agnostic_channels", type=int, default=3)
    parser.add_argument("--output_channels", type=int, default=3)

    # --- Training Specific Options ---
    parser.add_argument("--lr", type=float, default=0.0001, help="learning rate")
    parser.add_argument("--gmm_epochs", type=int, default=50)
    parser.add_argument("--tom_epochs", type=int, default=50)

    # --- Checkpoint Loading ---
    parser.add_argument("--gmm_checkpoint", default=None, help="path to GMM pre-trained weights")
    parser.add_argument("--tom_checkpoint", default=None, help="path to TOM pre-trained weights")

    opt = parser.parse_args()
    return opt

# In main.py

def main():
    opt = get_opt()
    
    # Create directories for checkpoints and results
    os.makedirs(os.path.join(opt.checkpoint_dir, opt.name), exist_ok=True)
    os.makedirs(os.path.join(opt.result_dir, opt.name), exist_ok=True)

    # --- THIS SECTION IS CORRECTED ---
    # Correctly parse the GPU IDs string into a list of integers
    gpu_ids_list = []
    if opt.gpu_ids:
        str_ids = opt.gpu_ids.split(',')
        for str_id in str_ids:
            try:
                # Use strip() to remove any accidental whitespace
                gpu_id = int(str_id.strip())
                if gpu_id >= 0:
                    gpu_ids_list.append(gpu_id)
            except ValueError:
                print(f"Warning: Could not parse GPU ID '{str_id}'. Skipping.")
    
    # Set the primary device. For DataParallel, this is typically cuda:0
    # The wrapper will handle distributing to other GPUs.
    device = torch.device("cuda:0" if torch.cuda.is_available() and gpu_ids_list else "cpu")

    # --- Initialize Dataset ---
    is_train = opt.mode == 'train'
    dataset = VitonHDDataset(opt, is_train=is_train)
    print(f"Dataset created for mode: '{opt.mode}'. Found {len(dataset)} samples.")

    # --- Initialize Models ---
    gmm_model = GMM(opt).to(device)
    tom_model = TOM(opt).to(device)
    
    # --- Wrap models for Multi-GPU if more than one ID is provided ---
    if len(gpu_ids_list) > 1:
        # This print statement is now correct
        print(f"Using {len(gpu_ids_list)} GPUs: {gpu_ids_list}")
        gmm_model = torch.nn.DataParallel(gmm_model, device_ids=gpu_ids_list)
        tom_model = torch.nn.DataParallel(tom_model, device_ids=gpu_ids_list)
    
    # --- Checkpoint loading and the rest of the script remain the same ---
    if opt.gmm_checkpoint:
        gmm_model.load_state_dict(torch.load(opt.gmm_checkpoint, map_location=device), strict=False)
        print(f"GMM checkpoint loaded from: {opt.gmm_checkpoint}")
    if opt.tom_checkpoint:
        tom_model.load_state_dict(torch.load(opt.tom_checkpoint, map_location=device), strict=False)
        print(f"TOM checkpoint loaded from: {opt.tom_checkpoint}")

    # ... (rest of the main function: calling train or test) ...
    if opt.mode == 'train':
        print("Starting GMM training...")
        train_gmm(opt, gmm_model, dataset)
        print("GMM training complete. Starting TOM training...")
        
        # When saving/loading in a multi-gpu setup, it's safer to access the underlying model
        # However, for this project structure, a simple load will often work.
        best_gmm_path = os.path.join(opt.checkpoint_dir, opt.name, f'gmm_epoch_{opt.gmm_epochs}.pth')
        gmm_model.load_state_dict(torch.load(best_gmm_path, map_location=device))
        
        train_tom(opt, tom_model, gmm_model, dataset)
        print("All training finished.")
        
    elif opt.mode == 'test':
        if not opt.gmm_checkpoint or not opt.tom_checkpoint:
            raise ValueError("In test mode, you must provide paths to pre-trained --gmm_checkpoint and --tom_checkpoint.")
        print("Starting testing...")
        test(opt, gmm_model, tom_model, dataset)
        print(f"Testing complete. Results saved in: {os.path.join(opt.result_dir, opt.name)}")
if __name__ == '__main__':
    main()

Overwriting main.py


In [19]:
!python main.py \
  --mode test \
  --dataroot "/kaggle/input/clothe/clothes_tryon_dataset" \
  --name "HD-VITON-Final-Attempt" \
  --gpu_ids 0 \
  --gmm_checkpoint "/kaggle/input/weights/pytorch/default/1/gmm_final (1).pth" \
  --tom_checkpoint "/kaggle/input/weights/pytorch/default/1/alias_final.pth" \
  --result_dir "/kaggle/working/" \
  --person_rep_channels 7

Dataset created for mode: 'test'. Found 2032 samples.
GMM checkpoint loaded from: /kaggle/input/weights/pytorch/default/1/gmm_final (1).pth
TOM checkpoint loaded from: /kaggle/input/weights/pytorch/default/1/alias_final.pth
Starting testing...
^C
Traceback (most recent call last):
  File "/kaggle/working/main.py", line 91, in <module>
    main()
  File "/kaggle/working/main.py", line 87, in main
    test(opt, gmm_model, tom_model, dataset)
  File "/kaggle/working/test.py", line 17, in test
    for i, data in enumerate(dataloader):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 708, in __next__
    data = self._next_data()
           ^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1458, in _next_data
    idx, data = self._get_data()
                ^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1420, in _get_data
    success, data = self.

In [45]:
!python main.py \
  --mode train \
  --dataroot "/kaggle/input/clothe/clothes_tryon_dataset" \
  --name "FineTuned-MultiGPU" \
  --gpu_ids 0,1 \
  --person_rep_channels 7 \
  --batch_size 4 \
  --lr 0.00005 \
  --gmm_epochs 15 \
  --tom_epochs 5 \
  --checkpoint_dir "/kaggle/working/finetuned_checkpoints_multi_gpu" \
  --gmm_checkpoint "/kaggle/input/weights/pytorch/default/1/gmm_final (1).pth" \
  --tom_checkpoint "/kaggle/input/weights/pytorch/default/1/alias_final.pth"

Dataset created for mode: 'train'. Found 11647 samples.
Using 2 GPUs: [0, 1]
GMM checkpoint loaded from: /kaggle/input/weights/pytorch/default/1/gmm_final (1).pth
TOM checkpoint loaded from: /kaggle/input/weights/pytorch/default/1/alias_final.pth
Starting GMM training...
GMM - Epoch [1/15], Step [100/2912], Loss: 0.0284
GMM - Epoch [1/15], Step [200/2912], Loss: 0.0782
GMM - Epoch [1/15], Step [300/2912], Loss: 0.0850
GMM - Epoch [1/15], Step [400/2912], Loss: 0.0438
GMM - Epoch [1/15], Step [500/2912], Loss: 0.0312
GMM - Epoch [1/15], Step [600/2912], Loss: 0.0083
GMM - Epoch [1/15], Step [700/2912], Loss: 0.0168
GMM - Epoch [1/15], Step [800/2912], Loss: 0.0111
GMM - Epoch [1/15], Step [900/2912], Loss: 0.0546
GMM - Epoch [1/15], Step [1000/2912], Loss: 0.0093
GMM - Epoch [1/15], Step [1100/2912], Loss: 0.0418
GMM - Epoch [1/15], Step [1200/2912], Loss: 0.0050
GMM - Epoch [1/15], Step [1300/2912], Loss: 0.0058
GMM - Epoch [1/15], Step [1400/2912], Loss: 0.0467
GMM - Epoch [1/15], Ste

In [32]:
!python main.py \
  --mode test \
  --dataroot "/kaggle/input/clothe/clothes_tryon_dataset" \
  --name "Results_Epoch_5" \
  --gpu_ids 0 \
  --person_rep_channels 7 \
  --result_dir "/kaggle/working/visualization_results" \
  --gmm_checkpoint "/kaggle/working/finetuned_checkpoints_multi_gpu/FineTuned-MultiGPU/gmm_epoch_5.pth" \
  --tom_checkpoint "/kaggle/working/finetuned_checkpoints_multi_gpu/FineTuned-MultiGPU/tom_epoch_5.pth"

Traceback (most recent call last):
  File "/kaggle/working/main.py", line 123, in <module>
    if opt.gmm_checkpoint:
       ^^^
NameError: name 'opt' is not defined. Did you mean: 'oct'?


In [48]:
# Cell 1: --- CONFIGURATION FOR VISUALIZATION ---
# Run this cell first.

import argparse

# We use argparse.Namespace to create a simple object that holds all our settings.
opt = argparse.Namespace(
    
    # --- Mode and Paths ---
    # Set the mode to 'test' to generate images
    mode="test", 
    
    # Path to the dataset to get test images from
    dataroot="/kaggle/input/clothe/clothes_tryon_dataset", 
    
    # Give a unique name for the output folder
    name="Visualization_From_FineTuned_Epoch_5", 
    
    # Where to save the final images
    result_dir="/kaggle/working/my_final_results2", 
    
    # --- Checkpoints to Load for Visualization ---
    # Path to YOUR new GMM weight file
    gmm_checkpoint="/kaggle/working/finetuned_checkpoints_multi_gpu/FineTuned-MultiGPU/gmm_epoch_15.pth", 
    
    # Path to YOUR new TOM weight file
    tom_checkpoint="/kaggle/working/finetuned_checkpoints_multi_gpu/FineTuned-MultiGPU/tom_epoch_5.pth",

    # --- Hardware and Model Shape ---
    gpu_ids="0", # Use a single GPU for testing
    person_rep_channels=7,
    cloth_channels=3,
    agnostic_channels=3,
    output_channels=3,

    # --- Data Loading Params (can be left as is for testing) ---
    batch_size_test=1,
    load_height=1024,
    load_width=768
)

print("✅ Configuration loaded successfully!")
print(f"   Mode set to: {opt.mode}")
print(f"   Loading GMM from: {opt.gmm_checkpoint}")
print(f"   Loading TOM from: {opt.tom_checkpoint}")

✅ Configuration loaded successfully!
   Mode set to: test
   Loading GMM from: /kaggle/working/finetuned_checkpoints_multi_gpu/FineTuned-MultiGPU/gmm_epoch_15.pth
   Loading TOM from: /kaggle/working/finetuned_checkpoints_multi_gpu/FineTuned-MultiGPU/tom_epoch_5.pth


In [49]:
# Cell 2: --- VISUALIZATION EXECUTION CELL ---
# Run this cell AFTER running the configuration cell above.

# 1. All necessary imports
import os
import torch
import torch.nn as nn
from dataset import VitonHDDataset
from models import GMM, TOM
from test import test # We only need the 'test' function

# 2. Helper function to clean model keys from multi-GPU training
def clean_dataparallel_keys(state_dict):
    new_state_dict = {}
    for k, v in state_dict.items():
        if k.startswith('module.'):
            name = k[7:]  # remove 'module.'
            new_state_dict[name] = v
        else:
            new_state_dict[k] = v
    return new_state_dict

# --- Main Logic Starts Here ---
# The 'opt' object is used directly from the cell above.

# 3. Setup GPU and create the results directory
os.makedirs(os.path.join(opt.result_dir, opt.name), exist_ok=True)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# 4. Load the test dataset
# We set is_train=False because opt.mode is 'test'
test_dataset = VitonHDDataset(opt, is_train=False) 
print(f"Test dataset loaded. Found {len(test_dataset)} samples.")

# 5. Initialize the models
gmm_model = GMM(opt).to(device)
tom_model = TOM(opt).to(device)

# 6. Load your fine-tuned weights
print("Loading fine-tuned weights...")
if opt.gmm_checkpoint:
    state_dict = torch.load(opt.gmm_checkpoint, map_location=device)
    state_dict = clean_dataparallel_keys(state_dict)
    gmm_model.load_state_dict(state_dict)
    print("   ✅ GMM weights loaded.")
    
if opt.tom_checkpoint:
    state_dict = torch.load(opt.tom_checkpoint, map_location=device)
    state_dict = clean_dataparallel_keys(state_dict)
    tom_model.load_state_dict(state_dict)
    print("   ✅ TOM weights loaded.")

# 7. Run the visualization function
print("\nStarting visualization...")
test(opt, gmm_model, tom_model, test_dataset)

print(f"\n✅ Visualization complete!")
print(f"Check for your results in the directory: {os.path.join(opt.result_dir, opt.name)}")

Test dataset loaded. Found 2032 samples.
Loading fine-tuned weights...
   ✅ GMM weights loaded.
   ✅ TOM weights loaded.

Starting visualization...
Processed and saved results for batch 1/2032
Processed and saved results for batch 2/2032
Processed and saved results for batch 3/2032
Processed and saved results for batch 4/2032
Processed and saved results for batch 5/2032
Processed and saved results for batch 6/2032
Processed and saved results for batch 7/2032
Processed and saved results for batch 8/2032
Processed and saved results for batch 9/2032
Processed and saved results for batch 10/2032
Processed and saved results for batch 11/2032
Processed and saved results for batch 12/2032
Processed and saved results for batch 13/2032
Processed and saved results for batch 14/2032
Processed and saved results for batch 15/2032
Processed and saved results for batch 16/2032
Processed and saved results for batch 17/2032
Processed and saved results for batch 18/2032
Processed and saved results for b

KeyboardInterrupt: 