In [8]:
# Name: Vikas Saahil
# Student no.: 239408810
# Simple local UI to draw a digit and automatically show predictions.
%matplotlib qt
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import tensorflow as tf

class DigitDrawer:
    def __init__(self, model_path="mnist_cnn_model.h5"):
        plt.style.use("dark_background")
        plt.rcParams["toolbar"] = "none"
        # Loads trained model
        self.model = tf.keras.models.load_model(model_path)
        print("Loaded model from", model_path)

        # 28x28 canvas (black)
        self.canvas = np.zeros((28, 28), dtype=np.float32)
        # Figure with panels for a canvas, prediction text and a probability bar chart
        self.fig = plt.figure(figsize=(8, 3))
        gs = self.fig.add_gridspec(1, 3, width_ratios=[1.3, 1, 1.3])
        self.ax_canvas = self.fig.add_subplot(gs[0, 0])
        self.ax_text   = self.fig.add_subplot(gs[0, 1])
        self.ax_probs  = self.fig.add_subplot(gs[0, 2])
        # Canvas panel
        self.im = self.ax_canvas.imshow(self.canvas, cmap="gray", vmin=0, vmax=1)
        self.ax_canvas.set_title("Draw a digit", fontsize=12)
        self.ax_canvas.axis("off")
        self.ax_canvas.add_patch(
            patches.Rectangle((0, 0), 27, 27, fill=False, edgecolor="white", linewidth=0.5)
        )

        # Prediction text panel
        self.ax_text.axis("off")
        self.text_obj = self.ax_text.text(
            0.5, 0.6, "Prediction:\n-",
            ha="center", va="center", fontsize=16
        )
        self.ax_text.text(
            0.5, 0.1,
            "Instructions:\n- Draw with mouse\n- Press 'c' to clear",
            ha="center", va="center", fontsize=8, color="0.8"
        )
        # Probability bar chart panel (0â€“9)
        self.ax_probs.set_title("Class probabilities", fontsize=10)
        self.ax_probs.set_xlim(0, 1)
        self.ax_probs.set_ylim(-0.5, 9.5)
        self.ax_probs.set_yticks(range(10))
        self.ax_probs.set_yticklabels([str(i) for i in range(10)])
        self.ax_probs.set_xlabel("Probability")
        self.prob_bars = self.ax_probs.barh(
            range(10),
            [0] * 10,
            color="#00ff88"
        )
        # Mouse operations
        self.drawing = False
        self.brush_size = 1
        self.fig.canvas.mpl_connect("button_press_event", self.on_press)
        self.fig.canvas.mpl_connect("button_release_event", self.on_release)
        self.fig.canvas.mpl_connect("motion_notify_event", self.on_move)
        self.fig.canvas.mpl_connect("key_press_event", self.on_key)

        # Automatic prediction timer
        self.timer = self.fig.canvas.new_timer(interval=500)
        self.timer.add_callback(self.predict_current)
        self.timer.start()

        plt.tight_layout()
    def on_press(self, event):
        if event.inaxes == self.ax_canvas:
            self.drawing = True
            self.draw_at(event.xdata, event.ydata)
    def on_release(self, event):
        self.drawing = False
    def on_move(self, event):
        if self.drawing and event.inaxes == self.ax_canvas:
            self.draw_at(event.xdata, event.ydata)
    def draw_at(self, x, y):
        if x is None or y is None:
            return
        xi, yi = int(round(y)), int(round(x))  # row = y, col = x
        for dx in range(-self.brush_size, self.brush_size + 1):
            for dy in range(-self.brush_size, self.brush_size + 1):
                rx, ry = xi + dx, yi + dy
                if 0 <= rx < 28 and 0 <= ry < 28:
                    self.canvas[rx, ry] = 1.0
        self.im.set_data(self.canvas)
        self.fig.canvas.draw_idle()

    def on_key(self, event):
        if event.key == "c":
            # Clears canvas and resets the UI
            self.canvas[:] = 0.0
            self.im.set_data(self.canvas)
            self.text_obj.set_text("Prediction:\n-")
            for bar in self.prob_bars:
                bar.set_width(0)
            self.fig.canvas.draw_idle()

    def predict_current(self):
        # Skips if canvas is blank
        if np.sum(self.canvas) == 0:
            return
        img = self.canvas.astype("float32")
        img = np.expand_dims(img, axis=-1)  # (28,28,1)
        img = np.expand_dims(img, axis=0)   # (1,28,28,1)
        probs = self.model.predict(img, verbose=0)[0]
        pred_label = int(np.argmax(probs))
        confidence = float(np.max(probs)) * 100.0
        self.text_obj.set_text(f"Prediction:\n{pred_label}\n({confidence:.1f}%)")
        for i, bar in enumerate(self.prob_bars):
            bar.set_width(float(probs[i]))
        self.fig.canvas.draw_idle()
    def show(self):
        print("Draw with mouse, press 'c' to clear. Prediction + probability graph update automatically.")
        plt.show()
if __name__ == "__main__":
        drawer = DigitDrawer("mnist_cnn_model.h5")
        drawer.show()




Loaded model from mnist_cnn_model.h5
Draw with mouse, press 'c' to clear. Prediction + probability graph update automatically.
