Skip to content

Commit a13e1c1

Browse files
committed
plot_curve_jared
1 parent d9ac3cb commit a13e1c1

File tree

1 file changed

+34
-1
lines changed

1 file changed

+34
-1
lines changed

toolkit/plot.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,4 +217,37 @@ def heatmap(df, n:int,target:str,columns:None):
217217

218218
plt.figure(figsize=(20,10))
219219
hm = sns.heatmap(cm, cbar=True, annot=True, cmap='YlOrBr', fmt='.2f', yticklabels=cols.values, xticklabels=cols.values)
220-
return hm
220+
return hm
221+
222+
def plot_roc_curve(y_true, y_pred, pos_label=1, figsize=(8, 8)):
223+
'''
224+
Function to plot the ROC curve of a binary classifier
225+
226+
Parameters:
227+
228+
y_true: true labels
229+
y_pred: model predictions
230+
pos_label: positive label (default: 1)
231+
figsize: figure size (default: (8, 8))
232+
233+
Returns:
234+
Lineplot of the ROC curve
235+
236+
'''
237+
# Compute the false positive rate, true positive rate, and thresholds
238+
fpr, tpr, thresholds = roc_curve(y_true, y_pred, pos_label=pos_label)
239+
240+
# Compute the area under the curve (AUC)
241+
roc_auc = auc(fpr, tpr)
242+
243+
# Create the ROC curve plot
244+
plt.figure(figsize=figsize)
245+
plt.plot(fpr, tpr, color='darkorange', lw=2, label='ROC curve (AUC = %0.2f)' % roc_auc)
246+
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
247+
plt.xlim([0.0, 1.0])
248+
plt.ylim([0.0, 1.05])
249+
plt.xlabel('False Positive Rate')
250+
plt.ylabel('True Positive Rate')
251+
plt.title('Receiver operating characteristic (ROC) curve')
252+
plt.legend(loc="lower right")
253+
plt.show()

0 commit comments

Comments
 (0)