From 2b2863f5ca4da78495cc105e78b87c7e020e1a77 Mon Sep 17 00:00:00 2001 From: Elliana May Date: Wed, 14 Feb 2024 16:55:37 +0800 Subject: [PATCH] test: add ser/de json test --- duckdb_engine/tests/test_datatypes.py | 58 +++++++++++++++++++++++++-- 1 file changed, 55 insertions(+), 3 deletions(-) diff --git a/duckdb_engine/tests/test_datatypes.py b/duckdb_engine/tests/test_datatypes.py index 950825d9..71c27987 100644 --- a/duckdb_engine/tests/test_datatypes.py +++ b/duckdb_engine/tests/test_datatypes.py @@ -1,12 +1,24 @@ +import decimal +import json import warnings -from typing import Type +from typing import Any, Dict, Type from uuid import uuid4 import duckdb from pytest import importorskip, mark -from sqlalchemy import Column, Integer, MetaData, String, Table, inspect, text +from sqlalchemy import ( + Column, + Integer, + MetaData, + Sequence, + String, + Table, + inspect, + select, + text, +) from sqlalchemy.dialects.postgresql import UUID -from sqlalchemy.engine import Engine +from sqlalchemy.engine import Engine, create_engine from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import Session from sqlalchemy.sql import sqltypes @@ -57,6 +69,46 @@ def test_raw_json(engine: Engine) -> None: ) +@mark.remote_data() +def test_custom_json_serializer() -> None: + def default(o: Any) -> Any: + if isinstance(o, decimal.Decimal): + return {"__tag": "decimal", "value": str(o)} + + def object_hook(pairs: Dict[str, Any]) -> Any: + if pairs.get("__tag", None) == "decimal": + return decimal.Decimal(pairs["value"]) + else: + return pairs + + engine = create_engine( + "duckdb://", + json_serializer=json.JSONEncoder(default=default).encode, + json_deserializer=json.JSONDecoder(object_hook=object_hook).decode, + ) + + Base = declarative_base() + + class Entry(Base): + __tablename__ = "test_json" + id = Column(Integer, Sequence("id_seq"), primary_key=True) + data = Column(JSON, nullable=False) + + Base.metadata.create_all(engine) + + with engine.connect() as conn: + session = Session(bind=conn) + + data = {"hello": decimal.Decimal("42")} + + session.add(Entry(data=data)) # type: ignore[call-arg] + session.commit() + + (res,) = session.execute(select(Entry)).one() + + assert res.data == data + + def test_json(engine: Engine, session: Session) -> None: base = declarative_base()