Skip to content

Commit

Permalink
merged master
Browse files Browse the repository at this point in the history
  • Loading branch information
tmbo committed Apr 17, 2018
2 parents 9e8e988 + 3450aee commit 1661ef5
Show file tree
Hide file tree
Showing 5 changed files with 146 additions and 5 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ Added
- added docs blurb on handling contextual dialogue
- distribute package as wheel file in addition to source distribution (faster install)
- allow a component to specify which languages it supports
- support for persisting models to Azure Storage
- added tokenizer for CHINESE (``zh``) as well as instructions on how to load
MITIE model

Expand Down
1 change: 1 addition & 0 deletions alt_requirements/requirements_dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ moto==1.2.0
mock==2.0.0
# other
google-cloud-storage==1.7.0
azure-storage-blob==1.0.0

# docs
sphinx==1.5.2
Expand Down
16 changes: 15 additions & 1 deletion docs/persist.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ Rasa NLU supports using `S3 <https://aws.amazon.com/s3/>`_ and
- ``BUCKET_NAME``
- ``AWS_ENDPOINT_URL``

If there is no bucket with the name ``BUCKET_NAME`` Rasa will create it.

* Google Cloud Storage
GCS is supported using the ``google-cloud-storage`` package
Expand All @@ -37,5 +38,18 @@ Rasa NLU supports using `S3 <https://aws.amazon.com/s3/>`_ and
and setting the ``GOOGLE_APPLICATION_CREDENTIALS`` environment
variable to the path of that key file.

If there is no bucket with the name ``BUCKET_NAME`` Rasa will create it.
* Azure Storage
Azure is supported using the ``azure-storage-blob`` package
which you can install with ``pip install azure-storage-blob``

Start the Rasa NLU server with ``storage`` option set to ``azure``.

The following environment variables must be set:

- ``AZURE_CONTAINER``
- ``AZURE_ACCOUNT_NAME``
- ``AZURE_ACCOUNT_KEY``

If there is no container with the name ``AZURE_CONTAINER`` Rasa will create it.

Models are gzipped before saving to cloud.
92 changes: 88 additions & 4 deletions rasa_nlu/persistor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,20 @@ def get_persistor(name):
# type: (Text) -> Optional[Persistor]
"""Returns an instance of the requested persistor.
Currently, `aws` and `gcs` are supported"""
Currently, `aws`, `gcs` and `azure` are supported"""

if name == 'aws':
return AWSPersistor(os.environ.get("BUCKET_NAME"),
os.environ.get("AWS_ENDPOINT_URL"))
elif name == 'gcs':
if name == 'gcs':
return GCSPersistor(os.environ.get("BUCKET_NAME"))
else:
return None

if name == 'azure':
return AzurePersistor(os.environ.get("AZURE_CONTAINER"),
os.environ.get("AZURE_ACCOUNT_NAME"),
os.environ.get("AZURE_ACCOUNT_KEY"))

return None


class Persistor(object):
Expand Down Expand Up @@ -255,3 +260,82 @@ def _retrieve_tar(self, target_filename):

blob = self.bucket.blob(target_filename)
blob.download_to_filename(target_filename)


class AzurePersistor(Persistor):
"""Store models on Azure"""

def __init__(self,
azure_container,
azure_account_name,
azure_account_key):
from azure.storage import blob as azureblob
from azure.storage.common import models as storageModel

super(AzurePersistor, self).__init__()

self.blob_client = azureblob.BlockBlobService(
account_name=azure_account_name,
account_key=azure_account_key,
endpoint_suffix="core.windows.net")

self._ensure_container_exists(azure_container)
self.container_name = azure_container

def _ensure_container_exists(self, container_name):
# type: (Text) -> None

exists = self.blob_client.exists(container_name)
if not exists:
self.blob_client.create_container(container_name)

def list_models(self, project):
# type: (Text) -> List[Text]

try:
blob_iterator = self.blob_client.list_blobs(
self.container_name,
prefix=self._project_prefix(project)
)
return [self._project_and_model_from_filename(b.name)[1]
for b in blob_iterator]
except Exception as e:
logger.warning("Failed to list models for project {} in "
"azure blob storage. {}".format(project, e))
return []

def list_projects(self):
# type: () -> List[Text]

try:
blob_iterator = self.blob_client.list_blobs(
self.container_name,
prefix=None
)
projects_set = {self._project_and_model_from_filename(b.name)[0]
for b in blob_iterator}
return list(projects_set)
except Exception as e:
logger.warning("Failed to list projects in "
"Azure. {}".format(e))
return []

def _persist_tar(self, file_key, tar_path):
# type: (Text, Text) -> None
"""Uploads a model persisted in the `target_dir` to Azure."""

self.blob_client.create_blob_from_path(
self.container_name,
file_key,
tar_path
)

def _retrieve_tar(self, target_filename):
# type: (Text) -> None
"""Downloads a model that has previously been persisted to Azure."""

self.blob_client.get_blob_to_path(
self.container_name,
target_filename,
target_filename
)
41 changes: 41 additions & 0 deletions tests/base/test_persistor.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,44 @@ def mocked_list_blobs():
result = persistor.GCSPersistor("").list_projects()

assert result == []


def test_list_projects_method_in_AzurePersistor():
def mocked_init(self, *args, **kwargs):
self._project_and_model_from_filename = lambda x: {'blob_name': ('project', )}[x]
self.blob_client = Object()
self.container_name = 'test'

def mocked_list_blobs(
container_name,
prefix=None
):
filter_result = Object()
filter_result.name = 'blob_name'
return filter_result,

self.blob_client.list_blobs = mocked_list_blobs

with mock.patch.object(persistor.AzurePersistor, "__init__", mocked_init):
result = persistor.AzurePersistor("").list_projects()

assert result == ['project']


def test_list_projects_method_raise_exeception_in_AzurePersistor():
def mocked_init(self, *args, **kwargs):
self._project_and_model_from_filename = lambda x: {'blob_name': ('project', )}[x]
self.blob_client = Object()

def mocked_list_blobs(
container_name,
prefix=None
):
raise ValueError

self.blob_client.list_blobs = mocked_list_blobs

with mock.patch.object(persistor.AzurePersistor, "__init__", mocked_init):
result = persistor.AzurePersistor("").list_projects()

assert result == []

0 comments on commit 1661ef5

Please sign in to comment.