Skip to content

Commit

Permalink
Fix complex numeric query binds, clean up raw API (#336)
Browse files Browse the repository at this point in the history
  • Loading branch information
genzgd committed Apr 3, 2024
1 parent 622924d commit 6408c2a
Show file tree
Hide file tree
Showing 9 changed files with 90 additions and 34 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/on_push.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ on:
- '*_test'
- '*_dev'
- '*_build'
- 'release_*'
- main
paths-ignore:
- 'VERSION'
Expand Down Expand Up @@ -63,11 +64,11 @@ jobs:
- '3.11'
- '3.12'
clickhouse-version:
- '23.3'
- '23.8'
- '23.12'
- '24.1'
- '24.2'
- '24.3'
- latest

name: Local Tests Py=${{ matrix.python-version }} CH=${{ matrix.clickhouse-version }}
Expand Down
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,15 @@ ClickHouse Connect has been included as an official Apache Superset database con
However, if you need compatibility with older versions of Superset, you may need clickhouse-connect
v0.5.25, which dynamically loads the EngineSpec from the clickhouse-connect project.

## 0.7.7, 2024-04-03
### Bug Fix
- Fixed client side binding for complex types containing floats or integers that was broken in version 0.7.5.
Closes https://github.com/ClickHouse/clickhouse-connect/issues/335.
### Improvement
- Added a `raw_stream` method to the Client the returns an io.Base. Use this instead of the `raw_query` method
with the (now removed) optional `stream` keyword boolean. Thanks to [Martijn Thé](https://github.com/martijnthe) for
the PR that highlighted the somewhat messy public API.

## 0.7.6, 2024-04-01
### Bug Fix
- Fixed issue with SQLAlchemy Point type. Closes https://github.com/ClickHouse/clickhouse-connect/issues/332.
Expand Down
2 changes: 1 addition & 1 deletion clickhouse_connect/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
version = '0.7.6'
version = '0.7.7'
6 changes: 3 additions & 3 deletions clickhouse_connect/cc_sqlalchemy/datatypes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from clickhouse_connect.datatypes.base import ClickHouseType, TypeDef, EMPTY_TYPE_DEF
from clickhouse_connect.datatypes.registry import parse_name, type_map
from clickhouse_connect.driver.query import format_query_value
from clickhouse_connect.driver.query import str_query_value

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -96,12 +96,12 @@ def _cached_literal_processor(*_):
method and should be able to ignore literal_processor definitions in the dialect, which are verbose and
confusing.
"""
return format_query_value
return str_query_value

def _compiler_dispatch(self, _visitor, **_):
"""
Override for the SqlAlchemy TypeEngine _compiler_dispatch method to sidestep unnecessary layers and complexity
when generating the type name. The underlying ClickHouseType generates the correct name
when generating the type name. The underlying ClickHouseType generates the correct name for the type
:return: Name generated by the underlying driver.
"""
return self.name
Expand Down
33 changes: 25 additions & 8 deletions clickhouse_connect/driver/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,8 +256,7 @@ def raw_query(self, query: str,
settings: Optional[Dict[str, Any]] = None,
fmt: str = None,
use_database: bool = True,
external_data: Optional[ExternalData] = None,
stream: bool = False) -> Union[bytes, io.IOBase]:
external_data: Optional[ExternalData] = None) -> bytes:
"""
Query method that simply returns the raw ClickHouse format bytes
:param query: Query statement/format string
Expand All @@ -270,6 +269,25 @@ def raw_query(self, query: str,
:return: bytes representing raw ClickHouse return value based on format
"""

@abstractmethod
def raw_stream(self, query: str,
parameters: Optional[Union[Sequence, Dict[str, Any]]] = None,
settings: Optional[Dict[str, Any]] = None,
fmt: str = None,
use_database: bool = True,
external_data: Optional[ExternalData] = None) -> io.IOBase:
"""
Query method that returns the result as an io.IOBase iterator
:param query: Query statement/format string
:param parameters: Optional dictionary used to format the query
:param settings: Optional dictionary of ClickHouse settings (key/string values)
:param fmt: ClickHouse output format
:param use_database Send the database parameter to ClickHouse so the command will be executed in the client
database context.
:param external_data External data to send with the query
:return: io.IOBase stream/iterator for the result
"""

# pylint: disable=duplicate-code,too-many-arguments,unused-argument
def query_np(self,
query: Optional[str] = None,
Expand Down Expand Up @@ -487,12 +505,11 @@ def query_arrow_stream(self,
:return: Generator that yields a PyArrow.Table for per block representing the result set
"""
settings = self._update_arrow_settings(settings, use_strings)
return to_arrow_batches(self.raw_query(query,
parameters,
settings,
fmt='ArrowStream',
external_data=external_data,
stream=True))
return to_arrow_batches(self.raw_stream(query,
parameters,
settings,
fmt='ArrowStream',
external_data=external_data))

def _update_arrow_settings(self,
settings: Optional[Dict[str, Any]],
Expand Down
50 changes: 35 additions & 15 deletions clickhouse_connect/driver/httpclient.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import io
import json
import logging
import re
Expand Down Expand Up @@ -436,27 +437,36 @@ def _raw_request(self,
else:
self._error_handler(response)

def ping(self):
"""
See BaseClient doc_string for this method
"""
try:
response = self.http.request('GET', f'{self.url}/ping', timeout=3)
return 200 <= response.status < 300
except HTTPError:
logger.debug('ping failed', exc_info=True)
return False

def raw_query(self, query: str,
parameters: Optional[Union[Sequence, Dict[str, Any]]] = None,
settings: Optional[Dict[str, Any]] = None,
fmt: str = None,
use_database: bool = True,
external_data: Optional[ExternalData] = None,
stream: bool = False) -> Union[bytes, HTTPResponse]:
external_data: Optional[ExternalData] = None) -> bytes:
"""
See BaseClient doc_string for this method
"""
body, params, fields = self._prep_raw_query(query, parameters, settings, fmt, use_database, external_data)
return self._raw_request(body, params, fields=fields).data

def raw_stream(self, query: str,
parameters: Optional[Union[Sequence, Dict[str, Any]]] = None,
settings: Optional[Dict[str, Any]] = None,
fmt: str = None,
use_database: bool = True,
external_data: Optional[ExternalData] = None) -> io.IOBase:
"""
See BaseClient doc_string for this method
"""
body, params, fields = self._prep_raw_query(query, parameters, settings, fmt, use_database, external_data)
return self._raw_request(body, params, fields=fields, stream=True)

def _prep_raw_query(self, query: str,
parameters: Optional[Union[Sequence, Dict[str, Any]]],
settings: Optional[Dict[str, Any]],
fmt: str,
use_database: bool,
external_data: Optional[ExternalData]):
final_query, bind_params = bind_query(query, parameters, self.server_tz)
if fmt:
final_query += f'\n FORMAT {fmt}'
Expand All @@ -472,8 +482,18 @@ def raw_query(self, query: str,
else:
body = final_query
fields = None
response = self._raw_request(body, params, fields=fields, stream=stream)
return response if stream else response.data
return body, params, fields

def ping(self):
"""
See BaseClient doc_string for this method
"""
try:
response = self.http.request('GET', f'{self.url}/ping', timeout=3)
return 200 <= response.status < 300
except HTTPError:
logger.debug('ping failed', exc_info=True)
return False

def close(self):
if self._owns_pool_manager:
Expand Down
10 changes: 7 additions & 3 deletions clickhouse_connect/driver/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,13 +399,13 @@ def format_query_value(value: Any, server_tz: tzinfo = pytz.UTC):
if isinstance(value, date):
return f"'{value.isoformat()}'"
if isinstance(value, list):
return f"[{', '.join(format_query_value(x, server_tz) for x in value)}]"
return f"[{', '.join(str_query_value(x, server_tz) for x in value)}]"
if isinstance(value, tuple):
return f"({', '.join(format_query_value(x, server_tz) for x in value)})"
return f"({', '.join(str_query_value(x, server_tz) for x in value)})"
if isinstance(value, dict):
if common.get_setting('dict_parameter_format') == 'json':
return format_str(any_to_json(value).decode())
pairs = [format_query_value(k, server_tz) + ':' + format_query_value(v, server_tz)
pairs = [str_query_value(k, server_tz) + ':' + str_query_value(v, server_tz)
for k, v in value.items()]
return f"{{{', '.join(pairs)}}}"
if isinstance(value, Enum):
Expand All @@ -415,6 +415,10 @@ def format_query_value(value: Any, server_tz: tzinfo = pytz.UTC):
return value


def str_query_value(value: Any, server_tz: tzinfo = pytz.UTC):
return str(format_query_value(value, server_tz))


# pylint: disable=too-many-branches
def format_bind_value(value: Any, server_tz: tzinfo = pytz.UTC, top_level: bool = True):
"""
Expand Down
5 changes: 2 additions & 3 deletions clickhouse_connect/tools/testing.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Sequence, Optional, Union, Dict, Any

from clickhouse_connect.driver import Client
from clickhouse_connect.driver.query import format_query_value, quote_identifier
from clickhouse_connect.driver.query import quote_identifier, str_query_value


class TableContext:
Expand Down Expand Up @@ -44,8 +44,7 @@ def __enter__(self):
if self.settings:
create_cmd += ' SETTINGS '
for key, value in self.settings.items():

create_cmd += f'{key} = {format_query_value(value)}, '
create_cmd += f'{key} = {str_query_value(value)}, '
if create_cmd.endswith(', '):
create_cmd = create_cmd[:-2]
self.client.command(create_cmd)
Expand Down
6 changes: 6 additions & 0 deletions tests/integration_tests/test_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@ def test_params(test_client: Client, table_context: Callable):
result = test_client.query('SELECT {l:Array(DateTime)}', parameters={'l': dt_params}).first_row
assert dt_params == result[0]

num_array_params = [2.5, 5.3, 7.4]
result = test_client.query('SELECT {l:Array(Float64)}', parameters={'l': num_array_params}).first_row
assert num_array_params == result[0]
result = test_client.query('SELECT %(l)s', parameters={'l': num_array_params}).first_row
assert num_array_params == result[0]

tp_params = ('str1', 'str2')
result = test_client.query('SELECT %(tp)s', parameters={'tp': tp_params}).first_row
assert tp_params == result[0]
Expand Down

0 comments on commit 6408c2a

Please sign in to comment.