Skip to content

Commit

Permalink
Merge pull request #28 from joshuawe/pipeline
Browse files Browse the repository at this point in the history
Add Pipeline - get metrics with one line of code
  • Loading branch information
joshuawe committed Dec 27, 2023
2 parents a1caedd + 37d51a3 commit 7bc618b
Show file tree
Hide file tree
Showing 12 changed files with 878 additions and 193 deletions.
8 changes: 8 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"python.testing.pytestArgs": [
"tests"
],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true,
"python.analysis.typeCheckingMode": "basic"
}
42 changes: 18 additions & 24 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ Furthermore, this library presents other useful visualizations, such as **compar
- Classification Report
- Confusion Matrix
- ROC curve (AUROC)
- y_prob histogram
- y_score histogram

- *multi-class classifier*

Expand All @@ -61,7 +61,7 @@ Furthermore, this library presents other useful visualizations, such as **compar

| <img src="https://github.com/joshuawe/plots_and_graphs/blob/main/images/roc_curve_bootstrap.png?raw=true" width="300" alt="Your Image"> | <img src="https://github.com/joshuawe/plots_and_graphs/blob/main/images/pr_curve.png?raw=true" width="300" alt="Your Image"> | <img src="https://github.com/joshuawe/plots_and_graphs/blob/main/images/y_prob_histogram.png?raw=true" width="300" alt="Your Image"> |
|:--------------------------------------------------:|:----------------------------------------------------------:|:-------------------------------------------------:|
| ROC Curve (AUROC) with bootstrapping | Precision-Recall Curve | y_prob histogram |
| ROC Curve (AUROC) with bootstrapping | Precision-Recall Curve | y_score histogram |


| <img src="https://github.com/joshuawe/plots_and_graphs/blob/main/images/multiclass/histogram_4_classes.png?raw=true" width="300" alt="Your Image"> | <img src="https://github.com/joshuawe/plots_and_graphs/blob/main/images/multiclass/roc_curves_multiclass.png?raw=true" width="300" alt=""> | <img src="" width="300" height="300" alt=""> |
Expand Down Expand Up @@ -95,34 +95,28 @@ pip install -e .

# Usage

Example usage of results from a binary classifier for a calibration curve.
Get all classification metrics with **ONE** line of code. Here, for a binary classifier:

```python
import matplotlib.pyplot as plt
import numpy as np
import plotsandgraphs as pandg
# ...
pandg.pipeline.binary_classifier(y_true, y_score)
```

# create some predictions of a hypothetical binary classifier
n_samples = 1000
y_true = np.random.choice([0,1], n_samples, p=[0.4, 0.6]) # the true class labels 0 or 1, with class imbalance 40:60

y_prob = np.zeros(y_true.shape) # a model's probability of class 1 predictions
y_prob[y_true==1] = np.random.beta(1, 0.6, y_prob[y_true==1].shape)
y_prob[y_true==0] = np.random.beta(0.5, 1, y_prob[y_true==0].shape)

# show prob distribution
fig_hist = pandg.binary_classifier.plot_y_prob_histogram(y_prob, y_true, save_fig_path=None)

# create calibration curve
fig_auroc = pandg.binary_classifier.plot_calibration_curve(y_prob, y_true, save_fig_path=None)
Or with some more configs:
```Python
configs = {
'roc': {'n_bootstraps': 10000},
'pr': {'figsize': (8,10)}
}
pandg.pipeline.binary_classifier(y_true, y_score, save_fig_path='results/metrics', file_type='png', plot_kwargs=configs)
```

For multiclass classification:

# --- OPTIONAL: Customize figure ---
# get axis of figure and change title
axes = fig_auroc.get_axes()
ax0 = axes[0]
ax0.set_title('New Title for Calibration Plot')
fig_auroc.show()
```Python
# with multiclass data y_true (one-hot encoded) and y_score
pandg.pipeline.multiclass_classifier(y_true, y_score)
```

# Requirements
Expand Down
331 changes: 331 additions & 0 deletions notebooks/pipeline.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions plotsandgraphs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from . import binary_classifier
from . import compare_distributions
from . import multiclass_classifier
from . import pipeline
63 changes: 44 additions & 19 deletions plotsandgraphs/binary_classifier.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path
from typing import Optional
from typing import Optional, Any, Union
import matplotlib.pyplot as plt
from matplotlib.colors import to_rgba
from matplotlib.figure import Figure
Expand Down Expand Up @@ -45,7 +45,9 @@ def plot_accuracy(y_true, y_pred, name="", save_fig_path=None) -> Figure:
return fig


def plot_confusion_matrix(y_true: np.ndarray, y_pred: np.ndarray, save_fig_path=None) -> Figure:
def plot_confusion_matrix(
y_true: np.ndarray, y_pred: np.ndarray, save_fig_path=None
) -> Figure:
import matplotlib.colors as colors

# Compute the confusion matrix
Expand All @@ -54,7 +56,9 @@ def plot_confusion_matrix(y_true: np.ndarray, y_pred: np.ndarray, save_fig_path=
cm = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis]

# Create the ConfusionMatrixDisplay instance and plot it
cmd = ConfusionMatrixDisplay(cm, display_labels=["class 0\nnegative", "class 1\npositive"])
cmd = ConfusionMatrixDisplay(
cm, display_labels=["class 0\nnegative", "class 1\npositive"]
)
fig, ax = plt.subplots(figsize=(4, 4))
cmd.plot(
cmap="YlOrRd",
Expand Down Expand Up @@ -144,6 +148,11 @@ def plot_classification_report(
ax : Matplotlib.pyplot.Axe
Axe object from matplotlib
"""
print(
"Warning: plot_classification_report is not experiencing a bug and is, hence, currently skipped."
)
return

import matplotlib as mpl
import matplotlib.colors as colors
import seaborn as sns
Expand All @@ -153,7 +162,11 @@ def plot_classification_report(
cmap = "YlOrRd"

clf_report = classification_report(y_true, y_pred, output_dict=True, **kwargs)
keys_to_plot = [key for key in clf_report.keys() if key not in ("accuracy", "macro avg", "weighted avg")]
keys_to_plot = [
key
for key in clf_report.keys()
if key not in ("accuracy", "macro avg", "weighted avg")
]
df = pd.DataFrame(clf_report, columns=keys_to_plot).T
# the following line ensures that dataframe are sorted from the majority classes to the minority classes
df.sort_values(by=["support"], inplace=True)
Expand Down Expand Up @@ -322,7 +335,9 @@ def plot_roc_curve(
auc_upper = np.quantile(bootstrap_aucs, CI_upper)
auc_lower = np.quantile(bootstrap_aucs, CI_lower)
label = f"{confidence_interval:.0%} CI: [{auc_lower:.2f}, {auc_upper:.2f}]"
plt.fill_between(base_fpr, tprs_lower, tprs_upper, alpha=0.3, label=label, zorder=2)
plt.fill_between(
base_fpr, tprs_lower, tprs_upper, alpha=0.3, label=label, zorder=2
)

if highlight_roc_area is True:
print(
Expand Down Expand Up @@ -354,16 +369,18 @@ def plot_roc_curve(
return fig


def plot_calibration_curve(y_prob: np.ndarray, y_true: np.ndarray, n_bins=10, save_fig_path=None) -> Figure:
def plot_calibration_curve(
y_true: np.ndarray, y_score: np.ndarray, n_bins=10, save_fig_path=None
) -> Figure:
"""
Creates calibration plot for a binary classifier and calculates the ECE.
Parameters
----------
y_prob : np.ndarray
The output probabilities of the classifier. Between 0 and 1.
y_true : np.ndarray
The actual labels of the data. Either 0 or 1.
y_score : np.ndarray
The output probabilities of the classifier. Between 0 and 1.
n_bins : int
The number of bins to use for the calibration curve.
save_fig_path : str, optional
Expand All @@ -376,13 +393,15 @@ def plot_calibration_curve(y_prob: np.ndarray, y_true: np.ndarray, n_bins=10, sa
ece : float
The expected calibration error.
"""
prob_true, prob_pred = calibration_curve(y_true, y_prob, n_bins=n_bins, strategy="uniform")
prob_true, prob_pred = calibration_curve(
y_true, y_score, n_bins=n_bins, strategy="uniform"
)

# Find the number of samples in each bin
bin_counts = np.histogram(y_prob, bins=n_bins, range=(0, 1))[0]
bin_counts = np.histogram(y_score, bins=n_bins, range=(0, 1))[0]

# Calculate the weighted absolute difference (ECE)
ece = np.abs(prob_pred - prob_true) * (bin_counts / len(y_prob))
ece = np.abs(prob_pred - prob_true) * (bin_counts / len(y_score))
ece = ece.sum().round(2)

fig = plt.figure(figsize=(5, 5))
Expand Down Expand Up @@ -449,16 +468,18 @@ def plot_calibration_curve(y_prob: np.ndarray, y_true: np.ndarray, n_bins=10, sa
return fig


def plot_y_prob_histogram(y_prob: np.ndarray, y_true: Optional[np.ndarray] = None, save_fig_path=None) -> Figure:
def plot_y_score_histogram(
y_true: Union[np.ndarray[Any, Any], None], y_score: np.ndarray[Any, Any], save_fig_path=None
) -> Figure:
"""
Provides a histogram for the predicted probabilities of a binary classifier. If ```y_true``` is provided, it divides the ```y_prob``` values into the two classes and plots them jointly into the same plot with different colors.
Provides a histogram for the predicted probabilities of a binary classifier. If ```y_true``` is provided, it divides the ```y_score``` values into the two classes and plots them jointly into the same plot with different colors.
Parameters
----------
y_prob : np.ndarray
The output probabilities of the classifier. Between 0 and 1.
y_true : Optional[np.ndarray], optional
The true class labels, by default None
y_score : np.ndarray
The output probabilities of the classifier. Between 0 and 1.
save_fig_path : _type_, optional
Path where to save figure, by default None
Expand All @@ -471,13 +492,15 @@ def plot_y_prob_histogram(y_prob: np.ndarray, y_true: Optional[np.ndarray] = Non
ax = fig.add_subplot(111)

if y_true is None:
ax.hist(y_prob, bins=10, alpha=0.9, edgecolor="midnightblue", linewidth=2, rwidth=1)
ax.hist(
y_score, bins=10, alpha=0.9, edgecolor="midnightblue", linewidth=2, rwidth=1
)
# same histogram as above, but with border lines
# ax.hist(y_prob, bins=10, alpha=0.5, edgecolor='black', linewidth=1.2)
else:
alpha = 0.6
ax.hist(
y_prob[y_true == 0],
y_score[y_true == 0],
bins=10,
alpha=alpha,
edgecolor="midnightblue",
Expand All @@ -486,7 +509,7 @@ def plot_y_prob_histogram(y_prob: np.ndarray, y_true: Optional[np.ndarray] = Non
label="$\\hat{y} = 0$",
)
ax.hist(
y_prob[y_true == 1],
y_score[y_true == 1],
bins=10,
alpha=alpha,
edgecolor="darkred",
Expand Down Expand Up @@ -577,6 +600,8 @@ def plot_pr_curve(

# Save the figure if save_fig_path is specified
if save_fig_path:
plt.savefig(save_fig_path, bbox_inches="tight")
path = Path(save_fig_path)
path.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(save_fig_path, bbox_inches="tight")

return fig
42 changes: 25 additions & 17 deletions plotsandgraphs/multiclass_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@
from sklearn.utils import resample
from tqdm import tqdm

from plotsandgraphs.utils import bootstrap, set_black_title_boxes, scale_ax_bbox, get_cmap
from plotsandgraphs.utils import (
bootstrap,
set_black_title_boxes,
scale_ax_bbox,
get_cmap,
)


def plot_roc_curve(
Expand All @@ -31,8 +36,8 @@ def plot_roc_curve(
n_bootstraps: int = 1,
figsize: Optional[Tuple[float, float]] = None,
class_labels: Optional[List[str]] = None,
split_plots: bool = True,
save_fig_path:Optional[Union[str, Tuple[str, str]]] = None,
split_plots: bool = False,
save_fig_path: Optional[Union[str, Tuple[str, str]]] = None,
) -> Tuple[Figure, Union[Figure, None]]:
"""
Creates two plots.
Expand All @@ -59,7 +64,7 @@ def plot_roc_curve(
class_labels : List[str], optional
The labels of the classes. By default None.
split_plots : bool, optional
Whether to split the plots into two separate figures. By default True.
Whether to split the plots into two separate figures. By default False.
save_fig_path : Optional[Union[str, Tuple[str, str]]], optional
Path to folder where the figure(s) should be saved. If None then plot is not saved, by default None. If `split_plots` is False, then a single str is required. If True, then a tuple of strings (Pos 1 Roc curves comparison, Pos 2 AUROCs comparison). E.g. `save_fig_path=('figures/roc_curves.png', 'figures/aurocs_comparison.png')`.
Expand Down Expand Up @@ -190,9 +195,6 @@ def roc_metric_function(y_true, y_score):

# create the subplot tiles (and black boxes)
set_black_title_boxes(axes.flat[:num_classes], class_labels)




# ---------- AUROC overview plot comparing classes ----------
# Make an AUROC overview plot comparing the aurocs per class and combined
Expand Down Expand Up @@ -260,18 +262,20 @@ def auroc_metric_function(y_true, y_score, average, multi_class):
fig_aurocs.savefig(path, bbox_inches="tight")
# save roc curves plot
if save_fig_path is not None:
path = save_fig_path[0] if split_plots is True else save_fig_path
path = Path(path)
path = save_fig_path[0] if split_plots is True else save_fig_path # type: ignore
path = Path(path) # type: ignore
path.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(path, bbox_inches="tight")
return fig, fig_aurocs


def plot_y_prob_histogram(
y_true: np.ndarray, y_score: Optional[np.ndarray] = None, save_fig_path: Optional[str]=None
def plot_y_score_histogram(
y_true: np.ndarray,
y_score: Optional[np.ndarray] = None,
save_fig_path: Optional[str] = None,
) -> Figure:
"""
Histogram plot that is intended to show the distribution of the predicted probabilities for different classes, where the the different classes (y_true==0 and y_true==1) are plotted in different colors.
Histogram plot that is intended to show the distribution of the predicted probabilities for different classes, where the the different classes (y_true==0 and y_true==1) are plotted in different colors.
Limitations: Does not work for samples, that can be part of multiple classes (e.g. multilabel classification).
Parameters
Expand All @@ -288,15 +292,19 @@ def plot_y_prob_histogram(
Figure
The figure of the histogram plot.
"""

num_classes = y_true.shape[-1]
class_labels = [f"Class {i}" for i in range(num_classes)]

cmap, colors = get_cmap("roma", n_colors=2) # 2 colors for y==0 and y==1 per class

# Aiming for a square plot
plot_cols = np.ceil(np.sqrt(num_classes)).astype(int) # Number of plots in a row # noqa
plot_rows = np.ceil(num_classes / plot_cols).astype(int) # Number of plots in a column # noqa
plot_cols = np.ceil(np.sqrt(num_classes)).astype(
int
) # Number of plots in a row # noqa
plot_rows = np.ceil(num_classes / plot_cols).astype(
int
) # Number of plots in a column # noqa
fig, axes = plt.subplots(
nrows=plot_rows,
ncols=plot_cols,
Expand Down

0 comments on commit 7bc618b

Please sign in to comment.