diff --git a/.ci/scripts/check_deprecated_terms.py b/.ci/scripts/check_deprecated_terms.py index e255677e..23b9bc58 100644 --- a/.ci/scripts/check_deprecated_terms.py +++ b/.ci/scripts/check_deprecated_terms.py @@ -1,11 +1,11 @@ #!/usr/bin/env python3 # .ci/scripts/check_deprecated_terms.py - + import os import re import sys import argparse - + # ---- Term sets ---- # Block these when you're on the *2.x* branch (i.e., forbid legacy 1.x names): TERMS_1X = [ @@ -16,7 +16,7 @@ "load-worker-coordinator-hosts", "execute-test", ] - + # Block these when you're on the *1.x* branch (i.e., forbid 2.x names): TERMS_2X = [ "cluster-configs", @@ -25,15 +25,13 @@ "run-test", "test-run", ] - + SKIP_DIRS = {".git", "venv", "__pycache__", ".pytest_cache", ".ci", "tests"} VALID_EXTENSIONS = (".py", ".yml", ".yaml", ".md", ".sh", ".json", ".txt") - -SUPRESS_MARKERS = { - "block-1x": "check-deprecated-terms-disable-1x", - "block-2x": "check-deprecated-terms-disable-2x" -} - + +SUPRESS_MARKERS = {"block-1x": "check-deprecated-terms-disable-1x", "block-2x": "check-deprecated-terms-disable-2x"} + + def generate_variants(term: str) -> set[str]: base = term.replace("-", " ").replace("_", " ") words = base.split() @@ -43,24 +41,27 @@ def generate_variants(term: str) -> set[str]: variants.add("_".join(words)) variants.add("".join([w.capitalize() for w in words])) # PascalCase variants.add(words[0] + "".join([w.capitalize() for w in words[1:]])) # camelCase - + # Optional: flip order for 2-word terms, but avoid silly "-ip" flips creating noise if len(words) == 2 and not words[1].lower() == "ip": variants.add("-".join(words[::-1])) variants.add("_".join(words[::-1])) variants.add(words[1] + words[0].capitalize()) # camelCase reverse return variants - + + def build_patterns(terms: list[str]) -> list[re.Pattern]: pats = [] for t in terms: for v in generate_variants(t): pats.append(re.compile(re.escape(v), re.IGNORECASE)) return pats - + + def should_check_file(path: str) -> bool: return path.endswith(VALID_EXTENSIONS) - + + def walk_and_check(patterns: list[re.Pattern], mode: str) -> int: error_found = 0 suppress_marker = SUPRESS_MARKERS.get(mode) @@ -86,24 +87,25 @@ def walk_and_check(patterns: list[re.Pattern], mode: str) -> int: except Exception as e: print(f"[Warning] Skipped file {full_path}: {e}") return error_found - + + def main(): p = argparse.ArgumentParser(description="Check forbidden term set by mode or env.") p.add_argument("--mode", choices=["block-1x", "block-2x"], default=os.getenv("OSB_TERM_MODE")) args = p.parse_args() - + mode = args.mode if not mode: print("No mode provided (use --mode block-1x | block-2x or set OSB_TERM_MODE). Exiting 0.") sys.exit(0) - + if mode == "block-1x": terms = TERMS_1X banner = "❌ 1.x terms found in 2.x branch. Replace with 2.x names." else: terms = TERMS_2X banner = "❌ 2.x terms found in 1.x branch. Replace with 1.x names." - + patterns = build_patterns(terms) failed = walk_and_check(patterns, mode) if failed: @@ -111,6 +113,7 @@ def main(): sys.exit(1) print("✅ No forbidden terms found for", mode) sys.exit(0) - + + if __name__ == "__main__": main() diff --git a/AGENTS.md b/AGENTS.md index 97ef3790..8ca68902 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -17,7 +17,8 @@ source .venv/bin/activate # Activate virtual environment ## Common Commands ```bash -make lint # Run ruff check on all Python source files +make format # Auto-format all Python source files (run before committing) +make lint # Run ruff check + ruff format --check (enforced in CI) make test # Run unit tests (pytest tests/) pytest tests/path/to/test_file.py::TestClass::test_method # Run a single test make it # Run integration tests via tox (requires Java, Docker; ~30 min) @@ -29,9 +30,9 @@ make clean # Remove build artifacts, caches, tox environments ## Code Style -- **Linter**: [ruff](https://docs.astral.sh/ruff/) (configured in `pyproject.toml` under `[tool.ruff]`) +- **Linter + formatter**: [ruff](https://docs.astral.sh/ruff/) (configured in `pyproject.toml` under `[tool.ruff]`) - **Max line length**: 180 characters -- Run `make lint` before committing; CI enforces this on every PR +- Run `ruff format .` then `make lint` before committing; CI enforces both on every PR ## Architecture diff --git a/DEVELOPER_GUIDE.md b/DEVELOPER_GUIDE.md index 9860fd69..af1d2d9e 100644 --- a/DEVELOPER_GUIDE.md +++ b/DEVELOPER_GUIDE.md @@ -9,6 +9,7 @@ Apache Solr Orbit. - [Importing the project into an IDE](#importing-the-project-into-an-ide) - [Setting Up a Local Solr Instance (Optional)](#setting-up-a-local-solr-instance-optional) - [Running Tests](#running-tests) +- [Code Style](#code-style) - [Submitting a Pull Request](#submitting-a-pull-request) - [Developing Breaking Changes](#developing-breaking-changes) - [Miscellaneous](#miscellaneous) @@ -110,9 +111,29 @@ Integration tests require a running Solr instance (local or Docker). make it ``` +## Code Style + +The project uses [ruff](https://docs.astral.sh/ruff/) for both linting and formatting. +`make lint` runs both checks and is enforced in CI on every PR. + +### Before committing + +Always run: + +```bash +make format # auto-format all Python files +make lint # verify lint and format are clean +``` + +Or set up your editor to format on save — the [ruff VS Code extension](https://marketplace.visualstudio.com/items?itemName=charliermarsh.ruff) and the [ruff PyCharm plugin](https://plugins.jetbrains.com/plugin/20574-ruff) both support this. + +Configuration lives in `pyproject.toml` under `[tool.ruff]`. Key settings: +- **Max line length**: 180 characters +- **Rules**: pycodestyle (`E`) and pyflakes (`F`) + ## Submitting a Pull Request -1. **Run tests**: `make test` (and `make it` if applicable). +1. **Run tests and lint**: `make test` and `make lint` (and `make it` if applicable). 2. **Rebase** onto the latest `main` before opening a PR. 3. Open the PR, referencing the related issue (`Closes #123`). 4. Respond to review comments; squash commits if asked. diff --git a/Makefile b/Makefile index 6f927f82..110ae19c 100644 --- a/Makefile +++ b/Makefile @@ -78,7 +78,10 @@ tox-env-clean: lint: ruff check . - # ruff format --check . # uncomment once the codebase has been formatted + ruff format --check . + +format: + ruff format . test: develop pytest tests/ @@ -105,4 +108,4 @@ release-checks: release: release-checks clean it ./release.sh $(release_version) $(next_version) -.PHONY: install clean python-caches-clean tox-env-clean test it it312 benchmark coverage release release-checks pyinst +.PHONY: install clean python-caches-clean tox-env-clean lint format test it it312 benchmark coverage release release-checks pyinst diff --git a/benchmarks/__init__.py b/benchmarks/__init__.py index 5047a451..f5768141 100644 --- a/benchmarks/__init__.py +++ b/benchmarks/__init__.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an diff --git a/benchmarks/worker_coordinator/__init__.py b/benchmarks/worker_coordinator/__init__.py index 5047a451..f5768141 100644 --- a/benchmarks/worker_coordinator/__init__.py +++ b/benchmarks/worker_coordinator/__init__.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an diff --git a/benchmarks/worker_coordinator/parsing_test.py b/benchmarks/worker_coordinator/parsing_test.py index 5ecebb71..23691c6d 100644 --- a/benchmarks/worker_coordinator/parsing_test.py +++ b/benchmarks/worker_coordinator/parsing_test.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -35,49 +35,35 @@ from solrorbit.worker_coordinator import runner -@pytest.mark.benchmark( - group="parse", - warmup="on", - warmup_iterations=10000, - disable_gc=True -) + +@pytest.mark.benchmark(group="parse", warmup="on", warmup_iterations=10000, disable_gc=True) def test_sort_reverse_and_regexp_small(benchmark): benchmark(sort_parsing_candidate_reverse_and_regexp, ParsingBenchmarks.small_page) -@pytest.mark.benchmark( - group="parse_large", - warmup="on", - warmup_iterations=10000, - disable_gc=True -) + +@pytest.mark.benchmark(group="parse_large", warmup="on", warmup_iterations=10000, disable_gc=True) def test_sort_reverse_and_regexp_large(benchmark): benchmark(sort_parsing_candidate_reverse_and_regexp, ParsingBenchmarks.large_page) + def sort_parsing_candidate_reverse_and_regexp(response): reversed_response = response[::-1] sort_pattern = r"(\][^\]]*?\[):\"tros\"" x = re.search(sort_pattern, reversed_response) # return json.loads(x.group(1)[::-1]) # mean 3.6 ms - return ujson.loads(x.group(1)[::-1]) # mean 1.7 ms - -@pytest.mark.benchmark( - group="parse", - warmup="on", - warmup_iterations=10000, - disable_gc=True -) + return ujson.loads(x.group(1)[::-1]) # mean 1.7 ms + + +@pytest.mark.benchmark(group="parse", warmup="on", warmup_iterations=10000, disable_gc=True) def test_sort_rfind_and_regexp_small(benchmark): benchmark(sort_parsing_candidate_rfind_and_regexp, ParsingBenchmarks.small_page) -@pytest.mark.benchmark( - group="parse_large", - warmup="on", - warmup_iterations=10000, - disable_gc=True -) + +@pytest.mark.benchmark(group="parse_large", warmup="on", warmup_iterations=10000, disable_gc=True) def test_sort_rfind_and_regexp_large(benchmark): benchmark(sort_parsing_candidate_rfind_and_regexp, ParsingBenchmarks.large_page) + def sort_parsing_candidate_rfind_and_regexp(response): index_of_last_sort = response.rfind('"sort"') sort_pattern = r"sort\":([^\]]*])" @@ -85,146 +71,104 @@ def sort_parsing_candidate_rfind_and_regexp(response): # return json.loads(x.group(1)[::-1]) return ujson.loads(x.group(1)) -@pytest.mark.benchmark( - group="parse", - warmup="on", - warmup_iterations=10000, - disable_gc=True -) + +@pytest.mark.benchmark(group="parse", warmup="on", warmup_iterations=10000, disable_gc=True) def test_sort_end_anchor_regexp(benchmark): benchmark(sort_parsing_candidate_end_anchor_regexp, ParsingBenchmarks.small_page) -@pytest.mark.benchmark( - group="parse_large", - warmup="on", - warmup_iterations=10000, - disable_gc=True -) + +@pytest.mark.benchmark(group="parse_large", warmup="on", warmup_iterations=10000, disable_gc=True) def test_sort_end_anchor_regexp_large(benchmark): benchmark(sort_parsing_candidate_end_anchor_regexp, ParsingBenchmarks.large_page) + def sort_parsing_candidate_end_anchor_regexp(response): # predictably, no difference in using a literal lookahead vs just a surrounding pattern. room for improvement? sort_pattern = r"\"sort\":([^\]]*])\}\]\}\}$" x = re.search(sort_pattern, response) # return ast.literal_eval(x.group(1)) # mean 8.6 ms # return json.loads(x.group(1)) # mean 3.2 ms - return ujson.loads(x.group(1)) # mean 1.5 ms - -@pytest.mark.benchmark( - group="parse", - warmup="on", - warmup_iterations=10000, - disable_gc=True -) + return ujson.loads(x.group(1)) # mean 1.5 ms + + +@pytest.mark.benchmark(group="parse", warmup="on", warmup_iterations=10000, disable_gc=True) def test_sort_find_all_regexp_small(benchmark): benchmark(sort_parsing_candidate_find_all, ParsingBenchmarks.small_page) -@pytest.mark.benchmark( - group="parse_large", - warmup="on", - warmup_iterations=10000, - disable_gc=True -) + +@pytest.mark.benchmark(group="parse_large", warmup="on", warmup_iterations=10000, disable_gc=True) def test_sort_find_all_regexp_large(benchmark): benchmark(sort_parsing_candidate_find_all, ParsingBenchmarks.large_page) + def sort_parsing_candidate_find_all(response): sort_pattern = r"\"sort\":([^\]]+])" x = re.findall(sort_pattern, response) return ujson.loads(x[-1]) -@pytest.mark.benchmark( - group="parse", - warmup="on", - warmup_iterations=10000, - disable_gc=True -) + +@pytest.mark.benchmark(group="parse", warmup="on", warmup_iterations=10000, disable_gc=True) def test_pit_id_regexp_small(benchmark): benchmark(pit_id_parsing_candidate_regexp, ParsingBenchmarks.small_page) -@pytest.mark.benchmark( - group="parse_large", - warmup="on", - warmup_iterations=10000, - disable_gc=True -) + +@pytest.mark.benchmark(group="parse_large", warmup="on", warmup_iterations=10000, disable_gc=True) def test_pit_id_regexp_large(benchmark): benchmark(pit_id_parsing_candidate_regexp, ParsingBenchmarks.large_page) + def pit_id_parsing_candidate_regexp(response): - pit_id_pattern = r'"pit_id":"([^"]*)"' # 0.9 ms + pit_id_pattern = r'"pit_id":"([^"]*)"' # 0.9 ms x = re.search(pit_id_pattern, response) return x.group(1) -@pytest.mark.benchmark( - group="parse", - warmup="on", - warmup_iterations=10000, - disable_gc=True -) + +@pytest.mark.benchmark(group="parse", warmup="on", warmup_iterations=10000, disable_gc=True) def test_combined_json_small(benchmark): benchmark(combined_parsing_candidate_json_loads, ParsingBenchmarks.small_page) -@pytest.mark.benchmark( - group="parse_large", - warmup="on", - warmup_iterations=10000, - disable_gc=True -) + +@pytest.mark.benchmark(group="parse_large", warmup="on", warmup_iterations=10000, disable_gc=True) def test_combined_json_large(benchmark): benchmark(combined_parsing_candidate_json_loads, ParsingBenchmarks.large_page) + def combined_parsing_candidate_json_loads(response): parsed_response = json.loads(response) pit_id = parsed_response.get("pit_id") sort = parsed_response.get("hits").get("hits")[-1].get("sort") return pit_id, sort -@pytest.mark.benchmark( - group="parse_large", - warmup="on", - warmup_iterations=10000, - disable_gc=True -) + +@pytest.mark.benchmark(group="parse_large", warmup="on", warmup_iterations=10000, disable_gc=True) def test_combined_ijson_large(benchmark): benchmark(combined_parsing_candidate_json_loads, ParsingBenchmarks.large_page) -@pytest.mark.benchmark( - group="parse", - warmup="on", - warmup_iterations=10000, - disable_gc=True -) + +@pytest.mark.benchmark(group="parse", warmup="on", warmup_iterations=10000, disable_gc=True) def test_combined_ijson_small(benchmark): benchmark(combined_parsing_candidate_json_loads, ParsingBenchmarks.small_page) + def combined_parsing_candidate_ijson_loads(response): parsed_response = ujson.loads(response) pit_id = parsed_response.get("pit_id") sort = parsed_response.get("hits").get("hits")[-1].get("sort") return pit_id, sort -@pytest.mark.benchmark( - group="parse", - warmup="on", - warmup_iterations=10000, - disable_gc=True -) + +@pytest.mark.benchmark(group="parse", warmup="on", warmup_iterations=10000, disable_gc=True) def test_pit_id_parse_small(benchmark): page = ParsingBenchmarks.small_page.encode() benchmark(pit_id_parsing_candidate_runner_parse, page) -@pytest.mark.benchmark( - group="parse_large", - warmup="on", - warmup_iterations=10000, - disable_gc=True -) + +@pytest.mark.benchmark(group="parse_large", warmup="on", warmup_iterations=10000, disable_gc=True) def test_pit_id_parse_large(benchmark): page = ParsingBenchmarks.large_page.encode() benchmark(pit_id_parsing_candidate_runner_parse, page) + def pit_id_parsing_candidate_runner_parse(response): response_bytes = io.BytesIO(response) parsed = runner.parse(response_bytes, ["pit_id"]) @@ -233,7 +177,6 @@ def pit_id_parsing_candidate_runner_parse(response): class ParsingBenchmarks(TestCase): - def test_all_candidates(self): """ Quick utility test to ensure all benchmark cases are correct @@ -243,16 +186,16 @@ def test_all_candidates(self): self.assertEqual("fedcba9876543210", pit_id) sort = sort_parsing_candidate_reverse_and_regexp(self.small_page) - self.assertEqual([1609780186,"2"], sort) + self.assertEqual([1609780186, "2"], sort) sort = sort_parsing_candidate_rfind_and_regexp(self.large_page) self.assertEqual([1609780186, "2"], sort) sort = sort_parsing_candidate_end_anchor_regexp(self.small_page) - self.assertEqual([1609780186,"2"], sort) + self.assertEqual([1609780186, "2"], sort) sort = sort_parsing_candidate_find_all(self.large_page) - self.assertEqual([1609780186,"2"], sort) + self.assertEqual([1609780186, "2"], sort) pit_id = pit_id_parsing_candidate_regexp(self.large_page) self.assertEqual("fedcba9876543210", pit_id) @@ -282,21 +225,26 @@ def test_all_candidates(self): ] } } - """.replace("\n", "").replace(" ", "") # assume client never calls ?pretty :) + """.replace("\n", "").replace(" ", "") # assume client never calls ?pretty :) - large_page = (""" + large_page = ( + ( + """ { "pit_id": "fedcba9876543210", "took": 10, "timed_out": false, "hits": { "total": 2, - "hits": [""" + """ + "hits": [""" + + """ { "_id": "1", "timestamp": 1609780186, "sort": [1609780186, "1"] - },""" * 100 + """ + },""" + * 100 + + """ { "_id": "2", "timestamp": 1609780186, @@ -305,4 +253,8 @@ def test_all_candidates(self): ] } } - """).replace("\n", "").replace(" ", "") + """ + ) + .replace("\n", "") + .replace(" ", "") + ) diff --git a/benchmarks/worker_coordinator/runner_test.py b/benchmarks/worker_coordinator/runner_test.py index 61348a75..ea66e197 100644 --- a/benchmarks/worker_coordinator/runner_test.py +++ b/benchmarks/worker_coordinator/runner_test.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -36,29 +36,23 @@ class SolrMock: def __init__(self, bulk_size): - self.no_errors = { - "took": 500, - "errors": False, - "items": [] - } + self.no_errors = {"took": 500, "errors": False, "items": []} for idx in range(0, bulk_size): - self.no_errors["items"].append({ - "index": { - "_index": "test", - "_type": "type1", - "_id": str(idx), - "_version": 1, - "result": "created", - "_shards": { - "total": 2, - "successful": 1, - "failed": 0 - }, - "created": True, - "status": 201, - "_seq_no": 0 + self.no_errors["items"].append( + { + "index": { + "_index": "test", + "_type": "type1", + "_id": str(idx), + "_version": 1, + "result": "created", + "_shards": {"total": 2, "successful": 1, "failed": 0}, + "created": True, + "status": 201, + "_seq_no": 0, + } } - }) + ) def bulk(self, body=None, index=None, doc_type=None, params=None): return self.no_errors @@ -67,30 +61,11 @@ def bulk(self, body=None, index=None, doc_type=None, params=None): solr_mock = SolrMock(bulk_size=BULK_SIZE) -@pytest.mark.benchmark( - group="bulk-runner", - warmup="on", - warmup_iterations=10000, - disable_gc=True -) +@pytest.mark.benchmark(group="bulk-runner", warmup="on", warmup_iterations=10000, disable_gc=True) def test_bulk_runner_without_errors_no_detailed_results(benchmark): - benchmark(bulk_index, solr_mock, { - "action-metadata-present": True, - "body": "bulk API body", - "bulk-size": BULK_SIZE - }) + benchmark(bulk_index, solr_mock, {"action-metadata-present": True, "body": "bulk API body", "bulk-size": BULK_SIZE}) -@pytest.mark.benchmark( - group="bulk-runner", - warmup="on", - warmup_iterations=1000, - disable_gc=True -) +@pytest.mark.benchmark(group="bulk-runner", warmup="on", warmup_iterations=1000, disable_gc=True) def test_bulk_runner_without_errors_with_detailed_results(benchmark): - benchmark(bulk_index, solr_mock, { - "action-metadata-present": True, - "body": "bulk API body", - "bulk-size": BULK_SIZE, - "detailed-results": True - }) + benchmark(bulk_index, solr_mock, {"action-metadata-present": True, "body": "bulk API body", "bulk-size": BULK_SIZE, "detailed-results": True}) diff --git a/benchmarks/workload/__init__.py b/benchmarks/workload/__init__.py index 5047a451..f5768141 100644 --- a/benchmarks/workload/__init__.py +++ b/benchmarks/workload/__init__.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an diff --git a/benchmarks/workload/bulk_params_test.py b/benchmarks/workload/bulk_params_test.py index 00ea4630..8030990a 100644 --- a/benchmarks/workload/bulk_params_test.py +++ b/benchmarks/workload/bulk_params_test.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -78,13 +78,9 @@ def create_reader(bulk_size): metadata = params.GenerateActionMetaData(index_name="test-idx", type_name=None) source = params.Slice(StaticSource, 0, sys.maxsize, None, None) - reader = params.MetadataIndexDataReader(data_file="bogus", - batch_size=bulk_size, - bulk_size=bulk_size, - file_source=source, - action_metadata=metadata, - index_name="test-idx", - type_name=None) + reader = params.MetadataIndexDataReader( + data_file="bogus", batch_size=bulk_size, bulk_size=bulk_size, file_source=source, action_metadata=metadata, index_name="test-idx", type_name=None + ) return reader diff --git a/it/__init__.py b/it/__init__.py index e438811b..4c2178bd 100644 --- a/it/__init__.py +++ b/it/__init__.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -43,7 +43,7 @@ DISTRIBUTIONS = ["9.10.1", "10.1.0"] WORKLOADS = ["geonames", "nyc_taxis"] BASE_COMMANDS = ["solr-orbit"] -ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) +ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) def all_benchmark_configs(t): @@ -92,7 +92,7 @@ def solrorbit(cfg, command_line): These commands may have different CLI options than test_run. """ cmd = solrorbit_command_line_for(cfg, command_line) - print(f'\n{datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S")} Invoking solr-orbit: {cmd}') + print(f"\n{datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S')} Invoking solr-orbit: {cmd}") err, retcode = process.run_subprocess_with_stderr(cmd) if retcode != 0: print(err) @@ -177,12 +177,10 @@ def install(self, distribution_version, node_name, cluster_config, http_port): "solr-orbit install --configuration-name={cfg} --distribution-version={dist} --build-type=tar " "--http-port={http_port} --node={node_name} --master-nodes=" "{node_name} --cluster-config={cluster_config} " - "--seed-hosts=\"127.0.0.1:{transport_port}\"".format(cfg=self.cfg, - dist=distribution_version, - http_port=http_port, - node_name=node_name, - cluster_config=cluster_config, - transport_port=transport_port)) + '--seed-hosts="127.0.0.1:{transport_port}"'.format( + cfg=self.cfg, dist=distribution_version, http_port=http_port, node_name=node_name, cluster_config=cluster_config, transport_port=transport_port + ) + ) if retcode != 0: raise AssertionError("Failed to install node {}.".format(distribution_version), err) self.installation_id = json.loads(err)["installation-id"] @@ -190,7 +188,7 @@ def install(self, distribution_version, node_name, cluster_config, http_port): raise AssertionError("Failed to install node {}.".format(distribution_version), e) def start(self, test_run_id): - cmd = "start --runtime-jdk=\"bundled\" --installation-id={} --test-run-id={}".format(self.installation_id, test_run_id) + cmd = 'start --runtime-jdk="bundled" --installation-id={} --test-run-id={}'.format(self.installation_id, test_run_id) if solrorbit(self.cfg, cmd) != 0: raise AssertionError("Failed to start test cluster.") solr_client = client.ClientFactory(hosts=[{"host": "127.0.0.1", "port": self.http_port}], client_options={}).create() @@ -212,10 +210,7 @@ def __init__(self): self.cluster = TestCluster("in-memory-it") def start(self): - self.cluster.install(distribution_version=MetricsStore.VERSION, - node_name="metrics-store", - cluster_config="defaults", - http_port=10200) + self.cluster.install(distribution_version=MetricsStore.VERSION, node_name="metrics-store", cluster_config="defaults", http_port=10200) self.cluster.start(test_run_id="metrics-store") def stop(self): @@ -241,7 +236,7 @@ def remove_integration_test_config(): def get_license(): - with open(os.path.join(ROOT_DIR, 'LICENSE')) as license_file: + with open(os.path.join(ROOT_DIR, "LICENSE")) as license_file: return license_file.readlines()[1].strip() @@ -249,12 +244,14 @@ def build_docker_image(): benchmark_version = version.__version__ env_variables = os.environ.copy() - env_variables['BENCHMARK_VERSION'] = benchmark_version - env_variables['BENCHMARK_LICENSE'] = get_license() - - command = f"docker build -t apache/solr-orbit:{benchmark_version}" \ - f" --build-arg BENCHMARK_VERSION --build-arg BENCHMARK_LICENSE " \ - f"-f {ROOT_DIR}/docker/Dockerfiles/Dockerfile-dev {ROOT_DIR}" + env_variables["BENCHMARK_VERSION"] = benchmark_version + env_variables["BENCHMARK_LICENSE"] = get_license() + + command = ( + f"docker build -t apache/solr-orbit:{benchmark_version}" + f" --build-arg BENCHMARK_VERSION --build-arg BENCHMARK_LICENSE " + f"-f {ROOT_DIR}/docker/Dockerfiles/Dockerfile-dev {ROOT_DIR}" + ) if process.run_subprocess_with_logging(command, env=env_variables) != 0: raise AssertionError("It was not possible to build the docker image from Dockerfile-dev") diff --git a/it/distribution_test.py b/it/distribution_test.py index 9030950d..f620dd13 100644 --- a/it/distribution_test.py +++ b/it/distribution_test.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -34,8 +34,7 @@ def test_tar_distributions(cfg): for dist in it.DISTRIBUTIONS: for workload in it.WORKLOADS: it.wait_until_port_is_free(port_number=port) - assert it.run_test(cfg, f"--distribution-version=\"{dist}\" --workload=\"{workload}\" " - f"--test-mode --cluster-config=4gheap --target-hosts=127.0.0.1:{port}") == 0 + assert it.run_test(cfg, f'--distribution-version="{dist}" --workload="{workload}" --test-mode --cluster-config=4gheap --target-hosts=127.0.0.1:{port}') == 0 @it.random_benchmark_config @@ -44,6 +43,12 @@ def test_docker_distribution(cfg): # only test the most recent Docker distribution dist = it.DISTRIBUTIONS[-1] it.wait_until_port_is_free(port_number=port) - assert it.run_test(cfg, f"--pipeline=\"docker\" --distribution-version=\"{dist}\" " - f"--workload=\"geonames\" --test-procedure=\"append-no-conflicts-index-only\" --test-mode " - f"--cluster-config=4gheap --target-hosts=127.0.0.1:{port}") == 0 + assert ( + it.run_test( + cfg, + f'--pipeline="docker" --distribution-version="{dist}" ' + f'--workload="geonames" --test-procedure="append-no-conflicts-index-only" --test-mode ' + f"--cluster-config=4gheap --target-hosts=127.0.0.1:{port}", + ) + == 0 + ) diff --git a/it/download_test.py b/it/download_test.py index 53d5ff64..58d5c00e 100644 --- a/it/download_test.py +++ b/it/download_test.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -31,9 +31,9 @@ @it.random_benchmark_config def test_download_distribution(cfg): for d in it.DISTRIBUTIONS: - assert it.solrorbit(cfg, f"download --distribution-version=\"{d}\" --quiet") == 0 + assert it.solrorbit(cfg, f'download --distribution-version="{d}" --quiet') == 0 @it.random_benchmark_config def test_does_not_download_unsupported_distribution(cfg): - assert it.solrorbit(cfg, "download --distribution-version=\"1.7.6\" --quiet") != 0 + assert it.solrorbit(cfg, 'download --distribution-version="1.7.6" --quiet') != 0 diff --git a/it/info_test.py b/it/info_test.py index 624922ba..d783ea22 100644 --- a/it/info_test.py +++ b/it/info_test.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -41,7 +41,7 @@ def test_workload_info_with_workload_repo(cfg): @it.benchmark_in_mem def test_workload_info_with_task_filter(cfg): - assert it.solrorbit(cfg, "info --workload=geonames --test-procedure=append-no-conflicts --include-tasks=\"type:search\"") == 0 + assert it.solrorbit(cfg, 'info --workload=geonames --test-procedure=append-no-conflicts --include-tasks="type:search"') == 0 @it.benchmark_in_mem @@ -49,8 +49,10 @@ def test_workload_info_fails_with_wrong_workload_params(cfg): # simulate a typo in workload parameter cmd = it.solrorbit_command_line_for(cfg, "info --workload=geonames --workload-params='conflict_probability:5,number-of-replicas:1'") output = process.run_subprocess_with_output(cmd) - expected = "Some of your workload parameter(s) \"number-of-replicas\" are not used by this workload; " \ - "perhaps you intend to use \"number_of_replicas\" instead.\n\nAll workload parameters you " \ - "provided are:\n- conflict_probability\n- number-of-replicas\n\nAll parameters exposed by this workload" + expected = ( + 'Some of your workload parameter(s) "number-of-replicas" are not used by this workload; ' + 'perhaps you intend to use "number_of_replicas" instead.\n\nAll workload parameters you ' + "provided are:\n- conflict_probability\n- number-of-replicas\n\nAll parameters exposed by this workload" + ) assert expected in "\n".join(output) diff --git a/it/list_test.py b/it/list_test.py index 4455f239..f0554ada 100644 --- a/it/list_test.py +++ b/it/list_test.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -39,12 +39,10 @@ def test_list_cluster_configs(cfg): assert it.solrorbit(cfg, "list cluster-configs --cluster-config-repository=default") == 0 - @it.benchmark_in_mem def test_list_workloads(cfg): assert it.solrorbit(cfg, "list workloads") == 0 - assert it.solrorbit(cfg, "list workloads --workload-repository=default " - "--workload-revision=cba4e45dda37ac03abbd3c9dd4532475dac355e9") == 0 + assert it.solrorbit(cfg, "list workloads --workload-repository=default --workload-revision=cba4e45dda37ac03abbd3c9dd4532475dac355e9") == 0 @it.benchmark_in_mem diff --git a/it/proxy_test.py b/it/proxy_test.py index 1ca4e83c..529ab40e 100644 --- a/it/proxy_test.py +++ b/it/proxy_test.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -42,13 +42,11 @@ def http_proxy(): config_dir = os.path.join(os.path.dirname(__file__), "resources", "squid") - lines = process.run_subprocess_with_output(f"docker run --rm --name squid -d " - f"-v {config_dir}/squidpasswords:/etc/squid/squidpasswords " - f"-v {config_dir}/squid.conf:/etc/squid/squid.conf " - f"-p 3128:3128 ubuntu/squid") + lines = process.run_subprocess_with_output( + f"docker run --rm --name squid -d -v {config_dir}/squidpasswords:/etc/squid/squidpasswords -v {config_dir}/squid.conf:/etc/squid/squid.conf -p 3128:3128 ubuntu/squid" + ) proxy_container_id = lines[0].strip() - proxy = HttpProxy(authenticated_url="http://testuser:testuser@127.0.0.1:3128", - anonymous_url="http://127.0.0.1:3128") + proxy = HttpProxy(authenticated_url="http://testuser:testuser@127.0.0.1:3128", anonymous_url="http://127.0.0.1:3128") yield proxy process.run_subprocess(f"docker stop {proxy_container_id}") @@ -98,7 +96,6 @@ def test_authenticated_proxy_user_can_connect(cfg, http_proxy, fresh_log_file): env = dict(os.environ) env["http_proxy"] = http_proxy.authenticated_url assert process.run_subprocess_with_logging(it.solrorbit_command_line_for(cfg, "list workloads"), env=env) == 0 - assert_log_line_present(fresh_log_file, - f"Connecting via proxy URL [{http_proxy.authenticated_url}] to the Internet") + assert_log_line_present(fresh_log_file, f"Connecting via proxy URL [{http_proxy.authenticated_url}] to the Internet") # authenticated proxy access is allowed assert_log_line_present(fresh_log_file, "Detected a working Internet connection") diff --git a/it/sources_test.py b/it/sources_test.py index dcc28012..d793a06e 100644 --- a/it/sources_test.py +++ b/it/sources_test.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -27,15 +27,27 @@ import it + @it.random_benchmark_config def test_sources(cfg): port = 19200 it.wait_until_port_is_free(port_number=port) - assert it.run_test(cfg, f"--pipeline=from-sources --revision=latest \ + assert ( + it.run_test( + cfg, + f"--pipeline=from-sources --revision=latest \ --workload=geonames --test-mode --target-hosts=127.0.0.1:{port} " - f"--test-procedure=append-no-conflicts --cluster-config=4gheap " - f"--opensearch-plugins=analysis-icu") == 0 + f"--test-procedure=append-no-conflicts --cluster-config=4gheap " + f"--opensearch-plugins=analysis-icu", + ) + == 0 + ) it.wait_until_port_is_free(port_number=port) - assert it.run_test(cfg, f"--pipeline=from-sources --workload=geonames --test-mode --target-hosts=127.0.0.1:{port} " - f"--test-procedure=append-no-conflicts-index-only --cluster-config=\"4gheap,ea\"") == 0 + assert ( + it.run_test( + cfg, + f'--pipeline=from-sources --workload=geonames --test-mode --target-hosts=127.0.0.1:{port} --test-procedure=append-no-conflicts-index-only --cluster-config="4gheap,ea"', + ) + == 0 + ) diff --git a/it/tracker_test.py b/it/tracker_test.py index 994652e2..11e3aed9 100644 --- a/it/tracker_test.py +++ b/it/tracker_test.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -50,23 +50,16 @@ def test_cluster(): @it.benchmark_in_mem def test_create_workload(cfg, tmp_path, test_cluster): # prepare some data - cmd = f"--test-mode --pipeline=benchmark-only --target-hosts=127.0.0.1:{test_cluster.http_port} " \ - f" --workload=geonames --test-procedure=append-no-conflicts-index-only --quiet" + cmd = f"--test-mode --pipeline=benchmark-only --target-hosts=127.0.0.1:{test_cluster.http_port} --workload=geonames --test-procedure=append-no-conflicts-index-only --quiet" assert it.run_test(cfg, cmd) == 0 # create the workload workload_name = f"test-workload-{uuid.uuid4()}" workload_path = tmp_path / workload_name - assert it.solrorbit(cfg, f"create-workload --target-hosts=127.0.0.1:{test_cluster.http_port} --indices=geonames " - f"--workload={workload_name} --output-path={tmp_path}") == 0 + assert it.solrorbit(cfg, f"create-workload --target-hosts=127.0.0.1:{test_cluster.http_port} --indices=geonames --workload={workload_name} --output-path={tmp_path}") == 0 - expected_files = ["workload.json", - "geonames.json", - "geonames-documents-1k.json", - "geonames-documents.json", - "geonames-documents-1k.json.bz2", - "geonames-documents.json.bz2"] + expected_files = ["workload.json", "geonames.json", "geonames-documents-1k.json", "geonames-documents.json", "geonames-documents-1k.json.bz2", "geonames-documents.json.bz2"] for f in expected_files: full_path = workload_path / f diff --git a/scripts/analyze.py b/scripts/analyze.py index c0cd70cc..a0c2ef8d 100644 --- a/scripts/analyze.py +++ b/scripts/analyze.py @@ -13,7 +13,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -61,7 +61,7 @@ def create_plot(): def present(a_plot, name): - a_plot.savefig("%s.png" % name, bbox_inches='tight') + a_plot.savefig("%s.png" % name, bbox_inches="tight") # plt.show() # alternatively only show it # explicitly close to free resources a_plot.close() @@ -96,11 +96,13 @@ def plot_service_time(raw_data, label_key): service_time_metrics = op_metrics["service_time"] if operation not in service_time_per_op: service_time_per_op[operation] = [] - service_time_per_op[operation].append({ - "data_series": data_series, - "percentiles": [decode_percentile_key(p) for p in service_time_metrics.keys()], - "percentile_values": list(service_time_metrics.values()), - }) + service_time_per_op[operation].append( + { + "data_series": data_series, + "percentiles": [decode_percentile_key(p) for p in service_time_metrics.keys()], + "percentile_values": list(service_time_metrics.values()), + } + ) for op, results in service_time_per_op.items(): _, ax = create_plot() @@ -109,7 +111,7 @@ def plot_service_time(raw_data, label_key): for candidate in results: label = candidate["data_series"] - series = ax.plot(candidate["percentiles"], candidate["percentile_values"], marker='.', label=label) + series = ax.plot(candidate["percentiles"], candidate["percentile_values"], marker=".", label=label) legend_handles.append(series[0]) legend_labels.append(label) @@ -120,7 +122,7 @@ def plot_service_time(raw_data, label_key): box = ax.get_position() ax.set_position([box.x0, box.y0, box.width * 0.8, box.height]) - ax.legend(legend_handles, legend_labels, loc='center left', bbox_to_anchor=(1, 0.5)) + ax.legend(legend_handles, legend_labels, loc="center left", bbox_to_anchor=(1, 0.5)) present(plt, "service_time_%s" % op) @@ -136,13 +138,15 @@ def plot_throughput(raw_data, label_key): throughput_metrics = op_metrics["throughput"] if operation not in throughput_per_op: throughput_per_op[operation] = [] - throughput_per_op[operation].append({ - "data_series": data_series, - "max": throughput_metrics["max"], - "median": throughput_metrics["median"], - "min": throughput_metrics["min"], - "unit": throughput_metrics["unit"] - }) + throughput_per_op[operation].append( + { + "data_series": data_series, + "max": throughput_metrics["max"], + "median": throughput_metrics["median"], + "min": throughput_metrics["min"], + "unit": throughput_metrics["unit"], + } + ) for op, results in throughput_per_op.items(): _, ax = create_plot() @@ -211,7 +215,7 @@ def plot_gc_times(raw_data, label_key): box = ax.get_position() ax.set_position([box.x0, box.y0, box.width * 0.8, box.height]) - ax.legend([old_bar[0], young_bar[0]], ["Old GC", "Young GC"], loc='center left', bbox_to_anchor=(1, 0.5)) + ax.legend([old_bar[0], young_bar[0]], ["Old GC", "Young GC"], loc="center left", bbox_to_anchor=(1, 0.5)) ax.set_ylim(ymin=0) present(plt, "gc_times") @@ -230,11 +234,10 @@ def parse_args(): "--label", help="defines which attribute to use for labelling data series (default: test-run-timestamp).", # choices=["environment", "test-run-timestamp", "user-tags", "test_procedure", "cluster-config-instance"], - default="test-run-timestamp") + default="test-run-timestamp", + ) - parser.add_argument("path", - nargs="+", - help="Full path to one or more test_run.json files") + parser.add_argument("path", nargs="+", help="Full path to one or more test_run.json files") return parser.parse_args() @@ -250,5 +253,5 @@ def main(): plot(series, args.label) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/scripts/expand-data-corpus.py b/scripts/expand-data-corpus.py index 5988b60d..7cee0389 100755 --- a/scripts/expand-data-corpus.py +++ b/scripts/expand-data-corpus.py @@ -99,21 +99,18 @@ """ + def handler(signum, frame): sys.exit(1) def error_exit(script_name, message): - print(f'{script_name}: {message}', file=sys.stderr) + print(f"{script_name}: {message}", file=sys.stderr) sys.exit(1) class DocGenerator: - - def __init__(self, - input_file: str, - start_timestamp: int, - interval:int) -> None: + def __init__(self, input_file: str, start_timestamp: int, interval: int) -> None: self.input_file = input_file self.timestamp = start_timestamp self.interval = interval @@ -140,7 +137,7 @@ def get_next_doc(self): class ArgParser(argparse.ArgumentParser): - def usage_msg(self, message: str=None) -> None: + def usage_msg(self, message: str = None) -> None: if message: print(message, file=sys.stderr) print(file=sys.stderr) @@ -148,49 +145,41 @@ def usage_msg(self, message: str=None) -> None: sys.exit(1) def error(self, message): - print('error: %s' % message, file=sys.stderr) + print("error: %s" % message, file=sys.stderr) self.usage_msg() -def generate_docs(script_name: str, - workload: str, - repository: str, - input_file: str, - output_file_suffix: str, - n_docs: int, - corpus_size: int, - interval: int, - start_timestamp: int, - batch_size: int): +def generate_docs( + script_name: str, workload: str, repository: str, input_file: str, output_file_suffix: str, n_docs: int, corpus_size: int, interval: int, start_timestamp: int, batch_size: int +): # # Set up for generation. # config = configparser.ConfigParser() - benchmark_home = os.environ.get('BENCHMARK_HOME') or os.environ['HOME'] - benchmark_ini = benchmark_home + '/.benchmark/benchmark.ini' + benchmark_home = os.environ.get("BENCHMARK_HOME") or os.environ["HOME"] + benchmark_ini = benchmark_home + "/.benchmark/benchmark.ini" if not os.path.isfile(benchmark_ini): error_exit(script_name, f"could not find benchmark config file {benchmark_ini}, run a workload first to create it") config.read(benchmark_ini) - root_dir = config['node']['root.dir'] - workload_dir= root_dir + '/workloads/' + repository + '/' + workload - data_dir = config['benchmarks']['local.dataset.cache'] + '/' + workload + root_dir = config["node"]["root.dir"] + workload_dir = root_dir + "/workloads/" + repository + "/" + workload + data_dir = config["benchmarks"]["local.dataset.cache"] + "/" + workload if not os.path.exists(data_dir): error_exit(script_name, f"workload data directory {data_dir} does not exist, run the appropriate workload first to create it") - output_file = data_dir + '/documents-' + output_file_suffix + '.json' - if '/' not in input_file: - input_file = data_dir + '/' + input_file + output_file = data_dir + "/documents-" + output_file_suffix + ".json" + if "/" not in input_file: + input_file = data_dir + "/" + input_file - out = open(output_file, 'w') - offsets = open(output_file + '.offset', 'w') + out = open(output_file, "w") + offsets = open(output_file + ".offset", "w") # # Obtain the generator to synthesize the documents. # - g = DocGenerator(input_file, start_timestamp, interval).\ - get_next_doc() + g = DocGenerator(input_file, start_timestamp, interval).get_next_doc() # # Generate the desired number of documents. @@ -207,7 +196,7 @@ def generate_docs(script_name: str, # Offset file entry. if line_num > 0 and line_num % batch_size == 0: - s = str(line_num) + ';' + str(offset) + '\n' + s = str(line_num) + ";" + str(offset) + "\n" offsets.write(s) line = next(g) @@ -222,21 +211,21 @@ def generate_docs(script_name: str, # Create the metadata files. # corpus_spec = dict() - corpus_spec['target-index'] = 'logs-' + output_file_suffix - corpus_spec['source-file'] = output_file - corpus_spec['document-count'] = line_num - corpus_spec['uncompressed-bytes'] = offset + corpus_spec["target-index"] = "logs-" + output_file_suffix + corpus_spec["source-file"] = output_file + corpus_spec["document-count"] = line_num + corpus_spec["uncompressed-bytes"] = offset - out = open(workload_dir + '/gen-docs-' + output_file_suffix + '.json', 'w') - out.write(json.dumps(corpus_spec) + '\n') + out = open(workload_dir + "/gen-docs-" + output_file_suffix + ".json", "w") + out.write(json.dumps(corpus_spec) + "\n") out.close() idx_spec = dict() - idx_spec['name'] = 'logs-' + output_file_suffix - idx_spec['body'] = 'index.json' + idx_spec["name"] = "logs-" + output_file_suffix + idx_spec["body"] = "index.json" - out = open(workload_dir + '/gen-idx-' + output_file_suffix + '.json', 'w') - out.write(json.dumps(idx_spec) + '\n') + out = open(workload_dir + "/gen-idx-" + output_file_suffix + ".json", "w") + out.write(json.dumps(idx_spec) + "\n") out.close() @@ -245,34 +234,16 @@ def main(args: list) -> None: signal.signal(signal.SIGINT, handler) script_name = os.path.basename(__file__) - parser = ArgParser(description=help_msg, - formatter_class=argparse.RawTextHelpFormatter) - parser.add_argument('-w', '--workload', - default='http_logs', - help="workload name, default: %(default)s") - parser.add_argument('-r', '--workload-repository', default='default', - help="workload name, default: %(default)s") - parser.add_argument('-c', '--corpus-size', type=int, - help="size of corpus to generate in GB") - parser.add_argument('-o', '--output-file-suffix', - default='generated', - help="suffix for output file name, " - "documents-SUFFIX.json, default: %(default)s") - parser.add_argument('-f', '--input-file', - default='documents-241998.json', - help="[EXPERT] input file name, default: %(default)s") - parser.add_argument('-n', '--number-of-docs', type=int, - help="[EXPERT] number of documents to generate") - parser.add_argument('-i', '--interval', type=int, - help="[EXPERT] interval between consecutive " - "timestamps, use a negative number to specify multiple " - "docs per timestamp") - parser.add_argument('-t', '--start-timestamp', type=int, - default=893964618, - help="[EXPERT] start timestamp, default: %(default)d") - parser.add_argument('-b', '--batch-size', default=50000, - help="[EXPERT] batch size per benchmark client thread, " - "default: %(default)d") + parser = ArgParser(description=help_msg, formatter_class=argparse.RawTextHelpFormatter) + parser.add_argument("-w", "--workload", default="http_logs", help="workload name, default: %(default)s") + parser.add_argument("-r", "--workload-repository", default="default", help="workload name, default: %(default)s") + parser.add_argument("-c", "--corpus-size", type=int, help="size of corpus to generate in GB") + parser.add_argument("-o", "--output-file-suffix", default="generated", help="suffix for output file name, documents-SUFFIX.json, default: %(default)s") + parser.add_argument("-f", "--input-file", default="documents-241998.json", help="[EXPERT] input file name, default: %(default)s") + parser.add_argument("-n", "--number-of-docs", type=int, help="[EXPERT] number of documents to generate") + parser.add_argument("-i", "--interval", type=int, help="[EXPERT] interval between consecutive timestamps, use a negative number to specify multiple docs per timestamp") + parser.add_argument("-t", "--start-timestamp", type=int, default=893964618, help="[EXPERT] start timestamp, default: %(default)d") + parser.add_argument("-b", "--batch-size", default=50000, help="[EXPERT] batch size per benchmark client thread, default: %(default)d") args = parser.parse_args() @@ -286,29 +257,15 @@ def main(args: list) -> None: batch_size = args.batch_size if n_docs and corpus_size: - parser.usage_msg(script_name + - ": can specify either number of documents" - "or corpus size, but not both") + parser.usage_msg(script_name + ": can specify either number of documentsor corpus size, but not both") elif not n_docs and not corpus_size: - parser.usage_msg(script_name + - ": must specify number of documents or corpus size") - interval = args.interval if args.interval is not None else \ - corpus_size * -2 - if workload != 'http_logs': - parser.usage_msg(script_name + - ': only the "http_logs" workload is currently supported') - - generate_docs(script_name, - workload, - repository, - input_file, - output_file_suffix, - n_docs, - corpus_size, - interval, - start_timestamp, - batch_size) - - -if __name__ == '__main__': + parser.usage_msg(script_name + ": must specify number of documents or corpus size") + interval = args.interval if args.interval is not None else corpus_size * -2 + if workload != "http_logs": + parser.usage_msg(script_name + ': only the "http_logs" workload is currently supported') + + generate_docs(script_name, workload, repository, input_file, output_file_suffix, n_docs, corpus_size, interval, start_timestamp, batch_size) + + +if __name__ == "__main__": sys.exit(main(sys.argv[1:])) diff --git a/solrorbit/__init__.py b/solrorbit/__init__.py index 2faaa094..e3f251a9 100644 --- a/solrorbit/__init__.py +++ b/solrorbit/__init__.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an diff --git a/solrorbit/actor.py b/solrorbit/actor.py index 76e27815..0768efe3 100644 --- a/solrorbit/actor.py +++ b/solrorbit/actor.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -39,6 +39,7 @@ class BenchmarkFailure: """ Indicates a failure in the benchmark execution due to an exception """ + def __init__(self, message, cause=None): self.message = message self.cause = cause @@ -57,10 +58,13 @@ def parametrized(decorator): :param decorator: The decorator that should accept parameters. """ + def inner(*args, **kwargs): def g(f): return decorator(f, *args, **kwargs) + return g + return inner @@ -90,6 +94,7 @@ def receiveMsg_DefuseBomb(self, msg, sender): :param f: The message handler. Does not need to passed directly, this is handled by the decorator infrastructure. :param actor_name: A human readable name of the current actor that should be used in the exception message. """ + def guard(self, msg, sender): try: return f(self, msg, sender) @@ -100,6 +105,7 @@ def guard(self, msg, sender): # don't forward the exception as is because the main process might not have this class available on the load path # and will fail then while deserializing the cause. self.send(sender, BenchmarkFailure("{} ({})".format(msg, str(e)))) + return guard @@ -143,22 +149,20 @@ def transition_when_all_children_responded(self, sender, msg, expected_status, n response_count = len(self.received_responses) expected_count = len(self.children) - self.logger.debug("[%d] of [%d] child actors have responded for transition from [%s] to [%s].", - response_count, expected_count, self.status, new_status) + self.logger.debug("[%d] of [%d] child actors have responded for transition from [%s] to [%s].", response_count, expected_count, self.status, new_status) if response_count == expected_count: - self.logger.debug("All [%d] child actors have responded. Transitioning now from [%s] to [%s].", - expected_count, self.status, new_status) + self.logger.debug("All [%d] child actors have responded. Transitioning now from [%s] to [%s].", expected_count, self.status, new_status) # all nodes have responded, change status self.status = new_status self.received_responses = [] transition() elif response_count > expected_count: raise exceptions.BenchmarkAssertionError( - "Received [%d] responses but only [%d] were expected to transition from [%s] to [%s]. The responses are: %s" % - (response_count, expected_count, self.status, new_status, self.received_responses)) + "Received [%d] responses but only [%d] were expected to transition from [%s] to [%s]. The responses are: %s" + % (response_count, expected_count, self.status, new_status, self.received_responses) + ) else: - raise exceptions.BenchmarkAssertionError("Received [%s] from [%s] but we are in status [%s] instead of [%s]." % - (type(msg), sender, self.status, expected_status)) + raise exceptions.BenchmarkAssertionError("Received [%s] from [%s] but we are in status [%s] instead of [%s]." % (type(msg), sender, self.status, expected_status)) def send_to_children_and_transition(self, sender, msg, expected_status, new_status): """ @@ -176,8 +180,7 @@ def send_to_children_and_transition(self, sender, msg, expected_status, new_stat for m in filter(None, self.children): self.send(m, msg) else: - raise exceptions.BenchmarkAssertionError("Received [%s] from [%s] but we are in status [%s] instead of [%s]." % - (type(msg), sender, self.status, expected_status)) + raise exceptions.BenchmarkAssertionError("Received [%s] from [%s] but we are in status [%s] instead of [%s]." % (type(msg), sender, self.status, expected_status)) def is_current_status_expected(self, expected_status): # if we don't expect anything, we're always in the right status @@ -254,9 +257,7 @@ def bootstrap_actor_system(try_join=False, prefer_local_only=False, local_ip=Non # Make the coordinator node the convention leader capabilities["Convention Address.IPv4"] = "%s:1900" % coordinator_ip logger.info("Starting actor system with system base [%s] and capabilities [%s].", system_base, capabilities) - return thespian.actors.ActorSystem(system_base, - logDefs=log.load_configuration(), - capabilities=capabilities) + return thespian.actors.ActorSystem(system_base, logDefs=log.load_configuration(), capabilities=capabilities) except thespian.actors.ActorSystemException: logger.exception("Could not initialize internal actor system.") raise diff --git a/solrorbit/aggregator.py b/solrorbit/aggregator.py index 7f099314..be3f4a52 100644 --- a/solrorbit/aggregator.py +++ b/solrorbit/aggregator.py @@ -7,6 +7,7 @@ from solrorbit import metrics, workload, config from solrorbit.utils import io as rio + class Aggregator: def __init__(self, cfg, test_runs_dict, args) -> None: self.config = cfg @@ -80,7 +81,7 @@ def aggregate_json_elements(json_elements: List[Any]) -> Any: return next((obj for obj in json_elements if obj is not None), None) if isinstance(key_path, str): - key_path = key_path.split('.') + key_path = key_path.split(".") nested_values = [get_nested_value(json_result, key_path) for json_result in all_json_results] return aggregate_json_elements(nested_values) @@ -122,7 +123,7 @@ def build_aggregated_results_dict(self) -> Dict[str, Any]: "total_transform_search_times": self.aggregate_json_by_key("total_transform_search_times"), "total_transform_index_times": self.aggregate_json_by_key("total_transform_index_times"), "total_transform_processing_times": self.aggregate_json_by_key("total_transform_processing_times"), - "total_transform_throughput": self.aggregate_json_by_key("total_transform_throughput") + "total_transform_throughput": self.aggregate_json_by_key("total_transform_throughput"), } for task, task_metrics in self.accumulated_results.items(): @@ -138,9 +139,9 @@ def build_aggregated_results_dict(self) -> Dict[str, Any]: if isinstance(aggregated_task_metrics[metric], dict): # Calculate RSD for the mean values across all test runs # We use mean here as it's more sensitive to outliers, which is desirable for assessing variability - mean_values = [v['mean'] for v in task_metrics[metric]] + mean_values = [v["mean"] for v in task_metrics[metric]] rsd = self.calculate_rsd(mean_values, f"{task}.{metric}.mean") - op_metric[metric]['mean_rsd'] = rsd + op_metric[metric]["mean_rsd"] = rsd # Handle derived metrics (like error_rate, duration) which are stored as simple values else: @@ -158,15 +159,12 @@ def update_config_object(self, test_run: TestRun) -> None: Uses the first test run as reference since configurations should be identical """ current_timestamp = self.config.opts("system", "time.start") - self.config.add(config.Scope.applicationOverride, "builder", - "cluster_config.names", test_run.cluster_config) - self.config.add(config.Scope.applicationOverride, "system", - "env.name", test_run.environment_name) + self.config.add(config.Scope.applicationOverride, "builder", "cluster_config.names", test_run.cluster_config) + self.config.add(config.Scope.applicationOverride, "system", "env.name", test_run.environment_name) self.config.add(config.Scope.applicationOverride, "system", "time.start", current_timestamp) self.config.add(config.Scope.applicationOverride, "test_run", "pipeline", test_run.pipeline) self.config.add(config.Scope.applicationOverride, "workload", "params", test_run.workload_params) - self.config.add(config.Scope.applicationOverride, "builder", - "cluster_config.params", test_run.cluster_config_instance_params) + self.config.add(config.Scope.applicationOverride, "builder", "cluster_config.params", test_run.cluster_config_instance_params) self.config.add(config.Scope.applicationOverride, "builder", "plugin.params", test_run.plugin_params) self.config.add(config.Scope.applicationOverride, "workload", "latency.percentiles", test_run.latency_percentiles) self.config.add(config.Scope.applicationOverride, "workload", "throughput.percentiles", test_run.throughput_percentiles) @@ -175,13 +173,13 @@ def build_aggregated_results(self) -> TestRun: test_run = self.test_store.find_by_test_run_id(list(self.test_runs.keys())[0]) aggregated_results = self.build_aggregated_results_dict() - if hasattr(self.args, 'results_file') and self.args.results_file != "": + if hasattr(self.args, "results_file") and self.args.results_file != "": normalized_results_file = rio.normalize_path(self.args.results_file, self.cwd) # ensure that the parent folder already exists when we try to write the file... rio.ensure_dir(rio.dirname(normalized_results_file)) test_run_id = os.path.basename(normalized_results_file) self.config.add(config.Scope.applicationOverride, "system", "test_run.id", normalized_results_file) - elif hasattr(self.args, 'test_run_id') and self.args.test_run_id: + elif hasattr(self.args, "test_run_id") and self.args.test_run_id: test_run_id = f"aggregate_results_{test_run.workload}_{self.args.test_run_id}" self.config.add(config.Scope.applicationOverride, "system", "test_run.id", test_run_id) else: @@ -196,9 +194,7 @@ def build_aggregated_results(self) -> TestRun: test_procedure_object = loaded_workload.find_test_procedure_or_default(self.test_procedure_name) test_run = metrics.create_test_run(self.config, loaded_workload, test_procedure_object, test_run.workload_revision) - test_run.user_tags = { - "aggregation-of-runs": list(self.test_runs.keys()) - } + test_run.user_tags = {"aggregation-of-runs": list(self.test_runs.keys())} test_run.add_results(AggregatedResults(aggregated_results)) test_run.distribution_version = test_run.distribution_version test_run.revision = test_run.revision @@ -211,20 +207,19 @@ def calculate_weighted_average(self, task_metrics: Dict[str, List[Any]], task_na weighted_metrics = {} # Get iterations for each test run - iterations_per_run = [self.accumulated_iterations[test_id][task_name] - for test_id in self.test_runs.keys()] + iterations_per_run = [self.accumulated_iterations[test_id][task_name] for test_id in self.test_runs.keys()] total_iterations = sum(iterations_per_run) for metric, values in task_metrics.items(): if isinstance(values[0], dict): weighted_metrics[metric] = {} for metric_field in values[0].keys(): - if metric_field == 'unit': + if metric_field == "unit": weighted_metrics[metric][metric_field] = values[0][metric_field] - elif metric_field == 'min': - weighted_metrics[metric]['overall_min'] = min(value.get(metric_field, 0) for value in values) - elif metric_field == 'max': - weighted_metrics[metric]['overall_max'] = max(value.get(metric_field, 0) for value in values) + elif metric_field == "min": + weighted_metrics[metric]["overall_min"] = min(value.get(metric_field, 0) for value in values) + elif metric_field == "max": + weighted_metrics[metric]["overall_max"] = max(value.get(metric_field, 0) for value in values) else: # for items like median or containing percentile values item_values = [value.get(metric_field, 0) for value in values] @@ -243,7 +238,7 @@ def calculate_rsd(self, values: List[Union[int, float]], metric_name: str) -> Un return "NA" # RSD is not applicable for a single value mean = statistics.mean(values) std_dev = statistics.stdev(values) - return (std_dev / mean) * 100 if mean != 0 else float('inf') + return (std_dev / mean) * 100 if mean != 0 else float("inf") def test_run_compatibility_check(self) -> None: first_test_run = self.test_store.find_by_test_run_id(list(self.test_runs.keys())[0]) @@ -254,8 +249,7 @@ def test_run_compatibility_check(self) -> None: if test_run: if test_run.workload != workload: raise ValueError( - f"Incompatible workload: test {id} has workload '{test_run.workload}' instead of '{workload}'. " - f"Ensure that all test IDs have the same workload." + f"Incompatible workload: test {id} has workload '{test_run.workload}' instead of '{workload}'. Ensure that all test IDs have the same workload." ) if test_run.test_procedure != test_procedure: raise ValueError( @@ -287,6 +281,7 @@ def aggregate(self) -> None: else: raise ValueError("Incompatible test run results") + class AggregatedResults: def __init__(self, results): self.results = results diff --git a/solrorbit/benchmark.py b/solrorbit/benchmark.py index 574991e3..e23cf224 100644 --- a/solrorbit/benchmark.py +++ b/solrorbit/benchmark.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -46,15 +46,14 @@ import thespian.actors from solrorbit import PROGRAM_NAME, BANNER, FORUM_LINK, SKULL, check_python_version, doc_link, telemetry -from solrorbit import version, actor, config, paths, \ - test_run_orchestrator, publisher, \ - metrics, workload, exceptions, log +from solrorbit import version, actor, config, paths, test_run_orchestrator, publisher, metrics, workload, exceptions, log from solrorbit.builder import cluster_config, builder from solrorbit.synthetic_data_generator import synthetic_data_generator_orchestrator from solrorbit.workload_generator import workload_generator from solrorbit.utils import io, convert, process, console, net, opts, versions from solrorbit import aggregator + def create_arg_parser(): def positive_number(v): value = int(v) @@ -91,15 +90,10 @@ def add_workload_source(subparser): "--workload-repository", help="Define the repository from where solr-orbit will load workloads (default: default).", # argparse is smart enough to use this default only if the user did not use --workload-path and also did not specify anything - default="default" + default="default", ) - workload_source_group.add_argument( - "--workload-path", - help="Define the path to a workload.") - subparser.add_argument( - "--workload-revision", - help="Define a specific revision in the workload repository that solr-orbit should use.", - default=None) + workload_source_group.add_argument("--workload-path", help="Define the path to a workload.") + subparser.add_argument("--workload-revision", help="Define a specific revision in the workload repository that solr-orbit should use.", default=None) # try to preload configurable defaults, but this does not work together with `--configuration-name` (which is undocumented anyway) cfg = config.Config() @@ -109,20 +103,19 @@ def add_workload_source(subparser): else: preserve_install = False - parser = argparse.ArgumentParser(prog=PROGRAM_NAME, - description=BANNER + "\n\n A macrobenchmarking tool for Apache Solr", - epilog="Find out more at {}".format(console.format.link(doc_link())), - formatter_class=argparse.RawDescriptionHelpFormatter) - parser.add_argument('-v', '--version', action='version', version="%(prog)s " + version.version()) + parser = argparse.ArgumentParser( + prog=PROGRAM_NAME, + description=BANNER + "\n\n A macrobenchmarking tool for Apache Solr", + epilog="Find out more at {}".format(console.format.link(doc_link())), + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + parser.add_argument("-v", "--version", action="version", version="%(prog)s " + version.version()) if len(sys.argv) == 1: parser.print_help() sys.exit(1) - subparsers = parser.add_subparsers( - title="subcommands", - dest="subcommand", - help="") + subparsers = parser.add_subparsers(title="subcommands", dest="subcommand", help="") test_run_parser = subparsers.add_parser("run", help="Run a benchmark") # change in favor of "list telemetry", "list workloads", "list pipelines" @@ -130,10 +123,9 @@ def add_workload_source(subparser): list_parser.add_argument( "configuration", metavar="configuration", - help="The configuration for which the tool should show the available options. " - "Possible values are: telemetry, workloads, pipelines, test-runs, cluster-configs", - choices=["telemetry", "workloads", "pipelines", "test-runs", "aggregated-results", - "cluster-configs"]) + help="The configuration for which the tool should show the available options. Possible values are: telemetry, workloads, pipelines, test-runs, cluster-configs", + choices=["telemetry", "workloads", "pipelines", "test-runs", "aggregated-results", "cluster-configs"], + ) list_parser.add_argument( "--limit", help="Limit the number of search results for recent test-runs (default: 10).", @@ -146,338 +138,182 @@ def add_workload_source(subparser): info_parser.add_argument( "--workload", "-w", - help=f"Define the workload to use. List possible workloads with `{PROGRAM_NAME} list workloads`." + help=f"Define the workload to use. List possible workloads with `{PROGRAM_NAME} list workloads`.", # we set the default value later on because we need to determine whether the user has provided this value. # default="geonames" ) info_parser.add_argument( - "--workload-params", - "-wp", - help="Define a comma-separated list of key:value pairs that are injected verbatim to the workload as variables.", - default="" - ) - info_parser.add_argument( - "--test-procedure", - help=f"Define the test_procedure to use. List possible test_procedures for workloads with `{PROGRAM_NAME} list workloads`." + "--workload-params", "-wp", help="Define a comma-separated list of key:value pairs that are injected verbatim to the workload as variables.", default="" ) + info_parser.add_argument("--test-procedure", help=f"Define the test_procedure to use. List possible test_procedures for workloads with `{PROGRAM_NAME} list workloads`.") info_task_filter_group = info_parser.add_mutually_exclusive_group() - info_task_filter_group.add_argument( - "--include-tasks", - help="Defines a comma-separated list of tasks to run. By default all tasks of a test_procedure are run.") - info_task_filter_group.add_argument( - "--exclude-tasks", - help="Defines a comma-separated list of tasks not to run. By default all tasks of a test_procedure are run.") + info_task_filter_group.add_argument("--include-tasks", help="Defines a comma-separated list of tasks to run. By default all tasks of a test_procedure are run.") + info_task_filter_group.add_argument("--exclude-tasks", help="Defines a comma-separated list of tasks not to run. By default all tasks of a test_procedure are run.") - synthetic_data_generator_parser = subparsers.add_parser("generate-data", - help="Generate synthetic data based on existing index mappings or custom module." + - "This data can be ported into Solr Orbit workloads." ) + synthetic_data_generator_parser = subparsers.add_parser( + "generate-data", help="Generate synthetic data based on existing index mappings or custom module." + "This data can be ported into Solr Orbit workloads." + ) exclusive_file_inputs = synthetic_data_generator_parser.add_mutually_exclusive_group(required=True) - exclusive_file_inputs.add_argument( - "--index-mappings", - "-i", - help="Index mappings (JSON) to generate synthetic data from." - ) + exclusive_file_inputs.add_argument("--index-mappings", "-i", help="Index mappings (JSON) to generate synthetic data from.") exclusive_file_inputs.add_argument( "--custom-module", "-m", - help="Custom Python module that defines how to generate documents. " + - "It can contain function definitions and even class definitions. " + - "This gives users more granular control over how data is generated. " + - "This module must contain generate_synthetic_document() definition." + help="Custom Python module that defines how to generate documents. " + + "It can contain function definitions and even class definitions. " + + "This gives users more granular control over how data is generated. " + + "This module must contain generate_synthetic_document() definition.", ) exclusive_params = synthetic_data_generator_parser.add_mutually_exclusive_group(required=True) - exclusive_params.add_argument( - "--total-size", - "-s", - type=int, - help="Total size in GB of synthetically generated data corpora" - ) + exclusive_params.add_argument("--total-size", "-s", type=int, help="Total size in GB of synthetically generated data corpora") + synthetic_data_generator_parser.add_argument("--index-name", "-n", required=True, help="Index name associated with generated corpora") synthetic_data_generator_parser.add_argument( - "--index-name", - "-n", - required=True, - help="Index name associated with generated corpora" + "--output-path", "-p", default=os.path.join(os.getcwd(), "generated_corpora"), help="Output path for data corpora. Data corpora will be written in a directory." ) synthetic_data_generator_parser.add_argument( - "--output-path", - "-p", - default=os.path.join(os.getcwd(), "generated_corpora"), - help="Output path for data corpora. Data corpora will be written in a directory." - ) - synthetic_data_generator_parser.add_argument( - "--custom-config", - "-c", - default=None, - help="Optional config where users can specify overrides for mapping synthetic data generator or values that module should use." + "--custom-config", "-c", default=None, help="Optional config where users can specify overrides for mapping synthetic data generator or values that module should use." ) synthetic_data_generator_parser.add_argument( "--test-document", "-t", default=False, action="store_true", - help="Generates a single synthetic document and displays it to the console so that users can validate generated values and output." + help="Generates a single synthetic document and displays it to the console so that users can validate generated values and output.", ) create_workload_parser = subparsers.add_parser("create-workload", help="Create a workload from existing data") - create_workload_parser.add_argument( - "--workload", - "-w", - required=True, - help="Name of the generated workload") - create_workload_parser.add_argument( - "--indices", - "-i", - type=non_empty_list, - required=True, - help="Comma-separated list of indices to include in the workload") - create_workload_parser.add_argument( - "--target-hosts", - "-t", - default="", - required=True, - help="Comma-separated list of host:port pairs which should be targeted") + create_workload_parser.add_argument("--workload", "-w", required=True, help="Name of the generated workload") + create_workload_parser.add_argument("--indices", "-i", type=non_empty_list, required=True, help="Comma-separated list of indices to include in the workload") + create_workload_parser.add_argument("--target-hosts", "-t", default="", required=True, help="Comma-separated list of host:port pairs which should be targeted") create_workload_parser.add_argument( "--client-options", "-c", default=opts.ClientOptions.DEFAULT_CLIENT_OPTIONS, - help=f"Comma-separated list of client options to use. (default: {opts.ClientOptions.DEFAULT_CLIENT_OPTIONS})") - create_workload_parser.add_argument( - "--output-path", - default=os.path.join(os.getcwd(), "workloads"), - help="Workload output directory (default: workloads/)") + help=f"Comma-separated list of client options to use. (default: {opts.ClientOptions.DEFAULT_CLIENT_OPTIONS})", + ) + create_workload_parser.add_argument("--output-path", default=os.path.join(os.getcwd(), "workloads"), help="Workload output directory (default: workloads/)") create_workload_parser.add_argument( - "--custom-queries", - type=argparse.FileType('r'), - help="Input JSON file to use containing custom workload queries that override the default match_all query") + "--custom-queries", type=argparse.FileType("r"), help="Input JSON file to use containing custom workload queries that override the default match_all query" + ) create_workload_parser.add_argument( "--number-of-docs", action=opts.StoreKeyPairAsDict, - nargs='+', + nargs="+", metavar="KEY:VAL", - help="Map of index name and integer doc count to extract. Ensure that index name also exists in --indices parameter. " + - "To specify several indices and doc counts, use format: : : ...") + help="Map of index name and integer doc count to extract. Ensure that index name also exists in --indices parameter. " + + "To specify several indices and doc counts, use format: : : ...", + ) create_workload_parser.add_argument( "--sample-frequency", action=opts.StoreKeyPairAsDict, - nargs='+', + nargs="+", metavar="KEY:VAL", - help="Map of index name and an integer, representing the sample frequency of docs that should be extracted per index. " + - "Ensure that index name also exists in --indices parameter. " + - "To specify several indices and doc counts, use format: : : ...") - - convert_workload_parser = subparsers.add_parser( - "convert-workload", - help="Convert an OpenSearch Benchmark workload to Solr-native format" - ) - convert_workload_parser.add_argument( - "--workload-path", - required=True, - help="Path to the source OpenSearch Benchmark workload directory (must contain workload.json)." - ) - convert_workload_parser.add_argument( - "--output-path", - default=None, - help="Path where the converted Solr workload will be written " - "(default: -solr)." - ) - convert_workload_parser.add_argument( - "--force", - action="store_true", - default=False, - help="Overwrite an existing converted workload directory." + help="Map of index name and an integer, representing the sample frequency of docs that should be extracted per index. " + + "Ensure that index name also exists in --indices parameter. " + + "To specify several indices and doc counts, use format: : : ...", ) + convert_workload_parser = subparsers.add_parser("convert-workload", help="Convert an OpenSearch Benchmark workload to Solr-native format") + convert_workload_parser.add_argument("--workload-path", required=True, help="Path to the source OpenSearch Benchmark workload directory (must contain workload.json).") + convert_workload_parser.add_argument("--output-path", default=None, help="Path where the converted Solr workload will be written (default: -solr).") + convert_workload_parser.add_argument("--force", action="store_true", default=False, help="Overwrite an existing converted workload directory.") + compare_parser = subparsers.add_parser("compare", help="Compare two test_runs") - compare_parser.add_argument( - "--baseline", - "-b", - required=True, - help=f"TestRun ID of the baseline (see {PROGRAM_NAME} list test-runs).") - compare_parser.add_argument( - "--contender", - "-c", - required=True, - help=f"TestRun ID of the contender (see {PROGRAM_NAME} list test-runs).") + compare_parser.add_argument("--baseline", "-b", required=True, help=f"TestRun ID of the baseline (see {PROGRAM_NAME} list test-runs).") + compare_parser.add_argument("--contender", "-c", required=True, help=f"TestRun ID of the contender (see {PROGRAM_NAME} list test-runs).") compare_parser.add_argument( "--percentiles", - help=f"A comma-separated list of percentiles to report latency and service time." - f"(default: {metrics.GlobalStatsCalculator.DEFAULT_LATENCY_PERCENTILES}).", - default=metrics.GlobalStatsCalculator.DEFAULT_LATENCY_PERCENTILES) + help=f"A comma-separated list of percentiles to report latency and service time.(default: {metrics.GlobalStatsCalculator.DEFAULT_LATENCY_PERCENTILES}).", + default=metrics.GlobalStatsCalculator.DEFAULT_LATENCY_PERCENTILES, + ) compare_parser.add_argument( - "--results-format", - help="Define the output format for the command line results (default: markdown).", - choices=["markdown", "csv"], - default="markdown") + "--results-format", help="Define the output format for the command line results (default: markdown).", choices=["markdown", "csv"], default="markdown" + ) compare_parser.add_argument( "--results-numbers-align", help="Define the output column number alignment for the command line results (default: right).", choices=["right", "center", "left", "decimal"], - default="right") - compare_parser.add_argument( - "--results-file", - help="Write the command line results also to the provided file.", - default="") - compare_parser.add_argument( - "--show-in-results", - help="Whether to include the comparison in the results file.", - default=True) + default="right", + ) + compare_parser.add_argument("--results-file", help="Write the command line results also to the provided file.", default="") + compare_parser.add_argument("--show-in-results", help="Whether to include the comparison in the results file.", default=True) visualize_parser = subparsers.add_parser("visualize", help="Generate HTML visualization for a test run") - visualize_parser.add_argument( - "--test-run-id", - "-tid", - dest="test_run_id", - required=True, - help=f"TestRun ID to visualize (see {PROGRAM_NAME} list test_runs).") + visualize_parser.add_argument("--test-run-id", "-tid", dest="test_run_id", required=True, help=f"TestRun ID to visualize (see {PROGRAM_NAME} list test_runs).") visualize_parser.add_argument( "--output-path", dest="output_path", help="Path where the HTML report should be saved. If not specified, it will be saved in the test run directory, where test_run.json can be found.", - default=None) + default=None, + ) aggregate_parser = subparsers.add_parser("aggregate", help="Aggregate multiple test-runs") - aggregate_parser.add_argument( - "--test-runs", - type=non_empty_list, - required=True, - help="Comma-separated list of TestRun IDs to aggregate") - aggregate_parser.add_argument( - "--test-runs-id", - "-tid", - help="Define a unique id for this aggregated test-run.", - default="") - aggregate_parser.add_argument( - "--results-file", - help="Write the aggregated results to the provided file.", - default="") - aggregate_parser.add_argument( - "--workload-repository", - help="Define the repository from where solr-orbit will load workloads (default: default).", - default="default") + aggregate_parser.add_argument("--test-runs", type=non_empty_list, required=True, help="Comma-separated list of TestRun IDs to aggregate") + aggregate_parser.add_argument("--test-runs-id", "-tid", help="Define a unique id for this aggregated test-run.", default="") + aggregate_parser.add_argument("--results-file", help="Write the aggregated results to the provided file.", default="") + aggregate_parser.add_argument("--workload-repository", help="Define the repository from where solr-orbit will load workloads (default: default).", default="default") download_parser = subparsers.add_parser("download", help="Downloads an artifact") - download_parser.add_argument( - "--cluster-config-repository", - help="Define the repository from where solr-orbit will load cluster-configs (default: default).", - default="default") - download_parser.add_argument( - "--cluster-config-revision", - help="Define a specific revision in the cluster-config repository that solr-orbit should use.", - default=None) - download_parser.add_argument( - "--cluster-config-path", - help="Define the path to the cluster-config and plugin configurations to use.") + download_parser.add_argument("--cluster-config-repository", help="Define the repository from where solr-orbit will load cluster-configs (default: default).", default="default") + download_parser.add_argument("--cluster-config-revision", help="Define a specific revision in the cluster-config repository that solr-orbit should use.", default=None) + download_parser.add_argument("--cluster-config-path", help="Define the path to the cluster-config and plugin configurations to use.") download_parser.add_argument( "--distribution-version", type=supported_os_version, - help="Define the version of the distribution to download. " - "Check https://projects.apache.org/project.html?solr for released versions.", - default="") - download_parser.add_argument( - "--distribution-repository", - help="Define the repository from where the distribution should be downloaded (default: release).", - default="release") + help="Define the version of the distribution to download. Check https://projects.apache.org/project.html?solr for released versions.", + default="", + ) + download_parser.add_argument("--distribution-repository", help="Define the repository from where the distribution should be downloaded (default: release).", default="release") download_parser.add_argument( "--cluster-config", - help=f"Define the cluster-config to use. List possible " - f"cluster-configs with `{PROGRAM_NAME} list " - f"cluster-configs` (default: defaults).", - default="defaults") # optimized for local usage + help=f"Define the cluster-config to use. List possible cluster-configs with `{PROGRAM_NAME} list cluster-configs` (default: defaults).", + default="defaults", + ) # optimized for local usage download_parser.add_argument( - "--cluster-config-params", - help="Define a comma-separated list of key:value pairs that are injected verbatim as variables for the cluster-config.", - default="" + "--cluster-config-params", help="Define a comma-separated list of key:value pairs that are injected verbatim as variables for the cluster-config.", default="" ) install_parser = subparsers.add_parser("install", help="Installs a Solr node locally") install_parser.add_argument( "--revision", help="Define the source code revision for building the benchmark candidate. 'current' uses the source tree as is," - " 'latest' fetches the latest version on main. It is also possible to specify a commit id or a timestamp." - " The timestamp must be specified as: \"@ts\" where \"ts\" must be a valid ISO 8601 timestamp, " - "e.g. \"@2013-07-27T10:37:00Z\" (default: current).", - default="current") # optimized for local usage, don't fetch sources + " 'latest' fetches the latest version on main. It is also possible to specify a commit id or a timestamp." + ' The timestamp must be specified as: "@ts" where "ts" must be a valid ISO 8601 timestamp, ' + 'e.g. "@2013-07-27T10:37:00Z" (default: current).', + default="current", + ) # optimized for local usage, don't fetch sources # Intentionally undocumented as we do not consider Docker a fully supported option. - install_parser.add_argument( - "--build-type", - help=argparse.SUPPRESS, - choices=["tar", "docker"], - default="tar") - install_parser.add_argument( - "--cluster-config-repository", - help="Define the repository from where solr-orbit will load cluster-configs (default: default).", - default="default") - install_parser.add_argument( - "--cluster-config-revision", - help="Define a specific revision in the cluster-config repository that solr-orbit should use.", - default=None) - install_parser.add_argument( - "--cluster-config-path", - help="Define the path to the cluster-config and plugin configurations to use.") - install_parser.add_argument( - "--runtime-jdk", - type=runtime_jdk, - help="The major version of the runtime JDK to use during installation.", - default=None) - install_parser.add_argument( - "--distribution-repository", - help="Define the repository from where the distribution should be downloaded (default: release).", - default="release") + install_parser.add_argument("--build-type", help=argparse.SUPPRESS, choices=["tar", "docker"], default="tar") + install_parser.add_argument("--cluster-config-repository", help="Define the repository from where solr-orbit will load cluster-configs (default: default).", default="default") + install_parser.add_argument("--cluster-config-revision", help="Define a specific revision in the cluster-config repository that solr-orbit should use.", default=None) + install_parser.add_argument("--cluster-config-path", help="Define the path to the cluster-config and plugin configurations to use.") + install_parser.add_argument("--runtime-jdk", type=runtime_jdk, help="The major version of the runtime JDK to use during installation.", default=None) + install_parser.add_argument("--distribution-repository", help="Define the repository from where the distribution should be downloaded (default: release).", default="release") install_parser.add_argument( "--distribution-version", type=supported_os_version, - help="Define the version of the distribution to download. " - "Check https://archive.apache.org/dist/solr/solr/ for released versions.", - default="") - install_parser.add_argument( - "--cluster-config", - help=f"Define the cluster-config to use. List possible " - f"cluster-configs with `{PROGRAM_NAME} list " - f"cluster-configs` (default: defaults).", - default="defaults") # optimized for local usage - install_parser.add_argument( - "--cluster-config-params", - help="Define a comma-separated list of key:value pairs that are injected verbatim as variables for the cluster-config.", - default="" - ) - install_parser.add_argument( - "--solr-modules", - help="Comma-separated list of Solr modules to enable (sets SOLR_MODULES). " - "Example: --solr-modules=analytics,extraction", - default="") - install_parser.add_argument( - "--plugin-params", - help="Define a comma-separated list of key:value pairs that are injected verbatim to all plugins as variables.", - default="" - ) - install_parser.add_argument( - "--network-host", - help="The IP address to bind to and publish", - default="127.0.0.1" - ) - install_parser.add_argument( - "--http-port", - help="The port to expose for HTTP traffic", - default="38983" + help="Define the version of the distribution to download. Check https://archive.apache.org/dist/solr/solr/ for released versions.", + default="", ) install_parser.add_argument( - "--node-name", - help="The name of this Solr node", - default="benchmark-node-0" - ) + "--cluster-config", + help=f"Define the cluster-config to use. List possible cluster-configs with `{PROGRAM_NAME} list cluster-configs` (default: defaults).", + default="defaults", + ) # optimized for local usage install_parser.add_argument( - "--master-nodes", - help="A comma-separated list of the initial master node names", - default="" + "--cluster-config-params", help="Define a comma-separated list of key:value pairs that are injected verbatim as variables for the cluster-config.", default="" ) install_parser.add_argument( - "--seed-hosts", - help="A comma-separated list of the initial seed host IPs", - default="" + "--solr-modules", help="Comma-separated list of Solr modules to enable (sets SOLR_MODULES). Example: --solr-modules=analytics,extraction", default="" ) + install_parser.add_argument("--plugin-params", help="Define a comma-separated list of key:value pairs that are injected verbatim to all plugins as variables.", default="") + install_parser.add_argument("--network-host", help="The IP address to bind to and publish", default="127.0.0.1") + install_parser.add_argument("--http-port", help="The port to expose for HTTP traffic", default="38983") + install_parser.add_argument("--node-name", help="The name of this Solr node", default="benchmark-node-0") + install_parser.add_argument("--master-nodes", help="A comma-separated list of the initial master node names", default="") + install_parser.add_argument("--seed-hosts", help="A comma-separated list of the initial seed host IPs", default="") start_parser = subparsers.add_parser("start", help="Starts a Solr node locally") start_parser.add_argument( @@ -487,27 +323,17 @@ def add_workload_source(subparser): # the default will be dynamically derived by # test_run_orchestrator based on the # presence / absence of other command line options - default="") - start_parser.add_argument( - "--test-run-id", - "-tid", - required=True, - help="Define a unique id for this test_run.", - default="") - start_parser.add_argument( - "--runtime-jdk", - type=runtime_jdk, - help="The major version of the runtime JDK to use.", - default=None) + default="", + ) + start_parser.add_argument("--test-run-id", "-tid", required=True, help="Define a unique id for this test_run.", default="") + start_parser.add_argument("--runtime-jdk", type=runtime_jdk, help="The major version of the runtime JDK to use.", default=None) start_parser.add_argument( "--telemetry", - help=f"Enable the provided telemetry devices, provided as a comma-separated list. List possible telemetry " - f"devices with `{PROGRAM_NAME} list telemetry`.", - default="") + help=f"Enable the provided telemetry devices, provided as a comma-separated list. List possible telemetry devices with `{PROGRAM_NAME} list telemetry`.", + default="", + ) start_parser.add_argument( - "--telemetry-params", - help="Define a comma-separated list of key:value pairs that are injected verbatim to the telemetry devices as parameters.", - default="" + "--telemetry-params", help="Define a comma-separated list of key:value pairs that are injected verbatim to the telemetry devices as parameters.", default="" ) stop_parser = subparsers.add_parser("stop", help="Stops a Solr node locally") @@ -518,305 +344,201 @@ def add_workload_source(subparser): # the default will be dynamically derived by # test_run_orchestrator based on the # presence / absence of other command line options - default="") + default="", + ) stop_parser.add_argument( - "--preserve-install", - help=f"Keep the benchmark candidate and its index. (default: {str(preserve_install).lower()}).", - default=preserve_install, - action="store_true") + "--preserve-install", help=f"Keep the benchmark candidate and its index. (default: {str(preserve_install).lower()}).", default=preserve_install, action="store_true" + ) for p in [list_parser, test_run_parser]: - p.add_argument( - "--distribution-version", - type=supported_os_version, - help="Define the version of the distribution to download.", - default="") - p.add_argument( - "--cluster-config-path", - help="Define the path to the cluster-config and plugin configurations to use.") - p.add_argument( - "--cluster-config-repository", - help="Define repository from where solr-orbit will load cluster-configs (default: default).", - default="default") - p.add_argument( - "--cluster-config-revision", - help="Define a specific revision in the cluster-config repository that solr-orbit should use.", - default=None) + p.add_argument("--distribution-version", type=supported_os_version, help="Define the version of the distribution to download.", default="") + p.add_argument("--cluster-config-path", help="Define the path to the cluster-config and plugin configurations to use.") + p.add_argument("--cluster-config-repository", help="Define repository from where solr-orbit will load cluster-configs (default: default).", default="default") + p.add_argument("--cluster-config-revision", help="Define a specific revision in the cluster-config repository that solr-orbit should use.", default=None) - test_run_parser.add_argument( - "--test-run-id", - "-tid", - help="Define a unique id for this test-run.", - default=str(uuid.uuid4())) + test_run_parser.add_argument("--test-run-id", "-tid", help="Define a unique id for this test-run.", default=str(uuid.uuid4())) test_run_parser.add_argument( "--pipeline", help="Select the pipeline to run.", # the default will be dynamically derived by # test_run_orchestrator based on the # presence / absence of other command line options - default="") + default="", + ) test_run_parser.add_argument( "--revision", help="Define the source code revision for building the benchmark candidate. 'current' uses the source tree as is," - " 'latest' fetches the latest version on main. It is also possible to specify a commit id or a timestamp." - " The timestamp must be specified as: \"@ts\" where \"ts\" must be a valid ISO 8601 timestamp, " - "e.g. \"@2013-07-27T10:37:00Z\" (default: current).", - default="current") # optimized for local usage, don't fetch sources + " 'latest' fetches the latest version on main. It is also possible to specify a commit id or a timestamp." + ' The timestamp must be specified as: "@ts" where "ts" must be a valid ISO 8601 timestamp, ' + 'e.g. "@2013-07-27T10:37:00Z" (default: current).', + default="current", + ) # optimized for local usage, don't fetch sources add_workload_source(test_run_parser) + test_run_parser.add_argument("--workload", "-w", help=f"Define the workload to use. List possible workloads with `{PROGRAM_NAME} list workloads`.") test_run_parser.add_argument( - "--workload", - "-w", - help=f"Define the workload to use. List possible workloads with `{PROGRAM_NAME} list workloads`." - ) - test_run_parser.add_argument( - "--workload-params", - "-wp", - help="Define a comma-separated list of key:value pairs that are injected verbatim to the workload as variables.", - default="" + "--workload-params", "-wp", help="Define a comma-separated list of key:value pairs that are injected verbatim to the workload as variables.", default="" ) - test_run_parser.add_argument( - "--test-procedure", - help=f"Define the test_procedure to use. List possible test_procedures for workloads with `{PROGRAM_NAME} list workloads`.") + test_run_parser.add_argument("--test-procedure", help=f"Define the test_procedure to use. List possible test_procedures for workloads with `{PROGRAM_NAME} list workloads`.") test_run_parser.add_argument( "--cluster-config", - help=f"Define the cluster-config to use. List possible " - f"cluster-configs with `{PROGRAM_NAME} list " - f"cluster-configs` (default: defaults).", - default="defaults") # optimized for local usage + help=f"Define the cluster-config to use. List possible cluster-configs with `{PROGRAM_NAME} list cluster-configs` (default: defaults).", + default="defaults", + ) # optimized for local usage test_run_parser.add_argument( - "--cluster-config-params", - help="Define a comma-separated list of key:value pairs that are injected verbatim as variables for the cluster-config.", - default="" + "--cluster-config-params", help="Define a comma-separated list of key:value pairs that are injected verbatim as variables for the cluster-config.", default="" ) + test_run_parser.add_argument("--runtime-jdk", type=runtime_jdk, help="The major version of the runtime JDK to use.", default=None) test_run_parser.add_argument( - "--runtime-jdk", - type=runtime_jdk, - help="The major version of the runtime JDK to use.", - default=None) - test_run_parser.add_argument( - "--solr-modules", - help="Comma-separated list of Solr modules to enable (sets SOLR_MODULES). " - "Example: --solr-modules=analytics,extraction", - default="") - test_run_parser.add_argument( - "--plugin-params", - help="Define a comma-separated list of key:value pairs that are injected verbatim to all plugins as variables.", - default="" + "--solr-modules", help="Comma-separated list of Solr modules to enable (sets SOLR_MODULES). Example: --solr-modules=analytics,extraction", default="" ) + test_run_parser.add_argument("--plugin-params", help="Define a comma-separated list of key:value pairs that are injected verbatim to all plugins as variables.", default="") test_run_parser.add_argument( "--target-hosts", "-t", - help="Define a comma-separated list of host:port pairs which should be targeted if using the pipeline 'benchmark-only' " - "(default: localhost:9200).", - default="") # actually the default is pipeline specific and it is set later - test_run_parser.add_argument( - "--worker-ips", - help="Define a comma-separated list of hosts which should generate load (default: localhost).", - default="localhost") - test_run_parser.add_argument( - "--grpc-target-hosts", - help="Define a comma-separated list of host:port pairs for gRPC endpoints " - "(default: localhost:9400).", - default="") + help="Define a comma-separated list of host:port pairs which should be targeted if using the pipeline 'benchmark-only' (default: localhost:9200).", + default="", + ) # actually the default is pipeline specific and it is set later + test_run_parser.add_argument("--worker-ips", help="Define a comma-separated list of hosts which should generate load (default: localhost).", default="localhost") + test_run_parser.add_argument("--grpc-target-hosts", help="Define a comma-separated list of host:port pairs for gRPC endpoints (default: localhost:9400).", default="") test_run_parser.add_argument( "--client-options", "-c", - help=f"Define a comma-separated list of client options to use. The options will be passed to the benchmark " - f"client (default: {opts.ClientOptions.DEFAULT_CLIENT_OPTIONS}).", - default=opts.ClientOptions.DEFAULT_CLIENT_OPTIONS) - test_run_parser.add_argument("--on-error", - choices=["continue", "abort"], - help="Controls how solr-orbit behaves on response errors (default: continue).", - default="continue") + help=f"Define a comma-separated list of client options to use. The options will be passed to the benchmark client (default: {opts.ClientOptions.DEFAULT_CLIENT_OPTIONS}).", + default=opts.ClientOptions.DEFAULT_CLIENT_OPTIONS, + ) + test_run_parser.add_argument("--on-error", choices=["continue", "abort"], help="Controls how solr-orbit behaves on response errors (default: continue).", default="continue") test_run_parser.add_argument( "--telemetry", - help=f"Enable the provided telemetry devices, provided as a comma-separated list. List possible telemetry " - f"devices with `{PROGRAM_NAME} list telemetry`.", - default="") - test_run_parser.add_argument( - "--telemetry-params", - help="Define a comma-separated list of key:value pairs that are injected verbatim to the telemetry devices as parameters.", - default="" + help=f"Enable the provided telemetry devices, provided as a comma-separated list. List possible telemetry devices with `{PROGRAM_NAME} list telemetry`.", + default="", ) test_run_parser.add_argument( - "--distribution-repository", - help="Define the repository from where the distribution should be downloaded (default: release).", - default="release") + "--telemetry-params", help="Define a comma-separated list of key:value pairs that are injected verbatim to the telemetry devices as parameters.", default="" + ) + test_run_parser.add_argument("--distribution-repository", help="Define the repository from where the distribution should be downloaded (default: release).", default="release") task_filter_group = test_run_parser.add_mutually_exclusive_group() - task_filter_group.add_argument( - "--include-tasks", - help="Defines a comma-separated list of tasks to run. By default all tasks of a test_procedure are run.") - task_filter_group.add_argument( - "--exclude-tasks", - help="Defines a comma-separated list of tasks not to run. By default all tasks of a test_procedure are run.") + task_filter_group.add_argument("--include-tasks", help="Defines a comma-separated list of tasks to run. By default all tasks of a test_procedure are run.") + task_filter_group.add_argument("--exclude-tasks", help="Defines a comma-separated list of tasks not to run. By default all tasks of a test_procedure are run.") test_run_parser.add_argument( "--user-tag", - help="Define a user-specific key-value pair (separated by ':'). It is added to each metric record as meta info. " - "Example: intention:baseline-ticket-12345", - default="") + help="Define a user-specific key-value pair (separated by ':'). It is added to each metric record as meta info. Example: intention:baseline-ticket-12345", + default="", + ) test_run_parser.add_argument( - "--results-format", - help="Define the output format for the command line results (default: markdown).", - choices=["markdown", "csv"], - default="markdown") + "--results-format", help="Define the output format for the command line results (default: markdown).", choices=["markdown", "csv"], default="markdown" + ) test_run_parser.add_argument( "--results-numbers-align", help="Define the output column number alignment for the command line results (default: right).", choices=["right", "center", "left", "decimal"], - default="right") + default="right", + ) test_run_parser.add_argument( "--show-in-results", help="Define which values are shown in the summary results published (default: available).", choices=["available", "all-percentiles", "all"], - default="available") - test_run_parser.add_argument( - "--results-file", - help="Write the command line results also to the provided file.", - default="") + default="available", + ) + test_run_parser.add_argument("--results-file", help="Write the command line results also to the provided file.", default="") test_run_parser.add_argument( - "--preserve-install", - help=f"Keep the benchmark candidate and its index. (default: {str(preserve_install).lower()}).", - default=preserve_install, - action="store_true") + "--preserve-install", help=f"Keep the benchmark candidate and its index. (default: {str(preserve_install).lower()}).", default=preserve_install, action="store_true" + ) test_run_parser.add_argument( "--test-mode", help="Runs the given workload in 'test mode'. Meant to check a workload for errors but not for real benchmarks (default: false).", default=False, - action="store_true") + action="store_true", + ) test_run_parser.add_argument( "--enable-worker-coordinator-profiling", help="Enables a profiler for analyzing the performance of calls in solr-orbit's worker coordinator (default: false).", default=False, - action="store_true") - test_run_parser.add_argument( - "--enable-assertions", - help="Enables assertion checks for tasks (default: false).", - default=False, - action="store_true") - test_run_parser.add_argument( - "--kill-running-processes", - "-k", action="store_true", - default=False, - help="If any processes is running, it is going to kill them and allow solr-orbit to continue to run." + ) + test_run_parser.add_argument("--enable-assertions", help="Enables assertion checks for tasks (default: false).", default=False, action="store_true") + test_run_parser.add_argument( + "--kill-running-processes", "-k", action="store_true", default=False, help="If any processes is running, it is going to kill them and allow solr-orbit to continue to run." ) test_run_parser.add_argument( "--latency-percentiles", - help=f"A comma-separated list of percentiles to report for latency " - f"(default: {metrics.GlobalStatsCalculator.DEFAULT_LATENCY_PERCENTILES}).", - default=metrics.GlobalStatsCalculator.DEFAULT_LATENCY_PERCENTILES + help=f"A comma-separated list of percentiles to report for latency (default: {metrics.GlobalStatsCalculator.DEFAULT_LATENCY_PERCENTILES}).", + default=metrics.GlobalStatsCalculator.DEFAULT_LATENCY_PERCENTILES, ) test_run_parser.add_argument( "--throughput-percentiles", help=f"A comma-separated list of percentiles to report for throughput, in addition to mean/median/max/min " - f"(default: {metrics.GlobalStatsCalculator.DEFAULT_THROUGHPUT_PERCENTILES}).", - default=metrics.GlobalStatsCalculator.DEFAULT_THROUGHPUT_PERCENTILES + f"(default: {metrics.GlobalStatsCalculator.DEFAULT_THROUGHPUT_PERCENTILES}).", + default=metrics.GlobalStatsCalculator.DEFAULT_THROUGHPUT_PERCENTILES, ) - test_run_parser.add_argument( - "--randomization-enabled", - help="Runs the given workload with query randomization enabled (default: false).", - default=False, - action="store_true") + test_run_parser.add_argument("--randomization-enabled", help="Runs the given workload with query randomization enabled (default: false).", default=False, action="store_true") test_run_parser.add_argument( "--randomization-repeat-frequency", - help=f"The repeat_frequency for query randomization. Ignored if randomization is off" - f"(default: {workload.loader.QueryRandomizerWorkloadProcessor.DEFAULT_RF}).", - default=workload.loader.QueryRandomizerWorkloadProcessor.DEFAULT_RF) + help=f"The repeat_frequency for query randomization. Ignored if randomization is off(default: {workload.loader.QueryRandomizerWorkloadProcessor.DEFAULT_RF}).", + default=workload.loader.QueryRandomizerWorkloadProcessor.DEFAULT_RF, + ) test_run_parser.add_argument( "--randomization-n", help=f"The number of standard values to generate for each field for query randomization." - f"Ignored if randomization is off (default: {workload.loader.QueryRandomizerWorkloadProcessor.DEFAULT_N}).", - default=workload.loader.QueryRandomizerWorkloadProcessor.DEFAULT_N) + f"Ignored if randomization is off (default: {workload.loader.QueryRandomizerWorkloadProcessor.DEFAULT_N}).", + default=workload.loader.QueryRandomizerWorkloadProcessor.DEFAULT_N, + ) test_run_parser.add_argument( "--randomization-alpha", help=f"The alpha parameter used for the Zipf distribution for query randomization. Low values spread the distribution out, " - f"high values favor the most common queries. " - f"Ignored if randomization is off (default: {workload.loader.QueryRandomizerWorkloadProcessor.DEFAULT_ALPHA}).", - default=workload.loader.QueryRandomizerWorkloadProcessor.DEFAULT_ALPHA) - test_run_parser.add_argument( - "--test-iterations", - help="The number of times to run the workload (default: 1).", - default=1) - test_run_parser.add_argument( - "--aggregate", - type=lambda x: (str(x).lower() in ['true', '1', 'yes', 'y']), - help="Aggregate the results of multiple test runs (default: true).", - default=True) + f"high values favor the most common queries. " + f"Ignored if randomization is off (default: {workload.loader.QueryRandomizerWorkloadProcessor.DEFAULT_ALPHA}).", + default=workload.loader.QueryRandomizerWorkloadProcessor.DEFAULT_ALPHA, + ) + test_run_parser.add_argument("--test-iterations", help="The number of times to run the workload (default: 1).", default=1) test_run_parser.add_argument( - "--sleep-timer", - help="Sleep for the specified number of seconds before starting the next test run (default: 5).", - default=5) + "--aggregate", type=lambda x: str(x).lower() in ["true", "1", "yes", "y"], help="Aggregate the results of multiple test runs (default: true).", default=True + ) + test_run_parser.add_argument("--sleep-timer", help="Sleep for the specified number of seconds before starting the next test run (default: 5).", default=5) test_run_parser.add_argument( "--cancel-on-error", action="store_true", help="Stop running tests if an error occurs in one of the test iterations (default: false).", ) - test_run_parser.add_argument( - "--load-test-qps", - help="Run a load test on your cluster, up to a certain QPS value (default: 0)", - default=0 - ) + test_run_parser.add_argument("--load-test-qps", help="Run a load test on your cluster, up to a certain QPS value (default: 0)", default=0) test_run_parser.add_argument( "--redline-test", help="Run a redline test on your cluster, up to a certain QPS value (default: 1000)", - nargs='?', + nargs="?", const=1000, # Value to use when flag is present but no value given default=0, # Value to use when flag is not present - type=int - ) - test_run_parser.add_argument( - "--redline-scale-step", type=int, - help="How many clients to add while scaling up during redline testing (default: 5).", - default=None ) + test_run_parser.add_argument("--redline-scale-step", type=int, help="How many clients to add while scaling up during redline testing (default: 5).", default=None) + test_run_parser.add_argument("--redline-scaledown-percentage", type=float, help="What percentage of clients to remove when errors occur (default: 10%%).", default=None) test_run_parser.add_argument( - "--redline-scaledown-percentage", - type=float, - help="What percentage of clients to remove when errors occur (default: 10%%).", - default=None - ) - test_run_parser.add_argument( - "--redline-post-scaledown-sleep", - type=int, - help="How many seconds to wait before scaling up again after a scale down (default: 30).", - default=None + "--redline-post-scaledown-sleep", type=int, help="How many seconds to wait before scaling up again after a scale down (default: 30).", default=None ) test_run_parser.add_argument( "--redline-max-clients", type=int, help="Maximum number of clients to allow during redline testing. If not set, will default to clients defined in the test procedure.", - default=None + default=None, ) test_run_parser.add_argument( - "--redline-max-cpu-usage", - type=int, - help="Maximum CPU utilization before scaling back client numbers. Used to activate CPU-based feedback in solr-orbit.", - default=None + "--redline-max-cpu-usage", type=int, help="Maximum CPU utilization before scaling back client numbers. Used to activate CPU-based feedback in solr-orbit.", default=None ) test_run_parser.add_argument( "--redline-cpu-window-seconds", type=int, help="How many seconds the window for average CPU load should be in seconds during CPU-based redline testing. (Default: 30)", - default=None + default=None, ) test_run_parser.add_argument( - "--redline-cpu-check-interval", - type=int, - help="How many seconds between CPU checks there should be during CPU-based redline testing. (Default: 30)", - default=None + "--redline-cpu-check-interval", type=int, help="How many seconds between CPU checks there should be during CPU-based redline testing. (Default: 30)", default=None ) test_run_parser.add_argument( - "--visualize", - help="Generate HTML visualizations for benchmark results. Stored in the test runs directory by default", - action="store_true", - default=False + "--visualize", help="Generate HTML visualizations for benchmark results. Stored in the test runs directory by default", action="store_true", default=False ) test_run_parser.add_argument( - "--visualize-output-path", - help="Path where the HTML visualization should be saved when --visualize is enabled. If not specified, it will be saved in the test run directory.", - default=None + "--visualize-output-path", + help="Path where the HTML visualization should be saved when --visualize is enabled. If not specified, it will be saved in the test run directory.", + default=None, ) ############################################################################### @@ -826,37 +548,29 @@ def add_workload_source(subparser): ############################################################################### # This option is intended to tell solr-orbit to assume a different start date than 'now'. This is effectively just useful for things like # backtesting or a benchmark run across environments (think: comparison of EC2 and bare metal) but never for the typical user. - test_run_parser.add_argument( - "--effective-start-date", - help=argparse.SUPPRESS, - type=lambda s: datetime.datetime.strptime(s, "%Y-%m-%d %H:%M:%S"), - default=None) + test_run_parser.add_argument("--effective-start-date", help=argparse.SUPPRESS, type=lambda s: datetime.datetime.strptime(s, "%Y-%m-%d %H:%M:%S"), default=None) # Skips checking that the REST API is available before proceeding with the benchmark - test_run_parser.add_argument( - "--skip-rest-api-check", - help=argparse.SUPPRESS, - action="store_true", - default=False) - - for p in [list_parser, test_run_parser, compare_parser, aggregate_parser, - download_parser, install_parser, start_parser, stop_parser, info_parser, - synthetic_data_generator_parser, create_workload_parser, visualize_parser, - convert_workload_parser]: + test_run_parser.add_argument("--skip-rest-api-check", help=argparse.SUPPRESS, action="store_true", default=False) + + for p in [ + list_parser, + test_run_parser, + compare_parser, + aggregate_parser, + download_parser, + install_parser, + start_parser, + stop_parser, + info_parser, + synthetic_data_generator_parser, + create_workload_parser, + visualize_parser, + convert_workload_parser, + ]: # This option is needed to support a separate configuration for the integration tests on the same machine - p.add_argument( - "--configuration-name", - help=argparse.SUPPRESS, - default=None) - p.add_argument( - "--quiet", - help="Suppress as much as output as possible (default: false).", - default=False, - action="store_true") - p.add_argument( - "--offline", - help="Assume that solr-orbit has no connection to the Internet (default: false).", - default=False, - action="store_true") + p.add_argument("--configuration-name", help=argparse.SUPPRESS, default=None) + p.add_argument("--quiet", help="Suppress as much as output as possible (default: false).", default=False, action="store_true") + p.add_argument("--offline", help="Assume that solr-orbit has no connection to the Internet (default: false).", default=False, action="store_true") return parser @@ -878,6 +592,7 @@ def dispatch_list(cfg): else: raise exceptions.SystemSetupError("Cannot list unknown configuration option [%s]" % what) + def dispatch_visualize(cfg): test_run_id = cfg.opts("system", "test_run.id") output_path = cfg.opts("visualize", "output.path", mandatory=False, default_value=None) @@ -888,11 +603,7 @@ def dispatch_visualize(cfg): te = store.find_by_test_run_id(test_run_id) # render, write, and open the HTML - html_path = ( - store.file_store.store_html_results(te) - if isinstance(store, metrics.CompositeTestRunStore) - else store.store_html_results(te) - ) + html_path = store.file_store.store_html_results(te) if isinstance(store, metrics.CompositeTestRunStore) else store.store_html_results(te) # if the user asked for --output-path, just copy the file there if output_path: @@ -906,6 +617,7 @@ def dispatch_visualize(cfg): except Exception as e: raise exceptions.SystemSetupError(f"Error visualizing test run: {e}") + def print_help_on_errors(): heading = "Getting further help:" console.println(console.format.bold(heading)) @@ -913,8 +625,7 @@ def print_help_on_errors(): console.println(f"* Check the log files in {paths.logs()} for errors.") console.println(f"* Read the documentation at {console.format.link(doc_link())}.") console.println(f"* Ask a question on the forum at {console.format.link(FORUM_LINK)}.") - console.println(f"* Raise an issue in the project issue tracker " - f"and include the log files in {paths.logs()}.") + console.println(f"* Raise an issue in the project issue tracker and include the log files in {paths.logs()}.") def run_test(cfg, kill_running_processes=False): @@ -928,17 +639,18 @@ def run_test(cfg, kill_running_processes=False): try: process.kill_running_benchmark_instances() except BaseException: - logger.exception( - "Could not terminate potentially running solr-orbit instances correctly. Attempting to go on anyway.") + logger.exception("Could not terminate potentially running solr-orbit instances correctly. Attempting to go on anyway.") else: other_benchmark_processes = process.find_all_other_benchmark_processes() if other_benchmark_processes: pids = [p.pid for p in other_benchmark_processes] - msg = f"There are other solr-orbit processes running on this machine (PIDs: {pids}) but only one " \ - f"benchmark is allowed to run at the same time.\n\nYou can use --kill-running-processes flag " \ - f"to kill running processes automatically and allow solr-orbit to continue to run a new benchmark. " \ - f"Otherwise, you need to manually kill them." + msg = ( + f"There are other solr-orbit processes running on this machine (PIDs: {pids}) but only one " + f"benchmark is allowed to run at the same time.\n\nYou can use --kill-running-processes flag " + f"to kill running processes automatically and allow solr-orbit to continue to run a new benchmark. " + f"Otherwise, you need to manually kill them." + ) raise exceptions.BenchmarkError(msg) # redline testing: check metrics store type before running cpu based feedback test @@ -948,8 +660,7 @@ def run_test(cfg, kill_running_processes=False): try: if isinstance(store, metrics.InMemoryMetricsStore): raise exceptions.SystemSetupError( - "CPU-based feedback requires a metrics store, but you're using the in-memory store. " - "Specify a metrics store in your benchmark.ini or via CLI to continue." + "CPU-based feedback requires a metrics store, but you're using the in-memory store. Specify a metrics store in your benchmark.ini or via CLI to continue." ) finally: store.close() @@ -973,8 +684,7 @@ def with_actor_system(runnable, cfg): except Exception as e: logger.exception("Could not bootstrap actor system.") if str(e) == "Unable to determine valid external socket address.": - console.warn("Could not determine a socket address. Are you running without any network? Switching to degraded mode.", - logger=logger) + console.warn("Could not determine a socket address. Are you running without any network? Switching to degraded mode.", logger=logger) logger.info("Falling back to offline actor system.") actor.use_offline_actor_system() actors = actor.bootstrap_actor_system(try_join=False, prefer_local_only=True) @@ -1010,8 +720,7 @@ def with_actor_system(runnable, cfg): logger.warning("User interrupted shutdown of internal actor system.") console.info("Please wait a moment for solr-orbit's internal components to shutdown.") if not shutdown_complete and times_interrupted > 0: - logger.warning("Terminating after user has interrupted actor system shutdown explicitly for [%d] times.", - times_interrupted) + logger.warning("Terminating after user has interrupted actor system shutdown explicitly for [%d] times.", times_interrupted) console.println("") console.warn("Terminating now at the risk of leaving child processes behind.") console.println("") @@ -1020,9 +729,7 @@ def with_actor_system(runnable, cfg): console.println(SKULL) console.println("") elif not shutdown_complete: - console.warn("Could not terminate all internal processes within timeout. Please check and force-terminate " - "all solr-orbit processes.") - + console.warn("Could not terminate all internal processes within timeout. Please check and force-terminate all solr-orbit processes.") def configure_telemetry_params(args, cfg): @@ -1059,10 +766,7 @@ def configure_workload_params(arg_parser, args, cfg, command_requires_workload=T def configure_builder_params(args, cfg, command_requires_cluster_config=True): if args.cluster_config_path: - cfg.add( - config.Scope.applicationOverride, "builder", - "cluster_config.path", os.path.abspath( - io.normalize_path(args.cluster_config_path))) + cfg.add(config.Scope.applicationOverride, "builder", "cluster_config.path", os.path.abspath(io.normalize_path(args.cluster_config_path))) cfg.add(config.Scope.applicationOverride, "builder", "repository.name", None) cfg.add(config.Scope.applicationOverride, "builder", "repository.revision", None) else: @@ -1073,19 +777,13 @@ def configure_builder_params(args, cfg, command_requires_cluster_config=True): if args.distribution_version: cfg.add(config.Scope.applicationOverride, "builder", "distribution.version", args.distribution_version) cfg.add(config.Scope.applicationOverride, "builder", "distribution.repository", args.distribution_repository) - cfg.add(config.Scope.applicationOverride, "builder", - "cluster_config.names", opts.csv_to_list( - args.cluster_config)) - cfg.add(config.Scope.applicationOverride, "builder", - "cluster_config.params", opts.to_dict( - args.cluster_config_params)) + cfg.add(config.Scope.applicationOverride, "builder", "cluster_config.names", opts.csv_to_list(args.cluster_config)) + cfg.add(config.Scope.applicationOverride, "builder", "cluster_config.params", opts.to_dict(args.cluster_config_params)) cfg.add(config.Scope.applicationOverride, "solr", "modules", getattr(args, "solr_modules", "")) pipeline = getattr(args, "pipeline", None) if pipeline == "benchmark-only" and args.cluster_config != "defaults": raise SystemExit( - "ERROR: --cluster-config is only valid for provisioning pipelines " - "(from-distribution, docker, from-sources). " - "It cannot be used with the 'benchmark-only' pipeline." + "ERROR: --cluster-config is only valid for provisioning pipelines (from-distribution, docker, from-sources). It cannot be used with the 'benchmark-only' pipeline." ) @@ -1124,6 +822,7 @@ def configure_reporting_params(args, cfg): cfg.add(config.Scope.applicationOverride, "reporting", "output.path", args.results_file) cfg.add(config.Scope.applicationOverride, "reporting", "numbers.align", args.results_numbers_align) + def prepare_test_runs_dict(args, cfg): cfg.add(config.Scope.applicationOverride, "reporting", "output.path", args.results_file) test_runs_dict = {} @@ -1134,6 +833,7 @@ def prepare_test_runs_dict(args, cfg): test_runs_dict[run] = None return test_runs_dict + def configure_test(arg_parser, args, cfg): # As the run command is doing more work than necessary at the moment, we duplicate several parameters # in this section that actually belong to dedicated subcommands (like install, start or stop). Over time @@ -1150,11 +850,7 @@ def configure_test(arg_parser, args, cfg): cfg.add(config.Scope.applicationOverride, "worker_coordinator", "profiling", args.enable_worker_coordinator_profiling) cfg.add(config.Scope.applicationOverride, "worker_coordinator", "assertions", args.enable_assertions) cfg.add(config.Scope.applicationOverride, "worker_coordinator", "on.error", args.on_error) - cfg.add( - config.Scope.applicationOverride, - "worker_coordinator", - "worker_ips", - opts.csv_to_list(args.worker_ips)) + cfg.add(config.Scope.applicationOverride, "worker_coordinator", "worker_ips", opts.csv_to_list(args.worker_ips)) cfg.add(config.Scope.applicationOverride, "workload", "test.mode.enabled", args.test_mode) cfg.add(config.Scope.applicationOverride, "workload", "load.test.clients", int(args.load_test_qps)) if args.redline_test: @@ -1180,18 +876,20 @@ def configure_test(arg_parser, args, cfg): configure_builder_params(args, cfg) cfg.add(config.Scope.applicationOverride, "builder", "runtime.jdk", args.runtime_jdk) cfg.add(config.Scope.applicationOverride, "builder", "source.revision", args.revision) -# cfg.add(config.Scope.applicationOverride, "builder", -# "cluster_config_instance.plugins", opts.csv_to_list( -# args.opensearch_plugins)) + # cfg.add(config.Scope.applicationOverride, "builder", + # "cluster_config_instance.plugins", opts.csv_to_list( + # args.opensearch_plugins)) cfg.add(config.Scope.applicationOverride, "builder", "plugin.params", opts.to_dict(args.plugin_params)) cfg.add(config.Scope.applicationOverride, "builder", "preserve.install", convert.to_bool(args.preserve_install)) cfg.add(config.Scope.applicationOverride, "builder", "skip.rest.api.check", convert.to_bool(args.skip_rest_api_check)) configure_reporting_params(args, cfg) + def print_test_run_id(args): console.info(f"[Test Run ID]: {args.test_run_id}") + def dispatch_sub_command(arg_parser, args, cfg): sub_command = args.subcommand @@ -1226,9 +924,9 @@ def dispatch_sub_command(arg_parser, args, cfg): cfg.add(config.Scope.applicationOverride, "builder", "node.name", args.node_name) cfg.add(config.Scope.applicationOverride, "builder", "master.nodes", opts.csv_to_list(args.master_nodes)) cfg.add(config.Scope.applicationOverride, "builder", "seed.hosts", opts.csv_to_list(args.seed_hosts)) -# cfg.add(config.Scope.applicationOverride, "builder", -# "cluster_config.plugins", opts.csv_to_list( -# args.opensearch_plugins)) + # cfg.add(config.Scope.applicationOverride, "builder", + # "cluster_config.plugins", opts.csv_to_list( + # args.opensearch_plugins)) cfg.add(config.Scope.applicationOverride, "builder", "plugin.params", opts.to_dict(args.plugin_params)) configure_builder_params(args, cfg) builder.install(cfg) @@ -1255,7 +953,7 @@ def dispatch_sub_command(arg_parser, args, cfg): test_runs.append(args.test_run_id) args.test_run_id = str(uuid.uuid4()) except Exception as e: - console.error(f"Error occurred during test run {_+1}: {str(e)}") + console.error(f"Error occurred during test run {_ + 1}: {str(e)}") if args.cancel_on_error: console.info("Cancelling remaining test runs.") break @@ -1299,15 +997,13 @@ def dispatch_sub_command(arg_parser, args, cfg): dispatch_visualize(cfg) elif sub_command == "convert-workload": from solrorbit.conversion import workload_converter + source_dir = os.path.abspath(args.workload_path) output_dir = os.path.abspath(args.output_path) if args.output_path else source_dir.rstrip("/") + "-solr" force = getattr(args, "force", False) if workload_converter.is_already_converted(output_dir) and not force: - console.info( - f"Workload already converted at: {output_dir}\n" - "Use --force to overwrite." - ) + console.info(f"Workload already converted at: {output_dir}\nUse --force to overwrite.") return True console.info(f"Converting workload: {source_dir} → {output_dir}") @@ -1404,8 +1100,7 @@ def main(): if not args.offline: probing_url = cfg.opts("system", "probing.url", default_value="https://github.com", mandatory=False) if not net.has_internet_connection(probing_url): - console.warn("No Internet connection detected. Automatic download of workload data sets etc. is disabled.", - logger=logger) + console.warn("No Internet connection detected. Automatic download of workload data sets etc. is disabled.", logger=logger) cfg.add(config.Scope.applicationOverride, "system", "offline.mode", True) else: logger.info("Detected a working Internet connection.") diff --git a/solrorbit/benchmarkd.py b/solrorbit/benchmarkd.py index d2b9abf8..45310d4d 100644 --- a/solrorbit/benchmarkd.py +++ b/solrorbit/benchmarkd.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -83,30 +83,22 @@ def main(): log.configure_logging() console.init(assume_tty=False) - parser = argparse.ArgumentParser(prog=PROGRAM_NAME, - description=BANNER + "\n\n Solr Orbit daemon to support remote benchmarks", - epilog="Find out more about Solr Orbit at {}".format(console.format.link(doc_link())), - formatter_class=argparse.RawDescriptionHelpFormatter) - parser.add_argument('--version', action='version', version="%(prog)s " + version.version()) + parser = argparse.ArgumentParser( + prog=PROGRAM_NAME, + description=BANNER + "\n\n Solr Orbit daemon to support remote benchmarks", + epilog="Find out more about Solr Orbit at {}".format(console.format.link(doc_link())), + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + parser.add_argument("--version", action="version", version="%(prog)s " + version.version()) - subparsers = parser.add_subparsers( - title="subcommands", - dest="subcommand", - help="") + subparsers = parser.add_subparsers(title="subcommands", dest="subcommand", help="") subparsers.required = True start_command = subparsers.add_parser("start", help="Starts the Solr Orbit daemon") restart_command = subparsers.add_parser("restart", help="Restarts the Solr Orbit daemon") for p in [start_command, restart_command]: - p.add_argument( - "--node-ip", - required=True, - help="The IP of this node.") - p.add_argument( - "--coordinator-ip", - required=True, - help="The IP of the coordinator node." - ) + p.add_argument("--node-ip", required=True, help="The IP of this node.") + p.add_argument("--coordinator-ip", required=True, help="The IP of the coordinator node.") subparsers.add_parser("stop", help="Stops the Solr Orbit daemon") subparsers.add_parser("status", help="Shows the current status of the local Solr Orbit daemon") @@ -125,5 +117,5 @@ def main(): raise exceptions.BenchmarkError("Unknown subcommand [%s]" % args.subcommand) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/solrorbit/builder/__init__.py b/solrorbit/builder/__init__.py index 8e1a69cc..40e86e05 100644 --- a/solrorbit/builder/__init__.py +++ b/solrorbit/builder/__init__.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -26,5 +26,4 @@ # under the License. # expose only the minimum API -from .builder import StartEngine, EngineStarted, StopEngine, EngineStopped, ResetRelativeTime, BuilderActor, \ - cluster_distribution_version, download, install, start, stop +from .builder import StartEngine, EngineStarted, StopEngine, EngineStopped, ResetRelativeTime, BuilderActor, cluster_distribution_version, download, install, start, stop diff --git a/solrorbit/builder/builder.py b/solrorbit/builder/builder.py index 160caa5f..97713dc0 100644 --- a/solrorbit/builder/builder.py +++ b/solrorbit/builder/builder.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -40,10 +40,9 @@ class NotFoundError(Exception): pass -from solrorbit import (PROGRAM_NAME, actor, client, config, exceptions, - metrics, paths) -from solrorbit.builder import (launcher, provisioner, - supplier) + +from solrorbit import PROGRAM_NAME, actor, client, config, exceptions, metrics, paths +from solrorbit.builder import launcher, provisioner, supplier from solrorbit.builder import cluster_config as cc from solrorbit.utils import console, net @@ -74,25 +73,19 @@ def install(cfg): # Ensure node_name and master_nodes match, using node_name as the default if node_name not in master_nodes: - print( - f"The provided --node-name '{node_name}' and --master-nodes '{master_nodes}' are different. " - f"Using '{node_name}' for both node name and initial master node." - ) + print(f"The provided --node-name '{node_name}' and --master-nodes '{master_nodes}' are different. Using '{node_name}' for both node name and initial master node.") master_nodes = [node_name] if build_type == "tar": binary_supplier = supplier.create(cfg, sources, cluster_config) - p = provisioner.local(cfg=cfg, cluster_config=cluster_config, ip=ip, http_port=http_port, - all_node_ips=seed_hosts, all_node_names=master_nodes, target_root=root_path, - node_name=node_name) + p = provisioner.local( + cfg=cfg, cluster_config=cluster_config, ip=ip, http_port=http_port, all_node_ips=seed_hosts, all_node_names=master_nodes, target_root=root_path, node_name=node_name + ) node_config = p.prepare(binary=binary_supplier()) elif build_type == "docker": if len(plugins) > 0: - raise exceptions.SystemSetupError("You cannot specify any plugins for Docker clusters. Please remove " - "\"--plugins\" and try again.") - p = provisioner.docker( - cfg=cfg, cluster_config=cluster_config, - ip=ip, http_port=http_port, target_root=root_path, node_name=node_name) + raise exceptions.SystemSetupError('You cannot specify any plugins for Docker clusters. Please remove "--plugins" and try again.') + p = provisioner.docker(cfg=cfg, cluster_config=cluster_config, ip=ip, http_port=http_port, target_root=root_path, node_name=node_name) # there is no binary for Docker that can be downloaded / built upfront node_config = p.prepare(binary=None) else: @@ -109,8 +102,9 @@ def start(cfg): with contextlib.suppress(FileNotFoundError): _load_node_file(root_path) install_id = cfg.opts("system", "install.id") - raise exceptions.SystemSetupError("A node with this installation id is already running. Please stop it first " - "with {} stop --installation-id={}".format(PROGRAM_NAME, install_id)) + raise exceptions.SystemSetupError( + "A node with this installation id is already running. Please stop it first with {} stop --installation-id={}".format(PROGRAM_NAME, install_id) + ) node_config = provisioner.load_node_configuration(root_path) @@ -146,7 +140,7 @@ def stop(cfg): test_run_id=current_test_run.test_run_id, test_run_timestamp=current_test_run.test_run_timestamp, workload_name=current_test_run.workload_name, - test_procedure_name=current_test_run.test_procedure_name + test_procedure_name=current_test_run.test_procedure_name, ) except exceptions.NotFound: logging.getLogger(__name__).info("Could not find test_run [%s] and will thus not persist system metrics.", test_run_id) @@ -167,9 +161,7 @@ def stop(cfg): metrics_store.close() # TODO: Do we need to expose this as a separate command as well? - provisioner.cleanup(preserve=cfg.opts("builder", "preserve.install"), - install_dir=node_config.binary_path, - data_paths=node_config.data_paths) + provisioner.cleanup(preserve=cfg.opts("builder", "preserve.install"), install_dir=node_config.binary_path, data_paths=node_config.data_paths) def _load_node_file(root_path): @@ -190,9 +182,9 @@ def _delete_node_file(root_path): # Public Messages ############################## + class StartEngine: - def __init__(self, cfg, open_metrics_context, sources, distribution, external, docker, ip=None, port=None, - node_id=None): + def __init__(self, cfg, open_metrics_context, sources, distribution, external, docker, ip=None, port=None, node_id=None): self.cfg = cfg self.open_metrics_context = open_metrics_context self.sources = sources @@ -215,8 +207,7 @@ def for_nodes(self, all_node_ips=None, all_node_ids=None, ip=None, port=None, no :param node_ids: A list of node id to set. :return: A corresponding ``StartNodes`` message with the specified IP, port number and node ids. """ - return StartNodes(self.cfg, self.open_metrics_context, self.sources, self.distribution, - self.external, self.docker, all_node_ips, all_node_ids, ip, port, node_ids) + return StartNodes(self.cfg, self.open_metrics_context, self.sources, self.distribution, self.external, self.docker, all_node_ips, all_node_ids, ip, port, node_ids) class EngineStarted: @@ -241,9 +232,9 @@ def __init__(self, reset_in_seconds): # Builder internal messages ############################## + class StartNodes: - def __init__(self, cfg, open_metrics_context, sources, distribution, external, docker, - all_node_ips, all_node_ids, ip, port, node_ids): + def __init__(self, cfg, open_metrics_context, sources, distribution, external, docker, all_node_ips, all_node_ids, ip, port, node_ids): self.cfg = cfg self.open_metrics_context = open_metrics_context self.sources = sources @@ -293,9 +284,9 @@ def to_ip_port(hosts): host_or_ip = host.pop("host") port = host.pop("port", 8983) if host: - raise exceptions.SystemSetupError("When specifying nodes to be managed by solr-orbit you can only supply " - "hostname:port pairs (e.g. 'localhost:8983'), any additional options cannot " - "be supported.") + raise exceptions.SystemSetupError( + "When specifying nodes to be managed by solr-orbit you can only supply hostname:port pairs (e.g. 'localhost:8983'), any additional options cannot be supported." + ) ip = net.resolve(host_or_ip) ip_port_pairs.append((ip, port)) return ip_port_pairs @@ -453,17 +444,17 @@ def on_all_nodes_stopped(self): # do not self-terminate, let the parent actor handle this -@thespian.actors.requireCapability('coordinator') +@thespian.actors.requireCapability("coordinator") class Dispatcher(actor.BenchmarkActor): """This Actor receives a copy of the startmsg (with the computed hosts - attached) and creates a NodeBuilderActor on each targeted - remote host. It uses Thespian SystemRegistration to get - notification of when remote nodes are available. As a special - case, if an IP address is localhost, the NodeBuilderActor is - immediately created locally. Once All NodeBuilderActors are - started, it will send them all their startup message, with a - reply-to back to the actor that made the request of the - Dispatcher. + attached) and creates a NodeBuilderActor on each targeted + remote host. It uses Thespian SystemRegistration to get + notification of when remote nodes are available. As a special + case, if an IP address is localhost, the NodeBuilderActor is + immediately created locally. Once All NodeBuilderActors are + started, it will send them all their startup message, with a + reply-to back to the actor that made the request of the + Dispatcher. """ def __init__(self): @@ -485,9 +476,8 @@ def receiveMsg_StartEngine(self, startmsg, sender): for (ip, port), node in all_nodes_by_host.items(): submsg = startmsg.for_nodes(all_node_ips, all_node_ids, ip, port, node) submsg.reply_to = sender - if ip == '127.0.0.1': - m = self.createActor(NodeBuilderActor, - targetActorRequirements={"coordinator": True}) + if ip == "127.0.0.1": + m = self.createActor(NodeBuilderActor, targetActorRequirements={"coordinator": True}) self.pending.append((m, submsg)) else: self.remotes[ip].append(submsg) @@ -505,16 +495,13 @@ def receiveMsg_StartEngine(self, startmsg, sender): def receiveMsg_ActorSystemConventionUpdate(self, convmsg, sender): if not convmsg.remoteAdded: self.logger.warning("Remote Solr Orbit node [%s] exited during NodeBuilderActor startup process.", convmsg.remoteAdminAddress) - self.start_sender(actor.BenchmarkFailure( - "Remote Solr Orbit node [%s] has been shutdown prematurely." % convmsg.remoteAdminAddress)) + self.start_sender(actor.BenchmarkFailure("Remote Solr Orbit node [%s] has been shutdown prematurely." % convmsg.remoteAdminAddress)) else: - remote_ip = convmsg.remoteCapabilities.get('ip', None) + remote_ip = convmsg.remoteCapabilities.get("ip", None) self.logger.info("Remote Solr Orbit node [%s] has started.", remote_ip) for eachmsg in self.remotes[remote_ip]: - self.pending.append((self.createActor(NodeBuilderActor, - targetActorRequirements={"ip": remote_ip}), - eachmsg)) + self.pending.append((self.createActor(NodeBuilderActor, targetActorRequirements={"ip": remote_ip}), eachmsg)) if remote_ip in self.remotes: del self.remotes[remote_ip] if not self.remotes: @@ -559,13 +546,19 @@ def receiveMsg_StartNodes(self, msg, sender): self.logger.info("Starting node(s) %s on [%s].", msg.node_ids, msg.ip) # Load node-specific configuration - cfg = config.auto_load_local_config(msg.cfg, additional_sections=[ - # only copy the relevant bits - "workload", "builder", "client", "telemetry", - # allow metrics store to extract test_run meta-data - "test_run", - "source" - ]) + cfg = config.auto_load_local_config( + msg.cfg, + additional_sections=[ + # only copy the relevant bits + "workload", + "builder", + "client", + "telemetry", + # allow metrics store to extract test_run meta-data + "test_run", + "source", + ], + ) # set root path (normally done by the main entry point) cfg.add(config.Scope.application, "node", "benchmark.root", paths.benchmark_root()) if not msg.external: @@ -576,8 +569,7 @@ def receiveMsg_StartNodes(self, msg, sender): metrics_store.open(ctx=msg.open_metrics_context) # avoid follow-up errors in case we receive an unexpected ActorExitRequest due to an early failure in a parent actor. - self.builder = create(cfg, metrics_store, msg.ip, msg.port, msg.all_node_ips, msg.all_node_ids, - msg.sources, msg.distribution, msg.external, msg.docker) + self.builder = create(cfg, metrics_store, msg.ip, msg.port, msg.all_node_ips, msg.all_node_ids, msg.sources, msg.distribution, msg.external, msg.docker) self.builder.start_engine() self.wakeupAfter(METRIC_FLUSH_INTERVAL_SECONDS) self.send(getattr(msg, "reply_to", sender), NodesStarted()) @@ -622,6 +614,7 @@ def receiveUnrecognizedMessage(self, msg, sender): # Internal API (only used by the actor and for tests) ##################################################### + def load_cluster_config(cfg, external): # externally provisioned clusters do not support cluster_configs / plugins if external: @@ -629,20 +622,14 @@ def load_cluster_config(cfg, external): plugins = [] else: cluster_config_path = cc.cluster_config_path(cfg) - cluster_config = cc.load_cluster_config( - cluster_config_path, - cfg.opts("builder", "cluster_config.names"), - cfg.opts("builder", "cluster_config.params")) - plugins = cc.load_plugins(cluster_config_path, - cfg.opts("builder", "cluster_config.plugins", mandatory=False), - cfg.opts("builder", "plugin.params", mandatory=False)) + cluster_config = cc.load_cluster_config(cluster_config_path, cfg.opts("builder", "cluster_config.names"), cfg.opts("builder", "cluster_config.params")) + plugins = cc.load_plugins(cluster_config_path, cfg.opts("builder", "cluster_config.plugins", mandatory=False), cfg.opts("builder", "plugin.params", mandatory=False)) # Store cluster_config_instance in config for TestRun to access (for result metadata) cfg.add(config.Scope.applicationOverride, "builder", "cluster_config.instance", cluster_config) return cluster_config, plugins -def create(cfg, metrics_store, node_ip, node_http_port, all_node_ips, all_node_ids, sources=False, distribution=False, - external=False, docker=False): +def create(cfg, metrics_store, node_ip, node_http_port, all_node_ips, all_node_ids, sources=False, distribution=False, external=False, docker=False): test_run_root_path = paths.test_run_root(cfg) node_ids = cfg.opts("provisioning", "node.ids", mandatory=False) node_name_prefix = cfg.opts("provisioning", "node.name.prefix") @@ -654,16 +641,13 @@ def create(cfg, metrics_store, node_ip, node_http_port, all_node_ips, all_node_i all_node_names = ["%s-%s" % (node_name_prefix, n) for n in all_node_ids] for node_id in node_ids: node_name = "%s-%s" % (node_name_prefix, node_id) - p.append( - provisioner.local(cfg, cluster_config, node_ip, node_http_port, all_node_ips, - all_node_names, test_run_root_path, node_name)) + p.append(provisioner.local(cfg, cluster_config, node_ip, node_http_port, all_node_ips, all_node_names, test_run_root_path, node_name)) l = launcher.ProcessLauncher(cfg) elif external: raise exceptions.BenchmarkAssertionError("Externally provisioned clusters should not need to be managed by Solr Orbit's builder") elif docker: if len(plugins) > 0: - raise exceptions.SystemSetupError("You cannot specify any plugins for Docker clusters. Please remove " - "\"--plugin-params\" and try again.") + raise exceptions.SystemSetupError('You cannot specify any plugins for Docker clusters. Please remove "--plugin-params" and try again.') s = lambda: None p = [] for node_id in node_ids: @@ -724,9 +708,7 @@ def stop_engine(self): self.metrics_store.close() self.nodes = [] for node_config in self.node_configs: - provisioner.cleanup(preserve=self.preserve_install, - install_dir=node_config.binary_path, - data_paths=node_config.data_paths) + provisioner.cleanup(preserve=self.preserve_install, install_dir=node_config.binary_path, data_paths=node_config.data_paths) self.node_configs = [] def _current_test_run(self): diff --git a/solrorbit/builder/cluster.py b/solrorbit/builder/cluster.py index 22802db9..9cfdca22 100644 --- a/solrorbit/builder/cluster.py +++ b/solrorbit/builder/cluster.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an diff --git a/solrorbit/builder/cluster_builder.py b/solrorbit/builder/cluster_builder.py index abdfdb09..cac9825e 100644 --- a/solrorbit/builder/cluster_builder.py +++ b/solrorbit/builder/cluster_builder.py @@ -2,6 +2,8 @@ The ClusterBuilder is the interface into the builder system from the Dispatcher. This class orchestrates all of the builder subcomponents used to create and delete a cluster. """ + + class ClusterBuilder: def __init__(self, provisioner, downloader, installer, launcher): self.provisioner = provisioner diff --git a/solrorbit/builder/cluster_config.py b/solrorbit/builder/cluster_config.py index 88d6bdfd..9c730638 100644 --- a/solrorbit/builder/cluster_config.py +++ b/solrorbit/builder/cluster_config.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -56,9 +56,7 @@ def list_cluster_configs(cfg): # idiomatic way according to https://docs.python.org/3/howto/sorting.html#sort-stability-and-complex-sorts cluster_configs = sorted(sorted(cluster_configs, key=lambda c: c.name), key=lambda c: c.type) console.println("Available cluster-configs:\n") - console.println(tabulate.tabulate( - [[c.name, c.type, c.description] for c in cluster_configs], - headers=["Name", "Type", "Description"])) + console.println(tabulate.tabulate([[c.name, c.type, c.description] for c in cluster_configs], headers=["Name", "Type", "Description"])) def load_cluster_config(repo, name, cluster_config_params=None): @@ -85,8 +83,7 @@ def __init__(self, root_path, entry_point): root_path = p # multiple cluster_configs are based on the same hook elif root_path != p: - raise exceptions.SystemSetupError( - "Invalid cluster_config: {}. Multiple bootstrap hooks are forbidden.".format(name)) + raise exceptions.SystemSetupError("Invalid cluster_config: {}. Multiple bootstrap hooks are forbidden.".format(name)) all_config_base_vars.update(descriptor.config_base_variables) all_cluster_config_vars.update(descriptor.variables) @@ -136,9 +133,7 @@ def cluster_config_path(cfg): cluster_config_repositories = cfg.opts("builder", "cluster_config.repository.dir") cluster_configs_dir = os.path.join(root, cluster_config_repositories) - current_cluster_config_repo = repo.BenchmarkRepository( - default_directory, cluster_configs_dir, - repo_name, "cluster_configs", offline) + current_cluster_config_repo = repo.BenchmarkRepository(default_directory, cluster_configs_dir, repo_name, "cluster_configs", offline) current_cluster_config_repo.set_cluster_configs_dir(repo_revision, distribution_version, cfg) return current_cluster_config_repo.repo_dir @@ -157,9 +152,8 @@ def __cluster_config_name(path): def __is_cluster_config(path): _, extension = io.splitext(path) return extension == ".ini" - return map(__cluster_config_name, filter( - __is_cluster_config, - os.listdir(self.cluster_configs_dir))) + + return map(__cluster_config_name, filter(__is_cluster_config, os.listdir(self.cluster_configs_dir))) def _cluster_config_file(self, name): return os.path.join(self.cluster_configs_dir, "{}.ini".format(name)) @@ -167,9 +161,7 @@ def _cluster_config_file(self, name): def load_cluster_config(self, name, cluster_config_params=None): cluster_config_config_file = self._cluster_config_file(name) if not io.exists(cluster_config_config_file): - raise exceptions.SystemSetupError( - "Unknown cluster-config [{}]. List the available " - "cluster-configs with {} list cluster-configs.".format(name, PROGRAM_NAME)) + raise exceptions.SystemSetupError("Unknown cluster-config [{}]. List the available cluster-configs with {} list cluster-configs.".format(name, PROGRAM_NAME)) config = self._config_loader(cluster_config_config_file) root_paths = [] config_paths = [] @@ -195,9 +187,7 @@ def load_cluster_config(self, name, cluster_config_params=None): if cluster_config_params: variables.update(cluster_config_params) - return ClusterConfigInstanceDescriptor( - name, description, cluster_config_type, - root_paths, config_paths, config_base_vars, variables) + return ClusterConfigInstanceDescriptor(name, description, cluster_config_type, root_paths, config_paths, config_base_vars, variables) def _config_loader(self, file_name): config = configparser.ConfigParser(interpolation=configparser.ExtendedInterpolation()) @@ -244,8 +234,7 @@ class ClusterConfigInstance: # name of the initial Python file to load for cluster_configs. entry_point = "config" - def __init__(self, names, root_path, config_paths, provider=ClusterInfraProvider.LOCAL, - flavor=ClusterFlavor.SELF_MANAGED, variables=None): + def __init__(self, names, root_path, config_paths, provider=ClusterInfraProvider.LOCAL, flavor=ClusterFlavor.SELF_MANAGED, variables=None): """ Creates new settings for a benchmark candidate. @@ -272,7 +261,7 @@ def mandatory_var(self, name): try: return self.variables[name] except KeyError: - raise exceptions.SystemSetupError("ClusterConfigInstance \"{}\" requires config key \"{}\"".format(self.name, name)) + raise exceptions.SystemSetupError('ClusterConfigInstance "{}" requires config key "{}"'.format(self.name, name)) @property def name(self): @@ -359,18 +348,17 @@ def load_plugin(self, name, config_names, plugin_params=None): if not config_names: # maybe we only have a config folder but nothing else (e.g. if there is only an install hook) if io.exists(root_path): - return PluginDescriptor(name=name, - core_plugin=core_plugin is not None, - config=config_names, - root_path=root_path, - variables=plugin_params) + return PluginDescriptor(name=name, core_plugin=core_plugin is not None, config=config_names, root_path=root_path, variables=plugin_params) else: if core_plugin: return core_plugin # If we just have a plugin name then we assume that this is a community plugin and the user has specified a download URL else: - self.logger.info("The plugin [%s] is neither a configured nor an official plugin. Assuming that this is a community " - "plugin not requiring any configuration and you have set a proper download URL.", name) + self.logger.info( + "The plugin [%s] is neither a configured nor an official plugin. Assuming that this is a community " + "plugin not requiring any configuration and you have set a proper download URL.", + name, + ) return PluginDescriptor(name, variables=plugin_params) else: variables = {} @@ -383,12 +371,15 @@ def load_plugin(self, name, config_names, plugin_params=None): # Do we have an explicit configuration for this plugin? if not io.exists(config_file): if core_plugin: - raise exceptions.SystemSetupError("Plugin [%s] does not provide configuration [%s]. List the available plugins " - "and configurations with %s list cluster-configs " - "--distribution-version=VERSION." % (name, config_name, PROGRAM_NAME)) + raise exceptions.SystemSetupError( + "Plugin [%s] does not provide configuration [%s]. List the available plugins " + "and configurations with %s list cluster-configs " + "--distribution-version=VERSION." % (name, config_name, PROGRAM_NAME) + ) else: - raise exceptions.SystemSetupError("Unknown plugin [%s]. List the available plugins with %s list " - "cluster-configs --distribution-version=VERSION." % (name, PROGRAM_NAME)) + raise exceptions.SystemSetupError( + "Unknown plugin [%s]. List the available plugins with %s list cluster-configs --distribution-version=VERSION." % (name, PROGRAM_NAME) + ) config = configparser.ConfigParser(interpolation=configparser.ExtendedInterpolation()) # Do not modify the case of option keys but read them as is @@ -411,8 +402,7 @@ def load_plugin(self, name, config_names, plugin_params=None): # maybe one of the configs is really just for providing variables. However, we still require one config base overall. if len(config_paths) == 0: raise exceptions.SystemSetupError("At least one config base is required for plugin [%s]" % name) - return PluginDescriptor(name=name, core_plugin=core_plugin is not None, config=config_names, root_path=root_path, - config_paths=config_paths, variables=variables) + return PluginDescriptor(name=name, core_plugin=core_plugin is not None, config=config_names, root_path=root_path, config_paths=config_paths, variables=variables) class PluginDescriptor: @@ -466,6 +456,7 @@ class BootstrapHookHandler: """ Responsible for loading and executing component-specific intitialization code. """ + def __init__(self, component, loader_class=modules.ComponentLoader): """ Creates a new BootstrapHookHandler. @@ -512,5 +503,4 @@ def invoke(self, phase, **kwargs): # hooks should only take keyword arguments to be forwards compatible with OSB! hook(config_names=self.component.config, **kwargs) else: - self.logger.debug("Component [%s] in config [%s] has no hook registered for phase [%s].", - self.component.name, self.component.config, phase) + self.logger.debug("Component [%s] in config [%s] has no hook registered for phase [%s].", self.component.name, self.component.config, phase) diff --git a/solrorbit/builder/configs/listers/plugin_config_instance_lister.py b/solrorbit/builder/configs/listers/plugin_config_instance_lister.py index b11a50d6..ed6ad5b5 100644 --- a/solrorbit/builder/configs/listers/plugin_config_instance_lister.py +++ b/solrorbit/builder/configs/listers/plugin_config_instance_lister.py @@ -14,15 +14,19 @@ def __init__(self, config_path_resolver): def list_plugin_config_instances(self): plugin_config_instances = [] for config_format_version in ConfigInstanceTypes.PLUGIN.supported_config_format_versions: - plugins_root_directory = self.config_path_resolver.resolve_config_path(ConfigInstanceTypes.PLUGIN.config_type, - config_format_version) + plugins_root_directory = self.config_path_resolver.resolve_config_path(ConfigInstanceTypes.PLUGIN.config_type, config_format_version) plugin_config_instances += self._list_core_plugins(plugins_root_directory, config_format_version) plugin_config_instances += self._list_configured_plugins(plugins_root_directory, config_format_version) - return sorted(plugin_config_instances, key=lambda plugin_config_instance: ( - plugin_config_instance.format_version, plugin_config_instance.name, - plugin_config_instance.config_names[0] if plugin_config_instance.config_names else None)) + return sorted( + plugin_config_instances, + key=lambda plugin_config_instance: ( + plugin_config_instance.format_version, + plugin_config_instance.name, + plugin_config_instance.config_names[0] if plugin_config_instance.config_names else None, + ), + ) def _list_core_plugins(self, plugins_root_directory, config_format_version): core_plugins_path = os.path.join(plugins_root_directory, "core-plugins.txt") @@ -31,9 +35,11 @@ def _list_core_plugins(self, plugins_root_directory, config_format_version): def _parse_core_plugins(self, core_plugins_path, config_format_version): with open(core_plugins_path, mode="rt", encoding="utf-8") as core_plugins_file: - return [PluginConfigInstance(name=line.strip().split(",")[0], - format_version=f"v{config_format_version}", - is_core_plugin=True) for line in core_plugins_file if not line.startswith("#")] + return [ + PluginConfigInstance(name=line.strip().split(",")[0], format_version=f"v{config_format_version}", is_core_plugin=True) + for line in core_plugins_file + if not line.startswith("#") + ] def _list_configured_plugins(self, plugins_root_directory, config_format_version): configured_plugins = [] @@ -47,8 +53,11 @@ def _list_configured_plugins(self, plugins_root_directory, config_format_version return configured_plugins def _parse_plugins_in_directory(self, plugin_path, plugin_directory, config_format_version): - return [self._parse_plugin_from_config_file(plugin_config_file, plugin_directory, config_format_version) - for plugin_config_file in os.listdir(plugin_path) if self._is_config_file(plugin_path, plugin_config_file)] + return [ + self._parse_plugin_from_config_file(plugin_config_file, plugin_directory, config_format_version) + for plugin_config_file in os.listdir(plugin_path) + if self._is_config_file(plugin_path, plugin_config_file) + ] def _is_config_file(self, plugin_path, plugin_config_file): return os.path.isfile(os.path.join(plugin_path, plugin_config_file)) and io.has_extension(plugin_config_file, ".ini") @@ -58,9 +67,7 @@ def _parse_plugin_from_config_file(self, plugin_config_file, plugin_directory, c plugin_name = self._file_to_plugin_name(plugin_directory) config_name = io.basename(file_name) - return PluginConfigInstance(name=plugin_name, - format_version=f"v{config_format_version}", - config_names=[config_name]) + return PluginConfigInstance(name=plugin_name, format_version=f"v{config_format_version}", config_names=[config_name]) def _file_to_plugin_name(self, file_name): return file_name.replace("_", "-") diff --git a/solrorbit/builder/configs/utils/config_path_resolver.py b/solrorbit/builder/configs/utils/config_path_resolver.py index 608c9bd6..f77338ac 100644 --- a/solrorbit/builder/configs/utils/config_path_resolver.py +++ b/solrorbit/builder/configs/utils/config_path_resolver.py @@ -30,9 +30,7 @@ def _get_config_root_path(self): cluster_config_repositories = self.cfg.opts("builder", "cluster_config.repository.dir") cluster_configs_dir = os.path.join(root, cluster_config_repositories) - current_cluster_config_repo = BenchmarkRepository( - default_directory, cluster_configs_dir, - repo_name, "cluster_configs", offline) + current_cluster_config_repo = BenchmarkRepository(default_directory, cluster_configs_dir, repo_name, "cluster_configs", offline) current_cluster_config_repo.set_cluster_configs_dir(repo_revision, distribution_version, self.cfg) return current_cluster_config_repo.repo_dir diff --git a/solrorbit/builder/downloaders/distribution_downloader.py b/solrorbit/builder/downloaders/distribution_downloader.py index ac4d0940..25349b1d 100644 --- a/solrorbit/builder/downloaders/distribution_downloader.py +++ b/solrorbit/builder/downloaders/distribution_downloader.py @@ -44,8 +44,7 @@ def _fetch_binary(self, host): is_cache_enabled = self.distribution_repository_provider.is_cache_enabled() if is_binary_present and is_cache_enabled: - self.logger.info("Skipping download for version [%s]. Found existing binary at [%s].", version, - distribution_path) + self.logger.info("Skipping download for version [%s]. Found existing binary at [%s].", version, distribution_path) else: self._download(host, distribution_path, download_url, version) @@ -72,8 +71,7 @@ def _download(self, host, distribution_path, download_url, version): try: self.executor.execute(host, f"curl -o {distribution_path} {download_url}") except ExecutorError as e: - self.logger.exception("Exception downloading distribution for version [%s] from [%s].", - version, download_url) + self.logger.exception("Exception downloading distribution for version [%s] from [%s].", version, download_url) raise e self.logger.info("Successfully downloaded distribution [%s].", version) diff --git a/solrorbit/builder/downloaders/downloader.py b/solrorbit/builder/downloaders/downloader.py index 731b8010..225623ce 100644 --- a/solrorbit/builder/downloaders/downloader.py +++ b/solrorbit/builder/downloaders/downloader.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -33,6 +33,7 @@ class Downloader(ABC): A downloader is used to supply the necessary components for running self-managed Solr. Implementations of this interface will download distributions or fetch from a source repository for Solr """ + def __init__(self, executor): self.executor = executor diff --git a/solrorbit/builder/downloaders/repositories/distribution_repository_provider.py b/solrorbit/builder/downloaders/repositories/distribution_repository_provider.py index 62402630..ac3ed2f3 100644 --- a/solrorbit/builder/downloaders/repositories/distribution_repository_provider.py +++ b/solrorbit/builder/downloaders/repositories/distribution_repository_provider.py @@ -33,7 +33,7 @@ def get_download_url(self, host): return self.repository_url_provider.render_url_for_key(host, self.cluster_config.variables, url_key) def get_file_name_from_download_url(self, download_url): - return download_url[download_url.rfind("/") + 1:] + return download_url[download_url.rfind("/") + 1 :] def is_cache_enabled(self): distribution_repository = self.cluster_config.variables["distribution"]["repository"] diff --git a/solrorbit/builder/downloaders/repositories/source_repository_provider.py b/solrorbit/builder/downloaders/repositories/source_repository_provider.py index 72a1fd9a..0f9f9753 100644 --- a/solrorbit/builder/downloaders/repositories/source_repository_provider.py +++ b/solrorbit/builder/downloaders/repositories/source_repository_provider.py @@ -18,28 +18,15 @@ def __init__(self, executor, repository_name): self.update_scenarios = self._generate_update_repository_scenarios() def _generate_update_repository_scenarios(self): - return OrderedDict([ - ( - lambda revision, is_remote_defined: revision == "latest" and is_remote_defined, - self._update_repository_to_latest - ), - ( - lambda revision, is_remote_defined: revision == "current", - self._update_repository_to_current - ), - ( - lambda revision, is_remote_defined: revision.startswith("@") and is_remote_defined, - self._update_repository_to_timestamp - ), - ( - lambda revision, is_remote_defined: is_remote_defined, - self._update_repository_to_commit_hash - ), - ( - lambda revision, is_remote_defined: True, - self._update_repository_to_local_revision - ), - ]) + return OrderedDict( + [ + (lambda revision, is_remote_defined: revision == "latest" and is_remote_defined, self._update_repository_to_latest), + (lambda revision, is_remote_defined: revision == "current", self._update_repository_to_current), + (lambda revision, is_remote_defined: revision.startswith("@") and is_remote_defined, self._update_repository_to_timestamp), + (lambda revision, is_remote_defined: is_remote_defined, self._update_repository_to_commit_hash), + (lambda revision, is_remote_defined: True, self._update_repository_to_local_revision), + ] + ) def fetch_repository(self, host, remote_url, revision, target_dir): if not self.path_manager.is_path_present(host, os.path.join(target_dir, ".git")): @@ -83,8 +70,7 @@ def _update_repository_to_current(self, host, revision, target_dir): def _update_repository_to_timestamp(self, host, revision, target_dir): # convert timestamp annotated for OSB to something git understands -> we strip leading and trailing " and the @. git_timestamp_revision = revision[1:] - self.logger.info("Fetching from remote and checking out revision with timestamp [%s] for " - "%s.", git_timestamp_revision, self.repository_name) + self.logger.info("Fetching from remote and checking out revision with timestamp [%s] for %s.", git_timestamp_revision, self.repository_name) self.git_manager.fetch(host, target_dir) revision_from_timestamp = self.git_manager.get_revision_from_timestamp(host, target_dir, git_timestamp_revision) self.git_manager.checkout(host, target_dir, revision_from_timestamp) @@ -101,8 +87,7 @@ def _update_repository_to_local_revision(self, host, revision, target_dir): def _get_revision(self, host, revision, target_dir): if self.path_manager.is_path_present(host, os.path.join(target_dir, ".git")): git_revision = self.git_manager.get_revision_from_local_repository(host, target_dir) - self.logger.info("User-specified revision [%s] for [%s] results in git revision [%s]", - revision, self.repository_name, git_revision) + self.logger.info("User-specified revision [%s] for [%s] results in git revision [%s]", revision, self.repository_name, git_revision) return git_revision diff --git a/solrorbit/builder/downloaders/source_downloader.py b/solrorbit/builder/downloaders/source_downloader.py index df64ef10..4a641cd0 100644 --- a/solrorbit/builder/downloaders/source_downloader.py +++ b/solrorbit/builder/downloaders/source_downloader.py @@ -23,8 +23,7 @@ class SourceDownloader(Downloader): - def __init__(self, cluster_config, executor, source_repository_provider, binary_builder, template_renderer, - artifact_variables_provider): + def __init__(self, cluster_config, executor, source_repository_provider, binary_builder, template_renderer, artifact_variables_provider): super().__init__(executor) self.logger = logging.getLogger(__name__) self.cluster_config = cluster_config @@ -58,10 +57,13 @@ def _prepare(self, host, artifact_variables): build_command_template = self.cluster_config.variables["source"]["build"]["command"] if self.binary_builder: - self.binary_builder.build(host, [ - self.template_renderer.render_template_string(clean_command_template, artifact_variables), - self.template_renderer.render_template_string(build_command_template, artifact_variables) - ]) + self.binary_builder.build( + host, + [ + self.template_renderer.render_template_string(clean_command_template, artifact_variables), + self.template_renderer.render_template_string(build_command_template, artifact_variables), + ], + ) def _get_zip_path(self, source_path, artifact_variables): artifact_path_pattern_template = self.cluster_config.variables["source"]["artifact_path_pattern"] diff --git a/solrorbit/builder/executors/exception_handling_shell_executor.py b/solrorbit/builder/executors/exception_handling_shell_executor.py index 6a9135e6..cebb122c 100644 --- a/solrorbit/builder/executors/exception_handling_shell_executor.py +++ b/solrorbit/builder/executors/exception_handling_shell_executor.py @@ -10,4 +10,4 @@ def execute(self, host, command, **kwargs): try: return self.executor.execute(host, command, kwargs) except Exception as e: - raise ExecutorError(f"Command \"{command}\" on host \"{host}\" failed to execute", e) + raise ExecutorError(f'Command "{command}" on host "{host}" failed to execute', e) diff --git a/solrorbit/builder/executors/local_shell_executor.py b/solrorbit/builder/executors/local_shell_executor.py index c5fa3585..2157320a 100644 --- a/solrorbit/builder/executors/local_shell_executor.py +++ b/solrorbit/builder/executors/local_shell_executor.py @@ -12,4 +12,4 @@ def execute(self, host, command, output=False, stdout=subprocess.PIPE, stderr=su return process.run_subprocess_with_output(command) else: if process.run_subprocess_with_logging(command, stdout=stdout, stderr=stderr, env=env, detach=detach): - raise ExecutorError(f"Command: \"{command}\" returned a non-zero exit code") + raise ExecutorError(f'Command: "{command}" returned a non-zero exit code') diff --git a/solrorbit/builder/installers/bare_installer.py b/solrorbit/builder/installers/bare_installer.py index 7ffba4ed..ee399247 100644 --- a/solrorbit/builder/installers/bare_installer.py +++ b/solrorbit/builder/installers/bare_installer.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an diff --git a/solrorbit/builder/installers/docker_installer.py b/solrorbit/builder/installers/docker_installer.py index cb92cd26..51eea672 100644 --- a/solrorbit/builder/installers/docker_installer.py +++ b/solrorbit/builder/installers/docker_installer.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -64,15 +64,17 @@ def _create_node(self): node_log_dir = os.path.join(node_root_dir, "logs", "server") node_heap_dump_dir = os.path.join(node_root_dir, "heapdump") - return Node(name=node_name, - port=node_port, - pid=None, - root_dir=node_root_dir, - binary_path=node_binary_path, - log_path=node_log_dir, - heap_dump_path=node_heap_dump_dir, - data_paths=node_data_paths, - telemetry=None) + return Node( + name=node_name, + port=node_port, + pid=None, + root_dir=node_root_dir, + binary_path=node_binary_path, + log_path=node_log_dir, + heap_dump_path=node_heap_dump_dir, + data_paths=node_data_paths, + telemetry=None, + ) def _prepare_node(self, host, node): directories_to_create = [node.binary_path, node.log_path, node.heap_dump_path, node.data_paths[0]] @@ -106,7 +108,7 @@ def _get_config_vars(self, node): "discovery_type": "single-node", "http_port": str(node.port), "zookeeper_port": str(node.port + 1000), - "cluster_settings": {} + "cluster_settings": {}, } config_vars = {} @@ -123,7 +125,7 @@ def _get_docker_vars(self, node, mounts): "solr_data_dir": node.data_paths[0], "solr_log_dir": node.log_path, "solr_heap_dump_dir": node.heap_dump_path, - "mounts": mounts + "mounts": mounts, } self._add_if_defined_for_cluster_config(docker_vars, "docker_mem_limit") self._add_if_defined_for_cluster_config(docker_vars, "docker_cpu_count") diff --git a/solrorbit/builder/installers/exception_handling_installer.py b/solrorbit/builder/installers/exception_handling_installer.py index e2960d40..acd7839e 100644 --- a/solrorbit/builder/installers/exception_handling_installer.py +++ b/solrorbit/builder/installers/exception_handling_installer.py @@ -11,10 +11,10 @@ def install(self, host, binaries, all_node_ips): try: return self.installer.install(host, binaries, all_node_ips) except Exception as e: - raise InstallError(f"Installing node on host \"{host}\" failed", e) + raise InstallError(f'Installing node on host "{host}" failed', e) def cleanup(self, host): try: return self.installer.cleanup(host) except Exception as e: - raise InstallError(f"Cleaning up install data on host \"{host}\" failed", e) + raise InstallError(f'Cleaning up install data on host "{host}" failed', e) diff --git a/solrorbit/builder/installers/installer.py b/solrorbit/builder/installers/installer.py index 4cb5e116..24671920 100644 --- a/solrorbit/builder/installers/installer.py +++ b/solrorbit/builder/installers/installer.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an diff --git a/solrorbit/builder/installers/preparers/solr_preparer.py b/solrorbit/builder/installers/preparers/solr_preparer.py index 52dfa5c2..76821cc8 100644 --- a/solrorbit/builder/installers/preparers/solr_preparer.py +++ b/solrorbit/builder/installers/preparers/solr_preparer.py @@ -51,15 +51,17 @@ def _create_node(self): node_log_dir = os.path.join(node_root_dir, "logs", "server") node_heap_dump_dir = os.path.join(node_root_dir, "heapdump") - return Node(name=node_name, - port=node_port, - pid=None, - root_dir=node_root_dir, - binary_path=node_binary_path, - log_path=node_log_dir, - heap_dump_path=node_heap_dump_dir, - data_paths=None, - telemetry=None) + return Node( + name=node_name, + port=node_port, + pid=None, + root_dir=node_root_dir, + binary_path=node_binary_path, + log_path=node_log_dir, + heap_dump_path=node_heap_dump_dir, + data_paths=None, + telemetry=None, + ) def _prepare_node(self, host, node, binary): self._prepare_directories(host, node) @@ -102,10 +104,10 @@ def get_config_vars(self, host, node, all_node_ips): "network_host": host.address, "http_port": str(node.port), "zookeeper_port": str(node.port + 1000), - "all_node_ips": "[\"%s\"]" % "\",\"".join(all_node_ips), + "all_node_ips": '["%s"]' % '","'.join(all_node_ips), # at the moment we are strict and enforce that all nodes are master eligible nodes "minimum_master_nodes": len(all_node_ips), - "install_root_path": node.binary_path + "install_root_path": node.binary_path, } config_vars = {} config_vars.update(self.cluster_config.variables) diff --git a/solrorbit/builder/java_resolver.py b/solrorbit/builder/java_resolver.py index 1744ff95..3dd13258 100644 --- a/solrorbit/builder/java_resolver.py +++ b/solrorbit/builder/java_resolver.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -49,9 +49,7 @@ def detect_jdk(jdks): allowed_runtime_jdks = [int(v) for v in cluster_config_runtime_jdks.split(",")] except ValueError: - raise exceptions.SystemSetupError( - "ClusterConfigInstance config key \"runtime.jdk\" is invalid: \"{}\" (must be int)".format( - cluster_config_runtime_jdks)) + raise exceptions.SystemSetupError('ClusterConfigInstance config key "runtime.jdk" is invalid: "{}" (must be int)'.format(cluster_config_runtime_jdks)) runtime_jdk_versions = determine_runtime_jdks() diff --git a/solrorbit/builder/launcher.py b/solrorbit/builder/launcher.py index a17e76c1..d98ba213 100644 --- a/solrorbit/builder/launcher.py +++ b/solrorbit/builder/launcher.py @@ -17,7 +17,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -130,6 +130,7 @@ class ProcessLauncher: """ Launcher is responsible for starting and stopping the benchmark candidate. """ + PROCESS_WAIT_TIMEOUT_SECONDS = 90.0 def __init__(self, cfg, clock=time.Clock): @@ -149,8 +150,7 @@ def _start_node(self, node_configuration, node_count_on_host): data_paths = node_configuration.data_paths node_telemetry_dir = os.path.join(node_configuration.node_root_path, "telemetry") - java_major_version, java_home = java_resolver.java_home(node_configuration.cluster_config_runtime_jdks, - self.cfg.opts("builder", "runtime.jdk")) + java_major_version, java_home = java_resolver.java_home(node_configuration.cluster_config_runtime_jdks, self.cfg.opts("builder", "runtime.jdk")) self.logger.info("Java major version: %s", java_major_version) self.logger.info("Java home: %s", java_home) @@ -199,7 +199,7 @@ def _prepare_env(self, node_name, java_home, t): self.logger.debug("env for [%s]: %s", node_name, str(env)) return env - def _set_env(self, env, k, v, separator=' ', prepend=False): + def _set_env(self, env, k, v, separator=" ", prepend=False): if v is not None: if k not in env: env[k] = v @@ -212,11 +212,7 @@ def _set_env(self, env, k, v, separator=' ', prepend=False): def _run_subprocess(command_line, env): command_line_args = shlex.split(command_line) - with subprocess.Popen(command_line_args, - stdout=subprocess.DEVNULL, - stderr=subprocess.DEVNULL, - env=env, - start_new_session=True) as command_line_process: + with subprocess.Popen(command_line_args, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, env=env, start_new_session=True) as command_line_process: # wait for it to finish command_line_process.wait() @@ -251,8 +247,7 @@ def _start_process(binary_path, env, distribution_version=None): logging.info("Solr %s uses embedded cloud mode by default", distribution_version) except (ValueError, IndexError): # If we can't parse version, assume newer Solr (no flag) - logging.warning("Could not parse Solr version from '%s', assuming 10.x+ (no --cloud flag)", - distribution_version) + logging.warning("Could not parse Solr version from '%s', assuming 10.x+ (no --cloud flag)", distribution_version) ret = ProcessLauncher._run_subprocess(command_line=" ".join(cmd), env=env) if ret != 0: diff --git a/solrorbit/builder/launchers/docker_launcher.py b/solrorbit/builder/launchers/docker_launcher.py index 77306dfa..ffeafe39 100644 --- a/solrorbit/builder/launchers/docker_launcher.py +++ b/solrorbit/builder/launchers/docker_launcher.py @@ -16,8 +16,7 @@ def __init__(self, cluster_config, shell_executor, metrics_store, clock=time.Clo super().__init__(shell_executor) self.logger = logging.getLogger(__name__) self.metrics_store = metrics_store - self.waiter = PeriodicWaiter(DockerLauncher.CONTAINER_WAIT_INTERVAL_SECONDS, - DockerLauncher.CONTAINER_WAIT_TIMEOUT_SECONDS, clock=clock) + self.waiter = PeriodicWaiter(DockerLauncher.CONTAINER_WAIT_INTERVAL_SECONDS, DockerLauncher.CONTAINER_WAIT_TIMEOUT_SECONDS, clock=clock) def start(self, host, node_configurations): nodes = [] diff --git a/solrorbit/builder/launchers/exception_handling_launcher.py b/solrorbit/builder/launchers/exception_handling_launcher.py index c1056a7d..064dce25 100644 --- a/solrorbit/builder/launchers/exception_handling_launcher.py +++ b/solrorbit/builder/launchers/exception_handling_launcher.py @@ -11,10 +11,10 @@ def start(self, host, node_configurations): try: return self.launcher.start(host, node_configurations) except Exception as e: - raise LaunchError(f"Starting node(s) on host \"{host}\" failed", e) + raise LaunchError(f'Starting node(s) on host "{host}" failed', e) def stop(self, host, nodes): try: return self.launcher.stop(host, nodes) except Exception as e: - raise LaunchError(f"Stopping node(s) on host \"{host}\" failed", e) + raise LaunchError(f'Stopping node(s) on host "{host}" failed', e) diff --git a/solrorbit/builder/launchers/launcher.py b/solrorbit/builder/launchers/launcher.py index d55b9b95..3c52f9d3 100644 --- a/solrorbit/builder/launchers/launcher.py +++ b/solrorbit/builder/launchers/launcher.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -32,6 +32,7 @@ class Launcher(ABC): """ Launchers are used to start and stop Solr on the nodes in a self-managed cluster. """ + def __init__(self, shell_executor): self.shell_executor = shell_executor diff --git a/solrorbit/builder/launchers/local_process_launcher.py b/solrorbit/builder/launchers/local_process_launcher.py index 850cd23b..9d505fce 100644 --- a/solrorbit/builder/launchers/local_process_launcher.py +++ b/solrorbit/builder/launchers/local_process_launcher.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -48,8 +48,7 @@ def __init__(self, cluster_config, shell_executor, metrics_store, clock=time.Clo self.logger = logging.getLogger(__name__) self.cluster_config = cluster_config self.metrics_store = metrics_store - self.waiter = PeriodicWaiter(LocalProcessLauncher.PROCESS_WAIT_INTERVAL_SECONDS, - LocalProcessLauncher.PROCESS_WAIT_TIMEOUT_SECONDS, clock=clock) + self.waiter = PeriodicWaiter(LocalProcessLauncher.PROCESS_WAIT_INTERVAL_SECONDS, LocalProcessLauncher.PROCESS_WAIT_TIMEOUT_SECONDS, clock=clock) def start(self, host, node_configurations): node_count_on_host = len(node_configurations) @@ -60,8 +59,7 @@ def _start_node(self, host, node_configuration, node_count_on_host): node_name = node_configuration.node_name binary_path = node_configuration.binary_path - java_major_version, java_home = java_resolver.java_home(node_configuration.cluster_config_runtime_jdks, - self.cluster_config.variables["system"]["runtime"]["jdk"]) + java_major_version, java_home = java_resolver.java_home(node_configuration.cluster_config_runtime_jdks, self.cluster_config.variables["system"]["runtime"]["jdk"]) self.logger.info("Java major version: %s", java_major_version) self.logger.info("Java home: %s", java_home) self.logger.info("Starting node [%s].", node_name) @@ -99,8 +97,7 @@ def _prepare_telemetry(self, node_configuration, node_count_on_host, java_major_ return telemetry.Telemetry(enabled_devices, devices=node_telemetry) def _prepare_env(self, node_name, java_home, telemetry): - env = {k: v for k, v in os.environ.items() if k in - opts.csv_to_list(self.cluster_config.variables["system"]["env"]["passenv"])} + env = {k: v for k, v in os.environ.items() if k in opts.csv_to_list(self.cluster_config.variables["system"]["env"]["passenv"])} if java_home: self._set_env(env, "PATH", os.path.join(java_home, "bin"), separator=os.pathsep, prepend=True) env["JAVA_HOME"] = java_home @@ -115,7 +112,7 @@ def _prepare_env(self, node_name, java_home, telemetry): self.logger.debug("env for [%s]: %s", node_name, str(env)) return env - def _set_env(self, env, key, value, separator=' ', prepend=False): + def _set_env(self, env, key, value, separator=" ", prepend=False): if value is not None: if key not in env: env[key] = value diff --git a/solrorbit/builder/models/bootstrap_phase.py b/solrorbit/builder/models/bootstrap_phase.py index a5eea397..764e93d9 100644 --- a/solrorbit/builder/models/bootstrap_phase.py +++ b/solrorbit/builder/models/bootstrap_phase.py @@ -6,6 +6,7 @@ class BootstrapPhase(Enum): An enum defining the valid phases of bootstrapping. A BootstrapPhase is used to define when a BootstrapHookHandler is executed during cluster creation. """ + POST_INSTALL = 10 @classmethod diff --git a/solrorbit/builder/models/plugin_config_instance.py b/solrorbit/builder/models/plugin_config_instance.py index 58590f2f..cc7002f8 100644 --- a/solrorbit/builder/models/plugin_config_instance.py +++ b/solrorbit/builder/models/plugin_config_instance.py @@ -43,5 +43,4 @@ def __hash__(self): return hash(self.name) ^ hash(self.config_names) ^ hash(self.is_core_plugin) def __eq__(self, other): - return isinstance(other, type(self)) and \ - (self.name, self.config_names, self.is_core_plugin) == (other.name, other.config_names, other.is_core_plugin) + return isinstance(other, type(self)) and (self.name, self.config_names, self.is_core_plugin) == (other.name, other.config_names, other.is_core_plugin) diff --git a/solrorbit/builder/provisioner.py b/solrorbit/builder/provisioner.py index 903be94e..2d29c2cf 100644 --- a/solrorbit/builder/provisioner.py +++ b/solrorbit/builder/provisioner.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -48,9 +48,7 @@ def local(cfg, cluster_config, ip, http_port, all_node_ips, all_node_names, targ runtime_jdk = cluster_config.mandatory_var("runtime.jdk") _, java_home = java_resolver.java_home(runtime_jdk, cfg.opts("builder", "runtime.jdk")) - os_installer = NodeInstaller( - cluster_config, java_home, node_name, - node_root_dir, all_node_ips, all_node_names, ip, http_port) + os_installer = NodeInstaller(cluster_config, java_home, node_name, node_root_dir, all_node_ips, all_node_names, ip, http_port) return BareProvisioner(os_installer, distribution_version=distribution_version) @@ -81,14 +79,12 @@ def as_dict(self): "node-name": self.node_name, "node-root-path": self.node_root_path, "binary-path": self.binary_path, - "data-paths": self.data_paths + "data-paths": self.data_paths, } @staticmethod def from_dict(d): - return NodeConfiguration( - d["build-type"], d["cluster-config-instance-runtime-jdks"], d["ip"], - d["node-name"], d["node-root-path"], d["binary-path"], d["data-paths"]) + return NodeConfiguration(d["build-type"], d["cluster-config-instance-runtime-jdks"], d["ip"], d["node-name"], d["node-root-path"], d["binary-path"], d["data-paths"]) def save_node_configuration(path, n): @@ -148,9 +144,8 @@ def delete_path(p): def _apply_config(source_root_path, target_root_path, config_vars): logger = logging.getLogger(__name__) for root, _, files in os.walk(source_root_path): - - env = jinja2.Environment(loader=jinja2.FileSystemLoader(root), autoescape=select_autoescape(['html', 'xml'])) - relative_root = root[len(source_root_path) + 1:] + env = jinja2.Environment(loader=jinja2.FileSystemLoader(root), autoescape=select_autoescape(["html", "xml"])) + relative_root = root[len(source_root_path) + 1 :] absolute_target_root = os.path.join(target_root_path, relative_root) io.ensure_dir(absolute_target_root) @@ -192,10 +187,15 @@ def prepare(self, binary): # Never let install hooks modify our original provisioner variables and just provide a copy! self.os_installer.invoke_install_hook(cluster_config.BootstrapPhase.post_install, provisioner_vars.copy()) - return NodeConfiguration("tar", self.os_installer.cluster_config.mandatory_var("runtime.jdk"), - self.os_installer.node_ip, self.os_installer.node_name, - self.os_installer.node_root_dir, self.os_installer.os_home_path, - self.os_installer.data_paths) + return NodeConfiguration( + "tar", + self.os_installer.cluster_config.mandatory_var("runtime.jdk"), + self.os_installer.node_ip, + self.os_installer.node_name, + self.os_installer.node_root_dir, + self.os_installer.os_home_path, + self.os_installer.data_paths, + ) def _provisioner_variables(self): provisioner_vars = {} @@ -206,8 +206,7 @@ def _provisioner_variables(self): class NodeInstaller: - def __init__(self, cluster_config, java_home, node_name, node_root_dir, all_node_ips, all_node_names, ip, http_port, - hook_handler_class=cluster_config.BootstrapHookHandler): + def __init__(self, cluster_config, java_home, node_name, node_root_dir, all_node_ips, all_node_names, ip, http_port, hook_handler_class=cluster_config.BootstrapHookHandler): self.cluster_config = cluster_config self.java_home = java_home self.node_name = node_name @@ -270,11 +269,11 @@ def variables(self): "network_host": network_host, "http_port": str(self.http_port), "zookeeper_port": str(self.http_port + 1000), - "all_node_ips": "[\"%s\"]" % "\",\"".join(self.all_node_ips), - "all_node_names": "[\"%s\"]" % "\",\"".join(self.all_node_names), + "all_node_ips": '["%s"]' % '","'.join(self.all_node_ips), + "all_node_names": '["%s"]' % '","'.join(self.all_node_names), # at the moment we are strict and enforce that all nodes are master eligible nodes "minimum_master_nodes": len(self.all_node_ips), - "install_root_path": self.os_home_path + "install_root_path": self.os_home_path, } variables = {} variables.update(self.cluster_config.variables) @@ -327,7 +326,7 @@ def __init__(self, cluster_config, node_name, ip, http_port, node_root_dir, dist "discovery_type": "single-node", "http_port": str(self.http_port), "zookeeper_port": str(self.http_port + 1000), - "cluster_settings": {} + "cluster_settings": {}, } self.config_vars = {} @@ -352,9 +351,9 @@ def prepare(self, binary): for cluster_config_config_path in self.cluster_config.config_paths: for root, _, files in os.walk(cluster_config_config_path): - env = jinja2.Environment(loader=jinja2.FileSystemLoader(root), autoescape=select_autoescape(['html', 'xml'])) + env = jinja2.Environment(loader=jinja2.FileSystemLoader(root), autoescape=select_autoescape(["html", "xml"])) - relative_root = root[len(cluster_config_config_path) + 1:] + relative_root = root[len(cluster_config_config_path) + 1 :] absolute_target_root = os.path.join(self.binary_path, relative_root) io.ensure_dir(absolute_target_root) @@ -376,8 +375,7 @@ def prepare(self, binary): with open(os.path.join(self.binary_path, "docker-compose.yml"), mode="wt", encoding="utf-8") as f: f.write(docker_cfg) - return NodeConfiguration("docker", self.cluster_config.mandatory_var("runtime.jdk"), - self.node_ip, self.node_name, self.node_root_dir, self.binary_path, self.data_paths) + return NodeConfiguration("docker", self.cluster_config.mandatory_var("runtime.jdk"), self.node_ip, self.node_name, self.node_root_dir, self.binary_path, self.data_paths) def docker_vars(self, mounts): # Determine Docker image based on version type @@ -397,7 +395,7 @@ def docker_vars(self, mounts): "solr_data_dir": self.data_paths[0], "solr_log_dir": self.node_log_dir, "solr_heap_dump_dir": self.heap_dump_dir, - "mounts": mounts + "mounts": mounts, } self._add_if_defined_for_cluster_config(v, "docker_mem_limit") self._add_if_defined_for_cluster_config(v, "docker_cpu_count") @@ -409,7 +407,7 @@ def _add_if_defined_for_cluster_config(self, variables, key): def _render_template(self, loader, template_name, variables): try: - env = jinja2.Environment(loader=loader, autoescape=select_autoescape(['html', 'xml'])) + env = jinja2.Environment(loader=loader, autoescape=select_autoescape(["html", "xml"])) for k, v in variables.items(): env.globals[k] = v template = env.get_template(template_name) @@ -422,6 +420,4 @@ def _render_template(self, loader, template_name, variables): def _render_template_from_file(self, variables): compose_file = os.path.join(self.benchmark_root, "resources", "docker-compose.yml.j2") - return self._render_template(loader=jinja2.FileSystemLoader(io.dirname(compose_file)), - template_name=io.basename(compose_file), - variables=variables) + return self._render_template(loader=jinja2.FileSystemLoader(io.dirname(compose_file)), template_name=io.basename(compose_file), variables=variables) diff --git a/solrorbit/builder/provisioners/provisioner.py b/solrorbit/builder/provisioners/provisioner.py index 572c0cd0..359e3587 100644 --- a/solrorbit/builder/provisioners/provisioner.py +++ b/solrorbit/builder/provisioners/provisioner.py @@ -5,6 +5,7 @@ class Provisioner(ABC): """ Provisioners are used to create and destroy any infrastructure required to construct a cluster. """ + def __init__(self): pass diff --git a/solrorbit/builder/solr_provisioner.py b/solrorbit/builder/solr_provisioner.py index 8d132fe4..4b1c20d5 100644 --- a/solrorbit/builder/solr_provisioner.py +++ b/solrorbit/builder/solr_provisioner.py @@ -97,9 +97,7 @@ class SolrProvisioner: p.clean("/tmp/solr-node") """ - def __init__(self, cache_dir: str = None, port: int = 8983, - startup_timeout: int = 120, cluster_config=None, solr_modules: str = "", - telemetry_devices: list = None): + def __init__(self, cache_dir: str = None, port: int = 8983, startup_timeout: int = 120, cluster_config=None, solr_modules: str = "", telemetry_devices: list = None): self.cache_dir = cache_dir or os.path.join(os.path.expanduser("~"), ".solr-orbit", "cache") self.port = port self.startup_timeout = startup_timeout @@ -145,10 +143,7 @@ def download(self, version: str) -> str: if os.path.exists(dest): os.remove(dest) - raise SolrProvisionerError( - f"Could not download Solr {version} from any mirror. " - f"Please download manually to {dest}." - ) + raise SolrProvisionerError(f"Could not download Solr {version} from any mirror. Please download manually to {dest}.") def install(self, version: str, install_dir: str) -> str: """ @@ -198,9 +193,7 @@ def start(self, solr_root: str, mode: str = None) -> None: logger.info("Starting Solr with: %s", " ".join(cmd)) result = subprocess.run(cmd, capture_output=True, text=True, env=self._build_env()) if result.returncode != 0: - raise SolrProvisionerError( - f"Solr failed to start: {result.stderr or result.stdout}" - ) + raise SolrProvisionerError(f"Solr failed to start: {result.stderr or result.stdout}") self._wait_for_ready() @@ -266,17 +259,14 @@ def _build_env(self) -> dict: def _bin_solr(self, solr_root: str) -> str: script = os.path.join(solr_root, "bin", "solr") if not os.path.isfile(script): - raise SolrProvisionerError( - f"bin/solr not found in {solr_root}. " - "Ensure install() was called first." - ) + raise SolrProvisionerError(f"bin/solr not found in {solr_root}. Ensure install() was called first.") return script def _detect_version(self, solr_root: str) -> str: """Read version from the Solr installation directory name.""" name = Path(solr_root).name # e.g. "solr-9.7.0" if name.startswith("solr-"): - return name[len("solr-"):] + return name[len("solr-") :] return "" def _wait_for_ready(self) -> None: @@ -292,16 +282,14 @@ def _wait_for_ready(self) -> None: except Exception as exc: last_exc = exc time.sleep(2) - raise SolrProvisionerError( - f"Solr did not become ready within {self.startup_timeout}s. " - f"Last error: {last_exc}" - ) + raise SolrProvisionerError(f"Solr did not become ready within {self.startup_timeout}s. Last error: {last_exc}") # --------------------------------------------------------------------------- # Docker launcher (T019) # --------------------------------------------------------------------------- + class SolrDockerLauncher: """ Launch an official Solr Docker container for benchmarking. @@ -321,9 +309,7 @@ class SolrDockerLauncher: DEFAULT_CONTAINER_NAME = "solr-orbit" - def __init__(self, port: int = 8983, startup_timeout: int = 60, - container_name: str = None, cluster_config=None, solr_modules: str = "", - telemetry_devices: list = None): + def __init__(self, port: int = 8983, startup_timeout: int = 60, container_name: str = None, cluster_config=None, solr_modules: str = "", telemetry_devices: list = None): self.port = port self.startup_timeout = startup_timeout self.container_name = container_name or self.DEFAULT_CONTAINER_NAME @@ -356,10 +342,13 @@ def start(self, version_tag: str = "9", mode: str = None) -> None: # Build the docker run command cmd = [ - "docker", "run", + "docker", + "run", "--rm", - "--name", self.container_name, - "-p", f"{self.port}:8983", + "--name", + self.container_name, + "-p", + f"{self.port}:8983", "-d", ] cmd += self._cluster_config_env_flags() @@ -377,9 +366,7 @@ def start(self, version_tag: str = "9", mode: str = None) -> None: logger.info("Starting Solr Docker container: %s", " ".join(cmd)) result = subprocess.run(cmd, capture_output=True, text=True) if result.returncode != 0: - raise SolrProvisionerError( - f"Failed to start Solr Docker container: {result.stderr or result.stdout}" - ) + raise SolrProvisionerError(f"Failed to start Solr Docker container: {result.stderr or result.stdout}") self._wait_for_ready() @@ -387,7 +374,9 @@ def start(self, version_tag: str = "9", mode: str = None) -> None: try: pid_result = subprocess.run( ["docker", "inspect", self.container_name, "--format={{.State.Pid}}"], - capture_output=True, text=True, check=True, + capture_output=True, + text=True, + check=True, ) self.pid = int(pid_result.stdout.strip()) logger.info("Solr container PID = %d", self.pid) @@ -449,7 +438,4 @@ def _wait_for_ready(self) -> None: except Exception as exc: last_exc = exc time.sleep(2) - raise SolrProvisionerError( - f"Solr container did not become ready within {self.startup_timeout}s. " - f"Last error: {last_exc}" - ) + raise SolrProvisionerError(f"Solr container did not become ready within {self.startup_timeout}s. Last error: {last_exc}") diff --git a/solrorbit/builder/supplier.py b/solrorbit/builder/supplier.py index b84b54dd..c76c5549 100644 --- a/solrorbit/builder/supplier.py +++ b/solrorbit/builder/supplier.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -86,24 +86,17 @@ def create(cfg, sources, cluster_config): if os_supplier_type == "source": os_src_dir = os.path.join(_src_dir(cfg), _config_value(src_config, "src.subdir")) - source_supplier = SourceSupplier(os_version, - os_src_dir, - remote_url=cfg.opts("source", "remote.repo.url"), - cluster_config=cluster_config, - builder=builder, - template_renderer=template_renderer) + source_supplier = SourceSupplier( + os_version, os_src_dir, remote_url=cfg.opts("source", "remote.repo.url"), cluster_config=cluster_config, builder=builder, template_renderer=template_renderer + ) if caching_enabled: os_file_resolver = FileNameResolver(dist_cfg, template_renderer) - source_supplier = CachedSourceSupplier(source_distributions_root, - source_supplier, - os_file_resolver) + source_supplier = CachedSourceSupplier(source_distributions_root, source_supplier, os_file_resolver) suppliers.append(source_supplier) else: - repo = DistributionRepository(name=cfg.opts("builder", "distribution.repository"), - distribution_config=dist_cfg, - template_renderer=template_renderer) + repo = DistributionRepository(name=cfg.opts("builder", "distribution.repository"), distribution_config=dist_cfg, template_renderer=template_renderer) suppliers.append(DistributionSupplier(repo, os_version, distributions_root)) return CompositeSupplier(suppliers) @@ -111,8 +104,7 @@ def create(cfg, sources, cluster_config): def _required_version(version): if not version or version.strip() == "": - raise exceptions.SystemSetupError("Could not determine version. Please specify the Solr distribution " - "to download with the command line parameter --distribution-version.") + raise exceptions.SystemSetupError("Could not determine version. Please specify the Solr distribution to download with the command line parameter --distribution-version.") else: return version @@ -146,8 +138,9 @@ def _src_dir(cfg, mandatory=True): try: return cfg.opts("node", "src.root.dir", mandatory=mandatory) except exceptions.ConfigError: - raise exceptions.SystemSetupError("You cannot benchmark Solr from sources. Did you install Gradle? Please install" - " all prerequisites and reconfigure with %s configure" % PROGRAM_NAME) + raise exceptions.SystemSetupError( + "You cannot benchmark Solr from sources. Did you install Gradle? Please install all prerequisites and reconfigure with %s configure" % PROGRAM_NAME + ) def _prune(root_path, max_age_days): @@ -178,6 +171,7 @@ def _prune(root_path, max_age_days): else: logger.info("Skipping [%s] (not a file).", artifact) + class TemplateRenderer: def __init__(self, version): self.version = version @@ -221,7 +215,7 @@ def file_name(self): # Solr distributions never include a JDK, so we always use release_url url_key = "release_url" url = self.template_renderer.render(self.cfg[url_key]) - return url[url.rfind("/") + 1:] + return url[url.rfind("/") + 1 :] @property def artifact_key(self): @@ -304,25 +298,24 @@ def fetch(self): def prepare(self): if self.builder: - self.builder.build([ - self.template_renderer.render(self.cluster_config.mandatory_var("clean_command")), - self.template_renderer.render(self.cluster_config.mandatory_var("system.build_command")) - ]) + self.builder.build( + [ + self.template_renderer.render(self.cluster_config.mandatory_var("clean_command")), + self.template_renderer.render(self.cluster_config.mandatory_var("system.build_command")), + ] + ) def add(self, binaries): binaries["solr"] = self.resolve_binary() def resolve_binary(self): try: - path = os.path.join(self.src_dir, - self.template_renderer.render(self.cluster_config.mandatory_var("system.artifact_path_pattern"))) + path = os.path.join(self.src_dir, self.template_renderer.render(self.cluster_config.mandatory_var("system.artifact_path_pattern"))) return glob.glob(path)[0] except IndexError: raise SystemSetupError("Couldn't find a tar.gz distribution. Please run Solr Orbit with the pipeline 'from-sources'.") - - class DistributionSupplier: def __init__(self, repo, version, distributions_root): self.repo = repo @@ -346,8 +339,9 @@ def fetch(self): self.logger.info("Successfully downloaded Solr [%s].", self.version) except urllib.error.HTTPError: self.logger.exception("Cannot download Solr distribution for version [%s] from [%s].", self.version, download_url) - raise exceptions.SystemSetupError("Cannot download Solr distribution from [%s]. Please check that the specified " - "version [%s] is correct." % (download_url, self.version)) + raise exceptions.SystemSetupError( + "Cannot download Solr distribution from [%s]. Please check that the specified version [%s] is correct." % (download_url, self.version) + ) else: self.logger.info("Skipping download for version [%s]. Found an existing binary at [%s].", self.version, distribution_path) @@ -360,13 +354,11 @@ def add(self, binaries): binaries["solr"] = self.distribution_path - def _config_value(src_config, key): try: return src_config[key] except KeyError: - raise exceptions.SystemSetupError("Mandatory config key [%s] is undefined. Please add it in the [source] section of the " - "config file." % key) + raise exceptions.SystemSetupError("Mandatory config key [%s] is undefined. Please add it in the [source] section of the config file." % key) def _extract_revisions(revision): @@ -374,18 +366,16 @@ def _extract_revisions(revision): if len(revisions) == 1: r = revisions[0] if r.startswith("solr:"): - r = r[len("solr:"):] + r = r[len("solr:") :] # may as well be just a single plugin m = re.match(REVISION_PATTERN, r) if m: - return { - m.group(1): m.group(2) - } + return {m.group(1): m.group(2)} else: return { "solr": r, # use a catch-all value - "all": r + "all": r, } else: results = {} @@ -523,7 +513,7 @@ def download_url(self): @property def file_name(self): url = self.download_url - return url[url.rfind("/") + 1:] + return url[url.rfind("/") + 1 :] def plugin_download_url(self, plugin_name): # cluster_config repo diff --git a/solrorbit/builder/utils/artifact_variables_provider.py b/solrorbit/builder/utils/artifact_variables_provider.py index a463e02a..c27df63e 100644 --- a/solrorbit/builder/utils/artifact_variables_provider.py +++ b/solrorbit/builder/utils/artifact_variables_provider.py @@ -6,11 +6,7 @@ def __init__(self, executor): self.executor = executor def get_artifact_variables(self, host, version=None): - return { - "VERSION": version, - "OSNAME": self._get_os_name(host), - "ARCH": self._get_arch(host) - } + return {"VERSION": version, "OSNAME": self._get_os_name(host), "ARCH": self._get_arch(host)} def _get_os_name(self, host): os_name = self.executor.execute(host, "uname", output=True)[0] diff --git a/solrorbit/builder/utils/binary_keys.py b/solrorbit/builder/utils/binary_keys.py index d0f71c81..67037bf9 100644 --- a/solrorbit/builder/utils/binary_keys.py +++ b/solrorbit/builder/utils/binary_keys.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an diff --git a/solrorbit/builder/utils/config_applier.py b/solrorbit/builder/utils/config_applier.py index fdd4a09a..37f4d2de 100644 --- a/solrorbit/builder/utils/config_applier.py +++ b/solrorbit/builder/utils/config_applier.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -49,7 +49,7 @@ def _apply_config(self, host, source_root_path, target_root_path, config_vars): mounts = {} for root, _, files in os.walk(source_root_path): - relative_root = root[len(source_root_path) + 1:] + relative_root = root[len(source_root_path) + 1 :] absolute_target_root = os.path.join(target_root_path, relative_root) self.path_manager.create_path(host, absolute_target_root) diff --git a/solrorbit/builder/utils/git_manager.py b/solrorbit/builder/utils/git_manager.py index f7a938fa..96e6e93d 100644 --- a/solrorbit/builder/utils/git_manager.py +++ b/solrorbit/builder/utils/git_manager.py @@ -15,7 +15,7 @@ def rebase(self, host, target_dir, remote="origin", branch="main"): self.executor.execute(host, f"git -C {target_dir} rebase {remote}/{branch}") def get_revision_from_timestamp(self, host, target_dir, timestamp): - get_revision_from_timestamp_command = f"git -C {target_dir} rev-list -n 1 --before=\"{timestamp}\" --date=iso8601 origin/main" + get_revision_from_timestamp_command = f'git -C {target_dir} rev-list -n 1 --before="{timestamp}" --date=iso8601 origin/main' return self.executor.execute(host, get_revision_from_timestamp_command, output=True)[0].strip() diff --git a/solrorbit/builder/utils/java_home_resolver.py b/solrorbit/builder/utils/java_home_resolver.py index c9bab5e6..9a98917b 100644 --- a/solrorbit/builder/utils/java_home_resolver.py +++ b/solrorbit/builder/utils/java_home_resolver.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -43,7 +43,7 @@ def resolve_java_home(self, host, cluster_config): try: allowed_runtime_jdks = [int(v) for v in runtime_jdks.split(",")] except ValueError: - raise SystemSetupError(f"ClusterConfigInstance variable key \"runtime.jdk\" is invalid: \"{runtime_jdks}\" (must be int)") + raise SystemSetupError(f'ClusterConfigInstance variable key "runtime.jdk" is invalid: "{runtime_jdks}" (must be int)') self.logger.info("Allowed JDK versions are %s.", allowed_runtime_jdks) return self._detect_jdk(host, allowed_runtime_jdks) diff --git a/solrorbit/builder/utils/jdk_resolver.py b/solrorbit/builder/utils/jdk_resolver.py index 2a4759e6..c77bc353 100644 --- a/solrorbit/builder/utils/jdk_resolver.py +++ b/solrorbit/builder/utils/jdk_resolver.py @@ -43,8 +43,7 @@ def _resolve_jdk_path(self, host, majors): resolved_major_to_java_home_path = {} for java_home_env_var_name in java_home_env_var_names: if java_home_env_var_name in defined_env_vars: - major_to_java_home_path = self._resolve_major_from_java_home(host, java_home_env_var_name, - defined_env_vars[java_home_env_var_name]) + major_to_java_home_path = self._resolve_major_from_java_home(host, java_home_env_var_name, defined_env_vars[java_home_env_var_name]) if major_to_java_home_path: resolved_major_to_java_home_path.update(major_to_java_home_path) diff --git a/solrorbit/builder/utils/template_renderer.py b/solrorbit/builder/utils/template_renderer.py index 8f00af30..4b1e785b 100644 --- a/solrorbit/builder/utils/template_renderer.py +++ b/solrorbit/builder/utils/template_renderer.py @@ -11,7 +11,7 @@ def render_template_file(self, root_path, variables, file_name): return self._handle_template_rendering_exceptions(self._render_template_file, root_path, variables, file_name) def _render_template_file(self, root_path, variables, file_name): - env = jinja2.Environment(loader=jinja2.FileSystemLoader(root_path), autoescape=select_autoescape(['html', 'xml'])) + env = jinja2.Environment(loader=jinja2.FileSystemLoader(root_path), autoescape=select_autoescape(["html", "xml"])) env.filters["version_between"] = loader.version_between template = env.get_template(io.basename(file_name)) # force a new line at the end. Jinja seems to remove it. @@ -21,7 +21,7 @@ def render_template_string(self, template_string, variables): return self._handle_template_rendering_exceptions(self._render_template_string, template_string, variables) def _render_template_string(self, template_string, variables): - env = jinja2.Environment(loader=jinja2.BaseLoader, autoescape=select_autoescape(['html', 'xml'])) + env = jinja2.Environment(loader=jinja2.BaseLoader, autoescape=select_autoescape(["html", "xml"])) env.filters["version_between"] = loader.version_between template = env.from_string(template_string) diff --git a/solrorbit/client.py b/solrorbit/client.py index b872003f..6001c05b 100644 --- a/solrorbit/client.py +++ b/solrorbit/client.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -42,6 +42,7 @@ # Exceptions # --------------------------------------------------------------------------- + class SolrClientError(Exception): """Base exception for all SolrAdminClient errors.""" @@ -58,6 +59,7 @@ class CollectionNotFoundError(SolrClientError): # SolrAdminClient # --------------------------------------------------------------------------- + class SolrAdminClient: """ Thin wrapper around requests.Session for Solr V2 API admin operations. @@ -69,9 +71,7 @@ class SolrAdminClient: Not thread-safe — each worker process creates its own instance. """ - def __init__(self, host: str, port: int = 8983, - username: str = None, password: str = None, - tls: bool = False, timeout: int = 30): + def __init__(self, host: str, port: int = 8983, username: str = None, password: str = None, tls: bool = False, timeout: int = 30): scheme = "https" if tls else "http" self.base_url = f"{scheme}://{host}:{port}" self.api_url = f"{self.base_url}/api" @@ -111,9 +111,7 @@ def get_version(self) -> str: try: return data["lucene"]["solr-spec-version"] except KeyError as exc: - raise SolrClientError( - f"Could not parse Solr version from /api/node/system response: {data}" - ) from exc + raise SolrClientError(f"Could not parse Solr version from /api/node/system response: {data}") from exc def get_major_version(self) -> int: """Return the major version integer (9 or 10).""" @@ -136,9 +134,7 @@ def wait_for_cluster_ready(self, timeout: int = 60, **kwargs) -> None: except Exception as exc: last_exc = exc time.sleep(2) - raise SolrClientError( - f"Solr cluster did not become ready within {timeout}s. Last error: {last_exc}" - ) + raise SolrClientError(f"Solr cluster did not become ready within {timeout}s. Last error: {last_exc}") # ------------------------------------------------------------------ # Configset management @@ -192,10 +188,9 @@ def delete_configset(self, name: str) -> None: # Collection management # ------------------------------------------------------------------ - def create_collection(self, name: str, configset: str, - num_shards: int = 1, replication_factor: int = 1, - tlog_replicas: int = 0, pull_replicas: int = 0, - wait_for_active_shards: int = 1) -> None: + def create_collection( + self, name: str, configset: str, num_shards: int = 1, replication_factor: int = 1, tlog_replicas: int = 0, pull_replicas: int = 0, wait_for_active_shards: int = 1 + ) -> None: """ Create a Solr collection via POST /api/collections. @@ -220,12 +215,9 @@ def create_collection(self, name: str, configset: str, if resp.status_code == 400: body = self._try_parse_json(resp) if "already exists" in str(body).lower(): - raise CollectionAlreadyExistsError( - f"Collection '{name}' already exists" - ) + raise CollectionAlreadyExistsError(f"Collection '{name}' already exists") self._raise_for_solr_error(resp, f"create collection '{name}'") - logger.info("Created collection '%s' (shards=%d, nrt=%d, tlog=%d, pull=%d)", - name, num_shards, replication_factor, tlog_replicas, pull_replicas) + logger.info("Created collection '%s' (shards=%d, nrt=%d, tlog=%d, pull=%d)", name, num_shards, replication_factor, tlog_replicas, pull_replicas) def delete_collection(self, name: str) -> None: """Delete a Solr collection via DELETE /api/collections/{name}.""" @@ -334,8 +326,7 @@ def get_node_metrics(self): # Raw request (for the raw-request workload operation) # ------------------------------------------------------------------ - def raw_request(self, method: str, path: str, - body=None, headers: dict = None) -> requests.Response: + def raw_request(self, method: str, path: str, body=None, headers: dict = None) -> requests.Response: """ Send an arbitrary HTTP request to a Solr endpoint. @@ -369,9 +360,7 @@ def _raise_for_solr_error(self, resp: requests.Response, operation: str) -> None return body = self._try_parse_json(resp) msg = body.get("error", {}).get("msg", resp.text) if isinstance(body, dict) else resp.text - raise SolrClientError( - f"Solr {operation} failed (HTTP {resp.status_code}): {msg}" - ) + raise SolrClientError(f"Solr {operation} failed (HTTP {resp.status_code}): {msg}") @staticmethod def _try_parse_json(resp: requests.Response) -> dict: @@ -400,6 +389,7 @@ def _build_configset_zip(configset_dir: str) -> bytes: # SolrClient — unified client used by runners and telemetry devices # --------------------------------------------------------------------------- + class SolrClient(RequestContextHolder): # pylint: disable=too-many-public-methods """ Single unified Solr client. Wraps SolrAdminClient (admin/HTTP) and pysolr.Solr @@ -414,16 +404,14 @@ class _NoOpTransport: async def close(self): pass - def __init__(self, host="localhost", port=8983, username=None, password=None, - tls=False, timeout=30): + def __init__(self, host="localhost", port=8983, username=None, password=None, tls=False, timeout=30): self._host = host self._port = port self._username = username self._password = password self._tls = tls self._timeout = timeout - self._admin = SolrAdminClient(host=host, port=port, username=username, - password=password, tls=tls, timeout=timeout) + self._admin = SolrAdminClient(host=host, port=port, username=username, password=password, tls=tls, timeout=timeout) self._pysolr_clients = {} # collection → pysolr.Solr (created lazily) self.transport = SolrClient._NoOpTransport() @@ -499,6 +487,7 @@ def _get(self, path: str): def _get_pysolr(self, collection: str): """Return (lazily-created, cached) pysolr.Solr for the given collection.""" import pysolr # pylint: disable=import-outside-toplevel + if collection not in self._pysolr_clients: scheme = "https" if self._tls else "http" url = f"{scheme}://{self._host}:{self._port}/solr/{collection}" @@ -506,8 +495,7 @@ def _get_pysolr(self, collection: str): session.trust_env = False # fork-safe on macOS (no CFNetwork proxy detection) if self._username and self._password: session.auth = (self._username, self._password) - self._pysolr_clients[collection] = pysolr.Solr( - url, timeout=self._timeout, always_commit=False, session=session) + self._pysolr_clients[collection] = pysolr.Solr(url, timeout=self._timeout, always_commit=False, session=session) return self._pysolr_clients[collection] def add(self, collection, docs, **kwargs): diff --git a/solrorbit/config.py b/solrorbit/config.py index 5e9069cf..7ee9b6be 100644 --- a/solrorbit/config.py +++ b/solrorbit/config.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -296,14 +296,12 @@ def migrate(config_file, current_version, target_version, out=print, i=input): logger.info("Config file is already at version [%s]. Skipping migration.", target_version) return if current_version < Config.EARLIEST_SUPPORTED_VERSION: - raise exceptions.ConfigError(f"The config file in {config_file.location} is too old. Please delete it " - f"and reconfigure from scratch with {PROGRAM_NAME} configure.") + raise exceptions.ConfigError(f"The config file in {config_file.location} is too old. Please delete it and reconfigure from scratch with {PROGRAM_NAME} configure.") logger.info("Upgrading configuration from version [%s] to [%s].", current_version, target_version) # Something is really fishy. We don't want to downgrade the configuration. if current_version >= target_version: - raise exceptions.ConfigError(f"The existing config file is available in a later version already. " - f"Expected version <= [{target_version}] but found [{current_version}]") + raise exceptions.ConfigError(f"The existing config file is available in a later version already. Expected version <= [{target_version}] but found [{current_version}]") # but first a backup... config_file.backup() config = config_file.load() diff --git a/solrorbit/context.py b/solrorbit/context.py index 02d3c871..54ed1a3f 100644 --- a/solrorbit/context.py +++ b/solrorbit/context.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -35,6 +35,7 @@ class RequestContextManager: This means that we can span a top-level request context, open sub-request contexts that can be used to measure individual timings and still measure the proper total time on the top-level request context. """ + def __init__(self, request_context_holder): self.ctx_holder = request_context_holder self.ctx = None @@ -86,6 +87,7 @@ class RequestContextHolder: """ Holds request context variables. This class is only meant to be used together with RequestContextManager. """ + request_context = contextvars.ContextVar("benchmark_request_context") def new_request_context(self): diff --git a/solrorbit/conversion/query.py b/solrorbit/conversion/query.py index 5b4b2766..f12bf6aa 100644 --- a/solrorbit/conversion/query.py +++ b/solrorbit/conversion/query.py @@ -128,6 +128,7 @@ def extract_sort_parameter(body: dict) -> str: # Internal helper functions # --------------------------------------------------------------------------- + def _translate_query_node(node: dict, fq_list: list = None) -> str: """Recursively translate a single OpenSearch query node to Solr syntax. @@ -172,10 +173,7 @@ def _translate_query_node(node: dict, fq_list: list = None) -> str: return f'{field}:"{_escape_solr_phrase(v)}"' return f"{field}:{_escape_solr_value(v)}" # If sub is not a dict or has no valid fields, fall back - logger.warning( - "match/match_phrase query has invalid structure: %s. Using *:*", - sub - ) + logger.warning("match/match_phrase query has invalid structure: %s. Using *:*", sub) return "*:*" if "range" in node: @@ -263,10 +261,7 @@ def _translate_terms_clause(field: str, values: list) -> str: - Large lists (>100 terms): ``field:(v1 v2 ...)`` for q context, still works but when used as an fq, callers should prefer ``{!terms f=field}v1,v2,...`` """ - escaped = " ".join( - f'"{_escape_solr_phrase(v)}"' if " " in str(v) else _escape_solr_value(v) - for v in values - ) + escaped = " ".join(f'"{_escape_solr_phrase(v)}"' if " " in str(v) else _escape_solr_value(v) for v in values) return f"{field}:({escaped})" @@ -305,10 +300,10 @@ def _escape_solr_value(value) -> str: result = [] for char in str(value): if char in special: - result.append('\\' + char) + result.append("\\" + char) else: result.append(char) - return ''.join(result) + return "".join(result) def _escape_solr_phrase(value) -> str: @@ -318,7 +313,7 @@ def _escape_solr_phrase(value) -> str: For phrases, we only need to escape quotes (and backslashes). Other special characters are OK within quotes. """ - return str(value).replace('\\', '\\\\').replace('"', '\\"') + return str(value).replace("\\", "\\\\").replace('"', '\\"') def translate_to_solr_json_dsl(body: dict) -> dict: @@ -437,11 +432,7 @@ def _convert_single_agg(agg_name: str, agg_def: dict): if not field: logger.warning("date_histogram agg '%s' has no field — skipping", agg_name) return None - interval = ( - dh_conf.get("calendar_interval") - or dh_conf.get("fixed_interval") - or dh_conf.get("interval", "month") - ) + interval = dh_conf.get("calendar_interval") or dh_conf.get("fixed_interval") or dh_conf.get("interval", "month") gap = _calendar_interval_to_solr_gap(interval) facet_def = { "type": "range", @@ -492,7 +483,8 @@ def _convert_single_agg(agg_name: str, agg_def: dict): agg_type = next(iter(agg_def), "unknown") logger.warning( "Unsupported aggregation type '%s' (name='%s') — skipping in Solr conversion.", - agg_type, agg_name, + agg_type, + agg_name, ) return None diff --git a/solrorbit/conversion/schema.py b/solrorbit/conversion/schema.py index 9cc77848..c843b5ea 100644 --- a/solrorbit/conversion/schema.py +++ b/solrorbit/conversion/schema.py @@ -68,7 +68,7 @@ # Note: This is a best-effort mapping for common cases OPENSEARCH_TO_SOLR_TYPES = { # Numeric types - "scaled_float": "pdouble", # Note: loses scaling_factor precision control + "scaled_float": "pdouble", # Note: loses scaling_factor precision control "half_float": "pfloat", "float": "pfloat", "double": "pdouble", @@ -76,18 +76,15 @@ "short": "pint", "integer": "pint", "long": "plong", - # String types - "keyword": "string", # Exact match, no analysis - "text": "text_general", # Analyzed text - + "keyword": "string", # Exact match, no analysis + "text": "text_general", # Analyzed text # Other types "boolean": "boolean", "date": "pdate", "binary": "binary", - # Spatial - "geo_point": "string", # Stored as "lat,lon" string (converted during indexing) + "geo_point": "string", # Stored as "lat,lon" string (converted during indexing) } @@ -119,10 +116,7 @@ def translate_opensearch_mapping(properties: Dict[str, Any]) -> tuple[Dict[str, # Translate main field solr_type = OPENSEARCH_TO_SOLR_TYPES.get(os_type) if not solr_type: - logger.warning( - f"Field '{field_name}' has unsupported type '{os_type}', " - f"falling back to 'string'" - ) + logger.warning(f"Field '{field_name}' has unsupported type '{os_type}', falling back to 'string'") solr_type = "string" # Build Solr field config @@ -142,10 +136,7 @@ def translate_opensearch_mapping(properties: Dict[str, Any]) -> tuple[Dict[str, # Solr: Uses ISO8601 by default, custom formats need DatePointField config os_format = field_config.get("format") if os_format and os_format != "strict_date_optional_time||epoch_millis": - logger.warning( - f"Field '{field_name}' has custom date format '{os_format}'. " - f"Solr will use ISO8601 format. Manual schema adjustment may be needed." - ) + logger.warning(f"Field '{field_name}' has custom date format '{os_format}'. Solr will use ISO8601 format. Manual schema adjustment may be needed.") solr_fields[field_name] = solr_field @@ -180,17 +171,12 @@ def translate_opensearch_mapping(properties: Dict[str, Any]) -> tuple[Dict[str, # Add copyField directive from main field to sub-field copy_fields.append((field_name, solr_sub_field_name)) - logger.info( - f"Multi-field detected: {field_name}.{sub_field_name} → " - f"{solr_sub_field_name} (type: {sub_solr_type})" - ) + logger.info(f"Multi-field detected: {field_name}.{sub_field_name} → {solr_sub_field_name} (type: {sub_solr_type})") return solr_fields, copy_fields -def generate_schema_xml(field_defs: Dict[str, Dict[str, Any]], - copy_fields: Optional[list[tuple[str, str]]] = None, - unique_key: str = "id") -> str: +def generate_schema_xml(field_defs: Dict[str, Dict[str, Any]], copy_fields: Optional[list[tuple[str, str]]] = None, unique_key: str = "id") -> str: """ Generate a Solr schema.xml from field definitions. @@ -209,13 +195,13 @@ def generate_schema_xml(field_defs: Dict[str, Dict[str, Any]], fields_xml = [] # Add required fields for SolrCloud - fields_xml.append(' ') + fields_xml.append(" ") fields_xml.append(f' ') fields_xml.append(' ') fields_xml.append(' ') fields_xml.append(' ') - fields_xml.append('') - fields_xml.append(' ') + fields_xml.append("") + fields_xml.append(" ") # Add workload fields for field_name, field_config in field_defs.items(): @@ -238,13 +224,13 @@ def generate_schema_xml(field_defs: Dict[str, Dict[str, Any]], if doc_values is not None: attrs.append(f'docValues="{str(doc_values).lower()}"') - fields_xml.append(f' ') + fields_xml.append(f" ") # Build copyField directives XML copy_fields_xml = [] if copy_fields: - copy_fields_xml.append('') - copy_fields_xml.append(' ') + copy_fields_xml.append("") + copy_fields_xml.append(" ") for source, dest in copy_fields: copy_fields_xml.append(f' ') @@ -314,8 +300,7 @@ def generate_schema_xml(field_defs: Dict[str, Dict[str, Any]], return schema_xml -def create_configset_from_schema(schema_xml: str, - configset_name: Optional[str] = None) -> str: +def create_configset_from_schema(schema_xml: str, configset_name: Optional[str] = None) -> str: """ Create a temporary Solr configset directory with the generated schema. diff --git a/solrorbit/conversion/workload_converter.py b/solrorbit/conversion/workload_converter.py index b99dcb6d..aeadb5de 100644 --- a/solrorbit/conversion/workload_converter.py +++ b/solrorbit/conversion/workload_converter.py @@ -54,10 +54,10 @@ # 2. Entire if/else/endif conditional block (may span many lines) # 3. Any remaining Jinja2 block tag or expression _JINJA_RE = re.compile( - r'"(\{\{[^}]*?\}\})"' # group 1: already-quoted {{expr}} - r'|\{%-?\s*if\b.*?\{%-?\s*endif\s*-?%\}' # full if/else/endif block - r'|\{%.*?%\}' # any other block tag - r'|\{\{.*?\}\}', # bare {{expr}} + r'"(\{\{[^}]*?\}\})"' # group 1: already-quoted {{expr}} + r"|\{%-?\s*if\b.*?\{%-?\s*endif\s*-?%\}" # full if/else/endif block + r"|\{%.*?%\}" # any other block tag + r"|\{\{.*?\}\}", # bare {{expr}} re.DOTALL, ) @@ -165,12 +165,12 @@ def _load_workload_json(workload_path: str) -> dict: # Template path: render Jinja2 then parse JSON try: from solrorbit.workload.loader import render_template_from_file + rendered = render_template_from_file(workload_path, template_vars={}) return json.loads(rendered) except Exception as exc: - raise ValueError( - f"Cannot parse workload file '{workload_path}' as JSON or Jinja2 template: {exc}" - ) from exc + raise ValueError(f"Cannot parse workload file '{workload_path}' as JSON or Jinja2 template: {exc}") from exc + # Sentinel filename written to the output directory after successful conversion CONVERTED_MARKER = "CONVERTED.md" @@ -304,11 +304,7 @@ def convert_opensearch_workload(source_dir: str, output_dir: str) -> dict: # --- Copy auxiliary files (Python param sources, templates, etc.) --- # Skip index body files (e.g. index.json) — replaced by generated configsets - index_body_files = { - index.get("body") - for index in rendered_workload.get("indices", []) - if index.get("body") - } + index_body_files = {index.get("body") for index in rendered_workload.get("indices", []) if index.get("body")} _copy_auxiliary_files(source_dir, output_dir, skip_files=index_body_files) # --- Follow external benchmark.collect() refs and make the workload self-contained --- @@ -319,7 +315,11 @@ def convert_opensearch_workload(source_dir: str, output_dir: str) -> dict: logger.info( "Workload conversion complete: %s → %s (%d ops, %d skipped, %d issues)", - source_dir, output_dir, 0, len(skipped), len(issues), + source_dir, + output_dir, + 0, + len(skipped), + len(issues), ) return { @@ -333,6 +333,7 @@ def convert_opensearch_workload(source_dir: str, output_dir: str) -> dict: # Internal helpers # --------------------------------------------------------------------------- + def _generate_configsets_from_indices(rendered_workload: dict, source_dir: str, output_dir: str, issues: list): """Generate Solr configsets from the rendered workload's indices section.""" for index in rendered_workload.get("indices", []): @@ -353,9 +354,7 @@ def _generate_configsets_from_indices(rendered_workload: dict, source_dir: str, issues.append(f"Could not generate schema for collection '{collection_name}': {exc}") -def _write_converted_workload_json( - workload_path: str, rendered_workload: dict, output_dir: str, issues: list, skipped: list -): +def _write_converted_workload_json(workload_path: str, rendered_workload: dict, output_dir: str, issues: list, skipped: list): """ Write the converted workload.json to *output_dir*, preserving Jinja2 template syntax. @@ -411,10 +410,7 @@ def _apply_inline_conversions(text: str, rendered_workload: dict, issues: list, parsed["collections"] = parsed.pop("indices") # Convert operations (filter out those skipped by _convert_operation) - parsed["operations"] = [ - op for op in parsed.get("operations", []) - if not isinstance(op, dict) or _convert_operation(op, issues, skipped, "", "") - ] + parsed["operations"] = [op for op in parsed.get("operations", []) if not isinstance(op, dict) or _convert_operation(op, issues, skipped, "", "")] # Convert challenge schedules for challenge in parsed.get("challenges", []): @@ -460,10 +456,7 @@ def _process_collected_files(source_dir: str, output_dir: str, issues: list, ski # Convert each operation in the fragment; filter out skipped ones if subdir == "operations": - ops_list = [ - op for op in ops_list - if not isinstance(op, dict) or _convert_operation(op, issues, skipped, source_dir, output_dir) - ] + ops_list = [op for op in ops_list if not isinstance(op, dict) or _convert_operation(op, issues, skipped, source_dir, output_dir)] converted_text = _serialise_jinja_fragment(ops_list, tokens, wrap_array=True) @@ -572,8 +565,7 @@ def _convert_operation(op, issues, skipped, source_dir, output_dir): aggs = body.get("aggs") or body.get("aggregations") or {} if _has_auto_date_histogram(aggs): logger.warning( - "Skipping operation '%s': auto_date_histogram is not supported in Solr " - "(Solr requires explicit gap/start/end for range facets).", + "Skipping operation '%s': auto_date_histogram is not supported in Solr (Solr requires explicit gap/start/end for range facets).", op_name, ) skipped.append(f"{op_name} (auto_date_histogram not supported in Solr)") @@ -757,7 +749,7 @@ def _convert_fragment_text(raw: str, issues: list, skipped: list) -> str: if old_op != new_op: result = re.sub( rf'(:\s*"){re.escape(old_op)}(")', - rf'\1{new_op}\2', + rf"\1{new_op}\2", result, ) return result diff --git a/solrorbit/exceptions.py b/solrorbit/exceptions.py index 80423ad6..742563d6 100644 --- a/solrorbit/exceptions.py +++ b/solrorbit/exceptions.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -48,6 +48,7 @@ class LaunchError(BenchmarkError): Thrown whenever there was a problem launching the benchmark candidate """ + class InstallError(BenchmarkError): """ Thrown whenever there was a problem installing the benchmark candidate @@ -152,6 +153,7 @@ class MappingsError(BenchmarkError): # raise these so that worker_coordinator can record uniform error metadata. # --------------------------------------------------------------------------- + class BenchmarkTransportError(BenchmarkError): """HTTP/transport-level error from any benchmark target. diff --git a/solrorbit/log.py b/solrorbit/log.py index a8caf647..52cf2a67 100644 --- a/solrorbit/log.py +++ b/solrorbit/log.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an diff --git a/solrorbit/metrics.py b/solrorbit/metrics.py index 070e34cd..08366e31 100644 --- a/solrorbit/metrics.py +++ b/solrorbit/metrics.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -43,11 +43,14 @@ from solrorbit import time, exceptions, version, paths from solrorbit.utils import convert, console, io, versions from solrorbit.visualizations.benchmark_report_renderer import render_results_html + + class MetaInfoScope(Enum): """ Defines the scope of a meta-information. Meta-information provides more context for a metric, for example the concrete version of OpenSearch that has been benchmarked or environment information like CPU model or OS. """ + cluster = 1 """ Cluster level meta-information is valid for all nodes in the cluster (e.g. the benchmarked OpenSearch version) @@ -60,12 +63,8 @@ class MetaInfoScope(Enum): def calculate_results(store, test_run): calc = GlobalStatsCalculator( - store, - test_run.workload, - test_run.test_procedure, - latency_percentiles=test_run.latency_percentiles, - throughput_percentiles=test_run.throughput_percentiles - ) + store, test_run.workload, test_run.test_procedure, latency_percentiles=test_run.latency_percentiles, throughput_percentiles=test_run.throughput_percentiles + ) return calc() @@ -89,13 +88,9 @@ def metrics_store(cfg, read_only=True, workload=None, test_procedure=None, clust test_run_id = cfg.opts("system", "test_run.id") test_run_timestamp = cfg.opts("system", "time.start") - selected_cluster_config = cfg.opts("builder", "cluster_config.names") \ - if cluster_config is None else cluster_config + selected_cluster_config = cfg.opts("builder", "cluster_config.names") if cluster_config is None else cluster_config - store.open( - test_run_id, test_run_timestamp, - workload, test_procedure, selected_cluster_config, - create=not read_only) + store.open(test_run_id, test_run_timestamp, workload, test_procedure, selected_cluster_config, create=not read_only) return store @@ -171,9 +166,7 @@ def __init__(self, cfg, clock=time.Clock, meta_info=None): self._stop_watch = self._clock.stop_watch() self.logger = logging.getLogger(__name__) - def open(self, test_run_id=None, test_run_timestamp=None, workload_name=None,\ - test_procedure_name=None, cluster_config_name=None, ctx=None,\ - create=False): + def open(self, test_run_id=None, test_run_timestamp=None, workload_name=None, test_procedure_name=None, cluster_config_name=None, ctx=None, create=False): """ Opens a metrics store for a specific test_run, workload, test_procedure and cluster_config. @@ -201,13 +194,15 @@ def open(self, test_run_id=None, test_run_timestamp=None, workload_name=None,\ assert self._test_run_id is not None, "Attempting to open metrics store without a test run id" assert self._test_run_timestamp is not None, "Attempting to open metrics store without a test run timestamp" - self._cluster_config_name = "+".join(self._cluster_config) \ - if isinstance(self._cluster_config, list) \ - else self._cluster_config + self._cluster_config_name = "+".join(self._cluster_config) if isinstance(self._cluster_config, list) else self._cluster_config - self.logger.info("Opening metrics store for test run timestamp=[%s], workload=[%s]," - "test_procedure=[%s], cluster_config=[%s]", - self._test_run_timestamp, self._workload, self._test_procedure, self._cluster_config_name) + self.logger.info( + "Opening metrics store for test run timestamp=[%s], workload=[%s],test_procedure=[%s], cluster_config=[%s]", + self._test_run_timestamp, + self._workload, + self._test_procedure, + self._cluster_config_name, + ) user_tags = extract_user_tags_from_config(self._config) for k, v in user_tags.items(): @@ -264,10 +259,7 @@ def _clear_meta_info(self): """ Clears all internally stored meta-info. This is considered Solr Orbit internal API and not intended for normal client consumption. """ - self._meta_info = { - MetaInfoScope.cluster: {}, - MetaInfoScope.node: {} - } + self._meta_info = {MetaInfoScope.cluster: {}, MetaInfoScope.node: {}} @property def open_context(self): @@ -276,11 +268,12 @@ def open_context(self): "test-run-timestamp": self._test_run_timestamp, "workload": self._workload, "test_procedure": self._test_procedure, - "cluster-config-instance": self._cluster_config + "cluster-config-instance": self._cluster_config, } - def put_value_cluster_level(self, name, value, unit=None, task=None, operation=None, operation_type=None, sample_type=SampleType.Normal, - absolute_time=None, relative_time=None, meta_data=None): + def put_value_cluster_level( + self, name, value, unit=None, task=None, operation=None, operation_type=None, sample_type=SampleType.Normal, absolute_time=None, relative_time=None, meta_data=None + ): """ Adds a new cluster level value metric. @@ -297,11 +290,22 @@ def put_value_cluster_level(self, name, value, unit=None, task=None, operation=N Defaults to None. The metrics store will derive the timestamp automatically. :param meta_data: A dict, containing additional key-value pairs. Defaults to None. """ - self._put_metric(MetaInfoScope.cluster, None, name, value, unit, task, operation, operation_type, sample_type, absolute_time, - relative_time, meta_data) - - def put_value_node_level(self, node_name, name, value, unit=None, task=None, operation=None, operation_type=None, - sample_type=SampleType.Normal, absolute_time=None, relative_time=None, meta_data=None): + self._put_metric(MetaInfoScope.cluster, None, name, value, unit, task, operation, operation_type, sample_type, absolute_time, relative_time, meta_data) + + def put_value_node_level( + self, + node_name, + name, + value, + unit=None, + task=None, + operation=None, + operation_type=None, + sample_type=SampleType.Normal, + absolute_time=None, + relative_time=None, + meta_data=None, + ): """ Adds a new node level value metric. @@ -319,11 +323,9 @@ def put_value_node_level(self, node_name, name, value, unit=None, task=None, ope Defaults to None. The metrics store will derive the timestamp automatically. :param meta_data: A dict, containing additional key-value pairs. Defaults to None. """ - self._put_metric(MetaInfoScope.node, node_name, name, value, unit, task, operation, operation_type, sample_type, absolute_time, - relative_time, meta_data) + self._put_metric(MetaInfoScope.node, node_name, name, value, unit, task, operation, operation_type, sample_type, absolute_time, relative_time, meta_data) - def _put_metric(self, level, level_key, name, value, unit, task, operation, operation_type, sample_type, absolute_time=None, - relative_time=None, meta_data=None): + def _put_metric(self, level, level_key, name, value, unit, task, operation, operation_type, sample_type, absolute_time=None, relative_time=None, meta_data=None): if level == MetaInfoScope.cluster: meta = self._meta_info[MetaInfoScope.cluster].copy() elif level == MetaInfoScope.node: @@ -353,7 +355,7 @@ def _put_metric(self, level, level_key, name, value, unit, task, operation, oper "value": value, "unit": unit, "sample-type": sample_type.name.lower(), - "meta": meta + "meta": meta, } if task: doc["task"] = task @@ -397,17 +399,18 @@ def put_doc(self, doc, level=None, node_name=None, meta_data=None, absolute_time if relative_time is None: relative_time = self._stop_watch.split_time() - doc.update({ - "@timestamp": time.to_epoch_millis(absolute_time), - "relative-time-ms": convert.seconds_to_ms(relative_time), - "test-run-id": self._test_run_id, - "test-run-timestamp": self._test_run_timestamp, - "environment": self._environment_name, - "workload": self._workload, - "test_procedure": self._test_procedure, - "cluster-config-instance": self._cluster_config_name, - - }) + doc.update( + { + "@timestamp": time.to_epoch_millis(absolute_time), + "relative-time-ms": convert.seconds_to_ms(relative_time), + "test-run-id": self._test_run_id, + "test-run-timestamp": self._test_run_timestamp, + "environment": self._environment_name, + "workload": self._workload, + "test_procedure": self._test_procedure, + "cluster-config-instance": self._cluster_config_name, + } + ) if meta: doc["meta"] = meta if self._workload_params: @@ -437,8 +440,7 @@ def _add(self, doc): """ raise NotImplementedError("abstract method") - def get_one(self, name, sample_type=None, node_name=None, task=None, mapper=lambda doc: doc["value"], - sort_key=None, sort_reverse=False): + def get_one(self, name, sample_type=None, node_name=None, task=None, mapper=lambda doc: doc["value"], sort_key=None, sort_reverse=False): """ Gets one value for the given metric name (even if there should be more than one). @@ -564,6 +566,8 @@ def get_mean(self, name, task=None, operation_type=None, sample_type=None): """ stats = self.get_stats(name, task, operation_type, sample_type) return stats["avg"] if stats else None + + class InMemoryMetricsStore(MetricsStore): # Note that this implementation can run out of memory; generally, this can occur when ingesting very large corpora. @@ -608,7 +612,6 @@ def _add(self, doc): self.docs.append(doc) self.doc_count += 1 - def flush(self, refresh=True): pass @@ -619,12 +622,10 @@ def to_externalizable(self, clear=False): self.doc_count = 0 self.out_of_memory = False if len(docs) * self.DOC_SIZE_IN_BYTES > psutil.virtual_memory().available - self.memory_available_threshold: - console.warn("Memory threshold exceeded by in-memory metrics store, skipping summary generation for current operation", - logger=self.logger) + console.warn("Memory threshold exceeded by in-memory metrics store, skipping summary generation for current operation", logger=self.logger) return None compressed = zlib.compress(pickle.dumps(docs)) - self.logger.debug("Compression changed size of metric store from [%d] bytes to [%d] bytes", - sys.getsizeof(docs, -1), sys.getsizeof(compressed, -1)) + self.logger.debug("Compression changed size of metric store from [%d] bytes to [%d] bytes", sys.getsizeof(docs, -1), sys.getsizeof(compressed, -1)) return compressed def get_percentiles(self, name, task=None, operation_type=None, sample_type=None, percentiles=None): @@ -665,9 +666,12 @@ def get_error_rate(self, task, operation_type=None, sample_type=None): total_count = 0 for doc in self.docs: # we can use any request metrics record (i.e. service time or latency) - if doc["name"] == "service_time" and doc["task"] == task and \ - (operation_type is None or doc["operation-type"] == operation_type) and \ - (sample_type is None or doc["sample-type"] == sample_type.name.lower()): + if ( + doc["name"] == "service_time" + and doc["task"] == task + and (operation_type is None or doc["operation-type"] == operation_type) + and (sample_type is None or doc["sample-type"] == sample_type.name.lower()) + ): total_count += 1 if doc["meta"]["success"] is False: error += 1 @@ -680,36 +684,33 @@ def get_stats(self, name, task=None, operation_type=None, sample_type=SampleType values = self.get(name, task, operation_type, sample_type) sorted_values = sorted(values) if len(sorted_values) > 0: - return { - "count": len(sorted_values), - "min": sorted_values[0], - "max": sorted_values[-1], - "avg": statistics.mean(sorted_values), - "sum": sum(sorted_values) - } + return {"count": len(sorted_values), "min": sorted_values[0], "max": sorted_values[-1], "avg": statistics.mean(sorted_values), "sum": sum(sorted_values)} else: return None def _get(self, name, task, operation_type, sample_type, node_name, mapper): - return [mapper(doc) - for doc in self.docs - if doc["name"] == name and - (task is None or doc["task"] == task) and - (operation_type is None or doc["operation-type"] == operation_type) and - (sample_type is None or doc["sample-type"] == sample_type.name.lower()) and - (node_name is None or doc.get("meta", {}).get("node_name") == node_name) - ] - - def get_one(self, name, sample_type=None, node_name=None, task=None, mapper=lambda doc: doc["value"], - sort_key=None, sort_reverse=False): + return [ + mapper(doc) + for doc in self.docs + if doc["name"] == name + and (task is None or doc["task"] == task) + and (operation_type is None or doc["operation-type"] == operation_type) + and (sample_type is None or doc["sample-type"] == sample_type.name.lower()) + and (node_name is None or doc.get("meta", {}).get("node_name") == node_name) + ] + + def get_one(self, name, sample_type=None, node_name=None, task=None, mapper=lambda doc: doc["value"], sort_key=None, sort_reverse=False): if sort_key: docs = sorted(self.docs, key=lambda k: k[sort_key], reverse=sort_reverse) else: docs = self.docs for doc in docs: - if (doc["name"] == name and (task is None or doc["task"] == task) and - (sample_type is None or doc["sample-type"] == sample_type.name.lower()) and - (node_name is None or doc.get("meta", {}).get("node_name") == node_name)): + if ( + doc["name"] == name + and (task is None or doc["task"] == task) + and (sample_type is None or doc["sample-type"] == sample_type.name.lower()) + and (node_name is None or doc.get("meta", {}).get("node_name") == node_name) + ): return mapper(doc) return None @@ -724,10 +725,8 @@ def __init__(self, cfg, clock=time.Clock, meta_info=None): super().__init__(cfg=cfg, clock=clock, meta_info=meta_info) self._metrics_file = None - def open(self, test_run_id=None, test_run_timestamp=None, workload_name=None, - test_procedure_name=None, cluster_config_name=None, ctx=None, create=False): - super().open(test_run_id, test_run_timestamp, workload_name, - test_procedure_name, cluster_config_name, ctx, create) + def open(self, test_run_id=None, test_run_timestamp=None, workload_name=None, test_procedure_name=None, cluster_config_name=None, ctx=None, create=False): + super().open(test_run_id, test_run_timestamp, workload_name, test_procedure_name, cluster_config_name, ctx, create) if create: run_dir = paths.test_run_root(self._config, test_run_id=self._test_run_id) io.ensure_dir(run_dir) @@ -787,42 +786,51 @@ def format_dict(d): test_runs = [] for test_run in store_item: - test_runs.append([ - test_run.test_run_id, - time.to_iso8601(test_run.test_run_timestamp), - test_run.workload, - format_dict(test_run.workload_params), - test_run.test_procedure_name, - test_run.cluster_config_name, - format_dict(test_run.user_tags), - test_run.workload_revision, - test_run.cluster_config_revision]) + test_runs.append( + [ + test_run.test_run_id, + time.to_iso8601(test_run.test_run_timestamp), + test_run.workload, + format_dict(test_run.workload_params), + test_run.test_procedure_name, + test_run.cluster_config_name, + format_dict(test_run.user_tags), + test_run.workload_revision, + test_run.cluster_config_revision, + ] + ) if len(test_runs) > 0: console.println(f"\nRecent {title}:\n") - console.println(tabulate.tabulate( - test_runs, - headers=[ - "TestRun ID", - "TestRun Timestamp", - "Workload", - "Workload Parameters", - "TestProcedure", - "ClusterConfigInstance", - "User Tags", - "workload Revision", - "Cluster Config Revision" - ])) + console.println( + tabulate.tabulate( + test_runs, + headers=[ + "TestRun ID", + "TestRun Timestamp", + "Workload", + "Workload Parameters", + "TestProcedure", + "ClusterConfigInstance", + "User Tags", + "workload Revision", + "Cluster Config Revision", + ], + ) + ) else: console.println("") console.println(f"No recent {title} found.") + def list_test_runs(cfg): list_test_helper(test_run_store(cfg).list(), "test-runs") + def list_aggregated_results(cfg): list_test_helper(test_run_store(cfg).list_aggregations(), "aggregated-results") + def create_test_run(cfg, workload, test_procedure, workload_revision=None): cluster_config = cfg.opts("builder", "cluster_config.names") environment = cfg.opts("system", "env.name") @@ -835,33 +843,62 @@ def create_test_run(cfg, workload, test_procedure, workload_revision=None): plugin_params = cfg.opts("builder", "plugin.params") benchmark_version = version.version() benchmark_revision = version.revision() - latency_percentiles = cfg.opts("workload", "latency.percentiles", mandatory=False, - default_value=GlobalStatsCalculator.DEFAULT_LATENCY_PERCENTILES) - throughput_percentiles = cfg.opts("workload", "throughput.percentiles", mandatory=False, - default_value=GlobalStatsCalculator.DEFAULT_THROUGHPUT_PERCENTILES) + latency_percentiles = cfg.opts("workload", "latency.percentiles", mandatory=False, default_value=GlobalStatsCalculator.DEFAULT_LATENCY_PERCENTILES) + throughput_percentiles = cfg.opts("workload", "throughput.percentiles", mandatory=False, default_value=GlobalStatsCalculator.DEFAULT_THROUGHPUT_PERCENTILES) # In tests, we don't get the default command-line arg value for percentiles, # so supply them as defaults here as well # Get cluster_config_instance if available (stored during provisioning) cluster_config_instance = cfg.opts("builder", "cluster_config.instance", mandatory=False, default_value=None) - return TestRun(benchmark_version, benchmark_revision, - environment, test_run_id, test_run_timestamp, - pipeline, user_tags, workload, - workload_params, test_procedure, cluster_config, cluster_config_params, - plugin_params, workload_revision, latency_percentiles=latency_percentiles, - throughput_percentiles=throughput_percentiles, cluster_config_instance=cluster_config_instance) + return TestRun( + benchmark_version, + benchmark_revision, + environment, + test_run_id, + test_run_timestamp, + pipeline, + user_tags, + workload, + workload_params, + test_procedure, + cluster_config, + cluster_config_params, + plugin_params, + workload_revision, + latency_percentiles=latency_percentiles, + throughput_percentiles=throughput_percentiles, + cluster_config_instance=cluster_config_instance, + ) class TestRun: - def __init__(self, benchmark_version, benchmark_revision, environment_name, - test_run_id, test_run_timestamp, pipeline, user_tags, - workload, workload_params, test_procedure, cluster_config, - cluster_config_params, plugin_params, - workload_revision=None, cluster_config_revision=None, - distribution_version=None, distribution_flavor=None, - revision=None, results=None, meta_data=None, latency_percentiles=None, throughput_percentiles=None, - cluster_config_instance=None): + def __init__( + self, + benchmark_version, + benchmark_revision, + environment_name, + test_run_id, + test_run_timestamp, + pipeline, + user_tags, + workload, + workload_params, + test_procedure, + cluster_config, + cluster_config_params, + plugin_params, + workload_revision=None, + cluster_config_revision=None, + distribution_version=None, + distribution_flavor=None, + revision=None, + results=None, + meta_data=None, + latency_percentiles=None, + throughput_percentiles=None, + cluster_config_instance=None, + ): if results is None: results = {} # this happens when the test run is created initially @@ -900,7 +937,6 @@ def __init__(self, benchmark_version, benchmark_revision, environment_name, self.throughput_percentiles = throughput_percentiles self.cluster_config_instance = cluster_config_instance - @property def workload_name(self): return str(self.workload) @@ -911,9 +947,7 @@ def test_procedure_name(self): @property def cluster_config_name(self): - return "+".join(self.cluster_config) \ - if isinstance(self.cluster_config, list) \ - else self.cluster_config + return "+".join(self.cluster_config) if isinstance(self.cluster_config, list) else self.cluster_config def add_results(self, results): self.results = results @@ -937,7 +971,7 @@ def as_dict(self): "distribution-version": self.distribution_version, "distribution-flavor": self.distribution_flavor, "cluster-config-revision": self.cluster_config_revision, - } + }, } if self.results: # if results was loaded from JSON it’s already a dict @@ -968,6 +1002,7 @@ def as_dict(self): "flavor": str(self.cluster_config_instance.flavor) if hasattr(self.cluster_config_instance, "flavor") else None, } return d + def to_result_dicts(self): """ :return: a list of dicts, suitable for persisting the results of this test run in a format that is Kibana-friendly. @@ -985,7 +1020,7 @@ def to_result_dicts(self): "test_procedure": self.test_procedure_name, "cluster-config-instance": self.cluster_config_name, # allow to logically delete records, e.g. for UI purposes when we only want to show the latest result - "active": True + "active": True, } if self.distribution_version: result_template["distribution-major-version"] = versions.major_version(self.distribution_version) @@ -1016,16 +1051,28 @@ def from_dict(cls, d): user_tags = d.get("user-tags", {}) # TODO: cluster is optional for BWC. This can be removed after some grace period. cluster = d.get("cluster", {}) - return TestRun(d["benchmark-version"], d.get("benchmark-revision"), d["environment"], d["test-run-id"], - time.from_is8601(d["test-run-timestamp"]), - d["pipeline"], user_tags, d["workload"], d.get("workload-params"), - d.get("test_procedure"), d["cluster-config-instance"], - d.get("cluster-config-instance-params"), d.get("plugin-params"), - workload_revision=d.get("workload-revision"), - cluster_config_revision=cluster.get("cluster-config-revision"), - distribution_version=cluster.get("distribution-version"), - distribution_flavor=cluster.get("distribution-flavor"), - revision=cluster.get("revision"), results=d.get("results"), meta_data=d.get("meta", {})) + return TestRun( + d["benchmark-version"], + d.get("benchmark-revision"), + d["environment"], + d["test-run-id"], + time.from_is8601(d["test-run-timestamp"]), + d["pipeline"], + user_tags, + d["workload"], + d.get("workload-params"), + d.get("test_procedure"), + d["cluster-config-instance"], + d.get("cluster-config-instance-params"), + d.get("plugin-params"), + workload_revision=d.get("workload-revision"), + cluster_config_revision=cluster.get("cluster-config-revision"), + distribution_version=cluster.get("distribution-version"), + distribution_flavor=cluster.get("distribution-flavor"), + revision=cluster.get("revision"), + results=d.get("results"), + meta_data=d.get("meta", {}), + ) class TestRunStore: @@ -1056,6 +1103,7 @@ class CompositeTestRunStore: Not wired into any active code path. Does not inherit from TestRunStore — it is a delegator with the same API. """ + def __init__(self, external_store, file_store): self.external_store = external_store self.file_store = file_store @@ -1080,6 +1128,7 @@ class FileTestRunStore(TestRunStore): def __init__(self, cfg): super().__init__(cfg) self._max_results = lambda: int(cfg.opts("system", "list.test_runs.max_results")) + def store_test_run(self, test_run): open_browser = False doc = test_run.as_dict() @@ -1114,7 +1163,7 @@ def store_html_results(self, test_run, open_browser=True): dest = os.path.expanduser(custom_output_path) os.makedirs(os.path.dirname(dest), exist_ok=True) print("[DEBUG]: ", dest) - with open(dest, 'w', encoding='utf-8') as f: + with open(dest, "w", encoding="utf-8") as f: f.write(html_content) console.info(f"HTML report saved to: {dest}") if open_browser: @@ -1133,22 +1182,21 @@ def store_html_results(self, test_run, open_browser=True): def _test_run_file(self, test_run_id=None, is_aggregated=False): if is_aggregated: - return os.path.join(paths.aggregated_results_root(cfg=self.cfg, test_run_id=test_run_id), - "aggregated_test_run.json") + return os.path.join(paths.aggregated_results_root(cfg=self.cfg, test_run_id=test_run_id), "aggregated_test_run.json") else: return os.path.join(paths.test_run_root(cfg=self.cfg, test_run_id=test_run_id), "test_run.json") def list(self): results = glob.glob(self._test_run_file(test_run_id="*")) all_test_runs = self._to_test_runs(results) - return all_test_runs[:self._max_results()] + return all_test_runs[: self._max_results()] def list_aggregations(self): aggregated_results = glob.glob(self._test_run_file(test_run_id="*", is_aggregated=True)) return self._to_test_runs(aggregated_results) def find_by_test_run_id(self, test_run_id): - is_aggregated = test_run_id.startswith('aggregate') + is_aggregated = test_run_id.startswith("aggregate") test_run_file = self._test_run_file(test_run_id=test_run_id, is_aggregated=is_aggregated) if io.exists(test_run_file): test_runs = self._to_test_runs([test_run_file]) @@ -1166,10 +1214,13 @@ def _to_test_runs(self, results): except BaseException: logging.getLogger(__name__).exception("Could not load test_run file [%s] (incompatible format?) Skipping...", result) return sorted(test_runs, key=lambda r: r.test_run_timestamp, reverse=True) + + class NoopResultsStore: """ Does not store any results separately as these are stored as part of the test_run on the file system. """ + def store_results(self, test_run): pass @@ -1200,27 +1251,29 @@ def filter_percentiles_by_sample_size(sample_size, percentiles): if p in percentiles: filtered_percentiles.append(p) else: - effective_sample_size = 10 ** (int(math.log10(sample_size))) # round down to nearest power of ten - delta = 0.000001 # If (p / 100) * effective_sample_size is within this value of a whole number, + effective_sample_size = 10 ** (int(math.log10(sample_size))) # round down to nearest power of ten + delta = 0.000001 # If (p / 100) * effective_sample_size is within this value of a whole number, # assume the discrepancy is due to floating point and allow it for p in percentiles: fraction = p / 100 # check if fraction * effective_sample_size is close enough to a whole number - if abs((effective_sample_size * fraction) - round(effective_sample_size*fraction)) < delta or p in [25, 75]: + if abs((effective_sample_size * fraction) - round(effective_sample_size * fraction)) < delta or p in [25, 75]: filtered_percentiles.append(p) # if no percentiles are suitable, just return 100 if len(filtered_percentiles) == 0: return [100] return filtered_percentiles + def percentiles_for_sample_size(sample_size, percentiles_list=None): # If latency_percentiles is present, as a list, display those values instead (assuming there are enough samples) percentiles = [] if percentiles_list: - percentiles = percentiles_list # Defaults get overridden if a value is provided + percentiles = percentiles_list # Defaults get overridden if a value is provided percentiles.sort() return filter_percentiles_by_sample_size(sample_size, percentiles) + class GlobalStatsCalculator: DEFAULT_LATENCY_PERCENTILES = "50,90,99,99.9,99.99,100" DEFAULT_LATENCY_PERCENTILES_LIST = [float(value) for value in DEFAULT_LATENCY_PERCENTILES.split(",")] @@ -1228,7 +1281,7 @@ class GlobalStatsCalculator: DEFAULT_THROUGHPUT_PERCENTILES = "" DEFAULT_THROUGHPUT_PERCENTILES_LIST = [] - OTHER_PERCENTILES = [50,90,99,99.9,99.99,100] + OTHER_PERCENTILES = [50, 90, 99, 99.9, 99.99, 100] # Use these percentiles when the single_latency fn is called for something other than latency def __init__(self, store, workload, test_procedure, latency_percentiles=None, throughput_percentiles=None): @@ -1275,17 +1328,13 @@ def __call__(self): self.single_latency(task_name, op_type, metric_name="recall@k"), self.single_latency(task_name, op_type, metric_name="recall@1"), error_rate, - duration + duration, ) profile_metrics = task.operation.params.get("profile-metrics", None) if profile_metrics: profile_metrics.append("query_time") - result.add_profile_metrics( - task_name, - task.operation.name, - {name: self.single_latency(task_name, op_type, metric_name=name) for name in profile_metrics} - ) + result.add_profile_metrics(task_name, task.operation.name, {name: self.single_latency(task_name, op_type, metric_name=name) for name in profile_metrics}) self.logger.debug("Gathering indexing metrics.") result.total_time = self.sum("indexing_total_time") @@ -1361,31 +1410,19 @@ def summary_stats(self, metric_name, task_name, operation_type, percentiles_list result = {} if mean and median and stats: - result = { - "min": stats["min"], - "mean": mean, - "median": median, - "max": stats["max"], - "unit": unit - } + result = {"min": stats["min"], "mean": mean, "median": median, "max": stats["max"], "unit": unit} else: - result = { - "min": None, - "mean": None, - "median": None, - "max": None, - "unit": unit - } + result = {"min": None, "mean": None, "median": None, "max": None, "unit": unit} - if percentiles_list: # modified from single_latency() + if percentiles_list: # modified from single_latency() sample_size = stats["count"] - percentiles = self.store.get_percentiles(metric_name, - task=task_name, - operation_type=operation_type, - sample_type=SampleType.Normal, - percentiles=percentiles_for_sample_size( - sample_size, - percentiles_list=percentiles_list)) + percentiles = self.store.get_percentiles( + metric_name, + task=task_name, + operation_type=operation_type, + sample_type=SampleType.Normal, + percentiles=percentiles_for_sample_size(sample_size, percentiles_list=percentiles_list), + ) for k, v in percentiles.items(): # safely encode so we don't have any dots in field names result[encode_float_key(k)] = v @@ -1396,12 +1433,7 @@ def shard_stats(self, metric_name): unit = self.store.get_unit(metric_name) if values: flat_values = [w for v in values for w in v] - return { - "min": min(flat_values), - "median": statistics.median(flat_values), - "max": max(flat_values), - "unit": unit - } + return {"min": min(flat_values), "median": statistics.median(flat_values), "max": max(flat_values), "unit": unit} else: return {} @@ -1410,14 +1442,7 @@ def ml_processing_time_stats(self): result = [] if values: for v in values: - result.append({ - "job": v["job"], - "min": v["min"], - "mean": v["mean"], - "median": v["median"], - "max": v["max"], - "unit": v["unit"] - }) + result.append({"job": v["job"], "min": v["min"], "mean": v["mean"], "median": v["median"], "max": v["max"], "unit": v["unit"]}) return result def total_transform_metric(self, metric_name): @@ -1427,19 +1452,14 @@ def total_transform_metric(self, metric_name): for v in values: transform_id = v.get("meta", {}).get("transform_id") if transform_id is not None: - result.append({ - "id": transform_id, - "mean": v["value"], - "unit": v["unit"] - }) + result.append({"id": transform_id, "mean": v["value"], "unit": v["unit"]}) return result def error_rate(self, task_name, operation_type): return self.store.get_error_rate(task=task_name, operation_type=operation_type, sample_type=SampleType.Normal) def duration(self, task_name): - return self.store.get_one("service_time", task=task_name, mapper=lambda doc: doc["relative-time-ms"], - sort_key="relative-time-ms", sort_reverse=True) + return self.store.get_one("service_time", task=task_name, mapper=lambda doc: doc["relative-time-ms"], sort_key="relative-time-ms", sort_reverse=True) def median(self, metric_name, task_name=None, operation_type=None, sample_type=None): return self.store.get_median(metric_name, task=task_name, operation_type=operation_type, sample_type=sample_type) @@ -1454,18 +1474,14 @@ def single_latency(self, task, operation_type, metric_name="latency"): if sample_size > 0: # The custom latency percentiles have to be supplied here as the workload runs, # or else they aren't present when results are published - percentiles = self.store.get_percentiles(metric_name, - task=task, - operation_type=operation_type, - sample_type=sample_type, - percentiles=percentiles_for_sample_size( - sample_size, - percentiles_list=percentiles_list - )) - mean = self.store.get_mean(metric_name, - task=task, - operation_type=operation_type, - sample_type=sample_type) + percentiles = self.store.get_percentiles( + metric_name, + task=task, + operation_type=operation_type, + sample_type=sample_type, + percentiles=percentiles_for_sample_size(sample_size, percentiles_list=percentiles_list), + ) + mean = self.store.get_mean(metric_name, task=task, operation_type=operation_type, sample_type=sample_type) unit = self.store.get_unit(metric_name, task=task, operation_type=operation_type) stats = collections.OrderedDict() for k, v in percentiles.items(): @@ -1525,13 +1541,9 @@ def as_dict(self): def as_flat_list(self): def op_metrics(op_item, key, single_value=False): - doc = { - "task": op_item["task"], - "operation": op_item["operation"], - "name": key - } + doc = {"task": op_item["task"], "operation": op_item["operation"], "name": key} if single_value: - doc["value"] = {"single": op_item[key]} + doc["value"] = {"single": op_item[key]} else: doc["value"] = op_item[key] if "meta" in op_item: @@ -1558,55 +1570,27 @@ def op_metrics(op_item, key, single_value=False): all_results.append(op_metrics(item, "duration", single_value=True)) elif metric == "ml_processing_time": for item in value: - all_results.append({ - "job": item["job"], - "name": "ml_processing_time", - "value": { - "min": item["min"], - "mean": item["mean"], - "median": item["median"], - "max": item["max"] - } - }) + all_results.append( + {"job": item["job"], "name": "ml_processing_time", "value": {"min": item["min"], "mean": item["mean"], "median": item["median"], "max": item["max"]}} + ) elif metric == "correctness_metrics": for item in value: for knn_metric in ["recall@k", "recall@1"]: if knn_metric in item: - all_results.append({ - "task": item["task"], - "operation": item["operation"], - "name": knn_metric, - "value": item[knn_metric] - }) + all_results.append({"task": item["task"], "operation": item["operation"], "name": knn_metric, "value": item[knn_metric]}) elif metric == "profile_metrics": for item in value: for metric_name in item.keys(): if metric_name not in ["task", "operation", "error_rate", "duration"]: - all_results.append({ - "task": item["task"], - "operation": item["operation"], - "name": metric_name, - "value": item[metric_name] - }) + all_results.append({"task": item["task"], "operation": item["operation"], "name": metric_name, "value": item[metric_name]}) elif metric.startswith("total_transform_") and value is not None: for item in value: - all_results.append({ - "id": item["id"], - "name": metric, - "value": { - "single": item["mean"] - } - }) + all_results.append({"id": item["id"], "name": metric, "value": {"single": item["mean"]}}) elif metric.endswith("_time_per_shard"): if value: all_results.append({"name": metric, "value": value}) elif value is not None: - result = { - "name": metric, - "value": { - "single": value - } - } + result = {"name": metric, "value": {"single": value}} all_results.append(result) # sorting is just necessary to have a stable order for tests. As we just have a small number of metrics, the overhead is neglible. return sorted(all_results, key=lambda m: m["name"]) @@ -1614,8 +1598,7 @@ def op_metrics(op_item, key, single_value=False): def v(self, d, k, default=None): return d.get(k, default) if d else default - def add_op_metrics(self, task, operation, throughput, latency, service_time, client_processing_time, - processing_time, error_rate, duration, meta): + def add_op_metrics(self, task, operation, throughput, latency, service_time, client_processing_time, processing_time, error_rate, duration, meta): doc = { "task": task, "operation": operation, @@ -1625,28 +1608,26 @@ def add_op_metrics(self, task, operation, throughput, latency, service_time, cli "client_processing_time": client_processing_time, "processing_time": processing_time, "error_rate": error_rate, - "duration": duration + "duration": duration, } if meta: doc["meta"] = meta self.op_metrics.append(doc) def add_correctness_metrics(self, task, operation, recall_at_k_stats, recall_at_1_stats, error_rate, duration): - self.correctness_metrics.append({ - "task": task, - "operation": operation, - "recall@k": recall_at_k_stats, - "recall@1":recall_at_1_stats, - "error_rate": error_rate, - "duration": duration, - }) + self.correctness_metrics.append( + { + "task": task, + "operation": operation, + "recall@k": recall_at_k_stats, + "recall@1": recall_at_1_stats, + "error_rate": error_rate, + "duration": duration, + } + ) def add_profile_metrics(self, task, operation, profile_metrics): - self.profile_metrics.append({ - "task": task, - "operation": operation, - "metrics": profile_metrics - }) + self.profile_metrics.append({"task": task, "operation": operation, "metrics": profile_metrics}) def tasks(self): # ensure we can read test_run.json files before Solr Orbit 0.8.0 @@ -1694,11 +1675,7 @@ def v(self, d, k, default=None): return d.get(k, default) if d else default def add_node_metrics(self, node, name, value, unit): - metric = { - "node": node, - "name": name, - "value": value - } + metric = {"node": node, "name": name, "value": value} if unit: metric["unit"] = unit self.node_metrics.append(metric) diff --git a/solrorbit/paths.py b/solrorbit/paths.py index c476d3b9..eeb19917 100644 --- a/solrorbit/paths.py +++ b/solrorbit/paths.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -27,12 +27,14 @@ import os from solrorbit.utils.io import ensure_dir + def benchmark_confdir(): default_home = os.path.expanduser("~") benchmark_confdir_path = os.path.join(os.getenv("BENCHMARK_HOME", default_home), ".solr-orbit") ensure_dir(benchmark_confdir_path) return benchmark_confdir_path + def benchmark_root(): return os.path.dirname(os.path.realpath(__file__)) @@ -46,11 +48,13 @@ def test_run_root(cfg, test_run_id=None): test_run_id = cfg.opts("system", "test_run.id") return os.path.join(test_runs_root(cfg), test_run_id) + def aggregated_results_root(cfg, test_run_id=None): if not test_run_id: test_run_id = cfg.opts("system", "test_run.id") return os.path.join(cfg.opts("node", "root.dir"), "aggregated_results", test_run_id) + def install_root(cfg=None): install_id = cfg.opts("system", "install.id") return os.path.join(test_runs_root(cfg), install_id) diff --git a/solrorbit/publisher.py b/solrorbit/publisher.py index af8ae5b2..2852f675 100644 --- a/solrorbit/publisher.py +++ b/solrorbit/publisher.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -49,23 +49,23 @@ ------------------------------------------------------ """ + class Throughput(Enum): MEAN = "mean" MAX = "max" MIN = "min" MEDIAN = "median" + def summarize(results, cfg): - SummaryResultsPublisher(results, cfg).publish() # check-deprecated-terms-disable-1x + SummaryResultsPublisher(results, cfg).publish() # check-deprecated-terms-disable-1x def compare(cfg, baseline_id, contender_id): if not baseline_id or not contender_id: raise exceptions.SystemSetupError("compare needs baseline and a contender") test_run_store = metrics.test_run_store(cfg) - ComparisonPublisher(cfg).publish( - test_run_store.find_by_test_run_id(baseline_id), - test_run_store.find_by_test_run_id(contender_id)) + ComparisonPublisher(cfg).publish(test_run_store.find_by_test_run_id(baseline_id), test_run_store.find_by_test_run_id(contender_id)) def print_internal(message): @@ -106,6 +106,7 @@ def format_as_csv(headers, data): writer.writerow(metric_record) return out.getvalue() + def comma_separated_string_to_number_list(string_list): # Split a comma-separated list in a string to a list of numbers. If they are whole numbers, make them ints, # so they display without decimals. @@ -126,25 +127,21 @@ def __init__(self, results, config): self.config = config self.results_file = config.opts("reporting", "output.path") self.results_format = config.opts("reporting", "format") - self.numbers_align = config.opts("reporting", "numbers.align", - mandatory=False, default_value="right") + self.numbers_align = config.opts("reporting", "numbers.align", mandatory=False, default_value="right") reporting_values = config.opts("reporting", "values") self.publish_all_values = reporting_values == "all" self.publish_all_percentile_values = reporting_values == "all-percentiles" - self.show_processing_time = convert.to_bool(config.opts("reporting", "output.processingtime", - mandatory=False, default_value=False)) + self.show_processing_time = convert.to_bool(config.opts("reporting", "output.processingtime", mandatory=False, default_value=False)) self.cwd = config.opts("node", "benchmark.cwd") self.display_percentiles = { - "throughput":comma_separated_string_to_number_list(config.opts("workload", "throughput.percentiles", mandatory=False)), - "latency": comma_separated_string_to_number_list(config.opts("workload", "latency.percentiles", mandatory=False)) + "throughput": comma_separated_string_to_number_list(config.opts("workload", "throughput.percentiles", mandatory=False)), + "latency": comma_separated_string_to_number_list(config.opts("workload", "latency.percentiles", mandatory=False)), } self.logger = logging.getLogger(__name__) writer_name = config.opts("reporting", "results_writer", mandatory=False, default_value=None) results_path = config.opts("reporting", "results_path", mandatory=False, default_value=None) if writer_name and results_path: - self._result_writer = result_writer.create_writer( - writer_name, results_path=rio.normalize_path(results_path) - ) + self._result_writer = result_writer.create_writer(writer_name, results_path=rio.normalize_path(results_path)) else: self._result_writer = None @@ -182,7 +179,6 @@ def publish(self): throughput_pattern = r"_(\d+)_clients$" - for record in stats.op_metrics: task = record["task"] is_task_part_of_throughput_testing = re.search(throughput_pattern, task) @@ -200,8 +196,7 @@ def publish(self): # The following code is run when the clients_list parameter is specified and publishes the max throughput. if max_throughput != -1 and record_with_best_throughput is not None: - self.publish_operational_statistics(metrics_table=metrics_table, warnings=warnings, record=record_with_best_throughput, - task=record_with_best_throughput["task"]) + self.publish_operational_statistics(metrics_table=metrics_table, warnings=warnings, record=record_with_best_throughput, task=record_with_best_throughput["task"]) metrics_table.extend(self._publish_best_client_settings(record_with_best_throughput, record_with_best_throughput["task"])) for record in stats.correctness_metrics: @@ -228,21 +223,16 @@ def add_warnings(self, warnings, values, op): if values["throughput"]["median"] is None: error_rate = values["error_rate"] if error_rate: - warnings.append("No throughput metrics available for [%s]. Likely cause: Error rate is %.1f%%. Please check the logs." - % (op, error_rate * 100)) + warnings.append("No throughput metrics available for [%s]. Likely cause: Error rate is %.1f%%. Please check the logs." % (op, error_rate * 100)) else: warnings.append("No throughput metrics available for [%s]. Likely cause: The benchmark ended already during warmup." % op) def write_results(self, metrics_table): - write_single_results(self.results_file, self.results_format, self.cwd, self.numbers_align, - headers=["Metric", "Task", "Value", "Unit"], - data_plain=metrics_table, - data_rich=metrics_table) + write_single_results( + self.results_file, self.results_format, self.cwd, self.numbers_align, headers=["Metric", "Task", "Value", "Unit"], data_plain=metrics_table, data_rich=metrics_table + ) if self._result_writer is not None: - structured = [ - {"name": row[0], "task": row[1], "value": row[2], "unit": row[3] or ""} - for row in metrics_table if row - ] + structured = [{"name": row[0], "task": row[1], "value": row[2], "unit": row[3] or ""} for row in metrics_table if row] # Build run metadata with run_id from config and any other stats metadata run_metadata = self.results.as_dict() if hasattr(self.results, "as_dict") else {} run_metadata["run_id"] = self.config.opts("system", "test_run.id") @@ -265,7 +255,7 @@ def _publish_throughput(self, values, task): self._line("Mean Throughput", task, throughput["mean"], unit, lambda v: "%.2f" % v), self._line("Median Throughput", task, throughput["median"], unit, lambda v: "%.2f" % v), self._line("Max Throughput", task, throughput["max"], unit, lambda v: "%.2f" % v), - *self._publish_percentiles("throughput", task, throughput) + *self._publish_percentiles("throughput", task, throughput), ) def _publish_latency(self, values, task): @@ -281,17 +271,12 @@ def _publish_recall(self, values, task): recall_k_mean = values["recall@k"]["mean"] recall_1_mean = values["recall@1"]["mean"] - return self._join( - self._line("Mean recall@k", task, recall_k_mean, "", lambda v: "%.2f" % v), - self._line("Mean recall@1", task, recall_1_mean, "", lambda v: "%.2f" % v) - ) + return self._join(self._line("Mean recall@k", task, recall_k_mean, "", lambda v: "%.2f" % v), self._line("Mean recall@1", task, recall_1_mean, "", lambda v: "%.2f" % v)) def _publish_profile_metrics(self, metrics, task): percentiles = [self._publish_percentiles(key, task, value) for key, value in metrics.items()] - return self._join( - *[item for percentile in percentiles for item in percentile] - ) + return self._join(*[item for percentile in percentiles for item in percentile]) def _publish_best_client_settings(self, record, task): num_clients = re.search(r"_(\d+)_clients$", task).group(1) @@ -304,15 +289,12 @@ def _publish_percentiles(self, name, task, value, unit="ms"): if value: for percentile in metrics.percentiles_for_sample_size(sys.maxsize, percentiles_list=percentiles): percentile_value = value.get(metrics.encode_float_key(percentile)) - a_line = self._line("%sth percentile %s" % (percentile, name), task, percentile_value, unit, - force=self.publish_all_percentile_values) + a_line = self._line("%sth percentile %s" % (percentile, name), task, percentile_value, unit, force=self.publish_all_percentile_values) self._append_non_empty(lines, a_line) return lines def _publish_error_rate(self, values, task): - return self._join( - self._line("error rate", task, values["error_rate"], "%", lambda v: "%.2f" % (v * 100.0)) - ) + return self._join(self._line("error rate", task, values["error_rate"], "%", lambda v: "%.2f" % (v * 100.0))) def _publish_totals(self, stats): lines = [] @@ -342,12 +324,9 @@ def _publish_total_time(self, name, total_time): def _publish_total_time_per_shard(self, name, total_time_per_shard): unit = "min" return self._join( - self._line("Min cumulative {} across primary shards".format(name), "", total_time_per_shard.get("min"), unit, - convert.ms_to_minutes), - self._line("Median cumulative {} across primary shards".format(name), "", total_time_per_shard.get("median"), - unit, convert.ms_to_minutes), - self._line("Max cumulative {} across primary shards".format(name), "", total_time_per_shard.get("max"), unit, - convert.ms_to_minutes), + self._line("Min cumulative {} across primary shards".format(name), "", total_time_per_shard.get("min"), unit, convert.ms_to_minutes), + self._line("Median cumulative {} across primary shards".format(name), "", total_time_per_shard.get("median"), unit, convert.ms_to_minutes), + self._line("Max cumulative {} across primary shards".format(name), "", total_time_per_shard.get("max"), unit, convert.ms_to_minutes), ) def _publish_total_count(self, name, total_count): @@ -371,7 +350,7 @@ def _publish_gc_metrics(self, stats): self._line("Total Young Gen GC time", "", stats.young_gc_time, "s", convert.ms_to_seconds), self._line("Total Young Gen GC count", "", stats.young_gc_count, ""), self._line("Total Old Gen GC time", "", stats.old_gc_time, "s", convert.ms_to_seconds), - self._line("Total Old Gen GC count", "", stats.old_gc_count, "") + self._line("Total Old Gen GC count", "", stats.old_gc_count, ""), ) def _publish_disk_usage(self, stats): @@ -388,29 +367,22 @@ def _publish_segment_memory(self, stats): self._line("Heap used for terms", "", stats.memory_terms, unit, convert.bytes_to_mb), self._line("Heap used for norms", "", stats.memory_norms, unit, convert.bytes_to_mb), self._line("Heap used for points", "", stats.memory_points, unit, convert.bytes_to_mb), - self._line("Heap used for stored fields", "", stats.memory_stored_fields, unit, convert.bytes_to_mb) + self._line("Heap used for stored fields", "", stats.memory_stored_fields, unit, convert.bytes_to_mb), ) def _publish_segment_counts(self, stats): - return self._join( - self._line("Segment count", "", stats.segment_count, "") - ) + return self._join(self._line("Segment count", "", stats.segment_count, "")) def _publish_transform_stats(self, stats): lines = [] for processing_time in stats.total_transform_processing_times: - lines.append( - self._line("Transform processing time", processing_time["id"], processing_time["mean"], - processing_time["unit"])) + lines.append(self._line("Transform processing time", processing_time["id"], processing_time["mean"], processing_time["unit"])) for index_time in stats.total_transform_index_times: - lines.append( - self._line("Transform indexing time", index_time["id"], index_time["mean"], index_time["unit"])) + lines.append(self._line("Transform indexing time", index_time["id"], index_time["mean"], index_time["unit"])) for search_time in stats.total_transform_search_times: - lines.append( - self._line("Transform search time", search_time["id"], search_time["mean"], search_time["unit"])) + lines.append(self._line("Transform search time", search_time["id"], search_time["mean"], search_time["unit"])) for throughput in stats.total_transform_throughput: - lines.append( - self._line("Transform throughput", throughput["id"], throughput["mean"], throughput["unit"])) + lines.append(self._line("Transform throughput", throughput["id"], throughput["mean"], throughput["unit"])) return lines @@ -437,11 +409,9 @@ def __init__(self, config): self.logger = logging.getLogger(__name__) self.results_file = config.opts("reporting", "output.path") self.results_format = config.opts("reporting", "format") - self.numbers_align = config.opts("reporting", "numbers.align", - mandatory=False, default_value="right") + self.numbers_align = config.opts("reporting", "numbers.align", mandatory=False, default_value="right") self.cwd = config.opts("node", "benchmark.cwd") - self.show_processing_time = convert.to_bool(config.opts("reporting", "output.processingtime", - mandatory=False, default_value=False)) + self.show_processing_time = convert.to_bool(config.opts("reporting", "output.processingtime", mandatory=False, default_value=False)) self.percentiles = comma_separated_string_to_number_list(config.opts("reporting", "percentiles", mandatory=False)) self.plain = False @@ -500,9 +470,15 @@ def _metrics_table(self, baseline_stats, contender_stats, plain): return metrics_table def _write_results(self, metrics_table, metrics_table_console): - write_single_results(self.results_file, self.results_format, self.cwd, self.numbers_align, - headers=["Metric", "Task", "Baseline", "Contender", "Diff", "Unit"], - data_plain=metrics_table, data_rich=metrics_table_console) + write_single_results( + self.results_file, + self.results_format, + self.cwd, + self.numbers_align, + headers=["Metric", "Task", "Baseline", "Contender", "Diff", "Unit"], + data_plain=metrics_table, + data_rich=metrics_table_console, + ) def _publish_throughput(self, baseline_stats, contender_stats, task): b_min = baseline_stats.metrics(task)["throughput"].get("overall_min") or baseline_stats.metrics(task)["throughput"]["min"] @@ -520,7 +496,7 @@ def _publish_throughput(self, baseline_stats, contender_stats, task): self._line("Min Throughput", b_min, c_min, task, b_unit, treat_increase_as_improvement=True), self._line("Mean Throughput", b_mean, c_mean, task, b_unit, treat_increase_as_improvement=True), self._line("Median Throughput", b_median, c_median, task, b_unit, treat_increase_as_improvement=True), - self._line("Max Throughput", b_max, c_max, task, b_unit, treat_increase_as_improvement=True) + self._line("Max Throughput", b_max, c_max, task, b_unit, treat_increase_as_improvement=True), ) def _publish_latency(self, baseline_stats, contender_stats, task): @@ -543,18 +519,13 @@ def _publish_percentiles(self, name, task, baseline_values, contender_values): for percentile in metrics.percentiles_for_sample_size(sys.maxsize, percentiles_list=self.percentiles): baseline_value = baseline_values.get(metrics.encode_float_key(percentile)) contender_value = contender_values.get(metrics.encode_float_key(percentile)) - self._append_non_empty(lines, self._line("%sth percentile %s" % (percentile, name), - baseline_value, contender_value, task, "ms", - treat_increase_as_improvement=False)) + self._append_non_empty(lines, self._line("%sth percentile %s" % (percentile, name), baseline_value, contender_value, task, "ms", treat_increase_as_improvement=False)) return lines def _publish_error_rate(self, baseline_stats, contender_stats, task): baseline_error_rate = baseline_stats.metrics(task)["error_rate"] contender_error_rate = contender_stats.metrics(task)["error_rate"] - return self._join( - self._line("error rate", baseline_error_rate, contender_error_rate, task, "%", - treat_increase_as_improvement=False, formatter=convert.factor(100.0)) - ) + return self._join(self._line("error rate", baseline_error_rate, contender_error_rate, task, "%", treat_increase_as_improvement=False, formatter=convert.factor(100.0))) def _publish_ml_processing_times(self, baseline_stats, contender_stats): lines = [] @@ -564,14 +535,10 @@ def _publish_ml_processing_times(self, baseline_stats, contender_stats): # O(n^2) but we assume here only a *very* limited number of jobs (usually just one) for contender in contender_stats.ml_processing_time: if contender["job"] == job_name: - lines.append(self._line("Min ML processing time", baseline["min"], contender["min"], - job_name, unit, treat_increase_as_improvement=False)) - lines.append(self._line("Mean ML processing time", baseline["mean"], contender["mean"], - job_name, unit, treat_increase_as_improvement=False)) - lines.append(self._line("Median ML processing time", baseline["median"], contender["median"], - job_name, unit, treat_increase_as_improvement=False)) - lines.append(self._line("Max ML processing time", baseline["max"], contender["max"], - job_name, unit, treat_increase_as_improvement=False)) + lines.append(self._line("Min ML processing time", baseline["min"], contender["min"], job_name, unit, treat_increase_as_improvement=False)) + lines.append(self._line("Mean ML processing time", baseline["mean"], contender["mean"], job_name, unit, treat_increase_as_improvement=False)) + lines.append(self._line("Median ML processing time", baseline["median"], contender["median"], job_name, unit, treat_increase_as_improvement=False)) + lines.append(self._line("Max ML processing time", baseline["max"], contender["max"], job_name, unit, treat_increase_as_improvement=False)) return lines def _publish_transform_processing_times(self, baseline_stats, contender_stats): @@ -582,166 +549,160 @@ def _publish_transform_processing_times(self, baseline_stats, contender_stats): transform_id = baseline["id"] for contender in contender_stats.total_transform_processing_times: if contender["id"] == transform_id: - lines.append( - self._line("Transform processing time", baseline["mean"], contender["mean"], - transform_id, baseline["unit"], treat_increase_as_improvement=True)) + lines.append(self._line("Transform processing time", baseline["mean"], contender["mean"], transform_id, baseline["unit"], treat_increase_as_improvement=True)) for baseline in baseline_stats.total_transform_index_times: transform_id = baseline["id"] for contender in contender_stats.total_transform_index_times: if contender["id"] == transform_id: - lines.append( - self._line("Transform indexing time", baseline["mean"], contender["mean"], - transform_id, baseline["unit"], treat_increase_as_improvement=True)) + lines.append(self._line("Transform indexing time", baseline["mean"], contender["mean"], transform_id, baseline["unit"], treat_increase_as_improvement=True)) for baseline in baseline_stats.total_transform_search_times: transform_id = baseline["id"] for contender in contender_stats.total_transform_search_times: if contender["id"] == transform_id: - lines.append( - self._line("Transform search time", baseline["mean"], contender["mean"], - transform_id, baseline["unit"], treat_increase_as_improvement=True)) + lines.append(self._line("Transform search time", baseline["mean"], contender["mean"], transform_id, baseline["unit"], treat_increase_as_improvement=True)) for baseline in baseline_stats.total_transform_throughput: transform_id = baseline["id"] for contender in contender_stats.total_transform_throughput: if contender["id"] == transform_id: - lines.append( - self._line("Transform throughput", baseline["mean"], contender["mean"], - transform_id, baseline["unit"], treat_increase_as_improvement=True)) + lines.append(self._line("Transform throughput", baseline["mean"], contender["mean"], transform_id, baseline["unit"], treat_increase_as_improvement=True)) return lines def _publish_total_times(self, baseline_stats, contender_stats): lines = [] - lines.extend(self._publish_total_time( - "indexing time", - baseline_stats.total_time, contender_stats.total_time - )) - lines.extend(self._publish_total_time_per_shard( - "indexing time", - baseline_stats.total_time_per_shard, contender_stats.total_time_per_shard - )) - lines.extend(self._publish_total_time( - "indexing throttle time", - baseline_stats.indexing_throttle_time, contender_stats.indexing_throttle_time - )) - lines.extend(self._publish_total_time_per_shard( - "indexing throttle time", - baseline_stats.indexing_throttle_time_per_shard, - contender_stats.indexing_throttle_time_per_shard - )) - lines.extend(self._publish_total_time( - "merge time", - baseline_stats.merge_time, contender_stats.merge_time, - )) - lines.extend(self._publish_total_count( - "merge count", - baseline_stats.merge_count, contender_stats.merge_count - )) - lines.extend(self._publish_total_time_per_shard( - "merge time", - baseline_stats.merge_time_per_shard, - contender_stats.merge_time_per_shard - )) - lines.extend(self._publish_total_time( - "merge throttle time", - baseline_stats.merge_throttle_time, - contender_stats.merge_throttle_time - )) - lines.extend(self._publish_total_time_per_shard( - "merge throttle time", - baseline_stats.merge_throttle_time_per_shard, - contender_stats.merge_throttle_time_per_shard - )) - lines.extend(self._publish_total_time( - "refresh time", - baseline_stats.refresh_time, contender_stats.refresh_time - )) - lines.extend(self._publish_total_count( - "refresh count", - baseline_stats.refresh_count, contender_stats.refresh_count - )) - lines.extend(self._publish_total_time_per_shard( - "refresh time", - baseline_stats.refresh_time_per_shard, - contender_stats.refresh_time_per_shard - )) - lines.extend(self._publish_total_time( - "flush time", - baseline_stats.flush_time, contender_stats.flush_time - )) - lines.extend(self._publish_total_count( - "flush count", - baseline_stats.flush_count, contender_stats.flush_count - )) - lines.extend(self._publish_total_time_per_shard( - "flush time", - baseline_stats.flush_time_per_shard, contender_stats.flush_time_per_shard - )) + lines.extend(self._publish_total_time("indexing time", baseline_stats.total_time, contender_stats.total_time)) + lines.extend(self._publish_total_time_per_shard("indexing time", baseline_stats.total_time_per_shard, contender_stats.total_time_per_shard)) + lines.extend(self._publish_total_time("indexing throttle time", baseline_stats.indexing_throttle_time, contender_stats.indexing_throttle_time)) + lines.extend( + self._publish_total_time_per_shard("indexing throttle time", baseline_stats.indexing_throttle_time_per_shard, contender_stats.indexing_throttle_time_per_shard) + ) + lines.extend( + self._publish_total_time( + "merge time", + baseline_stats.merge_time, + contender_stats.merge_time, + ) + ) + lines.extend(self._publish_total_count("merge count", baseline_stats.merge_count, contender_stats.merge_count)) + lines.extend(self._publish_total_time_per_shard("merge time", baseline_stats.merge_time_per_shard, contender_stats.merge_time_per_shard)) + lines.extend(self._publish_total_time("merge throttle time", baseline_stats.merge_throttle_time, contender_stats.merge_throttle_time)) + lines.extend(self._publish_total_time_per_shard("merge throttle time", baseline_stats.merge_throttle_time_per_shard, contender_stats.merge_throttle_time_per_shard)) + lines.extend(self._publish_total_time("refresh time", baseline_stats.refresh_time, contender_stats.refresh_time)) + lines.extend(self._publish_total_count("refresh count", baseline_stats.refresh_count, contender_stats.refresh_count)) + lines.extend(self._publish_total_time_per_shard("refresh time", baseline_stats.refresh_time_per_shard, contender_stats.refresh_time_per_shard)) + lines.extend(self._publish_total_time("flush time", baseline_stats.flush_time, contender_stats.flush_time)) + lines.extend(self._publish_total_count("flush count", baseline_stats.flush_count, contender_stats.flush_count)) + lines.extend(self._publish_total_time_per_shard("flush time", baseline_stats.flush_time_per_shard, contender_stats.flush_time_per_shard)) return lines def _publish_total_time(self, name, baseline_total, contender_total): unit = "min" return self._join( - self._line("Cumulative {} of primary shards".format(name), baseline_total, contender_total, "", unit, - treat_increase_as_improvement=False, formatter=convert.ms_to_minutes), + self._line( + "Cumulative {} of primary shards".format(name), baseline_total, contender_total, "", unit, treat_increase_as_improvement=False, formatter=convert.ms_to_minutes + ), ) def _publish_total_time_per_shard(self, name, baseline_per_shard, contender_per_shard): unit = "min" return self._join( - self._line("Min cumulative {} across primary shard".format(name), baseline_per_shard.get("min"), - contender_per_shard.get("min"), "", unit, treat_increase_as_improvement=False, formatter=convert.ms_to_minutes), - self._line("Median cumulative {} across primary shard".format(name), baseline_per_shard.get("median"), - contender_per_shard.get("median"), "", unit, treat_increase_as_improvement=False, formatter=convert.ms_to_minutes), - self._line("Max cumulative {} across primary shard".format(name), baseline_per_shard.get("max"), contender_per_shard.get("max"), - "", unit, treat_increase_as_improvement=False, formatter=convert.ms_to_minutes), + self._line( + "Min cumulative {} across primary shard".format(name), + baseline_per_shard.get("min"), + contender_per_shard.get("min"), + "", + unit, + treat_increase_as_improvement=False, + formatter=convert.ms_to_minutes, + ), + self._line( + "Median cumulative {} across primary shard".format(name), + baseline_per_shard.get("median"), + contender_per_shard.get("median"), + "", + unit, + treat_increase_as_improvement=False, + formatter=convert.ms_to_minutes, + ), + self._line( + "Max cumulative {} across primary shard".format(name), + baseline_per_shard.get("max"), + contender_per_shard.get("max"), + "", + unit, + treat_increase_as_improvement=False, + formatter=convert.ms_to_minutes, + ), ) def _publish_total_count(self, name, baseline_total, contender_total): - return self._join( - self._line("Cumulative {} of primary shards".format(name), baseline_total, contender_total, "", "", - treat_increase_as_improvement=False) - ) + return self._join(self._line("Cumulative {} of primary shards".format(name), baseline_total, contender_total, "", "", treat_increase_as_improvement=False)) def _publish_gc_metrics(self, baseline_stats, contender_stats): return self._join( - self._line("Total Young Gen GC time", baseline_stats.young_gc_time, contender_stats.young_gc_time, "", "s", - treat_increase_as_improvement=False, formatter=convert.ms_to_seconds), - self._line("Total Young Gen GC count", baseline_stats.young_gc_count, contender_stats.young_gc_count, "", "", - treat_increase_as_improvement=False), - self._line("Total Old Gen GC time", baseline_stats.old_gc_time, contender_stats.old_gc_time, "", "s", - treat_increase_as_improvement=False, formatter=convert.ms_to_seconds), - self._line("Total Old Gen GC count", baseline_stats.old_gc_count, contender_stats.old_gc_count, "", "", - treat_increase_as_improvement=False) + self._line( + "Total Young Gen GC time", + baseline_stats.young_gc_time, + contender_stats.young_gc_time, + "", + "s", + treat_increase_as_improvement=False, + formatter=convert.ms_to_seconds, + ), + self._line("Total Young Gen GC count", baseline_stats.young_gc_count, contender_stats.young_gc_count, "", "", treat_increase_as_improvement=False), + self._line( + "Total Old Gen GC time", baseline_stats.old_gc_time, contender_stats.old_gc_time, "", "s", treat_increase_as_improvement=False, formatter=convert.ms_to_seconds + ), + self._line("Total Old Gen GC count", baseline_stats.old_gc_count, contender_stats.old_gc_count, "", "", treat_increase_as_improvement=False), ) def _publish_disk_usage(self, baseline_stats, contender_stats): return self._join( - self._line("Store size", baseline_stats.store_size, contender_stats.store_size, "", "GB", - treat_increase_as_improvement=False, formatter=convert.bytes_to_gb), - self._line("Translog size", baseline_stats.translog_size, contender_stats.translog_size, "", "GB", - treat_increase_as_improvement=False, formatter=convert.bytes_to_gb), + self._line("Store size", baseline_stats.store_size, contender_stats.store_size, "", "GB", treat_increase_as_improvement=False, formatter=convert.bytes_to_gb), + self._line("Translog size", baseline_stats.translog_size, contender_stats.translog_size, "", "GB", treat_increase_as_improvement=False, formatter=convert.bytes_to_gb), ) def _publish_segment_memory(self, baseline_stats, contender_stats): return self._join( - self._line("Heap used for segments", baseline_stats.memory_segments, contender_stats.memory_segments, "", "MB", - treat_increase_as_improvement=False, formatter=convert.bytes_to_mb), - self._line("Heap used for doc values", baseline_stats.memory_doc_values, contender_stats.memory_doc_values, "", "MB", - treat_increase_as_improvement=False, formatter=convert.bytes_to_mb), - self._line("Heap used for terms", baseline_stats.memory_terms, contender_stats.memory_terms, "", "MB", - treat_increase_as_improvement=False, formatter=convert.bytes_to_mb), - self._line("Heap used for norms", baseline_stats.memory_norms, contender_stats.memory_norms, "", "MB", - treat_increase_as_improvement=False, formatter=convert.bytes_to_mb), - self._line("Heap used for points", baseline_stats.memory_points, contender_stats.memory_points, "", "MB", - treat_increase_as_improvement=False, formatter=convert.bytes_to_mb), - self._line("Heap used for stored fields", baseline_stats.memory_stored_fields, contender_stats.memory_stored_fields, "", - "MB", treat_increase_as_improvement=False, formatter=convert.bytes_to_mb) - ) + self._line( + "Heap used for segments", + baseline_stats.memory_segments, + contender_stats.memory_segments, + "", + "MB", + treat_increase_as_improvement=False, + formatter=convert.bytes_to_mb, + ), + self._line( + "Heap used for doc values", + baseline_stats.memory_doc_values, + contender_stats.memory_doc_values, + "", + "MB", + treat_increase_as_improvement=False, + formatter=convert.bytes_to_mb, + ), + self._line( + "Heap used for terms", baseline_stats.memory_terms, contender_stats.memory_terms, "", "MB", treat_increase_as_improvement=False, formatter=convert.bytes_to_mb + ), + self._line( + "Heap used for norms", baseline_stats.memory_norms, contender_stats.memory_norms, "", "MB", treat_increase_as_improvement=False, formatter=convert.bytes_to_mb + ), + self._line( + "Heap used for points", baseline_stats.memory_points, contender_stats.memory_points, "", "MB", treat_increase_as_improvement=False, formatter=convert.bytes_to_mb + ), + self._line( + "Heap used for stored fields", + baseline_stats.memory_stored_fields, + contender_stats.memory_stored_fields, + "", + "MB", + treat_increase_as_improvement=False, + formatter=convert.bytes_to_mb, + ), + ) def _publish_segment_counts(self, baseline_stats, contender_stats): - return self._join( - self._line("Segment count", baseline_stats.segment_count, contender_stats.segment_count, - "", "", treat_increase_as_improvement=False) - ) + return self._join(self._line("Segment count", baseline_stats.segment_count, contender_stats.segment_count, "", "", treat_increase_as_improvement=False)) def _join(self, *args): lines = [] @@ -755,8 +716,7 @@ def _append_non_empty(self, lines, line): def _line(self, metric, baseline, contender, task, unit, treat_increase_as_improvement, formatter=lambda x: x): if baseline is not None and contender is not None: - return [metric, str(task), formatter(baseline), formatter(contender), - *self._diff(baseline, contender, treat_increase_as_improvement, formatter), unit] + return [metric, str(task), formatter(baseline), formatter(contender), *self._diff(baseline, contender, treat_increase_as_improvement, formatter), unit] else: return [] @@ -787,9 +747,9 @@ def identity(x): color_neutral = console.format.neutral if percentage_diff > 5.0: - return color_greater("+%.2f%%" % percentage_diff)+" :red_circle:",color_greater("+%.5f" % diff) + return color_greater("+%.2f%%" % percentage_diff) + " :red_circle:", color_greater("+%.5f" % diff) elif percentage_diff < -5.0: - return color_smaller("%.2f%%" % percentage_diff)+" :green_circle:",color_smaller("%.5f" % diff) + return color_smaller("%.2f%%" % percentage_diff) + " :green_circle:", color_smaller("%.5f" % diff) else: # tabulate needs this to align all values correctly - return color_neutral("%.2f%%" % percentage_diff),color_neutral("%.5f" % diff) + return color_neutral("%.2f%%" % percentage_diff), color_neutral("%.5f" % diff) diff --git a/solrorbit/result_writer.py b/solrorbit/result_writer.py index c397f32a..d7aa0af1 100644 --- a/solrorbit/result_writer.py +++ b/solrorbit/result_writer.py @@ -102,11 +102,13 @@ def open(self, run_metadata: dict) -> None: # Example: 20260222_143052_7a82f1ea if timestamp and run_id != "unknown": from datetime import datetime + # timestamp can be either a datetime object or Unix timestamp (float/int) if isinstance(timestamp, datetime): time_str = timestamp.strftime("%Y%m%d_%H%M%S") elif isinstance(timestamp, (int, float)): import time + time_str = time.strftime("%Y%m%d_%H%M%S", time.gmtime(timestamp)) else: # Unknown timestamp type, fall back to run_id only @@ -191,10 +193,7 @@ def _write_summary(self) -> str: return "(no metrics recorded)" normal = [m for m in self._metrics if m.get("sample_type") != "warmup"] - rows = [ - [m.get("task", ""), m.get("name", ""), m.get("value", ""), m.get("unit", "")] - for m in normal - ] + rows = [[m.get("task", ""), m.get("name", ""), m.get("value", ""), m.get("unit", "")] for m in normal] table = tabulate_lib.tabulate( rows, headers=["Task", "Metric", "Value", "Unit"], @@ -234,8 +233,5 @@ def create_writer(name: str, **kwargs) -> ResultWriter: "local_filesystem": LocalFilesystemResultWriter, } if name not in registry: - raise exceptions.SystemSetupError( - f"Unknown results_writer '{name}'. " - f"Available: {', '.join(registry)}" - ) + raise exceptions.SystemSetupError(f"Unknown results_writer '{name}'. Available: {', '.join(registry)}") return registry[name](**kwargs) diff --git a/solrorbit/synthetic_data_generator/helpers.py b/solrorbit/synthetic_data_generator/helpers.py index bd2cfec8..f92f55e1 100644 --- a/solrorbit/synthetic_data_generator/helpers.py +++ b/solrorbit/synthetic_data_generator/helpers.py @@ -21,8 +21,9 @@ from solrorbit.synthetic_data_generator.strategies.strategy import DataGenerationStrategy from solrorbit.synthetic_data_generator.models import SyntheticDataGeneratorMetadata, SDGConfig, GB_TO_BYTES + def load_user_module(file_path): - allowed_extensions = ['.py'] + allowed_extensions = [".py"] extension = os.path.splitext(file_path)[1] if extension not in allowed_extensions: raise exceptions.SystemSetupError(f"User provided module with file extension [{extension}]. Python modules must have {allowed_extensions} extension.") @@ -32,6 +33,7 @@ def load_user_module(file_path): spec.loader.exec_module(user_module) return user_module + def load_mapping(mapping_file_path): """ Loads an index mapping from a JSON file. @@ -47,18 +49,23 @@ def load_mapping(mapping_file_path): return mapping_dict + def check_for_existing_files(output_path: str, index_name: str): - VALID_OPTIONS = ['y', 'yes', 'n', 'no'] - ALTERNATIVES_FOR_YES = ['y', 'yes'] + VALID_OPTIONS = ["y", "yes", "n", "no"] + ALTERNATIVES_FOR_YES = ["y", "yes"] logger = logging.getLogger(__name__) existing_files_found = existing_files_found_in_output_dir(output_path, index_name) if existing_files_found: - user_decision = input(f"Files with the same expected names were found in the output directory {output_path}. " + \ - "Would you like to remove them (so that SDG does not append to them)? (y/n): ") + user_decision = input( + f"Files with the same expected names were found in the output directory {output_path}. " + + "Would you like to remove them (so that SDG does not append to them)? (y/n): " + ) while user_decision.lower() not in VALID_OPTIONS: - user_decision = input(f"Invalid response. Files with the same expected names were found in the output directory {output_path}. " + \ - "Would you like to remove them (so that SDG does not append to them)? (y/n): ") + user_decision = input( + f"Invalid response. Files with the same expected names were found in the output directory {output_path}. " + + "Would you like to remove them (so that SDG does not append to them)? (y/n): " + ) if user_decision.lower() in ALTERNATIVES_FOR_YES: remove_existing_files(existing_files_found) @@ -68,15 +75,17 @@ def check_for_existing_files(output_path: str, index_name: str): logger.info("Keeping files at: %s", output_path) console.println(f"Keeping files at: {output_path}\n") + def existing_files_found_in_output_dir(output_path: str, index_name: str) -> bool: existing_files = [] for file in os.listdir(output_path): - if (file.startswith(index_name) and file.endswith(".json")) or (file.startswith(index_name) and file.endswith('_record.json')): + if (file.startswith(index_name) and file.endswith(".json")) or (file.startswith(index_name) and file.endswith("_record.json")): existing_files.append(os.path.join(output_path, file)) return existing_files + def remove_existing_files(existing_files_found: list): try: for file in existing_files_found: @@ -84,6 +93,7 @@ def remove_existing_files(existing_files_found: list): except OSError as e: raise exceptions.ExecutorError("Solr Orbit could not remove existing files for SDG: ", e) + def host_has_available_disk_storage(sdg_metadata: SyntheticDataGeneratorMetadata) -> bool: logger = logging.getLogger(__name__) try: @@ -98,15 +108,16 @@ def host_has_available_disk_storage(sdg_metadata: SyntheticDataGeneratorMetadata logger.error("Error checking disk space.") return False + def load_config(config_path: str) -> SDGConfig: try: - allowed_extensions = ['.yml', '.yaml'] + allowed_extensions = [".yml", ".yaml"] extension = os.path.splitext(config_path)[1] if extension not in allowed_extensions: raise exceptions.ConfigError(f"User provided config with extension [{extension}]. Config must have a {allowed_extensions} extension.") else: - with open(config_path, 'r') as file: + with open(config_path, "r") as file: config_details = yaml.safe_load(file) return SDGConfig(**config_details) if config_details else SDGConfig() @@ -116,14 +127,16 @@ def load_config(config_path: str) -> SDGConfig: except TypeError: raise exceptions.SystemSetupError("Error when loading config. Please ensure that the proper config was provided") + def write_chunk(data, file_path): written_bytes = 0 - with open(file_path, 'a') as f: + with open(file_path, "a") as f: for item in data: - f.write(json.dumps(item) + '\n') + f.write(json.dumps(item) + "\n") written_bytes += len(pickle.dumps(item)) return len(data), written_bytes + def calculate_avg_doc_size(strategy: DataGenerationStrategy): # Didn't do pickle because this seems to be more accurate output = strategy.generate_test_document() @@ -135,13 +148,15 @@ def calculate_avg_doc_size(strategy: DataGenerationStrategy): return size + def format_size(bytes): - for unit in ['B', 'KB', 'MB', 'GB', 'TB']: + for unit in ["B", "KB", "MB", "GB", "TB"]: if bytes < 1024: return f"{bytes:.2f} {unit}" bytes /= 1024 return f"{bytes:.2f} PB" + def format_time(seconds): if seconds < 60: return f"{seconds:.1f}s" @@ -153,11 +168,13 @@ def format_time(seconds): minutes, seconds = divmod(remainder, 60) return f"{int(hours)}h {int(minutes)}m {int(seconds)}s" + def setup_custom_tqdm_formatting(progress_bar): - progress_bar.format_dict['n_fmt'] = lambda n: format_size(n) # pylint: disable=unnecessary-lambda - progress_bar.format_dict['total_fmt'] = lambda t: format_size(t) # pylint: disable=unnecessary-lambda - progress_bar.format_dict['elapsed'] = lambda e: format_time(e) # pylint: disable=unnecessary-lambda - progress_bar.format_dict['remaining'] = lambda r: format_time(r) # pylint: disable=unnecessary-lambda + progress_bar.format_dict["n_fmt"] = lambda n: format_size(n) # pylint: disable=unnecessary-lambda + progress_bar.format_dict["total_fmt"] = lambda t: format_size(t) # pylint: disable=unnecessary-lambda + progress_bar.format_dict["elapsed"] = lambda e: format_time(e) # pylint: disable=unnecessary-lambda + progress_bar.format_dict["remaining"] = lambda r: format_time(r) # pylint: disable=unnecessary-lambda + def build_record(sdg_metadata: SyntheticDataGeneratorMetadata, total_time_to_generate_dataset: int, generated_dataset_details: dict) -> dict: total_docs_written = 0 @@ -172,23 +189,25 @@ def build_record(sdg_metadata: SyntheticDataGeneratorMetadata, total_time_to_gen "total-docs-written": total_docs_written, "total-dataset-size": total_dataset_size_in_bytes, "total-time-to-generate-dataset": total_time_to_generate_dataset, - "files": generated_dataset_details + "files": generated_dataset_details, } return record + def write_record(sdg_metadata: SyntheticDataGeneratorMetadata, record): path = os.path.join(sdg_metadata.output_path, f"{sdg_metadata.index_name}_record.json") - with open(path, 'w') as file: + with open(path, "w") as file: json.dump(record, file, indent=2) + def write_record_and_publish_summary_to_console(sdg_metadata: SyntheticDataGeneratorMetadata, total_time_to_generate_dataset: int, generated_dataset_details: dict): logger = logging.getLogger(__name__) record = build_record(sdg_metadata, total_time_to_generate_dataset, generated_dataset_details) write_record(sdg_metadata, record) - summary = f"Generated {record['total-docs-written']} docs in {total_time_to_generate_dataset} seconds. Total dataset size is {record['total-dataset-size'] / (1000 ** 3)}GB." + summary = f"Generated {record['total-docs-written']} docs in {total_time_to_generate_dataset} seconds. Total dataset size is {record['total-dataset-size'] / (1000**3)}GB." console.println("") console.println(summary) diff --git a/solrorbit/synthetic_data_generator/input_processor.py b/solrorbit/synthetic_data_generator/input_processor.py index f5996c1e..ba9f1ec3 100644 --- a/solrorbit/synthetic_data_generator/input_processor.py +++ b/solrorbit/synthetic_data_generator/input_processor.py @@ -13,6 +13,7 @@ logger = logging.getLogger(__name__) + def create_sdg_metadata_from_args(cfg) -> SyntheticDataGeneratorMetadata: """ Creates a Synthetic Data Generator Config based on the user's inputs @@ -27,17 +28,18 @@ def create_sdg_metadata_from_args(cfg) -> SyntheticDataGeneratorMetadata: custom_config_path = cfg.opts("synthetic_data_generator", "custom_config") return SyntheticDataGeneratorMetadata( - index_name = cfg.opts("synthetic_data_generator", "index_name"), - index_mappings_path = index_mappings_path, - custom_module_path = custom_module_path, - custom_config_path = custom_config_path, - output_path = cfg.opts("synthetic_data_generator", "output_path"), - total_size_gb= cfg.opts("synthetic_data_generator", "total_size"), + index_name=cfg.opts("synthetic_data_generator", "index_name"), + index_mappings_path=index_mappings_path, + custom_module_path=custom_module_path, + custom_config_path=custom_config_path, + output_path=cfg.opts("synthetic_data_generator", "output_path"), + total_size_gb=cfg.opts("synthetic_data_generator", "total_size"), ) except ConfigError as e: raise ConfigError("Config error when building SyntheticDataGeneratorMetadata: ", e) + def use_custom_synthetic_data_generator(sdg_metadata: SyntheticDataGeneratorMetadata) -> bool: if sdg_metadata.custom_module_path and not sdg_metadata.index_mappings_path: logger.info("User is using custom module to generate synthetic data. Custom module is found in this path: [%s]", sdg_metadata.custom_module_path) @@ -45,6 +47,7 @@ def use_custom_synthetic_data_generator(sdg_metadata: SyntheticDataGeneratorMeta return False + def use_mappings_synthetic_data_generator(sdg_metadata: SyntheticDataGeneratorMetadata) -> bool: if sdg_metadata.index_mappings_path and not sdg_metadata.custom_module_path: logger.info("User is using index mappings to generate synthetic data. Index mappings are found in this path: [%s]", sdg_metadata.index_mappings_path) diff --git a/solrorbit/synthetic_data_generator/models.py b/solrorbit/synthetic_data_generator/models.py index 6d172a09..7ea7db4d 100644 --- a/solrorbit/synthetic_data_generator/models.py +++ b/solrorbit/synthetic_data_generator/models.py @@ -14,7 +14,8 @@ from solrorbit.synthetic_data_generator.timeseries_partitioner import TimeSeriesPartitioner -GB_TO_BYTES = 1024 ** 3 +GB_TO_BYTES = 1024**3 + class TimeSeriesConfig(BaseModel): timeseries_field: str @@ -24,61 +25,60 @@ class TimeSeriesConfig(BaseModel): timeseries_format: str # pylint: disable = no-self-argument - @field_validator('timeseries_start_date', 'timeseries_end_date', 'timeseries_frequency', 'timeseries_format') + @field_validator("timeseries_start_date", "timeseries_end_date", "timeseries_frequency", "timeseries_format") def validate_string_fields(cls, v, info): """Validate that timeseries configuration fields are strings""" if not isinstance(v, str): - field_name = info.field_name.replace('_', ' ').title() + field_name = info.field_name.replace("_", " ").title() raise ValueError(f"{field_name} requires a string value. Value {v} is not valid.") # Additional validation for frequency and format fields - if info.field_name == 'timeseries_frequency': + if info.field_name == "timeseries_frequency": if v not in TimeSeriesPartitioner.AVAILABLE_FREQUENCIES: raise ValueError(f"Timeseries frequency {v} is not a valid value. Valid values are {TimeSeriesPartitioner.AVAILABLE_FREQUENCIES}") - if info.field_name == 'timeseries_format': + if info.field_name == "timeseries_format": if v not in TimeSeriesPartitioner.VALID_DATETIMESTAMPS_FORMATS: raise ValueError(f"Timeseries format {v} is not a valid value. Valid values are {TimeSeriesPartitioner.VALID_DATETIMESTAMPS_FORMATS}") return v # pylint: disable = no-self-argument - @field_validator('timeseries_field') + @field_validator("timeseries_field") def validate_timeseries_field(cls, v): if not v or not v.strip(): raise ValueError("timeseries_field cannot be empty") # Validate field name format # OpenSearch field names must start with a letter and contain only alphanumeric, underscores, and periods - if not re.match(r'^[a-zA-Z][a-zA-Z0-9_.]*$', v): - raise ValueError( - f"Invalid timeseries_field '{v}'. Field names must start with a letter " - "and contain only alphanumeric characters, underscores, and periods." - ) + if not re.match(r"^[a-zA-Z][a-zA-Z0-9_.]*$", v): + raise ValueError(f"Invalid timeseries_field '{v}'. Field names must start with a letter and contain only alphanumeric characters, underscores, and periods.") return v + class SettingsConfig(BaseModel): - workers: Optional[int] = Field(default_factory=os.cpu_count) # Number of workers recommended to not exceed CPU count - max_file_size_gb: Optional[int] = 40 # Default because some CloudProviders limit the size of files stored - docs_per_chunk: Optional[int] = 10000 # Default based on testing - filename_suffix_begins_at: Optional[int] = 0 # Start at suffix 0 + workers: Optional[int] = Field(default_factory=os.cpu_count) # Number of workers recommended to not exceed CPU count + max_file_size_gb: Optional[int] = 40 # Default because some CloudProviders limit the size of files stored + docs_per_chunk: Optional[int] = 10000 # Default based on testing + filename_suffix_begins_at: Optional[int] = 0 # Start at suffix 0 timeseries_enabled: Optional[TimeSeriesConfig] = None # pylint: disable = no-self-argument - @field_validator('workers', 'max_file_size_gb', 'docs_per_chunk') + @field_validator("workers", "max_file_size_gb", "docs_per_chunk") def validate_values_are_positive_integers(cls, v): if v is not None and v <= 0: raise ValueError(f"Value '{v}' in Settings portion must be a positive integer.") return v + class CustomGenerationValuesConfig(BaseModel): custom_lists: Optional[Dict[str, List[Any]]] = None custom_providers: Optional[List[Any]] = None # pylint: disable = no-self-argument - @field_validator('custom_lists') + @field_validator("custom_lists") def validate_custom_lists(cls, v): if v is not None: for key, value in v.items(): @@ -88,6 +88,7 @@ def validate_custom_lists(cls, v): raise ValueError(f"Value for key '{key}' must be a list.") return v + class GeneratorParams(BaseModel): # Integer / Long Params min: Optional[Union[int, float]] = None @@ -122,57 +123,59 @@ class GeneratorParams(BaseModel): token_id_step: Optional[int] = None class Config: - extra = 'forbid' + extra = "forbid" + class FieldOverride(BaseModel): generator: str params: GeneratorParams # pylint: disable = no-self-argument - @field_validator('generator') + @field_validator("generator") def validate_generator_name(cls, v): valid_generators = [ - 'generate_text', - 'generate_keyword', - 'generate_integer', - 'generate_long', - 'generate_short', - 'generate_byte', - 'generate_float', - 'generate_double', - 'generate_boolean', - 'generate_date', - 'generate_ip', - 'generate_geo_point', - 'generate_object', - 'generate_nested', - 'generate_knn_vector', - 'generate_sparse_vector' + "generate_text", + "generate_keyword", + "generate_integer", + "generate_long", + "generate_short", + "generate_byte", + "generate_float", + "generate_double", + "generate_boolean", + "generate_date", + "generate_ip", + "generate_geo_point", + "generate_object", + "generate_nested", + "generate_knn_vector", + "generate_sparse_vector", ] if v not in valid_generators: raise ValueError(f"Generator '{v}' mentioned in FieldOverrides not among valid generators: {valid_generators}") return v + class MappingGenerationValuesConfig(BaseModel): generator_overrides: Optional[Dict[str, GeneratorParams]] = None field_overrides: Optional[Dict[str, FieldOverride]] = None # pylint: disable = no-self-argument - @field_validator('generator_overrides') + @field_validator("generator_overrides") def validate_generator_types(cls, v): # Based on this documentation from OpenSearch: https://docs.opensearch.org/latest/mappings/supported-field-types/index/ # TODO: Add more support for if v is not None: supported_mapping_field_types = { - 'core-field-types': ['boolean'], - 'string-based-field-types': ['text', 'keyword'], - 'numeric-field-types': ['byte', 'short', 'integer', 'long', 'float', 'double'], - 'date-time-field-types': ['date'], - 'ip-field-types': ['ip'], - 'geographic-field-types': ['geo_point'], - 'object-field-types': ['object', 'nested'], - 'vector-field-types': ['knn_vector', 'sparse_vector'] + "core-field-types": ["boolean"], + "string-based-field-types": ["text", "keyword"], + "numeric-field-types": ["byte", "short", "integer", "long", "float", "double"], + "date-time-field-types": ["date"], + "ip-field-types": ["ip"], + "geographic-field-types": ["geo_point"], + "object-field-types": ["object", "nested"], + "vector-field-types": ["knn_vector", "sparse_vector"], } valid_generator_types = [] @@ -186,15 +189,16 @@ def validate_generator_types(cls, v): return v # pylint: disable = no-self-argument - @field_validator('field_overrides') + @field_validator("field_overrides") def validate_field_names(cls, v): if v is not None: for field_name in v.keys(): - if not re.match(r'^[a-zA-Z][a-zA-Z0-9_.]*$', field_name): + if not re.match(r"^[a-zA-Z][a-zA-Z0-9_.]*$", field_name): raise ValueError(f"Invalid Field Name '{field_name}' in FieldOverrides. Only alphanumeric characters, underscores and periods are allowed.") return v + class SyntheticDataGeneratorMetadata(BaseModel): index_name: Optional[str] = None index_mappings_path: Optional[str] = None @@ -204,7 +208,8 @@ class SyntheticDataGeneratorMetadata(BaseModel): total_size_gb: Optional[int] = None class Config: - extra = 'forbid' + extra = "forbid" + class SDGConfig(BaseModel): # If user does not provide YAML fil or provides YAML without all settings fields, it will use default generation settings. @@ -213,4 +218,4 @@ class SDGConfig(BaseModel): MappingGenerationValues: Optional[MappingGenerationValuesConfig] = None class Config: - extra = 'forbid' + extra = "forbid" diff --git a/solrorbit/synthetic_data_generator/strategies/custom_module_strategy.py b/solrorbit/synthetic_data_generator/strategies/custom_module_strategy.py index c8cab53d..d7fb1837 100644 --- a/solrorbit/synthetic_data_generator/strategies/custom_module_strategy.py +++ b/solrorbit/synthetic_data_generator/strategies/custom_module_strategy.py @@ -22,14 +22,15 @@ from solrorbit.synthetic_data_generator.models import SyntheticDataGeneratorMetadata, SDGConfig from solrorbit.synthetic_data_generator.timeseries_partitioner import TimeSeriesPartitioner + class CustomModuleStrategy(DataGenerationStrategy): - def __init__(self, sdg_metadata: SyntheticDataGeneratorMetadata, sdg_config: SDGConfig, custom_module: ModuleType) -> None: + def __init__(self, sdg_metadata: SyntheticDataGeneratorMetadata, sdg_config: SDGConfig, custom_module: ModuleType) -> None: self.sdg_metadata = sdg_metadata self.sdg_config = sdg_config self.custom_module = custom_module self.logger = logging.getLogger(__name__) - if not hasattr(self.custom_module, 'generate_synthetic_document'): + if not hasattr(self.custom_module, "generate_synthetic_document"): msg = f"Custom module at [{self.sdg_metadata.custom_module_path}] does not define a function called generate_synthetic_document(). Ensure that this method is defined." raise exceptions.ConfigError(msg) @@ -41,16 +42,13 @@ def __init__(self, sdg_metadata: SyntheticDataGeneratorMetadata, sdg_config: SD try: self.custom_lists = self.sdg_config.CustomGenerationValues.custom_lists or {} provider_names = self.sdg_config.CustomGenerationValues.custom_providers or [] - self.custom_providers = { - name: getattr(self.custom_module, name) for name in provider_names - } + self.custom_providers = {name: getattr(self.custom_module, name) for name in provider_names} except AttributeError as e: msg = f"Error when setting up custom lists and custom providers: {e}" raise exceptions.ConfigError(msg) except TypeError: msg = "Synthetic Data Generator Config has custom_lists and custom_providers pointing to null values. Either populate or remove." - # pylint: disable=arguments-differ def generate_data_chunks_across_workers(self, dask_client: Client, docs_per_chunk: int, seeds: list, timeseries_enabled: dict = None, timeseries_windows: list = None) -> list: """ @@ -65,23 +63,18 @@ def generate_data_chunks_across_workers(self, dask_client: Client, docs_per_chun for _ in range(len(seeds)): seed = seeds[_] window = timeseries_windows[_] - future = dask_client.submit( - self.generate_data_chunk_from_worker, self.custom_module.generate_synthetic_document, - docs_per_chunk, seed, timeseries_enabled, window - ) + future = dask_client.submit(self.generate_data_chunk_from_worker, self.custom_module.generate_synthetic_document, docs_per_chunk, seed, timeseries_enabled, window) futures.append(future) return futures else: # If not using timeseries approach - return [dask_client.submit( - self.generate_data_chunk_from_worker, self.custom_module.generate_synthetic_document, - docs_per_chunk, seed) for seed in seeds] - + return [dask_client.submit(self.generate_data_chunk_from_worker, self.custom_module.generate_synthetic_document, docs_per_chunk, seed) for seed in seeds] - def generate_data_chunk_from_worker(self, generate_synthetic_document: Callable, docs_per_chunk: int, seed: Optional[int], - timeseries_enabled: dict = None, timeseries_window: set = None) -> list: + def generate_data_chunk_from_worker( + self, generate_synthetic_document: Callable, docs_per_chunk: int, seed: Optional[int], timeseries_enabled: dict = None, timeseries_window: set = None + ) -> list: """ This method is submitted to Dask worker and can be thought of as the worker performing a job, which is calling the custom module's generate_synthetic_document() function to generate documents. @@ -104,7 +97,7 @@ def generate_data_chunk_from_worker(self, generate_synthetic_document: Callable, synthetic_docs = [] datetimestamps: Generator = TimeSeriesPartitioner.generate_datetimestamps_from_window( window=timeseries_window, frequency=timeseries_enabled.timeseries_frequency, format=timeseries_enabled.timeseries_format - ) + ) for datetimestamp in datetimestamps: document = generate_synthetic_document(providers=seeded_providers, **self.custom_lists) try: @@ -128,14 +121,16 @@ def generate_test_document(self, timeseries_enabled: dict = None, timeseries_win if timeseries_enabled and timeseries_enabled.timeseries_field: datetimestamps: Generator = TimeSeriesPartitioner.generate_datetimestamps_from_window( window=timeseries_window, frequency=timeseries_enabled.timeseries_frequency, format=timeseries_enabled.timeseries_format - ) + ) for datetimestamp in datetimestamps: document[timeseries_enabled.timeseries_field] = datetimestamp except AttributeError as e: - msg = "Encountered AttributeError when setting up custom_providers and custom_lists. " + \ - "It seems that your module might be using custom_lists and custom_providers." + \ - f"Please ensure you have provided a custom config with custom_providers and custom_lists: {e}" + msg = ( + "Encountered AttributeError when setting up custom_providers and custom_lists. " + + "It seems that your module might be using custom_lists and custom_providers." + + f"Please ensure you have provided a custom config with custom_providers and custom_lists: {e}" + ) raise exceptions.ConfigError(msg) return document @@ -147,21 +142,18 @@ def _instantiate_all_providers(self, custom_providers): if custom_providers: g = self._add_custom_providers(g, custom_providers) - provider_instances = { - 'generic': g, - 'random': r - } + provider_instances = {"generic": g, "random": r} return provider_instances def _seed_providers(self, providers, seed=None): - ''' + """ Generic Mimesis uses reseed method while non-generic Mimesis (like Random) uses seed method. Both lead to the same effect. - ''' + """ for key, provider_instance in providers.items(): - if key in ['generic']: + if key in ["generic"]: provider_instance.reseed(seed) - elif key in ['random']: + elif key in ["random"]: provider_instance.seed(seed) return providers diff --git a/solrorbit/synthetic_data_generator/strategies/mapping_strategy.py b/solrorbit/synthetic_data_generator/strategies/mapping_strategy.py index 0dfb002f..905c96c5 100644 --- a/solrorbit/synthetic_data_generator/strategies/mapping_strategy.py +++ b/solrorbit/synthetic_data_generator/strategies/mapping_strategy.py @@ -22,12 +22,13 @@ from solrorbit.synthetic_data_generator.models import SyntheticDataGeneratorMetadata, SDGConfig, MappingGenerationValuesConfig from solrorbit.synthetic_data_generator.timeseries_partitioner import TimeSeriesPartitioner + class MappingStrategy(DataGenerationStrategy): - def __init__(self, sdg_metadata: SyntheticDataGeneratorMetadata, sdg_config: SDGConfig, index_mapping: dict) -> None: + def __init__(self, sdg_metadata: SyntheticDataGeneratorMetadata, sdg_config: SDGConfig, index_mapping: dict) -> None: self.sdg_metadata = sdg_metadata - self.sdg_config = sdg_config # Optional YAML-based config for value constraints - self.index_mapping = index_mapping # OpenSearch Mapping - self.mapping_generation_values = (self.sdg_config.MappingGenerationValues or {}) if self.sdg_config else {} + self.sdg_config = sdg_config # Optional YAML-based config for value constraints + self.index_mapping = index_mapping # OpenSearch Mapping + self.mapping_generation_values = (self.sdg_config.MappingGenerationValues or {}) if self.sdg_config else {} self.logger = logging.getLogger(__name__) @@ -43,10 +44,7 @@ def generate_data_chunks_across_workers(self, dask_client: Client, docs_per_chun for _ in range(len(seeds)): seed = seeds[_] window = timeseries_windows[_] - future = dask_client.submit( - self.generate_data_chunk_from_worker, - docs_per_chunk, seed, timeseries_enabled, window - ) + future = dask_client.submit(self.generate_data_chunk_from_worker, docs_per_chunk, seed, timeseries_enabled, window) futures.append(future) @@ -76,7 +74,7 @@ def generate_data_chunk_from_worker(self, docs_per_chunk: int, seed: Optional[in synthetic_docs = [] datetimestamps: Generator = TimeSeriesPartitioner.generate_datetimestamps_from_window( window=timeseries_window, frequency=timeseries_enabled.timeseries_frequency, format=timeseries_enabled.timeseries_format - ) + ) for datetimestamp in datetimestamps: document = MappingConverter.generate_synthetic_document(mappings_with_generators) try: @@ -88,7 +86,6 @@ def generate_data_chunk_from_worker(self, docs_per_chunk: int, seed: Optional[in return synthetic_docs - documents = [MappingConverter.generate_synthetic_document(mappings_with_generators) for _ in range(docs_per_chunk)] return documents @@ -101,7 +98,7 @@ def generate_test_document(self, timeseries_enabled: dict = None, timeseries_win if timeseries_enabled and timeseries_enabled.timeseries_field: datetimestamps: Generator = TimeSeriesPartitioner.generate_datetimestamps_from_window( window=timeseries_window, frequency=timeseries_enabled.timeseries_frequency, format=timeseries_enabled.timeseries_format - ) + ) for datetimestamp in datetimestamps: document[timeseries_enabled.timeseries_field] = datetimestamp @@ -158,11 +155,11 @@ def generate_synthetic_document(transformed_mapping: Dict[str, Callable]) -> Dic return document - def generate_text(self, field_def: Dict[str, Any], **params) -> str: - choices = params.get('must_include', None) + def generate_text(self, field_def: Dict[str, Any], **params) -> str: + choices = params.get("must_include", None) analyzer = field_def.get("analyzer", "standard") - #TODO: Need to support other analyzers + # TODO: Need to support other analyzers text = "" if choices: term = random.choice(choices) @@ -174,9 +171,8 @@ def generate_text(self, field_def: Dict[str, Any], **params) -> str: text += f"Sample text for {random.randint(1, 100)}" return text - def generate_keyword(self, field_def: Dict[str, Any], **params) -> str: - choices = params.get('choices', None) + choices = params.get("choices", None) if choices: keyword = random.choice(choices) return keyword @@ -184,43 +180,47 @@ def generate_keyword(self, field_def: Dict[str, Any], **params) -> str: return f"key_{uuid.uuid4().hex[:8]}" def generate_long(self, field_def: Dict[str, Any], **params) -> int: - min = params.get('min', -(2**63 - 1)) - max = params.get('max', (2**63 - 1)) + min = params.get("min", -(2**63 - 1)) + max = params.get("max", (2**63 - 1)) return random.randint(min, max) def generate_integer(self, field_def: Dict[str, Any], **params) -> int: - min = params.get('min', -2147483648) - max = params.get('max', 2147483647) + min = params.get("min", -2147483648) + max = params.get("max", 2147483647) return random.randint(min, max) def generate_short(self, field_def: Dict[str, Any], **params) -> int: - min = params.get('min', -32768) - max = params.get('max', 32767) + min = params.get("min", -32768) + max = params.get("max", 32767) return random.randint(min, max) def generate_byte(self, field_def: Dict[str, Any], **params) -> int: - min = params.get('min', -128) - max = params.get('max', 127) + min = params.get("min", -128) + max = params.get("max", 127) return random.randint(min, max) def generate_double(self, field_def: Dict[str, Any], **params) -> float: - min = params.get('min', -1e9) - max = params.get('max', 1e9) + min = params.get("min", -1e9) + max = params.get("max", 1e9) return random.uniform(min, max) def generate_float(self, field_def: Dict[str, Any], **params) -> float: - min = params.get('min', 0) - max = params.get('max', 1000) - decimal_places = params.get('round', 2) + min = params.get("min", 0) + max = params.get("max", 1000) + decimal_places = params.get("round", 2) float_value = random.uniform(min, max) - return round(float_value , decimal_places) + return round(float_value, decimal_places) def generate_boolean(self, field_def: Dict[str, Any], **params) -> bool: return random.choice([True, False]) - def generate_date(self, field_def: Dict[str, Any], **params,) -> str: + def generate_date( + self, + field_def: Dict[str, Any], + **params, + ) -> str: # TODO Need to handle actual format values # If field definition includes format, then use it. date_format = field_def.get("format", "yyyy-mm-dd") @@ -233,9 +233,7 @@ def generate_date(self, field_def: Dict[str, Any], **params,) -> str: start_dt = datetime.datetime.fromisoformat(start_date) end_dt = datetime.datetime.fromisoformat(end_date) - random_date = start_dt + datetime.timedelta( - days=random.randint(0, (end_dt - start_dt).days) - ) + random_date = start_dt + datetime.timedelta(days=random.randint(0, (end_dt - start_dt).days)) # Apply formatting if date_format == "yyyy-mm-dd": @@ -248,10 +246,7 @@ def generate_ip(self, field_def: Dict[str, Any], **params) -> str: return f"{random.randint(1, 255)}.{random.randint(0, 255)}.{random.randint(0, 255)}.{random.randint(1, 254)}" def generate_geo_point(self, field_def: Dict[str, Any], **params) -> Dict[str, float]: - return { - "lat": random.uniform(-90, 90), - "lon": random.uniform(-180, 180) - } + return {"lat": random.uniform(-90, 90), "lon": random.uniform(-180, 180)} def generate_object(self, field_def: Dict[str, Any], **params) -> Dict[str, Any]: # This will be replaced by the nested fields generator @@ -336,17 +331,17 @@ def generate_sparse_vector(self, field_def: Dict[str, Any], **params) -> Dict[st Returns: Dict of token_id -> weight pairs with positive float values """ - num_tokens = params.get('num_tokens', 10) - min_weight = params.get('min_weight', 0.01) - max_weight = params.get('max_weight', 1.0) - token_id_start = params.get('token_id_start', 1000) - token_id_step = params.get('token_id_step', 100) + num_tokens = params.get("num_tokens", 10) + min_weight = params.get("min_weight", 0.01) + max_weight = params.get("max_weight", 1.0) + token_id_start = params.get("token_id_start", 1000) + token_id_step = params.get("token_id_step", 100) sparse_vector = {} for i in range(num_tokens): token_id = str(token_id_start + (i * token_id_step)) weight = random.uniform(min_weight, max_weight) - sparse_vector[token_id] = round(weight, 4) # imitate real neural sparse search models like Splade and DeepImpact + sparse_vector[token_id] = round(weight, 4) # imitate real neural sparse search models like Splade and DeepImpact return sparse_vector @@ -423,10 +418,8 @@ def transform_mapping_to_generators(self, mapping_dict: Dict[str, Any], field_pa transformed_mapping[field_name] = lambda f=field_def, gen=generator_func, p=generator_override_params: gen(f, **p) - return transformed_mapping - def _generate_obj(self, field_def: Dict[str, Any], nested_generators: Dict[str, Callable]) -> Dict[str, Any]: """Generate an object using nested generators""" result = {} diff --git a/solrorbit/synthetic_data_generator/strategies/strategy.py b/solrorbit/synthetic_data_generator/strategies/strategy.py index b0e39573..a7e8bdea 100644 --- a/solrorbit/synthetic_data_generator/strategies/strategy.py +++ b/solrorbit/synthetic_data_generator/strategies/strategy.py @@ -11,8 +11,8 @@ from dask.distributed import Client -class DataGenerationStrategy(ABC): +class DataGenerationStrategy(ABC): @abstractmethod def generate_data_chunks_across_workers(self, dask_client: Client, docs_per_chunk: int, seeds: list, timeseries_enabled: dict, timeseries_windows: list) -> list: """ @@ -22,8 +22,9 @@ def generate_data_chunks_across_workers(self, dask_client: Client, docs_per_chun """ @abstractmethod - def generate_data_chunk_from_worker(self, generate_synthetic_document: Callable, docs_per_chunk: int, - seed: Optional[int], timeseries_enabled: dict = None, timeseries_window: set = None) -> list: + def generate_data_chunk_from_worker( + self, generate_synthetic_document: Callable, docs_per_chunk: int, seed: Optional[int], timeseries_enabled: dict = None, timeseries_window: set = None + ) -> list: """ Generate chunk of docs with data generation logic for Dask worker diff --git a/solrorbit/synthetic_data_generator/synthetic_data_generator.py b/solrorbit/synthetic_data_generator/synthetic_data_generator.py index ca1a8ef6..0b816cd7 100644 --- a/solrorbit/synthetic_data_generator/synthetic_data_generator.py +++ b/solrorbit/synthetic_data_generator/synthetic_data_generator.py @@ -21,6 +21,7 @@ from solrorbit.synthetic_data_generator.models import SyntheticDataGeneratorMetadata, SDGConfig, GB_TO_BYTES from solrorbit.synthetic_data_generator.timeseries_partitioner import TimeSeriesPartitioner + class SyntheticDataGenerator: def __init__(self, sdg_metadata: SyntheticDataGeneratorMetadata, sdg_config: SDGConfig, strategy: DataGenerationStrategy) -> None: self.sdg_metadata = sdg_metadata @@ -33,7 +34,7 @@ def generate_seeds_for_workers(self, regenerate=False): # This adds latency so might consider deprecating this seed_generation_start_time = time.time() client = get_client() - workers = client.scheduler_info()['workers'] + workers = client.scheduler_info()["workers"] seeds = [] for worker_id in workers.keys(): @@ -57,11 +58,7 @@ def setup_timeseries_window(self, timeseries_enabled_settings: dict, workers: in self.logger.info("User is using timeseries enabled settings: %s", timeseries_enabled_settings) # Generate timeseries windows timeseries_partitioner = TimeSeriesPartitioner( - timeseries_enabled=timeseries_enabled_settings, - workers=workers, - docs_per_chunk=docs_per_chunk, - avg_document_size=avg_document_size, - total_size_bytes=total_size_bytes + timeseries_enabled=timeseries_enabled_settings, workers=workers, docs_per_chunk=docs_per_chunk, avg_document_size=avg_document_size, total_size_bytes=total_size_bytes ) timeseries_window = timeseries_partitioner.create_window_generator() if timeseries_enabled_settings.timeseries_frequency != timeseries_partitioner.frequency: @@ -76,10 +73,9 @@ def generate_test_document(self): if timeseries_enabled_settings: # Use dummy values for workers, docs_per_chunk, and avg_document_size timeseries_enabled_settings, timeseries_window = self.setup_timeseries_window( - timeseries_enabled_settings=timeseries_enabled_settings, workers=1, docs_per_chunk=1, - avg_document_size=123, total_size_bytes=total_size_bytes + timeseries_enabled_settings=timeseries_enabled_settings, workers=1, docs_per_chunk=1, avg_document_size=123, total_size_bytes=total_size_bytes ) - windows_for_workers = [next(timeseries_window) for _ in range(1)][0] # Just need to get one window for test document + windows_for_workers = [next(timeseries_window) for _ in range(1)][0] # Just need to get one window for test document return self.strategy.generate_test_document(timeseries_enabled_settings, windows_for_workers) else: @@ -108,8 +104,11 @@ def generate_dataset(self): workers: int = self.sdg_config.settings.workers if timeseries_enabled_settings: timeseries_enabled_settings, timeseries_window = self.setup_timeseries_window( - timeseries_enabled_settings=timeseries_enabled_settings, workers=workers, docs_per_chunk=docs_per_chunk, - avg_document_size=avg_document_size, total_size_bytes=total_size_bytes + timeseries_enabled_settings=timeseries_enabled_settings, + workers=workers, + docs_per_chunk=docs_per_chunk, + avg_document_size=avg_document_size, + total_size_bytes=total_size_bytes, ) dask_client = Client(n_workers=workers, threads_per_worker=1) # We keep it to 1 thread because generating random data is CPU intensive @@ -126,16 +125,14 @@ def generate_dataset(self): self.logger.info("Total GB to generate: [%s]", self.sdg_metadata.total_size_gb) self.logger.info("Max file size in GB: [%s]", self.sdg_config.settings.max_file_size_gb) - console.println(f"Total GB to generate: [{self.sdg_metadata.total_size_gb}]\n" - f"Average document size in bytes: [{avg_document_size}]\n" - f"Max file size in GB: [{self.sdg_config.settings.max_file_size_gb}]\n") + console.println( + f"Total GB to generate: [{self.sdg_metadata.total_size_gb}]\n" + f"Average document size in bytes: [{avg_document_size}]\n" + f"Max file size in GB: [{self.sdg_config.settings.max_file_size_gb}]\n" + ) start_time = time.time() - with tqdm(total=total_size_bytes, - unit='B', - unit_scale=True, - bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]") as progress_bar: - + with tqdm(total=total_size_bytes, unit="B", unit_scale=True, bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]") as progress_bar: helpers.setup_custom_tqdm_formatting(progress_bar) while current_size < total_size_bytes: file_path = os.path.join(self.sdg_metadata.output_path, f"{self.sdg_metadata.index_name}_{file_counter}.json") @@ -152,10 +149,7 @@ def generate_dataset(self): windows_for_workers = [next(timeseries_window) for _ in range(workers)] self.logger.info("Windows for workers: %s", windows_for_workers) mp_generation_start_time = time.time() - futures = self.strategy.generate_data_chunks_across_workers( - dask_client, docs_per_chunk, seeds, - timeseries_enabled_settings, windows_for_workers - ) + futures = self.strategy.generate_data_chunks_across_workers(dask_client, docs_per_chunk, seeds, timeseries_enabled_settings, windows_for_workers) results = dask_client.gather(futures) mp_generation_end_time = time.time() mp_generation_took_time = mp_generation_end_time - mp_generation_start_time @@ -165,7 +159,7 @@ def generate_dataset(self): writing_start_time = time.time() for i, res in enumerate(ordered_results): - self.logger.info("Writing results [%s/%s]", i+1, len(ordered_results)) + self.logger.info("Writing results [%s/%s]", i + 1, len(ordered_results)) docs_written_from_chunk, written_bytes = helpers.write_chunk(res, file_path) docs_written += docs_written_from_chunk current_size += written_bytes @@ -193,21 +187,13 @@ def generate_dataset(self): # If it exceeds the max file size, then append this to keep track of record if file_size >= max_file_size_bytes: file_name = os.path.basename(file_path) - generated_dataset_details.append({ - "file_name": file_name, - "docs": docs_written, - "file_size_bytes": file_size - }) + generated_dataset_details.append({"file_name": file_name, "docs": docs_written, "file_size_bytes": file_size}) if current_size >= total_size_bytes: break if current_size >= total_size_bytes: file_name = os.path.basename(file_path) - generated_dataset_details.append({ - "file_name": file_name, - "docs": docs_written, - "file_size_bytes": file_size - }) + generated_dataset_details.append({"file_name": file_name, "docs": docs_written, "file_size_bytes": file_size}) break file_counter += 1 diff --git a/solrorbit/synthetic_data_generator/synthetic_data_generator_orchestrator.py b/solrorbit/synthetic_data_generator/synthetic_data_generator_orchestrator.py index 6e0ce711..bacc60a0 100644 --- a/solrorbit/synthetic_data_generator/synthetic_data_generator_orchestrator.py +++ b/solrorbit/synthetic_data_generator/synthetic_data_generator_orchestrator.py @@ -18,6 +18,7 @@ from solrorbit.synthetic_data_generator.strategies import CustomModuleStrategy, MappingStrategy from solrorbit.synthetic_data_generator.models import SDGConfig + def orchestrate_data_generation(cfg): logger = logging.getLogger(__name__) sdg_metadata = create_sdg_metadata_from_args(cfg) @@ -29,7 +30,7 @@ def orchestrate_data_generation(cfg): if helpers.host_has_available_disk_storage(sdg_metadata): if use_custom_synthetic_data_generator(sdg_metadata): logger.info("Generating data with Custom Module Strategy") - custom_module = helpers.load_user_module(sdg_metadata.custom_module_path) # load it as a callable + custom_module = helpers.load_user_module(sdg_metadata.custom_module_path) # load it as a callable strategy = CustomModuleStrategy(sdg_metadata, sdg_config, custom_module) elif use_mappings_synthetic_data_generator(sdg_metadata): diff --git a/solrorbit/synthetic_data_generator/timeseries_partitioner.py b/solrorbit/synthetic_data_generator/timeseries_partitioner.py index ddc609e6..34bfa9ed 100644 --- a/solrorbit/synthetic_data_generator/timeseries_partitioner.py +++ b/solrorbit/synthetic_data_generator/timeseries_partitioner.py @@ -16,45 +16,45 @@ import solrorbit.exceptions as exceptions -class TimeSeriesPartitioner: +class TimeSeriesPartitioner: # TODO: Change this into a dictionary that points to which frequencies can have which formats VALID_DATETIMESTAMPS_FORMATS = [ - "%Y-%m-%d", # 2023-05-20 - "%Y-%m-%dT%H:%M:%S", # 2023-05-20T15:30:45 - "%Y-%m-%dT%H:%M:%S.%f", # 2023-05-20T15:30:45.123456 - "%Y-%m-%d %H:%M:%S", # 2023-05-20 15:30:45 - "%Y-%m-%d %H:%M:%S.%f", # 2023-05-20 15:30:45.123456 - "%d/%m/%Y", # 20/05/2023 - "%m/%d/%Y", # 05/20/2023 - "%d-%m-%Y", # 20-05-2023 - "%m-%d-%Y", # 05-20-2023 - "%d.%m.%Y", # 20.05.2023 - "%Y%m%d", # 20230520 - "%B %d, %Y", # May 20, 2023 - "%b %d, %Y", # May 20, 2023 - "%d %B %Y", # 20 May 2023 - "%d %b %Y", # 20 May 2023 - "%Y %B %d", # 2023 May 20 - "%d/%m/%Y %H:%M", # 20/05/2023 15:30 - "%d/%m/%Y %H:%M:%S", # 20/05/2023 15:30:45 - "%Y-%m-%d %I:%M %p", # 2023-05-20 03:30 PM - "%d.%m.%Y %H:%M", # 20.05.2023 15:30 - "%H:%M", # 15:30 - "%H:%M:%S", # 15:30:45 - "%I:%M %p", # 03:30 PM - "%I:%M:%S %p", # 03:30:45 PM - "%a, %d %b %Y %H:%M:%S", # Sat, 20 May 2023 15:30:45 - "%Y/%m/%d", # 2023/05/20 - "%Y/%m/%d %H:%M:%S", # 2023/05/20 15:30:45 - "%Y%m%d%H%M%S", # 20230520153045 - "epoch_s", # Epoch time in seconds format - "epoch_ms" # Epoch time in ms format + "%Y-%m-%d", # 2023-05-20 + "%Y-%m-%dT%H:%M:%S", # 2023-05-20T15:30:45 + "%Y-%m-%dT%H:%M:%S.%f", # 2023-05-20T15:30:45.123456 + "%Y-%m-%d %H:%M:%S", # 2023-05-20 15:30:45 + "%Y-%m-%d %H:%M:%S.%f", # 2023-05-20 15:30:45.123456 + "%d/%m/%Y", # 20/05/2023 + "%m/%d/%Y", # 05/20/2023 + "%d-%m-%Y", # 20-05-2023 + "%m-%d-%Y", # 05-20-2023 + "%d.%m.%Y", # 20.05.2023 + "%Y%m%d", # 20230520 + "%B %d, %Y", # May 20, 2023 + "%b %d, %Y", # May 20, 2023 + "%d %B %Y", # 20 May 2023 + "%d %b %Y", # 20 May 2023 + "%Y %B %d", # 2023 May 20 + "%d/%m/%Y %H:%M", # 20/05/2023 15:30 + "%d/%m/%Y %H:%M:%S", # 20/05/2023 15:30:45 + "%Y-%m-%d %I:%M %p", # 2023-05-20 03:30 PM + "%d.%m.%Y %H:%M", # 20.05.2023 15:30 + "%H:%M", # 15:30 + "%H:%M:%S", # 15:30:45 + "%I:%M %p", # 03:30 PM + "%I:%M:%S %p", # 03:30:45 PM + "%a, %d %b %Y %H:%M:%S", # Sat, 20 May 2023 15:30:45 + "%Y/%m/%d", # 2023/05/20 + "%Y/%m/%d %H:%M:%S", # 2023/05/20 15:30:45 + "%Y%m%d%H%M%S", # 20230520153045 + "epoch_s", # Epoch time in seconds format + "epoch_ms", # Epoch time in ms format ] # TODO: Let's make this a hashmap so that we can ensure the invalid formats are not used (e.g. frequency is updated to ms and format is still seconds) # These frequencies are based on what is supported in the Pandas library - AVAILABLE_FREQUENCIES = ['B', 'C', 'D', 'h', 'bh', 'cbh', 'min', 's', 'ms'] + AVAILABLE_FREQUENCIES = ["B", "C", "D", "h", "bh", "cbh", "min", "s", "ms"] def __init__(self, timeseries_enabled: dict, workers: int, docs_per_chunk: int, avg_document_size: int, total_size_bytes: int): self.timeseries_enabled = timeseries_enabled @@ -88,9 +88,9 @@ def get_updated_settings(self, timeseries_settings) -> dict: return timeseries_settings def create_window_generator(self) -> Generator: - ''' + """ returns: a list of timestamp pairs where each timestamp pair is a set containing start datetime and end datetime - ''' + """ # Determine optimal time settings # Check if number of docs generated will fit in the timestamp. Adjust frequency as needed expected_number_of_docs = self.total_size_bytes // self.avg_document_size @@ -102,12 +102,12 @@ def create_window_generator(self) -> Generator: if number_of_timestamps < expected_number_of_docs_with_buffer: self.logger.info("Number of timestamps generated is less than expected docs generated. Trying to find the optimal frequency") # ms is the smallest unit of time SDG can generate - if self.frequency == 'ms': - msg = "No finer time frequencies available to try than \"ms\". Please expand dates and frequency accordingly." + if self.frequency == "ms": + msg = 'No finer time frequencies available to try than "ms". Please expand dates and frequency accordingly.' self.logger.error(msg) raise exceptions.ConfigError(msg) - #TODO: Update the timeseries enabled settings too so downstream isn't confused + # TODO: Update the timeseries enabled settings too so downstream isn't confused optimal_frequency = self._try_other_frequencies(expected_number_of_docs_with_buffer) if not self._does_user_want_optimal_frequency(user_frequency=self.frequency, optimal_frequency=optimal_frequency): self.logger.info("User does not want to use optimal frequency and will cancel generation.") @@ -123,7 +123,7 @@ def create_window_generator(self) -> Generator: def generate_datetimestamp_window(self): current = pd.Timestamp(self.start_date) end = pd.Timestamp(self.end_date) - freq = pd.Timedelta(f"{self.docs_per_chunk-1}{self.frequency}") # Need to subtract one to include current timestamp. + freq = pd.Timedelta(f"{self.docs_per_chunk - 1}{self.frequency}") # Need to subtract one to include current timestamp. while current < end: window_end = min(current + freq, end) @@ -144,7 +144,7 @@ def generate_datetimestamps_from_window(window: set, frequency: str = "min", for start_datetimestamp = window[0] end_datetimestamp = window[1] generated_datetimestamps: pd.DatetimeIndex = pd.date_range(start_datetimestamp, end_datetimestamp, freq=frequency) - #TODO: Handle formatting after generating iterator? + # TODO: Handle formatting after generating iterator? if format and format in TimeSeriesPartitioner.VALID_DATETIMESTAMPS_FORMATS: if format == "epoch_s": generated_datetimestamps = generated_datetimestamps.map(lambda x: int(x.timestamp())) @@ -166,11 +166,10 @@ def sort_results_by_datetimestamps(results: list, timeseries_field: str) -> list logger.info("Length of results: %s", len(results)) logger.info("Docs in each result: %s ", [len(result) for result in results]) - start_time = time.time() sorted_results = sorted(results, key=lambda chunk: chunk[0][timeseries_field]) end_time = time.time() - logger.info("Time it took to sort: %s secs", end_time-start_time) + logger.info("Time it took to sort: %s secs", end_time - start_time) logger.info("First timestamp from all chunks: %s ", [result[0][timeseries_field] for result in sorted_results]) return sorted_results @@ -193,9 +192,8 @@ def _count_timestamps(self, frequency: str) -> int: count = int(delta / offset) + 1 return count - def _try_other_frequencies(self, expected_number_of_docs_with_buffer: int) -> str: - frequencies_to_try = deque(TimeSeriesPartitioner.AVAILABLE_FREQUENCIES[TimeSeriesPartitioner.AVAILABLE_FREQUENCIES.index(self.frequency)+1:]) + frequencies_to_try = deque(TimeSeriesPartitioner.AVAILABLE_FREQUENCIES[TimeSeriesPartitioner.AVAILABLE_FREQUENCIES.index(self.frequency) + 1 :]) frequency = "" while frequencies_to_try: @@ -210,16 +208,20 @@ def _try_other_frequencies(self, expected_number_of_docs_with_buffer: int) -> st return frequency def _does_user_want_optimal_frequency(self, user_frequency: str, optimal_frequency: str) -> bool: - valid_responses = ['y', 'yes', 'n', 'no'] - msg = f"The frequency [{optimal_frequency}] is a better option for the number of docs you are trying to generate " + \ - "because the current frequency you've selected does not have enough timestamps to allocate to docs generated." + \ - f"If you prefer your current frequency [{user_frequency}], please extend the time frame. " + \ - f"Would you like to use [{optimal_frequency}] as the frequency? (y/n): " + valid_responses = ["y", "yes", "n", "no"] + msg = ( + f"The frequency [{optimal_frequency}] is a better option for the number of docs you are trying to generate " + + "because the current frequency you've selected does not have enough timestamps to allocate to docs generated." + + f"If you prefer your current frequency [{user_frequency}], please extend the time frame. " + + f"Would you like to use [{optimal_frequency}] as the frequency? (y/n): " + ) requested_input = input(msg) while requested_input.lower() not in valid_responses: - msg = f"Please enter y or n. The frequency [{optimal_frequency}] is a better option for the number of docs you are trying to generate. " + \ - f"If you prefer your current frequency [{user_frequency}], please extend the time frame. " + \ - f"Would you like to use [{optimal_frequency}] as the frequency? (y/n): " + msg = ( + f"Please enter y or n. The frequency [{optimal_frequency}] is a better option for the number of docs you are trying to generate. " + + f"If you prefer your current frequency [{user_frequency}], please extend the time frame. " + + f"Would you like to use [{optimal_frequency}] as the frequency? (y/n): " + ) requested_input = input(msg) - return requested_input.lower() in ['y', 'yes'] + return requested_input.lower() in ["y", "yes"] diff --git a/solrorbit/synthetic_data_generator/types.py b/solrorbit/synthetic_data_generator/types.py index 99ebbbea..fd4af7c2 100644 --- a/solrorbit/synthetic_data_generator/types.py +++ b/solrorbit/synthetic_data_generator/types.py @@ -10,15 +10,10 @@ from dataclasses import dataclass, field from typing import Optional -GB_TO_BYTES = 1024 ** 3 +GB_TO_BYTES = 1024**3 + +DEFAULT_GENERATION_SETTINGS = {"workers": os.cpu_count(), "max_file_size_gb": 40, "docs_per_chunk": 10000, "filename_suffix_begins_at": 0, "timeseries_enabled": {}} -DEFAULT_GENERATION_SETTINGS = { - "workers": os.cpu_count(), - "max_file_size_gb": 40, - "docs_per_chunk": 10000, - "filename_suffix_begins_at": 0, - "timeseries_enabled": {} -} @dataclass class SyntheticDataGeneratorMetadata: diff --git a/solrorbit/telemetry.py b/solrorbit/telemetry.py index bb0c90e8..f9406279 100644 --- a/solrorbit/telemetry.py +++ b/solrorbit/telemetry.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -37,13 +37,15 @@ from solrorbit import metrics, time, exceptions from solrorbit.utils import io, sysstats, console, process + def list_telemetry(): console.println("Available telemetry devices:\n") # --- Solr-native devices (always enabled) --- console.println("Always-enabled Solr devices (no --telemetry flag needed):\n") solr_devices = [ - [d.command, d.human_name, d.help] for d in [ + [d.command, d.human_name, d.help] + for d in [ SolrJvmStats, SolrNodeStats, SolrCollectionStats, @@ -57,16 +59,27 @@ def list_telemetry(): # --- Optional REST devices (all pipelines) --- console.println("\n\nOptional REST devices (all pipelines — enable with --telemetry ):\n") - rest_devices = [[device.command, device.human_name, device.help] for device in [ - SegmentStats, ShardStats, ClusterEnvironmentInfo, - ]] + rest_devices = [ + [device.command, device.human_name, device.help] + for device in [ + SegmentStats, + ShardStats, + ClusterEnvironmentInfo, + ] + ] console.println(tabulate.tabulate(rest_devices, ["Command", "Name", "Description"])) # --- Optional JVM/process devices (provisioned pipelines only) --- console.println("\n\nOptional JVM/process devices (docker or from-distribution pipelines only):\n") - jvm_devices = [[device.command, device.human_name, device.help] for device in [ - FlightRecorder, Gc, JitCompiler, Heapdump, - ]] + jvm_devices = [ + [device.command, device.human_name, device.help] + for device in [ + FlightRecorder, + Gc, + JitCompiler, + Heapdump, + ] + ] console.println(tabulate.tabulate(jvm_devices, ["Command", "Name", "Description"])) console.println("\nJVM/process devices inject flags into SOLR_OPTS before Solr starts.") console.println("They are silently skipped when pipeline is benchmark-only.") @@ -131,6 +144,7 @@ def _enabled(self, device): # ######################################################################################## + class TelemetryDevice: def __init__(self): self.logger = logging.getLogger(__name__) @@ -246,8 +260,7 @@ def instrument_java_opts(self): io.ensure_dir(self.log_root) log_file = os.path.join(self.log_root, "jit.log") console.info("%s: Writing JIT compiler log to [%s]" % (self.human_name, log_file), logger=self.logger) - return ["-XX:+UnlockDiagnosticVMOptions", "-XX:+TraceClassLoading", "-XX:+LogCompilation", - "-XX:LogFile={}".format(log_file), "-XX:+PrintAssembly"] + return ["-XX:+UnlockDiagnosticVMOptions", "-XX:+TraceClassLoading", "-XX:+LogCompilation", "-XX:LogFile={}".format(log_file), "-XX:+PrintAssembly"] class Gc(TelemetryDevice): @@ -293,8 +306,7 @@ def detach_from_node(self, node, running): # noinspection PyBroadException try: if self.docker_container: - cmd = "docker exec {} jmap -dump:format=b,file={} {}".format( - self.docker_container, heap_dump_file, node.pid) + cmd = "docker exec {} jmap -dump:format=b,file={} {}".format(self.docker_container, heap_dump_file, node.pid) else: cmd = "jmap -dump:format=b,file={} {}".format(heap_dump_file, node.pid) if process.run_subprocess_with_logging(cmd): @@ -363,9 +375,7 @@ def __init__(self, telemetry_params, admin_client, metrics_store): self.metrics_store = metrics_store self.sample_interval = telemetry_params.get("shard-stats-sample-interval", 60) if self.sample_interval <= 0: - raise exceptions.SystemSetupError( - f"The telemetry parameter 'shard-stats-sample-interval' must be greater than zero but was {self.sample_interval}." - ) + raise exceptions.SystemSetupError(f"The telemetry parameter 'shard-stats-sample-interval' must be greater than zero but was {self.sample_interval}.") self.samplers = [] def on_benchmark_start(self): @@ -429,14 +439,13 @@ def record(self): idx = core_status.get("index", {}) num_docs = idx.get("numDocs", 0) size_bytes = idx.get("sizeInBytes", 0) - self.metrics_store.put_value_cluster_level( - f"shard_{shard_name}_num_docs", num_docs, "") - self.metrics_store.put_value_cluster_level( - f"shard_{shard_name}_size_bytes", size_bytes, "byte") + self.metrics_store.put_value_cluster_level(f"shard_{shard_name}_num_docs", num_docs, "") + self.metrics_store.put_value_cluster_level(f"shard_{shard_name}_size_bytes", size_bytes, "byte") except BaseException: self.logger.warning("ShardStats: could not get core STATUS for [%s].", core_name) break # only need the leader replica per shard + class StartupTime(InternalTelemetryDevice): def __init__(self, stopwatch=time.StopWatch): super().__init__() @@ -456,6 +465,7 @@ class DiskIo(InternalTelemetryDevice): """ Gathers disk I/O stats. """ + def __init__(self, node_count_on_host): super().__init__() self.node_count_on_host = node_count_on_host @@ -475,8 +485,7 @@ def attach_to_node(self, node): disk_start = sysstats.disk_io_counters() self.read_bytes = disk_start.read_bytes self.write_bytes = disk_start.write_bytes - self.logger.warning("Process I/O counters are not supported on this platform. Falling back to less " - "accurate disk I/O counters.") + self.logger.warning("Process I/O counters are not supported on this platform. Falling back to less accurate disk I/O counters.") except BaseException: self.logger.exception("Could not determine I/O stats at benchmark start.") @@ -495,9 +504,12 @@ def detach_from_node(self, node, running): else: disk_end = sysstats.disk_io_counters() if self.node_count_on_host > 1: - self.logger.info("There are [%d] nodes on this host and Solr Orbit fell back to disk I/O counters. " - "Attributing [1/%d] of total I/O to [%s].", - self.node_count_on_host, self.node_count_on_host, node.node_name) + self.logger.info( + "There are [%d] nodes on this host and Solr Orbit fell back to disk I/O counters. Attributing [1/%d] of total I/O to [%s].", + self.node_count_on_host, + self.node_count_on_host, + node.node_name, + ) self.read_bytes = (disk_end.read_bytes - self.read_bytes) // self.node_count_on_host self.write_bytes = (disk_end.write_bytes - self.write_bytes) // self.node_count_on_host @@ -568,6 +580,7 @@ class ClusterEnvironmentInfo(TelemetryDevice): Gathers static environment information on a cluster level (Solr version, JVM, CPU). Called once at benchmark start; stores results as run metadata. """ + internal = False command = "cluster-environment-info" human_name = "Cluster Environment Info" @@ -625,11 +638,11 @@ def add_metadata_for_node(metrics_store, node_name, host_name): metrics_store.add_meta_info(metrics.MetaInfoScope.node, node_name, "host_name", host_name) - class IndexSize(InternalTelemetryDevice): """ Measures the final size of the index """ + def __init__(self, data_paths): super().__init__() self.data_paths = data_paths @@ -661,6 +674,7 @@ def store_system_metrics(self, node, metrics_store): # Prometheus text format parser (shared with runner.py) # --------------------------------------------------------------------------- + def _parse_prometheus_text(text: str) -> dict: """ Parse Prometheus exposition text format into a flat dict of {metric_name: float}. @@ -693,6 +707,7 @@ def _parse_prometheus_text(text: str) -> dict: # Base class # --------------------------------------------------------------------------- + class SolrTelemetryDevice(TelemetryDevice): """ Abstract base for Solr telemetry polling devices. @@ -780,8 +795,11 @@ def _put(self, name: str, value, unit: str, task: str = "", meta: dict = None) - self._metrics_store[name] = {"value": value, "unit": unit} return self._metrics_store.put_value_cluster_level( - name=name, value=value, unit=unit, - task=task, operation_type="telemetry", + name=name, + value=value, + unit=unit, + task=task, + operation_type="telemetry", meta_data=meta or {}, ) @@ -790,6 +808,7 @@ def _put(self, name: str, value, unit: str, task: str = "", meta: dict = None) - # Device: SolrJvmStats # --------------------------------------------------------------------------- + class SolrJvmStats(SolrTelemetryDevice): """ Collect JVM heap, GC, thread, and buffer pool metrics from Solr. @@ -892,6 +911,7 @@ def _collect_prometheus(self, data: dict) -> None: # Device: SolrNodeStats # --------------------------------------------------------------------------- + class SolrNodeStats(SolrTelemetryDevice): """ Collect OS, file-descriptor, HTTP, and query-handler metrics from Solr. @@ -957,9 +977,7 @@ def _collect_metrics_json(self, data: dict) -> None: self._put("query_handler_avg_latency_ms", avg_latency, "ms") jetty = self._get_metric_json(data, "metrics", "solr.jetty") or {} - http_requests = jetty.get( - "org.eclipse.jetty.server.handler.StatisticsHandler.requests" - ) + http_requests = jetty.get("org.eclipse.jetty.server.handler.StatisticsHandler.requests") if http_requests is not None: self._put("node_http_requests_total", http_requests, "") @@ -980,6 +998,7 @@ def _collect_metrics_prometheus(self, data: dict) -> None: # Device: SolrCollectionStats # --------------------------------------------------------------------------- + class SolrCollectionStats(SolrTelemetryDevice): """ Collect per-collection document count, index size, segment count, and deleted docs. @@ -990,8 +1009,7 @@ class SolrCollectionStats(SolrTelemetryDevice): human_name = "Solr Collection Stats" help = "Per-collection: doc count, deleted docs, index size, and segment count (30 s interval)" - def __init__(self, admin_client, metrics_store, - collections: list = None, sample_interval_s: float = 30.0): + def __init__(self, admin_client, metrics_store, collections: list = None, sample_interval_s: float = 30.0): super().__init__(admin_client, metrics_store, sample_interval_s) self._collections = collections @@ -1018,8 +1036,7 @@ def _collect_collection(self, collection: str) -> None: self._put("num_docs", num_docs, "docs", meta={"collection": collection}) if index_size: - self._put("index_size_bytes", index_size, "bytes", - meta={"collection": collection}) + self._put("index_size_bytes", index_size, "bytes", meta={"collection": collection}) except Exception: pass @@ -1027,9 +1044,7 @@ def _collect_collection(self, collection: str) -> None: def _fetch_luke_stats(self, collection: str) -> None: try: - resp = self._client._get( - f"/solr/{collection}/admin/luke?numTerms=0&wt=json" - ) + resp = self._client._get(f"/solr/{collection}/admin/luke?numTerms=0&wt=json") info = resp.json().get("index", {}) num_docs = info.get("numDocs") deleted_docs = info.get("deletedDocs") or info.get("numDeletedDocs") @@ -1038,20 +1053,18 @@ def _fetch_luke_stats(self, collection: str) -> None: if num_docs is not None: self._put("num_docs", num_docs, "docs", meta={"collection": collection}) if deleted_docs is not None: - self._put("num_deleted_docs", deleted_docs, "docs", - meta={"collection": collection}) + self._put("num_deleted_docs", deleted_docs, "docs", meta={"collection": collection}) if segment_count is not None: - self._put("segment_count", segment_count, "", - meta={"collection": collection}) + self._put("segment_count", segment_count, "", meta={"collection": collection}) except Exception as exc: - logging.getLogger(__name__).debug("SolrCollectionStats: luke fallback failed for %s: %s", - collection, exc) + logging.getLogger(__name__).debug("SolrCollectionStats: luke fallback failed for %s: %s", collection, exc) # --------------------------------------------------------------------------- # Device: SolrQueryStats # --------------------------------------------------------------------------- + class SolrQueryStats(SolrTelemetryDevice): """ Collect query latency percentiles and cache hit ratio metrics from Solr. @@ -1107,6 +1120,7 @@ def _collect_prometheus(self, data: dict) -> None: # Device: SolrIndexingStats # --------------------------------------------------------------------------- + class SolrIndexingStats(SolrTelemetryDevice): """ Collect indexing throughput and merge metrics from Solr. @@ -1158,6 +1172,7 @@ def _collect_prometheus(self, data: dict) -> None: # Device: SolrCacheStats # --------------------------------------------------------------------------- + class SolrCacheStats(SolrTelemetryDevice): """ Collect Solr internal cache statistics for the three primary caches. diff --git a/solrorbit/test_run_orchestrator.py b/solrorbit/test_run_orchestrator.py index 7ff1af4c..9d2907f8 100644 --- a/solrorbit/test_run_orchestrator.py +++ b/solrorbit/test_run_orchestrator.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -36,9 +36,7 @@ import tabulate import thespian.actors -from solrorbit import actor, config, doc_link, \ - worker_coordinator, exceptions, builder, metrics, \ - publisher, workload, version, PROGRAM_NAME +from solrorbit import actor, config, doc_link, worker_coordinator, exceptions, builder, metrics, publisher, workload, version, PROGRAM_NAME from solrorbit.builder import cluster_config as cc_module from solrorbit.builder.supplier import SourceRepository, Builder from solrorbit.builder.solr_provisioner import SolrProvisioner, SolrDockerLauncher @@ -117,21 +115,13 @@ def receiveMsg_Setup(self, msg, sender): self.coordinator.setup(sources=msg.sources) self.logger.info("Asking builder to start the engine.") self.builder = self.createActor(builder.BuilderActor, targetActorRequirements={"coordinator": True}) - self.send(self.builder, builder.StartEngine(self.cfg, - self.coordinator.metrics_store.open_context, - msg.sources, - msg.distribution, - msg.external, - msg.docker)) + self.send(self.builder, builder.StartEngine(self.cfg, self.coordinator.metrics_store.open_context, msg.sources, msg.distribution, msg.external, msg.docker)) @actor.no_retry("test run orchestrator") # pylint: disable=no-value-for-parameter def receiveMsg_EngineStarted(self, msg, sender): self.logger.info("Builder has started engine successfully.") self.coordinator.test_run.cluster_config_revision = msg.cluster_config_revision - self.main_worker_coordinator = self.createActor( - worker_coordinator.WorkerCoordinatorActor, - targetActorRequirements={"coordinator": True} - ) + self.main_worker_coordinator = self.createActor(worker_coordinator.WorkerCoordinatorActor, targetActorRequirements={"coordinator": True}) self.logger.info("Telling worker_coordinator to prepare for benchmarking.") self.send(self.main_worker_coordinator, worker_coordinator.PrepareBenchmark(self.cfg, self.coordinator.current_workload)) @@ -215,6 +205,7 @@ def _check_workload_is_solr_native(self): return from solrorbit.conversion.detector import is_opensearch_workload_path + if is_opensearch_workload_path(workload_path): msg = ( f"This workload is in OpenSearch Benchmark format and cannot be run directly.\n" @@ -225,9 +216,7 @@ def _check_workload_is_solr_native(self): f"Then re-run with --workload-path {workload_path}-solr" ) console.error(msg) - raise exceptions.SystemSetupError( - f"OSB workload detected at '{workload_path}' — convert it first with 'solr-orbit convert-workload'" - ) + raise exceptions.SystemSetupError(f"OSB workload detected at '{workload_path}' — convert it first with 'solr-orbit convert-workload'") def setup(self, sources=False): # to load the workload we need to know the correct cluster distribution version. Usually, this value should be set @@ -252,21 +241,15 @@ def setup(self, sources=False): if self.current_test_procedure is None: raise exceptions.SystemSetupError( "Workload [{}] does not provide test_procedure [{}]. List the available workloads with {} list workloads.".format( - self.current_workload.name, test_procedure_name, PROGRAM_NAME)) + self.current_workload.name, test_procedure_name, PROGRAM_NAME + ) + ) if self.current_test_procedure.user_info: console.info(self.current_test_procedure.user_info) - self.test_run = metrics.create_test_run( - self.cfg, self.current_workload, - self.current_test_procedure, - self.workload_revision) - - self.metrics_store = metrics.metrics_store( - self.cfg, - workload=self.test_run.workload_name, - test_procedure=self.test_run.test_procedure_name, - read_only=False - ) + self.test_run = metrics.create_test_run(self.cfg, self.current_workload, self.current_test_procedure, self.workload_revision) + + self.metrics_store = metrics.metrics_store(self.cfg, workload=self.test_run.workload_name, test_procedure=self.test_run.test_procedure_name, read_only=False) self.test_run_store = metrics.test_run_store(self.cfg) def on_preparation_complete(self, distribution_flavor, distribution_version, revision): @@ -286,20 +269,17 @@ def on_preparation_complete(self, distribution_flavor, distribution_version, rev # pipeline = how the cluster is provisioned (e.g., "docker", "from-sources", "benchmark-only") cluster_cfg_display = ", ".join(self.test_run.cluster_config or ["none"]) if self.test_run.test_procedure.auto_generated: - console.info("Running benchmark with pipeline [{}], workload [{}], cluster_config [{}], version [{}].\n" - .format(self.test_run.pipeline, - self.test_run.workload_name, - cluster_cfg_display, - self.test_run.distribution_version or "unknown")) + console.info( + "Running benchmark with pipeline [{}], workload [{}], cluster_config [{}], version [{}].\n".format( + self.test_run.pipeline, self.test_run.workload_name, cluster_cfg_display, self.test_run.distribution_version or "unknown" + ) + ) else: - console.info("Running benchmark with pipeline [{}], workload [{}], test_procedure [{}], cluster_config [{}], version [{}].\n" - .format( - self.test_run.pipeline, - self.test_run.workload_name, - self.test_run.test_procedure_name, - cluster_cfg_display, - self.test_run.distribution_version or "unknown" - )) + console.info( + "Running benchmark with pipeline [{}], workload [{}], test_procedure [{}], cluster_config [{}], version [{}].\n".format( + self.test_run.pipeline, self.test_run.workload_name, self.test_run.test_procedure_name, cluster_cfg_display, self.test_run.distribution_version or "unknown" + ) + ) def on_task_finished(self, new_metrics): self.logger.info("Task has finished.") @@ -373,7 +353,7 @@ def set_default_hosts(cfg, host="127.0.0.1", port=9200): logger.info("Using configured hosts %s", configured_hosts.default) else: logger.info("Setting default host to [%s:%d]", host, port) - default_host_object = opts.TargetHosts("{}:{}".format(host,port)) + default_host_object = opts.TargetHosts("{}:{}".format(host, port)) cfg.add(config.Scope.benchmark, "client", "hosts", default_host_object) @@ -383,14 +363,14 @@ def benchmark_only(cfg): return run_test(cfg, external=True) -Pipeline("benchmark-only", - "Assumes an already running search engine instance, runs a benchmark and publishes results", benchmark_only) +Pipeline("benchmark-only", "Assumes an already running search engine instance, runs a benchmark and publishes results", benchmark_only) # --------------------------------------------------------------------------- # Solr-specific pipelines # --------------------------------------------------------------------------- + def _load_cluster_config(cfg): """ Load the cluster_config instance from the configured INI repository. @@ -437,11 +417,9 @@ def solr_from_sources(cfg): revision = cfg.opts("builder", "source.revision", mandatory=False, default_value="latest") port = int(cfg.opts("solr", "port", mandatory=False, default_value=8983)) src_dir = os.path.join(base_dir, "sources", "solr") - install_dir = cfg.opts("solr", "install_dir", mandatory=False, - default_value=os.path.join(base_dir, "builds", revision or "latest")) + install_dir = cfg.opts("solr", "install_dir", mandatory=False, default_value=os.path.join(base_dir, "builds", revision or "latest")) log_dir = os.path.join(base_dir, "logs") - remote_url = cfg.opts("source", "remote.repo.url", mandatory=False, - default_value="https://github.com/apache/solr.git") + remote_url = cfg.opts("source", "remote.repo.url", mandatory=False, default_value="https://github.com/apache/solr.git") # Step 1: Clone / update source tree logger.info("Fetching Solr sources at revision [%s] from [%s].", revision, remote_url) @@ -457,10 +435,7 @@ def solr_from_sources(cfg): pattern = os.path.join(src_dir, "solr", "packaging", "build", "distributions", "solr-*.tgz") tarballs = glob.glob(pattern) if not tarballs: - raise exceptions.SystemSetupError( - f"No Solr tarball found matching {pattern}. " - f"Check the Gradle build log at {os.path.join(log_dir, 'build.log')}." - ) + raise exceptions.SystemSetupError(f"No Solr tarball found matching {pattern}. Check the Gradle build log at {os.path.join(log_dir, 'build.log')}.") tarball_path = sorted(tarballs)[-1] # pick the newest if multiple logger.info("Using built Solr tarball: %s", tarball_path) @@ -473,8 +448,7 @@ def solr_from_sources(cfg): cc_instance = _load_cluster_config(cfg) solr_modules = cfg.opts("solr", "modules", mandatory=False, default_value="") - provisioner = SolrProvisioner(cache_dir=os.path.join(base_dir, "cache"), port=port, - cluster_config=cc_instance, solr_modules=solr_modules) + provisioner = SolrProvisioner(cache_dir=os.path.join(base_dir, "cache"), port=port, cluster_config=cc_instance, solr_modules=solr_modules) try: provisioner.start(solr_root, mode="cloud") set_default_hosts(cfg, host="127.0.0.1", port=port) @@ -490,9 +464,7 @@ def solr_from_sources(cfg): logger.warning("Solr clean failed during teardown: %s", exc) -Pipeline("from-sources", - "Builds Solr from source (git clone + Gradle assemble), provisions it locally, " - "runs a benchmark, and tears down.", solr_from_sources) +Pipeline("from-sources", "Builds Solr from source (git clone + Gradle assemble), provisions it locally, runs a benchmark, and tears down.", solr_from_sources) def solr_from_distribution(cfg): @@ -511,15 +483,12 @@ def solr_from_distribution(cfg): version_str = cfg.opts("builder", "distribution.version") port = int(cfg.opts("solr", "port", mandatory=False, default_value=8983)) base_dir = os.path.expanduser("~/.solr-orbit") - install_dir = cfg.opts("solr", "install_dir", mandatory=False, - default_value=os.path.join(base_dir, "installs", version_str)) - cache_dir = cfg.opts("solr", "cache_dir", mandatory=False, - default_value=os.path.join(base_dir, "cache")) + install_dir = cfg.opts("solr", "install_dir", mandatory=False, default_value=os.path.join(base_dir, "installs", version_str)) + cache_dir = cfg.opts("solr", "cache_dir", mandatory=False, default_value=os.path.join(base_dir, "cache")) cc_instance = _load_cluster_config(cfg) solr_modules = cfg.opts("solr", "modules", mandatory=False, default_value="") - provisioner = SolrProvisioner(cache_dir=cache_dir, port=port, cluster_config=cc_instance, - solr_modules=solr_modules) + provisioner = SolrProvisioner(cache_dir=cache_dir, port=port, cluster_config=cc_instance, solr_modules=solr_modules) _tarball = provisioner.download(version_str) solr_root = provisioner.install(version_str, install_dir) @@ -576,11 +545,9 @@ def solr_docker(cfg): logging.getLogger(__name__).warning("Solr Docker teardown failed: %s", exc) -Pipeline("from-distribution", - "Downloads a Solr distribution, provisions it locally, runs a benchmark, and tears down.", solr_from_distribution) +Pipeline("from-distribution", "Downloads a Solr distribution, provisions it locally, runs a benchmark, and tears down.", solr_from_distribution) -Pipeline("docker", - "Starts Solr via Docker, runs a benchmark, and removes the container on teardown.", solr_docker) +Pipeline("docker", "Starts Solr via Docker, runs a benchmark, and removes the container on teardown.", solr_docker) def available_pipelines(): @@ -616,14 +583,13 @@ def run(cfg): raise exceptions.SystemSetupError( "Only the [benchmark-only] pipeline is supported by the Docker image.\n" "Add --pipeline=benchmark-only in your arguments and try again.\n" - "For more details read the docs for the benchmark-only pipeline in {}\n".format( - doc_link(""))) + "For more details read the docs for the benchmark-only pipeline in {}\n".format(doc_link("")) + ) try: pipeline = pipelines[name] except KeyError: - raise exceptions.SystemSetupError( - "Unknown pipeline [%s]. List the available pipelines with %s list pipelines." % (name, PROGRAM_NAME)) + raise exceptions.SystemSetupError("Unknown pipeline [%s]. List the available pipelines with %s list pipelines." % (name, PROGRAM_NAME)) try: pipeline(cfg) except exceptions.BenchmarkError as e: diff --git a/solrorbit/time.py b/solrorbit/time.py index 896aeb34..af85116f 100644 --- a/solrorbit/time.py +++ b/solrorbit/time.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an diff --git a/solrorbit/utils/__init__.py b/solrorbit/utils/__init__.py index 5047a451..f5768141 100644 --- a/solrorbit/utils/__init__.py +++ b/solrorbit/utils/__init__.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an diff --git a/solrorbit/utils/collections.py b/solrorbit/utils/collections.py index 60f4797c..410e97ff 100644 --- a/solrorbit/utils/collections.py +++ b/solrorbit/utils/collections.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an diff --git a/solrorbit/utils/console.py b/solrorbit/utils/console.py index 02a97f42..e9789b11 100644 --- a/solrorbit/utils/console.py +++ b/solrorbit/utils/console.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -124,7 +124,7 @@ def init(quiet=False, assume_tty=True): except (KeyError, ValueError): # noinspection PyBroadException try: - os.environ['COLUMNS'] = str(shutil.get_terminal_size().columns) + os.environ["COLUMNS"] = str(shutil.get_terminal_size().columns) except BaseException: # don't fail if anything goes wrong here pass @@ -142,23 +142,20 @@ def set_assume_tty(assume_tty): def info(msg, end="\n", flush=False, force=False, logger=None, overline=None, underline=None): - println(msg, console_prefix="[INFO]", end=end, flush=flush, force=force, overline=overline, underline=underline, - logger=logger.info if logger else None) + println(msg, console_prefix="[INFO]", end=end, flush=flush, force=force, overline=overline, underline=underline, logger=logger.info if logger else None) def warn(msg, end="\n", flush=False, force=False, logger=None, overline=None, underline=None): - println(msg, console_prefix="[WARNING]", end=end, flush=flush, force=force, overline=overline, underline=underline - , logger=logger.warning if logger else None) + println(msg, console_prefix="[WARNING]", end=end, flush=flush, force=force, overline=overline, underline=underline, logger=logger.warning if logger else None) def error(msg, end="\n", flush=False, force=False, logger=None, overline=None, underline=None): - println(msg, console_prefix="[ERROR]", end=end, flush=flush, force=force, overline=overline, underline=underline - , logger=logger.error if logger else None, stderr=True) + println(msg, console_prefix="[ERROR]", end=end, flush=flush, force=force, overline=overline, underline=underline, logger=logger.error if logger else None, stderr=True) def println(msg, console_prefix=None, end="\n", flush=False, force=False, logger=None, overline=None, underline=None, stderr=False): allow_print = force or (not QUIET and (BENCHMARK_RUNNING_IN_DOCKER or ASSUME_TTY or sys.stdout.isatty())) - file=sys.stderr if stderr else sys.stdout + file = sys.stderr if stderr else sys.stdout if allow_print: complete_msg = "%s %s" % (console_prefix, msg) if console_prefix else msg if overline: @@ -171,7 +168,8 @@ def println(msg, console_prefix=None, end="\n", flush=False, force=False, logger def progress(width=90): - return CmdLineProgressResultsPublisher(width, plain_output=PLAIN) # check-deprecated-terms-disable-1x + return CmdLineProgressResultsPublisher(width, plain_output=PLAIN) # check-deprecated-terms-disable-1x + # check-deprecated-terms-disable-1x class CmdLineProgressResultsPublisher: @@ -218,7 +216,7 @@ def _truncate(self, text, max_length, omission="..."): if len(text) <= max_length: return text else: - return "%s%s" % (text[0:max_length - len(omission) - 5], omission) + return "%s%s" % (text[0 : max_length - len(omission) - 5], omission) def finish(self): if QUIET or (not BENCHMARK_RUNNING_IN_DOCKER and not ASSUME_TTY and not sys.stdout.isatty()): diff --git a/solrorbit/utils/convert.py b/solrorbit/utils/convert.py index 212a3a05..67568573 100644 --- a/solrorbit/utils/convert.py +++ b/solrorbit/utils/convert.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -25,6 +25,7 @@ # specific language governing permissions and limitations # under the License. + def bytes_to_kb(b): return b / 1024.0 if b else b diff --git a/solrorbit/utils/dataset.py b/solrorbit/utils/dataset.py index 2685f745..1e81c460 100644 --- a/solrorbit/utils/dataset.py +++ b/solrorbit/utils/dataset.py @@ -21,6 +21,7 @@ class Context(Enum): """DataSet context enum. Can be used to add additional context for how a data-set should be interpreted. """ + INDEX = 1 QUERY = 2 NEIGHBORS = 3 @@ -39,6 +40,7 @@ class DataSet(ABC): size: Gets the number of items in the data-set reset: Resets internal state of data-set to beginning """ + __metaclass__ = ABCMeta BEGINNING = 0 @@ -86,7 +88,7 @@ def get_data_set(data_set_format: str, path: str, context: Context): class HDF5DataSet(DataSet): - """ Data-set format corresponding to `ANN Benchmarks + """Data-set format corresponding to `ANN Benchmarks `_ """ @@ -113,7 +115,7 @@ def read(self, chunk_size: int): if end_offset > self.size(): end_offset = self.size() - vectors = cast(np.ndarray, self.data[self.current:end_offset]) + vectors = cast(np.ndarray, self.data[self.current : end_offset]) self.current = end_offset return vectors @@ -147,7 +149,7 @@ def parse_context(context: Context) -> str: return "test" if context == Context.PARENTS: - return "parents" # used in nested benchmarks to get the parent document id associated with each vector. + return "parents" # used in nested benchmarks to get the parent document id associated with each vector. if context == Context.MAX_DISTANCE_NEIGHBORS: return "max_distance_neighbors" @@ -161,7 +163,6 @@ def parse_context(context: Context) -> str: class BigANNDataSet(DataSet): - DATA_SET_HEADER_LENGTH = 8 FORMAT_NAME = "bigann" @@ -175,12 +176,11 @@ def __init__(self, dataset_path: str): self.row_length = 0 def _init_internal_params(self): - self.file = open(self.dataset_path, 'rb') + self.file = open(self.dataset_path, "rb") self.file.seek(DataSet.BEGINNING, os.SEEK_END) self.num_bytes = self.file.tell() if self.num_bytes < BigANNDataSet.DATA_SET_HEADER_LENGTH: - raise Exception("Invalid file: file size cannot be less than {} bytes".format( - BigANNDataSet.DATA_SET_HEADER_LENGTH)) + raise Exception("Invalid file: file size cannot be less than {} bytes".format(BigANNDataSet.DATA_SET_HEADER_LENGTH)) self.file.seek(BigANNDataSet.BEGINNING) self.rows = int.from_bytes(self.file.read(4), "little") self.row_length = int.from_bytes(self.file.read(4), "little") @@ -202,9 +202,7 @@ def read(self, chunk_size: int): if end_offset > self.size(): end_offset = self.size() - vectors = np.asarray( - [self._read_vector() for _ in range(end_offset - self.current)] - ) + vectors = np.asarray([self._read_vector() for _ in range(end_offset - self.current)]) self.current = end_offset return vectors @@ -247,12 +245,10 @@ def _get_supported_extension(self): """Return list of supported extension by this dataset""" def _get_extension(self): - ext = self.dataset_path.split('.')[-1] + ext = self.dataset_path.split(".")[-1] supported_extension = self._get_supported_extension() if ext not in supported_extension: - raise InvalidExtensionException( - "Unknown extension :{}, supported extensions are: {}".format( - ext, str(supported_extension))) + raise InvalidExtensionException("Unknown extension :{}, supported extensions are: {}".format(ext, str(supported_extension))) return ext @abstractmethod @@ -274,25 +270,21 @@ def _value_reader(self): class BigANNVectorDataSet(BigANNDataSet): - """ Data-set format for vector data-sets for `Big ANN Benchmarks + """Data-set format for vector data-sets for `Big ANN Benchmarks ` """ U8BIN_EXTENSION = "u8bin" FBIN_EXTENSION = "fbin" - SUPPORTED_EXTENSION = [ - FBIN_EXTENSION, U8BIN_EXTENSION - ] + SUPPORTED_EXTENSION = [FBIN_EXTENSION, U8BIN_EXTENSION] BYTES_PER_U8INT = 1 BYTES_PER_FLOAT = 4 def _init_internal_params(self): super()._init_internal_params() - if (self.num_bytes - BigANNDataSet.DATA_SET_HEADER_LENGTH) != ( - self.rows * self.row_length * self.bytes_per_num): - raise Exception("Invalid file. File size is not matching with expected estimated " - "value based on number of points, dimension and bytes per point") + if (self.num_bytes - BigANNDataSet.DATA_SET_HEADER_LENGTH) != (self.rows * self.row_length * self.bytes_per_num): + raise Exception("Invalid file. File size is not matching with expected estimated value based on number of points, dimension and bytes per point") def _get_supported_extension(self): return BigANNVectorDataSet.SUPPORTED_EXTENSION @@ -311,13 +303,13 @@ def _get_value_reader(self, extension): return lambda file: float(int.from_bytes(file.read(BigANNVectorDataSet.BYTES_PER_U8INT), "little")) if extension == BigANNVectorDataSet.FBIN_EXTENSION: - return lambda file: struct.unpack('`""" BIN_EXTENSION = "bin" @@ -331,10 +323,8 @@ def _init_internal_params(self): # num_queries(uint32_t) K-NN(uint32) followed by num_queries X K x sizeof(uint32_t) bytes of data # representing the IDs of the K-nearest neighbors of the queries, followed by num_queries X K x sizeof(float) # bytes of data representing the distances to the corresponding points. - if (self.num_bytes - BigANNDataSet.DATA_SET_HEADER_LENGTH) != 2 * ( - self.rows * self.row_length * self.bytes_per_num): - raise Exception("Invalid file. File size is not matching with expected estimated " - "value based on number of queries, k and bytes per query") + if (self.num_bytes - BigANNDataSet.DATA_SET_HEADER_LENGTH) != 2 * (self.rows * self.row_length * self.bytes_per_num): + raise Exception("Invalid file. File size is not matching with expected estimated value based on number of queries, k and bytes per query") def _get_supported_extension(self): return BigANNGroundTruthDataSet.SUPPORTED_EXTENSION @@ -343,14 +333,13 @@ def get_data_size(self, extension): return BigANNGroundTruthDataSet.BYTES_PER_UNSIGNED_INT32 def _get_value_reader(self, extension): - return lambda file: int.from_bytes( - file.read(BigANNGroundTruthDataSet.BYTES_PER_UNSIGNED_INT32), "little") + return lambda file: int.from_bytes(file.read(BigANNGroundTruthDataSet.BYTES_PER_UNSIGNED_INT32), "little") def create_big_ann_dataset(file_path: str): if not file_path: raise Exception("Invalid file path") - extension = file_path.split('.')[-1] + extension = file_path.split(".")[-1] if extension in BigANNGroundTruthDataSet.SUPPORTED_EXTENSION: return BigANNGroundTruthDataSet(file_path) if extension in BigANNVectorDataSet.SUPPORTED_EXTENSION: diff --git a/solrorbit/utils/git.py b/solrorbit/utils/git.py index 74082829..a54b81cb 100644 --- a/solrorbit/utils/git.py +++ b/solrorbit/utils/git.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -32,7 +32,8 @@ from solrorbit.utils import io, process MIN_REQUIRED_MAJOR_VERSION = 2 -VERSION_REGEX = r'.* ([0-9]+)\.([0-9]+)\..*' +VERSION_REGEX = r".* ([0-9]+)\.([0-9]+)\..*" + def probed(f): def probe(src, *args, **kwargs): @@ -44,9 +45,9 @@ def probe(src, *args, **kwargs): raise exceptions.SystemSetupError("Error invoking 'git', please install (or re-install).") match = re.search(VERSION_REGEX, out) if not match or int(match.group(1)) < MIN_REQUIRED_MAJOR_VERSION: - raise exceptions.SystemSetupError("solr-orbit requires at least version 2 of git. " - f"You have {out}. Please update git.") + raise exceptions.SystemSetupError(f"solr-orbit requires at least version 2 of git. You have {out}. Please update git.") return f(src, *args, **kwargs) + return probe @@ -96,8 +97,7 @@ def pull(src_dir, remote="origin", branch="main"): def pull_ts(src_dir, ts): fetch(src_dir) clean_src = io.escape_path(src_dir) - revision = process.run_subprocess_with_output( - "git -C {0} rev-list -n 1 --before=\"{1}\" --date=iso8601 origin/main".format(clean_src, ts))[0].strip() + revision = process.run_subprocess_with_output('git -C {0} rev-list -n 1 --before="{1}" --date=iso8601 origin/main'.format(clean_src, ts))[0].strip() if process.run_subprocess_with_logging("git -C {0} checkout {1}".format(clean_src, revision)): raise exceptions.SupplyError("Could not checkout source tree for timestamped revision [%s]" % ts) @@ -111,14 +111,12 @@ def pull_revision(src_dir, revision): @probed def head_revision(src_dir): - return process.run_subprocess_with_output("git -C {0} rev-parse --short HEAD".format( - io.escape_path(src_dir)))[0].strip() + return process.run_subprocess_with_output("git -C {0} rev-parse --short HEAD".format(io.escape_path(src_dir)))[0].strip() @probed def current_branch(src_dir): - return process.run_subprocess_with_output("git -C {0} rev-parse --abbrev-ref HEAD".format( - io.escape_path(src_dir)))[0].strip() + return process.run_subprocess_with_output("git -C {0} rev-parse --abbrev-ref HEAD".format(io.escape_path(src_dir)))[0].strip() @probed @@ -126,12 +124,9 @@ def branches(src_dir, remote=True): clean_src = io.escape_path(src_dir) if remote: # Because compatability issues with Git 2.40.0+, updated --format='%(refname:short)' to --format='%(refname)' - return _cleanup_remote_branch_names(process.run_subprocess_with_output( - "git -C {src} for-each-ref refs/remotes/ --format='%(refname)'".format(src=clean_src))) + return _cleanup_remote_branch_names(process.run_subprocess_with_output("git -C {src} for-each-ref refs/remotes/ --format='%(refname)'".format(src=clean_src))) else: - return _cleanup_local_branch_names( - process.run_subprocess_with_output( - "git -C {src} for-each-ref refs/heads/ --format='%(refname:short)'".format(src=clean_src))) + return _cleanup_local_branch_names(process.run_subprocess_with_output("git -C {src} for-each-ref refs/heads/ --format='%(refname:short)'".format(src=clean_src))) @probed @@ -140,7 +135,7 @@ def tags(src_dir): def _cleanup_remote_branch_names(branch_names): - return [(b[b.rindex("/") + 1:]).strip() for b in branch_names if not b.endswith("/HEAD")] + return [(b[b.rindex("/") + 1 :]).strip() for b in branch_names if not b.endswith("/HEAD")] def _cleanup_local_branch_names(branch_names): diff --git a/solrorbit/utils/io.py b/solrorbit/utils/io.py index b7b06d00..669470f3 100644 --- a/solrorbit/utils/io.py +++ b/solrorbit/utils/io.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -46,6 +46,7 @@ class FileSource: """ FileSource is a wrapper around a plain file which simplifies testing of file I/O calls. """ + def __init__(self, file_name, mode, encoding="utf-8"): self.file_name = file_name self.mode = mode @@ -96,6 +97,7 @@ class MmapSource: """ MmapSource is a wrapper around a memory-mapped file which simplifies testing of file I/O calls. """ + def __init__(self, file_name, mode, encoding="utf-8"): self.file_name = file_name self.mode = mode @@ -156,6 +158,7 @@ class DictStringFileSourceFactory: It is intended for scenarios where multiple files may be read by client code. """ + def __init__(self, name_to_contents): self.name_to_contents = name_to_contents @@ -168,6 +171,7 @@ class StringAsFileSource: Implementation of ``FileSource`` intended for tests. It's kept close to ``FileSource`` to simplify maintenance but it is not meant to be used in production code. """ + def __init__(self, contents, mode, encoding="utf-8"): """ :param contents: The file contents as an array of strings. Each item in the array should correspond to one line. @@ -239,6 +243,7 @@ def ensure_dir(directory, mode=0o777): if directory: os.makedirs(directory, mode, exist_ok=True) + def ensure_symlink(source, link_name): """ Ensure that a symlink exists from link_name to source. @@ -266,12 +271,11 @@ def ensure_symlink(source, link_name): os.symlink(source, link_name) logger.info("Created symlink: %s -> %s", link_name, source) + def _zipdir(source_directory, archive): for root, _, files in os.walk(source_directory): for file in files: - archive.write( - filename=os.path.join(root, file), - arcname=os.path.relpath(os.path.join(root, file), os.path.join(source_directory, ".."))) + archive.write(filename=os.path.join(root, file), arcname=os.path.relpath(os.path.join(root, file), os.path.join(source_directory, ".."))) def is_archive(name): @@ -371,8 +375,7 @@ def _do_decompress_manually(target_directory, filename, decompressor_args, decom if _do_decompress_manually_external(target_directory, filename, base_path_without_extension, decompressor_args): return else: - logging.getLogger(__name__).warning("%s not found in PATH. Using standard library, decompression will take longer.", - decompressor_bin) + logging.getLogger(__name__).warning("%s not found in PATH. Using standard library, decompression will take longer.", decompressor_bin) _do_decompress_manually_with_lib(target_directory, filename, decompressor_lib(filename)) @@ -382,8 +385,7 @@ def _do_decompress_manually_external(target_directory, filename, base_path_witho try: subprocess.run(decompressor_args + [filename], stdout=new_file, stderr=subprocess.PIPE, check=True) except subprocess.CalledProcessError as err: - logging.getLogger(__name__).warning("Failed to decompress [%s] with [%s]. Error [%s]. Falling back to standard library.", - filename, err.cmd, err.stderr) + logging.getLogger(__name__).warning("Failed to decompress [%s] with [%s]. Error [%s]. Falling back to standard library.", filename, err.cmd, err.stderr) return False return True @@ -403,7 +405,7 @@ def _do_decompress_manually_with_lib(target_directory, filename, compressed_file def _do_decompress_zstd(target_directory, filename): path_without_extension = basename(splitext(filename)[0]) try: - with open(filename, 'rb') as compressed_file: + with open(filename, "rb") as compressed_file: zstd_decompressor = zstd.ZstdDecompressor() with open(os.path.join(target_directory, path_without_extension), "wb") as new_file: for chunk in zstd_decompressor.read_to_iter(compressed_file): @@ -490,6 +492,7 @@ class FileOffsetTable: The FileOffsetTable represents a persistent mapping from lines in a data file to their offset in bytes in the data file. This helps bulk-indexing clients to advance quickly to a certain position in a large data file. """ + def __init__(self, data_file_path, offset_table_path, mode): """ Creates a new FileOffsetTable instance. The constructor should not be called directly but instead the @@ -600,7 +603,7 @@ def prepare_file_offset_table(data_file_path, base_url, source_url, downloader): if not file_offset_table.is_valid(): if not source_url: try: - downloader.download(base_url, None, data_file_path + '.offset', None) + downloader.download(base_url, None, data_file_path + ".offset", None) except exceptions.DataError as e: if isinstance(e.cause, urllib.error.HTTPError) and (e.cause.code == 403 or e.cause.code == 404): logging.getLogger(__name__).info("Pre-generated offset file not found, will generate from corpus data") diff --git a/solrorbit/utils/jvm.py b/solrorbit/utils/jvm.py index 3754c059..4394b867 100644 --- a/solrorbit/utils/jvm.py +++ b/solrorbit/utils/jvm.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -44,8 +44,7 @@ def supports_option(java_home, option): :param option: The JVM option or combination of JVM options (separated by spaces) to check. :return: True iff the provided ``option`` is supported on this JVM. """ - return process.exit_status_as_bool( - lambda: process.run_subprocess_with_logging("{} {} -version".format(_java(java_home), option))) + return process.exit_status_as_bool(lambda: process.run_subprocess_with_logging("{} {} -version".format(_java(java_home), option))) def system_property(java_home, system_property_name): @@ -129,8 +128,7 @@ def resolve_path(majors, sysprop_reader=system_property): java_home = _resolve_single_path(major, mandatory=False, sysprop_reader=sysprop_reader) if java_home: return major, java_home - raise exceptions.SystemSetupError("Install a JDK with one of the versions {} and point to it with one of {}." - .format(majors, _checked_env_vars(majors))) + raise exceptions.SystemSetupError("Install a JDK with one of the versions {} and point to it with one of {}.".format(majors, _checked_env_vars(majors))) def _resolve_single_path(major, mandatory=True, sysprop_reader=system_property): @@ -143,6 +141,7 @@ def _resolve_single_path(major, mandatory=True, sysprop_reader=system_property): :param sysprop_reader: (Optional) only relevant for testing. :return: The resolved path to the JDK or ``None`` if ``mandatory`` is ``False`` and no appropriate JDK has been found. """ + def do_resolve(env_var, major): java_v_home = os.getenv(env_var) if java_v_home: @@ -167,8 +166,7 @@ def do_resolve(env_var, major): if java_home: return java_home elif mandatory: - raise exceptions.SystemSetupError("Neither {} nor {} point to a JDK {} installation.". - format(specific_env_var, generic_env_var, major)) + raise exceptions.SystemSetupError("Neither {} nor {} point to a JDK {} installation.".format(specific_env_var, generic_env_var, major)) else: return None diff --git a/solrorbit/utils/modules.py b/solrorbit/utils/modules.py index ac6f7098..cf19ff40 100644 --- a/solrorbit/utils/modules.py +++ b/solrorbit/utils/modules.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -45,6 +45,7 @@ class ComponentLoader: install hooks. A component may also consist of multiple Python modules. """ + def __init__(self, root_path, component_entry_point, recurse=True): """ Creates a new component loader. @@ -64,7 +65,7 @@ def _modules(self, module_paths, component_name): for filename in os.listdir(path): name, ext = os.path.splitext(filename) if ext.endswith(".py"): - root_relative_path = os.path.join(path, name)[len(self.root_path) + len(os.path.sep):] + root_relative_path = os.path.join(path, name)[len(self.root_path) + len(os.path.sep) :] module_name = "%s.%s" % (component_name, root_relative_path.replace(os.path.sep, ".")) yield module_name diff --git a/solrorbit/utils/net.py b/solrorbit/utils/net.py index 08d08ecc..6686ae5d 100644 --- a/solrorbit/utils/net.py +++ b/solrorbit/utils/net.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -46,16 +46,17 @@ def init(): proxy_url = os.getenv("http_proxy") if proxy_url and len(proxy_url) > 0: parsed_url = urllib3.util.parse_url(proxy_url) - logger.info("Connecting via proxy URL [%s] to the Internet (picked up from the env variable [http_proxy]).", - proxy_url) - __HTTP = urllib3.ProxyManager(proxy_url, - cert_reqs='CERT_REQUIRED', - ca_certs=certifi.where(), - # appropriate headers will only be set if there is auth info - proxy_headers=urllib3.make_headers(proxy_basic_auth=parsed_url.auth)) + logger.info("Connecting via proxy URL [%s] to the Internet (picked up from the env variable [http_proxy]).", proxy_url) + __HTTP = urllib3.ProxyManager( + proxy_url, + cert_reqs="CERT_REQUIRED", + ca_certs=certifi.where(), + # appropriate headers will only be set if there is auth info + proxy_headers=urllib3.make_headers(proxy_basic_auth=parsed_url.auth), + ) else: logger.info("Connecting directly to the Internet (no proxy support).") - __HTTP = urllib3.PoolManager(cert_reqs='CERT_REQUIRED', ca_certs=certifi.where()) + __HTTP = urllib3.PoolManager(cert_reqs="CERT_REQUIRED", ca_certs=certifi.where()) class Progress: @@ -96,7 +97,6 @@ def _download_from_s3_bucket(bucket_name, bucket_path, local_path, expected_size console.error("S3 support is optional. Install it with `python -m pip install solr-orbit[s3]`") raise - class S3ProgressAdapter: def __init__(self, size, progress): self._expected_size_in_bytes = size @@ -112,22 +112,17 @@ def __call__(self, bytes_amount): if expected_size_in_bytes is None: expected_size_in_bytes = bucket.Object(bucket_path).content_length progress_callback = S3ProgressAdapter(expected_size_in_bytes, progress_indicator) if progress_indicator else None - bucket.download_file(bucket_path, local_path, - Callback=progress_callback, - Config=boto3.s3.transfer.TransferConfig(use_threads=False)) + bucket.download_file(bucket_path, local_path, Callback=progress_callback, Config=boto3.s3.transfer.TransferConfig(use_threads=False)) def _build_gcs_object_url(bucket_name, bucket_path): # / and other special characters must be urlencoded in bucket and object names # ref: https://cloud.google.com/storage/docs/request-endpoints#encoding - return functools.reduce(urllib.parse.urljoin, [ - "https://storage.googleapis.com/storage/v1/b/", - f"{quote(bucket_name.strip('/'), safe='')}/", - "o/", - f"{quote(bucket_path.strip('/'), safe='')}", - "?alt=media" - ]) + return functools.reduce( + urllib.parse.urljoin, + ["https://storage.googleapis.com/storage/v1/b/", f"{quote(bucket_name.strip('/'), safe='')}/", "o/", f"{quote(bucket_path.strip('/'), safe='')}", "?alt=media"], + ) def _download_from_gcs_bucket(bucket_name, bucket_path, local_path, expected_size_in_bytes=None, progress_indicator=None): @@ -136,14 +131,16 @@ def _download_from_gcs_bucket(bucket_name, bucket_path, local_path, expected_siz import google.oauth2.credentials import google.auth.transport.requests as tr_requests import google.auth + # Using Google Resumable Media as the standard storage library doesn't support progress # (https://github.com/googleapis/python-storage/issues/27) from google.resumable_media.requests import ChunkedDownload + ro_scope = "https://www.googleapis.com/auth/devstorage.read_only" access_token = os.environ.get("GOOGLE_AUTH_TOKEN") if access_token: - credentials = google.oauth2.credentials.Credentials(token=access_token, scopes=(ro_scope, )) + credentials = google.oauth2.credentials.Credentials(token=access_token, scopes=(ro_scope,)) else: # https://google-auth.readthedocs.io/en/latest/user-guide.html credentials, _ = google.auth.default(scopes=(ro_scope,)) @@ -172,7 +169,7 @@ def download_from_bucket(blobstore, url, local_path, expected_size_in_bytes=None bucket_end_index = bucket_and_path.find("/") bucket = bucket_and_path[:bucket_end_index] # we need to remove the leading "/" - bucket_path = bucket_and_path[bucket_end_index + 1:] + bucket_path = bucket_and_path[bucket_end_index + 1 :] logger.info("Downloading from [%s] bucket [%s] and path [%s] to [%s].", blobstore, bucket, bucket_path, local_path) blob_downloader[blobstore](bucket, bucket_path, local_path, expected_size_in_bytes, progress_indicator) @@ -181,8 +178,7 @@ def download_from_bucket(blobstore, url, local_path, expected_size_in_bytes=None def download_http(url, local_path, expected_size_in_bytes=None, progress_indicator=None): - with __http().request("GET", url, preload_content=False, retries=10, - timeout=urllib3.Timeout(connect=45, read=240)) as r, open(local_path, "wb") as out_file: + with __http().request("GET", url, preload_content=False, retries=10, timeout=urllib3.Timeout(connect=45, read=240)) as r, open(local_path, "wb") as out_file: if r.status > 299: raise urllib.error.HTTPError(url, r.status, "", None, None) # noinspection PyBroadException @@ -193,7 +189,7 @@ def download_http(url, local_path, expected_size_in_bytes=None, progress_indicat except BaseException: size_from_content_header = None - chunk_size = 2 ** 16 + chunk_size = 2**16 bytes_read = 0 for chunk in r.stream(chunk_size): @@ -208,8 +204,7 @@ def _add_url_param(url, params): url_parsed = urlparse(url) query = parse_qs(url_parsed.query) query.update(params) - return urlunparse((url_parsed.scheme, url_parsed.netloc, url_parsed.path, url_parsed.params, - urlencode(query, doseq=True), url_parsed.fragment)) + return urlunparse((url_parsed.scheme, url_parsed.netloc, url_parsed.path, url_parsed.params, urlencode(query, doseq=True), url_parsed.fragment)) def download(url, local_path, expected_size_in_bytes=None, progress_indicator=None): @@ -239,8 +234,9 @@ def download(url, local_path, expected_size_in_bytes=None, progress_indicator=No if expected_size_in_bytes is not None and download_size != expected_size_in_bytes: if os.path.isfile(tmp_data_set_path): os.remove(tmp_data_set_path) - raise exceptions.DataError("Download of [%s] is corrupt. Downloaded [%d] bytes but [%d] bytes are expected. Please retry." % - (local_path, download_size, expected_size_in_bytes)) + raise exceptions.DataError( + "Download of [%s] is corrupt. Downloaded [%d] bytes but [%d] bytes are expected. Please retry." % (local_path, download_size, expected_size_in_bytes) + ) os.rename(tmp_data_set_path, local_path) diff --git a/solrorbit/utils/opts.py b/solrorbit/utils/opts.py index e45fe637..31dda566 100644 --- a/solrorbit/utils/opts.py +++ b/solrorbit/utils/opts.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -42,6 +42,7 @@ def csv_to_list(csv): else: return [e.strip() for e in csv.split(",")] + def to_bool(v): if v is None: return None @@ -88,7 +89,7 @@ def convert(v): def to_dict(arg, default_parser=kv_to_map): - if io.has_extension(arg, ".json") and ',' not in arg and ':' not in arg: + if io.has_extension(arg, ".json") and "," not in arg and ":" not in arg: with open(io.normalize_path(arg), mode="rt", encoding="utf-8") as f: return json.load(f) elif arg.startswith("{"): @@ -102,7 +103,7 @@ def bulleted_list_of(src_list): def double_quoted_list_of(src_list): - return ["\"{}\"".format(param) for param in src_list] + return ['"{}"'.format(param) for param in src_list] def make_list_of_close_matches(word_list, all_possibilities): @@ -122,11 +123,13 @@ def make_list_of_close_matches(word_list, all_possibilities): return close_matches + class StoreKeyPairAsDict(argparse.Action): """ Custom Argparse action that allows users to pass in a key:value pairs after specifying a parameter. Used as action for --number-of-docs parameter for create-workload subcommand. """ + def __call__(self, parser, namespace, values, option_string=None): custom_dict = {} @@ -138,12 +141,10 @@ def __call__(self, parser, namespace, values, option_string=None): for kv in kv_pairs: try: - k,v = kv.split(":") + k, v = kv.split(":") custom_dict[k] = v except ValueError: - raise exceptions.InvalidSyntax( - "StoreKeyPairAsDict: Could not convert string to dict due to invalid syntax." - ) + raise exceptions.InvalidSyntax("StoreKeyPairAsDict: Could not convert string to dict due to invalid syntax.") setattr(namespace, self.dest, custom_dict) return custom_dict @@ -266,8 +267,7 @@ def normalize_to_dict(arg): if self.argvalue == ClientOptions.DEFAULT_CLIENT_OPTIONS and self.target_hosts is not None: # --client-options unset but multi-clusters used in --target-hosts? apply options defaults for all cluster names. - self.parsed_options = {cluster_name: kv_to_map([ClientOptions.DEFAULT_CLIENT_OPTIONS]) - for cluster_name in self.target_hosts.all_hosts.keys()} + self.parsed_options = {cluster_name: kv_to_map([ClientOptions.DEFAULT_CLIENT_OPTIONS]) for cluster_name in self.target_hosts.all_hosts.keys()} else: self.parsed_options = to_dict(self.argvalue, default_parser=normalize_to_dict) diff --git a/solrorbit/utils/parse.py b/solrorbit/utils/parse.py index 16c610cd..684e364c 100644 --- a/solrorbit/utils/parse.py +++ b/solrorbit/utils/parse.py @@ -10,9 +10,7 @@ def parse_string_parameter(key: str, params: dict, default: str = None) -> str: if key not in params or not params[key]: if default is not None: return default - raise ConfigurationError( - "Value cannot be None for param {}".format(key) - ) + raise ConfigurationError("Value cannot be None for param {}".format(key)) if isinstance(params[key], str): return params[key] @@ -24,9 +22,7 @@ def parse_int_parameter(key: str, params: dict, default: int = None) -> int: if key not in params: if default is not None: return default - raise ConfigurationError( - "Value cannot be None for param {}".format(key) - ) + raise ConfigurationError("Value cannot be None for param {}".format(key)) if isinstance(params[key], int): return params[key] @@ -38,9 +34,7 @@ def parse_float_parameter(key: str, params: dict, default: float = None) -> floa if key not in params: if default: return default - raise ConfigurationError( - "Value cannot be None for param {}".format(key) - ) + raise ConfigurationError("Value cannot be None for param {}".format(key)) if isinstance(params[key], float): return params[key] @@ -52,9 +46,7 @@ def parse_bool_parameter(key: str, params: dict, default: bool = None) -> bool: if key not in params: if default is not None: return default - raise ConfigurationError( - "Value cannot be None for param {}".format(key) - ) + raise ConfigurationError("Value cannot be None for param {}".format(key)) if isinstance(params[key], bool): return params[key] diff --git a/solrorbit/utils/process.py b/solrorbit/utils/process.py index 52869c46..7796ef59 100644 --- a/solrorbit/utils/process.py +++ b/solrorbit/utils/process.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -63,7 +63,7 @@ def run_subprocess_with_out_and_err(command_line): sp = subprocess.Popen(command_line_args, stdout=subprocess.PIPE, stderr=subprocess.PIPE, stdin=subprocess.DEVNULL) sp.wait() out, err = sp.communicate() - return out.decode('UTF-8'), err.decode('UTF-8'), sp.returncode + return out.decode("UTF-8"), err.decode("UTF-8"), sp.returncode def run_subprocess_with_stderr(command_line): @@ -74,7 +74,7 @@ def run_subprocess_with_stderr(command_line): sp = subprocess.Popen(command_line_args, stdout=subprocess.DEVNULL, stderr=subprocess.PIPE, stdin=subprocess.DEVNULL) sp.wait() _, err = sp.communicate() - return err.decode('UTF-8'), sp.returncode + return err.decode("UTF-8"), sp.returncode def exit_status_as_bool(runnable, quiet=False): @@ -93,8 +93,9 @@ def exit_status_as_bool(runnable, quiet=False): return False -def run_subprocess_with_logging(command_line, header=None, level=logging.INFO, stdin=None, stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, env=None, detach=False, capture_output=False): +def run_subprocess_with_logging( + command_line, header=None, level=logging.INFO, stdin=None, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, env=None, detach=False, capture_output=False +): """ Runs the provided command line in a subprocess. All output will be captured by a logger. @@ -119,13 +120,9 @@ def run_subprocess_with_logging(command_line, header=None, level=logging.INFO, s logger.info(header) # pylint: disable=subprocess-popen-preexec-fn - with subprocess.Popen(command_line_args, - stdout=stdout, - stderr=stderr, - universal_newlines=True, - env=env, - stdin=stdin if stdin else None, - preexec_fn=pre_exec) as command_line_process: + with subprocess.Popen( + command_line_args, stdout=stdout, stderr=stderr, universal_newlines=True, env=env, stdin=stdin if stdin else None, preexec_fn=pre_exec + ) as command_line_process: stdout, _ = command_line_process.communicate() if stdout: logger.log(level=level, msg=stdout) @@ -136,10 +133,7 @@ def run_subprocess_with_logging(command_line, header=None, level=logging.INFO, s def is_benchmark_process(p): cmdline = p.cmdline() - return p.name() == "solr-orbit" or \ - (len(cmdline) > 1 and - os.path.basename(cmdline[0].lower()).startswith("python") and - os.path.basename(cmdline[1]) == "solr-orbit") + return p.name() == "solr-orbit" or (len(cmdline) > 1 and os.path.basename(cmdline[0].lower()).startswith("python") and os.path.basename(cmdline[1]) == "solr-orbit") def find_all_other_benchmark_processes(): @@ -160,6 +154,7 @@ def kill(p): time.sleep(1) except psutil.NoSuchProcess: break + for_all_other_processes(predicate, kill) diff --git a/solrorbit/utils/repo.py b/solrorbit/utils/repo.py index aadbbbb8..13743aba 100644 --- a/solrorbit/utils/repo.py +++ b/solrorbit/utils/repo.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -62,8 +62,7 @@ def __init__(self, default_directory, root_dir, repo_name, resource_name, offlin else: if not git.is_working_copy(self.repo_dir) and repo_name != "default": if io.exists(self.repo_dir): - raise exceptions.SystemSetupError("[{src}] must be a git repository.\n\nPlease run:\ngit -C {src} init" - .format(src=self.repo_dir)) + raise exceptions.SystemSetupError("[{src}] must be a git repository.\n\nPlease run:\ngit -C {src} init".format(src=self.repo_dir)) def update(self, distribution_version): try: @@ -71,8 +70,7 @@ def update(self, distribution_version): branch = versions.best_match(git.branches(self.repo_dir, remote=self.remote), distribution_version) if branch: # Allow uncommitted changes iff we do not have to change the branch - self.logger.info( - "Checking out [%s] in [%s] for distribution version [%s].", branch, self.repo_dir, distribution_version) + self.logger.info("Checking out [%s] in [%s] for distribution version [%s].", branch, self.repo_dir, distribution_version) git.checkout(self.repo_dir, branch=branch) self.logger.info("Rebasing on [%s] in [%s] for distribution version [%s].", branch, self.repo_dir, distribution_version) try: @@ -80,33 +78,26 @@ def update(self, distribution_version): self.revision = git.head_revision(self.repo_dir) except exceptions.SupplyError: self.logger.exception("Cannot rebase due to local changes in [%s]", self.repo_dir) - console.warn( - "Local changes in [%s] prevent %s update from remote. Please commit your changes." % - (self.repo_dir, self.resource_name)) + console.warn("Local changes in [%s] prevent %s update from remote. Please commit your changes." % (self.repo_dir, self.resource_name)) return else: - msg = "Could not find %s remotely for distribution version [%s]. Trying to find %s locally." % \ - (self.resource_name, distribution_version, self.resource_name) + msg = "Could not find %s remotely for distribution version [%s]. Trying to find %s locally." % (self.resource_name, distribution_version, self.resource_name) self.logger.warning(msg) branch = versions.best_match(git.branches(self.repo_dir, remote=False), distribution_version) if branch: if git.current_branch(self.repo_dir) != branch: - self.logger.info("Checking out [%s] in [%s] for distribution version [%s].", - branch, self.repo_dir, distribution_version) + self.logger.info("Checking out [%s] in [%s] for distribution version [%s].", branch, self.repo_dir, distribution_version) git.checkout(self.repo_dir, branch=branch) self.revision = git.head_revision(self.repo_dir) else: - self.logger.info("No local branch found for distribution version [%s] in [%s]. Checking tags.", - distribution_version, self.repo_dir) + self.logger.info("No local branch found for distribution version [%s] in [%s]. Checking tags.", distribution_version, self.repo_dir) tag = self._find_matching_tag(distribution_version) if tag: - self.logger.info("Checking out tag [%s] in [%s] for distribution version [%s].", - tag, self.repo_dir, distribution_version) + self.logger.info("Checking out tag [%s] in [%s] for distribution version [%s].", tag, self.repo_dir, distribution_version) git.checkout(self.repo_dir, branch=tag) self.revision = git.head_revision(self.repo_dir) else: - raise exceptions.SystemSetupError("Cannot find %s for distribution version %s" - % (self.resource_name, distribution_version)) + raise exceptions.SystemSetupError("Cannot find %s for distribution version %s" % (self.resource_name, distribution_version)) except exceptions.SupplyError as e: tb = sys.exc_info()[2] raise exceptions.DataError("Cannot update %s in [%s] (%s)." % (self.resource_name, self.repo_dir, e.message)).with_traceback(tb) @@ -148,9 +139,9 @@ def _use_default_cluster_configs_dir(self, distribution_version, cfg): def _select_branch_version(self, distribution_version, pc_path): # Branches have been moved into resources/cluster_configs branches = [b for b in os.listdir(pc_path) if os.path.isdir(os.path.join(pc_path, b)) and b != "main"] - branches.sort(key=lambda b: list(map(int, b.split('.'))), reverse=True) + branches.sort(key=lambda b: list(map(int, b.split("."))), reverse=True) self.logger.info("branches: %s", branches) - convert = lambda s: list(map(int, s.split('.'))) + convert = lambda s: list(map(int, s.split("."))) if distribution_version is not None: # Return a branch that is less than or equal to the distribution version for branch in branches: diff --git a/solrorbit/utils/s3_data_producer.py b/solrorbit/utils/s3_data_producer.py index 0200e679..897208ab 100644 --- a/solrorbit/utils/s3_data_producer.py +++ b/solrorbit/utils/s3_data_producer.py @@ -15,19 +15,25 @@ from boto3 import client from solrorbit import exceptions + try: from solrorbit.data_streaming.data_producer import DataProducer except ImportError: + class DataProducer: # pylint: disable=too-few-public-methods """Fallback when data_streaming package is not available.""" + + from solrorbit.workload.ingestion_manager import IngestionManager + class S3DataProducer(DataProducer): """ Generate data by downloading an object from S3. Will support downloading from multiple objects in the future. """ - def __init__(self, bucket:str, keys, client_options: dict, data_dir=None) -> None: + + def __init__(self, bucket: str, keys, client_options: dict, data_dir=None) -> None: """ Constructor. :param bucket: The S3 bucket to download from. @@ -45,29 +51,29 @@ def __init__(self, bucket:str, keys, client_options: dict, data_dir=None) -> Non self.chunk_size = IngestionManager.chunk_size * 1024**2 self.num_workers = os.cpu_count() * 2 - self.s3_client = client('s3') + self.s3_client = client("s3") except Exception as e: print(f"Error: {e}") def _get_next_key(self): - if len(self.keys) > 2 and self.keys.endswith('**'): + if len(self.keys) > 2 and self.keys.endswith("**"): processed_keys = set() while True: response = self.s3_client.list_objects(Bucket=self.bucket, Prefix=self.keys[:-2]) - for object in response['Contents']: - key = object['Key'] + for object in response["Contents"]: + key = object["Key"] if key not in processed_keys: processed_keys.add(key) - size = self.s3_client.head_object(Bucket=self.bucket, Key=key)['ContentLength'] + size = self.s3_client.head_object(Bucket=self.bucket, Key=key)["ContentLength"] if size == 0: return yield key self.logger.info("Waiting for next (or empty) S3 object to appear in target bucket") time.sleep(60) - elif len(self.keys) > 1 and self.keys[-1] == '*': + elif len(self.keys) > 1 and self.keys[-1] == "*": response = self.s3_client.list_objects(Bucket=self.bucket, Prefix=self.keys[:-1]) - for object in response['Contents']: - yield object['Key'] + for object in response["Contents"]: + yield object["Key"] else: yield self.keys @@ -77,7 +83,7 @@ def _get_next_downloader(self): # Obtain the object size. self.logger.info("Processing object %s", k) response = self.s3_client.head_object(Bucket=self.bucket, Key=k) - size = response['ContentLength'] + size = response["ContentLength"] yield self._s3_multipart_downloader(self.bucket, k, 0, size) def _gen_range_args(self, beg, end, chunk_size): @@ -93,14 +99,14 @@ def _gen_range_args(self, beg, end, chunk_size): r_end = end - 1 else: r_end = r_beg + chunk_size - 1 - ranges.append(f'bytes={r_beg}-{r_end}') + ranges.append(f"bytes={r_beg}-{r_end}") return ranges def _s3_get_object_subrange(self, args): "Download a subrange of an S3 object." bucket, key, range = args resp = self.s3_client.get_object(Bucket=bucket, Key=key, Range=range) - return resp['Body'].read() + return resp["Body"].read() def _s3_multipart_downloader(self, bucket, key, beg, end): """ @@ -113,17 +119,15 @@ def _s3_multipart_downloader(self, bucket, key, beg, end): # Ensure futures are garbage collected before more are issued, to not run out of memory. with ThreadPoolExecutor(max_workers=self.num_workers) as executor: for i in range(0, len(ranges), self.num_workers): - subranges = ranges[i:i+self.num_workers] - futures = [executor.submit(self._s3_get_object_subrange, (bucket, key, range)) - for range in subranges] + subranges = ranges[i : i + self.num_workers] + futures = [executor.submit(self._s3_get_object_subrange, (bucket, key, range)) for range in subranges] wait(futures) for future in futures: yield future.result() def _output_chunk(self, rsl, chunk_id): "Write a chunk into its file. It will be processed later by one ingestion client." - with open(os.path.join(self.data_dir, self.chunk_prefix + "{:05d}".format(chunk_id)), - "w", encoding='utf-8') as fh: + with open(os.path.join(self.data_dir, self.chunk_prefix + "{:05d}".format(chunk_id)), "w", encoding="utf-8") as fh: fh.write(rsl) def generate_chunked_data(self): @@ -133,9 +137,9 @@ def generate_chunked_data(self): downloaders = self._get_next_downloader() for downloader in downloaders: for chunk in downloader: - rsl = chunk.decode('utf-8') + rsl = chunk.decode("utf-8") i = len(rsl) - while i and rsl[i-1] != '\n': + while i and rsl[i - 1] != "\n": i -= 1 if i == 0: raise exceptions.DataStreamingError(f"could not locate document end in chunk {chunk_id}") @@ -163,6 +167,7 @@ def main(bucket: str, keys: str) -> None: producer = S3DataProducer(bucket, keys, None) producer.generate_chunked_data() -if __name__ == '__main__': + +if __name__ == "__main__": # pylint: disable = no-value-for-parameter main(*sys.argv[1:]) diff --git a/solrorbit/utils/sysstats.py b/solrorbit/utils/sysstats.py index faa987eb..90a2dc2d 100644 --- a/solrorbit/utils/sysstats.py +++ b/solrorbit/utils/sysstats.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -33,6 +33,7 @@ # noinspection PyBroadException try: import cpuinfo + cpuinfo_available = True except Exception: cpuinfo_available = False diff --git a/solrorbit/utils/versions.py b/solrorbit/utils/versions.py index dab62c75..087ba433 100644 --- a/solrorbit/utils/versions.py +++ b/solrorbit/utils/versions.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -84,6 +84,7 @@ class Version: Represents a version with components major, minor, patch and suffix (suffix is optional). Suffixes are not considered for version comparisons as its contents are opaque and a semantically correct order cannot be defined. """ + def __init__(self, major, minor, patch, suffix=None): self.major = major self.minor = minor @@ -133,8 +134,7 @@ def __init__(self, version): self.with_major = f"{int(self.major)}" self.with_minor = f"{int(self.major)}.{int(self.minor)}" self.with_patch = f"{int(self.major)}.{int(self.minor)}.{int(self.patch)}" - self.with_suffix = f"{int(self.major)}.{int(self.minor)}.{int(self.patch)}-{self.suffix}" if self.suffix \ - else None + self.with_suffix = f"{int(self.major)}.{int(self.minor)}.{int(self.patch)}-{self.suffix}" if self.suffix else None @property def all_versions(self): @@ -147,9 +147,7 @@ def all_versions(self): """ versions = [(self.with_suffix, "with_suffix")] if self.suffix else [] - versions.extend([(self.with_patch, "with_patch"), - (self.with_minor, "with_minor"), - (self.with_major, "with_major")]) + versions.extend([(self.with_patch, "with_patch"), (self.with_minor, "with_minor"), (self.with_major, "with_major")]) return versions diff --git a/solrorbit/version.py b/solrorbit/version.py index a8cd9828..c7d03639 100644 --- a/solrorbit/version.py +++ b/solrorbit/version.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an diff --git a/solrorbit/visualizations/benchmark_report_renderer.py b/solrorbit/visualizations/benchmark_report_renderer.py index 7f13fdce..9bd0551d 100644 --- a/solrorbit/visualizations/benchmark_report_renderer.py +++ b/solrorbit/visualizations/benchmark_report_renderer.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -25,6 +25,7 @@ # specific language governing permissions and limitations # under the License. + def render_results_html(test_run, cfg) -> str: """ Build an HTML benchmark report for the given TestRun. @@ -49,7 +50,7 @@ def render_results_html(test_run, cfg) -> str: "distribution-version": test_run.distribution_version, "distribution-flavor": test_run.distribution_flavor, "provision-config-revision": test_run.provision_config_revision, - } + }, } if getattr(test_run, "results", None): # results might already be a dict or an object @@ -59,36 +60,36 @@ def render_results_html(test_run, cfg) -> str: doc["results"] = test_run.results.as_dict() # 2) Pull top-level fields - test_id = doc.get("test-run-id", "") - orbit_ver = doc.get("benchmark-version", "") - orbit_rev = doc.get("benchmark-revision", "") - environment = doc.get("environment", "") - pipeline = doc.get("pipeline", "") - workload = doc.get("workload", "") - test_procedure= doc.get("test-procedure", "") + test_id = doc.get("test-run-id", "") + orbit_ver = doc.get("benchmark-version", "") + orbit_rev = doc.get("benchmark-revision", "") + environment = doc.get("environment", "") + pipeline = doc.get("pipeline", "") + workload = doc.get("workload", "") + test_procedure = doc.get("test-procedure", "") # 3) Cluster info - cluster_info = doc.get("cluster", {}) - distro_ver = cluster_info.get("distribution-version", "") - distro_flav = cluster_info.get("distribution-flavor", "") - prov_conf_rev = cluster_info.get("provision-config-revision", None) + cluster_info = doc.get("cluster", {}) + distro_ver = cluster_info.get("distribution-version", "") + distro_flav = cluster_info.get("distribution-flavor", "") + prov_conf_rev = cluster_info.get("provision-config-revision", None) # 4) Config table dict config_dict = { - "Solr Orbit Version": orbit_ver, + "Solr Orbit Version": orbit_ver, "Solr Orbit Revision (git)": orbit_rev, - "Environment": environment, - "Pipeline": pipeline, - "Workload": workload, - "Test Procedure": test_procedure, - "Distribution Version": distro_ver, - "Distribution Flavor": distro_flav, + "Environment": environment, + "Pipeline": pipeline, + "Workload": workload, + "Test Procedure": test_procedure, + "Distribution Version": distro_ver, + "Distribution Flavor": distro_flav, "Provision Config Revision": prov_conf_rev, } # 5) Extract op_metrics results_dict = doc.get("results", {}) or {} - op_metrics = results_dict.get("op_metrics", []) + op_metrics = results_dict.get("op_metrics", []) # Build rows table_rows = [] @@ -96,15 +97,17 @@ def render_results_html(test_run, cfg) -> str: th = item.get("throughput", {}) st = item.get("service_time", {}) clients = item.get("search_clients") or item.get("clients", "") or "–" - table_rows.append({ - "task": item.get("task", ""), - "operation": item.get("operation", ""), - "throughput_mean": th.get("mean"), - "throughput_unit": th.get("unit", ""), - "service_time_mean": st.get("mean"), - "service_time_unit": st.get("unit", ""), - "search_clients": clients, - }) + table_rows.append( + { + "task": item.get("task", ""), + "operation": item.get("operation", ""), + "throughput_mean": th.get("mean"), + "throughput_unit": th.get("unit", ""), + "service_time_mean": st.get("mean"), + "service_time_unit": st.get("unit", ""), + "search_clients": clients, + } + ) # 6) Render helpers def render_config_table(cfg_d): @@ -115,32 +118,17 @@ def render_config_table(cfg_d): return f"{rows}
" def render_metrics_table(rows): - header = ( - "" - "TaskOperation" - "Throughput (mean)" - "Service Time (mean)" - "Search Clients" - "" - ) + header = "TaskOperationThroughput (mean)Service Time (mean)Search Clients" body = "" for r in rows: - th_val = f"{r['throughput_mean']} {r['throughput_unit']}" if r['throughput_mean'] is not None else "–" - st_val = f"{r['service_time_mean']} {r['service_time_unit']}" if r['service_time_mean'] is not None else "–" - body += ( - "" - f"{r['task']}" - f"{r['operation']}" - f"{th_val}" - f"{st_val}" - f"{r['search_clients']}" - "" - ) + th_val = f"{r['throughput_mean']} {r['throughput_unit']}" if r["throughput_mean"] is not None else "–" + st_val = f"{r['service_time_mean']} {r['service_time_unit']}" if r["service_time_mean"] is not None else "–" + body += f"{r['task']}{r['operation']}{th_val}{st_val}{r['search_clients']}" return f"{header}{body}
" # 7) Put it all together cfg_table_html = render_config_table(config_dict) - metrics_html = render_metrics_table(table_rows) + metrics_html = render_metrics_table(table_rows) return f""" diff --git a/solrorbit/worker_coordinator/__init__.py b/solrorbit/worker_coordinator/__init__.py index 8b013d6a..dd40a82f 100644 --- a/solrorbit/worker_coordinator/__init__.py +++ b/solrorbit/worker_coordinator/__init__.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -26,11 +26,4 @@ # under the License. # expose only the minimum API -from .worker_coordinator import ( - WorkerCoordinatorActor, - PrepareBenchmark, - PreparationComplete, - StartBenchmark, - BenchmarkComplete, - TaskFinished -) +from .worker_coordinator import WorkerCoordinatorActor, PrepareBenchmark, PreparationComplete, StartBenchmark, BenchmarkComplete, TaskFinished diff --git a/solrorbit/worker_coordinator/errors.py b/solrorbit/worker_coordinator/errors.py index ff546969..e13209f6 100644 --- a/solrorbit/worker_coordinator/errors.py +++ b/solrorbit/worker_coordinator/errors.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -27,23 +27,24 @@ import re + def parse_error(error_metadata): - error = error_metadata['error'] + error = error_metadata["error"] status_code = None description = "error occured, check logs for details" operation = UnknownOperationError(description, None) - if 'status' in error_metadata: + if "status" in error_metadata: status_code = error_metadata["status"] - if 'reason' in error: - description = error['reason'] - matches = re.findall(r'\[([^]]*)\]', description) + if "reason" in error: + description = error["reason"] + matches = re.findall(r"\[([^]]*)\]", description) for match in matches: if match == "indices:admin/create": operation = IndexOperationError(description, "index-create", status_code) elif match == "indices:admin/delete": - operation = IndexOperationError(description, "index-delete", status_code) + operation = IndexOperationError(description, "index-delete", status_code) elif match == "indices:data/write/bulk": operation = IndexOperationError(description, "index-append", status_code) elif match == "indices:admin/refresh": @@ -51,17 +52,18 @@ def parse_error(error_metadata): elif match == "indices:admin/forcemerge": operation = IndexOperationError(description, "force-merge", status_code) elif match == "indices:data/read/search": - operation = SearchOperationError(description, "search", status_code) + operation = SearchOperationError(description, "search", status_code) return operation -class BenchmarkOperationError(): +class BenchmarkOperationError: def __init__(self, description, operation=None, status_code=None): self.description = description self.operation = operation self.status_code = status_code + class UnknownOperationError(BenchmarkOperationError): def get_error_message(self): return self.description @@ -76,6 +78,7 @@ def get_error_message(self): else: return self.description + class SearchOperationError(BenchmarkOperationError): def get_error_message(self): if self.status_code == 403: diff --git a/solrorbit/worker_coordinator/runner.py b/solrorbit/worker_coordinator/runner.py index 6bf075e8..c4b0e2e7 100644 --- a/solrorbit/worker_coordinator/runner.py +++ b/solrorbit/worker_coordinator/runner.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -70,6 +70,7 @@ def register_default_runners(): register_runner("paginated-search", _paginated_runner, async_runner=True) register_runner("scroll-search", _paginated_runner, async_runner=True) + def runner_for(operation_type): try: return __RUNNERS[operation_type] @@ -93,8 +94,7 @@ def register_runner(operation_type, runner, **kwargs): operation_type = operation_type.to_hyphenated_string() if not async_runner: - raise exceptions.BenchmarkAssertionError( - "Runner [{}] must be implemented as async runner and registered with async_runner=True.".format(str(runner))) + raise exceptions.BenchmarkAssertionError("Runner [{}] must be implemented as async runner and registered with async_runner=True.".format(str(runner))) if getattr(runner, "multi_cluster", False): if "__aenter__" in dir(runner) and "__aexit__" in dir(runner): @@ -121,6 +121,7 @@ def register_runner(operation_type, runner, **kwargs): __RUNNERS[operation_type] = _with_completion(_with_assertions(cluster_aware_runner)) + # Only intended for unit-testing! def remove_runner(operation_type): del __RUNNERS[operation_type] @@ -163,7 +164,7 @@ def _default_kw_params(self, params): "params": "request-params", "request_timeout": "request-timeout", } - full_result = {k: params.get(v) for (k, v) in kw_dict.items()} + full_result = {k: params.get(v) for (k, v) in kw_dict.items()} # filter Nones return dict(filter(lambda kv: kv[1] is not None, full_result.items())) @@ -178,8 +179,10 @@ def _transport_request_params(self, params): headers.update({"x-opaque-id": opaque_id}) return request_params, headers + request_context_holder = RequestContextHolder() + def time_func(func): async def advised(*args, **kwargs): request_context_holder.on_client_request_start() @@ -188,6 +191,7 @@ async def advised(*args, **kwargs): return response finally: request_context_holder.on_client_request_end() + return advised @@ -195,6 +199,7 @@ class Delegator: """ Mixin to unify delegate handling """ + def __init__(self, delegate, *args, **kwargs): super().__init__(*args, **kwargs) self.delegate = delegate @@ -371,8 +376,7 @@ async def __call__(self, *args): for assertion in params["assertions"]: self.check_assertion(op_name, assertion, return_value) else: - self.logger.debug("Skipping assertion check in [%s] as [%s] does not return a dict.", - op_name, repr(self.delegate)) + self.logger.debug("Skipping assertion check in [%s] as [%s] does not return a dict.", op_name, repr(self.delegate)) return return_value def __repr__(self, *args, **kwargs): @@ -390,10 +394,7 @@ def mandatory(params, key, op): try: return params[key] except KeyError: - raise exceptions.DataError( - f"Parameter source for operation '{str(op)}' did not provide the mandatory parameter '{key}'. " - f"Add it to your parameter source and try again.") - + raise exceptions.DataError(f"Parameter source for operation '{str(op)}' did not provide the mandatory parameter '{key}'. Add it to your parameter source and try again.") def escape(v): @@ -450,11 +451,11 @@ def parse(text: BytesIO, props: List[str], lists: List[str] = None) -> dict: return parsed - class Sleep(Runner): """ Sleeps for the specified duration not issuing any request. """ + @time_func async def __call__(self, client, params): sleep_duration = mandatory(params, "duration", "sleep") @@ -476,10 +477,10 @@ class DeleteBackupRepository(Runner): """ Deletes a snapshot repository """ + async def __call__(self, client, params): raise exceptions.BenchmarkError( - f"[{repr(self)}] is not yet implemented for Apache Solr. " - "Port to Solr Backup V2 API: https://solr.apache.org/guide/solr/latest/configuration-guide/backups.html" + f"[{repr(self)}] is not yet implemented for Apache Solr. Port to Solr Backup V2 API: https://solr.apache.org/guide/solr/latest/configuration-guide/backups.html" ) def __repr__(self, *args, **kwargs): @@ -494,10 +495,10 @@ class CreateBackupRepository(Runner): """ Creates a new snapshot repository """ + async def __call__(self, client, params): raise exceptions.BenchmarkError( - f"[{repr(self)}] is not yet implemented for Apache Solr. " - "Port to Solr Backup V2 API: https://solr.apache.org/guide/solr/latest/configuration-guide/backups.html" + f"[{repr(self)}] is not yet implemented for Apache Solr. Port to Solr Backup V2 API: https://solr.apache.org/guide/solr/latest/configuration-guide/backups.html" ) def __repr__(self, *args, **kwargs): @@ -512,10 +513,10 @@ class CreateBackup(Runner): """ Creates a new snapshot repository """ + async def __call__(self, client, params): raise exceptions.BenchmarkError( - f"[{repr(self)}] is not yet implemented for Apache Solr. " - "Port to Solr Backup V2 API: https://solr.apache.org/guide/solr/latest/configuration-guide/backups.html" + f"[{repr(self)}] is not yet implemented for Apache Solr. Port to Solr Backup V2 API: https://solr.apache.org/guide/solr/latest/configuration-guide/backups.html" ) def __repr__(self, *args, **kwargs): @@ -529,8 +530,7 @@ class WaitForBackupCreate(Runner): # Current implementation is OpenSearch-specific and will fail against Solr. async def __call__(self, client, params): raise exceptions.BenchmarkError( - f"[{repr(self)}] is not yet implemented for Apache Solr. " - "Port to Solr Backup V2 API: https://solr.apache.org/guide/solr/latest/configuration-guide/backups.html" + f"[{repr(self)}] is not yet implemented for Apache Solr. Port to Solr Backup V2 API: https://solr.apache.org/guide/solr/latest/configuration-guide/backups.html" ) def __repr__(self, *args, **kwargs): @@ -545,17 +545,16 @@ class RestoreBackup(Runner): """ Restores a snapshot from an already registered repository """ + async def __call__(self, client, params): raise exceptions.BenchmarkError( - f"[{repr(self)}] is not yet implemented for Apache Solr. " - "Port to Solr Backup V2 API: https://solr.apache.org/guide/solr/latest/configuration-guide/backups.html" + f"[{repr(self)}] is not yet implemented for Apache Solr. Port to Solr Backup V2 API: https://solr.apache.org/guide/solr/latest/configuration-guide/backups.html" ) def __repr__(self, *args, **kwargs): return "restore-snapshot" - class CompositeContext: ctx = contextvars.ContextVar("composite_context") @@ -579,16 +578,14 @@ def get(key): try: return CompositeContext._ctx()[key] except KeyError: - raise KeyError(f"Unknown property [{key}]. Currently recognized " - f"properties are [{', '.join(CompositeContext._ctx().keys())}].") from None + raise KeyError(f"Unknown property [{key}]. Currently recognized properties are [{', '.join(CompositeContext._ctx().keys())}].") from None @staticmethod def remove(key): try: CompositeContext._ctx().pop(key) except KeyError: - raise KeyError(f"Unknown property [{key}]. Currently recognized " - f"properties are [{', '.join(CompositeContext._ctx().keys())}].") from None + raise KeyError(f"Unknown property [{key}]. Currently recognized properties are [{', '.join(CompositeContext._ctx().keys())}].") from None @staticmethod def _ctx(): @@ -602,6 +599,7 @@ class Composite(Runner): """ Executes a complex request structure which is measured as one composite operation. """ + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.supported_op_types = [ @@ -639,8 +637,7 @@ async def run_stream(self, client, stream, connection_limit): streams = [] op_type = item["operation-type"] if op_type not in self.supported_op_types: - raise exceptions.BenchmarkAssertionError( - f"Unsupported operation-type [{op_type}]. Use one of [{', '.join(self.supported_op_types)}].") + raise exceptions.BenchmarkAssertionError(f"Unsupported operation-type [{op_type}]. Use one of [{', '.join(self.supported_op_types)}].") runner = RequestTiming(runner_for(op_type)) async with connection_limit: async with runner: @@ -670,11 +667,7 @@ async def __call__(self, client, params): max_connections = params.get("max-connections", sys.maxsize) async with CompositeContext(): response = await self.run_stream(client, requests, asyncio.BoundedSemaphore(max_connections)) - return { - "weight": 1, - "unit": "ops", - "dependent_timing": response - } + return {"weight": 1, "unit": "ops", "dependent_timing": response} def __repr__(self, *args, **kwargs): return "composite" @@ -694,19 +687,11 @@ async def __call__(self, client, params): return_value = await self.delegate(client, params) if isinstance(return_value, tuple) and len(return_value) == 2: total_ops, total_ops_unit = return_value - result = { - "weight": total_ops, - "unit": total_ops_unit, - "success": True - } + result = {"weight": total_ops, "unit": total_ops_unit, "success": True} elif isinstance(return_value, dict): result = return_value else: - result = { - "weight": 1, - "unit": "ops", - "success": True - } + result = {"weight": 1, "unit": "ops", "success": True} start = request_context.request_start end = request_context.request_end @@ -716,7 +701,7 @@ async def __call__(self, client, params): "absolute_time": absolute_time, "request_start": start, "request_end": end, - "service_time": end - start + "service_time": end - start, } return result @@ -752,6 +737,7 @@ async def __aenter__(self): async def __call__(self, client, params): # pylint: disable=import-outside-toplevel import socket + retry_until_success = params.get("retry-until-success", self.retry_until_success) if retry_until_success: max_attempts = sys.maxsize @@ -774,8 +760,7 @@ async def __call__(self, client, params): self.logger.debug("%s has returned successfully", repr(self.delegate)) return return_value else: - self.logger.info("[%s] has returned with an error: %s. Retrying in [%.2f] seconds.", - repr(self.delegate), return_value, sleep_time) + self.logger.info("[%s] has returned with an error: %s. Retrying in [%.2f] seconds.", repr(self.delegate), return_value, sleep_time) await asyncio.sleep(sleep_time) else: return return_value @@ -811,6 +796,7 @@ def __repr__(self, *args, **kwargs): # Error translation helpers # --------------------------------------------------------------------------- + def _translate_solr_error(e): """Translate a pysolr or requests exception to a BenchmarkTransportError.""" if isinstance(e, requests.exceptions.ConnectionError): @@ -821,9 +807,7 @@ def _translate_solr_error(e): status_code = e.response.status_code if e.response is not None else None if status_code == 404: return exceptions.BenchmarkNotFoundError(str(e), cause=e) - return exceptions.BenchmarkTransportError( - str(e), cause=e, status_code=status_code, - error=f"HTTP {status_code}", info=str(e)) + return exceptions.BenchmarkTransportError(str(e), cause=e, status_code=status_code, error=f"HTTP {status_code}", info=str(e)) if isinstance(e, pysolr.SolrError): msg = str(e) status_code = None @@ -833,8 +817,7 @@ def _translate_solr_error(e): if 100 <= code < 600: status_code = code break - return exceptions.BenchmarkTransportError( - msg, cause=e, status_code=status_code, error="SolrError", info=msg) + return exceptions.BenchmarkTransportError(msg, cause=e, status_code=status_code, error="SolrError", info=msg) return exceptions.BenchmarkTransportError(str(e), cause=e, error=type(e).__name__, info=str(e)) @@ -850,6 +833,7 @@ async def wrapper(*args, **kwargs): raise except (pysolr.SolrError, requests.exceptions.RequestException) as e: raise _translate_solr_error(e) from e + return wrapper @@ -857,6 +841,7 @@ async def wrapper(*args, **kwargs): # Helpers # --------------------------------------------------------------------------- + def _get_collection(params): """Extract and validate the collection name from params.""" collection = params.get("collection") or params.get("index") or None @@ -919,9 +904,7 @@ def _translate_ndjson_batch(lines): logging.getLogger(__name__).warning("Skipping malformed first line: %s", first_line) return docs - has_action_keys = isinstance(first_obj, dict) and any( - k in first_obj for k in ("index", "create", "update", "delete") - ) + has_action_keys = isinstance(first_obj, dict) and any(k in first_obj for k in ("index", "create", "update", "delete")) if has_action_keys: docs = _parse_bulk_pairs(first_line, it) @@ -968,9 +951,7 @@ def _translate_ndjson_stream(lines): _logger.warning("Skipping malformed first line: %s", first_line) return - has_action_keys = isinstance(first_obj, dict) and any( - k in first_obj for k in ("index", "create", "update", "delete") - ) + has_action_keys = isinstance(first_obj, dict) and any(k in first_obj for k in ("index", "create", "update", "delete")) if has_action_keys: yield from _stream_bulk_pairs(first_line, it) @@ -1040,9 +1021,9 @@ def _stream_bulk_pairs(first_action_line, lines_iter): if isinstance(value, list) and len(value) == 2: if all(isinstance(v, (int, float)) for v in value): doc[key] = f"{value[1]},{value[0]}" - elif isinstance(value, str) and len(value) == 19 and value[10] == ' ': - if value[4] == '-' and value[7] == '-' and value[13] == ':' and value[16] == ':': - doc[key] = value.replace(' ', 'T') + 'Z' + elif isinstance(value, str) and len(value) == 19 and value[10] == " ": + if value[4] == "-" and value[7] == "-" and value[13] == ":" and value[16] == ":": + doc[key] = value.replace(" ", "T") + "Z" yield doc action_line = next(lines_iter, "").strip() @@ -1100,6 +1081,7 @@ def _parse_bulk_pairs(first_action_line, lines_iter): # Base runner with automatic error translation # --------------------------------------------------------------------------- + class SolrRunner(Runner): """Base class for all Solr runners. @@ -1118,6 +1100,7 @@ def __init_subclass__(cls, **kwargs): # Runner: bulk-index # --------------------------------------------------------------------------- + class SolrBulkIndex(SolrRunner): """ Index documents from an NDJSON corpus into Solr. @@ -1190,6 +1173,7 @@ def __str__(self): # Runner: search # --------------------------------------------------------------------------- + class SolrSearch(SolrRunner): """ Execute a Solr search query. @@ -1207,10 +1191,7 @@ async def __call__(self, client, params): body = params.get("body") if body is not None: - resp = await _run_in_executor( - sc.raw_request, "POST", f"/solr/{collection}/query", body, - {"Content-Type": "application/json"} - ) + resp = await _run_in_executor(sc.raw_request, "POST", f"/solr/{collection}/query", body, {"Content-Type": "application/json"}) resp.raise_for_status() num_hits = resp.json().get("response", {}).get("numFound", 0) else: @@ -1242,6 +1223,7 @@ def __str__(self): # Runner: paginated search (cursorMark deep pagination) # --------------------------------------------------------------------------- + class SolrPaginatedSearch(SolrRunner): """ Execute a cursor-paginated Solr search using cursorMark. @@ -1298,6 +1280,7 @@ def __str__(self): # Runner: commit # --------------------------------------------------------------------------- + class SolrCommit(SolrRunner): """ Commit pending changes in Solr. @@ -1328,6 +1311,7 @@ def __str__(self): # Runner: optimize # --------------------------------------------------------------------------- + class SolrOptimize(SolrRunner): """ Force-merge Solr segments (optimize). @@ -1355,6 +1339,7 @@ def __str__(self): # Runner: wait-for-merges # --------------------------------------------------------------------------- + class SolrWaitForMerges(SolrRunner): """ Poll Solr node metrics until no active merge operations remain across any core. @@ -1380,8 +1365,7 @@ async def __call__(self, client, params): total_running += int(val) elif isinstance(raw, dict): for core_metrics in raw.get("metrics", {}).values(): - for key in ("INDEX.merge.major.running", - "INDEX.merge.minor.running"): + for key in ("INDEX.merge.major.running", "INDEX.merge.minor.running"): val = core_metrics.get(key, 0) if isinstance(val, dict): val = val.get("value", 0) @@ -1407,6 +1391,7 @@ def __str__(self): # Runner: create-collection # --------------------------------------------------------------------------- + class SolrCreateCollection(SolrRunner): """ Collection creation — optionally with configset upload. @@ -1466,6 +1451,7 @@ def __str__(self): # Runner: delete-collection # --------------------------------------------------------------------------- + class SolrDeleteCollection(SolrRunner): """ Delete a Solr collection, optionally deleting its configset too. @@ -1508,6 +1494,7 @@ def __str__(self): # Runner: raw-request # --------------------------------------------------------------------------- + class RawRequest(Runner): """ Send an arbitrary HTTP request to any Solr endpoint. diff --git a/solrorbit/worker_coordinator/scheduler.py b/solrorbit/worker_coordinator/scheduler.py index eac62f20..7935c352 100644 --- a/solrorbit/worker_coordinator/scheduler.py +++ b/solrorbit/worker_coordinator/scheduler.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -176,8 +176,7 @@ def remove_scheduler(name): class SimpleScheduler(ABC): @abstractmethod - def next(self, current): - ... + def next(self, current): ... class Scheduler(ABC): @@ -188,8 +187,7 @@ def after_request(self, now, weight, unit, request_meta_data): pass @abstractmethod - def next(self, current): - ... + def next(self, current): ... # Deprecated @@ -197,6 +195,7 @@ class DelegatingScheduler(SimpleScheduler): """ Delegates to a scheduler function and acts as an adapter to the rest of the system. """ + def __init__(self, delegate): super().__init__() self.delegate = delegate @@ -210,6 +209,7 @@ class LegacyWrappingScheduler(Scheduler): """ Wraps legacy implementations to stay backwards-compatible with older scheduler implementations. """ + def __init__(self, task, legacy_scheduler_class): super().__init__() # the legacy API was based on parameters so only provide these @@ -223,6 +223,7 @@ class Unthrottled(Scheduler): """ benchmark-internal scheduler to handle unthrottled tasks. """ + def next(self, current): return 0 @@ -235,6 +236,7 @@ class DeterministicScheduler(SimpleScheduler): Schedules the next execution according to a `deterministic distribution `_. """ + name = "deterministic" # pylint: disable=unused-variable @@ -258,6 +260,7 @@ class PoissonScheduler(SimpleScheduler): See also http://preshing.com/20111007/how-to-generate-random-timings-for-a-poisson-process/ """ + name = "poisson" # pylint: disable=unused-variable @@ -278,6 +281,7 @@ class UnitAwareScheduler(Scheduler): scheduling to the scheduler provided by the user in the workload. """ + def __init__(self, task, scheduler_class): super().__init__() self.task = task @@ -298,18 +302,22 @@ def after_request(self, now, weight, unit, request_meta_data): if expected_unit == "ops/s": weight = 1 if self.first_request: - logging.getLogger(__name__).warning("Task [%s] throttles based on [%s] but results [%s]. " - "Please specify the target throughput in [%s] instead.", - self.task, expected_unit, actual_unit, actual_unit) + logging.getLogger(__name__).warning( + "Task [%s] throttles based on [%s] but results [%s]. Please specify the target throughput in [%s] instead.", + self.task, + expected_unit, + actual_unit, + actual_unit, + ) else: - raise exceptions.BenchmarkAssertionError(f"Target throughput for [{self.task}] is specified in " - f"[{expected_unit}] but the task throughput is measured " - f"in [{actual_unit}].") + raise exceptions.BenchmarkAssertionError( + f"Target throughput for [{self.task}] is specified in [{expected_unit}] but the task throughput is measured in [{actual_unit}]." + ) self.first_request = False self.current_weight = weight # throughput in requests/s for this client - target_throughput = (self.task.target_throughput.value / self.task.clients / self.current_weight) + target_throughput = self.task.target_throughput.value / self.task.clients / self.current_weight self.scheduler = self.scheduler_class(self.task, target_throughput) def next(self, current): diff --git a/solrorbit/worker_coordinator/worker_coordinator.py b/solrorbit/worker_coordinator/worker_coordinator.py index 8db170fe..f0b0771a 100644 --- a/solrorbit/worker_coordinator/worker_coordinator.py +++ b/solrorbit/worker_coordinator/worker_coordinator.py @@ -53,6 +53,8 @@ from solrorbit.workload import WorkloadProcessorRegistry, load_workload, load_workload_plugins, ingestion_manager from solrorbit.utils import convert, console, net from solrorbit.worker_coordinator.errors import parse_error + + ################################## # # Messages sent between worker_coordinators @@ -81,6 +83,7 @@ class PrepareWorkload: Initiates preparation of a workload. """ + def __init__(self, cfg, workload): """ :param cfg: Solr Orbit internal configuration object. @@ -111,6 +114,7 @@ class WorkerTask: """ Unit of work that should be completed by the low-level TaskExecutionActor """ + func: Callable params: dict @@ -210,10 +214,11 @@ def __init__(self, metrics, next_task_scheduled_in): self.metrics = metrics self.next_task_scheduled_in = next_task_scheduled_in + def load_redline_config(): config = configparser.ConfigParser() - benchmark_home = os.environ.get('BENCHMARK_HOME') or os.environ['HOME'] - benchmark_ini = benchmark_home + '/.benchmark/benchmark.ini' + benchmark_home = os.environ.get("BENCHMARK_HOME") or os.environ["HOME"] + benchmark_ini = benchmark_home + "/.benchmark/benchmark.ini" if not os.path.isfile(benchmark_ini): console.println(f"WARNING: redline config file {benchmark_ini} not found. Proceeding with default values.") return {} @@ -223,20 +228,13 @@ def load_redline_config(): if "redline" in config: redline = config["redline"] - for key in [ - "scale_step", - "scaledown_percentage", - "post_scaledown_sleep", - "max_cpu_usage", - "cpu_window_seconds", - "cpu_check_interval", - "max_clients" - ]: + for key in ["scale_step", "scaledown_percentage", "post_scaledown_sleep", "max_cpu_usage", "cpu_window_seconds", "cpu_check_interval", "max_clients"]: if key in redline: config_object[key] = redline[key] return config_object + class ConfigureFeedbackScaling: DEFAULT_SLEEP_SECONDS = 30 DEFAULT_SCALE_STEP = 5 @@ -244,8 +242,19 @@ class ConfigureFeedbackScaling: DEFAULT_CPU_WINDOW_SECONDS = 30 DEFAULT_CPU_CHECK_INTERVAL = 30 - def __init__(self, scale_step=None, scale_down_pct=None, sleep_seconds=None, max_clients=None, cpu_max=None, - cpu_window_seconds=None, cpu_check_interval=None, metrics_index=None, test_run_id=None, cfg=None): + def __init__( + self, + scale_step=None, + scale_down_pct=None, + sleep_seconds=None, + max_clients=None, + cpu_max=None, + cpu_window_seconds=None, + cpu_check_interval=None, + metrics_index=None, + test_run_id=None, + cfg=None, + ): config_object = load_redline_config() @@ -256,31 +265,37 @@ def __init__(self, scale_step=None, scale_down_pct=None, sleep_seconds=None, max self.cpu_window_seconds = int(cpu_window_seconds if cpu_window_seconds is not None else config_object.get("cpu_window_seconds", self.DEFAULT_CPU_WINDOW_SECONDS)) self.cpu_check_interval = int(cpu_check_interval if cpu_check_interval is not None else config_object.get("cpu_check_interval", self.DEFAULT_CPU_CHECK_INTERVAL)) self.max_clients = max_clients - self.cpu_max=cpu_max - self.cfg=cfg - self.metrics_index=metrics_index - self.test_run_id=test_run_id + self.cpu_max = cpu_max + self.cfg = cfg + self.metrics_index = metrics_index + self.test_run_id = test_run_id + class EnableFeedbackScaling: pass + class DisableFeedbackScaling: pass + class FeedbackState(Enum): """Various states for the FeedbackActor""" + NEUTRAL = "neutral" SCALING_DOWN = "scaling_down" SLEEP = "sleep" SCALING_UP = "scaling_up" DISABLED = "disabled" + class StartFeedbackActor: def __init__(self, error_queue=None, queue_lock=None, shared_states=None): self.shared_states = shared_states self.error_queue = error_queue self.queue_lock = queue_lock + # pylint: disable=too-many-public-methods class FeedbackActor(actor.BenchmarkActor): POST_SCALEDOWN_SECONDS = 30 @@ -299,7 +314,7 @@ def __init__(self) -> None: self.sleep_start_time = time.perf_counter() self.last_error_time = time.perf_counter() - FeedbackActor.POST_SCALEDOWN_SECONDS self.last_scaleup_time = time.perf_counter() - FeedbackActor.POST_SCALEDOWN_SECONDS - self.max_stable_clients = 0 # the value we want to return at the end of the test + self.max_stable_clients = 0 # the value we want to return at the end of the test # These will be passed in via StartFeedbackActor: self.error_queue = None self.queue_lock = None @@ -314,7 +329,7 @@ def __init__(self) -> None: self.cpu_window_seconds = None self.cpu_check_interval = None self.metrics_index = None - self.test_run_id=None + self.test_run_id = None def receiveMsg_StartFeedbackActor(self, msg, sender) -> None: """ @@ -351,25 +366,28 @@ def receiveMsg_ConfigureFeedbackScaling(self, msg, sender): # CPU feedback related items self.cpu_window_seconds = msg.cpu_window_seconds self.cpu_check_interval = msg.cpu_check_interval - self.test_run_id=msg.test_run_id - self.cfg=msg.cfg + self.test_run_id = msg.test_run_id + self.cfg = msg.cfg self.metrics_index = msg.metrics_index if msg.cpu_max: self.max_cpu_threshold = msg.cpu_max self.logger.info( - "Feedback actor has received the following configuration: Max clients = %s, scale step = %d, scale down percentage = %f, sleep time = %d", - self.total_client_count, self.num_clients_to_scale_up, self.percentage_clients_to_scale_down, self.POST_SCALEDOWN_SECONDS + "Feedback actor has received the following configuration: Max clients = %s, scale step = %d, scale down percentage = %f, sleep time = %d", + self.total_client_count, + self.num_clients_to_scale_up, + self.percentage_clients_to_scale_down, + self.POST_SCALEDOWN_SECONDS, ) def receiveMsg_ActorExitRequest(self, msg, sender): console.info("Redline test finished. Maximum stable client number reached: %d" % self.total_active_client_count) self.logger.info("FeedbackActor received ActorExitRequest and will shutdown") - if hasattr(self, 'shared_client_states'): + if hasattr(self, "shared_client_states"): self.shared_client_states.clear() def receiveMsg_ResetErrorThreshold(self, msg, sender): """Reset the max error threshold to allow scaling up again.""" - self.max_error_threshold = float('inf') + self.max_error_threshold = float("inf") self.logger.info("Error threshold has been reset, allowing full scale-up") def check_for_errors(self) -> List[Dict[str, Any]]: @@ -394,16 +412,16 @@ def clear_queue(self) -> None: def handle_state(self) -> None: current_time = time.perf_counter() # check CPU usage every N seconds - if (self.max_cpu_threshold and current_time - self.last_cpu_check >= self.cpu_check_interval): + if self.max_cpu_threshold and current_time - self.last_cpu_check >= self.cpu_check_interval: self._check_cpu_usage() self.last_cpu_check = current_time errors = self.check_for_errors() - sys.stdout.write("\x1b[s") # Save cursor position - sys.stdout.write("\x1b[1B") # Move cursor down 1 line - sys.stdout.write("\r\x1b[2K") # Clear the line + sys.stdout.write("\x1b[s") # Save cursor position + sys.stdout.write("\x1b[1B") # Move cursor down 1 line + sys.stdout.write("\r\x1b[2K") # Clear the line sys.stdout.write(f"[Redline] Active clients: {self.total_active_client_count}") - sys.stdout.write("\x1b[u") # Restore cursor position + sys.stdout.write("\x1b[u") # Restore cursor position sys.stdout.flush() if self.state == FeedbackState.DISABLED: @@ -425,9 +443,8 @@ def handle_state(self) -> None: return if self.state == FeedbackState.NEUTRAL: - self.max_stable_clients = max(self.max_stable_clients, self.total_active_client_count) # update the max number of stable clients - if (current_time - self.last_error_time >= self.POST_SCALEDOWN_SECONDS and - current_time - self.last_scaleup_time >= self.WAKEUP_INTERVAL): + self.max_stable_clients = max(self.max_stable_clients, self.total_active_client_count) # update the max number of stable clients + if current_time - self.last_error_time >= self.POST_SCALEDOWN_SECONDS and current_time - self.last_scaleup_time >= self.WAKEUP_INTERVAL: self.logger.info("No errors in the last %d seconds, scaling up", self.POST_SCALEDOWN_SECONDS) self.state = FeedbackState.SCALING_UP return @@ -497,10 +514,7 @@ def scale_up(self) -> None: clients_activated = 0 inactive_clients = [ - (worker_id, client_id) - for worker_id, client_states in self.shared_client_states.items() - for client_id, active in client_states.items() - if not active + (worker_id, client_id) for worker_id, client_states in self.shared_client_states.items() for client_id, active in client_states.items() if not active ] random.shuffle(inactive_clients) @@ -527,6 +541,7 @@ def _check_cpu_usage(self): "Disable redline testing or remove 'redline.max_cpu_usage' from your config." ) + class WorkerCoordinatorActor(actor.BenchmarkActor): RESET_RELATIVE_TIME_MARKER = "reset_relative_time" @@ -602,8 +617,7 @@ def receiveMsg_StartBenchmark(self, msg, sender): @actor.no_retry("worker_coordinator") # pylint: disable=no-value-for-parameter def receiveMsg_WorkloadPrepared(self, msg, sender): - self.transition_when_all_children_responded(sender, msg, - expected_status=None, new_status=None, transition=self._after_workload_prepared) + self.transition_when_all_children_responded(sender, msg, expected_status=None, new_status=None, transition=self._after_workload_prepared) @actor.no_retry("worker_coordinator") # pylint: disable=no-value-for-parameter def receiveMsg_JoinPointReached(self, msg, sender): @@ -633,14 +647,7 @@ def start_worker(self, worker_coordinator, worker_id, cfg, workload, allocations self.send(worker_coordinator, StartWorker(worker_id, cfg, workload, allocations, self.feedback_actor, error_queue, queue_lock, shared_states)) def start_feedbackActor(self, shared_states): - self.send( - self.feedback_actor, - StartFeedbackActor( - shared_states=shared_states, - error_queue=self.coordinator.error_queue, - queue_lock=self.coordinator.queue_lock - ) - ) + self.send(self.feedback_actor, StartFeedbackActor(shared_states=shared_states, error_queue=self.coordinator.error_queue, queue_lock=self.coordinator.queue_lock)) def drive_at(self, worker_coordinator, client_start_timestamp): self.send(worker_coordinator, Drive(client_start_timestamp)) @@ -679,25 +686,33 @@ def _after_workload_prepared(self): for child in self.children: self.send(child, thespian.actors.ActorExitRequest()) self.children = [] - self.send(self.start_sender, PreparationComplete( - # older versions (pre 6.3.0) don't expose build_flavor because the only (implicit) flavor was "oss" - cluster_version.get("build_flavor", "oss"), - cluster_version.get("number"), - cluster_version.get("build_hash") - )) + self.send( + self.start_sender, + PreparationComplete( + # older versions (pre 6.3.0) don't expose build_flavor because the only (implicit) flavor was "oss" + cluster_version.get("build_flavor", "oss"), + cluster_version.get("number"), + cluster_version.get("build_hash"), + ), + ) def on_benchmark_complete(self, metrics): self.send(self.start_sender, BenchmarkComplete(metrics)) def load_local_config(coordinator_config): - cfg = config.auto_load_local_config(coordinator_config, additional_sections=[ - # only copy the relevant bits - "workload", "worker_coordinator", "client", - # due to distribution version... - "builder", - "telemetry" - ]) + cfg = config.auto_load_local_config( + coordinator_config, + additional_sections=[ + # only copy the relevant bits + "workload", + "worker_coordinator", + "client", + # due to distribution version... + "builder", + "telemetry", + ], + ) # set root path (normally done by the main entry point) cfg.add(config.Scope.application, "node", "benchmark.root", paths.benchmark_root()) return cfg @@ -707,6 +722,7 @@ class TaskExecutionActor(actor.BenchmarkActor): """ This class should be used for long-running tasks, as it ensures they do not block the actor's messaging system """ + def __init__(self): super().__init__() self.pool = concurrent.futures.ThreadPoolExecutor(max_workers=1) @@ -731,8 +747,7 @@ def receiveMsg_StartTaskLoop(self, msg, sender): def receiveMsg_DoTask(self, msg, sender): # actor can arbitrarily execute code based on these messages. if anyone besides our parent sends a task, ignore if sender != self.parent: - msg = f"TaskExecutionActor expected message from [{self.parent}] but the received the following from " \ - f"[{sender}]: {vars(msg)}" + msg = f"TaskExecutionActor expected message from [{self.parent}] but the received the following from [{sender}]: {vars(msg)}" raise exceptions.BenchmarkError(msg) task = msg.task if self.executor_future is not None: @@ -765,6 +780,7 @@ def receiveMsg_BenchmarkFailure(self, msg, sender): # sent by our no_retry infrastructure; forward to master self.send(self.parent, msg) + class WorkloadPreparationActor(actor.BenchmarkActor): class Status(Enum): INITIALIZING = "initializing" @@ -811,27 +827,23 @@ def receiveMsg_PrepareWorkload(self, msg, sender): # the workload might have been loaded on a different machine (the coordinator machine) so we force a workload # update to ensure we use the latest version of plugins. load_workload(self.cfg) - load_workload_plugins(self.cfg, self.workload.name, register_workload_processor=tpr.register_workload_processor, - force_update=True) + load_workload_plugins(self.cfg, self.workload.name, register_workload_processor=tpr.register_workload_processor, force_update=True) # we expect on_prepare_workload can take a long time. seed a queue of tasks and delegate to child workers self.children = [self._create_task_executor() for _ in range(num_cores(self.cfg))] for processor in tpr.processors: self.processors.put(processor) self._seed_tasks(self.processors.get()) - self.send_to_children_and_transition(self, StartTaskLoop(self.workload.name, self.cfg), self.Status.INITIALIZING, - self.Status.PROCESSOR_RUNNING) + self.send_to_children_and_transition(self, StartTaskLoop(self.workload.name, self.cfg), self.Status.INITIALIZING, self.Status.PROCESSOR_RUNNING) def resume(self): if not self.processors.empty(): self._seed_tasks(self.processors.get()) - self.send_to_children_and_transition(self, StartTaskLoop(self.workload.name, self.cfg), self.Status.PROCESSOR_COMPLETE, - self.Status.PROCESSOR_RUNNING) + self.send_to_children_and_transition(self, StartTaskLoop(self.workload.name, self.cfg), self.Status.PROCESSOR_COMPLETE, self.Status.PROCESSOR_RUNNING) else: self.send(self.original_sender, WorkloadPrepared()) def _seed_tasks(self, processor): - self.tasks = list(WorkerTask(func, params) for func, params in - processor.on_prepare_workload(self.workload, self.data_root_dir)) + self.tasks = list(WorkerTask(func, params) for func, params in processor.on_prepare_workload(self.workload, self.data_root_dir)) def _create_task_executor(self): return self.createActor(TaskExecutionActor) @@ -848,13 +860,11 @@ def receiveMsg_ReadyForWork(self, msg, sender): @actor.no_retry("workload preparator") # pylint: disable=no-value-for-parameter def receiveMsg_WorkerIdle(self, msg, sender): - self.transition_when_all_children_responded(sender, msg, self.Status.PROCESSOR_RUNNING, - self.Status.PROCESSOR_COMPLETE, self.resume) + self.transition_when_all_children_responded(sender, msg, self.Status.PROCESSOR_RUNNING, self.Status.PROCESSOR_COMPLETE, self.resume) def num_cores(cfg): - return int(cfg.opts("system", "available.cores", mandatory=False, - default_value=multiprocessing.cpu_count())) + return int(cfg.opts("system", "available.cores", mandatory=False, default_value=multiprocessing.cpu_count())) class WorkerCoordinator: @@ -954,18 +964,10 @@ def prepare_benchmark(self, t): self.workload = t self.test_procedure = select_test_procedure(self.config, self.workload) self.quiet = self.config.opts("system", "quiet.mode", mandatory=False, default_value=False) - downsample_factor = int(self.config.opts( - "reporting", "metrics.request.downsample.factor", - mandatory=False, default_value=1)) - self.metrics_store = metrics.metrics_store(cfg=self.config, - workload=self.workload.name, - test_procedure=self.test_procedure.name, - read_only=False) - - self.sample_post_processor = DefaultSamplePostprocessor(self.metrics_store, - downsample_factor, - self.workload.meta_data, - self.test_procedure.meta_data) + downsample_factor = int(self.config.opts("reporting", "metrics.request.downsample.factor", mandatory=False, default_value=1)) + self.metrics_store = metrics.metrics_store(cfg=self.config, workload=self.workload.name, test_procedure=self.test_procedure.name, read_only=False) + + self.sample_post_processor = DefaultSamplePostprocessor(self.metrics_store, downsample_factor, self.workload.meta_data, self.test_procedure.meta_data) clients = self.create_clients() @@ -1044,8 +1046,7 @@ def start_benchmark(self): self.number_of_steps = len(allocator.join_points) - 1 self.tasks_per_join_point = allocator.tasks_per_joinpoint - self.logger.info("Solr Orbit consists of [%d] steps executed by [%d] clients.", - self.number_of_steps, len(self.allocations)) + self.logger.info("Solr Orbit consists of [%d] steps executed by [%d] clients.", self.number_of_steps, len(self.allocations)) # avoid flooding the log if there are too many clients if allocator.clients < 128: self.logger.info("Allocation matrix:\n%s", "\n".join([str(a) for a in self.allocations])) @@ -1072,8 +1073,9 @@ def start_benchmark(self): for client_id in clients: self.shared_client_dict[worker_id][client_id] = False # and send it along with the start_worker message. This way, the worker can pass it down to its assigned clients - self.target.start_worker(worker, worker_id, self.config, self.workload, client_allocations, - self.error_queue, self.queue_lock, shared_states=self.shared_client_dict[worker_id]) + self.target.start_worker( + worker, worker_id, self.config, self.workload, client_allocations, self.error_queue, self.queue_lock, shared_states=self.shared_client_dict[worker_id] + ) else: self.target.start_worker(worker, worker_id, self.config, self.workload, client_allocations) self.workers.append(worker) @@ -1093,18 +1095,21 @@ def start_benchmark(self): cpu_window_seconds = self.config.opts("workload", "redline.cpu_window_seconds", default_value=0) cpu_check_interval = self.config.opts("workload", "redline.cpu_check_interval", default_value=0) - self.target.send(self.target.feedback_actor, ConfigureFeedbackScaling( - scale_step=scale_step, - scale_down_pct=scale_down_pct, - sleep_seconds=sleep_seconds, - max_clients=max_clients, - cpu_max=cpu_max, - cpu_window_seconds=cpu_window_seconds, - cpu_check_interval=cpu_check_interval, - cfg=self.config, - metrics_index=metrics_index, - test_run_id=test_run_id - )) + self.target.send( + self.target.feedback_actor, + ConfigureFeedbackScaling( + scale_step=scale_step, + scale_down_pct=scale_down_pct, + sleep_seconds=sleep_seconds, + max_clients=max_clients, + cpu_max=cpu_max, + cpu_window_seconds=cpu_window_seconds, + cpu_check_interval=cpu_check_interval, + cfg=self.config, + metrics_index=metrics_index, + test_run_id=test_run_id, + ), + ) self.target.start_feedbackActor(self.shared_client_dict) self.update_progress_message() @@ -1112,8 +1117,7 @@ def start_benchmark(self): def joinpoint_reached(self, worker_id, worker_local_timestamp, task_allocations): self.currently_completed += 1 self.workers_completed_current_step[worker_id] = (worker_local_timestamp, time.perf_counter()) - self.logger.info("[%d/%d] workers reached join point [%d/%d].", - self.currently_completed, len(self.workers), self.current_step + 1, self.number_of_steps) + self.logger.info("[%d/%d] workers reached join point [%d/%d].", self.currently_completed, len(self.workers), self.current_step + 1, self.number_of_steps) # if we're in redline test mode, disable the feedback actor and pause all clients when we're at a joinpoint if self.config.opts("workload", "redline.test", mandatory=False): self.target.send(self.target.feedback_actor, DisableFeedbackScaling()) @@ -1174,8 +1178,7 @@ def move_to_next_task(self, workers_curr_step): for worker_id, worker in enumerate(self.workers): worker_ended_task_at, master_received_msg_at = workers_curr_step[worker_id] worker_start_timestamp = worker_ended_task_at + (start_next_task - master_received_msg_at) - self.logger.info("Scheduling next task for worker id [%d] at their timestamp [%f] (master timestamp [%f])", - worker_id, worker_start_timestamp, start_next_task) + self.logger.info("Scheduling next task for worker id [%d] at their timestamp [%f] (master timestamp [%f])", worker_id, worker_start_timestamp, start_next_task) self.target.drive_at(worker, worker_start_timestamp) def may_complete_current_task(self, task_allocations): @@ -1185,9 +1188,11 @@ def may_complete_current_task(self, task_allocations): # while this list could contain multiple items, it should always be the same task (but multiple # different clients) so any item is sufficient. current_join_point = joinpoints_completing_parent[0].task - self.logger.info("Tasks before join point [%s] are able to complete the parent structure. Checking " - "if all [%d] clients have finished yet.", - current_join_point, len(current_join_point.clients_executing_completing_task)) + self.logger.info( + "Tasks before join point [%s] are able to complete the parent structure. Checking if all [%d] clients have finished yet.", + current_join_point, + len(current_join_point.clients_executing_completing_task), + ) pending_client_ids = [] for client_id in current_join_point.clients_executing_completing_task: @@ -1240,18 +1245,17 @@ def update_progress_message(self, task_finished=False): # we only count clients which actually contribute to progress. If clients are executing tasks eternally in a parallel # structure, we should not count them. The reason is that progress depends entirely on the client(s) that execute the # task that is completing the parallel structure. - progress_per_client = [s.task_progress - for s in self.most_recent_sample_per_client.values() if s.task_progress is not None] + progress_per_client = [s.task_progress for s in self.most_recent_sample_per_client.values() if s.task_progress is not None] if not progress_per_client: # No clients have reported. - progress_per_client = [(0.0, '%')] + progress_per_client = [(0.0, "%")] num_clients = len(progress_per_client) total_progress = sum([p[0] for p in progress_per_client]) / num_clients - units = { p[1] for p in progress_per_client } + units = {p[1] for p in progress_per_client} assert len(units) == 1, "Encountered mix of disparate units while tracking task progress" unit = units.pop() - if unit != '%': + if unit != "%": self.progress_publisher.print("Running %s" % tasks, "[%4.1f GB]" % total_progress) else: if task_finished: @@ -1270,15 +1274,15 @@ def post_process_samples(self): self.raw_profile_samples = [] if len(profile_samples) > 0: if self.profile_metrics_post_processor is None: - self.profile_metrics_post_processor = ProfileMetricsSamplePostprocessor(self.metrics_store, - self.workload.meta_data, - self.test_procedure.meta_data) + self.profile_metrics_post_processor = ProfileMetricsSamplePostprocessor(self.metrics_store, self.workload.meta_data, self.test_procedure.meta_data) self.profile_metrics_post_processor(profile_samples) -class SamplePostprocessor(): + +class SamplePostprocessor: """ Parent class used to process samples into the metrics store """ + def __init__(self, metrics_store, workload_meta_data, test_procedure_meta_data): self.logger = logging.getLogger(__name__) self.metrics_store = metrics_store @@ -1297,6 +1301,7 @@ class DefaultSamplePostprocessor(SamplePostprocessor): """ Processes operational and correctness metric samples by merging and adding to the metrics store """ + def __init__(self, metrics_store, downsample_factor, workload_meta_data, test_procedure_meta_data): super().__init__(metrics_store, workload_meta_data, test_procedure_meta_data) self.throughput_calculator = ThroughputCalculator() @@ -1349,44 +1354,73 @@ def __call__(self, raw_samples): if idx % self.downsample_factor == 0: final_sample_count += 1 - meta_data = self.merge( - self.workload_meta_data, - self.test_procedure_meta_data, - sample.operation_meta_data, - sample.task.meta_data, - sample.request_meta_data) - - self.metrics_store.put_value_cluster_level(name="latency", value=convert.seconds_to_ms(sample.latency), - unit="ms", task=sample.task.name, - operation=sample.operation_name, operation_type=sample.operation_type, - sample_type=sample.sample_type, absolute_time=sample.absolute_time, - relative_time=sample.relative_time, meta_data=meta_data) - - self.metrics_store.put_value_cluster_level(name="service_time", value=convert.seconds_to_ms(sample.service_time), - unit="ms", task=sample.task.name, - operation=sample.operation_name, operation_type=sample.operation_type, - sample_type=sample.sample_type, absolute_time=sample.absolute_time, - relative_time=sample.relative_time, meta_data=meta_data) - - self.metrics_store.put_value_cluster_level(name="client_processing_time", - value=convert.seconds_to_ms(sample.client_processing_time), - unit="ms", task=sample.task.name, - operation=sample.operation_name, operation_type=sample.operation_type, - sample_type=sample.sample_type, absolute_time=sample.absolute_time, - relative_time=sample.relative_time, meta_data=meta_data) - - self.metrics_store.put_value_cluster_level(name="processing_time", value=convert.seconds_to_ms(sample.processing_time), - unit="ms", task=sample.task.name, - operation=sample.operation_name, operation_type=sample.operation_type, - sample_type=sample.sample_type, absolute_time=sample.absolute_time, - relative_time=sample.relative_time, meta_data=meta_data) + meta_data = self.merge(self.workload_meta_data, self.test_procedure_meta_data, sample.operation_meta_data, sample.task.meta_data, sample.request_meta_data) + + self.metrics_store.put_value_cluster_level( + name="latency", + value=convert.seconds_to_ms(sample.latency), + unit="ms", + task=sample.task.name, + operation=sample.operation_name, + operation_type=sample.operation_type, + sample_type=sample.sample_type, + absolute_time=sample.absolute_time, + relative_time=sample.relative_time, + meta_data=meta_data, + ) + + self.metrics_store.put_value_cluster_level( + name="service_time", + value=convert.seconds_to_ms(sample.service_time), + unit="ms", + task=sample.task.name, + operation=sample.operation_name, + operation_type=sample.operation_type, + sample_type=sample.sample_type, + absolute_time=sample.absolute_time, + relative_time=sample.relative_time, + meta_data=meta_data, + ) + + self.metrics_store.put_value_cluster_level( + name="client_processing_time", + value=convert.seconds_to_ms(sample.client_processing_time), + unit="ms", + task=sample.task.name, + operation=sample.operation_name, + operation_type=sample.operation_type, + sample_type=sample.sample_type, + absolute_time=sample.absolute_time, + relative_time=sample.relative_time, + meta_data=meta_data, + ) + + self.metrics_store.put_value_cluster_level( + name="processing_time", + value=convert.seconds_to_ms(sample.processing_time), + unit="ms", + task=sample.task.name, + operation=sample.operation_name, + operation_type=sample.operation_type, + sample_type=sample.sample_type, + absolute_time=sample.absolute_time, + relative_time=sample.relative_time, + meta_data=meta_data, + ) for timing in sample.dependent_timings: - self.metrics_store.put_value_cluster_level(name="service_time", value=convert.seconds_to_ms(timing.service_time), - unit="ms", task=timing.task.name, - operation=timing.operation_name, operation_type=timing.operation_type, - sample_type=timing.sample_type, absolute_time=timing.absolute_time, - relative_time=timing.relative_time, meta_data=meta_data) + self.metrics_store.put_value_cluster_level( + name="service_time", + value=convert.seconds_to_ms(timing.service_time), + unit="ms", + task=timing.task.name, + operation=timing.operation_name, + operation_type=timing.operation_type, + sample_type=timing.sample_type, + absolute_time=timing.absolute_time, + relative_time=timing.relative_time, + meta_data=meta_data, + ) end = time.perf_counter() self.logger.debug("Storing latency and service time took [%f] seconds.", (end - start)) @@ -1396,17 +1430,20 @@ def __call__(self, raw_samples): self.logger.debug("Calculating throughput took [%f] seconds.", (end - start)) start = end for task, samples in aggregates.items(): - meta_data = self.merge( - self.workload_meta_data, - self.test_procedure_meta_data, - task.operation.meta_data, - task.meta_data - ) + meta_data = self.merge(self.workload_meta_data, self.test_procedure_meta_data, task.operation.meta_data, task.meta_data) for absolute_time, relative_time, sample_type, throughput, throughput_unit in samples: - self.metrics_store.put_value_cluster_level(name="throughput", value=throughput, unit=throughput_unit, task=task.name, - operation=task.operation.name, operation_type=task.operation.type, - sample_type=sample_type, absolute_time=absolute_time, - relative_time=relative_time, meta_data=meta_data) + self.metrics_store.put_value_cluster_level( + name="throughput", + value=throughput, + unit=throughput_unit, + task=task.name, + operation=task.operation.name, + operation_type=task.operation.type, + sample_type=sample_type, + absolute_time=absolute_time, + relative_time=relative_time, + meta_data=meta_data, + ) end = time.perf_counter() self.logger.debug("Storing throughput took [%f] seconds.", (end - start)) start = end @@ -1419,8 +1456,7 @@ def __call__(self, raw_samples): self.metrics_store.flush(refresh=False) end = time.perf_counter() self.logger.debug("Flushing the metrics store took [%f] seconds.", (end - start)) - self.logger.debug("Postprocessing [%d] raw samples (downsampled to [%d] samples) took [%f] seconds in total.", - len(raw_samples), final_sample_count, (end - total_start)) + self.logger.debug("Postprocessing [%d] raw samples (downsampled to [%d] samples) took [%f] seconds in total.", len(raw_samples), final_sample_count, (end - total_start)) class ProfileMetricsSamplePostprocessor(SamplePostprocessor): @@ -1482,8 +1518,7 @@ def __call__(self, raw_samples): self.metrics_store.flush(refresh=False) end = time.perf_counter() self.logger.debug("Flushing the metrics store took [%f] seconds.", (end - start)) - self.logger.debug("Postprocessing [%d] raw samples (downsampled to [%d] samples) took [%f] seconds in total.", - len(raw_samples), final_sample_count, (end - total_start)) + self.logger.debug("Postprocessing [%d] raw samples (downsampled to [%d] samples) took [%f] seconds in total.", len(raw_samples), final_sample_count, (end - total_start)) def calculate_worker_assignments(host_configs, client_count): @@ -1541,10 +1576,7 @@ def __init__(self): self.allocations = [] def add(self, client_id, tasks): - self.allocations.append({ - "client_id": client_id, - "tasks": tasks - }) + self.allocations.append({"client_id": client_id, "tasks": tasks}) def is_joinpoint(self, task_index): return all(isinstance(t.task, JoinPoint) for t in self.tasks(task_index)) @@ -1621,8 +1653,9 @@ def receiveMsg_StartWorker(self, msg, sender): @actor.no_retry("worker") # pylint: disable=no-value-for-parameter def receiveMsg_Drive(self, msg, sender): sleep_time = datetime.timedelta(seconds=msg.client_start_timestamp - time.perf_counter()) - self.logger.info("Worker[%d] is continuing its work at task index [%d] on [%f], that is in [%s].", - self.worker_id, self.current_task_index, msg.client_start_timestamp, sleep_time) + self.logger.info( + "Worker[%d] is continuing its work at task index [%d] on [%f], that is in [%s].", self.worker_id, self.current_task_index, msg.client_start_timestamp, sleep_time + ) self.start_driving = True self.wakeupAfter(sleep_time) @@ -1631,11 +1664,9 @@ def receiveMsg_CompleteCurrentTask(self, msg, sender): # finish now ASAP. Remaining samples will be sent with the next WakeupMessage. We will also need to skip to the next # JoinPoint. But if we are already at a JoinPoint at the moment, there is nothing to do. if self.at_joinpoint(): - self.logger.info("Worker[%s] has received CompleteCurrentTask but is currently at join point at index [%d]. Ignoring.", - str(self.worker_id), self.current_task_index) + self.logger.info("Worker[%s] has received CompleteCurrentTask but is currently at join point at index [%d]. Ignoring.", str(self.worker_id), self.current_task_index) else: - self.logger.info("Worker[%s] has received CompleteCurrentTask. Completing tasks at index [%d].", - str(self.worker_id), self.current_task_index) + self.logger.info("Worker[%s] has received CompleteCurrentTask. Completing tasks at index [%d].", str(self.worker_id), self.current_task_index) self.complete.set() @actor.no_retry("worker") # pylint: disable=no-value-for-parameter @@ -1647,37 +1678,25 @@ def receiveMsg_WakeupMessage(self, msg, sender): else: current_samples = self.send_samples() if self.cancel.is_set(): - self.logger.info("Worker[%s] has detected that benchmark has been cancelled. Notifying master...", - str(self.worker_id)) + self.logger.info("Worker[%s] has detected that benchmark has been cancelled. Notifying master...", str(self.worker_id)) self.send(self.master, actor.BenchmarkCancelled()) elif self.executor_future is not None and self.executor_future.done(): e = self.executor_future.exception(timeout=0) if e: currentTasks = self.client_allocations.tasks(self.current_task_index) detailed_error = ( - f"Benchmark operation failed:\n" - f"Worker ID: {self.worker_id}\n" - f"Task: {', '.join(t.task.task.name for t in currentTasks)}\n" - f"Workload: {self.workload.name if self.workload else 'Unknown'}\n" - f"Test Procedure: {self.workload.selected_test_procedure_or_default}\n" - f"Cause: {e.cause if hasattr(e, 'cause') and e.cause is not None else 'Unknown'}" + f"Benchmark operation failed:\n" + f"Worker ID: {self.worker_id}\n" + f"Task: {', '.join(t.task.task.name for t in currentTasks)}\n" + f"Workload: {self.workload.name if self.workload else 'Unknown'}\n" + f"Test Procedure: {self.workload.selected_test_procedure_or_default}\n" + f"Cause: {e.cause if hasattr(e, 'cause') and e.cause is not None else 'Unknown'}" ) detailed_error += f"\nError: {str(e)}" - self.logger.exception( - "Worker[%s] has detected a benchmark failure:\n%s", - str(self.worker_id), - detailed_error, - exc_info=e - ) + self.logger.exception("Worker[%s] has detected a benchmark failure:\n%s", str(self.worker_id), detailed_error, exc_info=e) - self.send( - self.master, - actor.BenchmarkFailure( - detailed_error, - str(e) - ) - ) + self.send(self.master, actor.BenchmarkFailure(detailed_error, str(e))) else: self.logger.info("Worker[%s] is ready for the next task.", str(self.worker_id)) self.executor_future = None @@ -1685,13 +1704,13 @@ def receiveMsg_WakeupMessage(self, msg, sender): else: if current_samples and len(current_samples) > 0: most_recent_sample = current_samples[-1] - if most_recent_sample.task_progress is not None and most_recent_sample.task_progress[1] == '%': - self.logger.debug("Worker[%s] is executing [%s] (%.2f%% complete).", - str(self.worker_id), most_recent_sample.task, most_recent_sample.task_progress[0] * 100.0) + if most_recent_sample.task_progress is not None and most_recent_sample.task_progress[1] == "%": + self.logger.debug( + "Worker[%s] is executing [%s] (%.2f%% complete).", str(self.worker_id), most_recent_sample.task, most_recent_sample.task_progress[0] * 100.0 + ) else: # TODO: This could be misleading given that one worker could execute more than one task... - self.logger.debug("Worker[%s] is executing [%s] (dependent eternal task).", - str(self.worker_id), most_recent_sample.task) + self.logger.debug("Worker[%s] is executing [%s] (dependent eternal task).", str(self.worker_id), most_recent_sample.task) else: self.logger.debug("Worker[%s] is executing (no samples).", str(self.worker_id)) self.wakeupAfter(datetime.timedelta(seconds=self.wakeup_interval)) @@ -1731,14 +1750,27 @@ def drive(self): # There may be a situation where there are more (parallel) tasks than workers. If we were asked to complete all tasks, we not # only need to complete actively running tasks but actually all scheduled tasks until we reach the next join point. if self.complete.is_set(): - self.logger.info("Worker[%d] skips tasks at index [%d] because it has been asked to complete all " - "tasks until next join point.", self.worker_id, self.current_task_index) + self.logger.info( + "Worker[%d] skips tasks at index [%d] because it has been asked to complete all tasks until next join point.", self.worker_id, self.current_task_index + ) else: self.logger.info("Worker[%d] is executing tasks at index [%d].", self.worker_id, self.current_task_index) self.sampler = DefaultSampler(start_timestamp=time.perf_counter(), buffer_size=self.sample_queue_size) self.profile_sampler = ProfileMetricsSampler(start_timestamp=time.perf_counter(), buffer_size=self.sample_queue_size) - executor = AsyncIoAdapter(self.config, self.workload, task_allocations, self.sampler, self.profile_sampler, - self.cancel, self.complete, self.on_error, self.shared_states, self.feedback_actor, self.error_queue, self.queue_lock) + executor = AsyncIoAdapter( + self.config, + self.workload, + task_allocations, + self.sampler, + self.profile_sampler, + self.cancel, + self.complete, + self.on_error, + self.shared_states, + self.feedback_actor, + self.error_queue, + self.queue_lock, + ) self.executor_future = self.pool.submit(executor) self.wakeupAfter(datetime.timedelta(seconds=self.wakeup_interval)) @@ -1782,33 +1814,67 @@ def samples(self): pass return samples + class DefaultSampler(Sampler): """ Encapsulates management of gathered default samples (operational and correctness metrics). """ - def add(self, task, client_id, sample_type, meta_data, absolute_time, request_start, latency, service_time, - client_processing_time, processing_time, throughput, ops, ops_unit, time_period, task_progress, - dependent_timing=None): + def add( + self, + task, + client_id, + sample_type, + meta_data, + absolute_time, + request_start, + latency, + service_time, + client_processing_time, + processing_time, + throughput, + ops, + ops_unit, + time_period, + task_progress, + dependent_timing=None, + ): try: self.q.put_nowait( - DefaultSample(client_id, absolute_time, request_start, self.start_timestamp, task, sample_type, meta_data, - latency, service_time, client_processing_time, processing_time, throughput, ops, ops_unit, time_period, - task_progress, dependent_timing)) + DefaultSample( + client_id, + absolute_time, + request_start, + self.start_timestamp, + task, + sample_type, + meta_data, + latency, + service_time, + client_processing_time, + processing_time, + throughput, + ops, + ops_unit, + time_period, + task_progress, + dependent_timing, + ) + ) except queue.Full: self.logger.warning("Dropping sample for [%s] due to a full sampling queue.", task.operation.name) + class ProfileMetricsSampler(Sampler): """ Encapsulates management of gathered profile metrics samples. """ - def add(self, task, client_id, sample_type, meta_data, absolute_time, request_start, time_period, task_progress, - dependent_timing=None): + def add(self, task, client_id, sample_type, meta_data, absolute_time, request_start, time_period, task_progress, dependent_timing=None): try: self.q.put_nowait( - ProfileMetricsSample(client_id, absolute_time, request_start, self.start_timestamp, task, sample_type, meta_data, - time_period, task_progress, dependent_timing)) + ProfileMetricsSample(client_id, absolute_time, request_start, self.start_timestamp, task, sample_type, meta_data, time_period, task_progress, dependent_timing) + ) except queue.Full: self.logger.warning("Dropping sample for [%s] due to a full sampling queue.", task.operation.name) @@ -1817,8 +1883,8 @@ class Sample: """ Basic information used by metrics store to keep track of samples """ - def __init__(self, client_id, absolute_time, request_start, task_start, task, sample_type, request_meta_data, - time_period, task_progress, dependent_timing=None): + + def __init__(self, client_id, absolute_time, request_start, task_start, task, sample_type, request_meta_data, time_period, task_progress, dependent_timing=None): self.client_id = client_id self.absolute_time = absolute_time self.request_start = request_start @@ -1848,16 +1914,34 @@ def relative_time(self): return self.request_start - self.task_start def __repr__(self, *args, **kwargs): - return f"[{self.absolute_time}; {self.relative_time}] [client [{self.client_id}]] [{self.task}] " \ - f"[{self.sample_type}]" + return f"[{self.absolute_time}; {self.relative_time}] [client [{self.client_id}]] [{self.task}] [{self.sample_type}]" + class DefaultSample(Sample): """ Stores the operational and correctness metrics to later put into the metrics store """ - def __init__(self, client_id, absolute_time, request_start, task_start, task, sample_type, request_meta_data, latency, - service_time, client_processing_time, processing_time, throughput, total_ops, total_ops_unit, time_period, - task_progress, dependent_timing=None): + + def __init__( + self, + client_id, + absolute_time, + request_start, + task_start, + task, + sample_type, + request_meta_data, + latency, + service_time, + client_processing_time, + processing_time, + throughput, + total_ops, + total_ops_unit, + time_period, + task_progress, + dependent_timing=None, + ): super().__init__(client_id, absolute_time, request_start, task_start, task, sample_type, request_meta_data, time_period, task_progress, dependent_timing) self.latency = latency self.service_time = service_time @@ -1871,14 +1955,33 @@ def __init__(self, client_id, absolute_time, request_start, task_start, task, sa def dependent_timings(self): if self._dependent_timing: for t in self._dependent_timing: - yield DefaultSample(self.client_id, t["absolute_time"], t["request_start"], self.task_start, self.task, - self.sample_type, self.request_meta_data, 0, t["service_time"], 0, 0, 0, self.total_ops, - self.total_ops_unit, self.time_period, self.task_progress, None) + yield DefaultSample( + self.client_id, + t["absolute_time"], + t["request_start"], + self.task_start, + self.task, + self.sample_type, + self.request_meta_data, + 0, + t["service_time"], + 0, + 0, + 0, + self.total_ops, + self.total_ops_unit, + self.time_period, + self.task_progress, + None, + ) def __repr__(self, *args, **kwargs): - return f"[{self.absolute_time}; {self.relative_time}] [client [{self.client_id}]] [{self.task}] " \ - f"[{self.sample_type}]: [{self.latency}s] request latency, [{self.service_time}s] service time, " \ - f"[{self.total_ops} {self.total_ops_unit}]" + return ( + f"[{self.absolute_time}; {self.relative_time}] [client [{self.client_id}]] [{self.task}] " + f"[{self.sample_type}]: [{self.latency}s] request latency, [{self.service_time}s] service time, " + f"[{self.total_ops} {self.total_ops_unit}]" + ) + class ProfileMetricsSample(Sample): """ @@ -1889,8 +1992,18 @@ class ProfileMetricsSample(Sample): def dependent_timings(self): if self._dependent_timing: for t in self._dependent_timing: - yield ProfileMetricsSample(self.client_id, t["absolute_time"], t["request_start"], self.task_start, self.task, - self.sample_type, self.request_meta_data, self.time_period, self.task_progress, None) + yield ProfileMetricsSample( + self.client_id, + t["absolute_time"], + t["request_start"], + self.task_start, + self.task, + self.sample_type, + self.request_meta_data, + self.time_period, + self.task_progress, + None, + ) def select_test_procedure(config, t): @@ -1898,8 +2011,10 @@ def select_test_procedure(config, t): selected_test_procedure = t.find_test_procedure_or_default(test_procedure_name) if not selected_test_procedure: - raise exceptions.SystemSetupError("Unknown test_procedure [%s] for workload [%s]. You can list the available workloads and their " - "test_procedures with %s list workloads." % (test_procedure_name, t.name, PROGRAM_NAME)) + raise exceptions.SystemSetupError( + "Unknown test_procedure [%s] for workload [%s]. You can list the available workloads and their " + "test_procedures with %s list workloads." % (test_procedure_name, t.name, PROGRAM_NAME) + ) return selected_test_procedure @@ -1908,6 +2023,7 @@ class TaskStats: """ Stores per task numbers needed for throughput calculation in between multiple calculations. """ + def __init__(self, bucket_interval, sample_type, start_time): self.unprocessed = [] self.total_count = 0 @@ -1994,9 +2110,9 @@ def calculate_task_throughput(self, task, current_samples, bucket_interval_secs) if task not in self.task_stats: first_sample = current_samples[0] - self.task_stats[task] = ThroughputCalculator.TaskStats(bucket_interval=bucket_interval_secs, - sample_type=first_sample.sample_type, - start_time=first_sample.absolute_time - first_sample.time_period) + self.task_stats[task] = ThroughputCalculator.TaskStats( + bucket_interval=bucket_interval_secs, sample_type=first_sample.sample_type, start_time=first_sample.absolute_time - first_sample.time_period + ) current = self.task_stats[task] count = current.total_count last_sample = None @@ -2017,12 +2133,16 @@ def calculate_task_throughput(self, task, current_samples, bucket_interval_secs) if current.can_calculate_throughput(): current.finish_bucket(count) - task_throughput.append((sample.absolute_time, - sample.relative_time, - current.sample_type, - current.throughput, - # we calculate throughput per second - f"{sample.total_ops_unit}/s")) + task_throughput.append( + ( + sample.absolute_time, + sample.relative_time, + current.sample_type, + current.throughput, + # we calculate throughput per second + f"{sample.total_ops_unit}/s", + ) + ) else: current.unprocessed.append(sample) @@ -2030,28 +2150,33 @@ def calculate_task_throughput(self, task, current_samples, bucket_interval_secs) # interval (mainly needed to ensure we show throughput data in test mode) if last_sample is not None and current.can_add_final_throughput_sample(): current.finish_bucket(count) - task_throughput.append((last_sample.absolute_time, - last_sample.relative_time, - current.sample_type, - current.throughput, - f"{last_sample.total_ops_unit}/s")) + task_throughput.append((last_sample.absolute_time, last_sample.relative_time, current.sample_type, current.throughput, f"{last_sample.total_ops_unit}/s")) return task_throughput def map_task_throughput(self, current_samples): throughput = [] for sample in current_samples: - throughput.append((sample.absolute_time, - sample.relative_time, - sample.sample_type, - sample.throughput, - f"{sample.total_ops_unit}/s")) + throughput.append((sample.absolute_time, sample.relative_time, sample.sample_type, sample.throughput, f"{sample.total_ops_unit}/s")) return throughput class AsyncIoAdapter: - def __init__(self, cfg, workload, task_allocations, sampler, profile_sampler, cancel, complete, abort_on_error, - shared_states=None, feedback_actor=None, error_queue=None, queue_lock=None): + def __init__( + self, + cfg, + workload, + task_allocations, + sampler, + profile_sampler, + cancel, + complete, + abort_on_error, + shared_states=None, + feedback_actor=None, + error_queue=None, + queue_lock=None, + ): self.cfg = cfg self.workload = workload self.task_allocations = task_allocations @@ -2099,8 +2224,7 @@ def build_clients(all_hosts, all_client_options): # Properly size the internal connection pool to match the number of expected clients but allow the user # to override it if needed. client_count = len(self.task_allocations) - clients = build_clients(self.cfg.opts("client", "hosts").all_hosts, - self.cfg.opts("client", "options").with_max_connections(client_count)) + clients = build_clients(self.cfg.opts("client", "hosts").all_hosts, self.cfg.opts("client", "options").with_max_connections(client_count)) self.logger.info("Task assertions enabled: %s", str(self.assertions_enabled)) runner.enable_assertions(self.assertions_enabled) @@ -2123,8 +2247,21 @@ def build_clients(all_hosts, all_client_options): # need to start from (client) index 0 in both cases instead of 0 for indexA and 4 for indexB. schedule = schedule_for(task_allocation, params_per_task[task]) async_executor = AsyncExecutor( - client_id, task, schedule, clients, self.sampler, self.profile_sampler, self.cancel, self.complete, - task.error_behavior(self.abort_on_error), self.cfg, self.shared_states, self.feedback_actor, self.error_queue, self.queue_lock) + client_id, + task, + schedule, + clients, + self.sampler, + self.profile_sampler, + self.cancel, + self.complete, + task.error_behavior(self.abort_on_error), + self.cfg, + self.shared_states, + self.feedback_actor, + self.error_queue, + self.queue_lock, + ) final_executor = AsyncProfiler(async_executor) if self.profiling_enabled else async_executor aws.append(final_executor()) run_start = time.perf_counter() @@ -2155,19 +2292,14 @@ async def __call__(self, *args, **kwargs): # pylint: disable=import-outside-toplevel import yappi import io as python_io + yappi.start() try: return await self.target(*args, **kwargs) finally: yappi.stop() s = python_io.StringIO() - yappi.get_func_stats().print_all(out=s, columns={ - 0: ("name", 140), - 1: ("ncall", 8), - 2: ("tsub", 8), - 3: ("ttot", 8), - 4: ("tavg", 8) - }) + yappi.get_func_stats().print_all(out=s, columns={0: ("name", 140), 1: ("ncall", 8), 2: ("tsub", 8), 3: ("ttot", 8), 4: ("tavg", 8)}) profile = "\n=== Profile START ===\n" profile += s.getvalue() @@ -2176,8 +2308,23 @@ async def __call__(self, *args, **kwargs): class AsyncExecutor: - def __init__(self, client_id, task, schedule, clients, sampler, profile_sampler, cancel, complete, on_error, - config=None, shared_states=None, feedback_actor=None, error_queue=None, queue_lock=None): + def __init__( + self, + client_id, + task, + schedule, + clients, + sampler, + profile_sampler, + cancel, + complete, + on_error, + config=None, + shared_states=None, + feedback_actor=None, + error_queue=None, + queue_lock=None, + ): """ Executes tasks according to the schedule for a given operation. """ @@ -2231,8 +2378,7 @@ async def _prepare_context_manager(self, params: dict): """Prepare the appropriate context manager for the request.""" return self.clients["default"].new_request_context() - async def _execute_request(self, params: dict, expected_scheduled_time: float, total_start: float, - client_state: bool) -> dict: + async def _execute_request(self, params: dict, expected_scheduled_time: float, total_start: float, client_state: bool) -> dict: """Execute a request with timing control and error handling.""" request_timeout = (params or {}).get("request-timeout", None) absolute_expected_schedule_time = total_start + expected_scheduled_time @@ -2254,11 +2400,8 @@ async def _execute_request(self, params: dict, expected_scheduled_time: float, t async with context_manager as request_context: try: total_ops, total_ops_unit, request_meta_data = await asyncio.wait_for( - execute_single( - self.runner, self.clients, params, self.on_error, - redline_enabled=self.redline_enabled, client_enabled=client_state - ), - timeout=self.base_timeout if request_timeout is None else request_timeout + execute_single(self.runner, self.clients, params, self.on_error, redline_enabled=self.redline_enabled, client_enabled=client_state), + timeout=self.base_timeout if request_timeout is None else request_timeout, ) except asyncio.TimeoutError: self.logger.error("Client %s request timed out after %s s", self.client_id, self.base_timeout) @@ -2277,8 +2420,7 @@ async def _execute_request(self, params: dict, expected_scheduled_time: float, t client_request_end = request_context.client_request_end # If request failed or timings weren't properly captured, fall back - if not request_meta_data.get("success") or None in (request_start, request_end, client_request_start, - client_request_end): + if not request_meta_data.get("success") or None in (request_start, request_end, client_request_start, client_request_end): if request_start is None: request_start = processing_start if client_request_start is None: @@ -2290,11 +2432,7 @@ async def _execute_request(self, params: dict, expected_scheduled_time: float, t client_request_end = now if not request_meta_data.get("skipped", False): - error_info = { - "client_id": self.client_id, - "task": str(self.task), - "error_details": request_meta_data - } + error_info = {"client_id": self.client_id, "task": str(self.task), "error_details": request_meta_data} self.report_error(error_info) processing_end = time.perf_counter() @@ -2310,35 +2448,25 @@ async def _execute_request(self, params: dict, expected_scheduled_time: float, t "total_ops": total_ops, "total_ops_unit": total_ops_unit, "request_meta_data": request_meta_data, - "throughput_throttled": throughput_throttled + "throughput_throttled": throughput_throttled, } - def _process_results(self, result_data: dict, total_start: float, client_state: bool, - task_progress: tuple, add_profile_metric_sample: bool = False) -> bool: + def _process_results(self, result_data: dict, total_start: float, client_state: bool, task_progress: tuple, add_profile_metric_sample: bool = False) -> bool: """Process results from a request.""" # Handle cases where the request was skipped (no-op) if result_data["request_meta_data"].get("skipped_request"): - self.schedule_handle.after_request( - result_data["processing_end"], 0, "ops", {"success": False, "skipped": True} - ) + self.schedule_handle.after_request(result_data["processing_end"], 0, "ops", {"success": False, "skipped": True}) return self.complete.is_set() service_time = result_data["request_end"] - result_data["request_start"] - client_processing_time = (result_data["client_request_end"] - result_data[ - "client_request_start"]) - service_time + client_processing_time = (result_data["client_request_end"] - result_data["client_request_start"]) - service_time processing_time = result_data["processing_end"] - result_data["processing_start"] time_period = result_data["request_end"] - total_start - self.schedule_handle.after_request( - result_data["processing_end"], - result_data["total_ops"], - result_data["total_ops_unit"], - result_data["request_meta_data"] - ) + self.schedule_handle.after_request(result_data["processing_end"], result_data["total_ops"], result_data["total_ops_unit"], result_data["request_meta_data"]) throughput = result_data["request_meta_data"].pop("throughput", None) - latency = (result_data["request_end"] - (total_start + self.expected_scheduled_time) - if result_data["throughput_throttled"] else service_time) + latency = result_data["request_end"] - (total_start + self.expected_scheduled_time) if result_data["throughput_throttled"] else service_time runner_completed = getattr(self.runner, "completed", False) runner_task_progress = getattr(self.runner, "task_progress", None) @@ -2358,23 +2486,34 @@ def _process_results(self, result_data: dict, total_start: float, client_state: if client_state: if add_profile_metric_sample: self.profile_sampler.add( - self.task, self.client_id, self.sample_type, + self.task, + self.client_id, + self.sample_type, result_data["request_meta_data"], result_data["absolute_processing_start"], result_data["request_start"], time_period, progress, - result_data["request_meta_data"].pop("dependent_timing", None)) + result_data["request_meta_data"].pop("dependent_timing", None), + ) else: self.sampler.add( - self.task, self.client_id, self.sample_type, + self.task, + self.client_id, + self.sample_type, result_data["request_meta_data"], result_data["absolute_processing_start"], result_data["request_start"], - latency, service_time, client_processing_time, processing_time, - throughput, result_data["total_ops"], result_data["total_ops_unit"], - time_period, progress, - result_data["request_meta_data"].pop("dependent_timing", None) + latency, + service_time, + client_processing_time, + processing_time, + throughput, + result_data["total_ops"], + result_data["total_ops_unit"], + time_period, + progress, + result_data["request_meta_data"].pop("dependent_timing", None), ) return completed @@ -2435,13 +2574,11 @@ async def __call__(self, *args, **kwargs): raise exceptions.BenchmarkError(f"Cannot run task [{self.task}]: {e}") from None finally: if self.task_completes_parent: - self.logger.info( - "Task [%s] completes parent. Client id [%s] is finished executing it and signals completion.", - self.task, self.client_id - ) + self.logger.info("Task [%s] completes parent. Client id [%s] is finished executing it and signals completion.", self.task, self.client_id) self.complete.set() await self._cleanup() + request_context_holder = client.RequestContextHolder() @@ -2487,10 +2624,7 @@ async def execute_single(runner, clients, params, on_error, redline_enabled=Fals total_ops = 0 total_ops_unit = "ops" - request_meta_data = { - "success": False, - "error-type": "transport" - } + request_meta_data = {"success": False, "error-type": "transport"} if isinstance(e.status_code, int): request_meta_data["http-status"] = e.status_code if isinstance(e, exceptions.BenchmarkConnectionTimeout): @@ -2514,7 +2648,7 @@ async def execute_single(runner, clients, params, on_error, redline_enabled=Fals if not redline_enabled: raise exceptions.BenchmarkAssertionError(msg) - if 'error-description' in request_meta_data: + if "error-description" in request_meta_data: try: error_metadata = json.loads(request_meta_data["error-description"]) # parse error-description metadata @@ -2532,10 +2666,7 @@ async def execute_single(runner, clients, params, on_error, redline_enabled=Fals request_context_holder.on_request_start() total_ops = 0 total_ops_unit = "ops" - request_meta_data = { - "success": True, - "skipped_request": True - } + request_meta_data = {"success": True, "skipped_request": True} request_context_holder.on_request_end() request_context_holder.on_client_request_end() return total_ops, total_ops_unit, request_meta_data @@ -2587,8 +2718,7 @@ def __eq__(self, other): return isinstance(other, type(self)) and self.task == other.task and self.global_client_index == other.global_client_index def __repr__(self, *args, **kwargs): - return f"TaskAllocation [{self.client_index_in_task}/{self.task.clients}] for {self.task} " \ - f"and [{self.global_client_index}/{self.total_clients}] in total" + return f"TaskAllocation [{self.client_index_in_task}/{self.task.clients}] for {self.task} and [{self.global_client_index}/{self.total_clients}] in total" class Allocator: @@ -2632,12 +2762,14 @@ def allocations(self): physical_client_index = client_index % max_clients if sub_task.completes_parent: clients_executing_completing_task.append(physical_client_index) - ta = TaskAllocation(task = sub_task, - client_index_in_task = client_index - start_client_index, - global_client_index=client_index, - # if task represents a parallel structure this is the total number of clients - # executing sub-tasks concurrently. - total_clients=task.clients) + ta = TaskAllocation( + task=sub_task, + client_index_in_task=client_index - start_client_index, + global_client_index=client_index, + # if task represents a parallel structure this is the total number of clients + # executing sub-tasks concurrently. + total_clients=task.clients, + ) allocations[physical_client_index].append(ta) start_client_index += sub_task.clients @@ -2746,15 +2878,17 @@ def schedule_for(task_allocation, parameter_source): warmup_time_period = task.warmup_time_period if task.warmup_time_period else 0 ramp_down_time_period = task.ramp_down_time_period if task.ramp_down_time_period else 0 if client_index == 0: - logger.info("Creating time-period based schedule with [%s] distribution for [%s] with a warmup period of [%s] " - "seconds and a time period of [%s] seconds.", task.schedule, task.name, - str(warmup_time_period), str(task.time_period)) - loop_control = TimePeriodBased(warmup_time_period, task.time_period, ramp_down_time_period, - client_index, task.clients) + logger.info( + "Creating time-period based schedule with [%s] distribution for [%s] with a warmup period of [%s] seconds and a time period of [%s] seconds.", + task.schedule, + task.name, + str(warmup_time_period), + str(task.time_period), + ) + loop_control = TimePeriodBased(warmup_time_period, task.time_period, ramp_down_time_period, client_index, task.clients) # Log individual client duration if ramp-down is enabled if ramp_down_time_period > 0 and client_index == 0: - logger.info("Ramp-down enabled: clients will stop in reverse order over [%s] seconds", - str(ramp_down_time_period)) + logger.info("Ramp-down enabled: clients will stop in reverse order over [%s] seconds", str(ramp_down_time_period)) else: warmup_iterations = task.warmup_iterations if task.warmup_iterations else 0 if task.iterations: @@ -2765,8 +2899,13 @@ def schedule_for(task_allocation, parameter_source): else: iterations = None if client_index == 0: - logger.info("Creating iteration-count based schedule with [%s] distribution for [%s] with [%s] warmup " - "iterations and [%s] iterations.", task.schedule, task.name, str(warmup_iterations), str(iterations)) + logger.info( + "Creating iteration-count based schedule with [%s] distribution for [%s] with [%s] warmup iterations and [%s] iterations.", + task.schedule, + task.name, + str(warmup_iterations), + str(iterations), + ) loop_control = IterationBased(warmup_iterations, iterations) if client_index == 0: @@ -2813,6 +2952,7 @@ def __init__(self, task_allocation, sched, task_progress_control, runner, params # import asyncio # self.io_pool_exc = ThreadPoolExecutor(max_workers=1) # self.loop = asyncio.get_event_loop() + @property def ramp_up_wait_time(self): """ @@ -2843,8 +2983,7 @@ async def __call__(self): # does not contribute at all to completion. Hence, we cannot define completion. task_progress = self.params.task_progress if param_source_knows_progress else None # current_params = await self.loop.run_in_executor(self.io_pool_exc, self.params.params) - yield (next_scheduled, self.task_progress_control.sample_type, task_progress, self.runner, - self.params.params()) + yield (next_scheduled, self.task_progress_control.sample_type, task_progress, self.runner, self.params.params()) self.task_progress_control.next() except StopIteration: return @@ -2852,20 +2991,15 @@ async def __call__(self): while not self.task_progress_control.completed: try: next_scheduled = self.sched.next(next_scheduled) - #current_params = await self.loop.run_in_executor(self.io_pool_exc, self.params.params) - yield (next_scheduled, - self.task_progress_control.sample_type, - self.task_progress_control.task_progress, - self.runner, - self.params.params()) + # current_params = await self.loop.run_in_executor(self.io_pool_exc, self.params.params) + yield (next_scheduled, self.task_progress_control.sample_type, self.task_progress_control.task_progress, self.runner, self.params.params()) self.task_progress_control.next() except StopIteration: return class TimePeriodBased: - def __init__(self, warmup_time_period, time_period, ramp_down_time_period=None, - client_index=None, total_clients=None): + def __init__(self, warmup_time_period, time_period, ramp_down_time_period=None, client_index=None, total_clients=None): self._warmup_time_period = warmup_time_period self._time_period = time_period self._ramp_down_time_period = ramp_down_time_period or 0 @@ -2882,8 +3016,14 @@ def __init__(self, warmup_time_period, time_period, ramp_down_time_period=None, reverse_index = (total_clients - 1) - client_index client_early_stop = self._ramp_down_time_period * (reverse_index / total_clients) self._duration = self._base_duration - client_early_stop - self.logger.info("Client [%d/%d] will run for %.2f seconds (base: %.2f, early stop: %.2f due to ramp-down)", - client_index, total_clients, self._duration, self._base_duration, client_early_stop) + self.logger.info( + "Client [%d/%d] will run for %.2f seconds (base: %.2f, early stop: %.2f due to ramp-down)", + client_index, + total_clients, + self._duration, + self._base_duration, + client_early_stop, + ) else: self._duration = self._base_duration else: @@ -2911,7 +3051,7 @@ def infinite(self): @property def task_progress(self): - return (self._elapsed / self._duration, '%') + return (self._elapsed / self._duration, "%") @property def completed(self): @@ -2949,7 +3089,7 @@ def infinite(self): @property def task_progress(self): - return ((self._it + 1) / self._total_iterations, '%') + return ((self._it + 1) / self._total_iterations, "%") @property def completed(self): diff --git a/solrorbit/workload/__init__.py b/solrorbit/workload/__init__.py index 93c815c4..5955b378 100644 --- a/solrorbit/workload/__init__.py +++ b/solrorbit/workload/__init__.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -25,10 +25,7 @@ # specific language governing permissions and limitations # under the License. -from .loader import ( - list_workloads, workload_info, load_workload, load_workload_plugins, workload_repo, operation_parameters, set_absolute_data_path, - WorkloadProcessorRegistry -) +from .loader import list_workloads, workload_info, load_workload, load_workload_plugins, workload_repo, operation_parameters, set_absolute_data_path, WorkloadProcessorRegistry # expose the complete workload API from .workload import * diff --git a/solrorbit/workload/ingestion_manager.py b/solrorbit/workload/ingestion_manager.py index a0c56c06..7f159f7c 100644 --- a/solrorbit/workload/ingestion_manager.py +++ b/solrorbit/workload/ingestion_manager.py @@ -9,13 +9,14 @@ import os import multiprocessing + class IngestionManager: plimsoll = 4 * os.cpu_count() - ballast = plimsoll/2 - chunk_size = 50 # in MB + ballast = plimsoll / 2 + chunk_size = 50 # in MB lock = multiprocessing.Lock() - rd_index = multiprocessing.Value('i', 0) - wr_count = multiprocessing.Value('i', 0) - producer_started = multiprocessing.Value('i', 0) + rd_index = multiprocessing.Value("i", 0) + wr_count = multiprocessing.Value("i", 0) + producer_started = multiprocessing.Value("i", 0) load_full = multiprocessing.Condition() load_empty = multiprocessing.Condition() diff --git a/solrorbit/workload/loader.py b/solrorbit/workload/loader.py index b2326090..9ae303b2 100644 --- a/solrorbit/workload/loader.py +++ b/solrorbit/workload/loader.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -125,9 +125,13 @@ def list_workloads(cfg): data = [] for t in available_workloads: - line = [t.name, t.description, convert.number_to_human_string(t.number_of_documents), - convert.bytes_to_human_string(t.compressed_size_in_bytes), - convert.bytes_to_human_string(t.uncompressed_size_in_bytes)] + line = [ + t.name, + t.description, + convert.number_to_human_string(t.number_of_documents), + convert.bytes_to_human_string(t.compressed_size_in_bytes), + convert.bytes_to_human_string(t.uncompressed_size_in_bytes), + ] if not only_auto_generated_test_procedures: line.append(t.default_test_procedure) line.append(",".join(map(str, t.test_procedures))) @@ -211,19 +215,13 @@ def _load_single_workload(cfg, workload_repository, workload_name): return current_workload except FileNotFoundError as e: logging.getLogger(__name__).exception("Cannot load workload [%s]", workload_name) - raise exceptions.SystemSetupError(f"Cannot load workload [{workload_name}]. " - f"List the available workloads with [{PROGRAM_NAME} list workloads].") from e + raise exceptions.SystemSetupError(f"Cannot load workload [{workload_name}]. List the available workloads with [{PROGRAM_NAME} list workloads].") from e except BaseException: logging.getLogger(__name__).exception("Cannot load workload [%s]", workload_name) raise -def load_workload_plugins(cfg, - workload_name, - register_runner=None, - register_scheduler=None, - register_workload_processor=None, - force_update=False): +def load_workload_plugins(cfg, workload_name, register_runner=None, register_scheduler=None, register_workload_processor=None, force_update=False): """ Loads plugins that are defined for the current workload (as specified by the configuration). @@ -422,8 +420,7 @@ def prepare_docs(cfg, workload, corpus, preparator): for document_set in corpus.documents: if document_set.is_supported_source_format: data_root = data_dir(cfg, workload.name, corpus.name) - logging.getLogger(__name__).info("Resolved data root directory for document corpus [%s] in workload [%s] " - "to [%s].", corpus.name, workload.name, data_root) + logging.getLogger(__name__).info("Resolved data root directory for document corpus [%s] in workload [%s] to [%s].", corpus.name, workload.name, data_root) if len(data_root) == 1: preparator.prepare_document_set(document_set, data_root[0]) # attempt to prepare everything in the current directory and fallback to the corpus directory @@ -433,12 +430,7 @@ def prepare_docs(cfg, workload, corpus, preparator): def on_prepare_workload(self, workload, data_root_dir): prep = DocumentSetPreparator(workload.name, self.downloader, self.decompressor) for corpus in used_corpora(workload): - params = { - "cfg": self.cfg, - "workload": workload, - "corpus": corpus, - "preparator": prep - } + params = {"cfg": self.cfg, "workload": workload, "corpus": corpus, "preparator": prep} yield DefaultWorkloadPreparator.prepare_docs, params @@ -448,8 +440,7 @@ def __init__(self): def decompress(self, archive_path, documents_path, uncompressed_size): if uncompressed_size: - msg = f"Decompressing workload data from [{archive_path}] to [{documents_path}] (resulting size: " \ - f"[{convert.bytes_to_gb(uncompressed_size):.2f}] GB) ... " + msg = f"Decompressing workload data from [{archive_path}] to [{documents_path}] (resulting size: [{convert.bytes_to_gb(uncompressed_size):.2f}] GB) ... " else: msg = f"Decompressing workload data from [{archive_path}] to [{documents_path}] ... " @@ -458,13 +449,12 @@ def decompress(self, archive_path, documents_path, uncompressed_size): console.println("[OK]") if not os.path.isfile(documents_path): raise exceptions.DataError( - f"Decompressing [{archive_path}] did not create [{documents_path}]. Please check with the workload " - f"author if the compressed archive has been created correctly.") + f"Decompressing [{archive_path}] did not create [{documents_path}]. Please check with the workload author if the compressed archive has been created correctly." + ) extracted_bytes = os.path.getsize(documents_path) if uncompressed_size is not None and extracted_bytes != uncompressed_size: - raise exceptions.DataError(f"[{documents_path}] is corrupt. Extracted [{extracted_bytes}] bytes " - f"but [{uncompressed_size}] bytes are expected.") + raise exceptions.DataError(f"[{documents_path}] is corrupt. Extracted [{extracted_bytes}] bytes but [{uncompressed_size}] bytes are expected.") class Downloader: @@ -499,15 +489,13 @@ def download(self, base_url, source_url, target_path, size_in_bytes): self.logger.info("Downloading data from [%s] to [%s].", data_url, target_path) # we want to have a bit more accurate download progress as these files are typically very large - progress = net.Progress("[INFO] Downloading workload data file: " + os.path.basename(target_path), - accuracy=1) + progress = net.Progress("[INFO] Downloading workload data file: " + os.path.basename(target_path), accuracy=1) net.download(data_url, target_path, size_in_bytes, progress_indicator=progress) progress.finish() self.logger.info("Downloaded data from [%s] to [%s].", data_url, target_path) except urllib.error.HTTPError as e: if e.code == 404 and self.test_mode: - raise exceptions.DataError("This workload does not support test mode. Ask the workload author to add it or" - " disable test mode and retry.") from None + raise exceptions.DataError("This workload does not support test mode. Ask the workload author to add it or disable test mode and retry.") from None else: msg = f"Could not download [{data_url}] to [{target_path}]" if e.reason: @@ -519,13 +507,11 @@ def download(self, base_url, source_url, target_path, size_in_bytes): raise exceptions.DataError(f"Could not download [{data_url}] to [{target_path}].") from e if not os.path.isfile(target_path): - raise exceptions.SystemSetupError(f"Could not download [{data_url}] to [{target_path}]. Verify data " - f"are available at [{data_url}] and check your Internet connection.") + raise exceptions.SystemSetupError(f"Could not download [{data_url}] to [{target_path}]. Verify data are available at [{data_url}] and check your Internet connection.") actual_size = os.path.getsize(target_path) if size_in_bytes is not None and actual_size != size_in_bytes: - raise exceptions.DataError(f"[{target_path}] is corrupt. Downloaded [{actual_size}] bytes " - f"but [{size_in_bytes}] bytes are expected.") + raise exceptions.DataError(f"[{target_path}] is corrupt. Downloaded [{actual_size}] bytes but [{size_in_bytes}] bytes are expected.") class DocumentSetPreparator: @@ -546,8 +532,9 @@ def create_file_offset_table(self, document_file_path, base_url, source_url, exp lines_read = io.prepare_file_offset_table(document_file_path, base_url, source_url, self.downloader) if lines_read and lines_read != expected_number_of_lines: io.remove_file_offset_table(document_file_path) - raise exceptions.DataError(f"Data in [{document_file_path}] for workload [{self.workload_name}] are invalid. " - f"Expected [{expected_number_of_lines}] lines but got [{lines_read}].") + raise exceptions.DataError( + f"Data in [{document_file_path}] for workload [{self.workload_name}] are invalid. Expected [{expected_number_of_lines}] lines but got [{lines_read}]." + ) def prepare_document_set(self, document_set, data_root): """ @@ -568,12 +555,9 @@ def prepare_document_set(self, document_set, data_root): doc_path = os.path.join(data_root, document_set.document_file) archive_path = os.path.join(data_root, document_set.document_archive) if document_set.has_compressed_corpus() else None while True: - if self.is_locally_available(doc_path) and \ - self.has_expected_size(doc_path, document_set.uncompressed_size_in_bytes): + if self.is_locally_available(doc_path) and self.has_expected_size(doc_path, document_set.uncompressed_size_in_bytes): break - if document_set.has_compressed_corpus() and \ - self.is_locally_available(archive_path) and \ - self.has_expected_size(archive_path, document_set.compressed_size_in_bytes): + if document_set.has_compressed_corpus() and self.is_locally_available(archive_path) and self.has_expected_size(archive_path, document_set.compressed_size_in_bytes): self.decompressor.decompress(archive_path, doc_path, document_set.uncompressed_size_in_bytes) else: if document_set.has_compressed_corpus(): @@ -592,8 +576,11 @@ def prepare_document_set(self, document_set, data_root): self.downloader.download(document_set.base_url, None, os.path.join(data_root, part["name"]), part["size"]) try: with open(target_path, "wb") as outfile: - console.info(f"Concatenating file parts {', '.join([p['name'] for p in document_set.document_file_parts])}" - f" into {os.path.basename(target_path)}", flush=True, logger=self.logger) + console.info( + f"Concatenating file parts {', '.join([p['name'] for p in document_set.document_file_parts])} into {os.path.basename(target_path)}", + flush=True, + logger=self.logger, + ) for part in document_set.document_file_parts: part_name = os.path.join(data_root, part["name"]) with open(part_name, "rb") as infile: @@ -604,11 +591,12 @@ def prepare_document_set(self, document_set, data_root): else: self.downloader.download(document_set.base_url, document_set.source_url, target_path, expected_size) except exceptions.DataError as e: - if e.message == "Cannot download data because no base URL is provided." and \ - self.is_locally_available(target_path): - raise exceptions.DataError(f"[{target_path}] is present but does not have the expected " - f"size of [{expected_size}] bytes and it cannot be downloaded " - f"because no base URL is provided.") from None + if e.message == "Cannot download data because no base URL is provided." and self.is_locally_available(target_path): + raise exceptions.DataError( + f"[{target_path}] is present but does not have the expected " + f"size of [{expected_size}] bytes and it cannot be downloaded " + f"because no base URL is provided." + ) from None else: raise if document_set.support_file_offset_table: @@ -642,8 +630,7 @@ def prepare_bundled_document_set(self, document_set, data_root): self.create_file_offset_table(doc_path, document_set.base_url, document_set.source_url, document_set.number_of_lines) return True else: - raise exceptions.DataError(f"[{doc_path}] is present but does not have the expected size " - f"of [{document_set.uncompressed_size_in_bytes}] bytes.") + raise exceptions.DataError(f"[{doc_path}] is present but does not have the expected size of [{document_set.uncompressed_size_in_bytes}] bytes.") if document_set.has_compressed_corpus() and self.is_locally_available(archive_path): if self.has_expected_size(archive_path, document_set.compressed_size_in_bytes): @@ -652,8 +639,7 @@ def prepare_bundled_document_set(self, document_set, data_root): # treat this is an error because if the file is present but the size does not match, something is # really fishy. It is likely that the user is currently creating a new workload and did not specify # the file size correctly. - raise exceptions.DataError(f"[{archive_path}] is present but does not have " - f"the expected size of [{document_set.compressed_size_in_bytes}] bytes.") + raise exceptions.DataError(f"[{archive_path}] is present but does not have the expected size of [{document_set.compressed_size_in_bytes}] bytes.") else: return False @@ -678,9 +664,7 @@ def __init__(self, base_path, template_file_name, source=io.FileSource, fileglob def load_template_from_file(self): loader = jinja2.FileSystemLoader(self.base_path) try: - base_workload = loader.get_source(jinja2.Environment( - autoescape=select_autoescape(['html', 'xml'])), - self.template_file_name) + base_workload = loader.get_source(jinja2.Environment(autoescape=select_autoescape(["html", "xml"])), self.template_file_name) except jinja2.TemplateNotFound: self.logger.exception("Could not load workload from [%s].", self.template_file_name) raise WorkloadSyntaxError("Could not load workload from '{}'".format(self.template_file_name)) @@ -718,8 +702,7 @@ def read_glob_files(self, pattern): # A Jinja filter that tests if a version string lies within a specified range. # For instance, "1.2.3" lies between "1.0.0" and "2.0.0". def version_between(version, frm, to): - return list(map(int, version.split('.'))) >= list(map(int, frm.split('.'))) and \ - list(map(int, version.split('.'))) <= list(map(int, to.split('.'))) + return list(map(int, version.split("."))) >= list(map(int, frm.split("."))) and list(map(int, version.split("."))) <= list(map(int, to.split("."))) def default_internal_template_vars(glob_helper=lambda f: [], clock=time.Clock): @@ -727,15 +710,7 @@ def default_internal_template_vars(glob_helper=lambda f: [], clock=time.Clock): Dict of internal global variables used by our jinja2 renderers """ - return { - "globals": { - "now": clock.now(), - "glob": glob_helper - }, - "filters": { - "days_ago": time.days_ago - } - } + return {"globals": {"now": clock.now(), "glob": glob_helper}, "filters": {"days_ago": time.days_ago}} def render_template(template_source, template_vars=None, template_internal_vars=None, loader=None): @@ -760,17 +735,12 @@ def render_template(template_source, template_vars=None, template_internal_vars= {% endif %} {% endif %} {%- endmacro %} - """ + """, ] # place helpers dict loader first to prevent users from overriding our macros. env = jinja2.Environment( - loader=jinja2.ChoiceLoader([ - jinja2.DictLoader({"benchmark.helpers": "".join(macros)}), - jinja2.BaseLoader(), - loader - ]), - autoescape=select_autoescape(['html', 'xml']) + loader=jinja2.ChoiceLoader([jinja2.DictLoader({"benchmark.helpers": "".join(macros)}), jinja2.BaseLoader(), loader]), autoescape=select_autoescape(["html", "xml"]) ) if template_vars: @@ -788,7 +758,7 @@ def render_template(template_source, template_vars=None, template_internal_vars= def register_all_params_in_workload(assembled_source, complete_workload_params=None): - j2env = jinja2.Environment(autoescape=select_autoescape(['html', 'xml'])) + j2env = jinja2.Environment(autoescape=select_autoescape(["html", "xml"])) # we don't need the following j2 filters/macros but we define them anyway to prevent parsing failures internal_template_vars = default_internal_template_vars() @@ -815,10 +785,12 @@ def relative_glob(start, f): template_source.load_template_from_file() register_all_params_in_workload(template_source.assembled_source, complete_workload_params) - return render_template(loader=jinja2.FileSystemLoader(base_path), - template_source=template_source.assembled_source, - template_vars=template_vars, - template_internal_vars=default_internal_template_vars(glob_helper=lambda f: relative_glob(base_path, f))) + return render_template( + loader=jinja2.FileSystemLoader(base_path), + template_source=template_source.assembled_source, + template_vars=template_vars, + template_internal_vars=default_internal_template_vars(glob_helper=lambda f: relative_glob(base_path, f)), + ) class TaskFilterWorkloadProcessor(WorkloadProcessor): @@ -848,8 +820,7 @@ def _filters_from_filtered_tasks(self, filtered_tasks): elif spec[0] == "tag": filters.append(workload.TaskTagFilter(spec[1])) else: - raise exceptions.SystemSetupError(f"Invalid format for filtered tasks: [{t}]. " - f"Expected [type] but got [{spec[0]}].") + raise exceptions.SystemSetupError(f"Invalid format for filtered tasks: [{t}]. Expected [type] but got [{spec[0]}].") else: raise exceptions.SystemSetupError(f"Invalid format for filtered tasks: [{t}]") return filters @@ -878,8 +849,7 @@ def on_after_load_workload(self, input_workload, **kwargs): if self._filter_out_match(leaf_task): leafs_to_remove.append(leaf_task) for leaf_task in leafs_to_remove: - self.logger.info("Removing sub-task [%s] from test_procedure [%s] due to task filter.", - leaf_task, test_procedure) + self.logger.info("Removing sub-task [%s] from test_procedure [%s] due to task filter.", leaf_task, test_procedure) task.remove_task(leaf_task) for task in tasks_to_remove: self.logger.info("Removing task [%s] from test_procedure [%s] due to task filter.", task, test_procedure) @@ -915,8 +885,7 @@ def on_after_load_workload(self, input_workload, **kwargs): path, ext = io.splitext(document_set.document_file) document_set.document_file = f"{path}-1k{ext}" else: - raise exceptions.BenchmarkAssertionError(f"Document corpus [{corpus.name}] has neither compressed " - f"nor uncompressed corpus.") + raise exceptions.BenchmarkAssertionError(f"Document corpus [{corpus.name}] has neither compressed nor uncompressed corpus.") # we don't want to check sizes document_set.compressed_size_in_bytes = None @@ -941,13 +910,11 @@ def on_after_load_workload(self, input_workload, **kwargs): if leaf_task.warmup_time_period is not None and leaf_task.warmup_time_period > 0: leaf_task.warmup_time_period = 0 if self.logger.isEnabledFor(logging.DEBUG): - self.logger.debug("Resetting warmup time period for [%s] to [%d] seconds.", - str(leaf_task), leaf_task.warmup_time_period) + self.logger.debug("Resetting warmup time period for [%s] to [%d] seconds.", str(leaf_task), leaf_task.warmup_time_period) if leaf_task.time_period is not None and leaf_task.time_period > 10: leaf_task.time_period = 10 if self.logger.isEnabledFor(logging.DEBUG): - self.logger.debug("Resetting measurement time period for [%s] to [%d] seconds.", - str(leaf_task), leaf_task.time_period) + self.logger.debug("Resetting measurement time period for [%s] to [%d] seconds.", str(leaf_task), leaf_task.time_period) # Keep throttled to expose any errors but increase the target throughput for short execution times. if leaf_task.target_throughput: @@ -958,8 +925,8 @@ def on_after_load_workload(self, input_workload, **kwargs): return input_workload -class QueryRandomizerWorkloadProcessor(WorkloadProcessor): +class QueryRandomizerWorkloadProcessor(WorkloadProcessor): class QueryRandomizationInfo: # A class containing information about which values to replace when randomizing queries. # For example, QueryRandomizationInfo("range", [["gte", "gt"], ["lte", "lt"]], ["format"]) @@ -982,13 +949,11 @@ def validate_parameter_name_options_list(self, query_name, parameter_name_option for parameter_name_options in parameter_name_options_list: for parameter_name_option in parameter_name_options: if parameter_name_option == query_name: - raise exceptions.ExecutorError( - f"Cannot have a randomized value name {query_name} which is the same as the name of its query!") + raise exceptions.ExecutorError(f"Cannot have a randomized value name {query_name} which is the same as the name of its query!") all_values.append(parameter_name_option) distinct_values.add(parameter_name_option) if len(all_values) != len(distinct_values): - raise exceptions.ExecutorError( - f"Duplicate option for value name in query_randomization_info: {parameter_name_options_list}") + raise exceptions.ExecutorError(f"Duplicate option for value name in query_randomization_info: {parameter_name_options_list}") def check_one_of_each_name_present(self, obj): # Return true if one version of the value name is present in obj for each set of value options. @@ -1008,6 +973,7 @@ def check_one_of_each_name_present(self, obj): DEFAULT_N = 5000 DEFAULT_ALPHA = 1 DEFAULT_QUERY_RANDOMIZATION_INFO = QueryRandomizationInfo("range", [["gte", "gt"], ["lte", "lt"]], ["format"]) + def __init__(self, cfg): self.randomization_enabled = cfg.opts("workload", "randomization.enabled", mandatory=False, default_value=False) self.rf = float(cfg.opts("workload", "randomization.repeat_frequency", mandatory=False, default_value=self.DEFAULT_RF)) @@ -1019,12 +985,12 @@ def __init__(self, cfg): # Helper functions for computing Zipf distribution def H(self, i, H_list): # compute the harmonic number H_n,m = sum over i from 1 to n of (1 / i^m) - return H_list[i-1] + return H_list[i - 1] def precompute_H(self, n, m): H_list = [1] - for j in range(2, n+1): - H_list.append(H_list[-1] + 1 / (j ** m)) + for j in range(2, n + 1): + H_list.append(H_list[-1] + 1 / (j**m)) return H_list def zipf_cdf_inverse(self, u, H_list): @@ -1032,9 +998,8 @@ def zipf_cdf_inverse(self, u, H_list): # as the zipf cdf is discontinuous there is no real inverse but we can use this solution: # https://math.stackexchange.com/questions/53671/how-to-calculate-the-inverse-cdf-for-the-zipf-distribution # Precompute all values H_i,alpha for a fixed alpha and pass in as H_list - if (u < 0 or u >= 1): - raise exceptions.ExecutorError( - "Input u must have 0 <= u < 1. This error shouldn't appear, please raise an issue if it does") + if u < 0 or u >= 1: + raise exceptions.ExecutorError("Input u must have 0 <= u < 1. This error shouldn't appear, please raise an issue if it does") n = len(H_list) candidate_return = 1 denominator = self.H(n, H_list) @@ -1055,7 +1020,7 @@ def get_dict_from_previous_path(self, root, current_path): def extract_fields_helper(self, root, current_path, query_randomization_info): # Recursively called to find the location of ranges in a range query. # Return the field and the current path if we're currently scanning the field name in a range query, otherwise return an empty list. - fields = [] # pairs of (field, path_to_field) + fields = [] # pairs of (field, path_to_field) curr = self.get_dict_from_previous_path(root, current_path) if isinstance(curr, dict) and curr != {}: if len(current_path) > 0 and current_path[-1] == query_randomization_info.query_name: @@ -1085,8 +1050,9 @@ def extract_fields_and_paths(self, params, query_randomization_info): root = params["body"]["query"] except KeyError: raise exceptions.SystemSetupError( - f"Cannot extract range query fields from these params: {params}\n, missing params[\"body\"][\"query\"]\n" - f"Make sure the operation in operations/default.json is well-formed") + f'Cannot extract range query fields from these params: {params}\n, missing params["body"]["query"]\n' + f"Make sure the operation in operations/default.json is well-formed" + ) fields_and_paths = self.extract_fields_helper(root, [], query_randomization_info) return fields_and_paths @@ -1110,10 +1076,15 @@ def get_repeated_value_index(self): # minus 1 for mapping [1, N] to [0, N-1] of list indices return self.zipf_cdf_inverse(random.random(), self.H_list) - 1 - def get_randomized_values(self, input_workload, input_params, query_randomization_info, - get_standard_value=params.get_standard_value, - get_standard_value_source=params.get_standard_value_source, # Made these configurable for simpler unit tests - **kwargs): + def get_randomized_values( + self, + input_workload, + input_params, + query_randomization_info, + get_standard_value=params.get_standard_value, + get_standard_value_source=params.get_standard_value_source, # Made these configurable for simpler unit tests + **kwargs, + ): # The queries as listed in operations/default.json don't have the index param, # unlike the custom ones you would specify in workload.py, so we have to add them ourselves if "index" not in input_params: @@ -1134,16 +1105,21 @@ def get_randomized_values(self, input_workload, input_params, query_randomizatio return input_params def create_param_source_lambda(self, op_name, get_standard_value, get_standard_value_source, get_query_randomization_info): - return lambda w, p, **kwargs: self.get_randomized_values(w, p, query_randomization_info=get_query_randomization_info(op_name), - get_standard_value=get_standard_value, - get_standard_value_source=get_standard_value_source, - op_name=op_name, **kwargs) + return lambda w, p, **kwargs: self.get_randomized_values( + w, + p, + query_randomization_info=get_query_randomization_info(op_name), + get_standard_value=get_standard_value, + get_standard_value_source=get_standard_value_source, + op_name=op_name, + **kwargs, + ) def on_after_load_workload(self, input_workload, **kwargs): if not self.randomization_enabled: self.logger.info("Query randomization is disabled.") return input_workload - self.logger.info("Query randomization is enabled, with repeat frequency = %d, n = %d",self.rf, self.N) + self.logger.info("Query randomization is enabled, with repeat frequency = %d, n = %d", self.rf, self.N) # By default, use params for standard values and generate new standard values the first time an op/field is seen. # In unit tests, we should be able to supply our own sources independent of params. @@ -1170,23 +1146,29 @@ def on_after_load_workload(self, input_workload, **kwargs): op_type = None self.logger.info( "Found operation %s in default schedule with type %s, which couldn't be converted to a known OperationType", - leaf_task.operation.name, leaf_task.operation.type) + leaf_task.operation.name, + leaf_task.operation.type, + ) if op_type == workload.OperationType.Search: op_name = leaf_task.operation.name param_source_name = op_name + "-randomized" params.register_param_source_for_name( param_source_name, - self.create_param_source_lambda(op_name, get_standard_value=kwargs["get_standard_value"], - get_standard_value_source=kwargs["get_standard_value_source"], - get_query_randomization_info=params.get_query_randomization_info)) + self.create_param_source_lambda( + op_name, + get_standard_value=kwargs["get_standard_value"], + get_standard_value_source=kwargs["get_standard_value_source"], + get_query_randomization_info=params.get_query_randomization_info, + ), + ) leaf_task.operation.param_source = param_source_name # Generate the right number of standard values for this field, if not already present - for field_and_path in self.extract_fields_and_paths(leaf_task.operation.params, - params.get_query_randomization_info(op_name)): + for field_and_path in self.extract_fields_and_paths(leaf_task.operation.params, params.get_query_randomization_info(op_name)): if generate_new_standard_values: params.generate_standard_values_if_absent(op_name, field_and_path[0], self.N) return input_workload + class CompleteWorkloadParams: def __init__(self, user_specified_workload_params=None): self.workload_defined_params = set() @@ -1231,7 +1213,7 @@ def __init__(self, cfg): self.read_workload = WorkloadSpecificationReader( workload_params=self.workload_params, complete_workload_params=self.complete_workload_params, - selected_test_procedure=cfg.opts("workload", "test_procedure.name", mandatory=False) + selected_test_procedure=cfg.opts("workload", "test_procedure.name", mandatory=False), ) self.logger = logging.getLogger(__name__) @@ -1250,9 +1232,7 @@ def read(self, workload_name, workload_spec_file, mapping_dir): # involving lines numbers and it also does not bloat Solr Orbit's log file so much. tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".json") try: - rendered = render_template_from_file( - workload_spec_file, self.workload_params, - complete_workload_params=self.complete_workload_params) + rendered = render_template_from_file(workload_spec_file, self.workload_params, complete_workload_params=self.complete_workload_params) with open(tmp.name, "wt", encoding="utf-8") as f: f.write(rendered) self.logger.info("Final rendered workload for '%s' has been written to '%s'.", workload_spec_file, tmp.name) @@ -1264,17 +1244,15 @@ def read(self, workload_name, workload_spec_file, mapping_dir): except jinja2.exceptions.TemplateSyntaxError as e: exception_message = f"Jinja2 Exception TemplateSyntaxError: {e}\n" - if 'endif' in exception_message: - exception_message = exception_message + \ - "There is an extra Jinja2 \"endif\" somewhere in the workload's files. " + \ - "Please remove it so that the workload can be rendered and run.\n" - if 'Missing end of raw directive' in exception_message: - exception_message += \ - "In the workload files, \"{% raw -%}\" was provided but is missing it's associated \"{% endraw -%}\" tag.\n" + if "endif" in exception_message: + exception_message = ( + exception_message + 'There is an extra Jinja2 "endif" somewhere in the workload\'s files. ' + "Please remove it so that the workload can be rendered and run.\n" + ) + if "Missing end of raw directive" in exception_message: + exception_message += 'In the workload files, "{% raw -%}" was provided but is missing it\'s associated "{% endraw -%}" tag.\n' raise exceptions.SystemSetupError(exception_message) - except json.JSONDecodeError as e: self.logger.exception("Could not load [%s].", workload_spec_file) msg = "Could not load '{}': {}.".format(workload_spec_file, str(e)) @@ -1288,20 +1266,23 @@ def read(self, workload_name, workload_spec_file, mapping_dir): erroneous_lines.insert(line_idx - ctx_start + 1, "-" * (e.colno - 1) + "^ Error is here") msg += " Lines containing the error:\n\n{}\n\n".format("\n".join(erroneous_lines)) msg += "The complete workload has been written to '{}' for diagnosis. \n\n".format(tmp.name) - console_message = f"Suggestion: Verify that [{workload_name}] workload has correctly formatted JSON files and " + \ - "Jinja Templates. For Jinja2 errors, consider using a live Jinja2 parser. " + \ - f"See common workload formatting errors:{WorkloadFileReader.COMMON_WORKLOAD_FORMAT_ERRORS}" + console_message = ( + f"Suggestion: Verify that [{workload_name}] workload has correctly formatted JSON files and " + + "Jinja Templates. For Jinja2 errors, consider using a live Jinja2 parser. " + + f"See common workload formatting errors:{WorkloadFileReader.COMMON_WORKLOAD_FORMAT_ERRORS}" + ) msg += console_message raise WorkloadSyntaxError(msg) except Exception as e: # TypeErrors get logged here self.logger.exception("Could not load [%s].", workload_spec_file) - msg = "Could not load '{}'. The complete workload has been written to '{}' for diagnosis. \n\n".format( - workload_spec_file, tmp.name) - console_message = f"Suggestion: Verify that [{workload_name}] workload has correctly formatted JSON files and " + \ - "Jinja Templates. For Jinja2 errors, consider using a live Jinja2 parser. " + \ - f"See common workload formatting errors:{WorkloadFileReader.COMMON_WORKLOAD_FORMAT_ERRORS}" + msg = "Could not load '{}'. The complete workload has been written to '{}' for diagnosis. \n\n".format(workload_spec_file, tmp.name) + console_message = ( + f"Suggestion: Verify that [{workload_name}] workload has correctly formatted JSON files and " + + "Jinja Templates. For Jinja2 errors, consider using a live Jinja2 parser. " + + f"See common workload formatting errors:{WorkloadFileReader.COMMON_WORKLOAD_FORMAT_ERRORS}" + ) msg += console_message # Convert to string early on to avoid serialization errors with Jinja exceptions. raise WorkloadSyntaxError(msg, str(e)) @@ -1310,27 +1291,27 @@ def read(self, workload_name, workload_spec_file, mapping_dir): try: workload_version = int(raw_version) except ValueError: - raise exceptions.InvalidSyntax("version identifier for workload %s must be numeric but was [%s]" % ( - workload_name, str(raw_version))) + raise exceptions.InvalidSyntax("version identifier for workload %s must be numeric but was [%s]" % (workload_name, str(raw_version))) if WorkloadFileReader.MINIMUM_SUPPORTED_TRACK_VERSION > workload_version: - raise exceptions.BenchmarkError("Workload {} is on version {} but needs to be updated at least to version {} to work with the " - "current version of Solr Orbit.".format(workload_name, workload_version, - WorkloadFileReader.MINIMUM_SUPPORTED_TRACK_VERSION)) + raise exceptions.BenchmarkError( + "Workload {} is on version {} but needs to be updated at least to version {} to work with the current version of Solr Orbit.".format( + workload_name, workload_version, WorkloadFileReader.MINIMUM_SUPPORTED_TRACK_VERSION + ) + ) if WorkloadFileReader.MAXIMUM_SUPPORTED_TRACK_VERSION < workload_version: - raise exceptions.BenchmarkError("Workload {} requires a newer version of Solr Orbit. " - "Please upgrade Solr Orbit (supported workload version: {}, " - "required workload version: {}).".format( - workload_name, - WorkloadFileReader.MAXIMUM_SUPPORTED_TRACK_VERSION, - workload_version)) + raise exceptions.BenchmarkError( + "Workload {} requires a newer version of Solr Orbit. Please upgrade Solr Orbit (supported workload version: {}, required workload version: {}).".format( + workload_name, WorkloadFileReader.MAXIMUM_SUPPORTED_TRACK_VERSION, workload_version + ) + ) try: jsonschema.validate(workload_spec, self.workload_schema) except jsonschema.exceptions.ValidationError as ve: raise WorkloadSyntaxError( "Workload '{}' is invalid.\n\nError details: {}\nInstance: {}\nPath: {}\nSchema path: {}".format( - workload_name, ve.message, json.dumps( - ve.instance, indent=4, sort_keys=True), - ve.absolute_path, ve.absolute_schema_path)) + workload_name, ve.message, json.dumps(ve.instance, indent=4, sort_keys=True), ve.absolute_path, ve.absolute_schema_path + ) + ) try: current_workload = self.read_workload(workload_name, workload_spec, mapping_dir) @@ -1347,19 +1328,20 @@ def read(self, workload_name, workload_spec_file, mapping_dir): "All parameters exposed by this workload:\n" "{}".format( ",".join(opts.double_quoted_list_of(sorted(unused_user_defined_workload_params))), - ",".join(opts.double_quoted_list_of(sorted(opts.make_list_of_close_matches( - unused_user_defined_workload_params, - self.complete_workload_params.workload_defined_params - )))), + ",".join( + opts.double_quoted_list_of( + sorted(opts.make_list_of_close_matches(unused_user_defined_workload_params, self.complete_workload_params.workload_defined_params)) + ) + ), "\n".join(opts.bulleted_list_of(sorted(list(self.workload_params.keys())))), - "\n".join(opts.bulleted_list_of(self.complete_workload_params.sorted_workload_defined_params)))) + "\n".join(opts.bulleted_list_of(self.complete_workload_params.sorted_workload_defined_params)), + ) + ) self.logger.critical(err_msg) # also dump the message on the console console.println(err_msg) - raise exceptions.WorkloadConfigError( - "Unused workload parameters {}.".format(sorted(unused_user_defined_workload_params)) - ) + raise exceptions.WorkloadConfigError("Unused workload parameters {}.".format(sorted(unused_user_defined_workload_params))) return current_workload @@ -1411,10 +1393,7 @@ def register_query_randomization_info(self, op_name, query_name, parameter_name_ @property def meta_data(self): - return { - "benchmark_version": version.release_version(), - "async_runner": True - } + return {"benchmark_version": version.release_version(), "async_runner": True} class WorkloadSpecificationReader: @@ -1435,15 +1414,11 @@ def __call__(self, workload_name, workload_specification, mapping_dir): description = self._r(workload_specification, "description", mandatory=False, default_value="") meta_data = self._r(workload_specification, "meta", mandatory=False) - collections = [self._create_collection(col, mapping_dir) - for col in self._r(workload_specification, "collections", mandatory=False, default_value=[])] - corpora = self._create_corpora(self._r(workload_specification, "corpora", mandatory=False, default_value=[]), - collections=collections) + collections = [self._create_collection(col, mapping_dir) for col in self._r(workload_specification, "collections", mandatory=False, default_value=[])] + corpora = self._create_corpora(self._r(workload_specification, "corpora", mandatory=False, default_value=[]), collections=collections) test_procedures = self._create_test_procedures(workload_specification) # at this point, *all* workload params must have been referenced in the templates - return workload.Workload(name=self.name, meta_data=meta_data, - description=description, test_procedures=test_procedures, - corpora=corpora, collections=collections) + return workload.Workload(name=self.name, meta_data=meta_data, description=description, test_procedures=test_procedures, corpora=corpora, collections=collections) def _error(self, msg): raise WorkloadSyntaxError("Workload '%s' is invalid. %s" % (self.name, msg)) @@ -1492,8 +1467,7 @@ def _load_template(self, contents, description): self.logger.info("Loading template [%s].", description) register_all_params_in_workload(contents, self.complete_workload_params) try: - rendered = render_template(template_source=contents, - template_vars=self.workload_params) + rendered = render_template(template_source=contents, template_vars=self.workload_params) return json.loads(rendered) except Exception as e: self.logger.exception("Could not load file template for %s.", description) @@ -1511,15 +1485,12 @@ def _create_corpora(self, corpora_specs, collections=None): known_corpora_names.add(name) meta_data = self._r(corpus_spec, "meta", error_ctx=name, mandatory=False) - streaming_ingestion = self._r(corpus_spec, "streaming-ingestion", mandatory=False, - default_value="") + streaming_ingestion = self._r(corpus_spec, "streaming-ingestion", mandatory=False, default_value="") corpus = workload.DocumentCorpus(name=name, streaming_ingestion=streaming_ingestion, meta_data=meta_data) # defaults on corpus level default_base_url = self._r(corpus_spec, "base-url", mandatory=False, default_value=None) - default_source_format = self._r(corpus_spec, "source-format", mandatory=False, - default_value=workload.Documents.SOURCE_FORMAT_BULK) - default_action_and_meta_data = self._r(corpus_spec, "includes-action-and-meta-data", mandatory=False, - default_value=False) + default_source_format = self._r(corpus_spec, "source-format", mandatory=False, default_value=workload.Documents.SOURCE_FORMAT_BULK) + default_action_and_meta_data = self._r(corpus_spec, "includes-action-and-meta-data", mandatory=False, default_value=False) corpus_target_idx = None if len(collections) == 1: @@ -1548,30 +1519,31 @@ def _create_corpora(self, corpora_specs, collections=None): uncompressed_bytes = self._r(doc_spec, "uncompressed-bytes", mandatory=False) doc_meta_data = self._r(doc_spec, "meta", error_ctx=name, mandatory=False) - includes_action_and_meta_data = self._r(doc_spec, "includes-action-and-meta-data", mandatory=False, - default_value=default_action_and_meta_data) + includes_action_and_meta_data = self._r(doc_spec, "includes-action-and-meta-data", mandatory=False, default_value=default_action_and_meta_data) if includes_action_and_meta_data: target_idx = None target_type = None else: target_type = None - target_idx = self._r(doc_spec, "target-collection", - mandatory=len(collections) > 0 and corpus_target_idx is None, - default_value=corpus_target_idx, - error_ctx=docs) - - docs = workload.Documents(source_format=source_format, - document_file=document_file, - document_file_parts=document_file_parts, - document_archive=document_archive, - base_url=base_url, - source_url=source_url, - includes_action_and_meta_data=includes_action_and_meta_data, - number_of_documents=num_docs, - compressed_size_in_bytes=compressed_bytes, - uncompressed_size_in_bytes=uncompressed_bytes, - target_collection=target_idx, target_type=target_type, - meta_data=doc_meta_data) + target_idx = self._r( + doc_spec, "target-collection", mandatory=len(collections) > 0 and corpus_target_idx is None, default_value=corpus_target_idx, error_ctx=docs + ) + + docs = workload.Documents( + source_format=source_format, + document_file=document_file, + document_file_parts=document_file_parts, + document_archive=document_archive, + base_url=base_url, + source_url=source_url, + includes_action_and_meta_data=includes_action_and_meta_data, + number_of_documents=num_docs, + compressed_size_in_bytes=compressed_bytes, + uncompressed_size_in_bytes=uncompressed_bytes, + target_collection=target_idx, + target_type=target_type, + meta_data=doc_meta_data, + ) corpus.documents.append(docs) else: self._error("Unknown source-format [%s] in document corpus [%s]." % (source_format, name)) @@ -1596,8 +1568,7 @@ def _create_test_procedures(self, workload_spec): default = number_of_test_procedures == 1 or self._r(test_procedure_spec, "default", error_ctx=name, mandatory=False) selected = number_of_test_procedures == 1 or self.selected_test_procedure == name if default and default_test_procedure is not None: - self._error("Both '%s' and '%s' are defined as default test_procedures. Please define only one of them as default." - % (default_test_procedure.name, name)) + self._error("Both '%s' and '%s' are defined as default test_procedures. Please define only one of them as default." % (default_test_procedure.name, name)) if name in known_test_procedure_names: self._error("Duplicate test_procedure with name '%s'." % name) known_test_procedure_names.add(name) @@ -1606,8 +1577,9 @@ def _create_test_procedures(self, workload_spec): for op in self._r(test_procedure_spec, "schedule", error_ctx=name): if "clients_list" in op: - self.logger.info("Clients list specified: %s. Running multiple search tasks, "\ - "each scheduled with the corresponding number of clients from the list.", op["clients_list"]) + self.logger.info( + "Clients list specified: %s. Running multiple search tasks, each scheduled with the corresponding number of clients from the list.", op["clients_list"] + ) for num_clients in op["clients_list"]: op["clients"] = num_clients @@ -1630,23 +1602,27 @@ def _create_test_procedures(self, workload_spec): for task in schedule: for sub_task in task: if sub_task.name in known_task_names: - self._error("TestProcedure '%s' contains multiple tasks with the name '%s'. Please use the task's name property to " - "assign a unique name for each task." % (name, sub_task.name)) + self._error( + "TestProcedure '%s' contains multiple tasks with the name '%s'. Please use the task's name property to " + "assign a unique name for each task." % (name, sub_task.name) + ) else: known_task_names.add(sub_task.name) # merge params final_test_procedure_params = dict(collections.merge_dicts(workload_params, test_procedure_params)) - test_procedure = workload.TestProcedure(name=name, - parameters=final_test_procedure_params, - meta_data=meta_data, - description=description, - user_info=user_info, - default=default, - selected=selected, - auto_generated=auto_generated, - schedule=schedule) + test_procedure = workload.TestProcedure( + name=name, + parameters=final_test_procedure_params, + meta_data=meta_data, + description=description, + user_info=user_info, + default=default, + selected=selected, + auto_generated=auto_generated, + schedule=schedule, + ) if default: default_test_procedure = test_procedure @@ -1654,16 +1630,16 @@ def _create_test_procedures(self, workload_spec): if test_procedures and default_test_procedure is None: self._error( - "No default test_procedure specified. Please edit the workload and add \"default\": true to one of the test_procedures %s." - % ", ".join([c.name for c in test_procedures])) + 'No default test_procedure specified. Please edit the workload and add "default": true to one of the test_procedures %s.' + % ", ".join([c.name for c in test_procedures]) + ) return test_procedures def _rename_task_based_on_num_clients(self, name: str, num_clients: int) -> str: has_underscore = "_" in name has_hyphen = "-" in name if has_underscore and has_hyphen: - self.logger.warning("The test procedure name %s contains a mix of _ and -. "\ - "Consider changing the name to avoid frustrating bugs in the future.", name) + self.logger.warning("The test procedure name %s contains a mix of _ and -. Consider changing the name to avoid frustrating bugs in the future.", name) return name + "_" + str(num_clients) + "_clients" elif has_hyphen: return name + "-" + str(num_clients) + "-clients" @@ -1686,14 +1662,9 @@ def _get_test_procedure_specs(self, workload_spec): elif test_procedures is not None: return test_procedures, False elif schedule is not None: - return [{ - "name": "default", - "schedule": schedule - }], True + return [{"name": "default", "schedule": schedule}], True else: - raise AssertionError( - "Unexpected: schedule=[{}], test_procedure=[{}], test_procedures=[{}]".format( - schedule, test_procedure, test_procedures)) + raise AssertionError("Unexpected: schedule=[{}], test_procedure=[{}], test_procedures=[{}]".format(schedule, test_procedure, test_procedures)) def parse_parallel(self, ops_spec, ops, test_procedure_name): # use same default values as #parseTask() in case the 'parallel' element did not specify anything @@ -1709,25 +1680,38 @@ def parse_parallel(self, ops_spec, ops, test_procedure_name): # now descent to each operation tasks = [] for task in self._r(ops_spec, "tasks", error_ctx="parallel"): - tasks.append(self.parse_task(task, ops, test_procedure_name, default_warmup_iterations, default_iterations, - default_warmup_time_period, default_time_period, default_ramp_up_time_period, - default_ramp_down_time_period, completed_by)) + tasks.append( + self.parse_task( + task, + ops, + test_procedure_name, + default_warmup_iterations, + default_iterations, + default_warmup_time_period, + default_time_period, + default_ramp_up_time_period, + default_ramp_down_time_period, + completed_by, + ) + ) for task in tasks: if task.ramp_up_time_period != default_ramp_up_time_period: if default_ramp_up_time_period is None: - self._error(f"task '{task.name}' in 'parallel' element of test-procedure '{test_procedure_name}' specifies " - f"a ramp-up-time-period but it is only allowed on the 'parallel' element.") + self._error( + f"task '{task.name}' in 'parallel' element of test-procedure '{test_procedure_name}' specifies " + f"a ramp-up-time-period but it is only allowed on the 'parallel' element." + ) else: - self._error(f"task '{task.name}' specifies a different ramp-up-time-period than its enclosing " - f"'parallel' element in test-procedure '{test_procedure_name}'.") + self._error(f"task '{task.name}' specifies a different ramp-up-time-period than its enclosing 'parallel' element in test-procedure '{test_procedure_name}'.") if task.ramp_down_time_period != default_ramp_down_time_period: if default_ramp_down_time_period is None: - self._error(f"task '{task.name}' in 'parallel' element of test-procedure '{test_procedure_name}' specifies " - f"a ramp-down-time-period but it is only allowed on the 'parallel' element.") + self._error( + f"task '{task.name}' in 'parallel' element of test-procedure '{test_procedure_name}' specifies " + f"a ramp-down-time-period but it is only allowed on the 'parallel' element." + ) else: - self._error(f"task '{task.name}' specifies a different ramp-down-time-period than its enclosing " - f"'parallel' element in test-procedure '{test_procedure_name}'.") + self._error(f"task '{task.name}' specifies a different ramp-down-time-period than its enclosing 'parallel' element in test-procedure '{test_procedure_name}'.") if completed_by: completion_task = None for task in tasks: @@ -1736,15 +1720,28 @@ def parse_parallel(self, ops_spec, ops, test_procedure_name): elif task.completes_parent: self._error( "'parallel' element for test_procedure '%s' contains multiple tasks with the name '%s' which are marked with " - "'completed-by' but only task is allowed to match." % (test_procedure_name, completed_by)) + "'completed-by' but only task is allowed to match." % (test_procedure_name, completed_by) + ) if not completion_task: - self._error("'parallel' element for test_procedure '%s' is marked with 'completed-by' with task name '%s' but no task with " - "this name exists." % (test_procedure_name, completed_by)) + self._error( + "'parallel' element for test_procedure '%s' is marked with 'completed-by' with task name '%s' but no task with " + "this name exists." % (test_procedure_name, completed_by) + ) return workload.Parallel(tasks, clients) - def parse_task(self, task_spec, ops, test_procedure_name, default_warmup_iterations=None, default_iterations=None, - default_warmup_time_period=None, default_time_period=None, default_ramp_up_time_period=None, default_ramp_down_time_period=None, - completed_by_name=None): + def parse_task( + self, + task_spec, + ops, + test_procedure_name, + default_warmup_iterations=None, + default_iterations=None, + default_warmup_time_period=None, + default_time_period=None, + default_ramp_up_time_period=None, + default_ramp_down_time_period=None, + completed_by_name=None, + ): op_spec = task_spec["operation"] if isinstance(op_spec, str) and op_spec in ops: @@ -1755,62 +1752,71 @@ def parse_task(self, task_spec, ops, test_procedure_name, default_warmup_iterati schedule = self._r(task_spec, "schedule", error_ctx=op.name, mandatory=False) task_name = self._r(task_spec, "name", error_ctx=op.name, mandatory=False, default_value=op.name) - task = workload.Task(name=task_name, - operation=op, - tags=self._r(task_spec, "tags", error_ctx=op.name, mandatory=False), - meta_data=self._r(task_spec, "meta", error_ctx=op.name, mandatory=False), - warmup_iterations=self._r(task_spec, "warmup-iterations", error_ctx=op.name, mandatory=False, - default_value=default_warmup_iterations), - iterations=self._r(task_spec, "iterations", error_ctx=op.name, mandatory=False, default_value=default_iterations), - warmup_time_period=self._r(task_spec, "warmup-time-period", error_ctx=op.name, - mandatory=False, - default_value=default_warmup_time_period), - time_period=self._r(task_spec, "time-period", error_ctx=op.name, mandatory=False, - default_value=default_time_period), - ramp_up_time_period=self._r(task_spec, "ramp-up-time-period", error_ctx=op.name, - mandatory=False, default_value=default_ramp_up_time_period), - ramp_down_time_period=self._r(task_spec, "ramp-down-time-period", error_ctx=op.name, - mandatory=False, default_value=default_ramp_down_time_period), - clients=self._r(task_spec, "clients", error_ctx=op.name, mandatory=False, default_value=1), - completes_parent=(task_name == completed_by_name), - schedule=schedule, - # this is to provide scheduler-specific parameters for custom schedulers. - params=task_spec) + task = workload.Task( + name=task_name, + operation=op, + tags=self._r(task_spec, "tags", error_ctx=op.name, mandatory=False), + meta_data=self._r(task_spec, "meta", error_ctx=op.name, mandatory=False), + warmup_iterations=self._r(task_spec, "warmup-iterations", error_ctx=op.name, mandatory=False, default_value=default_warmup_iterations), + iterations=self._r(task_spec, "iterations", error_ctx=op.name, mandatory=False, default_value=default_iterations), + warmup_time_period=self._r(task_spec, "warmup-time-period", error_ctx=op.name, mandatory=False, default_value=default_warmup_time_period), + time_period=self._r(task_spec, "time-period", error_ctx=op.name, mandatory=False, default_value=default_time_period), + ramp_up_time_period=self._r(task_spec, "ramp-up-time-period", error_ctx=op.name, mandatory=False, default_value=default_ramp_up_time_period), + ramp_down_time_period=self._r(task_spec, "ramp-down-time-period", error_ctx=op.name, mandatory=False, default_value=default_ramp_down_time_period), + clients=self._r(task_spec, "clients", error_ctx=op.name, mandatory=False, default_value=1), + completes_parent=(task_name == completed_by_name), + schedule=schedule, + # this is to provide scheduler-specific parameters for custom schedulers. + params=task_spec, + ) if task.warmup_iterations is not None and task.time_period is not None: self._error( "Operation '%s' in test_procedure '%s' defines '%d' warmup iterations and a time period of '%d' seconds. Please do not " - "mix time periods and iterations." % (op.name, test_procedure_name, task.warmup_iterations, task.time_period)) + "mix time periods and iterations." % (op.name, test_procedure_name, task.warmup_iterations, task.time_period) + ) elif task.warmup_time_period is not None and task.iterations is not None: self._error( "Operation '%s' in test_procedure '%s' defines a warmup time period of '%d' seconds and '%d' iterations. Please do not " - "mix time periods and iterations." % (op.name, test_procedure_name, task.warmup_time_period, task.iterations)) + "mix time periods and iterations." % (op.name, test_procedure_name, task.warmup_time_period, task.iterations) + ) if (task.warmup_iterations is not None or task.iterations is not None) and task.ramp_up_time_period is not None: - self._error(f"Operation '{op.name}' in test_procedure '{test_procedure_name}' defines a ramp-up time period of " - f"{task.ramp_up_time_period} seconds as well as {task.warmup_iterations} warmup iterations and " - f"{task.iterations} iterations but mixing time periods and iterations is not allowed.") + self._error( + f"Operation '{op.name}' in test_procedure '{test_procedure_name}' defines a ramp-up time period of " + f"{task.ramp_up_time_period} seconds as well as {task.warmup_iterations} warmup iterations and " + f"{task.iterations} iterations but mixing time periods and iterations is not allowed." + ) if task.ramp_up_time_period is not None: if task.warmup_time_period is None: - self._error(f"Operation '{op.name}' in test_procedure '{test_procedure_name}' defines a ramp-up time period of " - f"{task.ramp_up_time_period} seconds but no warmup-time-period.") + self._error( + f"Operation '{op.name}' in test_procedure '{test_procedure_name}' defines a ramp-up time period of " + f"{task.ramp_up_time_period} seconds but no warmup-time-period." + ) elif task.warmup_time_period < task.ramp_up_time_period: - self._error(f"The warmup-time-period of operation '{op.name}' in test_procedure '{test_procedure_name}' is " - f"{task.warmup_time_period} seconds but must be greater than or equal to the " - f"ramp-up-time-period of {task.ramp_up_time_period} seconds.") + self._error( + f"The warmup-time-period of operation '{op.name}' in test_procedure '{test_procedure_name}' is " + f"{task.warmup_time_period} seconds but must be greater than or equal to the " + f"ramp-up-time-period of {task.ramp_up_time_period} seconds." + ) if task.ramp_down_time_period is not None: if task.time_period is None: - self._error(f"Operation '{op.name}' in test_procedure '{test_procedure_name}' defines a ramp-down time period of " - f"{task.ramp_down_time_period} seconds but no time-period.") + self._error( + f"Operation '{op.name}' in test_procedure '{test_procedure_name}' defines a ramp-down time period of {task.ramp_down_time_period} seconds but no time-period." + ) elif task.time_period < task.ramp_down_time_period: - self._error(f"The time-period of operation '{op.name}' in test_procedure '{test_procedure_name}' is " - f"{task.time_period} seconds but must be greater than or equal to the " - f"ramp-down-time-period of {task.ramp_down_time_period} seconds.") + self._error( + f"The time-period of operation '{op.name}' in test_procedure '{test_procedure_name}' is " + f"{task.time_period} seconds but must be greater than or equal to the " + f"ramp-down-time-period of {task.ramp_down_time_period} seconds." + ) if (task.warmup_iterations is not None or task.iterations is not None) and task.ramp_down_time_period is not None: - self._error(f"Operation '{op.name}' in test_procedure '{test_procedure_name}' defines a ramp-down time period of " - f"{task.ramp_down_time_period} seconds as well as {task.warmup_iterations} warmup iterations and " - f"{task.iterations} iterations but mixing time periods and iterations is not allowed.") + self._error( + f"Operation '{op.name}' in test_procedure '{test_procedure_name}' defines a ramp-down time period of " + f"{task.ramp_down_time_period} seconds as well as {task.warmup_iterations} warmup iterations and " + f"{task.iterations} iterations but mixing time periods and iterations is not allowed." + ) return task @@ -1853,8 +1859,6 @@ def parse_operation(self, op_spec, error_ctx="operations"): self.logger.info("Using user-provided operation type [%s] for operation [%s].", op_type_name, op_name) try: - return workload.Operation(name=op_name, meta_data=meta_data, - operation_type=op_type_name, params=params, - param_source=param_source) + return workload.Operation(name=op_name, meta_data=meta_data, operation_type=op_type_name, params=params, param_source=param_source) except exceptions.InvalidSyntax as e: raise WorkloadSyntaxError("Invalid operation [%s]: %s" % (op_name, str(e))) diff --git a/solrorbit/workload/params.py b/solrorbit/workload/params.py index 99a1afd2..10876306 100644 --- a/solrorbit/workload/params.py +++ b/solrorbit/workload/params.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -51,6 +51,7 @@ __STANDARD_VALUES = {} __QUERY_RANDOMIZATION_INFOS = {} + def param_source_for_operation(op_type, workload, params, task_name): try: # we know that this can only be a Solr Orbit core parameter source @@ -71,13 +72,14 @@ def param_source_for_name(name, workload, params): else: return param_source(workload, params) + def get_standard_value_source(op_name, field_name): try: return __STANDARD_VALUE_SOURCES[op_name][field_name] except KeyError: raise exceptions.SystemSetupError( - "Could not find standard value source for operation {}, field {}! Make sure this is registered in workload.py" - .format(op_name, field_name)) + "Could not find standard value source for operation {}, field {}! Make sure this is registered in workload.py".format(op_name, field_name) + ) def ensure_valid_param_source(param_source): @@ -94,12 +96,14 @@ def register_param_source_for_name(name, param_source_class): ensure_valid_param_source(param_source_class) __PARAM_SOURCES_BY_NAME[name] = param_source_class + def register_standard_value_source(op_name, field_name, standard_value_source): if op_name in __STANDARD_VALUE_SOURCES: __STANDARD_VALUE_SOURCES[op_name][field_name] = standard_value_source # We have to allow re-registration for the same op/field, since plugins are loaded many times when a workload is run else: - __STANDARD_VALUE_SOURCES[op_name] = {field_name:standard_value_source} + __STANDARD_VALUE_SOURCES[op_name] = {field_name: standard_value_source} + def generate_standard_values_if_absent(op_name, field_name, n): if op_name not in __STANDARD_VALUES: @@ -109,12 +113,11 @@ def generate_standard_values_if_absent(op_name, field_name, n): try: standard_value_source = __STANDARD_VALUE_SOURCES[op_name][field_name] except KeyError: - raise exceptions.SystemSetupError( - "Cannot generate standard values for operation {}, field {}. Standard value source is missing" - .format(op_name, field_name)) + raise exceptions.SystemSetupError("Cannot generate standard values for operation {}, field {}. Standard value source is missing".format(op_name, field_name)) for _i in range(n): __STANDARD_VALUES[op_name][field_name].append(standard_value_source()) + def get_standard_value(op_name, field_name, i): try: return __STANDARD_VALUES[op_name][field_name][i] @@ -122,22 +125,22 @@ def get_standard_value(op_name, field_name, i): raise exceptions.SystemSetupError("No standard values generated for operation {}, field {}".format(op_name, field_name)) except IndexError: raise exceptions.SystemSetupError( - "Standard value index {} out of range for operation {}, field name {} ({} values total)" - .format(i, op_name, field_name, len(__STANDARD_VALUES[op_name][field_name]))) + "Standard value index {} out of range for operation {}, field name {} ({} values total)".format(i, op_name, field_name, len(__STANDARD_VALUES[op_name][field_name])) + ) + def register_query_randomization_info(op_name, query_name, parameter_name_options_list, optional_parameters): # query_randomization_info is registered at the operation level - query_randomization_info = loader.QueryRandomizerWorkloadProcessor.QueryRandomizationInfo(query_name, - parameter_name_options_list, - optional_parameters - ) + query_randomization_info = loader.QueryRandomizerWorkloadProcessor.QueryRandomizationInfo(query_name, parameter_name_options_list, optional_parameters) __QUERY_RANDOMIZATION_INFOS[op_name] = query_randomization_info + def get_query_randomization_info(op_name): try: - return __QUERY_RANDOMIZATION_INFOS[op_name] + return __QUERY_RANDOMIZATION_INFOS[op_name] except KeyError: - return loader.QueryRandomizerWorkloadProcessor.DEFAULT_QUERY_RANDOMIZATION_INFO # If nothing is registered, return the default. + return loader.QueryRandomizerWorkloadProcessor.DEFAULT_QUERY_RANDOMIZATION_INFO # If nothing is registered, return the default. + # only intended for tests def _unregister_param_source_for_name(name): @@ -145,14 +148,17 @@ def _unregister_param_source_for_name(name): # something is fishy with the test and we'd rather know early. __PARAM_SOURCES_BY_NAME.pop(name) + # only intended for tests def _clear_standard_values(): __STANDARD_VALUES = {} __STANDARD_VALUE_SOURCES = {} + def _clear_query_randomization_infos(): __QUERY_RANDOMIZATION_INFOS = {} + # Default class ParamSource: """ @@ -222,11 +228,7 @@ def _client_params(self): :return: all applicable parameters that are global to Solr Orbit and apply to the cluster client """ - return { - "request-timeout": self._params.get("request-timeout"), - "headers": self._params.get("headers"), - "opaque-id": self._params.get("opaque-id") - } + return {"request-timeout": self._params.get("request-timeout"), "headers": self._params.get("headers"), "opaque-id": self._params.get("opaque-id")} class DelegatingParamSource(ParamSource): @@ -348,8 +350,7 @@ def __init__(self, workload, params, **kwargs): target_name = get_target(workload, params) type_name = params.get("type") if params.get("data-stream") and type_name: - raise exceptions.InvalidSyntax( - f"'type' not supported with 'data-stream' for operation '{kwargs.get('operation_name')}'") + raise exceptions.InvalidSyntax(f"'type' not supported with 'data-stream' for operation '{kwargs.get('operation_name')}'") request_cache = params.get("cache", None) detailed_results = params.get("detailed-results", False) calculate_recall = params.get("calculate-recall", True) @@ -370,12 +371,11 @@ def __init__(self, workload, params, **kwargs): "calculate-recall": calculate_recall, "request-params": request_params, "response-compression-enabled": response_compression_enabled, - "body": query_body + "body": query_body, } if not target_name: - raise exceptions.InvalidSyntax( - f"'index' or 'data-stream' is mandatory and is missing for operation '{kwargs.get('operation_name')}'") + raise exceptions.InvalidSyntax(f"'index' or 'data-stream' is mandatory and is missing for operation '{kwargs.get('operation_name')}'") if pages: self.query_params["pages"] = pages @@ -411,6 +411,7 @@ class IndexIdConflict(Enum): Note that this assumes that each document in the benchmark corpus has an id between [1, size_of(corpus)] """ + NoConflicts = 0 SequentialConflicts = 1 RandomConflicts = 2 @@ -433,8 +434,7 @@ def __init__(self, workload, params, **kwargs): raise exceptions.InvalidSyntax("'conflicts' cannot be used with 'data-streams'") if self.id_conflicts != IndexIdConflict.NoConflicts: - self.conflict_probability = self.float_param(params, name="conflict-probability", default_value=25, min_value=0, max_value=100, - min_operator=operator.lt) + self.conflict_probability = self.float_param(params, name="conflict-probability", default_value=25, min_value=0, max_value=100, min_operator=operator.lt) self.on_conflict = params.get("on-conflict", "index") if self.on_conflict not in ["index", "update"]: raise exceptions.InvalidSyntax("Unknown 'on-conflict' setting [{}]".format(self.on_conflict)) @@ -448,16 +448,18 @@ def __init__(self, workload, params, **kwargs): self.corpora = self.used_corpora(workload, params) if len(self.corpora) == 0: - raise exceptions.InvalidSyntax(f"There is no document corpus definition for workload {workload}. You must add at " - f"least one before making bulk requests to the target cluster.") + raise exceptions.InvalidSyntax( + f"There is no document corpus definition for workload {workload}. You must add at least one before making bulk requests to the target cluster." + ) for corpus in self.corpora: for document_set in corpus.documents: if document_set.includes_action_and_meta_data and self.id_conflicts != IndexIdConflict.NoConflicts: file_name = document_set.document_archive if document_set.has_compressed_corpus() else document_set.document_file - raise exceptions.InvalidSyntax("Cannot generate id conflicts [%s] as [%s] in document corpus [%s] already contains an " - "action and meta-data line." % (id_conflicts, file_name, corpus)) + raise exceptions.InvalidSyntax( + "Cannot generate id conflicts [%s] as [%s] in document corpus [%s] already contains an action and meta-data line." % (id_conflicts, file_name, corpus) + ) self.pipeline = params.get("pipeline", None) try: @@ -482,18 +484,26 @@ def __init__(self, workload, params, **kwargs): self.ingest_percentage = self.float_param(params, name="ingest-percentage", default_value=100, min_value=0, max_value=100) self.looped = params.get("looped", False) - self.param_source = PartitionBulkIndexParamSource(self.corpora, self.batch_size, self.bulk_size, - self.ingest_percentage, self.id_conflicts, - self.conflict_probability, self.on_conflict, - self.recency, self.pipeline, self.looped, self._params) + self.param_source = PartitionBulkIndexParamSource( + self.corpora, + self.batch_size, + self.bulk_size, + self.ingest_percentage, + self.id_conflicts, + self.conflict_probability, + self.on_conflict, + self.recency, + self.pipeline, + self.looped, + self._params, + ) def float_param(self, params, name, default_value, min_value, max_value, min_operator=operator.le): try: value = float(params.get(name, default_value)) if min_operator(value, min_value) or value > max_value: interval_min = "(" if min_operator is operator.le else "[" - raise exceptions.InvalidSyntax( - "'{}' must be in the range {}{:.1f}, {:.1f}] but was {:.1f}".format(name, interval_min, min_value, max_value, value)) + raise exceptions.InvalidSyntax("'{}' must be in the range {}{:.1f}, {:.1f}] but was {:.1f}".format(name, interval_min, min_value, max_value, value)) return value except ValueError: raise exceptions.InvalidSyntax("'{}' must be numeric".format(name)) @@ -507,16 +517,13 @@ def used_corpora(self, t, params): for corpus in t.corpora: if corpus.name in corpora_names: - filtered_corpus = corpus.filter(source_format=workload.Documents.SOURCE_FORMAT_BULK, - target_collections=params.get("indices")) - if filtered_corpus.streaming_ingestion or \ - filtered_corpus.number_of_documents(source_format=workload.Documents.SOURCE_FORMAT_BULK) > 0: + filtered_corpus = corpus.filter(source_format=workload.Documents.SOURCE_FORMAT_BULK, target_collections=params.get("indices")) + if filtered_corpus.streaming_ingestion or filtered_corpus.number_of_documents(source_format=workload.Documents.SOURCE_FORMAT_BULK) > 0: corpora.append(filtered_corpus) # the workload has corpora but none of them match if t.corpora and not corpora: - raise exceptions.BenchmarkAssertionError("The provided corpus %s does not match any of the corpora %s." % - (corpora_names, workload_corpora_names)) + raise exceptions.BenchmarkAssertionError("The provided corpus %s does not match any of the corpora %s." % (corpora_names, workload_corpora_names)) return corpora @@ -530,8 +537,9 @@ def params(self): class PartitionBulkIndexParamSource: - def __init__(self, corpora, batch_size, bulk_size, ingest_percentage, id_conflicts, conflict_probability, - on_conflict, recency, pipeline=None, looped = False, original_params=None): + def __init__( + self, corpora, batch_size, bulk_size, ingest_percentage, id_conflicts, conflict_probability, on_conflict, recency, pipeline=None, looped=False, original_params=None + ): """ :param corpora: Specification of affected document corpora. @@ -572,8 +580,7 @@ def partition(self, partition_index, total_partitions): if self.total_partitions is None: self.total_partitions = total_partitions elif self.total_partitions != total_partitions: - raise exceptions.BenchmarkAssertionError( - f"Total partitions is expected to be [{self.total_partitions}] but was [{total_partitions}]") + raise exceptions.BenchmarkAssertionError(f"Total partitions is expected to be [{self.total_partitions}] but was [{total_partitions}]") self.partitions.append(partition_index) def params(self): @@ -596,10 +603,21 @@ def _init_internal_params(self): start_index = self.partitions[0] end_index = self.partitions[-1] - self.internal_params = bulk_data_based(self.total_partitions, start_index, end_index, self.corpora, - self.batch_size, self.bulk_size, self.id_conflicts, - self.conflict_probability, self.on_conflict, self.recency, - self.pipeline, self.original_params, self.create_reader) + self.internal_params = bulk_data_based( + self.total_partitions, + start_index, + end_index, + self.corpora, + self.batch_size, + self.bulk_size, + self.id_conflicts, + self.conflict_probability, + self.on_conflict, + self.recency, + self.pipeline, + self.original_params, + self.create_reader, + ) if not self.streaming_ingestion: all_bulks = number_of_bulks(self.corpora, start_index, end_index, self.total_partitions, self.bulk_size) @@ -607,8 +625,7 @@ def _init_internal_params(self): @property def task_progress(self): - return (IngestionManager.rd_index.value * IngestionManager.chunk_size/1000, 'GB') if self.streaming_ingestion else (self.current_bulk / self.total_bulks, '%') - + return (IngestionManager.rd_index.value * IngestionManager.chunk_size / 1000, "GB") if self.streaming_ingestion else (self.current_bulk / self.total_bulks, "%") def get_target(workload, params): @@ -621,6 +638,7 @@ def get_target(workload, params): target_name = params.get("data-stream", default_target) return target_name + def number_of_bulks(corpora, start_partition_index, end_partition_index, total_partitions, bulk_size): """ :return: The number of bulk operations that the given client will issue. @@ -628,8 +646,7 @@ def number_of_bulks(corpora, start_partition_index, end_partition_index, total_p bulks = 0 for corpus in corpora: for docs in corpus.documents: - _, num_docs, _ = bounds(docs.number_of_documents, start_partition_index, end_partition_index, - total_partitions, docs.includes_action_and_meta_data) + _, num_docs, _ = bounds(docs.number_of_documents, start_partition_index, end_partition_index, total_partitions, docs.includes_action_and_meta_data) complete_bulks, rest = (num_docs // bulk_size, num_docs % bulk_size) bulks += complete_bulks if rest > 0: @@ -663,8 +680,7 @@ def chain(*iterables): yield element -def create_default_reader(corpus, docs, offset, num_lines, num_docs, batch_size, bulk_size, id_conflicts, conflict_probability, - on_conflict, recency): +def create_default_reader(corpus, docs, offset, num_lines, num_docs, batch_size, bulk_size, id_conflicts, conflict_probability, on_conflict, recency): source = Slice(io.MmapSource, offset, num_lines, corpus, docs) target = None use_create = False @@ -674,35 +690,36 @@ def create_default_reader(corpus, docs, offset, num_lines, num_docs, batch_size, if docs.includes_action_and_meta_data: return SourceOnlyIndexDataReader(docs.document_file, batch_size, bulk_size, source, target, docs.target_type) else: - am_handler = GenerateActionMetaData(target, docs.target_type, - build_conflicting_ids(id_conflicts, num_docs, offset), conflict_probability, - on_conflict, recency, use_create=use_create) + am_handler = GenerateActionMetaData( + target, docs.target_type, build_conflicting_ids(id_conflicts, num_docs, offset), conflict_probability, on_conflict, recency, use_create=use_create + ) return MetadataIndexDataReader(docs.document_file, batch_size, bulk_size, source, am_handler, target, docs.target_type) -def create_readers(num_clients, start_client_index, end_client_index, corpora, batch_size, bulk_size, id_conflicts, - conflict_probability, on_conflict, recency, create_reader): +def create_readers(num_clients, start_client_index, end_client_index, corpora, batch_size, bulk_size, id_conflicts, conflict_probability, on_conflict, recency, create_reader): logger = logging.getLogger(__name__) readers = [] for corpus in corpora: for docs in corpus.documents: if corpus.streaming_ingestion: offset = num_lines = num_docs = 0 - readers.append(create_reader(corpus, docs, offset, num_lines, num_docs, batch_size, bulk_size, id_conflicts, - conflict_probability, on_conflict, recency)) + readers.append(create_reader(corpus, docs, offset, num_lines, num_docs, batch_size, bulk_size, id_conflicts, conflict_probability, on_conflict, recency)) else: - offset, num_docs, num_lines = bounds(docs.number_of_documents, start_client_index, end_client_index, - num_clients, docs.includes_action_and_meta_data) + offset, num_docs, num_lines = bounds(docs.number_of_documents, start_client_index, end_client_index, num_clients, docs.includes_action_and_meta_data) if num_docs > 0: target = f"{docs.target_collection}/{docs.target_type}" if docs.target_collection else "/" - logger.info("Task-relative clients at index [%d-%d] will bulk index [%d] docs starting from line offset [%d] for [%s] " - "from corpus [%s].", start_client_index, end_client_index, num_docs, offset, - target, corpus.name) - readers.append(create_reader(corpus, docs, offset, num_lines, num_docs, batch_size, bulk_size, id_conflicts, - conflict_probability, on_conflict, recency)) + logger.info( + "Task-relative clients at index [%d-%d] will bulk index [%d] docs starting from line offset [%d] for [%s] from corpus [%s].", + start_client_index, + end_client_index, + num_docs, + offset, + target, + corpus.name, + ) + readers.append(create_reader(corpus, docs, offset, num_lines, num_docs, batch_size, bulk_size, id_conflicts, conflict_probability, on_conflict, recency)) else: - logger.info("Task-relative clients at index [%d-%d] skip [%s] (no documents to read).", - start_client_index, end_client_index, corpus.name) + logger.info("Task-relative clients at index [%d-%d] skip [%s] (no documents to read).", start_client_index, end_client_index, corpus.name) return readers @@ -748,7 +765,7 @@ def bulk_generator(readers, pipeline, original_params): "body": bulk, # This is not always equal to the bulk_size we get as parameter. The last bulk may be less than the bulk size. "bulk-size": docs_in_bulk, - "unit": "docs" + "unit": "docs", } if pipeline: bulk_params["pipeline"] = pipeline @@ -758,8 +775,21 @@ def bulk_generator(readers, pipeline, original_params): yield params -def bulk_data_based(num_clients, start_client_index, end_client_index, corpora, batch_size, bulk_size, id_conflicts, - conflict_probability, on_conflict, recency, pipeline, original_params, create_reader=create_default_reader): +def bulk_data_based( + num_clients, + start_client_index, + end_client_index, + corpora, + batch_size, + bulk_size, + id_conflicts, + conflict_probability, + on_conflict, + recency, + pipeline, + original_params, + create_reader=create_default_reader, +): """ Calculates the necessary schedule for bulk operations. @@ -782,21 +812,31 @@ def bulk_data_based(num_clients, start_client_index, end_client_index, corpora, intended for testing only. :return: A generator for the bulk operations of the given client. """ - readers = create_readers(num_clients, start_client_index, end_client_index, corpora, batch_size, bulk_size, - id_conflicts, conflict_probability, on_conflict, recency, create_reader) + readers = create_readers( + num_clients, start_client_index, end_client_index, corpora, batch_size, bulk_size, id_conflicts, conflict_probability, on_conflict, recency, create_reader + ) return bulk_generator(chain(*readers), pipeline, original_params) class GenerateActionMetaData: RECENCY_SLOPE = 30 - def __init__(self, index_name, type_name, conflicting_ids=None, conflict_probability=None, on_conflict=None, recency=None, - rand=random.random, randint=random.randint, randexp=random.expovariate, use_create=False): + def __init__( + self, + index_name, + type_name, + conflicting_ids=None, + conflict_probability=None, + on_conflict=None, + recency=None, + rand=random.random, + randint=random.randint, + randexp=random.expovariate, + use_create=False, + ): if type_name: - self.meta_data_index_with_id = '{"index": {"_index": "%s", "_type": "%s", "_id": "%s"}}\n' % \ - (index_name, type_name, "%s") - self.meta_data_update_with_id = '{"update": {"_index": "%s", "_type": "%s", "_id": "%s"}}\n' % \ - (index_name, type_name, "%s") + self.meta_data_index_with_id = '{"index": {"_index": "%s", "_type": "%s", "_id": "%s"}}\n' % (index_name, type_name, "%s") + self.meta_data_update_with_id = '{"update": {"_index": "%s", "_type": "%s", "_id": "%s"}}\n' % (index_name, type_name, "%s") self.meta_data_index_no_id = '{"index": {"_index": "%s", "_type": "%s"}}\n' % (index_name, type_name) else: self.meta_data_index_with_id = '{"index": {"_index": "%s", "_id": "%s"}}\n' % (index_name, "%s") @@ -892,7 +932,8 @@ def _start_producer(): client_options = getattr(client_options_obj, "all_client_options", {}) # pylint: disable = import-outside-toplevel from solrorbit.utils.s3_data_producer import S3DataProducer - bucket = re.sub('^s3://', "", Slice.base_url) + + bucket = re.sub("^s3://", "", Slice.base_url) keys = Slice.document_file producer = S3DataProducer(bucket, keys, client_options, Slice.data_dir) p = multiprocessing.Process(target=producer.generate_chunked_data) @@ -908,8 +949,7 @@ def open(self, file_name, mode, bulk_size): self._open_next() else: self.source = self.source_class(file_name, mode).open() - self.logger.info("Will read [%d] lines from [%s] starting from line [%d] with bulk size [%d].", - self.number_of_lines, file_name, self.offset, self.bulk_size) + self.logger.info("Will read [%d] lines from [%s] starting from line [%d] with bulk size [%d].", self.number_of_lines, file_name, self.offset, self.bulk_size) start = time.perf_counter() io.skip_lines(file_name, self.source, self.offset) end = time.perf_counter() @@ -1073,7 +1113,7 @@ def _read_bulk_regular(self): if action_type == "update": # remove the trailing "\n" as the doc needs to fit on one line doc = doc.strip() - current_bulk.append(b"{\"doc\":%s}\n" % doc) + current_bulk.append(b'{"doc":%s}\n' % doc) else: current_bulk.append(doc) else: @@ -1109,6 +1149,7 @@ def read_bulk(self): # Solr-specific param sources # --------------------------------------------------------------------------- + class SolrSearchParamSource(ParamSource): """ Param source for Solr search operations. @@ -1132,9 +1173,7 @@ def __init__(self, workload, params, **kwargs): super().__init__(workload, params, **kwargs) collection = params.get("collection") or get_target(workload, params) if not collection: - raise exceptions.InvalidSyntax( - f"'collection' is mandatory and is missing for operation '{kwargs.get('operation_name')}'" - ) + raise exceptions.InvalidSyntax(f"'collection' is mandatory and is missing for operation '{kwargs.get('operation_name')}'") self.query_params = { "collection": collection, diff --git a/solrorbit/workload/workload.py b/solrorbit/workload/workload.py index 05fab3ff..f7c970e1 100644 --- a/solrorbit/workload/workload.py +++ b/solrorbit/workload/workload.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -33,7 +33,6 @@ from solrorbit import exceptions - class Collection: """ Defines a Solr collection (Solr-native equivalent of Index). @@ -48,10 +47,9 @@ class Collection: tlog_replicas: TLOG replicas per shard (default: 0). """ - def __init__(self, name: str, configset: str = None, - configset_path: str = None, - num_shards: int = 1, replication_factor: int = 1, - pull_replicas: int = 0, tlog_replicas: int = 0): + def __init__( + self, name: str, configset: str = None, configset_path: str = None, num_shards: int = 1, replication_factor: int = 1, pull_replicas: int = 0, tlog_replicas: int = 0 + ): self.name = name self.configset = configset or name self.configset_path = configset_path @@ -86,17 +84,28 @@ def __eq__(self, other): return self.name == other.name - class Documents: SOURCE_FORMAT_BULK = "bulk" SOURCE_FORMAT_HDF5 = "hdf5" SOURCE_FORMAT_BIG_ANN = "big-ann" SUPPORTED_SOURCE_FORMAT = [SOURCE_FORMAT_BULK, SOURCE_FORMAT_HDF5, SOURCE_FORMAT_BIG_ANN] - def __init__(self, source_format, document_file=None, document_file_parts=None, document_archive=None, base_url=None, source_url=None, - includes_action_and_meta_data=False, - number_of_documents=0, compressed_size_in_bytes=0, uncompressed_size_in_bytes=0, target_collection=None, - target_type=None, meta_data=None): + def __init__( + self, + source_format, + document_file=None, + document_file_parts=None, + document_archive=None, + base_url=None, + source_url=None, + includes_action_and_meta_data=False, + number_of_documents=0, + compressed_size_in_bytes=0, + uncompressed_size_in_bytes=0, + target_collection=None, + target_type=None, + meta_data=None, + ): """ :param source_format: The format of these documents. Mandatory. @@ -199,19 +208,49 @@ def __repr__(self): return ", ".join(r) def __hash__(self): - return hash(self.source_format) ^ hash(self.document_file) ^ hash(self.document_archive) ^ hash(self.base_url) ^ \ - hash(self.source_url) ^ hash(self.includes_action_and_meta_data) ^ hash(self.number_of_documents) ^ \ - hash(self.compressed_size_in_bytes) ^ hash(self.uncompressed_size_in_bytes) ^ hash(self.target_collection) ^ \ - hash(self.target_type) ^ hash(frozenset(self.meta_data.items())) + return ( + hash(self.source_format) + ^ hash(self.document_file) + ^ hash(self.document_archive) + ^ hash(self.base_url) + ^ hash(self.source_url) + ^ hash(self.includes_action_and_meta_data) + ^ hash(self.number_of_documents) + ^ hash(self.compressed_size_in_bytes) + ^ hash(self.uncompressed_size_in_bytes) + ^ hash(self.target_collection) + ^ hash(self.target_type) + ^ hash(frozenset(self.meta_data.items())) + ) def __eq__(self, othr): - return (isinstance(othr, type(self)) and - (self.source_format, self.document_file, self.document_archive, self.base_url, self.source_url, - self.includes_action_and_meta_data, self.number_of_documents, self.compressed_size_in_bytes, - self.uncompressed_size_in_bytes, self.target_collection, self.target_type, self.meta_data) == - (othr.source_format, othr.document_file, othr.document_archive, othr.base_url, othr.source_url, - othr.includes_action_and_meta_data, othr.number_of_documents, othr.compressed_size_in_bytes, - othr.uncompressed_size_in_bytes, othr.target_collection, othr.target_type, othr.meta_data)) + return isinstance(othr, type(self)) and ( + self.source_format, + self.document_file, + self.document_archive, + self.base_url, + self.source_url, + self.includes_action_and_meta_data, + self.number_of_documents, + self.compressed_size_in_bytes, + self.uncompressed_size_in_bytes, + self.target_collection, + self.target_type, + self.meta_data, + ) == ( + othr.source_format, + othr.document_file, + othr.document_archive, + othr.base_url, + othr.source_url, + othr.includes_action_and_meta_data, + othr.number_of_documents, + othr.compressed_size_in_bytes, + othr.uncompressed_size_in_bytes, + othr.target_collection, + othr.target_type, + othr.meta_data, + ) class DocumentCorpus: @@ -262,8 +301,7 @@ def filter(self, source_format=None, target_collections=None): continue filtered.append(d) - return DocumentCorpus(self.name, filtered, streaming_ingestion=self.streaming_ingestion, - meta_data=dict(self.meta_data)) + return DocumentCorpus(self.name, filtered, streaming_ingestion=self.streaming_ingestion, meta_data=dict(self.meta_data)) def union(self, other): """ @@ -282,10 +320,9 @@ def union(self, other): if self is other: return self else: - return DocumentCorpus(name=self.name, - documents=list(set(self.documents).union(other.documents)), - streaming_ingestion=self.streaming_ingestion, - meta_data=dict(self.meta_data)) + return DocumentCorpus( + name=self.name, documents=list(set(self.documents).union(other.documents)), streaming_ingestion=self.streaming_ingestion, meta_data=dict(self.meta_data) + ) def __str__(self): return self.name @@ -300,9 +337,7 @@ def __hash__(self): return hash(self.name) ^ hash(self.documents) ^ hash(frozenset(self.meta_data.items())) def __eq__(self, othr): - return (isinstance(othr, type(self)) and - (self.name, self.documents, self.meta_data) == - (othr.name, othr.documents, othr.meta_data)) + return isinstance(othr, type(self)) and (self.name, self.documents, self.meta_data) == (othr.name, othr.documents, othr.meta_data) class Workload: @@ -310,8 +345,7 @@ class Workload: A workload defines the data set that is used. It corresponds loosely to a use case (e.g. logging, event processing, analytics, ...) """ - def __init__(self, name, description=None, meta_data=None, test_procedures=None, - corpora=None, has_plugins=False, collections=None): + def __init__(self, name, description=None, meta_data=None, test_procedures=None, corpora=None, has_plugins=False, collections=None): """ Creates a new workload. @@ -412,31 +446,28 @@ def __repr__(self): return ", ".join(r) def __hash__(self): - return hash(self.name) ^ hash(self.meta_data) ^ hash(self.description) ^ hash(self.test_procedures) ^ \ - hash(self.corpora) + return hash(self.name) ^ hash(self.meta_data) ^ hash(self.description) ^ hash(self.test_procedures) ^ hash(self.corpora) def __eq__(self, othr): - return (isinstance(othr, type(self)) and - (self.name, self.meta_data, self.description, self.test_procedures, self.collections, self.corpora) == - (othr.name, othr.meta_data, othr.description, othr.test_procedures, othr.collections, othr.corpora)) + return isinstance(othr, type(self)) and (self.name, self.meta_data, self.description, self.test_procedures, self.collections, self.corpora) == ( + othr.name, + othr.meta_data, + othr.description, + othr.test_procedures, + othr.collections, + othr.corpora, + ) class TestProcedure: """ A test procedure defines the concrete operations that will be done. """ - #Pytest throws a collection warning if the following line is removed + + # Pytest throws a collection warning if the following line is removed __test__ = False - def __init__(self, - name, - description=None, - user_info=None, - default=False, - selected=False, - auto_generated=False, - parameters=None, - meta_data=None, - schedule=None): + + def __init__(self, name, description=None, user_info=None, default=False, selected=False, auto_generated=False, parameters=None, meta_data=None, schedule=None): self.name = name self.parameters = parameters if parameters else {} self.meta_data = meta_data if meta_data else {} @@ -463,16 +494,29 @@ def __repr__(self): return ", ".join(r) def __hash__(self): - return hash(self.name) ^ hash(self.description) ^ hash(self.default) ^ \ - hash(self.selected) ^ hash(self.auto_generated) ^ hash(self.parameters) ^ hash(self.meta_data) ^ \ - hash(self.schedule) + return ( + hash(self.name) + ^ hash(self.description) + ^ hash(self.default) + ^ hash(self.selected) + ^ hash(self.auto_generated) + ^ hash(self.parameters) + ^ hash(self.meta_data) + ^ hash(self.schedule) + ) def __eq__(self, othr): - return (isinstance(othr, type(self)) and - (self.name, self.description, self.default, self.selected, self.auto_generated, - self.parameters, self.meta_data, self.schedule) == - (othr.name, othr.description, othr.default, othr.selected, othr.auto_generated, - othr.parameters, othr.meta_data, othr.schedule)) + return isinstance(othr, type(self)) and (self.name, self.description, self.default, self.selected, self.auto_generated, self.parameters, self.meta_data, self.schedule) == ( + othr.name, + othr.description, + othr.default, + othr.selected, + othr.auto_generated, + othr.parameters, + othr.meta_data, + othr.schedule, + ) + @unique class AdminStatus(Enum): @@ -540,6 +584,7 @@ def from_hyphenated_string(cls, v): else: raise KeyError(f"No enum value for [{v}]") + class TaskNameFilter: def __init__(self, name): self.name = name @@ -653,10 +698,23 @@ class Task: THROUGHPUT_PATTERN = re.compile(r"(?P(\d*\.)?\d+)\s(?P\w+/s)") IGNORE_RESPONSE_ERROR_LEVEL_WHITELIST = ["non-fatal"] - def __init__(self, name, operation, tags=None, meta_data=None, warmup_iterations=None, iterations=None, - warmup_time_period=None, time_period=None, ramp_up_time_period=None, ramp_down_time_period=None, - clients=1, completes_parent=False, - schedule=None, params=None): + def __init__( + self, + name, + operation, + tags=None, + meta_data=None, + warmup_iterations=None, + iterations=None, + warmup_time_period=None, + time_period=None, + ramp_up_time_period=None, + ramp_down_time_period=None, + clients=1, + completes_parent=False, + schedule=None, + params=None, + ): self.name = name self.operation = operation if isinstance(tags, str): @@ -691,8 +749,9 @@ def numeric(v): target_interval = self.params.get("target-interval") if target_interval is not None and target_throughput is not None: - raise exceptions.InvalidSyntax(f"Task [{self}] specifies target-interval [{target_interval}] and " - f"target-throughput [{target_throughput}] but only one of them is allowed.") + raise exceptions.InvalidSyntax( + f"Task [{self}] specifies target-interval [{target_interval}] and target-throughput [{target_throughput}] but only one of them is allowed." + ) value = None unit = "ops/s" @@ -712,8 +771,7 @@ def numeric(v): elif numeric(target_throughput): value = float(target_throughput) else: - raise exceptions.InvalidSyntax(f"Target throughput [{target_throughput}] for task [{self}] " - f"must be string or numeric.") + raise exceptions.InvalidSyntax(f"Target throughput [{target_throughput}] for task [{self}] must be string or numeric.") if value: return Throughput(value, unit) @@ -724,11 +782,11 @@ def numeric(v): def ignore_response_error_level(self): ignore_response_error_level = self.params.get("ignore-response-error-level") - if ignore_response_error_level and \ - ignore_response_error_level not in Task.IGNORE_RESPONSE_ERROR_LEVEL_WHITELIST: + if ignore_response_error_level and ignore_response_error_level not in Task.IGNORE_RESPONSE_ERROR_LEVEL_WHITELIST: raise exceptions.InvalidSyntax( f"Task [{self}] specifies ignore-response-error-level to [{ignore_response_error_level}] but " - f"the only allowed values are [{','.join(Task.IGNORE_RESPONSE_ERROR_LEVEL_WHITELIST)}].") + f"the only allowed values are [{','.join(Task.IGNORE_RESPONSE_ERROR_LEVEL_WHITELIST)}]." + ) return ignore_response_error_level @@ -751,20 +809,47 @@ def error_behavior(self, default_error_behavior): def __hash__(self): # Note that we do not include `params` in __hash__ and __eq__ (the other attributes suffice to uniquely define a task) - return hash(self.name) ^ hash(self.operation) ^ hash(self.warmup_iterations) ^ hash(self.iterations) ^ \ - hash(self.warmup_time_period) ^ hash(self.time_period) ^ hash(self.ramp_up_time_period) ^ \ - hash(self.ramp_down_time_period) ^ hash(self.clients) ^ hash(self.schedule) ^ hash(self.completes_parent) + return ( + hash(self.name) + ^ hash(self.operation) + ^ hash(self.warmup_iterations) + ^ hash(self.iterations) + ^ hash(self.warmup_time_period) + ^ hash(self.time_period) + ^ hash(self.ramp_up_time_period) + ^ hash(self.ramp_down_time_period) + ^ hash(self.clients) + ^ hash(self.schedule) + ^ hash(self.completes_parent) + ) def __eq__(self, other): # Note that we do not include `params` in __hash__ and __eq__ (the other attributes suffice to uniquely define a task) - return isinstance(other, type(self)) and (self.name, self.operation, self.warmup_iterations, self.iterations, - self.warmup_time_period, self.time_period, self.ramp_up_time_period, - self.ramp_down_time_period, self.clients, self.schedule, - self.completes_parent) == (other.name, other.operation, other.warmup_iterations, - other.iterations, other.warmup_time_period, other.time_period, - other.ramp_up_time_period, other.ramp_down_time_period, - other.clients, other.schedule, - other.completes_parent) + return isinstance(other, type(self)) and ( + self.name, + self.operation, + self.warmup_iterations, + self.iterations, + self.warmup_time_period, + self.time_period, + self.ramp_up_time_period, + self.ramp_down_time_period, + self.clients, + self.schedule, + self.completes_parent, + ) == ( + other.name, + other.operation, + other.warmup_iterations, + other.iterations, + other.warmup_time_period, + other.time_period, + other.ramp_up_time_period, + other.ramp_down_time_period, + other.clients, + other.schedule, + other.completes_parent, + ) def __iter__(self): return iter([self]) diff --git a/solrorbit/workload_generator/__init__.py b/solrorbit/workload_generator/__init__.py index 5047a451..f5768141 100644 --- a/solrorbit/workload_generator/__init__.py +++ b/solrorbit/workload_generator/__init__.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an diff --git a/solrorbit/workload_generator/config.py b/solrorbit/workload_generator/config.py index c1a2a5cc..083db889 100644 --- a/solrorbit/workload_generator/config.py +++ b/solrorbit/workload_generator/config.py @@ -9,6 +9,7 @@ from dataclasses import dataclass, field from typing import List + @dataclass class Index: name: str = None @@ -16,6 +17,7 @@ class Index: number_of_docs: int = None settings_and_mappings: dict = field(default_factory=dict) + @dataclass class CustomWorkload: workload_name: str = None diff --git a/solrorbit/workload_generator/extractors.py b/solrorbit/workload_generator/extractors.py index 46db822b..1af1831f 100644 --- a/solrorbit/workload_generator/extractors.py +++ b/solrorbit/workload_generator/extractors.py @@ -36,8 +36,7 @@ def _cursor_scan(client, collection, batch_size=1000): while True: resp = session.get( url, - params={"q": "*:*", "rows": batch_size, "sort": "id asc", - "cursorMark": cursor, "wt": "json"}, + params={"q": "*:*", "rows": batch_size, "sort": "id asc", "cursorMark": cursor, "wt": "json"}, timeout=client.timeout, ) resp.raise_for_status() @@ -52,6 +51,7 @@ def _cursor_scan(client, collection, batch_size=1000): break cursor = next_cursor + class IndexExtractor: def __init__(self, custom_workload, client): self.custom_workload: CustomWorkload = custom_workload @@ -117,7 +117,6 @@ def is_valid_collection(self, name): class CorpusExtractor(ABC): - @abstractmethod def extract_documents(self, index, documents_limit=None, sample_frequency=None): pass @@ -132,7 +131,7 @@ def __init__(self, custom_workload, client): self.client = client self.logger = logging.getLogger(__name__) - def template_vars(self,index_name, docs_path, doc_count): + def template_vars(self, index_name, docs_path, doc_count): comp_outpath = docs_path + COMP_EXT return { "index_name": index_name, @@ -140,13 +139,12 @@ def template_vars(self,index_name, docs_path, doc_count): "path": comp_outpath, "doc_count": doc_count, "uncompressed_bytes": os.path.getsize(docs_path), - "compressed_bytes": os.path.getsize(comp_outpath) + "compressed_bytes": os.path.getsize(comp_outpath), } def _get_doc_outpath(self, outdir, name, suffix=""): return os.path.join(outdir, f"{name}-documents{suffix}.json") - def extract_documents(self, index, documents_limit=None, sample_frequency=None): """ Scan a Solr collection with CursorMark pagination, dumping documents to @@ -168,15 +166,16 @@ def extract_documents(self, index, documents_limit=None, sample_frequency=None): # Only time when documents-1k.json will be less than 1K documents is # when the documents_limit is < 1k documents or source index has less than 1k documents if documents_limit < self.DEFAULT_TEST_MODE_DOC_COUNT: - test_mode_warning_msg = "Due to --number-of-docs set by user, " + \ - f"test-mode docs will be less than the default {self.DEFAULT_TEST_MODE_DOC_COUNT} documents." + test_mode_warning_msg = "Due to --number-of-docs set by user, " + f"test-mode docs will be less than the default {self.DEFAULT_TEST_MODE_DOC_COUNT} documents." console.warn(test_mode_warning_msg) # Notify users when they specified more documents than available in index if documents_limit > total_documents: - documents_to_extract_warning_msg = f"User requested extraction of {documents_limit} documents " + \ - f"but there are only {total_documents} documents in {index}. " + \ - f"Will only extract {total_documents} documents from {index}." + documents_to_extract_warning_msg = ( + f"User requested extraction of {documents_limit} documents " + + f"but there are only {total_documents} documents in {index}. " + + f"Will only extract {total_documents} documents from {index}." + ) console.warn(documents_to_extract_warning_msg) if sample_frequency and sample_frequency > 1: @@ -185,7 +184,6 @@ def extract_documents(self, index, documents_limit=None, sample_frequency=None): else: return self.standard_extraction(total_documents, documents_to_extract, index) - def sample_frequency_extraction(self, total_documents, sample_frequency, index): if total_documents > 0: self.logger.info("[%d] total docs in index [%s]. Extracting [%s] docs with sample frequency [%s]", total_documents, index, total_documents, sample_frequency) @@ -195,12 +193,13 @@ def sample_frequency_extraction(self, total_documents, sample_frequency, index): index, self._get_doc_outpath(self.custom_workload.workload_path, index, self.DEFAULT_TEST_MODE_SUFFIX), min(total_documents, self.DEFAULT_TEST_MODE_DOC_COUNT), - " for test mode") + " for test mode", + ) docs_path = self._get_doc_outpath(self.custom_workload.workload_path, index) self.dump_documents_with_sample_frequency(total_documents, sample_frequency, docs_path, index) - amount_of_docs_to_extract = (total_documents // sample_frequency) + amount_of_docs_to_extract = total_documents // sample_frequency return self.template_vars(index, docs_path, amount_of_docs_to_extract) else: self.logger.info("Skipping corpus extraction for index [%s] as it contains no documents.", index) @@ -217,7 +216,8 @@ def standard_extraction(self, total_documents, documents_to_extract, index): index, self._get_doc_outpath(self.custom_workload.workload_path, index, self.DEFAULT_TEST_MODE_SUFFIX), min(documents_to_extract, self.DEFAULT_TEST_MODE_DOC_COUNT), - " for test mode") + " for test mode", + ) # Create full corpora self.dump_documents(self.client, index, docs_path, documents_to_extract) @@ -242,7 +242,7 @@ def dump_documents_with_sample_frequency(self, number_of_docs_in_index, sample_f with open(docs_path, "wb") as outfile: with open(comp_outpath, "wb") as comp_outfile: self.logger.info("Dumping corpus for index [%s] to [%s].", index, docs_path) - progress_bar = tqdm(range(number_of_docs_to_fetch), desc=progress_message, ascii=' >=', bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') + progress_bar = tqdm(range(number_of_docs_to_fetch), desc=progress_message, ascii=" >=", bar_format="{l_bar}{bar:10}{r_bar}{bar:-10b}") for n, doc in enumerate(_cursor_scan(self.client, collection=index), start=1): if (n % sample_frequency) != 0: @@ -284,7 +284,6 @@ def dump_documents(self, client, index, docs_path, number_of_docs, progress_mess comp_outfile.write(compressor.flush()) progress.finish() - def render_progress(self, progress, progress_message_suffix, index, cur, total, freq): if cur % freq == 0 or total - cur < freq: msg = f"Extracting documents for index [{index}]{progress_message_suffix}..." diff --git a/solrorbit/workload_generator/helpers.py b/solrorbit/workload_generator/helpers.py index 2886f0a7..5946fc5d 100644 --- a/solrorbit/workload_generator/helpers.py +++ b/solrorbit/workload_generator/helpers.py @@ -26,38 +26,35 @@ DEFAULT_TEST_PROCEDURES = "default-test-procedures" TEMPLATE_EXT = ".json.j2" -class CustomWorkloadWriter: +class CustomWorkloadWriter: def __init__(self, custom_workload: CustomWorkload, templates_path: str): self.custom_workload = custom_workload self.templates_path = templates_path - self.custom_workload.workload_path = os.path.abspath( - os.path.join(io.normalize_path(self.custom_workload.output_path), - self.custom_workload.workload_name)) + self.custom_workload.workload_path = os.path.abspath(os.path.join(io.normalize_path(self.custom_workload.output_path), self.custom_workload.workload_name)) self.custom_workload.operations_path = os.path.join(self.custom_workload.workload_path, "operations") self.custom_workload.test_procedures_path = os.path.join(self.custom_workload.workload_path, "test_procedures") self.logger = logging.getLogger(__name__) def make_workload_directory(self): if not self._has_write_permission(self.custom_workload.workload_path): - error_suggestion = "Workload output path does not have write permissions. " \ - + "Please update the permissions for the specified output path or choose a different output path." + error_suggestion = ( + "Workload output path does not have write permissions. " + "Please update the permissions for the specified output path or choose a different output path." + ) self.logger.error(error_suggestion) console.error(error_suggestion) # Check if a workload of the same name already exists in output path if os.path.exists(self.custom_workload.workload_path): try: - input_text = f"A workload already exists at {self.custom_workload.workload_path}. " \ - + "Would you like to remove it? (y/n): " + input_text = f"A workload already exists at {self.custom_workload.workload_path}. " + "Would you like to remove it? (y/n): " user_decision = input(input_text) - while user_decision not in ('y', 'n'): + while user_decision not in ("y", "n"): user_decision = input("Provide y for yes or n for no. " + input_text) if user_decision == "y": - self.logger.info("Removing existing workload [%s] in path [%s]", - self.custom_workload.workload_name, self.custom_workload.workload_path) + self.logger.info("Removing existing workload [%s] in path [%s]", self.custom_workload.workload_name, self.custom_workload.workload_path) console.info("Removing workload of the same name.") shutil.rmtree(self.custom_workload.workload_path) elif user_decision == "n": @@ -68,8 +65,7 @@ def make_workload_directory(self): sys.exit(0) except OSError: - self.logger.error("Had issues removing existing workload [%s] in path [%s]", - self.custom_workload.workload_name, self.custom_workload.workload_path) + self.logger.error("Had issues removing existing workload [%s] in path [%s]", self.custom_workload.workload_name, self.custom_workload.workload_path) io.ensure_dir(self.custom_workload.workload_path) io.ensure_dir(self.custom_workload.operations_path) @@ -79,7 +75,7 @@ def write_custom_workload_record(self, template_vars): filename = f"{self.custom_workload.workload_path}/{self.custom_workload.workload_name}_record.json" try: self.logger.info("Writing custom workload record to filepath [%s]", filename) - with open(filename, 'w') as file: + with open(filename, "w") as file: json.dump(template_vars, file) except Exception as e: self.logger.error("Could not write to file as CustomWorkloadWriter encountered an error: [%s]", e) @@ -110,12 +106,13 @@ def _write_template(self, template_vars: dict, template_file: str, output_path: f.write(template.render(template_vars)) def _get_default_template(self, template_file: str): - template_file_name = template_file + TEMPLATE_EXT + template_file_name = template_file + TEMPLATE_EXT - env = Environment(loader=FileSystemLoader(self.templates_path), autoescape=select_autoescape(['html', 'xml'])) + env = Environment(loader=FileSystemLoader(self.templates_path), autoescape=select_autoescape(["html", "xml"])) return env.get_template(template_file_name) + class QueryProcessor: def __init__(self, queries: str): self.queries = queries @@ -134,6 +131,7 @@ def process_queries(self): return processed_queries + def process_indices(indices, sample_frequency_mapping, indices_docs_mapping): processed_indices = [] for index_name in indices: @@ -143,22 +141,16 @@ def process_indices(indices, sample_frequency_mapping, indices_docs_mapping): if indices_docs_mapping and index_name in indices_docs_mapping: number_of_docs_for_index = int(indices_docs_mapping[index_name]) if number_of_docs_for_index <= 0: - raise exceptions.SystemSetupError( - "Values specified with --number-of-docs must be greater than 0") + raise exceptions.SystemSetupError("Values specified with --number-of-docs must be greater than 0") # Do this if sample frequency is specified sample_frequency_for_index = None if sample_frequency_mapping and index_name in sample_frequency_mapping: sample_frequency_for_index = int(sample_frequency_mapping[index_name]) if sample_frequency_for_index <= 1: - raise exceptions.SystemSetupError( - "Values specified with --sample-frequency must be greater than 1") + raise exceptions.SystemSetupError("Values specified with --sample-frequency must be greater than 1") - index = Index( - name=index_name, - sample_frequency=sample_frequency_for_index, - number_of_docs=number_of_docs_for_index - ) + index = Index(name=index_name, sample_frequency=sample_frequency_for_index, number_of_docs=number_of_docs_for_index) processed_indices.append(index) except ValueError as e: @@ -166,6 +158,7 @@ def process_indices(indices, sample_frequency_mapping, indices_docs_mapping): return processed_indices + def validate_index_documents_map(indices, indices_docs_map): logger = logging.getLogger(__name__) logger.info("Indices Docs Map: [%s]", indices_docs_map) @@ -175,17 +168,17 @@ def validate_index_documents_map(indices, indices_docs_map): if len(indices) < len(indices_docs_map): raise exceptions.SystemSetupError( - "Number of : pairs in --number-of-docs exceeds number of indices in --indices. " + - "Ensure number of : pairs is less than or equal to number of indices." + "Number of : pairs in --number-of-docs exceeds number of indices in --indices. " + + "Ensure number of : pairs is less than or equal to number of indices." ) for index_name in indices_docs_map: if index_name not in indices: raise exceptions.SystemSetupError( - f"Index {index_name} provided in --number-of-docs was not found in --indices. " + - "Ensure that all indices in --number-of-docs are present in --indices." + f"Index {index_name} provided in --number-of-docs was not found in --indices. " + "Ensure that all indices in --number-of-docs are present in --indices." ) + def validate_sample_frequency_mapping(indices, sample_frequency_mapping): sample_frequency_enabled = sample_frequency_mapping is not None and len(sample_frequency_mapping) > 0 @@ -194,13 +187,12 @@ def validate_sample_frequency_mapping(indices, sample_frequency_mapping): if len(indices) < len(sample_frequency_mapping): raise exceptions.SystemSetupError( - "Number of : pairs exceeds number of indices in --indices. " + - "Ensure number of : pairs is less than or equal to number of indices in --indices." + "Number of : pairs exceeds number of indices in --indices. " + + "Ensure number of : pairs is less than or equal to number of indices in --indices." ) for index_name in sample_frequency_mapping: if index_name not in indices: raise exceptions.SystemSetupError( - "Index from : pair was not found in --indices. " + - "Ensure that indices from all : pairs exist in --indices." + "Index from : pair was not found in --indices. " + "Ensure that indices from all : pairs exist in --indices." ) diff --git a/solrorbit/workload_generator/workload_generator.py b/solrorbit/workload_generator/workload_generator.py index 8bed1c5c..79cf7685 100644 --- a/solrorbit/workload_generator/workload_generator.py +++ b/solrorbit/workload_generator/workload_generator.py @@ -16,6 +16,7 @@ from solrorbit.workload_generator.extractors import IndexExtractor, SequentialCorpusExtractor from solrorbit.utils import io, opts, console + def create_workload(cfg): logger = logging.getLogger(__name__) @@ -36,8 +37,7 @@ def create_workload(cfg): validate_index_documents_map(indices, number_of_docs) validate_sample_frequency_mapping(indices, sample_frequency_mapping) - client = ClientFactory(hosts=target_hosts.all_hosts[opts.TargetHosts.DEFAULT], - client_options=client_options.all_client_options[opts.TargetHosts.DEFAULT]).create() + client = ClientFactory(hosts=target_hosts.all_hosts[opts.TargetHosts.DEFAULT], client_options=client_options.all_client_options[opts.TargetHosts.DEFAULT]).create() info = client.info() console.info(f"Connected to Solr cluster [{info['name']}] version [{info['version']['number']}].\n", logger=logger) @@ -84,7 +84,7 @@ def create_workload(cfg): "workload_name": custom_workload.workload_name, "indices": custom_workload.extracted_indices, "corpora": custom_workload.corpora, - "custom_queries": custom_workload.queries + "custom_queries": custom_workload.queries, } logger.info("Template vars [%s]", template_vars) diff --git a/tests/__init__.py b/tests/__init__.py index feca5bd9..e9038183 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -34,8 +34,10 @@ def run_async(t): :param t: The test case to wrap. """ + def async_wrapper(*args, **kwargs): asyncio.run(t(*args, **kwargs), debug=True) + return async_wrapper diff --git a/tests/aggregator_test.py b/tests/aggregator_test.py index 7bc046ad..29394c79 100644 --- a/tests/aggregator_test.py +++ b/tests/aggregator_test.py @@ -3,49 +3,48 @@ from solrorbit import config from solrorbit.aggregator import Aggregator, AggregatedResults + @pytest.fixture def mock_config(): mock_cfg = Mock(spec=config.Config) mock_cfg.opts.side_effect = lambda *args: "test_procedure_name" if args == ("workload", "test_procedure.name") else "/path/to/root" return mock_cfg + @pytest.fixture def mock_test_runs(): - return { - "test1": Mock(), - "test2": Mock() - } + return {"test1": Mock(), "test2": Mock()} + @pytest.fixture def mock_args(): - return Mock( - results_file="", - test_run_id="", - workload_repository="default" - ) + return Mock(results_file="", test_run_id="", workload_repository="default") + @pytest.fixture def mock_test_store(): mock_store = Mock() mock_store.find_by_test_run_id.side_effect = [ Mock(results={"key1": {"nested": 10}}, workload="workload1", test_procedure="test_proc1"), - Mock(results={"key1": {"nested": 20}}, workload="workload1", test_procedure="test_proc1") + Mock(results={"key1": {"nested": 20}}, workload="workload1", test_procedure="test_proc1"), ] return mock_store + @pytest.fixture def aggregator(mock_config, mock_test_runs, mock_args, mock_test_store): aggregator = Aggregator(mock_config, mock_test_runs, mock_args) aggregator.test_store = mock_test_store return aggregator + def test_count_iterations_for_each_op(aggregator): mock_workload = Mock() - mock_task = Mock(spec=['name', 'iterations']) + mock_task = Mock(spec=["name", "iterations"]) mock_task.name = "op1" mock_task.iterations = 5 mock_schedule = [mock_task] - mock_test_procedure = Mock(spec=['name', 'schedule']) + mock_test_procedure = Mock(spec=["name", "schedule"]) mock_test_procedure.name = "test_procedure_name" mock_test_procedure.schedule = mock_schedule mock_workload.test_procedures = [mock_test_procedure] @@ -62,6 +61,7 @@ def test_count_iterations_for_each_op(aggregator): assert "op1" in aggregator.accumulated_iterations["test1"], "op1 not found in accumulated_iterations for test1" assert aggregator.accumulated_iterations["test1"]["op1"] == 5 + def test_accumulate_results(aggregator): mock_test_run = Mock() mock_test_run.results = { @@ -74,7 +74,7 @@ def test_accumulate_results(aggregator): "client_processing_time": 2, "processing_time": 3, "error_rate": 0.1, - "duration": 60 + "duration": 60, } ] } @@ -84,6 +84,7 @@ def test_accumulate_results(aggregator): assert "task1" in aggregator.accumulated_results assert all(metric in aggregator.accumulated_results["task1"] for metric in aggregator.metrics) + def test_test_run_compatibility_check(aggregator): mock_test_store = Mock() mock_test_store.find_by_test_run_id.side_effect = [ @@ -96,22 +97,18 @@ def test_test_run_compatibility_check(aggregator): assert aggregator.test_run_compatibility_check() + def test_aggregate_json_by_key(aggregator): result = aggregator.aggregate_json_by_key("key1.nested") assert result == 15 + def test_calculate_weighted_average(aggregator): - task_metrics = { - "throughput": [100, 200], - "latency": [{"avg": 10, "unit": "ms"}, {"avg": 20, "unit": "ms"}] - } + task_metrics = {"throughput": [100, 200], "latency": [{"avg": 10, "unit": "ms"}, {"avg": 20, "unit": "ms"}]} task_name = "op1" # set up accumulated_iterations - aggregator.accumulated_iterations = { - "test1": {"op1": 2}, - "test2": {"op1": 3} - } + aggregator.accumulated_iterations = {"test1": {"op1": 2}, "test2": {"op1": 3}} aggregator.test_runs = {"test1": Mock(), "test2": Mock()} result = aggregator.calculate_weighted_average(task_metrics, task_name) @@ -120,11 +117,13 @@ def test_calculate_weighted_average(aggregator): assert result["latency"]["avg"] == 16 # (10*2 + 20*3) / (2+3) assert result["latency"]["unit"] == "ms" + def test_calculate_rsd(aggregator): values = [1, 2, 3, 4, 5] rsd = aggregator.calculate_rsd(values, "test_metric") assert isinstance(rsd, float) + def test_test_run_compatibility_check_incompatible(aggregator): mock_test_store = Mock() mock_test_store.find_by_test_run_id.side_effect = [ @@ -136,6 +135,7 @@ def test_test_run_compatibility_check_incompatible(aggregator): with pytest.raises(ValueError): aggregator.test_run_compatibility_check() + def test_aggregated_results(): results = {"key": "value"} agg_results = AggregatedResults(results) diff --git a/tests/builder/__init__.py b/tests/builder/__init__.py index 5047a451..f5768141 100644 --- a/tests/builder/__init__.py +++ b/tests/builder/__init__.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an diff --git a/tests/builder/cluster_config_test.py b/tests/builder/cluster_config_test.py index 19d72281..e37dbbdf 100644 --- a/tests/builder/cluster_config_test.py +++ b/tests/builder/cluster_config_test.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -47,91 +47,66 @@ def setUp(self): def test_lists_cluster_config_names(self): # contrary to the name this assertion compares contents but does not care about order. self.assertCountEqual( - ["default", "with_hook", "32gheap", "missing_cfg_base", "empty_cfg_base", "ea", "verbose", "multi_hook", "another_with_hook"], - self.loader.cluster_config_names() + ["default", "with_hook", "32gheap", "missing_cfg_base", "empty_cfg_base", "ea", "verbose", "multi_hook", "another_with_hook"], self.loader.cluster_config_names() ) def test_load_known_cluster_config(self): - loaded_cluster_config = cluster_config.load_cluster_config( - self.cluster_config_dir, ["default"], - cluster_config_params={"data_paths": ["/mnt/disk0", "/mnt/disk1"]}) + loaded_cluster_config = cluster_config.load_cluster_config(self.cluster_config_dir, ["default"], cluster_config_params={"data_paths": ["/mnt/disk0", "/mnt/disk1"]}) self.assertEqual("default", loaded_cluster_config.name) - self.assertEqual( - [os.path.join(current_dir, "data", "cluster_configs", "v1", "vanilla", "templates")], - loaded_cluster_config.config_paths) + self.assertEqual([os.path.join(current_dir, "data", "cluster_configs", "v1", "vanilla", "templates")], loaded_cluster_config.config_paths) self.assertIsNone(loaded_cluster_config.root_path) - self.assertDictEqual({ - "heap_size": "1g", - "clean_command": "./gradlew clean", - "data_paths": ["/mnt/disk0", "/mnt/disk1"] - }, loaded_cluster_config.variables) + self.assertDictEqual({"heap_size": "1g", "clean_command": "./gradlew clean", "data_paths": ["/mnt/disk0", "/mnt/disk1"]}, loaded_cluster_config.variables) self.assertIsNone(loaded_cluster_config.root_path) def test_load_cluster_config_with_mixin_single_config_base(self): loaded_cluster_config = cluster_config.load_cluster_config(self.cluster_config_dir, ["32gheap", "ea"]) self.assertEqual("32gheap+ea", loaded_cluster_config.name) - self.assertEqual( - [os.path.join(current_dir, "data", "cluster_configs", "v1", "vanilla", "templates")], - loaded_cluster_config.config_paths) + self.assertEqual([os.path.join(current_dir, "data", "cluster_configs", "v1", "vanilla", "templates")], loaded_cluster_config.config_paths) self.assertIsNone(loaded_cluster_config.root_path) - self.assertEqual({ - "heap_size": "32g", - "clean_command": "./gradlew clean", - "assertions": "true" - }, loaded_cluster_config.variables) + self.assertEqual({"heap_size": "32g", "clean_command": "./gradlew clean", "assertions": "true"}, loaded_cluster_config.variables) self.assertIsNone(loaded_cluster_config.root_path) def test_load_cluster_config_with_mixin_multiple_config_bases(self): loaded_cluster_config = cluster_config.load_cluster_config(self.cluster_config_dir, ["32gheap", "ea", "verbose"]) self.assertEqual("32gheap+ea+verbose", loaded_cluster_config.name) - self.assertEqual([ - os.path.join(current_dir, "data", "cluster_configs", "v1", "vanilla", "templates"), - os.path.join(current_dir, "data", "cluster_configs", "v1", "verbose_logging", "templates"), - ], loaded_cluster_config.config_paths) + self.assertEqual( + [ + os.path.join(current_dir, "data", "cluster_configs", "v1", "vanilla", "templates"), + os.path.join(current_dir, "data", "cluster_configs", "v1", "verbose_logging", "templates"), + ], + loaded_cluster_config.config_paths, + ) self.assertIsNone(loaded_cluster_config.root_path) - self.assertEqual({ - "heap_size": "32g", - "clean_command": "./gradlew clean", - "verbose_logging": "true", - "assertions": "true" - }, loaded_cluster_config.variables) + self.assertEqual({"heap_size": "32g", "clean_command": "./gradlew clean", "verbose_logging": "true", "assertions": "true"}, loaded_cluster_config.variables) def test_load_cluster_config_with_install_hook(self): loaded_cluster_config = cluster_config.load_cluster_config( - self.cluster_config_dir, - ["default", "with_hook"], - cluster_config_params={"data_paths": ["/mnt/disk0", "/mnt/disk1"]}) + self.cluster_config_dir, ["default", "with_hook"], cluster_config_params={"data_paths": ["/mnt/disk0", "/mnt/disk1"]} + ) self.assertEqual("default+with_hook", loaded_cluster_config.name) - self.assertEqual([ - os.path.join(current_dir, "data", "cluster_configs", "v1", "vanilla", "templates"), - os.path.join(current_dir, "data", "cluster_configs", "v1", "with_hook", "templates"), - ], loaded_cluster_config.config_paths) self.assertEqual( - os.path.join(current_dir, "data", "cluster_configs", "v1", "with_hook"), - loaded_cluster_config.root_path) - self.assertDictEqual({ - "heap_size": "1g", - "clean_command": "./gradlew clean", - "data_paths": ["/mnt/disk0", "/mnt/disk1"] - }, loaded_cluster_config.variables) + [ + os.path.join(current_dir, "data", "cluster_configs", "v1", "vanilla", "templates"), + os.path.join(current_dir, "data", "cluster_configs", "v1", "with_hook", "templates"), + ], + loaded_cluster_config.config_paths, + ) + self.assertEqual(os.path.join(current_dir, "data", "cluster_configs", "v1", "with_hook"), loaded_cluster_config.root_path) + self.assertDictEqual({"heap_size": "1g", "clean_command": "./gradlew clean", "data_paths": ["/mnt/disk0", "/mnt/disk1"]}, loaded_cluster_config.variables) def test_load_cluster_config_with_multiple_bases_referring_same_install_hook(self): - loaded_cluster_config = cluster_config.load_cluster_config( - self.cluster_config_dir, ["with_hook", "another_with_hook"]) + loaded_cluster_config = cluster_config.load_cluster_config(self.cluster_config_dir, ["with_hook", "another_with_hook"]) self.assertEqual("with_hook+another_with_hook", loaded_cluster_config.name) - self.assertEqual([ - os.path.join(current_dir, "data", "cluster_configs", "v1", "vanilla", "templates"), - os.path.join(current_dir, "data", "cluster_configs", "v1", "with_hook", "templates"), - os.path.join(current_dir, "data", "cluster_configs", "v1", "verbose_logging", "templates") - ], loaded_cluster_config.config_paths) self.assertEqual( - os.path.join(current_dir, "data", "cluster_configs", "v1", "with_hook"), - loaded_cluster_config.root_path) - self.assertDictEqual({ - "heap_size": "16g", - "clean_command": "./gradlew clean", - "verbose_logging": "true" - }, loaded_cluster_config.variables) + [ + os.path.join(current_dir, "data", "cluster_configs", "v1", "vanilla", "templates"), + os.path.join(current_dir, "data", "cluster_configs", "v1", "with_hook", "templates"), + os.path.join(current_dir, "data", "cluster_configs", "v1", "verbose_logging", "templates"), + ], + loaded_cluster_config.config_paths, + ) + self.assertEqual(os.path.join(current_dir, "data", "cluster_configs", "v1", "with_hook"), loaded_cluster_config.root_path) + self.assertDictEqual({"heap_size": "16g", "clean_command": "./gradlew clean", "verbose_logging": "true"}, loaded_cluster_config.variables) def test_raises_error_on_unknown_cluster_config(self): with self.assertRaises(exceptions.SystemSetupError) as ctx: @@ -139,7 +114,8 @@ def test_raises_error_on_unknown_cluster_config(self): self.assertRegex( ctx.exception.args[0], r"Unknown cluster-config \[don_t-know-you\]. " - r"List the available cluster-configs with [^\s]+ list cluster-configs.") + r"List the available cluster-configs with [^\s]+ list cluster-configs.", + ) def test_raises_error_on_empty_config_base(self): with self.assertRaises(exceptions.SystemSetupError) as ctx: @@ -154,9 +130,7 @@ def test_raises_error_on_missing_config_base(self): def test_raises_error_if_more_than_one_different_install_hook(self): with self.assertRaises(exceptions.SystemSetupError) as ctx: cluster_config.load_cluster_config(self.cluster_config_dir, ["multi_hook"]) - self.assertEqual( - "Invalid cluster_config: ['multi_hook']. Multiple bootstrap hooks are forbidden.", - ctx.exception.args[0]) + self.assertEqual("Invalid cluster_config: ['multi_hook']. Multiple bootstrap hooks are forbidden.", ctx.exception.args[0]) class PluginLoaderTests(TestCase): @@ -174,12 +148,16 @@ def test_lists_plugins(self): cluster_config.PluginDescriptor(name="complex-plugin", config="config-b"), cluster_config.PluginDescriptor(name="my-analysis-plugin", core_plugin=True), cluster_config.PluginDescriptor(name="my-ingest-plugin", core_plugin=True), - cluster_config.PluginDescriptor(name="my-core-plugin-with-config", core_plugin=True) - ], self.loader.plugins()) + cluster_config.PluginDescriptor(name="my-core-plugin-with-config", core_plugin=True), + ], + self.loader.plugins(), + ) def test_loads_core_plugin(self): - self.assertEqual(cluster_config.PluginDescriptor(name="my-analysis-plugin", core_plugin=True, variables={"dbg": True}), - self.loader.load_plugin("my-analysis-plugin", config_names=None, plugin_params={"dbg": True})) + self.assertEqual( + cluster_config.PluginDescriptor(name="my-analysis-plugin", core_plugin=True, variables={"dbg": True}), + self.loader.load_plugin("my-analysis-plugin", config_names=None, plugin_params={"dbg": True}), + ) def test_loads_core_plugin_with_config(self): plugin = self.loader.load_plugin("my-core-plugin-with-config", config_names=None, plugin_params={"dbg": True}) @@ -191,17 +169,23 @@ def test_loads_core_plugin_with_config(self): self.assertEqual(expected_root_path, plugin.root_path) self.assertEqual(0, len(plugin.config_paths)) - self.assertEqual({ - # from plugin params - "dbg": True - }, plugin.variables) + self.assertEqual( + { + # from plugin params + "dbg": True + }, + plugin.variables, + ) def test_cannot_load_plugin_with_missing_config(self): with self.assertRaises(exceptions.SystemSetupError) as ctx: self.loader.load_plugin("my-analysis-plugin", ["missing-config"]) - self.assertRegex(ctx.exception.args[0], r"Plugin \[my-analysis-plugin\] does not provide configuration \[missing-config\]. List the" - r" available plugins and configurations with [^\s]+ list cluster-configs " - r"--distribution-version=VERSION.") + self.assertRegex( + ctx.exception.args[0], + r"Plugin \[my-analysis-plugin\] does not provide configuration \[missing-config\]. List the" + r" available plugins and configurations with [^\s]+ list cluster-configs " + r"--distribution-version=VERSION.", + ) def test_loads_community_plugin_without_configuration(self): self.assertEqual(cluster_config.PluginDescriptor("my-community-plugin"), self.loader.load_plugin("my-community-plugin", None)) @@ -209,8 +193,11 @@ def test_loads_community_plugin_without_configuration(self): def test_cannot_load_community_plugin_with_missing_config(self): with self.assertRaises(exceptions.SystemSetupError) as ctx: self.loader.load_plugin("my-community-plugin", "some-configuration") - self.assertRegex(ctx.exception.args[0], r"Unknown plugin \[my-community-plugin\]. List the available plugins with [^\s]+ list " - r"cluster-configs --distribution-version=VERSION.") + self.assertRegex( + ctx.exception.args[0], + r"Unknown plugin \[my-community-plugin\]. List the available plugins with [^\s]+ list " + r"cluster-configs --distribution-version=VERSION.", + ) def test_loads_configured_plugin(self): plugin = self.loader.load_plugin("complex-plugin", ["config-a", "config-b"], plugin_params={"dbg": True}) @@ -222,19 +209,25 @@ def test_loads_configured_plugin(self): self.assertEqual(expected_root_path, plugin.root_path) # order does matter here! We should not swap it - self.assertListEqual([ - os.path.join(expected_root_path, "default", "templates"), - os.path.join(expected_root_path, "special", "templates"), - ], plugin.config_paths) + self.assertListEqual( + [ + os.path.join(expected_root_path, "default", "templates"), + os.path.join(expected_root_path, "special", "templates"), + ], + plugin.config_paths, + ) - self.assertEqual({ - "foo": "bar", - "baz": "foo", - "var": "0", - "hello": "true", - # from plugin params - "dbg": True - }, plugin.variables) + self.assertEqual( + { + "foo": "bar", + "baz": "foo", + "var": "0", + "hello": "true", + # from plugin params + "dbg": True, + }, + plugin.variables, + ) class BootstrapHookHandlerTests(TestCase): @@ -282,5 +275,4 @@ def test_cannot_register_for_unknown_phase(self): handler.loader.registration_function = hook with self.assertRaises(exceptions.SystemSetupError) as ctx: handler.load() - self.assertEqual("Unknown bootstrap phase [this_is_an_unknown_install_phase]. Valid phases are: ['post_install'].", - ctx.exception.args[0]) + self.assertEqual("Unknown bootstrap phase [this_is_an_unknown_install_phase]. Valid phases are: ['post_install'].", ctx.exception.args[0]) diff --git a/tests/builder/configs/listers/plugin_config_instance_lister_test.py b/tests/builder/configs/listers/plugin_config_instance_lister_test.py index 9ac48922..a464777a 100644 --- a/tests/builder/configs/listers/plugin_config_instance_lister_test.py +++ b/tests/builder/configs/listers/plugin_config_instance_lister_test.py @@ -18,10 +18,13 @@ def test_list_plugin_config_instances(self): plugin_config_instances = self.plugin_config_instance_lister.list_plugin_config_instances() print(plugin_config_instances) - self.assertEqual(plugin_config_instances, [ - PluginConfigInstance(name="complex-plugin", format_version="v1", config_names=["config-a"]), - PluginConfigInstance(name="complex-plugin", format_version="v1", config_names=["config-b"]), - PluginConfigInstance(name="my-analysis-plugin", format_version="v1", is_core_plugin=True), - PluginConfigInstance(name="my-core-plugin-with-config", format_version="v1", is_core_plugin=True), - PluginConfigInstance(name="my-ingest-plugin", format_version="v1", is_core_plugin=True) - ]) + self.assertEqual( + plugin_config_instances, + [ + PluginConfigInstance(name="complex-plugin", format_version="v1", config_names=["config-a"]), + PluginConfigInstance(name="complex-plugin", format_version="v1", config_names=["config-b"]), + PluginConfigInstance(name="my-analysis-plugin", format_version="v1", is_core_plugin=True), + PluginConfigInstance(name="my-core-plugin-with-config", format_version="v1", is_core_plugin=True), + PluginConfigInstance(name="my-ingest-plugin", format_version="v1", is_core_plugin=True), + ], + ) diff --git a/tests/builder/configs/utils/config_path_resolver_test.py b/tests/builder/configs/utils/config_path_resolver_test.py index 2a73ad24..5caa0fad 100644 --- a/tests/builder/configs/utils/config_path_resolver_test.py +++ b/tests/builder/configs/utils/config_path_resolver_test.py @@ -13,7 +13,7 @@ def setUp(self): self.cfg = Mock() self.config_path_resolver = ConfigPathResolver(self.cfg) - @mock.patch('os.path.exists') + @mock.patch("os.path.exists") def test_cluster_config_path_defined(self, path_exists): path_exists.return_value = True # opts("builder", "cluster_config.path") @@ -22,10 +22,10 @@ def test_cluster_config_path_defined(self, path_exists): config_path = self.config_path_resolver.resolve_config_path(self.config_type, self.config_format_version) self.assertEqual(config_path, "/path/to/configs/red/v36") - @mock.patch('solrorbit.utils.git.fetch') - @mock.patch('solrorbit.utils.repo.BenchmarkRepository') - @mock.patch('solrorbit.utils.repo.BenchmarkRepository.set_cluster_configs_dir') - @mock.patch('os.path.exists') + @mock.patch("solrorbit.utils.git.fetch") + @mock.patch("solrorbit.utils.repo.BenchmarkRepository") + @mock.patch("solrorbit.utils.repo.BenchmarkRepository.set_cluster_configs_dir") + @mock.patch("os.path.exists") def test_cluster_config_path_not_defined(self, path_exists, set_repo, benchmark_repo, git_fetch): path_exists.return_value = True @@ -37,7 +37,7 @@ def test_cluster_config_path_not_defined(self, path_exists, set_repo, benchmark_ config_path = self.config_path_resolver.resolve_config_path(self.config_type, self.config_format_version) self.assertEqual(config_path, "/root_dir/repo_dir/fake-repo/red/v36") - @mock.patch('os.path.exists') + @mock.patch("os.path.exists") def test_cluster_config_path_does_not_exist(self, path_exists): path_exists.return_value = False # opts("builder", "cluster_config.path") diff --git a/tests/builder/data/cluster_configs/v1/hook2/config.py b/tests/builder/data/cluster_configs/v1/hook2/config.py index 09dec81a..bf854c3f 100644 --- a/tests/builder/data/cluster_configs/v1/hook2/config.py +++ b/tests/builder/data/cluster_configs/v1/hook2/config.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an diff --git a/tests/builder/data/cluster_configs/v1/with_hook/config.py b/tests/builder/data/cluster_configs/v1/with_hook/config.py index 09dec81a..bf854c3f 100644 --- a/tests/builder/data/cluster_configs/v1/with_hook/config.py +++ b/tests/builder/data/cluster_configs/v1/with_hook/config.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an diff --git a/tests/builder/downloaders/builders/source_binary_builder_test.py b/tests/builder/downloaders/builders/source_binary_builder_test.py index c8011ef1..59acf55a 100644 --- a/tests/builder/downloaders/builders/source_binary_builder_test.py +++ b/tests/builder/downloaders/builders/source_binary_builder_test.py @@ -18,26 +18,23 @@ def setUp(self): self.build_jdk_version = 13 self.log_dir = "/benchmark/logs" - self.source_binary_builder = SourceBinaryBuilder(self.executor, self.path_manager, self.jdk_resolver, - self.os_src_dir, self.build_jdk_version, self.log_dir) + self.source_binary_builder = SourceBinaryBuilder(self.executor, self.path_manager, self.jdk_resolver, self.os_src_dir, self.build_jdk_version, self.log_dir) self.jdk_resolver.resolve_jdk_path.return_value = (13, "/path/to/jdk") def test_build(self): self.source_binary_builder.build(self.host, self.build_commands) - self.executor.execute.assert_has_calls([ - mock.call(self.host, "export JAVA_HOME=/path/to/jdk"), - mock.call(self.host, "/fake/src/dir/gradle build > /benchmark/logs/build.log 2>&1") - ]) + self.executor.execute.assert_has_calls( + [mock.call(self.host, "export JAVA_HOME=/path/to/jdk"), mock.call(self.host, "/fake/src/dir/gradle build > /benchmark/logs/build.log 2>&1")] + ) def test_build_with_src_dir_override(self): self.source_binary_builder.build(self.host, self.build_commands, "/override/src") - self.executor.execute.assert_has_calls([ - mock.call(self.host, "export JAVA_HOME=/path/to/jdk"), - mock.call(self.host, "/override/src/gradle build > /benchmark/logs/build.log 2>&1") - ]) + self.executor.execute.assert_has_calls( + [mock.call(self.host, "export JAVA_HOME=/path/to/jdk"), mock.call(self.host, "/override/src/gradle build > /benchmark/logs/build.log 2>&1")] + ) def test_build_failure(self): # Set JAVA_HOME, execute build command diff --git a/tests/builder/downloaders/distribution_downloader_test.py b/tests/builder/downloaders/distribution_downloader_test.py index 8364e1d5..308ee47a 100644 --- a/tests/builder/downloaders/distribution_downloader_test.py +++ b/tests/builder/downloaders/distribution_downloader_test.py @@ -11,22 +11,13 @@ def setUp(self): self.host = None self.executor = Mock() - self.cluster_config = ClusterConfigInstance(names="fake", root_path="also fake", config_paths="fake2", variables={ - "node": { - "root": { - "dir": "/fake/dir/for/download" - } - }, - "distribution": { - "version": "1.2.3" - } - }) + self.cluster_config = ClusterConfigInstance( + names="fake", root_path="also fake", config_paths="fake2", variables={"node": {"root": {"dir": "/fake/dir/for/download"}}, "distribution": {"version": "1.2.3"}} + ) self.path_manager = Mock() self.distribution_repository_provider = Mock() - self.os_distro_downloader = DistributionDownloader(self.cluster_config, self.executor, self.path_manager, - self.distribution_repository_provider) - + self.os_distro_downloader = DistributionDownloader(self.cluster_config, self.executor, self.path_manager, self.distribution_repository_provider) self.os_distro_downloader.distribution_repository_provider.get_download_url.return_value = "https://fake/download.tar.gz" self.os_distro_downloader.distribution_repository_provider.get_file_name_from_download_url.return_value = "my-distro" @@ -39,10 +30,12 @@ def test_download_distro(self): binary_map = self.os_distro_downloader.download(self.host) self.assertEqual(binary_map, {"solr": "/fake/dir/for/download/distributions/my-distro"}) - self.executor.execute.assert_has_calls([ - mock.call(self.host, "test -f /fake/dir/for/download/distributions/my-distro"), - mock.call(self.host, "curl -o /fake/dir/for/download/distributions/my-distro https://fake/download.tar.gz") - ]) + self.executor.execute.assert_has_calls( + [ + mock.call(self.host, "test -f /fake/dir/for/download/distributions/my-distro"), + mock.call(self.host, "curl -o /fake/dir/for/download/distributions/my-distro https://fake/download.tar.gz"), + ] + ) def test_download_distro_exists_and_cache_enabled(self): # Check if file exists, download via curl @@ -51,9 +44,7 @@ def test_download_distro_exists_and_cache_enabled(self): binary_map = self.os_distro_downloader.download(self.host) self.assertEqual(binary_map, {"solr": "/fake/dir/for/download/distributions/my-distro"}) - self.executor.execute.assert_has_calls([ - mock.call(self.host, "test -f /fake/dir/for/download/distributions/my-distro") - ]) + self.executor.execute.assert_has_calls([mock.call(self.host, "test -f /fake/dir/for/download/distributions/my-distro")]) def test_download_distro_exists_and_cache_disabled(self): self.os_distro_downloader.distribution_repository_provider.is_cache_enabled.return_value = False @@ -63,7 +54,9 @@ def test_download_distro_exists_and_cache_disabled(self): binary_map = self.os_distro_downloader.download(self.host) self.assertEqual(binary_map, {"solr": "/fake/dir/for/download/distributions/my-distro"}) - self.executor.execute.assert_has_calls([ - mock.call(self.host, "test -f /fake/dir/for/download/distributions/my-distro"), - mock.call(self.host, "curl -o /fake/dir/for/download/distributions/my-distro https://fake/download.tar.gz") - ]) + self.executor.execute.assert_has_calls( + [ + mock.call(self.host, "test -f /fake/dir/for/download/distributions/my-distro"), + mock.call(self.host, "curl -o /fake/dir/for/download/distributions/my-distro https://fake/download.tar.gz"), + ] + ) diff --git a/tests/builder/downloaders/repositories/distribution_repository_provider_test.py b/tests/builder/downloaders/repositories/distribution_repository_provider_test.py index 1b16db33..10393b56 100644 --- a/tests/builder/downloaders/repositories/distribution_repository_provider_test.py +++ b/tests/builder/downloaders/repositories/distribution_repository_provider_test.py @@ -1,35 +1,25 @@ from unittest import TestCase, mock from unittest.mock import Mock -from solrorbit.builder.downloaders.repositories.distribution_repository_provider import \ - DistributionRepositoryProvider +from solrorbit.builder.downloaders.repositories.distribution_repository_provider import DistributionRepositoryProvider from solrorbit.builder.cluster_config import ClusterConfigInstance class DistributionRepositoryProviderTest(TestCase): def setUp(self): self.host = None - self.cluster_config = ClusterConfigInstance(names=None, config_paths=None, root_path=None, variables={ - "distribution": { - "repository": "release", - "release": { - "cache": True - } - } - }) + self.cluster_config = ClusterConfigInstance( + names=None, config_paths=None, root_path=None, variables={"distribution": {"repository": "release", "release": {"cache": True}}} + ) self.repository_url_provider = Mock() - self.os_distro_repo_provider = DistributionRepositoryProvider(self.cluster_config, - self.repository_url_provider) + self.os_distro_repo_provider = DistributionRepositoryProvider(self.cluster_config, self.repository_url_provider) def test_get_download_url(self): self.os_distro_repo_provider.get_download_url(self.host) - self.os_distro_repo_provider.repository_url_provider.render_url_for_key.assert_has_calls([ - mock.call(None, self.cluster_config.variables, "distribution.release_url") - ]) + self.os_distro_repo_provider.repository_url_provider.render_url_for_key.assert_has_calls([mock.call(None, self.cluster_config.variables, "distribution.release_url")]) def test_get_file_name(self): - file_name = self.os_distro_repo_provider.get_file_name_from_download_url( - "https://archive.apache.org/dist/solr/solr/9.10.1/solr-9.10.1.tgz") + file_name = self.os_distro_repo_provider.get_file_name_from_download_url("https://archive.apache.org/dist/solr/solr/9.10.1/solr-9.10.1.tgz") self.assertEqual(file_name, "solr-9.10.1.tgz") diff --git a/tests/builder/downloaders/repositories/plugin_distribution_repository_provider_test.py b/tests/builder/downloaders/repositories/plugin_distribution_repository_provider_test.py index 7c753fd0..b98a2788 100644 --- a/tests/builder/downloaders/repositories/plugin_distribution_repository_provider_test.py +++ b/tests/builder/downloaders/repositories/plugin_distribution_repository_provider_test.py @@ -1,8 +1,7 @@ from unittest import TestCase, mock from unittest.mock import Mock -from solrorbit.builder.downloaders.repositories.plugin_distribution_repository_provider import \ - PluginDistributionRepositoryProvider +from solrorbit.builder.downloaders.repositories.plugin_distribution_repository_provider import PluginDistributionRepositoryProvider from solrorbit.builder.cluster_config import PluginDescriptor @@ -13,9 +12,8 @@ def setUp(self): self.repository_url_provider = Mock() self.plugin_distro_repo_provider = PluginDistributionRepositoryProvider(self.plugin, self.repository_url_provider) - def test_get_plugin_url(self): self.plugin_distro_repo_provider.get_download_url(self.host) - self.plugin_distro_repo_provider.repository_url_provider.render_url_for_key.assert_has_calls([ - mock.call(None, {"distribution": {"repository": "release"}}, "distribution.release.remote.repo.url", mandatory=False) - ]) + self.plugin_distro_repo_provider.repository_url_provider.render_url_for_key.assert_has_calls( + [mock.call(None, {"distribution": {"repository": "release"}}, "distribution.release.remote.repo.url", mandatory=False)] + ) diff --git a/tests/builder/downloaders/repositories/repository_url_provider_test.py b/tests/builder/downloaders/repositories/repository_url_provider_test.py index 77a48699..e3f11a06 100644 --- a/tests/builder/downloaders/repositories/repository_url_provider_test.py +++ b/tests/builder/downloaders/repositories/repository_url_provider_test.py @@ -11,14 +11,7 @@ def setUp(self): self.artifact_variables_provider = Mock() self.host = None - self.variables = { - "distribution": { - "version": "1.2.3" - }, - "fake": { - "url": "opensearch/{{VERSION}}/opensearch-{{VERSION}}-{{OSNAME}}-{{ARCH}}.tar.gz" - } - } + self.variables = {"distribution": {"version": "1.2.3"}, "fake": {"url": "opensearch/{{VERSION}}/opensearch-{{VERSION}}-{{OSNAME}}-{{ARCH}}.tar.gz"}} self.url_key = "fake.url" self.repo_url_provider = RepositoryUrlProvider(self.template_renderer, self.artifact_variables_provider) @@ -27,12 +20,8 @@ def test_get_url(self): self.artifact_variables_provider.get_artifact_variables.return_value = {"fake": "vars"} self.repo_url_provider.render_url_for_key(self.host, self.variables, self.url_key) - self.artifact_variables_provider.get_artifact_variables.assert_has_calls([ - mock.call(self.host, "1.2.3") - ]) - self.template_renderer.render_template_string.assert_has_calls([ - mock.call("opensearch/{{VERSION}}/opensearch-{{VERSION}}-{{OSNAME}}-{{ARCH}}.tar.gz", {"fake": "vars"}) - ]) + self.artifact_variables_provider.get_artifact_variables.assert_has_calls([mock.call(self.host, "1.2.3")]) + self.template_renderer.render_template_string.assert_has_calls([mock.call("opensearch/{{VERSION}}/opensearch-{{VERSION}}-{{OSNAME}}-{{ARCH}}.tar.gz", {"fake": "vars"})]) def test_no_url_template_found(self): with self.assertRaises(SystemSetupError): diff --git a/tests/builder/downloaders/repositories/source_repository_provider_test.py b/tests/builder/downloaders/repositories/source_repository_provider_test.py index 734a30d9..f8d92ed2 100644 --- a/tests/builder/downloaders/repositories/source_repository_provider_test.py +++ b/tests/builder/downloaders/repositories/source_repository_provider_test.py @@ -25,12 +25,8 @@ def test_initialize_repo_with_remote(self): self.source_repo_provider.fetch_repository(self.host, self.remote_url, self.revision, self.target_dir) - self.source_repo_provider.path_manager.create_path.assert_has_calls([ - mock.call(self.host, self.target_dir, create_locally=False) - ]) - self.source_repo_provider.git_manager.clone.assert_has_calls([ - mock.call(self.host, self.remote_url, self.target_dir) - ]) + self.source_repo_provider.path_manager.create_path.assert_has_calls([mock.call(self.host, self.target_dir, create_locally=False)]) + self.source_repo_provider.git_manager.clone.assert_has_calls([mock.call(self.host, self.remote_url, self.target_dir)]) def test_initialize_repo_skippable(self): # Check repo/.git, check repo, check repo/.git @@ -50,48 +46,51 @@ def test_initialize_repo_no_remote_not_skippable(self): def test_update_repo_to_latest(self): self.source_repo_provider.fetch_repository(self.host, self.remote_url, "latest", self.target_dir) - self.source_repo_provider.git_manager.assert_has_calls([ - mock.call.fetch(self.host, self.target_dir), - mock.call.checkout(self.host, self.target_dir), - mock.call.rebase(self.host, self.target_dir), - mock.call.get_revision_from_local_repository(self.host, self.target_dir) - ]) + self.source_repo_provider.git_manager.assert_has_calls( + [ + mock.call.fetch(self.host, self.target_dir), + mock.call.checkout(self.host, self.target_dir), + mock.call.rebase(self.host, self.target_dir), + mock.call.get_revision_from_local_repository(self.host, self.target_dir), + ] + ) def test_update_repo_to_current(self): self.source_repo_provider.fetch_repository(self.host, self.remote_url, self.revision, self.target_dir) - self.source_repo_provider.git_manager.assert_has_calls([ - mock.call.get_revision_from_local_repository(self.host, self.target_dir) - ]) + self.source_repo_provider.git_manager.assert_has_calls([mock.call.get_revision_from_local_repository(self.host, self.target_dir)]) def test_update_repo_to_timestamp(self): self.source_repo_provider.git_manager.get_revision_from_timestamp.return_value = "fake rev" self.source_repo_provider.fetch_repository(self.host, self.remote_url, "@fake-timestamp", self.target_dir) - self.source_repo_provider.git_manager.assert_has_calls([ - mock.call.fetch(self.host, self.target_dir), - mock.call.get_revision_from_timestamp(self.host, self.target_dir, "fake-timestamp"), - mock.call.checkout(self.host, self.target_dir, "fake rev"), - mock.call.get_revision_from_local_repository(self.host, self.target_dir) - ]) + self.source_repo_provider.git_manager.assert_has_calls( + [ + mock.call.fetch(self.host, self.target_dir), + mock.call.get_revision_from_timestamp(self.host, self.target_dir, "fake-timestamp"), + mock.call.checkout(self.host, self.target_dir, "fake rev"), + mock.call.get_revision_from_local_repository(self.host, self.target_dir), + ] + ) def test_update_repo_to_commit_hash(self): self.source_repo_provider.fetch_repository(self.host, self.remote_url, "uuid", self.target_dir) - self.source_repo_provider.git_manager.assert_has_calls([ - mock.call.fetch(self.host, self.target_dir), - mock.call.checkout(self.host, self.target_dir, "uuid"), - mock.call.get_revision_from_local_repository(self.host, self.target_dir) - ]) + self.source_repo_provider.git_manager.assert_has_calls( + [ + mock.call.fetch(self.host, self.target_dir), + mock.call.checkout(self.host, self.target_dir, "uuid"), + mock.call.get_revision_from_local_repository(self.host, self.target_dir), + ] + ) def test_update_repo_to_local_revision(self): self.source_repo_provider.fetch_repository(self.host, None, "fake rev", self.target_dir) - self.source_repo_provider.git_manager.assert_has_calls([ - mock.call.checkout(self.host, self.target_dir, "fake rev"), - mock.call.get_revision_from_local_repository(self.host, self.target_dir) - ]) + self.source_repo_provider.git_manager.assert_has_calls( + [mock.call.checkout(self.host, self.target_dir, "fake rev"), mock.call.get_revision_from_local_repository(self.host, self.target_dir)] + ) def test_get_revision_repo_exists(self): self.source_repo_provider.git_manager.get_revision_from_local_repository.return_value = "my rev" diff --git a/tests/builder/downloaders/source_downloader_test.py b/tests/builder/downloaders/source_downloader_test.py index 42180b7c..4a6a6a09 100644 --- a/tests/builder/downloaders/source_downloader_test.py +++ b/tests/builder/downloaders/source_downloader_test.py @@ -16,33 +16,26 @@ def setUp(self): self.template_renderer = Mock() self.artifact_variables_provider = Mock() - self.cluster_config = ClusterConfigInstance(names="fake", root_path="also fake", config_paths="fake2", variables={ - "source": { - "root": { - "dir": "/fake/dir/for/source" - }, - "solr": { - "subdir": "solr_sub-dir" - }, - "remote": { - "repo": { - "url": "https://git.remote.fake" - } - }, - "revision": "current", - "artifact_path_pattern": "{{OSNAME}}.tar.gz", - "build": { - "command": "gradle build" - }, - "clean": { - "command": "gradle clean" + self.cluster_config = ClusterConfigInstance( + names="fake", + root_path="also fake", + config_paths="fake2", + variables={ + "source": { + "root": {"dir": "/fake/dir/for/source"}, + "solr": {"subdir": "solr_sub-dir"}, + "remote": {"repo": {"url": "https://git.remote.fake"}}, + "revision": "current", + "artifact_path_pattern": "{{OSNAME}}.tar.gz", + "build": {"command": "gradle build"}, + "clean": {"command": "gradle clean"}, } - } - }) + }, + ) - self.solr_source_downloader = SourceDownloader(self.cluster_config, self.executor, - self.source_repository_provider, self.binary_builder, - self.template_renderer, self.artifact_variables_provider) + self.solr_source_downloader = SourceDownloader( + self.cluster_config, self.executor, self.source_repository_provider, self.binary_builder, self.template_renderer, self.artifact_variables_provider + ) def test_download(self): self.artifact_variables_provider.get_artifact_variables.return_value = {"OSNAME": "fake_OS"} @@ -50,17 +43,9 @@ def test_download(self): solr_binary = self.solr_source_downloader.download(self.host) self.assertEqual(solr_binary, {BinaryKeys.SOLR: "/fake/dir/for/source/solr_sub-dir/fake artifact path"}) - self.source_repository_provider.fetch_repository.assert_has_calls([ - mock.call(self.host, "https://git.remote.fake", "current", "/fake/dir/for/source/solr_sub-dir") - ]) - self.binary_builder.build.assert_has_calls([ - mock.call(self.host, ["fake clean", "fake build"]) - ]) - self.artifact_variables_provider.get_artifact_variables.assert_has_calls([ - mock.call(self.host) - ]) - self.template_renderer.render_template_string.assert_has_calls([ - mock.call("gradle clean", {"OSNAME": "fake_OS"}), - mock.call("gradle build", {"OSNAME": "fake_OS"}), - mock.call("{{OSNAME}}.tar.gz", {"OSNAME": "fake_OS"}) - ]) + self.source_repository_provider.fetch_repository.assert_has_calls([mock.call(self.host, "https://git.remote.fake", "current", "/fake/dir/for/source/solr_sub-dir")]) + self.binary_builder.build.assert_has_calls([mock.call(self.host, ["fake clean", "fake build"])]) + self.artifact_variables_provider.get_artifact_variables.assert_has_calls([mock.call(self.host)]) + self.template_renderer.render_template_string.assert_has_calls( + [mock.call("gradle clean", {"OSNAME": "fake_OS"}), mock.call("gradle build", {"OSNAME": "fake_OS"}), mock.call("{{OSNAME}}.tar.gz", {"OSNAME": "fake_OS"})] + ) diff --git a/tests/builder/installers/bare_installer_test.py b/tests/builder/installers/bare_installer_test.py index 7b86d690..2c21f331 100644 --- a/tests/builder/installers/bare_installer_test.py +++ b/tests/builder/installers/bare_installer_test.py @@ -24,14 +24,7 @@ def setUp(self): names="defaults", root_path="fake", config_paths=["/tmp"], - variables={ - "test_run_root": self.test_run_root, - "cluster_name": self.cluster_name, - "node": { - "port": "8983" - }, - "preserve_install": False - } + variables={"test_run_root": self.test_run_root, "cluster_name": self.cluster_name, "node": {"port": "8983"}, "preserve_install": False}, ) self.installer = BareInstaller(self.cluster_config, self.executor, self.preparer) self.installer.config_applier = Mock() @@ -49,30 +42,18 @@ def test_install_node(self): node = self.installer.install(self.host, self.binaries, self.all_node_ips) self.assertEqual(node, "fake node") - self.preparer.prepare.assert_has_calls([ - mock.call(self.host, self.binaries) - ]) - self.preparer.get_config_vars.assert_has_calls([ - mock.call(self.host, "fake node", self.all_node_ips) - ]) - self.installer.config_applier.apply_configs.assert_has_calls([ - mock.call(self.host, "fake node", ["/tmp"], {"fake": "config"}) - ]) - self.installer.java_home_resolver.resolve_java_home.assert_has_calls([ - mock.call(self.host, self.cluster_config) - ]) - self.preparer.invoke_install_hook.assert_has_calls([ - mock.call(self.host, BootstrapPhase.post_install, {"fake": "config"}, {"JAVA_HOME": "/path/to/java/home"}) - ]) + self.preparer.prepare.assert_has_calls([mock.call(self.host, self.binaries)]) + self.preparer.get_config_vars.assert_has_calls([mock.call(self.host, "fake node", self.all_node_ips)]) + self.installer.config_applier.apply_configs.assert_has_calls([mock.call(self.host, "fake node", ["/tmp"], {"fake": "config"})]) + self.installer.java_home_resolver.resolve_java_home.assert_has_calls([mock.call(self.host, self.cluster_config)]) + self.preparer.invoke_install_hook.assert_has_calls([mock.call(self.host, BootstrapPhase.post_install, {"fake": "config"}, {"JAVA_HOME": "/path/to/java/home"})]) def test_install_no_java_home(self): self.installer.java_home_resolver.resolve_java_home.return_value = (None, None) self.installer.install(self.host, self.binaries, self.all_node_ips) - self.preparer.invoke_install_hook.assert_has_calls([ - mock.call(self.host, BootstrapPhase.post_install, {"fake": "config"}, {}) - ]) + self.preparer.invoke_install_hook.assert_has_calls([mock.call(self.host, BootstrapPhase.post_install, {"fake": "config"}, {})]) def test_multiple_nodes_installed(self): self.installer.preparers = [self.preparer, self.preparer2] diff --git a/tests/builder/installers/docker_installer_test.py b/tests/builder/installers/docker_installer_test.py index 52c3cf74..6a51624b 100644 --- a/tests/builder/installers/docker_installer_test.py +++ b/tests/builder/installers/docker_installer_test.py @@ -30,63 +30,59 @@ def setUp(self): variables={ "cluster_name": self.cluster_name, "test_run_root": self.test_run_root, - "node": { - "port": self.port - }, - "origin": { - "distribution": { - "version": "1.1.0" - }, - "docker": { - "docker_image": "solr" - } - } - } + "node": {"port": self.port}, + "origin": {"distribution": {"version": "1.1.0"}, "docker": {"docker_image": "solr"}}, + }, ) self.installer = DockerInstaller(self.cluster_config, self.executor) maxDiff = None + @mock.patch("uuid.uuid4") @mock.patch("solrorbit.paths.benchmark_root") def test_provisioning_with_defaults(self, benchmark_root, uuid4): uuid4.return_value = self.node_name - benchmark_root.return_value = os.path.normpath(os.path.join(os.path.dirname(os.path.realpath(__file__)), - os.pardir, os.pardir, os.pardir, "solrorbit")) + benchmark_root.return_value = os.path.normpath(os.path.join(os.path.dirname(os.path.realpath(__file__)), os.pardir, os.pardir, os.pardir, "solrorbit")) node = self.installer._create_node() - self.assertDictEqual({ - "cluster_name": self.cluster_name, - "node_name": self.node_name, - "install_root_path": "/var/solr", - "data_paths": ["/var/solr/data"], - "log_path": "/var/solr/logs", - "heap_dump_path": "/var/solr/heapdump", - "discovery_type": "single-node", - "network_host": "0.0.0.0", - "http_port": self.port, - "zookeeper_port": str(int(self.port) + 1000), - "cluster_settings": { + self.assertDictEqual( + { + "cluster_name": self.cluster_name, + "node_name": self.node_name, + "install_root_path": "/var/solr", + "data_paths": ["/var/solr/data"], + "log_path": "/var/solr/logs", + "heap_dump_path": "/var/solr/heapdump", + "discovery_type": "single-node", + "network_host": "0.0.0.0", + "http_port": self.port, + "zookeeper_port": str(int(self.port) + 1000), + "cluster_settings": {}, + "docker_image": "solr", }, - "docker_image": "solr" - }, self.installer._get_config_vars(node)) + self.installer._get_config_vars(node), + ) docker_vars = self.installer._get_docker_vars(node, mounts={}) - self.assertDictEqual({ - "solr_data_dir": self.node_data_dir, - "solr_log_dir": self.node_log_dir, - "solr_heap_dump_dir": self.node_heap_dump_dir, - "solr_version": "1.1.0", - "docker_image": "solr", - "http_port": 38983, - "mounts": {} - }, docker_vars) + self.assertDictEqual( + { + "solr_data_dir": self.node_data_dir, + "solr_log_dir": self.node_log_dir, + "solr_heap_dump_dir": self.node_heap_dump_dir, + "solr_version": "1.1.0", + "docker_image": "solr", + "http_port": 38983, + "mounts": {}, + }, + docker_vars, + ) docker_cfg = self.installer._render_template_from_docker_file(docker_vars) self.assertEqual( -"""version: '3' + """version: '3' services: solr-node1: image: solr:1.1.0 @@ -117,14 +113,16 @@ def test_provisioning_with_defaults(self, benchmark_root, uuid4): solr-data1: networks: solr-net: -""" % (self.node_data_dir, self.node_log_dir, self.node_heap_dump_dir), docker_cfg) +""" + % (self.node_data_dir, self.node_log_dir, self.node_heap_dump_dir), + docker_cfg, + ) @mock.patch("uuid.uuid4") @mock.patch("solrorbit.paths.benchmark_root") def test_provisioning_with_variables(self, benchmark_root, uuid4): uuid4.return_value = self.node_name - benchmark_root.return_value = os.path.normpath(os.path.join(os.path.dirname(os.path.realpath(__file__)), - os.pardir, os.pardir, os.pardir, "solrorbit")) + benchmark_root.return_value = os.path.normpath(os.path.join(os.path.dirname(os.path.realpath(__file__)), os.pardir, os.pardir, os.pardir, "solrorbit")) self.cluster_config.variables["origin"]["docker"]["docker_mem_limit"] = "256m" self.cluster_config.variables["origin"]["docker"]["docker_cpu_count"] = 2 @@ -135,7 +133,7 @@ def test_provisioning_with_variables(self, benchmark_root, uuid4): docker_cfg = self.installer._render_template_from_docker_file(docker_vars) self.assertEqual( -"""version: '3' + """version: '3' services: solr-node1: image: solr:1.1.0 @@ -168,4 +166,7 @@ def test_provisioning_with_variables(self, benchmark_root, uuid4): solr-data1: networks: solr-net: -""" % (self.node_data_dir, self.node_log_dir, self.node_heap_dump_dir), docker_cfg) +""" + % (self.node_data_dir, self.node_log_dir, self.node_heap_dump_dir), + docker_cfg, + ) diff --git a/tests/builder/installers/preparers/solr_preparer_test.py b/tests/builder/installers/preparers/solr_preparer_test.py index cc1ed7d3..37ccf81f 100644 --- a/tests/builder/installers/preparers/solr_preparer_test.py +++ b/tests/builder/installers/preparers/solr_preparer_test.py @@ -12,9 +12,17 @@ class NodePreparerTests(TestCase): def setUp(self): self.node_id = "abdefg" - self.node = Node(binary_path="/fake_binary_path", data_paths=["/fake1", "/fake2"], - name=self.node_id, pid=None, telemetry=None, port=8983, root_dir=None, - log_path="/fake/logpath", heap_dump_path="/fake/heap") + self.node = Node( + binary_path="/fake_binary_path", + data_paths=["/fake1", "/fake2"], + name=self.node_id, + pid=None, + telemetry=None, + port=8983, + root_dir=None, + log_path="/fake/logpath", + heap_dump_path="/fake/heap", + ) self.host = Host(name="fake", address="10.17.22.23", metadata={}, node=None) self.binaries = {BinaryKeys.SOLR: "/data/builds/distributions"} self.all_node_ips = ["10.17.22.22", "10.17.22.23"] @@ -26,16 +34,7 @@ def setUp(self): self.hook_handler_class = Mock() self.cluster_config = ClusterConfigInstance( - names="defaults", - root_path="fake", - config_paths=["/tmp"], - variables={ - "test_run_root": self.test_run_root, - "cluster_name": self.cluster_name, - "node": { - "port": "8983" - } - } + names="defaults", root_path="fake", config_paths=["/tmp"], variables={"test_run_root": self.test_run_root, "cluster_name": self.cluster_name, "node": {"port": "8983"}} ) self.preparer = SolrPreparer(self.cluster_config, self.executor, self.hook_handler_class) self.preparer.path_manager = Mock() @@ -56,19 +55,22 @@ def test_prepare(self, uuid): def test_config_vars(self): config_vars = self.preparer.get_config_vars(self.host, self.node, self.all_node_ips) - self.assertEqual({ - "cluster_name": self.cluster_name, - "node_name": self.node_id, - "data_paths": "/fake1", - "log_path": "/fake/logpath", - "heap_dump_path": "/fake/heap", - "node_ip": "10.17.22.23", - "network_host": "10.17.22.23", - "http_port": "8983", - "zookeeper_port": "9983", - "all_node_ips": "[\"10.17.22.22\",\"10.17.22.23\"]", - "minimum_master_nodes": 2, - "install_root_path": "/fake_binary_path", - "node": {"port": "8983"}, - "test_run_root": self.test_run_root - }, config_vars) + self.assertEqual( + { + "cluster_name": self.cluster_name, + "node_name": self.node_id, + "data_paths": "/fake1", + "log_path": "/fake/logpath", + "heap_dump_path": "/fake/heap", + "node_ip": "10.17.22.23", + "network_host": "10.17.22.23", + "http_port": "8983", + "zookeeper_port": "9983", + "all_node_ips": '["10.17.22.22","10.17.22.23"]', + "minimum_master_nodes": 2, + "install_root_path": "/fake_binary_path", + "node": {"port": "8983"}, + "test_run_root": self.test_run_root, + }, + config_vars, + ) diff --git a/tests/builder/java_resolver_test.py b/tests/builder/java_resolver_test.py index 163c2280..e3eb4b82 100644 --- a/tests/builder/java_resolver_test.py +++ b/tests/builder/java_resolver_test.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -35,9 +35,10 @@ class JavaResolverTests(TestCase): @mock.patch("solrorbit.utils.jvm.resolve_path") def test_resolves_java_home_for_default_runtime_jdk(self, resolve_jvm_path): resolve_jvm_path.return_value = (12, "/opt/jdk12") - major, java_home = java_resolver.java_home("12,11,10,9,8", - specified_runtime_jdk=None, - ) + major, java_home = java_resolver.java_home( + "12,11,10,9,8", + specified_runtime_jdk=None, + ) self.assertEqual(major, 12) self.assertEqual(java_home, "/opt/jdk12") @@ -45,9 +46,10 @@ def test_resolves_java_home_for_default_runtime_jdk(self, resolve_jvm_path): @mock.patch("solrorbit.utils.jvm.resolve_path") def test_resolves_java_home_for_specific_runtime_jdk(self, resolve_jvm_path): resolve_jvm_path.return_value = (8, "/opt/jdk8") - major, java_home = java_resolver.java_home("12,11,10,9,8", - specified_runtime_jdk=8, - ) + major, java_home = java_resolver.java_home( + "12,11,10,9,8", + specified_runtime_jdk=8, + ) self.assertEqual(major, 8) self.assertEqual(java_home, "/opt/jdk8") diff --git a/tests/builder/launcher_test.py b/tests/builder/launcher_test.py index ff9444b2..346d8215 100644 --- a/tests/builder/launcher_test.py +++ b/tests/builder/launcher_test.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -53,39 +53,20 @@ def create(self): class MockClient: def __init__(self, client_options): self.client_options = client_options - self.cluster = SubClient({ - "cluster_name": "benchmark-provisioned-cluster-cluster", - "nodes": { - "FCFjozkeTiOpN-SI88YEcg": { - "name": "Nefarius", - "host": "127.0.0.1" - } - } - }) - self.nodes = SubClient({ - "nodes": { - "FCFjozkeTiOpN-SI88YEcg": { - "name": "Nefarius", - "host": "127.0.0.1", - "os": { - "name": "Mac OS X", - "version": "10.11.4", - "available_processors": 8 - }, - "jvm": { - "version": "1.8.0_74", - "vm_vendor": "Oracle Corporation" + self.cluster = SubClient({"cluster_name": "benchmark-provisioned-cluster-cluster", "nodes": {"FCFjozkeTiOpN-SI88YEcg": {"name": "Nefarius", "host": "127.0.0.1"}}}) + self.nodes = SubClient( + { + "nodes": { + "FCFjozkeTiOpN-SI88YEcg": { + "name": "Nefarius", + "host": "127.0.0.1", + "os": {"name": "Mac OS X", "version": "10.11.4", "available_processors": 8}, + "jvm": {"version": "1.8.0_74", "vm_vendor": "Oracle Corporation"}, } } } - }) - self._info = { - "version": - { - "number": "5.0.0", - "build_hash": "abc123" - } - } + ) + self._info = {"version": {"number": "5.0.0", "build_hash": "abc123"}} def info(self): if self.client_options.get("raise-error-on-info", False): @@ -156,11 +137,7 @@ def __init__(self, pid): def get_metrics_store(cfg): ms = InMemoryMetricsStore(cfg) - ms.open(test_run_id=str(uuid.uuid4()), - test_run_timestamp=datetime.now(), - workload_name="test", - test_procedure_name="test", - cluster_config_name="test") + ms.open(test_run_id=str(uuid.uuid4()), test_run_timestamp=datetime.now(), workload_name="test", test_procedure_name="test", cluster_config_name="test") return ms @@ -168,13 +145,13 @@ def get_metrics_store(cfg): class ProcessLauncherTests(TestCase): - @mock.patch('subprocess.Popen', new=MockPopen) - @mock.patch('solrorbit.builder.java_resolver.java_home', return_value=(12, "/java_home/")) - @mock.patch('solrorbit.utils.jvm.supports_option', return_value=True) - @mock.patch('solrorbit.utils.io.get_size') - @mock.patch('os.chdir') - @mock.patch('solrorbit.builder.launcher.wait_for_pidfile', return_value=MOCK_PID_VALUE) - @mock.patch('psutil.Process', new=MockProcess) + @mock.patch("subprocess.Popen", new=MockPopen) + @mock.patch("solrorbit.builder.java_resolver.java_home", return_value=(12, "/java_home/")) + @mock.patch("solrorbit.utils.jvm.supports_option", return_value=True) + @mock.patch("solrorbit.utils.io.get_size") + @mock.patch("os.chdir") + @mock.patch("solrorbit.builder.launcher.wait_for_pidfile", return_value=MOCK_PID_VALUE) + @mock.patch("psutil.Process", new=MockProcess) def test_daemon_start_stop(self, wait_for_pidfile, chdir, get_size, supports, java_home): cfg = config.Config() cfg.add(config.Scope.application, "node", "root.dir", "test") @@ -188,13 +165,17 @@ def test_daemon_start_stop(self, wait_for_pidfile, chdir, get_size, supports, ja node_configs = [] for node in range(2): - node_configs.append(NodeConfiguration(build_type="tar", - cluster_config_runtime_jdks="12,11", - ip="127.0.0.1", - node_name="testnode-{}".format(node), - node_root_path="/tmp", - binary_path="/tmp", - data_paths="/tmp")) + node_configs.append( + NodeConfiguration( + build_type="tar", + cluster_config_runtime_jdks="12,11", + ip="127.0.0.1", + node_name="testnode-{}".format(node), + node_root_path="/tmp", + binary_path="/tmp", + data_paths="/tmp", + ) + ) nodes = proc_launcher.start(node_configs) self.assertEqual(len(nodes), 2) @@ -204,7 +185,7 @@ def test_daemon_start_stop(self, wait_for_pidfile, chdir, get_size, supports, ja # all nodes should be stopped self.assertEqual(nodes, stopped_nodes) - @mock.patch('psutil.Process', new=TerminatedProcess) + @mock.patch("psutil.Process", new=TerminatedProcess) def test_daemon_stop_with_already_terminated_process(self): cfg = config.Config() cfg.add(config.Scope.application, "node", "root.dir", "test") @@ -215,13 +196,7 @@ def test_daemon_stop_with_already_terminated_process(self): ms = get_metrics_store(cfg) proc_launcher = launcher.ProcessLauncher(cfg) - nodes = [ - cluster.Node(pid=-1, - binary_path="/bin", - host_name="localhost", - node_name="benchmark-0", - telemetry=telemetry.Telemetry()) - ] + nodes = [cluster.Node(pid=-1, binary_path="/bin", host_name="localhost", node_name="benchmark-0", telemetry=telemetry.Telemetry())] stopped_nodes = proc_launcher.stop(nodes, ms) # no nodes should have been stopped (they were already stopped) @@ -235,16 +210,16 @@ def test_env_options_order(self, sleep): proc_launcher = launcher.ProcessLauncher(cfg) - node_telemetry = [ - telemetry.FlightRecorder(telemetry_params={}, log_root="/tmp/telemetry", java_major_version=8) - ] + node_telemetry = [telemetry.FlightRecorder(telemetry_params={}, log_root="/tmp/telemetry", java_major_version=8)] t = telemetry.Telemetry(["jfr"], devices=node_telemetry) env = proc_launcher._prepare_env(node_name="node0", java_home="/java_home", t=t) self.assertEqual("/java_home/bin" + os.pathsep + os.environ["PATH"], env["PATH"]) - self.assertEqual("-XX:+ExitOnOutOfMemoryError -XX:+UnlockDiagnosticVMOptions -XX:+DebugNonSafepoints " - "-XX:StartFlightRecording=maxsize=0,maxage=0s,disk=true,dumponexit=true,filename=/tmp/telemetry/profile.jfr", - env["SOLR_JAVA_OPTS"]) + self.assertEqual( + "-XX:+ExitOnOutOfMemoryError -XX:+UnlockDiagnosticVMOptions -XX:+DebugNonSafepoints " + "-XX:StartFlightRecording=maxsize=0,maxage=0s,disk=true,dumponexit=true,filename=/tmp/telemetry/profile.jfr", + env["SOLR_JAVA_OPTS"], + ) def test_bundled_jdk_not_in_path(self): cfg = config.Config() @@ -325,6 +300,7 @@ def _stub_first_read(*args, **kwargs): return "" else: return old_read_se(*args, *kwargs) + handle.read.side_effect = _stub_first_read return mo @@ -372,11 +348,9 @@ def test_starts_container_successfully(self, run_subprocess_with_output, run_sub cfg = config.Config() docker = launcher.DockerLauncher(cfg) - node_config = NodeConfiguration(build_type="docker", - cluster_config_runtime_jdks="12,11", - ip="127.0.0.1", node_name="testnode", - node_root_path="/tmp", binary_path="/bin", - data_paths="/tmp") + node_config = NodeConfiguration( + build_type="docker", cluster_config_runtime_jdks="12,11", ip="127.0.0.1", node_name="testnode", node_root_path="/tmp", binary_path="/bin", data_paths="/tmp" + ) nodes = docker.start([node_config]) self.assertEqual(1, len(nodes)) @@ -389,10 +363,9 @@ def test_starts_container_successfully(self, run_subprocess_with_output, run_sub self.assertIsNotNone(node.telemetry) run_subprocess_with_logging.assert_called_once_with("docker-compose -f /bin/docker-compose.yml up -d") - run_subprocess_with_output.assert_has_calls([ - mock.call("docker-compose -f /bin/docker-compose.yml ps -q"), - mock.call('docker ps -a --filter "id=de604d0d" --filter "status=running" --filter "health=healthy" -q') - ]) + run_subprocess_with_output.assert_has_calls( + [mock.call("docker-compose -f /bin/docker-compose.yml ps -q"), mock.call('docker ps -a --filter "id=de604d0d" --filter "status=running" --filter "health=healthy" -q')] + ) @mock.patch("solrorbit.time.sleep") @mock.patch("solrorbit.utils.process.run_subprocess_with_logging") @@ -407,10 +380,8 @@ def test_container_not_started(self, run_subprocess_with_output, run_subprocess_ docker = launcher.DockerLauncher(cfg, clock=TestClock(stop_watch=stop_watch)) node_config = NodeConfiguration( - build_type="docker", cluster_config_runtime_jdks="12,11", - ip="127.0.0.1", node_name="testnode", - node_root_path="/tmp", binary_path="/bin", - data_paths="/tmp") + build_type="docker", cluster_config_runtime_jdks="12,11", ip="127.0.0.1", node_name="testnode", node_root_path="/tmp", binary_path="/bin", data_paths="/tmp" + ) with self.assertRaisesRegex(exceptions.LaunchError, "No healthy running container after 600 seconds!"): docker.start([node_config]) diff --git a/tests/builder/launchers/docker_launcher_test.py b/tests/builder/launchers/docker_launcher_test.py index 02fd2386..ba9a65b8 100644 --- a/tests/builder/launchers/docker_launcher_test.py +++ b/tests/builder/launchers/docker_launcher_test.py @@ -18,11 +18,9 @@ def setUp(self): self.launcher.waiter = Mock() self.host = None - self.node_config = NodeConfiguration(build_type="docker", - cluster_config_runtime_jdks="12,11", - ip="127.0.0.1", node_name="testnode", - node_root_path="/tmp", binary_path="/bin", - data_paths="/tmp") + self.node_config = NodeConfiguration( + build_type="docker", cluster_config_runtime_jdks="12,11", ip="127.0.0.1", node_name="testnode", node_root_path="/tmp", binary_path="/bin", data_paths="/tmp" + ) def test_starts_container_successfully(self): # [Start container (from docker-compose up), Docker container id (from docker-compose ps), @@ -39,10 +37,12 @@ def test_starts_container_successfully(self): self.assertEqual("testnode", node.node_name) self.assertIsNotNone(node.telemetry) - self.shell_executor.execute.assert_has_calls([ - mock.call(self.host, "docker-compose -f /bin/docker-compose.yml up -d"), - mock.call(self.host, "docker-compose -f /bin/docker-compose.yml ps -q", output=True), - ]) + self.shell_executor.execute.assert_has_calls( + [ + mock.call(self.host, "docker-compose -f /bin/docker-compose.yml up -d"), + mock.call(self.host, "docker-compose -f /bin/docker-compose.yml ps -q", output=True), + ] + ) def test_container_not_started(self): # [Start container (from docker-compose up), Docker container id (from docker-compose ps), @@ -75,15 +75,15 @@ def test_container_not_healthy(self): output = self.launcher._is_container_healthy(self.host, "de604d0d") self.assertEqual(output, False) - self.shell_executor.execute.assert_has_calls([ - mock.call(self.host, 'docker ps -a --filter "id=de604d0d" --filter "status=running" --filter "health=healthy" -q', output=True) - ]) + self.shell_executor.execute.assert_has_calls( + [mock.call(self.host, 'docker ps -a --filter "id=de604d0d" --filter "status=running" --filter "health=healthy" -q', output=True)] + ) def test_container_healthy(self): self.shell_executor.execute.return_value = ["We have a container"] output = self.launcher._is_container_healthy(self.host, "de604d0d") self.assertEqual(output, True) - self.shell_executor.execute.assert_has_calls([ - mock.call(self.host, 'docker ps -a --filter "id=de604d0d" --filter "status=running" --filter "health=healthy" -q', output=True) - ]) + self.shell_executor.execute.assert_has_calls( + [mock.call(self.host, 'docker ps -a --filter "id=de604d0d" --filter "status=running" --filter "health=healthy" -q', output=True)] + ) diff --git a/tests/builder/launchers/local_process_launcher_test.py b/tests/builder/launchers/local_process_launcher_test.py index 61c76faa..f862a4c4 100644 --- a/tests/builder/launchers/local_process_launcher_test.py +++ b/tests/builder/launchers/local_process_launcher_test.py @@ -18,45 +18,35 @@ def setUp(self): self.shell_executor = Mock() self.metrics_store = Mock() - self.variables = { - "system": { - "runtime": { - "jdk": None - }, - "env": { - "passenv": "PATH" - } - }, - "telemetry": { - "devices": [], - "params": None - } - } - self.cluster_config = ClusterConfigInstance("fake_cluster_config", "/path/to/root", - ["/path/to/config"], variables=self.variables) + self.variables = {"system": {"runtime": {"jdk": None}, "env": {"passenv": "PATH"}}, "telemetry": {"devices": [], "params": None}} + self.cluster_config = ClusterConfigInstance("fake_cluster_config", "/path/to/root", ["/path/to/config"], variables=self.variables) self.launcher = LocalProcessLauncher(self.cluster_config, self.shell_executor, self.metrics_store) self.launcher.waiter = Mock() self.host = None self.path = "fake" - @mock.patch('solrorbit.builder.java_resolver.java_home', return_value=(12, "/java_home/")) - @mock.patch('solrorbit.utils.jvm.supports_option', return_value=True) - @mock.patch('solrorbit.utils.io.get_size') - @mock.patch('solrorbit.telemetry') - @mock.patch('psutil.Process') + @mock.patch("solrorbit.builder.java_resolver.java_home", return_value=(12, "/java_home/")) + @mock.patch("solrorbit.utils.jvm.supports_option", return_value=True) + @mock.patch("solrorbit.utils.io.get_size") + @mock.patch("solrorbit.telemetry") + @mock.patch("psutil.Process") def test_daemon_start_stop(self, process, telemetry, get_size, supports, java_home): mo = mock_open(read_data="1234") node_configs = [] for node in range(2): - node_configs.append(NodeConfiguration(build_type="tar", - cluster_config_runtime_jdks="12,11", - ip="127.0.0.1", - node_name=f"testnode-{node}", - node_root_path="/tmp", - binary_path="/tmp", - data_paths="/tmp")) + node_configs.append( + NodeConfiguration( + build_type="tar", + cluster_config_runtime_jdks="12,11", + ip="127.0.0.1", + node_name=f"testnode-{node}", + node_root_path="/tmp", + binary_path="/tmp", + data_paths="/tmp", + ) + ) with mock.patch("builtins.open", mo): nodes = self.launcher.start(self.host, node_configs) @@ -68,17 +58,11 @@ def test_daemon_start_stop(self, process, telemetry, get_size, supports, java_ho # all nodes should be stopped self.assertEqual(nodes, stopped_nodes) - @mock.patch('psutil.Process') + @mock.patch("psutil.Process") def test_daemon_stop_with_already_terminated_process(self, process): process.side_effect = NoSuchProcess(123) - nodes = [ - cluster.Node(pid=-1, - binary_path="/bin", - host_name="localhost", - node_name="benchmark-0", - telemetry=telemetry.Telemetry()) - ] + nodes = [cluster.Node(pid=-1, binary_path="/bin", host_name="localhost", node_name="benchmark-0", telemetry=telemetry.Telemetry())] stopped_nodes = self.launcher.stop(self.host, nodes) # no nodes should have been stopped (they were already stopped) @@ -87,16 +71,16 @@ def test_daemon_stop_with_already_terminated_process(self, process): # flight recorder shows a warning for several seconds before continuing @mock.patch("solrorbit.time.sleep") def test_env_options_order(self, sleep): - node_telemetry = [ - telemetry.FlightRecorder(telemetry_params={}, log_root="/tmp/telemetry", java_major_version=8) - ] + node_telemetry = [telemetry.FlightRecorder(telemetry_params={}, log_root="/tmp/telemetry", java_major_version=8)] telem = telemetry.Telemetry(["jfr"], devices=node_telemetry) env = self.launcher._prepare_env(node_name="node0", java_home="/java_home", telemetry=telem) self.assertEqual("/java_home/bin" + os.pathsep + os.environ["PATH"], env["PATH"]) - self.assertEqual("-XX:+ExitOnOutOfMemoryError -XX:+UnlockDiagnosticVMOptions -XX:+DebugNonSafepoints " - "-XX:StartFlightRecording=maxsize=0,maxage=0s,disk=true,dumponexit=true,filename=/tmp/telemetry/profile.jfr", - env["SOLR_JAVA_OPTS"]) + self.assertEqual( + "-XX:+ExitOnOutOfMemoryError -XX:+UnlockDiagnosticVMOptions -XX:+DebugNonSafepoints " + "-XX:StartFlightRecording=maxsize=0,maxage=0s,disk=true,dumponexit=true,filename=/tmp/telemetry/profile.jfr", + env["SOLR_JAVA_OPTS"], + ) def test_bundled_jdk_not_in_path(self): os.environ["JAVA_HOME"] = "/path/to/java" diff --git a/tests/builder/mechanic_test.py b/tests/builder/mechanic_test.py index 069122b3..a7fbffa0 100644 --- a/tests/builder/mechanic_test.py +++ b/tests/builder/mechanic_test.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -44,11 +44,14 @@ def test_converts_valid_hosts(self, resolver): {"host": "site.example.com", "port": 8983}, ] - self.assertEqual([ - ("127.0.0.1", 8983), - ("10.16.23.5", 8983), - ("11.22.33.44", 8983), - ], builder.to_ip_port(hosts)) + self.assertEqual( + [ + ("127.0.0.1", 8983), + ("10.16.23.5", 8983), + ("11.22.33.44", 8983), + ], + builder.to_ip_port(hosts), + ) @mock.patch("solrorbit.utils.net.resolve") def test_rejects_hosts_with_unexpected_properties(self, resolver): @@ -62,9 +65,10 @@ def test_rejects_hosts_with_unexpected_properties(self, resolver): with self.assertRaises(exceptions.SystemSetupError) as ctx: builder.to_ip_port(hosts) - self.assertEqual("When specifying nodes to be managed by " - "solr-orbit you can only supply hostname:port pairs (e.g. 'localhost:8983'), " - "any additional options cannot be supported.", ctx.exception.args[0]) + self.assertEqual( + "When specifying nodes to be managed by solr-orbit you can only supply hostname:port pairs (e.g. 'localhost:8983'), any additional options cannot be supported.", + ctx.exception.args[0], + ) def test_groups_nodes_by_host(self): ip_port = [ @@ -81,8 +85,8 @@ def test_groups_nodes_by_host(self): ("127.0.0.1", 9200): [0, 1, 2], ("10.16.23.5", 9200): [3], ("11.22.33.44", 9200): [4, 5], - - }, builder.nodes_by_host(ip_port) + }, + builder.nodes_by_host(ip_port), ) def test_extract_all_node_ips(self): @@ -94,8 +98,7 @@ def test_extract_all_node_ips(self): ("11.22.33.44", 9200), ("11.22.33.44", 9200), ] - self.assertSetEqual({"127.0.0.1", "10.16.23.5", "11.22.33.44"}, - builder.extract_all_node_ips(ip_port)) + self.assertSetEqual({"127.0.0.1", "10.16.23.5", "11.22.33.44"}, builder.extract_all_node_ips(ip_port)) class BuilderTests(TestCase): diff --git a/tests/builder/provisioner_test.py b/tests/builder/provisioner_test.py index f0014430..1b9b6ef6 100644 --- a/tests/builder/provisioner_test.py +++ b/tests/builder/provisioner_test.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -47,22 +47,23 @@ def test_prepare_without_plugins(self, mock_rm, mock_ensure_dir, mock_decompress def null_apply_config(source_root_path, target_root_path, config_vars): apply_config_calls.append((source_root_path, target_root_path, config_vars)) - installer = provisioner.NodeInstaller(cluster_config= - cluster_config.ClusterConfigInstance( - names="unit-test-cluster-config-instance", - root_path=None, - config_paths=[HOME_DIR + "/.benchmark/benchmarks/cluster_configs/default/my-cluster-config-instance"], - variables={"heap": "4g", "runtime.jdk": "8", "runtime.jdk.bundled": "true"}), + installer = provisioner.NodeInstaller( + cluster_config=cluster_config.ClusterConfigInstance( + names="unit-test-cluster-config-instance", + root_path=None, + config_paths=[HOME_DIR + "/.benchmark/benchmarks/cluster_configs/default/my-cluster-config-instance"], + variables={"heap": "4g", "runtime.jdk": "8", "runtime.jdk.bundled": "true"}, + ), java_home="/usr/local/javas/java8", node_name="benchmark-node-0", node_root_dir=HOME_DIR + "/.benchmark/benchmarks/test_runs/unittest", all_node_ips=["10.17.22.22", "10.17.22.23"], all_node_names=["benchmark-node-0", "benchmark-node-1"], ip="10.17.22.23", - http_port=8983) + http_port=8983, + ) - p = provisioner.BareProvisioner(os_installer=installer, - apply_config=null_apply_config) + p = provisioner.BareProvisioner(os_installer=installer, apply_config=null_apply_config) node_config = p.prepare({"solr": "/opt/solr-9.0.0.tar.gz"}) self.assertEqual("8", node_config.cluster_config_runtime_jdks) @@ -74,26 +75,28 @@ def null_apply_config(source_root_path, target_root_path, config_vars): self.assertEqual(HOME_DIR + "/.benchmark/benchmarks/cluster_configs/default/my-cluster-config-instance", source_root_path) self.assertEqual("/opt/solr-9.0.0", target_root_path) - self.assertEqual({ - "cluster_settings": { + self.assertEqual( + { + "cluster_settings": {}, + "heap": "4g", + "runtime.jdk": "8", + "runtime.jdk.bundled": "true", + "cluster_name": "benchmark-provisioned-cluster", + "node_name": "benchmark-node-0", + "data_paths": ["/opt/solr-9.0.0/data"], + "log_path": HOME_DIR + "/.benchmark/benchmarks/test_runs/unittest/logs/server", + "heap_dump_path": HOME_DIR + "/.benchmark/benchmarks/test_runs/unittest/heapdump", + "node_ip": "10.17.22.23", + "network_host": "10.17.22.23", + "http_port": "8983", + "zookeeper_port": "9983", + "all_node_ips": '["10.17.22.22","10.17.22.23"]', + "all_node_names": '["benchmark-node-0","benchmark-node-1"]', + "minimum_master_nodes": 2, + "install_root_path": "/opt/solr-9.0.0", }, - "heap": "4g", - "runtime.jdk": "8", - "runtime.jdk.bundled": "true", - "cluster_name": "benchmark-provisioned-cluster", - "node_name": "benchmark-node-0", - "data_paths": ["/opt/solr-9.0.0/data"], - "log_path": HOME_DIR + "/.benchmark/benchmarks/test_runs/unittest/logs/server", - "heap_dump_path": HOME_DIR + "/.benchmark/benchmarks/test_runs/unittest/heapdump", - "node_ip": "10.17.22.23", - "network_host": "10.17.22.23", - "http_port": "8983", - "zookeeper_port": "9983", - "all_node_ips": "[\"10.17.22.22\",\"10.17.22.23\"]", - "all_node_names": "[\"benchmark-node-0\",\"benchmark-node-1\"]", - "minimum_master_nodes": 2, - "install_root_path": "/opt/solr-9.0.0" - }, config_vars) + config_vars, + ) class NoopHookHandler: def __init__(self, plugin): @@ -103,10 +106,7 @@ def can_load(self): return False def invoke(self, phase, variables, **kwargs): - self.hook_calls[phase] = { - "variables": variables, - "kwargs": kwargs - } + self.hook_calls[phase] = {"variables": variables, "kwargs": kwargs} class NoopHookHandler: @@ -129,35 +129,38 @@ class NodeInstallerTests(TestCase): @mock.patch("solrorbit.utils.io.ensure_dir") @mock.patch("shutil.rmtree") def test_prepare_default_data_paths(self, mock_rm, mock_ensure_dir, mock_decompress): - installer = provisioner.NodeInstaller(cluster_config=cluster_config.ClusterConfigInstance(names="defaults", - root_path=None, - config_paths="/tmp"), - java_home="/usr/local/javas/java8", - node_name="benchmark-node-0", - all_node_ips=["10.17.22.22", "10.17.22.23"], - all_node_names=["benchmark-node-0", "benchmark-node-1"], - ip="10.17.22.23", - http_port=9200, - node_root_dir=HOME_DIR + "/.benchmark/benchmarks/test_runs/unittest") + installer = provisioner.NodeInstaller( + cluster_config=cluster_config.ClusterConfigInstance(names="defaults", root_path=None, config_paths="/tmp"), + java_home="/usr/local/javas/java8", + node_name="benchmark-node-0", + all_node_ips=["10.17.22.22", "10.17.22.23"], + all_node_names=["benchmark-node-0", "benchmark-node-1"], + ip="10.17.22.23", + http_port=9200, + node_root_dir=HOME_DIR + "/.benchmark/benchmarks/test_runs/unittest", + ) installer.install("/data/builds/distributions") self.assertEqual(installer.os_home_path, "/install/solr-9.0.0") - self.assertEqual({ - "cluster_name": "benchmark-provisioned-cluster", - "node_name": "benchmark-node-0", - "data_paths": ["/install/solr-9.0.0/data"], - "log_path": HOME_DIR + "/.benchmark/benchmarks/test_runs/unittest/logs/server", - "heap_dump_path": HOME_DIR + "/.benchmark/benchmarks/test_runs/unittest/heapdump", - "node_ip": "10.17.22.23", - "network_host": "10.17.22.23", - "http_port": "9200", - "zookeeper_port": "10200", - "all_node_ips": "[\"10.17.22.22\",\"10.17.22.23\"]", - "all_node_names": "[\"benchmark-node-0\",\"benchmark-node-1\"]", - "minimum_master_nodes": 2, - "install_root_path": "/install/solr-9.0.0" - }, installer.variables) + self.assertEqual( + { + "cluster_name": "benchmark-provisioned-cluster", + "node_name": "benchmark-node-0", + "data_paths": ["/install/solr-9.0.0/data"], + "log_path": HOME_DIR + "/.benchmark/benchmarks/test_runs/unittest/logs/server", + "heap_dump_path": HOME_DIR + "/.benchmark/benchmarks/test_runs/unittest/heapdump", + "node_ip": "10.17.22.23", + "network_host": "10.17.22.23", + "http_port": "9200", + "zookeeper_port": "10200", + "all_node_ips": '["10.17.22.22","10.17.22.23"]', + "all_node_names": '["benchmark-node-0","benchmark-node-1"]', + "minimum_master_nodes": 2, + "install_root_path": "/install/solr-9.0.0", + }, + installer.variables, + ) self.assertEqual(installer.data_paths, ["/install/solr-9.0.0/data"]) @@ -166,73 +169,76 @@ def test_prepare_default_data_paths(self, mock_rm, mock_ensure_dir, mock_decompr @mock.patch("solrorbit.utils.io.ensure_dir") @mock.patch("shutil.rmtree") def test_prepare_user_provided_data_path(self, mock_rm, mock_ensure_dir, mock_decompress): - installer = provisioner.NodeInstaller(cluster_config=cluster_config.ClusterConfigInstance(names="defaults", - root_path=None, - config_paths="/tmp", - variables={"data_paths": "/tmp/some/data-path-dir"}), - java_home="/usr/local/javas/java8", - node_name="benchmark-node-0", - all_node_ips=["10.17.22.22", "10.17.22.23"], - all_node_names=["benchmark-node-0", "benchmark-node-1"], - ip="10.17.22.23", - http_port=9200, - node_root_dir="~/.benchmark/benchmarks/test_runs/unittest") + installer = provisioner.NodeInstaller( + cluster_config=cluster_config.ClusterConfigInstance(names="defaults", root_path=None, config_paths="/tmp", variables={"data_paths": "/tmp/some/data-path-dir"}), + java_home="/usr/local/javas/java8", + node_name="benchmark-node-0", + all_node_ips=["10.17.22.22", "10.17.22.23"], + all_node_names=["benchmark-node-0", "benchmark-node-1"], + ip="10.17.22.23", + http_port=9200, + node_root_dir="~/.benchmark/benchmarks/test_runs/unittest", + ) installer.install("/data/builds/distributions") self.assertEqual(installer.os_home_path, "/install/solr-9.0.0") - self.assertEqual({ - "cluster_name": "benchmark-provisioned-cluster", - "node_name": "benchmark-node-0", - "data_paths": ["/tmp/some/data-path-dir"], - "log_path": "~/.benchmark/benchmarks/test_runs/unittest/logs/server", - "heap_dump_path": "~/.benchmark/benchmarks/test_runs/unittest/heapdump", - "node_ip": "10.17.22.23", - "network_host": "10.17.22.23", - "http_port": "9200", - "zookeeper_port": "10200", - "all_node_ips": "[\"10.17.22.22\",\"10.17.22.23\"]", - "all_node_names": "[\"benchmark-node-0\",\"benchmark-node-1\"]", - "minimum_master_nodes": 2, - "install_root_path": "/install/solr-9.0.0" - }, installer.variables) + self.assertEqual( + { + "cluster_name": "benchmark-provisioned-cluster", + "node_name": "benchmark-node-0", + "data_paths": ["/tmp/some/data-path-dir"], + "log_path": "~/.benchmark/benchmarks/test_runs/unittest/logs/server", + "heap_dump_path": "~/.benchmark/benchmarks/test_runs/unittest/heapdump", + "node_ip": "10.17.22.23", + "network_host": "10.17.22.23", + "http_port": "9200", + "zookeeper_port": "10200", + "all_node_ips": '["10.17.22.22","10.17.22.23"]', + "all_node_names": '["benchmark-node-0","benchmark-node-1"]', + "minimum_master_nodes": 2, + "install_root_path": "/install/solr-9.0.0", + }, + installer.variables, + ) self.assertEqual(installer.data_paths, ["/tmp/some/data-path-dir"]) def test_invokes_hook_with_java_home(self): - installer = provisioner.NodeInstaller(cluster_config=cluster_config.ClusterConfigInstance(names="defaults", - root_path="/tmp", - config_paths="/tmp/templates", - variables={"data_paths": "/tmp/some/data-path-dir"}), - java_home="/usr/local/javas/java8", - node_name="benchmark-node-0", - all_node_ips=["10.17.22.22", "10.17.22.23"], - all_node_names=["benchmark-node-0", "benchmark-node-1"], - ip="10.17.22.23", - http_port=9200, - node_root_dir="~/.benchmark/benchmarks/test_runs/unittest", - hook_handler_class=NoopHookHandler) + installer = provisioner.NodeInstaller( + cluster_config=cluster_config.ClusterConfigInstance( + names="defaults", root_path="/tmp", config_paths="/tmp/templates", variables={"data_paths": "/tmp/some/data-path-dir"} + ), + java_home="/usr/local/javas/java8", + node_name="benchmark-node-0", + all_node_ips=["10.17.22.22", "10.17.22.23"], + all_node_names=["benchmark-node-0", "benchmark-node-1"], + ip="10.17.22.23", + http_port=9200, + node_root_dir="~/.benchmark/benchmarks/test_runs/unittest", + hook_handler_class=NoopHookHandler, + ) self.assertEqual(0, len(installer.hook_handler.hook_calls)) installer.invoke_install_hook(cluster_config.BootstrapPhase.post_install, {"foo": "bar"}) self.assertEqual(1, len(installer.hook_handler.hook_calls)) self.assertEqual({"foo": "bar"}, installer.hook_handler.hook_calls["post_install"]["variables"]) - self.assertEqual({"env": {"JAVA_HOME": "/usr/local/javas/java8"}}, - installer.hook_handler.hook_calls["post_install"]["kwargs"]) + self.assertEqual({"env": {"JAVA_HOME": "/usr/local/javas/java8"}}, installer.hook_handler.hook_calls["post_install"]["kwargs"]) def test_invokes_hook_no_java_home(self): - installer = provisioner.NodeInstaller(cluster_config=cluster_config.ClusterConfigInstance(names="defaults", - root_path="/tmp", - config_paths="/tmp/templates", - variables={"data_paths": "/tmp/some/data-path-dir"}), - java_home=None, - node_name="benchmark-node-0", - all_node_ips=["10.17.22.22", "10.17.22.23"], - all_node_names=["benchmark-node-0", "benchmark-node-1"], - ip="10.17.22.23", - http_port=9200, - node_root_dir="~/.benchmark/benchmarks/test_runs/unittest", - hook_handler_class=NoopHookHandler) + installer = provisioner.NodeInstaller( + cluster_config=cluster_config.ClusterConfigInstance( + names="defaults", root_path="/tmp", config_paths="/tmp/templates", variables={"data_paths": "/tmp/some/data-path-dir"} + ), + java_home=None, + node_name="benchmark-node-0", + all_node_ips=["10.17.22.22", "10.17.22.23"], + all_node_names=["benchmark-node-0", "benchmark-node-1"], + ip="10.17.22.23", + http_port=9200, + node_root_dir="~/.benchmark/benchmarks/test_runs/unittest", + hook_handler_class=NoopHookHandler, + ) self.assertEqual(0, len(installer.hook_handler.hook_calls)) installer.invoke_install_hook(cluster_config.BootstrapPhase.post_install, {"foo": "bar"}) @@ -243,6 +249,7 @@ def test_invokes_hook_no_java_home(self): class DockerProvisionerTests(TestCase): maxDiff = None + @mock.patch("uuid.uuid4") def test_provisioning_with_defaults(self, uuid4): uuid4.return_value = "9dbc682e-d32a-4669-8fbe-56fb77120dd4" @@ -253,48 +260,53 @@ def test_provisioning_with_defaults(self, uuid4): benchmark_root = os.path.normpath(os.path.join(os.path.dirname(os.path.realpath(__file__)), os.pardir, os.pardir, "solrorbit")) - c = cluster_config.ClusterConfigInstance("unit-test-cluster-config-instance", None, "/tmp", variables={ - "docker_image": "solr" - }) - - docker = provisioner.DockerProvisioner(cluster_config=c, - node_name="benchmark-node-0", - ip="10.17.22.33", - http_port=38983, - node_root_dir=node_root_dir, - distribution_version="1.1.0", - benchmark_root=benchmark_root) - - self.assertDictEqual({ - "cluster_name": "benchmark-provisioned-cluster", - "node_name": "benchmark-node-0", - "install_root_path": "/var/solr", - "data_paths": ["/var/solr/data"], - "log_path": "/var/solr/logs", - "heap_dump_path": "/var/solr/heapdump", - "discovery_type": "single-node", - "network_host": "0.0.0.0", - "http_port": "38983", - "zookeeper_port": "39983", - "cluster_settings": { + c = cluster_config.ClusterConfigInstance("unit-test-cluster-config-instance", None, "/tmp", variables={"docker_image": "solr"}) + + docker = provisioner.DockerProvisioner( + cluster_config=c, + node_name="benchmark-node-0", + ip="10.17.22.33", + http_port=38983, + node_root_dir=node_root_dir, + distribution_version="1.1.0", + benchmark_root=benchmark_root, + ) + + self.assertDictEqual( + { + "cluster_name": "benchmark-provisioned-cluster", + "node_name": "benchmark-node-0", + "install_root_path": "/var/solr", + "data_paths": ["/var/solr/data"], + "log_path": "/var/solr/logs", + "heap_dump_path": "/var/solr/heapdump", + "discovery_type": "single-node", + "network_host": "0.0.0.0", + "http_port": "38983", + "zookeeper_port": "39983", + "cluster_settings": {}, + "docker_image": "solr", + }, + docker.config_vars, + ) + + self.assertDictEqual( + { + "solr_data_dir": data_dir, + "solr_log_dir": log_dir, + "solr_heap_dump_dir": heap_dump_dir, + "solr_version": "1.1.0", + "docker_image": "solr", + "http_port": 38983, + "mounts": {}, }, - "docker_image": "solr" - }, docker.config_vars) - - self.assertDictEqual({ - "solr_data_dir": data_dir, - "solr_log_dir": log_dir, - "solr_heap_dump_dir": heap_dump_dir, - "solr_version": "1.1.0", - "docker_image": "solr", - "http_port": 38983, - "mounts": {} - }, docker.docker_vars(mounts={})) + docker.docker_vars(mounts={}), + ) docker_cfg = docker._render_template_from_file(docker.docker_vars(mounts={})) self.assertEqual( -"""version: '3' + """version: '3' services: solr-node1: image: solr:1.1.0 @@ -324,7 +336,10 @@ def test_provisioning_with_defaults(self, uuid4): volumes: solr-data1: networks: - solr-net:""" % (data_dir, log_dir, heap_dump_dir), docker_cfg) + solr-net:""" + % (data_dir, log_dir, heap_dump_dir), + docker_cfg, + ) @mock.patch("uuid.uuid4") def test_provisioning_with_variables(self, uuid4): @@ -336,24 +351,24 @@ def test_provisioning_with_variables(self, uuid4): benchmark_root = os.path.normpath(os.path.join(os.path.dirname(os.path.realpath(__file__)), os.pardir, os.pardir, "solrorbit")) - c = cluster_config.ClusterConfigInstance("unit-test-cluster-config-instance", None, "/tmp", variables={ - "docker_image": "solr", - "docker_mem_limit": "256m", - "docker_cpu_count": 2 - }) + c = cluster_config.ClusterConfigInstance( + "unit-test-cluster-config-instance", None, "/tmp", variables={"docker_image": "solr", "docker_mem_limit": "256m", "docker_cpu_count": 2} + ) - docker = provisioner.DockerProvisioner(cluster_config=c, - node_name="benchmark-node-0", - ip="10.17.22.33", - http_port=38983, - node_root_dir=node_root_dir, - distribution_version="1.1.0", - benchmark_root=benchmark_root) + docker = provisioner.DockerProvisioner( + cluster_config=c, + node_name="benchmark-node-0", + ip="10.17.22.33", + http_port=38983, + node_root_dir=node_root_dir, + distribution_version="1.1.0", + benchmark_root=benchmark_root, + ) docker_cfg = docker._render_template_from_file(docker.docker_vars(mounts={})) self.assertEqual( -"""version: '3' + """version: '3' services: solr-node1: image: solr:1.1.0 @@ -385,7 +400,10 @@ def test_provisioning_with_variables(self, uuid4): volumes: solr-data1: networks: - solr-net:""" % (data_dir, log_dir, heap_dump_dir), docker_cfg) + solr-net:""" + % (data_dir, log_dir, heap_dump_dir), + docker_cfg, + ) class CleanupTests(TestCase): @@ -394,10 +412,7 @@ class CleanupTests(TestCase): def test_preserves(self, mock_path_exists, mock_rm): mock_path_exists.return_value = True - provisioner.cleanup( - preserve=True, - install_dir="./benchmark/test_runs/install", - data_paths=["./benchmark/test_runs/data"]) + provisioner.cleanup(preserve=True, install_dir="./benchmark/test_runs/install", data_paths=["./benchmark/test_runs/data"]) self.assertEqual(mock_path_exists.call_count, 0) self.assertEqual(mock_rm.call_count, 0) @@ -407,10 +422,7 @@ def test_preserves(self, mock_path_exists, mock_rm): def test_cleanup(self, mock_path_exists, mock_rm): mock_path_exists.return_value = True - provisioner.cleanup( - preserve=False, - install_dir="./benchmark/test_runs/install", - data_paths=["./benchmark/test_runs/data"]) + provisioner.cleanup(preserve=False, install_dir="./benchmark/test_runs/install", data_paths=["./benchmark/test_runs/data"]) expected_dir_calls = [mock.call("/tmp/some/data-path-dir"), mock.call("/benchmark-root/workload/test_procedure/es-bin")] mock_path_exists.mock_calls = expected_dir_calls diff --git a/tests/builder/supplier_test.py b/tests/builder/supplier_test.py index 3224f72f..ef35fbf4 100644 --- a/tests/builder/supplier_test.py +++ b/tests/builder/supplier_test.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -39,12 +39,10 @@ class RevisionExtractorTests(TestCase): def test_single_revision(self): self.assertDictEqual({"solr": "67c2f42", "all": "67c2f42"}, supplier._extract_revisions("67c2f42")) self.assertDictEqual({"solr": "current", "all": "current"}, supplier._extract_revisions("current")) - self.assertDictEqual({"solr": "@2015-01-01-01:00:00", "all": "@2015-01-01-01:00:00"}, - supplier._extract_revisions("@2015-01-01-01:00:00")) + self.assertDictEqual({"solr": "@2015-01-01-01:00:00", "all": "@2015-01-01-01:00:00"}, supplier._extract_revisions("@2015-01-01-01:00:00")) def test_multiple_revisions(self): - self.assertDictEqual({"solr": "67c2f42", "some-plugin": "current"}, - supplier._extract_revisions("solr:67c2f42,some-plugin:current")) + self.assertDictEqual({"solr": "67c2f42", "some-plugin": "current"}, supplier._extract_revisions("solr:67c2f42,some-plugin:current")) def test_invalid_revisions(self): with self.assertRaises(exceptions.SystemSetupError) as ctx: @@ -84,8 +82,7 @@ def test_checkout_current(self, mock_is_working_copy, mock_clone, mock_pull, moc mock_is_working_copy.assert_called_with("/src") self.assertEqual(0, mock_clone.call_count) self.assertEqual(0, mock_pull.call_count) - mock_head_revision.assert_called_with("/src")\ - + mock_head_revision.assert_called_with("/src") @mock.patch("solrorbit.utils.git.head_revision", autospec=True) @mock.patch("solrorbit.utils.git.checkout") @@ -208,16 +205,9 @@ def add_os_artifact(binaries): # no version / revision provided renderer = supplier.TemplateRenderer(version=None) - dist_cfg = { - "release_url": "https://downloads.apache.org/solr/solr/{{VERSION}}/solr-{{VERSION}}.tgz" - } - file_resolver = supplier.FileNameResolver( - distribution_config=dist_cfg, - template_renderer=renderer - ) - cached_supplier = supplier.CachedSourceSupplier(distributions_root="/tmp", - source_supplier=opensearch, - file_resolver=file_resolver) + dist_cfg = {"release_url": "https://downloads.apache.org/solr/solr/{{VERSION}}/solr-{{VERSION}}.tgz"} + file_resolver = supplier.FileNameResolver(distribution_config=dist_cfg, template_renderer=renderer) + cached_supplier = supplier.CachedSourceSupplier(distributions_root="/tmp", source_supplier=opensearch, file_resolver=file_resolver) cached_supplier.fetch() cached_supplier.prepare() @@ -238,16 +228,9 @@ def test_uses_already_cached_artifact(self, opensearch, path_exists): path_exists.return_value = True renderer = supplier.TemplateRenderer(version="abc123") - dist_cfg = { - "release_url": "https://downloads.apache.org/solr/solr/{{VERSION}}/solr-{{VERSION}}.tgz" - } - file_resolver = supplier.FileNameResolver( - distribution_config=dist_cfg, - template_renderer=renderer - ) - cached_supplier = supplier.CachedSourceSupplier(distributions_root="/tmp", - source_supplier=opensearch, - file_resolver=file_resolver) + dist_cfg = {"release_url": "https://downloads.apache.org/solr/solr/{{VERSION}}/solr-{{VERSION}}.tgz"} + file_resolver = supplier.FileNameResolver(distribution_config=dist_cfg, template_renderer=renderer) + cached_supplier = supplier.CachedSourceSupplier(distributions_root="/tmp", source_supplier=opensearch, file_resolver=file_resolver) cached_supplier.fetch() cached_supplier.prepare() @@ -278,16 +261,11 @@ def add_os_artifact(binaries): renderer = supplier.TemplateRenderer(version="abc123") - dist_cfg = { - "release_url": "https://downloads.apache.org/solr/solr/{{VERSION}}/solr-{{VERSION}}.tgz" - } + dist_cfg = {"release_url": "https://downloads.apache.org/solr/solr/{{VERSION}}/solr-{{VERSION}}.tgz"} - cached_supplier = supplier.CachedSourceSupplier(distributions_root="/tmp", - source_supplier=opensearch, - file_resolver=supplier.FileNameResolver( - distribution_config=dist_cfg, - template_renderer=renderer - )) + cached_supplier = supplier.CachedSourceSupplier( + distributions_root="/tmp", source_supplier=opensearch, file_resolver=supplier.FileNameResolver(distribution_config=dist_cfg, template_renderer=renderer) + ) cached_supplier.fetch() cached_supplier.prepare() @@ -330,16 +308,11 @@ def add_os_artifact(binaries): renderer = supplier.TemplateRenderer(version="abc123") - dist_cfg = { - "release_url": "https://downloads.apache.org/solr/solr/{{VERSION}}/solr-{{VERSION}}.tgz" - } + dist_cfg = {"release_url": "https://downloads.apache.org/solr/solr/{{VERSION}}/solr-{{VERSION}}.tgz"} - cached_supplier = supplier.CachedSourceSupplier(distributions_root="/tmp", - source_supplier=opensearch, - file_resolver=supplier.FileNameResolver( - distribution_config=dist_cfg, - template_renderer=renderer - )) + cached_supplier = supplier.CachedSourceSupplier( + distributions_root="/tmp", source_supplier=opensearch, file_resolver=supplier.FileNameResolver(distribution_config=dist_cfg, template_renderer=renderer) + ) cached_supplier.fetch() cached_supplier.prepare() @@ -360,14 +333,9 @@ def setUp(self): super().setUp() renderer = supplier.TemplateRenderer(version="9.10.1") - dist_cfg = { - "release_url": "https://downloads.apache.org/solr/solr/{{VERSION}}/solr-{{VERSION}}.tgz" - } + dist_cfg = {"release_url": "https://downloads.apache.org/solr/solr/{{VERSION}}/solr-{{VERSION}}.tgz"} - self.resolver = supplier.FileNameResolver( - distribution_config=dist_cfg, - template_renderer=renderer - ) + self.resolver = supplier.FileNameResolver(distribution_config=dist_cfg, template_renderer=renderer) def test_resolve(self): self.resolver.revision = "9.10.1" @@ -418,7 +386,7 @@ def test_prunes_old_files(self, rm, lstat, isfile, listdir, exists): # opensearch-1.0.0.tar.gz PruneTests.LStat(st_ctime=int(ten_days_ago.timestamp())), # opensearch-1.0.1-x64.tar.gz - PruneTests.LStat(st_ctime=int(one_day_ago.timestamp())) + PruneTests.LStat(st_ctime=int(one_day_ago.timestamp())), ] supplier._prune(root_path="/tmp/test", max_age_days=7) @@ -428,70 +396,57 @@ def test_prunes_old_files(self, rm, lstat, isfile, listdir, exists): class SolrSourceSupplierTests(TestCase): def test_no_build(self): - cluster_config_instance = cluster_config.ClusterConfigInstance("default", root_path=None, config_paths=[], variables={ - "clean_command": "./gradlew clean", - "system.build_command": "./gradlew assemble" - }) + cluster_config_instance = cluster_config.ClusterConfigInstance( + "default", root_path=None, config_paths=[], variables={"clean_command": "./gradlew clean", "system.build_command": "./gradlew assemble"} + ) renderer = supplier.TemplateRenderer(version=None) - opensearch = supplier.SourceSupplier(revision="abc", - os_src_dir="/src", - remote_url="", - cluster_config=cluster_config_instance, - builder=None, - template_renderer=renderer) + opensearch = supplier.SourceSupplier(revision="abc", os_src_dir="/src", remote_url="", cluster_config=cluster_config_instance, builder=None, template_renderer=renderer) opensearch.prepare() # nothing has happened (intentionally) because there is no builder def test_build(self): - cluster_config_instance = cluster_config.ClusterConfigInstance("default", root_path=None, config_paths=[], variables={ - "clean_command": "./gradlew clean", - "system.build_command": "./gradlew assemble" - }) + cluster_config_instance = cluster_config.ClusterConfigInstance( + "default", root_path=None, config_paths=[], variables={"clean_command": "./gradlew clean", "system.build_command": "./gradlew assemble"} + ) builder = mock.create_autospec(supplier.Builder) renderer = supplier.TemplateRenderer(version="abc") - opensearch = supplier.SourceSupplier(revision="abc", - os_src_dir="/src", - remote_url="", - cluster_config=cluster_config_instance, - builder=builder, - template_renderer=renderer) + opensearch = supplier.SourceSupplier(revision="abc", os_src_dir="/src", remote_url="", cluster_config=cluster_config_instance, builder=builder, template_renderer=renderer) opensearch.prepare() builder.build.assert_called_once_with(["./gradlew clean", "./gradlew assemble"]) def test_raises_error_on_missing_cluster_config_variable(self): - cluster_config_instance = cluster_config.ClusterConfigInstance("default", root_path=None, config_paths=[], variables={ - "clean_command": "./gradlew clean", - # system.build_command is not defined - }) + cluster_config_instance = cluster_config.ClusterConfigInstance( + "default", + root_path=None, + config_paths=[], + variables={ + "clean_command": "./gradlew clean", + # system.build_command is not defined + }, + ) renderer = supplier.TemplateRenderer(version="abc") builder = mock.create_autospec(supplier.Builder) - opensearch = supplier.SourceSupplier(revision="abc", - os_src_dir="/src", - remote_url="", - cluster_config=cluster_config_instance, - builder=builder, - template_renderer=renderer) - with self.assertRaisesRegex(exceptions.SystemSetupError, - "ClusterConfigInstance \"default\" requires config key \"system.build_command\""): + opensearch = supplier.SourceSupplier(revision="abc", os_src_dir="/src", remote_url="", cluster_config=cluster_config_instance, builder=builder, template_renderer=renderer) + with self.assertRaisesRegex(exceptions.SystemSetupError, 'ClusterConfigInstance "default" requires config key "system.build_command"'): opensearch.prepare() self.assertEqual(0, builder.build.call_count) @mock.patch("glob.glob", lambda p: ["opensearch.tar.gz"]) def test_add_opensearch_binary(self): - cluster_config_instance = cluster_config.ClusterConfigInstance("default", root_path=None, config_paths=[], variables={ - "clean_command": "./gradlew clean", - "system.build_command": "./gradlew assemble", - "system.artifact_path_pattern": "distribution/archives/tar/build/distributions/*.tar.gz" - }) + cluster_config_instance = cluster_config.ClusterConfigInstance( + "default", + root_path=None, + config_paths=[], + variables={ + "clean_command": "./gradlew clean", + "system.build_command": "./gradlew assemble", + "system.artifact_path_pattern": "distribution/archives/tar/build/distributions/*.tar.gz", + }, + ) renderer = supplier.TemplateRenderer(version="abc") - opensearch = supplier.SourceSupplier(revision="abc", - os_src_dir="/src", - remote_url="", - cluster_config=cluster_config_instance, - builder=None, - template_renderer=renderer) + opensearch = supplier.SourceSupplier(revision="abc", os_src_dir="/src", remote_url="", cluster_config=cluster_config_instance, builder=None, template_renderer=renderer) binaries = {} opensearch.add(binaries=binaries) self.assertEqual(binaries, {"solr": "opensearch.tar.gz"}) @@ -500,14 +455,12 @@ def test_add_opensearch_binary(self): class CreateSupplierTests(TestCase): def test_derive_supply_requirements_source_build(self): # corresponds to --revision="abc" - requirements = supplier._supply_requirements( - sources=True, revisions={"solr": "abc"}, distribution_version=None) + requirements = supplier._supply_requirements(sources=True, revisions={"solr": "abc"}, distribution_version=None) self.assertDictEqual({"solr": ("source", "abc", True)}, requirements) def test_derive_supply_requirements_distribution(self): # corresponds to --distribution-version=1.0.0 - requirements = supplier._supply_requirements( - sources=False, revisions={}, distribution_version="1.0.0") + requirements = supplier._supply_requirements(sources=False, revisions={}, distribution_version="1.0.0") self.assertDictEqual({"solr": ("distribution", "1.0.0", False)}, requirements) def test_create_suppliers_for_os_only_config(self): @@ -516,8 +469,12 @@ def test_create_suppliers_for_os_only_config(self): # default value from command line cfg.add(config.Scope.application, "builder", "source.revision", "current") cfg.add(config.Scope.application, "builder", "distribution.repository", "release") - cfg.add(config.Scope.application, "distributions", "release.url", - "https://artifacts.opensearch.org/releases/bundle/opensearch/{{VERSION}}/opensearch-{{VERSION}}-{{OSNAME}}-{{ARCH}}.tar.gz") + cfg.add( + config.Scope.application, + "distributions", + "release.url", + "https://artifacts.opensearch.org/releases/bundle/opensearch/{{VERSION}}/opensearch-{{VERSION}}-{{OSNAME}}-{{ARCH}}.tar.gz", + ) cfg.add(config.Scope.application, "distributions", "release.cache", True) cfg.add(config.Scope.application, "node", "root.dir", "/opt/benchmark") @@ -529,38 +486,41 @@ def test_create_suppliers_for_os_only_config(self): self.assertIsInstance(composite_supplier.suppliers[0], supplier.DistributionSupplier) - class DistributionRepositoryTests(TestCase): def test_release_repo_config_with_default_url(self): renderer = supplier.TemplateRenderer(version="9.10.1") - repo = supplier.DistributionRepository(name="release", distribution_config={ - "release_url": "https://downloads.apache.org/solr/solr/{{VERSION}}/solr-{{VERSION}}.tgz", - "release.cache": "true" - }, template_renderer=renderer) - self.assertEqual("https://downloads.apache.org/solr/solr/9.10.1/solr-9.10.1.tgz", - repo.download_url) + repo = supplier.DistributionRepository( + name="release", + distribution_config={"release_url": "https://downloads.apache.org/solr/solr/{{VERSION}}/solr-{{VERSION}}.tgz", "release.cache": "true"}, + template_renderer=renderer, + ) + self.assertEqual("https://downloads.apache.org/solr/solr/9.10.1/solr-9.10.1.tgz", repo.download_url) self.assertEqual("solr-9.10.1.tgz", repo.file_name) self.assertTrue(repo.cache) def test_release_repo_config_with_user_url(self): renderer = supplier.TemplateRenderer(version="9.10.1") - repo = supplier.DistributionRepository(name="release", distribution_config={ - "release_url": "https://downloads.apache.org/solr/solr/{{VERSION}}/solr-{{VERSION}}.tgz", - # user override - "release.url": "https://downloads.apache.org/solr/solr/{{VERSION}}/solr-{{VERSION}}.tgz", - "release.cache": "false" - }, template_renderer=renderer) - self.assertEqual("https://downloads.apache.org/solr/solr/9.10.1/solr-9.10.1.tgz", - repo.download_url) + repo = supplier.DistributionRepository( + name="release", + distribution_config={ + "release_url": "https://downloads.apache.org/solr/solr/{{VERSION}}/solr-{{VERSION}}.tgz", + # user override + "release.url": "https://downloads.apache.org/solr/solr/{{VERSION}}/solr-{{VERSION}}.tgz", + "release.cache": "false", + }, + template_renderer=renderer, + ) + self.assertEqual("https://downloads.apache.org/solr/solr/9.10.1/solr-9.10.1.tgz", repo.download_url) self.assertEqual("solr-9.10.1.tgz", repo.file_name) self.assertFalse(repo.cache) def test_missing_url(self): renderer = supplier.TemplateRenderer(version="9.10.1") - repo = supplier.DistributionRepository(name="miss", distribution_config={ - "release_url": "https://downloads.apache.org/solr/solr/{{VERSION}}/solr-{{VERSION}}.tgz", - "release.cache": "true" - }, template_renderer=renderer) + repo = supplier.DistributionRepository( + name="miss", + distribution_config={"release_url": "https://downloads.apache.org/solr/solr/{{VERSION}}/solr-{{VERSION}}.tgz", "release.cache": "true"}, + template_renderer=renderer, + ) with self.assertRaises(exceptions.SystemSetupError) as ctx: # pylint: disable=pointless-statement # noinspection PyStatementEffect @@ -569,11 +529,15 @@ def test_missing_url(self): def test_missing_cache(self): renderer = supplier.TemplateRenderer(version="1.0.0") - repo = supplier.DistributionRepository(name="release", distribution_config={ - "jdk.unbundled.release.url": "https://artifacts.opensearch\ + repo = supplier.DistributionRepository( + name="release", + distribution_config={ + "jdk.unbundled.release.url": "https://artifacts.opensearch\ .org/releases/bundle/opensearch/{{VERSION}}/opensearch-{{VERSION}}-{{OSNAME}}-{{ARCH}}.tar.gz", - "runtime.jdk.bundled": "false" - }, template_renderer=renderer) + "runtime.jdk.bundled": "false", + }, + template_renderer=renderer, + ) with self.assertRaises(exceptions.SystemSetupError) as ctx: # pylint: disable=pointless-statement # noinspection PyStatementEffect @@ -582,12 +546,16 @@ def test_missing_cache(self): def test_invalid_cache_value(self): renderer = supplier.TemplateRenderer(version="1.0.0") - repo = supplier.DistributionRepository(name="release", distribution_config={ - "jdk.unbundled.release.url": "https://artifacts.opensearch\ + repo = supplier.DistributionRepository( + name="release", + distribution_config={ + "jdk.unbundled.release.url": "https://artifacts.opensearch\ .org/releases/bundle/opensearch/{{VERSION}}/opensearch-{{VERSION}}-{{OSNAME}}-{{ARCH}}.tar.gz", - "runtime.jdk.bundled": "false", - "release.cache": "Invalid" - }, template_renderer=renderer) + "runtime.jdk.bundled": "false", + "release.cache": "Invalid", + }, + template_renderer=renderer, + ) with self.assertRaises(exceptions.SystemSetupError) as ctx: # pylint: disable=pointless-statement # noinspection PyStatementEffect @@ -596,25 +564,34 @@ def test_invalid_cache_value(self): def test_plugin_config_with_default_url(self): renderer = supplier.TemplateRenderer(version="5.5.0") - repo = supplier.DistributionRepository(name="release", distribution_config={ - "runtime.jdk.bundled": "false", - "plugin_example_release_url": "https://artifacts.example.org/downloads/plugins/example-{{VERSION}}.zip" - }, template_renderer=renderer) + repo = supplier.DistributionRepository( + name="release", + distribution_config={"runtime.jdk.bundled": "false", "plugin_example_release_url": "https://artifacts.example.org/downloads/plugins/example-{{VERSION}}.zip"}, + template_renderer=renderer, + ) self.assertEqual("https://artifacts.example.org/downloads/plugins/example-5.5.0.zip", repo.plugin_download_url("example")) def test_plugin_config_with_user_url(self): renderer = supplier.TemplateRenderer(version="5.5.0") - repo = supplier.DistributionRepository(name="release", distribution_config={ - "runtime.jdk.bundled": "false", - "plugin_example_release_url": "https://artifacts.example.org/downloads/plugins/example-{{VERSION}}.zip", - # user override - "plugin.example.release.url": "https://mirror.example.org/downloads/plugins/example-{{VERSION}}.zip" - }, template_renderer=renderer) + repo = supplier.DistributionRepository( + name="release", + distribution_config={ + "runtime.jdk.bundled": "false", + "plugin_example_release_url": "https://artifacts.example.org/downloads/plugins/example-{{VERSION}}.zip", + # user override + "plugin.example.release.url": "https://mirror.example.org/downloads/plugins/example-{{VERSION}}.zip", + }, + template_renderer=renderer, + ) self.assertEqual("https://mirror.example.org/downloads/plugins/example-5.5.0.zip", repo.plugin_download_url("example")) def test_missing_plugin_config(self): renderer = supplier.TemplateRenderer(version="5.5.0") - repo = supplier.DistributionRepository(name="release", distribution_config={ - "runtime.jdk.bundled": "false", - }, template_renderer=renderer) + repo = supplier.DistributionRepository( + name="release", + distribution_config={ + "runtime.jdk.bundled": "false", + }, + template_renderer=renderer, + ) self.assertIsNone(repo.plugin_download_url("not existing")) diff --git a/tests/builder/utils/artifact_variables_provider_test.py b/tests/builder/utils/artifact_variables_provider_test.py index 2dd8fb76..3cae9993 100644 --- a/tests/builder/utils/artifact_variables_provider_test.py +++ b/tests/builder/utils/artifact_variables_provider_test.py @@ -15,28 +15,16 @@ def test_x86(self): self.executor.execute.side_effect = [["Linux"], ["x86_64"]] variables = self.artifact_variables_provider.get_artifact_variables(self.host) - self.assertEqual(variables, { - "VERSION": None, - "OSNAME": "linux", - "ARCH": "x64" - }) + self.assertEqual(variables, {"VERSION": None, "OSNAME": "linux", "ARCH": "x64"}) def test_arm(self): self.executor.execute.side_effect = [["Linux"], ["aarch64"]] variables = self.artifact_variables_provider.get_artifact_variables(self.host) - self.assertEqual(variables, { - "VERSION": None, - "OSNAME": "linux", - "ARCH": "arm64" - }) + self.assertEqual(variables, {"VERSION": None, "OSNAME": "linux", "ARCH": "arm64"}) def test_version_supplied(self): self.executor.execute.side_effect = [["Linux"], ["aarch64"]] variables = self.artifact_variables_provider.get_artifact_variables(self.host, "1.23") - self.assertEqual(variables, { - "VERSION": "1.23", - "OSNAME": "linux", - "ARCH": "arm64" - }) + self.assertEqual(variables, {"VERSION": "1.23", "OSNAME": "linux", "ARCH": "arm64"}) diff --git a/tests/builder/utils/config_applier_test.py b/tests/builder/utils/config_applier_test.py index 48f05e82..979dcd4a 100644 --- a/tests/builder/utils/config_applier_test.py +++ b/tests/builder/utils/config_applier_test.py @@ -7,8 +7,9 @@ class ConfigApplierTest(TestCase): def setUp(self): - self.node = Node(binary_path="/fake_binary_path", data_paths=["/fake1", "/fake2"], - name=None, pid=None, telemetry=None, port=None, root_dir=None, log_path=None, heap_dump_path=None) + self.node = Node( + binary_path="/fake_binary_path", data_paths=["/fake1", "/fake2"], name=None, pid=None, telemetry=None, port=None, root_dir=None, log_path=None, heap_dump_path=None + ) self.host = None self.config_paths = ["/fake_config_path"] self.config_vars = {} @@ -26,16 +27,10 @@ def test_apply_config_binary_file(self, is_plain_text, os_walk): mounts = self.config_applier.apply_configs(self.host, self.node, self.config_paths, self.config_vars) - self.assertEqual(mounts, { - "/fake_binary_path/sub_fake_config_path/fake_file": "/var/solr/sub_fake_config_path/fake_file" - }) - self.path_manager.create_path.assert_has_calls([ - mock.call(self.host, "/fake_binary_path/sub_fake_config_path") - ]) + self.assertEqual(mounts, {"/fake_binary_path/sub_fake_config_path/fake_file": "/var/solr/sub_fake_config_path/fake_file"}) + self.path_manager.create_path.assert_has_calls([mock.call(self.host, "/fake_binary_path/sub_fake_config_path")]) self.template_renderer.render_template_file.assert_has_calls([]) - self.executor.execute.assert_has_calls([ - mock.call(self.host, "cp /fake_config_path/sub_fake_config_path/fake_file /fake_binary_path/sub_fake_config_path/fake_file") - ]) + self.executor.execute.assert_has_calls([mock.call(self.host, "cp /fake_config_path/sub_fake_config_path/fake_file /fake_binary_path/sub_fake_config_path/fake_file")]) @mock.patch("os.walk") @mock.patch("solrorbit.utils.io.is_plain_text") @@ -46,17 +41,10 @@ def test_apply_config_plaintext_file(self, is_plain_text, os_walk): with patch("builtins.open", mock_open(read_data="fake_data")) as mock_file: mounts = self.config_applier.apply_configs(self.host, self.node, self.config_paths, self.config_vars) - self.assertEqual(mounts, { - "/fake_binary_path/sub_fake_config_path/fake_file": "/var/solr/sub_fake_config_path/fake_file" - }) - self.path_manager.create_path.assert_has_calls([ - mock.call(self.host, "/fake_binary_path/sub_fake_config_path") - ]) - self.template_renderer.render_template_file.assert_has_calls([ - mock.call("/fake_config_path/sub_fake_config_path", self.config_vars, "/fake_config_path/sub_fake_config_path/fake_file") - ]) - self.executor.execute.assert_has_calls([ - mock.call(self.host, - "cp /fake_binary_path/sub_fake_config_path/fake_file /fake_binary_path/sub_fake_config_path/fake_file") - ]) - mock_file.assert_called_with("/fake_binary_path/sub_fake_config_path/fake_file", mode='a', encoding='utf-8') + self.assertEqual(mounts, {"/fake_binary_path/sub_fake_config_path/fake_file": "/var/solr/sub_fake_config_path/fake_file"}) + self.path_manager.create_path.assert_has_calls([mock.call(self.host, "/fake_binary_path/sub_fake_config_path")]) + self.template_renderer.render_template_file.assert_has_calls( + [mock.call("/fake_config_path/sub_fake_config_path", self.config_vars, "/fake_config_path/sub_fake_config_path/fake_file")] + ) + self.executor.execute.assert_has_calls([mock.call(self.host, "cp /fake_binary_path/sub_fake_config_path/fake_file /fake_binary_path/sub_fake_config_path/fake_file")]) + mock_file.assert_called_with("/fake_binary_path/sub_fake_config_path/fake_file", mode="a", encoding="utf-8") diff --git a/tests/builder/utils/host_cleaner_test.py b/tests/builder/utils/host_cleaner_test.py index be5cea6d..ee320b2f 100644 --- a/tests/builder/utils/host_cleaner_test.py +++ b/tests/builder/utils/host_cleaner_test.py @@ -8,8 +8,7 @@ class HostCleanerTest(TestCase): def setUp(self): - self.node = Node(binary_path="/fake", data_paths=["/fake1", "/fake2"], - name=None, pid=None, telemetry=None, port=None, root_dir=None, log_path=None, heap_dump_path=None) + self.node = Node(binary_path="/fake", data_paths=["/fake1", "/fake2"], name=None, pid=None, telemetry=None, port=None, root_dir=None, log_path=None, heap_dump_path=None) self.host = Host(address="fake", name="fake", metadata={}, node=self.node) self.path_manager = Mock() @@ -18,11 +17,7 @@ def setUp(self): def test_cleanup(self): self.host_cleaner.cleanup(self.host, False) - self.path_manager.delete_path.assert_has_calls([ - mock.call(self.host, "/fake1"), - mock.call(self.host, "/fake2"), - mock.call(self.host, "/fake") - ]) + self.path_manager.delete_path.assert_has_calls([mock.call(self.host, "/fake1"), mock.call(self.host, "/fake2"), mock.call(self.host, "/fake")]) def test_cleanup_preserve_install(self): self.host_cleaner.cleanup(self.host, True) diff --git a/tests/builder/utils/java_home_resolver_test.py b/tests/builder/utils/java_home_resolver_test.py index b9045882..0f244666 100644 --- a/tests/builder/utils/java_home_resolver_test.py +++ b/tests/builder/utils/java_home_resolver_test.py @@ -12,17 +12,8 @@ def setUp(self): self.java_home_resolver = JavaHomeResolver(self.executor) self.java_home_resolver.jdk_resolver = Mock() - self.variables = { - "system": { - "runtime": { - "jdk": { - "version": "12,11,10,9,8" - } - } - } - } - self.cluster_config = ClusterConfigInstance("fake_cluster_config", "/path/to/root", - ["/path/to/config"], variables=self.variables) + self.variables = {"system": {"runtime": {"jdk": {"version": "12,11,10,9,8"}}}} + self.cluster_config = ClusterConfigInstance("fake_cluster_config", "/path/to/root", ["/path/to/config"], variables=self.variables) def test_resolves_java_home_for_default_runtime_jdk(self): self.java_home_resolver.jdk_resolver.resolve_jdk_path.return_value = (12, "/opt/jdk12") diff --git a/tests/builder/utils/jdk_resolver_test.py b/tests/builder/utils/jdk_resolver_test.py index e9845887..d45d05ff 100644 --- a/tests/builder/utils/jdk_resolver_test.py +++ b/tests/builder/utils/jdk_resolver_test.py @@ -35,8 +35,9 @@ def test_generic_java_home_matches(self): def test_multiple_majors(self): # printenv, $JAVA_HOME -XshowSettings:properties -version x 2 self.executor.execute.side_effect = [ - ["JAVA_HOME=/fake/path", "JAVA14_HOME=/another/fake/path"], ["java.vm.specification.version = 14"], - ["java.vm.specification.version = 9"] + ["JAVA_HOME=/fake/path", "JAVA14_HOME=/another/fake/path"], + ["java.vm.specification.version = 14"], + ["java.vm.specification.version = 9"], ] _, jdk_path = self.jdk_resolver.resolve_jdk_path(self.host, [8, 14, 16]) diff --git a/tests/builder/utils/path_manager_test.py b/tests/builder/utils/path_manager_test.py index a305bba1..fc2cfbb8 100644 --- a/tests/builder/utils/path_manager_test.py +++ b/tests/builder/utils/path_manager_test.py @@ -13,32 +13,24 @@ def setUp(self): self.executor = Mock() self.path_manager = PathManager(self.executor) - @mock.patch('solrorbit.utils.io.ensure_dir') + @mock.patch("solrorbit.utils.io.ensure_dir") def test_create_path(self, ensure_dir): self.path_manager.create_path(self.host, self.path) - ensure_dir.assert_has_calls([ - mock.call(self.path) - ]) - self.executor.execute.assert_has_calls([ - mock.call(self.host, f"mkdir -m 0777 -p {self.path}") - ]) + ensure_dir.assert_has_calls([mock.call(self.path)]) + self.executor.execute.assert_has_calls([mock.call(self.host, f"mkdir -m 0777 -p {self.path}")]) - @mock.patch('solrorbit.utils.io.ensure_dir') + @mock.patch("solrorbit.utils.io.ensure_dir") def test_create_path_no_local_copy(self, ensure_dir): self.path_manager.create_path(self.host, self.path) ensure_dir.assert_has_calls([]) - self.executor.execute.assert_has_calls([ - mock.call(self.host, f"mkdir -m 0777 -p {self.path}") - ]) + self.executor.execute.assert_has_calls([mock.call(self.host, f"mkdir -m 0777 -p {self.path}")]) def test_delete_valid_path(self): self.path_manager.delete_path(self.host, self.path) - self.executor.execute.assert_has_calls([ - mock.call(self.host, f"rm -r {self.path}") - ]) + self.executor.execute.assert_has_calls([mock.call(self.host, f"rm -r {self.path}")]) def test_delete_invalid_path(self): self.path_manager.delete_path(self.host, "/") diff --git a/tests/builder/utils/template_renderer_test.py b/tests/builder/utils/template_renderer_test.py index bf825888..1b7a2e63 100644 --- a/tests/builder/utils/template_renderer_test.py +++ b/tests/builder/utils/template_renderer_test.py @@ -14,7 +14,7 @@ def setUp(self): self.file_name = "non-existent.txt" self.template_renderer = TemplateRenderer() - @mock.patch('jinja2.Environment.get_template') + @mock.patch("jinja2.Environment.get_template") def test_successful_render(self, get_template): template = Mock() get_template.return_value = template @@ -23,25 +23,20 @@ def test_successful_render(self, get_template): self.template_renderer.render_template_file(self.root_path, self.variables, self.file_name) def test_version_between_filter(self): - self.assertEqual(self.template_renderer.render_template_string('{{ "2.0.0" | version_between("2.0.0", "3.0.0")}}', - self.variables), "True") - self.assertEqual(self.template_renderer.render_template_string('{{ "2.2.3" | version_between("2.0.0", "3.0.0")}}', - self.variables), "True") - self.assertEqual(self.template_renderer.render_template_string('{{ "3.0.0" | version_between("2.0.0", "3.0.0")}}', - self.variables), "True") - self.assertEqual(self.template_renderer.render_template_string('{{ "1.9.0" | version_between("2.0.0", "3.0.0")}}', - self.variables), "False") - self.assertEqual(self.template_renderer.render_template_string('{{ "3.0.1" | version_between("2.0.0", "3.0.0")}}', - self.variables), "False") - - @mock.patch('jinja2.Environment.get_template') + self.assertEqual(self.template_renderer.render_template_string('{{ "2.0.0" | version_between("2.0.0", "3.0.0")}}', self.variables), "True") + self.assertEqual(self.template_renderer.render_template_string('{{ "2.2.3" | version_between("2.0.0", "3.0.0")}}', self.variables), "True") + self.assertEqual(self.template_renderer.render_template_string('{{ "3.0.0" | version_between("2.0.0", "3.0.0")}}', self.variables), "True") + self.assertEqual(self.template_renderer.render_template_string('{{ "1.9.0" | version_between("2.0.0", "3.0.0")}}', self.variables), "False") + self.assertEqual(self.template_renderer.render_template_string('{{ "3.0.1" | version_between("2.0.0", "3.0.0")}}', self.variables), "False") + + @mock.patch("jinja2.Environment.get_template") def test_template_syntax_error(self, get_template): get_template.side_effect = TemplateSyntaxError("fake", 12) with self.assertRaises(InvalidSyntax): self.template_renderer.render_template_file(self.root_path, self.variables, self.file_name) - @mock.patch('jinja2.Environment.get_template') + @mock.patch("jinja2.Environment.get_template") def test_unknown_error(self, get_template): template = Mock() get_template.return_value = template diff --git a/tests/client_test.py b/tests/client_test.py index e99fb8b3..428d6df5 100644 --- a/tests/client_test.py +++ b/tests/client_test.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an diff --git a/tests/config_test.py b/tests/config_test.py index 784678ee..4a92eec7 100644 --- a/tests/config_test.py +++ b/tests/config_test.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -65,21 +65,17 @@ def backup(self): self.backup_created = True def store_default_config(self): - self.store({ - "distributions": { - "release.url": "https://acme.com/releases", - "release.cache": "true", - }, - "system": { - "env.name": "existing-unit-test-config" - }, - "meta": { - "config.version": config.Config.CURRENT_CONFIG_VERSION - }, - "benchmarks": { - "local.dataset.cache": "/tmp/benchmark/data" + self.store( + { + "distributions": { + "release.url": "https://acme.com/releases", + "release.cache": "true", + }, + "system": {"env.name": "existing-unit-test-config"}, + "meta": {"config.version": config.Config.CURRENT_CONFIG_VERSION}, + "benchmarks": {"local.dataset.cache": "/tmp/benchmark/data"}, } - }) + ) def store(self, c): self.present = True @@ -102,14 +98,7 @@ def test_load_existing_config(self): cfg = config.Config(config_file_class=InMemoryConfigStore) self.assertFalse(cfg.config_present()) - sample_config = { - "tests": { - "sample.key": "value" - }, - "meta": { - "config.version": config.Config.CURRENT_CONFIG_VERSION - } - } + sample_config = {"tests": {"sample.key": "value"}, "meta": {"config.version": config.Config.CURRENT_CONFIG_VERSION}} cfg.config_file.store(sample_config) self.assertTrue(cfg.config_present()) @@ -126,18 +115,9 @@ def test_load_all_opts_in_section(self): self.assertFalse(cfg.config_present()) sample_config = { - "distributions": { - "release.url": "https://acme.com/releases", - "release.cache": "true", - "snapshot.url": "https://acme.com/snapshots", - "snapshot.cache": "false" - }, - "system": { - "env.name": "local" - }, - "meta": { - "config.version": config.Config.CURRENT_CONFIG_VERSION - } + "distributions": {"release.url": "https://acme.com/releases", "release.cache": "true", "snapshot.url": "https://acme.com/snapshots", "snapshot.cache": "false"}, + "system": {"env.name": "local"}, + "meta": {"config.version": config.Config.CURRENT_CONFIG_VERSION}, } cfg.config_file.store(sample_config) @@ -146,27 +126,23 @@ def test_load_all_opts_in_section(self): # override a value so we can see that the scoping logic still works. Default is scope "application" cfg.add(config.Scope.applicationOverride, "distributions", "snapshot.cache", "true") - self.assertEqual({ - "release.url": "https://acme.com/releases", - "release.cache": "true", - "snapshot.url": "https://acme.com/snapshots", - # overridden! - "snapshot.cache": "true" - }, cfg.all_opts("distributions")) + self.assertEqual( + { + "release.url": "https://acme.com/releases", + "release.cache": "true", + "snapshot.url": "https://acme.com/snapshots", + # overridden! + "snapshot.cache": "true", + }, + cfg.all_opts("distributions"), + ) def test_add_all_in_section(self): source_cfg = config.Config(config_file_class=InMemoryConfigStore) sample_config = { - "tests": { - "sample.key": "value", - "sample.key2": "value" - }, - "no_copy": { - "other.key": "value" - }, - "meta": { - "config.version": config.Config.CURRENT_CONFIG_VERSION - } + "tests": {"sample.key": "value", "sample.key2": "value"}, + "no_copy": {"other.key": "value"}, + "meta": {"config.version": config.Config.CURRENT_CONFIG_VERSION}, } source_cfg.config_file.store(sample_config) source_cfg.load_config() @@ -211,22 +187,21 @@ def test_can_load_and_amend_existing_config(self): base_cfg.add(config.Scope.application, "benchmarks", "local.dataset.cache", "/base-config/data-set-cache") base_cfg.add(config.Scope.application, "unit-test", "sample.property", "let me copy you") - cfg = config.auto_load_local_config(base_cfg, additional_sections=["unit-test"], - config_file_class=InMemoryConfigStore, present=True, config={ - "distributions": { - "release.url": "https://acme.com/releases", - "release.cache": "true", - }, - "system": { - "env.name": "existing-unit-test-config" - }, - "meta": { - "config.version": config.Config.CURRENT_CONFIG_VERSION + cfg = config.auto_load_local_config( + base_cfg, + additional_sections=["unit-test"], + config_file_class=InMemoryConfigStore, + present=True, + config={ + "distributions": { + "release.url": "https://acme.com/releases", + "release.cache": "true", + }, + "system": {"env.name": "existing-unit-test-config"}, + "meta": {"config.version": config.Config.CURRENT_CONFIG_VERSION}, + "benchmarks": {"local.dataset.cache": "/tmp/benchmark/data"}, }, - "benchmarks": { - "local.dataset.cache": "/tmp/benchmark/data" - } - }) + ) self.assertTrue(cfg.config_file.present) # did not just copy base config self.assertNotEqual(base_cfg.opts("benchmarks", "local.dataset.cache"), cfg.opts("benchmarks", "local.dataset.cache")) @@ -241,27 +216,26 @@ def test_can_migrate_outdated_config(self): base_cfg.add(config.Scope.application, "benchmarks", "local.dataset.cache", "/base-config/data-set-cache") base_cfg.add(config.Scope.application, "unit-test", "sample.property", "let me copy you") - cfg = config.auto_load_local_config(base_cfg, additional_sections=["unit-test"], - config_file_class=InMemoryConfigStore, present=True, config={ + cfg = config.auto_load_local_config( + base_cfg, + additional_sections=["unit-test"], + config_file_class=InMemoryConfigStore, + present=True, + config={ "distributions": { "release.url": "https://acme.com/releases", "release.cache": "true", }, - "system": { - "env.name": "existing-unit-test-config" - }, + "system": {"env.name": "existing-unit-test-config"}, # outdated "meta": { # ensure we don't attempt to migrate if that version is unsupported "config.version": max(config.Config.CURRENT_CONFIG_VERSION - 1, config.Config.EARLIEST_SUPPORTED_VERSION) }, - "benchmarks": { - "local.dataset.cache": "/tmp/benchmark/data" - }, - "runtime": { - "java8.home": "/opt/jdk8" - } - }) + "benchmarks": {"local.dataset.cache": "/tmp/benchmark/data"}, + "runtime": {"java8.home": "/opt/jdk8"}, + }, + ) self.assertTrue(cfg.config_file.present) # did not just copy base config self.assertNotEqual(base_cfg.opts("benchmarks", "local.dataset.cache"), cfg.opts("benchmarks", "local.dataset.cache")) @@ -276,62 +250,34 @@ class ConfigMigrationTests(TestCase): def test_does_not_migrate_outdated_config(self): config_file = InMemoryConfigStore("test") sample_config = { - "system": { - "root.dir": "in-memory" - }, - "provisioning": { - - }, - "build": { - "maven.bin": "/usr/local/mvn" - }, - "benchmarks": { - "metrics.stats.disk.device": "/dev/hdd1" - }, - "reporting": { - "results.base.dir": "/tests/benchmark/reporting", - "output.html.results.filename": "index.html" - }, + "system": {"root.dir": "in-memory"}, + "provisioning": {}, + "build": {"maven.bin": "/usr/local/mvn"}, + "benchmarks": {"metrics.stats.disk.device": "/dev/hdd1"}, + "reporting": {"results.base.dir": "/tests/benchmark/reporting", "output.html.results.filename": "index.html"}, "runtime": { "java8.home": "/opt/jdk/8", - } + }, } config_file.store(sample_config) - with self.assertRaisesRegex(exceptions.ConfigError, - "The config file.*is too old. Please delete it and reconfigure from scratch"): + with self.assertRaisesRegex(exceptions.ConfigError, "The config file.*is too old. Please delete it and reconfigure from scratch"): config.migrate(config_file, config.Config.EARLIEST_SUPPORTED_VERSION - 1, config.Config.CURRENT_CONFIG_VERSION, out=null_output) # catch all test, migrations are checked in more detail in the other tests def test_migrate_from_earliest_supported_to_latest(self): config_file = InMemoryConfigStore("test") sample_config = { - "meta": { - "config.version": config.Config.EARLIEST_SUPPORTED_VERSION - }, - "system": { - "root.dir": "in-memory" - }, - "provisioning": { - - }, - "build": { - "maven.bin": "/usr/local/mvn" - }, - "benchmarks": { - "metrics.stats.disk.device": "/dev/hdd1" - }, - "reporting": { - "results.base.dir": "/tests/benchmark/reporting", - "output.html.results.filename": "index.html" - }, + "meta": {"config.version": config.Config.EARLIEST_SUPPORTED_VERSION}, + "system": {"root.dir": "in-memory"}, + "provisioning": {}, + "build": {"maven.bin": "/usr/local/mvn"}, + "benchmarks": {"metrics.stats.disk.device": "/dev/hdd1"}, + "reporting": {"results.base.dir": "/tests/benchmark/reporting", "output.html.results.filename": "index.html"}, "runtime": { "java8.home": "/opt/jdk/8", }, - "distributions": { - "release.url": "https://artifacts.opensearch.org/releases/bundle/opensearch/{{VERSION}}/opensearch-" - "{{VERSION}}-linux-x64.tar.gz" - } + "distributions": {"release.url": "https://artifacts.opensearch.org/releases/bundle/opensearch/{{VERSION}}/opensearch-{{VERSION}}-linux-x64.tar.gz"}, } config_file.store(sample_config) diff --git a/tests/metrics_test.py b/tests/metrics_test.py index 7cee51b8..2f0d9b6e 100644 --- a/tests/metrics_test.py +++ b/tests/metrics_test.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -43,6 +43,7 @@ AWS_SECRET_ACCESS_KEY_LENGTH = 40 AWS_SESSION_TOKEN_LENGTH = 752 + class StaticClock: NOW = 1453362707 @@ -110,50 +111,42 @@ def tearDown(self): def test_get_one(self): duration = StaticClock.NOW - self.metrics_store.open(InMemoryMetricsStoreTests.TEST_RUN_ID, InMemoryMetricsStoreTests.TEST_RUN_TIMESTAMP, - "test", "append-no-conflicts", "defaults", create=True) - self.metrics_store.put_value_cluster_level("service_time", 500, "ms", relative_time=duration-400, task="task1") + self.metrics_store.open(InMemoryMetricsStoreTests.TEST_RUN_ID, InMemoryMetricsStoreTests.TEST_RUN_TIMESTAMP, "test", "append-no-conflicts", "defaults", create=True) + self.metrics_store.put_value_cluster_level("service_time", 500, "ms", relative_time=duration - 400, task="task1") self.metrics_store.put_value_cluster_level("service_time", 600, "ms", relative_time=duration, task="task1") - self.metrics_store.put_value_cluster_level("final_index_size", 1000, "GB", relative_time=duration-300) + self.metrics_store.put_value_cluster_level("final_index_size", 1000, "GB", relative_time=duration - 300) self.metrics_store.close() - self.metrics_store.open(InMemoryMetricsStoreTests.TEST_RUN_ID, InMemoryMetricsStoreTests.TEST_RUN_TIMESTAMP, - "test", "append-no-conflicts", "defaults") + self.metrics_store.open(InMemoryMetricsStoreTests.TEST_RUN_ID, InMemoryMetricsStoreTests.TEST_RUN_TIMESTAMP, "test", "append-no-conflicts", "defaults") - actual_duration = self.metrics_store.get_one("service_time", task="task1", mapper=lambda doc: doc["relative-time-ms"], - sort_key="relative-time-ms", sort_reverse=True) + actual_duration = self.metrics_store.get_one("service_time", task="task1", mapper=lambda doc: doc["relative-time-ms"], sort_key="relative-time-ms", sort_reverse=True) self.assertEqual(duration * 1000, actual_duration) def test_get_one_no_hits(self): duration = StaticClock.NOW - self.metrics_store.open(InMemoryMetricsStoreTests.TEST_RUN_ID, InMemoryMetricsStoreTests.TEST_RUN_TIMESTAMP, - "test", "append-no-conflicts", "defaults", create=True) - self.metrics_store.put_value_cluster_level("final_index_size", 1000, "GB", relative_time=duration-300) + self.metrics_store.open(InMemoryMetricsStoreTests.TEST_RUN_ID, InMemoryMetricsStoreTests.TEST_RUN_TIMESTAMP, "test", "append-no-conflicts", "defaults", create=True) + self.metrics_store.put_value_cluster_level("final_index_size", 1000, "GB", relative_time=duration - 300) self.metrics_store.close() - self.metrics_store.open(InMemoryMetricsStoreTests.TEST_RUN_ID, InMemoryMetricsStoreTests.TEST_RUN_TIMESTAMP, - "test", "append-no-conflicts", "defaults") + self.metrics_store.open(InMemoryMetricsStoreTests.TEST_RUN_ID, InMemoryMetricsStoreTests.TEST_RUN_TIMESTAMP, "test", "append-no-conflicts", "defaults") - actual_duration = self.metrics_store.get_one("service_time", task="task1", mapper=lambda doc: doc["relative-time-ms"], - sort_key="relative-time-ms", sort_reverse=True) + actual_duration = self.metrics_store.get_one("service_time", task="task1", mapper=lambda doc: doc["relative-time-ms"], sort_key="relative-time-ms", sort_reverse=True) self.assertIsNone(actual_duration) def test_get_value(self): throughput = 5000 - self.metrics_store.open(InMemoryMetricsStoreTests.TEST_RUN_ID, InMemoryMetricsStoreTests.TEST_RUN_TIMESTAMP, - "test", "append-no-conflicts", "defaults", create=True) + self.metrics_store.open(InMemoryMetricsStoreTests.TEST_RUN_ID, InMemoryMetricsStoreTests.TEST_RUN_TIMESTAMP, "test", "append-no-conflicts", "defaults", create=True) self.metrics_store.put_value_cluster_level("indexing_throughput", 1, "docs/s", sample_type=metrics.SampleType.Warmup) self.metrics_store.put_value_cluster_level("indexing_throughput", throughput, "docs/s") self.metrics_store.put_value_cluster_level("final_index_size", 1000, "GB") self.metrics_store.close() - self.metrics_store.open(InMemoryMetricsStoreTests.TEST_RUN_ID, InMemoryMetricsStoreTests.TEST_RUN_TIMESTAMP, - "test", "append-no-conflicts", "defaults") + self.metrics_store.open(InMemoryMetricsStoreTests.TEST_RUN_ID, InMemoryMetricsStoreTests.TEST_RUN_TIMESTAMP, "test", "append-no-conflicts", "defaults") self.assertEqual(1, self.metrics_store.get_one("indexing_throughput", sample_type=metrics.SampleType.Warmup)) self.assertEqual(throughput, self.metrics_store.get_one("indexing_throughput", sample_type=metrics.SampleType.Normal)) @@ -161,33 +154,26 @@ def test_get_value(self): @mock.patch("solrorbit.utils.console.warn") @mock.patch("psutil.virtual_memory") def test_out_of_memory(self, virt_mem, console_warn): - vmem = namedtuple('vmem', ("available", "total")) + vmem = namedtuple("vmem", ("available", "total")) virt_mem.return_value = vmem(250, 1000) throughput = 5000 - self.metrics_store.open(InMemoryMetricsStoreTests.TEST_RUN_ID, InMemoryMetricsStoreTests.TEST_RUN_TIMESTAMP, - "test", "append-no-conflicts", "defaults", create=True) + self.metrics_store.open(InMemoryMetricsStoreTests.TEST_RUN_ID, InMemoryMetricsStoreTests.TEST_RUN_TIMESTAMP, "test", "append-no-conflicts", "defaults", create=True) self.metrics_store.put_value_cluster_level("indexing_throughput", 1, "docs/s", sample_type=metrics.SampleType.Warmup) self.metrics_store.put_value_cluster_level("indexing_throughput", throughput, "docs/s") self.metrics_store.put_value_cluster_level("final_index_size", 1000, "GB") - console_warn.assert_has_calls([ mock.call( - "Memory threshold exceeded by in-memory metrics store, not adding additional entries", - logger=mock.ANY) ]) + console_warn.assert_has_calls([mock.call("Memory threshold exceeded by in-memory metrics store, not adding additional entries", logger=mock.ANY)]) self.metrics_store.to_externalizable(clear=True) - console_warn.assert_has_calls([ mock.call( - "Memory threshold exceeded by in-memory metrics store, skipping summary generation for current operation", - logger=mock.ANY) ]) + console_warn.assert_has_calls([mock.call("Memory threshold exceeded by in-memory metrics store, skipping summary generation for current operation", logger=mock.ANY)]) self.metrics_store.close() def test_get_percentile(self): - self.metrics_store.open(InMemoryMetricsStoreTests.TEST_RUN_ID, InMemoryMetricsStoreTests.TEST_RUN_TIMESTAMP, - "test", "append-no-conflicts", "defaults", create=True) + self.metrics_store.open(InMemoryMetricsStoreTests.TEST_RUN_ID, InMemoryMetricsStoreTests.TEST_RUN_TIMESTAMP, "test", "append-no-conflicts", "defaults", create=True) for i in range(1, 1001): self.metrics_store.put_value_cluster_level("query_latency", float(i), "ms") self.metrics_store.close() - self.metrics_store.open(InMemoryMetricsStoreTests.TEST_RUN_ID, InMemoryMetricsStoreTests.TEST_RUN_TIMESTAMP, - "test", "append-no-conflicts", "defaults") + self.metrics_store.open(InMemoryMetricsStoreTests.TEST_RUN_ID, InMemoryMetricsStoreTests.TEST_RUN_TIMESTAMP, "test", "append-no-conflicts", "defaults") self.assert_equal_percentiles("query_latency", [100.0], {100.0: 1000.0}) self.assert_equal_percentiles("query_latency", [99.0], {99.0: 990.0}) @@ -197,28 +183,24 @@ def test_get_percentile(self): self.assert_equal_percentiles("query_latency", [99, 99.9, 100], {99: 990.0, 99.9: 999.0, 100: 1000.0}) def test_get_mean(self): - self.metrics_store.open(InMemoryMetricsStoreTests.TEST_RUN_ID, InMemoryMetricsStoreTests.TEST_RUN_TIMESTAMP, - "test", "append-no-conflicts", "defaults", create=True) + self.metrics_store.open(InMemoryMetricsStoreTests.TEST_RUN_ID, InMemoryMetricsStoreTests.TEST_RUN_TIMESTAMP, "test", "append-no-conflicts", "defaults", create=True) for i in range(1, 100): self.metrics_store.put_value_cluster_level("query_latency", float(i), "ms") self.metrics_store.close() - self.metrics_store.open(InMemoryMetricsStoreTests.TEST_RUN_ID, InMemoryMetricsStoreTests.TEST_RUN_TIMESTAMP, - "test", "append-no-conflicts", "defaults") + self.metrics_store.open(InMemoryMetricsStoreTests.TEST_RUN_ID, InMemoryMetricsStoreTests.TEST_RUN_TIMESTAMP, "test", "append-no-conflicts", "defaults") self.assertAlmostEqual(50, self.metrics_store.get_mean("query_latency")) def test_get_median(self): - self.metrics_store.open(InMemoryMetricsStoreTests.TEST_RUN_ID, InMemoryMetricsStoreTests.TEST_RUN_TIMESTAMP, - "test", "append-no-conflicts", "defaults", create=True) + self.metrics_store.open(InMemoryMetricsStoreTests.TEST_RUN_ID, InMemoryMetricsStoreTests.TEST_RUN_TIMESTAMP, "test", "append-no-conflicts", "defaults", create=True) for i in range(1, 1001): self.metrics_store.put_value_cluster_level("query_latency", float(i), "ms") self.metrics_store.close() - self.metrics_store.open(InMemoryMetricsStoreTests.TEST_RUN_ID, InMemoryMetricsStoreTests.TEST_RUN_TIMESTAMP, - "test", "append-no-conflicts", "defaults") + self.metrics_store.open(InMemoryMetricsStoreTests.TEST_RUN_ID, InMemoryMetricsStoreTests.TEST_RUN_TIMESTAMP, "test", "append-no-conflicts", "defaults") self.assertAlmostEqual(500.5, self.metrics_store.get_median("query_latency")) @@ -226,35 +208,10 @@ def assert_equal_percentiles(self, name, percentiles, expected_percentiles): actual_percentiles = self.metrics_store.get_percentiles(name, percentiles=percentiles) self.assertEqual(len(expected_percentiles), len(actual_percentiles)) for percentile, actual_percentile_value in actual_percentiles.items(): - self.assertAlmostEqual(expected_percentiles[percentile], actual_percentile_value, places=1, - msg=str(percentile) + "th percentile differs") + self.assertAlmostEqual(expected_percentiles[percentile], actual_percentile_value, places=1, msg=str(percentile) + "th percentile differs") def test_filter_percentiles_by_sample_size(self): - test_percentiles = [ - 0, - 0.0001, - 0.001, - 0.01, - 0.1, - 4, - 10, - 10.01, - 25, - 45, - 46.001, - 50, - 75, - 80.1, - 90, - 98.9, - 98.91, - 98.999, - 99, - 99.9, - 99.99, - 99.999, - 99.9999, - 100] + test_percentiles = [0, 0.0001, 0.001, 0.01, 0.1, 4, 10, 10.01, 25, 45, 46.001, 50, 75, 80.1, 90, 98.9, 98.91, 98.999, 99, 99.9, 99.99, 99.999, 99.9999, 100] sample_size_to_result_map = { 1: [100], 2: [50, 100], @@ -264,11 +221,9 @@ def test_filter_percentiles_by_sample_size(self): 100: [0, 4, 10, 25, 45, 50, 75, 90, 99, 100], 1000: [0, 0.1, 4, 10, 25, 45, 50, 75, 80.1, 90, 98.9, 99, 99.9, 100], 10000: [0, 0.01, 0.1, 4, 10, 10.01, 25, 45, 50, 75, 80.1, 90, 98.9, 98.91, 99, 99.9, 99.99, 100], - 100000: [0, 0.001, 0.01, 0.1, 4, 10, 10.01, 25, 45, 46.001, 50, 75, - 80.1, 90, 98.9, 98.91, 98.999, 99, 99.9, 99.99, 99.999, 100], - 1000000: [0, 0.0001, 0.001, 0.01, 0.1, 4, 10, 10.01, 25, 45, 46.001, 50, 75, - 80.1, 90, 98.9, 98.91, 98.999, 99, 99.9, 99.99, 99.999, 99.9999, 100] - } # 100,000 corresponds to 0.001% which is the order of magnitude we round to, + 100000: [0, 0.001, 0.01, 0.1, 4, 10, 10.01, 25, 45, 46.001, 50, 75, 80.1, 90, 98.9, 98.91, 98.999, 99, 99.9, 99.99, 99.999, 100], + 1000000: [0, 0.0001, 0.001, 0.01, 0.1, 4, 10, 10.01, 25, 45, 46.001, 50, 75, 80.1, 90, 98.9, 98.91, 98.999, 99, 99.9, 99.99, 99.999, 99.9999, 100], + } # 100,000 corresponds to 0.001% which is the order of magnitude we round to, # so at higher orders (>=1M samples) all values are permitted for sample_size, expected_results in sample_size_to_result_map.items(): filtered = metrics.filter_percentiles_by_sample_size(sample_size, test_percentiles) @@ -277,8 +232,7 @@ def test_filter_percentiles_by_sample_size(self): self.assertEqual(res, exp) def test_externalize_and_bulk_add(self): - self.metrics_store.open(InMemoryMetricsStoreTests.TEST_RUN_ID, InMemoryMetricsStoreTests.TEST_RUN_TIMESTAMP, - "test", "append-no-conflicts", "defaults", create=True) + self.metrics_store.open(InMemoryMetricsStoreTests.TEST_RUN_ID, InMemoryMetricsStoreTests.TEST_RUN_TIMESTAMP, "test", "append-no-conflicts", "defaults", create=True) self.metrics_store.put_value_cluster_level("final_index_size", 1000, "GB") self.assertEqual(1, len(self.metrics_store.docs)) @@ -295,72 +249,48 @@ def test_externalize_and_bulk_add(self): self.assertEqual(1000, self.metrics_store.get_one("final_index_size")) def test_meta_data_per_document(self): - self.metrics_store.open(InMemoryMetricsStoreTests.TEST_RUN_ID, InMemoryMetricsStoreTests.TEST_RUN_TIMESTAMP, - "test", "append-no-conflicts", "defaults", create=True) + self.metrics_store.open(InMemoryMetricsStoreTests.TEST_RUN_ID, InMemoryMetricsStoreTests.TEST_RUN_TIMESTAMP, "test", "append-no-conflicts", "defaults", create=True) self.metrics_store.add_meta_info(metrics.MetaInfoScope.cluster, None, "cluster-name", "test") - self.metrics_store.put_value_cluster_level("final_index_size", 1000, "GB", meta_data={ - "fs-block-size-bytes": 512 - }) - self.metrics_store.put_value_cluster_level("final_bytes_written", 1, "TB", meta_data={ - "io-batch-size-kb": 4 - }) + self.metrics_store.put_value_cluster_level("final_index_size", 1000, "GB", meta_data={"fs-block-size-bytes": 512}) + self.metrics_store.put_value_cluster_level("final_bytes_written", 1, "TB", meta_data={"io-batch-size-kb": 4}) self.assertEqual(2, len(self.metrics_store.docs)) - self.assertEqual({ - "cluster-name": "test", - "fs-block-size-bytes": 512 - }, self.metrics_store.docs[0]["meta"]) + self.assertEqual({"cluster-name": "test", "fs-block-size-bytes": 512}, self.metrics_store.docs[0]["meta"]) - self.assertEqual({ - "cluster-name": "test", - "io-batch-size-kb": 4 - }, self.metrics_store.docs[1]["meta"]) + self.assertEqual({"cluster-name": "test", "io-batch-size-kb": 4}, self.metrics_store.docs[1]["meta"]) def test_get_error_rate_zero_without_samples(self): - self.metrics_store.open(InMemoryMetricsStoreTests.TEST_RUN_ID, InMemoryMetricsStoreTests.TEST_RUN_TIMESTAMP, - "test", "append-no-conflicts", "defaults", create=True) + self.metrics_store.open(InMemoryMetricsStoreTests.TEST_RUN_ID, InMemoryMetricsStoreTests.TEST_RUN_TIMESTAMP, "test", "append-no-conflicts", "defaults", create=True) self.metrics_store.close() - self.metrics_store.open(InMemoryMetricsStoreTests.TEST_RUN_ID, InMemoryMetricsStoreTests.TEST_RUN_TIMESTAMP, - "test", "append-no-conflicts", "defaults") + self.metrics_store.open(InMemoryMetricsStoreTests.TEST_RUN_ID, InMemoryMetricsStoreTests.TEST_RUN_TIMESTAMP, "test", "append-no-conflicts", "defaults") self.assertEqual(0.0, self.metrics_store.get_error_rate("term-query", sample_type=metrics.SampleType.Normal)) def test_get_error_rate_by_sample_type(self): - self.metrics_store.open(InMemoryMetricsStoreTests.TEST_RUN_ID, InMemoryMetricsStoreTests.TEST_RUN_TIMESTAMP, - "test", "append-no-conflicts", "defaults", create=True) - self.metrics_store.put_value_cluster_level("service_time", 3.0, "ms", task="term-query", sample_type=metrics.SampleType.Warmup, - meta_data={"success": False}) - self.metrics_store.put_value_cluster_level("service_time", 3.0, "ms", task="term-query", sample_type=metrics.SampleType.Normal, - meta_data={"success": True}) + self.metrics_store.open(InMemoryMetricsStoreTests.TEST_RUN_ID, InMemoryMetricsStoreTests.TEST_RUN_TIMESTAMP, "test", "append-no-conflicts", "defaults", create=True) + self.metrics_store.put_value_cluster_level("service_time", 3.0, "ms", task="term-query", sample_type=metrics.SampleType.Warmup, meta_data={"success": False}) + self.metrics_store.put_value_cluster_level("service_time", 3.0, "ms", task="term-query", sample_type=metrics.SampleType.Normal, meta_data={"success": True}) self.metrics_store.close() - self.metrics_store.open(InMemoryMetricsStoreTests.TEST_RUN_ID, InMemoryMetricsStoreTests.TEST_RUN_TIMESTAMP, - "test", "append-no-conflicts", "defaults") + self.metrics_store.open(InMemoryMetricsStoreTests.TEST_RUN_ID, InMemoryMetricsStoreTests.TEST_RUN_TIMESTAMP, "test", "append-no-conflicts", "defaults") self.assertEqual(1.0, self.metrics_store.get_error_rate("term-query", sample_type=metrics.SampleType.Warmup)) self.assertEqual(0.0, self.metrics_store.get_error_rate("term-query", sample_type=metrics.SampleType.Normal)) def test_get_error_rate_mixed(self): - self.metrics_store.open(InMemoryMetricsStoreTests.TEST_RUN_ID, InMemoryMetricsStoreTests.TEST_RUN_TIMESTAMP, - "test", "append-no-conflicts", "defaults", create=True) - self.metrics_store.put_value_cluster_level("service_time", 3.0, "ms", task="term-query", sample_type=metrics.SampleType.Normal, - meta_data={"success": True}) - self.metrics_store.put_value_cluster_level("service_time", 3.0, "ms", task="term-query", sample_type=metrics.SampleType.Normal, - meta_data={"success": True}) - self.metrics_store.put_value_cluster_level("service_time", 3.0, "ms", task="term-query", sample_type=metrics.SampleType.Normal, - meta_data={"success": False}) - self.metrics_store.put_value_cluster_level("service_time", 3.0, "ms", task="term-query", sample_type=metrics.SampleType.Normal, - meta_data={"success": True}) - self.metrics_store.put_value_cluster_level("service_time", 3.0, "ms", task="term-query", sample_type=metrics.SampleType.Normal, - meta_data={"success": True}) + self.metrics_store.open(InMemoryMetricsStoreTests.TEST_RUN_ID, InMemoryMetricsStoreTests.TEST_RUN_TIMESTAMP, "test", "append-no-conflicts", "defaults", create=True) + self.metrics_store.put_value_cluster_level("service_time", 3.0, "ms", task="term-query", sample_type=metrics.SampleType.Normal, meta_data={"success": True}) + self.metrics_store.put_value_cluster_level("service_time", 3.0, "ms", task="term-query", sample_type=metrics.SampleType.Normal, meta_data={"success": True}) + self.metrics_store.put_value_cluster_level("service_time", 3.0, "ms", task="term-query", sample_type=metrics.SampleType.Normal, meta_data={"success": False}) + self.metrics_store.put_value_cluster_level("service_time", 3.0, "ms", task="term-query", sample_type=metrics.SampleType.Normal, meta_data={"success": True}) + self.metrics_store.put_value_cluster_level("service_time", 3.0, "ms", task="term-query", sample_type=metrics.SampleType.Normal, meta_data={"success": True}) self.metrics_store.close() - self.metrics_store.open(InMemoryMetricsStoreTests.TEST_RUN_ID, InMemoryMetricsStoreTests.TEST_RUN_TIMESTAMP, - "test", "append-no-conflicts", "defaults") + self.metrics_store.open(InMemoryMetricsStoreTests.TEST_RUN_ID, InMemoryMetricsStoreTests.TEST_RUN_TIMESTAMP, "test", "append-no-conflicts", "defaults") self.assertEqual(0.0, self.metrics_store.get_error_rate("term-query", sample_type=metrics.SampleType.Warmup)) self.assertEqual(0.2, self.metrics_store.get_error_rate("term-query", sample_type=metrics.SampleType.Normal)) @@ -383,9 +313,7 @@ def setUp(self): self.cfg.add(config.Scope.application, "system", "env.name", "unittest-env") self.cfg.add(config.Scope.application, "system", "list.test_runs.max_results", 100) self.cfg.add(config.Scope.application, "system", "time.start", FileTestRunStoreTests.TEST_RUN_TIMESTAMP) - self.cfg.add( - config.Scope.application, "system", "test_run.id", - FileTestRunStoreTests.TEST_RUN_ID) + self.cfg.add(config.Scope.application, "system", "test_run.id", FileTestRunStoreTests.TEST_RUN_ID) self.test_run_store = metrics.FileTestRunStore(self.cfg) def test_test_run_not_found(self): @@ -394,50 +322,43 @@ def test_test_run_not_found(self): self.test_run_store.find_by_test_run_id(FileTestRunStoreTests.TEST_RUN_ID) def test_store_test_run(self): - schedule = [ - workload.Task("index #1", workload.Operation("index", workload.OperationType.Bulk)) - ] + schedule = [workload.Task("index #1", workload.Operation("index", workload.OperationType.Bulk))] - t = workload.Workload(name="unittest", - collections=[workload.Collection(name="tests")], - test_procedures=[workload.TestProcedure(name="index", default=True, schedule=schedule)]) + t = workload.Workload( + name="unittest", collections=[workload.Collection(name="tests")], test_procedures=[workload.TestProcedure(name="index", default=True, schedule=schedule)] + ) test_run = metrics.TestRun( - benchmark_version="0.4.4", benchmark_revision="123abc", environment_name="unittest", - test_run_id=FileTestRunStoreTests.TEST_RUN_ID, - test_run_timestamp=FileTestRunStoreTests.TEST_RUN_TIMESTAMP, - pipeline="from-sources", user_tags={"os": "Linux"}, workload=t, workload_params={"clients": 12}, - test_procedure=t.default_test_procedure, - cluster_config="4gheap", - cluster_config_params=None, - plugin_params=None, - workload_revision="abc1", - cluster_config_revision="abc12333", - distribution_version="5.0.0", - distribution_flavor="default", revision="aaaeeef", - results=FileTestRunStoreTests.DictHolder( - { - "young_gc_time": 100, - "old_gc_time": 5, - "op_metrics": [ - { - "task": "index #1", - "operation": "index", - "throughput": { - "min": 1000, - "median": 1250, - "max": 1500, - "unit": "docs/s" - } - } - ] - }) - ) + benchmark_version="0.4.4", + benchmark_revision="123abc", + environment_name="unittest", + test_run_id=FileTestRunStoreTests.TEST_RUN_ID, + test_run_timestamp=FileTestRunStoreTests.TEST_RUN_TIMESTAMP, + pipeline="from-sources", + user_tags={"os": "Linux"}, + workload=t, + workload_params={"clients": 12}, + test_procedure=t.default_test_procedure, + cluster_config="4gheap", + cluster_config_params=None, + plugin_params=None, + workload_revision="abc1", + cluster_config_revision="abc12333", + distribution_version="5.0.0", + distribution_flavor="default", + revision="aaaeeef", + results=FileTestRunStoreTests.DictHolder( + { + "young_gc_time": 100, + "old_gc_time": 5, + "op_metrics": [{"task": "index #1", "operation": "index", "throughput": {"min": 1000, "median": 1250, "max": 1500, "unit": "docs/s"}}], + } + ), + ) self.test_run_store.store_test_run(test_run) - retrieved_test_run = self.test_run_store.find_by_test_run_id( - test_run_id=FileTestRunStoreTests.TEST_RUN_ID) + retrieved_test_run = self.test_run_store.find_by_test_run_id(test_run_id=FileTestRunStoreTests.TEST_RUN_ID) self.assertEqual(test_run.test_run_id, retrieved_test_run.test_run_id) self.assertEqual(test_run.test_run_timestamp, retrieved_test_run.test_run_timestamp) self.assertEqual(1, len(self.test_run_store.list())) @@ -457,14 +378,8 @@ def test_calculate_global_stats(self): cfg.add(config.Scope.application, "test_run", "pipeline", "from-sources") cfg.add(config.Scope.application, "workload", "params", {}) - index1 = workload.Task(name="index #1", operation=workload.Operation( - name="index", - operation_type=workload.OperationType.Bulk, - params=None)) - index2 = workload.Task(name="index #2", operation=workload.Operation( - name="index", - operation_type=workload.OperationType.Bulk, - params=None)) + index1 = workload.Task(name="index #1", operation=workload.Operation(name="index", operation_type=workload.OperationType.Bulk, params=None)) + index2 = workload.Task(name="index #2", operation=workload.Operation(name="index", operation_type=workload.OperationType.Bulk, params=None)) test_procedure = workload.TestProcedure(name="unittest", schedule=[index1, index2], default=True) t = workload.Workload("unittest", "unittest-workload", test_procedures=[test_procedure]) @@ -475,56 +390,50 @@ def test_calculate_global_stats(self): store.put_value_cluster_level("throughput", 1000, unit="docs/s", task="index #1", operation_type=workload.OperationType.Bulk) store.put_value_cluster_level("throughput", 2000, unit="docs/s", task="index #1", operation_type=workload.OperationType.Bulk) - store.put_value_cluster_level("latency", 2800, unit="ms", task="index #1", operation_type=workload.OperationType.Bulk, - sample_type=metrics.SampleType.Warmup) + store.put_value_cluster_level("latency", 2800, unit="ms", task="index #1", operation_type=workload.OperationType.Bulk, sample_type=metrics.SampleType.Warmup) store.put_value_cluster_level("latency", 200, unit="ms", task="index #1", operation_type=workload.OperationType.Bulk) store.put_value_cluster_level("latency", 220, unit="ms", task="index #1", operation_type=workload.OperationType.Bulk) store.put_value_cluster_level("latency", 225, unit="ms", task="index #1", operation_type=workload.OperationType.Bulk) - store.put_value_cluster_level("service_time", 250, unit="ms", task="index #1", operation_type=workload.OperationType.Bulk, - sample_type=metrics.SampleType.Warmup, meta_data={"success": False}, relative_time=536) - store.put_value_cluster_level("service_time", 190, unit="ms", task="index #1", operation_type=workload.OperationType.Bulk, - meta_data={"success": True}, relative_time=595) - store.put_value_cluster_level("service_time", 200, unit="ms", task="index #1", operation_type=workload.OperationType.Bulk, - meta_data={"success": False}, relative_time=709) - store.put_value_cluster_level("service_time", 210, unit="ms", task="index #1", operation_type=workload.OperationType.Bulk, - meta_data={"success": True}, relative_time=653) + store.put_value_cluster_level( + "service_time", + 250, + unit="ms", + task="index #1", + operation_type=workload.OperationType.Bulk, + sample_type=metrics.SampleType.Warmup, + meta_data={"success": False}, + relative_time=536, + ) + store.put_value_cluster_level("service_time", 190, unit="ms", task="index #1", operation_type=workload.OperationType.Bulk, meta_data={"success": True}, relative_time=595) + store.put_value_cluster_level("service_time", 200, unit="ms", task="index #1", operation_type=workload.OperationType.Bulk, meta_data={"success": False}, relative_time=709) + store.put_value_cluster_level("service_time", 210, unit="ms", task="index #1", operation_type=workload.OperationType.Bulk, meta_data={"success": True}, relative_time=653) # only warmup samples - store.put_value_cluster_level("throughput", 500, unit="docs/s", task="index #2", - sample_type=metrics.SampleType.Warmup, operation_type=workload.OperationType.Bulk) - store.put_value_cluster_level("latency", 2800, unit="ms", task="index #2", operation_type=workload.OperationType.Bulk, - sample_type=metrics.SampleType.Warmup) - store.put_value_cluster_level("service_time", 250, unit="ms", task="index #2", operation_type=workload.OperationType.Bulk, - sample_type=metrics.SampleType.Warmup, relative_time=600) - - store.put_doc(doc={ - "name": "ml_processing_time", - "job": "benchmark_ml_job_1", - "min": 2.2, - "mean": 12.3, - "median": 17.2, - "max": 36.0, - "unit": "ms" - }, level=metrics.MetaInfoScope.cluster) + store.put_value_cluster_level("throughput", 500, unit="docs/s", task="index #2", sample_type=metrics.SampleType.Warmup, operation_type=workload.OperationType.Bulk) + store.put_value_cluster_level("latency", 2800, unit="ms", task="index #2", operation_type=workload.OperationType.Bulk, sample_type=metrics.SampleType.Warmup) + store.put_value_cluster_level( + "service_time", 250, unit="ms", task="index #2", operation_type=workload.OperationType.Bulk, sample_type=metrics.SampleType.Warmup, relative_time=600 + ) + + store.put_doc( + doc={"name": "ml_processing_time", "job": "benchmark_ml_job_1", "min": 2.2, "mean": 12.3, "median": 17.2, "max": 36.0, "unit": "ms"}, + level=metrics.MetaInfoScope.cluster, + ) stats = metrics.calculate_results(store, metrics.create_test_run(cfg, t, test_procedure)) del store opm = stats.metrics("index #1") - self.assertEqual(collections.OrderedDict( - [("min", 500), ("mean", 1125), ("median", 1000), ("max", 2000), ("unit", "docs/s")]), opm["throughput"]) - self.assertEqual(collections.OrderedDict( - [("50_0", 220), ("100_0", 225), ("mean", 215), ("unit", "ms")]), opm["latency"]) - self.assertEqual(collections.OrderedDict( - [("50_0", 200), ("100_0", 210), ("mean", 200), ("unit", "ms")]), opm["service_time"]) + self.assertEqual(collections.OrderedDict([("min", 500), ("mean", 1125), ("median", 1000), ("max", 2000), ("unit", "docs/s")]), opm["throughput"]) + self.assertEqual(collections.OrderedDict([("50_0", 220), ("100_0", 225), ("mean", 215), ("unit", "ms")]), opm["latency"]) + self.assertEqual(collections.OrderedDict([("50_0", 200), ("100_0", 210), ("mean", 200), ("unit", "ms")]), opm["service_time"]) self.assertAlmostEqual(0.3333333333333333, opm["error_rate"]) - self.assertEqual(709*1000, opm["duration"]) + self.assertEqual(709 * 1000, opm["duration"]) opm2 = stats.metrics("index #2") - self.assertEqual(collections.OrderedDict( - [("min", None), ("mean", None), ("median", None), ("max", None), ("unit", "docs/s")]), opm2["throughput"]) + self.assertEqual(collections.OrderedDict([("min", None), ("mean", None), ("median", None), ("max", None), ("unit", "docs/s")]), opm2["throughput"]) self.assertEqual(1, len(stats.ml_processing_time)) self.assertEqual("benchmark_ml_job_1", stats.ml_processing_time[0]["job"]) @@ -533,7 +442,7 @@ def test_calculate_global_stats(self): self.assertEqual(17.2, stats.ml_processing_time[0]["median"]) self.assertEqual(36.0, stats.ml_processing_time[0]["max"]) self.assertEqual("ms", stats.ml_processing_time[0]["unit"]) - self.assertEqual(600*1000, opm2["duration"]) + self.assertEqual(600 * 1000, opm2["duration"]) def test_calculate_system_stats(self): cfg = config.Config() @@ -548,10 +457,7 @@ def test_calculate_system_stats(self): cfg.add(config.Scope.application, "test_run", "pipeline", "from-sources") cfg.add(config.Scope.application, "workload", "params", {}) - index = workload.Task(name="index #1", operation=workload.Operation( - name="index", - operation_type=workload.OperationType.Bulk, - params=None)) + index = workload.Task(name="index #1", operation=workload.Operation(name="index", operation_type=workload.OperationType.Bulk, params=None)) test_procedure = workload.TestProcedure(name="unittest", schedule=[index], default=True) t = workload.Workload("unittest", "unittest-workload", test_procedures=[test_procedure]) @@ -566,14 +472,7 @@ def test_calculate_system_stats(self): del store - self.assertEqual([ - { - "node": "benchmark-node-0", - "name": "index_size", - "value": 2048, - "unit": "bytes" - } - ], stats.node_metrics) + self.assertEqual([{"node": "benchmark-node-0", "name": "index_size", "value": 2048, "unit": "bytes"}], stats.node_metrics) def select(l, name, operation=None, job=None, node=None): @@ -598,27 +497,34 @@ def tearDown(self): del self.cfg def test_add_administrative_task_with_error_rate_in_results(self): - op = Operation(name='delete-index', operation_type='DeleteIndex', params={'include-in-reporting': False}) - task = Task('delete-index', operation=op, schedule='deterministic') - test_procedure = TestProcedure(name='append-fast-with-conflicts', schedule=[task], meta_data={}) - - self.metrics_store.open(InMemoryMetricsStoreTests.TEST_RUN_ID, InMemoryMetricsStoreTests.TEST_RUN_TIMESTAMP, - "test", "append-fast-with-conflicts", "defaults", create=True) - self.metrics_store.put_doc(doc={"@timestamp": 1595896761994, - "relative-time-ms": 283.382, - "test-run-id": "fb26018b-428d-4528-b36b-cf8c54a303ec", - "test-run-timestamp": "20200728T003905Z", "environment": "local", - "workload": "geonames", "test_procedure": "append-fast-with-conflicts", - "cluster-config-instance": "defaults", "name": "service_time", "value": 72.67997100007051, - "unit": "ms", "sample-type": "normal", - "meta": {"source_revision": "7f634e9f44834fbc12724506cc1da681b0c3b1e3", - "distribution_version": "7.6.0", "distribution_flavor": "oss", - "success": False}, "task": "delete-index", "operation": "delete-index", - "operation-type": "DeleteIndex"}) - - result = GlobalStatsCalculator(store=self.metrics_store, workload=Workload(name='geonames', meta_data={}), - test_procedure=test_procedure)() - assert "delete-index" in [op_metric.get('task') for op_metric in result.op_metrics] + op = Operation(name="delete-index", operation_type="DeleteIndex", params={"include-in-reporting": False}) + task = Task("delete-index", operation=op, schedule="deterministic") + test_procedure = TestProcedure(name="append-fast-with-conflicts", schedule=[task], meta_data={}) + + self.metrics_store.open(InMemoryMetricsStoreTests.TEST_RUN_ID, InMemoryMetricsStoreTests.TEST_RUN_TIMESTAMP, "test", "append-fast-with-conflicts", "defaults", create=True) + self.metrics_store.put_doc( + doc={ + "@timestamp": 1595896761994, + "relative-time-ms": 283.382, + "test-run-id": "fb26018b-428d-4528-b36b-cf8c54a303ec", + "test-run-timestamp": "20200728T003905Z", + "environment": "local", + "workload": "geonames", + "test_procedure": "append-fast-with-conflicts", + "cluster-config-instance": "defaults", + "name": "service_time", + "value": 72.67997100007051, + "unit": "ms", + "sample-type": "normal", + "meta": {"source_revision": "7f634e9f44834fbc12724506cc1da681b0c3b1e3", "distribution_version": "7.6.0", "distribution_flavor": "oss", "success": False}, + "task": "delete-index", + "operation": "delete-index", + "operation-type": "DeleteIndex", + } + ) + + result = GlobalStatsCalculator(store=self.metrics_store, workload=Workload(name="geonames", meta_data={}), test_procedure=test_procedure)() + assert "delete-index" in [op_metric.get("task") for op_metric in result.op_metrics] class GlobalStatsTests(TestCase): @@ -628,381 +534,146 @@ def test_as_flat_list(self): { "task": "index #1", "operation": "index", - "throughput": { - "min": 450, - "mean": 450, - "median": 450, - "max": 452, - "unit": "docs/s" - }, + "throughput": {"min": 450, "mean": 450, "median": 450, "max": 452, "unit": "docs/s"}, "latency": { "50": 340, "100": 376, }, - "service_time": { - "50": 341, - "100": 376 - }, + "service_time": {"50": 341, "100": 376}, "error_rate": 0.0, - "meta": { - "clients": 8, - "phase": "idx" - } + "meta": {"clients": 8, "phase": "idx"}, }, { "task": "search #2", "operation": "search", - "throughput": { - "min": 9, - "mean": 10, - "median": 10, - "max": 12, - "unit": "ops/s" - }, + "throughput": {"min": 9, "mean": 10, "median": 10, "max": 12, "unit": "ops/s"}, "latency": { "50": 99, "100": 111, }, - "service_time": { - "50": 98, - "100": 110 - }, - "error_rate": 0.1 - } + "service_time": {"50": 98, "100": 110}, + "error_rate": 0.1, + }, ], "ml_processing_time": [ - { - "job": "job_1", - "min": 3.3, - "mean": 5.2, - "median": 5.8, - "max": 12.34 - }, - { - "job": "job_2", - "min": 3.55, - "mean": 4.2, - "median": 4.9, - "max": 9.4 - }, + {"job": "job_1", "min": 3.3, "mean": 5.2, "median": 5.8, "max": 12.34}, + {"job": "job_2", "min": 3.55, "mean": 4.2, "median": 4.9, "max": 9.4}, ], "young_gc_time": 68, "young_gc_count": 7, "old_gc_time": 0, "old_gc_count": 0, "merge_time": 3702, - "merge_time_per_shard": { - "min": 40, - "median": 3702, - "max": 3900, - "unit": "ms" - }, + "merge_time_per_shard": {"min": 40, "median": 3702, "max": 3900, "unit": "ms"}, "merge_count": 2, "refresh_time": 596, - "refresh_time_per_shard": { - "min": 48, - "median": 89, - "max": 204, - "unit": "ms" - }, + "refresh_time_per_shard": {"min": 48, "median": 89, "max": 204, "unit": "ms"}, "refresh_count": 10, "flush_time": None, "flush_time_per_shard": {}, - "flush_count": 0 + "flush_count": 0, } s = metrics.GlobalStats(d) metric_list = s.as_flat_list() - self.assertEqual({ - "name": "throughput", - "task": "index #1", - "operation": "index", - "value": { - "min": 450, - "mean": 450, - "median": 450, - "max": 452, - "unit": "docs/s" - }, - "meta": { - "clients": 8, - "phase": "idx" - } - }, select(metric_list, "throughput", operation="index")) - - self.assertEqual({ - "name": "service_time", - "task": "index #1", - "operation": "index", - "value": { - "50": 341, - "100": 376 - }, - "meta": { - "clients": 8, - "phase": "idx" - } - }, select(metric_list, "service_time", operation="index")) - - self.assertEqual({ - "name": "latency", - "task": "index #1", - "operation": "index", - "value": { - "50": 340, - "100": 376 - }, - "meta": { - "clients": 8, - "phase": "idx" - } - }, select(metric_list, "latency", operation="index")) - - self.assertEqual({ - "name": "error_rate", - "task": "index #1", - "operation": "index", - "value": { - "single": 0.0 + self.assertEqual( + { + "name": "throughput", + "task": "index #1", + "operation": "index", + "value": {"min": 450, "mean": 450, "median": 450, "max": 452, "unit": "docs/s"}, + "meta": {"clients": 8, "phase": "idx"}, }, - "meta": { - "clients": 8, - "phase": "idx" - } - }, select(metric_list, "error_rate", operation="index")) - - self.assertEqual({ - "name": "throughput", - "task": "search #2", - "operation": "search", - "value": { - "min": 9, - "mean": 10, - "median": 10, - "max": 12, - "unit": "ops/s" - } - }, select(metric_list, "throughput", operation="search")) - - self.assertEqual({ - "name": "service_time", - "task": "search #2", - "operation": "search", - "value": { - "50": 98, - "100": 110 - } - }, select(metric_list, "service_time", operation="search")) - - self.assertEqual({ - "name": "latency", - "task": "search #2", - "operation": "search", - "value": { - "50": 99, - "100": 111 - } - }, select(metric_list, "latency", operation="search")) - - self.assertEqual({ - "name": "error_rate", - "task": "search #2", - "operation": "search", - "value": { - "single": 0.1 - } - }, select(metric_list, "error_rate", operation="search")) - - self.assertEqual({ - "name": "ml_processing_time", - "job": "job_1", - "value": { - "min": 3.3, - "mean": 5.2, - "median": 5.8, - "max": 12.34 - } - }, select(metric_list, "ml_processing_time", job="job_1")) - - self.assertEqual({ - "name": "ml_processing_time", - "job": "job_2", - "value": { - "min": 3.55, - "mean": 4.2, - "median": 4.9, - "max": 9.4 - } - }, select(metric_list, "ml_processing_time", job="job_2")) + select(metric_list, "throughput", operation="index"), + ) - self.assertEqual({ - "name": "young_gc_time", - "value": { - "single": 68 - } - }, select(metric_list, "young_gc_time")) - self.assertEqual({ - "name": "young_gc_count", - "value": { - "single": 7 - } - }, select(metric_list, "young_gc_count")) + self.assertEqual( + {"name": "service_time", "task": "index #1", "operation": "index", "value": {"50": 341, "100": 376}, "meta": {"clients": 8, "phase": "idx"}}, + select(metric_list, "service_time", operation="index"), + ) - self.assertEqual({ - "name": "old_gc_time", - "value": { - "single": 0 - } - }, select(metric_list, "old_gc_time")) - self.assertEqual({ - "name": "old_gc_count", - "value": { - "single": 0 - } - }, select(metric_list, "old_gc_count")) + self.assertEqual( + {"name": "latency", "task": "index #1", "operation": "index", "value": {"50": 340, "100": 376}, "meta": {"clients": 8, "phase": "idx"}}, + select(metric_list, "latency", operation="index"), + ) - self.assertEqual({ - "name": "merge_time", - "value": { - "single": 3702 - } - }, select(metric_list, "merge_time")) - - self.assertEqual({ - "name": "merge_time_per_shard", - "value": { - "min": 40, - "median": 3702, - "max": 3900, - "unit": "ms" - } - }, select(metric_list, "merge_time_per_shard")) + self.assertEqual( + {"name": "error_rate", "task": "index #1", "operation": "index", "value": {"single": 0.0}, "meta": {"clients": 8, "phase": "idx"}}, + select(metric_list, "error_rate", operation="index"), + ) - self.assertEqual({ - "name": "merge_count", - "value": { - "single": 2 - } - }, select(metric_list, "merge_count")) + self.assertEqual( + {"name": "throughput", "task": "search #2", "operation": "search", "value": {"min": 9, "mean": 10, "median": 10, "max": 12, "unit": "ops/s"}}, + select(metric_list, "throughput", operation="search"), + ) - self.assertEqual({ - "name": "refresh_time", - "value": { - "single": 596 - } - }, select(metric_list, "refresh_time")) - - self.assertEqual({ - "name": "refresh_time_per_shard", - "value": { - "min": 48, - "median": 89, - "max": 204, - "unit": "ms" - } - }, select(metric_list, "refresh_time_per_shard")) + self.assertEqual( + {"name": "service_time", "task": "search #2", "operation": "search", "value": {"50": 98, "100": 110}}, select(metric_list, "service_time", operation="search") + ) - self.assertEqual({ - "name": "refresh_count", - "value": { - "single": 10 - } - }, select(metric_list, "refresh_count")) + self.assertEqual({"name": "latency", "task": "search #2", "operation": "search", "value": {"50": 99, "100": 111}}, select(metric_list, "latency", operation="search")) + + self.assertEqual({"name": "error_rate", "task": "search #2", "operation": "search", "value": {"single": 0.1}}, select(metric_list, "error_rate", operation="search")) + + self.assertEqual( + {"name": "ml_processing_time", "job": "job_1", "value": {"min": 3.3, "mean": 5.2, "median": 5.8, "max": 12.34}}, select(metric_list, "ml_processing_time", job="job_1") + ) + + self.assertEqual( + {"name": "ml_processing_time", "job": "job_2", "value": {"min": 3.55, "mean": 4.2, "median": 4.9, "max": 9.4}}, select(metric_list, "ml_processing_time", job="job_2") + ) + + self.assertEqual({"name": "young_gc_time", "value": {"single": 68}}, select(metric_list, "young_gc_time")) + self.assertEqual({"name": "young_gc_count", "value": {"single": 7}}, select(metric_list, "young_gc_count")) + + self.assertEqual({"name": "old_gc_time", "value": {"single": 0}}, select(metric_list, "old_gc_time")) + self.assertEqual({"name": "old_gc_count", "value": {"single": 0}}, select(metric_list, "old_gc_count")) + + self.assertEqual({"name": "merge_time", "value": {"single": 3702}}, select(metric_list, "merge_time")) + + self.assertEqual({"name": "merge_time_per_shard", "value": {"min": 40, "median": 3702, "max": 3900, "unit": "ms"}}, select(metric_list, "merge_time_per_shard")) + + self.assertEqual({"name": "merge_count", "value": {"single": 2}}, select(metric_list, "merge_count")) + + self.assertEqual({"name": "refresh_time", "value": {"single": 596}}, select(metric_list, "refresh_time")) + + self.assertEqual({"name": "refresh_time_per_shard", "value": {"min": 48, "median": 89, "max": 204, "unit": "ms"}}, select(metric_list, "refresh_time_per_shard")) + + self.assertEqual({"name": "refresh_count", "value": {"single": 10}}, select(metric_list, "refresh_count")) self.assertIsNone(select(metric_list, "flush_time")) self.assertIsNone(select(metric_list, "flush_time_per_shard")) - self.assertEqual({ - "name": "flush_count", - "value": { - "single": 0 - } - }, select(metric_list, "flush_count")) + self.assertEqual({"name": "flush_count", "value": {"single": 0}}, select(metric_list, "flush_count")) class SystemStatsTests(TestCase): def test_as_flat_list(self): d = { "node_metrics": [ - { - "node": "benchmark-node-0", - "name": "startup_time", - "value": 3.4 - }, - { - "node": "benchmark-node-1", - "name": "startup_time", - "value": 4.2 - }, - { - "node": "benchmark-node-0", - "name": "index_size", - "value": 300 * 1024 * 1024 - }, - { - "node": "benchmark-node-1", - "name": "index_size", - "value": 302 * 1024 * 1024 - }, - { - "node": "benchmark-node-0", - "name": "bytes_written", - "value": 817 * 1024 * 1024 - }, - { - "node": "benchmark-node-1", - "name": "bytes_written", - "value": 833 * 1024 * 1024 - }, + {"node": "benchmark-node-0", "name": "startup_time", "value": 3.4}, + {"node": "benchmark-node-1", "name": "startup_time", "value": 4.2}, + {"node": "benchmark-node-0", "name": "index_size", "value": 300 * 1024 * 1024}, + {"node": "benchmark-node-1", "name": "index_size", "value": 302 * 1024 * 1024}, + {"node": "benchmark-node-0", "name": "bytes_written", "value": 817 * 1024 * 1024}, + {"node": "benchmark-node-1", "name": "bytes_written", "value": 833 * 1024 * 1024}, ], } s = metrics.SystemStats(d) metric_list = s.as_flat_list() - self.assertEqual({ - "node": "benchmark-node-0", - "name": "startup_time", - "value": { - "single": 3.4 - } - }, select(metric_list, "startup_time", node="benchmark-node-0")) + self.assertEqual({"node": "benchmark-node-0", "name": "startup_time", "value": {"single": 3.4}}, select(metric_list, "startup_time", node="benchmark-node-0")) - self.assertEqual({ - "node": "benchmark-node-1", - "name": "startup_time", - "value": { - "single": 4.2 - } - }, select(metric_list, "startup_time", node="benchmark-node-1")) + self.assertEqual({"node": "benchmark-node-1", "name": "startup_time", "value": {"single": 4.2}}, select(metric_list, "startup_time", node="benchmark-node-1")) - self.assertEqual({ - "node": "benchmark-node-0", - "name": "index_size", - "value": { - "single": 300 * 1024 * 1024 - } - }, select(metric_list, "index_size", node="benchmark-node-0")) + self.assertEqual({"node": "benchmark-node-0", "name": "index_size", "value": {"single": 300 * 1024 * 1024}}, select(metric_list, "index_size", node="benchmark-node-0")) - self.assertEqual({ - "node": "benchmark-node-1", - "name": "index_size", - "value": { - "single": 302 * 1024 * 1024 - } - }, select(metric_list, "index_size", node="benchmark-node-1")) + self.assertEqual({"node": "benchmark-node-1", "name": "index_size", "value": {"single": 302 * 1024 * 1024}}, select(metric_list, "index_size", node="benchmark-node-1")) - self.assertEqual({ - "node": "benchmark-node-0", - "name": "bytes_written", - "value": { - "single": 817 * 1024 * 1024 - } - }, select(metric_list, "bytes_written", node="benchmark-node-0")) + self.assertEqual( + {"node": "benchmark-node-0", "name": "bytes_written", "value": {"single": 817 * 1024 * 1024}}, select(metric_list, "bytes_written", node="benchmark-node-0") + ) - self.assertEqual({ - "node": "benchmark-node-1", - "name": "bytes_written", - "value": { - "single": 833 * 1024 * 1024 - } - }, select(metric_list, "bytes_written", node="benchmark-node-1")) + self.assertEqual( + {"node": "benchmark-node-1", "name": "bytes_written", "value": {"single": 833 * 1024 * 1024}}, select(metric_list, "bytes_written", node="benchmark-node-1") + ) diff --git a/tests/publisher_test.py b/tests/publisher_test.py index 052d2138..d71cd2c8 100644 --- a/tests/publisher_test.py +++ b/tests/publisher_test.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -30,6 +30,7 @@ from solrorbit import publisher + # pylint: disable=protected-access class FormatterTests(TestCase): def setUp(self): @@ -40,7 +41,7 @@ def setUp(self): self.metrics_data = [ ["Min Throughput", "index", "17300", "18000", "700", "ops/s"], ["Median Throughput", "index", "17500", "18500", "1000", "ops/s"], - ["Max Throughput", "index", "17700", "19000", "1300", "ops/s"] + ["Max Throughput", "index", "17700", "19000", "1300", "ops/s"], ] self.numbers_align = "right" @@ -62,7 +63,7 @@ def test_formats_as_csv(self): # 1 header line, no separation line + 3 data lines self.assertEqual(1 + 3, len(formatted.splitlines())) - @patch('solrorbit.publisher.convert.to_bool') + @patch("solrorbit.publisher.convert.to_bool") def test_publish_throughput_handles_different_metrics(self, mock_to_bool): config = Mock() @@ -81,29 +82,11 @@ def config_opts_side_effect(*args, **kwargs): # Mock for regular test run regular_stats = Mock() - regular_stats.metrics.return_value = { - "throughput": { - "min": 100, - "max": 200, - "mean": 150, - "median": 160, - "unit": "ops/s" - } - } + regular_stats.metrics.return_value = {"throughput": {"min": 100, "max": 200, "mean": 150, "median": 160, "unit": "ops/s"}} # Mock for aggregated test run aggregated_stats = Mock() - aggregated_stats.metrics.return_value = { - "throughput": { - "overall_min": 95, - "overall_max": 205, - "min": 100, - "max": 200, - "mean": 150, - "median": 160, - "unit": "ops/s" - } - } + aggregated_stats.metrics.return_value = {"throughput": {"overall_min": 95, "overall_max": 205, "min": 100, "max": 200, "mean": 150, "median": 160, "unit": "ops/s"}} # Test with regular stats result_regular = comparison_publisher._publish_throughput(regular_stats, regular_stats, "test_task") diff --git a/tests/scripts_test.py b/tests/scripts_test.py index 2ed73290..fb54562e 100644 --- a/tests/scripts_test.py +++ b/tests/scripts_test.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -30,12 +30,11 @@ import subprocess from unittest import TestCase -class ScriptsTests(TestCase): +class ScriptsTests(TestCase): def test_scr(self): os.environ["BENCHMARK_HOME"] = "/tmp" script = pathlib.Path(__file__).parent.parent / "scripts" / "expand-data-corpus.py" - p = subprocess.Popen([str(script), "-c", "10"], - stdout=subprocess.PIPE, stderr=subprocess.PIPE) - stderr = p.communicate()[1].decode('UTF-8') + p = subprocess.Popen([str(script), "-c", "10"], stdout=subprocess.PIPE, stderr=subprocess.PIPE) + stderr = p.communicate()[1].decode("UTF-8") self.assertTrue("could not find benchmark config file" in stderr) diff --git a/tests/synthetic_data_generator/incorrect_sample_custom_module.py b/tests/synthetic_data_generator/incorrect_sample_custom_module.py index da636505..4b65fccb 100644 --- a/tests/synthetic_data_generator/incorrect_sample_custom_module.py +++ b/tests/synthetic_data_generator/incorrect_sample_custom_module.py @@ -12,27 +12,26 @@ from mimesis.providers.base import BaseProvider GEOGRAPHIC_CLUSTERS = { - 'Manhattan': { - 'center': {'lat': 40.7831, 'lon': -73.9712}, - 'radius': 0.05 # degrees + "Manhattan": { + "center": {"lat": 40.7831, "lon": -73.9712}, + "radius": 0.05, # degrees }, - 'Brooklyn': { - 'center': {'lat': 40.6782, 'lon': -73.9442}, - 'radius': 0.05 + "Brooklyn": {"center": {"lat": 40.6782, "lon": -73.9442}, "radius": 0.05}, + "Austin": { + "center": {"lat": 30.2672, "lon": -97.7431}, + "radius": 0.1, # Increased radius to cover more of Austin }, - 'Austin': { - 'center': {'lat': 30.2672, 'lon': -97.7431}, - 'radius': 0.1 # Increased radius to cover more of Austin - } } + def generate_location(cluster): """Generate a random location within a cluster""" - center = GEOGRAPHIC_CLUSTERS[cluster]['center'] - radius = GEOGRAPHIC_CLUSTERS[cluster]['radius'] - lat = center['lat'] + random.uniform(-radius, radius) - lon = center['lon'] + random.uniform(-radius, radius) - return {'lat': lat, 'lon': lon} + center = GEOGRAPHIC_CLUSTERS[cluster]["center"] + radius = GEOGRAPHIC_CLUSTERS[cluster]["radius"] + lat = center["lat"] + random.uniform(-radius, radius) + lon = center["lon"] + random.uniform(-radius, radius) + return {"lat": lat, "lon": lon} + class NumericString(BaseProvider): class Meta: @@ -40,7 +39,8 @@ class Meta: @staticmethod def generate(length=5) -> str: - return ''.join([str(random.randint(0, 9)) for _ in range(length)]) + return "".join([str(random.randint(0, 9)) for _ in range(length)]) + class MultipleChoices(BaseProvider): class Meta: diff --git a/tests/synthetic_data_generator/sample_custom_module.py b/tests/synthetic_data_generator/sample_custom_module.py index 8ca67531..0d33b865 100644 --- a/tests/synthetic_data_generator/sample_custom_module.py +++ b/tests/synthetic_data_generator/sample_custom_module.py @@ -12,27 +12,26 @@ from mimesis.providers.base import BaseProvider GEOGRAPHIC_CLUSTERS = { - 'Manhattan': { - 'center': {'lat': 40.7831, 'lon': -73.9712}, - 'radius': 0.05 # degrees + "Manhattan": { + "center": {"lat": 40.7831, "lon": -73.9712}, + "radius": 0.05, # degrees }, - 'Brooklyn': { - 'center': {'lat': 40.6782, 'lon': -73.9442}, - 'radius': 0.05 + "Brooklyn": {"center": {"lat": 40.6782, "lon": -73.9442}, "radius": 0.05}, + "Austin": { + "center": {"lat": 30.2672, "lon": -97.7431}, + "radius": 0.1, # Increased radius to cover more of Austin }, - 'Austin': { - 'center': {'lat': 30.2672, 'lon': -97.7431}, - 'radius': 0.1 # Increased radius to cover more of Austin - } } + def generate_location(cluster): """Generate a random location within a cluster""" - center = GEOGRAPHIC_CLUSTERS[cluster]['center'] - radius = GEOGRAPHIC_CLUSTERS[cluster]['radius'] - lat = center['lat'] + random.uniform(-radius, radius) - lon = center['lon'] + random.uniform(-radius, radius) - return {'lat': lat, 'lon': lon} + center = GEOGRAPHIC_CLUSTERS[cluster]["center"] + radius = GEOGRAPHIC_CLUSTERS[cluster]["radius"] + lat = center["lat"] + random.uniform(-radius, radius) + lon = center["lon"] + random.uniform(-radius, radius) + return {"lat": lat, "lon": lon} + class NumericString(BaseProvider): class Meta: @@ -40,7 +39,8 @@ class Meta: @staticmethod def generate(length=5) -> str: - return ''.join([str(random.randint(0, 9)) for _ in range(length)]) + return "".join([str(random.randint(0, 9)) for _ in range(length)]) + class MultipleChoices(BaseProvider): class Meta: @@ -55,9 +55,10 @@ def generate(choices, num_of_choices=5) -> str: return [choices[random.randint(0, total_choices_available)] for _ in range(num_of_choices)] + def generate_synthetic_document(providers, **custom_lists): - generic = providers['generic'] - random_mimesis = providers['random'] + generic = providers["generic"] + random_mimesis = providers["random"] first_name = generic.person.first_name() last_name = generic.person.last_name() @@ -66,47 +67,34 @@ def generate_synthetic_document(providers, **custom_lists): # Driver Document document = { "dog_driver_id": f"DD{generic.numeric_string.generate(length=4)}", - "dog_name": random_mimesis.choice(custom_lists['dog_names']), - "dog_breed": random_mimesis.choice(custom_lists['dog_breeds']), + "dog_name": random_mimesis.choice(custom_lists["dog_names"]), + "dog_breed": random_mimesis.choice(custom_lists["dog_breeds"]), "license_number": f"{random_mimesis.choice(custom_lists['license_plates'])}{generic.numeric_string.generate(length=4)}", - "favorite_treats": random_mimesis.choice(custom_lists['treats']), - "preferred_tip": random_mimesis.choice(custom_lists['tips']), - "vehicle_type": random_mimesis.choice(custom_lists['vehicle_types']), - "vehicle_make": random_mimesis.choice(custom_lists['vehicle_makes']), - "vehicle_model": random_mimesis.choice(custom_lists['vehicle_models']), - "vehicle_year": random_mimesis.choice(custom_lists['vehicle_years']), - "vehicle_color": random_mimesis.choice(custom_lists['vehicle_colors']), - "license_plate": random_mimesis.choice(custom_lists['license_plates']), + "favorite_treats": random_mimesis.choice(custom_lists["treats"]), + "preferred_tip": random_mimesis.choice(custom_lists["tips"]), + "vehicle_type": random_mimesis.choice(custom_lists["vehicle_types"]), + "vehicle_make": random_mimesis.choice(custom_lists["vehicle_makes"]), + "vehicle_model": random_mimesis.choice(custom_lists["vehicle_models"]), + "vehicle_year": random_mimesis.choice(custom_lists["vehicle_years"]), + "vehicle_color": random_mimesis.choice(custom_lists["vehicle_colors"]), + "license_plate": random_mimesis.choice(custom_lists["license_plates"]), "current_location": generate_location(city), - "status": random.choice(['available', 'busy', 'offline']), + "status": random.choice(["available", "busy", "offline"]), "current_ride": f"R{generic.numeric_string.generate(length=6)}", - "account_status": random_mimesis.choice(custom_lists['account_status']), + "account_status": random_mimesis.choice(custom_lists["account_status"]), "join_date": generic.datetime.formatted_date(), "total_rides": generic.numeric.integer_number(start=1, end=200), "rating": generic.numeric.float_number(start=1.0, end=5.0, precision=2), "earnings": { - "today": { - "amount": generic.numeric.float_number(start=1.0, end=5.0, precision=2), - "currency": "USD" - }, - "this_week": { - "amount": generic.numeric.float_number(start=1.0, end=5.0, precision=2), - "currency": "USD" - }, - "this_month": { - "amount": generic.numeric.float_number(start=1.0, end=5.0, precision=2), - "currency": "USD" - } + "today": {"amount": generic.numeric.float_number(start=1.0, end=5.0, precision=2), "currency": "USD"}, + "this_week": {"amount": generic.numeric.float_number(start=1.0, end=5.0, precision=2), "currency": "USD"}, + "this_month": {"amount": generic.numeric.float_number(start=1.0, end=5.0, precision=2), "currency": "USD"}, }, "last_grooming_check": "2023-12-01", - "owner": { - "first_name": first_name, - "last_name": last_name, - "email": f"{first_name}{last_name}@gmail.com" - }, - "special_skills": generic.multiple_choices.generate(custom_lists['skills'], num_of_choices=3), + "owner": {"first_name": first_name, "last_name": last_name, "email": f"{first_name}{last_name}@gmail.com"}, + "special_skills": generic.multiple_choices.generate(custom_lists["skills"], num_of_choices=3), "bark_volume": generic.numeric.float_number(start=1.0, end=10.0, precision=2), - "tail_wag_speed": generic.numeric.float_number(start=1.0, end=10.0, precision=1) + "tail_wag_speed": generic.numeric.float_number(start=1.0, end=10.0, precision=1), } return document diff --git a/tests/synthetic_data_generator/strategies_test.py b/tests/synthetic_data_generator/strategies_test.py index bde368de..2b96f8be 100644 --- a/tests/synthetic_data_generator/strategies_test.py +++ b/tests/synthetic_data_generator/strategies_test.py @@ -19,7 +19,6 @@ class TestCustomStrategy: - @pytest.fixture def setup_sdg_metadata(self): return SyntheticDataGeneratorMetadata( @@ -28,28 +27,32 @@ def setup_sdg_metadata(self): custom_config_path="/path/to/config", custom_module_path="/path/to/module", output_path="/path/to/output", - total_size_gb=10 + total_size_gb=10, ) @pytest.fixture def mock_sdg_config(self): # This is what yaml.safe_load() would return loaded_sdg_config = { - 'settings': {'workers': 8, 'max_file_size_gb': 1, 'chunk_size': 10000}, - 'CustomGenerationValues': { - 'custom_lists': {'dog_names': ['Hana', 'Youpie', 'Charlie', 'Lucy', 'Cooper', 'Luna', 'Rocky', 'Daisy', 'Buddy', 'Molly'], - 'dog_breeds': ['Jindo', 'Labrador', 'German Shepherd', 'Golden Retriever', 'Bulldog', - 'Poodle', 'Beagle', 'Rottweiler', 'Boxer', 'Dachshund', 'Chihuahua'], - 'treats': ['cookies', 'pup_cup', 'jerky'], 'license_plates': ['WOOF101', 'BARKATAMZN'], - 'tips': ['biscuits', 'cash'], - 'skills': ['sniffing', 'squirrel_chasing', 'bite_tail', 'smile'], - 'vehicle_types': ['sedan', 'suv', 'truck'], 'vehicle_makes': ['toyta', 'honda', 'nissan'], - 'vehicle_models': ['rav4', 'accord', 'murano'], 'vehicle_years': [2012, 2015, 2019], - 'vehicle_colors': ['white', 'red', 'blue', 'black', 'silver'], 'account_status': ['active', 'inactive']}, - 'custom_providers': ['NumericString', 'MultipleChoices'] - } - } - + "settings": {"workers": 8, "max_file_size_gb": 1, "chunk_size": 10000}, + "CustomGenerationValues": { + "custom_lists": { + "dog_names": ["Hana", "Youpie", "Charlie", "Lucy", "Cooper", "Luna", "Rocky", "Daisy", "Buddy", "Molly"], + "dog_breeds": ["Jindo", "Labrador", "German Shepherd", "Golden Retriever", "Bulldog", "Poodle", "Beagle", "Rottweiler", "Boxer", "Dachshund", "Chihuahua"], + "treats": ["cookies", "pup_cup", "jerky"], + "license_plates": ["WOOF101", "BARKATAMZN"], + "tips": ["biscuits", "cash"], + "skills": ["sniffing", "squirrel_chasing", "bite_tail", "smile"], + "vehicle_types": ["sedan", "suv", "truck"], + "vehicle_makes": ["toyta", "honda", "nissan"], + "vehicle_models": ["rav4", "accord", "murano"], + "vehicle_years": [2012, 2015, 2019], + "vehicle_colors": ["white", "red", "blue", "black", "silver"], + "account_status": ["active", "inactive"], + }, + "custom_providers": ["NumericString", "MultipleChoices"], + }, + } sdg_config = SDGConfig(**loaded_sdg_config) return sdg_config @@ -69,7 +72,7 @@ def dask_client(self, sample_docs_generated): mock_futures = [] for doc in sample_docs_generated: mock_future = MagicMock() - mock_future.result.return_value = [doc] # Each worker produces a list of docs. In this test, each worker is returning a list of one doc + mock_future.result.return_value = [doc] # Each worker produces a list of docs. In this test, each worker is returning a list of one doc mock_futures.append(mock_future) dask_client.submit.side_effect = mock_futures @@ -79,29 +82,87 @@ def dask_client(self, sample_docs_generated): @pytest.fixture def sample_docs_generated(self): return [ - {"dog_driver_id": "DD4444", "dog_name": "Hana", "dog_breed": "Korean Jindo", - "license_number": "BARKATAMZN6176", "favorite_treats": "pup_cup", "preferred_tip": "cash", "vehicle_type": "truck", - "vehicle_make": "honda", "vehicle_model": "murano", "vehicle_year": 2019, "vehicle_color": "black", "license_plate": "WOOF101", - "current_location": {"lat": 40.73389901726337, "lon": -73.95726095278667}, "status": "available", "current_ride": "R922315", - "account_status": "inactive", "join_date": "05/01/2017", "total_rides": 166, "rating": 1.91, - "earnings": {"today": {"amount": 2.26, "currency": "USD"}, "this_week": {"amount": 1.44, "currency": "USD"}, "this_month": {"amount": 1.31, "currency": "USD"}}, - "last_grooming_check": "2023-12-01", "owner": {"first_name": "Elfrieda", "last_name": "Huffman", "email": "ElfriedaHuffman@gmail.com"}, - "special_skills": ["bite_tail", "bite_tail", "smile"], "bark_volume": 9.33, "tail_wag_speed": 4.5}, - {"dog_driver_id": "DD2495", "dog_name": "Luna", "dog_breed": "Chihuahua", "license_number": "WOOF1014472", "favorite_treats": "jerky", - "preferred_tip": "cash", "vehicle_type": "sedan", "vehicle_make": "nissan", "vehicle_model": "murano", "vehicle_year": 2019, "vehicle_color": "silver", - "license_plate": "WOOF101", "current_location": {"lat": 40.75654230013213, "lon": -73.98178219702368}, "status": "busy", "current_ride": "R690202", - "account_status": "active", "join_date": "03/06/2018", "total_rides": 24, "rating": 2.13, - "earnings": {"today": {"amount": 1.4, "currency": "USD"}, "this_week": {"amount": 3.89, "currency": "USD"}, "this_month": {"amount": 4.88, "currency": "USD"}}, - "last_grooming_check": "2023-12-01", "owner": {"first_name": "Avery", "last_name": "Moran", "email": "AveryMoran@gmail.com"}, - "special_skills": ["sniffing", "bite_tail", "smile"], - "bark_volume": 1.73, "tail_wag_speed": 9.1}, - {"dog_driver_id": "DD2223", "dog_name": "Youpie", "dog_breed": "Boxer", "license_number": "BARKATAMZN7147", "favorite_treats": "jerky", - "preferred_tip": "cash", "vehicle_type": "suv", "vehicle_make": "nissan", "vehicle_model": "murano", "vehicle_year": 2015, "vehicle_color": "white", - "license_plate": "BARKATAMZN", "current_location": {"lat": 30.212385699598567, "lon": -97.76458615057449}, "status": "available", - "current_ride": "R303297", "account_status": "inactive", "join_date": "03/25/2025", "total_rides": 110, "rating": 1.58, - "earnings": {"today": {"amount": 2.98, "currency": "USD"}, "this_week": {"amount": 2.71, "currency": "USD"}, "this_month": {"amount": 4.89, "currency": "USD"}}, - "last_grooming_check": "2023-12-01", "owner": {"first_name": "Otto", "last_name": "Stephens", "email": "OttoStephens@gmail.com"}, - "special_skills": ["sniffing", "squirrel_chasing", "bite_tail"], "bark_volume": 7.09, "tail_wag_speed": 7.2} + { + "dog_driver_id": "DD4444", + "dog_name": "Hana", + "dog_breed": "Korean Jindo", + "license_number": "BARKATAMZN6176", + "favorite_treats": "pup_cup", + "preferred_tip": "cash", + "vehicle_type": "truck", + "vehicle_make": "honda", + "vehicle_model": "murano", + "vehicle_year": 2019, + "vehicle_color": "black", + "license_plate": "WOOF101", + "current_location": {"lat": 40.73389901726337, "lon": -73.95726095278667}, + "status": "available", + "current_ride": "R922315", + "account_status": "inactive", + "join_date": "05/01/2017", + "total_rides": 166, + "rating": 1.91, + "earnings": {"today": {"amount": 2.26, "currency": "USD"}, "this_week": {"amount": 1.44, "currency": "USD"}, "this_month": {"amount": 1.31, "currency": "USD"}}, + "last_grooming_check": "2023-12-01", + "owner": {"first_name": "Elfrieda", "last_name": "Huffman", "email": "ElfriedaHuffman@gmail.com"}, + "special_skills": ["bite_tail", "bite_tail", "smile"], + "bark_volume": 9.33, + "tail_wag_speed": 4.5, + }, + { + "dog_driver_id": "DD2495", + "dog_name": "Luna", + "dog_breed": "Chihuahua", + "license_number": "WOOF1014472", + "favorite_treats": "jerky", + "preferred_tip": "cash", + "vehicle_type": "sedan", + "vehicle_make": "nissan", + "vehicle_model": "murano", + "vehicle_year": 2019, + "vehicle_color": "silver", + "license_plate": "WOOF101", + "current_location": {"lat": 40.75654230013213, "lon": -73.98178219702368}, + "status": "busy", + "current_ride": "R690202", + "account_status": "active", + "join_date": "03/06/2018", + "total_rides": 24, + "rating": 2.13, + "earnings": {"today": {"amount": 1.4, "currency": "USD"}, "this_week": {"amount": 3.89, "currency": "USD"}, "this_month": {"amount": 4.88, "currency": "USD"}}, + "last_grooming_check": "2023-12-01", + "owner": {"first_name": "Avery", "last_name": "Moran", "email": "AveryMoran@gmail.com"}, + "special_skills": ["sniffing", "bite_tail", "smile"], + "bark_volume": 1.73, + "tail_wag_speed": 9.1, + }, + { + "dog_driver_id": "DD2223", + "dog_name": "Youpie", + "dog_breed": "Boxer", + "license_number": "BARKATAMZN7147", + "favorite_treats": "jerky", + "preferred_tip": "cash", + "vehicle_type": "suv", + "vehicle_make": "nissan", + "vehicle_model": "murano", + "vehicle_year": 2015, + "vehicle_color": "white", + "license_plate": "BARKATAMZN", + "current_location": {"lat": 30.212385699598567, "lon": -97.76458615057449}, + "status": "available", + "current_ride": "R303297", + "account_status": "inactive", + "join_date": "03/25/2025", + "total_rides": 110, + "rating": 1.58, + "earnings": {"today": {"amount": 2.98, "currency": "USD"}, "this_week": {"amount": 2.71, "currency": "USD"}, "this_month": {"amount": 4.89, "currency": "USD"}}, + "last_grooming_check": "2023-12-01", + "owner": {"first_name": "Otto", "last_name": "Stephens", "email": "OttoStephens@gmail.com"}, + "special_skills": ["sniffing", "squirrel_chasing", "bite_tail"], + "bark_volume": 7.09, + "tail_wag_speed": 7.2, + }, ] @pytest.fixture @@ -131,7 +192,7 @@ def sample_field_names(self): "owner", "special_skills", "bark_volume", - "tail_wag_speed" + "tail_wag_speed", ] def test_generate_test_document(self, sample_field_names, custom_module_strategy): @@ -145,13 +206,13 @@ def test_avg_doc_size(self, custom_module_strategy): assert isinstance(avg_doc_size, int) def test_generate_data_chunks_across_workers(self, dask_client, custom_module_strategy): - futures_across_workers = custom_module_strategy.generate_data_chunks_across_workers(dask_client, 3, [1,2,3], None, None) + futures_across_workers = custom_module_strategy.generate_data_chunks_across_workers(dask_client, 3, [1, 2, 3], None, None) docs = [future.result() for future in futures_across_workers] assert len(docs) == 3 - assert docs[0][0]['dog_name'] == 'Hana' - assert docs[1][0]['dog_name'] == 'Luna' - assert docs[2][0]['dog_name'] == 'Youpie' + assert docs[0][0]["dog_name"] == "Hana" + assert docs[1][0]["dog_name"] == "Luna" + assert docs[2][0]["dog_name"] == "Youpie" def test_generate_data_chunk_from_worker(self, sample_field_names, custom_module_strategy): user_defined_function = custom_module_strategy.custom_module.generate_synthetic_document @@ -173,7 +234,6 @@ def test_custom_module_missing_generate_synthetic_document_function(self, setup_ class TestMappingStrategy: - @pytest.fixture def setup_sdg_metadata(self): return SyntheticDataGeneratorMetadata( @@ -182,35 +242,34 @@ def setup_sdg_metadata(self): custom_config_path="/path/to/config", custom_module_path="/path/to/module", output_path="/path/to/output", - total_size_gb=10 + total_size_gb=10, ) @pytest.fixture def mock_sdg_config(self): # This is what yaml.safe_load() would return loaded_sdg_config = { - 'settings': {'workers': 8, 'max_file_size_gb': 1, 'chunk_size': 10000}, - 'MappingGenerationValues': { - 'generator_overrides': { - 'integer': {'min': 0, 'max': 20}, - 'long': {'min': 0, 'max': 1000}, - 'float': {'min': 0.0, 'max': 1.0}, - 'double': {'min': 0.0, 'max': 2000.0}, - 'date': {'start_date': '2020-01-01', 'end_date': '2023-01-01', 'format': 'yyyy-mm-dd'}, - 'text': {'must_include': ['lorem', 'ipsum']}, - 'keyword': {'choices': ['naruto', 'sakura', 'sasuke']} - }, - 'field_overrides': { - 'id': {'generator': 'generate_keyword', - 'params': {'choices': ['Helly R', 'Mark S', 'Irving B']}}, - 'promo_codes': {'generator': 'generate_keyword', 'params': {'choices': ['HOT_SUMMER', 'TREATSYUM!']}}, - 'preferences.language': {'generator': 'generate_keyword', 'params': {'choices': ['Python', 'English']}}, - 'payment_methods.type': {'generator': 'generate_keyword', 'params': {'choices': ['Visa', 'Mastercard', 'Cash', 'Venmo']}}, - 'preferences.allergies': {'generator': 'generate_keyword', 'params': {'choices': ['Squirrels', 'Cats']}}, - 'favorite_locations.name': {'generator': 'generate_keyword', 'params': {'choices': ['Austin', 'NYC', 'Miami']}} - } - } - } + "settings": {"workers": 8, "max_file_size_gb": 1, "chunk_size": 10000}, + "MappingGenerationValues": { + "generator_overrides": { + "integer": {"min": 0, "max": 20}, + "long": {"min": 0, "max": 1000}, + "float": {"min": 0.0, "max": 1.0}, + "double": {"min": 0.0, "max": 2000.0}, + "date": {"start_date": "2020-01-01", "end_date": "2023-01-01", "format": "yyyy-mm-dd"}, + "text": {"must_include": ["lorem", "ipsum"]}, + "keyword": {"choices": ["naruto", "sakura", "sasuke"]}, + }, + "field_overrides": { + "id": {"generator": "generate_keyword", "params": {"choices": ["Helly R", "Mark S", "Irving B"]}}, + "promo_codes": {"generator": "generate_keyword", "params": {"choices": ["HOT_SUMMER", "TREATSYUM!"]}}, + "preferences.language": {"generator": "generate_keyword", "params": {"choices": ["Python", "English"]}}, + "payment_methods.type": {"generator": "generate_keyword", "params": {"choices": ["Visa", "Mastercard", "Cash", "Venmo"]}}, + "preferences.allergies": {"generator": "generate_keyword", "params": {"choices": ["Squirrels", "Cats"]}}, + "favorite_locations.name": {"generator": "generate_keyword", "params": {"choices": ["Austin", "NYC", "Miami"]}}, + }, + }, + } sdg_config = SDGConfig(**loaded_sdg_config) return sdg_config @@ -220,257 +279,101 @@ def basic_opensearch_index_mappings(self): return { "mappings": { "properties": { - "title": { - "type": "text", - "analyzer": "standard", - "fields": { - "keyword": { - "type": "keyword", - "ignore_above": 256 - } - } - }, - "description": { - "type": "text" - }, - "price": { - "type": "float" - }, - "created_at": { - "type": "date", - "format": "strict_date_optional_time||epoch_millis" - }, - "is_available": { - "type": "boolean" - }, - "category_id": { - "type": "integer" - }, - "tags": { - "type": "keyword" - } + "title": {"type": "text", "analyzer": "standard", "fields": {"keyword": {"type": "keyword", "ignore_above": 256}}}, + "description": {"type": "text"}, + "price": {"type": "float"}, + "created_at": {"type": "date", "format": "strict_date_optional_time||epoch_millis"}, + "is_available": {"type": "boolean"}, + "category_id": {"type": "integer"}, + "tags": {"type": "keyword"}, } }, - "settings": { - "number_of_shards": 1, - "number_of_replicas": 1 - } + "settings": {"number_of_shards": 1, "number_of_replicas": 1}, } @pytest.fixture def complex_opensearch_index_mapping(self): return { - "mappings": { - "dynamic": "strict", - "properties": { - "user": { + "mappings": { + "dynamic": "strict", + "properties": { + "user": { "type": "object", "properties": { - "id": { - "type": "keyword" - }, - "email": { - "type": "keyword" - }, - "name": { - "type": "text", - "fields": { - "keyword": { - "type": "keyword", - "ignore_above": 256 - }, - "completion": { - "type": "completion" - } - }, - "analyzer": "standard" - }, + "id": {"type": "keyword"}, + "email": {"type": "keyword"}, + "name": {"type": "text", "fields": {"keyword": {"type": "keyword", "ignore_above": 256}, "completion": {"type": "completion"}}, "analyzer": "standard"}, "address": { - "type": "object", - "properties": { - "street": { - "type": "text" + "type": "object", + "properties": { + "street": {"type": "text"}, + "city": {"type": "keyword"}, + "state": {"type": "keyword"}, + "zip": {"type": "keyword"}, + "location": {"type": "geo_point"}, }, - "city": { - "type": "keyword" - }, - "state": { - "type": "keyword" - }, - "zip": { - "type": "keyword" - }, - "location": { - "type": "geo_point" - } - } }, - "preferences": { - "type": "object", - "dynamic": True - } - } + "preferences": {"type": "object", "dynamic": True}, }, - "orders": { + }, + "orders": { "type": "nested", "properties": { - "id": { - "type": "keyword" - }, - "date": { - "type": "date", - "format": "strict_date_optional_time||epoch_millis" - }, - "amount": { - "type": "scaled_float", - "scaling_factor": 100 - }, - "status": { - "type": "keyword" - }, + "id": {"type": "keyword"}, + "date": {"type": "date", "format": "strict_date_optional_time||epoch_millis"}, + "amount": {"type": "scaled_float", "scaling_factor": 100}, + "status": {"type": "keyword"}, "items": { - "type": "nested", - "properties": { - "product_id": { - "type": "keyword" - }, - "name": { - "type": "text", - "fields": { - "keyword": { - "type": "keyword" - } - } - }, - "quantity": { - "type": "short" - }, - "price": { - "type": "float" + "type": "nested", + "properties": { + "product_id": {"type": "keyword"}, + "name": {"type": "text", "fields": {"keyword": {"type": "keyword"}}}, + "quantity": {"type": "short"}, + "price": {"type": "float"}, + "categories": {"type": "keyword"}, }, - "categories": { - "type": "keyword" - } - } }, "shipping_address": { - "type": "object", - "properties": { - "street": { - "type": "text" - }, - "city": { - "type": "keyword" - }, - "state": { - "type": "keyword" - }, - "zip": { - "type": "keyword" + "type": "object", + "properties": { + "street": {"type": "text"}, + "city": {"type": "keyword"}, + "state": {"type": "keyword"}, + "zip": {"type": "keyword"}, + "location": {"type": "geo_point"}, }, - "location": { - "type": "geo_point" - } - } - } - } - }, - "activity_log": { - "type": "nested", - "properties": { - "timestamp": { - "type": "date" }, - "action": { - "type": "keyword" - }, - "ip_address": { - "type": "ip" - }, - "details": { - "type": "object", - "enabled": False - } - } }, - "metadata": { + }, + "activity_log": { + "type": "nested", + "properties": {"timestamp": {"type": "date"}, "action": {"type": "keyword"}, "ip_address": {"type": "ip"}, "details": {"type": "object", "enabled": False}}, + }, + "metadata": { "type": "object", "properties": { - "created_at": { - "type": "date" - }, - "updated_at": { - "type": "date" - }, - "tags": { - "type": "keyword" - }, - "source": { - "type": "keyword" - }, - "version": { - "type": "integer" - } - } + "created_at": {"type": "date"}, + "updated_at": {"type": "date"}, + "tags": {"type": "keyword"}, + "source": {"type": "keyword"}, + "version": {"type": "integer"}, }, - "description": { + }, + "description": { "type": "text", "analyzer": "english", - "fields": { - "keyword": { - "type": "keyword", - "ignore_above": 256 - }, - "standard": { - "type": "text", - "analyzer": "standard" - } - } - }, - "ranking_scores": { - "type": "object", - "properties": { - "popularity": { - "type": "float" - }, - "relevance": { - "type": "float" - }, - "quality": { - "type": "float" - } - } - }, - "permissions": { - "type": "nested", - "properties": { - "user_id": { - "type": "keyword" - }, - "role": { - "type": "keyword" - }, - "granted_at": { - "type": "date" - } - } - } - } + "fields": {"keyword": {"type": "keyword", "ignore_above": 256}, "standard": {"type": "text", "analyzer": "standard"}}, }, - "settings": { - "number_of_shards": 3, - "number_of_replicas": 2, - "analysis": { - "analyzer": { - "email_analyzer": { - "type": "custom", - "tokenizer": "uax_url_email", - "filter": ["lowercase", "stop"] - } - } - } - } - } + "ranking_scores": {"type": "object", "properties": {"popularity": {"type": "float"}, "relevance": {"type": "float"}, "quality": {"type": "float"}}}, + "permissions": {"type": "nested", "properties": {"user_id": {"type": "keyword"}, "role": {"type": "keyword"}, "granted_at": {"type": "date"}}}, + }, + }, + "settings": { + "number_of_shards": 3, + "number_of_replicas": 2, + "analysis": {"analyzer": {"email_analyzer": {"type": "custom", "tokenizer": "uax_url_email", "filter": ["lowercase", "stop"]}}}, + }, + } @pytest.fixture def sample_docs_for_basic_mappings(self): @@ -478,11 +381,12 @@ def sample_docs_for_basic_mappings(self): { "title": "ipsum Sample text for 70", "description": "lorem Sample text for 94", - "price": 0.0, "created_at": "2020-04-22", + "price": 0.0, + "created_at": "2020-04-22", "is_available": False, "category_id": 13, - "tags": "Mark S" - }, + "tags": "Mark S", + }, { "title": "lorem Sample text for 7", "description": "lorem Sample text for 68", @@ -490,7 +394,7 @@ def sample_docs_for_basic_mappings(self): "created_at": "2022-07-04", "is_available": True, "category_id": 10, - "tags": "Helly R" + "tags": "Helly R", }, { "title": "lorem Sample text for 87", @@ -499,8 +403,8 @@ def sample_docs_for_basic_mappings(self): "created_at": "2020-02-21", "is_available": True, "category_id": 16, - "tags": "Irving B" - } + "tags": "Irving B", + }, ] @pytest.fixture @@ -516,126 +420,69 @@ def sample_docs_for_complex_mappings(self): "city": "naruto", "state": "sasuke", "zip": "naruto", - "location": { - "lat": 26.99576279125438, - "lon": -106.55835561335948 - } + "location": {"lat": 26.99576279125438, "lon": -106.55835561335948}, }, - "preferences": {} - }, - "orders": [{ - "id": "sakura", - "date": "2021-11-17", - "amount": "unknown_type", - "status": "naruto", - "items": [{ - "product_id": "sasuke", - "name": "ipsum Sample text for 5", - "quantity": 19395, - "price": 0.31, - "categories": "sakura" - }, { - "product_id": "sakura", - "name": "lorem Sample text for 94", - "quantity": -5488, - "price": 0.76, - "categories": "sasuke" - }], - "shipping_address": { - "street": "lorem Sample text for 63", - "city": "naruto", - "state": "sasuke", - "zip": "sakura", - "location": { - "lat": 40.62216376082151, - "lon": -65.29355206583621 - } - } - }, { - "id": "sakura", - "date": "2021-03-05", - "amount": "unknown_type", - "status": "sasuke", - "items": [{ - "product_id": "sakura", - "name": "lorem Sample text for 100", - "quantity": -1063, - "price": 0.11, - "categories": "sakura" - }], - "shipping_address": { - "street": "lorem Sample text for 72", - "city": "sasuke", - "state": "naruto", - "zip": "sasuke", - "location": { - "lat": 64.92051504356562, - "lon": -64.90676398234942 - } - } - }, { - "id": "naruto", - "date": "2020-04-23", - "amount": "unknown_type", - "status": "sasuke", - "items": [{ - "product_id": "sasuke", - "name": "lorem Sample text for 5", - "quantity": -27595, - "price": 0.73, - "categories": "sasuke" - }, { - "product_id": "sasuke", - "name": "ipsum Sample text for 8", - "quantity": 3581, - "price": 0.65, - "categories": "naruto" - }], - "shipping_address": { - "street": "lorem Sample text for 30", - "city": "sasuke", - "state": "sakura", - "zip": "naruto", - "location": { - "lat": -48.34559264417752, - "lon": -178.36558923535966 - } - } - }], - "activity_log": [{ - "timestamp": "2021-09-22", - "action": "sakura", - "ip_address": "101.52.247.55", - "details": {} - }, { - "timestamp": "2022-01-22", - "action": "sasuke", - "ip_address": "44.189.12.245", - "details": {} - }, { - "timestamp": "2022-07-03", - "action": "naruto", - "ip_address": "131.232.186.58", - "details": {} - }], - "metadata": { - "created_at": "2022-02-09", - "updated_at": "2022-04-18", - "tags": "sasuke", - "source": "sasuke", - "version": 6 + "preferences": {}, }, + "orders": [ + { + "id": "sakura", + "date": "2021-11-17", + "amount": "unknown_type", + "status": "naruto", + "items": [ + {"product_id": "sasuke", "name": "ipsum Sample text for 5", "quantity": 19395, "price": 0.31, "categories": "sakura"}, + {"product_id": "sakura", "name": "lorem Sample text for 94", "quantity": -5488, "price": 0.76, "categories": "sasuke"}, + ], + "shipping_address": { + "street": "lorem Sample text for 63", + "city": "naruto", + "state": "sasuke", + "zip": "sakura", + "location": {"lat": 40.62216376082151, "lon": -65.29355206583621}, + }, + }, + { + "id": "sakura", + "date": "2021-03-05", + "amount": "unknown_type", + "status": "sasuke", + "items": [{"product_id": "sakura", "name": "lorem Sample text for 100", "quantity": -1063, "price": 0.11, "categories": "sakura"}], + "shipping_address": { + "street": "lorem Sample text for 72", + "city": "sasuke", + "state": "naruto", + "zip": "sasuke", + "location": {"lat": 64.92051504356562, "lon": -64.90676398234942}, + }, + }, + { + "id": "naruto", + "date": "2020-04-23", + "amount": "unknown_type", + "status": "sasuke", + "items": [ + {"product_id": "sasuke", "name": "lorem Sample text for 5", "quantity": -27595, "price": 0.73, "categories": "sasuke"}, + {"product_id": "sasuke", "name": "ipsum Sample text for 8", "quantity": 3581, "price": 0.65, "categories": "naruto"}, + ], + "shipping_address": { + "street": "lorem Sample text for 30", + "city": "sasuke", + "state": "sakura", + "zip": "naruto", + "location": {"lat": -48.34559264417752, "lon": -178.36558923535966}, + }, + }, + ], + "activity_log": [ + {"timestamp": "2021-09-22", "action": "sakura", "ip_address": "101.52.247.55", "details": {}}, + {"timestamp": "2022-01-22", "action": "sasuke", "ip_address": "44.189.12.245", "details": {}}, + {"timestamp": "2022-07-03", "action": "naruto", "ip_address": "131.232.186.58", "details": {}}, + ], + "metadata": {"created_at": "2022-02-09", "updated_at": "2022-04-18", "tags": "sasuke", "source": "sasuke", "version": 6}, "description": "lorem Sample text for 46", - "ranking_scores": { - "popularity": 0.89, - "relevance": 0.79, - "quality": 0.95 - }, - "permissions": [{ - "user_id": "sasuke", - "role": "naruto", - "granted_at": "2022-09-18" - }] + "ranking_scores": {"popularity": 0.89, "relevance": 0.79, "quality": 0.95}, + "permissions": [{"user_id": "sasuke", "role": "naruto", "granted_at": "2022-09-18"}], }, { "user": { @@ -647,130 +494,70 @@ def sample_docs_for_complex_mappings(self): "city": "sasuke", "state": "naruto", "zip": "naruto", - "location": { - "lat": 18.716170289552792, - "lon": 26.676759590182087 - } + "location": {"lat": 18.716170289552792, "lon": 26.676759590182087}, }, - "preferences": {} - }, - "orders": [{ - "id": "sakura", - "date": "2020-03-14", - "amount": "unknown_type", - "status": "sasuke", - "items": [{ - "product_id": "naruto", - "name": "ipsum Sample text for 72", - "quantity": 8096, - "price": 0.98, - "categories": "sasuke" - }], - "shipping_address": { - "street": "ipsum Sample text for 70", - "city": "sasuke", - "state": "naruto", - "zip": "sasuke", - "location": { - "lat": 35.160379130894796, - "lon": 142.70658997708557 - } - } - }, { - "id": "sasuke", - "date": "2021-04-10", - "amount": "unknown_type", - "status": "sasuke", - "items": [{ - "product_id": "sakura", - "name": "ipsum Sample text for 95", - "quantity": -26888, - "price": 0.9, - "categories": "sasuke" - }, { - "product_id": "sakura", - "name": "lorem Sample text for 77", - "quantity": -4878, - "price": 0.68, - "categories": "sakura" - }, { - "product_id": "sasuke", - "name": "lorem Sample text for 60", - "quantity": 3465, - "price": 0.92, - "categories": "sakura" - }], - "shipping_address": { - "street": "ipsum Sample text for 69", - "city": "sakura", - "state": "sasuke", - "zip": "sakura", - "location": { - "lat": -10.880093403565638, - "lon": -167.79823045612983 - } - } - }, { - "id": "sasuke", - "date": "2021-11-06", - "amount": "unknown_type", - "status": "sakura", - "items": [{ - "product_id": "sakura", - "name": "lorem Sample text for 55", - "quantity": -22916, - "price": 0.29, - "categories": "naruto" - }, { - "product_id": "sasuke", - "name": "lorem Sample text for 72", - "quantity": -5256, - "price": 0.45, - "categories": "naruto" - }], - "shipping_address": { - "street": "ipsum Sample text for 76", - "city": "naruto", - "state": "naruto", - "zip": "sasuke", - "location": { - "lat": 89.19863809212737, - "lon": 114.09913348615743 - } - } - }], - "activity_log": [{ - "timestamp": "2020-06-26", - "action": "naruto", - "ip_address": "139.179.72.223", - "details": {} - }], - "metadata": { - "created_at": "2020-02-11", - "updated_at": "2021-05-25", - "tags": "sasuke", - "source": "sakura", - "version": 16 + "preferences": {}, }, + "orders": [ + { + "id": "sakura", + "date": "2020-03-14", + "amount": "unknown_type", + "status": "sasuke", + "items": [{"product_id": "naruto", "name": "ipsum Sample text for 72", "quantity": 8096, "price": 0.98, "categories": "sasuke"}], + "shipping_address": { + "street": "ipsum Sample text for 70", + "city": "sasuke", + "state": "naruto", + "zip": "sasuke", + "location": {"lat": 35.160379130894796, "lon": 142.70658997708557}, + }, + }, + { + "id": "sasuke", + "date": "2021-04-10", + "amount": "unknown_type", + "status": "sasuke", + "items": [ + {"product_id": "sakura", "name": "ipsum Sample text for 95", "quantity": -26888, "price": 0.9, "categories": "sasuke"}, + {"product_id": "sakura", "name": "lorem Sample text for 77", "quantity": -4878, "price": 0.68, "categories": "sakura"}, + {"product_id": "sasuke", "name": "lorem Sample text for 60", "quantity": 3465, "price": 0.92, "categories": "sakura"}, + ], + "shipping_address": { + "street": "ipsum Sample text for 69", + "city": "sakura", + "state": "sasuke", + "zip": "sakura", + "location": {"lat": -10.880093403565638, "lon": -167.79823045612983}, + }, + }, + { + "id": "sasuke", + "date": "2021-11-06", + "amount": "unknown_type", + "status": "sakura", + "items": [ + {"product_id": "sakura", "name": "lorem Sample text for 55", "quantity": -22916, "price": 0.29, "categories": "naruto"}, + {"product_id": "sasuke", "name": "lorem Sample text for 72", "quantity": -5256, "price": 0.45, "categories": "naruto"}, + ], + "shipping_address": { + "street": "ipsum Sample text for 76", + "city": "naruto", + "state": "naruto", + "zip": "sasuke", + "location": {"lat": 89.19863809212737, "lon": 114.09913348615743}, + }, + }, + ], + "activity_log": [{"timestamp": "2020-06-26", "action": "naruto", "ip_address": "139.179.72.223", "details": {}}], + "metadata": {"created_at": "2020-02-11", "updated_at": "2021-05-25", "tags": "sasuke", "source": "sakura", "version": 16}, "description": "ipsum Sample text for 8", - "ranking_scores": { - "popularity": 0.16, - "relevance": 0.06, - "quality": 0.72 - }, - "permissions": [{ - "user_id": "sakura", - "role": "sasuke", - "granted_at": "2020-01-23" - }, { - "user_id": "sasuke", - "role": "naruto", - "granted_at": "2022-11-19" - }, { - "user_id": "sakura", - "role": "naruto", - "granted_at": "2021-07-06" - }] + "ranking_scores": {"popularity": 0.16, "relevance": 0.06, "quality": 0.72}, + "permissions": [ + {"user_id": "sakura", "role": "sasuke", "granted_at": "2020-01-23"}, + {"user_id": "sasuke", "role": "naruto", "granted_at": "2022-11-19"}, + {"user_id": "sakura", "role": "naruto", "granted_at": "2021-07-06"}, + ], }, { "user": { @@ -782,109 +569,58 @@ def sample_docs_for_complex_mappings(self): "city": "sasuke", "state": "sasuke", "zip": "sakura", - "location": { - "lat": 32.588564110838846, - "lon": -23.052179676898845 - } + "location": {"lat": 32.588564110838846, "lon": -23.052179676898845}, }, - "preferences": {} - }, - "orders": [{ - "id": "sakura", - "date": "2021-03-21", - "amount": "unknown_type", - "status": "sasuke", - "items": [{ - "product_id": "sasuke", - "name": "ipsum Sample text for 63", - "quantity": 7348, - "price": 0.37, - "categories": "sasuke" - }], - "shipping_address": { - "street": "ipsum Sample text for 34", - "city": "naruto", - "state": "sakura", - "zip": "sasuke", - "location": { - "lat": -51.05656520056002, - "lon": -17.084501461214813 - } - } - }, { - "id": "sakura", - "date": "2021-01-21", - "amount": "unknown_type", - "status": "sasuke", - "items": [{ - "product_id": "naruto", - "name": "ipsum Sample text for 43", - "quantity": 28814, - "price": 0.91, - "categories": "sakura" - }, { - "product_id": "sasuke", - "name": "ipsum Sample text for 80", - "quantity": -293, - "price": 0.52, - "categories": "naruto" - }, { - "product_id": "sakura", - "name": "ipsum Sample text for 72", - "quantity": -31854, - "price": 0.54, - "categories": "sakura" - }], - "shipping_address": { - "street": "lorem Sample text for 19", - "city": "sakura", - "state": "sasuke", - "zip": "sasuke", - "location": { - "lat": 40.539565425667945, - "lon": 59.77455111107835 - } - } - }], - "activity_log": [{ - "timestamp": "2021-04-14", - "action": "sakura", - "ip_address": "155.140.52.23", - "details": {} - }], - "metadata": { - "created_at": "2021-06-22", - "updated_at": "2020-09-06", - "tags": "naruto", - "source": "naruto", - "version": 9 + "preferences": {}, }, + "orders": [ + { + "id": "sakura", + "date": "2021-03-21", + "amount": "unknown_type", + "status": "sasuke", + "items": [{"product_id": "sasuke", "name": "ipsum Sample text for 63", "quantity": 7348, "price": 0.37, "categories": "sasuke"}], + "shipping_address": { + "street": "ipsum Sample text for 34", + "city": "naruto", + "state": "sakura", + "zip": "sasuke", + "location": {"lat": -51.05656520056002, "lon": -17.084501461214813}, + }, + }, + { + "id": "sakura", + "date": "2021-01-21", + "amount": "unknown_type", + "status": "sasuke", + "items": [ + {"product_id": "naruto", "name": "ipsum Sample text for 43", "quantity": 28814, "price": 0.91, "categories": "sakura"}, + {"product_id": "sasuke", "name": "ipsum Sample text for 80", "quantity": -293, "price": 0.52, "categories": "naruto"}, + {"product_id": "sakura", "name": "ipsum Sample text for 72", "quantity": -31854, "price": 0.54, "categories": "sakura"}, + ], + "shipping_address": { + "street": "lorem Sample text for 19", + "city": "sakura", + "state": "sasuke", + "zip": "sasuke", + "location": {"lat": 40.539565425667945, "lon": 59.77455111107835}, + }, + }, + ], + "activity_log": [{"timestamp": "2021-04-14", "action": "sakura", "ip_address": "155.140.52.23", "details": {}}], + "metadata": {"created_at": "2021-06-22", "updated_at": "2020-09-06", "tags": "naruto", "source": "naruto", "version": 9}, "description": "ipsum Sample text for 55", - "ranking_scores": { - "popularity": 0.49, - "relevance": 0.23, - "quality": 0.32 - }, - "permissions": [{ - "user_id": "sakura", - "role": "sasuke", - "granted_at": "2020-07-29" - }, { - "user_id": "naruto", - "role": "sasuke", - "granted_at": "2021-01-19" - }] - } + "ranking_scores": {"popularity": 0.49, "relevance": 0.23, "quality": 0.32}, + "permissions": [{"user_id": "sakura", "role": "sasuke", "granted_at": "2020-07-29"}, {"user_id": "naruto", "role": "sasuke", "granted_at": "2021-01-19"}], + }, ] - @pytest.fixture def mapping_strategy_with_basic_mappings(self, setup_sdg_metadata, mock_sdg_config, basic_opensearch_index_mappings): strategy = MappingStrategy(setup_sdg_metadata, mock_sdg_config, basic_opensearch_index_mappings) return strategy - @pytest.fixture def mapping_strategy_with_complex_mappings(self, setup_sdg_metadata, mock_sdg_config, complex_opensearch_index_mappings): strategy = MappingStrategy(setup_sdg_metadata, mock_sdg_config, complex_opensearch_index_mappings) @@ -898,14 +634,13 @@ def dask_client(self, sample_docs_for_basic_mappings): mock_futures = [] for doc in sample_docs_for_basic_mappings: mock_future = MagicMock() - mock_future.result.return_value = [doc] # Each worker produces a list of docs. In this test, each worker is returning a list of one doc + mock_future.result.return_value = [doc] # Each worker produces a list of docs. In this test, each worker is returning a list of one doc mock_futures.append(mock_future) dask_client.submit.side_effect = mock_futures return dask_client - def test_generate_test_document(self, mapping_strategy_with_basic_mappings): field_names = ["title", "description", "price", "created_at", "is_available", "category_id", "tags"] document = mapping_strategy_with_basic_mappings.generate_test_document() @@ -919,13 +654,13 @@ def test_avg_doc_size(self, mapping_strategy_with_basic_mappings): def test_generate_data_chunks_across_workers(self, dask_client, mapping_strategy_with_basic_mappings): fields = ["title", "description", "price", "created_at", "is_available", "category_id", "tags"] - futures_across_workers = mapping_strategy_with_basic_mappings.generate_data_chunks_across_workers(dask_client, 3, [1,2,3], None, None) + futures_across_workers = mapping_strategy_with_basic_mappings.generate_data_chunks_across_workers(dask_client, 3, [1, 2, 3], None, None) docs = [future.result() for future in futures_across_workers] assert len(docs) == 3 - assert docs[0][0]['tags'] == 'Mark S' - assert docs[1][0]['tags'] == 'Helly R' - assert docs[2][0]['tags'] == 'Irving B' + assert docs[0][0]["tags"] == "Mark S" + assert docs[1][0]["tags"] == "Helly R" + assert docs[2][0]["tags"] == "Irving B" for doc in docs: for field in fields: @@ -943,34 +678,33 @@ def test_generate_data_chunk_from_worker(self, mapping_strategy_with_basic_mappi for field in fields: assert field in doc -class TestMappingConverter: +class TestMappingConverter: @pytest.fixture def mock_sdg_config(self): # This is what yaml.safe_load() would return loaded_sdg_config = { - 'settings': {'workers': 8, 'max_file_size_gb': 1, 'chunk_size': 10000}, - 'MappingGenerationValues': { - 'generator_overrides': { - 'integer': {'min': 0, 'max': 20}, - 'long': {'min': 0, 'max': 1000}, - 'float': {'min': 0.0, 'max': 1.0}, - 'double': {'min': 0.0, 'max': 2000.0}, - 'date': {'start_date': '2020-01-01', 'end_date': '2023-01-01', 'format': 'yyyy-mm-dd'}, - 'text': {'must_include': ['lorem', 'ipsum']}, - 'keyword': {'choices': ['naruto', 'sakura', 'sasuke']} - }, - 'field_overrides': { - 'id': {'generator': 'generate_keyword', - 'params': {'choices': ['Helly R', 'Mark S', 'Irving B']}}, - 'promo_codes': {'generator': 'generate_keyword', 'params': {'choices': ['HOT_SUMMER', 'TREATSYUM!']}}, - 'preferences.language': {'generator': 'generate_keyword', 'params': {'choices': ['Python', 'English']}}, - 'payment_methods.type': {'generator': 'generate_keyword', 'params': {'choices': ['Visa', 'Mastercard', 'Cash', 'Venmo']}}, - 'preferences.allergies': {'generator': 'generate_keyword', 'params': {'choices': ['Squirrels', 'Cats']}}, - 'favorite_locations.name': {'generator': 'generate_keyword', 'params': {'choices': ['Austin', 'NYC', 'Miami']}} - } - } - } + "settings": {"workers": 8, "max_file_size_gb": 1, "chunk_size": 10000}, + "MappingGenerationValues": { + "generator_overrides": { + "integer": {"min": 0, "max": 20}, + "long": {"min": 0, "max": 1000}, + "float": {"min": 0.0, "max": 1.0}, + "double": {"min": 0.0, "max": 2000.0}, + "date": {"start_date": "2020-01-01", "end_date": "2023-01-01", "format": "yyyy-mm-dd"}, + "text": {"must_include": ["lorem", "ipsum"]}, + "keyword": {"choices": ["naruto", "sakura", "sasuke"]}, + }, + "field_overrides": { + "id": {"generator": "generate_keyword", "params": {"choices": ["Helly R", "Mark S", "Irving B"]}}, + "promo_codes": {"generator": "generate_keyword", "params": {"choices": ["HOT_SUMMER", "TREATSYUM!"]}}, + "preferences.language": {"generator": "generate_keyword", "params": {"choices": ["Python", "English"]}}, + "payment_methods.type": {"generator": "generate_keyword", "params": {"choices": ["Visa", "Mastercard", "Cash", "Venmo"]}}, + "preferences.allergies": {"generator": "generate_keyword", "params": {"choices": ["Squirrels", "Cats"]}}, + "favorite_locations.name": {"generator": "generate_keyword", "params": {"choices": ["Austin", "NYC", "Miami"]}}, + }, + }, + } sdg_config = SDGConfig(**loaded_sdg_config) return sdg_config @@ -980,257 +714,101 @@ def basic_opensearch_index_mappings(self): return { "mappings": { "properties": { - "title": { - "type": "text", - "analyzer": "standard", - "fields": { - "keyword": { - "type": "keyword", - "ignore_above": 256 - } - } - }, - "description": { - "type": "text" - }, - "price": { - "type": "float" - }, - "created_at": { - "type": "date", - "format": "strict_date_optional_time||epoch_millis" - }, - "is_available": { - "type": "boolean" - }, - "category_id": { - "type": "integer" - }, - "tags": { - "type": "keyword" - } + "title": {"type": "text", "analyzer": "standard", "fields": {"keyword": {"type": "keyword", "ignore_above": 256}}}, + "description": {"type": "text"}, + "price": {"type": "float"}, + "created_at": {"type": "date", "format": "strict_date_optional_time||epoch_millis"}, + "is_available": {"type": "boolean"}, + "category_id": {"type": "integer"}, + "tags": {"type": "keyword"}, } }, - "settings": { - "number_of_shards": 1, - "number_of_replicas": 1 - } + "settings": {"number_of_shards": 1, "number_of_replicas": 1}, } @pytest.fixture def complex_opensearch_index_mappings(self): return { - "mappings": { - "dynamic": "strict", - "properties": { - "user": { + "mappings": { + "dynamic": "strict", + "properties": { + "user": { "type": "object", "properties": { - "id": { - "type": "keyword" - }, - "email": { - "type": "keyword" - }, - "name": { - "type": "text", - "fields": { - "keyword": { - "type": "keyword", - "ignore_above": 256 - }, - "completion": { - "type": "completion" - } - }, - "analyzer": "standard" - }, + "id": {"type": "keyword"}, + "email": {"type": "keyword"}, + "name": {"type": "text", "fields": {"keyword": {"type": "keyword", "ignore_above": 256}, "completion": {"type": "completion"}}, "analyzer": "standard"}, "address": { - "type": "object", - "properties": { - "street": { - "type": "text" - }, - "city": { - "type": "keyword" + "type": "object", + "properties": { + "street": {"type": "text"}, + "city": {"type": "keyword"}, + "state": {"type": "keyword"}, + "zip": {"type": "keyword"}, + "location": {"type": "geo_point"}, }, - "state": { - "type": "keyword" - }, - "zip": { - "type": "keyword" - }, - "location": { - "type": "geo_point" - } - } }, - "preferences": { - "type": "object", - "dynamic": True - } - } + "preferences": {"type": "object", "dynamic": True}, }, - "orders": { + }, + "orders": { "type": "nested", "properties": { - "id": { - "type": "keyword" - }, - "date": { - "type": "date", - "format": "strict_date_optional_time||epoch_millis" - }, - "amount": { - "type": "scaled_float", - "scaling_factor": 100 - }, - "status": { - "type": "keyword" - }, + "id": {"type": "keyword"}, + "date": {"type": "date", "format": "strict_date_optional_time||epoch_millis"}, + "amount": {"type": "scaled_float", "scaling_factor": 100}, + "status": {"type": "keyword"}, "items": { - "type": "nested", - "properties": { - "product_id": { - "type": "keyword" - }, - "name": { - "type": "text", - "fields": { - "keyword": { - "type": "keyword" - } - } - }, - "quantity": { - "type": "short" + "type": "nested", + "properties": { + "product_id": {"type": "keyword"}, + "name": {"type": "text", "fields": {"keyword": {"type": "keyword"}}}, + "quantity": {"type": "short"}, + "price": {"type": "float"}, + "categories": {"type": "keyword"}, }, - "price": { - "type": "float" - }, - "categories": { - "type": "keyword" - } - } }, "shipping_address": { - "type": "object", - "properties": { - "street": { - "type": "text" - }, - "city": { - "type": "keyword" - }, - "state": { - "type": "keyword" - }, - "zip": { - "type": "keyword" + "type": "object", + "properties": { + "street": {"type": "text"}, + "city": {"type": "keyword"}, + "state": {"type": "keyword"}, + "zip": {"type": "keyword"}, + "location": {"type": "geo_point"}, }, - "location": { - "type": "geo_point" - } - } - } - } - }, - "activity_log": { - "type": "nested", - "properties": { - "timestamp": { - "type": "date" }, - "action": { - "type": "keyword" - }, - "ip_address": { - "type": "ip" - }, - "details": { - "type": "object", - "enabled": False - } - } }, - "metadata": { + }, + "activity_log": { + "type": "nested", + "properties": {"timestamp": {"type": "date"}, "action": {"type": "keyword"}, "ip_address": {"type": "ip"}, "details": {"type": "object", "enabled": False}}, + }, + "metadata": { "type": "object", "properties": { - "created_at": { - "type": "date" - }, - "updated_at": { - "type": "date" - }, - "tags": { - "type": "keyword" - }, - "source": { - "type": "keyword" - }, - "version": { - "type": "integer" - } - } + "created_at": {"type": "date"}, + "updated_at": {"type": "date"}, + "tags": {"type": "keyword"}, + "source": {"type": "keyword"}, + "version": {"type": "integer"}, }, - "description": { + }, + "description": { "type": "text", "analyzer": "english", - "fields": { - "keyword": { - "type": "keyword", - "ignore_above": 256 - }, - "standard": { - "type": "text", - "analyzer": "standard" - } - } - }, - "ranking_scores": { - "type": "object", - "properties": { - "popularity": { - "type": "float" - }, - "relevance": { - "type": "float" - }, - "quality": { - "type": "float" - } - } - }, - "permissions": { - "type": "nested", - "properties": { - "user_id": { - "type": "keyword" - }, - "role": { - "type": "keyword" - }, - "granted_at": { - "type": "date" - } - } - } - } + "fields": {"keyword": {"type": "keyword", "ignore_above": 256}, "standard": {"type": "text", "analyzer": "standard"}}, }, - "settings": { - "number_of_shards": 3, - "number_of_replicas": 2, - "analysis": { - "analyzer": { - "email_analyzer": { - "type": "custom", - "tokenizer": "uax_url_email", - "filter": ["lowercase", "stop"] - } - } - } - } - } + "ranking_scores": {"type": "object", "properties": {"popularity": {"type": "float"}, "relevance": {"type": "float"}, "quality": {"type": "float"}}}, + "permissions": {"type": "nested", "properties": {"user_id": {"type": "keyword"}, "role": {"type": "keyword"}, "granted_at": {"type": "date"}}}, + }, + }, + "settings": { + "number_of_shards": 3, + "number_of_replicas": 2, + "analysis": {"analyzer": {"email_analyzer": {"type": "custom", "tokenizer": "uax_url_email", "filter": ["lowercase", "stop"]}}}, + }, + } @pytest.fixture def mapping_converter(self, mock_sdg_config): @@ -1248,7 +826,6 @@ def test_generating_documents_from_basic_mappings(self, mapping_converter, basic for field in fields: assert field in document - def test_generating_documents_for_complex_mappings(self, mapping_converter, complex_opensearch_index_mappings): mappings_with_generators = mapping_converter.transform_mapping_to_generators(complex_opensearch_index_mappings) @@ -1259,22 +836,7 @@ def test_generating_documents_for_complex_mappings(self, mapping_converter, comp assert field in document def test_generating_documents_for_with_overrides(self, mapping_converter): - basic_mappings = { - "properties": { - "id": { - "type": "keyword" - }, - "amount": { - "type": "float" - }, - "created_at": { - "type": "date" - }, - "status": { - "type": "keyword" - } - } - } + basic_mappings = {"properties": {"id": {"type": "keyword"}, "amount": {"type": "float"}, "created_at": {"type": "date"}, "status": {"type": "keyword"}}} mappings_with_generators_and_overrides = mapping_converter.transform_mapping_to_generators(basic_mappings) document = MappingConverter.generate_synthetic_document(transformed_mapping=mappings_with_generators_and_overrides) @@ -1288,13 +850,7 @@ def test_generating_documents_for_with_overrides(self, mapping_converter): def test_generate_sparse_vector(self, mapping_converter): """Test basic sparse_vector generation""" - mapping = { - "properties": { - "sparse_embedding": { - "type": "sparse_vector" - } - } - } + mapping = {"properties": {"sparse_embedding": {"type": "sparse_vector"}}} generators = mapping_converter.transform_mapping_to_generators(mapping) document = MappingConverter.generate_synthetic_document(generators) @@ -1316,25 +872,11 @@ def test_generate_sparse_vector_with_params(self, mapping_converter): """Test sparse_vector generation with custom parameters""" # Override to use custom params mapping_converter.mapping_config = { - "generator_overrides": { - "sparse_vector": { - "num_tokens": 5, - "min_weight": 0.1, - "max_weight": 0.9, - "token_id_start": 5000, - "token_id_step": 50 - } - }, - "field_overrides": {} + "generator_overrides": {"sparse_vector": {"num_tokens": 5, "min_weight": 0.1, "max_weight": 0.9, "token_id_start": 5000, "token_id_step": 50}}, + "field_overrides": {}, } - mapping = { - "properties": { - "embedding": { - "type": "sparse_vector" - } - } - } + mapping = {"properties": {"embedding": {"type": "sparse_vector"}}} generators = mapping_converter.transform_mapping_to_generators(mapping) document = MappingConverter.generate_synthetic_document(generators) @@ -1357,12 +899,7 @@ def test_generate_sparse_vector_in_complex_mapping(self, mapping_converter): "text": {"type": "text"}, "dense_vector": {"type": "knn_vector", "dimension": 3}, "sparse_vector": {"type": "sparse_vector"}, - "metadata": { - "type": "object", - "properties": { - "id": {"type": "keyword"} - } - } + "metadata": {"type": "object", "properties": {"id": {"type": "keyword"}}}, } } diff --git a/tests/synthetic_data_generator/synthetic_data_generator_test.py b/tests/synthetic_data_generator/synthetic_data_generator_test.py index acfca148..1be53690 100644 --- a/tests/synthetic_data_generator/synthetic_data_generator_test.py +++ b/tests/synthetic_data_generator/synthetic_data_generator_test.py @@ -12,8 +12,8 @@ from solrorbit.synthetic_data_generator.synthetic_data_generator import SyntheticDataGenerator from solrorbit.synthetic_data_generator.models import SyntheticDataGeneratorMetadata, SDGConfig -class TestSyntheticDataGeneratorWithCustomStrategy: +class TestSyntheticDataGeneratorWithCustomStrategy: @pytest.fixture def setup_sdg_metadata(self): return SyntheticDataGeneratorMetadata( @@ -22,27 +22,31 @@ def setup_sdg_metadata(self): custom_config_path="/path/to/config", custom_module_path="/path/to/module", output_path="/path/to/output", - total_size_gb=10 + total_size_gb=10, ) @pytest.fixture def mock_sdg_config(self): loaded_sdg_config = { - 'settings': {'workers': 8, 'max_file_size_gb': 1, 'chunk_size': 10000}, - 'CustomGenerationValues': { - 'custom_lists': {'dog_names': ['Hana', 'Youpie', 'Charlie', 'Lucy', 'Cooper', 'Luna', 'Rocky', 'Daisy', 'Buddy', 'Molly'], - 'dog_breeds': ['Jindo', 'Labrador', 'German Shepherd', 'Golden Retriever', 'Bulldog', - 'Poodle', 'Beagle', 'Rottweiler', 'Boxer', 'Dachshund', 'Chihuahua'], - 'treats': ['cookies', 'pup_cup', 'jerky'], 'license_plates': ['WOOF101', 'BARKATAMZN'], - 'tips': ['biscuits', 'cash'], - 'skills': ['sniffing', 'squirrel_chasing', 'bite_tail', 'smile'], - 'vehicle_types': ['sedan', 'suv', 'truck'], 'vehicle_makes': ['toyta', 'honda', 'nissan'], - 'vehicle_models': ['rav4', 'accord', 'murano'], 'vehicle_years': [2012, 2015, 2019], - 'vehicle_colors': ['white', 'red', 'blue', 'black', 'silver'], 'account_status': ['active', 'inactive']}, - 'custom_providers': ['NumericString', 'MultipleChoices'] - } - } - + "settings": {"workers": 8, "max_file_size_gb": 1, "chunk_size": 10000}, + "CustomGenerationValues": { + "custom_lists": { + "dog_names": ["Hana", "Youpie", "Charlie", "Lucy", "Cooper", "Luna", "Rocky", "Daisy", "Buddy", "Molly"], + "dog_breeds": ["Jindo", "Labrador", "German Shepherd", "Golden Retriever", "Bulldog", "Poodle", "Beagle", "Rottweiler", "Boxer", "Dachshund", "Chihuahua"], + "treats": ["cookies", "pup_cup", "jerky"], + "license_plates": ["WOOF101", "BARKATAMZN"], + "tips": ["biscuits", "cash"], + "skills": ["sniffing", "squirrel_chasing", "bite_tail", "smile"], + "vehicle_types": ["sedan", "suv", "truck"], + "vehicle_makes": ["toyta", "honda", "nissan"], + "vehicle_models": ["rav4", "accord", "murano"], + "vehicle_years": [2012, 2015, 2019], + "vehicle_colors": ["white", "red", "blue", "black", "silver"], + "account_status": ["active", "inactive"], + }, + "custom_providers": ["NumericString", "MultipleChoices"], + }, + } sdg_config = SDGConfig(**loaded_sdg_config) return sdg_config @@ -50,21 +54,18 @@ def mock_sdg_config(self): @pytest.fixture def mock_custom_module(self): mock_module = MagicMock - mock_module.generate_synthetic_document = MagicMock(return_value={'synthetic_field': 'synthetic_value'}) + mock_module.generate_synthetic_document = MagicMock(return_value={"synthetic_field": "synthetic_value"}) return mock_module @pytest.fixture def mock_dask_client(self): - mock_scheduler_info = {'id': '123456789', - 'services': {}, - 'type': 'Scheduler', - 'workers': {'127.0.0.1:12345': {'active': 0, - 'last-seen': 123412415.1234124, - 'name': '127.0.0.1:12345', - 'services': {}, - 'stored': 0, - 'time-delay': 0.12390819587}}} + mock_scheduler_info = { + "id": "123456789", + "services": {}, + "type": "Scheduler", + "workers": {"127.0.0.1:12345": {"active": 0, "last-seen": 123412415.1234124, "name": "127.0.0.1:12345", "services": {}, "stored": 0, "time-delay": 0.12390819587}}, + } mock_dask_client = MagicMock() mock_dask_client.scheduler_info = MagicMock(return_value=mock_scheduler_info) @@ -74,7 +75,7 @@ def mock_dask_client(self): @pytest.fixture def setup_custom_strategy(self, setup_sdg_metadata, mock_sdg_config, mock_custom_module): custom_strategy = MagicMock() - custom_strategy.generate_test_document.return_value = {'name': 'Shanks'} + custom_strategy.generate_test_document.return_value = {"name": "Shanks"} return custom_strategy @@ -83,7 +84,7 @@ def setup_custom_sdg(self, setup_sdg_metadata, mock_sdg_config, setup_custom_str return SyntheticDataGenerator(setup_sdg_metadata, mock_sdg_config, setup_custom_strategy) # Patch how it's used in SDG - @patch('solrorbit.synthetic_data_generator.synthetic_data_generator.get_client') + @patch("solrorbit.synthetic_data_generator.synthetic_data_generator.get_client") def test_generate_seeds_for_workers(self, mock_get_client, setup_custom_sdg, mock_dask_client): mock_get_client.return_value = mock_dask_client sdg = setup_custom_sdg @@ -99,13 +100,13 @@ def test_generate_test_document(self, setup_custom_sdg): result = sdg.generate_test_document() sdg.strategy.generate_test_document.assert_called_once() - assert result == {'name': 'Shanks'} + assert result == {"name": "Shanks"} def test_generate_dataset(self): pass -class TestSyntheticDataGeneratorWithMappingStrategy: +class TestSyntheticDataGeneratorWithMappingStrategy: @pytest.fixture def setup_sdg_metadata(self): return SyntheticDataGeneratorMetadata( @@ -114,34 +115,33 @@ def setup_sdg_metadata(self): custom_config_path="/path/to/config", custom_module_path="/path/to/module", output_path="/path/to/output", - total_size_gb=10 + total_size_gb=10, ) @pytest.fixture def mock_sdg_config(self): loaded_sdg_config = { - 'settings': {'workers': 8, 'max_file_size_gb': 1, 'chunk_size': 10000}, - 'MappingGenerationValues': { - 'generator_overrides': { - 'integer': {'min': 0, 'max': 20}, - 'long': {'min': 0, 'max': 1000}, - 'float': {'min': 0.0, 'max': 1.0}, - 'double': {'min': 0.0, 'max': 2000.0}, - 'date': {'start_date': '2020-01-01', 'end_date': '2023-01-01', 'format': 'yyyy-mm-dd'}, - 'text': {'must_include': ['lorem', 'ipsum']}, - 'keyword': {'choices': ['naruto', 'sakura', 'sasuke']} - }, - 'field_overrides': { - 'id': {'generator': 'generate_keyword', - 'params': {'choices': ['Helly R', 'Mark S', 'Irving B']}}, - 'promo_codes': {'generator': 'generate_keyword', 'params': {'choices': ['HOT_SUMMER', 'TREATSYUM!']}}, - 'preferences.language': {'generator': 'generate_keyword', 'params': {'choices': ['Python', 'English']}}, - 'payment_methods.type': {'generator': 'generate_keyword', 'params': {'choices': ['Visa', 'Mastercard', 'Cash', 'Venmo']}}, - 'preferences.allergies': {'generator': 'generate_keyword', 'params': {'choices': ['Squirrels', 'Cats']}}, - 'favorite_locations.name': {'generator': 'generate_keyword', 'params': {'choices': ['Austin', 'NYC', 'Miami']}} - } - } - } + "settings": {"workers": 8, "max_file_size_gb": 1, "chunk_size": 10000}, + "MappingGenerationValues": { + "generator_overrides": { + "integer": {"min": 0, "max": 20}, + "long": {"min": 0, "max": 1000}, + "float": {"min": 0.0, "max": 1.0}, + "double": {"min": 0.0, "max": 2000.0}, + "date": {"start_date": "2020-01-01", "end_date": "2023-01-01", "format": "yyyy-mm-dd"}, + "text": {"must_include": ["lorem", "ipsum"]}, + "keyword": {"choices": ["naruto", "sakura", "sasuke"]}, + }, + "field_overrides": { + "id": {"generator": "generate_keyword", "params": {"choices": ["Helly R", "Mark S", "Irving B"]}}, + "promo_codes": {"generator": "generate_keyword", "params": {"choices": ["HOT_SUMMER", "TREATSYUM!"]}}, + "preferences.language": {"generator": "generate_keyword", "params": {"choices": ["Python", "English"]}}, + "payment_methods.type": {"generator": "generate_keyword", "params": {"choices": ["Visa", "Mastercard", "Cash", "Venmo"]}}, + "preferences.allergies": {"generator": "generate_keyword", "params": {"choices": ["Squirrels", "Cats"]}}, + "favorite_locations.name": {"generator": "generate_keyword", "params": {"choices": ["Austin", "NYC", "Miami"]}}, + }, + }, + } sdg_config = SDGConfig(**loaded_sdg_config) return sdg_config @@ -149,21 +149,18 @@ def mock_sdg_config(self): @pytest.fixture def mock_custom_module(self): mock_module = MagicMock - mock_module.generate_synthetic_document = MagicMock(return_value={'synthetic_field': 'synthetic_value'}) + mock_module.generate_synthetic_document = MagicMock(return_value={"synthetic_field": "synthetic_value"}) return mock_module @pytest.fixture def mock_dask_client(self): - mock_scheduler_info = {'id': '123456789', - 'services': {}, - 'type': 'Scheduler', - 'workers': {'127.0.0.1:12345': {'active': 0, - 'last-seen': 123412415.1234124, - 'name': '127.0.0.1:12345', - 'services': {}, - 'stored': 0, - 'time-delay': 0.12390819587}}} + mock_scheduler_info = { + "id": "123456789", + "services": {}, + "type": "Scheduler", + "workers": {"127.0.0.1:12345": {"active": 0, "last-seen": 123412415.1234124, "name": "127.0.0.1:12345", "services": {}, "stored": 0, "time-delay": 0.12390819587}}, + } mock_dask_client = MagicMock() mock_dask_client.scheduler_info = MagicMock(return_value=mock_scheduler_info) @@ -173,7 +170,7 @@ def mock_dask_client(self): @pytest.fixture def setup_mapping_strategy(self, setup_sdg_metadata, mock_sdg_config, mock_custom_module): mapping_strategy = MagicMock() - mapping_strategy.generate_test_document.return_value = {'name': 'Shanks'} + mapping_strategy.generate_test_document.return_value = {"name": "Shanks"} return mapping_strategy @@ -182,7 +179,7 @@ def setup_custom_sdg(self, setup_sdg_metadata, mock_sdg_config, setup_mapping_st return SyntheticDataGenerator(setup_sdg_metadata, mock_sdg_config, setup_mapping_strategy) # Patch how it's used in SDG - @patch('solrorbit.synthetic_data_generator.synthetic_data_generator.get_client') + @patch("solrorbit.synthetic_data_generator.synthetic_data_generator.get_client") def test_generate_seeds_for_workers(self, mock_get_client, setup_custom_sdg, mock_dask_client): mock_get_client.return_value = mock_dask_client sdg = setup_custom_sdg @@ -198,4 +195,4 @@ def test_generate_test_document(self, setup_custom_sdg): result = sdg.generate_test_document() sdg.strategy.generate_test_document.assert_called_once() - assert result == {'name': 'Shanks'} + assert result == {"name": "Shanks"} diff --git a/tests/test_execution_orchestrator_test.py b/tests/test_execution_orchestrator_test.py index 81360ef3..805e3dae 100644 --- a/tests/test_execution_orchestrator_test.py +++ b/tests/test_execution_orchestrator_test.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -62,8 +62,7 @@ def unittest_pipeline(): def test_finds_available_pipelines(): expected = [ ["benchmark-only", "Assumes an already running search engine instance, runs a benchmark and publishes results"], - ["from-sources", "Builds Solr from source (git clone + Gradle assemble), provisions it locally, " - "runs a benchmark, and tears down."], + ["from-sources", "Builds Solr from source (git clone + Gradle assemble), provisions it locally, runs a benchmark, and tears down."], ["from-distribution", "Downloads a Solr distribution, provisions it locally, runs a benchmark, and tears down."], ["docker", "Starts Solr via Docker, runs a benchmark, and removes the container on teardown."], ] @@ -77,9 +76,7 @@ def test_prevents_running_an_unknown_pipeline(): cfg.add(config.Scope.benchmark, "test_run", "pipeline", "invalid") cfg.add(config.Scope.benchmark, "builder", "distribution.version", "5.0.0") - with pytest.raises( - exceptions.SystemSetupError, - match=r"Unknown pipeline \[invalid]. List the available pipelines with [\S]+? list pipelines."): + with pytest.raises(exceptions.SystemSetupError, match=r"Unknown pipeline \[invalid]. List the available pipelines with [\S]+? list pipelines."): test_run_orchestrator.run(cfg) @@ -99,13 +96,14 @@ def test_fails_without_benchmark_only_pipeline_in_docker(running_in_docker, unit cfg.add(config.Scope.benchmark, "test_run", "pipeline", "unit-test-pipeline") with pytest.raises( - exceptions.SystemSetupError, - match=re.escape( - "Only the [benchmark-only] pipeline is supported by the Docker image.\n" - "Add --pipeline=benchmark-only in your arguments and try again.\n" - "For more details read the docs for the benchmark-only pipeline in " - "https://solr.apache.org/guide/\n" - )): + exceptions.SystemSetupError, + match=re.escape( + "Only the [benchmark-only] pipeline is supported by the Docker image.\n" + "Add --pipeline=benchmark-only in your arguments and try again.\n" + "For more details read the docs for the benchmark-only pipeline in " + "https://solr.apache.org/guide/\n" + ), + ): test_run_orchestrator.run(cfg) @@ -119,6 +117,7 @@ def test_runs_a_known_pipeline(unittest_pipeline): unittest_pipeline.target.assert_called_once_with(cfg) + def test_runs_a_default_pipeline(benchmark_only_pipeline): # with no pipeline specified, should default to benchmark-only cfg = config.Config() diff --git a/tests/time_test.py b/tests/time_test.py index a74ff74a..abc3531f 100644 --- a/tests/time_test.py +++ b/tests/time_test.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an diff --git a/tests/unit/solr/conversion/test_detector.py b/tests/unit/solr/conversion/test_detector.py index cb529f9b..45e706d8 100644 --- a/tests/unit/solr/conversion/test_detector.py +++ b/tests/unit/solr/conversion/test_detector.py @@ -30,22 +30,12 @@ class TestWorkloadDetection(unittest.TestCase): def test_detect_solr_workload_with_collections_key(self): """Test detection of Solr workload by collections key.""" - workload = { - "name": "test-workload", - "collections": [ - {"name": "my-collection", "configset": "my-configset"} - ] - } + workload = {"name": "test-workload", "collections": [{"name": "my-collection", "configset": "my-configset"}]} self.assertFalse(is_opensearch_workload(workload)) def test_detect_opensearch_workload_with_indices_key(self): """Test detection of OpenSearch workload by indices key.""" - workload = { - "name": "test-workload", - "indices": [ - {"name": "my-index", "body": "index.json"} - ] - } + workload = {"name": "test-workload", "indices": [{"name": "my-index", "body": "index.json"}]} self.assertTrue(is_opensearch_workload(workload)) def test_detect_solr_by_operation_types(self): @@ -59,9 +49,9 @@ def test_detect_solr_by_operation_types(self): {"operation": {"operation-type": "create-collection"}}, {"operation": {"operation-type": "bulk-index"}}, {"operation": {"operation-type": "commit"}}, - ] + ], } - ] + ], } self.assertFalse(is_opensearch_workload(workload)) @@ -76,9 +66,9 @@ def test_detect_opensearch_by_operation_types(self): {"operation": {"operation-type": "create-index"}}, {"operation": {"operation-type": "index"}}, {"operation": {"operation-type": "force-merge"}}, - ] + ], } - ] + ], } self.assertTrue(is_opensearch_workload(workload)) @@ -86,37 +76,13 @@ def test_detect_by_param_source(self): """Test detection based on param-source values.""" opensearch_workload = { "name": "test-workload", - "challenges": [ - { - "name": "default", - "schedule": [ - { - "operation": { - "operation-type": "search", - "param-source": "opensearch-search-source" - } - } - ] - } - ] + "challenges": [{"name": "default", "schedule": [{"operation": {"operation-type": "search", "param-source": "opensearch-search-source"}}]}], } self.assertTrue(is_opensearch_workload(opensearch_workload)) solr_workload = { "name": "test-workload", - "challenges": [ - { - "name": "default", - "schedule": [ - { - "operation": { - "operation-type": "search", - "param-source": "solr-search-source" - } - } - ] - } - ] + "challenges": [{"name": "default", "schedule": [{"operation": {"operation-type": "search", "param-source": "solr-search-source"}}]}], } self.assertFalse(is_opensearch_workload(solr_workload)) @@ -135,12 +101,12 @@ def test_mixed_signals_scores_correctly(self): "name": "default", "schedule": [ {"operation": {"operation-type": "create-collection"}}, # Solr +2 - {"operation": {"operation-type": "bulk-index"}}, # Solr +2 - {"operation": {"operation-type": "commit"}}, # Solr +2 - {"operation": {"operation-type": "search"}}, # Neutral - ] + {"operation": {"operation-type": "bulk-index"}}, # Solr +2 + {"operation": {"operation-type": "commit"}}, # Solr +2 + {"operation": {"operation-type": "search"}}, # Neutral + ], } - ] + ], } self.assertFalse(is_opensearch_workload(workload)) diff --git a/tests/unit/solr/test_client.py b/tests/unit/solr/test_client.py index c89ddf88..c7d88819 100644 --- a/tests/unit/solr/test_client.py +++ b/tests/unit/solr/test_client.py @@ -78,6 +78,7 @@ def _make_client_with_mock_session(self): def test_upload_configset_success(self): import tempfile import os + client = self._make_client_with_mock_session() resp = _make_response(status_code=200) client._session.put.return_value = resp @@ -100,6 +101,7 @@ def test_upload_configset_success(self): def test_upload_configset_failure_raises(self): import tempfile import os + client = self._make_client_with_mock_session() resp = _make_response(status_code=500, text="Server Error") client._session.post.return_value = resp @@ -152,8 +154,7 @@ def test_create_collection_tlog_and_pull_replicas(self): client = self._make_client_with_mock_session() resp = _make_response(status_code=200, json_data={"responseHeader": {"status": 0}}) client._session.post.return_value = resp - client.create_collection("my-coll", "my-config", - tlog_replicas=2, pull_replicas=1) + client.create_collection("my-coll", "my-config", tlog_replicas=2, pull_replicas=1) _, kwargs = client._session.post.call_args payload = kwargs["json"] self.assertEqual(2, payload["tlogReplicas"]) @@ -225,6 +226,7 @@ def test_zip_contains_files(self): import tempfile import os import zipfile + with tempfile.TemporaryDirectory() as tmpdir: conf = os.path.join(tmpdir, "conf") os.makedirs(conf) diff --git a/tests/unit/solr/test_provisioner.py b/tests/unit/solr/test_provisioner.py index 49801367..cc2ef046 100644 --- a/tests/unit/solr/test_provisioner.py +++ b/tests/unit/solr/test_provisioner.py @@ -66,11 +66,13 @@ def test_solr_opts_sets_solr_opts_env(self): def test_multiple_variables_all_applied(self): """All known variables should be applied in a single call.""" - cc = _make_cluster_config({ - "heap_size": "8g", - "gc_tune": "-XX:+UseParallelGC", - "solr_opts": "-verbose:gc", - }) + cc = _make_cluster_config( + { + "heap_size": "8g", + "gc_tune": "-XX:+UseParallelGC", + "solr_opts": "-verbose:gc", + } + ) provisioner = SolrProvisioner(cluster_config=cc) env = provisioner._build_env() self.assertEqual("8g", env["SOLR_HEAP"]) diff --git a/tests/unit/solr/test_runner.py b/tests/unit/solr/test_runner.py index 7800b592..7b16a900 100644 --- a/tests/unit/solr/test_runner.py +++ b/tests/unit/solr/test_runner.py @@ -189,6 +189,7 @@ def test_bulk_index_returns_weight(self): def test_bulk_index_reports_errors(self): import pysolr + mock_sc = MagicMock() mock_sc.add.side_effect = pysolr.SolrError("Indexing error") @@ -279,6 +280,7 @@ def test_dict_query_body_posted_to_solr(self): class TestSolrCreateCollection(unittest.TestCase): def test_two_step_sequence(self): import tempfile + mock_sc = MagicMock() with tempfile.TemporaryDirectory() as tmpdir: @@ -309,9 +311,7 @@ def test_create_collection_passes_tlog_pull_replicas(self): runner = SolrCreateCollection() _run(runner(mock_sc, params)) - mock_sc.create_collection.assert_called_once_with( - "my-coll", "my-config", 2, 1, 2, 1 - ) + mock_sc.create_collection.assert_called_once_with("my-coll", "my-config", 2, 1, 2, 1) def test_create_collection_defaults_tlog_pull_to_zero(self): """Runner defaults tlog-replicas and pull-replicas to 0 when omitted.""" @@ -324,14 +324,13 @@ def test_create_collection_defaults_tlog_pull_to_zero(self): runner = SolrCreateCollection() _run(runner(mock_sc, params)) - mock_sc.create_collection.assert_called_once_with( - "my-coll", "my-config", 1, 1, 0, 0 - ) + mock_sc.create_collection.assert_called_once_with("my-coll", "my-config", 1, 1, 0, 0) class TestSolrDeleteCollection(unittest.TestCase): def test_delete_ignores_missing_by_default(self): from solrorbit.client import CollectionNotFoundError + mock_sc = MagicMock() mock_sc.delete_collection.side_effect = CollectionNotFoundError("not found") @@ -356,11 +355,13 @@ class TestRunnerRegistrationSmoke(unittest.TestCase): def setUp(self): from solrorbit.worker_coordinator.runner import register_default_runners + register_default_runners() def _run_via_framework(self, op_type, clients_dict, params): """Look up a registered runner and invoke it the same way execute_single does.""" from solrorbit.worker_coordinator.runner import runner_for + wrapped = runner_for(op_type) return _run(wrapped(clients_dict, params)) diff --git a/tests/unit/solr/test_schema_generator.py b/tests/unit/solr/test_schema_generator.py index 3fc288fa..44884553 100644 --- a/tests/unit/solr/test_schema_generator.py +++ b/tests/unit/solr/test_schema_generator.py @@ -59,14 +59,7 @@ def test_keyword_field_has_docvalues(self): def test_multi_field_with_raw_suffix(self): """Test multi-field with .raw sub-field creates separate field and copyField.""" - properties = { - "country_code": { - "type": "text", - "fields": { - "raw": {"type": "keyword"} - } - } - } + properties = {"country_code": {"type": "text", "fields": {"raw": {"type": "keyword"}}}} field_defs, copy_fields = translate_opensearch_mapping(properties) @@ -84,14 +77,7 @@ def test_multi_field_with_raw_suffix(self): def test_multi_field_with_keyword_suffix(self): """Test multi-field with .keyword sub-field.""" - properties = { - "name": { - "type": "text", - "fields": { - "keyword": {"type": "keyword"} - } - } - } + properties = {"name": {"type": "text", "fields": {"keyword": {"type": "keyword"}}}} field_defs, copy_fields = translate_opensearch_mapping(properties) @@ -106,15 +92,7 @@ def test_multi_field_with_keyword_suffix(self): def test_multi_field_with_multiple_subfields(self): """Test field with multiple sub-fields.""" - properties = { - "title": { - "type": "text", - "fields": { - "raw": {"type": "keyword"}, - "sort": {"type": "keyword"} - } - } - } + properties = {"title": {"type": "text", "fields": {"raw": {"type": "keyword"}, "sort": {"type": "keyword"}}}} field_defs, copy_fields = translate_opensearch_mapping(properties) diff --git a/tests/unit/solr/test_telemetry.py b/tests/unit/solr/test_telemetry.py index 4f0df25a..2a675762 100644 --- a/tests/unit/solr/test_telemetry.py +++ b/tests/unit/solr/test_telemetry.py @@ -37,13 +37,12 @@ # Helper: captured metrics store # --------------------------------------------------------------------------- + def _make_metrics_store(): """Return a MagicMock metrics store and a dict that captures stored values.""" stored = {} store = MagicMock() - store.put_value_cluster_level = MagicMock( - side_effect=lambda name, value, **kw: stored.update({name: value}) - ) + store.put_value_cluster_level = MagicMock(side_effect=lambda name, value, **kw: stored.update({name: value})) return store, stored @@ -51,6 +50,7 @@ def _make_metrics_store(): # _parse_prometheus_text # --------------------------------------------------------------------------- + class TestParsePrometheusText(unittest.TestCase): def test_basic_metric(self): text = "jvm_heap_used_bytes 1234567\n" @@ -70,10 +70,7 @@ def test_labels_stripped(self): self.assertAlmostEqual(42.0, result["http_requests_total"]) def test_labels_accumulated(self): - text = ( - 'requests_total{status="200"} 100\n' - 'requests_total{status="404"} 10\n' - ) + text = 'requests_total{status="200"} 100\nrequests_total{status="404"} 10\n' result = _parse_prometheus_text(text) self.assertAlmostEqual(110.0, result["requests_total"]) @@ -98,6 +95,7 @@ def test_invalid_value_skipped(self): # SolrTelemetryDevice base class helpers # --------------------------------------------------------------------------- + class TestBaseClassHelpers(unittest.TestCase): def _make_device(self, raw_metrics): client = MagicMock() @@ -150,6 +148,7 @@ def test_get_metric_prometheus_missing(self): # SolrJvmStats # --------------------------------------------------------------------------- + class TestSolrJvmStatsJson(unittest.TestCase): def _device(self, json_data): store, stored = _make_metrics_store() @@ -158,36 +157,48 @@ def _device(self, json_data): return SolrJvmStats(client, store), stored def test_heap_metrics_extracted(self): - data = {"metrics": {"solr.jvm": { - "memory.heap.used": 512_000_000, - "memory.heap.max": 2_000_000_000, - }}} + data = { + "metrics": { + "solr.jvm": { + "memory.heap.used": 512_000_000, + "memory.heap.max": 2_000_000_000, + } + } + } device, stored = self._device(data) device._collect() self.assertEqual(512_000_000, stored["jvm_heap_used_bytes"]) self.assertEqual(2_000_000_000, stored["jvm_heap_max_bytes"]) def test_gc_metrics_summed(self): - data = {"metrics": {"solr.jvm": { - "memory.heap.used": 1, - "memory.heap.max": 2, - "gc.G1-Young-Generation.count": 10, - "gc.G1-Old-Generation.count": 2, - "gc.G1-Young-Generation.time": 150, - "gc.G1-Old-Generation.time": 30, - }}} + data = { + "metrics": { + "solr.jvm": { + "memory.heap.used": 1, + "memory.heap.max": 2, + "gc.G1-Young-Generation.count": 10, + "gc.G1-Old-Generation.count": 2, + "gc.G1-Young-Generation.time": 150, + "gc.G1-Old-Generation.time": 30, + } + } + } device, stored = self._device(data) device._collect() self.assertEqual(12, stored["jvm_gc_count"]) self.assertEqual(180, stored["jvm_gc_time_ms"]) def test_gc_young_old_split(self): - data = {"metrics": {"solr.jvm": { - "gc.G1 Young Generation.count": 10, - "gc.G1 Young Generation.time": 100, - "gc.G1 Old Generation.count": 2, - "gc.G1 Old Generation.time": 50, - }}} + data = { + "metrics": { + "solr.jvm": { + "gc.G1 Young Generation.count": 10, + "gc.G1 Young Generation.time": 100, + "gc.G1 Old Generation.count": 2, + "gc.G1 Old Generation.time": 50, + } + } + } device, stored = self._device(data) device._collect() self.assertEqual(10, stored.get("jvm_gc_young_count")) @@ -196,20 +207,28 @@ def test_gc_young_old_split(self): self.assertEqual(50, stored.get("jvm_gc_old_time_ms")) def test_thread_metrics_extracted(self): - data = {"metrics": {"solr.jvm": { - "threads.count": 42, - "threads.peak.count": 50, - }}} + data = { + "metrics": { + "solr.jvm": { + "threads.count": 42, + "threads.peak.count": 50, + } + } + } device, stored = self._device(data) device._collect() self.assertEqual(42, stored["jvm_thread_count"]) self.assertEqual(50, stored["jvm_thread_peak_count"]) def test_buffer_pool_metrics_extracted(self): - data = {"metrics": {"solr.jvm": { - "buffers.direct.MemoryUsed": 1_048_576, - "buffers.mapped.MemoryUsed": 0, - }}} + data = { + "metrics": { + "solr.jvm": { + "buffers.direct.MemoryUsed": 1_048_576, + "buffers.mapped.MemoryUsed": 0, + } + } + } device, stored = self._device(data) device._collect() self.assertEqual(1_048_576, stored["jvm_buffer_pool_direct_bytes"]) @@ -247,6 +266,7 @@ def test_prometheus_heap_extracted(self): # SolrNodeStats # --------------------------------------------------------------------------- + class TestSolrNodeStats(unittest.TestCase): def _make_client(self, system_data=None, metrics_raw=None): client = MagicMock() @@ -269,10 +289,7 @@ def _get_side_effect(path): def test_cpu_extracted_from_system(self): store, stored = _make_metrics_store() - client = self._make_client( - system_data={"processCpuLoad": 0.45, "freePhysicalMemorySize": 4_000_000_000}, - metrics_raw={} - ) + client = self._make_client(system_data={"processCpuLoad": 0.45, "freePhysicalMemorySize": 4_000_000_000}, metrics_raw={}) device = SolrNodeStats(client, store) device._collect() @@ -281,10 +298,7 @@ def test_cpu_extracted_from_system(self): def test_file_descriptors_extracted(self): store, stored = _make_metrics_store() - client = self._make_client( - system_data={"openFileDescriptorCount": 128, "maxFileDescriptorCount": 65536}, - metrics_raw={} - ) + client = self._make_client(system_data={"openFileDescriptorCount": 128, "maxFileDescriptorCount": 65536}, metrics_raw={}) device = SolrNodeStats(client, store) device._collect() @@ -293,13 +307,18 @@ def test_file_descriptors_extracted(self): def test_query_handler_avg_latency_json(self): store, stored = _make_metrics_store() - metrics_data = {"metrics": {"solr.core": { - "QUERY./select.requests": 500, - "QUERY./select.errors": 3, - "QUERY./select.requestTimes.mean": 12.5, - }, "solr.jetty": { - "org.eclipse.jetty.server.handler.StatisticsHandler.requests": 1000, - }}} + metrics_data = { + "metrics": { + "solr.core": { + "QUERY./select.requests": 500, + "QUERY./select.errors": 3, + "QUERY./select.requestTimes.mean": 12.5, + }, + "solr.jetty": { + "org.eclipse.jetty.server.handler.StatisticsHandler.requests": 1000, + }, + } + } client = self._make_client(system_data={}, metrics_raw=metrics_data) device = SolrNodeStats(client, store) device._collect() @@ -311,10 +330,7 @@ def test_query_handler_avg_latency_json(self): def test_prometheus_format_node_stats(self): store, stored = _make_metrics_store() - prom_text = ( - "solr_metrics_core_query_requests_total 200\n" - "solr_metrics_core_query_errors_total 1\n" - ) + prom_text = "solr_metrics_core_query_requests_total 200\nsolr_metrics_core_query_errors_total 1\n" client = self._make_client(system_data={}, metrics_raw=prom_text) client.get_node_metrics.return_value = prom_text device = SolrNodeStats(client, store) @@ -327,22 +343,17 @@ def test_prometheus_format_node_stats(self): # SolrCollectionStats # --------------------------------------------------------------------------- + class TestSolrCollectionStats(unittest.TestCase): def test_num_docs_extracted_from_properties(self): store, stored = _make_metrics_store() props_resp = MagicMock() props_resp.ok = True - props_resp.json.return_value = { - "core-properties": { - "my-coll_shard1_replica1": {"numDocs": 5000} - } - } + props_resp.json.return_value = {"core-properties": {"my-coll_shard1_replica1": {"numDocs": 5000}}} luke_resp = MagicMock() luke_resp.ok = True - luke_resp.json.return_value = { - "index": {"numDocs": 5000, "deletedDocs": 50, "segmentCount": 3} - } + luke_resp.json.return_value = {"index": {"numDocs": 5000, "deletedDocs": 50, "segmentCount": 3}} def _get_side(path): if "luke" in path: @@ -367,9 +378,7 @@ def test_deleted_docs_extracted_from_luke(self): luke_resp = MagicMock() luke_resp.ok = True - luke_resp.json.return_value = { - "index": {"numDocs": 100, "deletedDocs": 10, "segmentCount": 2} - } + luke_resp.json.return_value = {"index": {"numDocs": 100, "deletedDocs": 10, "segmentCount": 2}} def _get_side(path): if "luke" in path: @@ -399,6 +408,7 @@ def test_luke_stats_on_error_no_crash(self): # SolrQueryStats # --------------------------------------------------------------------------- + class TestSolrQueryStats(unittest.TestCase): def _device(self, raw_metrics): store, stored = _make_metrics_store() @@ -407,14 +417,18 @@ def _device(self, raw_metrics): return SolrQueryStats(client, store), stored def test_latency_percentiles_json(self): - data = {"metrics": {"solr.core": { - "QUERY./select.requestTimes.p_50": 8.0, - "QUERY./select.requestTimes.p_99": 45.0, - "QUERY./select.requestTimes.p_99_9": 120.0, - "QUERY./select.requests": 1000, - "QUERY./select.errors": 5, - "CACHE.searcher.filterCache.hitratio": 0.94, - }}} + data = { + "metrics": { + "solr.core": { + "QUERY./select.requestTimes.p_50": 8.0, + "QUERY./select.requestTimes.p_99": 45.0, + "QUERY./select.requestTimes.p_99_9": 120.0, + "QUERY./select.requests": 1000, + "QUERY./select.errors": 5, + "CACHE.searcher.filterCache.hitratio": 0.94, + } + } + } device, stored = self._device(data) device._collect() @@ -426,19 +440,19 @@ def test_latency_percentiles_json(self): self.assertAlmostEqual(0.94, stored["query_cache_hit_ratio"]) def test_alternate_p999_key(self): - data = {"metrics": {"solr.core": { - "QUERY./select.requestTimes.p_999": 200.0, - }}} + data = { + "metrics": { + "solr.core": { + "QUERY./select.requestTimes.p_999": 200.0, + } + } + } device, stored = self._device(data) device._collect() self.assertAlmostEqual(200.0, stored["query_latency_p999_ms"]) def test_prometheus_latency(self): - prom_text = ( - "solr_metrics_core_query_request_times_p50_ms 10.0\n" - "solr_metrics_core_query_request_times_p99_ms 55.0\n" - "solr_metrics_core_query_requests_total 2000\n" - ) + prom_text = "solr_metrics_core_query_request_times_p50_ms 10.0\nsolr_metrics_core_query_request_times_p99_ms 55.0\nsolr_metrics_core_query_requests_total 2000\n" device, stored = self._device(prom_text) device._collect() @@ -456,6 +470,7 @@ def test_missing_metrics_no_error(self): # SolrIndexingStats # --------------------------------------------------------------------------- + class TestSolrIndexingStats(unittest.TestCase): def _device(self, raw_metrics): store, stored = _make_metrics_store() @@ -464,13 +479,17 @@ def _device(self, raw_metrics): return SolrIndexingStats(client, store), stored def test_indexing_stats_json(self): - data = {"metrics": {"solr.core": { - "UPDATE./update.requests": 500, - "UPDATE./update.errors": 2, - "UPDATE./update.requestTimes.mean": 3.5, - "INDEX.merge.major.running": 0, - "INDEX.merge.minor.running": 1, - }}} + data = { + "metrics": { + "solr.core": { + "UPDATE./update.requests": 500, + "UPDATE./update.errors": 2, + "UPDATE./update.requestTimes.mean": 3.5, + "INDEX.merge.major.running": 0, + "INDEX.merge.minor.running": 1, + } + } + } device, stored = self._device(data) device._collect() @@ -481,11 +500,7 @@ def test_indexing_stats_json(self): self.assertEqual(1, stored["index_merge_minor_running"]) def test_prometheus_indexing_stats(self): - prom_text = ( - "solr_metrics_core_update_requests_total 300\n" - "solr_metrics_core_update_errors_total 0\n" - "solr_metrics_core_index_merge_major_running 2\n" - ) + prom_text = "solr_metrics_core_update_requests_total 300\nsolr_metrics_core_update_errors_total 0\nsolr_metrics_core_index_merge_major_running 2\n" device, stored = self._device(prom_text) device._collect() @@ -503,6 +518,7 @@ def test_missing_core_section_no_error(self): # SolrCacheStats # --------------------------------------------------------------------------- + class TestSolrCacheStats(unittest.TestCase): def _device(self, raw_metrics): store = MagicMock() @@ -518,18 +534,22 @@ def _capture(name, value, unit, task="", operation_type="", meta_data=None): return SolrCacheStats(client, store), stored def test_cache_stats_json(self): - data = {"metrics": {"solr.core": { - "CACHE.searcher.queryResultCache.hits": 8000, - "CACHE.searcher.queryResultCache.inserts": 500, - "CACHE.searcher.queryResultCache.evictions": 100, - "CACHE.searcher.queryResultCache.ramBytesUsed": 1_000_000, - "CACHE.searcher.queryResultCache.hitratio": 0.94, - "CACHE.searcher.filterCache.hits": 5000, - "CACHE.searcher.filterCache.inserts": 200, - "CACHE.searcher.filterCache.evictions": 10, - "CACHE.searcher.filterCache.ramBytesUsed": 500_000, - "CACHE.searcher.filterCache.hitratio": 0.96, - }}} + data = { + "metrics": { + "solr.core": { + "CACHE.searcher.queryResultCache.hits": 8000, + "CACHE.searcher.queryResultCache.inserts": 500, + "CACHE.searcher.queryResultCache.evictions": 100, + "CACHE.searcher.queryResultCache.ramBytesUsed": 1_000_000, + "CACHE.searcher.queryResultCache.hitratio": 0.94, + "CACHE.searcher.filterCache.hits": 5000, + "CACHE.searcher.filterCache.inserts": 200, + "CACHE.searcher.filterCache.evictions": 10, + "CACHE.searcher.filterCache.ramBytesUsed": 500_000, + "CACHE.searcher.filterCache.hitratio": 0.96, + } + } + } device, stored = self._device(data) device._collect() @@ -542,10 +562,7 @@ def test_cache_stats_json(self): self.assertAlmostEqual(0.96, stored["cache_hit_ratio:filterCache"]) def test_prometheus_cache_stats(self): - prom_text = ( - "solr_metrics_core_cache_hits_total 13000\n" - "solr_metrics_core_cache_evictions_total 110\n" - ) + prom_text = "solr_metrics_core_cache_hits_total 13000\nsolr_metrics_core_cache_evictions_total 110\n" device, stored = self._device(prom_text) device._collect() @@ -562,6 +579,7 @@ def test_empty_metrics_no_error(self): # Polling thread lifecycle # --------------------------------------------------------------------------- + class TestTelemetryPollingThread(unittest.TestCase): def test_start_and_stop(self): """Verify that the polling thread starts and stops cleanly.""" @@ -571,9 +589,7 @@ def test_start_and_stop(self): collected = [] metrics_store = MagicMock() - metrics_store.put_value_cluster_level = MagicMock( - side_effect=lambda **kw: collected.append(kw) - ) + metrics_store.put_value_cluster_level = MagicMock(side_effect=lambda **kw: collected.append(kw)) device = SolrJvmStats(client, metrics_store, sample_interval_s=0.05) device.on_benchmark_start() diff --git a/tests/unit/solr/test_workload_converter.py b/tests/unit/solr/test_workload_converter.py index 1927a4cf..c00743bb 100644 --- a/tests/unit/solr/test_workload_converter.py +++ b/tests/unit/solr/test_workload_converter.py @@ -81,10 +81,13 @@ def _make_source_workload(self, tmpdir, workload_dict): def test_renames_indices_to_collections(self): with tempfile.TemporaryDirectory() as src, tempfile.TemporaryDirectory() as dst: - self._make_source_workload(src, { - "indices": [{"name": "my-index"}], - "challenges": [], - }) + self._make_source_workload( + src, + { + "indices": [{"name": "my-index"}], + "challenges": [], + }, + ) result = convert_opensearch_workload(src, dst) self.assertEqual(0, len(result["issues"])) @@ -96,28 +99,31 @@ def test_renames_indices_to_collections(self): def test_renames_operation_types(self): with tempfile.TemporaryDirectory() as src, tempfile.TemporaryDirectory() as dst: - self._make_source_workload(src, { - "indices": [], - "challenges": [ - { - "name": "default", - "schedule": [ - { - "operation": { - "name": "index-docs", - "operation-type": "bulk", + self._make_source_workload( + src, + { + "indices": [], + "challenges": [ + { + "name": "default", + "schedule": [ + { + "operation": { + "name": "index-docs", + "operation-type": "bulk", + }, }, - }, - { - "operation": { - "name": "run-search", - "operation-type": "search", + { + "operation": { + "name": "run-search", + "operation-type": "search", + }, }, - }, - ], - } - ], - }) + ], + } + ], + }, + ) convert_opensearch_workload(src, dst) with open(os.path.join(dst, "workload.json")) as f: out = json.load(f) @@ -127,23 +133,26 @@ def test_renames_operation_types(self): def test_translates_search_body_to_solr_json_dsl(self): with tempfile.TemporaryDirectory() as src, tempfile.TemporaryDirectory() as dst: - self._make_source_workload(src, { - "indices": [], - "challenges": [ - { - "name": "default", - "schedule": [ - { - "operation": { - "name": "search-all", - "operation-type": "search", - "body": {"query": {"match_all": {}}, "size": 10}, + self._make_source_workload( + src, + { + "indices": [], + "challenges": [ + { + "name": "default", + "schedule": [ + { + "operation": { + "name": "search-all", + "operation-type": "search", + "body": {"query": {"match_all": {}}, "size": 10}, + } } - } - ], - } - ], - }) + ], + } + ], + }, + ) convert_opensearch_workload(src, dst) with open(os.path.join(dst, "workload.json")) as f: out = json.load(f) @@ -155,22 +164,25 @@ def test_translates_search_body_to_solr_json_dsl(self): def test_unsupported_ops_are_skipped(self): with tempfile.TemporaryDirectory() as src, tempfile.TemporaryDirectory() as dst: - self._make_source_workload(src, { - "indices": [], - "challenges": [ - { - "name": "default", - "schedule": [ - { - "operation": { - "name": "snap", - "operation-type": "create-snapshot", + self._make_source_workload( + src, + { + "indices": [], + "challenges": [ + { + "name": "default", + "schedule": [ + { + "operation": { + "name": "snap", + "operation-type": "create-snapshot", + } } - } - ], - } - ], - }) + ], + } + ], + }, + ) result = convert_opensearch_workload(src, dst) self.assertIn("snap", result["skipped"]) @@ -235,14 +247,7 @@ def test_sort_is_extracted(self): self.assertIn("desc", result["sort"]) def test_terms_aggregation_converted_to_facet(self): - body = { - "query": {"match_all": {}}, - "aggs": { - "vendors": { - "terms": {"field": "vendor_id", "size": 5} - } - } - } + body = {"query": {"match_all": {}}, "aggs": {"vendors": {"terms": {"field": "vendor_id", "size": 5}}}} result = translate_to_solr_json_dsl(body) self.assertIn("facet", result) facet = result["facet"]["vendors"] @@ -260,7 +265,7 @@ def test_date_histogram_converted_to_range_facet(self): "calendar_interval": "month", } } - } + }, } result = translate_to_solr_json_dsl(body) facet = result["facet"]["pickup_by_month"] @@ -269,10 +274,7 @@ def test_date_histogram_converted_to_range_facet(self): self.assertEqual("+1MONTH", facet["gap"]) def test_avg_metric_aggregation(self): - body = { - "query": {"match_all": {}}, - "aggs": {"avg_fare": {"avg": {"field": "fare_amount"}}} - } + body = {"query": {"match_all": {}}, "aggs": {"avg_fare": {"avg": {"field": "fare_amount"}}}} result = translate_to_solr_json_dsl(body) self.assertEqual("avg(fare_amount)", result["facet"]["avg_fare"]) @@ -294,14 +296,7 @@ def test_empty_returns_empty(self): self.assertEqual({}, _convert_aggregations_to_facets(None)) def test_nested_agg_within_terms(self): - aggs = { - "by_vendor": { - "terms": {"field": "vendor_id", "size": 10}, - "aggs": { - "avg_fare": {"avg": {"field": "fare_amount"}} - } - } - } + aggs = {"by_vendor": {"terms": {"field": "vendor_id", "size": 10}, "aggs": {"avg_fare": {"avg": {"field": "fare_amount"}}}}} result = _convert_aggregations_to_facets(aggs) self.assertIn("by_vendor", result) self.assertIn("facet", result["by_vendor"]) diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py index 618b98b9..5db1e0f9 100644 --- a/tests/unit/test_telemetry.py +++ b/tests/unit/test_telemetry.py @@ -38,6 +38,7 @@ # Helpers # --------------------------------------------------------------------------- + def _make_admin_client(base_url="http://localhost:8983"): """Return a mock SolrAdminClient with a shared session.""" client = MagicMock() @@ -51,9 +52,7 @@ def _make_metrics_store(): """Return a MagicMock metrics store and a dict capturing stored values.""" stored = {} store = MagicMock() - store.put_value_cluster_level = MagicMock( - side_effect=lambda name, value, unit="": stored.update({name: value}) - ) + store.put_value_cluster_level = MagicMock(side_effect=lambda name, value, unit="": stored.update({name: value})) return store, stored @@ -61,8 +60,8 @@ def _make_metrics_store(): # T139: SegmentStats (Luke API) # --------------------------------------------------------------------------- -class TestSegmentStats(unittest.TestCase): +class TestSegmentStats(unittest.TestCase): def _make_response(self, json_data, status_code=200): resp = MagicMock() resp.status_code = status_code @@ -140,17 +139,10 @@ def test_segment_stats_connection_error_graceful(self): } } -CORE_STATUS_RESPONSE = { - "status": { - "my_coll_shard1_replica_n1": { - "index": {"numDocs": 500, "sizeInBytes": 10240} - } - } -} +CORE_STATUS_RESPONSE = {"status": {"my_coll_shard1_replica_n1": {"index": {"numDocs": 500, "sizeInBytes": 10240}}}} class TestShardStats(unittest.TestCase): - def _make_session_resp(self, json_data, status_code=200): resp = MagicMock() resp.status_code = status_code @@ -188,9 +180,7 @@ def test_shard_stats_recorder_emits_metrics(self): metrics_store, stored = _make_metrics_store() admin_client.get_clusterstatus.return_value = CLUSTERSTATUS_RESPONSE - admin_client.get_core_status.return_value = { - "index": {"numDocs": 500, "sizeInBytes": 10240} - } + admin_client.get_core_status.return_value = {"index": {"numDocs": 500, "sizeInBytes": 10240}} recorder = ShardStatsRecorder(admin_client=admin_client, metrics_store=metrics_store, sample_interval=60) recorder.record() @@ -213,6 +203,7 @@ def test_shard_stats_connection_error_graceful(self): def test_shard_stats_invalid_interval(self): """ShardStats raises SystemSetupError for non-positive sample interval.""" from solrorbit.exceptions import SystemSetupError + admin_client, _ = _make_admin_client() metrics_store, _ = _make_metrics_store() with self.assertRaises(SystemSetupError): @@ -235,7 +226,6 @@ def test_shard_stats_invalid_interval(self): class TestClusterEnvironmentInfo(unittest.TestCase): - def _make_session_resp(self, json_data, status_code=200): resp = MagicMock() resp.status_code = status_code @@ -249,9 +239,7 @@ def test_cluster_env_info_stores_version_and_jvm(self): meta_store = {} metrics_store = MagicMock() - metrics_store.add_meta_info = MagicMock( - side_effect=lambda scope, node, key, value: meta_store.update({key: value}) - ) + metrics_store.add_meta_info = MagicMock(side_effect=lambda scope, node, key, value: meta_store.update({key: value})) system_resp = self._make_session_resp(SYSTEM_INFO_RESPONSE) admin_client.raw_request.return_value = system_resp @@ -279,8 +267,8 @@ def test_cluster_env_info_failure_graceful(self): # T142: JVM device pipeline-skip behavior # --------------------------------------------------------------------------- -class TestFlightRecorderPipelineSkip(unittest.TestCase): +class TestFlightRecorderPipelineSkip(unittest.TestCase): def test_jfr_benchmark_only_returns_empty(self): """FlightRecorder returns [] when pipeline is benchmark-only.""" with tempfile.TemporaryDirectory() as log_root: @@ -318,7 +306,6 @@ def test_jfr_no_pipeline_key_returns_flags(self): class TestGcPipelineSkip(unittest.TestCase): - def test_gc_benchmark_only_returns_empty(self): """Gc returns [] when pipeline is benchmark-only.""" with tempfile.TemporaryDirectory() as log_root: @@ -344,7 +331,6 @@ def test_gc_docker_returns_xlog_flag(self): class TestJitCompilerPipelineSkip(unittest.TestCase): - def test_jit_benchmark_only_returns_empty(self): """JitCompiler returns [] when pipeline is benchmark-only.""" with tempfile.TemporaryDirectory() as log_root: @@ -369,7 +355,6 @@ def test_jit_no_telemetry_params(self): class TestHeapdumpDockerSupport(unittest.TestCase): - def test_heapdump_local_calls_jmap(self): """Heapdump calls jmap directly for non-Docker nodes.""" with tempfile.TemporaryDirectory() as log_root: diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py index 5047a451..f5768141 100644 --- a/tests/utils/__init__.py +++ b/tests/utils/__init__.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an diff --git a/tests/utils/collections_test.py b/tests/utils/collections_test.py index 347dd34d..63281420 100644 --- a/tests/utils/collections_test.py +++ b/tests/utils/collections_test.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -54,13 +54,8 @@ def test_can_merge_nested_dicts(self): d1 = { "params": { "cluster-config-instance": "4gheap", - "cluster-config-instance-params": { - "additional_cluster_settings": { - "indices.queries.cache.size": "5%", - "transport.tcp.compress": True - } - }, - "unique-param": "foobar" + "cluster-config-instance-params": {"additional_cluster_settings": {"indices.queries.cache.size": "5%", "transport.tcp.compress": True}}, + "unique-param": "foobar", } } @@ -69,53 +64,30 @@ def test_can_merge_nested_dicts(self): assert dict(collections.merge_dicts(d1, d2)) == { "params": { "cluster-config-instance-params": { - "additional_cluster_settings": { - "indices.queries.cache.size": "5%", - "transport.tcp.compress": True - }, - "data_paths": "/mnt/local_ssd"}, + "additional_cluster_settings": {"indices.queries.cache.size": "5%", "transport.tcp.compress": True}, + "data_paths": "/mnt/local_ssd", + }, "cluster-config-instance": "4gheap", - "unique-param": "foobar" + "unique-param": "foobar", } } def test_can_merge_nested_lists_in_dicts(self): - d1 = { - "params": { - "foo": [1, 2, 3] - } - } + d1 = {"params": {"foo": [1, 2, 3]}} - d2 = { - "params": { - "foo": [3, 4, 5] - } - } + d2 = {"params": {"foo": [3, 4, 5]}} - assert dict(collections.merge_dicts(d1, d2)) == { - "params": { - "foo": [1, 2, 3, 4, 5] - } - } + assert dict(collections.merge_dicts(d1, d2)) == {"params": {"foo": [1, 2, 3, 4, 5]}} def test_can_merge_nested_booleans_in_dicts(self): - d1 = { - "params": { - "foo": True, - "other": [1, 2, 3] - } - } + d1 = {"params": {"foo": True, "other": [1, 2, 3]}} - d2 = { - "params": { - "foo": False - } - } + d2 = {"params": {"foo": False}} assert dict(collections.merge_dicts(d1, d2)) == { "params": { # d2 wins "foo": False, - "other": [1, 2, 3] + "other": [1, 2, 3], } } diff --git a/tests/utils/console_test.py b/tests/utils/console_test.py index ee15eedc..34cb63c8 100644 --- a/tests/utils/console_test.py +++ b/tests/utils/console_test.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -77,9 +77,7 @@ def test_println_randomized_dockertrue_or_istty_and_isnotquiet(self, patched_pri console.BENCHMARK_RUNNING_IN_DOCKER = not random_boolean console.println(msg="Unittest message") - patched_print.assert_called_once_with( - "Unittest message", end="\n", flush=False, file=sys.stdout - ) + patched_print.assert_called_once_with("Unittest message", end="\n", flush=False, file=sys.stdout) @mock.patch("sys.stdout.isatty") @mock.patch("builtins.print") @@ -88,9 +86,7 @@ def test_println_randomized_assume_tty_or_istty_and_isnotquiet(self, patched_pri console.init(quiet=False, assume_tty=not random_boolean) patched_isatty.return_value = random_boolean console.println(msg="Unittest message") - patched_print.assert_called_once_with( - "Unittest message", end="\n", flush=False, file=sys.stdout - ) + patched_print.assert_called_once_with("Unittest message", end="\n", flush=False, file=sys.stdout) @mock.patch("sys.stdout.isatty") @mock.patch("builtins.print") @@ -109,9 +105,7 @@ def test_println_force_prints_even_when_quiet(self, patched_print, patched_isatt patched_isatty.return_value = random.choice([True, False]) console.println(msg="Unittest message", force=True) - patched_print.assert_called_once_with( - "Unittest message", end="\n", flush=False, file=sys.stdout - ) + patched_print.assert_called_once_with("Unittest message", end="\n", flush=False, file=sys.stdout) # pytest style class names need to start with Test and don't need to subclass @@ -180,10 +174,7 @@ def test_prints_when_isnotquiet_and_randomized_docker_or_istty(self, patched_isa mock_printer = mock.Mock() progress_publisher = console.CmdLineProgressResultsPublisher(width=width, printer=mock_printer) progress_publisher.print(message=message, progress=".") - mock_printer.assert_has_calls([ - mock.call(" " * width, end=""), - mock.call("\x1b[{}D{}{}.".format(width, message, " "*(width-len(message)-1)), end="") - ]) + mock_printer.assert_has_calls([mock.call(" " * width, end=""), mock.call("\x1b[{}D{}{}.".format(width, message, " " * (width - len(message) - 1)), end="")]) patched_flush.assert_called_once_with() @mock.patch("sys.stdout.flush") @@ -201,10 +192,7 @@ def test_prints_when_isnotquiet_and_randomized_assume_tty_or_istty(self, patched mock_printer = mock.Mock() progress_publisher = console.CmdLineProgressResultsPublisher(width=width, printer=mock_printer) progress_publisher.print(message=message, progress=".") - mock_printer.assert_has_calls([ - mock.call(" " * width, end=""), - mock.call("\x1b[{}D{}{}.".format(width, message, " "*(width-len(message)-1)), end="") - ]) + mock_printer.assert_has_calls([mock.call(" " * width, end=""), mock.call("\x1b[{}D{}{}.".format(width, message, " " * (width - len(message) - 1)), end="")]) patched_flush.assert_called_once_with() @mock.patch("sys.stdout.flush") diff --git a/tests/utils/convert_test.py b/tests/utils/convert_test.py index abe16de5..34fbcd65 100644 --- a/tests/utils/convert_test.py +++ b/tests/utils/convert_test.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an diff --git a/tests/utils/dataset_helper.py b/tests/utils/dataset_helper.py index 03144342..d7b75a86 100644 --- a/tests/utils/dataset_helper.py +++ b/tests/utils/dataset_helper.py @@ -17,7 +17,7 @@ class DataSetBuildContext: - """ Data class capturing information needed to build a particular data set + """Data class capturing information needed to build a particular data set Attributes: data_set_context: Indicator of what the data set is used for, @@ -41,7 +41,7 @@ def get_type(self) -> np.dtype: class DataSetBuilder(ABC): - """ Abstract builder used to create a build a collection of data sets + """Abstract builder used to create a build a collection of data sets Attributes: data_set_build_contexts: list of data set build contexts that builder @@ -52,7 +52,7 @@ def __init__(self): self.data_set_build_contexts = list() def add_data_set_build_context(self, data_set_build_context: DataSetBuildContext): - """ Adds a data set build context to list of contexts to be built. + """Adds a data set build context to list of contexts to be built. Args: data_set_build_context: DataSetBuildContext to be added to list @@ -65,14 +65,13 @@ def add_data_set_build_context(self, data_set_build_context: DataSetBuildContext return self def build(self): - """ Builds and serializes all data sets build contexts - """ + """Builds and serializes all data sets build contexts""" for data_set_build_context in self.data_set_build_contexts: self._build_data_set(data_set_build_context) @abstractmethod def _build_data_set(self, context: DataSetBuildContext): - """ Builds an individual data set + """Builds an individual data set Args: context: DataSetBuildContext of data set to be built @@ -80,7 +79,7 @@ def _build_data_set(self, context: DataSetBuildContext): @abstractmethod def _validate_data_set_context(self, context: DataSetBuildContext): - """ Validates that data set context can be added to this builder + """Validates that data set context can be added to this builder Args: context: DataSetBuildContext to be validated @@ -88,29 +87,23 @@ def _validate_data_set_context(self, context: DataSetBuildContext): class HDF5Builder(DataSetBuilder): - def __init__(self): super().__init__() self.data_set_meta_data = dict() def _validate_data_set_context(self, context: DataSetBuildContext): if context.path not in self.data_set_meta_data.keys(): - self.data_set_meta_data[context.path] = { - context.data_set_context: context - } + self.data_set_meta_data[context.path] = {context.data_set_context: context} return - if context.data_set_context in \ - self.data_set_meta_data[context.path].keys(): - raise IllegalDataSetBuildContext("Path and context for data set " - "are already present in builder.") + if context.data_set_context in self.data_set_meta_data[context.path].keys(): + raise IllegalDataSetBuildContext("Path and context for data set are already present in builder.") - self.data_set_meta_data[context.path][context.data_set_context] = \ - context + self.data_set_meta_data[context.path][context.data_set_context] = context @staticmethod def _validate_extension(context: DataSetBuildContext): - ext = context.path.split('.')[-1] + ext = context.path.split(".")[-1] if ext != HDF5DataSet.FORMAT_NAME: raise IllegalDataSetBuildContext("Invalid file extension") @@ -119,70 +112,56 @@ def _build_data_set(self, context: DataSetBuildContext): # For HDF5, because multiple data sets can be grouped in the same file, # we will build data sets in memory and not write to disk until # _flush_data_sets_to_disk is called - with h5py.File(context.path, 'a') as hf: - hf.create_dataset( - HDF5DataSet.parse_context(context.data_set_context), - data=context.vectors - ) + with h5py.File(context.path, "a") as hf: + hf.create_dataset(HDF5DataSet.parse_context(context.data_set_context), data=context.vectors) class BigANNVectorBuilder(DataSetBuilder): - def _validate_data_set_context(self, context: DataSetBuildContext): self._validate_extension(context) # prevent the duplication of paths for data sets data_set_paths = [c.path for c in self.data_set_build_contexts] if any(data_set_paths.count(x) > 1 for x in data_set_paths): - raise IllegalDataSetBuildContext("Build context paths have to be " - "unique.") + raise IllegalDataSetBuildContext("Build context paths have to be unique.") @staticmethod def _validate_extension(context: DataSetBuildContext): - ext = context.path.split('.')[-1] + ext = context.path.split(".")[-1] if ext not in [BigANNVectorDataSet.U8BIN_EXTENSION, BigANNVectorDataSet.FBIN_EXTENSION]: raise IllegalDataSetBuildContext("Invalid file extension: {}".format(ext)) - if ext == BigANNVectorDataSet.U8BIN_EXTENSION and context.get_type() != \ - np.uint8: - raise IllegalDataSetBuildContext("Invalid data type for {} ext." - .format(BigANNVectorDataSet - .U8BIN_EXTENSION)) + if ext == BigANNVectorDataSet.U8BIN_EXTENSION and context.get_type() != np.uint8: + raise IllegalDataSetBuildContext("Invalid data type for {} ext.".format(BigANNVectorDataSet.U8BIN_EXTENSION)) - if ext == BigANNVectorDataSet.FBIN_EXTENSION and context.get_type() != \ - np.float32: - raise IllegalDataSetBuildContext("Invalid data type for {} ext." - .format(BigANNVectorDataSet - .FBIN_EXTENSION)) + if ext == BigANNVectorDataSet.FBIN_EXTENSION and context.get_type() != np.float32: + raise IllegalDataSetBuildContext("Invalid data type for {} ext.".format(BigANNVectorDataSet.FBIN_EXTENSION)) def _build_data_set(self, context: DataSetBuildContext): num_vectors = context.get_num_rows() dimension = context.get_row_length() - with open(context.path, 'wb') as f: + with open(context.path, "wb") as f: f.write(int.to_bytes(num_vectors, 4, "little")) f.write(int.to_bytes(dimension, 4, "little")) context.vectors.tofile(f) class BigANNGroundTruthBuilder(BigANNVectorBuilder): - @staticmethod def _validate_extension(context: DataSetBuildContext): - ext = context.path.split('.')[-1] + ext = context.path.split(".")[-1] if ext not in [BigANNGroundTruthDataSet.BIN_EXTENSION]: raise IllegalDataSetBuildContext("Invalid file extension: {}".format(ext)) if context.get_type() != np.float32: - raise IllegalDataSetBuildContext("Invalid data type for {} ext." - .format(BigANNGroundTruthDataSet - .BIN_EXTENSION)) + raise IllegalDataSetBuildContext("Invalid data type for {} ext.".format(BigANNGroundTruthDataSet.BIN_EXTENSION)) def _build_data_set(self, context: DataSetBuildContext): num_queries = context.get_num_rows() k = context.get_row_length() - with open(context.path, 'wb') as f: + with open(context.path, "wb") as f: # Writing number of queries f.write(int.to_bytes(num_queries, 4, "little")) # Writing number of neighbors in a query @@ -214,6 +193,7 @@ def create_attributes(num_vectors: int) -> np.ndarray: return random_vector + def create_parent_ids(num_vectors: int, group_size: int = 10) -> np.ndarray: num_ids = (num_vectors + group_size - 1) // group_size # Calculate total number of different IDs needed ids = np.arange(1, num_ids + 1) # Create an array of IDs starting from 1 @@ -234,29 +214,18 @@ class IllegalDataSetBuildContext(Exception): """ def __init__(self, message: str): - self.message = f'{message}' + self.message = f"{message}" super().__init__(self.message) -def create_data_set( - num_vectors: int, - dimension: int, - extension: str, - data_set_context: Context, - data_set_dir, - file_path: str = None -) -> str: +def create_data_set(num_vectors: int, dimension: int, extension: str, data_set_context: Context, data_set_dir, file_path: str = None) -> str: if file_path: data_set_path = file_path else: - file_name_base = ''.join(random.choice(string.ascii_letters) for _ in - range(DEFAULT_RANDOM_STRING_LENGTH)) + file_name_base = "".join(random.choice(string.ascii_letters) for _ in range(DEFAULT_RANDOM_STRING_LENGTH)) data_set_file_name = "{}.{}".format(file_name_base, extension) data_set_path = os.path.join(data_set_dir, data_set_file_name) - context = DataSetBuildContext( - data_set_context, - create_random_2d_array(num_vectors, dimension), - data_set_path) + context = DataSetBuildContext(data_set_context, create_random_2d_array(num_vectors, dimension), data_set_path) if extension == HDF5DataSet.FORMAT_NAME: HDF5Builder().add_data_set_build_context(context).build() @@ -266,25 +235,14 @@ def create_data_set( return data_set_path -def create_attributes_data_set( - num_vectors: int, - dimension: int, - extension: str, - data_set_context: Context, - data_set_dir, - file_path: str = None -) -> str: +def create_attributes_data_set(num_vectors: int, dimension: int, extension: str, data_set_context: Context, data_set_dir, file_path: str = None) -> str: if file_path: data_set_path = file_path else: - file_name_base = ''.join(random.choice(string.ascii_letters) for _ in - range(DEFAULT_RANDOM_STRING_LENGTH)) + file_name_base = "".join(random.choice(string.ascii_letters) for _ in range(DEFAULT_RANDOM_STRING_LENGTH)) data_set_file_name = "{}.{}".format(file_name_base, extension) data_set_path = os.path.join(data_set_dir, data_set_file_name) - context = DataSetBuildContext( - data_set_context, - create_attributes(num_vectors), - data_set_path) + context = DataSetBuildContext(data_set_context, create_attributes(num_vectors), data_set_path) if extension == HDF5DataSet.FORMAT_NAME: HDF5Builder().add_data_set_build_context(context).build() @@ -294,25 +252,14 @@ def create_attributes_data_set( return data_set_path -def create_parent_data_set( - num_vectors: int, - dimension: int, - extension: str, - data_set_context: Context, - data_set_dir, - file_path: str = None -) -> str: +def create_parent_data_set(num_vectors: int, dimension: int, extension: str, data_set_context: Context, data_set_dir, file_path: str = None) -> str: if file_path: data_set_path = file_path else: - file_name_base = ''.join(random.choice(string.ascii_letters) for _ in - range(DEFAULT_RANDOM_STRING_LENGTH)) + file_name_base = "".join(random.choice(string.ascii_letters) for _ in range(DEFAULT_RANDOM_STRING_LENGTH)) data_set_file_name = "{}.{}".format(file_name_base, extension) data_set_path = os.path.join(data_set_dir, data_set_file_name) - context = DataSetBuildContext( - data_set_context, - create_parent_ids(num_vectors), - data_set_path) + context = DataSetBuildContext(data_set_context, create_parent_ids(num_vectors), data_set_path) if extension == HDF5DataSet.FORMAT_NAME: HDF5Builder().add_data_set_build_context(context).build() @@ -322,26 +269,14 @@ def create_parent_data_set( return data_set_path - -def create_ground_truth( - num_queries: int, - k: int, - extension: str, - data_set_context: Context, - data_set_dir, - file_path: str = None -) -> str: +def create_ground_truth(num_queries: int, k: int, extension: str, data_set_context: Context, data_set_dir, file_path: str = None) -> str: if file_path: data_set_path = file_path else: - file_name_base = ''.join(random.choice(string.ascii_letters) for _ in - range(DEFAULT_RANDOM_STRING_LENGTH)) + file_name_base = "".join(random.choice(string.ascii_letters) for _ in range(DEFAULT_RANDOM_STRING_LENGTH)) data_set_file_name = "{}.{}".format(file_name_base, extension) data_set_path = os.path.join(data_set_dir, data_set_file_name) - context = DataSetBuildContext( - data_set_context, - create_random_2d_array(num_queries, k), - data_set_path) + context = DataSetBuildContext(data_set_context, create_random_2d_array(num_queries, k), data_set_path) BigANNGroundTruthBuilder().add_data_set_build_context(context).build() return data_set_path diff --git a/tests/utils/dataset_test.py b/tests/utils/dataset_test.py index f595c52b..b2d9fe24 100644 --- a/tests/utils/dataset_test.py +++ b/tests/utils/dataset_test.py @@ -19,16 +19,9 @@ class DataSetTestCase(TestCase): - def testHDF5AsAcceptableDataSetFormat(self): with tempfile.TemporaryDirectory() as data_set_dir: - valid_data_set_path = create_data_set( - DEFAULT_NUM_VECTORS, - DEFAULT_DIMENSION, - HDF5DataSet.FORMAT_NAME, - DEFAULT_CONTEXT, - data_set_dir - ) + valid_data_set_path = create_data_set(DEFAULT_NUM_VECTORS, DEFAULT_DIMENSION, HDF5DataSet.FORMAT_NAME, DEFAULT_CONTEXT, data_set_dir) data_set_instance = get_data_set("hdf5", valid_data_set_path, Context.INDEX) self.assertEqual(data_set_instance.FORMAT_NAME, HDF5DataSet.FORMAT_NAME) self.assertEqual(data_set_instance.size(), DEFAULT_NUM_VECTORS) @@ -37,13 +30,7 @@ def testBigANNAsAcceptableDataSetFormatWithFloatExtension(self): float_extension = "fbin" data_set_dir = tempfile.mkdtemp() - valid_data_set_path = create_data_set( - DEFAULT_NUM_VECTORS, - DEFAULT_DIMENSION, - float_extension, - DEFAULT_CONTEXT, - data_set_dir - ) + valid_data_set_path = create_data_set(DEFAULT_NUM_VECTORS, DEFAULT_DIMENSION, float_extension, DEFAULT_CONTEXT, data_set_dir) data_set_instance = get_data_set("bigann", valid_data_set_path, Context.INDEX) self.assertEqual(data_set_instance.FORMAT_NAME, BigANNVectorDataSet.FORMAT_NAME) self.assertEqual(data_set_instance.size(), DEFAULT_NUM_VECTORS) @@ -52,13 +39,7 @@ def testBigANNGroundTruthAsAcceptableDataSetFormat(self): bin_extension = "bin" data_set_dir = tempfile.mkdtemp() - valid_data_set_path = create_ground_truth( - 100, - 10, - bin_extension, - Context.NEIGHBORS, - data_set_dir - ) + valid_data_set_path = create_ground_truth(100, 10, bin_extension, Context.NEIGHBORS, data_set_dir) data_set_instance = get_data_set("bigann", valid_data_set_path, Context.NEIGHBORS) self.assertEqual(data_set_instance.FORMAT_NAME, BigANNVectorDataSet.FORMAT_NAME) self.assertEqual(data_set_instance.size(), 100) diff --git a/tests/utils/git_test.py b/tests/utils/git_test.py index 8ffca822..f7e81018 100644 --- a/tests/utils/git_test.py +++ b/tests/utils/git_test.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -45,8 +45,7 @@ def test_version_too_old(self, run_subprocess_with_out_and_err): run_subprocess_with_out_and_err.return_value = ("git version 1.4.0", None, 0) with self.assertRaises(exceptions.SystemSetupError) as ctx: git.head_revision("/src") - self.assertEqual("solr-orbit requires at least version 2 of git. You have git version 1.4.0. Please update git.", - ctx.exception.args[0]) + self.assertEqual("solr-orbit requires at least version 2 of git. You have git version 1.4.0. Please update git.", ctx.exception.args[0]) run_subprocess_with_out_and_err.assert_called_with("git --version") @mock.patch("solrorbit.utils.io.ensure_dir") @@ -121,10 +120,7 @@ def test_rebase(self, run_subprocess_with_out_and_err, run_subprocess_with_loggi run_subprocess_with_out_and_err.return_value = ("git version 2.4.0", None, 0) run_subprocess_with_logging.return_value = 0 git.rebase("/src", remote="my-origin", branch="feature-branch") - calls = [ - mock.call("git -C /src checkout feature-branch"), - mock.call("git -C /src rebase my-origin/feature-branch") - ] + calls = [mock.call("git -C /src checkout feature-branch"), mock.call("git -C /src rebase my-origin/feature-branch")] run_subprocess_with_logging.assert_has_calls(calls) @mock.patch("solrorbit.utils.process.run_subprocess_with_logging") @@ -133,33 +129,31 @@ def test_pull(self, run_subprocess_with_out_and_err, run_subprocess_with_logging run_subprocess_with_out_and_err.return_value = ("git version 2.4.0", None, 0) run_subprocess_with_logging.return_value = 0 git.pull("/src", remote="my-origin", branch="feature-branch") - run_subprocess_with_out_and_err.assert_has_calls([ - # pull, fetch, rebase, checkout - mock.call("git --version") - ] * 4) + run_subprocess_with_out_and_err.assert_has_calls( + [ + # pull, fetch, rebase, checkout + mock.call("git --version") + ] + * 4 + ) calls = [ mock.call("git -C /src fetch --prune --tags my-origin"), mock.call("git -C /src checkout feature-branch"), - mock.call("git -C /src rebase my-origin/feature-branch") + mock.call("git -C /src rebase my-origin/feature-branch"), ] run_subprocess_with_logging.assert_has_calls(calls) @mock.patch("solrorbit.utils.process.run_subprocess_with_output") @mock.patch("solrorbit.utils.process.run_subprocess_with_logging") @mock.patch("solrorbit.utils.process.run_subprocess_with_out_and_err") - def test_pull_ts(self, run_subprocess_with_out_and_err, run_subprocess_with_logging, - run_subprocess_with_output): + def test_pull_ts(self, run_subprocess_with_out_and_err, run_subprocess_with_logging, run_subprocess_with_output): run_subprocess_with_out_and_err.return_value = ("git version 2.4.0", None, 0) run_subprocess_with_logging.return_value = 0 run_subprocess_with_output.return_value = ["3694a07"] git.pull_ts("/src", "20160101T110000Z") - run_subprocess_with_output.assert_called_with( - "git -C /src rev-list -n 1 --before=\"20160101T110000Z\" --date=iso8601 origin/main") - run_subprocess_with_logging.assert_has_calls([ - mock.call("git -C /src fetch --prune --tags origin"), - mock.call("git -C /src checkout 3694a07") - ]) + run_subprocess_with_output.assert_called_with('git -C /src rev-list -n 1 --before="20160101T110000Z" --date=iso8601 origin/main') + run_subprocess_with_logging.assert_has_calls([mock.call("git -C /src fetch --prune --tags origin"), mock.call("git -C /src checkout 3694a07")]) @mock.patch("solrorbit.utils.process.run_subprocess_with_logging") @mock.patch("solrorbit.utils.process.run_subprocess_with_out_and_err") @@ -167,10 +161,12 @@ def test_pull_revision(self, run_subprocess_with_out_and_err, run_subprocess_wit run_subprocess_with_out_and_err.return_value = ("git version 2.4.0", None, 0) run_subprocess_with_logging.return_value = 0 git.pull_revision("/src", "3694a07") - run_subprocess_with_logging.assert_has_calls([ - mock.call("git -C /src fetch --prune --tags origin"), - mock.call("git -C /src checkout 3694a07"), - ]) + run_subprocess_with_logging.assert_has_calls( + [ + mock.call("git -C /src fetch --prune --tags origin"), + mock.call("git -C /src checkout 3694a07"), + ] + ) @mock.patch("solrorbit.utils.process.run_subprocess_with_output") @mock.patch("solrorbit.utils.process.run_subprocess_with_out_and_err") @@ -184,10 +180,7 @@ def test_head_revision(self, run_subprocess_with_out_and_err, run_subprocess_wit @mock.patch("solrorbit.utils.process.run_subprocess_with_out_and_err") def test_list_remote_branches(self, run_subprocess_with_out_and_err, run_subprocess): run_subprocess_with_out_and_err.return_value = ("git version 2.4.0", None, 0) - run_subprocess.return_value = [" origin/HEAD", - " origin/main", - " origin/5.0.0-alpha1", - " origin/5"] + run_subprocess.return_value = [" origin/HEAD", " origin/main", " origin/5.0.0-alpha1", " origin/5"] self.assertEqual(["main", "5.0.0-alpha1", "5"], git.branches("/src", remote=True)) run_subprocess.assert_called_with("git -C /src for-each-ref refs/remotes/ --format='%(refname)'") @@ -195,10 +188,7 @@ def test_list_remote_branches(self, run_subprocess_with_out_and_err, run_subproc @mock.patch("solrorbit.utils.process.run_subprocess_with_out_and_err") def test_list_local_branches(self, run_subprocess_with_out_and_err, run_subprocess): run_subprocess_with_out_and_err.return_value = ("git version 2.4.0", None, 0) - run_subprocess.return_value = [" HEAD", - " main", - " 5.0.0-alpha1", - " 5"] + run_subprocess.return_value = [" HEAD", " main", " 5.0.0-alpha1", " 5"] self.assertEqual(["main", "5.0.0-alpha1", "5"], git.branches("/src", remote=False)) run_subprocess.assert_called_with("git -C /src for-each-ref refs/heads/ --format='%(refname:short)'") @@ -206,8 +196,7 @@ def test_list_local_branches(self, run_subprocess_with_out_and_err, run_subproce @mock.patch("solrorbit.utils.process.run_subprocess_with_out_and_err") def test_list_tags_with_tags_present(self, run_subprocess_with_out_and_err, run_subprocess): run_subprocess_with_out_and_err.return_value = ("git version 2.4.0", None, 0) - run_subprocess.return_value = [" v1", - " v2"] + run_subprocess.return_value = [" v1", " v2"] self.assertEqual(["v1", "v2"], git.tags("/src")) run_subprocess.assert_called_with("git -C /src tag") diff --git a/tests/utils/io_test.py b/tests/utils/io_test.py index 9cfea258..3474def2 100644 --- a/tests/utils/io_test.py +++ b/tests/utils/io_test.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -42,7 +42,7 @@ def mock_debian(args, fallback=None): "/usr/lib/jvm/java-7-openjdk-amd64/jre/bin/java", "/usr/lib/jvm/java-7-oracle/jre/bin/java", "/usr/lib/jvm/java-8-oracle/jre/bin/java", - "/usr/lib/jvm/java-9-openjdk-amd64/bin/java" + "/usr/lib/jvm/java-9-openjdk-amd64/bin/java", ] else: return fallback @@ -86,10 +86,10 @@ def test_decompresses_supported_file_formats(self): io.decompress(archive_path, target_directory=tmp_dir) - assert os.path.exists(decompressed_path) is True,\ - f"Could not decompress [{archive_path}] to [{decompressed_path}] (target file does not exist)" - assert self.read(decompressed_path) == "Sample text for DecompressionTests\n",\ + assert os.path.exists(decompressed_path) is True, f"Could not decompress [{archive_path}] to [{decompressed_path}] (target file does not exist)" + assert self.read(decompressed_path) == "Sample text for DecompressionTests\n", ( f"Could not decompress [{archive_path}] to [{decompressed_path}] (target file is corrupt)" + ) @mock.patch.object(io, "is_executable", return_value=False) def test_decompresses_supported_file_formats_with_lib_as_failover(self, mocked_is_executable): @@ -102,10 +102,10 @@ def test_decompresses_supported_file_formats_with_lib_as_failover(self, mocked_i with mock.patch.object(logger, "warning") as mocked_console_warn: io.decompress(archive_path, target_directory=tmp_dir) - assert os.path.exists(decompressed_path) is True,\ - f"Could not decompress [{archive_path}] to [{decompressed_path}] (target file does not exist)" - assert self.read(decompressed_path) == "Sample text for DecompressionTests\n",\ + assert os.path.exists(decompressed_path) is True, f"Could not decompress [{archive_path}] to [{decompressed_path}] (target file does not exist)" + assert self.read(decompressed_path) == "Sample text for DecompressionTests\n", ( f"Could not decompress [{archive_path}] to [{decompressed_path}] (target file is corrupt)" + ) if ext in ["bz2", "gz"]: assert "not found in PATH. Using standard library, decompression will take longer." in mocked_console_warn.call_args[0][0] @@ -129,5 +129,5 @@ def test_decompress_manually_external_fails_if_tool_missing(self, mocked_run): assert result is False def read(self, f): - with open(f, 'r') as content_file: + with open(f, "r") as content_file: return content_file.read() diff --git a/tests/utils/jvm_test.py b/tests/utils/jvm_test.py index 2a0087b7..b4c6fffa 100644 --- a/tests/utils/jvm_test.py +++ b/tests/utils/jvm_test.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -83,16 +83,14 @@ def test_resolve_path_for_one_version_no_matching_version(self, getenv): # JAVA8_HOME, JAVA_HOME getenv.side_effect = [None, "/opt/jdks/jdk/1.7"] - with self.assertRaisesRegex(expected_exception=exceptions.SystemSetupError, - expected_regex="JAVA_HOME points to JDK 7 but it should point to JDK 8."): + with self.assertRaisesRegex(expected_exception=exceptions.SystemSetupError, expected_regex="JAVA_HOME points to JDK 7 but it should point to JDK 8."): jvm.resolve_path(majors=8, sysprop_reader=self.path_based_prop_version_reader) @mock.patch("os.getenv") def test_resolve_path_for_one_version_no_env_vars_defined(self, getenv): getenv.return_value = None - with self.assertRaisesRegex(expected_exception=exceptions.SystemSetupError, - expected_regex="Neither JAVA8_HOME nor JAVA_HOME point to a JDK 8 installation."): + with self.assertRaisesRegex(expected_exception=exceptions.SystemSetupError, expected_regex="Neither JAVA8_HOME nor JAVA_HOME point to a JDK 8 installation."): jvm.resolve_path(majors=8, sysprop_reader=self.path_based_prop_version_reader) @mock.patch("os.getenv") diff --git a/tests/utils/net_test.py b/tests/utils/net_test.py index 53585e1d..59bc6897 100644 --- a/tests/utils/net_test.py +++ b/tests/utils/net_test.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -41,10 +41,8 @@ def test_download_from_s3_bucket(self, download, seed): expected_size = random.choice([None, random.randint(0, 1000)]) progress_indicator = random.choice([None, "some progress indicator"]) - net.download_from_bucket("s3", "s3://mybucket.opensearch.org/data/documents.json.bz2", "/tmp/documents.json.bz2", - expected_size, progress_indicator) - download.assert_called_once_with("mybucket.opensearch.org", "data/documents.json.bz2", - "/tmp/documents.json.bz2", expected_size, progress_indicator) + net.download_from_bucket("s3", "s3://mybucket.opensearch.org/data/documents.json.bz2", "/tmp/documents.json.bz2", expected_size, progress_indicator) + download.assert_called_once_with("mybucket.opensearch.org", "data/documents.json.bz2", "/tmp/documents.json.bz2", expected_size, progress_indicator) @mock.patch("solrorbit.utils.console.error") @mock.patch("solrorbit.utils.net._fake_import_boto3") @@ -52,9 +50,7 @@ def test_missing_boto3(self, import_boto3, console_error): import_boto3.side_effect = ImportError("no module named 'boto3'") with pytest.raises(ImportError, match="no module named 'boto3'"): net.download_from_bucket("s3", "s3://mybucket/data", "/tmp/data", None, None) - console_error.assert_called_once_with( - "S3 support is optional. Install it with `python -m pip install solr-orbit[s3]`" - ) + console_error.assert_called_once_with("S3 support is optional. Install it with `python -m pip install solr-orbit[s3]`") @pytest.mark.parametrize("seed", range(1)) @mock.patch("solrorbit.utils.net._download_from_gcs_bucket") @@ -63,30 +59,25 @@ def test_download_from_gs_bucket(self, download, seed): expected_size = random.choice([None, random.randint(0, 1000)]) progress_indicator = random.choice([None, "some progress indicator"]) - net.download_from_bucket("gs", "gs://unittest-gcp-bucket.test.org/data/documents.json.bz2", "/tmp/documents.json.bz2", - expected_size, progress_indicator) - download.assert_called_once_with("unittest-gcp-bucket.test.org", "data/documents.json.bz2", - "/tmp/documents.json.bz2", expected_size, progress_indicator) + net.download_from_bucket("gs", "gs://unittest-gcp-bucket.test.org/data/documents.json.bz2", "/tmp/documents.json.bz2", expected_size, progress_indicator) + download.assert_called_once_with("unittest-gcp-bucket.test.org", "data/documents.json.bz2", "/tmp/documents.json.bz2", expected_size, progress_indicator) @pytest.mark.parametrize("seed", range(40)) def test_gcs_object_url(self, seed): random.seed(seed) - bucket_name = random.choice(["unittest-bucket.test.me", "/unittest-bucket.test.me", - "/unittest-bucket.test.me/", "unittest-bucket.test.me/"]) - bucket_path = random.choice(["path/to/object", "/path/to/object", - "/path/to/object/", "path/to/object/"]) + bucket_name = random.choice(["unittest-bucket.test.me", "/unittest-bucket.test.me", "/unittest-bucket.test.me/", "unittest-bucket.test.me/"]) + bucket_path = random.choice(["path/to/object", "/path/to/object", "/path/to/object/", "path/to/object/"]) # pylint: disable=protected-access - assert net._build_gcs_object_url(bucket_name, bucket_path) == \ - "https://storage.googleapis.com/storage/v1/b/unittest-bucket.test.me/o/path%2Fto%2Fobject?alt=media" + assert net._build_gcs_object_url(bucket_name, bucket_path) == "https://storage.googleapis.com/storage/v1/b/unittest-bucket.test.me/o/path%2Fto%2Fobject?alt=media" def test_add_url_param_encoding_and_update(self): url = "https://artifacts.opensearch.org/releases/bundle/opensearch/1.0.0/opensearch-1.0.0-darwin-x64.tar.gz?flag1=true" params = {"flag1": "test me", "flag2": "test@me"} # pylint: disable=protected-access - assert net._add_url_param(url, params) == \ - ("https://artifacts.opensearch.org/releases/bundle/opensearch/"\ - "1.0.0/opensearch-1.0.0-darwin-x64.tar.gz?flag1=test+me&flag2=test%40me") + assert net._add_url_param(url, params) == ( + "https://artifacts.opensearch.org/releases/bundle/opensearch/1.0.0/opensearch-1.0.0-darwin-x64.tar.gz?flag1=test+me&flag2=test%40me" + ) def test_progress(self): progress = net.Progress("test") diff --git a/tests/utils/opts_test.py b/tests/utils/opts_test.py index 7e3466da..ff7a5225 100644 --- a/tests/utils/opts_test.py +++ b/tests/utils/opts_test.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -45,36 +45,22 @@ def test_kv_to_map(self): self.assertEqual({"k": 3}, opts.kv_to_map(["k:3"])) # implicit treatment as string self.assertEqual({"k": "v"}, opts.kv_to_map(["k:v"])) - self.assertEqual({"k": "v", "size": 4, "empty": False, "temperature": 0.5}, - opts.kv_to_map(["k:'v'", "size:4", "empty:false", "temperature:0.5"])) + self.assertEqual({"k": "v", "size": 4, "empty": False, "temperature": 0.5}, opts.kv_to_map(["k:'v'", "size:4", "empty:false", "temperature:0.5"])) class GenericHelperFunctionTests(TestCase): def test_list_as_bulleted_list(self): src_list = ["param-1", "param-2", "a_longer-parameter"] - self.assertEqual( - ["- param-1", "- param-2", "- a_longer-parameter"], - opts.bulleted_list_of(src_list) - ) + self.assertEqual(["- param-1", "- param-2", "- a_longer-parameter"], opts.bulleted_list_of(src_list)) def test_list_as_double_quoted_list(self): src_list = ["oneitem", "_another-weird_item", "param-3"] - self.assertEqual( - opts.double_quoted_list_of(src_list), - ['"oneitem"', '"_another-weird_item"', '"param-3"'] - ) + self.assertEqual(opts.double_quoted_list_of(src_list), ['"oneitem"', '"_another-weird_item"', '"param-3"']) def test_make_list_of_close_matches(self): - word_list = [ - "bulk_indexing_clients", - "bulk_indexing_iterations", - "target_throughput", - "bulk_size", - "number_of-shards", - "number_of_replicas", - "index_refresh_interval"] + word_list = ["bulk_indexing_clients", "bulk_indexing_iterations", "target_throughput", "bulk_size", "number_of-shards", "number_of_replicas", "index_refresh_interval"] available_word_list = [ "bulk_indexing_clients", @@ -117,206 +103,161 @@ def test_make_list_of_close_matches(self): "shard_sizing_queries", "source_enabled", "target_throughput", - "translog_sync"] + "translog_sync", + ] self.assertEqual( - ['bulk_indexing_clients', - 'bulk_indexing_iterations', - 'target_throughput', - 'bulk_size', - # number_of-shards had a typo - 'number_of_shards', - 'number_of_replicas', - 'index_refresh_interval'], - opts.make_list_of_close_matches(word_list, available_word_list) + [ + "bulk_indexing_clients", + "bulk_indexing_iterations", + "target_throughput", + "bulk_size", + # number_of-shards had a typo + "number_of_shards", + "number_of_replicas", + "index_refresh_interval", + ], + opts.make_list_of_close_matches(word_list, available_word_list), ) def test_make_list_of_close_matches_returns_with_empty_word_list(self): - self.assertEqual( - [], - opts.make_list_of_close_matches([], ["number_of_shards"]) - ) + self.assertEqual([], opts.make_list_of_close_matches([], ["number_of_shards"])) def test_make_list_of_close_matches_returns_empty_list_with_no_close_matches(self): - self.assertEqual( - [], - opts.make_list_of_close_matches( - ["number_of_shards", "number_of-replicas"], - []) - ) + self.assertEqual([], opts.make_list_of_close_matches(["number_of_shards", "number_of-replicas"], [])) def test_to_dict_str(self): json = '{ "field": 2 }' rsl = opts.to_dict(json) self.assertEqual(rsl, {"field": 2}) - rsl = opts.to_dict('a:1,b:2') - self.assertEqual(rsl, {'a': 1, 'b': 2}) + rsl = opts.to_dict("a:1,b:2") + self.assertEqual(rsl, {"a": 1, "b": 2}) @mock.patch("json.loads") def test_to_dict_file(self, json_loads): mo = mock.mock_open() with mock.patch("builtins.open", mo): - opts.to_dict('params.json') + opts.to_dict("params.json") mo.assert_called() mo = mock.mock_open() with mock.patch("builtins.open", mo): - opts.to_dict('index_body:idx.json') + opts.to_dict("index_body:idx.json") mo.assert_not_called() class TestTargetHosts(TestCase): def test_empty_arg_parses_as_empty_list(self): - self.assertEqual([], opts.TargetHosts('').default) - self.assertEqual({'default': []}, opts.TargetHosts('').all_hosts) + self.assertEqual([], opts.TargetHosts("").default) + self.assertEqual({"default": []}, opts.TargetHosts("").all_hosts) def test_csv_hosts_parses(self): - target_hosts = '127.0.0.1:9200,10.17.0.5:19200' + target_hosts = "127.0.0.1:9200,10.17.0.5:19200" - self.assertEqual( - {'default': [{'host': '127.0.0.1', 'port': 9200},{'host': '10.17.0.5', 'port': 19200}]}, - opts.TargetHosts(target_hosts).all_hosts - ) + self.assertEqual({"default": [{"host": "127.0.0.1", "port": 9200}, {"host": "10.17.0.5", "port": 19200}]}, opts.TargetHosts(target_hosts).all_hosts) - self.assertEqual( - [{'host': '127.0.0.1', 'port': 9200},{'host': '10.17.0.5', 'port': 19200}], - opts.TargetHosts(target_hosts).default - ) + self.assertEqual([{"host": "127.0.0.1", "port": 9200}, {"host": "10.17.0.5", "port": 19200}], opts.TargetHosts(target_hosts).default) - self.assertEqual( - [{'host': '127.0.0.1', 'port': 9200},{'host': '10.17.0.5', 'port': 19200}], - opts.TargetHosts(target_hosts).default) + self.assertEqual([{"host": "127.0.0.1", "port": 9200}, {"host": "10.17.0.5", "port": 19200}], opts.TargetHosts(target_hosts).default) def test_jsonstring_parses_as_dict_of_clusters(self): - target_hosts = ('{"default": ["127.0.0.1:9200","10.17.0.5:19200"],' - ' "remote_1": ["88.33.22.15:19200"],' - ' "remote_2": ["10.18.0.6:19200","10.18.0.7:19201"]}') + target_hosts = '{"default": ["127.0.0.1:9200","10.17.0.5:19200"], "remote_1": ["88.33.22.15:19200"], "remote_2": ["10.18.0.6:19200","10.18.0.7:19201"]}' self.assertEqual( - {'default': ['127.0.0.1:9200','10.17.0.5:19200'], - 'remote_1': ['88.33.22.15:19200'], - 'remote_2': ['10.18.0.6:19200','10.18.0.7:19201']}, - opts.TargetHosts(target_hosts).all_hosts) + {"default": ["127.0.0.1:9200", "10.17.0.5:19200"], "remote_1": ["88.33.22.15:19200"], "remote_2": ["10.18.0.6:19200", "10.18.0.7:19201"]}, + opts.TargetHosts(target_hosts).all_hosts, + ) def test_json_file_parameter_parses(self): self.assertEqual( - {"default": ["127.0.0.1:9200","10.127.0.3:19200"] }, - opts.TargetHosts(os.path.join(os.path.dirname(__file__), "resources", "target_hosts_1.json")).all_hosts) + {"default": ["127.0.0.1:9200", "10.127.0.3:19200"]}, opts.TargetHosts(os.path.join(os.path.dirname(__file__), "resources", "target_hosts_1.json")).all_hosts + ) self.assertEqual( { - "default": [ - {"host": "127.0.0.1", "port": 9200}, - {"host": "127.0.0.1", "port": 19200} - ], - "remote_1":[ - {"host": "10.127.0.3", "port": 9200}, - {"host": "10.127.0.8", "port": 9201} - ], - "remote_2":[ - {"host": "88.33.27.15", "port": 39200} - ] + "default": [{"host": "127.0.0.1", "port": 9200}, {"host": "127.0.0.1", "port": 19200}], + "remote_1": [{"host": "10.127.0.3", "port": 9200}, {"host": "10.127.0.8", "port": 9201}], + "remote_2": [{"host": "88.33.27.15", "port": 39200}], }, - opts.TargetHosts(os.path.join(os.path.dirname(__file__), "resources", "target_hosts_2.json")).all_hosts) + opts.TargetHosts(os.path.join(os.path.dirname(__file__), "resources", "target_hosts_2.json")).all_hosts, + ) class TestClientOptions(TestCase): def test_csv_client_options_parses(self): client_options_string = "use_ssl:true,verify_certs:true,ca_certs:'/path/to/cacert.pem'" - self.assertEqual( - {'use_ssl': True, 'verify_certs': True, 'ca_certs': '/path/to/cacert.pem'}, - opts.ClientOptions(client_options_string).default) + self.assertEqual({"use_ssl": True, "verify_certs": True, "ca_certs": "/path/to/cacert.pem"}, opts.ClientOptions(client_options_string).default) - self.assertEqual( - {'use_ssl': True, 'verify_certs': True, 'ca_certs': '/path/to/cacert.pem'}, - opts.ClientOptions(client_options_string).default - ) + self.assertEqual({"use_ssl": True, "verify_certs": True, "ca_certs": "/path/to/cacert.pem"}, opts.ClientOptions(client_options_string).default) - self.assertEqual( - {'default': {'use_ssl': True, 'verify_certs': True, 'ca_certs': '/path/to/cacert.pem'}}, - opts.ClientOptions(client_options_string).all_client_options - ) + self.assertEqual({"default": {"use_ssl": True, "verify_certs": True, "ca_certs": "/path/to/cacert.pem"}}, opts.ClientOptions(client_options_string).all_client_options) def test_jsonstring_client_options_parses(self): - client_options_string = '{"default": {"timeout": 60},' \ - '"remote_1": {"use_ssl":true,"verify_certs":true,"basic_auth_user": "solr", "basic_auth_password": "changeme"},'\ + client_options_string = ( + '{"default": {"timeout": 60},' + '"remote_1": {"use_ssl":true,"verify_certs":true,"basic_auth_user": "solr", "basic_auth_password": "changeme"},' '"remote_2": {"use_ssl":true,"verify_certs":true,"ca_certs":"/path/to/cacert.pem"}}' + ) - self.assertEqual( - {'timeout': 60}, - opts.ClientOptions(client_options_string).default) + self.assertEqual({"timeout": 60}, opts.ClientOptions(client_options_string).default) - self.assertEqual( - {'timeout': 60}, - opts.ClientOptions(client_options_string).default) + self.assertEqual({"timeout": 60}, opts.ClientOptions(client_options_string).default) self.assertEqual( - {'default': {'timeout':60}, - 'remote_1': {'use_ssl': True,'verify_certs': True,'basic_auth_user':'solr','basic_auth_password':'changeme'}, - 'remote_2': {'use_ssl': True,'verify_certs': True, 'ca_certs':'/path/to/cacert.pem'}}, - opts.ClientOptions(client_options_string).all_client_options) + { + "default": {"timeout": 60}, + "remote_1": {"use_ssl": True, "verify_certs": True, "basic_auth_user": "solr", "basic_auth_password": "changeme"}, + "remote_2": {"use_ssl": True, "verify_certs": True, "ca_certs": "/path/to/cacert.pem"}, + }, + opts.ClientOptions(client_options_string).all_client_options, + ) def test_json_file_parameter_parses(self): self.assertEqual( - {'default': {'timeout':60}, - 'remote_1': {'use_ssl': True,'verify_certs': True,'basic_auth_user':'solr','basic_auth_password':'changeme'}, - 'remote_2': {'use_ssl': True,'verify_certs': True, 'ca_certs':'/path/to/cacert.pem'}}, - opts.ClientOptions(os.path.join(os.path.dirname(__file__), "resources", "client_options_1.json")).all_client_options) + { + "default": {"timeout": 60}, + "remote_1": {"use_ssl": True, "verify_certs": True, "basic_auth_user": "solr", "basic_auth_password": "changeme"}, + "remote_2": {"use_ssl": True, "verify_certs": True, "ca_certs": "/path/to/cacert.pem"}, + }, + opts.ClientOptions(os.path.join(os.path.dirname(__file__), "resources", "client_options_1.json")).all_client_options, + ) - self.assertEqual( - {'default': {'timeout':60}}, - opts.ClientOptions(os.path.join(os.path.dirname(__file__), "resources", "client_options_2.json")).all_client_options) + self.assertEqual({"default": {"timeout": 60}}, opts.ClientOptions(os.path.join(os.path.dirname(__file__), "resources", "client_options_2.json")).all_client_options) def test_no_client_option_parses_to_default(self): client_options_string = opts.ClientOptions.DEFAULT_CLIENT_OPTIONS target_hosts = None - self.assertEqual( - {"timeout": 60}, - opts.ClientOptions(client_options_string, - target_hosts=target_hosts).default) + self.assertEqual({"timeout": 60}, opts.ClientOptions(client_options_string, target_hosts=target_hosts).default) - self.assertEqual( - {"default": {"timeout": 60}}, - opts.ClientOptions(client_options_string, - target_hosts=target_hosts).all_client_options) + self.assertEqual({"default": {"timeout": 60}}, opts.ClientOptions(client_options_string, target_hosts=target_hosts).all_client_options) - self.assertEqual( - {"timeout": 60}, - opts.ClientOptions(client_options_string, - target_hosts=target_hosts).default) + self.assertEqual({"timeout": 60}, opts.ClientOptions(client_options_string, target_hosts=target_hosts).default) def test_no_client_option_parses_to_default_with_multicluster(self): client_options_string = opts.ClientOptions.DEFAULT_CLIENT_OPTIONS target_hosts = opts.TargetHosts('{"default": ["127.0.0.1:9200,10.17.0.5:19200"], "remote": ["88.33.22.15:19200"]}') - self.assertEqual( - {"timeout": 60}, - opts.ClientOptions(client_options_string, - target_hosts=target_hosts).default) + self.assertEqual({"timeout": 60}, opts.ClientOptions(client_options_string, target_hosts=target_hosts).default) - self.assertEqual( - {"default": {"timeout": 60}, "remote": {"timeout": 60}}, - opts.ClientOptions(client_options_string, - target_hosts=target_hosts).all_client_options) + self.assertEqual({"default": {"timeout": 60}, "remote": {"timeout": 60}}, opts.ClientOptions(client_options_string, target_hosts=target_hosts).all_client_options) - self.assertEqual( - {"timeout": 60}, - opts.ClientOptions(client_options_string, - target_hosts=target_hosts).default) + self.assertEqual({"timeout": 60}, opts.ClientOptions(client_options_string, target_hosts=target_hosts).default) def test_amends_with_max_connections(self): client_options_string = opts.ClientOptions.DEFAULT_CLIENT_OPTIONS target_hosts = opts.TargetHosts('{"default": ["10.17.0.5:9200"], "remote": ["88.33.22.15:9200"]}') self.assertEqual( {"default": {"timeout": 60, "max_connections": 128}, "remote": {"timeout": 60, "max_connections": 128}}, - opts.ClientOptions(client_options_string, target_hosts=target_hosts).with_max_connections(128)) + opts.ClientOptions(client_options_string, target_hosts=target_hosts).with_max_connections(128), + ) def test_keeps_already_specified_max_connections(self): client_options_string = '{"default": {"timeout":60,"max_connections":5}, "remote": {"timeout":60}}' target_hosts = opts.TargetHosts('{"default": ["10.17.0.5:9200"], "remote": ["88.33.22.15:9200"]}') self.assertEqual( {"default": {"timeout": 60, "max_connections": 5}, "remote": {"timeout": 60, "max_connections": 32}}, - opts.ClientOptions(client_options_string, target_hosts=target_hosts).with_max_connections(32)) + opts.ClientOptions(client_options_string, target_hosts=target_hosts).with_max_connections(32), + ) diff --git a/tests/utils/process_test.py b/tests/utils/process_test.py index 49ff8af1..a1854c81 100644 --- a/tests/utils/process_test.py +++ b/tests/utils/process_test.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -59,28 +59,22 @@ def status(self): @mock.patch("psutil.process_iter") def test_find_other_benchmark_processes(self, process_iter): - benchmark_os_5_process = ProcessTests.Process(100, "java", - ["/usr/lib/jvm/java-8-oracle/bin/java", "-Xms2g", "-Xmx2g", "-Enode.name=benchmark-node0", - "org.elasticsearch.bootstrap.Elasticsearch"]) - benchmark_os_1_process = ProcessTests.Process(101, "java", - ["/usr/lib/jvm/java-8-oracle/bin/java", - "-Xms2g", "-Xmx2g", - "-Des.node.name=benchmark-node0", - "org.elasticsearch.bootstrap.Elasticsearch"]) - metrics_store_process = ProcessTests.Process(102, "java", ["/usr/lib/jvm/java-8-oracle/bin/java", "-Xms2g", "-Xmx2g", - "-Des.path.home=~/benchmark/metrics/", - "org.elasticsearch.bootstrap.Elasticsearch"]) + benchmark_os_5_process = ProcessTests.Process( + 100, "java", ["/usr/lib/jvm/java-8-oracle/bin/java", "-Xms2g", "-Xmx2g", "-Enode.name=benchmark-node0", "org.elasticsearch.bootstrap.Elasticsearch"] + ) + benchmark_os_1_process = ProcessTests.Process( + 101, "java", ["/usr/lib/jvm/java-8-oracle/bin/java", "-Xms2g", "-Xmx2g", "-Des.node.name=benchmark-node0", "org.elasticsearch.bootstrap.Elasticsearch"] + ) + metrics_store_process = ProcessTests.Process( + 102, "java", ["/usr/lib/jvm/java-8-oracle/bin/java", "-Xms2g", "-Xmx2g", "-Des.path.home=~/benchmark/metrics/", "org.elasticsearch.bootstrap.Elasticsearch"] + ) random_python = ProcessTests.Process(103, "python3", ["/some/django/app"]) other_process = ProcessTests.Process(104, "init", ["/usr/sbin/init"]) benchmark_process_p = ProcessTests.Process(105, "python3", ["/usr/bin/python3", "~/.local/bin/solr-orbit"]) benchmark_process_e = ProcessTests.Process(107, "solr-orbit", ["/usr/bin/python3", "~/.local/bin/solr-orbit"]) - benchmark_process_mac = ProcessTests.Process(108, "Python", ["/Python.app/Contents/MacOS/Python", - "~/.local/bin/solr-orbit"]) + benchmark_process_mac = ProcessTests.Process(108, "Python", ["/Python.app/Contents/MacOS/Python", "~/.local/bin/solr-orbit"]) # fake own process by determining our pid - own_benchmark_process = ProcessTests.Process( - os.getpid(), "Python", - ["/Python.app/Contents/MacOS/Python", - "~/.local/bin/solr-orbit"]) + own_benchmark_process = ProcessTests.Process(os.getpid(), "Python", ["/Python.app/Contents/MacOS/Python", "~/.local/bin/solr-orbit"]) night_benchmark_process = ProcessTests.Process(110, "Python", ["/Python.app/Contents/MacOS/Python", "~/.local/bin/night_benchmark"]) process_iter.return_value = [ @@ -96,45 +90,39 @@ def test_find_other_benchmark_processes(self, process_iter): night_benchmark_process, ] - self.assertEqual([benchmark_process_p, benchmark_process_e, benchmark_process_mac], - process.find_all_other_benchmark_processes()) + self.assertEqual([benchmark_process_p, benchmark_process_e, benchmark_process_mac], process.find_all_other_benchmark_processes()) @mock.patch("psutil.process_iter") def test_find_no_other_benchmark_process_running(self, process_iter): - metrics_store_process = ProcessTests.Process(102, "java", ["/usr/lib/jvm/java-8-oracle/bin/java", "-Xms2g", "-Xmx2g", - "-Des.path.home=~/benchmark/metrics/", - "org.elasticsearch.bootstrap.Elasticsearch"]) + metrics_store_process = ProcessTests.Process( + 102, "java", ["/usr/lib/jvm/java-8-oracle/bin/java", "-Xms2g", "-Xmx2g", "-Des.path.home=~/benchmark/metrics/", "org.elasticsearch.bootstrap.Elasticsearch"] + ) random_python = ProcessTests.Process(103, "python3", ["/some/django/app"]) - process_iter.return_value = [ metrics_store_process, random_python] + process_iter.return_value = [metrics_store_process, random_python] self.assertEqual(0, len(process.find_all_other_benchmark_processes())) @mock.patch("psutil.process_iter") def test_kills_only_benchmark_processes(self, process_iter): - benchmark_os_5_process = ProcessTests.Process(100, "java", - ["/usr/lib/jvm/java-8-oracle/bin/java", "-Xms2g", "-Xmx2g", "-Enode.name=benchmark-node0", - "org.elasticsearch.bootstrap.Elasticsearch"]) - benchmark_os_1_process = ProcessTests.Process(101, "java", - ["/usr/lib/jvm/java-8-oracle/bin/java", - "-Xms2g", "-Xmx2g", - "-Des.node.name=benchmark-node0", - "org.elasticsearch.bootstrap.Elasticsearch"]) - metrics_store_process = ProcessTests.Process(102, "java", ["/usr/lib/jvm/java-8-oracle/bin/java", "-Xms2g", "-Xmx2g", - "-Des.path.home=~/benchmark/metrics/", - "org.elasticsearch.bootstrap.Elasticsearch"]) + benchmark_os_5_process = ProcessTests.Process( + 100, "java", ["/usr/lib/jvm/java-8-oracle/bin/java", "-Xms2g", "-Xmx2g", "-Enode.name=benchmark-node0", "org.elasticsearch.bootstrap.Elasticsearch"] + ) + benchmark_os_1_process = ProcessTests.Process( + 101, "java", ["/usr/lib/jvm/java-8-oracle/bin/java", "-Xms2g", "-Xmx2g", "-Des.node.name=benchmark-node0", "org.elasticsearch.bootstrap.Elasticsearch"] + ) + metrics_store_process = ProcessTests.Process( + 102, "java", ["/usr/lib/jvm/java-8-oracle/bin/java", "-Xms2g", "-Xmx2g", "-Des.path.home=~/benchmark/metrics/", "org.elasticsearch.bootstrap.Elasticsearch"] + ) random_python = ProcessTests.Process(103, "python3", ["/some/django/app"]) other_process = ProcessTests.Process(104, "init", ["/usr/sbin/init"]) benchmark_process_p = ProcessTests.Process(105, "python3", ["/usr/bin/python3", "~/.local/bin/solr-orbit"]) # On Linux, process names are truncated to 15 chars; "solr-orbit" (9 chars) fits within that limit. benchmark_process_l = ProcessTests.Process(106, "solr-orbit", ["/usr/bin/python3", "~/.local/bin/solr-orbit"]) benchmark_process_e = ProcessTests.Process(107, "solr-orbit", ["/usr/bin/python3", "~/.local/bin/solr-orbit"]) - benchmark_process_mac = ProcessTests.Process(108, "Python", ["/Python.app/Contents/MacOS/Python", - "~/.local/bin/solr-orbit"]) + benchmark_process_mac = ProcessTests.Process(108, "Python", ["/Python.app/Contents/MacOS/Python", "~/.local/bin/solr-orbit"]) # fake own process by determining our pid - own_benchmark_process = ProcessTests.Process( - os.getpid(), "Python", - ["/Python.app/Contents/MacOS/Python", "~/.local/bin/solr-orbit"]) + own_benchmark_process = ProcessTests.Process(os.getpid(), "Python", ["/Python.app/Contents/MacOS/Python", "~/.local/bin/solr-orbit"]) night_benchmark_process = ProcessTests.Process(110, "Python", ["/Python.app/Contents/MacOS/Python", "~/.local/bin/night_benchmark"]) process_iter.return_value = [ diff --git a/tests/utils/repo_test.py b/tests/utils/repo_test.py index 84b2fdbe..55c981cf 100644 --- a/tests/utils/repo_test.py +++ b/tests/utils/repo_test.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -41,15 +41,9 @@ def test_fails_in_offline_mode_if_not_a_git_repo(self, is_working_copy, exists): exists.return_value = True with self.assertRaises(exceptions.SystemSetupError) as ctx: - repo.BenchmarkRepository( - default_directory=None, - root_dir="/benchmark-resources", - repo_name="unit-test", - resource_name="unittest-resources", - offline=True) + repo.BenchmarkRepository(default_directory=None, root_dir="/benchmark-resources", repo_name="unit-test", resource_name="unittest-resources", offline=True) - self.assertEqual("[/benchmark-resources/unit-test] must be a git repository.\n\n" - "Please run:\ngit -C /benchmark-resources/unit-test init", ctx.exception.args[0]) + self.assertEqual("[/benchmark-resources/unit-test] must be a git repository.\n\nPlease run:\ngit -C /benchmark-resources/unit-test init", ctx.exception.args[0]) @mock.patch("solrorbit.utils.io.exists", autospec=True) @mock.patch("solrorbit.utils.git.is_working_copy", autospec=True) @@ -57,12 +51,7 @@ def test_does_nothing_in_offline_mode_if_not_existing(self, is_working_copy, exi is_working_copy.return_value = False exists.return_value = False - r = repo.BenchmarkRepository( - default_directory=None, - root_dir="/benchmark-resources", - repo_name="unit-test", - resource_name="unittest-resources", - offline=True) + r = repo.BenchmarkRepository(default_directory=None, root_dir="/benchmark-resources", repo_name="unit-test", resource_name="unittest-resources", offline=True) self.assertFalse(r.remote) @@ -70,12 +59,7 @@ def test_does_nothing_in_offline_mode_if_not_existing(self, is_working_copy, exi def test_does_nothing_if_working_copy_present(self, is_working_copy): is_working_copy.return_value = True - r = repo.BenchmarkRepository( - default_directory=None, - root_dir="/benchmark-resources", - repo_name="unit-test", - resource_name="unittest-resources", - offline=True) + r = repo.BenchmarkRepository(default_directory=None, root_dir="/benchmark-resources", repo_name="unit-test", resource_name="unittest-resources", offline=True) self.assertFalse(r.remote) @@ -89,7 +73,8 @@ def test_clones_initially(self, clone, is_working_copy): root_dir="/benchmark-resources", repo_name="unit-test", resource_name="unittest-resources", - offline=False) + offline=False, + ) self.assertTrue(r.remote) @@ -105,7 +90,8 @@ def test_fetches_if_already_cloned(self, fetch, is_working_copy): root_dir="/benchmark-resources", repo_name="unit-test", resource_name="unittest-resources", - offline=False) + offline=False, + ) fetch.assert_called_with(src="/benchmark-resources/unit-test") @@ -120,7 +106,8 @@ def test_does_not_fetch_if_suppressed(self, fetch, is_working_copy): repo_name="unit-test", resource_name="unittest-resources", offline=False, - fetch=False) + fetch=False, + ) self.assertTrue(r.remote) @@ -137,7 +124,8 @@ def test_ignores_fetch_errors(self, fetch, is_working_copy): root_dir="/benchmark-resources", repo_name="unit-test", resource_name="unittest-resources", - offline=False) + offline=False, + ) # no exception during the call - we reach this here self.assertTrue(r.remote) @@ -159,7 +147,8 @@ def test_updates_from_remote(self, rebase, checkout, branches, fetch, is_working root_dir="/benchmark-resources", repo_name="unit-test", resource_name="unittest-resources", - offline=random.choice([True, False])) + offline=random.choice([True, False]), + ) r.update(distribution_version="1.7.3") @@ -180,12 +169,7 @@ def test_updates_locally(self, curr_branch, rebase, checkout, branches, fetch, i is_working_copy.return_value = True head_revision.return_value = "123a" - r = repo.BenchmarkRepository( - default_directory=None, - root_dir="/benchmark-resources", - repo_name="unit-test", - resource_name="unittest-resources", - offline=False) + r = repo.BenchmarkRepository(default_directory=None, root_dir="/benchmark-resources", repo_name="unit-test", resource_name="unittest-resources", offline=False) r.update(distribution_version="6.0.0") @@ -208,12 +192,7 @@ def test_fallback_to_tags(self, curr_branch, rebase, checkout, branches, tags, f is_working_copy.return_value = True head_revision.return_value = "123a" - r = repo.BenchmarkRepository( - default_directory=None, - root_dir="/benchmark-resources", - repo_name="unit-test", - resource_name="unittest-resources", - offline=False) + r = repo.BenchmarkRepository(default_directory=None, root_dir="/benchmark-resources", repo_name="unit-test", resource_name="unittest-resources", offline=False) r.update(distribution_version="1.7.4") @@ -238,7 +217,8 @@ def test_does_not_update_unknown_branch_remotely(self, rebase, checkout, branche root_dir="/benchmark-resources", repo_name="unit-test", resource_name="unittest-resources", - offline=False) + offline=False, + ) self.assertTrue(r.remote) @@ -267,8 +247,7 @@ def test_does_not_update_unknown_branch_remotely(self, rebase, checkout, branche @mock.patch("solrorbit.utils.git.checkout", autospec=True) @mock.patch("solrorbit.utils.git.rebase") @mock.patch("solrorbit.utils.git.current_branch") - def test_does_not_update_unknown_branch_remotely_local_fallback(self, curr_branch, rebase, checkout, branches, tags, - fetch, is_working_copy, head_revision): + def test_does_not_update_unknown_branch_remotely_local_fallback(self, curr_branch, rebase, checkout, branches, tags, fetch, is_working_copy, head_revision): curr_branch.return_value = "main" # we have only "main" remotely but a few more branches locally branches.side_effect = ["5", ["1", "2", "5", "main"]] @@ -281,7 +260,8 @@ def test_does_not_update_unknown_branch_remotely_local_fallback(self, curr_branc root_dir="/benchmark-resources", repo_name="unit-test", resource_name="unittest-resources", - offline=False) + offline=False, + ) r.update(distribution_version="1.7.3") @@ -308,12 +288,7 @@ def test_does_not_update_unknown_branch_locally(self, rebase, checkout, branches tags.return_value = [] is_working_copy.return_value = True - r = repo.BenchmarkRepository( - default_directory=None, - root_dir="/benchmark-resources", - repo_name="unit-test", - resource_name="unittest-resources", - offline=False) + r = repo.BenchmarkRepository(default_directory=None, root_dir="/benchmark-resources", repo_name="unit-test", resource_name="unittest-resources", offline=False) with self.assertRaises(exceptions.SystemSetupError) as ctx: r.update(distribution_version="4.0.0") @@ -330,12 +305,7 @@ def test_does_not_update_unknown_branch_locally(self, rebase, checkout, branches def test_checkout_revision(self, checkout, fetch, is_working_copy): is_working_copy.return_value = True - r = repo.BenchmarkRepository( - default_directory=None, - root_dir="/benchmark-resources", - repo_name="unit-test", - resource_name="unittest-resources", - offline=False) + r = repo.BenchmarkRepository(default_directory=None, root_dir="/benchmark-resources", repo_name="unit-test", resource_name="unittest-resources", offline=False) r.checkout("abcdef123") diff --git a/tests/utils/versions_test.py b/tests/utils/versions_test.py index c9b94c75..df98be84 100644 --- a/tests/utils/versions_test.py +++ b/tests/utils/versions_test.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -74,80 +74,68 @@ def test_latest_bounded_minor(self, seed): random.shuffle(alternatives) assert versions.latest_bounded_minor(alternatives, versions.VersionVariants("7.6.3")) == 2 - assert versions.latest_bounded_minor(alternatives, versions.VersionVariants("7.12.3")) == 10,\ + assert versions.latest_bounded_minor(alternatives, versions.VersionVariants("7.12.3")) == 10, ( "Nearest alternative with major.minor, skip alternatives with major.minor.patch" - assert versions.latest_bounded_minor(alternatives, versions.VersionVariants("7.11.2")) == 10,\ - "Skips all alternatives with major.minor.patch, even if exact match" - assert versions.latest_bounded_minor(alternatives, versions.VersionVariants("7.1.0")) is None,\ - "No matching alternative with minor version" + ) + assert versions.latest_bounded_minor(alternatives, versions.VersionVariants("7.11.2")) == 10, "Skips all alternatives with major.minor.patch, even if exact match" + assert versions.latest_bounded_minor(alternatives, versions.VersionVariants("7.1.0")) is None, "No matching alternative with minor version" def test_components_ignores_invalid_versions(self): with pytest.raises( - exceptions.InvalidSyntax, - match=re.escape( - r"version string '5.0.0a' does not conform to pattern " - r"'^(\d+)\.(\d+)\.(\d+)(?:-(.+))?$'")): + exceptions.InvalidSyntax, + match=re.escape( + r"version string '5.0.0a' does not conform to pattern " + r"'^(\d+)\.(\d+)\.(\d+)(?:-(.+))?$'" + ), + ): versions.components("5.0.0a") def test_versionvariants_parses_correct_version_string(self): - assert versions.VersionVariants("5.0.3").all_versions == [ - ("5.0.3", "with_patch"), - ("5.0", "with_minor"), - ("5", "with_major")] + assert versions.VersionVariants("5.0.3").all_versions == [("5.0.3", "with_patch"), ("5.0", "with_minor"), ("5", "with_major")] assert versions.VersionVariants("7.12.1-SNAPSHOT").all_versions == [ ("7.12.1-SNAPSHOT", "with_suffix"), - ("7.12.1", "with_patch"), - ("7.12", "with_minor"), - ("7", "with_major")] - assert versions.VersionVariants("10.3.63").all_versions == [ - ("10.3.63", "with_patch"), - ("10.3", "with_minor"), - ("10", "with_major")] + ("7.12.1", "with_patch"), + ("7.12", "with_minor"), + ("7", "with_major"), + ] + assert versions.VersionVariants("10.3.63").all_versions == [("10.3.63", "with_patch"), ("10.3", "with_minor"), ("10", "with_major")] def test_versions_rejects_invalid_version_strings(self): with pytest.raises( - exceptions.InvalidSyntax, - match=re.escape(r"version string '5.0.0a-SNAPSHOT' does not conform to pattern " - r"'^(\d+)\.(\d+)\.(\d+)(?:-(.+))?$'") + exceptions.InvalidSyntax, + match=re.escape( + r"version string '5.0.0a-SNAPSHOT' does not conform to pattern " + r"'^(\d+)\.(\d+)\.(\d+)(?:-(.+))?$'" + ), ): versions.VersionVariants("5.0.0a-SNAPSHOT") def test_find_best_match(self): - assert versions.best_match(["1.7", "2", "5.0.0-alpha1", "5", "main"], "6.0.0-alpha1") == "main",\ - "Assume main for versions newer than latest alternative available" + assert versions.best_match(["1.7", "2", "5.0.0-alpha1", "5", "main"], "6.0.0-alpha1") == "main", "Assume main for versions newer than latest alternative available" - assert versions.best_match(["1.7", "2", "5.0.0-alpha1", "5", "main"], "5.1.0-SNAPSHOT") == "5",\ - "Best match for specific version" + assert versions.best_match(["1.7", "2", "5.0.0-alpha1", "5", "main"], "5.1.0-SNAPSHOT") == "5", "Best match for specific version" - assert versions.best_match(["1.7", "2", "5.0.0-alpha1", "5", "main"], None) == "main",\ - "Assume main on unknown version" + assert versions.best_match(["1.7", "2", "5.0.0-alpha1", "5", "main"], None) == "main", "Assume main on unknown version" - assert versions.best_match(["1.7", "2", "5.0.0-alpha1", "5", "main"], "0.4") is None,\ - "Reject versions that are too old" + assert versions.best_match(["1.7", "2", "5.0.0-alpha1", "5", "main"], "0.4") is None, "Reject versions that are too old" - assert versions.best_match(["7", "7.10.2", "7.11", "7.2", "5", "6", "main"], "7.10.2") == "7.10.2", \ - "Exact match" + assert versions.best_match(["7", "7.10.2", "7.11", "7.2", "5", "6", "main"], "7.10.2") == "7.10.2", "Exact match" - assert versions.best_match(["7", "7.10", "main"], "7.1.0") == "7", \ - "Best match is major version" + assert versions.best_match(["7", "7.10", "main"], "7.1.0") == "7", "Best match is major version" - assert versions.best_match(["7", "7.11", "7.2", "5", "6", "main"], "7.11.0") == "7.11",\ - "Best match for specific minor version" + assert versions.best_match(["7", "7.11", "7.2", "5", "6", "main"], "7.11.0") == "7.11", "Best match for specific minor version" - assert versions.best_match(["7", "7.11", "7.2", "5", "6", "main"], "7.12.0") == "7.11",\ - "If no exact match, best match is the nearest prior minor" + assert versions.best_match(["7", "7.11", "7.2", "5", "6", "main"], "7.12.0") == "7.11", "If no exact match, best match is the nearest prior minor" - assert versions.best_match(["7", "7.11", "7.2", "5", "6", "main"], "7.3.0") == "7.2",\ - "If no exact match, best match is the nearest prior minor" + assert versions.best_match(["7", "7.11", "7.2", "5", "6", "main"], "7.3.0") == "7.2", "If no exact match, best match is the nearest prior minor" - assert versions.best_match(["7", "7.11", "7.2", "5", "6", "main"], "7.10.0") == "7.2", \ - "If no exact match, best match is the nearest prior minor" + assert versions.best_match(["7", "7.11", "7.2", "5", "6", "main"], "7.10.0") == "7.2", "If no exact match, best match is the nearest prior minor" - assert versions.best_match(["7", "7.1", "7.11.1", "7.11.0", "7.2", "5", "6", "main"], "7.12.0") == "7.2",\ + assert versions.best_match(["7", "7.1", "7.11.1", "7.11.0", "7.2", "5", "6", "main"], "7.12.0") == "7.2", ( "Patch or patch-suffix branches are not supported and ignored, best match is nearest prior minor" + ) - assert versions.best_match(["7", "7.11", "7.2", "5", "6", "main"], "7.1.0") == "7",\ - "If no exact match and no minor match, next best match is major version" + assert versions.best_match(["7", "7.11", "7.2", "5", "6", "main"], "7.1.0") == "7", "If no exact match and no minor match, next best match is major version" def test_version_comparison(self): assert versions.Version.from_string("7.10.2") < versions.Version.from_string("7.11.0") diff --git a/tests/worker_coordinator/__init__.py b/tests/worker_coordinator/__init__.py index 5047a451..f5768141 100644 --- a/tests/worker_coordinator/__init__.py +++ b/tests/worker_coordinator/__init__.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an diff --git a/tests/worker_coordinator/runner_test.py b/tests/worker_coordinator/runner_test.py index 1b6ba900..e80e4417 100644 --- a/tests/worker_coordinator/runner_test.py +++ b/tests/worker_coordinator/runner_test.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -64,8 +64,7 @@ async def runner_function(*args): returned_runner = runner.runner_for("unit_test") self.assertIsInstance(returned_runner, runner.NoCompletion) self.assertEqual("user-defined runner for [runner_function]", repr(returned_runner)) - self.assertEqual(("default_client", "param"), - await returned_runner({"default": "default_client", "other": "other_client"}, "param")) + self.assertEqual(("default_client", "param"), await returned_runner({"default": "default_client", "other": "other_client"}, "param")) @run_async async def test_single_cluster_runner_class_with_context_manager_should_be_wrapped_with_context_manager_enabled(self): @@ -80,12 +79,10 @@ def __str__(self): runner.register_runner(operation_type="unit_test", runner=test_runner, async_runner=True) returned_runner = runner.runner_for("unit_test") self.assertIsInstance(returned_runner, runner.NoCompletion) - self.assertEqual("user-defined context-manager enabled runner for [UnitTestSingleClusterContextManagerRunner]", - repr(returned_runner)) + self.assertEqual("user-defined context-manager enabled runner for [UnitTestSingleClusterContextManagerRunner]", repr(returned_runner)) # test that context_manager functionality gets preserved after wrapping async with returned_runner: - self.assertEqual(("default_client", "param"), - await returned_runner({"default": "default_client", "other": "other_client"}, "param")) + self.assertEqual(("default_client", "param"), await returned_runner({"default": "default_client", "other": "other_client"}, "param")) # check that the context manager interface of our inner runner has been respected. self.assertTrue(test_runner.fp.closed) @@ -104,8 +101,7 @@ def __str__(self): runner.register_runner(operation_type="unit_test", runner=test_runner, async_runner=True) returned_runner = runner.runner_for("unit_test") self.assertIsInstance(returned_runner, runner.NoCompletion) - self.assertEqual("user-defined context-manager enabled runner for [UnitTestMultiClusterContextManagerRunner]", - repr(returned_runner)) + self.assertEqual("user-defined context-manager enabled runner for [UnitTestMultiClusterContextManagerRunner]", repr(returned_runner)) # test that context_manager functionality gets preserved after wrapping all_clients = {"default": "default_client", "other": "other_client"} @@ -128,8 +124,7 @@ def __str__(self): returned_runner = runner.runner_for("unit_test") self.assertIsInstance(returned_runner, runner.NoCompletion) self.assertEqual("user-defined runner for [UnitTestSingleClusterRunner]", repr(returned_runner)) - self.assertEqual(("default_client", "param"), - await returned_runner({"default": "default_client", "other": "other_client"}, "param")) + self.assertEqual(("default_client", "param"), await returned_runner({"default": "default_client", "other": "other_client"}, "param")) @run_async async def test_multi_cluster_runner_class_should_be_wrapped(self): @@ -161,68 +156,37 @@ def tearDown(self): @run_async async def test_asserts_equal_succeeds(self): opensearch = None - response = { - "hits": { - "hits": { - "value": 5, - "relation": "eq" - } - } - } + response = {"hits": {"hits": {"value": 5, "relation": "eq"}}} delegate = mock.MagicMock() delegate.return_value = as_future(response) r = runner.AssertingRunner(delegate) async with r: - final_response = await r(opensearch, { - "name": "test-task", - "assertions": [ - { - "property": "hits.hits.value", - "condition": "==", - "value": 5 - }, - { - "property": "hits.hits.relation", - "condition": "==", - "value": "eq" - } - ] - }) + final_response = await r( + opensearch, + { + "name": "test-task", + "assertions": [{"property": "hits.hits.value", "condition": "==", "value": 5}, {"property": "hits.hits.relation", "condition": "==", "value": "eq"}], + }, + ) self.assertEqual(response, final_response) @run_async async def test_asserts_equal_fails(self): - opensearch = None - response = { - "hits": { - "hits": { - "value": 10000, - "relation": "gte" - } - } - } + opensearch = None + response = {"hits": {"hits": {"value": 10000, "relation": "gte"}}} delegate = mock.MagicMock() delegate.return_value = as_future(response) r = runner.AssertingRunner(delegate) - with self.assertRaisesRegex(exceptions.BenchmarkTaskAssertionError, - r"Expected \[hits.hits.relation\] in \[test-task\] to be == \[eq\] but was \[gte\]."): + with self.assertRaisesRegex(exceptions.BenchmarkTaskAssertionError, r"Expected \[hits.hits.relation\] in \[test-task\] to be == \[eq\] but was \[gte\]."): async with r: - await r(opensearch, { - "name": "test-task", - "assertions": [ - { - "property": "hits.hits.value", - "condition": "==", - "value": 10000 - }, - { - "property": "hits.hits.relation", - "condition": "==", - "value": "eq" - } - ] - }) + await r( + opensearch, + { + "name": "test-task", + "assertions": [{"property": "hits.hits.value", "condition": "==", "value": 10000}, {"property": "hits.hits.relation", "condition": "==", "value": "eq"}], + }, + ) @run_async async def test_skips_asserts_for_non_dicts(self): @@ -232,16 +196,7 @@ async def test_skips_asserts_for_non_dicts(self): delegate.return_value = as_future(response) r = runner.AssertingRunner(delegate) async with r: - final_response = await r(opensearch, { - "name": "test-task", - "assertions": [ - { - "property": "hits.hits.value", - "condition": "==", - "value": 5 - } - ] - }) + final_response = await r(opensearch, {"name": "test-task", "assertions": [{"property": "hits.hits.value", "condition": "==", "value": 5}]}) # still passes response as is self.assertEqual(response, final_response) @@ -260,8 +215,7 @@ def test_predicates(self): for predicate, vals in predicate_success.items(): expected, actual = vals - self.assertTrue(r.predicates[predicate](expected, actual), - f"Expected [{expected} {predicate} {actual}] to succeed.") + self.assertTrue(r.predicates[predicate](expected, actual), f"Expected [{expected} {predicate} {actual}] to succeed.") predicate_fail = { # predicate: (expected, actual) @@ -274,8 +228,7 @@ def test_predicates(self): for predicate, vals in predicate_fail.items(): expected, actual = vals - self.assertFalse(r.predicates[predicate](expected, actual), - f"Expected [{expected} {predicate} {actual}] to fail.") + self.assertFalse(r.predicates[predicate](expected, actual), f"Expected [{expected} {predicate} {actual}] to fail.") class RawRequestRunnerTests(TestCase): @@ -332,16 +285,16 @@ async def test_custom_headers(self): class SleepTests(TestCase): @mock.patch("tests.worker_coordinator.runner_test._FakeOSClient") - @mock.patch('solrorbit.client.RequestContextHolder.on_client_request_end') - @mock.patch('solrorbit.client.RequestContextHolder.on_client_request_start') + @mock.patch("solrorbit.client.RequestContextHolder.on_client_request_end") + @mock.patch("solrorbit.client.RequestContextHolder.on_client_request_start") # To avoid real sleeps in unit tests @mock.patch("asyncio.sleep", return_value=as_future()) @run_async async def test_missing_parameter(self, sleep, on_client_request_start, on_client_request_end, opensearch): r = runner.Sleep() - with self.assertRaisesRegex(exceptions.DataError, - "Parameter source for operation 'sleep' did not provide the mandatory parameter " - "'duration'. Add it to your parameter source and try again."): + with self.assertRaisesRegex( + exceptions.DataError, "Parameter source for operation 'sleep' did not provide the mandatory parameter 'duration'. Add it to your parameter source and try again." + ): await r(opensearch, params={}) self.assertEqual(0, opensearch.call_count) @@ -352,8 +305,8 @@ async def test_missing_parameter(self, sleep, on_client_request_start, on_client self.assertEqual(0, sleep.call_count) @mock.patch("tests.worker_coordinator.runner_test._FakeOSClient") - @mock.patch('solrorbit.client.RequestContextHolder.on_client_request_end') - @mock.patch('solrorbit.client.RequestContextHolder.on_client_request_start') + @mock.patch("solrorbit.client.RequestContextHolder.on_client_request_end") + @mock.patch("solrorbit.client.RequestContextHolder.on_client_request_start") # To avoid real sleeps in unit tests @mock.patch("asyncio.sleep", return_value=as_future()) @run_async @@ -369,16 +322,6 @@ async def test_sleep(self, sleep, on_client_request_start, on_client_request_end sleep.assert_called_once_with(4.3) - - - - - - - - - - class CompositeContextTests(TestCase): def test_cannot_be_used_outside_of_composite(self): with self.assertRaises(exceptions.BenchmarkAssertionError) as ctx: @@ -398,8 +341,7 @@ async def test_put_get_and_remove(self): async with runner.CompositeContext(): with self.assertRaises(KeyError) as ctx: runner.CompositeContext.get("don't clear this key") - self.assertEqual("Unknown property [don't clear this key]. Currently recognized properties are [].", - ctx.exception.args[0]) + self.assertEqual("Unknown property [don't clear this key]. Currently recognized properties are [].", ctx.exception.args[0]) @run_async async def test_fails_to_read_unknown_key(self): @@ -407,8 +349,7 @@ async def test_fails_to_read_unknown_key(self): with self.assertRaises(KeyError) as ctx: runner.CompositeContext.put("test", 1) runner.CompositeContext.get("unknown") - self.assertEqual("Unknown property [unknown]. Currently recognized properties are [test].", - ctx.exception.args[0]) + self.assertEqual("Unknown property [unknown]. Currently recognized properties are [test].", ctx.exception.args[0]) @run_async async def test_fails_to_remove_unknown_key(self): @@ -416,8 +357,7 @@ async def test_fails_to_remove_unknown_key(self): with self.assertRaises(KeyError) as ctx: runner.CompositeContext.put("test", 1) runner.CompositeContext.remove("unknown") - self.assertEqual("Unknown property [unknown]. Currently recognized properties are [test].", - ctx.exception.args[0]) + self.assertEqual("Unknown property [unknown]. Currently recognized properties are [test].", ctx.exception.args[0]) class CompositeTests(TestCase): @@ -461,10 +401,10 @@ def tearDown(self): runner.remove_runner("counter") runner.remove_runner("call-recorder") - @mock.patch('solrorbit.client.RequestContextHolder.on_client_request_end') - @mock.patch('solrorbit.client.RequestContextHolder.on_client_request_start') + @mock.patch("solrorbit.client.RequestContextHolder.on_client_request_end") + @mock.patch("solrorbit.client.RequestContextHolder.on_client_request_start") @mock.patch("tests.worker_coordinator.runner_test._FakeOSClient") - @mock.patch('solrorbit.client.RequestContextHolder.new_request_context') + @mock.patch("solrorbit.client.RequestContextHolder.new_request_context") @run_async async def test_runs_tasks_in_specified_order(self, opensearch, on_client_request_start, on_client_request_end, new_request_context): opensearch.transport.perform_request.return_value = as_future() @@ -515,7 +455,6 @@ async def test_runs_tasks_in_specified_order(self, opensearch, on_client_request "name": "call-after-stream-cd", "operation-type": "call-recorder", }, - ] } @@ -523,15 +462,20 @@ async def test_runs_tasks_in_specified_order(self, opensearch, on_client_request r.supported_op_types = ["call-recorder"] await r(opensearch, params) - self.assertEqual([ - "initial-call", - # concurrent - "stream-a", "stream-b", - "call-after-stream-ab", - # concurrent - "stream-c", "stream-d", - "call-after-stream-cd" - ], self.call_recorder_runner.calls) + self.assertEqual( + [ + "initial-call", + # concurrent + "stream-a", + "stream-b", + "call-after-stream-ab", + # concurrent + "stream-c", + "stream-d", + "call-after-stream-cd", + ], + self.call_recorder_runner.calls, + ) @pytest.mark.skip(reason="latency is system-dependent") @run_async @@ -542,29 +486,9 @@ async def test_adds_request_timings(self): params = { "requests": [ - { - "name": "initial-call", - "operation-type": "sleep", - "duration": 0.1 - }, - { - "stream": [ - { - "name": "stream-a", - "operation-type": "sleep", - "duration": 0.2 - } - ] - }, - { - "stream": [ - { - "name": "stream-b", - "operation-type": "sleep", - "duration": 0.1 - } - ] - } + {"name": "initial-call", "operation-type": "sleep", "duration": 0.1}, + {"stream": [{"name": "stream-a", "operation-type": "sleep", "duration": 0.2}]}, + {"stream": [{"name": "stream-b", "operation-type": "sleep", "duration": 0.1}]}, ] } @@ -593,37 +517,14 @@ async def test_adds_request_timings(self): self.assertIn("request_end", timing) self.assertGreater(timing["request_end"], timing["request_start"]) - @mock.patch('solrorbit.client.RequestContextHolder.on_client_request_end') - @mock.patch('solrorbit.client.RequestContextHolder.on_client_request_start') + @mock.patch("solrorbit.client.RequestContextHolder.on_client_request_end") + @mock.patch("solrorbit.client.RequestContextHolder.on_client_request_start") @mock.patch("tests.worker_coordinator.runner_test._FakeOSClient") @run_async async def test_limits_connections(self, opensearch, on_client_request_start, on_client_request_end): params = { "max-connections": 2, - "requests": [ - { - "stream": [ - { - "operation-type": "counter" - } - ] - }, - { - "stream": [ - { - "operation-type": "counter" - } - - ] - }, - { - "stream": [ - { - "operation-type": "counter" - } - ] - } - ] + "requests": [{"stream": [{"operation-type": "counter"}]}, {"stream": [{"operation-type": "counter"}]}, {"stream": [{"operation-type": "counter"}]}], } r = runner.Composite() @@ -633,32 +534,13 @@ async def test_limits_connections(self, opensearch, on_client_request_start, on_ # composite runner should limit to two concurrent connections self.assertEqual(2, self.counter_runner.max_value) - @mock.patch('solrorbit.client.RequestContextHolder.on_client_request_end') - @mock.patch('solrorbit.client.RequestContextHolder.on_client_request_start') + @mock.patch("solrorbit.client.RequestContextHolder.on_client_request_end") + @mock.patch("solrorbit.client.RequestContextHolder.on_client_request_start") @mock.patch("tests.worker_coordinator.runner_test._FakeOSClient") @run_async async def test_rejects_invalid_stream(self, opensearch, on_client_request_start, on_client_request_end): # params contains a "streams" property (plural) but it should be "stream" (singular) - params = { - "max-connections": 2, - "requests": [ - { - "stream": [ - { - "operation-type": "counter" - } - ] - }, - { - "streams": [ - { - "operation-type": "counter" - } - - ] - } - ] - } + params = {"max-connections": 2, "requests": [{"stream": [{"operation-type": "counter"}]}, {"streams": [{"operation-type": "counter"}]}]} r = runner.Composite() with self.assertRaises(exceptions.BenchmarkAssertionError) as ctx: @@ -666,22 +548,12 @@ async def test_rejects_invalid_stream(self, opensearch, on_client_request_start, self.assertEqual("Requests structure must contain [stream] or [operation-type].", ctx.exception.args[0]) - @mock.patch('solrorbit.client.RequestContextHolder.on_client_request_end') - @mock.patch('solrorbit.client.RequestContextHolder.on_client_request_start') + @mock.patch("solrorbit.client.RequestContextHolder.on_client_request_end") + @mock.patch("solrorbit.client.RequestContextHolder.on_client_request_start") @mock.patch("tests.worker_coordinator.runner_test._FakeOSClient") @run_async async def test_rejects_unsupported_operations(self, opensearch, on_client_request_start, on_client_request_end): - params = { - "requests": [ - { - "stream": [ - { - "operation-type": "bulk" - } - ] - } - ] - } + params = {"requests": [{"stream": [{"operation-type": "bulk"}]}]} r = runner.Composite() with self.assertRaises(exceptions.BenchmarkAssertionError) as ctx: @@ -712,23 +584,16 @@ def request_end(self): async def __aexit__(self, exc_type, exc_val, exc_tb): return False - @mock.patch('solrorbit.client.RequestContextHolder.on_client_request_end') - @mock.patch('solrorbit.client.RequestContextHolder.on_client_request_start') + @mock.patch("solrorbit.client.RequestContextHolder.on_client_request_end") + @mock.patch("solrorbit.client.RequestContextHolder.on_client_request_start") @mock.patch("tests.worker_coordinator.runner_test._FakeOSClient") @run_async async def test_merges_timing_info(self, opensearch, on_client_request_start, on_client_request_end): multi_cluster_client = {"default": opensearch} opensearch.new_request_context.return_value = RequestTimingTests.StaticRequestTiming(task_start=2) - delegate = mock.Mock(return_value=as_future({ - "weight": 5, - "unit": "ops", - "success": True - })) - params = { - "name": "unit-test-operation", - "operation-type": "test-op" - } + delegate = mock.Mock(return_value=as_future({"weight": 5, "unit": "ops", "success": True})) + params = {"name": "unit-test-operation", "operation-type": "test-op"} timer = runner.RequestTiming(delegate) response = await timer(multi_cluster_client, params) @@ -747,8 +612,8 @@ async def test_merges_timing_info(self, opensearch, on_client_request_start, on_ delegate.assert_called_once_with(multi_cluster_client, params) - @mock.patch('solrorbit.client.RequestContextHolder.on_client_request_end') - @mock.patch('solrorbit.client.RequestContextHolder.on_client_request_start') + @mock.patch("solrorbit.client.RequestContextHolder.on_client_request_end") + @mock.patch("solrorbit.client.RequestContextHolder.on_client_request_start") @mock.patch("tests.worker_coordinator.runner_test._FakeOSClient") @run_async async def test_creates_new_timing_info(self, opensearch, on_client_request_start, on_client_request_end): @@ -757,10 +622,7 @@ async def test_creates_new_timing_info(self, opensearch, on_client_request_start # a simple runner without a return value delegate = mock.Mock(return_value=as_future()) - params = { - "name": "unit-test-operation", - "operation-type": "test-op" - } + params = {"name": "unit-test-operation", "operation-type": "test-op"} timer = runner.RequestTiming(delegate) response = await timer(multi_cluster_client, params) @@ -830,12 +692,7 @@ async def test_is_transparent_on_application_error_when_no_retries(self): async def test_is_does_not_retry_on_success(self): delegate = mock.Mock(return_value=as_future()) opensearch = None - params = { - "retries": 3, - "retry-wait-period": 0.1, - "retry-on-timeout": True, - "retry-on-error": True - } + params = {"retries": 3, "retry-wait-period": 0.1, "retry-on-timeout": True, "retry-on-error": True} retrier = runner.Retry(delegate) await retrier(opensearch, params) @@ -844,56 +701,43 @@ async def test_is_does_not_retry_on_success(self): @run_async async def test_retries_on_timeout_if_wanted_and_raises_if_no_recovery(self): - delegate = mock.Mock(side_effect=[ - as_future(exception=exceptions.BenchmarkConnectionError("no route to host")), - as_future(exception=exceptions.BenchmarkConnectionError("no route to host")), - as_future(exception=exceptions.BenchmarkConnectionError("no route to host")), - as_future(exception=exceptions.BenchmarkConnectionError("no route to host")) - ]) + delegate = mock.Mock( + side_effect=[ + as_future(exception=exceptions.BenchmarkConnectionError("no route to host")), + as_future(exception=exceptions.BenchmarkConnectionError("no route to host")), + as_future(exception=exceptions.BenchmarkConnectionError("no route to host")), + as_future(exception=exceptions.BenchmarkConnectionError("no route to host")), + ] + ) opensearch = None - params = { - "retries": 3, - "retry-wait-period": 0.01, - "retry-on-timeout": True, - "retry-on-error": True - } + params = {"retries": 3, "retry-wait-period": 0.01, "retry-on-timeout": True, "retry-on-error": True} retrier = runner.Retry(delegate) with self.assertRaises(exceptions.BenchmarkConnectionError): await retrier(opensearch, params) - delegate.assert_has_calls([ - mock.call(opensearch, params), - mock.call(opensearch, params), - mock.call(opensearch, params) - ]) + delegate.assert_has_calls([mock.call(opensearch, params), mock.call(opensearch, params), mock.call(opensearch, params)]) @run_async async def test_retries_on_timeout_if_wanted_and_returns_first_call(self): failed_return_value = {"weight": 1, "unit": "ops", "success": False} - delegate = mock.Mock(side_effect=[ - as_future(exception=exceptions.BenchmarkConnectionError("no route to host")), - as_future(failed_return_value) - ]) + delegate = mock.Mock(side_effect=[as_future(exception=exceptions.BenchmarkConnectionError("no route to host")), as_future(failed_return_value)]) opensearch = None - params = { - "retries": 3, - "retry-wait-period": 0.01, - "retry-on-timeout": True, - "retry-on-error": False - } + params = {"retries": 3, "retry-wait-period": 0.01, "retry-on-timeout": True, "retry-on-error": False} retrier = runner.Retry(delegate) result = await retrier(opensearch, params) self.assertEqual(failed_return_value, result) - delegate.assert_has_calls([ - # has returned a connection error - mock.call(opensearch, params), - # has returned normally - mock.call(opensearch, params) - ]) + delegate.assert_has_calls( + [ + # has returned a connection error + mock.call(opensearch, params), + # has returned normally + mock.call(opensearch, params), + ] + ) @run_async async def test_retries_mixed_timeout_and_application_errors(self): @@ -901,52 +745,51 @@ async def test_retries_mixed_timeout_and_application_errors(self): failed_return_value = {"weight": 1, "unit": "ops", "success": False} success_return_value = {"weight": 1, "unit": "ops", "success": False} - delegate = mock.Mock(side_effect=[ - as_future(exception=connection_error), - as_future(failed_return_value), - as_future(exception=connection_error), - as_future(exception=connection_error), - as_future(failed_return_value), - as_future(success_return_value) - ]) + delegate = mock.Mock( + side_effect=[ + as_future(exception=connection_error), + as_future(failed_return_value), + as_future(exception=connection_error), + as_future(exception=connection_error), + as_future(failed_return_value), + as_future(success_return_value), + ] + ) opensearch = None params = { # we try exactly as often as there are errors to also test the semantics of "retry". "retries": 5, "retry-wait-period": 0.01, "retry-on-timeout": True, - "retry-on-error": True + "retry-on-error": True, } retrier = runner.Retry(delegate) result = await retrier(opensearch, params) self.assertEqual(success_return_value, result) - delegate.assert_has_calls([ - # connection error - mock.call(opensearch, params), - # application error - mock.call(opensearch, params), - # connection error - mock.call(opensearch, params), - # connection error - mock.call(opensearch, params), - # application error - mock.call(opensearch, params), - # success - mock.call(opensearch, params) - ]) + delegate.assert_has_calls( + [ + # connection error + mock.call(opensearch, params), + # application error + mock.call(opensearch, params), + # connection error + mock.call(opensearch, params), + # connection error + mock.call(opensearch, params), + # application error + mock.call(opensearch, params), + # success + mock.call(opensearch, params), + ] + ) @run_async async def test_does_not_retry_on_timeout_if_not_wanted(self): delegate = mock.Mock(side_effect=as_future(exception=exceptions.BenchmarkConnectionTimeout("timed out"))) opensearch = None - params = { - "retries": 3, - "retry-wait-period": 0.01, - "retry-on-timeout": False, - "retry-on-error": True - } + params = {"retries": 3, "retry-wait-period": 0.01, "retry-on-timeout": False, "retry-on-error": True} retrier = runner.Retry(delegate) with self.assertRaises(exceptions.BenchmarkConnectionTimeout): @@ -959,28 +802,22 @@ async def test_retries_on_application_error_if_wanted(self): failed_return_value = {"weight": 1, "unit": "ops", "success": False} success_return_value = {"weight": 1, "unit": "ops", "success": True} - delegate = mock.Mock(side_effect=[ - as_future(failed_return_value), - as_future(success_return_value) - ]) + delegate = mock.Mock(side_effect=[as_future(failed_return_value), as_future(success_return_value)]) opensearch = None - params = { - "retries": 3, - "retry-wait-period": 0.01, - "retry-on-timeout": False, - "retry-on-error": True - } + params = {"retries": 3, "retry-wait-period": 0.01, "retry-on-timeout": False, "retry-on-error": True} retrier = runner.Retry(delegate) result = await retrier(opensearch, params) self.assertEqual(success_return_value, result) - delegate.assert_has_calls([ - mock.call(opensearch, params), - # one retry - mock.call(opensearch, params) - ]) + delegate.assert_has_calls( + [ + mock.call(opensearch, params), + # one retry + mock.call(opensearch, params), + ] + ) @run_async async def test_does_not_retry_on_application_error_if_not_wanted(self): @@ -988,12 +825,7 @@ async def test_does_not_retry_on_application_error_if_not_wanted(self): delegate = mock.Mock(return_value=as_future(failed_return_value)) opensearch = None - params = { - "retries": 3, - "retry-wait-period": 0.01, - "retry-on-timeout": True, - "retry-on-error": False - } + params = {"retries": 3, "retry-wait-period": 0.01, "retry-on-timeout": True, "retry-on-error": False} retrier = runner.Retry(delegate) result = await retrier(opensearch, params) @@ -1006,12 +838,7 @@ async def test_does_not_retry_on_application_error_if_not_wanted(self): async def test_assumes_success_if_runner_returns_non_dict(self): delegate = mock.Mock(return_value=as_future(result=(1, "ops"))) opensearch = None - params = { - "retries": 3, - "retry-wait-period": 0.01, - "retry-on-timeout": True, - "retry-on-error": True - } + params = {"retries": 3, "retry-wait-period": 0.01, "retry-on-timeout": True, "retry-on-error": True} retrier = runner.Retry(delegate) result = await retrier(opensearch, params) @@ -1033,10 +860,7 @@ async def test_retries_until_success(self): delegate = mock.Mock(side_effect=responses) opensearch = None - params = { - "retry-until-success": True, - "retry-wait-period": 0.01 - } + params = {"retry-until-success": True, "retry-wait-period": 0.01} retrier = runner.Retry(delegate) result = await retrier(opensearch, params) diff --git a/tests/worker_coordinator/scheduler_test.py b/tests/worker_coordinator/scheduler_test.py index 038463e8..5ed8c151 100644 --- a/tests/worker_coordinator/scheduler_test.py +++ b/tests/worker_coordinator/scheduler_test.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -50,10 +50,8 @@ def assertThroughputEquals(self, sched, expected_average_throughput, msg="", rel expected_lower_bound = (1.0 - relative_delta) * expected_average_rate expected_upper_bound = (1.0 + relative_delta) * expected_average_rate - self.assertGreaterEqual(actual_average_rate, expected_lower_bound, - f"{msg}: expected target rate to be >= [{expected_lower_bound}] but was [{actual_average_rate}].") - self.assertLessEqual(actual_average_rate, expected_upper_bound, - f"{msg}: expected target rate to be <= [{expected_upper_bound}] but was [{actual_average_rate}].") + self.assertGreaterEqual(actual_average_rate, expected_lower_bound, f"{msg}: expected target rate to be >= [{expected_lower_bound}] but was [{actual_average_rate}].") + self.assertLessEqual(actual_average_rate, expected_upper_bound, f"{msg}: expected target rate to be <= [{expected_upper_bound}] but was [{actual_average_rate}].") class DeterministicSchedulerTests(SchedulerTestCase): @@ -75,30 +73,25 @@ def test_schedule_matches_expected_target_throughput(self): class UnitAwareSchedulerTests(TestCase): def test_scheduler_rejects_differing_throughput_units(self): - task = workload.Task(name="bulk-index", - operation=workload.Operation( - name="bulk-index", - operation_type=workload.OperationType.Bulk.to_hyphenated_string()), - clients=4, - params={ - "target-throughput": "5000 MB/s" - }) + task = workload.Task( + name="bulk-index", + operation=workload.Operation(name="bulk-index", operation_type=workload.OperationType.Bulk.to_hyphenated_string()), + clients=4, + params={"target-throughput": "5000 MB/s"}, + ) s = scheduler.UnitAwareScheduler(task=task, scheduler_class=scheduler.DeterministicScheduler) with self.assertRaises(exceptions.BenchmarkAssertionError) as ex: s.after_request(now=None, weight=1000, unit="docs", request_meta_data=None) - self.assertEqual("Target throughput for [bulk-index] is specified in [MB/s] but the task throughput " - "is measured in [docs/s].", ex.exception.args[0]) + self.assertEqual("Target throughput for [bulk-index] is specified in [MB/s] but the task throughput is measured in [docs/s].", ex.exception.args[0]) def test_scheduler_adapts_to_changed_weights(self): - task = workload.Task(name="bulk-index", - operation=workload.Operation( - name="bulk-index", - operation_type=workload.OperationType.Bulk.to_hyphenated_string()), - clients=4, - params={ - "target-throughput": "5000 docs/s" - }) + task = workload.Task( + name="bulk-index", + operation=workload.Operation(name="bulk-index", operation_type=workload.OperationType.Bulk.to_hyphenated_string()), + clients=4, + params={"target-throughput": "5000 docs/s"}, + ) s = scheduler.UnitAwareScheduler(task=task, scheduler_class=scheduler.DeterministicScheduler) # first request is unthrottled @@ -118,15 +111,15 @@ def test_scheduler_adapts_to_changed_weights(self): self.assertEqual(2 * task.clients, s.next(0)) def test_scheduler_accepts_differing_units_pages_and_ops(self): - task = workload.Task(name="scroll-query", - operation=workload.Operation( - name="scroll-query", - operation_type=workload.OperationType.Search.to_hyphenated_string()), - clients=1, - params={ - # implicitly: ops/s - "target-throughput": 10 - }) + task = workload.Task( + name="scroll-query", + operation=workload.Operation(name="scroll-query", operation_type=workload.OperationType.Search.to_hyphenated_string()), + clients=1, + params={ + # implicitly: ops/s + "target-throughput": 10 + }, + ) s = scheduler.UnitAwareScheduler(task=task, scheduler_class=scheduler.DeterministicScheduler) # first request is unthrottled @@ -141,15 +134,15 @@ def test_scheduler_accepts_differing_units_pages_and_ops(self): self.assertEqual(0.1 * task.clients, s.next(0)) def test_scheduler_does_not_change_throughput_for_empty_requests(self): - task = workload.Task(name="match-all-query", - operation=workload.Operation( - name="query", - operation_type=workload.OperationType.Search.to_hyphenated_string()), - clients=1, - params={ - # implicitly: ops/s - "target-throughput": 10 - }) + task = workload.Task( + name="match-all-query", + operation=workload.Operation(name="query", operation_type=workload.OperationType.Search.to_hyphenated_string()), + clients=1, + params={ + # implicitly: ops/s + "target-throughput": 10 + }, + ) s = scheduler.UnitAwareScheduler(task=task, scheduler_class=scheduler.DeterministicScheduler) # first request is unthrottled... @@ -240,12 +233,9 @@ def tearDown(self): scheduler.remove_scheduler("simple") def test_legacy_scheduler(self): - task = workload.Task(name="raw-request", - operation=workload.Operation( - name="raw", - operation_type=workload.OperationType.RawRequest.to_hyphenated_string()), - clients=1, - schedule="simple") + task = workload.Task( + name="raw-request", operation=workload.Operation(name="raw", operation_type=workload.OperationType.RawRequest.to_hyphenated_string()), clients=1, schedule="simple" + ) s = scheduler.scheduler_for(task) diff --git a/tests/worker_coordinator/worker_coordinator_test.py b/tests/worker_coordinator/worker_coordinator_test.py index b5108899..e4b4c35c 100644 --- a/tests/worker_coordinator/worker_coordinator_test.py +++ b/tests/worker_coordinator/worker_coordinator_test.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -71,7 +71,7 @@ def partition(self, partition_index, total_partitions): def task_progress(self): if self.infinite: return None - return (self._current / self._total, '%') + return (self._current / self._total, "%") def params(self): if not self.infinite and self._current > self._total: @@ -86,9 +86,9 @@ def __init__(self, all_hosts=None, all_client_options=None): self.all_hosts = all_hosts self.all_client_options = all_client_options self.uses_static_responses = False - self.default = all_client_options.get('default', {}) if all_client_options else {} + self.default = all_client_options.get("default", {}) if all_client_options else {} - def __init__(self, methodName='runTest'): + def __init__(self, methodName="runTest"): super().__init__(methodName) self.cfg = None self.workload = None @@ -120,30 +120,24 @@ def setUp(self): self.cfg.add(config.Scope.application, "telemetry", "params", {"ccr-stats-indices": {"default": ["leader_index"]}}) self.cfg.add(config.Scope.application, "builder", "cluster_config.names", ["default"]) self.cfg.add(config.Scope.application, "builder", "skip.rest.api.check", True) - self.cfg.add(config.Scope.application, "client", "hosts", - WorkerCoordinatorTests.Holder(all_hosts={"default": ["localhost:9200"]})) + self.cfg.add(config.Scope.application, "client", "hosts", WorkerCoordinatorTests.Holder(all_hosts={"default": ["localhost:9200"]})) self.cfg.add(config.Scope.application, "client", "options", WorkerCoordinatorTests.Holder(all_client_options={"default": {}})) self.cfg.add(config.Scope.application, "worker_coordinator", "worker_ips", ["localhost"]) self.cfg.add(config.Scope.application, "reporting", "datastore.type", "in-memory") self.cfg.add(config.Scope.applicationOverride, "workload", "redline.max_cpu_usage", None) - default_test_procedure = workload.TestProcedure("default", default=True, schedule=[ - workload.Task(name="index", operation=workload.Operation("index", operation_type=workload.OperationType.Bulk), clients=4) - ]) + default_test_procedure = workload.TestProcedure( + "default", default=True, schedule=[workload.Task(name="index", operation=workload.Operation("index", operation_type=workload.OperationType.Bulk), clients=4)] + ) another_test_procedure = workload.TestProcedure("other", default=False) - self.workload = workload.Workload( - name="unittest", - description="unittest workload", - test_procedures=[another_test_procedure, default_test_procedure]) + self.workload = workload.Workload(name="unittest", description="unittest workload", test_procedures=[another_test_procedure, default_test_procedure]) def tearDown(self): WorkerCoordinatorTests.StaticClientFactory.close() def create_test_worker_coordinator_target(self): client = "client_marker" - attrs = { - "create_client.return_value": client - } + attrs = {"create_client.return_value": client} return mock.Mock(**attrs) @mock.patch("solrorbit.utils.net.resolve") @@ -159,12 +153,14 @@ def test_start_benchmark_and_prepare_workload(self, resolve): target.prepare_workload.assert_called_once_with(["10.5.5.1", "10.5.5.2"], self.cfg, self.workload) d.start_benchmark() - target.create_client.assert_has_calls(calls=[ - mock.call("10.5.5.1"), - mock.call("10.5.5.1"), - mock.call("10.5.5.2"), - mock.call("10.5.5.2"), - ]) + target.create_client.assert_has_calls( + calls=[ + mock.call("10.5.5.1"), + mock.call("10.5.5.1"), + mock.call("10.5.5.2"), + mock.call("10.5.5.2"), + ] + ) # Did we start all load generators? There is no specific mock assert for this... self.assertEqual(4, target.start_worker.call_count) @@ -179,12 +175,14 @@ def test_assign_worker_coordinators_round_robin(self): d.start_benchmark() - target.create_client.assert_has_calls(calls=[ - mock.call("localhost"), - mock.call("localhost"), - mock.call("localhost"), - mock.call("localhost"), - ]) + target.create_client.assert_has_calls( + calls=[ + mock.call("localhost"), + mock.call("localhost"), + mock.call("localhost"), + mock.call("localhost"), + ] + ) # Did we start all load generators? There is no specific mock assert for this... self.assertEqual(4, target.start_worker.call_count) @@ -198,9 +196,7 @@ def test_client_reaches_join_point_others_still_executing(self): self.assertEqual(0, len(d.workers_completed_current_step)) - d.joinpoint_reached(worker_id=0, - worker_local_timestamp=10, - task_allocations=[worker_coordinator.ClientAllocation(client_id=0, task=worker_coordinator.JoinPoint(id=0))]) + d.joinpoint_reached(worker_id=0, worker_local_timestamp=10, task_allocations=[worker_coordinator.ClientAllocation(client_id=0, task=worker_coordinator.JoinPoint(id=0))]) self.assertEqual(1, len(d.workers_completed_current_step)) @@ -216,12 +212,11 @@ def test_client_reaches_join_point_which_completes_parent(self): self.assertEqual(0, len(d.workers_completed_current_step)) - d.joinpoint_reached(worker_id=0, - worker_local_timestamp=10, - task_allocations=[ - worker_coordinator.ClientAllocation(client_id=0, - task=worker_coordinator.JoinPoint(id=0, - clients_executing_completing_task=[0]))]) + d.joinpoint_reached( + worker_id=0, + worker_local_timestamp=10, + task_allocations=[worker_coordinator.ClientAllocation(client_id=0, task=worker_coordinator.JoinPoint(id=0, clients_executing_completing_task=[0]))], + ) self.assertEqual(-1, d.current_step) self.assertEqual(1, len(d.workers_completed_current_step)) @@ -229,31 +224,28 @@ def test_client_reaches_join_point_which_completes_parent(self): self.assertEqual(4, target.complete_current_task.call_count) # awaiting responses of other clients - d.joinpoint_reached(worker_id=1, - worker_local_timestamp=11, - task_allocations=[ - worker_coordinator.ClientAllocation(client_id=1, - task=worker_coordinator.JoinPoint(id=0, - clients_executing_completing_task=[0]))]) + d.joinpoint_reached( + worker_id=1, + worker_local_timestamp=11, + task_allocations=[worker_coordinator.ClientAllocation(client_id=1, task=worker_coordinator.JoinPoint(id=0, clients_executing_completing_task=[0]))], + ) self.assertEqual(-1, d.current_step) self.assertEqual(2, len(d.workers_completed_current_step)) - d.joinpoint_reached(worker_id=2, - worker_local_timestamp=12, - task_allocations=[ - worker_coordinator.ClientAllocation(client_id=2, - task=worker_coordinator.JoinPoint(id=0, - clients_executing_completing_task=[0]))]) + d.joinpoint_reached( + worker_id=2, + worker_local_timestamp=12, + task_allocations=[worker_coordinator.ClientAllocation(client_id=2, task=worker_coordinator.JoinPoint(id=0, clients_executing_completing_task=[0]))], + ) self.assertEqual(-1, d.current_step) self.assertEqual(3, len(d.workers_completed_current_step)) - d.joinpoint_reached(worker_id=3, - worker_local_timestamp=13, - task_allocations=[ - worker_coordinator.ClientAllocation(client_id=3, - task=worker_coordinator.JoinPoint(id=0, - clients_executing_completing_task=[0]))]) + d.joinpoint_reached( + worker_id=3, + worker_local_timestamp=13, + task_allocations=[worker_coordinator.ClientAllocation(client_id=3, task=worker_coordinator.JoinPoint(id=0, clients_executing_completing_task=[0]))], + ) # by now the previous step should be considered completed and we are at the next one self.assertEqual(0, d.current_step) @@ -271,8 +263,7 @@ async def test_load_test_clients_override(self): task = self.workload.find_test_procedure_or_default("default").schedule[0] original_clients = task.clients - d = worker_coordinator.WorkerCoordinator(self.create_test_worker_coordinator_target(), self.cfg, - client_factory_class=WorkerCoordinatorTests.StaticClientFactory) + d = worker_coordinator.WorkerCoordinator(self.create_test_worker_coordinator_target(), self.cfg, client_factory_class=WorkerCoordinatorTests.StaticClientFactory) d.prepare_benchmark(t=self.workload) d.start_benchmark() @@ -289,16 +280,18 @@ def op(name, operation_type): class SamplePostprocessorTests(TestCase): def throughput(self, absolute_time, relative_time, value): - return mock.call(name="throughput", - value=value, - unit="docs/s", - task="index", - operation="index-op", - operation_type="bulk", - sample_type=metrics.SampleType.Normal, - absolute_time=absolute_time, - relative_time=relative_time, - meta_data={}) + return mock.call( + name="throughput", + value=value, + unit="docs/s", + task="index", + operation="index-op", + operation_type="bulk", + sample_type=metrics.SampleType.Normal, + absolute_time=absolute_time, + relative_time=relative_time, + meta_data={}, + ) def service_time(self, absolute_time, relative_time, value): return self.request_metric(absolute_time, relative_time, "service_time", value) @@ -313,41 +306,40 @@ def latency(self, absolute_time, relative_time, value): return self.request_metric(absolute_time, relative_time, "latency", value) def request_metric(self, absolute_time, relative_time, name, value): - return mock.call(name=name, - value=value, - unit="ms", - task="index", - operation="index-op", - operation_type="bulk", - sample_type=metrics.SampleType.Normal, - absolute_time=absolute_time, - relative_time=relative_time, - meta_data={}) + return mock.call( + name=name, + value=value, + unit="ms", + task="index", + operation="index-op", + operation_type="bulk", + sample_type=metrics.SampleType.Normal, + absolute_time=absolute_time, + relative_time=relative_time, + meta_data={}, + ) @mock.patch("solrorbit.metrics.MetricsStore") def test_all_samples(self, metrics_store): - post_process = worker_coordinator.DefaultSamplePostprocessor(metrics_store, - downsample_factor=1, - workload_meta_data={}, - test_procedure_meta_data={}) + post_process = worker_coordinator.DefaultSamplePostprocessor(metrics_store, downsample_factor=1, workload_meta_data={}, test_procedure_meta_data={}) task = workload.Task("index", workload.Operation("index-op", "bulk", param_source="worker-coordinator-test-param-source")) samples = [ - worker_coordinator.DefaultSample( - 0, 38598, 24, 0, task, metrics.SampleType.Normal, - None, 0.01, 0.007, 0.0007, 0.009, None, 5000, "docs", 1, 1 / 2), - worker_coordinator.DefaultSample( - 0, 38599, 25, 0, task, metrics.SampleType.Normal, - None, 0.01, 0.007, 0.0007, 0.009, None, 5000, "docs", 2, 2 / 2), + worker_coordinator.DefaultSample(0, 38598, 24, 0, task, metrics.SampleType.Normal, None, 0.01, 0.007, 0.0007, 0.009, None, 5000, "docs", 1, 1 / 2), + worker_coordinator.DefaultSample(0, 38599, 25, 0, task, metrics.SampleType.Normal, None, 0.01, 0.007, 0.0007, 0.009, None, 5000, "docs", 2, 2 / 2), ] post_process(samples) calls = [ - self.latency(38598, 24, 10.0), self.service_time(38598, 24, 7.0), - self.client_processing_time(38598, 24, 0.7), self.processing_time(38598, 24, 9.0), - self.latency(38599, 25, 10.0), self.service_time(38599, 25, 7.0), - self.client_processing_time(38599, 25, 0.7), self.processing_time(38599, 25, 9.0), + self.latency(38598, 24, 10.0), + self.service_time(38598, 24, 7.0), + self.client_processing_time(38598, 24, 0.7), + self.processing_time(38598, 24, 9.0), + self.latency(38599, 25, 10.0), + self.service_time(38599, 25, 7.0), + self.client_processing_time(38599, 25, 0.7), + self.processing_time(38599, 25, 9.0), self.throughput(38598, 24, 5000), self.throughput(38599, 25, 5000), ] @@ -355,28 +347,23 @@ def test_all_samples(self, metrics_store): @mock.patch("solrorbit.metrics.MetricsStore") def test_downsamples(self, metrics_store): - post_process = worker_coordinator.DefaultSamplePostprocessor(metrics_store, - downsample_factor=2, - workload_meta_data={}, - test_procedure_meta_data={}) + post_process = worker_coordinator.DefaultSamplePostprocessor(metrics_store, downsample_factor=2, workload_meta_data={}, test_procedure_meta_data={}) task = workload.Task("index", workload.Operation("index-op", "bulk", param_source="worker-coordinator-test-param-source")) samples = [ - worker_coordinator.DefaultSample( - 0, 38598, 24, 0, task, metrics.SampleType.Normal, - None, 0.01, 0.007, 0.0007, 0.009, None, 5000, "docs", 1, 1 / 2), - worker_coordinator.DefaultSample( - 0, 38599, 25, 0, task, metrics.SampleType.Normal, - None, 0.01, 0.007, 0.0007, 0.009, None, 5000, "docs", 2, 2 / 2), + worker_coordinator.DefaultSample(0, 38598, 24, 0, task, metrics.SampleType.Normal, None, 0.01, 0.007, 0.0007, 0.009, None, 5000, "docs", 1, 1 / 2), + worker_coordinator.DefaultSample(0, 38599, 25, 0, task, metrics.SampleType.Normal, None, 0.01, 0.007, 0.0007, 0.009, None, 5000, "docs", 2, 2 / 2), ] post_process(samples) calls = [ # only the first out of two request samples is included, throughput metrics are still complete - self.latency(38598, 24, 10.0), self.service_time(38598, 24, 7.0), - self.client_processing_time(38598, 24, 0.7), self.processing_time(38598, 24, 9.0), + self.latency(38598, 24, 10.0), + self.service_time(38598, 24, 7.0), + self.client_processing_time(38598, 24, 0.7), + self.processing_time(38598, 24, 9.0), self.throughput(38598, 24, 5000), self.throughput(38599, 25, 5000), ] @@ -384,39 +371,41 @@ def test_downsamples(self, metrics_store): @mock.patch("solrorbit.metrics.MetricsStore") def test_dependent_samples(self, metrics_store): - post_process = worker_coordinator.DefaultSamplePostprocessor(metrics_store, - downsample_factor=1, - workload_meta_data={}, - test_procedure_meta_data={}) + post_process = worker_coordinator.DefaultSamplePostprocessor(metrics_store, downsample_factor=1, workload_meta_data={}, test_procedure_meta_data={}) task = workload.Task("index", workload.Operation("index-op", "bulk", param_source="worker-coordinator-test-param-source")) samples = [ worker_coordinator.DefaultSample( - 0, 38598, 24, 0, task, metrics.SampleType.Normal, - None, 0.01, 0.007, 0.0007, 0.009, None, 5000, "docs", 1, 1 / 2, - dependent_timing=[ - { - "absolute_time": 38601, - "request_start": 25, - "service_time": 0.05, - "operation": "index-op", - "operation-type": "bulk" - }, - { - "absolute_time": 38602, - "request_start": 26, - "service_time": 0.08, - "operation": "index-op", - "operation-type": "bulk" - } - ]), + 0, + 38598, + 24, + 0, + task, + metrics.SampleType.Normal, + None, + 0.01, + 0.007, + 0.0007, + 0.009, + None, + 5000, + "docs", + 1, + 1 / 2, + dependent_timing=[ + {"absolute_time": 38601, "request_start": 25, "service_time": 0.05, "operation": "index-op", "operation-type": "bulk"}, + {"absolute_time": 38602, "request_start": 26, "service_time": 0.08, "operation": "index-op", "operation-type": "bulk"}, + ], + ), ] post_process(samples) calls = [ - self.latency(38598, 24, 10.0), self.service_time(38598, 24, 7.0), - self.client_processing_time(38598, 24, 0.7), self.processing_time(38598, 24, 9.0), + self.latency(38598, 24, 10.0), + self.service_time(38598, 24, 7.0), + self.client_processing_time(38598, 24, 0.7), + self.processing_time(38598, 24, 9.0), # dependent timings self.service_time(38601, 25, 50.0), self.service_time(38602, 26, 80.0), @@ -427,182 +416,53 @@ def test_dependent_samples(self, metrics_store): class WorkerAssignmentTests(TestCase): def test_single_host_assignment_clients_matches_cores(self): - host_configs = [{ - "host": "localhost", - "cores": 4 - }] + host_configs = [{"host": "localhost", "cores": 4}] assignments = worker_coordinator.calculate_worker_assignments(host_configs, client_count=4) - self.assertEqual([ - { - "host": "localhost", - "workers": [ - [0], - [1], - [2], - [3] - ] - } - ], assignments) + self.assertEqual([{"host": "localhost", "workers": [[0], [1], [2], [3]]}], assignments) def test_single_host_assignment_more_clients_than_cores(self): - host_configs = [{ - "host": "localhost", - "cores": 4 - }] + host_configs = [{"host": "localhost", "cores": 4}] assignments = worker_coordinator.calculate_worker_assignments(host_configs, client_count=6) - self.assertEqual([ - { - "host": "localhost", - "workers": [ - [0, 1], - [2, 3], - [4], - [5] - ] - } - ], assignments) + self.assertEqual([{"host": "localhost", "workers": [[0, 1], [2, 3], [4], [5]]}], assignments) def test_single_host_assignment_less_clients_than_cores(self): - host_configs = [{ - "host": "localhost", - "cores": 4 - }] + host_configs = [{"host": "localhost", "cores": 4}] assignments = worker_coordinator.calculate_worker_assignments(host_configs, client_count=2) - self.assertEqual([ - { - "host": "localhost", - "workers": [ - [0], - [1], - [], - [] - ] - } - ], assignments) + self.assertEqual([{"host": "localhost", "workers": [[0], [1], [], []]}], assignments) def test_multiple_host_assignment_more_clients_than_cores(self): - host_configs = [ - { - "host": "host-a", - "cores": 4 - }, - { - "host": "host-b", - "cores": 4 - } - ] + host_configs = [{"host": "host-a", "cores": 4}, {"host": "host-b", "cores": 4}] assignments = worker_coordinator.calculate_worker_assignments(host_configs, client_count=16) - self.assertEqual([ - { - "host": "host-a", - "workers": [ - [0, 1], - [2, 3], - [4, 5], - [6, 7] - ] - }, - { - "host": "host-b", - "workers": [ - [8, 9], - [10, 11], - [12, 13], - [14, 15] - ] - } - ], assignments) + self.assertEqual([{"host": "host-a", "workers": [[0, 1], [2, 3], [4, 5], [6, 7]]}, {"host": "host-b", "workers": [[8, 9], [10, 11], [12, 13], [14, 15]]}], assignments) def test_multiple_host_assignment_less_clients_than_cores(self): - host_configs = [ - { - "host": "host-a", - "cores": 4 - }, - { - "host": "host-b", - "cores": 4 - } - ] + host_configs = [{"host": "host-a", "cores": 4}, {"host": "host-b", "cores": 4}] assignments = worker_coordinator.calculate_worker_assignments(host_configs, client_count=4) - self.assertEqual([ - { - "host": "host-a", - "workers": [ - [0], - [1], - [], - [] - ] - }, - { - "host": "host-b", - "workers": [ - [2], - [3], - [], - [] - ] - } - ], assignments) + self.assertEqual([{"host": "host-a", "workers": [[0], [1], [], []]}, {"host": "host-b", "workers": [[2], [3], [], []]}], assignments) def test_uneven_assignment_across_hosts(self): - host_configs = [ - { - "host": "host-a", - "cores": 4 - }, - { - "host": "host-b", - "cores": 4 - }, - { - "host": "host-c", - "cores": 4 - } - ] + host_configs = [{"host": "host-a", "cores": 4}, {"host": "host-b", "cores": 4}, {"host": "host-c", "cores": 4}] assignments = worker_coordinator.calculate_worker_assignments(host_configs, client_count=17) - self.assertEqual([ - { - "host": "host-a", - "workers": [ - [0, 1], - [2, 3], - [4], - [5] - ] - }, - { - "host": "host-b", - "workers": [ - [6, 7], - [8, 9], - [10], - [11] - ] - }, - { - "host": "host-c", - "workers": [ - [12, 13], - [14], - [15], - [16] - ] - } - ], assignments) + self.assertEqual( + [ + {"host": "host-a", "workers": [[0, 1], [2, 3], [4], [5]]}, + {"host": "host-b", "workers": [[6, 7], [8, 9], [10], [11]]}, + {"host": "host-c", "workers": [[12, 13], [14], [15], [16]]}, + ], + assignments, + ) class AllocatorTests(TestCase): @@ -610,9 +470,9 @@ def setUp(self): params.register_param_source_for_name("worker-coordinator-test-param-source", WorkerCoordinatorTestParamSource) def ta(self, task, client_index_in_task, global_client_index=None, total_clients=None): - return worker_coordinator.TaskAllocation(task, client_index_in_task, - client_index_in_task if global_client_index is None else global_client_index, - task.clients if total_clients is None else total_clients) + return worker_coordinator.TaskAllocation( + task, client_index_in_task, client_index_in_task if global_client_index is None else global_client_index, task.clients if total_clients is None else total_clients + ) def test_allocates_one_task(self): task = workload.Task("index", op("index", workload.OperationType.Bulk)) @@ -670,11 +530,7 @@ def test_allocates_mixed_tasks(self): stats = workload.Task("stats", op("stats", workload.OperationType.Sleep)) search = workload.Task("search", op("search", workload.OperationType.Search)) - allocator = worker_coordinator.Allocator([index, - workload.Parallel([index, stats, stats]), - index, - index, - workload.Parallel([search, search, search])]) + allocator = worker_coordinator.Allocator([index, workload.Parallel([index, stats, stats]), index, index, workload.Parallel([search, search, search])]) self.assertEqual(3, allocator.clients) @@ -706,24 +562,28 @@ def test_allocates_more_tasks_than_clients(self): # join_point, index_a, index_c, index_e, join_point self.assertEqual(5, len(allocations[0])) # we really have no chance to extract the join point so we just take what is there... - self.assertEqual([allocations[0][0], - self.ta(index_a, client_index_in_task=0, - global_client_index=0, total_clients=2), - self.ta(index_c, client_index_in_task=0, - global_client_index=2, total_clients=2), - self.ta(index_e, client_index_in_task=0, - global_client_index=4, total_clients=2), - allocations[0][4]], - allocations[0]) + self.assertEqual( + [ + allocations[0][0], + self.ta(index_a, client_index_in_task=0, global_client_index=0, total_clients=2), + self.ta(index_c, client_index_in_task=0, global_client_index=2, total_clients=2), + self.ta(index_e, client_index_in_task=0, global_client_index=4, total_clients=2), + allocations[0][4], + ], + allocations[0], + ) # join_point, index_a, index_c, None, join_point self.assertEqual(5, len(allocator.allocations[1])) - self.assertEqual([allocations[1][0], - self.ta(index_b, client_index_in_task=0, - global_client_index=1, total_clients=2), - self.ta(index_d, client_index_in_task=0, - global_client_index=3, total_clients=2), - None, allocations[1][4]], - allocations[1]) + self.assertEqual( + [ + allocations[1][0], + self.ta(index_b, client_index_in_task=0, global_client_index=1, total_clients=2), + self.ta(index_d, client_index_in_task=0, global_client_index=3, total_clients=2), + None, + allocations[1][4], + ], + allocations[1], + ) self.assertEqual([{index_a, index_b, index_c, index_d, index_e}], allocator.tasks_per_joinpoint) self.assertEqual(2, len(allocator.join_points)) @@ -750,32 +610,24 @@ def test_considers_number_of_clients_per_subtask(self): # join_point, index_a, index_c, join_point self.assertEqual(4, len(allocations[0])) # we really have no chance to extract the join point so we just take what is there... - self.assertEqual([allocations[0][0], - self.ta(index_a, client_index_in_task=0, - global_client_index=0, total_clients=3), - self.ta(index_c, client_index_in_task=1, - global_client_index=3, total_clients=3), - allocations[0][3]], - allocations[0]) + self.assertEqual( + [ + allocations[0][0], + self.ta(index_a, client_index_in_task=0, global_client_index=0, total_clients=3), + self.ta(index_c, client_index_in_task=1, global_client_index=3, total_clients=3), + allocations[0][3], + ], + allocations[0], + ) # task that client 1 will execute: # join_point, index_b, None, join_point self.assertEqual(4, len(allocator.allocations[1])) - self.assertEqual([allocations[1][0], - self.ta(index_b, client_index_in_task=0, - global_client_index=1, total_clients=3), - None, - allocations[1][3]], - allocations[1]) + self.assertEqual([allocations[1][0], self.ta(index_b, client_index_in_task=0, global_client_index=1, total_clients=3), None, allocations[1][3]], allocations[1]) # tasks that client 2 will execute: self.assertEqual(4, len(allocator.allocations[2])) - self.assertEqual([allocations[2][0], - self.ta(index_c, client_index_in_task=0, - global_client_index=2, total_clients=3), - None, - allocations[2][3]], - allocations[2]) + self.assertEqual([allocations[2][0], self.ta(index_c, client_index_in_task=0, global_client_index=2, total_clients=3), None, allocations[2][3]], allocations[2]) self.assertEqual([{index_a, index_b, index_c}], allocator.tasks_per_joinpoint) @@ -795,10 +647,8 @@ def test_different_sample_types(self): op = workload.Operation("index", workload.OperationType.Bulk, param_source="worker-coordinator-test-param-source") samples = [ - worker_coordinator.DefaultSample(0, 1470838595, 21, 0, op, metrics.SampleType.Warmup, - None, -1, -1, -1, -1, None, 3000, "docs", 1, 1), - worker_coordinator.DefaultSample(0, 1470838595.5, 21.5, 0, op, metrics.SampleType.Normal, - None, -1, -1, -1, -1, None, 2500, "docs", 1, 1), + worker_coordinator.DefaultSample(0, 1470838595, 21, 0, op, metrics.SampleType.Warmup, None, -1, -1, -1, -1, None, 3000, "docs", 1, 1), + worker_coordinator.DefaultSample(0, 1470838595.5, 21.5, 0, op, metrics.SampleType.Normal, None, -1, -1, -1, -1, None, 2500, "docs", 1, 1), ] aggregated = self.calculate_global_throughput(samples) @@ -821,12 +671,9 @@ def test_single_metrics_aggregation(self): worker_coordinator.DefaultSample(0, 38598, 24, 0, op, metrics.SampleType.Normal, None, -1, -1, -1, -1, None, 5000, "docs", 4, 4 / 9), worker_coordinator.DefaultSample(0, 38599, 25, 0, op, metrics.SampleType.Normal, None, -1, -1, -1, -1, None, 5000, "docs", 5, 5 / 9), worker_coordinator.DefaultSample(0, 38600, 26, 0, op, metrics.SampleType.Normal, None, -1, -1, -1, -1, None, 5000, "docs", 6, 6 / 9), - worker_coordinator.DefaultSample(1, 38598.5, 24.5, 0, op, metrics.SampleType.Normal, - None, -1, -1, -1, -1, None, 5000, "docs", 4.5, 7 / 9), - worker_coordinator.DefaultSample(1, 38599.5, 25.5, 0, op, metrics.SampleType.Normal, - None, -1, -1, -1, -1, None, 5000, "docs", 5.5, 8 / 9), - worker_coordinator.DefaultSample(1, 38600.5, 26.5, 0, op, metrics.SampleType.Normal, - None, -1, -1, -1, -1, None, 5000, "docs", 6.5, 9 / 9) + worker_coordinator.DefaultSample(1, 38598.5, 24.5, 0, op, metrics.SampleType.Normal, None, -1, -1, -1, -1, None, 5000, "docs", 4.5, 7 / 9), + worker_coordinator.DefaultSample(1, 38599.5, 25.5, 0, op, metrics.SampleType.Normal, None, -1, -1, -1, -1, None, 5000, "docs", 5.5, 8 / 9), + worker_coordinator.DefaultSample(1, 38600.5, 26.5, 0, op, metrics.SampleType.Normal, None, -1, -1, -1, -1, None, 5000, "docs", 6.5, 9 / 9), ] aggregated = self.calculate_global_throughput(samples) @@ -845,8 +692,7 @@ def test_single_metrics_aggregation(self): # self.assertEqual((1470838600.5, 26.5, metrics.SampleType.Normal, 10000), throughput[6]) def test_use_provided_throughput(self): - op = workload.Operation("index-recovery", workload.OperationType.Sleep, - param_source="worker-coordinator-test-param-source") + op = workload.Operation("index-recovery", workload.OperationType.Sleep, param_source="worker-coordinator-test-param-source") samples = [ worker_coordinator.DefaultSample(0, 38595, 21, 0, op, metrics.SampleType.Normal, None, -1, -1, -1, -1, 8000, 5000, "byte", 1, 1 / 3), @@ -873,17 +719,17 @@ class SchedulerTests(TestCase): class RunnerWithProgress: def __init__(self, complete_after=3): self.completed = False - self.task_progress = (0.0, '%') + self.task_progress = (0.0, "%") self.calls = 0 self.complete_after = complete_after async def __call__(self, *args, **kwargs): self.calls += 1 if not self.completed: - self.task_progress = (self.calls / self.complete_after, '%') + self.task_progress = (self.calls / self.complete_after, "%") self.completed = self.calls == self.complete_after else: - self.task_progress = (1.0, '%') + self.task_progress = (1.0, "%") class CustomComplexScheduler: def __init__(self, task): @@ -936,45 +782,34 @@ def tearDown(self): runner.remove_runner("bulk") def test_injects_parameter_source_into_scheduler(self): - task = workload.Task(name="search", - schedule="custom-complex-scheduler", - operation=workload.Operation( - name="search", - operation_type=workload.OperationType.Search.to_hyphenated_string(), - param_source="worker-coordinator-test-param-source" - ), - clients=4, - params={ - "target-throughput": "5000 ops/s" - }) + task = workload.Task( + name="search", + schedule="custom-complex-scheduler", + operation=workload.Operation(name="search", operation_type=workload.OperationType.Search.to_hyphenated_string(), param_source="worker-coordinator-test-param-source"), + clients=4, + params={"target-throughput": "5000 ops/s"}, + ) param_source = workload.operation_parameters(self.test_workload, task) - task_allocation = worker_coordinator.TaskAllocation( - task=task, - client_index_in_task=0, - global_client_index=0, - total_clients=task.clients - ) - schedule = worker_coordinator.schedule_for( - task_allocation, param_source) + task_allocation = worker_coordinator.TaskAllocation(task=task, client_index_in_task=0, global_client_index=0, total_clients=task.clients) + schedule = worker_coordinator.schedule_for(task_allocation, param_source) self.assertIsNotNone(schedule.sched.parameter_source, "Parameter source has not been injected into scheduler") self.assertEqual(param_source, schedule.sched.parameter_source) @run_async async def test_search_task_one_client(self): - task = workload.Task("search", workload.Operation("search", workload.OperationType.Search.to_hyphenated_string(), - param_source="worker-coordinator-test-param-source"), - warmup_iterations=3, iterations=5, clients=1, params={"target-throughput": 10, "clients": 1}) - param_source = workload.operation_parameters(self.test_workload, task) - task_allocation = worker_coordinator.TaskAllocation( - task=task, - client_index_in_task=0, - global_client_index=0, - total_clients=task.clients + task = workload.Task( + "search", + workload.Operation("search", workload.OperationType.Search.to_hyphenated_string(), param_source="worker-coordinator-test-param-source"), + warmup_iterations=3, + iterations=5, + clients=1, + params={"target-throughput": 10, "clients": 1}, ) - schedule = worker_coordinator.schedule_for( - task_allocation, param_source) + param_source = workload.operation_parameters(self.test_workload, task) + task_allocation = worker_coordinator.TaskAllocation(task=task, client_index_in_task=0, global_client_index=0, total_clients=task.clients) + schedule = worker_coordinator.schedule_for(task_allocation, param_source) expected_schedule = [ (0, metrics.SampleType.Warmup, 1 / 8, {}), @@ -990,18 +825,17 @@ async def test_search_task_one_client(self): @run_async async def test_search_task_two_clients(self): - task = workload.Task("search", workload.Operation("search", workload.OperationType.Search.to_hyphenated_string(), - param_source="worker-coordinator-test-param-source"), - warmup_iterations=1, iterations=5, clients=2, params={"target-throughput": 10, "clients": 2}) - param_source = workload.operation_parameters(self.test_workload, task) - task_allocation = worker_coordinator.TaskAllocation( - task=task, - client_index_in_task=0, - global_client_index=0, - total_clients=task.clients + task = workload.Task( + "search", + workload.Operation("search", workload.OperationType.Search.to_hyphenated_string(), param_source="worker-coordinator-test-param-source"), + warmup_iterations=1, + iterations=5, + clients=2, + params={"target-throughput": 10, "clients": 2}, ) - schedule = worker_coordinator.schedule_for( - task_allocation, param_source) + param_source = workload.operation_parameters(self.test_workload, task) + task_allocation = worker_coordinator.TaskAllocation(task=task, client_index_in_task=0, global_client_index=0, total_clients=task.clients) + schedule = worker_coordinator.schedule_for(task_allocation, param_source) expected_schedule = [ (0, metrics.SampleType.Warmup, 1 / 6, {}), @@ -1016,205 +850,203 @@ async def test_search_task_two_clients(self): @run_async async def test_schedule_param_source_determines_iterations_no_warmup(self): # we neither define any time-period nor any iteration count on the task. - task = workload.Task("bulk-index", workload.Operation("bulk-index", workload.OperationType.Bulk.to_hyphenated_string(), - params={"body": ["a"], "size": 3}, - param_source="worker-coordinator-test-param-source"), - clients=4, params={"target-throughput": 4}) + task = workload.Task( + "bulk-index", + workload.Operation( + "bulk-index", workload.OperationType.Bulk.to_hyphenated_string(), params={"body": ["a"], "size": 3}, param_source="worker-coordinator-test-param-source" + ), + clients=4, + params={"target-throughput": 4}, + ) param_source = workload.operation_parameters(self.test_workload, task) - task_allocation = worker_coordinator.TaskAllocation( - task=task, - client_index_in_task=0, - global_client_index=0, - total_clients=task.clients + task_allocation = worker_coordinator.TaskAllocation(task=task, client_index_in_task=0, global_client_index=0, total_clients=task.clients) + schedule = worker_coordinator.schedule_for(task_allocation, param_source) + + await self.assert_schedule( + [ + (0.0, metrics.SampleType.Normal, 1 / 3, {"body": ["a"], "size": 3}), + (1.0, metrics.SampleType.Normal, 2 / 3, {"body": ["a"], "size": 3}), + (2.0, metrics.SampleType.Normal, 3 / 3, {"body": ["a"], "size": 3}), + ], + schedule, ) - schedule = worker_coordinator.schedule_for( - task_allocation, param_source) - - await self.assert_schedule([ - (0.0, metrics.SampleType.Normal, 1 / 3, {"body": ["a"], "size": 3}), - (1.0, metrics.SampleType.Normal, 2 / 3, {"body": ["a"], "size": 3}), - (2.0, metrics.SampleType.Normal, 3 / 3, {"body": ["a"], "size": 3}), - ], schedule) @run_async async def test_schedule_param_source_determines_iterations_including_warmup(self): - task = workload.Task("bulk-index", workload.Operation("bulk-index", workload.OperationType.Bulk.to_hyphenated_string(), - params={"body": ["a"], "size": 5}, - param_source="worker-coordinator-test-param-source"), - warmup_iterations=2, clients=4, params={"target-throughput": 4}) + task = workload.Task( + "bulk-index", + workload.Operation( + "bulk-index", workload.OperationType.Bulk.to_hyphenated_string(), params={"body": ["a"], "size": 5}, param_source="worker-coordinator-test-param-source" + ), + warmup_iterations=2, + clients=4, + params={"target-throughput": 4}, + ) param_source = workload.operation_parameters(self.test_workload, task) - task_allocation = worker_coordinator.TaskAllocation( - task=task, - client_index_in_task=0, - global_client_index=0, - total_clients=task.clients + task_allocation = worker_coordinator.TaskAllocation(task=task, client_index_in_task=0, global_client_index=0, total_clients=task.clients) + schedule = worker_coordinator.schedule_for(task_allocation, param_source) + + await self.assert_schedule( + [ + (0.0, metrics.SampleType.Warmup, 1 / 5, {"body": ["a"], "size": 5}), + (1.0, metrics.SampleType.Warmup, 2 / 5, {"body": ["a"], "size": 5}), + (2.0, metrics.SampleType.Normal, 3 / 5, {"body": ["a"], "size": 5}), + (3.0, metrics.SampleType.Normal, 4 / 5, {"body": ["a"], "size": 5}), + (4.0, metrics.SampleType.Normal, 5 / 5, {"body": ["a"], "size": 5}), + ], + schedule, ) - schedule = worker_coordinator.schedule_for( - task_allocation, param_source) - - await self.assert_schedule([ - (0.0, metrics.SampleType.Warmup, 1 / 5, {"body": ["a"], "size": 5}), - (1.0, metrics.SampleType.Warmup, 2 / 5, {"body": ["a"], "size": 5}), - (2.0, metrics.SampleType.Normal, 3 / 5, {"body": ["a"], "size": 5}), - (3.0, metrics.SampleType.Normal, 4 / 5, {"body": ["a"], "size": 5}), - (4.0, metrics.SampleType.Normal, 5 / 5, {"body": ["a"], "size": 5}), - ], schedule) @run_async async def test_schedule_defaults_to_iteration_based(self): # no time-period and no iterations specified on the task. Also, the parameter source does not define a size. - task = workload.Task("bulk-index", workload.Operation("bulk-index", workload.OperationType.Bulk.to_hyphenated_string(), - params={"body": ["a"]}, - param_source="worker-coordinator-test-param-source"), - clients=1, params={"target-throughput": 4, "clients": 4}) + task = workload.Task( + "bulk-index", + workload.Operation("bulk-index", workload.OperationType.Bulk.to_hyphenated_string(), params={"body": ["a"]}, param_source="worker-coordinator-test-param-source"), + clients=1, + params={"target-throughput": 4, "clients": 4}, + ) param_source = workload.operation_parameters(self.test_workload, task) - task_allocation = worker_coordinator.TaskAllocation( - task=task, - client_index_in_task=0, - global_client_index=0, - total_clients=task.clients + task_allocation = worker_coordinator.TaskAllocation(task=task, client_index_in_task=0, global_client_index=0, total_clients=task.clients) + schedule = worker_coordinator.schedule_for(task_allocation, param_source) + + await self.assert_schedule( + [ + (0.0, metrics.SampleType.Normal, 1 / 1, {"body": ["a"]}), + ], + schedule, ) - schedule = worker_coordinator.schedule_for( - task_allocation, param_source) - - await self.assert_schedule([ - (0.0, metrics.SampleType.Normal, 1 / 1, {"body": ["a"]}), - ], schedule) @run_async async def test_schedule_for_warmup_time_based(self): - task = workload.Task("time-based", workload.Operation("time-based", workload.OperationType.Bulk.to_hyphenated_string(), - params={"body": ["a"], "size": 11}, - param_source="worker-coordinator-test-param-source"), - warmup_time_period=0, clients=4, params={"target-throughput": 4, "clients": 4}) + task = workload.Task( + "time-based", + workload.Operation( + "time-based", workload.OperationType.Bulk.to_hyphenated_string(), params={"body": ["a"], "size": 11}, param_source="worker-coordinator-test-param-source" + ), + warmup_time_period=0, + clients=4, + params={"target-throughput": 4, "clients": 4}, + ) param_source = workload.operation_parameters(self.test_workload, task) - task_allocation = worker_coordinator.TaskAllocation( - task=task, - client_index_in_task=0, - global_client_index=0, - total_clients=task.clients + task_allocation = worker_coordinator.TaskAllocation(task=task, client_index_in_task=0, global_client_index=0, total_clients=task.clients) + schedule = worker_coordinator.schedule_for(task_allocation, param_source) + + await self.assert_schedule( + [ + (0.0, metrics.SampleType.Normal, 1 / 11, {"body": ["a"], "size": 11}), + (1.0, metrics.SampleType.Normal, 2 / 11, {"body": ["a"], "size": 11}), + (2.0, metrics.SampleType.Normal, 3 / 11, {"body": ["a"], "size": 11}), + (3.0, metrics.SampleType.Normal, 4 / 11, {"body": ["a"], "size": 11}), + (4.0, metrics.SampleType.Normal, 5 / 11, {"body": ["a"], "size": 11}), + (5.0, metrics.SampleType.Normal, 6 / 11, {"body": ["a"], "size": 11}), + (6.0, metrics.SampleType.Normal, 7 / 11, {"body": ["a"], "size": 11}), + (7.0, metrics.SampleType.Normal, 8 / 11, {"body": ["a"], "size": 11}), + (8.0, metrics.SampleType.Normal, 9 / 11, {"body": ["a"], "size": 11}), + (9.0, metrics.SampleType.Normal, 10 / 11, {"body": ["a"], "size": 11}), + (10.0, metrics.SampleType.Normal, 11 / 11, {"body": ["a"], "size": 11}), + ], + schedule, ) - schedule = worker_coordinator.schedule_for( - task_allocation, param_source) - - await self.assert_schedule([ - (0.0, metrics.SampleType.Normal, 1 / 11, {"body": ["a"], "size": 11}), - (1.0, metrics.SampleType.Normal, 2 / 11, {"body": ["a"], "size": 11}), - (2.0, metrics.SampleType.Normal, 3 / 11, {"body": ["a"], "size": 11}), - (3.0, metrics.SampleType.Normal, 4 / 11, {"body": ["a"], "size": 11}), - (4.0, metrics.SampleType.Normal, 5 / 11, {"body": ["a"], "size": 11}), - (5.0, metrics.SampleType.Normal, 6 / 11, {"body": ["a"], "size": 11}), - (6.0, metrics.SampleType.Normal, 7 / 11, {"body": ["a"], "size": 11}), - (7.0, metrics.SampleType.Normal, 8 / 11, {"body": ["a"], "size": 11}), - (8.0, metrics.SampleType.Normal, 9 / 11, {"body": ["a"], "size": 11}), - (9.0, metrics.SampleType.Normal, 10 / 11, {"body": ["a"], "size": 11}), - (10.0, metrics.SampleType.Normal, 11 / 11, {"body": ["a"], "size": 11}), - ], schedule) @run_async async def test_infinite_schedule_without_progress_indication(self): - task = workload.Task("time-based", workload.Operation("time-based", workload.OperationType.Bulk.to_hyphenated_string(), - params={"body": ["a"]}, - param_source="worker-coordinator-test-param-source"), - warmup_time_period=0, clients=4, params={"target-throughput": 4, "clients": 4}) + task = workload.Task( + "time-based", + workload.Operation("time-based", workload.OperationType.Bulk.to_hyphenated_string(), params={"body": ["a"]}, param_source="worker-coordinator-test-param-source"), + warmup_time_period=0, + clients=4, + params={"target-throughput": 4, "clients": 4}, + ) param_source = workload.operation_parameters(self.test_workload, task) - task_allocation = worker_coordinator.TaskAllocation( - task=task, - client_index_in_task=0, - global_client_index=0, - total_clients=task.clients + task_allocation = worker_coordinator.TaskAllocation(task=task, client_index_in_task=0, global_client_index=0, total_clients=task.clients) + schedule = worker_coordinator.schedule_for(task_allocation, param_source) + + await self.assert_schedule( + [ + (0.0, metrics.SampleType.Normal, None, {"body": ["a"]}), + (1.0, metrics.SampleType.Normal, None, {"body": ["a"]}), + (2.0, metrics.SampleType.Normal, None, {"body": ["a"]}), + (3.0, metrics.SampleType.Normal, None, {"body": ["a"]}), + (4.0, metrics.SampleType.Normal, None, {"body": ["a"]}), + ], + schedule, + infinite_schedule=True, ) - schedule = worker_coordinator.schedule_for( - task_allocation, param_source) - - await self.assert_schedule([ - (0.0, metrics.SampleType.Normal, None, {"body": ["a"]}), - (1.0, metrics.SampleType.Normal, None, {"body": ["a"]}), - (2.0, metrics.SampleType.Normal, None, {"body": ["a"]}), - (3.0, metrics.SampleType.Normal, None, {"body": ["a"]}), - (4.0, metrics.SampleType.Normal, None, {"body": ["a"]}), - ], schedule, infinite_schedule=True) @run_async async def test_finite_schedule_with_progress_indication(self): - task = workload.Task("time-based", workload.Operation("time-based", workload.OperationType.Bulk.to_hyphenated_string(), - params={ - "body": ["a"], "size": 5}, - param_source="worker-coordinator-test-param-source"), - warmup_time_period=0, clients=4, params={"target-throughput": 4, "clients": 4}) + task = workload.Task( + "time-based", + workload.Operation( + "time-based", workload.OperationType.Bulk.to_hyphenated_string(), params={"body": ["a"], "size": 5}, param_source="worker-coordinator-test-param-source" + ), + warmup_time_period=0, + clients=4, + params={"target-throughput": 4, "clients": 4}, + ) param_source = workload.operation_parameters(self.test_workload, task) - task_allocation = worker_coordinator.TaskAllocation( - task=task, - client_index_in_task=0, - global_client_index=0, - total_clients=task.clients + task_allocation = worker_coordinator.TaskAllocation(task=task, client_index_in_task=0, global_client_index=0, total_clients=task.clients) + schedule = worker_coordinator.schedule_for(task_allocation, param_source) + + await self.assert_schedule( + [ + (0.0, metrics.SampleType.Normal, 1 / 5, {"body": ["a"], "size": 5}), + (1.0, metrics.SampleType.Normal, 2 / 5, {"body": ["a"], "size": 5}), + (2.0, metrics.SampleType.Normal, 3 / 5, {"body": ["a"], "size": 5}), + (3.0, metrics.SampleType.Normal, 4 / 5, {"body": ["a"], "size": 5}), + (4.0, metrics.SampleType.Normal, 5 / 5, {"body": ["a"], "size": 5}), + ], + schedule, + infinite_schedule=False, ) - schedule = worker_coordinator.schedule_for( - task_allocation, param_source) - - await self.assert_schedule([ - (0.0, metrics.SampleType.Normal, - 1 / 5, {"body": ["a"], "size": 5}), - (1.0, metrics.SampleType.Normal, - 2 / 5, {"body": ["a"], "size": 5}), - (2.0, metrics.SampleType.Normal, - 3 / 5, {"body": ["a"], "size": 5}), - (3.0, metrics.SampleType.Normal, - 4 / 5, {"body": ["a"], "size": 5}), - (4.0, metrics.SampleType.Normal, - 5 / 5, {"body": ["a"], "size": 5}), - ], schedule, infinite_schedule=False) @run_async async def test_schedule_with_progress_determined_by_runner(self): - task = workload.Task("time-based", workload.Operation("time-based", "worker-coordinator-test-runner-with-completion", - params={"body": ["a"]}, - param_source="worker-coordinator-test-param-source"), - clients=1, - params={"target-throughput": 1, "clients": 1}) + task = workload.Task( + "time-based", + workload.Operation("time-based", "worker-coordinator-test-runner-with-completion", params={"body": ["a"]}, param_source="worker-coordinator-test-param-source"), + clients=1, + params={"target-throughput": 1, "clients": 1}, + ) param_source = workload.operation_parameters(self.test_workload, task) - task_allocation = worker_coordinator.TaskAllocation( - task=task, - client_index_in_task=0, - global_client_index=0, - total_clients=task.clients + task_allocation = worker_coordinator.TaskAllocation(task=task, client_index_in_task=0, global_client_index=0, total_clients=task.clients) + schedule = worker_coordinator.schedule_for(task_allocation, param_source) + + await self.assert_schedule( + [ + (0.0, metrics.SampleType.Normal, None, {"body": ["a"]}), + (1.0, metrics.SampleType.Normal, None, {"body": ["a"]}), + (2.0, metrics.SampleType.Normal, None, {"body": ["a"]}), + (3.0, metrics.SampleType.Normal, None, {"body": ["a"]}), + (4.0, metrics.SampleType.Normal, None, {"body": ["a"]}), + ], + schedule, + infinite_schedule=True, ) - schedule = worker_coordinator.schedule_for( - task_allocation, param_source) - - await self.assert_schedule([ - (0.0, metrics.SampleType.Normal, None, {"body": ["a"]}), - (1.0, metrics.SampleType.Normal, None, {"body": ["a"]}), - (2.0, metrics.SampleType.Normal, None, {"body": ["a"]}), - (3.0, metrics.SampleType.Normal, None, {"body": ["a"]}), - (4.0, metrics.SampleType.Normal, None, {"body": ["a"]}), - ], schedule, infinite_schedule=True) @run_async async def test_schedule_for_time_based(self): - task = workload.Task("time-based", workload.Operation("time-based", workload.OperationType.Bulk.to_hyphenated_string(), - params={"body": ["a"], "size": 11}, - param_source="worker-coordinator-test-param-source"), - warmup_time_period=0.1, - time_period=0.1, - clients=1) + task = workload.Task( + "time-based", + workload.Operation( + "time-based", workload.OperationType.Bulk.to_hyphenated_string(), params={"body": ["a"], "size": 11}, param_source="worker-coordinator-test-param-source" + ), + warmup_time_period=0.1, + time_period=0.1, + clients=1, + ) param_source = workload.operation_parameters(self.test_workload, task) - task_allocation = worker_coordinator.TaskAllocation( - task=task, - client_index_in_task=0, - global_client_index=0, - total_clients=task.clients - ) - schedule_handle = worker_coordinator.schedule_for( - task_allocation, param_source) + task_allocation = worker_coordinator.TaskAllocation(task=task, client_index_in_task=0, global_client_index=0, total_clients=task.clients) + schedule_handle = worker_coordinator.schedule_for(task_allocation, param_source) schedule_handle.start() self.assertEqual(0.0, schedule_handle.ramp_up_wait_time) schedule = schedule_handle() @@ -1294,18 +1126,14 @@ def completed(self): @property def task_progress(self): - return ((self.iterations - self.iterations_left) / self.iterations, '%') + return ((self.iterations - self.iterations_left) / self.iterations, "%") async def __call__(self, opensearch, params): self.iterations_left -= 1 class RunnerOverridingThroughput: async def __call__(self, opensearch, params): - return { - "weight": 1, - "unit": "ops", - "throughput": 1.23 - } + return {"weight": 1, "unit": "ops", "throughput": 1.23} def __init__(self, methodName): super().__init__(methodName) @@ -1328,8 +1156,8 @@ def tearDown(self): runner.remove_runner("override-throughput") runner.remove_runner("bulk") - @mock.patch('solrorbit.client.RequestContextHolder.on_client_request_end') - @mock.patch('solrorbit.client.RequestContextHolder.on_client_request_start') + @mock.patch("solrorbit.client.RequestContextHolder.on_client_request_end") + @mock.patch("solrorbit.client.RequestContextHolder.on_client_request_start") @mock.patch("tests.worker_coordinator.worker_coordinator_test._FakeOSClient") @run_async async def test_run_schedule_in_throughput_mode(self, opensearch, on_client_request_start, on_client_request_end): @@ -1339,45 +1167,47 @@ async def test_run_schedule_in_throughput_mode(self, opensearch, on_client_reque opensearch.bulk.return_value = as_future(io.StringIO('{"errors": false, "took": 8}')) params.register_param_source_for_name("worker-coordinator-test-param-source", WorkerCoordinatorTestParamSource) - test_workload = workload.Workload(name="unittest", description="unittest workload", - test_procedures=None) - - task = workload.Task("time-based", workload.Operation("time-based", workload.OperationType.Bulk.to_hyphenated_string(), - params={ - "body": ["action_metadata_line", "index_line"], - "action-metadata-present": True, - "bulk-size": 1, - "unit": "docs", - # we need this because WorkerCoordinatorTestParamSource does not know - # that we only have one bulk and hence size() returns - # incorrect results - "size": 1 - }, - param_source="worker-coordinator-test-param-source"), - warmup_time_period=0, clients=4) + test_workload = workload.Workload(name="unittest", description="unittest workload", test_procedures=None) + + task = workload.Task( + "time-based", + workload.Operation( + "time-based", + workload.OperationType.Bulk.to_hyphenated_string(), + params={ + "body": ["action_metadata_line", "index_line"], + "action-metadata-present": True, + "bulk-size": 1, + "unit": "docs", + # we need this because WorkerCoordinatorTestParamSource does not know + # that we only have one bulk and hence size() returns + # incorrect results + "size": 1, + }, + param_source="worker-coordinator-test-param-source", + ), + warmup_time_period=0, + clients=4, + ) param_source = workload.operation_parameters(test_workload, task) - task_allocation = worker_coordinator.TaskAllocation(task=task, - client_index_in_task=0, - global_client_index=0, - total_clients=task.clients) - schedule = worker_coordinator.schedule_for( - task_allocation, param_source) + task_allocation = worker_coordinator.TaskAllocation(task=task, client_index_in_task=0, global_client_index=0, total_clients=task.clients) + schedule = worker_coordinator.schedule_for(task_allocation, param_source) sampler = worker_coordinator.DefaultSampler(start_timestamp=task_start) profile_sampler = worker_coordinator.ProfileMetricsSampler(start_timestamp=task_start) cancel = threading.Event() complete = threading.Event() - execute_schedule = worker_coordinator.AsyncExecutor(client_id=2, - task=task, - schedule=schedule, - clients={ - "default": opensearch - }, - sampler=sampler, - profile_sampler=profile_sampler, - cancel=cancel, - complete=complete, - on_error="continue") + execute_schedule = worker_coordinator.AsyncExecutor( + client_id=2, + task=task, + schedule=schedule, + clients={"default": opensearch}, + sampler=sampler, + profile_sampler=profile_sampler, + cancel=cancel, + complete=complete, + on_error="continue", + ) await execute_schedule() samples = sampler.samples @@ -1407,45 +1237,50 @@ async def test_run_schedule_with_progress_determined_by_runner(self, opensearch) opensearch.new_request_context.return_value = AsyncExecutorTests.StaticRequestTiming(task_start=task_start) params.register_param_source_for_name("worker-coordinator-test-param-source", WorkerCoordinatorTestParamSource) - test_workload = workload.Workload(name="unittest", description="unittest workload", - test_procedures=None) - - task = workload.Task("time-based", workload.Operation("time-based", operation_type="unit-test-recovery", params={ - "indices-to-restore": "*", - # The runner will determine progress - "size": None - }, param_source="worker-coordinator-test-param-source"), warmup_time_period=0, clients=4) + test_workload = workload.Workload(name="unittest", description="unittest workload", test_procedures=None) + + task = workload.Task( + "time-based", + workload.Operation( + "time-based", + operation_type="unit-test-recovery", + params={ + "indices-to-restore": "*", + # The runner will determine progress + "size": None, + }, + param_source="worker-coordinator-test-param-source", + ), + warmup_time_period=0, + clients=4, + ) param_source = workload.operation_parameters(test_workload, task) - task_allocation = worker_coordinator.TaskAllocation(task=task, - client_index_in_task=0, - global_client_index=0, - total_clients=task.clients) - schedule = worker_coordinator.schedule_for( - task_allocation, param_source) + task_allocation = worker_coordinator.TaskAllocation(task=task, client_index_in_task=0, global_client_index=0, total_clients=task.clients) + schedule = worker_coordinator.schedule_for(task_allocation, param_source) sampler = worker_coordinator.DefaultSampler(start_timestamp=task_start) profile_sampler = worker_coordinator.ProfileMetricsSampler(start_timestamp=task_start) cancel = threading.Event() complete = threading.Event() - execute_schedule = worker_coordinator.AsyncExecutor(client_id=2, - task=task, - schedule=schedule, - clients={ - "default": opensearch - }, - sampler=sampler, - profile_sampler=profile_sampler, - cancel=cancel, - complete=complete, - on_error="continue") + execute_schedule = worker_coordinator.AsyncExecutor( + client_id=2, + task=task, + schedule=schedule, + clients={"default": opensearch}, + sampler=sampler, + profile_sampler=profile_sampler, + cancel=cancel, + complete=complete, + on_error="continue", + ) await execute_schedule() samples = sampler.samples self.assertEqual(5, len(samples)) self.assertTrue(self.runner_with_progress.completed) - self.assertEqual((1.0, '%'), self.runner_with_progress.task_progress) + self.assertEqual((1.0, "%"), self.runner_with_progress.task_progress) self.assertFalse(complete.is_set(), "Executor should not auto-complete a normal task") previous_absolute_time = -1.0 previous_relative_time = -1.0 @@ -1472,41 +1307,44 @@ async def test_run_schedule_runner_overrides_times(self, opensearch): opensearch.new_request_context.return_value = AsyncExecutorTests.StaticRequestTiming(task_start=task_start) params.register_param_source_for_name("worker-coordinator-test-param-source", WorkerCoordinatorTestParamSource) - test_workload = workload.Workload(name="unittest", description="unittest workload", - test_procedures=None) - - task = workload.Task("override-throughput", workload.Operation("override-throughput", - operation_type="override-throughput", params={ - # we need this because WorkerCoordinatorTestParamSource does not know that we only have one iteration and hence - # size() returns incorrect results - "size": 1 - }, - param_source="worker-coordinator-test-param-source"), - warmup_iterations=0, iterations=1, clients=1) + test_workload = workload.Workload(name="unittest", description="unittest workload", test_procedures=None) + + task = workload.Task( + "override-throughput", + workload.Operation( + "override-throughput", + operation_type="override-throughput", + params={ + # we need this because WorkerCoordinatorTestParamSource does not know that we only have one iteration and hence + # size() returns incorrect results + "size": 1 + }, + param_source="worker-coordinator-test-param-source", + ), + warmup_iterations=0, + iterations=1, + clients=1, + ) param_source = workload.operation_parameters(test_workload, task) - task_allocation = worker_coordinator.TaskAllocation(task=task, - client_index_in_task=0, - global_client_index=0, - total_clients=task.clients) - schedule = worker_coordinator.schedule_for( - task_allocation, param_source) + task_allocation = worker_coordinator.TaskAllocation(task=task, client_index_in_task=0, global_client_index=0, total_clients=task.clients) + schedule = worker_coordinator.schedule_for(task_allocation, param_source) sampler = worker_coordinator.DefaultSampler(start_timestamp=task_start) profile_sampler = worker_coordinator.ProfileMetricsSampler(start_timestamp=task_start) cancel = threading.Event() complete = threading.Event() - execute_schedule = worker_coordinator.AsyncExecutor(client_id=0, - task=task, - schedule=schedule, - clients={ - "default": opensearch - }, - sampler=sampler, - profile_sampler=profile_sampler, - cancel=cancel, - complete=complete, - on_error="continue") + execute_schedule = worker_coordinator.AsyncExecutor( + client_id=0, + task=task, + schedule=schedule, + clients={"default": opensearch}, + sampler=sampler, + profile_sampler=profile_sampler, + cancel=cancel, + complete=complete, + on_error="continue", + ) await execute_schedule() samples = sampler.samples @@ -1528,53 +1366,46 @@ async def test_run_schedule_runner_overrides_times(self, opensearch): @mock.patch("tests.worker_coordinator.worker_coordinator_test._FakeOSClient") @run_async async def test_cancel_execute_schedule(self, opensearch): - opensearch.init_request_context.return_value = { - "client_request_start": 0, - "request_start": 1, - "request_end": 11, - "client_request_end": 12 - } + opensearch.init_request_context.return_value = {"client_request_start": 0, "request_start": 1, "request_end": 11, "client_request_end": 12} opensearch.bulk.return_value = as_future(io.StringIO('{"errors": false, "took": 8}')) params.register_param_source_for_name("worker-coordinator-test-param-source", WorkerCoordinatorTestParamSource) - test_workload = workload.Workload(name="unittest", description="unittest workload", - test_procedures=None) + test_workload = workload.Workload(name="unittest", description="unittest workload", test_procedures=None) # in one second (0.5 warmup + 0.5 measurement) we should get 1000 [ops/s] / 4 [clients] = 250 samples for target_throughput in [10, 100, 1000]: - task = workload.Task("time-based", workload.Operation("time-based", - workload.OperationType.Bulk.to_hyphenated_string(), - params={ - "body": ["action_metadata_line", "index_line"], - "action-metadata-present": True, - "bulk-size": 1 - }, - param_source="worker-coordinator-test-param-source"), - warmup_time_period=0.5, time_period=0.5, clients=4, - params={"target-throughput": target_throughput, "clients": 4}) + task = workload.Task( + "time-based", + workload.Operation( + "time-based", + workload.OperationType.Bulk.to_hyphenated_string(), + params={"body": ["action_metadata_line", "index_line"], "action-metadata-present": True, "bulk-size": 1}, + param_source="worker-coordinator-test-param-source", + ), + warmup_time_period=0.5, + time_period=0.5, + clients=4, + params={"target-throughput": target_throughput, "clients": 4}, + ) param_source = workload.operation_parameters(test_workload, task) - task_allocation = worker_coordinator.TaskAllocation(task=task, - client_index_in_task=0, - global_client_index=0, - total_clients=task.clients) - schedule = worker_coordinator.schedule_for( - task_allocation, param_source) + task_allocation = worker_coordinator.TaskAllocation(task=task, client_index_in_task=0, global_client_index=0, total_clients=task.clients) + schedule = worker_coordinator.schedule_for(task_allocation, param_source) sampler = worker_coordinator.DefaultSampler(start_timestamp=0) profile_sampler = worker_coordinator.ProfileMetricsSampler(start_timestamp=0) cancel = threading.Event() complete = threading.Event() - execute_schedule = worker_coordinator.AsyncExecutor(client_id=0, - task=task, - schedule=schedule, - clients={ - "default": opensearch - }, - sampler=sampler, - profile_sampler=profile_sampler, - cancel=cancel, - complete=complete, - on_error="continue") + execute_schedule = worker_coordinator.AsyncExecutor( + client_id=0, + task=task, + schedule=schedule, + clients={"default": opensearch}, + sampler=sampler, + profile_sampler=profile_sampler, + cancel=cancel, + complete=complete, + on_error="continue", + ) cancel.set() await execute_schedule() @@ -1588,7 +1419,6 @@ async def test_cancel_execute_schedule(self, opensearch): @run_async async def test_run_schedule_aborts_on_error(self, opensearch): class ExpectedUnitTestException(Exception): - def __str__(self): return "expected unit test exception" @@ -1596,7 +1426,6 @@ def run(*args, **kwargs): raise ExpectedUnitTestException() class ScheduleHandle: - def __init__(self): self.ramp_up_wait_time = 0 @@ -1614,27 +1443,30 @@ async def __call__(self): for invocation in invocations: yield invocation - task = workload.Task("no-op", workload.Operation("no-op", workload.OperationType.Bulk.to_hyphenated_string(), - params={}, - param_source="worker-coordinator-test-param-source"), - warmup_time_period=0.5, time_period=0.5, clients=4, - params={"clients": 4}) + task = workload.Task( + "no-op", + workload.Operation("no-op", workload.OperationType.Bulk.to_hyphenated_string(), params={}, param_source="worker-coordinator-test-param-source"), + warmup_time_period=0.5, + time_period=0.5, + clients=4, + params={"clients": 4}, + ) sampler = worker_coordinator.DefaultSampler(start_timestamp=0) profile_sampler = worker_coordinator.ProfileMetricsSampler(start_timestamp=0) cancel = threading.Event() complete = threading.Event() - execute_schedule = worker_coordinator.AsyncExecutor(client_id=2, - task=task, - schedule=ScheduleHandle(), - clients={ - "default": opensearch - }, - sampler=sampler, - profile_sampler=profile_sampler, - cancel=cancel, - complete=complete, - on_error="continue") + execute_schedule = worker_coordinator.AsyncExecutor( + client_id=2, + task=task, + schedule=ScheduleHandle(), + clients={"default": opensearch}, + sampler=sampler, + profile_sampler=profile_sampler, + cancel=cancel, + complete=complete, + on_error="continue", + ) with self.assertRaisesRegex(exceptions.BenchmarkError, r"Cannot run task \[no-op\]: expected unit test exception"): await execute_schedule() @@ -1648,11 +1480,7 @@ async def test_run_single_no_return_value(self): runner = mock.Mock() runner.return_value = as_future() - ops, unit, request_meta_data = await worker_coordinator.execute_single( - self.context_managed(runner), - opensearch, - params, - on_error="continue") + ops, unit, request_meta_data = await worker_coordinator.execute_single(self.context_managed(runner), opensearch, params, on_error="continue") self.assertEqual(1, ops) self.assertEqual("ops", unit) @@ -1665,11 +1493,7 @@ async def test_run_single_tuple(self): runner = mock.Mock() runner.return_value = as_future(result=(500, "MB")) - ops, unit, request_meta_data = await worker_coordinator.execute_single( - self.context_managed(runner), - opensearch, - params, - on_error="continue") + ops, unit, request_meta_data = await worker_coordinator.execute_single(self.context_managed(runner), opensearch, params, on_error="continue") self.assertEqual(500, ops) self.assertEqual("MB", unit) @@ -1680,28 +1504,15 @@ async def test_run_single_dict(self): opensearch = None params = None runner = mock.Mock() - runner.return_value = as_future({ - "weight": 50, - "unit": "docs", - "some-custom-meta-data": "valid", - "http-status": 200 - }) - - ops, unit, request_meta_data = await worker_coordinator.execute_single( - self.context_managed(runner), - opensearch, - params, - on_error="continue") + runner.return_value = as_future({"weight": 50, "unit": "docs", "some-custom-meta-data": "valid", "http-status": 200}) + + ops, unit, request_meta_data = await worker_coordinator.execute_single(self.context_managed(runner), opensearch, params, on_error="continue") self.assertEqual(50, ops) self.assertEqual("docs", unit) - self.assertEqual({ - "some-custom-meta-data": "valid", - "http-status": 200, - "success": True - }, request_meta_data) + self.assertEqual({"some-custom-meta-data": "valid", "http-status": 200, "success": True}, request_meta_data) - @mock.patch('solrorbit.client.RequestContextHolder.on_client_request_end') + @mock.patch("solrorbit.client.RequestContextHolder.on_client_request_end") @run_async async def test_run_single_with_connection_error_always_aborts(self, on_client_request_end): for on_error in ["abort", "continue"]: @@ -1712,65 +1523,52 @@ async def test_run_single_with_connection_error_always_aborts(self, on_client_re with self.assertRaises(exceptions.BenchmarkAssertionError) as ctx: await worker_coordinator.execute_single(self.context_managed(runner), opensearch, params, on_error=on_error) - self.assertEqual( - "Request returned an error. Error type: transport, Description: no route to host", - ctx.exception.args[0]) + self.assertEqual("Request returned an error. Error type: transport, Description: no route to host", ctx.exception.args[0]) - @mock.patch('solrorbit.client.RequestContextHolder.on_client_request_end') + @mock.patch("solrorbit.client.RequestContextHolder.on_client_request_end") @run_async async def test_run_single_with_http_400_aborts_when_specified(self, on_client_request_end): opensearch = None params = None - runner = mock.Mock(side_effect= - as_future(exception=exceptions.BenchmarkTransportError(status_code=404, error="not found", info="the requested document could not be found"))) + runner = mock.Mock( + side_effect=as_future(exception=exceptions.BenchmarkTransportError(status_code=404, error="not found", info="the requested document could not be found")) + ) with self.assertRaises(exceptions.BenchmarkAssertionError) as ctx: await worker_coordinator.execute_single(self.context_managed(runner), opensearch, params, on_error="abort") - self.assertEqual( - "Request returned an error. Error type: transport, Description: not found (the requested document could not be found)", - ctx.exception.args[0]) + self.assertEqual("Request returned an error. Error type: transport, Description: not found (the requested document could not be found)", ctx.exception.args[0]) - @mock.patch('solrorbit.client.RequestContextHolder.on_client_request_end') + @mock.patch("solrorbit.client.RequestContextHolder.on_client_request_end") @run_async async def test_run_single_with_http_400(self, on_client_request_end): opensearch = None params = None - runner = mock.Mock(side_effect= - as_future(exception=exceptions.BenchmarkTransportError(status_code=404, error="not found", info="the requested document could not be found"))) + runner = mock.Mock( + side_effect=as_future(exception=exceptions.BenchmarkTransportError(status_code=404, error="not found", info="the requested document could not be found")) + ) - ops, unit, request_meta_data = await worker_coordinator.execute_single( - self.context_managed(runner), opensearch, params, on_error="continue") + ops, unit, request_meta_data = await worker_coordinator.execute_single(self.context_managed(runner), opensearch, params, on_error="continue") self.assertEqual(0, ops) self.assertEqual("ops", unit) - self.assertEqual({ - "http-status": 404, - "error-type": "transport", - "error-description": "not found (the requested document could not be found)", - "success": False - }, request_meta_data) - - @mock.patch('solrorbit.client.RequestContextHolder.on_client_request_end') + self.assertEqual( + {"http-status": 404, "error-type": "transport", "error-description": "not found (the requested document could not be found)", "success": False}, request_meta_data + ) + + @mock.patch("solrorbit.client.RequestContextHolder.on_client_request_end") @run_async async def test_run_single_with_http_413(self, on_client_request_end): opensearch = None params = None - runner = mock.Mock(side_effect= - as_future(exception=exceptions.BenchmarkTransportError(status_code=413))) + runner = mock.Mock(side_effect=as_future(exception=exceptions.BenchmarkTransportError(status_code=413))) - ops, unit, request_meta_data = await worker_coordinator.execute_single( - self.context_managed(runner), opensearch, params, on_error="continue") + ops, unit, request_meta_data = await worker_coordinator.execute_single(self.context_managed(runner), opensearch, params, on_error="continue") self.assertEqual(0, ops) self.assertEqual("ops", unit) - self.assertEqual({ - "http-status": 413, - "error-type": "transport", - "error-description": "", - "success": False - }, request_meta_data) - - @mock.patch('solrorbit.client.RequestContextHolder.on_client_request_end') + self.assertEqual({"http-status": 413, "error-type": "transport", "error-description": "", "success": False}, request_meta_data) + + @mock.patch("solrorbit.client.RequestContextHolder.on_client_request_end") @run_async async def test_run_single_with_key_error(self, on_client_request_end): class FailingRunner: @@ -1789,9 +1587,7 @@ def __str__(self): with self.assertRaises(exceptions.SystemSetupError) as ctx: await worker_coordinator.execute_single(self.context_managed(runner), opensearch, params, on_error="continue") - self.assertEqual( - "Cannot execute [failing_mock_runner]. Provided parameters are: ['bulk', 'mode']. Error: ['bulk-size missing'].", - ctx.exception.args[0]) + self.assertEqual("Cannot execute [failing_mock_runner]. Provided parameters are: ['bulk', 'mode']. Error: ['bulk-size missing'].", ctx.exception.args[0]) class AsyncExecutorHelperMethodsTests(TestCase): @@ -1811,12 +1607,9 @@ def setUp(self): self.cfg.add(config.Scope.application, "workload", "test.mode.enabled", True) all_client_options = {"default": {"base_timeout": 10}} - self.cfg.add(config.Scope.application, "client", "options", - WorkerCoordinatorTests.Holder(all_client_options=all_client_options)) + self.cfg.add(config.Scope.application, "client", "options", WorkerCoordinatorTests.Holder(all_client_options=all_client_options)) - self.task = workload.Task("test-task", - workload.Operation("test-op", workload.OperationType.Bulk), - clients=2) + self.task = workload.Task("test-task", workload.Operation("test-op", workload.OperationType.Bulk), clients=2) self.sampler = mock.Mock() self.profile_sampler = mock.Mock() self.cancel = threading.Event() @@ -1844,7 +1637,7 @@ def setUp(self): config=self.cfg, shared_states=self.shared_states, error_queue=self.error_queue, - queue_lock=self.queue_lock + queue_lock=self.queue_lock, ) def test_get_client_options_with_valid_config(self): @@ -1894,11 +1687,7 @@ def test_report_error_with_queue(self): """Test report_error when an error queue is present.""" error_q = queue.Queue() self.executor.error_queue = error_q - error_info = { - "client_id": self.executor.client_id, - "task": str(self.executor.task), - "error_details": {"success": False, "error-type": "test-error"} - } + error_info = {"client_id": self.executor.client_id, "task": str(self.executor.task), "error_details": {"success": False, "error-type": "test-error"}} self.executor.report_error(error_info) queued_error = error_q.get_nowait() self.assertEqual(queued_error, error_info) @@ -1906,31 +1695,30 @@ def test_report_error_with_queue(self): def test_report_error_without_queue(self): """Test report_error runs without error when the error queue is None.""" self.executor.error_queue = None - error_info = { - "client_id": self.executor.client_id, - "task": str(self.executor.task), - "error_details": {"success": False, "error-type": "test-error"} - } + error_info = {"client_id": self.executor.client_id, "task": str(self.executor.task), "error_details": {"success": False, "error-type": "test-error"}} self.executor.report_error(error_info) def test_process_results_with_active_client(self): """Test _process_results adds a sample for an active client.""" result_data = { "absolute_processing_start": time.time(), - "request_start": 10.0, "request_end": 11.0, - "client_request_start": 9.9, "client_request_end": 11.1, - "processing_start": 9.8, "processing_end": 11.2, - "total_ops": 100, "total_ops_unit": "docs", - "request_meta_data": {"success": True}, "throughput_throttled": True + "request_start": 10.0, + "request_end": 11.0, + "client_request_start": 9.9, + "client_request_end": 11.1, + "processing_start": 9.8, + "processing_end": 11.2, + "total_ops": 100, + "total_ops_unit": "docs", + "request_meta_data": {"success": True}, + "throughput_throttled": True, } - self.executor.runner = mock.Mock(completed=False, task_progress=(0.5, '%')) + self.executor.runner = mock.Mock(completed=False, task_progress=(0.5, "%")) self.executor.task_completes_parent = False self.executor.sample_type = metrics.SampleType.Normal self.executor.expected_scheduled_time = 0 - completed = self.executor._process_results( - result_data, total_start=5.0, client_state=True, task_progress=(0.8, '%') - ) + completed = self.executor._process_results(result_data, total_start=5.0, client_state=True, task_progress=(0.8, "%")) self.assertFalse(completed) self.schedule_handle.after_request.assert_called_once() self.sampler.add.assert_called_once() @@ -1940,20 +1728,23 @@ def test_process_results_with_inactive_client(self): """Test _process_results does not add a sample for an inactive client.""" result_data = { "absolute_processing_start": time.time(), - "request_start": 10.0, "request_end": 11.0, - "client_request_start": 9.9, "client_request_end": 11.1, - "processing_start": 9.8, "processing_end": 11.2, - "total_ops": 100, "total_ops_unit": "docs", - "request_meta_data": {"success": True}, "throughput_throttled": True + "request_start": 10.0, + "request_end": 11.0, + "client_request_start": 9.9, + "client_request_end": 11.1, + "processing_start": 9.8, + "processing_end": 11.2, + "total_ops": 100, + "total_ops_unit": "docs", + "request_meta_data": {"success": True}, + "throughput_throttled": True, } - self.executor.runner = mock.Mock(completed=False, task_progress=(0.5, '%')) + self.executor.runner = mock.Mock(completed=False, task_progress=(0.5, "%")) self.executor.task_completes_parent = False self.executor.sample_type = metrics.SampleType.Normal self.executor.expected_scheduled_time = 0 - completed = self.executor._process_results( - result_data, total_start=5.0, client_state=False, task_progress=(0.8, '%') - ) + completed = self.executor._process_results(result_data, total_start=5.0, client_state=False, task_progress=(0.8, "%")) self.assertFalse(completed) self.schedule_handle.after_request.assert_called_once() self.sampler.add.assert_not_called() @@ -1978,13 +1769,11 @@ async def test_execute_request_success(self): self.executor._prepare_context_manager = mock.AsyncMock(return_value=context_manager) self.executor.runner = mock.Mock() - with mock.patch('solrorbit.worker_coordinator.worker_coordinator.execute_single'): - with mock.patch('asyncio.wait_for') as wait_for_mock: + with mock.patch("solrorbit.worker_coordinator.worker_coordinator.execute_single"): + with mock.patch("asyncio.wait_for") as wait_for_mock: wait_for_mock.return_value = (100, "docs", {"success": True}) - result = await self.executor._execute_request( - params, expected_scheduled_time, total_start, client_state - ) + result = await self.executor._execute_request(params, expected_scheduled_time, total_start, client_state) self.assertEqual(result["total_ops"], 100) self.assertEqual(result["total_ops_unit"], "docs") @@ -2010,14 +1799,12 @@ async def test_execute_request_with_throttling(self): self.executor._prepare_context_manager = mock.AsyncMock(return_value=context_manager) self.executor.runner = mock.Mock() - with mock.patch('time.perf_counter', side_effect=[10.0, 10.5, 13.0]): - with mock.patch('asyncio.sleep') as sleep_mock: - with mock.patch('asyncio.wait_for') as wait_for_mock: + with mock.patch("time.perf_counter", side_effect=[10.0, 10.5, 13.0]): + with mock.patch("asyncio.sleep") as sleep_mock: + with mock.patch("asyncio.wait_for") as wait_for_mock: wait_for_mock.return_value = (50, "docs", {"success": True}) - result = await self.executor._execute_request( - params, expected_scheduled_time, total_start, client_state - ) + result = await self.executor._execute_request(params, expected_scheduled_time, total_start, client_state) sleep_mock.assert_called_once_with(1.0) self.assertTrue(result["throughput_throttled"]) @@ -2040,6 +1827,7 @@ async def f(x): duration = end - start self.assertTrue(0.9 <= duration <= 1.2, "Should sleep for roughly 1 second but took [%.2f] seconds." % duration) + class FeedbackActorTests(TestCase): @pytest.fixture(autouse=True) def setup_actor(self): @@ -2051,15 +1839,8 @@ def setup_actor(self): def test_receive_shared_client_state_sets_total_client_count(self): self.actor.wakeupAfter = mock.MagicMock() - shared_states = { - 0: {0: False, 1: False}, - 1: {2: False, 3: False, 4: False} - } - message = worker_coordinator.StartFeedbackActor( - shared_states=shared_states, - error_queue=queue.Queue(), - queue_lock=mock.MagicMock() - ) + shared_states = {0: {0: False, 1: False}, 1: {2: False, 3: False, 4: False}} + message = worker_coordinator.StartFeedbackActor(shared_states=shared_states, error_queue=queue.Queue(), queue_lock=mock.MagicMock()) self.actor.receiveMsg_StartFeedbackActor(message, sender=None) assert self.actor.total_client_count == 5 @@ -2072,11 +1853,7 @@ def test_receive_start_feedback_actor_sets_queue_refs(self): dummy_queue_lock = mock.MagicMock() dummy_states = {0: {0: False}} - message = worker_coordinator.StartFeedbackActor( - shared_states=dummy_states, - error_queue=dummy_error_queue, - queue_lock=dummy_queue_lock - ) + message = worker_coordinator.StartFeedbackActor(shared_states=dummy_states, error_queue=dummy_error_queue, queue_lock=dummy_queue_lock) self.actor.receiveMsg_StartFeedbackActor(message, sender=None) assert self.actor.error_queue == dummy_error_queue @@ -2085,10 +1862,7 @@ def test_receive_start_feedback_actor_sets_queue_refs(self): self.actor.wakeupAfter.assert_called_once() def test_scale_up_only_activates_n_clients(self): - self.actor.shared_client_states = { - 0: {0: False, 1: False}, - 1: {2: False, 3: False} - } + self.actor.shared_client_states = {0: {0: False, 1: False}, 1: {2: False, 3: False}} self.actor.total_active_client_count = 0 self.actor.num_clients_to_scale_up = 2 @@ -2096,10 +1870,7 @@ def test_scale_up_only_activates_n_clients(self): assert self.actor.total_active_client_count == 2 def test_scale_down_pauses_percentage(self): - self.actor.shared_client_states = { - 0: {0: True, 1: True}, - 1: {2: True, 3: True, 4: False} - } + self.actor.shared_client_states = {0: {0: True, 1: True}, 1: {2: True, 3: True, 4: False}} self.actor.total_active_client_count = 4 # 4 active clients self.actor.percentage_clients_to_scale_down = 0.5 @@ -2112,29 +1883,22 @@ def test_handle_state_scales_up_only_when_conditions_met(self): self.actor.total_active_client_count = 0 self.actor.last_error_time = time.perf_counter() - 31 self.actor.last_scaleup_time = time.perf_counter() - 2 - self.actor.shared_client_states = { - 0: {0: False, 1: False}, - 1: {2: False, 3: False} - } + self.actor.shared_client_states = {0: {0: False, 1: False}, 1: {2: False, 3: False}} self.monkeypatch.setattr(self.actor, "check_for_errors", lambda: []) - self.actor.handle_state() # once to set to SCALING_UP - self.actor.handle_state() # once to scale up + self.actor.handle_state() # once to set to SCALING_UP + self.actor.handle_state() # once to scale up assert self.actor.state == worker_coordinator.FeedbackState.NEUTRAL assert self.actor.total_active_client_count > 0 self.monkeypatch.undo() - def test_handle_state_enters_sleep_on_error(self): self.actor.state = worker_coordinator.FeedbackState.NEUTRAL self.actor.total_active_client_count = 2 - self.actor.shared_client_states = { - 0: {0: True, 1: True}, - 1: {2: True, 3: True, 4: False} - } + self.actor.shared_client_states = {0: {0: True, 1: True}, 1: {2: True, 3: True, 4: False}} self.actor.error_queue.put({"error": "foo"}) self.actor.handle_state() @@ -2146,20 +1910,16 @@ def test_check_cpu_usage_raises_system_setup_error(self): # CPU-based redline feedback is not supported in Solr Orbit; # _check_cpu_usage must raise SystemSetupError immediately. from solrorbit.exceptions import SystemSetupError + with self.assertRaises(SystemSetupError): self.actor._check_cpu_usage() # pylint: disable=protected-access + class TimePeriodBasedTests(TestCase): # pylint: disable=protected-access def test_time_period_based_without_ramp_down(self): # Test basic time-period based schedule without ramp-down - loop_control = worker_coordinator.TimePeriodBased( - warmup_time_period=10, - time_period=100, - ramp_down_time_period=None, - client_index=0, - total_clients=4 - ) + loop_control = worker_coordinator.TimePeriodBased(warmup_time_period=10, time_period=100, ramp_down_time_period=None, client_index=0, total_clients=4) # Verify duration calculation self.assertEqual(110, loop_control._duration) @@ -2170,13 +1930,7 @@ def test_time_period_based_without_ramp_down(self): def test_time_period_based_with_ramp_down_client_0(self): # Test ramp-down for client 0 (first client, stops earliest) - loop_control = worker_coordinator.TimePeriodBased( - warmup_time_period=10, - time_period=100, - ramp_down_time_period=20, - client_index=0, - total_clients=4 - ) + loop_control = worker_coordinator.TimePeriodBased(warmup_time_period=10, time_period=100, ramp_down_time_period=20, client_index=0, total_clients=4) # Client 0: reverse_index = 3, early_stop = 20 * (3/4) = 15 # duration = 110 - 15 = 95 @@ -2186,13 +1940,7 @@ def test_time_period_based_with_ramp_down_client_0(self): def test_time_period_based_with_ramp_down_client_3(self): # Test ramp-down for client 3 (last client, runs full duration) - loop_control = worker_coordinator.TimePeriodBased( - warmup_time_period=10, - time_period=100, - ramp_down_time_period=20, - client_index=3, - total_clients=4 - ) + loop_control = worker_coordinator.TimePeriodBased(warmup_time_period=10, time_period=100, ramp_down_time_period=20, client_index=3, total_clients=4) # Client 3: reverse_index = 0, early_stop = 20 * (0/4) = 0 # duration = 110 - 0 = 110 @@ -2209,11 +1957,7 @@ def test_time_period_based_with_ramp_down_all_clients(self): durations = [] for client_index in range(total_clients): loop_control = worker_coordinator.TimePeriodBased( - warmup_time_period=warmup, - time_period=time_period, - ramp_down_time_period=ramp_down, - client_index=client_index, - total_clients=total_clients + warmup_time_period=warmup, time_period=time_period, ramp_down_time_period=ramp_down, client_index=client_index, total_clients=total_clients ) durations.append(loop_control._duration) @@ -2222,23 +1966,16 @@ def test_time_period_based_with_ramp_down_all_clients(self): # Verify spacing is correct for i in range(1, len(durations)): - spacing = durations[i] - durations[i-1] + spacing = durations[i] - durations[i - 1] expected_spacing = ramp_down / total_clients self.assertAlmostEqual(expected_spacing, spacing, places=2) + class TaskRampDownTests(TestCase): def test_task_with_ramp_down_time_period(self): # Test that Task accepts ramp_down_time_period parameter op = workload.Operation("test-op", workload.OperationType.Bulk.to_hyphenated_string(), {}) - task = workload.Task( - name="test-task", - operation=op, - warmup_time_period=10, - time_period=100, - ramp_up_time_period=20, - ramp_down_time_period=30, - clients=4 - ) + task = workload.Task(name="test-task", operation=op, warmup_time_period=10, time_period=100, ramp_up_time_period=20, ramp_down_time_period=30, clients=4) self.assertEqual(10, task.warmup_time_period) self.assertEqual(100, task.time_period) @@ -2249,12 +1986,7 @@ def test_task_with_ramp_down_time_period(self): def test_task_without_ramp_down_defaults_to_none(self): # Test that ramp_down_time_period defaults to None op = workload.Operation("test-op", workload.OperationType.Bulk.to_hyphenated_string(), {}) - task = workload.Task( - name="test-task", - operation=op, - time_period=100, - clients=4 - ) + task = workload.Task(name="test-task", operation=op, time_period=100, clients=4) self.assertIsNone(task.ramp_down_time_period) @@ -2262,29 +1994,11 @@ def test_task_equality_with_ramp_down(self): # Test that tasks with different ramp_down_time_period are not equal op = workload.Operation("test-op", workload.OperationType.Bulk.to_hyphenated_string(), {}) - task1 = workload.Task( - name="test-task", - operation=op, - time_period=100, - ramp_down_time_period=20, - clients=4 - ) + task1 = workload.Task(name="test-task", operation=op, time_period=100, ramp_down_time_period=20, clients=4) - task2 = workload.Task( - name="test-task", - operation=op, - time_period=100, - ramp_down_time_period=30, - clients=4 - ) + task2 = workload.Task(name="test-task", operation=op, time_period=100, ramp_down_time_period=30, clients=4) - task3 = workload.Task( - name="test-task", - operation=op, - time_period=100, - ramp_down_time_period=20, - clients=4 - ) + task3 = workload.Task(name="test-task", operation=op, time_period=100, ramp_down_time_period=20, clients=4) self.assertNotEqual(task1, task2) self.assertEqual(task1, task3) @@ -2293,21 +2007,9 @@ def test_task_hash_includes_ramp_down(self): # Test that hash includes ramp_down_time_period op = workload.Operation("test-op", workload.OperationType.Bulk.to_hyphenated_string(), {}) - task1 = workload.Task( - name="test-task", - operation=op, - time_period=100, - ramp_down_time_period=20, - clients=4 - ) + task1 = workload.Task(name="test-task", operation=op, time_period=100, ramp_down_time_period=20, clients=4) - task2 = workload.Task( - name="test-task", - operation=op, - time_period=100, - ramp_down_time_period=30, - clients=4 - ) + task2 = workload.Task(name="test-task", operation=op, time_period=100, ramp_down_time_period=30, clients=4) # Different ramp_down should produce different hashes self.assertNotEqual(hash(task1), hash(task2)) diff --git a/tests/workload/__init__.py b/tests/workload/__init__.py index 5047a451..f5768141 100644 --- a/tests/workload/__init__.py +++ b/tests/workload/__init__.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an diff --git a/tests/workload/loader_test.py b/tests/workload/loader_test.py index eccebb79..f0599100 100644 --- a/tests/workload/loader_test.py +++ b/tests/workload/loader_test.py @@ -79,10 +79,8 @@ def test_workload_from_directory(self, is_dir, path_exists): repo = loader.SimpleWorkloadRepository("/path/to/workload/unit-test") self.assertEqual("unit-test", repo.workload_name) self.assertEqual(["unit-test"], repo.workload_names) - self.assertEqual("/path/to/workload/unit-test", - repo.workload_dir("unit-test")) - self.assertEqual("/path/to/workload/unit-test/workload.json", - repo.workload_file("unit-test")) + self.assertEqual("/path/to/workload/unit-test", repo.workload_dir("unit-test")) + self.assertEqual("/path/to/workload/unit-test/workload.json", repo.workload_file("unit-test")) @mock.patch("os.path.exists") @mock.patch("os.path.isdir") @@ -92,14 +90,11 @@ def test_workload_from_file(self, is_file, is_dir, path_exists): is_dir.return_value = False path_exists.return_value = True - repo = loader.SimpleWorkloadRepository( - "/path/to/workload/unit-test/my-workload.json") + repo = loader.SimpleWorkloadRepository("/path/to/workload/unit-test/my-workload.json") self.assertEqual("my-workload", repo.workload_name) self.assertEqual(["my-workload"], repo.workload_names) - self.assertEqual("/path/to/workload/unit-test", - repo.workload_dir("my-workload")) - self.assertEqual("/path/to/workload/unit-test/my-workload.json", - repo.workload_file("my-workload")) + self.assertEqual("/path/to/workload/unit-test", repo.workload_dir("my-workload")) + self.assertEqual("/path/to/workload/unit-test/my-workload.json", repo.workload_file("my-workload")) @mock.patch("os.path.exists") @mock.patch("os.path.isdir") @@ -110,18 +105,15 @@ def test_workload_from_named_pipe(self, is_file, is_dir, path_exists): path_exists.return_value = True with self.assertRaises(exceptions.SystemSetupError) as ctx: - loader.SimpleWorkloadRepository( - "a named pipe cannot point to a workload") - self.assertEqual( - "a named pipe cannot point to a workload is neither a file nor a directory", ctx.exception.args[0]) + loader.SimpleWorkloadRepository("a named pipe cannot point to a workload") + self.assertEqual("a named pipe cannot point to a workload is neither a file nor a directory", ctx.exception.args[0]) @mock.patch("os.path.exists") def test_workload_from_non_existing_path(self, path_exists): path_exists.return_value = False with self.assertRaises(exceptions.SystemSetupError) as ctx: loader.SimpleWorkloadRepository("/path/does/not/exist") - self.assertEqual( - "Workload path /path/does/not/exist does not exist", ctx.exception.args[0]) + self.assertEqual("Workload path /path/does/not/exist does not exist", ctx.exception.args[0]) @mock.patch("os.path.isdir") @mock.patch("os.path.exists") @@ -131,8 +123,7 @@ def test_workload_from_directory_without_workload(self, path_exists, is_dir): is_dir.return_value = True with self.assertRaises(exceptions.SystemSetupError) as ctx: loader.SimpleWorkloadRepository("/path/to/not/a/workload") - self.assertEqual( - "Could not find workload.json in /path/to/not/a/workload", ctx.exception.args[0]) + self.assertEqual("Could not find workload.json in /path/to/not/a/workload", ctx.exception.args[0]) @mock.patch("os.path.exists") @mock.patch("os.path.isdir") @@ -143,10 +134,8 @@ def test_workload_from_file_but_not_json(self, is_file, is_dir, path_exists): path_exists.return_value = True with self.assertRaises(exceptions.SystemSetupError) as ctx: - loader.SimpleWorkloadRepository( - "/path/to/workload/unit-test/my-workload.xml") - self.assertEqual( - "/path/to/workload/unit-test/my-workload.xml has to be a JSON file", ctx.exception.args[0]) + loader.SimpleWorkloadRepository("/path/to/workload/unit-test/my-workload.xml") + self.assertEqual("/path/to/workload/unit-test/my-workload.xml has to be a JSON file", ctx.exception.args[0]) class GitRepositoryTests(TestCase): @@ -157,29 +146,21 @@ def __init__(self, remote_url, root_dir, repo_name, resource_name, offline, fetc @mock.patch("os.path.exists") @mock.patch("os.walk") def test_workload_from_existing_repo(self, walk, exists): - walk.return_value = iter( - [(".", ["unittest", "unittest2", "unittest3"], [])]) + walk.return_value = iter([(".", ["unittest", "unittest2", "unittest3"], [])]) exists.return_value = True cfg = config.Config() - cfg.add(config.Scope.application, "workload", - "workload.name", "unittest") - cfg.add(config.Scope.application, "workload", - "repository.name", "default") + cfg.add(config.Scope.application, "workload", "workload.name", "unittest") + cfg.add(config.Scope.application, "workload", "repository.name", "default") cfg.add(config.Scope.application, "system", "offline.mode", False) cfg.add(config.Scope.application, "node", "root.dir", "/tmp") - cfg.add(config.Scope.application, "benchmarks", - "workload.repository.dir", "workloads") + cfg.add(config.Scope.application, "benchmarks", "workload.repository.dir", "workloads") - repo = loader.GitWorkloadRepository( - cfg, fetch=False, update=False, repo_class=GitRepositoryTests.MockGitRepo) + repo = loader.GitWorkloadRepository(cfg, fetch=False, update=False, repo_class=GitRepositoryTests.MockGitRepo) self.assertEqual("unittest", repo.workload_name) - self.assertEqual(["unittest", "unittest2", "unittest3"], - list(repo.workload_names)) - self.assertEqual("/tmp/workloads/default/unittest", - repo.workload_dir("unittest")) - self.assertEqual( - "/tmp/workloads/default/unittest/workload.json", repo.workload_file("unittest")) + self.assertEqual(["unittest", "unittest2", "unittest3"], list(repo.workload_names)) + self.assertEqual("/tmp/workloads/default/unittest", repo.workload_dir("unittest")) + self.assertEqual("/tmp/workloads/default/unittest/workload.json", repo.workload_file("unittest")) class WorkloadPreparationTests(TestCase): @@ -191,21 +172,21 @@ def test_does_nothing_if_document_file_available(self, is_file, get_size, prepar get_size.return_value = 2000 prepare_file_offset_table.return_value = 5 - p = loader.DocumentSetPreparator(workload_name="unit-test", - downloader=loader.Downloader( - offline=False, test_mode=False), - decompressor=loader.Decompressor()) - - p.prepare_document_set(document_set=workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, - document_file="docs.json", - document_archive="docs.json.bz2", - number_of_documents=5, - compressed_size_in_bytes=200, - uncompressed_size_in_bytes=2000), - data_root="/tmp") + p = loader.DocumentSetPreparator(workload_name="unit-test", downloader=loader.Downloader(offline=False, test_mode=False), decompressor=loader.Decompressor()) + + p.prepare_document_set( + document_set=workload.Documents( + source_format=workload.Documents.SOURCE_FORMAT_BULK, + document_file="docs.json", + document_archive="docs.json.bz2", + number_of_documents=5, + compressed_size_in_bytes=200, + uncompressed_size_in_bytes=2000, + ), + data_root="/tmp", + ) - prepare_file_offset_table.assert_called_with( - "/tmp/docs.json", None, None, InstanceOf(loader.Downloader)) + prepare_file_offset_table.assert_called_with("/tmp/docs.json", None, None, InstanceOf(loader.Downloader)) @mock.patch("solrorbit.utils.io.prepare_file_offset_table") @mock.patch("os.path.getsize") @@ -215,21 +196,21 @@ def test_decompresses_if_archive_available(self, is_file, get_size, prepare_file get_size.return_value = 2000 prepare_file_offset_table.return_value = 5 - p = loader.DocumentSetPreparator(workload_name="unit-test", - downloader=loader.Downloader( - offline=False, test_mode=False), - decompressor=loader.Decompressor()) - - p.prepare_document_set(document_set=workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, - document_file="docs.json", - document_archive="docs.json.bz2", - number_of_documents=5, - compressed_size_in_bytes=200, - uncompressed_size_in_bytes=2000), - data_root="/tmp") + p = loader.DocumentSetPreparator(workload_name="unit-test", downloader=loader.Downloader(offline=False, test_mode=False), decompressor=loader.Decompressor()) + + p.prepare_document_set( + document_set=workload.Documents( + source_format=workload.Documents.SOURCE_FORMAT_BULK, + document_file="docs.json", + document_archive="docs.json.bz2", + number_of_documents=5, + compressed_size_in_bytes=200, + uncompressed_size_in_bytes=2000, + ), + data_root="/tmp", + ) - prepare_file_offset_table.assert_called_with( - "/tmp/docs.json", None, None, InstanceOf(loader.Downloader)) + prepare_file_offset_table.assert_called_with("/tmp/docs.json", None, None, InstanceOf(loader.Downloader)) @mock.patch("solrorbit.utils.io.decompress") @mock.patch("os.path.getsize") @@ -243,21 +224,21 @@ def test_raise_error_on_wrong_uncompressed_file_size(self, is_file, get_size, de # uncompressed is corrupt, only 1 byte available get_size.side_effect = [200, 1] - p = loader.DocumentSetPreparator(workload_name="unit-test", - downloader=loader.Downloader( - offline=False, test_mode=False), - decompressor=loader.Decompressor()) + p = loader.DocumentSetPreparator(workload_name="unit-test", downloader=loader.Downloader(offline=False, test_mode=False), decompressor=loader.Decompressor()) with self.assertRaises(exceptions.DataError) as ctx: - p.prepare_document_set(document_set=workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, - document_file="docs.json", - document_archive="docs.json.bz2", - number_of_documents=5, - compressed_size_in_bytes=200, - uncompressed_size_in_bytes=2000), - data_root="/tmp") - self.assertEqual( - "[/tmp/docs.json] is corrupt. Extracted [1] bytes but [2000] bytes are expected.", ctx.exception.args[0]) + p.prepare_document_set( + document_set=workload.Documents( + source_format=workload.Documents.SOURCE_FORMAT_BULK, + document_file="docs.json", + document_archive="docs.json.bz2", + number_of_documents=5, + compressed_size_in_bytes=200, + uncompressed_size_in_bytes=2000, + ), + data_root="/tmp", + ) + self.assertEqual("[/tmp/docs.json] is corrupt. Extracted [1] bytes but [2000] bytes are expected.", ctx.exception.args[0]) decompress.assert_called_with("/tmp/docs.json.bz2", "/tmp") @@ -272,22 +253,25 @@ def test_raise_error_if_compressed_does_not_contain_expected_document_file(self, # compressed file size is 200 get_size.return_value = 200 - p = loader.DocumentSetPreparator(workload_name="unit-test", - downloader=loader.Downloader( - offline=False, test_mode=False), - decompressor=loader.Decompressor()) + p = loader.DocumentSetPreparator(workload_name="unit-test", downloader=loader.Downloader(offline=False, test_mode=False), decompressor=loader.Decompressor()) with self.assertRaises(exceptions.DataError) as ctx: - p.prepare_document_set(document_set=workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, - base_url="http://benchmarks.opensearch.org/corpora/unit-test", - document_file="docs.json", - document_archive="docs.json.bz2", - number_of_documents=5, - compressed_size_in_bytes=200, - uncompressed_size_in_bytes=2000), - data_root="/tmp") - self.assertEqual("Decompressing [/tmp/docs.json.bz2] did not create [/tmp/docs.json]. Please check with the workload author if the " - "compressed archive has been created correctly.", ctx.exception.args[0]) + p.prepare_document_set( + document_set=workload.Documents( + source_format=workload.Documents.SOURCE_FORMAT_BULK, + base_url="http://benchmarks.opensearch.org/corpora/unit-test", + document_file="docs.json", + document_archive="docs.json.bz2", + number_of_documents=5, + compressed_size_in_bytes=200, + uncompressed_size_in_bytes=2000, + ), + data_root="/tmp", + ) + self.assertEqual( + "Decompressing [/tmp/docs.json.bz2] did not create [/tmp/docs.json]. Please check with the workload author if the compressed archive has been created correctly.", + ctx.exception.args[0], + ) decompress.assert_called_with("/tmp/docs.json.bz2", "/tmp") @@ -297,16 +281,14 @@ def test_raise_error_if_compressed_does_not_contain_expected_document_file(self, @mock.patch("solrorbit.utils.io.ensure_dir") @mock.patch("os.path.getsize") @mock.patch("os.path.isfile") - def test_download_document_archive_if_no_file_available(self, is_file, get_size, ensure_dir, download, decompress, - prepare_file_offset_table): + def test_download_document_archive_if_no_file_available(self, is_file, get_size, ensure_dir, download, decompress, prepare_file_offset_table): # uncompressed file does not exist # compressed file does not exist # after download compressed file exists # after download uncompressed file still does not exist (in main loop) # after download compressed file exists (in main loop) # after decompression, uncompressed file exists - is_file.side_effect = [False, False, - True, False, True, True, True, True] + is_file.side_effect = [False, False, True, False, True, True, True, True] # compressed file size is 200 after download # compressed file size is 200 after download (in main loop) # uncompressed file size is 2000 after decompression @@ -315,27 +297,26 @@ def test_download_document_archive_if_no_file_available(self, is_file, get_size, prepare_file_offset_table.return_value = 5 - p = loader.DocumentSetPreparator(workload_name="unit-test", - downloader=loader.Downloader( - offline=False, test_mode=False), - decompressor=loader.Decompressor()) - - p.prepare_document_set(document_set=workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, - base_url="http://benchmarks.opensearch.org/corpora/unit-test", - document_file="docs.json", - document_archive="docs.json.bz2", - number_of_documents=5, - compressed_size_in_bytes=200, - uncompressed_size_in_bytes=2000), - data_root="/tmp") + p = loader.DocumentSetPreparator(workload_name="unit-test", downloader=loader.Downloader(offline=False, test_mode=False), decompressor=loader.Decompressor()) + + p.prepare_document_set( + document_set=workload.Documents( + source_format=workload.Documents.SOURCE_FORMAT_BULK, + base_url="http://benchmarks.opensearch.org/corpora/unit-test", + document_file="docs.json", + document_archive="docs.json.bz2", + number_of_documents=5, + compressed_size_in_bytes=200, + uncompressed_size_in_bytes=2000, + ), + data_root="/tmp", + ) ensure_dir.assert_called_with("/tmp") decompress.assert_called_with("/tmp/docs.json.bz2", "/tmp") - calls = [mock.call("http://benchmarks.opensearch.org/corpora/unit-test/docs.json.bz2", - "/tmp/docs.json.bz2", 200, progress_indicator=mock.ANY)] + calls = [mock.call("http://benchmarks.opensearch.org/corpora/unit-test/docs.json.bz2", "/tmp/docs.json.bz2", 200, progress_indicator=mock.ANY)] download.assert_has_calls(calls) - prepare_file_offset_table.assert_called_with("/tmp/docs.json", 'http://benchmarks.opensearch.org/corpora/unit-test', - None, InstanceOf(loader.Downloader)) + prepare_file_offset_table.assert_called_with("/tmp/docs.json", "http://benchmarks.opensearch.org/corpora/unit-test", None, InstanceOf(loader.Downloader)) @mock.patch("solrorbit.utils.io.prepare_file_offset_table") @mock.patch("solrorbit.utils.io.decompress") @@ -343,16 +324,14 @@ def test_download_document_archive_if_no_file_available(self, is_file, get_size, @mock.patch("solrorbit.utils.io.ensure_dir") @mock.patch("os.path.getsize") @mock.patch("os.path.isfile") - def test_download_document_archive_with_source_url_compressed(self, is_file, get_size, ensure_dir, download, decompress, - prepare_file_offset_table): + def test_download_document_archive_with_source_url_compressed(self, is_file, get_size, ensure_dir, download, decompress, prepare_file_offset_table): # uncompressed file does not exist # compressed file does not exist # after download compressed file exists # after download uncompressed file still does not exist (in main loop) # after download compressed file exists (in main loop) # after decompression, uncompressed file exists - is_file.side_effect = [False, False, - True, False, True, True, True, True] + is_file.side_effect = [False, False, True, False, True, True, True, True] # compressed file size is 200 after download # compressed file size is 200 after download (in main loop) # uncompressed file size is 2000 after decompression @@ -361,28 +340,28 @@ def test_download_document_archive_with_source_url_compressed(self, is_file, get prepare_file_offset_table.return_value = 5 - p = loader.DocumentSetPreparator(workload_name="unit-test", - downloader=loader.Downloader( - offline=False, test_mode=False), - decompressor=loader.Decompressor()) - - p.prepare_document_set(document_set=workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, - base_url="http://benchmarks.opensearch.org/corpora", - source_url="http://benchmarks.opensearch.org/corpora/unit-test/docs.json.bz2", - document_file="docs.json", - document_archive="docs.json.bz2", - number_of_documents=5, - compressed_size_in_bytes=200, - uncompressed_size_in_bytes=2000), - data_root="/tmp") + p = loader.DocumentSetPreparator(workload_name="unit-test", downloader=loader.Downloader(offline=False, test_mode=False), decompressor=loader.Decompressor()) + + p.prepare_document_set( + document_set=workload.Documents( + source_format=workload.Documents.SOURCE_FORMAT_BULK, + base_url="http://benchmarks.opensearch.org/corpora", + source_url="http://benchmarks.opensearch.org/corpora/unit-test/docs.json.bz2", + document_file="docs.json", + document_archive="docs.json.bz2", + number_of_documents=5, + compressed_size_in_bytes=200, + uncompressed_size_in_bytes=2000, + ), + data_root="/tmp", + ) ensure_dir.assert_called_with("/tmp") decompress.assert_called_with("/tmp/docs.json.bz2", "/tmp") - download.assert_called_with("http://benchmarks.opensearch.org/corpora/unit-test/docs.json.bz2", - "/tmp/docs.json.bz2", 200, progress_indicator=mock.ANY) - prepare_file_offset_table.assert_called_with("/tmp/docs.json", 'http://benchmarks.opensearch.org/corpora', - 'http://benchmarks.opensearch.org/corpora/unit-test/docs.json.bz2', - InstanceOf(loader.Downloader)) + download.assert_called_with("http://benchmarks.opensearch.org/corpora/unit-test/docs.json.bz2", "/tmp/docs.json.bz2", 200, progress_indicator=mock.ANY) + prepare_file_offset_table.assert_called_with( + "/tmp/docs.json", "http://benchmarks.opensearch.org/corpora", "http://benchmarks.opensearch.org/corpora/unit-test/docs.json.bz2", InstanceOf(loader.Downloader) + ) @mock.patch("solrorbit.utils.io.prepare_file_offset_table") @mock.patch("solrorbit.utils.io.decompress") @@ -390,8 +369,7 @@ def test_download_document_archive_with_source_url_compressed(self, is_file, get @mock.patch("solrorbit.utils.io.ensure_dir") @mock.patch("os.path.getsize") @mock.patch("os.path.isfile") - def test_download_document_with_source_url_uncompressed(self, is_file, get_size, ensure_dir, download, decompress, - prepare_file_offset_table): + def test_download_document_with_source_url_uncompressed(self, is_file, get_size, ensure_dir, download, decompress, prepare_file_offset_table): # uncompressed file does not exist # after download uncompressed file exists # after download uncompressed file exists (main loop) @@ -402,28 +380,28 @@ def test_download_document_with_source_url_uncompressed(self, is_file, get_size, prepare_file_offset_table.return_value = 5 - p = loader.DocumentSetPreparator(workload_name="unit-test", - downloader=loader.Downloader( - offline=False, test_mode=False), - decompressor=loader.Decompressor()) - - p.prepare_document_set(document_set=workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, - source_url=f"{scheme}://benchmarks.opensearch.org/corpora/unit-test/docs.json", - base_url=f"{scheme}://benchmarks.opensearch.org/corpora/", - document_file="docs.json", - # --> We don't provide a document archive here <-- - document_archive=None, - number_of_documents=5, - compressed_size_in_bytes=200, - uncompressed_size_in_bytes=2000), - data_root="/tmp") + p = loader.DocumentSetPreparator(workload_name="unit-test", downloader=loader.Downloader(offline=False, test_mode=False), decompressor=loader.Decompressor()) + + p.prepare_document_set( + document_set=workload.Documents( + source_format=workload.Documents.SOURCE_FORMAT_BULK, + source_url=f"{scheme}://benchmarks.opensearch.org/corpora/unit-test/docs.json", + base_url=f"{scheme}://benchmarks.opensearch.org/corpora/", + document_file="docs.json", + # --> We don't provide a document archive here <-- + document_archive=None, + number_of_documents=5, + compressed_size_in_bytes=200, + uncompressed_size_in_bytes=2000, + ), + data_root="/tmp", + ) ensure_dir.assert_called_with("/tmp") - download.assert_called_with(f"{scheme}://benchmarks.opensearch.org/corpora/unit-test/docs.json", - "/tmp/docs.json", 2000, progress_indicator=mock.ANY) - prepare_file_offset_table.assert_called_with("/tmp/docs.json", f"{scheme}://benchmarks.opensearch.org/corpora/", - f"{scheme}://benchmarks.opensearch.org/corpora/unit-test/docs.json", - InstanceOf(loader.Downloader)) + download.assert_called_with(f"{scheme}://benchmarks.opensearch.org/corpora/unit-test/docs.json", "/tmp/docs.json", 2000, progress_indicator=mock.ANY) + prepare_file_offset_table.assert_called_with( + "/tmp/docs.json", f"{scheme}://benchmarks.opensearch.org/corpora/", f"{scheme}://benchmarks.opensearch.org/corpora/unit-test/docs.json", InstanceOf(loader.Downloader) + ) @mock.patch("solrorbit.utils.io.prepare_file_offset_table") @mock.patch("solrorbit.utils.io.decompress") @@ -431,8 +409,7 @@ def test_download_document_with_source_url_uncompressed(self, is_file, get_size, @mock.patch("solrorbit.utils.io.ensure_dir") @mock.patch("os.path.getsize") @mock.patch("os.path.isfile") - def test_download_document_with_trailing_baseurl_slash(self, is_file, get_size, ensure_dir, download, decompress, - prepare_file_offset_table): + def test_download_document_with_trailing_baseurl_slash(self, is_file, get_size, ensure_dir, download, decompress, prepare_file_offset_table): # uncompressed file does not exist # after download uncompressed file exists # after download uncompressed file exists (main loop) @@ -443,27 +420,26 @@ def test_download_document_with_trailing_baseurl_slash(self, is_file, get_size, prepare_file_offset_table.return_value = 5 - p = loader.DocumentSetPreparator(workload_name="unit-test", - downloader=loader.Downloader( - offline=False, test_mode=False), - decompressor=loader.Decompressor()) - - p.prepare_document_set(document_set=workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, - base_url=f"{scheme}://benchmarks.opensearch.org/corpora/unit-test/", - document_file="docs.json", - # --> We don't provide a document archive here <-- - document_archive=None, - number_of_documents=5, - compressed_size_in_bytes=200, - uncompressed_size_in_bytes=2000), - data_root="/tmp") + p = loader.DocumentSetPreparator(workload_name="unit-test", downloader=loader.Downloader(offline=False, test_mode=False), decompressor=loader.Decompressor()) + + p.prepare_document_set( + document_set=workload.Documents( + source_format=workload.Documents.SOURCE_FORMAT_BULK, + base_url=f"{scheme}://benchmarks.opensearch.org/corpora/unit-test/", + document_file="docs.json", + # --> We don't provide a document archive here <-- + document_archive=None, + number_of_documents=5, + compressed_size_in_bytes=200, + uncompressed_size_in_bytes=2000, + ), + data_root="/tmp", + ) ensure_dir.assert_called_with("/tmp") - calls = [mock.call(f"{scheme}://benchmarks.opensearch.org/corpora/unit-test/docs.json", - "/tmp/docs.json", 2000, progress_indicator=mock.ANY)] + calls = [mock.call(f"{scheme}://benchmarks.opensearch.org/corpora/unit-test/docs.json", "/tmp/docs.json", 2000, progress_indicator=mock.ANY)] download.assert_has_calls(calls) - prepare_file_offset_table.assert_called_with("/tmp/docs.json", f"{scheme}://benchmarks.opensearch.org/corpora/unit-test/", - None, InstanceOf(loader.Downloader)) + prepare_file_offset_table.assert_called_with("/tmp/docs.json", f"{scheme}://benchmarks.opensearch.org/corpora/unit-test/", None, InstanceOf(loader.Downloader)) @mock.patch("solrorbit.utils.io.prepare_file_offset_table") @mock.patch("solrorbit.utils.net.download") @@ -480,27 +456,26 @@ def test_download_document_file_if_no_file_available(self, is_file, get_size, en prepare_file_offset_table.return_value = 5 - p = loader.DocumentSetPreparator(workload_name="unit-test", - downloader=loader.Downloader( - offline=False, test_mode=False), - decompressor=loader.Decompressor()) - - p.prepare_document_set(document_set=workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, - base_url="http://benchmarks.opensearch.org/corpora/unit-test", - document_file="docs.json", - # --> We don't provide a document archive here <-- - document_archive=None, - number_of_documents=5, - compressed_size_in_bytes=200, - uncompressed_size_in_bytes=2000), - data_root="/tmp") + p = loader.DocumentSetPreparator(workload_name="unit-test", downloader=loader.Downloader(offline=False, test_mode=False), decompressor=loader.Decompressor()) + + p.prepare_document_set( + document_set=workload.Documents( + source_format=workload.Documents.SOURCE_FORMAT_BULK, + base_url="http://benchmarks.opensearch.org/corpora/unit-test", + document_file="docs.json", + # --> We don't provide a document archive here <-- + document_archive=None, + number_of_documents=5, + compressed_size_in_bytes=200, + uncompressed_size_in_bytes=2000, + ), + data_root="/tmp", + ) ensure_dir.assert_called_with("/tmp") - calls = [mock.call("http://benchmarks.opensearch.org/corpora/unit-test/docs.json", - "/tmp/docs.json", 2000, progress_indicator=mock.ANY)] + calls = [mock.call("http://benchmarks.opensearch.org/corpora/unit-test/docs.json", "/tmp/docs.json", 2000, progress_indicator=mock.ANY)] download.assert_has_calls(calls) - prepare_file_offset_table.assert_called_with("/tmp/docs.json", 'http://benchmarks.opensearch.org/corpora/unit-test', - None, InstanceOf(loader.Downloader)) + prepare_file_offset_table.assert_called_with("/tmp/docs.json", "http://benchmarks.opensearch.org/corpora/unit-test", None, InstanceOf(loader.Downloader)) @mock.patch("solrorbit.utils.net.download") @mock.patch("solrorbit.utils.io.ensure_dir") @@ -509,21 +484,21 @@ def test_raise_download_error_if_offline(self, is_file, ensure_dir, download): # uncompressed file does not exist is_file.return_value = False - p = loader.DocumentSetPreparator(workload_name="unit-test", - downloader=loader.Downloader( - offline=True, test_mode=False), - decompressor=loader.Decompressor()) + p = loader.DocumentSetPreparator(workload_name="unit-test", downloader=loader.Downloader(offline=True, test_mode=False), decompressor=loader.Decompressor()) with self.assertRaises(exceptions.SystemSetupError) as ctx: - p.prepare_document_set(document_set=workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, - base_url="http://benchmarks.opensearch.org/corpora/unit-test", - document_file="docs.json", - number_of_documents=5, - uncompressed_size_in_bytes=2000), - data_root="/tmp") + p.prepare_document_set( + document_set=workload.Documents( + source_format=workload.Documents.SOURCE_FORMAT_BULK, + base_url="http://benchmarks.opensearch.org/corpora/unit-test", + document_file="docs.json", + number_of_documents=5, + uncompressed_size_in_bytes=2000, + ), + data_root="/tmp", + ) - self.assertEqual( - "Cannot find [/tmp/docs.json]. Please disable offline mode and retry.", ctx.exception.args[0]) + self.assertEqual("Cannot find [/tmp/docs.json]. Please disable offline mode and retry.", ctx.exception.args[0]) self.assertEqual(0, ensure_dir.call_count) self.assertEqual(0, download.call_count) @@ -535,22 +510,22 @@ def test_raise_download_error_if_no_url_provided_and_file_missing(self, is_file, # uncompressed file does not exist is_file.return_value = False - p = loader.DocumentSetPreparator(workload_name="unit-test", - downloader=loader.Downloader( - offline=False, test_mode=False), - decompressor=loader.Decompressor()) + p = loader.DocumentSetPreparator(workload_name="unit-test", downloader=loader.Downloader(offline=False, test_mode=False), decompressor=loader.Decompressor()) with self.assertRaises(exceptions.DataError) as ctx: - p.prepare_document_set(document_set=workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, - base_url=None, - document_file="docs.json", - document_archive=None, - number_of_documents=5, - uncompressed_size_in_bytes=2000), - data_root="/tmp") + p.prepare_document_set( + document_set=workload.Documents( + source_format=workload.Documents.SOURCE_FORMAT_BULK, + base_url=None, + document_file="docs.json", + document_archive=None, + number_of_documents=5, + uncompressed_size_in_bytes=2000, + ), + data_root="/tmp", + ) - self.assertEqual( - "Cannot download data because no base URL is provided.", ctx.exception.args[0]) + self.assertEqual("Cannot download data because no base URL is provided.", ctx.exception.args[0]) self.assertEqual(0, ensure_dir.call_count) self.assertEqual(0, download.call_count) @@ -565,20 +540,19 @@ def test_raise_download_error_if_no_url_provided_and_wrong_file_size(self, is_fi # but it's size is wrong get_size.return_value = 100 - p = loader.DocumentSetPreparator(workload_name="unit-test", - downloader=loader.Downloader( - offline=False, test_mode=False), - decompressor=loader.Decompressor()) + p = loader.DocumentSetPreparator(workload_name="unit-test", downloader=loader.Downloader(offline=False, test_mode=False), decompressor=loader.Decompressor()) with self.assertRaises(exceptions.DataError) as ctx: - p.prepare_document_set(document_set=workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, - document_file="docs.json", - number_of_documents=5, - uncompressed_size_in_bytes=2000), - data_root="/tmp") + p.prepare_document_set( + document_set=workload.Documents( + source_format=workload.Documents.SOURCE_FORMAT_BULK, document_file="docs.json", number_of_documents=5, uncompressed_size_in_bytes=2000 + ), + data_root="/tmp", + ) - self.assertEqual("[/tmp/docs.json] is present but does not have the expected size of [2000] bytes and it " - "cannot be downloaded because no base URL is provided.", ctx.exception.args[0]) + self.assertEqual( + "[/tmp/docs.json] is present but does not have the expected size of [2000] bytes and it cannot be downloaded because no base URL is provided.", ctx.exception.args[0] + ) self.assertEqual(0, ensure_dir.call_count) self.assertEqual(0, download.call_count) @@ -590,28 +564,26 @@ def test_raise_download_error_no_test_mode_file(self, is_file, ensure_dir, downl # uncompressed file does not exist is_file.return_value = False - download.side_effect = urllib.error.HTTPError("http://benchmarks.opensearch.org.s3.amazonaws.com/corpora/unit-test/docs-1k.json", - 404, "", None, None) + download.side_effect = urllib.error.HTTPError("http://benchmarks.opensearch.org.s3.amazonaws.com/corpora/unit-test/docs-1k.json", 404, "", None, None) - p = loader.DocumentSetPreparator(workload_name="unit-test", - downloader=loader.Downloader( - offline=False, test_mode=True), - decompressor=loader.Decompressor()) + p = loader.DocumentSetPreparator(workload_name="unit-test", downloader=loader.Downloader(offline=False, test_mode=True), decompressor=loader.Decompressor()) with self.assertRaises(exceptions.DataError) as ctx: - p.prepare_document_set(document_set=workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, - base_url="http://benchmarks.opensearch.org/corpora/unit-test", - document_file="docs-1k.json", - number_of_documents=5, - uncompressed_size_in_bytes=None), - data_root="/tmp") + p.prepare_document_set( + document_set=workload.Documents( + source_format=workload.Documents.SOURCE_FORMAT_BULK, + base_url="http://benchmarks.opensearch.org/corpora/unit-test", + document_file="docs-1k.json", + number_of_documents=5, + uncompressed_size_in_bytes=None, + ), + data_root="/tmp", + ) - self.assertEqual("This workload does not support test mode. Ask the workload author to add it or disable " - "test mode and retry.", ctx.exception.args[0]) + self.assertEqual("This workload does not support test mode. Ask the workload author to add it or disable test mode and retry.", ctx.exception.args[0]) ensure_dir.assert_called_with("/tmp") - download.assert_called_with("http://benchmarks.opensearch.org/corpora/unit-test/docs-1k.json", - "/tmp/docs-1k.json", None, progress_indicator=mock.ANY) + download.assert_called_with("http://benchmarks.opensearch.org/corpora/unit-test/docs-1k.json", "/tmp/docs-1k.json", None, progress_indicator=mock.ANY) @mock.patch("solrorbit.utils.net.download") @mock.patch("solrorbit.utils.io.ensure_dir") @@ -620,28 +592,29 @@ def test_raise_download_error_on_connection_problems(self, is_file, ensure_dir, # uncompressed file does not exist is_file.return_value = False - download.side_effect = urllib.error.HTTPError("http://benchmarks.opensearch.org/corpora/unit-test/docs.json", - 500, "Internal Server Error", None, None) + download.side_effect = urllib.error.HTTPError("http://benchmarks.opensearch.org/corpora/unit-test/docs.json", 500, "Internal Server Error", None, None) - p = loader.DocumentSetPreparator(workload_name="unit-test", - downloader=loader.Downloader( - offline=False, test_mode=False), - decompressor=loader.Decompressor()) + p = loader.DocumentSetPreparator(workload_name="unit-test", downloader=loader.Downloader(offline=False, test_mode=False), decompressor=loader.Decompressor()) with self.assertRaises(exceptions.DataError) as ctx: - p.prepare_document_set(document_set=workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, - base_url="http://benchmarks.opensearch.org/corpora/unit-test", - document_file="docs.json", - number_of_documents=5, - uncompressed_size_in_bytes=2000), - data_root="/tmp") + p.prepare_document_set( + document_set=workload.Documents( + source_format=workload.Documents.SOURCE_FORMAT_BULK, + base_url="http://benchmarks.opensearch.org/corpora/unit-test", + document_file="docs.json", + number_of_documents=5, + uncompressed_size_in_bytes=2000, + ), + data_root="/tmp", + ) - self.assertEqual("Could not download [http://benchmarks.opensearch.org/corpora/unit-test/docs.json] " - "to [/tmp/docs.json] (HTTP status: 500, reason: Internal Server Error)", ctx.exception.args[0]) + self.assertEqual( + "Could not download [http://benchmarks.opensearch.org/corpora/unit-test/docs.json] to [/tmp/docs.json] (HTTP status: 500, reason: Internal Server Error)", + ctx.exception.args[0], + ) ensure_dir.assert_called_with("/tmp") - download.assert_called_with("http://benchmarks.opensearch.org/corpora/unit-test/docs.json", - "/tmp/docs.json", 2000, progress_indicator=mock.ANY) + download.assert_called_with("http://benchmarks.opensearch.org/corpora/unit-test/docs.json", "/tmp/docs.json", 2000, progress_indicator=mock.ANY) @mock.patch("solrorbit.utils.io.prepare_file_offset_table") @mock.patch("solrorbit.utils.io.decompress") @@ -653,21 +626,23 @@ def test_prepare_bundled_document_set_if_document_file_available(self, is_file, get_size.side_effect = [2000] prepare_file_offset_table.return_value = 5 - p = loader.DocumentSetPreparator(workload_name="unit-test", - downloader=loader.Downloader( - offline=False, test_mode=False), - decompressor=loader.Decompressor()) - - self.assertTrue(p.prepare_bundled_document_set(document_set=workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, - document_file="docs.json", - document_archive="docs.json.bz2", - number_of_documents=5, - compressed_size_in_bytes=200, - uncompressed_size_in_bytes=2000), - data_root=".")) + p = loader.DocumentSetPreparator(workload_name="unit-test", downloader=loader.Downloader(offline=False, test_mode=False), decompressor=loader.Decompressor()) + + self.assertTrue( + p.prepare_bundled_document_set( + document_set=workload.Documents( + source_format=workload.Documents.SOURCE_FORMAT_BULK, + document_file="docs.json", + document_archive="docs.json.bz2", + number_of_documents=5, + compressed_size_in_bytes=200, + uncompressed_size_in_bytes=2000, + ), + data_root=".", + ) + ) - prepare_file_offset_table.assert_called_with( - "./docs.json", None, None, InstanceOf(loader.Downloader)) + prepare_file_offset_table.assert_called_with("./docs.json", None, None, InstanceOf(loader.Downloader)) @mock.patch("solrorbit.utils.io.prepare_file_offset_table") @mock.patch("solrorbit.utils.io.decompress") @@ -677,18 +652,21 @@ def test_prepare_bundled_document_set_does_nothing_if_no_document_files(self, is # no files present is_file.return_value = False - p = loader.DocumentSetPreparator(workload_name="unit-test", - downloader=loader.Downloader( - offline=False, test_mode=False), - decompressor=loader.Decompressor()) - - self.assertFalse(p.prepare_bundled_document_set(document_set=workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, - document_file="docs.json", - document_archive="docs.json.bz2", - number_of_documents=5, - compressed_size_in_bytes=200, - uncompressed_size_in_bytes=2000), - data_root=".")) + p = loader.DocumentSetPreparator(workload_name="unit-test", downloader=loader.Downloader(offline=False, test_mode=False), decompressor=loader.Decompressor()) + + self.assertFalse( + p.prepare_bundled_document_set( + document_set=workload.Documents( + source_format=workload.Documents.SOURCE_FORMAT_BULK, + document_file="docs.json", + document_archive="docs.json.bz2", + number_of_documents=5, + compressed_size_in_bytes=200, + uncompressed_size_in_bytes=2000, + ), + data_root=".", + ) + ) self.assertEqual(0, decompress.call_count) self.assertEqual(0, prepare_file_offset_table.call_count) @@ -711,23 +689,23 @@ def test_used_corpora(self): "source-file": "documents-181998.unparsed.json.bz2", "document-count": 2708746, "compressed-bytes": 13064317, - "uncompressed-bytes": 303920342 + "uncompressed-bytes": 303920342, }, { "target-collection": "logs-191998", "source-file": "documents-191998.unparsed.json.bz2", "document-count": 9697882, "compressed-bytes": 47211781, - "uncompressed-bytes": 1088378738 + "uncompressed-bytes": 1088378738, }, { "target-collection": "logs-201998", "source-file": "documents-201998.unparsed.json.bz2", "document-count": 13053463, "compressed-bytes": 63174979, - "uncompressed-bytes": 1456836090 - } - ] + "uncompressed-bytes": 1456836090, + }, + ], }, { "name": "http_logs", @@ -738,51 +716,30 @@ def test_used_corpora(self): "source-file": "documents-181998.json.bz2", "document-count": 2708746, "compressed-bytes": 13815456, - "uncompressed-bytes": 363512754 + "uncompressed-bytes": 363512754, }, { "target-collection": "logs-191998", "source-file": "documents-191998.json.bz2", "document-count": 9697882, "compressed-bytes": 49439633, - "uncompressed-bytes": 1301732149 + "uncompressed-bytes": 1301732149, }, { "target-collection": "logs-201998", "source-file": "documents-201998.json.bz2", "document-count": 13053463, "compressed-bytes": 65623436, - "uncompressed-bytes": 1744012279 - } - ] - } + "uncompressed-bytes": 1744012279, + }, + ], + }, ], "operations": [ - { - "name": "bulk-index-1", - "operation-type": "bulk", - "corpora": ["http_logs"], - "indices": ["logs-181998"], - "bulk-size": 500 - }, - { - "name": "bulk-index-2", - "operation-type": "bulk", - "corpora": ["http_logs"], - "indices": ["logs-191998"], - "bulk-size": 500 - }, - { - "name": "bulk-index-3", - "operation-type": "bulk", - "corpora": ["http_logs_unparsed"], - "indices": ["logs-201998"], - "bulk-size": 500 - }, - { - "name": "node-stats", - "operation-type": "node-stats" - }, + {"name": "bulk-index-1", "operation-type": "bulk", "corpora": ["http_logs"], "indices": ["logs-181998"], "bulk-size": 500}, + {"name": "bulk-index-2", "operation-type": "bulk", "corpora": ["http_logs"], "indices": ["logs-191998"], "bulk-size": 500}, + {"name": "bulk-index-3", "operation-type": "bulk", "corpora": ["http_logs_unparsed"], "indices": ["logs-201998"], "bulk-size": 500}, + {"name": "node-stats", "operation-type": "node-stats"}, ], "test_procedures": [ { @@ -806,27 +763,21 @@ def test_used_corpora(self): ] } }, - { - "operation": "node-stats" - } - ] + {"operation": "node-stats"}, + ], } - ] + ], } - reader = loader.WorkloadSpecificationReader( - selected_test_procedure="default-test_procedure") + reader = loader.WorkloadSpecificationReader(selected_test_procedure="default-test_procedure") full_workload = reader("unittest", workload_specification, "/mappings") - used_corpora = sorted(loader.used_corpora( - full_workload), key=lambda c: c.name) + used_corpora = sorted(loader.used_corpora(full_workload), key=lambda c: c.name) self.assertEqual(2, len(used_corpora)) self.assertEqual("http_logs", used_corpora[0].name) # each bulk operation requires a different data file but they should have been merged properly. - self.assertEqual({"documents-181998.json.bz2", "documents-191998.json.bz2"}, - {d.document_archive for d in used_corpora[0].documents}) + self.assertEqual({"documents-181998.json.bz2", "documents-191998.json.bz2"}, {d.document_archive for d in used_corpora[0].documents}) self.assertEqual("http_logs_unparsed", used_corpora[1].name) - self.assertEqual({"documents-201998.unparsed.json.bz2"}, - {d.document_archive for d in used_corpora[1].documents}) + self.assertEqual({"documents-201998.unparsed.json.bz2"}, {d.document_archive for d in used_corpora[1].documents}) @mock.patch("solrorbit.utils.io.prepare_file_offset_table") @mock.patch("solrorbit.utils.io.decompress") @@ -844,21 +795,23 @@ def test_prepare_bundled_document_set_decompresses_compressed_docs(self, is_file get_size.side_effect = [200, 2000, 2000] prepare_file_offset_table.return_value = 5 - p = loader.DocumentSetPreparator(workload_name="unit-test", - downloader=loader.Downloader( - offline=False, test_mode=False), - decompressor=loader.Decompressor()) - - self.assertTrue(p.prepare_bundled_document_set(document_set=workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, - document_file="docs.json", - document_archive="docs.json.bz2", - number_of_documents=5, - compressed_size_in_bytes=200, - uncompressed_size_in_bytes=2000), - data_root=".")) + p = loader.DocumentSetPreparator(workload_name="unit-test", downloader=loader.Downloader(offline=False, test_mode=False), decompressor=loader.Decompressor()) + + self.assertTrue( + p.prepare_bundled_document_set( + document_set=workload.Documents( + source_format=workload.Documents.SOURCE_FORMAT_BULK, + document_file="docs.json", + document_archive="docs.json.bz2", + number_of_documents=5, + compressed_size_in_bytes=200, + uncompressed_size_in_bytes=2000, + ), + data_root=".", + ) + ) - prepare_file_offset_table.assert_called_with( - "./docs.json", None, None, InstanceOf(loader.Downloader)) + prepare_file_offset_table.assert_called_with("./docs.json", None, None, InstanceOf(loader.Downloader)) @mock.patch("os.path.getsize") @mock.patch("os.path.isfile") @@ -869,22 +822,22 @@ def test_prepare_bundled_document_set_error_compressed_docs_wrong_size(self, is_ # compressed has wrong size get_size.side_effect = [150] - p = loader.DocumentSetPreparator(workload_name="unit-test", - downloader=loader.Downloader( - offline=False, test_mode=False), - decompressor=loader.Decompressor()) + p = loader.DocumentSetPreparator(workload_name="unit-test", downloader=loader.Downloader(offline=False, test_mode=False), decompressor=loader.Decompressor()) with self.assertRaises(exceptions.DataError) as ctx: - p.prepare_bundled_document_set(document_set=workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, - document_file="docs.json", - document_archive="docs.json.bz2", - number_of_documents=5, - compressed_size_in_bytes=200, - uncompressed_size_in_bytes=2000), - data_root=".") + p.prepare_bundled_document_set( + document_set=workload.Documents( + source_format=workload.Documents.SOURCE_FORMAT_BULK, + document_file="docs.json", + document_archive="docs.json.bz2", + number_of_documents=5, + compressed_size_in_bytes=200, + uncompressed_size_in_bytes=2000, + ), + data_root=".", + ) - self.assertEqual("[./docs.json.bz2] is present but does not have the expected size of [200] bytes.", - ctx.exception.args[0]) + self.assertEqual("[./docs.json.bz2] is present but does not have the expected size of [200] bytes.", ctx.exception.args[0]) @mock.patch("solrorbit.utils.io.prepare_file_offset_table") @mock.patch("solrorbit.utils.io.decompress") @@ -896,21 +849,21 @@ def test_prepare_bundled_document_set_uncompressed_docs_wrong_size(self, is_file # uncompressed get_size.side_effect = [1500] - p = loader.DocumentSetPreparator(workload_name="unit-test", - downloader=loader.Downloader( - offline=False, test_mode=False), - decompressor=loader.Decompressor()) + p = loader.DocumentSetPreparator(workload_name="unit-test", downloader=loader.Downloader(offline=False, test_mode=False), decompressor=loader.Decompressor()) with self.assertRaises(exceptions.DataError) as ctx: - p.prepare_bundled_document_set(document_set=workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, - document_file="docs.json", - document_archive="docs.json.bz2", - number_of_documents=5, - compressed_size_in_bytes=200, - uncompressed_size_in_bytes=2000), - data_root=".") - self.assertEqual("[./docs.json] is present but does not have the expected size of [2000] bytes.", - ctx.exception.args[0]) + p.prepare_bundled_document_set( + document_set=workload.Documents( + source_format=workload.Documents.SOURCE_FORMAT_BULK, + document_file="docs.json", + document_archive="docs.json.bz2", + number_of_documents=5, + compressed_size_in_bytes=200, + uncompressed_size_in_bytes=2000, + ), + data_root=".", + ) + self.assertEqual("[./docs.json] is present but does not have the expected size of [2000] bytes.", ctx.exception.args[0]) self.assertEqual(0, prepare_file_offset_table.call_count) @@ -932,36 +885,35 @@ def test_download_document_file_from_part_files(self, rm_file, is_file, get_size prepare_file_offset_table.return_value = 5 - p = loader.DocumentSetPreparator(workload_name="unit-test", - downloader=loader.Downloader( - offline=False, test_mode=False), - decompressor=loader.Decompressor()) + p = loader.DocumentSetPreparator(workload_name="unit-test", downloader=loader.Downloader(offline=False, test_mode=False), decompressor=loader.Decompressor()) mo = mock.mock_open() with mock.patch("builtins.open", mo): - p.prepare_document_set(document_set=workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, - base_url="http://benchmarks.opensearch.org/corpora/unit-test", - document_file="docs.json", - document_file_parts=[{"name": "xaa", "size": 1000}, - {"name": "xab", - "size": 600}, - {"name": "xac", "size": 400}], - # --> We don't provide a document archive here <-- - document_archive=None, - number_of_documents=5, - compressed_size_in_bytes=200, - uncompressed_size_in_bytes=2000), - data_root="/tmp") + p.prepare_document_set( + document_set=workload.Documents( + source_format=workload.Documents.SOURCE_FORMAT_BULK, + base_url="http://benchmarks.opensearch.org/corpora/unit-test", + document_file="docs.json", + document_file_parts=[{"name": "xaa", "size": 1000}, {"name": "xab", "size": 600}, {"name": "xac", "size": 400}], + # --> We don't provide a document archive here <-- + document_archive=None, + number_of_documents=5, + compressed_size_in_bytes=200, + uncompressed_size_in_bytes=2000, + ), + data_root="/tmp", + ) ensure_dir.assert_called_with("/tmp") - calls = [mock.call('http://benchmarks.opensearch.org/corpora/unit-test/xaa', '/tmp/xaa', 1000, progress_indicator=mock.ANY), - mock.call('http://benchmarks.opensearch.org/corpora/unit-test/xab', - '/tmp/xab', 600, progress_indicator=mock.ANY), - mock.call('http://benchmarks.opensearch.org/corpora/unit-test/xac', '/tmp/xac', 400, progress_indicator=mock.ANY)] + calls = [ + mock.call("http://benchmarks.opensearch.org/corpora/unit-test/xaa", "/tmp/xaa", 1000, progress_indicator=mock.ANY), + mock.call("http://benchmarks.opensearch.org/corpora/unit-test/xab", "/tmp/xab", 600, progress_indicator=mock.ANY), + mock.call("http://benchmarks.opensearch.org/corpora/unit-test/xac", "/tmp/xac", 400, progress_indicator=mock.ANY), + ] download.assert_has_calls(calls) - prepare_file_offset_table.assert_called_with("/tmp/docs.json", 'http://benchmarks.opensearch.org/corpora/unit-test', - None, InstanceOf(loader.Downloader)) + prepare_file_offset_table.assert_called_with("/tmp/docs.json", "http://benchmarks.opensearch.org/corpora/unit-test", None, InstanceOf(loader.Downloader)) + class TemplateSource(TestCase): @mock.patch("solrorbit.utils.io.dirname") @@ -1003,7 +955,7 @@ def test_entrypoint_of_replace_includes(self, patched_read_glob, patched_dirname """) def dummy_read_glob(c): - return "{{\"replaced {}\": \"true\"}}".format(c) + return '{{"replaced {}": "true"}}'.format(c) patched_read_glob.side_effect = dummy_read_glob @@ -1046,21 +998,16 @@ def dummy_read_glob(c): } """) - self.assertEqual( - expected_response, - tmpl_src.replace_includes(base_path, workload) - ) + self.assertEqual(expected_response, tmpl_src.replace_includes(base_path, workload)) def test_read_glob_files(self): tmpl_obj = loader.TemplateSource( base_path="/some/path/to/a/benchmark/workload", template_file_name="workload.json", fileglobber=lambda pat: [ - os.path.join(os.path.dirname(__file__), - "resources", "workload_fragment_1.json"), - os.path.join(os.path.dirname(__file__), - "resources", "workload_fragment_2.json") - ] + os.path.join(os.path.dirname(__file__), "resources", "workload_fragment_1.json"), + os.path.join(os.path.dirname(__file__), "resources", "workload_fragment_2.json"), + ], ) response = tmpl_obj.read_glob_files("*workload_fragment_*.json") expected_response = '{\n "item1": "value1"\n}\n,\n{\n "item2": "value2"\n}\n' @@ -1069,8 +1016,7 @@ def test_read_glob_files(self): class TemplateRenderTests(TestCase): - unittest_template_internal_vars = loader.default_internal_template_vars( - clock=StaticClock) + unittest_template_internal_vars = loader.default_internal_template_vars(clock=StaticClock) def test_render_simple_template(self): template = """ @@ -1080,8 +1026,7 @@ def test_render_simple_template(self): } """ - rendered = loader.render_template( - template, template_internal_vars=TemplateRenderTests.unittest_template_internal_vars) + rendered = loader.render_template(template, template_internal_vars=TemplateRenderTests.unittest_template_internal_vars) expected = """ { @@ -1099,8 +1044,7 @@ def test_render_template_with_external_variables(self): } """ - rendered = loader.render_template(template, template_vars={"greeting": "Hi"}, - template_internal_vars=TemplateRenderTests.unittest_template_internal_vars) + rendered = loader.render_template(template, template_vars={"greeting": "Hi"}, template_internal_vars=TemplateRenderTests.unittest_template_internal_vars) expected = """ { @@ -1130,25 +1074,18 @@ def key_globber(e): } """ - source = io.DictStringFileSourceFactory({ - "dynamic-key-1": [ - textwrap.dedent('"dkey1": "value1"') - ], - "dynamic-key-2": [ - textwrap.dedent('"dkey2": "value2"') - ], - "dynamic-key-3": [ - textwrap.dedent('"dkey3": "value3"') - ] - }) + source = io.DictStringFileSourceFactory( + { + "dynamic-key-1": [textwrap.dedent('"dkey1": "value1"')], + "dynamic-key-2": [textwrap.dedent('"dkey2": "value2"')], + "dynamic-key-3": [textwrap.dedent('"dkey3": "value3"')], + } + ) - template_source = loader.TemplateSource( - "", "workload.json", source=source, fileglobber=key_globber) + template_source = loader.TemplateSource("", "workload.json", source=source, fileglobber=key_globber) template_source.load_template_from_string(template) - rendered = loader.render_template( - template_source.assembled_source, - template_internal_vars=TemplateRenderTests.unittest_template_internal_vars) + rendered = loader.render_template(template_source.assembled_source, template_internal_vars=TemplateRenderTests.unittest_template_internal_vars) expected = """ { @@ -1172,10 +1109,7 @@ def test_render_template_with_variables(self): "dkey2": {{ _bulk_size }} } """ - rendered = loader.render_template( - template, - template_vars={"clients": 8}, - template_internal_vars=TemplateRenderTests.unittest_template_internal_vars) + rendered = loader.render_template(template, template_vars={"clients": 8}, template_internal_vars=TemplateRenderTests.unittest_template_internal_vars) expected = """ { @@ -1200,61 +1134,37 @@ class CompleteWorkloadParamsTests(TestCase): def test_check_complete_workload_params_contains_all_workload_params(self): complete_workload_params = loader.CompleteWorkloadParams() - loader.register_all_params_in_workload( - CompleteWorkloadParamsTests.assembled_source, complete_workload_params) + loader.register_all_params_in_workload(CompleteWorkloadParamsTests.assembled_source, complete_workload_params) - self.assertEqual( - ["value2", "value3"], - complete_workload_params.sorted_workload_defined_params - ) + self.assertEqual(["value2", "value3"], complete_workload_params.sorted_workload_defined_params) def test_check_complete_workload_params_does_not_fail_with_no_workload_params(self): complete_workload_params = loader.CompleteWorkloadParams() - loader.register_all_params_in_workload('{}', complete_workload_params) + loader.register_all_params_in_workload("{}", complete_workload_params) - self.assertEqual( - [], - complete_workload_params.sorted_workload_defined_params - ) + self.assertEqual([], complete_workload_params.sorted_workload_defined_params) def test_unused_user_defined_workload_params(self): workload_params = { "number_of_repliacs": 1, # deliberate typo "enable_source": True, # unknown parameter - "number_of_shards": 5 + "number_of_shards": 5, } - complete_workload_params = loader.CompleteWorkloadParams( - user_specified_workload_params=workload_params) - complete_workload_params.populate_workload_defined_params(list_of_workload_params=[ - "bulk_indexing_clients", - "bulk_indexing_iterations", - "bulk_size", - "cluster_health", - "number_of_replicas", - "number_of_shards"] + complete_workload_params = loader.CompleteWorkloadParams(user_specified_workload_params=workload_params) + complete_workload_params.populate_workload_defined_params( + list_of_workload_params=["bulk_indexing_clients", "bulk_indexing_iterations", "bulk_size", "cluster_health", "number_of_replicas", "number_of_shards"] ) - self.assertEqual( - ["enable_source", "number_of_repliacs"], - sorted(complete_workload_params.unused_user_defined_workload_params()) - ) + self.assertEqual(["enable_source", "number_of_repliacs"], sorted(complete_workload_params.unused_user_defined_workload_params())) def test_unused_user_defined_workload_params_doesnt_fail_with_detaults(self): complete_workload_params = loader.CompleteWorkloadParams() - complete_workload_params.populate_workload_defined_params(list_of_workload_params=[ - "bulk_indexing_clients", - "bulk_indexing_iterations", - "bulk_size", - "cluster_health", - "number_of_replicas", - "number_of_shards"] + complete_workload_params.populate_workload_defined_params( + list_of_workload_params=["bulk_indexing_clients", "bulk_indexing_iterations", "bulk_size", "cluster_health", "number_of_replicas", "number_of_shards"] ) - self.assertEqual( - [], - sorted(complete_workload_params.unused_user_defined_workload_params()) - ) + self.assertEqual([], sorted(complete_workload_params.unused_user_defined_workload_params())) class WorkloadPostProcessingTests(TestCase): @@ -1336,37 +1246,9 @@ class WorkloadPostProcessingTests(TestCase): def test_post_processes_workload_spec(self): workload_specification = { - "indices": [ - { - "name": "test-index", - "body": "test-index-body.json", - "types": ["test-type"] - } - ], - "corpora": [ - { - "name": "unittest", - "documents": [ - { - "source-file": "documents.json.bz2", - "document-count": 10, - "compressed-bytes": 100, - "uncompressed-bytes": 10000 - } - ] - } - ], - "operations": [ - { - "name": "index-append", - "operation-type": "bulk", - "bulk-size": 5000 - }, - { - "name": "search", - "operation-type": "search" - } - ], + "indices": [{"name": "test-index", "body": "test-index-body.json", "types": ["test-type"]}], + "corpora": [{"name": "unittest", "documents": [{"source-file": "documents.json.bz2", "document-count": 10, "compressed-bytes": 100, "uncompressed-bytes": 10000}]}], + "operations": [{"name": "index-append", "operation-type": "bulk", "bulk-size": 5000}, {"name": "search", "operation-type": "search"}], "test_procedures": [ { "name": "default-test_procedure", @@ -1381,66 +1263,21 @@ def test_post_processes_workload_spec(self): { "parallel": { "tasks": [ - { - "name": "search #1", - "clients": 4, - "operation": "search", - "warmup-iterations": 1000, - "iterations": 2000, - "target-interval": 30 - }, - { - "name": "search #2", - "clients": 1, - "operation": "search", - "warmup-iterations": 1000, - "iterations": 2000, - "target-throughput": 200 - }, - { - "name": "search #3", - "clients": 1, - "operation": "search", - "iterations": 1 - } + {"name": "search #1", "clients": 4, "operation": "search", "warmup-iterations": 1000, "iterations": 2000, "target-interval": 30}, + {"name": "search #2", "clients": 1, "operation": "search", "warmup-iterations": 1000, "iterations": 2000, "target-throughput": 200}, + {"name": "search #3", "clients": 1, "operation": "search", "iterations": 1}, ] } - } - ] + }, + ], } - ] + ], } expected_post_processed = { - "indices": [ - { - "name": "test-index", - "body": "test-index-body.json", - "types": ["test-type"] - } - ], - "corpora": [ - { - "name": "unittest", - "documents": [ - { - "source-file": "documents-1k.json.bz2", - "document-count": 1000 - } - ] - } - ], - "operations": [ - { - "name": "index-append", - "operation-type": "bulk", - "bulk-size": 5000 - }, - { - "name": "search", - "operation-type": "search" - } - ], + "indices": [{"name": "test-index", "body": "test-index-body.json", "types": ["test-type"]}], + "corpora": [{"name": "unittest", "documents": [{"source-file": "documents-1k.json.bz2", "document-count": 1000}]}], + "operations": [{"name": "index-append", "operation-type": "bulk", "bulk-size": 5000}, {"name": "search", "operation-type": "search"}], "test_procedures": [ { "name": "default-test_procedure", @@ -1455,63 +1292,37 @@ def test_post_processes_workload_spec(self): { "parallel": { "tasks": [ - { - "name": "search #1", - "clients": 4, - "operation": "search", - "warmup-iterations": 4, - "iterations": 4 - }, - { - "name": "search #2", - "clients": 1, - "operation": "search", - "warmup-iterations": 1, - "iterations": 1 - }, - { - "name": "search #3", - "clients": 1, - "operation": "search", - "iterations": 1 - } + {"name": "search #1", "clients": 4, "operation": "search", "warmup-iterations": 4, "iterations": 4}, + {"name": "search #2", "clients": 1, "operation": "search", "warmup-iterations": 1, "iterations": 1}, + {"name": "search #3", "clients": 1, "operation": "search", "iterations": 1}, ] } - } - ] + }, + ], } - ] + ], } complete_workload_params = loader.CompleteWorkloadParams() - index_body = '{"settings": {"index.number_of_shards": {{ number_of_shards | default(5) }}, '\ - '"index.number_of_replicas": {{ number_of_replicas | default(0)}} }}' + index_body = '{"settings": {"index.number_of_shards": {{ number_of_shards | default(5) }}, "index.number_of_replicas": {{ number_of_replicas | default(0)}} }}' cfg = config.Config() - cfg.add(config.Scope.application, "workload", - "test.mode.enabled", True) + cfg.add(config.Scope.application, "workload", "test.mode.enabled", True) self.assertEqual( - self.as_workload( - expected_post_processed, complete_workload_params=complete_workload_params, index_body=index_body), + self.as_workload(expected_post_processed, complete_workload_params=complete_workload_params, index_body=index_body), loader.TestModeWorkloadProcessor(cfg).on_after_load_workload( - self.as_workload( - workload_specification, complete_workload_params=complete_workload_params, index_body=index_body) - ) + self.as_workload(workload_specification, complete_workload_params=complete_workload_params, index_body=index_body) + ), ) - self.assertEqual( - [], - complete_workload_params.sorted_workload_defined_params - ) + self.assertEqual([], complete_workload_params.sorted_workload_defined_params) def as_workload(self, workload_specification, workload_params=None, complete_workload_params=None, index_body=None): reader = loader.WorkloadSpecificationReader( workload_params=workload_params, complete_workload_params=complete_workload_params, - source=io.DictStringFileSourceFactory({ - "/mappings/test-index-body.json": [index_body] - }) + source=io.DictStringFileSourceFactory({"/mappings/test-index-body.json": [index_body]}), ) return reader("unittest", workload_specification, "/mappings") @@ -1522,95 +1333,66 @@ def test_sets_absolute_path(self, path_exists): path_exists.return_value = True cfg = config.Config() - cfg.add(config.Scope.application, "benchmarks", - "local.dataset.cache", "/data") + cfg.add(config.Scope.application, "benchmarks", "local.dataset.cache", "/data") - default_test_procedure = workload.TestProcedure("default", default=True, schedule=[ - workload.Task(name="index", operation=workload.Operation( - "index", operation_type=workload.OperationType.Bulk), clients=4) - ]) + default_test_procedure = workload.TestProcedure( + "default", default=True, schedule=[workload.Task(name="index", operation=workload.Operation("index", operation_type=workload.OperationType.Bulk), clients=4)] + ) another_test_procedure = workload.TestProcedure("other", default=False) - t = workload.Workload(name="u", test_procedures=[another_test_procedure, default_test_procedure], - corpora=[ - workload.DocumentCorpus("unittest", documents=[ - workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, - document_file="docs/documents.json", - document_archive="docs/documents.json.bz2") - ]) - ], - ) + t = workload.Workload( + name="u", + test_procedures=[another_test_procedure, default_test_procedure], + corpora=[ + workload.DocumentCorpus( + "unittest", + documents=[ + workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, document_file="docs/documents.json", document_archive="docs/documents.json.bz2") + ], + ) + ], + ) loader.set_absolute_data_path(cfg, t) - self.assertEqual("/data/unittest/docs/documents.json", - t.corpora[0].documents[0].document_file) - self.assertEqual("/data/unittest/docs/documents.json.bz2", - t.corpora[0].documents[0].document_archive) + self.assertEqual("/data/unittest/docs/documents.json", t.corpora[0].documents[0].document_file) + self.assertEqual("/data/unittest/docs/documents.json.bz2", t.corpora[0].documents[0].document_archive) class WorkloadFilterTests(TestCase): def filter(self, workload_specification, include_tasks=None, exclude_tasks=None): cfg = config.Config() - cfg.add(config.Scope.application, "workload", - "include.tasks", include_tasks) - cfg.add(config.Scope.application, "workload", - "exclude.tasks", exclude_tasks) + cfg.add(config.Scope.application, "workload", "include.tasks", include_tasks) + cfg.add(config.Scope.application, "workload", "exclude.tasks", exclude_tasks) processor = loader.TaskFilterWorkloadProcessor(cfg) return processor.on_after_load_workload(workload_specification) def test_rejects_invalid_syntax(self): with self.assertRaises(exceptions.SystemSetupError) as ctx: - self.filter(workload_specification=None, - include_tasks=["valid", "a:b:c"]) - self.assertEqual( - "Invalid format for filtered tasks: [a:b:c]", ctx.exception.args[0]) + self.filter(workload_specification=None, include_tasks=["valid", "a:b:c"]) + self.assertEqual("Invalid format for filtered tasks: [a:b:c]", ctx.exception.args[0]) def test_rejects_unknown_filter_type(self): with self.assertRaises(exceptions.SystemSetupError) as ctx: - self.filter(workload_specification=None, - include_tasks=["valid", "op-type:index"]) - self.assertEqual("Invalid format for filtered tasks: [op-type:index]. Expected [type] but got [op-type].", - ctx.exception.args[0]) + self.filter(workload_specification=None, include_tasks=["valid", "op-type:index"]) + self.assertEqual("Invalid format for filtered tasks: [op-type:index]. Expected [type] but got [op-type].", ctx.exception.args[0]) def test_filters_tasks(self): workload_specification = { "description": "description for unit test", "indices": [{"name": "test-index", "auto-managed": False}], "operations": [ - { - "name": "create-index", - "operation-type": "create-index" - }, - { - "name": "bulk-index", - "operation-type": "bulk" - }, - { - "name": "node-stats", - "operation-type": "node-stats" - }, - { - "name": "cluster-stats", - "operation-type": "custom-operation-type" - }, - { - "name": "match-all", - "operation-type": "search", - "body": { - "query": { - "match_all": {} - } - } - }, + {"name": "create-index", "operation-type": "create-index"}, + {"name": "bulk-index", "operation-type": "bulk"}, + {"name": "node-stats", "operation-type": "node-stats"}, + {"name": "cluster-stats", "operation-type": "custom-operation-type"}, + {"name": "match-all", "operation-type": "search", "body": {"query": {"match_all": {}}}}, ], "test_procedures": [ { "name": "default-test_procedure", "schedule": [ - { - "operation": "create-index" - }, + {"operation": "create-index"}, { "parallel": { "tasks": [ @@ -1633,16 +1415,9 @@ def test_filters_tasks(self): ] } }, - { - "operation": "node-stats" - }, - { - "name": "match-all-serial", - "operation": "match-all" - }, - { - "operation": "cluster-stats" - }, + {"operation": "node-stats"}, + {"name": "match-all-serial", "operation": "match-all"}, + {"operation": "cluster-stats"}, { "parallel": { "tasks": [ @@ -1659,37 +1434,36 @@ def test_filters_tasks(self): { "name": "index-5", "operation": "bulk-index", - } + }, ] } }, - { - "name": "final-cluster-stats", - "operation": "cluster-stats", - "tags": "include-me" - } - ] + {"name": "final-cluster-stats", "operation": "cluster-stats", "tags": "include-me"}, + ], } - ] + ], } reader = loader.WorkloadSpecificationReader() full_workload = reader("unittest", workload_specification, "/mappings") self.assertEqual(7, len(full_workload.test_procedures[0].schedule)) - filtered = self.filter(full_workload, include_tasks=["index-3", - "type:search", - # Filtering should also work for non-core operation types. - "type:custom-operation-type", - "tag:include-me"]) + filtered = self.filter( + full_workload, + include_tasks=[ + "index-3", + "type:search", + # Filtering should also work for non-core operation types. + "type:custom-operation-type", + "tag:include-me", + ], + ) schedule = filtered.test_procedures[0].schedule self.assertEqual(5, len(schedule)) - self.assertEqual(["index-3", "match-all-parallel"], - [t.name for t in schedule[0].tasks]) + self.assertEqual(["index-3", "match-all-parallel"], [t.name for t in schedule[0].tasks]) self.assertEqual("match-all-serial", schedule[1].name) self.assertEqual("cluster-stats", schedule[2].name) - self.assertEqual(["query-filtered", "index-4"], - [t.name for t in schedule[3].tasks]) + self.assertEqual(["query-filtered", "index-4"], [t.name for t in schedule[3].tasks]) self.assertEqual("final-cluster-stats", schedule[4].name) def test_filters_exclude_tasks(self): @@ -1697,39 +1471,17 @@ def test_filters_exclude_tasks(self): "description": "description for unit test", "indices": [{"name": "test-index", "auto-managed": False}], "operations": [ - { - "name": "create-index", - "operation-type": "create-index" - }, - { - "name": "bulk-index", - "operation-type": "bulk" - }, - { - "name": "node-stats", - "operation-type": "node-stats" - }, - { - "name": "cluster-stats", - "operation-type": "custom-operation-type" - }, - { - "name": "match-all", - "operation-type": "search", - "body": { - "query": { - "match_all": {} - } - } - }, + {"name": "create-index", "operation-type": "create-index"}, + {"name": "bulk-index", "operation-type": "bulk"}, + {"name": "node-stats", "operation-type": "node-stats"}, + {"name": "cluster-stats", "operation-type": "custom-operation-type"}, + {"name": "match-all", "operation-type": "search", "body": {"query": {"match_all": {}}}}, ], "test_procedures": [ { "name": "default-test_procedure", "schedule": [ - { - "operation": "create-index" - }, + {"operation": "create-index"}, { "parallel": { "tasks": [ @@ -1752,31 +1504,22 @@ def test_filters_exclude_tasks(self): ] } }, - { - "operation": "node-stats" - }, - { - "name": "match-all-serial", - "operation": "match-all" - }, - { - "operation": "cluster-stats" - } - ] + {"operation": "node-stats"}, + {"name": "match-all-serial", "operation": "match-all"}, + {"operation": "cluster-stats"}, + ], } - ] + ], } reader = loader.WorkloadSpecificationReader() full_workload = reader("unittest", workload_specification, "/mappings") self.assertEqual(5, len(full_workload.test_procedures[0].schedule)) - filtered = self.filter(full_workload, exclude_tasks=[ - "index-3", "type:search", "create-index"]) + filtered = self.filter(full_workload, exclude_tasks=["index-3", "type:search", "create-index"]) schedule = filtered.test_procedures[0].schedule self.assertEqual(3, len(schedule)) - self.assertEqual(["index-1", "index-2"], - [t.name for t in schedule[0].tasks]) + self.assertEqual(["index-1", "index-2"], [t.name for t in schedule[0].tasks]) self.assertEqual("node-stats", schedule[1].name) self.assertEqual("cluster-stats", schedule[2].name) @@ -1785,55 +1528,24 @@ def test_unmatched_exclude_runs_everything(self): "description": "description for unit test", "indices": [{"name": "test-index", "auto-managed": False}], "operations": [ - { - "name": "create-index", - "operation-type": "create-index" - }, - { - "name": "bulk-index", - "operation-type": "bulk" - }, - { - "name": "node-stats", - "operation-type": "node-stats" - }, - { - "name": "cluster-stats", - "operation-type": "custom-operation-type" - }, - { - "name": "match-all", - "operation-type": "search", - "body": { - "query": { - "match_all": {} - } - } - }, + {"name": "create-index", "operation-type": "create-index"}, + {"name": "bulk-index", "operation-type": "bulk"}, + {"name": "node-stats", "operation-type": "node-stats"}, + {"name": "cluster-stats", "operation-type": "custom-operation-type"}, + {"name": "match-all", "operation-type": "search", "body": {"query": {"match_all": {}}}}, ], "test_procedures": [ { "name": "default-test_procedure", "schedule": [ - { - "operation": "create-index" - }, - { - "operation": "bulk-index" - }, - { - "operation": "node-stats" - }, - { - "name": "match-all-serial", - "operation": "match-all" - }, - { - "operation": "cluster-stats" - } - ] + {"operation": "create-index"}, + {"operation": "bulk-index"}, + {"operation": "node-stats"}, + {"name": "match-all-serial", "operation": "match-all"}, + {"operation": "cluster-stats"}, + ], } - ] + ], } reader = loader.WorkloadSpecificationReader() @@ -1851,55 +1563,24 @@ def test_unmatched_include_runs_nothing(self): "description": "description for unit test", "indices": [{"name": "test-index", "auto-managed": False}], "operations": [ - { - "name": "create-index", - "operation-type": "create-index" - }, - { - "name": "bulk-index", - "operation-type": "bulk" - }, - { - "name": "node-stats", - "operation-type": "node-stats" - }, - { - "name": "cluster-stats", - "operation-type": "custom-operation-type" - }, - { - "name": "match-all", - "operation-type": "search", - "body": { - "query": { - "match_all": {} - } - } - }, + {"name": "create-index", "operation-type": "create-index"}, + {"name": "bulk-index", "operation-type": "bulk"}, + {"name": "node-stats", "operation-type": "node-stats"}, + {"name": "cluster-stats", "operation-type": "custom-operation-type"}, + {"name": "match-all", "operation-type": "search", "body": {"query": {"match_all": {}}}}, ], "test_procedures": [ { "name": "default-test_procedure", "schedule": [ - { - "operation": "create-index" - }, - { - "operation": "bulk-index" - }, - { - "operation": "node-stats" - }, - { - "name": "match-all-serial", - "operation": "match-all" - }, - { - "operation": "cluster-stats" - } - ] + {"operation": "create-index"}, + {"operation": "bulk-index"}, + {"operation": "node-stats"}, + {"name": "match-all-serial", "operation": "match-all"}, + {"operation": "cluster-stats"}, + ], } - ] + ], } reader = loader.WorkloadSpecificationReader() @@ -1912,8 +1593,8 @@ def test_unmatched_include_runs_nothing(self): schedule = filtered.test_procedures[0].schedule self.assertEqual(expected_schedule, schedule) -class WorkloadRandomizationTests(TestCase): +class WorkloadRandomizationTests(TestCase): # Helper class used to set up queries with mock standard values for testing # We want >1 op to ensure logic for giving different ops their own lambdas is working properly class StandardValueHelper: @@ -1929,24 +1610,16 @@ def __init__(self): # to be able to distinguish when we generate a new value vs draw an "existing" one. # in actual usage, these would come from the same function with some randomness in it self.saved_values = { - self.op_name_1:{ - self.field_name_1:{"lte":40, "gte":30}, - self.field_name_2:{"lte":"06/06/2016", "gte":"05/05/2016", "format":"dd/MM/yyyy"} - }, - self.op_name_2:{ - self.field_name_3:{"top_left":[-9, 9], "bottom_right":[0, 0]} - } + self.op_name_1: {self.field_name_1: {"lte": 40, "gte": 30}, self.field_name_2: {"lte": "06/06/2016", "gte": "05/05/2016", "format": "dd/MM/yyyy"}}, + self.op_name_2: {self.field_name_3: {"top_left": [-9, 9], "bottom_right": [0, 0]}}, } # Used to generate new values, in the source function self.new_values = { - self.op_name_1:{ - self.field_name_1:{"lte":41, "gte":31}, - self.field_name_2:{"lte":"04/04/2016", "gte":"03/03/2016", "format":"dd/MM/yyyy"} + self.op_name_1: {self.field_name_1: {"lte": 41, "gte": 31}, self.field_name_2: {"lte": "04/04/2016", "gte": "03/03/2016", "format": "dd/MM/yyyy"}}, + self.op_name_2: { + self.field_name_3: {"top_left": [-10, 10], "bottom_right": [0, 0]}, }, - self.op_name_2:{ - self.field_name_3:{"top_left":[-10, 10], "bottom_right":[0, 0]}, - } } self.op_1_query = { @@ -1957,43 +1630,18 @@ def __init__(self): "query": { "bool": { "filter": { - "range": { - self.field_name_1: { - "lt": 50, - "gte": 0 - } - }, - "must": [ - { - "range": { - self.field_name_2: { - "gte": "01/01/2015", - "lte": "21/01/2015", - "format": "dd/MM/yyyy" - } - } - } - ] + "range": {self.field_name_1: {"lt": 50, "gte": 0}}, + "must": [{"range": {self.field_name_2: {"gte": "01/01/2015", "lte": "21/01/2015", "format": "dd/MM/yyyy"}}}], } } - } - } + }, + }, } self.op_2_query = { "name": self.op_name_2, "operation-type": "search", - "body": { - "size": 0, - "query": { - "geo_bounding_box": { - self.field_name_3: { - "top_left": [-0.1, 61.0], - "bottom_right": [15.0, 48.0] - } - } - } - } + "body": {"size": 0, "query": {"geo_bounding_box": {self.field_name_3: {"top_left": [-0.1, 61.0], "bottom_right": [15.0, 48.0]}}}}, } def get_simple_workload(self): @@ -2001,21 +1649,12 @@ def get_simple_workload(self): workload_specification = { "description": "description for unit test", "collections": [{"name": self.index_name}], - "operations": [ - { - "name": "create-index", - "operation-type": "create-index" - }, - self.op_1_query, - self.op_2_query - ], + "operations": [{"name": "create-index", "operation-type": "create-index"}, self.op_1_query, self.op_2_query], "test_procedures": [ { "name": "default-test_procedure", "schedule": [ - { - "operation": "create-index" - }, + {"operation": "create-index"}, { "name": "dummy-task-name-1", "operation": self.op_name_1, @@ -2024,9 +1663,9 @@ def get_simple_workload(self): "name": "dummy-task-name-2", "operation": self.op_name_2, }, - ] + ], } - ] + ], } reader = loader.WorkloadSpecificationReader() full_workload = reader("unittest", workload_specification, "/mappings") @@ -2048,24 +1687,9 @@ def test_range_finding_function(self): single_range_query = { "name": "distance_amount_agg", "operation-type": "search", - "body": { - "size": 0, - "query": { - "bool": { - "filter": { - "range": { - "trip_distance": { - "lt": 50, - "gte": 0 - } - } - } - } - } - } + "body": {"size": 0, "query": {"bool": {"filter": {"range": {"trip_distance": {"lt": 50, "gte": 0}}}}}}, } - single_range_query_result = processor.extract_fields_and_paths( - single_range_query, loader.QueryRandomizerWorkloadProcessor.DEFAULT_QUERY_RANDOMIZATION_INFO) + single_range_query_result = processor.extract_fields_and_paths(single_range_query, loader.QueryRandomizerWorkloadProcessor.DEFAULT_QUERY_RANDOMIZATION_INFO) single_range_query_expected = [("trip_distance", ["bool", "filter", "range"])] self.assertEqual(single_range_query_result, single_range_query_expected) @@ -2075,91 +1699,50 @@ def test_range_finding_function(self): "body": { "size": 0, "query": { - "range": { - "dropoff_datetime": { - "gte": "01/01/2015", - "lte": "21/01/2015", - "format": "dd/MM/yyyy" - } - }, - "bool": { - "filter": { - "range": { - "dummy_field": { - "lte": 50, - "gt": 0 - } - } + "range": {"dropoff_datetime": {"gte": "01/01/2015", "lte": "21/01/2015", "format": "dd/MM/yyyy"}}, + "bool": { + "filter": {"range": {"dummy_field": {"lte": 50, "gt": 0}}}, + "must": [ + {"range": {"dummy_field_2": {"gte": "1998-05-01T00:00:00Z", "lt": "1998-05-02T00:00:00Z"}}}, + {"match": {"status": "400"}}, + {"range": {"dummy_field_3": {"gt": 10, "lt": 11}}}, + ], }, - "must": [ - { - "range": { - "dummy_field_2": { - "gte": "1998-05-01T00:00:00Z", - "lt": "1998-05-02T00:00:00Z" - } - } - }, - { - "match": { - "status": "400" - } - }, - { - "range": { - "dummy_field_3": { - "gt": 10, - "lt": 11 - } - } - } - ] - } - } - } + }, + }, } multiple_nested_range_query_result = processor.extract_fields_and_paths( - multiple_nested_range_query, loader.QueryRandomizerWorkloadProcessor.DEFAULT_QUERY_RANDOMIZATION_INFO) + multiple_nested_range_query, loader.QueryRandomizerWorkloadProcessor.DEFAULT_QUERY_RANDOMIZATION_INFO + ) print("Multi result: ", multiple_nested_range_query_result) multiple_nested_range_query_expected = [ ("dropoff_datetime", ["range"]), ("dummy_field", ["bool", "filter", "range"]), ("dummy_field_2", ["bool", "must", 0, "range"]), - ("dummy_field_3", ["bool", "must", 2, "range"]) - ] + ("dummy_field_3", ["bool", "must", 2, "range"]), + ] self.assertEqual(multiple_nested_range_query_result, multiple_nested_range_query_expected) with self.assertRaises(exceptions.SystemSetupError) as ctx: - params = {"body":{"contents":["not_a_valid_query"]}} + params = {"body": {"contents": ["not_a_valid_query"]}} _ = processor.extract_fields_and_paths(params, loader.QueryRandomizerWorkloadProcessor.DEFAULT_QUERY_RANDOMIZATION_INFO) self.assertEqual( - f"Cannot extract range query fields from these params: {params}\n, missing params[\"body\"][\"query\"]\n" + f'Cannot extract range query fields from these params: {params}\n, missing params["body"]["query"]\n' f"Make sure the operation in operations/default.json is well-formed", - ctx.exception.args[0]) + ctx.exception.args[0], + ) # Test a non-default value for query_randomization_info geo_point_query = { "name": "bbox", "operation-type": "search", - "body": { - "size": 0, - "query": { - "geo_bounding_box": { - "location": { - "top_left": [-0.1, 61.0], - "bottom_right": [15.0, 48.0] - } - } - } - } + "body": {"size": 0, "query": {"geo_bounding_box": {"location": {"top_left": [-0.1, 61.0], "bottom_right": [15.0, 48.0]}}}}, } - geo_point_query_randomization_info = loader.QueryRandomizerWorkloadProcessor.QueryRandomizationInfo( - "geo_bounding_box", [["top_left"], ["bottom_right"]], []) + geo_point_query_randomization_info = loader.QueryRandomizerWorkloadProcessor.QueryRandomizationInfo("geo_bounding_box", [["top_left"], ["bottom_right"]], []) geo_point_result = processor.extract_fields_and_paths(geo_point_query, geo_point_query_randomization_info) geo_point_expected = [("location", ["geo_bounding_box"])] self.assertEqual(geo_point_result, geo_point_expected) - def test_get_randomized_values(self): helper = self.StandardValueHelper() @@ -2173,12 +1756,14 @@ def test_get_randomized_values(self): # Test resulting params for operation 1 workload = helper.get_simple_workload() - modified_params = processor.get_randomized_values(workload, - helper.op_1_query, - loader.QueryRandomizerWorkloadProcessor.DEFAULT_QUERY_RANDOMIZATION_INFO, - op_name=helper.op_name_1, - get_standard_value=helper.get_standard_value, - get_standard_value_source=helper.get_standard_value_source) + modified_params = processor.get_randomized_values( + workload, + helper.op_1_query, + loader.QueryRandomizerWorkloadProcessor.DEFAULT_QUERY_RANDOMIZATION_INFO, + op_name=helper.op_name_1, + get_standard_value=helper.get_standard_value, + get_standard_value_source=helper.get_standard_value_source, + ) modified_range_1 = modified_params["body"]["query"]["bool"]["filter"]["range"][helper.field_name_1] modified_range_2 = modified_params["body"]["query"]["bool"]["filter"]["must"][0]["range"][helper.field_name_2] self.assertEqual(modified_range_1["lt"], expected_values_dict[helper.op_name_1][helper.field_name_1]["lte"]) @@ -2193,12 +1778,15 @@ def test_get_randomized_values(self): # Test resulting params for operation 2, which uses a non-default query_randomization_info workload = helper.get_simple_workload() - geo_point_query_randomization_info = loader.QueryRandomizerWorkloadProcessor.QueryRandomizationInfo( - "geo_bounding_box", [["top_left"], ["bottom_right"]], []) - modified_params = processor.get_randomized_values(workload, helper.op_2_query, geo_point_query_randomization_info, - op_name=helper.op_name_2, - get_standard_value=helper.get_standard_value, - get_standard_value_source=helper.get_standard_value_source) + geo_point_query_randomization_info = loader.QueryRandomizerWorkloadProcessor.QueryRandomizationInfo("geo_bounding_box", [["top_left"], ["bottom_right"]], []) + modified_params = processor.get_randomized_values( + workload, + helper.op_2_query, + geo_point_query_randomization_info, + op_name=helper.op_name_2, + get_standard_value=helper.get_standard_value, + get_standard_value_source=helper.get_standard_value_source, + ) modified_range_3 = modified_params["body"]["query"]["geo_bounding_box"][helper.field_name_3] self.assertEqual(modified_range_3["top_left"], expected_values_dict[helper.op_name_2][helper.field_name_3]["top_left"]) @@ -2213,8 +1801,8 @@ def test_on_after_load_workload(self): input_workload = helper.get_simple_workload() self.assertEqual( repr(input_workload), - repr(processor.on_after_load_workload(input_workload, get_standard_value=helper.get_standard_value, - get_standard_value_source=helper.get_standard_value_source))) + repr(processor.on_after_load_workload(input_workload, get_standard_value=helper.get_standard_value, get_standard_value_source=helper.get_standard_value_source)), + ) # It seems that comparing the workloads directly will incorrectly call them equal, even if they have differences, # so compare their string representations instead @@ -2229,9 +1817,12 @@ def test_on_after_load_workload(self): input_workload = helper.get_simple_workload() self.assertNotEqual( repr(input_workload), - repr(processor.on_after_load_workload(input_workload, get_standard_value=helper.get_standard_value, - get_standard_value_source=helper.get_standard_value_source, - query_randomization_info=None))) + repr( + processor.on_after_load_workload( + input_workload, get_standard_value=helper.get_standard_value, get_standard_value_source=helper.get_standard_value_source, query_randomization_info=None + ) + ), + ) for test_procedure in input_workload.test_procedures: for task in test_procedure.schedule: for leaf_task in task: @@ -2252,8 +1843,7 @@ def test_description_is_optional(self): } reader = loader.WorkloadSpecificationReader() - resulting_workload = reader( - "unittest", workload_specification, "/mappings") + resulting_workload = reader("unittest", workload_specification, "/mappings") self.assertEqual("unittest", resulting_workload.name) self.assertEqual("", resulting_workload.description) @@ -2264,58 +1854,31 @@ def test_can_read_workload_info(self): "data-streams": [], "corpora": [], "operations": [], - "test_procedures": [] + "test_procedures": [], } reader = loader.WorkloadSpecificationReader() - resulting_workload = reader( - "unittest", workload_specification, "/mappings") + resulting_workload = reader("unittest", workload_specification, "/mappings") self.assertEqual("unittest", resulting_workload.name) - self.assertEqual("description for unit test", - resulting_workload.description) + self.assertEqual("description for unit test", resulting_workload.description) def test_document_count_mandatory_if_file_present(self): workload_specification = { "description": "description for unit test", "indices": [{"name": "test-index", "types": ["docs"]}], - "corpora": [ - { - "name": "test", - "base-url": "https://localhost/data", - "documents": [{"source-file": "documents-main.json.bz2"}] - } - ], - "test_procedures": [] + "corpora": [{"name": "test", "base-url": "https://localhost/data", "documents": [{"source-file": "documents-main.json.bz2"}]}], + "test_procedures": [], } reader = loader.WorkloadSpecificationReader() with self.assertRaises(loader.WorkloadSyntaxError) as ctx: reader("unittest", workload_specification, "/mappings") - self.assertEqual( - "Workload 'unittest' is invalid. Mandatory element 'document-count' is missing.", ctx.exception.args[0]) + self.assertEqual("Workload 'unittest' is invalid. Mandatory element 'document-count' is missing.", ctx.exception.args[0]) @mock.patch("solrorbit.workload.loader.register_all_params_in_workload") def test_parse_with_mixed_warmup_iterations_and_measurement(self, mocked_params_checker): workload_specification = { "description": "description for unit test", - "indices": [ - { - "name": "test-index", - "body": "index.json", - "types": ["docs"] - } - ], - "corpora": [ - { - "name": "test", - "documents": [ - { - "source-file": "documents-main.json.bz2", - "document-count": 10, - "compressed-bytes": 100, - "uncompressed-bytes": 10000 - } - ] - } - ], + "indices": [{"name": "test-index", "body": "index.json", "types": ["docs"]}], + "corpora": [{"name": "test", "documents": [{"source-file": "documents-main.json.bz2", "document-count": 10, "compressed-bytes": 100, "uncompressed-bytes": 10000}]}], "operations": [ { "name": "index-append", @@ -2323,90 +1886,49 @@ def test_parse_with_mixed_warmup_iterations_and_measurement(self, mocked_params_ "bulk-size": 5000, } ], - "test_procedures": [ - { - "name": "default-test_procedure", - "schedule": [ - { - "clients": 8, - "operation": "index-append", - "warmup-iterations": 3, - "time-period": 60 - } - ] - } - - ] + "test_procedures": [{"name": "default-test_procedure", "schedule": [{"clients": 8, "operation": "index-append", "warmup-iterations": 3, "time-period": 60}]}], } - reader = loader.WorkloadSpecificationReader(source=io.DictStringFileSourceFactory({ - "/mappings/index.json": ['{"mappings": {"docs": "empty-for-test"}}'], - })) + reader = loader.WorkloadSpecificationReader( + source=io.DictStringFileSourceFactory( + { + "/mappings/index.json": ['{"mappings": {"docs": "empty-for-test"}}'], + } + ) + ) with self.assertRaises(loader.WorkloadSyntaxError) as ctx: reader("unittest", workload_specification, "/mappings") - self.assertEqual("Workload 'unittest' is invalid. Operation 'index-append' in test_procedure 'default-test_procedure' " - "defines '3' warmup iterations and a time period of '60' seconds. Please do not mix time periods and iterations.", - ctx.exception.args[0]) + self.assertEqual( + "Workload 'unittest' is invalid. Operation 'index-append' in test_procedure 'default-test_procedure' " + "defines '3' warmup iterations and a time period of '60' seconds. Please do not mix time periods and iterations.", + ctx.exception.args[0], + ) @mock.patch("solrorbit.workload.loader.register_all_params_in_workload") def test_parse_missing_test_procedure_or_test_procedures(self, mocked_params_checker): workload_specification = { "description": "description for unit test", - "indices": [ - { - "name": "test-index", - "body": "index.json", - "types": ["docs"] - } - ], - "corpora": [ - { - "name": "test", - "documents": [ - { - "source-file": "documents-main.json.bz2", - "document-count": 10, - "compressed-bytes": 100, - "uncompressed-bytes": 10000 - } - ] - } - ], + "indices": [{"name": "test-index", "body": "index.json", "types": ["docs"]}], + "corpora": [{"name": "test", "documents": [{"source-file": "documents-main.json.bz2", "document-count": 10, "compressed-bytes": 100, "uncompressed-bytes": 10000}]}], # no test_procedure or test_procedures element } - reader = loader.WorkloadSpecificationReader(source=io.DictStringFileSourceFactory({ - "/mappings/index.json": ['{"mappings": {"docs": "empty-for-test"}}'], - })) + reader = loader.WorkloadSpecificationReader( + source=io.DictStringFileSourceFactory( + { + "/mappings/index.json": ['{"mappings": {"docs": "empty-for-test"}}'], + } + ) + ) with self.assertRaises(loader.WorkloadSyntaxError) as ctx: reader("unittest", workload_specification, "/mappings") - self.assertEqual("Workload 'unittest' is invalid. You must define 'test_procedure', 'test_procedures' or " - "'schedule' but none is specified.", - ctx.exception.args[0]) + self.assertEqual("Workload 'unittest' is invalid. You must define 'test_procedure', 'test_procedures' or 'schedule' but none is specified.", ctx.exception.args[0]) @mock.patch("solrorbit.workload.loader.register_all_params_in_workload") def test_parse_iteration_and_ramp_up_period(self, mocked_params_checker): workload_specification = { "description": "description for unit test", - "indices": [ - { - "name": "test-index", - "body": "index.json", - "types": ["docs"] - } - ], - "corpora": [ - { - "name": "test", - "documents": [ - { - "source-file": "documents-main.json.bz2", - "document-count": 10, - "compressed-bytes": 100, - "uncompressed-bytes": 10000 - } - ] - } - ], + "indices": [{"name": "test-index", "body": "index.json", "types": ["docs"]}], + "corpora": [{"name": "test", "documents": [{"source-file": "documents-main.json.bz2", "document-count": 10, "compressed-bytes": 100, "uncompressed-bytes": 10000}]}], "operations": [ { "name": "index-append", @@ -2415,61 +1937,47 @@ def test_parse_iteration_and_ramp_up_period(self, mocked_params_checker): } ], "test_procedures": [ + {"name": "default-challenge", "schedule": [{"clients": 8, "operation": "index-append", "ramp-up-time-period": 120, "warmup-iterations": 3, "iterations": 5}]} + ], + } + reader = loader.WorkloadSpecificationReader( + source=io.DictStringFileSourceFactory( { - "name": "default-challenge", - "schedule": [ - { - "clients": 8, - "operation": "index-append", - "ramp-up-time-period": 120, - "warmup-iterations": 3, - "iterations": 5 - } - ] + "/mappings/index.json": ['{"mappings": {"docs": "empty-for-test"}}'], } - - ] - } - reader = loader.WorkloadSpecificationReader(source=io.DictStringFileSourceFactory({ - "/mappings/index.json": ['{"mappings": {"docs": "empty-for-test"}}'], - })) + ) + ) with self.assertRaises(loader.WorkloadSyntaxError) as ctx: reader("unittest", workload_specification, "/mappings") - self.assertEqual("Workload 'unittest' is invalid. Operation 'index-append' in test_procedure 'default-challenge' defines a ramp-up time period of " - "120 seconds as well as 3 warmup iterations and 5 iterations but mixing time periods and iterations is not allowed.", ctx.exception.args[0]) + self.assertEqual( + "Workload 'unittest' is invalid. Operation 'index-append' in test_procedure 'default-challenge' defines a ramp-up time period of " + "120 seconds as well as 3 warmup iterations and 5 iterations but mixing time periods and iterations is not allowed.", + ctx.exception.args[0], + ) @mock.patch("solrorbit.workload.loader.register_all_params_in_workload") def test_parse_valid_ramp_down_period(self, mocked_params_checker): workload_specification = { "description": "description for unit test", "indices": [{"name": "test-index", "body": "index.json", "types": ["docs"]}], - "corpora": [{ - "name": "test", - "documents": [{ - "source-file": "documents-main.json.bz2", - "document-count": 10, - "compressed-bytes": 100, - "uncompressed-bytes": 10000 - }] - }], + "corpora": [{"name": "test", "documents": [{"source-file": "documents-main.json.bz2", "document-count": 10, "compressed-bytes": 100, "uncompressed-bytes": 10000}]}], "operations": [{"name": "index-append", "operation-type": "bulk", "bulk-size": 5000}], - "test_procedures": [{ - "name": "default-challenge", - "schedule": [{ - "clients": 8, - "operation": "index-append", - "warmup-time-period": 60, - "time-period": 300, - "ramp-up-time-period": 60, - "ramp-down-time-period": 60 - }] - }] + "test_procedures": [ + { + "name": "default-challenge", + "schedule": [{"clients": 8, "operation": "index-append", "warmup-time-period": 60, "time-period": 300, "ramp-up-time-period": 60, "ramp-down-time-period": 60}], + } + ], } - reader = loader.WorkloadSpecificationReader(source=io.DictStringFileSourceFactory({ - "/mappings/index.json": ['{"mappings": {"docs": "empty-for-test"}}'], - })) + reader = loader.WorkloadSpecificationReader( + source=io.DictStringFileSourceFactory( + { + "/mappings/index.json": ['{"mappings": {"docs": "empty-for-test"}}'], + } + ) + ) resulting_workload = reader("unittest", workload_specification, "/mappings") task = resulting_workload.test_procedures[0].schedule[0] self.assertEqual(60, task.warmup_time_period) @@ -2482,167 +1990,110 @@ def test_parse_iteration_and_ramp_down_period_error(self, mocked_params_checker) workload_specification = { "description": "description for unit test", "indices": [{"name": "test-index", "body": "index.json", "types": ["docs"]}], - "corpora": [{ - "name": "test", - "documents": [{ - "source-file": "documents-main.json.bz2", - "document-count": 10, - "compressed-bytes": 100, - "uncompressed-bytes": 10000 - }] - }], + "corpora": [{"name": "test", "documents": [{"source-file": "documents-main.json.bz2", "document-count": 10, "compressed-bytes": 100, "uncompressed-bytes": 10000}]}], "operations": [{"name": "index-append", "operation-type": "bulk", "bulk-size": 5000}], - "test_procedures": [{ - "name": "default-challenge", - "schedule": [{ - "clients": 8, - "operation": "index-append", - "time-period": 300, - "ramp-down-time-period": 60, - "warmup-iterations": 3, - "iterations": 5 - }] - }] + "test_procedures": [ + { + "name": "default-challenge", + "schedule": [{"clients": 8, "operation": "index-append", "time-period": 300, "ramp-down-time-period": 60, "warmup-iterations": 3, "iterations": 5}], + } + ], } - reader = loader.WorkloadSpecificationReader(source=io.DictStringFileSourceFactory({ - "/mappings/index.json": ['{"mappings": {"docs": "empty-for-test"}}'], - })) + reader = loader.WorkloadSpecificationReader( + source=io.DictStringFileSourceFactory( + { + "/mappings/index.json": ['{"mappings": {"docs": "empty-for-test"}}'], + } + ) + ) with self.assertRaises(loader.WorkloadSyntaxError) as ctx: reader("unittest", workload_specification, "/mappings") # The warmup-iterations + time-period check triggers before the ramp-down-specific check - self.assertEqual("Workload 'unittest' is invalid. Operation 'index-append' in test_procedure 'default-challenge' defines '3' warmup iterations and " - "a time period of '300' seconds. Please do not mix time periods and iterations.", ctx.exception.args[0]) + self.assertEqual( + "Workload 'unittest' is invalid. Operation 'index-append' in test_procedure 'default-challenge' defines '3' warmup iterations and " + "a time period of '300' seconds. Please do not mix time periods and iterations.", + ctx.exception.args[0], + ) @mock.patch("solrorbit.workload.loader.register_all_params_in_workload") def test_parse_ramp_down_without_time_period_error(self, mocked_params_checker): workload_specification = { "description": "description for unit test", "indices": [{"name": "test-index", "body": "index.json", "types": ["docs"]}], - "corpora": [{ - "name": "test", - "documents": [{ - "source-file": "documents-main.json.bz2", - "document-count": 10, - "compressed-bytes": 100, - "uncompressed-bytes": 10000 - }] - }], + "corpora": [{"name": "test", "documents": [{"source-file": "documents-main.json.bz2", "document-count": 10, "compressed-bytes": 100, "uncompressed-bytes": 10000}]}], "operations": [{"name": "index-append", "operation-type": "bulk", "bulk-size": 5000}], - "test_procedures": [{ - "name": "default-challenge", - "schedule": [{ - "clients": 8, - "operation": "index-append", - "warmup-time-period": 30, - "ramp-down-time-period": 60 - }] - }] + "test_procedures": [{"name": "default-challenge", "schedule": [{"clients": 8, "operation": "index-append", "warmup-time-period": 30, "ramp-down-time-period": 60}]}], } - reader = loader.WorkloadSpecificationReader(source=io.DictStringFileSourceFactory({ - "/mappings/index.json": ['{"mappings": {"docs": "empty-for-test"}}'], - })) + reader = loader.WorkloadSpecificationReader( + source=io.DictStringFileSourceFactory( + { + "/mappings/index.json": ['{"mappings": {"docs": "empty-for-test"}}'], + } + ) + ) with self.assertRaises(loader.WorkloadSyntaxError) as ctx: reader("unittest", workload_specification, "/mappings") - self.assertEqual("Workload 'unittest' is invalid. Operation 'index-append' in test_procedure 'default-challenge' defines a ramp-down time period of " - "60 seconds but no time-period.", ctx.exception.args[0]) + self.assertEqual( + "Workload 'unittest' is invalid. Operation 'index-append' in test_procedure 'default-challenge' defines a ramp-down time period of 60 seconds but no time-period.", + ctx.exception.args[0], + ) @mock.patch("solrorbit.workload.loader.register_all_params_in_workload") def test_parse_ramp_down_exceeds_time_period_error(self, mocked_params_checker): workload_specification = { "description": "description for unit test", "indices": [{"name": "test-index", "body": "index.json", "types": ["docs"]}], - "corpora": [{ - "name": "test", - "documents": [{ - "source-file": "documents-main.json.bz2", - "document-count": 10, - "compressed-bytes": 100, - "uncompressed-bytes": 10000 - }] - }], + "corpora": [{"name": "test", "documents": [{"source-file": "documents-main.json.bz2", "document-count": 10, "compressed-bytes": 100, "uncompressed-bytes": 10000}]}], "operations": [{"name": "index-append", "operation-type": "bulk", "bulk-size": 5000}], - "test_procedures": [{ - "name": "default-challenge", - "schedule": [{ - "clients": 8, - "operation": "index-append", - "warmup-time-period": 30, - "time-period": 60, - "ramp-down-time-period": 120 - }] - }] + "test_procedures": [ + {"name": "default-challenge", "schedule": [{"clients": 8, "operation": "index-append", "warmup-time-period": 30, "time-period": 60, "ramp-down-time-period": 120}]} + ], } - reader = loader.WorkloadSpecificationReader(source=io.DictStringFileSourceFactory({ - "/mappings/index.json": ['{"mappings": {"docs": "empty-for-test"}}'], - })) + reader = loader.WorkloadSpecificationReader( + source=io.DictStringFileSourceFactory( + { + "/mappings/index.json": ['{"mappings": {"docs": "empty-for-test"}}'], + } + ) + ) with self.assertRaises(loader.WorkloadSyntaxError) as ctx: reader("unittest", workload_specification, "/mappings") - self.assertEqual("Workload 'unittest' is invalid. The time-period of operation 'index-append' in test_procedure 'default-challenge' is " - "60 seconds but must be greater than or equal to the ramp-down-time-period of 120 seconds.", ctx.exception.args[0]) - + self.assertEqual( + "Workload 'unittest' is invalid. The time-period of operation 'index-append' in test_procedure 'default-challenge' is " + "60 seconds but must be greater than or equal to the ramp-down-time-period of 120 seconds.", + ctx.exception.args[0], + ) @mock.patch("solrorbit.workload.loader.register_all_params_in_workload") def test_parse_test_procedure_and_test_procedures_are_defined(self, mocked_params_checker): workload_specification = { "description": "description for unit test", - "indices": [ - { - "name": "test-index", - "body": "index.json", - "types": ["docs"] - } - ], - "corpora": [ - { - "name": "test", - "documents": [ - { - "source-file": "documents-main.json.bz2", - "document-count": 10, - "compressed-bytes": 100, - "uncompressed-bytes": 10000 - } - ] - } - ], + "indices": [{"name": "test-index", "body": "index.json", "types": ["docs"]}], + "corpora": [{"name": "test", "documents": [{"source-file": "documents-main.json.bz2", "document-count": 10, "compressed-bytes": 100, "uncompressed-bytes": 10000}]}], # We define both. Note that test_procedures without any properties # would not pass JSON schema validation but we don't test this here. "test_procedure": {}, - "test_procedures": [] + "test_procedures": [], } - reader = loader.WorkloadSpecificationReader(source=io.DictStringFileSourceFactory({ - "/mappings/index.json": ['{"mappings": {"docs": "empty-for-test"}}'], - })) + reader = loader.WorkloadSpecificationReader( + source=io.DictStringFileSourceFactory( + { + "/mappings/index.json": ['{"mappings": {"docs": "empty-for-test"}}'], + } + ) + ) with self.assertRaises(loader.WorkloadSyntaxError) as ctx: reader("unittest", workload_specification, "/mappings") - self.assertEqual("Workload 'unittest' is invalid. Multiple out of 'test_procedure', 'test_procedures' or 'schedule' " - "are defined but only " - "one of them is allowed.", ctx.exception.args[0]) + self.assertEqual( + "Workload 'unittest' is invalid. Multiple out of 'test_procedure', 'test_procedures' or 'schedule' are defined but only one of them is allowed.", ctx.exception.args[0] + ) @mock.patch("solrorbit.workload.loader.register_all_params_in_workload") def test_parse_with_mixed_warmup_time_period_and_iterations(self, mocked_params_checker): workload_specification = { "description": "description for unit test", - "indices": [ - { - "name": "test-index", - "body": "index.json", - "types": ["docs"] - } - ], - "corpora": [ - { - "name": "test", - "documents": [ - { - "source-file": "documents-main.json.bz2", - "document-count": 10, - "compressed-bytes": 100, - "uncompressed-bytes": 10000 - } - ] - } - ], + "indices": [{"name": "test-index", "body": "index.json", "types": ["docs"]}], + "corpora": [{"name": "test", "documents": [{"source-file": "documents-main.json.bz2", "document-count": 10, "compressed-bytes": 100, "uncompressed-bytes": 10000}]}], "operations": [ { "name": "index-append", @@ -2650,128 +2101,71 @@ def test_parse_with_mixed_warmup_time_period_and_iterations(self, mocked_params_ "bulk-size": 5000, } ], - "test_procedures": [ - { - "name": "default-test_procedure", - "schedule": [ - { - "clients": 8, - "operation": "index-append", - "warmup-time-period": 20, - "iterations": 1000 - } - ] - } - - ] + "test_procedures": [{"name": "default-test_procedure", "schedule": [{"clients": 8, "operation": "index-append", "warmup-time-period": 20, "iterations": 1000}]}], } - reader = loader.WorkloadSpecificationReader(source=io.DictStringFileSourceFactory({ - "/mappings/index.json": ['{"mappings": {"docs": "empty-for-test"}}'], - })) + reader = loader.WorkloadSpecificationReader( + source=io.DictStringFileSourceFactory( + { + "/mappings/index.json": ['{"mappings": {"docs": "empty-for-test"}}'], + } + ) + ) with self.assertRaises(loader.WorkloadSyntaxError) as ctx: reader("unittest", workload_specification, "/mappings") - self.assertEqual("Workload 'unittest' is invalid. Operation 'index-append' in test_procedure 'default-test_procedure' " - "defines a warmup time " - "period of '20' seconds and '1000' iterations. " - "Please do not mix time periods and iterations.", - ctx.exception.args[0]) + self.assertEqual( + "Workload 'unittest' is invalid. Operation 'index-append' in test_procedure 'default-test_procedure' " + "defines a warmup time " + "period of '20' seconds and '1000' iterations. " + "Please do not mix time periods and iterations.", + ctx.exception.args[0], + ) def test_parse_duplicate_implicit_task_names(self): workload_specification = { "description": "description for unit test", - "operations": [ - { - "name": "search", - "operation-type": "search", - "index": "_all" - } - ], - "test_procedure": { - "name": "default-test_procedure", - "schedule": [ - { - "operation": "search", - "clients": 1 - }, - { - "operation": "search", - "clients": 2 - } - ] - } + "operations": [{"name": "search", "operation-type": "search", "index": "_all"}], + "test_procedure": {"name": "default-test_procedure", "schedule": [{"operation": "search", "clients": 1}, {"operation": "search", "clients": 2}]}, } reader = loader.WorkloadSpecificationReader() with self.assertRaises(loader.WorkloadSyntaxError) as ctx: reader("unittest", workload_specification, "/mappings") - self.assertEqual("Workload 'unittest' is invalid. TestProcedure 'default-test_procedure' contains multiple tasks" - " with the name 'search'. Please" - " use the task's name property to assign a unique name for each task.", - ctx.exception.args[0]) + self.assertEqual( + "Workload 'unittest' is invalid. TestProcedure 'default-test_procedure' contains multiple tasks" + " with the name 'search'. Please" + " use the task's name property to assign a unique name for each task.", + ctx.exception.args[0], + ) def test_parse_duplicate_explicit_task_names(self): workload_specification = { "description": "description for unit test", - "operations": [ - { - "name": "search", - "operation-type": "search", - "index": "_all" - } - ], + "operations": [{"name": "search", "operation-type": "search", "index": "_all"}], "test_procedure": { "name": "default-test_procedure", - "schedule": [ - { - "name": "duplicate-task-name", - "operation": "search", - "clients": 1 - }, - { - "name": "duplicate-task-name", - "operation": "search", - "clients": 2 - } - ] - } + "schedule": [{"name": "duplicate-task-name", "operation": "search", "clients": 1}, {"name": "duplicate-task-name", "operation": "search", "clients": 2}], + }, } reader = loader.WorkloadSpecificationReader() with self.assertRaises(loader.WorkloadSyntaxError) as ctx: reader("unittest", workload_specification, "/mappings") - self.assertEqual("Workload 'unittest' is invalid. TestProcedure 'default-test_procedure' contains multiple tasks with the name " - "'duplicate-task-name'. Please use the task's name property to assign a unique name for each task.", - ctx.exception.args[0]) + self.assertEqual( + "Workload 'unittest' is invalid. TestProcedure 'default-test_procedure' contains multiple tasks with the name " + "'duplicate-task-name'. Please use the task's name property to assign a unique name for each task.", + ctx.exception.args[0], + ) def test_parse_unique_task_names(self): workload_specification = { "description": "description for unit test", - "operations": [ - { - "name": "search", - "operation-type": "search", - "index": "_all" - } - ], + "operations": [{"name": "search", "operation-type": "search", "index": "_all"}], "test_procedure": { "name": "default-test_procedure", - "schedule": [ - { - "name": "search-one-client", - "operation": "search", - "clients": 1 - }, - { - "name": "search-two-clients", - "operation": "search", - "clients": 2 - } - ] - } + "schedule": [{"name": "search-one-client", "operation": "search", "clients": 1}, {"name": "search-two-clients", "operation": "search", "clients": 2}], + }, } - reader = loader.WorkloadSpecificationReader( - selected_test_procedure="default-test_procedure") - resulting_workload = reader( - "unittest", workload_specification, "/mappings") + reader = loader.WorkloadSpecificationReader(selected_test_procedure="default-test_procedure") + resulting_workload = reader("unittest", workload_specification, "/mappings") self.assertEqual("unittest", resulting_workload.name) test_procedure = resulting_workload.test_procedures[0] self.assertTrue(test_procedure.selected) @@ -2785,35 +2179,18 @@ def test_parse_unique_task_names(self): def test_parse_clients_list(self): workload_specification = { "description": "description for unit test", - "operations": [ - { - "name": "search", - "operation-type": "search", - "index": "_all" - } - ], + "operations": [{"name": "search", "operation-type": "search", "index": "_all"}], "test_procedure": { "name": "default-test-procedure", "schedule": [ - { - "name": "search-one-client", - "operation": "search", - "clients": 1, - "clients_list": [1, 2, 3] - }, - { - "name": "search-two-clients", - "operation": "search", - "clients": 2 - } - ] - } + {"name": "search-one-client", "operation": "search", "clients": 1, "clients_list": [1, 2, 3]}, + {"name": "search-two-clients", "operation": "search", "clients": 2}, + ], + }, } - reader = loader.WorkloadSpecificationReader( - selected_test_procedure="default-test-procedure") - resulting_workload = reader( - "unittest", workload_specification, "/mappings") + reader = loader.WorkloadSpecificationReader(selected_test_procedure="default-test-procedure") + resulting_workload = reader("unittest", workload_specification, "/mappings") self.assertEqual("unittest", resulting_workload.name) test_procedure = resulting_workload.test_procedures[0] self.assertTrue(test_procedure.selected) @@ -2829,11 +2206,11 @@ def test_parse_clients_list(self): self.assertEqual("search-two-clients", schedule[3].name) self.assertEqual("search", schedule[3].operation.name) + # pylint: disable=W0212 def test_naming_with_clients_list(self): - reader = loader.WorkloadSpecificationReader( - selected_test_procedure="default-test_procedure") + reader = loader.WorkloadSpecificationReader(selected_test_procedure="default-test_procedure") # Test case 1: name contains both "_" and "-" result = reader._rename_task_based_on_num_clients("test_name-task", 5) self.assertEqual(result, "test_name-task_5_clients") @@ -2854,157 +2231,71 @@ def test_unique_test_procedure_names(self): workload_specification = { "description": "description for unit test", "indices": [{"name": "test-index"}], - "operations": [ - { - "name": "index-append", - "operation-type": "bulk" - } - ], + "operations": [{"name": "index-append", "operation-type": "bulk"}], "test_procedures": [ - { - "name": "test-test_procedure", - "description": "Some test_procedure", - "default": True, - "schedule": [ - { - "operation": "index-append" - } - ] - }, - { - "name": "test-test_procedure", - "description": "Another test_procedure with the same name", - "schedule": [ - { - "operation": "index-append" - } - ] - } - - ] + {"name": "test-test_procedure", "description": "Some test_procedure", "default": True, "schedule": [{"operation": "index-append"}]}, + {"name": "test-test_procedure", "description": "Another test_procedure with the same name", "schedule": [{"operation": "index-append"}]}, + ], } reader = loader.WorkloadSpecificationReader() with self.assertRaises(loader.WorkloadSyntaxError) as ctx: reader("unittest", workload_specification, "/mappings") - self.assertEqual( - "Workload 'unittest' is invalid. Duplicate test_procedure with name 'test-test_procedure'.", ctx.exception.args[0]) + self.assertEqual("Workload 'unittest' is invalid. Duplicate test_procedure with name 'test-test_procedure'.", ctx.exception.args[0]) def test_not_more_than_one_default_test_procedure_possible(self): workload_specification = { "description": "description for unit test", "indices": [{"name": "test-index"}], - "operations": [ - { - "name": "index-append", - "operation-type": "bulk" - } - ], + "operations": [{"name": "index-append", "operation-type": "bulk"}], "test_procedures": [ - { - "name": "default-test_procedure", - "description": "Default test_procedure", - "default": True, - "schedule": [ - { - "operation": "index-append" - } - ] - }, - { - "name": "another-test_procedure", - "description": "See if we can sneek it in as another default", - "default": True, - "schedule": [ - { - "operation": "index-append" - } - ] - } - - ] + {"name": "default-test_procedure", "description": "Default test_procedure", "default": True, "schedule": [{"operation": "index-append"}]}, + {"name": "another-test_procedure", "description": "See if we can sneek it in as another default", "default": True, "schedule": [{"operation": "index-append"}]}, + ], } reader = loader.WorkloadSpecificationReader() with self.assertRaises(loader.WorkloadSyntaxError) as ctx: reader("unittest", workload_specification, "/mappings") - self.assertEqual("Workload 'unittest' is invalid. Both 'default-test_procedure' and 'another-test_procedure' " - "are defined as default test_procedures. " - "Please define only one of them as default.", ctx.exception.args[0]) + self.assertEqual( + "Workload 'unittest' is invalid. Both 'default-test_procedure' and 'another-test_procedure' " + "are defined as default test_procedures. " + "Please define only one of them as default.", + ctx.exception.args[0], + ) def test_at_least_one_default_test_procedure(self): workload_specification = { "description": "description for unit test", "indices": [{"name": "test-index"}], - "operations": [ - { - "name": "index-append", - "operation-type": "bulk" - } - ], + "operations": [{"name": "index-append", "operation-type": "bulk"}], "test_procedures": [ - { - "name": "test_procedure", - "schedule": [ - { - "operation": "index-append" - } - ] - }, - { - "name": "another-test_procedure", - "schedule": [ - { - "operation": "index-append" - } - ] - } - - ] + {"name": "test_procedure", "schedule": [{"operation": "index-append"}]}, + {"name": "another-test_procedure", "schedule": [{"operation": "index-append"}]}, + ], } reader = loader.WorkloadSpecificationReader() with self.assertRaises(loader.WorkloadSyntaxError) as ctx: reader("unittest", workload_specification, "/mappings") - self.assertEqual("Workload 'unittest' is invalid. No default test_procedure specified. " - "Please edit the workload and add \"default\": true " - "to one of the test_procedures test_procedure, another-test_procedure.", ctx.exception.args[0]) + self.assertEqual( + "Workload 'unittest' is invalid. No default test_procedure specified. " + 'Please edit the workload and add "default": true ' + "to one of the test_procedures test_procedure, another-test_procedure.", + ctx.exception.args[0], + ) def test_exactly_one_default_test_procedure(self): workload_specification = { "description": "description for unit test", "indices": [{"name": "test-index"}], - "operations": [ - { - "name": "index-append", - "operation-type": "bulk" - } - ], + "operations": [{"name": "index-append", "operation-type": "bulk"}], "test_procedures": [ - { - "name": "test_procedure", - "default": True, - "schedule": [ - { - "operation": "index-append" - } - ] - }, - { - "name": "another-test_procedure", - "schedule": [ - { - "operation": "index-append" - } - ] - } - - ] + {"name": "test_procedure", "default": True, "schedule": [{"operation": "index-append"}]}, + {"name": "another-test_procedure", "schedule": [{"operation": "index-append"}]}, + ], } - reader = loader.WorkloadSpecificationReader( - selected_test_procedure="another-test_procedure") - resulting_workload = reader( - "unittest", workload_specification, "/mappings") + reader = loader.WorkloadSpecificationReader(selected_test_procedure="another-test_procedure") + resulting_workload = reader("unittest", workload_specification, "/mappings") self.assertEqual(2, len(resulting_workload.test_procedures)) - self.assertEqual("test_procedure", - resulting_workload.test_procedures[0].name) + self.assertEqual("test_procedure", resulting_workload.test_procedures[0].name) self.assertTrue(resulting_workload.test_procedures[0].default) self.assertFalse(resulting_workload.test_procedures[1].default) self.assertTrue(resulting_workload.test_procedures[1].selected) @@ -3013,27 +2304,13 @@ def test_selects_sole_test_procedure_implicitly_as_default(self): workload_specification = { "description": "description for unit test", "indices": [{"name": "test-index"}], - "operations": [ - { - "name": "index-append", - "operation-type": "bulk" - } - ], - "test_procedure": { - "name": "test_procedure", - "schedule": [ - { - "operation": "index-append" - } - ] - } + "operations": [{"name": "index-append", "operation-type": "bulk"}], + "test_procedure": {"name": "test_procedure", "schedule": [{"operation": "index-append"}]}, } reader = loader.WorkloadSpecificationReader() - resulting_workload = reader( - "unittest", workload_specification, "/mappings") + resulting_workload = reader("unittest", workload_specification, "/mappings") self.assertEqual(1, len(resulting_workload.test_procedures)) - self.assertEqual("test_procedure", - resulting_workload.test_procedures[0].name) + self.assertEqual("test_procedure", resulting_workload.test_procedures[0].name) self.assertTrue(resulting_workload.test_procedures[0].default) self.assertTrue(resulting_workload.test_procedures[0].selected) @@ -3041,21 +2318,11 @@ def test_auto_generates_test_procedure_from_schedule(self): workload_specification = { "description": "description for unit test", "indices": [{"name": "test-index"}], - "operations": [ - { - "name": "index-append", - "operation-type": "bulk" - } - ], - "schedule": [ - { - "operation": "index-append" - } - ] + "operations": [{"name": "index-append", "operation-type": "bulk"}], + "schedule": [{"operation": "index-append"}], } reader = loader.WorkloadSpecificationReader() - resulting_workload = reader( - "unittest", workload_specification, "/mappings") + resulting_workload = reader("unittest", workload_specification, "/mappings") self.assertEqual(1, len(resulting_workload.test_procedures)) self.assertTrue(resulting_workload.test_procedures[0].auto_generated) self.assertTrue(resulting_workload.test_procedures[0].default) @@ -3069,55 +2336,32 @@ def test_inline_operations(self): "name": "test_procedure", "schedule": [ # an operation with parameters still needs to define a type - { - "operation": { - "operation-type": "bulk", - "bulk-size": 5000 - } - }, + {"operation": {"operation-type": "bulk", "bulk-size": 5000}}, # a parameterless operation can just use the operation type as implicit reference to the operation - { - "operation": "sleep" - } - ] - } + {"operation": "sleep"}, + ], + }, } reader = loader.WorkloadSpecificationReader() - resulting_workload = reader( - "unittest", workload_specification, "/mappings") + resulting_workload = reader("unittest", workload_specification, "/mappings") test_procedure = resulting_workload.test_procedures[0] self.assertEqual(2, len(test_procedure.schedule)) - self.assertEqual(workload.OperationType.Bulk.to_hyphenated_string( - ), test_procedure.schedule[0].operation.type) - self.assertEqual(workload.OperationType.Sleep.to_hyphenated_string( - ), test_procedure.schedule[1].operation.type) + self.assertEqual(workload.OperationType.Bulk.to_hyphenated_string(), test_procedure.schedule[0].operation.type) + self.assertEqual(workload.OperationType.Sleep.to_hyphenated_string(), test_procedure.schedule[1].operation.type) def test_supports_target_throughput(self): workload_specification = { "description": "description for unit test", "indices": [{"name": "test-index"}], - "operations": [ - { - "name": "index-append", - "operation-type": "bulk" - } - ], + "operations": [{"name": "index-append", "operation-type": "bulk"}], "test_procedure": { "name": "default-test_procedure", - "schedule": [ - { - "operation": "index-append", - "target-throughput": 10, - "warmup-time-period": 120, - "ramp-up-time-period": 60 - } - ] - } + "schedule": [{"operation": "index-append", "target-throughput": 10, "warmup-time-period": 120, "ramp-up-time-period": 60}], + }, } reader = loader.WorkloadSpecificationReader() - resulting_workload = reader( - "unittest", workload_specification, "/mappings") + resulting_workload = reader("unittest", workload_specification, "/mappings") indexing_task = resulting_workload.test_procedures[0].schedule[0] self.assertEqual(10, indexing_task.params["target-throughput"]) self.assertEqual(120, indexing_task.warmup_time_period) @@ -3127,12 +2371,7 @@ def test_supports_target_interval(self): workload_specification = { "description": "description for unit test", "indices": [{"name": "test-index"}], - "operations": [ - { - "name": "index-append", - "operation-type": "bulk" - } - ], + "operations": [{"name": "index-append", "operation-type": "bulk"}], "test_procedures": [ { "name": "default-test_procedure", @@ -3141,33 +2380,22 @@ def test_supports_target_interval(self): "operation": "index-append", "target-interval": 5, } - ] + ], } - ] + ], } reader = loader.WorkloadSpecificationReader() - resulting_workload = reader( - "unittest", workload_specification, "/mappings") - self.assertEqual( - 5, resulting_workload.test_procedures[0].schedule[0].params["target-interval"]) + resulting_workload = reader("unittest", workload_specification, "/mappings") + self.assertEqual(5, resulting_workload.test_procedures[0].schedule[0].params["target-interval"]) def test_parallel_tasks_with_default_values(self): workload_specification = { "description": "description for unit test", "indices": [{"name": "test-index"}], "operations": [ - { - "name": "index-1", - "operation-type": "bulk" - }, - { - "name": "index-2", - "operation-type": "bulk" - }, - { - "name": "index-3", - "operation-type": "bulk" - }, + {"name": "index-1", "operation-type": "bulk"}, + {"name": "index-2", "operation-type": "bulk"}, + {"name": "index-3", "operation-type": "bulk"}, ], "test_procedures": [ { @@ -3178,31 +2406,18 @@ def test_parallel_tasks_with_default_values(self): "warmup-time-period": 2400, "time-period": 36000, "tasks": [ - { - "operation": "index-1", - "warmup-time-period": 300, - "clients": 2 - }, - { - "operation": "index-2", - "time-period": 3600, - "clients": 4 - }, - { - "operation": "index-3", - "target-throughput": 10, - "clients": 16 - }, - ] + {"operation": "index-1", "warmup-time-period": 300, "clients": 2}, + {"operation": "index-2", "time-period": 3600, "clients": 4}, + {"operation": "index-3", "target-throughput": 10, "clients": 16}, + ], } } - ] + ], } - ] + ], } reader = loader.WorkloadSpecificationReader() - resulting_workload = reader( - "unittest", workload_specification, "/mappings") + resulting_workload = reader("unittest", workload_specification, "/mappings") parallel_element = resulting_workload.test_procedures[0].schedule[0] parallel_tasks = parallel_element.tasks @@ -3231,12 +2446,7 @@ def test_parallel_tasks_with_default_clients_does_not_propagate(self): workload_specification = { "description": "description for unit test", "indices": [{"name": "test-index"}], - "operations": [ - { - "name": "index-1", - "operation-type": "bulk" - } - ], + "operations": [{"name": "index-1", "operation-type": "bulk"}], "test_procedures": [ { "name": "default-test_procedure", @@ -3247,32 +2457,19 @@ def test_parallel_tasks_with_default_clients_does_not_propagate(self): "time-period": 36000, "clients": 2, "tasks": [ - { - "name": "index-1-1", - "operation": "index-1" - }, - { - "name": "index-1-2", - "operation": "index-1" - }, - { - "name": "index-1-3", - "operation": "index-1" - }, - { - "name": "index-1-4", - "operation": "index-1" - } - ] + {"name": "index-1-1", "operation": "index-1"}, + {"name": "index-1-2", "operation": "index-1"}, + {"name": "index-1-3", "operation": "index-1"}, + {"name": "index-1-4", "operation": "index-1"}, + ], } } - ] + ], } - ] + ], } reader = loader.WorkloadSpecificationReader() - resulting_workload = reader( - "unittest", workload_specification, "/mappings") + resulting_workload = reader("unittest", workload_specification, "/mappings") parallel_element = resulting_workload.test_procedures[0].schedule[0] parallel_tasks = parallel_element.tasks @@ -3286,42 +2483,18 @@ def test_parallel_tasks_with_completed_by_set(self): workload_specification = { "description": "description for unit test", "indices": [{"name": "test-index"}], - "operations": [ - { - "name": "index-1", - "operation-type": "bulk" - }, - { - "name": "index-2", - "operation-type": "bulk" - } - ], + "operations": [{"name": "index-1", "operation-type": "bulk"}, {"name": "index-2", "operation-type": "bulk"}], "test_procedures": [ { "name": "default-test_procedure", "schedule": [ - { - "parallel": { - "warmup-time-period": 2400, - "time-period": 36000, - "completed-by": "index-2", - "tasks": [ - { - "operation": "index-1" - }, - { - "operation": "index-2" - } - ] - } - } - ] + {"parallel": {"warmup-time-period": 2400, "time-period": 36000, "completed-by": "index-2", "tasks": [{"operation": "index-1"}, {"operation": "index-2"}]}} + ], } - ] + ], } reader = loader.WorkloadSpecificationReader() - resulting_workload = reader( - "unittest", workload_specification, "/mappings") + resulting_workload = reader("unittest", workload_specification, "/mappings") parallel_element = resulting_workload.test_procedures[0].schedule[0] parallel_tasks = parallel_element.tasks @@ -3339,16 +2512,7 @@ def test_parallel_tasks_with_named_task_completed_by_set(self): workload_specification = { "description": "description for unit test", "indices": [{"name": "test-index"}], - "operations": [ - { - "name": "index-1", - "operation-type": "bulk" - }, - { - "name": "index-2", - "operation-type": "bulk" - } - ], + "operations": [{"name": "index-1", "operation-type": "bulk"}, {"name": "index-2", "operation-type": "bulk"}], "test_procedures": [ { "name": "default-test_procedure", @@ -3358,25 +2522,15 @@ def test_parallel_tasks_with_named_task_completed_by_set(self): "warmup-time-period": 2400, "time-period": 36000, "completed-by": "name-index-2", - "tasks": [ - { - "name": "name-index-1", - "operation": "index-1" - }, - { - "name": "name-index-2", - "operation": "index-2" - } - ] + "tasks": [{"name": "name-index-1", "operation": "index-1"}, {"name": "name-index-2", "operation": "index-2"}], } } - ] + ], } - ] + ], } reader = loader.WorkloadSpecificationReader() - resulting_workload = reader( - "unittest", workload_specification, "/mappings") + resulting_workload = reader("unittest", workload_specification, "/mappings") parallel_element = resulting_workload.test_procedures[0].schedule[0] parallel_tasks = parallel_element.tasks @@ -3394,144 +2548,63 @@ def test_parallel_tasks_with_completed_by_set_no_task_matches(self): workload_specification = { "description": "description for unit test", "indices": [{"name": "test-index"}], - "operations": [ - { - "name": "index-1", - "operation-type": "bulk" - }, - { - "name": "index-2", - "operation-type": "bulk" - } - ], + "operations": [{"name": "index-1", "operation-type": "bulk"}, {"name": "index-2", "operation-type": "bulk"}], "test_procedures": [ - { - "name": "default-test_procedure", - "schedule": [ - { - "parallel": { - "completed-by": "non-existing-task", - "tasks": [ - { - "operation": "index-1" - }, - { - "operation": "index-2" - } - ] - } - } - ] - } - ] + {"name": "default-test_procedure", "schedule": [{"parallel": {"completed-by": "non-existing-task", "tasks": [{"operation": "index-1"}, {"operation": "index-2"}]}}]} + ], } reader = loader.WorkloadSpecificationReader() with self.assertRaises(loader.WorkloadSyntaxError) as ctx: reader("unittest", workload_specification, "/mappings") - self.assertEqual("Workload 'unittest' is invalid. 'parallel' element for " - "test_procedure 'default-test_procedure' is marked with 'completed-by' " - "with task name 'non-existing-task' but no task with this name exists.", ctx.exception.args[0]) + self.assertEqual( + "Workload 'unittest' is invalid. 'parallel' element for " + "test_procedure 'default-test_procedure' is marked with 'completed-by' " + "with task name 'non-existing-task' but no task with this name exists.", + ctx.exception.args[0], + ) def test_parallel_tasks_with_completed_by_set_multiple_tasks_match(self): workload_specification = { "description": "description for unit test", "indices": [{"name": "test-index"}], - "operations": [ - { - "name": "index-1", - "operation-type": "bulk" - } - ], + "operations": [{"name": "index-1", "operation-type": "bulk"}], "test_procedures": [ - { - "name": "default-test_procedure", - "schedule": [ - { - "parallel": { - "completed-by": "index-1", - "tasks": [ - { - "operation": "index-1" - }, - { - "operation": "index-1" - } - ] - } - } - ] - } - ] + {"name": "default-test_procedure", "schedule": [{"parallel": {"completed-by": "index-1", "tasks": [{"operation": "index-1"}, {"operation": "index-1"}]}}]} + ], } reader = loader.WorkloadSpecificationReader() with self.assertRaises(loader.WorkloadSyntaxError) as ctx: reader("unittest", workload_specification, "/mappings") - self.assertEqual("Workload 'unittest' is invalid. 'parallel' element for test_procedure " - "'default-test_procedure' contains multiple tasks with " - "the name 'index-1' which are marked with 'completed-by' but only task is allowed to match.", - ctx.exception.args[0]) + self.assertEqual( + "Workload 'unittest' is invalid. 'parallel' element for test_procedure " + "'default-test_procedure' contains multiple tasks with " + "the name 'index-1' which are marked with 'completed-by' but only task is allowed to match.", + ctx.exception.args[0], + ) def test_propagate_parameters_to_test_procedure_level(self): workload_specification = { "description": "description for unit test", - "parameters": { - "level": "workload", - "value": 7 - }, + "parameters": {"level": "workload", "value": 7}, "indices": [{"name": "test-index"}], - "operations": [ - { - "name": "index-append", - "operation-type": "bulk" - } - ], + "operations": [{"name": "index-append", "operation-type": "bulk"}], "test_procedures": [ - { - "name": "test_procedure", - "default": True, - "parameters": { - "level": "test_procedure", - "another-value": 17 - }, - "schedule": [ - { - "operation": "index-append" - } - ] - }, - { - "name": "another-test_procedure", - "schedule": [ - { - "operation": "index-append" - } - ] - } - - ] + {"name": "test_procedure", "default": True, "parameters": {"level": "test_procedure", "another-value": 17}, "schedule": [{"operation": "index-append"}]}, + {"name": "another-test_procedure", "schedule": [{"operation": "index-append"}]}, + ], } - reader = loader.WorkloadSpecificationReader( - selected_test_procedure="another-test_procedure") - resulting_workload = reader( - "unittest", workload_specification, "/mappings") + reader = loader.WorkloadSpecificationReader(selected_test_procedure="another-test_procedure") + resulting_workload = reader("unittest", workload_specification, "/mappings") self.assertEqual(2, len(resulting_workload.test_procedures)) - self.assertEqual("test_procedure", - resulting_workload.test_procedures[0].name) + self.assertEqual("test_procedure", resulting_workload.test_procedures[0].name) self.assertTrue(resulting_workload.test_procedures[0].default) - self.assertDictEqual({ - "level": "test_procedure", - "value": 7, - "another-value": 17 - }, resulting_workload.test_procedures[0].parameters) + self.assertDictEqual({"level": "test_procedure", "value": 7, "another-value": 17}, resulting_workload.test_procedures[0].parameters) self.assertFalse(resulting_workload.test_procedures[1].default) self.assertTrue(resulting_workload.test_procedures[1].selected) - self.assertDictEqual({ - "level": "workload", - "value": 7 - }, resulting_workload.test_procedures[1].parameters) + self.assertDictEqual({"level": "workload", "value": 7}, resulting_workload.test_procedures[1].parameters) class MyMockWorkloadProcessor(loader.WorkloadProcessor): @@ -3543,12 +2616,7 @@ def test_default_workload_processors(self): cfg = config.Config() cfg.add(config.Scope.application, "system", "offline.mode", False) tpr = loader.WorkloadProcessorRegistry(cfg) - expected_defaults = [ - loader.TaskFilterWorkloadProcessor, - loader.TestModeWorkloadProcessor, - loader.QueryRandomizerWorkloadProcessor, - loader.DefaultWorkloadPreparator - ] + expected_defaults = [loader.TaskFilterWorkloadProcessor, loader.TestModeWorkloadProcessor, loader.QueryRandomizerWorkloadProcessor, loader.DefaultWorkloadPreparator] actual_defaults = [proc.__class__ for proc in tpr.processors] self.assertCountEqual(expected_defaults, actual_defaults) @@ -3559,12 +2627,7 @@ def test_override_default_preparator(self): # call this once beforehand to make sure we don't "harden" the default in case calls are made out of order tpr.processors # pylint: disable=pointless-statement tpr.register_workload_processor(MyMockWorkloadProcessor()) - expected_processors = [ - loader.TaskFilterWorkloadProcessor, - loader.TestModeWorkloadProcessor, - loader.QueryRandomizerWorkloadProcessor, - MyMockWorkloadProcessor - ] + expected_processors = [loader.TaskFilterWorkloadProcessor, loader.TestModeWorkloadProcessor, loader.QueryRandomizerWorkloadProcessor, MyMockWorkloadProcessor] actual_processors = [proc.__class__ for proc in tpr.processors] self.assertCountEqual(expected_processors, actual_processors) diff --git a/tests/workload/params_test.py b/tests/workload/params_test.py index ce4da776..eb216a48 100644 --- a/tests/workload/params_test.py +++ b/tests/workload/params_test.py @@ -69,7 +69,7 @@ def test_slice_with_source_larger_than_slice(self): '{"key": "value7"}', '{"key": "value8"}', '{"key": "value9"}', - '{"key": "value10"}' + '{"key": "value10"}', ] source.open(data, "r", 5) @@ -106,75 +106,37 @@ def test_no_id_conflicts(self): def test_sequential_conflicts(self): self.assertEqual( - [ - '0000000000', - '0000000001', - '0000000002', - '0000000003', - '0000000004', - '0000000005', - '0000000006', - '0000000007', - '0000000008', - '0000000009', - '0000000010' - ], - params.build_conflicting_ids(params.IndexIdConflict.SequentialConflicts, 11, 0) + ["0000000000", "0000000001", "0000000002", "0000000003", "0000000004", "0000000005", "0000000006", "0000000007", "0000000008", "0000000009", "0000000010"], + params.build_conflicting_ids(params.IndexIdConflict.SequentialConflicts, 11, 0), ) self.assertEqual( - [ - '0000000005', - '0000000006', - '0000000007', - '0000000008', - '0000000009', - '0000000010', - '0000000011', - '0000000012', - '0000000013', - '0000000014', - '0000000015' - ], - params.build_conflicting_ids(params.IndexIdConflict.SequentialConflicts, 11, 5) + ["0000000005", "0000000006", "0000000007", "0000000008", "0000000009", "0000000010", "0000000011", "0000000012", "0000000013", "0000000014", "0000000015"], + params.build_conflicting_ids(params.IndexIdConflict.SequentialConflicts, 11, 5), ) def test_random_conflicts(self): predictable_shuffle = list.reverse - self.assertEqual( - [ - '0000000002', '0000000001', '0000000000' - ], - params.build_conflicting_ids(params.IndexIdConflict.RandomConflicts, 3, 0, shuffle=predictable_shuffle) - ) + self.assertEqual(["0000000002", "0000000001", "0000000000"], params.build_conflicting_ids(params.IndexIdConflict.RandomConflicts, 3, 0, shuffle=predictable_shuffle)) - self.assertEqual( - [ - '0000000007', '0000000006', '0000000005' - ], - params.build_conflicting_ids(params.IndexIdConflict.RandomConflicts, 3, 5, shuffle=predictable_shuffle) - ) + self.assertEqual(["0000000007", "0000000006", "0000000005"], params.build_conflicting_ids(params.IndexIdConflict.RandomConflicts, 3, 5, shuffle=predictable_shuffle)) class ActionMetaDataTests(TestCase): def test_generate_action_meta_data_without_id_conflicts(self): - self.assertEqual(("index", '{"index": {"_index": "test_index", "_type": "test_type"}}\n'), - next(params.GenerateActionMetaData("test_index", "test_type"))) + self.assertEqual(("index", '{"index": {"_index": "test_index", "_type": "test_type"}}\n'), next(params.GenerateActionMetaData("test_index", "test_type"))) def test_generate_action_meta_data_create(self): - self.assertEqual(("create", '{"create": {"_index": "test_index"}}\n'), - next(params.GenerateActionMetaData("test_index", None, use_create=True))) + self.assertEqual(("create", '{"create": {"_index": "test_index"}}\n'), next(params.GenerateActionMetaData("test_index", None, use_create=True))) def test_generate_action_meta_data_create_with_conflicts(self): with self.assertRaises(exceptions.BenchmarkError) as ctx: params.GenerateActionMetaData("test_index", None, conflicting_ids=[100, 200, 300, 400], use_create=True) - self.assertEqual("Index mode '_create' cannot be used with conflicting ids", - ctx.exception.args[0]) + self.assertEqual("Index mode '_create' cannot be used with conflicting ids", ctx.exception.args[0]) def test_generate_action_meta_data_typeless(self): - self.assertEqual(("index", '{"index": {"_index": "test_index"}}\n'), - next(params.GenerateActionMetaData("test_index", type_name=None))) + self.assertEqual(("index", '{"index": {"_index": "test_index"}}\n'), next(params.GenerateActionMetaData("test_index", type_name=None))) def test_generate_action_meta_data_with_id_conflicts(self): def idx(id): @@ -183,32 +145,40 @@ def idx(id): def conflict(action, id): return action, '{"%s": {"_index": "test_index", "_type": "test_type", "_id": "%s"}}\n' % (action, id) - pseudo_random_conflicts = iter([ - # if this value is <= our chosen threshold of 0.25 (see conflict_probability) we produce a conflict. - 0.2, - 0.25, - 0.2, - # no conflict - 0.3, - # conflict again - 0.0 - ]) - - chosen_index_of_conflicting_ids = iter([ - # the "random" index of the id in the array `conflicting_ids` that will produce a conflict - 1, - 3, - 2, - 0]) + pseudo_random_conflicts = iter( + [ + # if this value is <= our chosen threshold of 0.25 (see conflict_probability) we produce a conflict. + 0.2, + 0.25, + 0.2, + # no conflict + 0.3, + # conflict again + 0.0, + ] + ) + + chosen_index_of_conflicting_ids = iter( + [ + # the "random" index of the id in the array `conflicting_ids` that will produce a conflict + 1, + 3, + 2, + 0, + ] + ) conflict_action = random.choice(["index", "update"]) - generator = params.GenerateActionMetaData("test_index", "test_type", - conflicting_ids=[100, 200, 300, 400], - conflict_probability=25, - on_conflict=conflict_action, - rand=lambda: next(pseudo_random_conflicts), - randint=lambda x, y: next(chosen_index_of_conflicting_ids)) + generator = params.GenerateActionMetaData( + "test_index", + "test_type", + conflicting_ids=[100, 200, 300, 400], + conflict_probability=25, + on_conflict=conflict_action, + rand=lambda: next(pseudo_random_conflicts), + randint=lambda x, y: next(chosen_index_of_conflicting_ids), + ) # first one is always *not* drawn from a random index self.assertEqual(idx("100"), next(generator)) @@ -234,52 +204,58 @@ def conflict(action, type_name, id): else: return action, '{"%s": {"_index": "test_index", "_id": "%s"}}\n' % (action, id) - pseudo_random_conflicts = iter([ - # if this value is <= our chosen threshold of 0.25 (see conflict_probability) we produce a conflict. - 0.2, - 0.25, - 0.2, - # no conflict - 0.3, - 0.4, - 0.35, - # conflict again - 0.0, - 0.2, - 0.15 - ]) + pseudo_random_conflicts = iter( + [ + # if this value is <= our chosen threshold of 0.25 (see conflict_probability) we produce a conflict. + 0.2, + 0.25, + 0.2, + # no conflict + 0.3, + 0.4, + 0.35, + # conflict again + 0.0, + 0.2, + 0.15, + ] + ) # we use this value as `idx_range` in the calculation: idx = round((self.id_up_to - 1) * (1 - idx_range)) - pseudo_exponential_distribution = iter([ - # id_up_to = 1 -> idx = 0 - 0.013375248172714948, - # id_up_to = 1 -> idx = 0 - 0.042495604491024914, - # id_up_to = 1 -> idx = 0 - 0.005491072642023834, - # no conflict: id_up_to = 2 - # no conflict: id_up_to = 3 - # no conflict: id_up_to = 4 - # id_up_to = 4 -> idx = round((4 - 1) * (1 - 0.028557879547255083)) = 3 - 0.028557879547255083, - # id_up_to = 4 -> idx = round((4 - 1) * (1 - 0.209771474243926352)) = 2 - 0.209771474243926352 - ]) + pseudo_exponential_distribution = iter( + [ + # id_up_to = 1 -> idx = 0 + 0.013375248172714948, + # id_up_to = 1 -> idx = 0 + 0.042495604491024914, + # id_up_to = 1 -> idx = 0 + 0.005491072642023834, + # no conflict: id_up_to = 2 + # no conflict: id_up_to = 3 + # no conflict: id_up_to = 4 + # id_up_to = 4 -> idx = round((4 - 1) * (1 - 0.028557879547255083)) = 3 + 0.028557879547255083, + # id_up_to = 4 -> idx = round((4 - 1) * (1 - 0.209771474243926352)) = 2 + 0.209771474243926352, + ] + ) conflict_action = random.choice(["index", "update"]) type_name = random.choice([None, "test_type"]) - generator = params.GenerateActionMetaData("test_index", type_name=type_name, - conflicting_ids=[100, 200, 300, 400, 500, 600], - conflict_probability=25, - # heavily biased towards recent ids - recency=1.0, - on_conflict=conflict_action, - rand=lambda: next(pseudo_random_conflicts), - # we don't use this one here because recency is > 0. - # randint=lambda x, y: next(chosen_index_of_conflicting_ids), - randexp=lambda lmbda: next(pseudo_exponential_distribution) - ) + generator = params.GenerateActionMetaData( + "test_index", + type_name=type_name, + conflicting_ids=[100, 200, 300, 400, 500, 600], + conflict_probability=25, + # heavily biased towards recent ids + recency=1.0, + on_conflict=conflict_action, + rand=lambda: next(pseudo_random_conflicts), + # we don't use this one here because recency is > 0. + # randint=lambda x, y: next(chosen_index_of_conflicting_ids), + randexp=lambda lmbda: next(pseudo_exponential_distribution), + ) # first one is always *not* drawn from a random index self.assertEqual(idx(type_name, "100"), next(generator)) @@ -301,34 +277,22 @@ def idx(id): test_ids = [100, 200, 300, 400] - generator = params.GenerateActionMetaData("test_index", "test_type", - conflicting_ids=test_ids, - conflict_probability=0) + generator = params.GenerateActionMetaData("test_index", "test_type", conflicting_ids=test_ids, conflict_probability=0) self.assertListEqual([idx(id) for id in test_ids], list(generator)) class IndexDataReaderTests(TestCase): def test_read_bulk_larger_than_number_of_docs(self): - data = [ - b'{"key": "value1"}\n', - b'{"key": "value2"}\n', - b'{"key": "value3"}\n', - b'{"key": "value4"}\n', - b'{"key": "value5"}\n' - ] + data = [b'{"key": "value1"}\n', b'{"key": "value2"}\n', b'{"key": "value3"}\n', b'{"key": "value4"}\n', b'{"key": "value5"}\n'] bulk_size = 50 source = params.Slice(io.StringAsFileSource, 0, len(data), self.corpus("a", [self.docs(80)]), None) am_handler = params.GenerateActionMetaData("test_index", "test_type") - reader = params.MetadataIndexDataReader(data, - batch_size=bulk_size, - bulk_size=bulk_size, - file_source=source, - action_metadata=am_handler, - index_name="test_index", - type_name="test_type") + reader = params.MetadataIndexDataReader( + data, batch_size=bulk_size, bulk_size=bulk_size, file_source=source, action_metadata=am_handler, index_name="test_index", type_name="test_type" + ) expected_bulk_sizes = [len(data)] # lines should include meta-data @@ -336,25 +300,15 @@ def test_read_bulk_larger_than_number_of_docs(self): self.assert_bulks_sized(reader, expected_bulk_sizes, expected_line_sizes) def test_read_bulk_with_offset(self): - data = [ - b'{"key": "value1"}\n', - b'{"key": "value2"}\n', - b'{"key": "value3"}\n', - b'{"key": "value4"}\n', - b'{"key": "value5"}\n' - ] + data = [b'{"key": "value1"}\n', b'{"key": "value2"}\n', b'{"key": "value3"}\n', b'{"key": "value4"}\n', b'{"key": "value5"}\n'] bulk_size = 50 source = params.Slice(io.StringAsFileSource, 3, len(data), self.corpus("a", [self.docs(80)]), None) am_handler = params.GenerateActionMetaData("test_index", "test_type") - reader = params.MetadataIndexDataReader(data, - batch_size=bulk_size, - bulk_size=bulk_size, - file_source=source, - action_metadata=am_handler, - index_name="test_index", - type_name="test_type") + reader = params.MetadataIndexDataReader( + data, batch_size=bulk_size, bulk_size=bulk_size, file_source=source, action_metadata=am_handler, index_name="test_index", type_name="test_type" + ) expected_bulk_sizes = [(len(data) - 3)] # lines should include meta-data @@ -376,13 +330,9 @@ def test_read_bulk_smaller_than_number_of_docs(self): source = params.Slice(io.StringAsFileSource, 0, len(data), self.corpus("a", [self.docs(80)]), None) am_handler = params.GenerateActionMetaData("test_index", "test_type") - reader = params.MetadataIndexDataReader(data, - batch_size=bulk_size, - bulk_size=bulk_size, - file_source=source, - action_metadata=am_handler, - index_name="test_index", - type_name="test_type") + reader = params.MetadataIndexDataReader( + data, batch_size=bulk_size, bulk_size=bulk_size, file_source=source, action_metadata=am_handler, index_name="test_index", type_name="test_type" + ) expected_bulk_sizes = [3, 3, 1] # lines should include meta-data @@ -405,13 +355,9 @@ def test_read_bulk_smaller_than_number_of_docs_and_multiple_clients(self): source = params.Slice(io.StringAsFileSource, 0, 5, self.corpus("a", [self.docs(80)]), None) am_handler = params.GenerateActionMetaData("test_index", "test_type") - reader = params.MetadataIndexDataReader(data, - batch_size=bulk_size, - bulk_size=bulk_size, - file_source=source, - action_metadata=am_handler, - index_name="test_index", - type_name="test_type") + reader = params.MetadataIndexDataReader( + data, batch_size=bulk_size, bulk_size=bulk_size, file_source=source, action_metadata=am_handler, index_name="test_index", type_name="test_type" + ) expected_bulk_sizes = [3, 2] # lines should include meta-data @@ -433,18 +379,13 @@ def test_read_bulks_and_assume_metadata_line_in_source_file(self): b'{"index": {"_index": "test_index", "_type": "test_type"}\n', b'{"key": "value6"}\n', b'{"index": {"_index": "test_index", "_type": "test_type"}\n', - b'{"key": "value7"}\n' + b'{"key": "value7"}\n', ] bulk_size = 3 source = params.Slice(io.StringAsFileSource, 0, len(data), self.corpus("a", [self.docs(80)]), None) - reader = params.SourceOnlyIndexDataReader(data, - batch_size=bulk_size, - bulk_size=bulk_size, - file_source=source, - index_name="test_index", - type_name="test_type") + reader = params.SourceOnlyIndexDataReader(data, batch_size=bulk_size, bulk_size=bulk_size, file_source=source, index_name="test_index", type_name="test_type") expected_bulk_sizes = [3, 3, 1] # lines should include meta-data @@ -452,45 +393,43 @@ def test_read_bulks_and_assume_metadata_line_in_source_file(self): self.assert_bulks_sized(reader, expected_bulk_sizes, expected_line_sizes) def test_read_bulk_with_id_conflicts(self): - pseudo_random_conflicts = iter([ - # if this value is <= our chosen threshold of 0.25 (see conflict_probability) we produce a conflict. - 0.2, - 0.25, - 0.2, - # no conflict - 0.3 - ]) - - chosen_index_of_conflicting_ids = iter([ - # the "random" index of the id in the array `conflicting_ids` that will produce a conflict - 1, - 3, - 2]) + pseudo_random_conflicts = iter( + [ + # if this value is <= our chosen threshold of 0.25 (see conflict_probability) we produce a conflict. + 0.2, + 0.25, + 0.2, + # no conflict + 0.3, + ] + ) - data = [ - b'{"key": "value1"}\n', - b'{"key": "value2"}\n', - b'{"key": "value3"}\n', - b'{"key": "value4"}\n', - b'{"key": "value5"}\n' - ] + chosen_index_of_conflicting_ids = iter( + [ + # the "random" index of the id in the array `conflicting_ids` that will produce a conflict + 1, + 3, + 2, + ] + ) + + data = [b'{"key": "value1"}\n', b'{"key": "value2"}\n', b'{"key": "value3"}\n', b'{"key": "value4"}\n', b'{"key": "value5"}\n'] bulk_size = 2 source = params.Slice(io.StringAsFileSource, 0, len(data), self.corpus("a", [self.docs(80)]), None) - am_handler = params.GenerateActionMetaData("test_index", "test_type", - conflicting_ids=[100, 200, 300, 400], - conflict_probability=25, - on_conflict="update", - rand=lambda: next(pseudo_random_conflicts), - randint=lambda x, y: next(chosen_index_of_conflicting_ids)) - - reader = params.MetadataIndexDataReader(data, - batch_size=bulk_size, - bulk_size=bulk_size, - file_source=source, - action_metadata=am_handler, - index_name="test_index", - type_name="test_type") + am_handler = params.GenerateActionMetaData( + "test_index", + "test_type", + conflicting_ids=[100, 200, 300, 400], + conflict_probability=25, + on_conflict="update", + rand=lambda: next(pseudo_random_conflicts), + randint=lambda x, y: next(chosen_index_of_conflicting_ids), + ) + + reader = params.MetadataIndexDataReader( + data, batch_size=bulk_size, bulk_size=bulk_size, file_source=source, action_metadata=am_handler, index_name="test_index", type_name="test_type" + ) # consume all bulks bulks = [] @@ -499,40 +438,31 @@ def test_read_bulk_with_id_conflicts(self): for bulk_size, bulk in batch: bulks.append(bulk) - self.assertEqual([ - b'{"index": {"_index": "test_index", "_type": "test_type", "_id": "100"}}\n' + - b'{"key": "value1"}\n' + - b'{"update": {"_index": "test_index", "_type": "test_type", "_id": "200"}}\n' + - b'{"doc":{"key": "value2"}}\n', - b'{"update": {"_index": "test_index", "_type": "test_type", "_id": "400"}}\n' + - b'{"doc":{"key": "value3"}}\n' + - b'{"update": {"_index": "test_index", "_type": "test_type", "_id": "300"}}\n' + - b'{"doc":{"key": "value4"}}\n', - b'{"index": {"_index": "test_index", "_type": "test_type", "_id": "200"}}\n' + - b'{"key": "value5"}\n' - ], bulks) + self.assertEqual( + [ + b'{"index": {"_index": "test_index", "_type": "test_type", "_id": "100"}}\n' + + b'{"key": "value1"}\n' + + b'{"update": {"_index": "test_index", "_type": "test_type", "_id": "200"}}\n' + + b'{"doc":{"key": "value2"}}\n', + b'{"update": {"_index": "test_index", "_type": "test_type", "_id": "400"}}\n' + + b'{"doc":{"key": "value3"}}\n' + + b'{"update": {"_index": "test_index", "_type": "test_type", "_id": "300"}}\n' + + b'{"doc":{"key": "value4"}}\n', + b'{"index": {"_index": "test_index", "_type": "test_type", "_id": "200"}}\n' + b'{"key": "value5"}\n', + ], + bulks, + ) def test_read_bulk_with_external_id_and_zero_conflict_probability(self): - data = [ - b'{"key": "value1"}\n', - b'{"key": "value2"}\n', - b'{"key": "value3"}\n', - b'{"key": "value4"}\n' - ] + data = [b'{"key": "value1"}\n', b'{"key": "value2"}\n', b'{"key": "value3"}\n', b'{"key": "value4"}\n'] bulk_size = 2 source = params.Slice(io.StringAsFileSource, 0, len(data), self.corpus("a", [self.docs(80)]), None) - am_handler = params.GenerateActionMetaData("test_index", "test_type", - conflicting_ids=[100, 200, 300, 400], - conflict_probability=0) - - reader = params.MetadataIndexDataReader(data, - batch_size=bulk_size, - bulk_size=bulk_size, - file_source=source, - action_metadata=am_handler, - index_name="test_index", - type_name="test_type") + am_handler = params.GenerateActionMetaData("test_index", "test_type", conflicting_ids=[100, 200, 300, 400], conflict_probability=0) + + reader = params.MetadataIndexDataReader( + data, batch_size=bulk_size, bulk_size=bulk_size, file_source=source, action_metadata=am_handler, index_name="test_index", type_name="test_type" + ) # consume all bulks bulks = [] @@ -541,17 +471,19 @@ def test_read_bulk_with_external_id_and_zero_conflict_probability(self): for bulk_size, bulk in batch: bulks.append(bulk) - self.assertEqual([ - b'{"index": {"_index": "test_index", "_type": "test_type", "_id": "100"}}\n' + - b'{"key": "value1"}\n' + - b'{"index": {"_index": "test_index", "_type": "test_type", "_id": "200"}}\n' + - b'{"key": "value2"}\n', - - b'{"index": {"_index": "test_index", "_type": "test_type", "_id": "300"}}\n' + - b'{"key": "value3"}\n' + - b'{"index": {"_index": "test_index", "_type": "test_type", "_id": "400"}}\n' + - b'{"key": "value4"}\n' - ], bulks) + self.assertEqual( + [ + b'{"index": {"_index": "test_index", "_type": "test_type", "_id": "100"}}\n' + + b'{"key": "value1"}\n' + + b'{"index": {"_index": "test_index", "_type": "test_type", "_id": "200"}}\n' + + b'{"key": "value2"}\n', + b'{"index": {"_index": "test_index", "_type": "test_type", "_id": "300"}}\n' + + b'{"key": "value3"}\n' + + b'{"index": {"_index": "test_index", "_type": "test_type", "_id": "400"}}\n' + + b'{"key": "value4"}\n', + ], + bulks, + ) def assert_bulks_sized(self, reader, expected_bulk_sizes, expected_line_sizes): self.assertEqual(len(expected_bulk_sizes), len(expected_line_sizes), "Bulk sizes and line sizes must be equal") @@ -693,18 +625,18 @@ def test_calculate_number_of_bulks(self): self.assertEqual(1, self.number_of_bulks([self.corpus("a", [docs1])], 0, 0, 1, 1)) self.assertEqual(1, self.number_of_bulks([self.corpus("a", [docs1])], 0, 0, 1, 2)) - self.assertEqual(20, self.number_of_bulks( - [self.corpus("a", [docs2, docs2, docs2, docs2, docs1]), - self.corpus("b", [docs2, docs2, docs2, docs2, docs2, docs1])], 0, 0, 1, 1)) - self.assertEqual(11, self.number_of_bulks( - [self.corpus("a", [docs2, docs2, docs2, docs2, docs1]), - self.corpus("b", [docs2, docs2, docs2, docs2, docs2, docs1])], 0, 0, 1, 2)) - self.assertEqual(11, self.number_of_bulks( - [self.corpus("a", [docs2, docs2, docs2, docs2, docs1]), - self.corpus("b", [docs2, docs2, docs2, docs2, docs2, docs1])], 0, 0, 1, 3)) - self.assertEqual(11, self.number_of_bulks( - [self.corpus("a", [docs2, docs2, docs2, docs2, docs1]), - self.corpus("b", [docs2, docs2, docs2, docs2, docs2, docs1])], 0, 0, 1, 100)) + self.assertEqual( + 20, self.number_of_bulks([self.corpus("a", [docs2, docs2, docs2, docs2, docs1]), self.corpus("b", [docs2, docs2, docs2, docs2, docs2, docs1])], 0, 0, 1, 1) + ) + self.assertEqual( + 11, self.number_of_bulks([self.corpus("a", [docs2, docs2, docs2, docs2, docs1]), self.corpus("b", [docs2, docs2, docs2, docs2, docs2, docs1])], 0, 0, 1, 2) + ) + self.assertEqual( + 11, self.number_of_bulks([self.corpus("a", [docs2, docs2, docs2, docs2, docs1]), self.corpus("b", [docs2, docs2, docs2, docs2, docs2, docs1])], 0, 0, 1, 3) + ) + self.assertEqual( + 11, self.number_of_bulks([self.corpus("a", [docs2, docs2, docs2, docs2, docs1]), self.corpus("b", [docs2, docs2, docs2, docs2, docs2, docs1])], 0, 0, 1, 100) + ) self.assertEqual(2, self.number_of_bulks([self.corpus("a", [self.docs(800)])], 0, 0, 3, 250)) self.assertEqual(1, self.number_of_bulks([self.corpus("a", [self.docs(800)])], 0, 0, 3, 267)) @@ -727,8 +659,7 @@ def number_of_bulks(corpora, first_partition_index, last_partition_index, total_ def test_build_conflicting_ids(self): self.assertIsNone(params.build_conflicting_ids(params.IndexIdConflict.NoConflicts, 3, 0)) - self.assertEqual(["0000000000", "0000000001", "0000000002"], - params.build_conflicting_ids(params.IndexIdConflict.SequentialConflicts, 3, 0)) + self.assertEqual(["0000000000", "0000000001", "0000000002"], params.build_conflicting_ids(params.IndexIdConflict.SequentialConflicts, 3, 0)) # we cannot tell anything specific about the contents... self.assertEqual(3, len(params.build_conflicting_ids(params.IndexIdConflict.RandomConflicts, 3, 0))) @@ -736,12 +667,10 @@ def test_build_conflicting_ids(self): # pylint: disable=too-many-public-methods class BulkIndexParamSourceTests(TestCase): def test_create_without_params(self): - corpus = workload.DocumentCorpus(name="default", documents=[ - workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, - number_of_documents=10, - target_collection="test-idx", - target_type="test-type" - )]) + corpus = workload.DocumentCorpus( + name="default", + documents=[workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, number_of_documents=10, target_collection="test-idx", target_type="test-type")], + ) with self.assertRaises(exceptions.InvalidSyntax) as ctx: params.BulkIndexParamSource(workload=workload.Workload(name="unit-test", corpora=[corpus]), params={}) @@ -752,283 +681,209 @@ def test_create_without_corpora_definition(self): with self.assertRaises(exceptions.InvalidSyntax) as ctx: params.BulkIndexParamSource(workload=workload.Workload(name="unit-test"), params={}) - self.assertEqual("There is no document corpus definition for workload unit-test. " - "You must add at least one before making bulk requests to the target cluster.", ctx.exception.args[0]) + self.assertEqual( + "There is no document corpus definition for workload unit-test. You must add at least one before making bulk requests to the target cluster.", ctx.exception.args[0] + ) def test_create_with_non_numeric_bulk_size(self): - corpus = workload.DocumentCorpus(name="default", documents=[ - workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, - number_of_documents=10, - target_collection="test-idx", - target_type="test-type" - )]) + corpus = workload.DocumentCorpus( + name="default", + documents=[workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, number_of_documents=10, target_collection="test-idx", target_type="test-type")], + ) with self.assertRaises(exceptions.InvalidSyntax) as ctx: - params.BulkIndexParamSource(workload=workload.Workload(name="unit-test", corpora=[corpus]), params={ - "bulk-size": "Three" - }) + params.BulkIndexParamSource(workload=workload.Workload(name="unit-test", corpora=[corpus]), params={"bulk-size": "Three"}) self.assertEqual("'bulk-size' must be numeric", ctx.exception.args[0]) def test_create_with_negative_bulk_size(self): - corpus = workload.DocumentCorpus(name="default", documents=[ - workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, - number_of_documents=10, - target_collection="test-idx", - target_type="test-type" - )]) + corpus = workload.DocumentCorpus( + name="default", + documents=[workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, number_of_documents=10, target_collection="test-idx", target_type="test-type")], + ) with self.assertRaises(exceptions.InvalidSyntax) as ctx: - params.BulkIndexParamSource(workload=workload.Workload(name="unit-test", corpora=[corpus]), params={ - "bulk-size": -5 - }) + params.BulkIndexParamSource(workload=workload.Workload(name="unit-test", corpora=[corpus]), params={"bulk-size": -5}) self.assertEqual("'bulk-size' must be positive but was -5", ctx.exception.args[0]) def test_create_with_fraction_smaller_batch_size(self): - corpus = workload.DocumentCorpus(name="default", documents=[ - workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, - number_of_documents=10, - target_collection="test-idx", - target_type="test-type" - )]) + corpus = workload.DocumentCorpus( + name="default", + documents=[workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, number_of_documents=10, target_collection="test-idx", target_type="test-type")], + ) with self.assertRaises(exceptions.InvalidSyntax) as ctx: - params.BulkIndexParamSource(workload=workload.Workload(name="unit-test", corpora=[corpus]), params={ - "bulk-size": 5, - "batch-size": 3 - }) + params.BulkIndexParamSource(workload=workload.Workload(name="unit-test", corpora=[corpus]), params={"bulk-size": 5, "batch-size": 3}) self.assertEqual("'batch-size' must be greater than or equal to 'bulk-size'", ctx.exception.args[0]) def test_create_with_fraction_larger_batch_size(self): - corpus = workload.DocumentCorpus(name="default", documents=[ - workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, - number_of_documents=10, - target_collection="test-idx", - target_type="test-type" - )]) + corpus = workload.DocumentCorpus( + name="default", + documents=[workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, number_of_documents=10, target_collection="test-idx", target_type="test-type")], + ) with self.assertRaises(exceptions.InvalidSyntax) as ctx: - params.BulkIndexParamSource(workload=workload.Workload(name="unit-test", corpora=[corpus]), params={ - "bulk-size": 5, - "batch-size": 8 - }) + params.BulkIndexParamSource(workload=workload.Workload(name="unit-test", corpora=[corpus]), params={"bulk-size": 5, "batch-size": 8}) self.assertEqual("'batch-size' must be a multiple of 'bulk-size'", ctx.exception.args[0]) def test_create_with_metadata_in_source_file_but_conflicts(self): - corpus = workload.DocumentCorpus(name="default", documents=[ - workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, - document_archive="docs.json.bz2", - document_file="docs.json", - number_of_documents=10, - includes_action_and_meta_data=True) - ]) + corpus = workload.DocumentCorpus( + name="default", + documents=[ + workload.Documents( + source_format=workload.Documents.SOURCE_FORMAT_BULK, + document_archive="docs.json.bz2", + document_file="docs.json", + number_of_documents=10, + includes_action_and_meta_data=True, + ) + ], + ) with self.assertRaises(exceptions.InvalidSyntax) as ctx: - params.BulkIndexParamSource(workload=workload.Workload(name="unit-test", corpora=[corpus]), params={ - "conflicts": "random" - }) + params.BulkIndexParamSource(workload=workload.Workload(name="unit-test", corpora=[corpus]), params={"conflicts": "random"}) - self.assertEqual("Cannot generate id conflicts [random] as [docs.json.bz2] in document corpus [default] already contains " - "an action and meta-data line.", ctx.exception.args[0]) + self.assertEqual( + "Cannot generate id conflicts [random] as [docs.json.bz2] in document corpus [default] already contains an action and meta-data line.", ctx.exception.args[0] + ) def test_create_with_unknown_id_conflicts(self): with self.assertRaises(exceptions.InvalidSyntax) as ctx: - params.BulkIndexParamSource(workload=workload.Workload(name="unit-test"), params={ - "conflicts": "crazy" - }) + params.BulkIndexParamSource(workload=workload.Workload(name="unit-test"), params={"conflicts": "crazy"}) self.assertEqual("Unknown 'conflicts' setting [crazy]", ctx.exception.args[0]) def test_create_with_unknown_on_conflict_setting(self): with self.assertRaises(exceptions.InvalidSyntax) as ctx: - params.BulkIndexParamSource(workload=workload.Workload(name="unit-test"), params={ - "conflicts": "sequential", - "on-conflict": "delete" - }) + params.BulkIndexParamSource(workload=workload.Workload(name="unit-test"), params={"conflicts": "sequential", "on-conflict": "delete"}) self.assertEqual("Unknown 'on-conflict' setting [delete]", ctx.exception.args[0]) def test_create_with_conflicts_and_data_streams(self): with self.assertRaises(exceptions.InvalidSyntax) as ctx: - params.BulkIndexParamSource(workload=workload.Workload(name="unit-test"), params={ - "data-streams": ["test-data-stream-1", "test-data-stream-2"], - "conflicts": "sequential" - }) + params.BulkIndexParamSource( + workload=workload.Workload(name="unit-test"), params={"data-streams": ["test-data-stream-1", "test-data-stream-2"], "conflicts": "sequential"} + ) self.assertEqual("'conflicts' cannot be used with 'data-streams'", ctx.exception.args[0]) def test_create_with_ingest_percentage_too_low(self): - corpus = workload.DocumentCorpus(name="default", documents=[ - workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, - number_of_documents=10, - target_collection="test-idx", - target_type="test-type" - )]) + corpus = workload.DocumentCorpus( + name="default", + documents=[workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, number_of_documents=10, target_collection="test-idx", target_type="test-type")], + ) with self.assertRaises(exceptions.InvalidSyntax) as ctx: - params.BulkIndexParamSource(workload=workload.Workload(name="unit-test", corpora=[corpus]), params={ - "bulk-size": 5000, - "ingest-percentage": 0.0 - }) + params.BulkIndexParamSource(workload=workload.Workload(name="unit-test", corpora=[corpus]), params={"bulk-size": 5000, "ingest-percentage": 0.0}) self.assertEqual("'ingest-percentage' must be in the range (0.0, 100.0] but was 0.0", ctx.exception.args[0]) def test_create_with_ingest_percentage_too_high(self): - corpus = workload.DocumentCorpus(name="default", documents=[ - workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, - number_of_documents=10, - target_collection="test-idx", - target_type="test-type" - )]) + corpus = workload.DocumentCorpus( + name="default", + documents=[workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, number_of_documents=10, target_collection="test-idx", target_type="test-type")], + ) with self.assertRaises(exceptions.InvalidSyntax) as ctx: - params.BulkIndexParamSource(workload=workload.Workload(name="unit-test", corpora=[corpus]), params={ - "bulk-size": 5000, - "ingest-percentage": 100.1 - }) + params.BulkIndexParamSource(workload=workload.Workload(name="unit-test", corpora=[corpus]), params={"bulk-size": 5000, "ingest-percentage": 100.1}) self.assertEqual("'ingest-percentage' must be in the range (0.0, 100.0] but was 100.1", ctx.exception.args[0]) def test_create_with_ingest_percentage_not_numeric(self): - corpus = workload.DocumentCorpus(name="default", documents=[ - workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, - number_of_documents=10, - target_collection="test-idx", - target_type="test-type" - )]) + corpus = workload.DocumentCorpus( + name="default", + documents=[workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, number_of_documents=10, target_collection="test-idx", target_type="test-type")], + ) with self.assertRaises(exceptions.InvalidSyntax) as ctx: - params.BulkIndexParamSource(workload=workload.Workload(name="unit-test", corpora=[corpus]), params={ - "bulk-size": 5000, - "ingest-percentage": "100 percent" - }) + params.BulkIndexParamSource(workload=workload.Workload(name="unit-test", corpora=[corpus]), params={"bulk-size": 5000, "ingest-percentage": "100 percent"}) self.assertEqual("'ingest-percentage' must be numeric", ctx.exception.args[0]) def test_create_valid_param_source(self): - corpus = workload.DocumentCorpus(name="default", documents=[ - workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, - number_of_documents=10, - target_collection="test-idx", - target_type="test-type" - )]) - - self.assertIsNotNone(params.BulkIndexParamSource(workload.Workload(name="unit-test", corpora=[corpus]), params={ - "conflicts": "random", - "bulk-size": 5000, - "batch-size": 20000, - "ingest-percentage": 20.5, - "pipeline": "test-pipeline" - })) + corpus = workload.DocumentCorpus( + name="default", + documents=[workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, number_of_documents=10, target_collection="test-idx", target_type="test-type")], + ) + + self.assertIsNotNone( + params.BulkIndexParamSource( + workload.Workload(name="unit-test", corpora=[corpus]), + params={"conflicts": "random", "bulk-size": 5000, "batch-size": 20000, "ingest-percentage": 20.5, "pipeline": "test-pipeline"}, + ) + ) def test_passes_all_corpora_by_default(self): corpora = [ - workload.DocumentCorpus(name="default", documents=[ - workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, - number_of_documents=10, - target_collection="test-idx", - target_type="test-type" - ) - ]), - workload.DocumentCorpus(name="special", documents=[ - workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, - number_of_documents=100, - target_collection="test-idx2", - target_type="type" - ) - ]), + workload.DocumentCorpus( + name="default", + documents=[workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, number_of_documents=10, target_collection="test-idx", target_type="test-type")], + ), + workload.DocumentCorpus( + name="special", + documents=[workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, number_of_documents=100, target_collection="test-idx2", target_type="type")], + ), ] source = params.BulkIndexParamSource( - workload=workload.Workload(name="unit-test", corpora=corpora), - params={ - "conflicts": "random", - "bulk-size": 5000, - "batch-size": 20000, - "pipeline": "test-pipeline" - }) + workload=workload.Workload(name="unit-test", corpora=corpora), params={"conflicts": "random", "bulk-size": 5000, "batch-size": 20000, "pipeline": "test-pipeline"} + ) partition = source.partition(0, 1) self.assertEqual(partition.corpora, corpora) def test_filters_corpora(self): corpora = [ - workload.DocumentCorpus(name="default", documents=[ - workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, - number_of_documents=10, - target_collection="test-idx", - target_type="test-type" - ) - ]), - workload.DocumentCorpus(name="special", documents=[ - workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, - number_of_documents=100, - target_collection="test-idx2", - target_type="type" - ) - ]), + workload.DocumentCorpus( + name="default", + documents=[workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, number_of_documents=10, target_collection="test-idx", target_type="test-type")], + ), + workload.DocumentCorpus( + name="special", + documents=[workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, number_of_documents=100, target_collection="test-idx2", target_type="type")], + ), ] source = params.BulkIndexParamSource( workload=workload.Workload(name="unit-test", corpora=corpora), - params={ - "corpora": ["special"], - "conflicts": "random", - "bulk-size": 5000, - "batch-size": 20000, - "pipeline": "test-pipeline" - }) + params={"corpora": ["special"], "conflicts": "random", "bulk-size": 5000, "batch-size": 20000, "pipeline": "test-pipeline"}, + ) partition = source.partition(0, 1) self.assertEqual(partition.corpora, [corpora[1]]) def test_raises_exception_if_no_corpus_matches(self): - corpus = workload.DocumentCorpus(name="default", documents=[ - workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, - number_of_documents=10, - target_collection="test-idx", - target_type="test-type" - )]) + corpus = workload.DocumentCorpus( + name="default", + documents=[workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, number_of_documents=10, target_collection="test-idx", target_type="test-type")], + ) with self.assertRaises(exceptions.BenchmarkAssertionError) as ctx: params.BulkIndexParamSource( workload=workload.Workload(name="unit-test", corpora=[corpus]), - params={ - "corpora": "does_not_exist", - "conflicts": "random", - "bulk-size": 5000, - "batch-size": 20000, - "pipeline": "test-pipeline" - }) + params={"corpora": "does_not_exist", "conflicts": "random", "bulk-size": 5000, "batch-size": 20000, "pipeline": "test-pipeline"}, + ) self.assertEqual("The provided corpus ['does_not_exist'] does not match any of the corpora ['default'].", ctx.exception.args[0]) def test_ingests_all_documents_by_default(self): corpora = [ - workload.DocumentCorpus(name="default", documents=[ - workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, - number_of_documents=300000, - target_collection="test-idx", - target_type="test-type" - ) - ]), - workload.DocumentCorpus(name="special", documents=[ - workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, - number_of_documents=700000, - target_collection="test-idx2", - target_type="type" - ) - ]), + workload.DocumentCorpus( + name="default", + documents=[ + workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, number_of_documents=300000, target_collection="test-idx", target_type="test-type") + ], + ), + workload.DocumentCorpus( + name="special", + documents=[workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, number_of_documents=700000, target_collection="test-idx2", target_type="type")], + ), ] - source = params.BulkIndexParamSource( - workload=workload.Workload(name="unit-test", corpora=corpora), - params={ - "bulk-size": 10000 - }) + source = params.BulkIndexParamSource(workload=workload.Workload(name="unit-test", corpora=corpora), params={"bulk-size": 10000}) partition = source.partition(0, 1) partition._init_internal_params() @@ -1037,18 +892,22 @@ def test_ingests_all_documents_by_default(self): def test_restricts_number_of_bulks_if_required(self): def create_unit_test_reader(*args): - return StaticBulkReader("idx", "doc", bulks=[ - ['{"location" : [-0.1485188, 51.5250666]}'], - ['{"location" : [-0.1479949, 51.5252071]}'], - ['{"location" : [-0.1458559, 51.5289059]}'], - ['{"location" : [-0.1498551, 51.5282564]}'], - ['{"location" : [-0.1487043, 51.5254843]}'], - ['{"location" : [-0.1533367, 51.5261779]}'], - ['{"location" : [-0.1543018, 51.5262398]}'], - ['{"location" : [-0.1522118, 51.5266564]}'], - ['{"location" : [-0.1529092, 51.5263360]}'], - ['{"location" : [-0.1537008, 51.5265365]}'], - ]) + return StaticBulkReader( + "idx", + "doc", + bulks=[ + ['{"location" : [-0.1485188, 51.5250666]}'], + ['{"location" : [-0.1479949, 51.5252071]}'], + ['{"location" : [-0.1458559, 51.5289059]}'], + ['{"location" : [-0.1498551, 51.5282564]}'], + ['{"location" : [-0.1487043, 51.5254843]}'], + ['{"location" : [-0.1533367, 51.5261779]}'], + ['{"location" : [-0.1543018, 51.5262398]}'], + ['{"location" : [-0.1522118, 51.5266564]}'], + ['{"location" : [-0.1529092, 51.5263360]}'], + ['{"location" : [-0.1537008, 51.5265365]}'], + ], + ) def schedule(param_source): while True: @@ -1058,29 +917,21 @@ def schedule(param_source): return corpora = [ - workload.DocumentCorpus(name="default", documents=[ - workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, - number_of_documents=300000, - target_collection="test-idx", - target_type="test-type" - ) - ]), - workload.DocumentCorpus(name="special", documents=[ - workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, - number_of_documents=700000, - target_collection="test-idx2", - target_type="type" - ) - ]), + workload.DocumentCorpus( + name="default", + documents=[ + workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, number_of_documents=300000, target_collection="test-idx", target_type="test-type") + ], + ), + workload.DocumentCorpus( + name="special", + documents=[workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, number_of_documents=700000, target_collection="test-idx2", target_type="type")], + ), ] source = params.BulkIndexParamSource( - workload=workload.Workload(name="unit-test", corpora=corpora), - params={ - "bulk-size": 10000, - "ingest-percentage": 2.5, - "__create_reader": create_unit_test_reader - }) + workload=workload.Workload(name="unit-test", corpora=corpora), params={"bulk-size": 10000, "ingest-percentage": 2.5, "__create_reader": create_unit_test_reader} + ) partition = source.partition(0, 1) partition._init_internal_params() @@ -1089,46 +940,30 @@ def schedule(param_source): self.assertEqual(3, len(list(schedule(partition)))) def test_create_with_conflict_probability_zero(self): - corpus = workload.DocumentCorpus(name="default", documents=[ - workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, - number_of_documents=10, - target_collection="test-idx", - target_type="test-type" - )]) - - params.BulkIndexParamSource(workload=workload.Workload(name="unit-test", corpora=[corpus]), params={ - "bulk-size": 5000, - "conflicts": "sequential", - "conflict-probability": 0 - }) + corpus = workload.DocumentCorpus( + name="default", + documents=[workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, number_of_documents=10, target_collection="test-idx", target_type="test-type")], + ) + + params.BulkIndexParamSource( + workload=workload.Workload(name="unit-test", corpora=[corpus]), params={"bulk-size": 5000, "conflicts": "sequential", "conflict-probability": 0} + ) def test_create_with_conflict_probability_too_low(self): with self.assertRaises(exceptions.InvalidSyntax) as ctx: - params.BulkIndexParamSource(workload=workload.Workload(name="unit-test"), params={ - "bulk-size": 5000, - "conflicts": "sequential", - "conflict-probability": -0.1 - }) + params.BulkIndexParamSource(workload=workload.Workload(name="unit-test"), params={"bulk-size": 5000, "conflicts": "sequential", "conflict-probability": -0.1}) self.assertEqual("'conflict-probability' must be in the range [0.0, 100.0] but was -0.1", ctx.exception.args[0]) def test_create_with_conflict_probability_too_high(self): with self.assertRaises(exceptions.InvalidSyntax) as ctx: - params.BulkIndexParamSource(workload=workload.Workload(name="unit-test"), params={ - "bulk-size": 5000, - "conflicts": "sequential", - "conflict-probability": 100.1 - }) + params.BulkIndexParamSource(workload=workload.Workload(name="unit-test"), params={"bulk-size": 5000, "conflicts": "sequential", "conflict-probability": 100.1}) self.assertEqual("'conflict-probability' must be in the range [0.0, 100.0] but was 100.1", ctx.exception.args[0]) def test_create_with_conflict_probability_not_numeric(self): with self.assertRaises(exceptions.InvalidSyntax) as ctx: - params.BulkIndexParamSource(workload=workload.Workload(name="unit-test"), params={ - "bulk-size": 5000, - "conflicts": "sequential", - "conflict-probability": "100 percent" - }) + params.BulkIndexParamSource(workload=workload.Workload(name="unit-test"), params={"bulk-size": 5000, "conflicts": "sequential", "conflict-probability": "100 percent"}) self.assertEqual("'conflict-probability' must be numeric", ctx.exception.args[0]) @@ -1142,6 +977,7 @@ def create_unit_test_reader(*args): ['{"location" : [-0.1479949, 51.5252071]}'], ], ) + corpora = [ workload.DocumentCorpus( name="default", @@ -1176,7 +1012,6 @@ def create_unit_test_reader(*args): class BulkDataGeneratorTests(TestCase): - @classmethod def create_test_reader(cls, batches): def inner_create_test_reader(corpus, docs, *args): @@ -1185,160 +1020,168 @@ def inner_create_test_reader(corpus, docs, *args): return inner_create_test_reader def test_generate_two_bulks(self): - corpus = workload.DocumentCorpus(name="default", documents=[ - workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, - number_of_documents=10, - target_collection="test-idx", - target_type="test-type" - ) - ]) - - bulks = params.bulk_data_based(num_clients=1, start_client_index=0, end_client_index=0, corpora=[corpus], - batch_size=5, bulk_size=5, - id_conflicts=params.IndexIdConflict.NoConflicts, conflict_probability=None, on_conflict=None, - recency=None, pipeline=None, - original_params={ - "my-custom-parameter": "foo", - "my-custom-parameter-2": True - }, create_reader=BulkDataGeneratorTests. - create_test_reader([["1", "2", "3", "4", "5"], ["6", "7", "8"]])) + corpus = workload.DocumentCorpus( + name="default", + documents=[workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, number_of_documents=10, target_collection="test-idx", target_type="test-type")], + ) + + bulks = params.bulk_data_based( + num_clients=1, + start_client_index=0, + end_client_index=0, + corpora=[corpus], + batch_size=5, + bulk_size=5, + id_conflicts=params.IndexIdConflict.NoConflicts, + conflict_probability=None, + on_conflict=None, + recency=None, + pipeline=None, + original_params={"my-custom-parameter": "foo", "my-custom-parameter-2": True}, + create_reader=BulkDataGeneratorTests.create_test_reader([["1", "2", "3", "4", "5"], ["6", "7", "8"]]), + ) all_bulks = list(bulks) self.assertEqual(2, len(all_bulks)) - self.assertEqual({ - "action-metadata-present": True, - "body": ["1", "2", "3", "4", "5"], - "bulk-size": 5, - "unit": "docs", - "index": "test-idx", - "type": "test-type", - "my-custom-parameter": "foo", - "my-custom-parameter-2": True - }, all_bulks[0]) - - self.assertEqual({ - "action-metadata-present": True, - "body": ["6", "7", "8"], - "bulk-size": 3, - "unit": "docs", - "index": "test-idx", - "type": "test-type", - "my-custom-parameter": "foo", - "my-custom-parameter-2": True - }, all_bulks[1]) + self.assertEqual( + { + "action-metadata-present": True, + "body": ["1", "2", "3", "4", "5"], + "bulk-size": 5, + "unit": "docs", + "index": "test-idx", + "type": "test-type", + "my-custom-parameter": "foo", + "my-custom-parameter-2": True, + }, + all_bulks[0], + ) + + self.assertEqual( + { + "action-metadata-present": True, + "body": ["6", "7", "8"], + "bulk-size": 3, + "unit": "docs", + "index": "test-idx", + "type": "test-type", + "my-custom-parameter": "foo", + "my-custom-parameter-2": True, + }, + all_bulks[1], + ) def test_generate_bulks_from_multiple_corpora(self): corpora = [ - workload.DocumentCorpus(name="default", documents=[ - workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, - number_of_documents=5, - target_collection="logs-2018-01", - target_type="docs" - ), - workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, - number_of_documents=5, - target_collection="logs-2018-02", - target_type="docs" - ), - - ]), - workload.DocumentCorpus(name="special", documents=[ - workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, - number_of_documents=5, - target_collection="logs-2017-01", - target_type="docs" - ) - ]) - - ] + workload.DocumentCorpus( + name="default", + documents=[ + workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, number_of_documents=5, target_collection="logs-2018-01", target_type="docs"), + workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, number_of_documents=5, target_collection="logs-2018-02", target_type="docs"), + ], + ), + workload.DocumentCorpus( + name="special", + documents=[workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, number_of_documents=5, target_collection="logs-2017-01", target_type="docs")], + ), + ] - bulks = params.bulk_data_based(num_clients=1, start_client_index=0, end_client_index=0, corpora=corpora, - batch_size=5, bulk_size=5, - id_conflicts=params.IndexIdConflict.NoConflicts, conflict_probability=None, on_conflict=None, - recency=None, pipeline=None, - original_params={ - "my-custom-parameter": "foo", - "my-custom-parameter-2": True - }, create_reader=BulkDataGeneratorTests. - create_test_reader([["1", "2", "3", "4", "5"]])) + bulks = params.bulk_data_based( + num_clients=1, + start_client_index=0, + end_client_index=0, + corpora=corpora, + batch_size=5, + bulk_size=5, + id_conflicts=params.IndexIdConflict.NoConflicts, + conflict_probability=None, + on_conflict=None, + recency=None, + pipeline=None, + original_params={"my-custom-parameter": "foo", "my-custom-parameter-2": True}, + create_reader=BulkDataGeneratorTests.create_test_reader([["1", "2", "3", "4", "5"]]), + ) all_bulks = list(bulks) self.assertEqual(3, len(all_bulks)) - self.assertEqual({ - "action-metadata-present": True, - "body": ["1", "2", "3", "4", "5"], - "bulk-size": 5, - "unit": "docs", - "index": "logs-2018-01", - "type": "docs", - "my-custom-parameter": "foo", - "my-custom-parameter-2": True - }, all_bulks[0]) - - self.assertEqual({ - "action-metadata-present": True, - "body": ["1", "2", "3", "4", "5"], - "bulk-size": 5, - "unit": "docs", - "index": "logs-2018-02", - "type": "docs", - "my-custom-parameter": "foo", - "my-custom-parameter-2": True - }, all_bulks[1]) - - self.assertEqual({ - "action-metadata-present": True, - "body": ["1", "2", "3", "4", "5"], - "bulk-size": 5, - "unit": "docs", - "index": "logs-2017-01", - "type": "docs", - "my-custom-parameter": "foo", - "my-custom-parameter-2": True - }, all_bulks[2]) + self.assertEqual( + { + "action-metadata-present": True, + "body": ["1", "2", "3", "4", "5"], + "bulk-size": 5, + "unit": "docs", + "index": "logs-2018-01", + "type": "docs", + "my-custom-parameter": "foo", + "my-custom-parameter-2": True, + }, + all_bulks[0], + ) + + self.assertEqual( + { + "action-metadata-present": True, + "body": ["1", "2", "3", "4", "5"], + "bulk-size": 5, + "unit": "docs", + "index": "logs-2018-02", + "type": "docs", + "my-custom-parameter": "foo", + "my-custom-parameter-2": True, + }, + all_bulks[1], + ) + + self.assertEqual( + { + "action-metadata-present": True, + "body": ["1", "2", "3", "4", "5"], + "bulk-size": 5, + "unit": "docs", + "index": "logs-2017-01", + "type": "docs", + "my-custom-parameter": "foo", + "my-custom-parameter-2": True, + }, + all_bulks[2], + ) def test_internal_params_take_precedence(self): - corpus = workload.DocumentCorpus(name="default", documents=[ - workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, - number_of_documents=3, - target_collection="test-idx", - target_type="test-type" - ) - ]) - - bulks = params.bulk_data_based(num_clients=1, start_client_index=0, end_client_index=0, corpora=[corpus], - batch_size=3, bulk_size=3, id_conflicts=params.IndexIdConflict.NoConflicts, - conflict_probability=None, on_conflict=None, - recency=None, pipeline=None, - original_params={ - "body": "foo", - "custom-param": "bar" - }, create_reader=BulkDataGeneratorTests. - create_test_reader([["1", "2", "3"]])) + corpus = workload.DocumentCorpus( + name="default", + documents=[workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, number_of_documents=3, target_collection="test-idx", target_type="test-type")], + ) + + bulks = params.bulk_data_based( + num_clients=1, + start_client_index=0, + end_client_index=0, + corpora=[corpus], + batch_size=3, + bulk_size=3, + id_conflicts=params.IndexIdConflict.NoConflicts, + conflict_probability=None, + on_conflict=None, + recency=None, + pipeline=None, + original_params={"body": "foo", "custom-param": "bar"}, + create_reader=BulkDataGeneratorTests.create_test_reader([["1", "2", "3"]]), + ) all_bulks = list(bulks) self.assertEqual(1, len(all_bulks)) # body must not contain 'foo'! - self.assertEqual({ - "action-metadata-present": True, - "body": ["1", "2", "3"], - "bulk-size": 3, - "unit": "docs", - "index": "test-idx", - "type": "test-type", - "custom-param": "bar" - }, all_bulks[0]) + self.assertEqual( + {"action-metadata-present": True, "body": ["1", "2", "3"], "bulk-size": 3, "unit": "docs", "index": "test-idx", "type": "test-type", "custom-param": "bar"}, + all_bulks[0], + ) class ParamsRegistrationTests(TestCase): @staticmethod def param_source_legacy_function(indices, params): - return { - "key": params["parameter"] - } + return {"key": params["parameter"]} @staticmethod def param_source_function(workload, params, **kwargs): - return { - "key": params["parameter"] - } + return {"key": params["parameter"]} class ParamSourceLegacyClass: def __init__(self, indices=None, params=None): @@ -1352,9 +1195,7 @@ def size(self): return 1 def params(self): - return { - "class-key": self._params["parameter"] - } + return {"class-key": self._params["parameter"]} class ParamSourceClass: def __init__(self, workload=None, params=None, **kwargs): @@ -1368,9 +1209,7 @@ def size(self): return 1 def params(self): - return { - "class-key": self._params["parameter"] - } + return {"class-key": self._params["parameter"]} def __str__(self): return "test param source" @@ -1414,13 +1253,13 @@ def test_can_register_class_as_param_source(self): def test_cannot_register_an_instance_as_param_source(self): source_name = "params-test-class-param-source" # we create an instance, instead of passing the class - with self.assertRaisesRegex(exceptions.BenchmarkAssertionError, - "Parameter source \\[test param source\\] must be either a function or a class\\."): + with self.assertRaisesRegex(exceptions.BenchmarkAssertionError, "Parameter source \\[test param source\\] must be either a function or a class\\."): params.register_param_source_for_name(source_name, ParamsRegistrationTests.ParamSourceClass()) + class StandardValueSourceRegistrationTests(TestCase): def get_mock_standard_value_source(self, gte, lte): - return lambda : {"gte":gte, "lte":lte} + return lambda: {"gte": gte, "lte": lte} def test_register_standard_value_source(self): # Test the sequence: register standard value source -> generate saved standard values @@ -1439,49 +1278,47 @@ def test_register_standard_value_source(self): params.register_standard_value_source(op_name, field_name_1, self.get_mock_standard_value_source(gte_field_1, lte_field_1)) - self.assertEqual(params.get_standard_value_source(op_name, field_name_1)(), {"gte":gte_field_1, "lte":lte_field_1}) + self.assertEqual(params.get_standard_value_source(op_name, field_name_1)(), {"gte": gte_field_1, "lte": lte_field_1}) with self.assertRaises(exceptions.SystemSetupError) as ctx: _ = params.get_standard_value_source(op_name, field_name_2) self.assertEqual( - "Could not find standard value source for operation {}, field {}! Make sure this is registered in workload.py" - .format(op_name, field_name_2), ctx.exception.args[0]) + "Could not find standard value source for operation {}, field {}! Make sure this is registered in workload.py".format(op_name, field_name_2), ctx.exception.args[0] + ) with self.assertRaises(exceptions.SystemSetupError) as ctx: _ = params.get_standard_value(op_name, field_name_1, 0) self.assertEqual("No standard values generated for operation {}, field {}".format(op_name, field_name_1), ctx.exception.args[0]) params.generate_standard_values_if_absent(op_name, field_name_1, n) - self.assertEqual(params.get_standard_value(op_name, field_name_1, 0), {"gte":gte_field_1, "lte":lte_field_1}) + self.assertEqual(params.get_standard_value(op_name, field_name_1, 0), {"gte": gte_field_1, "lte": lte_field_1}) # check that running generate_standard_values_if_absent on the same inputs does nothing # we can do this by telling it to generate 2*n, but it won't because values are already present - params.generate_standard_values_if_absent(op_name, field_name_1, 2*n) + params.generate_standard_values_if_absent(op_name, field_name_1, 2 * n) with self.assertRaises(exceptions.SystemSetupError) as ctx: _ = params.get_standard_value(op_name, field_name_1, n + 1) self.assertEqual( - "Standard value index {} out of range for operation {}, field name {} ({} values total)" - .format(n+1, op_name, field_name_1, n), ctx.exception.args[0]) + "Standard value index {} out of range for operation {}, field name {} ({} values total)".format(n + 1, op_name, field_name_1, n), ctx.exception.args[0] + ) with self.assertRaises(exceptions.SystemSetupError) as ctx: params.generate_standard_values_if_absent(op_name, field_name_2, n) - self.assertEqual( - "Cannot generate standard values for operation {}, field {}. Standard value source is missing" - .format(op_name, field_name_2), ctx.exception.args[0]) + self.assertEqual("Cannot generate standard values for operation {}, field {}. Standard value source is missing".format(op_name, field_name_2), ctx.exception.args[0]) params.register_standard_value_source(op_name, field_name_2, self.get_mock_standard_value_source(gte_field_2, lte_field_2)) - self.assertEqual(params.get_standard_value_source(op_name, field_name_2)(), {"gte":gte_field_2, "lte":lte_field_2}) - self.assertEqual(params.get_standard_value_source(op_name, field_name_1)(), {"gte":gte_field_1, "lte":lte_field_1}) + self.assertEqual(params.get_standard_value_source(op_name, field_name_2)(), {"gte": gte_field_2, "lte": lte_field_2}) + self.assertEqual(params.get_standard_value_source(op_name, field_name_1)(), {"gte": gte_field_1, "lte": lte_field_1}) params._clear_standard_values() + class QueryRandomizationInfoRegistrationTests(TestCase): def check_result_equality(self, result, expected): self.assertEqual(result.query_name, expected.query_name) self.assertEqual(result.parameter_name_options_list, expected.parameter_name_options_list) self.assertEqual(result.optional_parameters, expected.optional_parameters) - def test_register_query_randomization_info(self): params._clear_query_randomization_infos() @@ -1502,19 +1339,18 @@ def test_register_query_randomization_info(self): params._clear_query_randomization_infos() + class SleepParamSourceTests(TestCase): def test_missing_duration_parameter(self): with self.assertRaisesRegex(exceptions.InvalidSyntax, "parameter 'duration' is mandatory for sleep operation"): params.SleepParamSource(workload.Workload(name="unit-test"), params={}) def test_duration_parameter_wrong_type(self): - with self.assertRaisesRegex(exceptions.InvalidSyntax, - "parameter 'duration' for sleep operation must be a number"): + with self.assertRaisesRegex(exceptions.InvalidSyntax, "parameter 'duration' for sleep operation must be a number"): params.SleepParamSource(workload.Workload(name="unit-test"), params={"duration": "this is a string"}) def test_duration_parameter_negative_number(self): - with self.assertRaisesRegex(exceptions.InvalidSyntax, - "parameter 'duration' must be non-negative but was -1.0"): + with self.assertRaisesRegex(exceptions.InvalidSyntax, "parameter 'duration' must be non-negative but was -1.0"): params.SleepParamSource(workload.Workload(name="unit-test"), params={"duration": -1.0}) def test_param_source_passes_all_parameters(self): @@ -1526,18 +1362,10 @@ class SearchParamSourceTests(TestCase): def test_passes_cache(self): col1 = workload.Collection(name="index1") - source = params.SearchParamSource(workload=workload.Workload(name="unit-test", collections=[col1]), params={ - "index": "index1", - "body": { - "query": { - "match_all": {} - } - }, - "headers": { - "header1": "value1" - }, - "cache": True - }) + source = params.SearchParamSource( + workload=workload.Workload(name="unit-test", collections=[col1]), + params={"index": "index1", "body": {"query": {"match_all": {}}}, "headers": {"header1": "value1"}, "cache": True}, + ) p = source.params() self.assertEqual(11, len(p)) @@ -1552,28 +1380,14 @@ def test_passes_cache(self): self.assertEqual(True, p["cache"]) self.assertEqual(True, p["response-compression-enabled"]) self.assertEqual(False, p["detailed-results"]) - self.assertEqual({ - "query": { - "match_all": {} - } - }, p["body"]) + self.assertEqual({"query": {"match_all": {}}}, p["body"]) def test_uses_collection_default(self): col1 = workload.Collection(name="collection-1") - source = params.SearchParamSource(workload=workload.Workload(name="unit-test", collections=[col1]), params={ - "body": { - "query": { - "match_all": {} - } - }, - "request-timeout": 1.0, - "headers": { - "header1": "value1", - "header2": "value2" - }, - "opaque-id": "12345abcde", - "cache": True - }) + source = params.SearchParamSource( + workload=workload.Workload(name="unit-test", collections=[col1]), + params={"body": {"query": {"match_all": {}}}, "request-timeout": 1.0, "headers": {"header1": "value1", "header2": "value2"}, "opaque-id": "12345abcde", "cache": True}, + ) p = source.params() self.assertEqual(11, len(p)) @@ -1581,48 +1395,27 @@ def test_uses_collection_default(self): self.assertEqual("collection-1", p["index"]) self.assertIsNone(p["type"]) self.assertEqual(1.0, p["request-timeout"]) - self.assertDictEqual({ - "header1": "value1", - "header2": "value2" - }, p["headers"]) + self.assertDictEqual({"header1": "value1", "header2": "value2"}, p["headers"]) self.assertEqual("12345abcde", p["opaque-id"]) self.assertEqual({}, p["request-params"]) self.assertEqual(True, p["cache"]) self.assertEqual(True, p["response-compression-enabled"]) self.assertEqual(False, p["detailed-results"]) - self.assertEqual({ - "query": { - "match_all": {} - } - }, p["body"]) + self.assertEqual({"query": {"match_all": {}}}, p["body"]) def test_create_without_index(self): with self.assertRaises(exceptions.InvalidSyntax) as ctx: - params.SearchParamSource(workload=workload.Workload(name="unit-test"), params={ - "type": "type1", - "body": { - "query": { - "match_all": {} - } - } - }, operation_name="test_operation") + params.SearchParamSource(workload=workload.Workload(name="unit-test"), params={"type": "type1", "body": {"query": {"match_all": {}}}}, operation_name="test_operation") self.assertEqual("'index' or 'data-stream' is mandatory and is missing for operation 'test_operation'", ctx.exception.args[0]) def test_passes_request_parameters(self): col1 = workload.Collection(name="index1") - source = params.SearchParamSource(workload=workload.Workload(name="unit-test", collections=[col1]), params={ - "index": "index1", - "request-params": { - "_source_include": "some_field" - }, - "body": { - "query": { - "match_all": {} - } - } - }) + source = params.SearchParamSource( + workload=workload.Workload(name="unit-test", collections=[col1]), + params={"index": "index1", "request-params": {"_source_include": "some_field"}, "body": {"query": {"match_all": {}}}}, + ) p = source.params() self.assertEqual(11, len(p)) @@ -1632,34 +1425,27 @@ def test_passes_request_parameters(self): self.assertIsNone(p["request-timeout"]) self.assertIsNone(p["headers"]) self.assertIsNone(p["opaque-id"]) - self.assertEqual({ - "_source_include": "some_field" - }, p["request-params"]) + self.assertEqual({"_source_include": "some_field"}, p["request-params"]) self.assertIsNone(p["cache"]) self.assertEqual(True, p["response-compression-enabled"]) self.assertEqual(False, p["detailed-results"]) - self.assertEqual({ - "query": { - "match_all": {} - } - }, p["body"]) + self.assertEqual({"query": {"match_all": {}}}, p["body"]) def test_user_specified_overrides_defaults(self): col1 = workload.Collection(name="index1") - source = params.SearchParamSource(workload=workload.Workload(name="unit-test", collections=[col1]), params={ - "index": "_all", - "type": "type1", - "cache": False, - "response-compression-enabled": False, - "detailed-results": True, - "opaque-id": "12345abcde", - "body": { - "query": { - "match_all": {} - } - } - }) + source = params.SearchParamSource( + workload=workload.Workload(name="unit-test", collections=[col1]), + params={ + "index": "_all", + "type": "type1", + "cache": False, + "response-compression-enabled": False, + "detailed-results": True, + "opaque-id": "12345abcde", + "body": {"query": {"match_all": {}}}, + }, + ) p = source.params() self.assertEqual(11, len(p)) @@ -1674,26 +1460,15 @@ def test_user_specified_overrides_defaults(self): self.assertEqual(False, p["cache"]) self.assertEqual(False, p["response-compression-enabled"]) self.assertEqual(True, p["detailed-results"]) - self.assertEqual({ - "query": { - "match_all": {} - } - }, p["body"]) + self.assertEqual({"query": {"match_all": {}}}, p["body"]) def test_user_specified_collection_overrides_defaults(self): col1 = workload.Collection(name="collection-1") - source = params.SearchParamSource(workload=workload.Workload(name="unit-test", collections=[col1]), params={ - "index": "collection-2", - "cache": False, - "response-compression-enabled": False, - "request-timeout": 1.0, - "body": { - "query": { - "match_all": {} - } - } - }) + source = params.SearchParamSource( + workload=workload.Workload(name="unit-test", collections=[col1]), + params={"index": "collection-2", "cache": False, "response-compression-enabled": False, "request-timeout": 1.0, "body": {"query": {"match_all": {}}}}, + ) p = source.params() self.assertEqual(11, len(p)) @@ -1708,37 +1483,26 @@ def test_user_specified_collection_overrides_defaults(self): self.assertEqual(False, p["cache"]) self.assertEqual(False, p["response-compression-enabled"]) self.assertEqual(False, p["detailed-results"]) - self.assertEqual({ - "query": { - "match_all": {} - } - }, p["body"]) + self.assertEqual({"query": {"match_all": {}}}, p["body"]) def test_assertions_without_detailed_results_are_invalid(self): col1 = workload.Collection(name="index1") - with self.assertRaisesRegex(exceptions.InvalidSyntax, - r"The property \[detailed-results\] must be \[true\] if assertions are defined"): - params.SearchParamSource(workload=workload.Workload(name="unit-test", collections=[col1]), params={ - "index": "_all", - # unset! - #"detailed-results": True, - "assertions": [{ - "property": "hits", - "condition": ">", - "value": 0 - }], - "body": { - "query": { - "match_all": {} - } - } - }) + with self.assertRaisesRegex(exceptions.InvalidSyntax, r"The property \[detailed-results\] must be \[true\] if assertions are defined"): + params.SearchParamSource( + workload=workload.Workload(name="unit-test", collections=[col1]), + params={ + "index": "_all", + # unset! + # "detailed-results": True, + "assertions": [{"property": "hits", "condition": ">", "value": 0}], + "body": {"query": {"match_all": {}}}, + }, + ) class CreateCollectionParamSourceTests(TestCase): def test_uses_first_collection_when_no_target_specified(self): - col = workload.Collection(name="my-col", configset="my-cfg", configset_path="/path/conf", - num_shards=2, replication_factor=1) + col = workload.Collection(name="my-col", configset="my-cfg", configset_path="/path/conf", num_shards=2, replication_factor=1) wl = workload.Workload(name="unit-test", collections=[col]) ps = params.CreateCollectionParamSource(workload=wl, params={}) p = ps.params() @@ -1765,10 +1529,7 @@ def test_raises_when_no_collections(self): def test_operation_params_override_collection_defaults(self): col = workload.Collection(name="my-col", num_shards=1, replication_factor=1) wl = workload.Workload(name="unit-test", collections=[col]) - ps = params.CreateCollectionParamSource( - workload=wl, - params={"collection": "my-col", "num-shards": 4, "replication-factor": 2, - "tlog-replicas": 1, "pull-replicas": 3}) + ps = params.CreateCollectionParamSource(workload=wl, params={"collection": "my-col", "num-shards": 4, "replication-factor": 2, "tlog-replicas": 1, "pull-replicas": 3}) p = ps.params() self.assertEqual(4, p["num-shards"]) self.assertEqual(2, p["replication-factor"]) @@ -1779,12 +1540,9 @@ def test_configset_path_from_collection_wins_over_operation(self): # The loader resolves configset-path to an absolute path against the # workload directory; the operation template only carries the relative # form. The loader-resolved path must survive. - col = workload.Collection(name="my-col", configset="my-col", - configset_path="/abs/workload/configsets/my-col") + col = workload.Collection(name="my-col", configset="my-col", configset_path="/abs/workload/configsets/my-col") wl = workload.Workload(name="unit-test", collections=[col]) - ps = params.CreateCollectionParamSource( - workload=wl, - params={"collection": "my-col", "configset-path": "configsets/my-col"}) + ps = params.CreateCollectionParamSource(workload=wl, params={"collection": "my-col", "configset-path": "configsets/my-col"}) p = ps.params() self.assertEqual("/abs/workload/configsets/my-col", p["configset-path"]) diff --git a/tests/workload/workload_test.py b/tests/workload/workload_test.py index 1104ec2b..515b47f3 100644 --- a/tests/workload/workload_test.py +++ b/tests/workload/workload_test.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -36,60 +36,60 @@ def test_finds_default_test_procedure(self): default_test_procedure = workload.TestProcedure("default", description="default test_procedure", default=True) another_test_procedure = workload.TestProcedure("other", description="non-default test_procedure", default=False) - self.assertEqual(default_test_procedure, - workload.Workload(name="unittest", - description="unittest workload", - test_procedures=[another_test_procedure, default_test_procedure]) - .default_test_procedure) + self.assertEqual( + default_test_procedure, + workload.Workload(name="unittest", description="unittest workload", test_procedures=[another_test_procedure, default_test_procedure]).default_test_procedure, + ) def test_default_test_procedure_none_if_no_test_procedures(self): - self.assertIsNone(workload.Workload(name="unittest", - description="unittest workload", - test_procedures=[]) - .default_test_procedure) + self.assertIsNone(workload.Workload(name="unittest", description="unittest workload", test_procedures=[]).default_test_procedure) def test_finds_test_procedure_by_name(self): default_test_procedure = workload.TestProcedure("default", description="default test_procedure", default=True) another_test_procedure = workload.TestProcedure("other", description="non-default test_procedure", default=False) - self.assertEqual(another_test_procedure, - workload.Workload(name="unittest", - description="unittest workload", - test_procedures=[another_test_procedure, default_test_procedure]) - .find_test_procedure_or_default("other")) + self.assertEqual( + another_test_procedure, + workload.Workload(name="unittest", description="unittest workload", test_procedures=[another_test_procedure, default_test_procedure]).find_test_procedure_or_default( + "other" + ), + ) def test_uses_default_test_procedure_if_no_name_given(self): default_test_procedure = workload.TestProcedure("default", description="default test_procedure", default=True) another_test_procedure = workload.TestProcedure("other", description="non-default test_procedure", default=False) - self.assertEqual(default_test_procedure, - workload.Workload(name="unittest", - description="unittest workload", - test_procedures=[another_test_procedure, default_test_procedure]) - .find_test_procedure_or_default("")) + self.assertEqual( + default_test_procedure, + workload.Workload(name="unittest", description="unittest workload", test_procedures=[another_test_procedure, default_test_procedure]).find_test_procedure_or_default( + "" + ), + ) def test_does_not_find_unknown_test_procedure(self): default_test_procedure = workload.TestProcedure("default", description="default test_procedure", default=True) another_test_procedure = workload.TestProcedure("other", description="non-default test_procedure", default=False) with self.assertRaises(exceptions.InvalidName) as ctx: - workload.Workload(name="unittest", - description="unittest workload", - test_procedures=[another_test_procedure, default_test_procedure]).find_test_procedure_or_default("unknown-name") + workload.Workload(name="unittest", description="unittest workload", test_procedures=[another_test_procedure, default_test_procedure]).find_test_procedure_or_default( + "unknown-name" + ) self.assertEqual("Unknown test_procedure [unknown-name] for workload [unittest]", ctx.exception.args[0]) class DocumentCorpusTests(TestCase): def test_do_not_filter(self): - corpus = workload.DocumentCorpus("test", documents=[ - workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, number_of_documents=5, target_collection="logs-01"), - workload.Documents(source_format="other", number_of_documents=6, target_collection="logs-02"), - workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, number_of_documents=7, target_collection="logs-03"), - workload.Documents(source_format=None, number_of_documents=8, target_collection=None) - ], meta_data={ - "average-document-size-in-bytes": 12 - }) + corpus = workload.DocumentCorpus( + "test", + documents=[ + workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, number_of_documents=5, target_collection="logs-01"), + workload.Documents(source_format="other", number_of_documents=6, target_collection="logs-02"), + workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, number_of_documents=7, target_collection="logs-03"), + workload.Documents(source_format=None, number_of_documents=8, target_collection=None), + ], + meta_data={"average-document-size-in-bytes": 12}, + ) filtered_corpus = corpus.filter() @@ -98,12 +98,15 @@ def test_do_not_filter(self): self.assertDictEqual(corpus.meta_data, filtered_corpus.meta_data) def test_filter_documents_by_format(self): - corpus = workload.DocumentCorpus("test", documents=[ - workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, number_of_documents=5, target_collection="logs-01"), - workload.Documents(source_format="other", number_of_documents=6, target_collection="logs-02"), - workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, number_of_documents=7, target_collection="logs-03"), - workload.Documents(source_format=None, number_of_documents=8, target_collection=None) - ]) + corpus = workload.DocumentCorpus( + "test", + documents=[ + workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, number_of_documents=5, target_collection="logs-01"), + workload.Documents(source_format="other", number_of_documents=6, target_collection="logs-02"), + workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, number_of_documents=7, target_collection="logs-03"), + workload.Documents(source_format=None, number_of_documents=8, target_collection=None), + ], + ) filtered_corpus = corpus.filter(source_format=workload.Documents.SOURCE_FORMAT_BULK) @@ -113,12 +116,15 @@ def test_filter_documents_by_format(self): self.assertEqual("logs-03", filtered_corpus.documents[1].target_collection) def test_filter_documents_by_indices(self): - corpus = workload.DocumentCorpus("test", documents=[ - workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, number_of_documents=5, target_collection="logs-01"), - workload.Documents(source_format="other", number_of_documents=6, target_collection="logs-02"), - workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, number_of_documents=7, target_collection="logs-03"), - workload.Documents(source_format=None, number_of_documents=8, target_collection=None) - ]) + corpus = workload.DocumentCorpus( + "test", + documents=[ + workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, number_of_documents=5, target_collection="logs-01"), + workload.Documents(source_format="other", number_of_documents=6, target_collection="logs-02"), + workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, number_of_documents=7, target_collection="logs-03"), + workload.Documents(source_format=None, number_of_documents=8, target_collection=None), + ], + ) filtered_corpus = corpus.filter(target_collections=["logs-02"]) @@ -127,12 +133,15 @@ def test_filter_documents_by_indices(self): self.assertEqual("logs-02", filtered_corpus.documents[0].target_collection) def test_filter_documents_by_format_and_indices(self): - corpus = workload.DocumentCorpus("test", documents=[ - workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, number_of_documents=5, target_collection="logs-01"), - workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, number_of_documents=6, target_collection="logs-02"), - workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, number_of_documents=7, target_collection="logs-03"), - workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, number_of_documents=8, target_collection=None) - ]) + corpus = workload.DocumentCorpus( + "test", + documents=[ + workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, number_of_documents=5, target_collection="logs-01"), + workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, number_of_documents=6, target_collection="logs-02"), + workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, number_of_documents=7, target_collection="logs-03"), + workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, number_of_documents=8, target_collection=None), + ], + ) filtered_corpus = corpus.filter(source_format=workload.Documents.SOURCE_FORMAT_BULK, target_collections=["logs-01", "logs-02"]) @@ -142,50 +151,68 @@ def test_filter_documents_by_format_and_indices(self): self.assertEqual("logs-02", filtered_corpus.documents[1].target_collection) def test_union_document_corpus_is_reflexive(self): - corpus = workload.DocumentCorpus("test", documents=[ - workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, number_of_documents=5, target_collection="logs-01"), - workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, number_of_documents=6, target_collection="logs-02"), - workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, number_of_documents=7, target_collection="logs-03"), - workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, number_of_documents=8, target_collection=None) - ]) + corpus = workload.DocumentCorpus( + "test", + documents=[ + workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, number_of_documents=5, target_collection="logs-01"), + workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, number_of_documents=6, target_collection="logs-02"), + workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, number_of_documents=7, target_collection="logs-03"), + workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, number_of_documents=8, target_collection=None), + ], + ) self.assertTrue(corpus.union(corpus) is corpus) def test_union_document_corpora_is_symmetric(self): - a = workload.DocumentCorpus("test", documents=[ - workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, number_of_documents=5, target_collection="logs-01"), - ]) - b = workload.DocumentCorpus("test", documents=[ - workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, number_of_documents=5, target_collection="logs-02"), - ]) + a = workload.DocumentCorpus( + "test", + documents=[ + workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, number_of_documents=5, target_collection="logs-01"), + ], + ) + b = workload.DocumentCorpus( + "test", + documents=[ + workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, number_of_documents=5, target_collection="logs-02"), + ], + ) self.assertEqual(b.union(a), a.union(b)) self.assertEqual(2, len(a.union(b).documents)) def test_cannot_union_mixed_document_corpora_by_name(self): - a = workload.DocumentCorpus("test", documents=[ - workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, number_of_documents=5, target_collection="logs-01"), - ]) - b = workload.DocumentCorpus("other", documents=[ - workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, number_of_documents=5, target_collection="logs-02"), - ]) + a = workload.DocumentCorpus( + "test", + documents=[ + workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, number_of_documents=5, target_collection="logs-01"), + ], + ) + b = workload.DocumentCorpus( + "other", + documents=[ + workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, number_of_documents=5, target_collection="logs-02"), + ], + ) with self.assertRaises(exceptions.BenchmarkAssertionError) as ae: a.union(b) self.assertEqual(ae.exception.message, "Corpora names differ: [test] and [other].") def test_cannot_union_mixed_document_corpora_by_meta_data(self): - a = workload.DocumentCorpus("test", documents=[ - workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, number_of_documents=5, target_collection="logs-01"), - ], meta_data={ - "with-metadata": False - }) - b = workload.DocumentCorpus("test", documents=[ - workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, number_of_documents=5, target_collection="logs-02"), - ], meta_data={ - "with-metadata": True - }) + a = workload.DocumentCorpus( + "test", + documents=[ + workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, number_of_documents=5, target_collection="logs-01"), + ], + meta_data={"with-metadata": False}, + ) + b = workload.DocumentCorpus( + "test", + documents=[ + workload.Documents(source_format=workload.Documents.SOURCE_FORMAT_BULK, number_of_documents=5, target_collection="logs-02"), + ], + meta_data={"with-metadata": True}, + ) with self.assertRaises(exceptions.BenchmarkAssertionError) as ae: a.union(b) - self.assertEqual(ae.exception.message, - "Corpora meta-data differ: [{'with-metadata': False}] and [{'with-metadata': True}].") + self.assertEqual(ae.exception.message, "Corpora meta-data differ: [{'with-metadata': False}] and [{'with-metadata': True}].") class OperationTypeTests(TestCase): @@ -203,16 +230,12 @@ def test_attributes(self): class TaskFilterTests(TestCase): def create_index_task(self): - return workload.Task("create-index-task", - workload.Operation("create-index-op", - operation_type=workload.OperationType.CreateBackup.to_hyphenated_string()), - tags=["write-op", "admin-op"]) + return workload.Task( + "create-index-task", workload.Operation("create-index-op", operation_type=workload.OperationType.CreateBackup.to_hyphenated_string()), tags=["write-op", "admin-op"] + ) def search_task(self): - return workload.Task("search-task", - workload.Operation("search-op", - operation_type=workload.OperationType.Search.to_hyphenated_string()), - tags="read-op") + return workload.Task("search-task", workload.Operation("search-op", operation_type=workload.OperationType.Search.to_hyphenated_string()), tags="read-op") def test_task_name_filter(self): f = workload.TaskNameFilter("create-index-task") @@ -277,16 +300,14 @@ def test_interval_and_throughput_is_rejected(self): with self.assertRaises(exceptions.InvalidSyntax) as e: # pylint: disable=pointless-statement task.target_throughput - self.assertEqual("Task [test] specifies target-interval [1] and target-throughput [1] but only one " - "of them is allowed.", e.exception.args[0]) + self.assertEqual("Task [test] specifies target-interval [1] and target-throughput [1] but only one of them is allowed.", e.exception.args[0]) def test_invalid_ignore_response_error_level_is_rejected(self): task = self.task(ignore_response_error_level="invalid-value") with self.assertRaises(exceptions.InvalidSyntax) as e: # pylint: disable=pointless-statement task.ignore_response_error_level - self.assertEqual("Task [test] specifies ignore-response-error-level to [invalid-value] but " - "the only allowed values are [non-fatal].", e.exception.args[0]) + self.assertEqual("Task [test] specifies ignore-response-error-level to [invalid-value] but the only allowed values are [non-fatal].", e.exception.args[0]) def test_task_continues_with_global_continue(self): task = self.task() diff --git a/tests/workload_generator/__init__.py b/tests/workload_generator/__init__.py index 5047a451..f5768141 100644 --- a/tests/workload_generator/__init__.py +++ b/tests/workload_generator/__init__.py @@ -16,7 +16,7 @@ # not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an