In [9]:
import sys
import torch
from torch import nn
from PyQt5.QtCore import Qt
from PyQt5 import QtWidgets, QtGui, QtCore
from PyQt5.QtWidgets import QFileDialog, QLabel, QPushButton, QVBoxLayout, QHBoxLayout, QFrame
from PIL import Image
import torchvision.transforms as T
import numpy as np

In [11]:
# --- UNet-style Generator ---
class UNetGenerator(nn.Module):
    def __init__(self, input_nc=3, output_nc=3, ngf=64):
        super(UNetGenerator, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(input_nc, ngf, 4, 2, 1),
            nn.ReLU(True),
            nn.Conv2d(ngf, ngf * 2, 4, 2, 1),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            nn.Conv2d(ngf * 2, ngf * 4, 4, 2, 1),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            nn.Conv2d(ngf * 4, ngf * 8, 4, 2, 1),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf, output_nc, 4, 2, 1),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [13]:
# --- Transforms ---
transform = T.Compose([
    T.Resize((256, 256)),
    T.ToTensor(),
    T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

inv_transform = T.Compose([
    T.Normalize(mean=[-1, -1, -1], std=[2, 2, 2]),
    T.ToPILImage()
])


def load_generator(path):
    model = UNetGenerator()
    state_dict = torch.load(path, map_location='cpu')
    if isinstance(state_dict, dict) and 'model_state_dict' in state_dict:
        state_dict = state_dict['model_state_dict']
    model.load_state_dict(state_dict)
    model.eval()
    return model


def is_sketch(img_pil):
    """Check if image is mostly grayscale by analyzing std across RGB before normalization."""
    img_arr = np.array(img_pil.resize((64, 64))).astype(np.float32) / 255.0
    std_per_channel = np.std(img_arr, axis=(0, 1))
    return np.mean(std_per_channel) < 0.08

In [15]:
class MainWindow(QtWidgets.QWidget):
    def __init__(self, model_G, model_F):
        super().__init__()
        self.setWindowTitle("Sketch ↔ Face Converter")
        self.setGeometry(200, 150, 800, 600)
        self.setStyleSheet("background-color: #2d2d2d; color: white; font-family: Arial; font-size: 14px;")

        self.model_G = model_G
        self.model_F = model_F

        # Upload button (centered)
        self.button = QPushButton("Upload Image")
        self.button.setStyleSheet("background-color: #5c5c8a; padding: 12px; font-size: 16px;")
        self.button.clicked.connect(self.load_image)

        # Labels for input/output
        self.input_label = QLabel("Input Image")
        self.output_label = QLabel("Output Image")

        # Image display
        self.input_frame = QLabel()
        self.output_frame = QLabel()
        for frame in [self.input_frame, self.output_frame]:
            frame.setFixedSize(256, 256)
            frame.setFrameShape(QFrame.Box)
            frame.setStyleSheet("background-color: white; border-radius: 10px;")

        # Layout for images
        img_layout = QHBoxLayout()
        img_layout.addWidget(self.input_frame)
        img_layout.addWidget(self.output_frame)

        # Layout for labels
        label_layout = QHBoxLayout()
        label_layout.addWidget(self.input_label)
        label_layout.addStretch()
        label_layout.addWidget(self.output_label)

        # Converter buttons (below images)
        self.face2sketch_button = QPushButton("Face → Sketch")
        self.sketch2face_button = QPushButton("Sketch → Face")
        self.face2sketch_button.setStyleSheet("background-color: #ff6f61; padding: 10px; font-size: 14px;")
        self.sketch2face_button.setStyleSheet("background-color: #4caf50; padding: 10px; font-size: 14px;")
        self.face2sketch_button.clicked.connect(self.convert_face_to_sketch)
        self.sketch2face_button.clicked.connect(self.convert_sketch_to_face)

        # Layout for buttons
        button_layout = QHBoxLayout()
        button_layout.addWidget(self.face2sketch_button)
        button_layout.addWidget(self.sketch2face_button)

        # Main layout
        main_layout = QVBoxLayout()
        main_layout.addWidget(self.button, alignment=QtCore.Qt.AlignCenter)
        main_layout.addLayout(img_layout)
        main_layout.addLayout(label_layout)
        main_layout.addLayout(button_layout)

        self.setLayout(main_layout)

    def load_image(self):
        fname, _ = QFileDialog.getOpenFileName(self, 'Open file', '', 'Images (*.png *.jpg *.jpeg)')
        if fname:
            pil_img = Image.open(fname).convert("RGB")
            input_tensor = transform(pil_img).unsqueeze(0)

            # Detect type using pre-normalized PIL image
            if is_sketch(pil_img):
                self.input_label.setText("Input: Sketch")
                self.output_label.setText("Output: Face")
            else:
                self.input_label.setText("Input: Face")
                self.output_label.setText("Output: Sketch")

            # Show input image
            input_pixmap = QtGui.QPixmap(fname).scaled(256, 256, Qt.KeepAspectRatio)
            self.input_frame.setPixmap(input_pixmap)

            self.image_path = fname

    def convert_face_to_sketch(self):
        if hasattr(self, 'image_path'):
            pil_img = Image.open(self.image_path).convert("RGB")
            input_tensor = transform(pil_img).unsqueeze(0)

            with torch.no_grad():
                output = self.model_F(input_tensor).squeeze(0).cpu().clamp_(-1, 1)

            output_img = inv_transform(output)
            output_img.save("output.png")

            # Show output image
            output_pixmap = QtGui.QPixmap("output.png").scaled(256, 256, Qt.KeepAspectRatio)
            self.output_frame.setPixmap(output_pixmap)
            self.output_label.setText("Output: Sketch")

    def convert_sketch_to_face(self):
        if hasattr(self, 'image_path'):
            pil_img = Image.open(self.image_path).convert("RGB")
            input_tensor = transform(pil_img).unsqueeze(0)

            with torch.no_grad():
                output = self.model_G(input_tensor).squeeze(0).cpu().clamp_(-1, 1)

            output_img = inv_transform(output)
            output_img.save("output.png")

            # Show output image
            output_pixmap = QtGui.QPixmap("output.png").scaled(256, 256, Qt.KeepAspectRatio)
            self.output_frame.setPixmap(output_pixmap)
            self.output_label.setText("Output: Face")

In [None]:
model_G = load_generator("Downloads/generator_sketch2face.pth")
model_F = load_generator("Downloads/generator_face2sketch.pth")

app = QtWidgets.QApplication(sys.argv)
win = MainWindow(model_G, model_F)
win.show()
sys.exit(app.exec_())


  state_dict = torch.load(path, map_location='cpu')
