In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision import transforms, models
from PIL import Image
import os
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import json
import cv2

DATA_DIR = "/kaggle/input/clothes-tryon/clothes_tryon_dataset"
GENERATOR_SAVE_PATH = "isolated_to_flat_generator1.pth"
DISCRIMINATOR_SAVE_PATH = "isolated_to_flat_discriminator1.pth"
PROGRESS_DIR = "isolated_to_flat_progress"

# Directories for the train set
IMAGE_DIR = os.path.join(DATA_DIR, "test/image")
CLOTH_DIR = os.path.join(DATA_DIR, "test/cloth")
PARSE_DIR = os.path.join(DATA_DIR, "test/image-parse-v3")

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LEARNING_RATE = 2e-4
BETA1 = 0.5
BATCH_SIZE = 4
EPOCHS = 200
IMAGE_HEIGHT = 256
IMAGE_WIDTH = 192
SUBSET_SIZE = 2000 

print(f"Using device: {DEVICE}")
print("--- RUNNING IN TRAINING MODE (Isolated Cloth to Flat Cloth GAN) ---")
os.makedirs(PROGRESS_DIR, exist_ok=True)

transform = transforms.Compose([
    transforms.Resize((IMAGE_HEIGHT, IMAGE_WIDTH)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

class IsolatedClothDataset(Dataset):
    def __init__(self):
        super().__init__()
        
        self.valid_pairs = []
        print("Verifying dataset integrity and creating correct pairs...")
        
        person_images = sorted(os.listdir(IMAGE_DIR))
        
        for person_fn in tqdm(person_images):
            base_name, _ = os.path.splitext(person_fn)
            cloth_fn = base_name + ".jpg"
            parse_fn = base_name + ".png"
            
            if all(os.path.exists(os.path.join(d, f)) for d, f in [
                (IMAGE_DIR, person_fn), 
                (CLOTH_DIR, cloth_fn), 
                (PARSE_DIR, parse_fn)
            ]):
                self.valid_pairs.append((person_fn, cloth_fn, parse_fn))
        
        print(f"Found {len(self.valid_pairs)} complete and correctly matched data pairs.")

    def __len__(self):
        return len(self.valid_pairs)

    def __getitem__(self, idx):
        person_fn, cloth_fn, parse_fn = self.valid_pairs[idx]
        try:
            person_img = Image.open(os.path.join(IMAGE_DIR, person_fn)).convert("RGB")
            target_cloth = Image.open(os.path.join(CLOTH_DIR, cloth_fn)).convert("RGB")
            parse_map_img = Image.open(os.path.join(PARSE_DIR, parse_fn))
            parse_map_arr = np.array(parse_map_img)
            person_img_arr = np.array(person_img.resize((IMAGE_WIDTH, IMAGE_HEIGHT)))
            resized_parse_map = np.array(parse_map_img.resize((IMAGE_WIDTH, IMAGE_HEIGHT), Image.NEAREST))
            cloth_mask = np.isin(resized_parse_map, [5, 6, 7]).astype(np.uint8)
            isolated_cloth_arr = person_img_arr * cloth_mask[:, :, np.newaxis]
            isolated_cloth = Image.fromarray(isolated_cloth_arr)
            source_cloth_tensor = transform(isolated_cloth)
            target_cloth_tensor = transform(target_cloth)
            
            return {'source_cloth': source_cloth_tensor, 'target_cloth': target_cloth_tensor}
        except Exception:
            return self.__getitem__((idx + 1) % len(self))


class UnetGenerator(nn.Module):
    def __init__(self, input_nc=3, output_nc=3, ngf=64):
        super(UnetGenerator, self).__init__()
        resnet = models.resnet34(weights=models.ResNet34_Weights.DEFAULT)
        
        self.encoder_conv1 = resnet.conv1
        self.encoder_bn1 = resnet.bn1
        self.encoder_relu = resnet.relu
        self.encoder_maxpool = resnet.maxpool
        self.encoder2 = resnet.layer1
        self.encoder3 = resnet.layer2
        self.encoder4 = resnet.layer3
        self.encoder5 = resnet.layer4
        
        self.upconv4 = nn.ConvTranspose2d(512, 256, 4, 2, 1)
        self.decoder4 = nn.Sequential(nn.Conv2d(512, 256, 3, 1, 1), nn.ReLU())
        self.upconv3 = nn.ConvTranspose2d(256, 128, 4, 2, 1)
        self.decoder3 = nn.Sequential(nn.Conv2d(256, 128, 3, 1, 1), nn.ReLU())
        self.upconv2 = nn.ConvTranspose2d(128, 64, 4, 2, 1)
        self.decoder2 = nn.Sequential(nn.Conv2d(128, 64, 3, 1, 1), nn.ReLU())
        self.upconv1 = nn.ConvTranspose2d(64, ngf, 4, 2, 1)
        self.decoder1 = nn.Sequential(nn.Conv2d(ngf + 64, ngf, 3, 1, 1), nn.ReLU())
        self.upconv_final = nn.ConvTranspose2d(ngf, ngf // 2, 4, 2, 1)
        self.decoder_final = nn.Sequential(nn.Conv2d(ngf // 2, ngf // 2, 3, 1, 1), nn.ReLU())
        self.final = nn.Sequential(nn.Conv2d(ngf // 2, output_nc, 3, 1, 1), nn.Tanh())

    def forward(self, x):
        skip1 = self.encoder_relu(self.encoder_bn1(self.encoder_conv1(x)))
        e1 = self.encoder_maxpool(skip1)
        e2 = self.encoder2(e1); e3 = self.encoder3(e2); e4 = self.encoder4(e3); e5 = self.encoder5(e4)
        d4 = self.decoder4(torch.cat([self.upconv4(e5), e4], 1))
        d3 = self.decoder3(torch.cat([self.upconv3(d4), e3], 1))
        d2 = self.decoder2(torch.cat([self.upconv2(d3), e2], 1))
        d1 = self.decoder1(torch.cat([self.upconv1(d2), skip1], 1))
        d_final = self.decoder_final(self.upconv_final(d1))
        return self.final(d_final)

class PatchGANDiscriminator(nn.Module):
    def __init__(self, input_nc=6, ndf=64, n_layers=3): # input_nc is 3 (source) + 3 (target/fake) = 6
        super(PatchGANDiscriminator, self).__init__()
        layers = [nn.Conv2d(input_nc, ndf, 4, 2, 1), nn.LeakyReLU(0.2, True)]
        nf_mult = 1
        for n in range(1, n_layers + 1):
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** n, 8)
            stride = 1 if n == n_layers else 2
            layers += [nn.Conv2d(ndf*nf_mult_prev, ndf*nf_mult, 4, stride, 1, bias=False), nn.BatchNorm2d(ndf*nf_mult), nn.LeakyReLU(0.2, True)]
        layers += [nn.Conv2d(ndf*nf_mult, 1, 4, 1, 1)]
        self.model = nn.Sequential(*layers)
    def forward(self, x): return self.model(x)
def validate_and_visualize(generator, val_loader, epoch):
    generator.eval()
    fig, axes = plt.subplots(len(val_loader.dataset), 3, figsize=(10, len(val_loader.dataset) * 5))
    fig.suptitle(f"Epoch {epoch+1} Validation Progress", fontsize=16)
    
    with torch.no_grad():
        for i, batch in enumerate(val_loader):
            source_cloth = batch['source_cloth'].to(DEVICE)
            target_cloth = batch['target_cloth'].to(DEVICE)
            fake_cloth = generator(source_cloth)
            
            source_vis = source_cloth[0].cpu().permute(1, 2, 0) * 0.5 + 0.5
            target_vis = target_cloth[0].cpu().permute(1, 2, 0) * 0.5 + 0.5
            fake_vis = fake_cloth[0].cpu().permute(1, 2, 0) * 0.5 + 0.5
            
            axes[i, 0].imshow(torch.clamp(source_vis, 0, 1)); axes[i, 0].set_title("Source (Isolated Cloth)"); axes[i, 0].axis("off")
            axes[i, 1].imshow(torch.clamp(fake_vis, 0, 1)); axes[i, 1].set_title("Generated (Unwarped)"); axes[i, 1].axis("off")
            axes[i, 2].imshow(torch.clamp(target_vis, 0, 1)); axes[i, 2].set_title("Ground Truth"); axes[i, 2].axis("off")
            
    plt.tight_layout(rect=[0, 0, 1, 0.96])
    plt.savefig(os.path.join(PROGRESS_DIR, f"epoch_{epoch+1:03d}.png"))
    plt.show(); plt.close(fig)
    generator.train()
def main():
    generator = UnetGenerator().to(DEVICE)
    discriminator = PatchGANDiscriminator().to(DEVICE)
    optimizer_G = torch.optim.Adam(generator.parameters(), lr=LEARNING_RATE, betas=(BETA1, 0.999))
    optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=LEARNING_RATE, betas=(BETA1, 0.999))
    gan_loss_fn = nn.BCEWithLogitsLoss()
    l1_loss_fn = nn.L1Loss()
    
    full_dataset = IsolatedClothDataset()
    
    subset_indices = list(range(min(SUBSET_SIZE, len(full_dataset))))
    subset_dataset = Subset(full_dataset, subset_indices)
    
    val_size = 5
    train_size = len(subset_dataset) - val_size
    train_dataset, val_dataset = torch.utils.data.random_split(subset_dataset, [train_size, val_size])
    
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)

    if os.path.exists(GENERATOR_SAVE_PATH):
        print("Loading pre-trained models...")
        generator.load_state_dict(torch.load(GENERATOR_SAVE_PATH))
        discriminator.load_state_dict(torch.load(DISCRIMINATOR_SAVE_PATH))
    else:
        print("--- Starting GAN Training ---")
        for epoch in range(EPOCHS):
            loop = tqdm(train_loader, leave=True)
            for batch in loop:
                source_cloth = batch['source_cloth'].to(DEVICE)
                target_cloth = batch['target_cloth'].to(DEVICE)
                
                # Train Discriminator
                optimizer_D.zero_grad()
                fake_cloth = generator(source_cloth)
                pred_real = discriminator(torch.cat((source_cloth, target_cloth), 1))
                loss_D_real = gan_loss_fn(pred_real, torch.ones_like(pred_real))
                pred_fake = discriminator(torch.cat((source_cloth, fake_cloth.detach()), 1))
                loss_D_fake = gan_loss_fn(pred_fake, torch.zeros_like(pred_fake))
                loss_D = (loss_D_real + loss_D_fake) * 0.5
                loss_D.backward()
                optimizer_D.step()
                optimizer_G.zero_grad()
                pred_fake_for_G = discriminator(torch.cat((source_cloth, fake_cloth), 1))
                loss_G_gan = gan_loss_fn(pred_fake_for_G, torch.ones_like(pred_fake_for_G))
                loss_G_l1 = l1_loss_fn(fake_cloth, target_cloth) * 100
                loss_G = loss_G_gan + loss_G_l1
                loss_G.backward()
                optimizer_G.step()

                loop.set_postfix(loss_D=loss_D.item(), loss_G=loss_G.item())

            print(f"Epoch {epoch+1}/{EPOCHS} complete.")
            validate_and_visualize(generator, val_loader, epoch)
            
        torch.save(generator.state_dict(), GENERATOR_SAVE_PATH)
        torch.save(discriminator.state_dict(), DISCRIMINATOR_SAVE_PATH)
        print("Models saved.")
        
    print("--- Displaying Final Results ---")
    validate_and_visualize(generator, val_loader, epoch=-1)

if __name__ == "__main__":
    main()
