Skip to content

Commit

Permalink
Improve testing (#16)
Browse files Browse the repository at this point in the history
* upload tests

* cosmetic adjustments

* cosmetic adjustments (black)

* pylint optimization

* update pylint

* add tests accuracy + roc bootstrapping

* fix tests and added plot

* improve pylint score
  • Loading branch information
joshuawe committed Nov 10, 2023
1 parent 9e91555 commit 4ac598c
Show file tree
Hide file tree
Showing 13 changed files with 208 additions and 66 deletions.
5 changes: 0 additions & 5 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,6 @@ jobs:
- uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Debugging
run: |
ls -la
cat Makefile
make virtualenv
- name: Install project
run: |
make virtualenv
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/pylint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,5 @@ jobs:
pip install -r requirements.txt
- name: Analysing the code with pylint
run: |
pylint --fail-under=4 $(git ls-files '*.py')
pylint --fail-under=6 $(git ls-files '*.py')
# pylint $(git ls-files '*.py')
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ These could include visualizing the results for a binary classifier, for which p
|:--------------------------------------------------:|:----------------------------------------------------------:|:-------------------------------------------------:|
| Calibration Curve | Classification Report | Confusion Matrix |

| <img src="https://github.com/joshuawe/plots_and_graphs/blob/main/images/roc_curve.png?raw=true" width="300" alt="Your Image"> | <img src="https://github.com/joshuawe/plots_and_graphs/blob/main/images/roc_curve_bootstrap.png?raw=true" width="300" alt="Your Image"> | <img src="https://github.com/joshuawe/plots_and_graphs/blob/main/images/y_prob_histogram.png?raw=true" width="300" alt="Your Image"> |
| <img src="https://github.com/joshuawe/plots_and_graphs/blob/main/images/roc_curve_bootstrap.png?raw=true" width="300" alt="Your Image"> | <img src="https://github.com/joshuawe/plots_and_graphs/blob/main/images/pr_curve.png?raw=true" width="300" alt="Your Image"> | <img src="https://github.com/joshuawe/plots_and_graphs/blob/main/images/y_prob_histogram.png?raw=true" width="300" alt="Your Image"> |
|:--------------------------------------------------:|:----------------------------------------------------------:|:-------------------------------------------------:|
| ROC Curve (AUROC) | ROC Curve (AUROC) with bootstrapping | y_prob histogram |
| ROC Curve (AUROC) with bootstrapping | Precision-Recall Curve | y_prob histogram |


| <img src="https://github.com/joshuawe/plots_and_graphs/blob/main/images/raincloud.png?raw=true" width="300" alt="Your Image"> | <img src="data:image/gif;base64,R0lGODlhAQABAIAAAAAAAP///yH5BAEAAAAALAAAAAABAAEAAAIBRAA7" width="300" height="300" alt=""> | <img src="data:image/gif;base64,R0lGODlhAQABAIAAAAAAAP///yH5BAEAAAAALAAAAAABAAEAAAIBRAA7" width="300" height="300" alt=""> |
Expand All @@ -82,7 +82,7 @@ Install the package via pip.
pip install plotsandgraphs
```

Alternativelynstall the package from git.
Alternatively install the package from git.
```bash
git clone https://github.com/joshuawe/plots_and_graphs
cd plots_and_graphs
Expand Down
Binary file added images/pr_curve.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified images/y_prob_histogram.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
78 changes: 39 additions & 39 deletions plotsandgraphs/binary_classifier.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from pathlib import Path
from typing import Optional
import matplotlib.pyplot as plt
from matplotlib.colors import to_rgba
from matplotlib.figure import Figure
import seaborn as sns
import numpy as np
import pandas as pd
from sklearn.metrics import (
Expand All @@ -15,9 +16,7 @@
)
from sklearn.calibration import calibration_curve
from sklearn.utils import resample
from pathlib import Path
from tqdm import tqdm
from typing import Optional


def plot_accuracy(y_true, y_pred, name="", save_fig_path=None) -> Figure:
Expand All @@ -39,16 +38,14 @@ def plot_accuracy(y_true, y_pred, name="", save_fig_path=None) -> Figure:
plt.title(title)
plt.tight_layout()

if save_fig_path != None:
if save_fig_path is not None:
path = Path(save_fig_path)
path.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(save_fig_path, bbox_inches="tight")
return fig, accuracy
return fig


def plot_confusion_matrix(
y_true: np.ndarray, y_pred: np.ndarray, save_fig_path=None
) -> Figure:
def plot_confusion_matrix(y_true: np.ndarray, y_pred: np.ndarray, save_fig_path=None) -> Figure:
import matplotlib.colors as colors

# Compute the confusion matrix
Expand All @@ -57,16 +54,14 @@ def plot_confusion_matrix(
cm = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis]

# Create the ConfusionMatrixDisplay instance and plot it
cmd = ConfusionMatrixDisplay(
cm, display_labels=["class 0\nnegative", "class 1\npositive"]
)
cmd = ConfusionMatrixDisplay(cm, display_labels=["class 0\nnegative", "class 1\npositive"])
fig, ax = plt.subplots(figsize=(4, 4))
cmd.plot(
cmap="YlOrRd",
values_format="",
colorbar=False,
ax=ax,
text_kw={"visible": False},
# text_kw={"visible": False},
)
cmd.texts_ = []
cmd.text_ = []
Expand Down Expand Up @@ -106,7 +101,7 @@ def plot_confusion_matrix(
cbar.outline.set_visible(False)
plt.tight_layout()

if save_fig_path != None:
if save_fig_path is not None:
path = Path(save_fig_path)
path.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(save_fig_path, bbox_inches="tight")
Expand All @@ -115,7 +110,7 @@ def plot_confusion_matrix(


def plot_classification_report(
y_test: np.ndarray,
y_true: np.ndarray,
y_pred: np.ndarray,
title="Classification Report",
figsize=(8, 4),
Expand Down Expand Up @@ -152,18 +147,13 @@ def plot_classification_report(
import matplotlib as mpl
import matplotlib.colors as colors
import seaborn as sns
import pathlib

fig, ax = plt.subplots(figsize=figsize)

cmap = "YlOrRd"

clf_report = classification_report(y_test, y_pred, output_dict=True, **kwargs)
keys_to_plot = [
key
for key in clf_report.keys()
if key not in ("accuracy", "macro avg", "weighted avg")
]
clf_report = classification_report(y_true, y_pred, output_dict=True, **kwargs)
keys_to_plot = [key for key in clf_report.keys() if key not in ("accuracy", "macro avg", "weighted avg")]
df = pd.DataFrame(clf_report, columns=keys_to_plot).T
# the following line ensures that dataframe are sorted from the majority classes to the minority classes
df.sort_values(by=["support"], inplace=True)
Expand All @@ -174,8 +164,8 @@ def plot_classification_report(
mask[:, cols - 1] = True

bounds = np.linspace(0, 1, 11)
cmap = plt.cm.get_cmap("YlOrRd", len(bounds) + 1)
norm = colors.BoundaryNorm(bounds, cmap.N) # type: ignore[attr-defined]
cmap = plt.cm.get_cmap("YlOrRd", len(bounds) + 1) # type: ignore[assignment]
norm = colors.BoundaryNorm(bounds, cmap.N) # type: ignore[attr-defined]

ax = sns.heatmap(
df,
Expand Down Expand Up @@ -247,7 +237,7 @@ def plot_classification_report(
plt.yticks(rotation=360)
plt.tight_layout()

if save_fig_path != None:
if save_fig_path is not None:
path = Path(save_fig_path)
path.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(save_fig_path, bbox_inches="tight")
Expand Down Expand Up @@ -332,9 +322,7 @@ def plot_roc_curve(
auc_upper = np.quantile(bootstrap_aucs, CI_upper)
auc_lower = np.quantile(bootstrap_aucs, CI_lower)
label = f"{confidence_interval:.0%} CI: [{auc_lower:.2f}, {auc_upper:.2f}]"
plt.fill_between(
base_fpr, tprs_lower, tprs_upper, alpha=0.3, label=label, zorder=2
)
plt.fill_between(base_fpr, tprs_lower, tprs_upper, alpha=0.3, label=label, zorder=2)

if highlight_roc_area is True:
print(
Expand Down Expand Up @@ -366,9 +354,7 @@ def plot_roc_curve(
return fig


def plot_calibration_curve(
y_prob: np.ndarray, y_true: np.ndarray, n_bins=10, save_fig_path=None
) -> Figure:
def plot_calibration_curve(y_prob: np.ndarray, y_true: np.ndarray, n_bins=10, save_fig_path=None) -> Figure:
"""
Creates calibration plot for a binary classifier and calculates the ECE.
Expand All @@ -390,9 +376,7 @@ def plot_calibration_curve(
ece : float
The expected calibration error.
"""
prob_true, prob_pred = calibration_curve(
y_true, y_prob, n_bins=n_bins, strategy="uniform"
)
prob_true, prob_pred = calibration_curve(y_true, y_prob, n_bins=n_bins, strategy="uniform")

# Find the number of samples in each bin
bin_counts = np.histogram(y_prob, bins=n_bins, range=(0, 1))[0]
Expand Down Expand Up @@ -465,7 +449,7 @@ def plot_calibration_curve(
return fig


def plot_y_prob_histogram(y_prob: np.ndarray, y_true: Optional[np.ndarray]=None, save_fig_path=None) -> Figure:
def plot_y_prob_histogram(y_prob: np.ndarray, y_true: Optional[np.ndarray] = None, save_fig_path=None) -> Figure:
"""
Provides a histogram for the predicted probabilities of a binary classifier. If ```y_true``` is provided, it divides the ```y_prob``` values into the two classes and plots them jointly into the same plot with different colors.
Expand All @@ -485,16 +469,32 @@ def plot_y_prob_histogram(y_prob: np.ndarray, y_true: Optional[np.ndarray]=None,
"""
fig = plt.figure(figsize=(5, 5))
ax = fig.add_subplot(111)

if y_true is None:
ax.hist(y_prob, bins=10, alpha=0.9, edgecolor="midnightblue", linewidth=2, rwidth=1)
# same histogram as above, but with border lines
# ax.hist(y_prob, bins=10, alpha=0.5, edgecolor='black', linewidth=1.2)
else:
alpha = 0.6
ax.hist(y_prob[y_true==0], bins=10, alpha=alpha, edgecolor="midnightblue", linewidth=2, rwidth=1, label="$\\hat{y} = 0$")
ax.hist(y_prob[y_true==1], bins=10, alpha=alpha, edgecolor="darkred", linewidth=2, rwidth=1, label="$\\hat{y} = 1$")

ax.hist(
y_prob[y_true == 0],
bins=10,
alpha=alpha,
edgecolor="midnightblue",
linewidth=2,
rwidth=1,
label="$\\hat{y} = 0$",
)
ax.hist(
y_prob[y_true == 1],
bins=10,
alpha=alpha,
edgecolor="darkred",
linewidth=2,
rwidth=1,
label="$\\hat{y} = 1$",
)

plt.legend()
ax.set(xlabel="Predicted probability [-]", ylabel="Count [-]", xlim=(-0.01, 1.0))
ax.set_title("Histogram of predicted probabilities")
Expand All @@ -505,7 +505,7 @@ def plot_y_prob_histogram(y_prob: np.ndarray, y_true: Optional[np.ndarray]=None,
plt.tight_layout()

# save plot
if save_fig_path != None:
if save_fig_path is not None:
path = Path(save_fig_path)
path.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(save_fig_path, bbox_inches="tight")
Expand Down
2 changes: 1 addition & 1 deletion plotsandgraphs/compare_distributions.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import List, Tuple, Optional
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import pandas as pd
from typing import List, Tuple, Optional


def plot_raincloud(
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,8 @@ max-line-length = 120
[tool.pylint."BASIC"]
variable-rgx = "[a-z_][a-z0-9_]{0,30}$|[a-z0-9_]+([A-Z][a-z0-9_]+)*$" # Allow snake case and camel case for variable names

[tool.pylint."MESSAGES CONTROL"]
disable = "W0621" # Allow redefining names in outer scope

[flake8]
max-line-length = 120
7 changes: 3 additions & 4 deletions src/binary_classifier.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
from typing import Optional
from pathlib import Path
import matplotlib.pyplot as plt
from matplotlib.colors import to_rgba
from matplotlib.figure import Figure
import seaborn as sns
import numpy as np
import pandas as pd
from sklearn.metrics import confusion_matrix, classification_report, ConfusionMatrixDisplay, roc_curve, auc, accuracy_score, precision_recall_curve
from sklearn.calibration import calibration_curve
from sklearn.utils import resample
from pathlib import Path
from tqdm import tqdm
from typing import Optional


def plot_accuracy(y_true, y_pred, name='', save_fig_path=None) -> Figure:
Expand Down Expand Up @@ -381,7 +380,7 @@ def plot_y_prob_histogram(y_prob: np.ndarray, save_fig_path=None) -> Figure:
plt.tight_layout()

# save plot
if (save_fig_path != None):
if (save_fig_path is not None):
path = Path(save_fig_path)
path.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(save_fig_path, bbox_inches='tight')
Expand Down
17 changes: 8 additions & 9 deletions src/compare_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@

def plot_raincloud(df: pd.DataFrame,
x_col: str,
y_col: str,
colors: List[str] = None,
order: List[str] = None,
title: str = None,
x_label: str = None,
x_range: Tuple[float, float] = None,
show_violin = True,
show_scatter = True,
y_col: str,
colors: List[str] = None,
order: List[str] = None,
title: str = None,
x_label: str = None,
x_range: Tuple[float, float] = None,
show_violin = True,
show_scatter = True,
show_boxplot = True):

"""
Expand Down Expand Up @@ -49,7 +49,6 @@ def plot_raincloud(df: pd.DataFrame,
colors = [mpl.colors.to_hex(cmap(i)) for i in np.linspace(0, 1, len(order))]
else:
assert len(colors) == len(order), 'colors and order must be the same length'
colors = colors

# Boxplot
if show_boxplot:
Expand Down
7 changes: 7 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import os

TEST_RESULTS_PATH = os.path.join(os.path.dirname(__file__), "test_results")

# print cwd in console

# print os.path.dirname(__file__)

0 comments on commit 4ac598c

Please sign in to comment.