Skip to content

Commit

Permalink
add pipeline test
Browse files Browse the repository at this point in the history
  • Loading branch information
joshuawe committed Dec 22, 2023
1 parent bd34e12 commit 7fa2c3f
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 6 deletions.
12 changes: 6 additions & 6 deletions plotsandgraphs/binary_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,16 +354,16 @@ def plot_roc_curve(
return fig


def plot_calibration_curve(y_prob: np.ndarray, y_true: np.ndarray, n_bins=10, save_fig_path=None) -> Figure:
def plot_calibration_curve(y_true: np.ndarray, y_score: np.ndarray, n_bins=10, save_fig_path=None) -> Figure:
"""
Creates calibration plot for a binary classifier and calculates the ECE.
Parameters
----------
y_prob : np.ndarray
The output probabilities of the classifier. Between 0 and 1.
y_true : np.ndarray
The actual labels of the data. Either 0 or 1.
y_score : np.ndarray
The output probabilities of the classifier. Between 0 and 1.
n_bins : int
The number of bins to use for the calibration curve.
save_fig_path : str, optional
Expand All @@ -376,13 +376,13 @@ def plot_calibration_curve(y_prob: np.ndarray, y_true: np.ndarray, n_bins=10, sa
ece : float
The expected calibration error.
"""
prob_true, prob_pred = calibration_curve(y_true, y_prob, n_bins=n_bins, strategy="uniform")
prob_true, prob_pred = calibration_curve(y_true, y_score, n_bins=n_bins, strategy="uniform")

# Find the number of samples in each bin
bin_counts = np.histogram(y_prob, bins=n_bins, range=(0, 1))[0]
bin_counts = np.histogram(y_score, bins=n_bins, range=(0, 1))[0]

# Calculate the weighted absolute difference (ECE)
ece = np.abs(prob_pred - prob_true) * (bin_counts / len(y_prob))
ece = np.abs(prob_pred - prob_true) * (bin_counts / len(y_score))
ece = ece.sum().round(2)

fig = plt.figure(figsize=(5, 5))
Expand Down
23 changes: 23 additions & 0 deletions tests/test_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from pathlib import Path
from typing import Tuple
import numpy as np
import pytest

from plotsandgraphs import pipeline


from .utils import random_data_binary_classifier

TEST_RESULTS_PATH = Path("tests/test_results/pipeline")

def test_binary_classification_pipeline(random_data_binary_classifier):
"""
Test binary classification pipeline.
Parameters
----------
random_data_binary_classifier : Tuple[np.ndarray, np.ndarray]
The simulated data.
"""
y_true, y_score = random_data_binary_classifier
pipeline.binary_classifier(y_true, y_score, save_fig_path=TEST_RESULTS_PATH / "pipeline.png")
25 changes: 25 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from typing import Tuple
import numpy as np
import pytest


@pytest.fixture(scope="module")
def random_data_binary_classifier() -> Tuple[np.ndarray, np.ndarray]:
"""
Create random data for binary classifier tests.
Returns
-------
Tuple[np.ndarray, np.ndarray]
The simulated data. y_true, y_score
"""
# create some data
n_samples = 1000
y_true = np.random.choice(
[0, 1], n_samples, p=[0.4, 0.6]
) # the true class labels 0 or 1, with class imbalance 40:60

y_score = np.zeros(y_true.shape) # a model's probability of class 1 predictions
y_score[y_true == 1] = np.random.beta(1, 0.6, y_score[y_true == 1].shape)
y_score[y_true == 0] = np.random.beta(0.5, 1, y_score[y_true == 0].shape)
return y_true, y_score

0 comments on commit 7fa2c3f

Please sign in to comment.