Skip to content

Commit

Permalink
fix: Bump prophet, re-enable tests, and remedy column eligibility log…
Browse files Browse the repository at this point in the history
…ic (#24129)
  • Loading branch information
john-bodley committed Jul 5, 2023
1 parent 0836000 commit 383dac6
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 12 deletions.
2 changes: 1 addition & 1 deletion requirements/testing.in
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#
-r development.in
-r integration.in
-e file:.[bigquery,hive,presto,trino]
-e file:.[bigquery,hive,presto,prophet,trino]
docker
flask-testing
freezegun
Expand Down
26 changes: 23 additions & 3 deletions requirements/testing.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SHA1:623feb0dd2b6bd376238ecf75069bc82136c2d70
# SHA1:78fe89f88adf34ac75513d363d7d9d0b5cc8cd1c
#
# This file is autogenerated by pip-compile-multi
# To update, run:
Expand All @@ -12,16 +12,26 @@
# -r requirements/base.in
# -r requirements/development.in
# -r requirements/testing.in
cmdstanpy==1.1.0
# via prophet
contourpy==1.0.7
# via matplotlib
coverage[toml]==7.2.5
# via pytest-cov
cycler==0.11.0
# via matplotlib
db-dtypes==1.1.1
# via pandas-gbq
docker==6.1.1
# via -r requirements/testing.in
ephem==4.1.4
# via lunarcalendar
exceptiongroup==1.1.1
# via pytest
flask-testing==0.8.1
# via -r requirements/testing.in
fonttools==4.39.4
# via matplotlib
freezegun==1.2.2
# via -r requirements/testing.in
google-api-core[grpc]==2.11.0
Expand Down Expand Up @@ -73,6 +83,12 @@ iniconfig==2.0.0
# via pytest
jsonschema-spec==0.1.4
# via openapi-spec-validator
kiwisolver==1.4.4
# via matplotlib
lunarcalendar==0.0.9
# via prophet
matplotlib==3.7.1
# via prophet
oauthlib==3.2.2
# via requests-oauthlib
openapi-schema-validator==0.4.4
Expand All @@ -85,6 +101,8 @@ parameterized==0.9.0
# via -r requirements/testing.in
pathable==0.4.3
# via jsonschema-spec
prophet==1.1.3
# via apache-superset
proto-plus==1.22.2
# via
# google-cloud-bigquery
Expand All @@ -107,8 +125,6 @@ pydata-google-auth==1.7.0
# via pandas-gbq
pyfakefs==5.2.2
# via -r requirements/testing.in
pyhive[presto]==0.6.5
# via apache-superset
pytest==7.3.1
# via
# -r requirements/testing.in
Expand All @@ -130,6 +146,10 @@ sqlalchemy-bigquery==1.6.1
# via apache-superset
statsd==4.0.1
# via -r requirements/testing.in
tqdm==4.65.0
# via
# cmdstanpy
# prophet
trino==0.324.0
# via apache-superset
tzdata==2023.3
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def get_git_sha() -> str:
"postgres": ["psycopg2-binary==2.9.6"],
"presto": ["pyhive[presto]>=0.6.5"],
"trino": ["trino>=0.324.0"],
"prophet": ["prophet>=1.0.1, <1.1", "pystan<3.0"],
"prophet": ["prophet>=1.1.0, <2.0.0"],
"redshift": ["sqlalchemy-redshift>=0.8.1, < 0.9"],
"rockset": ["rockset>=0.8.10, <0.9"],
"shillelagh": [
Expand Down
9 changes: 8 additions & 1 deletion superset/utils/pandas_postprocessing/prophet.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import logging
from typing import Optional, Union

import pandas as pd
from flask_babel import gettext as _
from pandas import DataFrame

Expand Down Expand Up @@ -134,7 +135,13 @@ def prophet( # pylint: disable=too-many-arguments
raise InvalidPostProcessingError(_("DataFrame include at least one series"))

target_df = DataFrame()
for column in [column for column in df.columns if column != index]:

for column in [
column
for column in df.columns
if column != index
and pd.to_numeric(df[column], errors="coerce").notnull().all()
]:
fit_df = _prophet_fit_and_predict(
df=df[[index, column]].rename(columns={index: "ds", column: "y"}),
confidence_interval=confidence_interval,
Expand Down
4 changes: 2 additions & 2 deletions tests/integration_tests/charts/data/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,11 +444,11 @@ def test_chart_data_dttm_filter(self):
else:
raise Exception("ds column not found")

@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_chart_data_prophet(self):
"""
Chart data API: Ensure prophet post transformation works
"""
pytest.importorskip("prophet")
time_grain = "P1Y"
self.query_context_payload["queries"][0]["is_timeseries"] = True
self.query_context_payload["queries"][0]["groupby"] = []
Expand Down Expand Up @@ -476,7 +476,7 @@ def test_chart_data_prophet(self):
self.assertIn("sum__num__yhat", row)
self.assertIn("sum__num__yhat_upper", row)
self.assertIn("sum__num__yhat_lower", row)
self.assertEqual(result["rowcount"], 47)
self.assertEqual(result["rowcount"], 103)

@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_chart_data_invalid_post_processing(self):
Expand Down
4 changes: 0 additions & 4 deletions tests/unit_tests/pandas_postprocessing/test_prophet.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@


def test_prophet_valid():
pytest.importorskip("prophet")

df = prophet(df=prophet_df, time_grain="P1M", periods=3, confidence_interval=0.9)
columns = {column for column in df.columns}
assert columns == {
Expand Down Expand Up @@ -113,8 +111,6 @@ def test_prophet_valid():


def test_prophet_valid_zero_periods():
pytest.importorskip("prophet")

df = prophet(df=prophet_df, time_grain="P1M", periods=0, confidence_interval=0.9)
columns = {column for column in df.columns}
assert columns == {
Expand Down

0 comments on commit 383dac6

Please sign in to comment.