Skip to content

Commit

Permalink
black, ruff, lint
Browse files Browse the repository at this point in the history
  • Loading branch information
bernardusrendy committed Jun 18, 2024
1 parent afec48d commit aff8a52
Show file tree
Hide file tree
Showing 13 changed files with 168 additions and 101 deletions.
21 changes: 12 additions & 9 deletions alab_management/experiment_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
tasks and samples and mark the finished tasks in the database when it is
done.
"""

import time
from typing import Any

Expand All @@ -30,7 +31,7 @@ def __init__(self):

config = AlabOSConfig()
self.__copy_to_completed_db = (
"mongodb_completed" in config
"mongodb_completed" in config
) # if this is not defined in the config, assume it this feature is not being used.
if self.__copy_to_completed_db:
self.completed_experiment_view = CompletedExperimentView()
Expand Down Expand Up @@ -91,7 +92,9 @@ def _handle_pending_experiment(self, experiment: dict[str, Any]):
},
)
if task_graph.has_cycle():
self.experiment_view.update_experiment_status(experiment["_id"], ExperimentStatus.ERROR)
self.experiment_view.update_experiment_status(
experiment["_id"], ExperimentStatus.ERROR
)
print(f"Experiment ({experiment['_id']}) has a cycle in the graph.")
return

Expand Down Expand Up @@ -156,13 +159,13 @@ def mark_completed_experiments(self):

# if all the tasks of an experiment have been finished
if all(
self.task_view.get_status(task_id=task_id)
in {
TaskStatus.COMPLETED,
TaskStatus.ERROR,
TaskStatus.CANCELLED,
}
for task_id in task_ids
self.task_view.get_status(task_id=task_id)
in {
TaskStatus.COMPLETED,
TaskStatus.ERROR,
TaskStatus.CANCELLED,
}
for task_id in task_ids
):
self.experiment_view.update_experiment_status(
exp_id=experiment["_id"], status=ExperimentStatus.COMPLETED
Expand Down
4 changes: 1 addition & 3 deletions alab_management/lab_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,7 @@ def request_resources(
device_type: self._device_client.create_device_wrapper(device_name)
for device_type, device_name in devices.items()
} # type: ignore
self._task_view.update_status(
task_id=self.task_id, status=TaskStatus.RUNNING
)
self._task_view.update_status(task_id=self.task_id, status=TaskStatus.RUNNING)
yield devices, sample_positions

self._resource_requester.release_resources(request_id=request_id)
Expand Down
17 changes: 11 additions & 6 deletions alab_management/resource_manager/resource_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,9 @@ def handle_released_resources(self):
self._release_devices(devices)
self._release_sample_positions(sample_positions)
self.update_request_status(
request_id=request["_id"], status=RequestStatus.RELEASED, original_status=RequestStatus.NEED_RELEASE
request_id=request["_id"],
status=RequestStatus.RELEASED,
original_status=RequestStatus.NEED_RELEASE,
)

def handle_requested_resources(self):
Expand All @@ -84,9 +86,12 @@ def _handle_requested_resources(self, request_entry: dict[str, Any]):
task_id = request_entry["task_id"]

task_status = self.task_view.get_status(task_id=task_id)
if (task_status != TaskStatus.REQUESTING_RESOURCES or
task_id in {task["task_id"] for task in self.task_view.get_tasks_to_be_canceled(
canceling_progress=CancelingProgress.WORKER_NOTIFIED)}):
if task_status != TaskStatus.REQUESTING_RESOURCES or task_id in {
task["task_id"]
for task in self.task_view.get_tasks_to_be_canceled(
canceling_progress=CancelingProgress.WORKER_NOTIFIED
)
}:
# this implies the Task has been cancelled or errored somewhere else in the chain -- we should
# not allocate any resources to the broken Task.
self.update_request_status(
Expand Down Expand Up @@ -197,7 +202,7 @@ def _occupy_devices(self, devices: dict[str, dict[str, Any]], task_id: ObjectId)
)

def _occupy_sample_positions(
self, sample_positions: dict[str, list[dict[str, Any]]], task_id: ObjectId
self, sample_positions: dict[str, list[dict[str, Any]]], task_id: ObjectId
):
for sample_positions_ in sample_positions.values():
for sample_position_ in sample_positions_:
Expand All @@ -211,7 +216,7 @@ def _release_devices(self, devices: dict[str, dict[str, Any]]):
self.device_view.release_device(device["name"])

def _release_sample_positions(
self, sample_positions: dict[str, list[dict[str, Any]]]
self, sample_positions: dict[str, list[dict[str, Any]]]
):
for sample_positions_ in sample_positions.values():
for sample_position in sample_positions_:
Expand Down
91 changes: 52 additions & 39 deletions alab_management/resource_manager/resource_requester.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
TaskLauncher is the core module of the system,
which actually executes the tasks.
"""

import concurrent
import time
from concurrent.futures import Future
Expand Down Expand Up @@ -102,13 +103,20 @@ class RequestMixin:
def __init__(self):
self._request_collection = get_collection("requests")

def update_request_status(self, request_id: ObjectId, status: RequestStatus,
original_status: RequestStatus | list[RequestStatus] = None):
def update_request_status(
self,
request_id: ObjectId,
status: RequestStatus,
original_status: RequestStatus | list[RequestStatus] = None,
):
"""Update the status of a request by request_id."""
if original_status is not None:
if isinstance(original_status, list):
value_returned = self._request_collection.update_one(
{"_id": request_id, "status": {"$in": [status.name for status in original_status]}},
{
"_id": request_id,
"status": {"$in": [status.name for status in original_status]},
},
{"$set": {"status": status.name}},
)
else:
Expand Down Expand Up @@ -153,8 +161,8 @@ class ResourceRequester(RequestMixin):
"""

def __init__(
self,
task_id: ObjectId,
self,
task_id: ObjectId,
):
self._request_collection = get_collection("requests")
self._waiting: dict[ObjectId, dict[str, Any]] = {}
Expand All @@ -180,10 +188,10 @@ def __close__(self):
__del__ = __close__

def request_resources(
self,
resource_request: _ResourceRequestDict,
timeout: float | None = None,
priority: TaskPriority | int | None = None,
self,
resource_request: _ResourceRequestDict,
timeout: float | None = None,
priority: TaskPriority | int | None = None,
) -> dict[str, Any]:
"""
Request lab resources.
Expand Down Expand Up @@ -247,14 +255,10 @@ def request_resources(
result = self.get_concurrent_result(f, timeout=timeout)
except concurrent.futures.TimeoutError as e:
# if the request is not fulfilled, cancel it to make sure the resources are released
request = self._request_collection.find_one_and_update({
"_id": _id,
"status": {"$ne": RequestStatus.FULFILLED.name}
}, {
"$set": {
"status": RequestStatus.CANCELED.name
}
})
request = self._request_collection.find_one_and_update(
{"_id": _id, "status": {"$ne": RequestStatus.FULFILLED.name}},
{"$set": {"status": RequestStatus.CANCELED.name}},
)
if request is not None:
raise CombinedTimeoutError(
f"Request {result.inserted_id} timed out after {timeout} seconds."
Expand Down Expand Up @@ -292,17 +296,23 @@ def release_resources(self, request_id: ObjectId):
request = self.get_request(request_id)
if request["status"] in [RequestStatus.CANCELED.name, RequestStatus.ERROR.name]:
if ("assigned_devices" in request) or (
"assigned_sample_positions" in request
"assigned_sample_positions" in request
):
self.update_request_status(request_id, RequestStatus.NEED_RELEASE, original_status=[
RequestStatus.CANCELED, RequestStatus.ERROR
])
self.update_request_status(
request_id,
RequestStatus.NEED_RELEASE,
original_status=[RequestStatus.CANCELED, RequestStatus.ERROR],
)
else:
# If it doesn't have assigned resources, just leave it as CANCELED or ERROR
return
# For the requests that were fulfilled, definitely have assigned resources, release them
elif request["status"] == RequestStatus.FULFILLED.name:
self.update_request_status(request_id, RequestStatus.NEED_RELEASE, original_status=RequestStatus.FULFILLED)
self.update_request_status(
request_id,
RequestStatus.NEED_RELEASE,
original_status=RequestStatus.FULFILLED,
)

# wait for the request to be released or canceled or errored during the release
while self.get_request(request_id, projection=["status"])["status"] not in [
Expand Down Expand Up @@ -342,24 +352,27 @@ def release_all_resources(self):
RequestStatus.CANCELED.name,
RequestStatus.ERROR.name,
] and (
("assigned_devices" in request)
or ("assigned_sample_positions" in request)
("assigned_devices" in request)
or ("assigned_sample_positions" in request)
):
self.update_request_status(request["_id"], RequestStatus.NEED_RELEASE,
original_status=[RequestStatus.CANCELED, RequestStatus.ERROR])
self.update_request_status(
request["_id"],
RequestStatus.NEED_RELEASE,
original_status=[RequestStatus.CANCELED, RequestStatus.ERROR],
)
assigned_cancel_error_requests_id.append(request["_id"])

# wait for all the requests to be released or canceled or errored during the release
while any(
(
request["status"]
not in [
RequestStatus.RELEASED.name,
RequestStatus.CANCELED.name,
RequestStatus.ERROR.name,
]
)
for request in self.get_requests_by_task_id(self.task_id)
(
request["status"]
not in [
RequestStatus.RELEASED.name,
RequestStatus.CANCELED.name,
RequestStatus.ERROR.name,
]
)
for request in self.get_requests_by_task_id(self.task_id)
):
time.sleep(0.5)

Expand Down Expand Up @@ -435,9 +448,9 @@ def _handle_canceled_request(self, request_id: ObjectId):

@staticmethod
def _post_process_requested_resource(
devices: dict[type[BaseDevice] | str, str],
sample_positions: dict[str, list[str]],
resource_request: dict[str | type[BaseDevice] | None, dict[str, int]],
devices: dict[type[BaseDevice] | str, str],
sample_positions: dict[str, list[str]],
resource_request: dict[str | type[BaseDevice] | None, dict[str, int]],
):
processed_sample_positions: dict[
type[BaseDevice] | str | None, dict[str, list[str]]
Expand All @@ -456,7 +469,7 @@ def _post_process_requested_resource(
f"{devices[device_request]}{SamplePosition.SEPARATOR}"
)
if not reply_prefix.startswith(
device_prefix
device_prefix
): # dont extra prepend for nested requests
reply_prefix = device_prefix + reply_prefix
processed_sample_positions[device_request][prefix] = sample_positions[
Expand Down
1 change: 1 addition & 0 deletions alab_management/scripts/launch_worker.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Launch Dramatiq worker to submit tasks."""

from alab_management.task_manager.task_manager import TaskManager


Expand Down
4 changes: 3 additions & 1 deletion alab_management/task_manager/task_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,5 +152,7 @@ def handle_tasks_to_be_canceled(self):
self.task_view.update_canceling_progress(
task_id=task_entry["task_id"],
canceling_progress=CancelingProgress.WORKER_NOTIFIED,
original_progress=CancelingProgress[task_entry["canceling_progress"]],
original_progress=CancelingProgress[
task_entry["canceling_progress"]
],
)
2 changes: 2 additions & 0 deletions alab_management/user_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ def retrieve_user_input_with_note(self, request_id: ObjectId) -> tuple[str, str]
raise
return request["response"], request["note"]


def request_user_input(
task_id: ObjectId | None,
prompt: str,
Expand Down Expand Up @@ -208,6 +209,7 @@ def request_maintenance_input(prompt: str, options: list[str]):
category="Maintenance",
)


def request_user_input_with_note(
task_id: ObjectId | None,
prompt: str,
Expand Down
6 changes: 4 additions & 2 deletions alab_management/utils/data_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ def init(cls):
)
sim_mode_flag = AlabOSConfig().is_sim_mode()
# force to enable sim mode, just in case
cls.db = cls.client[AlabOSConfig()["general"]["name"] + ("_sim" * sim_mode_flag)]
cls.db = cls.client[
AlabOSConfig()["general"]["name"] + ("_sim" * sim_mode_flag)
]


class _GetCompletedMongoCollection(_BaseGetMongoCollection):
Expand All @@ -76,7 +78,7 @@ def init(cls):
if sim_mode_flag:
cls.db = cls.client[
AlabOSConfig()["general"]["name"] + "(completed)" + "_sim"
]
]
else:
cls.db = cls.client[AlabOSConfig()["general"]["name"] + "(completed)"]
# type: ignore # pylint: disable=unsubscriptable-object
Expand Down
5 changes: 4 additions & 1 deletion tests/fake_lab/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@
from .devices.furnace import Furnace
from .devices.robot_arm import RobotArm
from .tasks.ending import Ending
from .tasks.error_handling_task import ErrorHandlingRecoverable, ErrorHandlingUnrecoverable
from .tasks.error_handling_task import (
ErrorHandlingRecoverable,
ErrorHandlingUnrecoverable,
)
from .tasks.heating import Heating
from .tasks.infinite_task import InfiniteTask
from .tasks.moving import Moving
Expand Down
14 changes: 11 additions & 3 deletions tests/fake_lab/tasks/error_handling_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@ def __init__(self, samples: list[ObjectId], *args, **kwargs):
self.sample = samples[0]

def run(self):
with self.lab_view.request_resources({DeviceThatFails: {"failures": 1}}) as (devices, _):
with self.lab_view.request_resources({DeviceThatFails: {"failures": 1}}) as (
devices,
_,
):
device_that_fails = devices[DeviceThatFails]
device_that_fails.fail()

Expand All @@ -22,12 +25,17 @@ def __init__(self, samples: list[ObjectId], *args, **kwargs):
self.sample = samples[0]

def run(self):
with self.lab_view.request_resources({DeviceThatFails: {"failures": 1}}) as (devices, _):
with self.lab_view.request_resources({DeviceThatFails: {"failures": 1}}) as (
devices,
_,
):
device_that_fails_ = devices[DeviceThatFails]
try:
device_that_fails_.fail()
except Exception as e:
response = self.lab_view.request_user_input("What should I do?", options=["OK", "Abort"])
response = self.lab_view.request_user_input(
"What should I do?", options=["OK", "Abort"]
)
if response == "OK":
pass
else:
Expand Down
4 changes: 3 additions & 1 deletion tests/test_lab_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,9 @@ def test_request_resources(self):

self.assertEqual(
"LOCKED",
self.sample_view.get_sample_position_status("furnace_1/inside/1")[0].name,
self.sample_view.get_sample_position_status("furnace_1/inside/1")[
0
].name,
)
self.assertEqual(
"LOCKED",
Expand Down
Loading

0 comments on commit aff8a52

Please sign in to comment.