## Preparing Modules

In [1]:
import sys
import subprocess
from pathlib import Path

# https://stackoverflow.com/questions/68154312/check-if-module-is-installed-in-jupyter-rather-than-in-kernel
def is_installed(pkg_name):
    pip = Path(sys.base_prefix).joinpath('bin', 'pip')  # Linux
    # pip = Path(sys.base_prefix).joinpath('Scripts', 'pip.exe')  # Windows
    proc = subprocess.Popen(
        [pip.as_posix(), 'list'], 
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE
    )
    out, err = proc.communicate()
    packages = out.decode().lower().split('\n')[2:]
    packages = [pkg.split()[0].strip() for pkg in packages if pkg]
    return pkg_name.lower() in packages

In [2]:
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from IPython.display import display, HTML
from os import getcwd
if not is_installed("ipyfilechooser"):
    !{sys.executable} -m pip install ipyfilechooser
from ipyfilechooser import FileChooser
import ipywidgets as widgets
import requests
import io
import imageio
import os

## Import the Classification Model

In [3]:
import classifier
%load_ext autoreload
%autoreload
filename = "efficientnetv2b0_feature_extract_transfer.h5"
target_size = (128, 128, 3)
classifier = classifier.SimpleClassifier(filename, target_size)
classifier.set_labels("labels.txt")

2022-08-06 02:46:26.398936: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [4]:
def resize(image, target_size=(128,128), pad=False, method="bilinear"):
    return tf.image.resize(image,target_size,method=method).numpy()

def predict(resized_img, classes_dict = None):
    return classifier.predict(resized_img)
    
def predict_from_byte(image_data):
    image = np.array(Image.open(io.BytesIO(image_data)).convert("RGB")) / 255.0
    display_result(image)

In [5]:
# Prepare directories
if not os.path.exists("Images"):
    os.mkdir("Images")
if not os.path.exists("Images/Correct"):
    os.mkdir("Images/Correct")
if not os.path.exists("Images/Incorrect"):
    os.mkdir("Images/Incorrect")

## Helper Functions and Display

In [6]:
out = widgets.Output(layout={'border': '1px solid black'})
with out:
    display(HTML("<h2>Prediction Result</h2>"))

In [7]:
def check_correctness_interface(source_image, prediction):
    label = widgets.Output()
    with label:
        display(HTML("Is the above prediction correct?"))
    correct = widgets.Button(description = "Correct", button_style = "success", icon = "check")
    incorrect = widgets.Button(description = "Incorrect", button_style = "danger", icon = "times")
    buttons = [correct, incorrect]
    correct.on_click(lambda sender: check_correctness(True, source_image, prediction, buttons))
    incorrect.on_click(lambda sender: check_correctness(False, source_image, prediction, buttons))
    display(label, widgets.HBox([correct, incorrect]))

def display_result(source_image):
    resized_img = resize(source_image)
    fig, (ax1, ax2) = plt.subplots(1, 2)
    ax1.imshow(source_image)
    ax1.set_title("Original Image")
    ax2.imshow(resized_img)
    ax2.set_title("Resized Image")
    plt.show()
    prediction = predict(resized_img)
    display(HTML("<b>Prediction: <span style=\"background-color:#00407a; color:white;\">%s</span></b>" % (prediction)))
    check_correctness_interface(source_image, prediction)

In [8]:
def save_image(path, source_image, button = None):
    idx = 1
    while os.path.exists(path.replace(".", "_%d." % (idx))):
        idx += 1
    imageio.imwrite(path.replace(".", "_%d." % (idx)), source_image)
    if button is not None:
        button.disabled = True

In [9]:
@out.capture(clear_output=False)
def check_correctness(correct, source_image, label, buttons):
    for b in buttons:
        b.disabled = True
    if correct:
        save_image("Images/Correct/" + label + ".jpg", source_image)
    else:
        display(HTML("<b>Please indicate the correct breed name:</b>"))
        text = widgets.Text(placeholder="Correct breed name")
        button = widgets.Button(button_style="success", icon="check")
        button.on_click(lambda sender: save_image("Images/Incorrect/" + text.value + ".jpg", source_image, button))
        display(widgets.HBox([text, button]))

In [10]:
@out.capture(clear_output=True)
def predict_from_url(url):
    print("Predicting from URL..." + url)
    try:
        r = requests.get(url, stream=True)
        content_type, extension = r.headers.get('content-type').split("/")
        if content_type == 'image':
            r.raw.decode_content = True
            image_data = r.content
            predict_from_byte(image_data)
    except:
        print("Cannot Retreive Image")
        
@out.capture(clear_output=True)
def predict_from_upload(image_data, name):
    print("Predicting from the uploaded file: %s" % (name))
    predict_from_byte(image_data)
    
@out.capture(clear_output=True)
def predict_from_uploaded_images(chooser):
    path = chooser.selected
    img = np.array(Image.open(path).convert("RGB")) / 255.0
    display_result(img)

In [12]:
# Textarea + Submit Button
text = widgets.Textarea()
submit_button = widgets.Button(description="Submit", button_style='success')
submit_button.on_click(lambda sender: predict_from_url(text.value))
# File Upload
upload = widgets.FileUpload(accept='.png, .jpg, .jpeg')
upload.observe(lambda sender: predict_from_upload(upload.data[0], list(upload.value.keys())[0]), names=['data'])
# FileChooser
fc = FileChooser(getcwd())
fc.filter_pattern = ['*.jpg', '*.png', '*.jpeg']
fc.register_callback(predict_from_uploaded_images)
# Prompt Label
label = widgets.Output()
with label:
    display(HTML("<h2>Select an image using one of these options</h2>"))
# Create tabs
titles = ["From URL", "Upload an Image", "Use Uploaded Images"]
tabs = widgets.Tab()
tabs.children = [widgets.VBox([text, submit_button]), upload, fc]
for i in range(len(titles)):
    tabs.set_title(i, titles[i])
display(out, label, tabs)

Output(layout=Layout(border='1px solid black'), outputs=({'output_type': 'display_data', 'data': {'text/plain'…

Output()

Tab(children=(VBox(children=(Textarea(value=''), Button(button_style='success', description='Submit', style=Bu…