In [6]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
from pathlib import PurePath
from json import load

In [1]:
SMALL_SIZE = 12
MEDIUM_SIZE = 18
BIGGER_SIZE = 26
CHONK_SIZE = 32
font = {'family' : 'DIN Condensed',
        'weight' : 'bold',
        'size'   : SMALL_SIZE}
plt.rc('font', **font)
plt.rc('axes', titlesize=BIGGER_SIZE, labelsize=MEDIUM_SIZE, facecolor="xkcd:white")
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=CHONK_SIZE, facecolor="xkcd:white", edgecolor="xkcd:black") #  powder blue

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [3]:
def calc_model_performance(tpr, fpr, thresholds):
    performance_dict = {}
    
    # calculating auc
    performance_dict["AUC"] = np.trapz(y = tpr, x = fpr)
    
    # calculating eer
    tnr = 1 - np.array(tpr)
    scores = np.array((tnr, fpr)).T
    diffs = np.absolute(scores[:, 0] - scores[:, 1])
    min_index = np.argmin(diffs)
    lowest_threshold = thresholds[min_index]
    eer = (tnr[min_index] + fpr[min_index]) / 2
    performance_dict["EER"] = eer
    performance_dict["Threshold"] = lowest_threshold

    return performance_dict

In [4]:

def plot_ROC_curve(tpr, fpr, thresholds, performance, model_name, output_folder):
    fig, ax = plt.subplots()
    ax.fill_between(fpr, tpr)
    sns.scatterplot(x = fpr, y = tpr, ax = ax)
    ax.set_xlabel("False Positive Rate")
    ax.set_ylabel("True Positive Rate")

    v, h = .1, .1
    ax.set_xlim(0-h, 1+h)
    ax.set_ylim(0-v, 1+v)

    # Loop through the data points 
    for i, threshold in enumerate (thresholds):
        plt.text(fpr[i], tpr[i], threshold)

    vals = [i for i in np.arange(0, 1, 0.01)]
    xp = [i for i in np.arange(1, 0, -0.01)]
    sns.lineplot(x = vals, y = vals, ax = ax, color = "red")
    sns.lineplot(x = vals, y = xp, ax = ax, color = "green")
    fig.suptitle(f"ROC Curve: {model_name}")
    fig.set_size_inches(10, 7)
    start = .2
    gap = .2
    height = 1.05
    plt.text(start, height, f"EER: {round(performance['EER'] * 100, 1)}%")
    plt.text(start + gap, height, f"Threshold: {round(performance['Threshold'], 3)}")
    plt.text(start + 2*gap + 0.05, height, f"AUC: {round(performance['AUC'], 3)}")

    plt.savefig(f"{output_folder}{model_name}", dpi = 400)

    pass

In [10]:
read_path = PurePath("/Users/joshuaelms/Desktop/github_repos/nsf-reu2022/data/simulation_results/tpr_fpr_Manhattan.json")
with open(read_path, "r") as f:
    data = load(f)

In [9]:
t_start, t_stop, t_step = 0, 10, 1
thresholds = [round(i, 2) for i in np.arange(t_start, t_stop, t_step)]
aggregate_data = {str(threshold): {"tpr": [], "fpr": []} for threshold in thresholds}
for user in data:
    tpr, fpr = data[user][threshold]["tpr"], data[user][threshold]["fpr"]
    if tpr:
        aggregate_data[user]["tpr"].append(tpr)

    if fpr: 
        aggregate_data[user]["fpr"].append(fpr)

KeyError: 'tpr'

In [12]:
data

{'84500': {'0': {'tpr': 0.0, 'fpr': 0.0},
  '1': {'tpr': 0.11392405063291139, 'fpr': 0.07263594138759497},
  '2': {'tpr': 0.4177215189873418, 'fpr': 0.2993871855379043},
  '3': {'tpr': 0.4936708860759494, 'fpr': 0.5410502389008431},
  '4': {'tpr': 0.6582278481012658, 'fpr': 0.6957801688762092},
  '5': {'tpr': 0.759493670886076, 'fpr': 0.7921340920106242},
  '6': {'tpr': 0.7974683544303798, 'fpr': 0.862652147873684},
  '7': {'tpr': 0.8481012658227848, 'fpr': 0.9027339078986492},
  '8': {'tpr': 0.8860759493670886, 'fpr': 0.9204549110334455},
  '9': {'tpr': 0.9367088607594937, 'fpr': 0.9365268750514364}},
 '864651': {'0': {'tpr': 0.0, 'fpr': 0.0},
  '1': {'tpr': 0.06329113924050633, 'fpr': 0.014887670605434588},
  '2': {'tpr': 0.21518987341772153, 'fpr': 0.08483435153395881},
  '3': {'tpr': 0.3670886075949367, 'fpr': 0.19450867372558803},
  '4': {'tpr': 0.5189873417721519, 'fpr': 0.3311187412460628},
  '5': {'tpr': 0.6582278481012658, 'fpr': 0.4971906008001003},
  '6': {'tpr': 0.772151898