1
+ import numpy as np
2
+ from sklearn .metrics import confusion_matrix , precision_score , recall_score
3
+ import matplotlib .pyplot as plt
4
+ import matplotlib .patches as ptch
5
+
6
+ # Appendix A - working with single threshold
7
+ pred_scores = [0.7 , 0.3 , 0.5 , 0.6 , 0.55 , 0.9 , 0.4 , 0.2 , 0.4 , 0.3 ]
8
+ y_true = ["positive" , "negative" , "negative" , "positive" , "positive" , "positive" , "negative" , "positive" , "negative" , "positive" ]
9
+
10
+ # To convert the scores into a class label, a threshold is used.
11
+ # When the score is equal to or above the threshold, the sample is classified as one class.
12
+ # Otherwise, it is classified as the other class.
13
+ # Suppose a sample is Positive if its score is above or equal to the threshold. Otherwise, it is Negative.
14
+ # The next block of code converts the scores into class labels with a threshold of 0.5.
15
+
16
+ threshold = 0.5
17
+
18
+ y_pred = ["positive" if score >= threshold else "negative" for score in pred_scores ]
19
+ print (y_pred )
20
+
21
+ r = np .flip (confusion_matrix (y_true , y_pred ))
22
+ print ("\n # Confusion Matrix (From Left to Right & Top to Bottom: \n True Positive, False Negative, \n False Positive, True Negative)" )
23
+ print (r )
24
+
25
+ # Remember that the higher the precision, the more confident the model is when it classifies a sample as Positive.
26
+ # Higher the recall, the more positive samples the model correctly classified as Positive.
27
+
28
+ precision = precision_score (y_true = y_true , y_pred = y_pred , pos_label = "positive" )
29
+ print ("\n # Precision = 4/(4+1)" )
30
+ print (precision )
31
+
32
+ recall = recall_score (y_true = y_true , y_pred = y_pred , pos_label = "positive" )
33
+ print ("\n # Recall = 4/(4+2)" )
34
+ print (recall )
35
+
36
+ # Appendix B - working with multiple thresholds
37
+ y_true = ["positive" , "negative" , "negative" , "positive" , "positive" , "positive" , "negative" , "positive" , "negative" , "positive" , "positive" , "positive" , "positive" , "negative" , "negative" , "negative" ]
38
+
39
+ pred_scores = [0.7 , 0.3 , 0.5 , 0.6 , 0.55 , 0.9 , 0.4 , 0.2 , 0.4 , 0.3 , 0.7 , 0.5 , 0.8 , 0.2 , 0.3 , 0.35 ]
40
+
41
+ thresholds = np .arange (start = 0.2 , stop = 0.7 , step = 0.05 )
42
+
43
+ # Due to the importance of both precision and recall, there is a precision-recall curve that shows
44
+ # the tradeoff between the precision and recall values for different thresholds.
45
+ # This curve helps to select the best threshold to maximize both metrics
46
+
47
+ def precision_recall_curve (y_true , pred_scores , thresholds ):
48
+ precisions = []
49
+ recalls = []
50
+ f1_scores = []
51
+
52
+ for threshold in thresholds :
53
+ y_pred = ["positive" if score >= threshold else "negative" for score in pred_scores ]
54
+
55
+ precision = precision_score (y_true = y_true , y_pred = y_pred , pos_label = "positive" )
56
+ recall = recall_score (y_true = y_true , y_pred = y_pred , pos_label = "positive" )
57
+ f1_score = (2 * precision * recall ) / (precision + recall )
58
+
59
+ precisions .append (precision )
60
+ recalls .append (recall )
61
+ f1_scores .append (f1_score )
62
+
63
+ return precisions , recalls , f1_scores
64
+
65
+ precisions , recalls , f1_scores = precision_recall_curve (y_true = y_true ,
66
+ pred_scores = pred_scores ,
67
+ thresholds = thresholds )
68
+
69
+ print ("\n Recall:: Precision :: F1-Score" ,)
70
+ for p , r , f in zip (precisions , recalls , f1_scores ):
71
+ print (round (r ,4 ),"\t ::\t " ,round (p ,4 ),"\t ::\t " ,round (f ,4 ))
72
+
73
+ # np.max() returns the max. value in the array
74
+ # np.argmax() will return the index of the value found by np.max()
75
+
76
+ print ('Best F1-Score: ' , np .max (f1_scores ))
77
+ idx_best_f1 = np .argmax (f1_scores )
78
+ print ('\n Best threshold: ' , thresholds [idx_best_f1 ])
79
+ print ('Index of threshold: ' , idx_best_f1 )
80
+
81
+ # Can disable comment to display the plot
82
+
83
+ # plt.plot(recalls, precisions, linewidth=4, color="red")
84
+ # plt.scatter(recalls[idx_best_f1], precisions[idx_best_f1], zorder=1, linewidth=6)
85
+ # plt.xlabel("Recall", fontsize=12, fontweight='bold')
86
+ # plt.ylabel("Precision", fontsize=12, fontweight='bold')
87
+ # plt.title("Precision-Recall Curve", fontsize=15, fontweight="bold")
88
+ # plt.show()
89
+
90
+ # Appendix C - average precision (AP)
91
+ precisions , recalls , f1_scores = precision_recall_curve (y_true = y_true ,
92
+ pred_scores = pred_scores ,
93
+ thresholds = thresholds )
94
+
95
+ precisions .append (1 )
96
+ recalls .append (0 )
97
+
98
+ precisions = np .array (precisions )
99
+ recalls = np .array (recalls )
100
+
101
+ print ('\n Recall ::' ,recalls )
102
+ print ('Precision ::' ,precisions )
103
+
104
+ AP = np .sum ((recalls [:- 1 ] - recalls [1 :]) * precisions [:- 1 ])
105
+ print ("\n AP --" , AP )
106
+
107
+ # Appendix D - Intersection over Union
108
+
109
+ # gt_box -- ground-truth bounding box
110
+ # pred_box -- prediction bounding box
111
+ def intersection_over_union (gt_box , pred_box ):
112
+
113
+ inter_box_top_left = [max (gt_box [0 ], pred_box [0 ]), max (gt_box [1 ], pred_box [1 ])]
114
+
115
+ print ("\n inter_box_top_left:" , inter_box_top_left )
116
+ print ("gt_box:" , gt_box )
117
+ print ("pred_box:" , pred_box )
118
+ inter_box_bottom_right = [min (gt_box [0 ]+ gt_box [2 ], pred_box [0 ]+ pred_box [2 ]), min (gt_box [1 ]+ gt_box [3 ], pred_box [1 ]+ pred_box [3 ])]
119
+ print ("inter_box_bottom_right:" , inter_box_bottom_right )
120
+
121
+ inter_box_w = inter_box_bottom_right [0 ] - inter_box_top_left [0 ]
122
+ print ("inter_box_w:" , inter_box_w )
123
+ inter_box_h = inter_box_bottom_right [1 ] - inter_box_top_left [1 ]
124
+ print ("inter_box_h:" , inter_box_h )
125
+
126
+ intersection = inter_box_w * inter_box_h
127
+ union = gt_box [2 ] * gt_box [3 ] + pred_box [2 ] * pred_box [3 ] - intersection
128
+
129
+ iou = intersection / union
130
+
131
+ return iou , intersection , union
132
+
133
+ gt_box1 = [320 , 220 , 680 , 900 ]
134
+ pred_box1 = [500 , 320 , 550 , 700 ]
135
+
136
+ gt_box2 = [645 , 130 , 310 , 320 ]
137
+ pred_box2 = [500 , 60 , 310 , 320 ]
138
+
139
+ iou1 = intersection_over_union (gt_box1 , pred_box1 )
140
+ print ("\n IOU1 ::" , iou1 )
141
+
142
+ iou2 = intersection_over_union (gt_box2 , pred_box2 )
143
+ print ("\n IOU2 ::" , iou2 )
0 commit comments