<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc" style="margin-top: 1em;"><ul class="toc-item"><li><span><a href="#Setup" data-toc-modified-id="Setup-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>Setup</a></span></li><li><span><a href="#Choose-a-person" data-toc-modified-id="Choose-a-person-2"><span class="toc-item-num">2&nbsp;&nbsp;</span>Choose a person</a></span></li><li><span><a href="#Get-starting-image-references" data-toc-modified-id="Get-starting-image-references-3"><span class="toc-item-num">3&nbsp;&nbsp;</span>Get starting image references</a></span><ul class="toc-item"><li><span><a href="#Option-1:-fetch-them-from-Google-Images" data-toc-modified-id="Option-1:-fetch-them-from-Google-Images-3.1"><span class="toc-item-num">3.1&nbsp;&nbsp;</span>Option 1: fetch them from Google Images</a></span></li><li><span><a href="#Option-2:-specify-face-ids-from-Esper" data-toc-modified-id="Option-2:-specify-face-ids-from-Esper-3.2"><span class="toc-item-num">3.2&nbsp;&nbsp;</span>Option 2: specify face ids from Esper</a></span></li><li><span><a href="#Option-3:-load-references-and-dataset-from-existing-model" data-toc-modified-id="Option-3:-load-references-and-dataset-from-existing-model-3.3"><span class="toc-item-num">3.3&nbsp;&nbsp;</span>Option 3: load references and dataset from existing model</a></span></li></ul></li><li><span><a href="#Your-initial-reference-images" data-toc-modified-id="Your-initial-reference-images-4"><span class="toc-item-num">4&nbsp;&nbsp;</span>Your initial reference images</a></span></li><li><span><a href="#Building-a-Training-Set" data-toc-modified-id="Building-a-Training-Set-5"><span class="toc-item-num">5&nbsp;&nbsp;</span>Building a Training Set</a></span><ul class="toc-item"><li><span><a href="#Getting-negative-examples-(via-sampling)" data-toc-modified-id="Getting-negative-examples-(via-sampling)-5.1"><span class="toc-item-num">5.1&nbsp;&nbsp;</span>Getting negative examples (via sampling)</a></span></li><li><span><a href="#Getting-positive-examples-(via-k-NN)" data-toc-modified-id="Getting-positive-examples-(via-k-NN)-5.2"><span class="toc-item-num">5.2&nbsp;&nbsp;</span>Getting positive examples (via k-NN)</a></span></li></ul></li><li><span><a href="#Training-a-Classifier" data-toc-modified-id="Training-a-Classifier-6"><span class="toc-item-num">6&nbsp;&nbsp;</span>Training a Classifier</a></span><ul class="toc-item"><li><span><a href="#Training-and-obtaining-predictions" data-toc-modified-id="Training-and-obtaining-predictions-6.1"><span class="toc-item-num">6.1&nbsp;&nbsp;</span>Training and obtaining predictions</a></span></li><li><span><a href="#Visualize-predictions" data-toc-modified-id="Visualize-predictions-6.2"><span class="toc-item-num">6.2&nbsp;&nbsp;</span>Visualize predictions</a></span></li></ul></li><li><span><a href="#Saving-your-model" data-toc-modified-id="Saving-your-model-7"><span class="toc-item-num">7&nbsp;&nbsp;</span>Saving your model</a></span></li></ul></div>

# Setup

Before we begin, we need to load some dependencies and define some utility functions. Run the cell below.

In [None]:
%matplotlib inline

from IPython.display import display, clear_output
from IPython.core.pylabtools import figsize
figsize(12, 5)
import ipywidgets as widgets
import os
import pickle
import time
import traceback
import random
import math
import numpy as np
np.warnings.filterwarnings('ignore')
import matplotlib.pyplot as plt
from sklearn import metrics
from collections import namedtuple

from esper.prelude import *
from esper.stdlib import *
from esper.captions import phrase_search
from esper.identity import load_and_select_faces_from_images
from esper.plot_util import tile_images
from esper.major_canonical_shows import MAJOR_CANONICAL_SHOWS
from esper import embed_google_images

import esper.face_embeddings as face_embeddings

print('Done importing')

MODEL_DIR = '/app/data/identity_models_v2'

ReferenceFaces = namedtuple(
    'ReferenceFaces', ['name', 'ids', 'embs', 'imgs'])

def flatten(l):
    return [item for sublist in l for item in sublist]

def split_list(l, idx):
    return l[:idx], l[idx:]

def query_faces(ids):
    faces = Face.objects.filter(id__in=ids)
    return faces.values(
        'id', 'bbox_y1', 'bbox_y2', 'bbox_x1', 'bbox_x2',
        'frame__number', 'frame__video__id', 'frame__video__fps',
        'shot__min_frame', 'shot__max_frame')

def query_sample(qs, n):
    return qs.order_by('?')[:n]

def query_faces_result(faces, expand_bbox=0.05):
    """Replaces qs_to_result"""
    result = []
    for face in faces:
        if (face.get('shot__min_frame') is not None and 
                face.get('shot__max_frame') is not None):
            min_frame = int(
                (face['shot__min_frame'] + 
                 face['shot__max_frame']) / 2)
        else:
            min_frame = face['frame__number']
        face_result = {
            'type': 'flat', 'label': '', 
            'elements': [{
                'objects': [{
                    'id': face['id'],
                    'background': False,
                    'type': 'bbox',
                    'bbox_y1': max(face['bbox_y1'] - expand_bbox, 0),
                    'bbox_y2': min(face['bbox_y2'] + expand_bbox, 1),
                    'bbox_x1': max(face['bbox_x1'] - expand_bbox, 0),
                    'bbox_x2': min(face['bbox_x2'] + expand_bbox, 1),
                }], 
                'min_frame': min_frame,
                'video': face['frame__video__id']
            }]
        }
        result.append(face_result)
    return {'type': 'Face', 'count': 0, 'result': result}

def load_face_img(face):
    return crop(load_frame(face.frame.video, face.frame.number, []), face)

def sort_ids_by_distance(ids, embs):
    dists = face_embeddings.dist(ids, targets=embs)
    return [i for _, i in sorted(zip(dists, ids))]

def sort_faces_by_distance(faces, embs, ascending=False):
    ids = [f['id'] for f in faces]
    id_to_dist = {
        k: v for k, v in zip(ids, face_embeddings.dist(ids, targets=embs))
    }
    order_const = 1 if ascending else -1
    faces.sort(key=lambda x: order_const * id_to_dist[x['id']])
    return faces

def video_ids_with_mentions(phrase):
    result = phrase_search(phrase)
    return {d.id for d in result.documents}

def continue_yn_prompt(msg):
    l = input('{} Continue? (y/N): '.format(msg))
    if l.strip().lower() != 'y':
        raise ValueError('User aborted.') 
        
print('Done loading utils')

def load_female_face_ids():
    FEMALE_FACE_IDS_CACHE = '/tmp/female_face_ids.pkl'
    if os.path.exists(FEMALE_FACE_IDS_CACHE):
        with open(FEMALE_FACE_IDS_CACHE, 'rb') as f:
            ids = pickle.load(f)
    else:
        ids = {
            f['face__id'] for f in FaceGender.objects.filter(
                gender__name='F', face__frame__video__threeyears_dataset=True
            ).values('face__id')
        }
        with open(FEMALE_FACE_IDS_CACHE, 'wb') as f:
            pickle.dump(ids, f)
    return ids
female_face_ids = load_female_face_ids()

print('Done loading data')

# Choose a person

Please select a person for whom you would like to build a model for.

In [None]:
name = input('Enter a name: ').strip()
assert name != '', 'Name cannot be the empty string'

# Get starting image references

In order to train a binary model to identify a person, we need to find some initial visual examples of the target person. We will use these initial example images to build a training set to train the model.

We have provided three options for how to obtain these starting images:
 - Google Image Search
 - face ids from Esper
 - existing model on disk
 
You only need to do one of the options.

## Option 1: fetch them from Google Images 

The following code fetches images using Google Image Search. You will be asked to select, from the faces in the results, which faces are your target person.

In [None]:
def get_google_images(name):
    img_dir = embed_google_images.fetch_images(name)

    # If the images returned are not satisfactory, rerun the above with extra params:
    #     query_extras='' # additional keywords to add to search
    #     force=True      # ignore cached images

    # TODO: use Esper to select images
    imgs = flatten(load_and_select_faces_from_images(img_dir))
    embs = embed_google_images.embed_images(imgs)
    assert len(imgs) == len(embs)
    return ReferenceFaces(name=name, ids=set(), embs=embs, imgs=imgs)
    
face_references = get_google_images(name)

## Option 2: specify face ids from Esper

If you already know a set of Esper face identifiers corresponding to the target person, then you can specify an initial set of face idenitifiers below. Running `confirm_selected_faces()` afterwards will display the faces in an interactive widget and let you ignore some of the ids (pressing the '\]' key while hovering). Any faces not marked as "ignore" will be selected. 

In [None]:
def confirm_selected_faces(face_ids):        
    submit_button = widgets.Button(
        layout=widgets.Layout(width='auto'),
        style={'description_width': 'initial'},
        description='Confirm selection',
        disabled=False,
        button_style='danger',
        tooltip='Submit labels'
    )
    
    example_faces = query_faces(sorted(face_ids))
    example_selection_widget = esper_widget(
        query_faces_result(example_faces),
        crop_bboxes=True, jupyter_keybindings=True
    )
    
    def on_submit(b):
        ignored_example_face_idxs = set(example_selection_widget.ignored)
        example_selection_widget.close()
        clear_output()

        ids = {
            f['id'] for i, f in enumerate(example_faces) 
            if i not in ignored_example_face_idxs
        }
        print('You deselected {} and accepted {} faces.'.format(
              len(ignored_example_face_idxs), len(ids)))
        
        imgs = par_for(load_face_img, Face.objects.filter(id__in=face_ids))
        embs = [x for _, x in face_embeddings.get(face_ids)]
        global face_references
        face_references = ReferenceFaces(name=name, ids=ids, imgs=imgs, embs=embs)
    
    submit_button.on_click(on_submit)
    
    display(widgets.HBox([widgets.Label('Controls:'), submit_button]))
    display(example_selection_widget)

In [None]:
face_ids = [
    644710, 4686364, 2678025, 62032, 13248, 4846879, 4804861, 561270, 2651257,
    2083010, 2117202, 1848221, 2495606, 4465870, 3801638, 865102, 3861979, 4146727,
    3358820, 2087225, 1032403, 1137346, 2220864, 5384396, 3885087, 5107580, 2856632,
    335131, 4371949, 533850, 5384760, 3335516
]
confirm_selected_faces(face_ids)

## Option 3: load references and dataset from existing model

If you have previously saved a model for the target person, you can load reference images and the coresponeding dataset.

In [None]:
def load_model(path=None):
    if path is None: 
        path = os.path.join(MODEL_DIR, '{}.pkl'.format(name.lower().replace(' ', '_')))
    print('Loading model: {}'.format(path))
    with open(path, 'rb') as f:
        model = pickle.load(f)
    assert model['name'] == name, 'Model name does not match {} != {}'.format(model['name'], name)
    
    embs = model['init_embs']
    ids = set(model['init_ids'])
    imgs = model['init_imgs']
    
    references = ReferenceFaces(
        name=model['name'], ids=ids, imgs=imgs, embs=embs)
    pos_examples = set(model['pos_examples'])
    neg_examples = set(model['neg_examples'])
    print('Done! Loaded {} reference faces; {} positive and {} negative examples'.format(
          len(embs), len(pos_examples), len(neg_examples)))
    return references, pos_examples, neg_examples

In [None]:
face_references, pos_examples, neg_examples = load_model()

# Your initial reference images

You can always view your initial set of reference images by running `show_reference_imgs()`. This can come in handy if you are unsure if a face is the target or not.

In [None]:
def show_reference_imgs(refs):
    tiled_imgs = tile_images(
        [cv2.resize(x, (100, 100)) for x in refs.imgs], 
        cols=10, blank_value=255)
    print('Your reference images for {}.'.format(refs.name))
    plt.figure()
    imshow(tiled_imgs)
    plt.tight_layout()
    plt.show()

In [None]:
if face_references is None:
    raise ValueError('Missing initial reference images')
show_reference_imgs(face_references)

# Building a Training Set

Hooray! We have our initial reference images. However, these are still too few to train a model. This section will rectify that by constructing a dataset of sufficient diversity and size to begin training.

In [None]:
try:
    neg_examples
    continue_yn_prompt('neg_examples will be overwritten')
    neg_examples = set()
except NameError: 
    neg_examples = set()

In [None]:
try:
    pos_examples
    continue_yn_prompt('pos_examples will be overwritten')
    pos_examples = set(face_references.ids)
except NameError: 
    pos_examples = set(face_references.ids)

## Getting negative examples (via sampling)

We will obtain negative examples by randomly sampling the dataset. You will also be presented with the opportunity to clean these sampled faces. 

<b>Cleaning:</b> Faces on TV news do not receive equal screen time. Political figures such as Donald Trump and Hillary Clinton can comprise up to 2% of the total faces in the dataset. This is sufficient to appear when performing negative sampling. `get_negative_samples()` allows you to select these faces that are the target person, remove them from the negative samples, and add them them to positive examples set. To select a single face, hover and press '\['; to select an entire page of faces, press '\{'. 

You can rerun `get_negative_samples()` as many times as needed.

In [None]:
def get_negative_samples(neg_samples=None, k=None):
    if neg_samples is None:
        neg_samples = face_embeddings.sample(k)
    
    update_button = widgets.Button(
        layout=widgets.Layout(width='auto'),
        style={'description_width': 'initial'},
        description='Commit selections',
        disabled=False,
        button_style='danger',
    )
    
    commit_button = widgets.Button(
        layout=widgets.Layout(width='auto'),
        style={'description_width': 'initial'},
        description='Commit selections and dismiss widget',
        disabled=False,
        button_style='warning',
    )

    neg_samples_ord = sort_ids_by_distance(
        neg_samples, face_references.embs
    )
    neg_samples_ord_idxs = {
        b: a for a, b in enumerate(neg_samples_ord)
    }
    
    neg_samples_faces = list(query_faces(neg_samples_ord))
    neg_samples_faces.sort(key=lambda f: neg_samples_ord_idxs[f['id']])
    neg_samples_id_set = {f['id'] for f in neg_samples_faces}
    neg_samples = list(filter(
        lambda x: x in neg_samples_id_set, neg_samples_ord))
    selection_widget = esper_widget(
        query_faces_result(neg_samples_faces), 
        crop_bboxes=True, jupyter_keybindings=True
    )

    def _update():
        # Read from the widget, update selections, and commit result
        ignored_idxs = set(selection_widget.ignored)
        selected_idxs = set(selection_widget.selected)
        selection_widget.close()
        clear_output() 
        
        # Add to positive set
        for i in selected_idxs:
            face_id = neg_samples_ord[i]
            if face_id not in face_references.ids:
                _id, emb = face_embeddings.get([face_id])[0]
                assert _id == face_id
                face_references.ids.add(face_id)
                face_references.embs.append(emb)
            pos_examples.add(face_id)

        # Filter negative set
        neg_samples = [
            _id for _, _id in filter(
                lambda x: x[0] not in ignored_idxs and x[0] not in selected_idxs,
                enumerate(neg_samples_ord))
        ]
        
        print('You selected {} and ignored {} faces.'.format(
              len(selected_idxs), len(ignored_idxs),
              len(neg_samples)))
        return neg_samples
    
    def on_update(b):
        neg_samples = _update()
        get_negative_samples(neg_samples=neg_samples)
    update_button.on_click(on_update)
    
    def on_commit(b):
        neg_samples = _update()
        global neg_examples
        neg_examples |= set(neg_samples)
        print('Committed {} samples. There are now {} negative examples.'.format(
              len(neg_samples), len(neg_examples)))
    commit_button.on_click(on_commit)     
   
    display(widgets.HBox([widgets.Label('Controls:'), update_button, commit_button]))
    display(selection_widget)

In [None]:
get_negative_samples(k=10000)

## Getting positive examples (via k-NN)

To obtain an initial set of positive examples, we use the reference images that you selected earlier and find the k-nearest neighbors in the dataset. `get_positive_examples()` will sample these face ids and load a widget to confirm these positive examples.

The goal is to select a clean set of positive examples. The widget will display faces in order of ascending distance from the reference images. All images past your highest index selection (hover and press '\[') will be discarded. If no selections are made, then all of the faces are accepted.

You can rerun `get_positive_examples()` as many times as needed.

In [None]:
def get_positive_examples(k):
    submit_button = widgets.Button(
        layout=widgets.Layout(width='auto'),
        style={'description_width': 'initial'},
        description='Confirm selections',
        disabled=False,
        button_style='danger'
    )

    # Order by increasing distance, excluding already selected
    pos_samples_and_dists = list(filter(
        lambda x: x[0] not in pos_examples, 
        face_embeddings.knn(
            targets=face_references.embs, 
            k=len(pos_examples) + k, max_threshold=1.)
    ))
    pos_samples_to_idx = {
        f[0]: i for i, f in enumerate(pos_samples_and_dists)
    }
    pos_samples_faces = list(query_faces(
        [x[0] for x in pos_samples_and_dists]))
    pos_samples_faces.sort(key=lambda f: pos_samples_to_idx[f['id']])
    for p in pos_samples_faces:
        _, p['dist'] = pos_samples_and_dists[pos_samples_to_idx[p['id']]]
    
    selection_widget = esper_widget(
        query_faces_result(pos_samples_faces),
        crop_bboxes=True, jupyter_keybindings=True
    )

    def on_submit(b):
        selected_idxs = selection_widget.selected
        ignored_idxs = set(selection_widget.ignored)
        max_selected_idx = (
            max(selected_idxs)
            if len(selected_idxs) > 0 else len(pos_samples_faces))
        clear_output()
        
        pos_samples = {
            x['id'] for i, x in enumerate(pos_samples_faces[:max_selected_idx])
            if i not in ignored_idxs
        }
        
        global pos_examples
        pos_examples |= set(pos_samples)
        print('Accepted {} labels. There are now {} positive examples.'.format(
              len(pos_samples), len(pos_examples)))      
    submit_button.on_click(on_submit)
    
    display(widgets.HBox([widgets.Label('Controls:'), submit_button]))
    display(selection_widget)

In [None]:
get_positive_examples(k=1000)

# Training a Classifier

This section will train a model based on examples that you selected previously. Before proceeding, make sure that you have run the negative sampling cell above and generated a set of initial positive examples If you have not, then the following cell with throw a ValueError telling you to do so.

In [None]:
if pos_examples is None:
    raise ValueError('No positive training examples! Did you confirm the selection above?')
if neg_examples is None:
    raise ValueError('No negative training examples!')
print('Proceeding with {} positive and {} negative training examples'.format(
    len(pos_examples), len(neg_examples)))

## Training and obtaining predictions

`train_model()` will train a binary classifier using the face identifiers in `pos_examples` and `neg_examples`. This will take roughly 20 seconds. Each time you rerun `train_model()`, your model will be retrained, so be sure to rerun it when you have added new labels and are ready to retrain. This will also generate some debugging charts to diagnose model performance.

In [None]:
POS_LABEL = 1
NEG_LABEL = 0
NUM_EPOCHS = 40
LEARNING_RATE = 1
L2_PENALTY = 1e-5

def plot_roc(y_true, y_pred, title='Receiver Operating Characteristic'):
    fpr, tpr, threshold = metrics.roc_curve(y_true, y_pred)
    roc_auc = metrics.auc(fpr, tpr)
    plt.figure()
    plt.title(title)
    plt.plot(fpr, tpr, 'b', label = 'AUC = %0.2f' % roc_auc)
    plt.legend(loc = 'lower right')
    plt.plot([0, 1], [0, 1],'r--')
    plt.xlim([0, 1])
    plt.ylim([0, 1])
    plt.ylabel('True Positive Rate')
    plt.xlabel('False Positive Rate')
    plt.show()
    
def plot_binary_score_histograms(y_true, y_pred, y_max=None, 
                                 title='Score Distribution by Class'):
    bins = np.linspace(0, 1, 100)
    plt.figure()
    plt.hist([x for i, x in enumerate(y_pred) if y_true[i] == POS_LABEL], 
             bins, alpha=0.5, label=face_references.name)
    plt.hist([x for i, x in enumerate(y_pred) if y_true[i] == NEG_LABEL], 
             bins, alpha=0.5, label='Not {}'.format(face_references.name))
    plt.title(title)
    plt.xlabel('Predicted Score')
    if y_max is not None: 
        plt.ylim(0, y_max)
    plt.legend()
    plt.show()
    
def plot_score_histogram(predictions, sample):
    bins = np.linspace(0, 1, 100)
    plt.figure()
    sampled_pred = (
        random.sample(predictions, sample) 
        if sample < len(predictions) else predictions
    )
    plt.hist([s for _, s in sampled_pred], bins, alpha=1)
    plt.title('Predicted Score Distribution (sample={})'.format(
              min(sample, len(predictions))))
    plt.xlabel('Predicted Score')
    plt.yscale('log', nonposy='clip')
    plt.show()
    
def plot_estimated_cdf(predictions, sample):
    n_bins = 100
    def score_to_bin(s):
        v = math.floor(s * n_bins)
        return min(v, n_bins - 1)
    bins = np.zeros(n_bins)
    sampled_pred = (
        random.sample(predictions, sample) 
        if sample < len(predictions) else predictions
    )
    for _, s in sampled_pred:
        bins[score_to_bin(s)] += s

    sample_est_pos = np.sum(bins)
    total_est_pos = int(sample_est_pos / sample * len(predictions))
    
    norm_bins = bins / sample_est_pos
    cdf_bins = np.cumsum(norm_bins)
    inds = np.arange(bins.size) / n_bins
    plt.figure()
    plt.title('CDF of Positive Predictions ' +
              '(total estimated positives={})'.format(
              total_est_pos))
    plt.plot(inds, cdf_bins, label='Est. Cumulative Proportion')
    plt.plot(inds, norm_bins, label='Est. Bin Proportion ({} bins)'.format(n_bins))
    plt.ylabel('Proportion')
    plt.xlabel('Predicted Score')
    plt.ylim(bottom=0)
    plt.legend()
    plt.show()
    
def train_model(train_val_ratio=10):
    print('Training logistic classifier with {}:1 train to validation split'.format(
          train_val_ratio))
    
    pos_examples_copy = list(pos_examples)
    random.shuffle(pos_examples_copy)
    pos_split_idx = int(len(pos_examples_copy) / train_val_ratio)
    val_pos, train_pos = split_list(pos_examples_copy, pos_split_idx)
    
    neg_examples_copy = list(neg_examples)
    random.shuffle(neg_examples_copy)
    neg_split_idx = int(len(neg_examples_copy) / train_val_ratio)
    val_neg, train_neg = split_list(neg_examples_copy, neg_split_idx)
    
    train_ids = train_pos + train_neg
    train_y = ([POS_LABEL] * len(train_pos)) + ([NEG_LABEL] * len(train_neg))
    
    val_ids = val_pos + val_neg
    val_y = ([POS_LABEL] * len(val_pos)) + ([NEG_LABEL] * len(val_neg))
    
    weights, predictions = face_embeddings.logreg(
        train_ids, train_y,
        0, 1, num_epochs=NUM_EPOCHS, learning_rate=LEARNING_RATE, 
        l2_penalty=L2_PENALTY)
    
    print('Generating debugging plots')
    
    train_id_to_idx = {v: i for i, v in enumerate(train_ids)}
    train_pred_y = [0] * len(train_ids)
    val_id_to_idx = {v: i for i, v in enumerate(val_ids)}
    val_pred_y = [0] * len(val_ids)
    
    for v, s in predictions:
        if v in train_id_to_idx:
            train_pred_y[train_id_to_idx[v]] = s
        if v in val_id_to_idx:
            val_pred_y[val_id_to_idx[v]] = s
            
    num_tabs = 3
    outputs = [widgets.Output() for _ in range(num_tabs)]
    tabs = widgets.Tab(children=outputs)
    
    with outputs[0]:
        tabs.set_title(0, 'Entire Dataset')
        plot_score_histogram(predictions, sample=100000)
        print('If we interpret the scores produced by the model as probabilities, '
              'we can estimate the number of true positives that we expect to find '
              'in the dataset. The following plot makes this assumption and shows '
              'the expected contribution of faces of varying scores to the total.')
        plot_estimated_cdf(predictions, sample=100000)
        
    with outputs[1]:
        tabs.set_title(1, 'Training Set')
        plot_roc(train_y, train_pred_y)
        plot_binary_score_histograms(train_y, train_pred_y)
        
    with outputs[2]:
        tabs.set_title(2, 'Validation Set')
        plot_roc(val_y, val_pred_y)
        plot_binary_score_histograms(val_y, val_pred_y)
    
    display(tabs)
    return weights, predictions

print('Hyperparameters')
print('  Epochs:', NUM_EPOCHS)
print('  Learning rate:', LEARNING_RATE)
print('  L2 penalty:', L2_PENALTY)

In [None]:
weights, predictions = train_model()

## Visualize predictions

In [None]:
STYLE_ARGS = {'description_width': 'initial'}

sample_size_text = widgets.BoundedIntText(
    style=STYLE_ARGS,
    value=100,
    min=1,
    max=10000,
    description='Sample size:',
    disabled=False
)

hide_already_labeled_checkbox = widgets.Checkbox(
    style=STYLE_ARGS,
    value=True,
    description='Hide already labeled examples',
    disabled=False
)

sample_sort_button = widgets.ToggleButtons(
    style=STYLE_ARGS,
    options=['random', 'descending distance', 'ascending distance'],
    value='descending distance',
    description='Sample sort:',
    disabled=False,
    orientation='horizontal'
)

score_range_slider = widgets.FloatRangeSlider(
    layout=widgets.Layout(width='100%'),
    style=STYLE_ARGS,
    value=[0.45, 0.55],
    min=0,
    max=1,
    step=0.05,
    description='Predicted scores:',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.2f',
)

commercial_filter_button = widgets.ToggleButtons(
    style=STYLE_ARGS,
    options=['disabled', 'select', 'exclude'],
    value='disabled',
    description='Commercial filter:',
    disabled=False,
    orientation='horizontal'
)

GENDER_OPTIONS = {
    'disabled': 0,
    'male': 1,
    'female': 2,
}
gender_filter_button = widgets.ToggleButtons(
    style=STYLE_ARGS,
    options=['disabled', 'male', 'female'],
    value='disabled',
    description='Gender filter:',
    disabled=False,
    orientation='horizontal'
)

MAX_HEIGHT = 1.
MIN_HEIGHT = 0.
face_height_slider = widgets.FloatRangeSlider(
    layout=widgets.Layout(width='100%'),
    style=STYLE_ARGS,
    value=[MIN_HEIGHT, MAX_HEIGHT],
    min=0,
    max=1,
    step=0.05,
    description='Face height (proportion):',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.2f',
)

MAX_SHARPNESS = 1000.
MIN_SHARPNESS = 0.
face_sharpness_slider = widgets.FloatRangeSlider(
    layout=widgets.Layout(width='100%'),
    style=STYLE_ARGS,
    value=[MIN_SHARPNESS, MAX_SHARPNESS],
    min=MIN_SHARPNESS,
    max=MAX_SHARPNESS,
    step=0.5,
    description='Face sharpness:',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.1f',
)

caption_filter_button = widgets.ToggleButtons(
    style=STYLE_ARGS,
    options=['disabled', 'mentioned', 'not mentioned'],
    value='disabled',
    description='Captions filter:',
    disabled=False,
    orientation='horizontal'
)

caption_filter_text = widgets.Text(
    layout=widgets.Layout(width='100%'),
    style=STYLE_ARGS,
    value=face_references.name,
    placeholder='Type something...',
    description='Caption phrases (separated by commas):',
    disabled=False
)

canonical_show_dropdown = widgets.Dropdown(
    layout=widgets.Layout(width='100%'),
    style=STYLE_ARGS,
    options=['All'] + list(sorted(MAJOR_CANONICAL_SHOWS)),
    value='All',
    description='Show filter:',
    disabled=False,
)

def get_vis_args():
    score_range = score_range_slider.value
    height_range = face_height_slider.value
    sharpness_range = face_sharpness_slider.value
    return {
        'hide_already_labeled': hide_already_labeled_checkbox.value,
        'sample_size': sample_size_text.value,
        'sample_sort': sample_sort_button.value,
        'score_range': score_range,
        'commercial_filter': commercial_filter_button.value,
        'gender_filter': gender_filter_button.value,
        'height_range': height_range,
        'sharpness_range': sharpness_range,
        'caption_filter': caption_filter_button.value,
        'caption_text': [t.strip() for t in caption_filter_text.value.split(',')],
        'canonical_show': canonical_show_dropdown.value,
    }

In [None]:
display(widgets.HBox([sample_size_text, hide_already_labeled_checkbox]))
display(sample_sort_button)
display(score_range_slider)
display(commercial_filter_button)
display(gender_filter_button)
display(canonical_show_dropdown)
display(face_height_slider)
display(face_sharpness_slider)
display(caption_filter_button)
display(caption_filter_text)
custom_filter_fn = lambda x: x

You may also define a custom query filter function using any of the Esper models. Note that this cell will need to be re-run each time the function is updated.

In [None]:
def custom_filter_fn(qs):
    # BEGIN: INSERT CODE HERE}
    print('No custom filter defined.')
    # END
    return qs

The labeling widget will display below once `visualize()` is called. If you update the selections and sliders above, you can reload the visualization by confirming your current labels by hitting the <b>Confirm selections</b> button or by hitting the <b>Refresh</b> button. Alternatively, rerun the cell.

In [None]:
def visualize():
    show_reference_imgs(face_references)
    print('Loading widget...')
    
    vis_args = get_vis_args()
    
    submit_button = widgets.Button(
        layout=widgets.Layout(width='auto'),
        style={'description_width': 'initial'},
        description='Confirm selections',
        disabled=False,
        button_style='danger'
    )
    
    refresh_button = widgets.Button(
        layout=widgets.Layout(width='auto'),
        style={'description_width': 'initial'},
        description='Refresh (w/o confirming)',
        disabled=False,
        button_style='info',
    )
    
    dismiss_button = widgets.Button(
        layout=widgets.Layout(width='auto'),
        style={'description_width': 'initial'},
        description='Dismiss widget',
        disabled=False,
        button_style='warning'
    )
    
    labeled_ids_set = set(pos_examples) | set(neg_examples)
    
    min_score, max_score = vis_args['score_range']
    gender_filter = GENDER_OPTIONS[vis_args['gender_filter']]
    hide_already_labeled = vis_args['hide_already_labeled']
    def pre_query_filter_fn(face_id_and_score):
        face_id, score = face_id_and_score
        if hide_already_labeled and face_id in labeled_ids_set:
            return False
        if gender_filter == 1:    # Male 
            if face_id in female_face_ids:
                return False
        elif gender_filter == 2:  # Female
            if face_id not in female_face_ids:
                return False
        return score >= min_score and score <= max_score
    
    def query_filter_fn(qs):
        if vis_args['commercial_filter'] != 'disabled':
            qs = qs.filter(
                shot__in_commercial=vis_args['commercial_filter'] == 'select')
        
        min_height, max_height = vis_args['height_range']
        if min_height > MIN_HEIGHT or max_height < MAX_HEIGHT:
            qs = qs.annotate(height=BoundingBox.height_expr())
            min_height = min_height
            if min_height > MIN_HEIGHT:
                qs = qs.filter(height__gte=min_height)
            max_height = max_height
            if max_height < MAX_HEIGHT:
                qs = qs.filter(height__lte=max_height)
        
        min_sharpness, max_sharpness = vis_args['sharpness_range']
        if min_sharpness > MIN_SHARPNESS:
            qs = qs.filter(blurriness__gte=min_sharpness)
        if max_sharpness < MAX_SHARPNESS:
            qs = qs.filter(blurriness__lte=max_sharpness)
      
        if vis_args['canonical_show'] != 'All':
            qs = qs.filter(
                frame__video__show__canonical_show__name=vis_args['canonical_show'])
        
        if vis_args['caption_filter'] != 'disabled':
            video_ids = set()
            for phrase in vis_args['caption_text']:
                video_ids |= video_ids_with_mentions(phrase)
                video_ids |= video_ids_with_mentions(phrase.upper())
            if vis_args['caption_filter'] == 'mentioned':
                qs = qs.filter(frame__video__id__in=video_ids)
            else:
                qs = qs.exclude(frame__video__id__in=video_ids)

        return qs
    
    filtered_pred = list(filter(pre_query_filter_fn, predictions))
    filtered_pred_faces = query_faces([x[0] for x in filtered_pred])
    filtered_pred_faces = query_filter_fn(filtered_pred_faces)
   
    # Execute custom Esper query
    try:
        filtered_pred_faces = custom_filter_fn(filtered_pred_faces)
    except Exception as e:
        traceback.print_exc()
        
    filtered_count = filtered_pred_faces.count()
    sample_size = vis_args['sample_size']
    if filtered_count > sample_size:
        filtered_pred_faces = query_sample(filtered_pred_faces, sample_size)
    filtered_pred_faces = list(filtered_pred_faces)
    
    print('Showing {} of {} faces'.format(
          min(sample_size, filtered_count), filtered_count))
    
    # Reorder the samples
    if vis_args['sample_sort'] != 'disabled':
        filtered_pred_faces = sort_faces_by_distance(
            filtered_pred_faces, face_references.embs,
            'ascending' in vis_args['sample_sort'])

    selection_widget = esper_widget(
        query_faces_result(filtered_pred_faces), 
        crop_bboxes=True, jupyter_keybindings=True)
    
    def on_submit(b):
        selected_idxs = set(selection_widget.selected)
        ignored_idxs = set(selection_widget.ignored)
        clear_output()
        
        selected_face_ids = []
        ignored_face_ids = []
        for i, f in enumerate(filtered_pred_faces):
            if i in selected_idxs:
                selected_face_ids.append(f['id'])
            if i in ignored_idxs:
                ignored_face_ids.append(f['id'])
              
        new_pos_labels = 0
        for i in selected_face_ids:
            if i not in labeled_ids_set:
                pos_examples.append(i)
                new_pos_labels += 1
        new_neg_labels = 0
        for i in ignored_face_ids:
            if i not in labeled_ids_set:
                neg_examples.append(i)
                new_neg_labels += 1
                
        print('Added {} new positive and {} new negative examples'.format(
              new_pos_labels, new_neg_labels))
        visualize()
    submit_button.on_click(on_submit)
        
    def on_refresh(b):
        clear_output()
        print('Refreshed without updating examples')
        visualize()
    refresh_button.on_click(on_refresh)
    
    def on_dismiss(b):
        clear_output()
        print('Dismissed widget. Rerun the cell to get it back.')
    dismiss_button.on_click(on_dismiss)
    
    display(widgets.HBox(
        [widgets.Label('Controls:'), submit_button, refresh_button, dismiss_button]))
    display(selection_widget)

In [None]:
visualize()

In addition to visualizing raw images, it can be helpful to examine debugging plots on varying slices of the dataset. The following cell calles `debug_charts()` which generates plots specific to the selected slice. 

Note: "sample size" and "score range" are not applicable to slices. 

In [None]:
MIN_SCORE = 0.05

def debug_charts():
    print('Loading charts for slice... Note: for efficiency, min_score={}'.format(MIN_SCORE))

    vis_args = get_vis_args()
    
    refresh_button = widgets.Button(
        layout=widgets.Layout(width='auto'),
        style={'description_width': 'initial'},
        description='Refresh (w/o confirming)',
        disabled=False,
        button_style='info',
        tooltip='Refresh examples'
    )
    def on_refresh(b):
        clear_output()
        debug_charts()
    refresh_button.on_click(on_refresh)
    
    dismiss_button = widgets.Button(
        layout=widgets.Layout(width='auto'),
        style={'description_width': 'initial'},
        description='Dismiss widget',
        disabled=False,
        button_style='warning',
        tooltip='Dismiss widget'
    )
    def on_dismiss(b):
        clear_output()
        print('Dismissed charts. Rerun the cell to get them back.')
    dismiss_button.on_click(on_dismiss)
    
    gender_filter = GENDER_OPTIONS[vis_args['gender_filter']]
    
    def query_filter_fn(qs):
        if vis_args['commercial_filter'] != 'disabled':
            qs = qs.filter(
                shot__in_commercial=vis_args['commercial_filter'] == 'select')
        
        min_height, max_height = vis_args['height_range']
        if min_height > MIN_HEIGHT or max_height < MAX_HEIGHT:
            qs = qs.annotate(height=BoundingBox.height_expr())
            min_height = min_height
            if min_height > MIN_HEIGHT:
                qs = qs.filter(height__gte=min_height)
            max_height = max_height
            if max_height < MAX_HEIGHT:
                qs = qs.filter(height__lte=max_height)
        
        min_sharpness, max_sharpness = vis_args['sharpness_range']
        if min_sharpness > MIN_SHARPNESS:
            qs = qs.filter(blurriness__gte=min_sharpness)
        if max_sharpness < MAX_SHARPNESS:
            qs = qs.filter(blurriness__lte=max_sharpness)
      
        if vis_args['canonical_show'] != 'All':
            qs = qs.filter(
                frame__video__show__canonical_show__name=vis_args['canonical_show'])
        
        if vis_args['caption_filter'] != 'disabled':
            print('Warning: caption filter is not implemented yet')
        return qs

    num_tabs = 3
    outputs = [widgets.Output() for _ in range(num_tabs)]
    tabs = widgets.Tab(children=outputs)
    tabs.set_title(0, 'CDF and Positive Count Estimate')
    tabs.set_title(1, 'Positive Training Examples')
    tabs.set_title(2, 'Negative Training Examples')
    
    display(widgets.HBox([widgets.Label('Controls:'), refresh_button, dismiss_button]))
    display(tabs)
    
    with outputs[0]:
        print('Computing prediction distibution for slice')
        def pre_query_filter_fn(face_id_and_score):
            face_id, score = face_id_and_score
            if score < MIN_SCORE:
                return False
            if gender_filter == 1:    # Male 
                if face_id in female_face_ids:
                    return False
            elif gender_filter == 2:  # Female
                if face_id not in female_face_ids:
                    return False
            return True
        
        slice_pred = list(filter(pre_query_filter_fn, predictions))
        slice_pred_faces = Face.objects.filter(id__in=[x[0] for x in slice_pred])
        slice_pred_faces = query_filter_fn(slice_pred_faces)

        # Execute custom Esper query
        try:
            slice_pred_faces = custom_filter_fn(slice_pred_faces)
        except Exception as e:
            traceback.print_exc()

        # Filter predictions
        slice_pred_face_ids = {f['id'] for f in slice_pred_faces.values('id')}
        slice_pred = list(filter(lambda x: x[0] in slice_pred_face_ids, slice_pred))
        
        print('Face count: {}'.format(len(slice_pred_face_ids)))
        plot_score_histogram(slice_pred, sample=100000)
        plot_estimated_cdf(slice_pred, sample=100000)
        
    def gender_filter_fn(face):
        face_id = face['id']
        if gender_filter == 1:    # Male 
            if face_id in female_face_ids:
                return False
        elif gender_filter == 2:  # Female
            if face_id not in female_face_ids:
                return False
        return True
        
    with outputs[1]:
        print('Showing positive training examples from slice')
        pos_faces = query_faces(pos_examples)
        pos_faces = query_filter_fn(pos_faces)
        try:
            pos_faces = custom_filter_fn(pos_faces)
        except Exception as e:
            traceback.print_exc()
        if gender_filter != 0:
            pos_faces = filter(gender_filter_fn, pos_faces)
        pos_train_widget = esper_widget(
            query_faces_result(pos_faces), 
            crop_bboxes=True, jupyter_keybindings=True)
        display(pos_train_widget)
        
    with outputs[2]:
        print('Showing negative training examples from slice')
        neg_faces = query_faces(neg_examples)
        neg_faces = query_filter_fn(neg_faces)
        try:
            neg_faces = custom_filter_fn(neg_faces)
        except Exception as e:
            traceback.print_exc()
        if gender_filter != 0:
            neg_faces = filter(gender_filter_fn, neg_faces)
        neg_train_widget = esper_widget(
            query_faces_result(neg_faces), 
            crop_bboxes=True, jupyter_keybindings=True)
        display(neg_train_widget)

In [None]:
debug_charts()

# Saving your model

Serialize the model weights, training set, and labels.

In [None]:
def save_model(path=None):
    if path is None: 
        if not os.path.exists(MODEL_DIR):
            os.makedirs(MODEL_DIR)
        path = os.path.join(MODEL_DIR, '{}.pkl'.format(
            name.lower().replace(' ', '_')))
    print('Saving model: {}'.format(path))
    if os.path.exists(path):
        continue_yn_prompt('Existing file will be overwritten')
    with open(path, 'wb') as f:
        model = {
            'name': face_references.name,
            'init_embs': face_references.embs,
            'init_ids': face_references.ids,
            'init_imgs': face_references.imgs,
            'pos_examples': pos_examples,
            'neg_examples': neg_examples,
            'weights': weights
        }
        pickle.dump(model, f)
    print('Done!')

In [None]:
save_model(path=None)