Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: using batch api when query cloud dataset #452

Merged
merged 2 commits into from
Mar 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
58 changes: 48 additions & 10 deletions app/src/dataSource/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -78,25 +78,63 @@ export function finishDataService(msg: any) {
)
}

interface IBatchGetDatasTask {
query: any;
resolve: (value: any) => void;
reject: (reason?: any) => void;
}

function initBatchGetDatas(action: string) {
const taskList = [] as IBatchGetDatasTask[];

const batchGetDatas = async(taskList: IBatchGetDatasTask[]) => {
const result = await communicationStore.comm?.sendMsg(
action,
{"queryList": taskList.map(task => task.query)}
);
if (result) {
for (let i = 0; i < taskList.length; i++) {
taskList[i].resolve(result["data"]["datas"][i]);
}
} else {
for (let i = 0; i < taskList.length; i++) {
taskList[i].reject("get result error");
}
}
}

const getDatas = (query: any) => {
return new Promise<any>((resolve, reject) => {
taskList.push({ query, resolve, reject });
if (taskList.length === 1) {
setTimeout(() => {
batchGetDatas(taskList.splice(0, taskList.length));
}, 100);
}
})
}

return {
getDatas
}
}

const batchGetDatasBySql = initBatchGetDatas("batch_get_datas_by_sql");
const batchGetDatasByPayload = initBatchGetDatas("batch_get_datas_by_payload");

export function getDatasFromKernelBySql(fieldMetas: any) {
return async (payload: IDataQueryPayload) => {
const sql = parser_dsl_with_meta(
"pygwalker_mid_table",
JSON.stringify(payload),
JSON.stringify({"pygwalker_mid_table": fieldMetas})
);
const result = await communicationStore.comm?.sendMsg(
"get_datas",
{"sql": sql}
);
return (result ? result["data"]["datas"] : []) as IRow[];
const result = await batchGetDatasBySql.getDatas(sql);
return (result ?? []) as IRow[];
}
}

export async function getDatasFromKernelByPayload(payload: IDataQueryPayload) {
const result = await communicationStore.comm?.sendMsg(
"get_datas_by_payload",
{payload}
);
return (result ? result["data"]["datas"] : []) as IRow[];
const result = await batchGetDatasByPayload.getDatas(payload);
return (result ?? []) as IRow[];
}
2 changes: 1 addition & 1 deletion app/src/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ const MainApp = (props: {children: React.ReactNode, darkMode: "dark" | "light" |
<style>{style}</style>
{ props.children }
{!props.hideToolBar && (
<div className="flex w-full p-1 overflow-hidden border-t border-border">
<div className="flex w-full mt-1 p-1 overflow-hidden border-t border-border">
<ToggleGroup
type="single"
value={selectedDarkMode}
Expand Down
93 changes: 17 additions & 76 deletions app/src/utils/communication.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,17 @@ const initJupyterCommunication = (gid: string) => {
return document.getElementsByClassName(`hacker-comm-pyg-kernel-store-${gid}-${index}`)[0].childNodes[1] as HTMLInputElement;
})

const requestTask = [] as any[];
const endpoints = new Map<string, (data: any) => any>();
const bufferMap = new Map<string, any>();

const fetchOnJupyter = (value: string) => {
const event = new Event("input", { bubbles: true })
const kernelText = kernelTextList[curKernelTextIndex];
kernelText.value = value;
kernelText.dispatchEvent(event);
curKernelTextIndex = (curKernelTextIndex + 1) % kernelTextCount;
}

const onMessage = (msg: string) => {
const data = JSON.parse(msg);
const action = data.action;
Expand All @@ -89,13 +96,6 @@ const initJupyterCommunication = (gid: string) => {
document.dispatchEvent(new CustomEvent(getSignalName(data.rid)));
return
}
if (action === "finish_batch_request") {
data.data.forEach((resp: any) => {
bufferMap.set(resp.rid, resp.data);
document.dispatchEvent(new CustomEvent(getSignalName(resp.rid)));
})
return
}
const callback = endpoints.get(action);
if (callback) {
const resp = callback(data.data) ?? {};
Expand Down Expand Up @@ -129,21 +129,7 @@ const initJupyterCommunication = (gid: string) => {

const sendMsgAsync = (action: string, data: any, rid: string | null) => {
rid = rid ?? uuidv4();
requestTask.push({ action, data, rid });
if (requestTask.length === 1) {
setTimeout(() => {
batchSendMsgAsync(requestTask.splice(0, requestTask.length));
}, 100);
}
}

const batchSendMsgAsync = (data: any) => {
const rid = uuidv4();
const event = new Event("input", { bubbles: true })
const kernelText = kernelTextList[curKernelTextIndex];
kernelText.value = JSON.stringify({ gid: gid, rid: rid, action: "batch_request", data });
kernelText.dispatchEvent(event);
curKernelTextIndex = (curKernelTextIndex + 1) % kernelTextCount;
fetchOnJupyter(JSON.stringify({ gid, rid, action, data }));
}

const registerEndpoint = (action: string, callback: (data: any) => any) => {
Expand Down Expand Up @@ -187,17 +173,6 @@ const initJupyterCommunication = (gid: string) => {
}
}

interface IHttpRequestTask {
data: {
action: string;
data: any;
rid: string;
gid: string;
};
resolve: (resp: any) => void;
reject: (err: any) => void;
}

const initHttpCommunication = (gid: string, baseUrl: string) => {
// temporary solution in streamlit could
const domain = window.parent.document.location.host.split(".").slice(-2).join('.');
Expand All @@ -208,8 +183,6 @@ const initHttpCommunication = (gid: string, baseUrl: string) => {
url = `/${baseUrl}/${gid}`
}

const requestTask = [] as IHttpRequestTask[];

const sendMsg = async(action: string, data: any, timeout: number = 30_000) => {
const timer = setTimeout(() => {
raiseRequestError("communication timeout", 0);
Expand All @@ -227,48 +200,16 @@ const initHttpCommunication = (gid: string, baseUrl: string) => {
}
}

const sendMsgAsync = (action: string, data: any) => {
const sendMsgAsync = async(action: string, data: any) => {
const rid = uuidv4();
const promise = new Promise<any>((resolve, reject) => {
requestTask.push({
data: { action, data, rid, gid },
resolve,
reject,
});
if (requestTask.length === 1) {
setTimeout(() => {
batchSendMsgAsync(requestTask.splice(0, requestTask.length));
}, 100);
return await (await fetch(
url,
{
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ action, data, rid, gid }),
}
});
return promise;
}

const batchSendMsgAsync = async(taskList: IHttpRequestTask[]) => {
try {
const resp = await fetch(
url,
{
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
gid: gid,
rid: uuidv4(),
action: "batch_request",
data: taskList.map(task => task.data)
}),
}
)
const respJson = await resp.json();
taskList.forEach((task, index) => {
const taskResp = respJson[index];
task.resolve(taskResp);
})
} catch (err) {
taskList.forEach(task => {
task.reject(err);
})
}
)).json();
}

const registerEndpoint = (_: string, __: (data: any) => any) => {}
Expand Down
26 changes: 24 additions & 2 deletions pygwalker/api/pygwalker.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
from pygwalker.errors import DataCountLimitError
from pygwalker import __version__

RESPONSE_MAX_DATA_LENGTH = 1 * 1000 * 1000


class PygWalker:
"""PygWalker"""
Expand Down Expand Up @@ -334,20 +336,38 @@ def upload_spec_to_cloud(data: Dict[str, Any]):
def _get_datas(data: Dict[str, Any]):
sql = data["sql"]
datas = self.data_parser.get_datas_by_sql(sql)
if len(datas) > 1 * 1000 * 1000:
if len(datas) > RESPONSE_MAX_DATA_LENGTH:
raise DataCountLimitError()
return {
"datas": datas
}

def _get_datas_by_payload(data: Dict[str, Any]):
datas = self.data_parser.get_datas_by_payload(data["payload"])
if len(datas) > 1 * 1000 * 1000:
if len(datas) > RESPONSE_MAX_DATA_LENGTH:
raise DataCountLimitError()
return {
"datas": datas
}

def _batch_get_datas_by_sql(data: Dict[str, Any]):
result = self.data_parser.batch_get_datas_by_sql(data["queryList"])
for datas in result:
if len(datas) > RESPONSE_MAX_DATA_LENGTH:
raise DataCountLimitError()
return {
"datas": result
}

def _batch_get_datas_by_payload(data: Dict[str, Any]):
result = self.data_parser.batch_get_datas_by_payload(data["queryList"])
for datas in result:
if len(datas) > RESPONSE_MAX_DATA_LENGTH:
raise DataCountLimitError()
return {
"datas": result
}

def _get_spec_by_text(data: Dict[str, Any]):
return {
"data": self.cloud_service.get_spec_by_text(data["metas"], data["query"])
Expand Down Expand Up @@ -403,6 +423,8 @@ def _upload_to_cloud_dashboard(data: Dict[str, Any]):
if self.use_kernel_calc:
comm.register("get_datas", _get_datas)
comm.register("get_datas_by_payload", _get_datas_by_payload)
comm.register("batch_get_datas_by_sql", _batch_get_datas_by_sql)
comm.register("batch_get_datas_by_payload", _batch_get_datas_by_payload)

if self.is_export_dataframe:
comm.register("export_dataframe_by_payload", _export_dataframe_by_payload)
Expand Down
8 changes: 1 addition & 7 deletions pygwalker/communications/gradio_comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,7 @@ async def _pygwalker_router(req: Request) -> Response:
json_data = await req.json()

# pylint: disable=protected-access
if json_data["action"] == "batch_request":
result = [
comm_obj._receive_msg(request["action"], request["data"])
for request in json_data["data"]
]
else:
result = comm_obj._receive_msg(json_data["action"], json_data["data"])
result = comm_obj._receive_msg(json_data["action"], json_data["data"])
# pylint: enable=protected-access

result = json.dumps(result, cls=DataFrameEncoder)
Expand Down
15 changes: 2 additions & 13 deletions pygwalker/communications/hacker_comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,19 +61,8 @@ def _on_mesage(self, info: Dict[str, Any]):
if action == "finish_request":
return

if action == "batch_request":
resp = [
{
"data": self._receive_msg(request["action"], request["data"]),
"rid": request["rid"]
}
for request in data
if request["action"] != "finish_request"
]
self.send_msg_async("finish_batch_request", resp, rid)
else:
resp = self._receive_msg(action, data)
self.send_msg_async("finish_request", resp, rid)
resp = self._receive_msg(action, data)
self.send_msg_async("finish_request", resp, rid)

def _get_html_widget(self) -> Text:
text = Text(value="", placeholder="")
Expand Down
8 changes: 1 addition & 7 deletions pygwalker/communications/streamlit_comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,7 @@ def post(self, gid: str):
json_data = json.loads(self.request.body)

# pylint: disable=protected-access
if json_data["action"] == "batch_request":
result = [
comm_obj._receive_msg(request["action"], request["data"])
for request in json_data["data"]
]
else:
result = comm_obj._receive_msg(json_data["action"], json_data["data"])
result = comm_obj._receive_msg(json_data["action"], json_data["data"])
# pylint: enable=protected-access

self.write(json.dumps(result, cls=DataFrameEncoder))
Expand Down
24 changes: 24 additions & 0 deletions pygwalker/data_parsers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,16 @@ def get_datas_by_payload(self, payload: Dict[str, Any]) -> List[Dict[str, Any]]:
"""get records"""
raise NotImplementedError

@abc.abstractmethod
def batch_get_datas_by_sql(self, sql_list: List[str]) -> List[List[Dict[str, Any]]]:
"""batch get records"""
raise NotImplementedError

@abc.abstractmethod
def batch_get_datas_by_payload(self, payload_list: List[Dict[str, Any]]) -> List[List[Dict[str, Any]]]:
"""batch get records"""
raise NotImplementedError

@abc.abstractmethod
def to_csv(self) -> io.BytesIO:
"""get records"""
Expand Down Expand Up @@ -186,6 +196,20 @@ def get_datas_by_payload(self, payload: Dict[str, Any]) -> List[Dict[str, Any]]:
)
return self.get_datas_by_sql(sql)

def batch_get_datas_by_sql(self, sql_list: List[str]) -> List[List[Dict[str, Any]]]:
"""batch get records"""
return [
self.get_datas_by_sql(sql)
for sql in sql_list
]

def batch_get_datas_by_payload(self, payload_list: List[Dict[str, Any]]) -> List[List[Dict[str, Any]]]:
"""batch get records"""
return [
self.get_datas_by_payload(payload)
for payload in payload_list
]

@property
def dataset_tpye(self) -> str:
return "dataframe_default"
Expand Down
12 changes: 12 additions & 0 deletions pygwalker/data_parsers/cloud_dataset_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,18 @@ def _get_all_datas(self, limit: int) -> List[Dict[str, Any]]:
payload = {"workflow": [{"type": "view", "query": [{"op": "raw", "fields": ["*"]}]}], "limit": limit, "offset": 0}
return self.get_datas_by_payload(payload)

def batch_get_datas_by_sql(self, sql_list: List[str]) -> List[List[Dict[str, Any]]]:
"""batch get records"""
pass

def batch_get_datas_by_payload(self, payload_list: List[Dict[str, Any]]) -> List[List[Dict[str, Any]]]:
"""batch get records"""
result = self._cloud_service.batch_query_from_dataset(self.dataset_id, payload_list)
return [
item["rows"]
for item in result
]

@property
def dataset_tpye(self) -> str:
return "cloud_dataset"
Expand Down