Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incorrect Compile Code from DBT Power User #1154

Open
1 task
ashutran opened this issue May 22, 2024 · 3 comments
Open
1 task

Incorrect Compile Code from DBT Power User #1154

ashutran opened this issue May 22, 2024 · 3 comments
Assignees
Labels
bug Something isn't working sweep

Comments

@ashutran
Copy link

ashutran commented May 22, 2024

Expected behavior

While using the Complied DBT preview feature result is not as expected:

  1. In my model, I'm using a macro for the incremental logic, which contains various if-else statements to determine the compiled code. However, it seems that part of the logic is being skipped, and I'm not sure why this is happening.

image
image
2024-05-21 20_31_45-dim_incremental_merge_by_pk sql - dmg-gcp-cdw-data-etl - Visual Studio Code

  1. I compared the compiled code generated by DBT with the code generated by the plugin, and I noticed that several pieces of logic that should be present in the compiled code are missing.

2024-05-21 19_31_32-srv_temp____temp__dim-appraisal_pps_t sql - dmg-gcp-cdw-data-etl - Visual Studio

  1. In my model, after implementing the incremental logic, I use another macro to apply the casting logic to ensure each column is cast to its appropriate type. However, this step is failing.

2024-05-22 12_33_51-

  1. While executing the query, it's failing with the following error. My DBT architecture places all models under a folder called CDW. However, using dbt deps, we're downloading the project dependencies into a core folder, as specified in the packages.yml file.

{
"code": -1,
"message": "Database Error\n Syntax error: SELECT list must not be empty at [305:1]",
"data": ""Error: Database Error\n Syntax error: SELECT list must not be empty at [305:1]\n\tat DBTCoreProjectIntegration_1. (c:\\Users\\.vscode\\extensions\\innoverio.vscode-dbt-power-user-0.39.12\\dist\\extension.js:18398:127)\n\tat Generator.throw ()\n\tat rejected (c:\\Users\\.vscode\\extensions\\innoverio.vscode-dbt-power-user-0.39.12\\dist\\extension.js:25808:28)""
}
2024-05-22 12_38_24-

image

Actual behavior

After compiling the code

  1. It should read the macro logic and based on that it should return the correct code.
  2. Should apply the casting logic as mentioned the model using the specific Macro.
  3. Query Result should return the value

Steps To Reproduce

Using poetry shell to activate the poetry virtual env
Using python version 3.11.6
dbt-core 1.6.1
dbt-bigquery 1.6.4

Create any model having Macro for incremental logic using if-else where it'll check if the table exist then if condition else another condition
image
image

Log output/Screenshots

No response

Operating System

Windows 10

dbt version

1.6.1

dbt Adapter

dbt-bigquery 1.6.4

dbt Power User version

v0.39.12

Are you willing to submit PR?

  • Yes I am willing to submit a PR!
@ashutran ashutran added the bug Something isn't working label May 22, 2024
Copy link
Contributor

sweep-ai bot commented May 23, 2024

🚀 Here's the PR! #1159

💎 Sweep Pro: You have unlimited Sweep issues

Actions

  • ↻ Restart Sweep

Step 1: 🔎 Searching

Here are the code search results. I'm now analyzing these search results to write the PR.

Relevant files (click to expand). Mentioned files will always appear here.

from decimal import Decimal
import dbt.adapters.factory
# This is critical because `get_adapter` is all over dbt-core
# as they expect a singleton adapter instance per plugin,
# so dbt-niceDatabase will have one adapter instance named niceDatabase.
# This makes sense in dbt-land where we have a single Project/Profile
# combination executed in process from start to finish or a single tenant RPC
# This doesn't fit our paradigm of one adapter per DbtProject in a multitenant server,
# so we create an adapter instance **independent** of the FACTORY cache
# and attach it directly to our RuntimeConfig which is passed through
# anywhere dbt-core needs config including in all `get_adapter` calls
dbt.adapters.factory.get_adapter = lambda config: config.adapter
import os
import threading
import uuid
import sys
import contextlib
from collections import UserDict
from collections.abc import Iterable
from datetime import date, datetime, time
from copy import copy
from functools import lru_cache, partial
from hashlib import md5
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Optional,
Tuple,
TypeVar,
Union,
)
import agate
import json
from dbt.adapters.factory import get_adapter_class_by_name
from dbt.config.runtime import RuntimeConfig
from dbt.flags import set_from_args
from dbt.parser.manifest import ManifestLoader, process_node
from dbt.parser.sql import SqlBlockParser, SqlMacroParser
from dbt.task.sql import SqlCompileRunner, SqlExecuteRunner
from dbt.tracking import disable_tracking
from dbt.version import __version__ as dbt_version
DBT_MAJOR_VER, DBT_MINOR_VER, DBT_PATCH_VER = (
int(v) if v.isnumeric() else v for v in dbt_version.split(".")
)
if DBT_MAJOR_VER >=1 and DBT_MINOR_VER >= 8:
from dbt.contracts.graph.manifest import Manifest # type: ignore
from dbt.contracts.graph.nodes import ManifestNode, CompiledNode # type: ignore
from dbt.artifacts.resources.v1.components import ColumnInfo # type: ignore
from dbt.artifacts.resources.types import NodeType # type: ignore
from dbt_common.events.functions import fire_event # type: ignore
from dbt.artifacts.schemas.manifest import WritableManifest # type: ignore
elif DBT_MAJOR_VER >= 1 and DBT_MINOR_VER > 3:
from dbt.contracts.graph.nodes import ColumnInfo, ManifestNode, CompiledNode # type: ignore
from dbt.node_types import NodeType # type: ignore
from dbt.contracts.graph.manifest import WritableManifest # type: ignore
from dbt.events.functions import fire_event # type: ignore
else:
from dbt.contracts.graph.compiled import ManifestNode, CompiledNode # type: ignore
from dbt.contracts.graph.parsed import ColumnInfo # type: ignore
from dbt.node_types import NodeType # type: ignore
from dbt.events.functions import fire_event # type: ignore
if TYPE_CHECKING:
# These imports are only used for type checking
from dbt.adapters.base import BaseRelation # type: ignore
if DBT_MAJOR_VER >= 1 and DBT_MINOR_VER >= 8:
from dbt.adapters.contracts.connection import AdapterResponse
else:
from dbt.contracts.connection import AdapterResponse
Primitive = Union[bool, str, float, None]
PrimitiveDict = Dict[str, Primitive]
CACHE = {}
CACHE_VERSION = 1
SQL_CACHE_SIZE = 1024
MANIFEST_ARTIFACT = "manifest.json"
RAW_CODE = "raw_code" if DBT_MAJOR_VER >= 1 and DBT_MINOR_VER >= 3 else "raw_sql"
COMPILED_CODE = (
"compiled_code" if DBT_MAJOR_VER >= 1 and DBT_MINOR_VER >= 3 else "compiled_sql"
)
JINJA_CONTROL_SEQS = ["{{", "}}", "{%", "%}", "{#", "#}"]
T = TypeVar("T")
REQUIRE_RESOURCE_NAMES_WITHOUT_SPACES = "REQUIRE_RESOURCE_NAMES_WITHOUT_SPACES"
DBT_DEBUG = "DBT_DEBUG"
DBT_DEFER = "DBT_DEFER"
DBT_STATE = "DBT_STATE"
DBT_FAVOR_STATE = "DBT_FAVOR_STATE"
@contextlib.contextmanager
def add_path(path):
sys.path.append(path)
try:
yield
finally:
sys.path.remove(path)
def validate_sql(
sql: str,
dialect: str,
models: List[Dict],
):
try:
ALTIMATE_PACKAGE_PATH = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "altimate_packages"
)
with add_path(ALTIMATE_PACKAGE_PATH):
from altimate.validate_sql import validate_sql_from_models
return validate_sql_from_models(sql, dialect, models)
except Exception as e:
raise Exception(str(e))
def to_dict(obj):
if isinstance(obj, agate.Table):
return {
"rows": [to_dict(row) for row in obj.rows],
"column_names": obj.column_names,
"column_types": list(map(lambda x: x.__class__.__name__, obj.column_types)),
}
if isinstance(obj, str):
return obj
if isinstance(obj, Decimal):
return float(obj)
if isinstance(obj, (datetime, date, time)):
return obj.isoformat()
elif isinstance(obj, dict):
return dict((key, to_dict(val)) for key, val in obj.items())
elif isinstance(obj, Iterable):
return [to_dict(val) for val in obj]
elif hasattr(obj, "__dict__"):
return to_dict(vars(obj))
elif hasattr(obj, "__slots__"):
return to_dict(
dict((name, getattr(obj, name)) for name in getattr(obj, "__slots__"))
)
return obj
def has_jinja(query: str) -> bool:
"""Utility to check for jinja prior to certain compilation procedures"""
return any(seq in query for seq in JINJA_CONTROL_SEQS)
def memoize_get_rendered(function):
"""Custom memoization function for dbt-core jinja interface"""
def wrapper(
string: str,
ctx: Dict[str, Any],
node: "ManifestNode" = None,
capture_macros: bool = False,
native: bool = False,
):
v = md5(string.strip().encode("utf-8")).hexdigest()
v += "__" + str(CACHE_VERSION)
if capture_macros == True and node is not None:
if node.is_ephemeral:
return function(string, ctx, node, capture_macros, native)
v += "__" + node.unique_id
rv = CACHE.get(v)
if rv is not None:
return rv
else:
rv = function(string, ctx, node, capture_macros, native)
CACHE[v] = rv
return rv
return wrapper
def default_profiles_dir(project_dir):
if "DBT_PROFILES_DIR" in os.environ:
profiles_dir = os.path.expanduser(os.environ["DBT_PROFILES_DIR"])
if os.path.isabs(profiles_dir):
return os.path.normpath(profiles_dir)
return os.path.join(project_dir, profiles_dir)
project_profiles_file = os.path.normpath(os.path.join(project_dir, "profiles.yml"))
return (
project_dir
if os.path.exists(project_profiles_file)
else os.path.join(os.path.expanduser("~"), ".dbt")
)
def target_path(project_dir):
if "DBT_TARGET_PATH" in os.environ:
target_path = os.path.expanduser(os.environ["DBT_TARGET_PATH"])
if os.path.isabs(target_path):
return os.path.normpath(target_path)
return os.path.normpath(os.path.join(project_dir, target_path))
return None
def find_package_paths(project_directories):
def get_package_path(project_dir):
try:
project = DbtProject(
project_dir=project_dir,
profiles_dir=default_profiles_dir(project_dir),
target_path=target_path(project_dir),
)
project.init_config()
packages_path = project.config.packages_install_path
if os.path.isabs(packages_path):
return os.path.normpath(packages_path)
return os.path.normpath(os.path.join(project_dir, packages_path))
except Exception as e:
# We don't care about exceptions here, that is dealt with later when the project is loaded
pass
return list(map(get_package_path, project_directories))
# Performance hacks
# jinja.get_rendered = memoize_get_rendered(jinja.get_rendered)
disable_tracking()
fire_event = lambda e: None
class ConfigInterface:
"""This mimic dbt-core args based interface for dbt-core
class instantiation"""
def __init__(
self,
threads: Optional[int] = 1,
target: Optional[str] = None,
profiles_dir: Optional[str] = None,
project_dir: Optional[str] = None,
profile: Optional[str] = None,
target_path: Optional[str] = None,
defer: Optional[bool] = False,
state: Optional[str] = None,
favor_state: Optional[bool] = False,
):
self.threads = threads
self.target = target
self.profiles_dir = profiles_dir
self.project_dir = project_dir
self.dependencies = []
self.single_threaded = threads == 1
self.quiet = True
self.profile = profile
self.target_path = target_path
self.defer = defer
self.state = state
self.favor_state = favor_state
if DBT_MAJOR_VER >= 1 and DBT_MINOR_VER >= 8:
self.REQUIRE_RESOURCE_NAMES_WITHOUT_SPACES = os.environ.get(REQUIRE_RESOURCE_NAMES_WITHOUT_SPACES, True)
self.DEBUG = os.environ.get(DBT_DEBUG, False)
def __str__(self):
return f"ConfigInterface(threads={self.threads}, target={self.target}, profiles_dir={self.profiles_dir}, project_dir={self.project_dir}, profile={self.profile}, target_path={self.target_path})"
class ManifestProxy(UserDict):
"""Proxy for manifest dictionary (`flat_graph`), if we need mutation then we should
create a copy of the dict or interface with the dbt-core manifest object instead"""
def _readonly(self, *args, **kwargs):
raise RuntimeError("Cannot modify ManifestProxy")
__setitem__ = _readonly
__delitem__ = _readonly
pop = _readonly
popitem = _readonly
clear = _readonly
update = _readonly
setdefault = _readonly
class DbtAdapterExecutionResult:
"""Interface for execution results, this keeps us 1 layer removed from dbt interfaces which may change"""
def __init__(
self,
adapter_response: "AdapterResponse",
table: agate.Table,
raw_sql: str,
compiled_sql: str,
) -> None:
self.adapter_response = adapter_response
self.table = table
self.raw_sql = raw_sql
self.compiled_sql = compiled_sql
class DbtAdapterCompilationResult:
"""Interface for compilation results, this keeps us 1 layer removed from dbt interfaces which may change"""
def __init__(self, raw_sql: str, compiled_sql: str, node: "ManifestNode") -> None:
self.raw_sql = raw_sql
self.compiled_sql = compiled_sql
self.node = node
class DbtProject:
"""Container for a dbt project. The dbt attribute is the primary interface for
dbt-core. The adapter attribute is the primary interface for the dbt adapter"""
def __init__(
self,
target: Optional[str] = None,
profiles_dir: Optional[str] = None,
project_dir: Optional[str] = None,
threads: Optional[int] = 1,
profile: Optional[str] = None,
target_path: Optional[str] = None,
defer_to_prod: bool = False,
manifest_path: Optional[str] = None,
favor_state: bool = False,
):
self.args = ConfigInterface(
threads=threads,
target=target,
profiles_dir=profiles_dir,
project_dir=project_dir,
profile=profile,
target_path=target_path,
defer=defer_to_prod,
state=manifest_path,
favor_state=favor_state,
)
# Utilities
self._sql_parser: Optional[SqlBlockParser] = None
self._macro_parser: Optional[SqlMacroParser] = None
self._sql_runner: Optional[SqlExecuteRunner] = None
self._sql_compiler: Optional[SqlCompileRunner] = None
# Tracks internal state version
self._version: int = 1
self.mutex = threading.Lock()
self.defer_to_prod = defer_to_prod
self.defer_to_prod_manifest_path = manifest_path
self.favor_state = favor_state
def get_adapter(self):
"""This inits a new Adapter which is fundamentally different than
the singleton approach in the core lib"""
adapter_name = self.config.credentials.type
adapter_type = get_adapter_class_by_name(adapter_name)
if DBT_MAJOR_VER >= 1 and DBT_MINOR_VER >= 8:
from dbt.mp_context import get_mp_context
return adapter_type(self.config, get_mp_context())
return adapter_type(self.config)
def init_config(self):
if DBT_MAJOR_VER >= 1 and DBT_MINOR_VER >= 8:
from dbt_common.context import set_invocation_context
set_invocation_context(os.environ)
set_from_args(self.args, None)
else:
set_from_args(self.args, self.args)
self.config = RuntimeConfig.from_args(self.args)
if hasattr(self.config, "source_paths"):
self.config.model_paths = self.config.source_paths
def init_project(self):
try:
self.init_config()
self.adapter = self.get_adapter()
self.adapter.connections.set_connection_name()
if DBT_MAJOR_VER >= 1 and DBT_MINOR_VER >= 8:
from dbt.context.providers import generate_runtime_macro_context
self.adapter.set_macro_context_generator(generate_runtime_macro_context)
self.config.adapter = self.adapter
except Exception as e:
# reset project
self.config = None
self.dbt = None
raise Exception(str(e))
def parse_project(self) -> None:
try:
project_parser = ManifestLoader(
self.config,
self.config.load_dependencies(),
self.adapter.connections.set_query_header,
)
self.dbt = project_parser.load()
project_parser.save_macros_to_adapter(self.adapter)
self.dbt.build_flat_graph()
except Exception as e:
# reset manifest
self.dbt = None
raise Exception(str(e))
self._sql_parser = None
self._macro_parser = None
self._sql_compiler = None
self._sql_runner = None
def set_defer_config(
self, defer_to_prod: bool, manifest_path: str, favor_state: bool
) -> None:
if DBT_MAJOR_VER >= 1 and DBT_MINOR_VER >= 8:
self.args.defer = defer_to_prod
self.args.state = manifest_path
self.args.favor_state = favor_state
self.defer_to_prod = defer_to_prod
self.defer_to_prod_manifest_path = manifest_path
self.favor_state = favor_state
@classmethod
def from_args(cls, args: ConfigInterface) -> "DbtProject":
"""Instatiate the DbtProject directly from a ConfigInterface instance"""
return cls(
target=args.target,
profiles_dir=args.profiles_dir,
project_dir=args.project_dir,
threads=args.threads,
profile=args.profile,
target_path=args.target_path,
)
@property
def sql_parser(self) -> SqlBlockParser:
"""A dbt-core SQL parser capable of parsing and adding nodes to the manifest via `parse_remote` which will
also return the added node to the caller. Note that post-parsing this still typically requires calls to
`_process_nodes_for_ref` and `_process_sources_for_ref` from `dbt.parser.manifest`
"""
if self._sql_parser is None:
self._sql_parser = SqlBlockParser(self.config, self.dbt, self.config)
return self._sql_parser
@property
def macro_parser(self) -> SqlMacroParser:
"""A dbt-core macro parser"""
if self._macro_parser is None:
self._macro_parser = SqlMacroParser(self.config, self.dbt)
return self._macro_parser
@property
def sql_runner(self) -> SqlExecuteRunner:
"""A runner which is used internally by the `execute_sql` function of `dbt.lib`.
The runners `node` attribute can be updated before calling `compile` or `compile_and_execute`.
"""
if self._sql_runner is None:
self._sql_runner = SqlExecuteRunner(
self.config, self.adapter, node=None, node_index=1, num_nodes=1
)
return self._sql_runner
@property
def sql_compiler(self) -> SqlCompileRunner:
"""A runner which is used internally by the `compile_sql` function of `dbt.lib`.
The runners `node` attribute can be updated before calling `compile` or `compile_and_execute`.
"""
if self._sql_compiler is None:
self._sql_compiler = SqlCompileRunner(
self.config, self.adapter, node=None, node_index=1, num_nodes=1
)
return self._sql_compiler
@property
def project_name(self) -> str:
"""dbt project name"""
return self.config.project_name
@property
def project_root(self) -> str:
"""dbt project root"""
return self.config.project_root
@property
def manifest(self) -> ManifestProxy:
"""dbt manifest dict"""
return ManifestProxy(self.dbt.flat_graph)
def safe_parse_project(self) -> None:
self.clear_caches()
# reinit the project because config may change
# this operation is cheap anyway
self.init_project()
# doing this so that we can allow inits to fail when config is
# bad and restart after the user sets it up correctly
if hasattr(self, "config"):
_config_pointer = copy(self.config)
else:
_config_pointer = None
try:
self.parse_project()
self.write_manifest_artifact()
if self.defer_to_prod:
if DBT_MAJOR_VER >= 1 and DBT_MINOR_VER >= 8:
writable_manifest = WritableManifest.read_and_check_versions(self.defer_to_prod_manifest_path)
manifest = Manifest.from_writable_manifest(writable_manifest)
self.dbt.merge_from_artifact(
other=manifest,
)
else:
with open(self.defer_to_prod_manifest_path) as f:
manifest = WritableManifest.from_dict(json.load(f))
selected = set()
self.dbt.merge_from_artifact(
self.adapter,
other=manifest,
selected=selected,
favor_state=self.favor_state,
)
except Exception as e:
self.config = _config_pointer
raise Exception(str(e))
def write_manifest_artifact(self) -> None:
"""Write a manifest.json to disk"""
artifact_path = os.path.join(
self.config.project_root, self.config.target_path, MANIFEST_ARTIFACT
)
self.dbt.write(artifact_path)
def clear_caches(self) -> None:
"""Clear least recently used caches and reinstantiable container objects"""
self.get_ref_node.cache_clear()
self.get_source_node.cache_clear()
self.get_macro_function.cache_clear()
self.get_columns.cache_clear()
@lru_cache(maxsize=10)
def get_ref_node(self, target_model_name: str) -> "ManifestNode":
"""Get a `"ManifestNode"` from a dbt project model name"""
try:
if DBT_MAJOR_VER >= 1 and DBT_MINOR_VER >= 6:
return self.dbt.resolve_ref(
source_node=None,
target_model_name=target_model_name,
target_model_version=None,
target_model_package=None,
current_project=self.config.project_name,
node_package=self.config.project_name,
)
if DBT_MAJOR_VER == 1 and DBT_MINOR_VER >= 5:
return self.dbt.resolve_ref(
target_model_name=target_model_name,
target_model_version=None,
target_model_package=None,
current_project=self.config.project_name,
node_package=self.config.project_name,
)
return self.dbt.resolve_ref(
target_model_name=target_model_name,
target_model_package=None,
current_project=self.config.project_name,
node_package=self.config.project_name,
)
except Exception as e:
raise Exception(str(e))
@lru_cache(maxsize=10)
def get_source_node(
self, target_source_name: str, target_table_name: str
) -> "ManifestNode":
"""Get a `"ManifestNode"` from a dbt project source name and table name"""
try:
return self.dbt.resolve_source(
target_source_name=target_source_name,
target_table_name=target_table_name,
current_project=self.config.project_name,
node_package=self.config.project_name,
)
except Exception as e:
raise Exception(str(e))
def get_server_node(self, sql: str, node_name="name"):
"""Get a node for SQL execution against adapter"""
self._clear_node(node_name)
sql_node = self.sql_parser.parse_remote(sql, node_name)
process_node(self.config, self.dbt, sql_node)
return sql_node
@lru_cache(maxsize=100)
def get_macro_function(self, macro_name: str) -> Callable[[Dict[str, Any]], Any]:
"""Get macro as a function which takes a dict via argument named `kwargs`,
ie: `kwargs={"relation": ...}`
make_schema_fn = get_macro_function('make_schema')\n
make_schema_fn({'name': '__test_schema_1'})\n
make_schema_fn({'name': '__test_schema_2'})"""
if DBT_MAJOR_VER >= 1 and DBT_MINOR_VER >= 8:
return partial(
self.adapter.execute_macro, macro_name=macro_name
)
else:
return partial(
self.adapter.execute_macro, macro_name=macro_name, manifest=self.dbt
)
def adapter_execute(
self, sql: str, auto_begin: bool = True, fetch: bool = False
) -> Tuple["AdapterResponse", agate.Table]:
"""Wraps adapter.execute. Execute SQL against database"""
return self.adapter.execute(sql, auto_begin, fetch)
def execute_macro(
self,
macro: str,
kwargs: Optional[Dict[str, Any]] = None,
) -> Any:
"""Wraps adapter execute_macro. Execute a macro like a function."""
return self.get_macro_function(macro)(kwargs=kwargs)
def execute_sql(self, raw_sql: str) -> DbtAdapterExecutionResult:
"""Execute dbt SQL statement against database"""
with self.adapter.connection_named("master"):
# if no jinja chars then these are synonymous
compiled_sql = raw_sql
if has_jinja(raw_sql):
# jinja found, compile it
compilation_result = self._compile_sql(raw_sql)
compiled_sql = compilation_result.compiled_sql
return DbtAdapterExecutionResult(
*self.adapter_execute(compiled_sql, fetch=True),
raw_sql,
compiled_sql,
)
def execute_node(self, node: "ManifestNode") -> DbtAdapterExecutionResult:
"""Execute dbt SQL statement against database from a"ManifestNode"""
try:
raw_sql: str = getattr(node, RAW_CODE)
compiled_sql: Optional[str] = getattr(node, COMPILED_CODE, None)
if compiled_sql:
# node is compiled, execute the SQL
return self.execute_sql(compiled_sql)
# node not compiled
if has_jinja(raw_sql):
# node has jinja in its SQL, compile it
compiled_sql = self._compile_node(node).compiled_sql
# execute the SQL
return self.execute_sql(compiled_sql or raw_sql)
except Exception as e:
raise Exception(str(e))
def compile_sql(self, raw_sql: str) -> DbtAdapterCompilationResult:
try:
with self.adapter.connection_named("master"):
return self._compile_sql(raw_sql)
except Exception as e:
raise Exception(str(e))
def compile_node(
self, node: "ManifestNode"
) -> Optional[DbtAdapterCompilationResult]:
try:
with self.adapter.connection_named("master"):
return self._compile_node(node)
except Exception as e:
raise Exception(str(e))
def _compile_sql(self, raw_sql: str) -> DbtAdapterCompilationResult:
"""Creates a node with a `dbt.parser.sql` class. Compile generated node."""
try:
temp_node_id = str("t_" + uuid.uuid4().hex)
node = self._compile_node(self.get_server_node(raw_sql, temp_node_id))
self._clear_node(temp_node_id)
return node
except Exception as e:
raise Exception(str(e))
def _compile_node(
self, node: Union["ManifestNode", "CompiledNode"]
) -> Optional[DbtAdapterCompilationResult]:
"""Compiles existing node."""
try:
self.sql_compiler.node = copy(node)
if DBT_MAJOR_VER == 1 and DBT_MINOR_VER <= 3:
compiled_node = (
node
if isinstance(node, CompiledNode)
else self.sql_compiler.compile(self.dbt)
)
else:
# this is essentially a convenient wrapper to adapter.get_compiler
compiled_node = self.sql_compiler.compile(self.dbt)
return DbtAdapterCompilationResult(
getattr(compiled_node, RAW_CODE),
getattr(compiled_node, COMPILED_CODE),
compiled_node,
)
except Exception as e:
raise Exception(str(e))
def _clear_node(self, name="name"):
"""Removes the statically named node created by `execute_sql` and `compile_sql` in `dbt.lib`"""
if self.dbt is not None:
self.dbt.nodes.pop(
f"{NodeType.SqlOperation}.{self.project_name}.{name}", None
)
def get_relation(
self, database: Optional[str], schema: Optional[str], name: Optional[str]
) -> Optional["BaseRelation"]:
"""Wrapper for `adapter.get_relation`"""
return self.adapter.get_relation(database, schema, name)
def create_relation(
self, database: Optional[str], schema: Optional[str], name: Optional[str]
) -> "BaseRelation":
"""Wrapper for `adapter.Relation.create`"""
return self.adapter.Relation.create(database, schema, name)
def create_relation_from_node(self, node: "ManifestNode") -> "BaseRelation":
"""Wrapper for `adapter.Relation.create_from`"""
return self.adapter.Relation.create_from(self.config, node)
def get_columns_in_relation(self, relation: "BaseRelation") -> List[str]:
"""Wrapper for `adapter.get_columns_in_relation`"""
try:
with self.adapter.connection_named("master"):
return self.adapter.get_columns_in_relation(relation)
except Exception as e:
raise Exception(str(e))
@lru_cache(maxsize=5)
def get_columns(self, node: "ManifestNode") -> List["ColumnInfo"]:
"""Get a list of columns from a compiled node"""
columns = []
try:
columns.extend(
[
c.name
for c in self.get_columns_in_relation(
self.create_relation_from_node(node)
)
]
)
except Exception:
original_sql = str(getattr(node, RAW_CODE))
# TODO: account for `TOP` syntax
setattr(node, RAW_CODE, f"select * from ({original_sql}) limit 0")
result = self.execute_node(node)
setattr(node, RAW_CODE, original_sql)
delattr(node, COMPILED_CODE)
columns.extend(result.table.column_names)
return columns
def get_catalog(self) -> Dict[str, Any]:
"""Get catalog from adapter"""
catalog_table: agate.Table = agate.Table([])
catalog_data: List[PrimitiveDict] = []
exceptions: List[Exception] = []
try:
with self.adapter.connection_named("generate_catalog"):
catalog_table, exceptions = self.adapter.get_catalog(self.dbt)
if exceptions:
raise Exception(str(exceptions))
catalog_data = [
dict(
zip(catalog_table.column_names, map(dbt.utils._coerce_decimal, row))
)
for row in catalog_table
]
except Exception as e:
raise Exception(str(e))
return catalog_data
def get_or_create_relation(
self, database: str, schema: str, name: str
) -> Tuple["BaseRelation", bool]:
"""Get relation or create if not exists. Returns tuple of relation and
boolean result of whether it existed ie: (relation, did_exist)"""
ref = self.get_relation(database, schema, name)
return (
(ref, True)
if ref
else (self.create_relation(database, schema, name), False)
)
def create_schema(self, node: "ManifestNode"):
"""Create a schema in the database"""
return self.execute_macro(
"create_schema",
kwargs={"relation": self.create_relation_from_node(node)},
)
def materialize(
self, node: "ManifestNode", temporary: bool = True
) -> Tuple["AdapterResponse", None]:
"""Materialize a table in the database"""
return self.adapter_execute(
# Returns CTAS string so send to adapter.execute
self.execute_macro(
"create_table_as",
kwargs={
"sql": getattr(node, COMPILED_CODE),
"relation": self.create_relation_from_node(node),
"temporary": temporary,
},
),
auto_begin=True,
)
def get_dbt_version(self):
return [DBT_MAJOR_VER, DBT_MINOR_VER, DBT_PATCH_VER]
def validate_sql_dry_run(self, compiled_sql: str):
if DBT_MAJOR_VER < 1:
return None
if DBT_MINOR_VER < 6:
return None
try:
return self.adapter.validate_sql(compiled_sql)
except Exception as e:

import {
CancellationToken,
Diagnostic,
DiagnosticCollection,
Disposable,
languages,
Range,
RelativePattern,
Uri,
window,
workspace,
} from "vscode";
import {
extendErrorWithSupportLinks,
getFirstWorkspacePath,
getProjectRelativePath,
provideSingleton,
setupWatcherHandler,
} from "../utils";
import {
Catalog,
CompilationResult,
DBColumn,
DBTNode,
DBTCommand,
DBTCommandExecutionInfrastructure,
DBTDetection,
DBTProjectDetection,
DBTProjectIntegration,
ExecuteSQLResult,
PythonDBTCommandExecutionStrategy,
QueryExecution,
SourceNode,
Node,
ExecuteSQLError,
HealthcheckArgs,
} from "./dbtIntegration";
import { PythonEnvironment } from "../manifest/pythonEnvironment";
import { CommandProcessExecutionFactory } from "../commandProcessExecution";
import { PythonBridge, PythonException } from "python-bridge";
import * as path from "path";
import { DBTProject } from "../manifest/dbtProject";
import { existsSync, readFileSync } from "fs";
import { parse } from "yaml";
import { TelemetryService } from "../telemetry";
import {
AltimateRequest,
NotFoundError,
ValidateSqlParseErrorResponse,
} from "../altimate";
import { DBTProjectContainer } from "../manifest/dbtProjectContainer";
import { ManifestPathType } from "../constants";
import { DBTTerminal } from "./dbtTerminal";
import { ValidationProvider } from "../validation_provider";
import { DeferToProdService } from "../services/deferToProdService";
const DEFAULT_QUERY_TEMPLATE = "select * from ({query}) as query limit {limit}";
// TODO: we shouold really get these from manifest directly
interface ResolveReferenceNodeResult {
database: string;
schema: string;
alias: string;
}
interface ResolveReferenceSourceResult {
database: string;
schema: string;
alias: string;
resource_type: string;
identifier: string;
}
interface DeferConfig {
deferToProduction: boolean;
favorState: boolean;
manifestPathForDeferral: string;
manifestPathType?: ManifestPathType;
dbtCoreIntegrationId?: number;
}
type InsightType = "Modelling" | "Test" | "structure";
interface Insight {
name: string;
type: InsightType;
message: string;
recommendation: string;
reason_to_flag: string;
metadata: {
model?: string;
model_unique_id?: string;
model_type?: string;
convention?: string | null;
};
}
type Severity = "ERROR" | "WARNING";
interface ModelInsight {
insight: Insight;
severity: Severity;
unique_id: string;
package_name: string;
path: string;
original_file_path: string;
}
export interface ProjectHealthcheck {
model_insights: Record<string, ModelInsight[]>;
// package_insights: any;
}
@provideSingleton(DBTCoreDetection)
export class DBTCoreDetection implements DBTDetection {
constructor(
private pythonEnvironment: PythonEnvironment,
private commandProcessExecutionFactory: CommandProcessExecutionFactory,
) {}
async detectDBT(): Promise<boolean> {
try {
const checkDBTInstalledProcess =
this.commandProcessExecutionFactory.createCommandProcessExecution({
command: this.pythonEnvironment.pythonPath,
args: ["-c", "import dbt"],
cwd: getFirstWorkspacePath(),
envVars: this.pythonEnvironment.environmentVariables,
});
const { stderr } = await checkDBTInstalledProcess.complete();
if (stderr) {
throw new Error(stderr);
}
return true;
} catch (error) {
return false;
}
}
}
@provideSingleton(DBTCoreProjectDetection)
export class DBTCoreProjectDetection
implements DBTProjectDetection, Disposable
{
constructor(
private executionInfrastructure: DBTCommandExecutionInfrastructure,
private dbtTerminal: DBTTerminal,
) {}
private getPackageInstallPathFallback(
projectDirectory: Uri,
packageInstallPath: string,
): string {
const dbtProjectFile = path.join(
projectDirectory.fsPath,
"dbt_project.yml",
);
if (existsSync(dbtProjectFile)) {
const dbtProjectConfig: any = parse(readFileSync(dbtProjectFile, "utf8"));
const packagesInstallPath = dbtProjectConfig["packages-install-path"];
if (packagesInstallPath) {
if (path.isAbsolute(packagesInstallPath)) {
return packagesInstallPath;
} else {
return path.join(projectDirectory.fsPath, packagesInstallPath);
}
}
}
return packageInstallPath;
}
async discoverProjects(projectDirectories: Uri[]): Promise<Uri[]> {
let packagesInstallPaths = projectDirectories.map((projectDirectory) =>
path.join(projectDirectory.fsPath, "dbt_packages"),
);
let python: PythonBridge | undefined;
try {
python = this.executionInfrastructure.createPythonBridge(
getFirstWorkspacePath(),
);
await python.ex`from dbt_core_integration import *`;
const packagesInstallPathsFromPython = await python.lock<string[]>(
(python) =>
python`to_dict(find_package_paths(${projectDirectories.map(
(projectDirectory) => projectDirectory.fsPath,
)}))`,
);
packagesInstallPaths = packagesInstallPaths.map(
(packageInstallPath, index) => {
const packageInstallPathFromPython =
packagesInstallPathsFromPython[index];
if (packageInstallPathFromPython) {
return Uri.file(packageInstallPathFromPython).fsPath;
}
return packageInstallPath;
},
);
} catch (error) {
this.dbtTerminal.debug(
"dbtCoreIntegration:discoverProjects",
"An error occured while finding package paths: " + error,
);
// Fallback to reading yaml files
packagesInstallPaths = projectDirectories.map((projectDirectory, idx) =>
this.getPackageInstallPathFallback(
projectDirectory,
packagesInstallPaths[idx],
),
);
} finally {
if (python) {
this.executionInfrastructure.closePythonBridge(python);
}
}
const filteredProjectFiles = projectDirectories.filter((uri) => {
return !packagesInstallPaths.some((packageInstallPath) => {
return uri.fsPath.startsWith(packageInstallPath!);
});
});
if (filteredProjectFiles.length > 20) {
window.showWarningMessage(
`dbt Power User detected ${filteredProjectFiles.length} projects in your work space, this will negatively affect performance.`,
);
}
return filteredProjectFiles;
}
async dispose() {}
}
@provideSingleton(DBTCoreProjectIntegration)
export class DBTCoreProjectIntegration
implements DBTProjectIntegration, Disposable
{
static DBT_PROFILES_FILE = "profiles.yml";
private profilesDir?: string;
private targetPath?: string;
private adapterType?: string;
private version?: number[];
private packagesInstallPath?: string;
private modelPaths?: string[];
private seedPaths?: string[];
private macroPaths?: string[];
private python: PythonBridge;
private disposables: Disposable[] = [];
private readonly rebuildManifestDiagnostics =
languages.createDiagnosticCollection("dbt");
private readonly pythonBridgeDiagnostics =
languages.createDiagnosticCollection("dbt");
private static QUEUE_ALL = "all";
constructor(
private executionInfrastructure: DBTCommandExecutionInfrastructure,
private pythonEnvironment: PythonEnvironment,
private telemetry: TelemetryService,
private pythonDBTCommandExecutionStrategy: PythonDBTCommandExecutionStrategy,
private dbtProjectContainer: DBTProjectContainer,
private altimateRequest: AltimateRequest,
private dbtTerminal: DBTTerminal,
private validationProvider: ValidationProvider,
private deferToProdService: DeferToProdService,
private projectRoot: Uri,
private projectConfigDiagnostics: DiagnosticCollection,
) {
this.dbtTerminal.debug(
"DBTCoreProjectIntegration",
`Registering dbt core project at ${this.projectRoot}`,
);
this.python = this.executionInfrastructure.createPythonBridge(
this.projectRoot.fsPath,
);
this.executionInfrastructure.createQueue(
DBTCoreProjectIntegration.QUEUE_ALL,
);
this.disposables.push(
this.pythonEnvironment.onPythonEnvironmentChanged(() => {
this.python = this.executionInfrastructure.createPythonBridge(
this.projectRoot.fsPath,
);
}),
this.rebuildManifestDiagnostics,
this.pythonBridgeDiagnostics,
);
}
// remove the trailing slashes if they exists,
// causes the quote to be escaped when passing to python
private removeTrailingSlashes(input: string | undefined) {
return input?.replace(/\\+$/, "");
}
private getLimitQuery(queryTemplate: string, query: string, limit: number) {
return queryTemplate
.replace("{query}", () => query)
.replace("{limit}", () => limit.toString());
}
private async getQuery(
query: string,
limit: number,
): Promise<{ queryTemplate: string; limitQuery: string }> {
try {
const dbtVersion = await this.version;
//dbt supports limit macro after v1.5
if (dbtVersion && dbtVersion[0] >= 1 && dbtVersion[1] >= 5) {
const args = { sql: query, limit };
const queryTemplateFromMacro = await this.python?.lock(
(python) =>
python!`to_dict(project.execute_macro('get_limit_subquery_sql', ${args}))`,
);
this.dbtTerminal.debug(
"DBTCoreProjectIntegration",
"Using query template from macro",
queryTemplateFromMacro,
);
return {
queryTemplate: queryTemplateFromMacro,
limitQuery: queryTemplateFromMacro,
};
}
} catch (err) {
console.error("Error while getting get_limit_subquery_sql macro", err);
this.telemetry.sendTelemetryError(
"executeMacroGetLimitSubquerySQLError",
err,
{ adapter: this.adapterType || "unknown" },
);
}
const queryTemplate = workspace
.getConfiguration("dbt")
.get<string>("queryTemplate");
if (queryTemplate && queryTemplate !== DEFAULT_QUERY_TEMPLATE) {
console.log("Using user provided query template", queryTemplate);
const limitQuery = this.getLimitQuery(queryTemplate, query, limit);
return { queryTemplate, limitQuery };
}
return {
queryTemplate: DEFAULT_QUERY_TEMPLATE,
limitQuery: this.getLimitQuery(DEFAULT_QUERY_TEMPLATE, query, limit),
};
}
async refreshProjectConfig(): Promise<void> {
await this.createPythonDbtProject(this.python);
await this.python.ex`project.init_project()`;
this.targetPath = await this.findTargetPath();
this.modelPaths = await this.findModelPaths();
this.seedPaths = await this.findSeedPaths();
this.macroPaths = await this.findMacroPaths();
this.packagesInstallPath = await this.findPackagesInstallPath();
this.version = await this.findVersion();
this.adapterType = await this.findAdapterType();
}
async executeSQL(query: string, limit: number): Promise<QueryExecution> {
this.throwBridgeErrorIfAvailable();
const { limitQuery } = await this.getQuery(query, limit);
const queryThread = this.executionInfrastructure.createPythonBridge(
this.projectRoot.fsPath,
);
await this.createPythonDbtProject(queryThread);
await queryThread.ex`project.init_project()`;
return new QueryExecution(
async () => {
queryThread.kill(2);
},
async () => {
// compile query
const compiledQuery = await this.unsafeCompileQuery(limitQuery);
// execute query
let result: ExecuteSQLResult;
try {
result = await queryThread!.lock<ExecuteSQLResult>(
(python) => python`to_dict(project.execute_sql(${compiledQuery}))`,
);
const { manifestPathType } =
this.deferToProdService.getDeferConfigByProjectRoot(
this.projectRoot.fsPath,
);
if (manifestPathType === ManifestPathType.REMOTE) {
this.altimateRequest.sendDeferToProdEvent(ManifestPathType.REMOTE);
}
} catch (err) {
const message = `Error while executing sql: ${compiledQuery}`;
this.dbtTerminal.error("dbtCore:executeSQL", message, err);
if (err instanceof PythonException) {
throw new ExecuteSQLError(err.exception.message, compiledQuery!);
}
throw new ExecuteSQLError((err as Error).message, compiledQuery!);
}
return { ...result, compiled_stmt: compiledQuery };
},
);
}
private async createPythonDbtProject(bridge: PythonBridge) {
await bridge.ex`from dbt_core_integration import *`;
const targetPath = this.removeTrailingSlashes(
await bridge.lock(
(python) => python`target_path(${this.projectRoot.fsPath})`,
),
);
const { deferToProduction, manifestPath, favorState } =
await this.getDeferConfig();
await bridge.ex`project = DbtProject(project_dir=${this.projectRoot.fsPath}, profiles_dir=${this.profilesDir}, target_path=${targetPath}, defer_to_prod=${deferToProduction}, manifest_path=${manifestPath}, favor_state=${favorState}) if 'project' not in locals() else project`;
}
async initializeProject(): Promise<void> {
try {
await this.python
.ex`from dbt_core_integration import default_profiles_dir`;
await this.python.ex`from dbt_healthcheck import *`;
this.profilesDir = this.removeTrailingSlashes(
await this.python.lock(
(python) => python`default_profiles_dir(${this.projectRoot.fsPath})`,
),
);
if (this.profilesDir) {
const dbtProfileWatcher = workspace.createFileSystemWatcher(
new RelativePattern(
this.profilesDir,
DBTCoreProjectIntegration.DBT_PROFILES_FILE,
),
);
this.disposables.push(
dbtProfileWatcher,
// when the project config changes we need to re-init the dbt project
...setupWatcherHandler(dbtProfileWatcher, () =>
this.rebuildManifest(),
),
);
}
await this.createPythonDbtProject(this.python);
this.pythonBridgeDiagnostics.clear();
} catch (exc: any) {
if (exc instanceof PythonException) {
// python errors can be about anything, so we just associate the error with the project file
// with a fixed range
if (exc.message.includes("No module named 'dbt'")) {
// Let's not create an error for each project if dbt is not detected
// This is already displayed in the status bar
return;
}
let errorMessage =
"An error occured while initializing the dbt project: " +
exc.exception.message;
if (exc.exception.type.module === "dbt.exceptions") {
// TODO: we can do provide solutions per type of dbt exception
errorMessage =
"An error occured while initializing the dbt project, dbt found following issue: " +
exc.exception.message;
}
this.pythonBridgeDiagnostics.set(
Uri.joinPath(this.projectRoot, DBTProject.DBT_PROJECT_FILE),
[new Diagnostic(new Range(0, 0, 999, 999), errorMessage)],
);
this.telemetry.sendTelemetryError("pythonBridgeInitPythonError", exc);
} else {
window.showErrorMessage(
extendErrorWithSupportLinks(
"An unexpected error occured while initializing the dbt project at " +
this.projectRoot +
": " +
exc +
".",
),
);
this.telemetry.sendTelemetryError("pythonBridgeInitError", exc);
}
}
}
getTargetPath(): string | undefined {
return this.targetPath;
}
getModelPaths(): string[] | undefined {
return this.modelPaths;
}
getSeedPaths(): string[] | undefined {
return this.seedPaths;
}
getMacroPaths(): string[] | undefined {
return this.macroPaths;
}
getPackageInstallPath(): string | undefined {
return this.packagesInstallPath;
}
getAdapterType(): string | undefined {
return this.adapterType;
}
getVersion(): number[] | undefined {
return this.version;
}
async findAdapterType(): Promise<string | undefined> {
return this.python.lock<string>(
(python) => python`project.config.credentials.type`,
);
}
getPythonBridgeStatus(): boolean {
return this.python.connected;
}
getAllDiagnostic(): Diagnostic[] {
const projectURI = Uri.joinPath(
this.projectRoot,
DBTProject.DBT_PROJECT_FILE,
);
return [
...(this.pythonBridgeDiagnostics.get(projectURI) || []),
...(this.projectConfigDiagnostics.get(projectURI) || []),
...(this.rebuildManifestDiagnostics.get(projectURI) || []),
];
}
async rebuildManifest(): Promise<void> {
const errors = this.projectConfigDiagnostics.get(
Uri.joinPath(this.projectRoot, DBTProject.DBT_PROJECT_FILE),
);
if (errors !== undefined && errors.length > 0) {
// No point in trying to rebuild the manifest if the config is not valid
return;
}
try {
await this.python.lock(
(python) => python`to_dict(project.safe_parse_project())`,
);
this.rebuildManifestDiagnostics.clear();
} catch (exc) {
if (exc instanceof PythonException) {
// dbt errors can be about anything, so we just associate the error with the project file
// with a fixed range
this.rebuildManifestDiagnostics.set(
Uri.joinPath(this.projectRoot, DBTProject.DBT_PROJECT_FILE),
[
new Diagnostic(
new Range(0, 0, 999, 999),
"There is a problem in your dbt project. Compilation failed: " +
exc.exception.message,
),
],
);
this.telemetry.sendTelemetryEvent(
"pythonBridgeCannotParseProjectUserError",
{
error: exc.exception.message,
adapter: this.getAdapterType() || "unknown", // TODO: this should be moved to dbtProject
},
);
return;
}
// if we get here, it is not a dbt error but an extension error.
this.telemetry.sendTelemetryError(
"pythonBridgeCannotParseProjectUnknownError",
exc,
{
adapter: this.adapterType || "unknown", // TODO: this should be moved to dbtProject
},
);
window.showErrorMessage(
extendErrorWithSupportLinks(
"An error occured while rebuilding the dbt manifest: " + exc + ".",
),
);
}
}
async runModel(command: DBTCommand) {
this.addCommandToQueue(
await this.addDeferParams(this.dbtCoreCommand(command)),
);
}
async buildModel(command: DBTCommand) {
this.addCommandToQueue(
await this.addDeferParams(this.dbtCoreCommand(command)),
);
}
async buildProject(command: DBTCommand) {
this.addCommandToQueue(
await this.addDeferParams(this.dbtCoreCommand(command)),
);
}
async runTest(command: DBTCommand) {
this.addCommandToQueue(
await this.addDeferParams(this.dbtCoreCommand(command)),
);
}
async runModelTest(command: DBTCommand) {
this.addCommandToQueue(
await this.addDeferParams(this.dbtCoreCommand(command)),
);
}
async compileModel(command: DBTCommand) {
this.addCommandToQueue(
await this.addDeferParams(this.dbtCoreCommand(command)),
);
}
async generateDocs(command: DBTCommand) {
this.addCommandToQueue(this.dbtCoreCommand(command));
}
async executeCommandImmediately(command: DBTCommand) {
return await this.dbtCoreCommand(command).execute();
}
async deps(command: DBTCommand) {
const { stdout, stderr } = await this.dbtCoreCommand(command).execute();
if (stderr) {
throw new Error(stderr);
}
return stdout;
}
async debug(command: DBTCommand) {
const { stdout, stderr } = await this.dbtCoreCommand(command).execute();
if (stderr) {
throw new Error(stderr);
}
return stdout;
}
private addCommandToQueue(command: DBTCommand) {
const isInstalled =
this.dbtProjectContainer.showErrorIfDbtOrPythonNotInstalled();
if (!isInstalled) {
return;
}
this.executionInfrastructure.addCommandToQueue(
DBTCoreProjectIntegration.QUEUE_ALL,
command,
);
}
private async getDeferManifestPath(
manifestPathType: ManifestPathType | undefined,
manifestPathForDeferral: string,
dbtCoreIntegrationId: number | undefined,
): Promise<string> {
if (!manifestPathType) {
const configNotPresent = new Error(
"Please configure defer to production functionality by specifying manifest path in Actions panel before using it.",
);
throw configNotPresent;
}
if (manifestPathType === ManifestPathType.LOCAL) {
if (!manifestPathForDeferral) {
const configNotPresent = new Error(
"manifestPathForDeferral config is not present, use the actions panel to set the Defer to production configuration.",
);
this.dbtTerminal.error(
"manifestPathForDeferral",
"manifestPathForDeferral is not present",
configNotPresent,
);
throw configNotPresent;
}
return manifestPathForDeferral;
}
if (manifestPathType === ManifestPathType.REMOTE) {
try {
this.validationProvider.throwIfNotAuthenticated();
} catch (err) {
throw new Error(
"Defer to production is currently enabled with 'DataPilot dbt integration' mode. It requires a valid Altimate AI API key and instance name in the settings. In order to run dbt commands, please either switch to Local Path mode or disable the feature or add an API key / instance name.",
);
}
this.dbtTerminal.debug(
"remoteManifest",
`fetching artifact url for dbtCoreIntegrationId: ${dbtCoreIntegrationId}`,
);
try {
const response = await this.altimateRequest.fetchArtifactUrl(
"manifest",
dbtCoreIntegrationId!,
);
const manifestPath = await this.altimateRequest.downloadFileLocally(
response.url,
this.projectRoot,
);
console.log(`Set remote manifest path: ${manifestPath}`);
return manifestPath;
} catch (error) {
if (error instanceof NotFoundError) {
const manifestNotFoundError = new Error(
"Unable to download remote manifest file. Did you upload your manifest using the Altimate DataPilot CLI?",
);
this.dbtTerminal.error(
"remoteManifestError",
"Unable to download remote manifest file.",
manifestNotFoundError,
);
throw manifestNotFoundError;
}
throw error;
}
}
throw new Error(`Invalid manifestPathType: ${manifestPathType}`);
}
private async getDeferParams(): Promise<string[]> {
const deferConfig = this.deferToProdService.getDeferConfigByProjectRoot(
this.projectRoot.fsPath,
);
const {
deferToProduction,
manifestPathForDeferral,
favorState,
manifestPathType,
dbtCoreIntegrationId,
} = deferConfig;
if (!deferToProduction) {
this.dbtTerminal.debug("deferToProd", "defer to prod not enabled");
return [];
}
const manifestPath = await this.getDeferManifestPath(
manifestPathType,
manifestPathForDeferral,
dbtCoreIntegrationId,
);
const args = ["--defer", "--state", manifestPath];
if (favorState) {
args.push("--favor-state");
}
this.dbtTerminal.debug(
"deferToProd",
`executing dbt command with defer params ${manifestPathType} mode`,
true,
args,
);
if (manifestPathType === ManifestPathType.REMOTE) {
this.altimateRequest.sendDeferToProdEvent(ManifestPathType.REMOTE);
}
return args;
}
private async addDeferParams(command: DBTCommand) {
const deferParams = await this.getDeferParams();
deferParams.forEach((param) => command.addArgument(param));
return command;
}
private dbtCoreCommand(command: DBTCommand) {
command.addArgument("--project-dir");
command.addArgument(this.projectRoot.fsPath);
if (this.profilesDir) {
command.addArgument("--profiles-dir");
command.addArgument(this.profilesDir);
}
command.setExecutionStrategy(this.pythonDBTCommandExecutionStrategy);
return command;
}
// internal commands
async unsafeCompileNode(modelName: string): Promise<string> {
this.throwBridgeErrorIfAvailable();
const output = await this.python?.lock<CompilationResult>(
(python) =>
python!`to_dict(project.compile_node(project.get_ref_node(${modelName})))`,
);
return output.compiled_sql;
}
async unsafeCompileQuery(query: string): Promise<string> {
this.throwBridgeErrorIfAvailable();
const output = await this.python?.lock<CompilationResult>(
(python) => python!`to_dict(project.compile_sql(${query}))`,
);
return output.compiled_sql;
}
async validateSql(query: string, dialect: string, models: any) {
this.throwBridgeErrorIfAvailable();
const result = await this.python?.lock<ValidateSqlParseErrorResponse>(
(python) =>
python!`to_dict(validate_sql(${query}, ${dialect}, ${models}))`,
);
return result;
}
async validateSQLDryRun(query: string) {
this.throwBridgeErrorIfAvailable();
const result = await this.python?.lock<{ bytes_processed: string }>(
(python) => python!`to_dict(project.validate_sql_dry_run(${query}))`,
);
return result;
}
async getColumnsOfModel(modelName: string) {
this.throwBridgeErrorIfAvailable();
// Get database and schema
const node = (await this.python?.lock(
(python) => python!`to_dict(project.get_ref_node(${modelName}))`,
)) as ResolveReferenceNodeResult;
// Get columns
if (!node) {
return [];
}
// TODO: fix this type
return this.getColumsOfRelation(
node.database,
node.schema,
node.alias || modelName,
);
}
async getColumnsOfSource(sourceName: string, tableName: string) {
this.throwBridgeErrorIfAvailable();
// Get database and schema
const node = (await this.python?.lock(
(python) =>
python!`to_dict(project.get_source_node(${sourceName}, ${tableName}))`,
)) as ResolveReferenceSourceResult;
// Get columns
if (!node) {
return [];
}
return this.getColumsOfRelation(
node.database,
node.schema,
node.identifier,
);
}
private async getColumsOfRelation(
database: string | undefined,
schema: string | undefined,
objectName: string,
): Promise<DBColumn[]> {
this.throwBridgeErrorIfAvailable();
return this.python?.lock<DBColumn[]>(
(python) =>
python!`to_dict(project.get_columns_in_relation(project.create_relation(${database}, ${schema}, ${objectName})))`,
);
}
async getBulkSchema(
nodes: DBTNode[],
cancellationToken: CancellationToken,
): Promise<Record<string, DBColumn[]>> {
const result: Record<string, DBColumn[]> = {};
for (const n of nodes) {
if (cancellationToken.isCancellationRequested) {
break;
}
if (n.resource_type === DBTProject.RESOURCE_TYPE_SOURCE) {
const source = n as SourceNode;
result[n.unique_id] = await this.getColumnsOfSource(
source.name,
source.table,
);
} else {
const model = n as Node;
result[n.unique_id] = await this.getColumnsOfModel(model.name);
}
}
return result;
}
async getCatalog(): Promise<Catalog> {
this.throwBridgeErrorIfAvailable();
return await this.python?.lock<Catalog>(
(python) => python!`to_dict(project.get_catalog())`,
);
}
// get dbt config
private async findModelPaths(): Promise<string[]> {
return (
await this.python.lock<string[]>(
(python) => python`to_dict(project.config.model_paths)`,
)
).map((modelPath: string) => {
if (!path.isAbsolute(modelPath)) {
return path.join(this.projectRoot.fsPath, modelPath);
}
return modelPath;
});
}
private async findSeedPaths(): Promise<string[]> {
return (
await this.python.lock<string[]>(
(python) => python`to_dict(project.config.seed_paths)`,
)
).map((seedPath: string) => {
if (!path.isAbsolute(seedPath)) {
return path.join(this.projectRoot.fsPath, seedPath);
}
return seedPath;
});
}
getDebounceForRebuildManifest() {
return 2000;
}
private async findMacroPaths(): Promise<string[]> {
return (
await this.python.lock<string[]>(
(python) => python`to_dict(project.config.macro_paths)`,
)
).map((macroPath: string) => {
if (!path.isAbsolute(macroPath)) {
return path.join(this.projectRoot.fsPath, macroPath);
}
return macroPath;
});
}
private async findTargetPath(): Promise<string> {
let targetPath = await this.python.lock(
(python) => python`to_dict(project.config.target_path)`,
);
if (!path.isAbsolute(targetPath)) {
targetPath = path.join(this.projectRoot.fsPath, targetPath);
}
return targetPath;
}
private async findPackagesInstallPath(): Promise<string> {
let packageInstallPath = await this.python.lock(
(python) => python`to_dict(project.config.packages_install_path)`,
);
if (!path.isAbsolute(packageInstallPath)) {
packageInstallPath = path.join(
this.projectRoot.fsPath,
packageInstallPath,
);
}
return packageInstallPath;
}
private async findVersion(): Promise<number[]> {
return this.python?.lock<number[]>(
(python) => python!`to_dict(project.get_dbt_version())`,
);
}
private throwBridgeErrorIfAvailable() {
const allDiagnostics: DiagnosticCollection[] = [
this.pythonBridgeDiagnostics,
this.projectConfigDiagnostics,
this.rebuildManifestDiagnostics,
];
for (const diagnosticCollection of allDiagnostics) {
for (const [_, diagnostics] of diagnosticCollection) {
if (diagnostics.length > 0) {
const firstError = diagnostics[0];
throw new Error(firstError.message);
}
}
}
}
findPackageVersion(packageName: string) {
if (!this.packagesInstallPath) {
throw new Error("Missing packages install path");
}
if (!packageName) {
throw new Error("Invalid package name");
}
const dbtProjectYmlFilePath = path.join(
this.packagesInstallPath,
packageName,
"dbt_project.yml",
);
if (!existsSync(dbtProjectYmlFilePath)) {
throw new Error("Package not installed");
}
const fileContents = readFileSync(dbtProjectYmlFilePath, {
encoding: "utf-8",
});
if (!fileContents) {
throw new Error(`${packageName} has empty dbt_project.yml`);
}
const parsedConfig = parse(fileContents, {
strict: false,
uniqueKeys: false,
maxAliasCount: -1,
});
if (!parsedConfig?.version) {
throw new Error(`Missing version in ${dbtProjectYmlFilePath}`);
}
return parsedConfig.version;
}
async dispose() {
try {
await this.executionInfrastructure.closePythonBridge(this.python);
} catch (error) {} // We don't care about errors here.
this.rebuildManifestDiagnostics.clear();
this.pythonBridgeDiagnostics.clear();
while (this.disposables.length) {
const x = this.disposables.pop();
if (x) {
x.dispose();
}
}
}
async performDatapilotHealthcheck({
manifestPath,
catalogPath,
config,
configPath,
}: HealthcheckArgs): Promise<ProjectHealthcheck> {
this.throwBridgeErrorIfAvailable();
const healthCheckThread = this.executionInfrastructure.createPythonBridge(
this.projectRoot.fsPath,
);
await this.createPythonDbtProject(healthCheckThread);
await healthCheckThread.ex`from dbt_healthcheck import *`;
const result = await healthCheckThread.lock<ProjectHealthcheck>(
(python) =>
python!`to_dict(project_healthcheck(${manifestPath}, ${catalogPath}, ${configPath}, ${config}))`,
);
return result;
}
private async getDeferConfig() {
try {
const root = getProjectRelativePath(this.projectRoot);
const currentConfig: Record<string, DeferConfig> =
this.deferToProdService.getDeferConfigByWorkspace();
const {
deferToProduction,
manifestPathForDeferral,
favorState,
manifestPathType,
dbtCoreIntegrationId,
} = currentConfig[root];
const manifestFolder = await this.getDeferManifestPath(
manifestPathType,
manifestPathForDeferral,
dbtCoreIntegrationId,
);
const manifestPath = path.join(manifestFolder, DBTProject.MANIFEST_FILE);
return { deferToProduction, manifestPath, favorState };
} catch (error) {
this.dbtTerminal.debug(
"dbtCoreIntegration:getDeferConfig",
"An error occured while getting defer config: " +
(error as Error).message,
);
}
return { deferToProduction: false, manifestPath: null, favorState: false };
}
async applyDeferConfig(): Promise<void> {
const { deferToProduction, manifestPath, favorState } =
await this.getDeferConfig();
await this.python?.lock<void>(
(python) =>
python!`project.set_defer_config(${deferToProduction}, ${manifestPath}, ${favorState})`,
);
await this.rebuildManifest();
}
throwDiagnosticsErrorIfAvailable(): void {
this.throwBridgeErrorIfAvailable();
}

import re
import sqlglot
from sqlglot.executor import execute
from sqlglot.expressions import Table
from sqlglot.optimizer import traverse_scope
from sqlglot.optimizer.qualify import qualify
ADAPTER_MAPPING = {
"bigquery": "bigquery",
"clickhouse": "clickhouse",
"databricks": "databricks",
"duckdb": "duckdb",
"hive": "hive",
"mysql": "mysql",
"oracle": "oracle",
"postgres": "postgres",
"redshift": "redshift",
"snowflake": "snowflake",
"spark": "spark",
"starrocks": "starrocks",
"teradata": "teradata",
"trino": "trino",
"synapse": "tsql",
"sqlserver": "tsql",
"doris": "doris",
}
MULTIPLE_OCCURENCES_STR = "Unable to highlight the exact location in the SQL code due to multiple occurrences."
MAPPING_FAILED_STR = "Unable to highlight the exact location in the SQL code."
def extract_column_name(text):
# List of regex patterns
regex_patterns = [
r"Column '\"(\w+)\"' could not be resolved",
r"Unknown column: (\w+)",
r"Column '(\w+)' could not be resolved",
r"Unknown output column: (\w+)",
r"Cannot automatically join: (\w+)",
]
# Iterate over each regex pattern
for regex in regex_patterns:
matches = re.findall(regex, text)
if matches:
return matches[0]
return None
def find_single_occurrence_indices(main_string, substring):
# Convert both strings to lowercase for case-insensitive comparison
main_string = main_string.lower()
substring = substring.lower() if substring else ""
if not substring:
return None, None
num_occurrences = main_string.count(substring)
# Check if the substring occurs only once in the main string
if num_occurrences == 1:
start_index = main_string.find(substring)
return start_index, start_index + len(substring), num_occurrences
# Return None if the substring doesn't occur exactly once
return None, None, num_occurrences
def map_adapter_to_dialect(adapter: str):
return ADAPTER_MAPPING.get(adapter, adapter)
def get_str_position(str, row, col):
"""
Get the position of a grid position in a string
"""
lines = str.split("\n")
position = 0
for i in range(row - 1):
position += len(lines[i]) + 1
position += col
return position
def get_line_and_column_from_position(text, start_index):
"""
Finds the grid position (row and column) in a multiline string given a Python start index.
Rows and columns are 1-indexed.
:param text: Multiline string.
:param start_index: Python start index (0-indexed).
:return: Tuple of (row, column).
"""
row = 0
current_length = 0
# Split the text into lines
lines = text.split("\n")
for line in lines:
# Check if the start_index is in the current line
if current_length + len(line) >= start_index:
# Column is the difference between start_index and the length of processed characters
column = start_index - current_length + 1
return row, column
# Update the row and current length for the next iteration
row += 1
current_length += len(line) + 1 # +1 for the newline character
return None, None
def _build_message(sql: str, error: dict):
len_highlight = len(error.get("highlight", ""))
len_prefix = len(error.get("start_context", ""))
if error.get("line") and error.get("col"):
end_position = get_str_position(sql, error["line"], error["col"])
start_position = end_position - len_highlight - len_prefix
row, col = get_line_and_column_from_position(sql, start_position)
return {
"description": "Failed to parse the sql query",
"start_position": [row, col],
"end_position": [error["line"], error["col"]],
}
return {"description": "Failed to parse the sql query"}
def sql_parse_errors(sql: str, dialect: str):
errors = []
try:
sqlglot.transpile(sql, read=dialect)
ast = sqlglot.parse_one(sql, read=dialect)
if isinstance(ast, sqlglot.exp.Alias):
return [
{
"description": "Failed to parse the sql query.",
}
]
except sqlglot.errors.ParseError as e:
for error in e.errors:
errors.append(_build_message(sql, error))
return errors
def get_start_and_end_position(sql: str, invalid_string: str):
start, end, num_occurences = find_single_occurrence_indices(sql, invalid_string)
if start and end:
return (
list(get_line_and_column_from_position(sql, start)),
list(get_line_and_column_from_position(sql, end)),
num_occurences,
)
return None, None, num_occurences
def form_error(
error: str, invalid_entity: str, start_position, end_position, num_occurences
):
if num_occurences > 1:
error = (
f"{error}\n {MULTIPLE_OCCURENCES_STR.format(invalid_entity=invalid_entity)}"
)
return {
"description": error,
}
if not start_position or not end_position:
error = (
f"{error}\n {MAPPING_FAILED_STR.format(invalid_entity=invalid_entity)}"
if invalid_entity
else error
)
return {
"description": error,
}
return {
"description": error,
"start_position": start_position,
"end_position": end_position,
}
def validate_tables_and_columns(
sql: str,
dialect: str,
schemas: dict,
):
try:
parsed_sql = sqlglot.parse_one(sql, read=dialect)
qualify(parsed_sql, dialect=dialect, schema=schemas)
except sqlglot.errors.OptimizeError as e:
error = str(e)
if "sqlglot" in error:
error = "Failed to validate the query."
invalid_entity = extract_column_name(error)
if not invalid_entity:
return [
{
"description": error,
}
]
start_position, end_position, num_occurences = get_start_and_end_position(
sql, invalid_entity
)
error = error if error[-1] == "." else error + "."
return [
form_error(
error, invalid_entity, start_position, end_position, num_occurences
)
]
return None
def sql_execute_errors(
sql: str,
dialect: str,
schemas: dict,
):
tables = {}
for db in schemas:
if db not in tables:
tables[db] = {}
for schema in schemas[db]:
if schema not in tables[db]:
tables[db][schema] = {}
for table in schemas[db][schema]:
tables[db][schema][table] = []
try:
execute(
sql=sql,
read=dialect,
schema=schemas,
tables=tables,
)
except sqlglot.errors.ExecuteError as e:
return [
{
"description": str(e),
}
]
return None
def qualify_columns(expression):
"""
Qualify the columns in the given SQL expression.
"""
try:
return qualify(
expression,
qualify_columns=True,
isolate_tables=True,
validate_qualify_columns=False,
)
except sqlglot.errors.OptimizeError as error:
return expression
def parse_sql_query(sql_query, dialect):
"""
Parses the SQL query and returns an AST.
"""
return sqlglot.parse_one(sql_query, read=dialect)
def extract_physical_columns(ast):
"""
Extracts physical columns from the given AST.
"""
physical_columns = {}
for scope in traverse_scope(ast):
for column in scope.columns:
table = scope.sources.get(column.table)
if isinstance(table, Table):
db, schema, table_name = table.catalog, table.db, table.name
if db is None or schema is None:
continue
path = f"{db}.{schema}.{table_name}".lower()
physical_columns.setdefault(path, set()).add(column.name)
return physical_columns
def get_columns_used(sql_query, dialect):
"""
Process the SQL query to extract physical columns.
"""
ast = parse_sql_query(sql_query, dialect)
qualified_ast = qualify_columns(ast)
return extract_physical_columns(qualified_ast)
def validate_columns_present_in_schema(sql_query, dialect, schemas, model_mapping):
"""
Validate that the columns in the SQL query are present in the schema.
"""
errors = []
new_schemas = {}
for db in schemas:
for schema in schemas[db]:
for table in schemas[db][schema]:
path = f"{db}.{schema}.{table}".lower()
new_schemas.setdefault(path, set()).update(
[column.lower() for column in schemas[db][schema][table].keys()]
)
schemas = new_schemas
try:
columns_used = get_columns_used(sql_query, dialect)
for table, columns_set in columns_used.items():
if table not in schemas:
(
start_position,
end_position,
num_occurences,
) = get_start_and_end_position(sql_query, table)
error = f"Error: Table '{table}' not found. This issue often occurs when a table is used directly\n in dbt instead of being referenced through the appropriate syntax.\n To resolve this, ensure that '{table}' is propaerly defined in your project and use the 'ref()' function to reference it in your models."
errors.append(
form_error(
error, table, start_position, end_position, num_occurences
)
)
continue
columns = schemas[table]
for column in columns_set:
if column.lower() not in columns:
(
start_position,
end_position,
num_occurences,
) = get_start_and_end_position(sql_query, column)
table = model_mapping.get(table, table)
error = f"Error: Column '{column}' not found in '{table}'. \nPossible causes: 1) Typo in column name. 2) Column not materialized. 3) Column not selected in parent cte."
errors.append(
form_error(
error,
column,
start_position,
end_position,
num_occurences,
)
)
except Exception as e:
pass

from __future__ import annotations
import logging
import typing as t
from collections import defaultdict
from sqlglot import exp
from sqlglot.errors import ErrorLevel, UnsupportedError, concat_messages
from sqlglot.helper import apply_index_offset, csv, seq_get
from sqlglot.time import format_time
from sqlglot.tokens import Tokenizer, TokenType
logger = logging.getLogger("sqlglot")
class Generator:
"""
Generator converts a given syntax tree to the corresponding SQL string.
Args:
pretty: Whether or not to format the produced SQL string.
Default: False.
identify: Determines when an identifier should be quoted. Possible values are:
False (default): Never quote, except in cases where it's mandatory by the dialect.
True or 'always': Always quote.
'safe': Only quote identifiers that are case insensitive.
normalize: Whether or not to normalize identifiers to lowercase.
Default: False.
pad: Determines the pad size in a formatted string.
Default: 2.
indent: Determines the indentation size in a formatted string.
Default: 2.
normalize_functions: Whether or not to normalize all function names. Possible values are:
"upper" or True (default): Convert names to uppercase.
"lower": Convert names to lowercase.
False: Disables function name normalization.
unsupported_level: Determines the generator's behavior when it encounters unsupported expressions.
Default ErrorLevel.WARN.
max_unsupported: Maximum number of unsupported messages to include in a raised UnsupportedError.
This is only relevant if unsupported_level is ErrorLevel.RAISE.
Default: 3
leading_comma: Determines whether or not the comma is leading or trailing in select expressions.
This is only relevant when generating in pretty mode.
Default: False
max_text_width: The max number of characters in a segment before creating new lines in pretty mode.
The default is on the smaller end because the length only represents a segment and not the true
line length.
Default: 80
comments: Whether or not to preserve comments in the output SQL code.
Default: True
"""
TRANSFORMS = {
exp.DateAdd: lambda self, e: self.func(
"DATE_ADD", e.this, e.expression, exp.Literal.string(e.text("unit"))
),
exp.TsOrDsAdd: lambda self, e: self.func(
"TS_OR_DS_ADD", e.this, e.expression, exp.Literal.string(e.text("unit"))
),
exp.CaseSpecificColumnConstraint: lambda self, e: f"{'NOT ' if e.args.get('not_') else ''}CASESPECIFIC",
exp.CharacterSetColumnConstraint: lambda self, e: f"CHARACTER SET {self.sql(e, 'this')}",
exp.CharacterSetProperty: lambda self, e: f"{'DEFAULT ' if e.args.get('default') else ''}CHARACTER SET={self.sql(e, 'this')}",
exp.CheckColumnConstraint: lambda self, e: f"CHECK ({self.sql(e, 'this')})",
exp.ClusteredColumnConstraint: lambda self, e: f"CLUSTERED ({self.expressions(e, 'this', indent=False)})",
exp.CollateColumnConstraint: lambda self, e: f"COLLATE {self.sql(e, 'this')}",
exp.CopyGrantsProperty: lambda self, e: "COPY GRANTS",
exp.CommentColumnConstraint: lambda self, e: f"COMMENT {self.sql(e, 'this')}",
exp.DateFormatColumnConstraint: lambda self, e: f"FORMAT {self.sql(e, 'this')}",
exp.DefaultColumnConstraint: lambda self, e: f"DEFAULT {self.sql(e, 'this')}",
exp.EncodeColumnConstraint: lambda self, e: f"ENCODE {self.sql(e, 'this')}",
exp.ExecuteAsProperty: lambda self, e: self.naked_property(e),
exp.ExternalProperty: lambda self, e: "EXTERNAL",
exp.HeapProperty: lambda self, e: "HEAP",
exp.InlineLengthColumnConstraint: lambda self, e: f"INLINE LENGTH {self.sql(e, 'this')}",
exp.IntervalSpan: lambda self, e: f"{self.sql(e, 'this')} TO {self.sql(e, 'expression')}",
exp.LanguageProperty: lambda self, e: self.naked_property(e),
exp.LocationProperty: lambda self, e: self.naked_property(e),
exp.LogProperty: lambda self, e: f"{'NO ' if e.args.get('no') else ''}LOG",
exp.MaterializedProperty: lambda self, e: "MATERIALIZED",
exp.NoPrimaryIndexProperty: lambda self, e: "NO PRIMARY INDEX",
exp.NonClusteredColumnConstraint: lambda self, e: f"NONCLUSTERED ({self.expressions(e, 'this', indent=False)})",
exp.NotForReplicationColumnConstraint: lambda self, e: "NOT FOR REPLICATION",
exp.OnCommitProperty: lambda self, e: f"ON COMMIT {'DELETE' if e.args.get('delete') else 'PRESERVE'} ROWS",
exp.OnProperty: lambda self, e: f"ON {self.sql(e, 'this')}",
exp.OnUpdateColumnConstraint: lambda self, e: f"ON UPDATE {self.sql(e, 'this')}",
exp.PathColumnConstraint: lambda self, e: f"PATH {self.sql(e, 'this')}",
exp.ReturnsProperty: lambda self, e: self.naked_property(e),
exp.SetProperty: lambda self, e: f"{'MULTI' if e.args.get('multi') else ''}SET",
exp.SettingsProperty: lambda self, e: f"SETTINGS{self.seg('')}{(self.expressions(e))}",
exp.SqlSecurityProperty: lambda self, e: f"SQL SECURITY {'DEFINER' if e.args.get('definer') else 'INVOKER'}",
exp.StabilityProperty: lambda self, e: e.name,
exp.TemporaryProperty: lambda self, e: f"TEMPORARY",
exp.ToTableProperty: lambda self, e: f"TO {self.sql(e.this)}",
exp.TransientProperty: lambda self, e: "TRANSIENT",
exp.TitleColumnConstraint: lambda self, e: f"TITLE {self.sql(e, 'this')}",
exp.UppercaseColumnConstraint: lambda self, e: f"UPPERCASE",
exp.VarMap: lambda self, e: self.func("MAP", e.args["keys"], e.args["values"]),
exp.VolatileProperty: lambda self, e: "VOLATILE",
exp.WithJournalTableProperty: lambda self, e: f"WITH JOURNAL TABLE={self.sql(e, 'this')}",
}
# Whether or not null ordering is supported in order by
NULL_ORDERING_SUPPORTED = True
# Whether or not locking reads (i.e. SELECT ... FOR UPDATE/SHARE) are supported
LOCKING_READS_SUPPORTED = False
# Always do union distinct or union all
EXPLICIT_UNION = False
# Wrap derived values in parens, usually standard but spark doesn't support it
WRAP_DERIVED_VALUES = True
# Whether or not create function uses an AS before the RETURN
CREATE_FUNCTION_RETURN_AS = True
# Whether or not MERGE ... WHEN MATCHED BY SOURCE is allowed
MATCHED_BY_SOURCE = True
# Whether or not the INTERVAL expression works only with values like '1 day'
SINGLE_STRING_INTERVAL = False
# Whether or not the plural form of date parts like day (i.e. "days") is supported in INTERVALs
INTERVAL_ALLOWS_PLURAL_FORM = True
# Whether or not the TABLESAMPLE clause supports a method name, like BERNOULLI
TABLESAMPLE_WITH_METHOD = True
# Whether or not to treat the number in TABLESAMPLE (50) as a percentage
TABLESAMPLE_SIZE_IS_PERCENT = False
# Whether or not limit and fetch are supported (possible values: "ALL", "LIMIT", "FETCH")
LIMIT_FETCH = "ALL"
# Whether or not a table is allowed to be renamed with a db
RENAME_TABLE_WITH_DB = True
# The separator for grouping sets and rollups
GROUPINGS_SEP = ","
# The string used for creating an index on a table
INDEX_ON = "ON"
# Whether or not join hints should be generated
JOIN_HINTS = True
# Whether or not table hints should be generated
TABLE_HINTS = True
# Whether or not query hints should be generated
QUERY_HINTS = True
# What kind of separator to use for query hints
QUERY_HINT_SEP = ", "
# Whether or not comparing against booleans (e.g. x IS TRUE) is supported
IS_BOOL_ALLOWED = True
# Whether or not to include the "SET" keyword in the "INSERT ... ON DUPLICATE KEY UPDATE" statement
DUPLICATE_KEY_UPDATE_WITH_SET = True
# Whether or not to generate the limit as TOP <value> instead of LIMIT <value>
LIMIT_IS_TOP = False
# Whether or not to generate INSERT INTO ... RETURNING or INSERT INTO RETURNING ...
RETURNING_END = True
# Whether or not to generate the (+) suffix for columns used in old-style join conditions
COLUMN_JOIN_MARKS_SUPPORTED = False
# Whether or not to generate an unquoted value for EXTRACT's date part argument
EXTRACT_ALLOWS_QUOTES = True
# Whether or not TIMETZ / TIMESTAMPTZ will be generated using the "WITH TIME ZONE" syntax
TZ_TO_WITH_TIME_ZONE = False
# Whether or not the NVL2 function is supported
NVL2_SUPPORTED = True
# https://cloud.google.com/bigquery/docs/reference/standard-sql/query-syntax
SELECT_KINDS: t.Tuple[str, ...] = ("STRUCT", "VALUE")
# Whether or not VALUES statements can be used as derived tables.
# MySQL 5 and Redshift do not allow this, so when False, it will convert
# SELECT * VALUES into SELECT UNION
VALUES_AS_TABLE = True
# Whether or not the word COLUMN is included when adding a column with ALTER TABLE
ALTER_TABLE_ADD_COLUMN_KEYWORD = True
TYPE_MAPPING = {
exp.DataType.Type.NCHAR: "CHAR",
exp.DataType.Type.NVARCHAR: "VARCHAR",
exp.DataType.Type.MEDIUMTEXT: "TEXT",
exp.DataType.Type.LONGTEXT: "TEXT",
exp.DataType.Type.MEDIUMBLOB: "BLOB",
exp.DataType.Type.LONGBLOB: "BLOB",
exp.DataType.Type.INET: "INET",
}
STAR_MAPPING = {
"except": "EXCEPT",
"replace": "REPLACE",
}
TIME_PART_SINGULARS = {
"microseconds": "microsecond",
"seconds": "second",
"minutes": "minute",
"hours": "hour",
"days": "day",
"weeks": "week",
"months": "month",
"quarters": "quarter",
"years": "year",
}
TOKEN_MAPPING: t.Dict[TokenType, str] = {}
STRUCT_DELIMITER = ("<", ">")
PARAMETER_TOKEN = "@"
PROPERTIES_LOCATION = {
exp.AlgorithmProperty: exp.Properties.Location.POST_CREATE,
exp.AutoIncrementProperty: exp.Properties.Location.POST_SCHEMA,
exp.BlockCompressionProperty: exp.Properties.Location.POST_NAME,
exp.CharacterSetProperty: exp.Properties.Location.POST_SCHEMA,
exp.ChecksumProperty: exp.Properties.Location.POST_NAME,
exp.CollateProperty: exp.Properties.Location.POST_SCHEMA,
exp.CopyGrantsProperty: exp.Properties.Location.POST_SCHEMA,
exp.Cluster: exp.Properties.Location.POST_SCHEMA,
exp.ClusteredByProperty: exp.Properties.Location.POST_SCHEMA,
exp.DataBlocksizeProperty: exp.Properties.Location.POST_NAME,
exp.DefinerProperty: exp.Properties.Location.POST_CREATE,
exp.DictRange: exp.Properties.Location.POST_SCHEMA,
exp.DictProperty: exp.Properties.Location.POST_SCHEMA,
exp.DistKeyProperty: exp.Properties.Location.POST_SCHEMA,
exp.DistStyleProperty: exp.Properties.Location.POST_SCHEMA,
exp.EngineProperty: exp.Properties.Location.POST_SCHEMA,
exp.ExecuteAsProperty: exp.Properties.Location.POST_SCHEMA,
exp.ExternalProperty: exp.Properties.Location.POST_CREATE,
exp.FallbackProperty: exp.Properties.Location.POST_NAME,
exp.FileFormatProperty: exp.Properties.Location.POST_WITH,
exp.FreespaceProperty: exp.Properties.Location.POST_NAME,
exp.HeapProperty: exp.Properties.Location.POST_WITH,
exp.IsolatedLoadingProperty: exp.Properties.Location.POST_NAME,
exp.JournalProperty: exp.Properties.Location.POST_NAME,
exp.LanguageProperty: exp.Properties.Location.POST_SCHEMA,
exp.LikeProperty: exp.Properties.Location.POST_SCHEMA,
exp.LocationProperty: exp.Properties.Location.POST_SCHEMA,
exp.LockingProperty: exp.Properties.Location.POST_ALIAS,
exp.LogProperty: exp.Properties.Location.POST_NAME,
exp.MaterializedProperty: exp.Properties.Location.POST_CREATE,
exp.MergeBlockRatioProperty: exp.Properties.Location.POST_NAME,
exp.NoPrimaryIndexProperty: exp.Properties.Location.POST_EXPRESSION,
exp.OnProperty: exp.Properties.Location.POST_SCHEMA,
exp.OnCommitProperty: exp.Properties.Location.POST_EXPRESSION,
exp.Order: exp.Properties.Location.POST_SCHEMA,
exp.PartitionedByProperty: exp.Properties.Location.POST_WITH,
exp.PrimaryKey: exp.Properties.Location.POST_SCHEMA,
exp.Property: exp.Properties.Location.POST_WITH,
exp.ReturnsProperty: exp.Properties.Location.POST_SCHEMA,
exp.RowFormatProperty: exp.Properties.Location.POST_SCHEMA,
exp.RowFormatDelimitedProperty: exp.Properties.Location.POST_SCHEMA,
exp.RowFormatSerdeProperty: exp.Properties.Location.POST_SCHEMA,
exp.SchemaCommentProperty: exp.Properties.Location.POST_SCHEMA,
exp.SerdeProperties: exp.Properties.Location.POST_SCHEMA,
exp.Set: exp.Properties.Location.POST_SCHEMA,
exp.SettingsProperty: exp.Properties.Location.POST_SCHEMA,
exp.SetProperty: exp.Properties.Location.POST_CREATE,
exp.SortKeyProperty: exp.Properties.Location.POST_SCHEMA,
exp.SqlSecurityProperty: exp.Properties.Location.POST_CREATE,
exp.StabilityProperty: exp.Properties.Location.POST_SCHEMA,
exp.TemporaryProperty: exp.Properties.Location.POST_CREATE,
exp.ToTableProperty: exp.Properties.Location.POST_SCHEMA,
exp.TransientProperty: exp.Properties.Location.POST_CREATE,
exp.MergeTreeTTL: exp.Properties.Location.POST_SCHEMA,
exp.VolatileProperty: exp.Properties.Location.POST_CREATE,
exp.WithDataProperty: exp.Properties.Location.POST_EXPRESSION,
exp.WithJournalTableProperty: exp.Properties.Location.POST_NAME,
}
# Keywords that can't be used as unquoted identifier names
RESERVED_KEYWORDS: t.Set[str] = set()
# Expressions whose comments are separated from them for better formatting
WITH_SEPARATED_COMMENTS: t.Tuple[t.Type[exp.Expression], ...] = (
exp.Create,
exp.Delete,
exp.Drop,
exp.From,
exp.Insert,
exp.Join,
exp.Select,
exp.Update,
exp.Where,
exp.With,
)
# Expressions that can remain unwrapped when appearing in the context of an INTERVAL
UNWRAPPED_INTERVAL_VALUES: t.Tuple[t.Type[exp.Expression], ...] = (
exp.Column,
exp.Literal,
exp.Neg,
exp.Paren,
)
SENTINEL_LINE_BREAK = "__SQLGLOT__LB__"
# Autofilled
INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
INVERSE_TIME_TRIE: t.Dict = {}
INDEX_OFFSET = 0
UNNEST_COLUMN_ONLY = False
ALIAS_POST_TABLESAMPLE = False
IDENTIFIERS_CAN_START_WITH_DIGIT = False
STRICT_STRING_CONCAT = False
NORMALIZE_FUNCTIONS: bool | str = "upper"
NULL_ORDERING = "nulls_are_small"
ESCAPE_LINE_BREAK = False
can_identify: t.Callable[[str, str | bool], bool]
# Delimiters for quotes, identifiers and the corresponding escape characters
QUOTE_START = "'"
QUOTE_END = "'"
IDENTIFIER_START = '"'
IDENTIFIER_END = '"'
TOKENIZER_CLASS = Tokenizer
# Delimiters for bit, hex, byte and raw literals
BIT_START: t.Optional[str] = None
BIT_END: t.Optional[str] = None
HEX_START: t.Optional[str] = None
HEX_END: t.Optional[str] = None
BYTE_START: t.Optional[str] = None
BYTE_END: t.Optional[str] = None
__slots__ = (
"pretty",
"identify",
"normalize",
"pad",
"_indent",
"normalize_functions",
"unsupported_level",
"max_unsupported",
"leading_comma",
"max_text_width",
"comments",
"unsupported_messages",
"_escaped_quote_end",
"_escaped_identifier_end",
"_cache",
)
def __init__(
self,
pretty: t.Optional[bool] = None,
identify: str | bool = False,
normalize: bool = False,
pad: int = 2,
indent: int = 2,
normalize_functions: t.Optional[str | bool] = None,
unsupported_level: ErrorLevel = ErrorLevel.WARN,
max_unsupported: int = 3,
leading_comma: bool = False,
max_text_width: int = 80,
comments: bool = True,
):
import sqlglot as sqlglot
self.pretty = pretty if pretty is not None else sqlglot.pretty
self.identify = identify
self.normalize = normalize
self.pad = pad
self._indent = indent
self.unsupported_level = unsupported_level
self.max_unsupported = max_unsupported
self.leading_comma = leading_comma
self.max_text_width = max_text_width
self.comments = comments
# This is both a Dialect property and a Generator argument, so we prioritize the latter
self.normalize_functions = (
self.NORMALIZE_FUNCTIONS if normalize_functions is None else normalize_functions
)
self.unsupported_messages: t.List[str] = []
self._escaped_quote_end: str = self.TOKENIZER_CLASS.STRING_ESCAPES[0] + self.QUOTE_END
self._escaped_identifier_end: str = (
self.TOKENIZER_CLASS.IDENTIFIER_ESCAPES[0] + self.IDENTIFIER_END
)
self._cache: t.Optional[t.Dict[int, str]] = None
def generate(
self,
expression: t.Optional[exp.Expression],
cache: t.Optional[t.Dict[int, str]] = None,
) -> str:
"""
Generates the SQL string corresponding to the given syntax tree.
Args:
expression: The syntax tree.
cache: An optional sql string cache. This leverages the hash of an Expression
which can be slow to compute, so only use it if you set _hash on each node.
Returns:
The SQL string corresponding to `expression`.
"""
if cache is not None:
self._cache = cache
self.unsupported_messages = []
sql = self.sql(expression).strip()
self._cache = None
if self.unsupported_level == ErrorLevel.IGNORE:
return sql
if self.unsupported_level == ErrorLevel.WARN:
for msg in self.unsupported_messages:
logger.warning(msg)
elif self.unsupported_level == ErrorLevel.RAISE and self.unsupported_messages:
raise UnsupportedError(concat_messages(self.unsupported_messages, self.max_unsupported))
if self.pretty:
sql = sql.replace(self.SENTINEL_LINE_BREAK, "\n")
return sql
def unsupported(self, message: str) -> None:
if self.unsupported_level == ErrorLevel.IMMEDIATE:
raise UnsupportedError(message)
self.unsupported_messages.append(message)
def sep(self, sep: str = " ") -> str:
return f"{sep.strip()}\n" if self.pretty else sep
def seg(self, sql: str, sep: str = " ") -> str:
return f"{self.sep(sep)}{sql}"
def pad_comment(self, comment: str) -> str:
comment = " " + comment if comment[0].strip() else comment
comment = comment + " " if comment[-1].strip() else comment
return comment
def maybe_comment(
self,
sql: str,
expression: t.Optional[exp.Expression] = None,
comments: t.Optional[t.List[str]] = None,
) -> str:
comments = (
((expression and expression.comments) if comments is None else comments) # type: ignore
if self.comments
else None
)
if not comments or isinstance(expression, exp.Binary):
return sql
sep = "\n" if self.pretty else " "
comments_sql = sep.join(
f"/*{self.pad_comment(comment)}*/" for comment in comments if comment
)
if not comments_sql:
return sql
if isinstance(expression, self.WITH_SEPARATED_COMMENTS):
return (
f"{self.sep()}{comments_sql}{sql}"
if sql[0].isspace()
else f"{comments_sql}{self.sep()}{sql}"
)
return f"{sql} {comments_sql}"
def wrap(self, expression: exp.Expression | str) -> str:
this_sql = self.indent(
self.sql(expression)
if isinstance(expression, (exp.Select, exp.Union))
else self.sql(expression, "this"),
level=1,
pad=0,
)
return f"({self.sep('')}{this_sql}{self.seg(')', sep='')}"
def no_identify(self, func: t.Callable[..., str], *args, **kwargs) -> str:
original = self.identify
self.identify = False
result = func(*args, **kwargs)
self.identify = original
return result
def normalize_func(self, name: str) -> str:
if self.normalize_functions == "upper" or self.normalize_functions is True:
return name.upper()
if self.normalize_functions == "lower":
return name.lower()
return name
def indent(
self,
sql: str,
level: int = 0,
pad: t.Optional[int] = None,
skip_first: bool = False,
skip_last: bool = False,
) -> str:
if not self.pretty:
return sql
pad = self.pad if pad is None else pad
lines = sql.split("\n")
return "\n".join(
line
if (skip_first and i == 0) or (skip_last and i == len(lines) - 1)
else f"{' ' * (level * self._indent + pad)}{line}"
for i, line in enumerate(lines)
)
def sql(
self,
expression: t.Optional[str | exp.Expression],
key: t.Optional[str] = None,
comment: bool = True,
) -> str:
if not expression:
return ""
if isinstance(expression, str):
return expression
if key:
value = expression.args.get(key)
if value:
return self.sql(value)
return ""
if self._cache is not None:
expression_id = hash(expression)
if expression_id in self._cache:
return self._cache[expression_id]
transform = self.TRANSFORMS.get(expression.__class__)
if callable(transform):
sql = transform(self, expression)
elif transform:
sql = transform
elif isinstance(expression, exp.Expression):
exp_handler_name = f"{expression.key}_sql"
if hasattr(self, exp_handler_name):
sql = getattr(self, exp_handler_name)(expression)
elif isinstance(expression, exp.Func):
sql = self.function_fallback_sql(expression)
elif isinstance(expression, exp.Property):
sql = self.property_sql(expression)
else:
raise ValueError(f"Unsupported expression type {expression.__class__.__name__}")
else:
raise ValueError(f"Expected an Expression. Received {type(expression)}: {expression}")
sql = self.maybe_comment(sql, expression) if self.comments and comment else sql
if self._cache is not None:
self._cache[expression_id] = sql
return sql
def uncache_sql(self, expression: exp.Uncache) -> str:
table = self.sql(expression, "this")
exists_sql = " IF EXISTS" if expression.args.get("exists") else ""
return f"UNCACHE TABLE{exists_sql} {table}"
def cache_sql(self, expression: exp.Cache) -> str:
lazy = " LAZY" if expression.args.get("lazy") else ""
table = self.sql(expression, "this")
options = expression.args.get("options")
options = f" OPTIONS({self.sql(options[0])} = {self.sql(options[1])})" if options else ""
sql = self.sql(expression, "expression")
sql = f" AS{self.sep()}{sql}" if sql else ""
sql = f"CACHE{lazy} TABLE {table}{options}{sql}"
return self.prepend_ctes(expression, sql)
def characterset_sql(self, expression: exp.CharacterSet) -> str:
if isinstance(expression.parent, exp.Cast):
return f"CHAR CHARACTER SET {self.sql(expression, 'this')}"
default = "DEFAULT " if expression.args.get("default") else ""
return f"{default}CHARACTER SET={self.sql(expression, 'this')}"
def column_sql(self, expression: exp.Column) -> str:
join_mark = " (+)" if expression.args.get("join_mark") else ""
if join_mark and not self.COLUMN_JOIN_MARKS_SUPPORTED:
join_mark = ""
self.unsupported("Outer join syntax using the (+) operator is not supported.")
column = ".".join(
self.sql(part)
for part in (
expression.args.get("catalog"),
expression.args.get("db"),
expression.args.get("table"),
expression.args.get("this"),
)
if part
)
return f"{column}{join_mark}"
def columnposition_sql(self, expression: exp.ColumnPosition) -> str:
this = self.sql(expression, "this")
this = f" {this}" if this else ""
position = self.sql(expression, "position")
return f"{position}{this}"
def columndef_sql(self, expression: exp.ColumnDef, sep: str = " ") -> str:
column = self.sql(expression, "this")
kind = self.sql(expression, "kind")
constraints = self.expressions(expression, key="constraints", sep=" ", flat=True)
exists = "IF NOT EXISTS " if expression.args.get("exists") else ""
kind = f"{sep}{kind}" if kind else ""
constraints = f" {constraints}" if constraints else ""
position = self.sql(expression, "position")
position = f" {position}" if position else ""
return f"{exists}{column}{kind}{constraints}{position}"
def columnconstraint_sql(self, expression: exp.ColumnConstraint) -> str:
this = self.sql(expression, "this")
kind_sql = self.sql(expression, "kind").strip()
return f"CONSTRAINT {this} {kind_sql}" if this else kind_sql
def computedcolumnconstraint_sql(self, expression: exp.ComputedColumnConstraint) -> str:
this = self.sql(expression, "this")
if expression.args.get("not_null"):
persisted = " PERSISTED NOT NULL"
elif expression.args.get("persisted"):
persisted = " PERSISTED"
else:
persisted = ""
return f"AS {this}{persisted}"
def autoincrementcolumnconstraint_sql(self, _) -> str:
return self.token_sql(TokenType.AUTO_INCREMENT)
def compresscolumnconstraint_sql(self, expression: exp.CompressColumnConstraint) -> str:
if isinstance(expression.this, list):
this = self.wrap(self.expressions(expression, key="this", flat=True))
else:
this = self.sql(expression, "this")
return f"COMPRESS {this}"
def generatedasidentitycolumnconstraint_sql(
self, expression: exp.GeneratedAsIdentityColumnConstraint
) -> str:
this = ""
if expression.this is not None:
on_null = " ON NULL" if expression.args.get("on_null") else ""
this = " ALWAYS" if expression.this else f" BY DEFAULT{on_null}"
start = expression.args.get("start")
start = f"START WITH {start}" if start else ""
increment = expression.args.get("increment")
increment = f" INCREMENT BY {increment}" if increment else ""
minvalue = expression.args.get("minvalue")
minvalue = f" MINVALUE {minvalue}" if minvalue else ""
maxvalue = expression.args.get("maxvalue")
maxvalue = f" MAXVALUE {maxvalue}" if maxvalue else ""
cycle = expression.args.get("cycle")
cycle_sql = ""
if cycle is not None:
cycle_sql = f"{' NO' if not cycle else ''} CYCLE"
cycle_sql = cycle_sql.strip() if not start and not increment else cycle_sql
sequence_opts = ""
if start or increment or cycle_sql:
sequence_opts = f"{start}{increment}{minvalue}{maxvalue}{cycle_sql}"
sequence_opts = f" ({sequence_opts.strip()})"
expr = self.sql(expression, "expression")
expr = f"({expr})" if expr else "IDENTITY"
return f"GENERATED{this} AS {expr}{sequence_opts}"
def notnullcolumnconstraint_sql(self, expression: exp.NotNullColumnConstraint) -> str:
return f"{'' if expression.args.get('allow_null') else 'NOT '}NULL"
def primarykeycolumnconstraint_sql(self, expression: exp.PrimaryKeyColumnConstraint) -> str:
desc = expression.args.get("desc")
if desc is not None:
return f"PRIMARY KEY{' DESC' if desc else ' ASC'}"
return f"PRIMARY KEY"
def uniquecolumnconstraint_sql(self, expression: exp.UniqueColumnConstraint) -> str:
this = self.sql(expression, "this")
this = f" {this}" if this else ""
return f"UNIQUE{this}"
def createable_sql(self, expression: exp.Create, locations: t.DefaultDict) -> str:
return self.sql(expression, "this")
def create_sql(self, expression: exp.Create) -> str:
kind = self.sql(expression, "kind").upper()
properties = expression.args.get("properties")
properties_locs = self.locate_properties(properties) if properties else defaultdict()
this = self.createable_sql(expression, properties_locs)
properties_sql = ""
if properties_locs.get(exp.Properties.Location.POST_SCHEMA) or properties_locs.get(
exp.Properties.Location.POST_WITH
):
properties_sql = self.sql(
exp.Properties(
expressions=[
*properties_locs[exp.Properties.Location.POST_SCHEMA],
*properties_locs[exp.Properties.Location.POST_WITH],
]
)
)
begin = " BEGIN" if expression.args.get("begin") else ""
expression_sql = self.sql(expression, "expression")
if expression_sql:
expression_sql = f"{begin}{self.sep()}{expression_sql}"
if self.CREATE_FUNCTION_RETURN_AS or not isinstance(expression.expression, exp.Return):
if properties_locs.get(exp.Properties.Location.POST_ALIAS):
postalias_props_sql = self.properties(
exp.Properties(
expressions=properties_locs[exp.Properties.Location.POST_ALIAS]
),
wrapped=False,
)
expression_sql = f" AS {postalias_props_sql}{expression_sql}"
else:
expression_sql = f" AS{expression_sql}"
postindex_props_sql = ""
if properties_locs.get(exp.Properties.Location.POST_INDEX):
postindex_props_sql = self.properties(
exp.Properties(expressions=properties_locs[exp.Properties.Location.POST_INDEX]),
wrapped=False,
prefix=" ",
)
indexes = self.expressions(expression, key="indexes", indent=False, sep=" ")
indexes = f" {indexes}" if indexes else ""
index_sql = indexes + postindex_props_sql
replace = " OR REPLACE" if expression.args.get("replace") else ""
unique = " UNIQUE" if expression.args.get("unique") else ""
postcreate_props_sql = ""
if properties_locs.get(exp.Properties.Location.POST_CREATE):
postcreate_props_sql = self.properties(
exp.Properties(expressions=properties_locs[exp.Properties.Location.POST_CREATE]),
sep=" ",
prefix=" ",
wrapped=False,
)
modifiers = "".join((replace, unique, postcreate_props_sql))
postexpression_props_sql = ""
if properties_locs.get(exp.Properties.Location.POST_EXPRESSION):
postexpression_props_sql = self.properties(
exp.Properties(
expressions=properties_locs[exp.Properties.Location.POST_EXPRESSION]
),
sep=" ",
prefix=" ",
wrapped=False,
)
exists_sql = " IF NOT EXISTS" if expression.args.get("exists") else ""
no_schema_binding = (
" WITH NO SCHEMA BINDING" if expression.args.get("no_schema_binding") else ""
)
clone = self.sql(expression, "clone")
clone = f" {clone}" if clone else ""
expression_sql = f"CREATE{modifiers} {kind}{exists_sql} {this}{properties_sql}{expression_sql}{postexpression_props_sql}{index_sql}{no_schema_binding}{clone}"
return self.prepend_ctes(expression, expression_sql)
def clone_sql(self, expression: exp.Clone) -> str:
this = self.sql(expression, "this")
shallow = "SHALLOW " if expression.args.get("shallow") else ""
this = f"{shallow}CLONE {this}"
when = self.sql(expression, "when")
if when:
kind = self.sql(expression, "kind")
expr = self.sql(expression, "expression")
return f"{this} {when} ({kind} => {expr})"
return this
def describe_sql(self, expression: exp.Describe) -> str:
return f"DESCRIBE {self.sql(expression, 'this')}"
def prepend_ctes(self, expression: exp.Expression, sql: str) -> str:
with_ = self.sql(expression, "with")
if with_:
sql = f"{with_}{self.sep()}{sql}"
return sql
def with_sql(self, expression: exp.With) -> str:
sql = self.expressions(expression, flat=True)
recursive = "RECURSIVE " if expression.args.get("recursive") else ""
return f"WITH {recursive}{sql}"
def cte_sql(self, expression: exp.CTE) -> str:
alias = self.sql(expression, "alias")
return f"{alias} AS {self.wrap(expression)}"
def tablealias_sql(self, expression: exp.TableAlias) -> str:
alias = self.sql(expression, "this")
columns = self.expressions(expression, key="columns", flat=True)
columns = f"({columns})" if columns else ""
return f"{alias}{columns}"
def bitstring_sql(self, expression: exp.BitString) -> str:
this = self.sql(expression, "this")
if self.BIT_START:
return f"{self.BIT_START}{this}{self.BIT_END}"
return f"{int(this, 2)}"
def hexstring_sql(self, expression: exp.HexString) -> str:
this = self.sql(expression, "this")
if self.HEX_START:
return f"{self.HEX_START}{this}{self.HEX_END}"
return f"{int(this, 16)}"
def bytestring_sql(self, expression: exp.ByteString) -> str:
this = self.sql(expression, "this")
if self.BYTE_START:
return f"{self.BYTE_START}{this}{self.BYTE_END}"
return this
def rawstring_sql(self, expression: exp.RawString) -> str:
string = self.escape_str(expression.this.replace("\\", "\\\\"))
return f"{self.QUOTE_START}{string}{self.QUOTE_END}"
def datatypeparam_sql(self, expression: exp.DataTypeParam) -> str:
this = self.sql(expression, "this")
specifier = self.sql(expression, "expression")
specifier = f" {specifier}" if specifier else ""
return f"{this}{specifier}"
def datatype_sql(self, expression: exp.DataType) -> str:
type_value = expression.this
if type_value == exp.DataType.Type.USERDEFINED and expression.args.get("kind"):
type_sql = self.sql(expression, "kind")
else:
type_sql = (
self.TYPE_MAPPING.get(type_value, type_value.value)
if isinstance(type_value, exp.DataType.Type)
else type_value
)
nested = ""
interior = self.expressions(expression, flat=True)
values = ""
if interior:
if expression.args.get("nested"):
nested = f"{self.STRUCT_DELIMITER[0]}{interior}{self.STRUCT_DELIMITER[1]}"
if expression.args.get("values") is not None:
delimiters = ("[", "]") if type_value == exp.DataType.Type.ARRAY else ("(", ")")
values = self.expressions(expression, key="values", flat=True)
values = f"{delimiters[0]}{values}{delimiters[1]}"
elif type_value == exp.DataType.Type.INTERVAL:
nested = f" {interior}"
else:
nested = f"({interior})"
type_sql = f"{type_sql}{nested}{values}"
if self.TZ_TO_WITH_TIME_ZONE and type_value in (
exp.DataType.Type.TIMETZ,
exp.DataType.Type.TIMESTAMPTZ,
):
type_sql = f"{type_sql} WITH TIME ZONE"
return type_sql
def directory_sql(self, expression: exp.Directory) -> str:
local = "LOCAL " if expression.args.get("local") else ""
row_format = self.sql(expression, "row_format")
row_format = f" {row_format}" if row_format else ""
return f"{local}DIRECTORY {self.sql(expression, 'this')}{row_format}"
def delete_sql(self, expression: exp.Delete) -> str:
this = self.sql(expression, "this")
this = f" FROM {this}" if this else ""
using = self.sql(expression, "using")
using = f" USING {using}" if using else ""
where = self.sql(expression, "where")
returning = self.sql(expression, "returning")
limit = self.sql(expression, "limit")
tables = self.expressions(expression, key="tables")
tables = f" {tables}" if tables else ""
if self.RETURNING_END:
expression_sql = f"{this}{using}{where}{returning}{limit}"
else:
expression_sql = f"{returning}{this}{using}{where}{limit}"
return self.prepend_ctes(expression, f"DELETE{tables}{expression_sql}")
def drop_sql(self, expression: exp.Drop) -> str:
this = self.sql(expression, "this")
kind = expression.args["kind"]
exists_sql = " IF EXISTS " if expression.args.get("exists") else " "
temporary = " TEMPORARY" if expression.args.get("temporary") else ""
materialized = " MATERIALIZED" if expression.args.get("materialized") else ""
cascade = " CASCADE" if expression.args.get("cascade") else ""
constraints = " CONSTRAINTS" if expression.args.get("constraints") else ""
purge = " PURGE" if expression.args.get("purge") else ""
return (
f"DROP{temporary}{materialized} {kind}{exists_sql}{this}{cascade}{constraints}{purge}"
)
def except_sql(self, expression: exp.Except) -> str:
return self.prepend_ctes(
expression,
self.set_operation(expression, self.except_op(expression)),
)
def except_op(self, expression: exp.Except) -> str:
return f"EXCEPT{'' if expression.args.get('distinct') else ' ALL'}"
def fetch_sql(self, expression: exp.Fetch) -> str:
direction = expression.args.get("direction")
direction = f" {direction.upper()}" if direction else ""
count = expression.args.get("count")
count = f" {count}" if count else ""
if expression.args.get("percent"):
count = f"{count} PERCENT"
with_ties_or_only = "WITH TIES" if expression.args.get("with_ties") else "ONLY"
return f"{self.seg('FETCH')}{direction}{count} ROWS {with_ties_or_only}"
def filter_sql(self, expression: exp.Filter) -> str:
this = self.sql(expression, "this")
where = self.sql(expression, "expression").strip()
return f"{this} FILTER({where})"
def hint_sql(self, expression: exp.Hint) -> str:
if not self.QUERY_HINTS:
self.unsupported("Hints are not supported")
return ""
return f" /*+ {self.expressions(expression, sep=self.QUERY_HINT_SEP).strip()} */"
def index_sql(self, expression: exp.Index) -> str:
unique = "UNIQUE " if expression.args.get("unique") else ""
primary = "PRIMARY " if expression.args.get("primary") else ""
amp = "AMP " if expression.args.get("amp") else ""
name = self.sql(expression, "this")
name = f"{name} " if name else ""
table = self.sql(expression, "table")
table = f"{self.INDEX_ON} {table}" if table else ""
using = self.sql(expression, "using")
using = f" USING {using} " if using else ""
index = "INDEX " if not table else ""
columns = self.expressions(expression, key="columns", flat=True)
columns = f"({columns})" if columns else ""
partition_by = self.expressions(expression, key="partition_by", flat=True)
partition_by = f" PARTITION BY {partition_by}" if partition_by else ""
return f"{unique}{primary}{amp}{index}{name}{table}{using}{columns}{partition_by}"
def identifier_sql(self, expression: exp.Identifier) -> str:
text = expression.name
lower = text.lower()
text = lower if self.normalize and not expression.quoted else text
text = text.replace(self.IDENTIFIER_END, self._escaped_identifier_end)
if (
expression.quoted
or self.can_identify(text, self.identify)
or lower in self.RESERVED_KEYWORDS
or (not self.IDENTIFIERS_CAN_START_WITH_DIGIT and text[:1].isdigit())
):
text = f"{self.IDENTIFIER_START}{text}{self.IDENTIFIER_END}"
return text
def inputoutputformat_sql(self, expression: exp.InputOutputFormat) -> str:
input_format = self.sql(expression, "input_format")
input_format = f"INPUTFORMAT {input_format}" if input_format else ""
output_format = self.sql(expression, "output_format")
output_format = f"OUTPUTFORMAT {output_format}" if output_format else ""
return self.sep().join((input_format, output_format))
def national_sql(self, expression: exp.National, prefix: str = "N") -> str:
string = self.sql(exp.Literal.string(expression.name))
return f"{prefix}{string}"
def partition_sql(self, expression: exp.Partition) -> str:
return f"PARTITION({self.expressions(expression, flat=True)})"
def properties_sql(self, expression: exp.Properties) -> str:
root_properties = []
with_properties = []
for p in expression.expressions:
p_loc = self.PROPERTIES_LOCATION[p.__class__]
if p_loc == exp.Properties.Location.POST_WITH:
with_properties.append(p.copy())
elif p_loc == exp.Properties.Location.POST_SCHEMA:
root_properties.append(p.copy())
return self.root_properties(
exp.Properties(expressions=root_properties)
) + self.with_properties(exp.Properties(expressions=with_properties))
def root_properties(self, properties: exp.Properties) -> str:
if properties.expressions:
return self.sep() + self.expressions(properties, indent=False, sep=" ")
return ""
def properties(
self,
properties: exp.Properties,
prefix: str = "",
sep: str = ", ",
suffix: str = "",
wrapped: bool = True,
) -> str:
if properties.expressions:
expressions = self.expressions(properties, sep=sep, indent=False)
if expressions:
expressions = self.wrap(expressions) if wrapped else expressions
return f"{prefix}{' ' if prefix and prefix != ' ' else ''}{expressions}{suffix}"
return ""
def with_properties(self, properties: exp.Properties) -> str:
return self.properties(properties, prefix=self.seg("WITH"))
def locate_properties(self, properties: exp.Properties) -> t.DefaultDict:
properties_locs = defaultdict(list)
for p in properties.expressions:
p_loc = self.PROPERTIES_LOCATION[p.__class__]
if p_loc != exp.Properties.Location.UNSUPPORTED:
properties_locs[p_loc].append(p.copy())
else:
self.unsupported(f"Unsupported property {p.key}")
return properties_locs
def property_sql(self, expression: exp.Property) -> str:
property_cls = expression.__class__
if property_cls == exp.Property:
return f"{expression.name}={self.sql(expression, 'value')}"
property_name = exp.Properties.PROPERTY_TO_NAME.get(property_cls)
if not property_name:
self.unsupported(f"Unsupported property {expression.key}")
return f"{property_name}={self.sql(expression, 'this')}"
def likeproperty_sql(self, expression: exp.LikeProperty) -> str:
options = " ".join(f"{e.name} {self.sql(e, 'value')}" for e in expression.expressions)
options = f" {options}" if options else ""
return f"LIKE {self.sql(expression, 'this')}{options}"
def fallbackproperty_sql(self, expression: exp.FallbackProperty) -> str:
no = "NO " if expression.args.get("no") else ""
protection = " PROTECTION" if expression.args.get("protection") else ""
return f"{no}FALLBACK{protection}"
def journalproperty_sql(self, expression: exp.JournalProperty) -> str:
no = "NO " if expression.args.get("no") else ""
local = expression.args.get("local")
local = f"{local} " if local else ""
dual = "DUAL " if expression.args.get("dual") else ""
before = "BEFORE " if expression.args.get("before") else ""
after = "AFTER " if expression.args.get("after") else ""
return f"{no}{local}{dual}{before}{after}JOURNAL"
def freespaceproperty_sql(self, expression: exp.FreespaceProperty) -> str:
freespace = self.sql(expression, "this")
percent = " PERCENT" if expression.args.get("percent") else ""
return f"FREESPACE={freespace}{percent}"
def checksumproperty_sql(self, expression: exp.ChecksumProperty) -> str:
if expression.args.get("default"):
property = "DEFAULT"
elif expression.args.get("on"):
property = "ON"
else:
property = "OFF"
return f"CHECKSUM={property}"
def mergeblockratioproperty_sql(self, expression: exp.MergeBlockRatioProperty) -> str:
if expression.args.get("no"):
return "NO MERGEBLOCKRATIO"
if expression.args.get("default"):
return "DEFAULT MERGEBLOCKRATIO"
percent = " PERCENT" if expression.args.get("percent") else ""
return f"MERGEBLOCKRATIO={self.sql(expression, 'this')}{percent}"
def datablocksizeproperty_sql(self, expression: exp.DataBlocksizeProperty) -> str:
default = expression.args.get("default")
minimum = expression.args.get("minimum")
maximum = expression.args.get("maximum")
if default or minimum or maximum:
if default:
prop = "DEFAULT"
elif minimum:
prop = "MINIMUM"
else:
prop = "MAXIMUM"
return f"{prop} DATABLOCKSIZE"
units = expression.args.get("units")
units = f" {units}" if units else ""
return f"DATABLOCKSIZE={self.sql(expression, 'size')}{units}"
def blockcompressionproperty_sql(self, expression: exp.BlockCompressionProperty) -> str:
autotemp = expression.args.get("autotemp")
always = expression.args.get("always")
default = expression.args.get("default")
manual = expression.args.get("manual")
never = expression.args.get("never")
if autotemp is not None:
prop = f"AUTOTEMP({self.expressions(autotemp)})"
elif always:
prop = "ALWAYS"
elif default:
prop = "DEFAULT"
elif manual:
prop = "MANUAL"
elif never:
prop = "NEVER"
return f"BLOCKCOMPRESSION={prop}"
def isolatedloadingproperty_sql(self, expression: exp.IsolatedLoadingProperty) -> str:
no = expression.args.get("no")
no = " NO" if no else ""
concurrent = expression.args.get("concurrent")
concurrent = " CONCURRENT" if concurrent else ""
for_ = ""
if expression.args.get("for_all"):
for_ = " FOR ALL"
elif expression.args.get("for_insert"):
for_ = " FOR INSERT"
elif expression.args.get("for_none"):
for_ = " FOR NONE"
return f"WITH{no}{concurrent} ISOLATED LOADING{for_}"
def lockingproperty_sql(self, expression: exp.LockingProperty) -> str:
kind = expression.args.get("kind")
this = f" {self.sql(expression, 'this')}" if expression.this else ""
for_or_in = expression.args.get("for_or_in")
lock_type = expression.args.get("lock_type")
override = " OVERRIDE" if expression.args.get("override") else ""
return f"LOCKING {kind}{this} {for_or_in} {lock_type}{override}"
def withdataproperty_sql(self, expression: exp.WithDataProperty) -> str:
data_sql = f"WITH {'NO ' if expression.args.get('no') else ''}DATA"
statistics = expression.args.get("statistics")
statistics_sql = ""
if statistics is not None:
statistics_sql = f" AND {'NO ' if not statistics else ''}STATISTICS"
return f"{data_sql}{statistics_sql}"
def insert_sql(self, expression: exp.Insert) -> str:
overwrite = expression.args.get("overwrite")
if isinstance(expression.this, exp.Directory):
this = " OVERWRITE" if overwrite else " INTO"
else:
this = " OVERWRITE TABLE" if overwrite else " INTO"
alternative = expression.args.get("alternative")
alternative = f" OR {alternative}" if alternative else ""
ignore = " IGNORE" if expression.args.get("ignore") else ""
this = f"{this} {self.sql(expression, 'this')}"
exists = " IF EXISTS" if expression.args.get("exists") else ""
partition_sql = (
f" {self.sql(expression, 'partition')}" if expression.args.get("partition") else ""
)
where = self.sql(expression, "where")
where = f"{self.sep()}REPLACE WHERE {where}" if where else ""
expression_sql = f"{self.sep()}{self.sql(expression, 'expression')}"
conflict = self.sql(expression, "conflict")
by_name = " BY NAME" if expression.args.get("by_name") else ""
returning = self.sql(expression, "returning")
if self.RETURNING_END:
expression_sql = f"{expression_sql}{conflict}{returning}"
else:
expression_sql = f"{returning}{expression_sql}{conflict}"
sql = f"INSERT{alternative}{ignore}{this}{by_name}{exists}{partition_sql}{where}{expression_sql}"
return self.prepend_ctes(expression, sql)
def intersect_sql(self, expression: exp.Intersect) -> str:
return self.prepend_ctes(
expression,
self.set_operation(expression, self.intersect_op(expression)),
)
def intersect_op(self, expression: exp.Intersect) -> str:
return f"INTERSECT{'' if expression.args.get('distinct') else ' ALL'}"
def introducer_sql(self, expression: exp.Introducer) -> str:
return f"{self.sql(expression, 'this')} {self.sql(expression, 'expression')}"
def pseudotype_sql(self, expression: exp.PseudoType) -> str:
return expression.name.upper()
def objectidentifier_sql(self, expression: exp.ObjectIdentifier) -> str:
return expression.name.upper()
def onconflict_sql(self, expression: exp.OnConflict) -> str:
conflict = "ON DUPLICATE KEY" if expression.args.get("duplicate") else "ON CONFLICT"
constraint = self.sql(expression, "constraint")
if constraint:
constraint = f"ON CONSTRAINT {constraint}"
key = self.expressions(expression, key="key", flat=True)
do = "" if expression.args.get("duplicate") else " DO "
nothing = "NOTHING" if expression.args.get("nothing") else ""
expressions = self.expressions(expression, flat=True)
set_keyword = "SET " if self.DUPLICATE_KEY_UPDATE_WITH_SET else ""
if expressions:
expressions = f"UPDATE {set_keyword}{expressions}"
return f"{self.seg(conflict)} {constraint}{key}{do}{nothing}{expressions}"
def returning_sql(self, expression: exp.Returning) -> str:
return f"{self.seg('RETURNING')} {self.expressions(expression, flat=True)}"
def rowformatdelimitedproperty_sql(self, expression: exp.RowFormatDelimitedProperty) -> str:
fields = expression.args.get("fields")
fields = f" FIELDS TERMINATED BY {fields}" if fields else ""
escaped = expression.args.get("escaped")
escaped = f" ESCAPED BY {escaped}" if escaped else ""
items = expression.args.get("collection_items")
items = f" COLLECTION ITEMS TERMINATED BY {items}" if items else ""
keys = expression.args.get("map_keys")
keys = f" MAP KEYS TERMINATED BY {keys}" if keys else ""
lines = expression.args.get("lines")
lines = f" LINES TERMINATED BY {lines}" if lines else ""
null = expression.args.get("null")
null = f" NULL DEFINED AS {null}" if null else ""
return f"ROW FORMAT DELIMITED{fields}{escaped}{items}{keys}{lines}{null}"
def withtablehint_sql(self, expression: exp.WithTableHint) -> str:
return f"WITH ({self.expressions(expression, flat=True)})"
def indextablehint_sql(self, expression: exp.IndexTableHint) -> str:
this = f"{self.sql(expression, 'this')} INDEX"
target = self.sql(expression, "target")
target = f" FOR {target}" if target else ""
return f"{this}{target} ({self.expressions(expression, flat=True)})"
def table_sql(self, expression: exp.Table, sep: str = " AS ") -> str:
table = ".".join(
part
for part in [
self.sql(expression, "catalog"),
self.sql(expression, "db"),
self.sql(expression, "this"),
]
if part
)
version = self.sql(expression, "version")
version = f" {version}" if version else ""
alias = self.sql(expression, "alias")
alias = f"{sep}{alias}" if alias else ""
hints = self.expressions(expression, key="hints", sep=" ")
hints = f" {hints}" if hints and self.TABLE_HINTS else ""
pivots = self.expressions(expression, key="pivots", sep=" ", flat=True)
pivots = f" {pivots}" if pivots else ""
joins = self.expressions(expression, key="joins", sep="", skip_first=True)
laterals = self.expressions(expression, key="laterals", sep="")
return f"{table}{version}{alias}{hints}{pivots}{joins}{laterals}"
def tablesample_sql(
self, expression: exp.TableSample, seed_prefix: str = "SEED", sep=" AS "
) -> str:
if self.ALIAS_POST_TABLESAMPLE and expression.this.alias:
table = expression.this.copy()
table.set("alias", None)
this = self.sql(table)
alias = f"{sep}{self.sql(expression.this, 'alias')}"
else:
this = self.sql(expression, "this")
alias = ""
method = self.sql(expression, "method")
method = f"{method.upper()} " if method and self.TABLESAMPLE_WITH_METHOD else ""
numerator = self.sql(expression, "bucket_numerator")
denominator = self.sql(expression, "bucket_denominator")
field = self.sql(expression, "bucket_field")
field = f" ON {field}" if field else ""
bucket = f"BUCKET {numerator} OUT OF {denominator}{field}" if numerator else ""
percent = self.sql(expression, "percent")
percent = f"{percent} PERCENT" if percent else ""
rows = self.sql(expression, "rows")
rows = f"{rows} ROWS" if rows else ""
size = self.sql(expression, "size")
if size and self.TABLESAMPLE_SIZE_IS_PERCENT:
size = f"{size} PERCENT"
seed = self.sql(expression, "seed")
seed = f" {seed_prefix} ({seed})" if seed else ""
kind = expression.args.get("kind", "TABLESAMPLE")
return f"{this} {kind} {method}({bucket}{percent}{rows}{size}){seed}{alias}"
def pivot_sql(self, expression: exp.Pivot) -> str:
expressions = self.expressions(expression, flat=True)
if expression.this:
this = self.sql(expression, "this")
on = f"{self.seg('ON')} {expressions}"
using = self.expressions(expression, key="using", flat=True)
using = f"{self.seg('USING')} {using}" if using else ""
group = self.sql(expression, "group")
return f"PIVOT {this}{on}{using}{group}"
alias = self.sql(expression, "alias")
alias = f" AS {alias}" if alias else ""
unpivot = expression.args.get("unpivot")
direction = "UNPIVOT" if unpivot else "PIVOT"
field = self.sql(expression, "field")
include_nulls = expression.args.get("include_nulls")
if include_nulls is not None:
nulls = " INCLUDE NULLS " if include_nulls else " EXCLUDE NULLS "
else:
nulls = ""
return f"{direction}{nulls}({expressions} FOR {field}){alias}"
def version_sql(self, expression: exp.Version) -> str:
this = f"FOR {expression.name}"
kind = expression.text("kind")
expr = self.sql(expression, "expression")
return f"{this} {kind} {expr}"
def tuple_sql(self, expression: exp.Tuple) -> str:
return f"({self.expressions(expression, flat=True)})"
def update_sql(self, expression: exp.Update) -> str:
this = self.sql(expression, "this")
set_sql = self.expressions(expression, flat=True)
from_sql = self.sql(expression, "from")
where_sql = self.sql(expression, "where")
returning = self.sql(expression, "returning")
order = self.sql(expression, "order")
limit = self.sql(expression, "limit")
if self.RETURNING_END:
expression_sql = f"{from_sql}{where_sql}{returning}"
else:
expression_sql = f"{returning}{from_sql}{where_sql}"
sql = f"UPDATE {this} SET {set_sql}{expression_sql}{order}{limit}"
return self.prepend_ctes(expression, sql)
def values_sql(self, expression: exp.Values) -> str:
# The VALUES clause is still valid in an `INSERT INTO ..` statement, for example
if self.VALUES_AS_TABLE or not expression.find_ancestor(exp.From, exp.Join):
args = self.expressions(expression)
alias = self.sql(expression, "alias")
values = f"VALUES{self.seg('')}{args}"
values = (
f"({values})"
if self.WRAP_DERIVED_VALUES and (alias or isinstance(expression.parent, exp.From))
else values
)
return f"{values} AS {alias}" if alias else values
# Converts `VALUES...` expression into a series of select unions.
# Note: If you have a lot of unions then this will result in a large number of recursive statements to
# evaluate the expression. You may need to increase `sys.setrecursionlimit` to run and it can also be
# very slow.
expression = expression.copy()
column_names = expression.alias and expression.args["alias"].columns
selects = []
for i, tup in enumerate(expression.expressions):
row = tup.expressions
if i == 0 and column_names:
row = [
exp.alias_(value, column_name) for value, column_name in zip(row, column_names)
]
selects.append(exp.Select(expressions=row))
subquery_expression: exp.Select | exp.Union = selects[0]
if len(selects) > 1:
for select in selects[1:]:
subquery_expression = exp.union(
subquery_expression, select, distinct=False, copy=False
)
return self.subquery_sql(subquery_expression.subquery(expression.alias, copy=False))
def var_sql(self, expression: exp.Var) -> str:
return self.sql(expression, "this")
def into_sql(self, expression: exp.Into) -> str:
temporary = " TEMPORARY" if expression.args.get("temporary") else ""
unlogged = " UNLOGGED" if expression.args.get("unlogged") else ""
return f"{self.seg('INTO')}{temporary or unlogged} {self.sql(expression, 'this')}"
def from_sql(self, expression: exp.From) -> str:
return f"{self.seg('FROM')} {self.sql(expression, 'this')}"
def group_sql(self, expression: exp.Group) -> str:
group_by = self.op_expressions("GROUP BY", expression)
if expression.args.get("all"):
return f"{group_by} ALL"
grouping_sets = self.expressions(expression, key="grouping_sets", indent=False)
grouping_sets = (
f"{self.seg('GROUPING SETS')} {self.wrap(grouping_sets)}" if grouping_sets else ""
)
cube = expression.args.get("cube", [])
if seq_get(cube, 0) is True:
return f"{group_by}{self.seg('WITH CUBE')}"
else:
cube_sql = self.expressions(expression, key="cube", indent=False)
cube_sql = f"{self.seg('CUBE')} {self.wrap(cube_sql)}" if cube_sql else ""
rollup = expression.args.get("rollup", [])
if seq_get(rollup, 0) is True:
return f"{group_by}{self.seg('WITH ROLLUP')}"
else:
rollup_sql = self.expressions(expression, key="rollup", indent=False)
rollup_sql = f"{self.seg('ROLLUP')} {self.wrap(rollup_sql)}" if rollup_sql else ""
groupings = csv(
grouping_sets,
cube_sql,
rollup_sql,
self.seg("WITH TOTALS") if expression.args.get("totals") else "",
sep=self.GROUPINGS_SEP,
)
if expression.args.get("expressions") and groupings:
group_by = f"{group_by}{self.GROUPINGS_SEP}"
return f"{group_by}{groupings}"
def having_sql(self, expression: exp.Having) -> str:
this = self.indent(self.sql(expression, "this"))
return f"{self.seg('HAVING')}{self.sep()}{this}"
def connect_sql(self, expression: exp.Connect) -> str:
start = self.sql(expression, "start")
start = self.seg(f"START WITH {start}") if start else ""
connect = self.sql(expression, "connect")
connect = self.seg(f"CONNECT BY {connect}")
return start + connect
def prior_sql(self, expression: exp.Prior) -> str:
return f"PRIOR {self.sql(expression, 'this')}"
def join_sql(self, expression: exp.Join) -> str:
op_sql = " ".join(
op
for op in (
expression.method,
"GLOBAL" if expression.args.get("global") else None,
expression.side,
expression.kind,
expression.hint if self.JOIN_HINTS else None,
)
if op
)
on_sql = self.sql(expression, "on")
using = expression.args.get("using")
if not on_sql and using:
on_sql = csv(*(self.sql(column) for column in using))
this_sql = self.sql(expression, "this")
if on_sql:
on_sql = self.indent(on_sql, skip_first=True)
space = self.seg(" " * self.pad) if self.pretty else " "
if using:
on_sql = f"{space}USING ({on_sql})"
else:
on_sql = f"{space}ON {on_sql}"
elif not op_sql:
return f", {this_sql}"
op_sql = f"{op_sql} JOIN" if op_sql else "JOIN"
return f"{self.seg(op_sql)} {this_sql}{on_sql}"
def lambda_sql(self, expression: exp.Lambda, arrow_sep: str = "->") -> str:
args = self.expressions(expression, flat=True)
args = f"({args})" if len(args.split(",")) > 1 else args
return f"{args} {arrow_sep} {self.sql(expression, 'this')}"
def lateral_sql(self, expression: exp.Lateral) -> str:
this = self.sql(expression, "this")
if isinstance(expression.this, exp.Subquery):
return f"LATERAL {this}"
if expression.args.get("view"):
alias = expression.args["alias"]
columns = self.expressions(alias, key="columns", flat=True)
table = f" {alias.name}" if alias.name else ""
columns = f" AS {columns}" if columns else ""
op_sql = self.seg(f"LATERAL VIEW{' OUTER' if expression.args.get('outer') else ''}")
return f"{op_sql}{self.sep()}{this}{table}{columns}"
alias = self.sql(expression, "alias")
alias = f" AS {alias}" if alias else ""
return f"LATERAL {this}{alias}"
def limit_sql(self, expression: exp.Limit, top: bool = False) -> str:
this = self.sql(expression, "this")
args = ", ".join(
sql
for sql in (
self.sql(expression, "offset"),
self.sql(expression, "expression"),
)
if sql
)
return f"{this}{self.seg('TOP' if top else 'LIMIT')} {args}"
def offset_sql(self, expression: exp.Offset) -> str:
this = self.sql(expression, "this")
return f"{this}{self.seg('OFFSET')} {self.sql(expression, 'expression')}"
def setitem_sql(self, expression: exp.SetItem) -> str:
kind = self.sql(expression, "kind")
kind = f"{kind} " if kind else ""
this = self.sql(expression, "this")
expressions = self.expressions(expression)
collate = self.sql(expression, "collate")
collate = f" COLLATE {collate}" if collate else ""
global_ = "GLOBAL " if expression.args.get("global") else ""
return f"{global_}{kind}{this}{expressions}{collate}"
def set_sql(self, expression: exp.Set) -> str:
expressions = (
f" {self.expressions(expression, flat=True)}" if expression.expressions else ""
)
tag = " TAG" if expression.args.get("tag") else ""
return f"{'UNSET' if expression.args.get('unset') else 'SET'}{tag}{expressions}"
def pragma_sql(self, expression: exp.Pragma) -> str:
return f"PRAGMA {self.sql(expression, 'this')}"
def lock_sql(self, expression: exp.Lock) -> str:
if not self.LOCKING_READS_SUPPORTED:
self.unsupported("Locking reads using 'FOR UPDATE/SHARE' are not supported")
return ""
lock_type = "FOR UPDATE" if expression.args["update"] else "FOR SHARE"
expressions = self.expressions(expression, flat=True)
expressions = f" OF {expressions}" if expressions else ""
wait = expression.args.get("wait")
if wait is not None:
if isinstance(wait, exp.Literal):
wait = f" WAIT {self.sql(wait)}"
else:
wait = " NOWAIT" if wait else " SKIP LOCKED"
return f"{lock_type}{expressions}{wait or ''}"
def literal_sql(self, expression: exp.Literal) -> str:
text = expression.this or ""
if expression.is_string:
text = f"{self.QUOTE_START}{self.escape_str(text)}{self.QUOTE_END}"
return text
def escape_str(self, text: str) -> str:
text = text.replace(self.QUOTE_END, self._escaped_quote_end)
if self.ESCAPE_LINE_BREAK:
text = text.replace("\n", "\\n")
elif self.pretty:
text = text.replace("\n", self.SENTINEL_LINE_BREAK)
return text
def loaddata_sql(self, expression: exp.LoadData) -> str:
local = " LOCAL" if expression.args.get("local") else ""
inpath = f" INPATH {self.sql(expression, 'inpath')}"
overwrite = " OVERWRITE" if expression.args.get("overwrite") else ""
this = f" INTO TABLE {self.sql(expression, 'this')}"
partition = self.sql(expression, "partition")
partition = f" {partition}" if partition else ""
input_format = self.sql(expression, "input_format")
input_format = f" INPUTFORMAT {input_format}" if input_format else ""
serde = self.sql(expression, "serde")
serde = f" SERDE {serde}" if serde else ""
return f"LOAD DATA{local}{inpath}{overwrite}{this}{partition}{input_format}{serde}"
def null_sql(self, *_) -> str:
return "NULL"
def boolean_sql(self, expression: exp.Boolean) -> str:
return "TRUE" if expression.this else "FALSE"
def order_sql(self, expression: exp.Order, flat: bool = False) -> str:
this = self.sql(expression, "this")
this = f"{this} " if this else this
return self.op_expressions(f"{this}ORDER BY", expression, flat=this or flat) # type: ignore
def cluster_sql(self, expression: exp.Cluster) -> str:
return self.op_expressions("CLUSTER BY", expression)
def distribute_sql(self, expression: exp.Distribute) -> str:
return self.op_expressions("DISTRIBUTE BY", expression)
def sort_sql(self, expression: exp.Sort) -> str:
return self.op_expressions("SORT BY", expression)
def ordered_sql(self, expression: exp.Ordered) -> str:
desc = expression.args.get("desc")
asc = not desc
nulls_first = expression.args.get("nulls_first")
nulls_last = not nulls_first
nulls_are_large = self.NULL_ORDERING == "nulls_are_large"
nulls_are_small = self.NULL_ORDERING == "nulls_are_small"
nulls_are_last = self.NULL_ORDERING == "nulls_are_last"
sort_order = " DESC" if desc else ""
nulls_sort_change = ""
if nulls_first and (
(asc and nulls_are_large) or (desc and nulls_are_small) or nulls_are_last
):
nulls_sort_change = " NULLS FIRST"
elif (
nulls_last
and ((asc and nulls_are_small) or (desc and nulls_are_large))
and not nulls_are_last
):
nulls_sort_change = " NULLS LAST"
if nulls_sort_change and not self.NULL_ORDERING_SUPPORTED:
self.unsupported(
"Sorting in an ORDER BY on NULLS FIRST/NULLS LAST is not supported by this dialect"
)
nulls_sort_change = ""
return f"{self.sql(expression, 'this')}{sort_order}{nulls_sort_change}"
def matchrecognize_sql(self, expression: exp.MatchRecognize) -> str:
partition = self.partition_by_sql(expression)
order = self.sql(expression, "order")
measures = self.expressions(expression, key="measures")
measures = self.seg(f"MEASURES{self.seg(measures)}") if measures else ""
rows = self.sql(expression, "rows")
rows = self.seg(rows) if rows else ""
after = self.sql(expression, "after")
after = self.seg(after) if after else ""
pattern = self.sql(expression, "pattern")
pattern = self.seg(f"PATTERN ({pattern})") if pattern else ""
definition_sqls = [
f"{self.sql(definition, 'alias')} AS {self.sql(definition, 'this')}"
for definition in expression.args.get("define", [])
]
definitions = self.expressions(sqls=definition_sqls)
define = self.seg(f"DEFINE{self.seg(definitions)}") if definitions else ""
body = "".join(
(
partition,
order,
measures,
rows,
after,
pattern,
define,
)
)
alias = self.sql(expression, "alias")
alias = f" {alias}" if alias else ""
return f"{self.seg('MATCH_RECOGNIZE')} {self.wrap(body)}{alias}"
def query_modifiers(self, expression: exp.Expression, *sqls: str) -> str:
limit: t.Optional[exp.Fetch | exp.Limit] = expression.args.get("limit")
# If the limit is generated as TOP, we need to ensure it's not generated twice
with_offset_limit_modifiers = not isinstance(limit, exp.Limit) or not self.LIMIT_IS_TOP
if self.LIMIT_FETCH == "LIMIT" and isinstance(limit, exp.Fetch):
limit = exp.Limit(expression=exp.maybe_copy(limit.args.get("count")))
elif self.LIMIT_FETCH == "FETCH" and isinstance(limit, exp.Limit):
limit = exp.Fetch(direction="FIRST", count=exp.maybe_copy(limit.expression))
fetch = isinstance(limit, exp.Fetch)
offset_limit_modifiers = (
self.offset_limit_modifiers(expression, fetch, limit)
if with_offset_limit_modifiers
else []
)
return csv(
*sqls,
*[self.sql(join) for join in expression.args.get("joins") or []],
self.sql(expression, "connect"),
self.sql(expression, "match"),
*[self.sql(lateral) for lateral in expression.args.get("laterals") or []],
self.sql(expression, "where"),
self.sql(expression, "group"),
self.sql(expression, "having"),
*self.after_having_modifiers(expression),
self.sql(expression, "order"),
*offset_limit_modifiers,
*self.after_limit_modifiers(expression),
sep="",
)
def offset_limit_modifiers(
self, expression: exp.Expression, fetch: bool, limit: t.Optional[exp.Fetch | exp.Limit]
) -> t.List[str]:
return [
self.sql(expression, "offset") if fetch else self.sql(limit),
self.sql(limit) if fetch else self.sql(expression, "offset"),
]
def after_having_modifiers(self, expression: exp.Expression) -> t.List[str]:
return [
self.sql(expression, "qualify"),
self.seg("WINDOW ") + self.expressions(expression, key="windows", flat=True)
if expression.args.get("windows")
else "",
self.sql(expression, "distribute"),
self.sql(expression, "sort"),
self.sql(expression, "cluster"),
]
def after_limit_modifiers(self, expression: exp.Expression) -> t.List[str]:
locks = self.expressions(expression, key="locks", sep=" ")
locks = f" {locks}" if locks else ""
return [locks, self.sql(expression, "sample")]
def select_sql(self, expression: exp.Select) -> str:
hint = self.sql(expression, "hint")
distinct = self.sql(expression, "distinct")
distinct = f" {distinct}" if distinct else ""
kind = self.sql(expression, "kind").upper()
limit = expression.args.get("limit")
top = (
self.limit_sql(limit, top=True)
if isinstance(limit, exp.Limit) and self.LIMIT_IS_TOP
else ""
)
expressions = self.expressions(expression)
if kind:
if kind in self.SELECT_KINDS:
kind = f" AS {kind}"
else:
if kind == "STRUCT":
expressions = self.expressions(
sqls=[
self.sql(
exp.Struct(
expressions=[
exp.column(e.output_name).eq(
e.this if isinstance(e, exp.Alias) else e
)
for e in expression.expressions
]
)
)
]
)
kind = ""
expressions = f"{self.sep()}{expressions}" if expressions else expressions
sql = self.query_modifiers(
expression,
f"SELECT{top}{hint}{distinct}{kind}{expressions}",
self.sql(expression, "into", comment=False),
self.sql(expression, "from", comment=False),
)
return self.prepend_ctes(expression, sql)
def schema_sql(self, expression: exp.Schema) -> str:
this = self.sql(expression, "this")
this = f"{this} " if this else ""
sql = self.schema_columns_sql(expression)
return f"{this}{sql}"
def schema_columns_sql(self, expression: exp.Schema) -> str:
return f"({self.sep('')}{self.expressions(expression)}{self.seg(')', sep='')}"
def star_sql(self, expression: exp.Star) -> str:
except_ = self.expressions(expression, key="except", flat=True)
except_ = f"{self.seg(self.STAR_MAPPING['except'])} ({except_})" if except_ else ""
replace = self.expressions(expression, key="replace", flat=True)
replace = f"{self.seg(self.STAR_MAPPING['replace'])} ({replace})" if replace else ""
return f"*{except_}{replace}"
def parameter_sql(self, expression: exp.Parameter) -> str:
this = self.sql(expression, "this")
this = f"{{{this}}}" if expression.args.get("wrapped") else f"{this}"
return f"{self.PARAMETER_TOKEN}{this}"
def sessionparameter_sql(self, expression: exp.SessionParameter) -> str:
this = self.sql(expression, "this")
kind = expression.text("kind")
if kind:
kind = f"{kind}."
return f"@@{kind}{this}"
def placeholder_sql(self, expression: exp.Placeholder) -> str:
return f":{expression.name}" if expression.name else "?"
def subquery_sql(self, expression: exp.Subquery, sep: str = " AS ") -> str:
alias = self.sql(expression, "alias")
alias = f"{sep}{alias}" if alias else ""
pivots = self.expressions(expression, key="pivots", sep=" ", flat=True)
pivots = f" {pivots}" if pivots else ""
sql = self.query_modifiers(expression, self.wrap(expression), alias, pivots)
return self.prepend_ctes(expression, sql)
def qualify_sql(self, expression: exp.Qualify) -> str:
this = self.indent(self.sql(expression, "this"))
return f"{self.seg('QUALIFY')}{self.sep()}{this}"
def union_sql(self, expression: exp.Union) -> str:
return self.prepend_ctes(
expression,
self.set_operation(expression, self.union_op(expression)),
)
def union_op(self, expression: exp.Union) -> str:
kind = " DISTINCT" if self.EXPLICIT_UNION else ""
kind = kind if expression.args.get("distinct") else " ALL"
by_name = " BY NAME" if expression.args.get("by_name") else ""
return f"UNION{kind}{by_name}"
def unnest_sql(self, expression: exp.Unnest) -> str:
args = self.expressions(expression, flat=True)
alias = expression.args.get("alias")
if alias and self.UNNEST_COLUMN_ONLY:
columns = alias.columns
alias = self.sql(columns[0]) if columns else ""
else:
alias = self.sql(expression, "alias")
alias = f" AS {alias}" if alias else alias
ordinality = " WITH ORDINALITY" if expression.args.get("ordinality") else ""
offset = expression.args.get("offset")
offset = f" WITH OFFSET AS {self.sql(offset)}" if offset else ""
return f"UNNEST({args}){ordinality}{alias}{offset}"
def where_sql(self, expression: exp.Where) -> str:
this = self.indent(self.sql(expression, "this"))
return f"{self.seg('WHERE')}{self.sep()}{this}"
def window_sql(self, expression: exp.Window) -> str:
this = self.sql(expression, "this")
partition = self.partition_by_sql(expression)
order = expression.args.get("order")
order = self.order_sql(order, flat=True) if order else ""
spec = self.sql(expression, "spec")
alias = self.sql(expression, "alias")
over = self.sql(expression, "over") or "OVER"
this = f"{this} {'AS' if expression.arg_key == 'windows' else over}"
first = expression.args.get("first")
if first is None:
first = ""
else:
first = "FIRST" if first else "LAST"
if not partition and not order and not spec and alias:
return f"{this} {alias}"
args = " ".join(arg for arg in (alias, first, partition, order, spec) if arg)
return f"{this} ({args})"
def partition_by_sql(self, expression: exp.Window | exp.MatchRecognize) -> str:
partition = self.expressions(expression, key="partition_by", flat=True)
return f"PARTITION BY {partition}" if partition else ""
def windowspec_sql(self, expression: exp.WindowSpec) -> str:
kind = self.sql(expression, "kind")
start = csv(self.sql(expression, "start"), self.sql(expression, "start_side"), sep=" ")
end = (
csv(self.sql(expression, "end"), self.sql(expression, "end_side"), sep=" ")
or "CURRENT ROW"
)
return f"{kind} BETWEEN {start} AND {end}"
def withingroup_sql(self, expression: exp.WithinGroup) -> str:
this = self.sql(expression, "this")
expression_sql = self.sql(expression, "expression")[1:] # order has a leading space
return f"{this} WITHIN GROUP ({expression_sql})"
def between_sql(self, expression: exp.Between) -> str:
this = self.sql(expression, "this")
low = self.sql(expression, "low")
high = self.sql(expression, "high")
return f"{this} BETWEEN {low} AND {high}"
def bracket_sql(self, expression: exp.Bracket) -> str:
expressions = apply_index_offset(expression.this, expression.expressions, self.INDEX_OFFSET)
expressions_sql = ", ".join(self.sql(e) for e in expressions)
return f"{self.sql(expression, 'this')}[{expressions_sql}]"
def safebracket_sql(self, expression: exp.SafeBracket) -> str:
return self.bracket_sql(expression)
def all_sql(self, expression: exp.All) -> str:
return f"ALL {self.wrap(expression)}"
def any_sql(self, expression: exp.Any) -> str:
this = self.sql(expression, "this")
if isinstance(expression.this, exp.Subqueryable):
this = self.wrap(this)
return f"ANY {this}"
def exists_sql(self, expression: exp.Exists) -> str:
return f"EXISTS{self.wrap(expression)}"
def case_sql(self, expression: exp.Case) -> str:
this = self.sql(expression, "this")
statements = [f"CASE {this}" if this else "CASE"]
for e in expression.args["ifs"]:
statements.append(f"WHEN {self.sql(e, 'this')}")
statements.append(f"THEN {self.sql(e, 'true')}")
default = self.sql(expression, "default")
if default:
statements.append(f"ELSE {default}")
statements.append("END")
if self.pretty and self.text_width(statements) > self.max_text_width:
return self.indent("\n".join(statements), skip_first=True, skip_last=True)
return " ".join(statements)
def constraint_sql(self, expression: exp.Constraint) -> str:
this = self.sql(expression, "this")
expressions = self.expressions(expression, flat=True)
return f"CONSTRAINT {this} {expressions}"
def nextvaluefor_sql(self, expression: exp.NextValueFor) -> str:
order = expression.args.get("order")
order = f" OVER ({self.order_sql(order, flat=True)})" if order else ""
return f"NEXT VALUE FOR {self.sql(expression, 'this')}{order}"
def extract_sql(self, expression: exp.Extract) -> str:
this = self.sql(expression, "this") if self.EXTRACT_ALLOWS_QUOTES else expression.this.name
expression_sql = self.sql(expression, "expression")
return f"EXTRACT({this} FROM {expression_sql})"
def trim_sql(self, expression: exp.Trim) -> str:
trim_type = self.sql(expression, "position")
if trim_type == "LEADING":
return self.func("LTRIM", expression.this)
elif trim_type == "TRAILING":
return self.func("RTRIM", expression.this)
else:
return self.func("TRIM", expression.this, expression.expression)
def safeconcat_sql(self, expression: exp.SafeConcat) -> str:
expressions = expression.expressions
if self.STRICT_STRING_CONCAT:
expressions = (exp.cast(e, "text") for e in expressions)
return self.func("CONCAT", *expressions)
def check_sql(self, expression: exp.Check) -> str:
this = self.sql(expression, key="this")
return f"CHECK ({this})"
def foreignkey_sql(self, expression: exp.ForeignKey) -> str:
expressions = self.expressions(expression, flat=True)
reference = self.sql(expression, "reference")
reference = f" {reference}" if reference else ""
delete = self.sql(expression, "delete")
delete = f" ON DELETE {delete}" if delete else ""
update = self.sql(expression, "update")
update = f" ON UPDATE {update}" if update else ""
return f"FOREIGN KEY ({expressions}){reference}{delete}{update}"
def primarykey_sql(self, expression: exp.ForeignKey) -> str:
expressions = self.expressions(expression, flat=True)
options = self.expressions(expression, key="options", flat=True, sep=" ")
options = f" {options}" if options else ""
return f"PRIMARY KEY ({expressions}){options}"
def if_sql(self, expression: exp.If) -> str:
expression = expression.copy()
return self.case_sql(exp.Case(ifs=[expression], default=expression.args.get("false")))
def matchagainst_sql(self, expression: exp.MatchAgainst) -> str:
modifier = expression.args.get("modifier")
modifier = f" {modifier}" if modifier else ""
return f"{self.func('MATCH', *expression.expressions)} AGAINST({self.sql(expression, 'this')}{modifier})"
def jsonkeyvalue_sql(self, expression: exp.JSONKeyValue) -> str:
return f"{self.sql(expression, 'this')}: {self.sql(expression, 'expression')}"
def jsonobject_sql(self, expression: exp.JSONObject) -> str:
null_handling = expression.args.get("null_handling")
null_handling = f" {null_handling}" if null_handling else ""
unique_keys = expression.args.get("unique_keys")
if unique_keys is not None:
unique_keys = f" {'WITH' if unique_keys else 'WITHOUT'} UNIQUE KEYS"
else:
unique_keys = ""
return_type = self.sql(expression, "return_type")
return_type = f" RETURNING {return_type}" if return_type else ""
format_json = " FORMAT JSON" if expression.args.get("format_json") else ""
encoding = self.sql(expression, "encoding")
encoding = f" ENCODING {encoding}" if encoding else ""
return self.func(
"JSON_OBJECT",
*expression.expressions,
suffix=f"{null_handling}{unique_keys}{return_type}{format_json}{encoding})",
)
def openjsoncolumndef_sql(self, expression: exp.OpenJSONColumnDef) -> str:
this = self.sql(expression, "this")
kind = self.sql(expression, "kind")
path = self.sql(expression, "path")
path = f" {path}" if path else ""
as_json = " AS JSON" if expression.args.get("as_json") else ""
return f"{this} {kind}{path}{as_json}"
def openjson_sql(self, expression: exp.OpenJSON) -> str:
this = self.sql(expression, "this")
path = self.sql(expression, "path")
path = f", {path}" if path else ""
expressions = self.expressions(expression)
with_ = (
f" WITH ({self.seg(self.indent(expressions), sep='')}{self.seg(')', sep='')}"
if expressions
else ""
)
return f"OPENJSON({this}{path}){with_}"
def in_sql(self, expression: exp.In) -> str:
query = expression.args.get("query")
unnest = expression.args.get("unnest")
field = expression.args.get("field")
is_global = " GLOBAL" if expression.args.get("is_global") else ""
if query:
in_sql = self.wrap(query)
elif unnest:
in_sql = self.in_unnest_op(unnest)
elif field:
in_sql = self.sql(field)
else:
in_sql = f"({self.expressions(expression, flat=True)})"
return f"{self.sql(expression, 'this')}{is_global} IN {in_sql}"
def in_unnest_op(self, unnest: exp.Unnest) -> str:
return f"(SELECT {self.sql(unnest)})"
def interval_sql(self, expression: exp.Interval) -> str:
unit = self.sql(expression, "unit")
if not self.INTERVAL_ALLOWS_PLURAL_FORM:
unit = self.TIME_PART_SINGULARS.get(unit.lower(), unit)
unit = f" {unit}" if unit else ""
if self.SINGLE_STRING_INTERVAL:
this = expression.this.name if expression.this else ""
return f"INTERVAL '{this}{unit}'" if this else f"INTERVAL{unit}"
this = self.sql(expression, "this")
if this:
unwrapped = isinstance(expression.this, self.UNWRAPPED_INTERVAL_VALUES)
this = f" {this}" if unwrapped else f" ({this})"
return f"INTERVAL{this}{unit}"
def return_sql(self, expression: exp.Return) -> str:
return f"RETURN {self.sql(expression, 'this')}"
def reference_sql(self, expression: exp.Reference) -> str:
this = self.sql(expression, "this")
expressions = self.expressions(expression, flat=True)
expressions = f"({expressions})" if expressions else ""
options = self.expressions(expression, key="options", flat=True, sep=" ")
options = f" {options}" if options else ""
return f"REFERENCES {this}{expressions}{options}"
def anonymous_sql(self, expression: exp.Anonymous) -> str:
return self.func(expression.name, *expression.expressions)
def paren_sql(self, expression: exp.Paren) -> str:
if isinstance(expression.unnest(), exp.Select):
sql = self.wrap(expression)
else:
sql = self.seg(self.indent(self.sql(expression, "this")), sep="")
sql = f"({sql}{self.seg(')', sep='')}"
return self.prepend_ctes(expression, sql)
def neg_sql(self, expression: exp.Neg) -> str:
# This makes sure we don't convert "- - 5" to "--5", which is a comment
this_sql = self.sql(expression, "this")
sep = " " if this_sql[0] == "-" else ""
return f"-{sep}{this_sql}"
def not_sql(self, expression: exp.Not) -> str:
return f"NOT {self.sql(expression, 'this')}"
def alias_sql(self, expression: exp.Alias) -> str:
alias = self.sql(expression, "alias")
alias = f" AS {alias}" if alias else ""
return f"{self.sql(expression, 'this')}{alias}"
def aliases_sql(self, expression: exp.Aliases) -> str:
return f"{self.sql(expression, 'this')} AS ({self.expressions(expression, flat=True)})"
def attimezone_sql(self, expression: exp.AtTimeZone) -> str:
this = self.sql(expression, "this")
zone = self.sql(expression, "zone")
return f"{this} AT TIME ZONE {zone}"
def add_sql(self, expression: exp.Add) -> str:
return self.binary(expression, "+")
def and_sql(self, expression: exp.And) -> str:
return self.connector_sql(expression, "AND")
def xor_sql(self, expression: exp.Xor) -> str:
return self.connector_sql(expression, "XOR")
def connector_sql(self, expression: exp.Connector, op: str) -> str:
if not self.pretty:
return self.binary(expression, op)
sqls = tuple(
self.maybe_comment(self.sql(e), e, e.parent.comments or []) if i != 1 else self.sql(e)
for i, e in enumerate(expression.flatten(unnest=False))
)
sep = "\n" if self.text_width(sqls) > self.max_text_width else " "
return f"{sep}{op} ".join(sqls)
def bitwiseand_sql(self, expression: exp.BitwiseAnd) -> str:
return self.binary(expression, "&")
def bitwiseleftshift_sql(self, expression: exp.BitwiseLeftShift) -> str:
return self.binary(expression, "<<")
def bitwisenot_sql(self, expression: exp.BitwiseNot) -> str:
return f"~{self.sql(expression, 'this')}"
def bitwiseor_sql(self, expression: exp.BitwiseOr) -> str:
return self.binary(expression, "|")
def bitwiserightshift_sql(self, expression: exp.BitwiseRightShift) -> str:
return self.binary(expression, ">>")
def bitwisexor_sql(self, expression: exp.BitwiseXor) -> str:
return self.binary(expression, "^")
def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str:
format_sql = self.sql(expression, "format")
format_sql = f" FORMAT {format_sql}" if format_sql else ""
return f"{safe_prefix or ''}CAST({self.sql(expression, 'this')} AS {self.sql(expression, 'to')}{format_sql})"
def currentdate_sql(self, expression: exp.CurrentDate) -> str:
zone = self.sql(expression, "this")
return f"CURRENT_DATE({zone})" if zone else "CURRENT_DATE"
def collate_sql(self, expression: exp.Collate) -> str:
return self.binary(expression, "COLLATE")
def command_sql(self, expression: exp.Command) -> str:
return f"{self.sql(expression, 'this').upper()} {expression.text('expression').strip()}"
def comment_sql(self, expression: exp.Comment) -> str:
this = self.sql(expression, "this")
kind = expression.args["kind"]
exists_sql = " IF EXISTS " if expression.args.get("exists") else " "
expression_sql = self.sql(expression, "expression")
return f"COMMENT{exists_sql}ON {kind} {this} IS {expression_sql}"
def mergetreettlaction_sql(self, expression: exp.MergeTreeTTLAction) -> str:
this = self.sql(expression, "this")
delete = " DELETE" if expression.args.get("delete") else ""
recompress = self.sql(expression, "recompress")
recompress = f" RECOMPRESS {recompress}" if recompress else ""
to_disk = self.sql(expression, "to_disk")
to_disk = f" TO DISK {to_disk}" if to_disk else ""
to_volume = self.sql(expression, "to_volume")
to_volume = f" TO VOLUME {to_volume}" if to_volume else ""
return f"{this}{delete}{recompress}{to_disk}{to_volume}"
def mergetreettl_sql(self, expression: exp.MergeTreeTTL) -> str:
where = self.sql(expression, "where")
group = self.sql(expression, "group")
aggregates = self.expressions(expression, key="aggregates")
aggregates = self.seg("SET") + self.seg(aggregates) if aggregates else ""
if not (where or group or aggregates) and len(expression.expressions) == 1:
return f"TTL {self.expressions(expression, flat=True)}"
return f"TTL{self.seg(self.expressions(expression))}{where}{group}{aggregates}"
def transaction_sql(self, expression: exp.Transaction) -> str:
return "BEGIN"
def commit_sql(self, expression: exp.Commit) -> str:
chain = expression.args.get("chain")
if chain is not None:
chain = " AND CHAIN" if chain else " AND NO CHAIN"
return f"COMMIT{chain or ''}"
def rollback_sql(self, expression: exp.Rollback) -> str:
savepoint = expression.args.get("savepoint")
savepoint = f" TO {savepoint}" if savepoint else ""
return f"ROLLBACK{savepoint}"
def altercolumn_sql(self, expression: exp.AlterColumn) -> str:
this = self.sql(expression, "this")
dtype = self.sql(expression, "dtype")
if dtype:
collate = self.sql(expression, "collate")
collate = f" COLLATE {collate}" if collate else ""
using = self.sql(expression, "using")
using = f" USING {using}" if using else ""
return f"ALTER COLUMN {this} TYPE {dtype}{collate}{using}"
default = self.sql(expression, "default")
if default:
return f"ALTER COLUMN {this} SET DEFAULT {default}"
if not expression.args.get("drop"):
self.unsupported("Unsupported ALTER COLUMN syntax")
return f"ALTER COLUMN {this} DROP DEFAULT"
def renametable_sql(self, expression: exp.RenameTable) -> str:
if not self.RENAME_TABLE_WITH_DB:
# Remove db from tables
expression = expression.transform(
lambda n: exp.table_(n.this) if isinstance(n, exp.Table) else n
)
this = self.sql(expression, "this")
return f"RENAME TO {this}"
def altertable_sql(self, expression: exp.AlterTable) -> str:
actions = expression.args["actions"]
if isinstance(actions[0], exp.ColumnDef):
if self.ALTER_TABLE_ADD_COLUMN_KEYWORD:
actions = self.expressions(
expression,
key="actions",
prefix="ADD COLUMN ",
)
else:
actions = f"ADD {self.expressions(expression, key='actions')}"
elif isinstance(actions[0], exp.Schema):
actions = self.expressions(expression, key="actions", prefix="ADD COLUMNS ")
elif isinstance(actions[0], exp.Delete):
actions = self.expressions(expression, key="actions", flat=True)
else:
actions = self.expressions(expression, key="actions")
exists = " IF EXISTS" if expression.args.get("exists") else ""
only = " ONLY" if expression.args.get("only") else ""
return f"ALTER TABLE{exists}{only} {self.sql(expression, 'this')} {actions}"
def droppartition_sql(self, expression: exp.DropPartition) -> str:
expressions = self.expressions(expression)
exists = " IF EXISTS " if expression.args.get("exists") else " "
return f"DROP{exists}{expressions}"
def addconstraint_sql(self, expression: exp.AddConstraint) -> str:
this = self.sql(expression, "this")
expression_ = self.sql(expression, "expression")
add_constraint = f"ADD CONSTRAINT {this}" if this else "ADD"
enforced = expression.args.get("enforced")
if enforced is not None:
return f"{add_constraint} CHECK ({expression_}){' ENFORCED' if enforced else ''}"
return f"{add_constraint} {expression_}"
def distinct_sql(self, expression: exp.Distinct) -> str:
this = self.expressions(expression, flat=True)
this = f" {this}" if this else ""
on = self.sql(expression, "on")
on = f" ON {on}" if on else ""
return f"DISTINCT{this}{on}"
def ignorenulls_sql(self, expression: exp.IgnoreNulls) -> str:
return f"{self.sql(expression, 'this')} IGNORE NULLS"
def respectnulls_sql(self, expression: exp.RespectNulls) -> str:
return f"{self.sql(expression, 'this')} RESPECT NULLS"
def intdiv_sql(self, expression: exp.IntDiv) -> str:
return self.sql(
exp.Cast(
this=exp.Div(this=expression.this.copy(), expression=expression.expression.copy()),
to=exp.DataType(this=exp.DataType.Type.INT),
)
)
def dpipe_sql(self, expression: exp.DPipe) -> str:
return self.binary(expression, "||")
def safedpipe_sql(self, expression: exp.SafeDPipe) -> str:
if self.STRICT_STRING_CONCAT:
return self.func("CONCAT", *(exp.cast(e, "text") for e in expression.flatten()))
return self.dpipe_sql(expression)
def div_sql(self, expression: exp.Div) -> str:
return self.binary(expression, "/")
def overlaps_sql(self, expression: exp.Overlaps) -> str:
return self.binary(expression, "OVERLAPS")
def distance_sql(self, expression: exp.Distance) -> str:
return self.binary(expression, "<->")
def dot_sql(self, expression: exp.Dot) -> str:
return f"{self.sql(expression, 'this')}.{self.sql(expression, 'expression')}"
def eq_sql(self, expression: exp.EQ) -> str:
return self.binary(expression, "=")
def escape_sql(self, expression: exp.Escape) -> str:
return self.binary(expression, "ESCAPE")
def glob_sql(self, expression: exp.Glob) -> str:
return self.binary(expression, "GLOB")
def gt_sql(self, expression: exp.GT) -> str:
return self.binary(expression, ">")
def gte_sql(self, expression: exp.GTE) -> str:
return self.binary(expression, ">=")
def ilike_sql(self, expression: exp.ILike) -> str:
return self.binary(expression, "ILIKE")
def ilikeany_sql(self, expression: exp.ILikeAny) -> str:
return self.binary(expression, "ILIKE ANY")
def is_sql(self, expression: exp.Is) -> str:
if not self.IS_BOOL_ALLOWED and isinstance(expression.expression, exp.Boolean):
return self.sql(
expression.this if expression.expression.this else exp.not_(expression.this)
)
return self.binary(expression, "IS")
def like_sql(self, expression: exp.Like) -> str:
return self.binary(expression, "LIKE")
def likeany_sql(self, expression: exp.LikeAny) -> str:
return self.binary(expression, "LIKE ANY")
def similarto_sql(self, expression: exp.SimilarTo) -> str:
return self.binary(expression, "SIMILAR TO")
def lt_sql(self, expression: exp.LT) -> str:
return self.binary(expression, "<")
def lte_sql(self, expression: exp.LTE) -> str:
return self.binary(expression, "<=")
def mod_sql(self, expression: exp.Mod) -> str:
return self.binary(expression, "%")
def mul_sql(self, expression: exp.Mul) -> str:
return self.binary(expression, "*")
def neq_sql(self, expression: exp.NEQ) -> str:
return self.binary(expression, "<>")
def nullsafeeq_sql(self, expression: exp.NullSafeEQ) -> str:
return self.binary(expression, "IS NOT DISTINCT FROM")
def nullsafeneq_sql(self, expression: exp.NullSafeNEQ) -> str:
return self.binary(expression, "IS DISTINCT FROM")
def or_sql(self, expression: exp.Or) -> str:
return self.connector_sql(expression, "OR")
def slice_sql(self, expression: exp.Slice) -> str:
return self.binary(expression, ":")
def sub_sql(self, expression: exp.Sub) -> str:
return self.binary(expression, "-")
def trycast_sql(self, expression: exp.TryCast) -> str:
return self.cast_sql(expression, safe_prefix="TRY_")
def use_sql(self, expression: exp.Use) -> str:
kind = self.sql(expression, "kind")
kind = f" {kind}" if kind else ""
this = self.sql(expression, "this")
this = f" {this}" if this else ""
return f"USE{kind}{this}"
def binary(self, expression: exp.Binary, op: str) -> str:
op = self.maybe_comment(op, comments=expression.comments)
return f"{self.sql(expression, 'this')} {op} {self.sql(expression, 'expression')}"
def function_fallback_sql(self, expression: exp.Func) -> str:
args = []
for key in expression.arg_types:
arg_value = expression.args.get(key)
if isinstance(arg_value, list):
for value in arg_value:
args.append(value)
elif arg_value is not None:
args.append(arg_value)
if self.normalize_functions:
name = expression.sql_name()
else:
name = (expression._meta and expression.meta.get("name")) or expression.sql_name()
return self.func(name, *args)
def func(
self,
name: str,
*args: t.Optional[exp.Expression | str],
prefix: str = "(",
suffix: str = ")",
) -> str:
return f"{self.normalize_func(name)}{prefix}{self.format_args(*args)}{suffix}"
def format_args(self, *args: t.Optional[str | exp.Expression]) -> str:
arg_sqls = tuple(self.sql(arg) for arg in args if arg is not None)
if self.pretty and self.text_width(arg_sqls) > self.max_text_width:
return self.indent("\n" + f",\n".join(arg_sqls) + "\n", skip_first=True, skip_last=True)
return ", ".join(arg_sqls)
def text_width(self, args: t.Iterable) -> int:
return sum(len(arg) for arg in args)
def format_time(self, expression: exp.Expression) -> t.Optional[str]:
return format_time(
self.sql(expression, "format"), self.INVERSE_TIME_MAPPING, self.INVERSE_TIME_TRIE
)
def expressions(
self,
expression: t.Optional[exp.Expression] = None,
key: t.Optional[str] = None,
sqls: t.Optional[t.List[str]] = None,
flat: bool = False,
indent: bool = True,
skip_first: bool = False,
sep: str = ", ",
prefix: str = "",
) -> str:
expressions = expression.args.get(key or "expressions") if expression else sqls
if not expressions:
return ""
if flat:
return sep.join(sql for sql in (self.sql(e) for e in expressions) if sql)
num_sqls = len(expressions)
# These are calculated once in case we have the leading_comma / pretty option set, correspondingly
pad = " " * self.pad
stripped_sep = sep.strip()
result_sqls = []
for i, e in enumerate(expressions):
sql = self.sql(e, comment=False)
if not sql:
continue
comments = self.maybe_comment("", e) if isinstance(e, exp.Expression) else ""
if self.pretty:
if self.leading_comma:
result_sqls.append(f"{sep if i > 0 else pad}{prefix}{sql}{comments}")
else:
result_sqls.append(
f"{prefix}{sql}{stripped_sep if i + 1 < num_sqls else ''}{comments}"
)
else:
result_sqls.append(f"{prefix}{sql}{comments}{sep if i + 1 < num_sqls else ''}")
result_sql = "\n".join(result_sqls) if self.pretty else "".join(result_sqls)
return self.indent(result_sql, skip_first=skip_first) if indent else result_sql
def op_expressions(self, op: str, expression: exp.Expression, flat: bool = False) -> str:
flat = flat or isinstance(expression.parent, exp.Properties)
expressions_sql = self.expressions(expression, flat=flat)
if flat:
return f"{op} {expressions_sql}"
return f"{self.seg(op)}{self.sep() if expressions_sql else ''}{expressions_sql}"
def naked_property(self, expression: exp.Property) -> str:
property_name = exp.Properties.PROPERTY_TO_NAME.get(expression.__class__)
if not property_name:
self.unsupported(f"Unsupported property {expression.__class__.__name__}")
return f"{property_name} {self.sql(expression, 'this')}"
def set_operation(self, expression: exp.Expression, op: str) -> str:
this = self.sql(expression, "this")
op = self.seg(op)
return self.query_modifiers(
expression, f"{this}{op}{self.sep()}{self.sql(expression, 'expression')}"
)
def tag_sql(self, expression: exp.Tag) -> str:
return f"{expression.args.get('prefix')}{self.sql(expression.this)}{expression.args.get('postfix')}"
def token_sql(self, token_type: TokenType) -> str:
return self.TOKEN_MAPPING.get(token_type, token_type.name)
def userdefinedfunction_sql(self, expression: exp.UserDefinedFunction) -> str:
this = self.sql(expression, "this")
expressions = self.no_identify(self.expressions, expression)
expressions = (
self.wrap(expressions) if expression.args.get("wrapped") else f" {expressions}"
)
return f"{this}{expressions}"
def joinhint_sql(self, expression: exp.JoinHint) -> str:
this = self.sql(expression, "this")
expressions = self.expressions(expression, flat=True)
return f"{this}({expressions})"
def kwarg_sql(self, expression: exp.Kwarg) -> str:
return self.binary(expression, "=>")
def when_sql(self, expression: exp.When) -> str:
matched = "MATCHED" if expression.args["matched"] else "NOT MATCHED"
source = " BY SOURCE" if self.MATCHED_BY_SOURCE and expression.args.get("source") else ""
condition = self.sql(expression, "condition")
condition = f" AND {condition}" if condition else ""
then_expression = expression.args.get("then")
if isinstance(then_expression, exp.Insert):
then = f"INSERT {self.sql(then_expression, 'this')}"
if "expression" in then_expression.args:
then += f" VALUES {self.sql(then_expression, 'expression')}"
elif isinstance(then_expression, exp.Update):
if isinstance(then_expression.args.get("expressions"), exp.Star):
then = f"UPDATE {self.sql(then_expression, 'expressions')}"
else:
then = f"UPDATE SET {self.expressions(then_expression, flat=True)}"
else:
then = self.sql(then_expression)
return f"WHEN {matched}{source}{condition} THEN {then}"
def merge_sql(self, expression: exp.Merge) -> str:
table = expression.this
table_alias = ""
hints = table.args.get("hints")
if hints and table.alias and isinstance(hints[0], exp.WithTableHint):
# T-SQL syntax is MERGE ... <target_table> [WITH (<merge_hint>)] [[AS] table_alias]
table = table.copy()
table_alias = f" AS {self.sql(table.args['alias'].pop())}"
this = self.sql(table)
using = f"USING {self.sql(expression, 'using')}"
on = f"ON {self.sql(expression, 'on')}"
expressions = self.expressions(expression, sep=" ")
return f"MERGE INTO {this}{table_alias} {using} {on} {expressions}"
def tochar_sql(self, expression: exp.ToChar) -> str:
if expression.args.get("format"):
self.unsupported("Format argument unsupported for TO_CHAR/TO_VARCHAR function")
return self.sql(exp.cast(expression.this, "text"))
def dictproperty_sql(self, expression: exp.DictProperty) -> str:
this = self.sql(expression, "this")
kind = self.sql(expression, "kind")
settings_sql = self.expressions(expression, key="settings", sep=" ")
args = f"({self.sep('')}{settings_sql}{self.seg(')', sep='')}" if settings_sql else "()"
return f"{this}({kind}{args})"
def dictrange_sql(self, expression: exp.DictRange) -> str:
this = self.sql(expression, "this")
max = self.sql(expression, "max")
min = self.sql(expression, "min")
return f"{this}(MIN {min} MAX {max})"
def dictsubproperty_sql(self, expression: exp.DictSubProperty) -> str:
return f"{self.sql(expression, 'this')} {self.sql(expression, 'value')}"
def oncluster_sql(self, expression: exp.OnCluster) -> str:
return ""
def clusteredbyproperty_sql(self, expression: exp.ClusteredByProperty) -> str:
expressions = self.expressions(expression, key="expressions", flat=True)
sorted_by = self.expressions(expression, key="sorted_by", flat=True)
sorted_by = f" SORTED BY ({sorted_by})" if sorted_by else ""
buckets = self.sql(expression, "buckets")
return f"CLUSTERED BY ({expressions}){sorted_by} INTO {buckets} BUCKETS"
def anyvalue_sql(self, expression: exp.AnyValue) -> str:
this = self.sql(expression, "this")
having = self.sql(expression, "having")
if having:
this = f"{this} HAVING {'MAX' if expression.args.get('max') else 'MIN'} {having}"
return self.func("ANY_VALUE", this)
def querytransform_sql(self, expression: exp.QueryTransform) -> str:
transform = self.func("TRANSFORM", *expression.expressions)
row_format_before = self.sql(expression, "row_format_before")
row_format_before = f" {row_format_before}" if row_format_before else ""
record_writer = self.sql(expression, "record_writer")
record_writer = f" RECORDWRITER {record_writer}" if record_writer else ""
using = f" USING {self.sql(expression, 'command_script')}"
schema = self.sql(expression, "schema")
schema = f" AS {schema}" if schema else ""
row_format_after = self.sql(expression, "row_format_after")
row_format_after = f" {row_format_after}" if row_format_after else ""
record_reader = self.sql(expression, "record_reader")
record_reader = f" RECORDREADER {record_reader}" if record_reader else ""
return f"{transform}{row_format_before}{record_writer}{using}{schema}{row_format_after}{record_reader}"
def indexconstraintoption_sql(self, expression: exp.IndexConstraintOption) -> str:
key_block_size = self.sql(expression, "key_block_size")
if key_block_size:
return f"KEY_BLOCK_SIZE = {key_block_size}"
using = self.sql(expression, "using")
if using:
return f"USING {using}"
parser = self.sql(expression, "parser")
if parser:
return f"WITH PARSER {parser}"
comment = self.sql(expression, "comment")
if comment:
return f"COMMENT {comment}"
visible = expression.args.get("visible")
if visible is not None:
return "VISIBLE" if visible else "INVISIBLE"
engine_attr = self.sql(expression, "engine_attr")
if engine_attr:
return f"ENGINE_ATTRIBUTE = {engine_attr}"
secondary_engine_attr = self.sql(expression, "secondary_engine_attr")
if secondary_engine_attr:
return f"SECONDARY_ENGINE_ATTRIBUTE = {secondary_engine_attr}"
self.unsupported("Unsupported index constraint option.")
return ""
def indexcolumnconstraint_sql(self, expression: exp.IndexColumnConstraint) -> str:
kind = self.sql(expression, "kind")
kind = f"{kind} INDEX" if kind else "INDEX"
this = self.sql(expression, "this")
this = f" {this}" if this else ""
type_ = self.sql(expression, "type")
type_ = f" USING {type_}" if type_ else ""
schema = self.sql(expression, "schema")
schema = f" {schema}" if schema else ""
options = self.expressions(expression, key="options", sep=" ")
options = f" {options}" if options else ""
return f"{kind}{this}{type_}{schema}{options}"
def nvl2_sql(self, expression: exp.Nvl2) -> str:
if self.NVL2_SUPPORTED:
return self.function_fallback_sql(expression)
case = exp.Case().when(
expression.this.is_(exp.null()).not_(copy=False),
expression.args["true"].copy(),
copy=False,
)
else_cond = expression.args.get("false")
if else_cond:
case.else_(else_cond.copy(), copy=False)
return self.sql(case)
def comprehension_sql(self, expression: exp.Comprehension) -> str:
this = self.sql(expression, "this")
expr = self.sql(expression, "expression")
iterator = self.sql(expression, "iterator")
condition = self.sql(expression, "condition")
condition = f" IF {condition}" if condition else ""
return f"{this} FOR {expr} IN {iterator}{condition}"
def cached_generator(
cache: t.Optional[t.Dict[int, str]] = None
) -> t.Callable[[exp.Expression], str]:
"""Returns a cached generator."""
cache = {} if cache is None else cache
generator = Generator(normalize=True, identify="safe")

Step 2: ⌨️ Coding

dbt_core_integration.py

Add debugging statements to the `_compile_node` method to log the progress of the incremental logic macro execution.
--- 
+++ 
@@ -13,6 +13,21 @@
             else:
                 # this is essentially a convenient wrapper to adapter.get_compiler
                 compiled_node = self.sql_compiler.compile(self.dbt)
+
+            # Add debugging statements
+            self.dbtTerminal.debug(
+                "_compile_node",
+                f"Compiling node: {node.unique_id}",
+            )
+            self.dbtTerminal.debug(
+                "_compile_node",
+                f"Raw SQL: {getattr(compiled_node, RAW_CODE)}",
+            )
+            self.dbtTerminal.debug(
+                "_compile_node",
+                f"Compiled SQL: {getattr(compiled_node, COMPILED_CODE)}",
+            )
+
             return DbtAdapterCompilationResult(
                 getattr(compiled_node, RAW_CODE),
                 getattr(compiled_node, COMPILED_CODE),

dbt_core_integration.py

Verify that the `sql_compiler` is properly set up and configured to handle macros with if-else statements.
--- 
+++ 
@@ -7,4 +7,13 @@
             self._sql_compiler = SqlCompileRunner(
                 self.config, self.adapter, node=None, node_index=1, num_nodes=1
             )
+
+            # Verify sql_compiler configuration
+            assert self._sql_compiler.config.has_option("compile", "handle_macros"), (
+                "sql_compiler configuration missing 'handle_macros' option"
+            )
+            assert self._sql_compiler.config.get("compile", "handle_macros"), (
+                "sql_compiler not configured to handle macros"
+            )
+
         return self._sql_compiler

dbt_core_integration.py

Ensure that the `compile` method of the `sql_compiler` is being called with the correct arguments, including the dbt manifest.
--- 
+++ 
@@ -12,7 +12,7 @@
                 )
             else:
                 # this is essentially a convenient wrapper to adapter.get_compiler
-                compiled_node = self.sql_compiler.compile(self.dbt)
+                compiled_node = self.sql_compiler.compile(manifest=self.dbt)
             return DbtAdapterCompilationResult(
                 getattr(compiled_node, RAW_CODE),
                 getattr(compiled_node, COMPILED_CODE),

dbt_core_integration.py

Add logging statements in the `_compile_node` method to capture the compiled code generated by dbt and the plugin.
--- 
+++ 
@@ -13,10 +13,25 @@
             else:
                 # this is essentially a convenient wrapper to adapter.get_compiler
                 compiled_node = self.sql_compiler.compile(self.dbt)
-            return DbtAdapterCompilationResult(
+
+            # Log compiled code
+            self.dbtTerminal.debug(
+                "_compile_node",
+                f"DBT compiled code: {getattr(compiled_node, COMPILED_CODE)}",
+            )
+
+            result = DbtAdapterCompilationResult(
                 getattr(compiled_node, RAW_CODE),
                 getattr(compiled_node, COMPILED_CODE),
                 compiled_node,
             )
+
+            # Log plugin compiled code
+            self.dbtTerminal.debug(
+                "_compile_node",
+                f"Plugin compiled code: {result.compiled_sql}",
+            )
+
+            return result
         except Exception as e:
             raise Exception(str(e))

Step 3: 🔄️ Validating

Your changes have been successfully made to the branch sweep/incorrect_compile_code_from_dbt_power_us. I have validated these changes using a syntax checker and a linter.


Tip

To recreate the pull request, edit the issue title or description.

This is an automated message generated by Sweep AI.

@anandgupta42
Copy link
Contributor

@mdesmet Can you have a look?

@mdesmet
Copy link
Contributor

mdesmet commented Jul 5, 2024

@ashutran : Can you provide us with a sample code that reproduces the issue(s)? I think we don't have the full picture of what you are trying to achieve.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working sweep
Projects
None yet
Development

No branches or pull requests

3 participants