Skip to content

Commit

Permalink
Black box titles for multiclass histograms (#24)
Browse files Browse the repository at this point in the history
* update readme

* add black boxes to histplot, update test
  • Loading branch information
joshuawe committed Dec 17, 2023
1 parent 23c83cc commit de8a750
Show file tree
Hide file tree
Showing 6 changed files with 88 additions and 203 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,9 @@ Furthermore, this library presents other useful visualizations, such as **compar
| ROC Curve (AUROC) with bootstrapping | Precision-Recall Curve | y_prob histogram |


| <img src="https://github.com/joshuawe/plots_and_graphs/blob/main/images/raincloud.png?raw=true" width="300" alt="Your Image"> | <img src="data:image/gif;base64,R0lGODlhAQABAIAAAAAAAP///yH5BAEAAAAALAAAAAABAAEAAAIBRAA7" width="300" height="300" alt=""> | <img src="data:image/gif;base64,R0lGODlhAQABAIAAAAAAAP///yH5BAEAAAAALAAAAAABAAEAAAIBRAA7" width="300" height="300" alt=""> |
| <img src="https://github.com/joshuawe/plots_and_graphs/blob/main/images/multiclass/histogram_4_classes.png?raw=true" width="300" alt="Your Image"> | <img src="https://github.com/joshuawe/plots_and_graphs/blob/main/images/multiclass/roc_curves_multiclass.png?raw=true" width="300" alt=""> | <img src="data:image/gif;base64,R0lGODlhAQABAIAAAAAAAP///yH5BAEAAAAALAAAAAABAAEAAAIBRAA7" width="300" height="300" alt=""> |
|:--------------------------------------------------:|:-------------------------------------------------:| :-------------------------------------------------:|
| Raincloud | | |
| Histogram (y_scores) | ROC curves (AUROC) with bootstrapping | |



Expand Down
Binary file modified images/multiclass/histogram_4_classes.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
194 changes: 26 additions & 168 deletions notebooks/multiclass_classification.ipynb

Large diffs are not rendered by default.

45 changes: 17 additions & 28 deletions plotsandgraphs/multiclass_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from sklearn.utils import resample
from tqdm import tqdm

from plotsandgraphs.utils import bootstrap, set_black_title_box, scale_ax_bbox, get_cmap
from plotsandgraphs.utils import bootstrap, set_black_title_boxes, scale_ax_bbox, get_cmap


def plot_roc_curve(
Expand All @@ -32,7 +32,7 @@ def plot_roc_curve(
figsize: Optional[Tuple[float, float]] = None,
class_labels: Optional[List[str]] = None,
split_plots: bool = True,
save_fig_path=Optional[Union[str, Tuple[str, str]]],
save_fig_path:Optional[Union[str, Tuple[str, str]]] = None,
) -> Tuple[Figure, Union[Figure, None]]:
"""
Creates two plots.
Expand Down Expand Up @@ -188,22 +188,11 @@ def roc_metric_function(y_true, y_score):
for i in range(num_classes, len(axes.flat)):
axes.flat[i].axis("off")

# make the subplot tiles (and black boxes)
for i in range(num_classes):
set_black_title_box(axes.flat[i], f"Class {i}")
plt.tight_layout(h_pad=1.5)
# make the subplot tiles (and black boxes)
# First time to get the approx. correct spacing with plt.tight_layout()
# Second time to get the correct width of the black box
# Thank you matplotlib ...
for i in range(num_classes):
set_black_title_box(
axes.flat[i],
f"Class {i}",
set_title_kwargs={
"fontdict": {"fontname": "Arial Black", "fontweight": "bold"}
},
)
# create the subplot tiles (and black boxes)
set_black_title_boxes(axes.flat[:num_classes], class_labels)




# ---------- AUROC overview plot comparing classes ----------
# Make an AUROC overview plot comparing the aurocs per class and combined
Expand Down Expand Up @@ -281,13 +270,12 @@ def auroc_metric_function(y_true, y_score, average, multi_class):
def plot_y_prob_histogram(
y_true: np.ndarray, y_prob: Optional[np.ndarray] = None, save_fig_path=None
) -> Figure:
num_classes = y_true.shape[-1]
class_labels = [f"Class {i}" for i in range(num_classes)]

# Aiming for a square plot
plot_cols = np.ceil(np.sqrt(y_true.shape[-1])).astype(
int
) # Number of plots in a row
plot_rows = np.ceil(y_true.shape[-1] / plot_cols).astype(
int
) # Number of plots in a column
plot_cols = np.ceil(np.sqrt(num_classes)).astype(int) # Number of plots in a row # noqa
plot_rows = np.ceil(num_classes / plot_cols).astype(int) # Number of plots in a column # noqa
fig, axes = plt.subplots(
nrows=plot_rows,
ncols=plot_cols,
Expand All @@ -298,11 +286,11 @@ def plot_y_prob_histogram(
plt.suptitle("Predicted probability histogram")

# Flatten axes if there is only one class, even though this function is designed for multiclasses
if y_true.shape[-1] == 1:
if num_classes == 1:
axes = np.array([axes])

for i, ax in enumerate(axes.flat):
if i >= y_true.shape[-1]:
if i >= num_classes:
ax.axis("off")
continue

Expand All @@ -327,7 +315,7 @@ def plot_y_prob_histogram(
linewidth=2,
rwidth=1,
)
ax.set_title(f"Class {i}")
ax.set_title(class_labels[i])
ax.set_xlim((-0.005, 1.0))
# if subplot in first column
if (i % plot_cols) == 0:
Expand All @@ -342,7 +330,8 @@ def plot_y_prob_histogram(
if i == 0:
ax.legend()

plt.tight_layout()
# create the subplot tiles (and black boxes)
set_black_title_boxes(axes.flat[:num_classes], class_labels)

# save plot
if save_fig_path is not None:
Expand Down
46 changes: 42 additions & 4 deletions plotsandgraphs/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import Optional, List, Callable, Dict, Tuple, Union, TYPE_CHECKING
from typing import Optional, List, Callable, Dict, Tuple, Union, TYPE_CHECKING, Literal
from tqdm import tqdm
from sklearn.utils import resample
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.path import Path
from matplotlib.patches import BoxStyle
from matplotlib.colors import LinearSegmentedColormap
Expand Down Expand Up @@ -99,8 +100,9 @@ def __call__(self, x0, y0, width, height, mutation_size):
closed=True)


def set_black_title_box(ax: "Axes", title=str, backgroundcolor='black', color='white', set_title_kwargs: Dict={}):
def _set_black_title_box(ax: "Axes", title:str, backgroundcolor='black', color='white', title_kwargs: Optional[Dict]=None):
"""
Note: Do not use this function by itself, instead use `set_black_title_boxes()`.
Sets the title of the given axes with a black bounding box.
Note: When using `plt.tight_layout()` the box might not have the correct width. First call `plt.tight_layout()` and then `set_black_title_box()`.
Expand All @@ -111,14 +113,50 @@ def set_black_title_box(ax: "Axes", title=str, backgroundcolor='black', color='w
- color: The color of the title text (default: 'white').
- set_title_kwargs: Keyword arguments to pass to `ax.set_title()`.
"""
if title_kwargs is None:
title_kwargs = {'fontdict': {"fontname": "Arial Black", "fontweight": "bold"}}
BoxStyle._style_list["ext"] = ExtendedTextBox_v2
ax_width = ax.get_window_extent().width
# make title with black bounding box
title = ax.set_title(title, backgroundcolor=backgroundcolor, color=color, **set_title_kwargs)
bb = title.get_bbox_patch() # get bbox from title
title_instance = ax.set_title(title, backgroundcolor=backgroundcolor, color=color, **title_kwargs)
bb = title_instance.get_bbox_patch() # get bbox from title
bb.set_boxstyle("ext", pad=0.1, width=ax_width) # use custom style


def set_black_title_boxes(axes: "np.ndarray[Axes]", titles: List[str], backgroundcolor='black', color='white', title_kwargs: Optional[Dict]=None, tight_layout_kwargs: Dict={}):
"""
Creates black boxes for the subtitles above the given axes with the given titles. The subtitles are centered above the axes.
Parameters
----------
axes : np.ndarray["Axes"]
np.ndarray of matplotlib.axes.Axes objects. (Usually returned by plt.subplots() call)
titles : List[str]
List of titles for the axes. Same length as axes.
backgroundcolor : str, optional
Background color of boxes, by default 'black'
color : str, optional
Font color, by default 'white'
title_kwargs : Dict, optional
kwargs for the `ax.set_title()` call, by default {}
tight_layout_kwargs : Dict, optional
kwargs for the `plt.tight_layout()` call, by default {}
"""

for i, ax in enumerate(axes.flat):
_set_black_title_box(ax, titles[i], backgroundcolor, color, title_kwargs)

plt.tight_layout(**tight_layout_kwargs)

for i, ax in enumerate(axes.flat):
_set_black_title_box(ax, titles[i], backgroundcolor, color, title_kwargs)


return




def scale_ax_bbox(ax: "Axes", factor: float):
# Get the current position of the subplot
box = ax.get_position()
Expand Down
2 changes: 1 addition & 1 deletion tests/test_multiclass_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def test_hist_plot():
random_data_binary_classifier : Tuple[np.ndarray, np.ndarray]
The simulated data.
"""
for num_classes in [2, 3, 4, 5, 10, 16, 25]:
for num_classes in [1, 2, 3, 4, 5, 10, 16, 25]:
y_true, y_prob = random_data_multiclass_classifier(num_classes=num_classes)
print(TEST_RESULTS_PATH)
multiclass.plot_y_prob_histogram(y_true=y_true, y_prob=y_prob, save_fig_path=TEST_RESULTS_PATH / f"histogram_{num_classes}_classes.png")
Expand Down

0 comments on commit de8a750

Please sign in to comment.