<a href="https://colab.research.google.com/github/actionsolve/article25/blob/master/20191200_Webcam_acquire.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Demo Webcam acquisition

## Todo
  - Retrain CNN

## Done
  - capture image sequence
  - display thumbnails
  - Layouts
  
    

## Imports

In [0]:
%tensorflow_version 1.x

TensorFlow 1.x selected.


In [0]:
# Imports
import datetime as dt  ; 
import os
import sys              ; print(f'Python ver    ver: %s.%s.%s   %s' % (*sys.version_info[0:3], sys.platform))
import threading
import time

from cv2 import *       ; print(f'OpenCV        ver: {cv2.__version__}')
import numpy as np      ; print(f'Numpy         ver: {np.__version__}')
import pandas as pd     ; print(f'Pandas        ver: {pd.__version__}' )

import ipywidgets as widgets
#from ipywidgets import Button, GridBox, Layout, ButtonStyle
from google.colab import widgets as colab_widgets  #  For working tabs

from IPython.display import Image, HTML, display, display_html, Javascript, Markdown, clear_output
#from PIL import Image                            XXX BEWARE also using Image from HTML
from google.colab.output import eval_js
from base64 import b64decode

import PIL
import io
from base64 import b64encode

if False:
    import tensorflow as tf ; print(f'Tensorflow    ver: {tf.__version__}' )  # native: ver: 2.2.0-rc1
    import tensorflow.keras as keras           ; print(f'Keras         ver: {keras.__version__}' )  #  ver: 2.2.4-tf

    #from tensorflow.keras.layers import Input, Dense
    from tensorflow.keras.models import model_from_json
    # from tensorflow.keras import models, layers

    from tensorflow.keras.preprocessing.image import load_img, img_to_array
    from tensorflow.keras.preprocessing.image import ImageDataGenerator

    from tensorflow.keras.applications.vgg16 import preprocess_input
    from tensorflow.keras.applications.vgg16 import decode_predictions
    from tensorflow.keras.applications.vgg16 import VGG16
    from tensorflow.keras import backend as K

    #import tensorflow.python.util.deprecation as deprecation

else:
    import tensorflow as tf ; print(f'Tensorflow    ver: {tf.__version__}' )  # native: ver: 2.2.0-rc1
    import keras           ; print(f'Keras         ver: {keras.__version__}' )  # native ver: 2.2.5

    #from keras.layers import Input, Dense
    from keras.models import model_from_json

    from keras.preprocessing.image import load_img, img_to_array
    from keras.preprocessing.image import ImageDataGenerator

    from keras.applications.vgg16 import preprocess_input
    from keras.applications.vgg16 import decode_predictions
    from keras.applications.vgg16 import VGG16
    from keras import backend as K

    #import tensorflow.python.util.deprecation as deprecation

## Tools - Image Acquisition

In [0]:
# Webcam acquisition tools
# Plaguarised from https://colab.research.google.com/notebooks/snippets/advanced_outputs.ipynb#scrollTo=buJCl90WhNfq
def get_html_live_video(width=640, height=480):
    """
    Get HTML tag for live video display from webcam
    """
    html_video_tag = f'<video id="video" width="{width}" height="{height}" autoplay></video>'
    html_video_start_js = """
    <script>
        // Display webcam in video tag.  No audio required
        if(navigator.mediaDevices && navigator.mediaDevices.getUserMedia) {
            navigator.mediaDevices.getUserMedia({ video: true, audio: false }).then(function(stream) {
                let video = document.getElementById('video');
                video.srcObject = stream;
                video.play();
            });
        }
    </script>
    """
    # return HTML(html_video_tag + html_video_js)
    return html_video_tag + html_video_start_js

def acquire_webcam_image_to_file(filename, image_dims=(640, 480), await_button_click=True, quality=0.8, verbose=False):
    """
    Acquire image to file
    """
    js = Javascript('''
        async function takePhoto(await_button_click, quality) {

            const div = document.createElement('div');
            const capture = document.createElement('button');
            if(await_button_click) {               
                capture.textContent = 'Capture Image';
                div.appendChild(capture);
            }

            const video = document.createElement('video');
            video.style.display = 'block';
            const stream = await navigator.mediaDevices.getUserMedia({video: true});
            document.body.appendChild(div);
            div.appendChild(video);
            video.srcObject = stream;
            await video.play();

            // Add button after image canvas
            if(await_button_click) {
                div.appendChild(capture);
            }

            // Resize the output to fit the video element.
            google.colab.output.setIframeHeight(document.documentElement.scrollHeight, true);

            // Wait for Capture to be clicked.
            if(await_button_click) {
                await new Promise((resolve) => capture.onclick = resolve);
            }

            // Draw video image to canvas, to capture
            const canvas = document.createElement('canvas');
            canvas.width = video.videoWidth;
            canvas.height = video.videoHeight;
            canvas.getContext('2d').drawImage(video, 0, 0);

            // Stop video
            stream.getVideoTracks()[0].stop();
            div.remove();

            return canvas.toDataURL('image/jpeg', quality);
        }
    ''')
    display(js)
    data_jpeg_base64 = eval_js(f'takePhoto({str(await_button_click).lower()}, {quality})')
    tokens = data_jpeg_base64.split(',')
    if(verbose): print(f'    Acquired {len(tokens[1])} bytes, type: {tokens[0]}')
    bytes_jpeg = b64decode(data_jpeg_base64.split(',')[1])
    image = PIL.Image.open(io.BytesIO(bytes_jpeg))  # ; print(image.size) 
    old_width, old_height = image.size

    # Resize
    (new_width, new_height) = image_dims
    if( (old_width != new_width) or (old_height != new_height) ):
        if(verbose): print(f'    Resizing image {old_width} x {old_height}  ->  {new_width} x {new_height}')
        image = image.resize(image_dims, PIL.Image.ANTIALIAS) #; print(image.size)

    # Save to file
    image.save(filename, optimize=True, quality=95)
    if(verbose): print(f'    Saved to {filename}    {image.size[0]} x {image.size[1]}')
    # with open(filename, 'wb') as f:        f.write(bytes_jpeg)
    return filename, new_width, new_height

def stop_live_webcam_video():
    js = Javascript('''
        function stop_video() {
            console.log("Stopping HTML video, via JS " );
            const video = document.getElementById('video');
            const stream = video.srcObject;
            stream.getVideoTracks().forEach(track => track.stop());
        }
    ''')
    display(js)
    eval_js(f'stop_video()')
    
# Test
if False:
    try:
        print('Started')
        # Show live video
        display(HTML(get_html_live_video(width=100, height=100))) 

        # Acquire from webcam
        filename, img_width, img_height = acquire_webcam_image_to_file('photo.big.jpg', await_button_click=False, verbose=True)
        #print(f'Saved to {filename}    {img_width} x {img_height}')
        display(Image(filename))

        # Acquire from webcam
        filename, img_width, img_height = acquire_webcam_image_to_file('photo.sml.jpg', (224, 224), await_button_click=False, verbose=True)
        #print(f'Saved to {filename}    {img_width} x {img_height}')
        display(Image(filename))

    except Exception as err:
        # Errors will be thrown if the user does not have a webcam or if they do not
        # grant the page permission to access it.
        print(str(err))

if False:
    print('Started')
    #display(get_html_live_video()) 
    display(HTML(get_html_live_video(width=100, height=100))) 
    print('waiting')
    time.sleep(3)
    stop_live_webcam_video()
    print('stopped')

    #clear_output()
    #filename, img_width, img_height = acquire_webcam_image_to_file('photo.x.jpg', await_button_click=False, verbose=True)
    #print(f'Saved to {filename}    {img_width} x {img_height}')
    #display(Image(filename))


In [0]:
# Simple display of thumbnail images from filename list, in a row, left to right.  Keep to < 10 for visibility
def get_image_as_html_tag(filename, width=200, height=200):
    # Convert to PNG bytes
    image = PIL.Image.open(filename)
    bytes_png = io.BytesIO()  
    image.save(bytes_png, format='png')
    image_data = b64encode(bytes_png.getvalue()).decode('utf-8')

    # Generate HTML image tag
    #html_tag = f"<img style='width: 200px; margin: 5px; float: left; border: 1px solid black;' src='data:image/png;base64,{image_data}'/>"
    html_tag = f"<img style='width: {width}px; height:{height}px margin: 5px; float: left; border: 1px solid black;' src='data:image/png;base64,{image_data}'/>"
    #display(html_tag) ; display(HTML(html_tag))

    return html_tag

def get_results_as_html_tag(class_name, certainty):
    html_tag = '<table style="width:180px; border:1px black">'  # <table style="width:300px">
    html_tag += '<tr>'
    html_tag += '<td><big>' ; html_tag += f'<bold>{class_name}</bold>'  ; html_tag += '</big></td>'
    html_tag += '<td><big>' ; html_tag += f'{certainty*100:.2f} %'      ; html_tag += '</big></td>'
    html_tag += '</tr>'
    html_tag += '</table>'
    
    return html_tag

def display_image_thumbnails(filenames, width=200, height=200):
    '''
    Simple display of thumbnail images from filename list, in a row, left to right.  Keep to < 10 for visibility
    '''
    image_list_html_tags = ''.join(  [get_image_as_html_tag(filename, width, height) for filename in filenames ] ) #; print(image_list_html_tags)
    display(HTML(image_list_html_tags))

def capture_images_for_class(base_dir, class_name, num_images_to_capture, image_dims, verbose=False):

    # Make storage folder for traiing images
    # XXX clear previous
    base_dir_for_class = os.path.join(base_dir, class_name)
    if not os.path.exists(base_dir_for_class):  
        os.makedirs(base_dir_for_class)

    # Loop acquiring new images
    filenames = []
    for image_num in range(num_images_to_capture):

        filename = os.path.join(base_dir_for_class, f'image_{image_num:03}.jpg')   
        fn, w, h = acquire_webcam_image_to_file(filename, image_dims, await_button_click=False)
        # print(f'    {dt.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}:   Captured: {fn}  {w} x {h}')
        log_msg(f'  Captured: {fn}  {w} x {h}')
        filenames.append(filename)

    return filenames

# Test
if False:
    filenames = []
    for index in range(3):
        filename = 'aa' + str(index) + '.jpg'
        #acquire_webcam_image_to_file(filename=filename, await_button_click=False)
        acquire_webcam_image_to_file(filename, (224, 224), await_button_click=False, verbose=True)
        # filename, img_width, img_height = acquire_webcam_image_to_file('photo.jpg', (224, 224), await_button_click=False)
        filenames.append(filename)

    print(filenames)
    #display_image_thumbnails(filenames)
    display_image_thumbnails(filenames, width=50, height=50)

if False:
    image_width, image_height = 224, 224 
    class_names = ( 'AAA',  'BBB')
    base_dir = 'test_images'
    num_images_to_capture = 3  # Test
    image_dims = (image_width, image_height)

    filenames = []
    for class_name in class_names:

        print(f'Class {class_name}')
        new_filenames = capture_images_for_class(base_dir, class_name, num_images_to_capture, image_dims) #; print(filenames)
        #filenames = ['test_images/AAA/image_000.jpg', 'test_images/AAA/image_001.jpg', 'test_images/AAA/image_002.jpg']
        for filename in new_filenames:
            filenames.append(filename)
        time.sleep(2)

    # print(filenames)
    display_image_thumbnails(filenames, width=50, height=50)

## Tools - Display and Classification

In [0]:
# Config
num_images_to_acquire = 2
min_classes_to_train = 2
num_training_epochs = 40
save_path = 'images'
image_width, image_height = 224, 224 
image_dims = (image_width, image_height)

# Intial weights to use
# https://github.com/fchollet/deep-learning-models/releases/download/v0.1/vgg16_weights_tf_dim_ordering_tf_kernels.h5
WEIGHTS_FILE_VGG16 = 'vgg16_weights_tf_dim_ordering_tf_kernels.h5'
#WEIGHTS_FILE_VGG16 = 'vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5'
WEIGHTS_FILE_SIZE = 553467096  # bytes  notop.h5:58889256    top:553467096  (wget reported corectly too)
# os.stat_result(st_mode=33188, st_ino=3941984, st_dev=46, st_nlink=1, st_uid=0, st_gid=0,  st_size=58889256, st_atime=1584428003, st_mtime=1495517769, st_ctime=1584428015) notop: 
# os.stat_result(st_mode=33188, st_ino=3941984, st_dev=46, st_nlink=1, st_uid=0, st_gid=0, st_size=553467096, st_atime=1584428831, st_mtime=1495517769, st_ctime=1584428841) top
WEIGHTS_PATH_VGG16 = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.1'
WEIGHTS_PATH_LOCAL = '~/.keras/models/'  # Beware '~' not yet expanded to actual directory on local OS

# Other initialisations
import os
WEIGHTS_PATH_LOCAL = os.path.expanduser(WEIGHTS_PATH_LOCAL) ; print(f'  WEIGHTS_PATH_LOCAL: {WEIGHTS_PATH_LOCAL}')
WEIGHTS_FILE_LOCAL = os.path.join(WEIGHTS_PATH_LOCAL, WEIGHTS_FILE_VGG16) ; print(f'  WEIGHTS_FILE_LOCAL: {WEIGHTS_FILE_LOCAL}')

# Globals - descope these
custom_model = None    # XXX Set this 
base_model = None
weights_file_available = False

In [0]:
def diag_print_dir(dir_name, tag_name=''):
    dir_list = os.listdir(dir_name) 
    print(f"    {tag_name} Dir: {dir_name},  Files: {dir_list}") 

In [0]:
# Temp debug: remove weights file
if True:
    # Ensure path exists
    os.makedirs(WEIGHTS_PATH_LOCAL, mode=0o777, exist_ok=True)
    
    diag_print_dir(WEIGHTS_PATH_LOCAL, tag_name='')
    #print(f'File: {os.stat(WEIGHTS_FILE_LOCAL).st_size} vs expected: {WEIGHTS_FILE_SIZE}')
    !/bin/rm -f  ~/.keras/models/*
    diag_print_dir(WEIGHTS_PATH_LOCAL, tag_name='')

In [0]:
# Build custom model
def get_custom_model(base_model, num_output_classes, verbose=False):

    if(verbose): print(f'  Customising model to {num_output_classes} outputs')

    # These should not be trainable
    for layer in base_model.layers: layer.trainable = False

    # Strip final (most-abstract) layers
    # model.layers.pop()
    # x = base_model.output
    # x = base_model.layers[-2].output  # Lose final/top layer
    x = base_model.layers[-3].output  # Lose final/top 2 layers

    # Add new layers
    layer_preds = keras.layers.Dense(num_output_classes, activation='softmax')(x) # Final layer with softmax activation

    # Create new model
    custom_model = keras.models.Model(inputs=base_model.input, outputs=layer_preds)

    # Compile model
    # XXX Check optimisers
    custom_model.compile(loss='categorical_crossentropy', optimizer='AdaDelta', metrics=['accuracy',])
    # custom_model.compile(loss='categorical_crossentropy', optimizer=optimizers.RMSprop(lr=1e-4),  metrics=['acc'])

    if(verbose): print(custom_model.summary())
    return custom_model

In [0]:
# History plots
def plot_training(loss_train, acc_train, loss_valid=None, acc_valid=None):
    
    import matplotlib.pyplot as plt
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 4))

    ax1.plot(loss_train, label='Train', color='red')
    if(loss_valid is not None):
        ax1.plot(loss_valid, label='Valid')
    ax1.set_xlabel('Epochs')
    ax1.set_ylabel('Loss')
    ax1.set_title('Loss')
    ax1.legend()

    ax2.plot(acc_train, label='Train', color='red')
    if(acc_valid is not None):
        ax2.plot(acc_valid, label='Valid')
    ax2.set_xlabel('Epochs')
    ax2.set_ylabel('Accuracy')
    ax2.set_title('Accuracy')
    ax2.legend()

    plt.show()

def log_msg(msg):
    timestamp = dt.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    msg_full = f'{timestamp}: {msg}'
    # with tab_bar.output_to(1, select=False):
    if 'widget_output' in vars() or 'widget_output' in globals():
        with widget_output:
            print(msg_full)
    else:
        print(msg_full)

In [0]:
# Predict 
#def get_class_from_image(filename, class_names, verbose=True):
#    global base_model, custom_model
#    print(f'* get_class_from_image({filename}, {class_names})    model={custom_model}')
#    most_likely_class_label, max_pred = get_class_label_from_image(custom_model, filename, class_names, show_possibles=False)
#    # XXX If not good, should alos query base_model
#    return most_likely_class_label, max_pred
def get_class_label_from_image(model, filename, class_names, show_possibles=False):

    log_msg(f'* get_class_label_from_image({filename}, {class_names})    model={model}')

    # Load image from file
    image = load_img(filename, target_size=(image_width, image_height))

    # convert the image pixels to a numpy array, reshape, and normalise
    image = img_to_array(image)
    image = image.reshape((1, image.shape[0], image.shape[1], image.shape[2]))
    image = preprocess_input(image)

    # Predict probabilities across all output classes
    yhat = model.predict(image)        # ; print(yhat)  ->  [[0.00166724 0.9983328 ]]
    assert len(yhat[0]) == len(class_names), 'Inconsistent class list, with trainer model'
    predictions = yhat[0]  ; log_msg(predictions)
    #labels = decode_predictions(yhat)  # ; print(labels)
    # eg. [[('n04200800', 'shoe_shop', 0.37742418), ('n04070727', 'refrigerator', 0.20702392)...
    # [[1.000000e+00 5.988968e-11]]

    if show_possibles:
        log_msg(f'  Top {len(predictions)} possibles' )
        for index  in range(len(predictions)):
            log_msg(f'    {class_names[index]:20}     ({predictions[index]*100:.2f}%)' )

    max_pred = np.max(predictions)
    index_of_max = [i for i in range(len(predictions)) if predictions[i] == max_pred][0]  #; print(index_of_max)
    most_likely_class_label = class_names[index_of_max]

    # Print the classification
    #most_likely_class_label = labels[0][0]
    #print(f'{most_likely_class_label[1]} ({most_likely_class_label[2]*100:.3f}%)' )

    return most_likely_class_label, max_pred
    #return 'aaaaa', 0.99

In [0]:
# Train model from images 
def train_model(verbose=False):

    if(verbose): log_msg(f'* train_model(from images path: {save_path})')

    # Load images into generator
    # https://machinelearningmastery.com/how-to-configure-image-data-augmentation-when-training-deep-learning-neural-networks/
    if(verbose): log_msg(f'    Loading images from path: "{save_path}"   {dt.datetime.now().strftime("%Y-%m-%d %H:%M:%S")} ') ; time_start = time.time()
    datagen = ImageDataGenerator(rescale=1./255)
    # XXX Hack to direct unwanted logging to 'logging' tab, if available
    if 'widget_output' in vars() or 'widget_output' in globals():
        with widget_output:
            train_generator = datagen.flow_from_directory(save_path, class_mode='categorical', #class_mode='binary', 
                target_size=(image_width, image_height))
    else:
        train_generator = datagen.flow_from_directory(save_path, class_mode='categorical', #class_mode='binary', 
            target_size=(image_width, image_height))
    class_names = list(train_generator.class_indices.keys())
    if(verbose): log_msg(f'      Class indices map: {train_generator.class_indices}') # .class_indices.keys()
    if(verbose): log_msg(f'    Finished  {dt.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}   Elapsed: {(time.time() - time_start):.3f}')

    # Check enough classes to train
    if(len(class_names) > 1):
        if(verbose): log_msg(f'    Found classes: {class_names}')
    else:
        log_msg(f'    ** Not enough classes to train on (only found: {class_names}) **')

    global weights_file_available
    while not weights_file_available:
        if(verbose): log_msg(f'    T: Waiting for weights download ... ')
        time.sleep(2)

    #K.clear_session()  # Does NOT delete pretrained weights
    global base_model, custom_model
    tf.logging.set_verbosity(tf.logging.ERROR)
    verbosity = tf.logging.get_verbosity()
    #deprecation._PRINT_DEPRECATION_WARNINGS = False
    base_model = VGG16() # include_top=False)   # Entire:500MB,  include_top=False : 5MB
    tf.logging.set_verbosity(verbosity)  # tf.logging.INFO)
    #deprecation._PRINT_DEPRECATION_WARNINGS = True

    # Custom model
    log_msg(f'  T: {base_model}  len(class_names): {len(class_names)}')
    custom_model = get_custom_model(base_model, num_output_classes=len(class_names), verbose=False)

    # Train
    if(verbose): log_msg(f'    Training  {dt.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}:  ')
    time_start = time.time()    #  or  time_start = datetime.now()
    history = custom_model.fit(train_generator, 
        #steps_per_epoch=len(train_generator), validation_data=test_it, validation_steps=len(test_it), 
        epochs=num_training_epochs, verbose=0)
    if(verbose): log_msg(f'    Finished  {dt.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}   Elapsed: {(time.time() - time_start):.3f}')

    # History plots
    if(verbose):   # verbose
        log_msg(f'    Avail history fields: {history.history.keys()}' )
        plot_training(history.history['loss'], history.history['acc']) 
        #, history.history['val_loss'],  history.history['val_acc'] )

In [0]:
# Initiate background download of weights in thread
def init_load_weights_background(verbose=False):
    if(verbose): log_msg(f'* initialise_trainable_model()')
    # https://arxiv.org/abs/1409.1556
    # https://machinelearningmastery.com/use-pre-trained-vgg-model-classify-objects-photographs/

    #diag_print_dir(WEIGHTS_PATH_LOCAL, tag_name='')

    # Check if weights file already exists, and correct size
    global weights_file_available
    log_msg(f'  WEIGHTS_FILE_LOCAL: {WEIGHTS_FILE_LOCAL},  expected size: {WEIGHTS_FILE_SIZE}')
    #print(f'Checking: {os.stat(WEIGHTS_FILE_LOCAL).st_size} vs {WEIGHTS_FILE_SIZE}')
    if(os.access(WEIGHTS_FILE_LOCAL, os.R_OK) and os.stat(WEIGHTS_FILE_LOCAL).st_size == WEIGHTS_FILE_SIZE):
        file_stat = os.stat(WEIGHTS_FILE_LOCAL)  # ; print(f'  file_stat  : {file_stat} ')
        if(verbose): log_msg(f'    Initial model weights already downloaded : {WEIGHTS_FILE_LOCAL},  {file_stat.st_size} bytes')
        weights_file_available = True
        return
    else:
        # Clean slate
        if(verbose): log_msg(f'    Clearing models weights directory {WEIGHTS_PATH_LOCAL}')
        !/bin/rm -f  $WEIGHTS_PATH_LOCAL/*
        # dir_list = os.listdir(WEIGHTS_PATH_LOCAL) ; print(f"      Dir: {WEIGHTS_PATH_LOCAL},  Files: {dir_list}")  
        weights_file_available = False

    # Ensure path exists
    os.makedirs(WEIGHTS_PATH_LOCAL, mode=0o777, exist_ok=True)

    def load_weights_background(arg0, arg1):

        global base_model, custom_model
        # Allow TF to dump warnings and initialise itself.
        # XXX Warnings and errors will be hidden in thread                   XXX
        if(verbose): log_msg(f'    B: Loading VGG16()   {dt.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}  ')   ; time_start = time.time()
        tf.logging.set_verbosity(tf.logging.ERROR)
        verbosity = tf.logging.get_verbosity()
        #deprecation._PRINT_DEPRECATION_WARNINGS = False
        K.clear_session()  # Does NOT delete pretrained weights
        base_model = VGG16() # include_top=False)   # Entire:500MB,  include_top=False : 5MB
        #   VGG16(weights='imagenet', include_top=False, input_shape=(image_width, image_height, 3))
        #deprecation._PRINT_DEPRECATION_WARNINGS = True
        #tf.logging.set_verbosity(tf.logging.INFO)
        tf.logging.set_verbosity(verbosity)  # tf.logging.INFO)
        if(verbose): log_msg(f'    B: Finished loading VGG16 {dt.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}   Elapsed: {(time.time() - time_start):.3f} ')  

        K.clear_session()  # Does NOT delete pretrained weights
        log_msg(f'  B: {base_model}')

        global weights_file_available ; weights_file_available = True

    # Start background download
    args_list=(2, 3)
    threading.Thread(target=load_weights_background, args=args_list).start()
    #load_weights_background(*args_list)

In [0]:
# Test init, train, predict
if False:
    print(f'Starting')

    # Init models
    global base_model, custom_model
    global weights_file_available  ; weights_file_available = False
    init_load_weights_background(verbose=True)
 
    download_source_images = False
    # Setup
    class_names = ('yyy', 'zzz')
    if download_source_images:
        # Clean out previous
        !/bin/rm -rf $save_path
        def download_source_file(class_name, file_names, path_src):
            target_path = os.path.join(save_path, class_name) 
            os.makedirs(os.path.join(target_path), exist_ok=True)
            for file_name in file_names:
                #!wget https://raw.githubusercontent.com/Horea94/Fruit-Images-Dataset/master/Test/Apple%20Braeburn/321_100.jpg
                full_url = os.path.join(path_src, file_name)
                !/usr/bin/wget $full_url
                !/bin/mv $file_name $target_path
        # Explore:  https://github.com/Horea94/Fruit-Images-Dataset/tree/master/Training/Apple%20Braeburn
        download_source_file('yyy', ('100_100.jpg', '123_100.jpg', '307_100.jpg'), 
                            'https://github.com/Horea94/Fruit-Images-Dataset/raw/master/Training/Apple%20Braeburn/')
        download_source_file('zzz', ('104_100.jpg', '122_100.jpg', '131_100.jpg', ), 
                            'https://github.com/Horea94/Fruit-Images-Dataset/raw/master/Training/Banana/')

    while not weights_file_available:
        print(f'    I: Waiting for weights download ... ')
        time.sleep(2)

    # Train models
    print(f'Training model')
    train_model(verbose=True)

    filename = os.path.join(save_path, 'zzz/104_100.jpg')
    # most_likely_class_label, max_pred = get_class_from_image(filename, class_names, verbose=True)
    most_likely_class_label, max_pred = get_class_label_from_image(custom_model, filename, class_names, show_possibles=False)
    print(f'{most_likely_class_label} ({max_pred*100:.3f}%)' )
    
    print(f'Finished')

In [0]:
# Tools
def rem_dir(base_dir, class_name=None, verbose=False):
    import shutil
    base_dir_for_class = os.path.join(base_dir, class_name) if class_name else base_dir
    if os.path.exists(base_dir_for_class):  
        if(verbose): print(f'    Removing {base_dir_for_class}')
        shutil.rmtree(base_dir_for_class)
    else:
        if(verbose): print(f'    Missing dir: {base_dir_for_class}')

def get_only_classes_with_images(base_dir, class_names, min_num_images, verbose=False):
    filtered_class_list = []
    for class_name in class_names:
        base_dir_for_class = os.path.join(base_dir, class_name) 
        if(verbose): log_msg(f'  Seaching: {base_dir_for_class}')
        if not os.path.exists(base_dir_for_class): 
            if(verbose): print(f'    Missing dir: {base_dir_for_class}')
            continue
        file_list = [file_name for file_name in os.listdir(base_dir_for_class) if file_name.endswith('.jpg')]
        if(verbose): log_msg(f'    Found: {file_list}')
        # XXX Check file type
        if(len(file_list) >= min_num_images):
            filtered_class_list += (class_name, )
    return filtered_class_list

def get_class_list_from_widgets(v_box_classes, verbose=False):
    # Get class list (non-whitespace names)
    class_names = []
    for h_box in v_box_classes.children:  # print(h_box)  
        txt_class_name = h_box.children[0]
        class_name = txt_class_name.value.strip()
        if(len(class_name) > 0):
            if(verbose): log_msg(f'  Adding class name [{class_name}]') 
            class_names.append(class_name)
        else:
            if(verbose): log_msg(f'  Empty class name - ignored') 
    return class_names

def get_valid_class_list(base_dir, v_box_classes, verbose=False):
    class_names = get_class_list_from_widgets(v_box_classes, verbose)
    class_names = get_only_classes_with_images(base_dir, class_names, min_classes_to_train, verbose)
    return class_names

if False:
    # XXX Need 3 dirs, at-least 2 with > imgaes present
    base_dir = 'test_images'
    class_names = ('AAA', 'BBB', 'kkk') ; print(f'  classes 0: {class_names}')
    class_names = get_only_classes_with_images(base_dir, class_names, min_num_images=3, verbose=True)
    print(f'  classes 1: {class_names}')

    rem_dir(base_dir, class_name='kkk', verbose=True)

In [0]:
# Handlers
def handler_add_widgets_for_new_class(v_box_classes, row_num, num_images_to_acquire):

    log_msg(f'Adding new object to train')

    global model_initialised, model_trainable, model_trained
    if not model_initialised:
        log_msg(f'Initialising model ...')
        # XXX
        model_initialised = True

    # Widgets
    txt_class_name = widgets.Text(value='', placeholder='Type a name for the new object', description='Object Name:', disabled=False)
    btn_acquire = widgets.Button(description='Take Photos', button_style='success', disabled = True, icon='camera')
    box_images = widgets.HBox()
    btn_remove = widgets.Button(description='Clear', button_style='danger', disabled = True, icon='remove')

    def acquire_images_and_add(save_path, class_name, num_images_to_acquire, image_dims, box_images, verbose=False):
        if(verbose): log_msg(f'* acquire_images_and_add({num_images_to_acquire} images to {save_path}, {class_name})')
        img_filenames = capture_images_for_class(save_path, class_name, num_images_to_acquire, image_dims) #; print(filenames)
        #img_filenames = ['test_images/AAA/image_000.jpg', 'test_images/AAA/image_001.jpg', 'test_images/AAA/image_002.jpg']
        for filename in img_filenames:
            html_image = get_image_as_html_tag(filename, width=50, height=50)
            w = widgets.HTML(value = html_image,  placeholder='placeholder HTML' ) # XXX
            box_images.children += (w,)
        return img_filenames

    # Wiring
    def handler_name_change(obj):
        # print(f'  Row: {row_num}, Handling name change [{txt_class_name.value}] class')
        txt_class_name.value = txt_class_name.value.strip()    # XXX No spaces
        if(len(txt_class_name.value) > 0):
            btn_remove.disabled = False
            btn_acquire.disabled = False
        else:
            btn_remove.disabled = True 
            btn_acquire.disabled = True
    def handler_acquire_for_class(obj):
        txt_class_name.disabled = True
        btn_acquire.disabled = True
        class_name = txt_class_name.value.strip()
        # print(f'  Acquiring for {class_name} class,  Row: {row_num}, ')  # ;  print(f'  obj {obj}')
        img_filenames = acquire_images_and_add(save_path, class_name, num_images_to_acquire, image_dims, box_images, verbose=True)
        # Update other widgets
        model_trained = False ; btn_identify.disabled = True
        if(len(get_valid_class_list(save_path, v_box_classes)) >= min_classes_to_train): 
            model_trainable = True  ; btn_train.disabled = False 
            # print(f'  ** Enabled training ')  # ;  print(f'  obj {obj}')
        #display_image_thumbnails(img_filenames)
    def handler_clear_class(obj):
        # Delete image files
        class_name = txt_class_name.value.strip()
        rem_dir(save_path, class_name, verbose=False)
        # Clear widget
        txt_class_name.value = ''
        txt_class_name.disabled = False
        box_images.children = []  # Do the discarded children need 'close()' called on each?
        # Update other widgets
        model_trained = False ; btn_identify.disabled = True
        if(len(get_valid_class_list(save_path, v_box_classes)) < min_classes_to_train): 
            model_trainable = False ; btn_train.disabled = True
            model_trained = False ; btn_identify.disabled = True

    # Wiring
    txt_class_name.observe(handler_name_change, names='value')
    btn_acquire.on_click(handler_acquire_for_class)
    btn_remove.on_click(handler_clear_class)

    # Layout for row
    h_box = widgets.HBox((txt_class_name, btn_acquire, box_images, btn_remove))
    v_box_classes.children += (h_box,)    #; print(f'  Rows: {len(v_box_objects.children)}')

def handler_learn_classes(v_box_classes):
    log_msg(f'Getting classes from {len(v_box_classes.children)} rows')
    # Get class list.  Ignore empty rows
    class_names = get_valid_class_list(save_path, v_box_classes) 
    # Train
    train_model(verbose=False)  # class_names) 
    model_trained = True; btn_identify.disabled = False

def handler_identify(v_box_classes, widget_snapshot):
    # Acquire image and identify
    log_msg(f'Identifying')

    # Get class names.  Ignore empty rows
    class_names = get_valid_class_list(save_path, v_box_classes) 

    # Acquire image
    image_dims = (image_width, image_height)
    filename = 'photo.jpg'
    filename, _, _ = acquire_webcam_image_to_file(filename, image_dims, await_button_click=False)  
    widget_snapshot.value = '<h3>Snapshot to Identify</h3>' + get_image_as_html_tag(filename, width=100, height=100)

    # Identify
    global base_model, custom_model
    most_likely_class_label, max_pred = get_class_label_from_image(custom_model, filename, class_names, show_possibles=True)
    # XXX most_likely_class_label, max_pred = get_class_label_from_image(base_model, filename, class_names, show_possibles=True)

    # Present results
    # msg = f'Most likely:\n\n    {most_likely_class_label}     ({max_pred*100:.2f}%)\n'
    log_msg( f'Most likely:   {most_likely_class_label}     (approximately {max_pred*100:.2f}%)')
    widget_snapshot.value += get_results_as_html_tag(most_likely_class_label, max_pred)
    # XXX widget_snapshot.value += get_results_as_html_tag(most_likely_class_label, max_pred)

def handler_stop_video(obj):
    stop_live_webcam_video()

In [0]:
# Widgets
# Icons from https://www.w3schools.com/icons/icons_reference.asp  -> Font Awesome 4
btn_add_class = widgets.Button(description='Add new object', button_style='success', disabled=False, icon='plus')
btn_train = widgets.Button(description='Learn objects', button_style='success', disabled=True, icon='eye')  # star eye
btn_identify = widgets.Button(description='Identify objects', button_style='success', disabled=True, icon='search-plus') # search-plus lightbulb
btn_stop_video = widgets.Button(description='Stop Webcam', button_style='danger', disabled=False, icon='hand-stop-o')  # stop minus-circle hand-stop-o

widget_video = widgets.HTML(get_html_live_video(width=500, height=300), layout=widgets.Layout(width='auto', grid_area='widget_video'))  # live webcam display + output log area (eg. )
widget_snapshot = widgets.HTML(value = f" ", layout=widgets.Layout(width='auto', grid_area='widget_snapshot')) 
#widget_results = widgets.HTML(value = f"<b>placeholder</b> - results", layout=widgets.Layout(width='auto', grid_area='widget_results'))

widget_output = widgets.Output(layout={'height': '100px', 'width': '100%', 'border': '1px solid black', 'overflow':'scroll'})

v_box_main = widgets.VBox()    #; print(v_box.children),  num_widget_rows = len(v_box_main.children)
v_box_classes = widgets.VBox() 
#h_box_identify = widgets.HBox()

# Wiring     
btn_add_class.on_click(lambda obj : handler_add_widgets_for_new_class(v_box_classes, len(v_box_classes.children), num_images_to_acquire))
btn_train.on_click(lambda obj : handler_learn_classes(v_box_classes))
btn_identify.on_click(lambda obj : handler_identify(v_box_classes, widget_snapshot))
btn_stop_video.on_click(handler_stop_video)

# Layout
# https://ipywidgets.readthedocs.io/en/latest/examples/Output%20Widget.html   and 'widget_output.clear_output()'
#h_box_identify.children = (widget_video, widget_snapshot, widget_results)    #, widget_output) 
v_box_main.children = (btn_add_class, v_box_classes, btn_train, btn_identify)#, h_box_identify, btn_stop_video)
#children = [btn_add_class, v_box_classes, btn_train, btn_identify, widget_video, widget_snapshot, btn_stop_video]
children = [widget_video, widget_snapshot]
areas = '''
        "widget_video widget_video . widget_snapshot"
        '''
layout=widgets.Layout(width='800px',  # width='50%'
            grid_template_rows='auto auto auto auto auto auto auto',
            grid_template_columns='25% 25% 25% 25%',
            grid_template_areas=areas)
gridbox = widgets.GridBox(children=children, layout=layout)
#display(HTML('<link rel="stylesheet" href="//stackpath.bootstrapcdn.com/font-awesome/4.7.0/css/font-awesome.min.css"/>'))
#display(v_box_main, gridbox, btn_stop_video) 

NameError: ignored

## Run

In [0]:
# Widget enablement
global model_initialised, model_trainable, model_trained
model_initialised = False  ; model_trainable = False ; model_trained = False;

# Clear the decks
rem_dir(save_path)
clear_output()

# Display
# Enable icons on widgets:  https://fontawesome.com/icons?d=gallery
display(HTML('<link rel="stylesheet" href="//stackpath.bootstrapcdn.com/font-awesome/4.7.0/css/font-awesome.min.css"/>'))
display(v_box_main, gridbox, btn_stop_video) 

# Ensure logs visible in tab
tab_bar = colab_widgets.TabBar(['Main', 'Logging'])
with tab_bar.output_to(0, select=True):
    print('Click on "Logging" tab to see logs')
with tab_bar.output_to(1, select=False):
    display(widget_output)

# Load pre-trained CNN, in background
log_msg(f'Loading pre-trained CNN, in background')
init_load_weights_background(verbose=True)

log_msg(f'Fin layout')