In [1]:
from tensorflow.keras.models import load_model
from tensorflow.keras.utils import load_img, img_to_array
from tensorflow.keras.preprocessing import image
import tensorflow as tf
from os.path import join
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from IPython.display import HTML

from PIL import Image as pImage, ImageOps, ExifTags
from ipyfilechooser import FileChooser

import ipywidgets as widgets

In [2]:
# Globals
MODEL_PATH = "dice_type_classification1.keras"
model_classes =["d4","d6","d8","d10","d12","d20"]

model = model = tf.keras.models.load_model(join("models",MODEL_PATH))

In [3]:
def load_image_to_input(image_path):
    with pImage.open(image_path) as img:
        file_name = image_path.split("\\")[-1]
    
        img = ImageOps.exif_transpose(img)
    
        #Zoom in on the image assuming the dice is in the middle
        width, height = img.size
        crop_width = width // 1.5
        crop_height = height // 1.5
        left = (width - crop_width) // 2
        top = (height - crop_height) // 2
        right = left + crop_width
        bottom = top + crop_height
        crop_box = (left, top, right, bottom)
        cropped_img = img.crop(crop_box)

        # Resize image for training
        resized_image = cropped_img.resize((448,448))
        version = 0
        og_filename, ext = file_name.split(".")
        
        while os.path.isfile(join("input",file_name)):
            version +=1
            file_name = og_filename +"("+ str(version) + ")."+ext
            
        resized_image.save(join("input",file_name))
        return file_name

    

In [123]:
def predict_dice(image_array, model = model):

    predictions = model.predict(image_array)
    predicted_class = np.argmax(predictions, axis=1)
    predicted_class_index = np.argmax(predictions, axis=1)
    
    predicted_class_name = model_classes[predicted_class[0]]
    
    confidence = predictions[0][predicted_class_index[0]] * 100
    print(f"Predicted class index: {predicted_class}\n Predicted label: {model_classes[predicted_class[0]]}\n Confidence: {confidence:.2f}%")
    


In [203]:
def show_predictions(image_path_list):
    for img_p in image_path_list:
    
        img = image.load_img(img_p, target_size=(448, 448))
        img_array = image.img_to_array(img)
        img_array = np.expand_dims(img_array, axis=0)
        img_array = tf.keras.applications.vgg16.preprocess_input(img_array)
        
        print(str(img_p).split("\\")[-1])
        predict_dice(img_array)

    

In [204]:
images_chosen = [
    join("data","dicetype","d4","d4_angle_color031.jpg"),
    join("data","dicetype","d6","d6_45angle_0136.jpg"),
    join("data","dicetype","d10","d10_color129.jpg"),
    join("data","dice","20","IMG_9763.jpg")
]

# show_predictions(images_chosen)


In [205]:
def show_image(image_path):
    if image_path == None:
        return widgets.Label("No image selected")
        
    img = pImage.open(image_path)
    
    # Correct orientation based on EXIF data if rotated
    if hasattr(img, '_getexif'):
        exif = img._getexif()
        if exif is not None:
            for orientation in ExifTags.TAGS.keys():
                if ExifTags.TAGS[orientation] == 'Orientation':
                    break
            if orientation in exif:
                if exif[orientation] == 3:
                    img = img.rotate(180, expand=True)
                elif exif[orientation] == 6:
                    img = img.rotate(270, expand=True)
                elif exif[orientation] == 8:
                    img = img.rotate(90, expand=True)

    
    # Read and display the image
    # # img = mpimg.imread(img)
    plt.figure(figsize=(5,5))
    plt.imshow(img)
    plt.axis('off')
    plt.show()

In [213]:
# Widgets

header = widgets.HTML("<h1 style='text-align: center; background-color: lightgrey; height: 80px;'>Dice Classification</h1>", 
                      layout = widgets.Layout(width="100%", height="80px"))
header_box = widgets.HBox([header], layout=widgets.Layout(align_items='center', justify_content='center', width="100%"))

file_choose = FileChooser("./data", layout = widgets.Layout(height="auto", width="100%", border="2px solid red"))
file_choose.layout.overflow = 'hidden'

image_output = widgets.Output(layout = widgets.Layout(height="auto", width="100%", border="2px solid red"))
image_output.layout.overflow = 'hidden'

prediction_output = widgets.Output(layout = widgets.Layout(height="auto", width="100%", border="2px solid red"))

vbtn1 = widgets.Button(description="Confusion Matrix",
                    layout=widgets.Layout(width="100%", height="auto", margin="10px"), button_style="info")
vbtn2 = widgets.Button(description="Model Layers",
                    layout=widgets.Layout(width="100%", height="auto", margin="10px"), button_style="info")
vbtn3 = widgets.Button(description="Visual 3",
                    layout=widgets.Layout(width="100%", height="auto", margin="10px"), button_style="info")
vbtn1.on_click(lambda x: change_visual_content(1))
vbtn2.on_click(lambda x: change_visual_content(2))
vbtn3.on_click(lambda x: change_visual_content(3))

visual_buttons = widgets.HBox([vbtn1, vbtn2, vbtn3], layout=widgets.Layout(align_items='center', justify_content='center', width="100%"))

visual1 = widgets.Image()
visual2 = widgets.Output()
visual3 = widgets.Image()

with visual2:
    model.summary()

with open(join("assets","Confusion Matrix.png"), "rb") as f:
    visual1.value = f.read()


content1 = widgets.HTML("<div>Test1</div>")
content2 = widgets.HTML("<div>Test2</div>")
content3 = widgets.HTML("<div>Test3</div>")

visual1.layout = widgets.Layout(width="50%")
visual2.layout = widgets.Layout(width="60%")
visual3.layout = widgets.Layout(width="50%")

content1.layout = widgets.Layout(width="50%")
content2.layout = widgets.Layout(width="40%")
content3.layout = widgets.Layout(width="50%")

visual_box1 = widgets.HBox([visual1,content1], layout = widgets.Layout(height="auto", width="100%", border="2px solid red"))
visual_box2 = widgets.HBox([visual2,content2], layout = widgets.Layout(height="auto", width="100%", border="2px solid blue"))
visual_box3 = widgets.HBox([visual3,content3], layout = widgets.Layout(height="auto", width="100%", border="2px solid green"))

visual_box_output = widgets.Output(layout = widgets.Layout(height="auto", width="100%", border="2px solid red"))

# image_box = widgets.HBox([image_output], layout=widgets.Layout(height="400px"))

In [207]:
def update_prediction(file_path):
    prediction_output.clear_output()

    with prediction_output:
        show_predictions([file_choose.selected])

In [208]:
def on_file_selected(change):
    # Clear previous output
    image_output.clear_output()
    with image_output:
        # Show the selected image
        show_image(file_choose.selected)
    update_prediction(file_choose.selected)
    
file_choose.register_callback(on_file_selected)

In [209]:
def change_visual_content(content_num):
    visual_box_output.clear_output()

    with visual_box_output:
        if content_num == 1:
            display(visual_box1)
        elif content_num == 2:
            display(visual_box2)
        else:
            display(visual_box3)
    

In [214]:
page = widgets.VBox([header_box, widgets.HBox([file_choose, image_output, prediction_output]), visual_buttons, visual_box_output])
page
# image_grid

VBox(children=(HBox(children=(HTML(value="<h1 style='text-align: center; background-color: lightgrey; height: …

In [110]:
file_choose.selected

'C:\\Users\\joshu\\OneDrive\\Documents\\Dev\\Jupyter\\data\\dice\\11\\IMG_0209.JPG'