diff --git a/README.md b/README.md index 40cc8b2581..f4657f94df 100644 --- a/README.md +++ b/README.md @@ -5,17 +5,17 @@ - [ Astro :rocket: ](#astro-rocket) +- [Overview](#overview) - [Philosophy](#philosophy) - [Setup](#setup) - [Using Astro as a SQL Engineer](#using-astro-as-a-sql-engineer) - [Schemas](#schemas) - [Setting up SQL files](#setting-up-sql-files) - [Using Astro as a Python Engineer](#using-astro-as-a-python-engineer) - - [The output_table parameter](#the-output_table-parameter) + - [Setting Input and Output Tables](#setting-input-and-output-tables) - [Loading Data](#loading-data) - [Transform](#transform) - - [Transform File](#transform-file) - - [Raw SQL](#raw-sql) + - [Putting it All Together](#putting-it-all-together) - [Other SQL functions](#other-sql-functions) - [Appending data](#appending-data) - [Merging data](#merging-data) diff --git a/src/astro/sql/operators/agnostic_load_file.py b/src/astro/sql/operators/agnostic_load_file.py index 6d0f27949a..4255697cd8 100644 --- a/src/astro/sql/operators/agnostic_load_file.py +++ b/src/astro/sql/operators/agnostic_load_file.py @@ -18,14 +18,13 @@ from typing import Union from urllib.parse import urlparse -import boto3 import pandas as pd from airflow.hooks.base import BaseHook -from airflow.models import BaseOperator, DagRun, TaskInstance -from google.cloud.storage import Client +from airflow.models import BaseOperator from smart_open import open from astro.sql.table import Table, TempTable, create_table_name +from astro.utils.cloud_storage_creds import gcs_client, s3fs_creds from astro.utils.load_dataframe import move_dataframe_to_sql from astro.utils.schema_util import get_schema from astro.utils.task_id_helper import get_task_id @@ -110,8 +109,8 @@ def _load_dataframe(self, path): file_type = path.split(".")[-1] transport_params = { - "s3": self._s3fs_creds, - "gs": self._gcs_client, + "s3": s3fs_creds, + "gs": gcs_client, "": lambda: None, }[urlparse(path).scheme]() deserialiser = { @@ -126,32 +125,6 @@ def _load_dataframe(self, path): stream, **deserialiser_params.get(file_type, {}) ) - def _s3fs_creds(self): - # To-do: reuse this method from sql decorator - """Structure s3fs credentials from Airflow connection. - s3fs enables pandas to write to s3 - """ - # To-do: clean-up how S3 creds are passed to s3fs - k, v = ( - os.environ["AIRFLOW__ASTRO__CONN_AWS_DEFAULT"] - .replace("%2F", "/") - .replace("aws://", "") - .replace("@", "") - .split(":") - ) - session = boto3.Session( - aws_access_key_id=k, - aws_secret_access_key=v, - ) - return dict(client=session.client("s3")) - - def _gcs_client(self): - """ - get GCS credentials for storage. - """ - client = Client() - return dict(client=client) - def load_file( path, diff --git a/src/astro/sql/operators/agnostic_save_file.py b/src/astro/sql/operators/agnostic_save_file.py index 4a370314d3..7cf393eb5b 100644 --- a/src/astro/sql/operators/agnostic_save_file.py +++ b/src/astro/sql/operators/agnostic_save_file.py @@ -27,6 +27,7 @@ from astro.sql.operators.temp_hooks import TempPostgresHook, TempSnowflakeHook from astro.sql.table import Table +from astro.utils.cloud_storage_creds import gcs_client, s3fs_creds from astro.utils.task_id_helper import get_task_id @@ -105,8 +106,8 @@ def execute(self, context): def file_exists(self, output_file_path, output_conn_id=None): transport_params = { - "s3": self._s3fs_creds, - "gs": self._gcs_client, + "s3": s3fs_creds, + "gs": gcs_client, "": lambda: None, }[urlparse(output_file_path).scheme]() try: @@ -121,8 +122,8 @@ def agnostic_write_file(self, df, output_file_path, output_conn_id=None): Select output file format based on param output_file_format to class. """ transport_params = { - "s3": self._s3fs_creds, - "gs": self._gcs_client, + "s3": s3fs_creds, + "gs": gcs_client, "": lambda: None, }[urlparse(output_file_path).scheme]() @@ -142,32 +143,6 @@ def agnostic_write_file(self, df, output_file_path, output_conn_id=None): stream, **serialiser_params.get(self.output_file_format, {}) ) - def _s3fs_creds(self): - # To-do: reuse this method from sql decorator - """Structure s3fs credentials from Airflow connection. - s3fs enables pandas to write to s3 - """ - # To-do: clean-up how S3 creds are passed to s3fs - k, v = ( - os.environ["AIRFLOW__ASTRO__CONN_AWS_DEFAULT"] - .replace("%2F", "/") - .replace("aws://", "") - .replace("@", "") - .split(":") - ) - session = boto3.Session( - aws_access_key_id=k, - aws_secret_access_key=v, - ) - return dict(client=session.client("s3")) - - def _gcs_client(self): - """ - get GCS credentials for storage - """ - client = Client() - return dict(client=client) - @staticmethod def create_table_name(context): ti: TaskInstance = context["ti"] diff --git a/src/astro/sql/operators/temp_hooks.py b/src/astro/sql/operators/temp_hooks.py index bb1048f5a0..29ef861efe 100644 --- a/src/astro/sql/operators/temp_hooks.py +++ b/src/astro/sql/operators/temp_hooks.py @@ -28,7 +28,7 @@ def get_uri(self) -> str: """Override DbApiHook get_uri method for get_sqlalchemy_engine()""" conn_config = self._get_conn_params() uri = ( - "snowflake://{user}:{password}@{account}.{region}/{database}/{schema}" + "snowflake://{user}:{password}@{account}/{database}/{schema}" "?warehouse={warehouse}&role={role}&authenticator={authenticator}" ) return uri.format(**conn_config) diff --git a/src/astro/utils/cloud_storage_creds.py b/src/astro/utils/cloud_storage_creds.py new file mode 100644 index 0000000000..9cdb81769e --- /dev/null +++ b/src/astro/utils/cloud_storage_creds.py @@ -0,0 +1,39 @@ +import os +from urllib import parse + +import boto3 +from google.cloud.storage import Client + + +def parse_s3_env_var(): + raw_data = ( + os.environ["AIRFLOW__ASTRO__CONN_AWS_DEFAULT"] + .replace("%2F", "/") + .replace("aws://", "") + .replace("@", "") + .split(":") + ) + return [parse.unquote(r) for r in raw_data] + + +def s3fs_creds(): + # To-do: reuse this method from sql decorator + """Structure s3fs credentials from Airflow connection. + s3fs enables pandas to write to s3 + """ + # To-do: clean-up how S3 creds are passed to s3fs + + k, v = parse_s3_env_var() + session = boto3.Session( + aws_access_key_id=k, + aws_secret_access_key=v, + ) + return dict(client=session.client("s3")) + + +def gcs_client(): + """ + get GCS credentials for storage. + """ + client = Client() + return dict(client=client) diff --git a/tests/operators/test_agnostic_load_file.py b/tests/operators/test_agnostic_load_file.py index 7409889981..35d2a4b632 100644 --- a/tests/operators/test_agnostic_load_file.py +++ b/tests/operators/test_agnostic_load_file.py @@ -28,6 +28,7 @@ import os import pathlib import unittest.mock +from unittest import mock import pandas as pd import pytest @@ -506,6 +507,19 @@ def sql_server(request): hook.run(f"DROP TABLE IF EXISTS {schema}.{OUTPUT_TABLE_NAME}") +@mock.patch.dict( + os.environ, + { + "AIRFLOW__ASTRO__CONN_AWS_DEFAULT": "abcd:%40%23%24%25%40%24%23ASDH%40Ksd23%25SD546@" + }, +) +def test_aws_decode(): + from astro.utils.cloud_storage_creds import parse_s3_env_var + + k, v = parse_s3_env_var() + assert v == "@#$%@$#ASDH@Ksd23%SD546" + + @pytest.mark.parametrize("sql_server", ["snowflake", "postgres"], indirect=True) @pytest.mark.parametrize("file_type", ["ndjson", "json", "csv"]) def test_load_file(sample_dag, sql_server, file_type):