Skip to content

Commit

Permalink
Merge pull request #74 from CederGroupHub/user_input_request_with_note
Browse files Browse the repository at this point in the history
User input request with note
  • Loading branch information
odartsi committed Jun 18, 2024
2 parents 8ae443b + 3ade681 commit 1cadb71
Show file tree
Hide file tree
Showing 13 changed files with 228 additions and 102 deletions.
21 changes: 12 additions & 9 deletions alab_management/experiment_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
tasks and samples and mark the finished tasks in the database when it is
done.
"""

import time
from typing import Any

Expand All @@ -30,7 +31,7 @@ def __init__(self):

config = AlabOSConfig()
self.__copy_to_completed_db = (
"mongodb_completed" in config
"mongodb_completed" in config
) # if this is not defined in the config, assume it this feature is not being used.
if self.__copy_to_completed_db:
self.completed_experiment_view = CompletedExperimentView()
Expand Down Expand Up @@ -91,7 +92,9 @@ def _handle_pending_experiment(self, experiment: dict[str, Any]):
},
)
if task_graph.has_cycle():
self.experiment_view.update_experiment_status(experiment["_id"], ExperimentStatus.ERROR)
self.experiment_view.update_experiment_status(
experiment["_id"], ExperimentStatus.ERROR
)
print(f"Experiment ({experiment['_id']}) has a cycle in the graph.")
return

Expand Down Expand Up @@ -156,13 +159,13 @@ def mark_completed_experiments(self):

# if all the tasks of an experiment have been finished
if all(
self.task_view.get_status(task_id=task_id)
in {
TaskStatus.COMPLETED,
TaskStatus.ERROR,
TaskStatus.CANCELLED,
}
for task_id in task_ids
self.task_view.get_status(task_id=task_id)
in {
TaskStatus.COMPLETED,
TaskStatus.ERROR,
TaskStatus.CANCELLED,
}
for task_id in task_ids
):
self.experiment_view.update_experiment_status(
exp_id=experiment["_id"], status=ExperimentStatus.COMPLETED
Expand Down
16 changes: 12 additions & 4 deletions alab_management/lab_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from alab_management.task_view.task import BaseTask
from alab_management.task_view.task_enums import TaskPriority, TaskStatus
from alab_management.task_view.task_view import TaskView
from alab_management.user_input import request_user_input
from alab_management.user_input import request_user_input, request_user_input_with_note


class DeviceRunningException(Exception):
Expand Down Expand Up @@ -126,9 +126,7 @@ def request_resources(
device_type: self._device_client.create_device_wrapper(device_name)
for device_type, device_name in devices.items()
} # type: ignore
self._task_view.update_status(
task_id=self.task_id, status=TaskStatus.RUNNING
)
self._task_view.update_status(task_id=self.task_id, status=TaskStatus.RUNNING)
yield devices, sample_positions

self._resource_requester.release_resources(request_id=request_id)
Expand Down Expand Up @@ -335,6 +333,16 @@ def request_user_input(self, prompt: str, options: list[str]) -> str:
"""
return request_user_input(task_id=self.task_id, prompt=prompt, options=options)

def request_user_input_with_note(
self, prompt: str, options: list[str]
) -> tuple[str, str]:
"""Request user input from the user. This function will block until the user inputs something. Returns the
value returned by the user and the note.
"""
return request_user_input_with_note(
task_id=self.task_id, prompt=prompt, options=options
)

@property
def priority(self) -> int:
"""Get the priority of the task."""
Expand Down
17 changes: 11 additions & 6 deletions alab_management/resource_manager/resource_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,9 @@ def handle_released_resources(self):
self._release_devices(devices)
self._release_sample_positions(sample_positions)
self.update_request_status(
request_id=request["_id"], status=RequestStatus.RELEASED, original_status=RequestStatus.NEED_RELEASE
request_id=request["_id"],
status=RequestStatus.RELEASED,
original_status=RequestStatus.NEED_RELEASE,
)

def handle_requested_resources(self):
Expand All @@ -84,9 +86,12 @@ def _handle_requested_resources(self, request_entry: dict[str, Any]):
task_id = request_entry["task_id"]

task_status = self.task_view.get_status(task_id=task_id)
if (task_status != TaskStatus.REQUESTING_RESOURCES or
task_id in {task["task_id"] for task in self.task_view.get_tasks_to_be_canceled(
canceling_progress=CancelingProgress.WORKER_NOTIFIED)}):
if task_status != TaskStatus.REQUESTING_RESOURCES or task_id in {
task["task_id"]
for task in self.task_view.get_tasks_to_be_canceled(
canceling_progress=CancelingProgress.WORKER_NOTIFIED
)
}:
# this implies the Task has been cancelled or errored somewhere else in the chain -- we should
# not allocate any resources to the broken Task.
self.update_request_status(
Expand Down Expand Up @@ -197,7 +202,7 @@ def _occupy_devices(self, devices: dict[str, dict[str, Any]], task_id: ObjectId)
)

def _occupy_sample_positions(
self, sample_positions: dict[str, list[dict[str, Any]]], task_id: ObjectId
self, sample_positions: dict[str, list[dict[str, Any]]], task_id: ObjectId
):
for sample_positions_ in sample_positions.values():
for sample_position_ in sample_positions_:
Expand All @@ -211,7 +216,7 @@ def _release_devices(self, devices: dict[str, dict[str, Any]]):
self.device_view.release_device(device["name"])

def _release_sample_positions(
self, sample_positions: dict[str, list[dict[str, Any]]]
self, sample_positions: dict[str, list[dict[str, Any]]]
):
for sample_positions_ in sample_positions.values():
for sample_position in sample_positions_:
Expand Down
91 changes: 52 additions & 39 deletions alab_management/resource_manager/resource_requester.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
TaskLauncher is the core module of the system,
which actually executes the tasks.
"""

import concurrent
import time
from concurrent.futures import Future
Expand Down Expand Up @@ -102,13 +103,20 @@ class RequestMixin:
def __init__(self):
self._request_collection = get_collection("requests")

def update_request_status(self, request_id: ObjectId, status: RequestStatus,
original_status: RequestStatus | list[RequestStatus] = None):
def update_request_status(
self,
request_id: ObjectId,
status: RequestStatus,
original_status: RequestStatus | list[RequestStatus] = None,
):
"""Update the status of a request by request_id."""
if original_status is not None:
if isinstance(original_status, list):
value_returned = self._request_collection.update_one(
{"_id": request_id, "status": {"$in": [status.name for status in original_status]}},
{
"_id": request_id,
"status": {"$in": [status.name for status in original_status]},
},
{"$set": {"status": status.name}},
)
else:
Expand Down Expand Up @@ -153,8 +161,8 @@ class ResourceRequester(RequestMixin):
"""

def __init__(
self,
task_id: ObjectId,
self,
task_id: ObjectId,
):
self._request_collection = get_collection("requests")
self._waiting: dict[ObjectId, dict[str, Any]] = {}
Expand All @@ -180,10 +188,10 @@ def __close__(self):
__del__ = __close__

def request_resources(
self,
resource_request: _ResourceRequestDict,
timeout: float | None = None,
priority: TaskPriority | int | None = None,
self,
resource_request: _ResourceRequestDict,
timeout: float | None = None,
priority: TaskPriority | int | None = None,
) -> dict[str, Any]:
"""
Request lab resources.
Expand Down Expand Up @@ -247,14 +255,10 @@ def request_resources(
result = self.get_concurrent_result(f, timeout=timeout)
except concurrent.futures.TimeoutError as e:
# if the request is not fulfilled, cancel it to make sure the resources are released
request = self._request_collection.find_one_and_update({
"_id": _id,
"status": {"$ne": RequestStatus.FULFILLED.name}
}, {
"$set": {
"status": RequestStatus.CANCELED.name
}
})
request = self._request_collection.find_one_and_update(
{"_id": _id, "status": {"$ne": RequestStatus.FULFILLED.name}},
{"$set": {"status": RequestStatus.CANCELED.name}},
)
if request is not None:
raise CombinedTimeoutError(
f"Request {result.inserted_id} timed out after {timeout} seconds."
Expand Down Expand Up @@ -292,17 +296,23 @@ def release_resources(self, request_id: ObjectId):
request = self.get_request(request_id)
if request["status"] in [RequestStatus.CANCELED.name, RequestStatus.ERROR.name]:
if ("assigned_devices" in request) or (
"assigned_sample_positions" in request
"assigned_sample_positions" in request
):
self.update_request_status(request_id, RequestStatus.NEED_RELEASE, original_status=[
RequestStatus.CANCELED, RequestStatus.ERROR
])
self.update_request_status(
request_id,
RequestStatus.NEED_RELEASE,
original_status=[RequestStatus.CANCELED, RequestStatus.ERROR],
)
else:
# If it doesn't have assigned resources, just leave it as CANCELED or ERROR
return
# For the requests that were fulfilled, definitely have assigned resources, release them
elif request["status"] == RequestStatus.FULFILLED.name:
self.update_request_status(request_id, RequestStatus.NEED_RELEASE, original_status=RequestStatus.FULFILLED)
self.update_request_status(
request_id,
RequestStatus.NEED_RELEASE,
original_status=RequestStatus.FULFILLED,
)

# wait for the request to be released or canceled or errored during the release
while self.get_request(request_id, projection=["status"])["status"] not in [
Expand Down Expand Up @@ -342,24 +352,27 @@ def release_all_resources(self):
RequestStatus.CANCELED.name,
RequestStatus.ERROR.name,
] and (
("assigned_devices" in request)
or ("assigned_sample_positions" in request)
("assigned_devices" in request)
or ("assigned_sample_positions" in request)
):
self.update_request_status(request["_id"], RequestStatus.NEED_RELEASE,
original_status=[RequestStatus.CANCELED, RequestStatus.ERROR])
self.update_request_status(
request["_id"],
RequestStatus.NEED_RELEASE,
original_status=[RequestStatus.CANCELED, RequestStatus.ERROR],
)
assigned_cancel_error_requests_id.append(request["_id"])

# wait for all the requests to be released or canceled or errored during the release
while any(
(
request["status"]
not in [
RequestStatus.RELEASED.name,
RequestStatus.CANCELED.name,
RequestStatus.ERROR.name,
]
)
for request in self.get_requests_by_task_id(self.task_id)
(
request["status"]
not in [
RequestStatus.RELEASED.name,
RequestStatus.CANCELED.name,
RequestStatus.ERROR.name,
]
)
for request in self.get_requests_by_task_id(self.task_id)
):
time.sleep(0.5)

Expand Down Expand Up @@ -435,9 +448,9 @@ def _handle_canceled_request(self, request_id: ObjectId):

@staticmethod
def _post_process_requested_resource(
devices: dict[type[BaseDevice] | str, str],
sample_positions: dict[str, list[str]],
resource_request: dict[str | type[BaseDevice] | None, dict[str, int]],
devices: dict[type[BaseDevice] | str, str],
sample_positions: dict[str, list[str]],
resource_request: dict[str | type[BaseDevice] | None, dict[str, int]],
):
processed_sample_positions: dict[
type[BaseDevice] | str | None, dict[str, list[str]]
Expand All @@ -456,7 +469,7 @@ def _post_process_requested_resource(
f"{devices[device_request]}{SamplePosition.SEPARATOR}"
)
if not reply_prefix.startswith(
device_prefix
device_prefix
): # dont extra prepend for nested requests
reply_prefix = device_prefix + reply_prefix
processed_sample_positions[device_request][prefix] = sample_positions[
Expand Down
1 change: 1 addition & 0 deletions alab_management/scripts/launch_worker.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Launch Dramatiq worker to submit tasks."""

from alab_management.task_manager.task_manager import TaskManager


Expand Down
4 changes: 3 additions & 1 deletion alab_management/task_manager/task_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,5 +152,7 @@ def handle_tasks_to_be_canceled(self):
self.task_view.update_canceling_progress(
task_id=task_entry["task_id"],
canceling_progress=CancelingProgress.WORKER_NOTIFIED,
original_progress=CancelingProgress[task_entry["canceling_progress"]],
original_progress=CancelingProgress[
task_entry["canceling_progress"]
],
)
51 changes: 51 additions & 0 deletions alab_management/user_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,29 @@ def get_all_pending_requests(self) -> list:
self._input_collection.find({"status": UserRequestStatus.PENDING.value}),
)

def retrieve_user_input_with_note(self, request_id: ObjectId) -> tuple[str, str]:
"""
Retrive response from user for a given request. Blocks until request is marked as completed.
Returns the user response, which is one of a list of options
"""
status = UserRequestStatus.PENDING
try:
while status == UserRequestStatus.PENDING:
request = self._input_collection.find_one({"_id": request_id})
if request is None:
raise ValueError(
f"User input request id {request_id} does not exist!"
)
status = UserRequestStatus(request["status"])
time.sleep(0.5)
except: # noqa: E722
self._input_collection.update_one(
{"_id": request_id}, {"$set": {"status": UserRequestStatus.ERROR.name}}
)
raise
return request["response"], request["note"]


def request_user_input(
task_id: ObjectId | None,
Expand Down Expand Up @@ -185,3 +208,31 @@ def request_maintenance_input(prompt: str, options: list[str]):
maintenance=True,
category="Maintenance",
)


def request_user_input_with_note(
task_id: ObjectId | None,
prompt: str,
options: list[str],
maintenance: bool = False,
category: str = "Unknown Category",
) -> tuple[str, str]:
"""
Request user input through the dashboard. Blocks until response is given.
task_id (ObjectId): task id requesting user input
prompt (str): prompt to give user
options (List[str]): response options to give user
maintenance (bool): if true, mark this as a request for overall system maintenance
Returns user response as string.
"""
user_input_view = UserInputView()
request_id = user_input_view.insert_request(
task_id=task_id,
prompt=prompt,
options=options,
maintenance=maintenance,
category=category,
)
return user_input_view.retrieve_user_input_with_note(request_id=request_id)
6 changes: 4 additions & 2 deletions alab_management/utils/data_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ def init(cls):
)
sim_mode_flag = AlabOSConfig().is_sim_mode()
# force to enable sim mode, just in case
cls.db = cls.client[AlabOSConfig()["general"]["name"] + ("_sim" * sim_mode_flag)]
cls.db = cls.client[
AlabOSConfig()["general"]["name"] + ("_sim" * sim_mode_flag)
]


class _GetCompletedMongoCollection(_BaseGetMongoCollection):
Expand All @@ -76,7 +78,7 @@ def init(cls):
if sim_mode_flag:
cls.db = cls.client[
AlabOSConfig()["general"]["name"] + "(completed)" + "_sim"
]
]
else:
cls.db = cls.client[AlabOSConfig()["general"]["name"] + "(completed)"]
# type: ignore # pylint: disable=unsubscriptable-object
Expand Down
Loading

0 comments on commit 1cadb71

Please sign in to comment.