From 903260584c376a61bfb71409c02ea89d4018637b Mon Sep 17 00:00:00 2001 From: Vincent Chan Date: Tue, 7 Feb 2023 11:42:54 -0800 Subject: [PATCH 1/2] Add evaluate and fetchdf API endpoints --- {web/server/tests => tests/web}/test_main.py | 33 +++++++++++++++++++- web/server/api/endpoints.py | 28 +++++++++++++++++ web/server/models.py | 9 ++++++ web/server/tests/__init__.py | 0 4 files changed, 69 insertions(+), 1 deletion(-) rename {web/server/tests => tests/web}/test_main.py (86%) delete mode 100644 web/server/tests/__init__.py diff --git a/web/server/tests/test_main.py b/tests/web/test_main.py similarity index 86% rename from web/server/tests/test_main.py rename to tests/web/test_main.py index 1ccb6bd0cc..a295f841fc 100644 --- a/web/server/tests/test_main.py +++ b/tests/web/test_main.py @@ -4,8 +4,9 @@ import pytest from fastapi.testclient import TestClient +from sqlmesh.core.context import Context from web.server.main import app -from web.server.settings import Settings, get_settings +from web.server.settings import Settings, get_loaded_context, get_settings client = TestClient(app) @@ -26,6 +27,15 @@ def get_settings_override() -> Settings: return tmp_path +@pytest.fixture +def web_sushi_context(sushi_context: Context) -> Context: + def get_context_override() -> Context: + return sushi_context + + app.dependency_overrides[get_loaded_context] = get_context_override + return sushi_context + + def test_get_files(project_tmp_path: Path) -> None: models_dir = project_tmp_path / "models" models_dir.mkdir() @@ -213,3 +223,24 @@ def test_cancel_no_task() -> None: response = client.post("/api/plan/cancel") assert response.status_code == 422 assert response.json() == {"detail": "No active task found."} + + +def test_evaluate(web_sushi_context: Context) -> None: + response = client.post( + "/api/evaluate", + json={ + "model": "sushi.top_waiters", + "start": "2022-01-01", + "end": "now", + "latest": "now", + "limit": 100, + }, + ) + assert response.status_code == 200 + assert response.json() + + +def test_fetchdf(web_sushi_context: Context) -> None: + response = client.post("/api/fetchdf", content='"SELECT * from sushi.top_waiters"') + assert response.status_code == 200 + assert response.json() diff --git a/web/server/api/endpoints.py b/web/server/api/endpoints.py index b39a46fa42..8e615b66f0 100644 --- a/web/server/api/endpoints.py +++ b/web/server/api/endpoints.py @@ -7,6 +7,7 @@ import typing as t from pathlib import Path +import pandas as pd from fastapi import APIRouter, Body, Depends, HTTPException, Request, Response, status from starlette.status import HTTP_404_NOT_FOUND, HTTP_422_UNPROCESSABLE_ENTITY @@ -291,3 +292,30 @@ async def cancel( status_code=HTTP_422_UNPROCESSABLE_ENTITY, detail="No active task found." ) response.status_code = status.HTTP_204_NO_CONTENT + + +@router.post("/evaluate") +async def evaluate( + options: models.EvaluateInput, + context: Context = Depends(get_loaded_context), +) -> str: + """Evaluate a model with a default limit of 1000""" + df = context.evaluate( + options.model, + start=options.start, + end=options.end, + latest=options.latest, + limit=options.limit, + ) + if isinstance(df, pd.DataFrame): + return df.to_json() + return df.toPandas().to_json() + + +@router.post("/fetchdf") +async def fetchdf( + sql: str = Body(), + context: Context = Depends(get_loaded_context), +) -> str: + """Fetches a dataframe given a sql string""" + return context.fetchdf(sql).to_json() diff --git a/web/server/models.py b/web/server/models.py index d3b40c0f9b..fa0de61456 100644 --- a/web/server/models.py +++ b/web/server/models.py @@ -6,6 +6,7 @@ from pydantic import validator from sqlmesh.core.context_diff import ContextDiff +from sqlmesh.utils.date import TimeLike from sqlmesh.utils.pydantic import PydanticModel SUPPORTED_EXTENSIONS = {".py", ".sql", ".yaml"} @@ -99,3 +100,11 @@ class ContextEnvironment(PydanticModel): environment: str changes: t.Optional[ContextEnvironmentChanges] backfills: t.List[ContextEnvironmentBackfill] = [] + + +class EvaluateInput(PydanticModel): + model: str + start: TimeLike + end: TimeLike + latest: TimeLike + limit: t.Optional[int] = None diff --git a/web/server/tests/__init__.py b/web/server/tests/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 From 7b9d3d4980e7a37a91c519e9117162acbce7cb50 Mon Sep 17 00:00:00 2001 From: Vincent Chan Date: Tue, 7 Feb 2023 12:48:53 -0800 Subject: [PATCH 2/2] Return exception stack trace to frontend --- web/server/api/endpoints.py | 31 +++++++++++++++++++++---------- 1 file changed, 21 insertions(+), 10 deletions(-) diff --git a/web/server/api/endpoints.py b/web/server/api/endpoints.py index 8e615b66f0..1d3c97dc5d 100644 --- a/web/server/api/endpoints.py +++ b/web/server/api/endpoints.py @@ -4,6 +4,7 @@ import functools import json import os +import traceback import typing as t from pathlib import Path @@ -298,15 +299,20 @@ async def cancel( async def evaluate( options: models.EvaluateInput, context: Context = Depends(get_loaded_context), -) -> str: +) -> t.Optional[str]: """Evaluate a model with a default limit of 1000""" - df = context.evaluate( - options.model, - start=options.start, - end=options.end, - latest=options.latest, - limit=options.limit, - ) + try: + df = context.evaluate( + options.model, + start=options.start, + end=options.end, + latest=options.latest, + limit=options.limit, + ) + except Exception: + raise HTTPException( + status_code=HTTP_422_UNPROCESSABLE_ENTITY, detail=traceback.format_exc() + ) if isinstance(df, pd.DataFrame): return df.to_json() return df.toPandas().to_json() @@ -316,6 +322,11 @@ async def evaluate( async def fetchdf( sql: str = Body(), context: Context = Depends(get_loaded_context), -) -> str: +) -> t.Optional[str]: """Fetches a dataframe given a sql string""" - return context.fetchdf(sql).to_json() + try: + return context.fetchdf(sql).to_json() + except Exception: + raise HTTPException( + status_code=HTTP_422_UNPROCESSABLE_ENTITY, detail=traceback.format_exc() + )