In [1]:
# Section 1: Imports and Device Setup
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

# Check for GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [2]:

# Section 2: Model Definitions
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(6, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU(inplace=True)
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(9, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 1, kernel_size=4, stride=1, padding=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)

In [4]:


# Section 3: Dataset Preparation
class VirtualTryOnDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.images = os.listdir(os.path.join(root_dir, 'image'))

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_name = self.images[idx]
        
        try:
            # Load images from the respective directories
            image = Image.open(os.path.join(self.root_dir, 'image', img_name))
            cloth = Image.open(os.path.join(self.root_dir, 'cloth', img_name))
            cloth_mask = Image.open(os.path.join(self.root_dir, 'cloth-mask', img_name))
            agnostic = Image.open(os.path.join(self.root_dir, 'agnostic-v3.2', img_name))
            
            if self.transform:
                image = self.transform(image)
                cloth = self.transform(cloth)
                cloth_mask = self.transform(cloth_mask)
                agnostic = self.transform(agnostic)

            return image, cloth, cloth_mask, agnostic

        except FileNotFoundError:
            print(f"File {img_name} not found. Skipping.")
            return self.__getitem__((idx + 1) % len(self))  # Skip to the next item if file is missing

# Define transformations
transform = transforms.Compose([
    transforms.Resize((256, 192)),
    transforms.ToTensor()
])

train_dataset = VirtualTryOnDataset(root_dir='/kaggle/input/high-resolution-viton-zalando-dataset/train', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)


FileNotFoundError: [WinError 3] The system cannot find the path specified: '/kaggle/input/high-resolution-viton-zalando-dataset/train\\image'

In [None]:
# Section 4: Model Initialization
generator = Generator().to(device)
discriminator = Discriminator().to(device)

criterion_gan = nn.BCELoss()
criterion_l1 = nn.L1Loss()

optimizer_g = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

# Section 5: Training Loop
num_epochs = 10
for epoch in range(num_epochs):
    for i, (image, cloth, cloth_mask, agnostic) in enumerate(train_loader):
        
        # Move data to GPU if available
        image = image.to(device)
        cloth = cloth.to(device)
        cloth_mask = cloth_mask.to(device)
        agnostic = agnostic.to(device)
        
        # Generate fake images using the generator
        warped_cloth = cloth  # Simplified step (you’d warp it in practice)
        input_to_generator = torch.cat((agnostic, warped_cloth), dim=1)
        
        # Train Discriminator
        optimizer_d.zero_grad()
        
        # Real images
        real_input = torch.cat((image, agnostic, warped_cloth), dim=1)
        real_output = discriminator(real_input)
        real_label = torch.ones_like(real_output).to(device)
        loss_d_real = criterion_gan(real_output, real_label)
        
        # Fake images
        fake_image = generator(input_to_generator)
        fake_input = torch.cat((fake_image, agnostic, warped_cloth), dim=1)
        fake_output = discriminator(fake_input.detach())
        fake_label = torch.zeros_like(fake_output).to(device)
        loss_d_fake = criterion_gan(fake_output, fake_label)
        
        # Total discriminator loss
        loss_d = (loss_d_real + loss_d_fake) / 2
        loss_d.backward()
        optimizer_d.step()
        
        # Train Generator
        optimizer_g.zero_grad()
        
        fake_output = discriminator(fake_input)
        loss_g_gan = criterion_gan(fake_output, real_label)
        loss_g_l1 = criterion_l1(fake_image, image) * 10.0  # L1 loss for pixel-level alignment
        
        # Total generator loss
        loss_g = loss_g_gan + loss_g_l1
        loss_g.backward()
        optimizer_g.step()
        
        # Print loss every few iterations
        if i % 100 == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], "
                  f"D Loss: {loss_d.item():.4f}, G Loss: {loss_g.item():.4f}")

    # Save the model after each epoch
    torch.save(generator.state_dict(), f"/kaggle/working/generator_epoch_{epoch+1}.pth")
    torch.save(discriminator.state_dict(), f"/kaggle/working/discriminator_epoch_{epoch+1}.pth")
    print(f"Model saved for epoch {epoch+1}")

    # Display a sample output image
    with torch.no_grad():
        sample_fake_image = fake_image[0].cpu().permute(1, 2, 0) * 0.5 + 0.5  # Re-normalize for display
        sample_fake_image = np.clip(sample_fake_image.numpy(), 0, 1)
        
        plt.figure(figsize=(4, 4))
        plt.imshow(sample_fake_image)
        plt.axis('off')
        plt.title(f"Generated Image after Epoch {epoch+1}")
        plt.show()

# Section 6: Training Complete
print("Training complete!")
