Skip to content

Commit

Permalink
Fix bugs in loading and saving from s3 + snowflake region (#46)
Browse files Browse the repository at this point in the history
* Fix bugs in loading and saving from s3

* fix lint

* Fix minor bug

* Fix minor bug
  • Loading branch information
dimberman committed Jan 21, 2022
1 parent 1a59828 commit 541038c
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 65 deletions.
6 changes: 3 additions & 3 deletions README.md
Expand Up @@ -5,17 +5,17 @@
- [
Astro :rocket:
](#astro-rocket)
- [Overview](#overview)
- [Philosophy](#philosophy)
- [Setup](#setup)
- [Using Astro as a SQL Engineer](#using-astro-as-a-sql-engineer)
- [Schemas](#schemas)
- [Setting up SQL files](#setting-up-sql-files)
- [Using Astro as a Python Engineer](#using-astro-as-a-python-engineer)
- [The output_table parameter](#the-output_table-parameter)
- [Setting Input and Output Tables](#setting-input-and-output-tables)
- [Loading Data](#loading-data)
- [Transform](#transform)
- [Transform File](#transform-file)
- [Raw SQL](#raw-sql)
- [Putting it All Together](#putting-it-all-together)
- [Other SQL functions](#other-sql-functions)
- [Appending data](#appending-data)
- [Merging data](#merging-data)
Expand Down
35 changes: 4 additions & 31 deletions src/astro/sql/operators/agnostic_load_file.py
Expand Up @@ -18,14 +18,13 @@
from typing import Union
from urllib.parse import urlparse

import boto3
import pandas as pd
from airflow.hooks.base import BaseHook
from airflow.models import BaseOperator, DagRun, TaskInstance
from google.cloud.storage import Client
from airflow.models import BaseOperator
from smart_open import open

from astro.sql.table import Table, TempTable, create_table_name
from astro.utils.cloud_storage_creds import gcs_client, s3fs_creds
from astro.utils.load_dataframe import move_dataframe_to_sql
from astro.utils.schema_util import get_schema
from astro.utils.task_id_helper import get_task_id
Expand Down Expand Up @@ -110,8 +109,8 @@ def _load_dataframe(self, path):

file_type = path.split(".")[-1]
transport_params = {
"s3": self._s3fs_creds,
"gs": self._gcs_client,
"s3": s3fs_creds,
"gs": gcs_client,
"": lambda: None,
}[urlparse(path).scheme]()
deserialiser = {
Expand All @@ -126,32 +125,6 @@ def _load_dataframe(self, path):
stream, **deserialiser_params.get(file_type, {})
)

def _s3fs_creds(self):
# To-do: reuse this method from sql decorator
"""Structure s3fs credentials from Airflow connection.
s3fs enables pandas to write to s3
"""
# To-do: clean-up how S3 creds are passed to s3fs
k, v = (
os.environ["AIRFLOW__ASTRO__CONN_AWS_DEFAULT"]
.replace("%2F", "/")
.replace("aws://", "")
.replace("@", "")
.split(":")
)
session = boto3.Session(
aws_access_key_id=k,
aws_secret_access_key=v,
)
return dict(client=session.client("s3"))

def _gcs_client(self):
"""
get GCS credentials for storage.
"""
client = Client()
return dict(client=client)


def load_file(
path,
Expand Down
35 changes: 5 additions & 30 deletions src/astro/sql/operators/agnostic_save_file.py
Expand Up @@ -27,6 +27,7 @@

from astro.sql.operators.temp_hooks import TempPostgresHook, TempSnowflakeHook
from astro.sql.table import Table
from astro.utils.cloud_storage_creds import gcs_client, s3fs_creds
from astro.utils.task_id_helper import get_task_id


Expand Down Expand Up @@ -105,8 +106,8 @@ def execute(self, context):

def file_exists(self, output_file_path, output_conn_id=None):
transport_params = {
"s3": self._s3fs_creds,
"gs": self._gcs_client,
"s3": s3fs_creds,
"gs": gcs_client,
"": lambda: None,
}[urlparse(output_file_path).scheme]()
try:
Expand All @@ -121,8 +122,8 @@ def agnostic_write_file(self, df, output_file_path, output_conn_id=None):
Select output file format based on param output_file_format to class.
"""
transport_params = {
"s3": self._s3fs_creds,
"gs": self._gcs_client,
"s3": s3fs_creds,
"gs": gcs_client,
"": lambda: None,
}[urlparse(output_file_path).scheme]()

Expand All @@ -142,32 +143,6 @@ def agnostic_write_file(self, df, output_file_path, output_conn_id=None):
stream, **serialiser_params.get(self.output_file_format, {})
)

def _s3fs_creds(self):
# To-do: reuse this method from sql decorator
"""Structure s3fs credentials from Airflow connection.
s3fs enables pandas to write to s3
"""
# To-do: clean-up how S3 creds are passed to s3fs
k, v = (
os.environ["AIRFLOW__ASTRO__CONN_AWS_DEFAULT"]
.replace("%2F", "/")
.replace("aws://", "")
.replace("@", "")
.split(":")
)
session = boto3.Session(
aws_access_key_id=k,
aws_secret_access_key=v,
)
return dict(client=session.client("s3"))

def _gcs_client(self):
"""
get GCS credentials for storage
"""
client = Client()
return dict(client=client)

@staticmethod
def create_table_name(context):
ti: TaskInstance = context["ti"]
Expand Down
2 changes: 1 addition & 1 deletion src/astro/sql/operators/temp_hooks.py
Expand Up @@ -28,7 +28,7 @@ def get_uri(self) -> str:
"""Override DbApiHook get_uri method for get_sqlalchemy_engine()"""
conn_config = self._get_conn_params()
uri = (
"snowflake://{user}:{password}@{account}.{region}/{database}/{schema}"
"snowflake://{user}:{password}@{account}/{database}/{schema}"
"?warehouse={warehouse}&role={role}&authenticator={authenticator}"
)
return uri.format(**conn_config)
Expand Down
39 changes: 39 additions & 0 deletions src/astro/utils/cloud_storage_creds.py
@@ -0,0 +1,39 @@
import os
from urllib import parse

import boto3
from google.cloud.storage import Client


def parse_s3_env_var():
raw_data = (
os.environ["AIRFLOW__ASTRO__CONN_AWS_DEFAULT"]
.replace("%2F", "/")
.replace("aws://", "")
.replace("@", "")
.split(":")
)
return [parse.unquote(r) for r in raw_data]


def s3fs_creds():
# To-do: reuse this method from sql decorator
"""Structure s3fs credentials from Airflow connection.
s3fs enables pandas to write to s3
"""
# To-do: clean-up how S3 creds are passed to s3fs

k, v = parse_s3_env_var()
session = boto3.Session(
aws_access_key_id=k,
aws_secret_access_key=v,
)
return dict(client=session.client("s3"))


def gcs_client():
"""
get GCS credentials for storage.
"""
client = Client()
return dict(client=client)
14 changes: 14 additions & 0 deletions tests/operators/test_agnostic_load_file.py
Expand Up @@ -28,6 +28,7 @@
import os
import pathlib
import unittest.mock
from unittest import mock

import pandas as pd
import pytest
Expand Down Expand Up @@ -506,6 +507,19 @@ def sql_server(request):
hook.run(f"DROP TABLE IF EXISTS {schema}.{OUTPUT_TABLE_NAME}")


@mock.patch.dict(
os.environ,
{
"AIRFLOW__ASTRO__CONN_AWS_DEFAULT": "abcd:%40%23%24%25%40%24%23ASDH%40Ksd23%25SD546@"
},
)
def test_aws_decode():
from astro.utils.cloud_storage_creds import parse_s3_env_var

k, v = parse_s3_env_var()
assert v == "@#$%@$#ASDH@Ksd23%SD546"


@pytest.mark.parametrize("sql_server", ["snowflake", "postgres"], indirect=True)
@pytest.mark.parametrize("file_type", ["ndjson", "json", "csv"])
def test_load_file(sample_dag, sql_server, file_type):
Expand Down

0 comments on commit 541038c

Please sign in to comment.