Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 28 additions & 3 deletions sqlmesh/core/console.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from __future__ import annotations

import abc
import asyncio
import datetime
import json
import time
import typing as t
import unittest
Expand All @@ -18,7 +20,8 @@
from sqlmesh.core.snapshot import Snapshot, SnapshotChangeCategory
from sqlmesh.core.test import ModelTest
from sqlmesh.utils import rich as srich
from sqlmesh.utils.date import to_date
from sqlmesh.utils.date import now_timestamp, to_date
from web.server.sse import Event

if t.TYPE_CHECKING:
import ipywidgets as widgets
Expand Down Expand Up @@ -683,13 +686,14 @@ def __init__(self) -> None:
super().__init__()
self.current_task_status: t.Dict[str, t.Dict[str, int]] = {}
self.previous_task_status: t.Dict[str, t.Dict[str, int]] = {}
self.queue: asyncio.Queue = asyncio.Queue()

def start_snapshot_progress(self, snapshot_name: str, total_batches: int) -> None:
"""Indicates that a new load progress has begun."""
self.current_task_status[snapshot_name] = {
"completed": 0,
"total": total_batches,
"start": int(time.time()),
"start": now_timestamp(),
}

def update_snapshot_progress(self, snapshot_name: str, num_batches: int) -> None:
Expand All @@ -701,17 +705,38 @@ def update_snapshot_progress(self, snapshot_name: str, num_batches: int) -> None
>= self.current_task_status[snapshot_name]["total"]
):
self.current_task_status[snapshot_name]["end"] = int(time.time())
self.queue.put_nowait(
Event(
event="tasks",
data=json.dumps(
{
"ok": True,
"tasks": self.current_task_status,
"timestamp": now_timestamp(),
}
),
)
)

def complete_snapshot_progress(self) -> None:
"""Indicates that load progress is complete"""
self.log_success("All model batches have been executed successfully")
self.queue.put_nowait("All model batches have been executed successfully")
self.stop_snapshot_progress()

def stop_snapshot_progress(self) -> None:
"""Stop the load progress"""
self.previous_task_status = self.current_task_status.copy()
self.current_task_status = {}

def log_test_results(
self, result: unittest.result.TestResult, output: str, target_dialect: str
) -> None:
self.queue.put_nowait(
Event(
data=f"Successfully ran {str(result.testsRun)} tests against {target_dialect}"
)
)


def get_console() -> TerminalConsole:
"""
Expand Down
17 changes: 17 additions & 0 deletions web/server/api/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,23 @@ async def running_tasks() -> t.AsyncGenerator:
return SSEResponse(running_tasks())


@router.get("/events")
async def events(
request: Request,
) -> SSEResponse:
async def generator() -> t.AsyncGenerator:
queue: asyncio.Queue = asyncio.Queue()
request.app.state.console_listeners.append(queue)
try:
while True:
yield await queue.get()
queue.task_done()
finally:
request.app.state.console_listeners.remove(queue)

return SSEResponse(generator())


@router.post("/plan/cancel")
async def cancel(
request: Request,
Expand Down
22 changes: 22 additions & 0 deletions web/server/main.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,34 @@
import asyncio

from fastapi import FastAPI

from sqlmesh.core.console import ApiConsole
from web.server.api.endpoints import router

app = FastAPI()
api_console = ApiConsole()

app.include_router(router, prefix="/api")


@app.on_event("startup")
async def startup_event() -> None:
async def dispatch() -> None:
while True:
item = await api_console.queue.get()
for listener in app.state.console_listeners:
await listener.put(item)
api_console.queue.task_done()

app.state.console_listeners = []
app.state.dispatch_task = asyncio.create_task(dispatch())


@app.on_event("shutdown")
def shutdown_event() -> None:
app.state.dispatch_task.cancel()


@app.get("/health")
def health() -> str:
return "ok"
5 changes: 3 additions & 2 deletions web/server/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from fastapi import Depends
from pydantic import BaseSettings

from sqlmesh.core.console import ApiConsole
from sqlmesh.core.context import Context


Expand All @@ -25,7 +24,9 @@ def _get_context(path: str) -> Context:

@lru_cache()
def _get_loaded_context(path: str, config: str) -> Context:
return Context(path=path, config=config, console=ApiConsole())
from web.server.main import api_console

return Context(path=path, config=config, console=api_console)


def get_loaded_context(settings: Settings = Depends(get_settings)) -> Context:
Expand Down