# Number of samples (above a certain threshold) vs. accuracy

In [None]:
import mlflow
import pandas as pd
import ipywidgets as widgets
from ipywidgets import interact
import numpy as np
from sklearn.metrics import roc_auc_score
from tqdm import tqdm

import sys
import glob
sys.path.append('../')

from src import GLOBAL_ARTIFACTS_PATH
import matplotlib.pyplot as plt 
# %matplotlib notebook

mlflow.set_tracking_uri('http://127.0.0.1:5005')
# run_id = 'dd9271cad9c4492b96905d2c16bfdcda'
# run_id = 'b310e158a398481f8c30175b7571139b' # PECertainty Multi, Uniform Dropout, 0.25
# run_id = '46a5847a72694b539a76896e5a9e6b09' # SoftMax Multi, Uniform Dropout, 0.25
run_id = '4d9affd3520e40d1ab77957307020c85' # SoftMax Multi, not finished, Uniform Dropout, 0.25

## Prepare Dataframe

In [None]:
run = mlflow.get_run(run_id)
artifacts_path = GLOBAL_ARTIFACTS_PATH + '/' + '/'.join(run.info.artifact_uri.split('/')[-4:]) 
dfs_paths = glob.glob(artifacts_path + '/*')
artifacts_path

dfs = []
for df_path in tqdm(dfs_paths):
    dfs.append(pd.read_csv(df_path))
df = pd.concat(dfs[::-1], ignore_index=True)
df.correctness = df.correctness.astype(int)

In [None]:
df

In [None]:
strategies = df.strategy.unique()
epochs = df.epoch.unique()

## Plotting

In [None]:
linestyle_tuple = [
    ('solid','-','v',80),
     ('dotted',                (0, (1, 1)),'',100),
     ('densely dotted',       'solid', 'D',120),
     ('loosely dashed',        'solid','*',140),
     ('dashed',                (0, (5, 5)),'',160),
     ('densely dashed',        (0, (5, 1)),'',180),
    ('loosely dotted',        (0, (1, 10)), 'o',200),


     ('loosely dashdotted',    (0, (3, 10, 1, 10))),
     ('dashdotted',            (0, (3, 5, 1, 5))),
     ('densely dashdotted',    (0, (3, 1, 1, 1))),

     ('dashdotdotted',         (0, (3, 5, 1, 5, 1, 5))),
     ('loosely dashdotdotted', (0, (3, 10, 1, 10, 1, 10))),
     ('densely dashdotdotted', (0, (3, 1, 1, 1, 1, 1)))]

for epoch in range(max(epochs)):
    plt.figure(figsize=(15,10))
    group = df[df.epoch==epoch]
    legends = []
    for idx, strategy in enumerate(strategies): 
        legends.append(f'{strategy}')
        df_strat = group[group.strategy==strategy]
        total_number = len(df_strat)
        
        x_values = np.linspace(10, total_number, 1000, endpoint=True, dtype=np.int)
        
        y_values = []
        for x in tqdm(x_values): 
            acc = df_strat.sort_values('score', ascending=False).head(x).correctness.mean()
            y_values.append(acc)
        plt.plot(
            x_values, 
            y_values, 
            label=strategy, 
            linestyle=linestyle_tuple[idx][1], 
            marker=linestyle_tuple[idx][2],
            linewidth=2, 
            markevery=linestyle_tuple[idx][3]
        )
    plt.legend(loc=4) 
    plt.title(f'epoch: {epoch}')
    plt.show()