# **Metrics**

In [1]:
import cv2
import numpy as np
import pandas as pd
import os
import itertools
import matplotlib.pyplot as plt

%matplotlib inline

## **Metrics computation**

In [2]:
def get_accuracy(args):
    assert len(args) == 4, f'Args must be float threshold or tps, fps, tns, fns values, but got {type(args)}.'
    tps, fps, tns, fns = args
    
    if tps!=0 or fps!=0 or tns!=0 or fns!=0:
        return (tps+tns)/(tps+tns+fps+fns)
    return 0

def get_precision(args):
    assert len(args) == 4, f'Args must be float threshold or tps, fps, tns, fns values, but got {type(args)}.'
    tps, fps, tns, fns = args
    
    if tps!=0 or fps!=0:
        return tps/(tps+fps)
    return 0

def get_recall(args):
    assert len(args) == 4, f'Args must be float threshold or tps, fps, tns, fns values, but got {type(args)}.'
    tps, fps, tns, fns = args
    if tps!=0 or fns!=0:
        return tps/(tps+fns)
    return 0

def get_F1_score(args):
    assert len(args) == 4, f'Args must be float threshold or tps, fps, tns, fns values, but got {type(args)}.'
    tps, fps, tns, fns = args
        
    precision = get_precision(args)
    recall = get_recall(args)
    if precision!=0 or recall!=0:
        return (2*precision*recall)/(precision+recall)
    return 0

## **Metrics for thresholds combinations**

In [3]:
def check_thresh(tps_df, fps_df, fields, thresholds):
    tps = tps_df
    fps = fps_df
    
    for field, threshold in zip(fields, thresholds):
        if threshold is not None:
            tps = tps[tps[field] < threshold]
            fps = fps[fps[field] < threshold]
        
    tps = len(tps)
    fns = len(tps_df) - tps
    
    fps = len(fps)
    tns = len(fps_df) - fps
                        
    return tps, fps, tns, fns

def get_all_field_thresholds(field_mean, field_std, upper_bound=1.0):
    thresholds = []
    alpha = 0
    
    while True:
        threshold = field_mean + alpha*field_std
        
        if len(thresholds) > 1 and threshold >= upper_bound:
            break
            
        thresholds.append(threshold)
        alpha += 1
        
    return thresholds
        
    
def get_all_combinations(fields_thresholds):
    return list(itertools.product(*fields_thresholds))

def get_thresholds(fields, tps_df, fps_df):
    assert isinstance(fields, list), f'Expected a list of fields data, but got {type(fields)}.'
    assert len(fields)>0, "Fields can't be null."
    fields_names = []
    fields_thresholds = []
    
    for field in fields:
        assert isinstance(field, dict), f'Expected field data to be a dictionary, but got {type(field)}.'
        assert 'name' in field, f'Expected "name" to be a field in data for field.'

        upper_bound = fps_df[field['name']].mean()
        if 'upper_bound' in field:
            upper_bound = field['upper_bound']
            assert isinstance(field['upper_bound'], float), f'Expected upper_bound of type float, but got {type(upper_bound)}.'
            
        fields_names.append(field['name'])
        fields_thresholds.append(get_all_field_thresholds(tps_df[field['name']].mean(), tps_df[field['name']].std(), upper_bound=upper_bound))
    
    threshold_combinations = get_all_combinations(fields_thresholds)
    
    return fields_names, threshold_combinations

def get_metrics(tps_df, fps_df, fields):
    # fields - list of dicts with field name, mean and std
    metrics = []        
    fields_names, threshold_combinations = get_thresholds(fields, tps_df, fps_df)
    
    for fields_thresholds in threshold_combinations:
        curr_metric = {}
        
        tps, fps, tns, fns = check_thresh(tps_df, fps_df, fields_names, fields_thresholds)
        accuracy = get_accuracy((tps, fps, tns, fns))
        precision = get_precision((tps, fps, tns, fns))
        recall = get_recall((tps, fps, tns, fns))
        f1_score = get_F1_score((tps, fps, tns, fns))

        for field, threshold in zip(fields_names, fields_thresholds):
            curr_metric[field+'_thresh'] = threshold

        curr_metric.update({'%tp':round((tps/len(tps_df))*100, 4), 
                            '%fn':round((fns/len(tps_df))*100, 4), 
                            '%tn':round((tns/(len(fps_df)))*100, 4), 
                            '%fp':round((fps/(len(fps_df)))*100, 4), 
                            'accuracy':accuracy, 
                            'f1_score':f1_score, 
                            'precision':precision, 
                            'recall':recall})

        metrics.append(curr_metric)
    
    return pd.DataFrame(metrics)

In [4]:
def get_best_thresholds(metrics_df, metrics, n=1, return_type='df'):
    def get_field_name(col_name):
        return '_'.join(col_name.split('_')[:-1])
    
    metrics_df = metrics_df.sort_values(by=metrics, ascending=False)
    best_thresholds_df = metrics_df.iloc[:n]

    if return_type == 'df':
        return best_thresholds_df
    
    assert return_type == 'dict', f'Expected return_type to be an str of "df" or "dict", but got {return_type}.'
    best_thresholds = []
    for index, row in best_thresholds_df.iterrows():
        best_threshold = {}
        for col in best_thresholds_df.columns:
            if col.endswith('thresh'):
                best_threshold[get_field_name(col)] = row[col]
        best_thresholds.append(best_threshold)
        
    return best_thresholds

In [5]:
def validate_thresh(true_df, false_df, list_thresholds):
    def is_below_thresholds(row, thresholds):
        is_below = True
        
        for field, threshold in thresholds.items():
            if row[field].values[0] > threshold:
                is_below = False
                
        return is_below
    
    validation = []
    for thresholds in list_thresholds:
        tps, fps, tns, fns = 0, 0, 0, 0
        row = 0
        
        while row<len(true_df) and row<len(false_df):
            true_row = true_df.iloc[[row]]
            false_row = false_df.iloc[[row]]
            
            if is_below_thresholds(true_row, thresholds):
                tps+=1
            else:
                fns+=1


            if is_below_thresholds(false_row, thresholds):
                fps+=1
            else:
                tns+=1

            row+=1
            
        while row<len(true_df):
            true_row = true_df.iloc[[row]]
            if is_below_thresholds(true_row, thresholds):
                tps+=1
            else:
                fns+=1
            
            row+=1
                
        while row<len(false_df):
            false_row = false_df.iloc[[row]]
            if is_below_thresholds(false_row, thresholds):
                fps+=1
            else:
                tns+=1
            
            row+=1

        accuracy = get_accuracy((tps, fps, tns, fns))
        precision = get_precision((tps, fps, tns, fns))
        recall = get_recall((tps, fps, tns, fns))
        f1_score = get_F1_score((tps, fps, tns, fns))
        
        curr_threshold = thresholds.copy()

        curr_threshold.update({'tps%':tps/len(true_df)*100, 'fns%':fns/len(true_df)*100, 
                               'tns%':tns/len(false_df)*100, 'fps%':fps/len(false_df)*100,
                               'accuracy':accuracy, 'f1_score':f1_score, 'precision':precision, 'recall':recall})
        validation.append(curr_threshold)
    
    return pd.DataFrame(validation)

## **Plot distribution**

In [6]:
def plot(true, false, title='', cols=None):
    def get_num_rows(cols, ncols):
        if len(cols)<=ncols:
            nrows=1
        else:
            nrows=len(cols)//ncols + len(cols)%ncols
        
        return nrows
    
    def plot_on_axes(true, false, axes, title):
        ax = true[title].plot(kind="kde", ax=axes, title=title, fontsize=20)
        false[title].plot(kind="kde", ax=ax, title=title, fontsize=20)
    
    if cols is None:
        cols = true.columns
    
    ncols = 2
    nrows = get_num_rows(cols, ncols)
    fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(30,nrows*5))
    fig.suptitle(title, fontsize=40)
    plt.rcParams.update({'font.size': 20}) # must set in top
    
    nrow, ncol = 0, 0
    for col in cols:
        ax=axes[nrow,ncol] if nrows>1 else axes[ncol]
        plot_on_axes(true, false, ax, col)
        
        ncol+=1
        if ncol%ncols==0:
            ncol=0
            nrow+=1