<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc" style="margin-top: 1em;"><ul class="toc-item"><li><span><a href="#Designate-a-Set-of-Movies" data-toc-modified-id="Designate-a-Set-of-Movies-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>Designate a Set of Movies</a></span></li><li><span><a href="#Designate-a-Set-of-Characters" data-toc-modified-id="Designate-a-Set-of-Characters-2"><span class="toc-item-num">2&nbsp;&nbsp;</span>Designate a Set of Characters</a></span></li><li><span><a href="#Initialize-Notebook-State" data-toc-modified-id="Initialize-Notebook-State-3"><span class="toc-item-num">3&nbsp;&nbsp;</span>Initialize Notebook State</a></span></li><li><span><a href="#Do-Clustering-Pass" data-toc-modified-id="Do-Clustering-Pass-4"><span class="toc-item-num">4&nbsp;&nbsp;</span>Do Clustering Pass</a></span></li><li><span><a href="#Do-Search-Pass" data-toc-modified-id="Do-Search-Pass-5"><span class="toc-item-num">5&nbsp;&nbsp;</span>Do Search Pass</a></span><ul class="toc-item"><li><span><a href="#Save-the-Results" data-toc-modified-id="Save-the-Results-5.1"><span class="toc-item-num">5.1&nbsp;&nbsp;</span>Save the Results</a></span></li></ul></li></ul></div>

<b>Imports and Setup. Run the hidden cell below!<b>

In [None]:
from IPython.display import display, clear_output
import ipywidgets as widgets
import datetime
import itertools
import io
import os
import time
import math
import numpy as np
np.warnings.filterwarnings('ignore')
import pickle
from collections import defaultdict
from django.db import transaction

from esper.stdlib import *
from esper.prelude import *
import esper.face_embeddings as face_embeddings

NUM_PER_PAGE = 25

def flatten(l):
    return [item for sublist in l for item in sublist]

def query_faces(ids):
    faces = Face.objects.filter(id__in=ids)
    return faces.values(
        'id', 'bbox_y1', 'bbox_y2', 'bbox_x1', 'bbox_x2',
        'frame__number', 'frame__video__id', 'frame__video__fps',
        'shot__min_frame', 'shot__max_frame')

def query_sample(qs, n):
    return qs.order_by('?')[:n]

def query_faces_result(faces, expand_bbox=0.05):
    """Replaces qs_to_result"""
    result = []
    for face in faces:
        min_frame = face['frame__number']
        face_result = {
            'type': 'flat', 'label': '', 
            'elements': [{
                'objects': [{
                    'id': face['id'],
                    'background': False,
                    'type': 'bbox',
                    'bbox_y1': max(face['bbox_y1'] - expand_bbox, 0),
                    'bbox_y2': min(face['bbox_y2'] + expand_bbox, 1),
                    'bbox_x1': max(face['bbox_x1'] - expand_bbox, 0),
                    'bbox_x2': min(face['bbox_x2'] + expand_bbox, 1),
                }], 
                'min_frame': min_frame,
                'video': face['frame__video__id']
            }]
        }
        result.append(face_result)
    return {'type': 'Face', 'count': 0, 'result': result}

def sort_ids_by_distance(ids, embs):
    dists = face_embeddings.dist(ids, targets=embs)
    return [i for _, i in sorted(zip(dists, ids))]

def sort_faces_by_distance(faces, embs, ascending=False):
    ids = [f['id'] for f in faces]
    id_to_dist = {
        k: v for k, v in zip(ids, face_embeddings.dist(ids, targets=embs))
    }
    order_const = 1 if ascending else -1
    faces.sort(key=lambda x: order_const * id_to_dist[x['id']])
    return faces

def get_clusters(face_ids, k):
    clusters = [[] for _ in range(k)]
    for i, c in face_embeddings.kmeans(face_ids, k):
        clusters[c].append(i)
    clusters.sort(key=lambda x: -len(x))
    return clusters

def get_clusters_recursive(face_ids, max_size=1000):
    clusters = []
    branch = math.ceil(len(face_ids) / max_size)
    for c in get_clusters(face_ids, branch):
        if len(c) > max_size:
            clusters.extend(get_clusters_recursive(c, max_size))
        else:
            clusters.append(c)
    clusters.sort(key=lambda x: -len(x))
    return clusters

def get_faces(videos):
    face_ids = [f['id']for f in Face.objects.filter(frame__video__in=videos).values('id')]
    print('Selected films contain {} faces'.format(len(face_ids)))
    embs_exist = face_embeddings.exists(face_ids)
    if not all(embs_exist):
        print('Missing {} face embeddings'.format(len(face_ids) - sum(embs_exist)),
              file=sys.stderr)
    face_dict = {f['id']: f for f in query_faces(face_ids)}
    assert len(face_ids) == len(face_dict)
    return face_ids, face_dict

def parse_identity_list(text):
    identities = set()
    for line in text.split('\n'):
        line = line.strip()
        if len(line) == 0:
            continue
        name, character = line.lower().split(',')
        name = name.strip()
        character = character.strip()
        if len(name) == 0:
            raise ValueError('Name cannot be empty')
        if len(character) == 0:
            raise ValueError('Character cannot be empty')
        identities.add((name, character))
    if len(identities) == 0:
        raise ValueError('No identities specified')
    return identities

def exclude_labeled_faces(face_ids):
    exclude_set = {
        f['face__id'] for f in FaceIdentity.objects.filter(
            face__id__in=face_ids
        ).distinct('face__id').values('face__id')
    }
    return [i for i in face_ids if i not in exclude_set]

def show_people_textbox():
    people_textbox = widgets.Textarea(
        value='',
        layout=widgets.Layout(width='auto', height='400px'),
        style={'description_width': 'initial'},
        placeholder='e.g., Daniel Radclife, Harry Potter',
        description='<b>People:</b> name, identity (1 per line)',
        disabled=False
    )
    valid_checkbox = widgets.Valid(
        value=False,
        style={'description_width': 'initial'},
        description='<b>Valid?</b>',
    )
    def update(b):
        global people
        try:
            people = parse_identity_list(people_textbox.value)
            valid_checkbox.value = True
        except:
            people = []
            valid_checkbox.value = False
    people_textbox.observe(update, names='value')
    display(people_textbox)
    display(valid_checkbox)
    
def format_identity(person, character):
    return '{} :: {}'.format(person, character)

def parse_identity(s):
    return tuple(s.split(' :: '))

def get_identity_options():
    return [''] + [format_identity(*x) for x in sorted(
                   people, key=lambda y: y[1])]

def get_searchable_identity_dropdown():
    identity_dropdown = widgets.Dropdown(
        options=get_identity_options(),
        value='',
        description='Person:',
        disabled=False
    )
    identity_text = widgets.Text(
        value='',
        placeholder='search...',
        continuous_update=True,
        disabled=False
    )
    def on_update(b):
        identity_options = get_identity_options()
        identity_dropdown.options = identity_options
        search_value = identity_text.value.strip().lower()
        for opt in identity_options:
            if search_value in opt:
                identity_dropdown.value = opt
                break
        else:
            identity_dropdown.value = identity_options[0]
    identity_text.observe(on_update, names='value')
    return identity_text, identity_dropdown

def int_prompt(msg, min_val, max_val, default):
    line = input('{}, range=[{}, {}], default={}: '.format(
                 msg, min_val, max_val, default)).strip()
    if line == '':
        value = default
    else:
        value = int(line)
    if value < min_val or value > max_val:
        raise ValueError('Out of range.')
    return value

def prepare_orm_objects(person_to_clusters):
    # Need 2 labelers to get around unique constraint
    time_str = datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S')
    actor_labeler_name = 'cluster-actor-{}'.format(time_str)
    actor_labeler = Labeler.objects.create(name=actor_labeler_name)
    print('Labeler created:', actor_labeler_name)
    
    character_labeler_name = 'cluster-role-{}'.format(time_str)
    character_labeler = Labeler.objects.create(name=character_labeler_name)
    print('Labeler created:', character_labeler_name)
    
    face_identities = []
    for identity, clusters in person_to_clusters.items():
        cluster_face_ids = set(flatten(clusters))
        if len(cluster_face_ids) == 0:
            print('Skipping: no faces for {}'.format(identity), file=sys.stderr)
            continue

        actor_name, character_name = identity
        actor, created = Identity.objects.get_or_create(name=actor_name)
        if created:
            print('Identity created (actor/actress):', actor_name)
        character, created = Identity.objects.get_or_create(name=character_name)
        if created:
            print('Identity created (character):', character_name)

        for i in cluster_face_ids:
            face_identities.append(FaceIdentity(
                labeler=actor_labeler,
                face_id=i,
                probability=1.,
                identity=actor
            ))
            face_identities.append(FaceIdentity(
                labeler=character_labeler,
                face_id=i,
                probability=1.,
                identity=character
            ))
    return (actor_labeler, character_labeler), face_identities

def save_orm_objects(face_identities):
    print('Commiting {} face identity labels'.format(len(face_identities)))
    with transaction.atomic():
        FaceIdentity.objects.bulk_create(face_identities)
    print('Done!')

def print_videos(videos):
    print('You selected {} videos:'.format(videos.count()))
    for v in videos:
        print('  ', v.name)
    
def print_videos_with_faces():
    print('The following videos have faces:')
    key = 'face__frame__video__name'
    for f in FaceFeatures.objects.distinct(key).values(key):
        print('  ', f[key])
        
print_videos_with_faces()

# Designate a Set of Movies

`videos` needs to be a DjangoQuerySet 

In [None]:
videos = Video.objects.filter(name__contains='star wars episode ii')

In [None]:
print_videos(videos)

# Designate a Set of Characters

Use the get_imdb_actor_list notebook to fetch all of the actors/actresses. Or enter them manually.

In [None]:
show_people_textbox()

# Initialize Notebook State

Global variables to track labeling state.

In [None]:
person_to_clusters = defaultdict(list)
unassigned_face_ids, face_dict = get_faces(videos)
initial_num_faces = len(unassigned_face_ids)

<b>Optional: filter to only unlabeled faces</b>

In [None]:
unassigned_face_ids = exclude_labeled_faces(unassigned_face_ids)

# Do Clustering Pass

The idea of this section is that by clustering across all of the unassigned faces and making assignments, main characters will be captured. `do_clustering_pass()` uses k-means clustering to compute clusters of the maximum size specified.

When evaluating the clusters, the options are to accept, ignore (removes the faces from being clustered), or split (returns the faces to unassigned). You add more names, just be sure to refresh the list of people for the person to appear. 

Pressing "finish clustering pass" will return any unselected clusters to the unassigned pool, while committing your other selections. After you click this, some stats will be printed.

You can re-run the cell to do additional clustering passes.

<b>To expand a frame, hover and press '='. Press '=' again to shrink it.</b> 

<b>Run the hidden cell below!</b>

In [None]:
SUB_CLUSTER_THRESH = 100
SUB_CLUSTER_COUNT = 5

def visualize_cluster(cluster_id, face_ids, ignored_ids, unassigned_ids, 
                      clusters_done, sort='random'):
    print('Cluster has {} faces'.format(len(face_ids)))
    if sort != 'random':
        if len(face_ids) > SUB_CLUSTER_THRESH:
            sub_clusters = get_clusters(face_ids, SUB_CLUSTER_COUNT)
        else:
            sub_clusters = face_ids
        mean_embs = [face_embeddings.mean(c) for c in sub_clusters]
        face_ids = sort_ids_by_distance(face_ids, mean_embs)
        if sort == 'descending':
            face_ids = face_ids[::-1]
    else:
        random.shuffle(face_ids)
    faces = [face_dict[i] for i in face_ids]

    sort_button = widgets.ToggleButtons(
        options=['random', 'ascending', 'descending'],
        value=sort,
         style={'description_width': 'initial'},
        description='Sort (by L2-distance to centers):',
        disabled=False,
        button_style=''
    )
    def refresh(b):
        new_sort = sort_button.value
        clear_output()
        visualize_cluster(cluster_id, face_ids, ignored_ids, unassigned_ids, 
                          clusters_done, sort=new_sort)
    sort_button.observe(refresh, names='value')
    display(sort_button)
    
    identity_text, identity_dropdown = get_searchable_identity_dropdown()
    
    accept_button = widgets.Button(description='Accept cluster', button_style='success')
    def on_accept(b):
        person = identity_dropdown.value
        clear_output()
        if person == '':
            print('No person specified. Cluster ignored.', file=sys.stderr)
            ignored_ids.update(face_ids)
        else:
            person = parse_identity(person)
            person_to_clusters[person].append(face_ids)
            print('Assigned cluster to {} ({}).'.format(person, len(face_ids)))
        clusters_done.add(cluster_id)
    accept_button.on_click(on_accept)
    ignore_button = widgets.Button(description='Ignore cluster', button_style='danger')
    def on_ignore(b):
        ignored_ids.update(face_ids)
        clusters_done.add(cluster_id)
        clear_output()
        print('Ignored cluster ({} faces).'.format(len(face_ids)))
    ignore_button.on_click(on_ignore)
    reject_button = widgets.Button(description='Split cluster', button_style='warning')
    def on_reject(b):
        unassigned_ids.update(face_ids)
        clusters_done.add(cluster_id)
        clear_output()
        print('Returned cluster to unassigned set ({} faces).'.format(len(face_ids)))
    reject_button.on_click(on_reject)
    
    cluster_widget = esper_widget(
        query_faces_result(faces), results_per_page=NUM_PER_PAGE, 
        crop_bboxes=True, jupyter_keybindings=True, disable_playback=True)
    display(widgets.HBox([accept_button, ignore_button, reject_button,
                          identity_dropdown, identity_text]))
    display(cluster_widget)
    
def do_clustering_pass(ids, branch=10):
    print('Clustering {} faces.'.format(len(ids)))
    default_size = math.ceil(len(ids) / branch / 100) * 100
    cluster_size = int_prompt('Specify a maximum cluster size',
                              100, max(default_size, 5000), default_size)
    
    clusters = get_clusters_recursive(ids, max_size=cluster_size)
    print('Found {} clusters. (Ordered from largest to smallest).'.format(len(clusters)))
    
    unassigned_ids = set()
    ignored_ids = set()
    clusters_done = set()
    
    outputs = [widgets.Output() for _ in range(len(clusters))]
    tabs = widgets.Tab(children=outputs)
    for i in range(len(clusters)):
        tabs.set_title(i, str(i))
    tabs_loaded = set()
    def load_current_tab(b):
        i = tabs.selected_index
        if not i in tabs_loaded:
            tabs_loaded.add(i)
        with outputs[i]:
            visualize_cluster(i, clusters[i], ignored_ids, unassigned_ids,
                              clusters_done)
    tabs.observe(load_current_tab, names='selected_index')
    load_current_tab(None)
    
    finish_button = widgets.Button(description='Finish clustering pass',
                                   button_style='success')
    def on_finish(b):
        clear_output()
        if not len(clusters_done) == len(clusters):
            remaining_clusters = set(range(len(clusters))) - clusters_done
            print('Not all clusters selected... (treating them as unassigned)')
            for c in remaining_clusters:
                unassigned_ids.update(clusters[c])
        finish_button.disabled = True
        assigned_count = len(ids) - len(ignored_ids) - len(unassigned_ids)
        print('Clustering pass statistics:')
        print('  ', 'Assigned: {}'.format(assigned_count))
        print('  ', 'Ignored: {}'.format(len(ignored_ids)))
        print('  ', 'Unassigned: {}'.format(len(unassigned_ids)))
        global unassigned_face_ids
        unassigned_face_ids = list(unassigned_ids)
        
        print()
        print_assignment_state()
    finish_button.on_click(on_finish)
    display(tabs)
    display(finish_button)
    
def print_assignment_state():
    print('{} faces are unassigned (of {} initially)'.format(
          len(unassigned_face_ids), initial_num_faces))
    if len(person_to_clusters) > 0:
        print('The following people have faces assigned:')
        for k, v in sorted(person_to_clusters.items()):
            id_set = set()
            for c in v:
                id_set.update(c)
            print('  {}: {}'.format(k, len(id_set)))
    else:
        print('No people have faces assigned.')

In [None]:
print_assignment_state()

In [None]:
do_clustering_pass(unassigned_face_ids)

# Do Search Pass

This is useful for selecting specific unclustered people. The idea is to find 1 or more examples of a person, and then to sort all the unassigned faces by distance to these examples.

First, you will be shown unassigned faces at random. Select 1 or more examples of the person you want to label from this set. <b>To select an example, press '['.</b> Hit submit when you are ready to move on to the next step.

Next, you will be shown the remaining unassigned faces in order of ascending distance. Select all the faces that are of this person. <b>To select a page of examples, press '{' (i.e., 'shift + ['). To deselect a face, press '[' while hovering over it.</b> Hit submit to finish the assignment.

<b>Run the hidden cell below!</b>

In [None]:
def do_search_pass(ids):
    faces = [face_dict[i] for i in ids]
    random.shuffle(faces)
    
    select_widget = esper_widget(
        query_faces_result(faces), results_per_page=NUM_PER_PAGE, 
        crop_bboxes=True, jupyter_keybindings=True, disable_playback=True)
    
    submit_button = widgets.Button(description='Submit examples', button_style='success')
    def on_submit(b):
        selected_idxs = select_widget.selected
        clear_output()
        if len(selected_idxs) == 0:
            print('No examples selected. Aborting.', file=sys.stderr)
            return
        else:
            print('Selected {} examples. Ordering unassigned faces by ascending distance.'.format(
                  len(selected_idxs)))
            print()
        
        nonlocal faces
        selected_ids = [faces[i]['id'] for i in selected_idxs]
        selected_embs = [x for _, x in face_embeddings.get(selected_ids)]
        faces = sort_faces_by_distance(faces, selected_embs, ascending=True)
        
        identity_text, identity_dropdown = get_searchable_identity_dropdown()
        
        select_widget2 = esper_widget(
            query_faces_result(faces), results_per_page= 4 * NUM_PER_PAGE, 
            crop_bboxes=True, jupyter_keybindings=True, disable_playback=True)
        
        submit_button2 = widgets.Button(description='Submit selections', button_style='success')
        def on_submit2(b):
            selected_idxs2 = select_widget2.selected
            selected_ids2 = [faces[i]['id'] for i in selected_idxs2]
            person = identity_dropdown.value
            if person == '':
                print('No person selected... Try again.', file=sys.stderr)
                return

            clear_output()
            person = parse_identity(person)
            person_to_clusters[person].append(selected_ids2)
            print('Assigned {} faces to {}.'.format(len(selected_ids2), person))

            global unassigned_face_ids
            unassigned_face_ids = list(set(ids) - set(selected_ids2))

            print()
            print_assignment_state()
        submit_button2.on_click(on_submit2)
        
        print('Select faces matching the target person')
        display(widgets.HBox([submit_button2, identity_dropdown, identity_text]))
        display(select_widget2)
        
    submit_button.on_click(on_submit)
    print('Select face search examples:')
    display(submit_button)
    display(select_widget)

In [None]:
print_assignment_state()

In [None]:
do_search_pass(unassigned_face_ids)

## Save the Results

This will save your labels to the database. `prepare_orm_objects()` will create labeler rows, tied to this session. The new face identity rows, corresponding to your assignments, will have foreign keys to these labelers.

`save_orm_objects()` will save the identity labels to the database.

In [None]:
print_assignment_state()

In [None]:
labelers, face_identities = prepare_orm_objects(person_to_clusters)

In [None]:
save_orm_objects(face_identities)

<i>Help! I messed up and committed bad labels... What should I do?</i> Hopefully, if the notebook is still running, the labelers are still defined. Then, you can run the following.

In [None]:
# for labeler in labelers:
#     display(labeler.delete())