diff --git a/src/mldebug/backend/factory.py b/src/mldebug/backend/factory.py index 44def34..ed639c7 100644 --- a/src/mldebug/backend/factory.py +++ b/src/mldebug/backend/factory.py @@ -9,11 +9,12 @@ """ import importlib -import sys from dataclasses import dataclass, field from typing import Any +from mldebug.utils import cleanup_and_exit + @dataclass class BackendConfig: @@ -55,10 +56,10 @@ def create_backend(backend_type, config): xrt_mod = importlib.import_module("mldebug.backend.xrt_impl") except ModuleNotFoundError: print("Unable to import Backend. Python 3.10 is required on Win/Linux and 3.12 on Embedded Linux.") - sys.exit(1) + cleanup_and_exit(config.args, 1) except ImportError: print("Unable to import XRT. Please check install.") - sys.exit(1) + cleanup_and_exit(config.args, 1) return xrt_mod.XRTImpl(config.tiles, config.ctx_id, config.pid, config.device) if backend_type == "test": diff --git a/src/mldebug/batch_runner.py b/src/mldebug/batch_runner.py index 5a777d0..07dde77 100644 --- a/src/mldebug/batch_runner.py +++ b/src/mldebug/batch_runner.py @@ -17,7 +17,7 @@ from concurrent.futures import ThreadPoolExecutor, as_completed -from mldebug.utils import LOGGER, timeit +from mldebug.utils import LOGGER, cleanup_and_exit, timeit # 16 byte pm, we assume 2 clock cycle delay COMBO_EVENT_MAX_DELAY_CYCLES = 32 @@ -311,7 +311,7 @@ def _process_err(self): else: self.status_handle.get("aie_status_error.txt") self._write_run_summary("FAIL") - sys.exit(1) + cleanup_and_exit(self.args, 1) def _process_end_breakpoint(self, layer, it, sid): """ @@ -511,6 +511,7 @@ def _write_run_summary(self, status): summary = {"status": status, "run_flags": flags_dict} try: + pathlib.Path(self.args.top_output_dir).mkdir(parents=True, exist_ok=True) with open(rsf, "w", encoding="utf-8") as fh: json.dump(summary, fh, indent=2, default=str) except (IOError, OSError) as e: diff --git a/src/mldebug/client_debug.py b/src/mldebug/client_debug.py index dc19ded..cbee01b 100644 --- a/src/mldebug/client_debug.py +++ b/src/mldebug/client_debug.py @@ -21,7 +21,7 @@ from mldebug.interactive_controller import InteractiveController from mldebug.layer_info import LayerInfo from mldebug.memory_dumper import MemoryDumper -from mldebug.utils import LOGGER +from mldebug.utils import LOGGER, register_debug_server class ClientDebug: @@ -51,7 +51,12 @@ def __init__(self, args, ctx_id, pid, output_dir): # Create this first so that connection will be aborted in case of crash if self.args.automated_debug or self.args.l3: - debug_server = DebugServer(self.args.subgraph_name, self.output_dir, self.args.backend == "test") + debug_server = DebugServer( + self.output_dir, self.args.backend == "test", subgraph_name=self.args.subgraph_name, + ) + # Track the live server so cleanup_and_exit() at unplanned exit points + # can send TERMINATE_CONNECTION to flexmlrt. + register_debug_server(debug_server) try: self.design_info = LayerInfo(args) diff --git a/src/mldebug/debug_server.py b/src/mldebug/debug_server.py index f67af5a..856f238 100644 --- a/src/mldebug/debug_server.py +++ b/src/mldebug/debug_server.py @@ -19,7 +19,10 @@ class DebugServer: and communication with flexmlrt for buffer dump and termination requests. """ - def __init__(self, subgraph_name, output_dir, is_testmode, bind_addr=("127.0.0.1", 9000)) -> None: + def __init__( + self, output_dir, is_testmode, subgraph_name="subgraph", + bind_addr=("127.0.0.1", 9000), connect_timeout=None, + ) -> None: """ Initialize the DebugServer instance. @@ -28,11 +31,14 @@ def __init__(self, subgraph_name, output_dir, is_testmode, bind_addr=("127.0.0.1 output_dir (str): Directory where buffer dumps will be stored. is_testmode (bool): Enables test mode, which disables socket operations for CI/testing. bind_addr (tuple): Address and port to bind the debug server socket. + connect_timeout (float, optional): If set, accept() gives up after this + many seconds; used by cleanup paths to avoid hanging forever. """ self.bind_addr = bind_addr self.subgraph_name = subgraph_name self.output_dir = output_dir self.is_testmode = is_testmode + self.connect_timeout = connect_timeout self.server_socket = None self.client_socket = None self.start() @@ -64,9 +70,20 @@ def start(self): self.server_socket.listen(1) LOGGER.verbose_print(f"Listening on {self.bind_addr}...") + if self.connect_timeout is not None: + self.server_socket.settimeout(self.connect_timeout) self.client_socket, client_address = self.server_socket.accept() + # Reset to blocking mode for subsequent send/recv. + if self.connect_timeout is not None: + self.server_socket.settimeout(None) + self.client_socket.settimeout(None) LOGGER.log(f"[INFO] Connected to FlexmlRT on {client_address}") return True + except socket.timeout: + LOGGER.verbose_print( + f"Timed out after {self.connect_timeout}s waiting for flexmlrt to connect." + ) + return False except socket.error as e: LOGGER.verbose_print(f"Socket error during setup or connection: {e}") return False diff --git a/src/mldebug/input_parser.py b/src/mldebug/input_parser.py index e397f74..93e64d3 100644 --- a/src/mldebug/input_parser.py +++ b/src/mldebug/input_parser.py @@ -13,12 +13,14 @@ import importlib import os import subprocess -import sys import re from mldebug.arch import load_aie_arch, AIE_DEV_PHX, AIE_DEV_STX, AIE_DEV_TEL from mldebug.backend.core_dump_impl import CoreDumpFallbackReader -from mldebug.utils import LOGGER, is_aarch64, is_windows +from mldebug.utils import LOGGER, cleanup_and_exit, input_with_timeout, is_aarch64, is_windows + +# Seconds to wait at interactive prompts before giving up and exiting. +HW_CONTEXT_INPUT_TIMEOUT_S = 60 @dataclass class RunFlags: @@ -126,13 +128,15 @@ def get_flag(s, default=False): ) -def check_registry_keys(npu3=False) -> None: +def check_registry_keys(args, npu3=False) -> None: """ Checks if specific registry keys are correctly configured on Windows, and sets values if necessary for MLDebug operation. Exits on failure or after making modifications. Args: + args: Argument namespace. Used to drive flexmlrt cleanup on exit + (only when ``args.l3`` is set). npu3 (bool): Whether to check npu3-specific registry keys. Returns: @@ -174,16 +178,16 @@ def check_registry_keys(npu3=False) -> None: f"Error: Unable to access or create registry key:" f" HKEY_LOCAL_MACHINE\\{key_path}. Please run tool with admin privileges." ) - sys.exit(1) + cleanup_and_exit(args, 1) except ValueError: LOGGER.log(f"Error: Invalid registry key format: {key_path}") - sys.exit(1) + cleanup_and_exit(args, 1) if modified: LOGGER.log( "\nRegistry settings to enable MlDebug were modified. Please restart your machine for the changes to take effect." ) - sys.exit(1) + cleanup_and_exit(args, 1) else: LOGGER.log("\nRegistry settings check passed. No modifications were necessary.") @@ -252,18 +256,13 @@ def print_hw_context_table(current_contexts: dict[str, dict[str, str]]) -> None: LOGGER.log(f"{context:<12} {columns_str:<30} {context_data['pid']:<12} {context_data['status']:<12}") -def check_hw_context(device: str) -> tuple[int, int]: +def check_hw_context(args) -> tuple[int, int]: """ - Finds and returns the hardware context and process ID from the xrt-smi command output. - - If xrt-smi fails or no application is running, prompts the user to input ctx and pid manually. - - Args: - device (str): Device identifier. - - Returns: - Tuple[int, int]: Selected context ID and PID. + Returns (ctx_id, pid) from xrt-smi, prompting the user as a fallback. + Manual prompts time out after ``HW_CONTEXT_INPUT_TIMEOUT_S`` seconds and + call ``cleanup_and_exit(args, 1)`` on failure / timeout. """ + device = args.device filename = "xrt-smi_output.json" use_shell = is_windows() @@ -297,17 +296,35 @@ def check_hw_context(device: str) -> tuple[int, int]: else: print_hw_context_table(current_contexts) # Ask user - selected_context_id = input("Multiple Contexts Found. Please enter the Context ID you want to select: ") + selected_context_id = input_with_timeout( + "Multiple Contexts Found. Please enter the Context ID you want to select: ", + HW_CONTEXT_INPUT_TIMEOUT_S, + ) if selected_context_id in current_contexts: ctx = int(selected_context_id) pid = int(current_contexts[selected_context_id]["pid"]) else: LOGGER.log("Could not find the provided context, Exiting now.") - sys.exit(1) + cleanup_and_exit(args, 1) except (FileNotFoundError, subprocess.CalledProcessError, json.JSONDecodeError): - LOGGER.log("Error with xrt-smi. Please enter ctx, pid manually.") - pid = int(input("Enter PID > ")) - ctx = int(input("Enter CTX ID > ")) + LOGGER.log( + f"Error with xrt-smi. Please enter ctx, pid manually " + f"(waiting up to {HW_CONTEXT_INPUT_TIMEOUT_S}s for each value)." + ) + pid_str = input_with_timeout("Enter PID > ", HW_CONTEXT_INPUT_TIMEOUT_S) + if pid_str is None: + LOGGER.log("\nTimed out waiting for PID input. Exiting.") + cleanup_and_exit(args, 1) + ctx_str = input_with_timeout("Enter CTX ID > ", HW_CONTEXT_INPUT_TIMEOUT_S) + if ctx_str is None: + LOGGER.log("\nTimed out waiting for CTX ID input. Exiting.") + cleanup_and_exit(args, 1) + try: + pid = int(pid_str) + ctx = int(ctx_str) + except ValueError: + LOGGER.log("Invalid PID/CTX ID input. Exiting.") + cleanup_and_exit(args, 1) return ctx, pid diff --git a/src/mldebug/memory_dumper.py b/src/mldebug/memory_dumper.py index ff4f489..c2da679 100644 --- a/src/mldebug/memory_dumper.py +++ b/src/mldebug/memory_dumper.py @@ -236,7 +236,7 @@ def _ensure_debug_server(self): """ if not self.debug_server: LOGGER.log("[INFO] Starting L3 debug server...") - self.debug_server = DebugServer(None, self.output_dir, self.args.backend == "test") + self.debug_server = DebugServer(self.output_dir, self.args.backend == "test") if not self.debug_server.client_socket and self.args.backend != "test": LOGGER.log( "[ERROR] Failed to connect to FlexML runtime. Make sure FlexML is running and waiting for debugger connection." diff --git a/src/mldebug/mldebug_cli.py b/src/mldebug/mldebug_cli.py index 47b1df1..bdbd5bb 100644 --- a/src/mldebug/mldebug_cli.py +++ b/src/mldebug/mldebug_cli.py @@ -147,7 +147,7 @@ def launch_debug(args, output_dir): context_id = 0 pid = 0 if args.backend == "xrt": - context_id, pid = check_hw_context(args.device) + context_id, pid = check_hw_context(args) # Top debug handle _apply_unsupported_kernels_from_args(args) handle = ClientDebug(args, context_id, pid, output_dir) @@ -370,7 +370,7 @@ def app(): for fsp in fsp_execution_order: create_run_flags(args, subgraph_folder_path, fsp, fsp_execution_order) if not registry_checked and args.backend == "xrt" and is_windows(): - check_registry_keys(args.device == AIE_DEV_NPU3) + check_registry_keys(args, args.device == AIE_DEV_NPU3) registry_checked = True debug(args, timestamp, subgraph_name, fsp, model_folder_name) if args.dump_aie_status: diff --git a/src/mldebug/utils.py b/src/mldebug/utils.py index 8f5054e..16eaa5d 100644 --- a/src/mldebug/utils.py +++ b/src/mldebug/utils.py @@ -9,6 +9,8 @@ import os import platform +import sys +import threading import time @@ -307,6 +309,75 @@ def print_tile_grid(title, tiles, register_values=None, format_type="hex"): print(f"{'=' * total_width}") +def input_with_timeout(prompt, timeout): + """ + Read a line from stdin, or return None after ``timeout`` seconds. + Uses a daemon thread so it works on Windows (no signal.alarm). + """ + result = [] + + def _reader(): + try: + result.append(input(prompt)) + except EOFError: + pass + + t = threading.Thread(target=_reader, daemon=True) + t.start() + t.join(timeout) + if t.is_alive(): + return None + return result[0] if result else None + + +# Tracks the live DebugServer so cleanup_and_exit can close it on exit. +_active_debug_server = None + + +def register_debug_server(server): + """Register the live DebugServer (or None to clear).""" + global _active_debug_server # pylint: disable=global-statement + _active_debug_server = server + + +def terminate_flexml_connection(timeout=5): + """ + Spin up a brief DebugServer, send TERMINATE_CONNECTION, and close. + Best-effort cleanup used on unplanned exit; all errors are swallowed. + """ + # Import lazily to avoid a circular import (debug_server imports LOGGER). + from mldebug.debug_server import DebugServer # pylint: disable=import-outside-toplevel + + try: + server = DebugServer( + output_dir="", + is_testmode=False, + connect_timeout=timeout, + ) + server.close() + except Exception as e: # pylint: disable=broad-except + LOGGER.log(f"[WARN] flexmlrt cleanup failed: {e}") + + +def cleanup_and_exit(args, code=1): + """ + Exit, first tearing down the flexmlrt connection when ``args.l3`` is set. + Closes the registered DebugServer if any, else starts a brief one to send + TERMINATE_CONNECTION (covers exits that happen before ClientDebug runs). + """ + global _active_debug_server # pylint: disable=global-statement + if args is not None and getattr(args, "l3", False): + if _active_debug_server is not None: + try: + _active_debug_server.close() + except Exception as e: # pylint: disable=broad-except + LOGGER.log(f"[WARN] Failed to close active debug server: {e}") + _active_debug_server = None + else: + terminate_flexml_connection() + sys.exit(code) + + def is_aarch64(): """ ARM