From dd5a092fba983692883f7f4efe14276eb0059349 Mon Sep 17 00:00:00 2001 From: longxiaofei Date: Fri, 1 Mar 2024 19:25:53 +0800 Subject: [PATCH 1/2] refactor: remove common batch request --- app/src/dataSource/index.ts | 58 ++++++++++-- app/src/index.tsx | 2 +- app/src/utils/communication.tsx | 93 ++++--------------- pygwalker/api/pygwalker.py | 26 +++++- pygwalker/communications/gradio_comm.py | 8 +- pygwalker/communications/hacker_comm.py | 15 +-- pygwalker/communications/streamlit_comm.py | 8 +- pygwalker/data_parsers/base.py | 24 +++++ .../data_parsers/cloud_dataset_parser.py | 14 +++ pygwalker/data_parsers/database_parser.py | 14 +++ pygwalker/data_parsers/spark_parser.py | 14 +++ 11 files changed, 160 insertions(+), 116 deletions(-) diff --git a/app/src/dataSource/index.ts b/app/src/dataSource/index.ts index 1f6b11a..35d64c6 100644 --- a/app/src/dataSource/index.ts +++ b/app/src/dataSource/index.ts @@ -78,6 +78,50 @@ export function finishDataService(msg: any) { ) } +interface IBatchGetDatasTask { + query: any; + resolve: (value: any) => void; + reject: (reason?: any) => void; +} + +function initBatchGetDatas(action: string) { + const taskList = [] as IBatchGetDatasTask[]; + + const batchGetDatas = async(taskList: IBatchGetDatasTask[]) => { + const result = await communicationStore.comm?.sendMsg( + action, + {"queryList": taskList.map(task => task.query)} + ); + if (result) { + for (let i = 0; i < taskList.length; i++) { + taskList[i].resolve(result["data"]["datas"][i]); + } + } else { + for (let i = 0; i < taskList.length; i++) { + taskList[i].reject("get result error"); + } + } + } + + const getDatas = (query: any) => { + return new Promise((resolve, reject) => { + taskList.push({ query, resolve, reject }); + if (taskList.length === 1) { + setTimeout(() => { + batchGetDatas(taskList.splice(0, taskList.length)); + }, 100); + } + }) + } + + return { + getDatas + } +} + +const batchGetDatasBySql = initBatchGetDatas("batch_get_datas_by_sql"); +const batchGetDatasByPayload = initBatchGetDatas("batch_get_datas_by_payload"); + export function getDatasFromKernelBySql(fieldMetas: any) { return async (payload: IDataQueryPayload) => { const sql = parser_dsl_with_meta( @@ -85,18 +129,12 @@ export function getDatasFromKernelBySql(fieldMetas: any) { JSON.stringify(payload), JSON.stringify({"pygwalker_mid_table": fieldMetas}) ); - const result = await communicationStore.comm?.sendMsg( - "get_datas", - {"sql": sql} - ); - return (result ? result["data"]["datas"] : []) as IRow[]; + const result = await batchGetDatasBySql.getDatas(sql); + return (result ?? []) as IRow[]; } } export async function getDatasFromKernelByPayload(payload: IDataQueryPayload) { - const result = await communicationStore.comm?.sendMsg( - "get_datas_by_payload", - {payload} - ); - return (result ? result["data"]["datas"] : []) as IRow[]; + const result = await batchGetDatasByPayload.getDatas(payload); + return (result ?? []) as IRow[]; } diff --git a/app/src/index.tsx b/app/src/index.tsx index 0f5d18a..9d77bf6 100644 --- a/app/src/index.tsx +++ b/app/src/index.tsx @@ -105,7 +105,7 @@ const MainApp = (props: {children: React.ReactNode, darkMode: "dark" | "light" | { props.children } {!props.hideToolBar && ( -
+
{ return document.getElementsByClassName(`hacker-comm-pyg-kernel-store-${gid}-${index}`)[0].childNodes[1] as HTMLInputElement; }) - const requestTask = [] as any[]; const endpoints = new Map any>(); const bufferMap = new Map(); + const fetchOnJupyter = (value: string) => { + const event = new Event("input", { bubbles: true }) + const kernelText = kernelTextList[curKernelTextIndex]; + kernelText.value = value; + kernelText.dispatchEvent(event); + curKernelTextIndex = (curKernelTextIndex + 1) % kernelTextCount; + } + const onMessage = (msg: string) => { const data = JSON.parse(msg); const action = data.action; @@ -89,13 +96,6 @@ const initJupyterCommunication = (gid: string) => { document.dispatchEvent(new CustomEvent(getSignalName(data.rid))); return } - if (action === "finish_batch_request") { - data.data.forEach((resp: any) => { - bufferMap.set(resp.rid, resp.data); - document.dispatchEvent(new CustomEvent(getSignalName(resp.rid))); - }) - return - } const callback = endpoints.get(action); if (callback) { const resp = callback(data.data) ?? {}; @@ -129,21 +129,7 @@ const initJupyterCommunication = (gid: string) => { const sendMsgAsync = (action: string, data: any, rid: string | null) => { rid = rid ?? uuidv4(); - requestTask.push({ action, data, rid }); - if (requestTask.length === 1) { - setTimeout(() => { - batchSendMsgAsync(requestTask.splice(0, requestTask.length)); - }, 100); - } - } - - const batchSendMsgAsync = (data: any) => { - const rid = uuidv4(); - const event = new Event("input", { bubbles: true }) - const kernelText = kernelTextList[curKernelTextIndex]; - kernelText.value = JSON.stringify({ gid: gid, rid: rid, action: "batch_request", data }); - kernelText.dispatchEvent(event); - curKernelTextIndex = (curKernelTextIndex + 1) % kernelTextCount; + fetchOnJupyter(JSON.stringify({ gid, rid, action, data })); } const registerEndpoint = (action: string, callback: (data: any) => any) => { @@ -187,17 +173,6 @@ const initJupyterCommunication = (gid: string) => { } } -interface IHttpRequestTask { - data: { - action: string; - data: any; - rid: string; - gid: string; - }; - resolve: (resp: any) => void; - reject: (err: any) => void; -} - const initHttpCommunication = (gid: string, baseUrl: string) => { // temporary solution in streamlit could const domain = window.parent.document.location.host.split(".").slice(-2).join('.'); @@ -208,8 +183,6 @@ const initHttpCommunication = (gid: string, baseUrl: string) => { url = `/${baseUrl}/${gid}` } - const requestTask = [] as IHttpRequestTask[]; - const sendMsg = async(action: string, data: any, timeout: number = 30_000) => { const timer = setTimeout(() => { raiseRequestError("communication timeout", 0); @@ -227,48 +200,16 @@ const initHttpCommunication = (gid: string, baseUrl: string) => { } } - const sendMsgAsync = (action: string, data: any) => { + const sendMsgAsync = async(action: string, data: any) => { const rid = uuidv4(); - const promise = new Promise((resolve, reject) => { - requestTask.push({ - data: { action, data, rid, gid }, - resolve, - reject, - }); - if (requestTask.length === 1) { - setTimeout(() => { - batchSendMsgAsync(requestTask.splice(0, requestTask.length)); - }, 100); + return await (await fetch( + url, + { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ action, data, rid, gid }), } - }); - return promise; - } - - const batchSendMsgAsync = async(taskList: IHttpRequestTask[]) => { - try { - const resp = await fetch( - url, - { - method: "POST", - headers: { "Content-Type": "application/json" }, - body: JSON.stringify({ - gid: gid, - rid: uuidv4(), - action: "batch_request", - data: taskList.map(task => task.data) - }), - } - ) - const respJson = await resp.json(); - taskList.forEach((task, index) => { - const taskResp = respJson[index]; - task.resolve(taskResp); - }) - } catch (err) { - taskList.forEach(task => { - task.reject(err); - }) - } + )).json(); } const registerEndpoint = (_: string, __: (data: any) => any) => {} diff --git a/pygwalker/api/pygwalker.py b/pygwalker/api/pygwalker.py index e5c714d..59c238e 100644 --- a/pygwalker/api/pygwalker.py +++ b/pygwalker/api/pygwalker.py @@ -41,6 +41,8 @@ from pygwalker.errors import DataCountLimitError from pygwalker import __version__ +RESPONSE_MAX_DATA_LENGTH = 1 * 1000 * 1000 + class PygWalker: """PygWalker""" @@ -334,7 +336,7 @@ def upload_spec_to_cloud(data: Dict[str, Any]): def _get_datas(data: Dict[str, Any]): sql = data["sql"] datas = self.data_parser.get_datas_by_sql(sql) - if len(datas) > 1 * 1000 * 1000: + if len(datas) > RESPONSE_MAX_DATA_LENGTH: raise DataCountLimitError() return { "datas": datas @@ -342,12 +344,30 @@ def _get_datas(data: Dict[str, Any]): def _get_datas_by_payload(data: Dict[str, Any]): datas = self.data_parser.get_datas_by_payload(data["payload"]) - if len(datas) > 1 * 1000 * 1000: + if len(datas) > RESPONSE_MAX_DATA_LENGTH: raise DataCountLimitError() return { "datas": datas } + def _batch_get_datas_by_sql(data: Dict[str, Any]): + result = self.data_parser.batch_get_datas_by_sql(data["queryList"]) + for datas in result: + if len(datas) > RESPONSE_MAX_DATA_LENGTH: + raise DataCountLimitError() + return { + "datas": result + } + + def _batch_get_datas_by_payload(data: Dict[str, Any]): + result = self.data_parser.batch_get_datas_by_payload(data["queryList"]) + for datas in result: + if len(datas) > RESPONSE_MAX_DATA_LENGTH: + raise DataCountLimitError() + return { + "datas": result + } + def _get_spec_by_text(data: Dict[str, Any]): return { "data": self.cloud_service.get_spec_by_text(data["metas"], data["query"]) @@ -403,6 +423,8 @@ def _upload_to_cloud_dashboard(data: Dict[str, Any]): if self.use_kernel_calc: comm.register("get_datas", _get_datas) comm.register("get_datas_by_payload", _get_datas_by_payload) + comm.register("batch_get_datas_by_sql", _batch_get_datas_by_sql) + comm.register("batch_get_datas_by_payload", _batch_get_datas_by_payload) if self.is_export_dataframe: comm.register("export_dataframe_by_payload", _export_dataframe_by_payload) diff --git a/pygwalker/communications/gradio_comm.py b/pygwalker/communications/gradio_comm.py index 9511121..7fd93ab 100644 --- a/pygwalker/communications/gradio_comm.py +++ b/pygwalker/communications/gradio_comm.py @@ -22,13 +22,7 @@ async def _pygwalker_router(req: Request) -> Response: json_data = await req.json() # pylint: disable=protected-access - if json_data["action"] == "batch_request": - result = [ - comm_obj._receive_msg(request["action"], request["data"]) - for request in json_data["data"] - ] - else: - result = comm_obj._receive_msg(json_data["action"], json_data["data"]) + result = comm_obj._receive_msg(json_data["action"], json_data["data"]) # pylint: enable=protected-access result = json.dumps(result, cls=DataFrameEncoder) diff --git a/pygwalker/communications/hacker_comm.py b/pygwalker/communications/hacker_comm.py index 65228cc..185ef05 100644 --- a/pygwalker/communications/hacker_comm.py +++ b/pygwalker/communications/hacker_comm.py @@ -61,19 +61,8 @@ def _on_mesage(self, info: Dict[str, Any]): if action == "finish_request": return - if action == "batch_request": - resp = [ - { - "data": self._receive_msg(request["action"], request["data"]), - "rid": request["rid"] - } - for request in data - if request["action"] != "finish_request" - ] - self.send_msg_async("finish_batch_request", resp, rid) - else: - resp = self._receive_msg(action, data) - self.send_msg_async("finish_request", resp, rid) + resp = self._receive_msg(action, data) + self.send_msg_async("finish_request", resp, rid) def _get_html_widget(self) -> Text: text = Text(value="", placeholder="") diff --git a/pygwalker/communications/streamlit_comm.py b/pygwalker/communications/streamlit_comm.py index d70f965..3520a42 100644 --- a/pygwalker/communications/streamlit_comm.py +++ b/pygwalker/communications/streamlit_comm.py @@ -35,13 +35,7 @@ def post(self, gid: str): json_data = json.loads(self.request.body) # pylint: disable=protected-access - if json_data["action"] == "batch_request": - result = [ - comm_obj._receive_msg(request["action"], request["data"]) - for request in json_data["data"] - ] - else: - result = comm_obj._receive_msg(json_data["action"], json_data["data"]) + result = comm_obj._receive_msg(json_data["action"], json_data["data"]) # pylint: enable=protected-access self.write(json.dumps(result, cls=DataFrameEncoder)) diff --git a/pygwalker/data_parsers/base.py b/pygwalker/data_parsers/base.py index 47d9031..c63faa1 100644 --- a/pygwalker/data_parsers/base.py +++ b/pygwalker/data_parsers/base.py @@ -69,6 +69,16 @@ def get_datas_by_payload(self, payload: Dict[str, Any]) -> List[Dict[str, Any]]: """get records""" raise NotImplementedError + @abc.abstractmethod + def batch_get_datas_by_sql(self, sql_list: List[str]) -> List[List[Dict[str, Any]]]: + """batch get records""" + raise NotImplementedError + + @abc.abstractmethod + def batch_get_datas_by_payload(self, payload_list: List[Dict[str, Any]]) -> List[List[Dict[str, Any]]]: + """batch get records""" + raise NotImplementedError + @abc.abstractmethod def to_csv(self) -> io.BytesIO: """get records""" @@ -186,6 +196,20 @@ def get_datas_by_payload(self, payload: Dict[str, Any]) -> List[Dict[str, Any]]: ) return self.get_datas_by_sql(sql) + def batch_get_datas_by_sql(self, sql_list: List[str]) -> List[List[Dict[str, Any]]]: + """batch get records""" + return [ + self.get_datas_by_sql(sql) + for sql in sql_list + ] + + def batch_get_datas_by_payload(self, payload_list: List[Dict[str, Any]]) -> List[List[Dict[str, Any]]]: + """batch get records""" + return [ + self.get_datas_by_payload(payload) + for payload in payload_list + ] + @property def dataset_tpye(self) -> str: return "dataframe_default" diff --git a/pygwalker/data_parsers/cloud_dataset_parser.py b/pygwalker/data_parsers/cloud_dataset_parser.py index a564065..fa647e7 100644 --- a/pygwalker/data_parsers/cloud_dataset_parser.py +++ b/pygwalker/data_parsers/cloud_dataset_parser.py @@ -90,6 +90,20 @@ def _get_all_datas(self, limit: int) -> List[Dict[str, Any]]: payload = {"workflow": [{"type": "view", "query": [{"op": "raw", "fields": ["*"]}]}], "limit": limit, "offset": 0} return self.get_datas_by_payload(payload) + def batch_get_datas_by_sql(self, sql_list: List[str]) -> List[List[Dict[str, Any]]]: + """batch get records""" + return [ + self.get_datas_by_sql(sql) + for sql in sql_list + ] + + def batch_get_datas_by_payload(self, payload_list: List[Dict[str, Any]]) -> List[List[Dict[str, Any]]]: + """batch get records""" + return [ + self.get_datas_by_payload(payload) + for payload in payload_list + ] + @property def dataset_tpye(self) -> str: return "cloud_dataset" diff --git a/pygwalker/data_parsers/database_parser.py b/pygwalker/data_parsers/database_parser.py index bd8bd75..64723f8 100644 --- a/pygwalker/data_parsers/database_parser.py +++ b/pygwalker/data_parsers/database_parser.py @@ -185,6 +185,20 @@ def to_parquet(self) -> io.BytesIO: self.example_pandas_df.toPandas().to_parquet(content, index=False, compression="snappy") return content + def batch_get_datas_by_sql(self, sql_list: List[str]) -> List[List[Dict[str, Any]]]: + """batch get records""" + return [ + self.get_datas_by_sql(sql) + for sql in sql_list + ] + + def batch_get_datas_by_payload(self, payload_list: List[Dict[str, Any]]) -> List[List[Dict[str, Any]]]: + """batch get records""" + return [ + self.get_datas_by_payload(payload) + for payload in payload_list + ] + @property def dataset_tpye(self) -> str: return f"connector_{self.conn.dialect_name}" diff --git a/pygwalker/data_parsers/spark_parser.py b/pygwalker/data_parsers/spark_parser.py index 248d8d4..e7c187f 100644 --- a/pygwalker/data_parsers/spark_parser.py +++ b/pygwalker/data_parsers/spark_parser.py @@ -76,6 +76,20 @@ def get_datas_by_payload(self, payload: Dict[str, Any]) -> List[Dict[str, Any]]: ) return self.get_datas_by_sql(sql) + def batch_get_datas_by_sql(self, sql_list: List[str]) -> List[List[Dict[str, Any]]]: + """batch get records""" + return [ + self.get_datas_by_sql(sql) + for sql in sql_list + ] + + def batch_get_datas_by_payload(self, payload_list: List[Dict[str, Any]]) -> List[List[Dict[str, Any]]]: + """batch get records""" + return [ + self.get_datas_by_payload(payload) + for payload in payload_list + ] + def to_csv(self) -> io.BytesIO: content = io.BytesIO() self.df.toPandas().to_csv(content, index=False) From 77be1167e2f9456976d3377835ebfd249c71244e Mon Sep 17 00:00:00 2001 From: longxiaofei Date: Fri, 1 Mar 2024 20:57:31 +0800 Subject: [PATCH 2/2] feat: using batch api when query cloud dataset --- .../data_parsers/cloud_dataset_parser.py | 10 +++---- pygwalker/services/cloud_service.py | 27 +++++++++++++++---- 2 files changed, 26 insertions(+), 11 deletions(-) diff --git a/pygwalker/data_parsers/cloud_dataset_parser.py b/pygwalker/data_parsers/cloud_dataset_parser.py index fa647e7..7700c73 100644 --- a/pygwalker/data_parsers/cloud_dataset_parser.py +++ b/pygwalker/data_parsers/cloud_dataset_parser.py @@ -92,16 +92,14 @@ def _get_all_datas(self, limit: int) -> List[Dict[str, Any]]: def batch_get_datas_by_sql(self, sql_list: List[str]) -> List[List[Dict[str, Any]]]: """batch get records""" - return [ - self.get_datas_by_sql(sql) - for sql in sql_list - ] + pass def batch_get_datas_by_payload(self, payload_list: List[Dict[str, Any]]) -> List[List[Dict[str, Any]]]: """batch get records""" + result = self._cloud_service.batch_query_from_dataset(self.dataset_id, payload_list) return [ - self.get_datas_by_payload(payload) - for payload in payload_list + item["rows"] + for item in result ] @property diff --git a/pygwalker/services/cloud_service.py b/pygwalker/services/cloud_service.py index 053446d..c6e5835 100644 --- a/pygwalker/services/cloud_service.py +++ b/pygwalker/services/cloud_service.py @@ -64,11 +64,20 @@ def send(self, request: requests.PreparedRequest, **kwargs) -> requests.Response resp_json = resp.json() except Exception as e: raise CloudFunctionError(f"Request failed: {resp.text}") from e - if resp_json["success"] is False: - raise CloudFunctionError( - f"Request failed: {resp_json['message']}", - code=resp_json["code"] if resp_json["code"] != 0 else ErrorCode.UNKNOWN_ERROR - ) + + if "success" in resp_json: + if resp_json["success"] is False: + raise CloudFunctionError( + f"Request failed: {resp_json['message']}", + code=resp_json["code"] if resp_json["code"] != 0 else ErrorCode.UNKNOWN_ERROR + ) + else: + if resp.status_code != 200: + raise CloudFunctionError( + f"Request failed: {resp_json['error']['message']}", + code=resp_json["error"]["code"] + ) + return resp @@ -316,6 +325,14 @@ def query_from_dataset(self, dataset_id: str, payload: Dict[str, Any]) -> List[D resp = self.session.post(url, json=params, timeout=15) return resp.json()["data"] + def batch_query_from_dataset(self, dataset_id: str, query_list: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + url = f"{GlobalVarManager.kanaries_api_host}/v1/dataset/{dataset_id}/query" + params = { + "query": query_list, + } + resp = self.session.post(url, json=params, timeout=40) + return resp.json()["data"] + def create_cloud_dataset( self, data_parser: BaseDataParser,