From c98778864068bed63abd1d3f6f5d7da264bafff0 Mon Sep 17 00:00:00 2001 From: yxd92326 Date: Tue, 8 Apr 2025 17:24:39 +0100 Subject: [PATCH 1/4] Remove updates of global state --- src/murfey/server/__init__.py | 12 -------- src/murfey/server/api/__init__.py | 38 ------------------------ src/murfey/server/demo_api.py | 48 +------------------------------ src/murfey/util/models.py | 7 ----- 4 files changed, 1 insertion(+), 104 deletions(-) diff --git a/src/murfey/server/__init__.py b/src/murfey/server/__init__.py index 41e2bf1d7..94a5431bb 100644 --- a/src/murfey/server/__init__.py +++ b/src/murfey/server/__init__.py @@ -57,7 +57,6 @@ get_security_config, ) from murfey.util.processing_params import default_spa_parameters -from murfey.util.state import global_state from murfey.util.tomo import midpoint try: @@ -2258,17 +2257,6 @@ def feedback_callback(header: dict, message: dict) -> None: time.sleep(2) _transport_object.transport.nack(header, requeue=True) return None - if global_state.get("data_collection_group_ids") and isinstance( - global_state["data_collection_group_ids"], dict - ): - global_state["data_collection_group_ids"] = { - **global_state["data_collection_group_ids"], - message.get("tag"): dcgid, - } - else: - global_state["data_collection_group_ids"] = { - message.get("tag"): dcgid - } _transport_object.transport.ack(header) if dcg_hooks := entry_points().select( group="murfey.hooks", name="data_collection_group" diff --git a/src/murfey/server/api/__init__.py b/src/murfey/server/api/__init__.py index 0e786b551..9ca80c63f 100644 --- a/src/murfey/server/api/__init__.py +++ b/src/murfey/server/api/__init__.py @@ -76,7 +76,6 @@ BLSampleImageParameters, BLSampleParameters, BLSubSampleParameters, - ClearanceKeys, ClientInfo, ContextInfo, CurrentGainRef, @@ -107,7 +106,6 @@ Visit, ) from murfey.util.processing_params import default_spa_parameters -from murfey.util.state import global_state from murfey.util.tomo import midpoint from murfey.workflows.spa.flush_spa_preprocess import ( register_foil_hole, @@ -1795,42 +1793,6 @@ async def make_gif( return {"output_gif": str(output_path)} -@router.post("/visits/{visit_name}/clean_state") -async def clean_state(visit_name: str, for_clearance: ClearanceKeys): - if global_state.get("data_collection_group_ids") and isinstance( - global_state["data_collection_group_ids"], dict - ): - global_state["data_collection_group_ids"] = { - k: v - for k, v in global_state["data_collection_group_ids"].items() - if k not in for_clearance.data_collection_group - } - if global_state.get("data_collection_ids") and isinstance( - global_state["data_collection_ids"], dict - ): - global_state["data_collection_ids"] = { - k: v - for k, v in global_state["data_collection_ids"].items() - if k not in for_clearance.data_collection - } - if global_state.get("processing_job_ids") and isinstance( - global_state["processing_job_ids"], dict - ): - global_state["processing_job_ids"] = { - k: v - for k, v in global_state["processing_job_ids"].items() - if k not in for_clearance.processing_job - } - if global_state.get("autoproc_program_ids") and isinstance( - global_state["autoproc_program_ids"], dict - ): - global_state["autoproc_program_ids"] = { - k: v - for k, v in global_state["autoproc_program_ids"].items() - if k not in for_clearance.autoproc_program - } - - @router.get("/new_client_id/") async def new_client_id(db=murfey_db): clients = db.exec(select(ClientEnvironment)).all() diff --git a/src/murfey/server/demo_api.py b/src/murfey/server/demo_api.py index d953b87ab..3bb6c9c04 100644 --- a/src/murfey/server/demo_api.py +++ b/src/murfey/server/demo_api.py @@ -89,7 +89,6 @@ Visit, ) from murfey.util.processing_params import default_spa_parameters -from murfey.util.state import global_state from murfey.workflows.spa.picking import _register_picked_particles_use_diameter log = logging.getLogger("murfey.server.demo_api") @@ -1447,15 +1446,6 @@ def register_dc_group( db.add(murfey_app_3d) db.commit() - if global_state.get("data_collection_group_ids") and isinstance( - global_state["data_collection_group_ids"], dict - ): - global_state["data_collection_group_ids"] = { - **global_state["data_collection_group_ids"], - dcg_params.tag: dcgid, - } - else: - global_state["data_collection_group_ids"] = {dcg_params.tag: dcgid} if dcg_params.atlas: _flush_grid_square_records( {"session_id": session_id, "tag": dcg_params.tag}, demo=True @@ -1511,15 +1501,6 @@ def start_dc( db.add(murfey_app) db.commit() db.close() - if global_state.get("data_collection_ids") and isinstance( - global_state["data_collection_ids"], dict - ): - global_state["data_collection_ids"] = { - **global_state["data_collection_ids"], - dc_params.tag: 1, - } - else: - global_state["data_collection_ids"] = {dc_params.tag: 1} if dc_params.exposure_time: prom.exposure_time.set(dc_params.exposure_time) return dc_params @@ -1529,35 +1510,8 @@ def start_dc( def register_proc( visit_name, session_id: MurfeySessionID, proc_params: ProcessingJobParameters ): + # This should probably do something log.info("Registering processing job") - if global_state.get("processing_job_ids"): - assert isinstance(global_state["processing_job_ids"], dict) - global_state["processing_job_ids"] = { - **{ - k: v - for k, v in global_state["processing_job_ids"].items() - if k != proc_params.tag - }, - proc_params.tag: { - **global_state["processing_job_ids"].get(proc_params.tag, {}), - proc_params.recipe: 1, - }, - } - else: - global_state["processing_job_ids"] = {proc_params.tag: {proc_params.recipe: 1}} - if global_state.get("autoproc_program_ids"): - assert isinstance(global_state["autoproc_program_ids"], dict) - global_state["autoproc_program_ids"] = { - **global_state["autoproc_program_ids"], - proc_params.tag: { - **global_state["autoproc_program_ids"].get(proc_params.tag, {}), - proc_params.recipe: 1, - }, - } - else: - global_state["autoproc_program_ids"] = { - proc_params.tag: {proc_params.recipe: 1} - } log.info("Processing job registered") return proc_params diff --git a/src/murfey/util/models.py b/src/murfey/util/models.py index d72eff927..c825c8b36 100644 --- a/src/murfey/util/models.py +++ b/src/murfey/util/models.py @@ -127,13 +127,6 @@ class RsyncerInfo(BaseModel): tag: str = "" -class ClearanceKeys(BaseModel): - data_collection_group: List[str] - data_collection: List[str] - processing_job: List[str] - autoproc_program: List[str] - - class GainReference(BaseModel): gain_ref: Path rescale: bool = True From 67f142dd1bfca2f4b026ac32c66f54bda8968366 Mon Sep 17 00:00:00 2001 From: yxd92326 Date: Tue, 8 Apr 2025 17:34:35 +0100 Subject: [PATCH 2/4] Completely delete the global state --- src/murfey/server/api/__init__.py | 20 ------ src/murfey/server/demo_api.py | 22 ------- src/murfey/server/websocket.py | 46 +++---------- src/murfey/util/models.py | 13 ---- src/murfey/util/state.py | 106 ------------------------------ 5 files changed, 9 insertions(+), 198 deletions(-) delete mode 100644 src/murfey/util/state.py diff --git a/src/murfey/server/api/__init__.py b/src/murfey/server/api/__init__.py index 9ca80c63f..270892e73 100644 --- a/src/murfey/server/api/__init__.py +++ b/src/murfey/server/api/__init__.py @@ -77,11 +77,9 @@ BLSampleParameters, BLSubSampleParameters, ClientInfo, - ContextInfo, CurrentGainRef, DCGroupParameters, DCParameters, - File, FoilHoleParameters, FractionationParameters, GainReference, @@ -1021,23 +1019,6 @@ def visit_info( return None -@router.post("/visits/{visit_name}/context") -async def register_context(context_info: ContextInfo): - await ws.manager.broadcast(f"Context registered: {context_info}") - await ws.manager.set_state("experiment_type", context_info.experiment_type) - await ws.manager.set_state( - "acquisition_software", context_info.acquisition_software - ) - - -@router.post("/visits/{visit_name}/files") -async def add_file(file: File): - message = f"File {file} transferred" - log.info(message) - await ws.manager.broadcast(f"File {file} transferred") - return file - - @router.post("/instruments/{instrument_name}/feedback") async def send_murfey_message(instrument_name: str, msg: RegistrationMessage): if _transport_object: @@ -1320,7 +1301,6 @@ async def request_tomography_preprocessing( db.add(for_stash) db.commit() db.close() - # await ws.manager.broadcast(f"Pre-processing requested for {ppath.name}") return proc_file diff --git a/src/murfey/server/demo_api.py b/src/murfey/server/demo_api.py index 3bb6c9c04..694c9f6d9 100644 --- a/src/murfey/server/demo_api.py +++ b/src/murfey/server/demo_api.py @@ -62,11 +62,9 @@ ) from murfey.util.models import ( ClientInfo, - ContextInfo, CurrentGainRef, DCGroupParameters, DCParameters, - File, FoilHoleParameters, FractionationParameters, GainReference, @@ -899,26 +897,6 @@ def visit_info(request: Request, visit_name: str): ) -@router.post("/visits/{visit_name}/context") -async def register_context(context_info: ContextInfo): - log.info( - f"Context {context_info.experiment_type}:{context_info.acquisition_software} registered" - ) - await ws.manager.broadcast(f"Context registered: {context_info}") - await ws.manager.set_state("experiment_type", context_info.experiment_type) - await ws.manager.set_state( - "acquisition_software", context_info.acquisition_software - ) - - -@router.post("/visits/{visit_name}/files") -async def add_file(file: File): - message = f"File {file} transferred" - log.info(message) - await ws.manager.broadcast(f"File {file} transferred") - return file - - @router.post("/feedback") async def send_murfey_message(msg: RegistrationMessage): pass diff --git a/src/murfey/server/websocket.py b/src/murfey/server/websocket.py index f6f5302a4..ba04ef309 100644 --- a/src/murfey/server/websocket.py +++ b/src/murfey/server/websocket.py @@ -4,7 +4,7 @@ import json import logging from datetime import datetime -from typing import Any, Dict, Generic, TypeVar, Union +from typing import Any, Dict, TypeVar, Union from fastapi import APIRouter, WebSocket, WebSocketDisconnect from sqlmodel import select @@ -13,7 +13,6 @@ from murfey.server.murfey_db import get_murfey_db_session from murfey.util import sanitise from murfey.util.db import ClientEnvironment -from murfey.util.state import State, global_state T = TypeVar("T") @@ -21,11 +20,9 @@ log = logging.getLogger("murfey.server.websocket") -class ConnectionManager(Generic[T]): - def __init__(self, state: State[T]): +class ConnectionManager: + def __init__(self): self.active_connections: Dict[int | str, WebSocket] = {} - self._state = state - self._state.subscribe(self._broadcast_state_update) async def connect( self, websocket: WebSocket, client_id: int | str, register_client: bool = True @@ -38,7 +35,6 @@ async def connect( "To register a client the client ID must be an integer" ) self._register_new_client(client_id) - await websocket.send_json({"message": "state-full", "state": self._state.data}) @staticmethod def _register_new_client(client_id: int): @@ -48,9 +44,7 @@ def _register_new_client(client_id: int): murfey_db.commit() murfey_db.close() - def disconnect( - self, websocket: WebSocket, client_id: int | str, unregister_client: bool = True - ): + def disconnect(self, client_id: int | str, unregister_client: bool = True): self.active_connections.pop(client_id) if unregister_client: murfey_db = next(get_murfey_db_session()) @@ -67,33 +61,14 @@ async def broadcast(self, message: str): for connection in self.active_connections: await self.active_connections[connection].send_text(message) - async def _broadcast_state_update( - self, attribute: str, value: T | None, message: str = "state-update" - ): - for connection in self.active_connections: - await self.active_connections[connection].send_json( - {"message": message, "attribute": attribute, "value": value} - ) - - async def set_state(self, attribute: str, value: T): - log.info( - f"State attribute {sanitise(attribute)!r} set to {sanitise(str(value))!r}" - ) - await self._state.set(attribute, value) - - async def delete_state(self, attribute: str): - log.info(f"State attribute {sanitise(attribute)!r} removed") - await self._state.delete(attribute) - -manager = ConnectionManager(global_state) +manager = ConnectionManager() @ws.websocket("/test/{client_id}") async def websocket_endpoint(websocket: WebSocket, client_id: int): await manager.connect(websocket, client_id) await manager.broadcast(f"Client {client_id} joined") - await manager.set_state(f"Client {client_id}", "joined") try: while True: data = await websocket.receive_text() @@ -111,9 +86,8 @@ async def websocket_endpoint(websocket: WebSocket, client_id: int): select(ClientEnvironment).where(ClientEnvironment.client_id == client_id) ).one() prom.monitoring_switch.labels(visit=client_env.visit).set(0) - manager.disconnect(websocket, client_id) + manager.disconnect(client_id) await manager.broadcast(f"Client #{client_id} disconnected") - await manager.delete_state(f"Client {client_id}") @ws.websocket("/connect/{client_id}") @@ -122,7 +96,6 @@ async def websocket_connection_endpoint( ): await manager.connect(websocket, client_id, register_client=False) await manager.broadcast(f"Client {client_id} joined") - await manager.set_state(f"Client {client_id}", "joined") try: while True: data = await websocket.receive_text() @@ -138,9 +111,8 @@ async def websocket_connection_endpoint( await manager.broadcast(f"Client #{client_id} sent message {data}") except WebSocketDisconnect: log.info(f"Disconnecting Client {sanitise(str(client_id))}") - manager.disconnect(websocket, client_id, unregister_client=False) + manager.disconnect(client_id, unregister_client=False) await manager.broadcast(f"Client #{client_id} disconnected") - await manager.delete_state(f"Client {client_id}") async def check_connections(active_connections): @@ -178,7 +150,7 @@ async def close_ws_connection(client_id: int): murfey_db.close() client_id_str = str(client_id).replace("\r\n", "").replace("\n", "") log.info(f"Disconnecting {client_id_str}") - manager.disconnect(manager.active_connections[client_id], client_id) + manager.disconnect(client_id) prom.monitoring_switch.labels(visit=visit_name).set(0) await manager.broadcast(f"Client #{client_id} disconnected") @@ -187,5 +159,5 @@ async def close_ws_connection(client_id: int): async def close_unrecorded_ws_connection(client_id: Union[int, str]): client_id_str = str(client_id).replace("\r\n", "").replace("\n", "") log.info(f"Disconnecting {client_id_str}") - manager.disconnect(manager.active_connections[client_id], client_id) + manager.disconnect(client_id) await manager.broadcast(f"Client #{client_id} disconnected") diff --git a/src/murfey/util/models.py b/src/murfey/util/models.py index c825c8b36..c8d3faaea 100644 --- a/src/murfey/util/models.py +++ b/src/murfey/util/models.py @@ -83,14 +83,6 @@ class RegistrationMessage(BaseModel): params: Optional[Dict[str, Any]] = None -class File(BaseModel): - name: str - description: str - size: int - timestamp: datetime - full_path: str - - class ConnectionFileParameters(BaseModel): filename: str destinations: List[str] @@ -102,11 +94,6 @@ class SessionInfo(BaseModel): rescale: bool = True -class ContextInfo(BaseModel): - experiment_type: str - acquisition_software: str - - class ClientInfo(BaseModel): id: int diff --git a/src/murfey/util/state.py b/src/murfey/util/state.py deleted file mode 100644 index 10f02299c..000000000 --- a/src/murfey/util/state.py +++ /dev/null @@ -1,106 +0,0 @@ -from __future__ import annotations - -import asyncio -from typing import Awaitable, Callable, Mapping, TypeVar, Union - -T = TypeVar("T") -GlobalStateValues = Union[str, int, list, dict, None] - -from murfey.util import Observer - - -class State(Mapping[str, T], Observer): - """A helper class to coordinate shared state across server instances. - This is a Mapping with added (synchronous) set and delete functionality, - as well as asynchronous .update/.delete calls. It implements the Observer - pattern notifying synchronous and asynchronous callback functions. - """ - - def __init__(self): - self.data: dict[str, T] = {} - self._listeners: list[Callable[[str, T | None], Awaitable[None] | None]] = [] - super().__init__() - - def __repr__(self): - return f"{type(self).__name__}({self.data}; {len(self._listeners)} subscribers)" - - def __len__(self) -> int: - return len(self.data) - - def __iter__(self): - return iter(self.data) - - def __contains__(self, key) -> bool: - return key in self.data - - def __getitem__(self, key: str) -> T: - if key in self.data: - return self.data[key] - raise KeyError(key) - - async def delete(self, key: str): - del self.data[key] - await self.anotify(key, None) - - async def aupdate(self, key: str, value: T): - if self.data.get(key): - if isinstance(self.data[key], dict): - self.data[key].update(value) # type: ignore - await self.anotify(key, value, message="state-update-partial") - else: - self.data[key] = value - await self.anotify(key, value, message="state-update") - - async def set(self, key: str, value: T): - self.data[key] = value - await self.anotify(key, value) - - def update(self, key: str, value: T, perform_state_update: bool = False): - if self.data.get(key): - if isinstance(self.data[key], dict): - if perform_state_update: - self.data[key].update(value) # type: ignore - self.notify(key, value, message="state-update-partial") - else: - self.data[key] = value - self.notify(key, value, message="state-update") - - def subscribe( - self, - fn: Callable[[str, T | None], Awaitable[None] | None], - secondary: bool = False, - final: bool = False, - ): - if secondary: - self._secondary_listeners.append(fn) - elif final: - self._final_listeners.append(fn) - else: - self._listeners.append(fn) - - def __setitem__(self, key: str, item: T): - try: - asyncio.get_running_loop() - except RuntimeError: - # This is synchronous code, we're not running in an event loop - self.data[key] = item - self.notify(key, item) - return - raise RuntimeError( - "__setitem__() called from async code. Use async .update() instead" - ) - - def __delitem__(self, key: str): - try: - asyncio.get_running_loop() - except RuntimeError: - # This is synchronous code, we're not running in an event loop - del self.data[key] - self.notify(key, None) - return - raise RuntimeError( - "__delitem__() called from async code. Use async .delete() instead" - ) - - -global_state: State[GlobalStateValues] = State() From 29d0119de75fb9408af2701e26e373f110ae9883 Mon Sep 17 00:00:00 2001 From: yxd92326 Date: Wed, 9 Apr 2025 09:04:48 +0100 Subject: [PATCH 3/4] This is still used --- src/murfey/util/models.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/murfey/util/models.py b/src/murfey/util/models.py index c8d3faaea..f1855f0dc 100644 --- a/src/murfey/util/models.py +++ b/src/murfey/util/models.py @@ -83,6 +83,14 @@ class RegistrationMessage(BaseModel): params: Optional[Dict[str, Any]] = None +class File(BaseModel): + name: str + description: str + size: int + timestamp: datetime + full_path: str + + class ConnectionFileParameters(BaseModel): filename: str destinations: List[str] From 4b1ca0e7ec423e7938334bce52ca99f195714357 Mon Sep 17 00:00:00 2001 From: yxd92326 Date: Wed, 9 Apr 2025 09:05:10 +0100 Subject: [PATCH 4/4] Don't test state as that's gone --- tests/util/test_state.py | 183 --------------------------------------- 1 file changed, 183 deletions(-) delete mode 100644 tests/util/test_state.py diff --git a/tests/util/test_state.py b/tests/util/test_state.py deleted file mode 100644 index 9a19c3158..000000000 --- a/tests/util/test_state.py +++ /dev/null @@ -1,183 +0,0 @@ -from __future__ import annotations - -import asyncio -import inspect -from unittest import mock - -import pytest - -from murfey.util.state import State - - -def test_default_state_behaves_like_empty_dictionary(): - s = State() - assert s == {} - assert dict(s) == {} - assert len(s) == 0 - assert not s - - -def test_state_object_behaves_like_a_dictionary(): - s = State() - - assert s.get("key") is None - assert "key" not in s - assert len(s) == 0 - assert not s - - s["key"] = "value" - - assert s["key"] == "value" - assert "key" in s - assert "notkey" not in s - assert len(s) == 1 - assert s - - s["key"] = "newvalue" - - assert s["key"] == "newvalue" - assert "key" in s - assert len(s) == 1 - assert s - - -def test_calling_async_methods_synchronously(): - s = State() - return_value = s.aupdate("key", "value") - assert inspect.isawaitable(return_value) - assert len(s) == 0 - asyncio.run(return_value) - assert len(s) == 1 - - return_value = s.delete("key") - assert inspect.isawaitable(return_value) - assert len(s) == 1 - asyncio.run(return_value) - assert len(s) == 0 - - -def test_calling_sync_methods_asynchronously(): - s = State() - - async def set_value(): - s["key"] = "value" - - async def delete_value(): - del s["key"] - - with pytest.raises(RuntimeError, match="async.*update.*instead"): - asyncio.run(set_value()) - assert not s - - s["key"] = "value" - assert s - with pytest.raises(RuntimeError, match="async.*delete.*instead"): - asyncio.run(delete_value()) - assert s - - -def test_state_object_supports_multiple_non_async_listeners(): - s = State() - listener = mock.Mock() - s.subscribe(listener) - s.subscribe(listener) - assert "2" in repr(s) - - s["attribute"] = mock.sentinel.value - - assert listener.call_count == 2 - listener.assert_has_calls([mock.call("attribute", mock.sentinel.value)] * 2) - - -def test_state_object_notifies_listeners_on_synchronous_change(): - # Test with both sync and async subscribers - s = State() - assert "0" in repr(s) - - sync_listener = mock.Mock() - s.subscribe(sync_listener) - sync_listener.assert_not_called() - async_listener = mock.AsyncMock() - s.subscribe(async_listener) - async_listener.assert_not_called() - assert "2" in repr(s) - assert "key" not in repr(s) - - s["key"] = mock.sentinel.value - sync_listener.assert_called_once_with("key", mock.sentinel.value) - async_listener.assert_called_once_with("key", mock.sentinel.value) - async_listener.assert_awaited() - assert "key" in repr(s) - - sync_listener.reset_mock() - async_listener.reset_mock() - sync_listener.assert_not_called() - async_listener.assert_not_called() - s["key"] = mock.sentinel.value2 - sync_listener.assert_called_once_with("key", mock.sentinel.value2) - async_listener.assert_called_once_with("key", mock.sentinel.value2) - async_listener.assert_awaited() - - sync_listener.reset_mock() - async_listener.reset_mock() - assert s["key"] == mock.sentinel.value2 - # Dictionary access should not notify - sync_listener.assert_not_called() - async_listener.assert_not_called() - - sync_listener.reset_mock() - async_listener.reset_mock() - del s["key"] - sync_listener.assert_called_once_with("key", None) - async_listener.assert_called_once_with("key", None) - async_listener.assert_awaited() - - -def test_state_object_notifies_listeners_on_asynchronous_change(): - # Test with both sync and async subscribers - s = State() - assert "0" in repr(s) - - sync_listener = mock.Mock() - s.subscribe(sync_listener) - sync_listener.assert_not_called() - async_listener = mock.AsyncMock() - s.subscribe(async_listener) - async_listener.assert_not_called() - assert "2" in repr(s) - assert "key" not in repr(s) - - async def set_value(): - await s.aupdate("key", mock.sentinel.value) - - async def set_value2(): - await s.aupdate("key", mock.sentinel.value2) - - async def delete_value(): - await s.delete("key") - - asyncio.run(set_value()) - sync_listener.assert_called_once_with( - "key", mock.sentinel.value, message="state-update" - ) - async_listener.assert_called_once_with( - "key", mock.sentinel.value, message="state-update" - ) - async_listener.assert_awaited() - assert "key" in repr(s) - - sync_listener.reset_mock() - async_listener.reset_mock() - sync_listener.assert_not_called() - async_listener.assert_not_called() - asyncio.run(set_value2()) - # sync_listener.assert_called_once_with("key", mock.sentinel.value2, message="state-update-partial") - # async_listener.assert_called_once_with("key", mock.sentinel.value2) - # async_listener.assert_awaited() - - sync_listener.reset_mock() - async_listener.reset_mock() - asyncio.run(delete_value()) - sync_listener.assert_called_once_with("key", None) - async_listener.assert_called_once_with("key", None) - async_listener.assert_awaited()