Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make asset requirements for remote storage optional #167

Merged
merged 10 commits into from Sep 1, 2022
2 changes: 2 additions & 0 deletions README.md
Expand Up @@ -70,5 +70,7 @@ Install with `pip`:
pip install modelkit
```

Optional dependencies are available for remote storage providers ([see documentation](https://cornerstone-ondemand.github.io/modelkit/assets/storage_provider/#using-different-providers))

## Community
Join our [community](https://discord.gg/ayj5wdAArV) on Discord to get support and leave feedback
13 changes: 8 additions & 5 deletions docs/assets/storage_provider.md
Expand Up @@ -28,11 +28,14 @@ Developers may additionally need to be able to push new assets and or update exi

## Using different providers

The flavor of the remote store that is used depends on the `MODELKIT_STORAGE_PROVIDER` environment variables
The flavor of the remote store that is used depends on optional dependencies used during pip install and on the `MODELKIT_STORAGE_PROVIDER` environment variable.

The default `pip install modelkit` will only allow you to target a local directory.


### Using AWS S3 storage

Use `MODELKIT_STORAGE_PROVIDER=s3` to connect to S3 storage.
Use `pip install modelkit[assets-s3]` and setup this environment variable `MODELKIT_STORAGE_PROVIDER=s3` to connect to S3 storage.

We use [boto3](https://boto3.amazonaws.com/v1/documentation/api/latest/index.html) under the hood.

Expand All @@ -53,7 +56,7 @@ Use `AWS_KMS_KEY_ID` environment variable to set your key and be able to upload

### GCS storage

Use `MODELKIT_STORAGE_PROVIDER=gcs` to connect to GCS storage.
Use `pip install modelkit[assets-gcs]` and setup this environment variable `MODELKIT_STORAGE_PROVIDER=gcs` to connect to GCS storage.

We use [google-cloud-storage](https://googleapis.dev/python/storage/latest/index.html).

Expand All @@ -67,7 +70,7 @@ If `GOOGLE_APPLICATION_CREDENTIALS` is provided, it should point to a local JSON

### Using Azure blob storage

Use `MODELKIT_STORAGE_PROVIDER=az` to connect to Azure blob storage.
Use `pip install modelkit[assets-az]` and setup this environment variable `MODELKIT_STORAGE_PROVIDER=az` to connect to Azure blob storage.

We use [azure-storage-blobl](https://docs.microsoft.com/en-us/azure/storage/blobs/storage-quickstart-blobs-python) under the hood.

Expand All @@ -80,7 +83,7 @@ The client is created by passing the authentication information to `BlobServiceC

### `local` mode

Use `MODELKIT_STORAGE_PROVIDER=local` to treat a local folder as a remote source.
Setup this environment variable `MODELKIT_STORAGE_PROVIDER=local` to treat a local folder as a remote source.

Assets will be downloaded from this folder to the configured asset dir.

Expand Down
32 changes: 29 additions & 3 deletions modelkit/assets/cli.py
Expand Up @@ -10,11 +10,21 @@
from rich.table import Table
from rich.tree import Tree

from modelkit.assets.drivers.gcs import GCSStorageDriver
from modelkit.assets.drivers.s3 import S3StorageDriver
try:
from modelkit.assets.drivers.gcs import GCSStorageDriver

has_gcs = True
except ModuleNotFoundError:
has_gcs = False
try:
from modelkit.assets.drivers.s3 import S3StorageDriver

has_s3 = True
except ModuleNotFoundError:
has_s3 = False
from modelkit.assets.errors import ObjectDoesNotExistError
from modelkit.assets.manager import AssetsManager
from modelkit.assets.remote import StorageProvider
from modelkit.assets.remote import DriverNotInstalledError, StorageProvider
from modelkit.assets.settings import AssetSpec


Expand Down Expand Up @@ -121,8 +131,16 @@ def new_(asset_path, asset_spec, storage_prefix, dry_run):
if not os.path.exists(asset_path):
parsed_path = parse_remote_url(asset_path)
if parsed_path["storage_prefix"] == "gs":
if not has_gcs:
raise DriverNotInstalledError(
"GCS driver not installed, install modelkit[assets-gcs]"
)
driver = GCSStorageDriver(bucket=parsed_path["bucket_name"])
elif parsed_path["storage_prefix"] == "s3":
if not has_s3:
raise DriverNotInstalledError(
"S3 driver not installed, install modelkit[assets-s3]"
)
driver = S3StorageDriver(bucket=parsed_path["bucket_name"])
else:
raise ValueError(
Expand Down Expand Up @@ -212,8 +230,16 @@ def update_(asset_path, asset_spec, storage_prefix, bump_major, dry_run):
if not os.path.exists(asset_path):
parsed_path = parse_remote_url(asset_path)
if parsed_path["storage_prefix"] == "gs":
if not has_gcs:
raise DriverNotInstalledError(
"GCS driver not installed, install modelkit[assets-gcs]"
)
driver = GCSStorageDriver(bucket=parsed_path["bucket_name"])
elif parsed_path["storage_prefix"] == "s3":
if not has_s3:
raise DriverNotInstalledError(
"S3 driver not installed, install modelkit[assets-s3]"
)
driver = S3StorageDriver(bucket=parsed_path["bucket_name"])
else:
raise ValueError(
Expand Down
14 changes: 8 additions & 6 deletions modelkit/assets/drivers/azure.py
Expand Up @@ -7,10 +7,12 @@

from modelkit.assets import errors
from modelkit.assets.drivers.abc import StorageDriver
from modelkit.assets.drivers.retry import RETRY_POLICY
from modelkit.assets.drivers.retry import retry_policy

logger = get_logger(__name__)

AZURE_RETRY_POLICY = retry_policy()


class AzureStorageDriver(StorageDriver):
bucket: str
Expand All @@ -34,13 +36,13 @@ def __init__(
os.environ["AZURE_STORAGE_CONNECTION_STRING"]
)

@retry(**RETRY_POLICY)
@retry(**AZURE_RETRY_POLICY)
def iterate_objects(self, prefix=None):
container = self.client.get_container_client(self.bucket)
for blob in container.list_blobs(prefix=prefix):
yield blob["name"]

@retry(**RETRY_POLICY)
@retry(**AZURE_RETRY_POLICY)
def upload_object(self, file_path, object_name):
blob_client = self.client.get_blob_client(
container=self.bucket, blob=object_name
Expand All @@ -50,7 +52,7 @@ def upload_object(self, file_path, object_name):
with open(file_path, "rb") as f:
blob_client.upload_blob(f)

@retry(**RETRY_POLICY)
@retry(**AZURE_RETRY_POLICY)
def download_object(self, object_name, destination_path):
blob_client = self.client.get_blob_client(
container=self.bucket, blob=object_name
Expand All @@ -67,14 +69,14 @@ def download_object(self, object_name, destination_path):
with open(destination_path, "wb") as f:
f.write(blob_client.download_blob().readall())

@retry(**RETRY_POLICY)
@retry(**AZURE_RETRY_POLICY)
def delete_object(self, object_name):
blob_client = self.client.get_blob_client(
container=self.bucket, blob=object_name
)
blob_client.delete_blob()

@retry(**RETRY_POLICY)
@retry(**AZURE_RETRY_POLICY)
def exists(self, object_name):
blob_client = self.client.get_blob_client(
container=self.bucket, blob=object_name
Expand Down
16 changes: 9 additions & 7 deletions modelkit/assets/drivers/gcs.py
@@ -1,18 +1,20 @@
import os
from typing import Optional

from google.api_core.exceptions import NotFound
from google.api_core.exceptions import GoogleAPIError, NotFound
from google.cloud import storage
from google.cloud.storage import Client
from structlog import get_logger
from tenacity import retry

from modelkit.assets import errors
from modelkit.assets.drivers.abc import StorageDriver
from modelkit.assets.drivers.retry import RETRY_POLICY
from modelkit.assets.drivers.retry import retry_policy

logger = get_logger(__name__)

GCS_RETRY_POLICY = retry_policy(GoogleAPIError)


class GCSStorageDriver(StorageDriver):
bucket: str
Expand All @@ -35,13 +37,13 @@ def __init__(
else:
self.client = Client()

@retry(**RETRY_POLICY)
@retry(**GCS_RETRY_POLICY)
def iterate_objects(self, prefix=None):
bucket = self.client.bucket(self.bucket)
for blob in bucket.list_blobs(prefix=prefix):
yield blob.name

@retry(**RETRY_POLICY)
@retry(**GCS_RETRY_POLICY)
def upload_object(self, file_path, object_name):
bucket = self.client.bucket(self.bucket)
blob = bucket.blob(object_name)
Expand All @@ -50,7 +52,7 @@ def upload_object(self, file_path, object_name):
with open(file_path, "rb") as f:
blob.upload_from_file(f)

@retry(**RETRY_POLICY)
@retry(**GCS_RETRY_POLICY)
def download_object(self, object_name, destination_path):
bucket = self.client.bucket(self.bucket)
blob = bucket.blob(object_name)
Expand All @@ -66,13 +68,13 @@ def download_object(self, object_name, destination_path):
driver=self, bucket=self.bucket, object_name=object_name
)

@retry(**RETRY_POLICY)
@retry(**GCS_RETRY_POLICY)
def delete_object(self, object_name):
bucket = self.client.bucket(self.bucket)
blob = bucket.blob(object_name)
blob.delete()

@retry(**RETRY_POLICY)
@retry(**GCS_RETRY_POLICY)
def exists(self, object_name):
bucket = self.client.bucket(self.bucket)
blob = bucket.blob(object_name)
Expand Down
37 changes: 20 additions & 17 deletions modelkit/assets/drivers/retry.py
@@ -1,20 +1,10 @@
import botocore
import google
import requests
from structlog import get_logger
from tenacity import retry_if_exception, stop_after_attempt, wait_random_exponential

logger = get_logger(__name__)


def retriable_error(exception):
return (
isinstance(exception, botocore.exceptions.ClientError)
or isinstance(exception, google.api_core.exceptions.GoogleAPIError)
or isinstance(exception, requests.exceptions.ChunkedEncodingError)
)


def log_after_retry(retry_state):
logger.info(
"Retrying",
Expand All @@ -24,10 +14,23 @@ def log_after_retry(retry_state):
)


RETRY_POLICY = {
"wait": wait_random_exponential(multiplier=1, min=4, max=10),
"stop": stop_after_attempt(5),
"retry": retry_if_exception(retriable_error),
"after": log_after_retry,
"reraise": True,
}
def retry_policy(type_error=None):
if not type_error:

def is_retry_eligible(error):
return isinstance(error, requests.exceptions.ChunkedEncodingError)

else:

def is_retry_eligible(error):
return isinstance(error, type_error) or isinstance(
error, requests.exceptions.ChunkedEncodingError
)

return {
"wait": wait_random_exponential(multiplier=1, min=4, max=10),
"stop": stop_after_attempt(5),
"retry": retry_if_exception(is_retry_eligible),
"after": log_after_retry,
"reraise": True,
}
14 changes: 8 additions & 6 deletions modelkit/assets/drivers/s3.py
Expand Up @@ -8,10 +8,12 @@

from modelkit.assets import errors
from modelkit.assets.drivers.abc import StorageDriver
from modelkit.assets.drivers.retry import RETRY_POLICY
from modelkit.assets.drivers.retry import retry_policy

logger = get_logger(__name__)

S3_RETRY_POLICY = retry_policy(botocore.exceptions.ClientError)


class S3StorageDriver(StorageDriver):
bucket: str
Expand Down Expand Up @@ -44,15 +46,15 @@ def __init__(
region_name=aws_default_region or os.environ.get("AWS_DEFAULT_REGION"),
)

@retry(**RETRY_POLICY)
@retry(**S3_RETRY_POLICY)
def iterate_objects(self, prefix=None):
paginator = self.client.get_paginator("list_objects_v2")
pages = paginator.paginate(Bucket=self.bucket, Prefix=prefix or "")
for page in pages:
for obj in page.get("Contents", []):
yield obj["Key"]

@retry(**RETRY_POLICY)
@retry(**S3_RETRY_POLICY)
def upload_object(self, file_path, object_name):
if self.aws_kms_key_id:
self.client.upload_file( # pragma: no cover
Expand All @@ -67,7 +69,7 @@ def upload_object(self, file_path, object_name):
else:
self.client.upload_file(file_path, self.bucket, object_name)

@retry(**RETRY_POLICY)
@retry(**S3_RETRY_POLICY)
def download_object(self, object_name, destination_path):
try:
with open(destination_path, "wb") as f:
Expand All @@ -81,11 +83,11 @@ def download_object(self, object_name, destination_path):
driver=self, bucket=self.bucket, object_name=object_name
)

@retry(**RETRY_POLICY)
@retry(**S3_RETRY_POLICY)
def delete_object(self, object_name):
self.client.delete_object(Bucket=self.bucket, Key=object_name)

@retry(**RETRY_POLICY)
@retry(**S3_RETRY_POLICY)
def exists(self, object_name):
try:
self.client.head_object(Bucket=self.bucket, Key=object_name)
Expand Down