Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions src/ansys/hps/data_transfer/client/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ class BinaryConfig:
Whether to ignore SSL certificate verification.
debug: bool, default: False
Whether to enable debug logging.
max_restarts: int, default: 5
Maximum number of times to restart the worker if it crashes.
"""

def __init__(
Expand All @@ -169,6 +171,7 @@ def __init__(
debug: bool = False,
auth_type: str = None,
env: dict | None = None,
max_restarts: int = 5,
):
"""Initialize the BinaryConfig class object."""
self.data_transfer_url = data_transfer_url
Expand All @@ -191,6 +194,7 @@ def __init__(
self._env = env or {}
self.insecure = insecure
self.auth_type = auth_type
self.max_restarts = max_restarts

self._on_token_update = None
self._on_process_died = None
Expand Down Expand Up @@ -388,6 +392,7 @@ def _log_output(self):
# log.debug("Worker log output stopped")

def _monitor(self):
restart_count = 0 # Initialize a counter for restarts
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to reset this when a worker starts up successfully?

while not self._stop.is_set():
if self._process is None:
self._prepare()
Expand All @@ -414,18 +419,23 @@ def _monitor(self):
else:
ret_code = self._process.poll()
if ret_code is not None and ret_code != 0:
restart_count += 1 # Increment the restart counter
if restart_count > self.config.max_restarts:
log.error(f"Worker exceeded maximum restart attempts ({self.config.max_restarts}). Stopping...")
break # Exit the loop after exceeding the restart limit

log.warning(f"Worker exited with code {ret_code}, restarting ...")
self._process = None
self._prepared.clear()
if self.config._on_process_died is not None:
self.config._on_process_died(ret_code)
time.sleep(1.0)
continue
# elif self._config.debug:
# log.debug(f"Worker running ...")
# Reset restart_count if the worker is running successfully
restart_count = 0

time.sleep(self._config.monitor_interval)
# log.debug("Worker monitor stopped")
log.debug("Worker monitor stopped")

def _prepare(self):
if self._config._selected_port is None:
Expand Down
42 changes: 42 additions & 0 deletions src/ansys/hps/data_transfer/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ def __init__(

self._session = None
self.binary = None
self.panic_file = None

self._features = None
self._api_key = None
Expand Down Expand Up @@ -538,6 +539,37 @@ def _adjust_config(self):
}
self._bin_config.env.update({k: v for k, v in env.items() if k not in os.environ})

def _fetch_panic_file(self, resp):
"""Extract and log the panic file location from the response."""
if resp.status_code == 200:
self.panic_file = resp.json().get("debug", {}).get("panic_file", None)
log.debug(f"Worker panic file: {self.panic_file}")

def _panic_file_contents(self):
"""Read and log the contents of the panic file if it exists."""
# if the file exists and the size of the file is > 0,
# read and log its content
if self.panic_file and os.path.exists(self.panic_file):
try:
if os.path.getsize(self.panic_file) > 0:
with open(self.panic_file) as f:
# Read the file line by line
lines = f.readlines()
message = []
for line in lines:
# Check for empty lines to split the message
if line.strip() == "":
log.error(f"Worker panic file content:\n{''.join(message)}")
message = [] # Reset the message buffer
else:
message.append(line)

# Log any remaining content after the last empty line
if message:
log.error(f"Worker panic file content:\n{''.join(message)}")
except Exception as panic_ex:
log.debug(f"Failed to read panic file: {panic_ex}")


class AsyncClient(ClientBase):
"""Provides an async interface to the Python client to the HPS data transfer APIs."""
Expand Down Expand Up @@ -568,6 +600,9 @@ def __setstate__(self, state):
async def start(self):
"""Start the async binary worker."""
super().start()
# grab location of panic file
resp = await self.session.get("/")
self._fetch_panic_file(resp)
self._monitor_task = asyncio.create_task(self._monitor())

async def stop(self, wait=5.0):
Expand Down Expand Up @@ -632,6 +667,8 @@ async def _monitor(self):
if self.binary_config.debug:
log.debug("URL: %s", self.base_api_url)
log.debug(traceback.format_exc())
# Before marking it as failed check if there is a panic file
self._panic_file_contents()
self._monitor_state.mark_failed(exc=ex, binary=self.binary)
continue

Expand Down Expand Up @@ -675,6 +712,9 @@ def start(self):
self._monitor_thread = threading.Thread(
target=self._monitor, args=(), daemon=True, name="worker_status_monitor"
)
# grab location of panic file
resp = self.session.get("/")
self._fetch_panic_file(resp)
self._monitor_thread.start()

def stop(self, wait=5.0):
Expand Down Expand Up @@ -733,6 +773,8 @@ def _monitor(self):
if self.binary_config.debug:
log.debug("URL: %s", self.base_api_url)
log.debug(traceback.format_exc())
# Before marking it as failed check if there is a panic file
self._panic_file_contents()
self._monitor_state.mark_failed(exc=ex, binary=self.binary)
continue

Expand Down
78 changes: 78 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright (C) 2024 - 2025 ANSYS, Inc. and/or its affiliates.
# SPDX-License-Identifier: MIT
#
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

"""This module contains tests for verifying the functionality of the Client class"""

import unittest
from unittest.mock import MagicMock, mock_open, patch

from ansys.hps.data_transfer.client.client import ClientBase


class TestClientBase(unittest.TestCase):
"""Test suite for the ClientBase class."""

def setUp(self):
"""Set up the ClientBase instance for testing."""
self.client = ClientBase()
self.client.panic_file = None

@patch("ansys.hps.data_transfer.client.client.log")
def test_fetch_panic_file(self, mock_log):
"""Test the _fetch_panic_file method."""
# Mock the response object
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.json.return_value = {"debug": {"panic_file": "/path/to/panic_file.log"}}

# Call the method
self.client._fetch_panic_file(mock_resp)

# Assertions
assert self.client.panic_file == "/path/to/panic_file.log"
mock_log.debug.assert_called_with("Worker panic file: /path/to/panic_file.log")

@patch("os.path.exists", return_value=True)
@patch("os.path.getsize", return_value=100)
@patch(
"builtins.open",
new_callable=mock_open,
read_data="Error: Something went wrong\n\nDetails: Invalid configuration\n\n",
)
@patch("ansys.hps.data_transfer.client.client.log")
def test_panic_file_contents(self, mock_log, mock_open_file, mock_getsize, mock_exists):
"""Test the _panic_file_contents method."""
# Set the panic file path
self.client.panic_file = "/path/to/panic_file.log"

# Call the method
self.client._panic_file_contents()

# Assertions
mock_exists.assert_called_once_with("/path/to/panic_file.log")
mock_getsize.assert_called_once_with("/path/to/panic_file.log")
mock_log.error.assert_any_call("Worker panic file content:\nError: Something went wrong\n")
mock_log.error.assert_any_call("Worker panic file content:\nDetails: Invalid configuration\n")


if __name__ == "__main__":
unittest.main()