In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
import os
import json
import torch
import numpy as np
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
import cv2

class FashionVTONDataset(Dataset):
    """
    Custom PyTorch Dataset for the Fashion VTON task.
    This version dynamically finds the clothing label to be more robust.
    """
    def __init__(self, data_root, image_size=(256, 192), original_image_size=(768, 1024)):
        self.data_root = data_root
        self.image_size = image_size
        self.original_image_size = original_image_size
        
        self.image_dir = os.path.join(data_root, 'image')
        self.cloth_dir = os.path.join(data_root, 'cloth')
        self.cloth_mask_dir = os.path.join(data_root, 'cloth-mask')
        self.pose_dir = os.path.join(data_root, 'openpose_json')
        self.parse_dir = os.path.join(data_root, 'image-parse-v3')

        self.image_files = sorted([f for f in os.listdir(self.image_dir) if f.endswith(('.jpg', '.png'))])

        self.transform = transforms.Compose([
            transforms.Resize(self.image_size, interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        
        self.mask_transform = transforms.Compose([
            transforms.Resize(self.image_size, interpolation=transforms.InterpolationMode.NEAREST),
            transforms.ToTensor()
        ])

    def __len__(self):
        return len(self.image_files)

    def find_upper_cloth_label(self, pose_keypoints, parse_array):
        """
        Dynamically finds the clothing label by looking at the most common
        pixel value in the torso area defined by pose keypoints.
        """
        # BODY_25 keypoint indices for the torso area
        torso_indices = [1, 2, 5, 8, 9, 12]  # Neck, Shoulders, MidHip, Hips
        
        visible_points = []
        for i in torso_indices:
            x, y, conf = pose_keypoints[i]
            if conf > 0.1:  # Use only confident keypoints
                visible_points.append((int(x), int(y)))
        
        # If not enough points to define an area, return a fallback
        if len(visible_points) < 3:
            return 5  # Fallback to the original hardcoded label

        # Define a bounding box around the torso
        x_coords, y_coords = zip(*visible_points)
        min_x, max_x = min(x_coords), max(x_coords)
        min_y, max_y = min(y_coords), max(y_coords)

        # Ensure the bounding box has a valid area
        if max_x <= min_x or max_y <= min_y:
            return 5

        # Crop the parse map to the torso area
        torso_parse_area = parse_array[min_y:max_y, min_x:max_x]

        # Find the most frequent non-zero label
        unique_labels, counts = np.unique(torso_parse_area, return_counts=True)
        non_zero_mask = (unique_labels != 0) # Ignore background
        
        if np.any(non_zero_mask):
            # Get the label with the highest count among non-zero labels
            mode_label = unique_labels[non_zero_mask][np.argmax(counts[non_zero_mask])]
            return mode_label
        else:
            return 5 # Fallback if only background is found

    def __getitem__(self, idx):
        image_name = self.image_files[idx]
        base_name = os.path.splitext(image_name)[0]

        # Load images
        person_image = Image.open(os.path.join(self.image_dir, image_name)).convert('RGB')
        cloth_image = Image.open(os.path.join(self.cloth_dir, image_name)).convert('RGB')
        cloth_mask = Image.open(os.path.join(self.cloth_mask_dir, image_name)).convert('L')

        # Load pose data
        pose_path = os.path.join(self.pose_dir, f"{base_name}_keypoints.json")
        try:
            with open(pose_path, 'r') as f:
                pose_data = json.load(f)
            pose_keypoints = np.array(pose_data['people'][0]['pose_keypoints_2d']).reshape(-1, 3)
        except (FileNotFoundError, IndexError):
            pose_keypoints = np.zeros((25, 3), dtype=np.float32)

        # Load segmentation map (at original resolution for label finding)
        parse_path = os.path.join(self.parse_dir, f"{base_name}.png")
        parse_image_orig = Image.open(parse_path).convert('L')
        parse_array_orig = np.array(parse_image_orig)

        # --- Use the new robust method to find the clothing label ---
        upper_cloth_label = self.find_upper_cloth_label(pose_keypoints, parse_array_orig)
        
        # Resize parse map for creating masks
        parse_array_resized = cv2.resize(parse_array_orig, self.image_size[::-1], interpolation=cv2.INTER_NEAREST)

        # Create masks using the dynamically found label
        person_cloth_mask = (parse_array_resized == upper_cloth_label).astype(np.float32)
        
        # --- The rest of the processing remains the same ---
        person_image_tensor = self.transform(person_image)
        cloth_image_tensor = self.transform(cloth_image)
        cloth_mask_tensor = self.mask_transform(cloth_mask)
        pose_map = self.create_pose_map(pose_keypoints) # Assuming create_pose_map is part of this class
        pose_map_tensor = torch.from_numpy(pose_map).float()
        
        blurred_mask = cv2.GaussianBlur(person_cloth_mask, (5, 5), 0)
        blurred_mask_tensor = torch.from_numpy(blurred_mask).unsqueeze(0)
        
        agnostic_person_tensor = person_image_tensor * (1 - blurred_mask_tensor)
        warped_cloth_tensor = person_image_tensor * torch.from_numpy(person_cloth_mask).unsqueeze(0)

        return {
            'person_image': person_image_tensor,
            'cloth_image': cloth_image_tensor,
            'cloth_mask': cloth_mask_tensor,
            'agnostic_person': agnostic_person_tensor,
            'pose_map': pose_map_tensor,
            'warped_cloth': warped_cloth_tensor
        }

    def create_pose_map(self, keypoints):
        h, w = self.image_size
        orig_w, orig_h = self.original_image_size
        num_keypoints = keypoints.shape[0]
        pose_map = np.zeros((num_keypoints, h, w), dtype=np.float32)
        
        for i, point in enumerate(keypoints):
            if point[2] > 0.1:
                x = int(point[0] * w / orig_w)
                y = int(point[1] * h / orig_h)
                if 0 <= x < w and 0 <= y < h:
                    cv2.circle(pose_map[i], (x, y), radius=3, color=1, thickness=-1)
        return pose_map

In [None]:
import torch
# Corrected import line:
from torch.utils.data import DataLoader 
import matplotlib.pyplot as plt
import numpy as np
# This transform is just for visualization, so we define it locally here
from torchvision import transforms 

def tensor_to_pil(tensor):
    """Converts a [-1, 1] tensor to a PIL Image for visualization."""
    tensor = (tensor + 1) / 2
    tensor = tensor.clamp(0, 1)
    # Ensure tensor is on CPU before converting to numpy/PIL
    return transforms.ToPILImage()(tensor.cpu())

# --- Configuration ---
DATA_ROOT = '/kaggle/input/clothe/clothes_tryon_dataset/train'
BATCH_SIZE = 1
ORIGINAL_SIZE = (768, 1024) 

# --- Create Dataset and DataLoader ---
# The class name is FashionVTONDataset
dataset = FashionVTONDataset(data_root=DATA_ROOT, original_image_size=ORIGINAL_SIZE)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

# --- Fetch and Visualize One Sample ---
# Make sure to handle the case where the dataloader might be empty
try:
    sample = next(iter(dataloader))
    print("Successfully loaded a sample from the DataLoader.")
except StopIteration:
    print("DataLoader is empty. Please check your data directory and file names.")
    exit()


print("Keys in the sample batch:", sample.keys())
print("Shape of person_image:", sample['person_image'].shape)
print("Shape of cloth_image:", sample['cloth_image'].shape)
print("Shape of pose_map:", sample['pose_map'].shape)
print("Shape of agnostic_person:", sample['agnostic_person'].shape)

# --- Visualization ---
fig, axs = plt.subplots(1, 5, figsize=(20, 4))
axs[0].imshow(tensor_to_pil(sample['person_image'][0]))
axs[0].set_title("Original Person")
axs[0].axis('off')

axs[1].imshow(tensor_to_pil(sample['cloth_image'][0]))
axs[1].set_title("Cloth Item")
axs[1].axis('off')

# Summing along the channel dimension to visualize the pose map
pose_visualization = np.sum(sample['pose_map'][0].cpu().numpy(), axis=0)
axs[2].imshow(pose_visualization, cmap='gray')
axs[2].set_title("Pose Map")
axs[2].axis('off')

axs[3].imshow(tensor_to_pil(sample['agnostic_person'][0]))
axs[3].set_title("Agnostic Person")
axs[3].axis('off')

axs[4].imshow(tensor_to_pil(sample['warped_cloth'][0]))
axs[4].set_title("Ground Truth Warped")
axs[4].axis('off')

plt.tight_layout()
plt.show()

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class GMM(nn.Module):
    """
    Geometric Matching Module.
    A U-Net-like architecture that takes a cloth image and a pose map
    and predicts a flow field to warp the cloth.
    """
    def __init__(self, in_channels_cloth=3, in_channels_pose=18, out_channels_flow=2):
        """
        Args:
            in_channels_cloth (int): Number of channels in the cloth image (3 for RGB).
            in_channels_pose (int): Number of channels in the pose map (e.g., 18 for 18 keypoints).
            out_channels_flow (int): Number of channels for the output flow (2 for x and y).
        """
        super(GMM, self).__init__()

        # Encoder part
        self.encoder1 = self.conv_block(in_channels_cloth + in_channels_pose, 64)
        self.encoder2 = self.conv_block(64, 128)
        self.encoder3 = self.conv_block(128, 256)
        self.encoder4 = self.conv_block(256, 512)
        self.pool = nn.MaxPool2d(2, 2)

        # Bottleneck
        self.bottleneck = self.conv_block(512, 1024)

        # Decoder part
        self.upconv4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.decoder4 = self.conv_block(1024, 512)
        
        self.upconv3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.decoder3 = self.conv_block(512, 256)
        
        self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.decoder2 = self.conv_block(256, 128)
        
        self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.decoder1 = self.conv_block(128, 64)

        # Final output layer
        # This layer predicts the 2-channel flow field.
        # The output is activated with tanh to keep values between -1 and 1.
        self.conv_out = nn.Conv2d(64, out_channels_flow, kernel_size=1)
        self.tanh = nn.Tanh()

    def conv_block(self, in_channels, out_channels):
        """Helper function for a standard convolutional block."""
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, cloth_image, pose_map):
        """
        Forward pass for the GMM.
        
        Args:
            cloth_image (Tensor): The clothing image tensor. Shape: [B, 3, H, W]
            pose_map (Tensor): The pose map tensor. Shape: [B, 18, H, W]
            
        Returns:
            Tensor: The predicted flow field. Shape: [B, 2, H, W]
        """
        # Concatenate inputs along the channel dimension
        x = torch.cat([cloth_image, pose_map], dim=1)

        # Encoder path
        e1 = self.encoder1(x)
        p1 = self.pool(e1)
        
        e2 = self.encoder2(p1)
        p2 = self.pool(e2)
        
        e3 = self.encoder3(p2)
        p3 = self.pool(e3)
        
        e4 = self.encoder4(p3)
        p4 = self.pool(e4)

        # Bottleneck
        b = self.bottleneck(p4)

        # Decoder path
        d4 = self.upconv4(b)
        d4 = torch.cat([d4, e4], dim=1)
        d4 = self.decoder4(d4)
        
        d3 = self.upconv3(d4)
        d3 = torch.cat([d3, e3], dim=1)
        d3 = self.decoder3(d3)
        
        d2 = self.upconv2(d3)
        d2 = torch.cat([d2, e2], dim=1)
        d2 = self.decoder2(d2)
        
        d1 = self.upconv1(d2)
        d1 = torch.cat([d1, e1], dim=1)
        d1 = self.decoder1(d1)

        # Output flow field
        flow_field = self.conv_out(d1)
        flow_field = self.tanh(flow_field)
        
        return flow_field


def warp_cloth_with_flow(cloth_image, flow_field):
    """
    Warps a cloth image using a predicted flow field.
    
    Args:
        cloth_image (Tensor): The clothing image to warp. Shape: [B, C, H, W]
        flow_field (Tensor): The predicted flow field. Shape: [B, 2, H, W]
        
    Returns:
        Tensor: The warped cloth image.
    """
    # The flow field from the network is in range [-1, 1].
    # grid_sample expects the flow field to be permuted to [B, H, W, 2]
    # where the last dimension contains (x, y) coordinates.
    flow_field = flow_field.permute(0, 2, 3, 1)
    
    # torch.nn.functional.grid_sample applies the warping.
    # 'align_corners=True' is important for consistency.
    warped_image = F.grid_sample(cloth_image, flow_field, mode='bilinear', padding_mode='zeros', align_corners=True)
    
    return warped_image

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.utils import save_image
import os
from tqdm import tqdm
import torchvision.models as models

# This assumes the FashionVTONDataset, GMM, and warp_cloth_with_flow are defined
# from dataset import FashionVTONDataset
# from geometric_matching import GMM, warp_cloth_with_flow

class VGGPerceptualLoss(nn.Module):
    # ... (This class definition remains exactly the same as the previous version) ...
    def __init__(self, resize=True):
        super(VGGPerceptualLoss, self).__init__()
        vgg = models.vgg19(weights=models.VGG19_Weights.IMAGENET1K_V1).features
        self.vgg_layers = vgg[:35].eval()
        for param in self.vgg_layers.parameters(): param.requires_grad = False
        self.l1 = nn.L1Loss()
        self.transform = nn.functional.interpolate
        self.resize = resize
        self.register_buffer("mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
        self.register_buffer("std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
    def forward(self, pred, target):
        pred = (pred + 1) / 2; target = (target + 1) / 2
        pred = (pred - self.mean) / self.std; target = (target - self.mean) / self.std
        if self.resize:
            pred = self.transform(pred, mode='bilinear', size=(224, 224), align_corners=False)
            target = self.transform(target, mode='bilinear', size=(224, 224), align_corners=False)
        pred_features = self.vgg_layers(pred); target_features = self.vgg_layers(target)
        return self.l1(pred_features, target_features)

class TVLoss(nn.Module):
    # ... (This class definition remains unchanged) ...
    def __init__(self): super(TVLoss, self).__init__()
    def forward(self, x):
        batch_size, c, h, w = x.size(); tv_h = torch.pow(x[:,:,1:,:] - x[:,:,:-1,:], 2).sum(); tv_w = torch.pow(x[:,:,:,1:] - x[:,:,:,:-1], 2).sum()
        return (tv_h + tv_w) / (batch_size * c * h * w)


def main():
    """Main training loop for the GMM with MASKED L1 Loss and re-balanced weights."""
    # --- Setup & Config (Same as before) ---
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu"); print(f"Using device: {device}")
    DATA_ROOT = '/kaggle/input/clothe/clothes_tryon_dataset/train'; CHECKPOINT_DIR = '/kaggle/working/'; VISUALIZATION_DIR = '/kaggle/working/'
    BATCH_SIZE = 8; NUM_EPOCHS = 50; LEARNING_RATE = 2e-5; IMAGE_SIZE = (256, 192); ORIGINAL_IMAGE_SIZE = (768, 1024)
    os.makedirs(CHECKPOINT_DIR, exist_ok=True); os.makedirs(VISUALIZATION_DIR, exist_ok=True)

    # --- Data Loading (Same as before) ---
    train_dataset = FashionVTONDataset(data_root=DATA_ROOT, image_size=IMAGE_SIZE, original_image_size=ORIGINAL_IMAGE_SIZE)
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
    print(f"Dataset loaded with {len(train_dataset)} samples.")

    # --- Model, Optimizer, and Loss (Same as before) ---
    gmm = GMM(in_channels_pose=25).to(device)
    optimizer = optim.Adam(gmm.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
    l1_loss_fn = nn.L1Loss().to(device)
    perceptual_loss_fn = VGGPerceptualLoss().to(device)
    tv_loss_fn = TVLoss().to(device)
    
    # --- Resumption Logic (Same as before) ---
    start_epoch = 0; checkpoint_path = os.path.join(CHECKPOINT_DIR, 'gmm_latest.pth')
    if os.path.isfile(checkpoint_path):
        print(f"Resuming GMM training from checkpoint: {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path); gmm.load_state_dict(checkpoint['model_state_dict']); optimizer.load_state_dict(checkpoint['optimizer_state_dict']); start_epoch = checkpoint['epoch'] + 1
        print(f"Resumed from Epoch {start_epoch}")
    else: print("Starting GMM training from scratch.")

    # --- Training Loop ---
    for epoch in range(start_epoch, NUM_EPOCHS):
        gmm.train(); epoch_loss = 0.0
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}")

        for i, batch in enumerate(progress_bar):
            cloth_image = batch['cloth_image'].to(device)
            cloth_mask = batch['cloth_mask'].to(device) # We need the cloth mask now
            pose_map = batch['pose_map'].to(device)
            ground_truth_warped = batch['warped_cloth'].to(device)
            
            optimizer.zero_grad()
            predicted_flow = gmm(cloth_image, pose_map)
            predicted_warped = warp_cloth_with_flow(cloth_image, predicted_flow)
            
            # Warp the cloth mask as well to know where the cloth is in the warped image
            warped_cloth_mask = warp_cloth_with_flow(cloth_mask, predicted_flow)

            # =================== UPDATED LOSS CALCULATION ===================
            # 1. Masked L1 Loss: Only calculate loss on the cloth pixels
            loss_l1 = l1_loss_fn(predicted_warped * warped_cloth_mask, ground_truth_warped * warped_cloth_mask)
            
            # 2. Perceptual Loss (same as before)
            loss_p = perceptual_loss_fn(predicted_warped, ground_truth_warped)
            
            # 3. TV Loss (same as before)
            loss_tv = tv_loss_fn(predicted_flow)
            
            # 4. Re-balanced Total Loss: Give L1 loss a stronger weight
            total_loss = (10 * loss_l1) + loss_p + (0.5 * loss_tv)
            # ===============================================================

            total_loss.backward()
            torch.nn.utils.clip_grad_norm_(gmm.parameters(), max_norm=1.0)
            optimizer.step()
            
            epoch_loss += total_loss.item()
            progress_bar.set_postfix(loss=total_loss.item(), l1=loss_l1.item(), perceptual=loss_p.item())
            
            # (Visualization code is the same)
            if i == 0:
                visual_comparison = torch.cat([(cloth_image.cpu() + 1) / 2, (predicted_warped.cpu().detach() + 1) / 2, (ground_truth_warped.cpu() + 1) / 2], dim=0)
                save_image(visual_comparison, os.path.join(VISUALIZATION_DIR, f'epoch_{epoch+1}_comparison.png'), nrow=BATCH_SIZE)

        avg_epoch_loss = epoch_loss / len(train_loader)
        print(f"End of Epoch {epoch+1}/{NUM_EPOCHS}, Average Loss: {avg_epoch_loss:.4f}")

        # (Checkpointing code is the same)
        latest_checkpoint_state = {'epoch': epoch, 'model_state_dict': gmm.state_dict(), 'optimizer_state_dict': optimizer.state_dict()}
        torch.save(latest_checkpoint_state, checkpoint_path)

    final_model_path = os.path.join(CHECKPOINT_DIR, 'gmm_final.pth')
    torch.save(gmm.state_dict(), final_model_path)
    print(f"Final model saved to {final_model_path}")

if __name__ == '__main__':
    main()

In [8]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.utils import save_image
import os
from tqdm import tqdm

# # Import our custom modules
# from dataset import FashionVTONDataset
# from geometric_matching import GMM, warp_cloth_with_flow

# --- Training Configuration ---
DATA_ROOT = '/kaggle/input/clothe/clothes_tryon_dataset/train'
CHECKPOINT_DIR = '/kaggle/working/'
VISUALIZATION_DIR = '/kaggle/working/'
BATCH_SIZE = 8 # Adjust based on your VRAM
NUM_EPOCHS = 5 # Start with 50 and increase if needed
LEARNING_RATE = 0.0001
IMAGE_SIZE = (256, 192) # Use a smaller size for faster training initially
ORIGINAL_IMAGE_SIZE = (768, 1024) # Adjust if your source is different

# Create directories if they don't exist
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
os.makedirs(VISUALIZATION_DIR, exist_ok=True)


class TVLoss(nn.Module):
    """Total Variation Loss for flow field regularization."""
    def __init__(self):
        super(TVLoss, self).__init__()

    def forward(self, x):
        batch_size, c, h, w = x.size()
        tv_h = torch.pow(x[:,:,1:,:] - x[:,:,:-1,:], 2).sum()
        tv_w = torch.pow(x[:,:,:,1:] - x[:,:,:,:-1], 2).sum()
        return (tv_h + tv_w) / (batch_size * c * h * w)


def main():
    """Main training loop for the GMM."""
    # --- Setup ---
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # --- Data Loading ---
    print("Loading dataset...")
    train_dataset = FashionVTONDataset(
        data_root=DATA_ROOT,
        image_size=IMAGE_SIZE,
        original_image_size=ORIGINAL_IMAGE_SIZE
    )
    train_loader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=4,
        pin_memory=True
    )
    print(f"Dataset loaded with {len(train_dataset)} samples.")

    # --- Model, Loss, and Optimizer ---
    gmm = GMM(in_channels_pose=25).to(device)
    l1_loss = nn.L1Loss().to(device)
    tv_loss = TVLoss().to(device)
    optimizer = optim.Adam(gmm.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999)) # LEARNING_RATE is now lower

    # --- Training Loop ---
    print("Starting training...")
    for epoch in range(NUM_EPOCHS):
        gmm.train()
        epoch_loss = 0.0
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}")

        for i, batch in enumerate(progress_bar):
            cloth_image = batch['cloth_image'].to(device)
            pose_map = batch['pose_map'].to(device)
            ground_truth_warped = batch['warped_cloth'].to(device)
            
            optimizer.zero_grad()

            predicted_flow = gmm(cloth_image, pose_map)
            predicted_warped = warp_cloth_with_flow(cloth_image, predicted_flow)

            loss_l1 = l1_loss(predicted_warped, ground_truth_warped)
            loss_tv = tv_loss(predicted_flow)
            total_loss = loss_l1 + 0.5 * loss_tv

            total_loss.backward()
            
            # Add gradient clipping as a safety measure
            torch.nn.utils.clip_grad_norm_(gmm.parameters(), max_norm=1.0)
            
            optimizer.step()
            
            epoch_loss += total_loss.item()
            progress_bar.set_postfix(loss=total_loss.item(), l1=loss_l1.item(), tv=loss_tv.item())
            
            if i == 0:
                cloth_image_vis = (cloth_image + 1) / 2
                predicted_warped_vis = (predicted_warped + 1) / 2
                ground_truth_vis = (ground_truth_warped + 1) / 2
                visual_comparison = torch.cat([cloth_image_vis, predicted_warped_vis, ground_truth_vis], dim=0)
                save_image(visual_comparison, os.path.join(VISUALIZATION_DIR, f'epoch_{epoch+1}_comparison.png'), nrow=BATCH_SIZE)

        avg_epoch_loss = epoch_loss / len(train_loader)
        print(f"End of Epoch {epoch+1}/{NUM_EPOCHS}, Average Loss: {avg_epoch_loss:.4f}")

        if (epoch + 1) % 5 == 0:
            checkpoint_path = os.path.join(CHECKPOINT_DIR, f'gmm_epoch_{epoch+1}.pth')
            torch.save(gmm.state_dict(), checkpoint_path)
            print(f"Checkpoint saved to {checkpoint_path}")

    print("Training finished.")
    final_model_path = os.path.join(CHECKPOINT_DIR, 'gmm_final.pth')
    torch.save(gmm.state_dict(), final_model_path)
    print(f"Final model saved to {final_model_path}")

if __name__ == '__main__':
    main()

Using device: cuda
Loading dataset...
Dataset loaded with 11647 samples.
Starting training...


Epoch 1/5: 100%|██████████| 1456/1456 [08:33<00:00,  2.83it/s, l1=0.4, loss=0.4, tv=4.29e-6]     


End of Epoch 1/5, Average Loss: 0.5723


Epoch 2/5: 100%|██████████| 1456/1456 [08:26<00:00,  2.87it/s, l1=0.608, loss=0.608, tv=5.42e-5] 


End of Epoch 2/5, Average Loss: 0.5744


Epoch 3/5: 100%|██████████| 1456/1456 [08:25<00:00,  2.88it/s, l1=0.648, loss=0.648, tv=0.000177]


End of Epoch 3/5, Average Loss: 0.5731


Epoch 4/5: 100%|██████████| 1456/1456 [08:24<00:00,  2.89it/s, l1=0.559, loss=0.559, tv=0.000269]


End of Epoch 4/5, Average Loss: 0.5705


Epoch 5/5: 100%|██████████| 1456/1456 [08:22<00:00,  2.90it/s, l1=0.632, loss=0.632, tv=0.000102]


End of Epoch 5/5, Average Loss: 0.5759
Checkpoint saved to /kaggle/working/gmm_epoch_5.pth
Training finished.
Final model saved to /kaggle/working/gmm_final.pth


In [12]:
# %%writefile train_gmm_ddp.py
# import torch
# import torch.nn as nn
# import torch.optim as optim
# from torch.utils.data import DataLoader, Dataset
# from torchvision.utils import save_image
# import os
# from tqdm import tqdm
# import torch.distributed as dist
# from torch.nn.parallel import DistributedDataParallel as DDP
# from torch.utils.data.distributed import DistributedSampler
# import torchvision.models as models
# import json
# import cv2
# import numpy as np
# from PIL import Image
# from torchvision import transforms
# import socket # NEW IMPORT
# import torch.nn.functional as F # <<< THIS IS THE FIX

# # ===================================================================
# # ALL CLASS DEFINITIONS (These are unchanged)
# # ===================================================================
# class FashionVTONDataset(Dataset):
#     def __init__(self, data_root, image_size=(256, 192), original_image_size=(768, 1024)):
#         self.data_root = data_root; self.image_size = image_size; self.original_image_size = original_image_size; self.image_dir = os.path.join(data_root, 'image'); self.cloth_dir = os.path.join(data_root, 'cloth'); self.cloth_mask_dir = os.path.join(data_root, 'cloth-mask'); self.pose_dir = os.path.join(data_root, 'openpose_json'); self.parse_dir = os.path.join(data_root, 'image-parse-v3'); self.image_files = sorted([f for f in os.listdir(self.image_dir) if f.endswith(('.jpg', '.png'))]); self.transform = transforms.Compose([transforms.Resize(self.image_size, interpolation=transforms.InterpolationMode.BILINEAR), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]); self.mask_transform = transforms.Compose([transforms.Resize(self.image_size, interpolation=transforms.InterpolationMode.NEAREST), transforms.ToTensor()])
#     def __len__(self): return len(self.image_files)
#     def find_upper_cloth_label(self, pose_keypoints, parse_array):
#         torso_indices = [1, 2, 5, 8, 9, 12]; visible_points = []
#         for i in torso_indices:
#             if i < len(pose_keypoints): x, y, conf = pose_keypoints[i];
#             if conf > 0.1: visible_points.append((int(x), int(y)))
#         if len(visible_points) < 3: return 5
#         x_coords, y_coords = zip(*visible_points); min_x, max_x = min(x_coords), max(x_coords); min_y, max_y = min(y_coords), max(y_coords)
#         if max_x <= min_x or max_y <= min_y: return 5
#         torso_parse_area = parse_array[min_y:max_y, min_x:max_x]; unique_labels, counts = np.unique(torso_parse_area, return_counts=True); non_zero_mask = (unique_labels != 0)
#         if np.any(non_zero_mask): return unique_labels[non_zero_mask][np.argmax(counts[non_zero_mask])]
#         else: return 5
#     def __getitem__(self, idx):
#         image_name = self.image_files[idx]; base_name = os.path.splitext(image_name)[0]
#         person_image = Image.open(os.path.join(self.image_dir, image_name)).convert('RGB'); cloth_image = Image.open(os.path.join(self.cloth_dir, image_name)).convert('RGB'); cloth_mask = Image.open(os.path.join(self.cloth_mask_dir, image_name)).convert('L')
#         try:
#             with open(os.path.join(self.pose_dir, f"{base_name}_keypoints.json"), 'r') as f: pose_data = json.load(f)
#             pose_keypoints = np.array(pose_data['people'][0]['pose_keypoints_2d']).reshape(-1, 3)
#         except (FileNotFoundError, IndexError): pose_keypoints = np.zeros((25, 3), dtype=np.float32)
#         parse_path = os.path.join(self.parse_dir, f"{base_name}.png")
#         if not os.path.exists(parse_path): parse_path = os.path.join(self.parse_dir, f"{base_name}.jpg")
#         parse_array_orig = np.array(Image.open(parse_path).convert('L'))
#         upper_cloth_label = self.find_upper_cloth_label(pose_keypoints, parse_array_orig)
#         parse_array_resized = cv2.resize(parse_array_orig, self.image_size[::-1], interpolation=cv2.INTER_NEAREST)
#         person_cloth_mask = (parse_array_resized == upper_cloth_label).astype(np.float32)
#         person_image_tensor = self.transform(person_image); cloth_image_tensor = self.transform(cloth_image); cloth_mask_tensor = self.mask_transform(cloth_mask)
#         pose_map_tensor = torch.from_numpy(self.create_pose_map(pose_keypoints)).float()
#         blurred_mask_tensor = torch.from_numpy(cv2.GaussianBlur(person_cloth_mask, (5, 5), 0)).unsqueeze(0)
#         agnostic_person_tensor = person_image_tensor * (1 - blurred_mask_tensor)
#         warped_cloth_tensor = person_image_tensor * torch.from_numpy(person_cloth_mask).unsqueeze(0)
#         return {'person_image': person_image_tensor, 'cloth_image': cloth_image_tensor, 'cloth_mask': cloth_mask_tensor, 'agnostic_person': agnostic_person_tensor, 'pose_map': pose_map_tensor, 'warped_cloth': warped_cloth_tensor}
#     def create_pose_map(self, keypoints):
#         h, w = self.image_size; orig_w, orig_h = self.original_image_size; num_keypoints = keypoints.shape[0]
#         pose_map = np.zeros((num_keypoints, h, w), dtype=np.float32)
#         for i, point in enumerate(keypoints):
#             if point[2] > 0.1:
#                 x, y = int(point[0] * w / orig_w), int(point[1] * h / orig_h)
#                 if 0 <= x < w and 0 <= y < h: cv2.circle(pose_map[i], (x, y), radius=3, color=1, thickness=-1)
#         return pose_map

# class GMM(nn.Module):
#     def __init__(self, in_channels_cloth=3, in_channels_pose=25, out_channels_flow=2):
#         super(GMM, self).__init__(); self.encoder1 = self.conv_block(in_channels_cloth + in_channels_pose, 64); self.encoder2 = self.conv_block(64, 128); self.encoder3 = self.conv_block(128, 256); self.encoder4 = self.conv_block(256, 512); self.pool = nn.MaxPool2d(2, 2); self.bottleneck = self.conv_block(512, 1024); self.upconv4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2); self.decoder4 = self.conv_block(1024, 512); self.upconv3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2); self.decoder3 = self.conv_block(512, 256); self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2); self.decoder2 = self.conv_block(256, 128); self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2); self.decoder1 = self.conv_block(128, 64); self.conv_out = nn.Conv2d(64, out_channels_flow, kernel_size=1); self.tanh = nn.Tanh()
#     def conv_block(self, in_channels, out_channels): return nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), nn.ReLU(inplace=True))
#     def forward(self, cloth_image, pose_map):
#         x = torch.cat([cloth_image, pose_map], dim=1); e1 = self.encoder1(x); p1 = self.pool(e1); e2 = self.encoder2(p1); p2 = self.pool(e2); e3 = self.encoder3(p2); p3 = self.pool(e3); e4 = self.encoder4(p3); p4 = self.pool(e4); b = self.bottleneck(p4); d4 = self.upconv4(b); d4 = torch.cat([d4, e4], dim=1); d4 = self.decoder4(d4); d3 = self.upconv3(d4); d3 = torch.cat([d3, e3], dim=1); d3 = self.decoder3(d3); d2 = self.upconv2(d3); d2 = torch.cat([d2, e2], dim=1); d2 = self.decoder2(d2); d1 = self.upconv1(d2); d1 = torch.cat([d1, e1], dim=1); d1 = self.decoder1(d1); flow_field = self.conv_out(d1); flow_field = self.tanh(flow_field); return flow_field

# def warp_cloth_with_flow(cloth_image, flow_field):
#     flow_field = flow_field.permute(0, 2, 3, 1); warped_image = F.grid_sample(cloth_image, flow_field, mode='bilinear', padding_mode='zeros', align_corners=True); return warped_image

# class VGGPerceptualLoss(nn.Module):
#     def __init__(self, resize=True):
#         super(VGGPerceptualLoss, self).__init__(); vgg = models.vgg19(weights=models.VGG19_Weights.IMAGENET1K_V1).features; self.vgg_layers = vgg[:35].eval();
#         for param in self.vgg_layers.parameters(): param.requires_grad = False
#         self.l1 = nn.L1Loss(); self.transform = nn.functional.interpolate; self.resize = resize
#         self.register_buffer("mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)); self.register_buffer("std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
#     def forward(self, pred, target):
#         pred = (pred + 1) / 2; target = (target + 1) / 2; pred = (pred - self.mean) / self.std; target = (target - self.mean) / self.std
#         if self.resize: pred = self.transform(pred, mode='bilinear', size=(224, 224), align_corners=False); target = self.transform(target, mode='bilinear', size=(224, 224), align_corners=False)
#         pred_features = self.vgg_layers(pred); target_features = self.vgg_layers(target); return self.l1(pred_features, target_features)

# class TVLoss(nn.Module):
#     def __init__(self): super(TVLoss, self).__init__()
#     def forward(self, x):
#         batch_size, c, h, w = x.size(); tv_h = torch.pow(x[:,:,1:,:] - x[:,:,:-1,:], 2).sum(); tv_w = torch.pow(x[:,:,:,1:] - x[:,:,:,:-1], 2).sum()
#         return (tv_h + tv_w) / (batch_size * c * h * w)

# # --- CORRECTED Distributed Training Setup ---
# def setup(rank, world_size):
#     # Try to get the master address from the environment variables set by torchrun
#     master_addr = os.environ.get("MASTER_ADDR", "localhost")
#     master_port = os.environ.get("MASTER_PORT", "12355")
    
#     # If the master address is still localhost, try to find the actual IP.
#     # This is a robust fallback for environments like Kaggle.
#     if master_addr == "localhost":
#         try:
#             hostname = socket.gethostname()
#             master_addr = socket.gethostbyname(hostname)
#         except socket.gaierror:
#             # If that fails, fall back to the loopback address.
#             master_addr = "127.0.0.1"

#     os.environ['MASTER_ADDR'] = master_addr
#     os.environ['MASTER_PORT'] = master_port
    
#     if rank == 0:
#         print(f"Initializing process group... MASTER_ADDR={master_addr}, MASTER_PORT={master_port}")
        
#     dist.init_process_group("nccl", rank=rank, world_size=world_size)

# def cleanup():
#     dist.destroy_process_group()

# # --- Main Training Function (Unchanged) ---
# def train(rank, world_size):
#     setup(rank, world_size)
    
#     # --- Configuration ---
#     DATA_ROOT = '/kaggle/input/clothe/clothes_tryon_dataset/train'
#     CHECKPOINT_DIR = '/kaggle/working/'
#     VISUALIZATION_DIR = '/kaggle/working/'
#     BATCH_SIZE = 8; NUM_EPOCHS = 50; LEARNING_RATE = 2e-5
#     IMAGE_SIZE = (256, 192); ORIGINAL_IMAGE_SIZE = (768, 1024); ACCUMULATION_STEPS = 4

#     # --- Setup Device and Model ---
#     torch.cuda.set_device(rank)
#     gmm = GMM(in_channels_pose=25).to(rank)
#     gmm = DDP(gmm, device_ids=[rank], find_unused_parameters=True) # find_unused_parameters can help with complex models
    
#     optimizer = optim.Adam(gmm.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
#     l1_loss_fn = nn.L1Loss().to(rank); perceptual_loss_fn = VGGPerceptualLoss().to(rank); tv_loss_fn = TVLoss().to(rank)
#     scaler = torch.cuda.amp.GradScaler()

#     # --- Resumption Logic ---
#     start_epoch = 0; checkpoint_path = os.path.join(CHECKPOINT_DIR, 'gmm_latest_ddp.pth')
#     if os.path.isfile(checkpoint_path) and rank == 0:
#         print(f"Loading checkpoint: {checkpoint_path}")
#         checkpoint = torch.load(checkpoint_path, map_location={'cuda:0': f'cuda:{rank}'})
#         gmm.module.load_state_dict(checkpoint['model_state_dict'])
#         optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
#         scaler.load_state_dict(checkpoint['scaler_state_dict'])
#         start_epoch = checkpoint['epoch'] + 1
#         print(f"Resumed from Epoch {start_epoch}")
    
#     dist.barrier() 

#     # --- Data Loading with Distributed Sampler ---
#     dataset = FashionVTONDataset(data_root=DATA_ROOT, image_size=IMAGE_SIZE, original_image_size=ORIGINAL_IMAGE_SIZE)
#     sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=True)
#     loader = DataLoader(dataset, batch_size=BATCH_SIZE, sampler=sampler, num_workers=4, pin_memory=True)

#     # --- Training Loop ---
#     for epoch in range(start_epoch, NUM_EPOCHS):
#         sampler.set_epoch(epoch); gmm.train()
#         progress_bar = tqdm(loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}", disable=(rank != 0))

#         for i, batch in enumerate(progress_bar):
#             cloth_image = batch['cloth_image'].to(rank); cloth_mask = batch['cloth_mask'].to(rank)
#             pose_map = batch['pose_map'].to(rank); ground_truth_warped = batch['warped_cloth'].to(rank)
            
#             with torch.cuda.amp.autocast():
#                 predicted_flow = gmm(cloth_image, pose_map)
#                 predicted_warped = warp_cloth_with_flow(cloth_image, predicted_flow)
#                 warped_cloth_mask = warp_cloth_with_flow(cloth_mask, predicted_flow)
#                 loss_l1 = l1_loss_fn(predicted_warped * warped_cloth_mask, ground_truth_warped * warped_cloth_mask)
#                 loss_p = perceptual_loss_fn(predicted_warped, ground_truth_warped)
#                 loss_tv = tv_loss_fn(predicted_flow)
#                 total_loss = (10 * loss_l1) + loss_p + (0.5 * loss_tv)
#                 total_loss = total_loss / ACCUMULATION_STEPS

#             scaler.scale(total_loss).backward()
            
#             if (i + 1) % ACCUMULATION_STEPS == 0:
#                 scaler.step(optimizer); scaler.update(); optimizer.zero_grad()
            
#             if rank == 0: progress_bar.set_postfix(loss=total_loss.item() * ACCUMULATION_STEPS)
        
#         if rank == 0:
#             latest_checkpoint_state = {
#                 'epoch': epoch, 'model_state_dict': gmm.module.state_dict(),
#                 'optimizer_state_dict': optimizer.state_dict(), 'scaler_state_dict': scaler.state_dict(),
#             }
#             torch.save(latest_checkpoint_state, checkpoint_path)
            
#             with torch.no_grad():
#                 vis_batch = torch.cat([(cloth_image.cpu() + 1) / 2, (predicted_warped.cpu().detach() + 1) / 2, (ground_truth_warped.cpu() + 1) / 2], dim=0)
#                 save_image(vis_batch, os.path.join(VISUALIZATION_DIR, f'epoch_{epoch+1}_comparison.png'), nrow=BATCH_SIZE)

#             print(f"Epoch {epoch+1} finished. Checkpoint saved.")

#     if rank == 0:
#         final_model_path = os.path.join(CHECKPOINT_DIR, 'gmm_final.pth')
#         torch.save(gmm.module.state_dict(), final_model_path)
#         print(f"Final model saved to {final_model_path}")
        
#     cleanup()

# if __name__ == '__main__':
#     world_size = torch.cuda.device_count()
#     if world_size < 2:
#         print("Distributed training requires at least 2 GPUs.")
#     else:
#         # Use torch.multiprocessing.spawn which is often more stable in notebooks than torchrun
#         torch.multiprocessing.spawn(train, args=(world_size,), nprocs=world_size, join=True)

In [11]:
!python train_gmm_ddp.py

Found 2 GPUs. Starting DDP with file sync at /kaggle/working/ddp_sync_file
Initializing process group with: file:///kaggle/working/ddp_sync_file
  scaler = torch.cuda.amp.GradScaler(); start_epoch = 0; checkpoint_path = os.path.join(CHECKPOINT_DIR, 'gmm_latest_ddp.pth')
  scaler = torch.cuda.amp.GradScaler(); start_epoch = 0; checkpoint_path = os.path.join(CHECKPOINT_DIR, 'gmm_latest_ddp.pth')
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
Epoch 1/50: 100%|██████████████████| 728/728 [04:18<00:00,  2.82it/s, loss=8.14]
Epoch 1 finished. Checkpoint saved.
Epoch 2/50: 100%|██████████████████| 728/728 [03:55<00:00,  3.09it/s, loss=7.64]
Epoch 2 finished. Checkpoint saved.
Epoch 3/50: 100%|██████████████████| 728/728 [03:54<00:00,  3.10it/s, loss=6.68]
Epoch 3 finished. Checkpoint saved.
Epoch 4/50: 100%|██████████████████| 728/728 [03:53<00:00,  3.12it/s, loss=6.15]
Epoch 4 finished. Checkpoint saved.
Epoch 5/50: 100%|██████████████████| 728/728 [04:00<00:00,  3.03it/

In [10]:
%%writefile train_gmm_ddp.py
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision.utils import save_image
import os
from tqdm import tqdm
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
import torchvision.models as models
import json
import cv2
import numpy as np
from PIL import Image
from torchvision import transforms
import time # Import time for the wait mechanism
import torch.nn.functional as F # <<< THIS IS THE FIX

# ===================================================================
# ALL CLASS DEFINITIONS (These are unchanged)
# ===================================================================
class FashionVTONDataset(Dataset):
    def __init__(self, data_root, image_size=(256, 192), original_image_size=(768, 1024)):
        self.data_root = data_root; self.image_size = image_size; self.original_image_size = original_image_size; self.image_dir = os.path.join(data_root, 'image'); self.cloth_dir = os.path.join(data_root, 'cloth'); self.cloth_mask_dir = os.path.join(data_root, 'cloth-mask'); self.pose_dir = os.path.join(data_root, 'openpose_json'); self.parse_dir = os.path.join(data_root, 'image-parse-v3'); self.image_files = sorted([f for f in os.listdir(self.image_dir) if f.endswith(('.jpg', '.png'))]); self.transform = transforms.Compose([transforms.Resize(self.image_size, interpolation=transforms.InterpolationMode.BILINEAR), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]); self.mask_transform = transforms.Compose([transforms.Resize(self.image_size, interpolation=transforms.InterpolationMode.NEAREST), transforms.ToTensor()])
    def __len__(self): return len(self.image_files)
    def find_upper_cloth_label(self, pose_keypoints, parse_array):
        torso_indices = [1, 2, 5, 8, 9, 12]; visible_points = []
        for i in torso_indices:
            if i < len(pose_keypoints): x, y, conf = pose_keypoints[i];
            if conf > 0.1: visible_points.append((int(x), int(y)))
        if len(visible_points) < 3: return 5
        x_coords, y_coords = zip(*visible_points); min_x, max_x = min(x_coords), max(x_coords); min_y, max_y = min(y_coords), max(y_coords)
        if max_x <= min_x or max_y <= min_y: return 5
        torso_parse_area = parse_array[min_y:max_y, min_x:max_x]; unique_labels, counts = np.unique(torso_parse_area, return_counts=True); non_zero_mask = (unique_labels != 0)
        if np.any(non_zero_mask): return unique_labels[non_zero_mask][np.argmax(counts[non_zero_mask])]
        else: return 5
    def __getitem__(self, idx):
        image_name = self.image_files[idx]; base_name = os.path.splitext(image_name)[0]
        person_image = Image.open(os.path.join(self.image_dir, image_name)).convert('RGB'); cloth_image = Image.open(os.path.join(self.cloth_dir, image_name)).convert('RGB'); cloth_mask = Image.open(os.path.join(self.cloth_mask_dir, image_name)).convert('L')
        try:
            with open(os.path.join(self.pose_dir, f"{base_name}_keypoints.json"), 'r') as f: pose_data = json.load(f)
            pose_keypoints = np.array(pose_data['people'][0]['pose_keypoints_2d']).reshape(-1, 3)
        except (FileNotFoundError, IndexError): pose_keypoints = np.zeros((25, 3), dtype=np.float32)
        parse_path = os.path.join(self.parse_dir, f"{base_name}.png")
        if not os.path.exists(parse_path): parse_path = os.path.join(self.parse_dir, f"{base_name}.jpg")
        parse_array_orig = np.array(Image.open(parse_path).convert('L'))
        upper_cloth_label = self.find_upper_cloth_label(pose_keypoints, parse_array_orig)
        parse_array_resized = cv2.resize(parse_array_orig, self.image_size[::-1], interpolation=cv2.INTER_NEAREST)
        person_cloth_mask = (parse_array_resized == upper_cloth_label).astype(np.float32)
        person_image_tensor = self.transform(person_image); cloth_image_tensor = self.transform(cloth_image); cloth_mask_tensor = self.mask_transform(cloth_mask)
        pose_map_tensor = torch.from_numpy(self.create_pose_map(pose_keypoints)).float()
        blurred_mask_tensor = torch.from_numpy(cv2.GaussianBlur(person_cloth_mask, (5, 5), 0)).unsqueeze(0)
        agnostic_person_tensor = person_image_tensor * (1 - blurred_mask_tensor)
        warped_cloth_tensor = person_image_tensor * torch.from_numpy(person_cloth_mask).unsqueeze(0)
        return {'person_image': person_image_tensor, 'cloth_image': cloth_image_tensor, 'cloth_mask': cloth_mask_tensor, 'agnostic_person': agnostic_person_tensor, 'pose_map': pose_map_tensor, 'warped_cloth': warped_cloth_tensor}
    def create_pose_map(self, keypoints):
        h, w = self.image_size; orig_w, orig_h = self.original_image_size; num_keypoints = keypoints.shape[0]
        pose_map = np.zeros((num_keypoints, h, w), dtype=np.float32)
        for i, point in enumerate(keypoints):
            if point[2] > 0.1:
                x, y = int(point[0] * w / orig_w), int(point[1] * h / orig_h)
                if 0 <= x < w and 0 <= y < h: cv2.circle(pose_map[i], (x, y), radius=3, color=1, thickness=-1)
        return pose_map

class GMM(nn.Module):
    def __init__(self, in_channels_cloth=3, in_channels_pose=25, out_channels_flow=2):
        super(GMM, self).__init__(); self.encoder1 = self.conv_block(in_channels_cloth + in_channels_pose, 64); self.encoder2 = self.conv_block(64, 128); self.encoder3 = self.conv_block(128, 256); self.encoder4 = self.conv_block(256, 512); self.pool = nn.MaxPool2d(2, 2); self.bottleneck = self.conv_block(512, 1024); self.upconv4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2); self.decoder4 = self.conv_block(1024, 512); self.upconv3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2); self.decoder3 = self.conv_block(512, 256); self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2); self.decoder2 = self.conv_block(256, 128); self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2); self.decoder1 = self.conv_block(128, 64); self.conv_out = nn.Conv2d(64, out_channels_flow, kernel_size=1); self.tanh = nn.Tanh()
    def conv_block(self, in_channels, out_channels): return nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), nn.ReLU(inplace=True))
    def forward(self, cloth_image, pose_map):
        x = torch.cat([cloth_image, pose_map], dim=1); e1 = self.encoder1(x); p1 = self.pool(e1); e2 = self.encoder2(p1); p2 = self.pool(e2); e3 = self.encoder3(p2); p3 = self.pool(e3); e4 = self.encoder4(p3); p4 = self.pool(e4); b = self.bottleneck(p4); d4 = self.upconv4(b); d4 = torch.cat([d4, e4], dim=1); d4 = self.decoder4(d4); d3 = self.upconv3(d4); d3 = torch.cat([d3, e3], dim=1); d3 = self.decoder3(d3); d2 = self.upconv2(d3); d2 = torch.cat([d2, e2], dim=1); d2 = self.decoder2(d2); d1 = self.upconv1(d2); d1 = torch.cat([d1, e1], dim=1); d1 = self.decoder1(d1); flow_field = self.conv_out(d1); flow_field = self.tanh(flow_field); return flow_field

def warp_cloth_with_flow(cloth_image, flow_field):
    flow_field = flow_field.permute(0, 2, 3, 1); warped_image = F.grid_sample(cloth_image, flow_field, mode='bilinear', padding_mode='zeros', align_corners=True); return warped_image

class VGGPerceptualLoss(nn.Module):
    def __init__(self, resize=True):
        super(VGGPerceptualLoss, self).__init__(); vgg = models.vgg19(weights=models.VGG19_Weights.IMAGENET1K_V1).features; self.vgg_layers = vgg[:35].eval();
        for param in self.vgg_layers.parameters(): param.requires_grad = False
        self.l1 = nn.L1Loss(); self.transform = nn.functional.interpolate; self.resize = resize
        self.register_buffer("mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)); self.register_buffer("std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
    def forward(self, pred, target):
        pred = (pred + 1) / 2; target = (target + 1) / 2; pred = (pred - self.mean) / self.std; target = (target - self.mean) / self.std
        if self.resize: pred = self.transform(pred, mode='bilinear', size=(224, 224), align_corners=False); target = self.transform(target, mode='bilinear', size=(224, 224), align_corners=False)
        pred_features = self.vgg_layers(pred); target_features = self.vgg_layers(target); return self.l1(pred_features, target_features)

class TVLoss(nn.Module):
    def __init__(self): super(TVLoss, self).__init__()
    def forward(self, x):
        batch_size, c, h, w = x.size(); tv_h = torch.pow(x[:,:,1:,:] - x[:,:,:-1,:], 2).sum(); tv_w = torch.pow(x[:,:,:,1:] - x[:,:,:,:-1], 2).sum()
        return (tv_h + tv_w) / (batch_size * c * h * w)

# --- CORRECTED Distributed Training Setup using File Store ---
def setup(rank, world_size, sync_file):
    # Set the initialization method to use the shared file system
    init_method = f'file://{sync_file}'
    
    if rank == 0:
        print(f"Initializing process group with: {init_method}")
        
    dist.init_process_group("nccl", init_method=init_method, rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

# --- Main Training Function (Unchanged) ---
def train(rank, world_size, sync_file):
    setup(rank, world_size, sync_file)
    
    # ... (The rest of the train function is exactly the same as the previous version) ...
    DATA_ROOT = '/kaggle/input/clothe/clothes_tryon_dataset/train'; CHECKPOINT_DIR = '/kaggle/working/'; VISUALIZATION_DIR = '/kaggle/working/'
    BATCH_SIZE = 8; NUM_EPOCHS = 50; LEARNING_RATE = 2e-5; IMAGE_SIZE = (256, 192); ORIGINAL_IMAGE_SIZE = (768, 1024); ACCUMULATION_STEPS = 4
    torch.cuda.set_device(rank); gmm = GMM(in_channels_pose=25).to(rank); gmm = DDP(gmm, device_ids=[rank], find_unused_parameters=True)
    optimizer = optim.Adam(gmm.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999)); l1_loss_fn = nn.L1Loss().to(rank); perceptual_loss_fn = VGGPerceptualLoss().to(rank); tv_loss_fn = TVLoss().to(rank)
    scaler = torch.cuda.amp.GradScaler(); start_epoch = 0; checkpoint_path = os.path.join(CHECKPOINT_DIR, 'gmm_latest_ddp.pth')
    if os.path.isfile(checkpoint_path) and rank == 0:
        print(f"Loading checkpoint: {checkpoint_path}"); checkpoint = torch.load(checkpoint_path, map_location={'cuda:0': f'cuda:{rank}'})
        gmm.module.load_state_dict(checkpoint['model_state_dict']); optimizer.load_state_dict(checkpoint['optimizer_state_dict']); scaler.load_state_dict(checkpoint['scaler_state_dict']); start_epoch = checkpoint['epoch'] + 1
        print(f"Resumed from Epoch {start_epoch}")
    dist.barrier(); dataset = FashionVTONDataset(data_root=DATA_ROOT, image_size=IMAGE_SIZE, original_image_size=ORIGINAL_IMAGE_SIZE)
    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=True); loader = DataLoader(dataset, batch_size=BATCH_SIZE, sampler=sampler, num_workers=4, pin_memory=True)
    for epoch in range(start_epoch, NUM_EPOCHS):
        sampler.set_epoch(epoch); gmm.train()
        progress_bar = tqdm(loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}", disable=(rank != 0))
        for i, batch in enumerate(progress_bar):
            cloth_image = batch['cloth_image'].to(rank); cloth_mask = batch['cloth_mask'].to(rank); pose_map = batch['pose_map'].to(rank); ground_truth_warped = batch['warped_cloth'].to(rank)
            with torch.cuda.amp.autocast():
                predicted_flow = gmm(cloth_image, pose_map); predicted_warped = warp_cloth_with_flow(cloth_image, predicted_flow); warped_cloth_mask = warp_cloth_with_flow(cloth_mask, predicted_flow)
                loss_l1 = l1_loss_fn(predicted_warped * warped_cloth_mask, ground_truth_warped * warped_cloth_mask); loss_p = perceptual_loss_fn(predicted_warped, ground_truth_warped); loss_tv = tv_loss_fn(predicted_flow)
                total_loss = (10 * loss_l1) + loss_p + (0.5 * loss_tv); total_loss = total_loss / ACCUMULATION_STEPS
            scaler.scale(total_loss).backward()
            if (i + 1) % ACCUMULATION_STEPS == 0: scaler.step(optimizer); scaler.update(); optimizer.zero_grad()
            if rank == 0: progress_bar.set_postfix(loss=total_loss.item() * ACCUMULATION_STEPS)
        if rank == 0:
            latest_checkpoint_state = {'epoch': epoch, 'model_state_dict': gmm.module.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scaler_state_dict': scaler.state_dict()}
            torch.save(latest_checkpoint_state, checkpoint_path)
            with torch.no_grad():
                vis_batch = torch.cat([(cloth_image.cpu() + 1) / 2, (predicted_warped.cpu().detach() + 1) / 2, (ground_truth_warped.cpu() + 1) / 2], dim=0)
                save_image(vis_batch, os.path.join(VISUALIZATION_DIR, f'epoch_{epoch+1}_comparison.png'), nrow=BATCH_SIZE)
            print(f"Epoch {epoch+1} finished. Checkpoint saved.")
    if rank == 0:
        final_model_path = os.path.join(CHECKPOINT_DIR, 'gmm_final.pth')
        torch.save(gmm.module.state_dict(), final_model_path)
        print(f"Final model saved to {final_model_path}")
    cleanup()

if __name__ == '__main__':
    world_size = torch.cuda.device_count()
    if world_size < 2:
        print("Distributed training requires at least 2 GPUs.")
    else:
        # Define the path for the sync file in a writable directory
        sync_file_path = os.path.join('/kaggle/working', 'ddp_sync_file')
        
        # Ensure the sync file doesn't exist from a previous failed run
        if os.path.exists(sync_file_path):
            os.remove(sync_file_path)
            
        print(f"Found {world_size} GPUs. Starting DDP with file sync at {sync_file_path}")
        # Pass the sync_file_path to the train function
        torch.multiprocessing.spawn(train, args=(world_size, sync_file_path), nprocs=world_size, join=True)

Overwriting train_gmm_ddp.py


In [1]:
# ===================================================================
# FINAL KAGGLE NOTEBOOK CELL
# Includes all class definitions and Option 2 (robust, resumable training)
# ===================================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import os
import json
import cv2
import numpy as np
from PIL import Image
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from torchvision.utils import save_image
from tqdm import tqdm
from diffusers import AutoencoderKL, DDPMScheduler
from transformers import logging

# Suppress verbose messages from transformers
logging.set_verbosity_error()


# --- CLASS DEFINITIONS ---
class FashionVTONDataset(Dataset):
    def __init__(self, data_root, image_size=(256, 192), original_image_size=(768, 1024)):
        self.data_root = data_root; self.image_size = image_size; self.original_image_size = original_image_size; self.image_dir = os.path.join(data_root, 'image'); self.cloth_dir = os.path.join(data_root, 'cloth'); self.cloth_mask_dir = os.path.join(data_root, 'cloth-mask'); self.pose_dir = os.path.join(data_root, 'openpose_json'); self.parse_dir = os.path.join(data_root, 'image-parse-v3'); self.image_files = sorted([f for f in os.listdir(self.image_dir) if f.endswith(('.jpg', '.png'))]); self.transform = transforms.Compose([transforms.Resize(self.image_size, interpolation=transforms.InterpolationMode.BILINEAR), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]); self.mask_transform = transforms.Compose([transforms.Resize(self.image_size, interpolation=transforms.InterpolationMode.NEAREST), transforms.ToTensor()])
    def __len__(self): return len(self.image_files)
    def find_upper_cloth_label(self, pose_keypoints, parse_array):
        torso_indices = [1, 2, 5, 8, 9, 12]; visible_points = []
        for i in torso_indices:
            if i < len(pose_keypoints):
                x, y, conf = pose_keypoints[i]
                if conf > 0.1: visible_points.append((int(x), int(y)))
        if len(visible_points) < 3: return 5
        x_coords, y_coords = zip(*visible_points); min_x, max_x = min(x_coords), max(x_coords); min_y, max_y = min(y_coords), max(y_coords)
        if max_x <= min_x or max_y <= min_y: return 5
        torso_parse_area = parse_array[min_y:max_y, min_x:max_x]; unique_labels, counts = np.unique(torso_parse_area, return_counts=True); non_zero_mask = (unique_labels != 0)
        if np.any(non_zero_mask): return unique_labels[non_zero_mask][np.argmax(counts[non_zero_mask])]
        else: return 5
    def __getitem__(self, idx):
        image_name = self.image_files[idx]; base_name = os.path.splitext(image_name)[0]
        person_image = Image.open(os.path.join(self.image_dir, image_name)).convert('RGB'); cloth_image = Image.open(os.path.join(self.cloth_dir, image_name)).convert('RGB'); cloth_mask = Image.open(os.path.join(self.cloth_mask_dir, image_name)).convert('L')
        try:
            with open(os.path.join(self.pose_dir, f"{base_name}_keypoints.json"), 'r') as f: pose_data = json.load(f)
            pose_keypoints = np.array(pose_data['people'][0]['pose_keypoints_2d']).reshape(-1, 3)
        except (FileNotFoundError, IndexError): pose_keypoints = np.zeros((25, 3), dtype=np.float32)
        parse_path = os.path.join(self.parse_dir, f"{base_name}.png")
        if not os.path.exists(parse_path): parse_path = os.path.join(self.parse_dir, f"{base_name}.jpg")
        parse_array_orig = np.array(Image.open(parse_path).convert('L'))
        upper_cloth_label = self.find_upper_cloth_label(pose_keypoints, parse_array_orig)
        parse_array_resized = cv2.resize(parse_array_orig, self.image_size[::-1], interpolation=cv2.INTER_NEAREST)
        person_cloth_mask = (parse_array_resized == upper_cloth_label).astype(np.float32)
        person_image_tensor = self.transform(person_image); cloth_image_tensor = self.transform(cloth_image); cloth_mask_tensor = self.mask_transform(cloth_mask)
        pose_map_tensor = torch.from_numpy(self.create_pose_map(pose_keypoints)).float()
        blurred_mask_tensor = torch.from_numpy(cv2.GaussianBlur(person_cloth_mask, (5, 5), 0)).unsqueeze(0)
        agnostic_person_tensor = person_image_tensor * (1 - blurred_mask_tensor)
        warped_cloth_tensor = person_image_tensor * torch.from_numpy(person_cloth_mask).unsqueeze(0)
        return {'person_image': person_image_tensor, 'cloth_image': cloth_image_tensor, 'cloth_mask': cloth_mask_tensor, 'agnostic_person': agnostic_person_tensor, 'pose_map': pose_map_tensor, 'warped_cloth': warped_cloth_tensor}
    def create_pose_map(self, keypoints):
        h, w = self.image_size; orig_w, orig_h = self.original_image_size; num_keypoints = keypoints.shape[0]
        pose_map = np.zeros((num_keypoints, h, w), dtype=np.float32)
        for i, point in enumerate(keypoints):
            if point[2] > 0.1:
                x, y = int(point[0] * w / orig_w), int(point[1] * h / orig_h)
                if 0 <= x < w and 0 <= y < h: cv2.circle(pose_map[i], (x, y), radius=3, color=1, thickness=-1)
        return pose_map

class GMM(nn.Module):
    def __init__(self, in_channels_cloth=3, in_channels_pose=25, out_channels_flow=2):
        super(GMM, self).__init__(); self.encoder1 = self.conv_block(in_channels_cloth + in_channels_pose, 64); self.encoder2 = self.conv_block(64, 128); self.encoder3 = self.conv_block(128, 256); self.encoder4 = self.conv_block(256, 512); self.pool = nn.MaxPool2d(2, 2); self.bottleneck = self.conv_block(512, 1024); self.upconv4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2); self.decoder4 = self.conv_block(1024, 512); self.upconv3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2); self.decoder3 = self.conv_block(512, 256); self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2); self.decoder2 = self.conv_block(256, 128); self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2); self.decoder1 = self.conv_block(128, 64); self.conv_out = nn.Conv2d(64, out_channels_flow, kernel_size=1); self.tanh = nn.Tanh()
    def conv_block(self, in_channels, out_channels): return nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), nn.ReLU(inplace=True))
    def forward(self, cloth_image, pose_map):
        x = torch.cat([cloth_image, pose_map], dim=1); e1 = self.encoder1(x); p1 = self.pool(e1); e2 = self.encoder2(p1); p2 = self.pool(e2); e3 = self.encoder3(p2); p3 = self.pool(e3); e4 = self.encoder4(p3); p4 = self.pool(e4); b = self.bottleneck(p4); d4 = self.upconv4(b); d4 = torch.cat([d4, e4], dim=1); d4 = self.decoder4(d4); d3 = self.upconv3(d4); d3 = torch.cat([d3, e3], dim=1); d3 = self.decoder3(d3); d2 = self.upconv2(d3); d2 = torch.cat([d2, e2], dim=1); d2 = self.decoder2(d2); d1 = self.upconv1(d2); d1 = torch.cat([d1, e1], dim=1); d1 = self.decoder1(d1); flow_field = self.conv_out(d1); flow_field = self.tanh(flow_field); return flow_field

def warp_cloth_with_flow(cloth_image, flow_field):
    flow_field = flow_field.permute(0, 2, 3, 1); warped_image = F.grid_sample(cloth_image, flow_field, mode='bilinear', padding_mode='zeros', align_corners=True); return warped_image

class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim): super().__init__(); self.dim = dim
    def forward(self, time): device = time.device; half_dim = self.dim // 2; embeddings = math.log(10000) / (half_dim - 1); embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings); embeddings = time[:, None] * embeddings[None, :]; embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1); return embeddings
class ResnetBlock(nn.Module):
    def __init__(self, in_channels, out_channels, *, time_emb_dim=None):
        super().__init__(); self.mlp = (nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, out_channels)) if time_emb_dim is not None else None); self.block1 = nn.Sequential(nn.Conv2d(in_channels, out_channels, 3, padding=1), nn.GroupNorm(8, out_channels), nn.SiLU()); self.block2 = nn.Sequential(nn.Conv2d(out_channels, out_channels, 3, padding=1), nn.GroupNorm(8, out_channels), nn.SiLU()); self.res_conv = nn.Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
    def forward(self, x, time_emb=None):
        h = self.block1(x)
        if self.mlp is not None and time_emb is not None:
            time_emb = self.mlp(time_emb); h = h + time_emb.unsqueeze(-1).unsqueeze(-1)
        h = self.block2(h); return h + self.res_conv(x)
class AttentionBlock(nn.Module):
    def __init__(self, channels): super().__init__(); self.gn = nn.GroupNorm(8, channels); self.qkv = nn.Conv2d(channels, channels * 3, 1); self.out = nn.Conv2d(channels, channels, 1)
    def forward(self, x, time_emb=None):
        b, c, h, w = x.shape; x_in = x; x = self.gn(x); x = self.qkv(x); q, k, v = torch.chunk(x, 3, dim=1); q = q.view(b, c, h * w); k = k.view(b, c, h * w); v = v.view(b, c, h * w); k = k.softmax(dim=-1); attn = torch.einsum("b c i, b c j -> b i j", q, k); out = torch.einsum("b i j, b c j -> b c i", attn, v); out = out.view(b, c, h, w); return self.out(out) + x_in
class ConditionalUNet(nn.Module):
    def __init__(self, in_channels, model_channels, out_channels, time_emb_dim=256, condition_channels=6):
        super().__init__(); self.time_mlp = nn.Sequential(SinusoidalPositionEmbeddings(time_emb_dim), nn.Linear(time_emb_dim, time_emb_dim), nn.ReLU()); self.init_conv = nn.Conv2d(in_channels + condition_channels, model_channels, kernel_size=3, padding=1); self.down1 = ResnetBlock(model_channels, 128, time_emb_dim=time_emb_dim); self.down2 = ResnetBlock(128, 128, time_emb_dim=time_emb_dim); self.down3 = ResnetBlock(128, 256, time_emb_dim=time_emb_dim); self.down4 = AttentionBlock(256); self.down5 = ResnetBlock(256, 256, time_emb_dim=time_emb_dim); self.down6 = ResnetBlock(256, 512, time_emb_dim=time_emb_dim); self.pool = nn.MaxPool2d(2); self.mid1 = ResnetBlock(512, 1024, time_emb_dim=time_emb_dim); self.mid_attn = AttentionBlock(1024); self.mid2 = ResnetBlock(1024, 512, time_emb_dim=time_emb_dim); self.up1 = nn.ConvTranspose2d(512, 256, 2, 2); self.up_res1 = ResnetBlock(512, 256, time_emb_dim=time_emb_dim); self.up_attn1 = AttentionBlock(256); self.up_res2 = ResnetBlock(256, 256, time_emb_dim=time_emb_dim); self.up2 = nn.ConvTranspose2d(256, 128, 2, 2); self.up_res3 = ResnetBlock(256, 128, time_emb_dim=time_emb_dim); self.up_res4 = ResnetBlock(128, 128, time_emb_dim=time_emb_dim); self.out_res = ResnetBlock(128 + model_channels, 64, time_emb_dim=time_emb_dim); self.out_conv = nn.Conv2d(64, out_channels, 1)
    def forward(self, x, time, condition):
        t = self.time_mlp(time); condition_downsampled = F.interpolate(condition, size=x.shape[2:], mode='bilinear', align_corners=False); x = torch.cat([x, condition_downsampled], dim=1); x = self.init_conv(x); r0 = x.clone(); x = self.down1(x, t); x = self.down2(x, t); r1 = x.clone(); x = self.pool(x); x = self.down3(x, t); x = self.down4(x, t); x = self.down5(x, t); r2 = x.clone(); x = self.pool(x); x = self.down6(x, t); x = self.mid1(x, t); x = self.mid_attn(x, t); x = self.mid2(x, t); x = self.up1(x); x = torch.cat([x, r2], dim=1); x = self.up_res1(x, t); x = self.up_attn1(x, t); x = self.up_res2(x, t); x = self.up2(x); x = torch.cat([x, r1], dim=1); x = self.up_res3(x, t); x = self.up_res4(x, t); x = torch.cat([x, r0], dim=1); x = self.out_res(x, t); return self.out_conv(x)


# --- CONFIGURATION & TRAINING SCRIPT ---
DATA_ROOT = '/kaggle/input/clothe/clothes_tryon_dataset/train'
OUTPUT_DIR = '/kaggle/working/'
GMM_CHECKPOINT_PATH = '/kaggle/input/gmm_5epoch/pytorch/default/1/gmm_final (works but low epochs).pth'
CHECKPOINT_DIR_DIFF = os.path.join(OUTPUT_DIR, 'checkpoints_diffusion')
VISUALIZATION_DIR_DIFF = os.path.join(OUTPUT_DIR, 'visualizations_diffusion')

BATCH_SIZE = 2; NUM_EPOCHS = 100; LEARNING_RATE = 1e-4; IMAGE_SIZE = (256, 192); ORIGINAL_IMAGE_SIZE = (768, 1024); VALIDATION_SPLIT = 0.1
VAE_MODEL_ID = "stabilityai/stable-diffusion-2-1-base"; VAE_SUBFOLDER = "vae"; NUM_TRAIN_TIMESTEPS = 1000

os.makedirs(CHECKPOINT_DIR_DIFF, exist_ok=True); os.makedirs(VISUALIZATION_DIR_DIFF, exist_ok=True)

@torch.no_grad()
def evaluate_and_visualize(epoch, unet, gmm, vae, noise_scheduler, val_loader, device):
    unet.eval(); gmm.eval(); val_loss = 0.0; progress_bar = tqdm(val_loader, desc="Validating", leave=False)
    for batch in progress_bar:
        person_image = batch['person_image'].to(device); cloth_image = batch['cloth_image'].to(device); pose_map = batch['pose_map'].to(device); agnostic_person = batch['agnostic_person'].to(device)
        flow = gmm(cloth_image, pose_map); warped_cloth = warp_cloth_with_flow(cloth_image, flow); condition = torch.cat([agnostic_person, warped_cloth], dim=1)
        latents = vae.encode(person_image).latent_dist.sample() * vae.config.scaling_factor
        noise = torch.randn_like(latents); timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (latents.shape[0],), device=device).long()
        noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps); predicted_noise = unet(noisy_latents, timesteps, condition); loss = F.mse_loss(predicted_noise, noise); val_loss += loss.item()
    avg_val_loss = val_loss / len(val_loader)
    sample_condition = condition[0:1]; original_vis_image = person_image[0:1]; sample_latents = torch.randn(1, 4, IMAGE_SIZE[0] // 8, IMAGE_SIZE[1] // 8).to(device)
    for t in tqdm(noise_scheduler.timesteps, desc="Generating sample", leave=False):
        pred_noise = unet(sample_latents, t.unsqueeze(0).to(device), sample_condition); sample_latents = noise_scheduler.step(pred_noise, t, sample_latents).prev_sample
    sample_latents = 1 / vae.config.scaling_factor * sample_latents; generated_image = vae.decode(sample_latents).sample
    original_vis = (original_vis_image + 1) / 2; condition_vis = (sample_condition[:, :3] + 1) / 2; warped_vis = (sample_condition[:, 3:] + 1) / 2; generated_vis = (generated_image + 1) / 2
    comparison = torch.cat([original_vis, condition_vis, warped_vis, generated_vis], dim=0); save_image(comparison, os.path.join(VISUALIZATION_DIR_DIFF, f'epoch_{epoch+1}_sample.png'), nrow=4)
    unet.train(); return avg_val_loss

def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu"); print(f"Using device: {device}")
    vae = AutoencoderKL.from_pretrained(VAE_MODEL_ID, subfolder=VAE_SUBFOLDER).to(device)
    gmm = GMM(in_channels_pose=25).to(device)
    unet = ConditionalUNet(in_channels=4, model_channels=128, out_channels=4, condition_channels=6).to(device)
    optimizer = torch.optim.AdamW(unet.parameters(), lr=LEARNING_RATE)
    noise_scheduler = DDPMScheduler(num_train_timesteps=NUM_TRAIN_TIMESTEPS, beta_schedule='squaredcos_cap_v2')
    
    if not os.path.isfile(GMM_CHECKPOINT_PATH): print(f"FATAL ERROR: GMM checkpoint file not found at '{GMM_CHECKPOINT_PATH}'"); return
    gmm.load_state_dict(torch.load(GMM_CHECKPOINT_PATH, map_location=device))
    vae.requires_grad_(False); gmm.requires_grad_(False)

    start_epoch = 0; best_val_loss = float('inf')
    checkpoint_path = os.path.join(CHECKPOINT_DIR_DIFF, 'latest_checkpoint.pth')
    if os.path.isfile(checkpoint_path):
        print(f"Resuming training from checkpoint: {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path)
        unet.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        best_val_loss = checkpoint['best_val_loss']
        print(f"Resumed from Epoch {start_epoch}, Best Val Loss: {best_val_loss:.4f}")
    else: print("Starting training from scratch.")

    full_dataset = FashionVTONDataset(data_root=DATA_ROOT, image_size=IMAGE_SIZE, original_image_size=ORIGINAL_IMAGE_SIZE)
    val_size = int(len(full_dataset) * VALIDATION_SPLIT); train_size = len(full_dataset) - val_size
    train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)
    print(f"Dataset loaded: {train_size} training samples, {val_size} validation samples.")
    
    print(f"Starting training from epoch {start_epoch + 1}...")
    for epoch in range(start_epoch, NUM_EPOCHS):
        unet.train(); train_loss = 0.0
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}")
        for batch in progress_bar:
            with torch.no_grad():
                person_image = batch['person_image'].to(device); cloth_image = batch['cloth_image'].to(device); pose_map = batch['pose_map'].to(device); agnostic_person = batch['agnostic_person'].to(device)
                flow = gmm(cloth_image, pose_map); warped_cloth = warp_cloth_with_flow(cloth_image, flow); condition = torch.cat([agnostic_person, warped_cloth], dim=1)
                latents = vae.encode(person_image).latent_dist.sample() * vae.config.scaling_factor
            noise = torch.randn_like(latents); timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (latents.shape[0],), device=device).long()
            noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
            optimizer.zero_grad(); predicted_noise = unet(noisy_latents, timesteps, condition); loss = F.mse_loss(predicted_noise, noise); loss.backward(); optimizer.step()
            train_loss += loss.item(); progress_bar.set_postfix(loss=loss.item())
        
        avg_train_loss = train_loss / len(train_loader)
        avg_val_loss = evaluate_and_visualize(epoch, unet, gmm, vae, noise_scheduler, val_loader, device)
        print(f"Epoch {epoch+1}/{NUM_EPOCHS} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            best_model_path = os.path.join(CHECKPOINT_DIR_DIFF, 'unet_best.pth')
            torch.save(unet.state_dict(), best_model_path)
            print(f"🎉 New best model saved with validation loss: {best_val_loss:.4f}")

        latest_checkpoint_state = {'epoch': epoch, 'model_state_dict': unet.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'best_val_loss': best_val_loss}
        torch.save(latest_checkpoint_state, checkpoint_path)
        print(f"Latest checkpoint updated at {checkpoint_path}")

    print("Training finished.")

if __name__ == '__main__':
    main()

2025-07-15 13:04:34.297754: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1752584674.666142      36 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1752584674.775957      36 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


Using device: cuda


config.json:   0%|          | 0.00/553 [00:00<?, ?B/s]

vae/diffusion_pytorch_model.safetensors:   0%|          | 0.00/335M [00:00<?, ?B/s]

Starting training from scratch.
Dataset loaded: 10483 training samples, 1164 validation samples.
Starting training from epoch 1...


Epoch 1/100:   1%|          | 63/5242 [00:12<17:10,  5.03it/s, loss=0.672]


KeyboardInterrupt: 

In [16]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class SinusoidalPositionEmbeddings(nn.Module):
    """
    Module to generate sinusoidal time embeddings.
    Used to inform the U-Net of the current noise level (timestep).
    """
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings

class ResnetBlock(nn.Module):
    """A standard ResNet block with two convolutions."""
    def __init__(self, in_channels, out_channels, *, time_emb_dim=None):
        super().__init__()
        self.mlp = (
            nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, out_channels))
            if time_emb_dim is not None
            else None
        )

        self.block1 = nn.Sequential(nn.Conv2d(in_channels, out_channels, 3, padding=1), nn.GroupNorm(8, out_channels), nn.SiLU())
        self.block2 = nn.Sequential(nn.Conv2d(out_channels, out_channels, 3, padding=1), nn.GroupNorm(8, out_channels), nn.SiLU())
        self.res_conv = nn.Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()

    def forward(self, x, time_emb=None):
        h = self.block1(x)
        if self.mlp is not None and time_emb is not None:
            time_emb = self.mlp(time_emb)
            h = h + time_emb.unsqueeze(-1).unsqueeze(-1)
        h = self.block2(h)
        return h + self.res_conv(x)

class AttentionBlock(nn.Module):
    """Self-attention block."""
    def __init__(self, channels):
        super().__init__()
        self.gn = nn.GroupNorm(8, channels)
        self.qkv = nn.Conv2d(channels, channels * 3, 1)
        self.out = nn.Conv2d(channels, channels, 1)

    class AttentionBlock(nn.Module):
    """Self-attention block."""
    def __init__(self, channels):
        super().__init__()
        self.gn = nn.GroupNorm(8, channels)
        self.qkv = nn.Conv2d(channels, channels * 3, 1)
        self.out = nn.Conv2d(channels, channels, 1)

    # CORRECTED LINE: Add the time_emb argument
    def forward(self, x, time_emb=None):
        b, c, h, w = x.shape
        x_in = x
        x = self.gn(x)
        x = self.qkv(x)
        q, k, v = torch.chunk(x, 3, dim=1)

        q = q.view(b, c, h * w)
        k = k.view(b, c, h * w)
        v = v.view(b, c, h * w)

        # The rest of the logic remains exactly the same
        k = k.softmax(dim=-1)
        attn = torch.einsum("b c i, b c j -> b i j", q, k)
        out = torch.einsum("b i j, b c j -> b c i", attn, v)
        out = out.view(b, c, h, w)
        return self.out(out) + x_in
class ConditionalUNet(nn.Module):
    """
    The main U-Net for the diffusion model.
    This version has the corrected upsampling path.
    """
    def __init__(self, in_channels, model_channels, out_channels, time_emb_dim=256, condition_channels=6):
        super().__init__()
        
        # --- Time embedding and Initial Conv (No change) ---
        self.time_mlp = nn.Sequential(
            SinusoidalPositionEmbeddings(time_emb_dim),
            nn.Linear(time_emb_dim, time_emb_dim),
            nn.ReLU()
        )
        self.init_conv = nn.Conv2d(in_channels + condition_channels, model_channels, kernel_size=3, padding=1)

        # --- Downsampling path (No change) ---
        self.down_blocks = nn.ModuleList([
            ResnetBlock(model_channels, 128, time_emb_dim=time_emb_dim),
            ResnetBlock(128, 128, time_emb_dim=time_emb_dim),
            nn.Conv2d(128, 128, 3, 2, 1),
            ResnetBlock(128, 256, time_emb_dim=time_emb_dim),
            AttentionBlock(256, time_emb_dim=time_emb_dim),
            ResnetBlock(256, 256, time_emb_dim=time_emb_dim),
            nn.Conv2d(256, 256, 3, 2, 1),
            ResnetBlock(256, 512, time_emb_dim=time_emb_dim),
        ])

        # --- Bottleneck (No change) ---
        self.mid_block1 = ResnetBlock(512, 1024, time_emb_dim=time_emb_dim)
        self.mid_attn = AttentionBlock(1024, time_emb_dim=time_emb_dim)
        self.mid_block2 = ResnetBlock(1024, 512, time_emb_dim=time_emb_dim)

        # =================== FIX STARTS HERE: CORRECTED UPSAMPLING ARCHITECTURE ===================
        #
        # Adjust the in_channels for the ResNet blocks to account for the concatenation
        # with the skip connection from the downsampling path.
        #
        self.up_blocks = nn.ModuleList([
            # Input to this block is 512 (from mid_block2) + 512 (from residual) = 1024
            ResnetBlock(1024, 256, time_emb_dim=time_emb_dim),
            nn.ConvTranspose2d(256, 256, 4, 2, 1), # Upsample
            
            # Input to this block is 256 (from above) + 256 (from residual) = 512
            ResnetBlock(512, 128, time_emb_dim=time_emb_dim),
            AttentionBlock(128, time_emb_dim=time_emb_dim),
            
            # Input to this block is 128 (from above) + 128 (from residual) = 256
            ResnetBlock(256, 128, time_emb_dim=time_emb_dim),
            nn.ConvTranspose2d(128, 128, 4, 2, 1), # Upsample
            
            # Input to this block is 128 (from above) + 128 (from residual) = 256
            ResnetBlock(256, 64, time_emb_dim=time_emb_dim),

            # Input to this block is 64 (from above) + 64 (from residual) = 128
            ResnetBlock(128, 64, time_emb_dim=time_emb_dim),
        ])

        self.final_res_block = ResnetBlock(model_channels + 64, model_channels, time_emb_dim=time_emb_dim)
        self.out_conv = nn.Conv2d(model_channels, out_channels, 1)

    def forward(self, x, time, condition):
        # Initial projection, time embeddings, and downsampling path (No change)
        t = self.time_mlp(time)
        condition_downsampled = F.interpolate(condition, size=x.shape[2:], mode='bilinear', align_corners=False)
        x = torch.cat([x, condition_downsampled], dim=1)
        x = self.init_conv(x)
        
        residuals = [x.clone()]
        for block in self.down_blocks:
            x = block(x, t)
            residuals.append(x)
            
        x = self.mid_block1(x, t)
        x = self.mid_attn(x, t)
        x = self.mid_block2(x, t)

        # =================== FIX STARTS HERE: CORRECTED FORWARD PASS LOGIC ===================
        #
        # Implement the correct U-Net upsampling logic:
        # 1. Process with ResNet/Attention blocks.
        # 2. If it's an upsampling layer, apply it.
        # 3. Concatenate with the skip connection AFTER upsampling.
        #
        for block in self.up_blocks:
            # Concatenate the skip connection from the downsampling path
            res = residuals.pop()
            x = torch.cat([x, res], dim=1)
            
            # Pass through the ResNet/Attention blocks
            x = block(x, t)
            
            # If the block is an upsampling layer, the architecture definition handles it.
            # We just need to ensure the logic flow is correct. Let's simplify the loop.
            # The architecture definition is the main fix. Let's rewrite the loop to match.
            
        # Let's use a more explicit loop that matches the corrected architecture
        x = self.up_blocks[0](torch.cat([x, residuals.pop()], dim=1), t) # ResNet: 1024 -> 256
        x = self.up_blocks[1](x) # Upsample: 256 -> 256
        
        x = self.up_blocks[2](torch.cat([x, residuals.pop()], dim=1), t) # ResNet: 512 -> 128
        x = self.up_blocks[3](x, t) # Attention
        x = self.up_blocks[4](torch.cat([x, residuals.pop()], dim=1), t) # ResNet: 256 -> 128
        x = self.up_blocks[5](x) # Upsample: 128 -> 128
        
        x = self.up_blocks[6](torch.cat([x, residuals.pop()], dim=1), t) # ResNet: 256 -> 64
        x = self.up_blocks[7](torch.cat([x, residuals.pop()], dim=1), t) # ResNet: 128 -> 64
        
        # --- Final block ---
        x = self.final_res_block(torch.cat([x, residuals.pop()], dim=1), t)
        return self.out_conv(x)
        
    def forward(self, x, time, condition):
        # x: noisy latents [B, 4, H/8, W/8]
        # time: timestep [B]
        # condition: concatenated agnostic_person and warped_cloth [B, 6, H, W]
        
        # Initial projection and time embeddings
        t = self.time_mlp(time)
        
        # We need to downsample the condition to match the latent space dimensions
        condition_downsampled = F.interpolate(condition, size=x.shape[2:], mode='bilinear', align_corners=False)
        x = torch.cat([x, condition_downsampled], dim=1)
        x = self.init_conv(x)
        
        # Store residuals for skip connections
        residuals = [x.clone()]
        
        # Downsampling
        for block in self.down_blocks:
            if isinstance(block, nn.Conv2d): # Downsampling layer
                x = block(x)
            else: # ResNet or Attention block
                x = block(x, t)
            residuals.append(x)
            
        # Bottleneck
        x = self.mid_block1(x, t)
        x = self.mid_attn(x)
        x = self.mid_block2(x, t)

        # Upsampling
        for block in self.up_blocks:
            if isinstance(block, nn.ConvTranspose2d): # Upsampling layer
                res = residuals.pop()
                x = torch.cat([x, res], dim=1)
                x = block(x)
            else: # ResNet or Attention block
                res = residuals.pop()
                x = torch.cat([x, res], dim=1)
                x = block(x, t)
        
        # Final block
        x = torch.cat([x, residuals.pop()], dim=1)
        x = self.final_res_block(x, t)
        
        return self.out_conv(x)

IndentationError: expected an indented block after class definition on line 54 (4015150922.py, line 55)

In [6]:
from torch.utils.data import DataLoader, random_split
from torchvision.utils import save_image
from tqdm import tqdm
from diffusers import AutoencoderKL, DDPMScheduler
from transformers import logging

# Suppress verbose messages from transformers
#logging.set_verbosity_error()

# =================== FIX STARTS HERE: CORRECT PATHS AND LOGIC ===================

# --- Training Configuration ---
# =================== FIX STARTS HERE: CORRECT PATHS FOR KAGGLE INPUT ===================

# --- Training Configuration ---
DATA_ROOT = '/kaggle/input/clothe/clothes_tryon_dataset/train'
OUTPUT_DIR = '/kaggle/working/' # The main directory for all our new outputs

# CORRECTED PATH: Point directly to the uploaded model file in your Kaggle dataset
GMM_CHECKPOINT_PATH = '/kaggle/input/gmm_5epoch/pytorch/default/1/gmm_final (works but low epochs).pth'

# Define subdirectories for our NEW diffusion model outputs in the working directory
CHECKPOINT_DIR_DIFF = os.path.join(OUTPUT_DIR, 'checkpoints_diffusion')
VISUALIZATION_DIR_DIFF = os.path.join(OUTPUT_DIR, 'visualizations_diffusion')

# --- The rest of your configuration ---
BATCH_SIZE = 2
NUM_EPOCHS = 100
LEARNING_RATE = 1e-4
IMAGE_SIZE = (256, 192)
ORIGINAL_IMAGE_SIZE = (768, 1024)
VALIDATION_SPLIT = 0.1
VAE_MODEL_ID = "stabilityai/stable-diffusion-2-1-base"
VAE_SUBFOLDER = "vae"
NUM_TRAIN_TIMESTEPS = 1000

# =================== FIX ENDS HERE ===================
# =================== FIX ENDS HERE ===================


# Create directories if they don't exist
os.makedirs(CHECKPOINT_DIR_DIFF, exist_ok=True)
os.makedirs(VISUALIZATION_DIR_DIFF, exist_ok=True)


@torch.no_grad()
def evaluate_and_visualize(epoch, unet, gmm, vae, noise_scheduler, val_loader, device):
    # ... (paste the full evaluate_and_visualize function here, it's correct) ...
    unet.eval(); gmm.eval(); val_loss = 0.0; progress_bar = tqdm(val_loader, desc="Validating", leave=False)
    for batch in progress_bar:
        person_image = batch['person_image'].to(device); cloth_image = batch['cloth_image'].to(device); pose_map = batch['pose_map'].to(device); agnostic_person = batch['agnostic_person'].to(device)
        flow = gmm(cloth_image, pose_map); warped_cloth = warp_cloth_with_flow(cloth_image, flow); condition = torch.cat([agnostic_person, warped_cloth], dim=1)
        latents = vae.encode(person_image).latent_dist.sample() * vae.config.scaling_factor
        noise = torch.randn_like(latents); timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (latents.shape[0],), device=device).long()
        noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps); predicted_noise = unet(noisy_latents, timesteps, condition); loss = F.mse_loss(predicted_noise, noise); val_loss += loss.item()
    avg_val_loss = val_loss / len(val_loader)
    sample_condition = condition[0:1]; original_vis_image = person_image[0:1]; sample_latents = torch.randn(1, 4, IMAGE_SIZE[0] // 8, IMAGE_SIZE[1] // 8).to(device)
    for t in tqdm(noise_scheduler.timesteps, desc="Generating sample", leave=False):
        pred_noise = unet(sample_latents, t.unsqueeze(0).to(device), sample_condition); sample_latents = noise_scheduler.step(pred_noise, t, sample_latents).prev_sample
    sample_latents = 1 / vae.config.scaling_factor * sample_latents; generated_image = vae.decode(sample_latents).sample
    original_vis = (original_vis_image + 1) / 2; condition_vis = (sample_condition[:, :3] + 1) / 2; warped_vis = (sample_condition[:, 3:] + 1) / 2; generated_vis = (generated_image + 1) / 2
    comparison = torch.cat([original_vis, condition_vis, warped_vis, generated_vis], dim=0); save_image(comparison, os.path.join(VISUALIZATION_DIR_DIFF, f'epoch_{epoch+1}_sample.png'), nrow=4)
    unet.train(); return avg_val_loss


def main():
    """Main training loop with validation and best model saving."""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # --- Load Pre-trained Components ---
    print("Loading pre-trained VAE and GMM...")
    vae = AutoencoderKL.from_pretrained(VAE_MODEL_ID, subfolder=VAE_SUBFOLDER).to(device)
    
    # Check if the GMM checkpoint file exists
    if not os.path.isfile(GMM_CHECKPOINT_PATH):
        print(f"FATAL ERROR: GMM checkpoint file not found at '{GMM_CHECKPOINT_PATH}'")
        print("Please ensure the GMM model was trained and its output file 'gmm_final.pth' is in '/kaggle/working/'.")
        return # Use return instead of exit() in notebooks
    
    print(f"Found GMM checkpoint at: {GMM_CHECKPOINT_PATH}")
    gmm = GMM(in_channels_pose=25).to(device)
    gmm.load_state_dict(torch.load(GMM_CHECKPOINT_PATH, map_location=device))
    
    vae.requires_grad_(False)
    gmm.requires_grad_(False)
    noise_scheduler = DDPMScheduler(num_train_timesteps=NUM_TRAIN_TIMESTEPS, beta_schedule='squaredcos_cap_v2')

    # --- Data Loading and Splitting ---
    print("Loading and splitting dataset...")
    full_dataset = FashionVTONDataset(
        data_root=DATA_ROOT, image_size=IMAGE_SIZE, original_image_size=ORIGINAL_IMAGE_SIZE
    )
    val_size = int(len(full_dataset) * VALIDATION_SPLIT); train_size = len(full_dataset) - val_size
    train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True) # Reduced num_workers for Kaggle
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)
    print(f"Dataset loaded: {train_size} training samples, {val_size} validation samples.")

    # --- Model, Optimizer ---
    unet = ConditionalUNet(in_channels=4, model_channels=128, out_channels=4, condition_channels=6).to(device)
    optimizer = optim.AdamW(unet.parameters(), lr=LEARNING_RATE)

    # --- Training Loop ---
    print("Starting Diffusion Model training...")
    best_val_loss = float('inf')

    for epoch in range(NUM_EPOCHS):
        unet.train(); train_loss = 0.0
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}")
        for batch in progress_bar:
            with torch.no_grad():
                person_image = batch['person_image'].to(device); cloth_image = batch['cloth_image'].to(device); pose_map = batch['pose_map'].to(device); agnostic_person = batch['agnostic_person'].to(device)
                flow = gmm(cloth_image, pose_map); warped_cloth = warp_cloth_with_flow(cloth_image, flow); condition = torch.cat([agnostic_person, warped_cloth], dim=1)
                latents = vae.encode(person_image).latent_dist.sample() * vae.config.scaling_factor
            noise = torch.randn_like(latents); timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (latents.shape[0],), device=device).long()
            noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
            optimizer.zero_grad(); predicted_noise = unet(noisy_latents, timesteps, condition); loss = F.mse_loss(predicted_noise, noise); loss.backward(); optimizer.step()
            train_loss += loss.item(); progress_bar.set_postfix(loss=loss.item())
        
        avg_train_loss = train_loss / len(train_loader)
        avg_val_loss = evaluate_and_visualize(epoch, unet, gmm, vae, noise_scheduler, val_loader, device)
        print(f"Epoch {epoch+1}/{NUM_EPOCHS} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")
        
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            best_model_path = os.path.join(CHECKPOINT_DIR_DIFF, 'unet_best.pth')
            torch.save(unet.state_dict(), best_model_path)
            print(f"🎉 New best model saved with validation loss: {best_val_loss:.4f} at {best_model_path}")
    
    print("Diffusion model training finished.")
    print(f"Best validation loss achieved: {best_val_loss:.4f}")

if __name__ == '__main__':
    main()

Using device: cuda
Loading pre-trained VAE and GMM...
Found GMM checkpoint at: /kaggle/input/gmm_5epoch/pytorch/default/1/gmm_final (works but low epochs).pth


NameError: name 'GMM' is not defined

In [21]:
# ===================================================================
# CELL 1: ALL IMPORTS AND CLASS DEFINITIONS
# This part fixes the "NameError: name 'GMM' is not defined"
# ===================================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import os
import json
import cv2
import numpy as np
from PIL import Image
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from torchvision.utils import save_image
from tqdm import tqdm
from diffusers import AutoencoderKL, DDPMScheduler
from transformers import logging
from torch.optim import AdamW
# Suppress verbose messages from transformers
logging.set_verbosity_error()


# --- Class Definition from: src/dataset.py ---
class FashionVTONDataset(Dataset):
    def __init__(self, data_root, image_size=(256, 192), original_image_size=(768, 1024)):
        self.data_root = data_root
        self.image_size = image_size
        self.original_image_size = original_image_size
        self.image_dir = os.path.join(data_root, 'image')
        self.cloth_dir = os.path.join(data_root, 'cloth')
        self.cloth_mask_dir = os.path.join(data_root, 'cloth-mask')
        self.pose_dir = os.path.join(data_root, 'openpose_json')
        self.parse_dir = os.path.join(data_root, 'image-parse-v3')
        self.image_files = sorted([f for f in os.listdir(self.image_dir) if f.endswith(('.jpg', '.png'))])
        self.transform = transforms.Compose([transforms.Resize(self.image_size, interpolation=transforms.InterpolationMode.BILINEAR), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        self.mask_transform = transforms.Compose([transforms.Resize(self.image_size, interpolation=transforms.InterpolationMode.NEAREST), transforms.ToTensor()])
    def __len__(self):
        return len(self.image_files)
    def find_upper_cloth_label(self, pose_keypoints, parse_array):
        torso_indices = [1, 2, 5, 8, 9, 12]; visible_points = []
        for i in torso_indices:
            if i < len(pose_keypoints):
                x, y, conf = pose_keypoints[i]
                if conf > 0.1: visible_points.append((int(x), int(y)))
        if len(visible_points) < 3: return 5
        x_coords, y_coords = zip(*visible_points); min_x, max_x = min(x_coords), max(x_coords); min_y, max_y = min(y_coords), max(y_coords)
        if max_x <= min_x or max_y <= min_y: return 5
        torso_parse_area = parse_array[min_y:max_y, min_x:max_x]; unique_labels, counts = np.unique(torso_parse_area, return_counts=True); non_zero_mask = (unique_labels != 0)
        if np.any(non_zero_mask): return unique_labels[non_zero_mask][np.argmax(counts[non_zero_mask])]
        else: return 5
    def __getitem__(self, idx):
        image_name = self.image_files[idx]; base_name = os.path.splitext(image_name)[0]
        person_image = Image.open(os.path.join(self.image_dir, image_name)).convert('RGB'); cloth_image = Image.open(os.path.join(self.cloth_dir, image_name)).convert('RGB'); cloth_mask = Image.open(os.path.join(self.cloth_mask_dir, image_name)).convert('L')
        try:
            with open(os.path.join(self.pose_dir, f"{base_name}_keypoints.json"), 'r') as f: pose_data = json.load(f)
            pose_keypoints = np.array(pose_data['people'][0]['pose_keypoints_2d']).reshape(-1, 3)
        except (FileNotFoundError, IndexError): pose_keypoints = np.zeros((25, 3), dtype=np.float32)
        parse_path = os.path.join(self.parse_dir, f"{base_name}.png")
        if not os.path.exists(parse_path): parse_path = os.path.join(self.parse_dir, f"{base_name}.jpg") # fallback
        parse_array_orig = np.array(Image.open(parse_path).convert('L'))
        upper_cloth_label = self.find_upper_cloth_label(pose_keypoints, parse_array_orig)
        parse_array_resized = cv2.resize(parse_array_orig, self.image_size[::-1], interpolation=cv2.INTER_NEAREST)
        person_cloth_mask = (parse_array_resized == upper_cloth_label).astype(np.float32)
        person_image_tensor = self.transform(person_image); cloth_image_tensor = self.transform(cloth_image); cloth_mask_tensor = self.mask_transform(cloth_mask)
        pose_map_tensor = torch.from_numpy(self.create_pose_map(pose_keypoints)).float()
        blurred_mask_tensor = torch.from_numpy(cv2.GaussianBlur(person_cloth_mask, (5, 5), 0)).unsqueeze(0)
        agnostic_person_tensor = person_image_tensor * (1 - blurred_mask_tensor)
        warped_cloth_tensor = person_image_tensor * torch.from_numpy(person_cloth_mask).unsqueeze(0)
        return {'person_image': person_image_tensor, 'cloth_image': cloth_image_tensor, 'cloth_mask': cloth_mask_tensor, 'agnostic_person': agnostic_person_tensor, 'pose_map': pose_map_tensor, 'warped_cloth': warped_cloth_tensor}
    def create_pose_map(self, keypoints):
        h, w = self.image_size; orig_w, orig_h = self.original_image_size; num_keypoints = keypoints.shape[0]
        pose_map = np.zeros((num_keypoints, h, w), dtype=np.float32)
        for i, point in enumerate(keypoints):
            if point[2] > 0.1:
                x, y = int(point[0] * w / orig_w), int(point[1] * h / orig_h)
                if 0 <= x < w and 0 <= y < h: cv2.circle(pose_map[i], (x, y), radius=3, color=1, thickness=-1)
        return pose_map

# --- Class Definition from: src/geometric_matching.py ---
class GMM(nn.Module):
    def __init__(self, in_channels_cloth=3, in_channels_pose=25, out_channels_flow=2):
        super(GMM, self).__init__(); self.encoder1 = self.conv_block(in_channels_cloth + in_channels_pose, 64); self.encoder2 = self.conv_block(64, 128); self.encoder3 = self.conv_block(128, 256); self.encoder4 = self.conv_block(256, 512); self.pool = nn.MaxPool2d(2, 2); self.bottleneck = self.conv_block(512, 1024); self.upconv4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2); self.decoder4 = self.conv_block(1024, 512); self.upconv3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2); self.decoder3 = self.conv_block(512, 256); self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2); self.decoder2 = self.conv_block(256, 128); self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2); self.decoder1 = self.conv_block(128, 64); self.conv_out = nn.Conv2d(64, out_channels_flow, kernel_size=1); self.tanh = nn.Tanh()
    def conv_block(self, in_channels, out_channels): return nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), nn.ReLU(inplace=True))
    def forward(self, cloth_image, pose_map):
        x = torch.cat([cloth_image, pose_map], dim=1); e1 = self.encoder1(x); p1 = self.pool(e1); e2 = self.encoder2(p1); p2 = self.pool(e2); e3 = self.encoder3(p2); p3 = self.pool(e3); e4 = self.encoder4(p3); p4 = self.pool(e4); b = self.bottleneck(p4); d4 = self.upconv4(b); d4 = torch.cat([d4, e4], dim=1); d4 = self.decoder4(d4); d3 = self.upconv3(d4); d3 = torch.cat([d3, e3], dim=1); d3 = self.decoder3(d3); d2 = self.upconv2(d3); d2 = torch.cat([d2, e2], dim=1); d2 = self.decoder2(d2); d1 = self.upconv1(d2); d1 = torch.cat([d1, e1], dim=1); d1 = self.decoder1(d1); flow_field = self.conv_out(d1); flow_field = self.tanh(flow_field); return flow_field

def warp_cloth_with_flow(cloth_image, flow_field):
    flow_field = flow_field.permute(0, 2, 3, 1); warped_image = F.grid_sample(cloth_image, flow_field, mode='bilinear', padding_mode='zeros', align_corners=True); return warped_image

# --- Class Definition from: src/diffusion_model.py ---
class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim): super().__init__(); self.dim = dim
    def forward(self, time): device = time.device; half_dim = self.dim // 2; embeddings = math.log(10000) / (half_dim - 1); embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings); embeddings = time[:, None] * embeddings[None, :]; embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1); return embeddings
class ResnetBlock(nn.Module):
    def __init__(self, in_channels, out_channels, *, time_emb_dim=None):
        super().__init__(); self.mlp = (nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, out_channels)) if time_emb_dim is not None else None); self.block1 = nn.Sequential(nn.Conv2d(in_channels, out_channels, 3, padding=1), nn.GroupNorm(8, out_channels), nn.SiLU()); self.block2 = nn.Sequential(nn.Conv2d(out_channels, out_channels, 3, padding=1), nn.GroupNorm(8, out_channels), nn.SiLU()); self.res_conv = nn.Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
    def forward(self, x, time_emb=None):
        h = self.block1(x)
        if self.mlp is not None and time_emb is not None:
            time_emb = self.mlp(time_emb)
            h = h + time_emb.unsqueeze(-1).unsqueeze(-1)
        h = self.block2(h)
        return h + self.res_conv(x)
class AttentionBlock(nn.Module):
    def __init__(self, channels, time_emb=None): super().__init__(); self.gn = nn.GroupNorm(8, channels); self.qkv = nn.Conv2d(channels, channels * 3, 1); self.out = nn.Conv2d(channels, channels, 1)
    def forward(self, x, time_emb=None): b, c, h, w = x.shape; x_in = x; x = self.gn(x); x = self.qkv(x); q, k, v = torch.chunk(x, 3, dim=1); q = q.view(b, c, h * w); k = k.view(b, c, h * w); v = v.view(b, c, h * w); k = k.softmax(dim=-1); attn = torch.einsum("b c i, b c j -> b i j", q, k); out = torch.einsum("b i j, b c j -> b c i", attn, v); out = out.view(b, c, h, w); return self.out(out) + x_in
class ConditionalUNet(nn.Module):
    """
    The main U-Net for the diffusion model.
    This version has the FINAL corrected upsampling path logic.
    """
    def __init__(self, in_channels, model_channels, out_channels, time_emb_dim=256, condition_channels=6):
        super().__init__()
        
        # --- Time embedding and Initial Conv ---
        self.time_mlp = nn.Sequential(
            SinusoidalPositionEmbeddings(time_emb_dim),
            nn.Linear(time_emb_dim, time_emb_dim),
            nn.ReLU()
        )
        self.init_conv = nn.Conv2d(in_channels + condition_channels, model_channels, kernel_size=3, padding=1)

        # --- Downsampling Path ---
        self.down1 = ResnetBlock(model_channels, 128, time_emb_dim=time_emb_dim)
        self.down2 = ResnetBlock(128, 128, time_emb_dim=time_emb_dim)
        self.down3 = ResnetBlock(128, 256, time_emb_dim=time_emb_dim)
        self.down4 = AttentionBlock(256)
        self.down5 = ResnetBlock(256, 256, time_emb_dim=time_emb_dim)
        self.down6 = ResnetBlock(256, 512, time_emb_dim=time_emb_dim)
        
        self.pool = nn.MaxPool2d(2)

        # --- Bottleneck ---
        self.mid1 = ResnetBlock(512, 1024, time_emb_dim=time_emb_dim)
        self.mid_attn = AttentionBlock(1024)
        self.mid2 = ResnetBlock(1024, 512, time_emb_dim=time_emb_dim)

        # --- Upsampling Path ---
        self.up1 = nn.ConvTranspose2d(512, 256, 2, 2)
        self.up_res1 = ResnetBlock(512, 256, time_emb_dim=time_emb_dim)
        self.up_attn1 = AttentionBlock(256)
        self.up_res2 = ResnetBlock(256, 256, time_emb_dim=time_emb_dim)
        
        self.up2 = nn.ConvTranspose2d(256, 128, 2, 2)
        self.up_res3 = ResnetBlock(256, 128, time_emb_dim=time_emb_dim)
        self.up_res4 = ResnetBlock(128, 128, time_emb_dim=time_emb_dim)

        # --- Final Output ---
        self.out_res = ResnetBlock(128 + model_channels, 64, time_emb_dim=time_emb_dim)
        self.out_conv = nn.Conv2d(64, out_channels, 1)

    def forward(self, x, time, condition):
        # 1. Initial processing
        t = self.time_mlp(time)
        condition_downsampled = F.interpolate(condition, size=x.shape[2:], mode='bilinear', align_corners=False)
        x = torch.cat([x, condition_downsampled], dim=1)
        x = self.init_conv(x)
        
        # Store for skip connection
        r0 = x.clone()

        # 2. Downsampling
        x = self.down1(x, t)
        x = self.down2(x, t)
        r1 = x.clone() # Residual after first block
        x = self.pool(x)
        
        x = self.down3(x, t)
        x = self.down4(x, t)
        x = self.down5(x, t)
        r2 = x.clone() # Residual after second block
        x = self.pool(x)

        x = self.down6(x, t)
        
        # 3. Bottleneck
        x = self.mid1(x, t)
        x = self.mid_attn(x, t)
        x = self.mid2(x, t)

        # 4. Upsampling
        x = self.up1(x) # Upsample
        x = torch.cat([x, r2], dim=1) # Concatenate with skip connection
        x = self.up_res1(x, t) # Process combined tensor
        x = self.up_attn1(x, t)
        x = self.up_res2(x, t)

        x = self.up2(x) # Upsample
        x = torch.cat([x, r1], dim=1) # Concatenate with skip connection
        x = self.up_res3(x, t) # Process combined tensor
        x = self.up_res4(x, t)
        
        # 5. Final output stage
        x = torch.cat([x, r0], dim=1) # Concatenate with initial residual
        x = self.out_res(x, t)
        
        return self.out_conv(x)
# ===================================================================
# CELL 2: CONFIGURATION AND TRAINING SCRIPT
# This part now has the correct paths and logic.
# ===================================================================

# --- Training Configuration ---
DATA_ROOT = '/kaggle/input/clothe/clothes_tryon_dataset/train'
OUTPUT_DIR = '/kaggle/working/'

# CORRECTED PATH: Point directly to the uploaded model file
GMM_CHECKPOINT_PATH = '/kaggle/input/gmm_5epoch/pytorch/default/1/gmm_final (works but low epochs).pth'

# Define subdirectories for our diffusion model outputs
CHECKPOINT_DIR_DIFF = os.path.join(OUTPUT_DIR, 'checkpoints_diffusion')
VISUALIZATION_DIR_DIFF = os.path.join(OUTPUT_DIR, 'visualizations_diffusion')

BATCH_SIZE = 2
NUM_EPOCHS = 5
LEARNING_RATE = 1e-4
IMAGE_SIZE = (256, 192)
ORIGINAL_IMAGE_SIZE = (768, 1024)
VALIDATION_SPLIT = 0.1
VAE_MODEL_ID = "stabilityai/stable-diffusion-2-1-base"
VAE_SUBFOLDER = "vae"
NUM_TRAIN_TIMESTEPS = 1000

# Create directories if they don't exist
os.makedirs(CHECKPOINT_DIR_DIFF, exist_ok=True)
os.makedirs(VISUALIZATION_DIR_DIFF, exist_ok=True)

@torch.no_grad()
def evaluate_and_visualize(epoch, unet, gmm, vae, noise_scheduler, val_loader, device):
    unet.eval(); gmm.eval(); val_loss = 0.0; progress_bar = tqdm(val_loader, desc="Validating", leave=False)
    for batch in progress_bar:
        person_image = batch['person_image'].to(device); cloth_image = batch['cloth_image'].to(device); pose_map = batch['pose_map'].to(device); agnostic_person = batch['agnostic_person'].to(device)
        flow = gmm(cloth_image, pose_map); warped_cloth = warp_cloth_with_flow(cloth_image, flow); condition = torch.cat([agnostic_person, warped_cloth], dim=1)
        latents = vae.encode(person_image).latent_dist.sample() * vae.config.scaling_factor
        noise = torch.randn_like(latents); timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (latents.shape[0],), device=device).long()
        noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps); predicted_noise = unet(noisy_latents, timesteps, condition); loss = F.mse_loss(predicted_noise, noise); val_loss += loss.item()
    avg_val_loss = val_loss / len(val_loader)
    sample_condition = condition[0:1]; original_vis_image = person_image[0:1]; sample_latents = torch.randn(1, 4, IMAGE_SIZE[0] // 8, IMAGE_SIZE[1] // 8).to(device)
    for t in tqdm(noise_scheduler.timesteps, desc="Generating sample", leave=False):
        pred_noise = unet(sample_latents, t.unsqueeze(0).to(device), sample_condition); sample_latents = noise_scheduler.step(pred_noise, t, sample_latents).prev_sample
    sample_latents = 1 / vae.config.scaling_factor * sample_latents; generated_image = vae.decode(sample_latents).sample
    original_vis = (original_vis_image + 1) / 2; condition_vis = (sample_condition[:, :3] + 1) / 2; warped_vis = (sample_condition[:, 3:] + 1) / 2; generated_vis = (generated_image + 1) / 2
    comparison = torch.cat([original_vis, condition_vis, warped_vis, generated_vis], dim=0); save_image(comparison, os.path.join(VISUALIZATION_DIR_DIFF, f'epoch_{epoch+1}_sample.png'), nrow=4)
    unet.train(); return avg_val_loss

def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    print("Loading pre-trained VAE and GMM...")
    vae = AutoencoderKL.from_pretrained(VAE_MODEL_ID, subfolder=VAE_SUBFOLDER).to(device)
    if not os.path.isfile(GMM_CHECKPOINT_PATH):
        print(f"FATAL ERROR: GMM checkpoint file not found at '{GMM_CHECKPOINT_PATH}'")
        return
    print(f"Found GMM checkpoint at: {GMM_CHECKPOINT_PATH}")
    gmm = GMM(in_channels_pose=25).to(device)
    gmm.load_state_dict(torch.load(GMM_CHECKPOINT_PATH, map_location=device))
    vae.requires_grad_(False); gmm.requires_grad_(False)
    noise_scheduler = DDPMScheduler(num_train_timesteps=NUM_TRAIN_TIMESTEPS, beta_schedule='squaredcos_cap_v2')
    print("Loading and splitting dataset...")
    full_dataset = FashionVTONDataset(data_root=DATA_ROOT, image_size=IMAGE_SIZE, original_image_size=ORIGINAL_IMAGE_SIZE)
    val_size = int(len(full_dataset) * VALIDATION_SPLIT); train_size = len(full_dataset) - val_size
    train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)
    print(f"Dataset loaded: {train_size} training samples, {val_size} validation samples.")
    unet = ConditionalUNet(in_channels=4, model_channels=128, out_channels=4, condition_channels=6).to(device)
    optimizer = torch.optim.AdamW(unet.parameters(), lr=LEARNING_RATE)
    print("Starting Diffusion Model training...")
    best_val_loss = float('inf')
    for epoch in range(NUM_EPOCHS):
        unet.train(); train_loss = 0.0
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}")
        for batch in progress_bar:
            with torch.no_grad():
                person_image = batch['person_image'].to(device); cloth_image = batch['cloth_image'].to(device); pose_map = batch['pose_map'].to(device); agnostic_person = batch['agnostic_person'].to(device)
                flow = gmm(cloth_image, pose_map); warped_cloth = warp_cloth_with_flow(cloth_image, flow); condition = torch.cat([agnostic_person, warped_cloth], dim=1)
                latents = vae.encode(person_image).latent_dist.sample() * vae.config.scaling_factor
            noise = torch.randn_like(latents); timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (latents.shape[0],), device=device).long()
            noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
            optimizer.zero_grad(); predicted_noise = unet(noisy_latents, timesteps, condition); loss = F.mse_loss(predicted_noise, noise); loss.backward(); optimizer.step()
            train_loss += loss.item(); progress_bar.set_postfix(loss=loss.item())
        avg_train_loss = train_loss / len(train_loader)
        avg_val_loss = evaluate_and_visualize(epoch, unet, gmm, vae, noise_scheduler, val_loader, device)
        print(f"Epoch {epoch+1}/{NUM_EPOCHS} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            best_model_path = os.path.join(CHECKPOINT_DIR_DIFF, 'unet_best.pth')
            torch.save(unet.state_dict(), best_model_path)
            print(f"🎉 New best model saved with validation loss: {best_val_loss:.4f} at {best_model_path}")
    print("Diffusion model training finished."); print(f"Best validation loss achieved: {best_val_loss:.4f}")

if __name__ == '__main__':
    main()

Using device: cuda
Loading pre-trained VAE and GMM...
Found GMM checkpoint at: /kaggle/input/gmm_5epoch/pytorch/default/1/gmm_final (works but low epochs).pth
Loading and splitting dataset...
Dataset loaded: 10483 training samples, 1164 validation samples.
Starting Diffusion Model training...


Epoch 1/5: 100%|██████████| 5242/5242 [15:17<00:00,  5.71it/s, loss=0.0227] 
                                                                      

Epoch 1/5 | Train Loss: 0.2367 | Val Loss: 0.2078
🎉 New best model saved with validation loss: 0.2078 at /kaggle/working/checkpoints_diffusion/unet_best.pth


Epoch 2/5: 100%|██████████| 5242/5242 [15:24<00:00,  5.67it/s, loss=0.199]  
                                                                      

Epoch 2/5 | Train Loss: 0.2005 | Val Loss: 0.1962
🎉 New best model saved with validation loss: 0.1962 at /kaggle/working/checkpoints_diffusion/unet_best.pth


Epoch 3/5: 100%|██████████| 5242/5242 [15:24<00:00,  5.67it/s, loss=0.00719]
                                                                      

Epoch 3/5 | Train Loss: 0.1925 | Val Loss: 0.1936
🎉 New best model saved with validation loss: 0.1936 at /kaggle/working/checkpoints_diffusion/unet_best.pth


Epoch 4/5: 100%|██████████| 5242/5242 [15:23<00:00,  5.67it/s, loss=0.475]  
                                                                      

Epoch 4/5 | Train Loss: 0.1897 | Val Loss: 0.1893
🎉 New best model saved with validation loss: 0.1893 at /kaggle/working/checkpoints_diffusion/unet_best.pth


Epoch 5/5: 100%|██████████| 5242/5242 [15:23<00:00,  5.68it/s, loss=0.00616]
                                                                      

Epoch 5/5 | Train Loss: 0.1809 | Val Loss: 0.1778
🎉 New best model saved with validation loss: 0.1778 at /kaggle/working/checkpoints_diffusion/unet_best.pth
Diffusion model training finished.
Best validation loss achieved: 0.1778


In [None]:
# ===================================================================
# FINAL KAGGLE NOTEBOOK CELL FOR HD-VITON TRAINING
# Includes all class definitions and a robust, resumable GAN training loop
# ===================================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import os
import json
import cv2
import numpy as np
from PIL import Image
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from torchvision.utils import save_image
import torchvision.models as models
from tqdm import tqdm

# --- CLASS DEFINITIONS (Dataset, GMM, VGG Loss) ---
# These are the components we need from the previous steps.
class FashionVTONDataset(Dataset):
    def __init__(self, data_root, image_size=(256, 192), original_image_size=(768, 1024)):
        self.data_root = data_root; self.image_size = image_size; self.original_image_size = original_image_size; self.image_dir = os.path.join(data_root, 'image'); self.cloth_dir = os.path.join(data_root, 'cloth'); self.cloth_mask_dir = os.path.join(data_root, 'cloth-mask'); self.pose_dir = os.path.join(data_root, 'openpose_json'); self.parse_dir = os.path.join(data_root, 'image-parse-v3'); self.image_files = sorted([f for f in os.listdir(self.image_dir) if f.endswith(('.jpg', '.png'))]); self.transform = transforms.Compose([transforms.Resize(self.image_size, interpolation=transforms.InterpolationMode.BILINEAR), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]); self.mask_transform = transforms.Compose([transforms.Resize(self.image_size, interpolation=transforms.InterpolationMode.NEAREST), transforms.ToTensor()])
    def __len__(self): return len(self.image_files)
    def find_upper_cloth_label(self, pose_keypoints, parse_array):
        torso_indices = [1, 2, 5, 8, 9, 12]; visible_points = []
        for i in torso_indices:
            if i < len(pose_keypoints): x, y, conf = pose_keypoints[i];
            if conf > 0.1: visible_points.append((int(x), int(y)))
        if len(visible_points) < 3: return 5
        x_coords, y_coords = zip(*visible_points); min_x, max_x = min(x_coords), max(x_coords); min_y, max_y = min(y_coords), max(y_coords)
        if max_x <= min_x or max_y <= min_y: return 5
        torso_parse_area = parse_array[min_y:max_y, min_x:max_x]; unique_labels, counts = np.unique(torso_parse_area, return_counts=True); non_zero_mask = (unique_labels != 0)
        if np.any(non_zero_mask): return unique_labels[non_zero_mask][np.argmax(counts[non_zero_mask])]
        else: return 5
    def __getitem__(self, idx):
        image_name = self.image_files[idx]; base_name = os.path.splitext(image_name)[0]
        person_image = Image.open(os.path.join(self.image_dir, image_name)).convert('RGB'); cloth_image = Image.open(os.path.join(self.cloth_dir, image_name)).convert('RGB'); cloth_mask = Image.open(os.path.join(self.cloth_mask_dir, image_name)).convert('L')
        try:
            with open(os.path.join(self.pose_dir, f"{base_name}_keypoints.json"), 'r') as f: pose_data = json.load(f)
            pose_keypoints = np.array(pose_data['people'][0]['pose_keypoints_2d']).reshape(-1, 3)
        except (FileNotFoundError, IndexError): pose_keypoints = np.zeros((25, 3), dtype=np.float32)
        parse_path = os.path.join(self.parse_dir, f"{base_name}.png")
        if not os.path.exists(parse_path): parse_path = os.path.join(self.parse_dir, f"{base_name}.jpg")
        parse_array_orig = np.array(Image.open(parse_path).convert('L'))
        upper_cloth_label = self.find_upper_cloth_label(pose_keypoints, parse_array_orig)
        parse_array_resized = cv2.resize(parse_array_orig, self.image_size[::-1], interpolation=cv2.INTER_NEAREST)
        person_cloth_mask = (parse_array_resized == upper_cloth_label).astype(np.float32)
        person_image_tensor = self.transform(person_image); cloth_image_tensor = self.transform(cloth_image); cloth_mask_tensor = self.mask_transform(cloth_mask)
        pose_map_tensor = torch.from_numpy(self.create_pose_map(pose_keypoints)).float()
        blurred_mask_tensor = torch.from_numpy(cv2.GaussianBlur(person_cloth_mask, (5, 5), 0)).unsqueeze(0)
        agnostic_person_tensor = person_image_tensor * (1 - blurred_mask_tensor)
        warped_cloth_tensor = person_image_tensor * torch.from_numpy(person_cloth_mask).unsqueeze(0)
        return {'person_image': person_image_tensor, 'cloth_image': cloth_image_tensor, 'cloth_mask': cloth_mask_tensor, 'agnostic_person': agnostic_person_tensor, 'pose_map': pose_map_tensor, 'warped_cloth': warped_cloth_tensor}
    def create_pose_map(self, keypoints):
        h, w = self.image_size; orig_w, orig_h = self.original_image_size; num_keypoints = keypoints.shape[0]
        pose_map = np.zeros((num_keypoints, h, w), dtype=np.float32)
        for i, point in enumerate(keypoints):
            if point[2] > 0.1:
                x, y = int(point[0] * w / orig_w), int(point[1] * h / orig_h)
                if 0 <= x < w and 0 <= y < h: cv2.circle(pose_map[i], (x, y), radius=3, color=1, thickness=-1)
        return pose_map

class GMM(nn.Module):
    def __init__(self, in_channels_cloth=3, in_channels_pose=25, out_channels_flow=2):
        super(GMM, self).__init__(); self.encoder1 = self.conv_block(in_channels_cloth + in_channels_pose, 64); self.encoder2 = self.conv_block(64, 128); self.encoder3 = self.conv_block(128, 256); self.encoder4 = self.conv_block(256, 512); self.pool = nn.MaxPool2d(2, 2); self.bottleneck = self.conv_block(512, 1024); self.upconv4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2); self.decoder4 = self.conv_block(1024, 512); self.upconv3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2); self.decoder3 = self.conv_block(512, 256); self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2); self.decoder2 = self.conv_block(256, 128); self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2); self.decoder1 = self.conv_block(128, 64); self.conv_out = nn.Conv2d(64, out_channels_flow, kernel_size=1); self.tanh = nn.Tanh()
    def conv_block(self, in_channels, out_channels): return nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), nn.ReLU(inplace=True))
    def forward(self, cloth_image, pose_map):
        x = torch.cat([cloth_image, pose_map], dim=1); e1 = self.encoder1(x); p1 = self.pool(e1); e2 = self.encoder2(p1); p2 = self.pool(e2); e3 = self.encoder3(p2); p3 = self.pool(e3); e4 = self.encoder4(p3); p4 = self.pool(e4); b = self.bottleneck(p4); d4 = self.upconv4(b); d4 = torch.cat([d4, e4], dim=1); d4 = self.decoder4(d4); d3 = self.upconv3(d4); d3 = torch.cat([d3, e3], dim=1); d3 = self.decoder3(d3); d2 = self.upconv2(d3); d2 = torch.cat([d2, e2], dim=1); d2 = self.decoder2(d2); d1 = self.upconv1(d2); d1 = torch.cat([d1, e1], dim=1); d1 = self.decoder1(d1); flow_field = self.conv_out(d1); flow_field = self.tanh(flow_field); return flow_field

def warp_cloth_with_flow(cloth_image, flow_field):
    flow_field = flow_field.permute(0, 2, 3, 1); warped_image = F.grid_sample(cloth_image, flow_field, mode='bilinear', padding_mode='zeros', align_corners=True); return warped_image

class VGGPerceptualLoss(nn.Module):
    def __init__(self, resize=True):
        super(VGGPerceptualLoss, self).__init__(); vgg = models.vgg19(weights=models.VGG19_Weights.IMAGENET1K_V1).features; self.vgg_layers = vgg[:35].eval();
        for param in self.vgg_layers.parameters(): param.requires_grad = False
        self.l1 = nn.L1Loss(); self.transform = nn.functional.interpolate; self.resize = resize
        self.register_buffer("mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)); self.register_buffer("std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
    def forward(self, pred, target):
        pred = (pred + 1) / 2; target = (target + 1) / 2; pred = (pred - self.mean) / self.std; target = (target - self.mean) / self.std
        if self.resize: pred = self.transform(pred, mode='bilinear', size=(224, 224), align_corners=False); target = self.transform(target, mode='bilinear', size=(224, 224), align_corners=False)
        pred_features = self.vgg_layers(pred); target_features = self.vgg_layers(target); return self.l1(pred_features, target_features)

# --- NEW HD-VITON MODEL DEFINITIONS ---
class GeneratorHD(nn.Module):
    """The main U-Net Generator for HD-VITON. Takes agnostic person and warped cloth."""
    def __init__(self, in_channels=6, out_channels=3):
        super(GeneratorHD, self).__init__()
        self.encoder1 = nn.Sequential(nn.Conv2d(in_channels, 64, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, inplace=True))
        self.encoder2 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1), nn.InstanceNorm2d(128), nn.LeakyReLU(0.2, inplace=True))
        self.encoder3 = nn.Sequential(nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1), nn.InstanceNorm2d(256), nn.LeakyReLU(0.2, inplace=True))
        self.encoder4 = nn.Sequential(nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1), nn.InstanceNorm2d(512), nn.LeakyReLU(0.2, inplace=True))
        self.bottleneck = nn.Sequential(nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1), nn.ReLU(inplace=True))
        self.decoder1 = nn.Sequential(nn.ConvTranspose2d(512, 512, kernel_size=4, stride=2, padding=1), nn.InstanceNorm2d(512), nn.ReLU(inplace=True))
        self.decoder2 = nn.Sequential(nn.ConvTranspose2d(1024, 256, kernel_size=4, stride=2, padding=1), nn.InstanceNorm2d(256), nn.ReLU(inplace=True))
        self.decoder3 = nn.Sequential(nn.ConvTranspose2d(512, 128, kernel_size=4, stride=2, padding=1), nn.InstanceNorm2d(128), nn.ReLU(inplace=True))
        self.decoder4 = nn.Sequential(nn.ConvTranspose2d(256, 64, kernel_size=4, stride=2, padding=1), nn.InstanceNorm2d(64), nn.ReLU(inplace=True))
        self.final_layer = nn.Sequential(nn.ConvTranspose2d(128, out_channels, kernel_size=4, stride=2, padding=1), nn.Tanh())

    def forward(self, x_agnostic, x_warped_cloth):
        x = torch.cat([x_agnostic, x_warped_cloth], 1)
        e1 = self.encoder1(x)
        e2 = self.encoder2(e1)
        e3 = self.encoder3(e2)
        e4 = self.encoder4(e3)
        b = self.bottleneck(e4)
        d1 = self.decoder1(b)
        d2 = self.decoder2(torch.cat([d1, e4], 1))
        d3 = self.decoder3(torch.cat([d2, e3], 1))
        d4 = self.decoder4(torch.cat([d3, e2], 1))
        out = self.final_layer(torch.cat([d4, e1], 1))
        return out

class Discriminator(nn.Module):
    """A PatchGAN Discriminator to classify image patches as real or fake."""
    def __init__(self, in_channels=6):
        super(Discriminator, self).__init__()
        def discriminator_block(in_filters, out_filters, stride=2, norm=True):
            layers = [nn.Conv2d(in_filters, out_filters, 4, stride=stride, padding=1)]
            if norm: layers.append(nn.InstanceNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers
        self.model = nn.Sequential(
            *discriminator_block(in_channels, 64, norm=False),
            *discriminator_block(64, 128),
            *discriminator_block(128, 256),
            *discriminator_block(256, 512, stride=1),
            nn.Conv2d(512, 1, 4, padding=1)
        )
    def forward(self, img_A, img_B):
        img_input = torch.cat((img_A, img_B), 1)
        return self.model(img_input)

# --- CONFIGURATION & TRAINING SCRIPT ---
# Paths
DATA_ROOT = '/kaggle/input/clothe/clothes_tryon_dataset/train'
GMM_CHECKPOINT_PATH = '/kaggle/input/gmm_5epoch/pytorch/default/1/gmm_final (works but low epochs).pth'
OUTPUT_DIR = '/kaggle/working/'
CHECKPOINT_DIR_HD = os.path.join(OUTPUT_DIR, 'checkpoints_hd_viton')
VISUALIZATION_DIR_HD = os.path.join(OUTPUT_DIR, 'visualizations_hd_viton')

# Hyperparameters
BATCH_SIZE = 4; NUM_EPOCHS = 200; LEARNING_RATE_G = 2e-4; LEARNING_RATE_D = 2e-4
IMAGE_SIZE = (256, 192); ORIGINAL_IMAGE_SIZE = (768, 1024)
LAMBDA_L1 = 10.0; LAMBDA_VGG = 10.0

os.makedirs(CHECKPOINT_DIR_HD, exist_ok=True); os.makedirs(VISUALIZATION_DIR_HD, exist_ok=True)

def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu"); print(f"Using device: {device}")
    
    # --- Initialize Models ---
    generator = GeneratorHD().to(device)
    discriminator = Discriminator().to(device)
    gmm = GMM(in_channels_pose=25).to(device)
    vgg_loss_fn = VGGPerceptualLoss().to(device)
    
    # --- Load Pre-trained GMM ---
    if not os.path.isfile(GMM_CHECKPOINT_PATH): print(f"FATAL ERROR: GMM checkpoint not found at '{GMM_CHECKPOINT_PATH}'"); return
    gmm.load_state_dict(torch.load(GMM_CHECKPOINT_PATH, map_location=device)); gmm.eval(); gmm.requires_grad_(False)
    
    # --- Optimizers and Loss ---
    optimizer_G = torch.optim.Adam(generator.parameters(), lr=LEARNING_RATE_G, betas=(0.5, 0.999))
    optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=LEARNING_RATE_D, betas=(0.5, 0.999))
    criterion_GAN = nn.BCEWithLogitsLoss().to(device)
    criterion_L1 = nn.L1Loss().to(device)

    # --- Resumption Logic ---
    start_epoch = 0; checkpoint_path = os.path.join(CHECKPOINT_DIR_HD, 'hd_latest.pth')
    if os.path.isfile(checkpoint_path):
        print(f"Resuming HD-VITON training from checkpoint: {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path)
        generator.load_state_dict(checkpoint['generator_state_dict'])
        discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
        optimizer_G.load_state_dict(checkpoint['optimizer_G_state_dict'])
        optimizer_D.load_state_dict(checkpoint['optimizer_D_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        print(f"Resumed from Epoch {start_epoch}")
    else: print("Starting HD-VITON training from scratch.")
    
    # --- Data Loading ---
    train_dataset = FashionVTONDataset(data_root=DATA_ROOT, image_size=IMAGE_SIZE, original_image_size=ORIGINAL_IMAGE_SIZE)
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
    
    for epoch in range(start_epoch, NUM_EPOCHS):
        generator.train(); discriminator.train()
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}")
        
        for i, batch in enumerate(progress_bar):
            # --- Prepare Inputs ---
            real_image = batch['person_image'].to(device)
            cloth_image = batch['cloth_image'].to(device)
            pose_map = batch['pose_map'].to(device)
            agnostic_person = batch['agnostic_person'].to(device)
            with torch.no_grad():
                flow = gmm(cloth_image, pose_map)
                warped_cloth = warp_cloth_with_flow(cloth_image, flow)

            # --- Train Discriminator ---
            optimizer_D.zero_grad()
            fake_image = generator(agnostic_person, warped_cloth)
            # Real
            pred_real = discriminator(agnostic_person, real_image)
            loss_D_real = criterion_GAN(pred_real, torch.ones_like(pred_real))
            # Fake
            pred_fake = discriminator(agnostic_person, fake_image.detach())
            loss_D_fake = criterion_GAN(pred_fake, torch.zeros_like(pred_fake))
            # Total D loss
            loss_D = (loss_D_real + loss_D_fake) * 0.5
            loss_D.backward(); optimizer_D.step()

            # --- Train Generator ---
            optimizer_G.zero_grad()
            pred_fake_for_G = discriminator(agnostic_person, fake_image)
            loss_G_gan = criterion_GAN(pred_fake_for_G, torch.ones_like(pred_fake_for_G))
            loss_G_l1 = criterion_L1(fake_image, real_image) * LAMBDA_L1
            loss_G_vgg = vgg_loss_fn(fake_image, real_image) * LAMBDA_VGG
            # Total G loss
            loss_G = loss_G_gan + loss_G_l1 + loss_G_vgg
            loss_G.backward(); optimizer_G.step()
            
            progress_bar.set_postfix(D_loss=loss_D.item(), G_loss=loss_G.item(), G_gan=loss_G_gan.item(), G_l1=loss_G_l1.item())

        # --- Visualization & Checkpointing ---
        if epoch % 5 == 0:
            with torch.no_grad():
                vis_batch = torch.cat([real_image.cpu(), agnostic_person.cpu(), warped_cloth.cpu(), fake_image.cpu()], 0)
                save_image((vis_batch + 1) / 2.0, os.path.join(VISUALIZATION_DIR_HD, f'epoch_{epoch+1}.png'), nrow=BATCH_SIZE)
        
        latest_checkpoint_state = {
            'epoch': epoch, 'generator_state_dict': generator.state_dict(), 'discriminator_state_dict': discriminator.state_dict(),
            'optimizer_G_state_dict': optimizer_G.state_dict(), 'optimizer_D_state_dict': optimizer_D.state_dict()
        }
        torch.save(latest_checkpoint_state, checkpoint_path)
        torch.save(generator.state_dict(), os.path.join(CHECKPOINT_DIR_HD, 'generator_latest.pth'))
        
    print("Training finished.")

if __name__ == '__main__':
    main()

In [30]:
%%writefile train_tps_warper.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision.utils import save_image
import os
from tqdm import tqdm
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
import json
import cv2
import numpy as np
from PIL import Image
from torchvision import transforms
import time

# ===================================================================
# ALL CLASS DEFINITIONS FOR THIS STAGE
# ===================================================================

# --- Dataset Class (Unchanged) ---
class FashionVTONDataset(Dataset):
    def __init__(self, data_root, image_size=(256, 192), original_image_size=(768, 1024)):
        self.data_root = data_root; self.image_size = image_size; self.original_image_size = original_image_size; self.image_dir = os.path.join(data_root, 'image'); self.cloth_dir = os.path.join(data_root, 'cloth'); self.cloth_mask_dir = os.path.join(data_root, 'cloth-mask'); self.pose_dir = os.path.join(data_root, 'openpose_json'); self.parse_dir = os.path.join(data_root, 'image-parse-v3'); self.image_files = sorted([f for f in os.listdir(self.image_dir) if f.endswith(('.jpg', '.png'))]); self.transform = transforms.Compose([transforms.Resize(self.image_size, interpolation=transforms.InterpolationMode.BILINEAR), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]); self.mask_transform = transforms.Compose([transforms.Resize(self.image_size, interpolation=transforms.InterpolationMode.NEAREST), transforms.ToTensor()])
    def __len__(self): return len(self.image_files)
    def find_upper_cloth_label(self, pose_keypoints, parse_array):
        torso_indices = [1, 2, 5, 8, 9, 12]; visible_points = []
        for i in torso_indices:
            if i < len(pose_keypoints): x, y, conf = pose_keypoints[i];
            if conf > 0.1: visible_points.append((int(x), int(y)))
        if len(visible_points) < 3: return 5
        x_coords, y_coords = zip(*visible_points); min_x, max_x = min(x_coords), max(x_coords); min_y, max_y = min(y_coords), max(y_coords)
        if max_x <= min_x or max_y <= min_y: return 5
        torso_parse_area = parse_array[min_y:max_y, min_x:max_x]; unique_labels, counts = np.unique(torso_parse_area, return_counts=True); non_zero_mask = (unique_labels != 0)
        if np.any(non_zero_mask): return unique_labels[non_zero_mask][np.argmax(counts[non_zero_mask])]
        else: return 5
    def __getitem__(self, idx):
        image_name = self.image_files[idx]; base_name = os.path.splitext(image_name)[0]
        person_image = Image.open(os.path.join(self.image_dir, image_name)).convert('RGB'); cloth_image = Image.open(os.path.join(self.cloth_dir, image_name)).convert('RGB'); cloth_mask = Image.open(os.path.join(self.cloth_mask_dir, image_name)).convert('L')
        try:
            with open(os.path.join(self.pose_dir, f"{base_name}_keypoints.json"), 'r') as f: pose_data = json.load(f)
            pose_keypoints = np.array(pose_data['people'][0]['pose_keypoints_2d']).reshape(-1, 3)
        except (FileNotFoundError, IndexError): pose_keypoints = np.zeros((25, 3), dtype=np.float32)
        parse_path = os.path.join(self.parse_dir, f"{base_name}.png")
        if not os.path.exists(parse_path): parse_path = os.path.join(self.parse_dir, f"{base_name}.jpg")
        parse_array_orig = np.array(Image.open(parse_path).convert('L'))
        upper_cloth_label = self.find_upper_cloth_label(pose_keypoints, parse_array_orig)
        parse_array_resized = cv2.resize(parse_array_orig, self.image_size[::-1], interpolation=cv2.INTER_NEAREST)
        person_cloth_mask = (parse_array_resized == upper_cloth_label).astype(np.float32)
        person_image_tensor = self.transform(person_image); cloth_image_tensor = self.transform(cloth_image); cloth_mask_tensor = self.mask_transform(cloth_mask)
        pose_map_tensor = torch.from_numpy(self.create_pose_map(pose_keypoints)).float()
        blurred_mask_tensor = torch.from_numpy(cv2.GaussianBlur(person_cloth_mask, (5, 5), 0)).unsqueeze(0)
        agnostic_person_tensor = person_image_tensor * (1 - blurred_mask_tensor)
        warped_cloth_tensor = person_image_tensor * torch.from_numpy(person_cloth_mask).unsqueeze(0)
        return {'person_image': person_image_tensor, 'cloth_image': cloth_image_tensor, 'cloth_mask': cloth_mask_tensor, 'agnostic_person': agnostic_person_tensor, 'pose_map': pose_map_tensor, 'warped_cloth': warped_cloth_tensor}
    def create_pose_map(self, keypoints):
        h, w = self.image_size; orig_w, orig_h = self.original_image_size; num_keypoints = keypoints.shape[0]
        pose_map = np.zeros((num_keypoints, h, w), dtype=np.float32)
        for i, point in enumerate(keypoints):
            if point[2] > 0.1:
                x, y = int(point[0] * w / orig_w), int(point[1] * h / orig_h)
                if 0 <= x < w and 0 <= y < h: cv2.circle(pose_map[i], (x, y), radius=3, color=1, thickness=-1)
        return pose_map

# --- NEW TPS Warper Network Definition ---
# --- NEW TPS Warper Network Definition ---
class TPSWarper(nn.Module):
    """
    A network that predicts TPS transformation parameters.
    This version is simplified and designed to work directly with the pure PyTorch warp function.
    """
    def __init__(self, feature_channels=256, num_control_points=25):
        super().__init__()
        self.num_control_points = num_control_points
        
        self.cloth_feature_extractor = self._make_feature_extractor(3, feature_channels)
        self.person_feature_extractor = self._make_feature_extractor(28, feature_channels)
        
        self.regressor = nn.Sequential(
            nn.Linear(feature_channels * 2, 256), nn.ReLU(inplace=True),
            nn.Linear(256, 128), nn.ReLU(inplace=True),
            nn.Linear(128, num_control_points * 2),
            nn.Tanh() 
        )
        
        grid_size = int(np.sqrt(num_control_points))
        if grid_size * grid_size != num_control_points:
            raise ValueError(f"num_control_points must be a perfect square. Got {num_control_points}")
            
        grid = torch.linspace(-0.5, 0.5, grid_size) # Use a smaller default grid
        x, y = torch.meshgrid(grid, grid, indexing='ij')
        self.register_buffer('source_control_points', torch.stack([x.flatten(), y.flatten()], dim=-1))

    def _make_feature_extractor(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, 64, 3, 2, 1), nn.ReLU(),
            nn.Conv2d(64, 128, 3, 2, 1), nn.ReLU(),
            nn.Conv2d(128, 256, 3, 2, 1), nn.ReLU(),
            nn.AdaptiveAvgPool2d(1)
        )

    def forward(self, cloth_image, pose_map, agnostic_person):
        cloth_features = self.cloth_feature_extractor(cloth_image).view(cloth_image.size(0), -1)
        person_input = torch.cat([agnostic_person, pose_map], dim=1)
        person_features = self.person_feature_extractor(person_input).view(person_input.size(0), -1)

        # Concatenate features instead of multiplying for more stability
        correlation_features = torch.cat([cloth_features, person_features], dim=1)
        offsets = self.regressor(correlation_features).view(-1, self.num_control_points, 2)
        
        destination_control_points = self.source_control_points + (offsets * 0.5) # Scale offsets
        
        return self.source_control_points, destination_control_points
# Helper function to perform TPS warp given control points
# Helper function to perform TPS warp given control points
# Helper function to perform TPS warp given control points
# Helper function to perform TPS warp given control points
def tps_warp(source_image, src_pts, dst_pts):
    """
    Performs a batch-friendly affine warp using PyTorch only.
    It solves for the best-fit affine matrix for each item in the batch
    and then applies it with F.grid_sample.
    """
    # List to store the affine matrices for each item in the batch
    M_list = []
    
    # --- FIX: Iterate over the batch ---
    for i in range(src_pts.size(0)):
        # Get the points for the current item in the batch
        src_item_pts = src_pts[i]
        dst_item_pts = dst_pts[i]
        
        # Pad source and destination points to create homogeneous coordinates
        # These are now 2D tensors [25, 3]
        src_homo = F.pad(src_item_pts, (0, 1), "constant", 1.0)
        dst_homo = F.pad(dst_item_pts, (0, 1), "constant", 1.0)
        
        # Solve the least squares problem to find the affine matrix
        # These are now simple 2D inputs, which lstsq handles perfectly
        try:
            A_inv_B = torch.linalg.lstsq(src_homo, dst_homo).solution
            M = A_inv_B.transpose(0, 1)[:2, :] # Extract the 2x3 affine matrix
            M_list.append(M.unsqueeze(0))
        except torch.linalg.LinAlgError:
            # If the full solve fails, use a robust 3-point estimate
            M_3pt = cv2.getAffineTransform(src_item_pts[:3].cpu().numpy().astype(np.float32), 
                                           dst_item_pts[:3].cpu().numpy().astype(np.float32))
            M_list.append(torch.from_numpy(M_3pt).unsqueeze(0).to(source_image.device))


    # Recombine the list of matrices into a single batch tensor
    M_tensor = torch.cat(M_list, dim=0).to(source_image.device).float()

    # Create the grid and warp the image for the entire batch at once
    grid = F.affine_grid(M_tensor, source_image.size(), align_corners=False)
    warped_image = F.grid_sample(source_image, grid, align_corners=False)
    
    return warped_image
# --- Distributed Training Setup ---
def setup(rank, world_size, sync_file):
    init_method = f'file://{sync_file}'
    if rank == 0: print(f"Initializing process group with: {init_method}")
    dist.init_process_group("nccl", init_method=init_method, rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

# --- Main Training Function ---
def train(rank, world_size, sync_file):
    setup(rank, world_size, sync_file)
    
    # --- Configuration ---
    DATA_ROOT = '/kaggle/input/clothe/clothes_tryon_dataset/train'
    CHECKPOINT_DIR = '/kaggle/working/tps_warper_checkpoints/'
    VISUALIZATION_DIR = '/kaggle/working/tps_warper_visuals/'
    os.makedirs(CHECKPOINT_DIR, exist_ok=True); os.makedirs(VISUALIZATION_DIR, exist_ok=True)
    
    BATCH_SIZE = 32; NUM_EPOCHS = 50; LEARNING_RATE = 1e-4
    IMAGE_SIZE = (256, 192); ORIGINAL_IMAGE_SIZE = (768, 1024)

    # --- Setup Device, Model, Optimizer, Loss ---
    torch.cuda.set_device(rank)
    model = TPSWarper().to(rank)
    model = DDP(model, device_ids=[rank])
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    l1_loss_fn = nn.L1Loss().to(rank)
    scaler = torch.cuda.amp.GradScaler()

    # --- Resumption Logic ---
    start_epoch = 0; checkpoint_path = os.path.join(CHECKPOINT_DIR, 'tps_latest_ddp.pth')
    if os.path.isfile(checkpoint_path) and rank == 0:
        print(f"Loading checkpoint: {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path, map_location={'cuda:0': f'cuda:{rank}'})
        model.module.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        scaler.load_state_dict(checkpoint['scaler_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
    dist.barrier() 

    # --- Data Loading ---
    dataset = FashionVTONDataset(data_root=DATA_ROOT, image_size=IMAGE_SIZE, original_image_size=ORIGINAL_IMAGE_SIZE)
    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=True)
    loader = DataLoader(dataset, batch_size=BATCH_SIZE, sampler=sampler, num_workers=4, pin_memory=True)

    # --- Training Loop ---
    for epoch in range(start_epoch, NUM_EPOCHS):
        sampler.set_epoch(epoch); model.train()
        progress_bar = tqdm(loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}", disable=(rank != 0))

        for batch in progress_bar:
            optimizer.zero_grad()
            cloth_image = batch['cloth_image'].to(rank); cloth_mask = batch['cloth_mask'].to(rank)
            pose_map = batch['pose_map'].to(rank); agnostic_person = batch['agnostic_person'].to(rank)
            ground_truth_warped = batch['warped_cloth'].to(rank)
            
            with torch.cuda.amp.autocast():
                src_pts, dst_pts = model(cloth_image, pose_map, agnostic_person)
                # Use kornia/fallback to warp the cloth
                predicted_warped = tps_warp(cloth_image, src_pts, dst_pts)
                
                # FOCUS ON L1 LOSS ONLY FOR GEOMETRY
                total_loss = l1_loss_fn(predicted_warped, ground_truth_warped)

            scaler.scale(total_loss).backward()
            scaler.step(optimizer); scaler.update()
            
            if rank == 0: progress_bar.set_postfix(loss=total_loss.item())
        
        # --- Checkpointing & Visualization (only from rank 0) ---
        if rank == 0:
            latest_checkpoint_state = {
                'epoch': epoch, 'model_state_dict': model.module.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(), 'scaler_state_dict': scaler.state_dict(),
            }
            torch.save(latest_checkpoint_state, checkpoint_path)
            
            with torch.no_grad():
                vis_batch = torch.cat([(cloth_image.cpu() + 1) / 2, (predicted_warped.cpu().detach() + 1) / 2, (ground_truth_warped.cpu() + 1) / 2], dim=0)
                save_image(vis_batch, os.path.join(VISUALIZATION_DIR, f'epoch_{epoch+1}_comparison.png'), nrow=BATCH_SIZE)

            print(f"Epoch {epoch+1} finished. Checkpoint saved.")

    if rank == 0:
        final_model_path = os.path.join(CHECKPOINT_DIR, 'tps_warper_final.pth')
        torch.save(model.module.state_dict(), final_model_path)
        print(f"Final model saved to {final_model_path}")
        
    cleanup()

if __name__ == '__main__':
    world_size = torch.cuda.device_count()
    if world_size < 2: print("Distributed training requires at least 2 GPUs.")
    else:
        sync_file_path = os.path.join('/kaggle/working', 'ddp_sync_file')
        if os.path.exists(sync_file_path): os.remove(sync_file_path)
        torch.multiprocessing.spawn(train, args=(world_size, sync_file_path), nprocs=world_size, join=True)

Overwriting train_tps_warper.py


In [14]:
!pip install kornia --quiet

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m4.7 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m87.4 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m66.3 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m52.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m8.0 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m14.4 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [31]:
!python train_tps_warper.py


Initializing process group with: file:///kaggle/working/ddp_sync_file
  scaler = torch.cuda.amp.GradScaler()
  scaler = torch.cuda.amp.GradScaler()
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
terminate called without an active exception
terminate called without an active exception
Epoch 1/50:   0%|                                       | 0/182 [00:24<?, ?it/s]
W0716 15:51:24.656000 1081 torch/multiprocessing/spawn.py:169] Terminating process 1085 via signal SIGTERM
Traceback (most recent call last):
  File "/kaggle/working/train_tps_warper.py", line 258, in <module>
    torch.multiprocessing.spawn(train, args=(world_size, sync_file_path), nprocs=world_size, join=True)
  File "/usr/local/lib/python3.11/dist-packages/torch/multiprocessing/spawn.py", line 340, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method="spawn")
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dis

In [32]:
%%writefile train_tps_warper.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision.utils import save_image
import os
from tqdm import tqdm
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
import json
import cv2
import numpy as np
from PIL import Image
from torchvision import transforms

# ===================================================================
# ALL CLASS DEFINITIONS FOR STAGE 1
# ===================================================================

class FashionVTONDataset(Dataset):
    def __init__(self, data_root, image_size=(256, 192), original_image_size=(768, 1024)):
        self.data_root = data_root; self.image_size = image_size; self.original_image_size = original_image_size; self.image_dir = os.path.join(data_root, 'image'); self.cloth_dir = os.path.join(data_root, 'cloth'); self.cloth_mask_dir = os.path.join(data_root, 'cloth-mask'); self.pose_dir = os.path.join(data_root, 'openpose_json'); self.parse_dir = os.path.join(data_root, 'image-parse-v3'); self.image_files = sorted([f for f in os.listdir(self.image_dir) if f.endswith(('.jpg', '.png'))]); self.transform = transforms.Compose([transforms.Resize(self.image_size, interpolation=transforms.InterpolationMode.BILINEAR), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]); self.mask_transform = transforms.Compose([transforms.Resize(self.image_size, interpolation=transforms.InterpolationMode.NEAREST), transforms.ToTensor()])
    def __len__(self): return len(self.image_files)
    def find_upper_cloth_label(self, pose_keypoints, parse_array):
        torso_indices = [1, 2, 5, 8, 9, 12]; visible_points = []
        for i in torso_indices:
            if i < len(pose_keypoints): x, y, conf = pose_keypoints[i];
            if conf > 0.1: visible_points.append((int(x), int(y)))
        if len(visible_points) < 3: return 5
        x_coords, y_coords = zip(*visible_points); min_x, max_x = min(x_coords), max(x_coords); min_y, max_y = min(y_coords), max(y_coords)
        if max_x <= min_x or max_y <= min_y: return 5
        torso_parse_area = parse_array[min_y:max_y, min_x:max_x]; unique_labels, counts = np.unique(torso_parse_area, return_counts=True); non_zero_mask = (unique_labels != 0)
        if np.any(non_zero_mask): return unique_labels[non_zero_mask][np.argmax(counts[non_zero_mask])]
        else: return 5
    def __getitem__(self, idx):
        image_name = self.image_files[idx]; base_name = os.path.splitext(image_name)[0]
        person_image = Image.open(os.path.join(self.image_dir, image_name)).convert('RGB'); cloth_image = Image.open(os.path.join(self.cloth_dir, image_name)).convert('RGB'); cloth_mask = Image.open(os.path.join(self.cloth_mask_dir, image_name)).convert('L')
        try:
            with open(os.path.join(self.pose_dir, f"{base_name}_keypoints.json"), 'r') as f: pose_data = json.load(f)
            pose_keypoints = np.array(pose_data['people'][0]['pose_keypoints_2d']).reshape(-1, 3)
        except (FileNotFoundError, IndexError): pose_keypoints = np.zeros((25, 3), dtype=np.float32)
        parse_path = os.path.join(self.parse_dir, f"{base_name}.png")
        if not os.path.exists(parse_path): parse_path = os.path.join(self.parse_dir, f"{base_name}.jpg")
        parse_array_orig = np.array(Image.open(parse_path).convert('L'))
        upper_cloth_label = self.find_upper_cloth_label(pose_keypoints, parse_array_orig)
        parse_array_resized = cv2.resize(parse_array_orig, self.image_size[::-1], interpolation=cv2.INTER_NEAREST)
        person_cloth_mask = (parse_array_resized == upper_cloth_label).astype(np.float32)
        person_image_tensor = self.transform(person_image); cloth_image_tensor = self.transform(cloth_image); cloth_mask_tensor = self.mask_transform(cloth_mask)
        pose_map_tensor = torch.from_numpy(self.create_pose_map(pose_keypoints)).float()
        blurred_mask_tensor = torch.from_numpy(cv2.GaussianBlur(person_cloth_mask, (5, 5), 0)).unsqueeze(0)
        agnostic_person_tensor = person_image_tensor * (1 - blurred_mask_tensor)
        warped_cloth_tensor = person_image_tensor * torch.from_numpy(person_cloth_mask).unsqueeze(0)
        return {'person_image': person_image_tensor, 'cloth_image': cloth_image_tensor, 'cloth_mask': cloth_mask_tensor, 'agnostic_person': agnostic_person_tensor, 'pose_map': pose_map_tensor, 'warped_cloth': warped_cloth_tensor}
    def create_pose_map(self, keypoints):
        h, w = self.image_size; orig_w, orig_h = self.original_image_size; num_keypoints = keypoints.shape[0]
        pose_map = np.zeros((num_keypoints, h, w), dtype=np.float32)
        for i, point in enumerate(keypoints):
            if point[2] > 0.1:
                x, y = int(point[0] * w / orig_w), int(point[1] * h / orig_h)
                if 0 <= x < w and 0 <= y < h: cv2.circle(pose_map[i], (x, y), radius=3, color=1, thickness=-1)
        return pose_map

def tps_warp(source_image, src_pts, dst_pts):
    M_list = []
    for i in range(src_pts.size(0)):
        src_homo = F.pad(src_pts[i], (0, 1), "constant", 1.0)
        dst_homo = F.pad(dst_pts[i], (0, 1), "constant", 1.0)
        try:
            A_inv_B = torch.linalg.lstsq(src_homo, dst_homo).solution
            M = A_inv_B.transpose(0, 1)[:2, :]
            M_list.append(M.unsqueeze(0))
        except torch.linalg.LinAlgError:
            M_3pt = cv2.getAffineTransform(src_pts[i, :3].detach().cpu().numpy().astype(np.float32), 
                                           dst_pts[i, :3].detach().cpu().numpy().astype(np.float32))
            M_list.append(torch.from_numpy(M_3pt).unsqueeze(0).to(source_image.device))
    M_tensor = torch.cat(M_list, dim=0).to(source_image.device).float()
    grid = F.affine_grid(M_tensor, source_image.size(), align_corners=False)
    return F.grid_sample(source_image, grid, align_corners=False)

class TPSWarper(nn.Module):
    def __init__(self, feature_channels=256, num_control_points=25):
        super().__init__()
        self.num_control_points = num_control_points
        self.cloth_feature_extractor = self._make_feature_extractor(3, feature_channels)
        self.person_feature_extractor = self._make_feature_extractor(28, feature_channels)
        self.regressor = nn.Sequential(nn.Linear(feature_channels * 2, 256), nn.ReLU(inplace=True), nn.Linear(256, 128), nn.ReLU(inplace=True), nn.Linear(128, num_control_points * 2), nn.Tanh())
        grid_size = int(np.sqrt(num_control_points))
        if grid_size * grid_size != num_control_points: raise ValueError(f"num_control_points must be a perfect square. Got {num_control_points}")
        grid = torch.linspace(-0.5, 0.5, grid_size)
        x, y = torch.meshgrid(grid, grid, indexing='ij')
        self.register_buffer('source_control_points', torch.stack([x.flatten(), y.flatten()], dim=-1))
    def _make_feature_extractor(self, in_channels, out_channels):
        return nn.Sequential(nn.Conv2d(in_channels, 64, 3, 2, 1), nn.ReLU(), nn.Conv2d(64, 128, 3, 2, 1), nn.ReLU(), nn.Conv2d(128, 256, 3, 2, 1), nn.ReLU(), nn.AdaptiveAvgPool2d(1))
    def forward(self, cloth_image, pose_map, agnostic_person):
        batch_size = cloth_image.size(0)
        cloth_features = self.cloth_feature_extractor(cloth_image).view(batch_size, -1)
        person_input = torch.cat([agnostic_person, pose_map], dim=1)
        person_features = self.person_feature_extractor(person_input).view(batch_size, -1)
        correlation_features = torch.cat([cloth_features, person_features], dim=1)
        offsets = self.regressor(correlation_features).view(batch_size, self.num_control_points, 2)
        # --- THIS IS THE FIX: Expand the source points to match the batch size ---
        src_pts_batch = self.source_control_points.unsqueeze(0).expand(batch_size, -1, -1)
        dst_pts_batch = src_pts_batch + (offsets * 0.5)
        return src_pts_batch, dst_pts_batch

def setup(rank, world_size, sync_file):
    init_method = f'file://{sync_file}'
    dist.init_process_group("nccl", init_method=init_method, rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

def train(rank, world_size, sync_file):
    setup(rank, world_size, sync_file)
    DATA_ROOT = '/kaggle/input/clothe/clothes_tryon_dataset/train'; CHECKPOINT_DIR = '/kaggle/working/tps_warper_checkpoints/'; VISUALIZATION_DIR = '/kaggle/working/tps_warper_visuals/'
    os.makedirs(CHECKPOINT_DIR, exist_ok=True); os.makedirs(VISUALIZATION_DIR, exist_ok=True)
    BATCH_SIZE = 32; NUM_EPOCHS = 50; LEARNING_RATE = 1e-4; IMAGE_SIZE = (256, 192); ORIGINAL_IMAGE_SIZE = (768, 1024)
    torch.cuda.set_device(rank)
    model = TPSWarper().to(rank); model = DDP(model, device_ids=[rank])
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE); l1_loss_fn = nn.L1Loss().to(rank); scaler = torch.cuda.amp.GradScaler()
    start_epoch = 0; checkpoint_path = os.path.join(CHECKPOINT_DIR, 'tps_latest_ddp.pth')
    if os.path.isfile(checkpoint_path) and rank == 0:
        checkpoint = torch.load(checkpoint_path, map_location={'cuda:0': f'cuda:{rank}'})
        model.module.load_state_dict(checkpoint['model_state_dict']); optimizer.load_state_dict(checkpoint['optimizer_state_dict']); scaler.load_state_dict(checkpoint['scaler_state_dict']); start_epoch = checkpoint['epoch'] + 1
    dist.barrier()
    dataset = FashionVTONDataset(data_root=DATA_ROOT, image_size=IMAGE_SIZE, original_image_size=ORIGINAL_IMAGE_SIZE)
    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=True)
    loader = DataLoader(dataset, batch_size=BATCH_SIZE, sampler=sampler, num_workers=4, pin_memory=True)
    for epoch in range(start_epoch, NUM_EPOCHS):
        sampler.set_epoch(epoch); model.train()
        progress_bar = tqdm(loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}", disable=(rank != 0))
        for batch in progress_bar:
            optimizer.zero_grad()
            cloth_image = batch['cloth_image'].to(rank); pose_map = batch['pose_map'].to(rank); agnostic_person = batch['agnostic_person'].to(rank); ground_truth_warped = batch['warped_cloth'].to(rank)
            with torch.cuda.amp.autocast():
                src_pts, dst_pts = model(cloth_image, pose_map, agnostic_person)
                predicted_warped = tps_warp(cloth_image, src_pts, dst_pts)
                total_loss = l1_loss_fn(predicted_warped, ground_truth_warped)
            scaler.scale(total_loss).backward(); scaler.step(optimizer); scaler.update()
            if rank == 0: progress_bar.set_postfix(loss=total_loss.item())
        if rank == 0:
            latest_checkpoint_state = {'epoch': epoch, 'model_state_dict': model.module.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scaler_state_dict': scaler.state_dict()}
            torch.save(latest_checkpoint_state, checkpoint_path)
            with torch.no_grad():
                vis_batch = torch.cat([(cloth_image.cpu() + 1) / 2, (predicted_warped.cpu().detach() + 1) / 2, (ground_truth_warped.cpu() + 1) / 2], dim=0)
                save_image(vis_batch, os.path.join(VISUALIZATION_DIR, f'epoch_{epoch+1}_comparison.png'), nrow=BATCH_SIZE)
            print(f"Epoch {epoch+1} finished. Checkpoint saved.")
    if rank == 0:
        final_model_path = os.path.join(CHECKPOINT_DIR, 'tps_warper_final.pth'); torch.save(model.module.state_dict(), final_model_path); print(f"Final model saved to {final_model_path}")
    cleanup()

if __name__ == '__main__':
    world_size = torch.cuda.device_count()
    if world_size < 2: print("Distributed training requires at least 2 GPUs.")
    else:
        sync_file_path = os.path.join('/kaggle/working', 'ddp_sync_file')
        if os.path.exists(sync_file_path): os.remove(sync_file_path)
        print(f"Found {world_size} GPUs. Starting DDP with file sync at {sync_file_path}")
        torch.multiprocessing.spawn(train, args=(world_size, sync_file_path), nprocs=world_size, join=True)

Overwriting train_tps_warper.py


In [33]:
!python train_tps_warper.py

Found 2 GPUs. Starting DDP with file sync at /kaggle/working/ddp_sync_file
  optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE); l1_loss_fn = nn.L1Loss().to(rank); scaler = torch.cuda.amp.GradScaler()
  optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE); l1_loss_fn = nn.L1Loss().to(rank); scaler = torch.cuda.amp.GradScaler()
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
Epoch 1/50: 100%|█████████████████| 182/182 [03:39<00:00,  1.21s/it, loss=0.296]
Epoch 1 finished. Checkpoint saved.
Epoch 2/50: 100%|█████████████████| 182/182 [03:41<00:00,  1.21s/it, loss=0.297]
Epoch 2 finished. Checkpoint saved.
Epoch 3/50: 100%|█████████████████| 182/182 [03:36<00:00,  1.19s/it, loss=0.312]
Epoch 3 finished. Checkpoint saved.
Epoch 4/50: 100%|█████████████████| 182/182 [03:36<00:00,  1.19s/it, loss=0.302]
Epoch 4 finished. Checkpoint saved.
Epoch 5/50: 100%|█████████████████| 182/182 [03:32<00:00,  1.17s/it, loss=0.311]
Epoch 5 finished. Checkpoint saved.


In [2]:
%%writefile train_coarse_gmm.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision.utils import save_image
import os
from tqdm import tqdm
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
import json
import cv2
import numpy as np
from PIL import Image
from torchvision import transforms

# ===================================================================
# CLASS DEFINITIONS
# ===================================================================

class FashionVTONDataset(Dataset):
    def __init__(self, data_root, image_size=(256, 192), original_image_size=(768, 1024)):
        self.data_root = data_root; self.image_size = image_size; self.original_image_size = original_image_size; self.image_dir = os.path.join(data_root, 'image'); self.cloth_dir = os.path.join(data_root, 'cloth'); self.cloth_mask_dir = os.path.join(data_root, 'cloth-mask'); self.pose_dir = os.path.join(data_root, 'openpose_json'); self.parse_dir = os.path.join(data_root, 'image-parse-v3'); self.image_files = sorted([f for f in os.listdir(self.image_dir) if f.endswith(('.jpg', '.png'))]); self.transform = transforms.Compose([transforms.Resize(self.image_size, interpolation=transforms.InterpolationMode.BILINEAR), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]); self.mask_transform = transforms.Compose([transforms.Resize(self.image_size, interpolation=transforms.InterpolationMode.NEAREST), transforms.ToTensor()])
    def __len__(self): return len(self.image_files)
    def find_upper_cloth_label(self, pose_keypoints, parse_array):
        torso_indices = [1, 2, 5, 8, 9, 12]; visible_points = []
        for i in torso_indices:
            if i < len(pose_keypoints): x, y, conf = pose_keypoints[i];
            if conf > 0.1: visible_points.append((int(x), int(y)))
        if len(visible_points) < 3: return 5
        x_coords, y_coords = zip(*visible_points); min_x, max_x = min(x_coords), max(x_coords); min_y, max_y = min(y_coords), max(y_coords)
        if max_x <= min_x or max_y <= min_y: return 5
        torso_parse_area = parse_array[min_y:max_y, min_x:max_x]; unique_labels, counts = np.unique(torso_parse_area, return_counts=True); non_zero_mask = (unique_labels != 0)
        if np.any(non_zero_mask): return unique_labels[non_zero_mask][np.argmax(counts[non_zero_mask])]
        else: return 5
    def __getitem__(self, idx):
        image_name = self.image_files[idx]; base_name = os.path.splitext(image_name)[0]
        person_image = Image.open(os.path.join(self.image_dir, image_name)).convert('RGB'); cloth_image = Image.open(os.path.join(self.cloth_dir, image_name)).convert('RGB'); cloth_mask = Image.open(os.path.join(self.cloth_mask_dir, image_name)).convert('L')
        try:
            with open(os.path.join(self.pose_dir, f"{base_name}_keypoints.json"), 'r') as f: pose_data = json.load(f)
            pose_keypoints = np.array(pose_data['people'][0]['pose_keypoints_2d']).reshape(-1, 3)
        except (FileNotFoundError, IndexError): pose_keypoints = np.zeros((25, 3), dtype=np.float32)
        parse_path = os.path.join(self.parse_dir, f"{base_name}.png")
        if not os.path.exists(parse_path): parse_path = os.path.join(self.parse_dir, f"{base_name}.jpg")
        parse_array_orig = np.array(Image.open(parse_path).convert('L'))
        upper_cloth_label = self.find_upper_cloth_label(pose_keypoints, parse_array_orig)
        parse_array_resized = cv2.resize(parse_array_orig, self.image_size[::-1], interpolation=cv2.INTER_NEAREST)
        person_cloth_mask = (parse_array_resized == upper_cloth_label).astype(np.float32)
        person_image_tensor = self.transform(person_image); cloth_image_tensor = self.transform(cloth_image); cloth_mask_tensor = self.mask_transform(cloth_mask)
        pose_map_tensor = torch.from_numpy(self.create_pose_map(pose_keypoints)).float()
        blurred_mask_tensor = torch.from_numpy(cv2.GaussianBlur(person_cloth_mask, (5, 5), 0)).unsqueeze(0)
        agnostic_person_tensor = person_image_tensor * (1 - blurred_mask_tensor)
        warped_cloth_tensor = person_image_tensor * torch.from_numpy(person_cloth_mask).unsqueeze(0)
        return {'person_image': person_image_tensor, 'cloth_image': cloth_image_tensor, 'cloth_mask': cloth_mask_tensor, 'agnostic_person': agnostic_person_tensor, 'pose_map': pose_map_tensor, 'warped_cloth': warped_cloth_tensor}
    def create_pose_map(self, keypoints):
        h, w = self.image_size; orig_w, orig_h = self.original_image_size; num_keypoints = keypoints.shape[0]
        pose_map = np.zeros((num_keypoints, h, w), dtype=np.float32)
        for i, point in enumerate(keypoints):
            if point[2] > 0.1:
                x, y = int(point[0] * w / orig_w), int(point[1] * h / orig_h)
                if 0 <= x < w and 0 <= y < h: cv2.circle(pose_map[i], (x, y), radius=3, color=1, thickness=-1)
        return pose_map

# Using the robust U-Net GMM architecture from our first successful attempt
class GMM(nn.Module):
    def __init__(self, in_channels_cloth=3, in_channels_pose=25, out_channels_flow=2):
        super(GMM, self).__init__(); self.encoder1 = self.conv_block(in_channels_cloth + in_channels_pose, 64); self.encoder2 = self.conv_block(64, 128); self.encoder3 = self.conv_block(128, 256); self.encoder4 = self.conv_block(256, 512); self.pool = nn.MaxPool2d(2, 2); self.bottleneck = self.conv_block(512, 1024); self.upconv4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2); self.decoder4 = self.conv_block(1024, 512); self.upconv3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2); self.decoder3 = self.conv_block(512, 256); self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2); self.decoder2 = self.conv_block(256, 128); self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2); self.decoder1 = self.conv_block(128, 64); self.conv_out = nn.Conv2d(64, out_channels_flow, kernel_size=1); self.tanh = nn.Tanh()
    def conv_block(self, in_channels, out_channels): return nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), nn.ReLU(inplace=True))
    def forward(self, cloth_image, pose_map):
        x = torch.cat([cloth_image, pose_map], dim=1); e1 = self.encoder1(x); p1 = self.pool(e1); e2 = self.encoder2(p1); p2 = self.pool(e2); e3 = self.encoder3(p2); p3 = self.pool(e3); e4 = self.encoder4(p3); p4 = self.pool(e4); b = self.bottleneck(p4); d4 = self.upconv4(b); d4 = torch.cat([d4, e4], dim=1); d4 = self.decoder4(d4); d3 = self.upconv3(d4); d3 = torch.cat([d3, e3], dim=1); d3 = self.decoder3(d3); d2 = self.upconv2(d3); d2 = torch.cat([d2, e2], dim=1); d2 = self.decoder2(d2); d1 = self.upconv1(d2); d1 = torch.cat([d1, e1], dim=1); d1 = self.decoder1(d1); flow_field = self.conv_out(d1); flow_field = self.tanh(flow_field); return flow_field

def warp_cloth_with_flow(cloth_image, flow_field):
    flow_field = flow_field.permute(0, 2, 3, 1); warped_image = F.grid_sample(cloth_image, flow_field, mode='bilinear', padding_mode='zeros', align_corners=True); return warped_image

class TVLoss(nn.Module):
    def __init__(self): super(TVLoss, self).__init__()
    def forward(self, x):
        batch_size, c, h, w = x.size(); tv_h = torch.pow(x[:,:,1:,:] - x[:,:,:-1,:], 2).sum(); tv_w = torch.pow(x[:,:,:,1:] - x[:,:,:,:-1], 2).sum()
        return (tv_h + tv_w) / (batch_size * c * h * w)

def setup(rank, world_size, sync_file):
    init_method = f'file://{sync_file}'
    dist.init_process_group("nccl", init_method=init_method, rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

def train(rank, world_size, sync_file):
    setup(rank, world_size, sync_file)
    DATA_ROOT = '/kaggle/input/clothe/clothes_tryon_dataset/train'; CHECKPOINT_DIR = '/kaggle/working/coarse_gmm_checkpoints/'; VISUALIZATION_DIR = '/kaggle/working/coarse_gmm_visuals/'
    os.makedirs(CHECKPOINT_DIR, exist_ok=True); os.makedirs(VISUALIZATION_DIR, exist_ok=True)
    BATCH_SIZE = 16; NUM_EPOCHS = 50; LEARNING_RATE = 2e-5; IMAGE_SIZE = (256, 192); ORIGINAL_IMAGE_SIZE = (768, 1024)
    torch.cuda.set_device(rank)
    model = GMM(in_channels_pose=25).to(rank) # Using the stable GMM architecture
    model = DDP(model, device_ids=[rank])
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
    l1_loss_fn = nn.L1Loss().to(rank); tv_loss_fn = TVLoss().to(rank)
    scaler = torch.cuda.amp.GradScaler()
    start_epoch = 0; checkpoint_path = os.path.join(CHECKPOINT_DIR, 'coarse_gmm_latest.pth')
    if os.path.isfile(checkpoint_path) and rank == 0:
        checkpoint = torch.load(checkpoint_path, map_location={'cuda:0': f'cuda:{rank}'})
        model.module.load_state_dict(checkpoint['model_state_dict']); optimizer.load_state_dict(checkpoint['optimizer_state_dict']); scaler.load_state_dict(checkpoint['scaler_state_dict']); start_epoch = checkpoint['epoch'] + 1
    dist.barrier()
    dataset = FashionVTONDataset(data_root=DATA_ROOT, image_size=IMAGE_SIZE, original_image_size=ORIGINAL_IMAGE_SIZE)
    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=True)
    loader = DataLoader(dataset, batch_size=BATCH_SIZE, sampler=sampler, num_workers=4, pin_memory=True)
    for epoch in range(start_epoch, NUM_EPOCHS):
        sampler.set_epoch(epoch); model.train()
        progress_bar = tqdm(loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}", disable=(rank != 0))
        for batch in progress_bar:
            optimizer.zero_grad()
            cloth_image = batch['cloth_image'].to(rank); pose_map = batch['pose_map'].to(rank); ground_truth_warped = batch['warped_cloth'].to(rank)
            with torch.cuda.amp.autocast():
                predicted_flow = model(cloth_image, pose_map)
                predicted_warped = warp_cloth_with_flow(cloth_image, predicted_flow)
                # --- LOSS IS L1 + TV ONLY ---
                loss_l1 = l1_loss_fn(predicted_warped, ground_truth_warped)
                loss_tv = tv_loss_fn(predicted_flow)
                total_loss = loss_l1 + 0.5 * loss_tv
            scaler.scale(total_loss).backward(); scaler.step(optimizer); scaler.update()
            if rank == 0: progress_bar.set_postfix(loss=total_loss.item())
        if rank == 0:
            latest_checkpoint_state = {'epoch': epoch, 'model_state_dict': model.module.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scaler_state_dict': scaler.state_dict()}
            torch.save(latest_checkpoint_state, checkpoint_path)
            with torch.no_grad():
                vis_batch = torch.cat([(cloth_image.cpu() + 1) / 2, (predicted_warped.cpu().detach() + 1) / 2, (ground_truth_warped.cpu() + 1) / 2], dim=0)
                save_image(vis_batch, os.path.join(VISUALIZATION_DIR, f'epoch_{epoch+1}_comparison.png'), nrow=BATCH_SIZE)
            print(f"Epoch {epoch+1} finished. Checkpoint saved.")
    if rank == 0:
        final_model_path = os.path.join(CHECKPOINT_DIR, 'coarse_gmm_final.pth'); torch.save(model.module.state_dict(), final_model_path); print(f"Final model saved to {final_model_path}")
    cleanup()

if __name__ == '__main__':
    world_size = torch.cuda.device_count()
    if world_size < 2: print("Distributed training requires at least 2 GPUs.")
    else:
        sync_file_path = os.path.join('/kaggle/working', 'ddp_sync_file')
        if os.path.exists(sync_file_path): os.remove(sync_file_path)
        torch.multiprocessing.spawn(train, args=(world_size, sync_file_path), nprocs=world_size, join=True)

Writing train_coarse_gmm.py


In [1]:
!python train_coarse_gmm.py

python3: can't open file '/kaggle/working/train_coarse_gmm.py': [Errno 2] No such file or directory


In [51]:
%%writefile train_coarse_gmm.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision.utils import save_image
import os
from tqdm import tqdm
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
import json
import cv2
import numpy as np
from PIL import Image
from torchvision import transforms

# ===================================================================
# CLASS DEFINITIONS (Unchanged)
# ===================================================================
class FashionVTONDataset(Dataset):
    def __init__(self, data_root, image_size=(256, 192), original_image_size=(768, 1024)):
        self.data_root = data_root; self.image_size = image_size; self.original_image_size = original_image_size; self.image_dir = os.path.join(data_root, 'image'); self.cloth_dir = os.path.join(data_root, 'cloth'); self.cloth_mask_dir = os.path.join(data_root, 'cloth-mask'); self.pose_dir = os.path.join(data_root, 'openpose_json'); self.parse_dir = os.path.join(data_root, 'image-parse-v3'); self.image_files = sorted([f for f in os.listdir(self.image_dir) if f.endswith(('.jpg', '.png'))]); self.transform = transforms.Compose([transforms.Resize(self.image_size, interpolation=transforms.InterpolationMode.BILINEAR), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]); self.mask_transform = transforms.Compose([transforms.Resize(self.image_size, interpolation=transforms.InterpolationMode.NEAREST), transforms.ToTensor()])
    def __len__(self): return len(self.image_files)
    def find_upper_cloth_label(self, pose_keypoints, parse_array):
        torso_indices = [1, 2, 5, 8, 9, 12]; visible_points = []
        for i in torso_indices:
            if i < len(pose_keypoints): x, y, conf = pose_keypoints[i];
            if conf > 0.1: visible_points.append((int(x), int(y)))
        if len(visible_points) < 3: return 5
        x_coords, y_coords = zip(*visible_points); min_x, max_x = min(x_coords), max(x_coords); min_y, max_y = min(y_coords), max(y_coords)
        if max_x <= min_x or max_y <= min_y: return 5
        torso_parse_area = parse_array[min_y:max_y, min_x:max_x]; unique_labels, counts = np.unique(torso_parse_area, return_counts=True); non_zero_mask = (unique_labels != 0)
        if np.any(non_zero_mask): return unique_labels[non_zero_mask][np.argmax(counts[non_zero_mask])]
        else: return 5
    def __getitem__(self, idx):
        image_name = self.image_files[idx]; base_name = os.path.splitext(image_name)[0]
        person_image = Image.open(os.path.join(self.image_dir, image_name)).convert('RGB'); cloth_image = Image.open(os.path.join(self.cloth_dir, image_name)).convert('RGB'); cloth_mask = Image.open(os.path.join(self.cloth_mask_dir, image_name)).convert('L')
        try:
            with open(os.path.join(self.pose_dir, f"{base_name}_keypoints.json"), 'r') as f: pose_data = json.load(f)
            pose_keypoints = np.array(pose_data['people'][0]['pose_keypoints_2d']).reshape(-1, 3)
        except (FileNotFoundError, IndexError): pose_keypoints = np.zeros((25, 3), dtype=np.float32)
        parse_path = os.path.join(self.parse_dir, f"{base_name}.png")
        if not os.path.exists(parse_path): parse_path = os.path.join(self.parse_dir, f"{base_name}.jpg")
        parse_array_orig = np.array(Image.open(parse_path).convert('L'))
        upper_cloth_label = self.find_upper_cloth_label(pose_keypoints, parse_array_orig)
        parse_array_resized = cv2.resize(parse_array_orig, self.image_size[::-1], interpolation=cv2.INTER_NEAREST)
        person_cloth_mask = (parse_array_resized == upper_cloth_label).astype(np.float32)
        person_image_tensor = self.transform(person_image); cloth_image_tensor = self.transform(cloth_image); cloth_mask_tensor = self.mask_transform(cloth_mask)
        pose_map_tensor = torch.from_numpy(self.create_pose_map(pose_keypoints)).float()
        blurred_mask_tensor = torch.from_numpy(cv2.GaussianBlur(person_cloth_mask, (5, 5), 0)).unsqueeze(0)
        agnostic_person_tensor = person_image_tensor * (1 - blurred_mask_tensor)
        warped_cloth_tensor = person_image_tensor * torch.from_numpy(person_cloth_mask).unsqueeze(0)
        return {'person_image': person_image_tensor, 'cloth_image': cloth_image_tensor, 'cloth_mask': cloth_mask_tensor, 'agnostic_person': agnostic_person_tensor, 'pose_map': pose_map_tensor, 'warped_cloth': warped_cloth_tensor}
    def create_pose_map(self, keypoints):
        h, w = self.image_size; orig_w, orig_h = self.original_image_size; num_keypoints = keypoints.shape[0]
        pose_map = np.zeros((num_keypoints, h, w), dtype=np.float32)
        for i, point in enumerate(keypoints):
            if point[2] > 0.1:
                x, y = int(point[0] * w / orig_w), int(point[1] * h / orig_h)
                if 0 <= x < w and 0 <= y < h: cv2.circle(pose_map[i], (x, y), radius=3, color=1, thickness=-1)
        return pose_map

class GMM(nn.Module):
    def __init__(self, in_channels_cloth=3, in_channels_pose=25, out_channels_flow=2):
        super(GMM, self).__init__(); self.encoder1 = self.conv_block(in_channels_cloth + in_channels_pose, 64); self.encoder2 = self.conv_block(64, 128); self.encoder3 = self.conv_block(128, 256); self.encoder4 = self.conv_block(256, 512); self.pool = nn.MaxPool2d(2, 2); self.bottleneck = self.conv_block(512, 1024); self.upconv4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2); self.decoder4 = self.conv_block(1024, 512); self.upconv3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2); self.decoder3 = self.conv_block(512, 256); self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2); self.decoder2 = self.conv_block(256, 128); self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2); self.decoder1 = self.conv_block(128, 64); self.conv_out = nn.Conv2d(64, out_channels_flow, kernel_size=1); self.tanh = nn.Tanh()
    def conv_block(self, in_channels, out_channels): return nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), nn.ReLU(inplace=True))
    def forward(self, cloth_image, pose_map):
        x = torch.cat([cloth_image, pose_map], dim=1); e1 = self.encoder1(x); p1 = self.pool(e1); e2 = self.encoder2(p1); p2 = self.pool(e2); e3 = self.encoder3(p2); p3 = self.pool(e3); e4 = self.encoder4(p3); p4 = self.pool(e4); b = self.bottleneck(p4); d4 = self.upconv4(b); d4 = torch.cat([d4, e4], dim=1); d4 = self.decoder4(d4); d3 = self.upconv3(d4); d3 = torch.cat([d3, e3], dim=1); d3 = self.decoder3(d3); d2 = self.upconv2(d3); d2 = torch.cat([d2, e2], dim=1); d2 = self.decoder2(d2); d1 = self.upconv1(d2); d1 = torch.cat([d1, e1], dim=1); d1 = self.decoder1(d1); flow_field = self.conv_out(d1); flow_field = self.tanh(flow_field); return flow_field

def warp_cloth_with_flow(cloth_image, flow_field):
    flow_field = flow_field.permute(0, 2, 3, 1); warped_image = F.grid_sample(cloth_image, flow_field, mode='bilinear', padding_mode='zeros', align_corners=True); return warped_image

class TVLoss(nn.Module):
    def __init__(self): super(TVLoss, self).__init__()
    def forward(self, x):
        batch_size, c, h, w = x.size(); tv_h = torch.pow(x[:,:,1:,:] - x[:,:,:-1,:], 2).sum(); tv_w = torch.pow(x[:,:,:,1:] - x[:,:,:,:-1], 2).sum()
        return (tv_h + tv_w) / (batch_size * c * h * w)

def setup(rank, world_size, sync_file):
    init_method = f'file://{sync_file}'
    dist.init_process_group("nccl", init_method=init_method, rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

def train(rank, world_size, sync_file):
    setup(rank, world_size, sync_file)
    DATA_ROOT = '/kaggle/input/clothe/clothes_tryon_dataset/train'; CHECKPOINT_DIR = '/kaggle/working/coarse_gmm_checkpoints_new1/'; VISUALIZATION_DIR = '/kaggle/working/coarse_gmm_visuals_new/'
    os.makedirs(CHECKPOINT_DIR, exist_ok=True); os.makedirs(VISUALIZATION_DIR, exist_ok=True)
    BATCH_SIZE = 16; NUM_EPOCHS = 50; LEARNING_RATE = 2e-5; IMAGE_SIZE = (256, 192); ORIGINAL_IMAGE_SIZE = (768, 1024)
    torch.cuda.set_device(rank)
    model = GMM(in_channels_pose=25).to(rank) 
    model = DDP(model, device_ids=[rank])
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
    l1_loss_fn = nn.L1Loss().to(rank); tv_loss_fn = TVLoss().to(rank)
    scaler = torch.cuda.amp.GradScaler()
    start_epoch = 0; checkpoint_path = os.path.join(CHECKPOINT_DIR, 'coarse_gmm_latest.pth')
    if os.path.isfile(checkpoint_path) and rank == 0:
        checkpoint = torch.load(checkpoint_path, map_location={'cuda:0': f'cuda:{rank}'})
        model.module.load_state_dict(checkpoint['model_state_dict']); optimizer.load_state_dict(checkpoint['optimizer_state_dict']); scaler.load_state_dict(checkpoint['scaler_state_dict']); start_epoch = checkpoint['epoch'] + 1
    dist.barrier()
    dataset = FashionVTONDataset(data_root=DATA_ROOT, image_size=IMAGE_SIZE, original_image_size=ORIGINAL_IMAGE_SIZE)
    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=True)
    loader = DataLoader(dataset, batch_size=BATCH_SIZE, sampler=sampler, num_workers=4, pin_memory=True)
    for epoch in range(start_epoch, NUM_EPOCHS):
        sampler.set_epoch(epoch); model.train()
        progress_bar = tqdm(loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}", disable=(rank != 0))
        for batch in progress_bar:
            optimizer.zero_grad()
            cloth_image = batch['cloth_image'].to(rank)
            cloth_mask = batch['cloth_mask'].to(rank) # <<< LOAD THE MASK
            pose_map = batch['pose_map'].to(rank)
            ground_truth_warped = batch['warped_cloth'].to(rank)
            
            with torch.cuda.amp.autocast():
                predicted_flow = model(cloth_image, pose_map)
                predicted_warped = warp_cloth_with_flow(cloth_image, predicted_flow)
                
                # --- THIS IS THE FIX: Warp the cloth mask as well ---
                warped_cloth_mask = F.grid_sample(cloth_mask, predicted_flow.permute(0, 2, 3, 1), mode='bilinear', padding_mode='zeros', align_corners=True)
                
                # --- And apply it to the L1 Loss ---
                loss_l1 = l1_loss_fn(predicted_warped * warped_cloth_mask, ground_truth_warped * warped_cloth_mask)
                loss_tv = tv_loss_fn(predicted_flow)
                total_loss = loss_l1 + 0.5 * loss_tv

            scaler.scale(total_loss).backward(); scaler.step(optimizer); scaler.update()
            if rank == 0: progress_bar.set_postfix(loss=total_loss.item())
                    # ... inside the "if rank == 0:" block at the end of the epoch loop ...
            with torch.no_grad():
                # --- THIS IS THE FIX ---
                # We should visualize the ground_truth_warped image, not the mask.
                vis_batch = torch.cat([
                    (cloth_image.cpu() + 1) / 2, 
                    (predicted_warped.cpu().detach() + 1) / 2, 
                    (ground_truth_warped.cpu() + 1) / 2
                ], dim=0)
                save_image(vis_batch, os.path.join(VISUALIZATION_DIR, f'epoch_{epoch+1}_comparison.png'), nrow=BATCH_SIZE)
    if rank == 0:
        final_model_path = os.path.join(CHECKPOINT_DIR, 'coarse_gmm_final.pth'); torch.save(model.module.state_dict(), final_model_path); print(f"Final model saved to {final_model_path}")
    cleanup()

if __name__ == '__main__':
    world_size = torch.cuda.device_count()
    if world_size < 2: print("Distributed training requires at least 2 GPUs.")
    else:
        sync_file_path = os.path.join('/kaggle/working', 'ddp_sync_file')
        if os.path.exists(sync_file_path): os.remove(sync_file_path)
        torch.multiprocessing.spawn(train, args=(world_size, sync_file_path), nprocs=world_size, join=True)

Overwriting train_coarse_gmm.py


In [52]:
!python train_coarse_gmm.py

  scaler = torch.cuda.amp.GradScaler()
  scaler = torch.cuda.amp.GradScaler()
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
Epoch 1/50: 100%|█████████████████| 364/364 [05:53<00:00,  1.03it/s, loss=0.537]
Epoch 2/50: 100%|█████████████████| 364/364 [05:52<00:00,  1.03it/s, loss=0.573]
Epoch 3/50: 100%|█████████████████| 364/364 [05:49<00:00,  1.04it/s, loss=0.576]
Epoch 4/50: 100%|█████████████████| 364/364 [05:47<00:00,  1.05it/s, loss=0.493]
Epoch 5/50:  23%|████              | 83/364 [01:31<04:17,  1.09it/s, loss=0.509]^C
Traceback (most recent call last):
  File "/kaggle/working/train_coarse_gmm.py", line 150, in <module>
    torch.multiprocessing.spawn(train, args=(world_size, sync_file_path), nprocs=world_size, join=True)
  File "/usr/local/lib/python3.11/dist-packages/torch/multiprocessing/spawn.py", line 340, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method="spawn")
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^