Skip to content

Commit

Permalink
Merge pull request #1122 from TeamCOMPAS/fix_h5sample_dimension_error
Browse files Browse the repository at this point in the history
add try-catch for h5sample if there are issues parsing the file
  • Loading branch information
avivajpeyi committed May 10, 2024
2 parents 637dad6 + 3c3a693 commit abcfe78
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 29 deletions.
18 changes: 13 additions & 5 deletions compas_python_utils/h5sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import h5py
import numpy as np
import os

from compas_python_utils.h5view import printSummary
from compas_python_utils.h5copy import copyHDF5File
Expand Down Expand Up @@ -86,10 +87,17 @@ def sample_h5(
with h5py.File(compas_h5_filepath, "r") as compas_h5_file:
printSummary(compas_h5_filepath, compas_h5_file)

with h5py.File(output_filepath, 'w') as out_h5_file:
copyHDF5File(compas_h5_filepath, out_h5_file)
sampled_binary_seeds = np.random.choice(binary_seeds, size=n, replace=replace)
_sample(out_h5_file, seed_key, sampled_binary_seeds)
try:
with h5py.File(output_filepath, 'w') as out_h5_file:
copyHDF5File(compas_h5_filepath, out_h5_file)
sampled_binary_seeds = np.random.choice(binary_seeds, size=n, replace=replace)
_sample(out_h5_file, seed_key, sampled_binary_seeds)
except Exception as e:
print(f"Error sampling COMPAS h5 file: {e}")
# remove the output file if it was created
if os.path.exists(output_filepath):
os.remove(output_filepath)
return

print("Sampled file summary:")
with h5py.File(output_filepath, "r") as output_h5_file:
Expand Down Expand Up @@ -187,4 +195,4 @@ def main(): # pragma: no cover


if __name__ == "__main__": # pragma: no cover
main()
main()
48 changes: 24 additions & 24 deletions py_tests/test_h5sample.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,38 @@
import os

import numpy as np
import pytest
from conftest import get_compas_data
from deepdiff import DeepDiff

from compas_python_utils import h5sample
import pytest


def test_sample(tmp_path, example_compas_output_path):
@pytest.mark.parametrize(
"test_kwargs",
[
dict(n=100, replace=True),
dict(frac=0.5, replace=False),
dict(frac=1.2, replace=True),
dict(seed_group="BSE_Double_Compact_Objects", frac=2.0, replace=True),
],
)
def test_sample(test_kwargs, tmp_path, fake_compas_output):
"""Test that h5sample can sample a file"""
np.random.seed(0)
init_file = example_compas_output_path

test_kwargs = [
dict(n=100, replace=True),
dict(frac=0.5, replace=False),
dict(frac=1.2, replace=True),
dict(seed_group="BSE_Double_Compact_Objects", frac=2.0, replace=True),
]

for kwg in test_kwargs:
# creating new h5 file to copy contents to
new_file = f"{tmp_path}/sampled_compas_out.h5"
if os.path.exists(new_file):
os.remove(new_file)
h5sample.sample_h5(init_file, new_file, **kwg)
assert os.path.exists(new_file), f"File {new_file} does not exist"
init_data = get_compas_data(init_file)
new_data = get_compas_data(new_file)
diff = DeepDiff(init_data, new_data)
assert (
len(diff) > 0
), f"The sampled file is the same as the original when using kwgs: {kwg}"
init_file = fake_compas_output

new_file = f"{tmp_path}/sampled_compas_out.h5"
if os.path.exists(new_file):
os.remove(new_file)
h5sample.sample_h5(init_file, new_file, **test_kwargs)
assert os.path.exists(new_file), f"File {new_file} does not exist"
init_data = get_compas_data(init_file)
new_data = get_compas_data(new_file)
diff = DeepDiff(init_data, new_data)
assert (
len(diff) > 0
), f"The sampled file is the same as the original when using kwgs: {test_kwargs}"


def test_argparser(capsys):
Expand Down

0 comments on commit abcfe78

Please sign in to comment.