Skip to content

Commit

Permalink
airbyte-lib: Escape column names (#34969)
Browse files Browse the repository at this point in the history
Co-authored-by: Aaron ("AJ") Steers <aj@airbyte.io>
  • Loading branch information
Joe Reuter and aaronsteers committed Feb 9, 2024
1 parent cb81cb4 commit fca2e66
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 12 deletions.
26 changes: 20 additions & 6 deletions airbyte-lib/airbyte_lib/caches/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,14 @@ def get_sql_engine(self) -> Engine:

return self._engine

def _init_connection_settings(self, connection: Connection) -> None:
"""This is called automatically whenever a new connection is created.
By default this is a no-op. Subclasses can use this to set connection settings, such as
timezone, case-sensitivity settings, and other session-level variables.
"""
pass

@contextmanager
def get_sql_connection(self) -> Generator[sqlalchemy.engine.Connection, None, None]:
"""A context manager which returns a new SQL connection for running queries.
Expand All @@ -206,10 +214,12 @@ def get_sql_connection(self) -> Generator[sqlalchemy.engine.Connection, None, No
"""
if self.use_singleton_connection and self._connection_to_reuse is not None:
connection = self._connection_to_reuse
self._init_connection_settings(connection)
yield connection

else:
with self.get_sql_engine().begin() as connection:
self._init_connection_settings(connection)
yield connection

if not self.use_singleton_connection:
Expand Down Expand Up @@ -312,6 +322,10 @@ def _ensure_schema_exists(
schema_name in found_schemas
), f"Schema {schema_name} was not created. Found: {found_schemas}"

def _quote_identifier(self, identifier: str) -> str:
"""Return the given identifier, quoted."""
return f'"{identifier}"'

@final
def _get_temp_table_name(
self,
Expand All @@ -327,7 +341,7 @@ def _fully_qualified(
table_name: str,
) -> str:
"""Return the fully qualified name of the given table."""
return f"{self.config.schema_name}.{table_name}"
return f"{self.config.schema_name}.{self._quote_identifier(table_name)}"

@final
def _create_table_for_loading(
Expand All @@ -339,7 +353,7 @@ def _create_table_for_loading(
"""Create a new table for loading data."""
temp_table_name = self._get_temp_table_name(stream_name, batch_id)
column_definition_str = ",\n ".join(
f"{column_name} {sql_type}"
f"{self._quote_identifier(column_name)} {sql_type}"
for column_name, sql_type in self._get_sql_column_definitions(stream_name).items()
)
self._create_table(temp_table_name, column_definition_str)
Expand Down Expand Up @@ -383,7 +397,7 @@ def _ensure_final_table_exists(
did_exist = self._table_exists(table_name)
if not did_exist and create_if_missing:
column_definition_str = ",\n ".join(
f"{column_name} {sql_type}"
f"{self._quote_identifier(column_name)} {sql_type}"
for column_name, sql_type in self._get_sql_column_definitions(
stream_name,
).items()
Expand Down Expand Up @@ -743,7 +757,7 @@ def _append_temp_table_to_final_table(
stream_name: str,
) -> None:
nl = "\n"
columns = self._get_sql_column_definitions(stream_name).keys()
columns = [self._quote_identifier(c) for c in self._get_sql_column_definitions(stream_name)]
self._execute_sql(
f"""
INSERT INTO {self._fully_qualified(final_table_name)} (
Expand Down Expand Up @@ -815,8 +829,8 @@ def _merge_temp_table_to_final_table(
Databases that do not support this syntax can override this method.
"""
nl = "\n"
columns = self._get_sql_column_definitions(stream_name).keys()
pk_columns = self._get_primary_keys(stream_name)
columns = {self._quote_identifier(c) for c in self._get_sql_column_definitions(stream_name)}
pk_columns = {self._quote_identifier(c) for c in self._get_primary_keys(stream_name)}
non_pk_columns = columns - pk_columns
join_clause = "{nl} AND ".join(f"tmp.{pk_col} = final.{pk_col}" for pk_col in pk_columns)
set_clause = "{nl} ".join(f"{col} = tmp.{col}" for col in non_pk_columns)
Expand Down
5 changes: 4 additions & 1 deletion airbyte-lib/airbyte_lib/caches/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,10 @@ def _write_files_to_new_table(
stream_name=stream_name,
batch_id=batch_id,
)
columns_list = list(self._get_sql_column_definitions(stream_name).keys())
columns_list = [
self._quote_identifier(c)
for c in list(self._get_sql_column_definitions(stream_name).keys())
]
columns_list_str = indent("\n, ".join(columns_list), " ")
files_list = ", ".join([f"'{f!s}'" for f in files])
insert_statement = dedent(
Expand Down
18 changes: 14 additions & 4 deletions airbyte-lib/airbyte_lib/caches/snowflake.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.

"""A Snowflake implementation of the cache.
TODO: FIXME: Snowflake Cache doesn't work yet. It's a work in progress.
"""
"""A Snowflake implementation of the cache."""

from __future__ import annotations

Expand All @@ -20,6 +17,8 @@
if TYPE_CHECKING:
from pathlib import Path

from sqlalchemy.engine import Connection


class SnowflakeCacheConfig(SQLCacheConfigBase, ParquetWriterConfig):
"""Configuration for the Snowflake cache.
Expand Down Expand Up @@ -82,6 +81,17 @@ def _write_files_to_new_table(
"""
return super()._write_files_to_new_table(files, stream_name, batch_id)

@overrides
def _init_connection_settings(self, connection: Connection) -> None:
"""We override this method to set the QUOTED_IDENTIFIERS_IGNORE_CASE setting to True.
This is necessary because Snowflake otherwise will treat quoted table and column references
as case-sensitive.
More info: https://docs.snowflake.com/en/sql-reference/identifiers-syntax
"""
connection.execute("ALTER SESSION SET QUOTED_IDENTIFIERS_IGNORE_CASE = TRUE")

@overrides
def get_telemetry_info(self) -> CacheTelemetryInfo:
return CacheTelemetryInfo("snowflake")
3 changes: 2 additions & 1 deletion airbyte-lib/examples/run_pokeapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,5 @@
)
source.check()

print(list(source.get_records("pokemon")))
# print(list(source.get_records("pokemon")))
source.read(cache=ab.new_local_cache("poke"))

0 comments on commit fca2e66

Please sign in to comment.