Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: parse array field in mysql #498

Merged
merged 1 commit into from
Mar 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pygwalker/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from pygwalker.services.global_var import GlobalVarManager
from pygwalker.services.kaggle import show_tips_user_kaggle as __show_tips_user_kaggle

__version__ = "0.4.8a5"
__version__ = "0.4.8a6"
__hash__ = __rand_str()

from pygwalker.api.jupyter import walk, render, table
Expand Down
11 changes: 8 additions & 3 deletions pygwalker/data_parsers/database_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ class Connector:
- engine_params: engine params, refer to sqlalchemy doc for params. example: {"pool_size": 10}
"""
engine_map = {}
JSON_TYPE_CODE_SET_MAP = {
"snowflake": {9, 10},
"mysql": {245}
}

def __init__(self, url: str, view_sql: str, engine_params: Optional[Dict[str, Any]] = None) -> "Connector":
_check_view_sql(view_sql)
Expand All @@ -59,6 +63,7 @@ def __init__(self, url: str, view_sql: str, engine_params: Optional[Dict[str, An
self.url = url
self.engine = self._get_engine(engine_params)
self.view_sql = view_sql
self._json_type_code_set = self.JSON_TYPE_CODE_SET_MAP.get(self.dialect_name, set())

def _get_engine(self, engine_params: Dict[str, Any]) -> Engine:
if self.url not in self.engine_map:
Expand All @@ -72,14 +77,14 @@ def query_datas(self, sql: str) -> List[Dict[str, Any]]:
field_type_map = {}
with self.engine.connect() as connection:
result = connection.execute(text(sql))
if self.dialect_name == "snowflake":
if self.dialect_name in self.JSON_TYPE_CODE_SET_MAP:
field_type_map = {
column_desc.name: column_desc.type_code
column_desc[0]: column_desc[1]
for column_desc in result.cursor.description
}
return [
{
key: json.loads(value) if field_type_map.get(key, -1) in {9, 10} else value
key: json.loads(value) if field_type_map.get(key, -1) in self._json_type_code_set else value
for key, value in item.items()
}
for item in result.mappings()
Expand Down
3 changes: 3 additions & 0 deletions pygwalker/utils/custom_sqlglot.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from sqlglot.dialects.duckdb import DuckDB as DuckdbDialect
from sqlglot.dialects.postgres import Postgres as PostgresDialect
from sqlglot.dialects.mysql import MySQL as MysqlDialect
from sqlglot import exp
from sqlglot.helper import seq_get

Expand All @@ -19,3 +20,5 @@ def _postgres_round_generator(e: exp.Round) -> str:
)

PostgresDialect.Generator.TRANSFORMS[exp.Round] = lambda _, e: _postgres_round_generator(e)

MysqlDialect.Generator.TRANSFORMS[exp.Array] = lambda self, e: self.func("JSON_ARRAY", *e.expressions)
Loading