-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
91 lines (76 loc) · 3.35 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score, roc_curve
def get_all_roc_coordinates(y_real, y_proba):
'''
Calculates all the ROC Curve coordinates (tpr and fpr) by considering each point as a treshold for the predicion of the class.
Args:
y_real: The list or series with the real classes.
y_proba: The array with the probabilities for each class, obtained by using the `.predict_proba()` method.
Returns:
tpr_list: The list of TPRs representing each threshold.
fpr_list: The list of FPRs representing each threshold.
'''
tpr_list = [0]
fpr_list = [0]
for i in range(len(y_proba)):
threshold = y_proba[i]
y_pred = y_proba >= threshold
tpr, fpr = calculate_tpr_fpr(y_real, y_pred)
tpr_list.append(tpr)
fpr_list.append(fpr)
return tpr_list, fpr_list
def calculate_tpr_fpr(y_real, y_pred):
'''
Calculates the True Positive Rate (tpr) and the True Negative Rate (fpr) based on real and predicted observations
Args:
y_real: The list or series with the real classes
y_pred: The list or series with the predicted classes
Returns:
tpr: The True Positive Rate of the classifier
fpr: The False Positive Rate of the classifier
'''
# Calculates the confusion matrix and recover each element
cm = confusion_matrix(y_real, y_pred)
TN = cm[0, 0]
FP = cm[0, 1]
FN = cm[1, 0]
TP = cm[1, 1]
# Calculates tpr and fpr
tpr = TP/(TP + FN) # sensitivity - true positive rate
fpr = 1 - TN/(TN+FP) # 1-specificity - false positive rate
return tpr, fpr
def plot_roc_curve(tpr, fpr, scatter = True, ax = None):
'''
Plots the ROC Curve by using the list of coordinates (tpr and fpr).
Args:
tpr: The list of TPRs representing each coordinate.
fpr: The list of FPRs representing each coordinate.
scatter: When True, the points used on the calculation will be plotted with the line (default = True).
'''
if ax == None:
plt.figure(figsize = (5, 5))
ax = plt.axes()
if scatter:
sns.scatterplot(x = fpr, y = tpr, ax = ax)
sns.lineplot(x = fpr, y = tpr, ax = ax)
sns.lineplot(x = [0, 1], y = [0, 1], color = 'green', ax = ax)
plt.xlim(-0.05, 1.05)
plt.ylim(-0.05, 1.05)
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
def print_conf_matrix(conf_matrix, model):
group_names = ['c-normal, p-normal','c-normal, p-suspect','c-normal, p-path',
'c-suspect, p-normal', 'c-suspect, p-suspect', 'c-suspect, p-path',
'c-path, p-normal','c-path, p-suspect','c-path, p-path']
group_counts = ["{0:0.0f}".format(value) for value in conf_matrix.flatten()]
group_percentages = ["{0:.2%}".format(value) for value in conf_matrix.flatten()/np.sum(conf_matrix)]
labels = [f"{v1}\n{v2}\n{v3}" for v1, v2, v3 in zip(group_names,group_counts,group_percentages)]
labels = np.asarray(labels).reshape(3,3)
plt.figure(figsize=(10,10))
plt.xlabel("Predicted class")
plt.ylabel("True class")
plot = sns.heatmap(conf_matrix, annot=labels, fmt='', cmap='Blues', cbar=False)
fig = plot.get_figure()
fig.savefig(f"{model} conf matrix.png")