# Binary vs Order Classifiers

Notebook to compare the outputs of the binary and order classifiers.

In [1]:
import os
import pandas as pd
import matplotlib.pyplot as plt

In [None]:
# Load in all the data
# For every file in the data directory, load into a list of dataframes
data = {}

data_dir = '/home/users/katriona/amber-inferences/examples/example_outputs/'

for file in os.listdir(data_dir):
    if file.endswith('.csv'):
        data[file] = pd.read_csv(data_dir + file)

In [None]:
# print the shape of each dataframe
for key in data.keys():
    print(key, data[key].shape)

In [20]:
# Add a column for confidence of moth, where = binary_confidence if binary_class == moth and 1-binary_confidence if binary_class == not_moth

def plot_confidence_by_class(df, order_name='Lepidoptera Macros'):
    df_lep = df.loc[df['order_name'] == order_name, ]
    df_lep['moth_conf'] = df_lep['class_confidence'].astype(float)
    df_lep['moth_conf'] = df_lep['moth_conf'].where(df_lep['class_name'] == 'moth', 1-df_lep['moth_conf'])

    # normalise the confidence values
    df_lep['moth_conf'] = df_lep['moth_conf'] - df_lep['moth_conf'].min()

    plt.scatter(df_lep['order_confidence'], df_lep['moth_conf'])
    plt.xlabel('Order Confidence')
    plt.ylabel('Moth Confidence')
    plt.title('Moth Confidence vs Order Confidence')
    plt.show()


In [37]:
# get accuracy metrics
def get_accuracy_metrics(df):
    df['order_moth'] = 'nonmoth'
    df['order_moth'].loc[df['order_name'].str.contains('Lepi'), ] = 'moth'

    # create a confusion matrix
    confusion_matrix = pd.crosstab(df['class_name'], df['order_moth'])

    fpr = confusion_matrix.loc['moth', 'nonmoth'] / confusion_matrix.loc['moth', ].sum()
    fnr = confusion_matrix.loc['nonmoth', 'moth'] / confusion_matrix.loc['nonmoth', ].sum()
    tpr = 1 - fpr
    tnr = 1 - fnr

    accuracy = (confusion_matrix.loc['moth', 'moth'] + confusion_matrix.loc['nonmoth', 'nonmoth']) / confusion_matrix.sum().sum()

    return [accuracy, tpr, tnr, fpr, fnr, confusion_matrix]

In [38]:
def get_breakdown(df):
    # get the breakdown of the class names and order names
    breakdown = df[['class_name', 'order_name']].value_counts()

    # convert to df
    breakdown = breakdown.reset_index()
    breakdown.columns = ['class_name', 'order_name', 'count']


    # add the percentage of total
    breakdown['percentage'] = (breakdown['count'] / breakdown['count'].sum() * 100).round(2)

    breakdown['aligned'] = 'not aligned'

    # set aligned if class_name == moth and order_name contains Lepidoptera
    breakdown.loc[(breakdown['class_name'] == 'moth') & (breakdown['order_name'].str.contains('Lepidoptera')), 'aligned'] = 'aligned'
    breakdown.loc[(breakdown['class_name'] == 'nonmoth') & ~ (breakdown['order_name'].str.contains('Lepidoptera')), 'aligned'] = 'aligned'

    aligned = breakdown.loc[breakdown['aligned'] == 'aligned', 'count'].sum()/breakdown['count'].sum() * 100
    print(f'{aligned.round(2)}% of predictions are aligned')

    return breakdown, aligned

In [44]:
def plot_breakdown(breakdown, title, scale=False):
    breakdown = breakdown.loc[~ breakdown['order_name'].str.contains('data'), ]
    breakdown = breakdown.loc[~ breakdown['order_name'].str.contains('gbr'), ]

    # Create pivot table for stacked bar chart
    pivot_table = breakdown.pivot(index='class_name', columns='order_name', values='count')

    if scale:
        pivot_table = pivot_table.div(pivot_table.sum(axis=1), axis=0)

    # Plot stacked bar chart
    ax = pivot_table.plot(kind='bar', stacked=True, )

    # Add hatching manually
    hatches = {}

    all_non_moth_keys = breakdown['order_name'].unique()
    # if there are any non-moth keys, add them to the hatches dict as None
    for key in all_non_moth_keys:
        # if Lepidoptera is substring of key
        if 'Lepidoptera' in key:
            hatches[key] = '*'
        else:
            hatches[key] = None

    for bar_group, bars in zip(ax.containers, pivot_table.columns):
        for bar in bar_group:
            if hatches[bars]:
                bar.set_hatch(hatches[bars])

    plt.ylabel('Count')
    plt.title(title)
    plt.show()

In [None]:
align_df = pd.DataFrame({'file': [], 'percent_aligned': [], 'accuracy': [],
                        'tpr': [], 'tnr': [], 'fpr': [], 'fnr': []})

# for each dataframe in data, get the breakdown and plot
for key in data.keys():
    print(f"Dataframe: {key}")
    breakdown, aligned = get_breakdown(data[key])

    accuracy, tpr, tnr, fpr, fnr, confusion_matrix = get_accuracy_metrics(data[key])

    # append to df
    new_df = pd.DataFrame({'file': [key], 'percent_aligned': [aligned],
                        'accuracy': [accuracy], 'tpr': [tpr], 'tnr': [tnr], 'fpr': [fpr], 'fnr': [fnr]})
    align_df = pd.concat([align_df, new_df], ignore_index=True)

    plot_breakdown(breakdown, key.replace('.csv', '').replace('_', ' '))
    #plot_confidence_by_class(data[key])
    print('\n\n\n')

In [None]:
align_df