Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement AzureResultHandler #1421

Merged
merged 3 commits into from
Aug 29, 2019
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ These changes are available in the [master branch](https://github.com/PrefectHQ/

- Added Local, Kubernetes, and Nomad agents - [#1341](https://github.com/PrefectHQ/prefect/pull/1341)
- Add the ability for Tasks to sequentially loop - [#1356](https://github.com/PrefectHQ/prefect/pull/1356)
- - Add `AzureResultHandler` for handling results to / from Azure Blob storage containers - [#1421](https://github.com/PrefectHQ/prefect/pull/1421)

### Enhancements

Expand Down
2 changes: 1 addition & 1 deletion docs/outline.toml
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ classes = ["Result", "SafeResult", "NoResultType"]
[pages.engine.result_handlers]
title = "Result Handlers"
module = "prefect.engine.result_handlers"
classes = ["JSONResultHandler", "GCSResultHandler", "LocalResultHandler", "S3ResultHandler"]
classes = ["JSONResultHandler", "GCSResultHandler", "LocalResultHandler", "S3ResultHandler", "AzureResultHandler"]

[pages.engine.cloud]
title = "Cloud"
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
extras = {
"airtable": ["airtable-python-wrapper >= 0.11, < 0.12"],
"aws": ["boto3 >= 1.9, < 2.0"],
"azure": ["azure-storage-blob >= 2.1.0, < 3.0"],
"dev": dev_requires,
"dropbox": ["dropbox ~= 9.0"],
"google": [
Expand Down
5 changes: 5 additions & 0 deletions src/prefect/engine/result_handlers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,8 @@
from prefect.engine.result_handlers.s3_result_handler import S3ResultHandler
except ImportError:
pass

try:
from prefect.engine.result_handlers.azure_result_handler import AzureResultHandler
except ImportError:
pass
128 changes: 128 additions & 0 deletions src/prefect/engine/result_handlers/azure_result_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import base64
import json
import uuid
from typing import TYPE_CHECKING, Any

import cloudpickle
import pendulum

from prefect.client import Secret
from prefect.engine.result_handlers import ResultHandler

if TYPE_CHECKING:
import azure.storage.blob


class AzureResultHandler(ResultHandler):
"""
Result Handler for writing to and reading from an Azure Blob storage.

Args:
- container (str): the name of the container to write to / read from
- azure_credentials_secret (str, optional): the name of the Prefect Secret
which stores your Azure credentials; this Secret must be a JSON string
with two keys: `ACCOUNT_NAME` and `ACCOUNT_KEY`

Note that for this result handler to work properly, your Azure Credentials must
be made available in the `"AZ_CREDENTIALS"` Prefect Secret.
"""

def __init__(
self, container: str = None, azure_credentials_secret: str = "AZ_CREDENTIALS"
) -> None:
self.container = container
self.azure_credentials_secret = azure_credentials_secret
super().__init__()

def initialize_service(self) -> None:
"""
Initialize a Blob service.
"""
import azure.storage.blob

azure_credentials = Secret(self.azure_credentials_secret).get()
if isinstance(azure_credentials, str):
azure_credentials = json.loads(azure_credentials)

az_account_name = azure_credentials["ACCOUNT_NAME"]
az_account_key = azure_credentials["ACCOUNT_KEY"]
blob_service = azure.storage.blob.BlockBlobService(
account_name=az_account_name, account_key=az_account_key
)
self.service = blob_service

@property
def service(self) -> "azure.storage.blob.BlockBlobService":
if not hasattr(self, "_service"):
self.initialize_service()
return self._service

@service.setter
def service(self, val: Any) -> None:
self._service = val

def __getstate__(self) -> dict:
state = self.__dict__.copy()
if "_service" in state:
del state["_service"]
return state

def __setstate__(self, state: dict) -> None:
self.__dict__.update(state)

def write(self, result: Any) -> str:
"""
Given a result, writes the result to a location in Azure Blob storage
and returns the resulting URI.

Args:
- result (Any): the written result

Returns:
- str: the Blob URI
"""
date = pendulum.now("utc").format("Y/M/D")
uri = "{date}/{uuid}.prefect_result".format(date=date, uuid=uuid.uuid4())
self.logger.debug("Starting to upload result to {}...".format(uri))

## prepare data
binary_data = base64.b64encode(cloudpickle.dumps(result)).decode()

## upload
self.service.create_blob_from_text(
container_name=self.container, blob_name=uri, text=binary_data
)

self.logger.debug("Finished uploading result to {}.".format(uri))

return uri

def read(self, uri: str) -> Any:
"""
Given a uri, reads a result from Azure Blob storage, reads it and returns it

Args:
- uri (str): the Azure Blob URI

Returns:
- Any: the read result
"""
try:
self.logger.debug("Starting to download result from {}...".format(uri))
blob_result = self.service.get_blob_to_text(
container_name=self.container, blob_name=uri
)
content_string = blob_result.content
try:
return_val = cloudpickle.loads(base64.b64decode(content_string))
except EOFError:
return_val = None
self.logger.debug("Finished downloading result from {}.".format(uri))
except Exception as exc:
self.logger.exception(
"Unexpected error while reading from result handler: {}".format(
repr(exc)
)
)
return_val = None
return return_val
10 changes: 10 additions & 0 deletions src/prefect/serialization/result_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
LocalResultHandler,
ResultHandler,
S3ResultHandler,
AzureResultHandler,
)
from prefect.utilities.serialization import (
JSONCompatible,
Expand Down Expand Up @@ -66,6 +67,14 @@ class Meta:
aws_credentials_secret = fields.String(allow_none=True)


class AzureResultHandlerSchema(BaseResultHandlerSchema):
class Meta:
object_class = AzureResultHandler

container = fields.String(allow_none=False)
azure_credentials_secret = fields.String(allow_none=True)


class ResultHandlerSchema(OneOfSchema):
"""
Field that chooses between several nested schemas
Expand All @@ -78,6 +87,7 @@ class ResultHandlerSchema(OneOfSchema):
"S3ResultHandler": S3ResultHandlerSchema,
"JSONResultHandler": JSONResultHandlerSchema,
"LocalResultHandler": LocalResultHandlerSchema,
"AzureResultHandler": AzureResultHandlerSchema,
}

def get_obj_type(self, obj: Any) -> str:
Expand Down
80 changes: 80 additions & 0 deletions tests/engine/result_handlers/test_result_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import prefect
from prefect.client import Client
from prefect.engine.result_handlers import (
AzureResultHandler,
GCSResultHandler,
JSONResultHandler,
LocalResultHandler,
Expand Down Expand Up @@ -256,3 +257,82 @@ def __getstate__(self):
handler = S3ResultHandler(bucket="foo")
res = cloudpickle.loads(cloudpickle.dumps(handler))
assert isinstance(res, S3ResultHandler)


@pytest.mark.xfail(raises=ImportError, reason="azure extras not installed.")
class TestAzureResultHandler:
@pytest.fixture
def azure_service(self, monkeypatch):
import azure.storage.blob

service = MagicMock()
blob = MagicMock(BlockBlobService=service)
storage = MagicMock(blob=blob)

with patch.dict("sys.modules", {"azure": MagicMock(storage=storage)}):
yield service

def test_azure_service_init_uses_secrets(self, azure_service):
handler = AzureResultHandler(container="bob")
assert handler.container == "bob"
assert azure_service.called is False

with prefect.context(
secrets=dict(AZ_CREDENTIALS=dict(ACCOUNT_NAME="1", ACCOUNT_KEY="42"))
):
with set_temporary_config({"cloud.use_local_secrets": True}):
handler.initialize_service()

assert azure_service.call_args[1] == {"account_name": "1", "account_key": "42"}

def test_azure_service_init_uses_custom_secrets(self, azure_service):
handler = AzureResultHandler(container="bob", azure_credentials_secret="MY_FOO")

with prefect.context(
secrets=dict(MY_FOO=dict(ACCOUNT_NAME=1, ACCOUNT_KEY=999))
):
with set_temporary_config({"cloud.use_local_secrets": True}):
handler.initialize_service()

assert handler.container == "bob"
assert azure_service.call_args[1] == {"account_name": 1, "account_key": 999}

def test_azure_service_writes_to_blob_prefixed_by_date_suffixed_by_prefect(
self, azure_service
):
handler = AzureResultHandler(container="foo")

with prefect.context(
secrets=dict(AZ_CREDENTIALS=dict(ACCOUNT_NAME=1, ACCOUNT_KEY=42))
):
with set_temporary_config({"cloud.use_local_secrets": True}):
uri = handler.write("so-much-data")

a = azure_service.return_value

used_uri = azure_service.return_value.create_blob_from_text.call_args[1][
"blob_name"
]

assert used_uri == uri
assert used_uri.startswith(pendulum.now("utc").format("Y/M/D"))
assert used_uri.endswith("prefect_result")

def test_azure_service_handler_is_pickleable(self, monkeypatch):
class service:
def __init__(self, *args, **kwargs):
pass

def __getstate__(self):
raise ValueError("I cannot be pickled.")

with patch.dict(
"sys.modules", {"azure.storage.blob": MagicMock(BlockBlobService=service)}
):
with prefect.context(
secrets=dict(AZ_CREDENTIALS=dict(ACCOUNT_NAME=1, ACCOUNT_KEY=42))
):
with set_temporary_config({"cloud.use_local_secrets": True}):
handler = AzureResultHandler(container="foo")
res = cloudpickle.loads(cloudpickle.dumps(handler))
assert isinstance(res, AzureResultHandler)
45 changes: 45 additions & 0 deletions tests/serialization/test_result_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
LocalResultHandler,
ResultHandler,
S3ResultHandler,
AzureResultHandler,
)
from prefect.serialization.result_handlers import (
CustomResultHandlerSchema,
Expand Down Expand Up @@ -211,3 +212,47 @@ def raise_me(*args, **kwargs):
assert isinstance(handler, S3ResultHandler)
assert handler.bucket == "bucket3"
assert handler.aws_credentials_secret == "FOO"


@pytest.mark.xfail(raises=ImportError, reason="azure extras not installed.")
class TestAzureResultHandler:
def test_serialize(self):
handler = AzureResultHandler(
container="my-container", azure_credentials_secret="FOO"
)
serialized = ResultHandlerSchema().dump(handler)
assert serialized["type"] == "AzureResultHandler"
assert serialized["container"] == "my-container"
assert serialized["azure_credentials_secret"] == "FOO"

def test_deserialize_from_dict(self):
handler = ResultHandlerSchema().load(
{"type": "AzureResultHandler", "container": "foo-bar"}
)
assert isinstance(handler, AzureResultHandler)
assert handler.container == "foo-bar"
assert handler.azure_credentials_secret == "AZ_CREDENTIALS"

def test_roundtrip(self):
schema = ResultHandlerSchema()
handler = schema.load(schema.dump(AzureResultHandler(container="container3")))
assert isinstance(handler, AzureResultHandler)
assert handler.container == "container3"

def test_roundtrip_never_loads_client(self, monkeypatch):
schema = ResultHandlerSchema()

def raise_me(*args, **kwargs):
raise SyntaxError("oops")

monkeypatch.setattr(AzureResultHandler, "initialize_service", raise_me)
handler = schema.load(
schema.dump(
AzureResultHandler(
container="container3", azure_credentials_secret="FOO"
)
)
)
assert isinstance(handler, AzureResultHandler)
assert handler.container == "container3"
assert handler.azure_credentials_secret == "FOO"