In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os

In [2]:
# Custom Dataset
class FaceSketchDataset(Dataset):
    def __init__(self, sketch_dir, real_dir, transform=None):
        self.sketch_dir = sketch_dir
        self.real_dir = real_dir
        self.transform = transform
        self.image_files = os.listdir(sketch_dir)

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        sketch_path = os.path.join(self.sketch_dir, self.image_files[idx])
        real_path = os.path.join(self.real_dir, self.image_files[idx])

        sketch = Image.open(sketch_path).convert('RGB')
        real = Image.open(real_path).convert('RGB')

        if self.transform:
            sketch = self.transform(sketch)
            real = self.transform(real)

        return sketch, real

In [3]:
# Generator
class Generator(nn.Module):
    def __init__(self, input_channels=3, output_channels=3, n_filters=64):
        super(Generator, self).__init__()
        self.conv1 = nn.Conv2d(input_channels, n_filters, kernel_size=4, stride=2, padding=1)
        self.conv2 = nn.Conv2d(n_filters, n_filters * 2, kernel_size=4, stride=2, padding=1)
        self.conv3 = nn.Conv2d(n_filters * 2, n_filters * 4, kernel_size=4, stride=2, padding=1)
        self.conv4 = nn.Conv2d(n_filters * 4, n_filters * 8, kernel_size=4, stride=2, padding=1)

        self.deconv1 = nn.ConvTranspose2d(n_filters * 8, n_filters * 4, kernel_size=4, stride=2, padding=1)
        self.deconv2 = nn.ConvTranspose2d(n_filters * 4, n_filters * 2, kernel_size=4, stride=2, padding=1)
        self.deconv3 = nn.ConvTranspose2d(n_filters * 2, n_filters, kernel_size=4, stride=2, padding=1)
        self.deconv4 = nn.ConvTranspose2d(n_filters, output_channels, kernel_size=4, stride=2, padding=1)

        self.relu = nn.ReLU()
        self.tanh = nn.Tanh()

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.relu(self.conv3(x))
        x = self.relu(self.conv4(x))

        x = self.relu(self.deconv1(x))
        x = self.relu(self.deconv2(x))
        x = self.relu(self.deconv3(x))
        x = self.tanh(self.deconv4(x))
        return x


In [4]:
#Discriminator
class Discriminator(nn.Module):
    def __init__(self, input_channels=3, n_filters=64):
        super(Discriminator, self).__init__()
        self.conv_layers = nn.Sequential(
            # input is (input_channels*2) x 256 x 256
            nn.Conv2d(input_channels * 2, n_filters, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (n_filters) x 128 x 128
            nn.Conv2d(n_filters, n_filters * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(n_filters * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (n_filters*2) x 64 x 64
            nn.Conv2d(n_filters * 2, n_filters * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(n_filters * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (n_filters*4) x 32 x 32
            nn.Conv2d(n_filters * 4, n_filters * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(n_filters * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (n_filters*8) x 16 x 16
            nn.Conv2d(n_filters * 8, n_filters * 16, 4, 2, 1, bias=False),
            nn.BatchNorm2d(n_filters * 16),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (n_filters*16) x 8 x 8
        )
        
        # Calculate the size of the flattened features
        self.feature_size = n_filters * 16 * 8 * 8
        
        self.classifier = nn.Sequential(
            nn.Linear(self.feature_size, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(1024, 1),
            nn.Sigmoid()
        )

    def forward(self, sketch, image):
        # Concatenate the sketch and the image
        x = torch.cat([sketch, image], dim=1)
        # Pass through convolutional layers
        x = self.conv_layers(x)
        # Flatten the features
        x = x.view(-1, self.feature_size)
        # Pass through the classifier
        return self.classifier(x)

In [5]:
# Training function
def train_cgan(generator, discriminator, dataloader, num_epochs, device):
    criterion = nn.BCELoss()
    g_optimizer = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    
    turn = False
    counter = 0
    for epoch in range(num_epochs):   
#         if(not(turn) and counter>=1):
#             turn = not(turn)
#             counter = 0
#         elif (turn):
#             turn = not(turn)
#         else:
#             counter+=1
        
        for i, (sketches, real_images) in enumerate(dataloader):
            
            batch_size = sketches.size(0)
            real_label = torch.ones(batch_size, 1, 1, 1).to(device)
            fake_label = torch.zeros(batch_size, 1, 1, 1).to(device)

            sketches = sketches.to(device)
            real_images = real_images.to(device)

            # Train Discriminator
            d_optimizer.zero_grad()

            # Real images
            d_real_output = discriminator(sketches, real_images)
            d_real_output = d_real_output.reshape(batch_size, 1, 1, 1)
            d_real_loss = criterion(d_real_output, real_label)

            # Fake images
            fake_images = generator(sketches)
            d_fake_output = discriminator(sketches, fake_images.detach())
            d_fake_output = d_fake_output.reshape(batch_size, 1, 1, 1)
            d_fake_loss = criterion(d_fake_output, fake_label)
            
            d_loss = d_real_loss + d_fake_loss
            #if(turn):
            d_loss.backward()
            d_optimizer.step()

            # Train Generator
            g_optimizer.zero_grad()
            g_fake_output = discriminator(sketches, fake_images)
            g_fake_output = g_fake_output.reshape(batch_size, 1, 1, 1)
            g_loss = criterion(g_fake_output, real_label) 
            #if(not(turn)):
            g_loss.backward()
            g_optimizer.step()
            
            

            if i % 100 == 0:
                print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(dataloader)}], "
                      f"D_loss: {d_loss.item():.4f}, G_loss: {g_loss.item():.4f}")
        counter+=1

In [2]:

# # Main execution
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Define transforms
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    # Create dataset and dataloader
    dataset = FaceSketchDataset("/kaggle/input/person-face-sketches/train/sketches", "/kaggle/input/person-face-sketches/train/photos", transform=transform)
    dataloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=4)

    # Initialize models
    generator = Generator().to(device)
    discriminator = Discriminator().to(device)

    # Train the model
    train_cgan(generator, discriminator, dataloader, num_epochs=100, device=device)

if __name__ == "__main__":
    main()


Epoch [1/30], Step [1/323], D_loss: 1.3841, G_loss: 12.6832
Epoch [1/30], Step [101/323], D_loss: 0.9044, G_loss: 15.9051
Epoch [1/30], Step [201/323], D_loss: 0.7231, G_loss: 18.4128
Epoch [1/30], Step [301/323], D_loss: 0.6102, G_loss: 20.0723
Epoch [2/30], Step [1/323], D_loss: 0.5734, G_loss: 22.7562
Epoch [2/30], Step [101/323], D_loss: 0.5292, G_loss: 24.8213
Epoch [2/30], Step [201/323], D_loss: 0.4981, G_loss: 26.6089
Epoch [2/30], Step [301/323], D_loss: 0.4723, G_loss: 27.9731
Epoch [3/30], Step [1/323], D_loss: 0.4512, G_loss: 28.3494
Epoch [3/30], Step [101/323], D_loss: 0.4301, G_loss: 27.2007
Epoch [3/30], Step [201/323], D_loss: 0.4156, G_loss: 25.7641
Epoch [3/30], Step [301/323], D_loss: 0.3987, G_loss: 24.4154
Epoch [4/30], Step [1/323], D_loss: 0.3845, G_loss: 23.0856
Epoch [4/30], Step [101/323], D_loss: 0.3712, G_loss: 21.6347
Epoch [4/30], Step [201/323], D_loss: 0.3601, G_loss: 20.7710
Epoch [4/30], Step [301/323], D_loss: 0.3498, G_loss: 19.0016
Epoch [5/30], St