In [1]:
!pip install torch torchvision numpy matplotlib tqdm opencv-python

Defaulting to user installation because normal site-packages is not writeable

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.2[0m[39;49m -> [0m[32;49m25.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3 -m pip install --upgrade pip[0m


In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import os

In [3]:
import torch
print("GPU Available:", torch.cuda.is_available())

GPU Available: True


In [5]:
import urllib.request
import tarfile

# Define dataset URL and paths
dataset_url = "https://efrosgans.eecs.berkeley.edu/pix2pix/datasets/edges2handbags.tar.gz"
dataset_path = "datasets/"

# Create directory if not exists
os.makedirs(dataset_path, exist_ok=True)

# Download dataset
dataset_file = dataset_path + "edges2handbags.tar.gz"
urllib.request.urlretrieve(dataset_url, dataset_file)

# Extract dataset
with tarfile.open(dataset_file, "r:gz") as tar:
    tar.extractall(dataset_path)

# Remove tar file
os.remove(dataset_file)

print("Dataset downloaded and extracted successfully!")

Dataset downloaded and extracted successfully!


In [5]:
import os
import torch
from PIL import Image
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader

In [6]:
dataset_paths = {
    "Handbags": r"/home/dell/AMLdata/edges2handbags",
    "Shoes": r"/home/dell/AMLdata/edges2shoes"
}

In [7]:
import os

for name, path in dataset_paths.items():
    if not os.path.exists(path):
        print(f"Error: {path} does not exist!")

In [24]:
base_path = "/home/dell/AMLdata"
folders = [os.path.join(base_path, f) for f in os.listdir(base_path) if os.path.isdir(os.path.join(base_path, f))]
print(folders)

['/home/dell/AMLdata/places', '/home/dell/AMLdata/edges2shoes', '/home/dell/AMLdata/edges2handbags']


In [21]:
folder_name = "edges2shoes"
folder_path = os.path.abspath(folder_name)
print(folder_path)


/home/dell/edges2shoes


In [8]:
# image transformations
transform = transforms.Compose([
    transforms.Resize((256, 256)),  # Resize for training
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])  # Normalize between -1 and 1
])

class PairedImageDataset(Dataset):
    def __init__(self, root, transform=None):
        self.root = root
        self.transform = transform
        self.image_files = [os.path.join(root, f) for f in os.listdir(root) if f.endswith('.jpg')]

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = self.image_files[idx]
        img = Image.open(img_path).convert("RGB")

        # Split into (edge, real photo)
        w, h = img.size
        edge = img.crop((0, 0, w//2, h))  # Left half = Edge
        real = img.crop((w//2, 0, w, h))  # Right half = Realistic photo

        if self.transform:
            edge = self.transform(edge)
            real = self.transform(real)

        return edge, real

In [9]:
import os

for name, path in dataset_paths.items():
    print(f"Checking path: {path}")
    if not os.path.exists(path):
        print(f"Error: {path} does not exist!")

Checking path: /home/dell/AMLdata/edges2handbags
Checking path: /home/dell/AMLdata/edges2shoes


In [10]:
class PairedImageDataset(Dataset):
    def __init__(self, root, transform=None):
        self.root = root
        self.transform = transform

        print(f"Checking path: {root}")  # Debugging

        if not os.path.exists(root):
            raise FileNotFoundError(f"Dataset directory not found: {root}")

        self.image_files = [os.path.join(root, f) for f in os.listdir(root) if f.endswith('.jpg')]

        if len(self.image_files) == 0:
            print(f"Warning: No images found in {root}!")

    def __len__(self):
        return len(self.image_files)

In [11]:
import os

dataset_path = "/home/dell/AMLdata/edges2shoes"  # Change to the dataset folder you want to check
print("Files in dataset folder:", os.listdir(dataset_path))

Files in dataset folder: ['val', 'train']


In [12]:
class PairedImageDataset(Dataset):
    def __init__(self, root, transform=None):
        self.root = root
        self.transform = transform

        # Recursively collect all image file paths inside train/ and val/
        self.image_files = []
        for subdir in ["train", "val"]:  # Look inside both train/ and val/
            folder_path = os.path.join(root, subdir)
            if os.path.exists(folder_path):  # Ensure the folder exists
                self.image_files.extend(
                    [os.path.join(folder_path, f) for f in os.listdir(folder_path) if f.lower().endswith(('.jpg', '.png', '.jpeg'))]
                )

        # Debug: Check if images are found
        print(f"Found {len(self.image_files)} images in {root}")

        if len(self.image_files) == 0:
            raise FileNotFoundError(f"No images found in {root}/train or {root}/val")

    def __len__(self):
        return len(self.image_files)

In [13]:
class PairedImageDataset(Dataset):
    def __init__(self, root, transform=None):
        self.root = root
        self.transform = transform

        # Recursively collect all image file paths from any subfolders
        self.image_files = []
        for dirpath, _, filenames in os.walk(root):  # Walk through all subfolders
            self.image_files.extend(
                [os.path.join(dirpath, f) for f in filenames if f.lower().endswith(('.jpg', '.png', '.jpeg'))]
            )

        # Debug: Print how many images were found
        print(f"Found {len(self.image_files)} images in {root}")

        if len(self.image_files) == 0:
            raise FileNotFoundError(f"No images found in {root} or its subdirectories!")

    def __len__(self):
        return len(self.image_files)

In [14]:
datasets_dict = {name: PairedImageDataset(path, transform=transform) for name, path in dataset_paths.items()}
batch_size = 16
dataloaders = {name: DataLoader(ds, batch_size=batch_size, shuffle=True) for name, ds in datasets_dict.items()}

for name, ds in datasets_dict.items():
    print(f"{name} Dataset Loaded: {len(ds)} paired images")

Found 138767 images in /home/dell/AMLdata/edges2handbags
Found 50025 images in /home/dell/AMLdata/edges2shoes
Handbags Dataset Loaded: 138767 paired images
Shoes Dataset Loaded: 50025 paired images


In [15]:
import torch.nn as nn

class UNetGenerator(nn.Module):
    def __init__(self, input_channels=3, output_channels=3, features=64):
        super(UNetGenerator, self).__init__()

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(input_channels, features, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(features, features * 2, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(features * 2),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(features * 2, features * 4, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(features * 4),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(features * 4, features * 8, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(features * 8),
            nn.LeakyReLU(0.2, inplace=True),
        )

        # Decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(features * 8, features * 4, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(features * 4),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(features * 4, features * 2, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(features * 2),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(features * 2, features, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(features),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(features, output_channels, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

# Initialize Generator
generator = UNetGenerator().cuda()

print("Generator Model Created!")

Generator Model Created!


In [16]:
class PatchDiscriminator(nn.Module):
    def __init__(self, input_channels=6, features=64):
        super(PatchDiscriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Conv2d(input_channels, features, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(features, features * 2, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(features * 2),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(features * 2, features * 4, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(features * 4),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(features * 4, 1, kernel_size=4, stride=1, padding=1),
            nn.Sigmoid()
        )

    def forward(self, edge, real):
        x = torch.cat([edge, real], dim=1)  # Concatenate (edge, real/generated)
        return self.model(x)

# Initialize Discriminator
discriminator = PatchDiscriminator().cuda()

print("Discriminator Model Created!")

Discriminator Model Created!


In [17]:
class PairedImageDataset(Dataset):
    def __init__(self, root, transform=None):
        self.root = root
        self.transform = transform
        self.image_files = sorted([os.path.join(root, f) for f in os.listdir(root) if f.endswith('.jpg')])

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = self.image_files[idx]

        # Load image and ensure it's valid
        try:
            img = Image.open(img_path).convert("RGB")
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            return None  # Skip bad images

        # Ensure the image is split correctly into (edge, real)
        w, h = img.size
        if w < 2:  # Ensure width is large enough to split
            print(f"Skipping image {img_path}, width too small!")
            return None

        edge = img.crop((0, 0, w//2, h))  # Left half
        real = img.crop((w//2, 0, w, h))  # Right half

        if self.transform:
            edge = self.transform(edge)
            real = self.transform(real)

        return edge, real  # Ensure this returns correctly

In [18]:
import os
import torch
from PIL import Image
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import glob

# Define image transformations
transform = transforms.Compose([
    transforms.Resize((256, 256)),  # Resize for training
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])  # Normalize between -1 and 1
])

class PairedImageDataset(Dataset):
    def __init__(self, root, transform=None):
        self.root = root
        self.transform = transform

        # Recursively find all images inside subfolders
        self.image_files = sorted(glob.glob(os.path.join(root, "**/*.jpg"), recursive=True))

        if len(self.image_files) == 0:
            raise RuntimeError(f"No images found in {root}. Check folder structure!")

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = self.image_files[idx]

        try:
            img = Image.open(img_path).convert("RGB")
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            return self.__getitem__((idx + 1) % len(self))  # Skip to next image

        # Ensure images are split correctly (edges | real photo)
        w, h = img.size
        if w < 2:
            print(f"Skipping {img_path}, width too small!")
            return self.__getitem__((idx + 1) % len(self))

        edge = img.crop((0, 0, w//2, h))  # Left half = Edge
        real = img.crop((w//2, 0, w, h))  # Right half = Realistic photo

        if self.transform:
            edge = self.transform(edge)
            real = self.transform(real)

        return edge, real  # Ensure this returns valid tensors


In [19]:
dataset_paths = {
    "Handbags": r"/home/dell/AMLdata/edges2handbags",
    "Shoes": r"/home/dell/AMLdata/edges2shoes"
}

# Load datasets
datasets_dict = {name: PairedImageDataset(path, transform=transform) for name, path in dataset_paths.items()}
batch_size = 16
dataloaders = {name: DataLoader(ds, batch_size=batch_size, shuffle=True) for name, ds in datasets_dict.items()}

# Print dataset sizes
for name, ds in datasets_dict.items():
    print(f"{name} Dataset Loaded: {len(ds)} paired images")


Handbags Dataset Loaded: 138767 paired images
Shoes Dataset Loaded: 50025 paired images


In [20]:
for edge, real in dataloaders["Shoes"]:
    print("Edge batch shape:", edge.shape)
    print("Real batch shape:", real.shape)
    break

Edge batch shape: torch.Size([16, 3, 256, 256])
Real batch shape: torch.Size([16, 3, 256, 256])


In [21]:
criterion_gan = nn.BCELoss()
l1_loss = nn.L1Loss()

optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

print("Loss functions and optimizers initialized!")


Loss functions and optimizers initialized!


In [22]:
for name, ds in datasets_dict.items():
    print(f" {name} Dataset Loaded: {len(ds)} paired images")

 Handbags Dataset Loaded: 138767 paired images
 Shoes Dataset Loaded: 50025 paired images


In [24]:
# Define Save Function BEFORE Training Function
def save_trained_models(dataset_name, generator, discriminator):
    save_dir = os.getcwd()  # Get current working directory
   
    # Save full models
    torch.save(generator, os.path.join(save_dir, f"generator_{dataset_name}.pth"))
    torch.save(discriminator, os.path.join(save_dir, f"discriminator_{dataset_name}.pth"))

    # Save only model weights
    torch.save(generator.state_dict(), os.path.join(save_dir, f"generator_weights_{dataset_name}.pth"))
    torch.save(discriminator.state_dict(), os.path.join(save_dir, f"discriminator_weights_{dataset_name}.pth"))

    print(f" {dataset_name} models saved successfully at {save_dir}!")

In [25]:
import time

epochs = 100

def train_cgan(dataset_name, dataloader):
    print(f"\n Training cGAN on {dataset_name} Dataset...\n")

    for epoch in range(epochs):
        start_time = time.time()

        for i, (edge, real) in enumerate(dataloader):
            edge, real = edge.cuda(), real.cuda()

            # Train Discriminator
            optimizer_D.zero_grad()
            fake_images = generator(edge)

            real_output = discriminator(edge, real)
            fake_output = discriminator(edge, fake_images.detach())

            real_labels = torch.ones_like(real_output).cuda()
            fake_labels = torch.zeros_like(fake_output).cuda()

            real_loss = criterion_gan(real_output, real_labels)
            fake_loss = criterion_gan(fake_output, fake_labels)
            d_loss = (real_loss + fake_loss) / 2

            d_loss.backward()
            optimizer_D.step()

            # Train Generator
            optimizer_G.zero_grad()
            fake_output = discriminator(edge, fake_images)
            g_loss = criterion_gan(fake_output, real_labels) + 100 * l1_loss(fake_images, real)

            g_loss.backward()
            optimizer_G.step()

        print(f"Epoch {epoch+1}/{epochs} - D Loss: {d_loss.item():.4f} - G Loss: {g_loss.item():.4f}")
        
    save_trained_models(dataset_name, generator, discriminator )



# Train on all datasets
for dataset_name in ["Shoes", "Handbags"]:
    train_cgan(dataset_name, dataloaders[dataset_name])



 Training cGAN on Shoes Dataset...

Epoch 1/100 - D Loss: 0.1250 - G Loss: 20.9759
Epoch 2/100 - D Loss: 0.2635 - G Loss: 20.1871
Epoch 3/100 - D Loss: 0.2256 - G Loss: 16.1376
Epoch 4/100 - D Loss: 0.7274 - G Loss: 12.9691
Epoch 5/100 - D Loss: 0.2081 - G Loss: 19.6297
Epoch 6/100 - D Loss: 0.0206 - G Loss: 26.2806
Epoch 7/100 - D Loss: 0.0074 - G Loss: 21.7140
Epoch 8/100 - D Loss: 0.0358 - G Loss: 22.1803
Epoch 9/100 - D Loss: 0.0167 - G Loss: 24.1718
Epoch 10/100 - D Loss: 0.0057 - G Loss: 29.2349
Epoch 11/100 - D Loss: 0.2091 - G Loss: 16.4510
Epoch 12/100 - D Loss: 0.1238 - G Loss: 16.1530
Epoch 13/100 - D Loss: 0.0004 - G Loss: 26.3516
Epoch 14/100 - D Loss: 0.0187 - G Loss: 27.8696
Epoch 15/100 - D Loss: 0.0117 - G Loss: 31.7295
Epoch 16/100 - D Loss: 0.1466 - G Loss: 17.2259
Epoch 17/100 - D Loss: 0.2813 - G Loss: 22.7406
Epoch 18/100 - D Loss: 0.0154 - G Loss: 17.2596
Epoch 19/100 - D Loss: 0.5401 - G Loss: 17.7276
Epoch 20/100 - D Loss: 0.0435 - G Loss: 22.7132
Epoch 21/100