In [None]:
from google.colab import files
files.upload()
! mkdir ~/.kaggle
! cp kaggle.json ~/.kaggle/
! chmod 600 ~/.kaggle/kaggle.json

! kaggle datasets download -d arbazkhan971/cuhk-face-sketch-database-cufs
! unzip cuhk-face-sketch-database-cufs.zip

In [None]:
import os
import glob
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
from sklearn.model_selection import train_test_split

In [None]:
# Configuration
IMG_WIDTH = 128
IMG_HEIGHT = 128
IMG_CHANNELS = 3
LATENT_DIM = 128
VAE_LR = 1e-4
GEN_LR = 1e-4
DISC_LR = 1e-4
LAMBDA_CYCLE = 10
EPOCHS = 25
BATCH_SIZE = 16

# Directories
REAL_IMAGES_DIR = '/content/photos'
PENCIL_SKETCHES_DIR = '/content/sketches'

# Transformation for images
transform = transforms.Compose([
    transforms.Resize((IMG_HEIGHT, IMG_WIDTH)),
    transforms.ToTensor(),
    #transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize to [-1, 1]
])

In [None]:
def sorted_alphanumeric(data):
    convert = lambda text: int(text) if text.isdigit() else text.lower()
    alphanum_key = lambda key: [convert(c) for c in re.split('([0-9]+)', key)]
    return sorted(data, key=alphanum_key)

class ImageDataset(Dataset):
    def __init__(self, real_images_path, sketches_path, transform=None):
        # Get all file paths
        self.real_images = glob.glob(real_images_path)
        self.sketches = glob.glob(sketches_path)

        # Sort file paths
        self.real_images = sorted_alphanumeric(self.real_images)
        self.sketches = sorted_alphanumeric(self.sketches)

        # Ensure that lengths of real_images and sketches match
        if len(self.real_images) != len(self.sketches):
            raise ValueError("The number of real images and sketches must be the same")

        self.transform = transform

    def __len__(self):
        return len(self.real_images)

    def __getitem__(self, idx):
        real_image = Image.open(self.real_images[idx])
        sketch = Image.open(self.sketches[idx])

        if self.transform:
            real_image = self.transform(real_image)
            sketch = self.transform(sketch)

        return real_image, sketch

In [None]:
class Sampling(nn.Module):
    def forward(self, z_mean, z_log_var):
        epsilon = torch.randn_like(z_mean)
        return z_mean + torch.exp(0.5 * z_log_var) * epsilon

class Encoder(nn.Module):
    def __init__(self, latent_dim):
        super(Encoder, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, stride=2, padding=1)
        self.fc1 = nn.Linear(64 * 32 * 32, 16)
        self.fc2_mean = nn.Linear(16, latent_dim)
        self.fc2_log_var = nn.Linear(16, latent_dim)
        self.sampling = Sampling()

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = x.view(x.size(0), -1)
        x = torch.relu(self.fc1(x))
        z_mean = self.fc2_mean(x)
        z_log_var = self.fc2_log_var(x)
        z = self.sampling(z_mean, z_log_var)
        return z_mean, z_log_var, z

class Decoder(nn.Module):
    def __init__(self, latent_dim):
        super(Decoder, self).__init__()
        self.fc1 = nn.Linear(latent_dim, 64 * 32 * 32)
        self.fc2 = nn.Linear(64 * 32 * 32, 3 * 128 * 128)
        self.deconv1 = nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1)
        self.deconv2 = nn.ConvTranspose2d(32, 3, 3, stride=2, padding=1)

    def forward(self, z):
        x = torch.relu(self.fc1(z))
        x = x.view(x.size(0), 64, 32, 32)
        x = torch.relu(self.deconv1(x))
        x = torch.sigmoid(self.deconv2(x))
        return x

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.down_stack = nn.Sequential(
            nn.Conv2d(3, 64, 4, stride=2, padding=1),
            nn.LeakyReLU(),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(),
            nn.Conv2d(128, 256, 4, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.LeakyReLU(),
            nn.Conv2d(256, 512, 4, stride=2, padding=1),
            nn.InstanceNorm2d(512),
            nn.LeakyReLU()
        )
        self.up_stack = nn.Sequential(
            nn.ConvTranspose2d(512, 512, 4, stride=2, padding=1),
            nn.Dropout(0.5),
            nn.ReLU(),
            nn.ConvTranspose2d(512, 512, 4, stride=2, padding=1),
            nn.Dropout(0.5),
            nn.ReLU(),
            nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
            nn.ReLU()
        )
        self.final_layer = nn.ConvTranspose2d(64, 3, 4, stride=2, padding=1, activation='tanh')

    def forward(self, x):
        down = self.down_stack(x)
        x = self.up_stack(down)
        return self.final_layer(x)

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 64, 4, stride=2, padding=1),
            nn.LeakyReLU(),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(),
            nn.Conv2d(128, 256, 4, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.LeakyReLU(),
            nn.Conv2d(256, 512, 4, stride=2, padding=1),
            nn.InstanceNorm2d(512),
            nn.LeakyReLU(),
            nn.Conv2d(512, 1, 4, padding=1)
        )

    def forward(self, x):
        return self.model(x)

In [None]:
def vae_loss(real, reconstruction, z_mean, z_log_var):
    recon_loss = nn.functional.mse_loss(reconstruction, real, reduction='sum')
    kl_loss = -0.5 * torch.sum(1 + z_log_var - z_mean.pow(2) - z_log_var.exp())
    return recon_loss + kl_loss

def train_vae(vae, dataloader, optimizer):
    vae.train()
    total_loss = 0
    for real_images, _ in dataloader:
        optimizer.zero_grad()
        z_mean, z_log_var, z = vae.encoder(real_images)
        reconstruction = vae.decoder(z)
        loss = vae_loss(real_images, reconstruction, z_mean, z_log_var)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(dataloader)

def gan_loss_fn(real, fake):
    return nn.functional.binary_cross_entropy_with_logits(fake, real)

def train_gan(gan, dataloader, vae_optimizer, gen_optimizer, disc_optimizer):
    gan.train()
    total_gen_loss = 0
    total_disc_loss = 0
    for real_images, sketches in dataloader:
        # VAE training
        vae_loss = train_vae(gan.vae, dataloader, vae_optimizer)

        # GAN training
        real_images = real_images.to(device)
        sketches = sketches.to(device)

        fake_sketches = gan.generator(real_images)
        cycled_images = gan.generator(fake_sketches)
        same_sketches = gan.generator(sketches)

        disc_real_sketches = gan.discriminator(sketches)
        disc_fake_sketches = gan.discriminator(fake_sketches)

        valid = torch.ones_like(disc_fake_sketches)
        fake = torch.zeros_like(disc_fake_sketches)

        gen_loss = gan_loss_fn(valid, disc_fake_sketches) + \
                   LAMBDA_CYCLE * nn.functional.l1_loss(real_images, cycled_images) + \
                   LAMBDA_CYCLE * nn.functional.l1_loss(sketches, same_sketches)

        disc_loss = (gan_loss_fn(valid, disc_real_sketches) + \
                     gan_loss_fn(fake, disc_fake_sketches)) * 0.5

        gen_optimizer.zero_grad()
        gen_loss.backward()
        gen_optimizer.step()

        disc_optimizer.zero_grad()
        disc_loss.backward()
        disc_optimizer.step()

        total_gen_loss += gen_loss.item()
        total_disc_loss += disc_loss.item()

    return total_gen_loss / len(dataloader), total_disc_loss / len(dataloader)

In [None]:
# Load dataset
dataset = ImageDataset(
    real_images_path=os.path.join(REAL_IMAGES_DIR, '*'),
    sketches_path=os.path.join(PENCIL_SKETCHES_DIR, '*'),
    transform=transform
)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

# Initialize models
encoder = Encoder(LATENT_DIM).to(device)
decoder = Decoder(LATENT_DIM).to(device)
vae = VAE(encoder, decoder).to(device)
generator = Generator().to(device)
discriminator = Discriminator().to(device)

# Optimizers
vae_optimizer = optim.Adam(vae.parameters(), lr=VAE_LR)
gen_optimizer = optim.Adam(generator.parameters(), lr=GEN_LR)
disc_optimizer = optim.Adam(discriminator.parameters(), lr=DISC_LR)

# Training loop
for epoch in range(EPOCHS):
    vae_loss = train_vae(vae, dataloader, vae_optimizer)
    gen_loss, disc_loss = train_gan(gan, dataloader, vae_optimizer, gen_optimizer, disc_optimizer)
    print(f'Epoch {epoch + 1}/{EPOCHS} - VAE Loss: {vae_loss:.4f} - Gen Loss: {gen_loss:.4f} - Disc Loss: {disc_loss:.4f}')