Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 12 additions & 13 deletions flamesdk/flame_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(self, aggregator_requires_data: bool = False, silent: bool = False)
self.flame_log(f"Nginx connection failure (error_msg='{repr(e)}')", log_type='error')

# Set up the connection to all the services needed
## Connect to message broker
## Connect to MessageBroker
self.flame_log("\tConnecting to MessageBroker...", end='', suppress_tail=True)
try:
self._message_broker_api = MessageBrokerAPI(self.config, self._flame_logger)
Expand All @@ -43,13 +43,13 @@ def __init__(self, aggregator_requires_data: bool = False, silent: bool = False)
self._message_broker_api = None
self.flame_log(f"failed (error_msg='{repr(e)}')", log_type='error', suppress_head=True)
try:
### Update config with self_config from Messagebroker
### Update config with self_config from MessageBroker
self.config = self._message_broker_api.config
except Exception as e:
self.flame_log(f"Unable to retrieve node config from message broker (error_msg='{repr(e)}')",
log_type='error')

## Connect to po service
## Connect to POService
self.flame_log("\tConnecting to PO service...", end='', suppress_tail=True)
try:
self._po_api = POAPI(self.config, self._flame_logger)
Expand All @@ -59,7 +59,7 @@ def __init__(self, aggregator_requires_data: bool = False, silent: bool = False)
self._po_api = None
self.flame_log(f"failed (error_msg='{repr(e)}')", log_type='error', suppress_head=True)

## Connect to result service
## Connect to ResultService
self.flame_log("\tConnecting to ResultService...", end='', suppress_tail=True)
try:
self._storage_api = StorageAPI(self.config, self._flame_logger)
Expand All @@ -69,7 +69,7 @@ def __init__(self, aggregator_requires_data: bool = False, silent: bool = False)
self.flame_log(f"failed (error_msg='{repr(e)}')", log_type='error', suppress_head=True)

if (self.config.node_role == 'default') or aggregator_requires_data:
## Connection to data service
## Connection to DataService
self.flame_log("\tConnecting to DataApi...", end='', suppress_tail=True)
try:
self._data_api = DataAPI(self.config, self._flame_logger)
Expand All @@ -80,7 +80,7 @@ def __init__(self, aggregator_requires_data: bool = False, silent: bool = False)
else:
self._data_api = True

# Start the flame api thread used for incoming messages and health checks
# Start the FlameAPI thread used for incoming messages and health checks
self.flame_log("\tStarting FlameApi thread...", end='', suppress_tail=True)
try:
self._flame_api_thread = Thread(target=self._start_flame_api)
Expand Down Expand Up @@ -120,7 +120,7 @@ def get_participant_ids(self) -> list[str]:
Returns a list of all participant ids in the analysis
:return: the list of participants
"""
return [participant['nodeId'] for participant in self._message_broker_api.participants]
return [p['nodeId'] for p in self.get_participants()]

def get_node_status(self,
timeout: Optional[int] = None) -> dict[str, Literal["online", "offline", "not_connected"]]:
Expand Down Expand Up @@ -186,7 +186,6 @@ def ready_check(self,
specified interval until all nodes respond or the timeout is reached.

Parameters:
flame (FlameCoreSDK): The SDK instance used to communicate with the nodes.
nodes (list[str]): A list of node identifiers to check for readiness.
attempt_interval (int, optional): The interval (in seconds) between successive attempts.
Defaults to 30 seconds.
Expand Down Expand Up @@ -586,7 +585,7 @@ def get_data_client(self, data_id: str) -> Optional[AsyncClient]:
:param data_id: the id of the data source
:return: the data client
"""
if type(self._data_api) == DataAPI:
if isinstance(self._data_api, DataAPI):
return self._data_api.get_data_client(data_id)
else:
self.flame_log("Data API is not available, cannot retrieve data client",
Expand All @@ -598,7 +597,7 @@ def get_data_sources(self) -> Optional[list[str]]:
Returns a list of all data sources available for this project.
:return: the list of data sources
"""
if type(self._data_api) == DataAPI:
if isinstance(self._data_api, DataAPI):
return self._data_api.get_data_sources()
else:
self.flame_log("Data API is not available, cannot retrieve data sources",
Expand All @@ -611,7 +610,7 @@ def get_fhir_data(self, fhir_queries: Optional[list[str]] = None) -> Optional[li
:param fhir_queries: list of queries to get the data
:return:
"""
if type(self._data_api) == DataAPI:
if isinstance(self._data_api, DataAPI):
return self._data_api.get_fhir_data(fhir_queries)
else:
self.flame_log("Data API is not available, cannot retrieve FHIR data",
Expand All @@ -624,7 +623,7 @@ def get_s3_data(self, s3_keys: Optional[list[str]] = None) -> Optional[list[Unio
:param s3_keys:f
:return:
"""
if type(self._data_api) == DataAPI:
if isinstance(self._data_api, DataAPI):
return self._data_api.get_s3_data(s3_keys)
else:
self.flame_log("Data API is not available, cannot retrieve S3 data",
Expand All @@ -639,7 +638,7 @@ def _start_flame_api(self) -> None:
:return:
"""
self.flame_api = FlameAPI(self._message_broker_api.message_broker_client,
self._data_api.data_client if hasattr(self._data_api, 'data_client') else 'ignore',
self._data_api.data_client if isinstance(self._data_api, DataAPI) else self._data_api,
self._storage_api.result_client,
self._po_api.po_client,
self._flame_logger,
Expand Down
2 changes: 0 additions & 2 deletions flamesdk/resources/client_apis/clients/data_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,5 +95,3 @@ def get_data_source_client(self, data_id: str) -> AsyncClient:
self.flame_logger.raise_error(f"Data source with id {data_id} not found")
client = AsyncClient(base_url=f"{path}")
return client


11 changes: 6 additions & 5 deletions flamesdk/resources/client_apis/clients/message_broker_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,7 @@ async def get_partner_nodes(self, self_node_id: str, analysis_id: str) -> list[d
return response

async def test_connection(self) -> bool:
response = await self._message_broker.get("/healthz",
headers=[('Connection', 'close')])
response = await self._message_broker.get("/healthz", headers=[('Connection', 'close')])
try:
response.raise_for_status()
return True
Expand Down Expand Up @@ -181,12 +180,14 @@ async def send_message(self, message: Message) -> None:

def receive_message(self, body: dict) -> None:
needs_acknowledgment = body["meta"]["akn_id"] is None
message = Message(message=body, config=self.nodeConfig, outgoing=False, flame_logger=self.flame_logger )
message = Message(message=body, config=self.nodeConfig, flame_logger=self.flame_logger, outgoing=False)
self.list_of_incoming_messages.append(message)

if needs_acknowledgment:
self.flame_logger.new_log("acknowledging ready check" if body["meta"]["category"] == "ready_check" else "incoming message",
log_type='info')
self.flame_logger.new_log(
"acknowledging ready check" if body["meta"]["category"] == "ready_check" else "incoming message",
log_type='info'
)
asyncio.run(self.acknowledge_message(message))

def delete_message_by_id(self, message_id: str, type: Literal["outgoing", "incoming"]) -> int:
Expand Down
10 changes: 3 additions & 7 deletions flamesdk/resources/client_apis/clients/po_client.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from typing import Optional, Union
import asyncio
from httpx import Client, HTTPError

from flamesdk.resources.utils.logging import FlameLogger


class POClient:
def __init__(self, nginx_name: str, keycloak_token: str, flame_logger: FlameLogger) -> None:
self.nginx_name = nginx_name
Expand All @@ -13,7 +12,7 @@ def __init__(self, nginx_name: str, keycloak_token: str, flame_logger: FlameLogg
follow_redirects=True)
self.flame_logger = flame_logger

def refresh_token(self, keycloak_token: str):
def refresh_token(self, keycloak_token: str) -> None:
self.client = Client(base_url=f"http://{self.nginx_name}/po",
headers={"Authorization": f"Bearer {keycloak_token}",
"accept": "application/json"},
Expand All @@ -26,15 +25,12 @@ def stream_logs(self, log: str, log_type: str, analysis_id: str, status: str) ->
"analysis_id": analysis_id,
"status": status
}
print("Sending logs to PO:", log_dict)
response = self.client.post("/stream_logs",
json=log_dict,
headers={"Content-Type": "application/json"})
try:
response.raise_for_status()
print("Successfully streamed logs to PO")
except HTTPError as e:
#self.flame_logger.new_log(f"Failed to stream logs to PO: {repr(e)}", log_type='error')
print("HTTP Error in po api:", repr(e))
except Exception as e:
print("Unforeseen Error:", repr(e))
print("Unforeseen Error in po api:", repr(e))
7 changes: 4 additions & 3 deletions flamesdk/resources/client_apis/clients/result_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ class LocalDifferentialPrivacyParams(TypedDict, total=True):


class ResultClient:

def __init__(self, nginx_name, keycloak_token, flame_logger: FlameLogger) -> None:
self.nginx_name = nginx_name
self.client = Client(base_url=f"http://{nginx_name}/storage",
Expand Down Expand Up @@ -99,8 +98,10 @@ def push_result(self,
file_body = pickle.dumps(result)
except pickle.PicklingError as e:
self.flame_logger.raise_error(f"Failed to pickle result data: {repr(e)}")
file_body = None
else:
self.flame_logger.raise_error(f"Failed to pickle result data: {repr(e)}")
file_body = None

if remote_node_id:
data = {"remote_node_id": remote_node_id}
Expand Down Expand Up @@ -157,7 +158,7 @@ def get_intermediate_data(self,
type = "intermediate" if type == "global" else type

if tag:
urls = self._get_location_url_for_tag(tag)
urls = self._get_location_urls_for_tag(tag)
if tag_option == "last":
urls = urls[-1:]
elif tag_option == "first":
Expand All @@ -172,7 +173,7 @@ def get_intermediate_data(self,
else:
return self._get_file(f"/{type}/{id}?node_id={sender_node_id}")

def _get_location_url_for_tag(self, tag: str) -> str:
def _get_location_urls_for_tag(self, tag: str) -> list[str]:
"""
Retrieves the URL associated with the specified tag.
:param tag:
Expand Down
11 changes: 6 additions & 5 deletions flamesdk/resources/client_apis/data_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
from flamesdk.resources.node_config import NodeConfig
from flamesdk.resources.utils.logging import FlameLogger


class DataAPI:
def __init__(self, config: NodeConfig, flame_logger: FlameLogger) -> None:
self.data_client = DataApiClient(config.project_id,
config.nginx_name,
config.data_source_token,
config.keycloak_token,
flame_logger= flame_logger)
self.data_client = DataApiClient(project_id=config.project_id,
nginx_name=config.nginx_name,
data_source_token=config.data_source_token,
keycloak_token=config.keycloak_token,
flame_logger=flame_logger)

def get_data_client(self, data_id: str) -> AsyncClient:
"""
Expand Down
16 changes: 7 additions & 9 deletions flamesdk/resources/client_apis/message_broker_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

class MessageBrokerAPI:
def __init__(self, config: NodeConfig, flame_logger: FlameLogger) -> None:
self.flame_logger = flame_logger
self.message_broker_client = MessageBrokerClient(config, flame_logger)
self.config = self.message_broker_client.nodeConfig
self.participants = asyncio.run(self.message_broker_client.get_partner_nodes(self.config.node_id,
Expand Down Expand Up @@ -42,7 +41,7 @@ async def send_message(self,
message = Message(message=message,
config=self.config,
outgoing=True,
flame_logger=self.flame_logger,
flame_logger=self.message_broker_client.flame_logger,
message_number=self.message_broker_client.message_number,
category=message_category,
recipients=receivers)
Expand Down Expand Up @@ -180,13 +179,12 @@ def send_message_and_wait_for_responses(self,
"""
time_start = datetime.now()
# Send the message
asyncio.run(self.send_message(receivers,
message_category,
message,
max_attempts,
timeout,
attempt_timeout,
))
asyncio.run(self.send_message(receivers=receivers,
message_category=message_category,
message=message,
max_attempts=max_attempts,
timeout=timeout,
attempt_timeout=attempt_timeout))
timeout = timeout - (datetime.now() - time_start).seconds
if timeout < 0:
timeout = 1
Expand Down
3 changes: 1 addition & 2 deletions flamesdk/resources/client_apis/po_api.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import asyncio

from flamesdk.resources.client_apis.clients.po_client import POClient
from flamesdk.resources.node_config import NodeConfig
from flamesdk.resources.utils.logging import FlameLogger


class POAPI:
def __init__(self, config: NodeConfig, flame_logger: FlameLogger) -> None:
self.po_client = POClient(config.nginx_name, config.keycloak_token, flame_logger)
Expand Down
1 change: 1 addition & 0 deletions flamesdk/resources/client_apis/storage_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from flamesdk.resources.node_config import NodeConfig
from flamesdk.resources.utils.logging import FlameLogger


class StorageAPI:
def __init__(self, config: NodeConfig, flame_logger: FlameLogger) -> None:
self.result_client = ResultClient(config.nginx_name, config.keycloak_token, flame_logger)
Expand Down
9 changes: 4 additions & 5 deletions flamesdk/resources/node_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@


class NodeConfig:

def __init__(self):
def __init__(self) -> None:
# init analysis status
self.finished = False

Expand All @@ -18,11 +17,11 @@ def __init__(self):
self.node_role = None
self.node_id = None

def set_role(self, role):
def set_role(self, role) -> None:
self.node_role = role

def set_node_id(self, node_id):
def set_node_id(self, node_id) -> None:
self.node_id = node_id

def finish_analysis(self):
def finish_analysis(self) -> None:
self.finished = True
8 changes: 4 additions & 4 deletions flamesdk/resources/rest_api.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import sys
import threading
import uvicorn
from typing import Any, Callable, Union
from typing import Any, Callable, Union, Optional

from fastapi import FastAPI, APIRouter, Request, Depends
from fastapi.responses import JSONResponse
Expand All @@ -18,7 +18,7 @@
class FlameAPI:
def __init__(self,
message_broker: MessageBrokerClient,
data_client: Union[DataApiClient, str],
data_client: Union[DataApiClient, Optional[bool]],
result_client: ResultClient,
po_client: POClient,
flame_logger: FlameLogger,
Expand Down Expand Up @@ -49,7 +49,7 @@ def __init__(self,
self.finishing_call = finishing_call

@router.post("/token_refresh", response_class=JSONResponse)
async def token_refresh(request: Request) -> JSONResponse:
async def token_refresh(request: Request) -> Optional[JSONResponse]:
try:
# get body
body = await request.json()
Expand All @@ -64,7 +64,7 @@ async def token_refresh(request: Request) -> JSONResponse:
po_client.refresh_token(new_token)
# refresh token in message-broker
message_broker.refresh_token(new_token)
if type(data_client) is DataApiClient:
if isinstance(data_client, DataApiClient):
# refresh token in data client
data_client.refresh_token(new_token)
# refresh token in result client
Expand Down
Loading