Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 102 additions & 5 deletions src/mldebug/input_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from mldebug.arch import load_aie_arch, AIE_DEV_PHX, AIE_DEV_STX, AIE_DEV_TEL
from mldebug.backend.core_dump_impl import CoreDumpFallbackReader
from mldebug.backend.factory import BackendConfig, create_backend
from mldebug.utils import LOGGER, cleanup_and_exit, input_with_timeout, is_aarch64, is_windows

# Seconds to wait at interactive prompts before giving up and exiting.
Expand Down Expand Up @@ -256,13 +257,73 @@ def print_hw_context_table(current_contexts: dict[str, dict[str, str]]) -> None:
LOGGER.log(f"{context:<12} {columns_str:<30} {context_data['pid']:<12} {context_data['status']:<12}")


def _validate_contexts_with_read(contexts: dict, device: str, aie_iface) -> list[tuple[int, int]] | None:
"""
Validate ALL contexts by reading CORE_STATUS register (verifies register access)

Args:
contexts: All hardware contexts from xrt-smi (context_id -> info incl. status)
device: Device name (for backend initialization)
aie_iface: Already-loaded AIE interface, or None to load it

Returns:
List of (context_id, pid) tuples that passed validation, or None if none passed.
"""
# Use first AIE core tile for test read
# Tile layout: Row 0=Shim, Rows 1 to (OFFSET-1)=Memory, Rows OFFSET+=AIE cores
# For Telluride: (0, 3), For PHX/STX: (0, 2)
test_col = 0
test_row = aie_iface.AIE_TILE_ROW_OFFSET

# CORE_STATUS register - safe read-only register
# Device-specific addresses: Telluride=0x38004, PHX/STX=0x32004
test_reg = aie_iface.Core_registers["CORE_STATUS"]
test_tiles = [(test_col, test_row)]

valid_contexts = []
for ctx_id, ctx_info in contexts.items():
backend = None
try:
pid = int(ctx_info["pid"])
ctx = int(ctx_id)

config = BackendConfig(
tiles=test_tiles,
ctx_id=ctx,
pid=pid,
device=device,
)
backend = create_backend("xrt", config)

backend.read_register(test_col, test_row, test_reg)
valid_contexts.append((ctx, pid))

# TODO: catch device-specific errors (e.g. EBUSY from XRT) instead of Exception
except Exception as e:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in future, it would be good to catch the specific exception

print(f"[DEBUG] Context {ctx_id} failed validation: {type(e).__name__}: {e}")
continue

# Clean up the test backend to avoid resource leaks
finally:
del backend

if not valid_contexts:
print("[WARNING] No contexts passed validation")
return None
return valid_contexts


def check_hw_context(args) -> tuple[int, int]:
"""
Returns (ctx_id, pid) from xrt-smi, prompting the user as a fallback.
Manual prompts time out after ``HW_CONTEXT_INPUT_TIMEOUT_S`` seconds and
call ``cleanup_and_exit(args, 1)`` on failure / timeout.
Returns (ctx_id, pid) from xrt-smi.

1. If only one context exists, auto-select it.
2. If multiple exist, validate all (Active and Idle) with a CORE_STATUS register read.
3. If no context passes validation, prompt the user (60s timeout; invalid input or timeout
calls ``cleanup_and_exit(args, 1)``).
"""
device = args.device
aie_iface = args.aie_iface
filename = "xrt-smi_output.json"
use_shell = is_windows()

Expand Down Expand Up @@ -290,14 +351,23 @@ def check_hw_context(args) -> tuple[int, int]:
if not current_contexts:
print("Warning: xrt-smi could find no applications running. Please launch an application to use MLDebugger.")
raise FileNotFoundError

# Path 1: Single context found -> auto-select it
if len(current_contexts) == 1:
ctx = int(list(current_contexts.keys())[0])
pid = int(list(current_contexts.values())[0]["pid"])
else:
return ctx, pid

# Path 2: Multiple contexts found -> validate all with register read test
print(f"[INFO] Found {len(current_contexts)} hardware context(s). Validating with register read test...")
valid_contexts = _validate_contexts_with_read(current_contexts, device, aie_iface)

# Path 2a: No contexts passed validation -> prompt user for input
if valid_contexts is None:
print_hw_context_table(current_contexts)
# Ask user
selected_context_id = input_with_timeout(
"Multiple Contexts Found. Please enter the Context ID you want to select: ",
"No Contexts passed validation. Please enter the Context ID you want to select: ",
HW_CONTEXT_INPUT_TIMEOUT_S,
)
if selected_context_id in current_contexts:
Expand All @@ -306,6 +376,33 @@ def check_hw_context(args) -> tuple[int, int]:
else:
LOGGER.log("Could not find the provided context, Exiting now.")
cleanup_and_exit(args, 1)
return ctx, pid

# Path 2b: Single valid context found -> auto-select it
elif len(valid_contexts) == 1:
ctx, pid = valid_contexts[0]
return ctx, pid

# Path 2c: Multiple valid contexts found -> prompt user for input
else:
lookup = {str(ctx): (ctx, pid) for ctx, pid in valid_contexts}
valid_ids = set(lookup.keys())
valid_only = {k: v for k, v in current_contexts.items() if str(k) in valid_ids}
print_hw_context_table(valid_only)
# Ask user
selected_context_id = input_with_timeout(
f"{len(valid_contexts)} Contexts passed validation. "
"Please enter the Context ID you want to select: ",
HW_CONTEXT_INPUT_TIMEOUT_S,
)
if selected_context_id in valid_only:
ctx = int(selected_context_id)
pid = int(valid_only[selected_context_id]["pid"])
else:
LOGGER.log(f"Context ID {selected_context_id} not found. Valid options: {', '.join(valid_only.keys())}")
cleanup_and_exit(args, 1)
return ctx, pid

except (FileNotFoundError, subprocess.CalledProcessError, json.JSONDecodeError):
LOGGER.log(
f"Error with xrt-smi. Please enter ctx, pid manually "
Expand Down
Loading