Skip to content

Commit

Permalink
update readme & linting
Browse files Browse the repository at this point in the history
  • Loading branch information
joshuawe committed Dec 27, 2023
1 parent 81d89c5 commit 160876a
Show file tree
Hide file tree
Showing 9 changed files with 345 additions and 234 deletions.
42 changes: 18 additions & 24 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ Furthermore, this library presents other useful visualizations, such as **compar
- Classification Report
- Confusion Matrix
- ROC curve (AUROC)
- y_prob histogram
- y_score histogram

- *multi-class classifier*

Expand All @@ -61,7 +61,7 @@ Furthermore, this library presents other useful visualizations, such as **compar

| <img src="https://github.com/joshuawe/plots_and_graphs/blob/main/images/roc_curve_bootstrap.png?raw=true" width="300" alt="Your Image"> | <img src="https://github.com/joshuawe/plots_and_graphs/blob/main/images/pr_curve.png?raw=true" width="300" alt="Your Image"> | <img src="https://github.com/joshuawe/plots_and_graphs/blob/main/images/y_prob_histogram.png?raw=true" width="300" alt="Your Image"> |
|:--------------------------------------------------:|:----------------------------------------------------------:|:-------------------------------------------------:|
| ROC Curve (AUROC) with bootstrapping | Precision-Recall Curve | y_prob histogram |
| ROC Curve (AUROC) with bootstrapping | Precision-Recall Curve | y_score histogram |


| <img src="https://github.com/joshuawe/plots_and_graphs/blob/main/images/multiclass/histogram_4_classes.png?raw=true" width="300" alt="Your Image"> | <img src="https://github.com/joshuawe/plots_and_graphs/blob/main/images/multiclass/roc_curves_multiclass.png?raw=true" width="300" alt=""> | <img src="data:image/gif;base64,R0lGODlhAQABAIAAAAAAAP///yH5BAEAAAAALAAAAAABAAEAAAIBRAA7" width="300" height="300" alt=""> |
Expand Down Expand Up @@ -95,34 +95,28 @@ pip install -e .

# Usage

Example usage of results from a binary classifier for a calibration curve.
Get all classification metrics with **ONE** line of code. Here, for a binary classifier:

```python
import matplotlib.pyplot as plt
import numpy as np
import plotsandgraphs as pandg
# ...
pandg.pipeline.binary_classifier(y_true, y_score)
```

# create some predictions of a hypothetical binary classifier
n_samples = 1000
y_true = np.random.choice([0,1], n_samples, p=[0.4, 0.6]) # the true class labels 0 or 1, with class imbalance 40:60

y_prob = np.zeros(y_true.shape) # a model's probability of class 1 predictions
y_prob[y_true==1] = np.random.beta(1, 0.6, y_prob[y_true==1].shape)
y_prob[y_true==0] = np.random.beta(0.5, 1, y_prob[y_true==0].shape)

# show prob distribution
fig_hist = pandg.binary_classifier.plot_y_prob_histogram(y_prob, y_true, save_fig_path=None)

# create calibration curve
fig_auroc = pandg.binary_classifier.plot_calibration_curve(y_prob, y_true, save_fig_path=None)
Or with some more configs:
```Python
configs = {
'roc': {'n_bootstraps': 10000},
'pr': {'figsize': (8,10)}
}
pandg.pipeline.binary_classifier(y_true, y_score, save_fig_path='results/metrics', file_type='png', plot_kwargs=configs)
```

For multiclass classification:

# --- OPTIONAL: Customize figure ---
# get axis of figure and change title
axes = fig_auroc.get_axes()
ax0 = axes[0]
ax0.set_title('New Title for Calibration Plot')
fig_auroc.show()
```Python
# with multiclass data y_true (one-hot encoded) and y_score
pandg.pipeline.multiclass_classifier(y_true, y_score)
```

# Requirements
Expand Down
40 changes: 30 additions & 10 deletions plotsandgraphs/binary_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ def plot_accuracy(y_true, y_pred, name="", save_fig_path=None) -> Figure:
return fig


def plot_confusion_matrix(y_true: np.ndarray, y_pred: np.ndarray, save_fig_path=None) -> Figure:
def plot_confusion_matrix(
y_true: np.ndarray, y_pred: np.ndarray, save_fig_path=None
) -> Figure:
import matplotlib.colors as colors

# Compute the confusion matrix
Expand All @@ -54,7 +56,9 @@ def plot_confusion_matrix(y_true: np.ndarray, y_pred: np.ndarray, save_fig_path=
cm = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis]

# Create the ConfusionMatrixDisplay instance and plot it
cmd = ConfusionMatrixDisplay(cm, display_labels=["class 0\nnegative", "class 1\npositive"])
cmd = ConfusionMatrixDisplay(
cm, display_labels=["class 0\nnegative", "class 1\npositive"]
)
fig, ax = plt.subplots(figsize=(4, 4))
cmd.plot(
cmap="YlOrRd",
Expand Down Expand Up @@ -144,8 +148,10 @@ def plot_classification_report(
ax : Matplotlib.pyplot.Axe
Axe object from matplotlib
"""
print("Warning: plot_classification_report is not experiencing a bug and is, hence, currently skipped.")
return
print(
"Warning: plot_classification_report is not experiencing a bug and is, hence, currently skipped."
)
return

import matplotlib as mpl
import matplotlib.colors as colors
Expand All @@ -156,7 +162,11 @@ def plot_classification_report(
cmap = "YlOrRd"

clf_report = classification_report(y_true, y_pred, output_dict=True, **kwargs)
keys_to_plot = [key for key in clf_report.keys() if key not in ("accuracy", "macro avg", "weighted avg")]
keys_to_plot = [
key
for key in clf_report.keys()
if key not in ("accuracy", "macro avg", "weighted avg")
]
df = pd.DataFrame(clf_report, columns=keys_to_plot).T
# the following line ensures that dataframe are sorted from the majority classes to the minority classes
df.sort_values(by=["support"], inplace=True)
Expand Down Expand Up @@ -325,7 +335,9 @@ def plot_roc_curve(
auc_upper = np.quantile(bootstrap_aucs, CI_upper)
auc_lower = np.quantile(bootstrap_aucs, CI_lower)
label = f"{confidence_interval:.0%} CI: [{auc_lower:.2f}, {auc_upper:.2f}]"
plt.fill_between(base_fpr, tprs_lower, tprs_upper, alpha=0.3, label=label, zorder=2)
plt.fill_between(
base_fpr, tprs_lower, tprs_upper, alpha=0.3, label=label, zorder=2
)

if highlight_roc_area is True:
print(
Expand Down Expand Up @@ -357,7 +369,9 @@ def plot_roc_curve(
return fig


def plot_calibration_curve(y_true: np.ndarray, y_score: np.ndarray, n_bins=10, save_fig_path=None) -> Figure:
def plot_calibration_curve(
y_true: np.ndarray, y_score: np.ndarray, n_bins=10, save_fig_path=None
) -> Figure:
"""
Creates calibration plot for a binary classifier and calculates the ECE.
Expand All @@ -379,7 +393,9 @@ def plot_calibration_curve(y_true: np.ndarray, y_score: np.ndarray, n_bins=10, s
ece : float
The expected calibration error.
"""
prob_true, prob_pred = calibration_curve(y_true, y_score, n_bins=n_bins, strategy="uniform")
prob_true, prob_pred = calibration_curve(
y_true, y_score, n_bins=n_bins, strategy="uniform"
)

# Find the number of samples in each bin
bin_counts = np.histogram(y_score, bins=n_bins, range=(0, 1))[0]
Expand Down Expand Up @@ -452,7 +468,9 @@ def plot_calibration_curve(y_true: np.ndarray, y_score: np.ndarray, n_bins=10, s
return fig


def plot_y_score_histogram(y_true: Optional[np.ndarray], y_score: np.ndarray = None, save_fig_path=None) -> Figure:
def plot_y_score_histogram(
y_true: Optional[np.ndarray], y_score: np.ndarray = None, save_fig_path=None
) -> Figure:
"""
Provides a histogram for the predicted probabilities of a binary classifier. If ```y_true``` is provided, it divides the ```y_score``` values into the two classes and plots them jointly into the same plot with different colors.
Expand All @@ -474,7 +492,9 @@ def plot_y_score_histogram(y_true: Optional[np.ndarray], y_score: np.ndarray = N
ax = fig.add_subplot(111)

if y_true is None:
ax.hist(y_score, bins=10, alpha=0.9, edgecolor="midnightblue", linewidth=2, rwidth=1)
ax.hist(
y_score, bins=10, alpha=0.9, edgecolor="midnightblue", linewidth=2, rwidth=1
)
# same histogram as above, but with border lines
# ax.hist(y_prob, bins=10, alpha=0.5, edgecolor='black', linewidth=1.2)
else:
Expand Down
32 changes: 20 additions & 12 deletions plotsandgraphs/multiclass_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@
from sklearn.utils import resample
from tqdm import tqdm

from plotsandgraphs.utils import bootstrap, set_black_title_boxes, scale_ax_bbox, get_cmap
from plotsandgraphs.utils import (
bootstrap,
set_black_title_boxes,
scale_ax_bbox,
get_cmap,
)


def plot_roc_curve(
Expand All @@ -32,7 +37,7 @@ def plot_roc_curve(
figsize: Optional[Tuple[float, float]] = None,
class_labels: Optional[List[str]] = None,
split_plots: bool = False,
save_fig_path:Optional[Union[str, Tuple[str, str]]] = None,
save_fig_path: Optional[Union[str, Tuple[str, str]]] = None,
) -> Tuple[Figure, Union[Figure, None]]:
"""
Creates two plots.
Expand Down Expand Up @@ -190,9 +195,6 @@ def roc_metric_function(y_true, y_score):

# create the subplot tiles (and black boxes)
set_black_title_boxes(axes.flat[:num_classes], class_labels)




# ---------- AUROC overview plot comparing classes ----------
# Make an AUROC overview plot comparing the aurocs per class and combined
Expand Down Expand Up @@ -268,10 +270,12 @@ def auroc_metric_function(y_true, y_score, average, multi_class):


def plot_y_score_histogram(
y_true: np.ndarray, y_score: Optional[np.ndarray] = None, save_fig_path: Optional[str]=None
y_true: np.ndarray,
y_score: Optional[np.ndarray] = None,
save_fig_path: Optional[str] = None,
) -> Figure:
"""
Histogram plot that is intended to show the distribution of the predicted probabilities for different classes, where the the different classes (y_true==0 and y_true==1) are plotted in different colors.
Histogram plot that is intended to show the distribution of the predicted probabilities for different classes, where the the different classes (y_true==0 and y_true==1) are plotted in different colors.
Limitations: Does not work for samples, that can be part of multiple classes (e.g. multilabel classification).
Parameters
Expand All @@ -288,15 +292,19 @@ def plot_y_score_histogram(
Figure
The figure of the histogram plot.
"""

num_classes = y_true.shape[-1]
class_labels = [f"Class {i}" for i in range(num_classes)]

cmap, colors = get_cmap("roma", n_colors=2) # 2 colors for y==0 and y==1 per class

# Aiming for a square plot
plot_cols = np.ceil(np.sqrt(num_classes)).astype(int) # Number of plots in a row # noqa
plot_rows = np.ceil(num_classes / plot_cols).astype(int) # Number of plots in a column # noqa
plot_cols = np.ceil(np.sqrt(num_classes)).astype(
int
) # Number of plots in a row # noqa
plot_rows = np.ceil(num_classes / plot_cols).astype(
int
) # Number of plots in a column # noqa
fig, axes = plt.subplots(
nrows=plot_rows,
ncols=plot_cols,
Expand Down
90 changes: 43 additions & 47 deletions plotsandgraphs/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,109 +6,105 @@
from . import multiclass_classifier as mc


FILE_ENDINGS = Literal["pdf", "png", "jpg", "jpeg", "svg"]

FILE_ENDINGS = Literal['pdf', 'png', 'jpg', 'jpeg', 'svg']



def binary_classifier(y_true, y_score, save_fig_path=None, plot_kwargs={}, file_type:FILE_ENDINGS='png'):


def binary_classifier(
y_true, y_score, save_fig_path=None, plot_kwargs={}, file_type: FILE_ENDINGS = "png"
):
# Create new tqdm instance
tqdm_instance = tqdm(total=6, desc='Binary classifier metrics', leave=True)
tqdm_instance = tqdm(total=6, desc="Binary classifier metrics", leave=True)

# Update tqdm instance
tqdm_instance.update()

# 1) Plot ROC curve
roc_kwargs = plot_kwargs.get('roc', {})
save_path = get_file_path(save_fig_path, 'roc_curve', file_type)
roc_kwargs = plot_kwargs.get("roc", {})
save_path = get_file_path(save_fig_path, "roc_curve", file_type)
bc.plot_roc_curve(y_true, y_score, save_fig_path=save_path, **roc_kwargs)
tqdm_instance.update()

# 2) Plot precision-recall curve
pr_kwargs = plot_kwargs.get('pr', {})
save_path = get_file_path(save_fig_path, 'pr_curve', file_type)
pr_kwargs = plot_kwargs.get("pr", {})
save_path = get_file_path(save_fig_path, "pr_curve", file_type)
bc.plot_pr_curve(y_true, y_score, save_fig_path=save_path, **pr_kwargs)
tqdm_instance.update()

# 3) Plot calibration curve
cal_kwargs = plot_kwargs.get('cal', {})
save_path = get_file_path(save_fig_path, 'calibration_curve', file_type)
cal_kwargs = plot_kwargs.get("cal", {})
save_path = get_file_path(save_fig_path, "calibration_curve", file_type)
bc.plot_calibration_curve(y_true, y_score, save_fig_path=save_path, **cal_kwargs)
tqdm_instance.update()

# 3) Plot confusion matrix
cm_kwargs = plot_kwargs.get('cm', {})
save_path = get_file_path(save_fig_path, 'confusion_matrix', file_type)
cm_kwargs = plot_kwargs.get("cm", {})
save_path = get_file_path(save_fig_path, "confusion_matrix", file_type)
bc.plot_confusion_matrix(y_true, y_score, save_fig_path=save_path, **cm_kwargs)
tqdm_instance.update()

# 5) Plot classification report
cr_kwargs = plot_kwargs.get('cr', {})
save_path = get_file_path(save_fig_path, 'classification_report', file_type)
cr_kwargs = plot_kwargs.get("cr", {})
save_path = get_file_path(save_fig_path, "classification_report", file_type)
bc.plot_classification_report(y_true, y_score, save_fig_path=save_path, **cr_kwargs)
tqdm_instance.update()

# 6) Plot y_score histogram
hist_kwargs = plot_kwargs.get('hist', {})
save_path = get_file_path(save_fig_path, 'y_score_histogram', file_type)
hist_kwargs = plot_kwargs.get("hist", {})
save_path = get_file_path(save_fig_path, "y_score_histogram", file_type)
bc.plot_y_score_histogram(y_true, y_score, save_fig_path=save_path, **hist_kwargs)
tqdm_instance.update()

return


return


def multiclass_classifier(y_true, y_score, save_fig_path=None, plot_kwargs={}, file_type:FILE_ENDINGS = 'png'):

def multiclass_classifier(
y_true, y_score, save_fig_path=None, plot_kwargs={}, file_type: FILE_ENDINGS = "png"
):
# Create new tqdm instance
tqdm_instance = tqdm(total=6, desc='Binary classifier metrics', leave=True)
tqdm_instance = tqdm(total=6, desc="Binary classifier metrics", leave=True)

# Update tqdm instance
tqdm_instance.update()

# 1) Plot ROC curve
roc_kwargs = plot_kwargs.get('roc', {})
save_path = get_file_path(save_fig_path, 'roc_curve', '')
roc_kwargs = plot_kwargs.get("roc", {})
save_path = get_file_path(save_fig_path, "roc_curve", "")
mc.plot_roc_curve(y_true, y_score, save_fig_path=save_path, **roc_kwargs)
tqdm_instance.update()

# 2) Plot precision-recall curve
# pr_kwargs = plot_kwargs.get('pr', {})
# mc.plot_pr_curve(y_true, y_score, save_fig_path=save_fig_path, **pr_kwargs)
# tqdm_instance.update()

# 3) Plot calibration curve
# cal_kwargs = plot_kwargs.get('cal', {})
# mc.plot_calibration_curve(y_true, y_score, save_fig_path=save_fig_path, **cal_kwargs)
# tqdm_instance.update()

# 3) Plot confusion matrix
# cm_kwargs = plot_kwargs.get('cm', {})
# mc.plot_confusion_matrix(y_true, y_score, save_fig_path=save_fig_path, **cm_kwargs)
# tqdm_instance.update()

# 5) Plot classification report
# cr_kwargs = plot_kwargs.get('cr', {})
# mc.plot_classification_report(y_true, y_score, save_fig_path=save_fig_path, **cr_kwargs)
# tqdm_instance.update()

# 6) Plot y_score histogram
hist_kwargs = plot_kwargs.get('hist', {})
save_path = get_file_path(save_fig_path, 'y_score_histogram', file_type)
hist_kwargs = plot_kwargs.get("hist", {})
save_path = get_file_path(save_fig_path, "y_score_histogram", file_type)
mc.plot_y_score_histogram(y_true, y_score, save_fig_path=save_path, **hist_kwargs)
tqdm_instance.update()

return

return


def get_file_path(save_fig_path: Union[Path,None, str], name:str, ending:str):
def get_file_path(save_fig_path: Union[Path, None, str], name: str, ending: str):
if save_fig_path is None:
return None
else:
result = Path(save_fig_path) / f"{name}.{ending}"
print(result)
return str(result)
return str(result)

0 comments on commit 160876a

Please sign in to comment.