In [None]:
from google.colab import drive
drive.mount('/content/drive')
#%cd /content/drive/MyDrive/FOCE

In [None]:
import numpy as np
import natsort
import cv2
import os

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.models as models
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.utils.data import DataLoader, Dataset

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
from PIL import Image, ImageFilter
from torchvision.utils import save_image


In [None]:
def extract_key(filename):
    base = os.path.basename(filename)
    name, _ = os.path.splitext(base)
    # Normalize keys for both types
    if name.startswith("f-") and "-sz1" in name:
        return name.replace("-sz1", "")
    elif name.startswith("sketch"):
        return name.replace("sketch", "image").replace("png", "jpg")
    return name

def load_matched_filenames(photo_dir, sketch_dir):
    photo_files = [os.path.join(photo_dir, f) for f in os.listdir(photo_dir)]
    sketch_files = [os.path.join(sketch_dir, f) for f in os.listdir(sketch_dir)]

    photo_dict = {extract_key(f): f for f in photo_files}
    sketch_dict = {extract_key(f): f for f in sketch_files}

    common_keys = sorted(set(photo_dict.keys()) & set(sketch_dict.keys()))
    matched_photos = [photo_dict[k] for k in common_keys]
    matched_sketches = [sketch_dict[k] for k in common_keys]
    return matched_photos, matched_sketches

class ImageDataset(Dataset):
    def __init__(self, photo_paths, sketch_paths, size=(256, 256)):
        self.photo_paths = photo_paths
        self.sketch_paths = sketch_paths
        self.transform = transforms.Compose([
            transforms.Resize(size),
            transforms.ToTensor()
            #transforms.Normalize((0.5,), (0.5,))
        ])

    def __len__(self):
        return len(self.photo_paths)

    def __getitem__(self, idx):
        photo = Image.open(self.photo_paths[idx]).convert('RGB')
        sketch = Image.open(self.sketch_paths[idx]).convert('RGB')
        return self.transform(sketch), self.transform(photo)

In [None]:
class ResnetBlock(nn.Module):
    def __init__(self, dim):
        super(ResnetBlock, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1),
            nn.InstanceNorm2d(dim),
            nn.ReLU(inplace=True),
            nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1),
            nn.InstanceNorm2d(dim)
        )

    def forward(self, x):
        return x + self.block(x)

#generator network
class GlobalGenerator(nn.Module):
    def __init__(self, img_channels=3, ngf=64, num_blocks=9):
        super(GlobalGenerator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(img_channels, ngf, kernel_size=7, stride=1, padding=3),
            nn.InstanceNorm2d(ngf),
            nn.ReLU(inplace=True),

            #downsampling:
            nn.Conv2d(ngf, ngf*2, kernel_size=3, stride=2, padding=1),
            nn.InstanceNorm2d(ngf*2),
            nn.ReLU(inplace=True),
            nn.Conv2d(ngf*2, ngf*4, kernel_size=3, stride=2, padding=1),
            nn.InstanceNorm2d(ngf*4),
            nn.ReLU(inplace=True),

            *[ResnetBlock(ngf*4) for _ in range(num_blocks)],

            nn.ConvTranspose2d(ngf*4, ngf*2, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(ngf*2),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(ngf*2, ngf, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(ngf),
            nn.ReLU(inplace=True),

            nn.Conv2d(ngf, img_channels, kernel_size=7, stride=1, padding=3),
            nn.Tanh()
        )

    def forward(self, x):
        return self.model(x)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print('Device:', device)

model_path = '/content/drive/MyDrive/FOCE/Combined Dataset/trained models/g_model_final_v3.pth'

model = GlobalGenerator()
model.load_state_dict(torch.load(model_path))  # load the saved weights
model.eval()  # set model to evaluation mode



In [None]:
photo_path = '/content/drive/MyDrive/FOCE/Combined Dataset/test/photo/'
sketch_path = '/content/drive/MyDrive/FOCE/Combined Dataset/test/sketch/'
# photo = load_filename(photo_path)
# sketch = load_filename(sketch_path)
photo, sketch = load_matched_filenames(photo_path, sketch_path)
dataset = ImageDataset(photo, sketch)
batch_size = 1

dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=6, pin_memory=True, shuffle=False)


def display_images(sketch, photo):
    sketch = (sketch.cpu().numpy().transpose(1, 2, 0))
    photo = (photo.cpu().numpy().transpose(1, 2, 0))


    fig, ax = plt.subplots(1, 2, figsize=(12, 4))
    ax[0].imshow(sketch)
    ax[0].set_title("Sketch")
    ax[0].axis("off")

    ax[1].imshow(photo)
    ax[1].set_title("Real Photo")
    ax[1].axis("off")
    plt.show()


#for displaying the dataset:
for i, (sketches, real_faces) in enumerate(dataloader):
  display_images(sketches[0], real_faces[0])

In [None]:
def generate_face(model, sketch_img, i):
    generated_face = model(sketch_img)
    generated_face = generated_face[0]

    # save_image((generated_face + 1) / 2, f'/content/drive/MyDrive/FOCE/CUHK/Generated Faces/generated_face_{i}.png')
    save_image((generated_face + 1) / 2, f'/content/drive/MyDrive/FOCE/Combined Dataset/test/Generated Faces/generated_face_{i}.png')

    return generated_face.detach()


In [None]:
def display_images(sketch, photo):
    # Remove batch dimension if present
    if sketch.dim() == 4:
        sketch = sketch.squeeze(0)
    if photo.dim() == 4:
        photo = photo.squeeze(0)

    # Ensure it's in shape [C, H, W]
    if sketch.shape[0] == 1:  # grayscale
        sketch = sketch.expand(3, -1, -1)
    if photo.shape[0] == 1:
        photo = photo.expand(3, -1, -1)

    # Convert to [H, W, C] and normalize to [0,1]
    sketch = (sketch.cpu().numpy().transpose(1, 2, 0) + 1) / 2
    photo = (photo.cpu().numpy().transpose(1, 2, 0))

    photo = cv2.resize(photo, (200,250))
    sketch = cv2.resize(sketch, (200, 250))
    # Plot
    import matplotlib.pyplot as plt
    fig, axs = plt.subplots(1, 2, figsize=(8, 4))
    axs[0].imshow(sketch)
    axs[0].set_title('Generated Image')
    axs[1].imshow(photo)
    axs[1].set_title('Real Face')
    for ax in axs:
        ax.axis('off')
    plt.tight_layout()
    plt.show()

In [None]:
photo_path = '/content/drive/MyDrive/FOCE/Combined Dataset/test/photo/'
sketch_path = '/content/drive/MyDrive/FOCE/Combined Dataset/test/sketch/'
# photo = load_filename(photo_path)
# sketch = load_filename(sketch_path)
photo, sketch = load_matched_filenames(photo_path, sketch_path)

dataset = ImageDataset(photo, sketch)
batch_size = 1

dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=6, pin_memory=True, shuffle=False)


for i, (sketch, real_face) in enumerate(dataloader):
  generated_face = generate_face(model, sketch, i)
  print(generated_face.shape)
  display_images(generated_face, real_face)

In [None]:
def save_images(sketch, photo, generated, output_dir, epoch):

    # Remove batch dimension if present
    if sketch.dim() == 4:
        sketch = sketch.squeeze(0)
    if photo.dim() == 4:
        photo = photo.squeeze(0)

    # Ensure it's in shape [C, H, W]
    if sketch.shape[0] == 1:  # grayscale
        sketch = sketch.expand(3, -1, -1)
    if photo.shape[0] == 1:
        photo = photo.expand(3, -1, -1)

    os.makedirs(output_dir, exist_ok=True)

    # Convert tensors (range [-1,1]) to image arrays in [0, 1]
    sketch_np = ((sketch.permute(1, 2, 0).cpu().numpy()) + 1) / 2
    photo_np = ((photo.permute(1, 2, 0).cpu().numpy()))
    #generated_np = ((generated.permute(1, 2, 0).cpu().numpy()) + 1) / 2

    # Convert arrays to PIL Images
    sketch_img = Image.fromarray((sketch_np * 255).astype('uint8'))
    photo_img = Image.fromarray((photo_np * 255).astype('uint8'))
    #generated_img = Image.fromarray((generated_np * 255).astype('uint8'))

    # Apply sharpening to the generated image using our new ImageSharpener class
    # sharpener = ImageSharpener(radius=2, percent=150, threshold=3)
    # generated_img_sharp = sharpener.sharpen(generated)


    fig, axs = plt.subplots(1, 4, figsize=(8, 4))
    axs[0].imshow(sketch_img)
    axs[0].set_title('Sketch')
    axs[1].imshow(photo_img)
    axs[1].set_title('Real Face')
    axs[2].imshow(generated)
    axs[2].set_title('Generated')
    # axs[3].imshow(generated_img_sharp)
    # axs[3].set_title('Sharpened Generated')
    for ax in axs:
        ax.axis('off')
    plt.tight_layout()
    plt.show()

    # Save images
    # generated_img_sharp.save(os.path.join(output_dir, f'generated_sharpened_epoch_{epoch}.png'))
    generated.save(os.path.join(output_dir, f'generated_image_{epoch}.png'))

In [None]:
# output_dir = '/content/drive/MyDrive/FOCE/CUHK/Enhanced Images'
output_dir = '/content/drive/MyDrive/FOCE/Combined Dataset/test/Generated Faces'

for i, (sketch, real_face) in enumerate(dataloader):
  # generated = Image.open(f'/content/drive/MyDrive/FOCE/CUHK/Generated Faces/generated_face_{i}.png')
  generated = Image.open(f'/content/drive/MyDrive/FOCE/Combined Dataset/test/Generated Faces/generated_face_{i}.png')
  save_images(sketch, real_face, generated, output_dir, i)