# Digit Classification App

In [10]:
import warnings
warnings.filterwarnings('ignore')

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import Model
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, Flatten
from tensorflow.keras.layers import Conv2D, MaxPooling2D
from tensorflow.keras import backend as K
import numpy as np
import pandas as pandas
from sklearn.model_selection import KFold
#import import_ipynb


import time, sys
from IPython.display import clear_output

%matplotlib inline
%matplotlib widget

import matplotlib as matplotlib
import matplotlib.pyplot as pyplot
from matplotlib.pyplot import figure
from matplotlib.pyplot import legend as Legend
from matplotlib.widgets import Button as mButton
from matplotlib.image import *
import matplotlib.widgets as m_widgets
import matplotlib.animation
import ipywidgets as widgets
from ipywidgets import interactive as interactive
from IPython.display import display
#from tensorflow.keras import backend as K
from tensorflow.keras.models import load_model
import random
from numpy.random import rand
import ipympl as ipl
from sklearn.manifold import TSNE
from sklearn import decomposition
from sklearn.preprocessing import StandardScaler
import seaborn as seaborn


In [6]:
#load training data + labels and testing data + labels
(train_data, train_labels), (test_data, test_labels) = mnist.load_data()

In [7]:
#load trained model
model = load_model('../digit_classification_model.h5')

In [8]:
#Loads the model history 
history=np.load('../my_history.npy',allow_pickle='TRUE').item()

In [5]:
def reshape_images(test_data, test_labels, train_data, train_labels):
    
    train_data = train_data.reshape(train_data.shape[0], 28, 28, 1)
    test_data = test_data.reshape(test_data.shape[0], 28, 28, 1)
    
    image_shape = (28, 28, 1)
    train_labels = keras.utils.to_categorical(train_labels)
    test_labels = keras.utils.to_categorical(test_labels)
    
    train_data = train_data.astype('float32')
    test_data = test_data.astype('float32')
    
    train_data /= 255
    test_data /= 255

In [6]:
reshape_images(test_data, test_labels, train_data, train_labels)

In [7]:
def select_image(test_data, index):
    
    '''Classifies the clicked image and displays the digit'''    
    
    image = test_data[index].reshape(1, 28, 28, 1)
    result = model.predict(image)
    return np.argmax(result), max(result)

In [8]:
def show_eval_score(test_data, test_labels, train_data, train_labels, rate_id):

    '''Dispalys the accuracy or loss percentage''' 

    train_data = train_data.reshape(train_data.shape[0], 28, 28, 1)
    test_data = test_data.reshape(test_data.shape[0], 28, 28, 1)
    
    image_shape = (28, 28, 1)

    train_labels = keras.utils.to_categorical(train_labels)
    test_labels = keras.utils.to_categorical(test_labels)
    
    train_data = train_data.astype('float32')
    test_data = test_data.astype('float32')
    
    train_data /= 255
    test_data /= 255


  
    rate = model.evaluate(test_data, test_labels, verbose=0)
    

    if(rate_id == 0):
        return rate[0]*100
    else:
        return rate[1]*100

In [9]:
class figure_image(object):
    
    
    def __init__(self, fig, index, test_data):
        '''A dummy class for storing references to figures'''

        self.fig = fig #Pyplot figure
 
        self.index = index #Image index in MNIST Test dataset
    
        self.test_data = test_data #MNIST test dataset, object copy
      
        self.ax = pyplot.gca() #Get current axes (rather than instantiate a new one)

        self.reset_button = widgets.Button(description="Reset", fontsize=16) #Clears classification output and resets
        
        self.reset_button.on_click(self.reset_button_clicked) #Function to reset 
        
                     
        self.label = self.ax.set_title("Click image to classify", fontsize=18) #Title of fig; classification displayed here
        
        self.cid = self.fig.canvas.mpl_connect('button_press_event', self.image_clicked) #Classify image when clicked
        
        
    def image_clicked(self, event):
        
        '''When the figure is clicked, the model is called to classify the image and output the predicted digit'''
        
        self.ax.set_title(select_image(test_data, self.index)[0], fontsize=18)
    
    def reset_button_clicked(self, event):
        '''Clears the output of a classification event'''
        
        
        self.ax.set_title("Click image to classify", fontsize=18)
        

        
    def show(self):
        pyplot.show()
        self.fig.show()

In [10]:
def create_progress_bar():
    
    '''Creates a standard progress bar to dispay during long computations.'''    
    
    progress_bar = widgets.IntProgress(value=0,
    min=0,
    max=10,
    step=1,
    description='Loading:',
    bar_style='', 
    orientation='horizontal')
    count = 0
    
    return progress_bar

In [11]:
def load_progress_bar(progress_bar, rate_id):
    
    '''Increments a progress bar while the accuracy or loss rate is being evaluated'''
    
    count=0
    while count <= 3:
        progress_bar.value += 1
        time.sleep(.1)
        count+=1

    score = show_eval_score(test_data, test_labels, train_data, train_labels, rate_id)

    while count <= 10:
        progress_bar.value += 1 # signal to increment the progress bar
        time.sleep(.4)
        count += 1
        
    return score

In [12]:
def get_loss_rate(event, button, vbox):
    
    '''Evaluates the loss of the model and displays it. If loss rate is already displayed,
    it is cleared and reset.'''    
    
    percent_label = widgets.Label(value='')
    total_tested_label = widgets.Label(value='')
    total_loss_label = widgets.Label(value='')
    
    
    
    
    progress_bar = create_progress_bar()
    
    if(len(button.description) > 6):
        
        vbox.children += (progress_bar, percent_label, total_tested_label, total_loss_label,)
        
        score = load_progress_bar(progress_bar, 0)
        total_tested = len(test_data)
        total_loss = int(len(test_data) * (score/100))
        
        percent_label.value = str(score)+' % loss'
        total_tested_label.value = "Number of images tested: "+str(total_tested)
        total_loss_label.value = "Number of missed predictions: "+str(total_loss)
        
        button.description = "Clear"
        
        progress_bar.value = 10
        time.sleep(.2)
        progress_bar.layout.visibility = 'hidden'
        
        
    else:
        button.description = "Show Loss Rate:"
        vbox.children=[]



In [13]:
def get_accuracy_rate(event, button, vbox):
    
    '''Evaluates the accuracy of the model and displays it. If accuracy rate is already displayed,
    it is cleared and reset.'''    
    
    percent_label = widgets.Label(value='')
    total_tested_label = widgets.Label(value='')
    total_correct_label = widgets.Label(value='')
    
    
    progress_bar = create_progress_bar()
    
    
    
    if(len(button.description) > 6):
        
        vbox.children += (progress_bar, percent_label, total_tested_label, total_correct_label) 
        score = load_progress_bar(progress_bar, 1)
        
        total_correct = int(len(test_data)*(score/100))
        total_tested = len(test_data)
        
        percent_label.value = str(score)+ "% accuracy"
        total_tested_label.value = "Number of images tested: "+str(total_tested)
        total_correct_label.value = "Number of correct predictions: "+str(total_correct)
        
        button.description = "Clear"


        progress_bar.value=10
        time.sleep(.2)
        progress_bar.layout.visibility = 'hidden'
        
       

    else:
        button.description = "Show Accuracy Rate:"
        vbox.children=[]

In [14]:
def get_images():
    
    
    '''Randomly selects 10 images from the test dataset. Each image can be run through the model 
    for classification by clicking it, and they are each paired with a reset button to clear the 
    output of the classification. The images and corresponding reset buttons are placed in a container, and each 
    container is placed inside a larger container to be displayed upon app initialization.'''

    hbox_list = []
    hbox_list2 = []
    
    num_samples = 10
    
    indexes = random.sample(range(len(test_data)), num_samples)

    
    for i in range(len(indexes)):
        output = widgets.Output()
        with output:
            fig = pyplot.figure()
        pyplot.imshow(test_data[indexes[i]], cmap=pyplot.get_cmap('gray'), picker=True)
        new_test = figure_image(fig, indexes[i], test_data)
        fig.canvas.toolbar_position = 'bottom'
            
            
        hbox = widgets.HBox([new_test.reset_button, output])
        hbox_list.append(hbox)
        
        if (i+1)%2 == 0:
            n_hbox = widgets.HBox([])
            for k in range(i-1, i+1):
                n_hbox.children += (hbox_list[k],)
            hbox_list2.append(n_hbox)
            
    vbox = widgets.VBox([])
    
    for box in range(0, len(hbox_list2)):
        vbox.children += (hbox_list2[box],)
    return vbox
        
        #display(widgets.HBox([new_test.reset_button, output]))

In [15]:
def refresh_images(event, vbox):
     
    #for i in range(1, 11):
     #   pyplot.close(i)
        
    pyplot.close('all')
    
    new_vbox = get_images()
    
    
    vbox.children=new_vbox.children

In [16]:
def show_loss_button():
    
    '''Creates and displays a button for evaluating the loss rate of the model'''    
    
    style = {'description_width': 'initial'}

    loss_button = widgets.Button(description="Click to Show Loss Rate:", style=style)
    
    loss_button.layout.height = '60px'
    loss_button.layout.width = '50%'
    
    loss_button.style.button_color="palegoldenrod"
    
    res_vbox = widgets.VBox([])

    
    loss_hbox = widgets.HBox([loss_button, res_vbox])
    loss_button.on_click(lambda event: get_loss_rate('button_press_event', loss_button, res_vbox))
    
    return loss_hbox

In [17]:
def show_acc_button():
    
    
    '''Creates and displays a button for evaluating the accuracy rate of the model'''
    
    style = {'description_width': 'initial'}

    acc_button = widgets.Button(description="Click to Show Accuracy Rate:", style=style)
    
    acc_button.layout.height = '60px'
    acc_button.layout.width = '50%'
    
    acc_button.style.button_color="palegoldenrod"
    
    res_vbox = widgets.VBox([])
    
    
    acc_hbox = widgets.HBox([acc_button, res_vbox])
    acc_button.on_click(lambda event: get_accuracy_rate('button_press_event', acc_button, res_vbox))
    
    return acc_hbox

In [9]:
def plot_pca_cluster():
    
    '''Plot the principal component analysis of the training data. The data is compressed into two dimensions, 
    and the resulting plot clusters images with similar features.'''

    output = widgets.Output()
    #l_pca = widgets.Label(value="Plotting Principle Component Analysis...")


    
    with output:
        fig = pyplot.figure()
    
    test_data_csv=pandas.read_csv("../mnist_test.csv")
    
    label=test_data_csv['label']
    drop = test_data_csv.drop('label',axis=1)
    
    pca_data = drop
    
    pca= decomposition.PCA()

    standard_scaler_data= StandardScaler().fit_transform(pca_data.astype(int))

    pca.n_components = 2
    fitted_data= pca.fit_transform(standard_scaler_data)


    fitted_data=np.vstack((fitted_data.T,label)).T


    data_frame = pandas.DataFrame(data=fitted_data, columns=("1st Principal", "2nd Principal", "Key"))

    seaborn.FacetGrid(data_frame, hue='Key',height=6).map(pyplot.scatter,'1st Principal','2nd Principal').add_legend()
    
    

In [None]:
def plot_acc_curve():
    
    '''Plot the classification accuracy in the training set. 
    The accuracy is inerpreted as the percentage of correctly classified images during training.'''    
    
    output = widgets.Output()
    with output:
        fig = pyplot.figure()
    pyplot.gca().set_title("Classification Accuracy Curve")
    pyplot.plot(history['val_acc'], color='tab:pink')
   
    l_acc = widgets.Label(value="Plotting Classification Accuracy Curve...")
    
    vbox = widgets.VBox([l_acc, output])
    
    return vbox

    
    

In [20]:
def plot_loss_curve():
    
    '''Plot the classification loss in the training set.
    The loss is interpreted as the percentage of classification errors during training.'''
    
    output=widgets.Output()
    with output:
        fig = pyplot.figure()
    pyplot.gca().set_title("Cross Entropy Loss Curve")
    pyplot.plot(history['val_loss'], color='tab:purple')
    
    l_loss = widgets.Label(value="Plotting Classification Cross Entropy Loss Curve...")

    
    vbox = widgets.VBox([l_loss, output])
    return vbox

In [21]:
def plot_acc_analysis():
    

    '''Plot the validation accuracy against the training accuracy for comparison.
    The validation accuracy represents the percentage of correctly classified images
    in the validation dataset, which contains images that the model has never seen before.'''
    
    output=widgets.Output()
    with output:
        fig = pyplot.figure()
    pyplot.gca().set_title("Classification Accuracy Analysis")
    pyplot.plot(history['acc'], color='tab:cyan')
    pyplot.plot(history['val_acc'], color='tab:red')
    
    
    l_acc_analysis = widgets.Label(value="Plotting Classification Accuracy Analysis...")

    vbox = widgets.VBox([l_acc_analysis, output])
    return vbox

In [22]:
def plot_loss_analysis():
    
    '''Plot the validation loss against the training loss for comparison.
    The validation loss represents the percentage of incorrectly classified images in the validation dataset,
    which contains images that the model has never seen before.'''

    output=widgets.Output()
    with output:
        fig = pyplot.figure()
    pyplot.gca().set_title("Cross Entropy Loss Analysis")
    pyplot.plot(history['loss'], color='tab:purple')
    pyplot.plot(history['val_loss'], color='tab:green')
    
    l_loss_analysis = widgets.Label(value="Plotting Classification Cross Entropy Loss Analysis...")

    vbox = widgets.VBox([l_loss_analysis, output])
    return vbox

In [23]:
def plot_digit_frequency():
    
    '''Plots the number of occurrences of each digit in the test data'''

    digits = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
    
    digit_freq = [0]*10
    
    for label in test_labels:
        digit_freq[label]+=1
    output = widgets.Output()
    
    with output:
        fig = pyplot.figure()
   
    pyplot.gca().set_title("Frequency of Digits in Test Data")
    pyplot.gca().set_xlabel("Digit")
    pyplot.gca().set_ylabel("Frequency")
    pyplot.gca().set_yticks(np.arange(0, 1140, 100))
    
    
    pyplot.bar(digits, digit_freq, color=['tab:blue', 'tab:orange', 'tab:green', 'tab:red', 
                                          'tab:purple', 'tab:brown', 'tab:pink', 'tab:gray', 'tab:olive', 'tab:cyan'])

    
    l_digit_freq = widgets.Label(value="Plotting digit frequency...")
    
    vbox=widgets.VBox([l_digit_freq, output])
    
    return vbox

In [24]:
def first_progr_bar_display(progress_bar, description):
    
    '''Initialize a progress bar to signal that figures are currently loading'''
    
    style={'description_width': 'initial'}

    progress_bar.description=description
    progress_bar.style=style
    
    display(progress_bar)
    
    count=0
    while count <= 3:
        progress_bar.value += 1
        time.sleep(.1)
        count+=1
    return progress_bar

In [25]:
def second_progr_bar_display(progress_bar):
    
    '''Moves progress bar to the end to signal loading completion'''
    
    
    count=3
    while count <= 10:
        progress_bar.value += 1 
        time.sleep(.4)
        count += 1

    progress_bar.value = 10
    time.sleep(.3)
    progress_bar.layout.visibility='hidden'


In [26]:
#THE WORKING CODE

def start_app():
    pyplot.close('all')    
####################################################################################################
    
    style={'description_width': 'initial'}   
####################################################################################################

    
    progress_bar1 = create_progress_bar() #Progress bar to track loading of dataset sample images
    
    first_progr_bar_display(progress_bar1, "Loading Dataset Samples...")  
####################################################################################################
    
    
    dataset_label = widgets.Label(value= "Sample data; click images to classify digit")
    
####################################################################################################
  

    images_vbox = get_images() #Container of 10 randomly selected test data samples
    
    generate_new_samples_button = widgets.Button(description="Click to generate new data samples") #Randomly generate 10 new samples
    
    generate_new_samples_button.layout.height='50px'
    
    generate_new_samples_button.layout.width='30%'
    
    generate_new_samples_button.style.button_color="lightsalmon"
    
    generate_new_samples_button.on_click(lambda event: refresh_images('button_press_event', images_vbox))
    
    display(generate_new_samples_button) #Display sample re-generation button
    
    dataset_vbox = widgets.VBox([dataset_label, images_vbox]) #Container with dataset label and data sample container

    display(dataset_vbox) #Display the label and data sample images
    

    ##################################################################################################################
        
    acc_hbox = show_acc_button() #Button/label pair to show accuracy rate
    
    loss_hbox = show_loss_button() #button/label pair to show loss rate
    
    rate_vbox = widgets.VBox([acc_hbox, loss_hbox]) #vbox with both buttons
    
    display(rate_vbox)
    
    ######################################################################################################################
    

    progress_bar2 = create_progress_bar() #Progress bar to track loading of graphs
    first_progr_bar_display(progress_bar2, "Plotting Graphs...")
        

    digit_freq_box = plot_digit_frequency()
    display(digit_freq_box)
        
    acc_box = plot_acc_curve()
    display(acc_box)
    
    loss_box = plot_loss_curve()
    display(loss_box)
    
    acc_analysis_box = plot_acc_analysis()
    display(acc_analysis_box)
    
    loss_analysis_box = plot_loss_analysis()
    display(loss_analysis_box)
    
    display(widgets.Label(value="Plotting Principal Component Analysis Cluster...", style=style))
    display(plot_pca_cluster()) #Plot PCA scatter plot
    
    
    #Move progress bar to end after rendering is complete
    second_progr_bar_display(progress_bar2) 
    second_progr_bar_display(progress_bar1)
    
    
    ######################################################################################################################

    
    
    
    ######################################################################################################################


    ######################################################################################################################

    
    
    

In [27]:
pyplot.close('all') #Close all figures
start_app()

IntProgress(value=0, description='Loading Dataset Samples...', max=10, style=ProgressStyle(description_width='…

Button(description='Click to generate new data samples', layout=Layout(height='50px', width='30%'), style=Butt…

VBox(children=(Label(value='Sample data; click images to classify digit'), VBox(children=(HBox(children=(HBox(…

VBox(children=(HBox(children=(Button(description='Click to Show Accuracy Rate:', layout=Layout(height='60px', …

IntProgress(value=0, description='Plotting Graphs...', max=10, style=ProgressStyle(description_width='initial'…

VBox(children=(Label(value='Plotting digit frequency...'), Output()))

VBox(children=(Label(value='Plotting Classification Accuracy Curve...'), Output()))

VBox(children=(Label(value='Plotting Classification Cross Entropy Loss Curve...'), Output()))

VBox(children=(Label(value='Plotting Classification Accuracy Analysis...'), Output()))

VBox(children=(Label(value='Plotting Classification Cross Entropy Loss Analysis...'), Output()))

Label(value='Plotting Principal Component Analysis Cluster...', style=DescriptionStyle(description_width='init…

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

None