In [5]:
import os
import cv2
import dash
from dash import dcc
from dash import html
from dash.dependencies import Input, Output, State
import base64
import io
from PIL import Image
import dash_bootstrap_components as dbc
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras.models import load_model # type: ignore
from scipy import ndimage as nd
from skimage.segmentation import slic
from skimage import img_as_float
import torchvision.transforms as transforms
import torch
import torch
from torch import nn

# Getting the models
# GANS
cross_entropy = tf.keras.losses.BinaryCrossentropy()
mse = tf.keras.losses.MeanSquaredError()
def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output) - tf.random.uniform( shape=real_output.shape , maxval=0.1 ) , real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output) + tf.random.uniform( shape=fake_output.shape , maxval=0.1  ) , fake_output)
    total_loss = real_loss + fake_loss
    return total_loss
def generator_loss(fake_output , real_y):
    real_y = tf.cast( real_y , 'float32' )
    return mse( fake_output , real_y )
generator_optimizer = tf.keras.optimizers.Adam( 0.0005 )
discriminator_optimizer = tf.keras.optimizers.Adam( 0.0005 )
GANS_generator = load_model(r'GANS\generator.keras', custom_objects={'generator_loss': generator_loss, 'generator_optimizer': generator_optimizer})
GANS_discriminator = load_model(r'GANS\discriminator.keras', custom_objects={'discriminator_loss': discriminator_loss, 'discriminator_optimizer': discriminator_optimizer})

# OpenCV
openCVModel = load_model(r"OpenCV\colorize_opencv.keras")

# OpenCV Improved
DIR = "OpenCV"
PROTOTXT = os.path.join(DIR, r"model\colorization_deploy_v2.prototxt")
POINTS = os.path.join(DIR, r"model\pts_in_hull.npy")
MODEL = os.path.join(DIR, r"model\colorization_release_v2.caffemodel")
net = cv2.dnn.readNetFromCaffe(PROTOTXT, MODEL)
pts = np.load(POINTS)
class8 = net.getLayerId("class8_ab")
conv8 = net.getLayerId("conv8_313_rh")
pts = pts.transpose().reshape(2, 313, 1, 1)
net.getLayer(class8).blobs = [pts.astype("float32")]
net.getLayer(conv8).blobs = [np.full([1, 313], 2.606, dtype="float32")]

# Autoencoders
class ColorAutoEncoder(nn.Module):
    def _init_(self):
        super()._init_()
        self.down1 = nn.Conv2d(1, 64, 3, stride=2, padding=1)
        self.down2 = nn.Conv2d(64, 128, 3, stride=2, padding=1)
        self.down3 = nn.Conv2d(128, 256, 3, stride=2, padding=1)
        self.down4 = nn.Conv2d(256, 512, 3, stride=2, padding=1)
        self.up1 = nn.ConvTranspose2d(512, 256, 3, stride=2, padding=1, output_padding=1)
        self.up2 = nn.ConvTranspose2d(512, 128, 3, stride=2, padding=1, output_padding=1)
        self.up3 = nn.ConvTranspose2d(256, 64, 3, stride=2, padding=1, output_padding=1)
        self.up4 = nn.ConvTranspose2d(128, 3, 3, stride=2, padding=1, output_padding=1)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        d1 = self.relu(self.down1(x))
        d2 = self.relu(self.down2(d1))
        d3 = self.relu(self.down3(d2))
        d4 = self.relu(self.down4(d3))
        u1 = self.relu(self.up1(d4))
        u2 = self.relu(self.up2(torch.cat((u1, d3), dim=1)))
        u3 = self.relu(self.up3(torch.cat((u2, d2), dim=1)))
        u4 = self.sigmoid(self.up4(torch.cat((u3, d1), dim=1)))
        return u4
autoEncoderModel = torch.load(r'AutoEncoder\model.pth')
autoEncoderModel.eval()



# KNN
from joblib import load
KNNModel = load(r"KNN\knn_model.joblib")
def extract_features(image):
    X1 = extract_all(image)
    X2 = superpixel(image, False).reshape(-1, 1)
    X3 = extract_neighbors_features(image).reshape(-1, 1)
    # Print shapes for debugging
    print("X1 shape:", X1.shape)
    print("X2 shape:", X2.shape)
    print("X3 shape:", X3.shape)
    X = np.concatenate((X1, X2, X3), axis=1)
    return X
def superpixel(image, status):    
    if status:
        segments = slic(img_as_float(image), n_segments=100, sigma=5)
    else:
        segments = slic(img_as_float(image), n_segments=100, sigma=5, compactness=0.1, channel_axis=None) 
    return segments
def extract_neighbors_features(img, distance=8):
    if len(img.shape) > 2:
        height, width, _ = img.shape  # For colored images
        total_pixels = height * width
    else:
        height, width = img.shape  # For grayscale images
        total_pixels = img.size
    X = []
    for x in range(height):
        for y in range(width):
            neighbors = []
            for k in range(x - distance, x + distance + 1):
                for p in range(y - distance, y + distance + 1):
                    if x == k and p == y:
                        continue
                    elif 0 <= k < height and 0 <= p < width:
                        if len(img.shape) > 2:
                            neighbors.append(img[k, p])  # For colored images
                        else:
                            neighbors.append(img[k, p])  # For grayscale images
                    else:
                        neighbors.append(0)
            X.append(sum(neighbors) / len(neighbors))
    return np.array(X).reshape(total_pixels, -1)
def extract_all(img):
    if len(img.shape) > 2:
        # For colored images
        img2 = img.reshape(-1, img.shape[-1])  # Reshape to 2D array
    else:
        # For grayscale images
        img2 = img.reshape(-1, 1)  # Reshape to 2D array
    # First feature is gray value of each pixel
    df = pd.DataFrame()
    df['GrayValue(I)'] = img2.flatten()
    # Second feature is GAUSSIAN filter with sigma=3
    gaussian_img = nd.gaussian_filter(img, sigma=3)
    gaussian_img1 = gaussian_img.reshape(-1)
    df['Gaussian s3'] = gaussian_img1
    # Third feature is GAUSSIAN filter with sigma=7
    gaussian_img = nd.gaussian_filter(img, sigma=7)
    gaussian_img2 = gaussian_img.reshape(-1)
    df['Gaussian s7'] = gaussian_img2
    # Fourth feature is generic filter for variance of each pixel with size=3
    variance_img = nd.generic_filter(img, np.var, size=3)
    variance_img1 = variance_img.reshape(-1)
    df['Variance s3'] = variance_img1
    return df
def colorize_image(grayscale_image, knn_model):
    # Extract features from grayscale image
    grayscale_features = extract_features(grayscale_image)
    # Predict color values using the trained model
    predicted_colors = knn_model.predict(grayscale_features)
    # Reshape predicted colors to match the size of the grayscale image
    predicted_colors_reshaped = predicted_colors.reshape(grayscale_image.shape[0], grayscale_image.shape[1], -1)   
    # Check the number of channels in the predicted colors
    if predicted_colors_reshaped.shape[-1] == 1:
        # If the predicted colors have only one channel, convert to BGR
        colorized_image = cv2.cvtColor(predicted_colors_reshaped, cv2.COLOR_GRAY2BGR)
    elif predicted_colors_reshaped.shape[-1] == 3:
        # If the predicted colors have three channels, no need to convert
        colorized_image = predicted_colors_reshaped
    else:
        raise ValueError("Invalid number of channels in predicted colors")
    return colorized_image


# Define the Dash app
app = dash.Dash(external_stylesheets=[dbc.themes.BOOTSTRAP])

# Define the layout of the web page
app.layout = dbc.Container(
    fluid=True,
    style={
        'backgroundImage': 'url("/assets/splash.jpg")',
        'backgroundRepeat': 'no-repeat',
        'backgroundPosition': 'center',
        'backgroundSize': 'cover',
        'position': 'relative', 
        'minHeight': '100vh',
    },
    children=[
        dbc.Row(
            [
                dbc.Col(
                    dbc.Card(
                        dbc.CardBody(
                            [
                                html.H1('Revive your pictures with a splash of colors!', className="display-3"),
                                html.P("Welcome to Colorify, where black and white memories find their vibrant voice! Reveal the hidden beauty of the past with our innovative models!",
                                       className="lead"),
                                html.Hr(className="my-2"),
                                html.P("Select a model:"),
                                dbc.ButtonGroup(
                                    [dbc.Button('Using GANs', id='gans', color="primary", className="mr-1"),
                                    #  dbc.Button('Using Autoencoders', id='autoencoders', color="primary", className="mr-1"),
                                     dbc.Button('Using OpenCV', id='opencv', color="primary", className="mr-1"),
                                     dbc.Button('Using KNN', id='knn', color="primary", className="mr-1"),
                                     dbc.Button('Using OpenCV_Improved', id='opencvimproved', color="primary", className="mr-1")
                                     ],
                                    size="lg"
                                ),
                            ]
                        ),
                        className="mb-3",
                    ),
                    width=12,
                    className="text-center",
                ),
            ]
        ),
        dbc.Row(
    [
        dbc.Col(
            [
                html.H2('Original Image:', className="mb-2 mt-2"),

                dcc.Upload(
                    id='upload-image',
                    children=html.Div(['Drag and Drop or ', html.A('Select an Image')]),
                    style={
                        'width': '450px',  
                        'height': '500px',  
                        'lineHeight': '300px',  
                        'borderWidth': '2px',
                        'borderStyle': 'dashed',
                        'borderRadius': '5px',
                        'textAlign': 'center',
                        'margin': '10px',
                        'padding': '20px',
                    },
                    accept='image/*'
                ),
            ],
            width=6,
            className="text-left",
        ),
        dbc.Col(width=3), 
        dbc.Col(
            [
                html.H2('Output Image:', className="mb-2 mt-2"),
                html.Div(id='output-image-upload', style={
                        'width': '450px',  
                        'height': '500px',  
                        'lineHeight': '300px',  
                        'borderWidth': '2px',
                        'borderStyle': 'dashed',
                        'borderRadius': '5px',
                        'textAlign': 'center',
                        'margin': '10px',
                        'padding': '20px',}),
            ],
            width=3, 
            className="text-left",
        ),
    ],
    justify="between", 
),
    ],
)

@app.callback(
    Output('upload-image', 'children'),
    [Input('upload-image', 'contents')],
    [State('upload-image', 'filename')]
)
def update_output(contents, filename):
    if contents is not None:
        content_type, content_string = contents.split(',')
        decoded = base64.b64decode(content_string)
        try:
            if 'jpg' in filename:
                data = Image.open(io.BytesIO(decoded))
        except Exception as e:
            print(e)
            return html.Div([
                'There was an error processing this file.'
            ])

        return html.Div([
            html.Img(src=contents, style={'height':'100%', 'width':'100%'}),
            html.Hr(),
            html.H5(filename)
        ])

    return html.Div(['Drag and Drop or ', html.A('Select an Image')])

# Callback to display the uploaded image for GANS
@app.callback(
    Output('output-image-upload', 'children'),
    [Input('autoencoders', 'n_clicks'), Input('gans', 'n_clicks'), Input('opencv', 'n_clicks'), Input('knn', 'n_clicks'), Input('opencvimproved', 'n_clicks')],
    [State('upload-image', 'contents')]
)
def GANS_output(gans_n_clicks, autoencoders_n_clicks, opencv_n_clicks, knn_n_clicks,opencvimproved_n_clicks, contents):

    ctx = dash.callback_context
    if not ctx.triggered:
        return None
    else:
        button_id = ctx.triggered[0]['prop_id'].split('.')[0]
    
    if button_id == 'gans' and autoencoders_n_clicks is not None and contents is not None:
        content_type, content_string = contents.split(',')
        decoded = base64.b64decode(content_string)
        image = Image.open(io.BytesIO(decoded)).resize((120, 120))
        gray_img_array = (np.asarray(image).reshape((1, 120, 120, 1))) / 255  # Add an extra dimension
        y = GANS_generator(gray_img_array).numpy()
        output = Image.fromarray(( y[0] * 255 ).astype( 'uint8' )).resize( ( 400 , 400 ) ) 
        output = np.asarray( output )
        # Convert the colorized image to a base64 string
        buffered = io.BytesIO()
        Image.fromarray(output).save(buffered, format="JPEG")
        img_str = base64.b64encode(buffered.getvalue()).decode()
        return html.Img(src='data:image/jpeg;base64,'+img_str)
    
    elif button_id == 'opencv' and opencv_n_clicks is not None and contents is not None:
        content_type, content_string = contents.split(',')
        decoded = base64.b64decode(content_string)
        image = Image.open(io.BytesIO(decoded)).resize((120, 120))
        image.save("temp.png")
        # Use your trained model to colorize images
        color_image = cv2.imread("temp.png")
        color_image = cv2.cvtColor(color_image, cv2.COLOR_BGR2Lab)
        color_image = cv2.resize(color_image, (256, 256))
        l_channel = color_image[:,:,0]
        l_channel = l_channel.astype('float32') / 255.0
        l_channel = np.expand_dims(l_channel, axis=-1)
        l_channel = np.expand_dims(l_channel, axis=0)  # add an extra dimension for the batch size
        ab_channels = openCVModel.predict(l_channel)
        # Ensure that l_channel and ab_channels both have 3 dimensions
        l_channel = np.squeeze(l_channel, axis=0)
        ab_channels = np.squeeze(ab_channels, axis=0)
        colorized_image = np.concatenate((l_channel, ab_channels), axis=-1)
        colorized_image = colorized_image * 255.0  # denormalize the pixel values
        colorized_image = colorized_image.astype('uint8')  # convert to integer pixel values
        # Ensure that colorized_image has 3 channels before converting color spaces
        if colorized_image.shape[-1] == 2:
            colorized_image = np.concatenate((colorized_image, np.zeros((colorized_image.shape[0], colorized_image.shape[1], 1))), axis=-1)
        colorized_image = cv2.cvtColor(colorized_image, cv2.COLOR_Lab2BGR)
        colorized_image = cv2.resize(colorized_image, (400, 400))
        colorized_image_pil = Image.fromarray(colorized_image)
        buffer = io.BytesIO()
        colorized_image_pil.save(buffer, format='PNG')
        image_base64 = base64.b64encode(buffer.getvalue()).decode()
        image_data_url = f'data:image/png;base64,{image_base64}'
        os.remove("temp.png")
        return html.Img(src=image_data_url)
    
    elif button_id == 'knn' and knn_n_clicks is not None and contents is not None:
        content_type, content_string = contents.split(',')
        decoded = base64.b64decode(content_string)
        image = Image.open(io.BytesIO(decoded)).resize((120, 120))
        image.save("temp.png")
        input_image = cv2.imread("temp.png", cv2.IMREAD_GRAYSCALE)
        colorized_image = colorize_image(input_image, KNNModel)
        colorized_image = colorized_image.astype(np.uint8)
        colorized_image_pil = Image.fromarray(colorized_image).resize((400, 400))
        buffer = io.BytesIO()
        colorized_image_pil.save(buffer, format='PNG')
        image_base64 = base64.b64encode(buffer.getvalue()).decode()
        image_data_url = f'data:image/png;base64,{image_base64}'
        return html.Img(src=image_data_url)
    
    elif button_id == 'opencvimproved' and opencvimproved_n_clicks is not None and contents is not None:
        content_type, content_string = contents.split(',')
        decoded = base64.b64decode(content_string)
        image = Image.open(io.BytesIO(decoded)).resize((120, 120))
        image.save("temp.png")
        image = cv2.imread("temp.png")
        scaled = image.astype("float32") / 255.0
        lab = cv2.cvtColor(scaled, cv2.COLOR_BGR2LAB)
        resized = cv2.resize(lab, (224, 224))
        L = cv2.split(resized)[0]
        L -= 50
        net.setInput(cv2.dnn.blobFromImage(L))
        ab = net.forward()[0, :, :, :].transpose((1, 2, 0))
        ab = cv2.resize(ab, (image.shape[1], image.shape[0]))
        L = cv2.split(lab)[0]
        colorized = np.concatenate((L[:, :, np.newaxis], ab), axis=2)
        colorized = cv2.cvtColor(colorized, cv2.COLOR_LAB2BGR)
        colorized = np.clip(colorized, 0, 1)
        colorized = (255 * colorized).astype("uint8")
        colorized = cv2.resize(colorized, (400, 400))
        colorized_pil = Image.fromarray(colorized)
        buffer = io.BytesIO()
        colorized_pil.save(buffer, format='PNG')
        image_base64 = base64.b64encode(buffer.getvalue()).decode()
        image_data_url = f'data:image/png;base64,{image_base64}'
        os.remove("temp.png")
        return html.Img(src=image_data_url)
    
    # elif button_id == 'autoencoders' and autoencoders_n_clicks is not None and contents is not None:
    #     content_type, content_string = contents.split(',')
    #     decoded = base64.b64decode(content_string)
    #     image = Image.open(io.BytesIO(decoded)).resize((120, 120))
    #     image.save("temp.png")
    #     new_gray_img = Image.open('temp.png')
    #     new_gray_img = new_gray_img.convert('L')
    #     transform = transforms.Compose([
    #         transforms.Resize((256, 256)), 
    #         transforms.ToTensor(),
    #     ])
    #     new_gray_img = transform(new_gray_img).unsqueeze(0)
    #     DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # use GPU if available
    #     new_gray_img = new_gray_img.to(DEVICE)
    #     model = autoEncoderModel.to(DEVICE)
    #     with torch.no_grad():
    #         output = model(new_gray_img)
    #         output_image = output.cpu().squeeze().numpy()
    #     input_image = new_gray_img.cpu().squeeze().numpy()
    #     output_image = np.transpose(output_image, (1, 2, 0))
    #     output_image = Image.fromarray((output_image * 255).astype(np.uint8))
    #     buffer = io.BytesIO()
    #     output_image.save(buffer, format='PNG')
    #     image_base64 = base64.b64encode(buffer.getvalue()).decode()
    #     image_data_url = f'data:image/png;base64,{image_base64}'
    #     os.remove("temp.png")
    #     return html.Img(src=image_data_url)


if __name__ == '__main__':
    app.run_server(port=4051)