In [20]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os
import numpy as np
from torch.autograd import Variable
import torchvision.utils as vutils

In [21]:
class Generator(nn.Module):
    def __init__(self, latent_dim=100):
        super(Generator, self).__init__()
        
        self.latent_dim = latent_dim
        
        # Process condition (sketch) through conv layers
        self.condition_processor = nn.Sequential(
            nn.Conv2d(1, 64, 4, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True)
        )
        
        # Process noise
        self.noise_processor = nn.Sequential(
            nn.Linear(latent_dim, 128 * 16 * 16),
            nn.LeakyReLU(0.2, inplace=True)
        )
        
        # Combined processing
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 3, 3, 1, 1),
            nn.Tanh()
        )

    def forward(self, noise, condition):
        # Process condition
        processed_condition = self.condition_processor(condition)  # Output: B x 128 x 16 x 16
        
        # Process noise
        processed_noise = self.noise_processor(noise)
        processed_noise = processed_noise.view(-1, 128, 16, 16)  # Reshape to match condition
        
        # Concatenate along channel dimension
        combined = torch.cat([processed_noise, processed_condition], dim=1)  # B x 256 x 16 x 16
        
        # Generate image
        output = self.decoder(combined)
        return output

In [22]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        
        # Process condition (sketch)
        self.condition_processor = nn.Sequential(
            nn.Conv2d(1, 64, 4, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True)
        )
        
        # Process image
        self.image_processor = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True)
        )
        
        # Combined processing
        self.combined_processor = nn.Sequential(
            nn.Conv2d(256, 512, 4, 2, 1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True)
        )

        # Final classification layer
        self.classifier = nn.Sequential(
            nn.Linear(512 * 8 * 8, 1),
            nn.Sigmoid()
        )

    def forward(self, image, condition):
        # Process condition and image
        processed_condition = self.condition_processor(condition)
        processed_image = self.image_processor(image)
        
        # Combine features
        combined = torch.cat([processed_image, processed_condition], dim=1)
        features = self.combined_processor(combined)
        
        # Flatten and classify
        features = features.view(features.size(0), -1)
        validity = self.classifier(features)
        
        return validity

In [23]:
class FaceSketchDataset(Dataset):
    def __init__(self, root_dir, split='train', transform=None):
        self.root_dir = root_dir
        self.split = split
        self.transform = transform

        self.photo_dir = os.path.join(root_dir, split, 'photos')
        self.sketch_dir = os.path.join(root_dir, split, 'sketches')

        self.photos = sorted([f for f in os.listdir(self.photo_dir) if f.endswith(('.jpg', '.png'))])

    def __len__(self):
        return len(self.photos)

    def __getitem__(self, idx):
        photo_name = self.photos[idx]
        photo_path = os.path.join(self.photo_dir, photo_name)
        sketch_path = os.path.join(self.sketch_dir, photo_name)

        photo = Image.open(photo_path).convert('RGB')
        sketch = Image.open(sketch_path).convert('L')

        if self.transform:
            photo = self.transform(photo)
            sketch = self.transform(sketch)

        return sketch, photo

In [24]:
def save_model(generator, discriminator, epoch, optimizer_G, optimizer_D, path='checkpoints'):
    os.makedirs(path, exist_ok=True)
    torch.save({
        'epoch': epoch,
        'generator_state_dict': generator.state_dict(),
        'discriminator_state_dict': discriminator.state_dict(),
        'optimizer_G_state_dict': optimizer_G.state_dict(),
        'optimizer_D_state_dict': optimizer_D.state_dict(),
    }, os.path.join(path, f'checkpoint_epoch_{epoch}.pth'))

def load_model(generator, discriminator, optimizer_G, optimizer_D, checkpoint_path):
    checkpoint = torch.load(checkpoint_path)
    generator.load_state_dict(checkpoint['generator_state_dict'])
    discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
    optimizer_G.load_state_dict(checkpoint['optimizer_G_state_dict'])
    optimizer_D.load_state_dict(checkpoint['optimizer_D_state_dict'])
    return checkpoint['epoch']

In [25]:
def train_cgan(generator, discriminator, dataloader, num_epochs, device, save_interval=10):
    # Loss functions
    adversarial_loss = nn.BCELoss()

    # Optimizers
    optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

    generator = generator.to(device)
    discriminator = discriminator.to(device)

    for epoch in range(num_epochs):
        for i, (sketches, real_imgs) in enumerate(dataloader):
            batch_size = real_imgs.size(0)
            
            real_imgs = real_imgs.to(device)
            sketches = sketches.to(device)

            valid = torch.ones(batch_size, 1).to(device)
            fake = torch.zeros(batch_size, 1).to(device)

            # Train Generator
            optimizer_G.zero_grad()
            z = torch.randn(batch_size, generator.latent_dim).to(device)
            gen_imgs = generator(z, sketches)
            g_loss = adversarial_loss(discriminator(gen_imgs, sketches), valid)
            g_loss.backward()
            optimizer_G.step()

            # Train Discriminator
            optimizer_D.zero_grad()
            real_loss = adversarial_loss(discriminator(real_imgs, sketches), valid)
            fake_loss = adversarial_loss(discriminator(gen_imgs.detach(), sketches), fake)
            d_loss = (real_loss + fake_loss) / 2
            d_loss.backward()
            optimizer_D.step()

            if i % 100 == 0:
                print(f"[Epoch {epoch}/{num_epochs}] [Batch {i}/{len(dataloader)}] "
                      f"[D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}]")

        # Save model checkpoints
        if (epoch + 1) % save_interval == 0:
            save_model(generator, discriminator, epoch + 1, optimizer_G, optimizer_D)

In [26]:
def generate_sketch(model_path, image_path, output_path, device='cuda'):
    # Initialize model
    generator = Generator()
    discriminator = Discriminator()
    optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

    # Load trained model
    load_model(generator, discriminator, optimizer_G, optimizer_D, model_path)
    generator.to(device)
    generator.eval()

    # Prepare image
    transform = transforms.Compose([
        transforms.Resize((64, 64)),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])
    ])

    image = Image.open(image_path).convert('L')
    image = transform(image).unsqueeze(0).to(device)

    # Generate sketch
    with torch.no_grad():
        z = torch.randn(1, generator.latent_dim).to(device)
        generated = generator(z, image)

    # Save generated image
    vutils.save_image(generated, output_path, normalize=True)
    return generated

In [27]:
def main():
    # Hyperparameters
    latent_dim = 100
    batch_size = 64
    num_epochs = 200
    image_size = 64
    root_dir = "archive"  # Update with your dataset path

    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Create directories for saving results
    os.makedirs("checkpoints", exist_ok=True)
    os.makedirs("images", exist_ok=True)

    # Data transforms
    transform = transforms.Compose([
        transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])
    ])

    # Create datasets
    train_dataset = FaceSketchDataset(root_dir, split='train', transform=transform)
    val_dataset = FaceSketchDataset(root_dir, split='val', transform=transform)

    # Create dataloaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=2
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=2
    )

    # Initialize models
    generator = Generator(latent_dim)
    discriminator = Discriminator()

    # Train the model
    train_cgan(generator, discriminator, train_loader, num_epochs, device)

In [28]:
# For training
main()

[Epoch 0/200] [Batch 0/323] [D loss: 0.6780] [G loss: 0.8016]
[Epoch 0/200] [Batch 100/323] [D loss: 0.3331] [G loss: 3.9636]
[Epoch 0/200] [Batch 200/323] [D loss: 0.1987] [G loss: 1.8771]
[Epoch 0/200] [Batch 300/323] [D loss: 0.2922] [G loss: 1.5256]
[Epoch 1/200] [Batch 0/323] [D loss: 0.2810] [G loss: 3.2125]
[Epoch 1/200] [Batch 100/323] [D loss: 0.1683] [G loss: 2.6461]
[Epoch 1/200] [Batch 200/323] [D loss: 0.1015] [G loss: 3.7749]
[Epoch 1/200] [Batch 300/323] [D loss: 0.2175] [G loss: 2.9279]
[Epoch 2/200] [Batch 0/323] [D loss: 0.5609] [G loss: 0.7474]
[Epoch 2/200] [Batch 100/323] [D loss: 0.1444] [G loss: 2.7633]
[Epoch 2/200] [Batch 200/323] [D loss: 0.1319] [G loss: 2.2718]
[Epoch 2/200] [Batch 300/323] [D loss: 0.1677] [G loss: 3.2981]
[Epoch 3/200] [Batch 0/323] [D loss: 0.1282] [G loss: 2.5638]
[Epoch 3/200] [Batch 100/323] [D loss: 0.0733] [G loss: 3.0565]
[Epoch 3/200] [Batch 200/323] [D loss: 0.0904] [G loss: 2.8127]
[Epoch 3/200] [Batch 300/323] [D loss: 0.0505] [

In [31]:

# For inference
model_path = "checkpoints/checkpoint_epoch_200.pth"
input_image = "archive/train/photos/0.jpg"
output_path = "sketch.jpg"
generate_sketch(model_path, input_image, output_path)

  checkpoint = torch.load(checkpoint_path)


tensor([[[[-0.3519, -0.4555, -0.7304,  ..., -0.9049, -0.7889, -0.6978],
          [-0.3437, -0.4842, -0.4858,  ..., -0.8872, -0.6428, -0.7665],
          [-0.4431, -0.2874, -0.2692,  ..., -0.9008, -0.8853, -0.7058],
          ...,
          [ 0.8666,  0.8406,  0.8020,  ...,  0.8288,  0.7549,  0.7914],
          [ 0.9003,  0.8660,  0.8699,  ...,  0.8830,  0.8250,  0.8737],
          [ 0.8118,  0.7935,  0.8548,  ...,  0.9584,  0.8868,  0.8340]],

         [[-0.1771, -0.1356, -0.4497,  ..., -0.5185, -0.5195, -0.2646],
          [-0.1665, -0.4211, -0.6521,  ..., -0.6834, -0.5058, -0.6442],
          [-0.2465, -0.1737, -0.5108,  ..., -0.7988, -0.2872, -0.4867],
          ...,
          [ 0.8510,  0.7742,  0.8802,  ...,  0.6724,  0.6564,  0.4917],
          [ 0.8790,  0.8711,  0.8904,  ...,  0.8169,  0.7204,  0.7035],
          [ 0.7299,  0.8462,  0.8252,  ...,  0.9093,  0.8975,  0.7782]],

         [[-0.1894, -0.2082, -0.7538,  ..., -0.5733, -0.5413, -0.1853],
          [-0.0874, -0.3652, -