Skip to content
This repository has been archived by the owner on Sep 19, 2023. It is now read-only.

extended storage functions: ls, stream, rename, delete #38

Merged
merged 23 commits into from Jan 14, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
@@ -1 +1,2 @@
*.pyc
.pytest_cache
6 changes: 4 additions & 2 deletions README.md
Expand Up @@ -90,6 +90,7 @@ Multiple storage destinations can be defined with the `--storage_config` option
"storage_id_2": {
"type": "ssh",
"server": "my-server.com",
"basedir": "myrepo",
"user": "root",
"password": "root"
}
Expand All @@ -106,7 +107,8 @@ python entrypoint.py --storage_config storages.json --model_storage storage_id_2
If the configuration is not provided or a storage identifier is not set, the host filesystem is used.

Available storage types are:
* `ssh`: transfer files or directories using ssh, requires `server` name, `user` and `password`
* `ssh`: transfer files or directories using ssh, requires `server` name, `user` and `password` or `pkey`. `basedir` (optional) defines base directory for relative paths.
* `local`: local file storage, `basedir` (optional) defines base directory for relative paths
* `s3`: transfer files or directories using ssh, requires `bucket` and `aws_credentials`
* `http`: transfer files only using simple GET and POST requests. Requires `get_pattern` and `push_pattern` that are urls using `%s` string placeholders, expanded with python `%` operator: for instance `http://opennmt.net/%s/`

Expand Down Expand Up @@ -320,7 +322,7 @@ Reload the model on the reserved resource. In its simplest form, this route will

Serving is currently supported by the following frameworks:

* `google_transate`
* `google_translate`
* `opennmt_lua`
* `opennmt_tf`

Expand Down
275 changes: 82 additions & 193 deletions nmtwizard/storage.py

Large diffs are not rendered by default.

Empty file added nmtwizard/storages/__init__.py
Empty file.
63 changes: 63 additions & 0 deletions nmtwizard/storages/generic.py
@@ -0,0 +1,63 @@
import os
import abc
import six

@six.add_metaclass(abc.ABCMeta)
class Storage(object):
"""Abstract class for storage implementations."""

def __init__(self, storage_id):
self._storage_id = storage_id

# Non conventional storage might need to override these.
def join(self, path, *paths):
"""Build a path respecting storage prefix
"""
return os.path.join(path, *paths)

def split(self, path):
"""Split a path
"""
return os.path.split(path)

@abc.abstractmethod
def get(self, remote_path, local_path, directory=False):
"""Get a file or a directory from a storage to a local file
"""
raise NotImplementedError()

@abc.abstractmethod
def stream(self, remote_path, buffer_size=1024):
"""return a generator on a remote file
"""
raise NotImplementedError()

@abc.abstractmethod
def push(self, local_path, remote_path):
"""Push a local file on a remote storage
"""
raise NotImplementedError()

@abc.abstractmethod
def listdir(self, remote_path, recursive=False):
"""Return a dictionary with all files in the remote directory
"""
raise NotImplementedError()

@abc.abstractmethod
def delete(self, remote_path, recursive=False):
"""Delete a file or a directory from a storage
"""
raise NotImplementedError()

@abc.abstractmethod
def rename(self, old_remote_path, new_remote_path):
"""Delete a file or a directory from a storage
"""
raise NotImplementedError()

@abc.abstractmethod
def exists(self, remote_path):
"""Check if path is existing
"""
raise NotImplementedError()
78 changes: 78 additions & 0 deletions nmtwizard/storages/http.py
@@ -0,0 +1,78 @@
"""Definition of `http` storage class"""

import os
import requests

from nmtwizard.storages.generic import Storage

class HTTPStorage(Storage):
"""Simple http file-only storage."""

def __init__(self, storage_id, pattern_get, pattern_push=None, pattern_list=None):
super(HTTPStorage, self).__init__(storage_id)
self._pattern_get = pattern_get
self._pattern_push = pattern_push
self._pattern_list = pattern_list

def get(self, remote_path, local_path, directory=False):
if not directory:
res = requests.get(self._pattern_get % remote_path)
if res.status_code != 200:
raise RuntimeError('cannot not get %s (response code %d)' % (remote_path, res.status_code))
if os.path.isdir(local_path):
local_path = os.path.join(local_path, os.path.basename(remote_path))
elif not os.path.exists(os.path.dirname(local_path)):
os.makedirs(os.path.dirname(local_path))
with open(local_path, "wb") as f:
f.write(res.content)
elif self._pattern_list is None:
raise ValueError('http storage %s can not handle directories' % self._storage_id)
else:
res = requests.get(self._pattern_list % remote_path)
if res.status_code != 200:
raise RuntimeError('Error when listing remote directory %s (status %d)' % (
remote_path, res.status_code))
data = res.json()
for f in data:
path = f["path"]
self.get(os.path.join(remote_path, path), os.path.join(local_path, path))

def stream(self, remote_path, buffer_size=1024):
res = requests.get(self._pattern_get % remote_path, stream=True)
if res.status_code != 200:
raise RuntimeError('cannot not get %s (response code %d)' % (remote_path, res.status_code))

def generate():
for chunk in res.iter_content(chunk_size=buffer_size, decode_unicode=None):
yield chunk

return generate()

def push(self, local_path, remote_path):
if self._pattern_push is None:
raise ValueError('http storage %s can not handle post requests' % self._storage_id)
if os.path.isfile(local_path):
with open(local_path, "rb") as f:
data = f.read()
res = requests.post(url=self._pattern_push % remote_path,
data=data,
headers={'Content-Type': 'application/octet-stream'})
if res.status_code != 200:
raise RuntimeError('cannot not post %s to %s (response code %d)' % (
local_path,
remote_path,
res.status_code))
else:
raise NotImplementedError('http storage can not handle directories')

def listdir(self, remote_path, recursive=False):
raise NotImplementedError()

def delete(self, remote_path, recursive=False):
raise NotImplementedError()

def rename(self, old_remote_path, new_remote_path):
raise NotImplementedError()

def exists(self, remote_path):
raise NotImplementedError()
100 changes: 100 additions & 0 deletions nmtwizard/storages/local.py
@@ -0,0 +1,100 @@
"""Definition of `local` storage class"""

import shutil
import os

from nmtwizard.storages.generic import Storage

class LocalStorage(Storage):
"""Storage using the local filesystem."""

def __init__(self, storage_id=None, basedir=None):
super(LocalStorage, self).__init__(storage_id or "local")
self._basedir = basedir

def get(self, remote_path, local_path, directory=False):
if self._basedir:
remote_path = os.path.join(self._basedir, remote_path)
if directory:
shutil.copytree(remote_path, local_path)
else:
shutil.copy(remote_path, local_path)

def stream(self, remote_path, buffer_size=1024):
if self._basedir:
remote_path = os.path.join(self._basedir, remote_path)

def generate():
"""generator function to stream local file"""
with open(remote_path, "rb") as f:
for chunk in iter(lambda: f.read(buffer_size), b''):
yield chunk
return generate()

def push(self, local_path, remote_path):
if self._basedir:
remote_path = os.path.join(self._basedir, remote_path)
if os.path.isdir(local_path):
shutil.copytree(local_path, remote_path)
else:
if remote_path.endswith('/') or os.path.isdir(remote_path):
remote_path = os.path.join(remote_path, os.path.basename(local_path))
dirname = os.path.dirname(remote_path)
if os.path.exists(dirname):
if not os.path.isdir(dirname):
raise ValueError("%s is not a directory" % dirname)
else:
os.makedirs(dirname)
shutil.copy(local_path, remote_path)

def delete(self, remote_path, recursive=False):
if self._basedir:
remote_path = os.path.join(self._basedir, remote_path)
if recursive:
if not os.path.isdir(remote_path):
os.remove(remote_path)
else:
shutil.rmtree(remote_path, ignore_errors=True)
else:
if not os.path.isfile(remote_path):
raise ValueError("%s is not a file" % remote_path)
os.remove(remote_path)

def listdir(self, remote_path, recursive=False):
if self._basedir:
remote_path = os.path.join(self._basedir, remote_path)
listfile = []
if not os.path.isdir(remote_path):
raise ValueError("%s is not a directory" % remote_path)

def getfiles_rec(path):
"""recursive listdir"""
for f in os.listdir(path):
fullpath = os.path.join(path, f)
if self._basedir:
rel_fullpath = os.path.relpath(fullpath, self._basedir)
else:
rel_fullpath = fullpath
if os.path.isdir(fullpath):
if recursive:
getfiles_rec(fullpath)
else:
listfile.append(rel_fullpath+'/')
else:
listfile.append(rel_fullpath)

getfiles_rec(remote_path)

return listfile

def rename(self, old_remote_path, new_remote_path):
if self._basedir:
old_remote_path = os.path.join(self._basedir, old_remote_path)
if self._basedir:
new_remote_path = os.path.join(self._basedir, new_remote_path)
os.rename(old_remote_path, new_remote_path)

def exists(self, remote_path):
if self._basedir:
remote_path = os.path.join(self._basedir, remote_path)
return os.path.exists(remote_path)
110 changes: 110 additions & 0 deletions nmtwizard/storages/s3.py
@@ -0,0 +1,110 @@
"""Definition of `s3` storage class"""

import os
import boto3

from nmtwizard.storages.generic import Storage

class S3Storage(Storage):
"""Storage on Amazon S3."""

def __init__(self, storage_id, bucket_name, access_key_id=None, secret_access_key=None, region_name=None):
super(S3Storage, self).__init__(storage_id)
if access_key_id is None and secret_access_key is None and region_name is None:
session = boto3
else:
session = boto3.Session(
aws_access_key_id=access_key_id,
aws_secret_access_key=secret_access_key,
region_name=region_name)
self._s3 = session.resource('s3')
self._bucket_name = bucket_name
self._bucket = self._s3.Bucket(bucket_name)

def get(self, remote_path, local_path, directory=False):
if not directory:
if os.path.isdir(local_path):
local_path = os.path.join(local_path, os.path.basename(remote_path))
self._bucket.download_file(remote_path, local_path)
else:
objects = list(self._bucket.objects.filter(Prefix=remote_path))
if not objects:
raise RuntimeError('%s not found' % remote_path)
os.mkdir(local_path)
for obj in objects:
directories = obj.key.split('/')[1:-1]
if directories:
directory_path = os.path.join(local_path, os.path.join(*directories))
if not os.path.exists(directory_path):
os.makedirs(directory_path)
path = os.path.join(local_path, os.path.join(*obj.key.split('/')[1:]))
self._bucket.download_file(obj.key, path)

def push(self, local_path, remote_path):
if os.path.isfile(local_path):
if remote_path.endswith('/') or self.exists(remote_path+'/'):
remote_path = os.path.join(remote_path, os.path.basename(local_path))
self._bucket.upload_file(local_path, remote_path)
else:
for root, _, files in os.walk(local_path):
for filename in files:
path = os.path.join(root, filename)
relative_path = os.path.relpath(path, local_path)
s3_path = os.path.join(remote_path, relative_path)
self._bucket.upload_file(path, s3_path)

def stream(self, remote_path, buffer_size=1024):
body = self._s3.Object(self._bucket_name, remote_path).get()['Body']

def generate():
for chunk in iter(lambda: body.read(buffer_size), b''):
yield chunk

return generate()

def listdir(self, remote_path, recursive=False):
objects = list(self._bucket.objects.filter(Prefix=remote_path))
lsdir = {}
for obj in objects:
path = obj.key
p = path.find('/', len(remote_path)+1)
if not recursive and p != -1:
path = path[0:p+1]
lsdir[path] = 1
else:
lsdir[path] = 0
return lsdir.keys()

def delete(self, remote_path, recursive=False):
lsdir = self.listdir(remote_path, recursive)
if recursive:
if remote_path in lsdir or not lsdir:
raise ValueError("%s is not a directory" % remote_path)
else:
if remote_path not in lsdir:
raise ValueError("%s is not a file" % remote_path)

for key in lsdir:
self._s3.meta.client.delete_object(Bucket=self._bucket_name, Key=key)

def rename(self, old_remote_path, new_remote_path):
for obj in self._bucket.objects.filter(Prefix=old_remote_path):
src_key = obj.key
if not src_key.endswith('/'):
copy_source = self._bucket_name + '/' + src_key
if src_key == old_remote_path:
# it is a file that we are copying
dest_file_key = new_remote_path
else:
filename = src_key.split('/')[-1]
dest_file_key = new_remote_path + '/' + filename
self._s3.Object(self._bucket_name, dest_file_key).copy_from(CopySource=copy_source)
self._s3.Object(self._bucket_name, src_key).delete()

def exists(self, remote_path):
result = self._bucket.objects.filter(Prefix=remote_path)
try:
obj = iter(result).next()
except StopIteration:
return False
return obj.key == remote_path or obj.key.startswith(remote_path+"/")