@@ -217,4 +217,37 @@ def heatmap(df, n:int,target:str,columns:None):
217
217
218
218
plt .figure (figsize = (20 ,10 ))
219
219
hm = sns .heatmap (cm , cbar = True , annot = True , cmap = 'YlOrBr' , fmt = '.2f' , yticklabels = cols .values , xticklabels = cols .values )
220
- return hm
220
+ return hm
221
+
222
+ def plot_roc_curve (y_true , y_pred , pos_label = 1 , figsize = (8 , 8 )):
223
+ '''
224
+ Function to plot the ROC curve of a binary classifier
225
+
226
+ Parameters:
227
+
228
+ y_true: true labels
229
+ y_pred: model predictions
230
+ pos_label: positive label (default: 1)
231
+ figsize: figure size (default: (8, 8))
232
+
233
+ Returns:
234
+ Lineplot of the ROC curve
235
+
236
+ '''
237
+ # Compute the false positive rate, true positive rate, and thresholds
238
+ fpr , tpr , thresholds = roc_curve (y_true , y_pred , pos_label = pos_label )
239
+
240
+ # Compute the area under the curve (AUC)
241
+ roc_auc = auc (fpr , tpr )
242
+
243
+ # Create the ROC curve plot
244
+ plt .figure (figsize = figsize )
245
+ plt .plot (fpr , tpr , color = 'darkorange' , lw = 2 , label = 'ROC curve (AUC = %0.2f)' % roc_auc )
246
+ plt .plot ([0 , 1 ], [0 , 1 ], color = 'navy' , lw = 2 , linestyle = '--' )
247
+ plt .xlim ([0.0 , 1.0 ])
248
+ plt .ylim ([0.0 , 1.05 ])
249
+ plt .xlabel ('False Positive Rate' )
250
+ plt .ylabel ('True Positive Rate' )
251
+ plt .title ('Receiver operating characteristic (ROC) curve' )
252
+ plt .legend (loc = "lower right" )
253
+ plt .show ()
0 commit comments