Skip to content

Commit

Permalink
Made get_prediction_vs_actual_data public
Browse files Browse the repository at this point in the history
  • Loading branch information
christopherbunn committed Dec 14, 2020
1 parent 0f1f1c6 commit 23bdbea
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 3 deletions.
1 change: 1 addition & 0 deletions docs/source/api_reference.rst
Expand Up @@ -267,6 +267,7 @@ Utility Methods
binary_objective_vs_threshold
get_prediction_vs_actual_over_time_data
partial_dependence
get_prediction_vs_actual_data


Graph Utility Methods
Expand Down
1 change: 1 addition & 0 deletions evalml/model_understanding/__init__.py
Expand Up @@ -12,6 +12,7 @@
graph_binary_objective_vs_threshold,
partial_dependence,
graph_partial_dependence,
get_prediction_vs_actual_data,
graph_prediction_vs_actual,
get_prediction_vs_actual_over_time_data,
graph_prediction_vs_actual_over_time,
Expand Down
24 changes: 21 additions & 3 deletions evalml/model_understanding/graphs.py
Expand Up @@ -507,8 +507,26 @@ def _calculate_axis_range(arr):
return [min_value - margins, max_value + margins]


def _get_prediction_vs_actual_data(y_true, y_pred, outlier_threshold):
"""Helper method to help calculate the y_true and y_pred dataframe, with a column for outliers"""
def get_prediction_vs_actual_data(y_true, y_pred, outlier_threshold=None):
"""Combines y_true and y_pred into a single dataframe and adds a column for outliers. Used in `graph_prediction_vs_actual()`.
Arguments:
y_true (pd.Series): The real target values of the data
y_pred (pd.Series): The predicted values outputted by the regression model.
outlier_threshold (int, float): A positive threshold for what is considered an outlier value. This value is compared to the absolute difference
between each value of y_true and y_pred. Values within this threshold will be blue, otherwise they will be yellow.
Defaults to None
Returns:
pd.DataFrame with the following columns:
* `prediction`: Predicted values from regression model.
* `actual`: Real target values.
* `outlier`: Colors indicating which values are in the threshold for what is considered an outlier value.
"""
if outlier_threshold and outlier_threshold <= 0:
raise ValueError(f"Threshold must be positive! Provided threshold is {outlier_threshold}")

y_true = _convert_to_woodwork_structure(y_true)
y_true = _convert_woodwork_types_wrapper(y_true.to_series())
y_pred = _convert_to_woodwork_structure(y_pred)
Expand Down Expand Up @@ -547,7 +565,7 @@ def graph_prediction_vs_actual(y_true, y_pred, outlier_threshold=None):
if outlier_threshold and outlier_threshold <= 0:
raise ValueError(f"Threshold must be positive! Provided threshold is {outlier_threshold}")

df = _get_prediction_vs_actual_data(y_true, y_pred, outlier_threshold)
df = get_prediction_vs_actual_data(y_true, y_pred, outlier_threshold)
data = []

x_axis = _calculate_axis_range(df['prediction'])
Expand Down
26 changes: 26 additions & 0 deletions evalml/tests/model_understanding_tests/test_graphs.py
Expand Up @@ -20,6 +20,7 @@
confusion_matrix,
decision_tree_data_from_estimator,
decision_tree_data_from_pipeline,
get_prediction_vs_actual_data,
graph_binary_objective_vs_threshold,
graph_confusion_matrix,
graph_partial_dependence,
Expand Down Expand Up @@ -931,6 +932,31 @@ def test_jupyter_graph_check(import_check, jupyter_check, X_y_binary, X_y_regres
import_check.assert_called_with('ipywidgets', warning=True)


def test_get_prediction_vs_actual_data():
y_true = [1, 2, 3000, 4, 5, 6, 7, 8, 9, 10, 11, 12]
y_pred = [5, 4, 2, 8, 6, 6, 5, 1, 7, 2, 1, 3000]

with pytest.raises(ValueError, match="Threshold must be positive!"):
get_prediction_vs_actual_data(y_true, y_pred, outlier_threshold=-1)

outlier_loc = [2, 11]
results = get_prediction_vs_actual_data(y_true, y_pred, outlier_threshold=2000)
assert isinstance(results, pd.DataFrame)
assert np.array_equal(results['prediction'], y_pred)
assert np.array_equal(results['actual'], y_true)
for i, value in enumerate(results['outlier']):
if i in outlier_loc:
assert value == "#ffff00"
else:
assert value == '#0000ff'

results = get_prediction_vs_actual_data(y_true, y_pred)
assert isinstance(results, pd.DataFrame)
assert np.array_equal(results['prediction'], y_pred)
assert np.array_equal(results['actual'], y_true)
assert (results['outlier'] == '#0000ff').all()


def test_graph_prediction_vs_actual_default():
go = pytest.importorskip('plotly.graph_objects', reason='Skipping plotting test because plotly not installed')
y_true = [1, 2, 3000, 4, 5, 6, 7, 8, 9, 10, 11, 12]
Expand Down

0 comments on commit 23bdbea

Please sign in to comment.