Skip to content

Commit

Permalink
Merge pull request #1768 from davidmarin/google-download-chunks
Browse files Browse the repository at this point in the history
GCS cat() now streams chunks incrementally (fixes #1674)
  • Loading branch information
David Marin committed May 7, 2018
2 parents de79472 + 5474848 commit 540e442
Show file tree
Hide file tree
Showing 7 changed files with 209 additions and 45 deletions.
11 changes: 11 additions & 0 deletions mrjob/cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ def decompress(readable, path, bufsize=1024):
possibly decompressing based on *path*.
if *readable* appears to be a fileobj, pass it through as-is.
if *readable* does not have a ``read()`` method, assume that it's
a generator that yields chunks of bytes
"""
if path.endswith('.gz'):
return gunzip_stream(readable)
Expand All @@ -103,7 +106,15 @@ def is_compressed(path):
def to_chunks(readable, bufsize=1024):
"""Convert *readable*, which is any object supporting ``read()``
(e.g. fileobjs) to a stream of non-empty ``bytes``.
If *readable* has an ``__iter__`` method but not a ``read`` method,
pass through as-is.
"""
if hasattr(readable, '__iter__') and not hasattr(readable, 'read'):
for chunk in readable:
yield chunk
return

while True:
chunk = readable.read(bufsize)
if chunk:
Expand Down
9 changes: 6 additions & 3 deletions mrjob/dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from mrjob.conf import combine_dicts
from mrjob.fs.composite import CompositeFilesystem
from mrjob.fs.gcs import GCSFilesystem
from mrjob.fs.gcs import _download_as_string
from mrjob.fs.gcs import is_gcs_uri
from mrjob.fs.gcs import parse_gcs_uri
from mrjob.fs.local import LocalFilesystem
Expand Down Expand Up @@ -891,9 +892,11 @@ def _get_new_driver_output_lines(self, driver_output_uri):

try:
# TODO: use start= kwarg once google-cloud-storage 1.9 is out
new_data = log_blob.download_as_string()[state['pos']:]
except google.api_core.exceptions.NotFound:
# handle race condition where blob was just created
#new_data = log_blob.download_as_string()[state['pos']:]
new_data = _download_as_string(log_blob, start=state['pos'])
except (google.api_core.exceptions.NotFound,
google.api_core.exceptions.RequestRangeNotSatisfiable):
# blob was just created, or no more data is available
break

state['buffer'] += new_data
Expand Down
109 changes: 99 additions & 10 deletions mrjob/fs/gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import base64
import binascii
import fnmatch
import hashlib
import logging
from base64 import b64decode
from io import BytesIO
from tempfile import TemporaryFile

from mrjob.cat import decompress
Expand All @@ -26,15 +29,20 @@
from mrjob.runner import GLOB_RE

try:
import google.api_core.exceptions
import google.cloud._helpers
import google.cloud.exceptions
import google.cloud.storage.client
from google.api_core.exceptions import NotFound
import google.resumable_media
import google.resumable_media.requests
except ImportError:
NotFound = None
google = None


log = logging.getLogger(__name__)

# download this many bytes at once from cat()
_CAT_CHUNK_SIZE = 8192


def _path_glob_to_parsed_gcs_uri(path_glob):
# support globs
Expand Down Expand Up @@ -115,7 +123,7 @@ def _ls(self, path_glob):

try:
bucket = self.get_bucket(bucket_name)
except NotFound:
except google.api_core.exceptions.NotFound:
return # treat nonexistent buckets as empty

for blob in bucket.list_blobs(prefix=base_name):
Expand All @@ -135,19 +143,30 @@ def md5sum(self, path):
return binascii.hexlify(b64decode(blob.md5_hash)).decode('ascii')

def _cat_file(self, gcs_uri):
return decompress(self._cat_blob(gcs_uri), gcs_uri)

def _cat_blob(self, gcs_uri):
""":py:meth:`cat_file`, minus decompression."""
blob = self._get_blob(gcs_uri)

if not blob:
return # don't cat nonexistent files

with TemporaryFile(dir=self._local_tmp_dir) as temp:
blob.download_to_file(temp)
start = 0

# now read from that file
temp.seek(0)
while True:
end = start + _CAT_CHUNK_SIZE
try:
chunk = _download_as_string(blob, start=start, end=end)
except google.api_core.exceptions.RequestRangeNotSatisfiable:
return

for chunk in decompress(temp, gcs_uri):
yield chunk
yield chunk

if len(chunk) < _CAT_CHUNK_SIZE:
return

start = end

def mkdir(self, dest):
"""Make a directory. This does nothing on GCS because there are
Expand Down Expand Up @@ -274,3 +293,73 @@ def parse_gcs_uri(uri):
raise ValueError('Invalid GCS URI: %s' % uri)

return components.netloc, components.path[1:]


# temporary shim for incremental download, taken from
# https://github.com/GoogleCloudPlatform/google-cloud-python
# Remove this once google-cloud-storage>1.8.0 comes out.

# note that this raises RequestRangeNotSatisfiable if start is at the
# end of the blob
def _download_as_string(blob, client=None, start=None, end=None):
string_buffer = BytesIO()
_download_to_file(
blob, string_buffer, client=client, start=start, end=end)
return string_buffer.getvalue()

# don't call the functions below directly; they're just to support
# _download_as_string()

def _download_to_file(blob, file_obj, client=None, start=None, end=None):
download_url = blob._get_download_url()
headers = _get_encryption_headers(blob._encryption_key)
headers['accept-encoding'] = 'gzip'

transport = blob._get_transport(client)
try:
_do_download(
blob, transport, file_obj, download_url, headers, start, end)
except google.resumable_media.InvalidResponse as exc:
_raise_from_invalid_response(exc)


def _do_download(blob, transport, file_obj, download_url, headers,
start=None, end=None):
if blob.chunk_size is None:
download = google.resumable_media.requests.Download(
download_url, stream=file_obj, headers=headers,
start=start, end=end)
download.consume(transport)
else:
download = google.resumable_media.requests.ChunkedDownload(
download_url, blob.chunk_size, file_obj, headers=headers,
start=start if start else 0, end=end)

while not download.finished:
download.consume_next_chunk(transport)


def _get_encryption_headers(key, source=False):
if key is None:
return {}

key = google.cloud._helpers._to_bytes(key)
key_hash = hashlib.sha256(key).digest()
key_hash = base64.b64encode(key_hash)
key = base64.b64encode(key)

if source:
prefix = 'X-Goog-Copy-Source-Encryption-'
else:
prefix = 'X-Goog-Encryption-'

return {
prefix + 'Algorithm': 'AES256',
prefix + 'Key': google.cloud.helpers._bytes_to_unicode(key),
prefix + 'Key-Sha256': (
google.cloud.helpers._bytes_to_unicode(key_hash)),
}


def _raise_from_invalid_response(error):
raise google.cloud.exceptions.from_http_response(error.response)
15 changes: 14 additions & 1 deletion tests/fs/test_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from hashlib import md5

from mrjob.fs.gcs import GCSFilesystem
from mrjob.fs.gcs import _CAT_CHUNK_SIZE

from tests.compress import gzip_compress
from tests.mock_google import MockGoogleTestCase
Expand Down Expand Up @@ -57,13 +58,25 @@ def test_cat_gz(self):

def test_chunks_file(self):
self.put_gcs_multi({
'gs://walrus/data/foo': b'foo\nfoo\n' * 1000
'gs://walrus/data/foo': b'foo\nfoo\n' * 10000
})

self.assertGreater(
len(list(self.fs._cat_file('gs://walrus/data/foo'))),
1)

def test_chunk_boundary(self):
# trying to read from end of file raises an exception, which we catch
data = b'a' * _CAT_CHUNK_SIZE + b'b' * _CAT_CHUNK_SIZE

self.put_gcs_multi({
'gs://walrus/data/foo': data,
})

self.assertEqual(
list(self.fs._cat_file('gs://walrus/data/foo')),
[b'a' * _CAT_CHUNK_SIZE, b'b' * _CAT_CHUNK_SIZE])


class GCSFSTestCase(MockGoogleTestCase):

Expand Down
7 changes: 6 additions & 1 deletion tests/mock_google/case.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from .dataproc import MockGoogleDataprocJobClient
from .logging import MockGoogleLoggingClient
from .storage import MockGoogleStorageClient
from .storage import _mock_download_as_string_shim
from tests.mr_two_step_job import MRTwoStepJob
from tests.py2 import Mock
from tests.py2 import patch
Expand Down Expand Up @@ -75,13 +76,17 @@ def setUp(self):
self.start(patch('google.cloud.dataproc_v1.JobControllerClient',
self.job_client))

# TODO: mock this
self.start(patch('google.cloud.logging.Client',
self.logging_client))

self.start(patch('google.cloud.storage.client.Client',
self.storage_client))

self.start(patch('mrjob.dataproc._download_as_string',
_mock_download_as_string_shim))
self.start(patch('mrjob.fs.gcs._download_as_string',
_mock_download_as_string_shim))

self.start(patch('time.sleep'))

def auth_default(self, scopes=None):
Expand Down
75 changes: 48 additions & 27 deletions tests/mock_google/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from google.api_core.exceptions import Conflict
from google.api_core.exceptions import NotFound
from google.api_core.exceptions import RequestRangeNotSatisfiable


class MockGoogleStorageClient(object):
Expand Down Expand Up @@ -139,38 +140,34 @@ def __init__(self, name, bucket, chunk_size=None):
self.md5_hash = None

def delete(self):
fs = self.bucket.client.mock_gcs_fs
if (self.bucket.name not in self._fs or
self.name not in self._fs[self.bucket.name]['blobs']):
raise NotFound('DELETE %s: Not Found' % self._blob_uri())

if (self.bucket.name not in fs or
self.name not in fs[self.bucket.name]['blobs']):
raise NotFound('DELETE https://www.googleapis.com/storage/v1/b'
'/%s/o/%s: Not Found' %
(self.bucket.name, self.name))

del fs[self.bucket.name]['blobs'][self.name]

def download_as_string(self):
fs = self.bucket.client.mock_gcs_fs
del self._fs[self.bucket.name]['blobs'][self.name]

# this mocks a future version of this method which is
# currently only available in dev. Our code accesses the start and end
# keywords through mrjob.fs.gcs._download_as_string(). See
# _mock_download_as_string_shim() below.
def download_as_string(self, client=None, start=None, end=None):
try:
return fs[self.bucket.name]['blobs'][self.name]['data']
data = self._fs[self.bucket.name]['blobs'][self.name]['data']
except KeyError:
raise NotFound('GET https://www.googleapis.com/download/storage'
'/v1/b/%s/o/%s?alt=media: Not Found' %
(self.bucket.name, self.name))
raise NotFound('GET %s?alt=media: Not Found' % self._blob_uri())

if start is not None and start >= len(data):
# it doesn't care if *end* exceeds the range
raise RequestRangeNotSatisfiable(
'GET %s?alt=media: Request range not satisfiable' %
self._blob_uri())

return data[start:end]

def download_to_file(self, file_obj):
data = self.download_as_string()
file_obj.write(data)

def _set_md5_hash(self):
# call this when we upload data, or when we _get_blob
try:
self.md5_hash = b64encode(
md5(self.download_as_string()).digest())
except NotFound:
pass

@property
def size(self):
try:
Expand All @@ -185,15 +182,39 @@ def upload_from_filename(self, filename):
self.upload_from_string(data)

def upload_from_string(self, data):
fs = self.bucket.client.mock_gcs_fs

if self.bucket.name not in fs:
if self.bucket.name not in self._fs:
raise NotFound('POST https://www.googleapis.com/upload/storage'
'/v1/b/%s/o?uploadType=multipart: Not Found' %
self.bucket.name)

fs_objs = fs[self.bucket.name]['blobs']
fs_objs = self._fs[self.bucket.name]['blobs']
fs_obj = fs_objs.setdefault(self.name, dict(data=b''))
fs_obj['data'] = data

self._set_md5_hash()

def _blob_uri(self):
# used for error messages
return ('https://www.googleapis.com/download/storage'
'/v1/b/%s/o/%s' % (self.bucket.name, self.name))

@property
def _fs(self):
return self.bucket.client.mock_gcs_fs

def _set_md5_hash(self):
# call this when we upload data, or when we Bucket.get_blob()

try:
# don't call download_as_string() because we need to mock
# exceptions from it
data = self._fs[self.bucket.name]['blobs'][self.name]['data']
except KeyError:
pass

self.md5_hash = b64encode(md5(data).digest())


# mock mrjob.fs.gcs._download_as_string(), which is a shim
def _mock_download_as_string_shim(blob, client=None, start=None, end=None):
return blob.download_as_string(client=client, start=start, end=end)

0 comments on commit 540e442

Please sign in to comment.