Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/defib/tui/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ def start_recovery(
screen = ProgressScreen(chip, firmware_path, port, send_break)
self.push_screen(screen)

def start_flash_doctor(self, port: str = "") -> None:
def start_flash_doctor(self, chip: str = "", port: str = "") -> None:
"""Switch to Flash Doctor screen."""
screen = FlashDoctorScreen(port=port)
screen = FlashDoctorScreen(chip=chip, port=port)
self.push_screen(screen)


Expand Down
221 changes: 127 additions & 94 deletions src/defib/tui/screens/flash_doctor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,14 @@
from datetime import datetime

from textual.app import ComposeResult
from textual.containers import Vertical, Horizontal, Center
from textual.containers import Vertical, Horizontal
from textual.screen import Screen
from textual.widget import Widget
from textual.widgets import (
Header,
Footer,
Static,
Button,
Select,
RichLog,
)
from textual.reactive import reactive
Expand Down Expand Up @@ -75,20 +74,6 @@ def _build_banner(subtitle: str) -> str:
)


def _get_serial_ports() -> list[tuple[str, str]]:
"""Get available serial ports as (label, value) tuples."""
try:
from serial.tools.list_ports import comports
ports = sorted(
[p for p in comports() if p.vid is not None],
key=lambda p: p.device,
)
if ports:
return [(f"{p.device} - {p.description}", p.device) for p in ports]
except Exception:
pass
return [("No ports found", "")]


class SectorGrid(Widget):
"""Custom widget: Norton Disk Doctor-style sector block grid.
Expand Down Expand Up @@ -127,6 +112,8 @@ def set_geometry(self, num_sectors: int, sector_size: int) -> None:
self._sector_size = sector_size
self._statuses = [None] * num_sectors
self._scanning_idx = None
rows = (num_sectors + self._cols - 1) // self._cols if num_sectors else 1
self.styles.height = rows
self.grid_version += 1

def set_sector(self, index: int, status: int) -> None:
Expand Down Expand Up @@ -273,22 +260,11 @@ class FlashDoctorScreen(Screen[None]):

/* ── Setup panel (pre-scan) ── */
#setup-panel {
width: 80;
width: 100%;
height: auto;
border: thick $accent;
padding: 1 2;
background: $panel;
margin: 1 0;
}

#setup-panel Label {
margin-top: 1;
}

#setup-buttons {
margin-top: 1;
height: 3;
align: center middle;
content-align: center middle;
text-align: center;
padding: 1 0;
}

/* ── Scan view ── */
Expand All @@ -301,6 +277,7 @@ class FlashDoctorScreen(Screen[None]):
#grid-frame {
width: 100%;
height: auto;
min-height: 16;
border: thick $accent;
margin: 0 2;
padding: 1 0;
Expand All @@ -323,7 +300,8 @@ class FlashDoctorScreen(Screen[None]):

#stats-panel {
height: auto;
margin: 1 2;
max-height: 10;
margin: 0 2;
padding: 1 0;
border: tall $accent;
background: $panel;
Expand All @@ -339,8 +317,8 @@ class FlashDoctorScreen(Screen[None]):

/* ── Results log ── */
#results-log {
height: 1fr;
margin: 0 2 1 2;
height: 8;
margin: 0 2;
border: tall $accent;
}

Expand Down Expand Up @@ -371,9 +349,10 @@ class FlashDoctorScreen(Screen[None]):
class ScanComplete(Message):
"""Emitted when a scan finishes."""

def __init__(self, port: str = "") -> None:
def __init__(self, port: str = "", chip: str = "") -> None:
super().__init__()
self._port = port
self._chip = chip
self._transport: object | None = None
self._client: object | None = None
self._scanning = False
Expand All @@ -391,36 +370,22 @@ def __init__(self, port: str = "") -> None:
def compose(self) -> ComposeResult:
yield Header()

port_options = _get_serial_ports()

with Vertical(id="doctor-container"):
yield Static(
_build_banner("SPI NOR Health Scanner — sector by sector"),
id="doctor-banner",
)

with Center():
with Vertical(id="setup-panel"):
yield Static(
"[bold]Connect to a running flash agent[/]\n"
"[dim]Upload the agent first with: defib agent upload[/]",
)
yield Static("")
yield Static("[bold]Serial Port:[/]")
yield Select(
port_options,
prompt="Select port...",
id="port-select-doctor",
allow_blank=True,
value=self._port or (port_options[0][1] if port_options else ""),
)

with Horizontal(id="setup-buttons"):
yield Button(
"Connect & Scan",
variant="success",
id="connect-scan-btn",
)
yield Static(
f" [bold]Chip:[/] {self._chip} [bold]Port:[/] {self._port}\n"
" [bold yellow]Power-cycle the camera, then press Start[/]",
id="setup-panel",
)
yield Button(
"Start",
variant="success",
id="connect-scan-btn",
)

with Vertical(id="scan-view"):
with Vertical(id="grid-frame"):
Expand Down Expand Up @@ -453,21 +418,11 @@ def compose(self) -> ComposeResult:
yield Footer()

def on_mount(self) -> None:
# Start in setup mode
self.query_one("#scan-view").display = False

# If port was pre-set, auto-connect
if self._port:
self.call_later(self._connect_and_scan)

def on_button_pressed(self, event: Button.Pressed) -> None:
if event.button.id == "connect-scan-btn":
sel = self.query_one("#port-select-doctor", Select)
self._port = str(sel.value) if sel.value != Select.BLANK else ""
if self._port:
self._connect_and_scan()
else:
self.notify("Select a serial port first", severity="warning")
self._upload_and_scan()
elif event.button.id == "scan-btn":
self._start_rescan()
elif event.button.id == "save-btn":
Expand All @@ -493,61 +448,133 @@ def action_save_dump(self) -> None:

# ── Connection ───────────────────────────────────────────────────────

def _connect_and_scan(self) -> None:
"""Connect on main thread (quick), then launch scan in background thread."""
self.run_worker(self._do_connect(), exclusive=True)
def _upload_and_scan(self) -> None:
"""Upload agent, connect, and scan — all in one flow."""
btn = self.query_one("#connect-scan-btn", Button)
btn.disabled = True
btn.label = "Waiting for boot..."

async def _do_connect(self) -> None:
"""Connect to agent (runs on main event loop — brief blocking is OK)."""
from defib.agent.client import FlashAgentClient
self.query_one("#setup-panel").display = False
self.query_one("#connect-scan-btn").display = False
self.query_one("#doctor-banner").display = False
self.query_one("#scan-view").display = True
self._log("Power-cycle the camera now!")

self.run_worker(self._do_upload_and_connect(), exclusive=True)

async def _do_upload_and_connect(self) -> None:
"""Upload agent via boot protocol, then connect and scan."""
import asyncio as aio

from defib.agent.client import FlashAgentClient, get_agent_binary
from defib.firmware import get_cached_path, download_firmware, has_firmware
from defib.profiles.loader import load_profile
from defib.protocol.hisilicon_standard import HiSiliconStandard
from defib.recovery.events import ProgressEvent
from defib.transport.serial import SerialTransport

log = self.query_one("#results-log", RichLog)
log.clear()
self._log("Connecting to flash agent...")
chip = self._chip
port = self._port

# Find agent binary
agent_path = get_agent_binary(chip)
if not agent_path:
self._log(f"[red]No agent binary for '{chip}'[/]")
return

agent_data = agent_path.read_bytes()

# Get SPL from cached U-Boot (download if needed)
try:
profile = load_profile(chip)
except Exception as e:
self._log(f"[red]Unknown chip '{chip}':[/] {e}")
return

cached_fw = get_cached_path(chip)
if not cached_fw:
if has_firmware(chip):
self._log("Downloading U-Boot for SPL...")
try:
cached_fw = download_firmware(chip)
except Exception as e:
self._log(f"[red]Download failed:[/] {e}")
return
else:
self._log(f"[red]No firmware available for '{chip}'[/]")
return

spl_data = cached_fw.read_bytes()[:profile.spl_max_size]
self._log(
f"Agent: [cyan]{agent_path.name}[/] ({len(agent_data)} bytes) "
f"SPL: {len(spl_data)} bytes"
)

# Open serial port
try:
self._transport = await SerialTransport.create(self._port)
transport = await SerialTransport.create(port)
except Exception as e:
self._log(f"[red]Failed to open port:[/] {e}")
self.notify(f"Port error: {e}", severity="error")
self._log(f"[red]Port error:[/] {e}")
return

# Handshake + upload via boot protocol
protocol = HiSiliconStandard()
protocol.set_profile(profile)

def on_progress(e: ProgressEvent) -> None:
if e.message:
self._log(f" {e.message}")

hs = await protocol.handshake(transport, on_progress)
if not hs.success:
self._log("[red]Handshake failed[/]")
await transport.close()
return

result = await protocol.send_firmware(
transport, agent_data, on_progress, spl_override=spl_data,
)
if not result.success:
self._log(f"[red]Upload failed:[/] {result.error}")
await transport.close()
return

client = FlashAgentClient(self._transport)
self._log("[green]Agent uploaded![/] Waiting for READY...")

# Reconnect and wait for agent
await transport.close()
await aio.sleep(2)
transport = await SerialTransport.create(port)

client = FlashAgentClient(transport, chip)
if not await client.connect(timeout=10.0):
self._log("[red]Agent not responding[/] — is it uploaded?")
self.notify("Agent not responding", severity="error")
await self._transport.close()
self._transport = None
self._log("[red]Agent not responding after upload[/]")
await transport.close()
return

info = await client.get_info()
self._transport = transport
self._client = client
self._flash_size = int(info.get("flash_size", 0))
self._sector_size = int(info.get("sector_size", 0x10000))
self._num_sectors = self._flash_size // self._sector_size if self._sector_size else 0
jedec = info.get("jedec_id", "??????")
jedec = str(info.get("jedec_id", "??????"))

self._log(
f"[green]Connected![/] JEDEC: [bold]{jedec}[/] "
f"Flash: [bold]{self._flash_size // 1024}KB[/] "
f"Sectors: [bold]{self._num_sectors}[/] × {self._sector_size // 1024}KB"
)

# Switch to scan view
self.query_one("#setup-panel").display = False
subtitle = f"{jedec} — {self._flash_size // 1024}KB — {self._num_sectors} sectors"
self.query_one("#doctor-banner", Static).update(_build_banner(subtitle))
self.query_one("#scan-view").display = True

# Configure grid
grid = self.query_one("#sector-grid", SectorGrid)
grid.set_geometry(self._num_sectors, self._sector_size)
self.query_one("#grid-title", Static).update(
f"[bold]Flash Map[/] — {self._num_sectors} sectors"
)

# Now launch the scan in a background thread
self._launch_scan()

def _start_rescan(self) -> None:
Expand Down Expand Up @@ -592,10 +619,16 @@ def _scan_thread_fn(self) -> None:
"""Run scan_flash in a plain thread (blocking serial I/O happens here)."""
import asyncio as _aio

from defib.agent.client import FlashAgentClient, SectorResult
from defib.agent.client import FlashAgentClient, FALLBACK_BAUD, SectorResult

client: FlashAgentClient = self._client # type: ignore[assignment]

# Reset baud to 115200 — the agent may have reverted after idle timeout
port = getattr(client._transport, '_port', None)
if port is not None:
port.baudrate = FALLBACK_BAUD
client._current_baud = FALLBACK_BAUD

def on_sector(result: SectorResult) -> None:
self._pending_sectors.append(result)

Expand Down
Loading
Loading