In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
import torch
import torch.utils.data as data
import torchvision.transforms as transforms
from PIL import Image
import os.path as osp
import numpy as np

class CustomDataset(data.Dataset):

    def __init__(self, opt):
        super(CustomDataset, self).__init__()
        self.opt = opt
        self.dataroot = opt.dataroot
        self.datamode = opt.datamode
        self.data_list = opt.data_list
        self.fine_height = opt.fine_height
        self.fine_width = opt.fine_width
        self.data_path = osp.join(self.dataroot, self.datamode)
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

        # Load data list
        im_names = []
        c_names = []
        with open(osp.join(self.dataroot, self.data_list), 'r') as f:
            for line in f.readlines():
                im_name, c_name = line.strip().split()
                im_names.append(im_name)
                c_names.append(c_name)

        self.im_names = im_names
        self.c_names = c_names

    def __getitem__(self, index):
        c_name = self.c_names[index]
        im_name = self.im_names[index]

        # Clothing image
        cloth_path = osp.join(self.data_path, 'cloth', c_name)
        c = Image.open(cloth_path).convert('RGB')
        c = self.transform(c)

        # Clothing mask
        cloth_mask_path = osp.join(self.data_path, 'cloth-mask', c_name)
        cm = Image.open(cloth_mask_path)
        cm_array = np.array(cm)
        cm_array = (cm_array >= 128).astype(np.float32)
        cm = torch.from_numpy(cm_array)
        cm.unsqueeze_(0)

        # Person image (for visualization and ground truth)
        image_path = osp.join(self.data_path, 'image', im_name)
        im = Image.open(image_path).convert('RGB')
        im = self.transform(im)

        # Person's parse-agnostic representation
        parse_agnostic_path = osp.join(self.data_path, 'image-parse-agnostic-v3.2', im_name.replace('.jpg', '.png'))
        parse_agnostic = Image.open(parse_agnostic_path).convert('L')
        parse_agnostic = transforms.ToTensor()(parse_agnostic)

        # Densepose
        densepose_path = osp.join(self.data_path, 'image-densepose', im_name)
        densepose = Image.open(densepose_path).convert('RGB')
        densepose = self.transform(densepose)

        # OpenPose image (rendered keypoints)
        openpose_path = osp.join(self.data_path, 'openpose_img', im_name.replace('.jpg', '_rendered.png'))
        pose_map = Image.open(openpose_path).convert('RGB')
        pose_map = self.transform(pose_map)

        result = {
            # Inputs for the Condition Generator (TOCG)
            'c_name':         c_name,
            'im_name':        im_name,
            'cloth':          c,
            'cloth_mask':     cm,
            'parse_agnostic': parse_agnostic,
            'densepose':      densepose,
            'pose':           pose_map,
            # Ground truth image for the final Generator
            'image':          im,
        }

        return result

    def __len__(self):
        return len(self.im_names)


class CPDataLoader:

    def __init__(self, opt, dataset):
        super(CPDataLoader, self).__init__()

        is_shuffle = opt.shuffle

        self.data_loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=opt.batch_size,
            shuffle=is_shuffle,
            num_workers=opt.workers,
            pin_memory=True,
            drop_last=True
        )
        self.dataset = dataset
        self.data_iter = self.data_loader.__iter__()

    def next_batch(self):
        try:
            batch = self.data_iter.__next__()
        except StopIteration:
            self.data_iter = self.data_loader.__iter__()
            batch = self.data_iter.__next__()

        return batch

In [None]:
import argparse
import torch


YOUR_KAGGLE_DATASET_NAME = 'clothes_tryon_dataset'


opt = argparse.Namespace()

opt.dataroot = f'/kaggle/input/clothes_tryon_dataset'
opt.datamode = 'train' 
opt.data_list = 'train_pairs.txt'

# --- Image & Batch Settings ---
opt.fine_height = 1024
opt.fine_width = 768
opt.batch_size = 4  
opt.workers = 2     


opt.shuffle = False 


print("--- Test Script Initialized ---")
print(f"Dataroot set to: {opt.dataroot}")
print(f"Loading data list: {opt.data_list}")


try:
    # --- Initialize the Dataset ---
    print("Initializing CustomDataset...")
    # This assumes CustomDataset and CPDataLoader are defined in the previous cell
    train_dataset = CustomDataset(opt)
    print(f"Dataset initialized successfully. Found {len(train_dataset)} image pairs.")

    # --- Initialize the DataLoader ---
    print("\nInitializing CPDataLoader...")
    train_loader = CPDataLoader(opt, train_dataset)
    print("Dataloader initialized successfully.")

    # --- Fetch One Batch ---
    print("\nAttempting to load one batch of data...")
    first_batch = train_loader.next_batch()
    print("Successfully loaded one batch!")

    # --- Inspect the Batch Content ---
    print("\n--- Batch Content Verification ---")
    for key, value in first_batch.items():
        if isinstance(value, torch.Tensor):
            print(f"  - Key: '{key}', Shape: {value.shape}, DType: {value.dtype}")
        else:
            # This will print the list of image/cloth names
            print(f"  - Key: '{key}', Type: {type(value)}, Length: {len(value)}")
    print("-" * 30)
    print("\n DATA LOADING TEST PASSED! ")
    print("You are now ready to define the network architectures.")

except FileNotFoundError as e:
    print("\n--- ❌ DATA LOADING TEST FAILED: File Not Found ❌ ---")
    print(f"ERROR: {e}")
    print("\nPlease double-check the following:")
    print(f"1. Is YOUR_KAGGLE_DATASET_NAME set correctly to '{YOUR_KAGGLE_DATASET_NAME}'?")
    print("2. Does your Kaggle dataset have the exact folder structure we discussed?")
    print("   (e.g., /train/cloth, /train/image, etc.)")

except Exception as e:
    print(f"\n---  AN UNEXPECTED ERROR OCCURRED  ---")
    print(f"ERROR: {e}")

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import re
import functools
import numpy as np
from torch.nn.utils import spectral_norm

#==============================================================================
# Helper Functions and Classes
#==============================================================================

class UnetSkipConnectionBlock(nn.Module):

    def __init__(self, outer_nc, inner_nc, input_nc=None,
                 submodule=None, outermost=False, innermost=False,
                 norm_layer=nn.BatchNorm2d, use_dropout=False):
        super(UnetSkipConnectionBlock, self).__init__()
        self.outermost = outermost
        if type(norm_layer) == functools.partial:
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d
        if input_nc is None:
            input_nc = outer_nc
        downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
                             stride=2, padding=1, bias=use_bias)
        downrelu = nn.LeakyReLU(0.2, True)
        downnorm = norm_layer(inner_nc)
        uprelu = nn.ReLU(True)
        upnorm = norm_layer(outer_nc)

        if outermost:
            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
                                        kernel_size=4, stride=2,
                                        padding=1)
            down = [downconv]
            up = [uprelu, upconv, nn.Tanh()]
            model = down + [submodule] + up
        elif innermost:
            upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
                                        kernel_size=4, stride=2,
                                        padding=1, bias=use_bias)
            down = [downrelu, downconv]
            up = [uprelu, upconv, upnorm]
            model = down + up
        else:
            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
                                        kernel_size=4, stride=2,
                                        padding=1, bias=use_bias)
            down = [downrelu, downconv, downnorm]
            up = [uprelu, upconv, upnorm]

            if use_dropout:
                model = down + [submodule] + up + [nn.Dropout(0.5)]
            else:
                model = down + [submodule] + up

        self.model = nn.Sequential(*model)

    def forward(self, x):
        if self.outermost:
            return self.model(x)
        else:
            return torch.cat([x, self.model(x)], 1)



class UnetGenerator(nn.Module):
    """
    Defines the U-Net generator architecture.
    This is used for the Try-On Condition Generator (TOCG).
    """
    def __init__(self, input_nc, output_nc, num_downs, ngf=64,
                 norm_layer=nn.BatchNorm2d, use_dropout=False):
        super(UnetGenerator, self).__init__()

        # construct unet structure
        unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True)
        for i in range(num_downs - 5):
            unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
        unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
        unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
        unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
        self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer)

    def forward(self, input):
        return self.model(input)


class NLayerDiscriminator(nn.Module):

    def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False):
        super(NLayerDiscriminator, self).__init__()
        if type(norm_layer) == functools.partial:
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d

        kw = 4
        padw = 1
        sequence = [
            nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
            nn.LeakyReLU(0.2, True)
        ]

        nf_mult = 1
        nf_mult_prev = 1
        for n in range(1, n_layers):
            nf_mult_prev = nf_mult
            nf_mult = min(2**n, 8)
            sequence += [
                nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
                          kernel_size=kw, stride=2, padding=padw, bias=use_bias),
                norm_layer(ndf * nf_mult),
                nn.LeakyReLU(0.2, True)
            ]

        nf_mult_prev = nf_mult
        nf_mult = min(2**n_layers, 8)
        sequence += [
            nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
                      kernel_size=kw, stride=1, padding=padw, bias=use_bias),
            norm_layer(ndf * nf_mult),
            nn.LeakyReLU(0.2, True)
        ]

        sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]

        if use_sigmoid:
            sequence += [nn.Sigmoid()]

        self.model = nn.Sequential(*sequence)

    def forward(self, input):
        return self.model(input)




class SPADE(nn.Module):
    def __init__(self, config_text, norm_nc, label_nc):
        super().__init__()
        parsed = re.search(r'spade(\d+)x(\d+)_norm_(\w+)', config_text)
        param_free_norm_type = parsed.group(3)
        ks = int(parsed.group(1))
        pw = ks // 2
        self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)

        nhidden = 128
        self.mlp_shared = nn.Sequential(
            nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw),
            nn.ReLU()
        )
        self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)
        self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)

    def forward(self, x, segmap):
        normalized = self.param_free_norm(x)
        segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest')
        actv = self.mlp_shared(segmap)
        gamma = self.mlp_gamma(actv)
        beta = self.mlp_beta(actv)
        out = normalized * (1 + gamma) + beta
        return out

class SPADEResnetBlock(nn.Module):
    def __init__(self, fin, fout, opt):
        super().__init__()
        self.learned_shortcut = (fin != fout)
        fmiddle = min(fin, fout)
        self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1)
        self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1)
        if self.learned_shortcut:
            self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False)

        spade_config_str = 'spade3x3_norm_in'
        self.norm_0 = SPADE(spade_config_str, fin, opt.semantic_nc)
        self.norm_1 = SPADE(spade_config_str, fmiddle, opt.semantic_nc)
        if self.learned_shortcut:
            self.norm_s = SPADE(spade_config_str, fin, opt.semantic_nc)

    def forward(self, x, seg):
        x_s = self.shortcut(x, seg)
        dx = self.conv_0(self.actvn(self.norm_0(x, seg)))
        dx = self.conv_1(self.actvn(self.norm_1(dx, seg)))
        out = x_s + dx
        return out

    def shortcut(self, x, seg):
        if self.learned_shortcut:
            x_s = self.conv_s(self.norm_s(x, seg))
        else:
            x_s = x
        return x_s

    def actvn(self, x):
        return F.leaky_relu(x, 2e-1)


class SPADEGenerator(nn.Module):
    """
    The main SPADE-based generator for the final image synthesis.
    """
    def __init__(self, opt):
        super().__init__()
        self.opt = opt
        nf = opt.ngf
        self.sw, self.sh = self.compute_latent_vector_size(opt)

        self.fc = nn.Conv2d(self.opt.semantic_nc, 16 * nf, 3, padding=1)
        self.head_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt)
        self.G_middle_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt)
        self.G_middle_1 = SPADEResnetBlock(16 * nf, 16 * nf, opt)
        self.up_0 = SPADEResnetBlock(16 * nf, 8 * nf, opt)
        self.up_1 = SPADEResnetBlock(8 * nf, 4 * nf, opt)
        self.up_2 = SPADEResnetBlock(4 * nf, 2 * nf, opt)
        self.up_3 = SPADEResnetBlock(2 * nf, 1 * nf, opt)
        self.conv_img = nn.Conv2d(nf, 3, 3, padding=1)
        self.up = nn.Upsample(scale_factor=2)

    def compute_latent_vector_size(self, opt):
        num_up_layers = 5
        sw = opt.fine_width // (2**num_up_layers)
        sh = round(sw / (opt.fine_width / opt.fine_height))
        return sw, sh

    def forward(self, input, z=None):
        seg = input
        x = self.fc(seg)
        x = self.head_0(x, seg)
        x = self.up(x)
        x = self.G_middle_0(x, seg)
        x = self.G_middle_1(x, seg)
        x = self.up(x)
        x = self.up_0(x, seg)
        x = self.up(x)
        x = self.up_1(x, seg)
        x = self.up(x)
        x = self.up_2(x, seg)
        x = self.up(x)
        x = self.up_3(x, seg)
        x = self.conv_img(F.leaky_relu(x, 2e-1))
        x = F.tanh(x)
        return x


In [None]:
import argparse
import time
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter


YOUR_KAGGLE_DATASET_NAME = 'clothes_tryon_dataset

opt = argparse.Namespace()

# --- Path Settings ---
opt.dataroot = f'/kaggle/input/clothes_tryon_dataset'
opt.name = 'TOCG_Training_Run' 
opt.checkpoint_dir = '/kaggle/working/checkpoints' 

opt.datamode = 'train'
opt.data_list = 'train_pairs.txt'
opt.workers = 2
opt.batch_size = 4 
opt.shuffle = True

# --- Model & Training Settings ---
opt.fine_height = 1024
opt.fine_width = 768

opt.input_nc = 11
opt.output_nc = 1 
opt.ngf = 64 
opt.ndf = 64 
opt.num_downs = 7 
opt.norm = 'batch' 
opt.lr = 0.0004 
opt.beta1 = 0.5 
opt.lambda_l1 = 10.0 

opt.display_freq = 200 
opt.save_epoch_freq = 1 
opt.niter = 10 
opt.niter_decay = 20
opt.tensorboard_dir = '/kaggle/working/tensorboard_logs'



class GANLoss(nn.Module):
    """Defines the GAN loss which uses either LSGAN or the regular GAN.
    """
    def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0):
        super(GANLoss, self).__init__()
        self.register_buffer('real_label', torch.tensor(target_real_label))
        self.register_buffer('fake_label', torch.tensor(target_fake_label))
        if use_lsgan:
            self.loss = nn.MSELoss()
        else:
            self.loss = nn.BCELoss()

    def get_target_tensor(self, input, target_is_real):
        if target_is_real:
            target_tensor = self.real_label
        else:
            target_tensor = self.fake_label
        return target_tensor.expand_as(input)

    def __call__(self, input, target_is_real):
        target_tensor = self.get_target_tensor(input, target_is_real)
        return self.loss(input, target_tensor)

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0.0)
print("--- Initializing Stage 1 Training: Try-On Condition Generator (TOCG) ---")

# Create directories for checkpoints and logs
os.makedirs(os.path.join(opt.checkpoint_dir, opt.name), exist_ok=True)
os.makedirs(opt.tensorboard_dir, exist_ok=True)
writer = SummaryWriter(log_dir=os.path.join(opt.tensorboard_dir, opt.name))

train_dataset = CustomDataset(opt)
train_loader = CPDataLoader(opt, train_dataset)
print(f"Dataset loaded with {len(train_dataset)} pairs.")

# Define the models from Cell 3
# Generator (TOCG)
netG = UnetGenerator(opt.input_nc, opt.output_nc, opt.num_downs, opt.ngf, norm_layer=nn.BatchNorm2d)
netG.apply(weights_init)

# Discriminator
# The discriminator input is the generator input + the generator output
netD = NLayerDiscriminator(opt.input_nc + opt.output_nc, opt.ndf, n_layers=3, norm_layer=nn.BatchNorm2d)
netD.apply(weights_init)

# Move models to GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
netG.to(device)
netD.to(device)
print("Models initialized and moved to GPU.")

# Define loss functions
criterionGAN = GANLoss().to(device)
criterionL1 = nn.L1Loss().to(device)

# Setup optimizers
optimizer_G = optim.Adam(netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
optimizer_D = optim.Adam(netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))

# Learning rate schedulers
scheduler_G = torch.optim.lr_scheduler.LambdaLR(optimizer_G, lr_lambda=lambda epoch: 1.0 - max(0, epoch + 1 - opt.niter) / float(opt.niter_decay + 1))
scheduler_D = torch.optim.lr_scheduler.LambdaLR(optimizer_D, lr_lambda=lambda epoch: 1.0 - max(0, epoch + 1 - opt.niter) / float(opt.niter_decay + 1))


print("--- Starting Training Loop ---")
total_steps = 0
for epoch in range(1, opt.niter + opt.niter_decay + 1):
    epoch_start_time = time.time()
    for i, data in enumerate(train_loader.data_loader):
        total_steps += 1

        # Unpack data
        cloth = data['cloth'].to(device)
        cloth_mask = data['cloth_mask'].to(device)
        pose_map = data['pose'].to(device)
        densepose = data['densepose'].to(device)
        # CORRECTED: Use the real parse_agnostic from the dataloader
        parse_agnostic = data['parse_agnostic'].to(device)

        # Concatenate inputs for the generator
        inputs = torch.cat([pose_map, densepose, cloth, cloth_mask, parse_agnostic], 1)
        
        # Ground truth is the cloth mask
        real_mask = cloth_mask

        # --- Forward pass ---
        fake_mask = netG(inputs)

        # --- Train Discriminator (D) ---
        optimizer_D.zero_grad()
        # Real
        real_AB = torch.cat((inputs, real_mask), 1)
        pred_real = netD(real_AB.detach())
        loss_D_real = criterionGAN(pred_real, True)
        # Fake
        fake_AB = torch.cat((inputs, fake_mask), 1)
        pred_fake = netD(fake_AB.detach())
        loss_D_fake = criterionGAN(pred_fake, False)
        # Combine
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        loss_D.backward()
        optimizer_D.step()

        # --- Train Generator (G) ---
        optimizer_G.zero_grad()
        # GAN loss
        fake_AB = torch.cat((inputs, fake_mask), 1)
        pred_fake = netD(fake_AB)
        loss_G_GAN = criterionGAN(pred_fake, True)
        # L1 loss
        loss_G_L1 = criterionL1(fake_mask, real_mask) * opt.lambda_l1
        # Combine
        loss_G = loss_G_GAN + loss_G_L1
        loss_G.backward()
        optimizer_G.step()

        # --- Logging and Visualization ---
        if total_steps % opt.display_freq == 0:
            print(f"Epoch: {epoch}, Step: {total_steps}, Loss D: {loss_D.item():.4f}, Loss G: {loss_G.item():.4f}, Loss G_L1: {loss_G_L1.item():.4f}")
            writer.add_scalar('Loss/Discriminator', loss_D.item(), total_steps)
            writer.add_scalar('Loss/Generator', loss_G.item(), total_steps)
            writer.add_scalar('Loss/Generator_L1', loss_G_L1.item(), total_steps)

    # --- End of Epoch ---
    scheduler_G.step()
    scheduler_D.step()
    print(f"End of Epoch {epoch} / {opt.niter + opt.niter_decay} \t Time Taken: {time.time() - epoch_start_time:.2f}s")

    # Save checkpoint
    if epoch % opt.save_epoch_freq == 0:
        save_path_G = os.path.join(opt.checkpoint_dir, opt.name, f'epoch_{epoch}_net_G.pth')
        save_path_D = os.path.join(opt.checkpoint_dir, opt.name, f'epoch_{epoch}_net_D.pth')
        torch.save(netG.cpu().state_dict(), save_path_G)
        torch.save(netD.cpu().state_dict(), save_path_D)
        netG.to(device)
        netD.to(device)
        print(f"Saved checkpoint for epoch {epoch} at {save_path_G}")


# Save final model
final_save_path_G = os.path.join(opt.checkpoint_dir, opt.name, 'tocg_final.pth')
torch.save(netG.cpu().state_dict(), final_save_path_G)
print(f"--- Training Complete. Final model saved to {final_save_path_G} ---")
writer.close()