Skip to content

Commit

Permalink
pipeline multiclass
Browse files Browse the repository at this point in the history
  • Loading branch information
joshuawe committed Dec 26, 2023
1 parent 59aa72b commit 1cc259d
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 12 deletions.
4 changes: 3 additions & 1 deletion plotsandgraphs/binary_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,8 @@ def plot_pr_curve(

# Save the figure if save_fig_path is specified
if save_fig_path:
plt.savefig(save_fig_path, bbox_inches="tight")
path = Path(save_fig_path)
path.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(save_fig_path, bbox_inches="tight")

return fig
2 changes: 1 addition & 1 deletion plotsandgraphs/multiclass_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def auroc_metric_function(y_true, y_score, average, multi_class):
return fig, fig_aurocs


def plot_y_prob_histogram(
def plot_y_score_histogram(
y_true: np.ndarray, y_score: Optional[np.ndarray] = None, save_fig_path: Optional[str]=None
) -> Figure:
"""
Expand Down
85 changes: 76 additions & 9 deletions plotsandgraphs/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
from typing import Literal, Union
from pathlib import Path
from tqdm.auto import tqdm

from . import binary_classifier as bc
from . import multiclass_classifier as mc

from tqdm.auto import tqdm


FILE_ENDINGS = Literal['pdf', 'png', 'jpg', 'jpeg', 'svg']



def binary_classifier(y_true, y_score, save_fig_path=None, plot_kwargs={}):
def binary_classifier(y_true, y_score, save_fig_path=None, plot_kwargs={}, file_type:FILE_ENDINGS='png'):


# Create new tqdm instance
tqdm_instance = tqdm(total=6, desc='Binary classifier metrics', leave=True)
Expand All @@ -14,32 +22,91 @@ def binary_classifier(y_true, y_score, save_fig_path=None, plot_kwargs={}):

# 1) Plot ROC curve
roc_kwargs = plot_kwargs.get('roc', {})
bc.plot_roc_curve(y_true, y_score, save_fig_path=save_fig_path, **roc_kwargs)
save_path = get_file_path(save_fig_path, 'roc_curve', file_type)
bc.plot_roc_curve(y_true, y_score, save_fig_path=save_path, **roc_kwargs)
tqdm_instance.update()

# 2) Plot precision-recall curve
pr_kwargs = plot_kwargs.get('pr', {})
bc.plot_pr_curve(y_true, y_score, save_fig_path=save_fig_path, **pr_kwargs)
save_path = get_file_path(save_fig_path, 'pr_curve', file_type)
bc.plot_pr_curve(y_true, y_score, save_fig_path=save_path, **pr_kwargs)
tqdm_instance.update()

# 3) Plot calibration curve
cal_kwargs = plot_kwargs.get('cal', {})
bc.plot_calibration_curve(y_true, y_score, save_fig_path=save_fig_path, **cal_kwargs)
save_path = get_file_path(save_fig_path, 'calibration_curve', file_type)
bc.plot_calibration_curve(y_true, y_score, save_fig_path=save_path, **cal_kwargs)
tqdm_instance.update()

# 3) Plot confusion matrix
cm_kwargs = plot_kwargs.get('cm', {})
bc.plot_confusion_matrix(y_true, y_score, save_fig_path=save_fig_path, **cm_kwargs)
save_path = get_file_path(save_fig_path, 'confusion_matrix', file_type)
bc.plot_confusion_matrix(y_true, y_score, save_fig_path=save_path, **cm_kwargs)
tqdm_instance.update()

# 5) Plot classification report
cr_kwargs = plot_kwargs.get('cr', {})
bc.plot_classification_report(y_true, y_score, save_fig_path=save_fig_path, **cr_kwargs)
save_path = get_file_path(save_fig_path, 'classification_report', file_type)
bc.plot_classification_report(y_true, y_score, save_fig_path=save_path, **cr_kwargs)
tqdm_instance.update()

# 6) Plot y_score histogram
hist_kwargs = plot_kwargs.get('hist', {})
save_path = get_file_path(save_fig_path, 'y_score_histogram', file_type)
bc.plot_y_score_histogram(y_true, y_score, save_fig_path=save_path, **hist_kwargs)
tqdm_instance.update()

return




def multiclass_classifier(y_true, y_score, save_fig_path=None, plot_kwargs={}):

# Create new tqdm instance
tqdm_instance = tqdm(total=6, desc='Binary classifier metrics', leave=True)

# Update tqdm instance
tqdm_instance.update()

# 1) Plot ROC curve
roc_kwargs = plot_kwargs.get('roc', {})
mc.plot_roc_curve(y_true, y_score, save_fig_path=save_fig_path, **roc_kwargs)
tqdm_instance.update()

# 2) Plot precision-recall curve
# pr_kwargs = plot_kwargs.get('pr', {})
# mc.plot_pr_curve(y_true, y_score, save_fig_path=save_fig_path, **pr_kwargs)
# tqdm_instance.update()

# 3) Plot calibration curve
# cal_kwargs = plot_kwargs.get('cal', {})
# mc.plot_calibration_curve(y_true, y_score, save_fig_path=save_fig_path, **cal_kwargs)
# tqdm_instance.update()

# 3) Plot confusion matrix
# cm_kwargs = plot_kwargs.get('cm', {})
# mc.plot_confusion_matrix(y_true, y_score, save_fig_path=save_fig_path, **cm_kwargs)
# tqdm_instance.update()

# 5) Plot classification report
# cr_kwargs = plot_kwargs.get('cr', {})
# mc.plot_classification_report(y_true, y_score, save_fig_path=save_fig_path, **cr_kwargs)
# tqdm_instance.update()

# 6) Plot y_score histogram
hist_kwargs = plot_kwargs.get('hist', {})
bc.plot_y_score_histogram(y_true, y_score, save_fig_path=save_fig_path, **hist_kwargs)
mc.plot_y_score_histogram(y_true, y_score, save_fig_path=save_fig_path, **hist_kwargs)
tqdm_instance.update()

return
return



def get_file_path(save_fig_path: Union[Path,None, str], name:str, ending:str):
if save_fig_path is None:
return None
else:
result = Path(save_fig_path) / f"{name}.{ending}"
print(result)
return str(result)
2 changes: 1 addition & 1 deletion tests/test_binary_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest
import plotsandgraphs.binary_classifier as binary

TEST_RESULTS_PATH = Path(r"tests\test_results")
TEST_RESULTS_PATH = Path(r"tests\test_results\binary_classifier")


@pytest.fixture(scope="module")
Expand Down

0 comments on commit 1cc259d

Please sign in to comment.