-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
68ef2cd
commit 121eee0
Showing
3 changed files
with
190 additions
and
71 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
from dataclasses import dataclass | ||
from typing import List, Tuple | ||
from subprocess import Popen, PIPE | ||
import logging | ||
from contextlib import contextmanager | ||
from tempfile import NamedTemporaryFile | ||
import re | ||
import shlex | ||
|
||
|
||
@dataclass | ||
class S5CmdCpConfig: | ||
flatten_dir: bool = False | ||
no_overwrite: bool = False | ||
overwrite_if_size: bool = False | ||
overwrite_if_newer: bool = False | ||
|
||
@property | ||
def cp_config_str(self): | ||
cfg = [ | ||
"--flatten " if self.flatten_dir else "", | ||
"--no-clobber " if self.no_overwrite else "", | ||
"--if-size-differ " if self.overwrite_if_size else "", | ||
"--if-source-newer " if self.overwrite_if_newer else "", | ||
] | ||
return "".join(cfg) | ||
|
||
def __str__(self): | ||
return self.cp_config_str | ||
|
||
def __bool__(self): | ||
return bool(self.cp_config_str) | ||
|
||
|
||
class S5Cmd: | ||
def __init__( | ||
self, | ||
access_credentials: Tuple[str, str], | ||
logging_level: str = "debug", | ||
cp_config: S5CmdCpConfig = None, | ||
): | ||
if not cp_config: | ||
self.cp_config = S5CmdCpConfig() | ||
|
||
self.access_credentials = access_credentials | ||
self.logging_level = logging_level | ||
self.uri_identifier_regex = r"^[a-zA-Z0-9_]{2}://" | ||
|
||
@contextmanager | ||
def _bucket_credentials(self): | ||
try: | ||
with NamedTemporaryFile() as tmp: | ||
with open(tmp.name, "w") as f: | ||
f.write("[default]\n") | ||
f.write(f"aws_access_key_id = {self.access_credentials[0]}\n") | ||
f.write(f"aws_secret_access_key = {self.access_credentials[1]}\n") | ||
yield tmp.name | ||
finally: | ||
pass | ||
|
||
def download_from_bucket(self, uris: List[str], local_path: str, out_stream=None): | ||
"""Downloads file(s) from a bucket using s5cmd and a supplied list of URIs.""" | ||
main_cmd = "s5cmd" | ||
|
||
# Check the integrity of the supplied URIs | ||
uri_prefixes = [re.match(self.uri_identifier_regex, uri) for uri in uris] | ||
if None in uri_prefixes: | ||
raise ValueError("All URIs must begin with a qualified bucket prefix.") | ||
uri_prefixes = [prefix.group() for prefix in uri_prefixes] | ||
if not len(set(uri_prefixes)) == 1: | ||
raise ValueError(f"All URIs must begin with the same prefix. Found prefixes: {set(uri_prefixes)}") | ||
uri_prefix = uri_prefixes[0] | ||
if uri_prefix not in ["gs://", "s3://"]: | ||
raise ValueError(f"Only gs:// and s3:// URIs are supported. Found prefix: {uri_prefix}") | ||
|
||
# Amend the URIs with the s3:// prefix and add endpoint URL if required | ||
if uri_prefix == "gs://": | ||
main_cmd += " --endpoint-url https://storage.googleapis.com" | ||
for i, uri in enumerate(uris): | ||
uris[i] = uri.replace("gs://", "s3://") | ||
|
||
# Configure the input and output streams | ||
stdout = out_stream if out_stream else PIPE | ||
stderr = out_stream if out_stream else PIPE | ||
|
||
# Make the run commands | ||
blob_cmds = [] | ||
for uri in uris: | ||
blob_cmd = "cp" | ||
if self.cp_config: | ||
blob_cmd += f" {str(self.cp_config)}" | ||
blob_cmd += f" {uri} {local_path}" | ||
blob_cmds.append(blob_cmd) | ||
blob_stdin = "\n".join(blob_cmds) | ||
logging.info(f"s5cmd blob download command example: {blob_cmds[0]}") | ||
|
||
# Initialise credentials and download | ||
with self._bucket_credentials() as credentials: | ||
main_cmd += f" --credentials-file {credentials} run" | ||
logging.info(f"Executing download command: {main_cmd}") | ||
proc = Popen(shlex.split(main_cmd), stdout=stdout, stderr=stderr, stdin=PIPE) | ||
proc.communicate(input=blob_stdin.encode()) | ||
returncode = proc.wait() | ||
return returncode | ||
|
||
def upload_to_bucket(self, files: List[str], bucket_uri: str, out_stream=None): | ||
"""Downloads file(s) from a bucket using s5cmd and a supplied list of URIs.""" | ||
cmd = "s5cmd" | ||
|
||
# Check that the uri prefixes are supported | ||
bucket_prefix = re.match(self.uri_identifier_regex, bucket_uri).group(0) | ||
if bucket_prefix not in ["gs://", "s3://"]: | ||
raise ValueError(f"Only gs:// and s3:// URIs are supported. Found prefix: {bucket_prefix}") | ||
|
||
# Amend the URIs with the s3:// prefix and add endpoint URL if required | ||
if bucket_prefix == "gs://": | ||
cmd = " ".join([cmd, "--endpoint-url https://storage.googleapis.com"]) | ||
|
||
# Configure the input and output streams | ||
stdout = out_stream if out_stream else PIPE | ||
stderr = out_stream if out_stream else PIPE | ||
blob_stdin = "\n".join(files) | ||
|
||
# Initialise credentials and download | ||
with self._bucket_credentials() as credentials: | ||
logging.info(f"Executing download command: {cmd}") | ||
cmd = " ".join([cmd, f" --credentials_file {credentials} cp {self.cp_config_str} {bucket_uri}"]) | ||
proc = Popen(shlex.split(cmd), shell=False, stdout=stdout, stderr=stderr) | ||
proc.communicate(input=blob_stdin.encode()) | ||
returncode = proc.wait() | ||
return returncode |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters