In [None]:
def plot_roc_pr_ovr(df, classes=['Negative', 'Neutral', 'Positive']):
    # Plots the Probability Distributions and the ROC Curves One vs Rest
    plt.figure(figsize = (12, 8))
    bins = [i/20 for i in range(20)] + [1]
    roc_auc_ovr = {}

    for i in range(len(classes)):
        # Gets the class
        c = classes[i]

        # Prepares an auxiliar dataframe to help with the plots
        df_aux = df.copy()
        df_aux['label'] = [1 if y == c else 0 for y in df['label']]
        df_aux['Probability'] = df[c]
        df_aux = df_aux.reset_index(drop = True)

        # Plots the probability distribution for the class and the rest
        ax = plt.subplot(3, 3, i+1)
        sns.histplot(x="Probability", data=df_aux, hue='label', color='b', ax=ax, bins=bins)
        ax.set_title(c)
        ax.legend([f"Class: {c}", "Rest"])
        ax.set_xlabel(f"P(x = {c})")

        # Calculates the ROC Coordinates and plots the ROC Curves
        ax_middle = plt.subplot(3, 3, i+4)
        fpr, tpr, __ = roc_curve(df_aux['label'], df_aux['Probability'])
        sns.lineplot(x=fpr, y=tpr, ax=ax_middle)
        sns.lineplot(x=[0, 1], y=[0, 1], color='green', ax=ax_middle)
        ax_middle.lines[1].set_linestyle("--")
        plt.xlim(-0.05, 1.05)
        plt.ylim(-0.05, 1.05)
        plt.xlabel("False Positive Rate")
        plt.ylabel("True Positive Rate")
        ax_middle.set_title("ROC Curve OvR (AUC={:.4f})".format(auc(fpr, tpr)))

        # Calculates the ROC AUC OvR
        roc_auc_ovr[c] = roc_auc_score(df_aux['label'], df_aux['Probability'])

        # Calculates the Precision-Recall OvR
        ax_bottom = plt.subplot(3, 3, i+7)
        precision, recall, _ = precision_recall_curve(df_aux['label'], df_aux['Probability'])
        ap = average_precision_score(df_aux['label'], df_aux['Probability'], average='macro')
        sns.lineplot(x=recall, y=precision, color = 'b', ax = ax_bottom)
        ax_bottom.set_title("Precision-Recall Curve OvR (AP={:.4f})".format(ap))
        ax_bottom.set_ylabel("Precision")
        ax_bottom.set_xlabel("Recall")

    plt.tight_layout()