Skip to content

Commit

Permalink
Add /processes endpoint (#30)
Browse files Browse the repository at this point in the history
Add /processes endpoint with methods POST (create new processes) and GET (list active processes).
  • Loading branch information
NinadBhat committed Aug 30, 2021
1 parent ac6e812 commit ee6a656
Show file tree
Hide file tree
Showing 5 changed files with 305 additions and 1 deletion.
3 changes: 2 additions & 1 deletion aiida_restapi/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
from fastapi import FastAPI

from aiida_restapi.graphql import main
from aiida_restapi.routers import auth, computers, groups, nodes, users
from aiida_restapi.routers import auth, computers, groups, nodes, process, users

app = FastAPI()
app.include_router(auth.router)
app.include_router(computers.router)
app.include_router(nodes.router)
app.include_router(groups.router)
app.include_router(users.router)
app.include_router(process.router)
app.add_route("/graphql", main.app, name="graphql")
31 changes: 31 additions & 0 deletions aiida_restapi/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,3 +297,34 @@ class Group_Post(AiidaModel):
label: str = Field(description="Used to access the group. Must be unique.")
type_string: Optional[str] = Field(description="Type of the group")
description: Optional[str] = Field(description="Short description of the group.")


class Process(AiidaModel):
"""AiiDA Process Model"""

_orm_entity = orm.ProcessNode

id: Optional[int] = Field(description="Unique id (pk)")
uuid: Optional[UUID] = Field(description="Universally unique identifier")
node_type: Optional[str] = Field(description="Node type")
process_type: Optional[str] = Field(description="Process type")
label: str = Field(description="Label of node")
description: Optional[str] = Field(description="Description of node")
ctime: Optional[datetime] = Field(description="Creation time")
mtime: Optional[datetime] = Field(description="Last modification time")
user_id: Optional[int] = Field(description="Created by user id (pk)")
dbcomputer_id: Optional[int] = Field(description="Associated computer id (pk)")
attributes: Optional[Dict] = Field(
description="Variable attributes of the node",
)
extras: Optional[Dict] = Field(
description="Variable extras (unsealed) of the node",
)


class Process_Post(AiidaModel):
"""AiiDA Process Post Model"""

label: str = Field(description="Label of node")
inputs: dict = Field(description="Input parmeters")
process_entry_point: str = Field(description="Entry Point for process")
90 changes: 90 additions & 0 deletions aiida_restapi/routers/process.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# -*- coding: utf-8 -*-
"""Declaration of FastAPI router for processes."""

from typing import List, Optional

from aiida import orm
from aiida.cmdline.utils.decorators import with_dbenv
from aiida.common.exceptions import NotExistent
from aiida.engine import submit
from aiida.orm.querybuilder import QueryBuilder
from aiida.plugins import load_entry_point_from_string
from fastapi import APIRouter, Depends, HTTPException

from aiida_restapi.models import Process, Process_Post, User

from .auth import get_current_active_user

router = APIRouter()


def substitute_node(input_dict: dict) -> dict:
"""Substitutes node ids with nodes"""
node_ids = {
key: node_id for key, node_id in input_dict.items() if not key.endswith(".uuid")
}

for key, value in input_dict.items():
if key not in node_ids.keys():
try:
node_ids[key[:-5]] = orm.Node.get(uuid=value)
except NotExistent as exc:
raise HTTPException(
status_code=404,
detail="Node ID: {} does not exist.".format(value),
) from exc

return node_ids


@router.get("/processes", response_model=List[Process])
@with_dbenv()
async def read_processes() -> List[Process]:
"""Get list of all processes"""

return Process.get_entities()


@router.get("/processes/projectable_properties", response_model=List[str])
async def get_processes_projectable_properties() -> List[str]:
"""Get projectable properties for processes endpoint"""

return Process.get_projectable_properties()


@router.get("/processes/{proc_id}", response_model=Process)
@with_dbenv()
async def read_process(proc_id: int) -> Optional[Process]:
"""Get process by id."""
qbobj = QueryBuilder()
qbobj.append(
orm.ProcessNode, filters={"id": proc_id}, project=["**"], tag="process"
).limit(1)

return qbobj.dict()[0]["process"]


@router.post("/processes", response_model=Process)
@with_dbenv()
async def post_process(
process: Process_Post,
current_user: User = Depends(
get_current_active_user
), # pylint: disable=unused-argument
) -> Optional[Process]:
"""Create new process."""
process_dict = process.dict(exclude_unset=True, exclude_none=True)
inputs = substitute_node(process_dict["inputs"])
entry_point = process_dict.get("process_entry_point")

try:
entry_point_process = load_entry_point_from_string(entry_point)
except ValueError as exc:
raise HTTPException(
status_code=404,
detail="Entry point '{}' not recognized.".format(entry_point),
) from exc

process_node = submit(entry_point_process, **inputs)

return process_node
70 changes: 70 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
# -*- coding: utf-8 -*-
"""pytest fixtures for simplified testing."""
import tempfile

import pytest
from aiida import orm
from aiida.engine import ProcessState
from aiida.orm import WorkChainNode, WorkFunctionNode
from fastapi.testclient import TestClient

from aiida_restapi import app, config
Expand Down Expand Up @@ -53,6 +57,72 @@ def default_computers():
return [comp_1.id, comp_2.id]


@pytest.fixture(scope="function")
def example_processes():
"""Populate database with some processes"""
calcs = []
process_label = "SomeDummyWorkFunctionNode"

# Create 6 WorkFunctionNodes and WorkChainNodes (one for each ProcessState)
for state in ProcessState:

calc = WorkFunctionNode()
calc.set_process_state(state)

# Set the WorkFunctionNode as successful
if state == ProcessState.FINISHED:
calc.set_exit_status(0)

# Give a `process_label` to the `WorkFunctionNodes` so the `--process-label` option can be tested
calc.set_attribute("process_label", process_label)

calc.store()
calcs.append(calc.id)

calc = WorkChainNode()
calc.set_process_state(state)

# Set the WorkChainNode as failed
if state == ProcessState.FINISHED:
calc.set_exit_status(1)

# Set the waiting work chain as paused as well
if state == ProcessState.WAITING:
calc.pause()

calc.store()
calcs.append(calc.id)
return calcs


@pytest.fixture(scope="function")
def default_test_add_process():
"""Populate database with some node to test adding process"""

workdir = tempfile.mkdtemp()

computer = orm.Computer(
label="localhost",
hostname="localhost",
workdir=workdir,
transport_type="local",
scheduler_type="direct",
)
computer.store()
computer.set_minimum_job_poll_interval(0.0)
computer.configure()

code = orm.Code(
input_plugin_name="arithmetic.add", remote_computer_exec=(computer, "/bin/true")
).store()

x = orm.Int(1).store()

y = orm.Int(2).store()

return [code.uuid, x.uuid, y.uuid]


@pytest.fixture(scope="function")
def default_groups():
"""Populate database with some groups."""
Expand Down
112 changes: 112 additions & 0 deletions tests/test_processes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# -*- coding: utf-8 -*-
"""Test the /processes endpoint"""


def test_get_processes(example_processes, client): # pylint: disable=unused-argument
"""Test listing existing processes."""
response = client.get("/processes/")

assert response.status_code == 200
assert len(response.json()) == 12


def test_get_processes_projectable(client):
"""Test get projectable properites for processes."""
response = client.get("/processes/projectable_properties")

assert response.status_code == 200
assert response.json() == [
"id",
"uuid",
"node_type",
"process_type",
"label",
"description",
"ctime",
"mtime",
"user_id",
"dbcomputer_id",
"attributes",
"extras",
]


def test_get_single_processes(
example_processes, client
): # pylint: disable=unused-argument
"""Test retrieving a single processes."""
for proc_id in example_processes:
response = client.get("/processes/{}".format(proc_id))
assert response.status_code == 200


def test_add_process(
default_test_add_process, client, authenticate
): # pylint: disable=unused-argument
"""Test adding new process"""
code_id, x_id, y_id = default_test_add_process
response = client.post(
"/processes",
json={
"label": "test_new_process",
"process_entry_point": "aiida.calculations:arithmetic.add",
"inputs": {
"code.uuid": code_id,
"x.uuid": x_id,
"y.uuid": y_id,
"metadata": {
"description": "Test job submission with the add plugin",
},
},
},
)
assert response.status_code == 200


def test_add_process_invalid_entry_point(
default_test_add_process, client, authenticate
): # pylint: disable=unused-argument
"""Test adding new process with invalid entry point"""
code_id, x_id, y_id = default_test_add_process
response = client.post(
"/processes",
json={
"label": "test_new_process",
"process_entry_point": "wrong_entry_point",
"inputs": {
"code.uuid": code_id,
"x.uuid": x_id,
"y.uuid": y_id,
"metadata": {
"description": "Test job submission with the add plugin",
},
},
},
)
assert response.status_code == 404


def test_add_process_invalid_node_id(
default_test_add_process, client, authenticate
): # pylint: disable=unused-argument
"""Test adding new process with invalid Node ID"""
code_id, x_id, _ = default_test_add_process
response = client.post(
"/processes",
json={
"label": "test_new_process",
"process_entry_point": "aiida.calculations:arithmetic.add",
"inputs": {
"code.uuid": code_id,
"x.uuid": x_id,
"y.uuid": "891a9efa-f90e-11eb-9a03-0242ac130003",
"metadata": {
"description": "Test job submission with the add plugin",
},
},
},
)
assert response.status_code == 404
assert response.json() == {
"detail": "Node ID: 891a9efa-f90e-11eb-9a03-0242ac130003 does not exist."
}

0 comments on commit ee6a656

Please sign in to comment.