In [None]:
pip install torch torchvision pillow

In [None]:
import torch
import gc

def clear_gpu_memory():
    torch.cuda.empty_cache()
    gc.collect()


In [None]:
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
from PIL import Image
import os

# Mount Drive (you’ve probably already done this step)
# from google.colab import drive
# drive.mount('/content/drive')

# --- Path to your dataset inside Google Drive ---
data_root = '/kaggle/input/fashion-gan/dataset_new_1500'

# --- Define custom dataset for pairing ---
class PairedDataset(Dataset):
    def __init__(self, sketch_dataset, real_dataset):
        self.sketch_dataset = sketch_dataset
        self.real_dataset = real_dataset

    def __len__(self):
        return min(len(self.sketch_dataset), len(self.real_dataset))

    def __getitem__(self, idx):
        sketch_img, _ = self.sketch_dataset[idx]
        real_img, _ = self.real_dataset[idx]
        return sketch_img, real_img

# --- Define transforms ---
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(256),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# --- Load training datasets ---
train_sketch = datasets.ImageFolder(root=os.path.join(data_root, 'train/train_sketch_1500/'), transform=transform)
train_real = datasets.ImageFolder(root=os.path.join(data_root, 'train/train_real_1500/'), transform=transform)

train_dataset = PairedDataset(train_sketch, train_real)
train_loader = DataLoader(train_dataset, batch_size=7, shuffle=True)

# --- Load validation datasets ---
val_sketch = datasets.ImageFolder(root=os.path.join(data_root, 'valid/valid_sketch_1500/'), transform=transform)
val_real = datasets.ImageFolder(root=os.path.join(data_root, 'valid/valid_real_1500/'), transform=transform)

val_dataset = PairedDataset(val_sketch, val_real)
val_loader = DataLoader(val_dataset, batch_size=7, shuffle=False)


In [None]:
import torch
import torch.nn as nn

# ========================
# Residual Block
# ========================
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(channels, channels, 3),
            nn.InstanceNorm2d(channels),
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(channels, channels, 3),
            nn.InstanceNorm2d(channels)
        )

    def forward(self, x):
        return x + self.block(x)

# ========================
# Generator: ResNet-based
# ========================
class Generator(nn.Module):
    def __init__(self, input_channels=3, output_channels=3, n_residual_blocks=9):
        super().__init__()

        model = [
            nn.ReflectionPad2d(3),
            nn.Conv2d(input_channels, 64, 7),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True),

            nn.Conv2d(64, 128, 3, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(inplace=True),

            nn.Conv2d(128, 256, 3, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.ReLU(inplace=True),
        ]

        for _ in range(n_residual_blocks):
            model.append(ResidualBlock(256))

        model += [
            nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True),

            nn.ReflectionPad2d(3),
            nn.Conv2d(64, output_channels, 7),
            nn.Tanh()
        ]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

# ========================
# Discriminator: PatchGAN
# ========================
class Discriminator(nn.Module):
    def __init__(self, input_channels=3):
        super().__init__()

        def block(in_channels, out_channels, normalize=True):
            layers = [nn.Conv2d(in_channels, out_channels, 4, stride=2, padding=1)]
            if normalize:
                layers.append(nn.InstanceNorm2d(out_channels))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(input_channels, 64, normalize=False),
            *block(64, 128),
            *block(128, 256),
            *block(256, 512),
            nn.Conv2d(512, 1, 4, padding=1)  # PatchGAN output
        )

    def forward(self, x):
        return self.model(x)


In [None]:
import os

os.makedirs("checkpoints", exist_ok=True)  # ✅ Creates the folder if it doesn't exist

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import itertools
import torchvision.utils as vutils
import os
import random

# ============ Setup ============
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Set seed
torch.manual_seed(42)
random.seed(42)

G_S2R = Generator().to(device)
G_R2S = Generator().to(device)
D_R = Discriminator().to(device)
D_S = Discriminator().to(device)

criterion_GAN = nn.MSELoss()
criterion_cycle = nn.L1Loss()
criterion_identity = nn.L1Loss()

lr = 0.0001
beta1, beta2 = 0.5, 0.999
optimizer_G = optim.Adam(itertools.chain(G_S2R.parameters(), G_R2S.parameters()), lr=lr, betas=(beta1, beta2))
optimizer_D_R = optim.Adam(D_R.parameters(), lr=lr, betas=(beta1, beta2))
optimizer_D_S = optim.Adam(D_S.parameters(), lr=lr, betas=(beta1, beta2))

# ============ LR Scheduler ============ 
warmup_epochs = 5
total_epochs = 50

def lr_lambda(current_epoch):
    if current_epoch < warmup_epochs:
        return float(current_epoch + 1) / warmup_epochs  # Warm-up
    else:
        return max(0.1, 0.95 ** (current_epoch - warmup_epochs))  # Decay

scheduler_G = optim.lr_scheduler.LambdaLR(optimizer_G, lr_lambda)
scheduler_D_R = optim.lr_scheduler.LambdaLR(optimizer_D_R, lr_lambda)
scheduler_D_S = optim.lr_scheduler.LambdaLR(optimizer_D_S, lr_lambda)

real_label_val = 1.0
fake_label_val = 0.0

lambda_cycle = 10.0
lambda_id = 5.0

# ============ Fix Input ============
def fix_input_shape(img_tensor, name):
    if img_tensor.dim() == 3:
        img_tensor = img_tensor.unsqueeze(0)
    if img_tensor.dim() != 4:
        raise ValueError(f"[ERROR] {name} tensor shape invalid: {img_tensor.shape}")
    if img_tensor.shape[2] < 7 or img_tensor.shape[3] < 7:
        raise ValueError(f"[ERROR] {name} too small for ReflectionPad2d: {img_tensor.shape}")
    return img_tensor

# ============ Save Samples ============
def save_sample_images(sketch, fake_real, real_image, epoch, output_dir="samples"):
    os.makedirs(output_dir, exist_ok=True)
    sample = torch.cat((sketch[0:1], fake_real[0:1], real_image[0:1]), dim=0)
    vutils.save_image(sample, f"{output_dir}/epoch_{epoch+1}.png", nrow=3, normalize=True)


# ============ Early Stopping ============
best_loss = float('inf')
patience = 15
min_delta = 0.001
patience_counter = 0

# ============ Training Loop ============
num_epochs = 50
for epoch in range(num_epochs):
    epoch_g_loss = 0
    total_batches = len(train_loader)

    for i, (real_sketch, real_image) in enumerate(train_loader):
        real_sketch = fix_input_shape(real_sketch.to(device), "Sketch")
        real_image = fix_input_shape(real_image.to(device), "Real")

        # Train Generators
        optimizer_G.zero_grad()

        fake_real = G_S2R(real_sketch)
        pred_fake = D_R(fake_real)
        loss_GAN_S2R = criterion_GAN(pred_fake, torch.full_like(pred_fake, real_label_val, device=device))

        fake_sketch = G_R2S(real_image)
        pred_fake2 = D_S(fake_sketch)
        loss_GAN_R2S = criterion_GAN(pred_fake2, torch.full_like(pred_fake2, real_label_val, device=device))

        rec_sketch = G_R2S(fake_real)
        rec_real = G_S2R(fake_sketch)
        loss_cycle_sketch = criterion_cycle(rec_sketch, real_sketch)
        loss_cycle_real = criterion_cycle(rec_real, real_image)

        # Identity loss
        idt_real = G_S2R(real_image)
        idt_sketch = G_R2S(real_sketch)
        loss_idt_real = criterion_identity(idt_real, real_image)
        loss_idt_sketch = criterion_identity(idt_sketch, real_sketch)

        # Total generator loss
        loss_G = (
            loss_GAN_S2R + loss_GAN_R2S +
            lambda_cycle * (loss_cycle_sketch + loss_cycle_real) +
            lambda_id * (loss_idt_real + loss_idt_sketch)
        )
        loss_G.backward()
        optimizer_G.step()

        # Train Discriminator R
        optimizer_D_R.zero_grad()
        loss_real = criterion_GAN(D_R(real_image), torch.full_like(pred_fake, real_label_val, device=device))
        loss_fake = criterion_GAN(D_R(fake_real.detach()), torch.full_like(pred_fake, fake_label_val, device=device))
        loss_D_R = (loss_real + loss_fake) * 0.5
        loss_D_R.backward()
        optimizer_D_R.step()

        # Train Discriminator S
        optimizer_D_S.zero_grad()
        loss_real = criterion_GAN(D_S(real_sketch), torch.full_like(pred_fake2, real_label_val, device=device))
        loss_fake = criterion_GAN(D_S(fake_sketch.detach()), torch.full_like(pred_fake2, fake_label_val, device=device))
        loss_D_S = (loss_real + loss_fake) * 0.5
        loss_D_S.backward()
        optimizer_D_S.step()

        epoch_g_loss += loss_G.item()

        # Log: First 3 and Last 3 Batches
        if i < 3 or i >= total_batches - 3:
            print(f"[Epoch {epoch+1}/{num_epochs}] [Batch {i+1}/{total_batches}] "
                  f"[D_R: {loss_D_R.item():.4f}] [D_S: {loss_D_S.item():.4f}] [G: {loss_G.item():.4f}]")

    avg_g_loss = epoch_g_loss / total_batches
    print(f"Epoch {epoch+1} Average G Loss: {avg_g_loss:.4f}")

    # Save sample outputs
    
    if (epoch + 1) % 5 == 0:  # Save every 5th epoch
        save_sample_images(real_sketch, fake_real, real_image, epoch + 1)


    scheduler_G.step()
    scheduler_D_R.step()
    scheduler_D_S.step()

    # Save best model
    if best_loss - avg_g_loss > min_delta:
        best_loss = avg_g_loss
        patience_counter = 0
        os.makedirs("checkpoints", exist_ok=True)
        torch.save(G_S2R.state_dict(), "checkpoints/G_S2R_best.pth")
        torch.save(G_R2S.state_dict(), "checkpoints/G_R2S_best.pth")
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print(f"Early stopping at epoch {epoch+1} due to no improvement in generator loss.")
            break

print("Training done ✅")


In [None]:
import os
import torchvision.utils as vutils

G_S2R.eval()
G_R2S.eval()

output_dir = "output_images/val_sketch_to_real"
os.makedirs(output_dir, exist_ok=True)

with torch.no_grad():
    for idx, (val_sketch, _) in enumerate(val_loader):
        val_sketch = val_sketch.to(device)

        # Generate fake real images from sketch
        fake_real = G_S2R(val_sketch)

        # Save the output image
        for i in range(fake_real.size(0)):
            vutils.save_image(fake_real[i],
                              os.path.join(output_dir, f"fake_real_{idx}_{i}.png"),
                              normalize=True)


In [None]:
import matplotlib.pyplot as plt
import torchvision.transforms as T

G_S2R.eval()
to_pil = T.ToPILImage()

with torch.no_grad():
    for idx, (val_sketch, val_real) in enumerate(val_loader):
        val_sketch = val_sketch.to(device)
        val_real = val_real.to(device)
        fake_real = G_S2R(val_sketch)

        # Display Sketch, Generated, and Ground Truth side-by-side
        for i in range(fake_real.size(0)):
            sketch_img = to_pil(val_sketch[i].cpu())
            gen_img = to_pil(fake_real[i].cpu())
            real_img = to_pil(val_real[i].cpu())

            fig, axs = plt.subplots(1, 3, figsize=(9, 3))
            axs[0].imshow(sketch_img)
            axs[0].set_title("Input Sketch")
            axs[0].axis("off")

            axs[1].imshow(gen_img)
            axs[1].set_title("Generated Image")
            axs[1].axis("off")

            axs[2].imshow(real_img)
            axs[2].set_title("Ground Truth")
            axs[2].axis("off")

            plt.tight_layout()
            plt.show()

        # Show only first few batches (adjust as needed)
        # if idx >= 2:
        #     break


In [None]:
import os
import torch
import torchvision.utils as vutils
import torchvision.transforms as transforms
from PIL import Image
from zipfile import ZipFile

# Create directory to save output images
output_dir = "generated_samples"
os.makedirs(output_dir, exist_ok=True)

G_S2R.eval()  # set generator to eval mode

# Save first N batches (each has batch_size images)
num_batches_to_save = 5

with torch.no_grad():
    for idx, (sketch, real) in enumerate(val_loader):
        if idx >= num_batches_to_save:
            break
        
        sketch = sketch.to(device)
        fake_real = G_S2R(sketch)

        # Loop over batch
        for i in range(sketch.size(0)):
            s_img = sketch[i].cpu()
            f_img = fake_real[i].cpu()
            r_img = real[i].cpu()

            # Resize for better view in PPT
            s_img = transforms.ToPILImage()(s_img).resize((128, 128))
            f_img = transforms.ToPILImage()(f_img).resize((128, 128))
            r_img = transforms.ToPILImage()(r_img).resize((128, 128))

            # Combine: Sketch | Generated | Real
            combined = Image.new("RGB", (128 * 3, 128))
            combined.paste(s_img, (0, 0))
            combined.paste(f_img, (128, 0))
            combined.paste(r_img, (256, 0))

            combined.save(f"{output_dir}/sample_{idx}_{i}.png")

# ✅ Zip the output folder for download on Kaggle
zip_path = "generated_samples.zip"
with ZipFile(zip_path, "w") as zipf:
    for root, _, files in os.walk(output_dir):
        for file in files:
            zipf.write(os.path.join(root, file),
                       arcname=os.path.join(os.path.basename(root), file))

print("✅ Images saved & zipped! You can now download 'generated_samples.zip'")


In [None]:
import os
import torch
import torchvision.utils as vutils
import matplotlib.pyplot as plt
from torchvision import transforms
from PIL import Image
import zipfile

G_S2R.eval()
G_R2S.eval()

output_dir = "generated_samples"
os.makedirs(output_dir, exist_ok=True)

transform = transforms.ToPILImage()

# Loop through entire val_loader
with torch.no_grad():
    for idx, (val_sketch, val_real) in enumerate(val_loader):
        val_sketch = val_sketch.to(device)
        val_real = val_real.to(device)

        fake_real = G_S2R(val_sketch)

        for i in range(fake_real.size(0)):
            sketch_img = transform(val_sketch[i].cpu())
            generated_img = transform(fake_real[i].cpu())
            real_img = transform(val_real[i].cpu())

            # Combine into one image
            combined = Image.new("RGB", (sketch_img.width * 3, sketch_img.height))
            combined.paste(sketch_img, (0, 0))
            combined.paste(generated_img, (sketch_img.width, 0))
            combined.paste(real_img, (sketch_img.width * 2, 0))

            # Save combined image
            combined.save(os.path.join(output_dir, f"combined_{idx}_{i}.png"))

print("✅ All images saved!")

# Zip the folder
zip_name = "generated_samples.zip"
with zipfile.ZipFile(zip_name, 'w') as zipf:
    for root, _, files in os.walk(output_dir):
        for file in files:
            zipf.write(os.path.join(root, file), os.path.relpath(os.path.join(root, file), output_dir))

print(f"📦 Zipped all into {zip_name}")
