Skip to content

Commit

Permalink
histogram binary classifier: allow seperating class 1 and 0, optionally
Browse files Browse the repository at this point in the history
  • Loading branch information
joshuawe committed Nov 10, 2023
1 parent cadc236 commit b80c009
Showing 1 changed file with 29 additions and 4 deletions.
33 changes: 29 additions & 4 deletions plotsandgraphs/binary_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,12 +465,37 @@ def plot_calibration_curve(
return fig, ece


def plot_y_prob_histogram(y_prob: np.ndarray, save_fig_path=None) -> Figure:
def plot_y_prob_histogram(y_prob: np.ndarray, y_true: Optional[np.ndarray]=None, save_fig_path=None) -> Figure:
"""
Provides a histogram for the predicted probabilities of a binary classifier. If ```y_true``` is provided, it divides the ```y_prob``` values into the two classes and plots them jointly into the same plot with different colors.
Parameters
----------
y_prob : np.ndarray
The output probabilities of the classifier. Between 0 and 1.
y_true : Optional[np.ndarray], optional
The true class labels, by default None
save_fig_path : _type_, optional
Path where to save figure, by default None
Returns
-------
Figure
The histrogram as a matplotlib figure
"""
fig = plt.figure(figsize=(5, 5))
ax = fig.add_subplot(111)
ax.hist(y_prob, bins=10, alpha=0.9, edgecolor="midnightblue", linewidth=2, rwidth=1)
# same histogram as above, but with border lines
# ax.hist(y_prob, bins=10, alpha=0.5, edgecolor='black', linewidth=1.2)

if y_true is None:
ax.hist(y_prob, bins=10, alpha=0.9, edgecolor="midnightblue", linewidth=2, rwidth=1)
# same histogram as above, but with border lines
# ax.hist(y_prob, bins=10, alpha=0.5, edgecolor='black', linewidth=1.2)
else:
alpha = 0.6
ax.hist(y_prob[y_true==0], bins=10, alpha=alpha, edgecolor="midnightblue", linewidth=2, rwidth=1, label="$\\hat{y} = 0$")
ax.hist(y_prob[y_true==1], bins=10, alpha=alpha, edgecolor="darkred", linewidth=2, rwidth=1, label="$\\hat{y} = 1$")

plt.legend()
ax.set(xlabel="Predicted probability [-]", ylabel="Count [-]", xlim=(-0.01, 1.0))
ax.set_title("Histogram of predicted probabilities")

Expand Down

0 comments on commit b80c009

Please sign in to comment.