diff --git a/sqlmesh/core/console.py b/sqlmesh/core/console.py index 088a0661bc..0a3dd06b7f 100644 --- a/sqlmesh/core/console.py +++ b/sqlmesh/core/console.py @@ -1,7 +1,9 @@ from __future__ import annotations import abc +import asyncio import datetime +import json import time import typing as t import unittest @@ -18,7 +20,8 @@ from sqlmesh.core.snapshot import Snapshot, SnapshotChangeCategory from sqlmesh.core.test import ModelTest from sqlmesh.utils import rich as srich -from sqlmesh.utils.date import to_date +from sqlmesh.utils.date import now_timestamp, to_date +from web.server.sse import Event if t.TYPE_CHECKING: import ipywidgets as widgets @@ -683,13 +686,14 @@ def __init__(self) -> None: super().__init__() self.current_task_status: t.Dict[str, t.Dict[str, int]] = {} self.previous_task_status: t.Dict[str, t.Dict[str, int]] = {} + self.queue: asyncio.Queue = asyncio.Queue() def start_snapshot_progress(self, snapshot_name: str, total_batches: int) -> None: """Indicates that a new load progress has begun.""" self.current_task_status[snapshot_name] = { "completed": 0, "total": total_batches, - "start": int(time.time()), + "start": now_timestamp(), } def update_snapshot_progress(self, snapshot_name: str, num_batches: int) -> None: @@ -701,10 +705,22 @@ def update_snapshot_progress(self, snapshot_name: str, num_batches: int) -> None >= self.current_task_status[snapshot_name]["total"] ): self.current_task_status[snapshot_name]["end"] = int(time.time()) + self.queue.put_nowait( + Event( + event="tasks", + data=json.dumps( + { + "ok": True, + "tasks": self.current_task_status, + "timestamp": now_timestamp(), + } + ), + ) + ) def complete_snapshot_progress(self) -> None: """Indicates that load progress is complete""" - self.log_success("All model batches have been executed successfully") + self.queue.put_nowait("All model batches have been executed successfully") self.stop_snapshot_progress() def stop_snapshot_progress(self) -> None: @@ -712,6 +728,15 @@ def stop_snapshot_progress(self) -> None: self.previous_task_status = self.current_task_status.copy() self.current_task_status = {} + def log_test_results( + self, result: unittest.result.TestResult, output: str, target_dialect: str + ) -> None: + self.queue.put_nowait( + Event( + data=f"Successfully ran {str(result.testsRun)} tests against {target_dialect}" + ) + ) + def get_console() -> TerminalConsole: """ diff --git a/web/server/api/endpoints.py b/web/server/api/endpoints.py index 34cf020186..b9b874b41b 100644 --- a/web/server/api/endpoints.py +++ b/web/server/api/endpoints.py @@ -285,6 +285,23 @@ async def running_tasks() -> t.AsyncGenerator: return SSEResponse(running_tasks()) +@router.get("/events") +async def events( + request: Request, +) -> SSEResponse: + async def generator() -> t.AsyncGenerator: + queue: asyncio.Queue = asyncio.Queue() + request.app.state.console_listeners.append(queue) + try: + while True: + yield await queue.get() + queue.task_done() + finally: + request.app.state.console_listeners.remove(queue) + + return SSEResponse(generator()) + + @router.post("/plan/cancel") async def cancel( request: Request, diff --git a/web/server/main.py b/web/server/main.py index edc4c5e1c8..4fb4f0ce82 100644 --- a/web/server/main.py +++ b/web/server/main.py @@ -1,12 +1,34 @@ +import asyncio + from fastapi import FastAPI +from sqlmesh.core.console import ApiConsole from web.server.api.endpoints import router app = FastAPI() +api_console = ApiConsole() app.include_router(router, prefix="/api") +@app.on_event("startup") +async def startup_event() -> None: + async def dispatch() -> None: + while True: + item = await api_console.queue.get() + for listener in app.state.console_listeners: + await listener.put(item) + api_console.queue.task_done() + + app.state.console_listeners = [] + app.state.dispatch_task = asyncio.create_task(dispatch()) + + +@app.on_event("shutdown") +def shutdown_event() -> None: + app.state.dispatch_task.cancel() + + @app.get("/health") def health() -> str: return "ok" diff --git a/web/server/settings.py b/web/server/settings.py index d361edb37b..9f8744f657 100644 --- a/web/server/settings.py +++ b/web/server/settings.py @@ -4,7 +4,6 @@ from fastapi import Depends from pydantic import BaseSettings -from sqlmesh.core.console import ApiConsole from sqlmesh.core.context import Context @@ -25,7 +24,9 @@ def _get_context(path: str) -> Context: @lru_cache() def _get_loaded_context(path: str, config: str) -> Context: - return Context(path=path, config=config, console=ApiConsole()) + from web.server.main import api_console + + return Context(path=path, config=config, console=api_console) def get_loaded_context(settings: Settings = Depends(get_settings)) -> Context: