This repository has been archived by the owner on Aug 26, 2020. It is now read-only.
/
_files.py
164 lines (129 loc) · 5.33 KB
/
_files.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Placeholder docstring"""
from __future__ import absolute_import
import contextlib
import json
import os
import shutil
import tarfile
import tempfile
import boto3
from six.moves.urllib import parse
from sagemaker_containers import _env, _params
def write_success_file(): # type: () -> None
"""Create a file 'success' when training is successful. This file doesn't need to
have any content.
See: https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-training-algo.html
"""
file_path = os.path.join(_env.output_dir, "success")
empty_content = ""
write_file(file_path, empty_content)
def write_failure_file(failure_msg): # type: (str) -> None
"""Create a file 'failure' if training fails after all algorithm output (for example,
logging) completes, the failure description should be written to this file. In a
DescribeTrainingJob response, Amazon SageMaker returns the first 1024 characters from
this file as FailureReason.
See: https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-training-algo.html
Args:
failure_msg: The description of failure
"""
file_path = os.path.join(_env.output_dir, "failure")
write_file(file_path, failure_msg)
@contextlib.contextmanager
def tmpdir(suffix="", prefix="tmp", directory=None): # type: (str, str, str) -> None
"""Create a temporary directory with a context manager. The file is deleted when the
context exits.
The prefix, suffix, and dir arguments are the same as for mkstemp().
Args:
suffix (str): If suffix is specified, the file name will end with that suffix,
otherwise there will be no suffix.
prefix (str): If prefix is specified, the file name will begin with that prefix;
otherwise, a default prefix is used.
directory (str): If directory is specified, the file will be created in that directory;
otherwise, a default directory is used.
Returns:
str: path to the directory
"""
tmp = tempfile.mkdtemp(suffix=suffix, prefix=prefix, dir=directory)
yield tmp
shutil.rmtree(tmp)
def write_file(path, data, mode="w"): # type: (str, str, str) -> None
"""Write data to a file.
Args:
path (str): path to the file.
data (str): data to be written to the file.
mode (str): mode which the file will be open.
"""
with open(path, mode) as f:
f.write(data)
def read_file(path, mode="r"):
"""Read data from a file.
Args:
path (str): path to the file.
mode (str): mode which the file will be open.
Returns:
"""
with open(path, mode) as f:
return f.read()
def read_json(path): # type: (str) -> dict
"""Read a JSON file.
Args:
path (str): Path to the file.
Returns:
(dict[object, object]): A dictionary representation of the JSON file.
"""
with open(path, "r") as f:
return json.load(f)
def download_and_extract(uri, path): # type: (str, str) -> None
"""Download, prepare and install a compressed tar file from S3 or local directory as
an entry point.
SageMaker Python SDK saves the user provided entry points as compressed tar files in S3
Args:
uri (str): the location of the entry point.
path (bool): The path where the script will be installed. It will not download and
install the if the path already has the user entry point.
"""
if not os.path.exists(path):
os.makedirs(path)
if not os.listdir(path):
with tmpdir() as tmp:
if uri.startswith("s3://"):
dst = os.path.join(tmp, "tar_file")
s3_download(uri, dst)
with tarfile.open(name=dst, mode="r:gz") as t:
t.extractall(path=path)
elif os.path.isdir(uri):
if uri == path:
return
if os.path.exists(path):
shutil.rmtree(path)
shutil.copytree(uri, path)
elif tarfile.is_tarfile(uri):
with tarfile.open(name=uri, mode="r:gz") as t:
t.extractall(path=path)
else:
shutil.copy2(uri, path)
def s3_download(url, dst): # type: (str, str) -> None
"""Download a file from S3.
Args:
url (str): the s3 url of the file.
dst (str): the destination where the file will be saved.
"""
url = parse.urlparse(url)
if url.scheme != "s3":
raise ValueError("Expecting 's3' scheme, got: %s in %s" % (url.scheme, url))
bucket, key = url.netloc, url.path.lstrip("/")
region = os.environ.get("AWS_REGION", os.environ.get(_params.REGION_NAME_ENV))
s3 = boto3.resource("s3", region_name=region)
s3.Bucket(bucket).download_file(key, dst)