In [2]:
import io
import os
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import backend as K
import matplotlib.pyplot as plt
from PIL import Image
import PySimpleGUI as sg
import os.path

def create_model(xSize, ySize):
    model = keras.models.Sequential([
    keras.layers.Conv2D(64, (3, 3), padding="same", input_shape=[xSize,ySize,3]),
    keras.layers.Activation("relu"),
    keras.layers.Conv2D(64, (3, 3), padding="same"),
    keras.layers.BatchNormalization(),
    keras.layers.Activation("relu"),
    keras.layers.MaxPooling2D(),
    keras.layers.Conv2D(128, (3, 3), padding="same"),
    keras.layers.Activation("relu"),
    keras.layers.Conv2D(128, (3, 3), padding="same"),
    keras.layers.BatchNormalization(),
    keras.layers.Activation("relu"),
    keras.layers.MaxPooling2D(),
    keras.layers.Conv2D(256, (3, 3), padding="same"),
    keras.layers.Activation("relu"),
    keras.layers.Conv2D(256, (3, 3), padding="same"),
    keras.layers.BatchNormalization(),
    keras.layers.Activation("relu"),
    keras.layers.MaxPooling2D(),
    keras.layers.Conv2D(512, (3, 3), padding="same"),
    keras.layers.Activation("relu"),
    keras.layers.Conv2D(512, (3, 3), padding="same"),
    keras.layers.BatchNormalization(),
    keras.layers.Activation("relu"),
    keras.layers.MaxPooling2D(),
    keras.layers.Conv2D(1024, (3, 3), padding="same"),
    keras.layers.Activation("relu"),
    keras.layers.Conv2D(1024, (3, 3), padding="same"),
    keras.layers.BatchNormalization(),
    keras.layers.Activation("relu"),
    keras.layers.Conv2DTranspose(512, (2, 2), strides=(2, 2), padding="same"),
    keras.layers.Conv2D(512, (3, 3), padding="same"),
    keras.layers.Activation("relu"),
    keras.layers.Conv2D(512, (3, 3), padding="same"),
    keras.layers.BatchNormalization(),
    keras.layers.Activation("relu"),
    keras.layers.Conv2DTranspose(256, (2, 2), strides=(2, 2), padding="same"),
    keras.layers.Conv2D(256, (3, 3), padding="same"),
    keras.layers.Activation("relu"),
    keras.layers.Conv2D(256, (3, 3), padding="same"),
    keras.layers.BatchNormalization(),
    keras.layers.Activation("relu"),
    keras.layers.Conv2DTranspose(128, (2, 2), strides=(2, 2), padding="same"),
    keras.layers.Conv2D(128, (3, 3), padding="same"),
    keras.layers.Activation("relu"),
    keras.layers.Conv2D(128, (3, 3), padding="same"),
    keras.layers.BatchNormalization(),
    keras.layers.Activation("relu"),
    keras.layers.Conv2DTranspose(64, (2, 2), strides=(2, 2), padding="same"),
    keras.layers.Conv2D(64, (3, 3), padding="same"),
    keras.layers.Activation("relu"),
    keras.layers.Conv2D(64, (3, 3), padding="same"),
    keras.layers.BatchNormalization(),
    keras.layers.Activation("relu"),
    keras.layers.Conv2D(3, (1, 1), activation="sigmoid")
    ])
    
    return model

def dice_coef(y_true, y_pred, smooth=1):
    intersection = K.sum(K.abs(y_true * y_pred), axis=-1)
    return (2. * intersection + smooth) / (K.sum(K.square(y_true),-1) + K.sum(K.square(y_pred),-1) + smooth)

def dice_coef_loss(y_true, y_pred):
    return 1-dice_coef(y_true, y_pred)

# First the window layout in 2 columns

file_types = [("JPEG (*.jpg)", "*.jpg"),
              ("All files (*.*)", "*.*")]

file_list_column = [
    [
        sg.Text("Choose image"),
        sg.In(size=(25, 1), enable_events=True, key="-INPUT FILE-"),
        sg.FileBrowse(file_types=file_types),
        sg.Button("Start segmentation"),
    ],
    [
        sg.Image(key="-PREVIEW-")
    ],
]

# For now will only show the name of the file that was chosen
image_viewer_column = [
    [sg.Image(key="-IMAGE-")],
]

# ----- Full layout -----
layout = [
    [
        sg.Column(file_list_column),
        sg.VSeperator(),
        sg.Column(image_viewer_column),
    ]
]

window = sg.Window("Image Segmentation", layout)

# Run the Event Loop
while True:
    event, values = window.read()
    if event == "Exit" or event == sg.WIN_CLOSED:
        break

    if event == "Start segmentation":  # A file was chosen from the listbox
        try:
            filename = values["-INPUT FILE-"]
            if os.path.exists(filename):
                preview = Image.open(values["-INPUT FILE-"])
                preview.thumbnail((400, 400))
                bio = io.BytesIO()
                preview.save(bio, format="PNG")
                window["-PREVIEW-"].update(data=bio.getvalue())
                
                image = Image.open(values["-INPUT FILE-"])
                width, height = image.size
                model = create_model(width, height)
                model.compile(optimizer="adam", loss=dice_coef_loss, metrics=["accuracy"])
                model.load_weights('model_dc.h5')
                
                image = np.array(image)
                image = image/255
                
                x = np.stack([image.tolist()])
                pred = model.predict(x)
                predx = np.squeeze(pred)
                
                predx = np.uint8(predx*255)
                img = Image.fromarray(predx, 'RGB')
                
                img.thumbnail((400, 400))
                bio = io.BytesIO()
                img.save(bio, format="PNG")
                window["-IMAGE-"].update(data=bio.getvalue())

        except:
            pass

window.close()

