Skip to content

Commit

Permalink
Merge pull request #216 from Cornerstone-OnDemand/215-update-drivers-cli
Browse files Browse the repository at this point in the history
Update assets cli after drivers breaking changes
  • Loading branch information
tgenin committed Mar 5, 2024
2 parents 7e2457c + 99e6ba5 commit 65c4fdc
Showing 1 changed file with 14 additions and 10 deletions.
24 changes: 14 additions & 10 deletions modelkit/assets/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,14 @@
from rich.table import Table
from rich.tree import Tree

from modelkit.assets.drivers.abc import StorageDriverSettings

try:
from modelkit.assets.drivers.gcs import GCSStorageDriver
from modelkit.assets.drivers.gcs import GCSStorageDriver, GCSStorageDriverSettings

has_gcs = True
except ModuleNotFoundError:
has_gcs = False
try:
from modelkit.assets.drivers.s3 import S3StorageDriver
from modelkit.assets.drivers.s3 import S3StorageDriver, S3StorageDriverSettings

has_s3 = True
except ModuleNotFoundError:
Expand Down Expand Up @@ -132,20 +130,23 @@ def new_(asset_path, asset_spec, storage_prefix, dry_run):
with tempfile.TemporaryDirectory() as tmp_dir:
if not os.path.exists(asset_path):
parsed_path = parse_remote_url(asset_path)
driver_settings = StorageDriverSettings(
bucket=parsed_path["bucket_name"]
)
if parsed_path["storage_prefix"] == "gs":
if not has_gcs:
raise DriverNotInstalledError(
"GCS driver not installed, install modelkit[assets-gcs]"
)
driver_settings = GCSStorageDriverSettings(
bucket=parsed_path["bucket_name"]
)
driver = GCSStorageDriver(driver_settings)
elif parsed_path["storage_prefix"] == "s3":
if not has_s3:
raise DriverNotInstalledError(
"S3 driver not installed, install modelkit[assets-s3]"
)
driver_settings = S3StorageDriverSettings(
bucket=parsed_path["bucket_name"]
)
driver = S3StorageDriver(driver_settings)
else:
raise ValueError(
Expand Down Expand Up @@ -234,20 +235,23 @@ def update_(asset_path, asset_spec, storage_prefix, bump_major, dry_run):
with tempfile.TemporaryDirectory() as tmp_dir:
if not os.path.exists(asset_path):
parsed_path = parse_remote_url(asset_path)
driver_settings = StorageDriverSettings(
bucket=parsed_path["bucket_name"]
)
if parsed_path["storage_prefix"] == "gs":
if not has_gcs:
raise DriverNotInstalledError(
"GCS driver not installed, install modelkit[assets-gcs]"
)
driver_settings = GCSStorageDriverSettings(
bucket=parsed_path["bucket_name"]
)
driver = GCSStorageDriver(driver_settings)
elif parsed_path["storage_prefix"] == "s3":
if not has_s3:
raise DriverNotInstalledError(
"S3 driver not installed, install modelkit[assets-s3]"
)
driver_settings = S3StorageDriverSettings(
bucket=parsed_path["bucket_name"]
)
driver = S3StorageDriver(driver_settings)
else:
raise ValueError(
Expand Down

0 comments on commit 65c4fdc

Please sign in to comment.