diff --git a/.env.example b/.env.example index ce2c71d..5724bc8 100644 --- a/.env.example +++ b/.env.example @@ -153,3 +153,11 @@ PGVECTOR_COLLECTION=table_info_db # VectorDB 설정 VECTORDB_TYPE=faiss # faiss 또는 pgvector + + +# TRINO_HOST=localhost +# TRINO_PORT=8080 +# TRINO_USER=admin +# TRINO_PASSWORD=password +# TRINO_CATALOG=delta +# TRINO_SCHEMA=default diff --git a/db_utils/__init__.py b/db_utils/__init__.py index e8502b8..f5a3901 100644 --- a/db_utils/__init__.py +++ b/db_utils/__init__.py @@ -13,6 +13,7 @@ from .duckdb_connector import DuckDBConnector from .databricks_connector import DatabricksConnector from .snowflake_connector import SnowflakeConnector +from .trino_connector import TrinoConnector env_path = os.path.join(os.getcwd(), ".env") @@ -54,6 +55,7 @@ def get_db_connector(db_type: Optional[str] = None, config: Optional[DBConfig] = "duckdb": DuckDBConnector, "databricks": DatabricksConnector, "snowflake": SnowflakeConnector, + "trino": TrinoConnector, } if db_type not in connector_map: diff --git a/db_utils/trino_connector.py b/db_utils/trino_connector.py new file mode 100644 index 0000000..b471674 --- /dev/null +++ b/db_utils/trino_connector.py @@ -0,0 +1,120 @@ +import pandas as pd +from .base_connector import BaseConnector +from .config import DBConfig +from .logger import logger + +try: + import trino +except Exception as e: # pragma: no cover + trino = None + _import_error = e + + +class TrinoConnector(BaseConnector): + """ + Connect to Trino and execute SQL queries. + """ + + connection = None + + def __init__(self, config: DBConfig): + """ + Initialize the TrinoConnector with connection parameters. + + Parameters: + config (DBConfig): Configuration object containing connection parameters. + """ + self.host = config["host"] + self.port = config["port"] or 8080 + self.user = config.get("user") or "anonymous" + self.password = config.get("password") + self.database = config.get("database") # e.g., catalog.schema + self.extra = config.get("extra") or {} + self.http_scheme = self.extra.get("http_scheme", "http") + self.catalog = self.extra.get("catalog") + self.schema = self.extra.get("schema") + + # If database given as "catalog.schema", split into fields + if self.database and (not self.catalog or not self.schema): + if "." in self.database: + db_catalog, db_schema = self.database.split(".", 1) + self.catalog = self.catalog or db_catalog + self.schema = self.schema or db_schema + + self.connect() + + def connect(self) -> None: + """ + Establish a connection to the Trino cluster. + """ + if trino is None: + logger.error(f"Failed to import trino driver: {_import_error}") + raise _import_error + + try: + auth = None + if self.password: + # If HTTP, ignore password to avoid insecure auth error + if self.http_scheme == "http": + logger.warning( + "Password provided for HTTP; ignoring password. " + "Set TRINO_HTTP_SCHEME=https to enable password authentication." + ) + else: + # Basic auth over HTTPS + auth = trino.auth.BasicAuthentication(self.user, self.password) + + self.connection = trino.dbapi.connect( + host=self.host, + port=self.port, + user=self.user, + http_scheme=self.http_scheme, + catalog=self.catalog, + schema=self.schema, + auth=auth, + # Optional: session properties + # session_properties={} + ) + logger.info("Successfully connected to Trino.") + except Exception as e: + logger.error(f"Failed to connect to Trino: {e}") + raise + + def run_sql(self, sql: str) -> pd.DataFrame: + """ + Execute a SQL query and return the result as a pandas DataFrame. + + Parameters: + sql (str): SQL query string to be executed. + + Returns: + pd.DataFrame: Result of the SQL query as a pandas DataFrame. + """ + try: + cursor = self.connection.cursor() + cursor.execute(sql) + columns = ( + [desc[0] for desc in cursor.description] if cursor.description else [] + ) + rows = cursor.fetchall() if cursor.description else [] + return pd.DataFrame(rows, columns=columns) + except Exception as e: + logger.error(f"Failed to execute SQL query on Trino: {e}") + raise + finally: + try: + cursor.close() + except Exception: + pass + + def close(self) -> None: + """ + Close the connection to the Trino cluster. + """ + if self.connection: + try: + self.connection.close() + except Exception: + pass + logger.info("Connection to Trino closed.") + self.connection = None diff --git a/pyproject.toml b/pyproject.toml index 3452f24..3f310f7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,7 @@ dependencies = [ "google-cloud-bigquery>=3.20.1,<4.0.0", "pgvector==0.3.6", "langchain-postgres==0.0.15", + "trino>=0.329.0,<1.0.0", ] [project.scripts] @@ -82,4 +83,3 @@ dev-dependencies = [ "pytest>=8.3.5", ] -