Skip to content

Commit

Permalink
tests: added fixtures for task plugin archives
Browse files Browse the repository at this point in the history
Used fixtures to pass a tmp dir containing a .tar file to represent a task plugin archive.

Closes #316
  • Loading branch information
andrewhand committed Dec 1, 2023
1 parent 3259a40 commit 41fdcde
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 168 deletions.
92 changes: 0 additions & 92 deletions tests/unit/restapi/test_job.py

This file was deleted.

142 changes: 66 additions & 76 deletions tests/unit/restapi/test_task_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,43 @@
from __future__ import annotations

import tarfile
from typing import Any
from typing import Any, BinaryIO, Dict

import pytest
from flask.testing import FlaskClient
from flask_sqlalchemy import SQLAlchemy
from werkzeug.test import TestResponse

from dioptra.restapi.task_plugin.routes import BASE_ROUTE as TASK_PLUGIN_BASE_ROUTE


# -- Fixtures --------------------------------------------------------------------------


@pytest.fixture
def task_plugin_builtin_request_form(task_plugin_name: str, task_plugin_archive: BinaryIO) -> Dict[str, Any]:
return {
"task_plugin_name": "builtin_plugin",
"task_plugin_file": task_plugin_archive,
"collection": "dioptra_builtins"
}

@pytest.fixture
def task_plugin_custom_request_form(task_plugin_name: str, task_plugin_archive: BinaryIO) -> Dict[str, Any]:
return {
"task_plugin_name": "custom_plugin",
"task_plugin_file": task_plugin_archive,
"collection": "dioptra_custom"
}


# -- Actions ---------------------------------------------------------------------------


def register_task_plugin(client: FlaskClient, task_plugin_name: str, task_plugin_file: Any, collection: str) -> TestResponse:
def register_task_plugin(
client: FlaskClient,
task_plugin_request_form: Dict[str, Any],
) -> TestResponse:
"""Register a task plugin package using the API.
Args:
Expand All @@ -49,11 +74,7 @@ def register_task_plugin(client: FlaskClient, task_plugin_name: str, task_plugin
"""
return client.post(
f"/api/{TASK_PLUGIN_BASE_ROUTE}/",
json={
"task_plugin_name": task_plugin_name,
"task_plugin_file": task_plugin_file,
"collection": collection
},
json=task_plugin_request_form,
follow_redirects=True,
)

Expand Down Expand Up @@ -180,9 +201,7 @@ def assert_retrieving_all_task_plugins_works(

def assert_registering_existing_task_plugin_name_fails(
client: FlaskClient,
task_plugin_name: str,
task_plugin_file: Any,
collection: str,
task_plugin_request_form: Dict[str, Any],
) -> None:
"""Assert that registering a task plugin with an existing name in the same collection fails.
Expand All @@ -193,15 +212,15 @@ def assert_registering_existing_task_plugin_name_fails(
Raises:
AssertionError: If the response status code is not 400.
"""
response = register_task_plugin(client, task_plugin_name=task_plugin_name, task_plugin_file=task_plugin_file, collection=collection)
response = register_task_plugin(client, task_plugin_request_form)
assert response.status_code == 400


def assert_custom_task_plugin_not_found(
client: FlaskClient,
task_plugin_name: str,
) -> None:
"""Assert that a queue is not found.
"""Assert that a task plugin package is not found.
Args:
client: The Flask test client.
Expand All @@ -220,12 +239,17 @@ def assert_custom_task_plugin_not_found(
# -- Tests -----------------------------------------------------------------------------


def test_task_plugin_registration(client: FlaskClient, db: SQLAlchemy) -> None:
def test_task_plugin_registration(
client: FlaskClient,
db: SQLAlchemy,
task_plugin_builtin_request_form: Dict[str, Any],
task_plugin_custom_request_form: Dict[str, Any],
) -> None:
"""Test that task plugin packages can be registered and retrieved using the API.
This test validates the following sequence of actions:
- A user registers four task plugin packages, 2 builtins "artifacts" and "attacks", 2 custom "custom_fgm_plugin" and "custom_patch_plugin".
- A user registers a builtin task plugin package and a custom task plugins package.
- The user is able to retrieve information about each task plugin package using its unique name.
- The user is able to retrieve a list of all registered builtin task plugin packages.
- The user is able to retrieve a list of all registered custom task plugin packages.
Expand All @@ -234,97 +258,63 @@ def test_task_plugin_registration(client: FlaskClient, db: SQLAlchemy) -> None:
during registration.
"""

tar1 = tarfile.open("artifacts.tar", "w")
tar1.add("file1.py")
tar1.add("file2.py")
tar1.close()
tar2 = tarfile.open("attacks.tar", "w")
tar2.add("file1.py")
tar2.add("file2.py")
tar2.close()

tar3 = tarfile.open("custom_fgm_plugin.tar", "w")
tar3.add("file1.py")
tar3.add("file2.py")
tar3.close()
tar4 = tarfile.open("custom_patch_plugin.tar", "w")
tar4.add("file1.py")
tar4.add("file2.py")
tar4.close()

plugin1_response = register_task_plugin(client, task_plugin_name="artifacts", task_plugin_file=tar1, collection="dioptra_builtins")
plugin2_response = register_task_plugin(client, task_plugin_name="attacks", task_plugin_file=tar2, collection="dioptra_builtins")
plugin3_response = register_task_plugin(client, task_plugin_name="custom_fgm_plugin", task_plugin_file=tar3, collection="dioptra_custom")
plugin4_response = register_task_plugin(client, task_plugin_name="custom_patch_plugin", task_plugin_file=tar4, collection="dioptra_custom")
plugin1_response = register_task_plugin(client, task_plugin_builtin_request_form)
plugin2_response = register_task_plugin(client, task_plugin_custom_request_form)
plugin1_expected = plugin1_response.get_json()
plugin2_expected = plugin2_response.get_json()
plugin3_expected = plugin3_response.get_json()
plugin4_expected = plugin4_response.get_json()
builtins_expected_list = [plugin1_expected, plugin2_expected]
custom_expected_list = [plugin3_expected, plugin4_expected]
all_expected_list = [plugin1_expected, plugin2_expected, plugin3_expected, plugin4_expected]
builtins_expected_list = [plugin1_expected]
custom_expected_list = [plugin2_expected]
all_expected_list = [plugin1_expected, plugin2_expected]
assert_retrieving_builtins_task_plugin_by_name_works(
client, task_plugin_name=plugin1_expected["taskPluginName"], expected=plugin1_expected
)
assert_retrieving_builtins_task_plugin_by_name_works(
client, task_plugin_name=plugin2_expected["taskPluginName"], expected=plugin2_expected
)
assert_retrieving_custom_task_plugin_by_name_works(
client, task_plugin_name=plugin3_expected["taskPluginName"], expected=plugin3_expected
)
assert_retrieving_custom_task_plugin_by_name_works(
client, task_plugin_name=plugin4_expected["taskPluginName"], expected=plugin4_expected
)
assert_retrieving_all_builtins_task_plugins_works(client, expected=builtins_expected_list)
assert_retrieving_all_custom_task_plugins_works(client, expected=custom_expected_list)
assert_retrieving_all_task_plugins_works(client, expected=all_expected_list)


def test_cannot_register_existing_queue_name(
client: FlaskClient, db: SQLAlchemy
def test_cannot_register_existing_task_plugin_name(
client: FlaskClient,
db: SQLAlchemy,
task_plugin_builtin_request_form: Dict[str, Any],
task_plugin_custom_request_form: Dict[str, Any],
) -> None:
"""Test that registering a task plugin package with an existing name fails.
This test validates the following sequence of actions:
- A user registers a builtin task plugin named "attacks".
- A user registers a custom task plugin named "custom_fgm_plugin".
- A user registers a builtin task plugin.
- A user registers a custom task plugin.
- The user attempts to register a second builtin task plugin with the same name, which fails.
- The user attempts to register a second custom task plugin with the same name, which fails.
"""
tar1 = tarfile.open("artifacts.tar", "w")
tar1.add("file1.py")
tar1.add("file2.py")
tar1.close()
tar2 = tarfile.open("custom_fgm_plugin.tar", "w")
tar2.add("file1.py")
tar2.add("file2.py")
tar2.close()

register_task_plugin(client, task_plugin_name="artifacts", task_plugin_file=tar1, collection="dioptra_builtins")
assert_registering_existing_task_plugin_name_fails(client, task_plugin_name="artifacts", task_plugin_file=tar1, collection="dioptra_builtins")
register_task_plugin(client, task_plugin_name="custom_fgm_plugin", task_plugin_file=tar2, collection="dioptra_custom")
assert_registering_existing_task_plugin_name_fails(client, task_plugin_name="custom_fgm_plugin", task_plugin_file=tar1, collection="dioptra_custom")

register_task_plugin(client, task_plugin_builtin_request_form)
register_task_plugin(client, task_plugin_custom_request_form)
assert_registering_existing_task_plugin_name_fails(client, task_plugin_builtin_request_form)
assert_registering_existing_task_plugin_name_fails(client, task_plugin_custom_request_form)

def test_delete_custom_task_plugin_by_name(client: FlaskClient, db: SQLAlchemy) -> None:
"""Test that a queue can be deleted by referencing its name.
def test_delete_custom_task_plugin_by_name(
client: FlaskClient,
db: SQLAlchemy,
task_plugin_custom_request_form: Dict[str, Any],
) -> None:
"""Test that a task plugin can be deleted by referencing its name.
This test validates the following sequence of actions:
- A user registers a queue named "tensorflow_cpu".
- The user is able to retrieve information about the "tensorflow_cpu" queue that
matches the information that was provided during registration.
- The user deletes the "tensorflow_cpu" queue by referencing its name.
- The user attempts to retrieve information about the "tensorflow_cpu" queue, which
- A user registers a task plugin.
- The user is able to retrieve information about the task plugin that
matches the name that was provided during registration.
- The user deletes the task plugin by referencing its name.
- The user attempts to retrieve information about the task plugin by name, which
is no longer found.
"""
tar = tarfile.open("artifacts.tar", "w")
tar.add("file1.py")
tar.add("file2.py")
tar.close()

registration_response = register_task_plugin(client, task_plugin_name="custom_fgm_plugin", task_plugin_file=tar, collection="dioptra_custom")
registration_response = register_task_plugin(client, task_plugin_custom_request_form)
task_plugin_json = registration_response.get_json()
assert_retrieving_custom_task_plugin_by_name_works(
client, task_plugin_name=task_plugin_json["taskPluginName"], expected=task_plugin_json
Expand Down

0 comments on commit 41fdcde

Please sign in to comment.