diff --git a/src/defib/tui/app.py b/src/defib/tui/app.py index 30cf716..430fdcb 100644 --- a/src/defib/tui/app.py +++ b/src/defib/tui/app.py @@ -37,9 +37,9 @@ def start_recovery( screen = ProgressScreen(chip, firmware_path, port, send_break) self.push_screen(screen) - def start_flash_doctor(self, port: str = "") -> None: + def start_flash_doctor(self, chip: str = "", port: str = "") -> None: """Switch to Flash Doctor screen.""" - screen = FlashDoctorScreen(port=port) + screen = FlashDoctorScreen(chip=chip, port=port) self.push_screen(screen) diff --git a/src/defib/tui/screens/flash_doctor.py b/src/defib/tui/screens/flash_doctor.py index 648ea30..e1cb5a8 100644 --- a/src/defib/tui/screens/flash_doctor.py +++ b/src/defib/tui/screens/flash_doctor.py @@ -11,7 +11,7 @@ from datetime import datetime from textual.app import ComposeResult -from textual.containers import Vertical, Horizontal, Center +from textual.containers import Vertical, Horizontal from textual.screen import Screen from textual.widget import Widget from textual.widgets import ( @@ -19,7 +19,6 @@ Footer, Static, Button, - Select, RichLog, ) from textual.reactive import reactive @@ -75,20 +74,6 @@ def _build_banner(subtitle: str) -> str: ) -def _get_serial_ports() -> list[tuple[str, str]]: - """Get available serial ports as (label, value) tuples.""" - try: - from serial.tools.list_ports import comports - ports = sorted( - [p for p in comports() if p.vid is not None], - key=lambda p: p.device, - ) - if ports: - return [(f"{p.device} - {p.description}", p.device) for p in ports] - except Exception: - pass - return [("No ports found", "")] - class SectorGrid(Widget): """Custom widget: Norton Disk Doctor-style sector block grid. @@ -127,6 +112,8 @@ def set_geometry(self, num_sectors: int, sector_size: int) -> None: self._sector_size = sector_size self._statuses = [None] * num_sectors self._scanning_idx = None + rows = (num_sectors + self._cols - 1) // self._cols if num_sectors else 1 + self.styles.height = rows self.grid_version += 1 def set_sector(self, index: int, status: int) -> None: @@ -273,22 +260,11 @@ class FlashDoctorScreen(Screen[None]): /* ── Setup panel (pre-scan) ── */ #setup-panel { - width: 80; + width: 100%; height: auto; - border: thick $accent; - padding: 1 2; - background: $panel; - margin: 1 0; - } - - #setup-panel Label { - margin-top: 1; - } - - #setup-buttons { - margin-top: 1; - height: 3; - align: center middle; + content-align: center middle; + text-align: center; + padding: 1 0; } /* ── Scan view ── */ @@ -301,6 +277,7 @@ class FlashDoctorScreen(Screen[None]): #grid-frame { width: 100%; height: auto; + min-height: 16; border: thick $accent; margin: 0 2; padding: 1 0; @@ -323,7 +300,8 @@ class FlashDoctorScreen(Screen[None]): #stats-panel { height: auto; - margin: 1 2; + max-height: 10; + margin: 0 2; padding: 1 0; border: tall $accent; background: $panel; @@ -339,8 +317,8 @@ class FlashDoctorScreen(Screen[None]): /* ── Results log ── */ #results-log { - height: 1fr; - margin: 0 2 1 2; + height: 8; + margin: 0 2; border: tall $accent; } @@ -371,9 +349,10 @@ class FlashDoctorScreen(Screen[None]): class ScanComplete(Message): """Emitted when a scan finishes.""" - def __init__(self, port: str = "") -> None: + def __init__(self, port: str = "", chip: str = "") -> None: super().__init__() self._port = port + self._chip = chip self._transport: object | None = None self._client: object | None = None self._scanning = False @@ -391,36 +370,22 @@ def __init__(self, port: str = "") -> None: def compose(self) -> ComposeResult: yield Header() - port_options = _get_serial_ports() - with Vertical(id="doctor-container"): yield Static( _build_banner("SPI NOR Health Scanner — sector by sector"), id="doctor-banner", ) - with Center(): - with Vertical(id="setup-panel"): - yield Static( - "[bold]Connect to a running flash agent[/]\n" - "[dim]Upload the agent first with: defib agent upload[/]", - ) - yield Static("") - yield Static("[bold]Serial Port:[/]") - yield Select( - port_options, - prompt="Select port...", - id="port-select-doctor", - allow_blank=True, - value=self._port or (port_options[0][1] if port_options else ""), - ) - - with Horizontal(id="setup-buttons"): - yield Button( - "Connect & Scan", - variant="success", - id="connect-scan-btn", - ) + yield Static( + f" [bold]Chip:[/] {self._chip} [bold]Port:[/] {self._port}\n" + " [bold yellow]Power-cycle the camera, then press Start[/]", + id="setup-panel", + ) + yield Button( + "Start", + variant="success", + id="connect-scan-btn", + ) with Vertical(id="scan-view"): with Vertical(id="grid-frame"): @@ -453,21 +418,11 @@ def compose(self) -> ComposeResult: yield Footer() def on_mount(self) -> None: - # Start in setup mode self.query_one("#scan-view").display = False - # If port was pre-set, auto-connect - if self._port: - self.call_later(self._connect_and_scan) - def on_button_pressed(self, event: Button.Pressed) -> None: if event.button.id == "connect-scan-btn": - sel = self.query_one("#port-select-doctor", Select) - self._port = str(sel.value) if sel.value != Select.BLANK else "" - if self._port: - self._connect_and_scan() - else: - self.notify("Select a serial port first", severity="warning") + self._upload_and_scan() elif event.button.id == "scan-btn": self._start_rescan() elif event.button.id == "save-btn": @@ -493,40 +448,117 @@ def action_save_dump(self) -> None: # ── Connection ─────────────────────────────────────────────────────── - def _connect_and_scan(self) -> None: - """Connect on main thread (quick), then launch scan in background thread.""" - self.run_worker(self._do_connect(), exclusive=True) + def _upload_and_scan(self) -> None: + """Upload agent, connect, and scan — all in one flow.""" + btn = self.query_one("#connect-scan-btn", Button) + btn.disabled = True + btn.label = "Waiting for boot..." - async def _do_connect(self) -> None: - """Connect to agent (runs on main event loop — brief blocking is OK).""" - from defib.agent.client import FlashAgentClient + self.query_one("#setup-panel").display = False + self.query_one("#connect-scan-btn").display = False + self.query_one("#doctor-banner").display = False + self.query_one("#scan-view").display = True + self._log("Power-cycle the camera now!") + + self.run_worker(self._do_upload_and_connect(), exclusive=True) + + async def _do_upload_and_connect(self) -> None: + """Upload agent via boot protocol, then connect and scan.""" + import asyncio as aio + + from defib.agent.client import FlashAgentClient, get_agent_binary + from defib.firmware import get_cached_path, download_firmware, has_firmware + from defib.profiles.loader import load_profile + from defib.protocol.hisilicon_standard import HiSiliconStandard + from defib.recovery.events import ProgressEvent from defib.transport.serial import SerialTransport - log = self.query_one("#results-log", RichLog) - log.clear() - self._log("Connecting to flash agent...") + chip = self._chip + port = self._port + + # Find agent binary + agent_path = get_agent_binary(chip) + if not agent_path: + self._log(f"[red]No agent binary for '{chip}'[/]") + return + + agent_data = agent_path.read_bytes() + + # Get SPL from cached U-Boot (download if needed) + try: + profile = load_profile(chip) + except Exception as e: + self._log(f"[red]Unknown chip '{chip}':[/] {e}") + return + + cached_fw = get_cached_path(chip) + if not cached_fw: + if has_firmware(chip): + self._log("Downloading U-Boot for SPL...") + try: + cached_fw = download_firmware(chip) + except Exception as e: + self._log(f"[red]Download failed:[/] {e}") + return + else: + self._log(f"[red]No firmware available for '{chip}'[/]") + return + + spl_data = cached_fw.read_bytes()[:profile.spl_max_size] + self._log( + f"Agent: [cyan]{agent_path.name}[/] ({len(agent_data)} bytes) " + f"SPL: {len(spl_data)} bytes" + ) + # Open serial port try: - self._transport = await SerialTransport.create(self._port) + transport = await SerialTransport.create(port) except Exception as e: - self._log(f"[red]Failed to open port:[/] {e}") - self.notify(f"Port error: {e}", severity="error") + self._log(f"[red]Port error:[/] {e}") + return + + # Handshake + upload via boot protocol + protocol = HiSiliconStandard() + protocol.set_profile(profile) + + def on_progress(e: ProgressEvent) -> None: + if e.message: + self._log(f" {e.message}") + + hs = await protocol.handshake(transport, on_progress) + if not hs.success: + self._log("[red]Handshake failed[/]") + await transport.close() + return + + result = await protocol.send_firmware( + transport, agent_data, on_progress, spl_override=spl_data, + ) + if not result.success: + self._log(f"[red]Upload failed:[/] {result.error}") + await transport.close() return - client = FlashAgentClient(self._transport) + self._log("[green]Agent uploaded![/] Waiting for READY...") + + # Reconnect and wait for agent + await transport.close() + await aio.sleep(2) + transport = await SerialTransport.create(port) + + client = FlashAgentClient(transport, chip) if not await client.connect(timeout=10.0): - self._log("[red]Agent not responding[/] — is it uploaded?") - self.notify("Agent not responding", severity="error") - await self._transport.close() - self._transport = None + self._log("[red]Agent not responding after upload[/]") + await transport.close() return info = await client.get_info() + self._transport = transport self._client = client self._flash_size = int(info.get("flash_size", 0)) self._sector_size = int(info.get("sector_size", 0x10000)) self._num_sectors = self._flash_size // self._sector_size if self._sector_size else 0 - jedec = info.get("jedec_id", "??????") + jedec = str(info.get("jedec_id", "??????")) self._log( f"[green]Connected![/] JEDEC: [bold]{jedec}[/] " @@ -534,20 +566,15 @@ async def _do_connect(self) -> None: f"Sectors: [bold]{self._num_sectors}[/] × {self._sector_size // 1024}KB" ) - # Switch to scan view - self.query_one("#setup-panel").display = False subtitle = f"{jedec} — {self._flash_size // 1024}KB — {self._num_sectors} sectors" self.query_one("#doctor-banner", Static).update(_build_banner(subtitle)) - self.query_one("#scan-view").display = True - # Configure grid grid = self.query_one("#sector-grid", SectorGrid) grid.set_geometry(self._num_sectors, self._sector_size) self.query_one("#grid-title", Static).update( f"[bold]Flash Map[/] — {self._num_sectors} sectors" ) - # Now launch the scan in a background thread self._launch_scan() def _start_rescan(self) -> None: @@ -592,10 +619,16 @@ def _scan_thread_fn(self) -> None: """Run scan_flash in a plain thread (blocking serial I/O happens here).""" import asyncio as _aio - from defib.agent.client import FlashAgentClient, SectorResult + from defib.agent.client import FlashAgentClient, FALLBACK_BAUD, SectorResult client: FlashAgentClient = self._client # type: ignore[assignment] + # Reset baud to 115200 — the agent may have reverted after idle timeout + port = getattr(client._transport, '_port', None) + if port is not None: + port.baudrate = FALLBACK_BAUD + client._current_baud = FALLBACK_BAUD + def on_sector(result: SectorResult) -> None: self._pending_sectors.append(result) diff --git a/src/defib/tui/screens/main.py b/src/defib/tui/screens/main.py index 0248466..4cdfec0 100644 --- a/src/defib/tui/screens/main.py +++ b/src/defib/tui/screens/main.py @@ -163,7 +163,7 @@ def compose(self) -> ComposeResult: def _get_chip(self) -> str: sel = self.query_one("#chip-select", Select) - return str(sel.value) if sel.value != Select.BLANK else "" + return str(sel.value) if isinstance(sel.value, str) else "" def on_select_changed(self, event: Select.Changed) -> None: if event.select.id == "chip-select": @@ -298,10 +298,20 @@ def _start_recovery(self) -> None: app.start_recovery(chip, firmware_path, port, send_break) def _start_flash_doctor(self) -> None: + chip = self._get_chip() port_sel = self.query_one("#port-select", Select) - port = str(port_sel.value) if port_sel.value != Select.BLANK else "" + port = str(port_sel.value) if isinstance(port_sel.value, str) else "" + + errors: list[str] = [] + if not chip: + errors.append("Select a chip model") + if not port: + errors.append("Select a serial port") + if errors: + self.notify("\n".join(errors), severity="error", title="Flash Doctor") + return from defib.tui.app import DefibApp app = self.app if isinstance(app, DefibApp): - app.start_flash_doctor(port) + app.start_flash_doctor(chip=chip, port=port) diff --git a/tests/test_tui.py b/tests/test_tui.py index a5794ec..8560b3f 100644 --- a/tests/test_tui.py +++ b/tests/test_tui.py @@ -79,9 +79,10 @@ async def test_flash_doctor_screen_renders(self): app = DefibApp() async with app.run_test(size=(120, 40)) as pilot: - screen = FlashDoctorScreen() - app.push_screen(screen) + app.start_flash_doctor(chip="hi3516ev300", port="/dev/ttyUSB0") await pilot.pause() + screen = app.screen + assert isinstance(screen, FlashDoctorScreen) assert screen.query_one("#doctor-banner") is not None assert screen.query_one("#sector-grid") is not None assert screen.query_one("#scan-stats") is not None @@ -89,18 +90,26 @@ async def test_flash_doctor_screen_renders(self): assert screen.query_one("#connect-scan-btn") is not None @pytest.mark.asyncio - async def test_flash_doctor_button_on_main_screen(self): - """Main screen has Flash Doctor button that opens the screen.""" - from defib.tui.screens.flash_doctor import FlashDoctorScreen + async def test_flash_doctor_blocked_without_chip(self): + """Flash Doctor button requires chip and port selection.""" from defib.tui.screens.main import MainScreen app = DefibApp() async with app.run_test(size=(120, 40)) as pilot: + await pilot.click("#doctor-btn") + await pilot.pause() + # Should stay on MainScreen — no chip selected assert isinstance(app.screen, MainScreen) - btn = app.screen.query_one("#doctor-btn") - assert btn is not None - await pilot.click("#doctor-btn") + @pytest.mark.asyncio + async def test_flash_doctor_opens_with_chip_and_port(self): + """Flash Doctor opens when chip and port are selected.""" + from defib.tui.screens.flash_doctor import FlashDoctorScreen + + app = DefibApp() + async with app.run_test(size=(120, 40)) as pilot: + # Bypass UI selection — push screen directly + app.start_flash_doctor(chip="hi3516ev300", port="/dev/ttyUSB0") await pilot.pause() assert isinstance(app.screen, FlashDoctorScreen) @@ -112,7 +121,7 @@ async def test_flash_doctor_escape_goes_back(self): app = DefibApp() async with app.run_test(size=(120, 40)) as pilot: - await pilot.click("#doctor-btn") + app.start_flash_doctor(chip="hi3516ev300", port="/dev/ttyUSB0") await pilot.pause() assert isinstance(app.screen, FlashDoctorScreen)