Skip to content

Commit 4be2cd4

Browse files
authored
Add files via upload
1 parent 25e8b59 commit 4be2cd4

File tree

1 file changed

+143
-0
lines changed

1 file changed

+143
-0
lines changed

prec_rec_curve.py

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
import numpy as np
2+
from sklearn.metrics import confusion_matrix, precision_score, recall_score
3+
import matplotlib.pyplot as plt
4+
import matplotlib.patches as ptch
5+
6+
# Appendix A - working with single threshold
7+
pred_scores = [0.7, 0.3, 0.5, 0.6, 0.55, 0.9, 0.4, 0.2, 0.4, 0.3]
8+
y_true = ["positive", "negative", "negative", "positive", "positive", "positive", "negative", "positive", "negative", "positive"]
9+
10+
# To convert the scores into a class label, a threshold is used.
11+
# When the score is equal to or above the threshold, the sample is classified as one class.
12+
# Otherwise, it is classified as the other class.
13+
# Suppose a sample is Positive if its score is above or equal to the threshold. Otherwise, it is Negative.
14+
# The next block of code converts the scores into class labels with a threshold of 0.5.
15+
16+
threshold = 0.5
17+
18+
y_pred = ["positive" if score >= threshold else "negative" for score in pred_scores]
19+
print(y_pred)
20+
21+
r = np.flip(confusion_matrix(y_true, y_pred))
22+
print("\n# Confusion Matrix (From Left to Right & Top to Bottom: \nTrue Positive, False Negative, \nFalse Positive, True Negative)")
23+
print(r)
24+
25+
# Remember that the higher the precision, the more confident the model is when it classifies a sample as Positive.
26+
# Higher the recall, the more positive samples the model correctly classified as Positive.
27+
28+
precision = precision_score(y_true=y_true, y_pred=y_pred, pos_label="positive")
29+
print("\n# Precision = 4/(4+1)")
30+
print(precision)
31+
32+
recall = recall_score(y_true=y_true, y_pred=y_pred, pos_label="positive")
33+
print("\n# Recall = 4/(4+2)")
34+
print(recall)
35+
36+
# Appendix B - working with multiple thresholds
37+
y_true = ["positive", "negative", "negative", "positive", "positive", "positive", "negative", "positive", "negative", "positive", "positive", "positive", "positive", "negative", "negative", "negative"]
38+
39+
pred_scores = [0.7, 0.3, 0.5, 0.6, 0.55, 0.9, 0.4, 0.2, 0.4, 0.3, 0.7, 0.5, 0.8, 0.2, 0.3, 0.35]
40+
41+
thresholds = np.arange(start=0.2, stop=0.7, step=0.05)
42+
43+
# Due to the importance of both precision and recall, there is a precision-recall curve that shows
44+
# the tradeoff between the precision and recall values for different thresholds.
45+
# This curve helps to select the best threshold to maximize both metrics
46+
47+
def precision_recall_curve(y_true, pred_scores, thresholds):
48+
precisions = []
49+
recalls = []
50+
f1_scores = []
51+
52+
for threshold in thresholds:
53+
y_pred = ["positive" if score >= threshold else "negative" for score in pred_scores]
54+
55+
precision = precision_score(y_true=y_true, y_pred=y_pred, pos_label="positive")
56+
recall = recall_score(y_true=y_true, y_pred=y_pred, pos_label="positive")
57+
f1_score = (2 * precision * recall) / (precision + recall)
58+
59+
precisions.append(precision)
60+
recalls.append(recall)
61+
f1_scores.append(f1_score)
62+
63+
return precisions, recalls, f1_scores
64+
65+
precisions, recalls, f1_scores = precision_recall_curve(y_true=y_true,
66+
pred_scores=pred_scores,
67+
thresholds=thresholds)
68+
69+
print("\nRecall:: Precision :: F1-Score",)
70+
for p, r, f in zip(precisions, recalls, f1_scores):
71+
print(round(r,4),"\t::\t",round(p,4),"\t::\t",round(f,4))
72+
73+
# np.max() returns the max. value in the array
74+
# np.argmax() will return the index of the value found by np.max()
75+
76+
print('Best F1-Score: ', np.max(f1_scores))
77+
idx_best_f1 = np.argmax(f1_scores)
78+
print('\nBest threshold: ', thresholds[idx_best_f1])
79+
print('Index of threshold: ', idx_best_f1)
80+
81+
# Can disable comment to display the plot
82+
83+
# plt.plot(recalls, precisions, linewidth=4, color="red")
84+
# plt.scatter(recalls[idx_best_f1], precisions[idx_best_f1], zorder=1, linewidth=6)
85+
# plt.xlabel("Recall", fontsize=12, fontweight='bold')
86+
# plt.ylabel("Precision", fontsize=12, fontweight='bold')
87+
# plt.title("Precision-Recall Curve", fontsize=15, fontweight="bold")
88+
# plt.show()
89+
90+
# Appendix C - average precision (AP)
91+
precisions, recalls, f1_scores = precision_recall_curve(y_true=y_true,
92+
pred_scores=pred_scores,
93+
thresholds=thresholds)
94+
95+
precisions.append(1)
96+
recalls.append(0)
97+
98+
precisions = np.array(precisions)
99+
recalls = np.array(recalls)
100+
101+
print('\nRecall ::',recalls)
102+
print('Precision ::',precisions)
103+
104+
AP = np.sum((recalls[:-1] - recalls[1:]) * precisions[:-1])
105+
print("\nAP --", AP)
106+
107+
# Appendix D - Intersection over Union
108+
109+
# gt_box -- ground-truth bounding box
110+
# pred_box -- prediction bounding box
111+
def intersection_over_union(gt_box, pred_box):
112+
113+
inter_box_top_left = [max(gt_box[0], pred_box[0]), max(gt_box[1], pred_box[1])]
114+
115+
print("\ninter_box_top_left:", inter_box_top_left)
116+
print("gt_box:", gt_box)
117+
print("pred_box:", pred_box)
118+
inter_box_bottom_right = [min(gt_box[0]+gt_box[2], pred_box[0]+pred_box[2]), min(gt_box[1]+gt_box[3], pred_box[1]+pred_box[3])]
119+
print("inter_box_bottom_right:", inter_box_bottom_right)
120+
121+
inter_box_w = inter_box_bottom_right[0] - inter_box_top_left[0]
122+
print("inter_box_w:", inter_box_w)
123+
inter_box_h = inter_box_bottom_right[1] - inter_box_top_left[1]
124+
print("inter_box_h:", inter_box_h)
125+
126+
intersection = inter_box_w * inter_box_h
127+
union = gt_box[2] * gt_box[3] + pred_box[2] * pred_box[3] - intersection
128+
129+
iou = intersection / union
130+
131+
return iou, intersection, union
132+
133+
gt_box1 = [320, 220, 680, 900]
134+
pred_box1 = [500, 320, 550, 700]
135+
136+
gt_box2 = [645, 130, 310, 320]
137+
pred_box2 = [500, 60, 310, 320]
138+
139+
iou1 = intersection_over_union(gt_box1, pred_box1)
140+
print("\nIOU1 ::", iou1)
141+
142+
iou2 = intersection_over_union(gt_box2, pred_box2)
143+
print("\nIOU2 ::", iou2)

0 commit comments

Comments
 (0)