In [6]:
import os
import cv2
import dlib
import glob
import torch
import imutils
import warnings

import numpy as np
import pandas as pd

from tqdm import tqdm
from pathlib import Path
from imutils import face_utils
from PIL import Image as Image
from skimage.metrics import structural_similarity as compare_ssim

import sys
sys.path.insert(0, r"D:\Coding\NTU\FYP\Source\StyleGAN2_ADA_TORCH")

import dnnlib
import legacy

from matplotlib import pyplot as plt
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas

from PyQt5.QtWidgets import QApplication, QMainWindow, QDesktopWidget, QMenuBar, QMenu, QAction, QLayoutItem
from PyQt5.QtWidgets import QGroupBox, QHBoxLayout, QVBoxLayout, QGridLayout, QFrame, QSpacerItem, QScrollArea
from PyQt5.QtWidgets import QLabel, QLineEdit, QTextEdit, QSpinBox, QDoubleSpinBox, QComboBox 
from PyQt5.QtWidgets import QSlider
from PyQt5.QtWidgets import QListWidget, QListWidgetItem, QTabWidget
from PyQt5.QtWidgets import QGraphicsView, QGraphicsScene
from PyQt5.QtWidgets import QPushButton, QRadioButton, QFileDialog, QDialog, QMessageBox

from PyQt5.QtGui import QImage, QPixmap, QIcon, QTextCursor, QPen, QBrush, QFont
from PyQt5.QtGui import QMouseEvent

from PyQt5.QtCore import Qt, QSize, QObject, QPoint, QRect, QTimer, QCoreApplication, QEventLoop
from PyQt5.QtCore import pyqtSignal, pyqtSignal, pyqtSlot

from pyqtgraph import PlotWidget, plot
import pyqtgraph as pg


In [7]:
class StyleGAN_img():
    def __init__(self, G, seed, shape, label, device, truncation_psi=0.5, noise_mode="const"):
        self.G = G
        self.seed = seed
        self.shape = shape
        self.label = label
        self.device = device
        self.truncation_psi = truncation_psi
        self.noise_mode = noise_mode

    def generate_latent_vector(self):
        rnd = np.random.RandomState(self.seed)
        z = rnd.randn(1, self.shape)
        return z

    def generate_image(self, z):
        z = torch.from_numpy(z).to(self.device)
        img_gpu = self.G(z, self.label, truncation_psi=self.truncation_psi, noise_mode=self.noise_mode)
        img_tensor = (img_gpu.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
        img = np.transpose(img_tensor.cpu().numpy()[0], (0,1,2))
        return img 
    
    def generate_projected(self, ws):
        ws = torch.tensor(ws, device=self.device)
        assert ws.shape[1:] == (self.G.num_ws, self.G.w_dim)
        for _, w in enumerate(ws):
            projected = self.G.synthesis(w.unsqueeze(0), noise_mode='const')
            projected = (projected.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
            # projected = Image.fromarray(projected[0].cpu().numpy(), 'RGB')
            projected = projected[0].cpu().numpy()
        return projected

class StyleGAN_init():
    def __init__(self, path_to_network):
        self.network_path = path_to_network
        self.device = torch.device(self.torch_is_cuda_available())
        self.G = self.load_network()
        (self.z_shape, self.zero_label) = self.setup_image_generation()

        stylegan_img_init = StyleGAN_img(self.G, 0, self.z_shape, self.zero_label, self.device)
        z_init = stylegan_img_init.generate_latent_vector()
        stylegan_img_init.generate_image(z_init)

    def torch_is_cuda_available(self):
        # Check if cuda is available
        if torch.cuda.is_available(): 
            print('CUDA is available. Using device \'0\': %s'%(torch.cuda.get_device_name(0)))
            return 'cuda'
        else:
            warnings.warn('Warning! CUDA is not available... Using CPU...')
            return 'cpu'

    def load_network(self):
        print('Unpacking network file...', end=' ')
        with dnnlib.util.open_url(self.network_path) as f:
            _G = legacy.load_network_pkl(f)['G_ema'].to(self.device)
        print('Done.')
        return _G

    def setup_image_generation(self):
        _z_shape = self.G.z_dim
        _zero_label = torch.zeros([1, self.G.c_dim], device=self.device)
        return _z_shape, _zero_label

In [8]:
class MainWindow(QMainWindow):
    def __init__(self):
        super().__init__()
        self.setWindowTitle("FYP StyleGAN2")
        self.setMinimumSize(QSize(1280, 768))
        self.showMaximized()

        self.StyleGAN_init_state = False
        self.allow_add_new = True
        self.add_new_count = 0
        self.add_count_row = dict(zip(range(5),[0]*5))
        self.transform_count = 0
        self.max_transform_count = 5
        self.img_size = 264

        self.placeholder_img = np.random.rand(self.img_size,self.img_size,3) * 255
        self.placeholder_img = self.image_np_to_pixmap(self.placeholder_img)

        self.init_window()

    def init_window(self):
        # build menubar
        menubar = QMenuBar()
        menu_main_file = menubar.addMenu('File')
        menu_main_file = menubar.addMenu('Edit')
        menu_main_file = menubar.addMenu('View')
        menu_main_file = menubar.addMenu('Help')
        self.setMenuBar(menubar)

        # build central widget layout
        toolbar_layout = self.build_toolbar_layout()
        central_layout = self.build_central_layout()
        console_layout = self.build_console_layout()

        # app main layout
        main_layout = QVBoxLayout()
        main_layout.addLayout(toolbar_layout, 1)
        main_layout.addLayout(central_layout, 4)
        main_layout.addWidget(console_layout, 1)

        central_widget = QFrame()
        central_widget.setLayout(main_layout)

        self.setCentralWidget(central_widget)
        self.connect_signals()
        self.load_default_settings()

    def load_default_settings(self):
        cwd = Path.cwd()
        default_settings_path = cwd / 'default_settings.npy'
        self.default_settings = np.load(default_settings_path, allow_pickle=True)
        default_network_path = self.default_settings[0]['default_network_path']
        default_storage_path = self.default_settings[0]['default_storage_path']
        default_database_path = self.default_settings[0]['default_database_path']
        self.network_lineedit.setText(default_network_path)
        self.img_control_save_lineedit.setText(default_storage_path)
        self.database_lineedit.setText(default_database_path)

        if default_network_path != ' ':
            self.check_network_file_type(default_network_path)
        if default_storage_path != ' ':
            pass
        if default_database_path != ' ':
            self.load_attributes_data(default_database_path)
            self.control_group.addLayout(self.build_attributes_layout())
        

    def build_toolbar_layout(self):
        network_label = QLabel('Network File:')
        self.network_lineedit = QLineEdit()
        self.network_lineedit.setReadOnly(True)
        self.network_file_button = QPushButton('Select')

        database_label = QLabel('Database Directory:')
        self.database_lineedit = QLineEdit()
        self.database_lineedit.setReadOnly(True)
        self.database_button = QPushButton('Select')

        file_group = QHBoxLayout()
        file_group.addWidget(network_label)
        file_group.addWidget(self.network_lineedit)
        file_group.addWidget(self.network_file_button)
        file_group.addWidget(database_label)
        file_group.addWidget(self.database_lineedit)
        file_group.addWidget(self.database_button)

        toolbar_group = QHBoxLayout()
        toolbar_group.addLayout(file_group)
        return toolbar_group

    def load_attributes_data(self, path):
        npys_path = glob.glob(path + '\*.npy')
        self.attributes_data = {}
        for npy in npys_path:
            attribute_name = npy.split('\\')[-1].split('.npy')[0]
            direction = np.load(npy)
            slider, scale = 0, 1
            data = [direction, slider, scale]
            self.attributes_data[attribute_name] = data

    def build_attribute_box(self, attribute, intensity_slider, scaler_spinbox, intensity_lineedit):
        attribute_label = QLineEdit(attribute)
        attribute_label.setAlignment(Qt.AlignLeft)
        attribute_label.setFixedWidth(125)
        attribute_label.setReadOnly(True)
        attribute_label.setMinimumHeight(20)

        attribute_hbox = QHBoxLayout()
        attribute_hbox.addWidget(attribute_label, 1)
        attribute_hbox.addWidget(intensity_slider, 4)
        attribute_hbox.addWidget(scaler_spinbox, 1)
        attribute_hbox.addWidget(intensity_lineedit, 1)
        return attribute_hbox   

    def build_attributes_controls(self):
        intensity_slider = QSlider()
        intensity_slider.setOrientation(Qt.Horizontal)
        intensity_slider.setTickPosition(QSlider.TicksBelow)
        intensity_slider.setMinimum(-20)
        intensity_slider.setMaximum(20)
        intensity_slider.setSingleStep(1)
        intensity_slider.setSliderPosition(0)
        intensity_slider.setTracking(False)

        scaler_spinbox = QDoubleSpinBox()
        scaler_spinbox.setAlignment(Qt.AlignCenter)
        scaler_spinbox.setMaximum(20)
        scaler_spinbox.setMinimum(0)
        scaler_spinbox.setValue(1)
        scaler_spinbox.setDecimals(3)
        scaler_spinbox.setSingleStep(0.025)

        if not self.StyleGAN_init_state:
            scaler_spinbox.setDisabled(True)
            intensity_slider.setDisabled(True)

        intensity_lineedit = QLineEdit()
        intensity_lineedit.setReadOnly(True)
        intensity_lineedit.setAlignment(Qt.AlignCenter)
        intensity_lineedit.setText('0.000')

        controls = [intensity_slider, scaler_spinbox, intensity_lineedit]
        return controls

    def build_transformation_control(self):
        img_seed_spinbox = QSpinBox()
        img_seed_spinbox.setAlignment(Qt.AlignCenter)
        img_seed_spinbox.setRange(0, 999999)

        img_seed_button = QPushButton('Generate')
        img_load_button = QPushButton('Load')
        img_save_button = QPushButton('Save')
        img_delete_button = QPushButton('Delete')

        img_seed_spinbox.setDisabled(True)
        img_seed_button.setDisabled(True)
        img_load_button.setDisabled(True)
        img_save_button.setDisabled(True)
        img_delete_button.setDisabled(True)

        img_load_lineedit = QLineEdit()
        img_load_lineedit.setReadOnly(True)

        img_original = QLabel()
        img_original.setAlignment(Qt.AlignCenter)
        img_original.setPixmap(self.placeholder_img)
        img_original.setMinimumSize(QSize(self.img_size, self.img_size))

        img_transformed = QLabel()
        img_transformed.setAlignment(Qt.AlignCenter)
        img_transformed.setPixmap(self.placeholder_img)
        img_transformed.setMinimumSize(QSize(self.img_size, self.img_size))

        controls = [img_seed_spinbox, img_seed_button, 
                    img_load_lineedit, img_load_button, 
                    img_save_button, img_delete_button,
                    img_original, img_transformed]
        return controls

    def add_transformation_layout(self):
        controls = self.build_transformation_control()
        self.image_controls[self.transform_count] = controls

        self.image_controls[self.transform_count][1].clicked.connect(
            lambda _, idx=self.add_new_count: self.on_img_seed_button_clicked(idx))
        self.image_controls[self.transform_count][3].clicked.connect(
            lambda _, idx=self.add_new_count: self.on_img_load_button_clicked(idx))
        self.image_controls[self.transform_count][-4].clicked.connect(
            lambda _, idx=self.add_new_count: self.on_img_save_button_clicked(idx))
        self.image_controls[self.transform_count][-3].clicked.connect(
            lambda _, idx=self.add_new_count: self.on_img_delete_button_clicked(idx))

        img_seed_generate_widgets = QHBoxLayout()
        img_seed_generate_widgets.addWidget(QLabel('Seed:'), 1)
        img_seed_generate_widgets.addWidget(controls[0], 2)

        img_seed_generate_layout = QVBoxLayout()
        img_seed_generate_layout.addLayout(img_seed_generate_widgets)
        img_seed_generate_layout.addWidget(controls[1])

        img_seed_generate_frame = QFrame()
        img_seed_generate_frame.setLayout(img_seed_generate_layout)

        img_seed_load_select = QPushButton('Select')
        img_seed_load_select.clicked.connect(
            lambda _, idx=self.add_new_count: self.on_img_load_select_clicked(idx))

        img_seed_load_widgets = QHBoxLayout()
        img_seed_load_widgets.addWidget(QLabel('File:'), 1)
        img_seed_load_widgets.addWidget(controls[2], 2)
        img_seed_load_widgets.addWidget(img_seed_load_select, 1)

        img_seed_load_layout = QVBoxLayout()
        img_seed_load_layout.addLayout(img_seed_load_widgets)
        img_seed_load_layout.addWidget(controls[3])

        img_seed_load_frame = QFrame()
        img_seed_load_frame.setLayout(img_seed_load_layout)

        img_seed_tab_layout = QTabWidget()
        img_seed_tab_layout.addTab(img_seed_generate_frame, 'Generate Seed')
        img_seed_tab_layout.addTab(img_seed_load_frame, 'Load Projected Face')

        img_seed_layout = QVBoxLayout()
        img_seed_layout.addWidget(img_seed_tab_layout, 1)
        img_seed_layout.addWidget(controls[4], 1)
        img_seed_layout.addWidget(controls[5], 1)

        img_original_label = QLabel('Original')
        self.set_font(img_original_label, weight='Bold', align=Qt.AlignCenter, underline=True)
        img_original_label.setAlignment(Qt.AlignCenter)

        img_transformed_label = QLabel('Transformed')
        self.set_font(img_transformed_label, weight='Bold', align=Qt.AlignCenter, underline=True)
        img_transformed_label.setAlignment(Qt.AlignCenter)

        img_header_group = QHBoxLayout()
        img_header_group.addWidget(img_original_label)
        img_header_group.addWidget(img_transformed_label)

        img_label_group = QHBoxLayout()
        img_label_group.addWidget(controls[-2])
        img_label_group.addWidget(controls[-1])

        imgs_group = QVBoxLayout()
        imgs_group.addLayout(img_header_group)
        imgs_group.addLayout(img_label_group)

        control_group = QHBoxLayout()
        control_group.addLayout(img_seed_layout, 1)
        control_group.addLayout(imgs_group, 2)

        line = QFrame()
        line.setFrameShape(QFrame.HLine)
        line.setStyleSheet('border-bottom: 1px solid gray')

        transformation_box = QVBoxLayout()
        transformation_box.addWidget(line)
        transformation_box.addLayout(control_group)
        return transformation_box

    def build_attributes_layout(self):
        self.attribute_group = QVBoxLayout()
        self.attribute_group.setSpacing(20)
        self.attribute_group.setAlignment(Qt.AlignTop)

        for idx, attribute in enumerate(self.attributes_data):
            self.attributes_controls[attribute] = self.build_attributes_controls()
            attribute_box = self.build_attribute_box(attribute, 
                                                     self.attributes_controls[attribute][0], 
                                                     self.attributes_controls[attribute][1], 
                                                     self.attributes_controls[attribute][2])
            self.attribute_group.addLayout(attribute_box)

            self.attributes_controls[attribute][0].valueChanged.connect(
                lambda val, idx=idx: self.get_slider_value(val, idx))
            self.attributes_controls[attribute][1].valueChanged.connect(
                lambda val, idx=idx: self.get_scale_value(val, idx))

        control_frame = QFrame()
        control_frame.setLayout(self.attribute_group)

        control_scrollarea = QScrollArea()
        control_scrollarea.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
        control_scrollarea.setVerticalScrollBarPolicy(Qt.ScrollBarAsNeeded)
        control_scrollarea.setWidgetResizable(True)
        control_scrollarea.setWidget(control_frame)

        control_layout = QVBoxLayout()
        control_layout.setContentsMargins(0, 0, 0, 0)
        control_layout.addWidget(control_scrollarea)

        self.control_group.addLayout(control_layout)
        return control_layout

    def build_central_layout(self):
        # inint dictionaries
        self.attributes_controls = {}
        self.image_controls = {}
        self.image_data = {}

        attribute_label = QLabel('Attributes')
        self.set_font(attribute_label, weight='Bold', align=Qt.AlignCenter, underline=True)

        slider_label = QLabel('Direction')
        self.set_font(slider_label, weight='Bold', align=Qt.AlignCenter, underline=True)

        scale_label = QLabel('Scale')
        self.set_font(scale_label, weight='Bold', align=Qt.AlignCenter, underline=True)

        intensity_label = QLabel('Intensity')
        self.set_font(intensity_label, weight='Bold', align=Qt.AlignCenter, underline=True)

        header_box = QHBoxLayout()
        header_box.addWidget(attribute_label, 1)
        header_box.addWidget(slider_label, 4)
        header_box.addWidget(scale_label, 1)
        header_box.addWidget(intensity_label, 1)

        self.control_group = QVBoxLayout()
        self.control_group.addLayout(header_box)

        control_groupbox = QGroupBox('Controls')
        control_groupbox.setLayout(self.control_group)

        self.transformation_data = {}
        default_transformation_layout = self.add_transformation_layout()

        self.transformation_group = QVBoxLayout()
        self.transformation_group.setAlignment(Qt.AlignTop)
        self.transformation_group.addLayout(default_transformation_layout)
        self.transformation_group.addStretch(1)

        img_frame = QFrame()
        img_frame.setLayout(self.transformation_group)

        img_scrollarea = QScrollArea()
        img_scrollarea.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
        img_scrollarea.setVerticalScrollBarPolicy(Qt.ScrollBarAsNeeded)
        img_scrollarea.setWidgetResizable(True)
        img_scrollarea.setWidget(img_frame)

        img_control_save_label = QLabel('Save Directory:')
        img_control_save_label.setAlignment(Qt.AlignCenter)

        self.img_control_save_lineedit = QLineEdit()
        self.img_control_save_lineedit.setReadOnly(True)

        self.img_control_select_btn = QPushButton('Select')
        self.img_control_add_btn = QPushButton('Add New Seed')
        self.img_control_add_btn.setDisabled(True)
        self.img_control_reset_btn = QPushButton('Reset')
        self.img_control_reset_btn.setDisabled(True)

        img_control = QHBoxLayout()
        img_control.addWidget(img_control_save_label, 1)
        img_control.addWidget(self.img_control_save_lineedit, 2)
        img_control.addWidget(self.img_control_select_btn, 1)
        img_control.addWidget(self.img_control_add_btn, 1)
        img_control.addWidget(self.img_control_reset_btn, 1)

        img_layout = QVBoxLayout()
        img_layout.setContentsMargins(0, 0, 0, 0)
        img_layout.addLayout(img_control)
        img_layout.addWidget(img_scrollarea)

        img_groupbox = QGroupBox('Image')
        img_groupbox.setLayout(img_layout)

        # --------------------------------------------------
        central_hbox_layout = QHBoxLayout()
        central_hbox_layout.addWidget(control_groupbox, 1)
        central_hbox_layout.addWidget(img_groupbox, 1)
        return central_hbox_layout

    def build_console_layout(self):
        self.console_textedit = QTextEdit()
        self.console_textedit.setReadOnly(True)

        console_group = QVBoxLayout()
        console_group.addWidget(self.console_textedit)

        console_groupbox = QGroupBox('Console')
        console_groupbox.setLayout(console_group)
        return console_groupbox

    def connect_signals(self):
        self.network_file_button.clicked.connect(self.on_network_file_button_clicked)
        self.database_button.clicked.connect(self.on_database_folder_button_clicked)
        self.img_control_select_btn.clicked.connect(self.on_img_control_select_button_clicked)
        self.img_control_add_btn.clicked.connect(self.on_img_control_add_button_clicked)
        self.img_control_reset_btn.clicked.connect(self.on_img_control_reset_button_clicked)

    def set_font(self, text, weight='Normal', align=Qt.AlignLeft, underline=False):
        weights = {'Thin':0, 'ExtraLight':12, 'Light':25, 'Normal':50, 'Medium':57,
                   'DemiBold':63, 'Bold':75, 'ExtraBold':81, 'Black':87}
        weight = list(weights.values())[list(weights.keys()).index(weight)]

        font_underline = QFont()
        if underline: font_underline.setUnderline(True)
        font_underline.setWeight(weight)

        text.setFont(font_underline)
        text.setAlignment(align)

    def set_delay(self, time_in_msec):
        timer = QTimer()
        timer.setSingleShot(True)
        timer.timeout.connect(lambda tT=timer: self.stop_time(tT))
        timer.start(time_in_msec)

    def stop_time(self, timer):
        timer.stop()

    def squeeze_layout(self, layout):
        layout.setSpacing(0)
        layout.addStretch(1)

    def image_np_to_pixmap(self, img):
        (height, width, _) = img.shape
        img_qt = QImage(img.data.tobytes(), width, height, 3*width, QImage.Format_RGB888)
        img_pixmap = QPixmap.fromImage(img_qt)
        return img_pixmap

    def image_pixmap_to_icon(self, img):
        img_icon = QIcon(img)
        return img_icon

    def update_image_label(self, img, label):
        img = self.image_np_to_pixmap(img)
        img = img.scaled(self.img_size, self.img_size,  Qt.KeepAspectRatio)
        label.setPixmap(img)

    def update_default_settings(self, item_to_update, data):
        cwd = Path.cwd()
        default_settings_path = cwd / 'default_settings.npy'
        self.default_settings = np.load(default_settings_path, allow_pickle=True)
        self.default_settings[0][item_to_update] = data
        np.save(default_settings_path, self.default_settings)

    @pyqtSlot(str)
    def on_stream_message(self, message):
        self.console_textedit.moveCursor(QTextCursor.End)
        self.console_textedit.insertPlainText(message)

    @pyqtSlot()
    def on_network_file_button_clicked(self):
        file_dialog = QFileDialog()
        file_dialog.setFileMode(QFileDialog.AnyFile)
        file_types = 'Pickle File (*.pkl)'
        file_dialog.setNameFilter(file_types)  

        if file_dialog.exec() == QDialog.Accepted:
            self.check_network_file_type(file_dialog.selectedFiles()[0])

    def check_network_file_type(self, path):
        controls_image = list(self.image_controls.values())
        controls_image[0] = controls_image[0][:4] + controls_image[0][6:]
        print('Verifying selected network file...', end=' ')
        self.selected_file = path
        selected_file_extension = self.selected_file.split('\\')[-1].split('.')[-1]
        if selected_file_extension == 'pkl':
            self.network_lineedit.setText(self.selected_file)
            for controls in controls_image:
                self.set_widget_disabled(controls, False)
            self.update_default_settings('default_network_path', self.selected_file)
            print('Valid.')
            print('Selected File: %s'%(self.selected_file))
            print('Initialising StyleGAN model.... %s'%(self.selected_file), end=' ')
            self.StyleGAN_init_state = False # re init stylegan model
            if not self.StyleGAN_init_state:
                self.StyleGAN_helper = StyleGAN_init(self.network_lineedit.text())
                self.StyleGAN_init_state = True
            if self.img_control_save_lineedit.text():
                self.set_widget_disabled([controls_image[0][4]], False)
            controls_attributes = list(self.attributes_controls.values())
            for controls in controls_attributes:
                self.set_widget_disabled(controls, False)
            print('Done.')
        else:
            for controls in controls_image:
                self.set_widget_disabled(controls, True)
            print('Invalid.')
            print('Warning! Wrong file type selected. \'*.pkl\' only.')

    def on_database_folder_button_clicked(self):
        file_dialog = QFileDialog()
        file_dialog.setFileMode(QFileDialog.DirectoryOnly)

        if file_dialog.exec() == QDialog.Accepted:
            selected_folder = file_dialog.selectedFiles()[0]
            self.database_lineedit.setText(selected_folder)
            self.load_attributes_data(selected_folder)
            if len(self.control_group.children()) > 1:
                controls_layout = self.control_group.children()[-1]
                self.delete_layout(controls_layout)
                self.attributes_controls.clear()
            self.control_group.addLayout(self.build_attributes_layout())

            self.update_default_settings('default_database_path', selected_folder)

    @pyqtSlot()
    def on_img_control_select_button_clicked(self):
        file_dialog = QFileDialog()
        file_dialog.setFileMode(QFileDialog.DirectoryOnly)

        if file_dialog.exec() == QDialog.Accepted:
            selected_folder = file_dialog.selectedFiles()[0]
            self.img_control_save_lineedit.setText(selected_folder)
            self.update_default_settings('default_storage_path', selected_folder)

            for controls in self.image_controls:
                self.image_controls[controls][4].setDisabled(False)

    @pyqtSlot()
    def on_img_control_add_button_clicked(self):
        if self.transform_count > 0 and not self.allow_add_new or \
        self.transform_count == self.max_transform_count - 1:
            return

        self.allow_add_new = False
        self.transform_count += 1
        self.add_new_count += 1
        self.add_count_row[self.transform_count] = self.add_new_count
        transformation_layout = self.add_transformation_layout()
        self.transformation_group.addLayout(transformation_layout)
        self.image_data[self.transform_count] = [0, 0, 0]

        controls = list(self.image_controls.values())[self.transform_count]
        for idx in range(len(controls)-2):
            controls[idx].setDisabled(False)
        self.img_control_add_btn.setDisabled(True)

    @pyqtSlot()
    def on_img_control_reset_button_clicked(self):
        message_box = QMessageBox()
        message_box.setWindowTitle(' ')
        message_box.setIcon(QMessageBox.Warning)
        message_box.setText('All data will be lost!')
        message_box.setStandardButtons(QMessageBox.Yes | QMessageBox.Cancel)

        if message_box.exec() == QMessageBox.Yes:
            print('Resetting...', end=' ')
            for controls in self.attributes_controls:
                self.attributes_controls[controls][0].setValue(0)
                self.attributes_controls[controls][1].setValue(1)
                self.attributes_controls[controls][2].setText('0.000')

            for data in self.attributes_data:
                self.attributes_data[data][1] = 0
                self.attributes_data[data][2] = 0

            for controls in reversed(list(self.image_controls.keys())):
                if len(self.image_controls) != 1: 
                    row_idx_to_count = self.add_count_row[len(self.image_controls)-1]
                    self.delete_image_control_box(row_idx_to_count)
                else:
                    self.image_controls[controls][-2].setPixmap(self.placeholder_img)
                    self.image_controls[controls][-1].setPixmap(self.placeholder_img)

            for data in reversed(list(self.image_data.keys())):
                if len(self.image_data) != 1: 
                    self.image_data.pop(data)
                else: 
                    self.image_data[data][1] = 0
                    self.image_data[data][2] = 0

            self.add_new_count = 0
            self.add_count_row = dict(zip(range(5),[0]*5))  
            self.allow_add_new = True
            print('Done.')

    @pyqtSlot()
    def on_img_seed_button_clicked(self, count):
        if not self.StyleGAN_init_state:
            self.StyleGAN_helper = StyleGAN_init(self.network_lineedit.text())
            self.StyleGAN_init_state = True

            self.img_control_add_btn.setDisabled(False)
            self.img_control_reset_btn.setDisabled(False)
            controls_image = list(self.image_controls.values())
            if self.img_control_save_lineedit.text():
                self.set_widget_disabled([controls_image[0][4]], False)
            controls_attributes = list(self.attributes_controls.values())
            for controls in controls_attributes:
                self.set_widget_disabled(controls, False)

        row_idx = list(self.add_count_row.values()).index(count)
        if row_idx == self.transform_count:
            self.allow_add_new = True
            self.img_control_add_btn.setDisabled(False)
            self.img_control_reset_btn.setDisabled(False)
            self.image_controls[row_idx][4].setDisabled(False)

        controls = list(self.image_controls.values())[row_idx]
        selected_seed = controls[0].value()
        self.stylegan = StyleGAN_img(self.StyleGAN_helper.G,
                                           selected_seed,
                                           self.StyleGAN_helper.z_shape,
                                           self.StyleGAN_helper.zero_label,
                                           self.StyleGAN_helper.device)
        z_original = self.stylegan.generate_latent_vector()
        original_image = self.stylegan.generate_image(z_original)
        self.image_data[row_idx] = [selected_seed, z_original, z_original]

        self.update_image_label(original_image, controls[-2])
        self.update_image_label(original_image, controls[-1])
        self.apply_direction()

    @pyqtSlot()
    def on_img_load_select_clicked(self, count):
        file_dialog = QFileDialog()
        file_dialog.setFileMode(QFileDialog.AnyFile)
        file_types = 'Numpy File (*.npz)'
        file_dialog.setNameFilter(file_types)  

        if file_dialog.exec() == QDialog.Accepted:
            selected_folder = file_dialog.selectedFiles()[0]
            row_idx = list(self.add_count_row.values()).index(count)
            seed_load_lineedit = self.image_controls[row_idx][2]
            seed_load_lineedit.setText(str(selected_folder))

    @pyqtSlot()
    def on_img_load_button_clicked(self, count):
        row_idx = list(self.add_count_row.values()).index(count)
        controls = list(self.image_controls.values())[row_idx]

        row_idx = list(self.add_count_row.values()).index(count)
        if row_idx == self.transform_count:
            self.allow_add_new = True
            self.img_control_add_btn.setDisabled(False)
            self.img_control_reset_btn.setDisabled(False)

        project_ws_path = self.image_controls[row_idx][2].text()
        if not project_ws_path: return
        projected_ws = np.load(project_ws_path)['w']
        self.stylegan = StyleGAN_img(self.StyleGAN_helper.G,
                                           0,
                                           self.StyleGAN_helper.z_shape,
                                           self.StyleGAN_helper.zero_label,
                                           self.StyleGAN_helper.device)
        projected_face = self.stylegan.generate_projected(projected_ws)
        self.image_data[row_idx] = [-1, projected_ws, projected_ws]

        self.update_image_label(projected_face, controls[-2])
        self.update_image_label(projected_face, controls[-1])
        self.apply_direction()

    @pyqtSlot()
    def on_img_save_button_clicked(self, count):
        row_idx = list(self.add_count_row.values()).index(count)
        save_btn = list(self.image_controls.values())[row_idx][4]
        save_btn.setText('Saving...')
        QCoreApplication.processEvents()

        folder_path = self.img_control_save_lineedit.text()
        img_seed = list(self.image_data.values())[row_idx][0]
        if img_seed == -1:
            projected_path = list(self.image_controls.values())[row_idx][2].text()
            save_name = projected_path.split('/')[-1].split('.')[0]
            print(projected_path)
            print('Saving Projected: %s...'%(save_name), end='')
        else:
            save_name = 'seed_%d'%(img_seed)
            print('Saving Seed %d...'%(img_seed), end='')
        save_name_len = len(save_name)
        save_count = 1

        for file_path in glob.glob(folder_path + "\*.npy"):
            file_name = file_path.split('\\')[-1].split('.')[0]
            if save_name in file_name:
                save_name = save_name[:save_name_len] + '_%d'%(save_count)
                save_count += 1

        img_data = np.array(list(self.image_data.values())[row_idx], dtype=object)
        attributes_data = np.array(self.attributes_data, dtype=object)
        save_data = np.array([img_data, attributes_data], dtype=object)
        save_path = folder_path + '\\' + save_name + '.npy'
        np.save(save_path, save_data)

        original_z = list(self.image_data.values())[row_idx][1]
        transformed_z = list(self.image_data.values())[row_idx][2]
        if list(self.image_data.values())[row_idx][0] == -1:
            save_img_original = self.stylegan.generate_projected(original_z)
            save_img_transformed = self.stylegan.generate_projected(transformed_z)
        else:  
            save_img_original = self.stylegan.generate_image(original_z)
            save_img_transformed = self.stylegan.generate_image(transformed_z)

        fig, ax = plt.subplots(nrows=1, ncols=2)
        ax[0].imshow(save_img_original)
        ax[0].set_title('Original')
        ax[0].axis('off')

        ax[1].imshow(save_img_transformed)
        ax[1].set_title('Transformed')
        ax[1].axis('off')

        save_path = folder_path + '\\' + save_name + '.png'
        fig.savefig(save_path, format='png', dpi=300)
        plt.close(fig)

        QTimer.singleShot(1000, lambda: save_btn.setText('Save'))
        print('Saved to: %s...'%(save_path))

    @pyqtSlot()
    def on_img_delete_button_clicked(self, idx):
        self.delete_image_control_box(idx)

    def delete_image_control_box(self, count):
        row_values = list(self.add_count_row.values())
        row_idx = row_values.index(count)
        for i in range(self.max_transform_count - 1):
            if i > row_idx - 1:
                self.add_count_row[i] = row_values[i + 1]

        control_layout = self.transformation_group.children()[row_idx]
        self.delete_layout(control_layout)
        self.transformation_group.removeItem(control_layout)

        layout_count = len(self.transformation_group.children())
        new_keys = [idx for idx in range(layout_count)]

        self.image_data.pop(row_idx)
        self.image_controls.pop(row_idx)

        data_values = list(self.image_data.values())
        self.image_data = dict(zip(new_keys, data_values))

        controls_values = list(self.image_controls.values())
        self.image_controls = dict(zip(new_keys, controls_values))

        self.transform_count -= 1

    def delete_layout(self, layout):
        if layout is not None:
            while layout.count():
                item = layout.takeAt(0)
                widget = item.widget()
                if widget is not None:
                    widget.deleteLater()
                else:
                    self.delete_layout(item.layout())
            del layout

    def set_widget_disabled(self, widgets, state):
        for widget in widgets:
            widget.setDisabled(state)

    def calculate_intensity(self, slider, scale):
        intensity = slider * scale
        return round(intensity, 3)

    def apply_direction(self):
        for idx, controls in enumerate(list(self.image_controls.values())):
            if not self.allow_add_new and idx == self.transform_count: break
            z = list(self.image_data.values())[idx][1]

            for attribute in self.attributes_data:
                direction = self.attributes_data[attribute][0]
                slider = self.attributes_data[attribute][1]
                scale = self.attributes_data[attribute][2]
                intensity = self.calculate_intensity(slider, scale)
                z = z + direction * intensity

            seed = list(self.image_data.values())[idx][0]
            original_z = list(self.image_data.values())[idx][1]
            self.image_data[idx] = [seed, original_z, z]
            if seed == -1:
                transformed_image = self.stylegan.generate_projected(z)
            else:
                transformed_image = self.stylegan.generate_image(z)
            self.update_image_label(transformed_image, controls[-1])

    def get_slider_value(self, value, idx):
        slider = value
        scale = list(self.attributes_controls.values())[idx][1].value()
        intensity = self.calculate_intensity(slider, scale)
        list(self.attributes_controls.values())[idx][2].setText(str(intensity))
        self.update_attributes_data(idx, slider, scale)
        self.apply_direction()

    def get_scale_value(self, value, idx):
        scale = value
        slider = list(self.attributes_controls.values())[idx][0].value()
        intensity = self.calculate_intensity(slider, scale)
        list(self.attributes_controls.values())[idx][2].setText(str(intensity))
        self.update_attributes_data(idx, slider, scale)
        self.apply_direction()

    def update_attributes_data(self, idx, slider, scale):
        list(self.attributes_data.values())[idx][1] = slider
        list(self.attributes_data.values())[idx][2] = scale

In [9]:
class stdout_override(QObject):
    message = pyqtSignal(str)
    def __init__(self):
        super(stdout_override, self).__init__()

    def write(self, message):
        self.message.emit(str(message))

    def flush(self):
        pass

In [10]:
def main():
    app = QApplication(sys.argv)

    stylesheet = "stylesheet.css"
    with open(stylesheet,"r") as f:
        app.setStyleSheet(f.read())

    window = MainWindow()
    window.show()

    stream = stdout_override()
    stream.message.connect(window.on_stream_message)
    sys.stdout = stream

    welcome_message()
    app.exec()

def welcome_message():
    print('Select a network file (*.pkl) to start.')

if __name__ == '__main__':
    main()