Skip to content

Commit

Permalink
Merge branch 'master' into processes
Browse files Browse the repository at this point in the history
  • Loading branch information
NinadBhat committed Aug 19, 2021
2 parents 1dfa072 + ac6e812 commit e2f7c35
Show file tree
Hide file tree
Showing 10 changed files with 576 additions and 2 deletions.
3 changes: 2 additions & 1 deletion aiida_restapi/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
from fastapi import FastAPI

from aiida_restapi.graphql import main
from aiida_restapi.routers import auth, computers, groups, process, users
from aiida_restapi.routers import auth, computers, groups, nodes, process, users

app = FastAPI()
app.include_router(auth.router)
app.include_router(computers.router)
app.include_router(nodes.router)
app.include_router(groups.router)
app.include_router(users.router)
app.include_router(process.router)
Expand Down
137 changes: 137 additions & 0 deletions aiida_restapi/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,47 @@
"""
# pylint: disable=too-few-public-methods

import inspect
import io
from datetime import datetime
from typing import ClassVar, Dict, List, Optional, Type, TypeVar
from uuid import UUID

from aiida import orm
from aiida.restapi.common.identifiers import load_entry_point_from_full_type
from fastapi import Form
from pydantic import BaseModel, Field

# Template type for subclasses of `AiidaModel`
ModelType = TypeVar("ModelType", bound="AiidaModel")


def as_form(cls: Type[BaseModel]) -> Type[BaseModel]:
"""
Adds an as_form class method to decorated models. The as_form class method
can be used with FastAPI endpoints
Note: Taken from https://github.com/tiangolo/fastapi/issues/2387
"""
new_params = [
inspect.Parameter(
field.alias,
inspect.Parameter.POSITIONAL_ONLY,
default=(Form(field.default) if not field.required else Form(...)),
)
for field in cls.__fields__.values()
]

async def _as_form(**data: Dict) -> BaseModel:
return cls(**data)

sig = inspect.signature(_as_form)
sig = sig.replace(parameters=new_params)
_as_form.__signature__ = sig # type: ignore
setattr(cls, "as_form", _as_form)
return cls


class AiidaModel(BaseModel):
"""A mapping of an AiiDA entity to a pydantic model."""

Expand All @@ -26,6 +56,7 @@ class Config:
"""The models configuration."""

orm_mode = True
extra = "forbid"

@classmethod
def get_projectable_properties(cls) -> List[str]:
Expand Down Expand Up @@ -88,6 +119,11 @@ class User(AiidaModel):

_orm_entity = orm.User

class Config:
"""The models configuration."""

extra = "allow"

id: Optional[int] = Field(description="Unique user id (pk)")
email: str = Field(description="Email address of the user")
first_name: Optional[str] = Field(description="First name of the user")
Expand Down Expand Up @@ -119,6 +155,107 @@ class Computer(AiidaModel):
description="General settings for these communication and management protocols"
)

description: Optional[str] = Field(description="Description of node")


class Node(AiidaModel):
"""AiiDA Node Model."""

_orm_entity = orm.Node

id: Optional[int] = Field(description="Unique id (pk)")
uuid: Optional[UUID] = Field(description="Unique uuid")
node_type: Optional[str] = Field(description="Node type")
process_type: Optional[str] = Field(description="Process type")
label: str = Field(description="Label of node")
description: Optional[str] = Field(description="Description of node")
ctime: Optional[datetime] = Field(description="Creation time")
mtime: Optional[datetime] = Field(description="Last modification time")
user_id: Optional[int] = Field(description="Created by user id (pk)")
dbcomputer_id: Optional[int] = Field(description="Associated computer id (pk)")
attributes: Optional[Dict] = Field(
description="Variable attributes of the node",
)
extras: Optional[Dict] = Field(
description="Variable extras (unsealed) of the node",
)


@as_form
class Node_Post(AiidaModel):
"""AiiDA model for posting Nodes."""

node_type: Optional[str] = Field(description="Node type")
process_type: Optional[str] = Field(description="Process type")
label: str = Field(description="Label of node")
description: Optional[str] = Field(description="Description of node")
user_id: Optional[int] = Field(description="Created by user id (pk)")
dbcomputer_id: Optional[int] = Field(description="Associated computer id (pk)")
attributes: Optional[Dict] = Field(
description="Variable attributes of the node",
)
extras: Optional[Dict] = Field(
description="Variable extras (unsealed) of the node",
)

@classmethod
def create_new_node(
cls: Type[ModelType],
node_type: str,
node_dict: dict,
) -> orm.Node:
"Create and Store new Node"

orm_class = load_entry_point_from_full_type(node_type)
attributes = node_dict.pop("attributes", {})
extras = node_dict.pop("extras", {})

if issubclass(orm_class, orm.BaseType):
orm_object = orm_class(
attributes["value"],
**node_dict,
)
elif issubclass(orm_class, orm.Dict):
orm_object = orm_class(
dict=attributes,
**node_dict,
)
elif issubclass(orm_class, orm.Code):
orm_object = orm_class()
orm_object.set_remote_computer_exec(
(
orm.Computer.get(id=node_dict.get("dbcomputer_id")),
attributes["remote_exec_path"],
)
)
orm_object.label = node_dict.get("label")
else:
orm_object = load_entry_point_from_full_type(node_type)(**node_dict)
orm_object.set_attribute_many(attributes)

orm_object.set_extra_many(extras)
orm_object.store()
return orm_object

@classmethod
def create_new_node_with_file(
cls: Type[ModelType],
node_type: str,
node_dict: dict,
file: bytes,
) -> orm.Node:
"Create and Store new Node with file"
attributes = node_dict.pop("attributes", {})
extras = node_dict.pop("extras", {})

orm_object = load_entry_point_from_full_type(node_type)(
file=io.BytesIO(file), **node_dict, **attributes
)

orm_object.set_extra_many(extras)
orm_object.store()
return orm_object


class Group(AiidaModel):
"""AiiDA Group model."""
Expand Down
1 change: 1 addition & 0 deletions aiida_restapi/routers/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ async def read_group(group_id: int) -> Optional[Group]:


@router.post("/groups", response_model=Group)
@with_dbenv()
async def create_group(
group: Group_Post,
current_user: User = Depends(
Expand Down
97 changes: 97 additions & 0 deletions aiida_restapi/routers/nodes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# -*- coding: utf-8 -*-
"""Declaration of FastAPI application."""
from typing import List, Optional

from aiida import orm
from aiida.cmdline.utils.decorators import with_dbenv
from fastapi import APIRouter, Depends, File, HTTPException
from importlib_metadata import entry_points

from aiida_restapi import models

from .auth import get_current_active_user

router = APIRouter()

ENTRY_POINTS = entry_points()


@router.get("/nodes", response_model=List[models.Node])
@with_dbenv()
async def read_nodes() -> List[models.Node]:
"""Get list of all nodes"""
return models.Node.get_entities()


@router.get("/nodes/projectable_properties", response_model=List[str])
async def get_nodes_projectable_properties() -> List[str]:
"""Get projectable properties for nodes endpoint"""

return models.Node.get_projectable_properties()


@router.get("/nodes/{nodes_id}", response_model=models.Node)
@with_dbenv()
async def read_node(nodes_id: int) -> Optional[models.Node]:
"""Get nodes by id."""
qbobj = orm.QueryBuilder()

qbobj.append(orm.Node, filters={"id": nodes_id}, project=["**"], tag="node").limit(
1
)
return qbobj.dict()[0]["node"]


@router.post("/nodes", response_model=models.Node)
@with_dbenv()
async def create_node(
node: models.Node_Post,
current_user: models.User = Depends(
get_current_active_user
), # pylint: disable=unused-argument
) -> models.Node:
"""Create new AiiDA node."""

node_dict = node.dict(exclude_unset=True)
node_type = node_dict.pop("node_type", None)

try:
(entry_point_node,) = ENTRY_POINTS.select(
group="aiida.rest.post", name=node_type
)
except ValueError as exc:
raise HTTPException(
status_code=404, detail="Entry point '{}' not recognized.".format(node_type)
) from exc

try:
orm_object = entry_point_node.load().create_new_node(node_type, node_dict)
except (TypeError, ValueError, KeyError) as err:
raise HTTPException(status_code=400, detail="Error: {0}".format(err)) from err

return models.Node.from_orm(orm_object)


@router.post("/nodes/singlefile", response_model=models.Node)
@with_dbenv()
async def create_upload_file(
upload_file: bytes = File(...),
params: models.Node_Post = Depends(models.Node_Post.as_form), # type: ignore # pylint: disable=maybe-no-member
current_user: models.User = Depends(
get_current_active_user
), # pylint: disable=unused-argument
) -> models.Node:
"""Endpoint for uploading file data"""
node_dict = params.dict(exclude_unset=True, exclude_none=True)
node_type = node_dict.pop("node_type", None)

try:
(entry_point_node,) = entry_points(group="aiida.rest.post", name=node_type)
except KeyError as exc:
raise KeyError("Entry point '{}' not recognized.".format(node_type)) from exc

orm_object = entry_point_node.load().create_new_node_with_file(
node_type, node_dict, upload_file
)

return models.Node.from_orm(orm_object)
1 change: 1 addition & 0 deletions aiida_restapi/routers/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ async def read_user(user_id: int) -> Optional[User]:


@router.post("/users", response_model=User)
@with_dbenv()
async def create_user(
user: User,
current_user: User = Depends(
Expand Down
11 changes: 11 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,17 @@ def default_groups():
return [group_1.id, group_2.id]


@pytest.fixture(scope="function")
def default_nodes():
"""Populate database with some nodes."""
node_1 = orm.Int(1).store()
node_2 = orm.Float(1.1).store()
node_3 = orm.Str("test_string").store()
node_4 = orm.Bool(False).store()

return [node_1.id, node_2.id, node_3.id, node_4.id]


@pytest.fixture(scope="function")
def authenticate():
"""Authenticate user.
Expand Down
12 changes: 12 additions & 0 deletions setup.json
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,18 @@
],
"aiida.cmdline.data": [
"restapi = aiida_restapi.cli:data_cli"
],
"aiida.rest.post": [
"data.str.Str.| = aiida_restapi.models:Node_Post",
"data.float.Float.| = aiida_restapi.models:Node_Post",
"data.int.Int.| = aiida_restapi.models:Node_Post",
"data.bool.Bool.| = aiida_restapi.models:Node_Post",
"data.structure.StructureData.| = aiida_restapi.models:Node_Post",
"data.orbital.OrbitalData.| = aiida_restapi.models:Node_Post",
"data.list.List.| = aiida_restapi.models:Node_Post",
"data.dict.Dict.| = aiida_restapi.models:Node_Post",
"data.singlefile.SingleFileData.| = aiida_restapi.models:Node_Post",
"data.code.Code.| = aiida_restapi.models:Node_Post"
]
},
"include_package_data": true,
Expand Down
1 change: 1 addition & 0 deletions tests/test_computers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def test_get_computers_projectable(client):
"scheduler_type",
"transport_type",
"metadata",
"description",
]


Expand Down
3 changes: 2 additions & 1 deletion tests/test_models/test_computer_get_entities.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
- hostname: localhost_1
- description: ''
hostname: localhost_1
id: int
metadata: {}
name: test_comp_1
Expand Down

0 comments on commit e2f7c35

Please sign in to comment.