Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add processes endpoint #30

Merged
merged 12 commits into from
Aug 30, 2021
3 changes: 2 additions & 1 deletion aiida_restapi/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
from fastapi import FastAPI

from aiida_restapi.graphql import main
from aiida_restapi.routers import auth, computers, groups, nodes, users
from aiida_restapi.routers import auth, computers, groups, nodes, process, users

app = FastAPI()
app.include_router(auth.router)
app.include_router(computers.router)
app.include_router(nodes.router)
app.include_router(groups.router)
app.include_router(users.router)
app.include_router(process.router)
app.add_route("/graphql", main.app, name="graphql")
31 changes: 31 additions & 0 deletions aiida_restapi/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,3 +297,34 @@ class Group_Post(AiidaModel):
label: str = Field(description="Used to access the group. Must be unique.")
type_string: Optional[str] = Field(description="Type of the group")
description: Optional[str] = Field(description="Short description of the group.")


class Process(AiidaModel):
"""AiiDA Process Model"""

_orm_entity = orm.ProcessNode

id: Optional[int] = Field(description="Unique id (pk)")
uuid: Optional[UUID] = Field(description="Universally unique identifier")
node_type: Optional[str] = Field(description="Node type")
process_type: Optional[str] = Field(description="Process type")
label: str = Field(description="Label of node")
description: Optional[str] = Field(description="Description of node")
ctime: Optional[datetime] = Field(description="Creation time")
mtime: Optional[datetime] = Field(description="Last modification time")
user_id: Optional[int] = Field(description="Created by user id (pk)")
dbcomputer_id: Optional[int] = Field(description="Associated computer id (pk)")
attributes: Optional[Dict] = Field(
description="Variable attributes of the node",
)
extras: Optional[Dict] = Field(
description="Variable extras (unsealed) of the node",
)


class Process_Post(AiidaModel):
"""AiiDA Process Post Model"""

label: str = Field(description="Label of node")
inputs: dict = Field(description="Input parmeters")
process_entry_point: str = Field(description="Entry Point for process")
90 changes: 90 additions & 0 deletions aiida_restapi/routers/process.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# -*- coding: utf-8 -*-
"""Declaration of FastAPI router for processes."""

from typing import List, Optional

from aiida import orm
from aiida.cmdline.utils.decorators import with_dbenv
from aiida.common.exceptions import NotExistent
from aiida.engine import submit
from aiida.orm.querybuilder import QueryBuilder
from aiida.plugins import load_entry_point_from_string
from fastapi import APIRouter, Depends, HTTPException

from aiida_restapi.models import Process, Process_Post, User

from .auth import get_current_active_user

router = APIRouter()


def substitute_node(input_dict: dict) -> dict:
"""Substitutes node ids with nodes"""
node_ids = {
key: node_id for key, node_id in input_dict.items() if not key.endswith(".uuid")
}

for key, value in input_dict.items():
if key not in node_ids.keys():
try:
node_ids[key[:-5]] = orm.Node.get(uuid=value)
except NotExistent as exc:
raise HTTPException(
status_code=404,
detail="Node ID: {} does not exist.".format(value),
) from exc

return node_ids


@router.get("/processes", response_model=List[Process])
@with_dbenv()
async def read_processes() -> List[Process]:
"""Get list of all processes"""

return Process.get_entities()


@router.get("/processes/projectable_properties", response_model=List[str])
async def get_processes_projectable_properties() -> List[str]:
"""Get projectable properties for processes endpoint"""

return Process.get_projectable_properties()


@router.get("/processes/{proc_id}", response_model=Process)
@with_dbenv()
async def read_process(proc_id: int) -> Optional[Process]:
"""Get process by id."""
qbobj = QueryBuilder()
qbobj.append(
orm.ProcessNode, filters={"id": proc_id}, project=["**"], tag="process"
).limit(1)

return qbobj.dict()[0]["process"]


@router.post("/processes", response_model=Process)
@with_dbenv()
async def post_process(
process: Process_Post,
current_user: User = Depends(
get_current_active_user
), # pylint: disable=unused-argument
) -> Optional[Process]:
"""Create new process."""
process_dict = process.dict(exclude_unset=True, exclude_none=True)
inputs = substitute_node(process_dict["inputs"])
entry_point = process_dict.get("process_entry_point")

try:
entry_point_process = load_entry_point_from_string(entry_point)
except ValueError as exc:
raise HTTPException(
status_code=404,
detail="Entry point '{}' not recognized.".format(entry_point),
) from exc

process_node = submit(entry_point_process, **inputs)

return process_node
70 changes: 70 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
# -*- coding: utf-8 -*-
"""pytest fixtures for simplified testing."""
import tempfile

import pytest
from aiida import orm
from aiida.engine import ProcessState
from aiida.orm import WorkChainNode, WorkFunctionNode
from fastapi.testclient import TestClient

from aiida_restapi import app, config
Expand Down Expand Up @@ -53,6 +57,72 @@ def default_computers():
return [comp_1.id, comp_2.id]


@pytest.fixture(scope="function")
def example_processes():
"""Populate database with some processes"""
calcs = []
process_label = "SomeDummyWorkFunctionNode"

# Create 6 WorkFunctionNodes and WorkChainNodes (one for each ProcessState)
for state in ProcessState:

calc = WorkFunctionNode()
calc.set_process_state(state)

# Set the WorkFunctionNode as successful
if state == ProcessState.FINISHED:
calc.set_exit_status(0)

# Give a `process_label` to the `WorkFunctionNodes` so the `--process-label` option can be tested
calc.set_attribute("process_label", process_label)

calc.store()
calcs.append(calc.id)

calc = WorkChainNode()
calc.set_process_state(state)

# Set the WorkChainNode as failed
if state == ProcessState.FINISHED:
calc.set_exit_status(1)

# Set the waiting work chain as paused as well
if state == ProcessState.WAITING:
calc.pause()

calc.store()
calcs.append(calc.id)
return calcs


@pytest.fixture(scope="function")
def default_test_add_process():
"""Populate database with some node to test adding process"""

workdir = tempfile.mkdtemp()

computer = orm.Computer(
label="localhost",
hostname="localhost",
workdir=workdir,
transport_type="local",
scheduler_type="direct",
)
computer.store()
computer.set_minimum_job_poll_interval(0.0)
computer.configure()

code = orm.Code(
input_plugin_name="arithmetic.add", remote_computer_exec=(computer, "/bin/true")
).store()

x = orm.Int(1).store()

y = orm.Int(2).store()

return [code.uuid, x.uuid, y.uuid]


@pytest.fixture(scope="function")
def default_groups():
"""Populate database with some groups."""
Expand Down
112 changes: 112 additions & 0 deletions tests/test_processes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# -*- coding: utf-8 -*-
"""Test the /processes endpoint"""


def test_get_processes(example_processes, client): # pylint: disable=unused-argument
"""Test listing existing processes."""
response = client.get("/processes/")

assert response.status_code == 200
assert len(response.json()) == 12


def test_get_processes_projectable(client):
"""Test get projectable properites for processes."""
response = client.get("/processes/projectable_properties")

assert response.status_code == 200
assert response.json() == [
"id",
"uuid",
"node_type",
"process_type",
"label",
"description",
"ctime",
"mtime",
"user_id",
"dbcomputer_id",
"attributes",
"extras",
]


def test_get_single_processes(
example_processes, client
): # pylint: disable=unused-argument
"""Test retrieving a single processes."""
for proc_id in example_processes:
response = client.get("/processes/{}".format(proc_id))
assert response.status_code == 200


def test_add_process(
default_test_add_process, client, authenticate
): # pylint: disable=unused-argument
"""Test adding new process"""
code_id, x_id, y_id = default_test_add_process
response = client.post(
"/processes",
json={
"label": "test_new_process",
"process_entry_point": "aiida.calculations:arithmetic.add",
"inputs": {
"code.uuid": code_id,
"x.uuid": x_id,
"y.uuid": y_id,
"metadata": {
"description": "Test job submission with the add plugin",
},
},
},
)
assert response.status_code == 200


def test_add_process_invalid_entry_point(
default_test_add_process, client, authenticate
): # pylint: disable=unused-argument
"""Test adding new process with invalid entry point"""
code_id, x_id, y_id = default_test_add_process
response = client.post(
"/processes",
json={
"label": "test_new_process",
"process_entry_point": "wrong_entry_point",
"inputs": {
"code.uuid": code_id,
"x.uuid": x_id,
"y.uuid": y_id,
"metadata": {
"description": "Test job submission with the add plugin",
},
},
},
)
assert response.status_code == 404


def test_add_process_invalid_node_id(
default_test_add_process, client, authenticate
): # pylint: disable=unused-argument
"""Test adding new process with invalid Node ID"""
code_id, x_id, _ = default_test_add_process
response = client.post(
"/processes",
json={
"label": "test_new_process",
"process_entry_point": "aiida.calculations:arithmetic.add",
"inputs": {
"code.uuid": code_id,
"x.uuid": x_id,
"y.uuid": "891a9efa-f90e-11eb-9a03-0242ac130003",
"metadata": {
"description": "Test job submission with the add plugin",
},
},
},
)
assert response.status_code == 404
assert response.json() == {
"detail": "Node ID: 891a9efa-f90e-11eb-9a03-0242ac130003 does not exist."
}