Skip to content

Commit

Permalink
Get the file location based on path and conn_id (#1478)
Browse files Browse the repository at this point in the history
# Description
## What is the current behavior?
<!-- Please describe the current behavior that you are modifying. -->
Currently `create_file_location` is picking module path based on
`filetype.value` even tough the `conn_id` is passed.
https://github.com/astronomer/astro-sdk/blob/main/python-sdk/src/astro/files/locations/__init__.py#L23
get_location_type doesn't take care of fetching the filelocation based
on `conn_type` as per `conn_id` at all. It just picks up from the path.

<!--
Issues are required for both bug fixes and features.
Reference it using one of the following:

closes: #ISSUE
related: #ISSUE
-->
closes: #1471 


## What is the new behavior?
<!-- Please describe the behavior or changes that are being added by
this PR. -->
- Add `validate_conn()` in BaseFilelocation class to match file path and
connection type

## Does this introduce a breaking change?
No

### Checklist
- [x] Created tests which fail without the change (if possible)
- [x] Extended the README / documentation, if necessary

Co-authored-by: Kaxil Naik <kaxilnaik@gmail.com>
  • Loading branch information
sunank200 and kaxil committed Dec 29, 2022
1 parent e96ab04 commit 91a8b86
Show file tree
Hide file tree
Showing 16 changed files with 68 additions and 37 deletions.
3 changes: 2 additions & 1 deletion python-sdk/docs/development/CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,5 +79,6 @@ Before you submit a pull request (PR), check that it meets these guidelines:

- Run tests locally before opening PR.

- Adhere to guidelines for commit messages described in this [article](http://chris.beams.io/posts/git-commit/).
- Adhere to guidelines for commit messages described in this
[article](https://cbea.ms/git-commit/).
This makes the lives of those who come after you a lot easier.
10 changes: 5 additions & 5 deletions python-sdk/src/astro/files/locations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
from astro.files.locations.base import BaseFileLocation
from astro.utils.path import get_class_name, get_dict_with_module_names_to_dot_notations

DEFAULT_CONN_TYPE_TO_MODULE_PATH = get_dict_with_module_names_to_dot_notations(Path(__file__))
DEFAULT_CONN_TYPE_TO_MODULE_PATH["https"] = DEFAULT_CONN_TYPE_TO_MODULE_PATH["http"]
DEFAULT_CONN_TYPE_TO_MODULE_PATH["gs"] = DEFAULT_CONN_TYPE_TO_MODULE_PATH["gcs"]
DEFAULT_CONN_TYPE_TO_MODULE_PATH["wasbs"] = DEFAULT_CONN_TYPE_TO_MODULE_PATH["wasb"]
DEFAULT_FILE_SCHEME_TO_MODULE_PATH = get_dict_with_module_names_to_dot_notations(Path(__file__))
DEFAULT_FILE_SCHEME_TO_MODULE_PATH["https"] = DEFAULT_FILE_SCHEME_TO_MODULE_PATH["http"]
DEFAULT_FILE_SCHEME_TO_MODULE_PATH["gs"] = DEFAULT_FILE_SCHEME_TO_MODULE_PATH["gcs"]
DEFAULT_FILE_SCHEME_TO_MODULE_PATH["wasbs"] = DEFAULT_FILE_SCHEME_TO_MODULE_PATH["wasb"]


def create_file_location(path: str, conn_id: Optional[str] = None) -> BaseFileLocation:
Expand All @@ -20,7 +20,7 @@ def create_file_location(path: str, conn_id: Optional[str] = None) -> BaseFileLo
:param conn_id: Airflow connection ID
"""
filetype: FileLocation = BaseFileLocation.get_location_type(path)
module_path = DEFAULT_CONN_TYPE_TO_MODULE_PATH[filetype.value]
module_path = DEFAULT_FILE_SCHEME_TO_MODULE_PATH[filetype.value]
module_ref = importlib.import_module(module_path)
class_name = get_class_name(module_ref)
location: BaseFileLocation = getattr(module_ref, class_name)(path, conn_id)
Expand Down
1 change: 1 addition & 0 deletions python-sdk/src/astro/files/locations/amazon/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class S3Location(BaseFileLocation):
"""Handler S3 object store operations"""

location_type = FileLocation.S3
supported_conn_type = {S3Hook.conn_type, "aws"}

@property
def hook(self) -> S3Hook:
Expand Down
1 change: 1 addition & 0 deletions python-sdk/src/astro/files/locations/azure/wasb.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class WASBLocation(BaseFileLocation):
"""Handler WASB object store operations"""

location_type = FileLocation.WASB
supported_conn_type = {WasbHook.conn_type, "wasbs"}

@property
def hook(self) -> WasbHook:
Expand Down
15 changes: 15 additions & 0 deletions python-sdk/src/astro/files/locations/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from urllib.parse import urlparse

import smart_open
from airflow.hooks.base import BaseHook

from astro.constants import FileLocation

Expand All @@ -15,6 +16,7 @@ class BaseFileLocation(ABC):
"""Base Location abstract class"""

template_fields = ("path", "conn_id")
supported_conn_type: set[str] = set()

def __init__(self, path: str, conn_id: str | None = None):
"""
Expand All @@ -25,6 +27,19 @@ def __init__(self, path: str, conn_id: str | None = None):
"""
self.path: str = path
self.conn_id: str | None = conn_id
self.validate_conn()

def validate_conn(self):
"""Check if the conn_id matches with provided path."""
if not self.conn_id:
return

connection_type = BaseHook.get_connection(self.conn_id).conn_type
if connection_type not in self.supported_conn_type:
raise ValueError(
f"Connection type {connection_type} is not supported for {self.path}. "
f"Supported types are {self.supported_conn_type}"
)

@property
def hook(self):
Expand Down
2 changes: 2 additions & 0 deletions python-sdk/src/astro/files/locations/google/gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ class GCSLocation(BaseFileLocation):
"""Handler GS object store operations"""

location_type = FileLocation.GS
# TODO: Restrict the supported conn_type to only GCSHook.conn_type
supported_conn_type = {GCSHook.conn_type, "gcpbigquery", "bigquery"}

@property
def hook(self) -> GCSHook:
Expand Down
1 change: 1 addition & 0 deletions python-sdk/src/astro/files/locations/google/gdrive.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class GdriveLocation(BaseFileLocation):
"""Handler for Google Drive operators."""

location_type = FileLocation.GOOGLE_DRIVE
supported_conn_type = {GoogleDriveHook.conn_type}

@property
def hook(self) -> GoogleDriveHook:
Expand Down
3 changes: 3 additions & 0 deletions python-sdk/src/astro/files/locations/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ def paths(self) -> list[str]:
"""Resolve patterns in path"""
return [self.path]

def validate_conn(self):
"""Override as conn_id is not always required for http location."""

@property
def size(self) -> int:
"""Return file size for HTTP location"""
Expand Down
3 changes: 3 additions & 0 deletions python-sdk/src/astro/files/locations/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ def paths(self) -> list[str]:
paths = glob.glob(url.path)
return paths

def validate_conn(self):
"""Override as conn_id is not always required for local location."""

@property
def size(self) -> int:
"""Return the size in bytes of the given file.
Expand Down
6 changes: 4 additions & 2 deletions python-sdk/tests/files/locations/test_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ def test_pull_from_json_dict():


@pytest.mark.parametrize("input_dict", [{}], ids=["empty_dict"])
def test_databricks_auth_no_value(input_dict):
@mock.patch("astro.files.locations.base.BaseFileLocation.validate_conn", return_value=None)
def test_databricks_auth_no_value(validate_conn, input_dict):
with mock.patch(
"astro.files.locations.google.gcs.GCSLocation.hook", new_callable=PropertyMock
), pytest.raises(ValueError) as exc_info:
Expand All @@ -53,7 +54,8 @@ def test_databricks_auth_no_value(input_dict):
@pytest.mark.parametrize(
"input_dict", [{"key_path": "foo/bar"}, {"keyfile_dict": mock_creds}], ids=["key_path", "key_dict"]
)
def test_databricks_auth(input_dict):
@mock.patch("astro.files.locations.base.BaseFileLocation.validate_conn", return_value=None)
def test_databricks_auth(validate_conn, input_dict):
with mock.patch(
"astro.files.locations.google.gcs.GCSLocation.hook", new_callable=PropertyMock
), mock.patch("astro.files.locations.google.gcs._pull_credentials_from_keypath") as mock_keypath:
Expand Down
26 changes: 14 additions & 12 deletions python-sdk/tests/files/locations/test_location_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from astro.constants import FileLocation
from astro.files.locations import create_file_location, get_class_name
from astro.files.locations.google.gcs import GCSLocation
from astro.files.locations.local import LocalLocation

LOCAL_FILENAME = str(uuid.uuid4())
Expand Down Expand Up @@ -108,25 +109,20 @@ class SomethingElseLocation: # skipcq: PY-D0002
@pytest.mark.parametrize(
"loc_1,loc_2,equality",
[
(LocalLocation("/tmp/file_a.csv"), LocalLocation("/tmp/file_a.csv"), True),
(GCSLocation("gs://tmp/file_a.csv"), GCSLocation("gs://tmp/file_a.csv"), True),
(
LocalLocation("/tmp/file_a.csv", conn_id="test"),
LocalLocation("/tmp/file_a.csv", conn_id="test"),
GCSLocation("gs://tmp/file_a.csv", conn_id="google_cloud_default"),
GCSLocation("gs://tmp/file_a.csv", conn_id="google_cloud_default"),
True,
),
(
LocalLocation("/tmp/file_a.csv", conn_id="test"),
LocalLocation("/tmp/file_a.csv", conn_id="test"),
GCSLocation("gs://tmp/file_a.csv", conn_id="google_cloud_default"),
GCSLocation("gs://tmp/file_a.csv", conn_id="google_cloud_default"),
True,
),
(
LocalLocation("/tmp/file_a.csv", conn_id="test"),
LocalLocation("/tmp/file_a.csv", conn_id="test2"),
False,
),
(
LocalLocation("/tmp/file_a.csv", conn_id="test"),
LocalLocation("/tmp/file_b.csv", conn_id="test"),
GCSLocation("gs://tmp/file_a.csv", conn_id="google_cloud_default"),
GCSLocation("gs://tmp/file_b.csv", conn_id="google_cloud_default"),
False,
),
],
Expand All @@ -142,3 +138,9 @@ def test_location_eq(loc_1, loc_2, equality):
def test_location_hash():
"""Test that hashing works"""
assert isinstance(hash(LocalLocation("/tmp/file_a.csv")), int)


def test_invalid_conn_id_with_file_path():
"""Raise a value when the connection types doesn't match the path"""
with pytest.raises(ValueError, match=r".* is not supported for .*"):
GCSLocation("gs://tmp/file_a.csv", conn_id="aws_default")
8 changes: 4 additions & 4 deletions python-sdk/tests/files/operators/test_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ def test_get_file_list_execute_local():
@patch("astro.files.locations.google.gcs.GCSLocation.hook")
def test_get_file_list_execute_gcs(hook):
"""Assert that when file object location point to GCS then get_file_list using GCSHook"""
hook.return_value = Connection(conn_id="conn", conn_type="google_cloud_platform")
hook.return_value = Connection(conn_id="google_cloud_default", conn_type="google_cloud_platform")
op = ListFileOperator(
task_id="task_id",
conn_id="conn",
conn_id="google_cloud_default",
path="gs://bucket/some-file",
)
op.execute(None)
Expand All @@ -32,10 +32,10 @@ def test_get_file_list_execute_gcs(hook):
@patch("astro.files.locations.amazon.s3.S3Location.hook")
def test_get_file_list_s3(hook):
"""Assert that when file object location point to s3 then get_file_list using S3Hook"""
hook.return_value = Connection(conn_id="conn", conn_type="s3")
hook.return_value = Connection(conn_id="aws_default", conn_type="s3")
op = ListFileOperator(
task_id="task_id",
conn_id="conn",
conn_id="aws_default",
path="s3://bucket/some-file",
)
op.execute(None)
Expand Down
16 changes: 8 additions & 8 deletions python-sdk/tests/files/test_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,23 +94,23 @@ def test_get_file_list():
[
(File("/tmp/file_a.csv"), File("/tmp/file_a.csv"), True),
(
File("/tmp/file_a.csv", conn_id="test"),
File("/tmp/file_a.csv", conn_id="test"),
File("gs://tmp/file_a.csv", conn_id="google_cloud_default"),
File("gs://tmp/file_a.csv", conn_id="google_cloud_default"),
True,
),
(
File("/tmp/file_a.csv", conn_id="test", filetype=constants.FileType.CSV),
File("/tmp/file_a.csv", conn_id="test", filetype=constants.FileType.CSV),
File("gs://tmp/file_a.csv", conn_id="google_cloud_default", filetype=constants.FileType.CSV),
File("gs://tmp/file_a.csv", conn_id="google_cloud_default", filetype=constants.FileType.CSV),
True,
),
(
File("/tmp/file_a.csv", conn_id="test", filetype=constants.FileType.CSV),
File("/tmp/file_a.csv", conn_id="test2", filetype=constants.FileType.JSON),
File("gs://tmp/file_a.csv", conn_id="google_cloud_default", filetype=constants.FileType.CSV),
File("gs://tmp/file_a.csv", conn_id="google_cloud_default", filetype=constants.FileType.JSON),
False,
),
(
File("/tmp/file_a.csv", conn_id="test", filetype=constants.FileType.CSV),
File("/tmp/file_b.csv", conn_id="test", filetype=constants.FileType.CSV),
File("gs://tmp/file_a.csv", conn_id="google_cloud_default", filetype=constants.FileType.CSV),
File("gs://tmp/file_b.csv", conn_id="google_cloud_default", filetype=constants.FileType.CSV),
False,
),
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def test_existing_table_exists(database_table_fixture):
[
{
"database": Database.DELTA,
"file": File("s3://tmp9/databricks-test/", conn_id="default_aws", filetype=FileType.CSV),
"file": File("s3://tmp9/databricks-test/", conn_id="aws_default", filetype=FileType.CSV),
},
],
indirect=True,
Expand Down
6 changes: 3 additions & 3 deletions python-sdk/tests_integration/databricks_tests/test_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test_autoloader_load_file_local(database_table_fixture):
ids=["delta"],
)
def test_autoloader_load_file_s3(database_table_fixture):
file = File("s3://tmp9/databricks-test/", conn_id="default_aws", filetype=FileType.CSV)
file = File("s3://tmp9/databricks-test/", conn_id="aws_default", filetype=FileType.CSV)
database, table = database_table_fixture
database.load_file_to_table(
input_file=file,
Expand All @@ -72,7 +72,7 @@ def test_delta_load_file_gcs(database_table_fixture):

file = File(
"gs://astro-sdk/benchmark/trimmed/covid_overview/covid_overview_10kb.csv",
conn_id="databricks_gcs",
conn_id="google_cloud_default",
filetype=FileType.CSV,
)
database, table = database_table_fixture
Expand All @@ -99,7 +99,7 @@ def test_delta_load_file_gcs_autoloader(database_table_fixture):

file = File(
"gs://astro-sdk/benchmark/trimmed/covid_overview/",
conn_id="databricks_gcs",
conn_id="google_cloud_default",
filetype=FileType.CSV,
)
database, table = database_table_fixture
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def test_save_table_remote_file_exists_overwrite_false(
with sample_dag:
export_table_to_file(
input_data=test_table,
output_file=File(path=remote_files_fixture[0], conn_id="aws_default"),
output_file=File(path=remote_files_fixture[0]),
if_exists="exception",
)
test_utils.run_dag(sample_dag)
Expand Down

0 comments on commit 91a8b86

Please sign in to comment.