Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add configuration for Presto cursor poll interval #10191

Merged
merged 1 commit into from Jun 30, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 7 additions & 1 deletion superset/db_engine_specs/presto.py
Expand Up @@ -51,6 +51,9 @@
config = app.config
logger = logging.getLogger(__name__)

# See here: https://github.com/dropbox/PyHive/blob/8eb0aeab8ca300f3024655419b93dad926c1a351/pyhive/presto.py#L93 # pylint: disable=line-too-long
DEFAULT_PYHIVE_POLL_INTERVAL = 1


def get_children(column: Dict[str, str]) -> List[Dict[str, str]]:
"""
Expand Down Expand Up @@ -729,6 +732,9 @@ def get_create_view(
def handle_cursor(cls, cursor: Any, query: Query, session: Session) -> None:
"""Updates progress information"""
query_id = query.id
poll_interval = query.database.connect_args.get(
"poll_interval", DEFAULT_PYHIVE_POLL_INTERVAL
)
logger.info("Query %i: Polling the cursor for progress", query_id)
polled = cursor.poll()
# poll returns dict -- JSON status information or ``None``
Expand Down Expand Up @@ -762,7 +768,7 @@ def handle_cursor(cls, cursor: Any, query: Query, session: Session) -> None:
if progress > query.progress:
query.progress = progress
session.commit()
time.sleep(1)
time.sleep(poll_interval)
logger.info("Query %i: Polling the cursor for progress", query_id)
polled = cursor.poll()

Expand Down
4 changes: 4 additions & 0 deletions superset/models/core.py
Expand Up @@ -244,6 +244,10 @@ def table_cache_timeout(self) -> Optional[int]:
def default_schemas(self) -> List[str]:
return self.get_extra().get("default_schemas", [])

@property
def connect_args(self) -> Dict[str, Any]:
return self.get_extra().get("engine_params", {}).get("connect_args", {})

@classmethod
def get_password_masked_url_from_uri( # pylint: disable=invalid-name
cls, uri: str
Expand Down