diff --git a/aiida_restapi/main.py b/aiida_restapi/main.py index 0d5504e..390c711 100644 --- a/aiida_restapi/main.py +++ b/aiida_restapi/main.py @@ -3,11 +3,12 @@ from fastapi import FastAPI from aiida_restapi.graphql import main -from aiida_restapi.routers import auth, computers, groups, process, users +from aiida_restapi.routers import auth, computers, groups, nodes, process, users app = FastAPI() app.include_router(auth.router) app.include_router(computers.router) +app.include_router(nodes.router) app.include_router(groups.router) app.include_router(users.router) app.include_router(process.router) diff --git a/aiida_restapi/models.py b/aiida_restapi/models.py index f8707e9..0fa1e71 100644 --- a/aiida_restapi/models.py +++ b/aiida_restapi/models.py @@ -6,17 +6,47 @@ """ # pylint: disable=too-few-public-methods +import inspect +import io from datetime import datetime from typing import ClassVar, Dict, List, Optional, Type, TypeVar from uuid import UUID from aiida import orm +from aiida.restapi.common.identifiers import load_entry_point_from_full_type +from fastapi import Form from pydantic import BaseModel, Field # Template type for subclasses of `AiidaModel` ModelType = TypeVar("ModelType", bound="AiidaModel") +def as_form(cls: Type[BaseModel]) -> Type[BaseModel]: + """ + Adds an as_form class method to decorated models. The as_form class method + can be used with FastAPI endpoints + + Note: Taken from https://github.com/tiangolo/fastapi/issues/2387 + """ + new_params = [ + inspect.Parameter( + field.alias, + inspect.Parameter.POSITIONAL_ONLY, + default=(Form(field.default) if not field.required else Form(...)), + ) + for field in cls.__fields__.values() + ] + + async def _as_form(**data: Dict) -> BaseModel: + return cls(**data) + + sig = inspect.signature(_as_form) + sig = sig.replace(parameters=new_params) + _as_form.__signature__ = sig # type: ignore + setattr(cls, "as_form", _as_form) + return cls + + class AiidaModel(BaseModel): """A mapping of an AiiDA entity to a pydantic model.""" @@ -26,6 +56,7 @@ class Config: """The models configuration.""" orm_mode = True + extra = "forbid" @classmethod def get_projectable_properties(cls) -> List[str]: @@ -88,6 +119,11 @@ class User(AiidaModel): _orm_entity = orm.User + class Config: + """The models configuration.""" + + extra = "allow" + id: Optional[int] = Field(description="Unique user id (pk)") email: str = Field(description="Email address of the user") first_name: Optional[str] = Field(description="First name of the user") @@ -119,6 +155,107 @@ class Computer(AiidaModel): description="General settings for these communication and management protocols" ) + description: Optional[str] = Field(description="Description of node") + + +class Node(AiidaModel): + """AiiDA Node Model.""" + + _orm_entity = orm.Node + + id: Optional[int] = Field(description="Unique id (pk)") + uuid: Optional[UUID] = Field(description="Unique uuid") + node_type: Optional[str] = Field(description="Node type") + process_type: Optional[str] = Field(description="Process type") + label: str = Field(description="Label of node") + description: Optional[str] = Field(description="Description of node") + ctime: Optional[datetime] = Field(description="Creation time") + mtime: Optional[datetime] = Field(description="Last modification time") + user_id: Optional[int] = Field(description="Created by user id (pk)") + dbcomputer_id: Optional[int] = Field(description="Associated computer id (pk)") + attributes: Optional[Dict] = Field( + description="Variable attributes of the node", + ) + extras: Optional[Dict] = Field( + description="Variable extras (unsealed) of the node", + ) + + +@as_form +class Node_Post(AiidaModel): + """AiiDA model for posting Nodes.""" + + node_type: Optional[str] = Field(description="Node type") + process_type: Optional[str] = Field(description="Process type") + label: str = Field(description="Label of node") + description: Optional[str] = Field(description="Description of node") + user_id: Optional[int] = Field(description="Created by user id (pk)") + dbcomputer_id: Optional[int] = Field(description="Associated computer id (pk)") + attributes: Optional[Dict] = Field( + description="Variable attributes of the node", + ) + extras: Optional[Dict] = Field( + description="Variable extras (unsealed) of the node", + ) + + @classmethod + def create_new_node( + cls: Type[ModelType], + node_type: str, + node_dict: dict, + ) -> orm.Node: + "Create and Store new Node" + + orm_class = load_entry_point_from_full_type(node_type) + attributes = node_dict.pop("attributes", {}) + extras = node_dict.pop("extras", {}) + + if issubclass(orm_class, orm.BaseType): + orm_object = orm_class( + attributes["value"], + **node_dict, + ) + elif issubclass(orm_class, orm.Dict): + orm_object = orm_class( + dict=attributes, + **node_dict, + ) + elif issubclass(orm_class, orm.Code): + orm_object = orm_class() + orm_object.set_remote_computer_exec( + ( + orm.Computer.get(id=node_dict.get("dbcomputer_id")), + attributes["remote_exec_path"], + ) + ) + orm_object.label = node_dict.get("label") + else: + orm_object = load_entry_point_from_full_type(node_type)(**node_dict) + orm_object.set_attribute_many(attributes) + + orm_object.set_extra_many(extras) + orm_object.store() + return orm_object + + @classmethod + def create_new_node_with_file( + cls: Type[ModelType], + node_type: str, + node_dict: dict, + file: bytes, + ) -> orm.Node: + "Create and Store new Node with file" + attributes = node_dict.pop("attributes", {}) + extras = node_dict.pop("extras", {}) + + orm_object = load_entry_point_from_full_type(node_type)( + file=io.BytesIO(file), **node_dict, **attributes + ) + + orm_object.set_extra_many(extras) + orm_object.store() + return orm_object + class Group(AiidaModel): """AiiDA Group model.""" diff --git a/aiida_restapi/routers/groups.py b/aiida_restapi/routers/groups.py index 6fcab8d..2bfd41b 100644 --- a/aiida_restapi/routers/groups.py +++ b/aiida_restapi/routers/groups.py @@ -41,6 +41,7 @@ async def read_group(group_id: int) -> Optional[Group]: @router.post("/groups", response_model=Group) +@with_dbenv() async def create_group( group: Group_Post, current_user: User = Depends( diff --git a/aiida_restapi/routers/nodes.py b/aiida_restapi/routers/nodes.py new file mode 100644 index 0000000..82d96f6 --- /dev/null +++ b/aiida_restapi/routers/nodes.py @@ -0,0 +1,97 @@ +# -*- coding: utf-8 -*- +"""Declaration of FastAPI application.""" +from typing import List, Optional + +from aiida import orm +from aiida.cmdline.utils.decorators import with_dbenv +from fastapi import APIRouter, Depends, File, HTTPException +from importlib_metadata import entry_points + +from aiida_restapi import models + +from .auth import get_current_active_user + +router = APIRouter() + +ENTRY_POINTS = entry_points() + + +@router.get("/nodes", response_model=List[models.Node]) +@with_dbenv() +async def read_nodes() -> List[models.Node]: + """Get list of all nodes""" + return models.Node.get_entities() + + +@router.get("/nodes/projectable_properties", response_model=List[str]) +async def get_nodes_projectable_properties() -> List[str]: + """Get projectable properties for nodes endpoint""" + + return models.Node.get_projectable_properties() + + +@router.get("/nodes/{nodes_id}", response_model=models.Node) +@with_dbenv() +async def read_node(nodes_id: int) -> Optional[models.Node]: + """Get nodes by id.""" + qbobj = orm.QueryBuilder() + + qbobj.append(orm.Node, filters={"id": nodes_id}, project=["**"], tag="node").limit( + 1 + ) + return qbobj.dict()[0]["node"] + + +@router.post("/nodes", response_model=models.Node) +@with_dbenv() +async def create_node( + node: models.Node_Post, + current_user: models.User = Depends( + get_current_active_user + ), # pylint: disable=unused-argument +) -> models.Node: + """Create new AiiDA node.""" + + node_dict = node.dict(exclude_unset=True) + node_type = node_dict.pop("node_type", None) + + try: + (entry_point_node,) = ENTRY_POINTS.select( + group="aiida.rest.post", name=node_type + ) + except ValueError as exc: + raise HTTPException( + status_code=404, detail="Entry point '{}' not recognized.".format(node_type) + ) from exc + + try: + orm_object = entry_point_node.load().create_new_node(node_type, node_dict) + except (TypeError, ValueError, KeyError) as err: + raise HTTPException(status_code=400, detail="Error: {0}".format(err)) from err + + return models.Node.from_orm(orm_object) + + +@router.post("/nodes/singlefile", response_model=models.Node) +@with_dbenv() +async def create_upload_file( + upload_file: bytes = File(...), + params: models.Node_Post = Depends(models.Node_Post.as_form), # type: ignore # pylint: disable=maybe-no-member + current_user: models.User = Depends( + get_current_active_user + ), # pylint: disable=unused-argument +) -> models.Node: + """Endpoint for uploading file data""" + node_dict = params.dict(exclude_unset=True, exclude_none=True) + node_type = node_dict.pop("node_type", None) + + try: + (entry_point_node,) = entry_points(group="aiida.rest.post", name=node_type) + except KeyError as exc: + raise KeyError("Entry point '{}' not recognized.".format(node_type)) from exc + + orm_object = entry_point_node.load().create_new_node_with_file( + node_type, node_dict, upload_file + ) + + return models.Node.from_orm(orm_object) diff --git a/aiida_restapi/routers/users.py b/aiida_restapi/routers/users.py index 6128f65..9fe08d5 100644 --- a/aiida_restapi/routers/users.py +++ b/aiida_restapi/routers/users.py @@ -39,6 +39,7 @@ async def read_user(user_id: int) -> Optional[User]: @router.post("/users", response_model=User) +@with_dbenv() async def create_user( user: User, current_user: User = Depends( diff --git a/conftest.py b/conftest.py index e879728..a4c4623 100644 --- a/conftest.py +++ b/conftest.py @@ -137,6 +137,17 @@ def default_groups(): return [group_1.id, group_2.id] +@pytest.fixture(scope="function") +def default_nodes(): + """Populate database with some nodes.""" + node_1 = orm.Int(1).store() + node_2 = orm.Float(1.1).store() + node_3 = orm.Str("test_string").store() + node_4 = orm.Bool(False).store() + + return [node_1.id, node_2.id, node_3.id, node_4.id] + + @pytest.fixture(scope="function") def authenticate(): """Authenticate user. diff --git a/setup.json b/setup.json index 3bbf744..2917591 100644 --- a/setup.json +++ b/setup.json @@ -19,6 +19,18 @@ ], "aiida.cmdline.data": [ "restapi = aiida_restapi.cli:data_cli" + ], + "aiida.rest.post": [ + "data.str.Str.| = aiida_restapi.models:Node_Post", + "data.float.Float.| = aiida_restapi.models:Node_Post", + "data.int.Int.| = aiida_restapi.models:Node_Post", + "data.bool.Bool.| = aiida_restapi.models:Node_Post", + "data.structure.StructureData.| = aiida_restapi.models:Node_Post", + "data.orbital.OrbitalData.| = aiida_restapi.models:Node_Post", + "data.list.List.| = aiida_restapi.models:Node_Post", + "data.dict.Dict.| = aiida_restapi.models:Node_Post", + "data.singlefile.SingleFileData.| = aiida_restapi.models:Node_Post", + "data.code.Code.| = aiida_restapi.models:Node_Post" ] }, "include_package_data": true, diff --git a/tests/test_computers.py b/tests/test_computers.py index 1e4ee66..fda7795 100644 --- a/tests/test_computers.py +++ b/tests/test_computers.py @@ -23,6 +23,7 @@ def test_get_computers_projectable(client): "scheduler_type", "transport_type", "metadata", + "description", ] diff --git a/tests/test_models/test_computer_get_entities.yml b/tests/test_models/test_computer_get_entities.yml index 4fe4a19..259bc2e 100644 --- a/tests/test_models/test_computer_get_entities.yml +++ b/tests/test_models/test_computer_get_entities.yml @@ -1,4 +1,5 @@ -- hostname: localhost_1 +- description: '' + hostname: localhost_1 id: int metadata: {} name: test_comp_1 diff --git a/tests/test_nodes.py b/tests/test_nodes.py new file mode 100644 index 0000000..eb6da56 --- /dev/null +++ b/tests/test_nodes.py @@ -0,0 +1,312 @@ +# -*- coding: utf-8 -*- +"""Test the /nodes endpoint""" + + +import io + + +def test_get_nodes_projectable(client): + """Test get projectable properites for nodes.""" + response = client.get("/nodes/projectable_properties") + + assert response.status_code == 200 + assert response.json() == [ + "id", + "uuid", + "node_type", + "process_type", + "label", + "description", + "ctime", + "mtime", + "user_id", + "dbcomputer_id", + "attributes", + "extras", + ] + + +def test_get_single_nodes(default_nodes, client): # pylint: disable=unused-argument + """Test retrieving a single nodes.""" + + for nodes_id in default_nodes: + response = client.get("/nodes/{}".format(nodes_id)) + assert response.status_code == 200 + + +def test_get_nodes(default_nodes, client): # pylint: disable=unused-argument + """Test listing existing nodes.""" + response = client.get("/nodes") + assert response.status_code == 200 + assert len(response.json()) == 4 + + +def test_create_dict(client, authenticate): # pylint: disable=unused-argument + """Test creating a new dict.""" + response = client.post( + "/nodes", + json={ + "node_type": "data.dict.Dict.|", + "attributes": {"x": 1, "y": 2}, + "label": "test_dict", + }, + ) + assert response.status_code == 200, response.content + + +def test_create_code( + default_computers, client, authenticate +): # pylint: disable=unused-argument + """Test creating a new Code.""" + + for comp_id in default_computers: + response = client.post( + "/nodes", + json={ + "node_type": "data.code.Code.|", + "dbcomputer_id": comp_id, + "attributes": {"is_local": False, "remote_exec_path": "/bin/true"}, + "label": "test_code", + }, + ) + assert response.status_code == 200, response.content + + +def test_create_list(client, authenticate): # pylint: disable=unused-argument + """Test creating a new list.""" + response = client.post( + "/nodes", + json={ + "node_type": "data.list.List.|", + "attributes": {"list": [2, 3]}, + "label": "test_list", + }, + ) + + assert response.status_code == 200, response.content + + +def test_create_int(client, authenticate): # pylint: disable=unused-argument + """Test creating a new Int.""" + response = client.post( + "/nodes", + json={ + "node_type": "data.int.Int.|", + "attributes": {"value": 6}, + "label": "test_Int", + }, + ) + assert response.status_code == 200, response.content + + +def test_create_float(client, authenticate): # pylint: disable=unused-argument + """Test creating a new Float.""" + response = client.post( + "/nodes", + json={ + "node_type": "data.float.Float.|", + "attributes": {"value": 6.6}, + "label": "test_Float", + }, + ) + assert response.status_code == 200, response.content + + +def test_create_string(client, authenticate): # pylint: disable=unused-argument + """Test creating a new string.""" + response = client.post( + "/nodes", + json={ + "node_type": "data.str.Str.|", + "attributes": {"value": "test_string"}, + "label": "test_string", + }, + ) + assert response.status_code == 200, response.content + + +def test_create_bool(client, authenticate): # pylint: disable=unused-argument + """Test creating a new Bool.""" + response = client.post( + "/nodes", + json={ + "node_type": "data.bool.Bool.|", + "attributes": {"value": "True"}, + "label": "test_bool", + }, + ) + assert response.status_code == 200, response.content + + +def test_create_structure_data(client, authenticate): # pylint: disable=unused-argument + """Test creating a new StructureData.""" + response = client.post( + "/nodes", + json={ + "node_type": "data.structure.StructureData.|", + "process_type": None, + "label": "test_StructureData", + "description": "", + "attributes": { + "cell": [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]], + "pbc": [True, True, True], + "ase": None, + "pymatgen": None, + "pymatgen_structure": None, + "pymatgen_molecule": None, + }, + }, + ) + + assert response.status_code == 200, response.content + + +def test_create_orbital_data(client, authenticate): # pylint: disable=unused-argument + """Test creating a new OrbitalData.""" + response = client.post( + "/nodes", + json={ + "node_type": "data.orbital.OrbitalData.|", + "process_type": None, + "label": "test_OrbitalData", + "description": "", + "attributes": { + "orbital_dicts": [ + { + "spin": 0, + "position": [ + -1, + 1, + 1, + ], + "kind_name": "As", + "diffusivity": None, + "radial_nodes": 0, + "_orbital_type": "realhydrogen", + "x_orientation": None, + "z_orientation": None, + "angular_momentum": -3, + } + ] + }, + }, + ) + + assert response.status_code == 200, response.content + + +def test_create_single_file_upload( + client, authenticate +): # pylint: disable=unused-argument + """Testing file upload""" + test_file = { + "upload_file": ( + "test_file.txt", + io.BytesIO(b"Some test strings"), + "multipart/form-data", + ) + } + params = { + "node_type": "data.singlefile.SingleFileData.|", + "process_type": None, + "label": "test_upload_file", + "description": "Testing single upload file", + "attributes": {}, + } + + response = client.post("/nodes/singlefile", files=test_file, data=params) + + assert response.status_code == 200 + + +def test_create_node_wrond_value( + client, authenticate +): # pylint: disable=unused-argument + """Test creating a new node with wrong value.""" + response = client.post( + "/nodes", + json={ + "node_type": "data.float.Float.|", + "attributes": {"value": "tests"}, + "label": "test_Float", + }, + ) + assert response.status_code == 400, response.content + + response = client.post( + "/nodes", + json={ + "node_type": "data.int.Int.|", + "attributes": {"value": "tests"}, + "label": "test_int", + }, + ) + assert response.status_code == 400, response.content + + +def test_create_node_wrong_attribute( + client, authenticate +): # pylint: disable=unused-argument + """Test adding node with wrong attributes.""" + response = client.post( + "/nodes", + json={ + "node_type": "data.str.Str.|", + "attributes": {"value1": 5}, + "label": "test_int", + }, + ) + assert response.status_code == 400, response.content + + +def test_wrong_entry_point(client, authenticate): # pylint: disable=unused-argument + """Test adding node with wrong entry point.""" + response = client.post( + "/nodes", + json={ + "node_type": "data.float.wrong.|", + "attributes": {"value": 3}, + "label": "test_Float", + }, + ) + assert response.status_code == 404, response.content + + +def test_create_additional_attribute( + default_computers, client, authenticate +): # pylint: disable=unused-argument + """Test adding additional properties returns errors.""" + + for comp_id in default_computers: + response = client.post( + "/nodes", + json={ + "uuid": "3", + "node_type": "data.code.Code.|", + "dbcomputer_id": comp_id, + "attributes": {"is_local": False, "remote_exec_path": "/bin/true"}, + "label": "test_code", + }, + ) + assert response.status_code == 422, response.content + + +def test_create_bool_with_extra( + client, authenticate +): # pylint: disable=unused-argument + """Test creating a new Bool with extra.""" + response = client.post( + "/nodes", + json={ + "node_type": "data.bool.Bool.|", + "attributes": {"value": "True"}, + "label": "test_bool", + "extras": {"extra_one": "value_1", "extra_two": "value_2"}, + }, + ) + + check_response = client.get("/nodes/{}".format(response.json()["id"])) + + assert check_response.status_code == 200, response.content + assert check_response.json()["extras"]["extra_one"] == "value_1" + assert check_response.json()["extras"]["extra_two"] == "value_2"