Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 32 additions & 1 deletion web/server/tests/test_main.py → tests/web/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
import pytest
from fastapi.testclient import TestClient

from sqlmesh.core.context import Context
from web.server.main import app
from web.server.settings import Settings, get_settings
from web.server.settings import Settings, get_loaded_context, get_settings

client = TestClient(app)

Expand All @@ -26,6 +27,15 @@ def get_settings_override() -> Settings:
return tmp_path


@pytest.fixture
def web_sushi_context(sushi_context: Context) -> Context:
def get_context_override() -> Context:
return sushi_context

app.dependency_overrides[get_loaded_context] = get_context_override
return sushi_context


def test_get_files(project_tmp_path: Path) -> None:
models_dir = project_tmp_path / "models"
models_dir.mkdir()
Expand Down Expand Up @@ -213,3 +223,24 @@ def test_cancel_no_task() -> None:
response = client.post("/api/plan/cancel")
assert response.status_code == 422
assert response.json() == {"detail": "No active task found."}


def test_evaluate(web_sushi_context: Context) -> None:
response = client.post(
"/api/evaluate",
json={
"model": "sushi.top_waiters",
"start": "2022-01-01",
"end": "now",
"latest": "now",
"limit": 100,
},
)
assert response.status_code == 200
assert response.json()


def test_fetchdf(web_sushi_context: Context) -> None:
response = client.post("/api/fetchdf", content='"SELECT * from sushi.top_waiters"')
assert response.status_code == 200
assert response.json()
39 changes: 39 additions & 0 deletions web/server/api/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
import functools
import json
import os
import traceback
import typing as t
from pathlib import Path

import pandas as pd
from fastapi import APIRouter, Body, Depends, HTTPException, Request, Response, status
from starlette.status import HTTP_404_NOT_FOUND, HTTP_422_UNPROCESSABLE_ENTITY

Expand Down Expand Up @@ -291,3 +293,40 @@ async def cancel(
status_code=HTTP_422_UNPROCESSABLE_ENTITY, detail="No active task found."
)
response.status_code = status.HTTP_204_NO_CONTENT


@router.post("/evaluate")
async def evaluate(
options: models.EvaluateInput,
context: Context = Depends(get_loaded_context),
) -> t.Optional[str]:
"""Evaluate a model with a default limit of 1000"""
try:
df = context.evaluate(
options.model,
start=options.start,
end=options.end,
latest=options.latest,
limit=options.limit,
)
except Exception:
raise HTTPException(
status_code=HTTP_422_UNPROCESSABLE_ENTITY, detail=traceback.format_exc()
)
if isinstance(df, pd.DataFrame):
return df.to_json()
return df.toPandas().to_json()


@router.post("/fetchdf")
async def fetchdf(
sql: str = Body(),
context: Context = Depends(get_loaded_context),
) -> t.Optional[str]:
"""Fetches a dataframe given a sql string"""
try:
return context.fetchdf(sql).to_json()
except Exception:
raise HTTPException(
status_code=HTTP_422_UNPROCESSABLE_ENTITY, detail=traceback.format_exc()
)
9 changes: 9 additions & 0 deletions web/server/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pydantic import validator

from sqlmesh.core.context_diff import ContextDiff
from sqlmesh.utils.date import TimeLike
from sqlmesh.utils.pydantic import PydanticModel

SUPPORTED_EXTENSIONS = {".py", ".sql", ".yaml"}
Expand Down Expand Up @@ -99,3 +100,11 @@ class ContextEnvironment(PydanticModel):
environment: str
changes: t.Optional[ContextEnvironmentChanges]
backfills: t.List[ContextEnvironmentBackfill] = []


class EvaluateInput(PydanticModel):
model: str
start: TimeLike
end: TimeLike
latest: TimeLike
limit: t.Optional[int] = None
Empty file removed web/server/tests/__init__.py
Empty file.