#### Load Hashes and Scores

- Load in a set of hashes and distance scores that have already been calculated

In [None]:
import os
import ipywidgets as widgets
import pandas as pd
import numpy as np
from joblib import load
from phaser.utils import load_labelencoders


hash_dist_dir = r"./demo_outputs/"

# Load the label encoders
le = load_labelencoders(filename="LabelEncoders.bz2", path=hash_dist_dir)

# Get values to construct triplets
TRANSFORMS = le["t"].classes_
METRICS = le["m"].classes_
ALGORITHMS = le["a"].classes_

# Load from CSV
#df_h = pd.read_csv(os.path.join(hash_dist_dir , "Hashes.csv.bz2"))
#df_d = pd.read_csv(os.path.join(hash_dist_dir , "Distances.csv.bz2"))

# Load from the df files instead (a better option for larger datasets)
df_h = load(os.path.join(hash_dist_dir , "Hashes.df.bz2"))
df_d = load(os.path.join(hash_dist_dir , "Distances.df.bz2"))

# Inter (0), Intra (1)
intra_df = df_d[df_d["class"] == 1]
inter_df = df_d[df_d["class"] == 0]

print(df_h)


#### Prepare Plots

- Run this before the plot segments below.
- Allows for some configurability of plots

In [None]:
from phaser.evaluation import MetricMaker
from phaser.plotting import  hist_fig, kde_ax, eer_ax, roc_ax
import matplotlib.pyplot as plt
from ipywidgets import interactive
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning) # Ignore Seaborn warnings due to underlying package using future deprecated calls


#define widgets
tselect = widgets.Dropdown(
        options=TRANSFORMS[:-1],
        description='Transform'
    )
mselect = widgets.Dropdown(
        options=METRICS,
        description='Metric'
    )
aselect = widgets.Dropdown(
        options=ALGORITHMS,
        description='Algorithm'
    )
modeselect = widgets.Dropdown(
        options=["inter", "intra"],
        description='Comparison Mode'
    )


### Hist plots, separate for intra/inter
def plot_image(transform, mode, bins=25,  width=8, height=6):
    data = df_h
    if transform != 'Select' and bins > 1:
        if mode == "inter":
            fig = hist_fig(inter_df, label_encoding=le, transform=transform, interactive=True, bins=bins, figsize=(width,height))
        elif mode == "intra":
            fig = hist_fig(intra_df, label_encoding=le, transform=transform, interactive=True, bins=bins, figsize=(width,height))
        fig.suptitle("Similarity Histograms")
        

### KDE multi plot
def kde_plot_multi(transform, width=8, height=6):
    if transform != 'Select':

        #t_label = le_a.transform(np.array(transform).ravel()
        n_cols = len(METRICS)
        n_rows = len(ALGORITHMS)

        # Subset data
        fig, axes = plt.subplots(ncols=n_cols, nrows=n_rows, figsize=(width,height), constrained_layout=True, 
                                 sharex=True, sharey=False, squeeze=False)
                                 
        for col_i, metric in enumerate(METRICS):
            for row_i, algo in enumerate(ALGORITHMS):
                    # Transform strings to labels
                    m_label = le["m"].transform(np.array(metric).ravel())
                    a_label = le["a"].transform(np.array(algo).ravel())

                    # Subset data and get the distances for the chosen transformation
                    _X = df_d.query(f"algo=={a_label} and metric == {m_label}")
                    

                    kde_ax(_X, transform, label_encoding=le, fill=True, title=f"{algo}-{metric}", ax=axes[row_i, col_i])
        fig.suptitle("Inter/Intra-Score KDE Plots")
        

### EER multi plot
def eer_plot_multi(transform, width=8, height=6):
    if transform != 'Select':

        n_cols = len(METRICS)
        n_rows = len(ALGORITHMS)
        # Subset data
        fig, axes = plt.subplots(ncols=n_cols, nrows=n_rows, figsize=(width, height), constrained_layout=True, 
                                 sharex=True, sharey=False, squeeze=False)
                                 
        for col_i, metric in enumerate(METRICS):
            for row_i, algo in enumerate(ALGORITHMS):
                    # Transform strings to labels
                    m_label = le["m"].transform(np.array(metric).ravel())
                    a_label = le["a"].transform(np.array(algo).ravel())

                    # Subset data and get the distances for the chosen transformation
                    _X = df_d.query(f"algo=={a_label} and metric == {m_label}")

                    # get similarities and true class labels
                    y_true = _X["class"]
                    y_similarity = _X[transform]

                    # Prepare metrics for plotting EER and AUC
                    mm = MetricMaker(y_true=y_true, y_similarity=y_similarity, weighted=False)
                    
                    # Set threshold
                    threshold = mm.eer_thresh

                    # Make predictions and compute cm using EER
                    eer_ax(mm.fpr, mm.tpr, mm.thresholds, threshold=threshold, legend=f"", title=f"{algo}-{metric}", ax=axes[row_i, col_i])
        fig.suptitle("EER Plots")
        

### ROC multi plot
def roc_plot_multi(transform, width=8, height=6):
    if transform != 'Select':

        n_cols = len(METRICS)
        n_rows = len(ALGORITHMS)
        # Subset data
        fig, axes = plt.subplots(ncols=n_cols, nrows=n_rows, figsize=(width,height), constrained_layout=True, 
                                 sharex=True, sharey=False, squeeze=False)
                                 
        for col_i, metric in enumerate(METRICS):
            for row_i, algo in enumerate(ALGORITHMS):
                    # Transform strings to labels
                    m_label = le["m"].transform(np.array(metric).ravel())
                    a_label = le["a"].transform(np.array(algo).ravel())

                    # Subset data and get the distances for the chosen transformation
                    _X = df_d.query(f"algo=={a_label} and metric == {m_label}")

                    # get similarities and true class labels
                    y_true = _X["class"]
                    y_similarity = _X[transform]

                    # Prepare metrics for plotting EER and AUC
                    mm = MetricMaker(y_true=y_true, y_similarity=y_similarity, weighted=False)
                    


                    # Make predictions and compute cm using EER
                    roc_ax(mm.fpr, mm.tpr, mm.auc, title=f"{algo}-{metric}", ax=axes[row_i, col_i])
        fig.suptitle("ROC Plots")
        

#### Similarity Score Histograms

- Normalised counts of scores for each hash/transform
- Same data as the KDE plots, but allows for a better understanding of any gaps in the distributions.
- Updates as long as the number of bins is over 1.
- Ideally: Inter distribution is normally distributed around 0.5, while the intra similarity is as high as possible.

In [None]:
# Similarity Score Histograms
h = interactive(plot_image, transform=tselect, mode=modeselect)  # optional: save_location
display(h)

#### Kernel Density Estimation (KDE)

- Combined plot for Inter/Intra scores.
- More or less the same as Histograms, but estimates probability density.
- Ideally, both classes should be completely non-overlapping. Overlap is expected for difficult transforms and indicates difficulty in setting a threshold to separate them, resulting in Fasle Positives / False Negatives.

In [None]:
# Similarity Kernel Density Estimation (KDE) for inter/intra classes
k = interactive(kde_plot_multi, transform=tselect)
display(k)

#### Error Rate

- Visualise the False Positive Rate (FPR) and False Negative Rate (FNR) trade-offs across the similarity score spectrum.
- The vertical line represents the score at which the where FPR == FNR, i.e. Thje Equal Error Rate Threshold (EERt)

In [None]:
# Equal Error Rate (EER) similarity plots
eer = interactive(eer_plot_multi, transform=tselect)
display(eer)


#### Receiver Operating Characteristic (ROC)

- Plot TPR vs FPR to visualise trade-offs.
- Provides Area Under the Curve (AUC) as a means of summarising overall performance. Larger AUC (up to 1.0) is better.

In [None]:
# Receiver Operator Characteristic (ROC) similarity plots
roc = interactive(roc_plot_multi, transform=tselect)
display(roc)
