Skip to content

Commit

Permalink
fix: properly close files in lineage queries and tests (#4587)
Browse files Browse the repository at this point in the history
Closes #4458
  • Loading branch information
jmahlik committed Apr 25, 2024
1 parent ed390dd commit 2a52478
Show file tree
Hide file tree
Showing 7 changed files with 24 additions and 14 deletions.
4 changes: 2 additions & 2 deletions src/sagemaker/lineage/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,8 +335,8 @@ def _get_legend_line(self, component_name):

def _add_legend(self, path):
"""Embed legend to html file generated by pyvis."""
f = open(path, "r")
content = self.BeautifulSoup(f, "html.parser")
with open(path, "r") as f:
content = self.BeautifulSoup(f, "html.parser")

legend = """
<div style="display: inline-block; font-size: 1vw; font-family: verdana;
Expand Down
3 changes: 2 additions & 1 deletion tests/data/sip/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ def main():
)

model_dir = os.environ.get("SM_MODEL_DIR")
pkl.dump(bst, open(model_dir + "/model.bin", "wb"))
with open(model_dir + "/model.bin", "wb") as f:
pkl.dump(bst, f)


if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions tests/integ/sagemaker/lineage/test_lineage_visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,8 @@ def test_graph_visualize(sagemaker_session, extract_data_from_html):
lq_result.visualize(path="testGraph.html")

# check generated graph info
fo = open("testGraph.html", "r")
lines = fo.readlines()
with open("testGraph.html", "r") as fo:
lines = fo.readlines()
for line in lines:
if "nodes = " in line:
node = line
Expand Down
3 changes: 2 additions & 1 deletion tests/integ/sagemaker/workflow/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,8 @@ def test_one_step_ingestion_pipeline(

temp_flow_path = "./ingestion.flow"
with cleanup_feature_group(feature_group):
json.dump(ingestion_only_flow, open(temp_flow_path, "w"))
with open(temp_flow_path, "w") as f:
json.dump(ingestion_only_flow, f)

data_wrangler_processor = DataWranglerProcessor(
role=role,
Expand Down
6 changes: 4 additions & 2 deletions tests/integ/test_sagemaker_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ def expected_merged_config():
expected_merged_config_file_path = os.path.join(
CONFIG_DATA_DIR, "expected_output_config_after_merge.yaml"
)
return yaml.safe_load(open(expected_merged_config_file_path, "r").read())
with open(expected_merged_config_file_path, "r") as f:
return yaml.safe_load(f.read())


@pytest.fixture(scope="module")
Expand Down Expand Up @@ -171,7 +172,8 @@ def test_config_download_from_s3_and_merge(
CONFIG_DATA_DIR, "sample_additional_config_for_merge.yaml"
)

config_file_1_as_yaml = open(config_file_1_local_path, "r").read()
with open(config_file_1_local_path, "r") as f:
config_file_1_as_yaml = f.read()
s3_uri_config_1 = os.path.join(s3_uri_prefix, "config_1.yaml")

# Upload S3 files in case they dont already exist
Expand Down
12 changes: 8 additions & 4 deletions tests/unit/sagemaker/local/test_local_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,9 +221,12 @@ def test_write_config_file(LocalSession, tmpdir):
assert os.path.exists(resource_config_file)
assert os.path.exists(input_data_config_file)

hyperparameters_data = json.load(open(hyperparameters_file))
resource_config_data = json.load(open(resource_config_file))
input_data_config_data = json.load(open(input_data_config_file))
with open(hyperparameters_file) as f:
hyperparameters_data = json.load(f)
with open(resource_config_file) as f:
resource_config_data = json.load(f)
with open(input_data_config_file) as f:
input_data_config_data = json.load(f)

# Validate HyperParameters
for k, v in HYPERPARAMETERS.items():
Expand Down Expand Up @@ -280,7 +283,8 @@ def test_write_config_files_input_content_type(LocalSession, tmpdir):
sagemaker_container.write_config_files(host, HYPERPARAMETERS, input_data_config)

assert os.path.exists(input_data_config_file)
parsed_input_config = json.load(open(input_data_config_file))
with open(input_data_config_file) as f:
parsed_input_config = json.load(f)
# Validate Input Data Config
for channel in input_data_config:
assert channel["ChannelName"] in parsed_input_config
Expand Down
6 changes: 4 additions & 2 deletions tests/unit/sagemaker/serializers/test_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,13 +345,15 @@ def test_data_serializer_raw(data_serializer):
input_image = image.read()
input_image_data = data_serializer.serialize(input_image)
validation_image_file_path = os.path.join(DATA_DIR, "", "cuteCat.raw")
validation_image_data = open(validation_image_file_path, "rb").read()
with open(validation_image_file_path, "rb") as f:
validation_image_data = f.read()
assert input_image_data == validation_image_data


def test_data_serializer_file_like(data_serializer):
input_image_file_path = os.path.join(DATA_DIR, "", "cuteCat.jpg")
validation_image_file_path = os.path.join(DATA_DIR, "", "cuteCat.raw")
input_image_data = data_serializer.serialize(input_image_file_path)
validation_image_data = open(validation_image_file_path, "rb").read()
with open(validation_image_file_path, "rb") as f:
validation_image_data = f.read()
assert input_image_data == validation_image_data

0 comments on commit 2a52478

Please sign in to comment.