Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 0 additions & 12 deletions src/murfey/server/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@
get_security_config,
)
from murfey.util.processing_params import default_spa_parameters
from murfey.util.state import global_state
from murfey.util.tomo import midpoint

try:
Expand Down Expand Up @@ -2258,17 +2257,6 @@ def feedback_callback(header: dict, message: dict) -> None:
time.sleep(2)
_transport_object.transport.nack(header, requeue=True)
return None
if global_state.get("data_collection_group_ids") and isinstance(
global_state["data_collection_group_ids"], dict
):
global_state["data_collection_group_ids"] = {
**global_state["data_collection_group_ids"],
message.get("tag"): dcgid,
}
else:
global_state["data_collection_group_ids"] = {
message.get("tag"): dcgid
}
_transport_object.transport.ack(header)
if dcg_hooks := entry_points().select(
group="murfey.hooks", name="data_collection_group"
Expand Down
58 changes: 0 additions & 58 deletions src/murfey/server/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,10 @@
BLSampleImageParameters,
BLSampleParameters,
BLSubSampleParameters,
ClearanceKeys,
ClientInfo,
ContextInfo,
CurrentGainRef,
DCGroupParameters,
DCParameters,
File,
FoilHoleParameters,
FractionationParameters,
GainReference,
Expand All @@ -107,7 +104,6 @@
Visit,
)
from murfey.util.processing_params import default_spa_parameters
from murfey.util.state import global_state
from murfey.util.tomo import midpoint
from murfey.workflows.spa.flush_spa_preprocess import (
register_foil_hole,
Expand Down Expand Up @@ -1023,23 +1019,6 @@ def visit_info(
return None


@router.post("/visits/{visit_name}/context")
async def register_context(context_info: ContextInfo):
await ws.manager.broadcast(f"Context registered: {context_info}")
await ws.manager.set_state("experiment_type", context_info.experiment_type)
await ws.manager.set_state(
"acquisition_software", context_info.acquisition_software
)


@router.post("/visits/{visit_name}/files")
async def add_file(file: File):
message = f"File {file} transferred"
log.info(message)
await ws.manager.broadcast(f"File {file} transferred")
return file


@router.post("/instruments/{instrument_name}/feedback")
async def send_murfey_message(instrument_name: str, msg: RegistrationMessage):
if _transport_object:
Expand Down Expand Up @@ -1322,7 +1301,6 @@ async def request_tomography_preprocessing(
db.add(for_stash)
db.commit()
db.close()
# await ws.manager.broadcast(f"Pre-processing requested for {ppath.name}")
return proc_file


Expand Down Expand Up @@ -1795,42 +1773,6 @@ async def make_gif(
return {"output_gif": str(output_path)}


@router.post("/visits/{visit_name}/clean_state")
async def clean_state(visit_name: str, for_clearance: ClearanceKeys):
if global_state.get("data_collection_group_ids") and isinstance(
global_state["data_collection_group_ids"], dict
):
global_state["data_collection_group_ids"] = {
k: v
for k, v in global_state["data_collection_group_ids"].items()
if k not in for_clearance.data_collection_group
}
if global_state.get("data_collection_ids") and isinstance(
global_state["data_collection_ids"], dict
):
global_state["data_collection_ids"] = {
k: v
for k, v in global_state["data_collection_ids"].items()
if k not in for_clearance.data_collection
}
if global_state.get("processing_job_ids") and isinstance(
global_state["processing_job_ids"], dict
):
global_state["processing_job_ids"] = {
k: v
for k, v in global_state["processing_job_ids"].items()
if k not in for_clearance.processing_job
}
if global_state.get("autoproc_program_ids") and isinstance(
global_state["autoproc_program_ids"], dict
):
global_state["autoproc_program_ids"] = {
k: v
for k, v in global_state["autoproc_program_ids"].items()
if k not in for_clearance.autoproc_program
}


@router.get("/new_client_id/")
async def new_client_id(db=murfey_db):
clients = db.exec(select(ClientEnvironment)).all()
Expand Down
70 changes: 1 addition & 69 deletions src/murfey/server/demo_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,9 @@
)
from murfey.util.models import (
ClientInfo,
ContextInfo,
CurrentGainRef,
DCGroupParameters,
DCParameters,
File,
FoilHoleParameters,
FractionationParameters,
GainReference,
Expand All @@ -89,7 +87,6 @@
Visit,
)
from murfey.util.processing_params import default_spa_parameters
from murfey.util.state import global_state
from murfey.workflows.spa.picking import _register_picked_particles_use_diameter

log = logging.getLogger("murfey.server.demo_api")
Expand Down Expand Up @@ -900,26 +897,6 @@ def visit_info(request: Request, visit_name: str):
)


@router.post("/visits/{visit_name}/context")
async def register_context(context_info: ContextInfo):
log.info(
f"Context {context_info.experiment_type}:{context_info.acquisition_software} registered"
)
await ws.manager.broadcast(f"Context registered: {context_info}")
await ws.manager.set_state("experiment_type", context_info.experiment_type)
await ws.manager.set_state(
"acquisition_software", context_info.acquisition_software
)


@router.post("/visits/{visit_name}/files")
async def add_file(file: File):
message = f"File {file} transferred"
log.info(message)
await ws.manager.broadcast(f"File {file} transferred")
return file


@router.post("/feedback")
async def send_murfey_message(msg: RegistrationMessage):
pass
Expand Down Expand Up @@ -1447,15 +1424,6 @@ def register_dc_group(
db.add(murfey_app_3d)
db.commit()

if global_state.get("data_collection_group_ids") and isinstance(
global_state["data_collection_group_ids"], dict
):
global_state["data_collection_group_ids"] = {
**global_state["data_collection_group_ids"],
dcg_params.tag: dcgid,
}
else:
global_state["data_collection_group_ids"] = {dcg_params.tag: dcgid}
if dcg_params.atlas:
_flush_grid_square_records(
{"session_id": session_id, "tag": dcg_params.tag}, demo=True
Expand Down Expand Up @@ -1511,15 +1479,6 @@ def start_dc(
db.add(murfey_app)
db.commit()
db.close()
if global_state.get("data_collection_ids") and isinstance(
global_state["data_collection_ids"], dict
):
global_state["data_collection_ids"] = {
**global_state["data_collection_ids"],
dc_params.tag: 1,
}
else:
global_state["data_collection_ids"] = {dc_params.tag: 1}
if dc_params.exposure_time:
prom.exposure_time.set(dc_params.exposure_time)
return dc_params
Expand All @@ -1529,35 +1488,8 @@ def start_dc(
def register_proc(
visit_name, session_id: MurfeySessionID, proc_params: ProcessingJobParameters
):
# This should probably do something
log.info("Registering processing job")
if global_state.get("processing_job_ids"):
assert isinstance(global_state["processing_job_ids"], dict)
global_state["processing_job_ids"] = {
**{
k: v
for k, v in global_state["processing_job_ids"].items()
if k != proc_params.tag
},
proc_params.tag: {
**global_state["processing_job_ids"].get(proc_params.tag, {}),
proc_params.recipe: 1,
},
}
else:
global_state["processing_job_ids"] = {proc_params.tag: {proc_params.recipe: 1}}
if global_state.get("autoproc_program_ids"):
assert isinstance(global_state["autoproc_program_ids"], dict)
global_state["autoproc_program_ids"] = {
**global_state["autoproc_program_ids"],
proc_params.tag: {
**global_state["autoproc_program_ids"].get(proc_params.tag, {}),
proc_params.recipe: 1,
},
}
else:
global_state["autoproc_program_ids"] = {
proc_params.tag: {proc_params.recipe: 1}
}
log.info("Processing job registered")
return proc_params

Expand Down
46 changes: 9 additions & 37 deletions src/murfey/server/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import json
import logging
from datetime import datetime
from typing import Any, Dict, Generic, TypeVar, Union
from typing import Any, Dict, TypeVar, Union

from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from sqlmodel import select
Expand All @@ -13,19 +13,16 @@
from murfey.server.murfey_db import get_murfey_db_session
from murfey.util import sanitise
from murfey.util.db import ClientEnvironment
from murfey.util.state import State, global_state

T = TypeVar("T")

ws = APIRouter(prefix="/ws", tags=["websocket"])
log = logging.getLogger("murfey.server.websocket")


class ConnectionManager(Generic[T]):
def __init__(self, state: State[T]):
class ConnectionManager:
def __init__(self):
self.active_connections: Dict[int | str, WebSocket] = {}
self._state = state
self._state.subscribe(self._broadcast_state_update)

async def connect(
self, websocket: WebSocket, client_id: int | str, register_client: bool = True
Expand All @@ -38,7 +35,6 @@
"To register a client the client ID must be an integer"
)
self._register_new_client(client_id)
await websocket.send_json({"message": "state-full", "state": self._state.data})

@staticmethod
def _register_new_client(client_id: int):
Expand All @@ -48,9 +44,7 @@
murfey_db.commit()
murfey_db.close()

def disconnect(
self, websocket: WebSocket, client_id: int | str, unregister_client: bool = True
):
def disconnect(self, client_id: int | str, unregister_client: bool = True):
self.active_connections.pop(client_id)
if unregister_client:
murfey_db = next(get_murfey_db_session())
Expand All @@ -67,33 +61,14 @@
for connection in self.active_connections:
await self.active_connections[connection].send_text(message)

async def _broadcast_state_update(
self, attribute: str, value: T | None, message: str = "state-update"
):
for connection in self.active_connections:
await self.active_connections[connection].send_json(
{"message": message, "attribute": attribute, "value": value}
)

async def set_state(self, attribute: str, value: T):
log.info(
f"State attribute {sanitise(attribute)!r} set to {sanitise(str(value))!r}"
)
await self._state.set(attribute, value)

async def delete_state(self, attribute: str):
log.info(f"State attribute {sanitise(attribute)!r} removed")
await self._state.delete(attribute)


manager = ConnectionManager(global_state)
manager = ConnectionManager()


@ws.websocket("/test/{client_id}")
async def websocket_endpoint(websocket: WebSocket, client_id: int):
await manager.connect(websocket, client_id)
await manager.broadcast(f"Client {client_id} joined")
await manager.set_state(f"Client {client_id}", "joined")
try:
while True:
data = await websocket.receive_text()
Expand All @@ -111,9 +86,8 @@
select(ClientEnvironment).where(ClientEnvironment.client_id == client_id)
).one()
prom.monitoring_switch.labels(visit=client_env.visit).set(0)
manager.disconnect(websocket, client_id)
manager.disconnect(client_id)

Check warning on line 89 in src/murfey/server/websocket.py

View check run for this annotation

Codecov / codecov/patch

src/murfey/server/websocket.py#L89

Added line #L89 was not covered by tests
await manager.broadcast(f"Client #{client_id} disconnected")
await manager.delete_state(f"Client {client_id}")


@ws.websocket("/connect/{client_id}")
Expand All @@ -122,7 +96,6 @@
):
await manager.connect(websocket, client_id, register_client=False)
await manager.broadcast(f"Client {client_id} joined")
await manager.set_state(f"Client {client_id}", "joined")
try:
while True:
data = await websocket.receive_text()
Expand All @@ -138,9 +111,8 @@
await manager.broadcast(f"Client #{client_id} sent message {data}")
except WebSocketDisconnect:
log.info(f"Disconnecting Client {sanitise(str(client_id))}")
manager.disconnect(websocket, client_id, unregister_client=False)
manager.disconnect(client_id, unregister_client=False)

Check warning on line 114 in src/murfey/server/websocket.py

View check run for this annotation

Codecov / codecov/patch

src/murfey/server/websocket.py#L114

Added line #L114 was not covered by tests
await manager.broadcast(f"Client #{client_id} disconnected")
await manager.delete_state(f"Client {client_id}")


async def check_connections(active_connections):
Expand Down Expand Up @@ -178,7 +150,7 @@
murfey_db.close()
client_id_str = str(client_id).replace("\r\n", "").replace("\n", "")
log.info(f"Disconnecting {client_id_str}")
manager.disconnect(manager.active_connections[client_id], client_id)
manager.disconnect(client_id)

Check warning on line 153 in src/murfey/server/websocket.py

View check run for this annotation

Codecov / codecov/patch

src/murfey/server/websocket.py#L153

Added line #L153 was not covered by tests
prom.monitoring_switch.labels(visit=visit_name).set(0)
await manager.broadcast(f"Client #{client_id} disconnected")

Expand All @@ -187,5 +159,5 @@
async def close_unrecorded_ws_connection(client_id: Union[int, str]):
client_id_str = str(client_id).replace("\r\n", "").replace("\n", "")
log.info(f"Disconnecting {client_id_str}")
manager.disconnect(manager.active_connections[client_id], client_id)
manager.disconnect(client_id)

Check warning on line 162 in src/murfey/server/websocket.py

View check run for this annotation

Codecov / codecov/patch

src/murfey/server/websocket.py#L162

Added line #L162 was not covered by tests
await manager.broadcast(f"Client #{client_id} disconnected")
12 changes: 0 additions & 12 deletions src/murfey/util/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,6 @@ class SessionInfo(BaseModel):
rescale: bool = True


class ContextInfo(BaseModel):
experiment_type: str
acquisition_software: str


class ClientInfo(BaseModel):
id: int

Expand All @@ -127,13 +122,6 @@ class RsyncerInfo(BaseModel):
tag: str = ""


class ClearanceKeys(BaseModel):
data_collection_group: List[str]
data_collection: List[str]
processing_job: List[str]
autoproc_program: List[str]


class GainReference(BaseModel):
gain_ref: Path
rescale: bool = True
Expand Down
Loading