From ac6e8127f50df90b8cdadd2e4d8d95bc8168b97b Mon Sep 17 00:00:00 2001 From: Ninad Date: Thu, 19 Aug 2021 23:56:28 +1000 Subject: [PATCH] NEW: Add `/nodes` endpoint (#25) Add `/nodes` endpoint supporting GET and POST methods for creating and querying nodes of the AiiDA graph. Some node types, such as `SinglefileData` require a file to be passed to the node constructor. In order to avoid complications with multiple requests and be able to send both the JSON metadata and the file in the same request, a `/nodes/singlefile` endpoint is added that accepts content type `multipart/form-data`. This introduces a bit of inconsistency in the use of the API (which is otherwise `application/json`) but is is a practical workaround for the time being until a better solution is identified. --- .pre-commit-config.yaml | 1 + aiida_restapi/main.py | 3 +- aiida_restapi/models.py | 137 ++++++++ aiida_restapi/routers/groups.py | 1 + aiida_restapi/routers/nodes.py | 97 ++++++ aiida_restapi/routers/users.py | 1 + conftest.py | 11 + setup.json | 15 +- tests/test_computers.py | 1 + tests/test_graphql/test_full/test_full.yml | 2 +- .../test_computer_get_entities.yml | 3 +- tests/test_nodes.py | 312 ++++++++++++++++++ 12 files changed, 580 insertions(+), 4 deletions(-) create mode 100644 aiida_restapi/routers/nodes.py create mode 100644 tests/test_nodes.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index bc77f5f..0719a74 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -58,6 +58,7 @@ repos: - passlib - pytest~=3.6,<5.0.0 - sphinx<4 + - importlib_metadata~=4.3 exclude: > (?x)^( docs/.*| diff --git a/aiida_restapi/main.py b/aiida_restapi/main.py index abeec6c..5a3c295 100644 --- a/aiida_restapi/main.py +++ b/aiida_restapi/main.py @@ -3,11 +3,12 @@ from fastapi import FastAPI from aiida_restapi.graphql import main -from aiida_restapi.routers import auth, computers, groups, users +from aiida_restapi.routers import auth, computers, groups, nodes, users app = FastAPI() app.include_router(auth.router) app.include_router(computers.router) +app.include_router(nodes.router) app.include_router(groups.router) app.include_router(users.router) app.add_route("/graphql", main.app, name="graphql") diff --git a/aiida_restapi/models.py b/aiida_restapi/models.py index 5c011af..2994ca8 100644 --- a/aiida_restapi/models.py +++ b/aiida_restapi/models.py @@ -6,17 +6,47 @@ """ # pylint: disable=too-few-public-methods +import inspect +import io from datetime import datetime from typing import ClassVar, Dict, List, Optional, Type, TypeVar from uuid import UUID from aiida import orm +from aiida.restapi.common.identifiers import load_entry_point_from_full_type +from fastapi import Form from pydantic import BaseModel, Field # Template type for subclasses of `AiidaModel` ModelType = TypeVar("ModelType", bound="AiidaModel") +def as_form(cls: Type[BaseModel]) -> Type[BaseModel]: + """ + Adds an as_form class method to decorated models. The as_form class method + can be used with FastAPI endpoints + + Note: Taken from https://github.com/tiangolo/fastapi/issues/2387 + """ + new_params = [ + inspect.Parameter( + field.alias, + inspect.Parameter.POSITIONAL_ONLY, + default=(Form(field.default) if not field.required else Form(...)), + ) + for field in cls.__fields__.values() + ] + + async def _as_form(**data: Dict) -> BaseModel: + return cls(**data) + + sig = inspect.signature(_as_form) + sig = sig.replace(parameters=new_params) + _as_form.__signature__ = sig # type: ignore + setattr(cls, "as_form", _as_form) + return cls + + class AiidaModel(BaseModel): """A mapping of an AiiDA entity to a pydantic model.""" @@ -26,6 +56,7 @@ class Config: """The models configuration.""" orm_mode = True + extra = "forbid" @classmethod def get_projectable_properties(cls) -> List[str]: @@ -88,6 +119,11 @@ class User(AiidaModel): _orm_entity = orm.User + class Config: + """The models configuration.""" + + extra = "allow" + id: Optional[int] = Field(description="Unique user id (pk)") email: str = Field(description="Email address of the user") first_name: Optional[str] = Field(description="First name of the user") @@ -119,6 +155,107 @@ class Computer(AiidaModel): description="General settings for these communication and management protocols" ) + description: Optional[str] = Field(description="Description of node") + + +class Node(AiidaModel): + """AiiDA Node Model.""" + + _orm_entity = orm.Node + + id: Optional[int] = Field(description="Unique id (pk)") + uuid: Optional[UUID] = Field(description="Unique uuid") + node_type: Optional[str] = Field(description="Node type") + process_type: Optional[str] = Field(description="Process type") + label: str = Field(description="Label of node") + description: Optional[str] = Field(description="Description of node") + ctime: Optional[datetime] = Field(description="Creation time") + mtime: Optional[datetime] = Field(description="Last modification time") + user_id: Optional[int] = Field(description="Created by user id (pk)") + dbcomputer_id: Optional[int] = Field(description="Associated computer id (pk)") + attributes: Optional[Dict] = Field( + description="Variable attributes of the node", + ) + extras: Optional[Dict] = Field( + description="Variable extras (unsealed) of the node", + ) + + +@as_form +class Node_Post(AiidaModel): + """AiiDA model for posting Nodes.""" + + node_type: Optional[str] = Field(description="Node type") + process_type: Optional[str] = Field(description="Process type") + label: str = Field(description="Label of node") + description: Optional[str] = Field(description="Description of node") + user_id: Optional[int] = Field(description="Created by user id (pk)") + dbcomputer_id: Optional[int] = Field(description="Associated computer id (pk)") + attributes: Optional[Dict] = Field( + description="Variable attributes of the node", + ) + extras: Optional[Dict] = Field( + description="Variable extras (unsealed) of the node", + ) + + @classmethod + def create_new_node( + cls: Type[ModelType], + node_type: str, + node_dict: dict, + ) -> orm.Node: + "Create and Store new Node" + + orm_class = load_entry_point_from_full_type(node_type) + attributes = node_dict.pop("attributes", {}) + extras = node_dict.pop("extras", {}) + + if issubclass(orm_class, orm.BaseType): + orm_object = orm_class( + attributes["value"], + **node_dict, + ) + elif issubclass(orm_class, orm.Dict): + orm_object = orm_class( + dict=attributes, + **node_dict, + ) + elif issubclass(orm_class, orm.Code): + orm_object = orm_class() + orm_object.set_remote_computer_exec( + ( + orm.Computer.get(id=node_dict.get("dbcomputer_id")), + attributes["remote_exec_path"], + ) + ) + orm_object.label = node_dict.get("label") + else: + orm_object = load_entry_point_from_full_type(node_type)(**node_dict) + orm_object.set_attribute_many(attributes) + + orm_object.set_extra_many(extras) + orm_object.store() + return orm_object + + @classmethod + def create_new_node_with_file( + cls: Type[ModelType], + node_type: str, + node_dict: dict, + file: bytes, + ) -> orm.Node: + "Create and Store new Node with file" + attributes = node_dict.pop("attributes", {}) + extras = node_dict.pop("extras", {}) + + orm_object = load_entry_point_from_full_type(node_type)( + file=io.BytesIO(file), **node_dict, **attributes + ) + + orm_object.set_extra_many(extras) + orm_object.store() + return orm_object + class Group(AiidaModel): """AiiDA Group model.""" diff --git a/aiida_restapi/routers/groups.py b/aiida_restapi/routers/groups.py index 6fcab8d..2bfd41b 100644 --- a/aiida_restapi/routers/groups.py +++ b/aiida_restapi/routers/groups.py @@ -41,6 +41,7 @@ async def read_group(group_id: int) -> Optional[Group]: @router.post("/groups", response_model=Group) +@with_dbenv() async def create_group( group: Group_Post, current_user: User = Depends( diff --git a/aiida_restapi/routers/nodes.py b/aiida_restapi/routers/nodes.py new file mode 100644 index 0000000..82d96f6 --- /dev/null +++ b/aiida_restapi/routers/nodes.py @@ -0,0 +1,97 @@ +# -*- coding: utf-8 -*- +"""Declaration of FastAPI application.""" +from typing import List, Optional + +from aiida import orm +from aiida.cmdline.utils.decorators import with_dbenv +from fastapi import APIRouter, Depends, File, HTTPException +from importlib_metadata import entry_points + +from aiida_restapi import models + +from .auth import get_current_active_user + +router = APIRouter() + +ENTRY_POINTS = entry_points() + + +@router.get("/nodes", response_model=List[models.Node]) +@with_dbenv() +async def read_nodes() -> List[models.Node]: + """Get list of all nodes""" + return models.Node.get_entities() + + +@router.get("/nodes/projectable_properties", response_model=List[str]) +async def get_nodes_projectable_properties() -> List[str]: + """Get projectable properties for nodes endpoint""" + + return models.Node.get_projectable_properties() + + +@router.get("/nodes/{nodes_id}", response_model=models.Node) +@with_dbenv() +async def read_node(nodes_id: int) -> Optional[models.Node]: + """Get nodes by id.""" + qbobj = orm.QueryBuilder() + + qbobj.append(orm.Node, filters={"id": nodes_id}, project=["**"], tag="node").limit( + 1 + ) + return qbobj.dict()[0]["node"] + + +@router.post("/nodes", response_model=models.Node) +@with_dbenv() +async def create_node( + node: models.Node_Post, + current_user: models.User = Depends( + get_current_active_user + ), # pylint: disable=unused-argument +) -> models.Node: + """Create new AiiDA node.""" + + node_dict = node.dict(exclude_unset=True) + node_type = node_dict.pop("node_type", None) + + try: + (entry_point_node,) = ENTRY_POINTS.select( + group="aiida.rest.post", name=node_type + ) + except ValueError as exc: + raise HTTPException( + status_code=404, detail="Entry point '{}' not recognized.".format(node_type) + ) from exc + + try: + orm_object = entry_point_node.load().create_new_node(node_type, node_dict) + except (TypeError, ValueError, KeyError) as err: + raise HTTPException(status_code=400, detail="Error: {0}".format(err)) from err + + return models.Node.from_orm(orm_object) + + +@router.post("/nodes/singlefile", response_model=models.Node) +@with_dbenv() +async def create_upload_file( + upload_file: bytes = File(...), + params: models.Node_Post = Depends(models.Node_Post.as_form), # type: ignore # pylint: disable=maybe-no-member + current_user: models.User = Depends( + get_current_active_user + ), # pylint: disable=unused-argument +) -> models.Node: + """Endpoint for uploading file data""" + node_dict = params.dict(exclude_unset=True, exclude_none=True) + node_type = node_dict.pop("node_type", None) + + try: + (entry_point_node,) = entry_points(group="aiida.rest.post", name=node_type) + except KeyError as exc: + raise KeyError("Entry point '{}' not recognized.".format(node_type)) from exc + + orm_object = entry_point_node.load().create_new_node_with_file( + node_type, node_dict, upload_file + ) + + return models.Node.from_orm(orm_object) diff --git a/aiida_restapi/routers/users.py b/aiida_restapi/routers/users.py index 6128f65..9fe08d5 100644 --- a/aiida_restapi/routers/users.py +++ b/aiida_restapi/routers/users.py @@ -39,6 +39,7 @@ async def read_user(user_id: int) -> Optional[User]: @router.post("/users", response_model=User) +@with_dbenv() async def create_user( user: User, current_user: User = Depends( diff --git a/conftest.py b/conftest.py index a59ddb9..e057b09 100644 --- a/conftest.py +++ b/conftest.py @@ -67,6 +67,17 @@ def default_groups(): return [group_1.id, group_2.id] +@pytest.fixture(scope="function") +def default_nodes(): + """Populate database with some nodes.""" + node_1 = orm.Int(1).store() + node_2 = orm.Float(1.1).store() + node_3 = orm.Str("test_string").store() + node_4 = orm.Bool(False).store() + + return [node_1.id, node_2.id, node_3.id, node_4.id] + + @pytest.fixture(scope="function") def authenticate(): """Authenticate user. diff --git a/setup.json b/setup.json index bc17fc7..2917591 100644 --- a/setup.json +++ b/setup.json @@ -19,6 +19,18 @@ ], "aiida.cmdline.data": [ "restapi = aiida_restapi.cli:data_cli" + ], + "aiida.rest.post": [ + "data.str.Str.| = aiida_restapi.models:Node_Post", + "data.float.Float.| = aiida_restapi.models:Node_Post", + "data.int.Int.| = aiida_restapi.models:Node_Post", + "data.bool.Bool.| = aiida_restapi.models:Node_Post", + "data.structure.StructureData.| = aiida_restapi.models:Node_Post", + "data.orbital.OrbitalData.| = aiida_restapi.models:Node_Post", + "data.list.List.| = aiida_restapi.models:Node_Post", + "data.dict.Dict.| = aiida_restapi.models:Node_Post", + "data.singlefile.SingleFileData.| = aiida_restapi.models:Node_Post", + "data.code.Code.| = aiida_restapi.models:Node_Post" ] }, "include_package_data": true, @@ -36,7 +48,8 @@ "pydantic~=1.8.2", "graphene~=2.0", "python-dateutil~=2.0", - "lark~=0.11.0" + "lark~=0.11.0", + "importlib_metadata~=4.3" ], "extras_require": { "testing": [ diff --git a/tests/test_computers.py b/tests/test_computers.py index 1e4ee66..fda7795 100644 --- a/tests/test_computers.py +++ b/tests/test_computers.py @@ -23,6 +23,7 @@ def test_get_computers_projectable(client): "scheduler_type", "transport_type", "metadata", + "description", ] diff --git a/tests/test_graphql/test_full/test_full.yml b/tests/test_graphql/test_full/test_full.yml index f5dd59e..5155d41 100644 --- a/tests/test_graphql/test_full/test_full.yml +++ b/tests/test_graphql/test_full/test_full.yml @@ -1,4 +1,4 @@ data: - aiidaVersion: 1.6.4 + aiidaVersion: 1.6.5 node: label: node 1 diff --git a/tests/test_models/test_computer_get_entities.yml b/tests/test_models/test_computer_get_entities.yml index 4fe4a19..259bc2e 100644 --- a/tests/test_models/test_computer_get_entities.yml +++ b/tests/test_models/test_computer_get_entities.yml @@ -1,4 +1,5 @@ -- hostname: localhost_1 +- description: '' + hostname: localhost_1 id: int metadata: {} name: test_comp_1 diff --git a/tests/test_nodes.py b/tests/test_nodes.py new file mode 100644 index 0000000..eb6da56 --- /dev/null +++ b/tests/test_nodes.py @@ -0,0 +1,312 @@ +# -*- coding: utf-8 -*- +"""Test the /nodes endpoint""" + + +import io + + +def test_get_nodes_projectable(client): + """Test get projectable properites for nodes.""" + response = client.get("/nodes/projectable_properties") + + assert response.status_code == 200 + assert response.json() == [ + "id", + "uuid", + "node_type", + "process_type", + "label", + "description", + "ctime", + "mtime", + "user_id", + "dbcomputer_id", + "attributes", + "extras", + ] + + +def test_get_single_nodes(default_nodes, client): # pylint: disable=unused-argument + """Test retrieving a single nodes.""" + + for nodes_id in default_nodes: + response = client.get("/nodes/{}".format(nodes_id)) + assert response.status_code == 200 + + +def test_get_nodes(default_nodes, client): # pylint: disable=unused-argument + """Test listing existing nodes.""" + response = client.get("/nodes") + assert response.status_code == 200 + assert len(response.json()) == 4 + + +def test_create_dict(client, authenticate): # pylint: disable=unused-argument + """Test creating a new dict.""" + response = client.post( + "/nodes", + json={ + "node_type": "data.dict.Dict.|", + "attributes": {"x": 1, "y": 2}, + "label": "test_dict", + }, + ) + assert response.status_code == 200, response.content + + +def test_create_code( + default_computers, client, authenticate +): # pylint: disable=unused-argument + """Test creating a new Code.""" + + for comp_id in default_computers: + response = client.post( + "/nodes", + json={ + "node_type": "data.code.Code.|", + "dbcomputer_id": comp_id, + "attributes": {"is_local": False, "remote_exec_path": "/bin/true"}, + "label": "test_code", + }, + ) + assert response.status_code == 200, response.content + + +def test_create_list(client, authenticate): # pylint: disable=unused-argument + """Test creating a new list.""" + response = client.post( + "/nodes", + json={ + "node_type": "data.list.List.|", + "attributes": {"list": [2, 3]}, + "label": "test_list", + }, + ) + + assert response.status_code == 200, response.content + + +def test_create_int(client, authenticate): # pylint: disable=unused-argument + """Test creating a new Int.""" + response = client.post( + "/nodes", + json={ + "node_type": "data.int.Int.|", + "attributes": {"value": 6}, + "label": "test_Int", + }, + ) + assert response.status_code == 200, response.content + + +def test_create_float(client, authenticate): # pylint: disable=unused-argument + """Test creating a new Float.""" + response = client.post( + "/nodes", + json={ + "node_type": "data.float.Float.|", + "attributes": {"value": 6.6}, + "label": "test_Float", + }, + ) + assert response.status_code == 200, response.content + + +def test_create_string(client, authenticate): # pylint: disable=unused-argument + """Test creating a new string.""" + response = client.post( + "/nodes", + json={ + "node_type": "data.str.Str.|", + "attributes": {"value": "test_string"}, + "label": "test_string", + }, + ) + assert response.status_code == 200, response.content + + +def test_create_bool(client, authenticate): # pylint: disable=unused-argument + """Test creating a new Bool.""" + response = client.post( + "/nodes", + json={ + "node_type": "data.bool.Bool.|", + "attributes": {"value": "True"}, + "label": "test_bool", + }, + ) + assert response.status_code == 200, response.content + + +def test_create_structure_data(client, authenticate): # pylint: disable=unused-argument + """Test creating a new StructureData.""" + response = client.post( + "/nodes", + json={ + "node_type": "data.structure.StructureData.|", + "process_type": None, + "label": "test_StructureData", + "description": "", + "attributes": { + "cell": [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]], + "pbc": [True, True, True], + "ase": None, + "pymatgen": None, + "pymatgen_structure": None, + "pymatgen_molecule": None, + }, + }, + ) + + assert response.status_code == 200, response.content + + +def test_create_orbital_data(client, authenticate): # pylint: disable=unused-argument + """Test creating a new OrbitalData.""" + response = client.post( + "/nodes", + json={ + "node_type": "data.orbital.OrbitalData.|", + "process_type": None, + "label": "test_OrbitalData", + "description": "", + "attributes": { + "orbital_dicts": [ + { + "spin": 0, + "position": [ + -1, + 1, + 1, + ], + "kind_name": "As", + "diffusivity": None, + "radial_nodes": 0, + "_orbital_type": "realhydrogen", + "x_orientation": None, + "z_orientation": None, + "angular_momentum": -3, + } + ] + }, + }, + ) + + assert response.status_code == 200, response.content + + +def test_create_single_file_upload( + client, authenticate +): # pylint: disable=unused-argument + """Testing file upload""" + test_file = { + "upload_file": ( + "test_file.txt", + io.BytesIO(b"Some test strings"), + "multipart/form-data", + ) + } + params = { + "node_type": "data.singlefile.SingleFileData.|", + "process_type": None, + "label": "test_upload_file", + "description": "Testing single upload file", + "attributes": {}, + } + + response = client.post("/nodes/singlefile", files=test_file, data=params) + + assert response.status_code == 200 + + +def test_create_node_wrond_value( + client, authenticate +): # pylint: disable=unused-argument + """Test creating a new node with wrong value.""" + response = client.post( + "/nodes", + json={ + "node_type": "data.float.Float.|", + "attributes": {"value": "tests"}, + "label": "test_Float", + }, + ) + assert response.status_code == 400, response.content + + response = client.post( + "/nodes", + json={ + "node_type": "data.int.Int.|", + "attributes": {"value": "tests"}, + "label": "test_int", + }, + ) + assert response.status_code == 400, response.content + + +def test_create_node_wrong_attribute( + client, authenticate +): # pylint: disable=unused-argument + """Test adding node with wrong attributes.""" + response = client.post( + "/nodes", + json={ + "node_type": "data.str.Str.|", + "attributes": {"value1": 5}, + "label": "test_int", + }, + ) + assert response.status_code == 400, response.content + + +def test_wrong_entry_point(client, authenticate): # pylint: disable=unused-argument + """Test adding node with wrong entry point.""" + response = client.post( + "/nodes", + json={ + "node_type": "data.float.wrong.|", + "attributes": {"value": 3}, + "label": "test_Float", + }, + ) + assert response.status_code == 404, response.content + + +def test_create_additional_attribute( + default_computers, client, authenticate +): # pylint: disable=unused-argument + """Test adding additional properties returns errors.""" + + for comp_id in default_computers: + response = client.post( + "/nodes", + json={ + "uuid": "3", + "node_type": "data.code.Code.|", + "dbcomputer_id": comp_id, + "attributes": {"is_local": False, "remote_exec_path": "/bin/true"}, + "label": "test_code", + }, + ) + assert response.status_code == 422, response.content + + +def test_create_bool_with_extra( + client, authenticate +): # pylint: disable=unused-argument + """Test creating a new Bool with extra.""" + response = client.post( + "/nodes", + json={ + "node_type": "data.bool.Bool.|", + "attributes": {"value": "True"}, + "label": "test_bool", + "extras": {"extra_one": "value_1", "extra_two": "value_2"}, + }, + ) + + check_response = client.get("/nodes/{}".format(response.json()["id"])) + + assert check_response.status_code == 200, response.content + assert check_response.json()["extras"]["extra_one"] == "value_1" + assert check_response.json()["extras"]["extra_two"] == "value_2"