From 66ce99f46e5dd2df1b1ad67ddf76e6d874f328c6 Mon Sep 17 00:00:00 2001 From: Jamie Diprose <5715104+jdddog@users.noreply.github.com> Date: Mon, 13 Nov 2023 16:25:03 +1300 Subject: [PATCH] Refactor to support Astro --- .coveragerc | 10 +- .gitattributes | 3 +- .github/workflows/publish-pypi.yml | 25 +- .github/workflows/unit-tests.yml | 10 +- .gitignore | 3 + LICENSES_THIRD_PARTY | 160 +- README.md | 23 +- bin/generate-api-client.sh | 34 - docs/api/BigQueryBytesProcessed.md | 54 - docs/api/DatasetRelease.md | 114 - docs/api/ObservatoryApi.md | 661 ------ docs/api/api_client.md | 146 -- docs/api/api_rest.rst | 5 - docs/api/api_server.md | 18 - docs/api/index.rst | 8 - docs/api/openapi.yaml | 229 -- docs/conf.py | 6 +- docs/graphics/workflow_flow.png | Bin 4313 -> 0 bytes docs/index.rst | 49 +- docs/installation.md | 24 - docs/requirements.txt | 11 +- docs/tutorials/deploy_terraform.md | 480 ---- docs/tutorials/index.rst | 33 - docs/tutorials/observatory_dev.md | 208 -- docs/tutorials/workflow/intro.md | 60 - docs/tutorials/workflow/step_by_step.md | 709 ------ docs/tutorials/workflow/style.md | 17 - docs/tutorials/workflow/workflow_class.md | 404 ---- install.sh | 474 ---- observatory-api/.dockerignore | 1 - observatory-api/.gitignore | 1 - observatory-api/.openapi-generator-ignore | 28 - observatory-api/.openapi-generator/FILES | 17 - observatory-api/.openapi-generator/VERSION | 1 - observatory-api/README.md | 1 - observatory-api/api-config.yaml | 11 - observatory-api/observatory/api/cli.py | 39 - .../observatory/api/client/__init__.py | 28 - .../observatory/api/client/api/__init__.py | 3 - .../api/client/api/observatory_api.py | 709 ------ .../observatory/api/client/api_client.py | 897 ------- .../observatory/api/client/configuration.py | 474 ---- .../observatory/api/client/exceptions.py | 159 -- .../observatory/api/client/model/__init__.py | 5 - .../api/client/model/dataset_release.py | 316 --- .../observatory/api/client/model_utils.py | 2059 ----------------- .../observatory/api/client/rest.py | 353 --- observatory-api/observatory/api/server/api.py | 227 -- observatory-api/observatory/api/server/app.py | 43 - .../api/server/openapi.yaml.jinja2 | 244 -- .../api/server/openapi_renderer.py | 73 - observatory-api/observatory/api/server/orm.py | 287 --- observatory-api/observatory/api/testing.py | 69 - observatory-api/observatory/api/utils.py | 44 - observatory-api/requirements.sh | 15 - observatory-api/requirements.txt | 14 - observatory-api/setup.cfg | 58 - observatory-api/setup.py | 3 - .../templates/README_common.mustache | 116 - .../templates/README_onlypackage.mustache | 37 - observatory-api/templates/api_doc.mustache | 148 -- .../templates/configuration.mustache | 627 ----- observatory-api/templates/model_doc.mustache | 87 - .../model_templates/classvars.mustache | 138 -- observatory-platform/README.md | 1 - .../observatory/platform/api.py | 150 -- .../observatory/platform/cli/cli.py | 448 ---- .../observatory/platform/cli/cli_utils.py | 41 - .../platform/cli/generate_command.py | 612 ----- .../platform/cli/platform_command.py | 46 - .../platform/cli/terraform_command.py | 211 -- .../platform/config-terraform.yaml.jinja2 | 62 - .../observatory/platform/config.yaml.jinja2 | 40 - .../platform/dags/load_dags_modules.py | 37 - .../platform/dags/load_workflows.py | 34 - .../docker/Dockerfile.apiserver.jinja2 | 61 - .../docker/Dockerfile.observatory.jinja2 | 66 - .../docker/Dockerfile.package_install.jinja2 | 66 - .../observatory/platform/docker/builder.py | 152 -- .../platform/docker/compose_runner.py | 132 -- .../docker-compose.observatory.yml.jinja2 | 305 --- .../docker/entrypoint-airflow.sh.jinja2 | 39 - .../platform/docker/entrypoint-api.sh.jinja2 | 40 - .../platform/docker/entrypoint-root.sh | 32 - .../platform/docker/platform_runner.py | 222 -- .../platform/observatory_config.py | 1702 -------------- .../platform/observatory_environment.py | 1339 ----------- .../observatory/platform/terraform/build.sh | 66 - .../observatory/platform/terraform/main.tf | 575 ----- .../terraform/observatory-image.json.pkr.hcl | 61 - .../observatory/platform/terraform/outputs.tf | 47 - .../platform/terraform/secret/main.tf | 21 - .../platform/terraform/secret/outputs.tf | 0 .../platform/terraform/secret/variables.tf | 14 - .../platform/terraform/startup.tpl.jinja2 | 69 - .../platform/terraform/terraform_api.py | 503 ---- .../platform/terraform/terraform_builder.py | 252 -- .../platform/terraform/variables.tf | 99 - .../platform/terraform/versions.tf | 17 - .../observatory/platform/terraform/vm/main.tf | 34 - .../platform/terraform/vm/outputs.tf | 3 - .../platform/terraform/vm/variables.tf | 69 - .../platform/workflows/vm_workflow.py | 641 ----- .../platform/workflows/workflow.py | 596 ----- observatory-platform/requirements.sh | 18 - observatory-platform/requirements.txt | 67 - observatory-platform/setup.cfg | 65 - observatory-platform/setup.py | 3 - .../api => observatory_platform}/__init__.py | 0 .../airflow}/__init__.py | 0 .../airflow}/airflow.py | 322 +-- observatory_platform/airflow/release.py | 191 ++ .../airflow/sensors.py | 65 +- observatory_platform/airflow/tasks.py | 115 + .../airflow/tests}/__init__.py | 0 .../airflow/tests/fixtures}/__init__.py | 0 .../airflow/tests/fixtures}/bad_dag.py | 0 .../airflow/tests/fixtures}/good_dag.py | 0 .../airflow/tests}/test_airflow.py | 157 +- .../airflow/tests/test_release.py | 37 + .../airflow/tests/test_sensors.py | 103 +- .../airflow/tests/test_workflow.py | 121 + observatory_platform/airflow/workflow.py | 372 +++ .../config.py | 19 +- observatory_platform/dataset_api.py | 349 +++ .../files.py | 7 +- .../google}/__init__.py | 0 .../google}/bigquery.py | 183 +- observatory_platform/google/gcp.py | 120 + .../google}/gcs.py | 175 +- observatory_platform/google/gke.py | 103 + .../google/tests}/__init__.py | 0 .../google/tests/fixtures}/__init__.py | 0 .../google/tests/fixtures/bad_dag.py | 3 + .../google/tests/fixtures}/people.csv | 0 .../google/tests/fixtures}/people.jsonl | 0 .../google/tests/fixtures}/people_extra.jsonl | 0 .../google/tests/fixtures}/people_schema.json | 0 .../tests/fixtures/schema}/db_merge0.json | 0 .../tests/fixtures/schema}/db_merge1.json | 0 .../schema}/stream_telescope_file1.json | 0 .../schema}/stream_telescope_file2.json | 0 .../schema}/stream_telescope_schema.json | 0 .../tests/fixtures/schema}/table_a.json | 0 .../fixtures/schema}/table_b_1900-01-01.json | 0 .../fixtures/schema}/table_b_2000-01-01.json | 0 .../schema}/test_schema_2021-01-01.json | 0 .../fixtures/schema}/wos-2020-10-01.json | 0 .../google/tests}/test_bigquery.py | 144 +- .../google/tests}/test_gcs.py | 14 +- .../http_download.py | 6 +- .../jinja2_utils.py | 3 +- .../proc_utils.py | 26 - .../sandbox}/__init__.py | 0 observatory_platform/sandbox/ftp_server.py | 90 + observatory_platform/sandbox/http_server.py | 87 + .../sandbox/sandbox_environment.py | 429 ++++ observatory_platform/sandbox/sftp_server.py | 165 ++ observatory_platform/sandbox/test_utils.py | 613 +++++ .../sandbox/tests}/__init__.py | 0 .../sandbox/tests/fixtures}/__init__.py | 0 .../sandbox/tests/fixtures/bad_dag.py | 3 + .../sandbox/tests/fixtures}/http_testfile.txt | 0 .../sandbox/tests/fixtures/people.csv | 3 + .../sandbox/tests/fixtures}/people.csv.gz | 0 .../sandbox/tests/fixtures/people.jsonl | 3 + .../sandbox/tests/fixtures/people_schema.json | 3 + .../sandbox/tests/test_ftp_server.py | 93 + .../sandbox/tests/test_http_server.py | 96 + .../sandbox/tests/test_sandbox_environment.py | 333 +++ .../sandbox/tests/test_sftp_server.py | 53 + .../sandbox/tests/test_test_utils.py | 346 +++ .../schema}/__init__.py | 0 .../schema/dataset_release.json | 86 + .../platform => observatory_platform}/sftp.py | 0 .../sql}/__init__.py | 0 .../sql/delete_records.sql.jinja2 | 0 .../sql/select_columns.sql.jinja2 | 0 .../sql/select_table_shard_dates.sql.jinja2 | 0 .../sql/upsert_records.sql.jinja2 | 0 .../tests}/__init__.py | 0 .../tests/fixtures}/__init__.py | 0 .../tests/fixtures}/find_replace.txt | 0 .../fixtures}/get_http_response_json.json | 0 .../get_http_response_xml_to_dict.xml | 0 .../tests/fixtures/http_testfile.txt | 3 + .../tests/fixtures}/http_testfile2.txt | 0 .../tests/fixtures}/load_csv.csv | 0 .../tests/fixtures}/load_csv_gz.csv.gz | 0 .../tests/fixtures}/load_jsonl.jsonl | 0 .../tests/fixtures}/load_jsonl_gz.jsonl.gz | 0 .../tests/fixtures}/test_hasher.txt | 0 .../tests/fixtures}/testzip.txt | 0 .../tests/fixtures}/testzip.txt.gz | 0 observatory_platform/tests/test_config.py | 58 + .../tests/test_dataset_api.py | 189 ++ .../tests}/test_files.py | 32 +- .../tests}/test_http_download.py | 24 +- .../tests}/test_jinja2_utils.py | 2 +- .../tests/test_proc_utils.py | 21 +- .../tests}/test_sftp.py | 2 +- .../tests}/test_url_utils.py | 36 +- .../url_utils.py | 4 +- pyproject.toml | 146 ++ setup.cfg | 67 +- strategy.ini | 70 - .../dags/hello_world_dag.py | 3 - .../my_workflows_project/dags/my_dag.py | 3 - .../workflows/__init__.py | 0 .../workflows/my_workflow.py | 3 - .../cli/my-workflows-project/requirements.sh | 3 - .../cli/my-workflows-project/requirements.txt | 3 - .../cli/my-workflows-project/setup.cfg | 3 - .../cli/my-workflows-project/setup.py | 3 - .../elastic/the-expanse-mappings.json | 3 - .../fixtures/schemas/ao-author-mappings.json | 3 - .../schemas/ao_author_2021-01-01.json | 3 - tests/fixtures/utils/main.tf | 3 - tests/fixtures/utils/test.csv | 3 - tests/observatory/__init__.py | 0 tests/observatory/api/__init__.py | 0 tests/observatory/api/client/__init__.py | 0 .../api/client/test_dataset_release.py | 78 - .../api/client/test_observatory_api.py | 295 --- tests/observatory/api/server/__init__.py | 0 tests/observatory/api/server/test_openapi.py | 53 - tests/observatory/api/server/test_orm.py | 203 -- tests/observatory/api/test_utils.py | 59 - tests/observatory/platform/__init__.py | 0 tests/observatory/platform/cli/__init__.py | 0 tests/observatory/platform/cli/test_cli.py | 727 ------ .../platform/cli/test_cli_functional.py | 474 ---- .../platform/cli/test_cli_utils.py | 63 - .../platform/cli/test_generate_command.py | 663 ------ .../platform/cli/test_platform_command.py | 114 - tests/observatory/platform/docker/__init__.py | 0 .../platform/docker/test_builder.py | 117 - .../platform/docker/test_compose_runner.py | 88 - .../platform/docker/test_platform_runner.py | 286 --- .../platform/terraform/__init__.py | 0 .../platform/terraform/test_terraform_api.py | 521 ----- .../terraform/test_terraform_builder.py | 219 -- tests/observatory/platform/test_api.py | 112 - tests/observatory/platform/test_config.py | 140 -- .../platform/test_observatory_config.py | 1131 --------- .../platform/test_observatory_environment.py | 857 ------- tests/observatory/platform/utils/__init__.py | 0 .../platform/utils/test_proc_utils.py | 55 - .../platform/workflows/__init__.py | 0 .../platform/workflows/test_vm_create.py | 259 --- .../platform/workflows/test_vm_destroy.py | 1207 ---------- .../platform/workflows/test_workflow.py | 429 ---- 252 files changed, 5031 insertions(+), 29215 deletions(-) delete mode 100755 bin/generate-api-client.sh delete mode 100644 docs/api/BigQueryBytesProcessed.md delete mode 100644 docs/api/DatasetRelease.md delete mode 100644 docs/api/ObservatoryApi.md delete mode 100644 docs/api/api_client.md delete mode 100644 docs/api/api_rest.rst delete mode 100644 docs/api/api_server.md delete mode 100644 docs/api/index.rst delete mode 100644 docs/api/openapi.yaml delete mode 100644 docs/graphics/workflow_flow.png delete mode 100644 docs/installation.md delete mode 100644 docs/tutorials/deploy_terraform.md delete mode 100644 docs/tutorials/index.rst delete mode 100644 docs/tutorials/observatory_dev.md delete mode 100644 docs/tutorials/workflow/intro.md delete mode 100644 docs/tutorials/workflow/step_by_step.md delete mode 100644 docs/tutorials/workflow/style.md delete mode 100644 docs/tutorials/workflow/workflow_class.md delete mode 100755 install.sh delete mode 100644 observatory-api/.dockerignore delete mode 100644 observatory-api/.gitignore delete mode 100644 observatory-api/.openapi-generator-ignore delete mode 100644 observatory-api/.openapi-generator/FILES delete mode 100644 observatory-api/.openapi-generator/VERSION delete mode 100644 observatory-api/README.md delete mode 100644 observatory-api/api-config.yaml delete mode 100644 observatory-api/observatory/api/cli.py delete mode 100644 observatory-api/observatory/api/client/__init__.py delete mode 100644 observatory-api/observatory/api/client/api/__init__.py delete mode 100644 observatory-api/observatory/api/client/api/observatory_api.py delete mode 100644 observatory-api/observatory/api/client/api_client.py delete mode 100644 observatory-api/observatory/api/client/configuration.py delete mode 100644 observatory-api/observatory/api/client/exceptions.py delete mode 100644 observatory-api/observatory/api/client/model/__init__.py delete mode 100644 observatory-api/observatory/api/client/model/dataset_release.py delete mode 100644 observatory-api/observatory/api/client/model_utils.py delete mode 100644 observatory-api/observatory/api/client/rest.py delete mode 100644 observatory-api/observatory/api/server/api.py delete mode 100644 observatory-api/observatory/api/server/app.py delete mode 100644 observatory-api/observatory/api/server/openapi.yaml.jinja2 delete mode 100644 observatory-api/observatory/api/server/openapi_renderer.py delete mode 100644 observatory-api/observatory/api/server/orm.py delete mode 100644 observatory-api/observatory/api/testing.py delete mode 100644 observatory-api/observatory/api/utils.py delete mode 100644 observatory-api/requirements.sh delete mode 100644 observatory-api/requirements.txt delete mode 100644 observatory-api/setup.cfg delete mode 100644 observatory-api/setup.py delete mode 100644 observatory-api/templates/README_common.mustache delete mode 100644 observatory-api/templates/README_onlypackage.mustache delete mode 100644 observatory-api/templates/api_doc.mustache delete mode 100644 observatory-api/templates/configuration.mustache delete mode 100644 observatory-api/templates/model_doc.mustache delete mode 100644 observatory-api/templates/model_templates/classvars.mustache delete mode 100644 observatory-platform/README.md delete mode 100644 observatory-platform/observatory/platform/api.py delete mode 100644 observatory-platform/observatory/platform/cli/cli.py delete mode 100644 observatory-platform/observatory/platform/cli/cli_utils.py delete mode 100644 observatory-platform/observatory/platform/cli/generate_command.py delete mode 100644 observatory-platform/observatory/platform/cli/platform_command.py delete mode 100644 observatory-platform/observatory/platform/cli/terraform_command.py delete mode 100644 observatory-platform/observatory/platform/config-terraform.yaml.jinja2 delete mode 100644 observatory-platform/observatory/platform/config.yaml.jinja2 delete mode 100644 observatory-platform/observatory/platform/dags/load_dags_modules.py delete mode 100644 observatory-platform/observatory/platform/dags/load_workflows.py delete mode 100644 observatory-platform/observatory/platform/docker/Dockerfile.apiserver.jinja2 delete mode 100644 observatory-platform/observatory/platform/docker/Dockerfile.observatory.jinja2 delete mode 100644 observatory-platform/observatory/platform/docker/Dockerfile.package_install.jinja2 delete mode 100644 observatory-platform/observatory/platform/docker/builder.py delete mode 100644 observatory-platform/observatory/platform/docker/compose_runner.py delete mode 100644 observatory-platform/observatory/platform/docker/docker-compose.observatory.yml.jinja2 delete mode 100755 observatory-platform/observatory/platform/docker/entrypoint-airflow.sh.jinja2 delete mode 100644 observatory-platform/observatory/platform/docker/entrypoint-api.sh.jinja2 delete mode 100755 observatory-platform/observatory/platform/docker/entrypoint-root.sh delete mode 100644 observatory-platform/observatory/platform/docker/platform_runner.py delete mode 100644 observatory-platform/observatory/platform/observatory_config.py delete mode 100644 observatory-platform/observatory/platform/observatory_environment.py delete mode 100644 observatory-platform/observatory/platform/terraform/build.sh delete mode 100644 observatory-platform/observatory/platform/terraform/main.tf delete mode 100644 observatory-platform/observatory/platform/terraform/observatory-image.json.pkr.hcl delete mode 100644 observatory-platform/observatory/platform/terraform/outputs.tf delete mode 100644 observatory-platform/observatory/platform/terraform/secret/main.tf delete mode 100644 observatory-platform/observatory/platform/terraform/secret/outputs.tf delete mode 100644 observatory-platform/observatory/platform/terraform/secret/variables.tf delete mode 100755 observatory-platform/observatory/platform/terraform/startup.tpl.jinja2 delete mode 100644 observatory-platform/observatory/platform/terraform/terraform_api.py delete mode 100644 observatory-platform/observatory/platform/terraform/terraform_builder.py delete mode 100644 observatory-platform/observatory/platform/terraform/variables.tf delete mode 100644 observatory-platform/observatory/platform/terraform/versions.tf delete mode 100644 observatory-platform/observatory/platform/terraform/vm/main.tf delete mode 100644 observatory-platform/observatory/platform/terraform/vm/outputs.tf delete mode 100644 observatory-platform/observatory/platform/terraform/vm/variables.tf delete mode 100644 observatory-platform/observatory/platform/workflows/vm_workflow.py delete mode 100644 observatory-platform/observatory/platform/workflows/workflow.py delete mode 100644 observatory-platform/requirements.sh delete mode 100644 observatory-platform/requirements.txt delete mode 100644 observatory-platform/setup.cfg delete mode 100644 observatory-platform/setup.py rename {observatory-api/observatory/api => observatory_platform}/__init__.py (100%) rename {observatory-api/observatory/api/server => observatory_platform/airflow}/__init__.py (100%) rename {observatory-platform/observatory/platform => observatory_platform/airflow}/airflow.py (58%) create mode 100644 observatory_platform/airflow/release.py rename observatory-platform/observatory/platform/utils/dag_run_sensor.py => observatory_platform/airflow/sensors.py (73%) create mode 100644 observatory_platform/airflow/tasks.py rename {observatory-platform/observatory/platform => observatory_platform/airflow/tests}/__init__.py (100%) rename {observatory-platform/observatory/platform/cli => observatory_platform/airflow/tests/fixtures}/__init__.py (100%) rename {tests/fixtures/utils => observatory_platform/airflow/tests/fixtures}/bad_dag.py (100%) rename {tests/fixtures/utils => observatory_platform/airflow/tests/fixtures}/good_dag.py (100%) rename {tests/observatory/platform => observatory_platform/airflow/tests}/test_airflow.py (78%) create mode 100644 observatory_platform/airflow/tests/test_release.py rename tests/observatory/platform/utils/test_dag_run_sensor.py => observatory_platform/airflow/tests/test_sensors.py (72%) create mode 100644 observatory_platform/airflow/tests/test_workflow.py create mode 100644 observatory_platform/airflow/workflow.py rename {observatory-platform/observatory/platform => observatory_platform}/config.py (83%) create mode 100644 observatory_platform/dataset_api.py rename {observatory-platform/observatory/platform => observatory_platform}/files.py (99%) rename {observatory-platform/observatory/platform/dags => observatory_platform/google}/__init__.py (100%) rename {observatory-platform/observatory/platform => observatory_platform/google}/bigquery.py (87%) create mode 100644 observatory_platform/google/gcp.py rename {observatory-platform/observatory/platform => observatory_platform/google}/gcs.py (83%) create mode 100644 observatory_platform/google/gke.py rename {observatory-platform/observatory/platform/docker => observatory_platform/google/tests}/__init__.py (100%) rename {observatory-platform/observatory/platform/sql => observatory_platform/google/tests/fixtures}/__init__.py (100%) create mode 100644 observatory_platform/google/tests/fixtures/bad_dag.py rename {tests/fixtures/utils => observatory_platform/google/tests/fixtures}/people.csv (100%) rename {tests/fixtures/utils => observatory_platform/google/tests/fixtures}/people.jsonl (100%) rename {tests/fixtures/utils => observatory_platform/google/tests/fixtures}/people_extra.jsonl (100%) rename {tests/fixtures/utils => observatory_platform/google/tests/fixtures}/people_schema.json (100%) rename {tests/fixtures/schemas => observatory_platform/google/tests/fixtures/schema}/db_merge0.json (100%) rename {tests/fixtures/schemas => observatory_platform/google/tests/fixtures/schema}/db_merge1.json (100%) rename {tests/fixtures/schemas => observatory_platform/google/tests/fixtures/schema}/stream_telescope_file1.json (100%) rename {tests/fixtures/schemas => observatory_platform/google/tests/fixtures/schema}/stream_telescope_file2.json (100%) rename {tests/fixtures/schemas => observatory_platform/google/tests/fixtures/schema}/stream_telescope_schema.json (100%) rename {tests/fixtures/schemas => observatory_platform/google/tests/fixtures/schema}/table_a.json (100%) rename {tests/fixtures/schemas => observatory_platform/google/tests/fixtures/schema}/table_b_1900-01-01.json (100%) rename {tests/fixtures/schemas => observatory_platform/google/tests/fixtures/schema}/table_b_2000-01-01.json (100%) rename {tests/fixtures/schemas => observatory_platform/google/tests/fixtures/schema}/test_schema_2021-01-01.json (100%) rename {tests/fixtures/schemas => observatory_platform/google/tests/fixtures/schema}/wos-2020-10-01.json (100%) rename {tests/observatory/platform => observatory_platform/google/tests}/test_bigquery.py (90%) rename {tests/observatory/platform => observatory_platform/google/tests}/test_gcs.py (98%) rename {observatory-platform/observatory/platform/utils => observatory_platform}/http_download.py (98%) rename {observatory-platform/observatory/platform/utils => observatory_platform}/jinja2_utils.py (99%) rename {observatory-platform/observatory/platform/utils => observatory_platform}/proc_utils.py (54%) rename {observatory-platform/observatory/platform/terraform => observatory_platform/sandbox}/__init__.py (100%) create mode 100644 observatory_platform/sandbox/ftp_server.py create mode 100644 observatory_platform/sandbox/http_server.py create mode 100644 observatory_platform/sandbox/sandbox_environment.py create mode 100644 observatory_platform/sandbox/sftp_server.py create mode 100644 observatory_platform/sandbox/test_utils.py rename {observatory-platform/observatory/platform/utils => observatory_platform/sandbox/tests}/__init__.py (100%) rename {observatory-platform/observatory/platform/workflows => observatory_platform/sandbox/tests/fixtures}/__init__.py (100%) create mode 100644 observatory_platform/sandbox/tests/fixtures/bad_dag.py rename {tests/fixtures/utils => observatory_platform/sandbox/tests/fixtures}/http_testfile.txt (100%) create mode 100644 observatory_platform/sandbox/tests/fixtures/people.csv rename {tests/fixtures/utils => observatory_platform/sandbox/tests/fixtures}/people.csv.gz (100%) create mode 100644 observatory_platform/sandbox/tests/fixtures/people.jsonl create mode 100644 observatory_platform/sandbox/tests/fixtures/people_schema.json create mode 100644 observatory_platform/sandbox/tests/test_ftp_server.py create mode 100644 observatory_platform/sandbox/tests/test_http_server.py create mode 100644 observatory_platform/sandbox/tests/test_sandbox_environment.py create mode 100644 observatory_platform/sandbox/tests/test_sftp_server.py create mode 100644 observatory_platform/sandbox/tests/test_test_utils.py rename {tests => observatory_platform/schema}/__init__.py (100%) create mode 100644 observatory_platform/schema/dataset_release.json rename {observatory-platform/observatory/platform => observatory_platform}/sftp.py (100%) rename {tests/fixtures => observatory_platform/sql}/__init__.py (100%) rename {observatory-platform/observatory/platform => observatory_platform}/sql/delete_records.sql.jinja2 (100%) rename {observatory-platform/observatory/platform => observatory_platform}/sql/select_columns.sql.jinja2 (100%) rename {observatory-platform/observatory/platform => observatory_platform}/sql/select_table_shard_dates.sql.jinja2 (100%) rename {observatory-platform/observatory/platform => observatory_platform}/sql/upsert_records.sql.jinja2 (100%) rename {tests/fixtures/cli/my-workflows-project/my_workflows_project => observatory_platform/tests}/__init__.py (100%) rename {tests/fixtures/cli/my-workflows-project/my_workflows_project/dags => observatory_platform/tests/fixtures}/__init__.py (100%) rename {tests/fixtures/utils => observatory_platform/tests/fixtures}/find_replace.txt (100%) rename {tests/fixtures/utils => observatory_platform/tests/fixtures}/get_http_response_json.json (100%) rename {tests/fixtures/utils => observatory_platform/tests/fixtures}/get_http_response_xml_to_dict.xml (100%) create mode 100644 observatory_platform/tests/fixtures/http_testfile.txt rename {tests/fixtures/utils => observatory_platform/tests/fixtures}/http_testfile2.txt (100%) rename {tests/fixtures/elastic => observatory_platform/tests/fixtures}/load_csv.csv (100%) rename {tests/fixtures/elastic => observatory_platform/tests/fixtures}/load_csv_gz.csv.gz (100%) rename {tests/fixtures/elastic => observatory_platform/tests/fixtures}/load_jsonl.jsonl (100%) rename {tests/fixtures/elastic => observatory_platform/tests/fixtures}/load_jsonl_gz.jsonl.gz (100%) rename {tests/fixtures/utils => observatory_platform/tests/fixtures}/test_hasher.txt (100%) rename {tests/fixtures/utils => observatory_platform/tests/fixtures}/testzip.txt (100%) rename {tests/fixtures/utils => observatory_platform/tests/fixtures}/testzip.txt.gz (100%) create mode 100644 observatory_platform/tests/test_config.py create mode 100644 observatory_platform/tests/test_dataset_api.py rename {tests/observatory/platform => observatory_platform/tests}/test_files.py (92%) rename {tests/observatory/platform/utils => observatory_platform/tests}/test_http_download.py (90%) rename {tests/observatory/platform/utils => observatory_platform/tests}/test_jinja2_utils.py (94%) rename observatory-platform/observatory/platform/dags/dummy_telescope.py => observatory_platform/tests/test_proc_utils.py (52%) rename {tests/observatory/platform => observatory_platform/tests}/test_sftp.py (97%) rename {tests/observatory/platform/utils => observatory_platform/tests}/test_url_utils.py (91%) rename {observatory-platform/observatory/platform/utils => observatory_platform}/url_utils.py (100%) create mode 100644 pyproject.toml delete mode 100644 strategy.ini delete mode 100644 tests/fixtures/cli/my-workflows-project/my_workflows_project/dags/hello_world_dag.py delete mode 100644 tests/fixtures/cli/my-workflows-project/my_workflows_project/dags/my_dag.py delete mode 100644 tests/fixtures/cli/my-workflows-project/my_workflows_project/workflows/__init__.py delete mode 100644 tests/fixtures/cli/my-workflows-project/my_workflows_project/workflows/my_workflow.py delete mode 100644 tests/fixtures/cli/my-workflows-project/requirements.sh delete mode 100644 tests/fixtures/cli/my-workflows-project/requirements.txt delete mode 100644 tests/fixtures/cli/my-workflows-project/setup.cfg delete mode 100644 tests/fixtures/cli/my-workflows-project/setup.py delete mode 100644 tests/fixtures/elastic/the-expanse-mappings.json delete mode 100644 tests/fixtures/schemas/ao-author-mappings.json delete mode 100644 tests/fixtures/schemas/ao_author_2021-01-01.json delete mode 100644 tests/fixtures/utils/main.tf delete mode 100644 tests/fixtures/utils/test.csv delete mode 100644 tests/observatory/__init__.py delete mode 100644 tests/observatory/api/__init__.py delete mode 100644 tests/observatory/api/client/__init__.py delete mode 100644 tests/observatory/api/client/test_dataset_release.py delete mode 100644 tests/observatory/api/client/test_observatory_api.py delete mode 100644 tests/observatory/api/server/__init__.py delete mode 100644 tests/observatory/api/server/test_openapi.py delete mode 100644 tests/observatory/api/server/test_orm.py delete mode 100644 tests/observatory/api/test_utils.py delete mode 100644 tests/observatory/platform/__init__.py delete mode 100644 tests/observatory/platform/cli/__init__.py delete mode 100644 tests/observatory/platform/cli/test_cli.py delete mode 100644 tests/observatory/platform/cli/test_cli_functional.py delete mode 100644 tests/observatory/platform/cli/test_cli_utils.py delete mode 100644 tests/observatory/platform/cli/test_generate_command.py delete mode 100644 tests/observatory/platform/cli/test_platform_command.py delete mode 100644 tests/observatory/platform/docker/__init__.py delete mode 100644 tests/observatory/platform/docker/test_builder.py delete mode 100644 tests/observatory/platform/docker/test_compose_runner.py delete mode 100644 tests/observatory/platform/docker/test_platform_runner.py delete mode 100644 tests/observatory/platform/terraform/__init__.py delete mode 100644 tests/observatory/platform/terraform/test_terraform_api.py delete mode 100644 tests/observatory/platform/terraform/test_terraform_builder.py delete mode 100644 tests/observatory/platform/test_api.py delete mode 100644 tests/observatory/platform/test_config.py delete mode 100644 tests/observatory/platform/test_observatory_config.py delete mode 100644 tests/observatory/platform/test_observatory_environment.py delete mode 100644 tests/observatory/platform/utils/__init__.py delete mode 100644 tests/observatory/platform/utils/test_proc_utils.py delete mode 100644 tests/observatory/platform/workflows/__init__.py delete mode 100644 tests/observatory/platform/workflows/test_vm_create.py delete mode 100644 tests/observatory/platform/workflows/test_vm_destroy.py delete mode 100644 tests/observatory/platform/workflows/test_workflow.py diff --git a/.coveragerc b/.coveragerc index cbe812ea2..580f23ff2 100644 --- a/.coveragerc +++ b/.coveragerc @@ -1,6 +1,6 @@ [run] branch=True -source=observatory +source=observatory_platform [report] exclude_lines = @@ -12,10 +12,4 @@ exclude_lines = @abstract ignore_errors = True omit = - tests/* - observatory-api/observatory/api/client/api_client.py - observatory-api/observatory/api/client/configuration.py - observatory-api/observatory/api/client/exceptions.py - observatory-api/observatory/api/client/model_utils.py - observatory-api/observatory/api/client/rest.py - observatory-platform/observatory/platform/airflow/* + */tests* diff --git a/.gitattributes b/.gitattributes index cac44f22c..e38728087 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,4 +1,5 @@ -tests/fixtures/** filter=lfs diff=lfs merge=lfs -text +**/fixtures/** filter=lfs diff=lfs merge=lfs -text docs/* linguist-documentation notebooks/* linguist-documentation *.csv filter=lfs diff=lfs merge=lfs -text +**/fixtures/**/* filter=lfs diff=lfs merge=lfs -text diff --git a/.github/workflows/publish-pypi.yml b/.github/workflows/publish-pypi.yml index 88bc90957..45d664cb9 100644 --- a/.github/workflows/publish-pypi.yml +++ b/.github/workflows/publish-pypi.yml @@ -10,31 +10,18 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v2 + uses: actions/checkout@v4 - name: Set up Python - uses: actions/setup-python@v2 + uses: actions/setup-python@v5 with: python-version: 3.10 - name: Build packages run: | - cd observatory-api - python3 setup.py sdist - - cd ../observatory-platform - cp ../README.md . - python3 setup.py sdist - - cd ../ - - - name: Publish observatory-api - if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags') - uses: pypa/gh-action-pypi-publish@master - with: - user: __token__ - password: ${{ secrets.PYPI_PASSWORD }} - packages_dir: observatory-api/dist/ + python3 -m pip install --upgrade build + apt install python3.10-venv + python3 -m build --sdist - name: Publish observatory-platform if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags') @@ -42,4 +29,4 @@ jobs: with: user: __token__ password: ${{ secrets.PYPI_PASSWORD }} - packages_dir: observatory-platform/dist/ + packages_dir: ./dist/ diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index a6ce6d8b7..95624ab55 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -14,26 +14,24 @@ jobs: steps: - name: Checkout ${{ matrix.python-version }} - uses: actions/checkout@v2 + uses: actions/checkout@v4 with: lfs: true - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Install dependencies run: | python -m pip install --upgrade pip - pip install -e observatory-api[tests] --constraint https://raw.githubusercontent.com/apache/airflow/constraints-2.6.3/constraints-no-providers-${{ matrix.python-version }}.txt - pip install -e observatory-platform[tests] --constraint https://raw.githubusercontent.com/apache/airflow/constraints-2.6.3/constraints-no-providers-${{ matrix.python-version }}.txt + pip install -e .[tests] --constraint https://raw.githubusercontent.com/apache/airflow/constraints-2.7.3/constraints-no-providers-${{ matrix.python-version }}.txt - name: Check licenses run: | # stop the build if there are licensing issues - liccheck --sfile strategy.ini --rfile observatory-api/requirements.txt --level CAUTIOUS --reporting liccheck-output.txt --no-deps - liccheck --sfile strategy.ini --rfile observatory-platform/requirements.txt --level CAUTIOUS --reporting liccheck-output.txt --no-deps + liccheck - name: Lint with flake8 run: | diff --git a/.gitignore b/.gitignore index a6ff141c6..f7711a028 100644 --- a/.gitignore +++ b/.gitignore @@ -22,3 +22,6 @@ ChangeLog /observatory-dags/observatory/dags/workflows/oapen_cloud_function.zip docs/schemas .env +/observatory-platform/build +/observatory-api/build +liccheck-output.txt \ No newline at end of file diff --git a/LICENSES_THIRD_PARTY b/LICENSES_THIRD_PARTY index 74f92841e..c7ce950dd 100644 --- a/LICENSES_THIRD_PARTY +++ b/LICENSES_THIRD_PARTY @@ -1,140 +1,42 @@ -Licenses of third party code included in this project are documented in this file as well as in individual source files. +# Code in observatory_platform/sandbox/sandbox_environment.py for running Airflow tasks comes from: +# * https://github.com/apache/airflow/blob/ffb472cf9e630bd70f51b74b0d0ea4ab98635572/airflow/cli/commands/task_command.py +# * https://github.com/apache/airflow/blob/master/docs/apache-airflow/best-practices.rst -The below files are included from https://github.com/keras-team/keras and have the following license: +# With the license: -- observatory_platform/utils/data_utils.py -- observatory_platform/utils/progbar_utils.py - -# COPYRIGHT -# -# All contributions by François Chollet: -# Copyright (c) 2015 - 2019, François Chollet. -# All rights reserved. -# -# All contributions by Google: -# Copyright (c) 2015 - 2019, Google, Inc. -# All rights reserved. -# -# All contributions by Microsoft: -# Copyright (c) 2017 - 2019, Microsoft, Inc. -# All rights reserved. -# -# All other contributions: -# Copyright (c) 2015 - 2019, the respective contributors. -# All rights reserved. -# -# Each contributor holds copyright over their respective contributions. -# The project versioning (Git) records all such contribution source information. -# -# LICENSE -# -# The MIT License (MIT) -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -The below files are included from https://github.com/commoncrawl/cc-pyspark and have the following license: - -- observatory_platform/common_crawl/cc_fetcher.py - -# The MIT License (MIT) -# -# Copyright (c) 2017 Common Crawl -# Copyright (c) 2019 Curtin University +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: +# http://www.apache.org/licenses/LICENSE-2.0 # -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -The below file are included from https://github.com/fengwangPhysics/matplotlib-chord-diagram and have the following -license: - -- observatory_platform/analysis/charts/chord.py +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. -# MIT License -# -# Copyright (c) 2017 Feng Wang -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -The below files are included from https://github.com/sphinx-doc/sphinx/tree/master/sphinx/templates/quickstart and have -the following license: +# Code in observatory_platform/sandbox/sftp_server is based on: +# * https://github.com/rspivak/sftpserver/blob/master/src/sftpserver/__init__.py -- observatory-platform/observatory/platform/cli/templates/generate_project/sphinx-quickstart/conf.py_t -- observatory-platform/observatory/platform/cli/templates/generate_project/sphinx-quickstart/make.bat.new_t -- observatory-platform/observatory/platform/cli/templates/generate_project/sphinx-quickstart/Makefile.new_t -- observatory-platform/observatory/platform/cli/templates/generate_project/sphinx-quickstart/root_doc.rst_t +# With the license: -# License for Sphinx -# ================== -# -# Copyright (c) 2007-2021 by the Sphinx team (see AUTHORS file). -# All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are -# met: +# Copyright 2021-2024 Curtin University # -# * Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # -# * Redistributions in binary form must reproduce the above copyright -# notice, this list of conditions and the following disclaimer in the -# documentation and/or other materials provided with the distribution. +# http://www.apache.org/licenses/LICENSE-2.0 # -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. \ No newline at end of file diff --git a/README.md b/README.md index cd5110999..05fa911f4 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,11 @@ ![Observatory Platform](https://raw.githubusercontent.com/The-Academic-Observatory/observatory-platform/develop/logo.jpg) -The Observatory Platform is an environment for fetching, processing and analysing data to understand how well -universities operate as Open Knowledge Institutions. +The Observatory Platform is a set of dependencies used by the Curtin Open Knowledge Initiative (COKI) for running its +Airflow based workflows to fetch, process and analyse bibliometric datasets. -The Observatory Platform is built with Apache Airflow. The workflows for the project can be seen in the -[Academic Observatory Workflows](https://github.com/The-Academic-Observatory/academic-observatory-workflows) -and [OAeBU Workflows](https://github.com/The-Academic-Observatory/oaebu-workflows) projects. +The workflows for the project can be seen in at: +* [Academic Observatory Workflows](https://github.com/The-Academic-Observatory/academic-observatory-workflows) +* [OAeBU Workflows](https://github.com/The-Academic-Observatory/oaebu-workflows) [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) [![Python Version](https://img.shields.io/badge/python-3.10-blue)](https://img.shields.io/badge/python-3.10-blue) @@ -15,8 +15,17 @@ and [OAeBU Workflows](https://github.com/The-Academic-Observatory/oaebu-workflow [![codecov](https://codecov.io/gh/The-Academic-Observatory/observatory-platform/branch/develop/graph/badge.svg)](https://codecov.io/gh/The-Academic-Observatory/observatory-platform) [![DOI](https://zenodo.org/badge/227744539.svg)](https://zenodo.org/badge/latestdoi/227744539) -## Documentation -For more detailed documentation about the Observatory Platform see the Read the Docs website [https://observatory-platform.readthedocs.io](https://observatory-platform.readthedocs.io) +## Dependencies +Observatory Platform supports Python 3.10, Ubuntu Linux 22.04 and MacOS 10.14, on x86 architecture. + +System dependencies: +* Python 3.10 +* Pip +* virtualenv +* Google Cloud SDK (optional): https://cloud.google.com/sdk/docs/install-sdk + +## Python Package Reference +See the Read the Docs website for documentation on the Python package [https://observatory-platform.readthedocs.io](https://observatory-platform.readthedocs.io) ## Dependent Repositories The Observatory Platform is a dependency for other repositories developed and maintained by [The Academic Observatory](https://github.com/The-Academic-Observatory): diff --git a/bin/generate-api-client.sh b/bin/generate-api-client.sh deleted file mode 100755 index 5a1c65a25..000000000 --- a/bin/generate-api-client.sh +++ /dev/null @@ -1,34 +0,0 @@ -#!/bin/bash - -export PATH=$PATH:~/bin/openapitools/ -export OPENAPI_GENERATOR_VERSION=6.1.0 - -if ! command -v openapi-generator-cli &>/dev/null; then - mkdir -p ~/bin/openapitools - curl https://raw.githubusercontent.com/OpenAPITools/openapi-generator/master/bin/utils/openapi-generator-cli.sh >~/bin/openapitools/openapi-generator-cli - chmod u+x ~/bin/openapitools/openapi-generator-cli - openapi-generator-cli version -fi - -if ! command -v observatory-api &>/dev/null; then - pip install -e observatory-api --constraint https://raw.githubusercontent.com/apache/airflow/constraints-2.6.3/constraints-no-providers-3.10.txt -fi - -# Directories -api_dir=observatory-api/observatory/api -server_dir=${api_dir}/server -client_dir=${api_dir}/client -docs_dir=docs/api - -# Generate OpenAPI specification -observatory-api generate-openapi-spec ${server_dir}/openapi.yaml.jinja2 observatory-api/openapi.yaml --api-client -cp observatory-api/openapi.yaml docs/api/openapi.yaml - -# Generate OpenAPI Python client -openapi-generator-cli generate -i observatory-api/openapi.yaml -g python -c observatory-api/api-config.yaml -t observatory-api/templates/ -o observatory-api - -# Massage files into correct directory -mv ${api_dir}/client_README.md ${docs_dir}/api_client.md -cp -rn ${client_dir}/test/* tests/observatory/api/client/ -mv ${client_dir}/docs/* ${docs_dir} -rm -r ${client_dir}/test/ ${client_dir}/docs/ ${client_dir}/apis/ ${client_dir}/models/ diff --git a/docs/api/BigQueryBytesProcessed.md b/docs/api/BigQueryBytesProcessed.md deleted file mode 100644 index 098bc34e3..000000000 --- a/docs/api/BigQueryBytesProcessed.md +++ /dev/null @@ -1,54 +0,0 @@ -# BigQueryBytesProcessed - -## Properties -
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
NameTypeDescriptionNotes
idint[optional]
projectstr[optional]
totalint[optional]
createddatetime[optional] [readonly]
modifieddatetime[optional] [readonly]
- diff --git a/docs/api/DatasetRelease.md b/docs/api/DatasetRelease.md deleted file mode 100644 index 9bfa7fce0..000000000 --- a/docs/api/DatasetRelease.md +++ /dev/null @@ -1,114 +0,0 @@ -# DatasetRelease - -## Properties -
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
NameTypeDescriptionNotes
idint[optional]
dag_idstr[optional]
dataset_idstr[optional]
dag_run_idstr, none_type[optional]
data_interval_startdatetime, none_type[optional]
data_interval_enddatetime, none_type[optional]
snapshot_datedatetime, none_type[optional]
partition_datedatetime, none_type[optional]
changefile_start_datedatetime, none_type[optional]
changefile_end_datedatetime, none_type[optional]
sequence_startint, none_type[optional]
sequence_endint, none_type[optional]
createddatetime[optional] [readonly]
modifieddatetime[optional] [readonly]
extrabool, date, datetime, dict, float, int, list, str, none_type[optional]
- diff --git a/docs/api/ObservatoryApi.md b/docs/api/ObservatoryApi.md deleted file mode 100644 index 422a9339c..000000000 --- a/docs/api/ObservatoryApi.md +++ /dev/null @@ -1,661 +0,0 @@ -# ObservatoryApi - -All URIs are relative to *https://localhost:5002* - -
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
MethodHTTP requestDescription
delete_dataset_releaseDELETE /v1/dataset_releasedelete a DatasetRelease
get_dataset_releaseGET /v1/dataset_releaseget a DatasetRelease
get_dataset_releasesGET /v1/dataset_releasesGet a list of DatasetRelease objects
post_dataset_releasePOST /v1/dataset_releasecreate a DatasetRelease
put_dataset_releasePUT /v1/dataset_releasecreate or update a DatasetRelease
- -## **delete_dataset_release** -> delete_dataset_release(id) - -delete a DatasetRelease - -Delete a DatasetRelease by passing it's id. - -### Example - -* Api Key Authentication (api_key): -```python -import time -import observatory.api.client -from observatory.api.client.api import observatory_api -from pprint import pprint -# Defining the host is optional and defaults to https://localhost:5002 -# See configuration.py for a list of all supported configuration parameters. -configuration = observatory.api.client.Configuration( - host = "https://localhost:5002" -) - -# The client must configure the authentication and authorization parameters -# in accordance with the API server security policy. -# Examples for each auth method are provided below, use the example that -# satisfies your auth use case. - -# Configure API key authorization: api_key -configuration.api_key['api_key'] = 'YOUR_API_KEY' - -# Uncomment below to setup prefix (e.g. Bearer) for API key, if needed -# configuration.api_key_prefix['api_key'] = 'Bearer' - -# Enter a context with an instance of the API client -with observatory.api.client.ApiClient(configuration) as api_client: - # Create an instance of the API class - api_instance = observatory_api.ObservatoryApi(api_client) - id = 1 # int | DatasetRelease id - - # example passing only required values which don't have defaults set - try: - # delete a DatasetRelease - api_instance.delete_dataset_release(id) - except observatory.api.client.ApiException as e: - print("Exception when calling ObservatoryApi->delete_dataset_release: %s\n" % e) -``` - - -### Parameters - - -
- - - - - - - - - - - - - - - - - - - - - - - -
NameTypeDescriptionNotes
idintDatasetRelease id
- - -### Return type - -void (empty response body) - -### Authorization - -[api_key](ObservatoryApi.html#api_key) - -### HTTP request headers - - - **Content-Type**: Not defined - - **Accept**: Not defined - - -### HTTP response details -
- - - - - - - - - - - - - - - - -
Status codeDescriptionResponse headers
200DatasetRelease deleted -
- -## **get_dataset_release** -> DatasetRelease get_dataset_release(id) - -get a DatasetRelease - -Get the details of a DatasetRelease by passing it's id. - -### Example - -* Api Key Authentication (api_key): -```python -import time -import observatory.api.client -from observatory.api.client.api import observatory_api -from observatory.api.client.model.dataset_release import DatasetRelease -from pprint import pprint -# Defining the host is optional and defaults to https://localhost:5002 -# See configuration.py for a list of all supported configuration parameters. -configuration = observatory.api.client.Configuration( - host = "https://localhost:5002" -) - -# The client must configure the authentication and authorization parameters -# in accordance with the API server security policy. -# Examples for each auth method are provided below, use the example that -# satisfies your auth use case. - -# Configure API key authorization: api_key -configuration.api_key['api_key'] = 'YOUR_API_KEY' - -# Uncomment below to setup prefix (e.g. Bearer) for API key, if needed -# configuration.api_key_prefix['api_key'] = 'Bearer' - -# Enter a context with an instance of the API client -with observatory.api.client.ApiClient(configuration) as api_client: - # Create an instance of the API class - api_instance = observatory_api.ObservatoryApi(api_client) - id = 1 # int | DatasetRelease id - - # example passing only required values which don't have defaults set - try: - # get a DatasetRelease - api_response = api_instance.get_dataset_release(id) - pprint(api_response) - except observatory.api.client.ApiException as e: - print("Exception when calling ObservatoryApi->get_dataset_release: %s\n" % e) -``` - - -### Parameters - - -
- - - - - - - - - - - - - - - - - - - - - - - -
NameTypeDescriptionNotes
idintDatasetRelease id
- - -### Return type - -[**DatasetRelease**](DatasetRelease.html) - -### Authorization - -[api_key](ObservatoryApi.html#api_key) - -### HTTP request headers - - - **Content-Type**: Not defined - - **Accept**: application/json - - -### HTTP response details -
- - - - - - - - - - - - - - - - - - - - - -
Status codeDescriptionResponse headers
200the fetched DatasetRelease -
400bad input parameter -
- -## **get_dataset_releases** -> [DatasetRelease] get_dataset_releases() - -Get a list of DatasetRelease objects - -Get a list of DatasetRelease objects - -### Example - -* Api Key Authentication (api_key): -```python -import time -import observatory.api.client -from observatory.api.client.api import observatory_api -from observatory.api.client.model.dataset_release import DatasetRelease -from pprint import pprint -# Defining the host is optional and defaults to https://localhost:5002 -# See configuration.py for a list of all supported configuration parameters. -configuration = observatory.api.client.Configuration( - host = "https://localhost:5002" -) - -# The client must configure the authentication and authorization parameters -# in accordance with the API server security policy. -# Examples for each auth method are provided below, use the example that -# satisfies your auth use case. - -# Configure API key authorization: api_key -configuration.api_key['api_key'] = 'YOUR_API_KEY' - -# Uncomment below to setup prefix (e.g. Bearer) for API key, if needed -# configuration.api_key_prefix['api_key'] = 'Bearer' - -# Enter a context with an instance of the API client -with observatory.api.client.ApiClient(configuration) as api_client: - # Create an instance of the API class - api_instance = observatory_api.ObservatoryApi(api_client) - dag_id = "dag_id_example" # str | the dag_id to fetch release info for (optional) - dataset_id = "dataset_id_example" # str | the dataset_id to fetch release info for (optional) - - # example passing only required values which don't have defaults set - # and optional values - try: - # Get a list of DatasetRelease objects - api_response = api_instance.get_dataset_releases(dag_id=dag_id, dataset_id=dataset_id) - pprint(api_response) - except observatory.api.client.ApiException as e: - print("Exception when calling ObservatoryApi->get_dataset_releases: %s\n" % e) -``` - - -### Parameters - - -
- - - - - - - - - - - - - - - - - - - - - - - -
NameTypeDescriptionNotes
dag_idstrthe dag_id to fetch release info for -[optional] -
dataset_idstrthe dataset_id to fetch release info for -[optional] -
- - -### Return type - -[**[DatasetRelease]**](DatasetRelease.html) - -### Authorization - -[api_key](ObservatoryApi.html#api_key) - -### HTTP request headers - - - **Content-Type**: Not defined - - **Accept**: application/json - - -### HTTP response details -
- - - - - - - - - - - - - - - - - - - - - -
Status codeDescriptionResponse headers
200a list of DatasetRelease objects -
400bad input parameter -
- -## **post_dataset_release** -> DatasetRelease post_dataset_release(body) - -create a DatasetRelease - -Create a DatasetRelease by passing a DatasetRelease object, without an id. - -### Example - -* Api Key Authentication (api_key): -```python -import time -import observatory.api.client -from observatory.api.client.api import observatory_api -from observatory.api.client.model.dataset_release import DatasetRelease -from pprint import pprint -# Defining the host is optional and defaults to https://localhost:5002 -# See configuration.py for a list of all supported configuration parameters. -configuration = observatory.api.client.Configuration( - host = "https://localhost:5002" -) - -# The client must configure the authentication and authorization parameters -# in accordance with the API server security policy. -# Examples for each auth method are provided below, use the example that -# satisfies your auth use case. - -# Configure API key authorization: api_key -configuration.api_key['api_key'] = 'YOUR_API_KEY' - -# Uncomment below to setup prefix (e.g. Bearer) for API key, if needed -# configuration.api_key_prefix['api_key'] = 'Bearer' - -# Enter a context with an instance of the API client -with observatory.api.client.ApiClient(configuration) as api_client: - # Create an instance of the API class - api_instance = observatory_api.ObservatoryApi(api_client) - body = DatasetRelease( - id=1, - dag_id="doi_workflow", - dataset_id="doi", - dag_run_id="YYYY-MM-DDTHH:mm:ss.ssssss", - data_interval_start=dateutil_parser('2020-01-02T20:01:05Z'), - data_interval_end=dateutil_parser('2020-01-02T20:01:05Z'), - snapshot_date=dateutil_parser('2020-01-02T20:01:05Z'), - partition_date=dateutil_parser('2020-01-02T20:01:05Z'), - changefile_start_date=dateutil_parser('2020-01-02T20:01:05Z'), - changefile_end_date=dateutil_parser('2020-01-02T20:01:05Z'), - sequence_start=1, - sequence_end=3, - extra={}, - ) # DatasetRelease | DatasetRelease to create - - # example passing only required values which don't have defaults set - try: - # create a DatasetRelease - api_response = api_instance.post_dataset_release(body) - pprint(api_response) - except observatory.api.client.ApiException as e: - print("Exception when calling ObservatoryApi->post_dataset_release: %s\n" % e) -``` - - -### Parameters - - -
- - - - - - - - - - - - - - - - - - - - - - - -
NameTypeDescriptionNotes
bodyDatasetReleaseDatasetRelease to create
- - -### Return type - -[**DatasetRelease**](DatasetRelease.html) - -### Authorization - -[api_key](ObservatoryApi.html#api_key) - -### HTTP request headers - - - **Content-Type**: application/json - - **Accept**: application/json - - -### HTTP response details -
- - - - - - - - - - - - - - - - -
Status codeDescriptionResponse headers
201DatasetRelease created, returning the created object with an id -
- -## **put_dataset_release** -> DatasetRelease put_dataset_release(body) - -create or update a DatasetRelease - -Create a DatasetRelease by passing a DatasetRelease object, without an id. Update an existing DatasetRelease by passing a DatasetRelease object with an id. - -### Example - -* Api Key Authentication (api_key): -```python -import time -import observatory.api.client -from observatory.api.client.api import observatory_api -from observatory.api.client.model.dataset_release import DatasetRelease -from pprint import pprint -# Defining the host is optional and defaults to https://localhost:5002 -# See configuration.py for a list of all supported configuration parameters. -configuration = observatory.api.client.Configuration( - host = "https://localhost:5002" -) - -# The client must configure the authentication and authorization parameters -# in accordance with the API server security policy. -# Examples for each auth method are provided below, use the example that -# satisfies your auth use case. - -# Configure API key authorization: api_key -configuration.api_key['api_key'] = 'YOUR_API_KEY' - -# Uncomment below to setup prefix (e.g. Bearer) for API key, if needed -# configuration.api_key_prefix['api_key'] = 'Bearer' - -# Enter a context with an instance of the API client -with observatory.api.client.ApiClient(configuration) as api_client: - # Create an instance of the API class - api_instance = observatory_api.ObservatoryApi(api_client) - body = DatasetRelease( - id=1, - dag_id="doi_workflow", - dataset_id="doi", - dag_run_id="YYYY-MM-DDTHH:mm:ss.ssssss", - data_interval_start=dateutil_parser('2020-01-02T20:01:05Z'), - data_interval_end=dateutil_parser('2020-01-02T20:01:05Z'), - snapshot_date=dateutil_parser('2020-01-02T20:01:05Z'), - partition_date=dateutil_parser('2020-01-02T20:01:05Z'), - changefile_start_date=dateutil_parser('2020-01-02T20:01:05Z'), - changefile_end_date=dateutil_parser('2020-01-02T20:01:05Z'), - sequence_start=1, - sequence_end=3, - extra={}, - ) # DatasetRelease | DatasetRelease to create or update - - # example passing only required values which don't have defaults set - try: - # create or update a DatasetRelease - api_response = api_instance.put_dataset_release(body) - pprint(api_response) - except observatory.api.client.ApiException as e: - print("Exception when calling ObservatoryApi->put_dataset_release: %s\n" % e) -``` - - -### Parameters - - -
- - - - - - - - - - - - - - - - - - - - - - - -
NameTypeDescriptionNotes
bodyDatasetReleaseDatasetRelease to create or update
- - -### Return type - -[**DatasetRelease**](DatasetRelease.html) - -### Authorization - -[api_key](ObservatoryApi.html#api_key) - -### HTTP request headers - - - **Content-Type**: application/json - - **Accept**: application/json - - -### HTTP response details -
- - - - - - - - - - - - - - - - - - - - - -
Status codeDescriptionResponse headers
200DatasetRelease updated -
201DatasetRelease created, returning the created object with an id -
- diff --git a/docs/api/api_client.md b/docs/api/api_client.md deleted file mode 100644 index 72cab128d..000000000 --- a/docs/api/api_client.md +++ /dev/null @@ -1,146 +0,0 @@ -# Python API Client -The REST API for managing and accessing data from the Observatory Platform. - -The `observatory.api.client` package is automatically generated by the [OpenAPI Generator](https://openapi-generator.tech) project: - -- API version: 1.0.0 -- Package version: 1.0.0 - -## Requirements -Python >= 3.10 - -## Installation & Usage -To install the package with PyPI: -```bash -pip install observatory-api -``` - -To install the package from source: -```bash -git clone https://github.com/The-Academic-Observatory/observatory-platform.git -cd observatory-platform -pip install -e observatory-api --constraint https://raw.githubusercontent.com/apache/airflow/constraints-2.6.3/constraints-no-providers-3.10.txt -``` - -## Getting Started -In your own code, to use this library to connect and interact with client, -you can run the following: - -```python - -import time -import observatory.api.client -from pprint import pprint -from observatory.api.client.api import observatory_api -from observatory.api.client.model.dataset_release import DatasetRelease -# Defining the host is optional and defaults to https://localhost:5002 -# See configuration.py for a list of all supported configuration parameters. -configuration = observatory.api.client.Configuration( - host = "https://localhost:5002" -) - -# The client must configure the authentication and authorization parameters -# in accordance with the API server security policy. -# Examples for each auth method are provided below, use the example that -# satisfies your auth use case. - -# Configure API key authorization: api_key -configuration.api_key['api_key'] = 'YOUR_API_KEY' - -# Uncomment below to setup prefix (e.g. Bearer) for API key, if needed -# configuration.api_key_prefix['api_key'] = 'Bearer' - - -# Enter a context with an instance of the API client -with observatory.api.client.ApiClient(configuration) as api_client: - # Create an instance of the API class - api_instance = observatory_api.ObservatoryApi(api_client) - id = 1 # int | DatasetRelease id - - try: - # delete a DatasetRelease - api_instance.delete_dataset_release(id) - except observatory.api.client.ApiException as e: - print("Exception when calling ObservatoryApi->delete_dataset_release: %s\n" % e) -``` - -## Documentation for API Endpoints - -All URIs are relative to *https://localhost:5002* - -```eval_rst -.. toctree:: - :maxdepth: 1 - - ObservatoryApi -``` - -
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
ClassMethodHTTP requestDescription
ObservatoryApidelete_dataset_releaseDELETE /v1/dataset_releasedelete a DatasetRelease
ObservatoryApiget_dataset_releaseGET /v1/dataset_releaseget a DatasetRelease
ObservatoryApiget_dataset_releasesGET /v1/dataset_releasesGet a list of DatasetRelease objects
ObservatoryApipost_dataset_releasePOST /v1/dataset_releasecreate a DatasetRelease
ObservatoryApiput_dataset_releasePUT /v1/dataset_releasecreate or update a DatasetRelease
- -## Documentation For Models -```eval_rst -.. toctree:: - :maxdepth: 1 - - DatasetRelease - -``` - -## Documentation For Authorization - - -## api_key - -- **Type**: API key -- **API key parameter name**: key -- **Location**: URL query string - diff --git a/docs/api/api_rest.rst b/docs/api/api_rest.rst deleted file mode 100644 index 5466fa4d2..000000000 --- a/docs/api/api_rest.rst +++ /dev/null @@ -1,5 +0,0 @@ -REST API Specification ----------------------------- -The REST API specification for managing and accessing data from the Observatory Platform. - -.. openapi:httpdomain:: openapi.yaml \ No newline at end of file diff --git a/docs/api/api_server.md b/docs/api/api_server.md deleted file mode 100644 index df4039944..000000000 --- a/docs/api/api_server.md +++ /dev/null @@ -1,18 +0,0 @@ -# API Server -This is a tutorial for deploying the Observatory Platform to Google Cloud with Terraform. - -You should have [installed the Observatory Platform](installation.html) before following this tutorial. - -## How to deploy -The API is part of the Terraform configuration and is deployed when the Terraform configuration is applied. -See the 'Observatory Terraform Environment' section for more information on how to do this. -The URL of the API corresponds to the URL in the 'endpoints' Cloud Run service. - -## Generating an API key -To generate an API key, in the Google Cloud Console go to 'APIs & Services' -> 'Credentials' and click 'Create Credentials' in the top bar. -From the dropdown menu select, 'API key'. The API key is then generated for you. -For security reasons, restrict the API key to only the 'Observatory API' service. - -## Adding a user to the API -From the Google Cloud Endpoints page, click on "Observatory API" and then click "ADD MEMBER" at the far right of the -screen. Add the email address of the user that you want to add and give them the Role "Service Consumer". \ No newline at end of file diff --git a/docs/api/index.rst b/docs/api/index.rst deleted file mode 100644 index f99815d46..000000000 --- a/docs/api/index.rst +++ /dev/null @@ -1,8 +0,0 @@ -Observatory Platform API ----------------------------- -.. toctree:: - :maxdepth: 1 - - api_client - api_rest - api_server diff --git a/docs/api/openapi.yaml b/docs/api/openapi.yaml deleted file mode 100644 index 40171a282..000000000 --- a/docs/api/openapi.yaml +++ /dev/null @@ -1,229 +0,0 @@ -swagger: '2.0' -info: - title: Observatory API - description: | - The REST API for managing and accessing data from the Observatory Platform. - version: 1.0.0 - contact: - email: agent@observatory.academy - license: - name: Apache 2.0 - url: http://www.apache.org/licenses/LICENSE-2.0.html - - -host: localhost:5002 -schemes: - - https -produces: - - application/json -securityDefinitions: - # This section configures basic authentication with an API key. - api_key: - type: "apiKey" - name: "key" - in: "query" -security: - - api_key: [] - - -tags: -- name: Observatory - description: the Observatory API - - - -paths: - /v1/dataset_release: - get: - tags: - - Observatory - summary: get a DatasetRelease - operationId: get_dataset_release - description: | - Get the details of a DatasetRelease by passing it's id. - produces: - - application/json - parameters: - - in: query - name: id - description: DatasetRelease id - required: true - type: integer - responses: - 200: - description: the fetched DatasetRelease - schema: - $ref: '#/definitions/DatasetRelease' - 400: - description: bad input parameter - post: - tags: - - Observatory - summary: create a DatasetRelease - operationId: post_dataset_release - description: | - Create a DatasetRelease by passing a DatasetRelease object, without an id. - consumes: - - application/json - produces: - - application/json - parameters: - - in: body - name: body - description: DatasetRelease to create - required: true - schema: - $ref: '#/definitions/DatasetRelease' - responses: - 201: - description: DatasetRelease created, returning the created object with an id - schema: - $ref: '#/definitions/DatasetRelease' - put: - tags: - - Observatory - summary: create or update a DatasetRelease - operationId: put_dataset_release - description: | - Create a DatasetRelease by passing a DatasetRelease object, without an id. Update an existing DatasetRelease by - passing a DatasetRelease object with an id. - consumes: - - application/json - produces: - - application/json - parameters: - - in: body - name: body - description: DatasetRelease to create or update - required: true - schema: - $ref: '#/definitions/DatasetRelease' - responses: - 200: - description: DatasetRelease updated - schema: - $ref: '#/definitions/DatasetRelease' - 201: - description: DatasetRelease created, returning the created object with an id - schema: - $ref: '#/definitions/DatasetRelease' - delete: - tags: - - Observatory - summary: delete a DatasetRelease - operationId: delete_dataset_release - description: | - Delete a DatasetRelease by passing it's id. - consumes: - - application/json - produces: - - application/json - parameters: - - in: query - name: id - description: DatasetRelease id - required: true - type: integer - responses: - 200: - description: DatasetRelease deleted - - /v1/dataset_releases: - get: - tags: - - Observatory - summary: Get a list of DatasetRelease objects - operationId: get_dataset_releases - description: | - Get a list of DatasetRelease objects - produces: - - application/json - parameters: - - in: query - name: dag_id - description: the dag_id to fetch release info for - required: false - type: string - - in: query - name: dataset_id - description: the dataset_id to fetch release info for - required: false - type: string - responses: - 200: - description: a list of DatasetRelease objects - schema: - type: array - items: - $ref: '#/definitions/DatasetRelease' - 400: - description: bad input parameter - -definitions: - DatasetRelease: - type: object - properties: - id: - type: integer - dag_id: - type: string - example: "doi_workflow" - dataset_id: - type: string - example: "doi" - dag_run_id: - type: string - example: "YYYY-MM-DDTHH:mm:ss.ssssss" - x-nullable: true - data_interval_start: - type: string - format: date-time - example: 2020-01-02T20:01:05Z - x-nullable: true - data_interval_end: - type: string - format: date-time - example: 2020-01-02T20:01:05Z - x-nullable: true - snapshot_date: - type: string - format: date-time - example: 2020-01-02T20:01:05Z - x-nullable: true - partition_date: - type: string - format: date-time - example: 2020-01-02T20:01:05Z - x-nullable: true - changefile_start_date: - type: string - format: date-time - example: 2020-01-02T20:01:05Z - x-nullable: true - changefile_end_date: - type: string - format: date-time - example: 2020-01-02T20:01:05Z - x-nullable: true - sequence_start: - type: integer - example: 1 - x-nullable: true - sequence_end: - type: integer - example: 3 - x-nullable: true - created: - type: string - format: date-time - readOnly: true - modified: - type: string - format: date-time - readOnly: true - extra: - type: object - example: {'view-id': '830'} - minLength: 1 - maxLength: 512 - x-nullable: true \ No newline at end of file diff --git a/docs/conf.py b/docs/conf.py index d5fe78b7b..6c6269174 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -33,19 +33,17 @@ # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ - "pbr.sphinxext", "sphinx_rtd_theme", "sphinx.ext.viewcode", "sphinx.ext.intersphinx", "autoapi.extension", "recommonmark", - "sphinxcontrib.openapi", ] # Auto API settings: https://github.com/readthedocs/sphinx-autoapi autoapi_type = "python" -autoapi_dirs = ["../observatory-api", "../observatory-platform"] -autoapi_ignore = ["*.eggs*"] +autoapi_dirs = ["../observatory_platform"] +autoapi_ignore = ["*.eggs*", "**/tests/**"] autoapi_add_toctree_entry = True autoapi_python_use_implicit_namespaces = False diff --git a/docs/graphics/workflow_flow.png b/docs/graphics/workflow_flow.png deleted file mode 100644 index 9a1a755c9524d1e68b6d0c7bebd3f6cb86fe2a79..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 4313 zcmaKw=Q|r-*nnfjXw1fly=u2ctr{UAW{udYHCk1*VwTv{CL%>uQG}w}U(+hBowO)g zMU85WidB21G_harAMn2Cx}Ni#`+PX(%X6;lxlg98wJAHBAR7PxU^h22wgUj@PS4j$ zEWqO3{|&wFx=V3+1Ys}J~EC=^&*SI&Afe{guUTOZA1aBzm$hc^zDy&21^Mqxy!| zrD)RLTqqh4lPF~$C;9nI9(c&GP0wr}qi`Q)S;4CUdXhgMrYiiLr27JzsXl)<+$z^% zKT{wbnX&cl6}a3~=Pe%uZgOawv)CAtH6(Sbnc)_;@&%OJ$VGcf1=OX|z+!)#<|UAZ zLp+(+I_DgvPhjFXdwFx$(Ac|FKvLN%J*aaQa7~eqKj*GK$?Lm9=WN2qT{AyKccW@B zav!eu+RGSrt3uz8t)OJmUNwj9KBF+RFtV#WMkg4~o=}5iNO?|k1E3h*C7C8_kRpwA z-g~83MS5ORfjmi>1eCk<$4pj5Tn1IutgszlZ~MrUjw^PPD37)KI64cY+O;5)U$P-LB5p2teT znv~bSL(i!A7${fNz~(m7bgt8gE(DD#NhD047D~ivN&WRr1h(jqBpS^MO;KQ`D1P~~ zev%Q%OqzHx37c1WL^u4ogTQs!&@RRM0NL0#dT;B1i9}iHe1!8V#VpD~zo=i@p~ZTt zzX5AAHDACqEz@byx8((t1NWGqFpACvnzCBJ1cO0n?-a`bBCRXm(12`29_mXh%fz=d zpAauBWDCBK0AbIxt9`YkKDJ5kOzOp2Dl@y}q%^V$hPSm$02v6)y%&=dztT!I1aF~Z9;N;Mah$U2hdru53`8X zQDJAtBet3dc}puIUG~%O@-FF$u)1AQ13#@Q<(E@#<4!FL?mdvK-oB`bkhA-qzINiN zhJlf~s-Y;FjGAPcB;n*f*$(`9J4dg0)i*D|_%@UEHl77ciY2?5mxX)_m3%bO>0Q*$ zA+5i7(Np=Y*S{cAkC}`OLd_MBW7MES2dVmR=8U+ROPBR3`i{R?Wa(xiE%k-&+Nu$^hj^- zavBa=xlf-?AXDm@v#d2RQ+n957o?cia3DKLp+JVcvsaN1RMRg0V_`N$m|)B@wSY>& ze`U7SS){9G&Qaib+@#Jpt5}a^3~P#k3N39#RG>FFG0whU9wu91~c< zE=KTMATWks!n2Qv!s^3~y~?zyOyIMvtq4@&-b$t&8(@OxR-FNxW~C0q#ytVE7y&81 zOmK%z7jaHOV!!1U7Ze479-PxJr` z7%C#Y9jI_hHT%VM6ALP=Aw|^1`&>EVgv56=3C11ryUiF>MmE0y41LJYt6VER3df@~*p0SbUKxQ>RpGP8< zF9%oBg%U*GOWf6AZa}l)k;MZaO!{8x#XfAQ5?8<_EbKUE@+l zY$gEqGpC&-IQWVOho0#MC`lz#q1VIk7gc_AXpt~H9Rd-|mf&PaG1e4btTb2Vaa406LZt)${c>a*P+`y@}#X7j) zD0idVUt68g;6E;4SzeVbq~-kW*J$R#S6f2r{zw#F1>2>i9^QE&CtO~?S#i{slU3`i z0cTrYA$5jjSPAXH?W5~I?cd*-Sr`R8H+Gir5OTHrb!{~~Z(_;em0o_~#QF%i?eV4Z zl4*|%``Nu_CmpOo>C;Z|0RHZ%P+TlfcaWo-aGBvb7hVM_RSUL1hUxXpf?V^xf(%U1 zwe@Vv4+@APrb``*uvGKsa=jaGxrw#(bwzN;jyI(&uVJ^yP=J&HIH-y17YACzap zjx{oi7O?YBjJPUV#AbZ4DL-brI4MpjW0OUzkWzIM*}vQ&EmKaFBs2(|-_|aC#*`9M zK8T<=Ii?atW?uHmfkDo?;94N{ZMjM}c2)GT4Y? zj+Yu?eASp|3FI*D4hcjyokfi$Yq0(AL6+l75u$<_FY}>FmhMBn+kaU7ghvhDj@mq_ zH$F+G{KEb;yQ$i*cnYN5vbcB`QaIExWR$fTUtLipUO(Sig4=KmFMm*_Q<&PfQ#V*J za8tVcpd8syJPo z7)}24=FD|zETyAw#<%__4_Qr}#w0H-+>;!=&%etHH zgiG%hHTcnKj<_L8mQ}X7`);K4`$|{IS6!%8H&N3{9|LT~{b(dt;e0|_!}ciDJMT4Y z{ZqFxplCk@IXxkOp8JNxo;c(76PVeko1hJ5*SAFYIBDIM`aUXt{E2;ej0vzD+J)W==OcD)SQ$cd3+34CsF}y8a77bD<3MR zDL##ISpH63en%!PLxzxc`XN|YbA4>5}(|rbuBdAxR7$(xTjy{-AG) zV;&M9^pQ(nBHK9L}HOm4mJV-#`ZCtA>~KBYd1BS zt;&_?eVU0aE28X14MGO|%}}*TO|sDY;!gtJ8&MP;&Rf~H)tue+50pqnL-%Ql$v>pu z*?y&br`anQ7}?8);YMb2Kws1Is4+rl7KdeBC=l z8FBF2)j{t}-5!#ww9mAnGf;Kh2}Hi-?_UUtMbD&t8aF~pe{s`jV1)R~N-Cq`$=X)K z2L{+#W_|mWY|7fR_CXVdfSNG*`V*H#_-xn<^uZoYIN=)rn<_^?&Vyb%!Oa^ag2tJI zJNN#w`?ibZEb3icp^p`VtS#U|d=wp=9@YucN^^nTI%dHi=BW>%LaA;=;;?^|hX;My zQMduSLvd8)<$CF|>kky-Y>%e+d@Tm9RtMQ)mTx<(wlVjwUs>D_H2Ul>RxtD|iRfz) zR-oV_knjd1?(Y5$w)zkcy^M<;l!;rssBYbvIf&9%1GUG~Ng;K}tOXMseC6yW$fwtl zhvKzw8Pxh@Z(9&2&+UTB~2SAQK9{YGHra=`*By~gs{+CKse*ma#_NaM$Yf4DMAD?@_( zlEb^ApV$8A<37c3VO|_D5CUQ7lhHiMcA1^YRJln3$DViv?Tj2TDcGmAX{shqUwR&M zqW~+}y|E(o9#|dbRqXtoG*$L<`d6>nDlrj$Hq;Ym?N-pM!+PH3N*c|b#dMEXO@HTV!-Q_Vv$tO&gv z&?TiFA|D!l%XZg+RgkMAX|#meD#Gf{>D^RD7n^y*BR9DBkdc@E0b(EU@y(f&_vRqW z#+1NFq)J2j?Z#wRS!q-@Zs5I9kG*N@$yA=|6qEe=5e`Hkp|bWwvw{o9&te$TzxvCh z-o&hEXTbmIBRLW-hT?(Z2XO%6*~d*et7t~MW3ko~=L)kmXzY+dd6&IbSl$(Z?Pn!g zWlSlmScCkGP*EaIIMrpjGYVhYi;^}`{n4Uq=J1z>VSGD&a4xFrW1& zUxWJ?E<2PbBIOxW(EGj*XYUW2Df4+L_P{nm7fqh3&}lfo`8HeQ`8_XOnsa}|rYW|1 xN>iaY+b?A>cq~wPJQQa=#dd7{&?M_XEGMUYVA1b&#rZb`z#L(1{2K0={6CB378(Em diff --git a/docs/index.rst b/docs/index.rst index b9ba09255..4ac1b5fe3 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -1,42 +1,17 @@ ======================= Observatory Platform ======================= -The Observatory Platform provides an environment and workflows for fetching, processing and analysing data to -understand how well universities operate as Open Knowledge Institutions. +The Observatory Platform is a set of dependencies used by the Curtin Open Knowledge Initiative (COKI) for running its +Airflow based workflows to fetch, process and analyse bibliometric datasets. -Installation -======================= -.. toctree:: - :maxdepth: 1 - - installation - -Tutorials -======================= -.. toctree:: - :maxdepth: 2 +The workflows for the project can be seen at: - tutorials/index - -REST API -========================= -.. toctree:: - :maxdepth: 2 - - api/index - -License & Contributing Guidelines -================================= -Information about licenses, contributing guidelines etc. +* `Academic Observatory Workflows `_ +* `OAeBU Workflows `_ -.. toctree:: - :maxdepth: 1 - - license - -Python API Reference +Python Package Reference ===================== -Documentation for the observatory.platform Python API. +Documentation for the observatory_platform Python package. .. toctree:: :maxdepth: 3 @@ -48,3 +23,13 @@ Indices and tables * :ref:`genindex` * :ref:`modindex` * :ref:`search` + + +License +================================= +Information about licenses, contributing guidelines etc. + +.. toctree:: + :maxdepth: 1 + + license \ No newline at end of file diff --git a/docs/installation.md b/docs/installation.md deleted file mode 100644 index a7a1f12aa..000000000 --- a/docs/installation.md +++ /dev/null @@ -1,24 +0,0 @@ -# Installation -Observatory Platform supports Python 3.10, Ubuntu Linux 20.04 and MacOS 10.14, on x86 architecture. - -## System dependencies -* Python 3.10 -* Pip -* Docker -* Docker Compose V2 -* virtualenv -* curl - -Make sure you first have curl and bash installed on your system. MacOS comes with curl and bash. If you need to install curl on Ubuntu, run `sudo apt install -y curl`. Then run the following in a terminal: -``` -/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/The-Academic-Observatory/observatory-platform/main/install.sh)" -``` - -The installer script will prompt you with a series of questions to customise your installation, and optionally configure the observatory. At some point you might be asked to enter your user password in order to install system dependencies. If you only want to run the observatory platform, then select the `pypi` installation type. If you want to modify or develop the platform, select the `source` installation type. - -The script will create a Python virtual environment in the `observatory_venv` directory. - -There are two types of observatory platform deployments. `local` and `terraform`. The `local` installation allows you to run the observatory platform on the locally installed machine. The `terraform` installation deploys the platform to the cloud. See the documentation section on Terraform deployment for more details. - -You will also have the option of installing additional workflows. See the the GitHub pages for the [academic-observatory-workflows](https://github.com/The-Academic-Observatory/academic-observatory-workflows) and the [oaebu-workflows](https://github.com/The-Academic-Observatory/oaebu-workflows) for more information. - diff --git a/docs/requirements.txt b/docs/requirements.txt index 5ae12c3a5..264fd322e 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,9 +1,4 @@ Sphinx>6, <7 -sphinx-autoapi>=3 -sphinx-rtd-theme==1.3.0 -sphinxcontrib.openapi==0.7.* -recommonmark==0.7.* -pbr==5.5.* -pandas -myst-parser==2.0.0 -mistune<2 +sphinx-autoapi>=3,<4 +sphinx-rtd-theme>=2,<3 +recommonmark>=0.7.1,<1 diff --git a/docs/tutorials/deploy_terraform.md b/docs/tutorials/deploy_terraform.md deleted file mode 100644 index 98873d0f1..000000000 --- a/docs/tutorials/deploy_terraform.md +++ /dev/null @@ -1,480 +0,0 @@ -# Observatory Terraform Environment -This is a tutorial for deploying the Observatory Platform to Google Cloud with Terraform. - -You should have [installed the Observatory Platform](installation.html) before following this tutorial. - -## Install dependencies -The dependencies that are required include: -* [Packer](https://www.packer.io/): for automating the creation of the Google Cloud VM images. -* [Terraform](https://www.terraform.io/): to automate the deployment of the various Google Cloud services. -* [Google Cloud SDK](https://cloud.google.com/sdk/docs#install_the_latest_cloud_tools_version_cloudsdk_current_version): the Google -Cloud SDK including the gcloud command line tool. - -If you installed the observatory platform through the installer script, and selected the Terraform configuration, the dependencies were installed for you. - -If you wish to manually install the dependencies yourself, see the details below. - -### Linux -Install Packer: -```bash -sudo curl -L "https://releases.hashicorp.com/packer/1.9.2/packer_1.9.2_linux_amd64.zip" -o /usr/local/bin/packer -# When asked to replace, answer 'y' -unzip /usr/local/bin/packer -d /usr/local/bin/ -sudo chmod +x /usr/local/bin/packer -``` - -Install Google Cloud SDK: -```bash -sudo curl -L "https://dl.google.com/dl/cloudsdk/channels/rapid/downloads/google-cloud-sdk-330.0.0-linux-x86_64.tar.gz" -o /usr/local/bin/google-cloud-sdk.tar.gz -sudo tar -xzvf /usr/local/bin/google-cloud-sdk.tar.gz -C /usr/local/bin -rm /usr/local/bin/google-cloud-sdk.tar.gz -sudo chmod +x /usr/local/bin/google-cloud-sdk -/usr/local/bin/google-cloud-sdk/install.sh -``` - -Install Terraform: -```bash -sudo curl -L "https://releases.hashicorp.com/terraform/1.5.5/terraform_1.5.5_linux_amd64.zip" -o /usr/local/bin/terraform -# When asked to replace, answer 'y' -sudo unzip /usr/local/bin/terraform -d /usr/local/bin/ -sudo chmod +x /usr/local/bin/terraform -``` - -### Mac -Install Packer: -```bash -sudo curl -L "https://releases.hashicorp.com/packer/1.9.2/packer_1.9.2_darwin_amd64.zip" -o /usr/local/bin/packer -# When asked to replace, answer 'y' -unzip /usr/local/bin/packer -d /usr/local/bin/ -sudo chmod +x /usr/local/bin/packer -``` - -Install Google Cloud SDK: -```bash -sudo curl -L "https://dl.google.com/dl/cloudsdk/channels/rapid/downloads/google-cloud-sdk-330.0.0-darwin-x86_64.tar.gz" -o /usr/local/bin/google-cloud-sdk.tar.gz -mkdir /usr/local/bin/google-cloud-sdk -sudo tar -xzvf /usr/local/bin/google-cloud-sdk.tar.gz -C /usr/local/bin -rm /usr/local/bin/google-cloud-sdk.tar.gz -sudo chmod +x /usr/local/bin/google-cloud-sdk -/usr/local/bin/google-cloud-sdk/install.sh -``` - -Install Terraform: -```bash -sudo curl -L "https://releases.hashicorp.com/terraform/1.5.5/terraform_1.5.5_darwin_amd64.zip" -o /usr/local/bin/terraform -# When asked to replace, answer 'y' -unzip /usr/local/bin/terraform -d /usr/local/bin/ -sudo chmod +x /usr/local/bin/terraform -``` - -## Prepare Google Cloud project -Each environment (develop, staging, production) requires its own project. -See [Creating and managing projects](https://cloud.google.com/resource-manager/docs/creating-managing-projects) for more -details on creating a project. The following instructions are for one project only, repeat these steps for each -environment you would like to use. - -## Prepare permissions for Google Cloud service account -A Google Cloud service account will need to be created and it's service account key will need to be downloaded to -your workstation. See the article [Getting Started with Authentication](https://cloud.google.com/docs/authentication/getting-started) for -more details. - -### Development/test project -For the development and staging environments, the following permissions will need to be assigned to the service account -so that Terraform and Packer are able to provision the appropriate services: -```bash -BigQuery Admin -Cloud Build Service Account (API) -Cloud Run Admin (API) -Cloud SQL Admin -Compute Admin -Compute Image User -Compute Network Admin -Create Service Accounts -Delete Service Accounts -Project IAM Admin -Service Account Key Admin -Service Account User -Service Management Administrator (API) -Secret Manager Admin -Service Usage Admin -Storage Admin -Storage Transfer Admin -Serverless VPC Access Admin -``` - -### Production project -For the production environment, two custom roles with limited permissions need to be created to prevent storage buckets -as well as the Cloud SQL database instance from accidentally being destroyed. - -When running `terraform destroy` with these roles, Terraform will produce an error, because the service account -doesn't have the required permissions to destroy these resources (buckets and sql database instance). New roles can be -created in the Google Cloud Console, under 'IAM & Roles' and then 'Roles'. - -The two custom roles are: -* Custom Cloud SQL editor -Filter the Roles table on 'Cloud SQL Editor', select the role and click on 'create role from selection'. -Click on 'ADD PERMISSIONS' and add `cloudsql.users.create` and `cloudsql.instances.create`. -This new role replaces the 'Cloud SQL Admin' role compared to the development environment above. - -* Custom Storage Admin -Filter the Roles table on 'Storage Admin', select the role and click on 'create role from selection'. -At the 'assigned permissions' section filter for and remove `storage.buckets.delete` and `storage.objects.delete`. -This new role replaces the 'Storage Admin' role compared to the development environment above. - -```bash -Custom Cloud SQL Editor -Custom Storage Admin -BigQuery Admin -Cloud Build Service Account (API) -Cloud Run Admin (API) -Compute Admin -Compute Image User -Compute Network Admin -Create Service Accounts -Delete Service Accounts -Project IAM Admin -Service Account Key Admin -Service Account User -Service Management Administrator (API) -Secret Manager Admin -Service Usage Admin -Storage Transfer Admin -Serverless VPC Access Admin -``` - -## Prepare Google Cloud services -Enable the [Compute Engine API](https://console.developers.google.com/apis/api/compute.googleapis.com/overview) for the -google project. This is required for Packer to create the image. Other Google Cloud services are enabled by Terraform -itself. - -## Add user as verified domain owner -The terraform service account needs to be added as a verified domain owner in order to map the Cloud Run domain that is created -to a custom domain. The custom domain is used for the API service. See the [Google documentation](https://cloud.google.com/run/docs/mapping-custom-domains#adding_verified_domain_owners_to_other_users_or_service_accounts) -for more info on how to add a verified owner. - -## Switch to the branch that you would like to deploy -Enter the observatory-platform project folder: -```bash -cd observatory-platform -``` - -Switch to the branch that you would like to deploy, for example: -``` -git checkout develop -``` - -## Prepare configuration files -The Observatory Terraform configuration file needs to be created, to generate a default file run the following command: -```bash -observatory generate config terraform -``` - -The file is saved to `~/.observatory/config-terraform.yaml`. Customise the generated file, parameters with '<--' need -to be customised and parameters commented out are optional. - -See below for an example generated file: -```yaml -# The backend type: terraform -# The environment type: develop, staging or production -backend: - type: terraform - environment: develop - -# Apache Airflow settings -airflow: - fernet_key: 4yfYXnxjUZSsh1CefVigTuUGcH-AUnuKC9jJ2sUq-xA= # the fernet key which is used to encrypt the secrets in the airflow database - ui_user_email: my-email@example.com <-- # the email for the Apache Airflow UI's airflow user - ui_user_password: my-password <-- # the password for the Apache Airflow UI's airflow user - -# Terraform settings -terraform: - organization: my-terraform-org-name <-- # the terraform cloud organization - -# Google Cloud settings -google_cloud: - project_id: my-gcp-id <-- # the Google Cloud project identifier - credentials: /path/to/google_application_credentials.json <-- # the path to the Google Cloud service account credentials - region: us-west1 <-- # the Google Cloud region where the resources will be deployed - zone: us-west1-a <-- # the Google Cloud zone where the resources will be deployed - data_location: us <-- # the location for storing data, including Google Cloud Storage buckets and Cloud SQL backups - -# Google Cloud CloudSQL database settings -cloud_sql_database: - tier: db-custom-2-7680 # the machine tier to use for the Observatory Platform Cloud SQL database - backup_start_time: '23:00' # the time for Cloud SQL database backups to start in HH:MM format - postgres_password: my-password <-- # the password for the airflow postgres database user - -# Settings for the main VM that runs the Apache Airflow scheduler and webserver -airflow_main_vm: - machine_type: n2-standard-2 # the machine type for the virtual machine - disk_size: 50 # the disk size for the virtual machine in GB - disk_type: pd-ssd # the disk type for the virtual machine - create: true # determines whether virtual machine is created or destroyed - -# Settings for the weekly on-demand VM that runs large tasks -airflow_worker_vm: - machine_type: n1-standard-8 # the machine type for the virtual machine - disk_size: 3000 # the disk size for the virtual machine in GB - disk_type: pd-standard # the disk type for the virtual machine - create: false # determines whether virtual machine is created or destroyed - -# API settings -api: - domain_name: api.observatory.academy <-- # the custom domain name for the API, used for the google cloud endpoints service - subdomain: project_id # can be either 'project_id' or 'environment', used to determine a prefix for the domain_name - -# User defined Apache Airflow variables: -# airflow_variables: -# my_variable_name: my-variable-value - -# User defined Apache Airflow Connections: -# airflow_connections: -# my_connection: http://my-username:my-password@ - -# User defined Observatory DAGs projects: -# workflows_projects: -# - package_name: observatory-dags -# path: /home/user/observatory-platform/observatory-dags -# dags_module: observatory.dags.dags -``` - -The config file will be read when running `observatory terraform create-workspace` and -`observatory terraform update-workspace` and the variables are stored inside the Terraform Cloud workspace. - -### Fernet key -One of the required variables is a Fernet key, the generated default file includes a newly generated Fernet key that -can be used right away. Alternatively, generate a Fernet key yourself, with the following command: -```bash -observatory generate fernet-key -``` - -### Encoding airflow connections -Note that the login and passwords in the 'airflow_connections' variables need to be URL encoded, otherwise they will -not be parsed correctly. - -## Building the Google Compute VM image with Packer -First, build and deploy the Observatory Platform Google Compute VM image with Packer: -```bash -observatory terraform build-image ~/.observatory/config-terraform.yaml -``` - -Use this command if you have: -* Created, removed or updated user defined Observatory DAGs projects via the field `workflows_projects`, in the Observatory -Terraform config file. -* Updated any code in the Observatory Platform. -* Update the `backend.environment` variable in the Observatory Terraform config file: you need to make sure that an -image is built for the other environment. - -You will need to taint the VMs and update them so that they use the new image. - -You do not need to run this command if: -* You have created, removed or updated user defined Apache Airflow connections or variables in the Observatory -Terraform config file: in this case you will need to update the Terraform workspace. -* You have changed any other settings in the Observatory Terraform config file (apart from `backend.environment`): -in this case you will need to update the Terraform workspace variables and run `terraform apply`. - -Use this command if: - * This is the first time you are deploying the Terraform resources - * You have updated any files in the API directory (`/home/user/workspace/observatory-platform/observatory-platform/observatory/platform/api`) - -## Building the Terraform files -To refresh the files that are built into the `~/.observatory/build/terraform` directory, without rebuilding the entire -Google Compute VM image again, run the following command: -```bash -observatory terraform build-terraform ~/.observatory/config-terraform.yaml -``` - -Use this command if you have: - * Updated the Terraform deployment scripts, but nothing else. - -## Setting up Terraform -Enter the terraform directory: -```bash -cd ~/.observatory/build/terraform/terraform -``` - -Create token and login on Terraform Cloud: -```bash -terraform login -``` - -This should automatically store the token in `/home/user/.terraform.d/credentials.tfrc.json`, this file is used during -the next commands to retrieve the token. - -It's also possible to explicitly set the path to the credentials file using the option '--terraform-credentials-file'. - -## Creating and updating Terraform workspaces -See below for instructions on how to run observatory terraform create-workspace and update-workspace. - -### Create a workspace -Create a new workspace (this will use the created token file): -See [Observatory Terraform Environment](./observatory_dev.html#observatory-terraform-environment) for more info on the -usage of `observatory terraform`. -```bash -observatory terraform create-workspace ~/.observatory/config-terraform.yaml -``` - -You should see the following output: -```bash -Observatory Terraform: all dependencies found - Config: - - path: /home/user/.observatory/config-terraform.yaml - - file valid - Terraform credentials file: - - path: /home/user/.terraform.d/credentials.tfrc.json - -Terraform Cloud Workspace: - Organization: jamie-test - - Name: observatory-develop (prefix: 'observatory-' + suffix: 'develop') - - Settings: - - Auto apply: True - - Terraform Variables: - * environment: develop - * airflow: sensitive - * google_cloud: sensitive - * cloud_sql_database: sensitive - * airflow_main_vm: {"machine_type"="n2-standard-2","disk_size"=20,"disk_type"="pd-standard","create"=true} - * airflow_worker_vm: {"machine_type"="n2-standard-2","disk_size"=20,"disk_type"="pd-standard","create"=false} - * airflow_variables: {} - * airflow_connections: sensitive -Would you like to create a new workspace with these settings? [y/N]: -Creating workspace... -Successfully created workspace -``` - -### Update a workspace -To update variables in an existing workspace in Terraform Cloud: -```bash -observatory terraform update-workspace ~/.observatory/config-terraform.yaml -``` - -Depending on which variables are updated, you should see output similar to this: -```bash - Config: - - path: /home/user/.observatory/config-terraform.yaml - - file valid - Terraform credentials file: - - path: /home/user/.terraform.d/credentials.tfrc.json - -Terraform Cloud Workspace: - Organization: jamie-test - - Name: observatory-develop (prefix: 'observatory-' + suffix: 'develop') - - Settings: - - Auto apply: True - - Terraform Variables: - UPDATE - * airflow: sensitive -> sensitive - * google_cloud: sensitive -> sensitive - * cloud_sql_database: sensitive -> sensitive - * airflow_connections: sensitive -> sensitive - UNCHANGED - * api: {"domain_name"="api.observatory.academy","subdomain"="project_id"} - * environment: develop - * airflow_main_vm: {"machine_type"="n2-standard-2","disk_size"=20,"disk_type"="pd-standard","create"=true} - * airflow_worker_vm: {"machine_type"="n2-standard-2","disk_size"=20,"disk_type"="pd-standard","create"=false} - * airflow_variables: {} -Would you like to update the workspace with these settings? [y/N]: y -Updating workspace... -Successfully updated workspace -``` - -## Deploy -Once you have created your Terraform workspace, you can deploy the system with Terraform Cloud. - -Initialize Terraform using key/value pairs: -```bash -terraform init -backend-config="hostname="app.terraform.io"" -backend-config="organization="coki"" -``` - -Or using a backend file: -```bash -terraform init -backend-config=backend.hcl -``` - -With backend.hcl: -```hcl -hostname = "app.terraform.io" -organization = "coki" -``` - -If Terraform prompts to migrate all workspaces to "remote", answer "yes". - -Select the correct workspace in case multiple workspaces exist: -```bash -terraform workspace list -terraform workspace select -``` - -To preview the plan that will be executed with apply (optional): -```bash -terraform plan -``` - -To deploy the system with Terraform: -```bash -terraform apply -``` - -To destroy the system with Terraform: -```bash -terraform destroy -``` - -## Troubleshooting -See below for instructions on troubleshooting. - -### Undeleting Cloud Endpoints service -If your Cloud Endpoints service is deleted by Terraform and you try to recreate it again, you will get the following -error: -```bash -Error: googleapi: Error 400: Service has been deleted and will be purged after 30 days. To reuse this service, please undelete the service following https://cloud.google.com/service-infrastructure/docs/create-services#undeleting., failedPrecondition -``` - -To restore the Cloud Endpoints service, run the following: -```bash -gcloud endpoints services undelete -``` - -### Rebuild the VMs with a new Google Cloud VM image -If you have re-built the Google Cloud VM image, then you will need to manually taint the VMs and rebuild them: -```bash -terraform taint module.airflow_main_vm.google_compute_instance.vm_instance -terraform taint module.airflow_worker_vm.google_compute_instance.vm_instance -terraform apply -``` - -### Manually destroy the VMs -Run the following commands to manually destroy the VMs: -``` -terraform destroy -target module.airflow_main_vm.google_compute_instance.vm_instance -terraform destroy -target module.airflow_worker_vm.google_compute_instance.vm_instance -``` - -### Logging into the VMs -To ssh into airflow-main-vm: -```bash -gcloud compute ssh airflow-main-vm --project your-project-id --zone your-compute-zone -``` - -To ssh into airflow-worker-vm (this is off by default, turn on using airflow DAG): -```bash -gcloud compute ssh airflow-worker-vm --project your-project-id --zone your-compute-zone -``` - -### Viewing the Apache Airflow and Flower UIs -To view the Apache Airflow and Flower web user interfaces you must forward ports 8080 and 5555 from the airflow-main-vm -into your local workstation. - -To port forward with the gcloud command line tool: -```bash -gcloud compute ssh airflow-main-vm --project your-project-id --zone us-west1-c -- -L 5555:localhost:5555 -L 8080:localhost:8080 -``` - -### Syncing files with a VM -To sync your local Observatory Platform project with a VM run the following commands, making sure to customise -the username and vm-hostname for the machine: -```bash -rsync --rsync-path 'sudo -u airflow rsync' -av -e ssh --chown=airflow:airflow --exclude='docs' --exclude='*.pyc' \ - --exclude='*.tfvars' --exclude='*.tfstate*' --exclude='venv' --exclude='.terraform' --exclude='.git' \ - --exclude='*.egg-info' /path/to/observatory-platform username@vm-hostname:/opt/observatory -``` diff --git a/docs/tutorials/index.rst b/docs/tutorials/index.rst deleted file mode 100644 index 74ece4858..000000000 --- a/docs/tutorials/index.rst +++ /dev/null @@ -1,33 +0,0 @@ -Tutorials ----------------------------- - -Developing a telescope -######################################### - -.. toctree:: - :maxdepth: 2 - - workflow/intro - workflow/workflow_class - workflow/snapshot_class - workflow/stream_class - workflow/organisation_class - workflow/step_by_step - workflow/cli - workflow/style - -Starting an Observatory Platform development environment -########################################################## - -.. toctree:: - :maxdepth: 2 - - observatory_dev - -Deploying an Observatory Platform production environment with Terraform -######################################################################### - -.. toctree:: - :maxdepth: 2 - - deploy_terraform.md \ No newline at end of file diff --git a/docs/tutorials/observatory_dev.md b/docs/tutorials/observatory_dev.md deleted file mode 100644 index a55d214b3..000000000 --- a/docs/tutorials/observatory_dev.md +++ /dev/null @@ -1,208 +0,0 @@ -# Observatory Platform Development Environment -The following is a tutorial for running the local Observatory Platform development environment. - -Make sure that you have followed the installation instructions to install the Observatory Platform on your system. - -## Prepare configuration files -Generate a config.yaml file: -```bash -observatory generate config local -``` - -You should see the following output: -```bash -The file "/home/user/.observatory/config.yaml" exists, do you want to overwrite it? [y/N]: y -config.yaml saved to: "/home/user/.observatory/config.yaml" -``` - -The generated file should look like (with inline comments removed): -```yaml -# The backend type: local -# The environment type: develop, staging or production -backend: - type: local - environment: develop - -# Apache Airflow settings -airflow: - fernet_key: wbK_uYJ-x0tnUcy_WMwee6QYzI-7Ywbf-isKCvR1sZs= - -# Terraform settings: customise to use the vm_create and vm_destroy DAGs: -# terraform: -# organization: my-terraform-org-name - -# Google Cloud settings: customise to use Google Cloud services -# google_cloud: -# project_id: my-gcp-id # the Google Cloud project identifier -# credentials: /path/to/google_application_credentials.json # the path to the Google Cloud service account credentials -# data_location: us # the Google Cloud region where the resources will be deployed -# buckets: -# download_bucket: my-download-bucket-name # the bucket where downloads are stored -# transform_bucket: my-transform-bucket-name # the bucket where transformed files are stored - -# User defined Apache Airflow variables: -# airflow_variables: -# my_variable_name: my-variable-value - -# User defined Apache Airflow Connections: -# airflow_connections: -# my_connection: http://my-username:my-password@ - -# User defined Observatory DAGs projects: -# workflows_projects: -# - package_name: observatory-dags -# path: /home/user/observatory-platform/observatory-dags -# dags_module: observatory.dags.dags -``` - -See [Creating and managing projects](https://cloud.google.com/resource-manager/docs/creating-managing-projects) for more -details on creating a project and [Getting Started with Authentication](https://cloud.google.com/docs/authentication/getting-started) for -instructions on how to create a service account key. - -Make sure that service account has roles `roles/bigquery.admin` and `roles/storagetransfer.admin` as well as -access to the download and transform buckets. - -The table below lists connections that are required for telescopes bundled with the observatory: - -```eval_rst -+-------------------------+------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------+ -| Connection Key | Example | Description | -+=========================+======================================================================================================+========================================================================================+ -| crossref | http://myname:mypassword@myhost.com | Stores the token for the crossref API as a password | -+-------------------------+------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------+ -| mag_releases_table | http://myname:mypassword@myhost.com | Stores the azure-storage-account-name as a login and url-encoded-sas-token as password | -+-------------------------+------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------+ -| mag_snapshots_container | http://myname:mypassword@myhost.com | Stores the azure-storage-account-name as a login and url-encoded-sas-token as password | -+-------------------------+------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------+ -| terraform | mysql://:terraform-token@ | Stores the terraform user token as a password (used to create/destroy VMs) | -+-------------------------+------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------+ -| slack | https://:T00000000%2FB00000000%2FXXXXXXXXXXXXXXXXXXXXXXXX@https%3A%2F%2Fhooks.slack.com%2Fservices | Stores the URL for the Slack webhook as a host and the token as a password | -+-------------------------+------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------+ -``` - -## Running the local development environment -See below for instructions on how to start the Observatory Platform, view Observatory Platform UIs, stop the -Observatory Platform and customise other settings. - -### Start the Observatory Platform -To start the local Observatory Platform development environment: -```bash -observatory platform start -``` - -You should see the following output: -```bash -Observatory Platform: all dependencies found - Docker: - - path: /usr/bin/docker - - running - Host machine settings: - - observatory home: /home/user/.observatory - - data-path: /home/user/.observatory/data - - dags-path: /home/user/workspace/observatory-platform/observatory_platform/dags - - logs-path: /home/user/.observatory/logs - - postgres-path: /home/user/.observatory/postgres - - host-uid: 1000 - Docker Compose: - - path: /home/user/workspace/observatory-platform/venv/bin/docker-compose - config.yaml: - - path: /home/user/.observatory/config.yaml - - file valid -Observatory Platform: built -Observatory Platform: started -View the Apache Airflow UI at http://localhost:8080 -``` - -### Viewing the Apache Airflow and Flower UIs -Once the Observatory Platform has started, the following UIs can be accessed: -* Apache Airflow UI at [http://localhost:8080](http://localhost:8080) -* Flower UI at [http://localhost:5555](http://localhost:5555) - -### Stop the Observatory Platform -To stop the Observatory Platform: -```bash -observatory platform stop -``` - -You should see the following output: -```bash -Observatory Platform: all dependencies found -... -Observatory Platform: stopped -``` - -### Specify a config.yaml file -To specify a different config.yaml file use the `--config-path` parameter when starting the Observatory Platform: -```bash -observatory platform start --config-path /your/path/to/config.yaml -``` - -### Specify a DAGs folder -To specify a different dags folder use the `--dags-path` parameter when starting the Observatory Platform: -```bash -observatory platform start --dags-path /your/path/to/dags -``` - -### Specify a data folder -To specify a different folder to mount as the host machine's data folder, use the `--data-path` parameter when -starting the Observatory Platform: -```bash -observatory platform start --data-path /your/path/to/data -``` - -### Specify a logs folder -To specify a different folder to mount as the host machine's logs folder, use the `--logs-path` parameter when -starting the Observatory Platform: -```bash -observatory platform start --logs-path /your/path/to/logs -``` - -### Specify a PostgreSQL folder -To specify a different folder to mount as the host machine's PostgreSQL data folder, use the `--postgres-path` parameter -when starting the Observatory Platform: -```bash -observatory platform start --postgres-path /your/path/to/postgres -``` - -### Specify a user id -To specify different user id, which is used to set the ownership of the volume mounts, use the following command -when starting the Observatory Platform: -```bash -observatory platform start --host-uid 5000 -``` - -### Specify a group id -To specify different group id, which is used to set the ownership of the volume mounts, use the following command -when starting the Observatory Platform: -```bash -observatory platform start --host-gid 5000 -``` - -### Override default ports -You may override the host ports for Redis, Flower UI, Airflow UI, Elasticsearch and Kibana. An example is given -below: -```bash -observatory platform start --redis-port 6380 --flower-ui-port 5556 --airflow-ui-port 8081 -``` - -### Specify an existing Docker network -You may use an existing Docker network by supplying the network name: -```bash -observatory platform start --docker-network-name observatory-network -``` - -## Getting help -To get help with the Observatory Platform commands: -```bash -observatory --help -``` - -To get help with the Observatory Platform platform command: -```bash -observatory platform --help -``` - -To get help with the Observatory Platform generate command: -```bash -observatory generate --help -``` \ No newline at end of file diff --git a/docs/tutorials/workflow/intro.md b/docs/tutorials/workflow/intro.md deleted file mode 100644 index 6dc8846b2..000000000 --- a/docs/tutorials/workflow/intro.md +++ /dev/null @@ -1,60 +0,0 @@ -# Background -## Description of a workflow and telescope -The observatory platform collects data from many different sources and aggregates the different data sources as well. -This is all done with a variety of different workflows. -Some are aggregation or analytical workflows and some are workflows collecting the original data from a single data - source. -The last type of workflow is referred to as a telescope. -A telescope collects data from a single data source and should try to capture the data in its original state as much - as possible. -A telescope can generally be described with these tasks: - - Extract the raw data from an external source - - Store the raw data in a bucket - - Transform the data, so it is ready to be loaded into the data warehouse - - Store the transformed data in a bucket - - Load the data into the data warehouse - -## Managing workflows with Airflow -The workflows are all managed using Airflow. -This workflow management system helps to schedule and monitor the many - different workflows. -Airflow works with DAG (Directed Acyclic Graph) objects that are defined in a Python script. -The definition of a DAG according to Airflow is as follows: - > A dag (directed acyclic graph) is a collection of tasks with directional dependencies. A dag also has a schedule, a start date and an end date (optional). For each schedule, (say daily or hourly), the DAG needs to run each individual tasks as their dependencies are met. - -Generally speaking, a workflow is conveyed in a single DAG and there is a 1 on 1 mapping between DAGs and workflows. - -## Workflows and Google Cloud Platform -The Observatory Platform currently uses Google Cloud Platform as a platform for data storage and a data warehouse. -This means that the data is stored in Google Cloud Storage buckets and loaded into Google Cloud BigQuery, which - functions as the data warehouse. -To be able to access the different Google Cloud resources (such as Storage Buckets and BigQuery), the - GOOGLE_APPLICATION_CREDENTIALS environment variable is set on the Compute Engine that hosts Airflow. -This is all done when installing either the Observatory Platform Development Environment or Terraform Environment. -Many workflows make use of Google Cloud utility functions and these functions assume that the default credentials - are already set. - -## The workflow templates -Initially the workflows in the observatory platform were each developed individually. -There would be a workflow and release class that was unique for each workflow. -After developing a few workflows it became clear that there are many similarities between the workflows - and the classes that were developed. -For example, many tasks such as uploading data to a bucket or loading data into BigQuery were the same for different - workflows and only variables like filenames and schemas would be different. -The same properties were also often implemented, for example a download folder, release date and the many Airflow - related properties such as the DAG id, schedule interval, start date etc. - -These similarities prompted the development of a workflow template that can be used as a basis for a new workflow. -Additionally, the template abstracts away the code to create an Airflow DAG object, making it possible to use - the template and develop workflows without previous Airflow knowledge. -Having said that, some basic Airflow knowledge could come in handy, as it might help to understand the possibilities - and limitations of the template. -The template also implements properties that are often used and common tasks such as cleaning up local files at the - end of the workflow. -The initial template is referred to as the 'base' template and is used as a base for three other templates that - implement more specific tasks for loading data into BigQuery and have some properties set to specific values (such - as whether previous DAG runs should be run using the airflow 'catchup' setting). -The base template and the other three templates (snapshot, stream and organisation) are all explained in more detail - below. -Each of the templates also have their own corresponding release class, this class contains properties and methods - that are related to the specific release of a data source, these are also explained in more detail below. \ No newline at end of file diff --git a/docs/tutorials/workflow/step_by_step.md b/docs/tutorials/workflow/step_by_step.md deleted file mode 100644 index fb4893176..000000000 --- a/docs/tutorials/workflow/step_by_step.md +++ /dev/null @@ -1,709 +0,0 @@ -# Step by step tutorial - -A typical workflow pipeline will: -1. Create a DAG file that calls code to construct the workflow in `my-dags/my_dags/dags` -2. Create a workflow file containing code for the workflow itself in `my-dags/my_dags/workflows` -3. Create one or multiple schema files for the workflow data loaded into BigQuery in `my-dags/my_dags/database/schema` -4. Create a file with tests for the workflow in `my-dags/tests/workflows` -5. Create a documentation file about the workflow in `my-dags/docs` and update the `index.rst` file - -In these filepaths, `my-dags` is the workflows project folder and `my_dags` is the package name. - -## 1. Creating a DAG file -For Airflow to pickup new DAGs, it is required to create a DAG file that contains the DAG object as well as the keywords - 'airflow' and 'DAG'. -Any code in this file is executed every time the file is loaded into the Airflow dagbag, which is once per minute by - default. -This means that the code in this file should be as minimal as possible, preferably limited to just creating the DAG - object. -The filename is usually similar to the DAG id and the file should be inside the `my-dags/my_dags/dags` directory, - where `my-dags` is the workflows project folder and `my_dags` is the package name. - -An example of the DAG file: -```python -# The keywords airflow and DAG are required to load the DAGs from this file, see bullet 2 in the Apache Airflow FAQ: -# https://airflow.apache.org/docs/stable/faq.html - -from my_dags.workflows.my_workflow import MyWorkflow - -workflow = MyWorkflow() -globals()[workflow.dag_id] = workflow.make_dag() -``` - -## 2. Creating a workflow file -The workflow file contains the release class at the top, then the workflow class and at the bottom any functions that - are used within these classes. -This filename is also usually similar to the DAG id and should be inside the `my-dags/my_dags/workflows` directory. - -An example of the workflow file: -```python -import pendulum - -from observatory.platform.workflows.workflow import Release, Workflow -from observatory.platform.utils.airflow_utils import AirflowVars, AirflowConns - - -class MyWorkflowRelease(Release): - def __init__(self, dag_id: str, snapshot_date: pendulum.DateTime): - """Construct a Release instance - - :param dag_id: the id of the DAG. - :param snapshot_date: the release date (used to construct release_id). - """ - - self.snapshot_date = snapshot_date - release_id = f"{dag_id}_{self.snapshot_date.strftime('%Y_%m_%d')}" - super().__init__(dag_id, release_id) - - -class MyWorkflow(Workflow): - """ MyWorkflow Workflow.""" - - DAG_ID = "my_workflow" - - def __init__( - self, - dag_id: str = DAG_ID, - start_date: pendulum.DateTime = pendulum.datetime(2020, 1, 1), - schedule_interval: str = "@weekly", - catchup: bool = True, - queue: str = "default", - max_retries: int = 3, - max_active_runs: int = 1, - airflow_vars: list = None, - airflow_conns: list = None, - ): - """Construct a Workflow instance. - - :param dag_id: the id of the DAG. - :param start_date: the start date of the DAG. - :param schedule_interval: the schedule interval of the DAG. - :param catchup: whether to catchup the DAG or not. - :param queue: the Airflow queue name. - :param max_retries: the number of times to retry each task. - :param max_active_runs: the maximum number of DAG runs that can be run at once. - :param airflow_vars: list of airflow variable keys, for each variable it is checked if it exists in airflow - :param airflow_conns: list of airflow connection keys, for each connection it is checked if it exists in airflow - """ - - if airflow_vars is None: - airflow_vars = [ - AirflowVars.DATA_PATH, - AirflowVars.PROJECT_ID, - AirflowVars.DATA_LOCATION, - AirflowVars.DOWNLOAD_BUCKET, - AirflowVars.TRANSFORM_BUCKET, - ] - - # if airflow_conns is None: - # airflow_conns = [AirflowConns.SOMEDEFAULT_CONNECTION] - - super().__init__( - dag_id, - start_date, - schedule_interval, - catchup=catchup, - queue=queue, - max_retries=max_retries, - max_active_runs=max_active_runs, - airflow_vars=airflow_vars, - airflow_conns=airflow_conns, - ) - - # Add sensor tasks - # self.add_operator(some_airflow_sensor) - - # Add setup tasks - self.add_setup_task(self.check_dependencies) - - # Add generic tasks - self.add_task(self.task1) - self.add_task(self.cleanup) - - def make_release(self, **kwargs) -> MyWorkflowRelease: - """Make a release instance. - - :param kwargs: the context passed from the PythonOperator. - :return: A release instance - """ - snapshot_date = kwargs["execution_date"] - release = MyWorkflowRelease(dag_id=self.dag_id, snapshot_date=snapshot_date) - return release - - def task1(self, release: MyWorkflowRelease, **kwargs): - """Add your own comments. - - :param release: A MyWorkflowRelease instance - :param kwargs: The context passed from the PythonOperator. - :return: None. - """ - pass - - def cleanup(self, release: MyWorkflowRelease, **kwargs): - """Delete downloaded, extracted and transformed files of the release. - - :param release: A MyWorkflowRelease instance - :param kwargs: The context passed from the PythonOperator. - :return: None. - """ - release.cleanup() -``` - -### Using airflow Xcoms -Xcoms are an Airflow concept and are used with the workflows to pass on information between tasks. -The description of Xcoms by Airflow can be read - [here](https://airflow.apache.org/docs/apache-airflow/stable/concepts/xcoms.html#xcoms) and is as follows: - ->XComs (short for “cross-communications”) are a mechanism that let Tasks talk to each other, as by default Tasks are - entirely isolated and may be running on entirely different machines. -An XCom is identified by a key (essentially its name), as well as the task_id and dag_id it came from. -They can have any (serializable) value, but they are only designed for small amounts of data; do not use them to pass - around large values, like dataframes. -XComs are explicitly “pushed” and “pulled” to/from their storage using the xcom_push and xcom_pull methods on Task - Instances. -Many operators will auto-push their results into an XCom key called return_value if the do_xcom_push argument is set - to True (as it is by default), and @task functions do this as well. - -They are commonly used to pass on release information in workflows. -One task at the beginning of the workflow will retrieve release information such as the release date or possibly a - relevant release url. -The release information is then pushed during this task using Xcoms and it is pulled in the subsequent tasks, so a - release instance can be made with the given information. -An example of this can be seen in the implemented method `get_release_info` of the StreamTelescope class. - -The `get_release_info` method: -```python -def get_release_info(self, **kwargs) -> bool: - """Push the release info (start date, end date, first release) using Xcoms. - - :param kwargs: The context passed from the PythonOperator. - :return: None. - """ - ti: TaskInstance = kwargs["ti"] - - first_release = False - release_info = ti.xcom_pull(key=self.RELEASE_INFO, include_prior_dates=True) - if not release_info: - first_release = True - # set start date to the start of the DAG - start_date = pendulum.instance(kwargs["dag"].default_args["start_date"]).start_of("day") - else: - # set start date to end date of previous DAG run, add 1 day, because end date was processed in prev run. - start_date = pendulum.parse(release_info[1]) + timedelta(days=1) - # set start date to current day, subtract 1 day, because data from same day might not be available yet. - end_date = pendulum.today("UTC") - timedelta(days=1) - logging.info(f"Start date: {start_date}, end date: {end_date}, first release: {first_release}") - - # Turn dates into strings. Prefer JSON'able data over pickling in Airflow 2. - start_date = start_date.format("YYYYMMDD") - end_date = end_date.format("YYYYMMDD") - - ti.xcom_push(self.RELEASE_INFO, (start_date, end_date, first_release)) - return True -``` - -The start date, end date and first_release boolean are pushed using Xcoms with the `RELEASE_INFO` property as a key. -The info is then used within the `make_release` method. - -See for example the `make_release` method of the OrcidTelescope, which uses the StreamTelescope as a template. -```python -def make_release(self, **kwargs) -> OrcidRelease: - """Make a release instance. The release is passed as an argument to the function (TelescopeFunction) that is - called in 'task_callable'. - - :param kwargs: the context passed from the PythonOperator. See - https://airflow.apache.org/docs/stable/macros-ref.html for a list of the keyword arguments that are - passed to this argument. - :return: an OrcidRelease instance. - """ - ti: TaskInstance = kwargs["ti"] - start_date, end_date, first_release = ti.xcom_pull(key=OrcidTelescope.RELEASE_INFO, include_prior_dates=True) - - release = OrcidRelease( - self.dag_id, pendulum.parse(start_date), pendulum.parse(end_date), first_release, self.max_processes - ) - return release -``` - -### Using Airflow variables and connections -Any information that should not be hardcoded inside the workflow, but is still required for the workflow to function - can be passed on using Airflow variables and connections. -Both variables and connections can be added by defining them in the relevant config file (`config.yaml` in local - develop environment and `config-terraform.yaml` in deployed terraform environment). -Each variable or connection that is defined in the config file is made into an Airflow variable or connection when - starting the observatory environment. -The way these variables and connections are created is dependent on the type of observatory environment. -In the local develop environment, environment variables are created for Airflow variables and connections. -These environment variables are made up of the `AIRLFOW_VAR_` or `AIRFLOW_CONN_` prefix and the name that is used for - the variable or connection in the config file. -The prefixes are determined by Airflow and any environment variables with these prefixes will automatically be - picked up, see the Airflow documentation for more info on managing [variables](https://airflow.apache.org/docs/apache-airflow/stable/howto/variable.html#storing-variables-in-environment-variables) - and [connections](https://airflow.apache.org/docs/apache-airflow/stable/howto/connection.html#storing-a-connection-in-environment-variables) - with environment variables. -In the deployed terraform environment, the Google Cloud Secret Manager is used as a backend to store both Airflow - variables and connections, because this is more secure than using environment variables. -A secret is created for each individual Airflow variable or connection, see the Airflow documentation for more info - on the [secrets backend](https://airflow.apache.org/docs/apache-airflow/stable/security/secrets/secrets-backend/index.html#secrets-backend). - -#### Variables -Airflow variables should never contain any sensitive information. Example uses include the project_id, bucket names - or data location. - -#### Connections -Airflow connections can contain sensitive information and are often used to store credentials like API keys or - usernames and passwords. -In the local development environment, the Airflow connections are stored in the metastore database. -There, the passwords inside the connection configurations are encrypted using Fernet. -The value for the Airflow connection should always be a connection URI, see the [Airflow documentation](https://airflow.apache.org/docs/apache-airflow/stable/howto/connection.html#generating-a-connection-uri) - for more detailed information on how to construct this URI. - -#### Using a new variable or connection -Any new Airflow variables or connections have to be added to either the AirflowVars or AirflowConns class in the - airflow_utils file. -This file can be found at: -`observatory-platform/observatory/platform/utils/airflow_utils.py` - -These two classes act as a registry that make it easy to access the variables and connections in different DAGs -For each class attribute, the attribute name is used inside the workflow and the value is used inside the -`config.yaml` or `config-terraform.yaml` file. - -For example, to add the airflow variable 'new_variable' and connection 'new_connection', the relevant classes are - updated like this: -```python -# Inside observatory-platform/observatory/platform/utils/airflow_utils.py -class AirflowVars: - """Common Airflow Variable names used with the Observatory Platform""" - - # add to existing variables - NEW_VARIABLE = "new_variable" - - -class AirflowConns: - """Common Airflow Connection names used with the Observatory Platform""" - - # add to existing connections - NEW_CONNECTION = "new_connection" -``` - -The variable or connection can then be used inside the workflow like this: -```python -# Inside my-dags/my_dags/workflows/my_workflow.py -from observatory.platform.utils.airflow_utils import AirflowVars, AirflowConns - -airflow_conn = AirflowConns.NEW_CONNECTION -airflow_var = AirflowVars.NEW_VARIABLE -``` - -The relevant section of both the `config.yaml` and `config-terraform.yaml` files will look like this: -```yaml -# User defined Apache Airflow variables: -airflow_variables: - new_variable: my-variable-value - -# User defined Apache Airflow Connections: -airflow_connections: - new_connection: http://my-username:my-password@ -``` - -## 3. Creating a BigQuery schema file -BigQuery database schema json files are stored in `my-dag/my_dags/database/schema`. -They follow the format: `_YYYY-MM-DD.json`. -An additional custom version can be provided together with the date, in this case the files should follow the format: - `__YYYY-MM-DD.json`. - -The BigQuery table loading utility functions in the Observatory Platform will try to find the correct schema to use - for loading table data, based on the release date and custom version. -If no version is specified, the most recent schema with a date less than or equal to the release date of the data is - returned. -If a version string is specified, the most current (date) schema in that series is returned. -The utility functions are used by the BigQuery load tasks of the sub templates (Snapshot, Stream, Organisation) and - it is required to set the `schema_version` parameter to automatically pick up the schema version when using these - templates. - -## 4. Creating a test file -The Observatory Platform uses the `unittest` Python framework as a base and provides additional methods to run tasks - and test DAG structure. -It also uses the Python `coverage` package to analyse test coverage. - -To ensure that the workflow works as expected and to pick up any changes in the code base that would break the - workflow, it is required to add unit tests that cover the code in the developed workflow. - -The test files for workflows are stored in `my-dags/tests/workflows`. -The `ObservatoryTestCase` class in the `observatory-platform/observatory/platform/utils/test_utils.py` file contains - common test methods and should be used as a parent class for the unit tests instead of `unittest.TestCase`. -Additionally, the `ObservatoryEnvironment` class in the `test_utils.py` can be used to simulate the Airflow - environment and the different workflow tasks can be run and tested inside this environment. - -### Testing DAG structure -The workflow's DAG structure can be tested through the `assert_dag_structure` method of `ObservatoryTestCase`. -The DAG object is compared against a dictionary, where the key is the source node, and the value is a list of sink - nodes. -This expresses the relationship that the source node task is a dependency of all of the sink node tasks. - -Example: -```python -import pendulum - -from observatory.platform.utils.test_utils import ObservatoryTestCase -from observatory.platform.workflows.workflow import Release, Workflow - - -class MyWorkflow(Workflow): - def __init__( - self, - dag_id: str = "my_workflow", - start_date: pendulum.DateTime = pendulum.datetime(2017, 3, 20), - schedule_interval: str = "@weekly", - ): - super().__init__(dag_id, start_date, schedule_interval) - - self.add_task(self.task1) - self.add_task(self.task2) - - def make_release(self, **kwargs) -> Release: - snapshot_date = kwargs["execution_date"] - return Release(self.dag_id, snapshot_date) - - def task1(self, release, **kwargs): - pass - - def task2(self, release, **kwargs): - pass - - -class MyTestClass(ObservatoryTestCase): - """Tests for the workflow""" - - def __init__(self, *args, **kwargs): - """Constructor which sets up variables used by tests. - - :param args: arguments. - :param kwargs: keyword arguments. - """ - super(MyTestClass, self).__init__(*args, **kwargs) - - def test_dag_structure(self): - """Test that the DAG has the correct structure. - - :return: None - """ - expected = {"task1": ["task2"], "task2": []} - workflow = MyWorkflow() - dag = workflow.make_dag() - self.assert_dag_structure(expected, dag) -``` - -### Testing DAG loading -To test if a DAG loads from a DagBag, the `assert_dag_load` method can be used within an `ObservatoryEnvironment`. - -Example: -```python -import os -import pendulum - -from observatory.platform.utils.config_utils import module_file_path -from observatory.platform.utils.test_utils import ObservatoryTestCase, ObservatoryEnvironment -from observatory.platform.workflows.workflow import Release, Workflow - - -class MyWorkflow(Workflow): - def __init__( - self, - dag_id: str = "my_workflow", - start_date: pendulum.DateTime = pendulum.datetime(2017, 3, 20), - schedule_interval: str = "@weekly", - ): - super().__init__(dag_id, start_date, schedule_interval) - - self.add_task(self.task1) - self.add_task(self.task2) - - def make_release(self, **kwargs) -> Release: - snapshot_date = kwargs["execution_date"] - return Release(self.dag_id, snapshot_date) - - def task1(self, release, **kwargs): - pass - - def task2(self, release, **kwargs): - pass - - -class MyTestClass(ObservatoryTestCase): - """Tests for the workflow""" - - def __init__(self, *args, **kwargs): - """Constructor which sets up variables used by tests. - - :param args: arguments. - :param kwargs: keyword arguments. - """ - super(MyTestClass, self).__init__(*args, **kwargs) - - def test_dag_load(self): - """Test that the DAG can be loaded from a DAG bag. - - :return: None - """ - with ObservatoryEnvironment().create(): - dag_file = os.path.join(module_file_path("my_dags.dags"), "my_workflow.py") - self.assert_dag_load("my_workflow", dag_file) -``` - -### Testing workflow tasks -To run and test a workflow task, the `run_task` method can be used within an `ObservatoryEnvironment`. - -The ObservatoryEnvironment is used to simulate the Airflow environment. - -To ensure that a workflow can be run from end to end the Observatory Environment creates additional resources, it will: -* Create a temporary local directory. -* Set the OBSERVATORY_HOME environment variable. -* Initialise a temporary Airflow database. -* Create download and transform Google Cloud Storage buckets. -* Create BigQuery dataset(s). -* Create default Airflow Variables: - * AirflowVars.DATA_PATH - * AirflowVars.PROJECT_ID - * AirflowVars.DATA_LOCATION - * AirflowVars.DOWNLOAD_BUCKET - * AirflowVars.TRANSFORM_BUCKET. -* Create an ObservatoryApiEnvironment. -* Start an Elastic environment. -* Clean up all resources when the environment is closed. - -Note that if the unit test is stopped with a forced interrupt, the code block to clean up the created storage buckets - and datasets will not be executed and those resources will have to be manually removed. - -The run dependencies that are imposed on each task by the DAG structure are preserved in the test environment. -This means that to run a specific task, all the previous tasks in the DAG have to run successfully before that task - within the same `create_dag_run` environment. - -Example: -```python -import pendulum - -from observatory.platform.utils.test_utils import ObservatoryTestCase, ObservatoryEnvironment -from observatory.platform.workflows.workflow import Release, Workflow - - -class MyWorkflow(Workflow): - def __init__( - self, - dag_id: str = "my_workflow", - start_date: pendulum.DateTime = pendulum.datetime(2017, 3, 20), - schedule_interval: str = "@weekly", - ): - super().__init__(dag_id, start_date, schedule_interval) - - self.add_task(self.task1) - self.add_task(self.task2) - - def make_release(self, **kwargs) -> Release: - snapshot_date = kwargs["execution_date"] - return Release(self.dag_id, snapshot_date) - - def task1(self, release, **kwargs): - pass - - def task2(self, release, **kwargs): - pass - - -class MyTestClass(ObservatoryTestCase): - """Tests for the workflow""" - - def __init__(self, *args, **kwargs): - """Constructor which sets up variables used by tests. - :param args: arguments. - :param kwargs: keyword arguments. - """ - super(MyTestClass, self).__init__(*args, **kwargs) - self.execution_date = pendulum.datetime(2020, 1, 1) - - def test_workflow(self): - """Test the workflow end to end. - :return: None. - """ - # Setup Observatory environment - env = ObservatoryEnvironment() - - # Setup Workflow - workflow = MyWorkflow() - dag = workflow.make_dag() - - # Create the Observatory environment and run tests - with env.create(): - with env.create_dag_run(dag, self.execution_date): - # Run task1 - env.run_task(workflow.task1.__name__) -``` - -### Temporary GCP datasets -Unit testing frameworks often run tests in parallel, so there is no guarantee of execution order. -When running code that modifies datasets or tables in the Google Cloud, it is recommended to create temporary - datasets for each task to prevent any bugs caused by race conditions. -The `ObservatoryEnvironment` has a method called `add_dataset` that can be used to create a new dataset in the linked - project for the duration of the environment. - -### Observatory Platform API -Some workflows make use of the Observatory Platform API in order to fetch necessary metadata. -When writing unit tests for workflows that use the platform API, it is necessary to use an isolated API environment - where the relevant WorkflowType, Organisations and Telescope exist. -The ObservatoryEnvironment that is mentioned above can be used to achieve this. -An API session is started when creating the ObservatoryEnvironment and the WorkflowType, Organisations and Telescope - can all be added to this session. - -Example: -```python -import pendulum -from airflow.models.connection import Connection - -from my_dags.utils.identifiers import WorkflowTypes -from observatory.api.server import orm -from observatory.platform.utils.airflow_utils import AirflowConns -from observatory.platform.utils.test_utils import ObservatoryEnvironment - -dt = pendulum.now("UTC") - -# Create observatory environment -env = ObservatoryEnvironment() - -# Add the Observatory API connection, used from make_observatory_api() in DAG file -conn = Connection(conn_id=AirflowConns.OBSERVATORY_API, uri=f"http://:password@host:port") -env.add_connection(conn) - - -# Create telescope type with API -workflow_type = orm.WorkflowType(name="ONIX Telescope", type_id=WorkflowTypes.onix, created=dt, modified=dt) -env.api_session.add(workflow_type) - -# Create organisation with API -organisation = orm.Organisation(name="Curtin Press", created=dt, modified=dt) -env.api_session.add(organisation) - -# Create workflow with API -workflow = orm.Telescope( - name="Curtin Press ONIX Telescope", - workflow_type=workflow_type, - organisation=organisation, - modified=dt, - created=dt, -) -env.api_session.add(workflow) - -# Commit changes -env.api_session.commit() -``` - -## 5. Creating a documentation file -The Observatory Platform builds documentation using [Sphinx](https://www.sphinx-doc.org). -Documentation is contained in the `docs` directory. -Currently index pages are written in [RST format (Restructured Text)](https://www.sphinx-doc.org/en/master/usage/restructuredtext/basics.html), - and content pages are written with [Markdown](https://www.sphinx-doc.org/en/master/usage/markdown.html) for simplicity. - -It is possible to build the documentation by using the command: -``` -cd docs -make html -``` -This will output html documentation in the `docs/_build/html` directory and the file `docs_/build/index.html` can be - opened in a browser to preview what the documentation will look like. - -A documentation file with info on the workflow should be added in the `my-dags/docs` directory. -This documentation should at least include: - * A short summary on the data source. - * A summary table, see example below. - * Any details on set-up steps that are required to run this workflow. - * Info on any Airflow connections and variables that are used (see further below). - * The latest schema. - - Example of a summary table using `eval_rst` to format the RST table: - - ```eval_rst - +------------------------------+---------+ - | Summary | | - +==============================+=========+ - | Average runtime | 10 min | - +------------------------------+---------+ - | Average download size | 500 MB | - +------------------------------+---------+ - | Harvest Type | API | - +------------------------------+---------+ - | Harvest Frequency | Monthly | - +------------------------------+---------+ - | Runs on remote worker | True | - +------------------------------+---------+ - | Catchup missed runs | True | - +------------------------------+---------+ - | Table Write Disposition | Truncate| - +------------------------------+---------+ - | Update Frequency | Monthly | - +------------------------------+---------+ - | Credentials Required | No | - +------------------------------+---------+ - | Uses Workflow Template | Snapshot| - +------------------------------+---------+ - | Each shard includes all data | Yes | - +------------------------------+---------+ - ``` - -### Including Airflow variable/connection info in documentation -If a newly developed workflow uses an Airflow connection or variable, this should be explained in the workflow - documentation. -An example of the variable/connection is required as well as an explanation on how the value for this - variable/connection can be obtained. - -See for example this info section on the Airflow connection required with the google_books workflow: - ---- -## Airflow connections -Note that all values need to be urlencoded. -In the config.yaml file, the following airflow connection is required: - -### sftp_service -```yaml -sftp_service: ssh://:@?host_key= -``` -The sftp_service airflow connection is used to connect to the sftp_service and download the reports. -The username and password are created by the sftp service and the host is e.g. `oaebu.exavault.com`. -The host key is optional, you can get it by running ssh-keyscan, e.g.: -``` -ssh-keyscan oaebu.exavault.com -``` ---- - -### Including schemas in documentation -The documentation build system automatically converts all the schema files from `my-dags/my_dags/database/schemas` - into CSV files. -This is temporarily stored in the `docs/schemas` folder. -The csv files have the same filename as the original schema files, except for the suffix, which is changed to csv. -If there are multiple schemas for the same workflow, the `_latest` suffix can be used to always get the latest - version of the schema. -The schemas folder is cleaned up as part of the build process so this directory is not visible, but can be made - visable by disabling the cleanup code in the `Makefile`. - -To include a schema in the documentation markdown file, it is necessary to embed some RST that loads a table from a - csv file. -Since the recommonmark package is used, this can be done with an `eval_rst` codeblock that contains RST: - - ``` eval_rst - .. csv-table:: - :file: /path/to/schema_latest.csv - :width: 100% - :header-rows: 1 - ``` - -To determine the correct file path, it is recommended to construct a relative path to the `docs/schemas` directory - from the directory of the markdown file. - -For example, if the markdown file resides in -`my-dags/docs/my_workflow.md` - -And the schema file path is -`my-dags/my_dags/database/schema/my_workflow_2021-01-01.json` - -then the correct file path that should be used in the RST code block is -``` -:file: ../schemas/my_workflow_latest.csv -``` -The `..` follows the parent directory, this is needed once to reach `docs` from `my-dags/docs/workflows/my_workflow.md`. \ No newline at end of file diff --git a/docs/tutorials/workflow/style.md b/docs/tutorials/workflow/style.md deleted file mode 100644 index eafd431ea..000000000 --- a/docs/tutorials/workflow/style.md +++ /dev/null @@ -1,17 +0,0 @@ -# Style -All code should try to conform to the Python PEP-8 standard, and the default format style of the `Black` formatter. -This is done with the [autopep8 package](https://pypi.org/project/autopep8), and the - [black formatter](https://pypi.org/project/black/), using a line length of 120. - -It is recommended to use those format tools as part of the coding workflow. - -## Type hinting -Type hints should be provided for all of the function arguments that are used, and for return types. -Because Python is a weakly typed language, it can be confusing to those unacquainted with the codebase what type of - objects are being manipulated in a particular function. -Type hints help reduce this ambiguity. - -## Docstring -Docstring comments should also be provided for all classes, methods, and functions. -This includes descriptions of arguments, and returned objects. -These comments will be automatically compiled into the Observatory Platform API reference documentation section. \ No newline at end of file diff --git a/docs/tutorials/workflow/workflow_class.md b/docs/tutorials/workflow/workflow_class.md deleted file mode 100644 index c39a13571..000000000 --- a/docs/tutorials/workflow/workflow_class.md +++ /dev/null @@ -1,404 +0,0 @@ -# Workflow template -## Workflow -```eval_rst -See :meth:`platform.workflows.workflow.Workflow` for the API reference. -``` - -The workflow class is the most basic template that can be used. -It implements methods from the AbstractWorkflow class and it is not recommended that the AbstractWorkflow class is - used directly itself. - -### Make DAG -The `make_dag` method of the workflow class is used to create an Airflow DAG object. -When the object is defined in the global namespace, it is picked up by the Airflow scheduler and ensures that all tasks - are scheduled. - -### Adding tasks to DAG -It is possible to add one of the three types of tasks to this DAG object: - * Sensor - * Set-up task - * Task - -All three types of tasks can be added individually per task using the `add_` method or a list of tasks - can be added using the `add__chain` method. -To better understand the difference between these type of tasks, it is helpful to know how tasks are created in - Airflow. -Within a DAG, each task that is part of the DAG is created by instantiating an Operator class. -There are many different types of Airflow Operators available, but in the case of the template the usage is limited to - the BaseSensorOperator, PythonOperator and the ShortCircuitOperator. - -* The BaseSensorOperator keeps executing at a regular time interval and succeeds when a criteria is met and fails if and - when they time out. -* The PythonOperator simply calls an executable Python function. -* The ShortCircuitOperator is derived from the PythonOperator and additionally evaluates a condition. When the - conditions is False it short-circuits the workflow by skipping all downstream tasks. - -The **sensor** type instantiates the BaseSensorOperator (or a child class of this operator). -All sensor tasks are always chained to the beginning of the DAG. -Tasks of this type are useful for example to probe whether another task has finished successfully using the - ExternalTaskSensor. - -The **set-up task** type instantiates the ShortCircuitOperator. -Because the ShortCircuitOperator is used, the executable Python function that is called with this operator has to - return a boolean. -The returned value is then evaluated to determine whether the workflow continues. -Additionally, the set-up task does not require a release instance as an argument passed to the Python function, in - contrast to a 'general' task. -The set-up tasks are chained after any sensors and before any remaining 'general' tasks. -Tasks of this type are useful for example to check whether all dependencies for a workflow are met or to list which - releases are available. - -The general **task** type instantiates the PythonOperator. -The executable Python function that is called with this operator requires a release instance to be passed on as an - argument. -These tasks are always chained after any sensors and set-up tasks. -Tasks of this type are the most common in the workflows and are useful for any functionality that requires release - information such as downloading, transforming, loading into BigQuery, etc. - -Order of the different task types within a workflow: -

-Order of workflow tasks -

- -By default all tasks within the same type (sensor, setup task, task) are chained linearly in the order they are - inserted. -There is a context manager `parallel_tasks` which can be used to parallelise tasks. -All tasks that are added within that context are added in parallel. -Currently this is only supported for setup tasks. - -### The 'make_release' method -The `make_release` method is used to create a (list of) release instance(s). -A general task always requires a release instance as a parameter, so the `make_release` method is called when the - PythonOperator for a general task is made. -The release (or list of releases) that is made with this method is then passed on as a parameter to any general task - of that workflow. -Inside the general task the release properties can then be used for things such as local download paths. -Because the method is used for any general task, this method always has to be implemented. - -### Checking dependencies -The workflow class also has a method `check_dependencies` implemented that can be added as a set-up task. -All workflows require that at least some Airflow Variables and Connections are set, so these dependencies should be - checked at the start of each workflow and this can be done with this task. - -## Release -```eval_rst -See :meth:`platform.workflows.workflow.Release` for the API reference. -``` - -The Release class is a basic implementation of the AbstractRelease class. -An instance of the release class is passed on as an argument to any general tasks that are added to the workflow. -Similarly in set-up to the workflow class, it implements methods from the AbstractRelease class and it is not - recommended that the AbstractRelease class is used directly by itself. -The properties and methods that are added to the Release class should all be relevant to the release instance. -If they are always the same, independent of the release instance, they are better placed in the Workflow class. - -### The release id -The Release class always needs a release id. -This release id is usually based on the release date, so it is unique for each release and relates to the date when - the data became available or was processed. -It is used for the folder paths described below. - -### Folder paths -The Release class has properties for the paths of 3 different folders: - * `download_folder` - * `extract_folder` - * `transform_folder` - - It is convenient to use these when downloading/extract/transforming data and writing the data to a file in the - matching folder. -The paths for these folders always include the release id and the format is as follows: -`/path/to/workflows/{download|extract|transform}/{dag_id}/{release_id}/` - -The `path/to/workflows` is determined by a separate function. -Having these folder paths as properties of the release class makes it easy to have the same file structure for each - workflow. - -### List files in folders -The folder paths are also used for the 3 corresponding properties: - * `download_files` - * `extract_files` - * `transform_files` - -These properties will each return a list of files in their corresponding folder that match a given regex pattern. -This is useful when e.g. iterating through all download files to transform them, or passing on the list of transform - files to a function that uploads all files to a storage bucket. -The regex patterns for each of the 3 folders is passed on separately when instantiating the release class. - -### Bucket names -There are 2 storage buckets used to store the data processed with the workflow, a download bucket and a transform - bucket. -The bucket names are retrieved from Airflow Variables and there are 2 corresponding properties in the release class, -`download_bucket` and `transform_bucket`. -These properties are convenient to use when uploading data to either one of these buckets. - -### Clean up -The Release class has a `cleanup` method which can be called inside a task that will 'clean up' by deleting the 3 - folders mentioned above. -This method is part of the release class, because a clean up task is part of each workflow and it uses those - folder paths described above that are properties of the release class. - - -## Example -Below is an example of a simple workflow using the Workflow template. - -Workflow file: -```python -import pendulum -from airflow.sensors.external_task import ExternalTaskSensor - -from observatory.platform.workflows.workflow import Release, Workflow -from observatory.platform.utils.airflow_utils import AirflowVars, AirflowConns - - -class MyRelease(Release): - def __init__(self, dag_id: str, snapshot_date: pendulum.DateTime): - """Construct a Release instance - - :param dag_id: the id of the DAG. - :param snapshot_date: the release date (used to construct release_id). - """ - - self.snapshot_date = snapshot_date - release_id = f'{dag_id}_{self.snapshot_date.strftime("%Y_%m_%d")}' - super().__init__(dag_id, release_id) - - -class MyWorkflow(Workflow): - """MyWorkflow Workflow.""" - - DAG_ID = "my_workflow" - - def __init__( - self, - dag_id: str = DAG_ID, - start_date: pendulum.DateTime = pendulum.datetime(2020, 1, 1), - schedule_interval: str = "@weekly", - catchup: bool = True, - queue: str = "default", - max_retries: int = 3, - max_active_runs: int = 1, - airflow_vars: list = None, - airflow_conns: list = None, - ): - """Construct a Workflow instance. - - :param dag_id: the id of the DAG. - :param start_date: the start date of the DAG. - :param schedule_interval: the schedule interval of the DAG. - :param catchup: whether to catchup the DAG or not. - :param queue: the Airflow queue name. - :param max_retries: the number of times to retry each task. - :param max_active_runs: the maximum number of DAG runs that can be run at once. - :param airflow_vars: list of airflow variable keys, for each variable it is checked if it exists in airflow - :param airflow_conns: list of airflow connection keys, for each connection it is checked if it exists in airflow - """ - - if airflow_vars is None: - airflow_vars = [ - AirflowVars.DATA_PATH, - AirflowVars.PROJECT_ID, - AirflowVars.DATA_LOCATION, - AirflowVars.DOWNLOAD_BUCKET, - AirflowVars.TRANSFORM_BUCKET, - ] - - if airflow_conns is None: - airflow_conns = [AirflowConns.SOMEDEFAULT_CONNECTION] - - super().__init__( - dag_id, - start_date, - schedule_interval, - catchup=catchup, - queue=queue, - max_retries=max_retries, - max_active_runs=max_active_runs, - airflow_vars=airflow_vars, - airflow_conns=airflow_conns, - ) - - # Add sensor tasks - sensor = ExternalTaskSensor(external_dag_id="my_other_workflow", task_id="important_task", mode="reschedule") - self.add_operator(sensor) - - # Add setup tasks - self.add_setup_task(self.check_dependencies) - - # Add generic tasks - self.add_task(self.task1) - self.add_task(self.cleanup) - - def make_release(self, **kwargs) -> MyRelease: - """Make a release instance. - - :param kwargs: the context passed from the PythonOperator. - :return: A release instance - """ - snapshot_date = kwargs["execution_date"] - release = MyRelease(dag_id=self.dag_id, snapshot_date=snapshot_date) - return release - - def task1(self, release: MyRelease, **kwargs): - """Add your own comments. - - :param release: A MyRelease instance - :param kwargs: The context passed from the PythonOperator. - :return: None. - """ - pass - - def cleanup(self, release: MyRelease, **kwargs): - """Delete downloaded, extracted and transformed files of the release. - - :param release: A MyRelease instance - :param kwargs: The context passed from the PythonOperator. - :return: None. - """ - release.cleanup() -``` - -DAG file: -```python -# The keywords airflow and DAG are required to load the DAGs from this file, see bullet 2 in the Apache Airflow FAQ: -# https://airflow.apache.org/docs/stable/faq.html - -from observatory.dags.workflows.my_workflow import MyWorkflow - -workflow = MyWorkflow() -globals()[workflow.dag_id] = workflow.make_dag() -``` - -In case you are familiar with creating DAGs in Airflow, below is the equivalent workflow without using the template. -This might help to understand how the template works behind the scenes. - -Workflow and DAG in one file: -```python -# Copyright 2021 Curtin University -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Author: Aniek Roelofs -import shutil -import logging -from pendulum import datetime -from airflow import DAG -from airflow.exceptions import AirflowException -from airflow.operators.python import PythonOperator -from airflow.operators.python import ShortCircuitOperator -from airflow.sensors.external_task import ExternalTaskSensor - -from observatory.platform.utils.airflow_utils import AirflowConns, AirflowVars, check_connections, check_variables -from observatory.platform.utils.workflow_utils import ( - SubFolder, - on_failure_callback, - workflow_path, -) - - -def check_dependencies() -> bool: - """Checks the 'workflow' attributes, airflow variables & connections and possibly additional custom checks. - - :return: Whether variables and connections are available. - """ - # check that vars and connections are available - airflow_vars = [ - AirflowVars.DATA_PATH, - AirflowVars.PROJECT_ID, - AirflowVars.DATA_LOCATION, - AirflowVars.DOWNLOAD_BUCKET, - AirflowVars.TRANSFORM_BUCKET, - ] - vars_valid = check_variables(*airflow_vars) - - airflow_conns = [AirflowConns.SOMEDEFAULT_CONNECTION] - conns_valid = check_connections(*airflow_conns) - - if not vars_valid or not conns_valid: - raise AirflowException("Required variables or connections are missing") - - return True - - -def task1(**kwargs): - """Add your own comments. - - :param kwargs: The context passed from the PythonOperator. - :return: None. - """ - pass - - -def cleanup(**kwargs): - """Delete downloaded, extracted and transformed files of the release. - - :param kwargs: The context passed from the PythonOperator. - :return: None. - """ - dag_id = "my_workflow" - snapshot_date = kwargs["execution_date"] - release_id = f'{dag_id}_{snapshot_date.strftime("%Y_%m_%d")}' - download_folder = workflow_path(SubFolder.downloaded.value, dag_id, release_id) - extract_folder = workflow_path(SubFolder.extracted.value, dag_id, release_id) - transform_folder = workflow_path(SubFolder.transformed.value, dag_id, release_id) - - for path in [download_folder, extract_folder, transform_folder]: - try: - shutil.rmtree(path) - except FileNotFoundError as e: - logging.warning(f"No such file or directory {path}: {e}") - - -default_args = { - "owner": "airflow", - "start_date": datetime(2020, 1, 1), - "on_failure_callback": on_failure_callback, - "retries": 3, -} - -with DAG( - dag_id="my_workflow", - start_date=datetime(2020, 1, 1), - schedule_interval="@weekly", - default_args=default_args, - catchup=True, - max_active_runs=1, - doc_md="MyWorkflow Workflow", -) as dag: - - sensor_task = ExternalTaskSensor( - external_dag_id="my_other_workflow", - task_id="important_task", - mode="reschedule", - queue="default", - default_args=default_args, - provide_context=True, - ) - - check_dependencies_task = ShortCircuitOperator( - task_id="check_dependencies", - python_callable=check_dependencies, - queue="default", - default_args=default_args, - provide_context=True, - ) - - task_1 = PythonOperator( - task_id="task1", python_callable=task1, queue="default", default_args=default_args, provide_context=True - ) - - cleanup_task = PythonOperator( - task_id="cleanup", python_callable=cleanup, queue="default", default_args=default_args, provide_context=True - ) - -sensor_task >> check_dependencies_task >> task_1 >> cleanup_task -``` \ No newline at end of file diff --git a/install.sh b/install.sh deleted file mode 100755 index 2e2969b67..000000000 --- a/install.sh +++ /dev/null @@ -1,474 +0,0 @@ -#!/bin/bash - -venv_observatory_platform="observatory_venv" -airflow_version="2.6.3" -python_version="3.10" - -function set_os_arch() { - os=$(uname -s) - os_human=$os - os=$(lower_case $os) - - if [ "$os" == "Darwin" ]; then - os_human="Mac OS $(uname -r)" - fi - - arch=$(uname -m) - arch=$(lower_case $arch) -} - -## Functions - -function lower_case() { - echo $(echo $1 | tr '[:upper:]' '[:lower:]') -} - -function ask_question() { - local question="$1" - shift - local default_option="$1" - shift - local options=("$@") - - while :; do - local response - read -p "$question" response - response=${response:-$default_option} - response=$(lower_case $response) - - for option in "${options[@]}"; do - if [ "$response" = "$option" ]; then - break 2 - fi - done - done - - echo "$response" -} - -function check_system() { - if [ "$os" != "linux" ] && [ "$os" != "darwin" ]; then - echo "Incompatible operating system detected: $os_human" - exit 1 - fi - - if [ "$arch" != "x86_64" ] && [ "$arch" != "arm64" ]; then - echo "Incompatible architecture detected: $arch" - exit 1 - fi -} - -function ask_github_https_or_ssh() { - options=("https" "ssh") - options_str=$(echo ${options[@]} | sed "s/ /, /") - default_option="https" - question="Do you wish to use https or ssh to clone the source repository? ($options_str) [${default_option}]: " - clone_mode=$(ask_question "$question" "$default_option" "${options[@]}") - - if [ "$clone_mode" = "https" ]; then - clone_prefix="https://github.com/" - else - clone_prefix="git@github.com:" - fi - -} - -function ask_install_mode() { - options=("pypi" "source") - options_str=$(echo ${options[@]} | sed "s/ /, /") - default_option="pypi" - question="Do you wish to install with pypi (pip) or from source? If you want to just run the observatory platform, pypi is recommended. If you intend to modify or develop the platform, source is recommended. ($options_str) [${default_option}]: " - mode=$(ask_question "$question" "$default_option" "${options[@]}") - - if [ "$mode" = "source" ]; then - pip_install_env_flag="-e" - ask_github_https_or_ssh - fi -} - -function ask_install_observatory_tests() { - options=("y" "n") - default_option="y" - options_str=$(echo ${options[@]} | sed "s/ /, /") - question="Do you wish to install extra developer testing packages? ($options_str) [${default_option}]: " - install_test_extras=$(ask_question "$question" "$default_option" "${options[@]}") - - if [ "$install_test_extras" = "y" ]; then - test_suffix="[tests]" - fi -} - -function ask_install_academic_observatory_workflows() { - options=("y" "n") - default_option="y" - options_str=$(echo ${options[@]} | sed "s/ /, /") - question="Do you wish to install the academic-observatory-workflows? ($options_str) [${default_option}]: " - export install_ao_workflows=$(ask_question "$question" "$default_option" "${options[@]}") -} - -function ask_install_oaebu_workflows() { - options=("y" "n") - default_option="y" - options_str=$(echo ${options[@]} | sed "s/ /, /") - question="Do you wish to install the oaebu-workflows? ($options_str) [${default_option}]: " - export install_oaebu_workflows=$(ask_question "$question" "$default_option" "${options[@]}") -} - -function ask_config_type() { - options=("local" "terraform") - default_option="local" - options_str=$(echo ${options[@]} | sed "s/ /, /") - question="Do you want to use a local observatory platform configuration or use Terraform? ($options_str) [${default_option}]: " - config_type=$(ask_question "$question" "$default_option" "${options[@]}") -} - -function ask_config_observatory_base() { - options=("y" "n") - default_option="y" - options_str=$(echo ${options[@]} | sed "s/ /, /") - question="Do you want to configure Observatory platform basic settings during config file generation? Note that if you do not configure it now, you need to configure all the sections tagged [Required] later on by editing the config.yaml or config-terraform.yaml file. ($options_str) [${default_option}]: " - config_observatory_base=$(ask_question "$question" "$default_option" "${options[@]}") -} - -function ask_config_path_custom() { - question="Path to save generated config file to [leave blank for default config path]: " - default_option="" - read -p "$question" config_path - config_path=${config_path:-$default_option} -} - -function configure_install_options() { - echo "================================" - echo "Configuring installation options" - echo "================================" - - while :; do - # Configure options - install_oapi="y" - - ask_install_mode - ask_install_observatory_tests - ask_install_academic_observatory_workflows - ask_install_oaebu_workflows - ask_config_type - ask_config_observatory_base - - echo -e "\n" - - echo "==========================================================" - echo -e "Installation configuration summary:" - echo "----------------------------------------------------------" - echo "Operating system: $os_human, architecture: $arch" - echo "Install Observatory Platform: y" - echo "Installation type: $mode" - echo "Install extra developer testing packages: $install_test_extras" - echo "Install Observatory API: $install_oapi" - echo "Install Academic Observatory Workflows: $install_ao_workflows" - echo "Install OAEBU Workflows: $install_oaebu_workflows" - echo "Observatory type: $config_type" - echo "Configure settings during config file generation step: $config_observatory_base" - echo "" - echo "==========================================================" - echo -e "\n" - - local correct="" - while [ "$correct" != "y" ] && [ "$correct" != "n" ]; do - read -p "Are these options correct? (y, n) [y]: " correct - correct=${correct:-Y} - correct=$(lower_case $correct) - done - - if [ "$correct" = "y" ]; then - break - fi - - echo "Asking configuration questions again. If you wish to exit the installation script, press Ctrl+C" - done -} - -function install_google_cloud_sdk() { - echo "---------------------------" - echo "Installing Google Cloud SDK" - echo "---------------------------" - - local url="https://dl.google.com/dl/cloudsdk/channels/rapid/downloads/google-cloud-sdk-${google_cloud_sdk_version}-${os}-${gcloud_sdk_arch}.tar.gz" - sudo curl -L $url -o /usr/local/bin/google-cloud-sdk.tar.gz - sudo rm -rf /usr/local/bin/google-cloud-sdk - sudo tar -xzvf /usr/local/bin/google-cloud-sdk.tar.gz -C /usr/local/bin - sudo rm /usr/local/bin/google-cloud-sdk.tar.gz - sudo chmod +x /usr/local/bin/google-cloud-sdk - sudo /usr/local/bin/google-cloud-sdk/install.sh -} - -function install_terraform_deps() { - ## Package versions to install - - terraform_version="1.5.5" - packer_version="1.9.2" - google_cloud_sdk_version="330.0.0" - - terraform_arch=$arch - if [ "$arch" = "x86_64" ]; then - terraform_arch="amd64" - fi - - gcloud_sdk_arch=$arch - if [ "$arch" = "arm64" ]; then - gcloud_sdk_arch="arm" - fi - - echo "=================================" - echo "Installing Terraform dependencies" - echo "=================================" - - echo "-----------------" - echo "Installing Packer" - echo "-----------------" - - local url="https://releases.hashicorp.com/packer/${packer_version}/packer_${packer_version}_${os}_${terraform_arch}.zip" - sudo rm -f /usr/local/bin/packer - sudo curl -L $url -o /usr/local/bin/packer.zip - sudo unzip /usr/local/bin/packer.zip -d /usr/local/bin/ - sudo chmod +x /usr/local/bin/packer - sudo rm /usr/local/bin/packer.zip - - # Install Google Cloud SDK - install_google_cloud_sdk - - echo "--------------------" - echo "Installing Terraform" - echo "--------------------" - - local url="https://releases.hashicorp.com/terraform/${terraform_version}/terraform_${terraform_version}_${os}_${terraform_arch}.zip" - sudo curl -L $url -o /usr/local/bin/terraform.zip - # When asked to replace, answer 'y' - sudo rm -f /usr/local/bin/terraform - sudo unzip /usr/local/bin/terraform.zip -d /usr/local/bin/ - sudo chmod +x /usr/local/bin/terraform - sudo rm /usr/local/bin/terraform.zip -} - -function install_ubuntu_system_deps() { - echo "=====================================" - echo "Installing Ubuntu system dependencies" - echo "=====================================" - - sudo apt update - sudo apt-get install -y software-properties-common curl git python${python_version} python${python_version}-dev python3-pip python3-virtualenv - - echo "--------------------------" - echo "Creating Python virtualenv" - echo "--------------------------" - - virtualenv -p python${python_version} $venv_observatory_platform - - echo "-----------------" - echo "Installing docker" - echo "-----------------" - - sudo apt-get install -y docker.io docker-compose-plugin - - echo "Adding $(id -nu) to the docker group" - sudo usermod -aG docker $(id -nu) - - if [ "$config_type" = "terraform" ]; then - install_terraform_deps - fi - - echo -e "\n" -} - -function install_macos_system_deps() { - echo "=====================================" - echo "Installing Mac OS system dependencies" - echo "=====================================" - - if [ "$(command -v brew)" = "" ]; then - echo "Installing Homebrew" - /bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/master/install.sh)" - fi - - echo "---------------------" - echo "Installing Python ${python_version}" - echo "---------------------" - - brew install python@${python_version} - echo 'export PATH="/usr/local/opt/python@${python_version}/bin:$PATH"' >>~/.bash_profile - - echo "------------------------------" - echo "Installing Python dependencies" - echo "------------------------------" - - pip3 install -U virtualenv - - echo "--------------------------" - echo "Creating Python virtualenv" - echo "--------------------------" - - virtualenv -p /usr/local/opt/python@${python_version}/Frameworks/Python.framework/Versions/${python_version}/bin/python3 $venv_observatory_platform - - echo "-----------------" - echo "Installing Docker" - echo "-----------------" - - brew install --cask docker - - if [ "$config_type" = "terraform" ]; then - install_terraform_deps - fi -} - -function install_system_deps() { - if [ "$os" = "linux" ]; then - install_ubuntu_system_deps - elif [ "$os" = "darwin" ]; then - install_macos_system_deps - fi -} - -function install_observatory_platform() { - echo "--------------------------------------------------" - echo "Updating the virtual environment's Python packages" - echo "--------------------------------------------------" - - pip3 install -U pip virtualenv wheel - - echo "===========================================" - echo "Installing the observatory-platform package" - echo "===========================================" - - if [ "$mode" = "source" ]; then - git clone ${clone_prefix}The-Academic-Observatory/observatory-platform.git - cd observatory-platform - fi - - pip3 install ${pip_install_env_flag} observatory-api${test_suffix} --constraint https://raw.githubusercontent.com/apache/airflow/constraints-${airflow_version}/constraints-no-providers-${python_version}.txt - pip3 install ${pip_install_env_flag} observatory-platform${test_suffix} --constraint https://raw.githubusercontent.com/apache/airflow/constraints-${airflow_version}/constraints-no-providers-${python_version}.txt -} - -function install_observatory_api() { - if [ "$install_oapi" = "n" ]; then - return 0 - fi - - echo "==========================" - echo "Installing Observatory API" - echo "==========================" - -} - -function install_academic_observatory_workflows() { - if [ "$install_ao_workflows" = "n" ]; then - return 0 - fi - - echo "=========================================" - echo "Installing academic observatory workflows" - echo "=========================================" - - local prefix="" - - if [ "$mode" = "source" ]; then - mkdir -p workflows - cd workflows - git clone ${clone_prefix}The-Academic-Observatory/academic-observatory-workflows.git - cd .. - prefix="workflows/" - fi - - pip3 install ${pip_install_env_flag} ${prefix}academic-observatory-workflows${test_suffix} --constraint https://raw.githubusercontent.com/apache/airflow/constraints-${airflow_version}/constraints-no-providers-${python_version}.txt -} - -function install_oaebu_workflows() { - if [ "$install_oaebu_workflows" = "n" ]; then - return 0 - fi - - echo "==========================" - echo "Installing oaebu workflows" - echo "==========================" - - local prefix="" - - if [ "$mode" = "source" ]; then - mkdir -p workflows - cd workflows - git clone ${clone_prefix}The-Academic-Observatory/oaebu-workflows.git - cd .. - prefix="workflows/" - fi - - pip3 install ${pip_install_env_flag} ${prefix}oaebu-workflows${test_suffix} --constraint https://raw.githubusercontent.com/apache/airflow/constraints-${airflow_version}/constraints-no-providers-${python_version}.txt -} - -function generate_observatory_config() { - ask_config_path_custom - - local interactive="" - local ao_wf="" - local oaebu_wf="" - local oapi="" - local config_path_arg="" - local editable="" - - if [ "$config_observatory_base" = "y" ]; then - interactive="--interactive" - fi - - if [ "$install_ao_workflows" = "y" ]; then - ao_wf="--ao-wf" - fi - - if [ "$install_oaebu_workflows" = "y" ]; then - oaebu_wf="--oaebu-wf" - fi - - if [ "$install_oapi" = "y" ]; then - oapi="--oapi" - fi - - if [ "$config_path" != "" ]; then - config_path_arg="--config-path $config_path" - fi - - if [ "$mode" = "source" ]; then - editable="--editable" - fi - - echo "=============================" - echo "Generating Observatory config" - echo "=============================" - - observatory generate config $config_path_arg $editable $interactive $ao_wf $oaebu_wf $oapi $config_type -} - -#### ENTRY POINT #### - -echo "==================================================================================================================================" -echo "Installing Academic Observatory Platform. You may be prompted at some stages to enter in a password to install system dependencies" -echo "==================================================================================================================================" - -set_os_arch -check_system -configure_install_options -install_system_deps - -source $venv_observatory_platform/bin/activate - -install_observatory_platform -install_academic_observatory_workflows -install_oaebu_workflows - -if [ "$mode" = "source" ]; then - cd .. -fi - -generate_observatory_config - -deactivate - -echo "==================================================================================================================================" -echo "Installation complete." -echo "Please restart your computer for the Docker installation to take effect." -echo -e "You can start the observatory platform after the restart by first activating the Python virtual environment with:\n source ${PWD}/${venv_observatory_platform}/bin/activate" -echo -e "Once activated, you can start the observatory with: observatory platform start" diff --git a/observatory-api/.dockerignore b/observatory-api/.dockerignore deleted file mode 100644 index a66c47ad1..000000000 --- a/observatory-api/.dockerignore +++ /dev/null @@ -1 +0,0 @@ -*.egg-info \ No newline at end of file diff --git a/observatory-api/.gitignore b/observatory-api/.gitignore deleted file mode 100644 index 24a1d9c09..000000000 --- a/observatory-api/.gitignore +++ /dev/null @@ -1 +0,0 @@ -openapi.yaml \ No newline at end of file diff --git a/observatory-api/.openapi-generator-ignore b/observatory-api/.openapi-generator-ignore deleted file mode 100644 index 461c5e750..000000000 --- a/observatory-api/.openapi-generator-ignore +++ /dev/null @@ -1,28 +0,0 @@ -# OpenAPI Generator Ignore -# Generated by openapi-generator https://github.com/openapitools/openapi-generator - -# Use this file to prevent files from being overwritten by the generator. -# The patterns follow closely to .gitignore or .dockerignore. - -# Stop init file being output -observatory/__init__.py -observatory/api/client/apis/ -observatory/api/client/models/ - -# As an example, the C# client generator defines ApiClient.cs. -# You can make changes and tell OpenAPI Generator to ignore just this file by uncommenting the following line: -#ApiClient.cs - -# You can match any string of characters against a directory, file or extension with a single asterisk (*): -#foo/*/qux -# The above matches foo/bar/qux and foo/baz/qux, but not foo/bar/baz/qux - -# You can recursively match patterns against a directory, file or extension with a double asterisk (**): -#foo/**/qux -# This matches foo/bar/qux, foo/baz/qux, and foo/bar/baz/qux - -# You can also negate patterns with an exclamation (!). -# For example, you can ignore all files in a docs folder with the file extension .md: -#docs/*.md -# Then explicitly reverse the ignore rule for a single file: -#!docs/README.md diff --git a/observatory-api/.openapi-generator/FILES b/observatory-api/.openapi-generator/FILES deleted file mode 100644 index bae3c2d89..000000000 --- a/observatory-api/.openapi-generator/FILES +++ /dev/null @@ -1,17 +0,0 @@ -observatory/api/__init__.py -observatory/api/client/__init__.py -observatory/api/client/api/__init__.py -observatory/api/client/api/observatory_api.py -observatory/api/client/api_client.py -observatory/api/client/configuration.py -observatory/api/client/docs/DatasetRelease.md -observatory/api/client/docs/ObservatoryApi.md -observatory/api/client/exceptions.py -observatory/api/client/model/__init__.py -observatory/api/client/model/dataset_release.py -observatory/api/client/model_utils.py -observatory/api/client/rest.py -observatory/api/client/test/__init__.py -observatory/api/client/test/test_dataset_release.py -observatory/api/client/test/test_observatory_api.py -observatory/api/client_README.md diff --git a/observatory-api/.openapi-generator/VERSION b/observatory-api/.openapi-generator/VERSION deleted file mode 100644 index 358e78e60..000000000 --- a/observatory-api/.openapi-generator/VERSION +++ /dev/null @@ -1 +0,0 @@ -6.1.0 \ No newline at end of file diff --git a/observatory-api/README.md b/observatory-api/README.md deleted file mode 100644 index 5835b0b21..000000000 --- a/observatory-api/README.md +++ /dev/null @@ -1 +0,0 @@ -## Observatory API \ No newline at end of file diff --git a/observatory-api/api-config.yaml b/observatory-api/api-config.yaml deleted file mode 100644 index 198e09ce5..000000000 --- a/observatory-api/api-config.yaml +++ /dev/null @@ -1,11 +0,0 @@ -globalProperties: - apiTests: "true" - modelTests: "true" -additionalProperties: - generateSourceCodeOnly: "true" - hideGenerationTimestamp: "true" - library: urllib3 - packageName: observatory.api.client - projectName: client - pythonAttrNoneIfUnset: "true" - useNose: "false" \ No newline at end of file diff --git a/observatory-api/observatory/api/cli.py b/observatory-api/observatory/api/cli.py deleted file mode 100644 index 140ade733..000000000 --- a/observatory-api/observatory/api/cli.py +++ /dev/null @@ -1,39 +0,0 @@ -import click -from observatory.api.server.openapi_renderer import OpenApiRenderer - - -@click.group() -def cli(): - """The Observatory API command line tool. - - COMMAND: the commands to run include:\n - - generate-openapi-spec: generate an OpenAPI specification for the Observatory API.\n - """ - - pass - - -@cli.command() -@click.argument("template-file", type=click.Path(exists=True, file_okay=True, dir_okay=False)) -@click.argument("output-file", type=click.Path(exists=False, file_okay=True, dir_okay=False)) -@click.option( - "--api-client", is_flag=True, default=False, help="Generate OpenAPI config for OpenAPI client generation." -) -def generate_openapi_spec(template_file, output_file, api_client): - """Generate an OpenAPI specification for the Observatory API.\n - - TEMPLATE_FILE: the type of config file to generate. - OUTPUT_FILE: the type of config file to generate. - """ - - # Render file - renderer = OpenApiRenderer(template_file, api_client=api_client) - render = renderer.render() - - # Save file - with open(output_file, mode="w") as f: - f.write(render) - - -if __name__ == "__main__": - cli() diff --git a/observatory-api/observatory/api/client/__init__.py b/observatory-api/observatory/api/client/__init__.py deleted file mode 100644 index cc6bd4e13..000000000 --- a/observatory-api/observatory/api/client/__init__.py +++ /dev/null @@ -1,28 +0,0 @@ -# flake8: noqa - -""" - Observatory API - - The REST API for managing and accessing data from the Observatory Platform. # noqa: E501 - - The version of the OpenAPI document: 1.0.0 - Contact: agent@observatory.academy - Generated by: https://openapi-generator.tech -""" - - -__version__ = "1.0.0" - -# import ApiClient -from observatory.api.client.api_client import ApiClient - -# import Configuration -from observatory.api.client.configuration import Configuration - -# import exceptions -from observatory.api.client.exceptions import OpenApiException -from observatory.api.client.exceptions import ApiAttributeError -from observatory.api.client.exceptions import ApiTypeError -from observatory.api.client.exceptions import ApiValueError -from observatory.api.client.exceptions import ApiKeyError -from observatory.api.client.exceptions import ApiException diff --git a/observatory-api/observatory/api/client/api/__init__.py b/observatory-api/observatory/api/client/api/__init__.py deleted file mode 100644 index 6abd54036..000000000 --- a/observatory-api/observatory/api/client/api/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -# do not import all apis into this module because that uses a lot of memory and stack frames -# if you need the ability to import all apis from one package, import them with -# from observatory.api.client.apis import ObservatoryApi diff --git a/observatory-api/observatory/api/client/api/observatory_api.py b/observatory-api/observatory/api/client/api/observatory_api.py deleted file mode 100644 index bce4186c6..000000000 --- a/observatory-api/observatory/api/client/api/observatory_api.py +++ /dev/null @@ -1,709 +0,0 @@ -""" - Observatory API - - The REST API for managing and accessing data from the Observatory Platform. # noqa: E501 - - The version of the OpenAPI document: 1.0.0 - Contact: agent@observatory.academy - Generated by: https://openapi-generator.tech -""" - - -import re # noqa: F401 -import sys # noqa: F401 - -from observatory.api.client.api_client import ApiClient, Endpoint as _Endpoint -from observatory.api.client.model_utils import ( # noqa: F401 - check_allowed_values, - check_validations, - date, - datetime, - file_type, - none_type, - validate_and_convert_types -) -from observatory.api.client.model.dataset_release import DatasetRelease - - -class ObservatoryApi(object): - """NOTE: This class is auto generated by OpenAPI Generator - Ref: https://openapi-generator.tech - - Do not edit the class manually. - """ - - def __init__(self, api_client=None): - if api_client is None: - api_client = ApiClient() - self.api_client = api_client - self.delete_dataset_release_endpoint = _Endpoint( - settings={ - 'response_type': None, - 'auth': [ - 'api_key' - ], - 'endpoint_path': '/v1/dataset_release', - 'operation_id': 'delete_dataset_release', - 'http_method': 'DELETE', - 'servers': None, - }, - params_map={ - 'all': [ - 'id', - ], - 'required': [ - 'id', - ], - 'nullable': [ - ], - 'enum': [ - ], - 'validation': [ - ] - }, - root_map={ - 'validations': { - }, - 'allowed_values': { - }, - 'openapi_types': { - 'id': - (int,), - }, - 'attribute_map': { - 'id': 'id', - }, - 'location_map': { - 'id': 'query', - }, - 'collection_format_map': { - } - }, - headers_map={ - 'accept': [], - 'content_type': [], - }, - api_client=api_client - ) - self.get_dataset_release_endpoint = _Endpoint( - settings={ - 'response_type': (DatasetRelease,), - 'auth': [ - 'api_key' - ], - 'endpoint_path': '/v1/dataset_release', - 'operation_id': 'get_dataset_release', - 'http_method': 'GET', - 'servers': None, - }, - params_map={ - 'all': [ - 'id', - ], - 'required': [ - 'id', - ], - 'nullable': [ - ], - 'enum': [ - ], - 'validation': [ - ] - }, - root_map={ - 'validations': { - }, - 'allowed_values': { - }, - 'openapi_types': { - 'id': - (int,), - }, - 'attribute_map': { - 'id': 'id', - }, - 'location_map': { - 'id': 'query', - }, - 'collection_format_map': { - } - }, - headers_map={ - 'accept': [ - 'application/json' - ], - 'content_type': [], - }, - api_client=api_client - ) - self.get_dataset_releases_endpoint = _Endpoint( - settings={ - 'response_type': ([DatasetRelease],), - 'auth': [ - 'api_key' - ], - 'endpoint_path': '/v1/dataset_releases', - 'operation_id': 'get_dataset_releases', - 'http_method': 'GET', - 'servers': None, - }, - params_map={ - 'all': [ - 'dag_id', - 'dataset_id', - ], - 'required': [], - 'nullable': [ - ], - 'enum': [ - ], - 'validation': [ - ] - }, - root_map={ - 'validations': { - }, - 'allowed_values': { - }, - 'openapi_types': { - 'dag_id': - (str,), - 'dataset_id': - (str,), - }, - 'attribute_map': { - 'dag_id': 'dag_id', - 'dataset_id': 'dataset_id', - }, - 'location_map': { - 'dag_id': 'query', - 'dataset_id': 'query', - }, - 'collection_format_map': { - } - }, - headers_map={ - 'accept': [ - 'application/json' - ], - 'content_type': [], - }, - api_client=api_client - ) - self.post_dataset_release_endpoint = _Endpoint( - settings={ - 'response_type': (DatasetRelease,), - 'auth': [ - 'api_key' - ], - 'endpoint_path': '/v1/dataset_release', - 'operation_id': 'post_dataset_release', - 'http_method': 'POST', - 'servers': None, - }, - params_map={ - 'all': [ - 'body', - ], - 'required': [ - 'body', - ], - 'nullable': [ - ], - 'enum': [ - ], - 'validation': [ - ] - }, - root_map={ - 'validations': { - }, - 'allowed_values': { - }, - 'openapi_types': { - 'body': - (DatasetRelease,), - }, - 'attribute_map': { - }, - 'location_map': { - 'body': 'body', - }, - 'collection_format_map': { - } - }, - headers_map={ - 'accept': [ - 'application/json' - ], - 'content_type': [ - 'application/json' - ] - }, - api_client=api_client - ) - self.put_dataset_release_endpoint = _Endpoint( - settings={ - 'response_type': (DatasetRelease,), - 'auth': [ - 'api_key' - ], - 'endpoint_path': '/v1/dataset_release', - 'operation_id': 'put_dataset_release', - 'http_method': 'PUT', - 'servers': None, - }, - params_map={ - 'all': [ - 'body', - ], - 'required': [ - 'body', - ], - 'nullable': [ - ], - 'enum': [ - ], - 'validation': [ - ] - }, - root_map={ - 'validations': { - }, - 'allowed_values': { - }, - 'openapi_types': { - 'body': - (DatasetRelease,), - }, - 'attribute_map': { - }, - 'location_map': { - 'body': 'body', - }, - 'collection_format_map': { - } - }, - headers_map={ - 'accept': [ - 'application/json' - ], - 'content_type': [ - 'application/json' - ] - }, - api_client=api_client - ) - - def delete_dataset_release( - self, - id, - **kwargs - ): - """delete a DatasetRelease # noqa: E501 - - Delete a DatasetRelease by passing it's id. # noqa: E501 - This method makes a synchronous HTTP request by default. To make an - asynchronous HTTP request, please pass async_req=True - - >>> thread = api.delete_dataset_release(id, async_req=True) - >>> result = thread.get() - - Args: - id (int): DatasetRelease id - - Keyword Args: - _return_http_data_only (bool): response data without head status - code and headers. Default is True. - _preload_content (bool): if False, the urllib3.HTTPResponse object - will be returned without reading/decoding response data. - Default is True. - _request_timeout (int/float/tuple): timeout setting for this request. If - one number provided, it will be total request timeout. It can also - be a pair (tuple) of (connection, read) timeouts. - Default is None. - _check_input_type (bool): specifies if type checking - should be done one the data sent to the server. - Default is True. - _check_return_type (bool): specifies if type checking - should be done one the data received from the server. - Default is True. - _spec_property_naming (bool): True if the variable names in the input data - are serialized names, as specified in the OpenAPI document. - False if the variable names in the input data - are pythonic names, e.g. snake case (default) - _content_type (str/None): force body content-type. - Default is None and content-type will be predicted by allowed - content-types and body. - _host_index (int/None): specifies the index of the server - that we want to use. - Default is read from the configuration. - _request_auths (list): set to override the auth_settings for an a single - request; this effectively ignores the authentication - in the spec for a single request. - Default is None - async_req (bool): execute request asynchronously - - Returns: - None - If the method is called asynchronously, returns the request - thread. - """ - kwargs['async_req'] = kwargs.get( - 'async_req', False - ) - kwargs['_return_http_data_only'] = kwargs.get( - '_return_http_data_only', True - ) - kwargs['_preload_content'] = kwargs.get( - '_preload_content', True - ) - kwargs['_request_timeout'] = kwargs.get( - '_request_timeout', None - ) - kwargs['_check_input_type'] = kwargs.get( - '_check_input_type', True - ) - kwargs['_check_return_type'] = kwargs.get( - '_check_return_type', True - ) - kwargs['_spec_property_naming'] = kwargs.get( - '_spec_property_naming', False - ) - kwargs['_content_type'] = kwargs.get( - '_content_type') - kwargs['_host_index'] = kwargs.get('_host_index') - kwargs['_request_auths'] = kwargs.get('_request_auths', None) - kwargs['id'] = \ - id - return self.delete_dataset_release_endpoint.call_with_http_info(**kwargs) - - def get_dataset_release( - self, - id, - **kwargs - ): - """get a DatasetRelease # noqa: E501 - - Get the details of a DatasetRelease by passing it's id. # noqa: E501 - This method makes a synchronous HTTP request by default. To make an - asynchronous HTTP request, please pass async_req=True - - >>> thread = api.get_dataset_release(id, async_req=True) - >>> result = thread.get() - - Args: - id (int): DatasetRelease id - - Keyword Args: - _return_http_data_only (bool): response data without head status - code and headers. Default is True. - _preload_content (bool): if False, the urllib3.HTTPResponse object - will be returned without reading/decoding response data. - Default is True. - _request_timeout (int/float/tuple): timeout setting for this request. If - one number provided, it will be total request timeout. It can also - be a pair (tuple) of (connection, read) timeouts. - Default is None. - _check_input_type (bool): specifies if type checking - should be done one the data sent to the server. - Default is True. - _check_return_type (bool): specifies if type checking - should be done one the data received from the server. - Default is True. - _spec_property_naming (bool): True if the variable names in the input data - are serialized names, as specified in the OpenAPI document. - False if the variable names in the input data - are pythonic names, e.g. snake case (default) - _content_type (str/None): force body content-type. - Default is None and content-type will be predicted by allowed - content-types and body. - _host_index (int/None): specifies the index of the server - that we want to use. - Default is read from the configuration. - _request_auths (list): set to override the auth_settings for an a single - request; this effectively ignores the authentication - in the spec for a single request. - Default is None - async_req (bool): execute request asynchronously - - Returns: - DatasetRelease - If the method is called asynchronously, returns the request - thread. - """ - kwargs['async_req'] = kwargs.get( - 'async_req', False - ) - kwargs['_return_http_data_only'] = kwargs.get( - '_return_http_data_only', True - ) - kwargs['_preload_content'] = kwargs.get( - '_preload_content', True - ) - kwargs['_request_timeout'] = kwargs.get( - '_request_timeout', None - ) - kwargs['_check_input_type'] = kwargs.get( - '_check_input_type', True - ) - kwargs['_check_return_type'] = kwargs.get( - '_check_return_type', True - ) - kwargs['_spec_property_naming'] = kwargs.get( - '_spec_property_naming', False - ) - kwargs['_content_type'] = kwargs.get( - '_content_type') - kwargs['_host_index'] = kwargs.get('_host_index') - kwargs['_request_auths'] = kwargs.get('_request_auths', None) - kwargs['id'] = \ - id - return self.get_dataset_release_endpoint.call_with_http_info(**kwargs) - - def get_dataset_releases( - self, - **kwargs - ): - """Get a list of DatasetRelease objects # noqa: E501 - - Get a list of DatasetRelease objects # noqa: E501 - This method makes a synchronous HTTP request by default. To make an - asynchronous HTTP request, please pass async_req=True - - >>> thread = api.get_dataset_releases(async_req=True) - >>> result = thread.get() - - - Keyword Args: - dag_id (str): the dag_id to fetch release info for. [optional] - dataset_id (str): the dataset_id to fetch release info for. [optional] - _return_http_data_only (bool): response data without head status - code and headers. Default is True. - _preload_content (bool): if False, the urllib3.HTTPResponse object - will be returned without reading/decoding response data. - Default is True. - _request_timeout (int/float/tuple): timeout setting for this request. If - one number provided, it will be total request timeout. It can also - be a pair (tuple) of (connection, read) timeouts. - Default is None. - _check_input_type (bool): specifies if type checking - should be done one the data sent to the server. - Default is True. - _check_return_type (bool): specifies if type checking - should be done one the data received from the server. - Default is True. - _spec_property_naming (bool): True if the variable names in the input data - are serialized names, as specified in the OpenAPI document. - False if the variable names in the input data - are pythonic names, e.g. snake case (default) - _content_type (str/None): force body content-type. - Default is None and content-type will be predicted by allowed - content-types and body. - _host_index (int/None): specifies the index of the server - that we want to use. - Default is read from the configuration. - _request_auths (list): set to override the auth_settings for an a single - request; this effectively ignores the authentication - in the spec for a single request. - Default is None - async_req (bool): execute request asynchronously - - Returns: - [DatasetRelease] - If the method is called asynchronously, returns the request - thread. - """ - kwargs['async_req'] = kwargs.get( - 'async_req', False - ) - kwargs['_return_http_data_only'] = kwargs.get( - '_return_http_data_only', True - ) - kwargs['_preload_content'] = kwargs.get( - '_preload_content', True - ) - kwargs['_request_timeout'] = kwargs.get( - '_request_timeout', None - ) - kwargs['_check_input_type'] = kwargs.get( - '_check_input_type', True - ) - kwargs['_check_return_type'] = kwargs.get( - '_check_return_type', True - ) - kwargs['_spec_property_naming'] = kwargs.get( - '_spec_property_naming', False - ) - kwargs['_content_type'] = kwargs.get( - '_content_type') - kwargs['_host_index'] = kwargs.get('_host_index') - kwargs['_request_auths'] = kwargs.get('_request_auths', None) - return self.get_dataset_releases_endpoint.call_with_http_info(**kwargs) - - def post_dataset_release( - self, - body, - **kwargs - ): - """create a DatasetRelease # noqa: E501 - - Create a DatasetRelease by passing a DatasetRelease object, without an id. # noqa: E501 - This method makes a synchronous HTTP request by default. To make an - asynchronous HTTP request, please pass async_req=True - - >>> thread = api.post_dataset_release(body, async_req=True) - >>> result = thread.get() - - Args: - body (DatasetRelease): DatasetRelease to create - - Keyword Args: - _return_http_data_only (bool): response data without head status - code and headers. Default is True. - _preload_content (bool): if False, the urllib3.HTTPResponse object - will be returned without reading/decoding response data. - Default is True. - _request_timeout (int/float/tuple): timeout setting for this request. If - one number provided, it will be total request timeout. It can also - be a pair (tuple) of (connection, read) timeouts. - Default is None. - _check_input_type (bool): specifies if type checking - should be done one the data sent to the server. - Default is True. - _check_return_type (bool): specifies if type checking - should be done one the data received from the server. - Default is True. - _spec_property_naming (bool): True if the variable names in the input data - are serialized names, as specified in the OpenAPI document. - False if the variable names in the input data - are pythonic names, e.g. snake case (default) - _content_type (str/None): force body content-type. - Default is None and content-type will be predicted by allowed - content-types and body. - _host_index (int/None): specifies the index of the server - that we want to use. - Default is read from the configuration. - _request_auths (list): set to override the auth_settings for an a single - request; this effectively ignores the authentication - in the spec for a single request. - Default is None - async_req (bool): execute request asynchronously - - Returns: - DatasetRelease - If the method is called asynchronously, returns the request - thread. - """ - kwargs['async_req'] = kwargs.get( - 'async_req', False - ) - kwargs['_return_http_data_only'] = kwargs.get( - '_return_http_data_only', True - ) - kwargs['_preload_content'] = kwargs.get( - '_preload_content', True - ) - kwargs['_request_timeout'] = kwargs.get( - '_request_timeout', None - ) - kwargs['_check_input_type'] = kwargs.get( - '_check_input_type', True - ) - kwargs['_check_return_type'] = kwargs.get( - '_check_return_type', True - ) - kwargs['_spec_property_naming'] = kwargs.get( - '_spec_property_naming', False - ) - kwargs['_content_type'] = kwargs.get( - '_content_type') - kwargs['_host_index'] = kwargs.get('_host_index') - kwargs['_request_auths'] = kwargs.get('_request_auths', None) - kwargs['body'] = \ - body - return self.post_dataset_release_endpoint.call_with_http_info(**kwargs) - - def put_dataset_release( - self, - body, - **kwargs - ): - """create or update a DatasetRelease # noqa: E501 - - Create a DatasetRelease by passing a DatasetRelease object, without an id. Update an existing DatasetRelease by passing a DatasetRelease object with an id. # noqa: E501 - This method makes a synchronous HTTP request by default. To make an - asynchronous HTTP request, please pass async_req=True - - >>> thread = api.put_dataset_release(body, async_req=True) - >>> result = thread.get() - - Args: - body (DatasetRelease): DatasetRelease to create or update - - Keyword Args: - _return_http_data_only (bool): response data without head status - code and headers. Default is True. - _preload_content (bool): if False, the urllib3.HTTPResponse object - will be returned without reading/decoding response data. - Default is True. - _request_timeout (int/float/tuple): timeout setting for this request. If - one number provided, it will be total request timeout. It can also - be a pair (tuple) of (connection, read) timeouts. - Default is None. - _check_input_type (bool): specifies if type checking - should be done one the data sent to the server. - Default is True. - _check_return_type (bool): specifies if type checking - should be done one the data received from the server. - Default is True. - _spec_property_naming (bool): True if the variable names in the input data - are serialized names, as specified in the OpenAPI document. - False if the variable names in the input data - are pythonic names, e.g. snake case (default) - _content_type (str/None): force body content-type. - Default is None and content-type will be predicted by allowed - content-types and body. - _host_index (int/None): specifies the index of the server - that we want to use. - Default is read from the configuration. - _request_auths (list): set to override the auth_settings for an a single - request; this effectively ignores the authentication - in the spec for a single request. - Default is None - async_req (bool): execute request asynchronously - - Returns: - DatasetRelease - If the method is called asynchronously, returns the request - thread. - """ - kwargs['async_req'] = kwargs.get( - 'async_req', False - ) - kwargs['_return_http_data_only'] = kwargs.get( - '_return_http_data_only', True - ) - kwargs['_preload_content'] = kwargs.get( - '_preload_content', True - ) - kwargs['_request_timeout'] = kwargs.get( - '_request_timeout', None - ) - kwargs['_check_input_type'] = kwargs.get( - '_check_input_type', True - ) - kwargs['_check_return_type'] = kwargs.get( - '_check_return_type', True - ) - kwargs['_spec_property_naming'] = kwargs.get( - '_spec_property_naming', False - ) - kwargs['_content_type'] = kwargs.get( - '_content_type') - kwargs['_host_index'] = kwargs.get('_host_index') - kwargs['_request_auths'] = kwargs.get('_request_auths', None) - kwargs['body'] = \ - body - return self.put_dataset_release_endpoint.call_with_http_info(**kwargs) - diff --git a/observatory-api/observatory/api/client/api_client.py b/observatory-api/observatory/api/client/api_client.py deleted file mode 100644 index 1f78b48be..000000000 --- a/observatory-api/observatory/api/client/api_client.py +++ /dev/null @@ -1,897 +0,0 @@ -""" - Observatory API - - The REST API for managing and accessing data from the Observatory Platform. # noqa: E501 - - The version of the OpenAPI document: 1.0.0 - Contact: agent@observatory.academy - Generated by: https://openapi-generator.tech -""" - - -import json -import atexit -import mimetypes -from multiprocessing.pool import ThreadPool -import io -import os -import re -import typing -from urllib.parse import quote -from urllib3.fields import RequestField - - -from observatory.api.client import rest -from observatory.api.client.configuration import Configuration -from observatory.api.client.exceptions import ApiTypeError, ApiValueError, ApiException -from observatory.api.client.model_utils import ( - ModelNormal, - ModelSimple, - ModelComposed, - check_allowed_values, - check_validations, - date, - datetime, - deserialize_file, - file_type, - model_to_dict, - none_type, - validate_and_convert_types -) - - -class ApiClient(object): - """Generic API client for OpenAPI client library builds. - - OpenAPI generic API client. This client handles the client- - server communication, and is invariant across implementations. Specifics of - the methods and models for each application are generated from the OpenAPI - templates. - - NOTE: This class is auto generated by OpenAPI Generator. - Ref: https://openapi-generator.tech - Do not edit the class manually. - - :param configuration: .Configuration object for this client - :param header_name: a header to pass when making calls to the API. - :param header_value: a header value to pass when making calls to - the API. - :param cookie: a cookie to include in the header when making calls - to the API - :param pool_threads: The number of threads to use for async requests - to the API. More threads means more concurrent API requests. - """ - - _pool = None - - def __init__(self, configuration=None, header_name=None, header_value=None, - cookie=None, pool_threads=1): - if configuration is None: - configuration = Configuration.get_default_copy() - self.configuration = configuration - self.pool_threads = pool_threads - - self.rest_client = rest.RESTClientObject(configuration) - self.default_headers = {} - if header_name is not None: - self.default_headers[header_name] = header_value - self.cookie = cookie - # Set default User-Agent. - self.user_agent = 'OpenAPI-Generator/1.0.0/python' - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_value, traceback): - self.close() - - def close(self): - if self._pool: - self._pool.close() - self._pool.join() - self._pool = None - if hasattr(atexit, 'unregister'): - atexit.unregister(self.close) - - @property - def pool(self): - """Create thread pool on first request - avoids instantiating unused threadpool for blocking clients. - """ - if self._pool is None: - atexit.register(self.close) - self._pool = ThreadPool(self.pool_threads) - return self._pool - - @property - def user_agent(self): - """User agent for this API client""" - return self.default_headers['User-Agent'] - - @user_agent.setter - def user_agent(self, value): - self.default_headers['User-Agent'] = value - - def set_default_header(self, header_name, header_value): - self.default_headers[header_name] = header_value - - def __call_api( - self, - resource_path: str, - method: str, - path_params: typing.Optional[typing.Dict[str, typing.Any]] = None, - query_params: typing.Optional[typing.List[typing.Tuple[str, typing.Any]]] = None, - header_params: typing.Optional[typing.Dict[str, typing.Any]] = None, - body: typing.Optional[typing.Any] = None, - post_params: typing.Optional[typing.List[typing.Tuple[str, typing.Any]]] = None, - files: typing.Optional[typing.Dict[str, typing.List[io.IOBase]]] = None, - response_type: typing.Optional[typing.Tuple[typing.Any]] = None, - auth_settings: typing.Optional[typing.List[str]] = None, - _return_http_data_only: typing.Optional[bool] = None, - collection_formats: typing.Optional[typing.Dict[str, str]] = None, - _preload_content: bool = True, - _request_timeout: typing.Optional[typing.Union[int, float, typing.Tuple]] = None, - _host: typing.Optional[str] = None, - _check_type: typing.Optional[bool] = None, - _content_type: typing.Optional[str] = None, - _request_auths: typing.Optional[typing.List[typing.Dict[str, typing.Any]]] = None - ): - - config = self.configuration - - # header parameters - header_params = header_params or {} - header_params.update(self.default_headers) - if self.cookie: - header_params['Cookie'] = self.cookie - if header_params: - header_params = self.sanitize_for_serialization(header_params) - header_params = dict(self.parameters_to_tuples(header_params, - collection_formats)) - - # path parameters - if path_params: - path_params = self.sanitize_for_serialization(path_params) - path_params = self.parameters_to_tuples(path_params, - collection_formats) - for k, v in path_params: - # specified safe chars, encode everything - resource_path = resource_path.replace( - '{%s}' % k, - quote(str(v), safe=config.safe_chars_for_path_param) - ) - - # query parameters - if query_params: - query_params = self.sanitize_for_serialization(query_params) - query_params = self.parameters_to_tuples(query_params, - collection_formats) - - # post parameters - if post_params or files: - post_params = post_params if post_params else [] - post_params = self.sanitize_for_serialization(post_params) - post_params = self.parameters_to_tuples(post_params, - collection_formats) - post_params.extend(self.files_parameters(files)) - if header_params['Content-Type'].startswith("multipart"): - post_params = self.parameters_to_multipart(post_params, - (dict)) - - # body - if body: - body = self.sanitize_for_serialization(body) - - # auth setting - self.update_params_for_auth(header_params, query_params, - auth_settings, resource_path, method, body, - request_auths=_request_auths) - - # request url - if _host is None: - url = self.configuration.host + resource_path - else: - # use server/host defined in path or operation instead - url = _host + resource_path - - try: - # perform request and return response - response_data = self.request( - method, url, query_params=query_params, headers=header_params, - post_params=post_params, body=body, - _preload_content=_preload_content, - _request_timeout=_request_timeout) - except ApiException as e: - e.body = e.body.decode('utf-8') - raise e - - self.last_response = response_data - - return_data = response_data - - if not _preload_content: - return (return_data) - return return_data - - # deserialize response data - if response_type: - if response_type != (file_type,): - encoding = "utf-8" - content_type = response_data.getheader('content-type') - if content_type is not None: - match = re.search(r"charset=([a-zA-Z\-\d]+)[\s\;]?", content_type) - if match: - encoding = match.group(1) - response_data.data = response_data.data.decode(encoding) - - return_data = self.deserialize( - response_data, - response_type, - _check_type - ) - else: - return_data = None - - if _return_http_data_only: - return (return_data) - else: - return (return_data, response_data.status, - response_data.getheaders()) - - def parameters_to_multipart(self, params, collection_types): - """Get parameters as list of tuples, formatting as json if value is collection_types - - :param params: Parameters as list of two-tuples - :param dict collection_types: Parameter collection types - :return: Parameters as list of tuple or urllib3.fields.RequestField - """ - new_params = [] - if collection_types is None: - collection_types = (dict) - for k, v in params.items() if isinstance(params, dict) else params: # noqa: E501 - if isinstance( - v, collection_types): # v is instance of collection_type, formatting as application/json - v = json.dumps(v, ensure_ascii=False).encode("utf-8") - field = RequestField(k, v) - field.make_multipart(content_type="application/json; charset=utf-8") - new_params.append(field) - else: - new_params.append((k, v)) - return new_params - - @classmethod - def sanitize_for_serialization(cls, obj): - """Prepares data for transmission before it is sent with the rest client - If obj is None, return None. - If obj is str, int, long, float, bool, return directly. - If obj is datetime.datetime, datetime.date - convert to string in iso8601 format. - If obj is list, sanitize each element in the list. - If obj is dict, return the dict. - If obj is OpenAPI model, return the properties dict. - If obj is io.IOBase, return the bytes - :param obj: The data to serialize. - :return: The serialized form of data. - """ - if isinstance(obj, (ModelNormal, ModelComposed)): - return { - key: cls.sanitize_for_serialization(val) for key, - val in model_to_dict( - obj, - serialize=True).items()} - elif isinstance(obj, io.IOBase): - return cls.get_file_data_and_close_file(obj) - elif isinstance(obj, (str, int, float, none_type, bool)): - return obj - elif isinstance(obj, (datetime, date)): - return obj.isoformat() - elif isinstance(obj, ModelSimple): - return cls.sanitize_for_serialization(obj.value) - elif isinstance(obj, (list, tuple)): - return [cls.sanitize_for_serialization(item) for item in obj] - if isinstance(obj, dict): - return {key: cls.sanitize_for_serialization(val) for key, val in obj.items()} - raise ApiValueError( - 'Unable to prepare type {} for serialization'.format( - obj.__class__.__name__)) - - def deserialize(self, response, response_type, _check_type): - """Deserializes response into an object. - - :param response: RESTResponse object to be deserialized. - :param response_type: For the response, a tuple containing: - valid classes - a list containing valid classes (for list schemas) - a dict containing a tuple of valid classes as the value - Example values: - (str,) - (Pet,) - (float, none_type) - ([int, none_type],) - ({str: (bool, str, int, float, date, datetime, str, none_type)},) - :param _check_type: boolean, whether to check the types of the data - received from the server - :type _check_type: bool - - :return: deserialized object. - """ - # handle file downloading - # save response body into a tmp file and return the instance - if response_type == (file_type,): - content_disposition = response.getheader("Content-Disposition") - return deserialize_file(response.data, self.configuration, - content_disposition=content_disposition) - - # fetch data from response object - try: - received_data = json.loads(response.data) - except ValueError: - received_data = response.data - - # store our data under the key of 'received_data' so users have some - # context if they are deserializing a string and the data type is wrong - deserialized_data = validate_and_convert_types( - received_data, - response_type, - ['received_data'], - True, - _check_type, - configuration=self.configuration - ) - return deserialized_data - - def call_api( - self, - resource_path: str, - method: str, - path_params: typing.Optional[typing.Dict[str, typing.Any]] = None, - query_params: typing.Optional[typing.List[typing.Tuple[str, typing.Any]]] = None, - header_params: typing.Optional[typing.Dict[str, typing.Any]] = None, - body: typing.Optional[typing.Any] = None, - post_params: typing.Optional[typing.List[typing.Tuple[str, typing.Any]]] = None, - files: typing.Optional[typing.Dict[str, typing.List[io.IOBase]]] = None, - response_type: typing.Optional[typing.Tuple[typing.Any]] = None, - auth_settings: typing.Optional[typing.List[str]] = None, - async_req: typing.Optional[bool] = None, - _return_http_data_only: typing.Optional[bool] = None, - collection_formats: typing.Optional[typing.Dict[str, str]] = None, - _preload_content: bool = True, - _request_timeout: typing.Optional[typing.Union[int, float, typing.Tuple]] = None, - _host: typing.Optional[str] = None, - _check_type: typing.Optional[bool] = None, - _request_auths: typing.Optional[typing.List[typing.Dict[str, typing.Any]]] = None - ): - """Makes the HTTP request (synchronous) and returns deserialized data. - - To make an async_req request, set the async_req parameter. - - :param resource_path: Path to method endpoint. - :param method: Method to call. - :param path_params: Path parameters in the url. - :param query_params: Query parameters in the url. - :param header_params: Header parameters to be - placed in the request header. - :param body: Request body. - :param post_params dict: Request post form parameters, - for `application/x-www-form-urlencoded`, `multipart/form-data`. - :param auth_settings list: Auth Settings names for the request. - :param response_type: For the response, a tuple containing: - valid classes - a list containing valid classes (for list schemas) - a dict containing a tuple of valid classes as the value - Example values: - (str,) - (Pet,) - (float, none_type) - ([int, none_type],) - ({str: (bool, str, int, float, date, datetime, str, none_type)},) - :param files: key -> field name, value -> a list of open file - objects for `multipart/form-data`. - :type files: dict - :param async_req bool: execute request asynchronously - :type async_req: bool, optional - :param _return_http_data_only: response data without head status code - and headers - :type _return_http_data_only: bool, optional - :param collection_formats: dict of collection formats for path, query, - header, and post parameters. - :type collection_formats: dict, optional - :param _preload_content: if False, the urllib3.HTTPResponse object will - be returned without reading/decoding response - data. Default is True. - :type _preload_content: bool, optional - :param _request_timeout: timeout setting for this request. If one - number provided, it will be total request - timeout. It can also be a pair (tuple) of - (connection, read) timeouts. - :param _check_type: boolean describing if the data back from the server - should have its type checked. - :type _check_type: bool, optional - :param _request_auths: set to override the auth_settings for an a single - request; this effectively ignores the authentication - in the spec for a single request. - :type _request_auths: list, optional - :return: - If async_req parameter is True, - the request will be called asynchronously. - The method will return the request thread. - If parameter async_req is False or missing, - then the method will return the response directly. - """ - if not async_req: - return self.__call_api(resource_path, method, - path_params, query_params, header_params, - body, post_params, files, - response_type, auth_settings, - _return_http_data_only, collection_formats, - _preload_content, _request_timeout, _host, - _check_type, _request_auths=_request_auths) - - return self.pool.apply_async(self.__call_api, (resource_path, - method, path_params, - query_params, - header_params, body, - post_params, files, - response_type, - auth_settings, - _return_http_data_only, - collection_formats, - _preload_content, - _request_timeout, - _host, _check_type, None, _request_auths)) - - def request(self, method, url, query_params=None, headers=None, - post_params=None, body=None, _preload_content=True, - _request_timeout=None): - """Makes the HTTP request using RESTClient.""" - if method == "GET": - return self.rest_client.GET(url, - query_params=query_params, - _preload_content=_preload_content, - _request_timeout=_request_timeout, - headers=headers) - elif method == "HEAD": - return self.rest_client.HEAD(url, - query_params=query_params, - _preload_content=_preload_content, - _request_timeout=_request_timeout, - headers=headers) - elif method == "OPTIONS": - return self.rest_client.OPTIONS(url, - query_params=query_params, - headers=headers, - post_params=post_params, - _preload_content=_preload_content, - _request_timeout=_request_timeout, - body=body) - elif method == "POST": - return self.rest_client.POST(url, - query_params=query_params, - headers=headers, - post_params=post_params, - _preload_content=_preload_content, - _request_timeout=_request_timeout, - body=body) - elif method == "PUT": - return self.rest_client.PUT(url, - query_params=query_params, - headers=headers, - post_params=post_params, - _preload_content=_preload_content, - _request_timeout=_request_timeout, - body=body) - elif method == "PATCH": - return self.rest_client.PATCH(url, - query_params=query_params, - headers=headers, - post_params=post_params, - _preload_content=_preload_content, - _request_timeout=_request_timeout, - body=body) - elif method == "DELETE": - return self.rest_client.DELETE(url, - query_params=query_params, - headers=headers, - _preload_content=_preload_content, - _request_timeout=_request_timeout, - body=body) - else: - raise ApiValueError( - "http method must be `GET`, `HEAD`, `OPTIONS`," - " `POST`, `PATCH`, `PUT` or `DELETE`." - ) - - def parameters_to_tuples(self, params, collection_formats): - """Get parameters as list of tuples, formatting collections. - - :param params: Parameters as dict or list of two-tuples - :param dict collection_formats: Parameter collection formats - :return: Parameters as list of tuples, collections formatted - """ - new_params = [] - if collection_formats is None: - collection_formats = {} - for k, v in params.items() if isinstance(params, dict) else params: # noqa: E501 - if k in collection_formats: - collection_format = collection_formats[k] - if collection_format == 'multi': - new_params.extend((k, value) for value in v) - else: - if collection_format == 'ssv': - delimiter = ' ' - elif collection_format == 'tsv': - delimiter = '\t' - elif collection_format == 'pipes': - delimiter = '|' - else: # csv is the default - delimiter = ',' - new_params.append( - (k, delimiter.join(str(value) for value in v))) - else: - new_params.append((k, v)) - return new_params - - @staticmethod - def get_file_data_and_close_file(file_instance: io.IOBase) -> bytes: - file_data = file_instance.read() - file_instance.close() - return file_data - - def files_parameters(self, - files: typing.Optional[typing.Dict[str, - typing.List[io.IOBase]]] = None): - """Builds form parameters. - - :param files: None or a dict with key=param_name and - value is a list of open file objects - :return: List of tuples of form parameters with file data - """ - if files is None: - return [] - - params = [] - for param_name, file_instances in files.items(): - if file_instances is None: - # if the file field is nullable, skip None values - continue - for file_instance in file_instances: - if file_instance is None: - # if the file field is nullable, skip None values - continue - if file_instance.closed is True: - raise ApiValueError( - "Cannot read a closed file. The passed in file_type " - "for %s must be open." % param_name - ) - filename = os.path.basename(file_instance.name) - filedata = self.get_file_data_and_close_file(file_instance) - mimetype = (mimetypes.guess_type(filename)[0] or - 'application/octet-stream') - params.append( - tuple([param_name, tuple([filename, filedata, mimetype])])) - - return params - - def select_header_accept(self, accepts): - """Returns `Accept` based on an array of accepts provided. - - :param accepts: List of headers. - :return: Accept (e.g. application/json). - """ - if not accepts: - return - - accepts = [x.lower() for x in accepts] - - if 'application/json' in accepts: - return 'application/json' - else: - return ', '.join(accepts) - - def select_header_content_type(self, content_types, method=None, body=None): - """Returns `Content-Type` based on an array of content_types provided. - - :param content_types: List of content-types. - :param method: http method (e.g. POST, PATCH). - :param body: http body to send. - :return: Content-Type (e.g. application/json). - """ - if not content_types: - return None - - content_types = [x.lower() for x in content_types] - - if (method == 'PATCH' and - 'application/json-patch+json' in content_types and - isinstance(body, list)): - return 'application/json-patch+json' - - if 'application/json' in content_types or '*/*' in content_types: - return 'application/json' - else: - return content_types[0] - - def update_params_for_auth(self, headers, queries, auth_settings, - resource_path, method, body, request_auths=None): - """Updates header and query params based on authentication setting. - - :param headers: Header parameters dict to be updated. - :param queries: Query parameters tuple list to be updated. - :param auth_settings: Authentication setting identifiers list. - :param resource_path: A string representation of the HTTP request resource path. - :param method: A string representation of the HTTP request method. - :param body: A object representing the body of the HTTP request. - The object type is the return value of _encoder.default(). - :param request_auths: if set, the provided settings will - override the token in the configuration. - """ - if not auth_settings: - return - - if request_auths: - for auth_setting in request_auths: - self._apply_auth_params( - headers, queries, resource_path, method, body, auth_setting) - return - - for auth in auth_settings: - auth_setting = self.configuration.auth_settings().get(auth) - if auth_setting: - self._apply_auth_params( - headers, queries, resource_path, method, body, auth_setting) - - def _apply_auth_params(self, headers, queries, resource_path, method, body, auth_setting): - if auth_setting['in'] == 'cookie': - headers['Cookie'] = auth_setting['key'] + "=" + auth_setting['value'] - elif auth_setting['in'] == 'header': - if auth_setting['type'] != 'http-signature': - headers[auth_setting['key']] = auth_setting['value'] - elif auth_setting['in'] == 'query': - queries.append((auth_setting['key'], auth_setting['value'])) - else: - raise ApiValueError( - 'Authentication token must be in `query` or `header`' - ) - - -class Endpoint(object): - def __init__(self, settings=None, params_map=None, root_map=None, - headers_map=None, api_client=None, callable=None): - """Creates an endpoint - - Args: - settings (dict): see below key value pairs - 'response_type' (tuple/None): response type - 'auth' (list): a list of auth type keys - 'endpoint_path' (str): the endpoint path - 'operation_id' (str): endpoint string identifier - 'http_method' (str): POST/PUT/PATCH/GET etc - 'servers' (list): list of str servers that this endpoint is at - params_map (dict): see below key value pairs - 'all' (list): list of str endpoint parameter names - 'required' (list): list of required parameter names - 'nullable' (list): list of nullable parameter names - 'enum' (list): list of parameters with enum values - 'validation' (list): list of parameters with validations - root_map - 'validations' (dict): the dict mapping endpoint parameter tuple - paths to their validation dictionaries - 'allowed_values' (dict): the dict mapping endpoint parameter - tuple paths to their allowed_values (enum) dictionaries - 'openapi_types' (dict): param_name to openapi type - 'attribute_map' (dict): param_name to camelCase name - 'location_map' (dict): param_name to 'body', 'file', 'form', - 'header', 'path', 'query' - collection_format_map (dict): param_name to `csv` etc. - headers_map (dict): see below key value pairs - 'accept' (list): list of Accept header strings - 'content_type' (list): list of Content-Type header strings - api_client (ApiClient) api client instance - callable (function): the function which is invoked when the - Endpoint is called - """ - self.settings = settings - self.params_map = params_map - self.params_map['all'].extend([ - 'async_req', - '_host_index', - '_preload_content', - '_request_timeout', - '_return_http_data_only', - '_check_input_type', - '_check_return_type', - '_content_type', - '_spec_property_naming', - '_request_auths' - ]) - self.params_map['nullable'].extend(['_request_timeout']) - self.validations = root_map['validations'] - self.allowed_values = root_map['allowed_values'] - self.openapi_types = root_map['openapi_types'] - extra_types = { - 'async_req': (bool,), - '_host_index': (none_type, int), - '_preload_content': (bool,), - '_request_timeout': (none_type, float, (float,), [float], int, (int,), [int]), - '_return_http_data_only': (bool,), - '_check_input_type': (bool,), - '_check_return_type': (bool,), - '_spec_property_naming': (bool,), - '_content_type': (none_type, str), - '_request_auths': (none_type, list) - } - self.openapi_types.update(extra_types) - self.attribute_map = root_map['attribute_map'] - self.location_map = root_map['location_map'] - self.collection_format_map = root_map['collection_format_map'] - self.headers_map = headers_map - self.api_client = api_client - self.callable = callable - - def __validate_inputs(self, kwargs): - for param in self.params_map['enum']: - if param in kwargs: - check_allowed_values( - self.allowed_values, - (param,), - kwargs[param] - ) - - for param in self.params_map['validation']: - if param in kwargs: - check_validations( - self.validations, - (param,), - kwargs[param], - configuration=self.api_client.configuration - ) - - if kwargs['_check_input_type'] is False: - return - - for key, value in kwargs.items(): - fixed_val = validate_and_convert_types( - value, - self.openapi_types[key], - [key], - kwargs['_spec_property_naming'], - kwargs['_check_input_type'], - configuration=self.api_client.configuration - ) - kwargs[key] = fixed_val - - def __gather_params(self, kwargs): - params = { - 'body': None, - 'collection_format': {}, - 'file': {}, - 'form': [], - 'header': {}, - 'path': {}, - 'query': [] - } - - for param_name, param_value in kwargs.items(): - param_location = self.location_map.get(param_name) - if param_location is None: - continue - if param_location: - if param_location == 'body': - params['body'] = param_value - continue - base_name = self.attribute_map[param_name] - if (param_location == 'form' and - self.openapi_types[param_name] == (file_type,)): - params['file'][base_name] = [param_value] - elif (param_location == 'form' and - self.openapi_types[param_name] == ([file_type],)): - # param_value is already a list - params['file'][base_name] = param_value - elif param_location in {'form', 'query'}: - param_value_full = (base_name, param_value) - params[param_location].append(param_value_full) - if param_location not in {'form', 'query'}: - params[param_location][base_name] = param_value - collection_format = self.collection_format_map.get(param_name) - if collection_format: - params['collection_format'][base_name] = collection_format - - return params - - def __call__(self, *args, **kwargs): - """ This method is invoked when endpoints are called - Example: - - api_instance = ObservatoryApi() - api_instance.delete_dataset_release # this is an instance of the class Endpoint - api_instance.delete_dataset_release() # this invokes api_instance.delete_dataset_release.__call__() - which then invokes the callable functions stored in that endpoint at - api_instance.delete_dataset_release.callable or self.callable in this class - - """ - return self.callable(self, *args, **kwargs) - - def call_with_http_info(self, **kwargs): - - try: - index = self.api_client.configuration.server_operation_index.get( - self.settings['operation_id'], self.api_client.configuration.server_index - ) if kwargs['_host_index'] is None else kwargs['_host_index'] - server_variables = self.api_client.configuration.server_operation_variables.get( - self.settings['operation_id'], self.api_client.configuration.server_variables - ) - _host = self.api_client.configuration.get_host_from_settings( - index, variables=server_variables, servers=self.settings['servers'] - ) - except IndexError: - if self.settings['servers']: - raise ApiValueError( - "Invalid host index. Must be 0 <= index < %s" % - len(self.settings['servers']) - ) - _host = None - - for key, value in kwargs.items(): - if key not in self.params_map['all']: - raise ApiTypeError( - "Got an unexpected parameter '%s'" - " to method `%s`" % - (key, self.settings['operation_id']) - ) - # only throw this nullable ApiValueError if _check_input_type - # is False, if _check_input_type==True we catch this case - # in self.__validate_inputs - if (key not in self.params_map['nullable'] and value is None - and kwargs['_check_input_type'] is False): - raise ApiValueError( - "Value may not be None for non-nullable parameter `%s`" - " when calling `%s`" % - (key, self.settings['operation_id']) - ) - - for key in self.params_map['required']: - if key not in kwargs.keys(): - raise ApiValueError( - "Missing the required parameter `%s` when calling " - "`%s`" % (key, self.settings['operation_id']) - ) - - self.__validate_inputs(kwargs) - - params = self.__gather_params(kwargs) - - accept_headers_list = self.headers_map['accept'] - if accept_headers_list: - params['header']['Accept'] = self.api_client.select_header_accept( - accept_headers_list) - - if kwargs.get('_content_type'): - params['header']['Content-Type'] = kwargs['_content_type'] - else: - content_type_headers_list = self.headers_map['content_type'] - if content_type_headers_list: - if params['body'] != "": - content_types_list = self.api_client.select_header_content_type( - content_type_headers_list, self.settings['http_method'], - params['body']) - if content_types_list: - params['header']['Content-Type'] = content_types_list - - return self.api_client.call_api( - self.settings['endpoint_path'], self.settings['http_method'], - params['path'], - params['query'], - params['header'], - body=params['body'], - post_params=params['form'], - files=params['file'], - response_type=self.settings['response_type'], - auth_settings=self.settings['auth'], - async_req=kwargs['async_req'], - _check_type=kwargs['_check_return_type'], - _return_http_data_only=kwargs['_return_http_data_only'], - _preload_content=kwargs['_preload_content'], - _request_timeout=kwargs['_request_timeout'], - _host=_host, - _request_auths=kwargs['_request_auths'], - collection_formats=params['collection_format']) diff --git a/observatory-api/observatory/api/client/configuration.py b/observatory-api/observatory/api/client/configuration.py deleted file mode 100644 index 2b86f0d10..000000000 --- a/observatory-api/observatory/api/client/configuration.py +++ /dev/null @@ -1,474 +0,0 @@ -""" - Observatory API - - The REST API for managing and accessing data from the Observatory Platform. # noqa: E501 - - The version of the OpenAPI document: 1.0.0 - Contact: agent@observatory.academy - Generated by: https://openapi-generator.tech -""" - - -import copy -import logging -import multiprocessing -import sys -import certifi -import urllib3 - -from http import client as http_client -from observatory.api.client.exceptions import ApiValueError - - -JSON_SCHEMA_VALIDATION_KEYWORDS = { - 'multipleOf', 'maximum', 'exclusiveMaximum', - 'minimum', 'exclusiveMinimum', 'maxLength', - 'minLength', 'pattern', 'maxItems', 'minItems' -} - -class Configuration(object): - """NOTE: This class is auto generated by OpenAPI Generator - - Ref: https://openapi-generator.tech - Do not edit the class manually. - - :param host: Base url - :param api_key: Dict to store API key(s). - Each entry in the dict specifies an API key. - The dict key is the name of the security scheme in the OAS specification. - The dict value is the API key secret. - :param api_key_prefix: Dict to store API prefix (e.g. Bearer) - The dict key is the name of the security scheme in the OAS specification. - The dict value is an API key prefix when generating the auth data. - :param username: Username for HTTP basic authentication - :param password: Password for HTTP basic authentication - :param discard_unknown_keys: Boolean value indicating whether to discard - unknown properties. A server may send a response that includes additional - properties that are not known by the client in the following scenarios: - 1. The OpenAPI document is incomplete, i.e. it does not match the server - implementation. - 2. The client was generated using an older version of the OpenAPI document - and the server has been upgraded since then. - If a schema in the OpenAPI document defines the additionalProperties attribute, - then all undeclared properties received by the server are injected into the - additional properties map. In that case, there are undeclared properties, and - nothing to discard. - :param disabled_client_side_validations (string): Comma-separated list of - JSON schema validation keywords to disable JSON schema structural validation - rules. The following keywords may be specified: multipleOf, maximum, - exclusiveMaximum, minimum, exclusiveMinimum, maxLength, minLength, pattern, - maxItems, minItems. - By default, the validation is performed for data generated locally by the client - and data received from the server, independent of any validation performed by - the server side. If the input data does not satisfy the JSON schema validation - rules specified in the OpenAPI document, an exception is raised. - If disabled_client_side_validations is set, structural validation is - disabled. This can be useful to troubleshoot data validation problem, such as - when the OpenAPI document validation rules do not match the actual API data - received by the server. - :param server_index: Index to servers configuration. - :param server_variables: Mapping with string values to replace variables in - templated server configuration. The validation of enums is performed for - variables with defined enum values before. - :param server_operation_index: Mapping from operation ID to an index to server - configuration. - :param server_operation_variables: Mapping from operation ID to a mapping with - string values to replace variables in templated server configuration. - The validation of enums is performed for variables with defined enum values before. - :param ssl_ca_cert: str - the path to a file of concatenated CA certificates - in PEM format - - :Example: - - API Key Authentication Example. - Given the following security scheme in the OpenAPI specification: - components: - securitySchemes: - cookieAuth: # name for the security scheme - type: apiKey - in: cookie - name: JSESSIONID # cookie name - - You can programmatically set the cookie: - -conf = observatory.api.client.Configuration( - api_key={'cookieAuth': 'abc123'} - api_key_prefix={'cookieAuth': 'JSESSIONID'} -) - - The following cookie will be added to the HTTP request: - Cookie: JSESSIONID abc123 - """ - - _default = None - - def __init__(self, host=None, - api_key=None, api_key_prefix=None, - access_token=None, - username=None, password=None, - discard_unknown_keys=False, - disabled_client_side_validations="", - server_index=None, server_variables=None, - server_operation_index=None, server_operation_variables=None, - ssl_ca_cert=certifi.where(), - ): - """Constructor - """ - self._base_path = "https://localhost:5002" if host is None else host - """Default Base url - """ - self.server_index = 0 if server_index is None and host is None else server_index - self.server_operation_index = server_operation_index or {} - """Default server index - """ - self.server_variables = server_variables or {} - self.server_operation_variables = server_operation_variables or {} - """Default server variables - """ - self.temp_folder_path = None - """Temp file folder for downloading files - """ - # Authentication Settings - self.access_token = access_token - self.api_key = {} - if api_key: - self.api_key = api_key - """dict to store API key(s) - """ - self.api_key_prefix = {} - if api_key_prefix: - self.api_key_prefix = api_key_prefix - """dict to store API prefix (e.g. Bearer) - """ - self.refresh_api_key_hook = None - """function hook to refresh API key if expired - """ - self.username = username - """Username for HTTP basic authentication - """ - self.password = password - """Password for HTTP basic authentication - """ - self.discard_unknown_keys = discard_unknown_keys - self.disabled_client_side_validations = disabled_client_side_validations - self.logger = {} - """Logging Settings - """ - self.logger["package_logger"] = logging.getLogger("observatory.api.client") - self.logger["urllib3_logger"] = logging.getLogger("urllib3") - self.logger_format = '%(asctime)s %(levelname)s %(message)s' - """Log format - """ - self.logger_stream_handler = None - """Log stream handler - """ - self.logger_file_handler = None - """Log file handler - """ - self.logger_file = None - """Debug file location - """ - self.debug = False - """Debug switch - """ - - self.verify_ssl = True - """SSL/TLS verification - Set this to false to skip verifying SSL certificate when calling API - from https server. - """ - self.ssl_ca_cert = ssl_ca_cert - """Set this to customize the certificate file to verify the peer. - """ - self.cert_file = None - """client certificate file - """ - self.key_file = None - """client key file - """ - self.assert_hostname = None - """Set this to True/False to enable/disable SSL hostname verification. - """ - - self.connection_pool_maxsize = multiprocessing.cpu_count() * 5 - """urllib3 connection pool's maximum number of connections saved - per pool. urllib3 uses 1 connection as default value, but this is - not the best value when you are making a lot of possibly parallel - requests to the same host, which is often the case here. - cpu_count * 5 is used as default value to increase performance. - """ - - self.proxy = None - """Proxy URL - """ - self.proxy_headers = None - """Proxy headers - """ - self.safe_chars_for_path_param = '' - """Safe chars for path_param - """ - self.retries = None - """Adding retries to override urllib3 default value 3 - """ - # Enable client side validation - self.client_side_validation = True - - # Options to pass down to the underlying urllib3 socket - self.socket_options = None - - def __deepcopy__(self, memo): - cls = self.__class__ - result = cls.__new__(cls) - memo[id(self)] = result - for k, v in self.__dict__.items(): - if k not in ('logger', 'logger_file_handler'): - setattr(result, k, copy.deepcopy(v, memo)) - # shallow copy of loggers - result.logger = copy.copy(self.logger) - # use setters to configure loggers - result.logger_file = self.logger_file - result.debug = self.debug - return result - - def __setattr__(self, name, value): - object.__setattr__(self, name, value) - if name == 'disabled_client_side_validations': - s = set(filter(None, value.split(','))) - for v in s: - if v not in JSON_SCHEMA_VALIDATION_KEYWORDS: - raise ApiValueError( - "Invalid keyword: '{0}''".format(v)) - self._disabled_client_side_validations = s - - @classmethod - def set_default(cls, default): - """Set default instance of configuration. - - It stores default configuration, which can be - returned by get_default_copy method. - - :param default: object of Configuration - """ - cls._default = copy.deepcopy(default) - - @classmethod - def get_default_copy(cls): - """Return new instance of configuration. - - This method returns newly created, based on default constructor, - object of Configuration class or returns a copy of default - configuration passed by the set_default method. - - :return: The configuration object. - """ - if cls._default is not None: - return copy.deepcopy(cls._default) - return Configuration() - - @property - def logger_file(self): - """The logger file. - - If the logger_file is None, then add stream handler and remove file - handler. Otherwise, add file handler and remove stream handler. - - :param value: The logger_file path. - :type: str - """ - return self.__logger_file - - @logger_file.setter - def logger_file(self, value): - """The logger file. - - If the logger_file is None, then add stream handler and remove file - handler. Otherwise, add file handler and remove stream handler. - - :param value: The logger_file path. - :type: str - """ - self.__logger_file = value - if self.__logger_file: - # If set logging file, - # then add file handler and remove stream handler. - self.logger_file_handler = logging.FileHandler(self.__logger_file) - self.logger_file_handler.setFormatter(self.logger_formatter) - for _, logger in self.logger.items(): - logger.addHandler(self.logger_file_handler) - - @property - def debug(self): - """Debug status - - :param value: The debug status, True or False. - :type: bool - """ - return self.__debug - - @debug.setter - def debug(self, value): - """Debug status - - :param value: The debug status, True or False. - :type: bool - """ - self.__debug = value - if self.__debug: - # if debug status is True, turn on debug logging - for _, logger in self.logger.items(): - logger.setLevel(logging.DEBUG) - # turn on http_client debug - http_client.HTTPConnection.debuglevel = 1 - else: - # if debug status is False, turn off debug logging, - # setting log level to default `logging.WARNING` - for _, logger in self.logger.items(): - logger.setLevel(logging.WARNING) - # turn off http_client debug - http_client.HTTPConnection.debuglevel = 0 - - @property - def logger_format(self): - """The logger format. - - The logger_formatter will be updated when sets logger_format. - - :param value: The format string. - :type: str - """ - return self.__logger_format - - @logger_format.setter - def logger_format(self, value): - """The logger format. - - The logger_formatter will be updated when sets logger_format. - - :param value: The format string. - :type: str - """ - self.__logger_format = value - self.logger_formatter = logging.Formatter(self.__logger_format) - - def get_api_key_with_prefix(self, identifier, alias=None): - """Gets API key (with prefix if set). - - :param identifier: The identifier of apiKey. - :param alias: The alternative identifier of apiKey. - :return: The token for api key authentication. - """ - if self.refresh_api_key_hook is not None: - self.refresh_api_key_hook(self) - key = self.api_key.get(identifier, self.api_key.get(alias) if alias is not None else None) - if key: - prefix = self.api_key_prefix.get(identifier) - if prefix: - return "%s %s" % (prefix, key) - else: - return key - - def get_basic_auth_token(self): - """Gets HTTP basic authentication header (string). - - :return: The token for basic HTTP authentication. - """ - username = "" - if self.username is not None: - username = self.username - password = "" - if self.password is not None: - password = self.password - return urllib3.util.make_headers( - basic_auth=username + ':' + password - ).get('authorization') - - def auth_settings(self): - """Gets Auth Settings dict for api client. - - :return: The Auth Settings information dict. - """ - auth = {} - if 'api_key' in self.api_key: - auth['api_key'] = { - 'type': 'api_key', - 'in': 'query', - 'key': 'key', - 'value': self.get_api_key_with_prefix( - 'api_key', - ), - } - return auth - - def to_debug_report(self): - """Gets the essential information for debugging. - - :return: The report for debugging. - """ - return "Python SDK Debug Report:\n"\ - "OS: {env}\n"\ - "Python Version: {pyversion}\n"\ - "Version of the API: 1.0.0\n"\ - "SDK Package Version: 1.0.0".\ - format(env=sys.platform, pyversion=sys.version) - - def get_host_settings(self): - """Gets an array of host settings - - :return: An array of host settings - """ - return [ - { - 'url': "https://localhost:5002", - 'description': "No description provided", - } - ] - - def get_host_from_settings(self, index, variables=None, servers=None): - """Gets host URL based on the index and variables - :param index: array index of the host settings - :param variables: hash of variable and the corresponding value - :param servers: an array of host settings or None - :return: URL based on host settings - """ - if index is None: - return self._base_path - - variables = {} if variables is None else variables - servers = self.get_host_settings() if servers is None else servers - - try: - server = servers[index] - except IndexError: - raise ValueError( - "Invalid index {0} when selecting the host settings. " - "Must be less than {1}".format(index, len(servers))) - - url = server['url'] - - # go through variables and replace placeholders - for variable_name, variable in server.get('variables', {}).items(): - used_value = variables.get( - variable_name, variable['default_value']) - - if 'enum_values' in variable \ - and used_value not in variable['enum_values']: - raise ValueError( - "The variable `{0}` in the host URL has invalid value " - "{1}. Must be {2}.".format( - variable_name, variables[variable_name], - variable['enum_values'])) - - url = url.replace("{" + variable_name + "}", used_value) - - return url - - @property - def host(self): - """Return generated host.""" - return self.get_host_from_settings(self.server_index, variables=self.server_variables) - - @host.setter - def host(self, value): - """Fix base path.""" - self._base_path = value - self.server_index = None \ No newline at end of file diff --git a/observatory-api/observatory/api/client/exceptions.py b/observatory-api/observatory/api/client/exceptions.py deleted file mode 100644 index c37b69343..000000000 --- a/observatory-api/observatory/api/client/exceptions.py +++ /dev/null @@ -1,159 +0,0 @@ -""" - Observatory API - - The REST API for managing and accessing data from the Observatory Platform. # noqa: E501 - - The version of the OpenAPI document: 1.0.0 - Contact: agent@observatory.academy - Generated by: https://openapi-generator.tech -""" - - -class OpenApiException(Exception): - """The base exception class for all OpenAPIExceptions""" - - -class ApiTypeError(OpenApiException, TypeError): - def __init__(self, msg, path_to_item=None, valid_classes=None, - key_type=None): - """ Raises an exception for TypeErrors - - Args: - msg (str): the exception message - - Keyword Args: - path_to_item (list): a list of keys an indices to get to the - current_item - None if unset - valid_classes (tuple): the primitive classes that current item - should be an instance of - None if unset - key_type (bool): False if our value is a value in a dict - True if it is a key in a dict - False if our item is an item in a list - None if unset - """ - self.path_to_item = path_to_item - self.valid_classes = valid_classes - self.key_type = key_type - full_msg = msg - if path_to_item: - full_msg = "{0} at {1}".format(msg, render_path(path_to_item)) - super(ApiTypeError, self).__init__(full_msg) - - -class ApiValueError(OpenApiException, ValueError): - def __init__(self, msg, path_to_item=None): - """ - Args: - msg (str): the exception message - - Keyword Args: - path_to_item (list) the path to the exception in the - received_data dict. None if unset - """ - - self.path_to_item = path_to_item - full_msg = msg - if path_to_item: - full_msg = "{0} at {1}".format(msg, render_path(path_to_item)) - super(ApiValueError, self).__init__(full_msg) - - -class ApiAttributeError(OpenApiException, AttributeError): - def __init__(self, msg, path_to_item=None): - """ - Raised when an attribute reference or assignment fails. - - Args: - msg (str): the exception message - - Keyword Args: - path_to_item (None/list) the path to the exception in the - received_data dict - """ - self.path_to_item = path_to_item - full_msg = msg - if path_to_item: - full_msg = "{0} at {1}".format(msg, render_path(path_to_item)) - super(ApiAttributeError, self).__init__(full_msg) - - -class ApiKeyError(OpenApiException, KeyError): - def __init__(self, msg, path_to_item=None): - """ - Args: - msg (str): the exception message - - Keyword Args: - path_to_item (None/list) the path to the exception in the - received_data dict - """ - self.path_to_item = path_to_item - full_msg = msg - if path_to_item: - full_msg = "{0} at {1}".format(msg, render_path(path_to_item)) - super(ApiKeyError, self).__init__(full_msg) - - -class ApiException(OpenApiException): - - def __init__(self, status=None, reason=None, http_resp=None): - if http_resp: - self.status = http_resp.status - self.reason = http_resp.reason - self.body = http_resp.data - self.headers = http_resp.getheaders() - else: - self.status = status - self.reason = reason - self.body = None - self.headers = None - - def __str__(self): - """Custom error messages for exception""" - error_message = "Status Code: {0}\n"\ - "Reason: {1}\n".format(self.status, self.reason) - if self.headers: - error_message += "HTTP response headers: {0}\n".format( - self.headers) - - if self.body: - error_message += "HTTP response body: {0}\n".format(self.body) - - return error_message - - -class NotFoundException(ApiException): - - def __init__(self, status=None, reason=None, http_resp=None): - super(NotFoundException, self).__init__(status, reason, http_resp) - - -class UnauthorizedException(ApiException): - - def __init__(self, status=None, reason=None, http_resp=None): - super(UnauthorizedException, self).__init__(status, reason, http_resp) - - -class ForbiddenException(ApiException): - - def __init__(self, status=None, reason=None, http_resp=None): - super(ForbiddenException, self).__init__(status, reason, http_resp) - - -class ServiceException(ApiException): - - def __init__(self, status=None, reason=None, http_resp=None): - super(ServiceException, self).__init__(status, reason, http_resp) - - -def render_path(path_to_item): - """Returns a string representation of a path""" - result = "" - for pth in path_to_item: - if isinstance(pth, int): - result += "[{0}]".format(pth) - else: - result += "['{0}']".format(pth) - return result diff --git a/observatory-api/observatory/api/client/model/__init__.py b/observatory-api/observatory/api/client/model/__init__.py deleted file mode 100644 index 5542cf9f3..000000000 --- a/observatory-api/observatory/api/client/model/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# we can not import model classes here because that would create a circular -# reference which would not work in python2 -# do not import all models into this module because that uses a lot of memory and stack frames -# if you need the ability to import all models from one package, import them with -# from observatory.api.client.models import ModelA, ModelB diff --git a/observatory-api/observatory/api/client/model/dataset_release.py b/observatory-api/observatory/api/client/model/dataset_release.py deleted file mode 100644 index 9dd269c41..000000000 --- a/observatory-api/observatory/api/client/model/dataset_release.py +++ /dev/null @@ -1,316 +0,0 @@ -""" - Observatory API - - The REST API for managing and accessing data from the Observatory Platform. # noqa: E501 - - The version of the OpenAPI document: 1.0.0 - Contact: agent@observatory.academy - Generated by: https://openapi-generator.tech -""" - - -import re # noqa: F401 -import sys # noqa: F401 - -from observatory.api.client.model_utils import ( # noqa: F401 - ApiTypeError, - ModelComposed, - ModelNormal, - ModelSimple, - cached_property, - change_keys_js_to_python, - convert_js_args_to_python_args, - date, - datetime, - file_type, - none_type, - validate_get_composed_info, - OpenApiModel -) -from observatory.api.client.exceptions import ApiAttributeError - - - -class DatasetRelease(ModelNormal): - """NOTE: This class is auto generated by OpenAPI Generator. - Ref: https://openapi-generator.tech - - Do not edit the class manually. - - Attributes: - allowed_values (dict): The key is the tuple path to the attribute - and the for var_name this is (var_name,). The value is a dict - with a capitalized key describing the allowed value and an allowed - value. These dicts store the allowed enum values. - attribute_map (dict): The key is attribute name - and the value is json key in definition. - discriminator_value_class_map (dict): A dict to go from the discriminator - variable value to the discriminator class name. - validations (dict): The key is the tuple path to the attribute - and the for var_name this is (var_name,). The value is a dict - that stores validations for max_length, min_length, max_items, - min_items, exclusive_maximum, inclusive_maximum, exclusive_minimum, - inclusive_minimum, and regex. - additional_properties_type (tuple): A tuple of classes accepted - as additional properties values. - """ - - allowed_values = { - } - - validations = { - } - - additional_properties_type = None - - _nullable = False - - @cached_property - def openapi_types(): - """ - This must be a method because a model may have properties that are - of type self, this must run after the class is loaded - - Returns - openapi_types (dict): The key is attribute name - and the value is attribute type. - """ - return { - 'id': (int,), # noqa: E501, F821 - 'dag_id': (str,), # noqa: E501, F821 - 'dataset_id': (str,), # noqa: E501, F821 - 'dag_run_id': (str, none_type,), # noqa: E501, F821 - 'data_interval_start': (datetime, none_type,), # noqa: E501, F821 - 'data_interval_end': (datetime, none_type,), # noqa: E501, F821 - 'snapshot_date': (datetime, none_type,), # noqa: E501, F821 - 'partition_date': (datetime, none_type,), # noqa: E501, F821 - 'changefile_start_date': (datetime, none_type,), # noqa: E501, F821 - 'changefile_end_date': (datetime, none_type,), # noqa: E501, F821 - 'sequence_start': (int, none_type,), # noqa: E501, F821 - 'sequence_end': (int, none_type,), # noqa: E501, F821 - 'created': (datetime,), # noqa: E501, F821 - 'modified': (datetime,), # noqa: E501, F821 - 'extra': (bool, date, datetime, dict, float, int, list, str, none_type,), # noqa: E501, F821 - } - - @cached_property - def discriminator(): - return None - - - attribute_map = { - 'id': 'id', # noqa: E501 - 'dag_id': 'dag_id', # noqa: E501 - 'dataset_id': 'dataset_id', # noqa: E501 - 'dag_run_id': 'dag_run_id', # noqa: E501 - 'data_interval_start': 'data_interval_start', # noqa: E501 - 'data_interval_end': 'data_interval_end', # noqa: E501 - 'snapshot_date': 'snapshot_date', # noqa: E501 - 'partition_date': 'partition_date', # noqa: E501 - 'changefile_start_date': 'changefile_start_date', # noqa: E501 - 'changefile_end_date': 'changefile_end_date', # noqa: E501 - 'sequence_start': 'sequence_start', # noqa: E501 - 'sequence_end': 'sequence_end', # noqa: E501 - 'created': 'created', # noqa: E501 - 'modified': 'modified', # noqa: E501 - 'extra': 'extra', # noqa: E501 - } - - read_only_vars = { - 'created', # noqa: E501 - 'modified', # noqa: E501 - } - - _composed_schemas = {} - - @classmethod - @convert_js_args_to_python_args - def _from_openapi_data(cls, *args, **kwargs): # noqa: E501 - """DatasetRelease - a model defined in OpenAPI - - Keyword Args: - _check_type (bool): if True, values for parameters in openapi_types - will be type checked and a TypeError will be - raised if the wrong type is input. - Defaults to True - _path_to_item (tuple/list): This is a list of keys or values to - drill down to the model in received_data - when deserializing a response - _spec_property_naming (bool): True if the variable names in the input data - are serialized names, as specified in the OpenAPI document. - False if the variable names in the input data - are pythonic names, e.g. snake case (default) - _configuration (Configuration): the instance to use when - deserializing a file_type parameter. - If passed, type conversion is attempted - If omitted no type conversion is done. - _visited_composed_classes (tuple): This stores a tuple of - classes that we have traveled through so that - if we see that class again we will not use its - discriminator again. - When traveling through a discriminator, the - composed schema that is - is traveled through is added to this set. - For example if Animal has a discriminator - petType and we pass in "Dog", and the class Dog - allOf includes Animal, we move through Animal - once using the discriminator, and pick Dog. - Then in Dog, we will make an instance of the - Animal class but this time we won't travel - through its discriminator because we passed in - _visited_composed_classes = (Animal,) - id (int): [optional] # noqa: E501 - dag_id (str): [optional] # noqa: E501 - dataset_id (str): [optional] # noqa: E501 - dag_run_id (str, none_type): [optional] # noqa: E501 - data_interval_start (datetime, none_type): [optional] # noqa: E501 - data_interval_end (datetime, none_type): [optional] # noqa: E501 - snapshot_date (datetime, none_type): [optional] # noqa: E501 - partition_date (datetime, none_type): [optional] # noqa: E501 - changefile_start_date (datetime, none_type): [optional] # noqa: E501 - changefile_end_date (datetime, none_type): [optional] # noqa: E501 - sequence_start (int, none_type): [optional] # noqa: E501 - sequence_end (int, none_type): [optional] # noqa: E501 - created (datetime): [optional] # noqa: E501 - modified (datetime): [optional] # noqa: E501 - extra (bool, date, datetime, dict, float, int, list, str, none_type): [optional] # noqa: E501 - """ - - _check_type = kwargs.pop('_check_type', True) - _spec_property_naming = kwargs.pop('_spec_property_naming', True) - _path_to_item = kwargs.pop('_path_to_item', ()) - _configuration = kwargs.pop('_configuration', None) - _visited_composed_classes = kwargs.pop('_visited_composed_classes', ()) - - self = super(OpenApiModel, cls).__new__(cls) - - if args: - for arg in args: - if isinstance(arg, dict): - kwargs.update(arg) - else: - raise ApiTypeError( - "Invalid positional arguments=%s passed to %s. Remove those invalid positional arguments." % ( - args, - self.__class__.__name__, - ), - path_to_item=_path_to_item, - valid_classes=(self.__class__,), - ) - - self._data_store = {} - self._check_type = _check_type - self._spec_property_naming = _spec_property_naming - self._path_to_item = _path_to_item - self._configuration = _configuration - self._visited_composed_classes = _visited_composed_classes + (self.__class__,) - - for var_name, var_value in kwargs.items(): - if var_name not in self.attribute_map and \ - self._configuration is not None and \ - self._configuration.discard_unknown_keys and \ - self.additional_properties_type is None: - # discard variable. - continue - setattr(self, var_name, var_value) - return self - - required_properties = set([ - '_data_store', - '_check_type', - '_spec_property_naming', - '_path_to_item', - '_configuration', - '_visited_composed_classes', - ]) - - @convert_js_args_to_python_args - def __init__(self, *args, **kwargs): # noqa: E501 - """DatasetRelease - a model defined in OpenAPI - - Keyword Args: - _check_type (bool): if True, values for parameters in openapi_types - will be type checked and a TypeError will be - raised if the wrong type is input. - Defaults to True - _path_to_item (tuple/list): This is a list of keys or values to - drill down to the model in received_data - when deserializing a response - _spec_property_naming (bool): True if the variable names in the input data - are serialized names, as specified in the OpenAPI document. - False if the variable names in the input data - are pythonic names, e.g. snake case (default) - _configuration (Configuration): the instance to use when - deserializing a file_type parameter. - If passed, type conversion is attempted - If omitted no type conversion is done. - _visited_composed_classes (tuple): This stores a tuple of - classes that we have traveled through so that - if we see that class again we will not use its - discriminator again. - When traveling through a discriminator, the - composed schema that is - is traveled through is added to this set. - For example if Animal has a discriminator - petType and we pass in "Dog", and the class Dog - allOf includes Animal, we move through Animal - once using the discriminator, and pick Dog. - Then in Dog, we will make an instance of the - Animal class but this time we won't travel - through its discriminator because we passed in - _visited_composed_classes = (Animal,) - id (int): [optional] # noqa: E501 - dag_id (str): [optional] # noqa: E501 - dataset_id (str): [optional] # noqa: E501 - dag_run_id (str, none_type): [optional] # noqa: E501 - data_interval_start (datetime, none_type): [optional] # noqa: E501 - data_interval_end (datetime, none_type): [optional] # noqa: E501 - snapshot_date (datetime, none_type): [optional] # noqa: E501 - partition_date (datetime, none_type): [optional] # noqa: E501 - changefile_start_date (datetime, none_type): [optional] # noqa: E501 - changefile_end_date (datetime, none_type): [optional] # noqa: E501 - sequence_start (int, none_type): [optional] # noqa: E501 - sequence_end (int, none_type): [optional] # noqa: E501 - created (datetime): [optional] # noqa: E501 - modified (datetime): [optional] # noqa: E501 - extra (bool, date, datetime, dict, float, int, list, str, none_type): [optional] # noqa: E501 - """ - - _check_type = kwargs.pop('_check_type', True) - _spec_property_naming = kwargs.pop('_spec_property_naming', False) - _path_to_item = kwargs.pop('_path_to_item', ()) - _configuration = kwargs.pop('_configuration', None) - _visited_composed_classes = kwargs.pop('_visited_composed_classes', ()) - - if args: - for arg in args: - if isinstance(arg, dict): - kwargs.update(arg) - else: - raise ApiTypeError( - "Invalid positional arguments=%s passed to %s. Remove those invalid positional arguments." % ( - args, - self.__class__.__name__, - ), - path_to_item=_path_to_item, - valid_classes=(self.__class__,), - ) - - self._data_store = {} - self._check_type = _check_type - self._spec_property_naming = _spec_property_naming - self._path_to_item = _path_to_item - self._configuration = _configuration - self._visited_composed_classes = _visited_composed_classes + (self.__class__,) - - for var_name, var_value in kwargs.items(): - if var_name not in self.attribute_map and \ - self._configuration is not None and \ - self._configuration.discard_unknown_keys and \ - self.additional_properties_type is None: - # discard variable. - continue - setattr(self, var_name, var_value) - if var_name in self.read_only_vars: - raise ApiAttributeError(f"`{var_name}` is a read-only attribute. Use `from_openapi_data` to instantiate " - f"class with read only attributes.") diff --git a/observatory-api/observatory/api/client/model_utils.py b/observatory-api/observatory/api/client/model_utils.py deleted file mode 100644 index 1feaa70e1..000000000 --- a/observatory-api/observatory/api/client/model_utils.py +++ /dev/null @@ -1,2059 +0,0 @@ -""" - Observatory API - - The REST API for managing and accessing data from the Observatory Platform. # noqa: E501 - - The version of the OpenAPI document: 1.0.0 - Contact: agent@observatory.academy - Generated by: https://openapi-generator.tech -""" - - -from datetime import date, datetime # noqa: F401 -from copy import deepcopy -import inspect -import io -import os -import pprint -import re -import tempfile -import uuid - -from dateutil.parser import parse - -from observatory.api.client.exceptions import ( - ApiKeyError, - ApiAttributeError, - ApiTypeError, - ApiValueError, -) - -none_type = type(None) -file_type = io.IOBase - - -def convert_js_args_to_python_args(fn): - from functools import wraps - @wraps(fn) - def wrapped_init(_self, *args, **kwargs): - """ - An attribute named `self` received from the api will conflicts with the reserved `self` - parameter of a class method. During generation, `self` attributes are mapped - to `_self` in models. Here, we name `_self` instead of `self` to avoid conflicts. - """ - spec_property_naming = kwargs.get('_spec_property_naming', False) - if spec_property_naming: - kwargs = change_keys_js_to_python( - kwargs, _self if isinstance( - _self, type) else _self.__class__) - return fn(_self, *args, **kwargs) - return wrapped_init - - -class cached_property(object): - # this caches the result of the function call for fn with no inputs - # use this as a decorator on function methods that you want converted - # into cached properties - result_key = '_results' - - def __init__(self, fn): - self._fn = fn - - def __get__(self, instance, cls=None): - if self.result_key in vars(self): - return vars(self)[self.result_key] - else: - result = self._fn() - setattr(self, self.result_key, result) - return result - - -PRIMITIVE_TYPES = (list, float, int, bool, datetime, date, str, file_type) - - -def allows_single_value_input(cls): - """ - This function returns True if the input composed schema model or any - descendant model allows a value only input - This is true for cases where oneOf contains items like: - oneOf: - - float - - NumberWithValidation - - StringEnum - - ArrayModel - - null - TODO: lru_cache this - """ - if ( - issubclass(cls, ModelSimple) or - cls in PRIMITIVE_TYPES - ): - return True - elif issubclass(cls, ModelComposed): - if not cls._composed_schemas['oneOf']: - return False - return any(allows_single_value_input(c) for c in cls._composed_schemas['oneOf']) - return False - - -def composed_model_input_classes(cls): - """ - This function returns a list of the possible models that can be accepted as - inputs. - TODO: lru_cache this - """ - if issubclass(cls, ModelSimple) or cls in PRIMITIVE_TYPES: - return [cls] - elif issubclass(cls, ModelNormal): - if cls.discriminator is None: - return [cls] - else: - return get_discriminated_classes(cls) - elif issubclass(cls, ModelComposed): - if not cls._composed_schemas['oneOf']: - return [] - if cls.discriminator is None: - input_classes = [] - for c in cls._composed_schemas['oneOf']: - input_classes.extend(composed_model_input_classes(c)) - return input_classes - else: - return get_discriminated_classes(cls) - return [] - - -class OpenApiModel(object): - """The base class for all OpenAPIModels""" - - def set_attribute(self, name, value): - # this is only used to set properties on self - - path_to_item = [] - if self._path_to_item: - path_to_item.extend(self._path_to_item) - path_to_item.append(name) - - if name in self.openapi_types: - required_types_mixed = self.openapi_types[name] - elif self.additional_properties_type is None: - raise ApiAttributeError( - "{0} has no attribute '{1}'".format( - type(self).__name__, name), - path_to_item - ) - elif self.additional_properties_type is not None: - required_types_mixed = self.additional_properties_type - - if get_simple_class(name) != str: - error_msg = type_error_message( - var_name=name, - var_value=name, - valid_classes=(str,), - key_type=True - ) - raise ApiTypeError( - error_msg, - path_to_item=path_to_item, - valid_classes=(str,), - key_type=True - ) - - if self._check_type: - value = validate_and_convert_types( - value, required_types_mixed, path_to_item, self._spec_property_naming, - self._check_type, configuration=self._configuration) - if (name,) in self.allowed_values: - check_allowed_values( - self.allowed_values, - (name,), - value - ) - if (name,) in self.validations: - check_validations( - self.validations, - (name,), - value, - self._configuration - ) - self.__dict__['_data_store'][name] = value - - def __repr__(self): - """For `print` and `pprint`""" - return self.to_str() - - def __ne__(self, other): - """Returns true if both objects are not equal""" - return not self == other - - def __setattr__(self, attr, value): - """set the value of an attribute using dot notation: `instance.attr = val`""" - self[attr] = value - - def __getattr__(self, attr): - """get the value of an attribute using dot notation: `instance.attr`""" - return self.get(attr) - - def __copy__(self): - cls = self.__class__ - if self.get("_spec_property_naming", False): - return cls._new_from_openapi_data(**self.__dict__) - else: - return cls.__new__(cls, **self.__dict__) - - def __deepcopy__(self, memo): - cls = self.__class__ - - if self.get("_spec_property_naming", False): - new_inst = cls._new_from_openapi_data() - else: - new_inst = cls.__new__(cls, **self.__dict__) - - for k, v in self.__dict__.items(): - setattr(new_inst, k, deepcopy(v, memo)) - return new_inst - - - def __new__(cls, *args, **kwargs): - # this function uses the discriminator to - # pick a new schema/class to instantiate because a discriminator - # propertyName value was passed in - - if len(args) == 1: - arg = args[0] - if arg is None and is_type_nullable(cls): - # The input data is the 'null' value and the type is nullable. - return None - - if issubclass(cls, ModelComposed) and allows_single_value_input(cls): - model_kwargs = {} - oneof_instance = get_oneof_instance(cls, model_kwargs, kwargs, model_arg=arg) - return oneof_instance - - visited_composed_classes = kwargs.get('_visited_composed_classes', ()) - if ( - cls.discriminator is None or - cls in visited_composed_classes - ): - # Use case 1: this openapi schema (cls) does not have a discriminator - # Use case 2: we have already visited this class before and are sure that we - # want to instantiate it this time. We have visited this class deserializing - # a payload with a discriminator. During that process we traveled through - # this class but did not make an instance of it. Now we are making an - # instance of a composed class which contains cls in it, so this time make an instance of cls. - # - # Here's an example of use case 2: If Animal has a discriminator - # petType and we pass in "Dog", and the class Dog - # allOf includes Animal, we move through Animal - # once using the discriminator, and pick Dog. - # Then in the composed schema dog Dog, we will make an instance of the - # Animal class (because Dal has allOf: Animal) but this time we won't travel - # through Animal's discriminator because we passed in - # _visited_composed_classes = (Animal,) - - return super(OpenApiModel, cls).__new__(cls) - - # Get the name and value of the discriminator property. - # The discriminator name is obtained from the discriminator meta-data - # and the discriminator value is obtained from the input data. - discr_propertyname_py = list(cls.discriminator.keys())[0] - discr_propertyname_js = cls.attribute_map[discr_propertyname_py] - if discr_propertyname_js in kwargs: - discr_value = kwargs[discr_propertyname_js] - elif discr_propertyname_py in kwargs: - discr_value = kwargs[discr_propertyname_py] - else: - # The input data does not contain the discriminator property. - path_to_item = kwargs.get('_path_to_item', ()) - raise ApiValueError( - "Cannot deserialize input data due to missing discriminator. " - "The discriminator property '%s' is missing at path: %s" % - (discr_propertyname_js, path_to_item) - ) - - # Implementation note: the last argument to get_discriminator_class - # is a list of visited classes. get_discriminator_class may recursively - # call itself and update the list of visited classes, and the initial - # value must be an empty list. Hence not using 'visited_composed_classes' - new_cls = get_discriminator_class( - cls, discr_propertyname_py, discr_value, []) - if new_cls is None: - path_to_item = kwargs.get('_path_to_item', ()) - disc_prop_value = kwargs.get( - discr_propertyname_js, kwargs.get(discr_propertyname_py)) - raise ApiValueError( - "Cannot deserialize input data due to invalid discriminator " - "value. The OpenAPI document has no mapping for discriminator " - "property '%s'='%s' at path: %s" % - (discr_propertyname_js, disc_prop_value, path_to_item) - ) - - if new_cls in visited_composed_classes: - # if we are making an instance of a composed schema Descendent - # which allOf includes Ancestor, then Ancestor contains - # a discriminator that includes Descendent. - # So if we make an instance of Descendent, we have to make an - # instance of Ancestor to hold the allOf properties. - # This code detects that use case and makes the instance of Ancestor - # For example: - # When making an instance of Dog, _visited_composed_classes = (Dog,) - # then we make an instance of Animal to include in dog._composed_instances - # so when we are here, cls is Animal - # cls.discriminator != None - # cls not in _visited_composed_classes - # new_cls = Dog - # but we know we know that we already have Dog - # because it is in visited_composed_classes - # so make Animal here - return super(OpenApiModel, cls).__new__(cls) - - # Build a list containing all oneOf and anyOf descendants. - oneof_anyof_classes = None - if cls._composed_schemas is not None: - oneof_anyof_classes = ( - cls._composed_schemas.get('oneOf', ()) + - cls._composed_schemas.get('anyOf', ())) - oneof_anyof_child = new_cls in oneof_anyof_classes - kwargs['_visited_composed_classes'] = visited_composed_classes + (cls,) - - if cls._composed_schemas.get('allOf') and oneof_anyof_child: - # Validate that we can make self because when we make the - # new_cls it will not include the allOf validations in self - self_inst = super(OpenApiModel, cls).__new__(cls) - self_inst.__init__(*args, **kwargs) - - if kwargs.get("_spec_property_naming", False): - # when true, implies new is from deserialization - new_inst = new_cls._new_from_openapi_data(*args, **kwargs) - else: - new_inst = new_cls.__new__(new_cls, *args, **kwargs) - new_inst.__init__(*args, **kwargs) - - return new_inst - - @classmethod - @convert_js_args_to_python_args - def _new_from_openapi_data(cls, *args, **kwargs): - # this function uses the discriminator to - # pick a new schema/class to instantiate because a discriminator - # propertyName value was passed in - - if len(args) == 1: - arg = args[0] - if arg is None and is_type_nullable(cls): - # The input data is the 'null' value and the type is nullable. - return None - - if issubclass(cls, ModelComposed) and allows_single_value_input(cls): - model_kwargs = {} - oneof_instance = get_oneof_instance(cls, model_kwargs, kwargs, model_arg=arg) - return oneof_instance - - visited_composed_classes = kwargs.get('_visited_composed_classes', ()) - if ( - cls.discriminator is None or - cls in visited_composed_classes - ): - # Use case 1: this openapi schema (cls) does not have a discriminator - # Use case 2: we have already visited this class before and are sure that we - # want to instantiate it this time. We have visited this class deserializing - # a payload with a discriminator. During that process we traveled through - # this class but did not make an instance of it. Now we are making an - # instance of a composed class which contains cls in it, so this time make an instance of cls. - # - # Here's an example of use case 2: If Animal has a discriminator - # petType and we pass in "Dog", and the class Dog - # allOf includes Animal, we move through Animal - # once using the discriminator, and pick Dog. - # Then in the composed schema dog Dog, we will make an instance of the - # Animal class (because Dal has allOf: Animal) but this time we won't travel - # through Animal's discriminator because we passed in - # _visited_composed_classes = (Animal,) - - return cls._from_openapi_data(*args, **kwargs) - - # Get the name and value of the discriminator property. - # The discriminator name is obtained from the discriminator meta-data - # and the discriminator value is obtained from the input data. - discr_propertyname_py = list(cls.discriminator.keys())[0] - discr_propertyname_js = cls.attribute_map[discr_propertyname_py] - if discr_propertyname_js in kwargs: - discr_value = kwargs[discr_propertyname_js] - elif discr_propertyname_py in kwargs: - discr_value = kwargs[discr_propertyname_py] - else: - # The input data does not contain the discriminator property. - path_to_item = kwargs.get('_path_to_item', ()) - raise ApiValueError( - "Cannot deserialize input data due to missing discriminator. " - "The discriminator property '%s' is missing at path: %s" % - (discr_propertyname_js, path_to_item) - ) - - # Implementation note: the last argument to get_discriminator_class - # is a list of visited classes. get_discriminator_class may recursively - # call itself and update the list of visited classes, and the initial - # value must be an empty list. Hence not using 'visited_composed_classes' - new_cls = get_discriminator_class( - cls, discr_propertyname_py, discr_value, []) - if new_cls is None: - path_to_item = kwargs.get('_path_to_item', ()) - disc_prop_value = kwargs.get( - discr_propertyname_js, kwargs.get(discr_propertyname_py)) - raise ApiValueError( - "Cannot deserialize input data due to invalid discriminator " - "value. The OpenAPI document has no mapping for discriminator " - "property '%s'='%s' at path: %s" % - (discr_propertyname_js, disc_prop_value, path_to_item) - ) - - if new_cls in visited_composed_classes: - # if we are making an instance of a composed schema Descendent - # which allOf includes Ancestor, then Ancestor contains - # a discriminator that includes Descendent. - # So if we make an instance of Descendent, we have to make an - # instance of Ancestor to hold the allOf properties. - # This code detects that use case and makes the instance of Ancestor - # For example: - # When making an instance of Dog, _visited_composed_classes = (Dog,) - # then we make an instance of Animal to include in dog._composed_instances - # so when we are here, cls is Animal - # cls.discriminator != None - # cls not in _visited_composed_classes - # new_cls = Dog - # but we know we know that we already have Dog - # because it is in visited_composed_classes - # so make Animal here - return cls._from_openapi_data(*args, **kwargs) - - # Build a list containing all oneOf and anyOf descendants. - oneof_anyof_classes = None - if cls._composed_schemas is not None: - oneof_anyof_classes = ( - cls._composed_schemas.get('oneOf', ()) + - cls._composed_schemas.get('anyOf', ())) - oneof_anyof_child = new_cls in oneof_anyof_classes - kwargs['_visited_composed_classes'] = visited_composed_classes + (cls,) - - if cls._composed_schemas.get('allOf') and oneof_anyof_child: - # Validate that we can make self because when we make the - # new_cls it will not include the allOf validations in self - self_inst = cls._from_openapi_data(*args, **kwargs) - - new_inst = new_cls._new_from_openapi_data(*args, **kwargs) - return new_inst - - -class ModelSimple(OpenApiModel): - """the parent class of models whose type != object in their - swagger/openapi""" - - def __setitem__(self, name, value): - """set the value of an attribute using square-bracket notation: `instance[attr] = val`""" - if name in self.required_properties: - self.__dict__[name] = value - return - - self.set_attribute(name, value) - - def get(self, name, default=None): - """returns the value of an attribute or some default value if the attribute was not set""" - if name in self.required_properties: - return self.__dict__[name] - - return self.__dict__['_data_store'].get(name, default) - - def __getitem__(self, name): - """get the value of an attribute using square-bracket notation: `instance[attr]`""" - if name in self: - return self.get(name) - - raise ApiAttributeError( - "{0} has no attribute '{1}'".format( - type(self).__name__, name), - [e for e in [self._path_to_item, name] if e] - ) - - def __contains__(self, name): - """used by `in` operator to check if an attribute value was set in an instance: `'attr' in instance`""" - if name in self.required_properties: - return name in self.__dict__ - - return name in self.__dict__['_data_store'] - - def to_str(self): - """Returns the string representation of the model""" - return str(self.value) - - def __eq__(self, other): - """Returns true if both objects are equal""" - if not isinstance(other, self.__class__): - return False - - this_val = self._data_store['value'] - that_val = other._data_store['value'] - types = set() - types.add(this_val.__class__) - types.add(that_val.__class__) - vals_equal = this_val == that_val - return vals_equal - - -class ModelNormal(OpenApiModel): - """the parent class of models whose type == object in their - swagger/openapi""" - - def __setitem__(self, name, value): - """set the value of an attribute using square-bracket notation: `instance[attr] = val`""" - if name in self.required_properties: - self.__dict__[name] = value - return - - self.set_attribute(name, value) - - def get(self, name, default=None): - """returns the value of an attribute or some default value if the attribute was not set""" - if name in self.required_properties: - return self.__dict__[name] - - return self.__dict__['_data_store'].get(name, default) - - def __getitem__(self, name): - """get the value of an attribute using square-bracket notation: `instance[attr]`""" - if name in self: - return self.get(name) - - raise ApiAttributeError( - "{0} has no attribute '{1}'".format( - type(self).__name__, name), - [e for e in [self._path_to_item, name] if e] - ) - - def __contains__(self, name): - """used by `in` operator to check if an attribute value was set in an instance: `'attr' in instance`""" - if name in self.required_properties: - return name in self.__dict__ - - return name in self.__dict__['_data_store'] - - def to_dict(self): - """Returns the model properties as a dict""" - return model_to_dict(self, serialize=False) - - def to_str(self): - """Returns the string representation of the model""" - return pprint.pformat(self.to_dict()) - - def __eq__(self, other): - """Returns true if both objects are equal""" - if not isinstance(other, self.__class__): - return False - - if not set(self._data_store.keys()) == set(other._data_store.keys()): - return False - for _var_name, this_val in self._data_store.items(): - that_val = other._data_store[_var_name] - types = set() - types.add(this_val.__class__) - types.add(that_val.__class__) - vals_equal = this_val == that_val - if not vals_equal: - return False - return True - - -class ModelComposed(OpenApiModel): - """the parent class of models whose type == object in their - swagger/openapi and have oneOf/allOf/anyOf - - When one sets a property we use var_name_to_model_instances to store the value in - the correct class instances + run any type checking + validation code. - When one gets a property we use var_name_to_model_instances to get the value - from the correct class instances. - This allows multiple composed schemas to contain the same property with additive - constraints on the value. - - _composed_schemas (dict) stores the anyOf/allOf/oneOf classes - key (str): allOf/oneOf/anyOf - value (list): the classes in the XOf definition. - Note: none_type can be included when the openapi document version >= 3.1.0 - _composed_instances (list): stores a list of instances of the composed schemas - defined in _composed_schemas. When properties are accessed in the self instance, - they are returned from the self._data_store or the data stores in the instances - in self._composed_schemas - _var_name_to_model_instances (dict): maps between a variable name on self and - the composed instances (self included) which contain that data - key (str): property name - value (list): list of class instances, self or instances in _composed_instances - which contain the value that the key is referring to. - """ - - def __setitem__(self, name, value): - """set the value of an attribute using square-bracket notation: `instance[attr] = val`""" - if name in self.required_properties: - self.__dict__[name] = value - return - - """ - Use cases: - 1. additional_properties_type is None (additionalProperties == False in spec) - Check for property presence in self.openapi_types - if not present then throw an error - if present set in self, set attribute - always set on composed schemas - 2. additional_properties_type exists - set attribute on self - always set on composed schemas - """ - if self.additional_properties_type is None: - """ - For an attribute to exist on a composed schema it must: - - fulfill schema_requirements in the self composed schema not considering oneOf/anyOf/allOf schemas AND - - fulfill schema_requirements in each oneOf/anyOf/allOf schemas - - schema_requirements: - For an attribute to exist on a schema it must: - - be present in properties at the schema OR - - have additionalProperties unset (defaults additionalProperties = any type) OR - - have additionalProperties set - """ - if name not in self.openapi_types: - raise ApiAttributeError( - "{0} has no attribute '{1}'".format( - type(self).__name__, name), - [e for e in [self._path_to_item, name] if e] - ) - # attribute must be set on self and composed instances - self.set_attribute(name, value) - for model_instance in self._composed_instances: - setattr(model_instance, name, value) - if name not in self._var_name_to_model_instances: - # we assigned an additional property - self.__dict__['_var_name_to_model_instances'][name] = self._composed_instances + [self] - return None - - __unset_attribute_value__ = object() - - def get(self, name, default=None): - """returns the value of an attribute or some default value if the attribute was not set""" - if name in self.required_properties: - return self.__dict__[name] - - # get the attribute from the correct instance - model_instances = self._var_name_to_model_instances.get(name) - values = [] - # A composed model stores self and child (oneof/anyOf/allOf) models under - # self._var_name_to_model_instances. - # Any property must exist in self and all model instances - # The value stored in all model instances must be the same - if model_instances: - for model_instance in model_instances: - if name in model_instance._data_store: - v = model_instance._data_store[name] - if v not in values: - values.append(v) - len_values = len(values) - if len_values == 0: - return default - elif len_values == 1: - return values[0] - elif len_values > 1: - raise ApiValueError( - "Values stored for property {0} in {1} differ when looking " - "at self and self's composed instances. All values must be " - "the same".format(name, type(self).__name__), - [e for e in [self._path_to_item, name] if e] - ) - - def __getitem__(self, name): - """get the value of an attribute using square-bracket notation: `instance[attr]`""" - value = self.get(name, self.__unset_attribute_value__) - if value is self.__unset_attribute_value__: - raise ApiAttributeError( - "{0} has no attribute '{1}'".format( - type(self).__name__, name), - [e for e in [self._path_to_item, name] if e] - ) - return value - - def __contains__(self, name): - """used by `in` operator to check if an attribute value was set in an instance: `'attr' in instance`""" - - if name in self.required_properties: - return name in self.__dict__ - - model_instances = self._var_name_to_model_instances.get( - name, self._additional_properties_model_instances) - - if model_instances: - for model_instance in model_instances: - if name in model_instance._data_store: - return True - - return False - - def to_dict(self): - """Returns the model properties as a dict""" - return model_to_dict(self, serialize=False) - - def to_str(self): - """Returns the string representation of the model""" - return pprint.pformat(self.to_dict()) - - def __eq__(self, other): - """Returns true if both objects are equal""" - if not isinstance(other, self.__class__): - return False - - if not set(self._data_store.keys()) == set(other._data_store.keys()): - return False - for _var_name, this_val in self._data_store.items(): - that_val = other._data_store[_var_name] - types = set() - types.add(this_val.__class__) - types.add(that_val.__class__) - vals_equal = this_val == that_val - if not vals_equal: - return False - return True - - -COERCION_INDEX_BY_TYPE = { - ModelComposed: 0, - ModelNormal: 1, - ModelSimple: 2, - none_type: 3, # The type of 'None'. - list: 4, - dict: 5, - float: 6, - int: 7, - bool: 8, - datetime: 9, - date: 10, - str: 11, - file_type: 12, # 'file_type' is an alias for the built-in 'file' or 'io.IOBase' type. -} - -# these are used to limit what type conversions we try to do -# when we have a valid type already and we want to try converting -# to another type -UPCONVERSION_TYPE_PAIRS = ( - (str, datetime), - (str, date), - # A float may be serialized as an integer, e.g. '3' is a valid serialized float. - (int, float), - (list, ModelComposed), - (dict, ModelComposed), - (str, ModelComposed), - (int, ModelComposed), - (float, ModelComposed), - (list, ModelComposed), - (list, ModelNormal), - (dict, ModelNormal), - (str, ModelSimple), - (int, ModelSimple), - (float, ModelSimple), - (list, ModelSimple), -) - -COERCIBLE_TYPE_PAIRS = { - False: ( # client instantiation of a model with client data - # (dict, ModelComposed), - # (list, ModelComposed), - # (dict, ModelNormal), - # (list, ModelNormal), - # (str, ModelSimple), - # (int, ModelSimple), - # (float, ModelSimple), - # (list, ModelSimple), - # (str, int), - # (str, float), - # (str, datetime), - # (str, date), - # (int, str), - # (float, str), - ), - True: ( # server -> client data - (dict, ModelComposed), - (list, ModelComposed), - (dict, ModelNormal), - (list, ModelNormal), - (str, ModelSimple), - (int, ModelSimple), - (float, ModelSimple), - (list, ModelSimple), - # (str, int), - # (str, float), - (str, datetime), - (str, date), - # (int, str), - # (float, str), - (str, file_type) - ), -} - - -def get_simple_class(input_value): - """Returns an input_value's simple class that we will use for type checking - Python2: - float and int will return int, where int is the python3 int backport - str and unicode will return str, where str is the python3 str backport - Note: float and int ARE both instances of int backport - Note: str_py2 and unicode_py2 are NOT both instances of str backport - - Args: - input_value (class/class_instance): the item for which we will return - the simple class - """ - if isinstance(input_value, type): - # input_value is a class - return input_value - elif isinstance(input_value, tuple): - return tuple - elif isinstance(input_value, list): - return list - elif isinstance(input_value, dict): - return dict - elif isinstance(input_value, none_type): - return none_type - elif isinstance(input_value, file_type): - return file_type - elif isinstance(input_value, bool): - # this must be higher than the int check because - # isinstance(True, int) == True - return bool - elif isinstance(input_value, int): - return int - elif isinstance(input_value, datetime): - # this must be higher than the date check because - # isinstance(datetime_instance, date) == True - return datetime - elif isinstance(input_value, date): - return date - elif isinstance(input_value, str): - return str - return type(input_value) - - -def check_allowed_values(allowed_values, input_variable_path, input_values): - """Raises an exception if the input_values are not allowed - - Args: - allowed_values (dict): the allowed_values dict - input_variable_path (tuple): the path to the input variable - input_values (list/str/int/float/date/datetime): the values that we - are checking to see if they are in allowed_values - """ - these_allowed_values = list(allowed_values[input_variable_path].values()) - if (isinstance(input_values, list) - and not set(input_values).issubset( - set(these_allowed_values))): - invalid_values = ", ".join( - map(str, set(input_values) - set(these_allowed_values))), - raise ApiValueError( - "Invalid values for `%s` [%s], must be a subset of [%s]" % - ( - input_variable_path[0], - invalid_values, - ", ".join(map(str, these_allowed_values)) - ) - ) - elif (isinstance(input_values, dict) - and not set( - input_values.keys()).issubset(set(these_allowed_values))): - invalid_values = ", ".join( - map(str, set(input_values.keys()) - set(these_allowed_values))) - raise ApiValueError( - "Invalid keys in `%s` [%s], must be a subset of [%s]" % - ( - input_variable_path[0], - invalid_values, - ", ".join(map(str, these_allowed_values)) - ) - ) - elif (not isinstance(input_values, (list, dict)) - and input_values not in these_allowed_values): - raise ApiValueError( - "Invalid value for `%s` (%s), must be one of %s" % - ( - input_variable_path[0], - input_values, - these_allowed_values - ) - ) - - -def is_json_validation_enabled(schema_keyword, configuration=None): - """Returns true if JSON schema validation is enabled for the specified - validation keyword. This can be used to skip JSON schema structural validation - as requested in the configuration. - - Args: - schema_keyword (string): the name of a JSON schema validation keyword. - configuration (Configuration): the configuration class. - """ - - return (configuration is None or - not hasattr(configuration, '_disabled_client_side_validations') or - schema_keyword not in configuration._disabled_client_side_validations) - - -def check_validations( - validations, input_variable_path, input_values, - configuration=None): - """Raises an exception if the input_values are invalid - - Args: - validations (dict): the validation dictionary. - input_variable_path (tuple): the path to the input variable. - input_values (list/str/int/float/date/datetime): the values that we - are checking. - configuration (Configuration): the configuration class. - """ - - if input_values is None: - return - - current_validations = validations[input_variable_path] - if (is_json_validation_enabled('multipleOf', configuration) and - 'multiple_of' in current_validations and - isinstance(input_values, (int, float)) and - not (float(input_values) / current_validations['multiple_of']).is_integer()): - # Note 'multipleOf' will be as good as the floating point arithmetic. - raise ApiValueError( - "Invalid value for `%s`, value must be a multiple of " - "`%s`" % ( - input_variable_path[0], - current_validations['multiple_of'] - ) - ) - - if (is_json_validation_enabled('maxLength', configuration) and - 'max_length' in current_validations and - len(input_values) > current_validations['max_length']): - raise ApiValueError( - "Invalid value for `%s`, length must be less than or equal to " - "`%s`" % ( - input_variable_path[0], - current_validations['max_length'] - ) - ) - - if (is_json_validation_enabled('minLength', configuration) and - 'min_length' in current_validations and - len(input_values) < current_validations['min_length']): - raise ApiValueError( - "Invalid value for `%s`, length must be greater than or equal to " - "`%s`" % ( - input_variable_path[0], - current_validations['min_length'] - ) - ) - - if (is_json_validation_enabled('maxItems', configuration) and - 'max_items' in current_validations and - len(input_values) > current_validations['max_items']): - raise ApiValueError( - "Invalid value for `%s`, number of items must be less than or " - "equal to `%s`" % ( - input_variable_path[0], - current_validations['max_items'] - ) - ) - - if (is_json_validation_enabled('minItems', configuration) and - 'min_items' in current_validations and - len(input_values) < current_validations['min_items']): - raise ValueError( - "Invalid value for `%s`, number of items must be greater than or " - "equal to `%s`" % ( - input_variable_path[0], - current_validations['min_items'] - ) - ) - - items = ('exclusive_maximum', 'inclusive_maximum', 'exclusive_minimum', - 'inclusive_minimum') - if (any(item in current_validations for item in items)): - if isinstance(input_values, list): - max_val = max(input_values) - min_val = min(input_values) - elif isinstance(input_values, dict): - max_val = max(input_values.values()) - min_val = min(input_values.values()) - else: - max_val = input_values - min_val = input_values - - if (is_json_validation_enabled('exclusiveMaximum', configuration) and - 'exclusive_maximum' in current_validations and - max_val >= current_validations['exclusive_maximum']): - raise ApiValueError( - "Invalid value for `%s`, must be a value less than `%s`" % ( - input_variable_path[0], - current_validations['exclusive_maximum'] - ) - ) - - if (is_json_validation_enabled('maximum', configuration) and - 'inclusive_maximum' in current_validations and - max_val > current_validations['inclusive_maximum']): - raise ApiValueError( - "Invalid value for `%s`, must be a value less than or equal to " - "`%s`" % ( - input_variable_path[0], - current_validations['inclusive_maximum'] - ) - ) - - if (is_json_validation_enabled('exclusiveMinimum', configuration) and - 'exclusive_minimum' in current_validations and - min_val <= current_validations['exclusive_minimum']): - raise ApiValueError( - "Invalid value for `%s`, must be a value greater than `%s`" % - ( - input_variable_path[0], - current_validations['exclusive_maximum'] - ) - ) - - if (is_json_validation_enabled('minimum', configuration) and - 'inclusive_minimum' in current_validations and - min_val < current_validations['inclusive_minimum']): - raise ApiValueError( - "Invalid value for `%s`, must be a value greater than or equal " - "to `%s`" % ( - input_variable_path[0], - current_validations['inclusive_minimum'] - ) - ) - flags = current_validations.get('regex', {}).get('flags', 0) - if (is_json_validation_enabled('pattern', configuration) and - 'regex' in current_validations and - not re.search(current_validations['regex']['pattern'], - input_values, flags=flags)): - err_msg = r"Invalid value for `%s`, must match regular expression `%s`" % ( - input_variable_path[0], - current_validations['regex']['pattern'] - ) - if flags != 0: - # Don't print the regex flags if the flags are not - # specified in the OAS document. - err_msg = r"%s with flags=`%s`" % (err_msg, flags) - raise ApiValueError(err_msg) - - -def order_response_types(required_types): - """Returns the required types sorted in coercion order - - Args: - required_types (list/tuple): collection of classes or instance of - list or dict with class information inside it. - - Returns: - (list): coercion order sorted collection of classes or instance - of list or dict with class information inside it. - """ - - def index_getter(class_or_instance): - if isinstance(class_or_instance, list): - return COERCION_INDEX_BY_TYPE[list] - elif isinstance(class_or_instance, dict): - return COERCION_INDEX_BY_TYPE[dict] - elif (inspect.isclass(class_or_instance) - and issubclass(class_or_instance, ModelComposed)): - return COERCION_INDEX_BY_TYPE[ModelComposed] - elif (inspect.isclass(class_or_instance) - and issubclass(class_or_instance, ModelNormal)): - return COERCION_INDEX_BY_TYPE[ModelNormal] - elif (inspect.isclass(class_or_instance) - and issubclass(class_or_instance, ModelSimple)): - return COERCION_INDEX_BY_TYPE[ModelSimple] - elif class_or_instance in COERCION_INDEX_BY_TYPE: - return COERCION_INDEX_BY_TYPE[class_or_instance] - raise ApiValueError("Unsupported type: %s" % class_or_instance) - - sorted_types = sorted( - required_types, - key=lambda class_or_instance: index_getter(class_or_instance) - ) - return sorted_types - - -def remove_uncoercible(required_types_classes, current_item, spec_property_naming, - must_convert=True): - """Only keeps the type conversions that are possible - - Args: - required_types_classes (tuple): tuple of classes that are required - these should be ordered by COERCION_INDEX_BY_TYPE - spec_property_naming (bool): True if the variable names in the input - data are serialized names as specified in the OpenAPI document. - False if the variables names in the input data are python - variable names in PEP-8 snake case. - current_item (any): the current item (input data) to be converted - - Keyword Args: - must_convert (bool): if True the item to convert is of the wrong - type and we want a big list of coercibles - if False, we want a limited list of coercibles - - Returns: - (list): the remaining coercible required types, classes only - """ - current_type_simple = get_simple_class(current_item) - - results_classes = [] - for required_type_class in required_types_classes: - # convert our models to OpenApiModel - required_type_class_simplified = required_type_class - if isinstance(required_type_class_simplified, type): - if issubclass(required_type_class_simplified, ModelComposed): - required_type_class_simplified = ModelComposed - elif issubclass(required_type_class_simplified, ModelNormal): - required_type_class_simplified = ModelNormal - elif issubclass(required_type_class_simplified, ModelSimple): - required_type_class_simplified = ModelSimple - - if required_type_class_simplified == current_type_simple: - # don't consider converting to one's own class - continue - - class_pair = (current_type_simple, required_type_class_simplified) - if must_convert and class_pair in COERCIBLE_TYPE_PAIRS[spec_property_naming]: - results_classes.append(required_type_class) - elif class_pair in UPCONVERSION_TYPE_PAIRS: - results_classes.append(required_type_class) - return results_classes - - -def get_discriminated_classes(cls): - """ - Returns all the classes that a discriminator converts to - TODO: lru_cache this - """ - possible_classes = [] - key = list(cls.discriminator.keys())[0] - if is_type_nullable(cls): - possible_classes.append(cls) - for discr_cls in cls.discriminator[key].values(): - if hasattr(discr_cls, 'discriminator') and discr_cls.discriminator is not None: - possible_classes.extend(get_discriminated_classes(discr_cls)) - else: - possible_classes.append(discr_cls) - return possible_classes - - -def get_possible_classes(cls, from_server_context): - # TODO: lru_cache this - possible_classes = [cls] - if from_server_context: - return possible_classes - if hasattr(cls, 'discriminator') and cls.discriminator is not None: - possible_classes = [] - possible_classes.extend(get_discriminated_classes(cls)) - elif issubclass(cls, ModelComposed): - possible_classes.extend(composed_model_input_classes(cls)) - return possible_classes - - -def get_required_type_classes(required_types_mixed, spec_property_naming): - """Converts the tuple required_types into a tuple and a dict described - below - - Args: - required_types_mixed (tuple/list): will contain either classes or - instance of list or dict - spec_property_naming (bool): if True these values came from the - server, and we use the data types in our endpoints. - If False, we are client side and we need to include - oneOf and discriminator classes inside the data types in our endpoints - - Returns: - (valid_classes, dict_valid_class_to_child_types_mixed): - valid_classes (tuple): the valid classes that the current item - should be - dict_valid_class_to_child_types_mixed (dict): - valid_class (class): this is the key - child_types_mixed (list/dict/tuple): describes the valid child - types - """ - valid_classes = [] - child_req_types_by_current_type = {} - for required_type in required_types_mixed: - if isinstance(required_type, list): - valid_classes.append(list) - child_req_types_by_current_type[list] = required_type - elif isinstance(required_type, tuple): - valid_classes.append(tuple) - child_req_types_by_current_type[tuple] = required_type - elif isinstance(required_type, dict): - valid_classes.append(dict) - child_req_types_by_current_type[dict] = required_type[str] - else: - valid_classes.extend(get_possible_classes(required_type, spec_property_naming)) - return tuple(valid_classes), child_req_types_by_current_type - - -def change_keys_js_to_python(input_dict, model_class): - """ - Converts from javascript_key keys in the input_dict to python_keys in - the output dict using the mapping in model_class. - If the input_dict contains a key which does not declared in the model_class, - the key is added to the output dict as is. The assumption is the model_class - may have undeclared properties (additionalProperties attribute in the OAS - document). - """ - - if getattr(model_class, 'attribute_map', None) is None: - return input_dict - output_dict = {} - reversed_attr_map = {value: key for key, value in - model_class.attribute_map.items()} - for javascript_key, value in input_dict.items(): - python_key = reversed_attr_map.get(javascript_key) - if python_key is None: - # if the key is unknown, it is in error or it is an - # additionalProperties variable - python_key = javascript_key - output_dict[python_key] = value - return output_dict - - -def get_type_error(var_value, path_to_item, valid_classes, key_type=False): - error_msg = type_error_message( - var_name=path_to_item[-1], - var_value=var_value, - valid_classes=valid_classes, - key_type=key_type - ) - return ApiTypeError( - error_msg, - path_to_item=path_to_item, - valid_classes=valid_classes, - key_type=key_type - ) - - -def deserialize_primitive(data, klass, path_to_item): - """Deserializes string to primitive type. - - :param data: str/int/float - :param klass: str/class the class to convert to - - :return: int, float, str, bool, date, datetime - """ - additional_message = "" - try: - if klass in {datetime, date}: - additional_message = ( - "If you need your parameter to have a fallback " - "string value, please set its type as `type: {}` in your " - "spec. That allows the value to be any type. " - ) - if klass == datetime: - if len(data) < 8: - raise ValueError("This is not a datetime") - # The string should be in iso8601 datetime format. - parsed_datetime = parse(data) - date_only = ( - parsed_datetime.hour == 0 and - parsed_datetime.minute == 0 and - parsed_datetime.second == 0 and - parsed_datetime.tzinfo is None and - 8 <= len(data) <= 10 - ) - if date_only: - raise ValueError("This is a date, not a datetime") - return parsed_datetime - elif klass == date: - if len(data) < 8: - raise ValueError("This is not a date") - return parse(data).date() - else: - converted_value = klass(data) - if isinstance(data, str) and klass == float: - if str(converted_value) != data: - # '7' -> 7.0 -> '7.0' != '7' - raise ValueError('This is not a float') - return converted_value - except (OverflowError, ValueError) as ex: - # parse can raise OverflowError - raise ApiValueError( - "{0}Failed to parse {1} as {2}".format( - additional_message, repr(data), klass.__name__ - ), - path_to_item=path_to_item - ) from ex - - -def get_discriminator_class(model_class, - discr_name, - discr_value, cls_visited): - """Returns the child class specified by the discriminator. - - Args: - model_class (OpenApiModel): the model class. - discr_name (string): the name of the discriminator property. - discr_value (any): the discriminator value. - cls_visited (list): list of model classes that have been visited. - Used to determine the discriminator class without - visiting circular references indefinitely. - - Returns: - used_model_class (class/None): the chosen child class that will be used - to deserialize the data, for example dog.Dog. - If a class is not found, None is returned. - """ - - if model_class in cls_visited: - # The class has already been visited and no suitable class was found. - return None - cls_visited.append(model_class) - used_model_class = None - if discr_name in model_class.discriminator: - class_name_to_discr_class = model_class.discriminator[discr_name] - used_model_class = class_name_to_discr_class.get(discr_value) - if used_model_class is None: - # We didn't find a discriminated class in class_name_to_discr_class. - # So look in the ancestor or descendant discriminators - # The discriminator mapping may exist in a descendant (anyOf, oneOf) - # or ancestor (allOf). - # Ancestor example: in the GrandparentAnimal -> ParentPet -> ChildCat - # hierarchy, the discriminator mappings may be defined at any level - # in the hierarchy. - # Descendant example: mammal -> whale/zebra/Pig -> BasquePig/DanishPig - # if we try to make BasquePig from mammal, we need to travel through - # the oneOf descendant discriminators to find BasquePig - descendant_classes = model_class._composed_schemas.get('oneOf', ()) + \ - model_class._composed_schemas.get('anyOf', ()) - ancestor_classes = model_class._composed_schemas.get('allOf', ()) - possible_classes = descendant_classes + ancestor_classes - for cls in possible_classes: - # Check if the schema has inherited discriminators. - if hasattr(cls, 'discriminator') and cls.discriminator is not None: - used_model_class = get_discriminator_class( - cls, discr_name, discr_value, cls_visited) - if used_model_class is not None: - return used_model_class - return used_model_class - - -def deserialize_model(model_data, model_class, path_to_item, check_type, - configuration, spec_property_naming): - """Deserializes model_data to model instance. - - Args: - model_data (int/str/float/bool/none_type/list/dict): data to instantiate the model - model_class (OpenApiModel): the model class - path_to_item (list): path to the model in the received data - check_type (bool): whether to check the data tupe for the values in - the model - configuration (Configuration): the instance to use to convert files - spec_property_naming (bool): True if the variable names in the input - data are serialized names as specified in the OpenAPI document. - False if the variables names in the input data are python - variable names in PEP-8 snake case. - - Returns: - model instance - - Raise: - ApiTypeError - ApiValueError - ApiKeyError - """ - - kw_args = dict(_check_type=check_type, - _path_to_item=path_to_item, - _configuration=configuration, - _spec_property_naming=spec_property_naming) - - if issubclass(model_class, ModelSimple): - return model_class._new_from_openapi_data(model_data, **kw_args) - elif isinstance(model_data, list): - return model_class._new_from_openapi_data(*model_data, **kw_args) - if isinstance(model_data, dict): - kw_args.update(model_data) - return model_class._new_from_openapi_data(**kw_args) - elif isinstance(model_data, PRIMITIVE_TYPES): - return model_class._new_from_openapi_data(model_data, **kw_args) - - -def deserialize_file(response_data, configuration, content_disposition=None): - """Deserializes body to file - - Saves response body into a file in a temporary folder, - using the filename from the `Content-Disposition` header if provided. - - Args: - param response_data (str): the file data to write - configuration (Configuration): the instance to use to convert files - - Keyword Args: - content_disposition (str): the value of the Content-Disposition - header - - Returns: - (file_type): the deserialized file which is open - The user is responsible for closing and reading the file - """ - fd, path = tempfile.mkstemp(dir=configuration.temp_folder_path) - os.close(fd) - os.remove(path) - - if content_disposition: - filename = re.search(r'filename=[\'"]?([^\'"\s]+)[\'"]?', - content_disposition, - flags=re.I) - if filename is not None: - filename = filename.group(1) - else: - filename = "default_" + str(uuid.uuid4()) - - path = os.path.join(os.path.dirname(path), filename) - - with open(path, "wb") as f: - if isinstance(response_data, str): - # change str to bytes so we can write it - response_data = response_data.encode('utf-8') - f.write(response_data) - - f = open(path, "rb") - return f - - -def attempt_convert_item(input_value, valid_classes, path_to_item, - configuration, spec_property_naming, key_type=False, - must_convert=False, check_type=True): - """ - Args: - input_value (any): the data to convert - valid_classes (any): the classes that are valid - path_to_item (list): the path to the item to convert - configuration (Configuration): the instance to use to convert files - spec_property_naming (bool): True if the variable names in the input - data are serialized names as specified in the OpenAPI document. - False if the variables names in the input data are python - variable names in PEP-8 snake case. - key_type (bool): if True we need to convert a key type (not supported) - must_convert (bool): if True we must convert - check_type (bool): if True we check the type or the returned data in - ModelComposed/ModelNormal/ModelSimple instances - - Returns: - instance (any) the fixed item - - Raises: - ApiTypeError - ApiValueError - ApiKeyError - """ - valid_classes_ordered = order_response_types(valid_classes) - valid_classes_coercible = remove_uncoercible( - valid_classes_ordered, input_value, spec_property_naming) - if not valid_classes_coercible or key_type: - # we do not handle keytype errors, json will take care - # of this for us - if configuration is None or not configuration.discard_unknown_keys: - raise get_type_error(input_value, path_to_item, valid_classes, - key_type=key_type) - for valid_class in valid_classes_coercible: - try: - if issubclass(valid_class, OpenApiModel): - return deserialize_model(input_value, valid_class, - path_to_item, check_type, - configuration, spec_property_naming) - elif valid_class == file_type: - return deserialize_file(input_value, configuration) - return deserialize_primitive(input_value, valid_class, - path_to_item) - except (ApiTypeError, ApiValueError, ApiKeyError) as conversion_exc: - if must_convert: - raise conversion_exc - # if we have conversion errors when must_convert == False - # we ignore the exception and move on to the next class - continue - # we were unable to convert, must_convert == False - return input_value - - -def is_type_nullable(input_type): - """ - Returns true if None is an allowed value for the specified input_type. - - A type is nullable if at least one of the following conditions is true: - 1. The OAS 'nullable' attribute has been specified, - 1. The type is the 'null' type, - 1. The type is a anyOf/oneOf composed schema, and a child schema is - the 'null' type. - Args: - input_type (type): the class of the input_value that we are - checking - Returns: - bool - """ - if input_type is none_type: - return True - if issubclass(input_type, OpenApiModel) and input_type._nullable: - return True - if issubclass(input_type, ModelComposed): - # If oneOf/anyOf, check if the 'null' type is one of the allowed types. - for t in input_type._composed_schemas.get('oneOf', ()): - if is_type_nullable(t): - return True - for t in input_type._composed_schemas.get('anyOf', ()): - if is_type_nullable(t): - return True - return False - - -def is_valid_type(input_class_simple, valid_classes): - """ - Args: - input_class_simple (class): the class of the input_value that we are - checking - valid_classes (tuple): the valid classes that the current item - should be - Returns: - bool - """ - if issubclass(input_class_simple, OpenApiModel) and \ - valid_classes == (bool, date, datetime, dict, float, int, list, str, none_type,): - return True - valid_type = input_class_simple in valid_classes - if not valid_type and ( - issubclass(input_class_simple, OpenApiModel) or - input_class_simple is none_type): - for valid_class in valid_classes: - if input_class_simple is none_type and is_type_nullable(valid_class): - # Schema is oneOf/anyOf and the 'null' type is one of the allowed types. - return True - if not (issubclass(valid_class, OpenApiModel) and valid_class.discriminator): - continue - discr_propertyname_py = list(valid_class.discriminator.keys())[0] - discriminator_classes = ( - valid_class.discriminator[discr_propertyname_py].values() - ) - valid_type = is_valid_type(input_class_simple, discriminator_classes) - if valid_type: - return True - return valid_type - - -def validate_and_convert_types(input_value, required_types_mixed, path_to_item, - spec_property_naming, _check_type, configuration=None): - """Raises a TypeError is there is a problem, otherwise returns value - - Args: - input_value (any): the data to validate/convert - required_types_mixed (list/dict/tuple): A list of - valid classes, or a list tuples of valid classes, or a dict where - the value is a tuple of value classes - path_to_item: (list) the path to the data being validated - this stores a list of keys or indices to get to the data being - validated - spec_property_naming (bool): True if the variable names in the input - data are serialized names as specified in the OpenAPI document. - False if the variables names in the input data are python - variable names in PEP-8 snake case. - _check_type: (boolean) if true, type will be checked and conversion - will be attempted. - configuration: (Configuration): the configuration class to use - when converting file_type items. - If passed, conversion will be attempted when possible - If not passed, no conversions will be attempted and - exceptions will be raised - - Returns: - the correctly typed value - - Raises: - ApiTypeError - """ - results = get_required_type_classes(required_types_mixed, spec_property_naming) - valid_classes, child_req_types_by_current_type = results - - input_class_simple = get_simple_class(input_value) - valid_type = is_valid_type(input_class_simple, valid_classes) - if not valid_type: - if (configuration - or (input_class_simple == dict - and dict not in valid_classes)): - # if input_value is not valid_type try to convert it - converted_instance = attempt_convert_item( - input_value, - valid_classes, - path_to_item, - configuration, - spec_property_naming, - key_type=False, - must_convert=True, - check_type=_check_type - ) - return converted_instance - else: - raise get_type_error(input_value, path_to_item, valid_classes, - key_type=False) - - # input_value's type is in valid_classes - if len(valid_classes) > 1 and configuration: - # there are valid classes which are not the current class - valid_classes_coercible = remove_uncoercible( - valid_classes, input_value, spec_property_naming, must_convert=False) - if valid_classes_coercible: - converted_instance = attempt_convert_item( - input_value, - valid_classes_coercible, - path_to_item, - configuration, - spec_property_naming, - key_type=False, - must_convert=False, - check_type=_check_type - ) - return converted_instance - - if child_req_types_by_current_type == {}: - # all types are of the required types and there are no more inner - # variables left to look at - return input_value - inner_required_types = child_req_types_by_current_type.get( - type(input_value) - ) - if inner_required_types is None: - # for this type, there are not more inner variables left to look at - return input_value - if isinstance(input_value, list): - if input_value == []: - # allow an empty list - return input_value - for index, inner_value in enumerate(input_value): - inner_path = list(path_to_item) - inner_path.append(index) - input_value[index] = validate_and_convert_types( - inner_value, - inner_required_types, - inner_path, - spec_property_naming, - _check_type, - configuration=configuration - ) - elif isinstance(input_value, dict): - if input_value == {}: - # allow an empty dict - return input_value - for inner_key, inner_val in input_value.items(): - inner_path = list(path_to_item) - inner_path.append(inner_key) - if get_simple_class(inner_key) != str: - raise get_type_error(inner_key, inner_path, valid_classes, - key_type=True) - input_value[inner_key] = validate_and_convert_types( - inner_val, - inner_required_types, - inner_path, - spec_property_naming, - _check_type, - configuration=configuration - ) - return input_value - - -def model_to_dict(model_instance, serialize=True): - """Returns the model properties as a dict - - Args: - model_instance (one of your model instances): the model instance that - will be converted to a dict. - - Keyword Args: - serialize (bool): if True, the keys in the dict will be values from - attribute_map - """ - result = {} - - def extract_item(item): return ( - item[0], model_to_dict( - item[1], serialize=serialize)) if hasattr( - item[1], '_data_store') else item - - model_instances = [model_instance] - if model_instance._composed_schemas: - model_instances.extend(model_instance._composed_instances) - seen_json_attribute_names = set() - used_fallback_python_attribute_names = set() - py_to_json_map = {} - for model_instance in model_instances: - for attr, value in model_instance._data_store.items(): - if serialize: - # we use get here because additional property key names do not - # exist in attribute_map - try: - attr = model_instance.attribute_map[attr] - py_to_json_map.update(model_instance.attribute_map) - seen_json_attribute_names.add(attr) - except KeyError: - used_fallback_python_attribute_names.add(attr) - if isinstance(value, list): - if not value: - # empty list or None - result[attr] = value - else: - res = [] - for v in value: - if isinstance(v, PRIMITIVE_TYPES) or v is None: - res.append(v) - elif isinstance(v, ModelSimple): - res.append(v.value) - elif isinstance(v, dict): - res.append(dict(map( - extract_item, - v.items() - ))) - else: - res.append(model_to_dict(v, serialize=serialize)) - result[attr] = res - elif isinstance(value, dict): - result[attr] = dict(map( - extract_item, - value.items() - )) - elif isinstance(value, ModelSimple): - result[attr] = value.value - elif hasattr(value, '_data_store'): - result[attr] = model_to_dict(value, serialize=serialize) - else: - result[attr] = value - if serialize: - for python_key in used_fallback_python_attribute_names: - json_key = py_to_json_map.get(python_key) - if json_key is None: - continue - if python_key == json_key: - continue - json_key_assigned_no_need_for_python_key = json_key in seen_json_attribute_names - if json_key_assigned_no_need_for_python_key: - del result[python_key] - - return result - - -def type_error_message(var_value=None, var_name=None, valid_classes=None, - key_type=None): - """ - Keyword Args: - var_value (any): the variable which has the type_error - var_name (str): the name of the variable which has the typ error - valid_classes (tuple): the accepted classes for current_item's - value - key_type (bool): False if our value is a value in a dict - True if it is a key in a dict - False if our item is an item in a list - """ - key_or_value = 'value' - if key_type: - key_or_value = 'key' - valid_classes_phrase = get_valid_classes_phrase(valid_classes) - msg = ( - "Invalid type for variable '{0}'. Required {1} type {2} and " - "passed type was {3}".format( - var_name, - key_or_value, - valid_classes_phrase, - type(var_value).__name__, - ) - ) - return msg - - -def get_valid_classes_phrase(input_classes): - """Returns a string phrase describing what types are allowed - """ - all_classes = list(input_classes) - all_classes = sorted(all_classes, key=lambda cls: cls.__name__) - all_class_names = [cls.__name__ for cls in all_classes] - if len(all_class_names) == 1: - return 'is {0}'.format(all_class_names[0]) - return "is one of [{0}]".format(", ".join(all_class_names)) - - -def get_allof_instances(self, model_args, constant_args): - """ - Args: - self: the class we are handling - model_args (dict): var_name to var_value - used to make instances - constant_args (dict): - metadata arguments: - _check_type - _path_to_item - _spec_property_naming - _configuration - _visited_composed_classes - - Returns - composed_instances (list) - """ - composed_instances = [] - for allof_class in self._composed_schemas['allOf']: - - try: - if constant_args.get('_spec_property_naming'): - allof_instance = allof_class._from_openapi_data(**model_args, **constant_args) - else: - allof_instance = allof_class(**model_args, **constant_args) - composed_instances.append(allof_instance) - except Exception as ex: - raise ApiValueError( - "Invalid inputs given to generate an instance of '%s'. The " - "input data was invalid for the allOf schema '%s' in the composed " - "schema '%s'. Error=%s" % ( - allof_class.__name__, - allof_class.__name__, - self.__class__.__name__, - str(ex) - ) - ) from ex - return composed_instances - - -def get_oneof_instance(cls, model_kwargs, constant_kwargs, model_arg=None): - """ - Find the oneOf schema that matches the input data (e.g. payload). - If exactly one schema matches the input data, an instance of that schema - is returned. - If zero or more than one schema match the input data, an exception is raised. - In OAS 3.x, the payload MUST, by validation, match exactly one of the - schemas described by oneOf. - - Args: - cls: the class we are handling - model_kwargs (dict): var_name to var_value - The input data, e.g. the payload that must match a oneOf schema - in the OpenAPI document. - constant_kwargs (dict): var_name to var_value - args that every model requires, including configuration, server - and path to item. - - Kwargs: - model_arg: (int, float, bool, str, date, datetime, ModelSimple, None): - the value to assign to a primitive class or ModelSimple class - Notes: - - this is only passed in when oneOf includes types which are not object - - None is used to suppress handling of model_arg, nullable models are handled in __new__ - - Returns - oneof_instance (instance) - """ - if len(cls._composed_schemas['oneOf']) == 0: - return None - - oneof_instances = [] - # Iterate over each oneOf schema and determine if the input data - # matches the oneOf schemas. - for oneof_class in cls._composed_schemas['oneOf']: - # The composed oneOf schema allows the 'null' type and the input data - # is the null value. This is a OAS >= 3.1 feature. - if oneof_class is none_type: - # skip none_types because we are deserializing dict data. - # none_type deserialization is handled in the __new__ method - continue - - single_value_input = allows_single_value_input(oneof_class) - - try: - if not single_value_input: - if constant_kwargs.get('_spec_property_naming'): - oneof_instance = oneof_class._from_openapi_data( - **model_kwargs, **constant_kwargs) - else: - oneof_instance = oneof_class(**model_kwargs, **constant_kwargs) - else: - if issubclass(oneof_class, ModelSimple): - if constant_kwargs.get('_spec_property_naming'): - oneof_instance = oneof_class._from_openapi_data( - model_arg, **constant_kwargs) - else: - oneof_instance = oneof_class(model_arg, **constant_kwargs) - elif oneof_class in PRIMITIVE_TYPES: - oneof_instance = validate_and_convert_types( - model_arg, - (oneof_class,), - constant_kwargs['_path_to_item'], - constant_kwargs['_spec_property_naming'], - constant_kwargs['_check_type'], - configuration=constant_kwargs['_configuration'] - ) - oneof_instances.append(oneof_instance) - except Exception: - pass - if len(oneof_instances) == 0: - raise ApiValueError( - "Invalid inputs given to generate an instance of %s. None " - "of the oneOf schemas matched the input data." % - cls.__name__ - ) - elif len(oneof_instances) > 1: - raise ApiValueError( - "Invalid inputs given to generate an instance of %s. Multiple " - "oneOf schemas matched the inputs, but a max of one is allowed." % - cls.__name__ - ) - return oneof_instances[0] - - -def get_anyof_instances(self, model_args, constant_args): - """ - Args: - self: the class we are handling - model_args (dict): var_name to var_value - The input data, e.g. the payload that must match at least one - anyOf child schema in the OpenAPI document. - constant_args (dict): var_name to var_value - args that every model requires, including configuration, server - and path to item. - - Returns - anyof_instances (list) - """ - anyof_instances = [] - if len(self._composed_schemas['anyOf']) == 0: - return anyof_instances - - for anyof_class in self._composed_schemas['anyOf']: - # The composed oneOf schema allows the 'null' type and the input data - # is the null value. This is a OAS >= 3.1 feature. - if anyof_class is none_type: - # skip none_types because we are deserializing dict data. - # none_type deserialization is handled in the __new__ method - continue - - try: - if constant_args.get('_spec_property_naming'): - anyof_instance = anyof_class._from_openapi_data(**model_args, **constant_args) - else: - anyof_instance = anyof_class(**model_args, **constant_args) - anyof_instances.append(anyof_instance) - except Exception: - pass - if len(anyof_instances) == 0: - raise ApiValueError( - "Invalid inputs given to generate an instance of %s. None of the " - "anyOf schemas matched the inputs." % - self.__class__.__name__ - ) - return anyof_instances - - -def get_discarded_args(self, composed_instances, model_args): - """ - Gathers the args that were discarded by configuration.discard_unknown_keys - """ - model_arg_keys = model_args.keys() - discarded_args = set() - # arguments passed to self were already converted to python names - # before __init__ was called - for instance in composed_instances: - if instance.__class__ in self._composed_schemas['allOf']: - try: - keys = instance.to_dict().keys() - discarded_keys = model_args - keys - discarded_args.update(discarded_keys) - except Exception: - # allOf integer schema will throw exception - pass - else: - try: - all_keys = set(model_to_dict(instance, serialize=False).keys()) - js_keys = model_to_dict(instance, serialize=True).keys() - all_keys.update(js_keys) - discarded_keys = model_arg_keys - all_keys - discarded_args.update(discarded_keys) - except Exception: - # allOf integer schema will throw exception - pass - return discarded_args - - -def validate_get_composed_info(constant_args, model_args, self): - """ - For composed schemas, generate schema instances for - all schemas in the oneOf/anyOf/allOf definition. If additional - properties are allowed, also assign those properties on - all matched schemas that contain additionalProperties. - Openapi schemas are python classes. - - Exceptions are raised if: - - 0 or > 1 oneOf schema matches the model_args input data - - no anyOf schema matches the model_args input data - - any of the allOf schemas do not match the model_args input data - - Args: - constant_args (dict): these are the args that every model requires - model_args (dict): these are the required and optional spec args that - were passed in to make this model - self (class): the class that we are instantiating - This class contains self._composed_schemas - - Returns: - composed_info (list): length three - composed_instances (list): the composed instances which are not - self - var_name_to_model_instances (dict): a dict going from var_name - to the model_instance which holds that var_name - the model_instance may be self or an instance of one of the - classes in self.composed_instances() - additional_properties_model_instances (list): a list of the - model instances which have the property - additional_properties_type. This list can include self - """ - # create composed_instances - composed_instances = [] - allof_instances = get_allof_instances(self, model_args, constant_args) - composed_instances.extend(allof_instances) - oneof_instance = get_oneof_instance(self.__class__, model_args, constant_args) - if oneof_instance is not None: - composed_instances.append(oneof_instance) - anyof_instances = get_anyof_instances(self, model_args, constant_args) - composed_instances.extend(anyof_instances) - """ - set additional_properties_model_instances - additional properties must be evaluated at the schema level - so self's additional properties are most important - If self is a composed schema with: - - no properties defined in self - - additionalProperties: False - Then for object payloads every property is an additional property - and they are not allowed, so only empty dict is allowed - - Properties must be set on all matching schemas - so when a property is assigned toa composed instance, it must be set on all - composed instances regardless of additionalProperties presence - keeping it to prevent breaking changes in v5.0.1 - TODO remove cls._additional_properties_model_instances in 6.0.0 - """ - additional_properties_model_instances = [] - if self.additional_properties_type is not None: - additional_properties_model_instances = [self] - - """ - no need to set properties on self in here, they will be set in __init__ - By here all composed schema oneOf/anyOf/allOf instances have their properties set using - model_args - """ - discarded_args = get_discarded_args(self, composed_instances, model_args) - - # map variable names to composed_instances - var_name_to_model_instances = {} - for prop_name in model_args: - if prop_name not in discarded_args: - var_name_to_model_instances[prop_name] = [self] + list( - filter( - lambda x: prop_name in x.openapi_types, composed_instances)) - - return [ - composed_instances, - var_name_to_model_instances, - additional_properties_model_instances, - discarded_args - ] diff --git a/observatory-api/observatory/api/client/rest.py b/observatory-api/observatory/api/client/rest.py deleted file mode 100644 index c9eed8b97..000000000 --- a/observatory-api/observatory/api/client/rest.py +++ /dev/null @@ -1,353 +0,0 @@ -""" - Observatory API - - The REST API for managing and accessing data from the Observatory Platform. # noqa: E501 - - The version of the OpenAPI document: 1.0.0 - Contact: agent@observatory.academy - Generated by: https://openapi-generator.tech -""" - - -import io -import json -import logging -import re -import ssl -from urllib.parse import urlencode -from urllib.parse import urlparse -from urllib.request import proxy_bypass_environment -import urllib3 -import ipaddress - -from observatory.api.client.exceptions import ApiException, UnauthorizedException, ForbiddenException, NotFoundException, ServiceException, ApiValueError - - -logger = logging.getLogger(__name__) - - -class RESTResponse(io.IOBase): - - def __init__(self, resp): - self.urllib3_response = resp - self.status = resp.status - self.reason = resp.reason - self.data = resp.data - - def getheaders(self): - """Returns a dictionary of the response headers.""" - return self.urllib3_response.getheaders() - - def getheader(self, name, default=None): - """Returns a given response header.""" - return self.urllib3_response.getheader(name, default) - - -class RESTClientObject(object): - - def __init__(self, configuration, pools_size=4, maxsize=None): - # urllib3.PoolManager will pass all kw parameters to connectionpool - # https://github.com/shazow/urllib3/blob/f9409436f83aeb79fbaf090181cd81b784f1b8ce/urllib3/poolmanager.py#L75 # noqa: E501 - # https://github.com/shazow/urllib3/blob/f9409436f83aeb79fbaf090181cd81b784f1b8ce/urllib3/connectionpool.py#L680 # noqa: E501 - # maxsize is the number of requests to host that are allowed in parallel # noqa: E501 - # Custom SSL certificates and client certificates: http://urllib3.readthedocs.io/en/latest/advanced-usage.html # noqa: E501 - - # cert_reqs - if configuration.verify_ssl: - cert_reqs = ssl.CERT_REQUIRED - else: - cert_reqs = ssl.CERT_NONE - - addition_pool_args = {} - if configuration.assert_hostname is not None: - addition_pool_args['assert_hostname'] = configuration.assert_hostname # noqa: E501 - - if configuration.retries is not None: - addition_pool_args['retries'] = configuration.retries - - if configuration.socket_options is not None: - addition_pool_args['socket_options'] = configuration.socket_options - - if maxsize is None: - if configuration.connection_pool_maxsize is not None: - maxsize = configuration.connection_pool_maxsize - else: - maxsize = 4 - - # https pool manager - if configuration.proxy and not should_bypass_proxies( - configuration.host, no_proxy=configuration.no_proxy or ''): - self.pool_manager = urllib3.ProxyManager( - num_pools=pools_size, - maxsize=maxsize, - cert_reqs=cert_reqs, - ca_certs=configuration.ssl_ca_cert, - cert_file=configuration.cert_file, - key_file=configuration.key_file, - proxy_url=configuration.proxy, - proxy_headers=configuration.proxy_headers, - **addition_pool_args - ) - else: - self.pool_manager = urllib3.PoolManager( - num_pools=pools_size, - maxsize=maxsize, - cert_reqs=cert_reqs, - ca_certs=configuration.ssl_ca_cert, - cert_file=configuration.cert_file, - key_file=configuration.key_file, - **addition_pool_args - ) - - def request(self, method, url, query_params=None, headers=None, - body=None, post_params=None, _preload_content=True, - _request_timeout=None): - """Perform requests. - - :param method: http request method - :param url: http request url - :param query_params: query parameters in the url - :param headers: http request headers - :param body: request json body, for `application/json` - :param post_params: request post parameters, - `application/x-www-form-urlencoded` - and `multipart/form-data` - :param _preload_content: if False, the urllib3.HTTPResponse object will - be returned without reading/decoding response - data. Default is True. - :param _request_timeout: timeout setting for this request. If one - number provided, it will be total request - timeout. It can also be a pair (tuple) of - (connection, read) timeouts. - """ - method = method.upper() - assert method in ['GET', 'HEAD', 'DELETE', 'POST', 'PUT', - 'PATCH', 'OPTIONS'] - - if post_params and body: - raise ApiValueError( - "body parameter cannot be used with post_params parameter." - ) - - post_params = post_params or {} - headers = headers or {} - - timeout = None - if _request_timeout: - if isinstance(_request_timeout, (int, float)): # noqa: E501,F821 - timeout = urllib3.Timeout(total=_request_timeout) - elif (isinstance(_request_timeout, tuple) and - len(_request_timeout) == 2): - timeout = urllib3.Timeout( - connect=_request_timeout[0], read=_request_timeout[1]) - - try: - # For `POST`, `PUT`, `PATCH`, `OPTIONS`, `DELETE` - if method in ['POST', 'PUT', 'PATCH', 'OPTIONS', 'DELETE']: - # Only set a default Content-Type for POST, PUT, PATCH and OPTIONS requests - if (method != 'DELETE') and ('Content-Type' not in headers): - headers['Content-Type'] = 'application/json' - if query_params: - url += '?' + urlencode(query_params) - if ('Content-Type' not in headers) or (re.search('json', - headers['Content-Type'], re.IGNORECASE)): - request_body = None - if body is not None: - request_body = json.dumps(body) - r = self.pool_manager.request( - method, url, - body=request_body, - preload_content=_preload_content, - timeout=timeout, - headers=headers) - elif headers['Content-Type'] == 'application/x-www-form-urlencoded': # noqa: E501 - r = self.pool_manager.request( - method, url, - fields=post_params, - encode_multipart=False, - preload_content=_preload_content, - timeout=timeout, - headers=headers) - elif headers['Content-Type'] == 'multipart/form-data': - # must del headers['Content-Type'], or the correct - # Content-Type which generated by urllib3 will be - # overwritten. - del headers['Content-Type'] - r = self.pool_manager.request( - method, url, - fields=post_params, - encode_multipart=True, - preload_content=_preload_content, - timeout=timeout, - headers=headers) - # Pass a `string` parameter directly in the body to support - # other content types than Json when `body` argument is - # provided in serialized form - elif isinstance(body, str) or isinstance(body, bytes): - request_body = body - r = self.pool_manager.request( - method, url, - body=request_body, - preload_content=_preload_content, - timeout=timeout, - headers=headers) - else: - # Cannot generate the request from given parameters - msg = """Cannot prepare a request message for provided - arguments. Please check that your arguments match - declared content type.""" - raise ApiException(status=0, reason=msg) - # For `GET`, `HEAD` - else: - r = self.pool_manager.request(method, url, - fields=query_params, - preload_content=_preload_content, - timeout=timeout, - headers=headers) - except urllib3.exceptions.SSLError as e: - msg = "{0}\n{1}".format(type(e).__name__, str(e)) - raise ApiException(status=0, reason=msg) - - if _preload_content: - r = RESTResponse(r) - - # log response body - logger.debug("response body: %s", r.data) - - if not 200 <= r.status <= 299: - if r.status == 401: - raise UnauthorizedException(http_resp=r) - - if r.status == 403: - raise ForbiddenException(http_resp=r) - - if r.status == 404: - raise NotFoundException(http_resp=r) - - if 500 <= r.status <= 599: - raise ServiceException(http_resp=r) - - raise ApiException(http_resp=r) - - return r - - def GET(self, url, headers=None, query_params=None, _preload_content=True, - _request_timeout=None): - return self.request("GET", url, - headers=headers, - _preload_content=_preload_content, - _request_timeout=_request_timeout, - query_params=query_params) - - def HEAD(self, url, headers=None, query_params=None, _preload_content=True, - _request_timeout=None): - return self.request("HEAD", url, - headers=headers, - _preload_content=_preload_content, - _request_timeout=_request_timeout, - query_params=query_params) - - def OPTIONS(self, url, headers=None, query_params=None, post_params=None, - body=None, _preload_content=True, _request_timeout=None): - return self.request("OPTIONS", url, - headers=headers, - query_params=query_params, - post_params=post_params, - _preload_content=_preload_content, - _request_timeout=_request_timeout, - body=body) - - def DELETE(self, url, headers=None, query_params=None, body=None, - _preload_content=True, _request_timeout=None): - return self.request("DELETE", url, - headers=headers, - query_params=query_params, - _preload_content=_preload_content, - _request_timeout=_request_timeout, - body=body) - - def POST(self, url, headers=None, query_params=None, post_params=None, - body=None, _preload_content=True, _request_timeout=None): - return self.request("POST", url, - headers=headers, - query_params=query_params, - post_params=post_params, - _preload_content=_preload_content, - _request_timeout=_request_timeout, - body=body) - - def PUT(self, url, headers=None, query_params=None, post_params=None, - body=None, _preload_content=True, _request_timeout=None): - return self.request("PUT", url, - headers=headers, - query_params=query_params, - post_params=post_params, - _preload_content=_preload_content, - _request_timeout=_request_timeout, - body=body) - - def PATCH(self, url, headers=None, query_params=None, post_params=None, - body=None, _preload_content=True, _request_timeout=None): - return self.request("PATCH", url, - headers=headers, - query_params=query_params, - post_params=post_params, - _preload_content=_preload_content, - _request_timeout=_request_timeout, - body=body) - -# end of class RESTClientObject - - -def is_ipv4(target): - """ Test if IPv4 address or not - """ - try: - chk = ipaddress.IPv4Address(target) - return True - except ipaddress.AddressValueError: - return False - - -def in_ipv4net(target, net): - """ Test if target belongs to given IPv4 network - """ - try: - nw = ipaddress.IPv4Network(net) - ip = ipaddress.IPv4Address(target) - if ip in nw: - return True - return False - except ipaddress.AddressValueError: - return False - except ipaddress.NetmaskValueError: - return False - - -def should_bypass_proxies(url, no_proxy=None): - """ Yet another requests.should_bypass_proxies - Test if proxies should not be used for a particular url. - """ - - parsed = urlparse(url) - - # special cases - if parsed.hostname in [None, '']: - return True - - # special cases - if no_proxy in [None, '']: - return False - if no_proxy == '*': - return True - - no_proxy = no_proxy.lower().replace(' ', ''); - entries = ( - host for host in no_proxy.split(',') if host - ) - - if is_ipv4(parsed.hostname): - for item in entries: - if in_ipv4net(parsed.hostname, item): - return True - return proxy_bypass_environment(parsed.hostname, {'no': no_proxy}) diff --git a/observatory-api/observatory/api/server/api.py b/observatory-api/observatory/api/server/api.py deleted file mode 100644 index 58c6cee1e..000000000 --- a/observatory-api/observatory/api/server/api.py +++ /dev/null @@ -1,227 +0,0 @@ -# Copyright 2021 Curtin University -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Author: Aniek Roelofs, James Diprose - -from __future__ import annotations - -import logging -import os -from typing import Any, ClassVar, Dict, Tuple - -import connexion -import pendulum -from connexion import NoContent -from flask import jsonify -from sqlalchemy import and_ - -from observatory.api.server.openapi_renderer import OpenApiRenderer -from observatory.api.server.orm import ( - DatasetRelease, -) - -Response = Tuple[Any, int] -session_ = None # Global session - - -def get_item(cls: ClassVar, item_id: int): - """Get an item. - - :param cls: the SQLAlchemy Table metadata class. - :param item_id: the id of the item. - :return: a Response object. - """ - - item = session_.query(cls).filter(cls.id == item_id).one_or_none() - if item is not None: - logging.info(f"Found: {cls.__name__} with id {item_id}") - return jsonify(item) - - body = f"Not found: {cls.__name__} with id {item_id}" - logging.info(body) - return body, 404 - - -def post_item(cls: ClassVar, body: Dict) -> Response: - """Create an item. - - :param cls: the SQLAlchemy Table metadata class. - :param body: the item in the form of a dictionary. - :return: a Response object. - """ - - logging.info(f"Creating item: {cls.__name__}") - - # Automatically set created and modified datetime - now = pendulum.now("UTC") - body["created"] = now - body["modified"] = now - - create_item = cls(**body) - session_.add(create_item) - session_.flush() - session_.commit() - - logging.info(f"Created: {cls.__name__} with id {create_item.id}") - return jsonify(create_item), 201 - - -def put_item(cls: ClassVar, body: Dict) -> Response: - """Create or update an item. If the item has an id it will be updated, else it will be created. - - :param cls: the SQLAlchemy Table metadata class. - :param body: the item in the form of a dictionary. - :return: a Response object. - """ - - item_id = body.get("id") - if item_id is not None: - item = session_.query(cls).filter(cls.id == item_id).one_or_none() - - if item is not None: - logging.info(f"Updating {cls.__name__} {item_id}") - # Remove id and automatically set modified time - body.pop("id") - body["modified"] = pendulum.now("UTC") - item.update(**body) - session_.commit() - - logging.info(f"Updated: {cls.__name__} with id {item_id}") - return jsonify(item), 200 - else: - body = f"Not found: {cls.__name__} with id {item_id}" - logging.info(body) - return body, 404 - else: - return post_item(cls, body) - - -def delete_item(cls: ClassVar, item_id: int) -> Response: - """Delete an item. - - :param cls: the SQLAlchemy Table metadata class. - :param id: the id of the item. - :return: a Response object. - """ - - org = session_.query(cls).filter(cls.id == item_id).one_or_none() - if org is not None: - logging.info(f"Deleting {cls.__name__} {item_id}") - session_.query(cls).filter(cls.id == item_id).delete() - session_.commit() - - logging.info(f"Deleted: {cls.__name__} with id {item_id}") - return NoContent, 200 - else: - body = f"Not found: {cls.__name__} with id {item_id}" - logging.info(body) - return body, 404 - - -def get_items(cls: ClassVar, limit: int) -> Response: - """Get a list of items. - - :param cls: the SQLAlchemy Table metadata class. - :param limit: the maximum number of items to return. - :return: a Response object. - """ - - items = session_.query(cls).limit(limit).all() - - logging.info(f"Found items: {cls.__name__} {items}") - return jsonify(items) - - -def get_dataset_release(id: int) -> Response: - """Get a DatasetRelease. - - :param id: the DatasetRelease id. - :return: a Response object. - """ - - return get_item(DatasetRelease, id) - - -def post_dataset_release(body: Dict) -> Response: - """Create a DatasetRelease. - - :param body: the DatasetRelease in the form of a dictionary. - :return: a Response object. - """ - - return post_item(DatasetRelease, body) - - -def put_dataset_release(body: Dict) -> Response: - """Create or update a DatasetRelease. - - :param body: the DatasetRelease in the form of a dictionary. - :return: a Response object. - """ - - return put_item(DatasetRelease, body) - - -def delete_dataset_release(id: int) -> Response: - """Delete a DatasetRelease. - - :param id: the DatasetRelease id. - :return: a Response object. - """ - - return delete_item(DatasetRelease, id) - - -def get_dataset_releases(dag_id: str = None, dataset_id: str = None) -> Response: - """Get a list of DatasetRelease objects. - - :param dag_id: the dag_id to query - :param dataset_id: the dataset_id to query - :return: a Response object. - """ - - q = session_.query(DatasetRelease) - - # Create filters based on parameters - filters = [] - if dag_id is not None: - filters.append(DatasetRelease.dag_id == dag_id) - if dataset_id is not None: - filters.append(DatasetRelease.dataset_id == dataset_id) - if len(filters): - q = q.filter(and_(*filters)) - - # Return items that match - return q.all() - - -def create_app() -> connexion.App: - """Create a Connexion App. - - :return: the Connexion App. - """ - - logging.info("Creating app") - - # Create the application instance and don't sort JSON output alphabetically - conn_app = connexion.App(__name__) - conn_app.app.config["JSON_SORT_KEYS"] = False - - # Add the OpenAPI specification - specification_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "openapi.yaml.jinja2") - builder = OpenApiRenderer(specification_path) - specification = builder.to_dict() - conn_app.add_api(specification) - - return conn_app diff --git a/observatory-api/observatory/api/server/app.py b/observatory-api/observatory/api/server/app.py deleted file mode 100644 index e7cc69009..000000000 --- a/observatory-api/observatory/api/server/app.py +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright 2020-2021 Curtin University -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Author: Aniek Roelofs, James Diprose - -from __future__ import annotations - -from observatory.api.server.api import create_app -from observatory.api.server.orm import create_session, set_session - -# Setup Observatory DB Session -session_ = create_session() -set_session(session_) - -# Create the Connexion App -app = create_app() - - -@app.app.teardown_appcontext -def remove_session(exception=None): - """Remove the SQLAlchemy session. - - :param exception: - :return: None. - """ - - if session_ is not None: - session_.remove() - - -if __name__ == "__main__": - app.run(debug=True) diff --git a/observatory-api/observatory/api/server/openapi.yaml.jinja2 b/observatory-api/observatory/api/server/openapi.yaml.jinja2 deleted file mode 100644 index bd896378b..000000000 --- a/observatory-api/observatory/api/server/openapi.yaml.jinja2 +++ /dev/null @@ -1,244 +0,0 @@ -swagger: '2.0' -info: - title: Observatory API - description: | - The REST API for managing and accessing data from the Observatory Platform. - version: 1.0.0 - contact: - email: agent@observatory.academy - license: - name: Apache 2.0 - url: http://www.apache.org/licenses/LICENSE-2.0.html - -{# API Client specific settings #} -{%- if api_client %} -host: localhost:5002 -schemes: - - https -produces: - - application/json -securityDefinitions: - # This section configures basic authentication with an API key. - api_key: - type: "apiKey" - name: "key" - in: "query" -security: - - api_key: [] -{%- endif %} - -{# Tag settings #} -{%- if api_client %} -tags: -- name: Observatory - description: the Observatory API -{%- set dataset_release_tag="Observatory" -%} -{% else %} -tags: -- name: DatasetRelease - description: Create, read, update, delete and list information about the dataset release info. -{%- set dataset_release_tag="DatasetRelease" -%} -{%- endif %} - -{# Operation id namespace #} -{%- if api_client %} -{%- set operation_id_namespace="" -%} -{% else %} -{%- set operation_id_namespace="observatory.api.server.api." -%} -{%- endif %} - -paths: - /v1/dataset_release: - get: - tags: - - {{ dataset_release_tag }} - summary: get a DatasetRelease - operationId: {{ operation_id_namespace }}get_dataset_release - description: | - Get the details of a DatasetRelease by passing it's id. - produces: - - application/json - parameters: - - in: query - name: id - description: DatasetRelease id - required: true - type: integer - responses: - 200: - description: the fetched DatasetRelease - schema: - $ref: '#/definitions/DatasetRelease' - 400: - description: bad input parameter - post: - tags: - - {{ dataset_release_tag }} - summary: create a DatasetRelease - operationId: {{ operation_id_namespace }}post_dataset_release - description: | - Create a DatasetRelease by passing a DatasetRelease object, without an id. - consumes: - - application/json - produces: - - application/json - parameters: - - in: body - name: body - description: DatasetRelease to create - required: true - schema: - $ref: '#/definitions/DatasetRelease' - responses: - 201: - description: DatasetRelease created, returning the created object with an id - schema: - $ref: '#/definitions/DatasetRelease' - put: - tags: - - {{ dataset_release_tag }} - summary: create or update a DatasetRelease - operationId: {{ operation_id_namespace }}put_dataset_release - description: | - Create a DatasetRelease by passing a DatasetRelease object, without an id. Update an existing DatasetRelease by - passing a DatasetRelease object with an id. - consumes: - - application/json - produces: - - application/json - parameters: - - in: body - name: body - description: DatasetRelease to create or update - required: true - schema: - $ref: '#/definitions/DatasetRelease' - responses: - 200: - description: DatasetRelease updated - schema: - $ref: '#/definitions/DatasetRelease' - 201: - description: DatasetRelease created, returning the created object with an id - schema: - $ref: '#/definitions/DatasetRelease' - delete: - tags: - - {{ dataset_release_tag }} - summary: delete a DatasetRelease - operationId: {{ operation_id_namespace }}delete_dataset_release - description: | - Delete a DatasetRelease by passing it's id. - consumes: - - application/json - produces: - - application/json - parameters: - - in: query - name: id - description: DatasetRelease id - required: true - type: integer - responses: - 200: - description: DatasetRelease deleted - - /v1/dataset_releases: - get: - tags: - - {{ dataset_release_tag }} - summary: Get a list of DatasetRelease objects - operationId: {{ operation_id_namespace }}get_dataset_releases - description: | - Get a list of DatasetRelease objects - produces: - - application/json - parameters: - - in: query - name: dag_id - description: the dag_id to fetch release info for - required: false - type: string - - in: query - name: dataset_id - description: the dataset_id to fetch release info for - required: false - type: string - responses: - 200: - description: a list of DatasetRelease objects - schema: - type: array - items: - $ref: '#/definitions/DatasetRelease' - 400: - description: bad input parameter - -definitions: - DatasetRelease: - type: object - properties: - id: - type: integer - dag_id: - type: string - example: "doi_workflow" - dataset_id: - type: string - example: "doi" - dag_run_id: - type: string - example: "YYYY-MM-DDTHH:mm:ss.ssssss" - x-nullable: true - data_interval_start: - type: string - format: date-time - example: 2020-01-02T20:01:05Z - x-nullable: true - data_interval_end: - type: string - format: date-time - example: 2020-01-02T20:01:05Z - x-nullable: true - snapshot_date: - type: string - format: date-time - example: 2020-01-02T20:01:05Z - x-nullable: true - partition_date: - type: string - format: date-time - example: 2020-01-02T20:01:05Z - x-nullable: true - changefile_start_date: - type: string - format: date-time - example: 2020-01-02T20:01:05Z - x-nullable: true - changefile_end_date: - type: string - format: date-time - example: 2020-01-02T20:01:05Z - x-nullable: true - sequence_start: - type: integer - example: 1 - x-nullable: true - sequence_end: - type: integer - example: 3 - x-nullable: true - created: - type: string - format: date-time - readOnly: true - modified: - type: string - format: date-time - readOnly: true - extra: - type: object - example: {'view-id': '830'} - minLength: 1 - maxLength: 512 - x-nullable: true diff --git a/observatory-api/observatory/api/server/openapi_renderer.py b/observatory-api/observatory/api/server/openapi_renderer.py deleted file mode 100644 index 2a284f778..000000000 --- a/observatory-api/observatory/api/server/openapi_renderer.py +++ /dev/null @@ -1,73 +0,0 @@ -# Copyright 2020 Artificial Dimensions Ltd -# Copyright 2020-2021 Curtin University -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Author: James Diprose - -from typing import Dict - -import yaml -from jinja2 import Template - - -def render_template(template_path: str, **kwargs) -> str: - """Render a Jinja2 template. - - :param template_path: the path to the template. - :param kwargs: the keyword variables to populate the template with. - :return: the rendered template as a string. - """ - - # Read file contents - with open(template_path, "r") as file: - contents = file.read() - - # Fill template with text - template = Template(contents) - - # Render template - rendered = template.render(**kwargs) - - return rendered - - -class OpenApiRenderer: - def __init__(self, openapi_template_path: str, api_client: bool = False): - """Construct an object that renders an OpenAPI 2 Jinja2 file. - - :param openapi_template_path: the path to the OpenAPI 2 Jinja2 template. - :param api_client: whether to render the file for the Server (default) or the Client. - """ - - self.openapi_template_path = openapi_template_path - self.api_client = api_client - - def render(self) -> str: - """Render the OpenAPI file. - - :return: the rendered output. - """ - - return render_template( - self.openapi_template_path, - api_client=self.api_client, - ) - - def to_dict(self) -> Dict: - """Render and output the OpenAPI file as a dictionary. - - :return: the dictionary. - """ - - return yaml.safe_load(self.render()) diff --git a/observatory-api/observatory/api/server/orm.py b/observatory-api/observatory/api/server/orm.py deleted file mode 100644 index fb7944c54..000000000 --- a/observatory-api/observatory/api/server/orm.py +++ /dev/null @@ -1,287 +0,0 @@ -# Copyright 2021 Curtin University -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Author: James Diprose - - -from __future__ import annotations - -import os -from dataclasses import dataclass -from typing import Any, ClassVar, Dict, Union - -import pendulum -from sqlalchemy import ( - Column, - DateTime, - JSON, - Integer, - String, - create_engine, -) -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import scoped_session, sessionmaker - -Base = declarative_base() -session_ = None # Global session - - -def create_session(uri: str = os.environ.get("OBSERVATORY_DB_URI"), connect_args=None, poolclass=None): - """Create an SQLAlchemy session. - - :param uri: the database URI. - :param connect_args: connect arguments for SQLAlchemy. - :param poolclass: what SQLAlchemy poolclass to use. - :return: the SQLAlchemy session. - """ - - if uri is None: - raise ValueError( - "observatory.api.orm.create_session: please set the create_session `uri` parameter " - "or the environment variable OBSERVATORY_DB_URI with a valid PostgreSQL workflow string" - ) - - if connect_args is None: - connect_args = dict() - - engine = create_engine(uri, convert_unicode=True, connect_args=connect_args, poolclass=poolclass) - s = scoped_session(sessionmaker(autocommit=False, autoflush=False, bind=engine)) - Base.query = s.query_property() - Base.metadata.create_all(bind=engine) # create all tables. - - return s - - -def set_session(session): - """Set the SQLAlchemy session globally, within the orm module and the api module. - - :param session: the session to use. - :return: None. - """ - - global session_ - session_ = session - import observatory.api.server.api as api - - api.session_ = session - - -def fetch_db_object(cls: ClassVar, body: Any): - """Fetch a database object via SQLAlchemy. - - :param cls: the class of object to fetch. - :param body: the body of the object. If the body is None then None is returned (for the case where no object - exists), if the body is already of type cls then the body is returned as the object and if the body is a dictionary - with the key 'id' a query is made to fetch the given object. - :return: the object. - """ - - if body is None: - item = None - elif isinstance(body, cls): - item = body - elif isinstance(body, Dict): - if "id" not in body: - raise AttributeError(f"id not found in {body}") - - id = body["id"] - item = session_.query(cls).filter(cls.id == id).one_or_none() - if item is None: - raise ValueError(f"{item} with id {id} not found") - else: - raise ValueError(f"Unknown item type {body}") - - return item - - -def to_datetime_utc(obj: Union[None, pendulum.DateTime, str]) -> Union[pendulum.DateTime, None]: - """Converts pendulum.DateTime into UTC object. - - :param obj: a pendulum.DateTime object (which will just be converted to UTC) or None which will be returned. - :return: a DateTime object. - """ - - if isinstance(obj, pendulum.DateTime): - return obj.in_tz(tz="UTC") - elif isinstance(obj, str): - dt = pendulum.parse(obj) - return dt.in_tz(tz="UTC") - elif obj is None: - return None - - raise ValueError("body should be None or pendulum.DateTime") - - -@dataclass -class DatasetRelease(Base): - __tablename__ = "dataset_release" - - id: int - dag_id: str - dataset_id: str - dag_run_id: str - data_interval_start: pendulum.DateTime - data_interval_end: pendulum.DateTime - snapshot_date: pendulum.DateTime - partition_date: pendulum.DateTime - changefile_start_date: pendulum.DateTime - changefile_end_date: pendulum.DateTime - sequence_start: int - sequence_end: int - extra: Dict - created: pendulum.DateTime - modified: pendulum.DateTime - - id = Column(Integer, primary_key=True) - dag_id = Column(String(250), nullable=False) - dataset_id = Column(String(250), nullable=False) - dag_run_id = Column(String(250), nullable=True) - data_interval_start = Column(DateTime(), nullable=True) - data_interval_end = Column(DateTime(), nullable=True) - snapshot_date = Column(DateTime(), nullable=True) - partition_date = Column(DateTime(), nullable=True) - changefile_start_date = Column(DateTime(), nullable=True) - changefile_end_date = Column(DateTime(), nullable=True) - sequence_start = Column(Integer, nullable=True) - sequence_end = Column(Integer, nullable=True) - extra = Column(JSON(), nullable=True) - created = Column(DateTime()) - modified = Column(DateTime()) - - def __init__( - self, - id: int = None, - dag_id: str = None, - dataset_id: str = None, - dag_run_id: str = None, - data_interval_start: Union[pendulum.DateTime, str] = None, - data_interval_end: Union[pendulum.DateTime, str] = None, - snapshot_date: Union[pendulum.DateTime, str] = None, - partition_date: Union[pendulum.DateTime, str] = None, - changefile_start_date: Union[pendulum.DateTime, str] = None, - changefile_end_date: Union[pendulum.DateTime, str] = None, - sequence_start: int = None, - sequence_end: int = None, - extra: dict = None, - created: pendulum.DateTime = None, - modified: pendulum.DateTime = None, - ): - """Construct a DatasetRelease object. - - :param id: unique id. - :param dag_id: the DAG ID. - :param dataset_id: the dataset ID. - :param dag_run_id: the DAG's run ID. - :param data_interval_start: the DAGs data interval start. Date is inclusive. - :param data_interval_end: the DAGs data interval end. Date is exclusive. - :param snapshot_date: the release date of the snapshot. - :param partition_date: the partition date. - :param changefile_start_date: the date of the first changefile processed in this release. - :param changefile_end_date: the date of the last changefile processed in this release. - :param sequence_start: the starting sequence number of files that make up this release. - :param sequence_end: the end sequence number of files that make up this release. - :param extra: optional extra field for storing any data. - :param created: datetime created in UTC. - :param modified: datetime modified in UTC. - """ - - self.id = id - self.dag_id = dag_id - self.dataset_id = dataset_id - self.dag_run_id = dag_run_id - self.data_interval_start = to_datetime_utc(data_interval_start) - self.data_interval_end = to_datetime_utc(data_interval_end) - self.snapshot_date = to_datetime_utc(snapshot_date) - self.partition_date = to_datetime_utc(partition_date) - self.changefile_start_date = to_datetime_utc(changefile_start_date) - self.changefile_end_date = to_datetime_utc(changefile_end_date) - self.sequence_start = sequence_start - self.sequence_end = sequence_end - self.extra = extra - self.created = to_datetime_utc(created) - self.modified = to_datetime_utc(modified) - - def update( - self, - dag_id: str = None, - dataset_id: str = None, - dag_run_id: str = None, - data_interval_start: Union[pendulum.DateTime, str] = None, - data_interval_end: Union[pendulum.DateTime, str] = None, - snapshot_date: Union[pendulum.DateTime, str] = None, - partition_date: Union[pendulum.DateTime, str] = None, - changefile_start_date: Union[pendulum.DateTime, str] = None, - changefile_end_date: Union[pendulum.DateTime, str] = None, - sequence_start: int = None, - sequence_end: int = None, - extra: dict = None, - modified: pendulum.DateTime = None, - ): - """Update the properties of an existing DatasetRelease object. This method is handy when you want to update - the DatasetRelease from a dictionary, e.g. obj.update(**{'service': 'hello world'}). - - :param dag_id: the DAG ID. - :param dataset_id: the dataset ID. - :param dag_run_id: the DAG's run ID. - :param data_interval_start: the DAGs data interval start. Date is inclusive. - :param data_interval_end: the DAGs data interval end. Date is exclusive. - :param snapshot_date: the snapshot date. - :param partition_date: the partition date. - :param changefile_start_date: the date of the first changefile processed in this release. - :param changefile_end_date: the date of the last changefile processed in this release. - :param sequence_start: the starting sequence number of files that make up this release. - :param sequence_end: the end sequence number of files that make up this release. - :param extra: optional extra field for storing any data. - :param modified: datetime modified in UTC. - :return: None. - """ - - if dag_id is not None: - self.dag_id = dag_id - - if dataset_id is not None: - self.dataset_id = dataset_id - - if dag_run_id is not None: - self.dag_run_id = dag_run_id - - if data_interval_start is not None: - self.data_interval_start = to_datetime_utc(data_interval_start) - - if data_interval_end is not None: - self.data_interval_end = to_datetime_utc(data_interval_end) - - if snapshot_date is not None: - self.snapshot_date = to_datetime_utc(snapshot_date) - - if partition_date is not None: - self.partition_date = to_datetime_utc(partition_date) - - if changefile_start_date is not None: - self.changefile_start_date = to_datetime_utc(changefile_start_date) - - if changefile_end_date is not None: - self.changefile_end_date = to_datetime_utc(changefile_end_date) - - if sequence_start is not None: - self.sequence_start = sequence_start - - if sequence_end is not None: - self.sequence_end = sequence_end - - if extra is not None: - self.extra = extra - - if modified is not None: - self.modified = to_datetime_utc(modified) diff --git a/observatory-api/observatory/api/testing.py b/observatory-api/observatory/api/testing.py deleted file mode 100644 index 24f64099c..000000000 --- a/observatory-api/observatory/api/testing.py +++ /dev/null @@ -1,69 +0,0 @@ -# Copyright 2021 Curtin University -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Author: James Diprose - -import contextlib -import threading - -from observatory.api.server.api import create_app -from observatory.api.server.orm import create_session, set_session -from sqlalchemy.pool import StaticPool -from werkzeug.serving import make_server - - -class ObservatoryApiEnvironment: - def __init__(self, host: str = "localhost", port: int = 5000, seed_db: bool = False): - """Create an ObservatoryApiEnvironment instance. - - :param host: the host name. - :param port: the port. - :param seed_db: whether to seed the database or not. - """ - - self.host = host - self.port = port - self.seed_db = seed_db - self.db_uri = "sqlite://" - self.session = None - self.server = None - self.server_thread = None - - @contextlib.contextmanager - def create(self): - """Make and destroy an Observatory API isolated environment, which involves: - - * Creating an in memory SQLite database for the API backend to connect to - * Start the Connexion / Flask app - - :yield: None. - """ - - try: - # Connect to in memory SQLite database with SQLAlchemy - self.session = create_session( - uri=self.db_uri, connect_args={"check_same_thread": False}, poolclass=StaticPool - ) - set_session(self.session) - - # Create the Connexion App and start the server - app = create_app() - self.server = make_server(self.host, self.port, app) - self.server_thread = threading.Thread(target=self.server.serve_forever) - self.server_thread.start() - yield - finally: - # Stop server and wait for server thread to join - self.server.shutdown() - self.server_thread.join() diff --git a/observatory-api/observatory/api/utils.py b/observatory-api/observatory/api/utils.py deleted file mode 100644 index e45744087..000000000 --- a/observatory-api/observatory/api/utils.py +++ /dev/null @@ -1,44 +0,0 @@ -# Copyright 2020 Curtin University -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# -# Author: Tuan Chien - - -import os -from urllib.parse import urlparse - -from observatory.api.client import ApiClient, Configuration -from observatory.api.client.api.observatory_api import ObservatoryApi - - -def get_api_client(host: str = "localhost", port: int = 5002, api_key: dict = None) -> ObservatoryApi: - """Get an API client. - :param host: URI for api server. - :param port: Server port. - :param api_key: API key. - :return: ObservatoryApi object. - """ - - if "API_URI" in os.environ: - fields = urlparse(os.environ["API_URI"]) - uri = f"{fields.scheme}://{fields.hostname}:{fields.port}" - api_key = {"api_key": fields.password} - else: - uri = f"http://{host}:{port}" - - configuration = Configuration(host=uri, api_key=api_key) - api_client = ApiClient(configuration) - api = ObservatoryApi(api_client=api_client) - return api diff --git a/observatory-api/requirements.sh b/observatory-api/requirements.sh deleted file mode 100644 index 800cd0fd6..000000000 --- a/observatory-api/requirements.sh +++ /dev/null @@ -1,15 +0,0 @@ -#!/usr/bin/env bash -# Copyright 2021 Curtin University -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - diff --git a/observatory-api/requirements.txt b/observatory-api/requirements.txt deleted file mode 100644 index 2eed07357..000000000 --- a/observatory-api/requirements.txt +++ /dev/null @@ -1,14 +0,0 @@ -gunicorn>=20.1.0,<21 -Flask>=2,<3 -connexion[swagger-ui]>=2.7.0,<3 -pyyaml>=6,<7 -SQLAlchemy>=1.4,<2 -psycopg2-binary>=2.9.1,<3 -urllib3>=1.25.11,<2 -python-dateutil>=2.8.2,<3 -certifi -pendulum>=2.1.2,<3 -openapi-spec-validator>=0.3.1,<1 -Jinja2>=3,<4 -markupsafe>=2,<3 - diff --git a/observatory-api/setup.cfg b/observatory-api/setup.cfg deleted file mode 100644 index 89d9d20d1..000000000 --- a/observatory-api/setup.cfg +++ /dev/null @@ -1,58 +0,0 @@ -[metadata] -name = observatory-api -author = Curtin University -author_email = agent@observatory.academy -summary = -description_file = README.md -description_content_type = text/markdown; charset=UTF-8 -home_page = https://github.com/The-Academic-Observatory/observatory-platform -project_urls = - Bug Tracker = https://github.com/The-Academic-Observatory/observatory-platform/issues - Documentation = https://observatory-platform.readthedocs.io/en/latest/ - Source Code = https://github.com/The-Academic-Observatory/observatory-platform -python_requires = >=3.10 -license = Apache License Version 2.0 -classifier = - Development Status :: 2 - Pre-Alpha - Environment :: Console - Environment :: Web Environment - Intended Audience :: Developers - Intended Audience :: Science/Research - License :: OSI Approved :: Apache Software License - Operating System :: OS Independent - Programming Language :: Python :: 3 :: Only - Programming Language :: Python :: 3 - Programming Language :: Python :: 3.10 - Topic :: Scientific/Engineering - Topic :: Software Development :: Libraries - Topic :: Software Development :: Libraries :: Python Modules - Topic :: Utilities -keywords = - science - data - workflows - academic institutes - observatory-api - -[files] -packages = - observatory - observatory.api -data_files = - requirements.txt = requirements.txt - requirements.sh = requirements.sh - observatory/api = - observatory/api/server/openapi.yaml.jinja2 - -[entry_points] -console_scripts = - observatory-api = observatory.api.cli:cli - -[extras] -tests = - liccheck>=0.4.9,<1 - flake8>=3.8.0,<4 - coverage>=5.2,<6 - -[pbr] -skip_authors = true diff --git a/observatory-api/setup.py b/observatory-api/setup.py deleted file mode 100644 index 2c3a057c4..000000000 --- a/observatory-api/setup.py +++ /dev/null @@ -1,3 +0,0 @@ -from setuptools import setup - -setup(setup_requires=["pbr"], pbr=True, python_requires=">=3.10") diff --git a/observatory-api/templates/README_common.mustache b/observatory-api/templates/README_common.mustache deleted file mode 100644 index 6da05b337..000000000 --- a/observatory-api/templates/README_common.mustache +++ /dev/null @@ -1,116 +0,0 @@ -```python -{{#apiInfo}}{{#apis}}{{#-last}}{{#hasHttpSignatureMethods}}import datetime{{/hasHttpSignatureMethods}}{{/-last}}{{/apis}}{{/apiInfo}} -import time -import {{{packageName}}} -from pprint import pprint -{{#apiInfo}} -{{#apis}} -{{#-first}} -from {{apiPackage}} import {{classVarName}} -{{#imports}} -{{{import}}} -{{/imports}} -{{#operations}} -{{#operation}} -{{#-first}} -{{> python_doc_auth_partial}} - -# Enter a context with an instance of the API client -with {{{packageName}}}.ApiClient(configuration) as api_client: - # Create an instance of the API class - api_instance = {{classVarName}}.{{{classname}}}(api_client) - {{#allParams}}{{paramName}} = {{{example}}} # {{{dataType}}} | {{{description}}}{{^required}} (optional){{/required}}{{#defaultValue}} (default to {{{.}}}){{/defaultValue}} - {{/allParams}} - - try: - {{#summary}} # {{{.}}} - {{/summary}} {{#returnType}}api_response = {{/returnType}}api_instance.{{{operationId}}}({{#allParams}}{{#required}}{{paramName}}{{/required}}{{^required}}{{paramName}}={{paramName}}{{/required}}{{^-last}}, {{/-last}}{{/allParams}}){{#returnType}} - pprint(api_response){{/returnType}} - except {{{packageName}}}.ApiException as e: - print("Exception when calling {{classname}}->{{operationId}}: %s\n" % e) -{{/-first}} -{{/operation}} -{{/operations}} -{{/-first}} -{{/apis}} -{{/apiInfo}} -``` - -## Documentation for API Endpoints - -All URIs are relative to *{{basePath}}* - -```eval_rst -.. toctree:: - :maxdepth: 1 - - ObservatoryApi -``` - -
- - - - - - - - - - -{{#apiInfo}}{{#apis}}{{#operations}}{{#operation}} - - - - - - -{{/operation}}{{/operations}}{{/apis}}{{/apiInfo}} - - -
ClassMethodHTTP requestDescription
{{classname}}{{operationId}}{{httpMethod}} {{path}}{{#summary}}{{summary}}{{/summary}}
- -## Documentation For Models -```eval_rst -.. toctree:: - :maxdepth: 1 - - {{#models}}{{#model}}{{{classname}}} - {{/model}}{{/models}} -``` - -## Documentation For Authorization - -{{^authMethods}} - All endpoints do not require authorization. -{{/authMethods}} -{{#authMethods}} -{{#last}} Authentication schemes defined for the API:{{/last}} -## {{{name}}} - -{{#isApiKey}} -- **Type**: API key -- **API key parameter name**: {{{keyParamName}}} -- **Location**: {{#isKeyInQuery}}URL query string{{/isKeyInQuery}}{{#isKeyInHeader}}HTTP header{{/isKeyInHeader}} -{{/isApiKey}} -{{#isBasic}} -{{#isBasicBasic}} -- **Type**: HTTP basic authentication -{{/isBasicBasic}} -{{#isBasicBearer}} -- **Type**: Bearer authentication{{#bearerFormat}} ({{{.}}}){{/bearerFormat}} -{{/isBasicBearer}} -{{#isHttpSignature}} -- **Type**: HTTP signature authentication -{{/isHttpSignature}} -{{/isBasic}} -{{#isOAuth}} -- **Type**: OAuth -- **Flow**: {{{flow}}} -- **Authorization URL**: {{{authorizationUrl}}} -- **Scopes**: {{^scopes}}N/A{{/scopes}} -{{#scopes}} - **{{{scope}}}**: {{{description}}} -{{/scopes}} -{{/isOAuth}} - -{{/authMethods}} \ No newline at end of file diff --git a/observatory-api/templates/README_onlypackage.mustache b/observatory-api/templates/README_onlypackage.mustache deleted file mode 100644 index ea1ef9524..000000000 --- a/observatory-api/templates/README_onlypackage.mustache +++ /dev/null @@ -1,37 +0,0 @@ -# Python API Client -{{#appDescription}} -{{{appDescription}}} -{{/appDescription}} - -The `{{packageName}}` package is automatically generated by the [OpenAPI Generator](https://openapi-generator.tech) project: - -- API version: {{appVersion}} -- Package version: {{packageVersion}} -{{^hideGenerationTimestamp}} -- Build date: {{generatedDate}} -{{/hideGenerationTimestamp}} -{{#infoUrl}} -For more information, please visit [{{{infoUrl}}}]({{{infoUrl}}}) -{{/infoUrl}} - -## Requirements -Python >= 3.10 - -## Installation & Usage -To install the package with PyPI: -```bash -pip install observatory-api -``` - -To install the package from source: -```bash -git clone https://github.com/The-Academic-Observatory/observatory-platform.git -cd observatory-platform -pip install -e observatory-api -``` - -## Getting Started -In your own code, to use this library to connect and interact with {{{projectName}}}, -you can run the following: - -{{> README_common }} \ No newline at end of file diff --git a/observatory-api/templates/api_doc.mustache b/observatory-api/templates/api_doc.mustache deleted file mode 100644 index 9369a71ae..000000000 --- a/observatory-api/templates/api_doc.mustache +++ /dev/null @@ -1,148 +0,0 @@ -# {{classname}}{{#description}} -{{description}}{{/description}} - -All URIs are relative to *{{basePath}}* - -
- - - - - - - - - -{{#operations}}{{#operation}} - - - - - -{{/operation}}{{/operations}} - - -
MethodHTTP requestDescription
{{operationId}}{{httpMethod}} {{path}}{{#summary}}{{summary}}{{/summary}}
- -{{#operations}} -{{#operation}} -## **{{{operationId}}}** -> {{#returnType}}{{{returnType}}} {{/returnType}}{{{operationId}}}({{#requiredParams}}{{^defaultValue}}{{paramName}}{{^-last}}, {{/-last}}{{/defaultValue}}{{/requiredParams}}) - -{{{summary}}}{{#notes}} - -{{{notes}}}{{/notes}} - -### Example - -{{#hasAuthMethods}} -{{#authMethods}} -{{#isBasic}} -{{#isBasicBasic}} -* Basic Authentication ({{name}}): -{{/isBasicBasic}} -{{#isBasicBearer}} -* Bearer{{#bearerFormat}} ({{{.}}}){{/bearerFormat}} Authentication ({{name}}): -{{/isBasicBearer}} -{{/isBasic}} -{{#isApiKey}} -* Api Key Authentication ({{name}}): -{{/isApiKey }} -{{#isOAuth}} -* OAuth Authentication ({{name}}): -{{/isOAuth }} -{{/authMethods}} -{{/hasAuthMethods}} -{{> api_doc_example }} - -### Parameters -{{^allParams}}This endpoint does not need any parameter.{{/allParams}}{{#allParams}}{{#-last}} - -
- - - - - - - - - - -{{/-last}}{{/allParams}} -{{#requiredParams}}{{^defaultValue}} - - -{{^baseType}}{{/baseType}}{{#baseType}}{{/baseType}} - - - -{{/defaultValue}}{{/requiredParams}} - -{{#requiredParams}}{{#defaultValue}} - - -{{^baseType}}{{/baseType}}{{#baseType}}{{/baseType}} - - - -{{/defaultValue}} - -{{/requiredParams}}{{#optionalParams}} - - -{{^baseType}}{{/baseType}}{{#baseType}}{{/baseType}} - - - -{{/defaultValue}} -{{/optionalParams}} - - -
NameTypeDescriptionNotes
{{paramName}}{{dataType}}{{dataType}}{{description}}
{{paramName}}{{dataType}}{{dataType}}{{description}} -defaults to {{{.}}} -
{{paramName}}{{dataType}}{{dataType}}{{description}} -[optional]{{#defaultValue}} if omitted the server will use the default value of {{{.}}} -
- - -### Return type - -{{#returnType}}{{#returnTypeIsPrimitive}}**{{{returnType}}}**{{/returnTypeIsPrimitive}}{{^returnTypeIsPrimitive}}[**{{{returnType}}}**]({{returnBaseType}}.html){{/returnTypeIsPrimitive}}{{/returnType}}{{^returnType}}void (empty response body){{/returnType}} - -### Authorization - -{{^authMethods}}No authorization required{{/authMethods}}{{#authMethods}}[{{{name}}}](ObservatoryApi.html#{{{name}}}){{^-last}}, {{/-last}}{{/authMethods}} - -### HTTP request headers - - - **Content-Type**: {{#consumes}}{{{mediaType}}}{{^-last}}, {{/-last}}{{/consumes}}{{^consumes}}Not defined{{/consumes}} - - **Accept**: {{#produces}}{{{mediaType}}}{{^-last}}, {{/-last}}{{/produces}}{{^produces}}Not defined{{/produces}} - - -### HTTP response details -
- - - - - - - - - -{{#responses.0}} -{{#responses}} - - - - - -{{/responses}} -{{/responses.0}} - - -
Status codeDescriptionResponse headers
{{code}}{{message}}{{#headers}} * {{baseName}} - {{description}}
{{/headers}}{{^headers.0}} - {{/headers.0}}
- -{{/operation}} -{{/operations}} \ No newline at end of file diff --git a/observatory-api/templates/configuration.mustache b/observatory-api/templates/configuration.mustache deleted file mode 100644 index 6a200c698..000000000 --- a/observatory-api/templates/configuration.mustache +++ /dev/null @@ -1,627 +0,0 @@ -{{>partial_header}} - -import copy -import logging -{{^asyncio}} -import multiprocessing -{{/asyncio}} -import sys -import certifi -import urllib3 - -from http import client as http_client -from {{packageName}}.exceptions import ApiValueError - - -JSON_SCHEMA_VALIDATION_KEYWORDS = { - 'multipleOf', 'maximum', 'exclusiveMaximum', - 'minimum', 'exclusiveMinimum', 'maxLength', - 'minLength', 'pattern', 'maxItems', 'minItems' -} - -class Configuration(object): - """NOTE: This class is auto generated by OpenAPI Generator - - Ref: https://openapi-generator.tech - Do not edit the class manually. - - :param host: Base url - :param api_key: Dict to store API key(s). - Each entry in the dict specifies an API key. - The dict key is the name of the security scheme in the OAS specification. - The dict value is the API key secret. - :param api_key_prefix: Dict to store API prefix (e.g. Bearer) - The dict key is the name of the security scheme in the OAS specification. - The dict value is an API key prefix when generating the auth data. - :param username: Username for HTTP basic authentication - :param password: Password for HTTP basic authentication - :param discard_unknown_keys: Boolean value indicating whether to discard - unknown properties. A server may send a response that includes additional - properties that are not known by the client in the following scenarios: - 1. The OpenAPI document is incomplete, i.e. it does not match the server - implementation. - 2. The client was generated using an older version of the OpenAPI document - and the server has been upgraded since then. - If a schema in the OpenAPI document defines the additionalProperties attribute, - then all undeclared properties received by the server are injected into the - additional properties map. In that case, there are undeclared properties, and - nothing to discard. - :param disabled_client_side_validations (string): Comma-separated list of - JSON schema validation keywords to disable JSON schema structural validation - rules. The following keywords may be specified: multipleOf, maximum, - exclusiveMaximum, minimum, exclusiveMinimum, maxLength, minLength, pattern, - maxItems, minItems. - By default, the validation is performed for data generated locally by the client - and data received from the server, independent of any validation performed by - the server side. If the input data does not satisfy the JSON schema validation - rules specified in the OpenAPI document, an exception is raised. - If disabled_client_side_validations is set, structural validation is - disabled. This can be useful to troubleshoot data validation problem, such as - when the OpenAPI document validation rules do not match the actual API data - received by the server. -{{#hasHttpSignatureMethods}} - :param signing_info: Configuration parameters for the HTTP signature security scheme. - Must be an instance of {{{packageName}}}.signing.HttpSigningConfiguration -{{/hasHttpSignatureMethods}} - :param server_index: Index to servers configuration. - :param server_variables: Mapping with string values to replace variables in - templated server configuration. The validation of enums is performed for - variables with defined enum values before. - :param server_operation_index: Mapping from operation ID to an index to server - configuration. - :param server_operation_variables: Mapping from operation ID to a mapping with - string values to replace variables in templated server configuration. - The validation of enums is performed for variables with defined enum values before. - :param ssl_ca_cert: str - the path to a file of concatenated CA certificates - in PEM format - -{{#hasAuthMethods}} - :Example: -{{#hasApiKeyMethods}} - - API Key Authentication Example. - Given the following security scheme in the OpenAPI specification: - components: - securitySchemes: - cookieAuth: # name for the security scheme - type: apiKey - in: cookie - name: JSESSIONID # cookie name - - You can programmatically set the cookie: - -conf = {{{packageName}}}.Configuration( - api_key={'cookieAuth': 'abc123'} - api_key_prefix={'cookieAuth': 'JSESSIONID'} -) - - The following cookie will be added to the HTTP request: - Cookie: JSESSIONID abc123 -{{/hasApiKeyMethods}} -{{#hasHttpBasicMethods}} - - HTTP Basic Authentication Example. - Given the following security scheme in the OpenAPI specification: - components: - securitySchemes: - http_basic_auth: - type: http - scheme: basic - - Configure API client with HTTP basic authentication: - -conf = {{{packageName}}}.Configuration( - username='the-user', - password='the-password', -) - -{{/hasHttpBasicMethods}} -{{#hasHttpSignatureMethods}} - - HTTP Signature Authentication Example. - Given the following security scheme in the OpenAPI specification: - components: - securitySchemes: - http_basic_auth: - type: http - scheme: signature - - Configure API client with HTTP signature authentication. Use the 'hs2019' signature scheme, - sign the HTTP requests with the RSA-SSA-PSS signature algorithm, and set the expiration time - of the signature to 5 minutes after the signature has been created. - Note you can use the constants defined in the {{{packageName}}}.signing module, and you can - also specify arbitrary HTTP headers to be included in the HTTP signature, except for the - 'Authorization' header, which is used to carry the signature. - - One may be tempted to sign all headers by default, but in practice it rarely works. - This is beccause explicit proxies, transparent proxies, TLS termination endpoints or - load balancers may add/modify/remove headers. Include the HTTP headers that you know - are not going to be modified in transit. - -conf = {{{packageName}}}.Configuration( - signing_info = {{{packageName}}}.signing.HttpSigningConfiguration( - key_id = 'my-key-id', - private_key_path = 'rsa.pem', - signing_scheme = {{{packageName}}}.signing.SCHEME_HS2019, - signing_algorithm = {{{packageName}}}.signing.ALGORITHM_RSASSA_PSS, - signed_headers = [{{{packageName}}}.signing.HEADER_REQUEST_TARGET, - {{{packageName}}}.signing.HEADER_CREATED, - {{{packageName}}}.signing.HEADER_EXPIRES, - {{{packageName}}}.signing.HEADER_HOST, - {{{packageName}}}.signing.HEADER_DATE, - {{{packageName}}}.signing.HEADER_DIGEST, - 'Content-Type', - 'User-Agent' - ], - signature_max_validity = datetime.timedelta(minutes=5) - ) -) -{{/hasHttpSignatureMethods}} -{{/hasAuthMethods}} - """ - - _default = None - - def __init__(self, host=None, - api_key=None, api_key_prefix=None, - access_token=None, - username=None, password=None, - discard_unknown_keys=False, - disabled_client_side_validations="", -{{#hasHttpSignatureMethods}} - signing_info=None, -{{/hasHttpSignatureMethods}} - server_index=None, server_variables=None, - server_operation_index=None, server_operation_variables=None, - ssl_ca_cert=certifi.where(), - ): - """Constructor - """ - self._base_path = "{{{basePath}}}" if host is None else host - """Default Base url - """ - self.server_index = 0 if server_index is None and host is None else server_index - self.server_operation_index = server_operation_index or {} - """Default server index - """ - self.server_variables = server_variables or {} - self.server_operation_variables = server_operation_variables or {} - """Default server variables - """ - self.temp_folder_path = None - """Temp file folder for downloading files - """ - # Authentication Settings - self.access_token = access_token - self.api_key = {} - if api_key: - self.api_key = api_key - """dict to store API key(s) - """ - self.api_key_prefix = {} - if api_key_prefix: - self.api_key_prefix = api_key_prefix - """dict to store API prefix (e.g. Bearer) - """ - self.refresh_api_key_hook = None - """function hook to refresh API key if expired - """ - self.username = username - """Username for HTTP basic authentication - """ - self.password = password - """Password for HTTP basic authentication - """ - self.discard_unknown_keys = discard_unknown_keys - self.disabled_client_side_validations = disabled_client_side_validations -{{#hasHttpSignatureMethods}} - if signing_info is not None: - signing_info.host = host - self.signing_info = signing_info - """The HTTP signing configuration - """ -{{/hasHttpSignatureMethods}} - self.logger = {} - """Logging Settings - """ - self.logger["package_logger"] = logging.getLogger("{{packageName}}") - self.logger["urllib3_logger"] = logging.getLogger("urllib3") - self.logger_format = '%(asctime)s %(levelname)s %(message)s' - """Log format - """ - self.logger_stream_handler = None - """Log stream handler - """ - self.logger_file_handler = None - """Log file handler - """ - self.logger_file = None - """Debug file location - """ - self.debug = False - """Debug switch - """ - - self.verify_ssl = True - """SSL/TLS verification - Set this to false to skip verifying SSL certificate when calling API - from https server. - """ - self.ssl_ca_cert = ssl_ca_cert - """Set this to customize the certificate file to verify the peer. - """ - self.cert_file = None - """client certificate file - """ - self.key_file = None - """client key file - """ - self.assert_hostname = None - """Set this to True/False to enable/disable SSL hostname verification. - """ - - {{#asyncio}} - self.connection_pool_maxsize = 100 - """This value is passed to the aiohttp to limit simultaneous connections. - Default values is 100, None means no-limit. - """ - {{/asyncio}} - {{^asyncio}} - self.connection_pool_maxsize = multiprocessing.cpu_count() * 5 - """urllib3 connection pool's maximum number of connections saved - per pool. urllib3 uses 1 connection as default value, but this is - not the best value when you are making a lot of possibly parallel - requests to the same host, which is often the case here. - cpu_count * 5 is used as default value to increase performance. - """ - {{/asyncio}} - - self.proxy = None - """Proxy URL - """ - self.proxy_headers = None - """Proxy headers - """ - self.safe_chars_for_path_param = '' - """Safe chars for path_param - """ - self.retries = None - """Adding retries to override urllib3 default value 3 - """ - # Enable client side validation - self.client_side_validation = True - - # Options to pass down to the underlying urllib3 socket - self.socket_options = None - - def __deepcopy__(self, memo): - cls = self.__class__ - result = cls.__new__(cls) - memo[id(self)] = result - for k, v in self.__dict__.items(): - if k not in ('logger', 'logger_file_handler'): - setattr(result, k, copy.deepcopy(v, memo)) - # shallow copy of loggers - result.logger = copy.copy(self.logger) - # use setters to configure loggers - result.logger_file = self.logger_file - result.debug = self.debug - return result - - def __setattr__(self, name, value): - object.__setattr__(self, name, value) - if name == 'disabled_client_side_validations': - s = set(filter(None, value.split(','))) - for v in s: - if v not in JSON_SCHEMA_VALIDATION_KEYWORDS: - raise ApiValueError( - "Invalid keyword: '{0}''".format(v)) - self._disabled_client_side_validations = s -{{#hasHttpSignatureMethods}} - if name == "signing_info" and value is not None: - # Ensure the host paramater from signing info is the same as - # Configuration.host. - value.host = self.host -{{/hasHttpSignatureMethods}} - - @classmethod - def set_default(cls, default): - """Set default instance of configuration. - - It stores default configuration, which can be - returned by get_default_copy method. - - :param default: object of Configuration - """ - cls._default = copy.deepcopy(default) - - @classmethod - def get_default_copy(cls): - """Return new instance of configuration. - - This method returns newly created, based on default constructor, - object of Configuration class or returns a copy of default - configuration passed by the set_default method. - - :return: The configuration object. - """ - if cls._default is not None: - return copy.deepcopy(cls._default) - return Configuration() - - @property - def logger_file(self): - """The logger file. - - If the logger_file is None, then add stream handler and remove file - handler. Otherwise, add file handler and remove stream handler. - - :param value: The logger_file path. - :type: str - """ - return self.__logger_file - - @logger_file.setter - def logger_file(self, value): - """The logger file. - - If the logger_file is None, then add stream handler and remove file - handler. Otherwise, add file handler and remove stream handler. - - :param value: The logger_file path. - :type: str - """ - self.__logger_file = value - if self.__logger_file: - # If set logging file, - # then add file handler and remove stream handler. - self.logger_file_handler = logging.FileHandler(self.__logger_file) - self.logger_file_handler.setFormatter(self.logger_formatter) - for _, logger in self.logger.items(): - logger.addHandler(self.logger_file_handler) - - @property - def debug(self): - """Debug status - - :param value: The debug status, True or False. - :type: bool - """ - return self.__debug - - @debug.setter - def debug(self, value): - """Debug status - - :param value: The debug status, True or False. - :type: bool - """ - self.__debug = value - if self.__debug: - # if debug status is True, turn on debug logging - for _, logger in self.logger.items(): - logger.setLevel(logging.DEBUG) - # turn on http_client debug - http_client.HTTPConnection.debuglevel = 1 - else: - # if debug status is False, turn off debug logging, - # setting log level to default `logging.WARNING` - for _, logger in self.logger.items(): - logger.setLevel(logging.WARNING) - # turn off http_client debug - http_client.HTTPConnection.debuglevel = 0 - - @property - def logger_format(self): - """The logger format. - - The logger_formatter will be updated when sets logger_format. - - :param value: The format string. - :type: str - """ - return self.__logger_format - - @logger_format.setter - def logger_format(self, value): - """The logger format. - - The logger_formatter will be updated when sets logger_format. - - :param value: The format string. - :type: str - """ - self.__logger_format = value - self.logger_formatter = logging.Formatter(self.__logger_format) - - def get_api_key_with_prefix(self, identifier, alias=None): - """Gets API key (with prefix if set). - - :param identifier: The identifier of apiKey. - :param alias: The alternative identifier of apiKey. - :return: The token for api key authentication. - """ - if self.refresh_api_key_hook is not None: - self.refresh_api_key_hook(self) - key = self.api_key.get(identifier, self.api_key.get(alias) if alias is not None else None) - if key: - prefix = self.api_key_prefix.get(identifier) - if prefix: - return "%s %s" % (prefix, key) - else: - return key - - def get_basic_auth_token(self): - """Gets HTTP basic authentication header (string). - - :return: The token for basic HTTP authentication. - """ - username = "" - if self.username is not None: - username = self.username - password = "" - if self.password is not None: - password = self.password - return urllib3.util.make_headers( - basic_auth=username + ':' + password - ).get('authorization') - - def auth_settings(self): - """Gets Auth Settings dict for api client. - - :return: The Auth Settings information dict. - """ - auth = {} -{{#authMethods}} -{{#isApiKey}} - if '{{name}}' in self.api_key{{#vendorExtensions.x-auth-id-alias}} or '{{.}}' in self.api_key{{/vendorExtensions.x-auth-id-alias}}: - auth['{{name}}'] = { - 'type': 'api_key', - 'in': {{#isKeyInCookie}}'cookie'{{/isKeyInCookie}}{{#isKeyInHeader}}'header'{{/isKeyInHeader}}{{#isKeyInQuery}}'query'{{/isKeyInQuery}}, - 'key': '{{keyParamName}}', - 'value': self.get_api_key_with_prefix( - '{{name}}',{{#vendorExtensions.x-auth-id-alias}} - alias='{{.}}',{{/vendorExtensions.x-auth-id-alias}} - ), - } -{{/isApiKey}} -{{#isBasic}} - {{#isBasicBasic}} - if self.username is not None and self.password is not None: - auth['{{name}}'] = { - 'type': 'basic', - 'in': 'header', - 'key': 'Authorization', - 'value': self.get_basic_auth_token() - } - {{/isBasicBasic}} - {{#isBasicBearer}} - if self.access_token is not None: - auth['{{name}}'] = { - 'type': 'bearer', - 'in': 'header', - {{#bearerFormat}} - 'format': '{{{.}}}', - {{/bearerFormat}} - 'key': 'Authorization', - 'value': 'Bearer ' + self.access_token - } - {{/isBasicBearer}} - {{#isHttpSignature}} - if self.signing_info is not None: - auth['{{name}}'] = { - 'type': 'http-signature', - 'in': 'header', - 'key': 'Authorization', - 'value': None # Signature headers are calculated for every HTTP request - } - {{/isHttpSignature}} -{{/isBasic}} -{{#isOAuth}} - if self.access_token is not None: - auth['{{name}}'] = { - 'type': 'oauth2', - 'in': 'header', - 'key': 'Authorization', - 'value': 'Bearer ' + self.access_token - } -{{/isOAuth}} -{{/authMethods}} - return auth - - def to_debug_report(self): - """Gets the essential information for debugging. - - :return: The report for debugging. - """ - return "Python SDK Debug Report:\n"\ - "OS: {env}\n"\ - "Python Version: {pyversion}\n"\ - "Version of the API: {{version}}\n"\ - "SDK Package Version: {{packageVersion}}".\ - format(env=sys.platform, pyversion=sys.version) - - def get_host_settings(self): - """Gets an array of host settings - - :return: An array of host settings - """ - return [ - {{#servers}} - { - 'url': "{{{url}}}", - 'description': "{{{description}}}{{^description}}No description provided{{/description}}", - {{#variables}} - {{#-first}} - 'variables': { - {{/-first}} - '{{{name}}}': { - 'description': "{{{description}}}{{^description}}No description provided{{/description}}", - 'default_value': "{{{defaultValue}}}", - {{#enumValues}} - {{#-first}} - 'enum_values': [ - {{/-first}} - "{{{.}}}"{{^-last}},{{/-last}} - {{#-last}} - ] - {{/-last}} - {{/enumValues}} - }{{^-last}},{{/-last}} - {{#-last}} - } - {{/-last}} - {{/variables}} - }{{^-last}},{{/-last}} - {{/servers}} - ] - - def get_host_from_settings(self, index, variables=None, servers=None): - """Gets host URL based on the index and variables - :param index: array index of the host settings - :param variables: hash of variable and the corresponding value - :param servers: an array of host settings or None - :return: URL based on host settings - """ - if index is None: - return self._base_path - - variables = {} if variables is None else variables - servers = self.get_host_settings() if servers is None else servers - - try: - server = servers[index] - except IndexError: - raise ValueError( - "Invalid index {0} when selecting the host settings. " - "Must be less than {1}".format(index, len(servers))) - - url = server['url'] - - # go through variables and replace placeholders - for variable_name, variable in server.get('variables', {}).items(): - used_value = variables.get( - variable_name, variable['default_value']) - - if 'enum_values' in variable \ - and used_value not in variable['enum_values']: - raise ValueError( - "The variable `{0}` in the host URL has invalid value " - "{1}. Must be {2}.".format( - variable_name, variables[variable_name], - variable['enum_values'])) - - url = url.replace("{" + variable_name + "}", used_value) - - return url - - @property - def host(self): - """Return generated host.""" - return self.get_host_from_settings(self.server_index, variables=self.server_variables) - - @host.setter - def host(self, value): - """Fix base path.""" - self._base_path = value - self.server_index = None \ No newline at end of file diff --git a/observatory-api/templates/model_doc.mustache b/observatory-api/templates/model_doc.mustache deleted file mode 100644 index 37291aabc..000000000 --- a/observatory-api/templates/model_doc.mustache +++ /dev/null @@ -1,87 +0,0 @@ -{{#models}}{{#model}}# {{classname}} - -{{#description}}{{&description}} -{{/description}} -## Properties -
- - - - - - - - - - -{{#isEnum}} - - - - - - -{{/isEnum}} - -{{#isAlias}} - - - - - - -{{/isAlias}} - -{{#isArray}} - - - - - - -{{/isArray}} - -{{#requiredVars}} -{{^defaultValue}} - - - - - - -{{/defaultValue}} -{{/requiredVars}} - -{{#requiredVars}} -{{#defaultValue}} - - - - - - -{{/defaultValue}} -{{/requiredVars}} - -{{#optionalVars}} - - - - - - -{{/optionalVars}} - -{{#additionalPropertiesType}} - - - - - - -{{/additionalPropertiesType}} - - -
NameTypeDescriptionNotes
value{{^arrayModelType}}{{dataType}}{{/arrayModelType}}{{description}}{{#defaultValue}}{{#hasRequired}} if omitted the server will use the default value of {{/hasRequired}}{{^hasRequired}}defaults to {{/hasRequired}}{{{.}}}{{/defaultValue}}{{#allowableValues}}{{#defaultValue}}, {{/defaultValue}} must be one of [{{#enumVars}}{{{value}}}, {{/enumVars}}]{{/allowableValues}}
value{{^arrayModelType}}{{dataType}}{{/arrayModelType}}{{description}}{{#defaultValue}}{{#hasRequired}} if omitted the server will use the default value of {{/hasRequired}}{{^hasRequired}}defaults to {{/hasRequired}}{{{.}}}{{/defaultValue}}
value{{^arrayModelType}}{{dataType}}{{/arrayModelType}}{{#arrayModelType}}{{dataType}}{{/arrayModelType}}{{description}}{{#defaultValue}}{{#hasRequired}} if omitted the server will use the default value of {{/hasRequired}}{{^hasRequired}}defaults to {{/hasRequired}}{{{.}}}{{/defaultValue}}
{{name}}{{^complexType}}{{dataType}}{{/complexType}}{{#complexType}}{{dataType}}{{/complexType}}{{description}}{{#isReadOnly}}[readonly] {{/isReadOnly}}
{{name}}{{^complexType}}{{dataType}}{{/complexType}}{{#complexType}}{{dataType}}{{/complexType}}{{description}}{{^required}}[optional] {{/required}}{{#isReadOnly}}[readonly] {{/isReadOnly}}{{#defaultValue}}defaults to {{{.}}}{{/defaultValue}}
{{name}}{{^complexType}}{{dataType}}{{/complexType}}{{#complexType}}{{dataType}}{{/complexType}}{{description}}[optional] {{#isReadOnly}}[readonly] {{/isReadOnly}}{{#defaultValue}} if omitted the server will use the default value of {{{.}}}{{/defaultValue}}
any string name{{additionalPropertiesType}}any string name can be used but the value must be the correct type[optional]
- -{{/model}}{{/models}} \ No newline at end of file diff --git a/observatory-api/templates/model_templates/classvars.mustache b/observatory-api/templates/model_templates/classvars.mustache deleted file mode 100644 index 4dcf18afc..000000000 --- a/observatory-api/templates/model_templates/classvars.mustache +++ /dev/null @@ -1,138 +0,0 @@ - allowed_values = { -{{#isEnum}} - ('value',): { -{{#isNullable}} - 'None': None, -{{/isNullable}} -{{#allowableValues}} -{{#enumVars}} - '{{name}}': {{{value}}}, -{{/enumVars}} -{{/allowableValues}} - }, -{{/isEnum}} -{{#requiredVars}} -{{#isEnum}} - ('{{name}}',): { -{{#isNullable}} - 'None': None, -{{/isNullable}} -{{#allowableValues}} -{{#enumVars}} - '{{name}}': {{{value}}}, -{{/enumVars}} -{{/allowableValues}} - }, -{{/isEnum}} -{{/requiredVars}} -{{#optionalVars}} -{{#isEnum}} - ('{{name}}',): { -{{#isNullable}} - 'None': None, -{{/isNullable}} -{{#allowableValues}} -{{#enumVars}} - '{{name}}': {{{value}}}, -{{/enumVars}} -{{/allowableValues}} - }, -{{/isEnum}} -{{/optionalVars}} - } - - validations = { -{{#hasValidation}} - ('value',): { -{{> model_templates/validations }} -{{/hasValidation}} -{{#requiredVars}} -{{#hasValidation}} - ('{{name}}',): { -{{> model_templates/validations }} -{{/hasValidation}} -{{/requiredVars}} -{{#optionalVars}} -{{#hasValidation}} - ('{{name}}',): { -{{> model_templates/validations }} -{{/hasValidation}} -{{/optionalVars}} - } - -{{#additionalPropertiesType}} - @cached_property - def additional_properties_type(): - """ - This must be a method because a model may have properties that are - of type self, this must run after the class is loaded - """ -{{#imports}} -{{#-first}} - lazy_import() -{{/-first}} -{{/imports}} - return ({{{additionalPropertiesType}}},) # noqa: E501 -{{/additionalPropertiesType}} -{{^additionalPropertiesType}} - additional_properties_type = None -{{/additionalPropertiesType}} - - _nullable = {{#isNullable}}True{{/isNullable}}{{^isNullable}}False{{/isNullable}} - - @cached_property - def openapi_types(): - """ - This must be a method because a model may have properties that are - of type self, this must run after the class is loaded - - Returns - openapi_types (dict): The key is attribute name - and the value is attribute type. - """ -{{#imports}} -{{#-first}} - lazy_import() -{{/-first}} -{{/imports}} - return { -{{#isAlias}} - 'value': ({{{dataType}}},), -{{/isAlias}} -{{#isEnum}} - 'value': ({{{dataType}}},), -{{/isEnum}} -{{#isArray}} - 'value': ({{{dataType}}},), -{{/isArray}} -{{#requiredVars}} - '{{name}}': ({{{dataType}}},), # noqa: E501, F821 -{{/requiredVars}} -{{#optionalVars}} - '{{name}}': ({{{dataType}}},), # noqa: E501, F821 -{{/optionalVars}} - } - - @cached_property - def discriminator(): -{{^discriminator}} - return None -{{/discriminator}} -{{#discriminator}} -{{#mappedModels}} -{{#-first}} -{{#imports}} -{{#-first}} - lazy_import() -{{/-first}} -{{/imports}} -{{/-first}} -{{/mappedModels}} - val = { -{{#mappedModels}} - '{{mappingName}}': {{{modelName}}}, -{{/mappedModels}} - } - if not val: - return None - return {'{{{discriminatorName}}}': val}{{/discriminator}} \ No newline at end of file diff --git a/observatory-platform/README.md b/observatory-platform/README.md deleted file mode 100644 index fcde0065a..000000000 --- a/observatory-platform/README.md +++ /dev/null @@ -1 +0,0 @@ -## Observatory Platform \ No newline at end of file diff --git a/observatory-platform/observatory/platform/api.py b/observatory-platform/observatory/platform/api.py deleted file mode 100644 index 5fcd8853e..000000000 --- a/observatory-platform/observatory/platform/api.py +++ /dev/null @@ -1,150 +0,0 @@ -# Copyright 2020 Curtin University -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Author: Author: Tuan Chien, James Diprose - -import datetime -import logging -from typing import List, Optional - -import pendulum -from airflow.hooks.base import BaseHook - -from observatory.api.client.model.dataset_release import DatasetRelease -from observatory.platform.config import AirflowConns - - -def make_observatory_api(observatory_api_conn_id: str = AirflowConns.OBSERVATORY_API) -> "ObservatoryApi": # noqa: F821 - """Make the ObservatoryApi object, configuring it with a host and api_key. - - :param observatory_api_conn_id: the Observatory API Airflow Connection ID. - :return: the ObservatoryApi. - """ - - try: - from observatory.api.client.api.observatory_api import ObservatoryApi - from observatory.api.client.api_client import ApiClient - from observatory.api.client.configuration import Configuration - except ImportError as e: - logging.error("Please install the observatory-api Python package to use the make_observatory_api function") - raise e - - # Get connection - conn = BaseHook.get_connection(observatory_api_conn_id) - - # Assert connection has required fields - assert ( - conn.conn_type != "" and conn.conn_type is not None - ), f"Airflow Connection {observatory_api_conn_id} conn_type must not be None" - assert ( - conn.host != "" and conn.host is not None - ), f"Airflow Connection {observatory_api_conn_id} host must not be None" - - # Make host - host = f'{str(conn.conn_type).replace("_", "-").lower()}://{conn.host}' - if conn.port: - host += f":{conn.port}" - - # Only api_key when password present in connection - api_key = None - if conn.password != "" and conn.password is not None: - api_key = {"api_key": conn.password} - - # Return ObservatoryApi - config = Configuration(host=host, api_key=api_key) - api_client = ApiClient(config) - return ObservatoryApi(api_client=api_client) - - -def get_dataset_releases(*, dag_id: str, dataset_id: str) -> List[DatasetRelease]: - """Get a list of dataset releases for a given dataset. - - :param dag_id: dag id. - :param dataset_id: Dataset id. - :return: List of dataset releases. - """ - - api = make_observatory_api() - dataset_releases = api.get_dataset_releases(dag_id=dag_id, dataset_id=dataset_id) - return dataset_releases - - -def get_latest_dataset_release(releases: List[DatasetRelease], date_key: str) -> Optional[DatasetRelease]: - """Get the dataset release from the list with the most recent end date. - - :param releases: List of releases. - :param date_key: the key for accessing dates. - :return: Latest release (by end_date) - """ - - if len(releases) == 0: - return None - - latest = releases[0] - for release in releases: - if getattr(release, date_key) > getattr(latest, date_key): - latest = release - - return latest - - -def get_new_release_dates(*, dag_id: str, dataset_id: str, releases: List[datetime.datetime]) -> List[str]: - """Get a list of new release dates, i.e., releases in the user supplied list that are not in the API db. - - :param dag_id: The DAG id to check. - :param dataset_id: The dataset id to check. - :param releases: List of release dates to check. - :return: List of new releases in YYYYMMDD string format. - """ - - api_releases = get_dataset_releases(dag_id=dag_id, dataset_id=dataset_id) - api_releases_set = set([release.end_date.strftime("%Y%m%d") for release in api_releases]) - releases_set = set([release.strftime("%Y%m%d") for release in releases]) - new_releases = list(releases_set.difference(api_releases_set)) - return new_releases - - -def is_first_release(dag_id: str, dataset_id: str) -> bool: - """Use the API to check whether this is the first release of a dataset, i.e., are there no dataset release records. - - :param dag_id: DAG ID. - :param dataset_id: dataset id. - :return: Whether this is the first release. - """ - - releases = get_dataset_releases(dag_id=dag_id, dataset_id=dataset_id) - return len(releases) == 0 - - -def build_schedule(sched_start_date: pendulum.DateTime, sched_end_date: pendulum.DateTime): - """Useful for API based data sources. - - Create a fetch schedule to specify what date ranges to use for each API call. Will default to once a month - for now, but in the future if we are minimising API calls, this can be a more complicated scheme. - - :param sched_start_date: the schedule start date. - :param sched_end_date: the end date of the schedule. - :return: list of (section_start_date, section_end_date) pairs from start_date to current Airflow DAG start date. - """ - - schedule = [] - - for start_date in pendulum.Period(start=sched_start_date, end=sched_end_date).range("months"): - if start_date >= sched_end_date: - break - end_date = start_date.add(months=1).subtract(days=1).end_of("day") - end_date = min(sched_end_date, end_date) - schedule.append(pendulum.Period(start_date.date(), end_date.date())) - - return schedule diff --git a/observatory-platform/observatory/platform/cli/cli.py b/observatory-platform/observatory/platform/cli/cli.py deleted file mode 100644 index 5f69e2290..000000000 --- a/observatory-platform/observatory/platform/cli/cli.py +++ /dev/null @@ -1,448 +0,0 @@ -# Copyright 2020 Curtin University -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Author: James Diprose, Aniek Roelofs, Tuan Chien - -import json -import os -from typing import ClassVar - -import click - -from observatory.platform.cli.cli_utils import INDENT1, INDENT2, INDENT3, indent -from observatory.platform.cli.generate_command import GenerateCommand -from observatory.platform.cli.platform_command import PlatformCommand -from observatory.platform.cli.terraform_command import TerraformCommand -from observatory.platform.config import observatory_home -from observatory.platform.config import ( - terraform_credentials_path as default_terraform_credentials_path, -) -from observatory.platform.docker.platform_runner import DEBUG, HOST_UID -from observatory.platform.observatory_config import Config, TerraformConfig, ObservatoryConfig -from observatory.platform.observatory_config import ( - generate_fernet_key, - generate_secret_key, -) - -PLATFORM_NAME = "Observatory Platform" -TERRAFORM_NAME = "Observatory Terraform" - -LOCAL_CONFIG_PATH = os.path.join(observatory_home(), "config.yaml") -TERRAFORM_CONFIG_PATH = os.path.join(observatory_home(), "config-terraform.yaml") - - -def load_config(cls: ClassVar, config_path: str) -> Config: - """Load a config file. - :param cls: the config file class. - :param config_path: the path to the config file. - :return: the config file or exit with an OS.EX_CONFIG error or FileExistsError. - """ - - print(indent("config.yaml:", INDENT1)) - config_exists = os.path.exists(config_path) - - if not config_exists: - msg = indent(f"- file not found, generating a default file on path: {config_path}", INDENT2) - generate_cmd = GenerateCommand() - if cls == ObservatoryConfig: - print(msg) - generate_cmd.generate_local_config(config_path, editable=False, workflows=[], oapi=False) - elif cls == TerraformConfig: - print(msg) - generate_cmd.generate_terraform_config(config_path, editable=False, workflows=[], oapi=False) - else: - print(indent(f"- file not found, exiting: {config_path}", INDENT2)) - exit(os.EX_CONFIG) - else: - cfg = cls.load(config_path) - if cfg.is_valid: - print(indent("- file valid", INDENT2)) - else: - print(indent("- file invalid", INDENT2)) - for error in cfg.errors: - print(indent(f"- {error.key}: {error.value}", INDENT3)) - exit(os.EX_CONFIG) - - return cfg - - -@click.group() -def cli(): - """The Observatory Platform command line tool. - - COMMAND: the commands to run include:\n - - platform: start and stop the local Observatory Platform platform.\n - - generate: generate a variety of outputs.\n - - terraform: manage Terraform Cloud workspaces.\n - """ - - pass - - -@cli.command() -@click.argument("command", type=click.Choice(["start", "stop"])) -@click.option( - "--config-path", - type=click.Path(exists=False, file_okay=True, dir_okay=False), - default=LOCAL_CONFIG_PATH, - help="The path to the config.yaml configuration file.", - show_default=True, -) -@click.option( - "--host-uid", - type=click.INT, - default=HOST_UID, - help="The user id of the host system. Used to set the user id in the Docker containers.", - show_default=True, -) -@click.option("--debug", is_flag=True, default=DEBUG, help="Print debugging information.") -def platform(command: str, config_path: str, host_uid: int, debug): - """Run the local Observatory Platform platform.\n - - COMMAND: the command to give the platform:\n - - start: start the platform.\n - - stop: stop the platform.\n - """ - - min_line_chars = 80 - print(f"{PLATFORM_NAME}: checking dependencies...".ljust(min_line_chars), end="\r") - cfg = load_config(ObservatoryConfig, config_path) - - # Make the platform command, which encapsulates functionality for running the observatory - platform_cmd = PlatformCommand(cfg, host_uid=host_uid, debug=debug) - - # Check dependencies - platform_check_dependencies(platform_cmd, min_line_chars=min_line_chars) - - # Start the appropriate process - if command == "start": - platform_start(platform_cmd) - elif command == "stop": - platform_stop(platform_cmd) - - exit(os.EX_OK) - - -def platform_check_dependencies(platform_cmd: PlatformCommand, min_line_chars: int = 80): - """Check Platform dependencies. - - :param platform_cmd: the platform command instance. - :param min_line_chars: the minimum number of lines when printing to the command line interface. - :return: None. - """ - - if not platform_cmd.is_environment_valid: - print(f"{PLATFORM_NAME}: dependencies missing".ljust(min_line_chars)) - else: - print(f"{PLATFORM_NAME}: all dependencies found".ljust(min_line_chars)) - - print(indent("Docker:", INDENT1)) - if platform_cmd.docker_exe_path is not None: - print(indent(f"- path: {platform_cmd.docker_exe_path}", INDENT2)) - - if platform_cmd.is_docker_running: - print(indent(f"- running", INDENT2)) - else: - print(indent("- not running, please start", INDENT2)) - else: - print(indent("- not installed, please install https://docs.docker.com/get-docker/", INDENT2)) - - print(indent("Docker Compose V2:", INDENT1)) - if platform_cmd.docker_compose: - print(indent(f"- installed", INDENT2)) - else: - print(indent("- not installed, please install https://docs.docker.com/compose/install/", INDENT2)) - - if not platform_cmd.is_environment_valid: - exit(os.EX_CONFIG) - - print(indent("Host machine settings:", INDENT1)) - print(indent(f"- observatory home: {platform_cmd.config.observatory.observatory_home}", INDENT2)) - print(indent(f"- host-uid: {platform_cmd.host_uid}", INDENT2)) - - -def platform_start(platform_cmd: PlatformCommand, min_line_chars: int = 80): - """Check Platform dependencies. - - :param platform_cmd: the platform command instance. - :param min_line_chars: the minimum number of lines when printing to the command line interface. - :return: None. - """ - - print(f"{PLATFORM_NAME}: building...".ljust(min_line_chars), end="\r") - response = platform_cmd.build() - - if response.return_code == 0: - print(f"{PLATFORM_NAME}: built".ljust(min_line_chars)) - else: - print(f"{PLATFORM_NAME}: build error".ljust(min_line_chars)) - print(response.output) - exit(os.EX_CONFIG) - - # Start the built containers - print(f"{PLATFORM_NAME}: starting...".ljust(min_line_chars), end="\r") - response = platform_cmd.start() - - if response.return_code == 0: - ui_started = platform_cmd.wait_for_airflow_ui(timeout=120) - - if ui_started: - print(f"{PLATFORM_NAME}: started".ljust(min_line_chars)) - print(f"View the Apache Airflow UI at {platform_cmd.ui_url}") - else: - print(f"{PLATFORM_NAME}: error starting".ljust(min_line_chars)) - print(f"Could not find the Airflow UI at {platform_cmd.ui_url}") - else: - print("Error starting the Observatory Platform") - print(response.output) - exit(os.EX_CONFIG) - - -def platform_stop(platform_cmd: PlatformCommand, min_line_chars: int = 80): - """Start the Observatory platform. - - :param platform_cmd: the platform command instance. - :param min_line_chars: the minimum number of lines when printing to the command line interface. - :return: None. - """ - - print(f"{PLATFORM_NAME}: stopping...".ljust(min_line_chars), end="\r") - response = platform_cmd.stop() - - if platform_cmd.debug: - print(response.output) - - if response.return_code == 0: - print(f"{PLATFORM_NAME}: stopped".ljust(min_line_chars)) - else: - print(f"{PLATFORM_NAME}: error stopping".ljust(min_line_chars)) - print(response.error) - exit(os.EX_CONFIG) - - -@cli.group() -def generate(): - """The Observatory Platform generate command. - - COMMAND: the commands to run include:\n - - secrets: generate secrets.\n - - config: generate configuration files for the Observatory Platform.\n - - project: generate a new project directory and required files.\n - - workflow: generate all files for a new workflow. - """ - - pass - - -@generate.command() -@click.argument("command", type=click.Choice(["fernet-key", "secret-key"])) -def secrets(command: str): - """Generate secrets for the Observatory Platform.\n - - COMMAND: the type of secret to generate:\n - - fernet-key: generate a random Fernet Key.\n - - secret-key: generate a random secret key.\n - """ - - if command == "fernet-key": - print(generate_fernet_key()) - else: - print(generate_secret_key()) - - -@generate.command() -@click.argument("command", type=click.Choice(["local", "terraform"])) -@click.option( - "--config-path", - type=click.Path(exists=False, file_okay=True, dir_okay=False), - default=None, - help="The path to the config file to generate.", - show_default=True, -) -@click.option("--interactive", flag_value=True, help="Configuration through an interactive Q&A mode") -@click.option( - "--ao-wf", - flag_value=True, - help="Indicates that the academic-observatory-workflows was installed through the cli installer script", -) -@click.option( - "--oaebu-wf", - flag_value=True, - help="Indicates that the oaebu-workflows was installed through the cli installer script", -) -@click.option( - "--oapi", flag_value=True, help="Indicates that the observatory api was installed through the cli installer script" -) -@click.option("--editable", flag_value=True, help="Indicates the observatory platform is editable") -def config(command: str, config_path: str, interactive: bool, ao_wf: bool, oaebu_wf: bool, oapi: bool, editable: bool): - """Generate config files for the Observatory Platform.\n - - COMMAND: the type of config file to generate:\n - - local: generate a config file for running the Observatory Platform locally.\n - - terraform: generate a config file for running the Observatory Platform with Terraform.\n - - :param interactive: whether to interactively ask for configuration options. - :param ao_wf: Whether academic_observatory_workflows was installed using the installer script. - :param oaebu_wf: Whether oaebu_workflows was installed using the installer script. - :param oapi: Whether the Observatory API was installed using the installer script. - :param editable: Whether the observatory platform is editable. - """ - - # Make the generate command, which encapsulates functionality for generating data - cmd = GenerateCommand() - - if config_path is None: - config_path = LOCAL_CONFIG_PATH if command == "local" else TERRAFORM_CONFIG_PATH - - config_name = "Observatory config" if command == "local" else "Terraform config" - - workflows = [] - if ao_wf: - workflows.append("academic-observatory-workflows") - if oaebu_wf: - workflows.append("oaebu-workflows") - - if command == "local" and not interactive: - cmd_func = cmd.generate_local_config - elif command == "terraform" and not interactive: - cmd_func = cmd.generate_terraform_config - elif command == "local" and interactive: - cmd_func = cmd.generate_local_config_interactive - else: - cmd_func = cmd.generate_terraform_config_interactive - - if not os.path.exists(config_path) or click.confirm( - f'The file "{config_path}" exists, do you want to overwrite it?' - ): - cmd_func(config_path, workflows=workflows, oapi=oapi, editable=editable) - else: - click.echo(f"Not generating {config_name}") - - -# increase content width for cleaner help output -@cli.command(context_settings=dict(max_content_width=120)) -@click.argument( - "command", - type=click.Choice(["build-terraform", "build-image", "create-workspace", "update-workspace"]), -) -# The path to the config-terraform.yaml configuration file. -@click.argument("config-path", type=click.Path(exists=True, file_okay=True, dir_okay=False)) -@click.option( - "--terraform-credentials-path", - type=click.Path(exists=False, file_okay=True, dir_okay=False), - default=default_terraform_credentials_path(), - help="", - show_default=True, -) -@click.option("--debug", is_flag=True, default=DEBUG, help="Print debugging information.") -def terraform(command, config_path, terraform_credentials_path, debug): - """Commands to manage the deployment of the Observatory Platform with Terraform Cloud.\n - - COMMAND: the type of config file to generate:\n - - create-workspace: create a Terraform Cloud workspace.\n - - update-workspace: update a Terraform Cloud workspace.\n - - build-image: build a Google Compute image for the Terraform deployment with Packer.\n - - build-terraform: build the Terraform files.\n - """ - - cfg = load_config(TerraformConfig, config_path) - terraform_cmd = TerraformCommand(cfg, terraform_credentials_path, debug=debug) - generate_cmd = GenerateCommand() - - # Check dependencies - terraform_check_dependencies(terraform_cmd, generate_cmd) - - # Run commands - if command == "build-terraform": - # Build image with packer - terraform_cmd.build_terraform() - elif command == "build-image": - # Build image with packer - terraform_cmd.build_image() - else: - # Create a new workspace - if command == "create-workspace": - terraform_cmd.create_workspace() - - # Update an existing workspace - elif command == "update-workspace": - terraform_cmd.update_workspace() - - -@cli.command("sort-schema") -@click.argument("input-file", type=click.Path(exists=True, file_okay=True, dir_okay=False)) -def sort_schema_cmd(input_file): - def sort_schema(schema): - sorted_schema = sorted(schema, key=lambda x: x["name"]) - - for field in sorted_schema: - if field["type"] == "RECORD" and "fields" in field: - field["fields"] = sort_schema(field["fields"]) - - return sorted_schema - - # Load the JSON schema from a string - with open(input_file, mode="r") as f: - data = json.load(f) - - # Sort the schema - sorted_json_schema = sort_schema(data) - - # Save the schema - with open(input_file, mode="w") as f: - json.dump(sorted_json_schema, f, indent=2) - - -def terraform_check_dependencies( - terraform_cmd: TerraformCommand, generate_cmd: GenerateCommand, min_line_chars: int = 80 -): - """Check Terraform dependencies. - - :param terraform_cmd: the Terraform command instance. - :param generate_cmd: the generate command instance. - :param min_line_chars: the minimum number of lines when printing to the command line interface. - :return: None. - """ - - print(f"{TERRAFORM_NAME}: checking dependencies...".ljust(min_line_chars), end="\r") - - if not terraform_cmd.is_environment_valid: - print(f"{TERRAFORM_NAME}: dependencies missing".ljust(min_line_chars)) - else: - print(f"{TERRAFORM_NAME}: all dependencies found".ljust(min_line_chars)) - - print(indent("Terraform credentials file:", INDENT1)) - if terraform_cmd.terraform_credentials_exists: - print(indent(f"- path: {terraform_cmd.terraform_credentials_path}", INDENT2)) - else: - print(indent("- file not found, create one by running 'terraform login'", INDENT2)) - - print(indent("Packer", INDENT1)) - if terraform_cmd.terraform_builder.packer_exe_path is not None: - print(indent(f"- path: {terraform_cmd.terraform_builder.packer_exe_path}", INDENT2)) - else: - print(indent("- not installed, please install https://www.packer.io/docs/install", INDENT2)) - - print(indent("Google Cloud SDK", INDENT1)) - if terraform_cmd.terraform_builder.gcloud_exe_path is not None: - print(indent(f"- path: {terraform_cmd.terraform_builder.gcloud_exe_path}", INDENT2)) - else: - print(indent("- not installed, please install https://cloud.google.com/sdk/docs/install", INDENT2)) - - if not terraform_cmd.is_environment_valid: - exit(os.EX_CONFIG) - - -if __name__ == "__main__": - cli() diff --git a/observatory-platform/observatory/platform/cli/cli_utils.py b/observatory-platform/observatory/platform/cli/cli_utils.py deleted file mode 100644 index c8de64638..000000000 --- a/observatory-platform/observatory/platform/cli/cli_utils.py +++ /dev/null @@ -1,41 +0,0 @@ -# Copyright 2020 Curtin University -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Author: James Diprose, Aniek Roelofs - -INDENT1 = 2 -INDENT2 = 3 -INDENT3 = 4 -INDENT4 = 5 - - -def indent(string: str, num_spaces: int) -> str: - """Left indent a string. - - :param string: the string to indent. - :param num_spaces: the number of spaces to indent the string with. - :return: the indented string. - """ - - assert num_spaces > 0, "indent: num_spaces must be > 0" - return string.rjust(len(string) + num_spaces) - - -def comment(string: str) -> str: - """Add a Python comment character in front of a string. - :param string: String to comment out. - :return: Commented string. - """ - - return f"# {string}" diff --git a/observatory-platform/observatory/platform/cli/generate_command.py b/observatory-platform/observatory/platform/cli/generate_command.py deleted file mode 100644 index c0a839edc..000000000 --- a/observatory-platform/observatory/platform/cli/generate_command.py +++ /dev/null @@ -1,612 +0,0 @@ -# Copyright 2020 Curtin University -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Author: James Diprose, Aniek Roelofs, Tuan Chien - -from pathlib import Path -from typing import Any, List, Optional - -import click - -from observatory.platform.config import module_file_path -from observatory.platform.observatory_config import ( - BackendType, - CloudSqlDatabase, - Environment, - GoogleCloud, - Observatory, - ObservatoryConfig, - Terraform, - TerraformConfig, - VirtualMachine, - WorkflowsProject, - is_fernet_key, - is_secret_key, - Config, -) - -# Terminal formatting -BOLD = "\033[1m" -END = "\033[0m" - - -class DefaultWorkflowsProject: - """Get Workflows Project configuration for when it was selected via the installer script (editable type).""" - - @staticmethod - def academic_observatory_workflows(): - package_name = "academic-observatory-workflows" - package = module_file_path("academic_observatory_workflows.dags", nav_back_steps=-3) - package_type = "editable" - dags_module = "academic_observatory_workflows.dags" - - return WorkflowsProject( - package_name=package_name, package=package, package_type=package_type, dags_module=dags_module - ) - - @staticmethod - def oaebu_workflows(): - package_name = "oaebu-workflows" - package = module_file_path("oaebu_workflows.dags", nav_back_steps=-3) - package_type = "editable" - dags_module = "oaebu_workflows.dags" - - return WorkflowsProject( - package_name=package_name, package=package, package_type=package_type, dags_module=dags_module - ) - - -class GenerateCommand: - def generate_local_config(self, config_path: str, *, editable: bool, workflows: List[str], oapi: bool): - """Command line user interface for generating an Observatory Config config.yaml. - - :param config_path: the path where the config file should be saved. - :param editable: Whether the observatory platform is editable. - :param workflows: List of installer script installed workflows. - :return: None - """ - - file_type = "Observatory Config" - click.echo(f"Generating {file_type}...") - - workflows = InteractiveConfigBuilder.get_installed_workflows(workflows) - config = ObservatoryConfig(workflows_projects=workflows) - - if editable: - InteractiveConfigBuilder.set_editable_observatory_platform(config.observatory) - - if oapi: - InteractiveConfigBuilder.set_editable_observatory_api(config.observatory) - - config.save(path=config_path) - - click.echo(f'{file_type} saved to: "{config_path}"') - - def generate_terraform_config(self, config_path: str, *, editable: bool, workflows: List[str], oapi: bool): - """Command line user interface for generating a Terraform Config config-terraform.yaml. - - :param config_path: the path where the config file should be saved. - :param editable: Whether the observatory platform is editable. - :param workflows: List of installer script installed workflows. - :return: None - """ - - file_type = "Terraform Config" - click.echo(f"Generating {file_type}...") - workflows = InteractiveConfigBuilder.get_installed_workflows(workflows) - config = TerraformConfig(workflows_projects=workflows) - - if editable: - InteractiveConfigBuilder.set_editable_observatory_platform(config.observatory) - - if oapi: - InteractiveConfigBuilder.set_editable_observatory_api(config.observatory) - - config.save(path=config_path) - - click.echo(f'{file_type} saved to: "{config_path}"') - - def generate_local_config_interactive(self, config_path: str, *, workflows: List[str], oapi: bool, editable: bool): - """Construct an Observatory local config file through user assisted configuration. - - :param config_path: Configuration file path. - :param workflows: List of installer script installed workflows projects. - :param oapi: Whether installer script installed the Observatory API. - :param editable: Whether the observatory platform is editable. - """ - - file_type = "Observatory Config" - click.echo(f"Generating {file_type}...") - - config = InteractiveConfigBuilder.build( - backend_type=BackendType.local, workflows=workflows, oapi=oapi, editable=editable - ) - - if editable: - InteractiveConfigBuilder.set_editable_observatory_platform(config.observatory) - - config.save(config_path) - click.echo(f'{file_type} saved to: "{config_path}"') - - def generate_terraform_config_interactive( - self, config_path: str, *, workflows: List[str], oapi: bool, editable: bool - ): - """Construct an Observatory Terraform config file through user assisted configuration. - - :param config_path: Configuration file path. - :param workflows: List of workflows projects installed by installer script. - :param oapi: Whether installer script installed the Observatory API. - :param editable: Whether the observatory platform is editable. - """ - - file_type = "Terraform Config" - click.echo(f"Generating {file_type}...") - - config = InteractiveConfigBuilder.build( - backend_type=BackendType.terraform, workflows=workflows, oapi=oapi, editable=editable - ) - - if editable: - InteractiveConfigBuilder.set_editable_observatory_platform(config.observatory) - - config.save(config_path) - click.echo(f'{file_type} saved to: "{config_path}"') - - -class FernetKeyType(click.ParamType): - """Fernet key type for click prompt. Will validate the input against the is_fernet_key method.""" - - name = "FernetKeyType" - - def convert( - self, value: Any, param: Optional[click.core.Parameter] = None, ctx: Optional[click.core.Context] = None - ) -> Any: - valid, msg = is_fernet_key(value) - if not valid: - self.fail(f"Input is not a valid Fernet key. Reason: {msg}", param=param, ctx=ctx) - - return value - - -class FlaskSecretKeyType(click.ParamType): - """Secret key type for click prompt. Will validate the input against the is_secret_key method.""" - - name = "SecretKeyType" - - def convert( - self, value: Any, param: Optional[click.core.Parameter] = None, ctx: Optional[click.core.Context] = None - ) -> Any: - valid, msg = is_secret_key(value) - if not valid: - self.fail(f"Input is not a valid secret key. Reason: {msg}", param=param, ctx=ctx) - - return value - - -class InteractiveConfigBuilder: - """Helper class for configuring the ObservatoryConfig class parameters through interactive user input.""" - - @staticmethod - def set_editable_observatory_platform(observatory: Observatory): - """Set observatory package settings to editable. - - :param observatory: Observatory object to change. - """ - - observatory.package = module_file_path("observatory.platform", nav_back_steps=-3) - observatory.package_type = "editable" - - @staticmethod - def set_editable_observatory_api(observatory: Observatory): - """Set observatory api package settings to editable. - - :param observatory: Observatory object to change. - """ - - observatory.api_package = module_file_path("observatory.api", nav_back_steps=-3) - observatory.api_package_type = "editable" - - @staticmethod - def get_installed_workflows(workflows: List[str]) -> List[WorkflowsProject]: - """Add the workflows projects installed by the installer script. - - :param workflows: List of installed workflows (via installer script). - :return: List of WorkflowsProjects installed by the installer. - """ - - workflows_projects = [] - if "academic-observatory-workflows" in workflows: - workflows_projects.append(DefaultWorkflowsProject.academic_observatory_workflows()) - - if "oaebu-workflows" in workflows: - workflows_projects.append(DefaultWorkflowsProject.oaebu_workflows()) - - return workflows_projects - - @staticmethod - def build(*, backend_type: BackendType, workflows: List[str], oapi: bool, editable: bool) -> Config: - """Build the correct observatory configuration object through user assisted parameters. - - :param backend_type: The type of Observatory backend being configured. - :param workflows: List of workflows installed by installer script. - :param oapi: Whether installer script installed the Observatory API. - :param editable: Whether the observatory platform is editable. - :return: An observatory configuration object. - """ - - workflows_projects = InteractiveConfigBuilder.get_installed_workflows(workflows) - - if backend_type == BackendType.local: - config = ObservatoryConfig(workflows_projects=workflows_projects) - else: - config = TerraformConfig(workflows_projects=workflows_projects) - - # Common sections for all backends - InteractiveConfigBuilder.config_backend(config=config, backend_type=backend_type) - InteractiveConfigBuilder.config_observatory(config=config, oapi=oapi, editable=editable) - InteractiveConfigBuilder.config_terraform(config) - InteractiveConfigBuilder.config_google_cloud(config) - InteractiveConfigBuilder.config_workflows_projects(config) - - # Extra sections for Terraform - if backend_type == BackendType.terraform: - InteractiveConfigBuilder.config_cloud_sql_database(config) - InteractiveConfigBuilder.config_airflow_main_vm(config) - InteractiveConfigBuilder.config_airflow_worker_vm(config) - - return config - - @staticmethod - def config_backend(*, config: Config, backend_type: BackendType): - """Configure the backend section. - - :param config: Configuration object to edit. - :param backend_type: The backend type being used. - """ - - click.echo("Configuring backend settings") - config.backend.type = backend_type - - text = "What kind of environment is this?" - default = Environment.develop.name - choices = click.Choice( - choices=[Environment.develop.name, Environment.staging.name, Environment.production.name], - case_sensitive=False, - ) - - config.backend.environment = Environment[ - click.prompt(text=text, type=choices, default=default, show_default=True, show_choices=True) - ] - - @staticmethod - def config_observatory(*, config: Config, oapi: bool, editable: bool): - """Configure the observatory section. - - :param config: Configuration object to edit. - """ - - click.echo("Configuring Observatory settings") - - if editable: - InteractiveConfigBuilder.set_editable_observatory_platform(config.observatory) - # else: - # # Fill in if used installer script - # text = "What type of observatory platform installation did you perform? A git clone is an editable type, and a pip install is a pypi type." - # choices = click.Choice(choices=["editable", "sdist", "pypi"], case_sensitive=False) - # default = "pypi" - # package_type = click.prompt(text=text, type=choices, default=default, show_default=True, show_choices=True) - # config.observatory.package_type = package_type - - text = "Enter an Airflow Fernet key (leave blank to autogenerate)" - default = "" - fernet_key = click.prompt(text=text, type=FernetKeyType(), default=default) - - if fernet_key != "": - config.observatory.airflow_fernet_key = fernet_key - - text = "Enter an Airflow secret key (leave blank to autogenerate)" - default = "" - secret_key = click.prompt(text=text, type=FlaskSecretKeyType(), default=default) - - if secret_key != "": - config.observatory.airflow_secret_key = secret_key - - text = "Enter an email address to use for logging into the Airflow web interface" - default = config.observatory.airflow_ui_user_email - user_email = click.prompt(text=text, type=str, default=default, show_default=True) - config.observatory.airflow_ui_user_email = user_email - - text = f"Password for logging in with {user_email}" - default = config.observatory.airflow_ui_user_password - user_pass = click.prompt(text=text, type=str, default=default, show_default=True) - config.observatory.airflow_ui_user_password = user_pass - - text = "Enter observatory config directory. If it does not exist, it will be created." - default = config.observatory.observatory_home - observatory_home = click.prompt( - text=text, type=click.Path(exists=False, readable=True), default=default, show_default=True - ) - config.observatory.observatory_home = observatory_home - Path(observatory_home).mkdir(exist_ok=True, parents=True) - - text = "Enter postgres password" - default = config.observatory.postgres_password - postgres_password = click.prompt(text=text, type=str, default=default, show_default=True) - config.observatory.postgres_password = postgres_password - - text = "Redis port" - default = config.observatory.redis_port - redis_port = click.prompt(text=text, type=int, default=default, show_default=True) - config.observatory.redis_port = redis_port - - text = "Flower UI port" - default = config.observatory.flower_ui_port - flower_ui_port = click.prompt(text=text, type=int, default=default, show_default=True) - config.observatory.flower_ui_port = flower_ui_port - - text = "Airflow UI port" - default = config.observatory.airflow_ui_port - airflow_ui_port = click.prompt(text=text, type=int, default=default, show_default=True) - config.observatory.airflow_ui_port = airflow_ui_port - - text = "API port" - default = config.observatory.api_port - api_port = click.prompt(text=text, type=int, default=default, show_default=True) - config.observatory.api_port = api_port - - text = "Docker network name" - default = config.observatory.docker_network_name - docker_network_name = click.prompt(text=text, type=str, default=default, show_default=True) - config.observatory.docker_network_name = docker_network_name - - text = "Is the docker network external?" - default = config.observatory.docker_network_is_external - docker_network_is_external = click.prompt(text=text, type=bool, default=default, show_default=True) - config.observatory.docker_network_is_external = docker_network_is_external - - text = "Docker compose project name" - default = config.observatory.docker_compose_project_name - docker_compose_project_name = click.prompt(text=text, type=str, default=default, show_default=True) - config.observatory.docker_compose_project_name = docker_compose_project_name - - text = "Do you wish to enable ElasticSearch and Kibana?" - choices = click.Choice(choices=["y", "n"], case_sensitive=False) - default = "y" - enable_elk = click.prompt(text=text, default=default, type=choices, show_default=True, show_choices=True) - - config.observatory.enable_elk = True if enable_elk == "y" else False - - # If installed by installer script, we can fill in details - if oapi and editable: - InteractiveConfigBuilder.set_editable_observatory_api(config.observatory) - elif not oapi: - text = "Observatory API package name" - default = config.observatory.api_package - api_package = click.prompt(text=text, type=str, default=default, show_default=True) - config.observatory.api_package = api_package - - text = "Observatory API package type" - default = config.observatory.api_package_type - api_package_type = click.prompt(text=text, type=str, default=default, show_default=True) - config.observatory.api_package_type = api_package_type - - @staticmethod - def config_google_cloud(config: Config): - """Configure the Google Cloud section. - - :param config: Configuration object to edit. - """ - - zone = None - region = None - buckets = list() - - if not config.schema["google_cloud"]["required"]: - text = "Do you want to configure Google Cloud settings?" - proceed = click.confirm(text=text, default=False, abort=False, show_default=True) - if not proceed: - return - - text = "Google Cloud Project ID" - project_id = click.prompt(text=text, type=str) - - text = "Path to Google Service Account key file (json)" - credentials = click.prompt(text=text, type=click.Path(exists=True, readable=True)) - - text = "Data location" - default = "us" - data_location = click.prompt(text=text, type=str, default=default, show_default=True) - - if config.backend.type == BackendType.terraform: - text = "Region" - default = "us-west1" - region = click.prompt(text=text, type=str, default=default, show_default=True) - - text = "Zone" - default = "us-west1-a" - zone = click.prompt(text=text, type=str, default=default, show_default=True) - - config.google_cloud = GoogleCloud( - project_id=project_id, - credentials=credentials, - region=region, - zone=zone, - data_location=data_location, - ) - - @staticmethod - def config_terraform(config: Config): - """Configure the Terraform section. - - :param config: Configuration object to edit. - """ - - if not config.schema["terraform"]["required"]: - text = "Do you want to configure Terraform settings?" - proceed = click.confirm(text=text, default=False, abort=False, show_default=True) - if not proceed: - return - - if config.backend.type == BackendType.local: - suffix = " (leave blank to disable)" - default = "" - else: - suffix = "" - default = None - - text = f"Terraform organization name{suffix}" - organization = click.prompt(text=text, type=str, default=default) - - if organization == "": - return - - config.terraform = Terraform(organization=organization) - - @staticmethod - def config_workflows_projects(config: Config): - """Configure the DAGs projects section. - - :param config: Configuration object to edit. - """ - - click.echo( - "Configuring workflows projects. If you opted to install some workflows projects through the installer script then they will be automatically added to the config file for you. If not, e.g., if you installed via pip, you will need to add those projects manually now (or later)." - ) - - text = "Do you want to add workflows projects?" - add_workflows_projects = click.confirm(text=text, default=False, abort=False, show_default=True) - - if not add_workflows_projects: - return - - projects = list() - while True: - text = "Workflows package name" - package_name = click.prompt(text=text, type=str) - - text = "Workflows package, either a local path to a Python source (editable), sdist, or PyPI package name and version" - package = click.prompt(text=text, type=click.Path(exists=True, readable=True)) - - text = "Package type" - choices = click.Choice(choices=["editable", "sdist", "pypi"], case_sensitive=False) - default = "editable" - package_type = click.prompt(text=text, default=default, type=choices, show_default=True, show_choices=True) - - text = "Python import path to the module that contains the Apache Airflow DAGs to load" - dags_module = click.prompt(text=text, type=str) - - projects.append( - WorkflowsProject( - package_name=package_name, - package=package, - package_type=package_type, - dags_module=dags_module, - ) - ) - - text = "Do you wish to add another DAGs project?" - add_another = click.confirm(text=text, default=False, abort=False, show_default=True) - - if not add_another: - break - - config.workflows_projects.extend(projects) - - @staticmethod - def config_cloud_sql_database(config: TerraformConfig): - """Configure the cloud SQL database section. - - :param config: Configuration object to edit. - """ - - click.echo("Configuring the Google Cloud SQL Database") - - text = "Google CloudSQL db tier" - default = "db-custom-2-7680" - tier = click.prompt(text=text, type=str, default=default, show_default=True) - - text = "Google CloudSQL backup start time" - default = "23:00" - backup_start_time = click.prompt(text=text, type=str, default=default, show_default=True) - - config.cloud_sql_database = CloudSqlDatabase( - tier=tier, - backup_start_time=backup_start_time, - ) - - @staticmethod - def config_airflow_main_vm(config: TerraformConfig): - """Configure the Airflow main virtual machine section. - - :param config: Configuration object to edit. - """ - - click.echo(BOLD + "Configuring settings for the main VM that runs the Airflow scheduler and webserver" + END) - - text = "Machine type" - default = "n2-standard-2" - machine_type = click.prompt(text=text, type=str, default=default, show_default=True) - - text = "Disk size (GB)" - default = 50 - disk_size = click.prompt(text=text, type=int, default=default, show_default=True) - - text = "Disk type" - schema = config.schema["airflow_main_vm"]["schema"] - default = "pd-ssd" - choices = click.Choice(choices=schema["disk_type"]["allowed"], case_sensitive=False) - disk_type = click.prompt(text=text, type=choices, show_choices=True, default=default, show_default=True) - - text = "Create VM? If yes, and you run Terraform apply, the vm will be created. Otherwise if false, and you run Terraform apply, the vm will be destroyed." - create = click.confirm(text=text, default=True, abort=False, show_default=True) - - config.airflow_main_vm = VirtualMachine( - machine_type=machine_type, - disk_size=disk_size, - disk_type=disk_type, - create=create, - ) - - @staticmethod - def config_airflow_worker_vm(config: TerraformConfig): - """Configure the Airflow worker virtual machine section. - - :param config: Configuration object to edit. - """ - - click.echo(BOLD + "Configuring settings for the worker VM" + END) - - text = "Machine type" - default = "n1-standard-8" - machine_type = click.prompt(text=text, type=str, default=default, show_default=True) - - text = "Disk size (GB)" - default = 3000 - disk_size = click.prompt(text=text, type=int, default=default, show_default=True) - - text = "Disk type" - schema = config.schema["airflow_worker_vm"]["schema"] - default = "pd-standard" - choices = click.Choice(choices=schema["disk_type"]["allowed"], case_sensitive=False) - disk_type = click.prompt(text=text, type=choices, show_choices=True, default=default, show_default=True) - - text = "Create VM? If yes, and you run Terraform apply, the vm will be created. Otherwise if false, and you run Terraform apply, the vm will be destroyed." - create = click.confirm(text=text, default=False, abort=False, show_default=True) - - config.airflow_worker_vm = VirtualMachine( - machine_type=machine_type, - disk_size=disk_size, - disk_type=disk_type, - create=create, - ) diff --git a/observatory-platform/observatory/platform/cli/platform_command.py b/observatory-platform/observatory/platform/cli/platform_command.py deleted file mode 100644 index 8aee76b53..000000000 --- a/observatory-platform/observatory/platform/cli/platform_command.py +++ /dev/null @@ -1,46 +0,0 @@ -# Copyright 2020 Curtin University -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Author: James Diprose - -from observatory.platform.docker.platform_runner import PlatformRunner, HOST_UID, DEBUG -from observatory.platform.observatory_config import ObservatoryConfig -from observatory.platform.utils.url_utils import wait_for_url - - -class PlatformCommand(PlatformRunner): - def __init__(self, config: ObservatoryConfig, host_uid: int = HOST_UID, debug: bool = DEBUG): - """Create a PlatformCommand, which can be used to start and stop Observatory Platform instances. - :param config: The path to the observatory config. - :param host_uid: The user id of the host system. Used to set the user id in the Docker containers. - :param debug: Print debugging information. - """ - - super().__init__(config=config, host_uid=host_uid, debug=debug) - - @property - def ui_url(self) -> str: - """Return the URL to Apache Airflow UI. - :return: Apache Airflow UI URL. - """ - - return f"http://localhost:{self.config.observatory.airflow_ui_port}" - - def wait_for_airflow_ui(self, timeout: int = 60) -> bool: - """Wait for the Apache Airflow UI to start. - :param timeout: the number of seconds to wait before timing out. - :return: whether connecting to the Apache Airflow UI was successful or not. - """ - - return wait_for_url(self.ui_url, timeout=timeout) diff --git a/observatory-platform/observatory/platform/cli/terraform_command.py b/observatory-platform/observatory/platform/cli/terraform_command.py deleted file mode 100644 index 2b7d2b35d..000000000 --- a/observatory-platform/observatory/platform/cli/terraform_command.py +++ /dev/null @@ -1,211 +0,0 @@ -# Copyright 2020 Curtin University -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Author: James Diprose, Aniek Roelofs - -import os - -import click - -from observatory.platform.cli.cli_utils import indent, INDENT1, INDENT2 -from observatory.platform.observatory_config import TerraformConfig -from observatory.platform.terraform.terraform_api import TerraformApi -from observatory.platform.terraform.terraform_api import TerraformVariable -from observatory.platform.terraform.terraform_builder import TerraformBuilder - - -class TerraformCommand: - def __init__(self, config: TerraformConfig, terraform_credentials_path: str, debug: bool = False): - """Create a TerraformCommand, which can be used to create and update terraform workspaces. - - :param config: the Terraform Config file. - :param terraform_credentials_path: the path to the Terraform credentials file. - :param debug: whether to print debugging information. - """ - - self.config = config - self.terraform_builder = TerraformBuilder(config, debug=debug) - self.terraform_credentials_path = terraform_credentials_path - self.debug = debug - self.terraform_credentials_exists = os.path.exists(terraform_credentials_path) - - @property - def is_environment_valid(self): - """Whether is the parameters passed to the TerraformCommand are valid. - - :return: whether the parameters passed to the TerraformCommand are valid. - """ - - return all( - [ - self.terraform_credentials_exists, - self.terraform_builder.is_environment_valid, - ] - ) - - def print_variable(self, var: TerraformVariable): - """Print the output for the CLI for a single TerraformVariable instance. - - :param var: the TerraformVariable instance. - :return: None. - """ - - if var.sensitive: - print(indent(f"* \x1B[3m{var.key}\x1B[23m: sensitive", INDENT2)) - else: - print(indent(f"* \x1B[3m{var.key}\x1B[23m: {var.value}", INDENT2)) - - def print_variable_update(self, old_var: TerraformVariable, new_var: TerraformVariable): - """Print the output for the CLI for a terraform variable that is being updated. - - :param old_var: the old TerraformVariable instance. - :param new_var: the new TerraformVariable instance. - :return: None. - """ - - if old_var.sensitive: - print(indent(f"* \x1B[3m{old_var.key}\x1B[23m: sensitive -> sensitive", INDENT2)) - else: - print(indent(f"* \x1B[3m{old_var.key}\x1B[23m: {old_var.value} -> {new_var.value}", INDENT2)) - - def build_terraform(self): - """Build the Terraform files for the Observatory Platform. - - :return: None. - """ - - self.terraform_builder.build_terraform() - - def build_image(self): - """Build a Google Compute image for the Terraform deployment with Packer. - - :return: None. - """ - - self.terraform_builder.build_image() - - @property - def verbosity(self): - """Convert debug switch into Terraform API verbosity. - :return: - """ - - if self.debug: - return TerraformApi.VERBOSITY_DEBUG - return TerraformApi.VERBOSITY_WARNING - - def print_summary(self): - # Get organization, environment and prefix - organization = self.config.terraform.organization - environment = self.config.backend.environment.value - workspace = self.config.terraform_workspace_id - - # Display settings for workspace - print("\nTerraform Cloud Workspace: ") - print(indent(f"Organization: {organization}", INDENT1)) - print( - indent( - f"- Name: {workspace} (prefix: '{TerraformConfig.WORKSPACE_PREFIX}' + suffix: '{environment}')", INDENT1 - ) - ) - print(indent(f"- Settings: ", INDENT1)) - print(indent(f"- Auto apply: True", INDENT2)) - print(indent(f"- Terraform Variables:", INDENT1)) - - def create_workspace(self): - """Create a Terraform workspace. - - :return: None. - """ - - self.print_summary() - - # Get terraform token - token = TerraformApi.token_from_file(self.terraform_credentials_path) - terraform_api = TerraformApi(token, self.verbosity) - - # Get variables - terraform_variables = self.config.terraform_variables() - - # Get organization, environment and prefix - organization = self.config.terraform.organization - workspace = self.config.terraform_workspace_id - - for variable in terraform_variables: - self.print_variable(variable) - - # confirm creating workspace - if click.confirm("Would you like to create a new workspace with these settings?"): - print("Creating workspace...") - - # Create new workspace - terraform_api.create_workspace(organization, workspace, auto_apply=True, description="") - - # Get workspace ID - workspace_id = terraform_api.workspace_id(organization, workspace) - - # Add variables to workspace - for var in terraform_variables: - terraform_api.add_workspace_variable(var, workspace_id) - - print("Successfully created workspace") - - def update_workspace(self): - """Update a Terraform workspace. - - :return: None. - """ - - self.print_summary() - - # Get terraform token - token = TerraformApi.token_from_file(self.terraform_credentials_path) - terraform_api = TerraformApi(token, self.verbosity) - - # Get variables - terraform_variables = self.config.terraform_variables() - - # Get organization, environment and prefix - organization = self.config.terraform.organization - workspace = self.config.terraform_workspace_id - - # Get workspace ID - workspace_id = terraform_api.workspace_id(organization, workspace) - add, edit, unchanged, delete = terraform_api.plan_variable_changes(terraform_variables, workspace_id) - - if add: - print(indent("NEW", INDENT1)) - for var in add: - self.print_variable(var) - if edit: - print(indent("UPDATE", INDENT1)) - for old_var, new_var in edit: - self.print_variable_update(old_var, new_var) - if delete: - print(indent("DELETE", INDENT1)) - for var in delete: - self.print_variable(var) - if unchanged: - print(indent("UNCHANGED", INDENT1)) - for var in unchanged: - self.print_variable(var) - - # confirm creating workspace - if click.confirm("Would you like to update the workspace with these settings?"): - print("Updating workspace...") - - # Update variables in workspace - terraform_api.update_workspace_variables(add, edit, delete, workspace_id) - - print("Successfully updated workspace") diff --git a/observatory-platform/observatory/platform/config-terraform.yaml.jinja2 b/observatory-platform/observatory/platform/config-terraform.yaml.jinja2 deleted file mode 100644 index ab30c99fb..000000000 --- a/observatory-platform/observatory/platform/config-terraform.yaml.jinja2 +++ /dev/null @@ -1,62 +0,0 @@ -# The backend type: terraform -# The environment type: develop, staging or production -backend: - type: terraform - environment: develop - -# Observatory settings -observatory: - package: observatory-api - package_type: pypi - airflow_fernet_key: {{ airflow_fernet_key }} - airflow_secret_key: {{ airflow_secret_key }} - airflow_ui_user_email: my-email@example.com <-- - airflow_ui_user_password: my-password <-- - postgres_password: my-password <-- - -# Terraform settings -terraform: - organization: my-terraform-org-name <-- - -# Google Cloud settings -google_cloud: - project_id: my-gcp-id <-- - credentials: /path/to/google_application_credentials.json <-- - region: us-west1 <-- - zone: us-west1-a <-- - data_location: us <-- - -# Google Cloud CloudSQL database settings -cloud_sql_database: - tier: db-custom-2-7680 - backup_start_time: '23:00' - - -# Settings for the main VM that runs the Apache Airflow scheduler and webserver -airflow_main_vm: - machine_type: n2-standard-2 - disk_size: 50 - disk_type: pd-ssd - create: true - -# Settings for the weekly on-demand VM that runs large tasks -airflow_worker_vm: - machine_type: n1-standard-8 - disk_size: 3000 - disk_type: pd-standard - create: false - -# User defined Apache Airflow variables: -# airflow_variables: -# my_variable_name: my-variable-value - -# User defined Apache Airflow Connections: -# airflow_connections: -# my_connection: http://my-username:my-password@ - -# User defined Observatory DAGs projects: -# workflows_projects: -# - package_name: academic-observatory-workflows -# package: /path/to/academic-observatory-workflows -# package_type: editable -# dags_module: academic_observatory_workflows.dags diff --git a/observatory-platform/observatory/platform/config.yaml.jinja2 b/observatory-platform/observatory/platform/config.yaml.jinja2 deleted file mode 100644 index ee00961af..000000000 --- a/observatory-platform/observatory/platform/config.yaml.jinja2 +++ /dev/null @@ -1,40 +0,0 @@ -# The backend type: local -# The environment type: develop, staging or production -backend: - type: local - environment: develop - -# Observatory settings -observatory: - package: observatory-platform - package_type: pypi - airflow_fernet_key: {{ airflow_fernet_key }} - airflow_secret_key: {{ airflow_secret_key }} - -# Terraform settings: customise to use the vm_create and vm_destroy DAGs: -# terraform: -# organization: my-terraform-org-name - -# Google Cloud settings: customise to use Google Cloud services -# google_cloud: -# project_id: my-gcp-id -# credentials: /path/to/google_application_credentials.json -# data_location: us -# buckets: -# download_bucket: my-download-bucket-name -# transform_bucket: my-transform-bucket-name - -# User defined Apache Airflow variables: -# airflow_variables: -# my_variable_name: my-variable-value - -# User defined Apache Airflow Connections: -# airflow_connections: -# my_connection: http://my-username:my-password@ - -# User defined workflows projects: -# workflows_projects: -# - package_name: academic-observatory-workflows -# package: /path/to/academic-observatory-workflows -# package_type: editable -# dags_module: academic_observatory_workflows.dags diff --git a/observatory-platform/observatory/platform/dags/load_dags_modules.py b/observatory-platform/observatory/platform/dags/load_dags_modules.py deleted file mode 100644 index 34bd2973d..000000000 --- a/observatory-platform/observatory/platform/dags/load_dags_modules.py +++ /dev/null @@ -1,37 +0,0 @@ -# Copyright 2020 Curtin University -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# The keywords airflow and DAG are required to load the DAGs from this file, see bullet 2 in the Apache Airflow FAQ: -# https://airflow.apache.org/docs/stable/faq.html - -# Author: James Diprose - -import logging - -from airflow.models import DagBag - -from observatory.platform.airflow import fetch_dags_modules, fetch_dag_bag -from observatory.platform.config import module_file_path - -# Load DAGs for each DAG path -dags_modules = fetch_dags_modules() -for module_name in dags_modules: - dags_path = module_file_path(module_name) - logging.info(f"{module_name} DAGs path: {dags_path}") - dag_bag: DagBag = fetch_dag_bag(dags_path) - - # Load dags - for dag_id, dag in dag_bag.dags.items(): - logging.info(f"Adding DAG: dag_id={dag_id}, dag={dag}") - globals()[dag_id] = dag diff --git a/observatory-platform/observatory/platform/dags/load_workflows.py b/observatory-platform/observatory/platform/dags/load_workflows.py deleted file mode 100644 index ea0fde168..000000000 --- a/observatory-platform/observatory/platform/dags/load_workflows.py +++ /dev/null @@ -1,34 +0,0 @@ -# Copyright 2023 Curtin University -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# The keywords airflow and DAG are required to load the DAGs from this file, see bullet 2 in the Apache Airflow FAQ: -# https://airflow.apache.org/docs/stable/faq.html - -# Author: James Diprose - -import logging -from typing import List - -from observatory.platform.airflow import fetch_workflows, make_workflow -from observatory.platform.observatory_config import Workflow - -# Load DAGs -workflows: List[Workflow] = fetch_workflows() -for config in workflows: - logging.info(f"Making Workflow: {config.name}, dag_id={config.dag_id}") - workflow = make_workflow(config) - dag = workflow.make_dag() - - logging.info(f"Adding DAG: dag_id={workflow.dag_id}, dag={dag}") - globals()[workflow.dag_id] = dag diff --git a/observatory-platform/observatory/platform/docker/Dockerfile.apiserver.jinja2 b/observatory-platform/observatory/platform/docker/Dockerfile.apiserver.jinja2 deleted file mode 100644 index 43fecd5ec..000000000 --- a/observatory-platform/observatory/platform/docker/Dockerfile.apiserver.jinja2 +++ /dev/null @@ -1,61 +0,0 @@ -# Copyright 2020-2023 Curtin University -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Author: Tuan Chien - -FROM python:3.10-bullseye - -ARG HOST_USER_ID -ARG OBSERVATORY_HOME=/opt/observatory -ARG INSTALL_USER=apiserver - -RUN adduser ${INSTALL_USER} -USER root -RUN apt-get update -yqq -RUN apt-get install -y git python3-pip postgresql-client-13 gunicorn procps netcat -RUN usermod -u ${HOST_USER_ID} ${INSTALL_USER} - -USER ${INSTALL_USER} - -# Install dependencies for all projects -{% for package in config.python_packages %} -{% if package.name == 'observatory-api' %} -# Set working directory for {{ package.name }} -ARG WORKING_DIR=/opt/{{ package.name }} -WORKDIR ${WORKING_DIR} - -# Change owner of directory to airflow -USER root -RUN chown -R ${INSTALL_USER} ${WORKING_DIR} -USER ${INSTALL_USER} - -{% with install_deps=true %} -{% include 'Dockerfile.package_install.jinja2' %} -{% endwith %} - -# Set working directory back to airflow home -WORKDIR ${OBSERVATORY_HOME} -{% endif %} -{% endfor %} - -# Copy entry point scripts which install new dependencies at runtime and the Observatory Platform Python package -USER root - -COPY entrypoint-api.sh /entrypoint-api.sh -RUN chmod +x /entrypoint-api.sh - -RUN chown -R ${INSTALL_USER} /opt/observatory - -USER ${INSTALL_USER} -ENTRYPOINT ["/entrypoint-api.sh"] diff --git a/observatory-platform/observatory/platform/docker/Dockerfile.observatory.jinja2 b/observatory-platform/observatory/platform/docker/Dockerfile.observatory.jinja2 deleted file mode 100644 index 1bc684a3d..000000000 --- a/observatory-platform/observatory/platform/docker/Dockerfile.observatory.jinja2 +++ /dev/null @@ -1,66 +0,0 @@ -# Copyright 2020 Curtin University -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Author: James Diprose - -FROM apache/airflow:slim-{{ airflow_version }}-python{{ python_version }} - -ARG HOST_USER_ID -ARG OBSERVATORY_HOME=/opt/observatory -ARG INSTALL_USER=airflow - -# Install git which is required when installing dependencies with pip -USER root -RUN apt-get update -yqq -RUN apt-get install git -yqq - -# Change airflow user's user id to the hosts user id -RUN usermod -u ${HOST_USER_ID} ${INSTALL_USER} - -# Install Python dependencies for Observatory Platform as airflow user -USER ${INSTALL_USER} - -# Install dependencies for all projects -{% for package in config.python_packages %} -# Set working directory for {{ package.name }} -ARG WORKING_DIR=/opt/{{ package.name }} -WORKDIR ${WORKING_DIR} - -# Change owner of directory to airflow -USER root -RUN chown -R ${INSTALL_USER} ${WORKING_DIR} -USER ${INSTALL_USER} - -# Install apache-airflow-providers-google: required for cloud logging -# Install with no dependencies -RUN pip install gcloud-aio-storage==8.3.0 gcloud-aio-auth==4.2.3 google-cloud-secret-manager==2.16.3 -RUN pip install apache-airflow-providers-google==10.5.0 gcloud-aio-storage==8.3.0 gcloud-aio-auth==4.2.3 --no-deps - -{% with install_deps=true %} - {% include 'Dockerfile.package_install.jinja2' %} -{% endwith %} - -# Set working directory back to airflow home -WORKDIR ${AIRFLOW_HOME} -{% endfor %} - -# Copy entry point scripts which install new dependencies at runtime and the Observatory Platform Python package -USER root - -COPY entrypoint-root.sh /entrypoint-root.sh -COPY entrypoint-airflow.sh /entrypoint-airflow.sh -RUN chmod +x /entrypoint-root.sh -RUN chmod +x /entrypoint-airflow.sh - -ENTRYPOINT ["/entrypoint-root.sh"] diff --git a/observatory-platform/observatory/platform/docker/Dockerfile.package_install.jinja2 b/observatory-platform/observatory/platform/docker/Dockerfile.package_install.jinja2 deleted file mode 100644 index 0cf1c051b..000000000 --- a/observatory-platform/observatory/platform/docker/Dockerfile.package_install.jinja2 +++ /dev/null @@ -1,66 +0,0 @@ -{# Copyright 2020 Curtin University -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License.#} - -# Author: Tuan Chien - -USER ${INSTALL_USER} -{% if package.type == 'editable' %} - -{% if install_deps %} -# Install editable package system dependencies: {{ package.name }} -COPY ./requirements.{{ package.name }}.sh requirements.sh -USER root -RUN chmod +x ./requirements.sh -RUN ./requirements.sh -USER ${INSTALL_USER} -# The Python dependencies for editable packages are installed when the container starts -# as the packages are volume mounted -{% endif %} - -{% elif package.type == 'sdist' %} - -# Install sdist package: {{ package.name }} -{% if install_deps %} -COPY ./{{ package.docker_package }} {{ package.docker_package }} - -# Extract sdist and install requirements.sh -RUN tar -xf *.tar.gz --one-top-level=sdist --strip-components 1 -RUN cp sdist/requirements.sh ./requirements.sh -USER root -RUN chmod +x ./requirements.sh -RUN ./requirements.sh -USER ${INSTALL_USER} - -# Install Python package -RUN pip3 install {{ package.docker_package }} --user --constraint https://raw.githubusercontent.com/apache/airflow/constraints-{{ airflow_version }}/constraints-no-providers-{{ python_version }}.txt -{% endif %} - -{% elif package.type == 'pypi' %} - -# Install PyPI package: {{ package.name }} -# Extract sdist and install requirements.sh -{% if install_deps %} -RUN pip3 download {{ package.docker_package }} --no-binary :all: --no-deps -RUN tar -xf *.tar.gz --one-top-level=sdist --strip-components 1 -RUN cp sdist/requirements.sh ./requirements.sh -USER root -RUN chmod +x ./requirements.sh -RUN ./requirements.sh -USER ${INSTALL_USER} -{% endif %} - -# Install package with PyPI -RUN pip3 install {{ package.docker_package }} --user {{ "--no-deps" if not install_deps }} --constraint https://raw.githubusercontent.com/apache/airflow/constraints-{{ airflow_version }}/constraints-no-providers-{{ python_version }}.txt - -{% endif %} \ No newline at end of file diff --git a/observatory-platform/observatory/platform/docker/builder.py b/observatory-platform/observatory/platform/docker/builder.py deleted file mode 100644 index ca092ed3d..000000000 --- a/observatory-platform/observatory/platform/docker/builder.py +++ /dev/null @@ -1,152 +0,0 @@ -# Copyright 2021 Curtin University -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Author: James Diprose - -import dataclasses -import os -import shutil -from abc import ABC, abstractmethod -from typing import Dict, List - -from observatory.platform.utils.jinja2_utils import render_template - - -class BuilderInterface(ABC): - """An interface for building Docker and DockerCompose systems.""" - - @abstractmethod - def add_template(self, *, path: str, **kwargs): - """Add a Jinja template that will be rendered to the build directory before running Docker Compose commands. - :param path: the path to the Jinja2 template. - :param kwargs: the kwargs to use when rendering the template. - :return: None - """ - - pass - - @abstractmethod - def add_file(self, *, path: str, output_file_name: str): - """Add a file that will be copied to the build directory before running the Docker Compose commands. - :param path: a path to the file. - :param output_file_name: the output file name. - :return: None - """ - - pass - - @abstractmethod - def make_files(self) -> None: - """Render all Jinja templates and copy all files into the build directory. - :return: None. - """ - - -def rendered_file_name(template_file_path: str): - """Make the rendered file name from the Jinja template name. - :param template_file_path: the file path to the Jinja2 template. - :return: the output file name. - """ - - return os.path.basename(template_file_path).replace(".jinja2", "") - - -@dataclasses.dataclass -class File: - """A file to copy. - :param path: the path to the file. - :param output_file_name: the output file. - """ - - path: str - output_file_name: str - - -@dataclasses.dataclass -class Template: - """A Jinja2 Template to render. - :param path: the path to the Jinja2 template. - :param kwargs: the kwargs to use when rendering the template. - """ - - path: str - kwargs: Dict - - @property - def output_file_name(self) -> str: - """Make the output file name. - :return: the output file name. - """ - - return rendered_file_name(self.path) - - -class Builder(BuilderInterface): - def __init__(self, *, build_path: str): - """BuilderInterface implementation. - :param build_path: the path where the system will be built. - """ - - self.build_path = build_path - self.templates: List[Template] = [] - self.files: List[File] = [] - - def add_template(self, *, path: str, **kwargs): - """Add a Jinja template that will be rendered to the build directory before running Docker Compose commands. - :param path: the path to the Jinja2 template. - :param kwargs: the kwargs to use when rendering the template. - :return: None - """ - - self.templates.append(Template(path=path, kwargs=kwargs)) - - def add_file(self, *, path: str, output_file_name: str): - """Add a file that will be copied to the build directory before running the Docker Compose commands. - :param path: a path to the file. - :param output_file_name: the output file name. - :return: None - """ - - self.files.append(File(path=path, output_file_name=output_file_name)) - - def render_template(self, template: Template, output_file_path: str): - """Render a file using a Jinja template and save. - :param template: the template. - :param output_file_path: the output path. - :return: None. - """ - - render = render_template(template.path, **template.kwargs) - with open(output_file_path, "w") as f: - f.write(render) - - def make_files(self): - """Render all Jinja templates and copy all files into the build directory. - :return: None. - """ - - # Clear Docker directory and make build path - if os.path.exists(self.build_path): - shutil.rmtree(self.build_path) - os.makedirs(self.build_path) - - # Render templates - for template in self.templates: - output_path = os.path.join(self.build_path, template.output_file_name) - self.render_template(template, output_path) - - # Copy files - for file in self.files: - output_path = os.path.join(self.build_path, file.output_file_name) - shutil.copy(file.path, output_path) \ No newline at end of file diff --git a/observatory-platform/observatory/platform/docker/compose_runner.py b/observatory-platform/observatory/platform/docker/compose_runner.py deleted file mode 100644 index e086670bb..000000000 --- a/observatory-platform/observatory/platform/docker/compose_runner.py +++ /dev/null @@ -1,132 +0,0 @@ -# Copyright 2021 Curtin University -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Author: James Diprose - -import dataclasses -import os -import subprocess -from subprocess import Popen -from typing import Dict, List - -from observatory.platform.docker.builder import Builder, rendered_file_name -from observatory.platform.utils.proc_utils import stream_process - - -@dataclasses.dataclass -class ProcessOutput: - """Output from a process - - :param output: the process std out. - :param error: the process std error. - :param return_code: the process return code. - """ - - output: str - error: str - return_code: int - - -class ComposeRunner(Builder): - COMPOSE_ARGS_PREFIX = ["docker", "compose", "-f"] - COMPOSE_BUILD_ARGS = ["build"] - COMPOSE_START_ARGS = ["up", "-d"] - COMPOSE_STOP_ARGS = ["down"] - - def __init__( - self, *, compose_template_path: str, build_path: str, compose_template_kwargs: Dict = None, debug: bool = False - ): - """Docker Compose runner constructor. - - :param compose_template_path: the path to the Docker Compose Jinja2 template file. - :param build_path: the path where Docker will build. - :param compose_template_kwargs: the kwargs to use when rendering the Docker Compose Jinja2 template file. - :param debug: whether to run in debug mode or not. - """ - - super().__init__(build_path=build_path) - if compose_template_kwargs is None: - compose_template_kwargs = dict() - - self.debug = debug - self.compose_template_path = compose_template_path - self.add_template(path=compose_template_path, **compose_template_kwargs) - - def make_environment(self) -> Dict: - """Make the environment variables. - - :return: environment dictionary. - """ - - return os.environ.copy() - - @property - def compose_file_name(self): - """Return the file name for the rendered Docker Compose template. - - :return: Docker Compose file name. - """ - - return rendered_file_name(self.compose_template_path) - - def build(self) -> ProcessOutput: - """Build the Docker containers. - - :return: output and error stream results and proc return code. - """ - - return self.__run_docker_compose_cmd(self.COMPOSE_BUILD_ARGS) - - def start(self) -> ProcessOutput: - """Start the Docker containers. - - :return: output and error stream results and proc return code. - """ - - return self.__run_docker_compose_cmd(self.COMPOSE_START_ARGS) - - def stop(self) -> ProcessOutput: - """Stop the Docker containers. - - :return: output and error stream results and proc return code. - """ - - return self.__run_docker_compose_cmd(self.COMPOSE_STOP_ARGS) - - def __run_docker_compose_cmd(self, args: List) -> ProcessOutput: - """Run a set of Docker Compose arguments. - - :param args: the list of arguments. - :return: output and error stream results and proc return code. - """ - - # Make environment - env = self.make_environment() - - # Make files - self.make_files() - - # Build the containers first - proc: Popen = subprocess.Popen( - self.COMPOSE_ARGS_PREFIX + [self.compose_file_name] + args, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - env=env, - cwd=self.build_path, - ) - - # Wait for results - output, error = stream_process(proc, self.debug) - - return ProcessOutput(output, error, proc.returncode) diff --git a/observatory-platform/observatory/platform/docker/docker-compose.observatory.yml.jinja2 b/observatory-platform/observatory/platform/docker/docker-compose.observatory.yml.jinja2 deleted file mode 100644 index 06ddd5c40..000000000 --- a/observatory-platform/observatory/platform/docker/docker-compose.observatory.yml.jinja2 +++ /dev/null @@ -1,305 +0,0 @@ -{# Copyright 2020-2023 Curtin University -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Author: James Diprose -#} -version: '3.8' - -x-environment: &environment - # Airflow settings - AIRFLOW__CORE__EXECUTOR: CeleryExecutor - AIRFLOW__DATABASE__SQL_ALCHEMY_CONN: "postgresql+psycopg2://${POSTGRES_USER}:${POSTGRES_PASSWORD}@${POSTGRES_HOSTNAME}:5432/airflow" - AIRFLOW__CORE__FERNET_KEY: ${AIRFLOW_FERNET_KEY} - AIRFLOW__CORE__LOAD_EXAMPLES: "False" - AIRFLOW__CORE__LOAD_DEFAULT_CONNECTIONS: "False" - AIRFLOW__CORE__EXECUTE_TASKS_NEW_PYTHON_INTERPRETER: "True" - AIRFLOW__WEBSERVER__RBAC: "True" - AIRFLOW__WEBSERVER__SECRET_KEY: ${AIRFLOW_SECRET_KEY} - AIRFLOW__CELERY__BROKER_URL: "redis://:@${REDIS_HOSTNAME}:6379/0" - AIRFLOW__CELERY__RESULT_BACKEND: "db+postgresql://${POSTGRES_USER}:${POSTGRES_PASSWORD}@${POSTGRES_HOSTNAME}:5432/airflow" - AIRFLOW__CELERY_BROKER_TRANSPORT_OPTIONS__VISIBILITY_TIMEOUT: 259200 - AIRFLOW__API__AUTH_BACKENDS: airflow.api.auth.backend.basic_auth - {%- if config.backend.type.value == 'terraform' %} - AIRFLOW__LOGGING__REMOTE_LOGGING: "True" - AIRFLOW__LOGGING__REMOTE_BASE_LOG_FOLDER: "gs://${AIRFLOW_LOGGING_BUCKET}/logs" - AIRFLOW__LOGGING__REMOTE_LOG_CONN_ID: "google_cloud_observatory" - {%- endif %} - # Set Airflow DAGs folder to installed observatory platform dags - {%- if config.observatory.package_type != 'editable' %} - AIRFLOW__CORE__DAGS_FOLDER: "/home/airflow/.local/lib/python3.10/site-packages/observatory/platform/dags" - {%- endif %} - - # Paths to google credentials and UI user settings - GOOGLE_APPLICATION_CREDENTIALS: "/run/secrets/google_application_credentials.json" - - # Variables - AIRFLOW_VAR_DATA_PATH: "/opt/observatory/data" - AIRFLOW_VAR_WORKFLOWS: ${AIRFLOW_VAR_WORKFLOWS} - AIRFLOW_VAR_DAGS_MODULE_NAMES: ${AIRFLOW_VAR_DAGS_MODULE_NAMES} - -x-volumes: &volumes - - "${HOST_LOGS_PATH}:/opt/airflow/logs" - - "${HOST_DATA_PATH}:/opt/observatory/data" - - # Volume mapping when observatory installed in editable mode - {%- if config.backend.type.value == 'local' and config.observatory.package_type == 'editable' %} - - "{{ config.observatory.host_package }}/observatory/platform/dags:/opt/airflow/dags" - {%- elif config.backend.type.value == 'terraform' and config.observatory.package_type == 'editable' %} - - "/opt/observatory-platform/observatory/platform/dags:/opt/airflow/dags" - {%- endif -%} - - # Volume mappings for Python packages, incl observatory-platform and dags projects: - {%- for package in config.python_packages %} - {%- if config.backend.type.value == 'local' and package.type == 'editable' %} - - "{{ package.host_package }}:/opt/{{ package.name }}" - {%- elif config.backend.type.value == 'terraform' and package.type == 'editable' %} - - "/opt/{{ package.name }}:/opt/{{ package.name }}" - {%- endif -%} - {%- endfor %} - -x-depends-on: &depends-on - redis: - condition: service_healthy - {% if config.backend.type.value == 'local' -%} - postgres: - condition: service_healthy - {%- endif %} - -{% if config.backend.type.value == 'local' %} -{# Local network #} -x-network-mode: &networks - networks: - - {{ config.observatory.docker_network_name }} - -{%- else %} - -{#- Cloud network -#} -x-network-mode: &networks - network_mode: "host" - -{%- endif -%} - -{%- if config.backend.type.value == 'terraform' or not config.observatory.docker_network_is_external %} -{# Cloud network or not external network #} -networks: - {{ config.observatory.docker_network_name }}: - name: {{ config.observatory.docker_network_name }} - driver: bridge -{%- else %} -networks: - {{ config.observatory.docker_network_name }}: - external: true -{%- endif %} - -x-build: &build - context: . - dockerfile: Dockerfile.observatory - args: - - HOST_USER_ID=${HOST_USER_ID} - -services: - redis: - container_name: redis - hostname: redis - image: redis:latest - restart: always - ports: - - ${HOST_REDIS_PORT}:6379 - networks: - - {{ config.observatory.docker_network_name }} - healthcheck: - test: ["CMD", "redis-cli", "ping"] - interval: 5s - timeout: 30s - retries: 50 - - flower: - container_name: flower - hostname: flower - image: apache/airflow:{{ airflow_version }}-python{{ python_version }} - environment: *environment - restart: always - networks: - - {{ config.observatory.docker_network_name }} - ports: - - ${HOST_FLOWER_UI_PORT}:5555 - command: celery flower - healthcheck: - test: ["CMD", "curl", "--fail", "http://localhost:5555/"] - interval: 30s - timeout: 10s - retries: 5 - {% if (config.google_cloud is not none) and (config.google_cloud.credentials is not none) -%} - secrets: - - google_application_credentials.json - {%- endif %} - depends_on: - <<: *depends-on - airflow_init: - condition: service_completed_successfully - - webserver: - container_name: webserver - hostname: webserver - image: apache/airflow:{{ airflow_version }}-python{{ python_version }} - volumes: *volumes - environment: *environment - restart: always - networks: - - {{ config.observatory.docker_network_name }} - {% if (config.google_cloud is not none) and (config.google_cloud.credentials is not none) -%} - secrets: - - google_application_credentials.json - {%- endif %} - ports: - - ${HOST_AIRFLOW_UI_PORT}:8080 - command: webserver - healthcheck: - test: ["CMD-SHELL", "[ -f /opt/airflow/airflow-webserver.pid ]"] - interval: 30s - timeout: 30s - retries: 30 - depends_on: - <<: *depends-on - airflow_init: - condition: service_completed_successfully - - airflow_init: - container_name: airflow_init - hostname: airflow_init - image: apache/airflow:{{ airflow_version }}-python{{ python_version }} - environment: - <<: *environment - _AIRFLOW_DB_UPGRADE: "true" - _AIRFLOW_WWW_USER_CREATE: "true" - _AIRFLOW_WWW_USER_USERNAME: ${AIRFLOW_UI_USER_EMAIL} - _AIRFLOW_WWW_USER_PASSWORD: ${AIRFLOW_UI_USER_PASSWORD} - networks: - - {{ config.observatory.docker_network_name }} - command: version - depends_on: *depends-on - - scheduler: - container_name: scheduler - hostname: scheduler - build: *build - environment: *environment - volumes: *volumes - restart: always - networks: - - {{ config.observatory.docker_network_name }} - {% if (config.google_cloud is not none) and (config.google_cloud.credentials is not none) -%} - secrets: - - google_application_credentials.json - {%- endif %} - command: scheduler - depends_on: - <<: *depends-on - airflow_init: - condition: service_completed_successfully - - worker_local: - container_name: worker_local - hostname: worker_local - build: *build - environment: *environment - volumes: *volumes - restart: always - <<: *networks - {% if (config.google_cloud is not none) and (config.google_cloud.credentials is not none) -%} - secrets: - - google_application_credentials.json - {%- endif %} - command: celery worker -q default - depends_on: - <<: *depends-on - airflow_init: - condition: service_completed_successfully - - worker_remote: - container_name: worker_remote - hostname: worker_remote - build: *build - environment: *environment - volumes: *volumes - restart: always - <<: *networks - {% if (config.google_cloud is not none) and (config.google_cloud.credentials is not none) -%} - secrets: - - google_application_credentials.json - {%- endif %} - command: celery worker -q remote_queue - {% if config.backend.type.value == 'local' %} - depends_on: - <<: *depends-on - airflow_init: - condition: service_completed_successfully - {% endif %} - - apiserver: - container_name: apiserver - hostname: apiserver - build: - context: . - dockerfile: Dockerfile.apiserver - args: - - HOST_USER_ID=${HOST_USER_ID} - environment: - - API_SERVER_DB=observatory - - BASE_DB_URI=postgresql://${POSTGRES_USER}:${POSTGRES_PASSWORD}@${POSTGRES_HOSTNAME}/postgres - - OBSERVATORY_DB_URI=postgresql://${POSTGRES_USER}:${POSTGRES_PASSWORD}@${POSTGRES_HOSTNAME}/observatory - - OBSERVATORY_API_PORT=5002 - volumes: *volumes - restart: always - ports: - - ${HOST_API_SERVER_PORT}:5002 - networks: - - {{ config.observatory.docker_network_name }} - healthcheck: - test: ["CMD", "nc", "-z", "-v", "apiserver", "5002"] - interval: 30s - retries: 20 - command: apiserver - entrypoint: /entrypoint-api.sh - {% if config.backend.type.value == 'local' %} - depends_on: - postgres: - condition: service_healthy - {% endif %} - -{% if config.backend.type.value == 'local' %} - postgres: - container_name: postgres - hostname: postgres - image: postgres:12.2 - environment: - - POSTGRES_DB=airflow - - POSTGRES_USER=${POSTGRES_USER} - - POSTGRES_PASSWORD=${POSTGRES_PASSWORD} - volumes: - - ${HOST_POSTGRES_PATH}:/var/lib/postgresql/data - restart: always - networks: - - {{ config.observatory.docker_network_name }} - healthcheck: - test: ["CMD", "pg_isready", "-U", "${POSTGRES_USER}", "-d", "airflow"] - interval: 5s - retries: 5 -{%- endif %} - - -{% if (config.google_cloud is not none) and (config.google_cloud.credentials is not none) -%} -secrets: - google_application_credentials.json: - file: ${HOST_GOOGLE_APPLICATION_CREDENTIALS} -{%- endif %} diff --git a/observatory-platform/observatory/platform/docker/entrypoint-airflow.sh.jinja2 b/observatory-platform/observatory/platform/docker/entrypoint-airflow.sh.jinja2 deleted file mode 100755 index a90e57bce..000000000 --- a/observatory-platform/observatory/platform/docker/entrypoint-airflow.sh.jinja2 +++ /dev/null @@ -1,39 +0,0 @@ -#!/usr/bin/env bash -# Copyright 2020 Curtin University -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Author: James Diprose - -# This script is run as the airflow user - -######################################## -# Install editable mode Python packages -######################################## - -{% for package in config.python_packages %} -{% if package.type == 'editable' -%} -cd /opt/{{ package.name }} -export PBR_VERSION=0.0.1 -pip3 install -e . --user --constraint https://raw.githubusercontent.com/apache/airflow/constraints-{{ airflow_version }}/constraints-no-providers-{{ python_version }}.txt -unset PBR_VERSION -{% endif %} -{% endfor %} - -# Enter airflow home folder. Must be in the AIRFLOW_HOME folder (i.e. /opt/airflow) before running the next command -# otherwise the system will start but the workers and scheduler will not find the DAGs and other files because -# they look for them based on the current working directory. -cd ${AIRFLOW_HOME} - -# Run entrypoint given by airflow docker file -/usr/bin/dumb-init -- /entrypoint "$@" diff --git a/observatory-platform/observatory/platform/docker/entrypoint-api.sh.jinja2 b/observatory-platform/observatory/platform/docker/entrypoint-api.sh.jinja2 deleted file mode 100644 index 8d6dc9276..000000000 --- a/observatory-platform/observatory/platform/docker/entrypoint-api.sh.jinja2 +++ /dev/null @@ -1,40 +0,0 @@ -#!/usr/bin/env bash -# Copyright 2020-2023 Curtin University -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Author: Tuan Chien - - -######################################## -# Install editable mode Python packages -######################################## - -export PATH=/home/apiserver/.local/bin:$PATH - -{% for package in config.python_packages %} -{% if package.name == 'observatory-api' %} -{% if package.type == 'editable' -%} -cd /opt/{{ package.name }} -export PBR_VERSION=0.0.1 -pip3 install -e . --user -unset PBR_VERSION -{% endif %} -{% endif %} -{% endfor %} - -# Create database if it does not exist -psql ${BASE_DB_URI} -tc "SELECT 1 FROM pg_database WHERE datname = '${API_SERVER_DB}'" | grep -q 1 || psql ${BASE_DB_URI} -c "CREATE DATABASE ${API_SERVER_DB}" - -# Launch api server -gunicorn -b 0.0.0.0:${OBSERVATORY_API_PORT} --timeout 0 observatory.api.server.app:app \ No newline at end of file diff --git a/observatory-platform/observatory/platform/docker/entrypoint-root.sh b/observatory-platform/observatory/platform/docker/entrypoint-root.sh deleted file mode 100755 index 9c69d9cff..000000000 --- a/observatory-platform/observatory/platform/docker/entrypoint-root.sh +++ /dev/null @@ -1,32 +0,0 @@ -#!/usr/bin/env bash -# Copyright 2020 Curtin University -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Author: James Diprose - -# This script runs commands as the root user - -# Make airflow user own everything in home directory -chown -R airflow ${AIRFLOW_HOME} - -# Make airflow user own everything in observatory directory -chown -R airflow /opt/observatory - -# Hardcoded list of environment variables that need to be preserved -STANDARD_ENV_PRESERVE="AIRFLOW_HOME,GOOGLE_APPLICATION_CREDENTIALS,AO_HOME,AIRFLOW_UI_USER_EMAIL,AIRFLOW_UI_USER_PASSWORD,API_SERVER_DB,BASE_DB_URI,OBSERVATORY_DB_URI,OBSERVATORY_API_PORT" - -# Preserve all environment variables that begin with AIRFLOW__, AIRFLOW_VAR or AIRFLOW_CONN -ALL_ENV_PRESERVE=$(printenv | awk -v env_preserve="$STANDARD_ENV_PRESERVE" -F'=' '$0 ~ /AIRFLOW__|AIRFLOW_VAR|AIRFLOW_CONN/ {printf "%s,", $1} END {print env_preserve}') - -sudo --preserve-env=$ALL_ENV_PRESERVE --user airflow --set-home --login /entrypoint-airflow.sh "$@" \ No newline at end of file diff --git a/observatory-platform/observatory/platform/docker/platform_runner.py b/observatory-platform/observatory/platform/docker/platform_runner.py deleted file mode 100644 index c73f85632..000000000 --- a/observatory-platform/observatory/platform/docker/platform_runner.py +++ /dev/null @@ -1,222 +0,0 @@ -# Copyright 2020-2023 Curtin University -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Author: James Diprose, Aniek Roelofs - - -import os -import shutil - -import docker -import requests - -from observatory.platform.config import module_file_path -from observatory.platform.docker.compose_runner import ComposeRunner -from observatory.platform.observatory_config import Config - -HOST_UID = os.getuid() -DEBUG = False -PYTHON_VERSION = "3.10" -AIRFLOW_VERSION = "2.6.3" - - -class PlatformRunner(ComposeRunner): - def __init__( - self, - *, - config: Config, - host_uid: int = HOST_UID, - docker_build_path: str = None, - debug: bool = DEBUG, - python_version: str = PYTHON_VERSION, - airflow_version: str = AIRFLOW_VERSION, - ): - """Create a PlatformRunner instance, which is used to build, start and stop an Observatory Platform instance. - - :param config: the config. - :param host_uid: The user id of the host system. Used to set the user id in the Docker containers. - :param docker_build_path: the Docker build path. - :param debug: Print debugging information. - :param python_version: Python version. - :param airflow_version: Airflow version. - """ - - self.config = config - self.host_uid = host_uid - - # Set default values when config is invalid - observatory_home = self.config.observatory.observatory_home - - if docker_build_path is None: - docker_build_path = os.path.join(observatory_home, "build", "docker") - - super().__init__( - compose_template_path=os.path.join(self.docker_module_path, "docker-compose.observatory.yml.jinja2"), - build_path=docker_build_path, - compose_template_kwargs={ - "config": self.config, - "airflow_version": airflow_version, - "python_version": python_version, - }, - debug=debug, - ) - - # Add files - self.add_template( - path=os.path.join(self.docker_module_path, "Dockerfile.observatory.jinja2"), - config=self.config, - airflow_version=airflow_version, - python_version=python_version, - ) - self.add_template( - path=os.path.join(self.docker_module_path, "entrypoint-airflow.sh.jinja2"), - config=self.config, - airflow_version=airflow_version, - python_version=python_version, - ) - self.add_template( - path=os.path.join(self.docker_module_path, "Dockerfile.apiserver.jinja2"), - config=self.config, - airflow_version=airflow_version, - python_version=python_version, - ) - self.add_template(path=os.path.join(self.docker_module_path, "entrypoint-api.sh.jinja2"), config=self.config) - self.add_file( - path=os.path.join(self.docker_module_path, "entrypoint-root.sh"), output_file_name="entrypoint-root.sh" - ) - - # Add all project requirements files for local projects - for package in self.config.python_packages: - if package.type == "editable": - # Add requirements.sh - self.add_file( - path=os.path.join(package.host_package, "requirements.sh"), - output_file_name=f"requirements.{package.name}.sh", - ) - elif package.type == "sdist": - # Add sdist package file - self.add_file(path=package.host_package, output_file_name=package.docker_package) - - @property - def is_environment_valid(self) -> bool: - """Return whether the environment for building the Observatory Platform is valid. - - :return: whether the environment for building the Observatory Platform is valid. - """ - - return all([self.docker_exe_path is not None, self.docker_compose, self.is_docker_running]) - - @property - def docker_module_path(self) -> str: - """The path to the Observatory Platform docker module. - - :return: the path. - """ - - return module_file_path("observatory.platform.docker") - - @property - def docker_exe_path(self) -> str: - """The path to the Docker executable. - - :return: the path or None. - """ - - return shutil.which("docker") - - @property - def docker_compose(self) -> bool: - """Whether Docker Compose is installed. - - :return: true or false. - """ - - stream = os.popen("docker info") - return "compose: Docker Compose" in stream.read() - - @property - def is_docker_running(self) -> bool: - """Checks whether Docker is running or not. - - :return: whether Docker is running or not. - """ - - client = docker.from_env() - try: - is_running = client.ping() - except requests.exceptions.ConnectionError: - is_running = False - return is_running - - def make_environment(self): - """Make an environment containing the environment variables that are required to build and start the - Observatory docker environment. - - :return: None. - """ - - env = os.environ.copy() - - # Settings for Docker Compose - env["COMPOSE_PROJECT_NAME"] = self.config.observatory.docker_compose_project_name - env["POSTGRES_USER"] = "observatory" - env["POSTGRES_HOSTNAME"] = "postgres" - env["REDIS_HOSTNAME"] = "redis" - - # Host settings - env["HOST_USER_ID"] = str(self.host_uid) - observatory_home = os.path.normpath(self.config.observatory.observatory_home) - env["HOST_DATA_PATH"] = os.path.join(observatory_home, "data") - env["HOST_LOGS_PATH"] = os.path.join(observatory_home, "logs") - env["HOST_POSTGRES_PATH"] = os.path.join(observatory_home, "postgres") - env["HOST_REDIS_PORT"] = str(self.config.observatory.redis_port) - env["HOST_FLOWER_UI_PORT"] = str(self.config.observatory.flower_ui_port) - env["HOST_AIRFLOW_UI_PORT"] = str(self.config.observatory.airflow_ui_port) - env["HOST_API_SERVER_PORT"] = str(self.config.observatory.api_port) - - # Secrets - if self.config.google_cloud is not None and self.config.google_cloud.credentials is not None: - env["HOST_GOOGLE_APPLICATION_CREDENTIALS"] = self.config.google_cloud.credentials - env["AIRFLOW_FERNET_KEY"] = self.config.observatory.airflow_fernet_key - env["AIRFLOW_SECRET_KEY"] = self.config.observatory.airflow_secret_key - env["AIRFLOW_UI_USER_EMAIL"] = self.config.observatory.airflow_ui_user_email - env["AIRFLOW_UI_USER_PASSWORD"] = self.config.observatory.airflow_ui_user_password - env["POSTGRES_PASSWORD"] = self.config.observatory.postgres_password - - # AIRFLOW_VAR_WORKFLOWS is used to decide what workflows to run and what their settings are - env["AIRFLOW_VAR_WORKFLOWS"] = self.config.airflow_var_workflows - - # AIRFLOW_VAR_DAGS_MODULE_NAMES is used to decide what dags modules to load DAGS from - env["AIRFLOW_VAR_DAGS_MODULE_NAMES"] = self.config.airflow_var_dags_module_names - - return env - - def make_files(self): - """Create directories that are mounted as volumes as defined in the docker-compose file. - - :return: None. - """ - super(PlatformRunner, self).make_files() - observatory_home = os.path.normpath(self.config.observatory.observatory_home) - # Create data directory - data_dir = os.path.join(observatory_home, "data") - os.makedirs(data_dir, exist_ok=True) - - # Create logs directory - logs_dir = os.path.join(observatory_home, "logs") - os.makedirs(logs_dir, exist_ok=True) - - # Create postgres directory - postgres_dir = os.path.join(observatory_home, "postgres") - os.makedirs(postgres_dir, exist_ok=True) diff --git a/observatory-platform/observatory/platform/observatory_config.py b/observatory-platform/observatory/platform/observatory_config.py deleted file mode 100644 index 9633f100f..000000000 --- a/observatory-platform/observatory/platform/observatory_config.py +++ /dev/null @@ -1,1702 +0,0 @@ -# Copyright 2019, 2020 Curtin University -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Author: James Diprose, Aniek Roelofs, Tuan Chien - - -from __future__ import annotations - -import base64 -import datetime -import json -import os -from dataclasses import dataclass, field -from enum import Enum -from typing import Any, Callable, ClassVar, Dict, List, TextIO, Tuple, Union, Optional - -import binascii -import cerberus.validator -import pendulum -import yaml -from cerberus import Validator -from cryptography.fernet import Fernet -from yaml.constructor import SafeConstructor - -from observatory.platform.cli.cli_utils import ( - INDENT1, - INDENT3, - comment, - indent, -) -from observatory.platform.config import ( - observatory_home as default_observatory_home, -) -from observatory.platform.terraform.terraform_api import TerraformVariable - - -def generate_fernet_key() -> str: - """Generate a Fernet key. - - :return: A newly generated Fernet key. - """ - - return Fernet.generate_key().decode("utf8") - - -def generate_secret_key(length: int = 30) -> str: - """Generate a secret key for the Flask Airflow Webserver. - - :param length: the length of the key to generate. - :return: the random key. - """ - - return binascii.b2a_hex(os.urandom(length)).decode("utf-8") - - -def save_yaml(file_path: str, dict_: Dict): - """Save a yaml file from a dictionary. - - :param file_path: the path to the file to save. - :param dict_: the dictionary. - :return: None. - """ - - with open(file_path, "w") as yaml_file: - yaml.dump(dict_, yaml_file, default_flow_style=False) - - -def to_hcl(value: Dict) -> str: - """Convert a Python dictionary into HCL. - - :param value: the dictionary. - :return: the HCL string. - """ - - return json.dumps(value, separators=(",", "=")) - - -def from_hcl(string: str) -> Dict: - """Convert an HCL string into a Dict. - - :param string: the HCL string. - :return: the dict. - """ - - return json.loads(string.replace('"=', '":')) - - -class BackendType(Enum): - """The type of backend""" - - local = "local" - terraform = "terraform" - - -class Environment(Enum): - """The environment being used""" - - develop = "develop" - staging = "staging" - production = "production" - - -@dataclass -class Backend: - """The backend settings for the Observatory Platform. - - Attributes: - type: the type of backend being used (local environment or Terraform). - environment: what type of environment is being deployed (develop, staging or production). - """ - - type: BackendType = BackendType.local - environment: Environment = Environment.develop - - @staticmethod - def from_dict(dict_: Dict) -> Backend: - """Constructs a Backend instance from a dictionary. - - :param dict_: the dictionary. - :return: the Backend instance. - """ - - backend_type = BackendType(dict_.get("type")) - environment = Environment(dict_.get("environment")) - - return Backend( - backend_type, - environment, - ) - - -@dataclass -class PythonPackage: - name: str - host_package: str - docker_package: str - type: str - - -@dataclass -class Observatory: - """The Observatory settings for the Observatory Platform. - - Attributes: - :param package: the observatory platform package, either a local path to a Python source package (editable type), - path to a sdist (sdist) or a PyPI package name and version (pypi). - :param package_type: the package type, editable, sdist, pypi. - :param airflow_fernet_key: the Fernet key. - :param airflow_secret_key: the secret key used to run the flask app. - :param airflow_ui_user_password: the password for the Apache Airflow UI admin user. - :param airflow_ui_user_email: the email address for the Apache Airflow UI admin user. - :param observatory_home: The observatory home folder. - :param postgres_password: the Postgres SQL password. - :param redis_port: The host Redis port number. - :param flower_ui_port: The host's Flower UI port number. - :param airflow_ui_port: The host's Apache Airflow UI port number. - :param docker_network_name: The Docker Network name, used to specify a custom Docker Network. - :param docker_network_is_external: whether the docker network is external or not. - :param docker_compose_project_name: The namespace for the Docker Compose containers: https://docs.docker.com/compose/reference/envvars/#compose_project_name. - :param api_port: Port for accessing the API locally (exposed by docker container). - """ - - package: str = "observatory-platform" - package_type: str = "pypi" - airflow_fernet_key: str = field(default_factory=generate_fernet_key) - airflow_secret_key: str = field(default_factory=generate_secret_key) - airflow_ui_user_email: str = "airflow@airflow.com" - airflow_ui_user_password: str = "airflow" - observatory_home: str = default_observatory_home() - postgres_password: str = "postgres" - redis_port: int = 6379 - flower_ui_port: int = 5555 - airflow_ui_port: int = 8080 - docker_network_name: str = "observatory-network" - docker_network_is_external: bool = False - docker_compose_project_name: str = "observatory" - api_package: str = "observatory-api" - api_package_type: str = "pypi" - api_port: int = 5002 - - def to_hcl(self): - return to_hcl( - { - "airflow_fernet_key": self.airflow_fernet_key, - "airflow_secret_key": self.airflow_secret_key, - "airflow_ui_user_password": self.airflow_ui_user_password, - "airflow_ui_user_email": self.airflow_ui_user_email, - "postgres_password": self.postgres_password, - } - ) - - @property - def host_package(self): - return os.path.normpath(self.package) - - @staticmethod - def from_dict(dict_: Dict) -> Observatory: - """Constructs an Airflow instance from a dictionary. - - :param dict_: the dictionary. - :return: the Airflow instance. - """ - - package = dict_.get("package") - package_type = dict_.get("package_type") - airflow_fernet_key = dict_.get("airflow_fernet_key") - airflow_secret_key = dict_.get("airflow_secret_key") - airflow_ui_user_email = dict_.get("airflow_ui_user_email", Observatory.airflow_ui_user_email) - airflow_ui_user_password = dict_.get("airflow_ui_user_password", Observatory.airflow_ui_user_password) - observatory_home = dict_.get("observatory_home", Observatory.observatory_home) - postgres_password = dict_.get("postgres_password", Observatory.postgres_password) - redis_port = dict_.get("redis_port", Observatory.redis_port) - flower_ui_port = dict_.get("flower_ui_port", Observatory.flower_ui_port) - airflow_ui_port = dict_.get("airflow_ui_port", Observatory.airflow_ui_port) - docker_network_name = dict_.get("docker_network_name", Observatory.docker_network_name) - docker_network_is_external = dict_.get("docker_network_is_external", Observatory.docker_network_is_external) - docker_compose_project_name = dict_.get("docker_compose_project_name", Observatory.docker_compose_project_name) - api_package = dict_.get("api_package", Observatory.api_package) - api_package_type = dict_.get("api_package_type", Observatory.api_package_type) - api_port = dict_.get("api_port", Observatory.api_port) - - return Observatory( - package, - package_type, - airflow_fernet_key, - airflow_secret_key, - airflow_ui_user_password=airflow_ui_user_password, - airflow_ui_user_email=airflow_ui_user_email, - observatory_home=observatory_home, - postgres_password=postgres_password, - redis_port=redis_port, - flower_ui_port=flower_ui_port, - airflow_ui_port=airflow_ui_port, - docker_network_name=docker_network_name, - docker_network_is_external=docker_network_is_external, - docker_compose_project_name=docker_compose_project_name, - api_package=api_package, - api_package_type=api_package_type, - api_port=api_port, - ) - - -@dataclass -class CloudStorageBucket: - """Represents a Google Cloud Storage Bucket. - - Attributes: - id: the id of the bucket (which gets set as an Airflow variable). - name: the name of the Google Cloud storage bucket. - """ - - id: str - name: str - - @staticmethod - def parse_buckets(buckets: Dict) -> List[CloudStorageBucket]: - return parse_dict_to_list(buckets, CloudStorageBucket) - - -@dataclass -class GoogleCloud: - """The Google Cloud settings for the Observatory Platform. - - Attributes: - project_id: the Google Cloud project id. - credentials: the path to the Google Cloud credentials. - region: the Google Cloud region. - zone: the Google Cloud zone. - data_location: the data location for storing buckets. - """ - - project_id: str = None - credentials: str = None - region: str = None - zone: str = None - data_location: str = None - - def read_credentials(self) -> str: - with open(self.credentials, "r") as f: - data = f.read() - return data - - def to_hcl(self): - return to_hcl( - { - "project_id": self.project_id, - "credentials": self.read_credentials(), - "region": self.region, - "zone": self.zone, - "data_location": self.data_location, - } - ) - - @staticmethod - def from_dict(dict_: Dict) -> GoogleCloud: - """Constructs a GoogleCloud instance from a dictionary. - - :param dict_: the dictionary. - :return: the GoogleCloud instance. - """ - - project_id = dict_.get("project_id") - credentials = dict_.get("credentials") - region = dict_.get("region") - zone = dict_.get("zone") - data_location = dict_.get("data_location") - - return GoogleCloud( - project_id=project_id, credentials=credentials, region=region, zone=zone, data_location=data_location - ) - - -class CloudWorkspace: - def __init__( - self, - *, - project_id: str, - download_bucket: str, - transform_bucket: str, - data_location: str, - output_project_id: Optional[str] = None, - ): - """The CloudWorkspace settings used by workflows. - - project_id: the Google Cloud project id. input_project_id is an alias for project_id. - download_bucket: the Google Cloud Storage bucket where downloads will be stored. - transform_bucket: the Google Cloud Storage bucket where transformed data will be stored. - data_location: the data location for storing information, e.g. where BigQuery datasets should be located. - output_project_id: an optional Google Cloud project id when the outputs of a workflow should be stored in a - different project to the inputs. If an output_project_id is not supplied, the project_id will be used. - """ - - self._project_id = project_id - self._download_bucket = download_bucket - self._transform_bucket = transform_bucket - self._data_location = data_location - self._output_project_id = output_project_id - - @property - def project_id(self) -> str: - return self._project_id - - @project_id.setter - def project_id(self, project_id: str): - self._project_id = project_id - - @property - def download_bucket(self) -> str: - return self._download_bucket - - @download_bucket.setter - def download_bucket(self, download_bucket: str): - self._download_bucket = download_bucket - - @property - def transform_bucket(self) -> str: - return self._transform_bucket - - @transform_bucket.setter - def transform_bucket(self, transform_bucket: str): - self._transform_bucket = transform_bucket - - @property - def data_location(self) -> str: - return self._data_location - - @data_location.setter - def data_location(self, data_location: str): - self._data_location = data_location - - @property - def input_project_id(self) -> str: - return self._project_id - - @input_project_id.setter - def input_project_id(self, project_id: str): - self._project_id = project_id - - @property - def output_project_id(self) -> Optional[str]: - if self._output_project_id is None: - return self._project_id - return self._output_project_id - - @output_project_id.setter - def output_project_id(self, output_project_id: Optional[str]): - self._output_project_id = output_project_id - - @staticmethod - def from_dict(dict_: Dict) -> CloudWorkspace: - """Constructs a CloudWorkspace instance from a dictionary. - - :param dict_: the dictionary. - :return: the Workflow instance. - """ - - project_id = dict_.get("project_id") - download_bucket = dict_.get("download_bucket") - transform_bucket = dict_.get("transform_bucket") - data_location = dict_.get("data_location") - output_project_id = dict_.get("output_project_id") - - return CloudWorkspace( - project_id=project_id, - download_bucket=download_bucket, - transform_bucket=transform_bucket, - data_location=data_location, - output_project_id=output_project_id, - ) - - def to_dict(self) -> Dict: - """CloudWorkspace instance to dictionary. - - :return: the dictionary. - """ - - return dict( - project_id=self._project_id, - download_bucket=self._download_bucket, - transform_bucket=self._transform_bucket, - data_location=self._data_location, - output_project_id=self.output_project_id, - ) - - @staticmethod - def parse_cloud_workspaces(list: List) -> List[CloudWorkspace]: - """Parse the cloud workspaces list object into a list of CloudWorkspace instances. - - :param list: the list. - :return: a list of CloudWorkspace instances. - """ - - return [CloudWorkspace.from_dict(dict_) for dict_ in list] - - -@dataclass -class Workflow: - """A Workflow configuration. - - Attributes: - dag_id: the Airflow DAG identifier for the workflow. - name: a user-friendly name for the workflow. - class_name: the fully qualified class name for the workflow class. - cloud_workspace: the Cloud Workspace to use when running the workflow. - kwargs: a dictionary containing optional keyword arguments that are injected into the workflow constructor. - """ - - dag_id: str = None - name: str = None - class_name: str = None - cloud_workspace: CloudWorkspace = None - kwargs: Optional[Dict] = field(default_factory=lambda: dict()) - - def to_dict(self) -> Dict: - """Workflow instance to dictionary. - - :return: the dictionary. - """ - - cloud_workspace = self.cloud_workspace - if self.cloud_workspace is not None: - cloud_workspace = self.cloud_workspace.to_dict() - - return dict( - dag_id=self.dag_id, - name=self.name, - class_name=self.class_name, - cloud_workspace=cloud_workspace, - kwargs=self.kwargs, - ) - - @staticmethod - def from_dict(dict_: Dict) -> Workflow: - """Constructs a Workflow instance from a dictionary. - - :param dict_: the dictionary. - :return: the Workflow instance. - """ - - dag_id = dict_.get("dag_id") - name = dict_.get("name") - class_name = dict_.get("class_name") - - cloud_workspace = dict_.get("cloud_workspace") - if cloud_workspace is not None: - cloud_workspace = CloudWorkspace.from_dict(cloud_workspace) - - kwargs = dict_.get("kwargs", dict()) - - return Workflow(dag_id, name, class_name, cloud_workspace, kwargs) - - @staticmethod - def parse_workflows(list: List) -> List[Workflow]: - """Parse the workflows list object into a list of Workflow instances. - - :param list: the list. - :return: a list of Workflow instances. - """ - - return [Workflow.from_dict(dict_) for dict_ in list] - - -class PendulumDateTimeEncoder(json.JSONEncoder): - def default(self, obj: Any) -> Any: - if isinstance(obj, pendulum.DateTime): - return obj.isoformat() - return super().default(obj) - - -def workflows_to_json_string(workflows: List[Workflow]) -> str: - """Covnert a list of Workflow instances to a JSON string. - - :param workflows: the Workflow instances. - :return: a JSON string. - """ - - data = [workflow.to_dict() for workflow in workflows] - return json.dumps(data, cls=PendulumDateTimeEncoder) - - -def json_string_to_workflows(json_string: str) -> List[Workflow]: - """Convert a JSON string into a list of Workflow instances. - - :param json_string: a JSON string version of a list of Workflow instances. - :return: a list of Workflow instances. - """ - - def parse_datetime(obj): - for key, value in obj.items(): - try: - obj[key] = pendulum.parse(value) - except (ValueError, TypeError): - pass - return obj - - data = json.loads(json_string, object_hook=parse_datetime) - return Workflow.parse_workflows(data) - - -def parse_dict_to_list(dict_: Dict, cls: ClassVar) -> List[Any]: - """Parse the key, value pairs in a dictionary into a list of class instances. - - :param dict_: the dictionary. - :param cls: the type of class to construct. - :return: a list of class instances. - """ - - parsed_items = [] - for key, val in dict_.items(): - parsed_items.append(cls(key, val)) - return parsed_items - - -@dataclass -class WorkflowsProject: - """Represents a project that contains DAGs to load. - - Attributes: - :param package_name: the package name. - :param package: the observatory platform package, either a local path to a Python source package (editable type), - path to a sdist (sdist) or a PyPI package name and version (pypi). - :param package_type: the package type, editable, sdist, pypi. - dags_module: the Python import path to the module that contains the Apache Airflow DAGs to load. - """ - - package_name: str - package: str - package_type: str - dags_module: str - - @staticmethod - def parse_workflows_projects(list: List) -> List[WorkflowsProject]: - """Parse the workflows_projects list object into a list of WorkflowsProject instances. - - :param list: the list. - :return: a list of WorkflowsProject instances. - """ - - parsed_items = [] - for item in list: - package_name = item["package_name"] - package = item["package"] - package_type = item["package_type"] - dags_module = item["dags_module"] - parsed_items.append(WorkflowsProject(package_name, package, package_type, dags_module)) - return parsed_items - - -@dataclass -class Terraform: - """The Terraform settings for the Observatory Platform. - - Attributes: - organization: the Terraform Organisation name. - """ - - organization: str - - @staticmethod - def from_dict(dict_: Dict) -> Terraform: - """Constructs a Terraform instance from a dictionary. - - :param dict_: the dictionary. - """ - - organization = dict_.get("organization") - return Terraform(organization) - - -@dataclass -class CloudSqlDatabase: - """The Google Cloud SQL database settings for the Observatory Platform. - - Attributes: - tier: the database machine tier. - backup_start_time: the start time for backups in HH:MM format. - """ - - tier: str - backup_start_time: str - - def to_hcl(self): - return to_hcl({"tier": self.tier, "backup_start_time": self.backup_start_time}) - - @staticmethod - def from_dict(dict_: Dict) -> CloudSqlDatabase: - """Constructs a CloudSqlDatabase instance from a dictionary. - - :param dict_: the dictionary. - :return: the CloudSqlDatabase instance. - """ - - tier = dict_.get("tier") - backup_start_time = dict_.get("backup_start_time") - return CloudSqlDatabase(tier, backup_start_time) - - -@dataclass -class VirtualMachine: - """A Google Cloud virtual machine. - - Attributes: - machine_type: the type of Google Cloud virtual machine. - disk_size: the size of the disk in GB. - disk_type: the disk type; pd-standard or pd-ssd. - create: whether to create the VM or not. - """ - - machine_type: str - disk_size: int - disk_type: str - create: bool - - def to_hcl(self): - return to_hcl( - { - "machine_type": self.machine_type, - "disk_size": self.disk_size, - "disk_type": self.disk_type, - "create": self.create, - } - ) - - @staticmethod - def from_hcl(string: str) -> VirtualMachine: - return VirtualMachine.from_dict(from_hcl(string)) - - @staticmethod - def from_dict(dict_: Dict) -> VirtualMachine: - """Constructs a VirtualMachine instance from a dictionary. - - :param dict_: the dictionary. - :return: the VirtualMachine instance. - """ - - machine_type = dict_.get("machine_type") - disk_size = dict_.get("disk_size") - disk_type = dict_.get("disk_type") - create = str(dict_.get("create")).lower() == "true" - return VirtualMachine(machine_type, disk_size, disk_type, create) - - -def is_base64(text: bytes) -> bool: - """Check if the string is base64. - :param text: Text to check. - :return: Whether it is base64. - """ - - try: - base64.decodebytes(text) - except: - return False - - return True - - -def is_secret_key(key: str) -> Tuple[bool, Union[str, None]]: - """Check if the Airflow Flask webserver secret key is valid. - :param key: Key to check. - :return: Validity, and an error message if not valid. - """ - - key_bytes = bytes(key, "utf-8") - message = None - - key_length = len(key_bytes) - if key_length < 16: - message = f"Secret key should be length >=16, but is length {key_length}." - return False, message - - return True, message - - -def is_fernet_key(key: str) -> Tuple[bool, Union[str, None]]: - """Check if the Fernet key is valid. - :param key: Key to check. - :return: Validity, and an error message if not valid. - """ - - key_bytes = bytes(key, "utf-8") - - try: - decoded_key = base64.urlsafe_b64decode(key_bytes) - except: - message = f"Key {key} could not be urlsafe b64decoded." - return False, message - - key_length = len(decoded_key) - if key_length != 32: - message = f"Decoded Fernet key should be length 32, but is length {key_length}." - return False, message - - message = None - return True, message - - -def check_schema_field_fernet_key(field: str, value: str, error: Callable): - """ - :param field: Field name. - :param value: Field value. - :param error: Error handler passed in by Cerberus. - """ - - valid, message = is_fernet_key(value) - - if not valid: - error(field, f"is not a valid Fernet key. Reason: {message}") - - -def check_schema_field_secret_key(field: str, value: str, error: Callable): - """ - :param field: Field name. - :param value: Field value. - :param error: Error handler passed in by Cerberus. - """ - - valid, message = is_secret_key(value) - - if not valid: - error(field, f"is not a valid secret key. Reason: {message}") - - -def customise_pointer(field, value, error): - """Throw an error when a field contains the value ' <--' which means that the user should customise the - value in the config file. - - :param field: the field. - :param value: the value. - :param error: ? - :return: None. - """ - - if isinstance(value, str) and value.endswith(" <--"): - error(field, "Customise value ending with ' <--'") - - -class ObservatoryConfigValidator(Validator): - """Custom config Validator""" - - def _validate_google_application_credentials(self, google_application_credentials, field, value): - """Validate that the Google Application Credentials file exists. - The rule's arguments are validated against this schema: {'type': 'boolean'} - """ - if ( - google_application_credentials - and value is not None - and isinstance(value, str) - and not os.path.isfile(value) - ): - self._error( - field, - f"the file {value} does not exist. See " - f"https://cloud.google.com/docs/authentication/getting-started for instructions on " - f"how to create a service account and save the JSON key to your workstation.", - ) - - -@dataclass -class ValidationError: - """A validation error found when parsing a config file. - - Attributes: - key: the key in the config file. - value: the error. - """ - - key: str - value: Any - - -class ObservatoryConfig: - def __init__( - self, - backend: Backend = None, - observatory: Observatory = None, - google_cloud: GoogleCloud = None, - terraform: Terraform = None, - cloud_workspaces: List[CloudWorkspace] = None, - workflows: List[Workflow] = None, - workflows_projects: List[WorkflowsProject] = None, - validator: ObservatoryConfigValidator = None, - ): - """Create an ObservatoryConfig instance. - - :param backend: the backend config. - :param observatory: the Observatory config. - :param google_cloud: the Google Cloud config. - :param terraform: the Terraform config. - :param cloud_workspaces: the CloudWorkspaces. - :param workflows: the workflows to create in Airflow. - :param workflows_projects: a list of DAGs projects. - :param validator: an ObservatoryConfigValidator instance. - """ - - self.backend = backend if backend is not None else Backend() - self.observatory = observatory if observatory is not None else Observatory() - self.google_cloud = google_cloud - self.terraform = terraform - - self.cloud_workspaces = cloud_workspaces - if cloud_workspaces is None: - self.cloud_workspaces = [] - - self.workflows = workflows - if workflows is None: - self.workflows = [] - - self.workflows_projects = workflows_projects - if workflows_projects is None: - self.workflows_projects = [] - - self.validator = validator - - self.schema = make_schema(self.backend.type) - - @property - def is_valid(self) -> bool: - """Checks whether the config is valid or not. - - :return: whether the config is valid or not. - """ - - return self.validator is None or not len(self.validator._errors) - - @property - def errors(self) -> List[ValidationError]: - """Returns a list of ValidationError instances that were created when parsing the config file. - - :return: the list of ValidationError instances. - """ - - errors = [] - for key, values in self.validator.errors.items(): - for value in values: - if type(value) is dict: - for nested_key, nested_value in value.items(): - errors.append(ValidationError(f"{key}.{nested_key}", *nested_value)) - else: - errors.append(ValidationError(key, *values)) - - return errors - - @property - def python_packages(self) -> List[PythonPackage]: - """Make a list of Python Packages to build or include in the observatory. - :return: the list of Python packages. - """ - - # observatory-api should be installed first so that observatory-platform install doesn't try to install - # observatory-api from PyPI - packages = [ - PythonPackage( - name="observatory-api", - type=self.observatory.api_package_type, - host_package=self.observatory.api_package, - docker_package=os.path.basename(self.observatory.api_package), - ), - PythonPackage( - name="observatory-platform", - type=self.observatory.package_type, - host_package=self.observatory.package, - docker_package=os.path.basename(self.observatory.package), - ), - ] - - for project in self.workflows_projects: - packages.append( - PythonPackage( - name=project.package_name, - type=project.package_type, - host_package=project.package, - docker_package=os.path.basename(project.package), - ) - ) - - return packages - - @property - def airflow_var_workflows(self) -> str: - """Make the workflows Airflow variable. - :return: the workflows Airflow variable. - """ - - return workflows_to_json_string(self.workflows) - - @property - def airflow_var_dags_module_names(self): - """Returns a list of DAG project module names. - :return: the list of DAG project module names. - """ - - return json.dumps([project.dags_module for project in self.workflows_projects]) - - @staticmethod - def _parse_fields( - dict_: Dict, - ) -> Tuple[ - Backend, - Observatory, - GoogleCloud, - Terraform, - List[CloudWorkspace], - List[Workflow], - List[WorkflowsProject], - ]: - backend = Backend.from_dict(dict_.get("backend", dict())) - observatory = Observatory.from_dict(dict_.get("observatory", dict())) - google_cloud = GoogleCloud.from_dict(dict_.get("google_cloud", dict())) - terraform = Terraform.from_dict(dict_.get("terraform", dict())) - cloud_workspaces = CloudWorkspace.parse_cloud_workspaces(dict_.get("cloud_workspaces", list())) - workflows = Workflow.parse_workflows(dict_.get("workflows", list())) - workflows_projects = WorkflowsProject.parse_workflows_projects(dict_.get("workflows_projects", list())) - - return backend, observatory, google_cloud, terraform, cloud_workspaces, workflows, workflows_projects - - @classmethod - def from_dict(cls, dict_: Dict) -> ObservatoryConfig: - """Constructs an ObservatoryConfig instance from a dictionary. - - If the dictionary is invalid, then an ObservatoryConfig instance will be returned with no properties set, - except for the validator, which contains validation errors. - - :param dict_: the input dictionary. - :return: the ObservatoryConfig instance. - """ - - schema = make_schema(BackendType.local) - validator = ObservatoryConfigValidator() - is_valid = validator.validate(dict_, schema) - - if is_valid: - ( - backend, - observatory, - google_cloud, - terraform, - cloud_workspaces, - workflows, - workflows_projects, - ) = ObservatoryConfig._parse_fields(dict_) - - return ObservatoryConfig( - backend, - observatory, - google_cloud=google_cloud, - terraform=terraform, - cloud_workspaces=cloud_workspaces, - workflows=workflows, - workflows_projects=workflows_projects, - validator=validator, - ) - else: - return ObservatoryConfig(validator=validator) - - @classmethod - def load(cls, path: str): - """Load a configuration file. - - :return: the ObservatoryConfig instance (or a subclass of ObservatoryConfig) - """ - - # Make sure that dates and datetimes are returned as pendulum.DateTime instances - # We let yaml parse the date / datetime and then convert to pendulum - def date_constructor(loader, node): - value = SafeConstructor.construct_yaml_timestamp(loader, node) - if isinstance(value, datetime.date) and not isinstance(value, datetime.datetime): - value = datetime.datetime(value.year, value.month, value.day) - return pendulum.instance(value) - - yaml.SafeLoader.add_constructor("tag:yaml.org,2002:timestamp", date_constructor) - - dict_ = dict() - try: - with open(path, "r") as f: - dict_ = yaml.safe_load(f) - except yaml.YAMLError: - print(f"Error parsing {path}") - except FileNotFoundError: - print(f"No such file or directory: {path}") - except cerberus.validator.DocumentError as e: - print(f"cerberus.validator.DocumentError: {e}") - - return cls.from_dict(dict_) - - def get_requirement_string(self, section: str) -> str: - """Query the schema to see whether a section is required. - - :param section: Section to query. - :return: String indicating whether the section is required or optional. - """ - - if self.schema[section]["required"]: - return "Required" - - return "Optional" - - def save(self, path: str): - """Save the observatory configuration parameters to a config file. - - :param path: Configuration file path. - """ - - with open(path, "w") as f: - self.save_backend(f) - self.save_observatory(f) - self.save_terraform(f) - self.save_google_cloud(f) - self.save_workflows_projects(f) - - def save_backend(self, f: TextIO): - """Write the backend configuration section to the config file. - - :param f: File object for the config file. - """ - - requirement = self.get_requirement_string("backend") - f.write( - ( - f"# [{requirement}] Backend settings.\n" - "# Backend options are: local, terraform.\n" - "# Environment options are: develop, staging, production.\n" - ) - ) - lines = ObserveratoryConfigString.backend(self.backend) - f.writelines(lines) - f.write("\n") - - def save_observatory(self, f: TextIO): - """Write the Observatory configuration section to the config file. - - :param f: File object for the config file. - """ - - requirement = self.get_requirement_string("observatory") - f.write( - ( - f"# [{requirement}] Observatory settings\n" - "# If you did not supply your own Fernet and secret keys, then those fields are autogenerated.\n" - "# Passwords are in plaintext.\n" - "# If the package type is editable, the 'package' should be the path location of your package\n" - "# If the package type is PyPi, the 'package' should be its name on PyPi\n" - "# observatory_home is where the observatory metadata is stored.\n" - ) - ) - - lines = ObserveratoryConfigString.observatory(self.observatory) - f.writelines(lines) - f.write("\n") - - def save_terraform(self, f: TextIO): - """Write the Terraform configuration section to the config file. - - :param f: File object for the config file. - """ - - requirement = self.get_requirement_string("terraform") - f.write(f"# [{requirement}] Terraform settings\n") - - lines = ObserveratoryConfigString.terraform(self.terraform) - output = map(comment, lines) if self.terraform is None and requirement == "Optional" else lines - - f.writelines(output) - f.write("\n") - - def save_google_cloud(self, f: TextIO): - """Write the Google Cloud configuration section to the config file. - - :param f: File object for the config file. - """ - - requirement = self.get_requirement_string("google_cloud") - f.write( - ( - f"# [{requirement}] Google Cloud settings\n" - "# If you use any Google Cloud service functions, you will need to configure this.\n" - ) - ) - - lines = ObserveratoryConfigString.google_cloud(google_cloud=self.google_cloud, backend=self.backend) - output = map(comment, lines) if self.google_cloud is None and requirement == "Optional" else lines - - f.writelines(output) - f.write("\n") - - def save_workflows_projects(self, f: TextIO): - """Write the DAGs projects configuration section to the config file. - - :param f: File object for the config file. - """ - - requirement = self.get_requirement_string("workflows_projects") - f.write(f"# [{requirement}] User defined Observatory DAGs projects:\n") - - lines = ObserveratoryConfigString.workflows_projects(workflows_projects=self.workflows_projects) - output = map(comment, lines) if len(self.workflows_projects) == 0 and requirement == "Optional" else lines - - f.writelines(output) - f.write("\n") - - -class TerraformConfig(ObservatoryConfig): - WORKSPACE_PREFIX = "observatory-" - - def __init__( - self, - backend: Backend = None, - observatory: Observatory = None, - google_cloud: GoogleCloud = None, - terraform: Terraform = None, - cloud_workspaces: List[CloudWorkspace] = None, - workflows: List[Workflow] = None, - workflows_projects: List[WorkflowsProject] = None, - cloud_sql_database: CloudSqlDatabase = None, - airflow_main_vm: VirtualMachine = None, - airflow_worker_vm: VirtualMachine = None, - validator: ObservatoryConfigValidator = None, - ): - """Create a TerraformConfig instance. - - :param backend: the backend config. - :param observatory: the Observatory config. - :param google_cloud: the Google Cloud config. - :param terraform: the Terraform config. - :param cloud_workspaces: the CloudWorkspaces. - :param workflows: the workflows to create in Airflow. - :param workflows_projects: a list of DAGs projects. - :param cloud_sql_database: a Google Cloud SQL database config. - :param airflow_main_vm: the Airflow Main VM config. - :param airflow_worker_vm: the Airflow Worker VM config. - :param validator: an ObservatoryConfigValidator instance. - """ - - if backend is None: - backend = Backend(type=BackendType.terraform) - - super().__init__( - backend=backend, - observatory=observatory, - google_cloud=google_cloud, - terraform=terraform, - cloud_workspaces=cloud_workspaces, - workflows=workflows, - workflows_projects=workflows_projects, - validator=validator, - ) - self.cloud_sql_database = cloud_sql_database - self.airflow_main_vm = airflow_main_vm - self.airflow_worker_vm = airflow_worker_vm - - @property - def terraform_workspace_id(self): - """The Terraform workspace id. - - :return: the terraform workspace id. - """ - - return TerraformConfig.WORKSPACE_PREFIX + self.backend.environment.value - - def terraform_variables(self) -> List[TerraformVariable]: - """Create a list of TerraformVariable instances from the Terraform Config. - - :return: a list of TerraformVariable instances. - """ - - return [ - TerraformVariable("environment", self.backend.environment.value), - TerraformVariable("observatory", self.observatory.to_hcl(), sensitive=True, hcl=True), - TerraformVariable("google_cloud", self.google_cloud.to_hcl(), sensitive=True, hcl=True), - TerraformVariable("cloud_sql_database", self.cloud_sql_database.to_hcl(), hcl=True), - TerraformVariable("airflow_main_vm", self.airflow_main_vm.to_hcl(), hcl=True), - TerraformVariable("airflow_worker_vm", self.airflow_worker_vm.to_hcl(), hcl=True), - TerraformVariable("airflow_var_workflows", self.airflow_var_workflows), - TerraformVariable("airflow_var_dags_module_names", self.airflow_var_dags_module_names), - ] - - @classmethod - def from_dict(cls, dict_: Dict) -> TerraformConfig: - """Make an TerraformConfig instance from a dictionary. - - If the dictionary is invalid, then an ObservatoryConfig instance will be returned with no properties set, - except for the validator, which contains validation errors. - - :param dict_: the input dictionary that has been read via yaml.safe_load. - :return: the TerraformConfig instance. - """ - - schema = make_schema(BackendType.terraform) - validator = ObservatoryConfigValidator() - is_valid = validator.validate(dict_, schema) - - if is_valid: - ( - backend, - observatory, - google_cloud, - terraform, - cloud_workspaces, - workflows, - workflows_projects, - ) = ObservatoryConfig._parse_fields(dict_) - - cloud_sql_database = CloudSqlDatabase.from_dict(dict_.get("cloud_sql_database", dict())) - airflow_main_vm = VirtualMachine.from_dict(dict_.get("airflow_main_vm", dict())) - airflow_worker_vm = VirtualMachine.from_dict(dict_.get("airflow_worker_vm", dict())) - - return TerraformConfig( - backend, - observatory, - google_cloud=google_cloud, - terraform=terraform, - cloud_workspaces=cloud_workspaces, - workflows=workflows, - workflows_projects=workflows_projects, - cloud_sql_database=cloud_sql_database, - airflow_main_vm=airflow_main_vm, - airflow_worker_vm=airflow_worker_vm, - validator=validator, - ) - else: - return TerraformConfig(validator=validator) - - def save(self, path: str): - """Save the configuration to a config file in YAML format. - - :param path: Config file path. - """ - - # Save common config - super().save(path) - - # Save Terraform specific sections - with open(path, "a") as f: - self.save_cloud_sql_database(f) - self.save_airflow_main_vm(f) - self.save_airflow_worker_vm(f) - - def save_cloud_sql_database(self, f: TextIO): - """Write the cloud SQL database configuration section to the config file. - - :param f: File object for the config file. - """ - - requirement = self.get_requirement_string("cloud_sql_database") - f.write(f"# [{requirement}] Google Cloud CloudSQL database settings\n") - lines = ObserveratoryConfigString.cloud_sql_database(self.cloud_sql_database) - f.writelines(lines) - f.write("\n") - - def save_airflow_main_vm(self, f: TextIO): - """Write the Airflow main VM configuration section to the config file. - - :param f: File object for the config file. - """ - - requirement = self.get_requirement_string("airflow_main_vm") - f.write(f"# [{requirement}] Settings for the main VM that runs the Airflow cheduler and webserver\n") - lines = ObserveratoryConfigString.airflow_main_vm(self.airflow_main_vm) - f.writelines(lines) - f.write("\n") - - def save_airflow_worker_vm(self, f: TextIO): - """Write the Airflow worker VM configuration section to the config file. - - :param f: File object for the config file. - """ - - requirement = self.get_requirement_string("airflow_worker_vm") - f.write(f"# [{requirement}] Settings for the weekly on-demand VM that runs arge tasks\n") - lines = ObserveratoryConfigString.airflow_worker_vm(self.airflow_worker_vm) - f.writelines(lines) - f.write("\n") - - -Config = Union[ObservatoryConfig, TerraformConfig] - - -def make_schema(backend_type: BackendType) -> Dict: - """Make a schema for an Observatory or Terraform config file. - - :param backend_type: the type of backend, local or terraform. - :return: a dictionary containing the schema. - """ - - schema = dict() - is_backend_terraform = backend_type == BackendType.terraform - - # Backend settings - schema["backend"] = { - "required": True, - "type": "dict", - "schema": { - "type": {"required": True, "type": "string", "allowed": [backend_type.value]}, - "environment": {"required": True, "type": "string", "allowed": ["develop", "staging", "production"]}, - }, - } - - # Terraform settings - schema["terraform"] = { - "required": is_backend_terraform, - "type": "dict", - "schema": {"organization": {"required": True, "type": "string", "check_with": customise_pointer}}, - } - - # Google Cloud settings - schema["google_cloud"] = { - "required": is_backend_terraform, - "type": "dict", - "schema": { - "project_id": {"required": is_backend_terraform, "type": "string", "check_with": customise_pointer}, - "credentials": { - "required": is_backend_terraform, - "type": "string", - "check_with": customise_pointer, - "google_application_credentials": True, - }, - "region": { - "required": is_backend_terraform, - "type": "string", - "regex": r"^\w+\-\w+\d+$", - "check_with": customise_pointer, - }, - "zone": { - "required": is_backend_terraform, - "type": "string", - "regex": r"^\w+\-\w+\d+\-[a-z]{1}$", - "check_with": customise_pointer, - }, - "data_location": {"required": is_backend_terraform, "type": "string", "check_with": customise_pointer}, - }, - } - - # Observatory settings - package_types = ["editable", "sdist", "pypi"] - schema["observatory"] = { - "required": True, - "type": "dict", - "schema": { - "package": {"required": True, "type": "string"}, - "package_type": {"required": True, "type": "string", "allowed": package_types}, - "airflow_fernet_key": {"required": True, "type": "string", "check_with": check_schema_field_fernet_key}, - "airflow_secret_key": {"required": True, "type": "string", "check_with": check_schema_field_secret_key}, - "airflow_ui_user_password": {"required": is_backend_terraform, "type": "string"}, - "airflow_ui_user_email": {"required": is_backend_terraform, "type": "string"}, - "observatory_home": {"required": False, "type": "string"}, - "postgres_password": {"required": is_backend_terraform, "type": "string"}, - "redis_port": {"required": False, "type": "integer"}, - "flower_ui_port": {"required": False, "type": "integer"}, - "airflow_ui_port": {"required": False, "type": "integer"}, - "docker_network_name": {"required": False, "type": "string"}, - "docker_network_is_external": {"required": False, "type": "boolean"}, - "docker_compose_project_name": {"required": False, "type": "string"}, - "api_package": {"required": False, "type": "string"}, - "api_package_type": {"required": False, "type": "string", "allowed": package_types}, - "api_port": {"required": False, "type": "integer"}, - }, - } - - # Database settings - if is_backend_terraform: - schema["cloud_sql_database"] = { - "required": True, - "type": "dict", - "schema": { - "tier": {"required": True, "type": "string"}, - "backup_start_time": {"required": True, "type": "string", "regex": r"^\d{2}:\d{2}$"}, - }, - } - - # VM schema - vm_schema = { - "required": True, - "type": "dict", - "schema": { - "machine_type": { - "required": True, - "type": "string", - }, - "disk_size": {"required": True, "type": "integer", "min": 1}, - "disk_type": {"required": True, "type": "string", "allowed": ["pd-standard", "pd-ssd"]}, - "create": {"required": True, "type": "boolean"}, - }, - } - - # Airflow main and worker VM - if is_backend_terraform: - schema["airflow_main_vm"] = vm_schema - schema["airflow_worker_vm"] = vm_schema - - # Workflow configuration - cloud_workspace_schema = { - "project_id": {"required": True, "type": "string"}, - "download_bucket": {"required": True, "type": "string"}, - "transform_bucket": {"required": True, "type": "string"}, - "data_location": {"required": True, "type": "string"}, - "output_project_id": {"required": False, "type": "string"}, - } - - schema["cloud_workspaces"] = { - "required": False, - "type": "list", - "schema": { - "type": "dict", - "schema": {"workspace": {"required": True, "type": "dict", "schema": cloud_workspace_schema}}, - }, - } - - schema["workflows"] = { - "required": False, - "dependencies": "cloud_workspaces", # cloud_workspaces must be specified when workflows are defined - "type": "list", - "schema": { - "type": "dict", - "schema": { - "dag_id": {"required": True, "type": "string"}, - "name": {"required": True, "type": "string"}, - "class_name": {"required": True, "type": "string"}, - "cloud_workspace": {"required": False, "type": "dict", "schema": cloud_workspace_schema}, - "kwargs": {"required": False, "type": "dict"}, - }, - }, - } - - schema["workflows_projects"] = { - "required": False, - "type": "list", - "schema": { - "type": "dict", - "schema": { - "package_name": { - "required": True, - "type": "string", - }, - "package": {"required": True, "type": "string"}, - "package_type": {"required": True, "type": "string", "allowed": package_types}, - "dags_module": { - "required": True, - "type": "string", - }, - }, - }, - } - - return schema - - -class ObserveratoryConfigString: - """This class contains methods to construct config file sections.""" - - @staticmethod - def backend(backend: Backend) -> List[str]: - """Constructs the backend section string. - - :param backend: Backend configuration object. - :return: List of strings for the section, including the section heading." - """ - - lines = [ - "backend:\n", - indent(f"type: {backend.type.name}\n", INDENT1), - indent(f"environment: {backend.environment.name}\n", INDENT1), - ] - - return lines - - @staticmethod - def observatory(observatory: Observatory) -> List[str]: - """Constructs the observatory section string. - - :param observatory: Observatory configuration object. - :return: List of strings for the section, including the section heading." - """ - - lines = [ - "observatory:\n", - indent(f"package: {observatory.package}\n", INDENT1), - indent(f"package_type: {observatory.package_type}\n", INDENT1), - indent(f"airflow_fernet_key: {observatory.airflow_fernet_key}\n", INDENT1), - indent(f"airflow_secret_key: {observatory.airflow_secret_key}\n", INDENT1), - indent(f"airflow_ui_user_email: {observatory.airflow_ui_user_email}\n", INDENT1), - indent(f"airflow_ui_user_password: {observatory.airflow_ui_user_password}\n", INDENT1), - indent(f"observatory_home: {observatory.observatory_home}\n", INDENT1), - indent(f"postgres_password: {observatory.postgres_password}\n", INDENT1), - indent(f"redis_port: {observatory.redis_port}\n", INDENT1), - indent(f"flower_ui_port: {observatory.flower_ui_port}\n", INDENT1), - indent(f"airflow_ui_port: {observatory.airflow_ui_port}\n", INDENT1), - indent(f"docker_network_name: {observatory.docker_network_name}\n", INDENT1), - indent(f"docker_network_is_external: {observatory.docker_network_is_external}\n", INDENT1), - indent(f"docker_compose_project_name: {observatory.docker_compose_project_name}\n", INDENT1), - indent(f"api_package: {observatory.api_package}\n", INDENT1), - indent(f"api_package_type: {observatory.api_package_type}\n", INDENT1), - indent(f"api_port: {observatory.api_port}\n", INDENT1), - ] - - return lines - - @staticmethod - def terraform(terraform: Terraform) -> List[str]: - """Constructs the terraform section string. - - :param observatory: Terraform configuration object. - :return: List of strings for the section, including the section heading." - """ - - if terraform is None: - terraform = Terraform(organization="my-terraform-org-name") - - lines = [ - "terraform:\n", - indent(f"organization: {terraform.organization}\n", INDENT1), - ] - - return lines - - @staticmethod - def google_cloud(*, google_cloud: GoogleCloud, backend: Backend) -> List[str]: - """Constructs the Google Cloud section string. - - :param google_cloud: Google Cloud configuration object. - :param backend: Backend configuration object. - :return: List of strings for the section, including the section heading." - """ - - if google_cloud is None: - google_cloud = GoogleCloud( - project_id="my-gcp-id", - credentials="/path/to/credentials.json", - data_location="us", - region="us-west1", - zone="us-west1-a", - ) - - lines = [ - "google_cloud:\n", - indent(f"project_id: {google_cloud.project_id}\n", INDENT1), - indent(f"credentials: {google_cloud.credentials}\n", INDENT1), - indent(f"data_location: {google_cloud.data_location}\n", INDENT1), - ] - - # Is region and zone something we should be putting in the local config too? - if backend.type == BackendType.terraform: - lines.append(indent(f"region: {google_cloud.region}\n", INDENT1)) - lines.append(indent(f"zone: {google_cloud.zone}\n", INDENT1)) - - return lines - - @staticmethod - def workflows_projects(*, workflows_projects: List[WorkflowsProject] = None) -> List[str]: - """Constructs the DAGs projects section string. - - :param workflows_projects: List of DAGs project configuration objects. - :return: List of strings for the section, including the section heading." - """ - - projects = workflows_projects.copy() - - if len(projects) == 0: - projects.append( - WorkflowsProject( - package_name="observatory-dags", - package="/path/to/dags_project", - package_type="editable", - dags_module="observatory.dags.dags", - ) - ) - - lines = ["workflows_projects:\n"] - for project in projects: - lines.append(indent(f"- package_name: {project.package_name}\n", INDENT1)) - lines.append(indent(f"package: {project.package}\n", INDENT3)) - lines.append(indent(f"package_type: {project.package_type}\n", INDENT3)) - lines.append(indent(f"dags_module: {project.dags_module}\n", INDENT3)) - - return lines - - @staticmethod - def cloud_sql_database(cloud_sql_database: CloudSqlDatabase) -> List[str]: - """Constructs the cloud SQL database section string. - - :param cloud_sql_database: Cloud SQL configuration object. - :return: List of strings for the section, including the section heading." - """ - - if cloud_sql_database is None: - cloud_sql_database = CloudSqlDatabase( - tier="db-custom-2-7680", - backup_start_time="23:00", - ) - - lines = [ - "cloud_sql_database:\n", - indent(f"tier: {cloud_sql_database.tier}\n", INDENT1), - indent(f"backup_start_time: '{cloud_sql_database.backup_start_time}'\n", INDENT1), - ] - - return lines - - @staticmethod - def airflow_vm_lines_(*, vm: VirtualMachine, vm_type) -> List[str]: - """Constructs the virtual machine section string. - - :param vm: Virtual machine configuration object. - :param vm_type: Type of vm being configured. - :return: List of strings for the section, including the section heading." - """ - lines = [ - f"{vm_type}:\n", - indent(f"machine_type: {vm.machine_type}\n", INDENT1), - indent(f"disk_size: {vm.disk_size}\n", INDENT1), - indent(f"disk_type: {vm.disk_type}\n", INDENT1), - indent(f"create: {vm.create}\n", INDENT1), - ] - - return lines - - @staticmethod - def airflow_main_vm(vm: VirtualMachine) -> List[str]: - """Constructs the main virtual machine section string. - - :param vm: Virtual machine configuration object. - :return: List of strings for the section, including the section heading." - """ - - if vm is None: - vm = VirtualMachine( - machine_type="n2-standard-2", - disk_size=50, - disk_type="pd-ssd", - create=True, - ) - - lines = ObserveratoryConfigString.airflow_vm_lines_(vm=vm, vm_type="airflow_main_vm") - return lines - - @staticmethod - def airflow_worker_vm(vm: VirtualMachine) -> List[str]: - """Constructs the worker virtual machine section string. - - :param vm: Virtual machine configuration object. - :return: List of strings for the section, including the section heading." - """ - - if vm is None: - vm = VirtualMachine( - machine_type="n1-standard-8", - disk_size=3000, - disk_type="pd-standard", - create=False, - ) - - lines = ObserveratoryConfigString.airflow_vm_lines_(vm=vm, vm_type="airflow_worker_vm") - return lines diff --git a/observatory-platform/observatory/platform/observatory_environment.py b/observatory-platform/observatory/platform/observatory_environment.py deleted file mode 100644 index 2477b4e61..000000000 --- a/observatory-platform/observatory/platform/observatory_environment.py +++ /dev/null @@ -1,1339 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -# Sources: -# * https://github.com/apache/airflow/blob/ffb472cf9e630bd70f51b74b0d0ea4ab98635572/airflow/cli/commands/task_command.py -# * https://github.com/apache/airflow/blob/master/docs/apache-airflow/best-practices.rst - -# Copyright (c) 2011-2017 Ruslan Spivak -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. -# -# Sources: -# * https://github.com/rspivak/sftpserver/blob/master/src/sftpserver/__init__.py - -# Copyright 2021 Curtin University -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Author: James Diprose - -import contextlib -import datetime -import json -import logging -import os -import shutil -import socket -import socketserver -import threading -import time -import unittest -import uuid -from dataclasses import dataclass -from datetime import datetime, timedelta -from http.server import SimpleHTTPRequestHandler, ThreadingHTTPServer -from multiprocessing import Process -from typing import Dict, List, Optional, Set, Union - -import boto3 -import croniter -import google -import httpretty -import paramiko -import pendulum -import requests -from airflow import DAG, settings -from airflow.exceptions import AirflowException -from airflow.models import DagBag -from airflow.models.connection import Connection -from airflow.models.dagrun import DagRun -from airflow.models.taskinstance import TaskInstance -from airflow.models.variable import Variable -from airflow.operators.empty import EmptyOperator -from airflow.utils import db -from airflow.utils.state import State -from airflow.utils.types import DagRunType -from click.testing import CliRunner -from cryptography.hazmat.backends import default_backend -from cryptography.hazmat.primitives import serialization -from cryptography.hazmat.primitives.asymmetric import rsa -from dateutil.relativedelta import relativedelta -from deepdiff import DeepDiff -from google.cloud import bigquery, storage -from google.cloud.exceptions import NotFound -from pendulum import DateTime -from pyftpdlib.authorizers import DummyAuthorizer -from pyftpdlib.handlers import FTPHandler -from pyftpdlib.servers import ThreadedFTPServer -from sftpserver.stub_sftp import StubServer, StubSFTPServer - -from observatory.api.testing import ObservatoryApiEnvironment -from observatory.platform.bigquery import bq_create_dataset -from observatory.platform.bigquery import ( - bq_sharded_table_id, - bq_load_table, - bq_table_id, - SourceFormat, - bq_delete_old_datasets_with_prefix, -) -from observatory.platform.config import module_file_path, AirflowVars, AirflowConns -from observatory.platform.files import crc32c_base64_hash, get_file_hash, gzip_file_crc, save_jsonl_gz -from observatory.platform.gcs import ( - gcs_blob_uri, - gcs_upload_files, - gcs_delete_old_buckets_with_prefix, -) -from observatory.platform.observatory_config import Workflow, CloudWorkspace, workflows_to_json_string - - -def random_id(): - """Generate a random id for bucket name. - - When code is pushed to a branch and a pull request is open, Github Actions runs the unit tests workflow - twice, one for the push and one for the pull request. However, the uuid4 function, which calls os.urandom(16), - generates the same sequence of values for each workflow run. We have also used the hostname of the machine - in the construction of the random id to ensure sure that the ids are different on both workflow runs. - - :return: a random string id. - """ - return str(uuid.uuid5(uuid.uuid4(), socket.gethostname())).replace("-", "") - - -def test_fixtures_path(*subdirs) -> str: - """Get the path to the Observatory Platform test data directory. - - :return: he Observatory Platform test data directory. - """ - - base_path = module_file_path("tests.fixtures") - return os.path.join(base_path, *subdirs) - - -def find_free_port(host: str = "localhost") -> int: - """Find a free port. - - :param host: the host. - :return: the free port number - """ - - with socketserver.TCPServer((host, 0), None) as tcp_server: - return tcp_server.server_address[1] - - -def save_empty_file(path: str, file_name: str) -> str: - """Save empty file and return path. - - :param path: the file directory. - :param file_name: the file name. - :return: the full file path. - """ - - file_path = os.path.join(path, file_name) - open(file_path, "a").close() - - return file_path - - -@contextlib.contextmanager -def bq_dataset_test_env(*, project_id: str, location: str, prefix: str): - client = bigquery.Client() - dataset_id = prefix + "_" + random_id() - try: - bq_create_dataset(project_id=project_id, dataset_id=dataset_id, location=location) - yield dataset_id - finally: - client.delete_dataset(dataset_id, delete_contents=True, not_found_ok=True) - - -@contextlib.contextmanager -def aws_bucket_test_env(*, prefix: str, region_name: str, expiration_days=1) -> str: - # Create an S3 client - s3 = boto3.Session().client("s3", region_name=region_name) - bucket_name = f"obs-test-{prefix}-{random_id()}" - try: - s3.create_bucket(Bucket=bucket_name) # CreateBucketConfiguration={"LocationConstraint": region_name} - # Set up the lifecycle configuration - lifecycle_configuration = { - "Rules": [ - {"ID": "ExpireObjects", "Status": "Enabled", "Filter": {}, "Expiration": {"Days": expiration_days}} - ] - } - # Apply the lifecycle configuration to the bucket - s3.put_bucket_lifecycle_configuration(Bucket=bucket_name, LifecycleConfiguration=lifecycle_configuration) - yield bucket_name - except Exception as e: - raise e - finally: - # Get a reference to the bucket - s3_resource = boto3.Session().resource("s3") - bucket = s3_resource.Bucket(bucket_name) - - # Delete all objects and versions in the bucket - bucket.objects.all().delete() - bucket.object_versions.all().delete() - - # Delete the bucket - bucket.delete() - - print(f"Bucket {bucket_name} deleted") - - -class ObservatoryEnvironment: - OBSERVATORY_HOME_KEY = "OBSERVATORY_HOME" - - def __init__( - self, - project_id: str = None, - data_location: str = None, - api_host: str = "localhost", - api_port: int = 5000, - enable_api: bool = True, - prefix: Optional[str] = "obsenv_tests", - age_to_delete: int = 12, - workflows: List[Workflow] = None, - gcs_bucket_roles: Union[Set[str], str] = None, - ): - """Constructor for an Observatory environment. - - To create an Observatory environment: - env = ObservatoryEnvironment() - with env.create(): - pass - - :param project_id: the Google Cloud project id. - :param data_location: the Google Cloud data location. - :param api_host: the Observatory API host. - :param api_port: the Observatory API port. - :param enable_api: whether to enable the observatory API or not. - :param prefix: prefix for buckets and datsets created for the testing environment. - :param age_to_delete: age of buckets and datasets to delete that share the same prefix, in hours - """ - - self.project_id = project_id - self.data_location = data_location - self.api_host = api_host - self.api_port = api_port - self.buckets = {} - self.datasets = [] - self.data_path = None - self.session = None - self.temp_dir = None - self.api_env = None - self.api_session = None - self.enable_api = enable_api - self.dag_run: DagRun = None - self.prefix = prefix - self.age_to_delete = age_to_delete - self.workflows = workflows - - if self.create_gcp_env: - self.download_bucket = self.add_bucket(roles=gcs_bucket_roles) - self.transform_bucket = self.add_bucket(roles=gcs_bucket_roles) - self.storage_client = storage.Client() - self.bigquery_client = bigquery.Client() - else: - self.download_bucket = None - self.transform_bucket = None - self.storage_client = None - self.bigquery_client = None - - @property - def cloud_workspace(self) -> CloudWorkspace: - return CloudWorkspace( - project_id=self.project_id, - download_bucket=self.download_bucket, - transform_bucket=self.transform_bucket, - data_location=self.data_location, - ) - - @property - def create_gcp_env(self) -> bool: - """Whether to create the Google Cloud project environment. - - :return: whether to create Google Cloud project environ,ent - """ - - return self.project_id is not None and self.data_location is not None - - def assert_gcp_dependencies(self): - """Assert that the Google Cloud project dependencies are met. - - :return: None. - """ - - assert self.create_gcp_env, "Please specify the Google Cloud project_id and data_location" - - def add_bucket(self, prefix: Optional[str] = None, roles: Optional[Union[Set[str], str]] = None) -> str: - """Add a Google Cloud Storage Bucket to the Observatory environment. - - The bucket will be created when create() is called and deleted when the Observatory - environment is closed. - - :param prefix: an optional additional prefix for the bucket. - :return: returns the bucket name. - """ - - self.assert_gcp_dependencies() - parts = [] - if self.prefix: - parts.append(self.prefix) - if prefix: - parts.append(prefix) - parts.append(random_id()) - bucket_name = "_".join(parts) - - if len(bucket_name) > 63: - raise Exception(f"Bucket name cannot be longer than 63 characters: {bucket_name}") - else: - self.buckets[bucket_name] = roles - - return bucket_name - - def _create_bucket(self, bucket_id: str, roles: Optional[Union[str, Set[str]]] = None) -> None: - """Create a Google Cloud Storage Bucket. - - :param bucket_id: the bucket identifier. - :param roles: Create bucket with custom roles if required. - :return: None. - """ - - self.assert_gcp_dependencies() - bucket = self.storage_client.create_bucket(bucket_id, location=self.data_location) - logging.info(f"Created bucket with name: {bucket_id}") - - if roles: - roles = set(roles) if isinstance(roles, str) else roles - - # Get policy of bucket and add roles. - policy = bucket.get_iam_policy() - for role in roles: - policy.bindings.append({"role": role, "members": {"allUsers"}}) - bucket.set_iam_policy(policy) - logging.info(f"Added permission {role} to bucket {bucket_id} for allUsers.") - - def _create_dataset(self, dataset_id: str) -> None: - """Create a BigQuery dataset. - - :param dataset_id: the dataset identifier. - :return: None. - """ - - self.assert_gcp_dependencies() - dataset = bigquery.Dataset(f"{self.project_id}.{dataset_id}") - dataset.location = self.data_location - self.bigquery_client.create_dataset(dataset, exists_ok=True) - logging.info(f"Created dataset with name: {dataset_id}") - - def _delete_bucket(self, bucket_id: str) -> None: - """Delete a Google Cloud Storage Bucket. - - :param bucket_id: the bucket identifier. - :return: None. - """ - - self.assert_gcp_dependencies() - try: - bucket = self.storage_client.get_bucket(bucket_id) - bucket.delete(force=True) - except requests.exceptions.ReadTimeout: - pass - except google.api_core.exceptions.NotFound: - logging.warning( - f"Bucket {bucket_id} not found. Did you mean to call _delete_bucket on the same bucket twice?" - ) - - def add_dataset(self, prefix: Optional[str] = None) -> str: - """Add a BigQuery dataset to the Observatory environment. - - The BigQuery dataset will be deleted when the Observatory environment is closed. - - :param prefix: an optional additional prefix for the dataset. - :return: the BigQuery dataset identifier. - """ - - self.assert_gcp_dependencies() - parts = [] - if self.prefix: - parts.append(self.prefix) - if prefix: - parts.append(prefix) - parts.append(random_id()) - dataset_id = "_".join(parts) - self.datasets.append(dataset_id) - return dataset_id - - def _delete_dataset(self, dataset_id: str) -> None: - """Delete a BigQuery dataset. - - :param dataset_id: the BigQuery dataset identifier. - :return: None. - """ - - self.assert_gcp_dependencies() - try: - self.bigquery_client.delete_dataset(dataset_id, not_found_ok=True, delete_contents=True) - except requests.exceptions.ReadTimeout: - pass - - def add_variable(self, var: Variable) -> None: - """Add an Airflow variable to the Observatory environment. - - :param var: the Airflow variable. - :return: None. - """ - - self.session.add(var) - self.session.commit() - - def add_connection(self, conn: Connection): - """Add an Airflow connection to the Observatory environment. - - :param conn: the Airflow connection. - :return: None. - """ - - self.session.add(conn) - self.session.commit() - - def run_task(self, task_id: str) -> TaskInstance: - """Run an Airflow task. - - :param task_id: the Airflow task identifier. - :return: None. - """ - - assert self.dag_run is not None, "with create_dag_run must be called before run_task" - - dag = self.dag_run.dag - run_id = self.dag_run.run_id - task = dag.get_task(task_id=task_id) - ti = TaskInstance(task, run_id=run_id) - ti.refresh_from_db() - ti.run(ignore_ti_state=True) - - return ti - - def get_task_instance(self, task_id: str) -> TaskInstance: - """Get an up-to-date TaskInstance. - - :param task_id: the task id. - :return: up-to-date TaskInstance instance. - """ - - assert self.dag_run is not None, "with create_dag_run must be called before get_task_instance" - - run_id = self.dag_run.run_id - task = self.dag_run.dag.get_task(task_id=task_id) - ti = TaskInstance(task, run_id=run_id) - ti.refresh_from_db() - return ti - - @contextlib.contextmanager - def create_dag_run( - self, - dag: DAG, - execution_date: pendulum.DateTime, - run_type: DagRunType = DagRunType.SCHEDULED, - ): - """Create a DagRun that can be used when running tasks. - During cleanup the DAG run state is updated. - - :param dag: the Airflow DAG instance. - :param execution_date: the execution date of the DAG. - :param run_type: what run_type to use when running the DAG run. - :return: None. - """ - - # Get start date, which is one schedule interval after execution date - if isinstance(dag.normalized_schedule_interval, (timedelta, relativedelta)): - start_date = ( - datetime.fromtimestamp(execution_date.timestamp(), pendulum.tz.UTC) + dag.normalized_schedule_interval - ) - else: - start_date = croniter.croniter(dag.normalized_schedule_interval, execution_date).get_next(pendulum.DateTime) - - try: - self.dag_run = dag.create_dagrun( - state=State.RUNNING, - execution_date=execution_date, - start_date=start_date, - run_type=run_type, - ) - yield self.dag_run - finally: - self.dag_run.update_state() - - @contextlib.contextmanager - def create(self, task_logging: bool = False): - """Make and destroy an Observatory isolated environment, which involves: - - * Creating a temporary directory. - * Setting the OBSERVATORY_HOME environment variable. - * Initialising a temporary Airflow database. - * Creating download and transform Google Cloud Storage buckets. - * Creating default Airflow Variables: AirflowVars.DATA_PATH, - AirflowVars.DOWNLOAD_BUCKET and AirflowVars.TRANSFORM_BUCKET. - * Cleaning up all resources when the environment is closed. - - :param task_logging: display airflow task logging - :yield: Observatory environment temporary directory. - """ - - with CliRunner().isolated_filesystem() as temp_dir: - # Set temporary directory - self.temp_dir = temp_dir - - # Prepare environment - self.new_env = {self.OBSERVATORY_HOME_KEY: os.path.join(self.temp_dir, ".observatory")} - prev_env = dict(os.environ) - - try: - # Update environment - os.environ.update(self.new_env) - - # Create Airflow SQLite database - settings.DAGS_FOLDER = os.path.join(self.temp_dir, "airflow", "dags") - os.makedirs(settings.DAGS_FOLDER, exist_ok=True) - airflow_db_path = os.path.join(self.temp_dir, "airflow.db") - settings.SQL_ALCHEMY_CONN = f"sqlite:///{airflow_db_path}" - logging.info(f"SQL_ALCHEMY_CONN: {settings.SQL_ALCHEMY_CONN}") - settings.configure_orm(disable_connection_pool=True) - self.session = settings.Session - db.initdb() - - # Setup Airflow task logging - original_log_level = logging.getLogger().getEffectiveLevel() - if task_logging: - # Set root logger to INFO level, it seems that custom 'logging.info()' statements inside a task - # come from root - logging.getLogger().setLevel(20) - # Propagate logging so it is displayed - logging.getLogger("airflow.task").propagate = True - - # Create buckets and datasets - if self.create_gcp_env: - for bucket_id, roles in self.buckets.items(): - self._create_bucket(bucket_id, roles) - - for dataset_id in self.datasets: - self._create_dataset(dataset_id) - - # Deletes old test buckets and datasets from the project thats older than 2 hours. - gcs_delete_old_buckets_with_prefix(prefix=self.prefix, age_to_delete=self.age_to_delete) - bq_delete_old_datasets_with_prefix(prefix=self.prefix, age_to_delete=self.age_to_delete) - - # Add default Airflow variables - self.data_path = os.path.join(self.temp_dir, "data") - self.add_variable(Variable(key=AirflowVars.DATA_PATH, val=self.data_path)) - - if self.workflows is not None: - var = workflows_to_json_string(self.workflows) - self.add_variable(Variable(key=AirflowVars.WORKFLOWS, val=var)) - - # Reset dag run - self.dag_run: DagRun = None - - # Create ObservatoryApiEnvironment - if self.enable_api: - # Add Observatory API connection - conn = Connection( - conn_id=AirflowConns.OBSERVATORY_API, uri=f"http://:@{self.api_host}:{self.api_port}" - ) - self.add_connection(conn) - - # Create API environment - self.api_env = ObservatoryApiEnvironment(host=self.api_host, port=self.api_port) - with self.api_env.create(): - self.api_session = self.api_env.session - yield self.temp_dir - else: - yield self.temp_dir - finally: - # Set logger settings back to original settings - logging.getLogger().setLevel(original_log_level) - logging.getLogger("airflow.task").propagate = False - - # Revert environment - os.environ.clear() - os.environ.update(prev_env) - - if self.create_gcp_env: - # Remove Google Cloud Storage buckets - for bucket_id, roles in self.buckets.items(): - self._delete_bucket(bucket_id) - - # Remove BigQuery datasets - for dataset_id in self.datasets: - self._delete_dataset(dataset_id) - - -def load_and_parse_json( - file_path: str, - date_fields: Set[str] = None, - timestamp_fields: Set[str] = None, - date_formats: Set[str] = None, - timestamp_formats: str = None, -): - """Load a JSON file for testing purposes. It parses string dates and datetimes into date and datetime instances. - - :param file_path: the path to the JSON file. - :param date_fields: The fields to parse as a date. - :param timestamp_fields: The fields to parse as a timestamp. - :param date_formats: The date formats to use. If none, will use [%Y-%m-%d, %Y%m%d]. - :param timestamp_formats: The timestamp formats to use. If none, will use [%Y-%m-%d %H:%M:%S.%f %Z]. - """ - - if date_fields is None: - date_fields = set() - - if timestamp_fields is None: - timestamp_fields = set() - - if date_formats is None: - date_formats = {"%Y-%m-%d", "%Y%m%d"} - - if timestamp_formats is None: - timestamp_formats = {"%Y-%m-%d %H:%M:%S.%f %Z"} - - def parse_datetime(obj): - for key, value in obj.items(): - # Try to parse into a date or datetime - if key in date_fields: - if isinstance(value, str): - format_found = False - for format in date_formats: - try: - obj[key] = datetime.strptime(value, format).date() - format_found = True - break - except (ValueError, TypeError): - pass - if not format_found: - try: - dt = pendulum.parse(value) - dt = datetime( - dt.year, - dt.month, - dt.day, - dt.hour, - dt.minute, - dt.second, - dt.microsecond, - tzinfo=dt.tzinfo, - ).date() - obj[key] = dt - except (ValueError, TypeError): - pass - - if key in timestamp_fields: - if isinstance(value, str): - format_found = False - for format in timestamp_formats: - try: - obj[key] = datetime.strptime(value, format) - format_found = True - break - except (ValueError, TypeError): - pass - if not format_found: - try: - dt = pendulum.parse(value) - dt = datetime( - dt.year, - dt.month, - dt.day, - dt.hour, - dt.minute, - dt.second, - dt.microsecond, - tzinfo=dt.tzinfo, - ) - obj[key] = dt - except (ValueError, TypeError): - pass - - return obj - - with open(file_path, mode="r") as f: - rows = json.load(f, object_hook=parse_datetime) - return rows - - -def compare_lists_of_dicts(expected: List[Dict], actual: List[Dict], primary_key: str) -> bool: - """Compare two lists of dictionaries, using a primary_key as the basis for the top level comparisons. - - :param expected: the expected data. - :param actual: the actual data. - :param primary_key: the primary key. - :return: whether the expected and actual match. - """ - - expected_dict = {item[primary_key]: item for item in expected} - actual_dict = {item[primary_key]: item for item in actual} - - if set(expected_dict.keys()) != set(actual_dict.keys()): - logging.error("Primary keys don't match:") - logging.error(f"Only in expected: {set(expected_dict.keys()) - set(actual_dict.keys())}") - logging.error(f"Only in actual: {set(actual_dict.keys()) - set(expected_dict.keys())}") - return False - - all_matched = True - for key in expected_dict: - diff = DeepDiff(expected_dict[key], actual_dict[key], ignore_order=True) - logging.info(f"primary_key: {key}") - for diff_type, changes in diff.items(): - all_matched = False - log_diff(diff_type, changes) - - return all_matched - - -def log_diff(diff_type, changes): - """Log the DeepDiff changes. - - :param diff_type: the diff type. - :param changes: the changes. - :return: None. - """ - - if diff_type == "values_changed": - for key_path, change in changes.items(): - logging.error( - f"(expected) != (actual) {key_path}: {change['old_value']} (expected) != (actual) {change['new_value']}" - ) - elif diff_type == "dictionary_item_added": - for change in changes: - logging.error(f"dictionary_item_added: {change}") - elif diff_type == "dictionary_item_removed": - for change in changes: - logging.error(f"dictionary_item_removed: {change}") - elif diff_type == "type_changes": - for key_path, change in changes.items(): - logging.error( - f"(expected) != (actual) {key_path}: {change['old_type']} (expected) != (actual) {change['new_type']}" - ) - - -class ObservatoryTestCase(unittest.TestCase): - """Common test functions for testing Observatory Platform DAGs""" - - def __init__(self, *args, **kwargs): - """Constructor which sets up variables used by tests. - - :param args: arguments. - :param kwargs: keyword arguments. - """ - - super(ObservatoryTestCase, self).__init__(*args, **kwargs) - self.storage_client = storage.Client() - self.bigquery_client = bigquery.Client() - - # Turn logging to warning because vcr prints too much at info level - logging.basicConfig() - vcr_log = logging.getLogger("vcr") - vcr_log.setLevel(logging.WARNING) - - @property - def fake_cloud_workspace(self): - return CloudWorkspace( - project_id="project-id", - download_bucket="download_bucket", - transform_bucket="transform_bucket", - data_location="us", - ) - - def assert_dag_structure(self, expected: Dict, dag: DAG): - """Assert the DAG structure. - - :param expected: a dictionary of DAG task ids as keys and values which should be a list of downstream task ids. - :param dag: the DAG. - :return: None. - """ - - expected_keys = expected.keys() - actual_keys = dag.task_dict.keys() - self.assertEqual(expected_keys, actual_keys) - - for task_id, downstream_list in expected.items(): - self.assertTrue(dag.has_task(task_id)) - task = dag.get_task(task_id) - self.assertEqual(set(downstream_list), task.downstream_task_ids) - - def assert_dag_load(self, dag_id: str, dag_file: str): - """Assert that the given DAG loads from a DagBag. - - :param dag_id: the DAG id. - :param dag_file: the path to the DAG file. - :return: None. - """ - - with CliRunner().isolated_filesystem() as dag_folder: - if not os.path.exists(dag_file): - raise Exception(f"{dag_file} does not exist.") - - shutil.copy(dag_file, os.path.join(dag_folder, os.path.basename(dag_file))) - - dag_bag = DagBag(dag_folder=dag_folder) - - if dag_bag.import_errors != {}: - logging.error(f"DagBag errors: {dag_bag.import_errors}") - self.assertEqual({}, dag_bag.import_errors, dag_bag.import_errors) - - dag = dag_bag.get_dag(dag_id=dag_id) - - if dag is None: - logging.error( - f"DAG not found in the database. Make sure the DAG ID is correct, and the dag file contains the words 'airflow' and 'DAG'." - ) - self.assertIsNotNone(dag) - - self.assertGreaterEqual(len(dag.tasks), 1) - - def assert_dag_load_from_config(self, dag_id: str): - """Assert that the given DAG loads from a config file. - - :param dag_id: the DAG id. - :return: None. - """ - - self.assert_dag_load(dag_id, os.path.join(module_file_path("observatory.platform.dags"), "load_workflows.py")) - - def assert_blob_exists(self, bucket_id: str, blob_name: str): - """Assert whether a blob exists or not. - - :param bucket_id: the Google Cloud storage bucket id. - :param blob_name: the blob name (full path except for bucket) - :return: None. - """ - - # Get blob - bucket = self.storage_client.get_bucket(bucket_id) - blob = bucket.blob(blob_name) - self.assertTrue(blob.exists()) - - def assert_blob_integrity(self, bucket_id: str, blob_name: str, local_file_path: str): - """Assert whether the blob uploaded and that it has the expected hash. - - :param blob_name: the Google Cloud Blob name, i.e. the entire path to the blob on the Cloud Storage bucket. - :param bucket_id: the Google Cloud Storage bucket id. - :param local_file_path: the path to the local file. - :return: whether the blob uploaded and that it has the expected hash. - """ - - # Get blob - bucket = self.storage_client.get_bucket(bucket_id) - blob = bucket.blob(blob_name) - result = blob.exists() - - # Check that blob hash matches if it exists - if result: - # Get blob hash - blob.reload() - expected_hash = blob.crc32c - - # Check actual file - actual_hash = crc32c_base64_hash(local_file_path) - result = expected_hash == actual_hash - - self.assertTrue(result) - - def assert_table_integrity(self, table_id: str, expected_rows: int = None): - """Assert whether a BigQuery table exists and has the expected number of rows. - - :param table_id: the BigQuery table id. - :param expected_rows: the expected number of rows. - :return: whether the table exists and has the expected number of rows. - """ - - table = None - actual_rows = None - try: - table = self.bigquery_client.get_table(table_id) - actual_rows = table.num_rows - except NotFound: - pass - - self.assertIsNotNone(table) - if expected_rows is not None: - self.assertEqual(expected_rows, actual_rows) - - def assert_table_content(self, table_id: str, expected_content: List[dict], primary_key: str): - """Assert whether a BigQuery table has any content and if expected content is given whether it matches the - actual content. The order of the rows is not checked, only whether all rows in the expected content match - the rows in the actual content. - The expected content should be a list of dictionaries, where each dictionary represents one row of the table, - the keys are fieldnames and values are values. - - :param table_id: the BigQuery table id. - :param expected_content: the expected content. - :param primary_key: the primary key to use to compare. - :return: whether the table has content and the expected content is correct - """ - - logging.info( - f"assert_table_content: {table_id}, len(expected_content)={len(expected_content), }, primary_key={primary_key}" - ) - rows = None - actual_content = None - try: - rows = list(self.bigquery_client.list_rows(table_id)) - actual_content = [dict(row) for row in rows] - except NotFound: - pass - self.assertIsNotNone(rows) - self.assertIsNotNone(actual_content) - results = compare_lists_of_dicts(expected_content, actual_content, primary_key) - assert results, "Rows in actual content do not match expected content" - - def assert_table_bytes(self, table_id: str, expected_bytes: int): - """Assert whether the given bytes from a BigQuery table matches the expected bytes. - - :param table_id: the BigQuery table id. - :param expected_bytes: the expected number of bytes. - :return: whether the table exists and the expected bytes match - """ - - table = None - try: - table = self.bigquery_client.get_table(table_id) - except NotFound: - pass - - self.assertIsNotNone(table) - self.assertEqual(expected_bytes, table.num_bytes) - - def assert_file_integrity(self, file_path: str, expected_hash: str, algorithm: str): - """Assert that a file exists and it has the correct hash. - - :param file_path: the path to the file. - :param expected_hash: the expected hash. - :param algorithm: the algorithm to use when hashing, either md5 or gzip crc - :return: None. - """ - - self.assertTrue(os.path.isfile(file_path)) - - if algorithm == "gzip_crc": - actual_hash = gzip_file_crc(file_path) - else: - actual_hash = get_file_hash(file_path=file_path, algorithm=algorithm) - - self.assertEqual(expected_hash, actual_hash) - - def assert_cleanup(self, workflow_folder: str): - """Assert that the download, extracted and transformed folders were cleaned up. - - :param workflow_folder: the path to the DAGs download folder. - :return: None. - """ - - self.assertFalse(os.path.exists(workflow_folder)) - - def setup_mock_file_download( - self, uri: str, file_path: str, headers: Dict = None, method: str = httpretty.GET - ) -> None: - """Use httpretty to mock a file download. - - This function must be called from within an httpretty.enabled() block, for instance: - - with httpretty.enabled(): - self.setup_mock_file_download('https://example.com/file.zip', path_to_file) - - :param uri: the URI of the file download to mock. - :param file_path: the path to the file on the local system. - :param headers: the response headers. - :return: None. - """ - - if headers is None: - headers = {} - - with open(file_path, "rb") as f: - body = f.read() - - httpretty.register_uri(method, uri, adding_headers=headers, body=body) - - -class SftpServer: - """A Mock SFTP server for testing purposes""" - - def __init__( - self, - host: str = "localhost", - port: int = 3373, - level: str = "INFO", - backlog: int = 10, - startup_wait_secs: int = 1, - socket_timeout: int = 10, - ): - """Create a Mock SftpServer instance. - - :param host: the host name. - :param port: the port. - :param level: the log level. - :param backlog: ? - :param startup_wait_secs: time in seconds to wait before returning from create to give the server enough - time to start before connecting to it. - """ - - self.host = host - self.port = port - self.level = level - self.backlog = backlog - self.startup_wait_secs = startup_wait_secs - self.is_shutdown = True - self.tmp_dir = None - self.root_dir = None - self.private_key_path = None - self.server_thread = None - self.socket_timeout = socket_timeout - - def _generate_key(self): - """Generate a private key. - - :return: the filepath to the private key. - """ - - key = rsa.generate_private_key(public_exponent=65537, key_size=2048, backend=default_backend()) - - private_key_path = os.path.join(self.tmp_dir, "test_rsa.key") - with open(private_key_path, "wb") as f: - f.write( - key.private_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PrivateFormat.TraditionalOpenSSL, - encryption_algorithm=serialization.NoEncryption(), - ) - ) - - return private_key_path - - def _start_server(self): - paramiko_level = getattr(paramiko.common, self.level) - paramiko.common.logging.basicConfig(level=paramiko_level) - - server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - server_socket.settimeout(self.socket_timeout) - server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, True) - server_socket.bind((self.host, self.port)) - server_socket.listen(self.backlog) - - while not self.is_shutdown: - try: - conn, addr = server_socket.accept() - transport = paramiko.Transport(conn) - transport.add_server_key(paramiko.RSAKey.from_private_key_file(self.private_key_path)) - transport.set_subsystem_handler("sftp", paramiko.SFTPServer, StubSFTPServer) - - server = StubServer() - transport.start_server(server=server) - - channel = transport.accept() - while transport.is_active() and not self.is_shutdown: - time.sleep(1) - - except socket.timeout: - # Timeout must be set for socket otherwise it will wait for a connection forever and block - # the thread from exiting. At: conn, addr = server_socket.accept() - pass - - @contextlib.contextmanager - def create(self): - """Make and destroy a test SFTP server. - - :yield: None. - """ - - with CliRunner().isolated_filesystem() as tmp_dir: - # Override the root directory of the SFTP server, which is set as the cwd at import time - self.tmp_dir = tmp_dir - self.root_dir = os.path.join(tmp_dir, "home") - os.makedirs(self.root_dir, exist_ok=True) - StubSFTPServer.ROOT = self.root_dir - - # Generate private key - self.private_key_path = self._generate_key() - - try: - self.is_shutdown = False - self.server_thread = threading.Thread(target=self._start_server) - self.server_thread.start() - - # Wait a little bit to give the server time to grab the socket - time.sleep(self.startup_wait_secs) - - yield self.root_dir - finally: - # Stop server and wait for server thread to join - self.is_shutdown = True - if self.server_thread is not None: - self.server_thread.join() - - -class FtpServer: - """ - Create a Mock FTPServer instance. - - :param directory: The directory that is hosted on FTP server. - :param host: Hostname of the server. - :param port: The port number. - :param startup_wait_secs: time in seconds to wait before returning from create to give the server enough - time to start before connecting to it. - """ - - def __init__( - self, - directory: str = "/", - host: str = "localhost", - port: int = 21, - startup_wait_secs: int = 1, - root_username: str = "root", - root_password: str = "pass", - ): - self.host = host - self.port = port - self.directory = directory - self.startup_wait_secs = startup_wait_secs - - self.root_username = root_username - self.root_password = root_password - - self.is_shutdown = True - self.server_thread = None - - @contextlib.contextmanager - def create(self): - """Make and destroy a test FTP server. - - :yield: self.directory. - """ - - # Set up the FTP server with root and anonymous users. - authorizer = DummyAuthorizer() - authorizer.add_user( - username=self.root_username, password=self.root_password, homedir=self.directory, perm="elradfmwMT" - ) - authorizer.add_anonymous(self.directory) - handler = FTPHandler - handler.authorizer = authorizer - - try: - # Start server in separate thread. - self.server = ThreadedFTPServer((self.host, self.port), handler) - self.server_thread = threading.Thread(target=self.server.serve_forever) - self.server_thread.daemon = True - self.server_thread.start() - - # Wait a little bit to give the server time to grab the socket - time.sleep(self.startup_wait_secs) - - yield self.directory - - finally: - # Stop server and wait for server thread to join - self.is_shutdown = True - if self.server_thread is not None: - self.server.close_all() - self.server_thread.join() - - -def make_dummy_dag(dag_id: str, execution_date: pendulum.DateTime) -> DAG: - """A Dummy DAG for testing purposes. - - :param dag_id: the DAG id. - :param execution_date: the DAGs execution date. - :return: the DAG. - """ - - with DAG( - dag_id=dag_id, - schedule="@weekly", - default_args={"owner": "airflow", "start_date": execution_date}, - catchup=False, - ) as dag: - task1 = EmptyOperator(task_id="dummy_task") - - return dag - - -@dataclass -class Table: - """A table to be loaded into Elasticsearch. - - :param table_name: the table name. - :param is_sharded: whether the table is sharded or not. - :param dataset_id: the dataset id. - :param records: the records to load. - :param schema_file_path: the schema file path. - """ - - table_name: str - is_sharded: bool - dataset_id: str - records: List[Dict] - schema_file_path: str - - -def bq_load_tables( - *, - project_id: str, - tables: List[Table], - bucket_name: str, - snapshot_date: DateTime, -): - """Load the fake Observatory Dataset in BigQuery. - - :param project_id: GCP project id. - :param tables: the list of tables and records to load. - :param bucket_name: the Google Cloud Storage bucket name. - :param snapshot_date: the release date for the observatory dataset. - :return: None. - """ - - with CliRunner().isolated_filesystem() as t: - files_list = [] - blob_names = [] - - # Save to JSONL - for table in tables: - blob_name = f"{table.dataset_id}-{table.table_name}.jsonl.gz" - file_path = os.path.join(t, blob_name) - save_jsonl_gz(file_path, table.records) - files_list.append(file_path) - blob_names.append(blob_name) - - # Upload to Google Cloud Storage - success = gcs_upload_files(bucket_name=bucket_name, file_paths=files_list, blob_names=blob_names) - assert success, "Data did not load into BigQuery" - - # Save to BigQuery tables - for blob_name, table in zip(blob_names, tables): - if table.schema_file_path is None: - logging.error( - f"No schema found with search parameters: analysis_schema_path={table.schema_file_path}, " - f"table_name={table.table_name}, snapshot_date={snapshot_date}" - ) - exit(os.EX_CONFIG) - - if table.is_sharded: - table_id = bq_sharded_table_id(project_id, table.dataset_id, table.table_name, snapshot_date) - else: - table_id = bq_table_id(project_id, table.dataset_id, table.table_name) - - # Load BigQuery table - uri = gcs_blob_uri(bucket_name, blob_name) - logging.info(f"URI: {uri}") - success = bq_load_table( - uri=uri, - table_id=table_id, - schema_file_path=table.schema_file_path, - source_format=SourceFormat.NEWLINE_DELIMITED_JSON, - ) - if not success: - raise AirflowException("bq_load task: data failed to load data into BigQuery") - - -class HttpServer: - """Simple HTTP server for testing. Serves files from a directory to http://locahost:port/filename""" - - def __init__(self, directory: str, host: str = "localhost", port: int = None): - """Initialises the server. - - :param directory: Directory to serve. - """ - - self.directory = directory - self.process = None - - self.host = host - if port is None: - port = find_free_port(host=self.host) - self.port = port - self.address = (self.host, self.port) - self.url = f"http://{self.host}:{self.port}/" - - @staticmethod - def serve_(address, directory): - """Entry point for a new process to run HTTP server. - - :param address: Address (host, port) to bind server to. - :param directory: Directory to serve. - """ - - os.chdir(directory) - server = ThreadingHTTPServer(address, SimpleHTTPRequestHandler) - server.serve_forever() - - def start(self): - """Spin the server up in a new process.""" - - # Don't try to start it twice. - if self.process is not None and self.process.is_alive(): - return - - self.process = Process( - target=HttpServer.serve_, - args=( - self.address, - self.directory, - ), - ) - self.process.start() - - def stop(self): - """Shutdown the server.""" - - if self.process is not None and self.process.is_alive(): - self.process.kill() - self.process.join() - - @contextlib.contextmanager - def create(self): - """Spin up a server for the duration of the session.""" - self.start() - - try: - yield self.process - finally: - self.stop() diff --git a/observatory-platform/observatory/platform/terraform/build.sh b/observatory-platform/observatory/platform/terraform/build.sh deleted file mode 100644 index 4abb3875f..000000000 --- a/observatory-platform/observatory/platform/terraform/build.sh +++ /dev/null @@ -1,66 +0,0 @@ -#!/usr/bin/env bash - -echo " ----- Sleeping for 30 seconds as per Packer documentation ----- " -sleep 30 - -echo " ----- Install Docker and Docker Compose V2 (using apt-get) ----- " -sudo apt-get update -sudo apt-get -y install ca-certificates curl gnupg lsb-release -sudo mkdir -p /etc/apt/keyrings -curl -fsSL https://download.docker.com/linux/ubuntu/gpg | sudo gpg --dearmor -o /etc/apt/keyrings/docker.gpg -echo "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.gpg] https://download.docker.com/linux/ubuntu $(lsb_release -cs) stable" | sudo tee /etc/apt/sources.list.d/docker.list > /dev/null -sudo apt-get update -sudo apt-get -y install docker-ce docker-ce-cli containerd.io docker-compose-plugin -sudo service docker restart - -echo " ----- Make the airflow user and add it to the docker group ----- " -sudo useradd --home-dir /home/airflow --shell /bin/bash --create-home airflow -sudo usermod -aG docker airflow -sudo newgrp docker - -echo " ----- Install Berglas v1.0.1 ----- " -# See here for a list of releases: https://github.com/GoogleCloudPlatform/berglas/releases -curl -L "https://github.com/GoogleCloudPlatform/berglas/releases/download/v1.0.1/berglas_1.0.1_linux_amd64.tar.gz" | sudo tar -xz -C /usr/local/bin berglas -sudo chmod +x /usr/local/bin/berglas - -echo " ----- Install Google Compute Ops Agent ----- " -curl -sSO https://dl.google.com/cloudagents/add-google-cloud-ops-agent-repo.sh -sudo bash add-google-cloud-ops-agent-repo.sh --also-install --version=2.*.* -sudo systemctl status google-cloud-ops-agent"*" - -echo " ----- Make airflow and docker directories, move packages, and clean up files ----- " -sudo mkdir -p /opt/airflow/logs -sudo mkdir /opt/airflow/dags -sudo mkdir -p /opt/observatory/data -sudo mkdir -p /opt/observatory/build/docker - -# Move all packages into /opt directory -sudo cp -r /tmp/opt/packages/* /opt - -# Move Docker files into /opt/observatory/build/docker directory -sudo cp -r /tmp/opt/observatory/build/docker/* /opt/observatory/build/docker - -# Remove tmp -sudo rm -r /tmp - -# Own all /opt directories and packer home folder -sudo chown -R airflow /opt/ -sudo chown -R airflow /home/packer/ - -# Set working directory and environment variables for building docker containers -cd /opt/observatory/build/docker -export HOST_USER_ID=$(id -u airflow) -export HOST_REDIS_PORT=6379 -export HOST_FLOWER_UI_PORT=5555 -export HOST_AIRFLOW_UI_PORT=8080 -export HOST_ELASTIC_PORT=9200 -export HOST_KIBANA_PORT=5601 -export HOST_DATA_PATH=/opt/observatory/data -export HOST_LOGS_PATH=/opt/airflow/logs -export HOST_GOOGLE_APPLICATION_CREDENTIALS=/opt/observatory/google_application_credentials.json -export HOST_API_SERVER_PORT=5002 - -echo " ----- Building docker containers with docker-compose, running as airflow user ----- " -PRESERVE_ENV="HOST_USER_ID,HOST_REDIS_PORT,HOST_FLOWER_UI_PORT,HOST_AIRFLOW_UI_PORT,HOST_ELASTIC_PORT,HOST_KIBANA_PORT,HOST_API_SERVER_PORT,HOST_DATA_PATH,HOST_LOGS_PATH,HOST_GOOGLE_APPLICATION_CREDENTIALS" -sudo -u airflow --preserve-env=${PRESERVE_ENV} bash -c "docker compose -f docker-compose.observatory.yml pull" -sudo -u airflow --preserve-env=${PRESERVE_ENV} bash -c "docker compose -f docker-compose.observatory.yml build" \ No newline at end of file diff --git a/observatory-platform/observatory/platform/terraform/main.tf b/observatory-platform/observatory/platform/terraform/main.tf deleted file mode 100644 index 4fc6115b1..000000000 --- a/observatory-platform/observatory/platform/terraform/main.tf +++ /dev/null @@ -1,575 +0,0 @@ -######################################################################################################################## -# Configure Google Cloud Provider -######################################################################################################################## - -terraform { - backend "remote" { - workspaces { - prefix = "observatory-" - } - } -} - -provider "google" { - credentials = var.google_cloud.credentials - project = var.google_cloud.project_id - region = var.google_cloud.region - zone = var.google_cloud.zone -} - -provider "google-beta" { - credentials = var.google_cloud.credentials - project = var.google_cloud.project_id - region = var.google_cloud.region - zone = var.google_cloud.zone -} - -data "google_project" "project" { - project_id = var.google_cloud.project_id - depends_on = [google_project_service.cloud_resource_manager] -} - -data "google_compute_default_service_account" "default" { - depends_on = [google_project_service.compute_engine, google_project_service.services] -} - -data "google_storage_transfer_project_service_account" "default" { - depends_on = [google_project_service.services] -} -locals { - compute_service_account_email = data.google_compute_default_service_account.default.email - transfer_service_account_email = data.google_storage_transfer_project_service_account.default.email -} - -######################################################################################################################## -# Terraform Cloud Environment Variable (https://www.terraform.io/docs/cloud/run/run-environment.html) -######################################################################################################################## - -variable "TFC_WORKSPACE_SLUG" { - type = string - default = "" # An error occurs when you are running TF backend other than Terraform Cloud -} - -locals { - organization = split("/", var.TFC_WORKSPACE_SLUG)[0] - workspace_name = split("/", var.TFC_WORKSPACE_SLUG)[1] -} - -######################################################################################################################## -# Enable google cloud APIs -######################################################################################################################## - -resource "google_project_service" "cloud_resource_manager" { - project = var.google_cloud.project_id - service = "cloudresourcemanager.googleapis.com" - disable_dependent_services = true -} - -# Can't disable dependent services, because of existing observatory-image -resource "google_project_service" "compute_engine" { - project = var.google_cloud.project_id - service = "compute.googleapis.com" - disable_on_destroy = false - depends_on = [google_project_service.cloud_resource_manager] -} - -resource "google_project_service" "services" { - for_each = toset([ - "storagetransfer.googleapis.com", "iam.googleapis.com", "servicenetworking.googleapis.com", - "sqladmin.googleapis.com", "secretmanager.googleapis.com" - ]) - project = var.google_cloud.project_id - service = each.key - disable_dependent_services = true - depends_on = [google_project_service.cloud_resource_manager] -} - -######################################################################################################################## -# Create a service account and add permissions -######################################################################################################################## - -resource "google_service_account" "observatory_service_account" { - account_id = var.google_cloud.project_id - display_name = "Apache Airflow Service Account" - description = "The Google Service Account used by Apache Airflow" - depends_on = [google_project_service.services] -} - -# Create service account key, save to Google Secrets Manager and give compute service account access to the secret -resource "google_service_account_key" "observatory_service_account_key" { - service_account_id = google_service_account.observatory_service_account.name -} - -# BigQuery admin -resource "google_project_iam_member" "observatory_service_account_bigquery_iam" { - project = var.google_cloud.project_id - role = "roles/bigquery.admin" - member = "serviceAccount:${google_service_account.observatory_service_account.email}" -} - -# Storage Transfer Admin -resource "google_project_iam_member" "observatory_service_account_storage_iam" { - project = var.google_cloud.project_id - role = "roles/storagetransfer.admin" - member = "serviceAccount:${google_service_account.observatory_service_account.email}" -} - -# Add BQ limit for per user per day -resource "google_service_usage_consumer_quota_override" "bq_usage_per_user_per_day" { - provider = google-beta - project = var.google_cloud.project_id - service = "bigquery.googleapis.com" - metric = urlencode("bigquery.googleapis.com/quota/query/usage") - limit = urlencode("/d/project/user") - override_value = 10485760 # in megabytes, so 10 TiB - force = true -} - -# Add BQ limit for per day for entire project -resource "google_service_usage_consumer_quota_override" "bq_usage_per_day" { - provider = google-beta - project = var.google_cloud.project_id - service = "bigquery.googleapis.com" - metric = urlencode("bigquery.googleapis.com/quota/query/usage") - limit = urlencode("/d/project") - override_value = 15728640 # in megabytes, so 15 TiB - force = true -} - -######################################################################################################################## -# Storage Buckets -######################################################################################################################## - -# Random id to prevent destroy of resources in keepers -resource "random_id" "buckets_protector" { - count = var.environment == "production" ? 1 : 0 - byte_length = 8 - keepers = { - download_bucket = google_storage_bucket.observatory_download_bucket.id - transform_bucket = google_storage_bucket.observatory_transform_bucket.id - } - lifecycle { - prevent_destroy = true - } -} - -# Bucket for storing downloaded files -resource "google_storage_bucket" "observatory_download_bucket" { - name = "${var.google_cloud.project_id}-download" - force_destroy = true - location = var.google_cloud.data_location - project = var.google_cloud.project_id - lifecycle_rule { - condition { - age = "31" - matches_storage_class = ["STANDARD"] - } - action { - type = "SetStorageClass" - storage_class = "NEARLINE" - } - } - lifecycle_rule { - condition { - age = "365" - matches_storage_class = ["NEARLINE"] - } - action { - type = "SetStorageClass" - storage_class = "COLDLINE" - } - } -} - -# Permissions so that the Transfer Service Account can read / write files to bucket -resource "google_storage_bucket_iam_member" "observatory_download_bucket_transfer_service_account_legacy_bucket_reader" { - bucket = google_storage_bucket.observatory_download_bucket.name - role = "roles/storage.legacyBucketReader" - member = "serviceAccount:${data.google_storage_transfer_project_service_account.default.email}" -} - -# Must have object admin so that files can be overwritten -resource "google_storage_bucket_iam_member" "observatory_download_bucket_transfer_service_account_object_admin" { - bucket = google_storage_bucket.observatory_download_bucket.name - role = "roles/storage.objectAdmin" - member = "serviceAccount:${data.google_storage_transfer_project_service_account.default.email}" -} - -# Permissions so that Observatory Platform service account can read and write -resource "google_storage_bucket_iam_member" "observatory_download_bucket_observatory_service_account_legacy_bucket_reader" { - bucket = google_storage_bucket.observatory_download_bucket.name - role = "roles/storage.legacyBucketReader" - member = "serviceAccount:${google_service_account.observatory_service_account.email}" -} - -resource "google_storage_bucket_iam_member" "observatory_download_bucket_observatory_service_account_object_creator" { - bucket = google_storage_bucket.observatory_download_bucket.name - role = "roles/storage.objectCreator" - member = "serviceAccount:${google_service_account.observatory_service_account.email}" -} - -resource "google_storage_bucket_iam_member" "observatory_download_bucket_observatory_service_account_object_viewer" { - bucket = google_storage_bucket.observatory_download_bucket.name - role = "roles/storage.objectViewer" - member = "serviceAccount:${google_service_account.observatory_service_account.email}" -} - - -# Bucket for storing transformed files -resource "google_storage_bucket" "observatory_transform_bucket" { - name = "${var.google_cloud.project_id}-transform" - force_destroy = true - location = var.google_cloud.data_location - project = var.google_cloud.project_id - lifecycle_rule { - condition { - age = "31" - matches_storage_class = ["STANDARD"] - } - action { - type = "SetStorageClass" - storage_class = "NEARLINE" - } - } - lifecycle_rule { - condition { - age = "62" - matches_storage_class = ["NEARLINE"] - } - action { - type = "Delete" - } - } -} - -# Permissions so that Observatory Platform service account can read, create and delete -resource "google_storage_bucket_iam_member" "observatory_transform_bucket_observatory_service_account_legacy_bucket_reader" { - bucket = google_storage_bucket.observatory_transform_bucket.name - role = "roles/storage.legacyBucketReader" - member = "serviceAccount:${google_service_account.observatory_service_account.email}" -} - -# Must have object admin so that files can be overwritten, e.g. if a file was transformed incorrectly and has to be -# uploaded again -resource "google_storage_bucket_iam_member" "observatory_transform_bucket_observatory_service_account_object_admin" { - bucket = google_storage_bucket.observatory_transform_bucket.name - role = "roles/storage.objectAdmin" - member = "serviceAccount:${google_service_account.observatory_service_account.email}" -} - -# Bucket for airflow related files, e.g. airflow logs -resource "random_id" "airflow_bucket_protector" { - count = var.environment == "production" ? 1 : 0 - byte_length = 8 - keepers = { - airflow_bucket = google_storage_bucket.observatory_airflow_bucket.id - } - lifecycle { - prevent_destroy = true - } -} - -resource "google_storage_bucket" "observatory_airflow_bucket" { - name = "${var.google_cloud.project_id}-airflow" - force_destroy = true - location = var.google_cloud.data_location - project = var.google_cloud.project_id - lifecycle_rule { - condition { - age = "31" - matches_storage_class = ["STANDARD"] - } - action { - type = "SetStorageClass" - storage_class = "NEARLINE" - } - } - lifecycle_rule { - condition { - age = "365" - matches_storage_class = ["NEARLINE"] - } - action { - type = "SetStorageClass" - storage_class = "COLDLINE" - } - } -} - -# Permissions so that Observatory Platform service account can read and write -resource "google_storage_bucket_iam_member" "observatory_airflow_bucket_observatory_service_account_legacy_bucket_reader" { - bucket = google_storage_bucket.observatory_airflow_bucket.name - role = "roles/storage.legacyBucketReader" - member = "serviceAccount:${google_service_account.observatory_service_account.email}" -} - -resource "google_storage_bucket_iam_member" "observatory_airflow_bucket_observatory_service_account_object_creator" { - bucket = google_storage_bucket.observatory_airflow_bucket.name - role = "roles/storage.objectCreator" - member = "serviceAccount:${google_service_account.observatory_service_account.email}" -} - -resource "google_storage_bucket_iam_member" "observatory_airflow_bucket_observatory_service_account_object_viewer" { - bucket = google_storage_bucket.observatory_airflow_bucket.name - role = "roles/storage.objectViewer" - member = "serviceAccount:${google_service_account.observatory_service_account.email}" -} - -######################################################################################################################## -# Observatory Platform VPC Network -######################################################################################################################## - -# Necessary to define the network so that the VMs can talk to the Cloud SQL database. -locals { - network_name = "ao-network" -} - -resource "google_compute_network" "observatory_network" { - name = local.network_name - depends_on = [google_project_service.compute_engine] -} - -data "google_compute_subnetwork" "observatory_subnetwork" { - name = local.network_name - depends_on = [google_compute_network.observatory_network] # necessary to force reading of data -} - -resource "google_compute_firewall" "allow_internal_airflow" { - name = "allow-internal-airflow" - description = "Allow internal Airflow connections" - network = google_compute_network.observatory_network.name - source_ranges = ["10.128.0.0/9"] - target_tags = ["allow-internal-airflow"] - - allow { - protocol = "tcp" - ports = ["5002", "6379", "8793"] # Open apiserver, redis and Airflow worker ports to the internal network - } - priority = 65534 -} - -resource "google_compute_firewall" "allow_ssh" { - name = "allow-ssh" - description = "Allow SSH from anywhere" - network = google_compute_network.observatory_network.name - source_ranges = ["0.0.0.0/0"] - target_tags = ["allow-ssh"] - - allow { - protocol = "tcp" - ports = ["22"] - } - priority = 65534 -} - -######################################################################################################################## -# Observatory Platform Cloud SQL database -######################################################################################################################## - -resource "google_compute_global_address" "airflow_db_private_ip" { - name = "airflow-db-private-ip" - purpose = "VPC_PEERING" - address_type = "INTERNAL" - prefix_length = 16 - network = google_compute_network.observatory_network.id -} - -resource "google_service_networking_connection" "private_vpc_connection" { - network = google_compute_network.observatory_network.id - service = "servicenetworking.googleapis.com" - reserved_peering_ranges = [google_compute_global_address.airflow_db_private_ip.name] - depends_on = [google_project_service.services] -} - -resource "random_id" "airflow_db_name_suffix" { - byte_length = 4 -} - -resource "google_sql_database_instance" "observatory_db_instance" { - name = var.environment == "production" ? "observatory-db-instance" : "observatory-db-instance-${random_id.airflow_db_name_suffix.hex}" - database_version = "POSTGRES_12" - region = var.google_cloud.region - deletion_protection = var.environment == "production" - - depends_on = [google_service_networking_connection.private_vpc_connection, google_project_service.services] - settings { - tier = var.cloud_sql_database.tier - ip_configuration { - ipv4_enabled = false - private_network = google_compute_network.observatory_network.id - } - backup_configuration { - binary_log_enabled = false - enabled = true - location = var.google_cloud.data_location - start_time = var.cloud_sql_database.backup_start_time - } - deletion_protection_enabled = true # Stops the machine being deleted at the GCP platform level - } -} - -// Airflow Database -resource "google_sql_database" "airflow_db" { - name = "airflow" - depends_on = [google_sql_database_instance.observatory_db_instance] - instance = google_sql_database_instance.observatory_db_instance.name -} - -// New database user -resource "google_sql_user" "observatory_user" { - name = "observatory" - instance = google_sql_database_instance.observatory_db_instance.name - password = var.observatory.postgres_password -} - -// Observatory Platform Database -resource "google_sql_database" "observatory_db" { - name = "observatory" - depends_on = [google_sql_database_instance.observatory_db_instance] - instance = google_sql_database_instance.observatory_db_instance.name -} - -######################################################################################################################## -# Google Cloud Secrets required for the VMs -######################################################################################################################## - -locals { - google_cloud_secrets = { - airflow_ui_user_email = var.observatory.airflow_ui_user_email, - airflow_ui_user_password = var.observatory.airflow_ui_user_password, - airflow_fernet_key = var.observatory.airflow_fernet_key, - airflow_secret_key = var.observatory.airflow_secret_key, - postgres_password = var.observatory.postgres_password, - airflow_logging_bucket = google_storage_bucket.observatory_airflow_bucket.name, - airflow_var_workflows = var.airflow_var_workflows, - airflow_var_dags_module_names = var.airflow_var_dags_module_names, - - # Important: this must be the generated service account, not the developer's service account used to deploy the system - google_application_credentials = base64decode(google_service_account_key.observatory_service_account_key.private_key) - } -} - -module "google_cloud_secrets" { - for_each = local.google_cloud_secrets - source = "./secret" - secret_id = each.key - secret_data = contains([ - "postgres_password", "redis_password" - ], each.key) ? urlencode(each.value) : each.value - service_account_email = data.google_compute_default_service_account.default.email - depends_on = [google_project_service.services] -} - -######################################################################################################################## -# Airflow variables required for the VMs that will be exported as environment variables -######################################################################################################################## - -locals { - main_vm_name = "airflow-main-vm" - worker_vm_name = "airflow-worker-vm" - - main_vm_internal_ip = try(google_compute_address.airflow_main_vm_private_ip.address, null) - main_vm_external_ip = try(google_compute_address.airflow_main_vm_static_external_ip[0].address, null) - worker_vm_internal_ip = try(google_compute_address.airflow_worker_vm_private_ip.address, null) - worker_vm_external_ip = try(google_compute_address.airflow_worker_vm_static_external_ip[0].address, null) - - metadata_variables = { - project_id = var.google_cloud.project_id - postgres_hostname = google_sql_database_instance.observatory_db_instance.private_ip_address - redis_hostname = local.main_vm_name # this becomes the hostname of the main vm - } -} - -######################################################################################################################## -# Observatory Platform Main VM -######################################################################################################################## - -# Compute Image shared by both VMs -data "google_compute_image" "observatory_image" { - name = "observatory-image-${var.environment}" - depends_on = [google_project_service.compute_engine] -} - -resource "google_compute_address" "airflow_main_vm_private_ip" { - name = "${local.main_vm_name}-private-ip" - address_type = "INTERNAL" - subnetwork = data.google_compute_subnetwork.observatory_subnetwork.self_link - region = var.google_cloud.region - lifecycle { - prevent_destroy = true - } -} - -resource "google_compute_address" "airflow_main_vm_static_external_ip" { - count = var.environment == "production" ? 1 : 0 - name = "${local.main_vm_name}-static-external-ip" - address_type = "EXTERNAL" - region = var.google_cloud.region - lifecycle { - prevent_destroy = true - } -} - -module "airflow_main_vm" { - source = "./vm" - name = local.main_vm_name - depends_on = [ - google_sql_database_instance.observatory_db_instance, - module.google_cloud_secrets, - ] - network = google_compute_network.observatory_network - subnetwork = data.google_compute_subnetwork.observatory_subnetwork - image = data.google_compute_image.observatory_image - machine_type = var.airflow_main_vm.machine_type - disk_size = var.airflow_main_vm.disk_size - disk_type = var.airflow_main_vm.disk_type - region = var.google_cloud.region - service_account_email = local.compute_service_account_email - startup_script_path = "./startup-main.tpl" - metadata_variables = local.metadata_variables - internal_ip = local.main_vm_internal_ip - external_ip = local.main_vm_external_ip -} - -######################################################################################################################## -# Observatory Platform Worker VM -######################################################################################################################## - -resource "google_compute_address" "airflow_worker_vm_private_ip" { - name = "${local.worker_vm_name}-private-ip" - address_type = "INTERNAL" - subnetwork = data.google_compute_subnetwork.observatory_subnetwork.self_link - region = var.google_cloud.region - lifecycle { - prevent_destroy = true - } -} - -resource "google_compute_address" "airflow_worker_vm_static_external_ip" { - count = var.environment == "production" ? 1 : 0 - name = "${local.worker_vm_name}-static-external-ip" - address_type = "EXTERNAL" - region = var.google_cloud.region - lifecycle { - prevent_destroy = true - } -} - -module "airflow_worker_vm" { - count = var.airflow_worker_vm.create == true ? 1 : 0 - source = "./vm" - name = local.worker_vm_name - depends_on = [module.airflow_main_vm] - network = google_compute_network.observatory_network - subnetwork = data.google_compute_subnetwork.observatory_subnetwork - image = data.google_compute_image.observatory_image - machine_type = var.airflow_worker_vm.machine_type - disk_size = var.airflow_worker_vm.disk_size - disk_type = var.airflow_worker_vm.disk_type - region = var.google_cloud.region - service_account_email = local.compute_service_account_email - startup_script_path = "./startup-worker.tpl" - metadata_variables = local.metadata_variables - internal_ip = local.worker_vm_internal_ip - external_ip = local.worker_vm_external_ip -} \ No newline at end of file diff --git a/observatory-platform/observatory/platform/terraform/observatory-image.json.pkr.hcl b/observatory-platform/observatory/platform/terraform/observatory-image.json.pkr.hcl deleted file mode 100644 index 5e14eacde..000000000 --- a/observatory-platform/observatory/platform/terraform/observatory-image.json.pkr.hcl +++ /dev/null @@ -1,61 +0,0 @@ -packer { - required_plugins { - googlecompute = { - source = "github.com/hashicorp/googlecompute" - version = "~> 1" - } - } -} - -variable "credentials_file" { - type = string - default = "${env("credentials_file")}" -} - -variable "environment" { - type = string - default = "${env("environment")}" -} - -variable "project_id" { - type = string - default = "${env("project_id")}" -} - -variable "zone" { - type = string - default = "${env("zone")}" -} - -source "googlecompute" "autogenerated_1" { - account_file = "${var.credentials_file}" - image_name = "observatory-image-${var.environment}" - machine_type = "n1-standard-1" - project_id = "${var.project_id}" - source_image_family = "ubuntu-2204-lts" - ssh_username = "packer" - zone = "${var.zone}" -} - -build { - sources = ["source.googlecompute.autogenerated_1"] - - provisioner "shell" { - inline = ["mkdir -p /tmp/opt/observatory/build"] - } - - provisioner "file" { - destination = "/tmp/opt" - source = "../packages" - } - - provisioner "file" { - destination = "/tmp/opt/observatory/build" - source = "../docker" - } - - provisioner "shell" { - script = "build.sh" - } - -} diff --git a/observatory-platform/observatory/platform/terraform/outputs.tf b/observatory-platform/observatory/platform/terraform/outputs.tf deleted file mode 100644 index 6d45202a5..000000000 --- a/observatory-platform/observatory/platform/terraform/outputs.tf +++ /dev/null @@ -1,47 +0,0 @@ -output "airflow_db_ip_address" { - value = google_sql_database_instance.observatory_db_instance.private_ip_address - description = "The private IP address of the Airflow Cloud SQL database." -} - -output "airflow_main_vm_ip_address" { - value = try(google_compute_address.airflow_main_vm_private_ip, null) - description = "The private IP address of the Airflow Main VM." - sensitive = true -} - -output "airflow_worker_vm_ip_address" { - value = try(google_compute_address.airflow_worker_vm_private_ip, null) - description = "The private IP address of the Airflow Worker VM." - sensitive = true -} - -output "airflow_main_vm_external_ip" { - value = try(google_compute_address.airflow_main_vm_static_external_ip, null) - description = "The external IP address of the Airflow Main VM." - sensitive = true -} - -output "airflow_worker_vm_external_ip" { - value = try(google_compute_address.airflow_worker_vm_static_external_ip, null) - description = "The external IP address of the Airflow Worker VM." - sensitive = true -} - -output "airflow_main_vm_script" { - value = module.airflow_main_vm.vm_rendered - description = "Rendered template file" - sensitive = true # explicitly mark as sensitive so it can be exported -} - -output "airflow_worker_vm_script" { - value = try(module.airflow_worker_vm.vm_rendered, null) - description = "Rendered template file" -} - -output "project_number" { - value = data.google_project.project.number -} - -output "default_transfer_service_account" { - value = data.google_storage_transfer_project_service_account.default.email -} \ No newline at end of file diff --git a/observatory-platform/observatory/platform/terraform/secret/main.tf b/observatory-platform/observatory/platform/terraform/secret/main.tf deleted file mode 100644 index cd23f9fd7..000000000 --- a/observatory-platform/observatory/platform/terraform/secret/main.tf +++ /dev/null @@ -1,21 +0,0 @@ -resource "google_secret_manager_secret" "secret" { - secret_id = var.secret_id - - replication { - automatic = true - } -} - -resource "google_secret_manager_secret_version" "secret_version" { - depends_on = [google_secret_manager_secret.secret] - secret = google_secret_manager_secret.secret.id - secret_data = var.secret_data -} - -resource "google_secret_manager_secret_iam_member" "secret_member" { - depends_on = [google_secret_manager_secret_version.secret_version] - project = google_secret_manager_secret.secret.project - secret_id = google_secret_manager_secret.secret.secret_id - role = "roles/secretmanager.secretAccessor" - member = "serviceAccount:${var.service_account_email}" -} \ No newline at end of file diff --git a/observatory-platform/observatory/platform/terraform/secret/outputs.tf b/observatory-platform/observatory/platform/terraform/secret/outputs.tf deleted file mode 100644 index e69de29bb..000000000 diff --git a/observatory-platform/observatory/platform/terraform/secret/variables.tf b/observatory-platform/observatory/platform/terraform/secret/variables.tf deleted file mode 100644 index 74eeeebdb..000000000 --- a/observatory-platform/observatory/platform/terraform/secret/variables.tf +++ /dev/null @@ -1,14 +0,0 @@ -variable "secret_id" { - type = string - description = "The id of the Google Secret Manager secret." -} - -variable "secret_data" { - type = string - description = "The data for the Google Secret Manager secret." -} - -variable "service_account_email" { - type = string - description = "The email of the service account which secret accessor rights will be granted to." -} \ No newline at end of file diff --git a/observatory-platform/observatory/platform/terraform/startup.tpl.jinja2 b/observatory-platform/observatory/platform/terraform/startup.tpl.jinja2 deleted file mode 100755 index 8d73220c0..000000000 --- a/observatory-platform/observatory/platform/terraform/startup.tpl.jinja2 +++ /dev/null @@ -1,69 +0,0 @@ -#!/usr/bin/env bash -# Copyright 2020 Curtin University -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Author: James Diprose - -# Set environment variables for Docker -export COMPOSE_PROJECT_NAME=observatory -export HOST_USER_ID=$(id -u airflow) -export POSTGRES_USER="observatory" -export POSTGRES_HOSTNAME="${postgres_hostname}" -export POSTGRES_PASSWORD="sm://${project_id}/postgres_password" -export AIRFLOW_FERNET_KEY="sm://${project_id}/airflow_fernet_key" -export AIRFLOW_SECRET_KEY="sm://${project_id}/airflow_secret_key" -export REDIS_HOSTNAME="${redis_hostname}" -export HOST_FLOWER_UI_PORT=5555 -export HOST_REDIS_PORT=6379 -export HOST_AIRFLOW_UI_PORT=8080 -export HOST_API_SERVER_PORT=5002 -export HOST_DATA_PATH=/opt/observatory/data -export HOST_LOGS_PATH=/opt/airflow/logs -export HOST_GOOGLE_APPLICATION_CREDENTIALS=/opt/observatory/google_application_credentials.json -export AIRFLOW_LOGGING_BUCKET="sm://${project_id}/airflow_logging_bucket" - -# Set environment variables and docker container names based on whether this is the main or worker vm -{%- if is_airflow_main_vm %} -export AIRFLOW_UI_USER_PASSWORD="sm://${project_id}/airflow_ui_user_password" -export AIRFLOW_UI_USER_EMAIL="sm://${project_id}/airflow_ui_user_email" -{% set docker_containers="redis flower webserver scheduler worker_local airflow_init apiserver"%} -{%- else %} -{% set docker_containers="worker_remote"%} -{%- endif %} - -# Airflow Variables -export AIRFLOW_VAR_DATA_PATH=/opt/observatory/data -export AIRFLOW_VAR_WORKFLOWS="sm://${project_id}/airflow_var_workflows" -export AIRFLOW_VAR_DAGS_MODULE_NAMES="sm://${project_id}/airflow_var_dags_module_names" - -# Airflow Connections -export AIRFLOW_CONN_GOOGLE_CLOUD_OBSERVATORY="google-cloud-platform://?extra__google_cloud_platform__key_path=%2Frun%2Fsecrets%2Fgoogle_application_credentials.json&extra__google_cloud_platform__scope=https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fcloud-platform&extra__google_cloud_platform__project=${project_id}" - -# Hardcoded list of environment variables that need to be preserved -PRESERVE_ENV="HOST_USER_ID,POSTGRES_USER,POSTGRES_HOSTNAME,POSTGRES_PASSWORD,AIRFLOW_FERNET_KEY,AIRFLOW_SECRET_KEY,\ -REDIS_HOSTNAME,AIRFLOW_UI_USER_PASSWORD,AIRFLOW_UI_USER_EMAIL,HOST_FLOWER_UI_PORT,HOST_REDIS_PORT,HOST_AIRFLOW_UI_PORT,\ -HOST_API_SERVER_PORT,HOST_DATA_PATH,HOST_LOGS_PATH,HOST_GOOGLE_APPLICATION_CREDENTIALS,AIRFLOW_LOGGING_BUCKET,AIRFLOW_VAR_DATA_PATH,\ -AIRFLOW_VAR_WORKFLOWS,AIRFLOW_VAR_DAGS_MODULE_NAMES,AIRFLOW_CONN_GOOGLE_CLOUD_OBSERVATORY" - -# Save google application credentials to file -sudo -u airflow bash -c "berglas access sm://${project_id}/google_application_credentials > /opt/observatory/google_application_credentials.json" - -# Navigate to docker directory which contains all Docker and Docker Compose files -cd /opt/observatory/build/docker - -# Pull, build and start Docker containers -{% set docker_compose_cmd="docker compose -f docker-compose.observatory.yml"%} -sudo -u airflow --preserve-env=$PRESERVE_ENV bash -c "{{ docker_compose_cmd }} pull {{ docker_containers }}" -sudo -u airflow --preserve-env=$PRESERVE_ENV bash -c "{{ docker_compose_cmd }} build {{ docker_containers }}" -sudo -u airflow -H --preserve-env=$PRESERVE_ENV bash -c "berglas exec -- {{ docker_compose_cmd }} up -d {{ docker_containers }}" \ No newline at end of file diff --git a/observatory-platform/observatory/platform/terraform/terraform_api.py b/observatory-platform/observatory/platform/terraform/terraform_api.py deleted file mode 100644 index 19d596b12..000000000 --- a/observatory-platform/observatory/platform/terraform/terraform_api.py +++ /dev/null @@ -1,503 +0,0 @@ -# Copyright 2020 Curtin University -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Author: Aniek Roelofs, James Diprose - -from __future__ import annotations - -import json -import logging -import os -from dataclasses import dataclass -from enum import Enum -from http import HTTPStatus -from typing import Tuple, List - -import requests - - -class TerraformVariableCategory(Enum): - """The type of Terraform variable.""" - - terraform = "terraform" - env = "env" - - -@dataclass -class TerraformVariable: - """A TerraformVariable class. - - Attributes: - key: variable key. - value: variable value. - var_id: variable id, uniquely identifies the variable in the cloud. - description: variable description. - category: variable category. - hcl: whether the variable is HCL or not. - sensitive: whether the variable is sensitive or not. - """ - - key: str - value: str - var_id: str = None - description: str = "" - category: TerraformVariableCategory = TerraformVariableCategory.terraform - hcl: bool = False - sensitive: bool = False - - def __str__(self): - return self.key - - def __hash__(self): - return hash(str(self)) - - def __eq__(self, other): - return self.key == other.key - - @staticmethod - def from_dict(dict_) -> TerraformVariable: - """Parse a dictionary into a TerraformVariable instance. - - :param dict_: the dictionary. - :return: the TerraformVariable instance. - """ - - var_id = dict_.get("id") - attributes = dict_["attributes"] - key = attributes.get("key") - value = attributes.get("value") - sensitive = attributes.get("sensitive") - category = attributes.get("category") - hcl = attributes.get("hcl") - description = attributes.get("description") - - return TerraformVariable( - key, - value, - sensitive=sensitive, - category=TerraformVariableCategory(category), - hcl=hcl, - description=description, - var_id=var_id, - ) - - def to_dict(self): - """Convert a TerraformVariable instance into a dictionary. - - :return: the dictionary. - """ - - var = { - "type": "vars", - "attributes": { - "key": self.key, - "value": self.value, - "description": self.description, - "category": self.category.value, - "hcl": self.hcl, - "sensitive": self.sensitive, - }, - } - - if self.var_id is not None: - var["id"] = self.var_id - - return var - - -class TerraformApi: - TERRAFORM_WORKSPACE_VERSION = "0.13.5" - VERBOSITY_WARNING = 0 - VERBOSITY_INFO = 1 - VERBOSITY_DEBUG = 2 - - def __init__(self, token: str, verbosity: int = VERBOSITY_WARNING): - """Create a TerraformApi instance. - - :param token: the Terraform API token. - :param verbosity: the verbosity for the Terraform API. - """ - - self.token = token - if verbosity == TerraformApi.VERBOSITY_WARNING: - logging.getLogger().setLevel(logging.WARNING) - elif verbosity == TerraformApi.VERBOSITY_INFO: - logging.getLogger().setLevel(logging.INFO) - elif verbosity >= TerraformApi.VERBOSITY_DEBUG: - logging.getLogger().setLevel(logging.DEBUG) - self.api_url = "https://app.terraform.io/api/v2" - self.headers = {"Content-Type": "application/vnd.api+json", "Authorization": f"Bearer {token}"} - - @staticmethod - def token_from_file(file_path: str) -> str: - """Get the Terraform token from a credentials file. - - :param file_path: path to credentials file - :return: token - """ - - with open(file_path, "r") as file: - token = json.load(file)["credentials"]["app.terraform.io"]["token"] - return token - - def create_workspace( - self, - organisation: str, - workspace: str, - auto_apply: bool, - description: str, - version: str = TERRAFORM_WORKSPACE_VERSION, - ) -> int: - """Create a new workspace in Terraform Cloud. - - :param organisation: Name of terraform organisation - :param workspace: Name of terraform workspace - :param auto_apply: Whether the new workspace should be set to auto_apply - :param description: Description of the workspace - :param version: The terraform version - :return: The response status code - """ - attributes = { - "name": workspace, - "auto-apply": str(auto_apply).lower(), - "description": description, - "terraform_version": version, - } - data = {"data": {"type": "workspaces", "attributes": attributes}} - - response = requests.post( - f"{self.api_url}/organizations/{organisation}/workspaces", headers=self.headers, json=data - ) - - if response.status_code == HTTPStatus.CREATED: - logging.info(f"Created workspace {workspace}") - logging.debug(f"response: {response.text}") - elif response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY: - logging.warning(f"Workspace with name {workspace} already exists") - logging.debug(f"response: {response.text}") - else: - logging.error(f"Response status: {response.status_code}") - logging.error(f"Unsuccessful creating workspace, response: {response.text}") - exit(os.EX_CONFIG) - return response.status_code - - def delete_workspace(self, organisation: str, workspace: str) -> int: - """Delete a workspace in terraform cloud. - - :param organisation: Name of terraform organisation - :param workspace: Name of terraform workspace - :return: The response status code - """ - - response = requests.delete( - f"{self.api_url}/organizations/{organisation}/workspaces/{workspace}", headers=self.headers - ) - return response.status_code - - def workspace_id(self, organisation: str, workspace: str) -> str: - """Returns the workspace id. - - :param organisation: Name of terraform organisation - :param workspace: Name of terraform workspace - :return: workspace id - """ - response = requests.get( - f"{self.api_url}/organizations/{organisation}/workspaces/{workspace}", headers=self.headers - ) - if response.status_code == HTTPStatus.OK: - logging.info(f"Retrieved workspace id for workspace '{workspace}'.") - logging.debug(f"response: {response.text}") - else: - logging.error(f"Response status: {response.status_code}") - logging.error( - f"Unsuccessful retrieving workspace id for workspace '{workspace}', response: {response.text}" - ) - exit(os.EX_CONFIG) - - workspace_id = json.loads(response.text)["data"]["id"] - return workspace_id - - def add_workspace_variable(self, variable: TerraformVariable, workspace_id: str) -> str: - """Add a new variable to the workspace. Will return an error if the variable already exists. - - :param variable: the TerraformVariable instance. - :param workspace_id: the workspace id - :return: The var id - """ - - response = requests.post( - f"{self.api_url}/workspaces/{workspace_id}/vars", headers=self.headers, json={"data": variable.to_dict()} - ) - - key = variable.key - if response.status_code == HTTPStatus.CREATED: - logging.info(f"Added variable {key}") - else: - msg = f"Unsuccessful adding variable {key}, response: {response.text}, status_code: {response.status_code}" - logging.error(msg) - raise ValueError(msg) - - var_id = json.loads(response.text)["data"]["id"] - return var_id - - def update_workspace_variable(self, variable: TerraformVariable, workspace_id: str) -> int: - """Update a workspace variable that is identified by its id. - - :param variable: attributes of the variable - :param var_id: the variable id - :param workspace_id: the workspace id - :return: the response status code - """ - response = requests.patch( - f"{self.api_url}/workspaces/{workspace_id}/vars/{variable.var_id}", - headers=self.headers, - json={"data": variable.to_dict()}, - ) - try: - key = json.loads(response.text)["data"]["attributes"]["key"] - except KeyError: - try: - key = variable.key - except KeyError: - key = None - if response.status_code == HTTPStatus.OK: - logging.info(f"Updated variable {key}") - else: - msg = f"Unsuccessful updating variable with id {variable.var_id} and key {key}, response: {response.text}, status_code: {response.status_code}" - logging.error(msg) - raise ValueError(msg) - - return response.status_code - - def delete_workspace_variable(self, var: TerraformVariable, workspace_id: str) -> int: - """Delete a workspace variable identified by its id. Should not return any content in response. - - :param var: the variable - :param workspace_id: the workspace id - :return: The response code - """ - response = requests.delete(f"{self.api_url}/workspaces/{workspace_id}/vars/{var.var_id}", headers=self.headers) - - if response.status_code == HTTPStatus.NO_CONTENT: - logging.info(f"Deleted variable with id {var.var_id}") - else: - msg = f"Unsuccessful deleting variable with id {var.var_id}, response: {response.text}, status_code: {response.status_code}" - logging.error(msg) - raise ValueError(msg) - - return response.status_code - - def list_workspace_variables(self, workspace_id: str) -> List[TerraformVariable]: - """Returns a list of variables in the workspace. Each variable is a dict. - - :param workspace_id: The workspace id - :return: Variables in the workspace - """ - response = requests.get(f"{self.api_url}/workspaces/{workspace_id}/vars", headers=self.headers) - - if response.status_code == HTTPStatus.OK: - logging.info(f"Retrieved workspace variables.") - logging.debug(f"response: {response.text}") - else: - logging.error(f"Response status: {response.status_code}") - logging.error(f"Unsuccessful retrieving workspace variables, response: {response.text}") - exit(os.EX_CONFIG) - - workspace_vars = json.loads(response.text)["data"] - return [TerraformVariable.from_dict(dict_) for dict_ in workspace_vars] - - def create_configuration_version(self, workspace_id: str) -> Tuple[str, str]: - """Create a configuration version. A configuration version is a resource used to reference the uploaded - configuration files. It is associated with the run to use the uploaded configuration files for performing the - plan and apply. - - :param workspace_id: the workspace id - :return: the upload url - """ - data = {"data": {"type": "configuration-versions", "attributes": {"auto-queue-runs": "false"}}} - response = requests.post( - f"{self.api_url}/workspaces/{workspace_id}/configuration-versions", headers=self.headers, json=data - ) - if response.status_code == 201: - logging.info(f"Created configuration version.") - logging.debug(f"response: {response.text}") - else: - logging.error(f"Response status: {response.status_code}") - logging.error(f"Unsuccessful creating configuration version, response: {response.text}") - exit(os.EX_CONFIG) - - upload_url = json.loads(response.text)["data"]["attributes"]["upload-url"] - configuration_id = json.loads(response.text)["data"]["id"] - return upload_url, configuration_id - - def get_configuration_version_status(self, configuration_id: str) -> str: - """Show the configuration version and return it's status. The status will be pending when the - configuration version is initially created and will remain pending until configuration files are supplied via - upload, and while they are processed. The status will then be changed to 'uploaded'. Runs cannot be created - using pending or errored configuration versions. - - :param configuration_id: the configuration version id - :return: configuration version status - """ - response = requests.get(f"{self.api_url}/configuration-versions/{configuration_id}", headers=self.headers) - if response.status_code == HTTPStatus.OK: - logging.info(f"Retrieved configuration version info.") - logging.debug(f"response: {response.text}") - else: - logging.error(f"Response status: {response.status_code}") - logging.error(f"Unsuccessful retrieving configuration version info, response: {response.text}") - exit(os.EX_CONFIG) - - status = json.loads(response.text)["data"]["attributes"]["status"] - return status - - @staticmethod - def upload_configuration_files(upload_url: str, configuration_path: str) -> int: - """Uploads the configuration files. Auto-queue-runs is set to false when creating configuration version, - so conf will not be queued automatically. - - :param upload_url: upload url, returned when creating configuration version - :param configuration_path: path to tar.gz file containing config files (main.tf) - :return: the response code - """ - - headers = {"Content-Type": "application/octet-stream"} - with open(configuration_path, "rb") as configuration: - response = requests.put(upload_url, headers=headers, data=configuration.read()) - if response.status_code == HTTPStatus.OK: - logging.info(f"Uploaded configuration.") - logging.debug(f"response: {response.text}") - else: - logging.error(f"Response status: {response.status_code}") - logging.error(f"Unsuccessful uploading configuration, response: {response.text}") - exit(os.EX_CONFIG) - - return response.status_code - - def create_run(self, workspace_id: str, target_addrs: str = None, message: str = "") -> str: - """Creates a run, optionally targeted at a target address. If auto-apply is set to true the run will be applied - afterwards as well. - - :param workspace_id: the workspace id - :param target_addrs: the target address (id of the module/resource) - :param message: additional message that will be displayed at the terraform cloud run - :return: the run id - """ - - data = { - "data": { - "attributes": {"message": message}, - "type": "runs", - "relationships": {"workspace": {"data": {"type": "workspaces", "id": workspace_id}}}, - } - } - if target_addrs: - data["data"]["attributes"]["target-addrs"] = [target_addrs] - - response = requests.post(f"{self.api_url}/runs", headers=self.headers, json=data) - if response.status_code == HTTPStatus.CREATED: - logging.info(f"Created run.") - logging.debug(f"response: {response.text}") - else: - logging.error(f"Response status: {response.status_code}") - logging.error(f"Unsuccessful creating run, response: {response.text}") - exit(os.EX_CONFIG) - - run_id = json.loads(response.text)["data"]["id"] - return run_id - - def get_run_details(self, run_id: str) -> dict: - """Get details on a run identified by its id. - :param run_id: the run id - :return: the response text - """ - - response = requests.get(f"{self.api_url}/runs/{run_id}", headers=self.headers) - if not response.status_code == HTTPStatus.OK: - logging.error(f"Response status: {response.status_code}") - logging.error(f"Unsuccessful retrieving run details, response: {response.text}") - exit(os.EX_CONFIG) - - return json.loads(response.text) - - def plan_variable_changes( - self, new_vars: List[TerraformVariable], workspace_id: str - ) -> Tuple[ - List[TerraformVariable], - List[Tuple[TerraformVariable, TerraformVariable]], - List[TerraformVariable], - List[TerraformVariable], - ]: - """Compares the current variables in the workspace with a list of new variables. It sorts the new variables in - one of 4 different categories and adds them to the corresponding list. Sensitive variables can never be - 'unchanged'. - - :param new_vars: list of potential new variables where each variable is a variable attributes dict - :param workspace_id: the workspace id - :return: lists of variables in different categories (add, edit, unchanged, delete). - add: list of attributes dicts - edit: list of tuples with (attributes dict, var id, old value) - unchanged: list of attributes dicts - delete: list of tuples with (var key, var id, old value) - """ - - add: List[TerraformVariable] = [] - edit: List[Tuple[TerraformVariable, TerraformVariable]] = [] - unchanged: List[TerraformVariable] = [] - - # Get existing variables - old_vars = self.list_workspace_variables(workspace_id) - - # Make dict with old variable keys as keys - old_var_ids = {} - for old_var in old_vars: - old_var_ids[old_var.key] = old_var - - # Check which new variables need to be updated and which need to be added - for new_var in new_vars: - # check if variable already exists - if new_var.key in old_var_ids.keys(): - old_var = old_var_ids[new_var.key] - - # check if values of old and new are the same - if new_var.sensitive or new_var.value != old_var.value: - # Assign id of old variable to new variable - new_var.var_id = old_var.var_id - edit.append((old_var, new_var)) - else: - unchanged.append(new_var) - else: - add.append(new_var) - - # Compare old variables with new variables and check which are 'extra' in the old vars - delete = list(set(old_vars) - set(new_vars)) - - return add, edit, unchanged, delete - - def update_workspace_variables(self, add: list, edit: list, delete: list, workspace_id: str): - """Update workspace accordingly to the planned changes. - - :param add: list of attributes dicts of new variables - :param edit: list of tuples with (attributes dict, var id, old value) of existing variables that will be updated - :param delete: list of tuples with (var key, var id, old value) of variables that will be deleted - :param workspace_id: the workspace id - :return: None - """ - - for var in add: - self.add_workspace_variable(var, workspace_id) - for old_var, new_var in edit: - self.update_workspace_variable(new_var, workspace_id) - for var in delete: - self.delete_workspace_variable(var, workspace_id) diff --git a/observatory-platform/observatory/platform/terraform/terraform_builder.py b/observatory-platform/observatory/platform/terraform/terraform_builder.py deleted file mode 100644 index 01f833eba..000000000 --- a/observatory-platform/observatory/platform/terraform/terraform_builder.py +++ /dev/null @@ -1,252 +0,0 @@ -# Copyright 2020 Curtin University -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Author: James Diprose - - -import os -import re -import shutil -import subprocess -from subprocess import Popen -from typing import Tuple, Optional, List - -from observatory.platform.cli.cli_utils import indent, INDENT1 -from observatory.platform.config import module_file_path -from observatory.platform.docker.platform_runner import PlatformRunner -from observatory.platform.observatory_config import TerraformConfig -from observatory.platform.utils.jinja2_utils import render_template -from observatory.platform.utils.proc_utils import stream_process - - -def should_ignore(name, patterns): - """Check if a name matches any of the provided regex patterns.""" - return any(re.search(pattern, name) for pattern in patterns) - - -def copy_dir(src: str, dst: str, ignore_patterns: Optional[List] = None, verbose: bool = False): - if ignore_patterns is None: - ignore_patterns = set() - - if not os.path.exists(dst): - os.makedirs(dst) - for item in os.listdir(src): - s = os.path.join(src, item) - d = os.path.join(dst, item) - if os.path.isdir(s): - if not should_ignore(item, ignore_patterns): - copy_dir(s, d, ignore_patterns, verbose) - else: - if verbose: - print(f"Ignored: {s}") - else: - if verbose: - print(f"Copy {s} -> {d}") - shutil.copy2(s, d) - - -class TerraformBuilder: - def __init__(self, config: TerraformConfig, debug: bool = False): - """Create a TerraformBuilder instance, which is used to build, start and stop an Observatory Platform instance. - - :param config: the TerraformConfig. - :param debug: whether to print debug statements. - """ - - self.config = config - self.api_package_path = module_file_path("observatory.api", nav_back_steps=-3) - self.terraform_path = module_file_path("observatory.platform.terraform") - self.api_path = module_file_path("observatory.api.server") - self.debug = debug - - self.build_path = os.path.join(config.observatory.observatory_home, "build", "terraform") - self.platform_runner = PlatformRunner(config=config, docker_build_path=os.path.join(self.build_path, "docker")) - self.packages_build_path = os.path.join(self.build_path, "packages") - self.terraform_build_path = os.path.join(self.build_path, "terraform") - os.makedirs(self.packages_build_path, exist_ok=True) - os.makedirs(self.terraform_build_path, exist_ok=True) - - @property - def is_environment_valid(self) -> bool: - """Return whether the environment for building the Packer image is valid. - - :return: whether the environment for building the Packer image is valid. - """ - - return all([self.packer_exe_path is not None]) - - @property - def packer_exe_path(self) -> str: - """The path to the Packer executable. - - :return: the path or None. - """ - - return shutil.which("packer") - - @property - def gcloud_exe_path(self) -> str: - """The path to the Google Cloud SDK executable. - - :return: the path or None. - """ - - return shutil.which("gcloud") - - def make_files(self): - # Clear terraform/packages path - if os.path.exists(self.packages_build_path): - shutil.rmtree(self.packages_build_path) - os.makedirs(self.packages_build_path) - - # Copy local packages - for package in self.config.python_packages: - if package.type == "editable": - destination_path = os.path.join(self.packages_build_path, package.name) - copy_dir( - package.host_package, - destination_path, - ignore_patterns=[ - r"^\.git$", - r"^\.eggs$", - r"^__pycache__$", - r"^venv$", - r".*\.egg-info$", - r"^fixtures$", - r"^\.idea$", - ], - verbose=self.debug, - ) - - # Clear terraform/terraform path, but keep the .terraform folder and other hidden files necessary for terraform. - if os.path.exists(self.terraform_build_path): - terraform_files = os.listdir(self.terraform_build_path) - terraform_files_to_delete = [file for file in terraform_files if not file.startswith(".")] - for file in terraform_files_to_delete: - if os.path.isfile(os.path.join(self.terraform_build_path, file)): - os.remove(os.path.join(self.terraform_build_path, file)) - else: - shutil.rmtree(os.path.join(self.terraform_build_path, file)) - else: - os.makedirs(self.terraform_build_path) - - # Copy terraform files into build/terraform - copy_dir( - self.terraform_path, - self.terraform_build_path, - ignore_patterns=[r"^\.git$", r"^\.eggs$", r"^__pycache__$", r"^venv$", r".*\.egg-info$", r"^\.idea$"], - verbose=self.debug, - ) - - # Make startup scripts - self.make_startup_script(True, "startup-main.tpl") - self.make_startup_script(False, "startup-worker.tpl") - - def make_startup_script(self, is_airflow_main_vm: bool, file_name: str): - # Load and render template - template_path = os.path.join(self.terraform_path, "startup.tpl.jinja2") - render = render_template(template_path, is_airflow_main_vm=is_airflow_main_vm) - - # Save file - output_path = os.path.join(self.terraform_build_path, file_name) - with open(output_path, "w") as f: - f.write(render) - - def build_terraform(self): - """Build the Observatory Platform Terraform files. - - :return: None. - """ - - self.make_files() - self.platform_runner.make_files() - - def install_packer_plugins(self) -> Tuple[str, str, int]: - """Install the necessary plugins for packer.""" - - # Install the Packer plugins by doing an init on the observatory image config file - args = ["packer", "init", "observatory-image.json.pkr.hcl"] - - if self.debug: - print("Executing subprocess:") - print(indent(f"Command: {subprocess.list2cmdline(args)}", INDENT1)) - print(indent(f"Cwd: {self.terraform_build_path}", INDENT1)) - - proc: Popen = subprocess.Popen( - args, stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=self.terraform_build_path - ) - - # Wait for results - # Debug always true here because otherwise nothing gets printed and you don't know what the state of the - # image building is - output, error = stream_process(proc, True) - return output, error, proc.returncode - - def build_image(self) -> Tuple[str, str, int]: - """Build the Observatory Platform Google Compute image with Packer. - - :return: output and error stream results and proc return code. - """ - - # Make Terraform files - self.build_terraform() - - # Load template - template_vars = { - "credentials_file": self.config.google_cloud.credentials, - "project_id": self.config.google_cloud.project_id, - "zone": self.config.google_cloud.zone, - "environment": self.config.backend.environment.value, - } - variables = [] - for key, val in template_vars.items(): - variables.append("-var") - variables.append(f"{key}={val}") - - # Install the necessary Packer plugins - self.install_packer_plugins() - - # Build the containers first - args = ["packer", "build"] + variables + ["-force", "observatory-image.json.pkr.hcl"] - - if self.debug: - print("Executing subprocess:") - print(indent(f"Command: {subprocess.list2cmdline(args)}", INDENT1)) - print(indent(f"Cwd: {self.terraform_build_path}", INDENT1)) - - proc: Popen = subprocess.Popen( - args, stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=self.terraform_build_path - ) - - # Wait for results - # Debug always true here because otherwise nothing gets printed and you don't know what the state of the - # image building is - output, error = stream_process(proc, True) - return output, error, proc.returncode - - def gcloud_activate_service_account(self) -> Tuple[str, str, int]: - args = ["gcloud", "auth", "activate-service-account", "--key-file", self.config.google_cloud.credentials] - - if self.debug: - print("Executing subprocess:") - print(indent(f"Command: {subprocess.list2cmdline(args)}", INDENT1)) - print(indent(f"Cwd: {self.api_package_path}", INDENT1)) - - proc: Popen = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=self.api_package_path) - - # Wait for results - # Debug always true here because otherwise nothing gets printed and you don't know what the state of the - # image building is - output, error = stream_process(proc, True) - return output, error, proc.returncode diff --git a/observatory-platform/observatory/platform/terraform/variables.tf b/observatory-platform/observatory/platform/terraform/variables.tf deleted file mode 100644 index 36d391539..000000000 --- a/observatory-platform/observatory/platform/terraform/variables.tf +++ /dev/null @@ -1,99 +0,0 @@ -variable "environment" { - description = "The environment type: develop, staging or production." - type = string -} - -variable "observatory" { - description = < TerraformApi: - """Construct a TerraformApi object from the Airflow connection. - - :return: TerraformApi object. - """ - - token = get_airflow_connection_password(self.terraform_conn_id) - return TerraformApi(token) - - @property - def workspace_id(self) -> str: - """Uses terraform API and workspace name to get the id of this workspace. - - :return: workspace id - """ - - return self.terraform_api.workspace_id(self.organisation, self.workspace) - - def get_vm_info(self) -> Tuple[Optional[VirtualMachine], Optional[TerraformVariable]]: - """Get the VirtualMachine data object, and TerraformVariable object for airflow_worker_vm. - - :return VirtualMachine and TerraformVariable objects. - """ - - variables = self.terraform_api.list_workspace_variables(self.workspace_id) - - for var in variables: - if var.key == TERRAFORM_CREATE_VM_KEY: - return VirtualMachine.from_hcl(var.value), var - - return None, None - - def update_terraform_vm_create_variable(self, value: bool): - """Update the Terraform VM create flag. - - :param value: New value to set. - """ - - vm, vm_var = self.get_vm_info() - vm.create = value - logging.info(f"vm.create: {vm.create}") - vm_var.value = vm.to_hcl() - - self.terraform_api.update_workspace_variable(vm_var, self.workspace_id) - - def create_terraform_run(self, *, dag_id: str, start_date: pendulum.DateTime) -> str: - """Create a Terraform run and return the run ID. - - :param dag_id: DAG ID. - :param start_date: Task instance start date. - :return Terraform run ID. - """ - - message = f'Triggered from airflow DAG "{dag_id}" at {start_date}' - run_id = self.terraform_api.create_run(self.workspace_id, TARGET_ADDRS, message) - logging.info(f"Terraform run_id: {run_id}") - - return run_id - - def check_terraform_run_status(self, *, ti: TaskInstance, execution_date: pendulum.DateTime, run_id: str): - """Retrieve the terraform run status until it is in a finished state, either successful or errored. See - https://www.terraform.io/docs/cloud/api/run.html for possible run_status values. - If the run status is not successful and the environment isn't develop a warning message will be sent to a Slack - channel. - - :param ti: Task instance. - :param execution_date: DagRun execution date. - :param run_id: The run id of the Terraform run - :return: None - """ - - run_status = None - while run_status not in [ - "planned_and_finished", - "applied", - "errored", - "discarded", - "canceled", - "force_canceled", - ]: - run_details = self.terraform_api.get_run_details(run_id) - run_status = run_details["data"]["attributes"]["status"] - - logging.info(f"Run status: {run_status}") - comments = f"Terraform run status: {run_status}" - logging.info(f'Sending slack notification: "{comments}"') - send_slack_msg(ti=ti, execution_date=execution_date, comments=comments, slack_conn_id=self.slack_conn_id) - - -def parse_datetime(dt: str) -> Optional[pendulum.DateTime]: - """Try to parse datetime using pendulum.parse. Do not try to parse None. - - :param dt: Datetime string. - :return: Datetime object, or None if failed. - """ - - if dt is None: - return None - - return pendulum.parse(dt) - - -class VmCreateWorkflow(Workflow): - """Workflow to spin up an Airflow worker VM (with Terraform).""" - - def __init__( - self, - *, - dag_id: str, - terraform_organisation: str, - terraform_workspace: str, - terraform_conn_id=AirflowConns.TERRAFORM, - slack_conn_id=AirflowConns.SLACK, - start_date: pendulum.DateTime = pendulum.datetime(2020, 7, 1), - schedule: str = "@weekly", - **kwargs, - ): - """Construct the workflow. - - :param dag_id: the DAG id. - :param terraform_organisation: the Terraform Organisation ID. - :param terraform_workspace: the full Terraform Workspace name. - :param terraform_conn_id: the Airflow Connection ID for the Terraform credentials. - :param slack_conn_id: the Airflow Connection ID for the Slack credentials. - :param start_date: Start date for the DAG. - :param schedule: Schedule interval for the DAG. - :param kwargs: to catch any extra kwargs passed during DAG creation. - """ - - super().__init__( - dag_id=dag_id, - start_date=start_date, - schedule=schedule, - catchup=False, - max_active_runs=1, - airflow_conns=[terraform_conn_id, slack_conn_id], - tags=[Tag.observatory_platform], - ) - - self.terraform_organisation = terraform_organisation - self.terraform_workspace = terraform_workspace - self.vm_api = TerraformVirtualMachineAPI( - organisation=self.terraform_organisation, - workspace=self.terraform_workspace, - terraform_conn_id=terraform_conn_id, - slack_conn_id=slack_conn_id, - ) - self.add_setup_task(self.check_dependencies) - self.add_setup_task(self.check_vm_state) - self.add_task(self.update_terraform_variable) - self.add_task(self.run_terraform) - self.add_task(self.check_run_status) - self.add_task(self.cleanup, trigger_rule="none_failed") - - def make_release(self, **kwargs) -> None: - """Required for Workflow class. - - :param kwargs: Unused. - :return: None. - """ - - return None - - def check_vm_state(self, **kwargs) -> bool: - """Checks if VM is running. Proceed only if VM is not already running. - - :param kwargs: Unused. - :return: Whether to continue. - """ - - vm, _ = self.vm_api.get_vm_info() - logging.info(f"VM is on: {vm.create}") - return not vm.create - - def update_terraform_variable(self, _, **kwargs): - """Update Terraform variable for VM to running state. - - :param kwargs: Unused. - """ - - self.vm_api.update_terraform_vm_create_variable(True) - - def run_terraform(self, _, **kwargs): - """Runs terraform configuration. The current task start time, previous task start time, and Terraform run ID will be pushed to XComs. - - :param kwargs: the context passed from the PythonOperator. See - https://airflow.apache.org/docs/stable/macros-ref.html for a list of the keyword arguments that are passed to - this argument. - """ - - ti: TaskInstance = kwargs["ti"] - - prev_start_time_vm = ti.xcom_pull(key=XCOM_START_TIME_VM, include_prior_dates=True) - ti.xcom_push(XCOM_PREV_START_TIME_VM, prev_start_time_vm) - ti.xcom_push(XCOM_START_TIME_VM, ti.start_date.isoformat()) - - run_id = self.vm_api.create_terraform_run(dag_id=self.dag_id, start_date=ti.start_date) - ti.xcom_push(XCOM_TERRAFORM_RUN_ID, run_id) - - def check_run_status(self, _, **kwargs): - """Retrieve the terraform run status until it is in a finished state, either successful or errored. See - https://www.terraform.io/docs/cloud/api/run.html for possible run_status values. - If the run status is not successful and the environment isn't develop a warning message will be sent to a slack - channel. - - :param kwargs: the context passed from the PythonOperator. See - https://airflow.apache.org/docs/stable/macros-ref.html for a list of the keyword arguments that are passed to - this argument. - """ - - ti: TaskInstance = kwargs["ti"] - execution_date = kwargs["execution_date"] - - run_id = ti.xcom_pull(key=XCOM_TERRAFORM_RUN_ID, task_ids=self.run_terraform.__name__) - self.vm_api.check_terraform_run_status(ti=ti, execution_date=execution_date, run_id=run_id) - - def cleanup(self, _, **kwargs): - """Delete stale XCom messages. - - :param kwargs: the context passed from the PythonOperator. See - https://airflow.apache.org/docs/stable/macros-ref.html for a list of the keyword arguments that are passed to - this argument. - """ - - execution_date = kwargs["execution_date"] - delete_old_xcoms(dag_id=self.dag_id, execution_date=execution_date, retention_days=15) - - -class VmDestroyWorkflow(Workflow): - """Workflow to teardown an Airflow worker VM (with Terraform).""" - - def __init__( - self, - *, - dag_id: str, - terraform_organisation: str, - terraform_workspace: str, - dags_watch_list: List[str], - vm_create_dag_id: str = "vm_create", - terraform_conn_id=AirflowConns.TERRAFORM, - slack_conn_id=AirflowConns.SLACK, - start_date: pendulum.DateTime = pendulum.datetime(2020, 1, 1), - schedule: str = "*/10 * * * *", - **kwargs, - ): - """Construct the workflow. - - :param dag_id: the DAG id. - :param terraform_organisation: the Terraform Organisation ID. - :param terraform_workspace: the full Terraform Workspace name. - :param dags_watch_list: the list of DAGs to watch for before destroying the VM. - :param terraform_conn_id: the Airflow Connection ID for the Terraform credentials. - :param slack_conn_id: the Airflow Connection ID for the Slack credentials. - :param start_date: Start date for the DAG. - :param schedule: Schedule interval for the DAG. - :param kwargs: to catch any extra kwargs passed during DAG creation. - """ - - super().__init__( - dag_id=dag_id, - start_date=start_date, - schedule=schedule, - catchup=False, - max_active_runs=1, - airflow_conns=[terraform_conn_id, slack_conn_id], - tags=[Tag.observatory_platform], - ) - - self.terraform_organisation = terraform_organisation - self.terraform_workspace = terraform_workspace - self.dags_watch_list = dags_watch_list - self.vm_create_dag_id = vm_create_dag_id - self.slack_conn_id = slack_conn_id - self.vm_api = TerraformVirtualMachineAPI( - organisation=self.terraform_organisation, - workspace=self.terraform_workspace, - terraform_conn_id=terraform_conn_id, - slack_conn_id=slack_conn_id, - ) - - self.add_setup_task(self.check_dependencies) - self.add_setup_task(self.check_vm_state) - self.add_setup_task(self.check_dags_status) - self.add_task(self.update_terraform_variable) - self.add_task(self.run_terraform) - self.add_task(self.check_run_status) - self.add_task(self.cleanup, trigger_rule="none_failed") - - def make_release(self, **kwargs) -> None: - """Required for Workflow class. - - :param kwargs: Unused. - :return: None. - """ - - return None - - def check_vm_state(self, **kwargs) -> bool: - """Checks if VM is running. Proceed only if VM is running. - - :param kwargs: Unused. - :return: Whether to continue. - """ - - vm, _ = self.vm_api.get_vm_info() - logging.info(f"VM is on: {vm.create}") - return vm.create - - def check_dags_status(self, **kwargs): - """Check if all expected runs for the DAGs in the watchlist are successful. If they are the task, then proceed, otherwise check how long the VM has run for, and skip the rest of the workflow. - - :param kwargs: the context passed from the PythonOperator. See - https://airflow.apache.org/docs/stable/macros-ref.html for a list of the keyword arguments that are passed to - this argument. - :return: id of task which should be executed next - """ - - ti: TaskInstance = kwargs["ti"] - destroy_worker_vm = True - - prev_start_time_vm = ti.xcom_pull( - key=XCOM_PREV_START_TIME_VM, - task_ids=VmCreateWorkflow.run_terraform.__name__, - dag_id=self.vm_create_dag_id, - include_prior_dates=True, - ) - prev_start_time_vm = parse_datetime(prev_start_time_vm) - - start_time_vm = ti.xcom_pull( - key=XCOM_START_TIME_VM, - task_ids=VmCreateWorkflow.run_terraform.__name__, - dag_id=self.vm_create_dag_id, - include_prior_dates=True, - ) - start_time_vm = parse_datetime(start_time_vm) - - destroy_time_vm = ti.xcom_pull( - key=XCOM_DESTROY_TIME_VM, - task_ids=self.run_terraform.__name__, - dag_id=self.dag_id, - include_prior_dates=True, - ) - destroy_time_vm = parse_datetime(destroy_time_vm) - - logging.info( - f"prev_start_time_vm: {prev_start_time_vm}, start_time_vm: {start_time_vm}, " - f"destroy_time_vm: {destroy_time_vm}\n" - ) - - # Load VM DAGs watch list - for dag_id in self.dags_watch_list: - dagbag = DagBag() - dag = dagbag.get_dag(dag_id) - logging.info(f"Dag id: {dag_id}") - - # vm turned on manually and never turned on before - if not start_time_vm and not prev_start_time_vm: - logging.warning("Both start_time_vm and prev_start_time_vm are None. Unsure whether to turn off DAG.") - destroy_worker_vm = False - break - - # returns last execution date of previous vm cycle or None if a DAG is running - last_execution_prev = self._get_last_execution_prev(dag, dag_id, prev_start_time_vm) - if not last_execution_prev: - destroy_worker_vm = False - break - - logging.info(f"Execution date of last DAG before prev_start_time_vm: {last_execution_prev}\n") - - if destroy_time_vm: - if start_time_vm < destroy_time_vm: - # If the vm is on, but there's no start_time_vm it must have been turned on manually. - logging.warning( - "start_time_vm is before destroy_time_vm. Perhaps the vm was turned on " - "manually in between. This task will continue with the given start_time_vm." - ) - - # get a backfill of all expected runs between last execution date of prev cycle and the time the vm was - # started. create iterator starting at the latest planned schedule before start_time_vm - cron_iter = croniter(dag.normalized_schedule_interval, dag.previous_schedule(start_time_vm)) - - # the last_execution_current is expected 1 schedule interval before the DAGs 'previous_schedule', - # because airflow won't trigger a DAG until 1 schedule interval after the 'execution_date'. - last_execution_current = cron_iter.get_prev(datetime) - - # if DAG is not set to catchup any backfill, the only run date is the last one. - if dag.catchup: - execution_dates = dag.get_run_dates(last_execution_prev, last_execution_current) - else: - execution_dates = [last_execution_current] - - # for each execution date check if state is success. This can't be done in all_dag_runs above, because the - # dag_run might not be in all_dag_runs yet, because it is not scheduled yet. - destroy_worker_vm = self._check_success_runs(dag_id, execution_dates) - if destroy_worker_vm is False: - break - - logging.info(f"Destroying worker VM: {destroy_worker_vm}") - - # If not destroying vm, check VM runtime. - if not destroy_worker_vm: - self.check_runtime_vm(start_time_vm, **kwargs) - - return destroy_worker_vm - - def check_runtime_vm(self, start_time_vm: Optional[datetime], **kwargs): - """Checks how long the VM has been turned on based on the xcom value from the terraform run task. - A warning message will be sent in a slack channel if it has been on longer than the warning limit, - the environment isn't develop and a message hasn't been sent already in the last x hours. - - :param start_time_vm: Start time of the vm - :param kwargs: the context passed from the PythonOperator. See - https://airflow.apache.org/docs/stable/macros-ref.html for a list of the keyword arguments that are passed to - this argument. - :return: None - """ - - ti: TaskInstance = kwargs["ti"] - last_warning_time = ti.xcom_pull( - key=XCOM_WARNING_TIME, - task_ids=ti.task_id, - dag_id=self.dag_id, - include_prior_dates=True, - ) - last_warning_time = parse_datetime(last_warning_time) - - if start_time_vm: - # calculate number of hours passed since start time vm and now - hours_on = (ti.start_date - start_time_vm).total_seconds() / 3600 - logging.info( - f"Start time VM: {start_time_vm}, hours passed since start time: {hours_on}, warning limit: " - f"{VM_RUNTIME_H_WARNING}" - ) - - # check if a warning has been sent previously and if so, how many hours ago - if last_warning_time: - hours_since_warning = (ti.start_date - last_warning_time).total_seconds() / 3600 - else: - hours_since_warning = None - - # check if the VM has been on longer than the limit - if hours_on > VM_RUNTIME_H_WARNING: - # check if no warning was sent before or last time was longer ago than warning frequency - if not hours_since_warning or hours_since_warning > WARNING_FREQUENCY_H: - comments = ( - f"Worker VM has been on since {start_time_vm}. No. hours passed since then: " - f"{hours_on}." - f" Warning limit: {VM_RUNTIME_H_WARNING}H" - ) - execution_date = kwargs["execution_date"] - send_slack_msg( - ti=ti, execution_date=execution_date, comments=comments, slack_conn_id=self.slack_conn_id - ) - - ti.xcom_push(XCOM_WARNING_TIME, ti.start_date.isoformat()) - else: - logging.info(f"Start time VM unknown.") - - def update_terraform_variable(self, _, **kwargs): - """Update Terraform variable for VM to running state. - - :param kwargs: Unused. - """ - - self.vm_api.update_terraform_vm_create_variable(False) - - def run_terraform(self, _, **kwargs): - """Runs terraform configuration. The current task start time, previous task start time, and Terraform run ID will be pushed to XComs. - - :param _: None. - :param kwargs: the context passed from the PythonOperator. See - https://airflow.apache.org/docs/stable/macros-ref.html for a list of the keyword arguments that are passed to - this argument. - """ - - ti: TaskInstance = kwargs["ti"] - ti.xcom_push(XCOM_DESTROY_TIME_VM, ti.start_date.isoformat()) - run_id = self.vm_api.create_terraform_run(dag_id=self.dag_id, start_date=ti.start_date) - ti.xcom_push(XCOM_TERRAFORM_RUN_ID, run_id) - - def check_run_status(self, _, **kwargs): - """Retrieve the terraform run status until it is in a finished state, either successful or errored. See - https://www.terraform.io/docs/cloud/api/run.html for possible run_status values. - If the run status is not successful and the environment isn't develop a warning message will be sent to a slack - channel. - - :param _: None. - :param kwargs: the context passed from the PythonOperator. See - https://airflow.apache.org/docs/stable/macros-ref.html for a list of the keyword arguments that are passed to - this argument. - """ - - ti: TaskInstance = kwargs["ti"] - execution_date = kwargs["execution_date"] - - run_id = ti.xcom_pull(key=XCOM_TERRAFORM_RUN_ID, task_ids=self.run_terraform.__name__) - self.vm_api.check_terraform_run_status(ti=ti, execution_date=execution_date, run_id=run_id) - - def cleanup(self, _, **kwargs): - """Delete stale XCom messages. - - :param kwargs: the context passed from the PythonOperator. See - https://airflow.apache.org/docs/stable/macros-ref.html for a list of the keyword arguments that are passed to - this argument. - """ - - execution_date = kwargs["execution_date"] - delete_old_xcoms(dag_id=self.dag_id, execution_date=execution_date, retention_days=15) - - def _get_last_execution_prev( - self, dag: DAG, dag_id: str, prev_start_time_vm: Union[datetime, None] - ) -> Union[datetime, None]: - """Find the execution date of the last DAG run before the previous time the VM was turned on. - If there aren't any DAG runs before this time or the time is None (first/second time turning off VM) the - execution date is set to the start_date of the DAG instead. - - If a DAG is currently running it will return None and the remaining tasks are skipped. - - :param dag: DAG object - :param dag_id: the dag id - :param prev_start_time_vm: previous time the VM was turned on - :return: execution date or None - """ - - # Get execution date of the last run before previous start date - all_dag_runs = DagRun.find(dag_id=dag_id) - # sort dag runs by start datetime, newest first - for dag_run in sorted(all_dag_runs, key=lambda x: x.start_date, reverse=True): - if dag_run.state == "running": - logging.info("DAG is currently running.") - return None - # None if first time running destroy - if prev_start_time_vm: - if pendulum.instance(dag_run.start_date) < prev_start_time_vm: - # get execution date of last run from when the VM was previously on - return dag_run.execution_date - - # No runs executed previously - if prev_start_time_vm: - logging.info("No DAG runs that started before prev_start_time_vm.") - else: - # First time running destroy_vm, no previous start date available - logging.info("No prev_start_time_vm.") - logging.info("Setting last execution date to start_date of DAG.") - last_execution_prev = dag.default_args["start_date"] - - return last_execution_prev - - def _check_success_runs(self, dag_id: str, execution_dates: list) -> bool: - """For each date in the execution dates it checks if a DAG run exists and if so if the state is set to success. - - Only if both of these are true for all dates it will return True. - - :param dag_id: the dag id - :param execution_dates: list of execution dates - :return: True or False - """ - for date in execution_dates: - dag_runs = DagRun.find(dag_id=dag_id, execution_date=date) - if not dag_runs: - logging.info(f"Expected dag run on {date} has not been scheduled yet") - return False - - for dag_run in dag_runs: - logging.info( - f"id: {dag_run.dag_id}, start date: {dag_run.start_date}, execution date: " - f"{dag_run.execution_date}, state: {dag_run.state}" - ) - if dag_run.state != DagRunState.SUCCESS: - return False - return True diff --git a/observatory-platform/observatory/platform/workflows/workflow.py b/observatory-platform/observatory/platform/workflows/workflow.py deleted file mode 100644 index e7e3510aa..000000000 --- a/observatory-platform/observatory/platform/workflows/workflow.py +++ /dev/null @@ -1,596 +0,0 @@ -# Copyright 2020 Curtin University -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Author: Aniek Roelofs, James Diprose, Tuan Chien - -from __future__ import annotations - -import contextlib -import copy -import logging -import shutil -from abc import ABC, abstractmethod -from functools import partial -from typing import Any, Callable, Dict, List, Union, Optional - -try: - from typing import Protocol -except ImportError: - from typing_extensions import Protocol -import os -import pendulum -from airflow import DAG - -from airflow.exceptions import AirflowException -from airflow.models.baseoperator import chain -from airflow.operators.python import PythonOperator, ShortCircuitOperator -from airflow.operators.bash import BaseOperator, BashOperator -from observatory.platform.airflow import ( - check_connections, - check_variables, - on_failure_callback, - delete_old_xcoms, -) -from observatory.platform.airflow import get_data_path -from observatory.platform.observatory_config import CloudWorkspace - - -DATE_TIME_FORMAT = "YYYY-MM-DD_HH:mm:ss" - - -def make_snapshot_date(**kwargs) -> pendulum.DateTime: - """Make a snapshot date""" - - return kwargs["data_interval_end"] - - -def make_workflow_folder(dag_id: str, run_id: str, *subdirs: str) -> str: - """Return the path to this dag release's workflow folder. Will also create it if it doesn't exist - - :param dag_id: The ID of the dag. This is used to find/create the workflow folder - :param run_id: The Airflow DAGs run ID. Examples: "scheduled__2023-03-26T00:00:00+00:00" or "manual__2023-03-26T00:00:00+00:00". - :param subdirs: The folder path structure (if any) to create inside the workspace. e.g. 'download' or 'transform' - :return: the path of the workflow folder - """ - - path = os.path.join(get_data_path(), dag_id, run_id, *subdirs) - os.makedirs(path, exist_ok=True) - return path - - -def check_workflow_inputs(workflow: Workflow, check_cloud_workspace=True) -> None: - """Checks a Workflow object for validity - - :param workflow: The Workflow object - :param check_cloud_workspace: Whether to check the CloudWorkspace field, defaults to True - :raises AirflowException: Raised if there are invalid fields - """ - invalid_fields = [] - if not workflow.dag_id or not isinstance(workflow.dag_id, str): - invalid_fields.append("dag_id") - - if check_cloud_workspace: - cloud_workspace = workflow.cloud_workspace - if not isinstance(cloud_workspace, CloudWorkspace): - invalid_fields.append("cloud_workspace") - else: - required_fields = {"project_id": str, "data_location": str, "download_bucket": str, "transform_bucket": str} - for field_name, field_type in required_fields.items(): - field_value = getattr(cloud_workspace, field_name, None) - if not isinstance(field_value, field_type) or not field_value: - invalid_fields.append(f"cloud_workspace.{field_name}") - - if cloud_workspace.output_project_id is not None: - if not isinstance(cloud_workspace.output_project_id, str) or not cloud_workspace.output_project_id: - invalid_fields.append("cloud_workspace.output_project_id") - - if invalid_fields: - raise AirflowException(f"Workflow input fields invalid: {invalid_fields}") - - -def cleanup(dag_id: str, execution_date: str, workflow_folder: str = None, retention_days=31) -> None: - """Delete all files, folders and XComs associated from a release. - - :param dag_id: The ID of the DAG to remove XComs - :param execution_date: The execution date of the DAG run - :param workflow_folder: The top-level workflow folder to clean up - :param retention_days: How many days of Xcom messages to retain - """ - if workflow_folder: - try: - shutil.rmtree(workflow_folder) - except FileNotFoundError as e: - logging.warning(f"No such file or directory {workflow_folder}: {e}") - - delete_old_xcoms(dag_id=dag_id, execution_date=execution_date, retention_days=retention_days) - - -class WorkflowBashOperator(BashOperator): - def __init__(self, workflow: Workflow, *args, **kwargs): - super().__init__(*args, **kwargs) - self.workflow = workflow - - def render_template(self, content, context, *args, **kwargs): - # Make release and set in context - obj = self.workflow.make_release(**context) - if isinstance(obj, list): - context["releases"] = obj - elif isinstance(obj, Release): - context["release"] = obj - - # Add workflow to context - if self.workflow is not None: - context["workflow"] = self.workflow - - else: - raise AirflowException( - f"WorkflowBashOperator.render_template: self.make_release returned an object of an invalid type (should be a list of Releases, or a single Release object): {type(obj)}" - ) - - return super().render_template(content, context, *args, **kwargs) - - -class Release: - def __init__(self, *, dag_id: str, run_id: str): - """Construct a Release instance - - :param dag_id: the DAG ID. - :param run_id: the DAG's run ID. - """ - - self.dag_id = dag_id - self.run_id = run_id - self.workflow_folder = make_workflow_folder(self.dag_id, run_id) - - def __str__(self): - return f"Release(dag_id={self.dag_id}, run_id={self.run_id})" - - -class SnapshotRelease(Release): - def __init__( - self, - *, - dag_id: str, - run_id: str, - snapshot_date: pendulum.DateTime, - ): - """Construct a SnapshotRelease instance - - :param dag_id: the DAG ID. - :param run_id: the DAG's run ID. - :param snapshot_date: the release date of the snapshot. - """ - - super().__init__(dag_id=dag_id, run_id=run_id) - self.snapshot_date = snapshot_date - - snapshot = f"snapshot_{snapshot_date.format(DATE_TIME_FORMAT)}" - self.download_folder = make_workflow_folder(self.dag_id, run_id, snapshot, "download") - self.extract_folder = make_workflow_folder(self.dag_id, run_id, snapshot, "extract") - self.transform_folder = make_workflow_folder(self.dag_id, run_id, snapshot, "transform") - - def __str__(self): - return f"SnapshotRelease(dag_id={self.dag_id}, run_id={self.run_id}, snapshot_date={self.snapshot_date})" - - -class PartitionRelease(Release): - def __init__( - self, - *, - dag_id: str, - run_id: str, - partition_date: pendulum.DateTime, - ): - """Construct a PartitionRelease instance - - :param dag_id: the DAG ID. - :param run_id: the DAG's run ID. - :param partition_date: the release date of the partition. - """ - - super().__init__(dag_id=dag_id, run_id=run_id) - self.partition_date = partition_date - - partition = f"partition_{partition_date.format(DATE_TIME_FORMAT)}" - self.download_folder = make_workflow_folder(self.dag_id, run_id, partition, "download") - self.extract_folder = make_workflow_folder(self.dag_id, run_id, partition, "extract") - self.transform_folder = make_workflow_folder(self.dag_id, run_id, partition, "transform") - - def __str__(self): - return f"PartitionRelease(dag_id={self.dag_id}, run_id={self.run_id}, partition_date={self.partition_date})" - - -class ChangefileRelease(Release): - def __init__( - self, - *, - dag_id: str, - run_id: str, - start_date: pendulum.DateTime = None, - end_date: pendulum.DateTime = None, - sequence_start: int = None, - sequence_end: int = None, - ): - """Construct a ChangefileRelease instance - - :param dag_id: the DAG ID. - :param run_id: the DAG's run ID. - :param start_date: the date of the first changefile processed in this release. - :param end_date: the date of the last changefile processed in this release. - :param sequence_start: the starting sequence number of files that make up this release. - :param sequence_end: the end sequence number of files that make up this release. - """ - - super().__init__(dag_id=dag_id, run_id=run_id) - self.start_date = start_date - self.end_date = end_date - self.sequence_start = sequence_start - self.sequence_end = sequence_end - - changefile = f"changefile_{start_date.format(DATE_TIME_FORMAT)}_to_{end_date.format(DATE_TIME_FORMAT)}" - self.download_folder = make_workflow_folder(self.dag_id, run_id, changefile, "download") - self.extract_folder = make_workflow_folder(self.dag_id, run_id, changefile, "extract") - self.transform_folder = make_workflow_folder(self.dag_id, run_id, changefile, "transform") - - def __str__(self): - return ( - f"Release(dag_id={self.dag_id}, run_id={self.run_id}, start_date={self.start_date}, " - f"end_date={self.end_date}, sequence_start={self.sequence_start}, sequence_end={self.sequence_end})" - ) - - -def set_task_state(success: bool, task_id: str, release: Release = None): - """Update the state of the Airflow task. - :param success: whether the task was successful or not. - :param task_id: the task id. - :param release: the release being processed. Optional. - :return: None. - """ - - if success: - msg = f"{task_id}: success" - if release is not None: - msg += f" {release}" - logging.info(msg) - else: - msg_failed = f"{task_id}: failed" - if release is not None: - msg_failed += f" {release}" - logging.error(msg_failed) - raise AirflowException(msg_failed) - - -class ReleaseFunction(Protocol): - def __call__(self, release: Release, **kwargs: Any) -> Any: - ... - - """ - :param release: A single instance of an AbstractRelease - :param kwargs: the context passed from the PythonOperator. See - https://airflow.apache.org/docs/stable/macros-ref.html for a list of the keyword arguments that are passed to - this argument. - :return: Any. - """ - - -class ListReleaseFunction(Protocol): - def __call__(self, releases: List[Release], **kwargs: Any) -> Any: - ... - - """ - :param releases: A list of AbstractRelease instances - :param kwargs: the context passed from the PythonOperator. See - https://airflow.apache.org/docs/stable/macros-ref.html for a list of the keyword arguments that are passed to - this argument. - :return: Any. - """ - - -WorkflowFunction = Union[ReleaseFunction, ListReleaseFunction] - - -class AbstractWorkflow(ABC): - @abstractmethod - def add_setup_task(self, func: Callable): - """Add a setup task, which is used to run tasks before 'Release' objects are created, e.g. checking - dependencies, fetching available releases etc. - - A setup task has the following properties: - - Has the signature 'def func(self, **kwargs) -> bool', where - kwargs is the context passed from the PythonOperator. See https://airflow.apache.org/docs/stable/macros-ref.html - for a list of the keyword arguments that are passed to this argument. - - Run by a ShortCircuitOperator, meaning that a setup task can stop a DAG prematurely, e.g. if there is - nothing to process. - - func Needs to return a boolean - - :param func: the function that will be called by the ShortCircuitOperator task. - :return: None. - """ - pass - - @abstractmethod - def add_operator(self, operator: BaseOperator): - """Add an Apache Airflow operator. - - :param operator: the Apache Airflow operator. - """ - pass - - @abstractmethod - def add_task(self, func: Callable): - """Add a task, which is used to process releases. A task has the following properties: - - - Has one of the following signatures 'def func(self, release: Release, **kwargs)' or 'def func(self, releases: List[Release], **kwargs)' - - kwargs is the context passed from the PythonOperator. See https://airflow.apache.org/docs/stable/macros-ref.html for a list of the keyword arguments that are passed to this argument. - - Run by a PythonOperator. - - :param func: the function that will be called by the PythonOperator task. - :return: None. - """ - pass - - @abstractmethod - def task_callable(self, func: WorkflowFunction, **kwargs) -> Any: - """Invoke a task callable. Creates a Release instance or Release instances and calls the given task method. - - :param func: the task method. - :param kwargs: the context passed from the PythonOperator. See - https://airflow.apache.org/docs/stable/macros-ref.html for a list of the keyword arguments that are passed - to this argument. - :return: Any. - """ - pass - - @abstractmethod - def make_release(self, **kwargs) -> Union[Release, List[Release]]: - """Make a release instance. The release is passed as an argument to the function (WorkflowFunction) that is - called in 'task_callable'. - - :param kwargs: the context passed from the PythonOperator. See - https://airflow.apache.org/docs/stable/macros-ref.html for a list of the keyword arguments that are passed - to this argument. - :return: A release instance or list of release instances - """ - pass - - @abstractmethod - def make_dag(self) -> DAG: - """Make an Airflow DAG for a workflow. - - :return: the DAG object. - """ - pass - - -def make_task_id(func: Callable, kwargs: Dict) -> str: - """Set a task_id from a func or kwargs. - - :param func: the task function. - :param kwargs: the task kwargs parameter. - :return: the task id. - """ - - task_id_key = "task_id" - - if task_id_key in kwargs: - task_id = kwargs["task_id"] - else: - task_id = func.__name__ - - return task_id - - -class Workflow(AbstractWorkflow): - RELEASE_INFO = "releases" - - def __init__( - self, - dag_id: str, - start_date: pendulum.DateTime, - schedule: str, - catchup: bool = False, - queue: str = "default", - max_retries: int = 3, - max_active_runs: int = 1, - airflow_vars: list = None, - airflow_conns: list = None, - tags: Optional[List[str]] = None, - ): - """Construct a Workflow instance. - - :param dag_id: the id of the DAG. - :param start_date: the start date of the DAG. - :param schedule: the schedule interval of the DAG. - :param catchup: whether to catchup the DAG or not. - :param queue: the Airflow queue name. - :param max_retries: the number of times to retry each task. - :param max_active_runs: the maximum number of DAG runs that can be run at once. - :param airflow_vars: list of airflow variable keys, for each variable it is checked if it exists in airflow - :param airflow_conns: list of airflow connection keys, for each connection it is checked if it exists in airflow - :param tags: Optional Airflow DAG tags to add. - """ - - self.dag_id = dag_id - self.start_date = start_date - self.schedule = schedule - self.catchup = catchup - self.queue = queue - self.max_retries = max_retries - self.max_active_runs = max_active_runs - self.airflow_vars = airflow_vars - self.airflow_conns = airflow_conns - self._parallel_tasks = False - - self.operators = [] - self.default_args = { - "owner": "airflow", - "start_date": self.start_date, - "on_failure_callback": on_failure_callback, - "retries": self.max_retries, - "queue": self.queue, - } - self.description = self.__doc__ - self.dag = DAG( - dag_id=self.dag_id, - schedule=self.schedule, - default_args=self.default_args, - catchup=self.catchup, - max_active_runs=self.max_active_runs, - doc_md=self.__doc__, - tags=tags, - ) - - def add_operator(self, operator: BaseOperator): - """Add an Apache Airflow operator. - - :param operator: the Apache Airflow operator. - :return: None. - """ - - # Update operator settings - operator.start_date = self.start_date - operator.dag = self.dag - operator.queue = self.queue - operator.__dict__.update(self.default_args) - - # Add list of tasks while parallel_tasks is set - if self._parallel_tasks: - if len(self.operators) == 0 or not isinstance(self.operators[-1], List): - self.operators.append([operator]) - else: - self.operators[-1].append(operator) - # Add single task to the end of the list - else: - self.operators.append(operator) - - def add_setup_task(self, func: Callable, **kwargs): - """Add a setup task, which is used to run tasks before 'Release' objects are created, e.g. checking - dependencies, fetching available releases etc. - - A setup task has the following properties: - - Has the signature 'def func(self, **kwargs) -> bool', where - kwargs is the context passed from the PythonOperator. See https://airflow.apache.org/docs/stable/macros-ref.html - for a list of the keyword arguments that are passed to this argument. - - Run by a ShortCircuitOperator, meaning that a setup task can stop a DAG prematurely, e.g. if there is - nothing to process. - - func Needs to return a boolean - - :param func: the function that will be called by the ShortCircuitOperator task. - :return: None. - """ - - kwargs_ = copy.copy(kwargs) - kwargs_["task_id"] = make_task_id(func, kwargs) - op = ShortCircuitOperator(python_callable=func, **kwargs_) - self.add_operator(op) - - def add_task( - self, - func: Callable, - **kwargs, - ): - """Add a task, which is used to process releases. A task has the following properties: - - - Has one of the following signatures 'def func(self, release: Release, **kwargs)' or 'def func(self, - releases: List[Release], **kwargs)' - - kwargs is the context passed from the PythonOperator. See https://airflow.apache.org/docs/stable/macros-ref.html - for a list of the keyword arguments that are passed to this argument. - - Run by a PythonOperator. - - :param func: the function that will be called by the PythonOperator task. - :return: None. - """ - - kwargs_ = copy.copy(kwargs) - kwargs_["task_id"] = make_task_id(func, kwargs) - op = PythonOperator(python_callable=partial(self.task_callable, func), **kwargs_) - self.add_operator(op) - - def make_python_operator( - self, - func: Callable, - task_id: str, - **kwargs, - ): - """Make a PythonOperator which is used to process releases. - - :param func: the function that will be called by the PythonOperator task. - :param task_id: the task id. - :param kwargs: the context passed from the PythonOperator. See https://airflow.apache.org/docs/stable/macros-ref.html - for a list of the keyword arguments that are passed to this argument. - :return: PythonOperator instance. - """ - - kwargs_ = copy.copy(kwargs) - kwargs_["task_id"] = task_id - return PythonOperator(python_callable=partial(self.task_callable, func), **kwargs_) - - @contextlib.contextmanager - def parallel_tasks(self): - """When called, all tasks added to the workflow within the `with` block will run in parallel. - add_task can be used with this function. - - :return: None. - """ - - try: - self._parallel_tasks = True - yield - finally: - self._parallel_tasks = False - - def task_callable(self, func: WorkflowFunction, **kwargs) -> Any: - """Invoke a task callable. Creates a Release instance and calls the given task method. The result can be - pulled as an xcom in Airflow. - - :param func: the task method. - :param kwargs: the context passed from the PythonOperator. - See https://airflow.apache.org/docs/stable/macros-ref.html for a list of the keyword arguments that are passed - to this argument. - :return: Any. - """ - - release = self.make_release(**kwargs) - result = func(release, **kwargs) - return result - - def make_dag(self) -> DAG: - """Make an Airflow DAG for a workflow. - - :return: the DAG object. - """ - - with self.dag: - chain(*self.operators) - - return self.dag - - def check_dependencies(self, **kwargs) -> bool: - """Checks the 'workflow' attributes, airflow variables & connections and possibly additional custom checks. - - :param kwargs: The context passed from the PythonOperator. - :return: None. - """ - # check that vars and connections are available - vars_valid = True - conns_valid = True - if self.airflow_vars: - vars_valid = check_variables(*self.airflow_vars) - if self.airflow_conns: - conns_valid = check_connections(*self.airflow_conns) - - if not vars_valid or not conns_valid: - raise AirflowException("Required variables or connections are missing") - - return True diff --git a/observatory-platform/requirements.sh b/observatory-platform/requirements.sh deleted file mode 100644 index 595315cb1..000000000 --- a/observatory-platform/requirements.sh +++ /dev/null @@ -1,18 +0,0 @@ -#!/usr/bin/env bash -# Copyright 2021 Curtin University -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Basic dependencies -apt-get update -apt-get install git build-essential -y \ No newline at end of file diff --git a/observatory-platform/requirements.txt b/observatory-platform/requirements.txt deleted file mode 100644 index c5ac5ac5f..000000000 --- a/observatory-platform/requirements.txt +++ /dev/null @@ -1,67 +0,0 @@ -# Observatory API -observatory-api - -# Airflow -apache-airflow[redis,celery,slack,http]==2.6.3 - -# Google Cloud -google-crc32c>=1.1.0,<2 -google-cloud-bigquery>=3,<4 -google-api-python-client>=2,<3 -google-cloud-storage>=2.7.0,<3 -google-auth-oauthlib>=0.4.5,<1 - -# AWS -boto3>=1.15.0,<2 - -# Docker -docker>6,<7 - -# Command line interface, config files and templates -click>=8,<9 -pyyaml>=6,<7 -cerberus>=1.3.4,<2 -Jinja2>=3,<4 -stringcase>=1.2.0,<2 -pyhcl>=0.4.4,<1 - -# Reading and writing jsonlines files -jsonlines>=2.0.0,<3 # Writing -json_lines>=0.5.0,<1 # Reading, including jsonl.gz - -# HTTP requests and URL cleaning -requests>=2.25.0,<3 -tldextract>=3.1.1 -aiohttp>=3.7.0,<4 -responses>=0.23.1,<1 - -# SFTP -pysftp>=0.2.9,<1 -paramiko>=2.7.2,<3 - -# FTP -pyftpdlib>=1.5.7,<2 - -# Utils -natsort>=7.1.1,<8 -backoff>=2,<3 -timeout-decorator -validators<=0.20.0 -xmltodict -pandas -tenacity - -# Test utils -time_machine>=2.0.0,<3 -httpretty>=1.0.0,<2 -sftpserver>=0.3,<1 - -# Backports -typing-extensions>=3.10.0.1,<4; python_version<'3.10' - -oauth2client>=4.1.0,<5 - -deprecation>2,<3 - -virtualenv -deepdiff \ No newline at end of file diff --git a/observatory-platform/setup.cfg b/observatory-platform/setup.cfg deleted file mode 100644 index 3778da624..000000000 --- a/observatory-platform/setup.cfg +++ /dev/null @@ -1,65 +0,0 @@ -[metadata] -name = observatory-platform -author = Curtin University -author_email = agent@observatory.academy -summary = The Observatory Platform is an environment for fetching, processing and analysing data to understand how well universities operate as Open Knowledge Institutions. -description_file = README.md -description_content_type = text/markdown; charset=UTF-8 -home_page = https://github.com/The-Academic-Observatory/observatory-platform -project_urls = - Bug Tracker = https://github.com/The-Academic-Observatory/observatory-platform/issues - Documentation = https://observatory-platform.readthedocs.io/en/latest/ - Source Code = https://github.com/The-Academic-Observatory/observatory-platform -python_requires = >=3.10 -license = Apache License Version 2.0 -classifier = - Development Status :: 2 - Pre-Alpha - Environment :: Console - Environment :: Web Environment - Intended Audience :: Developers - Intended Audience :: Science/Research - License :: OSI Approved :: Apache Software License - Operating System :: OS Independent - Programming Language :: Python :: 3 :: Only - Programming Language :: Python :: 3 - Programming Language :: Python :: 3.10 - Topic :: Scientific/Engineering - Topic :: Software Development :: Libraries - Topic :: Software Development :: Libraries :: Python Modules - Topic :: Utilities -keywords = - science - data - workflows - academic institutes - observatory-platform - -[files] -packages = - observatory - observatory.platform -data_files = - requirements.txt = requirements.txt - requirements.sh = requirements.sh - observatory/platform/docker = observatory/platform/docker/* - observatory/platform/terraform = observatory/platform/terraform/* - observatory/platform = - observatory/platform/config.yaml.jinja2 - observatory/platform/config-terraform.yaml.jinja2 - -[entry_points] -console_scripts = - observatory = observatory.platform.cli.cli:cli - -[extras] -tests = - liccheck>=0.4.9,<1 - flake8>=3.8.0,<4 - coverage>=5.2,<6 - faker>=8.12.1,<9 - redis>=3.5.3,<4 - boto3>=1.15.0,<2 - azure-storage-blob>=12.8.1,<13 - -[pbr] -skip_authors = true diff --git a/observatory-platform/setup.py b/observatory-platform/setup.py deleted file mode 100644 index 2c3a057c4..000000000 --- a/observatory-platform/setup.py +++ /dev/null @@ -1,3 +0,0 @@ -from setuptools import setup - -setup(setup_requires=["pbr"], pbr=True, python_requires=">=3.10") diff --git a/observatory-api/observatory/api/__init__.py b/observatory_platform/__init__.py similarity index 100% rename from observatory-api/observatory/api/__init__.py rename to observatory_platform/__init__.py diff --git a/observatory-api/observatory/api/server/__init__.py b/observatory_platform/airflow/__init__.py similarity index 100% rename from observatory-api/observatory/api/server/__init__.py rename to observatory_platform/airflow/__init__.py diff --git a/observatory-platform/observatory/platform/airflow.py b/observatory_platform/airflow/airflow.py similarity index 58% rename from observatory-platform/observatory/platform/airflow.py rename to observatory_platform/airflow/airflow.py index 6082b8c53..8cca6c8d1 100644 --- a/observatory-platform/observatory/platform/airflow.py +++ b/observatory_platform/airflow/airflow.py @@ -16,12 +16,10 @@ from __future__ import annotations -import json import logging import textwrap import traceback from datetime import timedelta -from pydoc import locate from typing import List, Union from typing import Optional @@ -30,110 +28,21 @@ import validators from airflow import AirflowException from airflow.hooks.base import BaseHook -from airflow.models import TaskInstance, DagBag, Variable, XCom, DagRun +from airflow.models import TaskInstance, XCom, DagRun from airflow.providers.slack.hooks.slack_webhook import SlackWebhookHook -from airflow.sensors.external_task import ExternalTaskSensor from airflow.utils.db import provide_session -from airflow.utils.state import State from dateutil.relativedelta import relativedelta from sqlalchemy import and_ from sqlalchemy.orm import Session -from observatory.platform.config import AirflowConns, AirflowVars -from observatory.platform.observatory_config import Workflow, json_string_to_workflows +from observatory_platform.config import AirflowConns ScheduleInterval = Union[str, timedelta, relativedelta] -def change_task_log_level(new_levels: Union[List, int]) -> list: - """Change the logging levels of all handlers for an airflow task. - - :param new_levels: New logging levels that all handlers will be set to - :return: List of the old logging levels, can be used to restore logging levels. - """ - logger = logging.getLogger("airflow.task") - # stores logging levels - old_levels = [] - for count, handler in enumerate(logger.handlers): - old_levels.append(handler.level) - if isinstance(new_levels, int): - handler.setLevel(new_levels) - else: - handler.setLevel(new_levels[count]) - return old_levels - - -def check_variables(*variables): - """Checks whether all given airflow variables exist. - - :param variables: name of airflow variable - :return: True if all variables are valid - """ - is_valid = True - for name in variables: - try: - Variable.get(name) - except KeyError: - logging.error(f"Airflow variable '{name}' not set.") - is_valid = False - return is_valid - - -def check_connections(*connections): - """Checks whether all given airflow connections exist. - - :param connections: name of airflow connection - :return: True if all connections are valid - """ - is_valid = True - for name in connections: - try: - BaseHook.get_connection(name) - except KeyError: - logging.error(f"Airflow connection '{name}' not set.") - is_valid = False - return is_valid - - -def send_slack_msg( - *, ti: TaskInstance, execution_date: pendulum.DateTime, comments: str = "", slack_conn_id: str = AirflowConns.SLACK -): - """ - Send a Slack message using the token in the slack airflow connection. - :param ti: Task instance. - :param execution_date: DagRun execution date. - :param comments: Additional comments in slack message - :param slack_conn_id: the Airflow connection id for the Slack connection. - """ - - message = textwrap.dedent( - """ - :red_circle: Task Alert. - *Task*: {task} - *Dag*: {dag} - *Execution Time*: {exec_date} - *Log Url*: {log_url} - *Comments*: {comments} - """ - ).format( - task=ti.task_id, - dag=ti.dag_id, - exec_date=execution_date, - log_url=ti.log_url, - comments=comments, - ) - hook = SlackWebhookHook(slack_webhook_conn_id=slack_conn_id) - - # http_hook outputs the secret token, suppressing logging 'info' by setting level to 'warning' - old_levels = change_task_log_level(logging.WARNING) - hook.send_text(message) - # change back to previous levels - change_task_log_level(old_levels) - - def get_airflow_connection_url(conn_id: str) -> str: """Get the Airflow connection host, validate it is a valid url, and return it if it is, with a trailing /, - otherwise throw an exception. Assumes the connection_id exists. + otherwise throw an exception. Assumes the connection_id exists. :param conn_id: Airflow connection id. :return: Connection URL with a trailing / added if necessary, or raise an exception if it is not a valid URL. @@ -167,6 +76,16 @@ def get_airflow_connection_login(conn_id: str) -> str: return login +def is_first_dag_run(dag_run: DagRun) -> bool: + """Whether the DAG Run is the first run or not + + :param dag_run: A Dag Run instance + :return: Whether the DAG run is the first run or not + """ + + return dag_run.get_previous_dagrun() is None + + def get_airflow_connection_password(conn_id: str) -> str: """Get the Airflow connection password. Assumes the connection_id exists. @@ -186,8 +105,7 @@ def get_airflow_connection_password(conn_id: str) -> str: def on_failure_callback(context): - """ - Function that is called on failure of an airflow task. Will create a slack webhook and send a notification. + """Function that is called on failure of an airflow task. Will create a slack webhook and send a notification. :param context: the context passed from the PythonOperator. See https://airflow.apache.org/docs/stable/macros-ref.html for a list of the keyword arguments that are passed to @@ -209,9 +127,63 @@ def on_failure_callback(context): send_slack_msg(ti=ti, execution_date=execution_date, comments=comments, slack_conn_id=AirflowConns.SLACK) -def normalized_schedule_interval(schedule_interval: Optional[str]) -> Optional[ScheduleInterval]: +def change_task_log_level(new_levels: Union[List, int]) -> list: + """Change the logging levels of all handlers for an airflow task. + + :param new_levels: New logging levels that all handlers will be set to + :return: List of the old logging levels, can be used to restore logging levels. """ - Copied from https://github.com/apache/airflow/blob/main/airflow/models/dag.py#L851-L866 + logger = logging.getLogger("airflow.task") + # stores logging levels + old_levels = [] + for count, handler in enumerate(logger.handlers): + old_levels.append(handler.level) + if isinstance(new_levels, int): + handler.setLevel(new_levels) + else: + handler.setLevel(new_levels[count]) + return old_levels + + +def send_slack_msg( + *, ti: TaskInstance, execution_date: pendulum.DateTime, comments: str = "", slack_conn_id: str = AirflowConns.SLACK +): + """ + Send a Slack message using the token in the slack airflow connection. + + :param ti: Task instance. + :param execution_date: DagRun execution date. + :param comments: Additional comments in slack message + :param slack_conn_id: the Airflow connection id for the Slack connection. + """ + + message = textwrap.dedent( + """ + :red_circle: Task Alert. + *Task*: {task} + *Dag*: {dag} + *Execution Time*: {exec_date} + *Log Url*: {log_url} + *Comments*: {comments} + """ + ).format( + task=ti.task_id, + dag=ti.dag_id, + exec_date=execution_date, + log_url=ti.log_url, + comments=comments, + ) + hook = SlackWebhookHook(slack_webhook_conn_id=slack_conn_id) + + # http_hook outputs the secret token, suppressing logging 'info' by setting level to 'warning' + old_levels = change_task_log_level(logging.WARNING) + hook.send_text(message) + # change back to previous levels + change_task_log_level(old_levels) + + +def normalized_schedule_interval(schedule_interval: Optional[str]) -> Optional[ScheduleInterval]: + """Copied from https://github.com/apache/airflow/blob/main/airflow/models/dag.py#L851-L866 Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE file @@ -253,93 +225,6 @@ def normalized_schedule_interval(schedule_interval: Optional[str]) -> Optional[S return _schedule_interval -def get_data_path() -> str: - """Grabs the DATA_PATH airflow vairable - - :raises AirflowException: Raised if the variable does not exist - :return: DATA_PATH variable contents - """ - # Try to get value from env variable first, saving costs from GC secret usage - data_path = Variable.get(AirflowVars.DATA_PATH) - if not data_path: - raise AirflowException("DATA_PATH variable could not be found.") - return data_path - - -def fetch_workflows() -> List[Workflow]: - """Get the workflows from the Airflow Variable - - :return: the workflows to create. - """ - - workflows_str = Variable.get(AirflowVars.WORKFLOWS) - logging.info(f"workflows_str: {workflows_str}") - - try: - workflows = json_string_to_workflows(workflows_str) - logging.info(f"workflows: {workflows}") - except json.decoder.JSONDecodeError as e: - e.msg = f"workflows_str: {workflows_str}\n\n{e.msg}" - raise e - - return workflows - - -def fetch_dags_modules() -> dict: - """Get the dags modules from the Airflow Variable - - :return: Dags modules - """ - - dags_modules_str = Variable.get(AirflowVars.DAGS_MODULE_NAMES) - logging.info(f"dags_modules_str: {dags_modules_str}") - - try: - dags_modules_ = json.loads(dags_modules_str) - logging.info(f"dags_modules: {dags_modules_}") - except json.decoder.JSONDecodeError as e: - e.msg = f"dags_modules_str: {dags_modules_str}\n\n{e.msg}" - raise e - - return dags_modules_ - - -def make_workflow(workflow: Workflow): - """Make a workflow instance. - :param workflow: the workflow configuration. - :return: the workflow instance. - """ - - cls = locate(workflow.class_name) - if cls is None: - raise ModuleNotFoundError(f"dag_id={workflow.dag_id}: could not locate class_name={workflow.class_name}") - - return cls(dag_id=workflow.dag_id, cloud_workspace=workflow.cloud_workspace, **workflow.kwargs) - - -def fetch_dag_bag(path: str, include_examples: bool = False) -> DagBag: - """Load a DAG Bag from a given path. - - :param path: the path to the DAG bag. - :param include_examples: whether to include example DAGs or not. - :return: None. - """ - logging.info(f"Loading DAG bag from path: {path}") - dag_bag = DagBag(path, include_examples=include_examples) - - if dag_bag is None: - raise Exception(f"DagBag could not be loaded from path: {path}") - - if len(dag_bag.import_errors): - # Collate loading errors as single string and raise it as exception - results = [] - for path, exception in dag_bag.import_errors.items(): - results.append(f"DAG import exception: {path}\n{exception}\n\n") - raise Exception("\n".join(results)) - - return dag_bag - - @provide_session def delete_old_xcoms( session: Session = None, @@ -365,70 +250,3 @@ def delete_old_xcoms( ) # set synchronize_session="fetch" to prevent the following error: sqlalchemy.exc.InvalidRequestError: Could not evaluate current criteria in Python: "Cannot evaluate SelectStatementGrouping". Specify 'fetch' or False for the synchronize_session execution option. results.delete(synchronize_session="fetch") - - -def is_first_dag_run(dag_run: DagRun) -> bool: - """Whether the DAG Run is the first run or not - - :param dag_run: A Dag Run instance - :return: Whether the DAG run is the first run or not - """ - - return dag_run.get_previous_dagrun() is None - - -class PreviousDagRunSensor(ExternalTaskSensor): - def __init__( - self, - dag_id: str, - task_id: str = "wait_for_prev_dag_run", - external_task_id: str = "dag_run_complete", - allowed_states: List[str] = None, - *args, - **kwargs, - ): - """Custom ExternalTaskSensor designed to wait for a previous DAG run of the same DAG. This sensor also - skips on the first DAG run, as the DAG hasn't run before. - - Add the following as the last tag of your DAG: - DummyOperator( - task_id=external_task_id, - ) - - :param dag_id: the DAG id of the DAG to wait for. - :param task_id: the task id for this sensor. - :param external_task_id: the task id to wait for. - :param allowed_states: to override allowed_states. - :param args: args for ExternalTaskSensor. - :param kwargs: kwargs for ExternalTaskSensor. - """ - - if allowed_states is None: - # sensor can skip a run if previous dag run skipped for some reason - allowed_states = [ - State.SUCCESS, - State.SKIPPED, - ] - - super().__init__( - task_id=task_id, - external_dag_id=dag_id, - external_task_id=external_task_id, - allowed_states=allowed_states, - *args, - **kwargs, - ) - - @provide_session - def poke(self, context, session=None): - # Custom poke to allow the sensor to skip on the first DAG run - - ti = context["task_instance"] - dag_run = context["dag_run"] - - if is_first_dag_run(dag_run): - self.log.info("Skipping the sensor check on the first DAG run") - ti.set_state(State.SKIPPED) - return True - - return super().poke(context, session=session) diff --git a/observatory_platform/airflow/release.py b/observatory_platform/airflow/release.py new file mode 100644 index 000000000..032d84de9 --- /dev/null +++ b/observatory_platform/airflow/release.py @@ -0,0 +1,191 @@ +# Copyright 2020 Curtin University +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Author: Aniek Roelofs, James Diprose, Tuan Chien + +from __future__ import annotations + +import logging +import os + +import pendulum +from airflow.exceptions import AirflowException + +from observatory_platform.airflow.workflow import make_workflow_folder + +DATE_TIME_FORMAT = "YYYY-MM-DD_HH:mm:ss" + + +def make_snapshot_date(**kwargs) -> pendulum.DateTime: + """Make a snapshot date""" + + return kwargs["data_interval_end"] + + +def set_task_state(success: bool, task_id: str, release: Release = None): + """Update the state of the Airflow task. + :param success: whether the task was successful or not. + :param task_id: the task id. + :param release: the release being processed. Optional. + :return: None. + """ + + if success: + msg = f"{task_id}: success" + if release is not None: + msg += f" {release}" + logging.info(msg) + else: + msg_failed = f"{task_id}: failed" + if release is not None: + msg_failed += f" {release}" + logging.error(msg_failed) + raise AirflowException(msg_failed) + + +class Release: + def __init__(self, *, dag_id: str, run_id: str): + """Construct a Release instance + + :param dag_id: the DAG ID. + :param run_id: the DAG's run ID. + """ + + self.dag_id = dag_id + self.run_id = run_id + + @property + def workflow_folder(self): + return make_workflow_folder(self.dag_id, self.run_id) + + @property + def release_folder(self): + raise NotImplementedError("self.release_folder should be implemented by subclasses") + + @property + def download_folder(self): + path = os.path.join(self.release_folder, "download") + os.makedirs(path, exist_ok=True) + return path + + @property + def extract_folder(self): + path = os.path.join(self.release_folder, "extract") + os.makedirs(path, exist_ok=True) + return path + + @property + def transform_folder(self): + path = os.path.join(self.release_folder, "transform") + os.makedirs(path, exist_ok=True) + return path + + def __str__(self): + return f"Release(dag_id={self.dag_id}, run_id={self.run_id})" + + +class SnapshotRelease(Release): + def __init__( + self, + *, + dag_id: str, + run_id: str, + snapshot_date: pendulum.DateTime, + ): + """Construct a SnapshotRelease instance + + :param dag_id: the DAG ID. + :param run_id: the DAG's run ID. + :param snapshot_date: the release date of the snapshot. + """ + + super().__init__(dag_id=dag_id, run_id=run_id) + self.snapshot_date = snapshot_date + + @property + def release_folder(self): + return make_workflow_folder(self.dag_id, self.run_id, f"snapshot_{self.snapshot_date.format(DATE_TIME_FORMAT)}") + + def __str__(self): + return f"SnapshotRelease(dag_id={self.dag_id}, run_id={self.run_id}, snapshot_date={self.snapshot_date})" + + +class PartitionRelease(Release): + def __init__( + self, + *, + dag_id: str, + run_id: str, + partition_date: pendulum.DateTime, + ): + """Construct a PartitionRelease instance + + :param dag_id: the DAG ID. + :param run_id: the DAG's run ID. + :param partition_date: the release date of the partition. + """ + + super().__init__(dag_id=dag_id, run_id=run_id) + self.partition_date = partition_date + + @property + def release_folder(self): + return make_workflow_folder( + self.dag_id, self.run_id, f"partition_{self.partition_date.format(DATE_TIME_FORMAT)}" + ) + + def __str__(self): + return f"PartitionRelease(dag_id={self.dag_id}, run_id={self.run_id}, partition_date={self.partition_date})" + + +class ChangefileRelease(Release): + def __init__( + self, + *, + dag_id: str, + run_id: str, + start_date: pendulum.DateTime = None, + end_date: pendulum.DateTime = None, + sequence_start: int = None, + sequence_end: int = None, + ): + """Construct a ChangefileRelease instance + + :param dag_id: the DAG ID. + :param run_id: the DAG's run ID. + :param start_date: the date of the first changefile processed in this release. + :param end_date: the date of the last changefile processed in this release. + :param sequence_start: the starting sequence number of files that make up this release. + :param sequence_end: the end sequence number of files that make up this release. + """ + + super().__init__(dag_id=dag_id, run_id=run_id) + self.start_date = start_date + self.end_date = end_date + self.sequence_start = sequence_start + self.sequence_end = sequence_end + + @property + def release_folder(self): + return make_workflow_folder( + self.dag_id, + self.run_id, + f"changefile_{self.start_date.format(DATE_TIME_FORMAT)}_to_{self.end_date.format(DATE_TIME_FORMAT)}", + ) + + def __str__(self): + return ( + f"Release(dag_id={self.dag_id}, run_id={self.run_id}, start_date={self.start_date}, " + f"end_date={self.end_date}, sequence_start={self.sequence_start}, sequence_end={self.sequence_end})" + ) diff --git a/observatory-platform/observatory/platform/utils/dag_run_sensor.py b/observatory_platform/airflow/sensors.py similarity index 73% rename from observatory-platform/observatory/platform/utils/dag_run_sensor.py rename to observatory_platform/airflow/sensors.py index d5ad6ee69..36629bdbb 100644 --- a/observatory-platform/observatory/platform/utils/dag_run_sensor.py +++ b/observatory_platform/airflow/sensors.py @@ -14,19 +14,80 @@ # Author: Tuan Chien +from __future__ import annotations import datetime import os from time import sleep -from typing import Dict, Union +from typing import Dict, Union, List from airflow.exceptions import AirflowException from airflow.models import DagModel, DagRun from airflow.sensors.base import BaseSensorOperator -from airflow.utils.session import provide_session +from airflow.sensors.external_task import ExternalTaskSensor +from airflow.utils.db import provide_session from airflow.utils.state import State from sqlalchemy.orm.scoping import scoped_session +from observatory_platform.airflow.airflow import is_first_dag_run + + +class PreviousDagRunSensor(ExternalTaskSensor): + def __init__( + self, + dag_id: str, + task_id: str = "wait_for_prev_dag_run", + external_task_id: str = "dag_run_complete", + allowed_states: List[str] = None, + *args, + **kwargs, + ): + """Custom ExternalTaskSensor designed to wait for a previous DAG run of the same DAG. This sensor also + skips on the first DAG run, as the DAG hasn't run before. + + Add the following as the last tag of your DAG: + DummyOperator( + task_id=external_task_id, + ) + + :param dag_id: the DAG id of the DAG to wait for. + :param task_id: the task id for this sensor. + :param external_task_id: the task id to wait for. + :param allowed_states: to override allowed_states. + :param args: args for ExternalTaskSensor. + :param kwargs: kwargs for ExternalTaskSensor. + """ + + if allowed_states is None: + # sensor can skip a run if previous dag run skipped for some reason + allowed_states = [ + State.SUCCESS, + State.SKIPPED, + ] + + super().__init__( + task_id=task_id, + external_dag_id=dag_id, + external_task_id=external_task_id, + allowed_states=allowed_states, + *args, + **kwargs, + ) + + @provide_session + def poke(self, context, session=None): + # Custom poke to allow the sensor to skip on the first DAG run + + ti = context["task_instance"] + dag_run = context["dag_run"] + + if is_first_dag_run(dag_run): + self.log.info("Skipping the sensor check on the first DAG run") + ti.set_state(State.SKIPPED) + return True + + return super().poke(context, session=session) + class DagRunSensor(BaseSensorOperator): """ diff --git a/observatory_platform/airflow/tasks.py b/observatory_platform/airflow/tasks.py new file mode 100644 index 000000000..4ffae35a9 --- /dev/null +++ b/observatory_platform/airflow/tasks.py @@ -0,0 +1,115 @@ +# Copyright 2020-2023 Curtin University +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import annotations + +import logging +from typing import Optional, List + +import airflow +from airflow.decorators import task +from airflow.exceptions import AirflowNotFoundException +from airflow.hooks.base import BaseHook +from airflow.models import Variable + +from observatory_platform.google.gcp import gcp_delete_disk, gcp_create_disk +from observatory_platform.google.gke import gke_create_volume, gke_delete_volume + + +@task +def check_dependencies(airflow_vars: Optional[List[str]] = None, airflow_conns: Optional[List[str]] = None, **context): + """Checks if the given Airflow Variables and Connections exist. + + :param airflow_vars: the Airflow Variables to check exist. + :param airflow_conns: the Airflow Connections to check exist. + :return: None. + """ + + vars_valid = True + conns_valid = True + if airflow_vars: + vars_valid = check_variables(*airflow_vars) + if airflow_conns: + conns_valid = check_connections(*airflow_conns) + + if not vars_valid or not conns_valid: + raise AirflowNotFoundException("Required variables or connections are missing") + + +def check_variables(*variables): + """Checks whether all given airflow variables exist. + + :param variables: name of airflow variable + :return: True if all variables are valid + """ + is_valid = True + for name in variables: + try: + Variable.get(name) + except AirflowNotFoundException: + logging.error(f"Airflow variable '{name}' not set.") + is_valid = False + return is_valid + + +def check_connections(*connections): + """Checks whether all given airflow connections exist. + + :param connections: name of airflow connection + :return: True if all connections are valid + """ + is_valid = True + for name in connections: + try: + BaseHook.get_connection(name) + except airflow.exceptions.AirflowNotFoundException: + logging.error(f"Airflow connection '{name}' not set.") + is_valid = False + return is_valid + + +@task +def gke_create_storage( + project_id: str, zone: str, volume_name: str, volume_size: int, kubernetes_conn_id: str, **context +): + """Create storage on a GKE cluster. + + :param project_id: the Google Cloud project ID. + :param zone: the Google Cloud zone. + :param volume_name: the name of the volume. + :param volume_size: the volume size. + :param kubernetes_conn_id: the Kubernetes Airflow Connection ID. + :param context: the Airflow context. + :return: None. + """ + + gcp_create_disk(project_id=project_id, zone=zone, disk_name=volume_name, disk_size_gb=volume_size) + gke_create_volume(kubernetes_conn_id=kubernetes_conn_id, volume_name=volume_name, size_gi=volume_size) + + +@task +def gke_delete_storage(project_id: str, zone: str, volume_name: str, kubernetes_conn_id: str, **context): + """Delete storage on a GKE cluster. + + :param project_id: the Google Cloud project ID. + :param zone: the Google Cloud zone. + :param volume_name: the name of the volume. + :param kubernetes_conn_id: the Kubernetes Airflow Connection ID. + :param context: the Airflow context. + :return: None. + """ + + gke_delete_volume(kubernetes_conn_id=kubernetes_conn_id, volume_name=volume_name) + gcp_delete_disk(project_id=project_id, zone=zone, disk_name=volume_name) diff --git a/observatory-platform/observatory/platform/__init__.py b/observatory_platform/airflow/tests/__init__.py similarity index 100% rename from observatory-platform/observatory/platform/__init__.py rename to observatory_platform/airflow/tests/__init__.py diff --git a/observatory-platform/observatory/platform/cli/__init__.py b/observatory_platform/airflow/tests/fixtures/__init__.py similarity index 100% rename from observatory-platform/observatory/platform/cli/__init__.py rename to observatory_platform/airflow/tests/fixtures/__init__.py diff --git a/tests/fixtures/utils/bad_dag.py b/observatory_platform/airflow/tests/fixtures/bad_dag.py similarity index 100% rename from tests/fixtures/utils/bad_dag.py rename to observatory_platform/airflow/tests/fixtures/bad_dag.py diff --git a/tests/fixtures/utils/good_dag.py b/observatory_platform/airflow/tests/fixtures/good_dag.py similarity index 100% rename from tests/fixtures/utils/good_dag.py rename to observatory_platform/airflow/tests/fixtures/good_dag.py diff --git a/tests/observatory/platform/test_airflow.py b/observatory_platform/airflow/tests/test_airflow.py similarity index 78% rename from tests/observatory/platform/test_airflow.py rename to observatory_platform/airflow/tests/test_airflow.py index ab7930ed6..2f24f3dbf 100644 --- a/tests/observatory/platform/test_airflow.py +++ b/observatory_platform/airflow/tests/test_airflow.py @@ -16,39 +16,33 @@ import datetime import os -import shutil import textwrap import unittest from unittest.mock import MagicMock, patch import pendulum -from airflow.exceptions import AirflowException +from airflow.decorators import dag +from airflow.exceptions import AirflowException, AirflowNotFoundException from airflow.models.connection import Connection from airflow.models.dag import DAG -from airflow.models.variable import Variable from airflow.models.xcom import XCom from airflow.operators.bash import BashOperator from airflow.operators.python import PythonOperator from airflow.utils.session import provide_session from airflow.utils.state import State -from observatory.platform.airflow import ( +from observatory_platform.airflow.airflow import ( get_airflow_connection_login, get_airflow_connection_password, get_airflow_connection_url, send_slack_msg, - fetch_dags_modules, - fetch_dag_bag, delete_old_xcoms, on_failure_callback, - get_data_path, normalized_schedule_interval, is_first_dag_run, ) -from observatory.platform.observatory_environment import ( - ObservatoryEnvironment, - test_fixtures_path, -) +from observatory_platform.airflow.tasks import check_dependencies +from observatory_platform.sandbox.sandbox_environment import SandboxEnvironment class MockConnection: @@ -64,72 +58,7 @@ def get_password(self): class TestAirflow(unittest.TestCase): - @patch("observatory.platform.airflow.Variable.get") - def test_get_data_path(self, mock_variable_get): - """Tests the function that retrieves the data_path airflow variable""" - # 1 - no variable available - mock_variable_get.return_value = None - self.assertRaises(AirflowException, get_data_path) - - # 2 - available in Airflow variable - mock_variable_get.return_value = "env_return" - self.assertEqual("env_return", get_data_path()) - - def test_fetch_dags_modules(self): - """Test fetch_dags_modules""" - - dags_module_names_val = '["academic_observatory_workflows.dags", "oaebu_workflows.dags"]' - expected = ["academic_observatory_workflows.dags", "oaebu_workflows.dags"] - env = ObservatoryEnvironment(enable_api=False) - with env.create(): - # Test when no variable set - with self.assertRaises(KeyError): - fetch_dags_modules() - - # Test when using an Airflow Variable exists - env.add_variable(Variable(key="dags_module_names", val=dags_module_names_val)) - actual = fetch_dags_modules() - self.assertEqual(expected, actual) - - with ObservatoryEnvironment(enable_api=False).create(): - # Set environment variable - new_env = env.new_env - new_env["AIRFLOW_VAR_DAGS_MODULE_NAMES"] = dags_module_names_val - os.environ.update(new_env) - - # Test when using an Airflow Variable set with an environment variable - actual = fetch_dags_modules() - self.assertEqual(expected, actual) - - def test_fetch_dag_bag(self): - """Test fetch_dag_bag""" - - env = ObservatoryEnvironment(enable_api=False) - with env.create() as t: - # No DAGs found - dag_bag = fetch_dag_bag(t) - print(f"DAGS found on path: {t}") - for dag_id in dag_bag.dag_ids: - print(f" {dag_id}") - self.assertEqual(0, len(dag_bag.dag_ids)) - - # Bad DAG - src = test_fixtures_path("utils", "bad_dag.py") - shutil.copy(src, os.path.join(t, "dags.py")) - with self.assertRaises(Exception): - fetch_dag_bag(t) - - # Copy Good DAGs to folder - src = test_fixtures_path("utils", "good_dag.py") - shutil.copy(src, os.path.join(t, "dags.py")) - - # DAGs found - expected_dag_ids = {"hello", "world"} - dag_bag = fetch_dag_bag(t) - actual_dag_ids = set(dag_bag.dag_ids) - self.assertSetEqual(expected_dag_ids, actual_dag_ids) - - @patch("observatory.platform.airflow.SlackWebhookHook") + @patch("observatory_platform.airflow.airflow.SlackWebhookHook") @patch("airflow.hooks.base.BaseHook.get_connection") def test_send_slack_msg(self, mock_get_connection, m_slack): slack_webhook_conn_id = "slack_conn" @@ -168,7 +97,7 @@ def __init__(self): m_slack.return_value.send_text.assert_called_once_with(message) def test_get_airflow_connection_url_invalid(self): - with patch("observatory.platform.airflow.BaseHook") as m_basehook: + with patch("observatory_platform.airflow.airflow.BaseHook") as m_basehook: m_basehook.get_connection = MagicMock(return_value=MockConnection("")) self.assertRaises(AirflowException, get_airflow_connection_url, "some_connection") @@ -179,7 +108,7 @@ def test_get_airflow_connection_url_valid(self): expected_url = "http://localhost/" fake_conn = "some_connection" - with patch("observatory.platform.airflow.BaseHook") as m_basehook: + with patch("observatory_platform.airflow.airflow.BaseHook") as m_basehook: # With trailing / input_url = "http://localhost/" m_basehook.get_connection = MagicMock(return_value=MockConnection(input_url)) @@ -193,7 +122,7 @@ def test_get_airflow_connection_url_valid(self): self.assertEqual(url, expected_url) def test_get_airflow_connection_password(self): - env = ObservatoryEnvironment(enable_api=False) + env = SandboxEnvironment() with env.create(): # Assert that we can get a connection password conn_id = "conn_1" @@ -208,7 +137,7 @@ def test_get_airflow_connection_password(self): get_airflow_connection_password(conn_id) def test_get_airflow_connection_login(self): - env = ObservatoryEnvironment(enable_api=False) + env = SandboxEnvironment() with env.create(): # Assert that we can get a connection login conn_id = "conn_1" @@ -249,7 +178,7 @@ def create_xcom(**kwargs): execution_date = kwargs["execution_date"] ti.xcom_push("topic", {"snapshot_date": execution_date.format("YYYYMMDD"), "something": "info"}) - env = ObservatoryEnvironment(enable_api=False) + env = SandboxEnvironment() with env.create(): execution_date = pendulum.datetime(2021, 9, 5) with DAG( @@ -291,7 +220,7 @@ def get_xcom(session=None, dag_id=None, task_id=None, key=None, execution_date=N ).with_entities(XCom.value) return msgs.all() - env = ObservatoryEnvironment(enable_api=False) + env = SandboxEnvironment() with env.create(): first_execution_date = pendulum.datetime(2021, 9, 5) with DAG( @@ -335,7 +264,7 @@ def get_xcom(session=None, dag_id=None, task_id=None, key=None, execution_date=N msg = XCom.deserialize_value(xcoms[0]) self.assertEqual(msg["snapshot_date"], second_execution_date.format("YYYYMMDD")) - @patch("observatory.platform.airflow.send_slack_msg") + @patch("observatory_platform.airflow.airflow.send_slack_msg") def test_on_failure_callback(self, mock_send_slack_msg): # Fake Airflow ti instance class MockTI: @@ -361,10 +290,68 @@ def __init__(self): slack_conn_id="slack", ) + @patch("observatory_platform.airflow.airflow.send_slack_msg") + def test_callback(self, mock_send_slack_msg): + """Test that the on_failure_callback function is successfully called in a production environment when a task + fails + + :param mock_send_slack_msg: Mock send_slack_msg function + :return: None. + """ + + def create_dag(dag_id: str, start_date: pendulum.DateTime, schedule: str, retries: int, airflow_conns: list): + @dag( + dag_id=dag_id, + start_date=start_date, + schedule=schedule, + default_args=dict(retries=retries, on_failure_callback=on_failure_callback), + ) + def callback_test_dag(): + check_dependencies(airflow_conns=airflow_conns) + return callback_test_dag() + + # Setup Observatory environment + project_id = os.getenv("TEST_GCP_PROJECT_ID") + data_location = os.getenv("TEST_GCP_DATA_LOCATION") + env = SandboxEnvironment(project_id, data_location) + + # Setup Workflow with 0 retries and missing airflow variable, so it will fail the task + execution_date = pendulum.datetime(2020, 1, 1) + conn_id = "orcid_bucket" + my_dag = create_dag( + "test_callback", + execution_date, + "@weekly", + retries=0, + airflow_conns=[conn_id], + ) + + # Create the Observatory environment and run task, expecting slack webhook call in production environment + with env.create(task_logging=True): + with env.create_dag_run(my_dag, execution_date): + with self.assertRaises(AirflowNotFoundException): + env.run_task("check_dependencies") + + _, callkwargs = mock_send_slack_msg.call_args + self.assertTrue( + "airflow.exceptions.AirflowNotFoundException: Required variables or connections are missing" + in callkwargs["comments"] + ) + + # Reset mock + mock_send_slack_msg.reset_mock() + + # Add orcid_bucket connection and test that Slack Web Hook did not get triggered + with env.create(task_logging=True): + with env.create_dag_run(my_dag, execution_date): + env.add_connection(Connection(conn_id=conn_id, uri="https://orcid.org/")) + env.run_task("check_dependencies") + mock_send_slack_msg.assert_not_called() + def test_is_first_dag_run(self): """Test is_first_dag_run""" - env = ObservatoryEnvironment(enable_api=False) + env = SandboxEnvironment() with env.create(): first_execution_date = pendulum.datetime(2021, 9, 5) with DAG( diff --git a/observatory_platform/airflow/tests/test_release.py b/observatory_platform/airflow/tests/test_release.py new file mode 100644 index 000000000..17c998a6d --- /dev/null +++ b/observatory_platform/airflow/tests/test_release.py @@ -0,0 +1,37 @@ +# Copyright 2019-2024 Curtin University +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pendulum +from airflow.exceptions import AirflowException + +from observatory_platform.airflow.release import set_task_state, make_snapshot_date +from observatory_platform.sandbox.test_utils import SandboxTestCase + + +class TestWorkflow(SandboxTestCase): + def test_make_snapshot_date(self): + """Test make_table_name""" + + data_interval_end = pendulum.datetime(2021, 11, 11) + expected_date = pendulum.datetime(2021, 11, 11) + actual_date = make_snapshot_date(**{"data_interval_end": data_interval_end}) + self.assertEqual(expected_date, actual_date) + + def test_set_task_state(self): + """Test set_task_state""" + + task_id = "test_task" + set_task_state(True, task_id) + with self.assertRaises(AirflowException): + set_task_state(False, task_id) diff --git a/tests/observatory/platform/utils/test_dag_run_sensor.py b/observatory_platform/airflow/tests/test_sensors.py similarity index 72% rename from tests/observatory/platform/utils/test_dag_run_sensor.py rename to observatory_platform/airflow/tests/test_sensors.py index de9ca77f6..37752da9d 100644 --- a/tests/observatory/platform/utils/test_dag_run_sensor.py +++ b/observatory_platform/airflow/tests/test_sensors.py @@ -14,36 +14,39 @@ # # Author: Tuan Chien +from __future__ import annotations + import datetime +from airflow.models.dag import DAG import os.path from unittest.mock import patch import pendulum from airflow.exceptions import AirflowException, AirflowSensorTimeout from airflow.models import DagRun, DagModel +from airflow.operators.python import PythonOperator from airflow.utils.session import provide_session from airflow.utils.state import DagRunState, State -from observatory.platform.observatory_environment import ObservatoryEnvironment, ObservatoryTestCase, make_dummy_dag -from observatory.platform.utils.dag_run_sensor import DagRunSensor -from observatory.platform.workflows.workflow import Workflow +from observatory_platform.airflow.sensors import DagRunSensor +from observatory_platform.sandbox.sandbox_environment import SandboxEnvironment +from observatory_platform.sandbox.test_utils import SandboxTestCase, make_dummy_dag -class MonitoringWorkflow(Workflow): - DAG_ID = "test_workflow" +def create_dag( + *, + start_date: pendulum.DateTime, + ext_dag_id: str, + dag_id: str = "test_workflow", + schedule: str = "@monthly", + mode: str = "reschedule", + check_exists: bool = True, + catchup: bool = False, +): + with DAG(dag_id=dag_id, schedule=schedule, start_date=start_date, catchup=catchup) as dag: - def __init__( - self, - *, - start_date: pendulum.DateTime, - ext_dag_id: str, - schedule: str = "@monthly", - mode: str = "reschedule", - check_exists: bool = True, - ): - super().__init__( - dag_id=MonitoringWorkflow.DAG_ID, start_date=start_date, schedule=schedule, catchup=False - ) + def dummy_task(): + print("Hello world") sensor = DagRunSensor( task_id="sensor_task", @@ -56,18 +59,19 @@ def __init__( grace_period=datetime.timedelta(seconds=1), ) - self.add_operator(sensor) - self.add_task(self.dummy_task) + # Use the PythonOperator to run the Python functions + dummy_task_instance = PythonOperator( + task_id="dummy_task", + python_callable=dummy_task, + ) - def make_release(self, **kwargs): - return None + # Define the task sequence + sensor >> dummy_task_instance - def dummy_task(self, release, **kwargs): - if not self.succeed: - raise ValueError("Problem") + return dag -class TestDagRunSensor(ObservatoryTestCase): +class TestDagRunSensor(SandboxTestCase): """Test the Task Window Sensor. We use one of the stock example dags""" def __init__(self, *args, **kwargs): @@ -87,46 +91,42 @@ def update_db(self, *, session, object): session.commit() def test_no_dag_exists(self): - env = ObservatoryEnvironment() + env = SandboxEnvironment() with env.create(): execution_date = pendulum.datetime(2021, 9, 1) - wf = MonitoringWorkflow(start_date=self.start_date, ext_dag_id="nodag", check_exists=True) - dag = wf.make_dag() + dag = create_dag(start_date=self.start_date, ext_dag_id="nodag", check_exists=True) with env.create_dag_run(dag=dag, execution_date=execution_date): self.assertRaises(AirflowException, env.run_task, self.sensor_task_id) def test_no_dag_exists_no_check(self): - env = ObservatoryEnvironment() + env = SandboxEnvironment() with env.create(): execution_date = pendulum.datetime(2021, 9, 1) - wf = MonitoringWorkflow(start_date=self.start_date, ext_dag_id="nodag", check_exists=False) - dag = wf.make_dag() + dag = create_dag(start_date=self.start_date, ext_dag_id="nodag", check_exists=False) with env.create_dag_run(dag=dag, execution_date=execution_date): ti = env.run_task(self.sensor_task_id) self.assertEqual(ti.state, State.SUCCESS) def test_no_execution_date_in_range(self): - env = ObservatoryEnvironment() + env = SandboxEnvironment() with env.create() as t: self.add_dummy_dag_model(t, self.ext_dag_id, "@weekly") execution_date = pendulum.datetime(2021, 9, 1) - wf = MonitoringWorkflow(start_date=self.start_date, ext_dag_id=self.ext_dag_id) - dag = wf.make_dag() + dag = create_dag(start_date=self.start_date, ext_dag_id=self.ext_dag_id) with env.create_dag_run(dag=dag, execution_date=execution_date): ti = env.run_task(self.sensor_task_id) self.assertEqual(ti.state, State.SUCCESS) - @patch("observatory.platform.utils.dag_run_sensor.DagRunSensor.get_latest_execution_date") + @patch("observatory_platform.airflow.sensors.DagRunSensor.get_latest_execution_date") def test_grace_period(self, m_get_execdate): m_get_execdate.return_value = None - env = ObservatoryEnvironment() + env = SandboxEnvironment() with env.create() as t: self.add_dummy_dag_model(t, self.ext_dag_id, "@weekly") execution_date = pendulum.datetime(2021, 9, 1) - wf = MonitoringWorkflow(start_date=self.start_date, ext_dag_id=self.ext_dag_id) - dag = wf.make_dag() + dag = create_dag(start_date=self.start_date, ext_dag_id=self.ext_dag_id) with env.create_dag_run(dag=dag, execution_date=execution_date): ti = env.run_task(self.sensor_task_id) self.assertEqual(ti.state, State.SUCCESS) @@ -142,7 +142,7 @@ def add_dummy_dag_model(self, t: str, dag_id: str, schedule: str): self.update_db(object=model) def run_dummy_dag( - self, env: ObservatoryEnvironment, execution_date: pendulum.DateTime, task_id: str = "dummy_task" + self, env: SandboxEnvironment, execution_date: pendulum.DateTime, task_id: str = "dummy_task" ): dag = make_dummy_dag(self.ext_dag_id, execution_date) @@ -158,33 +158,31 @@ def run_dummy_dag( self.assertEqual(dagruns[-1].execution_date, execution_date) def test_execution_on_oldest_boundary(self): - env = ObservatoryEnvironment() + env = SandboxEnvironment() with env.create(): execution_date = pendulum.datetime(2021, 8, 25) self.run_dummy_dag(env, execution_date) execution_date = pendulum.datetime(2021, 9, 1) - wf = MonitoringWorkflow(start_date=self.start_date, ext_dag_id=self.ext_dag_id) - dag = wf.make_dag() + dag = create_dag(start_date=self.start_date, ext_dag_id=self.ext_dag_id) with env.create_dag_run(dag=dag, execution_date=execution_date): ti = env.run_task(self.sensor_task_id) self.assertEqual(ti.state, State.SUCCESS) def test_execution_on_newest_boundary(self): - env = ObservatoryEnvironment() + env = SandboxEnvironment() with env.create(): execution_date = pendulum.datetime(2021, 9, 1) self.run_dummy_dag(env, execution_date) execution_date = pendulum.datetime(2021, 9, 1) - wf = MonitoringWorkflow(start_date=self.start_date, ext_dag_id=self.ext_dag_id) - dag = wf.make_dag() + dag = create_dag(start_date=self.start_date, ext_dag_id=self.ext_dag_id) with env.create_dag_run(dag=dag, execution_date=execution_date): ti = env.run_task(self.sensor_task_id) self.assertEqual(ti.state, State.SUCCESS) def test_execution_multiple_dagruns_last_success(self): - env = ObservatoryEnvironment() + env = SandboxEnvironment() with env.create(): execution_date = pendulum.datetime(2021, 8, 25) self.run_dummy_dag(env, execution_date) @@ -193,8 +191,7 @@ def test_execution_multiple_dagruns_last_success(self): self.run_dummy_dag(env, execution_date) execution_date = pendulum.datetime(2021, 9, 1) - wf = MonitoringWorkflow(start_date=self.start_date, ext_dag_id=self.ext_dag_id) - dag = wf.make_dag() + dag = create_dag(start_date=self.start_date, ext_dag_id=self.ext_dag_id) with env.create_dag_run(dag=dag, execution_date=execution_date): ti = env.run_task(self.sensor_task_id) self.assertEqual(ti.state, State.SUCCESS) @@ -206,7 +203,7 @@ def fail_last_dag_run(self): self.update_db(object=last_dag_run) def test_execution_multiple_dagruns_last_fail_reschedule_mode(self): - env = ObservatoryEnvironment() + env = SandboxEnvironment() with env.create(): execution_date = pendulum.datetime(2021, 8, 25) self.run_dummy_dag(env, execution_date) @@ -216,14 +213,13 @@ def test_execution_multiple_dagruns_last_fail_reschedule_mode(self): self.fail_last_dag_run() execution_date = pendulum.datetime(2021, 9, 1) - wf = MonitoringWorkflow(start_date=self.start_date, ext_dag_id=self.ext_dag_id) - dag = wf.make_dag() + dag = create_dag(start_date=self.start_date, ext_dag_id=self.ext_dag_id) with env.create_dag_run(dag=dag, execution_date=execution_date): ti = env.run_task(self.sensor_task_id) self.assertEqual(ti.state, "up_for_reschedule") def test_execution_multiple_dagruns_last_fail_poke_mode(self): - env = ObservatoryEnvironment() + env = SandboxEnvironment() with env.create(): execution_date = pendulum.datetime(2021, 8, 25) self.run_dummy_dag(env, execution_date) @@ -233,8 +229,7 @@ def test_execution_multiple_dagruns_last_fail_poke_mode(self): self.fail_last_dag_run() execution_date = pendulum.datetime(2021, 9, 1) - wf = MonitoringWorkflow(start_date=self.start_date, ext_dag_id=self.ext_dag_id, mode="poke") - dag = wf.make_dag() + dag = create_dag(start_date=self.start_date, ext_dag_id=self.ext_dag_id, mode="poke") with env.create_dag_run(dag=dag, execution_date=execution_date): # ti = env.run_task(self.sensor_task_id) self.assertRaises(AirflowSensorTimeout, env.run_task, self.sensor_task_id) diff --git a/observatory_platform/airflow/tests/test_workflow.py b/observatory_platform/airflow/tests/test_workflow.py new file mode 100644 index 000000000..dd3481e96 --- /dev/null +++ b/observatory_platform/airflow/tests/test_workflow.py @@ -0,0 +1,121 @@ +# Copyright 2019-2024 Curtin University +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Author: James Diprose, Aniek Roelofs, Tuan-Chien + +import os +import shutil +from tempfile import TemporaryDirectory +from unittest.mock import patch + +import pendulum +from airflow.exceptions import AirflowException + +from observatory_platform.airflow.workflow import ( + Workflow, + workflows_to_json_string, + json_string_to_workflows, +) +from observatory_platform.airflow.workflow import get_data_path, fetch_dag_bag, make_workflow_folder +from observatory_platform.config import module_file_path +from observatory_platform.sandbox.sandbox_environment import SandboxEnvironment +from observatory_platform.sandbox.test_utils import SandboxTestCase + + +class TestWorkflow(SandboxTestCase): + + def __init__(self, *args, **kwargs): + super(TestWorkflow, self).__init__(*args, **kwargs) + self.fixtures_path = module_file_path("observatory_platform.airflow.tests.fixtures") + + @patch("observatory_platform.airflow.workflow.Variable.get") + def test_get_data_path(self, mock_variable_get): + """Tests the function that retrieves the data_path airflow variable""" + # 1 - no variable available + mock_variable_get.return_value = None + self.assertRaises(AirflowException, get_data_path) + + # 2 - available in Airflow variable + mock_variable_get.return_value = "env_return" + self.assertEqual("env_return", get_data_path()) + + def test_fetch_dag_bag(self): + """Test fetch_dag_bag""" + + env = SandboxEnvironment() + with env.create() as t: + # No DAGs found + dag_bag = fetch_dag_bag(t) + print(f"DAGS found on path: {t}") + for dag_id in dag_bag.dag_ids: + print(f" {dag_id}") + self.assertEqual(0, len(dag_bag.dag_ids)) + + # Bad DAG + src = os.path.join(self.fixtures_path, "bad_dag.py") + shutil.copy(src, os.path.join(t, "dags.py")) + with self.assertRaises(Exception): + fetch_dag_bag(t) + + # Copy Good DAGs to folder + src = os.path.join(self.fixtures_path, "good_dag.py") + shutil.copy(src, os.path.join(t, "dags.py")) + + # DAGs found + expected_dag_ids = {"hello", "world"} + dag_bag = fetch_dag_bag(t) + actual_dag_ids = set(dag_bag.dag_ids) + self.assertSetEqual(expected_dag_ids, actual_dag_ids) + + @patch("observatory_platform.airflow.workflow.Variable.get") + def test_make_workflow_folder(self, mock_get_variable): + """Tests the make_workflow_folder function""" + with TemporaryDirectory() as tempdir: + mock_get_variable.return_value = tempdir + run_id = "scheduled__2023-03-26T00:00:00+00:00" # Also can look like: "manual__2023-03-26T00:00:00+00:00" + path = make_workflow_folder("test_dag", run_id, "sub_folder", "subsub_folder") + self.assertEqual( + path, + os.path.join(tempdir, f"test_dag/scheduled__2023-03-26T00:00:00+00:00/sub_folder/subsub_folder"), + ) + + def test_workflows_to_json_string(self): + workflows = [ + Workflow( + dag_id="my_dag", + name="My DAG", + class_name="observatory_platform.workflows.vm_workflow.VmCreateWorkflow", + kwargs=dict(dt=pendulum.datetime(2021, 1, 1)), + ) + ] + json_string = workflows_to_json_string(workflows) + self.assertEqual( + '[{"dag_id": "my_dag", "name": "My DAG", "class_name": "observatory_platform.workflows.vm_workflow.VmCreateWorkflow", "cloud_workspace": null, "kwargs": {"dt": "2021-01-01T00:00:00+00:00"}}]', + json_string, + ) + + def test_json_string_to_workflows(self): + json_string = '[{"dag_id": "my_dag", "name": "My DAG", "class_name": "observatory_platform.workflows.vm_workflow.VmCreateWorkflow", "cloud_workspace": null, "kwargs": {"dt": "2021-01-01T00:00:00+00:00"}}]' + actual_workflows = json_string_to_workflows(json_string) + self.assertEqual( + [ + Workflow( + dag_id="my_dag", + name="My DAG", + class_name="observatory_platform.workflows.vm_workflow.VmCreateWorkflow", + kwargs=dict(dt=pendulum.datetime(2021, 1, 1)), + ) + ], + actual_workflows, + ) diff --git a/observatory_platform/airflow/workflow.py b/observatory_platform/airflow/workflow.py new file mode 100644 index 000000000..1373848fd --- /dev/null +++ b/observatory_platform/airflow/workflow.py @@ -0,0 +1,372 @@ +# Copyright 2019-2024 Curtin University +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Author: James Diprose, Aniek Roelofs, Tuan Chien + + +from __future__ import annotations + +import json +import logging +import os +import shutil +from dataclasses import dataclass, field +from pydoc import locate +from typing import Any, Dict, List, Optional + +import pendulum +from airflow import AirflowException +from airflow.models import DagBag, Variable + +from observatory_platform.airflow.airflow import delete_old_xcoms +from observatory_platform.config import AirflowVars + + +def get_data_path() -> str: + """Grabs the DATA_PATH airflow vairable + + :raises AirflowException: Raised if the variable does not exist + :return: DATA_PATH variable contents + """ + + # Try to get environment variable from environment variable first + data_path = os.environ.get(AirflowVars.DATA_PATH) + if data_path is not None: + return data_path + + # Try to get from Airflow Variable + data_path = Variable.get(AirflowVars.DATA_PATH) + if data_path is not None: + return data_path + + raise AirflowException("DATA_PATH variable could not be found.") + + +def fetch_workflows() -> List[Workflow]: + workflows = [] + workflows_str = Variable.get(AirflowVars.WORKFLOWS) + logging.info(f"workflows_str: {workflows_str}") + + if workflows_str is not None and workflows_str.strip() != "": + try: + workflows = json_string_to_workflows(workflows_str) + logging.info(f"workflows: {workflows}") + except json.decoder.JSONDecodeError as e: + e.msg = f"workflows_str: {workflows_str}\n\n{e.msg}" + + return workflows + + +def load_dags_from_config(): + for workflow in fetch_workflows(): + dag_id = workflow.dag_id + logging.info(f"Making Workflow: {workflow.name}, dag_id={dag_id}") + dag = make_dag(workflow) + + logging.info(f"Adding DAG: dag_id={dag_id}, dag={dag}") + globals()[dag_id] = dag + + +def make_dag(workflow: Workflow): + """Make a DAG instance from a Workflow config. + :param workflow: the workflow configuration. + :return: the workflow instance. + """ + + cls = locate(workflow.class_name) + if cls is None: + raise ModuleNotFoundError(f"dag_id={workflow.dag_id}: could not locate class_name={workflow.class_name}") + + return cls(dag_id=workflow.dag_id, cloud_workspace=workflow.cloud_workspace, **workflow.kwargs) + + +def make_workflow_folder(dag_id: str, run_id: str, *subdirs: str) -> str: + """Return the path to this dag release's workflow folder. Will also create it if it doesn't exist + + :param dag_id: The ID of the dag. This is used to find/create the workflow folder + :param run_id: The Airflow DAGs run ID. Examples: "scheduled__2023-03-26T00:00:00+00:00" or "manual__2023-03-26T00:00:00+00:00". + :param subdirs: The folder path structure (if any) to create inside the workspace. e.g. 'download' or 'transform' + :return: the path of the workflow folder + """ + + path = os.path.join(get_data_path(), dag_id, run_id, *subdirs) + os.makedirs(path, exist_ok=True) + return path + + +def fetch_dag_bag(path: str, include_examples: bool = False) -> DagBag: + """Load a DAG Bag from a given path. + + :param path: the path to the DAG bag. + :param include_examples: whether to include example DAGs or not. + :return: None. + """ + logging.info(f"Loading DAG bag from path: {path}") + dag_bag = DagBag(path, include_examples=include_examples) + + if dag_bag is None: + raise Exception(f"DagBag could not be loaded from path: {path}") + + if len(dag_bag.import_errors): + # Collate loading errors as single string and raise it as exception + results = [] + for path, exception in dag_bag.import_errors.items(): + results.append(f"DAG import exception: {path}\n{exception}\n\n") + raise Exception("\n".join(results)) + + return dag_bag + + +def cleanup(dag_id: str, execution_date: str, workflow_folder: str = None, retention_days=31) -> None: + """Delete all files, folders and XComs associated from a release. + + :param dag_id: The ID of the DAG to remove XComs + :param execution_date: The execution date of the DAG run + :param workflow_folder: The top-level workflow folder to clean up + :param retention_days: How many days of Xcom messages to retain + """ + if workflow_folder: + try: + shutil.rmtree(workflow_folder) + except FileNotFoundError as e: + logging.warning(f"No such file or directory {workflow_folder}: {e}") + + delete_old_xcoms(dag_id=dag_id, execution_date=execution_date, retention_days=retention_days) + + +class CloudWorkspace: + def __init__( + self, + *, + project_id: str, + download_bucket: str, + transform_bucket: str, + data_location: str, + output_project_id: Optional[str] = None, + ): + """The CloudWorkspace settings used by workflows. + + project_id: the Google Cloud project id. input_project_id is an alias for project_id. + download_bucket: the Google Cloud Storage bucket where downloads will be stored. + transform_bucket: the Google Cloud Storage bucket where transformed data will be stored. + data_location: the data location for storing information, e.g. where BigQuery datasets should be located. + output_project_id: an optional Google Cloud project id when the outputs of a workflow should be stored in a + different project to the inputs. If an output_project_id is not supplied, the project_id will be used. + """ + + self._project_id = project_id + self._download_bucket = download_bucket + self._transform_bucket = transform_bucket + self._data_location = data_location + self._output_project_id = output_project_id + + @property + def project_id(self) -> str: + return self._project_id + + @project_id.setter + def project_id(self, project_id: str): + self._project_id = project_id + + @property + def download_bucket(self) -> str: + return self._download_bucket + + @download_bucket.setter + def download_bucket(self, download_bucket: str): + self._download_bucket = download_bucket + + @property + def transform_bucket(self) -> str: + return self._transform_bucket + + @transform_bucket.setter + def transform_bucket(self, transform_bucket: str): + self._transform_bucket = transform_bucket + + @property + def data_location(self) -> str: + return self._data_location + + @data_location.setter + def data_location(self, data_location: str): + self._data_location = data_location + + @property + def input_project_id(self) -> str: + return self._project_id + + @input_project_id.setter + def input_project_id(self, project_id: str): + self._project_id = project_id + + @property + def output_project_id(self) -> Optional[str]: + if self._output_project_id is None: + return self._project_id + return self._output_project_id + + @output_project_id.setter + def output_project_id(self, output_project_id: Optional[str]): + self._output_project_id = output_project_id + + @staticmethod + def from_dict(dict_: Dict) -> CloudWorkspace: + """Constructs a CloudWorkspace instance from a dictionary. + + :param dict_: the dictionary. + :return: the Workflow instance. + """ + + project_id = dict_.get("project_id") + download_bucket = dict_.get("download_bucket") + transform_bucket = dict_.get("transform_bucket") + data_location = dict_.get("data_location") + output_project_id = dict_.get("output_project_id") + + return CloudWorkspace( + project_id=project_id, + download_bucket=download_bucket, + transform_bucket=transform_bucket, + data_location=data_location, + output_project_id=output_project_id, + ) + + def to_dict(self) -> Dict: + """CloudWorkspace instance to dictionary. + + :return: the dictionary. + """ + + return dict( + project_id=self._project_id, + download_bucket=self._download_bucket, + transform_bucket=self._transform_bucket, + data_location=self._data_location, + output_project_id=self.output_project_id, + ) + + @staticmethod + def parse_cloud_workspaces(list: List) -> List[CloudWorkspace]: + """Parse the cloud workspaces list object into a list of CloudWorkspace instances. + + :param list: the list. + :return: a list of CloudWorkspace instances. + """ + + return [CloudWorkspace.from_dict(dict_) for dict_ in list] + + +@dataclass +class Workflow: + """A Workflow configuration. + + Attributes: + dag_id: the Airflow DAG identifier for the workflow. + name: a user-friendly name for the workflow. + class_name: the fully qualified class name for the workflow class. + cloud_workspace: the Cloud Workspace to use when running the workflow. + kwargs: a dictionary containing optional keyword arguments that are injected into the workflow constructor. + """ + + dag_id: str = None + name: str = None + class_name: str = None + cloud_workspace: CloudWorkspace = None + kwargs: Optional[Dict] = field(default_factory=lambda: dict()) + + def to_dict(self) -> Dict: + """Workflow instance to dictionary. + + :return: the dictionary. + """ + + cloud_workspace = self.cloud_workspace + if self.cloud_workspace is not None: + cloud_workspace = self.cloud_workspace.to_dict() + + return dict( + dag_id=self.dag_id, + name=self.name, + class_name=self.class_name, + cloud_workspace=cloud_workspace, + kwargs=self.kwargs, + ) + + @staticmethod + def from_dict(dict_: Dict) -> Workflow: + """Constructs a Workflow instance from a dictionary. + + :param dict_: the dictionary. + :return: the Workflow instance. + """ + + dag_id = dict_.get("dag_id") + name = dict_.get("name") + class_name = dict_.get("class_name") + + cloud_workspace = dict_.get("cloud_workspace") + if cloud_workspace is not None: + cloud_workspace = CloudWorkspace.from_dict(cloud_workspace) + + kwargs = dict_.get("kwargs", dict()) + + return Workflow(dag_id, name, class_name, cloud_workspace, kwargs) + + @staticmethod + def parse_workflows(list: List) -> List[Workflow]: + """Parse the workflows list object into a list of Workflow instances. + + :param list: the list. + :return: a list of Workflow instances. + """ + + return [Workflow.from_dict(dict_) for dict_ in list] + + +class PendulumDateTimeEncoder(json.JSONEncoder): + def default(self, obj: Any) -> Any: + if isinstance(obj, pendulum.DateTime): + return obj.isoformat() + return super().default(obj) + + +def workflows_to_json_string(workflows: List[Workflow]) -> str: + """Covnert a list of Workflow instances to a JSON string. + + :param workflows: the Workflow instances. + :return: a JSON string. + """ + + data = [workflow.to_dict() for workflow in workflows] + return json.dumps(data, cls=PendulumDateTimeEncoder) + + +def json_string_to_workflows(json_string: str) -> List[Workflow]: + """Convert a JSON string into a list of Workflow instances. + + :param json_string: a JSON string version of a list of Workflow instances. + :return: a list of Workflow instances. + """ + + def parse_datetime(obj): + for key, value in obj.items(): + try: + obj[key] = pendulum.parse(value) + except (ValueError, TypeError): + pass + return obj + + data = json.loads(json_string, object_hook=parse_datetime) + return Workflow.parse_workflows(data) diff --git a/observatory-platform/observatory/platform/config.py b/observatory_platform/config.py similarity index 83% rename from observatory-platform/observatory/platform/config.py rename to observatory_platform/config.py index 0ce702c76..4ca9105e9 100644 --- a/observatory-platform/observatory/platform/config.py +++ b/observatory_platform/config.py @@ -21,15 +21,15 @@ class AirflowVars: - DATA_PATH = "data_path" - WORKFLOWS = "workflows" - DAGS_MODULE_NAMES = "dags_module_names" + DATA_PATH = "DATA_PATH" + WORKFLOWS = "WORKFLOWS" class AirflowConns: SLACK = "slack" TERRAFORM = "terraform" OBSERVATORY_API = "observatory_api" + GCP_CONN_ID = "google_cloud_default" class Tag: @@ -41,7 +41,7 @@ class Tag: def module_file_path(module_path: str, nav_back_steps: int = -1) -> str: """Get the file path of a module, given the Python import path to the module. - :param module_path: the Python import path to the module, e.g. observatory.platform.dags + :param module_path: the Python import path to the module, e.g. observatory_platform.dags :param nav_back_steps: the number of steps on the path to step back. :return: the file path to the module. """ @@ -74,19 +74,10 @@ def observatory_home(*subdirs) -> str: return path -def terraform_credentials_path() -> str: - """Get the path to the terraform credentials file that is created with 'terraform login'. - - :return: the path to the terraform credentials file - """ - - return os.path.join(pathlib.Path.home(), ".terraform.d/credentials.tfrc.json") - - def sql_templates_path() -> str: """Return the path to the SQL templates. :return: the path. """ - return module_file_path("observatory.platform.sql") + return module_file_path("observatory_platform.sql") diff --git a/observatory_platform/dataset_api.py b/observatory_platform/dataset_api.py new file mode 100644 index 000000000..25e2429c7 --- /dev/null +++ b/observatory_platform/dataset_api.py @@ -0,0 +1,349 @@ +# Copyright 2020-2024 Curtin University +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Author: Author: Tuan Chien, James Diprose + +from __future__ import annotations + +import dataclasses +import datetime +import os.path +from typing import List, Optional, Dict, Union + +import pendulum +from google.cloud import bigquery + +from observatory_platform.config import module_file_path +from observatory_platform.google.bigquery import ( + bq_load_from_memory, + bq_create_dataset, + bq_create_empty_table, + bq_run_query, +) + + +@dataclasses.dataclass +class DatasetRelease: + id: int + dag_id: str + dataset_id: str + dag_run_id: str + data_interval_start: pendulum.DateTime + data_interval_end: pendulum.DateTime + snapshot_date: pendulum.DateTime + partition_date: pendulum.DateTime + changefile_start_date: pendulum.DateTime + changefile_end_date: pendulum.DateTime + sequence_start: int + sequence_end: int + extra: Dict + created: pendulum.DateTime + modified: pendulum.DateTime + + def __init__( + self, + *, + dag_id: str, + dataset_id: str, + dag_run_id: str, + created: pendulum.DateTime, + modified: pendulum.DateTime, + data_interval_start: Union[pendulum.DateTime, str] = None, + data_interval_end: Union[pendulum.DateTime, str] = None, + snapshot_date: Union[pendulum.DateTime, str] = None, + partition_date: Union[pendulum.DateTime, str] = None, + changefile_start_date: Union[pendulum.DateTime, str] = None, + changefile_end_date: Union[pendulum.DateTime, str] = None, + sequence_start: int = None, + sequence_end: int = None, + extra: dict = None, + ): + """Construct a DatasetRelease object. + + :param dag_id: the DAG ID. + :param dataset_id: the dataset ID. + :param dag_run_id: the DAG's run ID. + :param created: datetime created in UTC. + :param modified: datetime modified in UTC. + :param data_interval_start: the DAGs data interval start. Date is inclusive. + :param data_interval_end: the DAGs data interval end. Date is exclusive. + :param snapshot_date: the release date of the snapshot. + :param partition_date: the partition date. + :param changefile_start_date: the date of the first changefile processed in this release. + :param changefile_end_date: the date of the last changefile processed in this release. + :param sequence_start: the starting sequence number of files that make up this release. + :param sequence_end: the end sequence number of files that make up this release. + :param extra: optional extra field for storing any data. + + """ + + self.dag_id = dag_id + self.dataset_id = dataset_id + self.dag_run_id = dag_run_id + self.data_interval_start = data_interval_start + self.data_interval_end = data_interval_end + self.snapshot_date = snapshot_date + self.partition_date = partition_date + self.changefile_start_date = changefile_start_date + self.changefile_end_date = changefile_end_date + self.sequence_start = sequence_start + self.sequence_end = sequence_end + self.extra = extra + self.created = created + self.modified = modified + + @staticmethod + def from_dict(_dict: Dict) -> DatasetRelease: + return DatasetRelease( + dag_id=_dict["dag_id"], + dataset_id=_dict["dataset_id"], + dag_run_id=_dict["dag_run_id"], + created=bq_timestamp_to_pendulum(_dict["created"]), + modified=bq_timestamp_to_pendulum(_dict["modified"]), + data_interval_start=bq_timestamp_to_pendulum(_dict.get("data_interval_start")), + data_interval_end=bq_timestamp_to_pendulum(_dict.get("data_interval_end")), + snapshot_date=bq_timestamp_to_pendulum(_dict.get("snapshot_date")), + partition_date=bq_timestamp_to_pendulum(_dict.get("partition_date")), + changefile_start_date=bq_timestamp_to_pendulum(_dict.get("changefile_start_date")), + changefile_end_date=bq_timestamp_to_pendulum(_dict.get("changefile_end_date")), + sequence_start=_dict.get("sequence_start"), + sequence_end=_dict.get("sequence_end"), + extra=_dict.get("extra"), + ) + + def to_dict(self) -> Dict: + return dict( + dag_id=self.dag_id, + dataset_id=self.dataset_id, + dag_run_id=self.dag_run_id, + created=self.created.to_iso8601_string(), + modified=self.modified.to_iso8601_string(), + data_interval_start=pendulum_to_bq_timestamp(self.data_interval_start), + data_interval_end=pendulum_to_bq_timestamp(self.data_interval_end), + snapshot_date=pendulum_to_bq_timestamp(self.snapshot_date), + partition_date=pendulum_to_bq_timestamp(self.partition_date), + changefile_start_date=pendulum_to_bq_timestamp(self.changefile_start_date), + changefile_end_date=pendulum_to_bq_timestamp(self.changefile_end_date), + sequence_start=self.sequence_start, + sequence_end=self.sequence_end, + extra=self.extra, + ) + + def __eq__(self, other): + if isinstance(other, DatasetRelease): + return self.__dict__ == other.__dict__ + return False + + +class DatasetAPI: + def __init__( + self, + project_id: str = None, + dataset_id: str = "dataset_api", + table_id: str = "dataset_releases", + location: str = "us", + client: Optional[bigquery.Client] = None, + ): + """Create a DatasetAPI instance. + + :param project_id: the BigQuery project ID. + :param dataset_id: the BigQuery dataset ID. + :param table_id: the BigQuery table ID. + :param location: the BigQuery dataset location. + :param client: Optional BigQuery client. + """ + + parts = [] + if project_id is None: + project_id = get_bigquery_default_project() + parts.append(project_id) + parts.append(dataset_id) + parts.append(table_id) + + self.project_id = project_id + self.dataset_id = dataset_id + self.table_id = table_id + self.location = location + self.client = client + self.full_table_id = ".".join(parts) + self.schema_file_path = os.path.join(module_file_path("observatory_platform.schema"), "dataset_release.json") + + def seed_db(self): + """Seed the BigQuery dataset and dataset release table. + + :return: None. + """ + + # Create BigQuery dataset if it does not exist + bq_create_dataset( + project_id=self.project_id, + dataset_id=self.dataset_id, + location=self.location, + description="Observatory Platform Dataset Release API", + client=self.client, + ) + + # Load empty table + bq_create_empty_table( + table_id=self.full_table_id, + schema_file_path=self.schema_file_path, + exists_ok=True, + client=self.client, + ) + + def add_dataset_release(self, release: DatasetRelease): + """Adds a DatasetRelease. + + :param release: the release. + :return: None. + """ + + # Load data + success = bq_load_from_memory( + table_id=self.full_table_id, + records=[release.to_dict()], + write_disposition=bigquery.WriteDisposition.WRITE_APPEND, + schema_file_path=self.schema_file_path, + client=self.client, + ) + if not success: + raise Exception("Failed to add dataset release") + + def get_dataset_releases( + self, *, dag_id: str, dataset_id: str, date_key: str = "created", limit: int | None = None + ) -> List[DatasetRelease]: + """Get a list of dataset releases for a given dataset. + + :param dag_id: dag id. + :param dataset_id: Dataset id. + :param date_key: the date key to use when sorting by date. One of: "created", "modified", "data_interval_start", + "data_interval_end", "snapshot_date", "partition_date", "changefile_start_date" or "changefile_end_date". + :param limit: the maximum number of rows to return. + :return: List of dataset releases. + """ + + valid_date_keys = { + "created", + "modified", + "data_interval_start", + "data_interval_end", + "snapshot_date", + "partition_date", + "changefile_start_date", + "changefile_end_date", + } + if date_key not in valid_date_keys: + raise ValueError(f"get_dataset_releases: invalid date_key: {date_key}, should be one of: {valid_date_keys}") + + sql = [f"SELECT * FROM `{self.full_table_id}` WHERE dag_id = '{dag_id}' AND dataset_id = '{dataset_id}'"] + sql.append(f"ORDER BY {date_key} DESC") + if limit is not None: + sql.append(f"LIMIT {limit}") + + # Fetch results + results = bq_run_query("\n".join(sql), client=self.client) + + # Convert to DatasetRelease objects + results = [DatasetRelease.from_dict(dict(result)) for result in results] + + return results + + def get_latest_dataset_release(self, *, dag_id: str, dataset_id: str, date_key: str) -> Optional[DatasetRelease]: + """Get the latest dataset release. + + :param dag_id: the Airflow DAG id. + :param dataset_id: the dataset id. + :param date_key: the date key. One of: "created", "modified", "data_interval_start", "data_interval_end", + "snapshot_date", "partition_date", "changefile_start_date" or "changefile_end_date". + :return: the latest release or None if there is no release. + """ + + releases = self.get_dataset_releases(dag_id=dag_id, dataset_id=dataset_id, date_key=date_key, limit=1) + if len(releases) == 0: + return None + return releases[0] + + def is_first_release(self, *, dag_id: str, dataset_id: str) -> bool: + """Use the API to check whether this is the first release of a dataset, i.e., are there no dataset release records. + + :param dag_id: DAG ID. + :param dataset_id: dataset id. + :return: Whether this is the first release. + """ + + results = bq_run_query( + f"SELECT COUNT(*) as count FROM `{self.full_table_id}` WHERE dag_id = '{dag_id}' AND dataset_id = '{dataset_id}'", + client=self.client, + ) + count = results[0]["count"] + return count == 0 + + +def build_schedule(sched_start_date: pendulum.DateTime, sched_end_date: pendulum.DateTime): + """Useful for API based data sources. + + Create a fetch schedule to specify what date ranges to use for each API call. Will default to once a month + for now, but in the future if we are minimising API calls, this can be a more complicated scheme. + + :param sched_start_date: the schedule start date. + :param sched_end_date: the end date of the schedule. + :return: list of (section_start_date, section_end_date) pairs from start_date to current Airflow DAG start date. + """ + + schedule = [] + + for start_date in pendulum.Period(start=sched_start_date, end=sched_end_date).range("months"): + if start_date >= sched_end_date: + break + end_date = start_date.add(months=1).subtract(days=1).end_of("day") + end_date = min(sched_end_date, end_date) + schedule.append(pendulum.Period(start_date.date(), end_date.date())) + + return schedule + + +def bq_timestamp_to_pendulum(obj: any) -> pendulum.DateTime | None: + """Convert a BigQuery timestamp to a pendulum DateTime object. + + :param obj: None, a string or datetime.datetime instance. + :return: pendulum.DateTime or None. + """ + + if obj is None: + return obj + elif isinstance(obj, datetime.datetime): + return pendulum.instance(obj) + elif isinstance(obj, str): + return pendulum.parse(obj) + raise NotImplementedError("Unsupported type") + + +def pendulum_to_bq_timestamp(dt: pendulum.DateTime | None) -> str: + """Convert a pendulum instance to a BigQuery timestamp string. + + :param dt: the pendulum DateTime instance or None. + :return: the string. + """ + + return None if dt is None else dt.to_iso8601_string() + + +def get_bigquery_default_project() -> str: + """Get the default BigQuery project ID. + + :return: BigQuery project ID. + """ + + client = bigquery.Client() + return client.project diff --git a/observatory-platform/observatory/platform/files.py b/observatory_platform/files.py similarity index 99% rename from observatory-platform/observatory/platform/files.py rename to observatory_platform/files.py index 9d853c9ca..c6a8263b8 100644 --- a/observatory-platform/observatory/platform/files.py +++ b/observatory_platform/files.py @@ -26,23 +26,22 @@ import re import shutil import subprocess +import zlib from _hashlib import HASH from datetime import datetime from functools import partial from pathlib import Path from subprocess import Popen -from typing import Any, List -from typing import BinaryIO, Dict +from typing import BinaryIO, Dict, Any, List import json_lines import jsonlines import numpy as np import pandas as pd -import zlib from google.cloud import bigquery from google_crc32c import Checksum as Crc32cChecksum -from observatory.platform.utils.proc_utils import wait_for_process +from observatory_platform.proc_utils import wait_for_process def list_files(path: str, regex: str = None) -> List[str]: diff --git a/observatory-platform/observatory/platform/dags/__init__.py b/observatory_platform/google/__init__.py similarity index 100% rename from observatory-platform/observatory/platform/dags/__init__.py rename to observatory_platform/google/__init__.py diff --git a/observatory-platform/observatory/platform/bigquery.py b/observatory_platform/google/bigquery.py similarity index 87% rename from observatory-platform/observatory/platform/bigquery.py rename to observatory_platform/google/bigquery.py index 3787f44c0..16434a70e 100644 --- a/observatory-platform/observatory/platform/bigquery.py +++ b/observatory_platform/google/bigquery.py @@ -1,4 +1,4 @@ -# Copyright 2020-2023 Curtin University +# Copyright 2020-2024 Curtin University # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -26,7 +26,7 @@ import jsonlines import pendulum -from google.api_core.exceptions import BadRequest, Conflict +from google.api_core.exceptions import BadRequest from google.cloud import bigquery from google.cloud.bigquery import ( LoadJob, @@ -43,11 +43,8 @@ from google.cloud.exceptions import Conflict, NotFound from natsort import natsorted -from observatory.platform.config import sql_templates_path -from observatory.platform.utils.jinja2_utils import ( - make_sql_jinja2_filename, - render_template, -) +from observatory_platform.config import sql_templates_path +from observatory_platform.jinja2_utils import make_sql_jinja2_filename, render_template # BigQuery single query byte limit. # Daily limit is set in Terraform @@ -127,15 +124,17 @@ def bq_table_shard_info(table_id: str) -> Tuple[str, Optional[pendulum.Date]]: return table_id[:-8], pendulum.parse(results.group(0)) -def bq_table_exists(table_id: str) -> bool: +def bq_table_exists(table_id: str, client: Optional[bigquery.Client] = None) -> bool: """Checks whether a BigQuery table exists or not. - :param table_id: the fully qualified BigQuery table identifier + :param table_id: the fully qualified BigQuery table identifier. + :param client: BigQuery client. If None default Client is created. :return: whether the table exists or not. """ assert_table_id(table_id) - client = bigquery.Client() + if client is None: + client = bigquery.Client() table_exists = True try: @@ -147,7 +146,11 @@ def bq_table_exists(table_id: str) -> bool: def bq_select_table_shard_dates( - *, table_id: str, end_date: Union[pendulum.DateTime, pendulum.Date], limit: int = 1 + *, + table_id: str, + end_date: Union[pendulum.DateTime, pendulum.Date], + limit: int = 1, + client: Optional[bigquery.Client] = None, ) -> List[pendulum.Date]: """Returns a list of table shard dates, sorted from the most recent to the oldest date. By default it returns the first result. @@ -155,6 +158,7 @@ def bq_select_table_shard_dates( :param table_id: the fully qualified BigQuery table identifier, excluding any shard date. :param end_date: the end date of the table suffixes to search for (most recent date). :param limit: the number of results to return. + :param client: BigQuery client. If None default Client is created. :return: """ @@ -166,7 +170,9 @@ def bq_select_table_shard_dates( end_date=end_date.strftime("%Y-%m-%d"), limit=limit, ) - rows = bq_run_query(query) + if client is None: + client = bigquery.Client() + rows = bq_run_query(query, client=client) dates = [] for row in rows: py_date = row["suffix"] @@ -176,23 +182,24 @@ def bq_select_table_shard_dates( def bq_select_latest_table( - *, - table_id: str, - end_date: Union[pendulum.DateTime, pendulum.Date], - sharded: bool, + *, table_id: str, end_date: Union[pendulum.DateTime, pendulum.Date], sharded: bool, client: bigquery.Client = None ): """Select the latest fully qualified BigQuery table identifier. :param table_id: the fully qualified BigQuery table identifier, excluding a shard date. :param end_date: latest date considered. :param sharded: whether the table is sharded or not. + :param client: BigQuery client. If None default Client is created. """ assert_table_id(table_id) if sharded: + if client is None: + client = bigquery.Client() table_date = bq_select_table_shard_dates( table_id=table_id, end_date=end_date, + client=client, )[0] table_id = f"{table_id}{table_date.strftime('%Y%m%d')}" @@ -293,16 +300,18 @@ def bq_find_schema( return None -def bq_update_table_description(*, table_id: str, description: str): +def bq_update_table_description(*, table_id: str, description: str, client: Optional[bigquery.Client] = None): """Update a BigQuery table's description. :param table_id: the fully qualified BigQuery table identifier. :param description: the description. + :param client: BigQuery client. If None default Client is created. :return: None. """ # Construct a BigQuery client object. - client = bigquery.Client() + if client is None: + client = bigquery.Client() # Set description on table table = bigquery.Table(table_id) @@ -331,6 +340,7 @@ def bq_load_table( cluster: bool = False, clustering_fields=None, ignore_unknown_values: bool = False, + client: Optional[bigquery.Client] = None, ) -> bool: """Load a BigQuery table from an object on Google Cloud Storage. @@ -352,6 +362,7 @@ def bq_load_table( :param clustering_fields: what fields to cluster on. Default is to overwrite. :param ignore_unknown_values: whether to ignore unknown values or not. + :param client: BigQuery client. If None default Client is created. :return: True if the load job was successful, False otherwise. """ @@ -372,7 +383,8 @@ def bq_load_table( clustering_fields = [] # Create load job - client = bigquery.Client() + if client is None: + client = bigquery.Client() job_config = LoadJobConfig() # Set global options @@ -424,10 +436,11 @@ def bq_load_from_memory( partition_type: bigquery.TimePartitioningType = bigquery.TimePartitioningType.DAY, require_partition_filter=False, write_disposition: str = bigquery.WriteDisposition.WRITE_TRUNCATE, - table_description: str = "", + table_description: str | None = None, cluster: bool = False, clustering_fields=None, ignore_unknown_values: bool = False, + client: Optional[bigquery.Client] = None, ) -> bool: """Load data into BigQuery from memory. @@ -444,6 +457,7 @@ def bq_load_from_memory( :param clustering_fields: what fields to cluster on. Default is to overwrite. :param ignore_unknown_values: whether to ignore unknown values or not. + :param client: BigQuery client. If None default Client is created. :return: True if the load job was successful, False otherwise. """ @@ -456,7 +470,8 @@ def bq_load_from_memory( clustering_fields = [] # Create load job - client = bigquery.Client() + if client is None: + client = bigquery.Client() job_config = LoadJobConfig() if schema_file_path is not None: @@ -502,11 +517,12 @@ def bq_load_from_memory( return state -def bq_query_bytes_estimate(query: str, *args, **kwargs) -> int: +def bq_query_bytes_estimate(query: str, *args, client: bigquery.Client = None, **kwargs) -> int: """Do a dry run of a BigQuery query to estimate the bytes processed. :param query: the query string. :param args: Positional arguments to pass onto the bigquery.Client().query function. + :param client: BigQuery client. If None default Client is created. :param kwargs: Named arguments to pass onto the bigquery.Client().query function. :return: Query bytes estimate. """ @@ -518,7 +534,9 @@ def bq_query_bytes_estimate(query: str, *args, **kwargs) -> int: config.dry_run = True kwargs["job_config"] = config - bytes_estimate = bigquery.Client().query(query, *args, **kwargs).total_bytes_processed + if client is None: + client = bigquery.Client() + bytes_estimate = client.query(query, *args, **kwargs).total_bytes_processed return bytes_estimate @@ -543,18 +561,23 @@ def bq_query_bytes_budget_check(*, bytes_budget: int, bytes_estimate: int): raise Exception(f"Bytes estimate {bytes_estimate} exceeds the budget {bytes_budget}.") -def bq_run_query(query: str, bytes_budget: int = BIGQUERY_SINGLE_QUERY_BYTE_LIMIT) -> list: +def bq_run_query( + query: str, bytes_budget: int = BIGQUERY_SINGLE_QUERY_BYTE_LIMIT, client: Optional[bigquery.Client] = None +) -> list: """Run a BigQuery query. Defaults to 1 TiB query budget. :param query: the query to run. :param bytes_budget: Maximum bytes allowed to be processed by the query. + :param client: BigQuery client. If None default Client is created. :return: the results. """ - bytes_estimate = bq_query_bytes_estimate(query) + if client is None: + client = bigquery.Client() + + bytes_estimate = bq_query_bytes_estimate(query, client=client) bq_query_bytes_budget_check(bytes_budget=bytes_budget, bytes_estimate=bytes_estimate) - client = bigquery.Client() query_job = client.query(query) rows = query_job.result() success = query_job.errors is None # throws error when query didn't work @@ -567,12 +590,14 @@ def bq_copy_table( src_table_id: Union[str, list], dst_table_id: str, write_disposition: bigquery.WriteDisposition = bigquery.WriteDisposition.WRITE_TRUNCATE, + client: Optional[bigquery.Client] = None, ) -> bool: """Copy a BigQuery table. :param src_table_id: the fully qualified BigQuery table identifier the source table. :param dst_table_id: the fully qualified BigQuery table identifier of the destination table. :param write_disposition: whether to append, overwrite or throw an error when data already exists in the table. + :param client: BigQuery client. If None default Client is created. :return: whether the table was copied successfully or not. """ @@ -583,9 +608,9 @@ def bq_copy_table( assert_table_id(src_table_id) assert_table_id(dst_table_id) - client = bigquery.Client() + if client is None: + client = bigquery.Client() job_config = bigquery.CopyJobConfig() - job_config.write_disposition = write_disposition job = client.copy_table(src_table_id, dst_table_id, job_config=job_config) @@ -593,18 +618,22 @@ def bq_copy_table( return result.done() -def bq_create_view(*, view_id: str, query: str, update_if_exists: bool = True) -> Table: +def bq_create_view( + *, view_id: str, query: str, update_if_exists: bool = True, client: Optional[bigquery.Client] = None +) -> Table: """Create a BigQuery view. :param view_id: the fully qualified BigQuery table identifier for the view. :param query: the query for the view. - :param update_if_exists: whether to update the view with the input query if it already exists + :param update_if_exists: whether to update the view with the input query if it already exists. + :param client: BigQuery client. If None default Client is created. :return: The bigquery table object of the view created/updated """ assert_table_id(view_id) - client = bigquery.Client() + if client is None: + client = bigquery.Client() view = bigquery.Table(view_id) view.view_query = query try: @@ -626,6 +655,7 @@ def bq_create_table_from_query( clustering_fields=None, bytes_budget: int = BIGQUERY_SINGLE_QUERY_BYTE_LIMIT, schema_file_path: str = None, + client: Optional[bigquery.Client] = None, ) -> bool: """Create a BigQuery dataset from a provided query. Defaults to 0.5 TiB query budget. If a schema file path is given and the table does not exist yet, then an empty table will be created with this @@ -638,6 +668,7 @@ def bq_create_table_from_query( :param clustering_fields: what fields to cluster on. :param bytes_budget: Maximum bytes allowed to be processed by query. :param schema_file_path: path on local file system to BigQuery table schema. + :param client: BigQuery client. If None default Client is created. :return: whether successful or not. """ @@ -654,7 +685,8 @@ def bq_create_table_from_query( logging.info(f"{func_name}: create bigquery table from query, {msg}") # Create empty table with schema. Delete the original table if it exists. - client = bigquery.Client() + if client is None: + client = bigquery.Client() write_disposition = bigquery.WriteDisposition.WRITE_TRUNCATE table = bigquery.Table(table_id) if schema_file_path: @@ -685,7 +717,7 @@ def bq_create_table_from_query( if clustering_fields: job_config.clustering_fields = clustering_fields - bytes_estimate = bq_query_bytes_estimate(sql, job_config=job_config) + bytes_estimate = bq_query_bytes_estimate(sql, job_config=job_config, client=client) bq_query_bytes_budget_check(bytes_budget=bytes_budget, bytes_estimate=bytes_estimate) query_job: QueryJob = client.query(sql, job_config=job_config) @@ -695,7 +727,9 @@ def bq_create_table_from_query( return success -def bq_create_dataset(*, project_id: str, dataset_id: str, location: str, description: str = "") -> bigquery.Dataset: +def bq_create_dataset( + *, project_id: str, dataset_id: str, location: str, description: str = "", client: Optional[bigquery.Client] = None +) -> bigquery.Dataset: """Create a BigQuery dataset. :param project_id: the Google Cloud project id. @@ -703,6 +737,7 @@ def bq_create_dataset(*, project_id: str, dataset_id: str, location: str, descri :param location: the location where the dataset will be stored: https://cloud.google.com/compute/docs/regions-zones/#locations :param description: a description for the dataset + :param client: BigQuery client. If None default Client is created. :return: None """ @@ -712,7 +747,8 @@ def bq_create_dataset(*, project_id: str, dataset_id: str, location: str, descri dataset_ref = f"{project_id}.{dataset_id}" # Make dataset handle - client = bigquery.Client() + if client is None: + client = bigquery.Client() ds = bigquery.Dataset(dataset_ref) # Set properties @@ -733,6 +769,8 @@ def bq_create_empty_table( table_id: str, schema_file_path: str = None, clustering_fields: List = None, + client: Optional[bigquery.Client] = None, + exists_ok: bool = False, ): """Creates an empty BigQuery table. If a path to a schema file is given the table will be created using this schema. @@ -740,6 +778,8 @@ def bq_create_empty_table( :param table_id: the fully qualified BigQuery table identifier of the table we will create. :param schema_file_path: path on local file system to BigQuery table schema. :param clustering_fields: what fields to cluster on. + :param client: BigQuery client. If None default Client is created. + :param exists_ok: whether it is OK for the table to exist already. :return: The table instance if the request was successful. """ @@ -747,7 +787,8 @@ def bq_create_empty_table( msg = f"table_id={table_id}, schema_file_path={schema_file_path}" logging.info(f"{func_name}: creating empty bigquery table {msg}") - client = bigquery.Client() + if client is None: + client = bigquery.Client() if schema_file_path: schema = client.schema_from_json(schema_file_path) table = bigquery.Table(table_id, schema=schema) @@ -758,23 +799,25 @@ def bq_create_empty_table( if clustering_fields: table.clustering_fields = clustering_fields - table = client.create_table(table) + table = client.create_table(table, exists_ok=exists_ok) return table -def bq_list_tables(project_id: str, dataset_id: str) -> List[str]: +def bq_list_tables(project_id: str, dataset_id: str, client: Optional[bigquery.Client] = None) -> List[str]: """List all the tables within a BigQuery dataset. :param project_id: the Google Cloud project id. :param dataset_id: the BigQuery dataset id. + :param client: BigQuery client. If None default Client is created. :return: the fully qualified BigQuery table ids. """ - src_client = bigquery.Client() + if client is None: + client = bigquery.Client() table_ids = [] ds = bigquery.Dataset(f"{project_id}.{dataset_id}") - tables = src_client.list_tables(ds, max_results=10000) + tables = client.list_tables(ds, max_results=10000) for table in tables: table_id = str(table.reference) table_ids.append(table_id) @@ -797,12 +840,15 @@ def bq_get_table(table_id: str) -> Optional[BQTable]: return None -def bq_export_table(*, table_id: str, file_type: str, destination_uri: str) -> bool: +def bq_export_table( + *, table_id: str, file_type: str, destination_uri: str, client: Optional[bigquery.Client] = None +) -> bool: """Export a BigQuery table. :param table_id: the fully qualified BigQuery table identifier. :param file_type: the type of file to save the exported data as; csv or jsonl. :param destination_uri: the Google Cloud storage bucket destination URI. + :param client: BigQuery client. If None default Client is created. :return: whether the dataset was exported successfully or not. """ @@ -817,7 +863,8 @@ def bq_export_table(*, table_id: str, file_type: str, destination_uri: str) -> b raise ValueError(f"export_bigquery_table: file type '{file_type}' is not supported") # Create and run extraction job - client = bigquery.Client() + if client is None: + client = bigquery.Client() extract_job_config = bigquery.ExtractJobConfig() # Set gz compression if file type ends in .gz @@ -831,7 +878,9 @@ def bq_export_table(*, table_id: str, file_type: str, destination_uri: str) -> b return extract_job.state == "DONE" -def bq_list_datasets_with_prefix(*, prefix: str = "") -> List[dataset.Dataset]: +def bq_list_datasets_with_prefix( + *, prefix: str = "", client: Optional[bigquery.Client] = None +) -> List[dataset.Dataset]: """List all BigQuery datasets with prefix. Due to multiple unit tests being run at once, need to include @@ -839,10 +888,12 @@ def bq_list_datasets_with_prefix(*, prefix: str = "") -> List[dataset.Dataset]: that it is listed and then that grabbed by the API. :param prefix: Prefix of datasets to list. + :param client: BigQuery client. If None default Client is created. :return: A list of dataset objects that are under the project. """ - client = bigquery.Client() + if client is None: + client = bigquery.Client() datasets = list(client.list_datasets()) dataset_list = [] for dataset in datasets: @@ -859,7 +910,7 @@ def bq_list_datasets_with_prefix(*, prefix: str = "") -> List[dataset.Dataset]: return dataset_list -def bq_delete_old_datasets_with_prefix(*, prefix: str, age_to_delete: int): +def bq_delete_old_datasets_with_prefix(*, prefix: str, age_to_delete: int, client: Optional[bigquery.Client] = None): """Deletes datasets that share the same prefix and if it is older than "age_to_delete" hours. Due to multiple unit tests being run at once, need to include a try and except as @@ -868,9 +919,11 @@ def bq_delete_old_datasets_with_prefix(*, prefix: str, age_to_delete: int): :param prefix: The identifying prefix of the datasets to delete. :param age_to_delete: Delete if the age of the bucket is older than this amount. + :param client: BigQuery client. If None default Client is created. """ - client = bigquery.Client() + if client is None: + client = bigquery.Client() # List all datsets in the project with prefix dataset_list = bq_list_datasets_with_prefix(prefix=prefix) @@ -910,12 +963,14 @@ def bq_snapshot( src_table_id: str, dst_table_id: str, expiry_date: pendulum.DateTime = None, + client: Optional[bigquery.Client] = None, ): """Create a BigQuery snapshot of a table. :param src_table_id: the BigQuery table name of the table to snapshot. :param dst_table_id: the date to give the snapshot table. :param expiry_date: the datetime for when the table should expire, e.g. datetime.datetime.now() + datetime.timedelta(minutes=60). If None then table will be permanent. + :param client: BigQuery client. If None default Client is created. :return: if the request was successful. """ @@ -925,7 +980,8 @@ def bq_snapshot( assert_table_id(src_table_id) assert_table_id(dst_table_id) - client = bigquery.Client() + if client is None: + client = bigquery.Client() job_config = CopyJobConfig( operation_type="SNAPSHOT", write_disposition="WRITE_EMPTY", destination_expiration_time=expiry_date.isoformat() ) @@ -939,14 +995,13 @@ def bq_snapshot( def bq_select_columns( - *, - table_id: str, - bytes_budget: Optional[int] = BIGQUERY_SINGLE_QUERY_BYTE_LIMIT, + *, table_id: str, bytes_budget: Optional[int] = BIGQUERY_SINGLE_QUERY_BYTE_LIMIT, client: bigquery.Client = None ) -> List[Dict]: """Select columns from a BigQuery table. :param table_id: the fully qualified BigQuery table identifier. :param bytes_budget: the BigQuery bytes budget. + :param client: BigQuery client. If None default Client is created. :return: the columns, which includes column_name and data_type. """ @@ -962,7 +1017,9 @@ def bq_select_columns( dataset_id=dataset_id, table_id=table_id, ) - rows = bq_run_query(query, bytes_budget=bytes_budget) + if client is None: + client = bigquery.Client() + rows = bq_run_query(query, bytes_budget=bytes_budget, client=client) return [dict(row) for row in rows] @@ -972,6 +1029,7 @@ def bq_upsert_records( upsert_table_id: str, primary_key: Union[str, List[str]], bytes_budget: Optional[int] = BIGQUERY_SINGLE_QUERY_BYTE_LIMIT, + client: bigquery.Client = None, ): """Upserts records (updates and inserts) from an upsert_table into a main_table based on a primary_key. @@ -979,16 +1037,20 @@ def bq_upsert_records( :param upsert_table_id: the fully qualified table identifier for the BigQuery table containing the upserts. :param primary_key: A single key or a list of keys to use to determine which records to upsert. :param bytes_budget: the BigQuery bytes budget. + :param client: BigQuery client. If None default Client is created. :return: whether the upsert was successful or not. """ assert_table_id(main_table_id) assert_table_id(upsert_table_id) + if client is None: + client = bigquery.Client() + # Fetch column names in main and upsert table which are used for the update part of the merge # and to check that the columns match - main_columns = bq_select_columns(table_id=main_table_id) - upsert_columns = bq_select_columns(table_id=upsert_table_id) + main_columns = bq_select_columns(table_id=main_table_id, client=client) + upsert_columns = bq_select_columns(table_id=upsert_table_id, client=client) # Assert that the column names and data types in main_table and upsert_table are the same and in the same order # Must be in same order for upsert to work @@ -1019,7 +1081,7 @@ def bq_upsert_records( keys=keys, columns=main_top_level_cols, ) - bq_run_query(query, bytes_budget=bytes_budget) + bq_run_query(query, bytes_budget=bytes_budget, client=client) def bq_delete_records( @@ -1031,6 +1093,7 @@ def bq_delete_records( main_table_primary_key_prefix: str = "", delete_table_primary_key_prefix: str = "", bytes_budget: Optional[int] = BIGQUERY_SINGLE_QUERY_BYTE_LIMIT, + client: bigquery.Client = None, ): """Deletes records from a main_table based on records in a delete_table. @@ -1041,16 +1104,22 @@ def bq_delete_records( :param main_table_primary_key_prefix: an optional prefix to add to the primary key main table cells. :param delete_table_primary_key_prefix: an optional prefix to add to the primary key delete table cells. :param bytes_budget: the bytes budget. + :param client: BigQuery client. If None default Client is created. :return: """ assert_table_id(main_table_id) assert_table_id(delete_table_id) + if client is None: + client = bigquery.Client() + # Fetch column names in main and delete table to check if primary keys match - main_column_index = {item["column_name"]: item["data_type"] for item in bq_select_columns(table_id=main_table_id)} + main_column_index = { + item["column_name"]: item["data_type"] for item in bq_select_columns(table_id=main_table_id, client=client) + } delete_column_index = { - item["column_name"]: item["data_type"] for item in bq_select_columns(table_id=delete_table_id) + item["column_name"]: item["data_type"] for item in bq_select_columns(table_id=delete_table_id, client=client) } # Check that primary_keys are in tables and that data types match @@ -1088,4 +1157,4 @@ def bq_delete_records( delete_table_primary_key_prefix=delete_table_primary_key_prefix, zip=zip, ) - bq_run_query(query, bytes_budget=bytes_budget) + bq_run_query(query, bytes_budget=bytes_budget, client=client) diff --git a/observatory_platform/google/gcp.py b/observatory_platform/google/gcp.py new file mode 100644 index 000000000..4cd3c4db5 --- /dev/null +++ b/observatory_platform/google/gcp.py @@ -0,0 +1,120 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This is an ingredient file. It is not meant to be run directly. Check the samples/snippets +# folder for complete code samples that are ready to be used. + +# Sources +# https://github.com/GoogleCloudPlatform/python-docs-samples + +from __future__ import annotations + +import logging +import sys +from typing import Any + +from google.api_core.exceptions import NotFound +from google.api_core.extended_operation import ExtendedOperation +from google.cloud import compute_v1 + + +def gcp_create_disk( + *, project_id: str, zone: str, disk_name: str, disk_size_gb: int = 10, disk_type: str = "pd-standard" +) -> compute_v1.Disk: + """Creates a new empty disk in a project in given zone. + + Source: https://github.com/GoogleCloudPlatform/python-docs-samples/blob/main/compute/client_library/ingredients/disks/create_empty_disk.py + + :param project_id: project ID or project number of the Cloud project you want to use. + :param zone: name of the zone in which you want to create the disk. + :param disk_name: name of the disk you want to create. + :param disk_size_gb: size of the new disk in gigabytes. + :param disk_type: the type of disk you want to create. This value uses the following format: + "zones/{zone}/diskTypes/(pd-standard|pd-ssd|pd-balanced|pd-extreme)". + For example: "zones/us-west3-b/diskTypes/pd-ssd" + :return: An unattached Disk instance. + """ + + disk = compute_v1.Disk() + disk.name = disk_name + disk.size_gb = disk_size_gb + disk.zone = zone + disk.type_ = f"zones/{zone}/diskTypes/{disk_type}" + + disk_client = compute_v1.DisksClient() + operation = disk_client.insert(project=project_id, zone=zone, disk_resource=disk) + + wait_for_extended_operation(operation, "disk creation") + + return disk_client.get(project=project_id, zone=zone, disk=disk.name) + + +def gcp_delete_disk(*, project_id: str, zone: str, disk_name: str) -> None: + """Deletes a disk from a project. + + Source: https://github.com/GoogleCloudPlatform/python-docs-samples/blob/main/compute/client_library/ingredients/disks/delete.py + + :param project_id: project ID or project number of the Cloud project you want to use. + :param zone: name of the zone in which is the disk you want to delete. + :param disk_name: name of the disk you want to delete. + :return: None. + """ + + disk_client = compute_v1.DisksClient() + try: + operation = disk_client.delete(project=project_id, zone=zone, disk=disk_name) + wait_for_extended_operation(operation, "disk deletion") + except NotFound: + logging.info(f"gcp_delete_disk: disk with name={disk_name} does not exist") + + +def wait_for_extended_operation( + operation: ExtendedOperation, verbose_name: str = "operation", timeout: int = 300 +) -> Any: + """Waits for the extended (long-running) operation to complete. + + If the operation is successful, it will return its result. + If the operation ends with an error, an exception will be raised. + If there were any warnings during the execution of the operation + they will be printed to sys.stderr. + + Source: https://github.com/GoogleCloudPlatform/python-docs-samples/blob/main/compute/client_library/snippets/operations/wait_for_extended_operation.py + + :param operation: a long-running operation you want to wait on. + :param verbose_name: (optional) a more verbose name of the operation, + used only during error and warning reporting. + :param timeout: how long (in seconds) to wait for operation to finish. If None, wait indefinitely. + :return: Whatever the operation.result() returns. + :raises: concurrent.futures.TimeoutError: in the case of an operation taking longer than `timeout` seconds to complete. + :raises: RuntimeError: if there is no exception set, but there is an `error_code` set for the `operation`. + :raises: operation.exception(): will raise the exception received from operation.exception() + """ + + result = operation.result(timeout=timeout) + + if operation.error_code: + print( + f"Error during {verbose_name}: [Code: {operation.error_code}]: {operation.error_message}", + file=sys.stderr, + flush=True, + ) + print(f"Operation ID: {operation.name}", file=sys.stderr, flush=True) + raise operation.exception() or RuntimeError(operation.error_message) + + if operation.warnings: + print(f"Warnings during {verbose_name}:\n", file=sys.stderr, flush=True) + for warning in operation.warnings: + print(f" - {warning.code}: {warning.message}", file=sys.stderr, flush=True) + + return result diff --git a/observatory-platform/observatory/platform/gcs.py b/observatory_platform/google/gcs.py similarity index 83% rename from observatory-platform/observatory/platform/gcs.py rename to observatory_platform/google/gcs.py index 34bcb9509..7ccd483ae 100644 --- a/observatory-platform/observatory/platform/gcs.py +++ b/observatory_platform/google/gcs.py @@ -14,32 +14,33 @@ # Author: James Diprose, Aniek Roelofs +import contextlib import csv import datetime import json import logging -import multiprocessing import os import pathlib import tempfile -from concurrent.futures import ProcessPoolExecutor, as_completed +import threading +import time +from concurrent.futures import ThreadPoolExecutor, as_completed from enum import Enum -from multiprocessing import BoundedSemaphore, cpu_count -from typing import List, Tuple -import contextlib +from multiprocessing import cpu_count +from typing import List, Tuple, Optional import pendulum -import time from airflow import AirflowException from google.api_core.exceptions import Conflict +from google.auth.credentials import Credentials from google.cloud import storage from google.cloud.exceptions import Conflict from google.cloud.storage import Blob, bucket from googleapiclient import discovery as gcp_api from requests.exceptions import ChunkedEncodingError -from observatory.platform.airflow import get_data_path -from observatory.platform.files import crc32c_base64_hash +from observatory_platform.airflow.workflow import get_data_path +from observatory_platform.files import crc32c_base64_hash # The chunk size to use when uploading / downloading a blob in multiple parts, must be a multiple of 256 KB. DEFAULT_CHUNK_SIZE = 256 * 1024 * 4 @@ -102,13 +103,15 @@ def gcs_blob_name_from_path(local_filepath: str) -> str: return blob_path -def gcs_bucket_exists(bucket_name: str): +def gcs_bucket_exists(bucket_name: str, client: Optional[storage.Client] = None): """Check whether the Google Cloud Storage bucket exists :param bucket_name: Bucket name (without gs:// prefix) + :param client: Storage client. If None default Client is created. :return: Whether the bucket exists or not """ - client = storage.Client() + if client is None: + client = storage.Client() bucket = client.bucket(bucket_name) exists = bucket.exists() @@ -117,7 +120,12 @@ def gcs_bucket_exists(bucket_name: str): def gcs_create_bucket( - *, bucket_name: str, location: str = None, project_id: str = None, lifecycle_delete_age: int = None + *, + bucket_name: str, + location: str = None, + project_id: str = None, + lifecycle_delete_age: int = None, + client: storage.Client = None, ) -> bool: """Create a cloud storage bucket @@ -126,6 +134,7 @@ def gcs_create_bucket( :param project_id: The project which the client acts on behalf of. Will be passed when creating a topic. If not passed, falls back to the default inferred from the environment. :param lifecycle_delete_age: Days until files in bucket are deleted + :param client: Storage client. If None default Client is created. :return: Whether creating bucket was successful or not. """ func_name = gcs_create_bucket.__name__ @@ -133,7 +142,9 @@ def gcs_create_bucket( success = False - client = storage.Client(project=project_id) + if client is None: + client = storage.Client() + client.project = project_id bucket = storage.Bucket(client, name=bucket_name) if lifecycle_delete_age: bucket.add_lifecycle_delete_rule(age=lifecycle_delete_age) @@ -148,19 +159,23 @@ def gcs_create_bucket( return success -def gcs_copy_blob(*, blob_name: str, src_bucket: str, dst_bucket: str, new_name: str = None) -> bool: +def gcs_copy_blob( + *, blob_name: str, src_bucket: str, dst_bucket: str, new_name: str = None, client: storage.Client = None +) -> bool: """Copy a blob from one bucket to another :param blob_name: The name of the blob. This corresponds to the unique path of the object in the bucket. :param src_bucket: The bucket to which the blob belongs. :param dst_bucket: The bucket into which the blob should be copied. :param new_name: (Optional) The new name for the copied file. + :param client: Storage client. If None default Client is created. :return: Whether copy was successful. """ func_name = gcs_copy_blob.__name__ logging.info(f"{func_name}: {os.path.join(src_bucket, blob_name)}") - client = storage.Client() + if client is None: + client = storage.Client() # source blob and bucket bucket = storage.Bucket(client, name=src_bucket) @@ -182,8 +197,9 @@ def gcs_download_blob( blob_name: str, file_path: str, retries: int = 3, - connection_sem: BoundedSemaphore = None, + connection_sem: threading.BoundedSemaphore = None, chunk_size: int = DEFAULT_CHUNK_SIZE, + client: storage.Client = None, ) -> bool: """Download a blob to a file. @@ -193,6 +209,7 @@ def gcs_download_blob( :param retries: the number of times to retry downloading the blob. :param connection_sem: a BoundedSemaphore to limit the number of download connections that can run at once. :param chunk_size: the chunk size to use when downloading a blob in multiple parts, must be a multiple of 256 KB. + :param client: Storage client. If None default Client is created. :return: whether the download was successful or not. """ @@ -200,9 +217,11 @@ def gcs_download_blob( logging.info(f"{func_name}: {file_path}") # Get blob - client = storage.Client() + if client is None: + client = storage.Client() bucket = client.bucket(bucket_name) blob: Blob = bucket.blob(blob_name) + uri = gcs_blob_uri(bucket_name, blob_name) # State download = True @@ -243,7 +262,9 @@ def gcs_download_blob( success = True break except ChunkedEncodingError as e: - logging.error(f"{func_name}: exception downloading file: try={i}, file_path={file_path}, exception={e}") + logging.error( + f"{func_name}: exception downloading file: try={i}, file_path={file_path}, uri={uri}, exception={e}" + ) # Release connection semaphore if connection_sem is not None: @@ -261,6 +282,7 @@ def gcs_download_blobs( max_connections: int = cpu_count(), retries: int = 3, chunk_size: int = DEFAULT_CHUNK_SIZE, + client: storage.Client = None, ) -> bool: """Download all blobs on a Google Cloud Storage bucket that are within a prefixed path, to a destination on the local file system. @@ -272,23 +294,24 @@ def gcs_download_blobs( :param max_connections: the maximum number of download connections at once. :param retries: the number of times to retry downloading the blob. :param chunk_size: the chunk size to use when downloading a blob in multiple parts, must be a multiple of 256 KB. + :param client: Storage client. If None default Client is created. :return: whether the files were downloaded successfully or not. """ func_name = gcs_download_blobs.__name__ # Get bucket - storage_client = storage.Client() - bucket = storage_client.get_bucket(bucket_name) + if client is None: + client = storage.Client() + bucket = client.get_bucket(bucket_name) # List blobs blobs: List[Blob] = list(bucket.list_blobs(prefix=prefix)) logging.info(f"{func_name}: {blobs}") # Download each blob in parallel - manager = multiprocessing.Manager() - connection_sem = manager.BoundedSemaphore(value=max_connections) - with ProcessPoolExecutor(max_workers=max_processes) as executor: + connection_sem = threading.BoundedSemaphore(value=max_connections) + with ThreadPoolExecutor(max_workers=max_processes) as executor: # Create tasks futures = [] futures_msgs = {} @@ -338,6 +361,7 @@ def gcs_upload_files( max_connections: int = cpu_count(), retries: int = 3, chunk_size: int = DEFAULT_CHUNK_SIZE, + credentials: Optional[Credentials] = None, ) -> bool: """Upload a list of files to Google Cloud storage. @@ -349,6 +373,7 @@ def gcs_upload_files( :param max_connections: the maximum number of upload connections at once. :param retries: the number of times to retry uploading a file if an error occurs. :param chunk_size: the chunk size to use when uploading a blob in multiple parts, must be a multiple of 256 KB. + :param credentials: the credentials to use with the Google Cloud Storage client. :return: whether the files were uploaded successfully or not. """ @@ -372,9 +397,8 @@ def gcs_upload_files( assert len(file_paths) == len(blob_names), f"{func_name}: file_paths and blob_names have different lengths" # Upload each file in parallel - manager = multiprocessing.Manager() - connection_sem = manager.BoundedSemaphore(value=max_connections) - with ProcessPoolExecutor(max_workers=max_processes) as executor: + connection_sem = threading.BoundedSemaphore(value=max_connections) + with ThreadPoolExecutor(max_workers=max_processes) as executor: # Create tasks futures = [] futures_msgs = {} @@ -384,11 +408,12 @@ def gcs_upload_files( future = executor.submit( gcs_upload_file, bucket_name=bucket_name, - blob_name=blob_name, + blob_name=str(blob_name), file_path=str(file_path), retries=retries, connection_sem=connection_sem, chunk_size=chunk_size, + credentials=credentials, ) futures.append(future) futures_msgs[future] = msg @@ -414,10 +439,11 @@ def gcs_upload_file( blob_name: str, file_path: str, retries: int = 3, - connection_sem: BoundedSemaphore = None, + connection_sem: Optional[threading.BoundedSemaphore] = None, chunk_size: int = DEFAULT_CHUNK_SIZE, - project_id: str = None, check_blob_hash: bool = True, + client: storage.Client = None, + credentials: Optional[Credentials] = None, ) -> Tuple[bool, bool]: """Upload a file to Google Cloud Storage. @@ -427,8 +453,9 @@ def gcs_upload_file( :param retries: the number of times to retry uploading a file if an error occurs. :param connection_sem: a BoundedSemaphore to limit the number of upload connections that can run at once. :param chunk_size: the chunk size to use when uploading a blob in multiple parts, must be a multiple of 256 KB. - :param project_id: the project in which the bucket is located, defaults to inferred from the environment. :param check_blob_hash: check whether the blob exists and if the crc32c hashes match, in which case skip uploading. + :param client: Storage client. If None default Client is created. + :param credentials: the credentials to use with the Google Cloud Storage client. :return: whether the task was successful or not and whether the file was uploaded. """ func_name = gcs_upload_file.__name__ @@ -439,9 +466,12 @@ def gcs_upload_file( success = False # Get blob - storage_client = storage.Client(project=project_id) - bucket = storage_client.get_bucket(bucket_name) + if client is None: + client = storage.Client(credentials=credentials) + + bucket = client.get_bucket(bucket_name) blob = bucket.blob(blob_name) + uri = gcs_blob_uri(bucket_name, blob_name) # Check if blob exists already and matches the file we are uploading if check_blob_hash and blob.exists(): @@ -458,10 +488,7 @@ def gcs_upload_file( f"{func_name}: files_match={files_match}, expected_hash={expected_hash}, " f"actual_hash={actual_hash}" ) if files_match: - logging.info( - f"{func_name}: skipping upload as files match. bucket_name={bucket_name}, blob_name={blob_name}, " - f"file_path={file_path}" - ) + logging.info(f"{func_name}: skipping upload as files match. uri={uri}, file_path={file_path}") upload = False success = True @@ -478,7 +505,9 @@ def gcs_upload_file( success = True break except ChunkedEncodingError as e: - logging.error(f"{func_name}: exception uploading file: try={i}, exception={e}") + logging.error( + f"{func_name}: exception uploading file: try={i}, file_path={file_path}, uri={uri}, exception={e}" + ) # Release connection semaphore if connection_sem is not None: @@ -487,15 +516,22 @@ def gcs_upload_file( return success, upload -def gcs_create_transfer_job(*, job: dict, func_name: str, gc_project_id: str) -> Tuple[bool, int]: +def gcs_create_transfer_job( + *, + job: dict, + func_name: str, + gc_project_id: str, + credentials: Optional[Credentials] = None, +) -> Tuple[bool, int]: """Start a google cloud storage transfer job :param job: contains the details of the transfer job :param func_name: function name used for detailed logging info :param gc_project_id: the Google Cloud project id that holds the Google Cloud Storage bucket. + :param credentials: the credentials to use with the Google Cloud Storage client. :return: whether the transfer was a success or not and the number of objects transferred. """ - client = gcp_api.build("storagetransfer", "v1") + client = gcp_api.build("storagetransfer", "v1", credentials=credentials) create_result = client.transferJobs().create(body=job).execute() transfer_job_name = create_result["name"] @@ -557,6 +593,7 @@ def gcs_create_azure_transfer( description: str, gc_bucket_path: str = None, start_date: pendulum.DateTime = pendulum.now("UTC"), + credentials: Optional[Credentials] = None, ) -> bool: """Transfer files from an Azure blob container to a Google Cloud Storage bucket. @@ -569,6 +606,7 @@ def gcs_create_azure_transfer( :param description: a description for the transfer job. :param gc_bucket_path: the path in the Google Cloud bucket to save the objects. :param start_date: the date that the transfer job will start. + :param credentials: the credentials to use with the Google Cloud Storage client. :return: whether the transfer was a success or not. """ @@ -602,7 +640,9 @@ def gcs_create_azure_transfer( job["transferSpec"]["gcsDataSink"]["path"] = gc_bucket_path - success, objects_count = gcs_create_transfer_job(job=job, func_name=func_name, gc_project_id=gc_project_id) + success, objects_count = gcs_create_transfer_job( + job=job, func_name=func_name, gc_project_id=gc_project_id, credentials=credentials + ) return success @@ -618,6 +658,7 @@ def gcs_create_aws_transfer( last_modified_before: pendulum.DateTime = None, transfer_manifest: str = None, start_date: pendulum.DateTime = pendulum.now("UTC"), + credentials: Optional[Credentials] = None, ) -> Tuple[bool, int]: """Transfer files from an AWS bucket to a Google Cloud Storage bucket. @@ -631,6 +672,7 @@ def gcs_create_aws_transfer( :param last_modified_before: :param transfer_manifest: Path to manifest file in Google Cloud bucket (incl gs://). :param start_date: the date that the transfer job will start. + :param credentials: the credentials to use with the Google Cloud Storage client. :return: whether the transfer was a success or not. """ @@ -672,7 +714,9 @@ def gcs_create_aws_transfer( if transfer_manifest: job["transferSpec"]["transferManifest"] = {"location": transfer_manifest} - success, objects_count = gcs_create_transfer_job(job=job, func_name=func_name, gc_project_id=gc_project_id) + success, objects_count = gcs_create_transfer_job( + job=job, func_name=func_name, gc_project_id=gc_project_id, credentials=credentials + ) return success, objects_count @@ -684,14 +728,18 @@ def _is_utf8_str(s: str) -> bool: return True -def gcs_upload_transfer_manifest(object_paths: List[str], blob_uri: str): +def gcs_upload_transfer_manifest(object_paths: List[str], blob_uri: str, client: storage.Client = None): """Save a GCS transfer manifest CSV file. :param object_paths: the object paths excluding bucket name. :param blob_uri: the full URI on GCS where the manifest should be uploaded to. + :param client: Storage client. If None default Client is created. :return: None. """ + if client is None: + client = storage.Client() + # Write temp file with tempfile.NamedTemporaryFile(mode="w", delete=True) as file: writer = csv.writer(file, quotechar='"', quoting=csv.QUOTE_MINIMAL) @@ -710,18 +758,20 @@ def gcs_upload_transfer_manifest(object_paths: List[str], blob_uri: str): # Upload to cloud storage bucket_name, blob_path = gcs_uri_parts(blob_uri) - success = gcs_upload_file(bucket_name=bucket_name, blob_name=blob_path, file_path=file_path) + success = gcs_upload_file(bucket_name=bucket_name, blob_name=blob_path, file_path=file_path, client=client) assert success, f"gcs_upload_transfer_manifest: error uploading manifest to {blob_uri}" -def gcs_delete_bucket_dir(*, bucket_name: str, prefix: str): +def gcs_delete_bucket_dir(*, bucket_name: str, prefix: str, client: storage.Client = None): """Recursively delete blobs from a GCS bucket with a folder prefix. :param bucket_name: Bucket name. :param prefix: Directory prefix. + :param client: Storage client. If None default Client is created. """ - client = storage.Client() + if client is None: + client = storage.Client() bucket = client.get_bucket(bucket_name) blobs = bucket.list_blobs(prefix=prefix) @@ -729,15 +779,17 @@ def gcs_delete_bucket_dir(*, bucket_name: str, prefix: str): blob.delete() -def gcs_list_buckets_with_prefix(*, prefix: str = "") -> List[bucket.Bucket]: +def gcs_list_buckets_with_prefix(*, prefix: str = "", client: storage.Client = None) -> List[bucket.Bucket]: """List all Google Cloud buckets with prefix. :param prefix: Prefix of the buckets to list + :param client: Storage client. If None default Client is created. :return: A list of bucket objects that are under the project. """ - storage_client = storage.Client() - buckets = list(storage_client.list_buckets()) + if client is None: + client = storage.Client() + buckets = list(client.list_buckets()) bucket_list = [] for bucket in buckets: if bucket.name.startswith(prefix): @@ -746,19 +798,23 @@ def gcs_list_buckets_with_prefix(*, prefix: str = "") -> List[bucket.Bucket]: return bucket_list -def gcs_list_blobs(bucket_name: str, prefix: str = None, match_glob: str = None) -> List[storage.Blob]: +def gcs_list_blobs( + bucket_name: str, prefix: str = None, match_glob: str = None, client: storage.Client = None +) -> List[storage.Blob]: """List blobs in a bucket using a gcs_uri. :param bucket_name: The name of the bucket :param prefix: The prefix to filter by :param match_glob: The glob pattern to filter by + :param client: Storage client. If None default Client is created. :return: A list of blob objects in the bucket """ - storage_client = storage.Client() - return list(storage_client.list_blobs(bucket_name, prefix=prefix, match_glob=match_glob)) + if client is None: + client = storage.Client() + return list(client.list_blobs(bucket_name, prefix=prefix, match_glob=match_glob)) -def gcs_delete_old_buckets_with_prefix(*, prefix: str, age_to_delete: int): +def gcs_delete_old_buckets_with_prefix(*, prefix: str, age_to_delete: int, client: storage.Client = None): """Deletes buckets that share the same prefix and if it is older than "age_to_delete" hours. Due to multiple unit tests being run at once, need to include a try and except as @@ -767,10 +823,14 @@ def gcs_delete_old_buckets_with_prefix(*, prefix: str, age_to_delete: int): :param prefix: The identifying prefix of the buckets to delete. :param age_to_delete: Delete if the age of the bucket is older than this amount. + :param client: Storage client. If None default Client is created. """ + if client is None: + client = storage.Client() + # List all buckets in the project. - bucket_list = gcs_list_buckets_with_prefix(prefix=prefix) + bucket_list = gcs_list_buckets_with_prefix(prefix=prefix, client=client) buckets_deleted = [] for bucket in bucket_list: @@ -794,16 +854,21 @@ def gcs_delete_old_buckets_with_prefix(*, prefix: str, age_to_delete: int): f"Deleted the following buckets with prefix '{prefix}' older than {age_to_delete} hours: {buckets_deleted}" ) + @contextlib.contextmanager -def gcs_hmac_key(project_id, service_account_email): +def gcs_hmac_key(project_id, service_account_email, client: storage.Client = None): """Generates a new HMAC key using the given project and service account. Deletes it when context closes. :param project_id: The Google Cloud project ID :param service_account_email: The service account used to generate the HMAC key + :param client: Storage client. If None default Client is created. """ - storage_client = storage.Client(project=project_id) - key, secret = storage_client.create_hmac_key(service_account_email=service_account_email, project_id=project_id) + + if client is None: + client = storage.Client(project=project_id) + + key, secret = client.create_hmac_key(service_account_email=service_account_email, project_id=project_id) try: yield key, secret finally: diff --git a/observatory_platform/google/gke.py b/observatory_platform/google/gke.py new file mode 100644 index 000000000..f7ae3aa1c --- /dev/null +++ b/observatory_platform/google/gke.py @@ -0,0 +1,103 @@ +# Copyright 2023 Curtin University +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +import kubernetes +from airflow.providers.cncf.kubernetes.hooks.kubernetes import KubernetesHook +from kubernetes import client + + +def gke_create_volume(*, kubernetes_conn_id: str, volume_name: str, size_gi: int): + """ + + :param kubernetes_conn_id: + :param volume_name: + :param size_gi: + :return: None. + """ + + # Make Kubernetes API Client from Airflow Connection + hook = KubernetesHook(conn_id=kubernetes_conn_id) + api_client = hook.get_conn() + v1 = client.CoreV1Api(api_client=api_client) + + # Create the PersistentVolume + capacity = {"storage": f"{size_gi}Gi"} + pv = client.V1PersistentVolume( + api_version="v1", + kind="PersistentVolume", + metadata=client.V1ObjectMeta( + name=volume_name, + # TODO: supposed to use this user for the persistent volume but doesn't seem to do anything + # annotations={"pv.beta.kubernetes.io/uid": f"{uid}", "pv.beta.kubernetes.io/gid": f"{uid}"} + ), + spec=client.V1PersistentVolumeSpec( + capacity=capacity, + access_modes=["ReadWriteOnce"], + persistent_volume_reclaim_policy="Retain", + storage_class_name="standard", + gce_persistent_disk=client.V1GCEPersistentDiskVolumeSource(pd_name=volume_name), + ), + ) + v1.create_persistent_volume(body=pv) + + # Create PersistentVolumeClaim + namespace = hook.get_namespace() + pvc = client.V1PersistentVolumeClaim( + api_version="v1", + kind="PersistentVolumeClaim", + metadata=client.V1ObjectMeta(name=volume_name), + spec=client.V1PersistentVolumeClaimSpec( + access_modes=["ReadWriteOnce"], + resources=client.V1ResourceRequirements(requests=capacity), + storage_class_name="standard", + ), + ) + v1.create_namespaced_persistent_volume_claim(namespace=namespace, body=pvc) + + +def gke_delete_volume(*, kubernetes_conn_id: str, volume_name: str): + """ + + :param kubernetes_conn_id: + :param namespace: + :param volume_name: + :return: None. + """ + + # Make Kubernetes API Client from Airflow Connection + hook = KubernetesHook(conn_id=kubernetes_conn_id) + api_client = hook.get_conn() + v1 = client.CoreV1Api(api_client=api_client) + + # Delete VolumeClaim and Volume + namespace = hook.get_namespace() + try: + v1.delete_namespaced_persistent_volume_claim(name=volume_name, namespace=namespace) + except kubernetes.client.exceptions.ApiException as e: + if e.status == 404: + logging.info( + f"gke_delete_volume: PersistentVolumeClaim with name={volume_name}, namespace={namespace} does not exist" + ) + else: + raise e + + try: + v1.delete_persistent_volume(name=volume_name) + except kubernetes.client.exceptions.ApiException as e: + if e.status == 404: + logging.info(f"gke_delete_volume: PersistentVolume with name={volume_name} does not exist") + else: + raise e diff --git a/observatory-platform/observatory/platform/docker/__init__.py b/observatory_platform/google/tests/__init__.py similarity index 100% rename from observatory-platform/observatory/platform/docker/__init__.py rename to observatory_platform/google/tests/__init__.py diff --git a/observatory-platform/observatory/platform/sql/__init__.py b/observatory_platform/google/tests/fixtures/__init__.py similarity index 100% rename from observatory-platform/observatory/platform/sql/__init__.py rename to observatory_platform/google/tests/fixtures/__init__.py diff --git a/observatory_platform/google/tests/fixtures/bad_dag.py b/observatory_platform/google/tests/fixtures/bad_dag.py new file mode 100644 index 000000000..e4ce1af0c --- /dev/null +++ b/observatory_platform/google/tests/fixtures/bad_dag.py @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ecff8cd6bee7e5a6dd1a4b40358484eef5c52f0efe1b3c31a00f62a1d7987f81 +size 978 diff --git a/tests/fixtures/utils/people.csv b/observatory_platform/google/tests/fixtures/people.csv similarity index 100% rename from tests/fixtures/utils/people.csv rename to observatory_platform/google/tests/fixtures/people.csv diff --git a/tests/fixtures/utils/people.jsonl b/observatory_platform/google/tests/fixtures/people.jsonl similarity index 100% rename from tests/fixtures/utils/people.jsonl rename to observatory_platform/google/tests/fixtures/people.jsonl diff --git a/tests/fixtures/utils/people_extra.jsonl b/observatory_platform/google/tests/fixtures/people_extra.jsonl similarity index 100% rename from tests/fixtures/utils/people_extra.jsonl rename to observatory_platform/google/tests/fixtures/people_extra.jsonl diff --git a/tests/fixtures/utils/people_schema.json b/observatory_platform/google/tests/fixtures/people_schema.json similarity index 100% rename from tests/fixtures/utils/people_schema.json rename to observatory_platform/google/tests/fixtures/people_schema.json diff --git a/tests/fixtures/schemas/db_merge0.json b/observatory_platform/google/tests/fixtures/schema/db_merge0.json similarity index 100% rename from tests/fixtures/schemas/db_merge0.json rename to observatory_platform/google/tests/fixtures/schema/db_merge0.json diff --git a/tests/fixtures/schemas/db_merge1.json b/observatory_platform/google/tests/fixtures/schema/db_merge1.json similarity index 100% rename from tests/fixtures/schemas/db_merge1.json rename to observatory_platform/google/tests/fixtures/schema/db_merge1.json diff --git a/tests/fixtures/schemas/stream_telescope_file1.json b/observatory_platform/google/tests/fixtures/schema/stream_telescope_file1.json similarity index 100% rename from tests/fixtures/schemas/stream_telescope_file1.json rename to observatory_platform/google/tests/fixtures/schema/stream_telescope_file1.json diff --git a/tests/fixtures/schemas/stream_telescope_file2.json b/observatory_platform/google/tests/fixtures/schema/stream_telescope_file2.json similarity index 100% rename from tests/fixtures/schemas/stream_telescope_file2.json rename to observatory_platform/google/tests/fixtures/schema/stream_telescope_file2.json diff --git a/tests/fixtures/schemas/stream_telescope_schema.json b/observatory_platform/google/tests/fixtures/schema/stream_telescope_schema.json similarity index 100% rename from tests/fixtures/schemas/stream_telescope_schema.json rename to observatory_platform/google/tests/fixtures/schema/stream_telescope_schema.json diff --git a/tests/fixtures/schemas/table_a.json b/observatory_platform/google/tests/fixtures/schema/table_a.json similarity index 100% rename from tests/fixtures/schemas/table_a.json rename to observatory_platform/google/tests/fixtures/schema/table_a.json diff --git a/tests/fixtures/schemas/table_b_1900-01-01.json b/observatory_platform/google/tests/fixtures/schema/table_b_1900-01-01.json similarity index 100% rename from tests/fixtures/schemas/table_b_1900-01-01.json rename to observatory_platform/google/tests/fixtures/schema/table_b_1900-01-01.json diff --git a/tests/fixtures/schemas/table_b_2000-01-01.json b/observatory_platform/google/tests/fixtures/schema/table_b_2000-01-01.json similarity index 100% rename from tests/fixtures/schemas/table_b_2000-01-01.json rename to observatory_platform/google/tests/fixtures/schema/table_b_2000-01-01.json diff --git a/tests/fixtures/schemas/test_schema_2021-01-01.json b/observatory_platform/google/tests/fixtures/schema/test_schema_2021-01-01.json similarity index 100% rename from tests/fixtures/schemas/test_schema_2021-01-01.json rename to observatory_platform/google/tests/fixtures/schema/test_schema_2021-01-01.json diff --git a/tests/fixtures/schemas/wos-2020-10-01.json b/observatory_platform/google/tests/fixtures/schema/wos-2020-10-01.json similarity index 100% rename from tests/fixtures/schemas/wos-2020-10-01.json rename to observatory_platform/google/tests/fixtures/schema/wos-2020-10-01.json diff --git a/tests/observatory/platform/test_bigquery.py b/observatory_platform/google/tests/test_bigquery.py similarity index 90% rename from tests/observatory/platform/test_bigquery.py rename to observatory_platform/google/tests/test_bigquery.py index 1c68347a9..2a1fcdb80 100644 --- a/tests/observatory/platform/test_bigquery.py +++ b/observatory_platform/google/tests/test_bigquery.py @@ -1,4 +1,4 @@ -# Copyright 2020-2023 Curtin University. All Rights Reserved. +# Copyright 2020-2024 Curtin University. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,21 +14,23 @@ # Author: James Diprose, Aniek Roelofs, Alex Massen-Hane -import re import datetime import json import os +import re +import time import unittest from unittest.mock import patch import pendulum -import time from click.testing import CliRunner from google.api_core.exceptions import Conflict from google.cloud import bigquery, storage from google.cloud.bigquery import SourceFormat, Table as BQTable -from observatory.platform.bigquery import ( +from observatory_platform.config import module_file_path +from observatory_platform.files import load_jsonl +from observatory_platform.google.bigquery import ( assert_table_id, bq_table_id, bq_table_id_parts, @@ -59,9 +61,8 @@ bq_export_table, bq_query_bytes_budget_check, ) -from observatory.platform.files import load_jsonl -from observatory.platform.gcs import gcs_delete_old_buckets_with_prefix, gcs_upload_file -from observatory.platform.observatory_environment import random_id, test_fixtures_path, bq_dataset_test_env +from observatory_platform.google.gcs import gcs_delete_old_buckets_with_prefix, gcs_upload_file +from observatory_platform.sandbox.test_utils import random_id, bq_dataset_test_env class TestBigQuery(unittest.TestCase): @@ -75,8 +76,8 @@ def __init__(self, *args, **kwargs): self.data = "hello world" self.expected_crc32c = "yZRlqg==" self.prefix = "bq_tests" - self.patents_table_id = f"bigquery-public-data.labeled_patents.figures" + self.test_data_path = module_file_path("observatory_platform.google.tests.fixtures") # Save time and only have this run once. if not __class__.__init__already: @@ -164,7 +165,7 @@ def test_bq_table_exists(self): table_id = bq_table_id(self.gc_project_id, dataset_id, "not_exists") self.assertFalse(bq_table_exists(table_id=table_id)) - @patch("observatory.platform.bigquery.bq_select_table_shard_dates") + @patch("observatory_platform.google.bigquery.bq_select_table_shard_dates") def test_bq_select_latest_table(self, mock_sel_table_suffixes): """Test make_table_name""" dt = pendulum.datetime(2021, 1, 1) @@ -193,8 +194,7 @@ def test_bq_create_dataset(self): client.delete_dataset(dataset_id, not_found_ok=True) def test_bq_create_empty_table(self): - test_data_path = test_fixtures_path("utils") - schema_file_path = os.path.join(test_data_path, "people_schema.json") + schema_file_path = os.path.join(self.test_data_path, "people_schema.json") with bq_dataset_test_env( project_id=self.gc_project_id, location=self.gc_location, prefix=self.prefix @@ -537,8 +537,7 @@ def test_bq_list_tables(self): def test_bq_get_table(self): """Test if a table can be reliably grabbed from the Bogquery API.""" - test_data_path = test_fixtures_path("utils") - json_file_path = os.path.join(test_data_path, "people.jsonl") + json_file_path = os.path.join(self.test_data_path, "people.jsonl") test_data = load_jsonl(json_file_path) with bq_dataset_test_env( @@ -590,46 +589,80 @@ def test_bq_export_table(self): blob.delete() def test_bq_find_schema(self): - """Test that the schema of a table can be found locally.""" - - test_schemas_path = test_fixtures_path("schemas") - - # No date - table_name = "table_a" - expected_schema_path = os.path.join(test_schemas_path, "table_a.json") - result = bq_find_schema(path=test_schemas_path, table_name=table_name) - self.assertEqual(expected_schema_path, result) - - # With date - table_name = "table_b" - date = pendulum.datetime(year=1900, month=1, day=1) - expected_schema_path = os.path.join(test_schemas_path, "table_b_1900-01-01.json") - result = bq_find_schema(path=test_schemas_path, table_name=table_name, release_date=date) - self.assertEqual(expected_schema_path, result) - - # With date - prior to the release date - table_name = "table_b" - date = pendulum.datetime(year=2020, month=1, day=1) - expected_schema_path = os.path.join(test_schemas_path, "table_b_2000-01-01.json") - result = bq_find_schema(path=test_schemas_path, table_name=table_name, release_date=date) - self.assertEqual(expected_schema_path, result) - - # No schema found with matching date - table_name = "table_b" - date = pendulum.datetime(year=1800, month=1, day=1) - result = bq_find_schema(path=test_schemas_path, table_name=table_name, release_date=date) + schemas_path = os.path.join(self.test_data_path, "schema") + + test_release_date = pendulum.datetime(2022, 11, 11) + previous_release_date = pendulum.datetime(1950, 11, 11) + + # Nonexistent tables test case + result = bq_find_schema(path=schemas_path, table_name="this_table_does_not_exist") + self.assertIsNone(result) + + result = bq_find_schema(path=schemas_path, table_name="does_not_exist", prefix="this_table") + self.assertIsNone(result) + + result = bq_find_schema( + path=schemas_path, table_name="this_table_does_not_exist", release_date=test_release_date + ) + self.assertIsNone(result) + + result = bq_find_schema( + path=schemas_path, table_name="does_not_exist", release_date=test_release_date, prefix="this_table" + ) self.assertIsNone(result) - # No schema found with matching table name - table_name = "table_c" - result = bq_find_schema(path=test_schemas_path, table_name=table_name) + # Release date on table name that doesn't end in date + result = bq_find_schema(path=schemas_path, table_name="table_a", release_date=test_release_date) self.assertIsNone(result) + result = bq_find_schema(path=schemas_path, table_name="a", release_date=test_release_date, prefix="table_") + self.assertIsNone(result) + + # Release date before table date + snapshot_date = pendulum.datetime(year=1000, month=1, day=1) + result = bq_find_schema(path=schemas_path, table_name="table_b", release_date=snapshot_date) + self.assertIsNone(result) + + # Basic test case - no date + expected_schema = "table_a.json" + result = bq_find_schema(path=schemas_path, table_name="table_a") + self.assertIsNotNone(result) + self.assertTrue(result.endswith(expected_schema)) + + # Prefix with no date + expected_schema = "table_a.json" + result = bq_find_schema(path=schemas_path, table_name="a", prefix="table_") + self.assertIsNotNone(result) + self.assertTrue(result.endswith(expected_schema)) + + # Table with date + expected_schema = "table_b_2000-01-01.json" + result = bq_find_schema(path=schemas_path, table_name="table_b", release_date=test_release_date) + self.assertIsNotNone(result) + self.assertTrue(result.endswith(expected_schema)) + + # Table with date and prefix + expected_schema = "table_b_2000-01-01.json" + result = bq_find_schema(path=schemas_path, table_name="b", release_date=test_release_date, prefix="table_") + self.assertIsNotNone(result) + self.assertTrue(result.endswith(expected_schema)) + + # Table with old date + expected_schema = "table_b_1900-01-01.json" + result = bq_find_schema(path=schemas_path, table_name="table_b", release_date=previous_release_date) + self.assertIsNotNone(result) + self.assertTrue(result.endswith(expected_schema)) + + # Table with old date and prefix + expected_schema = "table_b_1900-01-01.json" + result = bq_find_schema(path=schemas_path, table_name="b", release_date=previous_release_date, prefix="table_") + self.assertIsNotNone(result) + self.assertTrue(result.endswith(expected_schema)) + def test_bq_update_table_description(self): """Test that the description of a table can be updated.""" - test_data_path = test_fixtures_path("utils") - json_file_path = os.path.join(test_data_path, "people.jsonl") + json_file_path = os.path.join(self.test_data_path, "people.jsonl") test_data = load_jsonl(json_file_path) table_id = random_id() @@ -651,17 +684,16 @@ def test_bq_update_table_description(self): self.assertEqual(table.description, updated_table_description) def test_bq_load_table(self): - test_data_path = test_fixtures_path("utils") - schema_file_path = os.path.join(test_data_path, "people_schema.json") + schema_file_path = os.path.join(self.test_data_path, "people_schema.json") # CSV file - csv_file_path = os.path.join(test_data_path, "people.csv") + csv_file_path = os.path.join(self.test_data_path, "people.csv") csv_blob_name = f"people_{random_id()}.csv" # JSON files - json_file_path = os.path.join(test_data_path, "people.jsonl") + json_file_path = os.path.join(self.test_data_path, "people.jsonl") json_blob_name = f"people_{random_id()}.jsonl" - json_extra_file_path = os.path.join(test_data_path, "people_extra.jsonl") + json_extra_file_path = os.path.join(self.test_data_path, "people_extra.jsonl") json_extra_blob_name = f"people_{random_id()}.jsonl" with bq_dataset_test_env( @@ -763,11 +795,10 @@ def test_bq_load_table(self): blob.delete() def test_bq_load_from_memory(self): - test_data_path = test_fixtures_path("utils") - json_file_path = os.path.join(test_data_path, "people.jsonl") + json_file_path = os.path.join(self.test_data_path, "people.jsonl") test_data = load_jsonl(json_file_path) - schema_file_path = os.path.join(test_data_path, "people_schema.json") + schema_file_path = os.path.join(self.test_data_path, "people_schema.json") with bq_dataset_test_env( project_id=self.gc_project_id, location=self.gc_location, prefix=self.prefix @@ -819,10 +850,9 @@ def test_bq_load_from_memory(self): self.assertTrue(bq_table_exists(table_id=table_id)) def test_bq_select_columns(self): - test_data_path = test_fixtures_path("utils") - json_file_path = os.path.join(test_data_path, "people.jsonl") + json_file_path = os.path.join(self.test_data_path, "people.jsonl") test_data = load_jsonl(json_file_path) - schema_file_path = os.path.join(test_data_path, "people_schema.json") + schema_file_path = os.path.join(self.test_data_path, "people_schema.json") expected_columns = [ dict(column_name="first_name", data_type="STRING"), diff --git a/tests/observatory/platform/test_gcs.py b/observatory_platform/google/tests/test_gcs.py similarity index 98% rename from tests/observatory/platform/test_gcs.py rename to observatory_platform/google/tests/test_gcs.py index e2f1225da..e3a806a21 100644 --- a/tests/observatory/platform/test_gcs.py +++ b/observatory_platform/google/tests/test_gcs.py @@ -24,12 +24,12 @@ import pendulum from azure.storage.blob import BlobClient, BlobServiceClient from click.testing import CliRunner -from google.cloud import storage from google.auth import default +from google.cloud import storage -from observatory.platform.bigquery import bq_delete_old_datasets_with_prefix -from observatory.platform.files import crc32c_base64_hash, hex_to_base64_str -from observatory.platform.gcs import ( +from observatory_platform.files import crc32c_base64_hash, hex_to_base64_str +from observatory_platform.google.bigquery import bq_delete_old_datasets_with_prefix +from observatory_platform.google.gcs import ( gcs_blob_uri, gcs_create_aws_transfer, gcs_create_azure_transfer, @@ -46,7 +46,7 @@ gcs_blob_name_from_path, gcs_hmac_key, ) -from observatory.platform.observatory_environment import random_id, aws_bucket_test_env +from observatory_platform.sandbox.test_utils import random_id, aws_bucket_test_env def make_account_url(account_name: str) -> str: @@ -111,7 +111,7 @@ def __init__(self, *args, **kwargs): gcs_delete_old_buckets_with_prefix(prefix=self.prefix, age_to_delete=12) __class__.__init__already = True - @patch("observatory.platform.airflow.Variable.get") + @patch("observatory_platform.airflow.workflow.Variable.get") def test_gcs_blob_name_from_path(self, mock_get_variable): """Tests the blob_name from_path function""" @@ -188,7 +188,7 @@ def test_gcs_copy_blob(self): if blob.exists(): blob.delete() - @patch("observatory.platform.airflow.Variable.get") + @patch("observatory_platform.airflow.workflow.Variable.get") def test_upload_download_blobs_from_cloud_storage(self, mock_get_variable): runner = CliRunner() with runner.isolated_filesystem() as t: diff --git a/observatory-platform/observatory/platform/utils/http_download.py b/observatory_platform/http_download.py similarity index 98% rename from observatory-platform/observatory/platform/utils/http_download.py rename to observatory_platform/http_download.py index 5f2733a0e..9e6e05bcb 100644 --- a/observatory-platform/observatory/platform/utils/http_download.py +++ b/observatory_platform/http_download.py @@ -16,7 +16,7 @@ # Asynchronous HTTP GET file downloader. Use the download_file, and download_files interfaces to download. # Creates a worker pool to asynchronously (single thread) download file(s). -# Valid hash algorithms: see observatory.platform.utils.file_utils.get_hasher_ for valid options. +# Valid hash algorithms: see observatory_platform.utils.file_utils.get_hasher_ for valid options. # # Usage examples: # custom_headers = {"User-Agent" : "Something" } @@ -46,8 +46,8 @@ import aiohttp import backoff -from observatory.platform.files import validate_file_hash -from observatory.platform.utils.url_utils import get_filename_from_url +from observatory_platform.files import validate_file_hash +from observatory_platform.url_utils import get_filename_from_url @dataclass diff --git a/observatory-platform/observatory/platform/utils/jinja2_utils.py b/observatory_platform/jinja2_utils.py similarity index 99% rename from observatory-platform/observatory/platform/utils/jinja2_utils.py rename to observatory_platform/jinja2_utils.py index 0a8974bf2..80331302b 100644 --- a/observatory-platform/observatory/platform/utils/jinja2_utils.py +++ b/observatory_platform/jinja2_utils.py @@ -15,9 +15,10 @@ # Author: James Diprose -from jinja2 import Environment, FileSystemLoader import os +from jinja2 import Environment, FileSystemLoader + def render_template(template_path: str, **kwargs) -> str: """Render a Jinja2 template. diff --git a/observatory-platform/observatory/platform/utils/proc_utils.py b/observatory_platform/proc_utils.py similarity index 54% rename from observatory-platform/observatory/platform/utils/proc_utils.py rename to observatory_platform/proc_utils.py index 1549d854a..56002a473 100644 --- a/observatory-platform/observatory/platform/utils/proc_utils.py +++ b/observatory_platform/proc_utils.py @@ -28,29 +28,3 @@ def wait_for_process(proc: Popen) -> Tuple[str, str]: output = output.decode("utf-8") error = error.decode("utf-8") return output, error - - -def stream_process(proc: Popen, debug: bool) -> Tuple[str, str]: - """Print output while a process is running, returning the std output and std error streams as strings. - - :param proc: the process object. - :param debug: whether debug info should be displayed. - :return: std output and std error streams as strings. - """ - output_concat = "" - error_concat = "" - while True: - if proc.stdout: - for line in proc.stdout: - output = line.decode("utf-8") - if debug: - print(output, end="") - output_concat += output - if proc.stderr: - for line in proc.stderr: - error = line.decode("utf-8") - print(error, end="") - error_concat += error - if proc.poll() is not None: - break - return output_concat, error_concat diff --git a/observatory-platform/observatory/platform/terraform/__init__.py b/observatory_platform/sandbox/__init__.py similarity index 100% rename from observatory-platform/observatory/platform/terraform/__init__.py rename to observatory_platform/sandbox/__init__.py diff --git a/observatory_platform/sandbox/ftp_server.py b/observatory_platform/sandbox/ftp_server.py new file mode 100644 index 000000000..4ef048aec --- /dev/null +++ b/observatory_platform/sandbox/ftp_server.py @@ -0,0 +1,90 @@ +# Copyright 2021-2024 Curtin University +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Author: Alex Massen-Hane + +import contextlib +import threading +import time + +from pyftpdlib.authorizers import DummyAuthorizer +from pyftpdlib.handlers import FTPHandler +from pyftpdlib.servers import ThreadedFTPServer + + +class FtpServer: + """ + Create a Mock FTPServer instance. + + :param directory: The directory that is hosted on FTP server. + :param host: Hostname of the server. + :param port: The port number. + :param startup_wait_secs: time in seconds to wait before returning from create to give the server enough + time to start before connecting to it. + """ + + def __init__( + self, + directory: str = "/", + host: str = "localhost", + port: int = 21, + startup_wait_secs: int = 1, + root_username: str = "root", + root_password: str = "pass", + ): + self.host = host + self.port = port + self.directory = directory + self.startup_wait_secs = startup_wait_secs + + self.root_username = root_username + self.root_password = root_password + + self.is_shutdown = True + self.server_thread = None + + @contextlib.contextmanager + def create(self): + """Make and destroy a test FTP server. + + :yield: self.directory. + """ + + # Set up the FTP server with root and anonymous users. + authorizer = DummyAuthorizer() + authorizer.add_user( + username=self.root_username, password=self.root_password, homedir=self.directory, perm="elradfmwMT" + ) + authorizer.add_anonymous(self.directory) + handler = FTPHandler + handler.authorizer = authorizer + + try: + # Start server in separate thread. + self.server = ThreadedFTPServer((self.host, self.port), handler) + self.server_thread = threading.Thread(target=self.server.serve_forever) + self.server_thread.daemon = True + self.server_thread.start() + + # Wait a little bit to give the server time to grab the socket + time.sleep(self.startup_wait_secs) + + yield self.directory + + finally: + # Stop server and wait for server thread to join + self.is_shutdown = True + if self.server_thread is not None: + self.server.close_all() + self.server_thread.join() diff --git a/observatory_platform/sandbox/http_server.py b/observatory_platform/sandbox/http_server.py new file mode 100644 index 000000000..c9cc41f00 --- /dev/null +++ b/observatory_platform/sandbox/http_server.py @@ -0,0 +1,87 @@ +# Copyright 2021-2024 Curtin University +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Author: Tuan Chien + +import contextlib +import os +from http.server import SimpleHTTPRequestHandler, ThreadingHTTPServer +from multiprocessing import Process + +from observatory_platform.sandbox.test_utils import find_free_port + + +class HttpServer: + """Simple HTTP server for testing. Serves files from a directory to http://locahost:port/filename""" + + def __init__(self, directory: str, host: str = "localhost", port: int = None): + """Initialises the server. + + :param directory: Directory to serve. + """ + + self.directory = directory + self.process = None + + self.host = host + if port is None: + port = find_free_port(host=self.host) + self.port = port + self.address = (self.host, self.port) + self.url = f"http://{self.host}:{self.port}/" + + @staticmethod + def serve_(address, directory): + """Entry point for a new process to run HTTP server. + + :param address: Address (host, port) to bind server to. + :param directory: Directory to serve. + """ + + os.chdir(directory) + server = ThreadingHTTPServer(address, SimpleHTTPRequestHandler) + server.serve_forever() + + def start(self): + """Spin the server up in a new process.""" + + # Don't try to start it twice. + if self.process is not None and self.process.is_alive(): + return + + self.process = Process( + target=HttpServer.serve_, + args=( + self.address, + self.directory, + ), + ) + self.process.start() + + def stop(self): + """Shutdown the server.""" + + if self.process is not None and self.process.is_alive(): + self.process.kill() + self.process.join() + + @contextlib.contextmanager + def create(self): + """Spin up a server for the duration of the session.""" + self.start() + + try: + yield self.process + finally: + self.stop() diff --git a/observatory_platform/sandbox/sandbox_environment.py b/observatory_platform/sandbox/sandbox_environment.py new file mode 100644 index 000000000..92158d7eb --- /dev/null +++ b/observatory_platform/sandbox/sandbox_environment.py @@ -0,0 +1,429 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# Sources: +# * https://github.com/apache/airflow/blob/ffb472cf9e630bd70f51b74b0d0ea4ab98635572/airflow/cli/commands/task_command.py +# * https://github.com/apache/airflow/blob/master/docs/apache-airflow/best-practices.rst + +# Copyright 2021-2024 Curtin University +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import datetime +import logging +import os +from datetime import datetime, timedelta +from typing import List, Optional, Set, Union + +import croniter +import google +import pendulum +import requests +from airflow import DAG, settings +from airflow.models.connection import Connection +from airflow.models.dagrun import DagRun +from airflow.models.taskinstance import TaskInstance +from airflow.models.variable import Variable +from airflow.utils import db +from airflow.utils.state import State +from airflow.utils.types import DagRunType +from click.testing import CliRunner +from dateutil.relativedelta import relativedelta +from google.cloud import bigquery, storage + +from observatory_platform.airflow.workflow import Workflow, CloudWorkspace, workflows_to_json_string +from observatory_platform.config import AirflowVars +from observatory_platform.google.bigquery import bq_delete_old_datasets_with_prefix +from observatory_platform.google.gcs import gcs_delete_old_buckets_with_prefix +from observatory_platform.sandbox.test_utils import random_id + + +class SandboxEnvironment: + OBSERVATORY_HOME_KEY = "OBSERVATORY_HOME" + + def __init__( + self, + project_id: str = None, + data_location: str = None, + prefix: Optional[str] = "obsenv_tests", + age_to_delete: int = 12, + workflows: List[Workflow] = None, + gcs_bucket_roles: Union[Set[str], str] = None, + ): + """Constructor for an Observatory environment. + + To create an Observatory environment: + env = SandboxEnvironment() + with env.create(): + pass + + :param project_id: the Google Cloud project id. + :param data_location: the Google Cloud data location. + :param prefix: prefix for buckets and datsets created for the testing environment. + :param age_to_delete: age of buckets and datasets to delete that share the same prefix, in hours + """ + + self.project_id = project_id + self.data_location = data_location + self.buckets = {} + self.datasets = [] + self.data_path = None + self.session = None + self.temp_dir = None + self.api_env = None + self.api_session = None + self.dag_run: DagRun = None + self.prefix = prefix + self.age_to_delete = age_to_delete + self.workflows = workflows + + if self.create_gcp_env: + self.download_bucket = self.add_bucket(roles=gcs_bucket_roles) + self.transform_bucket = self.add_bucket(roles=gcs_bucket_roles) + self.storage_client = storage.Client() + self.bigquery_client = bigquery.Client() + else: + self.download_bucket = None + self.transform_bucket = None + self.storage_client = None + self.bigquery_client = None + + @property + def cloud_workspace(self) -> CloudWorkspace: + return CloudWorkspace( + project_id=self.project_id, + download_bucket=self.download_bucket, + transform_bucket=self.transform_bucket, + data_location=self.data_location, + ) + + @property + def create_gcp_env(self) -> bool: + """Whether to create the Google Cloud project environment. + + :return: whether to create Google Cloud project environ,ent + """ + + return self.project_id is not None and self.data_location is not None + + def assert_gcp_dependencies(self): + """Assert that the Google Cloud project dependencies are met. + + :return: None. + """ + + assert self.create_gcp_env, "Please specify the Google Cloud project_id and data_location" + + def add_bucket(self, prefix: Optional[str] = None, roles: Optional[Union[Set[str], str]] = None) -> str: + """Add a Google Cloud Storage Bucket to the Observatory environment. + + The bucket will be created when create() is called and deleted when the Observatory + environment is closed. + + :param prefix: an optional additional prefix for the bucket. + :return: returns the bucket name. + """ + + self.assert_gcp_dependencies() + parts = [] + if self.prefix: + parts.append(self.prefix) + if prefix: + parts.append(prefix) + parts.append(random_id()) + bucket_name = "_".join(parts) + + if len(bucket_name) > 63: + raise Exception(f"Bucket name cannot be longer than 63 characters: {bucket_name}") + else: + self.buckets[bucket_name] = roles + + return bucket_name + + def _create_bucket(self, bucket_id: str, roles: Optional[Union[str, Set[str]]] = None) -> None: + """Create a Google Cloud Storage Bucket. + + :param bucket_id: the bucket identifier. + :param roles: Create bucket with custom roles if required. + :return: None. + """ + + self.assert_gcp_dependencies() + bucket = self.storage_client.create_bucket(bucket_id, location=self.data_location) + logging.info(f"Created bucket with name: {bucket_id}") + + if roles: + roles = set(roles) if isinstance(roles, str) else roles + + # Get policy of bucket and add roles. + policy = bucket.get_iam_policy() + for role in roles: + policy.bindings.append({"role": role, "members": {"allUsers"}}) + bucket.set_iam_policy(policy) + logging.info(f"Added permission {role} to bucket {bucket_id} for allUsers.") + + def _create_dataset(self, dataset_id: str) -> None: + """Create a BigQuery dataset. + + :param dataset_id: the dataset identifier. + :return: None. + """ + + self.assert_gcp_dependencies() + dataset = bigquery.Dataset(f"{self.project_id}.{dataset_id}") + dataset.location = self.data_location + self.bigquery_client.create_dataset(dataset, exists_ok=True) + logging.info(f"Created dataset with name: {dataset_id}") + + def _delete_bucket(self, bucket_id: str) -> None: + """Delete a Google Cloud Storage Bucket. + + :param bucket_id: the bucket identifier. + :return: None. + """ + + self.assert_gcp_dependencies() + try: + bucket = self.storage_client.get_bucket(bucket_id) + bucket.delete(force=True) + except requests.exceptions.ReadTimeout: + pass + except google.api_core.exceptions.NotFound: + logging.warning( + f"Bucket {bucket_id} not found. Did you mean to call _delete_bucket on the same bucket twice?" + ) + + def add_dataset(self, prefix: Optional[str] = None) -> str: + """Add a BigQuery dataset to the Observatory environment. + + The BigQuery dataset will be deleted when the Observatory environment is closed. + + :param prefix: an optional additional prefix for the dataset. + :return: the BigQuery dataset identifier. + """ + + self.assert_gcp_dependencies() + parts = [] + if self.prefix: + parts.append(self.prefix) + if prefix: + parts.append(prefix) + parts.append(random_id()) + dataset_id = "_".join(parts) + self.datasets.append(dataset_id) + return dataset_id + + def _delete_dataset(self, dataset_id: str) -> None: + """Delete a BigQuery dataset. + + :param dataset_id: the BigQuery dataset identifier. + :return: None. + """ + + self.assert_gcp_dependencies() + try: + self.bigquery_client.delete_dataset(dataset_id, not_found_ok=True, delete_contents=True) + except requests.exceptions.ReadTimeout: + pass + + def add_variable(self, var: Variable) -> None: + """Add an Airflow variable to the Observatory environment. + + :param var: the Airflow variable. + :return: None. + """ + + self.session.add(var) + self.session.commit() + + def add_connection(self, conn: Connection): + """Add an Airflow connection to the Observatory environment. + + :param conn: the Airflow connection. + :return: None. + """ + + self.session.add(conn) + self.session.commit() + + def run_task(self, task_id: str) -> TaskInstance: + """Run an Airflow task. + + :param task_id: the Airflow task identifier. + :return: None. + """ + + assert self.dag_run is not None, "with create_dag_run must be called before run_task" + + dag = self.dag_run.dag + run_id = self.dag_run.run_id + task = dag.get_task(task_id=task_id) + ti = TaskInstance(task, run_id=run_id) + ti.refresh_from_db() + ti.run(ignore_ti_state=True) + + return ti + + def get_task_instance(self, task_id: str) -> TaskInstance: + """Get an up-to-date TaskInstance. + + :param task_id: the task id. + :return: up-to-date TaskInstance instance. + """ + + assert self.dag_run is not None, "with create_dag_run must be called before get_task_instance" + + run_id = self.dag_run.run_id + task = self.dag_run.dag.get_task(task_id=task_id) + ti = TaskInstance(task, run_id=run_id) + ti.refresh_from_db() + return ti + + @contextlib.contextmanager + def create_dag_run( + self, + dag: DAG, + execution_date: pendulum.DateTime, + run_type: DagRunType = DagRunType.SCHEDULED, + ): + """Create a DagRun that can be used when running tasks. + During cleanup the DAG run state is updated. + + :param dag: the Airflow DAG instance. + :param execution_date: the execution date of the DAG. + :param run_type: what run_type to use when running the DAG run. + :return: None. + """ + + # Get start date, which is one schedule interval after execution date + if isinstance(dag.normalized_schedule_interval, (timedelta, relativedelta)): + start_date = ( + datetime.fromtimestamp(execution_date.timestamp(), pendulum.tz.UTC) + dag.normalized_schedule_interval + ) + else: + start_date = croniter.croniter(dag.normalized_schedule_interval, execution_date).get_next(pendulum.DateTime) + + try: + self.dag_run = dag.create_dagrun( + state=State.RUNNING, + execution_date=execution_date, + start_date=start_date, + run_type=run_type, + ) + yield self.dag_run + finally: + self.dag_run.update_state() + + @contextlib.contextmanager + def create(self, task_logging: bool = False): + """Make and destroy an Observatory isolated environment, which involves: + + * Creating a temporary directory. + * Setting the OBSERVATORY_HOME environment variable. + * Initialising a temporary Airflow database. + * Creating download and transform Google Cloud Storage buckets. + * Creating default Airflow Variables: AirflowVars.DATA_PATH, + AirflowVars.DOWNLOAD_BUCKET and AirflowVars.TRANSFORM_BUCKET. + * Cleaning up all resources when the environment is closed. + + :param task_logging: display airflow task logging + :yield: Observatory environment temporary directory. + """ + + with CliRunner().isolated_filesystem() as temp_dir: + # Set temporary directory + self.temp_dir = temp_dir + + # Prepare environment + self.new_env = {self.OBSERVATORY_HOME_KEY: os.path.join(self.temp_dir, ".observatory")} + prev_env = dict(os.environ) + + try: + # Update environment + os.environ.update(self.new_env) + + # Create Airflow SQLite database + settings.DAGS_FOLDER = os.path.join(self.temp_dir, "airflow", "dags") + os.makedirs(settings.DAGS_FOLDER, exist_ok=True) + airflow_db_path = os.path.join(self.temp_dir, "airflow.db") + settings.SQL_ALCHEMY_CONN = f"sqlite:///{airflow_db_path}" + logging.info(f"SQL_ALCHEMY_CONN: {settings.SQL_ALCHEMY_CONN}") + settings.configure_orm(disable_connection_pool=True) + self.session = settings.Session + db.initdb() + + # Setup Airflow task logging + original_log_level = logging.getLogger().getEffectiveLevel() + if task_logging: + # Set root logger to INFO level, it seems that custom 'logging.info()' statements inside a task + # come from root + logging.getLogger().setLevel(20) + # Propagate logging so it is displayed + logging.getLogger("airflow.task").propagate = True + + # Create buckets and datasets + if self.create_gcp_env: + for bucket_id, roles in self.buckets.items(): + self._create_bucket(bucket_id, roles) + + for dataset_id in self.datasets: + self._create_dataset(dataset_id) + + # Deletes old test buckets and datasets from the project thats older than 2 hours. + gcs_delete_old_buckets_with_prefix(prefix=self.prefix, age_to_delete=self.age_to_delete) + bq_delete_old_datasets_with_prefix(prefix=self.prefix, age_to_delete=self.age_to_delete) + + # Add default Airflow variables + self.data_path = os.path.join(self.temp_dir, "data") + self.add_variable(Variable(key=AirflowVars.DATA_PATH, val=self.data_path)) + + if self.workflows is not None: + var = workflows_to_json_string(self.workflows) + self.add_variable(Variable(key=AirflowVars.WORKFLOWS, val=var)) + + # Reset dag run + self.dag_run: DagRun = None + + yield self.temp_dir + finally: + # Set logger settings back to original settings + logging.getLogger().setLevel(original_log_level) + logging.getLogger("airflow.task").propagate = False + + # Revert environment + os.environ.clear() + os.environ.update(prev_env) + + if self.create_gcp_env: + # Remove Google Cloud Storage buckets + for bucket_id, roles in self.buckets.items(): + self._delete_bucket(bucket_id) + + # Remove BigQuery datasets + for dataset_id in self.datasets: + self._delete_dataset(dataset_id) diff --git a/observatory_platform/sandbox/sftp_server.py b/observatory_platform/sandbox/sftp_server.py new file mode 100644 index 000000000..44efae0b1 --- /dev/null +++ b/observatory_platform/sandbox/sftp_server.py @@ -0,0 +1,165 @@ +# Copyright (c) 2011-2017 Ruslan Spivak +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# Sources: +# * https://github.com/rspivak/sftpserver/blob/master/src/sftpserver/__init__.py + +# Copyright 2021-2024 Curtin University +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import os +import socket +import threading +import time + +import paramiko +from click.testing import CliRunner +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import rsa +from sftpserver.stub_sftp import StubServer, StubSFTPServer + + +class SftpServer: + """A Mock SFTP server for testing purposes""" + + def __init__( + self, + host: str = "localhost", + port: int = 3373, + level: str = "INFO", + backlog: int = 10, + startup_wait_secs: int = 1, + socket_timeout: int = 10, + ): + """Create a Mock SftpServer instance. + + :param host: the host name. + :param port: the port. + :param level: the log level. + :param backlog: ? + :param startup_wait_secs: time in seconds to wait before returning from create to give the server enough + time to start before connecting to it. + """ + + self.host = host + self.port = port + self.level = level + self.backlog = backlog + self.startup_wait_secs = startup_wait_secs + self.is_shutdown = True + self.tmp_dir = None + self.root_dir = None + self.private_key_path = None + self.server_thread = None + self.socket_timeout = socket_timeout + + def _generate_key(self): + """Generate a private key. + + :return: the filepath to the private key. + """ + + key = rsa.generate_private_key(public_exponent=65537, key_size=2048, backend=default_backend()) + + private_key_path = os.path.join(self.tmp_dir, "test_rsa.key") + with open(private_key_path, "wb") as f: + f.write( + key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption(), + ) + ) + + return private_key_path + + def _start_server(self): + paramiko_level = getattr(paramiko.common, self.level) + paramiko.common.logging.basicConfig(level=paramiko_level) + + server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + server_socket.settimeout(self.socket_timeout) + server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, True) + server_socket.bind((self.host, self.port)) + server_socket.listen(self.backlog) + + while not self.is_shutdown: + try: + conn, addr = server_socket.accept() + transport = paramiko.Transport(conn) + transport.add_server_key(paramiko.RSAKey.from_private_key_file(self.private_key_path)) + transport.set_subsystem_handler("sftp", paramiko.SFTPServer, StubSFTPServer) + + server = StubServer() + transport.start_server(server=server) + + channel = transport.accept() + while transport.is_active() and not self.is_shutdown: + time.sleep(1) + + except socket.timeout: + # Timeout must be set for socket otherwise it will wait for a connection forever and block + # the thread from exiting. At: conn, addr = server_socket.accept() + pass + + @contextlib.contextmanager + def create(self): + """Make and destroy a test SFTP server. + + :yield: None. + """ + + with CliRunner().isolated_filesystem() as tmp_dir: + # Override the root directory of the SFTP server, which is set as the cwd at import time + self.tmp_dir = tmp_dir + self.root_dir = os.path.join(tmp_dir, "home") + os.makedirs(self.root_dir, exist_ok=True) + StubSFTPServer.ROOT = self.root_dir + + # Generate private key + self.private_key_path = self._generate_key() + + try: + self.is_shutdown = False + self.server_thread = threading.Thread(target=self._start_server) + self.server_thread.start() + + # Wait a little bit to give the server time to grab the socket + time.sleep(self.startup_wait_secs) + + yield self.root_dir + finally: + # Stop server and wait for server thread to join + self.is_shutdown = True + if self.server_thread is not None: + self.server_thread.join() diff --git a/observatory_platform/sandbox/test_utils.py b/observatory_platform/sandbox/test_utils.py new file mode 100644 index 000000000..12b26aca4 --- /dev/null +++ b/observatory_platform/sandbox/test_utils.py @@ -0,0 +1,613 @@ +# Copyright 2021-2024 Curtin University +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import contextlib +import datetime +import json +import logging +import os +import shutil +import socket +import socketserver +import unittest +import uuid +from dataclasses import dataclass +from datetime import datetime +from typing import Dict, List, Set + +import boto3 +import httpretty +import pendulum +from airflow import DAG +from airflow.exceptions import AirflowException +from airflow.models import DagBag +from airflow.operators.empty import EmptyOperator +from click.testing import CliRunner +from deepdiff import DeepDiff +from google.cloud import bigquery, storage +from google.cloud.bigquery import SourceFormat +from google.cloud.exceptions import NotFound +from pendulum import DateTime + +from observatory_platform.airflow.workflow import CloudWorkspace +from observatory_platform.files import crc32c_base64_hash, get_file_hash, gzip_file_crc, save_jsonl_gz +from observatory_platform.google.bigquery import bq_sharded_table_id, bq_load_table, bq_table_id, bq_create_dataset +from observatory_platform.google.gcs import gcs_blob_uri, gcs_upload_files + + +class SandboxTestCase(unittest.TestCase): + """Common test functions for testing Observatory Platform DAGs""" + + def __init__(self, *args, **kwargs): + """Constructor which sets up variables used by tests. + + :param args: arguments. + :param kwargs: keyword arguments. + """ + + super(SandboxTestCase, self).__init__(*args, **kwargs) + self.storage_client = storage.Client() + self.bigquery_client = bigquery.Client() + + # Turn logging to warning because vcr prints too much at info level + logging.basicConfig() + vcr_log = logging.getLogger("vcr") + vcr_log.setLevel(logging.WARNING) + + @property + def fake_cloud_workspace(self): + return CloudWorkspace( + project_id="project-id", + download_bucket="download_bucket", + transform_bucket="transform_bucket", + data_location="us", + ) + + def assert_dag_structure(self, expected: Dict, dag: DAG): + """Assert the DAG structure. + + :param expected: a dictionary of DAG task ids as keys and values which should be a list of downstream task ids. + :param dag: the DAG. + :return: None. + """ + + expected_keys = expected.keys() + actual_keys = dag.task_dict.keys() + self.assertEqual(expected_keys, actual_keys) + + for task_id, downstream_list in expected.items(): + self.assertTrue(dag.has_task(task_id)) + task = dag.get_task(task_id) + self.assertEqual(set(downstream_list), task.downstream_task_ids) + + def assert_dag_load(self, dag_id: str, dag_file: str): + """Assert that the given DAG loads from a DagBag. + + :param dag_id: the DAG id. + :param dag_file: the path to the DAG file. + :return: None. + """ + + with CliRunner().isolated_filesystem() as dag_folder: + if not os.path.exists(dag_file): + raise Exception(f"{dag_file} does not exist.") + + shutil.copy(dag_file, os.path.join(dag_folder, os.path.basename(dag_file))) + + dag_bag = DagBag(dag_folder=dag_folder) + + if dag_bag.import_errors != {}: + logging.error(f"DagBag errors: {dag_bag.import_errors}") + self.assertEqual({}, dag_bag.import_errors, dag_bag.import_errors) + + dag = dag_bag.get_dag(dag_id=dag_id) + + if dag is None: + logging.error( + f"DAG not found in the database. Make sure the DAG ID is correct, and the dag file contains the words 'airflow' and 'DAG'." + ) + self.assertIsNotNone(dag) + + self.assertGreaterEqual(len(dag.tasks), 1) + + def assert_dag_load_from_config(self, dag_id: str, dag_file: str): + """Assert that the given DAG loads from a config file. + + :param dag_id: the DAG id. + :param dag_file: the path to the dag loader + :return: None. + """ + + self.assert_dag_load(dag_id, dag_file) + + def assert_blob_exists(self, bucket_id: str, blob_name: str): + """Assert whether a blob exists or not. + + :param bucket_id: the Google Cloud storage bucket id. + :param blob_name: the blob name (full path except for bucket) + :return: None. + """ + + # Get blob + bucket = self.storage_client.get_bucket(bucket_id) + blob = bucket.blob(blob_name) + self.assertTrue(blob.exists()) + + def assert_blob_integrity(self, bucket_id: str, blob_name: str, local_file_path: str): + """Assert whether the blob uploaded and that it has the expected hash. + + :param blob_name: the Google Cloud Blob name, i.e. the entire path to the blob on the Cloud Storage bucket. + :param bucket_id: the Google Cloud Storage bucket id. + :param local_file_path: the path to the local file. + :return: whether the blob uploaded and that it has the expected hash. + """ + + # Get blob + bucket = self.storage_client.get_bucket(bucket_id) + blob = bucket.blob(blob_name) + result = blob.exists() + + # Check that blob hash matches if it exists + if result: + # Get blob hash + blob.reload() + expected_hash = blob.crc32c + + # Check actual file + actual_hash = crc32c_base64_hash(local_file_path) + result = expected_hash == actual_hash + + self.assertTrue(result) + + def assert_table_integrity(self, table_id: str, expected_rows: int = None): + """Assert whether a BigQuery table exists and has the expected number of rows. + + :param table_id: the BigQuery table id. + :param expected_rows: the expected number of rows. + :return: whether the table exists and has the expected number of rows. + """ + + table = None + actual_rows = None + try: + table = self.bigquery_client.get_table(table_id) + actual_rows = table.num_rows + except NotFound: + pass + + self.assertIsNotNone(table) + if expected_rows is not None: + self.assertEqual(expected_rows, actual_rows) + + def assert_table_content(self, table_id: str, expected_content: List[dict], primary_key: str): + """Assert whether a BigQuery table has any content and if expected content is given whether it matches the + actual content. The order of the rows is not checked, only whether all rows in the expected content match + the rows in the actual content. + The expected content should be a list of dictionaries, where each dictionary represents one row of the table, + the keys are fieldnames and values are values. + + :param table_id: the BigQuery table id. + :param expected_content: the expected content. + :param primary_key: the primary key to use to compare. + :return: whether the table has content and the expected content is correct + """ + + logging.info( + f"assert_table_content: {table_id}, len(expected_content)={len(expected_content), }, primary_key={primary_key}" + ) + rows = None + actual_content = None + try: + rows = list(self.bigquery_client.list_rows(table_id)) + actual_content = [dict(row) for row in rows] + except NotFound: + pass + self.assertIsNotNone(rows) + self.assertIsNotNone(actual_content) + results = compare_lists_of_dicts(expected_content, actual_content, primary_key) + assert results, "Rows in actual content do not match expected content" + + def assert_table_bytes(self, table_id: str, expected_bytes: int): + """Assert whether the given bytes from a BigQuery table matches the expected bytes. + + :param table_id: the BigQuery table id. + :param expected_bytes: the expected number of bytes. + :return: whether the table exists and the expected bytes match + """ + + table = None + try: + table = self.bigquery_client.get_table(table_id) + except NotFound: + pass + + self.assertIsNotNone(table) + self.assertEqual(expected_bytes, table.num_bytes) + + def assert_file_integrity(self, file_path: str, expected_hash: str, algorithm: str): + """Assert that a file exists and it has the correct hash. + + :param file_path: the path to the file. + :param expected_hash: the expected hash. + :param algorithm: the algorithm to use when hashing, either md5 or gzip crc + :return: None. + """ + + self.assertTrue(os.path.isfile(file_path)) + + if algorithm == "gzip_crc": + actual_hash = gzip_file_crc(file_path) + else: + actual_hash = get_file_hash(file_path=file_path, algorithm=algorithm) + + self.assertEqual(expected_hash, actual_hash) + + def assert_cleanup(self, workflow_folder: str): + """Assert that the download, extracted and transformed folders were cleaned up. + + :param workflow_folder: the path to the DAGs download folder. + :return: None. + """ + + self.assertFalse(os.path.exists(workflow_folder)) + + def setup_mock_file_download( + self, uri: str, file_path: str, headers: Dict = None, method: str = httpretty.GET + ) -> None: + """Use httpretty to mock a file download. + + This function must be called from within an httpretty.enabled() block, for instance: + + with httpretty.enabled(): + self.setup_mock_file_download('https://example.com/file.zip', path_to_file) + + :param uri: the URI of the file download to mock. + :param file_path: the path to the file on the local system. + :param headers: the response headers. + :return: None. + """ + + if headers is None: + headers = {} + + with open(file_path, "rb") as f: + body = f.read() + + httpretty.register_uri(method, uri, adding_headers=headers, body=body) + + +def random_id(): + """Generate a random id for bucket name. + + When code is pushed to a branch and a pull request is open, Github Actions runs the unit tests workflow + twice, one for the push and one for the pull request. However, the uuid4 function, which calls os.urandom(16), + generates the same sequence of values for each workflow run. We have also used the hostname of the machine + in the construction of the random id to ensure sure that the ids are different on both workflow runs. + + :return: a random string id. + """ + return str(uuid.uuid5(uuid.uuid4(), socket.gethostname())).replace("-", "") + + +def find_free_port(host: str = "localhost") -> int: + """Find a free port. + + :param host: the host. + :return: the free port number + """ + + with socketserver.TCPServer((host, 0), None) as tcp_server: + return tcp_server.server_address[1] + + +def save_empty_file(path: str, file_name: str) -> str: + """Save empty file and return path. + + :param path: the file directory. + :param file_name: the file name. + :return: the full file path. + """ + + file_path = os.path.join(path, file_name) + open(file_path, "a").close() + + return file_path + + +@contextlib.contextmanager +def bq_dataset_test_env(*, project_id: str, location: str, prefix: str): + client = bigquery.Client() + dataset_id = prefix + "_" + random_id() + try: + bq_create_dataset(project_id=project_id, dataset_id=dataset_id, location=location) + yield dataset_id + finally: + client.delete_dataset(dataset_id, delete_contents=True, not_found_ok=True) + + +@contextlib.contextmanager +def aws_bucket_test_env(*, prefix: str, region_name: str, expiration_days=1) -> str: + # Create an S3 client + s3 = boto3.Session().client("s3", region_name=region_name) + bucket_name = f"obs-test-{prefix}-{random_id()}" + try: + s3.create_bucket(Bucket=bucket_name) # CreateBucketConfiguration={"LocationConstraint": region_name} + # Set up the lifecycle configuration + lifecycle_configuration = { + "Rules": [ + {"ID": "ExpireObjects", "Status": "Enabled", "Filter": {}, "Expiration": {"Days": expiration_days}} + ] + } + # Apply the lifecycle configuration to the bucket + s3.put_bucket_lifecycle_configuration(Bucket=bucket_name, LifecycleConfiguration=lifecycle_configuration) + yield bucket_name + except Exception as e: + raise e + finally: + # Get a reference to the bucket + s3_resource = boto3.Session().resource("s3") + bucket = s3_resource.Bucket(bucket_name) + + # Delete all objects and versions in the bucket + bucket.objects.all().delete() + bucket.object_versions.all().delete() + + # Delete the bucket + bucket.delete() + + print(f"Bucket {bucket_name} deleted") + + +def load_and_parse_json( + file_path: str, + date_fields: Set[str] = None, + timestamp_fields: Set[str] = None, + date_formats: Set[str] = None, + timestamp_formats: str = None, +): + """Load a JSON file for testing purposes. It parses string dates and datetimes into date and datetime instances. + + :param file_path: the path to the JSON file. + :param date_fields: The fields to parse as a date. + :param timestamp_fields: The fields to parse as a timestamp. + :param date_formats: The date formats to use. If none, will use [%Y-%m-%d, %Y%m%d]. + :param timestamp_formats: The timestamp formats to use. If none, will use [%Y-%m-%d %H:%M:%S.%f %Z]. + """ + + if date_fields is None: + date_fields = set() + + if timestamp_fields is None: + timestamp_fields = set() + + if date_formats is None: + date_formats = {"%Y-%m-%d", "%Y%m%d"} + + if timestamp_formats is None: + timestamp_formats = {"%Y-%m-%d %H:%M:%S.%f %Z"} + + def parse_datetime(obj): + for key, value in obj.items(): + # Try to parse into a date or datetime + if key in date_fields: + if isinstance(value, str): + format_found = False + for format in date_formats: + try: + obj[key] = datetime.strptime(value, format).date() + format_found = True + break + except (ValueError, TypeError): + pass + if not format_found: + try: + dt = pendulum.parse(value) + dt = datetime( + dt.year, + dt.month, + dt.day, + dt.hour, + dt.minute, + dt.second, + dt.microsecond, + tzinfo=dt.tzinfo, + ).date() + obj[key] = dt + except (ValueError, TypeError): + pass + + if key in timestamp_fields: + if isinstance(value, str): + format_found = False + for format in timestamp_formats: + try: + obj[key] = datetime.strptime(value, format) + format_found = True + break + except (ValueError, TypeError): + pass + if not format_found: + try: + dt = pendulum.parse(value) + dt = datetime( + dt.year, + dt.month, + dt.day, + dt.hour, + dt.minute, + dt.second, + dt.microsecond, + tzinfo=dt.tzinfo, + ) + obj[key] = dt + except (ValueError, TypeError): + pass + + return obj + + with open(file_path, mode="r") as f: + rows = json.load(f, object_hook=parse_datetime) + return rows + + +def compare_lists_of_dicts(expected: List[Dict], actual: List[Dict], primary_key: str) -> bool: + """Compare two lists of dictionaries, using a primary_key as the basis for the top level comparisons. + + :param expected: the expected data. + :param actual: the actual data. + :param primary_key: the primary key. + :return: whether the expected and actual match. + """ + + expected_dict = {item[primary_key]: item for item in expected} + actual_dict = {item[primary_key]: item for item in actual} + + if set(expected_dict.keys()) != set(actual_dict.keys()): + logging.error("Primary keys don't match:") + logging.error(f"Only in expected: {set(expected_dict.keys()) - set(actual_dict.keys())}") + logging.error(f"Only in actual: {set(actual_dict.keys()) - set(expected_dict.keys())}") + return False + + all_matched = True + for key in expected_dict: + diff = DeepDiff(expected_dict[key], actual_dict[key], ignore_order=True) + logging.info(f"primary_key: {key}") + for diff_type, changes in diff.items(): + all_matched = False + log_diff(diff_type, changes) + + return all_matched + + +def log_diff(diff_type, changes): + """Log the DeepDiff changes. + + :param diff_type: the diff type. + :param changes: the changes. + :return: None. + """ + + if diff_type == "values_changed": + for key_path, change in changes.items(): + logging.error( + f"(expected) != (actual) {key_path}: {change['old_value']} (expected) != (actual) {change['new_value']}" + ) + elif diff_type == "dictionary_item_added": + for change in changes: + logging.error(f"dictionary_item_added: {change}") + elif diff_type == "dictionary_item_removed": + for change in changes: + logging.error(f"dictionary_item_removed: {change}") + elif diff_type == "type_changes": + for key_path, change in changes.items(): + logging.error( + f"(expected) != (actual) {key_path}: {change['old_type']} (expected) != (actual) {change['new_type']}" + ) + + +def make_dummy_dag(dag_id: str, execution_date: pendulum.DateTime) -> DAG: + """A Dummy DAG for testing purposes. + + :param dag_id: the DAG id. + :param execution_date: the DAGs execution date. + :return: the DAG. + """ + + with DAG( + dag_id=dag_id, + schedule="@weekly", + default_args={"owner": "airflow", "start_date": execution_date}, + catchup=False, + ) as dag: + task1 = EmptyOperator(task_id="dummy_task") + + return dag + + +@dataclass +class Table: + """A table to be loaded into Elasticsearch. + + :param table_name: the table name. + :param is_sharded: whether the table is sharded or not. + :param dataset_id: the dataset id. + :param records: the records to load. + :param schema_file_path: the schema file path. + """ + + table_name: str + is_sharded: bool + dataset_id: str + records: List[Dict] + schema_file_path: str + + +def bq_load_tables( + *, + project_id: str, + tables: List[Table], + bucket_name: str, + snapshot_date: DateTime, +): + """Load the fake Observatory Dataset in BigQuery. + + :param project_id: GCP project id. + :param tables: the list of tables and records to load. + :param bucket_name: the Google Cloud Storage bucket name. + :param snapshot_date: the release date for the observatory dataset. + :return: None. + """ + + with CliRunner().isolated_filesystem() as t: + files_list = [] + blob_names = [] + + # Save to JSONL + for table in tables: + blob_name = f"{table.dataset_id}-{table.table_name}.jsonl.gz" + file_path = os.path.join(t, blob_name) + save_jsonl_gz(file_path, table.records) + files_list.append(file_path) + blob_names.append(blob_name) + + # Upload to Google Cloud Storage + success = gcs_upload_files(bucket_name=bucket_name, file_paths=files_list, blob_names=blob_names) + assert success, "Data did not load into BigQuery" + + # Save to BigQuery tables + for blob_name, table in zip(blob_names, tables): + if table.schema_file_path is None: + logging.error( + f"No schema found with search parameters: analysis_schema_path={table.schema_file_path}, " + f"table_name={table.table_name}, snapshot_date={snapshot_date}" + ) + exit(os.EX_CONFIG) + + if table.is_sharded: + table_id = bq_sharded_table_id(project_id, table.dataset_id, table.table_name, snapshot_date) + else: + table_id = bq_table_id(project_id, table.dataset_id, table.table_name) + + # Load BigQuery table + uri = gcs_blob_uri(bucket_name, blob_name) + logging.info(f"URI: {uri}") + success = bq_load_table( + uri=uri, + table_id=table_id, + schema_file_path=table.schema_file_path, + source_format=SourceFormat.NEWLINE_DELIMITED_JSON, + ) + if not success: + raise AirflowException("bq_load task: data failed to load data into BigQuery") diff --git a/observatory-platform/observatory/platform/utils/__init__.py b/observatory_platform/sandbox/tests/__init__.py similarity index 100% rename from observatory-platform/observatory/platform/utils/__init__.py rename to observatory_platform/sandbox/tests/__init__.py diff --git a/observatory-platform/observatory/platform/workflows/__init__.py b/observatory_platform/sandbox/tests/fixtures/__init__.py similarity index 100% rename from observatory-platform/observatory/platform/workflows/__init__.py rename to observatory_platform/sandbox/tests/fixtures/__init__.py diff --git a/observatory_platform/sandbox/tests/fixtures/bad_dag.py b/observatory_platform/sandbox/tests/fixtures/bad_dag.py new file mode 100644 index 000000000..e4ce1af0c --- /dev/null +++ b/observatory_platform/sandbox/tests/fixtures/bad_dag.py @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ecff8cd6bee7e5a6dd1a4b40358484eef5c52f0efe1b3c31a00f62a1d7987f81 +size 978 diff --git a/tests/fixtures/utils/http_testfile.txt b/observatory_platform/sandbox/tests/fixtures/http_testfile.txt similarity index 100% rename from tests/fixtures/utils/http_testfile.txt rename to observatory_platform/sandbox/tests/fixtures/http_testfile.txt diff --git a/observatory_platform/sandbox/tests/fixtures/people.csv b/observatory_platform/sandbox/tests/fixtures/people.csv new file mode 100644 index 000000000..d53605417 --- /dev/null +++ b/observatory_platform/sandbox/tests/fixtures/people.csv @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c42995743365e9fac984892f73786afd3e1502e0c906c6f72a54b03735a293ae +size 132 diff --git a/tests/fixtures/utils/people.csv.gz b/observatory_platform/sandbox/tests/fixtures/people.csv.gz similarity index 100% rename from tests/fixtures/utils/people.csv.gz rename to observatory_platform/sandbox/tests/fixtures/people.csv.gz diff --git a/observatory_platform/sandbox/tests/fixtures/people.jsonl b/observatory_platform/sandbox/tests/fixtures/people.jsonl new file mode 100644 index 000000000..ac6a8b757 --- /dev/null +++ b/observatory_platform/sandbox/tests/fixtures/people.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f9f32d69651911b75d8e1727f60fd0a523a0fb4fd5771934e90b2734dad2d4c6 +size 333 diff --git a/observatory_platform/sandbox/tests/fixtures/people_schema.json b/observatory_platform/sandbox/tests/fixtures/people_schema.json new file mode 100644 index 000000000..be39b540c --- /dev/null +++ b/observatory_platform/sandbox/tests/fixtures/people_schema.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a8cd777b358e09862d0048779c50db25e2d9b09e0f755193478e0ba72efcd866 +size 232 diff --git a/observatory_platform/sandbox/tests/test_ftp_server.py b/observatory_platform/sandbox/tests/test_ftp_server.py new file mode 100644 index 000000000..5f8339a32 --- /dev/null +++ b/observatory_platform/sandbox/tests/test_ftp_server.py @@ -0,0 +1,93 @@ +# Copyright 2021-2024 Curtin University +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import contextlib +import os +import unittest +from ftplib import FTP + +from click.testing import CliRunner + +from observatory_platform.sandbox.ftp_server import FtpServer +from observatory_platform.sandbox.test_utils import find_free_port + + +class TestFtpServer(unittest.TestCase): + def setUp(self) -> None: + self.host = "localhost" + self.port = find_free_port() + + @contextlib.contextmanager + def test_server(self): + """Test that the FTP server can be connected to""" + + with CliRunner().isolated_filesystem() as tmp_dir: + server = FtpServer(directory=tmp_dir, host=self.host, port=self.port) + with server.create() as root_dir: + # Connect to FTP server anonymously + ftp_conn = FTP() + ftp_conn.connect(host=self.host, port=self.port) + ftp_conn.login() + + # Check that there are no files + files = ftp_conn.nlst() + self.assertFalse(len(files)) + + # Add a file and check that it exists + expected_file_name = "textfile.txt" + file_path = os.path.join(root_dir, expected_file_name) + with open(file_path, mode="w") as f: + f.write("hello world") + files = ftp_conn.nlst() + self.assertEqual(1, len(files)) + self.assertEqual(expected_file_name, files[0]) + + @contextlib.contextmanager + def test_user_permissions(self): + "Test the level of permissions of the root and anonymous users." + + with CliRunner().isolated_filesystem() as tmp_dir: + server = FtpServer( + directory=tmp_dir, host=self.host, port=self.port, root_username="root", root_password="pass" + ) + with server.create() as root_dir: + # Add a file onto locally hosted server. + expected_file_name = "textfile.txt" + file_path = os.path.join(root_dir, expected_file_name) + with open(file_path, mode="w") as f: + f.write("hello world") + + # Connect to FTP server anonymously. + ftp_conn = FTP() + ftp_conn.connect(host=self.host, port=self.port) + ftp_conn.login() + + # Make sure that anonymoous user has read-only permissions + ftp_repsonse = ftp_conn.sendcmd(f"MLST {expected_file_name}") + self.assertTrue(";perm=r;size=11;type=file;" in ftp_repsonse) + + ftp_conn.close() + + # Connect to FTP server as root user. + ftp_conn = FTP() + ftp_conn.connect(host=self.host, port=self.port) + ftp_conn.login(user="root", passwd="pass") + + # Make sure that root user has all available read/write/modification permissions. + ftp_repsonse = ftp_conn.sendcmd(f"MLST {expected_file_name}") + self.assertTrue(";perm=radfwMT;size=11;type=file;" in ftp_repsonse) + + ftp_conn.close() diff --git a/observatory_platform/sandbox/tests/test_http_server.py b/observatory_platform/sandbox/tests/test_http_server.py new file mode 100644 index 000000000..108f58c1c --- /dev/null +++ b/observatory_platform/sandbox/tests/test_http_server.py @@ -0,0 +1,96 @@ +# Copyright 2019-2024 Curtin University +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import annotations + +import os +from unittest.mock import patch + +import timeout_decorator +from click.testing import CliRunner + +from observatory_platform.config import module_file_path +from observatory_platform.http_download import ( + DownloadInfo, + download_file, + download_files, +) +from observatory_platform.sandbox.http_server import HttpServer +from observatory_platform.sandbox.test_utils import SandboxTestCase + + +class TestHttpserver(SandboxTestCase): + def __init__(self, *args, **kwargs): + super(TestHttpserver, self).__init__(*args, **kwargs) + self.fixtures_path = module_file_path("observatory_platform.sandbox.tests.fixtures") + + def test_serve(self): + """Make sure the server can be constructed.""" + with patch("observatory_platform.sandbox.http_server.ThreadingHTTPServer.serve_forever") as m_serve: + server = HttpServer(directory=".") + server.serve_(("localhost", 10000), ".") + self.assertEqual(m_serve.call_count, 1) + + @timeout_decorator.timeout(1) + def test_stop_before_start(self): + """Make sure there's no deadlock if we try to stop before a start.""" + + server = HttpServer(directory=".") + server.stop() + + @timeout_decorator.timeout(1) + def test_start_twice(self): + """Make sure there's no funny business if we try to stop before a start.""" + + server = HttpServer(directory=".") + server.start() + server.start() + server.stop() + + def test_server(self): + """Test the webserver can serve a directory""" + + server = HttpServer(directory=self.fixtures_path) + server.start() + + test_file = "http_testfile.txt" + expected_hash = "d8e8fca2dc0f896fd7cb4cb0031ba249" + algorithm = "md5" + + url = f"{server.url}{test_file}" + + with CliRunner().isolated_filesystem() as tmpdir: + dst_file = os.path.join(tmpdir, "testfile.txt") + + download_files(download_list=[DownloadInfo(url=url, filename=dst_file)]) + + self.assert_file_integrity(dst_file, expected_hash, algorithm) + + server.stop() + + def test_context_manager(self): + server = HttpServer(directory=self.fixtures_path) + + with server.create(): + test_file = "http_testfile.txt" + expected_hash = "d8e8fca2dc0f896fd7cb4cb0031ba249" + algorithm = "md5" + + url = f"{server.url}{test_file}" + + with CliRunner().isolated_filesystem() as tmpdir: + dst_file = os.path.join(tmpdir, "testfile.txt") + download_file(url=url, filename=dst_file) + self.assert_file_integrity(dst_file, expected_hash, algorithm) diff --git a/observatory_platform/sandbox/tests/test_sandbox_environment.py b/observatory_platform/sandbox/tests/test_sandbox_environment.py new file mode 100644 index 000000000..88ba0ce20 --- /dev/null +++ b/observatory_platform/sandbox/tests/test_sandbox_environment.py @@ -0,0 +1,333 @@ +# Copyright 2021-2024 Curtin University +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Author: James Diprose, Aniek Roelofs + +from __future__ import annotations + +import logging +import os +import unittest +from datetime import timedelta + +import croniter +import pendulum +from airflow.decorators import dag, task +from airflow.models.connection import Connection +from airflow.models.dag import ScheduleArg +from airflow.models.variable import Variable +from airflow.utils.state import TaskInstanceState +from google.cloud.exceptions import NotFound + +from observatory_platform.airflow.tasks import check_dependencies +from observatory_platform.config import AirflowVars +from observatory_platform.google.bigquery import bq_create_dataset +from observatory_platform.sandbox.sandbox_environment import SandboxEnvironment +from observatory_platform.sandbox.test_utils import random_id + +DAG_ID = "dag-test" +MY_VAR_ID = "my-variable" +MY_CONN_ID = "my-connection" + + +def create_dag( + dag_id: str = DAG_ID, + start_date: pendulum.DateTime = pendulum.datetime(2020, 9, 1, tz="UTC"), + schedule: ScheduleArg = "@weekly", +): + # Define the DAG (workflow) + @dag( + dag_id=dag_id, + schedule=schedule, + start_date=start_date, + ) + def my_dag(): + @task() + def task2(): + logging.info("task 2!") + + @task() + def task3(): + logging.info("task 3!") + + t1 = check_dependencies( + airflow_vars=[ + AirflowVars.DATA_PATH, + MY_VAR_ID, + ], + airflow_conns=[MY_CONN_ID], + ) + t2 = task2() + t3 = task3() + t1 >> t2 >> t3 + + return my_dag() + + +class TestSandboxEnvironment(unittest.TestCase): + """Test the SandboxEnvironment""" + + def __init__(self, *args, **kwargs): + super(TestSandboxEnvironment, self).__init__(*args, **kwargs) + self.project_id = os.getenv("TEST_GCP_PROJECT_ID") + self.data_location = os.getenv("TEST_GCP_DATA_LOCATION") + + def test_add_bucket(self): + """Test the add_bucket method""" + + env = SandboxEnvironment(self.project_id, self.data_location) + + # The download and transform buckets are added in the constructor + buckets = list(env.buckets.keys()) + self.assertEqual(2, len(buckets)) + self.assertEqual(env.download_bucket, buckets[0]) + self.assertEqual(env.transform_bucket, buckets[1]) + + # Test that calling add bucket adds a new bucket to the buckets list + name = env.add_bucket() + buckets = list(env.buckets.keys()) + self.assertEqual(name, buckets[-1]) + + # No Google Cloud variables raises error + with self.assertRaises(AssertionError): + SandboxEnvironment().add_bucket() + + def test_create_delete_bucket(self): + """Test _create_bucket and _delete_bucket""" + + env = SandboxEnvironment(self.project_id, self.data_location) + + bucket_id = "obsenv_tests_" + random_id() + + # Create bucket + env._create_bucket(bucket_id) + bucket = env.storage_client.bucket(bucket_id) + self.assertTrue(bucket.exists()) + + # Delete bucket + env._delete_bucket(bucket_id) + self.assertFalse(bucket.exists()) + + # Test double delete is handled gracefully + env._delete_bucket(bucket_id) + + # Test create a bucket with a set of roles + roles = {"roles/storage.objectViewer", "roles/storage.legacyBucketWriter"} + env._create_bucket(bucket_id, roles=roles) + bucket = env.storage_client.bucket(bucket_id) + bucket_policy = bucket.get_iam_policy() + for role in roles: + self.assertTrue({"role": role, "members": {"allUsers"}} in bucket_policy) + + # No Google Cloud variables raises error + bucket_id = "obsenv_tests_" + random_id() + with self.assertRaises(AssertionError): + SandboxEnvironment()._create_bucket(bucket_id) + with self.assertRaises(AssertionError): + SandboxEnvironment()._delete_bucket(bucket_id) + + def test_add_delete_dataset(self): + """Test add_dataset and _delete_dataset""" + + # Create dataset + env = SandboxEnvironment(self.project_id, self.data_location) + + dataset_id = env.add_dataset() + bq_create_dataset(project_id=self.project_id, dataset_id=dataset_id, location=self.data_location) + + # Check that dataset exists: should not raise NotFound exception + dataset_id = f"{self.project_id}.{dataset_id}" + env.bigquery_client.get_dataset(dataset_id) + + # Delete dataset + env._delete_dataset(dataset_id) + + # Check that dataset doesn't exist + with self.assertRaises(NotFound): + env.bigquery_client.get_dataset(dataset_id) + + # No Google Cloud variables raises error + with self.assertRaises(AssertionError): + SandboxEnvironment().add_dataset() + with self.assertRaises(AssertionError): + SandboxEnvironment()._delete_dataset(random_id()) + + def test_create(self): + """Tests create, add_variable, add_connection and run_task""" + + # Setup Telescope + execution_date = pendulum.datetime(year=2020, month=11, day=1) + my_dag = create_dag() + + # Test that previous tasks have to be finished to run next task + env = SandboxEnvironment(self.project_id, self.data_location) + + with env.create(task_logging=True): + with env.create_dag_run(my_dag, execution_date): + # Add_variable + env.add_variable(Variable(key=MY_VAR_ID, val="hello")) + + # Add connection + conn = Connection( + conn_id=MY_CONN_ID, uri="mysql://login:password@host:8080/schema?param1=val1¶m2=val2" + ) + env.add_connection(conn) + + # Test run task when dependencies are not met + ti = env.run_task("task2") + self.assertIsNone(ti.state) + + # Try again when dependencies are met + ti = env.run_task("check_dependencies") + self.assertEqual(TaskInstanceState.SUCCESS, ti.state) + + ti = env.run_task("task2") + self.assertEqual(TaskInstanceState.SUCCESS, ti.state) + + ti = env.run_task("task3") + self.assertEqual(TaskInstanceState.SUCCESS, ti.state) + + def test_task_logging(self): + """Test task logging""" + + env = SandboxEnvironment(self.project_id, self.data_location) + + # Setup Telescope + execution_date = pendulum.datetime(year=2020, month=11, day=1) + my_dag = create_dag() + + # Test environment without logging enabled + with env.create(): + with env.create_dag_run(my_dag, execution_date): + # Test add_variable + env.add_variable(Variable(key=MY_VAR_ID, val="hello")) + + # Test add_connection + conn = Connection( + conn_id=MY_CONN_ID, uri="mysql://login:password@host:8080/schema?param1=val1¶m2=val2" + ) + env.add_connection(conn) + + # Test run task + ti = env.run_task("check_dependencies") + self.assertFalse(ti.log.propagate) + self.assertEqual(TaskInstanceState.SUCCESS, ti.state) + + # Test environment with logging enabled + env = SandboxEnvironment(self.project_id, self.data_location) + with env.create(task_logging=True): + with env.create_dag_run(my_dag, execution_date): + # Test add_variable + env.add_variable(Variable(key=MY_VAR_ID, val="hello")) + + # Test add_connection + conn = Connection( + conn_id=MY_CONN_ID, uri="mysql://login:password@host:8080/schema?param1=val1¶m2=val2" + ) + env.add_connection(conn) + + # Test run task + ti = env.run_task("check_dependencies") + self.assertTrue(ti.log.propagate) + self.assertEqual(TaskInstanceState.SUCCESS, ti.state) + + def test_create_dagrun(self): + """Tests create_dag_run""" + + env = SandboxEnvironment(self.project_id, self.data_location) + + # Setup Telescope + first_execution_date = pendulum.datetime(year=2020, month=11, day=1, tz="UTC") + second_execution_date = pendulum.datetime(year=2020, month=12, day=1, tz="UTC") + my_dag = create_dag() + + # Get start dates outside of + first_start_date = croniter.croniter(my_dag.normalized_schedule_interval, first_execution_date).get_next( + pendulum.DateTime + ) + second_start_date = croniter.croniter(my_dag.normalized_schedule_interval, second_execution_date).get_next( + pendulum.DateTime + ) + + # Use DAG run with freezing time + with env.create(): + # Test add_variable + env.add_variable(Variable(key=MY_VAR_ID, val="hello")) + + # Test add_connection + conn = Connection(conn_id=MY_CONN_ID, uri="mysql://login:password@host:8080/schema?param1=val1¶m2=val2") + env.add_connection(conn) + + self.assertIsNone(env.dag_run) + # First DAG Run + with env.create_dag_run(my_dag, first_execution_date): + # Test DAG Run is set and has frozen start date + self.assertIsNotNone(env.dag_run) + self.assertEqual(first_start_date.date(), env.dag_run.start_date.date()) + + ti1 = env.run_task("check_dependencies") + self.assertEqual(TaskInstanceState.SUCCESS, ti1.state) + self.assertIsNone(ti1.previous_ti) + + with env.create_dag_run(my_dag, second_execution_date): + # Test DAG Run is set and has frozen start date + self.assertIsNotNone(env.dag_run) + self.assertEqual(second_start_date, env.dag_run.start_date) + + ti2 = env.run_task("check_dependencies") + self.assertEqual(TaskInstanceState.SUCCESS, ti2.state) + # Test previous ti is set + self.assertEqual(ti1.job_id, ti2.previous_ti.job_id) + + # Use DAG run without freezing time + env = SandboxEnvironment(self.project_id, self.data_location) + with env.create(): + # Test add_variable + env.add_variable(Variable(key=MY_VAR_ID, val="hello")) + + # Test add_connection + conn = Connection(conn_id=MY_CONN_ID, uri="mysql://login:password@host:8080/schema?param1=val1¶m2=val2") + env.add_connection(conn) + + # First DAG Run + with env.create_dag_run(my_dag, first_execution_date): + # Test DAG Run is set and has today as start date + self.assertIsNotNone(env.dag_run) + self.assertEqual(first_start_date, env.dag_run.start_date) + + ti1 = env.run_task("check_dependencies") + self.assertEqual(TaskInstanceState.SUCCESS, ti1.state) + self.assertIsNone(ti1.previous_ti) + + # Second DAG Run + with env.create_dag_run(my_dag, second_execution_date): + # Test DAG Run is set and has today as start date + self.assertIsNotNone(env.dag_run) + self.assertEqual(second_start_date, env.dag_run.start_date) + + ti2 = env.run_task("check_dependencies") + self.assertEqual(TaskInstanceState.SUCCESS, ti2.state) + # Test previous ti is set + self.assertEqual(ti1.job_id, ti2.previous_ti.job_id) + + def test_create_dag_run_timedelta(self): + env = SandboxEnvironment(self.project_id, self.data_location) + + my_dag = create_dag(schedule=timedelta(days=1)) + execution_date = pendulum.datetime(2021, 1, 1) + expected_dag_date = pendulum.datetime(2021, 1, 2) + with env.create(): + with env.create_dag_run(my_dag, execution_date): + self.assertIsNotNone(env.dag_run) + self.assertEqual(expected_dag_date, env.dag_run.start_date) diff --git a/observatory_platform/sandbox/tests/test_sftp_server.py b/observatory_platform/sandbox/tests/test_sftp_server.py new file mode 100644 index 000000000..9a0223e58 --- /dev/null +++ b/observatory_platform/sandbox/tests/test_sftp_server.py @@ -0,0 +1,53 @@ +# Copyright 2019-2024 Curtin University +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import annotations + +import os +import unittest + +import pysftp + +from observatory_platform.sandbox.sftp_server import SftpServer +from observatory_platform.sandbox.test_utils import find_free_port + + +class TestSftpServer(unittest.TestCase): + def setUp(self) -> None: + self.host = "localhost" + self.port = find_free_port() + + def test_server(self): + """Test that the SFTP server can be connected to""" + + server = SftpServer(host=self.host, port=self.port) + with server.create() as root_dir: + # Connect to SFTP server and disable host key checking + cnopts = pysftp.CnOpts() + cnopts.hostkeys = None + sftp = pysftp.Connection(self.host, port=self.port, username="", password="", cnopts=cnopts) + + # Check that there are no files + files = sftp.listdir(".") + self.assertFalse(len(files)) + + # Add a file and check that it exists + expected_file_name = "onix.xml" + file_path = os.path.join(root_dir, expected_file_name) + with open(file_path, mode="w") as f: + f.write("hello world") + files = sftp.listdir(".") + self.assertEqual(1, len(files)) + self.assertEqual(expected_file_name, files[0]) diff --git a/observatory_platform/sandbox/tests/test_test_utils.py b/observatory_platform/sandbox/tests/test_test_utils.py new file mode 100644 index 000000000..df293aa9e --- /dev/null +++ b/observatory_platform/sandbox/tests/test_test_utils.py @@ -0,0 +1,346 @@ +# Copyright 2019-2024 Curtin University +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import annotations + +import json +import os +import tempfile +import unittest +from datetime import datetime +from pathlib import Path + +import httpretty +import pendulum +from click.testing import CliRunner +from google.cloud.bigquery import SourceFormat + +from observatory_platform.config import module_file_path +from observatory_platform.google.bigquery import bq_create_dataset, bq_load_table, bq_table_id +from observatory_platform.google.gcs import gcs_upload_file, gcs_blob_uri +from observatory_platform.sandbox.sandbox_environment import SandboxEnvironment +from observatory_platform.sandbox.test_utils import SandboxTestCase, load_and_parse_json, random_id +from observatory_platform.sandbox.tests.test_sandbox_environment import create_dag +from observatory_platform.url_utils import retry_session + +DAG_ID = "my-dag" +DAG_FILE_CONTENT = """ +# The keywords airflow and DAG are required to load the DAGs from this file, see bullet 2 in the Apache Airflow FAQ: +# https://airflow.apache.org/docs/stable/faq.html + +from observatory_platform.sandbox.tests.test_sandbox_environment import create_dag + +dag_id = "my-dag" +globals()[dag_id] = create_dag(dag_id=dag_id) +""" + + +class TestSandboxTestCase(unittest.TestCase): + """Test the SandboxTestCase class""" + + def __init__(self, *args, **kwargs): + super(TestSandboxTestCase, self).__init__(*args, **kwargs) + self.project_id = os.getenv("TEST_GCP_PROJECT_ID") + self.data_location = os.getenv("TEST_GCP_DATA_LOCATION") + self.test_fixtures_path = module_file_path("observatory_platform.sandbox.tests.fixtures") + + def test_assert_dag_structure(self): + """Test assert_dag_structure""" + + test_case = SandboxTestCase() + dag = create_dag() + + # No assertion error + expected = {"check_dependencies": ["task2"], "task2": ["task3"], "task3": []} + test_case.assert_dag_structure(expected, dag) + + # Raise assertion error + with self.assertRaises(AssertionError): + expected = {"check_dependencies": ["list_releases"], "list_releases": []} + test_case.assert_dag_structure(expected, dag) + + def test_assert_dag_load(self): + """Test assert_dag_load""" + + test_case = SandboxTestCase() + env = SandboxEnvironment() + with env.create() as temp_dir: + # Write DAG into temp_dir + file_path = os.path.join(temp_dir, f"telescope_test.py") + with open(file_path, mode="w") as f: + f.write(DAG_FILE_CONTENT) + + # DAG loaded successfully: should be no errors + test_case.assert_dag_load(DAG_ID, file_path) + + # Remove DAG from temp_dir + os.unlink(file_path) + + # DAG not loaded + with self.assertRaises(Exception): + test_case.assert_dag_load(DAG_ID, file_path) + + # DAG not found + with self.assertRaises(Exception): + test_case.assert_dag_load("dag not found", file_path) + + # Import errors + with self.assertRaises(AssertionError): + test_case.assert_dag_load("no dag found", os.path.join(self.test_fixtures_path, "bad_dag.py")) + + # No dag + with self.assertRaises(AssertionError): + empty_filename = os.path.join(temp_dir, "empty_dag.py") + Path(empty_filename).touch() + test_case.assert_dag_load("invalid_dag_id", empty_filename) + + def test_assert_blob_integrity(self): + """Test assert_blob_integrity""" + + env = SandboxEnvironment(self.project_id, self.data_location) + + with env.create(): + # Upload file to download bucket and check gzip-crc + blob_name = "people.csv" + file_path = os.path.join(self.test_fixtures_path, blob_name) + result, upload = gcs_upload_file(bucket_name=env.download_bucket, blob_name=blob_name, file_path=file_path) + self.assertTrue(result) + + # Check that blob exists + test_case = SandboxTestCase() + test_case.assert_blob_integrity(env.download_bucket, blob_name, file_path) + + # Check that blob doesn't exist + with self.assertRaises(AssertionError): + test_case.assert_blob_integrity(env.transform_bucket, blob_name, file_path) + + def test_assert_table_integrity(self): + """Test assert_table_integrity""" + + env = SandboxEnvironment(self.project_id, self.data_location) + + with env.create(): + # Upload file to download bucket and check gzip-crc + blob_name = "people.jsonl" + file_path = os.path.join(self.test_fixtures_path, blob_name) + result, upload = gcs_upload_file(bucket_name=env.download_bucket, blob_name=blob_name, file_path=file_path) + self.assertTrue(result) + + # Create dataset + dataset_id = env.add_dataset() + bq_create_dataset(project_id=self.project_id, dataset_id=dataset_id, location=self.data_location) + + # Test loading JSON newline table + table_name = random_id() + schema_path = os.path.join(self.test_fixtures_path, "people_schema.json") + uri = gcs_blob_uri(env.download_bucket, blob_name) + table_id = bq_table_id(self.project_id, dataset_id, table_name) + result = bq_load_table( + uri=uri, + table_id=table_id, + schema_file_path=schema_path, + source_format=SourceFormat.NEWLINE_DELIMITED_JSON, + ) + self.assertTrue(result) + + # Check BigQuery table exists and has expected rows + test_case = SandboxTestCase() + table_id = f"{self.project_id}.{dataset_id}.{table_name}" + expected_rows = 5 + test_case.assert_table_integrity(table_id, expected_rows) + + # Check that BigQuery table doesn't exist + with self.assertRaises(AssertionError): + table_id = f"{dataset_id}.{random_id()}" + test_case.assert_table_integrity(table_id, expected_rows) + + # Check that BigQuery table has incorrect rows + with self.assertRaises(AssertionError): + table_id = f"{dataset_id}.{table_name}" + expected_rows = 20 + test_case.assert_table_integrity(table_id, expected_rows) + + def test_assert_table_content(self): + """Test assert table content + + :return: None. + """ + + env = SandboxEnvironment(self.project_id, self.data_location) + + with env.create(): + # Upload file to download bucket and check gzip-crc + blob_name = "people.jsonl" + file_path = os.path.join(self.test_fixtures_path, blob_name) + result, upload = gcs_upload_file(bucket_name=env.download_bucket, blob_name=blob_name, file_path=file_path) + self.assertTrue(result) + + # Create dataset + dataset_id = env.add_dataset() + bq_create_dataset(project_id=self.project_id, dataset_id=dataset_id, location=self.data_location) + + # Test loading JSON newline table + table_name = random_id() + schema_path = os.path.join(self.test_fixtures_path, "people_schema.json") + uri = gcs_blob_uri(env.download_bucket, blob_name) + table_id = bq_table_id(self.project_id, dataset_id, table_name) + result = bq_load_table( + uri=uri, + table_id=table_id, + schema_file_path=schema_path, + source_format=SourceFormat.NEWLINE_DELIMITED_JSON, + ) + self.assertTrue(result) + + # Check BigQuery table exists and has expected rows + test_case = SandboxTestCase() + table_id = f"{self.project_id}.{dataset_id}.{table_name}" + expected_content = [ + {"first_name": "Gisella", "last_name": "Derya", "dob": datetime(1997, 7, 1).date()}, + {"first_name": "Adelaida", "last_name": "Melis", "dob": datetime(1980, 9, 3).date()}, + {"first_name": "Melanie", "last_name": "Magomedkhan", "dob": datetime(1990, 3, 1).date()}, + {"first_name": "Octavia", "last_name": "Tomasa", "dob": datetime(1970, 1, 8).date()}, + {"first_name": "Ansgar", "last_name": "Zorion", "dob": datetime(2001, 2, 1).date()}, + ] + test_case.assert_table_content(table_id, expected_content, "first_name") + + # Check that BigQuery table doesn't exist + with self.assertRaises(AssertionError): + table_id = f"{self.project_id}.{dataset_id}.{random_id()}" + test_case.assert_table_content(table_id, expected_content, "first_name") + + # Check that BigQuery table has extra rows + with self.assertRaises(AssertionError): + table_id = f"{dataset_id}.{table_name}" + expected_content = [ + {"first_name": "Gisella", "last_name": "Derya", "dob": datetime(1997, 7, 1).date()}, + {"first_name": "Adelaida", "last_name": "Melis", "dob": datetime(1980, 9, 3).date()}, + {"first_name": "Octavia", "last_name": "Tomasa", "dob": datetime(1970, 1, 8).date()}, + {"first_name": "Ansgar", "last_name": "Zorion", "dob": datetime(2001, 2, 1).date()}, + ] + test_case.assert_table_content(table_id, expected_content, "first_name") + + # Check that BigQuery table has missing rows + with self.assertRaises(AssertionError): + table_id = f"{self.project_id}.{dataset_id}.{table_name}" + expected_content = [ + {"first_name": "Gisella", "last_name": "Derya", "dob": datetime(1997, 7, 1).date()}, + {"first_name": "Adelaida", "last_name": "Melis", "dob": datetime(1980, 9, 3).date()}, + {"first_name": "Melanie", "last_name": "Magomedkhan", "dob": datetime(1990, 3, 1).date()}, + {"first_name": "Octavia", "last_name": "Tomasa", "dob": datetime(1970, 1, 8).date()}, + {"first_name": "Ansgar", "last_name": "Zorion", "dob": datetime(2001, 2, 1).date()}, + {"first_name": "Extra", "last_name": "Row", "dob": datetime(2001, 2, 1).date()}, + ] + test_case.assert_table_content(table_id, expected_content, "first_name") + + def test_assert_file_integrity(self): + """Test assert_file_integrity""" + + test_case = SandboxTestCase() + + # Test md5 + file_path = os.path.join(self.test_fixtures_path, "people.csv") + expected_hash = "ad0d7ad3dc3434337cebd5fb543420e7" + algorithm = "md5" + test_case.assert_file_integrity(file_path, expected_hash, algorithm) + + # Test gzip-crc + file_path = os.path.join(self.test_fixtures_path, "people.csv.gz") + expected_hash = "3beea5ac" + algorithm = "gzip_crc" + test_case.assert_file_integrity(file_path, expected_hash, algorithm) + + def test_assert_cleanup(self): + """Test assert_cleanup""" + + with CliRunner().isolated_filesystem() as temp_dir: + workflow = os.path.join(temp_dir, "workflow") + + # Make download, extract and transform folders + os.makedirs(workflow) + + # Check that assertion is raised when folders exist + test_case = SandboxTestCase() + with self.assertRaises(AssertionError): + test_case.assert_cleanup(workflow) + + # Delete folders + os.rmdir(workflow) + + # No error when folders deleted + test_case.assert_cleanup(workflow) + + def test_setup_mock_file_download(self): + """Test mocking a file download""" + + with CliRunner().isolated_filesystem() as temp_dir: + # Write data into temp_dir + expected_data = "Hello World!" + file_path = os.path.join(temp_dir, f"content.txt") + with open(file_path, mode="w") as f: + f.write(expected_data) + + # Check that content was downloaded from test file + test_case = SandboxTestCase() + url = "https://example.com" + with httpretty.enabled(): + test_case.setup_mock_file_download(url, file_path) + response = retry_session().get(url) + self.assertEqual(expected_data, response.content.decode("utf-8")) + + +class TestLoadAndParseJson(unittest.TestCase): + def test_load_and_parse_json(self): + # Create a temporary JSON file + with tempfile.NamedTemporaryFile() as temp_file: + # Create the data dictionary and write to temp file + data = { + "date1": "2022-01-01", + "timestamp1": "2022-01-01 12:00:00.100000 UTC", + "date2": "20230101", + "timestamp2": "2023-01-01 12:00:00", + } + with open(temp_file.name, "w") as f: + json.dump(data, f) + + # Test case 1: Parsing date fields with default date formats. Not specifying timestamp fields + expected_result = data.copy() + expected_result["date1"] = datetime(2022, 1, 1).date() + expected_result["date2"] = datetime(2023, 1, 1).date() # Should be converted by pendulum + result = load_and_parse_json(temp_file.name, date_fields=["date1", "date2"], date_formats=["%Y-%m-%d"]) + self.assertEqual(result, expected_result) + + # Test case 2: Parsing timestamp fields with custom timestamp format, not specifying date field + expected_result = data.copy() + expected_result["timestamp1"] = datetime(2022, 1, 1, 12, 0, 0, 100000) + expected_result["timestamp2"] = datetime( + 2023, 1, 1, 12, 0, 0, tzinfo=pendulum.tz.timezone("UTC") + ) # Converted by pendulum + result = load_and_parse_json( + temp_file.name, + timestamp_fields=["timestamp1", "timestamp2"], + timestamp_formats=["%Y-%m-%d %H:%M:%S.%f %Z"], + ) + self.assertEqual(result, expected_result) + + # Test case 3: Default date and timestamp formats + expected_result = { + "date1": datetime(2022, 1, 1).date(), + "date2": "20230101", + "timestamp1": datetime(2022, 1, 1, 12, 0, 0, 100000), + "timestamp2": "2023-01-01 12:00:00", + } + result = load_and_parse_json(temp_file.name, date_fields=["date1"], timestamp_fields=["timestamp1"]) + self.assertEqual(result, expected_result) diff --git a/tests/__init__.py b/observatory_platform/schema/__init__.py similarity index 100% rename from tests/__init__.py rename to observatory_platform/schema/__init__.py diff --git a/observatory_platform/schema/dataset_release.json b/observatory_platform/schema/dataset_release.json new file mode 100644 index 000000000..019d0fafc --- /dev/null +++ b/observatory_platform/schema/dataset_release.json @@ -0,0 +1,86 @@ +[ + { + "name": "dag_id", + "mode": "REQUIRED", + "type": "STRING", + "description": "The Airflow DAG ID, e.g. doi_workflow" + }, + { + "name": "dataset_id", + "mode": "REQUIRED", + "type": "STRING", + "description": "A unique identifier to represent the dataset being processed." + }, + { + "name": "dag_run_id", + "mode": "REQUIRED", + "type": "STRING", + "description": "The Airflow DAG run ID." + }, + { + "name": "data_interval_start", + "mode": "NULLABLE", + "type": "TIMESTAMP", + "description": "The Airflow data interval start datetime." + }, + { + "name": "data_interval_end", + "mode": "NULLABLE", + "type": "TIMESTAMP", + "description": "The Airflow data interval end datetime." + }, + { + "name": "snapshot_date", + "mode": "NULLABLE", + "type": "TIMESTAMP", + "description": "For datasets that release entire snapshots, the date that the snapshot was released." + }, + { + "name": "partition_date", + "mode": "NULLABLE", + "type": "TIMESTAMP", + "description": "" + }, + { + "name": "changefile_start_date", + "mode": "NULLABLE", + "type": "TIMESTAMP", + "description": "The date of the first changefile in this release. For datasets that provide changefiles with inserts, updates and deletes." + }, + { + "name": "changefile_end_date", + "mode": "NULLABLE", + "type": "TIMESTAMP", + "description": "The date of the last changefile in this release. For datasets that provide changefiles with inserts, updates and deletes." + }, + { + "name": "sequence_start", + "mode": "NULLABLE", + "type": "INT64", + "description": "The first sequence number of changefiles in this release. For datasets that provide changefiles with inserts, updates and deletes." + }, + { + "name": "sequence_end", + "mode": "NULLABLE", + "type": "INT64", + "description": "The last sequence number of changefiles in this release. For datasets that provide changefiles with inserts, updates and deletes." + }, + { + "name": "extra", + "mode": "NULLABLE", + "type": "JSON", + "description": "An optional JSON field for extra information." + }, + { + "name": "created", + "mode": "REQUIRED", + "type": "TIMESTAMP", + "description": "The date that this record was created." + }, + { + "name": "modified", + "mode": "REQUIRED", + "type": "TIMESTAMP", + "description": "The date that this record was modified." + } +] \ No newline at end of file diff --git a/observatory-platform/observatory/platform/sftp.py b/observatory_platform/sftp.py similarity index 100% rename from observatory-platform/observatory/platform/sftp.py rename to observatory_platform/sftp.py diff --git a/tests/fixtures/__init__.py b/observatory_platform/sql/__init__.py similarity index 100% rename from tests/fixtures/__init__.py rename to observatory_platform/sql/__init__.py diff --git a/observatory-platform/observatory/platform/sql/delete_records.sql.jinja2 b/observatory_platform/sql/delete_records.sql.jinja2 similarity index 100% rename from observatory-platform/observatory/platform/sql/delete_records.sql.jinja2 rename to observatory_platform/sql/delete_records.sql.jinja2 diff --git a/observatory-platform/observatory/platform/sql/select_columns.sql.jinja2 b/observatory_platform/sql/select_columns.sql.jinja2 similarity index 100% rename from observatory-platform/observatory/platform/sql/select_columns.sql.jinja2 rename to observatory_platform/sql/select_columns.sql.jinja2 diff --git a/observatory-platform/observatory/platform/sql/select_table_shard_dates.sql.jinja2 b/observatory_platform/sql/select_table_shard_dates.sql.jinja2 similarity index 100% rename from observatory-platform/observatory/platform/sql/select_table_shard_dates.sql.jinja2 rename to observatory_platform/sql/select_table_shard_dates.sql.jinja2 diff --git a/observatory-platform/observatory/platform/sql/upsert_records.sql.jinja2 b/observatory_platform/sql/upsert_records.sql.jinja2 similarity index 100% rename from observatory-platform/observatory/platform/sql/upsert_records.sql.jinja2 rename to observatory_platform/sql/upsert_records.sql.jinja2 diff --git a/tests/fixtures/cli/my-workflows-project/my_workflows_project/__init__.py b/observatory_platform/tests/__init__.py similarity index 100% rename from tests/fixtures/cli/my-workflows-project/my_workflows_project/__init__.py rename to observatory_platform/tests/__init__.py diff --git a/tests/fixtures/cli/my-workflows-project/my_workflows_project/dags/__init__.py b/observatory_platform/tests/fixtures/__init__.py similarity index 100% rename from tests/fixtures/cli/my-workflows-project/my_workflows_project/dags/__init__.py rename to observatory_platform/tests/fixtures/__init__.py diff --git a/tests/fixtures/utils/find_replace.txt b/observatory_platform/tests/fixtures/find_replace.txt similarity index 100% rename from tests/fixtures/utils/find_replace.txt rename to observatory_platform/tests/fixtures/find_replace.txt diff --git a/tests/fixtures/utils/get_http_response_json.json b/observatory_platform/tests/fixtures/get_http_response_json.json similarity index 100% rename from tests/fixtures/utils/get_http_response_json.json rename to observatory_platform/tests/fixtures/get_http_response_json.json diff --git a/tests/fixtures/utils/get_http_response_xml_to_dict.xml b/observatory_platform/tests/fixtures/get_http_response_xml_to_dict.xml similarity index 100% rename from tests/fixtures/utils/get_http_response_xml_to_dict.xml rename to observatory_platform/tests/fixtures/get_http_response_xml_to_dict.xml diff --git a/observatory_platform/tests/fixtures/http_testfile.txt b/observatory_platform/tests/fixtures/http_testfile.txt new file mode 100644 index 000000000..1f6600da6 --- /dev/null +++ b/observatory_platform/tests/fixtures/http_testfile.txt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f2ca1bb6c7e907d06dafe4687e579fce76b37e4e93b7605022da52e6ccc26fd2 +size 5 diff --git a/tests/fixtures/utils/http_testfile2.txt b/observatory_platform/tests/fixtures/http_testfile2.txt similarity index 100% rename from tests/fixtures/utils/http_testfile2.txt rename to observatory_platform/tests/fixtures/http_testfile2.txt diff --git a/tests/fixtures/elastic/load_csv.csv b/observatory_platform/tests/fixtures/load_csv.csv similarity index 100% rename from tests/fixtures/elastic/load_csv.csv rename to observatory_platform/tests/fixtures/load_csv.csv diff --git a/tests/fixtures/elastic/load_csv_gz.csv.gz b/observatory_platform/tests/fixtures/load_csv_gz.csv.gz similarity index 100% rename from tests/fixtures/elastic/load_csv_gz.csv.gz rename to observatory_platform/tests/fixtures/load_csv_gz.csv.gz diff --git a/tests/fixtures/elastic/load_jsonl.jsonl b/observatory_platform/tests/fixtures/load_jsonl.jsonl similarity index 100% rename from tests/fixtures/elastic/load_jsonl.jsonl rename to observatory_platform/tests/fixtures/load_jsonl.jsonl diff --git a/tests/fixtures/elastic/load_jsonl_gz.jsonl.gz b/observatory_platform/tests/fixtures/load_jsonl_gz.jsonl.gz similarity index 100% rename from tests/fixtures/elastic/load_jsonl_gz.jsonl.gz rename to observatory_platform/tests/fixtures/load_jsonl_gz.jsonl.gz diff --git a/tests/fixtures/utils/test_hasher.txt b/observatory_platform/tests/fixtures/test_hasher.txt similarity index 100% rename from tests/fixtures/utils/test_hasher.txt rename to observatory_platform/tests/fixtures/test_hasher.txt diff --git a/tests/fixtures/utils/testzip.txt b/observatory_platform/tests/fixtures/testzip.txt similarity index 100% rename from tests/fixtures/utils/testzip.txt rename to observatory_platform/tests/fixtures/testzip.txt diff --git a/tests/fixtures/utils/testzip.txt.gz b/observatory_platform/tests/fixtures/testzip.txt.gz similarity index 100% rename from tests/fixtures/utils/testzip.txt.gz rename to observatory_platform/tests/fixtures/testzip.txt.gz diff --git a/observatory_platform/tests/test_config.py b/observatory_platform/tests/test_config.py new file mode 100644 index 000000000..df05747d9 --- /dev/null +++ b/observatory_platform/tests/test_config.py @@ -0,0 +1,58 @@ +# Copyright 2019 Curtin University. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Author: James Diprose, Aniek Roelofs + +import os +import pathlib +import unittest +from unittest.mock import patch + +from click.testing import CliRunner + +import observatory_platform.tests as platform_utils_tests +from observatory_platform.config import module_file_path, observatory_home + + +class TestConfig(unittest.TestCase): + def test_module_file_path(self): + # Go back one step (the default) + expected_path = str(pathlib.Path(*pathlib.Path(platform_utils_tests.__file__).resolve().parts[:-1]).resolve()) + actual_path = module_file_path("observatory_platform.tests", nav_back_steps=-1) + self.assertEqual(expected_path, actual_path) + + # Go back two steps + expected_path = str(pathlib.Path(*pathlib.Path(platform_utils_tests.__file__).resolve().parts[:-1]).resolve()) + actual_path = module_file_path("observatory_platform.tests.fixtures", nav_back_steps=-2) + self.assertEqual(expected_path, actual_path) + + @patch("observatory_platform.config.pathlib.Path.home") + def test_observatory_home(self, home_mock): + runner = CliRunner() + with runner.isolated_filesystem(): + # Create home path and mock getting home path + home_path = "user-home" + os.makedirs(home_path, exist_ok=True) + home_mock.return_value = home_path + + with runner.isolated_filesystem(): + # Test that observatory home works + path = observatory_home() + self.assertTrue(os.path.exists(path)) + self.assertEqual(f"{home_path}/.observatory", path) + + # Test that subdirectories are created + path = observatory_home("subfolder") + self.assertTrue(os.path.exists(path)) + self.assertEqual(f"{home_path}/.observatory/subfolder", path) diff --git a/observatory_platform/tests/test_dataset_api.py b/observatory_platform/tests/test_dataset_api.py new file mode 100644 index 000000000..89c2bdcd7 --- /dev/null +++ b/observatory_platform/tests/test_dataset_api.py @@ -0,0 +1,189 @@ +# Copyright 2020-2024 Curtin University +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os + +import pendulum + +from observatory_platform.dataset_api import build_schedule, DatasetAPI, DatasetRelease +from observatory_platform.google.bigquery import bq_run_query +from observatory_platform.sandbox.sandbox_environment import SandboxEnvironment +from observatory_platform.sandbox.test_utils import SandboxTestCase + + +class TestDatasetAPI(SandboxTestCase): + def __init__(self, *args, **kwargs): + super(TestDatasetAPI, self).__init__(*args, **kwargs) + self.project_id = os.getenv("TEST_GCP_PROJECT_ID") + self.data_location = os.getenv("TEST_GCP_DATA_LOCATION") + + def test_add_dataset_release(self): + env = SandboxEnvironment(project_id=self.project_id, data_location=self.data_location) + bq_dataset_id = env.add_dataset(prefix="dataset_api") + api = DatasetAPI(project_id=self.project_id, dataset_id=bq_dataset_id, location=self.data_location) + with env.create(): + api.seed_db() + + # Add dataset release + dag_id = "doi_workflow" + dataset_id = "doi" + + dt = pendulum.now() + expected = DatasetRelease( + dag_id=dag_id, + dataset_id=dataset_id, + dag_run_id="test", + created=dt, + modified=dt, + data_interval_start=dt, + data_interval_end=dt, + snapshot_date=dt, + partition_date=dt, + changefile_start_date=dt, + changefile_end_date=dt, + sequence_start=1, + sequence_end=20, + extra={"hello": "world"}, + ) + api.add_dataset_release(expected) + + # Check if dataset release added + rows = bq_run_query(f"SELECT * FROM `{api.full_table_id}`") + self.assertEqual(1, len(rows)) + actual = DatasetRelease.from_dict(dict(rows[0])) + self.assertEqual(expected, actual) + + def test_get_dataset_releases(self): + env = SandboxEnvironment(project_id=self.project_id, data_location=self.data_location) + bq_dataset_id = env.add_dataset(prefix="dataset_api") + api = DatasetAPI(project_id=self.project_id, dataset_id=bq_dataset_id, location=self.data_location) + expected = [] + with env.create(): + api.seed_db() + + # Create dataset releases + dag_id = "doi_workflow" + dataset_id = "doi" + for i in range(10): + dt = pendulum.now() + release = DatasetRelease( + dag_id=dag_id, + dataset_id=dataset_id, + dag_run_id="test", + created=dt, + modified=dt, + ) + api.add_dataset_release(release) + expected.append(release) + + # Sort descending order + expected.sort(key=lambda r: r.created, reverse=True) + + # Get releases + actual = api.get_dataset_releases(dag_id=dag_id, dataset_id=dataset_id) + self.assertListEqual(expected, actual) + + def test_is_first_release(self): + env = SandboxEnvironment(project_id=self.project_id, data_location=self.data_location) + bq_dataset_id = env.add_dataset(prefix="dataset_api") + api = DatasetAPI(project_id=self.project_id, dataset_id=bq_dataset_id, location=self.data_location) + with env.create(): + api.seed_db() + + dag_id = "doi_workflow" + dataset_id = "doi" + + # Is first release + is_first = api.is_first_release(dag_id=dag_id, dataset_id=dataset_id) + self.assertTrue(is_first) + + # Not first release + dt = pendulum.now() + release = DatasetRelease( + dag_id=dag_id, + dataset_id=dataset_id, + dag_run_id="test", + created=dt, + modified=dt, + ) + api.add_dataset_release(release) + is_first = api.is_first_release(dag_id=dag_id, dataset_id=dataset_id) + self.assertFalse(is_first) + + def test_get_latest_dataset_release(self): + dag_id = "doi_workflow" + dataset_id = "doi" + dt = pendulum.now() + releases = [ + DatasetRelease( + dag_id=dag_id, + dataset_id=dataset_id, + dag_run_id="test", + created=dt, + modified=dt, + snapshot_date=pendulum.datetime(2022, 1, 1), + ), + DatasetRelease( + dag_id=dag_id, + dataset_id=dataset_id, + dag_run_id="test", + created=dt, + modified=dt, + snapshot_date=pendulum.datetime(2023, 1, 1), + ), + DatasetRelease( + dag_id=dag_id, + dataset_id=dataset_id, + dag_run_id="test", + created=dt, + modified=dt, + snapshot_date=pendulum.datetime(2024, 1, 1), + ), + ] + env = SandboxEnvironment(project_id=self.project_id, data_location=self.data_location) + bq_dataset_id = env.add_dataset(prefix="dataset_api") + api = DatasetAPI(project_id=self.project_id, dataset_id=bq_dataset_id, location=self.data_location) + + with env.create(): + api.seed_db() + + # Create dataset releases + for release in releases: + api.add_dataset_release(release) + + latest = api.get_latest_dataset_release(dag_id=dag_id, dataset_id=dataset_id, date_key="snapshot_date") + self.assertEqual(releases[-1], latest) + + def test_build_schedule(self): + start_date = pendulum.datetime(2021, 1, 1) + end_date = pendulum.datetime(2021, 2, 1) + schedule = build_schedule(start_date, end_date) + self.assertEqual([pendulum.Period(pendulum.date(2021, 1, 1), pendulum.date(2021, 1, 31))], schedule) + + start_date = pendulum.datetime(2021, 1, 1) + end_date = pendulum.datetime(2021, 3, 1) + schedule = build_schedule(start_date, end_date) + self.assertEqual( + [ + pendulum.Period(pendulum.date(2021, 1, 1), pendulum.date(2021, 1, 31)), + pendulum.Period(pendulum.date(2021, 2, 1), pendulum.date(2021, 2, 28)), + ], + schedule, + ) + + start_date = pendulum.datetime(2021, 1, 7) + end_date = pendulum.datetime(2021, 2, 7) + schedule = build_schedule(start_date, end_date) + self.assertEqual([pendulum.Period(pendulum.date(2021, 1, 7), pendulum.date(2021, 2, 6))], schedule) diff --git a/tests/observatory/platform/test_files.py b/observatory_platform/tests/test_files.py similarity index 92% rename from tests/observatory/platform/test_files.py rename to observatory_platform/tests/test_files.py index df319241f..487390714 100644 --- a/tests/observatory/platform/test_files.py +++ b/observatory_platform/tests/test_files.py @@ -30,8 +30,9 @@ from click.testing import CliRunner from google.cloud import bigquery -from observatory.platform.files import add_partition_date, find_replace_file, get_chunks -from observatory.platform.files import ( +from observatory_platform.config import module_file_path +from observatory_platform.files import add_partition_date, find_replace_file, get_chunks +from observatory_platform.files import ( get_file_hash, get_hasher_, gunzip_files, @@ -44,19 +45,18 @@ split_file, split_file_and_compress, ) -from observatory.platform.files import validate_file_hash, load_jsonl, list_files -from observatory.platform.observatory_environment import test_fixtures_path +from observatory_platform.files import validate_file_hash, load_jsonl, list_files class TestFileUtils(unittest.TestCase): def __init__(self, *args, **kwargs): super(TestFileUtils, self).__init__(*args, **kwargs) - fixtures_path = test_fixtures_path("elastic") - self.csv_gz_file_path = os.path.join(fixtures_path, "load_csv_gz.csv.gz") - self.jsonl_gz_file_path = os.path.join(fixtures_path, "load_jsonl_gz.jsonl.gz") - self.csv_file_path = os.path.join(fixtures_path, "load_csv.csv") - self.jsonl_file_path = os.path.join(fixtures_path, "load_jsonl.jsonl") + self.fixtures_path = module_file_path("observatory_platform.tests.fixtures") + self.csv_gz_file_path = os.path.join(self.fixtures_path, "load_csv_gz.csv.gz") + self.jsonl_gz_file_path = os.path.join(self.fixtures_path, "load_jsonl_gz.jsonl.gz") + self.csv_file_path = os.path.join(self.fixtures_path, "load_csv.csv") + self.jsonl_file_path = os.path.join(self.fixtures_path, "load_jsonl.jsonl") self.expected_records = [ {"first_name": "Jim", "last_name": "Holden"}, {"first_name": "Alex", "last_name": "Kamal"}, @@ -151,24 +151,21 @@ def test_get_hasher_(self): def test_get_file_hash(self): expected_hash = "f299060e0383392ebeac64b714eca7e3" - fixtures_dir = test_fixtures_path("utils") - file_path = os.path.join(fixtures_dir, "test_hasher.txt") + file_path = os.path.join(self.fixtures_path, "test_hasher.txt") computed_hash = get_file_hash(file_path=file_path) self.assertEqual(expected_hash, computed_hash) def test_validate_file_hash(self): expected_hash = "f299060e0383392ebeac64b714eca7e3" - fixtures_dir = test_fixtures_path("utils") - file_path = os.path.join(fixtures_dir, "test_hasher.txt") + file_path = os.path.join(self.fixtures_path, "test_hasher.txt") self.assertTrue(validate_file_hash(file_path=file_path, expected_hash=expected_hash)) def test_gunzip_files(self): - fixture_dir = test_fixtures_path("utils") filename = "testzip.txt.gz" expected_hash = "62d83685cff9cd962ac5abb563c61f38" output_file = "testzip.txt" - src = os.path.join(fixture_dir, filename) + src = os.path.join(self.fixtures_path, filename) # Save in same dir with CliRunner().isolated_filesystem() as tmpdir: @@ -187,7 +184,7 @@ def test_gunzip_files(self): # Skip non gz files with CliRunner().isolated_filesystem() as tmpdir: dst = os.path.join(tmpdir, filename) - src_path = os.path.join(fixture_dir, output_file) + src_path = os.path.join(self.fixtures_path, output_file) gunzip_files(file_list=[src_path], output_dir=tmpdir) self.assertFalse(os.path.exists(dst)) @@ -329,8 +326,7 @@ def test_split_file(self): self.assertEqual(expected_data, data) def test_find_replace_file(self): - fixture_dir = test_fixtures_path("utils") - src = os.path.join(fixture_dir, "find_replace.txt") + src = os.path.join(self.fixtures_path, "find_replace.txt") expected_hash = "ffa623201cb9538bd3c030cd0b9f6b66" with CliRunner().isolated_filesystem(): diff --git a/tests/observatory/platform/utils/test_http_download.py b/observatory_platform/tests/test_http_download.py similarity index 90% rename from tests/observatory/platform/utils/test_http_download.py rename to observatory_platform/tests/test_http_download.py index 493ac1faa..c0c7058b0 100644 --- a/tests/observatory/platform/utils/test_http_download.py +++ b/observatory_platform/tests/test_http_download.py @@ -20,17 +20,15 @@ from click.testing import CliRunner -from observatory.platform.observatory_environment import ( - HttpServer, - ObservatoryTestCase, - test_fixtures_path, -) -from observatory.platform.utils.http_download import ( +from observatory_platform.config import module_file_path +from observatory_platform.http_download import ( DownloadInfo, download_file, download_files, ) -from observatory.platform.utils.url_utils import get_observatory_http_header +from observatory_platform.sandbox.http_server import HttpServer +from observatory_platform.sandbox.test_utils import SandboxTestCase +from observatory_platform.url_utils import get_observatory_http_header class MockVersionData: @@ -44,10 +42,10 @@ def get(self, attribute): return "test@test" -class TestAsyncHttpFileDownloader(ObservatoryTestCase): +class TestAsyncHttpFileDownloader(SandboxTestCase): def test_download_files(self): # Spin up http server - directory = test_fixtures_path("utils") + directory = module_file_path("observatory_platform.tests.fixtures") http_server = HttpServer(directory=directory) with http_server.create(): file1 = "http_testfile.txt" @@ -74,7 +72,7 @@ def test_download_files(self): # URL only with observatory user agent with CliRunner().isolated_filesystem() as tmpdir: - with patch("observatory.platform.utils.url_utils.metadata", return_value=MockVersionData): + with patch("observatory_platform.url_utils.metadata", return_value=MockVersionData): headers = get_observatory_http_header(package_name="observatory-platform") download_list = [url1, url2] download_files(download_list=download_list, headers=headers) @@ -169,14 +167,14 @@ def test_download_files(self): self.assert_file_integrity(file1, hash1, algorithm) # Skip download because exists - with patch("observatory.platform.utils.http_download.download_http_file_") as m_down: + with patch("observatory_platform.http_download.download_http_file_") as m_down: success = download_file(url=url1, filename=file1, hash=hash1, hash_algorithm="md5") self.assertTrue(success) self.assert_file_integrity(file1, hash1, algorithm) self.assertEqual(m_down.call_count, 0) # Skip download because exists (with prefix dir) - with patch("observatory.platform.utils.http_download.download_http_file_") as m_down: + with patch("observatory_platform.http_download.download_http_file_") as m_down: success, download_info = download_file( url=url1, filename=file1, hash=hash1, hash_algorithm="md5", prefix_dir=tmpdir ) @@ -186,7 +184,7 @@ def test_download_files(self): # Get filename from Content-Disposition with CliRunner().isolated_filesystem() as tmpdir: - with patch("observatory.platform.utils.http_download.parse_header") as m_header: + with patch("observatory_platform.http_download.parse_header") as m_header: m_header.return_value = (None, {"filename": "testfile"}) success, download_info = download_file(url=url1, hash=hash1, hash_algorithm="md5") self.assertTrue(success) diff --git a/tests/observatory/platform/utils/test_jinja2_utils.py b/observatory_platform/tests/test_jinja2_utils.py similarity index 94% rename from tests/observatory/platform/utils/test_jinja2_utils.py rename to observatory_platform/tests/test_jinja2_utils.py index fc1a54402..09cfdb245 100644 --- a/tests/observatory/platform/utils/test_jinja2_utils.py +++ b/observatory_platform/tests/test_jinja2_utils.py @@ -19,7 +19,7 @@ import pendulum from click.testing import CliRunner -from observatory.platform.utils.jinja2_utils import render_template, make_jinja2_filename, make_sql_jinja2_filename +from observatory_platform.jinja2_utils import render_template, make_jinja2_filename, make_sql_jinja2_filename class TestJinja2Utils(unittest.TestCase): diff --git a/observatory-platform/observatory/platform/dags/dummy_telescope.py b/observatory_platform/tests/test_proc_utils.py similarity index 52% rename from observatory-platform/observatory/platform/dags/dummy_telescope.py rename to observatory_platform/tests/test_proc_utils.py index 64dccbfc1..63a291e95 100644 --- a/observatory-platform/observatory/platform/dags/dummy_telescope.py +++ b/observatory_platform/tests/test_proc_utils.py @@ -1,4 +1,4 @@ -# Copyright 2020 Curtin University +# Copyright 2022 Curtin University # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,15 +14,18 @@ # Author: Aniek Roelofs -import pendulum -from airflow import DAG -from airflow.operators.bash import BashOperator +import unittest +from unittest.mock import patch -default_args = {"owner": "airflow", "start_date": pendulum.datetime(2020, 8, 10)} +from observatory_platform.proc_utils import wait_for_process -with DAG(dag_id="dummy_telescope", schedule="@daily", default_args=default_args, catchup=True) as dag: - task1 = BashOperator(task_id="task1", bash_command="echo 'hello'", queue="remote_queue") - task2 = BashOperator(task_id="task2", bash_command="echo 'world'", queue="remote_queue") +class TestProcUtils(unittest.TestCase): + @patch("subprocess.Popen") + def test_wait_for_process(self, mock_popen): + proc = mock_popen() + proc.communicate.return_value = "out".encode(), "err".encode() - task1 >> task2 + out, err = wait_for_process(proc) + self.assertEqual("out", out) + self.assertEqual("err", err) diff --git a/tests/observatory/platform/test_sftp.py b/observatory_platform/tests/test_sftp.py similarity index 97% rename from tests/observatory/platform/test_sftp.py rename to observatory_platform/tests/test_sftp.py index fc482f9ef..aff19ddf8 100644 --- a/tests/observatory/platform/test_sftp.py +++ b/observatory_platform/tests/test_sftp.py @@ -20,7 +20,7 @@ import pysftp from airflow.models.connection import Connection -from observatory.platform.sftp import make_sftp_connection +from observatory_platform.sftp import make_sftp_connection class TestSFTP(unittest.TestCase): diff --git a/tests/observatory/platform/utils/test_url_utils.py b/observatory_platform/tests/test_url_utils.py similarity index 91% rename from tests/observatory/platform/utils/test_url_utils.py rename to observatory_platform/tests/test_url_utils.py index 6b055ff0d..ad5a65ce6 100644 --- a/tests/observatory/platform/utils/test_url_utils.py +++ b/observatory_platform/tests/test_url_utils.py @@ -14,22 +14,23 @@ # Author: James Diprose, Keegan Smith +import time import unittest from datetime import datetime -from typing import List -from unittest.mock import patch +from typing import List, Any +from unittest.mock import patch, Mock import httpretty import pendulum import requests import responses -import time from airflow.exceptions import AirflowException from click.testing import CliRunner from tenacity import wait_fixed -from observatory.platform.observatory_environment import HttpServer, test_fixtures_path -from observatory.platform.utils.url_utils import ( +from observatory_platform.config import module_file_path +from observatory_platform.sandbox.http_server import HttpServer +from observatory_platform.url_utils import ( get_filename_from_url, get_http_response_json, get_http_response_xml_to_dict, @@ -41,7 +42,15 @@ wait_for_url, get_filename_from_http_header, ) -from tests.observatory.platform.cli.test_platform_command import MockUrlOpen + + +class MockUrlOpen(Mock): + def __init__(self, status: int, **kwargs: Any): + super().__init__(**kwargs) + self.status = status + + def getcode(self): + return self.status class TestUrlUtils(unittest.TestCase): @@ -57,6 +66,7 @@ def get(self, attribute): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + self.fixtures_path = module_file_path("observatory_platform.tests.fixtures") def __create_mock_request_sequence(self, url: str, status_codes: List[int], bodies: List[str], sleep: float = 0): self.sequence = 0 @@ -212,7 +222,7 @@ def test_retry_get_url_read_timeout(self): httpretty.disable() httpretty.reset() - @patch("observatory.platform.utils.url_utils.urllib.request.urlopen") + @patch("observatory_platform.url_utils.urllib.request.urlopen") def test_wait_for_url_success(self, mock_url_open): # Mock the status code return value: 200 should succeed mock_url_open.return_value = MockUrlOpen(200) @@ -225,7 +235,7 @@ def test_wait_for_url_success(self, mock_url_open): self.assertTrue(state) self.assertAlmostEqual(0, duration, delta=0.5) - @patch("observatory.platform.utils.url_utils.urllib.request.urlopen") + @patch("observatory_platform.url_utils.urllib.request.urlopen") def test_wait_for_url_failed(self, mock_url_open): # Mock the status code return value: 500 should fail mock_url_open.return_value = MockUrlOpen(500) @@ -239,7 +249,7 @@ def test_wait_for_url_failed(self, mock_url_open): self.assertFalse(state) self.assertAlmostEqual(expected_timeout, duration, delta=1) - @patch("observatory.platform.utils.url_utils.metadata", return_value=MockMetadata) + @patch("observatory_platform.url_utils.metadata", return_value=MockMetadata) def test_user_agent(self, mock_cfg): """Test user agent generation""" @@ -247,7 +257,7 @@ def test_user_agent(self, mock_cfg): ua = get_user_agent(package_name="observatory-platform") self.assertEqual(ua, gt) - @patch("observatory.platform.utils.url_utils.metadata", return_value=MockMetadata) + @patch("observatory_platform.url_utils.metadata", return_value=MockMetadata) def test_get_observatory_http_header(self, mock_cfg): expected_header = {"User-Agent": "observatory-platform v1 (+http://test.test; mailto: test@test)"} header = get_observatory_http_header(package_name="observatory-platform") @@ -279,7 +289,7 @@ def test_get_http_text_response(self): def test_get_http_response_json(self): with CliRunner().isolated_filesystem(): - httpserver = HttpServer(test_fixtures_path("utils")) + httpserver = HttpServer(self.fixtures_path) with httpserver.create(): url = f"http://{httpserver.host}:{httpserver.port}/get_http_response_json.json" @@ -290,7 +300,7 @@ def test_get_http_response_json(self): def test_get_http_response_xml_to_dict(self): with CliRunner().isolated_filesystem(): - httpserver = HttpServer(test_fixtures_path("utils")) + httpserver = HttpServer(self.fixtures_path) with httpserver.create(): url = f"http://{httpserver.host}:{httpserver.port}/get_http_response_xml_to_dict.xml" @@ -302,7 +312,7 @@ def test_get_http_response_xml_to_dict(self): self.assertEqual(response["note"]["heading"], "Test heading") self.assertEqual(response["note"]["body"], "Test text") - @patch("observatory.platform.utils.url_utils.requests.head") + @patch("observatory_platform.url_utils.requests.head") def test_get_filename_from_http_header(self, m_head): url = "http://someurl" diff --git a/observatory-platform/observatory/platform/utils/url_utils.py b/observatory_platform/url_utils.py similarity index 100% rename from observatory-platform/observatory/platform/utils/url_utils.py rename to observatory_platform/url_utils.py index e4c424ebc..e2acc24f4 100644 --- a/observatory-platform/observatory/platform/utils/url_utils.py +++ b/observatory_platform/url_utils.py @@ -18,18 +18,18 @@ import json import logging import os +import time import urllib.error import urllib.request from datetime import datetime from email.utils import parsedate_to_datetime +from importlib.metadata import metadata from typing import Dict, List, Tuple, Union, Optional import pytz import requests -import time import xmltodict from airflow import AirflowException -from importlib.metadata import metadata from requests.adapters import HTTPAdapter from tenacity import Retrying, stop_after_attempt, before_sleep_log, wait_exponential_jitter from tenacity.wait import wait_base, wait_fixed diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 000000000..07df75c0b --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,146 @@ +[build-system] +requires = ["setuptools>=44", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "observatory-platform" +version = "1.0.0" +description = "The Observatory Platform is an environment for fetching, processing and analysing data to understand how well universities operate as Open Knowledge Institutions." +requires-python = ">=3.10" +license = { text = "Apache-2.0" } +keywords = ["science", "data", "workflows", "academic institutes", "academic-observatory-workflows"] +authors = [{ name = "Curtin University", email = "agent@observatory.academy" }] +classifiers = [ + "Development Status :: 2 - Pre-Alpha", + "Environment :: Console", + "Environment :: Web Environment", + "Intended Audience :: Developers", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.7", + "Topic :: Scientific/Engineering", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", + "Topic :: Utilities" +] +dependencies = [ + # Airflow + "apache-airflow[slack]==2.7.3", + "apache-airflow-providers-cncf-kubernetes>=7.4.0", + # Google Cloud + "google-crc32c>=1.1.0,<2", + "google-cloud-bigquery>=3,<4", + "google-api-python-client>=2,<3", + "google-cloud-storage>=2.7.0,<3", + #"google-auth-oauthlib>=0.4.5,<1", + "google-cloud-compute >=1.16.0,<2.0", + # File manipulation, reading, writing + "Jinja2>=3,<4", + "jsonlines>=2.0.0,<3", # Writing + "json_lines>=0.5.0,<1", # Reading, including jsonl.gz + "pandas>=2,<3", + # HTTP requests and URL cleaning + "requests>=2.25.0,<3", + "aiohttp>=3.7.0,<4", + # SFTP + "paramiko>=2.7.2,<3", + "pysftp>=0.2.9,<1", + # Utils + "natsort>=7.1.1,<8", + "backoff>=2,<3", + "validators<=0.20.0", + "xmltodict", + "tenacity", +] + +[project.optional-dependencies] +tests = [ + "liccheck>=0.4.9,<1", + "flake8>=3.8.0,<4", + "coverage>=5.2,<6", + "azure-storage-blob>=12.8.1,<13", + "click>=8,<9", + "httpretty>=1.0.0,<2", + "deepdiff>=6,<7", + "responses>=0.23.1,<1", + "boto3>=1.15.0,<2", + "timeout-decorator>=0,<1", + # SFTP & FTP + "sftpserver>=0.3,<1", + "pyftpdlib>=1.5.7,<2", +] + +[project.urls] +"Homepage" = "https://github.com/The-Academic-Observatory/observatory-platform" +"Bug Tracker" = "https://github.com/The-Academic-Observatory/observatory-platform/issues" +"Documentation" = "https://observatory-platform.readthedocs.io/en/latest/" +"Source" = "https://github.com/The-Academic-Observatory/observatory-platform" +"Repository" = "https://github.com/The-Academic-Observatory/observatory-platform" + +[tool.liccheck] +authorized_licenses = [ + # Unencumbered licenses: https://opensource.google/docs/thirdparty/licenses/#unencumbered + "public domain ", + "cc0 1.0 universal (cc0 1.0) public domain dedication", + # Notice licenses: https://opensource.google/docs/thirdparty/licenses/#notice + "artistic", + "apache software license 2.0", + "apache license version 2.0", + "apache license, version 2.0", + "apache license 2.0", + "apache 2.0", + "apache-2.0", + "apache software", + "apache 2", + "apache-2", + "bsd", + "bsd-2-clause", + "bsd-3-clause", + "3-clause bsd", + "new bsd", + "bsd or apache license, version 2.0", + "bsd-2-clause or apache-2.0", + "isc license (iscl)", + "isc", + "mit", + "python software foundation", + "psfl", + "zpl 2.1", + # Reciprocal licenses: https://opensource.google/docs/thirdparty/licenses/#reciprocal + "mozilla public license 1.1 (mpl 1.1)", + "mozilla public license 2.0 (mpl 2.0)", + "mpl-2.0", + # LGPL + "lgpl", + "gnu library or lesser general public license (lgpl)", + "gnu lesser general public license v3 or later (lgplv3+)", + "lgplv3+" +] +unauthorized_licenses = [ + # Restricted licenses: https://opensource.google/docs/thirdparty/licenses/#restricted + "gpl v1", + "gpl v2", + "gpl v3", + # Banned licenses: https://opensource.google/docs/thirdparty/licenses/#banned + "agpl", + "affero gpl", + "agpl (affero gpl)", + "sspl" +] +level = "CAUTIOUS" +reporting_txt_file = "liccheck-output.txt" # by default is None +no_deps = false +dependencies = true # to load [project.dependencies] +optional_dependencies = ["tests"] # to load extras from [project.optional-dependencies] + +[tool.liccheck.authorized_packages] +# MIT license: https://pypi.org/project/ordereddict/1.1/ +ordereddict = "1.1" + +# MIT license: https://pypi.org/project/pendulum/1.4.4/ +pendulum = ">=1.4.4" + +# Python Imaging Library (PIL) License: https://github.com/python-pillow/Pillow/blob/master/LICENSE +Pillow = ">=7.2.0" \ No newline at end of file diff --git a/setup.cfg b/setup.cfg index e51fcd6f2..9ee53195c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,4 +1,65 @@ [metadata] -name = observatory -; This file is required because Sphinx needs it to determine the project name. -; The following error will be thrown otherwise: 'Could not find a setup.cfg to extract project name from' \ No newline at end of file +name = observatory-platform +author = Curtin University +author_email = agent@observatory.academy +summary = The Observatory Platform is an environment for fetching, processing and analysing data to understand how well universities operate as Open Knowledge Institutions. +description_file = README.md +description_content_type = text/markdown; charset=UTF-8 +home_page = https://github.com/The-Academic-Observatory/observatory-platform +project_urls = + Bug Tracker = + Documentation = + Source Code = +python_requires = >=3.10 +license = Apache License Version 2.0 +classifier = + Development Status :: 2 - Pre-Alpha + Environment :: Console + Environment :: Web Environment + Intended Audience :: Developers + Intended Audience :: Science/Research + License :: OSI Approved :: Apache Software License + Operating System :: OS Independent + Programming Language :: Python :: 3 :: Only + Programming Language :: Python :: 3 + Programming Language :: Python :: 3.10 + Topic :: Scientific/Engineering + Topic :: Software Development :: Libraries + Topic :: Software Development :: Libraries :: Python Modules + Topic :: Utilities +keywords = + science + data + workflows + academic institutes + observatory-platform + +[files] +packages = + observatory + observatory.platform +data_files = + requirements.txt = requirements.txt + requirements.sh = requirements.sh + observatory/platform/docker = observatory/platform/docker/* + observatory/platform/terraform = observatory/platform/terraform/* + observatory/platform = + observatory/platform/config.yaml.jinja2 + observatory/platform/config-terraform.yaml.jinja2 + +[entry_points] +console_scripts = + observatory = observatory_platform.cli.cli:cli + +[extras] +tests = + + + + + + + + +[pbr] +skip_authors = true diff --git a/strategy.ini b/strategy.ini deleted file mode 100644 index bf6addef7..000000000 --- a/strategy.ini +++ /dev/null @@ -1,70 +0,0 @@ -# List authorised and unauthorised licenses in lower case -[Licenses] -authorized_licenses : - # Unencumbered licenses: https://opensource.google/docs/thirdparty/licenses/#unencumbered - public domain - - # Notice licenses: https://opensource.google/docs/thirdparty/licenses/#notice - artistic - apache software license 2.0 - apache license version 2.0 - apache license, version 2.0 - apache license 2.0 - apache 2.0 - apache-2.0 - apache software - apache 2 - apache-2 - bsd - bsd-2-clause - bsd-3-clause - 3-clause bsd - new bsd - bsd or apache license, version 2.0 - bsd-2-clause or apache-2.0 - isc license (iscl) - isc - mit - python software foundation - psfl - zpl 2.1 - - # Reciprocal licenses: https://opensource.google/docs/thirdparty/licenses/#reciprocal - mozilla public license 1.1 (mpl 1.1) - mozilla public license 2.0 (mpl 2.0) - mpl-2.0 - - # LGPL - lgpl - gnu library or lesser general public license (lgpl) - gnu lesser general public license v3 or later (lgplv3+) - lgplv3+ - -unauthorized_licenses : - # Restricted licenses: https://opensource.google/docs/thirdparty/licenses/#restricted - gpl v1 - gpl v2 - gpl v3 - - # Banned licenses: https://opensource.google/docs/thirdparty/licenses/#banned - agpl - affero gpl - agpl (affero gpl) - sspl - -[Authorized Packages] -# MIT license: https://pypi.org/project/ordereddict/1.1/ -ordereddict: 1.1 - -# MIT license: https://pypi.org/project/pendulum/1.4.4/ -pendulum: >=1.4.4 - -# Python Imaging Library (PIL) License: https://github.com/python-pillow/Pillow/blob/master/LICENSE -Pillow: >=7.2.0 - -# Precipy was developed for the project, a license will be added to the repo shortly -precipy: >=0.2.2 - -# Apache 2.0 license, but this is not classified properly, see https://github.com/p1c2u/openapi-spec-validator/issues/138 -openapi-spec-validator: >=0.3.1,<1 - diff --git a/tests/fixtures/cli/my-workflows-project/my_workflows_project/dags/hello_world_dag.py b/tests/fixtures/cli/my-workflows-project/my_workflows_project/dags/hello_world_dag.py deleted file mode 100644 index 1d1580123..000000000 --- a/tests/fixtures/cli/my-workflows-project/my_workflows_project/dags/hello_world_dag.py +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:38993de5b31fa8138ef26129c3fd59f95c3aaf4427cce9b92109fad480168c7a -size 1379 diff --git a/tests/fixtures/cli/my-workflows-project/my_workflows_project/dags/my_dag.py b/tests/fixtures/cli/my-workflows-project/my_workflows_project/dags/my_dag.py deleted file mode 100644 index 0ff0b9197..000000000 --- a/tests/fixtures/cli/my-workflows-project/my_workflows_project/dags/my_dag.py +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:2fab4a2574c851fb4edb30fab83a6590f0f52789ab43672004348f2640706450 -size 1602 diff --git a/tests/fixtures/cli/my-workflows-project/my_workflows_project/workflows/__init__.py b/tests/fixtures/cli/my-workflows-project/my_workflows_project/workflows/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/fixtures/cli/my-workflows-project/my_workflows_project/workflows/my_workflow.py b/tests/fixtures/cli/my-workflows-project/my_workflows_project/workflows/my_workflow.py deleted file mode 100644 index ecb2b3180..000000000 --- a/tests/fixtures/cli/my-workflows-project/my_workflows_project/workflows/my_workflow.py +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:f059e1729f91ab5d25731f5bb3f54e5d5f4060aa563c78a85c30a8ff996bbd3a -size 2015 diff --git a/tests/fixtures/cli/my-workflows-project/requirements.sh b/tests/fixtures/cli/my-workflows-project/requirements.sh deleted file mode 100644 index 5d7ece972..000000000 --- a/tests/fixtures/cli/my-workflows-project/requirements.sh +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:0a8090aa477268a3d0dfbd780a2115718805d08682d4c7385272910a22b62b31 -size 625 diff --git a/tests/fixtures/cli/my-workflows-project/requirements.txt b/tests/fixtures/cli/my-workflows-project/requirements.txt deleted file mode 100644 index 8f15056e6..000000000 --- a/tests/fixtures/cli/my-workflows-project/requirements.txt +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:197b0ff11cd452f01099e7386c8e45bc2d614da5a7bb8bc8daef5fdf8bf1fe83 -size 12 diff --git a/tests/fixtures/cli/my-workflows-project/setup.cfg b/tests/fixtures/cli/my-workflows-project/setup.cfg deleted file mode 100644 index 8a9328ddc..000000000 --- a/tests/fixtures/cli/my-workflows-project/setup.cfg +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:e26b071c2387715afc45911547e5e9737016fdec07ff3cefae69db7e72e29b8d -size 263 diff --git a/tests/fixtures/cli/my-workflows-project/setup.py b/tests/fixtures/cli/my-workflows-project/setup.py deleted file mode 100644 index c7e370a13..000000000 --- a/tests/fixtures/cli/my-workflows-project/setup.py +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:bd9f29ac52d3a23dab7932d6b4883d90ae3a710633c1d3f2f2f83652af9bfc1f -size 96 diff --git a/tests/fixtures/elastic/the-expanse-mappings.json b/tests/fixtures/elastic/the-expanse-mappings.json deleted file mode 100644 index 6ff9250ad..000000000 --- a/tests/fixtures/elastic/the-expanse-mappings.json +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:f513017ce70a4ca65e15f0878f2dd7a409d3b1aac39bf34edc05a5c92aa7d7aa -size 413 diff --git a/tests/fixtures/schemas/ao-author-mappings.json b/tests/fixtures/schemas/ao-author-mappings.json deleted file mode 100644 index 189570d74..000000000 --- a/tests/fixtures/schemas/ao-author-mappings.json +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:1ee125954bf07c3d6af7801e61246465acc647793459f30aceb17d80c1d607e4 -size 481 diff --git a/tests/fixtures/schemas/ao_author_2021-01-01.json b/tests/fixtures/schemas/ao_author_2021-01-01.json deleted file mode 100644 index 589df926c..000000000 --- a/tests/fixtures/schemas/ao_author_2021-01-01.json +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:dabe83093ff944c24b0f830ad8c68f4488d41a32925b99a60f69ec4302ee925d -size 334 diff --git a/tests/fixtures/utils/main.tf b/tests/fixtures/utils/main.tf deleted file mode 100644 index c8d27a0e4..000000000 --- a/tests/fixtures/utils/main.tf +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:9844632918eed3f18cc25c5dfae363f324d2967674e7f4c5504a69f66f997d73 -size 254 diff --git a/tests/fixtures/utils/test.csv b/tests/fixtures/utils/test.csv deleted file mode 100644 index 7caf5b9e9..000000000 --- a/tests/fixtures/utils/test.csv +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:75c322d4b384f2b59b288566d4cd6623eb14b72bfc7b7fbafe8f7fafa8e91dfe -size 29 diff --git a/tests/observatory/__init__.py b/tests/observatory/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/observatory/api/__init__.py b/tests/observatory/api/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/observatory/api/client/__init__.py b/tests/observatory/api/client/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/observatory/api/client/test_dataset_release.py b/tests/observatory/api/client/test_dataset_release.py deleted file mode 100644 index b0f9564ea..000000000 --- a/tests/observatory/api/client/test_dataset_release.py +++ /dev/null @@ -1,78 +0,0 @@ -""" - Observatory API - - The REST API for managing and accessing data from the Observatory Platform. # noqa: E501 - - The version of the OpenAPI document: 1.0.0 - Contact: agent@observatory.academy - Generated by: https://openapi-generator.tech -""" - - -import datetime -import unittest - -from observatory.api.client.exceptions import ApiAttributeError, ApiTypeError -from observatory.api.client.model.dataset_release import DatasetRelease - - -class TestDatasetRelease(unittest.TestCase): - """DatasetRelease unit test stubs""" - - def testDatasetRelease(self): - """Test DatasetRelease""" - - class Configuration: - def __init__(self): - self.discard_unknown_keys = True - - dt = datetime.datetime.utcnow() - - # Successfully create - DatasetRelease( - id=1, - dag_id="doi_workflow", - dataset_id="doi", - dag_run_id="scheduled__2023-03-26T00:00:00+00:00", - data_interval_start=dt, - data_interval_end=dt, - partition_date=dt, - snapshot_date=dt, - changefile_start_date=dt, - changefile_end_date=dt, - sequence_start=1, - sequence_end=10, - extra={}, - ) - - # Created and modified are read only - with self.assertRaises(ApiAttributeError): - DatasetRelease( - id=1, - dag_id="doi_workflow", - dataset_id="doi", - partition_date=dt, - snapshot_date=dt, - start_date=dt, - end_date=dt, - sequence_num=1, - extra={}, - created=dt, - modified=dt, - ) - - # Invalid argument - with self.assertRaises(ApiTypeError): - DatasetRelease("hello") - - # Invalid keyword argument - with self.assertRaises(ApiAttributeError): - DatasetRelease(hello="world") - - self.assertRaises(ApiTypeError, DatasetRelease._from_openapi_data, "hello") - - DatasetRelease._from_openapi_data(hello="world", _configuration=Configuration()) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/observatory/api/client/test_observatory_api.py b/tests/observatory/api/client/test_observatory_api.py deleted file mode 100644 index d32d57030..000000000 --- a/tests/observatory/api/client/test_observatory_api.py +++ /dev/null @@ -1,295 +0,0 @@ -# Copyright 2021 Curtin University -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Author: Aniek Roelofs, James Diprose - -""" - Observatory API - - The REST API for managing and accessing data from the Observatory Platform. # noqa: E501 - - The version of the OpenAPI document: 1.0.0 - Contact: agent@observatory.academy - Generated by: https://openapi-generator.tech -""" - -import unittest - -import pendulum - -import observatory.api.server.orm as orm -from observatory.api.client import ApiClient, Configuration -from observatory.api.client.api.observatory_api import ObservatoryApi # noqa: E501 -from observatory.api.client.exceptions import ( - NotFoundException, -) -from observatory.api.client.model.dataset_release import DatasetRelease -from observatory.api.testing import ObservatoryApiEnvironment -from observatory.platform.observatory_environment import find_free_port - - -class TestObservatoryApi(unittest.TestCase): - """ObservatoryApi unit test stubs""" - - def setUp(self): - self.timezone = "Pacific/Auckland" - self.host = "localhost" - self.port = find_free_port() - configuration = Configuration(host=f"http://{self.host}:{self.port}") - api_client = ApiClient(configuration) - self.api = ObservatoryApi(api_client=api_client) # noqa: E501 - self.env = ObservatoryApiEnvironment(host=self.host, port=self.port) - - def test_ctor(self): - api = ObservatoryApi() - self.assertTrue(api.api_client is not None) - - def test_get_dataset_release(self): - """Test case for get_dataset_release""" - - with self.env.create(): - # Not found - expected_id = 1 - with self.assertRaises(NotFoundException) as e: - self.api.get_dataset_release(id=expected_id) - self.assertEqual(404, e.exception.status) - self.assertEqual(f'"Not found: DatasetRelease with id {expected_id}"\n', e.exception.body) - - # Add DatasetRelease to database - dt = pendulum.now(self.timezone) - dt_utc = dt.in_tz(tz="UTC") - self.env.session.add( - orm.DatasetRelease( - dag_id="doi_workflow", - dataset_id="doi", - dag_run_id="scheduled__2023-03-26T00:00:00+00:00", - data_interval_start=dt, - data_interval_end=dt, - partition_date=dt, - snapshot_date=dt, - changefile_start_date=dt, - changefile_end_date=dt, - sequence_start=1, - sequence_end=10, - extra={"hello": "world"}, - created=dt, - modified=dt, - ) - ) - self.env.session.commit() - - # Assert that DatasetRelease with given id exists - obj = self.api.get_dataset_release(id=expected_id) - self.assertIsInstance(obj, DatasetRelease) - self.assertEqual(expected_id, obj.id) - self.assertEqual("doi_workflow", obj.dag_id) - self.assertEqual("doi", obj.dataset_id) - self.assertEqual("scheduled__2023-03-26T00:00:00+00:00", obj.dag_run_id) - self.assertEqual(dt_utc, obj.data_interval_start) - self.assertEqual(dt_utc, obj.data_interval_end) - self.assertEqual(dt_utc, obj.snapshot_date) - self.assertEqual(dt_utc, obj.partition_date) - self.assertEqual(dt_utc, obj.changefile_start_date) - self.assertEqual(dt_utc, obj.changefile_end_date) - self.assertEqual(1, obj.sequence_start) - self.assertEqual(10, obj.sequence_end) - self.assertEqual({"hello": "world"}, obj.extra) - self.assertEqual(dt_utc, obj.created) - self.assertEqual(dt_utc, obj.modified) - - # Search by dataset_id - obj = self.api.get_dataset_release(expected_id) - self.assertIsInstance(obj, DatasetRelease) - - # DatasetRelease not found - id = 2 - self.assertRaises(NotFoundException, self.api.get_dataset_release, id) - - def test_post_dataset_release(self): - """Test case for post_dataset_release""" - - with self.env.create(): - # Post DatasetRelease - expected_id = 1 - dt = pendulum.now(self.timezone) - dt_utc = dt.in_tz(tz="UTC") - obj = DatasetRelease( - dag_id="doi_workflow", - dataset_id="doi", - dag_run_id="scheduled__2023-03-26T00:00:00+00:00", - data_interval_start=dt, - data_interval_end=dt, - snapshot_date=dt, - partition_date=dt, - changefile_start_date=dt, - changefile_end_date=dt, - sequence_start=1, - sequence_end=10, - extra={"hello": "world"}, - ) - result = self.api.post_dataset_release(obj) - self.assertIsInstance(result, DatasetRelease) - self.assertEqual(expected_id, result.id) - self.assertEqual("doi_workflow", result.dag_id) - self.assertEqual("doi", result.dataset_id) - self.assertEqual("scheduled__2023-03-26T00:00:00+00:00", result.dag_run_id) - self.assertEqual(dt_utc, result.data_interval_start) - self.assertEqual(dt_utc, result.data_interval_end) - self.assertEqual(dt_utc, result.snapshot_date) - self.assertEqual(dt_utc, result.partition_date) - self.assertEqual(dt_utc, result.changefile_start_date) - self.assertEqual(dt_utc, result.changefile_end_date) - self.assertEqual(1, result.sequence_start) - self.assertEqual(10, result.sequence_end) - self.assertEqual({"hello": "world"}, result.extra) - - def test_put_dataset_release(self): - """Test case for put_dataset_release""" - - with self.env.create(): - # Put create - expected_id = 1 - dt = pendulum.now(self.timezone) - dt_utc = dt.in_tz(tz="UTC") - obj = DatasetRelease( - dag_id="doi_workflow", - dataset_id="doi", - dag_run_id="scheduled__2023-03-26T00:00:00+00:00", - data_interval_start=dt, - data_interval_end=dt, - snapshot_date=dt, - partition_date=dt, - changefile_start_date=dt, - changefile_end_date=dt, - sequence_start=1, - sequence_end=10, - extra={"hello": "world"}, - ) - result = self.api.put_dataset_release(obj) - self.assertIsInstance(result, DatasetRelease) - self.assertEqual(expected_id, result.id) - self.assertEqual("doi_workflow", result.dag_id) - self.assertEqual("doi", result.dataset_id) - self.assertEqual("scheduled__2023-03-26T00:00:00+00:00", result.dag_run_id) - self.assertEqual(dt_utc, result.data_interval_start) - self.assertEqual(dt_utc, result.data_interval_end) - self.assertEqual(dt_utc, result.snapshot_date) - self.assertEqual(dt_utc, result.partition_date) - self.assertEqual(dt_utc, result.changefile_start_date) - self.assertEqual(dt_utc, result.changefile_end_date) - self.assertEqual(1, result.sequence_start) - self.assertEqual(10, result.sequence_end) - self.assertEqual({"hello": "world"}, result.extra) - - # Put update - obj = DatasetRelease( - id=expected_id, - sequence_start=2, - ) - result = self.api.put_dataset_release(obj) - self.assertIsInstance(result, DatasetRelease) - self.assertEqual(expected_id, result.id) - self.assertEqual(2, result.sequence_start) - - # Put not found - expected_id = 2 - with self.assertRaises(NotFoundException) as e: - self.api.put_dataset_release( - DatasetRelease( - id=expected_id, - dag_id="doi_workflow", - dataset_id="doi", - dag_run_id="scheduled__2023-03-26T00:00:00+00:00", - data_interval_start=dt, - data_interval_end=dt, - snapshot_date=dt, - partition_date=dt, - changefile_start_date=dt, - changefile_end_date=dt, - sequence_start=1, - sequence_end=10, - ) - ) - self.assertEqual(404, e.exception.status) - self.assertEqual(f'"Not found: DatasetRelease with id {expected_id}"\n', e.exception.body) - - def test_get_dataset_releases(self): - """Test case for get_dataset_releases""" - - with self.env.create(): - # Post DatasetRelease - obj = DatasetRelease( - dag_id="doi_workflow", - dataset_id="doi", - dag_run_id="scheduled__2023-03-26T00:00:00+00:00", - snapshot_date=pendulum.datetime(2023, 1, 1), - ) - self.api.post_dataset_release(obj) - obj = DatasetRelease( - dag_id="doi_workflow", - dataset_id="doi", - dag_run_id="scheduled__2023-03-26T00:00:00+00:00", - snapshot_date=pendulum.datetime(2023, 1, 7), - ) - self.api.post_dataset_release(obj) - obj = DatasetRelease( - dag_id="doi_workflow", - dataset_id="author", - dag_run_id="scheduled__2023-03-26T00:00:00+00:00", - snapshot_date=pendulum.datetime(2023, 1, 7), - ) - self.api.post_dataset_release(obj) - - # Assert that all DatasetRelease objects returned - objects = self.api.get_dataset_releases(dag_id="doi_workflow", dataset_id="doi") - self.assertEqual(2, len(objects)) - self.assertIsInstance(objects[0], DatasetRelease) - self.assertIsInstance(objects[1], DatasetRelease) - - self.assertEqual(objects[0].id, 1) - self.assertEqual(objects[0].dag_id, "doi_workflow") - self.assertEqual(objects[0].dataset_id, "doi") - self.assertEqual(objects[1].id, 2) - self.assertEqual(objects[1].dag_id, "doi_workflow") - self.assertEqual(objects[1].dataset_id, "doi") - - objects = self.api.get_dataset_releases(dag_id="doi_workflow", dataset_id="author") - self.assertEqual(1, len(objects)) - self.assertIsInstance(objects[0], DatasetRelease) - self.assertEqual(objects[0].dag_id, "doi_workflow") - self.assertEqual(objects[0].dataset_id, "author") - - def test_delete_dataset_release(self): - """Test case for delete_dataset_release""" - - with self.env.create(): - # Post DatasetRelease - expected_id = 1 - obj = DatasetRelease( - dag_id="doi_workflow", - dataset_id="doi", - dag_run_id="scheduled__2023-03-26T00:00:00+00:00", - snapshot_date=pendulum.datetime(2023, 1, 1), - ) - self.api.post_dataset_release(obj) - self.api.delete_dataset_release(expected_id) - - with self.assertRaises(NotFoundException) as e: - self.api.delete_dataset_release(expected_id) - self.assertEqual(404, e.exception.status) - self.assertEqual(f'"Not found: DatasetRelease with id {expected_id}"\n', e.exception.body) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/observatory/api/server/__init__.py b/tests/observatory/api/server/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/observatory/api/server/test_openapi.py b/tests/observatory/api/server/test_openapi.py deleted file mode 100644 index 7e21a899c..000000000 --- a/tests/observatory/api/server/test_openapi.py +++ /dev/null @@ -1,53 +0,0 @@ -# Copyright 2021 Curtin University -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Author: James Diprose - -import os -import unittest - -from click.testing import CliRunner -from openapi_spec_validator import validate_spec -from openapi_spec_validator.readers import read_from_filename - -from observatory.api.server.openapi_renderer import OpenApiRenderer -from observatory.platform.config import module_file_path - - -class TestOpenApiSchema(unittest.TestCase): - def setUp(self) -> None: - self.template_file = os.path.join(module_file_path("observatory.api.server"), "openapi.yaml.jinja2") - - def test_validate_backend(self): - """Test that the backend OpenAPI spec is valid""" - - renderer = OpenApiRenderer(self.template_file, api_client=False) - render = renderer.render() - self.validate_spec(render) - - def test_validate_api_client(self): - """Test that the API Client OpenAPI spec is valid""" - - renderer = OpenApiRenderer(self.template_file, api_client=True) - render = renderer.render() - self.validate_spec(render) - - def validate_spec(self, render: str): - with CliRunner().isolated_filesystem(): - file_name = "openapi.yaml" - with open(file_name, mode="w") as f: - f.write(render) - - spec_dict, spec_url = read_from_filename(file_name) - validate_spec(spec_dict) diff --git a/tests/observatory/api/server/test_orm.py b/tests/observatory/api/server/test_orm.py deleted file mode 100644 index cca507aaf..000000000 --- a/tests/observatory/api/server/test_orm.py +++ /dev/null @@ -1,203 +0,0 @@ -# Copyright 2021 Curtin University -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Author: James Diprose - -import unittest - -import pendulum -import sqlalchemy -from sqlalchemy.orm import scoped_session -from sqlalchemy.pool import StaticPool - -from observatory.api.server.orm import ( - DatasetRelease, - create_session, - fetch_db_object, - set_session, - to_datetime_utc, -) - - -class TestSession(unittest.TestCase): - def __init__(self, *args, **kwargs): - super(TestSession, self).__init__(*args, **kwargs) - self.uri = "sqlite://" - - def test_create_session(self): - """Test create_session and init_db""" - - # Create session with seed_db set to True - self.session = create_session(uri=self.uri, connect_args={"check_same_thread": False}, poolclass=StaticPool) - self.assertTrue(self.session.connection()) - - -class TestOrm(unittest.TestCase): - def __init__(self, *args, **kwargs): - super(TestOrm, self).__init__(*args, **kwargs) - self.uri = "sqlite://" - - def setUp(self) -> None: - """Create the SQLAlchemy Session""" - - self.session = create_session(uri=self.uri, connect_args={"check_same_thread": False}, poolclass=StaticPool) - set_session(self.session) - - def test_fetch_db_item(self): - """Test fetch_db_object""" - - # Body is None - self.assertEqual(None, fetch_db_object(DatasetRelease, None)) - - # Body is instance of cls - dt = pendulum.now("UTC") - dict_ = { - "dag_id": "doi_workflow", - "dataset_id": "doi", - "dag_run_id": "scheduled__2023-03-26T00:00:00+00:00", - "snapshot_date": dt, - "extra": {"hello": "world"}, - "modified": dt, - } - obj = DatasetRelease(**dict_) - self.assertEqual(obj, fetch_db_object(DatasetRelease, obj)) - - # Body is Dict and no id - with self.assertRaises(AttributeError): - fetch_db_object(DatasetRelease, {"dag_id": "doi_workflow"}) - - # Body is Dict, has id and is not found - with self.assertRaises(ValueError): - fetch_db_object(DatasetRelease, {"id": 1}) - - # Body is Dict, has id and is found - self.session.add(obj) - self.session.commit() - obj_fetched = fetch_db_object(DatasetRelease, {"id": 1}) - self.assertEqual(obj, obj_fetched) - - # Body is wrong type - with self.assertRaises(ValueError): - fetch_db_object(DatasetRelease, "hello") - - def test_to_datetime_utc(self): - """Test to_datetime""" - - # From datetime - dt_nz = pendulum.datetime(year=2020, month=12, day=31, tz=pendulum.timezone("Pacific/Auckland")) - dt_utc = pendulum.datetime(year=2020, month=12, day=30, hour=11, tz=pendulum.timezone("UTC")) - self.assertEqual(dt_utc, to_datetime_utc(dt_nz)) - self.assertEqual(dt_utc, to_datetime_utc(dt_utc)) - - # From None - self.assertIsNone(to_datetime_utc(None)) - - # From another type - with self.assertRaises(ValueError): - to_datetime_utc(dt_nz.date()) - - def test_create_session(self): - """Test that session is created""" - - self.assertIsInstance(self.session, scoped_session) - - # Assert ValueError because uri=None - with self.assertRaises(ValueError): - create_session(uri=None) - - # connect_args=None - create_session(uri=self.uri, connect_args=None, poolclass=StaticPool) - - def test_dataset_release(self): - """Test that DatasetRelease can be created, fetched, updates and deleted""" - - # Create DatasetRelease - created = pendulum.now("UTC") - dt = pendulum.now("UTC") - release = DatasetRelease( - dag_id="doi_workflow", - dataset_id="doi", - dag_run_id="scheduled__2023-03-26T00:00:00+00:00", - data_interval_start=dt, - data_interval_end=dt, - snapshot_date=dt, - changefile_start_date=dt, - changefile_end_date=dt, - sequence_start=1, - sequence_end=10, - extra={"hello": "world"}, - created=created, - ) - self.session.add(release) - self.session.commit() - - # Assert created object - expected_id = 1 - self.assertIsNotNone(release.id) - self.assertEqual(release.id, expected_id) - - # Update DatasetRelease - release = self.session.query(DatasetRelease).filter(DatasetRelease.id == expected_id).one() - release.snapshot_date = pendulum.datetime(1900, 1, 1) - self.session.commit() - - # Assert update - release = self.session.query(DatasetRelease).filter(DatasetRelease.id == expected_id).one() - self.assertEqual(pendulum.instance(release.snapshot_date), pendulum.datetime(1900, 1, 1)) - - # Delete items - self.session.query(DatasetRelease).filter(DatasetRelease.id == expected_id).delete() - with self.assertRaises(sqlalchemy.orm.exc.NoResultFound): - self.session.query(DatasetRelease).filter(DatasetRelease.id == expected_id).one() - - def test_dataset_release_from_dict(self): - """Test that DatasetRelease can be created from a dictionary""" - - # Create - expected_id = 1 - dt = pendulum.now("UTC") - dict_ = { - "dag_id": "doi_workflow", - "dataset_id": "doi", - "dag_run_id": "scheduled__2023-03-26T00:00:00+00:00", - "data_interval_start": dt, - "data_interval_end": dt, - "snapshot_date": dt, - "changefile_start_date": dt, - "changefile_end_date": dt, - "sequence_start": 1, - "sequence_end": 10, - "extra": {"hello": "world"}, - "modified": dt, - "created": dt, - } - obj = DatasetRelease(**dict_) - self.session.add(obj) - self.session.commit() - self.assertIsNotNone(obj.id) - self.assertEqual(obj.id, expected_id) - - # Update with no new values - obj.update(**{}) - self.session.commit() - self.assertEqual("doi_workflow", obj.dag_id) - self.assertEqual(dt, pendulum.instance(obj.snapshot_date)) - - # Update - dt = pendulum.now("UTC") - dict_ = {"dag_id": "doi_workflow_2", "snapshot_date": dt} - obj.update(**dict_) - self.session.commit() - self.assertEqual("doi_workflow_2", obj.dag_id) - self.assertEqual(dt, pendulum.instance(obj.snapshot_date)) diff --git a/tests/observatory/api/test_utils.py b/tests/observatory/api/test_utils.py deleted file mode 100644 index fd173fd40..000000000 --- a/tests/observatory/api/test_utils.py +++ /dev/null @@ -1,59 +0,0 @@ -# Copyright 2020 Curtin University -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# -# Author: Tuan Chien - - -import os - -from observatory.api.client import ApiClient, Configuration -from observatory.api.client.api.observatory_api import ObservatoryApi # noqa: E501 -from observatory.api.testing import ObservatoryApiEnvironment -from observatory.api.utils import ( - get_api_client, -) -from observatory.platform.observatory_environment import ObservatoryTestCase -from observatory.platform.observatory_environment import find_free_port - - -class TestApiUtils(ObservatoryTestCase): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # API environment - self.host = "localhost" - self.port = find_free_port() - configuration = Configuration(host=f"http://{self.host}:{self.port}") - api_client = ApiClient(configuration) - self.api = ObservatoryApi(api_client=api_client) # noqa: E501 - self.env = ObservatoryApiEnvironment(host=self.host, port=self.port) - - def test_get_api_client(self): - # No env var set - api = get_api_client() - self.assertEqual(api.api_client.configuration.host, "http://localhost:5002") - self.assertEqual(api.api_client.configuration.api_key, {}) - - # Environment variable set - os.environ["API_URI"] = "http://testhost:5002" - api = get_api_client() - self.assertEqual(api.api_client.configuration.host, "http://testhost:5002") - self.assertEqual(api.api_client.configuration.api_key, {"api_key": None}) - del os.environ["API_URI"] - - # Pass in arguments - api = get_api_client(host="host1") - self.assertEqual(api.api_client.configuration.host, "http://host1:5002") - self.assertEqual(api.api_client.configuration.api_key, {}) diff --git a/tests/observatory/platform/__init__.py b/tests/observatory/platform/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/observatory/platform/cli/__init__.py b/tests/observatory/platform/cli/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/observatory/platform/cli/test_cli.py b/tests/observatory/platform/cli/test_cli.py deleted file mode 100644 index 5cd4fe6f7..000000000 --- a/tests/observatory/platform/cli/test_cli.py +++ /dev/null @@ -1,727 +0,0 @@ -# Copyright 2020 Curtin University -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Author: James Diprose, Aniek Roelofs - -import json -import os -import unittest -from typing import Any, List -from unittest.mock import Mock, patch - -from click.testing import CliRunner - -from observatory.platform.cli.cli import ( - LOCAL_CONFIG_PATH, - TERRAFORM_CONFIG_PATH, - cli, -) -from observatory.platform.docker.compose_runner import ProcessOutput -from observatory.platform.docker.platform_runner import DEBUG, HOST_UID -from observatory.platform.observatory_config import TerraformConfig, ValidationError -from observatory.platform.observatory_environment import random_id -from observatory.platform.terraform.terraform_api import TerraformApi - - -class TestObservatoryGenerate(unittest.TestCase): - @patch("click.confirm") - @patch("os.path.exists") - def test_generate_secrets(self, mock_path_exists, mock_click_confirm): - """Test that the fernet key and default config files are generated""" - - # Test generate fernet key - runner = CliRunner() - result = runner.invoke(cli, ["generate", "secrets", "fernet-key"]) - self.assertEqual(result.exit_code, os.EX_OK) - - result = runner.invoke(cli, ["generate", "secrets", "secret-key"]) - self.assertEqual(result.exit_code, os.EX_OK) - - # Test that files are generated - with runner.isolated_filesystem(): - mock_click_confirm.return_value = True - mock_path_exists.return_value = False - - # Test generate local config - config_path = os.path.abspath("config.yaml") - result = runner.invoke(cli, ["generate", "config", "local", "--config-path", config_path]) - self.assertEqual(result.exit_code, os.EX_OK) - self.assertTrue(os.path.isfile(config_path)) - self.assertIn("Observatory Config saved to:", result.output) - - # Test generate terraform config - config_path = os.path.abspath("config-terraform.yaml") - result = runner.invoke(cli, ["generate", "config", "terraform", "--config-path", config_path]) - self.assertEqual(result.exit_code, os.EX_OK) - self.assertTrue(os.path.isfile(config_path)) - self.assertIn("Terraform Config saved to:", result.output) - - # Test that files are not generated when confirm is set to n - runner = CliRunner() - with runner.isolated_filesystem(): - mock_click_confirm.return_value = False - mock_path_exists.return_value = True - - # Test generate local config - config_path = os.path.abspath("config.yaml") - result = runner.invoke(cli, ["generate", "config", "local", "--config-path", config_path]) - self.assertEqual(result.exit_code, os.EX_OK) - self.assertFalse(os.path.isfile(config_path)) - self.assertIn("Not generating Observatory config\n", result.output) - - # Test generate terraform config - config_path = os.path.abspath("config-terraform.yaml") - result = runner.invoke(cli, ["generate", "config", "terraform", "--config-path", config_path]) - self.assertEqual(result.exit_code, os.EX_OK) - self.assertFalse(os.path.isfile(config_path)) - self.assertIn("Not generating Terraform config\n", result.output) - - @patch("observatory.platform.cli.cli.click.confirm") - @patch("observatory.platform.cli.cli.GenerateCommand.generate_terraform_config_interactive") - @patch("observatory.platform.cli.cli.GenerateCommand.generate_local_config_interactive") - def test_generate_default_configs(self, m_gen_config, m_gen_terra, m_click): - m_click.return_value = True - runner = CliRunner() - - # Default local - result = runner.invoke(cli, ["generate", "config", "local", "--interactive"]) - self.assertEqual(result.exit_code, os.EX_OK) - self.assertEqual(m_gen_config.call_count, 1) - self.assertEqual(m_gen_config.call_args.kwargs["workflows"], []) - self.assertEqual(m_gen_config.call_args.args[0], LOCAL_CONFIG_PATH) - - # Default terraform - result = runner.invoke(cli, ["generate", "config", "terraform", "--interactive"]) - self.assertEqual(result.exit_code, os.EX_OK) - self.assertEqual(m_gen_terra.call_count, 1) - self.assertEqual(m_gen_terra.call_args.kwargs["workflows"], []) - self.assertEqual(m_gen_terra.call_args.args[0], TERRAFORM_CONFIG_PATH) - - @patch("observatory.platform.cli.cli.GenerateCommand.generate_local_config_interactive") - def test_generate_local_interactive(self, m_gen_config): - runner = CliRunner() - with runner.isolated_filesystem(): - config_path = os.path.abspath("config.yaml") - result = runner.invoke(cli, ["generate", "config", "local", "--config-path", config_path, "--interactive"]) - self.assertEqual(result.exit_code, os.EX_OK) - self.assertEqual(m_gen_config.call_count, 1) - self.assertEqual(m_gen_config.call_args.kwargs["workflows"], []) - self.assertEqual(m_gen_config.call_args.args[0], config_path) - - @patch("observatory.platform.cli.cli.GenerateCommand.generate_local_config_interactive") - def test_generate_local_interactive_install_oworkflows(self, m_gen_config): - runner = CliRunner() - with runner.isolated_filesystem(): - config_path = os.path.abspath("config.yaml") - result = runner.invoke( - cli, - args=[ - "generate", - "config", - "local", - "--config-path", - config_path, - "--interactive", - "--ao-wf", - "--oaebu-wf", - ], - ) - self.assertEqual(result.exit_code, os.EX_OK) - self.assertEqual(m_gen_config.call_count, 1) - self.assertEqual( - m_gen_config.call_args.kwargs["workflows"], ["academic-observatory-workflows", "oaebu-workflows"] - ) - self.assertEqual(m_gen_config.call_args.args[0], config_path) - - @patch("observatory.platform.cli.cli.GenerateCommand.generate_terraform_config_interactive") - def test_generate_terraform_interactive(self, m_gen_config): - runner = CliRunner() - with runner.isolated_filesystem(): - config_path = os.path.abspath("config.yaml") - result = runner.invoke( - cli, ["generate", "config", "terraform", "--config-path", config_path, "--interactive"] - ) - self.assertEqual(result.exit_code, os.EX_OK) - self.assertEqual(m_gen_config.call_count, 1) - self.assertEqual(m_gen_config.call_args.kwargs["workflows"], []) - self.assertEqual(m_gen_config.call_args.args[0], config_path) - - @patch("observatory.platform.cli.cli.GenerateCommand.generate_terraform_config_interactive") - def test_generate_terraform_interactive_install_oworkflows(self, m_gen_config): - runner = CliRunner() - with runner.isolated_filesystem(): - config_path = os.path.abspath("config.yaml") - result = runner.invoke( - cli, - [ - "generate", - "config", - "terraform", - "--config-path", - config_path, - "--interactive", - "--ao-wf", - "--oaebu-wf", - ], - ) - self.assertEqual(result.exit_code, os.EX_OK) - self.assertEqual(m_gen_config.call_count, 1) - self.assertEqual( - m_gen_config.call_args.kwargs["workflows"], ["academic-observatory-workflows", "oaebu-workflows"] - ) - self.assertEqual(m_gen_config.call_args.args[0], config_path) - - -class MockConfig(Mock): - def __init__(self, is_valid: bool, errors: List[ValidationError] = None, **kwargs: Any): - super().__init__(**kwargs) - self._is_valid = is_valid - self._errors = errors - - @property - def observatory(self): - mock = Mock() - mock.observatory_home = "/path/to/home" - return mock - - @property - def is_valid(self): - return self._is_valid - - @property - def errors(self): - return self._errors - - -class MockPlatformCommand(Mock): - def __init__( - self, - config: MockConfig, - *, - is_environment_valid: bool, - docker_exe_path: str, - is_docker_running: bool, - docker_compose: bool, - build_return_code: int, - start_return_code: int, - stop_return_code: int, - wait_for_airflow_ui: bool, - dags_path: str, - **kwargs, - ): - super().__init__(**kwargs) - self.config = config - self.is_environment_valid = is_environment_valid - self.docker_exe_path = docker_exe_path - self.is_docker_running = is_docker_running - self.docker_compose = docker_compose - self.host_uid = HOST_UID - self.debug = DEBUG - self.dags_path = dags_path - self._build_return_code = build_return_code - self._start_return_code = start_return_code - self._stop_return_code = stop_return_code - self._wait_for_airflow_ui = wait_for_airflow_ui - - def make_files(self): - pass - - @property - def ui_url(self): - return "http://localhost:8080" - - def build(self): - return ProcessOutput("output", "error", self._build_return_code) - - def start(self): - return ProcessOutput("output", "error", self._start_return_code) - - def stop(self): - return ProcessOutput("output", "error", self._stop_return_code) - - def wait_for_airflow_ui(self, timeout: int = 60): - return self._wait_for_airflow_ui - - -class TestObservatoryPlatform(unittest.TestCase): - @patch("observatory.platform.cli.cli.ObservatoryConfig.load") - @patch("observatory.platform.cli.cli.PlatformCommand") - def test_platform_start_stop_success(self, mock_cmd, mock_config): - """Test that the start and stop command are successful""" - - runner = CliRunner() - with runner.isolated_filesystem() as t: - # Make empty config - config_path = os.path.join(t, "config.yaml") - - # Mock platform command - is_environment_valid = True - docker_exe_path = "/path/to/docker" - is_docker_running = True - docker_compose = True - config = MockConfig(is_valid=True) - mock_config.return_value = config - build_return_code = 0 - start_return_code = 0 - stop_return_code = 0 - wait_for_airflow_ui = True - dags_path = "/path/to/dags" - mock_cmd.return_value = MockPlatformCommand( - config, - is_environment_valid=is_environment_valid, - docker_exe_path=docker_exe_path, - is_docker_running=is_docker_running, - docker_compose=docker_compose, - build_return_code=build_return_code, - start_return_code=start_return_code, - stop_return_code=stop_return_code, - wait_for_airflow_ui=wait_for_airflow_ui, - dags_path=dags_path, - ) - - # Make empty config - open(config_path, "a").close() - - # Test that start command works - result = runner.invoke(cli, ["platform", "start", "--config-path", config_path]) - self.assertEqual(result.exit_code, os.EX_OK) - - # Test that stop command works - result = runner.invoke(cli, ["platform", "stop", "--config-path", config_path]) - self.assertEqual(result.exit_code, os.EX_OK) - - @patch("observatory.platform.cli.cli.PlatformCommand") - def test_platform_start_fail_generate(self, mock_cmd): - """Check that no config file generates an error""" - - runner = CliRunner() - with runner.isolated_filesystem() as t: - # Environment invalid, no Docker, Docker not running, no Docker Compose, no config file - default_config_path = os.path.join(t, "config.yaml") - is_environment_valid = False - docker_exe_path = None - is_docker_running = False - docker_compose = False - config = None - build_return_code = 0 - start_return_code = 0 - stop_return_code = 0 - wait_for_airflow_ui = True - dags_path = "/path/to/dags" - mock_cmd.return_value = MockPlatformCommand( - config, - is_environment_valid=is_environment_valid, - docker_exe_path=docker_exe_path, - is_docker_running=is_docker_running, - docker_compose=docker_compose, - build_return_code=build_return_code, - start_return_code=start_return_code, - stop_return_code=stop_return_code, - wait_for_airflow_ui=wait_for_airflow_ui, - dags_path=dags_path, - ) - - # Test that start command fails - result = runner.invoke(cli, ["platform", "start", "--config-path", default_config_path]) - self.assertEqual(result.exit_code, os.EX_CONFIG) - - # config.yaml - self.assertIn("- file not found, generating a default file", result.output) - self.assertTrue(os.path.isfile(default_config_path)) - - # Check return code - self.assertEqual(result.exit_code, os.EX_CONFIG) - - @patch("observatory.platform.cli.cli.ObservatoryConfig.load") - @patch("observatory.platform.cli.cli.PlatformCommand") - def test_platform_start_fail_docker_install_errors(self, mock_cmd, mock_config): - """Test that docker and docker compose not installed errors show up""" - - runner = CliRunner() - with runner.isolated_filesystem() as t: - # Environment invalid, no Docker, Docker not running, no Docker Compose, no config file - default_config_path = os.path.join(t, "config.yaml") - is_environment_valid = False - docker_exe_path = None - is_docker_running = False - docker_compose = False - config = MockConfig(is_valid=True) - mock_config.return_value = config - build_return_code = 0 - start_return_code = 0 - stop_return_code = 0 - wait_for_airflow_ui = True - dags_path = "/path/to/dags" - mock_cmd.return_value = MockPlatformCommand( - config, - is_environment_valid=is_environment_valid, - docker_exe_path=docker_exe_path, - is_docker_running=is_docker_running, - docker_compose=docker_compose, - build_return_code=build_return_code, - start_return_code=start_return_code, - stop_return_code=stop_return_code, - wait_for_airflow_ui=wait_for_airflow_ui, - dags_path=dags_path, - ) - - # Make empty config - open(default_config_path, "a").close() - - # Test that start command fails - result = runner.invoke(cli, ["platform", "start", "--config-path", default_config_path]) - self.assertEqual(result.exit_code, os.EX_CONFIG) - - # Docker not installed - self.assertIn("https://docs.docker.com/get-docker/", result.output) - - # Docker Compose not installed - self.assertIn("https://docs.docker.com/compose/install/", result.output) - - # Check return code - self.assertEqual(result.exit_code, os.EX_CONFIG) - - @patch("observatory.platform.cli.cli.ObservatoryConfig.load") - @patch("observatory.platform.cli.cli.PlatformCommand") - def test_platform_start_fail_docker_run_errors(self, mock_cmd, mock_config): - """Test that error message is printed when Docker is installed but not running""" - - runner = CliRunner() - with runner.isolated_filesystem() as t: - # Environment invalid, Docker installed but not running - default_config_path = os.path.join(t, "config.yaml") - is_environment_valid = False - docker_exe_path = "/path/to/docker" - is_docker_running = False - docker_compose = True - config = MockConfig(is_valid=True) - mock_config.return_value = config - build_return_code = 0 - start_return_code = 0 - stop_return_code = 0 - wait_for_airflow_ui = True - dags_path = "/path/to/dags" - mock_cmd.return_value = MockPlatformCommand( - config, - is_environment_valid=is_environment_valid, - docker_exe_path=docker_exe_path, - is_docker_running=is_docker_running, - docker_compose=docker_compose, - build_return_code=build_return_code, - start_return_code=start_return_code, - stop_return_code=stop_return_code, - wait_for_airflow_ui=wait_for_airflow_ui, - dags_path=dags_path, - ) - - # Make empty config - open(default_config_path, "a").close() - - # Test that start command fails - result = runner.invoke(cli, ["platform", "start", "--config-path", default_config_path]) - self.assertEqual(result.exit_code, os.EX_CONFIG) - - # Check that Docker is not running message printed - self.assertIn("not running, please start", result.output) - - # Check return code - self.assertEqual(result.exit_code, os.EX_CONFIG) - - @patch("observatory.platform.cli.cli.ObservatoryConfig.load") - @patch("observatory.platform.cli.cli.PlatformCommand") - def test_platform_start_fail_invalid_config_errors(self, mock_cmd, mock_config): - """Test that invalid config errors show up""" - - runner = CliRunner() - with runner.isolated_filesystem() as t: - # Environment invalid, Docker installed but not running - default_config_path = os.path.join(t, "config.yaml") - is_environment_valid = False - docker_exe_path = "/path/to/docker" - is_docker_running = False - docker_compose = True - validation_error = ValidationError("google_cloud.credentials", "required field") - config = MockConfig(is_valid=False, errors=[validation_error]) - mock_config.return_value = config - build_return_code = 0 - start_return_code = 0 - stop_return_code = 0 - wait_for_airflow_ui = True - dags_path = "/path/to/dags" - mock_cmd.return_value = MockPlatformCommand( - config, - is_environment_valid=is_environment_valid, - docker_exe_path=docker_exe_path, - is_docker_running=is_docker_running, - docker_compose=docker_compose, - build_return_code=build_return_code, - start_return_code=start_return_code, - stop_return_code=stop_return_code, - wait_for_airflow_ui=wait_for_airflow_ui, - dags_path=dags_path, - ) - - # Make empty config - open(default_config_path, "a").close() - - # Test that start command fails - result = runner.invoke(cli, ["platform", "start", "--config-path", default_config_path]) - self.assertEqual(result.exit_code, os.EX_CONFIG) - - # Check that google credentials file does not exist is printed - self.assertIn(f"google_cloud.credentials: required field", result.output) - - # Check return code - self.assertEqual(result.exit_code, os.EX_CONFIG) - - -class TestObservatoryTerraform(unittest.TestCase): - organisation = os.getenv("TEST_TERRAFORM_ORGANISATION") - token = os.getenv("TEST_TERRAFORM_TOKEN") - terraform_api = TerraformApi(token) - version = TerraformApi.TERRAFORM_WORKSPACE_VERSION - description = "test" - - @patch("click.confirm") - @patch("observatory.platform.observatory_config.TerraformConfig.load") - def test_terraform_create_update(self, mock_load_config, mock_click_confirm): - """Test creating and updating a terraform cloud workspace""" - - # Create token json - token_json = {"credentials": {"app.terraform.io": {"token": self.token}}} - runner = CliRunner() - with runner.isolated_filesystem() as working_dir: - # File paths - terraform_credentials_path = os.path.join(working_dir, "token.json") - config_file_path = os.path.join(working_dir, "config-terraform.yaml") - credentials_file_path = os.path.join(working_dir, "google_application_credentials.json") - TerraformConfig.WORKSPACE_PREFIX = random_id() + "-" - - # Create token file - with open(terraform_credentials_path, "w") as f: - json.dump(token_json, f) - - # Make a fake google application credentials as it is required schema validation - with open(credentials_file_path, "w") as f: - f.write("") - - # Make a fake config-terraform.yaml file - with open(config_file_path, "w") as f: - f.write("") - - # Create config instance - config = TerraformConfig.from_dict( - { - "backend": {"type": "terraform", "environment": "develop"}, - "observatory": { - "package": "observatory-platform", - "package_type": "pypi", - "airflow_fernet_key": "IWt5jFGSw2MD1shTdwzLPTFO16G8iEAU3A6mGo_vJTY=", - "airflow_secret_key": "a" * 16, - "airflow_ui_user_password": "password", - "airflow_ui_user_email": "password", - "postgres_password": "my-password", - }, - "terraform": {"organization": self.organisation}, - "google_cloud": { - "project_id": "my-project", - "credentials": credentials_file_path, - "region": "us-west1", - "zone": "us-west1-c", - "data_location": "us", - }, - "cloud_sql_database": {"tier": "db-custom-2-7680", "backup_start_time": "23:00"}, - "airflow_main_vm": { - "machine_type": "n2-standard-2", - "disk_size": 1, - "disk_type": "pd-ssd", - "create": True, - }, - "airflow_worker_vm": { - "machine_type": "n2-standard-2", - "disk_size": 1, - "disk_type": "pd-standard", - "create": False, - }, - } - ) - - self.assertTrue(config.is_valid) - mock_load_config.return_value = config - - # Create terraform api instance - terraform_api = TerraformApi(self.token) - workspace = TerraformConfig.WORKSPACE_PREFIX + config.backend.environment.value - - # As a safety measure, delete workspace even though it shouldn't exist yet - terraform_api.delete_workspace(self.organisation, workspace) - - # Create workspace, confirm yes - mock_click_confirm.return_value = "y" - result = runner.invoke( - cli, - [ - "terraform", - "create-workspace", - config_file_path, - "--terraform-credentials-path", - terraform_credentials_path, - ], - ) - self.assertIn("Successfully created workspace", result.output) - - # Create workspace, confirm no - mock_click_confirm.return_value = False - result = runner.invoke( - cli, - [ - "terraform", - "create-workspace", - config_file_path, - "--terraform-credentials-path", - terraform_credentials_path, - ], - ) - self.assertNotIn("Creating workspace...", result.output) - - # Update workspace, same config file but sensitive values will be replaced - mock_click_confirm.return_value = "y" - result = runner.invoke( - cli, - [ - "terraform", - "update-workspace", - config_file_path, - "--terraform-credentials-path", - terraform_credentials_path, - ], - ) - self.assertIn("Successfully updated workspace", result.output) - - # Update workspace, confirm no - mock_click_confirm.return_value = False - result = runner.invoke( - cli, - [ - "terraform", - "update-workspace", - config_file_path, - "--terraform-credentials-path", - terraform_credentials_path, - ], - ) - self.assertNotIn("Updating workspace...", result.output) - - # Delete workspace - terraform_api.delete_workspace(self.organisation, workspace) - - @patch("observatory.platform.observatory_config.TerraformConfig.load") - def test_terraform_check_dependencies(self, mock_load_config): - """Test that checking for dependencies prints the correct output when files are missing""" - runner = CliRunner() - with runner.isolated_filesystem() as working_dir: - credentials_file_path = os.path.join(working_dir, "google_application_credentials.json") - TerraformConfig.WORKSPACE_PREFIX = random_id() + "-" - - # No config file should exist because we are in a new isolated filesystem - config_file_path = os.path.join(working_dir, "config-terraform.yaml") - terraform_credentials_path = os.path.join(working_dir, "terraform-creds.yaml") - - # Check that correct exit code and output are returned - result = runner.invoke( - cli, - [ - "terraform", - "create-workspace", - config_file_path, - "--terraform-credentials-path", - terraform_credentials_path, - ], - ) - - # No config file - self.assertIn( - f"Error: Invalid value for 'CONFIG_PATH': File '{config_file_path}' does not exist.", result.output - ) - - # Check return code, exit from click invalid option - self.assertEqual(result.exit_code, 2) - - # Make a fake config-terraform.yaml file - with open(config_file_path, "w") as f: - f.write("") - - # Make a fake google credentials file - with open(credentials_file_path, "w") as f: - f.write("") - - # Create config instance - config = TerraformConfig.from_dict( - { - "backend": {"type": "terraform", "environment": "develop"}, - "observatory": { - "package": "observatory-platform", - "package_type": "pypi", - "airflow_fernet_key": "IWt5jFGSw2MD1shTdwzLPTFO16G8iEAU3A6mGo_vJTY=", - "airflow_secret_key": "a" * 16, - "airflow_ui_user_password": "password", - "airflow_ui_user_email": "password", - "postgres_password": "my-password", - }, - "terraform": {"organization": self.organisation}, - "google_cloud": { - "project_id": "my-project", - "credentials": credentials_file_path, - "region": "us-west1", - "zone": "us-west1-c", - "data_location": "us", - }, - "cloud_sql_database": {"tier": "db-custom-2-7680", "backup_start_time": "23:00"}, - "airflow_main_vm": { - "machine_type": "n2-standard-2", - "disk_size": 1, - "disk_type": "pd-ssd", - "create": True, - }, - "airflow_worker_vm": { - "machine_type": "n2-standard-2", - "disk_size": 1, - "disk_type": "pd-standard", - "create": False, - }, - } - ) - mock_load_config.return_value = config - - # Run again with existing config, specifying terraform files that don't exist. Check that correct exit - # code and output are returned - result = runner.invoke( - cli, - [ - "terraform", - "create-workspace", - config_file_path, - "--terraform-credentials-path", - terraform_credentials_path, - ], - ) - - # No terraform credentials file - self.assertIn( - "Terraform credentials file:\n - file not found, create one by running 'terraform login'", - result.output, - ) - - # Check return code - self.assertEqual(result.exit_code, os.EX_CONFIG) diff --git a/tests/observatory/platform/cli/test_cli_functional.py b/tests/observatory/platform/cli/test_cli_functional.py deleted file mode 100644 index 0c19ead26..000000000 --- a/tests/observatory/platform/cli/test_cli_functional.py +++ /dev/null @@ -1,474 +0,0 @@ -# Copyright 2021-2023 Curtin University -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Author: James Diprose - -import glob -import json -import logging -import os -import shutil -import subprocess -import time -import unittest -import uuid -from subprocess import Popen -from typing import Set -from unittest.mock import patch - -import requests -import stringcase -from click.testing import CliRunner -from cryptography.fernet import Fernet -from redis import Redis - -from observatory.platform.cli.cli import cli -from observatory.platform.observatory_config import ( - ObservatoryConfig, - Backend, - Observatory, - BackendType, - Environment, - WorkflowsProject, - Workflow, -) -from observatory.platform.observatory_environment import ( - test_fixtures_path, - find_free_port, - save_empty_file, - module_file_path, -) -from observatory.platform.utils.proc_utils import stream_process -from observatory.platform.utils.url_utils import wait_for_url - - -def list_dag_ids( - host: str = "http://localhost", port: int = None, user: str = "airflow@airflow.com", pwd: str = "airflow" -) -> Set: - """List the DAG ids that have been loaded in an Airflow instance. - - :param host: the hostname. - :param port: the port. - :param user: the username. - :param pwd: the password. - :return: the set of DAG ids. - """ - - parts = [host] - if port is not None: - parts.append(f":{port}") - parts.append("/api/v1/dags") - url = "".join(parts) - - dag_ids = [] - response = requests.get(url, headers={"Content-Type": "application/json"}, auth=(user, pwd)) - if response.status_code == 200: - dags = json.loads(response.text)["dags"] - dag_ids = [dag["dag_id"] for dag in dags] - - return set(dag_ids) - - -def build_sdist(package_path: str) -> str: - """Build a Python source distribution and return the path to the tar file. - - :param package_path: - :return: - """ - - # Remove dist directory - build_dir = os.path.join(package_path, "dist") - shutil.rmtree(build_dir, ignore_errors=True) - - # Set PBR version - env = os.environ.copy() - env["PBR_VERSION"] = "0.0.1" - - proc: Popen = Popen( - ["python3", "setup.py", "sdist"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=package_path, env=env - ) - output, error = stream_process(proc, True) - assert proc.returncode == 0, f"build_sdist failed: {package_path}" - - # Get path to sdist - results = glob.glob(os.path.join(build_dir, "*.tar.gz")) - return results[0] - - -class TestCliFunctional(unittest.TestCase): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.observatory_api_path = module_file_path("observatory.api", nav_back_steps=-3) - self.observatory_api_package_name = "observatory-api" - self.observatory_platform_path = module_file_path("observatory.platform", nav_back_steps=-3) - self.observatory_platform_package_name = "observatory-platform" - self.airflow_ui_user_email = "airflow@airflow.com" - self.airflow_ui_user_password = "airflow" - self.airflow_fernet_key = Fernet.generate_key() - self.airflow_secret_key = uuid.uuid4().hex - self.docker_network_name = "observatory-unit-test-network" - self.docker_compose_project_name = "observatory_unit_test" - self.expected_platform_dag_ids = {"dummy_telescope", "vm_create", "vm_destroy"} - self.expected_workflows_dag_ids = { - "dummy_telescope", - "vm_create", - "vm_destroy", - "my_dag", - "hello_world_dag", - "my_workflow", - } - self.start_cmd = ["platform", "start", "--debug", "--config-path"] - self.stop_cmd = ["platform", "stop", "--debug", "--config-path"] - self.workflows_package_name = "my-workflows-project" - self.dag_check_timeout = 360 - self.port_wait_timeout = 360 - self.config_file_name = "config.yaml" - - self.platform_workflows = [ - Workflow( - dag_id="vm_create", - name="VM Create Workflow", - class_name="observatory.platform.workflows.vm_workflow.VmCreateWorkflow", - kwargs=dict(terraform_organisation="terraform_organisation", terraform_workspace="terraform_workspace"), - ), - Workflow( - dag_id="vm_destroy", - name="VM Destroy Workflow", - class_name="observatory.platform.workflows.vm_workflow.VmDestroyWorkflow", - kwargs=dict( - terraform_organisation="terraform_organisation", - terraform_workspace="terraform_workspace", - dags_watch_list=["dag1", "dag2"], - ), - ), - ] - self.all_workflows = self.platform_workflows + [ - Workflow( - dag_id="my_workflow", - name="My Workflow", - class_name="my_workflows_project.workflows.my_workflow.MyWorkflow", - ), - ] - - def make_editable_observatory_config(self, temp_dir: str) -> ObservatoryConfig: - """Make an editable observatory config. - - :param temp_dir: the temp dir. - :return: ObservatoryConfig. - """ - - return ObservatoryConfig( - backend=Backend(type=BackendType.local, environment=Environment.develop), - observatory=Observatory( - package=os.path.join(temp_dir, self.observatory_platform_package_name), - package_type="editable", - airflow_fernet_key=self.airflow_fernet_key, - airflow_secret_key=self.airflow_secret_key, - airflow_ui_user_email=self.airflow_ui_user_email, - airflow_ui_user_password=self.airflow_ui_user_password, - observatory_home=temp_dir, - redis_port=find_free_port(), - flower_ui_port=find_free_port(), - airflow_ui_port=find_free_port(), - api_port=find_free_port(), - docker_network_name=self.docker_network_name, - docker_compose_project_name=self.docker_compose_project_name, - api_package_type="editable", - api_package=os.path.join(temp_dir, self.observatory_api_package_name), - ), - workflows=self.platform_workflows, - ) - - def copy_observatory_api(self, temp_dir: str): - """Copy the workflows project to the test dir""" - - shutil.copytree(self.observatory_api_path, os.path.join(temp_dir, self.observatory_api_package_name)) - - def copy_observatory_platform(self, temp_dir: str): - """Copy the workflows project to the test dir""" - - shutil.copytree(self.observatory_platform_path, os.path.join(temp_dir, self.observatory_platform_package_name)) - - def copy_workflows_project(self, temp_dir: str): - """Copy the workflows project to the test dir""" - - shutil.copytree( - test_fixtures_path("cli", self.workflows_package_name), os.path.join(temp_dir, self.workflows_package_name) - ) - - def assert_dags_loaded(self, expected_dag_ids: Set, config: ObservatoryConfig, dag_check_timeout: int = 180): - """Assert that DAGs loaded into Airflow. - - :param expected_dag_ids: the expected DAG ids. - :param config: the Observatory Config. - :param dag_check_timeout: how long to check for DAGs. - :return: None. - """ - - start = time.time() - while True: - duration = time.time() - start - actual_dag_ids = list_dag_ids( - port=config.observatory.airflow_ui_port, - user=config.observatory.airflow_ui_user_email, - pwd=config.observatory.airflow_ui_user_password, - ) - if expected_dag_ids == actual_dag_ids or duration > dag_check_timeout: - break - self.assertSetEqual(expected_dag_ids, actual_dag_ids) - - def assert_ports_open(self, observatory: Observatory, timeout: int = 180): - """Check that the ports given in the observatory object are accepting connections. - - :param observatory: the observatory object. - :param timeout: the length of time to wait until timing out. - :return: None. - """ - - # Expected values - expected_ports = [ - observatory.airflow_ui_port, - observatory.flower_ui_port, - f"{observatory.api_port}/swagger.json", - ] - - # Verify that ports are active - urls = [] - states = [] - for port in expected_ports: - url = f"http://localhost:{port}" - urls.append(url) - logging.info(f"Waiting for URL: {url}") - state = wait_for_url(url, timeout=timeout) - logging.info(f"URL {url} state: {state}") - states.append(state) - - # Assert states - for state in states: - self.assertTrue(state) - - # Check if Redis is active - redis = Redis(port=observatory.redis_port, socket_connect_timeout=1) - self.assertTrue(redis.ping()) - - @patch("observatory.platform.cli.cli.ObservatoryConfig.load") - def test_run_platform_editable(self, mock_config_load): - """Test that the platform runs when built from an editable project.""" - - runner = CliRunner() - with runner.isolated_filesystem() as t: - # Save empty config - config_path = save_empty_file(t, self.config_file_name) - - # Copy platform project - self.copy_observatory_api(t) - self.copy_observatory_platform(t) - - # Make config object - config = self.make_editable_observatory_config(t) - mock_config_load.return_value = config - - try: - # Test that start command works - result = runner.invoke(cli, self.start_cmd + [config_path], catch_exceptions=False) - self.assertEqual(os.EX_OK, result.exit_code) - - # Assert that ports are open - self.assert_ports_open(config.observatory, timeout=self.port_wait_timeout) - - # Test that default DAGs are loaded - self.assert_dags_loaded( - self.expected_platform_dag_ids, config, dag_check_timeout=self.dag_check_timeout - ) - - # Test that stop command works - result = runner.invoke(cli, self.stop_cmd + [config_path], catch_exceptions=False) - self.assertEqual(os.EX_OK, result.exit_code) - finally: - runner.invoke(cli, self.stop_cmd + [config_path]) - - @patch("observatory.platform.cli.cli.ObservatoryConfig.load") - def test_dag_load_workflows_project_editable(self, mock_config_load): - """Test that the DAGs load when build from an editable workflows project.""" - - runner = CliRunner() - with runner.isolated_filesystem() as t: - # Save empty config - config_path = save_empty_file(t, self.config_file_name) - - # Copy projects - self.copy_observatory_api(t) - self.copy_observatory_platform(t) - self.copy_workflows_project(t) - - # Make config object - config = self.make_editable_observatory_config(t) - config.workflows_projects = [ - WorkflowsProject( - package_name=self.workflows_package_name, - package=os.path.join(t, self.workflows_package_name), - package_type="editable", - dags_module=f"{stringcase.snakecase(self.workflows_package_name)}.dags", - ) - ] - config.workflows = self.all_workflows - mock_config_load.return_value = config - - try: - # Test that start command works - result = runner.invoke(cli, self.start_cmd + [config_path], catch_exceptions=False) - print("test_dag_load_workflows_project_editable errors") - print(f"Output: {result.output}") - self.assertEqual(os.EX_OK, result.exit_code) - - # Assert that ports are open - self.assert_ports_open(config.observatory, timeout=self.port_wait_timeout) - - # Test that default DAGs are loaded - self.assert_dags_loaded( - self.expected_workflows_dag_ids, config, dag_check_timeout=self.dag_check_timeout - ) - - # Test that stop command works - result = runner.invoke(cli, self.stop_cmd + [config_path], catch_exceptions=False) - self.assertEqual(os.EX_OK, result.exit_code) - finally: - runner.invoke(cli, self.stop_cmd + [config_path]) - - def make_sdist_observatory_config( - self, - temp_dir: str, - observatory_api_sdist_path: str, - observatory_sdist_path: str, - ) -> ObservatoryConfig: - """Make an sdist observatory config. - - :param temp_dir: the temp dir. - :param observatory_api_sdist_path: the observatory-api sdist path. - :param observatory_sdist_path: the observatory-platform sdist path. - :return: ObservatoryConfig. - """ - - return ObservatoryConfig( - backend=Backend(type=BackendType.local, environment=Environment.develop), - observatory=Observatory( - package=observatory_sdist_path, - package_type="sdist", - airflow_fernet_key=self.airflow_fernet_key, - airflow_secret_key=self.airflow_secret_key, - airflow_ui_user_email=self.airflow_ui_user_email, - airflow_ui_user_password=self.airflow_ui_user_password, - observatory_home=temp_dir, - redis_port=find_free_port(), - flower_ui_port=find_free_port(), - airflow_ui_port=find_free_port(), - api_port=find_free_port(), - docker_network_name=self.docker_network_name, - docker_compose_project_name=self.docker_compose_project_name, - api_package=observatory_api_sdist_path, - api_package_type="sdist", - ), - workflows=self.platform_workflows, - ) - - @patch("observatory.platform.cli.cli.ObservatoryConfig.load") - def test_run_platform_sdist(self, mock_config_load): - """Test that the platform runs when built from a source distribution.""" - - runner = CliRunner() - with runner.isolated_filesystem() as t: - # Save empty config - config_path = save_empty_file(t, self.config_file_name) - - # Copy platform project - self.copy_observatory_api(t) - self.copy_observatory_platform(t) - - # Build sdist - observatory_api_sdist_path = build_sdist(os.path.join(t, self.observatory_api_package_name)) - observatory_platform_sdist_path = build_sdist(os.path.join(t, self.observatory_platform_package_name)) - - # Make config object - config = self.make_sdist_observatory_config(t, observatory_api_sdist_path, observatory_platform_sdist_path) - mock_config_load.return_value = config - - try: - # Test that start command works - result = runner.invoke(cli, self.start_cmd + [config_path], catch_exceptions=False) - self.assertEqual(os.EX_OK, result.exit_code) - - # Assert that ports are open - self.assert_ports_open(config.observatory, timeout=self.port_wait_timeout) - - # Test that default DAGs are loaded - self.assert_dags_loaded( - self.expected_platform_dag_ids, config, dag_check_timeout=self.dag_check_timeout - ) - - # Test that stop command works - result = runner.invoke(cli, self.stop_cmd + [config_path], catch_exceptions=False) - self.assertEqual(os.EX_OK, result.exit_code) - finally: - runner.invoke(cli, self.stop_cmd + [config_path]) - - @patch("observatory.platform.cli.cli.ObservatoryConfig.load") - def test_dag_load_workflows_project_sdist(self, mock_config_load): - """Test that DAGs load from an sdist workflows project.""" - - runner = CliRunner() - with runner.isolated_filesystem() as t: - # Save empty config - config_path = save_empty_file(t, self.config_file_name) - - # Copy projects - self.copy_observatory_api(t) - self.copy_observatory_platform(t) - self.copy_workflows_project(t) - - # Build sdists - observatory_api_sdist_path = build_sdist(os.path.join(t, self.observatory_api_package_name)) - observatory_sdist_path = build_sdist(os.path.join(t, self.observatory_platform_package_name)) - workflows_sdist_path = build_sdist(os.path.join(t, self.workflows_package_name)) - - # Make config object - config = self.make_sdist_observatory_config(t, observatory_api_sdist_path, observatory_sdist_path) - config.workflows_projects = [ - WorkflowsProject( - package_name=self.workflows_package_name, - package=workflows_sdist_path, - package_type="sdist", - dags_module=f"{stringcase.snakecase(self.workflows_package_name)}.dags", - ) - ] - config.workflows = self.all_workflows - mock_config_load.return_value = config - - try: - # Test that start command works - result = runner.invoke(cli, self.start_cmd + [config_path], catch_exceptions=False) - self.assertEqual(os.EX_OK, result.exit_code) - - # Assert that ports are open - self.assert_ports_open(config.observatory, timeout=self.port_wait_timeout) - - # Test that default DAGs are loaded - self.assert_dags_loaded( - self.expected_workflows_dag_ids, config, dag_check_timeout=self.dag_check_timeout - ) - - # Test that stop command works - result = runner.invoke(cli, self.stop_cmd + [config_path], catch_exceptions=False) - self.assertEqual(os.EX_OK, result.exit_code) - finally: - runner.invoke(cli, self.stop_cmd + [config_path]) diff --git a/tests/observatory/platform/cli/test_cli_utils.py b/tests/observatory/platform/cli/test_cli_utils.py deleted file mode 100644 index b48069a17..000000000 --- a/tests/observatory/platform/cli/test_cli_utils.py +++ /dev/null @@ -1,63 +0,0 @@ -# Copyright 2020 Curtin University -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Author: James Diprose - -import unittest - -from observatory.platform.cli.cli_utils import ( - INDENT1, - INDENT2, - INDENT3, - INDENT4, - comment, - indent, -) - - -class TestClick(unittest.TestCase): - def test_indent(self): - original_str = "hello world" - - # 2 spaces - output = indent(original_str, INDENT1) - self.assertEqual(f" {original_str}", output) - - # 3 spaces - output = indent(original_str, INDENT2) - self.assertEqual(f" {original_str}", output) - - # 4 spaces - output = indent(original_str, INDENT3) - self.assertEqual(f" {original_str}", output) - - # 5 spaces - output = indent(original_str, INDENT4) - self.assertEqual(f" {original_str}", output) - - # Check that values below 0 raise assertion error - with self.assertRaises(AssertionError): - indent(original_str, 0) - - with self.assertRaises(AssertionError): - indent(original_str, -1) - - def test_comment(self): - input_str = "" - output = comment(input_str) - self.assertEqual(output, "# ") - - input_str = "Hello world" - output = comment(input_str) - self.assertEqual(output, "# Hello world") diff --git a/tests/observatory/platform/cli/test_generate_command.py b/tests/observatory/platform/cli/test_generate_command.py deleted file mode 100644 index 5869a8e2e..000000000 --- a/tests/observatory/platform/cli/test_generate_command.py +++ /dev/null @@ -1,663 +0,0 @@ -# Copyright 2020 Curtin University -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Author: James Diprose, Tuan Chien, Aniek Roelofs - -import os -import unittest -from unittest.mock import patch - -import click -from click.testing import CliRunner - -from observatory.platform.cli.generate_command import ( - DefaultWorkflowsProject, - FernetKeyType, - FlaskSecretKeyType, - GenerateCommand, - InteractiveConfigBuilder, -) -from observatory.platform.config import module_file_path -from observatory.platform.observatory_config import ( - Backend, - BackendType, - CloudSqlDatabase, - Environment, - GoogleCloud, - Observatory, - ObservatoryConfig, - Terraform, - TerraformConfig, - VirtualMachine, - WorkflowsProject, -) -from observatory.platform.observatory_environment import ObservatoryTestCase - - -class TestGenerateCommand(ObservatoryTestCase): - def test_generate_local_config(self): - cmd = GenerateCommand() - config_path = "config.yaml" - - with CliRunner().isolated_filesystem(): - cmd.generate_local_config(config_path, editable=False, workflows=[], oapi=False) - self.assertTrue(os.path.exists(config_path)) - - with CliRunner().isolated_filesystem(): - cmd.generate_local_config(config_path, editable=True, workflows=[], oapi=False) - self.assertTrue(os.path.exists(config_path)) - - with CliRunner().isolated_filesystem(): - cmd.generate_local_config(config_path, editable=False, workflows=[], oapi=True) - self.assertTrue(os.path.exists(config_path)) - - with CliRunner().isolated_filesystem(): - cmd.generate_local_config(config_path, editable=True, workflows=[], oapi=True) - self.assertTrue(os.path.exists(config_path)) - - def test_generate_terraform_config(self): - cmd = GenerateCommand() - config_path = "config-terraform.yaml" - - with CliRunner().isolated_filesystem(): - cmd.generate_terraform_config(config_path, editable=False, workflows=[], oapi=False) - self.assertTrue(os.path.exists(config_path)) - - with CliRunner().isolated_filesystem(): - cmd.generate_terraform_config(config_path, editable=True, workflows=[], oapi=False) - self.assertTrue(os.path.exists(config_path)) - - with CliRunner().isolated_filesystem(): - cmd.generate_terraform_config(config_path, editable=False, workflows=[], oapi=True) - self.assertTrue(os.path.exists(config_path)) - - with CliRunner().isolated_filesystem(): - cmd.generate_terraform_config(config_path, editable=True, workflows=[], oapi=True) - self.assertTrue(os.path.exists(config_path)) - - -class TestInteractiveConfigBuilder(unittest.TestCase): - @patch("observatory.platform.cli.generate_command.InteractiveConfigBuilder.set_editable_observatory_platform") - @patch("observatory.platform.cli.generate_command.InteractiveConfigBuilder.build") - def test_generate_local_config_interactive(self, mock_build, m_set_edit): - cmd = GenerateCommand() - cmd.generate_local_config_interactive( - config_path="path", workflows=["academic-observatory-workflows"], oapi=False, editable=False - ) - self.assertEqual(mock_build.call_args.kwargs["backend_type"], BackendType.local) - self.assertEqual(mock_build.call_args.kwargs["workflows"], ["academic-observatory-workflows"]) - self.assertEqual(m_set_edit.call_count, 0) - - cmd.generate_local_config_interactive( - config_path="path", workflows=["academic-observatory-workflows"], oapi=False, editable=True - ) - self.assertEqual(mock_build.call_args.kwargs["backend_type"], BackendType.local) - self.assertEqual(mock_build.call_args.kwargs["workflows"], ["academic-observatory-workflows"]) - self.assertEqual(m_set_edit.call_count, 1) - - @patch("observatory.platform.cli.generate_command.InteractiveConfigBuilder.set_editable_observatory_platform") - @patch("observatory.platform.cli.generate_command.InteractiveConfigBuilder.build") - def test_generate_terraform_config_interactive(self, mock_build, m_set_edit): - cmd = GenerateCommand() - cmd.generate_terraform_config_interactive( - config_path="path", workflows=["academic-observatory-workflows"], oapi=False, editable=False - ) - self.assertEqual(mock_build.call_args.kwargs["backend_type"], BackendType.terraform) - self.assertEqual(mock_build.call_args.kwargs["workflows"], ["academic-observatory-workflows"]) - self.assertEqual(m_set_edit.call_count, 0) - - cmd.generate_terraform_config_interactive( - config_path="path", workflows=["academic-observatory-workflows"], oapi=False, editable=True - ) - self.assertEqual(mock_build.call_args.kwargs["backend_type"], BackendType.terraform) - self.assertEqual(mock_build.call_args.kwargs["workflows"], ["academic-observatory-workflows"]) - self.assertEqual(m_set_edit.call_count, 1) - - @patch("observatory.platform.cli.generate_command.module_file_path") - @patch("observatory.platform.cli.generate_command.InteractiveConfigBuilder.config_airflow_worker_vm") - @patch("observatory.platform.cli.generate_command.InteractiveConfigBuilder.config_airflow_main_vm") - @patch("observatory.platform.cli.generate_command.InteractiveConfigBuilder.config_cloud_sql_database") - @patch("observatory.platform.cli.generate_command.InteractiveConfigBuilder.config_workflows_projects") - @patch("observatory.platform.cli.generate_command.InteractiveConfigBuilder.config_google_cloud") - @patch("observatory.platform.cli.generate_command.InteractiveConfigBuilder.config_terraform") - @patch("observatory.platform.cli.generate_command.InteractiveConfigBuilder.config_observatory") - @patch("observatory.platform.cli.generate_command.InteractiveConfigBuilder.config_backend") - def test_build( - self, - m_backend, - m_observatory, - m_terraform, - m_google_cloud, - m_workflows_projects, - m_cloud_sql_database, - m_airflow_main_vm, - m_airflow_worker_m, - m_mfp, - ): - def mock_mfp(*arg, **kwargs): - if arg[0] == "academic_observatory_workflows.dags": - return "/ao_workflows/path" - else: - return "oaebu_workflows/path" - - m_mfp.side_effect = mock_mfp - workflows = [] - local_nodags = InteractiveConfigBuilder.build( - backend_type=BackendType.local, workflows=workflows, oapi=False, editable=False - ) - self.assertTrue(isinstance(local_nodags, ObservatoryConfig)) - self.assertEqual(len(local_nodags.workflows_projects), 0) - - workflows = ["academic-observatory-workflows", "oaebu-workflows"] - local_dags = InteractiveConfigBuilder.build( - backend_type=BackendType.local, workflows=workflows, oapi=False, editable=False - ) - self.assertTrue(isinstance(local_dags, ObservatoryConfig)) - self.assertEqual(len(local_dags.workflows_projects), 2) - self.assertEqual(local_dags.workflows_projects[0], DefaultWorkflowsProject.academic_observatory_workflows()) - self.assertEqual(local_dags.workflows_projects[1], DefaultWorkflowsProject.oaebu_workflows()) - print(local_dags.workflows_projects) - - terraform_nodags = InteractiveConfigBuilder.build( - backend_type=BackendType.terraform, workflows=[], oapi=False, editable=False - ) - self.assertTrue(isinstance(terraform_nodags, TerraformConfig)) - self.assertTrue(isinstance(terraform_nodags, ObservatoryConfig)) - self.assertEqual(len(terraform_nodags.workflows_projects), 0) - - terraform_dags = InteractiveConfigBuilder.build( - backend_type=BackendType.terraform, - workflows=["academic-observatory-workflows", "oaebu-workflows"], - oapi=False, - editable=False, - ) - self.assertTrue(isinstance(terraform_dags, TerraformConfig)) - self.assertEqual(len(terraform_dags.workflows_projects), 2) - self.assertEqual(terraform_dags.workflows_projects[0], DefaultWorkflowsProject.academic_observatory_workflows()) - self.assertEqual(terraform_dags.workflows_projects[1], DefaultWorkflowsProject.oaebu_workflows()) - - self.assertTrue(m_backend.called) - self.assertTrue(m_observatory.called) - self.assertTrue(m_terraform.called) - self.assertTrue(m_google_cloud.called) - self.assertTrue(m_workflows_projects.called) - self.assertTrue(m_cloud_sql_database.called) - self.assertTrue(m_airflow_main_vm.called) - self.assertTrue(m_airflow_worker_m.called) - - @patch("observatory.platform.cli.generate_command.click.prompt") - def test_config_backend(self, m_prompt): - m_prompt.return_value = "staging" - - config = ObservatoryConfig() - expected = Backend(type=BackendType.local, environment=Environment.staging) - InteractiveConfigBuilder.config_backend(config=config, backend_type=expected.type) - self.assertEqual(config.backend.type, expected.type) - self.assertEqual(config.backend.environment, expected.environment) - - config = TerraformConfig() - expected = Backend(type=BackendType.terraform, environment=Environment.staging) - InteractiveConfigBuilder.config_backend(config=config, backend_type=expected.type) - self.assertEqual(config.backend.type, expected.type) - self.assertEqual(config.backend.environment, expected.environment) - - @patch("observatory.platform.cli.generate_command.click.prompt") - def test_config_observatory_filled_keys(self, m_prompt): - observatory = Observatory( - airflow_fernet_key="IWt5jFGSw2MD1shTdwzLPTFO16G8iEAU3A6mGo_vJTY=", - airflow_secret_key=("a" * 16), - airflow_ui_user_email="email@email", - airflow_ui_user_password="pass", - observatory_home="/", - postgres_password="pass", - redis_port=111, - flower_ui_port=53, - airflow_ui_port=64, - api_port=123, - docker_network_name="raefd", - docker_compose_project_name="proj", - ) - - # Answer to questions - m_prompt.side_effect = [ - # observatory.package_type, - observatory.airflow_fernet_key, - observatory.airflow_secret_key, - observatory.airflow_ui_user_email, - observatory.airflow_ui_user_password, - observatory.observatory_home, - observatory.postgres_password, - observatory.redis_port, - observatory.flower_ui_port, - observatory.airflow_ui_port, - observatory.api_port, - observatory.docker_network_name, - observatory.docker_network_is_external, - observatory.docker_compose_project_name, - "y", - observatory.api_package, - observatory.api_package_type, - ] - - config = ObservatoryConfig() - InteractiveConfigBuilder.config_observatory(config=config, oapi=False, editable=False) - self.assertEqual(config.observatory, observatory) - - @patch("observatory.platform.cli.generate_command.click.prompt") - def test_config_observatory_empty_keys(self, m_prompt): - observatory = Observatory( - airflow_fernet_key="", - airflow_secret_key="", - airflow_ui_user_email="email@email", - airflow_ui_user_password="pass", - observatory_home="/", - postgres_password="pass", - redis_port=111, - flower_ui_port=53, - airflow_ui_port=64, - api_port=123, - docker_network_name="raefd", - docker_compose_project_name="proj", - ) - - # Answer to questions - m_prompt.side_effect = [ - # observatory.package_type, - observatory.airflow_fernet_key, - observatory.airflow_secret_key, - observatory.airflow_ui_user_email, - observatory.airflow_ui_user_password, - observatory.observatory_home, - observatory.postgres_password, - observatory.redis_port, - observatory.flower_ui_port, - observatory.airflow_ui_port, - observatory.api_port, - observatory.docker_network_name, - observatory.docker_network_is_external, - observatory.docker_compose_project_name, - "y", - observatory.api_package, - observatory.api_package_type, - ] - - config = ObservatoryConfig() - InteractiveConfigBuilder.config_observatory(config=config, oapi=False, editable=False) - self.assertTrue(len(config.observatory.airflow_fernet_key) > 0) - self.assertTrue(len(config.observatory.airflow_secret_key) > 0) - - observatory.airflow_fernet_key = config.observatory.airflow_fernet_key - observatory.airflow_secret_key = config.observatory.airflow_secret_key - self.assertEqual(config.observatory, observatory) - - @patch("observatory.platform.cli.generate_command.click.prompt") - def test_config_observatory_editable(self, m_prompt): - observatory = Observatory( - airflow_fernet_key="", - airflow_secret_key="", - airflow_ui_user_email="email@email", - airflow_ui_user_password="pass", - observatory_home="/", - postgres_password="pass", - redis_port=111, - flower_ui_port=53, - airflow_ui_port=64, - api_port=123, - docker_network_name="raefd", - docker_compose_project_name="proj", - package=module_file_path("observatory.platform", nav_back_steps=-3), - package_type="editable", - ) - - # Answer to questions - m_prompt.side_effect = [ - observatory.airflow_fernet_key, - observatory.airflow_secret_key, - observatory.airflow_ui_user_email, - observatory.airflow_ui_user_password, - observatory.observatory_home, - observatory.postgres_password, - observatory.redis_port, - observatory.flower_ui_port, - observatory.airflow_ui_port, - observatory.api_port, - observatory.docker_network_name, - observatory.docker_network_is_external, - observatory.docker_compose_project_name, - "y", - observatory.api_package, - observatory.api_package_type, - ] - - config = ObservatoryConfig() - InteractiveConfigBuilder.config_observatory(config=config, oapi=False, editable=True) - observatory.airflow_fernet_key = config.observatory.airflow_fernet_key - observatory.airflow_secret_key = config.observatory.airflow_secret_key - self.assertTrue(config.observatory.package_type, "editable") - self.assertEqual(config.observatory, observatory) - - @patch("observatory.platform.cli.generate_command.click.prompt") - def test_config_observatory_oapi(self, m_prompt): - observatory = Observatory( - airflow_fernet_key="", - airflow_secret_key="", - airflow_ui_user_email="email@email", - airflow_ui_user_password="pass", - observatory_home="/", - postgres_password="pass", - redis_port=111, - flower_ui_port=53, - airflow_ui_port=64, - api_port=123, - docker_network_name="raefd", - docker_compose_project_name="proj", - package_type="pypi", - api_package_type="pypi", - ) - - # Answer to questions - m_prompt.side_effect = [ - # "pypi", - observatory.airflow_fernet_key, - observatory.airflow_secret_key, - observatory.airflow_ui_user_email, - observatory.airflow_ui_user_password, - observatory.observatory_home, - observatory.postgres_password, - observatory.redis_port, - observatory.flower_ui_port, - observatory.airflow_ui_port, - observatory.api_port, - observatory.docker_network_name, - observatory.docker_network_is_external, - observatory.docker_compose_project_name, - "y", - ] - - config = ObservatoryConfig() - InteractiveConfigBuilder.config_observatory(config=config, oapi=True, editable=False) - self.assertTrue(len(config.observatory.airflow_fernet_key) > 0) - self.assertTrue(len(config.observatory.airflow_secret_key) > 0) - - observatory.airflow_fernet_key = config.observatory.airflow_fernet_key - observatory.airflow_secret_key = config.observatory.airflow_secret_key - self.assertEqual(config.observatory, observatory) - - @patch("observatory.platform.cli.generate_command.click.prompt") - def test_config_observatory_oapi_editable(self, m_prompt): - observatory = Observatory( - airflow_fernet_key="", - airflow_secret_key="", - airflow_ui_user_email="email@email", - airflow_ui_user_password="pass", - observatory_home="/", - postgres_password="pass", - redis_port=111, - flower_ui_port=53, - airflow_ui_port=64, - api_port=123, - docker_network_name="raefd", - docker_compose_project_name="proj", - package_type="editable", - api_package_type="editable", - package=module_file_path("observatory.platform", nav_back_steps=-3), - api_package=module_file_path("observatory.api", nav_back_steps=-3), - ) - - # Answer to questions - m_prompt.side_effect = [ - # "pypi", - observatory.airflow_fernet_key, - observatory.airflow_secret_key, - observatory.airflow_ui_user_email, - observatory.airflow_ui_user_password, - observatory.observatory_home, - observatory.postgres_password, - observatory.redis_port, - observatory.flower_ui_port, - observatory.airflow_ui_port, - observatory.api_port, - observatory.docker_network_name, - observatory.docker_network_is_external, - observatory.docker_compose_project_name, - "y", - ] - - config = ObservatoryConfig() - InteractiveConfigBuilder.config_observatory(config=config, oapi=True, editable=True) - self.assertTrue(len(config.observatory.airflow_fernet_key) > 0) - self.assertTrue(len(config.observatory.airflow_secret_key) > 0) - - observatory.airflow_fernet_key = config.observatory.airflow_fernet_key - observatory.airflow_secret_key = config.observatory.airflow_secret_key - self.assertEqual(config.observatory, observatory) - - @patch("observatory.platform.cli.generate_command.click.confirm") - def test_config_google_cloud_local_no_config(self, m_confirm): - config = ObservatoryConfig() - m_confirm.return_value = False - InteractiveConfigBuilder.config_google_cloud(config) - self.assertEqual(config.google_cloud, None) - - @patch("observatory.platform.cli.generate_command.click.confirm") - @patch("observatory.platform.cli.generate_command.click.prompt") - def test_config_google_cloud_local_config(self, m_prompt, m_confirm): - m_confirm.return_value = True - - google_cloud = GoogleCloud( - project_id="proj", - credentials="/tmp", - data_location="us", - ) - - # Answer to questions - m_prompt.side_effect = [ - google_cloud.project_id, - google_cloud.credentials, - google_cloud.data_location, - ] - - config = ObservatoryConfig() - InteractiveConfigBuilder.config_google_cloud(config) - - self.assertEqual(config.google_cloud, google_cloud) - - @patch("observatory.platform.cli.generate_command.click.prompt") - def test_config_google_cloud_terraform_config(self, m_prompt): - google_cloud = GoogleCloud( - project_id="proj", - credentials="/tmp", - data_location="us", - region="us-west2", - zone="us-west1-b", - ) - - # Answer to questions - m_prompt.side_effect = [ - google_cloud.project_id, - google_cloud.credentials, - google_cloud.data_location, - google_cloud.region, - google_cloud.zone, - ] - - config = TerraformConfig() - InteractiveConfigBuilder.config_google_cloud(config) - - self.assertEqual(config.google_cloud, google_cloud) - - @patch("observatory.platform.cli.generate_command.click.confirm") - def test_config_terraform_local_no_config(self, m_confirm): - m_confirm.return_value = False - - config = ObservatoryConfig() - InteractiveConfigBuilder.config_terraform(config) - self.assertEqual(config.terraform, None) - - @patch("observatory.platform.cli.generate_command.click.confirm") - @patch("observatory.platform.cli.generate_command.click.prompt") - def test_config_terraform_local_config(self, m_prompt, m_confirm): - m_confirm.return_value = True - m_prompt.side_effect = ["myorg", ""] - terraform = Terraform(organization="myorg") - - config = ObservatoryConfig() - InteractiveConfigBuilder.config_terraform(config) - self.assertEqual(config.terraform, terraform) - - config = ObservatoryConfig() - InteractiveConfigBuilder.config_terraform(config) - self.assertEqual(config.terraform, None) - - @patch("observatory.platform.cli.generate_command.click.prompt") - def test_config_terraform_terraform_config(self, m_prompt): - m_prompt.return_value = "myorg" - terraform = Terraform(organization="myorg") - - config = TerraformConfig() - InteractiveConfigBuilder.config_terraform(config) - self.assertEqual(config.terraform, terraform) - - @patch("observatory.platform.cli.generate_command.click.confirm") - def test_config_workflows_projects_none(self, m_confirm): - m_confirm.return_value = False - config = ObservatoryConfig() - expected_dags = list() - InteractiveConfigBuilder.config_workflows_projects(config) - self.assertEqual(config.workflows_projects, expected_dags) - - @patch("observatory.platform.cli.generate_command.click.prompt") - @patch("observatory.platform.cli.generate_command.click.confirm") - def test_config_workflows_projects_add(self, m_confirm, m_prompt): - m_confirm.side_effect = [True, True, False] - - config = ObservatoryConfig() - expected_dags = [ - WorkflowsProject( - package_name="pack1", - package="/tmp", - package_type="editable", - dags_module="something", - ), - WorkflowsProject( - package_name="pack2", - package="/tmp", - package_type="editable", - dags_module="else", - ), - ] - - m_prompt.side_effect = [ - expected_dags[0].package_name, - expected_dags[0].package, - expected_dags[0].package_type, - expected_dags[0].dags_module, - expected_dags[1].package_name, - expected_dags[1].package, - expected_dags[1].package_type, - expected_dags[1].dags_module, - ] - - InteractiveConfigBuilder.config_workflows_projects(config) - self.assertEqual(config.workflows_projects, expected_dags) - - @patch("observatory.platform.cli.generate_command.click.prompt") - def test_config_cloud_sql_database(self, m_prompt): - setting = CloudSqlDatabase(tier="something", backup_start_time="12:00") - - m_prompt.side_effect = [ - setting.tier, - setting.backup_start_time, - ] - - config = TerraformConfig() - InteractiveConfigBuilder.config_cloud_sql_database(config) - - self.assertEqual(config.cloud_sql_database, setting) - - @patch("observatory.platform.cli.generate_command.click.prompt") - @patch("observatory.platform.cli.generate_command.click.confirm") - def test_config_airflow_main_vm(self, m_confirm, m_prompt): - create = True - m_confirm.return_value = create - - vm = VirtualMachine( - machine_type="n2-standard-2", - disk_size=1, - disk_type="pd-ssd", - create=create, - ) - - m_prompt.side_effect = [ - vm.machine_type, - vm.disk_size, - vm.disk_type, - vm.create, - ] - - config = TerraformConfig() - InteractiveConfigBuilder.config_airflow_main_vm(config) - self.assertEqual(config.airflow_main_vm, vm) - - @patch("observatory.platform.cli.generate_command.click.prompt") - @patch("observatory.platform.cli.generate_command.click.confirm") - def test_config_airflow_worker_vm(self, m_confirm, m_prompt): - create = False - m_confirm.return_value = create - - vm = VirtualMachine( - machine_type="n2-standard-2", - disk_size=1, - disk_type="pd-ssd", - create=create, - ) - - m_prompt.side_effect = [ - vm.machine_type, - vm.disk_size, - vm.disk_type, - vm.create, - ] - - config = TerraformConfig() - InteractiveConfigBuilder.config_airflow_worker_vm(config) - self.assertEqual(config.airflow_worker_vm, vm) - - -class TestFernetKeyParamType(unittest.TestCase): - def test_fernet_key_convert_fail(self): - ctype = FernetKeyType() - self.assertTrue(hasattr(ctype, "name")) - self.assertRaises(click.exceptions.BadParameter, ctype.convert, "badkey") - - def test_fernet_key_convert_succeed(self): - ctype = FernetKeyType() - self.assertTrue(hasattr(ctype, "name")) - key = "2a-Wxx5CZdb7wm_T6OailRtUilT7gajYTmPxoUvhVfM=" - result = ctype.convert(key) - self.assertEqual(result, key) - - -class TestSecretKeyParamType(unittest.TestCase): - def test_secret_key_convert_fail(self): - ctype = FlaskSecretKeyType() - self.assertTrue(hasattr(ctype, "name")) - self.assertRaises(click.exceptions.BadParameter, ctype.convert, "badkey") - - def test_secret_key_convert_succeed(self): - ctype = FlaskSecretKeyType() - self.assertTrue(hasattr(ctype, "name")) - key = "a" * 16 - result = ctype.convert(key) - self.assertEqual(result, key) diff --git a/tests/observatory/platform/cli/test_platform_command.py b/tests/observatory/platform/cli/test_platform_command.py deleted file mode 100644 index 0d3a3f6ac..000000000 --- a/tests/observatory/platform/cli/test_platform_command.py +++ /dev/null @@ -1,114 +0,0 @@ -# Copyright 2020 Curtin University -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Author: James Diprose - -import os -import unittest -from datetime import datetime -from typing import Any -from unittest.mock import patch, Mock - -from click.testing import CliRunner - -from observatory.platform.cli.platform_command import PlatformCommand -from observatory.platform.observatory_config import Observatory, ObservatoryConfig, Backend, Environment, BackendType -from observatory.platform.observatory_environment import find_free_port - - -class MockUrlOpen(Mock): - def __init__(self, status: int, **kwargs: Any): - super().__init__(**kwargs) - self.status = status - - def getcode(self): - return self.status - - -class TestPlatformCommand(unittest.TestCase): - def setUp(self) -> None: - self.config_file_name = "config.yaml" - self.observatory_platform_package_name = "observatory-platform" - self.airflow_fernet_key = "fernet-key" - self.airflow_secret_key = "secret-key" - self.airflow_ui_port = find_free_port() - - def make_observatory_config(self, t): - """Make an ObservatoryConfig instance. - - :param t: the temporary path. - :return: the ObservatoryConfig. - """ - - return ObservatoryConfig( - backend=Backend(type=BackendType.local, environment=Environment.develop), - observatory=Observatory( - package=os.path.join(t, self.observatory_platform_package_name), - package_type="editable", - airflow_fernet_key=self.airflow_fernet_key, - airflow_secret_key=self.airflow_secret_key, - observatory_home=t, - airflow_ui_port=find_free_port(), - ), - ) - - def test_ui_url(self): - with CliRunner().isolated_filesystem() as t: - # Make config - config = self.make_observatory_config(t) - - # Test that ui URL is correct - cmd = PlatformCommand(config) - cmd.config.observatory.airflow_ui_port = self.airflow_ui_port - - self.assertEqual(f"http://localhost:{self.airflow_ui_port}", cmd.ui_url) - - @patch("urllib.request.urlopen") - def test_wait_for_airflow_ui_success(self, mock_url_open): - # Mock the status code return value: 200 should succeed - mock_url_open.return_value = MockUrlOpen(200) - - with CliRunner().isolated_filesystem() as t: - # Make config - config = self.make_observatory_config(t) - - # Test that ui connects - cmd = PlatformCommand(config) - start = datetime.now() - state = cmd.wait_for_airflow_ui() - end = datetime.now() - duration = (end - start).total_seconds() - - self.assertTrue(state) - self.assertAlmostEquals(0, duration, delta=0.5) - - @patch("urllib.request.urlopen") - def test_wait_for_airflow_ui_failed(self, mock_url_open): - # Mock the status code return value: 500 should fail - mock_url_open.return_value = MockUrlOpen(500) - - with CliRunner().isolated_filesystem() as t: - # Make config - config = self.make_observatory_config(t) - - # Test that ui error - cmd = PlatformCommand(config) - expected_timeout = 10 - start = datetime.now() - state = cmd.wait_for_airflow_ui(expected_timeout) - end = datetime.now() - duration = (end - start).total_seconds() - - self.assertFalse(state) - self.assertAlmostEquals(expected_timeout, duration, delta=1) diff --git a/tests/observatory/platform/docker/__init__.py b/tests/observatory/platform/docker/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/observatory/platform/docker/test_builder.py b/tests/observatory/platform/docker/test_builder.py deleted file mode 100644 index dbf0919c5..000000000 --- a/tests/observatory/platform/docker/test_builder.py +++ /dev/null @@ -1,117 +0,0 @@ -# Copyright 2021 Curtin University -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Author: James Diprose - -import os -import unittest - -from click.testing import CliRunner - -from observatory.platform.docker.builder import Builder, Template, rendered_file_name - - -class TestTemplate(unittest.TestCase): - def test_rendered_file_name(self): - template_file_path = os.path.join("path", "to", "template.yaml.jinja2") - expected_output_filename = "template.yaml" - actual_file_name = rendered_file_name(template_file_path) - self.assertEqual(expected_output_filename, actual_file_name) - - def test_output_file_name(self): - template_file_path = os.path.join("path", "to", "template.yaml.jinja2") - t = Template(path=template_file_path, kwargs={}) - expected_output_filename = "template.yaml" - self.assertEqual(expected_output_filename, t.output_file_name) - - -class TestBuilder(unittest.TestCase): - def test_add_template(self): - expected_path = os.path.join("path", "to", "template.yaml.jinja2") - expected_kwargs = {"key1": 1, "key2": 2} - - builder = Builder(build_path="/") - builder.add_template(path=expected_path, **expected_kwargs) - - self.assertEqual(1, len(builder.templates)) - template = builder.templates[0] - self.assertEqual(expected_path, template.path) - self.assertEqual(expected_kwargs, template.kwargs) - - def test_add_file(self): - expected_path = os.path.join("path", "to", "my-file.txt") - expected_output_file_name = "my-file-renamed.txt" - - builder = Builder(build_path="/") - builder.add_file(path=expected_path, output_file_name=expected_output_file_name) - - self.assertEqual(1, len(builder.files)) - file = builder.files[0] - self.assertEqual(expected_path, file.path) - self.assertEqual(expected_output_file_name, file.output_file_name) - - def test_render_template(self): - with CliRunner().isolated_filesystem() as t: - # Make a template - template_path = os.path.join(t, "template.txt.jinja2") - with open(template_path, mode="w") as f: - f.write("{{ key }}") - expected_content = "hello" - template = Template(path=template_path, kwargs={"key": expected_content}) - - # Render the template - builder = Builder(build_path=t) - output_file_path = os.path.join(t, "template.txt") - builder.render_template(template, output_file_path) - - # Check output - with open(output_file_path, mode="r") as f: - content = f.read() - self.assertEqual(expected_content, content) - - def test_make_files(self): - with CliRunner().isolated_filesystem() as t: - build_path = os.path.join(t, "build") - - builder = Builder(build_path=build_path) - file_path = os.path.join(t, "file.txt") - output_file_name = "my-file.txt" - template_path = os.path.join(t, "template.txt.jinja2") - builder.add_file(path=file_path, output_file_name=output_file_name) - expected_content = "hello" - builder.add_template(path=template_path, key=expected_content) - - # Add files - with open(file_path, mode="w") as f: - f.write(expected_content) - with open(template_path, mode="w") as f: - f.write("{{ key }}") - - # Make files - builder.make_files() - - # Check files and folders exist - self.assertTrue(os.path.exists(build_path)) - - expected_file_path = os.path.join(build_path, output_file_name) - self.assertTrue(os.path.exists(expected_file_path)) - with open(expected_file_path, mode="r") as f: - content = f.read() - self.assertEqual(expected_content, content) - - expected_template_path = os.path.join(build_path, "template.txt") - self.assertTrue(os.path.exists(expected_template_path)) - with open(expected_template_path, mode="r") as f: - content = f.read() - self.assertEqual(expected_content, content) diff --git a/tests/observatory/platform/docker/test_compose_runner.py b/tests/observatory/platform/docker/test_compose_runner.py deleted file mode 100644 index d0a70b573..000000000 --- a/tests/observatory/platform/docker/test_compose_runner.py +++ /dev/null @@ -1,88 +0,0 @@ -# Copyright 2021 Curtin University -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Author: James Diprose - -import os.path -import unittest - -from click.testing import CliRunner - -from observatory.platform.docker.compose_runner import ComposeRunner -from observatory.platform.observatory_environment import find_free_port -from observatory.platform.utils.url_utils import retry_session - -DOCKERFILE = """ -FROM nginx -RUN echo "hello world" -""" - -DOCKER_COMPOSE_TEMPLATE = """ -version: '3' -services: - web: - build: . - ports: - - {{ port }}:80 -""" - - -class TestComposeRunner(unittest.TestCase): - def test_compose_file_name(self): - compose_template_path = "docker-compose.yml.jinja2" - runner = ComposeRunner(compose_template_path=compose_template_path, build_path="/") - self.assertEqual("docker-compose.yml", runner.compose_file_name) - - def test_build_start_stop(self): - with CliRunner().isolated_filesystem() as t: - build = os.path.join(t, "build") - port = find_free_port() - - # Save docker compose template - compose_template_path = os.path.join(t, "docker-compose.yml.jinja2") - with open(compose_template_path, mode="w") as f: - f.write(DOCKER_COMPOSE_TEMPLATE) - - # Save Dockerfile - dockerfile_name = "Dockerfile" - dockerfile_path = os.path.join(t, dockerfile_name) - with open(dockerfile_path, mode="w") as f: - f.write(DOCKERFILE) - - # Create compose runner - runner = ComposeRunner( - compose_template_path=compose_template_path, - compose_template_kwargs={"port": port}, - build_path=build, - debug=True, - ) - runner.add_file(path=dockerfile_path, output_file_name=dockerfile_name) - - # Build - p = runner.build() - self.assertEqual(0, p.return_code) - - # Start and stop - try: - p = runner.start() - self.assertEqual(0, p.return_code) - - # Test that the port is up - response = retry_session().get(f"http://localhost:{port}") - self.assertEqual(200, response.status_code) - - p = runner.stop() - self.assertEqual(0, p.return_code) - finally: - runner.stop() diff --git a/tests/observatory/platform/docker/test_platform_runner.py b/tests/observatory/platform/docker/test_platform_runner.py deleted file mode 100644 index 2380db0ea..000000000 --- a/tests/observatory/platform/docker/test_platform_runner.py +++ /dev/null @@ -1,286 +0,0 @@ -# Copyright 2020 Curtin University -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Author: James Diprose - -import logging -import os -import pathlib -import unittest -from typing import Any, Dict -from unittest.mock import Mock, patch - -import requests -from click.testing import CliRunner - -import observatory.platform.docker as docker_module -from observatory.platform.cli.cli import HOST_UID -from observatory.platform.docker.platform_runner import PlatformRunner -from observatory.platform.observatory_config import ( - Backend, - BackendType, - Workflow, - CloudWorkspace, - Environment, - GoogleCloud, - Observatory, - ObservatoryConfig, - Terraform, - WorkflowsProject, -) -from observatory.platform.observatory_environment import module_file_path - - -class MockFromEnv(Mock): - def __init__(self, is_running: bool, **kwargs: Any): - super().__init__(**kwargs) - self.is_running = is_running - - def ping(self): - if self.is_running: - return True - raise requests.exceptions.ConnectionError() - - -def make_expected_env(cmd: PlatformRunner) -> Dict: - """Make an expected environment. - - :param cmd: the PlatformRunner. - :return: the environment. - """ - - observatory_home = os.path.normpath(cmd.config.observatory.observatory_home) - return { - "COMPOSE_PROJECT_NAME": cmd.config.observatory.docker_compose_project_name, - "POSTGRES_USER": "observatory", - "POSTGRES_HOSTNAME": "postgres", - "REDIS_HOSTNAME": "redis", - "HOST_USER_ID": str(HOST_UID), - "HOST_DATA_PATH": os.path.join(observatory_home, "data"), - "HOST_LOGS_PATH": os.path.join(observatory_home, "logs"), - "HOST_POSTGRES_PATH": os.path.join(observatory_home, "postgres"), - "HOST_REDIS_PORT": str(cmd.config.observatory.redis_port), - "HOST_FLOWER_UI_PORT": str(cmd.config.observatory.flower_ui_port), - "HOST_AIRFLOW_UI_PORT": str(cmd.config.observatory.airflow_ui_port), - "HOST_API_SERVER_PORT": str(cmd.config.observatory.api_port), - "AIRFLOW_FERNET_KEY": cmd.config.observatory.airflow_fernet_key, - "AIRFLOW_SECRET_KEY": cmd.config.observatory.airflow_secret_key, - "AIRFLOW_UI_USER_EMAIL": cmd.config.observatory.airflow_ui_user_email, - "AIRFLOW_UI_USER_PASSWORD": cmd.config.observatory.airflow_ui_user_password, - "POSTGRES_PASSWORD": cmd.config.observatory.postgres_password, - } - - -class TestPlatformRunner(unittest.TestCase): - def setUp(self) -> None: - self.is_env_local = True - self.observatory_platform_path = module_file_path("observatory.platform", nav_back_steps=-3) - self.observatory_api_path = module_file_path("observatory.api", nav_back_steps=-3) - - def get_config(self, t: str): - return ObservatoryConfig( - backend=Backend(type=BackendType.local, environment=Environment.develop), - observatory=Observatory( - package=self.observatory_platform_path, - package_type="editable", - airflow_fernet_key="ez2TjBjFXmWhLyVZoZHQRTvBcX2xY7L4A7Wjwgr6SJU=", - airflow_secret_key="a" * 16, - observatory_home=t, - api_package=self.observatory_api_path, - api_package_type="editable" - ), - ) - - def test_is_environment_valid(self): - with CliRunner().isolated_filesystem() as t: - # Assumes that Docker is setup on the system where the tests are run - cfg = self.get_config(t) - cmd = PlatformRunner(config=cfg) - self.assertTrue(cmd.is_environment_valid) - - def test_docker_module_path(self): - """Test that the path to the Docker module is found""" - - with CliRunner().isolated_filesystem() as t: - cfg = self.get_config(t) - cmd = PlatformRunner(config=cfg) - expected_path = str(pathlib.Path(*pathlib.Path(docker_module.__file__).resolve().parts[:-1]).resolve()) - self.assertEqual(expected_path, cmd.docker_module_path) - - def test_docker_exe_path(self): - """Test that the path to the Docker executable is found""" - - with CliRunner().isolated_filesystem() as t: - cfg = self.get_config(t) - cmd = PlatformRunner(config=cfg) - result = cmd.docker_exe_path - self.assertIsNotNone(result) - self.assertTrue(result.endswith("docker")) - - def test_docker_compose(self): - """Test that the path to the Docker Compose executable is found""" - - with CliRunner().isolated_filesystem() as t: - cfg = self.get_config(t) - cmd = PlatformRunner(config=cfg) - self.assertTrue(cmd.docker_compose) - - @patch("observatory.platform.docker.platform_runner.docker.from_env") - def test_is_docker_running_true(self, mock_from_env): - """Test the property is_docker_running returns True when Docker is running""" - - mock_from_env.return_value = MockFromEnv(True) - - with CliRunner().isolated_filesystem() as t: - cfg = self.get_config(t) - cmd = PlatformRunner(config=cfg) - self.assertTrue(cmd.is_docker_running) - - @patch("observatory.platform.docker.platform_runner.docker.from_env") - def test_is_docker_running_false(self, mock_from_env): - """Test the property is_docker_running returns False when Docker is not running""" - - mock_from_env.return_value = MockFromEnv(False) - - with CliRunner().isolated_filesystem() as t: - cfg = self.get_config(t) - cmd = PlatformRunner(config=cfg) - self.assertFalse(cmd.is_docker_running) - - def test_make_observatory_files(self): - """Test building of the observatory files""" - - with CliRunner().isolated_filesystem() as t: - cfg = self.get_config(t) - cmd = PlatformRunner(config=cfg) - cmd.build() - - # Test that the expected files have been written - build_file_names = [ - "docker-compose.observatory.yml", - "Dockerfile.observatory", - "entrypoint-airflow.sh", - "entrypoint-root.sh", - ] - for file_name in build_file_names: - path = os.path.join(cmd.build_path, file_name) - logging.info(f"Expected file: {path}") - self.assertTrue(os.path.isfile(path)) - self.assertTrue(os.stat(path).st_size > 0) - - def test_make_environment_minimal(self): - """Test making of the minimal observatory platform files""" - - # Check that the environment variables are set properly for the default config - with CliRunner().isolated_filesystem() as t: - cfg = self.get_config(t) - cmd = PlatformRunner(config=cfg) - - # Make the environment - expected_env = make_expected_env(cmd) - env = cmd.make_environment() - - # Check that expected keys and values exist - for key, value in expected_env.items(): - self.assertTrue(key in env) - self.assertEqual(value, env[key]) - - # Check that Google Application credentials not in default config - self.assertTrue("HOST_GOOGLE_APPLICATION_CREDENTIALS" not in env) - - def test_make_environment_all_settings(self): - """Test making of the observatory platform files with all settings""" - - # Check that the environment variables are set properly for a complete config file - with CliRunner().isolated_filesystem() as t: - cfg = self.get_config(t) - cmd = PlatformRunner(config=cfg) - - # Manually override the platform command with a more fleshed out config file - dags_project = WorkflowsProject( - package_name="academic-observatory-workflows", - package="/path/to/academic-observatory-workflows", - package_type="editable", - dags_module="academic_observatory_workflows.dags", - ) - - backend = Backend(type=BackendType.local, environment=Environment.develop) - observatory = Observatory( - package="/path/to/observatory-platform", - package_type="editable", - airflow_fernet_key="ez2TjBjFXmWhLyVZoZHQRTvBcX2xY7L4A7Wjwgr6SJU=", - airflow_secret_key="a" * 16, - observatory_home=t, - ) - google_cloud = GoogleCloud(project_id="my-project-id", credentials="/path/to/creds.json") - terraform = Terraform(organization="my-terraform-org-name") - cloud_workspace = CloudWorkspace( - project_id="my-project-id", - download_bucket="my-download-bucket", - transform_bucket="my-transform-bucket", - data_location="us", - ) - config = ObservatoryConfig( - backend=backend, - observatory=observatory, - google_cloud=google_cloud, - terraform=terraform, - workflows_projects=[dags_project], - cloud_workspaces=[], - workflows=[ - Workflow( - dag_id="my_workflow", - name="My Workflow", - class_name="path.to.my_workflow.Workflow", - cloud_workspace=cloud_workspace, - kwargs=dict(hello="world"), - ) - ], - ) - cmd.config = config - cmd.config_exists = True - - # Make environment and test - env = cmd.make_environment() - - # Set FERNET_KEY, HOST_GOOGLE_APPLICATION_CREDENTIALS, AIRFLOW_VAR_DAGS_MODULE_NAMES - # and airflow variables and connections - expected_env = make_expected_env(cmd) - expected_env["AIRFLOW_FERNET_KEY"] = cmd.config.observatory.airflow_fernet_key - expected_env["AIRFLOW_SECRET_KEY"] = cmd.config.observatory.airflow_secret_key - expected_env["HOST_GOOGLE_APPLICATION_CREDENTIALS"] = cmd.config.google_cloud.credentials - expected_env["AIRFLOW_VAR_WORKFLOWS"] = cmd.config.airflow_var_workflows - expected_env["AIRFLOW_VAR_DAGS_MODULE_NAMES"] = cmd.config.airflow_var_dags_module_names - - # Check that expected keys and values exist - for key, value in expected_env.items(): - logging.info(f"Expected key: {key}") - self.assertTrue(key in env) - self.assertEqual(value, env[key]) - - def test_build(self): - """Test building of the observatory platform""" - - # Check that the environment variables are set properly for the default config - with CliRunner().isolated_filesystem() as t: - cfg = self.get_config(t) - cmd = PlatformRunner(config=cfg) - cmd.debug = True - - # Build the platform - response = cmd.build() - - # Assert that the platform builds - expected_return_code = 0 - self.assertEqual(expected_return_code, response.return_code) diff --git a/tests/observatory/platform/terraform/__init__.py b/tests/observatory/platform/terraform/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/observatory/platform/terraform/test_terraform_api.py b/tests/observatory/platform/terraform/test_terraform_api.py deleted file mode 100644 index 900b95cc5..000000000 --- a/tests/observatory/platform/terraform/test_terraform_api.py +++ /dev/null @@ -1,521 +0,0 @@ -# Copyright 2020 Curtin University. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Author: Aniek Roelofs, James Diprose - -import json -import os -import tarfile -import unittest - -from click.testing import CliRunner - -from observatory.platform.observatory_environment import random_id, test_fixtures_path -from observatory.platform.terraform.terraform_api import TerraformApi, TerraformVariable - -VALID_VARIABLE_DICTS = [ - { - "type": "vars", - "attributes": { - "key": "literal_variable", - "value": "some_text", - "sensitive": False, - "category": "terraform", - "hcl": False, - "description": "a description", - }, - }, - { - "type": "vars", - "attributes": { - "key": "literal_variable_sensitive", - "value": "some_text", - "sensitive": True, - "category": "terraform", - "hcl": False, - "description": "a description", - }, - }, - { - "type": "vars", - "attributes": { - "key": "hcl_variable", - "value": '{"test1"="value1", "test2"="value2"}', - "category": "terraform", - "description": "a description", - "hcl": True, - "sensitive": False, - }, - }, - { - "type": "vars", - "attributes": { - "key": "hcl_variable_sensitive", - "value": '{"test1"="value1", "test2"="value2"}', - "category": "terraform", - "description": "a description", - "hcl": True, - "sensitive": True, - }, - }, -] - -INVALID_VARIABLE_DICT = { - "type": "vars", - "attributes": { - "key": "invalid", - "value": "some_text", - "category": "invalid", - "description": "a description", - "hcl": False, - "sensitive": False, - }, -} - - -class TestTerraformVariable(unittest.TestCase): - def test_from_dict_to_dict(self): - # Check that from and to dict work - for expected_dict in VALID_VARIABLE_DICTS: - var = TerraformVariable.from_dict(expected_dict) - self.assertIsInstance(var, TerraformVariable) - actual_dict = var.to_dict() - self.assertDictEqual(expected_dict, actual_dict) - - # Test that value error is raised because of incorrect category - with self.assertRaises(ValueError): - TerraformVariable.from_dict(INVALID_VARIABLE_DICT) - - -class TestTerraformApi(unittest.TestCase): - organisation = os.getenv("TEST_TERRAFORM_ORGANISATION") - workspace = random_id() - token = os.getenv("TEST_TERRAFORM_TOKEN") - terraform_api = TerraformApi(token) - version = TerraformApi.TERRAFORM_WORKSPACE_VERSION - description = "test" - - def __init__(self, *args, **kwargs): - super(TestTerraformApi, self).__init__(*args, **kwargs) - - @classmethod - def setUpClass(cls) -> None: - cls.terraform_api.create_workspace( - cls.organisation, cls.workspace, auto_apply=True, description=cls.description, version=cls.version - ) - - @classmethod - def tearDownClass(cls) -> None: - cls.terraform_api.delete_workspace(cls.organisation, cls.workspace) - - def test_token_from_file(self): - """Test if token from json file is correctly retrieved""" - - token = "24asdfAAD.atlasv1.AD890asdnlqADn6daQdf" - token_json = {"credentials": {"app.terraform.io": {"token": token}}} - - with CliRunner().isolated_filesystem(): - with open("token.json", "w") as f: - json.dump(token_json, f) - self.assertEqual(token, TerraformApi.token_from_file("token.json")) - - def test_create_delete_workspace(self): - """Test response codes of successfully creating a workspace and when trying to create a workspace that - already exists.""" - - # First time, successful - workspace_id = self.workspace + "unittest" - response_code = self.terraform_api.create_workspace( - self.organisation, workspace_id, auto_apply=True, description=self.description, version=self.version - ) - self.assertEqual(response_code, 201) - - # Second time, workspace already exists - response_code = self.terraform_api.create_workspace( - self.organisation, workspace_id, auto_apply=True, description=self.description, version=self.version - ) - self.assertEqual(response_code, 422) - - # Delete workspace - response_code = self.terraform_api.delete_workspace(self.organisation, workspace_id) - self.assertEqual(response_code, 204) # must have changed from 200 to 204 in Feb / March 2022 - - # Try to delete non-existing workspace - response_code = self.terraform_api.delete_workspace(self.organisation, workspace_id) - self.assertEqual(response_code, 404) - - def test_workspace_id(self): - """Test that workspace id returns string or raises SystemExit for invalid workspace""" - - workspace_id = self.terraform_api.workspace_id(self.organisation, self.workspace) - self.assertIsInstance(workspace_id, str) - - with self.assertRaises(SystemExit): - self.terraform_api.workspace_id(self.organisation, "non-existing-workspace") - - def test_add_delete_workspace_variable(self): - """Test whether workspace variable is successfully added and deleted""" - - # Get workspace id - workspace_id = self.terraform_api.workspace_id(self.organisation, self.workspace) - - # Test the vars - for var_dict in VALID_VARIABLE_DICTS: - # Add variable - var = TerraformVariable.from_dict(var_dict) - var.var_id = self.terraform_api.add_workspace_variable(var, workspace_id) - self.assertIsInstance(var.var_id, str) - - # Raise error trying to add variable with key that already exists - with self.assertRaises(ValueError): - self.terraform_api.add_workspace_variable(var, workspace_id) - - # Delete variable - response_code = self.terraform_api.delete_workspace_variable(var, workspace_id) - self.assertEqual(response_code, 204) - - # Raise error trying to delete variable that doesn't exist - with self.assertRaises(ValueError): - self.terraform_api.delete_workspace_variable(var, workspace_id) - - def test_update_workspace_variable(self): - """Test updating variable both with empty attributes (meaning the var won't change) and updated attributes.""" - # get workspace id - workspace_id = self.terraform_api.workspace_id(self.organisation, self.workspace) - - for var_dict in VALID_VARIABLE_DICTS: - # Make variable - var = TerraformVariable.from_dict(var_dict) - - # Add variable to workspace and set var_id - var.var_id = self.terraform_api.add_workspace_variable(var, workspace_id) - - # Change key and value - var.key = var.key + "_updated" - var.value = "updated" - - # Key can not be changed for sensitive variables - if var.sensitive: - with self.assertRaises(ValueError): - self.terraform_api.update_workspace_variable(var, workspace_id) - else: - response_code = self.terraform_api.update_workspace_variable(var, workspace_id) - self.assertEqual(response_code, 200) - - # Delete variable - self.terraform_api.delete_workspace_variable(var, workspace_id) - - def test_list_workspace_variables(self): - """Test listing workspace variables and the returned response.""" - # get workspace id - workspace_id = self.terraform_api.workspace_id(self.organisation, self.workspace) - expected_vars_index = dict() - for var_dict in VALID_VARIABLE_DICTS: - # Make variable - var = TerraformVariable.from_dict(var_dict) - - # Add variable - self.terraform_api.add_workspace_variable(var, workspace_id) - - # Add to index - expected_vars_index[var.key] = var - - # Fetch workspace variables - workspace_vars = self.terraform_api.list_workspace_variables(workspace_id) - - # Check that the number of variables created is correct - self.assertTrue(len(VALID_VARIABLE_DICTS)) - - # Check expected and actual - for actual_var in workspace_vars: - # Make variable - expected_var = expected_vars_index[actual_var.key] - expected_var.var_id = actual_var.var_id - if expected_var.sensitive: - expected_var.value = None - - # Check that variable is TerraformVariable instance - self.assertIsInstance(actual_var, TerraformVariable) - - # Check that expected and actual variables match - self.assertDictEqual(expected_var.to_dict(), actual_var.to_dict()) - - # Delete variable - self.terraform_api.delete_workspace_variable(actual_var, workspace_id) - - def test_create_configuration_version(self): - """Test that configuration version is uploaded successfully""" - # get workspace id - workspace_id = self.terraform_api.workspace_id(self.organisation, self.workspace) - - # create configuration version - upload_url, _ = self.terraform_api.create_configuration_version(workspace_id) - self.assertIsInstance(upload_url, str) - - def test_check_configuration_version_status(self): - # get workspace id - workspace_id = self.terraform_api.workspace_id(self.organisation, self.workspace) - - # create configuration version - _, configuration_id = self.terraform_api.create_configuration_version(workspace_id) - - # get status - configuration_status = self.terraform_api.get_configuration_version_status(configuration_id) - self.assertIn(configuration_status, ["pending", "uploaded", "errored"]) - - def test_upload_configuration_files(self): - """Test that configuration files are uploaded successfully""" - # get workspace id - workspace_id = self.terraform_api.workspace_id(self.organisation, self.workspace) - - # create configuration version - upload_url, _ = self.terraform_api.create_configuration_version(workspace_id) - - configuration_path = test_fixtures_path("utils", "main.tf") - configuration_tar = "conf.tar.gz" - - with CliRunner().isolated_filesystem(): - # create tar.gz file of main.tf - with tarfile.open(configuration_tar, "w:gz") as tar: - tar.add(configuration_path, arcname=os.path.basename(configuration_path)) - - # test that configuration was created successfully - response_code = self.terraform_api.upload_configuration_files(upload_url, configuration_tar) - self.assertEqual(response_code, 200) - - def test_create_run(self): - """Test creating a run (with auto-apply)""" - # get workspace id - workspace_id = self.terraform_api.workspace_id(self.organisation, self.workspace) - # create configuration version - upload_url, configuration_id = self.terraform_api.create_configuration_version(workspace_id) - - configuration_path = test_fixtures_path("utils", "main.tf") - configuration_tar = "conf.tar.gz" - - with CliRunner().isolated_filesystem(): - # create tar.gz file of main.tf - with tarfile.open(configuration_tar, "w:gz") as tar: - tar.add(configuration_path, arcname=os.path.basename(configuration_path)) - # upload configuration files - self.terraform_api.upload_configuration_files(upload_url, configuration_tar) - - # wait until configuration files are processed and uploaded - configuration_status = None - while configuration_status != "uploaded": - configuration_status = self.terraform_api.get_configuration_version_status(configuration_id) - - # create run without target - run_id = self.terraform_api.create_run(workspace_id, target_addrs=None, message="No target") - self.assertIsInstance(run_id, str) - - # create run with target - run_id = self.terraform_api.create_run( - workspace_id, target_addrs="random_id.random", message="Targeting " "random_id" - ) - self.assertIsInstance(run_id, str) - - def test_get_run_details(self): - """Test retrieval of run details""" - # get workspace id - workspace_id = self.terraform_api.workspace_id(self.organisation, self.workspace) - # create configuration version - upload_url, configuration_id = self.terraform_api.create_configuration_version(workspace_id) - - configuration_path = test_fixtures_path("utils", "main.tf") - configuration_tar = "conf.tar.gz" - - with CliRunner().isolated_filesystem(): - # create tar.gz file of main.tf - with tarfile.open(configuration_tar, "w:gz") as tar: - tar.add(configuration_path, arcname=os.path.basename(configuration_path)) - # upload configuration files - self.terraform_api.upload_configuration_files(upload_url, configuration_tar) - - # wait until configuration files are processed and uploaded - configuration_status = None - while configuration_status != "uploaded": - configuration_status = self.terraform_api.get_configuration_version_status(configuration_id) - - # possible states - run_states = [ - "pending", - "plan_queued", - "planning", - "planned", - "cost_estimating", - "cost_estimated", - "policy_checking", - "policy_override", - "policy_soft_failed", - "policy_checked", - "confirmed", - "planned_and_finished", - "apply_queued", - "applying", - "applied", - "discarded", - "errored", - "canceled", - "force_canceled", - ] - - # create run with target - run_id = self.terraform_api.create_run( - workspace_id, target_addrs="random_id.random", message="Targeting " "random_id" - ) - run_details = self.terraform_api.get_run_details(run_id) - self.assertIsInstance(run_details, dict) - run_status = run_details["data"]["attributes"]["status"] - self.assertIn(run_status, run_states) - - # create run without target - run_id = self.terraform_api.create_run(workspace_id, target_addrs=None, message="No target") - run_details = self.terraform_api.get_run_details(run_id) - self.assertIsInstance(run_details, dict) - run_status = run_details["data"]["attributes"]["status"] - self.assertIn(run_status, run_states) - - def test_plan_variable_changes(self): - """Test the lists that are returned by plan_variable_changes.""" - - # Get workspace id - workspace_id = self.terraform_api.workspace_id(self.organisation, self.workspace) - - # Make variables to create and update - create_vars = [TerraformVariable.from_dict(var_dict) for var_dict in VALID_VARIABLE_DICTS[1:4]] - update_vars = [TerraformVariable.from_dict(var_dict) for var_dict in VALID_VARIABLE_DICTS[0:3]] - - # Create variables - for var in create_vars: - self.terraform_api.add_workspace_variable(var, workspace_id) - - # Leave variable 1 and 2 unchanged. Variable 1 is sensitive, so should be in - # the 'edit' list even when unchanged - add, edit, unchanged, delete = self.terraform_api.plan_variable_changes(update_vars, workspace_id) - - # Check lengths - expected_length = 1 - self.assertEqual(len(add), expected_length) - self.assertEqual(len(edit), expected_length) - self.assertEqual(len(unchanged), expected_length) - self.assertEqual(len(delete), expected_length) - - # Add: should contain variable with key literal_variable (index 0 update_vars) - self.assertDictEqual(update_vars[0].to_dict(), add[0].to_dict()) - - # Edit: should contain a tuple of variables with key literal_variable_sensitive (index 0 create_vars), - # the old variable and the new variable. var_id should not be None - edit_old_var = edit[0][0] - edit_new_var = edit[0][1] - expected_key = "literal_variable_sensitive" - self.assertEqual(edit_old_var.key, expected_key) - self.assertEqual(edit_new_var.key, expected_key) - self.assertIsNotNone(edit_old_var.var_id) - self.assertIsNotNone(edit_new_var.var_id) - - # Unchanged: should contain variable with key hcl_variable - self.assertDictEqual(create_vars[1].to_dict(), unchanged[0].to_dict()) - - # Delete: - expected_key = "hcl_variable_sensitive" - delete_var = delete[0] - self.assertEqual(delete_var.key, expected_key) - self.assertIsNotNone(delete_var.var_id) - - # Change variables 1 and 2 - update_vars[1].value = "updated" - update_vars[2].value = "updated" - add, edit, unchanged, delete = self.terraform_api.plan_variable_changes(update_vars, workspace_id) - - # Check lengths of results - self.assertEqual(len(add), 1) - self.assertEqual(len(edit), 2) - self.assertEqual(len(unchanged), 0) - self.assertEqual(len(delete), 1) - - # Add: should contain variable with key literal_variable (index 0 update_vars) - self.assertDictEqual(update_vars[0].to_dict(), add[0].to_dict()) - - # Edit: literal_variable_sensitive and hcl_variable changed - edit1_old_var = edit[0][0] - edit1_new_var = edit[0][1] - expected_key = "literal_variable_sensitive" - self.assertEqual(edit1_old_var.key, expected_key) - self.assertEqual(edit1_new_var.key, expected_key) - self.assertIsNotNone(edit1_old_var.var_id) - self.assertIsNotNone(edit1_new_var.var_id) - - edit2_old_var = edit[1][0] - edit2_new_var = edit[1][1] - expected_key = "hcl_variable" - self.assertEqual(edit2_old_var.key, expected_key) - self.assertEqual(edit2_new_var.key, expected_key) - self.assertIsNotNone(edit2_old_var.var_id) - self.assertIsNotNone(edit2_new_var.var_id) - - # Delete: key hcl_variable_sensitive - expected_key = "hcl_variable_sensitive" - delete_var = delete[0] - self.assertEqual(delete_var.key, expected_key) - self.assertIsNotNone(delete_var.var_id) - - # Delete variables - workspace_vars = self.terraform_api.list_workspace_variables(workspace_id) - for var in workspace_vars: - self.terraform_api.delete_workspace_variable(var, workspace_id) - - def test_update_workspace_variables(self): - """Test whether variables in workspace are updated correctly.""" - - # Get workspace id - workspace_id = self.terraform_api.workspace_id(self.organisation, self.workspace) - - # Create variables - for var_dict in VALID_VARIABLE_DICTS[1:4]: - var = TerraformVariable.from_dict(var_dict) - self.terraform_api.add_workspace_variable(var, workspace_id) - - # Variable 0 will be added and variable 3 will be deleted compared to workspace_vars - var_index = dict() - new_vars = [] - for var_dict in VALID_VARIABLE_DICTS[0:3]: - var = TerraformVariable.from_dict(var_dict) - var_index[var.key] = var - new_vars.append(var) - - # Leave variable 1 and 2 unchanged. Variable 1 is sensitive, so should be in - # the 'edit' list even when unchanged - add, edit, unchanged, delete = self.terraform_api.plan_variable_changes(new_vars, workspace_id) - - # Update variables - self.terraform_api.update_workspace_variables(add, edit, delete, workspace_id) - - # Check that variables have been updated as expected - workspace_vars = self.terraform_api.list_workspace_variables(workspace_id) - - # Should be same number of variables listed as created - self.assertTrue(len(workspace_vars), len(new_vars)) - - # Check that attributes match - for actual_var in workspace_vars: - self.assertIn(actual_var.key, var_index) - - # Make expected variable - expected_var = var_index[actual_var.key] - if expected_var.sensitive: - expected_var.value = None - expected_var.var_id = actual_var.var_id - - self.assertEqual(expected_var.to_dict(), actual_var.to_dict()) diff --git a/tests/observatory/platform/terraform/test_terraform_builder.py b/tests/observatory/platform/terraform/test_terraform_builder.py deleted file mode 100644 index e81b0138a..000000000 --- a/tests/observatory/platform/terraform/test_terraform_builder.py +++ /dev/null @@ -1,219 +0,0 @@ -# Copyright 2020 Curtin University -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Author: James Diprose, Aniek Roelofs - -import os -import unittest -from unittest.mock import Mock, patch - -from click.testing import CliRunner - -from observatory.platform.observatory_config import ( - TerraformConfig, - Backend, - BackendType, - Environment, - Observatory, - Terraform, - GoogleCloud, - CloudSqlDatabase, - VirtualMachine, -) -from observatory.platform.observatory_environment import module_file_path -from observatory.platform.terraform.terraform_builder import TerraformBuilder - - -class Popen(Mock): - def __init__(self, **kwargs): - super().__init__(**kwargs) - - @property - def returncode(self): - return 0 - - -class TestTerraformBuilder(unittest.TestCase): - def setUp(self) -> None: - self.is_env_local = True - self.observatory_platform_path = module_file_path("observatory.platform", nav_back_steps=-3) - self.observatory_api_path = module_file_path("observatory.api", nav_back_steps=-3) - - def get_terraform_config(self, t: str) -> TerraformConfig: - credentials_path = os.path.abspath("creds.json") - - return TerraformConfig( - backend=Backend(type=BackendType.terraform, environment=Environment.develop), - observatory=Observatory( - package=self.observatory_platform_path, - package_type="editable", - airflow_fernet_key="ez2TjBjFXmWhLyVZoZHQRTvBcX2xY7L4A7Wjwgr6SJU=", - airflow_secret_key="a" * 16, - airflow_ui_user_password="password", - airflow_ui_user_email="password", - postgres_password="my-password", - observatory_home=t, - api_package=self.observatory_api_path, - api_package_type="editable", - ), - terraform=Terraform(organization="hello world"), - google_cloud=GoogleCloud( - project_id="my-project", - credentials=credentials_path, - region="us-west1", - zone="us-west1-c", - data_location="us", - ), - cloud_sql_database=CloudSqlDatabase(tier="db-custom-2-7680", backup_start_time="23:00"), - airflow_main_vm=VirtualMachine(machine_type="n2-standard-2", disk_size=1, disk_type="pd-ssd", create=True), - airflow_worker_vm=VirtualMachine( - machine_type="n2-standard-2", - disk_size=1, - disk_type="pd-standard", - create=False, - ), - ) - - def test_is_environment_valid(self): - with CliRunner().isolated_filesystem() as t: - config_path = os.path.join(t, "config.yaml") - - # Environment should be valid because there is a config.yaml - # Assumes that Docker is setup on the system where the tests are run - cfg = self.get_terraform_config(t) - cmd = TerraformBuilder(config=cfg) - self.assertTrue(cmd.is_environment_valid) - - @unittest.skip - def test_packer_exe_path(self): - """Test that the path to the Packer executable is found""" - - with CliRunner().isolated_filesystem() as t: - cfg = self.get_terraform_config(t) - cmd = TerraformBuilder(config=cfg) - result = cmd.packer_exe_path - self.assertIsNotNone(result) - self.assertTrue(result.endswith("packer")) - - def test_build_terraform(self): - """Test building of the terraform files""" - - with CliRunner().isolated_filesystem() as t: - cfg = self.get_terraform_config(t) - cmd = TerraformBuilder(config=cfg) - cmd.build_terraform() - - # Test that the expected Terraform files have been written - secret_files = [os.path.join("secret", n) for n in ["main.tf", "outputs.tf", "variables.tf"]] - vm_files = [os.path.join("vm", n) for n in ["main.tf", "outputs.tf", "variables.tf"]] - root_files = [ - "build.sh", - "main.tf", - "observatory-image.json.pkr.hcl", - "outputs.tf", - "startup-main.tpl", - "startup-worker.tpl", - "variables.tf", - "versions.tf", - ] - all_files = secret_files + vm_files + root_files - - for file_name in all_files: - path = os.path.join(cmd.build_path, "terraform", file_name) - self.assertTrue(os.path.isfile(path)) - - # Test that expected packages exists - packages = ["observatory-api", "observatory-platform"] - for package in packages: - path = os.path.join(cmd.build_path, "packages", package) - print(f"Check exists: {path}") - self.assertTrue(os.path.exists(path)) - - # Test that the expected Docker files have been written - build_file_names = [ - "docker-compose.observatory.yml", - "Dockerfile.observatory", - "entrypoint-airflow.sh", - "entrypoint-root.sh", - ] - for file_name in build_file_names: - path = os.path.join(cmd.build_path, "docker", file_name) - print(f"Checking that file exists: {path}") - self.assertTrue(os.path.isfile(path)) - self.assertTrue(os.stat(path).st_size > 0) - - @patch("subprocess.Popen") - @patch("observatory.platform.terraform.terraform_builder.stream_process") - def install_packer_plugins(self, mock_stream_process, mock_subprocess): - """Test installing the necessary plugins for Packer.""" - - # Check that the environment variables are set properly for the default config - with CliRunner().isolated_filesystem() as t: - mock_subprocess.return_value = Popen() - mock_stream_process.return_value = ("", "") - - # Save default config file - cfg = self.get_terraform_config(t) - cmd = TerraformBuilder(config=cfg) - - # Install the packer plugins - output, error, return_code = cmd.install_packer_plugins() - - print(output, error, return_code) - - # Assert the install was successful - expected_return_code = 0 - self.assertEqual(expected_return_code, return_code) - - @patch("subprocess.Popen") - @patch("observatory.platform.terraform.terraform_builder.stream_process") - def test_build_image(self, mock_stream_process, mock_subprocess): - """Test building of the observatory platform""" - - # Check that the environment variables are set properly for the default config - with CliRunner().isolated_filesystem() as t: - mock_subprocess.return_value = Popen() - mock_stream_process.return_value = ("", "") - - # Save default config file - cfg = self.get_terraform_config(t) - cmd = TerraformBuilder(config=cfg) - - # Build the image - output, error, return_code = cmd.build_image() - - # Assert that the image built - expected_return_code = 0 - self.assertEqual(expected_return_code, return_code) - - @patch("subprocess.Popen") - @patch("observatory.platform.terraform.terraform_builder.stream_process") - def test_gcloud_activate_service_account(self, mock_stream_process, mock_subprocess): - """Test activating the gcloud service account""" - - # Check that the environment variables are set properly for the default config - with CliRunner().isolated_filesystem() as t: - mock_subprocess.return_value = Popen() - mock_stream_process.return_value = ("", "") - - # Make observatory files - cfg = self.get_terraform_config(t) - cmd = TerraformBuilder(config=cfg) - - # Activate the service account - output, error, return_code = cmd.gcloud_activate_service_account() - - # Assert that account was activated - expected_return_code = 0 - self.assertEqual(expected_return_code, return_code) diff --git a/tests/observatory/platform/test_api.py b/tests/observatory/platform/test_api.py deleted file mode 100644 index 8b8886aa9..000000000 --- a/tests/observatory/platform/test_api.py +++ /dev/null @@ -1,112 +0,0 @@ -# Copyright 2020 Curtin University -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from unittest.mock import patch - -import pendulum -from airflow.models.connection import Connection - -from observatory.api.client import ApiClient, Configuration -from observatory.api.client.api.observatory_api import ObservatoryApi # noqa: E501 -from observatory.api.testing import ObservatoryApiEnvironment -from observatory.platform.api import make_observatory_api, build_schedule -from observatory.platform.observatory_environment import ObservatoryTestCase, find_free_port - - -class TestObservatoryAPI(ObservatoryTestCase): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - def setUp(self): - self.timezone = "Pacific/Auckland" - self.host = "localhost" - self.port = 5001 - configuration = Configuration(host=f"http://{self.host}:{self.port}") - api_client = ApiClient(configuration) - self.api = ObservatoryApi(api_client=api_client) # noqa: E501 - self.env = ObservatoryApiEnvironment(host=self.host, port=self.port) - - @patch("airflow.hooks.base.BaseHook.get_connection") - def test_make_observatory_api(self, mock_get_connection): - """Test make_observatory_api""" - - conn_type = "http" - host = "api.observatory.academy" - api_key = "my_api_key" - - # No port - mock_get_connection.return_value = Connection(uri=f"{conn_type}://:{api_key}@{host}") - api = make_observatory_api() - self.assertEqual(f"http://{host}", api.api_client.configuration.host) - self.assertEqual(api_key, api.api_client.configuration.api_key["api_key"]) - - # Port - port = find_free_port() - mock_get_connection.return_value = Connection(uri=f"{conn_type}://:{api_key}@{host}:{port}") - api = make_observatory_api() - self.assertEqual(f"http://{host}:{port}", api.api_client.configuration.host) - self.assertEqual(api_key, api.api_client.configuration.api_key["api_key"]) - - # Assertion error: missing conn_type empty string - with self.assertRaises(AssertionError): - mock_get_connection.return_value = Connection(uri=f"://:{api_key}@{host}") - make_observatory_api() - - # Assertion error: missing host empty string - with self.assertRaises(AssertionError): - mock_get_connection.return_value = Connection(uri=f"{conn_type}://:{api_key}@") - make_observatory_api() - - # Assertion error: missing password empty string - with self.assertRaises(AssertionError): - mock_get_connection.return_value = Connection(uri=f"://:{api_key}@{host}") - make_observatory_api() - - # Assertion error: missing conn_type None - with self.assertRaises(AssertionError): - mock_get_connection.return_value = Connection(password=api_key, host=host) - make_observatory_api() - - # Assertion error: missing host None - with self.assertRaises(AssertionError): - mock_get_connection.return_value = Connection(conn_type=conn_type, password=api_key) - make_observatory_api() - - # Assertion error: missing password None - with self.assertRaises(AssertionError): - mock_get_connection.return_value = Connection(host=host, password=api_key) - make_observatory_api() - - def test_build_schedule(self): - start_date = pendulum.datetime(2021, 1, 1) - end_date = pendulum.datetime(2021, 2, 1) - schedule = build_schedule(start_date, end_date) - self.assertEqual([pendulum.Period(pendulum.date(2021, 1, 1), pendulum.date(2021, 1, 31))], schedule) - - start_date = pendulum.datetime(2021, 1, 1) - end_date = pendulum.datetime(2021, 3, 1) - schedule = build_schedule(start_date, end_date) - self.assertEqual( - [ - pendulum.Period(pendulum.date(2021, 1, 1), pendulum.date(2021, 1, 31)), - pendulum.Period(pendulum.date(2021, 2, 1), pendulum.date(2021, 2, 28)), - ], - schedule, - ) - - start_date = pendulum.datetime(2021, 1, 7) - end_date = pendulum.datetime(2021, 2, 7) - schedule = build_schedule(start_date, end_date) - self.assertEqual([pendulum.Period(pendulum.date(2021, 1, 7), pendulum.date(2021, 2, 6))], schedule) diff --git a/tests/observatory/platform/test_config.py b/tests/observatory/platform/test_config.py deleted file mode 100644 index 1dcbad420..000000000 --- a/tests/observatory/platform/test_config.py +++ /dev/null @@ -1,140 +0,0 @@ -# Copyright 2019 Curtin University. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Author: James Diprose, Aniek Roelofs - -import os -import pathlib -import unittest -from unittest.mock import patch - -import pendulum -from click.testing import CliRunner - -import tests.observatory.platform.utils as platform_utils_tests -from observatory.platform.bigquery import bq_find_schema -from observatory.platform.config import ( - module_file_path, - observatory_home, - terraform_credentials_path, -) -from observatory.platform.observatory_environment import test_fixtures_path - - -class TestConfig(unittest.TestCase): - def test_module_file_path(self): - # Go back one step (the default) - expected_path = str(pathlib.Path(*pathlib.Path(platform_utils_tests.__file__).resolve().parts[:-1]).resolve()) - actual_path = module_file_path("tests.observatory.platform.utils", nav_back_steps=-1) - self.assertEqual(expected_path, actual_path) - - # Go back two steps - expected_path = str(pathlib.Path(*pathlib.Path(platform_utils_tests.__file__).resolve().parts[:-2]).resolve()) - actual_path = module_file_path("tests.observatory.platform.utils", nav_back_steps=-2) - self.assertEqual(expected_path, actual_path) - - @patch("observatory.platform.config.pathlib.Path.home") - def test_observatory_home(self, home_mock): - runner = CliRunner() - with runner.isolated_filesystem(): - # Create home path and mock getting home path - home_path = "user-home" - os.makedirs(home_path, exist_ok=True) - home_mock.return_value = home_path - - with runner.isolated_filesystem(): - # Test that observatory home works - path = observatory_home() - self.assertTrue(os.path.exists(path)) - self.assertEqual(f"{home_path}/.observatory", path) - - # Test that subdirectories are created - path = observatory_home("subfolder") - self.assertTrue(os.path.exists(path)) - self.assertEqual(f"{home_path}/.observatory/subfolder", path) - - def test_terraform_credentials_path(self): - expected_path = os.path.expanduser("~/.terraform.d/credentials.tfrc.json") - actual_path = terraform_credentials_path() - self.assertEqual(expected_path, actual_path) - - def test_find_schema(self): - schemas_path = test_fixtures_path("schemas") - test_release_date = pendulum.datetime(2022, 11, 11) - previous_release_date = pendulum.datetime(1950, 11, 11) - - # Nonexistent tables test case - result = bq_find_schema(path=schemas_path, table_name="this_table_does_not_exist") - self.assertIsNone(result) - - result = bq_find_schema(path=schemas_path, table_name="does_not_exist", prefix="this_table") - self.assertIsNone(result) - - result = bq_find_schema( - path=schemas_path, table_name="this_table_does_not_exist", release_date=test_release_date - ) - self.assertIsNone(result) - - result = bq_find_schema( - path=schemas_path, table_name="does_not_exist", release_date=test_release_date, prefix="this_table" - ) - self.assertIsNone(result) - - # Release date on table name that doesn't end in date - result = bq_find_schema(path=schemas_path, table_name="table_a", release_date=test_release_date) - self.assertIsNone(result) - - result = bq_find_schema(path=schemas_path, table_name="a", release_date=test_release_date, prefix="table_") - self.assertIsNone(result) - - # Release date before table date - snapshot_date = pendulum.datetime(year=1000, month=1, day=1) - result = bq_find_schema(path=schemas_path, table_name="table_b", release_date=snapshot_date) - self.assertIsNone(result) - - # Basic test case - no date - expected_schema = "table_a.json" - result = bq_find_schema(path=schemas_path, table_name="table_a") - self.assertIsNotNone(result) - self.assertTrue(result.endswith(expected_schema)) - - # Prefix with no date - expected_schema = "table_a.json" - result = bq_find_schema(path=schemas_path, table_name="a", prefix="table_") - self.assertIsNotNone(result) - self.assertTrue(result.endswith(expected_schema)) - - # Table with date - expected_schema = "table_b_2000-01-01.json" - result = bq_find_schema(path=schemas_path, table_name="table_b", release_date=test_release_date) - self.assertIsNotNone(result) - self.assertTrue(result.endswith(expected_schema)) - - # Table with date and prefix - expected_schema = "table_b_2000-01-01.json" - result = bq_find_schema(path=schemas_path, table_name="b", release_date=test_release_date, prefix="table_") - self.assertIsNotNone(result) - self.assertTrue(result.endswith(expected_schema)) - - # Table with old date - expected_schema = "table_b_1900-01-01.json" - result = bq_find_schema(path=schemas_path, table_name="table_b", release_date=previous_release_date) - self.assertIsNotNone(result) - self.assertTrue(result.endswith(expected_schema)) - - # Table with old date and prefix - expected_schema = "table_b_1900-01-01.json" - result = bq_find_schema(path=schemas_path, table_name="b", release_date=previous_release_date, prefix="table_") - self.assertIsNotNone(result) - self.assertTrue(result.endswith(expected_schema)) diff --git a/tests/observatory/platform/test_observatory_config.py b/tests/observatory/platform/test_observatory_config.py deleted file mode 100644 index ec8cf067a..000000000 --- a/tests/observatory/platform/test_observatory_config.py +++ /dev/null @@ -1,1131 +0,0 @@ -# Copyright 2019 Curtin University. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Author: James Diprose, Aniek Roelofs - -import datetime -import os -import pathlib -import random -import string -import unittest -from typing import Dict, List - -import pendulum -import yaml -from click.testing import CliRunner - -from observatory.platform.observatory_config import ( - Backend, - BackendType, - CloudSqlDatabase, - Environment, - GoogleCloud, - Observatory, - ObservatoryConfig, - ObservatoryConfigValidator, - Terraform, - TerraformConfig, - VirtualMachine, - WorkflowsProject, - is_base64, - is_fernet_key, - is_secret_key, - make_schema, - save_yaml, - Workflow, - workflows_to_json_string, - json_string_to_workflows, -) - - -class TestObservatoryConfigValidator(unittest.TestCase): - def setUp(self) -> None: - self.schema = dict() - self.schema["google_cloud"] = { - "required": True, - "type": "dict", - "schema": {"credentials": {"required": True, "type": "string", "google_application_credentials": True}}, - } - - def test_workflows_to_json_string(self): - workflows = [ - Workflow( - dag_id="my_dag", - name="My DAG", - class_name="observatory.platform.workflows.vm_workflow.VmCreateWorkflow", - kwargs=dict(dt=pendulum.datetime(2021, 1, 1)), - ) - ] - json_string = workflows_to_json_string(workflows) - self.assertEqual( - '[{"dag_id": "my_dag", "name": "My DAG", "class_name": "observatory.platform.workflows.vm_workflow.VmCreateWorkflow", "cloud_workspace": null, "kwargs": {"dt": "2021-01-01T00:00:00+00:00"}}]', - json_string, - ) - - def test_json_string_to_workflows(self): - json_string = '[{"dag_id": "my_dag", "name": "My DAG", "class_name": "observatory.platform.workflows.vm_workflow.VmCreateWorkflow", "cloud_workspace": null, "kwargs": {"dt": "2021-01-01T00:00:00+00:00"}}]' - actual_workflows = json_string_to_workflows(json_string) - self.assertEqual( - [ - Workflow( - dag_id="my_dag", - name="My DAG", - class_name="observatory.platform.workflows.vm_workflow.VmCreateWorkflow", - kwargs=dict(dt=pendulum.datetime(2021, 1, 1)), - ) - ], - actual_workflows, - ) - - def test_validate_google_application_credentials(self): - """Check if an error occurs for pointing to a file that does not exist when the - 'google_application_credentials' tag is present in the schema.""" - - with CliRunner().isolated_filesystem(): - # Make google application credentials - credentials_file_path = os.path.join(pathlib.Path().absolute(), "google_application_credentials.json") - with open(credentials_file_path, "w") as f: - f.write("") - validator = ObservatoryConfigValidator(self.schema) - - # google_application_credentials tag and existing file - validator.validate({"google_cloud": {"credentials": credentials_file_path}}) - self.assertEqual(len(validator.errors), 0) - - # google_application_credentials tag and non-existing file - validator.validate({"google_cloud": {"credentials": "missing_file.json"}}) - self.assertEqual(len(validator.errors), 1) - - -class TestObservatoryConfig(unittest.TestCase): - def test_load(self): - # Test that a minimal configuration works - dict_ = { - "backend": {"type": "local", "environment": "develop"}, - "observatory": { - "package": "observatory-platform", - "package_type": "pypi", - "airflow_fernet_key": "IWt5jFGSw2MD1shTdwzLPTFO16G8iEAU3A6mGo_vJTY=", - "airflow_secret_key": "a" * 16, - }, - } - - file_path = "config-valid-minimal.yaml" - with CliRunner().isolated_filesystem(): - save_yaml(file_path, dict_) - - config = ObservatoryConfig.load(file_path) - self.assertIsInstance(config, ObservatoryConfig) - self.assertTrue(config.is_valid) - - file_path = "config-valid-typical.yaml" - with CliRunner().isolated_filesystem(): - credentials_path = os.path.abspath("creds.json") - open(credentials_path, "a").close() - - # Test that a typical configuration works - dict_ = { - "backend": {"type": "local", "environment": "develop"}, - "google_cloud": { - "project_id": "my-project-id", - "credentials": credentials_path, - "data_location": "us", - }, - "observatory": { - "package": "observatory-platform", - "package_type": "pypi", - "airflow_fernet_key": "IWt5jFGSw2MD1shTdwzLPTFO16G8iEAU3A6mGo_vJTY=", - "airflow_secret_key": "a" * 16, - }, - "workflows_projects": [ - { - "package_name": "academic-observatory-workflows", - "package": "/path/to/academic-observatory-workflows", - "package_type": "editable", - "dags_module": "academic_observatory_workflows.dags", - }, - { - "package_name": "oaebu-workflows", - "package": "/path/to/oaebu-workflows/dist/oaebu-workflows.tar.gz", - "package_type": "sdist", - "dags_module": "oaebu_workflows.dags", - }, - ], - "cloud_workspaces": [ - { - "workspace": { - "project_id": "my-project-id", - "download_bucket": "my-download-bucket", - "transform_bucket": "my-transform-bucket", - "data_location": "us", - } - }, - ], - "workflows": [ - { - "dag_id": "my_dag", - "name": "My DAG", - "cloud_workspace": { - "project_id": "my-project-id", - "download_bucket": "my-download-bucket", - "transform_bucket": "my-transform-bucket", - "data_location": "us", - }, - "class_name": "path.to.my_workflow.Workflow", - "kwargs": { - "hello": "world", - "hello_date": datetime.date(2021, 1, 1), - "hello_datetime": datetime.datetime(2021, 1, 1), - }, - # datetime.date gets converted into 2021-01-01 in yaml, which can be read as a date - # same for datetime.datetime - }, - ], - } - - save_yaml(file_path, dict_) - - config = ObservatoryConfig.load(file_path) - self.assertIsInstance(config, ObservatoryConfig) - self.assertTrue(config.is_valid) - - # Test that date value are parsed into pendulums - workflow: Workflow = config.workflows[0] - hello_date = workflow.kwargs["hello_date"] - self.assertIsInstance(hello_date, pendulum.DateTime) - self.assertEqual(pendulum.datetime(2021, 1, 1), hello_date) - - hello_datetime = workflow.kwargs["hello_datetime"] - self.assertIsInstance(hello_datetime, pendulum.DateTime) - self.assertEqual(pendulum.datetime(2021, 1, 1), hello_datetime) - - # Test that an invalid minimal config works - dict_ = {"backend": {"type": "terraform", "environment": "my-env"}, "airflow": {"fernet_key": False}} - - file_path = "config-invalid-minimal.yaml" - with CliRunner().isolated_filesystem(): - save_yaml(file_path, dict_) - - config = ObservatoryConfig.load(file_path) - self.assertIsInstance(config, ObservatoryConfig) - self.assertFalse(config.is_valid) - - # Test that an invalid typical config is loaded by invalid - dict_ = { - "backend": {"type": "terraform", "environment": "my-env"}, - "google_cloud": { - "project_id": "my-project-id", - "credentials": "/path/to/creds.json", - "data_location": 1, - }, - "observatory": {"airflow_fernet_key": "bad", "airflow_secret_key": "bad"}, - "workflows_projects": [ - { - "package_name": "academic-observatory-workflows", - "package_type": "editable", - "dags_module": "academic_observatory_workflows.dags", - }, - { - "package_name": "oaebu-workflows", - "package": "/path/to/oaebu-workflows/dist/oaebu-workflows.tar.gz", - "package_type": "sdist", - "dags_module": False, - }, - ], - "cloud_workspaces": [ - { - "workspace": { - "download_bucket": "my-download-bucket", - "transform_bucket": "my-transform-bucket", - "data_location": "us", - } - }, - ], - "workflows": [ - { - "name": "My DAG", - "cloud_workspace": { - "project_id": "my-project-id", - "download_bucket": "my-download-bucket", - "transform_bucket": "my-transform-bucket", - "data_location": "us", - }, - "class_name": "path.to.my_workflow.Workflow", - "kwargs": {"hello": "world"}, - }, - ], - } - - file_path = "config-invalid-typical.yaml" - with CliRunner().isolated_filesystem(): - save_yaml(file_path, dict_) - - config = ObservatoryConfig.load(file_path) - self.assertIsInstance(config, ObservatoryConfig) - self.assertFalse(config.is_valid) - self.assertEqual(12, len(config.errors)) - - -class TestTerraformConfig(unittest.TestCase): - def test_load(self): - # Test that a minimal configuration works - - file_path = "config-valid-typical.yaml" - with CliRunner().isolated_filesystem(): - credentials_path = os.path.abspath("creds.json") - open(credentials_path, "a").close() - - dict_ = { - "backend": {"type": "terraform", "environment": "develop"}, - "observatory": { - "package": "observatory-platform", - "package_type": "pypi", - "airflow_fernet_key": "IWt5jFGSw2MD1shTdwzLPTFO16G8iEAU3A6mGo_vJTY=", - "airflow_secret_key": "a" * 16, - "airflow_ui_user_password": "password", - "airflow_ui_user_email": "password", - "postgres_password": "my-password", - }, - "terraform": {"organization": "hello world"}, - "google_cloud": { - "project_id": "my-project", - "credentials": credentials_path, - "region": "us-west1", - "zone": "us-west1-c", - "data_location": "us", - }, - "cloud_sql_database": {"tier": "db-custom-2-7680", "backup_start_time": "23:00"}, - "airflow_main_vm": { - "machine_type": "n2-standard-2", - "disk_size": 1, - "disk_type": "pd-ssd", - "create": True, - }, - "airflow_worker_vm": { - "machine_type": "n2-standard-2", - "disk_size": 1, - "disk_type": "pd-standard", - "create": False, - }, - "cloud_workspaces": [ - { - "workspace": { - "project_id": "my-project-id", - "download_bucket": "my-download-bucket", - "transform_bucket": "my-transform-bucket", - "data_location": "us", - } - }, - ], - "workflows": [ - { - "dag_id": "my_dag", - "name": "My DAG", - "cloud_workspace": { - "project_id": "my-project-id", - "download_bucket": "my-download-bucket", - "transform_bucket": "my-transform-bucket", - "data_location": "us", - }, - "class_name": "path.to.my_workflow.Workflow", - "kwargs": {"hello": "world"}, - }, - ], - } - - save_yaml(file_path, dict_) - - config = TerraformConfig.load(file_path) - self.assertIsInstance(config, TerraformConfig) - self.assertTrue(config.is_valid) - - file_path = "config-valid-typical.yaml" - with CliRunner().isolated_filesystem(): - credentials_path = os.path.abspath("creds.json") - open(credentials_path, "a").close() - - # Test that a typical configuration is loaded - dict_ = { - "backend": {"type": "terraform", "environment": "develop"}, - "observatory": { - "package": "observatory-platform", - "package_type": "pypi", - "airflow_fernet_key": "IWt5jFGSw2MD1shTdwzLPTFO16G8iEAU3A6mGo_vJTY=", - "airflow_secret_key": "a" * 16, - "airflow_ui_user_password": "password", - "airflow_ui_user_email": "password", - "postgres_password": "my-password", - }, - "terraform": {"organization": "hello world"}, - "google_cloud": { - "project_id": "my-project", - "credentials": credentials_path, - "region": "us-west1", - "zone": "us-west1-c", - "data_location": "us", - }, - "cloud_sql_database": {"tier": "db-custom-2-7680", "backup_start_time": "23:00"}, - "airflow_main_vm": { - "machine_type": "n2-standard-2", - "disk_size": 1, - "disk_type": "pd-ssd", - "create": True, - }, - "airflow_worker_vm": { - "machine_type": "n2-standard-2", - "disk_size": 1, - "disk_type": "pd-standard", - "create": False, - }, - "workflows_projects": [ - { - "package_name": "academic-observatory-workflows", - "package": "/path/to/academic-observatory-workflows", - "package_type": "editable", - "dags_module": "academic_observatory_workflows.dags", - }, - { - "package_name": "oaebu-workflows", - "package": "/path/to/oaebu-workflows/dist/oaebu-workflows.tar.gz", - "package_type": "sdist", - "dags_module": "oaebu_workflows.dags", - }, - ], - } - - save_yaml(file_path, dict_) - - config = TerraformConfig.load(file_path) - self.assertIsInstance(config, TerraformConfig) - self.assertTrue(config.is_valid) - - # Test that an invalid minimal config is loaded and invalid - dict_ = { - "backend": {"type": "local", "environment": "develop"}, - "airflow": { - "package": "observatory-platform", - "package_type": "pypi", - "fernet_key": "random-fernet-key", - "secret_key": "random-secret-key", - "ui_user_password": "password", - "ui_user_email": "password", - }, - "terraform": {"organization": "hello world"}, - "google_cloud": { - "project_id": "my-project", - "credentials": "/path/to/creds.json", - "region": "us-west", - "zone": "us-west1", - "data_location": "us", - "buckets": { - "download_bucket": "my-download-bucket-1234", - "transform_bucket": "my-transform-bucket-1234", - }, - }, - "cloud_sql_database": {"tier": "db-custom-2-7680", "backup_start_time": "2300"}, - "airflow_main_vm": {"machine_type": "n2-standard-2", "disk_size": 0, "disk_type": "disk", "create": True}, - "airflow_worker_vm": { - "machine_type": "n2-standard-2", - "disk_size": 0, - "disk_type": "disk", - "create": False, - }, - } - - file_path = "config-invalid-minimal.yaml" - with CliRunner().isolated_filesystem(): - save_yaml(file_path, dict_) - - config = TerraformConfig.load(file_path) - self.assertIsInstance(config, TerraformConfig) - self.assertFalse(config.is_valid) - - # Test that an invalid typical config is loaded and invalid - dict_ = { - "backend": {"type": "terraform", "environment": "develop"}, - "airflow": { - "package": "observatory-platform", - "package_type": "pypi", - "fernet_key": "random-fernet-key", - "secret_key": "random-secret-key", - "ui_user_password": "password", - "ui_user_email": "password", - }, - "terraform": {"organization": "hello world"}, - "google_cloud": { - "project_id": "my-project", - "credentials": "/path/to/creds.json", - "region": "us-west1", - "zone": "us-west1-c", - "data_location": "us", - "buckets": { - "download_bucket": "my-download-bucket-1234", - "transform_bucket": "my-transform-bucket-1234", - }, - }, - "cloud_sql_database": {"tier": "db-custom-2-7680", "backup_start_time": "23:00"}, - "airflow_main_vm": {"machine_type": "n2-standard-2", "disk_size": 1, "disk_type": "pd-ssd", "create": True}, - "airflow_worker_vm": { - "machine_type": "n2-standard-2", - "disk_size": 1, - "disk_type": "pd-standard", - "create": False, - }, - "airflow_variables": {"my-variable-name": 1}, - "airflow_connections": {"my-connection": "my-token"}, - "workflows_projects": [ - { - "package_name": "academic-observatory-workflows", - "package_type": "editable", - "dags_module": "academic_observatory_workflows.dags", - }, - { - "package_name": "oaebu-workflows", - "package": "/path/to/oaebu-workflows/dist/oaebu-workflows.tar.gz", - "package_type": "sdist", - "dags_module": False, - }, - ], - "cloud_workspaces": [ - { - "workspace": { - "download_bucket": "my-download-bucket", - "transform_bucket": "my-transform-bucket", - "data_location": "us", - } - }, - ], - "workflows": [ - { - "name": "My DAG", - "cloud_workspace": { - "project_id": "my-project-id", - "download_bucket": "my-download-bucket", - "transform_bucket": "my-transform-bucket", - "data_location": "us", - }, - "class_name": "path.to.my_workflow.Workflow", - "kwargs": {"hello": "world"}, - }, - ], - } - - file_path = "config-invalid-typical.yaml" - with CliRunner().isolated_filesystem(): - save_yaml(file_path, dict_) - - config = TerraformConfig.load(file_path) - self.assertIsInstance(config, TerraformConfig) - self.assertFalse(config.is_valid) - self.assertEqual(10, len(config.errors)) - - -class TestSchema(unittest.TestCase): - def assert_sub_schema_valid( - self, valid_docs: List[Dict], invalid_docs: List[Dict], schema, sub_schema_key, expected_errors - ): - validator = ObservatoryConfigValidator() - sub_schema = dict() - sub_schema[sub_schema_key] = schema[sub_schema_key] - - # Assert that docs expected to be valid are valid - for doc in valid_docs: - is_valid = validator.validate(doc, sub_schema) - self.assertTrue(is_valid) - - # Assert that docs that are expected to be invalid are invalid - for doc, error in zip(invalid_docs, expected_errors): - is_valid = validator.validate(doc, sub_schema) - self.assertFalse(is_valid) - self.assertDictEqual(validator.errors, error) - - def assert_schema_keys(self, schema: Dict, contains: List, not_contains: List): - # Assert that keys are in schema - for key in contains: - self.assertTrue(key in schema) - - # Assert that keys aren't in schema - for key in not_contains: - self.assertTrue(key not in schema) - - def test_local_schema_keys(self): - # Test that local schema keys exist and that terraform only keys don't exist - schema = make_schema(BackendType.local) - contains = [ - "backend", - "terraform", - "google_cloud", - "observatory", - "workflows_projects", - "cloud_workspaces", - "workflows", - ] - not_contains = ["cloud_sql_database", "airflow_main_vm", "airflow_worker_vm"] - self.assert_schema_keys(schema, contains, not_contains) - - def test_terraform_schema_keys(self): - # Test that terraform schema keys exist - schema = make_schema(BackendType.terraform) - contains = [ - "backend", - "terraform", - "google_cloud", - "observatory", - "cloud_sql_database", - "airflow_main_vm", - "airflow_worker_vm", - "workflows_projects", - "cloud_workspaces", - "workflows", - ] - not_contains = [] - self.assert_schema_keys(schema, contains, not_contains) - - def test_local_schema_backend(self): - schema = make_schema(BackendType.local) - schema_key = "backend" - - valid_docs = [ - {"backend": {"type": "local", "environment": "develop"}}, - {"backend": {"type": "local", "environment": "staging"}}, - {"backend": {"type": "local", "environment": "production"}}, - ] - invalid_docs = [{"backend": {"type": "terraform", "environment": "hello"}}] - expected_errors = [ - {"backend": [{"environment": ["unallowed value hello"], "type": ["unallowed value terraform"]}]} - ] - self.assert_sub_schema_valid(valid_docs, invalid_docs, schema, schema_key, expected_errors) - - def test_local_schema_terraform(self): - schema = make_schema(BackendType.local) - schema_key = "terraform" - - valid_docs = [{}, {"terraform": {"organization": "hello world"}}] - invalid_docs = [{"terraform": {"organization": 0}}, {"terraform": {"organization": dict()}}, {"terraform": {}}] - expected_errors = [ - {"terraform": [{"organization": ["must be of string type"]}]}, - {"terraform": [{"organization": ["must be of string type"]}]}, - {"terraform": [{"organization": ["required field"]}]}, - ] - self.assert_sub_schema_valid(valid_docs, invalid_docs, schema, schema_key, expected_errors) - - def test_local_schema_google_cloud(self): - schema = make_schema(BackendType.local) - schema_key = "google_cloud" - - with CliRunner().isolated_filesystem(): - credentials_path = os.path.abspath("creds.json") - open(credentials_path, "a").close() - - valid_docs = [ - {}, - { - "google_cloud": { - "project_id": "my-project", - "credentials": credentials_path, - "region": "us-west1", - "zone": "us-west1-c", - "data_location": "us", - } - }, - ] - invalid_docs = [ - { - "google_cloud": { - "project_id": 1, - "credentials": "/path/to/creds.json", - "region": "us-west", - "zone": "us-west1", - "data_location": list(), - } - } - ] - - expected_errors = [ - { - "google_cloud": [ - { - "credentials": [ - "the file /path/to/creds.json does not exist. See https://cloud.google.com/docs/authentication/getting-started for instructions on how to create a service account and save the JSON key to your workstation." - ], - "project_id": ["must be of string type"], - "data_location": ["must be of string type"], - "region": ["value does not match regex '^\\w+\\-\\w+\\d+$'"], - "zone": ["value does not match regex '^\\w+\\-\\w+\\d+\\-[a-z]{1}$'"], - } - ] - } - ] - self.assert_sub_schema_valid(valid_docs, invalid_docs, schema, schema_key, expected_errors) - - def test_local_schema_observatory(self): - schema = make_schema(BackendType.local) - schema_key = "observatory" - - valid_docs = [ - { - "observatory": { - "package": "observatory-platform", - "package_type": "pypi", - "airflow_fernet_key": "IWt5jFGSw2MD1shTdwzLPTFO16G8iEAU3A6mGo_vJTY=", - "airflow_secret_key": "a" * 16, - } - }, - { - "observatory": { - "package": "/path/to/observatory-platform", - "package_type": "editable", - "airflow_fernet_key": "IWt5jFGSw2MD1shTdwzLPTFO16G8iEAU3A6mGo_vJTY=", - "airflow_secret_key": "a" * 16, - "airflow_ui_user_password": "password", - "airflow_ui_user_email": "password", - } - }, - ] - invalid_docs = [ - {}, - {"observatory": {"airflow_ui_user_password": "password", "airflow_ui_user_email": "password"}}, - ] - - expected_errors = [ - {"observatory": ["required field"]}, - { - "observatory": [ - { - "package": ["required field"], - "package_type": ["required field"], - "airflow_fernet_key": ["required field"], - "airflow_secret_key": ["required field"], - } - ] - }, - ] - self.assert_sub_schema_valid(valid_docs, invalid_docs, schema, schema_key, expected_errors) - - def test_local_schema_workflows_projects(self): - schema = make_schema(BackendType.local) - schema_key = "workflows_projects" - - valid_docs = [ - {}, - { - "workflows_projects": [ - { - "package_name": "academic-observatory-workflows", - "package": "/path/to/academic-observatory-workflows", - "package_type": "editable", - "dags_module": "academic_observatory_workflows.dags", - }, - { - "package_name": "oaebu-workflows", - "package": "/path/to/oaebu-workflows/dist/oaebu-workflows.tar.gz", - "package_type": "sdist", - "dags_module": "oaebu_workflows.dags", - }, - ], - }, - ] - invalid_docs = [ - { - "workflows_projects": [ - { - "package_name": "academic-observatory-workflows", - "package": "/path/to/academic-observatory-workflows", - } - ] - } - ] - - expected_errors = [ - {"workflows_projects": [{0: [{"package_type": ["required field"], "dags_module": ["required field"]}]}]} - ] - self.assert_sub_schema_valid(valid_docs, invalid_docs, schema, schema_key, expected_errors) - - def test_terraform_schema_backend(self): - schema = make_schema(BackendType.terraform) - schema_key = "backend" - - valid_docs = [{"backend": {"type": "terraform", "environment": "develop"}}] - invalid_docs = [{"backend": {"type": "local", "environment": "develop"}}] - expected_errors = [{"backend": [{"type": ["unallowed value local"]}]}] - self.assert_sub_schema_valid(valid_docs, invalid_docs, schema, schema_key, expected_errors) - - def test_terraform_schema_terraform(self): - # Test that terraform is required - schema = make_schema(BackendType.terraform) - schema_key = "terraform" - - valid_docs = [{"terraform": {"organization": "hello world"}}] - invalid_docs = [{}] - expected_errors = [{"terraform": ["required field"]}] - self.assert_sub_schema_valid(valid_docs, invalid_docs, schema, schema_key, expected_errors) - - def test_terraform_schema_google_cloud(self): - # Test that all google cloud fields are required - schema = make_schema(BackendType.terraform) - schema_key = "google_cloud" - - with CliRunner().isolated_filesystem(): - credentials_path = os.path.abspath("creds.json") - open(credentials_path, "a").close() - - valid_docs = [ - { - "google_cloud": { - "project_id": "my-project", - "credentials": credentials_path, - "region": "us-west1", - "zone": "us-west1-c", - "data_location": "us", - } - } - ] - invalid_docs = [{}, {"google_cloud": {}}] - - expected_errors = [ - {"google_cloud": ["required field"]}, - { - "google_cloud": [ - { - "credentials": ["required field"], - "data_location": ["required field"], - "project_id": ["required field"], - "region": ["required field"], - "zone": ["required field"], - } - ] - }, - ] - self.assert_sub_schema_valid(valid_docs, invalid_docs, schema, schema_key, expected_errors) - - def test_terraform_schema_observatory(self): - # Test that airflow ui password and email required - schema = make_schema(BackendType.terraform) - schema_key = "observatory" - - valid_docs = [ - { - "observatory": { - "package": "/path/to/observatory-platform/observatory-platform.tar.gz", - "package_type": "sdist", - "airflow_fernet_key": "IWt5jFGSw2MD1shTdwzLPTFO16G8iEAU3A6mGo_vJTY=", - "airflow_secret_key": "a" * 16, - "airflow_ui_user_password": "password", - "airflow_ui_user_email": "password", - "postgres_password": "password", - } - } - ] - invalid_docs = [{}, {"observatory": {}}] - - expected_errors = [ - {"observatory": ["required field"]}, - { - "observatory": [ - { - "package": ["required field"], - "package_type": ["required field"], - "airflow_fernet_key": ["required field"], - "airflow_secret_key": ["required field"], - "airflow_ui_user_email": ["required field"], - "airflow_ui_user_password": ["required field"], - "postgres_password": ["required field"], - } - ] - }, - ] - self.assert_sub_schema_valid(valid_docs, invalid_docs, schema, schema_key, expected_errors) - - def test_terraform_schema_database(self): - # Test database schema - schema = make_schema(BackendType.terraform) - schema_key = "cloud_sql_database" - - valid_docs = [{"cloud_sql_database": {"tier": "db-custom-2-7680", "backup_start_time": "23:00"}}] - invalid_docs = [ - {}, - {"cloud_sql_database": {}}, - {"cloud_sql_database": {"tier": 1, "backup_start_time": "2300"}}, - ] - - expected_errors = [ - {"cloud_sql_database": ["required field"]}, - {"cloud_sql_database": [{"backup_start_time": ["required field"], "tier": ["required field"]}]}, - { - "cloud_sql_database": [ - { - "backup_start_time": ["value does not match regex '^\\d{2}:\\d{2}$'"], - "tier": ["must be of string type"], - } - ] - }, - ] - self.assert_sub_schema_valid(valid_docs, invalid_docs, schema, schema_key, expected_errors) - - def assert_vm_schema(self, schema_key: str): - schema = make_schema(BackendType.terraform) - valid_docs = [ - { - schema_key: { - "machine_type": "n2-standard-2", - "disk_size": 1, - "disk_type": "pd-standard", - "create": False, - } - }, - {schema_key: {"machine_type": "n2-standard-2", "disk_size": 100, "disk_type": "pd-ssd", "create": True}}, - ] - invalid_docs = [ - {}, - {schema_key: {}}, - {schema_key: {"machine_type": 1, "disk_size": 0, "disk_type": "develop", "create": "True"}}, - ] - - expected_errors = [ - {schema_key: ["required field"]}, - { - schema_key: [ - { - "create": ["required field"], - "disk_size": ["required field"], - "disk_type": ["required field"], - "machine_type": ["required field"], - } - ] - }, - { - schema_key: [ - { - "create": ["must be of boolean type"], - "disk_size": ["min value is 1"], - "disk_type": ["unallowed value develop"], - "machine_type": ["must be of string type"], - } - ] - }, - ] - self.assert_sub_schema_valid(valid_docs, invalid_docs, schema, schema_key, expected_errors) - - def test_terraform_schema_vms(self): - # Test VM schema - self.assert_vm_schema("airflow_main_vm") - self.assert_vm_schema("airflow_worker_vm") - - -def tmp_config_file(dict_: dict) -> str: - """ - Dumps dict into a yaml file that is saved in a randomly named file. Used to as config file to create - ObservatoryConfig instance. - :param dict_: config dict - :return: path of temporary file - """ - content = yaml.safe_dump(dict_).replace("'!", "!").replace("':", ":") - file_name = "".join(random.choices(string.ascii_lowercase, k=10)) - with open(file_name, "w") as f: - f.write(content) - return file_name - - -class TestObservatoryConfigGeneration(unittest.TestCase): - def test_get_requirement_string(self): - with CliRunner().isolated_filesystem(): - config = ObservatoryConfig() - requirement = config.get_requirement_string("backend") - self.assertEqual(requirement, "Required") - - requirement = config.get_requirement_string("google_cloud") - self.assertEqual(requirement, "Optional") - - def test_save_observatory_config(self): - config = ObservatoryConfig( - terraform=Terraform(organization="myorg"), - backend=Backend(type=BackendType.local, environment=Environment.staging), - observatory=Observatory( - package="observatory-platform", - package_type="editable", - observatory_home="home", - postgres_password="pass", - redis_port=111, - airflow_ui_user_password="pass", - airflow_ui_user_email="email@email", - flower_ui_port=10, - airflow_ui_port=23, - docker_network_name="name", - docker_compose_project_name="proj", - docker_network_is_external=True, - api_package="api", - api_package_type="sdist", - api_port=123, - ), - google_cloud=GoogleCloud( - project_id="myproject", - credentials="config.yaml", - data_location="us", - ), - workflows_projects=[ - WorkflowsProject(package_name="myname", package="path", package_type="editable", dags_module="module") - ], - ) - - with CliRunner().isolated_filesystem(): - file = "config.yaml" - config.save(path=file) - self.assertTrue(os.path.exists(file)) - - loaded = ObservatoryConfig.load(file) - - self.assertEqual(loaded.backend, config.backend) - self.assertEqual(loaded.observatory, config.observatory) - self.assertEqual(loaded.terraform, config.terraform) - self.assertEqual(loaded.google_cloud, config.google_cloud) - self.assertEqual(loaded.workflows_projects, config.workflows_projects) - - def test_save_terraform_config(self): - config = TerraformConfig( - backend=Backend(type=BackendType.terraform, environment=Environment.staging), - observatory=Observatory(package="observatory-platform", package_type="editable"), - google_cloud=GoogleCloud( - project_id="myproject", - credentials="config.yaml", - data_location="us", - region="us-west1", - zone="us-west1-a", - ), - terraform=Terraform(organization="myorg"), - cloud_sql_database=CloudSqlDatabase(tier="test", backup_start_time="12:00"), - airflow_main_vm=VirtualMachine(machine_type="aa", disk_size=1, disk_type="pd-standard", create=False), - airflow_worker_vm=VirtualMachine(machine_type="bb", disk_size=1, disk_type="pd-ssd", create=True), - ) - - file = "config.yaml" - - with CliRunner().isolated_filesystem(): - config.save(path=file) - self.assertTrue(os.path.exists(file)) - loaded = TerraformConfig.load(file) - - self.assertEqual(loaded.backend, config.backend) - self.assertEqual(loaded.terraform, config.terraform) - self.assertEqual(loaded.google_cloud, config.google_cloud) - self.assertEqual(loaded.observatory, config.observatory) - self.assertEqual(loaded.cloud_sql_database, config.cloud_sql_database) - self.assertEqual(loaded.airflow_main_vm, config.airflow_main_vm) - self.assertEqual(loaded.airflow_worker_vm, config.airflow_worker_vm) - - def test_save_observatory_config_defaults(self): - config = ObservatoryConfig( - backend=Backend(type=BackendType.local, environment=Environment.staging), - ) - - with CliRunner().isolated_filesystem(): - file = "config.yaml" - config.save(path=file) - self.assertTrue(os.path.exists(file)) - - loaded = ObservatoryConfig.load(file) - self.assertEqual(loaded.backend, config.backend) - self.assertEqual(loaded.terraform, Terraform(organization=None)) - self.assertEqual(loaded.google_cloud.project_id, None) - self.assertEqual(loaded.observatory, config.observatory) - - def test_save_terraform_config_defaults(self): - config = TerraformConfig( - backend=Backend(type=BackendType.terraform, environment=Environment.staging), - observatory=Observatory(), - google_cloud=GoogleCloud( - project_id="myproject", - credentials="config.yaml", - data_location="us", - region="us-west1", - zone="us-west1-a", - ), - terraform=Terraform(organization="myorg"), - ) - - file = "config.yaml" - - with CliRunner().isolated_filesystem(): - config.save(path=file) - self.assertTrue(os.path.exists(file)) - loaded = TerraformConfig.load(file) - - self.assertEqual(loaded.backend, config.backend) - self.assertEqual(loaded.terraform, config.terraform) - self.assertEqual(loaded.google_cloud, config.google_cloud) - self.assertEqual(loaded.observatory, config.observatory) - - self.assertEqual( - loaded.cloud_sql_database, - CloudSqlDatabase( - tier="db-custom-2-7680", - backup_start_time="23:00", - ), - ) - - self.assertEqual( - loaded.airflow_main_vm, - VirtualMachine( - machine_type="n2-standard-2", - disk_size=50, - disk_type="pd-ssd", - create=True, - ), - ) - - self.assertEqual( - loaded.airflow_worker_vm, - VirtualMachine( - machine_type="n1-standard-8", - disk_size=3000, - disk_type="pd-standard", - create=False, - ), - ) - - -class TestKeyCheckers(unittest.TestCase): - def test_is_base64(self): - text = b"bWFrZSB0aGlzIHZhbGlk" - self.assertTrue(is_base64(text)) - - text = b"This is invalid base64" - self.assertFalse(is_base64(text)) - - def test_is_secret_key(self): - text = "invalid length" - valid, message = is_secret_key(text) - self.assertFalse(valid) - self.assertEqual(message, "Secret key should be length >=16, but is length 14.") - - text = "a" * 16 - valid, message = is_secret_key(text) - self.assertTrue(valid) - self.assertEqual(message, None) - - def test_is_fernet_key(self): - text = "invalid key" - valid, message = is_fernet_key(text) - self.assertFalse(valid) - self.assertEqual(message, f"Key {text} could not be urlsafe b64decoded.") - - text = "IWt5jFGSw2MD1shTdwzLPTFO16G8iEAU3A6mGo_vJTY=" - valid, message = is_fernet_key(text) - self.assertTrue(valid) - self.assertEqual(message, None) - - text = "[]}{!*/~inv" * 4 - valid, message = is_fernet_key(text) - self.assertFalse(valid) - self.assertEqual(message, "Decoded Fernet key should be length 32, but is length 12.") diff --git a/tests/observatory/platform/test_observatory_environment.py b/tests/observatory/platform/test_observatory_environment.py deleted file mode 100644 index 274568ddf..000000000 --- a/tests/observatory/platform/test_observatory_environment.py +++ /dev/null @@ -1,857 +0,0 @@ -# Copyright 2021 Curtin University -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Author: James Diprose, Aniek Roelofs - -from __future__ import annotations - -import contextlib -import logging -import os -import unittest -from datetime import datetime, timedelta, timezone -from pathlib import Path -from typing import List, Union -from unittest.mock import patch -from ftplib import FTP -import tempfile -import json - -import croniter -import httpretty -import pendulum -import pysftp -import timeout_decorator -from airflow.models.connection import Connection -from airflow.models.variable import Variable -from click.testing import CliRunner -from google.cloud.bigquery import SourceFormat -from google.cloud.exceptions import NotFound - -from observatory.platform.bigquery import bq_create_dataset, bq_load_table, bq_table_id -from observatory.platform.config import AirflowVars -from observatory.platform.gcs import gcs_upload_file, gcs_blob_uri -from observatory.platform.observatory_environment import ( - HttpServer, - ObservatoryEnvironment, - ObservatoryTestCase, - SftpServer, - FtpServer, - random_id, - test_fixtures_path, - find_free_port, - load_and_parse_json, -) -from observatory.platform.utils.http_download import ( - DownloadInfo, - download_file, - download_files, -) -from observatory.platform.utils.url_utils import retry_session -from observatory.platform.workflows.workflow import Workflow, Release - -DAG_ID = "telescope-test" -MY_VAR_ID = "my-variable" -MY_CONN_ID = "my-connection" -DAG_FILE_CONTENT = """ -# The keywords airflow and DAG are required to load the DAGs from this file, see bullet 2 in the Apache Airflow FAQ: -# https://airflow.apache.org/docs/stable/faq.html - -from tests.observatory.platform.test_observatory_environment import TelescopeTest - -telescope = TelescopeTest() -globals()['test-telescope'] = telescope.make_dag() -""" - - -class TelescopeTest(Workflow): - """A telescope for testing purposes""" - - def __init__( - self, - dag_id: str = DAG_ID, - start_date: pendulum.DateTime = pendulum.datetime(2020, 9, 1, tz="UTC"), - schedule: str = "@weekly", - setup_task_result: bool = True, - ): - airflow_vars = [ - AirflowVars.DATA_PATH, - MY_VAR_ID, - ] - airflow_conns = [MY_CONN_ID] - super().__init__(dag_id, start_date, schedule, airflow_vars=airflow_vars, airflow_conns=airflow_conns) - self.setup_task_result = setup_task_result - self.add_setup_task(self.check_dependencies) - self.add_setup_task(self.setup_task) - self.add_task(self.my_task) - - def make_release(self, **kwargs) -> Union[Release, List[Release]]: - return None - - def setup_task(self, **kwargs): - logging.info("setup_task success") - return self.setup_task_result - - def my_task(self, release, **kwargs): - logging.info("my_task success") - - -class TestObservatoryEnvironment(unittest.TestCase): - """Test the ObservatoryEnvironment""" - - def __init__(self, *args, **kwargs): - super(TestObservatoryEnvironment, self).__init__(*args, **kwargs) - self.project_id = os.getenv("TEST_GCP_PROJECT_ID") - self.data_location = os.getenv("TEST_GCP_DATA_LOCATION") - - def test_add_bucket(self): - """Test the add_bucket method""" - - env = ObservatoryEnvironment(self.project_id, self.data_location) - - # The download and transform buckets are added in the constructor - buckets = list(env.buckets.keys()) - self.assertEqual(2, len(buckets)) - self.assertEqual(env.download_bucket, buckets[0]) - self.assertEqual(env.transform_bucket, buckets[1]) - - # Test that calling add bucket adds a new bucket to the buckets list - name = env.add_bucket() - buckets = list(env.buckets.keys()) - self.assertEqual(name, buckets[-1]) - - # No Google Cloud variables raises error - with self.assertRaises(AssertionError): - ObservatoryEnvironment().add_bucket() - - def test_create_delete_bucket(self): - """Test _create_bucket and _delete_bucket""" - - env = ObservatoryEnvironment(self.project_id, self.data_location) - - bucket_id = "obsenv_tests_" + random_id() - - # Create bucket - env._create_bucket(bucket_id) - bucket = env.storage_client.bucket(bucket_id) - self.assertTrue(bucket.exists()) - - # Delete bucket - env._delete_bucket(bucket_id) - self.assertFalse(bucket.exists()) - - # Test double delete is handled gracefully - env._delete_bucket(bucket_id) - - # Test create a bucket with a set of roles - roles = {"roles/storage.objectViewer", "roles/storage.legacyBucketWriter"} - env._create_bucket(bucket_id, roles=roles) - bucket = env.storage_client.bucket(bucket_id) - bucket_policy = bucket.get_iam_policy() - for role in roles: - self.assertTrue({"role": role, "members": {"allUsers"}} in bucket_policy) - - # No Google Cloud variables raises error - bucket_id = "obsenv_tests_" + random_id() - with self.assertRaises(AssertionError): - ObservatoryEnvironment()._create_bucket(bucket_id) - with self.assertRaises(AssertionError): - ObservatoryEnvironment()._delete_bucket(bucket_id) - - def test_add_delete_dataset(self): - """Test add_dataset and _delete_dataset""" - - # Create dataset - env = ObservatoryEnvironment(self.project_id, self.data_location) - - dataset_id = env.add_dataset() - bq_create_dataset(project_id=self.project_id, dataset_id=dataset_id, location=self.data_location) - - # Check that dataset exists: should not raise NotFound exception - dataset_id = f"{self.project_id}.{dataset_id}" - env.bigquery_client.get_dataset(dataset_id) - - # Delete dataset - env._delete_dataset(dataset_id) - - # Check that dataset doesn't exist - with self.assertRaises(NotFound): - env.bigquery_client.get_dataset(dataset_id) - - # No Google Cloud variables raises error - with self.assertRaises(AssertionError): - ObservatoryEnvironment().add_dataset() - with self.assertRaises(AssertionError): - ObservatoryEnvironment()._delete_dataset(random_id()) - - def test_create(self): - """Tests create, add_variable, add_connection and run_task""" - - expected_state = "success" - - # Setup Telescope - execution_date = pendulum.datetime(year=2020, month=11, day=1) - telescope = TelescopeTest() - dag = telescope.make_dag() - - # Test that previous tasks have to be finished to run next task - env = ObservatoryEnvironment(self.project_id, self.data_location) - - with env.create(task_logging=True): - with env.create_dag_run(dag, execution_date): - # Add_variable - env.add_variable(Variable(key=MY_VAR_ID, val="hello")) - - # Add connection - conn = Connection( - conn_id=MY_CONN_ID, uri="mysql://login:password@host:8080/schema?param1=val1¶m2=val2" - ) - env.add_connection(conn) - - # Test run task when dependencies are not met - ti = env.run_task(telescope.setup_task.__name__) - self.assertIsNone(ti.state) - - # Try again when dependencies are met - ti = env.run_task(telescope.check_dependencies.__name__) - self.assertEqual(expected_state, ti.state) - - ti = env.run_task(telescope.setup_task.__name__) - self.assertEqual(expected_state, ti.state) - - ti = env.run_task(telescope.my_task.__name__) - self.assertEqual(expected_state, ti.state) - - # Test that tasks are skipped when setup task returns False - telescope = TelescopeTest(setup_task_result=False) - dag = telescope.make_dag() - env = ObservatoryEnvironment(self.project_id, self.data_location) - with env.create(task_logging=True): - with env.create_dag_run(dag, execution_date): - # Add_variable - env.add_variable(Variable(key=MY_VAR_ID, val="hello")) - - # Add connection - conn = Connection( - conn_id=MY_CONN_ID, uri="mysql://login:password@host:8080/schema?param1=val1¶m2=val2" - ) - env.add_connection(conn) - - ti = env.run_task(telescope.check_dependencies.__name__) - self.assertEqual(expected_state, ti.state) - - ti = env.run_task(telescope.setup_task.__name__) - self.assertEqual(expected_state, ti.state) - - expected_state = "skipped" - ti = env.run_task(telescope.my_task.__name__) - self.assertEqual(expected_state, ti.state) - - def test_task_logging(self): - """Test task logging""" - - expected_state = "success" - env = ObservatoryEnvironment(self.project_id, self.data_location) - - # Setup Telescope - execution_date = pendulum.datetime(year=2020, month=11, day=1) - telescope = TelescopeTest() - dag = telescope.make_dag() - - # Test environment without logging enabled - with env.create(): - with env.create_dag_run(dag, execution_date): - # Test add_variable - env.add_variable(Variable(key=MY_VAR_ID, val="hello")) - - # Test add_connection - conn = Connection( - conn_id=MY_CONN_ID, uri="mysql://login:password@host:8080/schema?param1=val1¶m2=val2" - ) - env.add_connection(conn) - - # Test run task - ti = env.run_task(telescope.check_dependencies.__name__) - self.assertFalse(ti.log.propagate) - self.assertEqual(expected_state, ti.state) - - # Test environment with logging enabled - env = ObservatoryEnvironment(self.project_id, self.data_location) - with env.create(task_logging=True): - with env.create_dag_run(dag, execution_date): - # Test add_variable - env.add_variable(Variable(key=MY_VAR_ID, val="hello")) - - # Test add_connection - conn = Connection( - conn_id=MY_CONN_ID, uri="mysql://login:password@host:8080/schema?param1=val1¶m2=val2" - ) - env.add_connection(conn) - - # Test run task - ti = env.run_task(telescope.check_dependencies.__name__) - self.assertTrue(ti.log.propagate) - self.assertEqual(expected_state, ti.state) - - def test_create_dagrun(self): - """Tests create_dag_run""" - - env = ObservatoryEnvironment(self.project_id, self.data_location) - - # Setup Telescope - first_execution_date = pendulum.datetime(year=2020, month=11, day=1, tz="UTC") - second_execution_date = pendulum.datetime(year=2020, month=12, day=1, tz="UTC") - telescope = TelescopeTest() - dag = telescope.make_dag() - - # Get start dates outside of - first_start_date = croniter.croniter(dag.normalized_schedule_interval, first_execution_date).get_next( - pendulum.DateTime - ) - second_start_date = croniter.croniter(dag.normalized_schedule_interval, second_execution_date).get_next( - pendulum.DateTime - ) - - # Use DAG run with freezing time - with env.create(): - # Test add_variable - env.add_variable(Variable(key=MY_VAR_ID, val="hello")) - - # Test add_connection - conn = Connection(conn_id=MY_CONN_ID, uri="mysql://login:password@host:8080/schema?param1=val1¶m2=val2") - env.add_connection(conn) - - self.assertIsNone(env.dag_run) - # First DAG Run - with env.create_dag_run(dag, first_execution_date): - # Test DAG Run is set and has frozen start date - self.assertIsNotNone(env.dag_run) - self.assertEqual(first_start_date.date(), env.dag_run.start_date.date()) - - ti1 = env.run_task(telescope.check_dependencies.__name__) - self.assertEqual("success", ti1.state) - self.assertIsNone(ti1.previous_ti) - - with env.create_dag_run(dag, second_execution_date): - # Test DAG Run is set and has frozen start date - self.assertIsNotNone(env.dag_run) - self.assertEqual(second_start_date, env.dag_run.start_date) - - ti2 = env.run_task(telescope.check_dependencies.__name__) - self.assertEqual("success", ti2.state) - # Test previous ti is set - self.assertEqual(ti1.job_id, ti2.previous_ti.job_id) - - # Use DAG run without freezing time - env = ObservatoryEnvironment(self.project_id, self.data_location) - with env.create(): - # Test add_variable - env.add_variable(Variable(key=MY_VAR_ID, val="hello")) - - # Test add_connection - conn = Connection(conn_id=MY_CONN_ID, uri="mysql://login:password@host:8080/schema?param1=val1¶m2=val2") - env.add_connection(conn) - - # First DAG Run - with env.create_dag_run(dag, first_execution_date): - # Test DAG Run is set and has today as start date - self.assertIsNotNone(env.dag_run) - self.assertEqual(first_start_date, env.dag_run.start_date) - - ti1 = env.run_task(telescope.check_dependencies.__name__) - self.assertEqual("success", ti1.state) - self.assertIsNone(ti1.previous_ti) - - # Second DAG Run - with env.create_dag_run(dag, second_execution_date): - # Test DAG Run is set and has today as start date - self.assertIsNotNone(env.dag_run) - self.assertEqual(second_start_date, env.dag_run.start_date) - - ti2 = env.run_task(telescope.check_dependencies.__name__) - self.assertEqual("success", ti2.state) - # Test previous ti is set - self.assertEqual(ti1.job_id, ti2.previous_ti.job_id) - - def test_create_dag_run_timedelta(self): - env = ObservatoryEnvironment(self.project_id, self.data_location) - - telescope = TelescopeTest(schedule=timedelta(days=1)) - dag = telescope.make_dag() - execution_date = pendulum.datetime(2021, 1, 1) - expected_dag_date = pendulum.datetime(2021, 1, 2) - with env.create(): - with env.create_dag_run(dag, execution_date): - self.assertIsNotNone(env.dag_run) - self.assertEqual(expected_dag_date, env.dag_run.start_date) - - -class TestObservatoryTestCase(unittest.TestCase): - """Test the ObservatoryTestCase class""" - - def __init__(self, *args, **kwargs): - super(TestObservatoryTestCase, self).__init__(*args, **kwargs) - self.project_id = os.getenv("TEST_GCP_PROJECT_ID") - self.data_location = os.getenv("TEST_GCP_DATA_LOCATION") - - def test_assert_dag_structure(self): - """Test assert_dag_structure""" - - test_case = ObservatoryTestCase() - telescope = TelescopeTest() - dag = telescope.make_dag() - - # No assertion error - expected = {"check_dependencies": ["setup_task"], "setup_task": ["my_task"], "my_task": []} - test_case.assert_dag_structure(expected, dag) - - # Raise assertion error - with self.assertRaises(AssertionError): - expected = {"check_dependencies": ["list_releases"], "list_releases": []} - test_case.assert_dag_structure(expected, dag) - - def test_assert_dag_load(self): - """Test assert_dag_load""" - - test_case = ObservatoryTestCase() - env = ObservatoryEnvironment() - with env.create() as temp_dir: - # Write DAG into temp_dir - file_path = os.path.join(temp_dir, f"telescope_test.py") - with open(file_path, mode="w") as f: - f.write(DAG_FILE_CONTENT) - - # DAG loaded successfully: should be no errors - test_case.assert_dag_load(DAG_ID, file_path) - - # Remove DAG from temp_dir - os.unlink(file_path) - - # DAG not loaded - with self.assertRaises(Exception): - test_case.assert_dag_load(DAG_ID, file_path) - - # DAG not found - with self.assertRaises(Exception): - test_case.assert_dag_load("dag not found", file_path) - - # Import errors - with self.assertRaises(AssertionError): - test_case.assert_dag_load("no dag found", test_fixtures_path("utils", "bad_dag.py")) - - # No dag - with self.assertRaises(AssertionError): - empty_filename = os.path.join(temp_dir, "empty_dag.py") - Path(empty_filename).touch() - test_case.assert_dag_load("invalid_dag_id", empty_filename) - - def test_assert_blob_integrity(self): - """Test assert_blob_integrity""" - - env = ObservatoryEnvironment(self.project_id, self.data_location) - - with env.create(): - # Upload file to download bucket and check gzip-crc - blob_name = "people.csv" - file_path = test_fixtures_path("utils", blob_name) - result, upload = gcs_upload_file(bucket_name=env.download_bucket, blob_name=blob_name, file_path=file_path) - self.assertTrue(result) - - # Check that blob exists - test_case = ObservatoryTestCase() - test_case.assert_blob_integrity(env.download_bucket, blob_name, file_path) - - # Check that blob doesn't exist - with self.assertRaises(AssertionError): - test_case.assert_blob_integrity(env.transform_bucket, blob_name, file_path) - - def test_assert_table_integrity(self): - """Test assert_table_integrity""" - - env = ObservatoryEnvironment(self.project_id, self.data_location) - - with env.create(): - # Upload file to download bucket and check gzip-crc - blob_name = "people.jsonl" - file_path = test_fixtures_path("utils", blob_name) - result, upload = gcs_upload_file(bucket_name=env.download_bucket, blob_name=blob_name, file_path=file_path) - self.assertTrue(result) - - # Create dataset - dataset_id = env.add_dataset() - bq_create_dataset(project_id=self.project_id, dataset_id=dataset_id, location=self.data_location) - - # Test loading JSON newline table - table_name = random_id() - schema_path = test_fixtures_path("utils", "people_schema.json") - uri = gcs_blob_uri(env.download_bucket, blob_name) - table_id = bq_table_id(self.project_id, dataset_id, table_name) - result = bq_load_table( - uri=uri, - table_id=table_id, - schema_file_path=schema_path, - source_format=SourceFormat.NEWLINE_DELIMITED_JSON, - ) - self.assertTrue(result) - - # Check BigQuery table exists and has expected rows - test_case = ObservatoryTestCase() - table_id = f"{self.project_id}.{dataset_id}.{table_name}" - expected_rows = 5 - test_case.assert_table_integrity(table_id, expected_rows) - - # Check that BigQuery table doesn't exist - with self.assertRaises(AssertionError): - table_id = f"{dataset_id}.{random_id()}" - test_case.assert_table_integrity(table_id, expected_rows) - - # Check that BigQuery table has incorrect rows - with self.assertRaises(AssertionError): - table_id = f"{dataset_id}.{table_name}" - expected_rows = 20 - test_case.assert_table_integrity(table_id, expected_rows) - - def test_assert_table_content(self): - """Test assert table content - - :return: None. - """ - - env = ObservatoryEnvironment(self.project_id, self.data_location) - - with env.create(): - # Upload file to download bucket and check gzip-crc - blob_name = "people.jsonl" - file_path = test_fixtures_path("utils", blob_name) - result, upload = gcs_upload_file(bucket_name=env.download_bucket, blob_name=blob_name, file_path=file_path) - self.assertTrue(result) - - # Create dataset - dataset_id = env.add_dataset() - bq_create_dataset(project_id=self.project_id, dataset_id=dataset_id, location=self.data_location) - - # Test loading JSON newline table - table_name = random_id() - schema_path = test_fixtures_path("utils", "people_schema.json") - uri = gcs_blob_uri(env.download_bucket, blob_name) - table_id = bq_table_id(self.project_id, dataset_id, table_name) - result = bq_load_table( - uri=uri, - table_id=table_id, - schema_file_path=schema_path, - source_format=SourceFormat.NEWLINE_DELIMITED_JSON, - ) - self.assertTrue(result) - - # Check BigQuery table exists and has expected rows - test_case = ObservatoryTestCase() - table_id = f"{self.project_id}.{dataset_id}.{table_name}" - expected_content = [ - {"first_name": "Gisella", "last_name": "Derya", "dob": datetime(1997, 7, 1).date()}, - {"first_name": "Adelaida", "last_name": "Melis", "dob": datetime(1980, 9, 3).date()}, - {"first_name": "Melanie", "last_name": "Magomedkhan", "dob": datetime(1990, 3, 1).date()}, - {"first_name": "Octavia", "last_name": "Tomasa", "dob": datetime(1970, 1, 8).date()}, - {"first_name": "Ansgar", "last_name": "Zorion", "dob": datetime(2001, 2, 1).date()}, - ] - test_case.assert_table_content(table_id, expected_content, "first_name") - - # Check that BigQuery table doesn't exist - with self.assertRaises(AssertionError): - table_id = f"{self.project_id}.{dataset_id}.{random_id()}" - test_case.assert_table_content(table_id, expected_content, "first_name") - - # Check that BigQuery table has extra rows - with self.assertRaises(AssertionError): - table_id = f"{dataset_id}.{table_name}" - expected_content = [ - {"first_name": "Gisella", "last_name": "Derya", "dob": datetime(1997, 7, 1).date()}, - {"first_name": "Adelaida", "last_name": "Melis", "dob": datetime(1980, 9, 3).date()}, - {"first_name": "Octavia", "last_name": "Tomasa", "dob": datetime(1970, 1, 8).date()}, - {"first_name": "Ansgar", "last_name": "Zorion", "dob": datetime(2001, 2, 1).date()}, - ] - test_case.assert_table_content(table_id, expected_content, "first_name") - - # Check that BigQuery table has missing rows - with self.assertRaises(AssertionError): - table_id = f"{self.project_id}.{dataset_id}.{table_name}" - expected_content = [ - {"first_name": "Gisella", "last_name": "Derya", "dob": datetime(1997, 7, 1).date()}, - {"first_name": "Adelaida", "last_name": "Melis", "dob": datetime(1980, 9, 3).date()}, - {"first_name": "Melanie", "last_name": "Magomedkhan", "dob": datetime(1990, 3, 1).date()}, - {"first_name": "Octavia", "last_name": "Tomasa", "dob": datetime(1970, 1, 8).date()}, - {"first_name": "Ansgar", "last_name": "Zorion", "dob": datetime(2001, 2, 1).date()}, - {"first_name": "Extra", "last_name": "Row", "dob": datetime(2001, 2, 1).date()}, - ] - test_case.assert_table_content(table_id, expected_content, "first_name") - - def test_assert_file_integrity(self): - """Test assert_file_integrity""" - - test_case = ObservatoryTestCase() - tests_path = test_fixtures_path("utils") - - # Test md5 - file_path = os.path.join(tests_path, "people.csv") - expected_hash = "ad0d7ad3dc3434337cebd5fb543420e7" - algorithm = "md5" - test_case.assert_file_integrity(file_path, expected_hash, algorithm) - - # Test gzip-crc - file_path = os.path.join(tests_path, "people.csv.gz") - expected_hash = "3beea5ac" - algorithm = "gzip_crc" - test_case.assert_file_integrity(file_path, expected_hash, algorithm) - - def test_assert_cleanup(self): - """Test assert_cleanup""" - - with CliRunner().isolated_filesystem() as temp_dir: - workflow = os.path.join(temp_dir, "workflow") - - # Make download, extract and transform folders - os.makedirs(workflow) - - # Check that assertion is raised when folders exist - test_case = ObservatoryTestCase() - with self.assertRaises(AssertionError): - test_case.assert_cleanup(workflow) - - # Delete folders - os.rmdir(workflow) - - # No error when folders deleted - test_case.assert_cleanup(workflow) - - def test_setup_mock_file_download(self): - """Test mocking a file download""" - - with CliRunner().isolated_filesystem() as temp_dir: - # Write data into temp_dir - expected_data = "Hello World!" - file_path = os.path.join(temp_dir, f"content.txt") - with open(file_path, mode="w") as f: - f.write(expected_data) - - # Check that content was downloaded from test file - test_case = ObservatoryTestCase() - url = "https://example.com" - with httpretty.enabled(): - test_case.setup_mock_file_download(url, file_path) - response = retry_session().get(url) - self.assertEqual(expected_data, response.content.decode("utf-8")) - - -class TestSftpServer(unittest.TestCase): - def setUp(self) -> None: - self.host = "localhost" - self.port = find_free_port() - - def test_server(self): - """Test that the SFTP server can be connected to""" - - server = SftpServer(host=self.host, port=self.port) - with server.create() as root_dir: - # Connect to SFTP server and disable host key checking - cnopts = pysftp.CnOpts() - cnopts.hostkeys = None - sftp = pysftp.Connection(self.host, port=self.port, username="", password="", cnopts=cnopts) - - # Check that there are no files - files = sftp.listdir(".") - self.assertFalse(len(files)) - - # Add a file and check that it exists - expected_file_name = "onix.xml" - file_path = os.path.join(root_dir, expected_file_name) - with open(file_path, mode="w") as f: - f.write("hello world") - files = sftp.listdir(".") - self.assertEqual(1, len(files)) - self.assertEqual(expected_file_name, files[0]) - - -class TestFtpServer(unittest.TestCase): - def setUp(self) -> None: - self.host = "localhost" - self.port = find_free_port() - - @contextlib.contextmanager - def test_server(self): - """Test that the FTP server can be connected to""" - - with CliRunner().isolated_filesystem() as tmp_dir: - server = FtpServer(directory=tmp_dir, host=self.host, port=self.port) - with server.create() as root_dir: - # Connect to FTP server anonymously - ftp_conn = FTP() - ftp_conn.connect(host=self.host, port=self.port) - ftp_conn.login() - - # Check that there are no files - files = ftp_conn.nlst() - self.assertFalse(len(files)) - - # Add a file and check that it exists - expected_file_name = "textfile.txt" - file_path = os.path.join(root_dir, expected_file_name) - with open(file_path, mode="w") as f: - f.write("hello world") - files = ftp_conn.nlst() - self.assertEqual(1, len(files)) - self.assertEqual(expected_file_name, files[0]) - - @contextlib.contextmanager - def test_user_permissions(self): - "Test the level of permissions of the root and anonymous users." - - with CliRunner().isolated_filesystem() as tmp_dir: - server = FtpServer( - directory=tmp_dir, host=self.host, port=self.port, root_username="root", root_password="pass" - ) - with server.create() as root_dir: - # Add a file onto locally hosted server. - expected_file_name = "textfile.txt" - file_path = os.path.join(root_dir, expected_file_name) - with open(file_path, mode="w") as f: - f.write("hello world") - - # Connect to FTP server anonymously. - ftp_conn = FTP() - ftp_conn.connect(host=self.host, port=self.port) - ftp_conn.login() - - # Make sure that anonymoous user has read-only permissions - ftp_repsonse = ftp_conn.sendcmd(f"MLST {expected_file_name}") - self.assertTrue(";perm=r;size=11;type=file;" in ftp_repsonse) - - ftp_conn.close() - - # Connect to FTP server as root user. - ftp_conn = FTP() - ftp_conn.connect(host=self.host, port=self.port) - ftp_conn.login(user="root", passwd="pass") - - # Make sure that root user has all available read/write/modification permissions. - ftp_repsonse = ftp_conn.sendcmd(f"MLST {expected_file_name}") - self.assertTrue(";perm=radfwMT;size=11;type=file;" in ftp_repsonse) - - ftp_conn.close() - - -class TestHttpserver(ObservatoryTestCase): - def test_serve(self): - """Make sure the server can be constructed.""" - with patch("observatory.platform.observatory_environment.ThreadingHTTPServer.serve_forever") as m_serve: - server = HttpServer(directory=".") - server.serve_(("localhost", 10000), ".") - self.assertEqual(m_serve.call_count, 1) - - @timeout_decorator.timeout(1) - def test_stop_before_start(self): - """Make sure there's no deadlock if we try to stop before a start.""" - - server = HttpServer(directory=".") - server.stop() - - @timeout_decorator.timeout(1) - def test_start_twice(self): - """Make sure there's no funny business if we try to stop before a start.""" - - server = HttpServer(directory=".") - server.start() - server.start() - server.stop() - - def test_server(self): - """Test the webserver can serve a directory""" - - directory = test_fixtures_path("utils") - server = HttpServer(directory=directory) - server.start() - - test_file = "http_testfile.txt" - expected_hash = "d8e8fca2dc0f896fd7cb4cb0031ba249" - algorithm = "md5" - - url = f"{server.url}{test_file}" - - with CliRunner().isolated_filesystem() as tmpdir: - dst_file = os.path.join(tmpdir, "testfile.txt") - - download_files(download_list=[DownloadInfo(url=url, filename=dst_file)]) - - self.assert_file_integrity(dst_file, expected_hash, algorithm) - - server.stop() - - def test_context_manager(self): - directory = test_fixtures_path("utils") - server = HttpServer(directory=directory) - - with server.create(): - test_file = "http_testfile.txt" - expected_hash = "d8e8fca2dc0f896fd7cb4cb0031ba249" - algorithm = "md5" - - url = f"{server.url}{test_file}" - - with CliRunner().isolated_filesystem() as tmpdir: - dst_file = os.path.join(tmpdir, "testfile.txt") - download_file(url=url, filename=dst_file) - self.assert_file_integrity(dst_file, expected_hash, algorithm) - - -class TestLoadAndParseJson(unittest.TestCase): - def test_load_and_parse_json(self): - # Create a temporary JSON file - with tempfile.NamedTemporaryFile() as temp_file: - # Create the data dictionary and write to temp file - data = { - "date1": "2022-01-01", - "timestamp1": "2022-01-01 12:00:00.100000 UTC", - "date2": "20230101", - "timestamp2": "2023-01-01 12:00:00", - } - with open(temp_file.name, "w") as f: - json.dump(data, f) - - # Test case 1: Parsing date fields with default date formats. Not specifying timestamp fields - expected_result = data.copy() - expected_result["date1"] = datetime(2022, 1, 1).date() - expected_result["date2"] = datetime(2023, 1, 1).date() # Should be converted by pendulum - result = load_and_parse_json(temp_file.name, date_fields=["date1", "date2"], date_formats=["%Y-%m-%d"]) - self.assertEqual(result, expected_result) - - # Test case 2: Parsing timestamp fields with custom timestamp format, not specifying date field - expected_result = data.copy() - expected_result["timestamp1"] = datetime(2022, 1, 1, 12, 0, 0, 100000) - expected_result["timestamp2"] = datetime( - 2023, 1, 1, 12, 0, 0, tzinfo=pendulum.tz.timezone("UTC") - ) # Converted by pendulum - result = load_and_parse_json( - temp_file.name, - timestamp_fields=["timestamp1", "timestamp2"], - timestamp_formats=["%Y-%m-%d %H:%M:%S.%f %Z"], - ) - self.assertEqual(result, expected_result) - - # Test case 3: Default date and timestamp formats - expected_result = { - "date1": datetime(2022, 1, 1).date(), - "date2": "20230101", - "timestamp1": datetime(2022, 1, 1, 12, 0, 0, 100000), - "timestamp2": "2023-01-01 12:00:00", - } - result = load_and_parse_json(temp_file.name, date_fields=["date1"], timestamp_fields=["timestamp1"]) - self.assertEqual(result, expected_result) diff --git a/tests/observatory/platform/utils/__init__.py b/tests/observatory/platform/utils/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/observatory/platform/utils/test_proc_utils.py b/tests/observatory/platform/utils/test_proc_utils.py deleted file mode 100644 index 106cda4a0..000000000 --- a/tests/observatory/platform/utils/test_proc_utils.py +++ /dev/null @@ -1,55 +0,0 @@ -# Copyright 2022 Curtin University -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Author: Aniek Roelofs - -import unittest -from unittest.mock import call, patch - -from observatory.platform.utils.proc_utils import stream_process, wait_for_process - - -class TestProcUtils(unittest.TestCase): - @patch("subprocess.Popen") - def test_wait_for_process(self, mock_popen): - proc = mock_popen() - proc.communicate.return_value = "out".encode(), "err".encode() - - out, err = wait_for_process(proc) - self.assertEqual("out", out) - self.assertEqual("err", err) - - @patch("subprocess.Popen") - @patch("builtins.print") - def test_stream_process(self, mock_print, mock_popen): - proc = mock_popen() - proc.stdout = ["out1".encode(), "out2".encode()] - proc.stderr = ["err1".encode(), "err2".encode()] - proc.poll.side_effect = [None, 0, None, 0] - - # Test with debug=False - out, err = stream_process(proc, False) - self.assertEqual("out1out2out1out2", out) - self.assertEqual("err1err2err1err2", err) - self.assertListEqual([call(err, end="") for err in ["err1", "err2", "err1", "err2"]], mock_print.call_args_list) - mock_print.reset_mock() - - # Test with debug=True - out, err = stream_process(proc, True) - self.assertEqual("out1out2out1out2", out) - self.assertEqual("err1err2err1err2", err) - self.assertListEqual( - [call(err, end="") for err in ["out1", "out2", "err1", "err2", "out1", "out2", "err1", "err2"]], - mock_print.call_args_list, - ) diff --git a/tests/observatory/platform/workflows/__init__.py b/tests/observatory/platform/workflows/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/observatory/platform/workflows/test_vm_create.py b/tests/observatory/platform/workflows/test_vm_create.py deleted file mode 100644 index 0512d2938..000000000 --- a/tests/observatory/platform/workflows/test_vm_create.py +++ /dev/null @@ -1,259 +0,0 @@ -from unittest.mock import PropertyMock, patch - -import pendulum -import time_machine -from airflow.models import XCom -from airflow.models.connection import Connection -from airflow.utils.session import provide_session -from airflow.utils.state import State - -from observatory.platform.config import AirflowConns -from observatory.platform.observatory_config import Workflow, VirtualMachine -from observatory.platform.observatory_environment import ( - ObservatoryEnvironment, - ObservatoryTestCase, -) -from observatory.platform.terraform.terraform_api import TerraformVariable -from observatory.platform.workflows.vm_workflow import VmCreateWorkflow, TerraformVirtualMachineAPI - - -@provide_session -def xcom_count(*, execution_date, dag_ids, session=None): - return XCom.get_many( - execution_date=execution_date, - dag_ids=dag_ids, - include_prior_dates=False, - session=session, - ).count() - - -class TestVmCreateWorkflow(ObservatoryTestCase): - """Test the vm_create dag.""" - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.dag_id = "vm_create" - self.terraform_organisation = "terraform-org" - self.terraform_workspace = "my-terraform-workspace-develop" - - @patch("observatory.platform.workflows.vm_workflow.TerraformApi.list_workspace_variables") - @patch( - "observatory.platform.workflows.vm_workflow.TerraformVirtualMachineAPI.workspace_id", new_callable=PropertyMock - ) - @patch( - "observatory.platform.workflows.vm_workflow.TerraformVirtualMachineAPI.terraform_api", new_callable=PropertyMock - ) - def test_get_vm_info_no_vars(self, m_tapi, m_wid, m_list_vars): - """Test get_vm_info""" - - m_list_vars.return_value = [] - m_wid.return_value = "wid" - api = TerraformVirtualMachineAPI(organisation=self.terraform_organisation, workspace=self.terraform_workspace) - vm, vm_var = api.get_vm_info() - self.assertIsNone(vm) - self.assertIsNone(vm_var) - - @patch( - "observatory.platform.workflows.vm_workflow.TerraformVirtualMachineAPI.workspace_id", new_callable=PropertyMock - ) - @patch( - "observatory.platform.workflows.vm_workflow.TerraformVirtualMachineAPI.terraform_api", new_callable=PropertyMock - ) - def test_get_vm_info_no_target_vars(self, m_tapi, m_wid): - """Test get_vm_info""" - - vm = VirtualMachine(machine_type="vm_type", disk_size=10, disk_type="ssd", create=False) - vm_tf = TerraformVariable( - key="not_target", - value=vm.to_hcl(), - hcl=True, - ) - - class MockApi: - def list_workspace_variables(self, *args): - return [vm_tf] - - m_tapi.return_value = MockApi() - api = TerraformVirtualMachineAPI(organisation=self.terraform_organisation, workspace=self.terraform_workspace) - - vm, vm_var = api.get_vm_info() - self.assertIsNone(vm) - self.assertIsNone(vm_var) - - def test_dag_structure(self): - """Test that vm_create has the correct structure. - :return: None - """ - - dag = VmCreateWorkflow( - dag_id=self.dag_id, - terraform_organisation=self.terraform_organisation, - terraform_workspace=self.terraform_workspace, - ).make_dag() - self.assert_dag_structure( - { - "check_dependencies": ["check_vm_state"], - "check_vm_state": ["update_terraform_variable"], - "update_terraform_variable": ["run_terraform"], - "run_terraform": ["check_run_status"], - "check_run_status": ["cleanup"], - "cleanup": [], - }, - dag, - ) - - def test_dag_load(self): - """Test that vm_create can be loaded from a DAG bag. - :return: None - """ - - env = ObservatoryEnvironment( - workflows=[ - Workflow( - dag_id="vm_create", - name="VM Create Workflow", - class_name="observatory.platform.workflows.vm_workflow.VmCreateWorkflow", - kwargs=dict( - terraform_organisation="terraform_organisation", terraform_workspace="terraform_workspace" - ), - ) - ] - ) - - with env.create(): - self.assert_dag_load_from_config("vm_create") - - def setup_env(self, env): - conn = Connection( - conn_id=AirflowConns.SLACK, uri="https://:my-slack-token@https%3A%2F%2Fhooks.slack.com%2Fservices" - ) - env.add_connection(conn) - - conn = Connection(conn_id=AirflowConns.TERRAFORM, uri="http://:apikey@") - env.add_connection(conn) - - @patch("observatory.platform.workflows.vm_workflow.TerraformApi.list_workspace_variables") - @patch("observatory.platform.workflows.vm_workflow.TerraformApi.workspace_id") - def test_workflow_vm_already_on(self, m_tapi, m_list_workspace_vars): - """Test the vm_create workflow""" - - m_tapi.return_value = "workspace" - - vm = VirtualMachine(machine_type="vm_type", disk_size=10, disk_type="ssd", create=True) - m_list_workspace_vars.return_value = [ - TerraformVariable( - key="airflow_worker_vm", - value=vm.to_hcl(), - hcl=True, - ) - ] - - env = ObservatoryEnvironment() - with env.create(): - workflow = VmCreateWorkflow( - dag_id=self.dag_id, - terraform_organisation=self.terraform_organisation, - terraform_workspace=self.terraform_workspace, - ) - dag = workflow.make_dag() - execution_date = pendulum.datetime(2021, 1, 1) - self.setup_env(env) - - with env.create_dag_run(dag, execution_date) as dag_run: - with time_machine.travel(dag_run.start_date, tick=True): - # check dependencies - ti = env.run_task(workflow.check_dependencies.__name__) - self.assertEqual(ti.state, State.SUCCESS) - - # check vm state - ti = env.run_task(workflow.check_vm_state.__name__) - self.assertEqual(ti.state, State.SUCCESS) - - # update terraform variable - ti = env.run_task(workflow.update_terraform_variable.__name__) - self.assertEqual(ti.state, State.SKIPPED) - - # run terraform - ti = env.run_task(workflow.run_terraform.__name__) - self.assertEqual(ti.state, State.SKIPPED) - - # check run status - ti = env.run_task(workflow.check_run_status.__name__) - self.assertEqual(ti.state, State.SKIPPED) - - # cleanup - ti = env.run_task(workflow.cleanup.__name__) - self.assertEqual(ti.state, State.SUCCESS) - - @patch("observatory.platform.workflows.vm_workflow.send_slack_msg") - @patch("observatory.platform.workflows.vm_workflow.TerraformApi.get_run_details") - @patch("observatory.platform.workflows.vm_workflow.TerraformApi.create_run") - @patch("observatory.platform.workflows.vm_workflow.TerraformApi.update_workspace_variable") - @patch("observatory.platform.workflows.vm_workflow.TerraformApi.list_workspace_variables") - @patch("observatory.platform.workflows.vm_workflow.TerraformApi.workspace_id") - def test_workflow_vm_create( - self, m_tapi, m_list_workspace_vars, m_update, m_create_run, m_run_details, m_send_slack_msg - ): - "Test the vm_create workflow" - - m_tapi.return_value = "workspace" - - vm = VirtualMachine(machine_type="vm_type", disk_size=10, disk_type="ssd", create=False) - vm_tf = TerraformVariable( - key="airflow_worker_vm", - value=vm.to_hcl(), - hcl=True, - ) - m_list_workspace_vars.return_value = [vm_tf] - m_create_run.return_value = 1 - m_run_details.return_value = {"data": {"attributes": {"status": "planned_and_finished"}}} - - env = ObservatoryEnvironment() - with env.create(): - workflow = VmCreateWorkflow( - dag_id=self.dag_id, - terraform_organisation=self.terraform_organisation, - terraform_workspace=self.terraform_workspace, - ) - dag = workflow.make_dag() - execution_date = pendulum.datetime(2021, 1, 1) - self.setup_env(env) - - with env.create_dag_run(dag, execution_date) as dag_run: - with time_machine.travel(dag_run.start_date, tick=True): - # check dependencies - ti = env.run_task(workflow.check_dependencies.__name__) - self.assertEqual(ti.state, State.SUCCESS) - - # check vm state - ti = env.run_task(workflow.check_vm_state.__name__) - self.assertEqual(ti.state, State.SUCCESS) - - # update terraform variable - ti = env.run_task(workflow.update_terraform_variable.__name__) - self.assertEqual(m_update.call_count, 1) - call_args, _ = m_update.call_args - self.assertEqual(call_args[0], vm_tf) - self.assertEqual(call_args[1], "workspace") - self.assertEqual(ti.state, State.SUCCESS) - - # run terraform - ti = env.run_task(workflow.run_terraform.__name__) - self.assertEqual(ti.state, State.SUCCESS) - self.assertEqual(m_create_run.call_count, 1) - - # check run status - ti = env.run_task(workflow.check_run_status.__name__) - self.assertEqual(ti.state, State.SUCCESS) - self.assertEqual(m_send_slack_msg.call_count, 1) - - # cleanup - ti = env.run_task(workflow.cleanup.__name__) - self.assertEqual(ti.state, State.SUCCESS) - self.assertEqual( - xcom_count( - execution_date=execution_date, - dag_ids=workflow.dag_id, - ), - 5, - ) diff --git a/tests/observatory/platform/workflows/test_vm_destroy.py b/tests/observatory/platform/workflows/test_vm_destroy.py deleted file mode 100644 index da0fbcf01..000000000 --- a/tests/observatory/platform/workflows/test_vm_destroy.py +++ /dev/null @@ -1,1207 +0,0 @@ -# Copyright 2021 Curtin University -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Author: Tuan Chien - -import datetime -import unittest -from unittest.mock import patch - -import pendulum -import time_machine -from airflow.models import XCom -from airflow.models.connection import Connection -from airflow.utils.session import provide_session -from airflow.utils.state import DagRunState, State - -from observatory.platform.config import AirflowConns -from observatory.platform.observatory_config import Workflow, VirtualMachine -from observatory.platform.observatory_environment import ( - ObservatoryEnvironment, - ObservatoryTestCase, -) -from observatory.platform.terraform.terraform_api import TerraformVariable -from observatory.platform.workflows.vm_workflow import ( - VmCreateWorkflow, - VmDestroyWorkflow, - parse_datetime, - XCOM_START_TIME_VM, - XCOM_PREV_START_TIME_VM, - XCOM_DESTROY_TIME_VM, - XCOM_WARNING_TIME, -) - - -@provide_session -def xcom_count(*, execution_date, dag_ids, session=None): - return XCom.get_many( - execution_date=execution_date, - dag_ids=dag_ids, - include_prior_dates=False, - session=session, - ).count() - - -@provide_session -def xcom_push(*, key, value, dag_id, task_id, execution_date, session=None): - XCom.set( - key=key, - value=value, - task_id=task_id, - dag_id=dag_id, - execution_date=execution_date, - session=session, - ) - - -class TestVmDestroyWorkflow(ObservatoryTestCase): - """Test the vm_destroy dag.""" - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.dag_id = "vm_destroy" - self.terraform_organisation = "terraform-org" - self.terraform_workspace = "my-terraform-workspace-develop" - self.vm_create_dag_id = "vm_create" - - def setup_env(self, env): - conn = Connection( - conn_id=AirflowConns.SLACK, uri="https://:my-slack-token@https%3A%2F%2Fhooks.slack.com%2Fservices" - ) - env.add_connection(conn) - - conn = Connection(conn_id=AirflowConns.TERRAFORM, uri="http://:apikey@") - env.add_connection(conn) - - def test_dag_structure(self): - """Test that vm_create has the correct structure. - :return: None - """ - - dag = VmDestroyWorkflow( - dag_id=self.dag_id, - terraform_organisation=self.terraform_organisation, - terraform_workspace=self.terraform_workspace, - dags_watch_list=[], - ).make_dag() - self.assert_dag_structure( - { - "check_dependencies": ["check_vm_state"], - "check_vm_state": ["check_dags_status"], - "check_dags_status": ["update_terraform_variable"], - "update_terraform_variable": ["run_terraform"], - "run_terraform": ["check_run_status"], - "check_run_status": ["cleanup"], - "cleanup": [], - }, - dag, - ) - - def test_dag_load(self): - """Test that vm_create can be loaded from a DAG bag. - :return: None - """ - - env = ObservatoryEnvironment( - workflows=[ - Workflow( - dag_id="vm_destroy", - name="VM Destroy Workflow", - class_name="observatory.platform.workflows.vm_workflow.VmDestroyWorkflow", - kwargs=dict( - terraform_organisation="terraform_organisation", - terraform_workspace="terraform_workspace", - dags_watch_list=[ - "crossref_events", - "crossref_metadata", - "open_citations", - "unpaywall", - "openalex", - ], - ), - ) - ] - ) - - with env.create(): - self.assert_dag_load_from_config("vm_destroy") - - def test_parse_datetime(self): - expected = pendulum.datetime(2021, 1, 1) - actual = parse_datetime("2021-01-01") - self.assertEqual(expected, actual) - - @patch("observatory.platform.workflows.vm_workflow.DagRun.find") - def test_get_last_execution_prev(self, m_drfind): - workflow = VmDestroyWorkflow( - dag_id=self.dag_id, - terraform_organisation=self.terraform_organisation, - terraform_workspace=self.terraform_workspace, - dags_watch_list=[], - ) - - class MockDagRun: - def __init__(self, *, start_date, execution_date, state): - self.start_date = start_date - self.execution_date = execution_date - self.state = state - - class MockDag: - def __init__(self): - self.default_args = {"start_date": datetime.datetime(2000, 1, 1)} - - # No dag runs, no prev start - m_drfind.return_value = [] - ts = workflow._get_last_execution_prev(dag=MockDag(), dag_id="dagid", prev_start_time_vm=None) - self.assertEqual(ts, datetime.datetime(2000, 1, 1)) - - # No dag runs, prev start - m_drfind.return_value = [] - ts = workflow._get_last_execution_prev( - dag=MockDag(), dag_id="dagid", prev_start_time_vm=pendulum.datetime(2001, 1, 1) - ) - self.assertEqual(ts, datetime.datetime(2000, 1, 1)) - - # Dag run running - m_drfind.return_value = [ - MockDagRun( - start_date=datetime.datetime(2000, 2, 1), - execution_date=datetime.datetime(2000, 2, 1), - state=DagRunState.RUNNING, - ) - ] - ts = workflow._get_last_execution_prev( - dag=MockDag(), dag_id="dagid", prev_start_time_vm=pendulum.datetime(2001, 1, 1) - ) - self.assertIsNone(ts) - - # Dag run success, no prev start time vm - m_drfind.return_value = [ - MockDagRun( - start_date=datetime.datetime(2000, 2, 1), - execution_date=datetime.datetime(2000, 3, 1), - state=DagRunState.SUCCESS, - ) - ] - ts = workflow._get_last_execution_prev( - dag=MockDag(), - dag_id="dagid", - prev_start_time_vm=None, - ) - self.assertEqual(ts, datetime.datetime(2000, 1, 1)) - - # Dag run success, dag run start date before prev start time vm - m_drfind.return_value = [ - MockDagRun( - start_date=datetime.datetime(2000, 2, 1), - execution_date=datetime.datetime(2000, 3, 1), - state=DagRunState.SUCCESS, - ) - ] - ts = workflow._get_last_execution_prev( - dag=MockDag(), dag_id="dagid", prev_start_time_vm=pendulum.datetime(2001, 1, 1) - ) - self.assertEqual(ts, datetime.datetime(2000, 3, 1)) - - # Dag run success, dag run start date after prev start time vm - m_drfind.return_value = [ - MockDagRun( - start_date=datetime.datetime(2002, 2, 1), - execution_date=datetime.datetime(2000, 3, 1), - state=DagRunState.SUCCESS, - ) - ] - ts = workflow._get_last_execution_prev( - dag=MockDag(), dag_id="dagid", prev_start_time_vm=pendulum.datetime(2001, 1, 1) - ) - self.assertEqual(ts, datetime.datetime(2000, 1, 1)) - - @patch("observatory.platform.workflows.vm_workflow.DagRun.find") - def test_check_success_run(self, m_drfind): - workflow = VmDestroyWorkflow( - dag_id=self.dag_id, - terraform_organisation=self.terraform_organisation, - terraform_workspace=self.terraform_workspace, - dags_watch_list=[], - ) - - class MockDagRun: - def __init__(self, *, start_date, execution_date, state): - self.start_date = start_date - self.execution_date = execution_date - self.state = state - self.dag_id = "dagid" - - # No dates - m_drfind.return_value = [] - execution_dates = [] - status = workflow._check_success_runs(dag_id="dagid", execution_dates=execution_dates) - self.assertTrue(status) - - # Date, no runs - m_drfind.return_value = None - execution_dates = [datetime.datetime(2000, 1, 1)] - status = workflow._check_success_runs(dag_id="dagid", execution_dates=execution_dates) - self.assertFalse(status) - - # Date, run not success - m_drfind.return_value = [ - MockDagRun( - start_date=datetime.datetime(2000, 2, 1), - execution_date=datetime.datetime(2000, 3, 1), - state=DagRunState.FAILED, - ) - ] - execution_dates = [datetime.datetime(2000, 1, 1)] - status = workflow._check_success_runs(dag_id="dagid", execution_dates=execution_dates) - self.assertFalse(status) - - # Date, run success - m_drfind.return_value = [ - MockDagRun( - start_date=datetime.datetime(2000, 2, 1), - execution_date=datetime.datetime(2000, 3, 1), - state=DagRunState.SUCCESS, - ) - ] - execution_dates = [datetime.datetime(2000, 1, 1)] - status = workflow._check_success_runs(dag_id="dagid", execution_dates=execution_dates) - self.assertTrue(status) - - @patch("observatory.platform.workflows.vm_workflow.TerraformApi.list_workspace_variables") - @patch("observatory.platform.workflows.vm_workflow.TerraformApi.workspace_id") - def test_vm_destroy_vm_already_off(self, m_tapi, m_list_workspace_vars): - """Test the vm_destroy workflow""" - - m_tapi.return_value = "workspace" - - vm = VirtualMachine(machine_type="vm_type", disk_size=10, disk_type="ssd", create=False) - m_list_workspace_vars.return_value = [ - TerraformVariable( - key="airflow_worker_vm", - value=vm.to_hcl(), - hcl=True, - ) - ] - - env = ObservatoryEnvironment() - with env.create(): - workflow = VmDestroyWorkflow( - dag_id=self.dag_id, - terraform_organisation=self.terraform_organisation, - terraform_workspace=self.terraform_workspace, - dags_watch_list=[], - ) - dag = workflow.make_dag() - execution_date = pendulum.datetime(2021, 1, 1) - self.setup_env(env) - - with env.create_dag_run(dag, execution_date) as dag_run: - with time_machine.travel(dag_run.start_date, tick=True): - # check dependencies - ti = env.run_task(workflow.check_dependencies.__name__) - self.assertEqual(ti.state, State.SUCCESS) - - # check vm state - ti = env.run_task(workflow.check_vm_state.__name__) - self.assertEqual(ti.state, State.SUCCESS) - - # check dags status - ti = env.run_task(workflow.check_dags_status.__name__) - self.assertEqual(ti.state, State.SKIPPED) - - @unittest.skip("Broken: removing workflow in near future") - @patch("observatory.platform.workflows.vm_workflow.TerraformApi.list_workspace_variables") - @patch("observatory.platform.workflows.vm_workflow.TerraformApi.workspace_id") - def test_vm_destroy_empty_watchlist(self, m_tapi, m_list_workspace_vars): - """Test the vm_destroy workflow""" - - m_tapi.return_value = "workspace" - - vm = VirtualMachine(machine_type="vm_type", disk_size=10, disk_type="ssd", create=True) - m_list_workspace_vars.return_value = [ - TerraformVariable( - key="airflow_worker_vm", - value=vm.to_hcl(), - hcl=True, - ) - ] - - env = ObservatoryEnvironment() - with env.create(): - workflow = VmDestroyWorkflow( - dag_id=self.dag_id, - terraform_organisation=self.terraform_organisation, - terraform_workspace=self.terraform_workspace, - dags_watch_list=[], - ) - dag = workflow.make_dag() - execution_date = pendulum.datetime(2021, 1, 1) - self.setup_env(env) - - with env.create_dag_run(dag, execution_date) as dag_run: - with time_machine.travel(dag_run.start_date, tick=True): - # check dependencies - ti = env.run_task(workflow.check_dependencies.__name__) - self.assertEqual(ti.state, State.SUCCESS) - - # check vm state - ti = env.run_task(workflow.check_vm_state.__name__) - self.assertEqual(ti.state, State.SUCCESS) - - # check dags status - ti = env.run_task(workflow.check_dags_status.__name__) - self.assertEqual(ti.state, State.SUCCESS) - - # update terraform variable - with patch( - "observatory.platform.workflows.vm_workflow.TerraformVirtualMachineAPI.update_terraform_vm_create_variable" - ) as m_update: - ti = env.run_task(workflow.update_terraform_variable.__name__) - m_update.assert_called_once_with(False) - self.assertEqual(ti.state, State.SUCCESS) - - # run terraform - with patch("observatory.platform.workflows.vm_workflow.TerraformApi.create_run") as m_create_run: - m_create_run.return_value = "run_id" - ti = env.run_task(workflow.run_terraform.__name__) - call_args, _ = m_create_run.call_args - self.assertEqual(call_args[0], "workspace") - self.assertEqual(call_args[1], "module.airflow_worker_vm") - self.assertEqual(ti.state, State.SUCCESS) - - # check run status - with patch( - "observatory.platform.workflows.vm_workflow.TerraformApi.get_run_details" - ) as m_run_details: - with patch("observatory.platform.workflows.vm_workflow.send_slack_msg") as m_slack: - m_run_details.return_value = {"data": {"attributes": {"status": "planned_and_finished"}}} - - ti = env.run_task(workflow.check_run_status.__name__) - self.assertEqual(ti.state, State.SUCCESS) - _, kwargs = m_slack.call_args - self.assertEqual(kwargs["comments"], "Terraform run status: planned_and_finished") - - # cleanup - ti = env.run_task(workflow.cleanup.__name__) - self.assertEqual(ti.state, State.SUCCESS) - self.assertEqual( - xcom_count( - execution_date=execution_date, - dag_ids=workflow.dag_id, - ), - 2, - ) - - @patch("observatory.platform.workflows.vm_workflow.TerraformApi.list_workspace_variables") - @patch("observatory.platform.workflows.vm_workflow.TerraformApi.workspace_id") - def test_vm_destroy_manual_create(self, m_tapi, m_list_workspace_vars): - """Test the vm_destroy workflow""" - - m_tapi.return_value = "workspace" - - vm = VirtualMachine(machine_type="vm_type", disk_size=10, disk_type="ssd", create=True) - m_list_workspace_vars.return_value = [ - TerraformVariable( - key="airflow_worker_vm", - value=vm.to_hcl(), - hcl=True, - ) - ] - - env = ObservatoryEnvironment() - with env.create(): - workflow = VmDestroyWorkflow( - dag_id=self.dag_id, - terraform_organisation=self.terraform_organisation, - terraform_workspace=self.terraform_workspace, - dags_watch_list=["vm_destroy"], - ) - dag = workflow.make_dag() - execution_date = pendulum.datetime(2021, 1, 1) - self.setup_env(env) - - with env.create_dag_run(dag, execution_date) as dag_run: - with time_machine.travel(dag_run.start_date, tick=True): - # check dependencies - ti = env.run_task(workflow.check_dependencies.__name__) - self.assertEqual(ti.state, State.SUCCESS) - - # check vm state - ti = env.run_task(workflow.check_vm_state.__name__) - self.assertEqual(ti.state, State.SUCCESS) - - # check dags status - ti = env.run_task(workflow.check_dags_status.__name__) - self.assertEqual(ti.state, State.SUCCESS) - - # update terraform variable - ti = env.run_task(workflow.update_terraform_variable.__name__) - self.assertEqual(ti.state, State.SKIPPED) - - @unittest.skip("Broken: removing workflow in near future") - @patch("observatory.platform.workflows.vm_workflow.TerraformApi.list_workspace_variables") - @patch("observatory.platform.workflows.vm_workflow.TerraformApi.workspace_id") - def test_vm_destroy_prev_execution_and_start_time(self, m_tapi, m_list_workspace_vars): - """Test the vm_destroy workflow""" - - m_tapi.return_value = "workspace" - - vm = VirtualMachine(machine_type="vm_type", disk_size=10, disk_type="ssd", create=True) - m_list_workspace_vars.return_value = [ - TerraformVariable( - key="airflow_worker_vm", - value=vm.to_hcl(), - hcl=True, - ) - ] - - env = ObservatoryEnvironment() - with env.create(): - workflow = VmDestroyWorkflow( - dag_id=self.dag_id, - terraform_organisation=self.terraform_organisation, - terraform_workspace=self.terraform_workspace, - dags_watch_list=["vm_destroy"], - ) - dag = workflow.make_dag() - execution_date = pendulum.datetime(2021, 1, 1) - self.setup_env(env) - - with env.create_dag_run(dag, execution_date) as dag_run: - with time_machine.travel(dag_run.start_date, tick=True): - # check dependencies - ti = env.run_task(workflow.check_dependencies.__name__) - self.assertEqual(ti.state, State.SUCCESS) - - # check vm state - ti = env.run_task(workflow.check_vm_state.__name__) - self.assertEqual(ti.state, State.SUCCESS) - - # check dags status - xcom_push( - dag_id=self.vm_create_dag_id, - task_id=VmCreateWorkflow.run_terraform.__name__, - execution_date=execution_date, - key=XCOM_START_TIME_VM, - value="2021-01-01", - ) - - xcom_push( - dag_id=self.vm_create_dag_id, - task_id=VmCreateWorkflow.run_terraform.__name__, - execution_date=execution_date, - key=XCOM_PREV_START_TIME_VM, - value="2021-01-01", - ) - - class MockDR: - def __init__(self): - self.start_date = datetime.datetime(2000, 1, 1) - self.execution_date = datetime.datetime(2020, 1, 1) - self.state = DagRunState.SUCCESS - self.dag_id = "dagid" - - class MockDag: - def __init__(self): - self.normalized_schedule_interval = "@weekly" - self.catchup = False - - def previous_schedule(self, *args): - return datetime.datetime(2000, 1, 1) - - def get_run_dates(self, *args): - return [datetime.datetime(2000, 1, 1)] - - with patch("observatory.platform.workflows.vm_workflow.DagRun.find") as m_drfind: - with patch("observatory.platform.workflows.vm_workflow.DagBag.get_dag") as m_getdag: - m_drfind.return_value = [MockDR()] - m_getdag.return_value = MockDag() - - ti = env.run_task(workflow.check_dags_status.__name__) - self.assertEqual(ti.state, State.SUCCESS) - - # update terraform variable - with patch( - "observatory.platform.workflows.vm_workflow.TerraformVirtualMachineAPI.update_terraform_vm_create_variable" - ) as m_update: - ti = env.run_task(workflow.update_terraform_variable.__name__) - self.assertEqual(ti.state, State.SUCCESS) - m_update.assert_called_once_with(False) - - # update terraform variable - with patch( - "observatory.platform.workflows.vm_workflow.TerraformVirtualMachineAPI.create_terraform_run" - ) as m_runterraform: - m_runterraform.return_value = "run_id" - ti = env.run_task(workflow.run_terraform.__name__) - self.assertEqual(ti.state, State.SUCCESS) - m_runterraform.assert_called_once() - - # check run status - with patch( - "observatory.platform.workflows.vm_workflow.TerraformVirtualMachineAPI.check_terraform_run_status" - ) as m_checkrun: - ti = env.run_task(workflow.check_run_status.__name__) - self.assertEqual(ti.state, State.SUCCESS) - m_checkrun.assert_called_once() - - # cleanup - ti = env.run_task(workflow.cleanup.__name__) - self.assertEqual(ti.state, State.SUCCESS) - self.assertEqual( - xcom_count( - execution_date=execution_date, - dag_ids=workflow.dag_id, - ), - 2, - ) - - @unittest.skip("Broken: removing workflow in near future") - @patch("observatory.platform.workflows.vm_workflow.TerraformApi.list_workspace_variables") - @patch("observatory.platform.workflows.vm_workflow.TerraformApi.workspace_id") - def test_vm_destroy_prev_execution_and_start_time_ge_destroy_time(self, m_tapi, m_list_workspace_vars): - """Test the vm_destroy workflow""" - - m_tapi.return_value = "workspace" - - vm = VirtualMachine(machine_type="vm_type", disk_size=10, disk_type="ssd", create=True) - m_list_workspace_vars.return_value = [ - TerraformVariable( - key="airflow_worker_vm", - value=vm.to_hcl(), - hcl=True, - ) - ] - - env = ObservatoryEnvironment() - with env.create(): - workflow = VmDestroyWorkflow( - dag_id=self.dag_id, - terraform_organisation=self.terraform_organisation, - terraform_workspace=self.terraform_workspace, - dags_watch_list=["vm_destroy"], - ) - dag = workflow.make_dag() - execution_date = pendulum.datetime(2021, 1, 1) - self.setup_env(env) - - with env.create_dag_run(dag, execution_date) as dag_run: - with time_machine.travel(dag_run.start_date, tick=True): - # check dependencies - ti = env.run_task(workflow.check_dependencies.__name__) - self.assertEqual(ti.state, State.SUCCESS) - - # check vm state - ti = env.run_task(workflow.check_vm_state.__name__) - self.assertEqual(ti.state, State.SUCCESS) - - # check dags status - xcom_push( - dag_id=self.vm_create_dag_id, - task_id=VmCreateWorkflow.run_terraform.__name__, - execution_date=execution_date, - key=XCOM_START_TIME_VM, - value="2021-01-01", - ) - - xcom_push( - dag_id=self.vm_create_dag_id, - task_id=VmCreateWorkflow.run_terraform.__name__, - execution_date=execution_date, - key=XCOM_PREV_START_TIME_VM, - value="2021-01-01", - ) - - xcom_push( - dag_id=self.dag_id, - task_id=VmCreateWorkflow.run_terraform.__name__, - execution_date=execution_date, - key=XCOM_DESTROY_TIME_VM, - value="2021-01-01", - ) - - class MockDR: - def __init__(self): - self.start_date = datetime.datetime(2000, 1, 1) - self.execution_date = datetime.datetime(2020, 1, 1) - self.state = DagRunState.SUCCESS - self.dag_id = "dagid" - - class MockDag: - def __init__(self): - self.normalized_schedule_interval = "@weekly" - self.catchup = False - - def previous_schedule(self, *args): - return datetime.datetime(2000, 1, 1) - - def get_run_dates(self, *args): - return [datetime.datetime(2000, 1, 1)] - - with patch("observatory.platform.workflows.vm_workflow.DagRun.find") as m_drfind: - with patch("observatory.platform.workflows.vm_workflow.DagBag.get_dag") as m_getdag: - m_drfind.return_value = [MockDR()] - m_getdag.return_value = MockDag() - - ti = env.run_task(workflow.check_dags_status.__name__) - self.assertEqual(ti.state, State.SUCCESS) - - # update terraform variable - with patch( - "observatory.platform.workflows.vm_workflow.TerraformVirtualMachineAPI.update_terraform_vm_create_variable" - ) as m_update: - ti = env.run_task(workflow.update_terraform_variable.__name__) - self.assertEqual(ti.state, State.SUCCESS) - m_update.assert_called_once_with(False) - - # update terraform variable - with patch( - "observatory.platform.workflows.vm_workflow.TerraformVirtualMachineAPI.create_terraform_run" - ) as m_runterraform: - m_runterraform.return_value = "run_id" - ti = env.run_task(workflow.run_terraform.__name__) - self.assertEqual(ti.state, State.SUCCESS) - m_runterraform.assert_called_once() - - # check run status - with patch( - "observatory.platform.workflows.vm_workflow.TerraformVirtualMachineAPI.check_terraform_run_status" - ) as m_checkrun: - ti = env.run_task(workflow.check_run_status.__name__) - self.assertEqual(ti.state, State.SUCCESS) - m_checkrun.assert_called_once() - - # cleanup - ti = env.run_task(workflow.cleanup.__name__) - self.assertEqual(ti.state, State.SUCCESS) - self.assertEqual( - xcom_count( - execution_date=execution_date, - dag_ids=workflow.dag_id, - ), - 2, - ) - - @unittest.skip("Broken: removing workflow in near future") - @patch("observatory.platform.workflows.vm_workflow.TerraformApi.list_workspace_variables") - @patch("observatory.platform.workflows.vm_workflow.TerraformApi.workspace_id") - def test_vm_destroy_prev_execution_and_start_time_lt_destroy_time(self, m_tapi, m_list_workspace_vars): - """Test the vm_destroy workflow""" - - m_tapi.return_value = "workspace" - - vm = VirtualMachine(machine_type="vm_type", disk_size=10, disk_type="ssd", create=True) - m_list_workspace_vars.return_value = [ - TerraformVariable( - key="airflow_worker_vm", - value=vm.to_hcl(), - hcl=True, - ) - ] - - env = ObservatoryEnvironment() - with env.create(): - workflow = VmDestroyWorkflow( - dag_id=self.dag_id, - terraform_organisation=self.terraform_organisation, - terraform_workspace=self.terraform_workspace, - dags_watch_list=["vm_destroy"], - ) - dag = workflow.make_dag() - execution_date = pendulum.datetime(2021, 1, 1) - self.setup_env(env) - - with env.create_dag_run(dag, execution_date) as dag_run: - with time_machine.travel(dag_run.start_date, tick=True): - # check dependencies - ti = env.run_task(workflow.check_dependencies.__name__) - self.assertEqual(ti.state, State.SUCCESS) - - # check vm state - ti = env.run_task(workflow.check_vm_state.__name__) - self.assertEqual(ti.state, State.SUCCESS) - - # check dags status - xcom_push( - dag_id=self.vm_create_dag_id, - task_id=VmCreateWorkflow.run_terraform.__name__, - execution_date=execution_date, - key=XCOM_START_TIME_VM, - value="2020-01-01", - ) - - xcom_push( - dag_id=self.vm_create_dag_id, - task_id=VmCreateWorkflow.run_terraform.__name__, - execution_date=execution_date, - key=XCOM_PREV_START_TIME_VM, - value="2021-01-01", - ) - - xcom_push( - dag_id=self.dag_id, - task_id=VmCreateWorkflow.run_terraform.__name__, - execution_date=execution_date, - key=XCOM_DESTROY_TIME_VM, - value="2021-01-01", - ) - - class MockDR: - def __init__(self): - self.start_date = datetime.datetime(2000, 1, 1) - self.execution_date = datetime.datetime(2020, 1, 1) - self.state = DagRunState.SUCCESS - self.dag_id = "dagid" - - class MockDag: - def __init__(self): - self.normalized_schedule_interval = "@weekly" - self.catchup = False - - def previous_schedule(self, *args): - return datetime.datetime(2000, 1, 1) - - def get_run_dates(self, *args): - return [datetime.datetime(2000, 1, 1)] - - with patch("observatory.platform.workflows.vm_workflow.DagRun.find") as m_drfind: - with patch("observatory.platform.workflows.vm_workflow.DagBag.get_dag") as m_getdag: - m_drfind.return_value = [MockDR()] - m_getdag.return_value = MockDag() - - ti = env.run_task(workflow.check_dags_status.__name__) - self.assertEqual(ti.state, State.SUCCESS) - - # update terraform variable - with patch( - "observatory.platform.workflows.vm_workflow.TerraformVirtualMachineAPI.update_terraform_vm_create_variable" - ) as m_update: - ti = env.run_task(workflow.update_terraform_variable.__name__) - self.assertEqual(ti.state, State.SUCCESS) - m_update.assert_called_once_with(False) - - # update terraform variable - with patch( - "observatory.platform.workflows.vm_workflow.TerraformVirtualMachineAPI.create_terraform_run" - ) as m_runterraform: - m_runterraform.return_value = "run_id" - ti = env.run_task(workflow.run_terraform.__name__) - self.assertEqual(ti.state, State.SUCCESS) - m_runterraform.assert_called_once() - - # check run status - with patch( - "observatory.platform.workflows.vm_workflow.TerraformVirtualMachineAPI.check_terraform_run_status" - ) as m_checkrun: - ti = env.run_task(workflow.check_run_status.__name__) - self.assertEqual(ti.state, State.SUCCESS) - m_checkrun.assert_called_once() - - # cleanup - ti = env.run_task(workflow.cleanup.__name__) - self.assertEqual(ti.state, State.SUCCESS) - self.assertEqual( - xcom_count( - execution_date=execution_date, - dag_ids=workflow.dag_id, - ), - 2, - ) - - @unittest.skip("Broken: removing workflow in near future") - @patch("observatory.platform.workflows.vm_workflow.TerraformApi.list_workspace_variables") - @patch("observatory.platform.workflows.vm_workflow.TerraformApi.workspace_id") - def test_vm_destroy_prev_execution_and_start_time_lt_destroy_time_catchup(self, m_tapi, m_list_workspace_vars): - """Test the vm_destroy workflow""" - - m_tapi.return_value = "workspace" - - vm = VirtualMachine(machine_type="vm_type", disk_size=10, disk_type="ssd", create=True) - m_list_workspace_vars.return_value = [ - TerraformVariable( - key="airflow_worker_vm", - value=vm.to_hcl(), - hcl=True, - ) - ] - - env = ObservatoryEnvironment() - with env.create(): - workflow = VmDestroyWorkflow( - dag_id=self.dag_id, - terraform_organisation=self.terraform_organisation, - terraform_workspace=self.terraform_workspace, - dags_watch_list=["vm_destroy"], - ) - dag = workflow.make_dag() - execution_date = pendulum.datetime(2021, 1, 1) - self.setup_env(env) - - with env.create_dag_run(dag, execution_date) as dag_run: - with time_machine.travel(dag_run.start_date, tick=True): - # check dependencies - ti = env.run_task(workflow.check_dependencies.__name__) - self.assertEqual(ti.state, State.SUCCESS) - - # check vm state - ti = env.run_task(workflow.check_vm_state.__name__) - self.assertEqual(ti.state, State.SUCCESS) - - # check dags status - xcom_push( - dag_id=self.vm_create_dag_id, - task_id=VmCreateWorkflow.run_terraform.__name__, - execution_date=execution_date, - key=XCOM_START_TIME_VM, - value="2020-01-01", - ) - - xcom_push( - dag_id=self.vm_create_dag_id, - task_id=VmCreateWorkflow.run_terraform.__name__, - execution_date=execution_date, - key=XCOM_PREV_START_TIME_VM, - value="2021-01-01", - ) - - xcom_push( - dag_id=self.dag_id, - task_id=VmCreateWorkflow.run_terraform.__name__, - execution_date=execution_date, - key=XCOM_DESTROY_TIME_VM, - value="2021-01-01", - ) - - class MockDR: - def __init__(self): - self.start_date = datetime.datetime(2000, 1, 1) - self.execution_date = datetime.datetime(2020, 1, 1) - self.state = DagRunState.SUCCESS - self.dag_id = "dagid" - - class MockDag: - def __init__(self): - self.normalized_schedule_interval = "@weekly" - self.catchup = True - - def previous_schedule(self, *args): - return datetime.datetime(2000, 1, 1) - - def get_run_dates(self, *args): - return [datetime.datetime(2000, 1, 1)] - - with patch("observatory.platform.workflows.vm_workflow.DagRun.find") as m_drfind: - with patch("observatory.platform.workflows.vm_workflow.DagBag.get_dag") as m_getdag: - m_drfind.return_value = [MockDR()] - m_getdag.return_value = MockDag() - - ti = env.run_task(workflow.check_dags_status.__name__) - self.assertEqual(ti.state, State.SUCCESS) - - # update terraform variable - with patch( - "observatory.platform.workflows.vm_workflow.TerraformVirtualMachineAPI.update_terraform_vm_create_variable" - ) as m_update: - ti = env.run_task(workflow.update_terraform_variable.__name__) - self.assertEqual(ti.state, State.SUCCESS) - m_update.assert_called_once_with(False) - - # update terraform variable - with patch( - "observatory.platform.workflows.vm_workflow.TerraformVirtualMachineAPI.create_terraform_run" - ) as m_runterraform: - m_runterraform.return_value = "run_id" - ti = env.run_task(workflow.run_terraform.__name__) - self.assertEqual(ti.state, State.SUCCESS) - m_runterraform.assert_called_once() - - # check run status - with patch( - "observatory.platform.workflows.vm_workflow.TerraformVirtualMachineAPI.check_terraform_run_status" - ) as m_checkrun: - ti = env.run_task(workflow.check_run_status.__name__) - self.assertEqual(ti.state, State.SUCCESS) - m_checkrun.assert_called_once() - - # cleanup - ti = env.run_task(workflow.cleanup.__name__) - self.assertEqual(ti.state, State.SUCCESS) - self.assertEqual( - xcom_count( - execution_date=execution_date, - dag_ids=workflow.dag_id, - ), - 2, - ) - - @unittest.skip("Broken: removing workflow in near future") - @patch("observatory.platform.workflows.vm_workflow.TerraformApi.list_workspace_variables") - @patch("observatory.platform.workflows.vm_workflow.TerraformApi.workspace_id") - def test_vm_destroy_start_time_no_prev_execution(self, m_tapi, m_list_workspace_vars): - """Test the vm_destroy workflow""" - - m_tapi.return_value = "workspace" - - vm = VirtualMachine(machine_type="vm_type", disk_size=10, disk_type="ssd", create=True) - m_list_workspace_vars.return_value = [ - TerraformVariable( - key="airflow_worker_vm", - value=vm.to_hcl(), - hcl=True, - ) - ] - - env = ObservatoryEnvironment() - with env.create(): - workflow = VmDestroyWorkflow( - dag_id=self.dag_id, - terraform_organisation=self.terraform_organisation, - terraform_workspace=self.terraform_workspace, - dags_watch_list=["vm_destroy"], - ) - dag = workflow.make_dag() - execution_date = pendulum.datetime(2021, 1, 1) - self.setup_env(env) - - with env.create_dag_run(dag, execution_date) as dag_run: - with time_machine.travel(dag_run.start_date, tick=True): - # check dependencies - ti = env.run_task(workflow.check_dependencies.__name__) - self.assertEqual(ti.state, State.SUCCESS) - - # check vm state - ti = env.run_task(workflow.check_vm_state.__name__) - self.assertEqual(ti.state, State.SUCCESS) - - # check dags status - xcom_push( - dag_id=self.vm_create_dag_id, - task_id=VmCreateWorkflow.run_terraform.__name__, - execution_date=execution_date, - key=XCOM_START_TIME_VM, - value="2021-01-01", - ) - - class MockDR: - def __init__(self): - self.start_date = datetime.datetime(2000, 1, 1) - self.execution_date = datetime.datetime(2020, 1, 1) - self.state = DagRunState.RUNNING - self.dag_id = "dagid" - - class MockDag: - def __init__(self): - self.normalized_schedule_interval = "@weekly" - self.catchup = False - self.default_args = {"start_date": datetime.datetime(2000, 1, 1)} - - def previous_schedule(self, *args): - return datetime.datetime(2000, 1, 1) - - def get_run_dates(self, *args): - return [datetime.datetime(2000, 1, 1)] - - with patch("observatory.platform.workflows.vm_workflow.DagRun.find") as m_drfind: - with patch("observatory.platform.workflows.vm_workflow.DagBag.get_dag") as m_getdag: - m_drfind.return_value = [MockDR()] - m_getdag.return_value = MockDag() - - ti = env.run_task(workflow.check_dags_status.__name__) - self.assertEqual(ti.state, State.SUCCESS) - - # update terraform variable - with patch( - "observatory.platform.workflows.vm_workflow.TerraformVirtualMachineAPI.update_terraform_vm_create_variable" - ) as m_update: - ti = env.run_task(workflow.update_terraform_variable.__name__) - self.assertEqual(ti.state, State.SKIPPED) - - @unittest.skip("Broken: removing workflow in near future") - @patch("observatory.platform.workflows.vm_workflow.TerraformApi.list_workspace_variables") - @patch("observatory.platform.workflows.vm_workflow.TerraformApi.workspace_id") - def test_vm_destroy_dont_destroy_worker_slack_warning(self, m_tapi, m_list_workspace_vars): - """Test the vm_destroy workflow""" - - m_tapi.return_value = "workspace" - - vm = VirtualMachine(machine_type="vm_type", disk_size=10, disk_type="ssd", create=True) - m_list_workspace_vars.return_value = [ - TerraformVariable( - key="airflow_worker_vm", - value=vm.to_hcl(), - hcl=True, - ) - ] - - env = ObservatoryEnvironment() - with env.create(): - workflow = VmDestroyWorkflow( - dag_id=self.dag_id, - terraform_organisation=self.terraform_organisation, - terraform_workspace=self.terraform_workspace, - dags_watch_list=["vm_destroy"], - ) - dag = workflow.make_dag() - execution_date = pendulum.datetime(2021, 1, 1) - self.setup_env(env) - - with env.create_dag_run(dag, execution_date) as dag_run: - with time_machine.travel(dag_run.start_date, tick=True): - # check dependencies - ti = env.run_task(workflow.check_dependencies.__name__) - self.assertEqual(ti.state, State.SUCCESS) - - # check vm state - ti = env.run_task(workflow.check_vm_state.__name__) - self.assertEqual(ti.state, State.SUCCESS) - - # check dags status - xcom_push( - dag_id=self.vm_create_dag_id, - task_id=VmCreateWorkflow.run_terraform.__name__, - execution_date=execution_date, - key=XCOM_START_TIME_VM, - value="2020-12-31", - ) - - xcom_push( - dag_id=self.vm_create_dag_id, - task_id=VmCreateWorkflow.run_terraform.__name__, - execution_date=execution_date, - key=XCOM_PREV_START_TIME_VM, - value="2021-01-01", - ) - - xcom_push( - dag_id=self.dag_id, - task_id=VmDestroyWorkflow.check_runtime_vm.__name__, - execution_date=execution_date, - key=XCOM_WARNING_TIME, - value="2020-01-01", - ) - - class MockDR: - def __init__(self): - self.start_date = datetime.datetime(2000, 1, 1) - self.execution_date = datetime.datetime(2020, 1, 1) - self.state = DagRunState.SUCCESS - self.dag_id = "dagid" - - class MockDag: - def __init__(self): - self.normalized_schedule_interval = "@weekly" - self.catchup = False - - def previous_schedule(self, *args): - return datetime.datetime(2000, 1, 1) - - def get_run_dates(self, *args): - return [datetime.datetime(2000, 1, 1)] - - with patch("observatory.platform.workflows.vm_workflow.DagRun.find") as m_drfind: - with patch("observatory.platform.workflows.vm_workflow.DagBag.get_dag") as m_getdag: - with patch( - "observatory.platform.workflows.vm_workflow.VmDestroyWorkflow._check_success_runs" - ) as m_check_success_runs: - with patch("observatory.platform.workflows.vm_workflow.send_slack_msg") as m_slack: - m_drfind.return_value = [MockDR()] - m_getdag.return_value = MockDag() - m_check_success_runs.return_value = False - - ti = env.run_task(workflow.check_dags_status.__name__) - self.assertEqual(ti.state, State.SUCCESS) - - self.assertEqual(m_slack.call_count, 1) - - @patch("observatory.platform.workflows.vm_workflow.TerraformApi.list_workspace_variables") - @patch("observatory.platform.workflows.vm_workflow.TerraformApi.workspace_id") - def test_vm_destroy_dont_destroy_worker_no_slack_warning(self, m_tapi, m_list_workspace_vars): - """Test the vm_destroy workflow""" - - m_tapi.return_value = "workspace" - - vm = VirtualMachine(machine_type="vm_type", disk_size=10, disk_type="ssd", create=True) - m_list_workspace_vars.return_value = [ - TerraformVariable( - key="airflow_worker_vm", - value=vm.to_hcl(), - hcl=True, - ) - ] - - env = ObservatoryEnvironment() - with env.create(): - workflow = VmDestroyWorkflow( - dag_id=self.dag_id, - terraform_organisation=self.terraform_organisation, - terraform_workspace=self.terraform_workspace, - dags_watch_list=["vm_destroy"], - ) - dag = workflow.make_dag() - data_interval_start = pendulum.datetime(2021, 1, 1) - self.setup_env(env) - - with env.create_dag_run(dag, data_interval_start) as dag_run: - with time_machine.travel(dag_run.start_date, tick=True): - # check dependencies - ti = env.run_task(workflow.check_dependencies.__name__) - self.assertEqual(ti.state, State.SUCCESS) - - # check vm state - ti = env.run_task(workflow.check_vm_state.__name__) - self.assertEqual(ti.state, State.SUCCESS) - - # check dags status - class MockDR: - def __init__(self): - self.start_date = datetime.datetime(2000, 1, 1) - self.execution_date = datetime.datetime(2020, 1, 1) - self.state = DagRunState.SUCCESS - self.dag_id = "dagid" - - class MockDag: - def __init__(self): - self.normalized_schedule_interval = "@weekly" - self.catchup = False - - def previous_schedule(self, *args): - return datetime.datetime(2000, 1, 1) - - def get_run_dates(self, *args): - return [datetime.datetime(2000, 1, 1)] - - with patch("observatory.platform.workflows.vm_workflow.DagRun.find") as m_drfind, patch( - "observatory.platform.workflows.vm_workflow.DagBag.get_dag" - ) as m_getdag, patch( - "observatory.platform.workflows.vm_workflow.VmDestroyWorkflow._check_success_runs" - ) as m_check_success_runs, patch( - "observatory.platform.workflows.vm_workflow.send_slack_msg" - ) as m_slack, patch( - "observatory.platform.workflows.vm_workflow.TaskInstance.xcom_pull" - ) as mock_xcom_pull: - mock_xcom_pull.reset_mock() - prev_start_time_vm = "2021-01-01" - start_time_vm = "2020-12-31" - warning_time = "2021-01-01" - # First 2 None xcom usages are from 'check_dependencies' task - mock_xcom_pull.side_effect = [None, None, prev_start_time_vm, start_time_vm, None, warning_time] - - m_drfind.return_value = [MockDR()] - m_getdag.return_value = MockDag() - m_check_success_runs.return_value = False - - ti = env.run_task(workflow.check_dags_status.__name__) - self.assertEqual(ti.state, State.SUCCESS) - self.assertEqual(m_slack.call_count, 0) diff --git a/tests/observatory/platform/workflows/test_workflow.py b/tests/observatory/platform/workflows/test_workflow.py deleted file mode 100644 index e53697b42..000000000 --- a/tests/observatory/platform/workflows/test_workflow.py +++ /dev/null @@ -1,429 +0,0 @@ -# Copyright 2021 Curtin University -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Author: Tuan Chien, Aniek Roelofs - -import os -from datetime import datetime, timezone -from functools import partial -from tempfile import TemporaryDirectory -from unittest.mock import patch, MagicMock -from copy import deepcopy - -import pendulum -from airflow import DAG -from airflow.exceptions import AirflowNotFoundException, AirflowException -from airflow.models.baseoperator import BaseOperator -from airflow.models.connection import Connection -from airflow.operators.bash import BashOperator -from airflow.operators.python import PythonOperator -from airflow.sensors.external_task import ExternalTaskSensor - -from observatory.platform.observatory_config import CloudWorkspace -from observatory.platform.observatory_environment import ( - ObservatoryEnvironment, - ObservatoryTestCase, - find_free_port, -) -from observatory.platform.workflows.workflow import ( - Release, - Workflow, - make_task_id, - make_workflow_folder, - make_snapshot_date, - cleanup, - set_task_state, - check_workflow_inputs, -) - - -class MockWorkflow(Workflow): - """ - Generic Workflow telescope for running tasks. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - def make_release(self, **kwargs) -> Release: - return Release(dag_id=self.dag_id, run_id=kwargs["run_id"]) - - def setup_task(self, **kwargs) -> bool: - return True - - def task(self, release: Release, **kwargs): - pass - - def task5(self, release: Release, **kwargs): - pass - - def task6(self, release: Release, **kwargs): - pass - - -class TestCallbackWorkflow(Workflow): - def __init__( - self, dag_id: str, start_date: pendulum.DateTime, schedule: str, max_retries: int, airflow_conns: list - ): - super().__init__(dag_id, start_date, schedule, max_retries=max_retries, airflow_conns=airflow_conns) - self.add_setup_task(self.check_dependencies) - - def make_release(self, **kwargs): - return - - -class TestWorkflowFunctions(ObservatoryTestCase): - def test_set_task_state(self): - """Test set_task_state""" - - task_id = "test_task" - set_task_state(True, task_id) - with self.assertRaises(AirflowException): - set_task_state(False, task_id) - - @patch("observatory.platform.airflow.Variable.get") - def test_make_workflow_folder(self, mock_get_variable): - """Tests the make_workflow_folder function""" - with TemporaryDirectory() as tempdir: - mock_get_variable.return_value = tempdir - run_id = "scheduled__2023-03-26T00:00:00+00:00" # Also can look like: "manual__2023-03-26T00:00:00+00:00" - path = make_workflow_folder("test_dag", run_id, "sub_folder", "subsub_folder") - self.assertEqual( - path, - os.path.join(tempdir, f"test_dag/scheduled__2023-03-26T00:00:00+00:00/sub_folder/subsub_folder"), - ) - - def test_check_workflow_inputs(self): - """Test check_workflow_inputs""" - # Test Dag ID validity - wf = MagicMock(dag_id="valid") - check_workflow_inputs(wf, check_cloud_workspace=False) # Should pass - for dag_id in ["", None, 42]: # Should all fail - wf.dag_id = dag_id - with self.assertRaises(AirflowException) as cm: - check_workflow_inputs(wf, check_cloud_workspace=False) - msg = cm.exception.args[0] - self.assertIn("dag_id", msg) - - # Test when cloud workspace is of wrong type - wf = MagicMock(dag_id="valid", cloud_workspace="invalid") - with self.assertRaisesRegex(AirflowException, "cloud_workspace"): - check_workflow_inputs(wf) - - # Test validity of each part of the cloud workspace - valid_cloud_workspace = CloudWorkspace( - project_id="project_id", - download_bucket="download_bucket", - transform_bucket="transform_bucket", - data_location="data_location", - output_project_id="output_project_id", - ) - wf = MagicMock(dag_id="valid", cloud_workspace=deepcopy(valid_cloud_workspace)) - check_workflow_inputs(wf) # Should pass - for attr, invalid_val in [ - ("project_id", ""), - ("download_bucket", None), - ("transform_bucket", 42), - ("data_location", MagicMock()), - ("output_project_id", ""), - ]: - wf = MagicMock(dag_id="valid", cloud_workspace=deepcopy(valid_cloud_workspace)) - setattr(wf.cloud_workspace, attr, invalid_val) - with self.assertRaisesRegex(AirflowException, f"cloud_workspace.{attr}"): - check_workflow_inputs(wf) - wf = MagicMock(dag_id="valid", cloud_workspace=deepcopy(valid_cloud_workspace)) - wf.cloud_workspace.output_project_id = None - check_workflow_inputs(wf) # This one should pass - - def test_make_snapshot_date(self): - """Test make_table_name""" - - data_interval_end = pendulum.datetime(2021, 11, 11) - expected_date = pendulum.datetime(2021, 11, 11) - actual_date = make_snapshot_date(**{"data_interval_end": data_interval_end}) - self.assertEqual(expected_date, actual_date) - - def test_cleanup(self): - """ - Tests the cleanup function. - Creates a task and pushes and Xcom. Also creates a fake workflow directory. - Both the Xcom and the directory should be deleted by the cleanup() function - """ - - def create_xcom(**kwargs): - ti = kwargs["ti"] - execution_date = kwargs["execution_date"] - ti.xcom_push("topic", {"snapshot_date": execution_date.format("YYYYMMDD"), "something": "info"}) - - env = ObservatoryEnvironment(enable_api=False) - with env.create(): - execution_date = pendulum.datetime(2023, 1, 1) - with DAG( - dag_id="test_dag", - schedule="@daily", - default_args={"owner": "airflow", "start_date": execution_date}, - catchup=True, - ) as dag: - kwargs = {"task_id": "create_xcom"} - op = PythonOperator(python_callable=create_xcom, **kwargs) - - with TemporaryDirectory() as workflow_dir: - # Create some files in the workflow folder - subdir = os.path.join(workflow_dir, "test_directory") - os.mkdir(subdir) - - # DAG Run - with env.create_dag_run(dag=dag, execution_date=execution_date): - ti = env.run_task("create_xcom") - self.assertEqual("success", ti.state) - msgs = ti.xcom_pull(key="topic", task_ids="create_xcom", include_prior_dates=True) - self.assertIsInstance(msgs, dict) - cleanup("test_dag", execution_date, workflow_folder=workflow_dir, retention_days=0) - msgs = ti.xcom_pull(key="topic", task_ids="create_xcom", include_prior_dates=True) - self.assertEqual(msgs, None) - self.assertEqual(os.path.isdir(subdir), False) - self.assertEqual(os.path.isdir(workflow_dir), False) - - -class TestWorkflow(ObservatoryTestCase): - """Tests the Telescope.""" - - def __init__(self, *args, **kwargs): - """Constructor which sets up variables used by tests. - - :param args: arguments. - :param kwargs: keyword arguments. - """ - - super().__init__(*args, **kwargs) - self.dag_id = "dag_id" - self.start_date = pendulum.datetime(2020, 1, 1) - self.schedule = "@weekly" - self.project_id = os.getenv("TEST_GCP_PROJECT_ID") - self.data_location = os.getenv("TEST_GCP_DATA_LOCATION") - - self.host = "localhost" - self.port = find_free_port() - - def test_make_task_id(self): - """Test make_task_id""" - - def test_func(): - pass - - # task_id is specified as kwargs - expected_task_id = "hello" - actual_task_id = make_task_id(test_func, {"task_id": expected_task_id}) - self.assertEqual(expected_task_id, actual_task_id) - - # task_id not specified in kwargs - expected_task_id = "test_func" - actual_task_id = make_task_id(test_func, {}) - self.assertEqual(expected_task_id, actual_task_id) - - def dummy_func(self): - pass - - def test_add_operator(self): - workflow = MockWorkflow( - dag_id="1", start_date=datetime(1970, 1, 1, 0, 0, tzinfo=timezone.utc), schedule="@daily" - ) - op1 = ExternalTaskSensor( - external_dag_id="1", task_id="test", start_date=datetime(1970, 1, 1, 0, 0, tzinfo=timezone.utc) - ) - op2 = ExternalTaskSensor( - external_dag_id="1", task_id="test2", start_date=datetime(1970, 1, 1, 0, 0, tzinfo=timezone.utc) - ) - - with workflow.parallel_tasks(): - workflow.add_operator(op1) - workflow.add_operator(op2) - workflow.add_task(self.dummy_func) - dag = workflow.make_dag() - - self.assert_dag_structure({"dummy_func": [], "test": ["dummy_func"], "test2": ["dummy_func"]}, dag) - - def test_workflow_tags(self): - workflow = MockWorkflow( - dag_id="1", - start_date=datetime(1970, 1, 1, 0, 0, tzinfo=timezone.utc), - schedule="@daily", - tags=["oaebu"], - ) - - self.assertEqual(workflow.dag.tags, ["oaebu"]) - - def test_make_dag(self): - """Test making DAG""" - # Test adding tasks from Telescope methods - telescope = MockWorkflow(self.dag_id, self.start_date, self.schedule) - telescope.add_setup_task(telescope.setup_task) - telescope.add_task(telescope.task) - dag = telescope.make_dag() - self.assertIsInstance(dag, DAG) - self.assertEqual(2, len(dag.tasks)) - for task in dag.tasks: - self.assertIsInstance(task, BaseOperator) - - # Test adding tasks from partial Telescope methods - telescope = MockWorkflow(self.dag_id, self.start_date, self.schedule) - for i in range(2): - setup_task = partial(telescope.task, somearg="test") - setup_task.__name__ = f"setup_task_{i}" - telescope.add_setup_task(setup_task) - for i in range(2): - task = partial(telescope.task, somearg="test") - task.__name__ = f"task_{i}" - telescope.add_task(task) - dag = telescope.make_dag() - self.assertIsInstance(dag, DAG) - self.assertEqual(4, len(dag.tasks)) - for task in dag.tasks: - self.assertIsInstance(task, BaseOperator) - - # Test adding tasks with custom kwargs - telescope = MockWorkflow(self.dag_id, self.start_date, self.schedule) - telescope.add_setup_task(telescope.setup_task, trigger_rule="none_failed") - telescope.add_task(telescope.task, trigger_rule="none_failed") - dag = telescope.make_dag() - self.assertIsInstance(dag, DAG) - self.assertEqual(2, len(dag.tasks)) - for task in dag.tasks: - self.assertIsInstance(task, BaseOperator) - self.assertEqual("none_failed", task.trigger_rule) - - # Test adding tasks with custom operator - telescope = MockWorkflow(self.dag_id, self.start_date, self.schedule) - task_id = "bash_task" - telescope.add_operator(BashOperator(task_id=task_id, bash_command="echo 'hello'")) - dag = telescope.make_dag() - self.assertIsInstance(dag, DAG) - self.assertEqual(1, len(dag.tasks)) - for task in dag.tasks: - self.assertEqual(task_id, task.task_id) - self.assertIsInstance(task, BashOperator) - - # Test adding parallel tasks - telescope = MockWorkflow(self.dag_id, self.start_date, self.schedule) - telescope.add_setup_task(telescope.setup_task) - with telescope.parallel_tasks(): - telescope.add_task(telescope.task, task_id="task1") - telescope.add_task(telescope.task, task_id="task2") - telescope.add_task(telescope.task, task_id="join1") - with telescope.parallel_tasks(): - telescope.add_task(telescope.task, task_id="task3") - telescope.add_task(telescope.task, task_id="task4") - telescope.add_task(telescope.task, task_id="join2") - with telescope.parallel_tasks(): - telescope.add_task(telescope.task5) - telescope.add_task(telescope.task6) - - dag = telescope.make_dag() - self.assertIsInstance(dag, DAG) - self.assertEqual(9, len(dag.tasks)) - self.assert_dag_structure( - { - "setup_task": ["task1", "task2"], - "task1": ["join1"], - "task2": ["join1"], - "join1": ["task3", "task4"], - "task3": ["join2"], - "task4": ["join2"], - "join2": ["task5", "task6"], - "task5": [], - "task6": [], - }, - dag, - ) - - # Test parallel tasks function - telescope = MockWorkflow(self.dag_id, self.start_date, self.schedule) - self.assertFalse(telescope._parallel_tasks) - with telescope.parallel_tasks(): - self.assertTrue(telescope._parallel_tasks) - self.assertFalse(telescope._parallel_tasks) - - def test_telescope(self): - """Basic test to make sure that the Workflow class can execute in an Airflow environment.""" - # Setup Observatory environment - env = ObservatoryEnvironment(self.project_id, self.data_location, api_host=self.host, api_port=self.port) - - # Create the Observatory environment and run tests - with env.create(task_logging=True): - task1 = "task1" - task2 = "task2" - expected_date = "success" - - workflow = MockWorkflow(self.dag_id, self.start_date, self.schedule) - workflow.add_setup_task(workflow.setup_task) - workflow.add_task(workflow.task, task_id=task1) - workflow.add_operator(BashOperator(task_id=task2, bash_command="echo 'hello'")) - - dag = workflow.make_dag() - with env.create_dag_run(dag, self.start_date): - ti = env.run_task(workflow.setup_task.__name__) - self.assertEqual(expected_date, ti.state) - - ti = env.run_task(task1) - self.assertEqual(expected_date, ti.state) - - ti = env.run_task(task2) - self.assertEqual(expected_date, ti.state) - - @patch("observatory.platform.airflow.send_slack_msg") - def test_callback(self, mock_send_slack_msg): - """Test that the on_failure_callback function is successfully called in a production environment when a task - fails - - :param mock_send_slack_msg: Mock send_slack_msg function - :return: None. - """ - # mock_send_slack_msg.return_value = Mock(spec=SlackWebhookHook) - - # Setup Observatory environment - env = ObservatoryEnvironment(self.project_id, self.data_location) - - # Setup Workflow with 0 retries and missing airflow variable, so it will fail the task - execution_date = pendulum.datetime(2020, 1, 1) - conn_id = "orcid_bucket" - workflow = TestCallbackWorkflow( - "test_callback", - execution_date, - self.schedule, - max_retries=0, - airflow_conns=[conn_id], - ) - dag = workflow.make_dag() - - # Create the Observatory environment and run task, expecting slack webhook call in production environment - with env.create(task_logging=True): - with env.create_dag_run(dag, execution_date): - with self.assertRaises(AirflowNotFoundException): - env.run_task(workflow.check_dependencies.__name__) - - _, callkwargs = mock_send_slack_msg.call_args - self.assertTrue( - "airflow.exceptions.AirflowNotFoundException: The conn_id `orcid_bucket` isn't defined" - in callkwargs["comments"] - ) - - # Reset mock - mock_send_slack_msg.reset_mock() - - # Add orcid_bucket connection and test that Slack Web Hook did not get triggered - with env.create(task_logging=True): - with env.create_dag_run(dag, execution_date): - env.add_connection(Connection(conn_id=conn_id, uri="https://orcid.org/")) - env.run_task(workflow.check_dependencies.__name__) - mock_send_slack_msg.assert_not_called()