In [38]:
import os
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
from PIL import Image


In [39]:
torch.manual_seed(42)
random.seed(42)

BATCH_SIZE = 10
LR = 0.0002
NUM_EPOCHS = 3
IMAGE_SIZE = 256
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [40]:
class DatasetClass(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.sketch_paths = []
        self.photo_paths = []

        for split in ['train', 'val', 'test']:
            sketch_folder = os.path.join(root_dir, split, 'sketches')
            photo_folder = os.path.join(root_dir, split, 'photos')

            # Match sketch and photo filenames
            for sketch_file in os.listdir(sketch_folder):
                self.sketch_paths.append(os.path.join(sketch_folder, sketch_file))
                self.photo_paths.append(os.path.join(photo_folder, sketch_file))

    def __len__(self):
        return len(self.sketch_paths)

    def __getitem__(self, idx):
        sketch = Image.open(self.sketch_paths[idx]).convert('RGB')
        photo = Image.open(self.photo_paths[idx]).convert('RGB')   # Convert photo to RGB

        if self.transform:
            sketch = self.transform(sketch)
            photo = self.transform(photo)

        return sketch, photo


In [41]:
preprocess_steps = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])


In [42]:
data_set = DatasetClass(root_dir='/kaggle/input/person-face-sketches', transform=preprocess_steps)
train_loader = DataLoader(data_set, batch_size=8, shuffle=True)

In [43]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=7, stride=1, padding=3),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 3, kernel_size=7, stride=1, padding=3),
            nn.Tanh()
        )

    def forward(self, x):
        return self.model(x)
    
    


In [44]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)


In [45]:
generator_A2B = Generator().to(DEVICE)
generator_B2A = Generator().to(DEVICE)
discriminator_A = Discriminator().to(DEVICE)
discriminator_B = Discriminator().to(DEVICE)

In [46]:
if torch.cuda.device_count() > 1:
    generator_A2B = nn.DataParallel(generator_A2B)
    generator_B2A = nn.DataParallel(generator_B2A)
    discriminator_A = nn.DataParallel(discriminator_A)
    discriminator_B = nn.DataParallel(discriminator_B)

criterion_GAN = nn.MSELoss()
criterion_cycle = nn.L1Loss()

optimizer_G = optim.Adam(list(generator_A2B.parameters()) + list(generator_B2A.parameters()), lr=LR, betas=(0.5, 0.999))
optimizer_D_A = optim.Adam(discriminator_A.parameters(), lr=LR, betas=(0.5, 0.999))
optimizer_D_B = optim.Adam(discriminator_B.parameters(), lr=LR, betas=(0.5, 0.999))

In [47]:
def visualize_results(sketch, generated_photo, real_photo, generated_sketch, step):
    fig, axes = plt.subplots(2, 2, figsize=(10, 10))

    axes[0, 0].imshow(((sketch.cpu().detach().numpy().transpose(1, 2, 0)) + 1) / 2)
    axes[0, 0].axis('off')
    axes[0, 0].set_title('Sketch')

    axes[0, 1].imshow(((generated_photo.cpu().detach().numpy().transpose(1, 2, 0)) + 1) / 2)
    axes[0, 1].axis('off')
    axes[0, 1].set_title('Generated Photo')

    axes[1, 0].imshow(((real_photo.cpu().detach().numpy().transpose(1, 2, 0)) + 1) / 2)
    axes[1, 0].axis('off')
    axes[1, 0].set_title('Real Photo')

    axes[1, 1].imshow(((generated_sketch.cpu().detach().numpy().transpose(1, 2, 0)) + 1) / 2)
    axes[1, 1].axis('off')
    axes[1, 1].set_title('Generated Sketch')

    plt.suptitle(f'Step {step}')
    plt.tight_layout()
    plt.show()



In [48]:
def save_model_weights(epoch, save_dir="model_weights"):
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    torch.save(generator_A2B.state_dict(), os.path.join(save_dir, f"generator_A2B_epoch_{epoch+1}.pth"))
    torch.save(generator_B2A.state_dict(), os.path.join(save_dir, f"generator_B2A_epoch_{epoch+1}.pth"))
    torch.save(discriminator_A.state_dict(), os.path.join(save_dir, f"discriminator_A_epoch_{epoch+1}.pth"))
    torch.save(discriminator_B.state_dict(), os.path.join(save_dir, f"discriminator_B_epoch_{epoch+1}.pth"))


In [None]:

for epoch in range(NUM_EPOCHS):
    for i, (sketch, photo) in enumerate(train_loader):  # Use train_loader instead of dataloader
        sketch = sketch.to(DEVICE)
        photo = photo.to(DEVICE)

        optimizer_G.zero_grad()

        # Generator forward pass A2B is sketch to photo
        fake_photo = generator_A2B(sketch)
        reconstructed_sketch = generator_B2A(fake_photo)
        loss_cycle_A = criterion_cycle(reconstructed_sketch, sketch)
        
        #B2A is photo to sketch
        fake_sketch = generator_B2A(photo)
        reconstructed_photo = generator_A2B(fake_sketch)
        loss_cycle_B = criterion_cycle(reconstructed_photo, photo)

        # GAN loss
        loss_GAN_A = criterion_GAN(discriminator_B(fake_photo), torch.ones_like(discriminator_B(fake_photo)))
        loss_GAN_B = criterion_GAN(discriminator_A(fake_sketch), torch.ones_like(discriminator_A(fake_sketch)))

        # Total generator loss
        loss_G = loss_GAN_A + loss_GAN_B + 10.0 * (loss_cycle_A + loss_cycle_B)
        loss_G.backward()
        optimizer_G.step()

        # Discriminator updates
        optimizer_D_A.zero_grad()
        optimizer_D_B.zero_grad()

        loss_D_A = criterion_GAN(discriminator_A(photo), torch.ones_like(discriminator_A(photo))) + \
                    criterion_GAN(discriminator_A(fake_sketch.detach()), torch.zeros_like(discriminator_A(fake_sketch.detach())))

        loss_D_B = criterion_GAN(discriminator_B(sketch), torch.ones_like(discriminator_B(sketch))) + \
                    criterion_GAN(discriminator_B(fake_photo.detach()), torch.zeros_like(discriminator_B(fake_photo.detach())))

        loss_D_A.backward()
        loss_D_B.backward()
        optimizer_D_A.step()
        optimizer_D_B.step()

        if (i + 1) % 2 == 0:
            visualize_results(sketch[0], fake_photo[0], photo[0], fake_sketch[0], i+1)
            
    

    save_model_weights(epoch)


In [50]:

torch.save(generator_A2B.state_dict(), "generator_A2B.pth")
torch.save(generator_B2A.state_dict(), "generator_B2A.pth")
torch.save(discriminator_A.state_dict(), "discriminator_A.pth")
torch.save(discriminator_B.state_dict(), "discriminator_B.pth")
