In [None]:
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.models import load_model

In [None]:
webcam_path = "data/webcam"
drawing_path = "data/drawing"
webcam_pics = os.listdir(webcam_path)
drawing_pics = os.listdir(drawing_path)
webcam_pics, drawing_pics

In [None]:
def preprocess_webcam_image(image_path):
    img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
    img = cv2.resize(img, (28, 28))
    img = img / 255.0
    return img

In [None]:
full_path = os.path.join(webcam_path, webcam_pics[0])
processed_webcam_image = preprocess_webcam_image(full_path)
processed_webcam_image.shape

In [None]:
plt.plot()
plt.imshow(processed_webcam_image, cmap=plt.get_cmap("gray"))
plt.show()

In [None]:
def preprocess_drawing_image(image_path):
    img = cv2.imread(image_path, cv2.IMREAD_UNCHANGED)
    img = cv2.resize(img, (28, 28))

    img[img[:, :, 3] == 0] = [255, 255, 255, 255]

    img = img[:, :, :3]

    img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)

    img_gray = cv2.equalizeHist(img_gray)

    img_gray = img_gray / 255.0

    img_gray = 1 - img_gray 

    return img_gray

In [None]:
plt.figure(figsize=(10, 10))

for i, pic in enumerate(drawing_pics):
    image_path = os.path.join(drawing_path, pic)
    preprocessed_image = preprocess_drawing_image(image_path)

    plt.subplot(3, 3, (i % 9) + 1)
    plt.imshow(preprocessed_image, cmap=plt.get_cmap("gray"))
    plt.title("Image {}".format(i + 1))

plt.tight_layout()
plt.show()

In [None]:
if len(drawing_pics) > 9:
    image_path = os.path.join(drawing_path, drawing_pics[7])
    preprocessed_image = preprocess_drawing_image(image_path)

    plt.figure(figsize=(8, 8))
    plt.subplot(1, 2, 1)
    plt.imshow(preprocessed_image, cmap=plt.get_cmap("gray"))
    plt.title("Preprocessed Image")

    model = load_model("../models/mnist_model_digits.h5")

    def predict_image(image):
        image = np.expand_dims(image, axis=0)
        prediction = model.predict(image)
        return prediction

    prediction = predict_image(preprocessed_image)
    predicted_label = np.argmax(prediction)

    plt.subplot(1, 2, 2)
    plt.bar(range(10), prediction[0])
    plt.xticks(range(10))
    plt.xlabel("Digit")
    plt.ylabel("Probability")
    plt.title("Prediction: {}".format(predicted_label))

    plt.tight_layout()
    plt.show()
else:
    print("The 'drawing_pics' list does not have an element at index 9.")