In [1]:
import customtkinter as tk
from tkinter import *
from tkinter import filedialog
from PIL import ImageTk, Image
import tensorflow as tf
from tensorflow.keras.preprocessing.image import array_to_img, img_to_array, load_img
from tensorflow.keras import layers, losses
from skimage.transform import resize
import numpy as np
from skimage.color import rgb2lab, lab2rgb, rgb2gray, gray2rgb
from skimage.io import imsave

In [2]:
# Defining the loss function
def ssim_loss(y_true, y_pred):
    return 1 - tf.reduce_mean(tf.image.ssim(y_true, y_pred, max_val=1.0))

In [3]:
# Define a custom ImageLabel widget that inherits from Label
class ImageLabel(Label):
    def __init__(self, master=None, **kwargs):
        super().__init__(master, **kwargs)
        self.image = None

    def set_image(self, image_path):
        img = Image.open(image_path)
        self.image = ImageTk.PhotoImage(img)
        self.configure(image=self.image)

In [None]:
# Command Functions 
def selectPic():
    global img
    global filename
    filename = filedialog.askopenfilename(initialdir='/images', title='Select Image', filetypes=(('png images', '*.png'), ('jpg images', '*.jpg')))
    img = Image.open(filename)
    img = img.resize((256, 256), Image.ANTIALIAS)
    img = ImageTk.PhotoImage(img)
    lbl_show_pic.set_image(img)
    entry_pic_path.insert(0, filename)

def colorize():
    global result_image
    # Load model into here     
    model = tf.keras.models.load_model('my_model.model', custom_objects={'ssim_loss': ssim_loss})
    img_color = []
    img = img_to_array(load_img(filename))
    img = resize(img, (160, 160))
    img_color.append(img)
    img_color = np.array(img_color, dtype=float)
    img_color = rgb2lab(1.0/255*img_color)[:,:,:,0]
    img_color = img_color.reshape(img_color.shape+(1,))
    output = model.predict(img_color)
    output = output*128
    result = np.zeros((160, 160, 3))
    result[:,:,0] = img_color[0][:,:,0]
    result[:,:,1:] = output[0]
    result = lab2rgb(result)
    imsave("D:/ /Semester 6/AI/Project/out.png", result)
    result = Image.fromarray((result*255).astype(np.uint8))
    result = result.resize((256, 256), resample=Image.LANCZOS)
    result_image = ImageTk.PhotoImage(result)
    lbl_show_result.set_image(result_image)


In [None]:
tk.set_appearance_mode('dark')
tk.set_default_color_theme('dark-blue')

root = tk.CTk()
frame = tk.CTkFrame(root)

# GUI Objects
lbl_show_pic = ImageLabel(frame, width=10, height=10)
lbl_show_result = ImageLabel(frame, width=10, height=10)
lbl_pic_path = tk.CTkLabel(frame, text='Path: ', padx=25, pady=25, font=('verdana', 16)) 
entry_pic_path = tk.CTkEntry(frame, font=('verdana', 16), width=50)
btn_browse = tk.CTkButton(frame, text='Select Image', font=('verdana', 16), command=selectPic)
btn_colorize = tk.CTkButton(frame, text='Colorize', font=('verdana', 16), command=colorize)
frame.pack()

lbl_pic_path.grid(row=0, column=0, padx=10, pady=10)
entry_pic_path.grid(row=0, column=1, padx=10, pady=10)
btn_browse.grid(row=1, column=0, padx=10, pady=10)
btn_colorize.grid(row=1, column=1, padx=10, pady=10)
lbl_show_pic.grid(row=2, column=0)
lbl_show_result.grid(row=2, column=1)

root.mainloop()