Skip to content

Commit

Permalink
fix: support content_type in FileSystemInput (#1073)
Browse files Browse the repository at this point in the history
  • Loading branch information
chuyang-deng committed Oct 2, 2019
1 parent a1b63b4 commit b006d73
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 3 deletions.
10 changes: 9 additions & 1 deletion src/sagemaker/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,12 @@ class FileSystemInput(object):
"""

def __init__(
self, file_system_id, file_system_type, directory_path, file_system_access_mode="ro"
self,
file_system_id,
file_system_type,
directory_path,
file_system_access_mode="ro",
content_type=None,
):
"""Create a new file system input used by an SageMaker training job.
Expand Down Expand Up @@ -144,3 +149,6 @@ def __init__(
}
}
}

if content_type:
self.config["ContentType"] = content_type
5 changes: 4 additions & 1 deletion tests/integ/file_system_input_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,10 @@ def _ami_id_for_region(sagemaker_session):
def _connect_ec2_instance(ec2_instance):
public_ip_address = ec2_instance.public_ip_address
connected_instance = Connection(
host=public_ip_address, port=22, user="ec2-user", connect_kwargs={"key_filename": KEY_PATH}
host=public_ip_address,
port=22,
user="ec2-user",
connect_kwargs={"key_filename": [KEY_PATH]},
)
return connected_instance

Expand Down
6 changes: 5 additions & 1 deletion tests/integ/test_tf_efs_fsx.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,12 @@ def test_mnist_efs(efs_fsx_setup, sagemaker_session, cpu_instance_type):
)

file_system_efs_id = efs_fsx_setup["file_system_efs_id"]
content_type = "application/json"
file_system_input = FileSystemInput(
file_system_id=file_system_efs_id, file_system_type="EFS", directory_path=EFS_DIR_PATH
file_system_id=file_system_efs_id,
file_system_type="EFS",
directory_path=EFS_DIR_PATH,
content_type=content_type,
)
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
estimator.fit(inputs=file_system_input, job_name=unique_name_from_base("test-mnist-efs"))
Expand Down
27 changes: 27 additions & 0 deletions tests/unit/test_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,33 @@ def test_file_system_input_all_arguments():
assert actual.config == expected


def test_file_system_input_content_type():
file_system_id = "fs-0a48d2a1"
file_system_type = "FSxLustre"
directory_path = "tensorflow"
file_system_access_mode = "rw"
content_type = "application/json"
actual = FileSystemInput(
file_system_id=file_system_id,
file_system_type=file_system_type,
directory_path=directory_path,
file_system_access_mode=file_system_access_mode,
content_type=content_type,
)
expected = {
"DataSource": {
"FileSystemDataSource": {
"FileSystemId": file_system_id,
"FileSystemType": file_system_type,
"DirectoryPath": directory_path,
"FileSystemAccessMode": "rw",
}
},
"ContentType": content_type,
}
assert actual.config == expected


def test_file_system_input_type_invalid():
with pytest.raises(ValueError) as excinfo:
file_system_id = "fs-0a48d2a1"
Expand Down

0 comments on commit b006d73

Please sign in to comment.