diff --git a/duckdb_engine/__init__.py b/duckdb_engine/__init__.py index 1b7204c7..1dfba79c 100644 --- a/duckdb_engine/__init__.py +++ b/duckdb_engine/__init__.py @@ -28,6 +28,7 @@ ) from sqlalchemy.dialects.postgresql.psycopg2 import PGDialect_psycopg2 from sqlalchemy.engine.default import DefaultDialect +from sqlalchemy.engine.interfaces import Dialect as RootDialect from sqlalchemy.engine.reflection import cache from sqlalchemy.engine.url import URL from sqlalchemy.exc import NoSuchTableError @@ -47,7 +48,7 @@ if TYPE_CHECKING: from sqlalchemy.base import Connection from sqlalchemy.engine.interfaces import _IndexDict - + from sqlalchemy.sql.type_api import _ResultProcessor register_extension_types() @@ -215,6 +216,16 @@ def quote_schema(self, schema: str, force: Any = None) -> str: return self.format_schema(schema) +class DuckDBNullType(sqltypes.NullType): + def result_processor( + self, dialect: RootDialect, coltype: sqltypes.TypeEngine + ) -> Optional["_ResultProcessor"]: + if coltype == "JSON": + return sqltypes.JSON().result_processor(dialect, coltype) + else: + return super().result_processor(dialect, coltype) + + class Dialect(PGDialect_psycopg2): name = "duckdb" driver = "duckdb_engine" @@ -247,6 +258,14 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: kwargs["use_native_hstore"] = False super().__init__(*args, **kwargs) + def type_descriptor(self, typeobj: Type[sqltypes.TypeEngine]) -> Any: # type: ignore[override] + res = super().type_descriptor(typeobj) + + if isinstance(res, sqltypes.NullType): + return DuckDBNullType() + + return res + def connect(self, *cargs: Any, **cparams: Any) -> "Connection": core_keys = get_core_config() preload_extensions = cparams.pop("preload_extensions", []) diff --git a/duckdb_engine/datatypes.py b/duckdb_engine/datatypes.py index b2f7cc1e..bec1e3c3 100644 --- a/duckdb_engine/datatypes.py +++ b/duckdb_engine/datatypes.py @@ -188,6 +188,7 @@ def __init__(self, fields: Dict[str, TV]): "timestamp_ms": sqltypes.TIMESTAMP, "timestamp_ns": sqltypes.TIMESTAMP, "enum": sqltypes.Enum, + "json": sqltypes.JSON, } diff --git a/duckdb_engine/tests/test_datatypes.py b/duckdb_engine/tests/test_datatypes.py index 2f9eebc7..950825d9 100644 --- a/duckdb_engine/tests/test_datatypes.py +++ b/duckdb_engine/tests/test_datatypes.py @@ -45,6 +45,18 @@ def test_unsigned_integer_type( assert session.query(table).one() +@mark.remote_data() +def test_raw_json(engine: Engine) -> None: + importorskip("duckdb", "0.9.3.dev4040") + + with engine.connect() as conn: + assert conn.execute(text("load json")) + + assert conn.execute(text("select {'Hello': 'world'}::JSON")).fetchone() == ( + {"Hello": "world"}, + ) + + def test_json(engine: Engine, session: Session) -> None: base = declarative_base()