Skip to content

Commit

Permalink
finalised workflow structure
Browse files Browse the repository at this point in the history
  • Loading branch information
keegansmith21 committed Aug 1, 2023
1 parent ebd5e8d commit 9ce141c
Show file tree
Hide file tree
Showing 3 changed files with 203 additions and 256 deletions.
139 changes: 81 additions & 58 deletions academic_observatory_workflows/s5cmd.py
Original file line number Diff line number Diff line change
@@ -1,66 +1,72 @@
import re
import shlex
import logging
from dataclasses import dataclass
from typing import List, Tuple
from typing import List, Tuple, Union
from subprocess import Popen, PIPE
from contextlib import contextmanager
from tempfile import NamedTemporaryFile


@dataclass
class S5CmdCpConfig:
flatten_dir: bool = False
no_overwrite: bool = False
overwrite_if_size: bool = False
overwrite_if_newer: bool = False
"""Configuration for S5Cmd cp command
@property
def cp_config_str(self):
cfg = [
"--flatten " if self.flatten_dir else "",
"--no-clobber " if self.no_overwrite else "",
"--if-size-differ " if self.overwrite_if_size else "",
"--if-source-newer " if self.overwrite_if_newer else "",
]
return "".join(cfg)
:param flatten_dir: Whether to flatten the directory structure
:param no_overwrite: Whether to not overwrite files if they already exist
:param overwrite_if_size: Whether to overwrite files only if source size differs
:param overwrite_if_newer: Whether to overwrite files only if source is newer
"""

def __str__(self):
return self.cp_config_str
def __init__(
self,
flatten_dir: bool = False,
no_overwrite: bool = False,
overwrite_if_size: bool = False,
overwrite_if_newer: bool = False,
):
self.flatten_dir = flatten_dir
self.no_overwrite = no_overwrite
self.overwrite_if_size = overwrite_if_size
self.overwrite_if_newer = overwrite_if_newer

def __bool__(self):
return bool(self.cp_config_str)
def __str__(self):
cfg = [
"--flatten" if self.flatten_dir else "",
"--no-clobber" if self.no_overwrite else "",
"--if-size-differ" if self.overwrite_if_size else "",
"--if-source-newer" if self.overwrite_if_newer else "",
]
cfg = [i for i in cfg if i] # Remove empty strings
return " ".join(cfg)


class S5Cmd:
def __init__(
self,
access_credentials: Tuple[str, str],
logging_level: str = "debug",
logging_level: str = "info",
out_stream: str = PIPE,
cp_config: S5CmdCpConfig = None,
):
if not cp_config:
self.cp_config = S5CmdCpConfig()

self.access_credentials = access_credentials
self.logging_level = logging_level
self.output_stream = out_stream
self.out_stream = out_stream
self.uri_identifier_regex = r"^[a-zA-Z0-9_]{2}://"

def _uri(self, uri: str):
return uri.replace("gs://", "s3://")

def _initialise_command(self, uri: str):
"""Initializes the command for the given bucket URI.
:param uri: The URI being accessed.
:return: The initialized command."""
cmd = "s5cmd"

# Check that the uri prefixes are supported
bucket_prefix = re.match(self.uri_identifier_regex, uri).group(0)
if bucket_prefix not in ["gs://", "s3://"]:
raise ValueError(f"Only gs:// and s3:// URIs are supported. Found prefix: {bucket_prefix}")
uri_prefix = re.match(self.uri_identifier_regex, uri).group(0)
if uri_prefix not in ["gs://", "s3://"]:
raise ValueError(f"Only gs:// and s3:// URIs are supported. Found prefix: {uri_prefix}")

# Amend the URIs with the s3:// prefix and add endpoint URL if required
if bucket_prefix == "gs://":
if uri_prefix == "gs://":
cmd = " ".join([cmd, "--endpoint-url https://storage.googleapis.com"])

return cmd
Expand All @@ -77,8 +83,15 @@ def _bucket_credentials(self):
finally:
pass

def download_from_bucket(self, uris: List[str], local_path: str):
"""Downloads file(s) from a bucket using s5cmd and a supplied list of URIs."""
def download_from_bucket(self, uris: Union[List[str], str], local_path: str) -> Tuple[bytes, bytes, int]:
"""Downloads file(s) from a bucket using s5cmd and a supplied list of URIs.
:param uris: The URI or list of URIs to download.
:param local_path: The local path to download to.
:return: A tuple of (stdout, stderr, s5cmd exit code).
"""
if not isinstance(uris, list):
uris = [uris]

# Check the integrity of the supplied URIs
uri_prefixes = [re.match(self.uri_identifier_regex, uri) for uri in uris]
Expand All @@ -92,18 +105,13 @@ def download_from_bucket(self, uris: List[str], local_path: str):
raise ValueError(f"Only gs:// and s3:// URIs are supported. Found prefix: {uri_prefix}")

# Amend the URIs with the s3:// prefix and add endpoint URL if required
cmd = "s5cmd"
if uri_prefix == "gs://":
cmd += " --endpoint-url https://storage.googleapis.com"
for i, uri in enumerate(uris):
uris[i] = uri.replace("gs://", "s3://")
cmd = self._initialise_command(uris[0])

# Make the run commands
blob_cmds = []
for uri in uris:
for uri in map(self._uri, uris):
blob_cmd = "cp"
if self.cp_config:
blob_cmd += f" {str(self.cp_config)}"
blob_cmd += f" {str(self.cp_config)}"
blob_cmd += f" {uri} {local_path}"
blob_cmds.append(blob_cmd)
blob_stdin = "\n".join(blob_cmds)
Expand All @@ -112,35 +120,50 @@ def download_from_bucket(self, uris: List[str], local_path: str):
# Initialise credentials and execute
with self._bucket_credentials() as credentials:
cmd += f" --credentials-file {credentials} run"
logging.info(f"Executing download command: {cmd}")
proc = Popen(shlex.split(cmd), stdout=self.out_stream, stderr=self.out_stream, stdin=PIPE)
proc.communicate(input=blob_stdin.encode())
returncode = proc.wait()
return returncode

def upload_to_bucket(self, files: List[str], bucket_uri: str):
"""Downloads file(s) from a bucket using s5cmd and a supplied list of URIs."""
stdout, stderr = proc.communicate(input=blob_stdin.encode())
returncode = proc.wait()
if returncode > 0:
logging.warn(f"s5cmd cp failed with return code {returncode}: {stderr}")
return stdout, stderr, returncode

def upload_to_bucket(self, files: Union[List[str], str], bucket_uri: str, cp_config: S5CmdCpConfig = None):
"""Downloads file(s) from a bucket using s5cmd and a supplied list of URIs.
:param files: The file(s) to upload.
:bucket_uri: The URI to upload to.
:return: A tuple of (stdout, stderr, s5cmd exit code).
"""
if not isinstance(files, list):
files = [files]
if not cp_config:
cp_config = S5CmdCpConfig()
cmd = self._initialise_command(bucket_uri)
blob_stdin = "\n".join(files)

# Initialise credentials and execute
with self._bucket_credentials() as credentials:
logging.info(f"Executing download command: {cmd}")
cmd = " ".join([cmd, f" --credentials_file {credentials} cp {self.cp_config_str} {bucket_uri}"])
cmd = " ".join([cmd, f" --credentials-file {credentials} cp {cp_config} {self._uri(bucket_uri)}"])
proc = Popen(shlex.split(cmd), shell=False, stdout=self.out_stream, stderr=self.out_stream)
proc.communicate(input=blob_stdin.encode())
stdout, stderr = proc.communicate(input=blob_stdin.encode())
returncode = proc.wait()
return returncode

def cat(self, blob_uri: str) -> Tuple[int, bytes, bytes]:
"""Executes a s5cmd cat operation on a remote file"""
if returncode > 0:
logging.warn(f"s5cmd cp failed with return code {returncode}: {stderr}")
return stdout, stderr, returncode

def cat(self, blob_uri: str) -> Tuple[bytes, bytes, int]:
"""Executes a s5cmd cat operation on a remote file
:param blob_uri: The URI to execute the cat on.
:return: A tuple of (stdout, stderr, s5cmd exit code).
"""
cmd = self._initialise_command(blob_uri)

# Initialise credentials and execute
with self._bucket_credentials() as credentials:
logging.info(f"Executing download command: {cmd}")
cmd = " ".join([cmd, f" --credentials_file {credentials} cat {blob_uri}"])
proc = Popen(shlex.split(cmd), shell=False, stdout=self.out_stream, stderr=self.out_stream)
cmd = " ".join([cmd, f" --credentials-file {credentials} cat {self._uri(blob_uri)}"])
proc = Popen(shlex.split(cmd), shell=False, stdout=PIPE, stderr=PIPE)
stdout, stderr = proc.communicate()
returncode = proc.wait()
return returncode, stdout, stderr
if returncode > 0:
logging.warn(f"s5cmd cat failed with return code {returncode}: {stderr}")
return stdout, stderr, returncode
Loading

0 comments on commit 9ce141c

Please sign in to comment.