# How good are methods at correctly detecting collisions?

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
from __future__ import print_function, absolute_import, unicode_literals, division
import sys, six; from six.moves import (zip, filter, map, reduce, input, range)
sys.path.append('..');import pathcustomize, about
about.about()

import collections

import matplotlib.pyplot as plt
import pandas as pd


#from multiworm import Experiment
from waldo.wio import Experiment
from waldo import collider
from waldo.collider.viz import direct_degree_distribution as ddd

In [None]:
#ex_id = '20130614_120518'
#ex_id = '20130318_131111'
#ex_id = '20130414_140704'
#ex_id = '20130702_135704' # many pics
ex_id = '20130702_135652' # many pics

experiment = Experiment(experiment_id=ex_id, data_root='/home/projects/worm_movement/Data/MWT_RawData')
graph = experiment.graph.copy()
collider.remove_nodes_outside_roi(graph, experiment)
pass

## Screen Data

In [None]:
screen = pd.read_csv('../../data/prep/collision_validate.csv')
screen = screen[screen['eid'] == ex_id]
assert len(screen), "No data for experiment ID"

In [None]:
screen_results = {
    '2': set(screen[screen['ans'] == 20]['bid'].values),# | set(screen[screen['ans'] == 30]['bid'].values),
    '1': set(screen[screen['ans'] == 10]['bid'].values),
    'all': set(screen['bid'].values),
}

In [None]:
ddd(graph)
ddd(graph, nodes=screen['bid'].values)
ddd(graph, nodes=screen[screen['ans'] == 10]['bid'].values)
ddd(graph, nodes=screen[screen['ans'] == 20]['bid'].values)
ddd(graph, nodes=screen[screen['ans'] == 30]['bid'].values)

In [None]:
def calc_performance(suspects):
    alg_results = set(suspects)
    alg_performance = [
        ('TP', alg_results & screen_results['2']),
        ('FP', alg_results - screen_results['2']),
        ('FN', screen_results['2'] - alg_results),
        ('TN', (screen_results['all'] - screen_results['2']) & alg_results),
    ]
    return alg_performance
    
def show_performance(alg_performance):
    ticks, data = zip(*alg_performance)
    data = [len(d) for d in data]
    ypos = list(range(len(alg_performance)))
    
    f, ax = plt.subplots()
    ax.barh(ypos, data, align='center')
    ax.set_yticks(ypos)
    ax.set_yticklabels(ticks)

    return f, ax

def show_result_type(suspects):
    answers = screen[screen['bid'].isin(suspects)]['ans']
    answer_categories = collections.Counter(answers)
    labels, amounts = zip(*six.iteritems(answer_categories))
    
    f, ax = plt.subplots()
    ax.pie(amounts, labels=labels)
    return f, ax

## Method 1: Time

In [None]:
threshold = 30 #?
suspects = collider.suspected_collisions(graph, threshold)
print(', '.join(str(x) for x in suspects[:10]), '...' if len(suspects) > 10 else '')
print('{} total suspects'.format(len(suspects)))

In [None]:
show_performance(calc_performance(suspects))

In [None]:
show_result_type(suspects)

## Method 2: Area

In [None]:
suspects = collider.find_area_based_collisions(graph, experiment)

In [None]:
print(', '.join(str(x) for x in suspects[:10]), '...' if len(suspects) > 10 else '')
print('{} total suspects'.format(len(suspects)))

In [None]:
show_performance(calc_performance(suspects))

In [None]:
show_result_type(suspects)

## Method 3: Bounding Box

In [None]:
suspects = collider.find_bbox_based_collisions(graph, experiment)

In [None]:
show_performance(calc_performance(suspects))

In [None]:
show_result_type(suspects)