diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 3fa0fb1..c22899c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -132,7 +132,7 @@ jobs: run: | python -m pip install --upgrade pip pip install -e . - pip install pytest pytest-cov opencensus-ext-azure + pip install pytest pytest-cov pytest-asyncio opencensus-ext-azure - name: Run tests run: pytest tests/ -v --tb=short --cov=azext_prototype --cov-report=xml --cov-report=term-missing diff --git a/.github/workflows/pr.yml b/.github/workflows/pr.yml index c63fa7f..0282a55 100644 --- a/.github/workflows/pr.yml +++ b/.github/workflows/pr.yml @@ -92,7 +92,7 @@ jobs: run: | python -m pip install --upgrade pip pip install -e . - pip install pytest pytest-cov opencensus-ext-azure + pip install pytest pytest-cov pytest-asyncio opencensus-ext-azure - name: Run tests run: pytest tests/ -v --tb=short --cov=azext_prototype --cov-report=xml --cov-report=term-missing diff --git a/HISTORY.rst b/HISTORY.rst index 75039b3..f7aae0f 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -3,6 +3,212 @@ Release History =============== +0.2.1b4 ++++++++ + +Discovery section gating and architecture task tracking +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +* **Reliable section completion gate** — replaced heuristic phrase + matching (``_is_section_done()``) with an explicit AI confirmation + step. Sections only advance when the AI responds with "Yes", + eliminating false-positive checkmarks from transitional language. +* **"All topics covered" accuracy** — the message now only appears + when every section received explicit AI confirmation. Otherwise a + softer prompt is shown. +* **"continue" keyword** — users can type ``continue`` (in addition + to ``done``) to proceed from discovery to architecture generation. +* **Architecture sections in task tree** — ``_generate_architecture_sections()`` + now reports each section to the TUI task tree with ``in_progress`` / + ``completed`` status updates. Dynamically discovered sections + (``[NEW_SECTION]`` markers) are appended in real time. +* **Timer format** — elapsed times >= 60 s now display as ``1m04s`` + instead of ``64s`` in the TUI info bar and per-section console + output. + +TUI console color, wrapping, and section pagination +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +* **Color consolidation** — all color constants now live in ``theme.py`` + as the single source of truth. Duplicate theme dicts in ``console.py`` + and hardcoded hex colors in ``task_tree.py``, ``tui_adapter.py``, and + ``console.py`` toolbar functions replaced with ``COLORS`` imports. +* **Rich markup preservation** — ``TUIAdapter.print_fn()`` no longer + strips Rich markup tags. Messages containing ``[success]``, + ``[info]``, etc. are routed to the new ``ConsoleView.write_markup()`` + method so status messages retain their colors in the TUI. +* **Horizontal wrapping** — ``ConsoleView`` (``RichLog``) now passes + ``wrap=True``, eliminating the horizontal scrollbar for long lines. +* **Agent response rendering** — new ``TUIAdapter.response_fn()`` + renders agent responses as colored Markdown via + ``ConsoleView.write_agent_response()``. Wired through + ``DiscoverySession.run()`` → ``DesignStage.execute()`` → + ``StageOrchestrator._run_design()``. +* **Section pagination** — multi-section agent responses (split on + ``## `` headings) are shown one section at a time with an "Enter to + continue" prompt between them. Single-section responses render all + at once. +* **Empty submit support** — ``PromptInput.enable(allow_empty=True)`` + allows submitting with no text, used by the pagination "Enter to + continue" prompt. Empty submissions are not echoed to the console. +* **Clean Ctrl+C exit** — ``_run_tui()`` helper in ``custom.py`` + suppresses ``SIGINT`` during the Textual run so Ctrl+C is handled + exclusively as a key event. Prevents ``KeyboardInterrupt`` from + propagating to the Azure CLI framework and eliminates the Windows + "Terminate batch job (Y/N)?" prompt from ``az.cmd``. + +Build-deploy stage decoupling +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +* **Stable stage IDs** — build stages now carry a persistent ``id`` + field (slug derived from name, e.g. ``"data-layer"``). IDs survive + renumbering, stage insertion/removal, and design iteration. Legacy + state files are backfilled on load. +* **Build-deploy correspondence** — deploy stages link back to build + stages via ``build_stage_id`` instead of fragile stage numbers. + ``sync_from_build_state()`` performs smart reconciliation: matching + stages are updated while preserving deploy progress, new build stages + create new deploy stages, and removed build stages are marked + ``"removed"``. +* **Stage splitting** — ``split_stage(N, substages)`` replaces one + deploy stage with N substages (``5a``, ``5b``, ``5c``) sharing the + same ``build_stage_id``. Supports code splits (Type A), deploy-only + splits (Type B), and manual step insertion (Type C). +* **Manual deployment steps** — stages with ``deploy_mode: "manual"`` + display instructions and pause for user confirmation (Done / Skip / + Need help) instead of executing IaC commands. Manual steps can + originate from the architect during plan derivation or from + remediation. +* **New deploy statuses** — ``"removed"`` (build stage deleted), + ``"destroyed"`` (resources torn down), ``"awaiting_manual"`` (waiting + for user confirmation). +* **New slash commands** — ``/split N`` (interactive stage splitting), + ``/destroy N`` (resource destruction with confirmation), + ``/manual N "instructions"`` (add/view manual step instructions). +* **Compound stage references** — all stage-referencing commands + (``/deploy``, ``/rollback``, ``/redeploy``, ``/plan``, ``/describe``) + accept substage labels: ``/deploy 5a``, ``/rollback 5`` (all + substages in reverse order). +* **Re-entry sync** — when the deploy session re-enters with an + existing deploy state, it syncs with the latest build state and + reports changes (new stages, removed stages, updated code). +* **Display improvements** — removed stages show with strikethrough and + ``(Removed)`` suffix, manual steps show ``[Manual]`` badge, substages + display compound IDs (``2a``, ``2b``). + +Deploy auto-remediation +~~~~~~~~~~~~~~~~~~~~~~~~ +* **Automatic deploy failure remediation** — when a deployment stage + fails, the system now automatically diagnoses (QA engineer), + determines a fix strategy (cloud architect), regenerates the code + (IaC/app agent), and retries deployment — up to 2 remediation + attempts before falling through to the interactive loop. +* **Downstream impact tracking** — after fixing a stage, the + architect checks whether downstream stages need regeneration + due to changed outputs or dependencies. Affected stages are + automatically regenerated before their deploy. +* **Consistent QA routing** — ``/deploy N`` and ``/redeploy N`` + slash commands now route through the remediation loop on failure, + not just print the error. +* **Deploy state enhancements** — new ``remediating`` status, + per-stage ``remediation_attempts`` counter, ``add_patch_stages()``, + and ``renumber_stages()`` methods. + +Incremental build stage +~~~~~~~~~~~~~~~~~~~~~~~~ +* **Design change detection** — ``BuildState`` now stores a design + snapshot (architecture hash + full text) after each build. On + re-entry, the build session compares the current design against the + snapshot to determine whether regeneration is needed. +* **Three-branch Phase 2** — the deployment plan derivation phase now + has three paths: + + - **Branch A** (first build): derive a fresh plan and save the + design snapshot. + - **Branch B** (design changed): ask the architect agent to diff the + old and new architectures, classify each stage as unchanged / + modified / removed, identify new services, and apply targeted + updates (``mark_stages_stale``, ``remove_stages``, ``add_stages``). + When ``plan_restructured`` is flagged, the user is offered a full + plan re-derive. + - **Branch C** (no changes): report "Build is up to date" and skip + directly to the review loop. + +* **Incremental stage operations** on ``BuildState``: + ``set_design_snapshot()``, ``design_has_changed()``, + ``get_previous_architecture()``, ``mark_stages_stale()``, + ``remove_stages()``, ``add_stages()``, ``renumber_stages()``. +* **Architecture diff via architect agent** — + ``_diff_architectures()`` sends old/new architecture + existing + stages to the architect, parses JSON classification, and falls back + to marking all stages as modified when the architect is unavailable. +* **Legacy build compatibility** — builds without a design snapshot + (pre-incremental) are treated as "design changed" with all stages + marked for rebuild, preserving conversation history. + +TUI dashboard +~~~~~~~~~~~~~~ +* **Added Textual TUI dashboard** — ``az prototype launch`` opens a full + terminal UI with four panels: scrollable console output (RichLog), + collapsible task tree with async status updates, growable multi-line + prompt (Enter to submit, Shift+Enter for newline), and an info bar + showing assist text and token usage. +* **Stage orchestrator** — the TUI auto-detects the current project stage + from ``.prototype/state/`` files and launches the appropriate session. + Users can navigate between design, build, and deploy without exiting. +* **Session-TUI bridge** — ``TUIAdapter`` connects synchronous sessions to + the async Textual event loop using ``call_from_thread`` and + ``threading.Event``. Sessions run on worker threads with ``input_fn`` + and ``print_fn`` routed through TUI widgets. +* **Spinner → task tree** — ``_maybe_spinner`` on all four sessions + (discovery, build, deploy, backlog) now accepts a ``status_fn`` callback + so the TUI can show progress via the info bar instead of Rich spinners. +* **Guarded console calls** — discovery slash commands (``/open``, + ``/status``, ``/why``, ``/summary``, ``/restart``, ``/help``) and + design stage header now route through ``_print`` when ``input_fn`` / + ``print_fn`` are injected, preventing Rich output conflicts in TUI mode. +* **New dependency** — ``textual>=8.0.0``. +* **Design command launches TUI** — ``az prototype design`` now opens the + TUI dashboard and auto-starts the design session, instead of running + synchronously in the terminal. ``--status`` remains CLI-only. + Artifact paths are resolved to absolute before the TUI takes over. +* **Section headers as tree branches** — during discovery, the + biz-analyst's AI responses are scanned for ``##`` / ``###`` headings + (e.g. "Project Context & Scope", "Data & Content") which appear as + collapsible sub-nodes under the Design branch in the task tree. + Duplicate headings are deduplicated by slug. + +Natural language intent detection +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +* **AI-powered command classification** — all four interactive sessions + (discovery, build, deploy, backlog) now accept natural language + instead of slash commands. When an AI provider is available, a + lightweight classification call maps user input to the appropriate + command. Falls back to keyword/regex scoring when AI is unavailable. +* **Mid-session file reading** — ``"read artifacts from "`` + reads files (PDF, DOCX, PPTX, images, text) during any session and + injects the content into the conversation context. +* **Deploy session natural language** — the deploy session no longer + requires slash commands. ``"deploy stage 3"``, ``"rollback all"``, + ``"deploy stages 3 and 4"`` are interpreted and executed directly. +* **Stage description command** — new ``/describe N`` command in both + build and deploy sessions. Natural language variants like + ``"describe stage 3"`` or ``"what's being deployed in stage 2"`` + show detailed resource, file, and status information for a stage. +* **Project summary in TUI** — the welcome banner now shows a + one-line project summary extracted from discovery state or the + design architecture. + +Packaging +~~~~~~~~~~ +* **Added ``__init__.py`` to data-only directories** — 15 data directories + (policies, standards, templates, knowledge, agent definitions) lacked + ``__init__.py``, causing setuptools "Package would be ignored" warnings + during wheel builds. The ``templates/`` directory also contained Python + modules (``registry.py``, ``validate.py``) that were not included in the + wheel. All data directories now have ``__init__.py`` so ``find_packages()`` + discovers them correctly. +* **Excluded ``__pycache__`` from package discovery** — ``setup.py`` now + filters ``__pycache__`` directories from ``find_packages()`` results to + prevent spurious build warnings. + 0.2.1b3 +++++++ diff --git a/azext_prototype/_params.py b/azext_prototype/_params.py index 951c6b9..3c136cd 100644 --- a/azext_prototype/_params.py +++ b/azext_prototype/_params.py @@ -48,6 +48,15 @@ def load_arguments(self, _): help="AI model to use (default: claude-sonnet-4.5 for copilot, gpt-4o for others).", ) + # --- az prototype launch --- + with self.argument_context("prototype launch") as c: + c.argument( + "stage", + arg_type=get_enum_type(["design", "build", "deploy"]), + help="Start the TUI at a specific stage instead of auto-detecting.", + default=None, + ) + # --- az prototype design --- with self.argument_context("prototype design") as c: c.argument( diff --git a/azext_prototype/agents/builtin/biz_analyst.py b/azext_prototype/agents/builtin/biz_analyst.py index 046dcda..373f17e 100644 --- a/azext_prototype/agents/builtin/biz_analyst.py +++ b/azext_prototype/agents/builtin/biz_analyst.py @@ -62,35 +62,34 @@ def __init__(self): with a user to prepare requirements for an Azure prototype. You're \ having a conversation — not running a questionnaire. -Talk to the user the way an experienced consultant would: listen \ -carefully, pick up on what they said (and what they didn't), and ask \ -the questions that matter most right now. Let the conversation flow \ -naturally from one topic to the next rather than dumping a list of \ -questions all at once. +## Response structure + +When analyzing the user's input, be COMPREHENSIVE — cover all relevant \ +topic areas in a single response. Use `## Heading` for each topic area \ +so the system can present them to the user one at a time. Ask 2–4 \ +focused questions per topic. + +When responding to follow-up answers about a SPECIFIC topic, stay \ +focused on that topic only. When you have no more questions about it, \ +respond ONLY with the word "Yes" (meaning yes, this section is complete). \ +Do not add any other text — just "Yes". ## How to behave -- **Never assume.** If the user hasn't told you something, ask. Don't \ - fill gaps with your own guesses. -- **Be conversational.** Respond to what they just said before asking \ - your next question. Acknowledge their answers. Build on them. -- **Ask open-ended questions.** Prefer "how", "what", "tell me about", \ - and "walk me through" over yes/no questions. Instead of "Do you need \ - authentication?", ask "How do you expect users to sign in?" Instead \ - of "Will there be multiple regions?", ask "What does your availability \ - story look like?" Open questions draw out richer detail and surface \ - requirements the user might not have thought to mention. -- **Go where the gaps are.** If they gave you a lot of detail on one \ - area, don't re-ask about it — move to something they haven't covered. -- **Explain briefly why you're asking** when it isn't obvious, so the \ - user understands the relevance. -- **Be comprehensive.** Ask 10–15 questions at a time, grouped by topic, \ - as long as they are relevant. This reduces round trips. Let the user \ - respond, then follow up on gaps. -- **Be pragmatic.** This is a prototype — but prototypes still need \ - solid requirements. Don't demand production-grade answers, but DO \ - explore each topic area thoroughly enough that the architect can make \ - informed decisions. +- **Never assume.** If the user hasn't told you something, ask. +- **Be warm and human.** You're a friendly colleague, not an interrogator. \ + Use natural language — "I'd love to hear more about...", "That gives me a \ + much clearer picture, thanks!", "Interesting — tell me more about..." +- **Acknowledge before asking.** Respond to what they just said before \ + asking your next questions. A brief "Got it — so [quick summary]" is fine. +- **Ask open-ended questions.** Prefer "how", "what", "tell me about" \ + over yes/no. Open questions draw out richer detail. +- **Go where the gaps are.** If they gave detail on one area, move on. +- **Don't restate everything mid-conversation.** Save comprehensive \ + summaries for the /summary command or final wrap-up. No "What I've \ + Understood So Far" sections. +- **Keep it real.** This is a prototype. If the user isn't sure, suggest \ + a reasonable default and move on. - **Be thorough before signalling readiness.** Ensure you have explored \ at least 8 of the topics listed below before deciding you have enough. \ When you feel the critical requirements are clear, say so naturally \ diff --git a/azext_prototype/agents/builtin/definitions/__init__.py b/azext_prototype/agents/builtin/definitions/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/azext_prototype/ai/token_tracker.py b/azext_prototype/ai/token_tracker.py index fbf8955..b9f7caf 100644 --- a/azext_prototype/ai/token_tracker.py +++ b/azext_prototype/ai/token_tracker.py @@ -12,6 +12,7 @@ # Used for budget-percentage display. Values are the *input* context # window (not total output limit). _CONTEXT_WINDOWS: dict[str, int] = { + # GPT models "gpt-4o": 128_000, "gpt-4o-mini": 128_000, "gpt-4-turbo": 128_000, @@ -19,10 +20,19 @@ "gpt-4-32k": 32_768, "gpt-35-turbo": 16_385, "gpt-3.5-turbo": 16_385, + # O-series "o1": 200_000, "o1-mini": 128_000, "o1-preview": 128_000, "o3-mini": 200_000, + # Claude models (Copilot) + "claude-sonnet-4": 200_000, + "claude-sonnet-4.5": 200_000, + "claude-haiku-4.5": 200_000, + "claude-opus-4": 200_000, + # Gemini models (Copilot) + "gemini-2.0-flash": 1_048_576, + "gemini-2.5-pro": 1_048_576, } diff --git a/azext_prototype/azext_metadata.json b/azext_prototype/azext_metadata.json index 149c1d9..0553af0 100644 --- a/azext_prototype/azext_metadata.json +++ b/azext_prototype/azext_metadata.json @@ -2,7 +2,7 @@ "azext.isPreview": true, "azext.minCliCoreVersion": "2.50.0", "name": "prototype", - "version": "0.2.1b3", + "version": "0.2.1b4", "azext.summary": "Azure CLI extension for building rapid prototypes with GitHub Copilot.", "license": "MIT", "classifiers": [ diff --git a/azext_prototype/commands.py b/azext_prototype/commands.py index 784e00c..fe3445c 100644 --- a/azext_prototype/commands.py +++ b/azext_prototype/commands.py @@ -6,6 +6,7 @@ def load_command_table(self, _): with self.command_group("prototype", is_preview=True) as g: g.custom_command("init", "prototype_init") + g.custom_command("launch", "prototype_launch") g.custom_command("design", "prototype_design") g.custom_command("build", "prototype_build") g.custom_command("deploy", "prototype_deploy") diff --git a/azext_prototype/custom.py b/azext_prototype/custom.py index 279edde..07a2959 100644 --- a/azext_prototype/custom.py +++ b/azext_prototype/custom.py @@ -9,6 +9,7 @@ import json import logging import os +import signal from datetime import datetime, timezone from pathlib import Path @@ -366,6 +367,51 @@ def prototype_init( return result +# ====================================================================== +# TUI Launch +# ====================================================================== + + +def _run_tui(app) -> None: + """Run a Textual app with clean Ctrl+C handling. + + Suppresses SIGINT during the Textual run so that Ctrl+C is handled + exclusively as a key event by the Textual binding (``ctrl+c`` → + ``action_quit``). This prevents ``KeyboardInterrupt`` from + propagating to the Azure CLI framework and, on Windows, eliminates + the "Terminate batch job (Y/N)?" prompt from ``az.cmd``. + """ + prev = signal.getsignal(signal.SIGINT) + try: + signal.signal(signal.SIGINT, lambda *_: None) + app.run() + except KeyboardInterrupt: + pass # clean exit + finally: + signal.signal(signal.SIGINT, prev) + + +@_quiet_output +@track("prototype launch") +def prototype_launch(cmd, stage=None): + """Launch the interactive TUI dashboard. + + Auto-detects the current project stage and launches the appropriate + session inside a Textual terminal application. + """ + from azext_prototype.ui.app import PrototypeApp + + project_dir = _get_project_dir() + + # Verify project is initialized + if not (Path(project_dir) / "prototype.yaml").is_file(): + raise CLIError("Run 'az prototype init' first.") + + app = PrototypeApp(start_stage=stage, project_dir=project_dir) + _run_tui(app) + return {"status": "completed"} + + @_quiet_output @track("prototype design") def prototype_design( @@ -399,14 +445,15 @@ def prototype_design( user, who may accept the compliant recommendation or override the policy. Overrides are tracked in the design state. """ - from azext_prototype.stages.design_stage import DesignStage - from azext_prototype.stages.discovery_state import DiscoveryState - from azext_prototype.ui.console import console + # --status: keep existing CLI behavior (quick check, no TUI) + if status: + from azext_prototype.stages.discovery_state import DiscoveryState + from azext_prototype.ui.console import console - project_dir, config, registry, agent_context = _prepare_command() + project_dir = _get_project_dir() + if not (Path(project_dir) / "prototype.yaml").is_file(): + raise CLIError("Run 'az prototype init' first.") - # Handle --status flag: just display current state and exit - if status: discovery_state = DiscoveryState(project_dir) if discovery_state.exists: discovery_state.load() @@ -432,21 +479,34 @@ def prototype_design( return {"status": "displayed"} - stage = DesignStage() - _check_guards(stage) - - try: - return stage.execute( - agent_context, - registry, - artifacts=artifacts, - context=context, - reset=reset, - interactive=interactive, - skip_discovery=skip_discovery, - ) - finally: - _shutdown_mcp(agent_context) + project_dir = _get_project_dir() + if not (Path(project_dir) / "prototype.yaml").is_file(): + raise CLIError("Run 'az prototype init' first.") + + # Resolve artifacts path to absolute before TUI takes over + resolved_artifacts = str(Path(artifacts).resolve()) if artifacts else None + + stage_kwargs = {} + if resolved_artifacts: + stage_kwargs["artifacts"] = resolved_artifacts + if context: + stage_kwargs["context"] = context + if reset: + stage_kwargs["reset"] = True + if interactive: + stage_kwargs["interactive"] = True + if skip_discovery: + stage_kwargs["skip_discovery"] = True + + from azext_prototype.ui.app import PrototypeApp + + app = PrototypeApp( + start_stage="design", + project_dir=project_dir, + stage_kwargs=stage_kwargs, + ) + _run_tui(app) + return {"status": "completed"} @_quiet_output diff --git a/azext_prototype/governance/policies/azure/__init__.py b/azext_prototype/governance/policies/azure/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/azext_prototype/governance/policies/integration/__init__.py b/azext_prototype/governance/policies/integration/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/azext_prototype/governance/policies/security/__init__.py b/azext_prototype/governance/policies/security/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/azext_prototype/governance/standards/application/__init__.py b/azext_prototype/governance/standards/application/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/azext_prototype/governance/standards/bicep/__init__.py b/azext_prototype/governance/standards/bicep/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/azext_prototype/governance/standards/principles/__init__.py b/azext_prototype/governance/standards/principles/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/azext_prototype/governance/standards/terraform/__init__.py b/azext_prototype/governance/standards/terraform/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/azext_prototype/knowledge/languages/__init__.py b/azext_prototype/knowledge/languages/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/azext_prototype/knowledge/roles/__init__.py b/azext_prototype/knowledge/roles/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/azext_prototype/knowledge/services/__init__.py b/azext_prototype/knowledge/services/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/azext_prototype/knowledge/tools/__init__.py b/azext_prototype/knowledge/tools/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/azext_prototype/stages/backlog_session.py b/azext_prototype/stages/backlog_session.py index 6d02504..82c55c2 100644 --- a/azext_prototype/stages/backlog_session.py +++ b/azext_prototype/stages/backlog_session.py @@ -34,6 +34,7 @@ ) from azext_prototype.stages.backlog_state import BacklogState from azext_prototype.stages.escalation import EscalationTracker +from azext_prototype.stages.intent import IntentKind, build_backlog_classifier from azext_prototype.stages.qa_router import route_error_to_qa from azext_prototype.ui.console import Console, DiscoveryPrompt from azext_prototype.ui.console import console as default_console @@ -146,6 +147,12 @@ def __init__( if self._escalation_tracker.exists: self._escalation_tracker.load() + # Intent classifier for natural language command detection + self._intent_classifier = build_backlog_classifier( + ai_provider=agent_context.ai_provider, + token_tracker=self._token_tracker, + ) + # ------------------------------------------------------------------ # # Public API # ------------------------------------------------------------------ # @@ -314,6 +321,23 @@ def run( break continue + # Natural language intent detection (commands only — not /add) + intent = self._intent_classifier.classify(user_input) + if intent.kind == IntentKind.COMMAND: + cmd_line = f"{intent.command} {intent.args}".strip() + handled = self._handle_slash_command( + cmd_line, + provider, + org, + project, + _input, + _print, + use_styled, + ) + if handled == "pushed": + break + continue + # Natural language — send to AI for item mutation exchange += 1 with self._maybe_spinner("Updating backlog...", use_styled): @@ -792,6 +816,11 @@ def _handle_slash_command( _print(" 'Update story 3 to include MFA'") _print(" 'Remove story 7'") _print("") + _print(" You can also use natural language for commands:") + _print(" 'show all items' instead of /list") + _print(" 'push item 3' instead of /push 3") + _print(" 'show me item 2' instead of /show 2") + _print("") return None @@ -952,10 +981,16 @@ def _get_production_items(self) -> str: return "" @contextmanager - def _maybe_spinner(self, message: str, use_styled: bool) -> Iterator[None]: + def _maybe_spinner(self, message: str, use_styled: bool, *, status_fn: Callable | None = None) -> Iterator[None]: """Show a spinner when using styled output, otherwise no-op.""" if use_styled: with self._console.spinner(message): yield + elif status_fn: + status_fn(message, "start") + try: + yield + finally: + status_fn(message, "end") else: yield diff --git a/azext_prototype/stages/build_session.py b/azext_prototype/stages/build_session.py index cfaaec6..ab8a7f7 100644 --- a/azext_prototype/stages/build_session.py +++ b/azext_prototype/stages/build_session.py @@ -22,6 +22,7 @@ import json import logging import re +import shutil from contextlib import contextmanager from pathlib import Path from typing import Any, Callable, Iterator @@ -36,6 +37,11 @@ from azext_prototype.parsers.file_extractor import parse_file_blocks, write_parsed_files from azext_prototype.stages.build_state import BuildState from azext_prototype.stages.escalation import EscalationTracker +from azext_prototype.stages.intent import ( + IntentKind, + build_build_classifier, + read_files_for_session, +) from azext_prototype.stages.policy_resolver import PolicyResolver from azext_prototype.stages.qa_router import route_error_to_qa from azext_prototype.ui.console import Console, DiscoveryPrompt @@ -49,7 +55,7 @@ _QUIT_WORDS = frozenset({"q", "quit", "exit"}) _DONE_WORDS = frozenset({"done", "finish", "accept", "lgtm"}) -_SLASH_COMMANDS = frozenset({"/status", "/stages", "/files", "/policy", "/help"}) +_SLASH_COMMANDS = frozenset({"/status", "/stages", "/files", "/policy", "/describe", "/help"}) # Maximum remediation cycles per stage before proceeding _MAX_STAGE_REMEDIATION_ATTEMPTS = 2 @@ -167,6 +173,12 @@ def __init__( # Token tracker self._token_tracker = TokenTracker() + # Intent classifier for natural language command detection + self._intent_classifier = build_build_classifier( + ai_provider=agent_context.ai_provider, + token_tracker=self._token_tracker, + ) + # Project config config = ProjectConfig(agent_context.project_dir) config.load() @@ -240,9 +252,12 @@ def run( _print(f"IaC Tool: {self._iac_tool}") _print("") - # ---- Phase 2: Derive deployment plan ---- + # ---- Phase 2: Derive deployment plan (three-branch) ---- existing_stages = self._build_state._state.get("deployment_stages", []) + skip_generation = False + if not existing_stages: + # Branch A: First build — derive fresh plan and save design snapshot _print("Deriving deployment plan...") _print("") @@ -254,41 +269,126 @@ def run( return BuildResult(cancelled=True) self._build_state.set_deployment_plan(stages) - else: - _print("Resuming from existing deployment plan.") + self._build_state.set_design_snapshot(design) + + elif self._build_state.design_has_changed(design): + # Branch B: Design changed — incremental rebuild + _print("Design changes detected since last build.") _print("") - # Present the plan - _print(self._build_state.format_stage_status()) - _print("") - _print("Review the deployment plan above.") - _print("Press Enter to start building, or provide feedback to adjust.") - _print("") + old_arch = self._build_state.get_previous_architecture() - try: - if use_styled: - confirmation = self._prompt.simple_prompt("> ") + if old_arch: + with self._maybe_spinner("Analyzing design changes...", use_styled): + diff_result = self._diff_architectures(old_arch, architecture, existing_stages) else: - confirmation = _input("> ").strip() - except (EOFError, KeyboardInterrupt): - return BuildResult(cancelled=True) - - if confirmation.lower() in _QUIT_WORDS: - return BuildResult(cancelled=True) - - # If user provides feedback, adjust the plan - if confirmation and confirmation.lower() not in _DONE_WORDS and confirmation.strip(): - with self._maybe_spinner("Adjusting deployment plan...", use_styled): - adjusted = self._adjust_plan(confirmation, architecture, templates) - if adjusted: - self._build_state.set_deployment_plan(adjusted) + # Legacy build with no snapshot text — treat all as modified + diff_result = { + "unchanged": [], + "modified": [s["stage"] for s in existing_stages], + "removed": [], + "added": [], + "plan_restructured": False, + "summary": "No previous architecture snapshot — marking all stages for rebuild.", + } + + _print(f" {diff_result.get('summary', 'Changes analyzed.')}") + _print("") + + if diff_result.get("plan_restructured"): + _print("The design changes are significant enough to require a full plan re-derive.") + _print("Press Enter to re-derive the full plan, or type 'quit' to cancel.") _print("") - _print(self._build_state.format_stage_status()) + try: + if use_styled: + confirm = self._prompt.simple_prompt("> ") + else: + confirm = _input("> ").strip() + except (EOFError, KeyboardInterrupt): + return BuildResult(cancelled=True) + if confirm.lower() in _QUIT_WORDS: + return BuildResult(cancelled=True) + + with self._maybe_spinner("Re-deriving deployment plan...", use_styled): + stages = self._derive_deployment_plan(architecture, templates) + if not stages: + _print("Could not derive deployment plan from architecture.") + return BuildResult(cancelled=True) + self._build_state.set_deployment_plan(stages) + else: + # Apply targeted updates + removed = diff_result.get("removed", []) + added = diff_result.get("added", []) + modified = diff_result.get("modified", []) + + if removed: + self._clean_removed_stage_files(removed, existing_stages) + self._build_state.remove_stages(removed) + _print(f" Removed {len(removed)} stage(s).") + + if added: + self._build_state.add_stages(added) + _print(f" Added {len(added)} new stage(s).") + + if modified: + self._build_state.mark_stages_stale(modified) + _print(f" Marked {len(modified)} stage(s) for regeneration.") + + if removed or added: + self._fix_stage_dirs() + + # Update the design snapshot + self._build_state.set_design_snapshot(design) + + else: + # Branch C: No design changes + pending_check = self._build_state.get_pending_stages() + if pending_check: + _print("Resuming from existing deployment plan.") + _print("") + else: + _print("Build is up to date — no design changes detected.") _print("") + skip_generation = True + + # Present the plan + _print(self._build_state.format_stage_status()) + _print("") + + if not skip_generation: + _print("Review the deployment plan above.") + _print("Press Enter to start building, or provide feedback to adjust.") + _print("") + + try: + if use_styled: + confirmation = self._prompt.simple_prompt("> ") + else: + confirmation = _input("> ").strip() + except (EOFError, KeyboardInterrupt): + return BuildResult(cancelled=True) + + if confirmation.lower() in _QUIT_WORDS: + return BuildResult(cancelled=True) + + # If user provides feedback, adjust the plan + if confirmation and confirmation.lower() not in _DONE_WORDS and confirmation.strip(): + with self._maybe_spinner("Adjusting deployment plan...", use_styled): + adjusted = self._adjust_plan(confirmation, architecture, templates) + if adjusted: + self._build_state.set_deployment_plan(adjusted) + _print("") + _print(self._build_state.format_stage_status()) + _print("") # ---- Phase 3: Staged generation ---- - pending = self._build_state.get_pending_stages() - total_stages = len(self._build_state._state["deployment_stages"]) + if skip_generation: + pending = [] + total_stages = len(self._build_state._state["deployment_stages"]) + generated_count = total_stages + else: + pending = self._build_state.get_pending_stages() + total_stages = len(self._build_state._state["deployment_stages"]) generated_count = len(self._build_state.get_generated_stages()) for stage in pending: @@ -413,7 +513,7 @@ def run( _print("") # ---- Phase 4: Advisory QA review ---- - if scope == "all" and self._qa_agent: + if not skip_generation and scope == "all" and self._qa_agent: _print("Running advisory review...") file_content = self._collect_generated_file_content() @@ -487,9 +587,10 @@ def run( _print("") # ---- Phase 5: Build report ---- - _print("") - _print(self._build_state.format_build_report()) - _print("") + if not skip_generation: + _print("") + _print(self._build_state.format_build_report()) + _print("") # ---- Phase 6: Review loop ---- _print("Review the build output above.") @@ -516,10 +617,24 @@ def run( lower = user_input.lower() # Slash commands - if lower in _SLASH_COMMANDS: + if lower.startswith("/"): self._handle_slash_command(lower, _print) continue + # Natural language intent detection + intent = self._intent_classifier.classify(user_input) + if intent.kind == IntentKind.COMMAND: + if intent.command == "/describe" and intent.args: + self._handle_describe(intent.args, _print) + else: + self._handle_slash_command(intent.command, _print) + continue + if intent.kind == IntentKind.READ_FILES: + text, _ = read_files_for_session(intent.args, self._context.project_dir, _print) + if text: + user_input = f"{user_input}\n\n## File Content\n{text}" + # Fall through to feedback handler with enriched input + if lower in _QUIT_WORDS: return BuildResult( files_generated=self._build_state._state.get("files_generated", []), @@ -652,6 +767,11 @@ def _derive_deployment_plan( "Each stage must have: stage (number), name, category " "(infra|data|app|schema|integration|docs|cicd|external), dir (output " f"directory path), services (array), status ('pending'), files (empty array).\n\n" + "Optional per-stage fields:\n" + "- deploy_mode: 'auto' (default, deploy via IaC/scripts) or 'manual' " + "(step that cannot be scripted, e.g., portal configuration)\n" + "- manual_instructions: when deploy_mode is 'manual', provide clear " + "step-by-step instructions for the user\n\n" f"Use '{self._iac_tool}' for IaC directories. Infrastructure stage dirs " f"should be like: concept/infra/{self._iac_tool}/stage-N-name/\n" "App stage dirs: concept/apps/stage-N-name/\n" @@ -714,17 +834,18 @@ def _normalise_stages(self, stages: list[dict]) -> list[dict]: for s in stages: if not isinstance(s, dict): continue - normalised.append( - { - "stage": s.get("stage", len(normalised) + 1), - "name": s.get("name", f"Stage {len(normalised) + 1}"), - "category": s.get("category", "infra"), - "dir": s.get("dir", ""), - "services": s.get("services", []), - "status": "pending", - "files": [], - } - ) + entry = { + "stage": s.get("stage", len(normalised) + 1), + "name": s.get("name", f"Stage {len(normalised) + 1}"), + "category": s.get("category", "infra"), + "dir": s.get("dir", ""), + "services": s.get("services", []), + "status": "pending", + "files": [], + "deploy_mode": s.get("deploy_mode", "auto"), + "manual_instructions": s.get("manual_instructions"), + } + normalised.append(entry) return normalised def _fallback_deployment_plan(self, templates: list) -> list[dict]: @@ -940,6 +1061,180 @@ def _adjust_plan( return self._parse_deployment_plan(response.content) return None + # ------------------------------------------------------------------ # + # Internal — incremental rebuild helpers + # ------------------------------------------------------------------ # + + def _diff_architectures( + self, + old_arch: str, + new_arch: str, + existing_stages: list[dict], + ) -> dict: + """Ask the architect to compare old and new architectures. + + Returns a dict classifying each existing stage as unchanged, + modified, or removed, plus any new stages to add. + + Falls back to marking all stages as modified when the architect + is unavailable or the response cannot be parsed. + """ + all_modified_fallback: dict = { + "unchanged": [], + "modified": [s["stage"] for s in existing_stages], + "removed": [], + "added": [], + "plan_restructured": False, + "summary": "Could not analyze changes — marking all stages for rebuild.", + } + + if not self._architect_agent or not self._context.ai_provider: + return all_modified_fallback + + stage_info = json.dumps( + [ + { + "stage": s["stage"], + "name": s["name"], + "category": s.get("category", "infra"), + "services": [svc.get("name", "") for svc in s.get("services", [])], + } + for s in existing_stages + ], + indent=2, + ) + + task = ( + "Compare the OLD and NEW architecture designs and determine how each " + "existing deployment stage is affected.\n\n" + f"## Old Architecture\n{old_arch}\n\n" + f"## New Architecture\n{new_arch}\n\n" + f"## Existing Deployment Stages\n```json\n{stage_info}\n```\n\n" + "## Instructions\n" + "Classify each stage number as:\n" + "- **unchanged**: no impact from the design changes\n" + "- **modified**: services or configuration in this stage changed\n" + "- **removed**: the services in this stage no longer exist in the new design\n\n" + "Also identify any NEW services that need new stages.\n\n" + "Set `plan_restructured: true` ONLY if the fundamental deployment " + "order or stage boundaries need to change (e.g., services moved between " + "stages, major dependency changes). Minor additions/removals should NOT " + "set this flag.\n\n" + "Return ONLY valid JSON:\n" + "```json\n" + "{\n" + ' "unchanged": [1, 2],\n' + ' "modified": [3],\n' + ' "removed": [4],\n' + ' "added": [{"name": "Redis Cache", "category": "data", "services": ' + '[{"name": "redis-cache", "computed_name": "", "resource_type": ' + '"Microsoft.Cache/redis", "sku": "Basic"}]}],\n' + ' "plan_restructured": false,\n' + ' "summary": "Added Redis cache; modified API to use Redis"\n' + "}\n" + "```\n" + ) + + try: + response = self._architect_agent.execute(self._context, task) + if response: + self._token_tracker.record(response) + if response and response.content: + result = self._parse_diff_result(response.content, existing_stages) + if result: + return result + except Exception: + logger.debug("Architecture diff failed", exc_info=True) + + return all_modified_fallback + + def _parse_diff_result(self, content: str, existing_stages: list[dict]) -> dict | None: + """Parse the architect's diff response into a structured result. + + Validates that referenced stage numbers actually exist. Stages + not mentioned by the architect default to ``unchanged``. + """ + # Try fenced JSON block first + json_match = re.search(r"```(?:json)?\s*\n(.*?)\n```", content, re.DOTALL) + raw = json_match.group(1) if json_match else content.strip() + + try: + data = json.loads(raw) + except (json.JSONDecodeError, TypeError): + return None + + if not isinstance(data, dict): + return None + + existing_nums = {s["stage"] for s in existing_stages} + + unchanged = [n for n in data.get("unchanged", []) if isinstance(n, int) and n in existing_nums] + modified = [n for n in data.get("modified", []) if isinstance(n, int) and n in existing_nums] + removed = [n for n in data.get("removed", []) if isinstance(n, int) and n in existing_nums] + + # Stages not mentioned default to unchanged + mentioned = set(unchanged) | set(modified) | set(removed) + for num in existing_nums: + if num not in mentioned: + unchanged.append(num) + + added = data.get("added", []) + if not isinstance(added, list): + added = [] + # Normalise added stages + normalised_added = [] + for item in added: + if isinstance(item, dict) and item.get("name"): + normalised_added.append( + { + "name": item["name"], + "category": item.get("category", "infra"), + "services": item.get("services", []), + "dir": item.get("dir", ""), + } + ) + + return { + "unchanged": sorted(unchanged), + "modified": sorted(modified), + "removed": sorted(removed), + "added": normalised_added, + "plan_restructured": bool(data.get("plan_restructured", False)), + "summary": data.get("summary", "Design changes analyzed."), + } + + def _clean_removed_stage_files(self, removed_nums: list[int], stages: list[dict]) -> None: + """Delete generated directories from disk for removed stages.""" + project_root = Path(self._context.project_dir) + for stage in stages: + if stage["stage"] in removed_nums: + stage_dir = stage.get("dir", "") + if stage_dir: + full_path = project_root / stage_dir + if full_path.exists() and full_path.is_dir(): + shutil.rmtree(full_path, ignore_errors=True) + logger.info("Removed stage directory: %s", full_path) + + def _fix_stage_dirs(self) -> None: + """Update stage directory paths to match current stage numbers. + + After renumbering, stage dirs like ``stage-4-redis`` may need to + become ``stage-3-redis`` if a prior stage was removed. + """ + for stage in self._build_state._state.get("deployment_stages", []): + old_dir = stage.get("dir", "") + if not old_dir: + continue + # Match pattern: .../stage-N-name + match = re.match(r"^(.*?/?)stage-\d+(-.*)?$", old_dir) + if match: + prefix = match.group(1) + suffix = match.group(2) or "" + new_dir = f"{prefix}stage-{stage['stage']}{suffix}" + if new_dir != old_dir: + stage["dir"] = new_dir + self._build_state.save() + # ------------------------------------------------------------------ # # Internal — stage generation # ------------------------------------------------------------------ # @@ -1332,34 +1627,86 @@ def _identify_stages_regex(self, feedback: str) -> list[int]: def _handle_slash_command(self, command: str, _print: Callable) -> None: """Handle build-session slash commands.""" - if command == "/status": - _print("") - _print(self._build_state.format_stage_status()) - _print("") - elif command == "/stages": + parts = command.strip().split(maxsplit=1) + cmd = parts[0] + arg = parts[1].strip() if len(parts) > 1 else "" + + if cmd in ("/status", "/stages"): _print("") _print(self._build_state.format_stage_status()) _print("") - elif command == "/files": + elif cmd == "/files": _print("") _print(self._build_state.format_files_list()) _print("") - elif command == "/policy": + elif cmd == "/policy": _print("") _print(self._build_state.format_policy_summary()) _print("") - elif command == "/help": + elif cmd == "/describe": + self._handle_describe(arg, _print) + elif cmd == "/help": _print("") _print("Available commands:") - _print(" /status - Show stage completion summary") - _print(" /stages - Show full deployment plan") - _print(" /files - List all generated files") - _print(" /policy - Show policy check summary") - _print(" /help - Show this help") - _print(" done - Accept build and exit") - _print(" quit - Cancel and exit") + _print(" /status - Show stage completion summary") + _print(" /stages - Show full deployment plan") + _print(" /files - List all generated files") + _print(" /policy - Show policy check summary") + _print(" /describe N - Show details for stage N") + _print(" /help - Show this help") + _print(" done - Accept build and exit") + _print(" quit - Cancel and exit") + _print("") + _print(" You can also use natural language:") + _print(" 'what's the build status' instead of /status") + _print(" 'show the generated files' instead of /files") + _print(" 'describe stage 2' instead of /describe 2") _print("") + def _handle_describe(self, arg: str, _print: Callable) -> None: + """Show detailed description of a build stage.""" + if not arg or not arg.strip(): + _print(" Usage: /describe N (stage number)") + return + + numbers = re.findall(r"\d+", arg) + if not numbers: + _print(" Usage: /describe N (stage number)") + return + + stage_num = int(numbers[0]) + stage = self._build_state.get_stage(stage_num) + if not stage: + _print(f" Stage {stage_num} not found.") + return + + _print("") + _print(f" Stage {stage_num}: {stage.get('name', '?')}") + _print(f" Category: {stage.get('category', '?')}") + _print(f" Status: {stage.get('status', 'pending')}") + _print(f" Dir: {stage.get('dir', '?')}") + + services = stage.get("services", []) + if services: + _print(f" Resources ({len(services)}):") + for svc in services: + name = svc.get("computed_name") or svc.get("name", "?") + rtype = svc.get("resource_type", "") + sku = svc.get("sku", "") + line = f" - {name}" + if rtype: + line += f" ({rtype})" + if sku: + line += f" [{sku}]" + _print(line) + + files = stage.get("files", []) + if files: + _print(f" Files ({len(files)}):") + for f in files: + _print(f" - {f}") + _print("") + # ------------------------------------------------------------------ # # Internal — utilities # ------------------------------------------------------------------ # @@ -1533,10 +1880,16 @@ def _collect_generated_file_content(self, max_bytes: int = 50_000) -> str: return "\n\n".join(parts) @contextmanager - def _maybe_spinner(self, message: str, use_styled: bool) -> Iterator[None]: + def _maybe_spinner(self, message: str, use_styled: bool, *, status_fn: Callable | None = None) -> Iterator[None]: """Show a spinner when using styled output, otherwise no-op.""" if use_styled: with self._console.spinner(message): yield + elif status_fn: + status_fn(message, "start") + try: + yield + finally: + status_fn(message, "end") else: yield diff --git a/azext_prototype/stages/build_state.py b/azext_prototype/stages/build_state.py index 168f270..71b6aa0 100644 --- a/azext_prototype/stages/build_state.py +++ b/azext_prototype/stages/build_state.py @@ -18,7 +18,9 @@ from __future__ import annotations +import hashlib import logging +import re from datetime import datetime, timezone from pathlib import Path from typing import Any @@ -27,6 +29,29 @@ logger = logging.getLogger(__name__) + +def _slugify(name: str) -> str: + """Convert a stage name to a URL-safe slug for use as a stable ID. + + Example: "Data Layer" → "data-layer" + """ + slug = name.lower().strip() + slug = re.sub(r"[^a-z0-9]+", "-", slug) + slug = slug.strip("-") + return slug or "stage" + + +def _ensure_unique_id(slug: str, existing: set[str]) -> str: + """Append a numeric suffix if *slug* already exists in *existing*.""" + if slug not in existing: + return slug + for i in range(2, 1000): + candidate = f"{slug}-{i}" + if candidate not in existing: + return candidate + return f"{slug}-{len(existing)}" + + BUILD_STATE_FILE = ".prototype/state/build.yaml" @@ -44,6 +69,11 @@ def _default_build_state() -> dict[str, Any]: "review_decisions": [], "conversation_history": [], "resources": [], + "design_snapshot": { + "iteration": None, + "architecture_hash": None, + "architecture_text": None, + }, "_metadata": { "created": None, "last_updated": None, @@ -91,6 +121,7 @@ def load(self) -> dict[str, Any]: loaded = yaml.safe_load(f) or {} self._state = _default_build_state() self._deep_merge(self._state, loaded) + self._backfill_ids() self._loaded = True logger.info("Loaded build state from %s", self._path) except (yaml.YAMLError, IOError) as e: @@ -154,6 +185,7 @@ def set_deployment_plan(self, stages: list[dict]) -> None: } """ self._state["deployment_stages"] = stages + self._assign_stable_ids() # Rebuild the aggregated resources list self._rebuild_resources() self.save() @@ -210,6 +242,111 @@ def get_stage(self, stage_num: int) -> dict | None: return stage return None + def get_stage_by_id(self, stage_id: str) -> dict | None: + """Return a specific stage by its stable ``id``.""" + for stage in self._state["deployment_stages"]: + if stage.get("id") == stage_id: + return stage + return None + + # ------------------------------------------------------------------ # + # Design snapshot — change detection for incremental rebuilds + # ------------------------------------------------------------------ # + + def set_design_snapshot(self, design: dict) -> None: + """Store a snapshot of the current design for future change detection. + + Captures the design iteration number, a content hash of the + architecture text, and the full architecture text for diffing. + """ + architecture = design.get("architecture", "") + self._state["design_snapshot"] = { + "iteration": design.get("_metadata", {}).get("iteration"), + "architecture_hash": hashlib.sha256(architecture.encode("utf-8")).hexdigest()[:16], + "architecture_text": architecture, + } + self.save() + + def design_has_changed(self, design: dict) -> bool: + """Check whether the design has changed since the last build. + + Returns ``True`` when the architecture content hash differs from + the stored snapshot, or when no snapshot exists (legacy builds). + """ + snapshot = self._state.get("design_snapshot", {}) + stored_hash = snapshot.get("architecture_hash") + if not stored_hash: + return True + + architecture = design.get("architecture", "") + current_hash = hashlib.sha256(architecture.encode("utf-8")).hexdigest()[:16] + return current_hash != stored_hash + + def get_previous_architecture(self) -> str | None: + """Return the stored architecture text from the last build, if any.""" + snapshot = self._state.get("design_snapshot", {}) + return snapshot.get("architecture_text") + + def mark_stages_stale(self, stage_nums: list[int]) -> None: + """Reset specific stages to ``pending`` without clearing their files. + + This allows the generation phase to re-generate only these stages + while preserving previously generated work on unaffected stages. + """ + for stage in self._state["deployment_stages"]: + if stage["stage"] in stage_nums: + stage["status"] = "pending" + self.save() + + def remove_stages(self, stage_nums: list[int]) -> None: + """Remove stages by number and clean up file references.""" + nums_set = set(stage_nums) + removed_files: list[str] = [] + for stage in self._state["deployment_stages"]: + if stage["stage"] in nums_set: + removed_files.extend(stage.get("files", [])) + + self._state["deployment_stages"] = [s for s in self._state["deployment_stages"] if s["stage"] not in nums_set] + + # Remove from files_generated + if removed_files: + removed_set = set(removed_files) + self._state["files_generated"] = [f for f in self._state["files_generated"] if f not in removed_set] + + self._rebuild_resources() + self.save() + + def add_stages(self, new_stages: list[dict]) -> None: + """Insert new stages before the docs stage and assign sequential numbers. + + New stages are inserted just before the last documentation stage + (if one exists), otherwise appended at the end. + """ + existing = self._state["deployment_stages"] + + # Find insertion point — before the docs stage + insert_idx = len(existing) + for i, s in enumerate(existing): + if s.get("category") == "docs": + insert_idx = i + break + + for ns in new_stages: + ns.setdefault("status", "pending") + ns.setdefault("files", []) + ns.setdefault("dir", "") + existing.insert(insert_idx, ns) + insert_idx += 1 + + self._assign_stable_ids() + self.renumber_stages() + + def renumber_stages(self) -> None: + """Renumber all stages sequentially starting from 1.""" + for idx, stage in enumerate(self._state["deployment_stages"], start=1): + stage["stage"] = idx + self.save() + # ------------------------------------------------------------------ # # Policy tracking # ------------------------------------------------------------------ # @@ -460,6 +597,31 @@ def format_policy_summary(self) -> str: # Internals # ------------------------------------------------------------------ # + def _assign_stable_ids(self) -> None: + """Ensure every deployment stage has a unique ``id`` field. + + Stages that already have an ``id`` keep it. Stages without one + get an ID derived from :func:`_slugify` on their name. + """ + existing_ids: set[str] = set() + for stage in self._state["deployment_stages"]: + sid = stage.get("id") + if sid: + existing_ids.add(sid) + + for stage in self._state["deployment_stages"]: + if not stage.get("id"): + slug = _slugify(stage.get("name", "stage")) + stage["id"] = _ensure_unique_id(slug, existing_ids) + existing_ids.add(stage["id"]) + # Ensure deploy_mode defaults + stage.setdefault("deploy_mode", "auto") + stage.setdefault("manual_instructions", None) + + def _backfill_ids(self) -> None: + """Backfill ``id``, ``deploy_mode``, and ``manual_instructions`` on legacy state files.""" + self._assign_stable_ids() + def _deep_merge(self, base: dict, updates: dict) -> None: """Deep merge updates into base dict.""" for key, value in updates.items(): diff --git a/azext_prototype/stages/deploy_helpers.py b/azext_prototype/stages/deploy_helpers.py index a3f682e..8eca1a6 100644 --- a/azext_prototype/stages/deploy_helpers.py +++ b/azext_prototype/stages/deploy_helpers.py @@ -1182,12 +1182,19 @@ def get_rollback_instructions(self, scope: str = "all") -> list[str]: return instructions - def snapshot_stage(self, stage_num: int, scope: str, iac_tool: str) -> dict: + def snapshot_stage( + self, + stage_num: int, + scope: str, + iac_tool: str, + build_stage_id: str | None = None, + ) -> dict: """Record per-stage pre-deployment snapshot.""" snapshot = { "stage": stage_num, "scope": scope, "iac_tool": iac_tool, + "build_stage_id": build_stage_id, "timestamp": datetime.now(timezone.utc).isoformat(), } if "stage_snapshots" not in self._state: diff --git a/azext_prototype/stages/deploy_session.py b/azext_prototype/stages/deploy_session.py index 1eeb110..5d76471 100644 --- a/azext_prototype/stages/deploy_session.py +++ b/azext_prototype/stages/deploy_session.py @@ -19,6 +19,7 @@ from __future__ import annotations import logging +import re import subprocess from contextlib import contextmanager from pathlib import Path @@ -28,6 +29,7 @@ from azext_prototype.agents.registry import AgentRegistry from azext_prototype.ai.token_tracker import TokenTracker from azext_prototype.config import ProjectConfig +from azext_prototype.parsers.file_extractor import parse_file_blocks, write_parsed_files from azext_prototype.stages.deploy_helpers import ( DeploymentOutputCapture, RollbackManager, @@ -48,6 +50,7 @@ ) from azext_prototype.stages.deploy_state import DeployState from azext_prototype.stages.escalation import EscalationTracker +from azext_prototype.stages.intent import IntentKind, build_deploy_classifier from azext_prototype.stages.qa_router import route_error_to_qa from azext_prototype.tracking import ChangeTracker from azext_prototype.ui.console import Console, DiscoveryPrompt @@ -55,6 +58,14 @@ logger = logging.getLogger(__name__) +# Maximum auto-remediation cycles per stage before falling through to interactive +_MAX_DEPLOY_REMEDIATION_ATTEMPTS = 2 + +# Files that should never be written for each IaC tool. +_BLOCKED_FILES: dict[str, set[str]] = { + "terraform": {"versions.tf"}, +} + def _lookup_deployer_object_id(client_id: str | None = None) -> str | None: """Resolve the AAD object ID of the deployer. @@ -95,6 +106,7 @@ def _lookup_deployer_object_id(client_id: str | None = None) -> str | None: "/rollback", "/redeploy", "/plan", + "/describe", "/outputs", "/preflight", "/login", @@ -176,6 +188,19 @@ def __init__( qa_agents = registry.find_by_capability(AgentCapability.QA) self._qa_agent = qa_agents[0] if qa_agents else None + # Resolve IaC, dev, and architect agents for remediation + self._iac_agents: dict[str, Any] = {} + for cap, key in [(AgentCapability.TERRAFORM, "terraform"), (AgentCapability.BICEP, "bicep")]: + agents = registry.find_by_capability(cap) + if agents: + self._iac_agents[key] = agents[0] + + dev_agents = registry.find_by_capability(AgentCapability.DEVELOP) + self._dev_agent = dev_agents[0] if dev_agents else None + + architect_agents = registry.find_by_capability(AgentCapability.ARCHITECT) + self._architect_agent = architect_agents[0] if architect_agents else None + # Escalation tracker self._escalation_tracker = EscalationTracker(agent_context.project_dir) if self._escalation_tracker.exists: @@ -190,6 +215,12 @@ def __init__( # Token tracker self._token_tracker = TokenTracker() + # Intent classifier for natural language command detection + self._intent_classifier = build_deploy_classifier( + ai_provider=agent_context.ai_provider, + token_tracker=self._token_tracker, + ) + # Deployment helpers self._output_capture = DeploymentOutputCapture(agent_context.project_dir) self._rollback_mgr = RollbackManager(agent_context.project_dir) @@ -280,11 +311,22 @@ def run( _print = print_fn or self._console.print # ---- Phase 1: Load build state ---- + build_path = Path(self._context.project_dir) / ".prototype" / "state" / "build.yaml" if not self._deploy_state._state["deployment_stages"]: - build_path = Path(self._context.project_dir) / ".prototype" / "state" / "build.yaml" if not self._deploy_state.load_from_build_state(build_path): _print(" No build state found. Run 'az prototype build' first.") return DeployResult(cancelled=True) + else: + # Re-entry: sync with latest build state + sync = self._deploy_state.sync_from_build_state(build_path) + if sync.created or sync.orphaned or sync.updated_code: + _print("") + _print(" Build state changed since last deploy:") + for detail in sync.details: + _print(f" - {detail}") + if sync.updated_code: + _print(f" - {sync.updated_code} deployed stage(s) have updated code") + _print("") # Resolve subscription / resource group / tenant / SP creds self._resolve_context(subscription, tenant, client_id, client_secret) @@ -387,7 +429,34 @@ def run( ) continue - _print(" Use slash commands to manage deployment. Type /help for a list.") + # Natural language intent detection + intent = self._intent_classifier.classify(user_input) + if intent.kind == IntentKind.COMMAND: + # Multi-stage support: "deploy stages 3 and 4" + import re as _re + + numbers = _re.findall(r"\d+", intent.args) if intent.args else [] + if len(numbers) > 1 and intent.command in ("/deploy", "/rollback"): + for num in numbers: + self._handle_slash_command( + f"{intent.command} {num}", + force, + use_styled, + _print, + _input, + ) + else: + cmd_line = f"{intent.command} {intent.args}".strip() + self._handle_slash_command( + cmd_line, + force, + use_styled, + _print, + _input, + ) + continue + + _print(" Type /help for commands, or use natural language (e.g. 'deploy stage 3').") return self._build_result() @@ -508,6 +577,17 @@ def run_single_stage( else: _print(f" Stage {stage_num} failed: {result.get('error', 'unknown error')}") + # Attempt non-interactive remediation + remediated = self._remediate_deploy_failure( + stage, + result, + False, + _print, + lambda p: "", + ) + if remediated and remediated.get("status") == "deployed": + _print(f" Stage {stage_num} deployed after remediation.") + return self._build_result() # ------------------------------------------------------------------ # @@ -799,10 +879,30 @@ def _deploy_pending_stages( # Capture outputs after infra stages if category in ("infra", "data", "integration"): self._capture_stage_outputs(stage) + elif result.get("status") == "awaiting_manual": + instructions = result.get("instructions", "No instructions provided.") + _print(" Manual step required:") + _print(f" {instructions}") + _print("") + _print(" When complete, enter: Done / Skip / Need help") + try: + answer = _input(" > ").strip().lower() + except (EOFError, KeyboardInterrupt): + _print(" Skipped.") + continue + if answer in ("done", "d", "yes", "y"): + self._deploy_state.mark_stage_deployed(stage_num) + _print(" Marked as deployed.") + elif answer in ("skip", "s"): + _print(" Skipped. Use /deploy to continue later.") + else: + _print(" Pausing deployment. Use /deploy to continue.") + break elif result.get("status") == "failed": _print(f" Failed: {result.get('error', 'unknown error')[:120]}") - self._handle_deploy_failure(stage, result, use_styled, _print, _input) - break # Stop sequential deployment — user decides via interactive loop + remediated = self._handle_deploy_failure(stage, result, use_styled, _print, _input) + if remediated.get("status") != "deployed": + break # Stop sequential deployment — user decides via interactive loop else: _print(f" Skipped: {result.get('reason', 'no action needed')}") @@ -812,13 +912,24 @@ def _deploy_single_stage(self, stage: dict[str, Any]) -> dict[str, Any]: """Deploy one stage and update state.""" stage_num = stage["stage"] category = stage.get("category", "infra") + deploy_mode = stage.get("deploy_mode", "auto") + + # Manual steps don't execute — they return a special status + if deploy_mode == "manual": + self._deploy_state.mark_stage_awaiting_manual(stage_num) + return { + "status": "awaiting_manual", + "instructions": stage.get("manual_instructions", "No instructions provided."), + } + stage_dir = Path(self._context.project_dir) / stage.get("dir", "") if not stage_dir.is_dir(): return {"status": "skipped", "reason": f"Directory not found: {stage.get('dir', '?')}"} # Snapshot before deploy - self._rollback_mgr.snapshot_stage(stage_num, category, self._iac_tool) + build_stage_id = stage.get("build_stage_id") + self._rollback_mgr.snapshot_stage(stage_num, category, self._iac_tool, build_stage_id=build_stage_id) self._deploy_state.mark_stage_deploying(stage_num) # Resolve generated secrets for Terraform stages (TF_VAR_* env vars) @@ -890,36 +1001,550 @@ def _handle_deploy_failure( use_styled: bool, _print: Callable[[str], None], _input: Callable[[str], str], - ) -> None: - """Route deployment failure to QA agent for diagnosis.""" - error_text = result.get("error", "Unknown error") - stage_info = f"Stage {stage['stage']}: {stage['name']}" + ) -> dict[str, Any]: + """Attempt auto-remediation of a deploy failure, falling through to interactive options. + Returns the final deploy result for the stage (may be ``"deployed"`` if + remediation succeeds, or the original failure if remediation is exhausted + or unavailable). + """ + # Attempt auto-remediation when agents are available + remediated = self._remediate_deploy_failure(stage, result, use_styled, _print, _input) + if remediated and remediated.get("status") == "deployed": + return remediated + + # Remediation not attempted — fall back to QA-only diagnosis + if remediated is None: + error_text = result.get("error", "Unknown error") + stage_info = f"Stage {stage['stage']}: {stage['name']}" + services = stage.get("services", []) + svc_names = [s.get("name", "") for s in services if s.get("name")] + + qa_result = route_error_to_qa( + error_text, + f"Deploy {stage_info}", + self._qa_agent, + self._context, + self._token_tracker, + _print, + services=svc_names, + escalation_tracker=self._escalation_tracker, + source_agent="deploy-session", + source_stage="deploy", + ) + + if not qa_result["diagnosed"]: + _print("") + _print(f" Error: {error_text[:500]}") + + if use_styled and qa_result.get("response"): + self._console.print_token_status(self._token_tracker.format_status()) + + # Show interactive options + _print("") + _print(" Options: /deploy (retry) | /rollback (undo) | /help | quit") + return remediated or result + + # ------------------------------------------------------------------ # + # Internal — Deploy failure remediation + # ------------------------------------------------------------------ # + + def _remediate_deploy_failure( + self, + stage: dict[str, Any], + result: dict[str, Any], + use_styled: bool, + _print: Callable[[str], None], + _input: Callable[[str], str], + ) -> dict[str, Any] | None: + """Closed-loop remediation: QA diagnoses -> architect guides -> IaC/dev fixes -> redeploy. + + Returns the final deploy result, or ``None`` if remediation cannot be + attempted (no agents / no AI provider). + """ + # Guard: need at minimum QA + one fix agent + AI provider + has_fix_agent = bool(self._iac_agents.get(self._iac_tool) or self._dev_agent) + if not self._qa_agent or not has_fix_agent or not self._context.ai_provider: + return None + + error_text = result.get("error", "Unknown error") + stage_num = stage["stage"] + stage_info = f"Stage {stage_num}: {stage['name']}" services = stage.get("services", []) svc_names = [s.get("name", "") for s in services if s.get("name")] + final_result = result + + for attempt in range(1, _MAX_DEPLOY_REMEDIATION_ATTEMPTS + 1): + current_attempts = stage.get("remediation_attempts", 0) + if current_attempts >= _MAX_DEPLOY_REMEDIATION_ATTEMPTS: + _print(f" Auto-remediation exhausted ({current_attempts} attempts) for {stage_info}.") + break + + # 1. QA diagnosis + qa_result = route_error_to_qa( + error_text, + f"Deploy {stage_info}", + self._qa_agent, + self._context, + self._token_tracker, + _print, + services=svc_names, + escalation_tracker=self._escalation_tracker, + source_agent="deploy-session", + source_stage="deploy", + ) + + if not qa_result["diagnosed"]: + _print("") + _print(f" Error: {error_text[:500]}") + break + + qa_diagnosis = qa_result.get("content", "") + + # 2. Architect fix guidance + architect_guidance = self._get_architect_fix_guidance(stage, error_text, qa_diagnosis) + + # 3. Mark stage as remediating + self._deploy_state.mark_stage_remediating(stage_num) + _print(f" Remediating {stage_info} (attempt {attempt})...") + + # 4. Build and execute fix task + agent, task = self._build_fix_task(stage, error_text, qa_diagnosis, architect_guidance) + if not agent: + _print(" No suitable agent available for remediation.") + break + + with self._maybe_spinner(f"Fixing {stage_info}...", use_styled): + try: + fix_response = agent.execute(self._context, task) + except Exception: + logger.debug("Fix agent failed during remediation", exc_info=True) + _print(" Fix agent encountered an error.") + break + + if fix_response: + self._token_tracker.record(fix_response) + + fix_content = fix_response.content if fix_response else "" + if not fix_content: + _print(" Fix agent returned no content.") + break + + # 5. Write fixed files + written = self._write_stage_files(stage, fix_content) + if written: + _print(f" Wrote {len(written)} file(s): {', '.join(Path(f).name for f in written[:5])}") + else: + _print(" No file blocks found in fix response.") + break - qa_result = route_error_to_qa( - error_text, - f"Deploy {stage_info}", - self._qa_agent, - self._context, - self._token_tracker, - _print, - services=svc_names, - escalation_tracker=self._escalation_tracker, - source_agent="deploy-session", - source_stage="deploy", + # 6. Check downstream impact + downstream = self._check_downstream_impact(stage, architect_guidance) + + # 7. Reset and re-deploy + self._deploy_state.reset_stage_to_pending(stage_num) + + with self._maybe_spinner(f"Re-deploying {stage_info}...", use_styled): + final_result = self._deploy_single_stage(stage) + + if final_result.get("status") == "deployed": + _print(f" {stage_info} deployed successfully after remediation.") + + # Capture outputs for infra stages + if stage.get("category") in ("infra", "data", "integration"): + self._capture_stage_outputs(stage) + + # Regenerate downstream stages if needed + if downstream: + self._regenerate_downstream_stages(downstream, use_styled, _print) + + return final_result + + # Failed again — loop for next attempt + error_text = final_result.get("error", "unknown error") + _print(f" Re-deploy failed: {error_text[:120]}") + + return final_result + + def _collect_stage_file_content(self, stage: dict, max_bytes: int = 20_000) -> str: + """Collect content of generated files for a single deploy stage. + + Falls back to globbing the stage directory when the ``files`` list is + empty (deploy stages may not always have files tracked). + """ + project_root = Path(self._context.project_dir) + parts: list[str] = [] + total = 0 + + files = stage.get("files", []) + + # Fallback: glob the stage directory for common IaC/app file types + if not files: + stage_dir = project_root / stage.get("dir", "") + if stage_dir.is_dir(): + for pattern in ("*.tf", "*.bicep", "*.sh", "*.py", "*.cs", "*.json", "*.yaml"): + for f in stage_dir.glob(pattern): + try: + rel = str(f.relative_to(project_root)) + files.append(rel) + except ValueError: + files.append(str(f)) + + if not files: + return "" + + for filepath in files: + if total >= max_bytes: + parts.append("\n(remaining files omitted — size cap reached)") + break + + full_path = project_root / filepath + try: + content = full_path.read_text(encoding="utf-8") + except (OSError, UnicodeDecodeError): + parts.append(f"```{filepath}\n(could not read file)\n```") + continue + + per_file_cap = 8_000 + if len(content) > per_file_cap: + content = content[:per_file_cap] + "\n... (truncated)" + + block = f"```{filepath}\n{content}\n```" + total += len(block) + parts.append(block) + + return "\n\n".join(parts) + + def _build_fix_task( + self, + stage: dict, + deploy_error: str, + qa_diagnosis: str, + architect_guidance: str, + ) -> tuple[Any | None, str]: + """Build a fix prompt for the IaC/dev agent and select the appropriate agent. + + Returns ``(agent, task_prompt)`` or ``(None, "")`` when no suitable + agent is available. + """ + category = stage.get("category", "infra") + + # Select agent based on category (mirrors BuildSession._build_stage_task) + if category in ("infra", "data", "integration"): + agent = self._iac_agents.get(self._iac_tool) + elif category in ("app", "schema", "cicd", "external"): + agent = self._dev_agent + else: + agent = self._iac_agents.get(self._iac_tool) or self._dev_agent + + if not agent: + return None, "" + + # Collect current stage files + file_content = self._collect_stage_file_content(stage, max_bytes=20_000) + + # Service list + services = stage.get("services", []) + svc_lines = "\n".join( + f"- {s.get('computed_name') or s.get('name', '?')}: " + f"{s.get('resource_type', 'N/A')} (SKU: {s.get('sku') or 'n/a'})" + for s in services ) - if not qa_result["diagnosed"]: - _print("") - _print(f" Error: {error_text[:500]}") + stage_dir = stage.get("dir", "concept") - if use_styled and qa_result.get("response"): - self._console.print_token_status(self._token_tracker.format_status()) + task = ( + f"Fix deployment Stage {stage['stage']}: {stage['name']}.\n\n" + f"The deployment FAILED with the following error. You MUST fix ALL issues.\n\n" + f"## Deploy Error\n```\n{deploy_error[:3000]}\n```\n\n" + f"## QA Diagnosis\n{qa_diagnosis[:2000]}\n\n" + f"## Architect Guidance\n{architect_guidance[:2000]}\n\n" + ) - _print("") - _print(" Options: /deploy (retry) | /rollback (undo) | /help | quit") + if file_content: + task += f"## Current Stage Files\n{file_content}\n\n" + + if svc_lines: + task += f"## Services in This Stage\n{svc_lines}\n\n" + + task += ( + f"## Requirements\n" + f"- Fix ALL issues identified in the error and diagnosis above\n" + f"- Preserve all working functionality — only change what's broken\n" + f"- All files should be relative to {stage_dir}/\n" + f"- Output COMPLETE file contents in fenced code blocks with filenames\n" + ) + + return agent, task + + def _write_stage_files(self, stage: dict, content: str) -> list[str]: + """Extract file blocks from AI response and write to disk. + + Returns a list of written file paths relative to the project dir. + """ + if not content: + return [] + + files = parse_file_blocks(content) + if not files: + return [] + + stage_dir = stage.get("dir", "concept") + output_dir = Path(self._context.project_dir) / stage_dir + blocked = _BLOCKED_FILES.get(self._iac_tool, set()) + + # Strip stage_dir prefix from filenames to avoid path duplication + cleaned: dict[str, str] = {} + for filename, file_content in files.items(): + normalized = filename.replace("\\", "/") + stage_prefix = stage_dir.replace("\\", "/") + if normalized.startswith(stage_prefix + "/"): + normalized = normalized[len(stage_prefix) + 1 :] + elif normalized.startswith(stage_prefix): + normalized = normalized[len(stage_prefix) :] + normalized = normalized or filename + + if normalized in blocked: + logger.info("Dropped blocked file: %s (IaC tool: %s)", normalized, self._iac_tool) + continue + + cleaned[normalized] = file_content + + written = write_parsed_files(cleaned, output_dir, verbose=False) + + project_root = Path(self._context.project_dir) + written_relative = [str(p.relative_to(project_root)) for p in written] + + # Sync build state with updated file list + self._sync_build_state(stage, written_relative) + + return written_relative + + def _sync_build_state(self, stage: dict, written_paths: list[str]) -> None: + """Best-effort sync of build.yaml after remediation writes. + + Updates the matching stage's ``files`` list and marks it as + ``generated`` so subsequent builds stay consistent. Uses + ``build_stage_id`` for matching when available, falling back + to stage number for legacy state files. + """ + try: + from azext_prototype.stages.build_state import BuildState + + bs = BuildState(self._context.project_dir) + if not bs.exists: + return + bs.load() + + build_stage_id = stage.get("build_stage_id") + matched = False + + if build_stage_id: + # Match by stable ID + target = bs.get_stage_by_id(build_stage_id) + if target: + target["files"] = written_paths + target["status"] = "generated" + matched = True + + if not matched: + # Fallback: match by stage number + stage_num = stage["stage"] + for build_stage in bs.state.get("deployment_stages", []): + if build_stage["stage"] == stage_num: + build_stage["files"] = written_paths + build_stage["status"] = "generated" + break + + bs.save() + except Exception: + logger.debug("Could not sync build state after remediation", exc_info=True) + + def _get_architect_fix_guidance(self, stage: dict, deploy_error: str, qa_diagnosis: str) -> str: + """Ask the architect agent for specific fix guidance. + + Returns guidance text, or a generic fallback if no architect is available. + """ + if not self._architect_agent or not self._context.ai_provider: + return "Fix the issues identified in the QA diagnosis. Ensure all resource references are correct." + + file_content = self._collect_stage_file_content(stage, max_bytes=10_000) + + task = ( + f"A deployment stage failed. Analyse the error and QA diagnosis, " + f"then provide SPECIFIC code changes needed to fix it.\n\n" + f"## Stage {stage['stage']}: {stage['name']}\n" + f"Category: {stage.get('category', 'infra')}\n\n" + f"## Deploy Error\n```\n{deploy_error[:2000]}\n```\n\n" + f"## QA Diagnosis\n{qa_diagnosis[:1500]}\n\n" + ) + + if file_content: + task += f"## Current Stage Files\n{file_content}\n\n" + + task += ( + "Provide:\n" + "1. Root cause of the failure\n" + "2. Specific code changes needed (which files, what to change)\n" + "3. Whether downstream stages might be affected by this fix " + "(e.g. changed outputs, renamed resources)\n" + ) + + try: + response = self._architect_agent.execute(self._context, task) + if response: + self._token_tracker.record(response) + if response and response.content: + return response.content + except Exception: + logger.debug("Architect fix guidance failed", exc_info=True) + + return "Fix the issues identified in the QA diagnosis. Ensure all resource references are correct." + + def _check_downstream_impact(self, fixed_stage: dict, architect_guidance: str) -> list[int]: + """Ask the architect whether downstream stages need regeneration. + + Returns a list of stage numbers that should be regenerated, or + empty list if none are affected. + """ + if not self._architect_agent or not self._context.ai_provider: + return [] + + stages = self._deploy_state._state.get("deployment_stages", []) + fixed_num = fixed_stage["stage"] + + # Only consider downstream stages that are pending or failed + downstream = [s for s in stages if s["stage"] > fixed_num and s.get("deploy_status") in ("pending", "failed")] + if not downstream: + return [] + + import json + + stage_info = json.dumps( + [ + { + "stage": s["stage"], + "name": s["name"], + "category": s.get("category", ""), + "build_stage_id": s.get("build_stage_id", ""), + "services": [svc.get("name", "") for svc in s.get("services", [])], + } + for s in downstream + ], + indent=2, + ) + + task = ( + f"Stage {fixed_num} ({fixed_stage['name']}) was just fixed during deployment.\n\n" + f"## Fix Context\n{architect_guidance[:1500]}\n\n" + f"## Downstream Stages\n```json\n{stage_info}\n```\n\n" + "Which downstream stages need their code regenerated because of " + "changed outputs, renamed resources, or modified dependencies from " + "the fix above?\n\n" + "Return ONLY a JSON array of affected stage numbers. " + "Return [] if no stages are affected.\n" + "Example: [3, 4]\n" + ) + + try: + response = self._architect_agent.execute(self._context, task) + if response: + self._token_tracker.record(response) + if response and response.content: + return self._parse_stage_numbers(response.content, downstream) + except Exception: + logger.debug("Downstream impact check failed", exc_info=True) + + return [] + + @staticmethod + def _parse_stage_numbers(content: str, valid_stages: list[dict]) -> list[int]: + """Parse a JSON array of stage numbers from AI response.""" + import json + + valid_nums = {s["stage"] for s in valid_stages} + + # Try to find a JSON array in the response + match = re.search(r"\[[\d\s,]*\]", content) + if match: + try: + numbers = json.loads(match.group()) + return [n for n in numbers if isinstance(n, int) and n in valid_nums] + except (json.JSONDecodeError, TypeError): + pass + + # Fallback: extract individual numbers + numbers = [int(n) for n in re.findall(r"\d+", content)] + return [n for n in numbers if n in valid_nums] + + def _regenerate_downstream_stages( + self, + stage_nums: list[int], + use_styled: bool, + _print: Callable[[str], None], + ) -> None: + """Regenerate code for downstream stages affected by an upstream fix. + + Only regenerates the code (writes files) — does NOT deploy. The + normal deploy loop handles deployment of these stages. + """ + if not stage_nums: + return + + _print(f" Regenerating {len(stage_nums)} downstream stage(s) affected by fix...") + + for num in stage_nums: + stage = self._deploy_state.get_stage(num) + if not stage: + continue + + category = stage.get("category", "infra") + if category in ("infra", "data", "integration"): + agent = self._iac_agents.get(self._iac_tool) + elif category in ("app", "schema", "cicd", "external"): + agent = self._dev_agent + else: + agent = self._iac_agents.get(self._iac_tool) or self._dev_agent + + if not agent: + continue + + file_content = self._collect_stage_file_content(stage, max_bytes=15_000) + stage_dir = stage.get("dir", "concept") + + task = ( + f"Regenerate code for Stage {num}: {stage['name']}.\n\n" + f"An upstream stage was fixed during deployment. This stage may need " + f"updates to match changed outputs or resource references.\n\n" + ) + if file_content: + task += f"## Current Stage Files\n{file_content}\n\n" + + task += ( + f"## Requirements\n" + f"- Update any references to upstream resources if they changed\n" + f"- Preserve all existing functionality\n" + f"- All files should be relative to {stage_dir}/\n" + f"- Output COMPLETE file contents in fenced code blocks with filenames\n" + ) + + with self._maybe_spinner(f"Regenerating Stage {num}...", use_styled): + try: + response = agent.execute(self._context, task) + except Exception: + logger.debug("Downstream regeneration failed for Stage %d", num, exc_info=True) + _print(f" Could not regenerate Stage {num}.") + continue + + if response: + self._token_tracker.record(response) + + if response and response.content: + written = self._write_stage_files(stage, response.content) + if written: + _print(f" Stage {num}: regenerated {len(written)} file(s).") + else: + _print(f" Stage {num}: no file blocks in regeneration response.") # ------------------------------------------------------------------ # # Internal — Rollback @@ -1009,6 +1634,100 @@ def _rollback_all( # Internal — Slash commands # ------------------------------------------------------------------ # + @staticmethod + def _parse_stage_ref(arg: str) -> tuple[int | None, str | None]: + """Parse a stage reference like ``"5"`` or ``"5a"``. + + Returns ``(stage_num, substage_label)`` or ``(None, None)`` on failure. + """ + from azext_prototype.stages.deploy_state import parse_stage_ref + + return parse_stage_ref(arg) + + def _resolve_stage_from_arg( + self, arg: str, _print: Callable[[str], None] + ) -> tuple[dict | None, int | None, str | None]: + """Parse a stage ref from an arg string and look it up. + + Returns ``(stage_dict, stage_num, substage_label)`` or ``(None, None, None)``. + """ + stage_num, label = self._parse_stage_ref(arg) + if stage_num is None: + _print(f" Invalid stage reference: {arg}") + return None, None, None + + if label: + stage = self._deploy_state.get_stage_by_display_id(arg) + else: + stage = self._deploy_state.get_stage(stage_num) + + if not stage: + _print(f" Stage {arg} not found.") + return None, stage_num, label + + return stage, stage_num, label + + def _deploy_stage_with_substages( + self, + stage_num: int, + use_styled: bool, + _print: Callable[[str], None], + _input: Callable[[str], str], + ) -> None: + """Deploy all substages of a stage number in order.""" + all_stages = self._deploy_state.get_all_stages_for_num(stage_num) + for stage in all_stages: + if stage.get("deploy_status") == "deployed": + from azext_prototype.stages.deploy_state import _format_display_id + + _print(f" Stage {_format_display_id(stage)} already deployed.") + continue + self._deploy_one_stage_cmd(stage, use_styled, _print, _input) + + def _deploy_one_stage_cmd( + self, + stage: dict, + use_styled: bool, + _print: Callable[[str], None], + _input: Callable[[str], str], + ) -> None: + """Deploy a single stage, handling manual mode and failures.""" + from azext_prototype.stages.deploy_state import _format_display_id + + display_id = _format_display_id(stage) + stage_num = stage["stage"] + + with self._maybe_spinner(f"Deploying Stage {display_id}...", use_styled): + result = self._deploy_single_stage(stage) + + if result.get("status") == "deployed": + _print(f" Stage {display_id} deployed successfully.") + if stage.get("category") in ("infra", "data", "integration"): + self._capture_stage_outputs(stage) + elif result.get("status") == "awaiting_manual": + instructions = result.get("instructions", "No instructions provided.") + _print(f" Stage {display_id} requires manual action:") + _print(f" {instructions}") + _print("") + _print(" When complete, enter: Done / Skip / Need help") + try: + answer = _input(" > ").strip().lower() + except (EOFError, KeyboardInterrupt): + _print(" Skipped.") + return + if answer in ("done", "d", "yes", "y"): + self._deploy_state.mark_stage_deployed(stage_num) + _print(f" Stage {display_id} marked as deployed.") + elif answer in ("need help", "help", "h"): + _print(" Ask the QA engineer for guidance by describing your issue.") + else: + _print(f" Stage {display_id} skipped. Use /deploy {display_id} when ready.") + elif result.get("status") == "failed": + _print(f" Stage {display_id} failed: {result.get('error', '?')[:120]}") + self._handle_deploy_failure(stage, result, use_styled, _print, _input) + else: + _print(f" Stage {display_id} skipped: {result.get('reason', 'no action needed')}") + def _handle_slash_command( self, command_line: str, @@ -1032,24 +1751,21 @@ def _handle_slash_command( if arg == "all" or not arg: self._deploy_pending_stages(force, use_styled, _print, _input) else: - try: - stage_num = int(arg) - stage = self._deploy_state.get_stage(stage_num) - if not stage: - _print(f" Stage {stage_num} not found.") - elif stage.get("deploy_status") == "deployed": - _print(f" Stage {stage_num} already deployed. Use /redeploy {stage_num}.") + stage, stage_num, label = self._resolve_stage_from_arg(arg, _print) + if stage: + assert stage_num is not None # guaranteed when stage is resolved + if stage.get("deploy_status") == "deployed" and not label: + _print(f" Stage {arg} already deployed. Use /redeploy {arg}.") + elif label: + # Deploy specific substage + self._deploy_one_stage_cmd(stage, use_styled, _print, _input) else: - with self._maybe_spinner(f"Deploying Stage {stage_num}...", use_styled): - result = self._deploy_single_stage(stage) - if result.get("status") == "deployed": - _print(f" Stage {stage_num} deployed successfully.") - if stage.get("category") in ("infra", "data", "integration"): - self._capture_stage_outputs(stage) + # Deploy all substages for this number + all_for_num = self._deploy_state.get_all_stages_for_num(stage_num) + if len(all_for_num) > 1: + self._deploy_stage_with_substages(stage_num, use_styled, _print, _input) else: - _print(f" Stage {stage_num} failed: {result.get('error', '?')[:120]}") - except ValueError: - _print(f" Invalid stage number: {arg}") + self._deploy_one_stage_cmd(stage, use_styled, _print, _input) _print("") elif cmd == "/rollback": @@ -1057,11 +1773,24 @@ def _handle_slash_command( if arg == "all" or not arg: self._rollback_all(_print, _input) else: - try: - stage_num = int(arg) - self._rollback_stage(stage_num, _print) - except ValueError: - _print(f" Invalid stage number: {arg}") + stage_num, label = self._parse_stage_ref(arg) + if stage_num is None: + _print(f" Invalid stage reference: {arg}") + elif label: + # Rollback specific substage + if self._deploy_state.can_rollback(stage_num, label): + self._rollback_stage(stage_num, _print) + else: + _print(f" Cannot rollback {arg}: later substages still deployed.") + else: + # Rollback all substages in reverse + all_for_num = self._deploy_state.get_all_stages_for_num(stage_num) + if len(all_for_num) > 1: + for s in reversed(all_for_num): + if s.get("deploy_status") == "deployed": + self._rollback_stage(s["stage"], _print) + else: + self._rollback_stage(stage_num, _print) _print("") elif cmd == "/redeploy": @@ -1069,74 +1798,148 @@ def _handle_slash_command( if not arg: _print(" Usage: /redeploy N") else: - try: - stage_num = int(arg) - stage = self._deploy_state.get_stage(stage_num) - if not stage: - _print(f" Stage {stage_num} not found.") + stage, stage_num, label = self._resolve_stage_from_arg(arg, _print) + if stage: + assert stage_num is not None # guaranteed when stage is resolved + # Rollback first if deployed + if stage.get("deploy_status") == "deployed": + success = self._rollback_stage(stage_num, _print) + if not success: + _print(" Rollback failed. Cannot redeploy.") + return + + stage["deploy_status"] = "pending" + self._deploy_state.save() + + from azext_prototype.stages.deploy_state import _format_display_id + + display_id = _format_display_id(stage) + with self._maybe_spinner(f"Redeploying Stage {display_id}...", use_styled): + result = self._deploy_single_stage(stage) + + if result.get("status") == "deployed": + _print(f" Stage {display_id} redeployed successfully.") + if stage.get("category") in ("infra", "data", "integration"): + self._capture_stage_outputs(stage) + elif result.get("status") == "awaiting_manual": + _print(f" Stage {display_id} requires manual action:") + _print(f" {result.get('instructions', '')}") else: - # Rollback first if deployed - if stage.get("deploy_status") == "deployed": - success = self._rollback_stage(stage_num, _print) - if not success: - _print(" Rollback failed. Cannot redeploy.") - return - - # Reset status to pending and redeploy - stage["deploy_status"] = "pending" - self._deploy_state.save() + _print(f" Stage {display_id} failed: {result.get('error', '?')[:120]}") + self._handle_deploy_failure(stage, result, use_styled, _print, _input) + _print("") - with self._maybe_spinner(f"Redeploying Stage {stage_num}...", use_styled): - result = self._deploy_single_stage(stage) + elif cmd == "/plan": + _print("") + if not arg: + _print(" Usage: /plan N") + else: + stage, stage_num, _label = self._resolve_stage_from_arg(arg, _print) + if stage: + stage_dir = Path(self._context.project_dir) / stage.get("dir", "") + if stage.get("deploy_mode") == "manual": + _print(f" Stage {arg} is a manual step — no plan preview.") + elif not stage_dir.is_dir(): + _print(f" Directory not found: {stage.get('dir', '?')}") + elif stage.get("category") in ("infra", "data", "integration"): + with self._maybe_spinner(f"Running plan for Stage {arg}...", use_styled): + if self._iac_tool == "terraform": + plan_env = self._deploy_env + generated = resolve_stage_secrets(stage_dir, self._config) + if generated: + plan_env = dict(self._deploy_env) if self._deploy_env else {} + plan_env.update(generated) + result = plan_terraform(stage_dir, self._subscription, env=plan_env) + else: + result = whatif_bicep( + stage_dir, + self._subscription, + self._resource_group, + env=self._deploy_env, + ) + if result.get("output"): + _print(result["output"]) + if result.get("error"): + _print(f" Error: {result['error']}") + else: + _print(f" Stage {arg} is an app stage — no plan preview.") + _print("") + + elif cmd == "/split": + _print("") + if not arg: + _print(" Usage: /split N") + else: + stage, stage_num, _label = self._resolve_stage_from_arg(arg, _print) + if stage and stage_num is not None: + if stage.get("_is_substage"): + _print(f" Stage {arg} is already a substage. Cannot split further.") + else: + _print(f" Splitting Stage {stage_num}: {stage['name']}") + _print(" Enter names for the substages (one per line, blank line to finish):") + substages: list[dict] = [] + while True: + try: + name = _input(" Name: ").strip() + except (EOFError, KeyboardInterrupt): + break + if not name: + break + substages.append({"name": name, "dir": stage.get("dir", "")}) + if len(substages) >= 2: + self._deploy_state.split_stage(stage_num, substages) + _print(f" Split into {len(substages)} substages.") + else: + _print(" Split requires at least 2 substages. Cancelled.") + _print("") - if result.get("status") == "deployed": - _print(f" Stage {stage_num} redeployed successfully.") - if stage.get("category") in ("infra", "data", "integration"): - self._capture_stage_outputs(stage) + elif cmd == "/destroy": + _print("") + if not arg: + _print(" Usage: /destroy N") + else: + stage, stage_num, _label = self._resolve_stage_from_arg(arg, _print) + if stage and stage_num is not None: + _print(f" Are you sure you want to destroy Stage {arg}: {stage['name']}? (y/N)") + try: + answer = _input(" > ").strip().lower() + except (EOFError, KeyboardInterrupt): + _print(" Cancelled.") + return + if answer in ("y", "yes"): + success = self._rollback_stage(stage_num, _print) + if success: + self._deploy_state.mark_stage_destroyed(stage_num) + _print(f" Stage {arg} destroyed.") else: - _print(f" Stage {stage_num} failed: {result.get('error', '?')[:120]}") - except ValueError: - _print(f" Invalid stage number: {arg}") + _print(f" Could not destroy Stage {arg}.") + else: + _print(" Cancelled.") _print("") - elif cmd == "/plan": + elif cmd == "/manual": _print("") if not arg: - _print(" Usage: /plan N") + _print(' Usage: /manual N "instructions"') else: - try: - stage_num = int(arg) - stage = self._deploy_state.get_stage(stage_num) - if not stage: - _print(f" Stage {stage_num} not found.") + # Parse: /manual 5 "instructions text" + manual_parts = arg.split(maxsplit=1) + ref = manual_parts[0] + instructions = manual_parts[1].strip('"').strip("'") if len(manual_parts) > 1 else "" + stage, stage_num, _label = self._resolve_stage_from_arg(ref, _print) + if stage: + if instructions: + stage["deploy_mode"] = "manual" + stage["manual_instructions"] = instructions + self._deploy_state.save() + _print(f" Stage {ref} set to manual mode with instructions.") else: - stage_dir = Path(self._context.project_dir) / stage.get("dir", "") - if not stage_dir.is_dir(): - _print(f" Directory not found: {stage.get('dir', '?')}") - elif stage.get("category") in ("infra", "data", "integration"): - with self._maybe_spinner(f"Running plan for Stage {stage_num}...", use_styled): - if self._iac_tool == "terraform": - plan_env = self._deploy_env - generated = resolve_stage_secrets(stage_dir, self._config) - if generated: - plan_env = dict(self._deploy_env) if self._deploy_env else {} - plan_env.update(generated) - result = plan_terraform(stage_dir, self._subscription, env=plan_env) - else: - result = whatif_bicep( - stage_dir, - self._subscription, - self._resource_group, - env=self._deploy_env, - ) - if result.get("output"): - _print(result["output"]) - if result.get("error"): - _print(f" Error: {result['error']}") + current = stage.get("manual_instructions", "") + if current: + _print(f" Current instructions: {current}") else: - _print(f" Stage {stage_num} is an app stage — no plan preview.") - except ValueError: - _print(f" Invalid stage number: {arg}") + _print(f" No manual instructions set for Stage {ref}.") + _print(' Use: /manual N "your instructions here"') _print("") elif cmd == "/outputs": @@ -1172,15 +1975,22 @@ def _handle_slash_command( _print(" az CLI not found on PATH.") _print("") + elif cmd == "/describe": + self._handle_describe(arg, _print) + elif cmd == "/help": _print("") _print(" Available commands:") _print(" /status - Show deployment progress per stage") _print(" /stages - List all stages with status (alias)") - _print(" /deploy [N] - Deploy stage N or all pending stages") + _print(" /deploy [N] - Deploy stage N (or 5a for substage) or all") _print(" /rollback [N] - Roll back stage N or all (reverse order)") _print(" /redeploy N - Rollback + redeploy stage N") _print(" /plan N - Show what-if/terraform plan for stage N") + _print(" /split N - Split a stage into substages") + _print(" /destroy N - Destroy resources for a removed stage") + _print(" /manual N - Add/view manual step instructions") + _print(" /describe N - Show details for stage N") _print(" /outputs - Show captured deployment outputs") _print(" /preflight - Re-run preflight checks") _print(" /login - Run az login interactively") @@ -1188,10 +1998,80 @@ def _handle_slash_command( _print(" done - Accept deployment and exit") _print(" quit - Exit deploy session") _print("") + _print(" Stage references accept substage labels: /deploy 5a") + _print("") + _print(" You can also use natural language:") + _print(" 'deploy stage 3' instead of /deploy 3") + _print(" 'rollback all' instead of /rollback all") + _print(" 'deploy stages 3 and 4' deploys multiple stages") + _print(" 'describe stage 2' instead of /describe 2") + _print("") else: _print(f" Unknown command: {cmd}. Type /help for a list.") + def _handle_describe(self, arg: str, _print: Callable[[str], None]) -> None: + """Show detailed description of a deploy stage.""" + if not arg or not arg.strip(): + _print(" Usage: /describe N (stage number)") + return + + numbers = re.findall(r"\d+", arg) + if not numbers: + _print(" Usage: /describe N (stage number)") + return + + stage_num = int(numbers[0]) + stage = self._deploy_state.get_stage(stage_num) + if not stage: + _print(f" Stage {stage_num} not found.") + return + + _print("") + _print(f" Stage {stage_num}: {stage.get('name', '?')}") + _print(f" Category: {stage.get('category', '?')}") + _print(f" Deploy status: {stage.get('deploy_status', 'pending')}") + _print(f" Dir: {stage.get('dir', '?')}") + + timestamp = stage.get("deployed_at", "") + if timestamp: + _print(f" Deployed at: {timestamp}") + + services = stage.get("services", []) + if services: + _print(f" Resources ({len(services)}):") + for svc in services: + name = svc.get("computed_name") or svc.get("name", "?") + rtype = svc.get("resource_type", "") + sku = svc.get("sku", "") + line = f" - {name}" + if rtype: + line += f" ({rtype})" + if sku: + line += f" [{sku}]" + _print(line) + + files = stage.get("files", []) + if files: + _print(f" Files ({len(files)}):") + for f in files: + _print(f" - {f}") + + deploy_output = stage.get("deploy_output", "") + if deploy_output: + _print(" Deploy output:") + for line in deploy_output.split("\n")[:10]: + _print(f" {line}") + if deploy_output.count("\n") > 10: + _print(" ... (truncated)") + + deploy_error = stage.get("deploy_error", "") + if deploy_error: + _print(" Deploy error:") + _print(f" {deploy_error[:200]}") + + _print("") + # ------------------------------------------------------------------ # # Internal — utilities # ------------------------------------------------------------------ # @@ -1208,10 +2088,16 @@ def _build_result(self) -> DeployResult: ) @contextmanager - def _maybe_spinner(self, message: str, use_styled: bool) -> Iterator[None]: + def _maybe_spinner(self, message: str, use_styled: bool, *, status_fn: Callable | None = None) -> Iterator[None]: """Show a spinner when using styled output, otherwise no-op.""" if use_styled: with self._console.spinner(message): yield + elif status_fn: + status_fn(message, "start") + try: + yield + finally: + status_fn(message, "end") else: yield diff --git a/azext_prototype/stages/deploy_state.py b/azext_prototype/stages/deploy_state.py index 80fd1da..ef5c086 100644 --- a/azext_prototype/stages/deploy_state.py +++ b/azext_prototype/stages/deploy_state.py @@ -13,22 +13,39 @@ - Preflight check results - Per-stage deploy/rollback audit trail - Captured Terraform/Bicep outputs +- Build-deploy correspondence via stable ``build_stage_id`` +- Substage splitting for 1:N divergence """ from __future__ import annotations import logging +import re +from dataclasses import dataclass, field from datetime import datetime, timezone from pathlib import Path from typing import Any import yaml +from azext_prototype.stages.build_state import _slugify + logger = logging.getLogger(__name__) DEPLOY_STATE_FILE = ".prototype/state/deploy.yaml" +@dataclass +class SyncResult: + """Result of syncing deploy state from build state.""" + + matched: int = 0 + created: int = 0 + orphaned: int = 0 + updated_code: int = 0 + details: list[str] = field(default_factory=list) + + def _default_deploy_state() -> dict[str, Any]: """Return the default empty deploy state structure.""" return { @@ -50,12 +67,31 @@ def _default_deploy_state() -> dict[str, Any]: } +def _enrich_deploy_fields(stage: dict) -> dict: + """Ensure a stage dict has all deploy-specific fields.""" + stage.setdefault("deploy_status", "pending") + stage.setdefault("deploy_timestamp", None) + stage.setdefault("deploy_output", "") + stage.setdefault("deploy_error", "") + stage.setdefault("rollback_timestamp", None) + stage.setdefault("remediation_attempts", 0) + stage.setdefault("build_stage_id", None) + stage.setdefault("deploy_mode", "auto") + stage.setdefault("manual_instructions", None) + stage.setdefault("substage_label", None) + stage.setdefault("_is_substage", False) + stage.setdefault("_destruction_declined", False) + return stage + + class DeployState: """Manages persistent deploy state in YAML format. Provides: - Loading existing state on startup (re-entrant deploys) - Importing deployment stages from build state + - Smart sync with build state (preserves deploy progress) + - Stage splitting for 1:N build-deploy divergence - Per-stage deploy status transitions with ordering enforcement - Preflight result tracking - Deploy and rollback audit logging @@ -93,6 +129,7 @@ def load(self) -> dict[str, Any]: loaded = yaml.safe_load(f) or {} self._state = _default_deploy_state() self._deep_merge(self._state, loaded) + self._backfill_build_stage_ids() self._loaded = True logger.info("Loaded deploy state from %s", self._path) except (yaml.YAMLError, IOError) as e: @@ -138,7 +175,7 @@ def load_from_build_state(self, build_state_path: str | Path) -> bool: For each stage from the build state, adds deploy-specific fields: ``deploy_status``, ``deploy_timestamp``, ``deploy_output``, - ``deploy_error``, ``rollback_timestamp``. + ``deploy_error``, ``rollback_timestamp``, ``build_stage_id``. Returns True if stages were imported, False if build.yaml not found or contained no deployment stages. @@ -163,11 +200,9 @@ def load_from_build_state(self, build_state_path: str | Path) -> bool: enriched: list[dict] = [] for stage in build_stages: enriched_stage = dict(stage) - enriched_stage.setdefault("deploy_status", "pending") - enriched_stage.setdefault("deploy_timestamp", None) - enriched_stage.setdefault("deploy_output", "") - enriched_stage.setdefault("deploy_error", "") - enriched_stage.setdefault("rollback_timestamp", None) + # Set build_stage_id from the build stage's id field + enriched_stage["build_stage_id"] = stage.get("id") or _slugify(stage.get("name", "stage")) + _enrich_deploy_fields(enriched_stage) enriched.append(enriched_stage) self._state["deployment_stages"] = enriched @@ -177,6 +212,241 @@ def load_from_build_state(self, build_state_path: str | Path) -> bool: logger.info("Imported %d stages from build state.", len(enriched)) return True + def sync_from_build_state(self, build_state_path: str | Path) -> SyncResult: + """Smart reconciliation of deploy stages with current build state. + + Unlike :meth:`load_from_build_state` (which overwrites), this method: + + - **Matches** existing deploy stages to build stages by ``build_stage_id`` + - **Updates** build-sourced fields (name, category, services, deploy_mode) + while preserving deploy state (status, timestamps, substage structure) + - **Creates** new deploy stages for new build stages + - **Orphans** deploy stages whose build stage was removed (sets ``removed``) + - Falls back to name+category matching for legacy stages + + Returns a :class:`SyncResult` summarising the changes. + """ + result = SyncResult() + path = Path(build_state_path) + if not path.exists(): + result.details.append("Build state not found.") + return result + + try: + with open(path, "r", encoding="utf-8") as f: + build_data = yaml.safe_load(f) or {} + except (yaml.YAMLError, IOError) as e: + result.details.append(f"Could not read build state: {e}") + return result + + build_stages = build_data.get("deployment_stages", []) + if not build_stages: + result.details.append("Build state has no deployment_stages.") + return result + + existing = self._state["deployment_stages"] + + # Index existing deploy stages by build_stage_id + deploy_by_bid: dict[str, list[dict]] = {} + for ds in existing: + bid = ds.get("build_stage_id") + if bid: + deploy_by_bid.setdefault(bid, []).append(ds) + + # Track which build_stage_ids we've matched + matched_bids: set[str] = set() + new_stages: list[dict] = [] + + for bs in build_stages: + bid = bs.get("id") or _slugify(bs.get("name", "stage")) + matched_bids.add(bid) + + if bid in deploy_by_bid: + # Update matched deploy stages with build-sourced fields + for ds in deploy_by_bid[bid]: + # Check if code changed + old_dir = ds.get("dir", "") + new_dir = bs.get("dir", "") + old_files = ds.get("files", []) + new_files = bs.get("files", []) + code_changed = (old_dir != new_dir) or (sorted(old_files) != sorted(new_files)) + + # Update build-sourced fields + ds["name"] = bs.get("name", ds["name"]) + ds["category"] = bs.get("category", ds.get("category", "infra")) + ds["services"] = bs.get("services", ds.get("services", [])) + ds["deploy_mode"] = bs.get("deploy_mode", ds.get("deploy_mode", "auto")) + ds["manual_instructions"] = bs.get("manual_instructions", ds.get("manual_instructions")) + if not ds.get("_is_substage"): + ds["dir"] = new_dir + ds["files"] = new_files + + if code_changed and ds.get("deploy_status") == "deployed": + ds["_code_updated"] = True + result.updated_code += 1 + + result.matched += 1 + else: + # Legacy fallback: match by name+category + legacy_match = None + for ds in existing: + if ( + not ds.get("build_stage_id") + and ds.get("name") == bs.get("name") + and ds.get("category") == bs.get("category") + ): + legacy_match = ds + break + + if legacy_match: + legacy_match["build_stage_id"] = bid + legacy_match["deploy_mode"] = bs.get("deploy_mode", "auto") + legacy_match["manual_instructions"] = bs.get("manual_instructions") + result.matched += 1 + matched_bids.add(bid) + else: + # Create new deploy stage + new_ds = dict(bs) + new_ds["build_stage_id"] = bid + _enrich_deploy_fields(new_ds) + new_stages.append(new_ds) + result.created += 1 + result.details.append(f"New stage: {bs.get('name', '?')}") + + # Rebuild ordered list + ordered: list[dict] = [] + processed_bids: set[str] = set() + + for bs in build_stages: + bid = bs.get("id") or _slugify(bs.get("name", "stage")) + if bid in deploy_by_bid and bid not in processed_bids: + # Add existing deploy stages for this build stage (in existing order) + ordered.extend(deploy_by_bid[bid]) + processed_bids.add(bid) + elif bid not in processed_bids: + # Add newly created stage + for ns in new_stages: + if ns.get("build_stage_id") == bid: + ordered.append(ns) + processed_bids.add(bid) + + # Also add legacy-matched stages that weren't in deploy_by_bid + for ds in existing: + if ds not in ordered and ds.get("build_stage_id") in matched_bids: + ordered.append(ds) + + # Orphaned stages (build stage removed) + for bid, deploy_stages in deploy_by_bid.items(): + if bid not in matched_bids: + for ds in deploy_stages: + if ds.get("deploy_status") not in ("removed", "destroyed"): + ds["deploy_status"] = "removed" + result.orphaned += 1 + result.details.append(f"Removed: {ds.get('name', '?')}") + ordered.append(ds) + + # Also catch any existing stages not yet in ordered + for ds in existing: + if ds not in ordered: + if ds.get("build_stage_id") not in matched_bids: + if ds.get("deploy_status") not in ("removed", "destroyed"): + ds["deploy_status"] = "removed" + result.orphaned += 1 + ordered.append(ds) + + self._state["deployment_stages"] = ordered + self._state["iac_tool"] = build_data.get("iac_tool", self._state.get("iac_tool", "terraform")) + self.renumber_stages() + + return result + + # ------------------------------------------------------------------ # + # Stage splitting + # ------------------------------------------------------------------ # + + def split_stage(self, stage_num: int, substages: list[dict]) -> None: + """Replace one deploy stage with N substages sharing the same ``build_stage_id``. + + Each substage gets a letter suffix: ``"a"``, ``"b"``, ``"c"``, etc. + The original stage is removed from the list. + + Args: + stage_num: The stage number to split. + substages: List of substage dicts with at minimum ``name``, ``dir``. + """ + stages = self._state["deployment_stages"] + parent_idx = None + parent = None + + for i, s in enumerate(stages): + if s["stage"] == stage_num and not s.get("substage_label"): + parent_idx = i + parent = s + break + + if parent is None or parent_idx is None: + logger.warning("Stage %d not found for splitting.", stage_num) + return + + build_stage_id = parent.get("build_stage_id") + labels = [chr(ord("a") + i) for i in range(len(substages))] + + new_entries: list[dict] = [] + for label, sub in zip(labels, substages): + entry = dict(parent) + entry.update(sub) + entry["stage"] = stage_num + entry["substage_label"] = label + entry["_is_substage"] = True + entry["build_stage_id"] = build_stage_id + _enrich_deploy_fields(entry) + # Reset deploy state for new substages + entry["deploy_status"] = "pending" + entry["deploy_timestamp"] = None + entry["deploy_output"] = "" + entry["deploy_error"] = "" + new_entries.append(entry) + + # Replace parent with substages + stages[parent_idx : parent_idx + 1] = new_entries + self.save() + + def get_stage_groups(self) -> dict[str | None, list[dict]]: + """Group deploy stages by ``build_stage_id`` for tree rendering. + + Returns a dict mapping ``build_stage_id`` → list of deploy stages. + Stages without a ``build_stage_id`` are grouped under ``None``. + """ + groups: dict[str | None, list[dict]] = {} + for s in self._state["deployment_stages"]: + bid = s.get("build_stage_id") + groups.setdefault(bid, []).append(s) + return groups + + def get_stages_for_build_stage(self, build_stage_id: str) -> list[dict]: + """Return all deploy stages linked to a given build stage.""" + return [s for s in self._state["deployment_stages"] if s.get("build_stage_id") == build_stage_id] + + def get_stage_by_display_id(self, display_id: str) -> dict | None: + """Parse a display ID like ``"5"`` or ``"5a"`` and return the matching stage. + + Returns None if no match found. + """ + stage_num, label = parse_stage_ref(display_id) + if stage_num is None: + return None + + for s in self._state["deployment_stages"]: + if s["stage"] == stage_num: + if label is None and not s.get("substage_label"): + return s + if label is not None and s.get("substage_label") == label: + return s + # If asking for bare number and stage has substages, return first + if label is None and s.get("substage_label"): + return s + return None + # ------------------------------------------------------------------ # # Stage status transitions # ------------------------------------------------------------------ # @@ -219,17 +489,118 @@ def mark_stage_rolled_back(self, stage_num: int) -> None: self.add_rollback_log_entry(stage_num) self.save() + def mark_stage_remediating(self, stage_num: int) -> None: + """Mark a stage as undergoing remediation and bump attempt counter.""" + stage = self.get_stage(stage_num) + if stage: + stage["deploy_status"] = "remediating" + stage["remediation_attempts"] = stage.get("remediation_attempts", 0) + 1 + self.add_deploy_log_entry(stage_num, "remediating", f"attempt {stage['remediation_attempts']}") + self.save() + + def reset_stage_to_pending(self, stage_num: int) -> None: + """Reset a failed/remediating stage back to pending for re-deploy.""" + stage = self.get_stage(stage_num) + if stage: + stage["deploy_status"] = "pending" + stage["deploy_error"] = "" + self.save() + + def mark_stage_removed(self, stage_num: int) -> None: + """Mark a stage as removed (build stage was deleted).""" + stage = self.get_stage(stage_num) + if stage: + stage["deploy_status"] = "removed" + self.add_deploy_log_entry(stage_num, "removed") + self.save() + + def mark_stage_destroyed(self, stage_num: int) -> None: + """Mark a removed stage as destroyed (resources torn down).""" + stage = self.get_stage(stage_num) + if stage: + stage["deploy_status"] = "destroyed" + self.add_deploy_log_entry(stage_num, "destroyed") + self.save() + + def mark_stage_awaiting_manual(self, stage_num: int) -> None: + """Mark a manual stage as awaiting user confirmation.""" + stage = self.get_stage(stage_num) + if stage: + stage["deploy_status"] = "awaiting_manual" + self.add_deploy_log_entry(stage_num, "awaiting_manual") + self.save() + + def add_patch_stages(self, new_stages: list[dict]) -> None: + """Insert new stages before the docs stage, enriched with deploy fields. + + Follows the same insertion pattern as + :meth:`~.build_state.BuildState.add_stages`. + """ + existing = self._state["deployment_stages"] + + # Find insertion point — before the docs stage + insert_idx = len(existing) + for i, s in enumerate(existing): + if s.get("category") == "docs": + insert_idx = i + break + + for ns in new_stages: + _enrich_deploy_fields(ns) + ns.setdefault("services", []) + ns.setdefault("files", []) + ns.setdefault("dir", "") + existing.insert(insert_idx, ns) + insert_idx += 1 + + self.renumber_stages() + + def renumber_stages(self) -> None: + """Renumber stages sequentially. + + Top-level stages get sequential integers starting from 1. + Substages inherit their parent's number (their labels are unchanged). + A group of substages with the same ``build_stage_id`` counts as + one logical stage for numbering purposes. + """ + stages = self._state["deployment_stages"] + current_num = 0 + seen_substage_bids: set[str | None] = set() + + for stage in stages: + if not stage.get("substage_label"): + # Top-level stage + current_num += 1 + stage["stage"] = current_num + else: + # Substage — check if this is the first substage in its group + bid = stage.get("build_stage_id") + if bid not in seen_substage_bids: + current_num += 1 + seen_substage_bids.add(bid) + stage["stage"] = current_num + + self.save() + # ------------------------------------------------------------------ # # Stage queries # ------------------------------------------------------------------ # def get_stage(self, stage_num: int) -> dict | None: - """Return a specific stage by number.""" + """Return a specific stage by number. + + For stages with substages, returns the first matching stage + (the first substage or the top-level stage). + """ for stage in self._state["deployment_stages"]: if stage["stage"] == stage_num: return stage return None + def get_all_stages_for_num(self, stage_num: int) -> list[dict]: + """Return all stages/substages with the given stage number.""" + return [s for s in self._state["deployment_stages"] if s["stage"] == stage_num] + def get_pending_stages(self) -> list[dict]: """Return stages not yet deployed.""" return [s for s in self._state["deployment_stages"] if s.get("deploy_status") == "pending"] @@ -248,18 +619,27 @@ def get_rollback_candidates(self) -> list[dict]: Only stages that can be safely rolled back are included. """ deployed = self.get_deployed_stages() - return sorted(deployed, key=lambda s: s["stage"], reverse=True) + return sorted(deployed, key=lambda s: (s["stage"], s.get("substage_label") or ""), reverse=True) - def can_rollback(self, stage_num: int) -> bool: + def can_rollback(self, stage_num: int, substage_label: str | None = None) -> bool: """Check if a stage can be rolled back. A stage can only be rolled back if no higher-numbered stage has - ``deploy_status == 'deployed'``. This enforces the invariant: - cannot roll back stage N before rolling back stage N+1. + ``deploy_status == 'deployed'``. For substages, checks within + the same stage number that no later substage is still deployed. """ for stage in self._state["deployment_stages"]: - if stage["stage"] > stage_num and stage.get("deploy_status") == "deployed": + s_status = stage.get("deploy_status") + if s_status != "deployed": + continue + s_num = stage["stage"] + s_label = stage.get("substage_label") + + if s_num > stage_num: return False + if s_num == stage_num and substage_label is not None and s_label is not None: + if s_label > substage_label: + return False return True # ------------------------------------------------------------------ # @@ -352,33 +732,49 @@ def format_deploy_report(self) -> str: lines.append("") stages = self._state.get("deployment_stages", []) + active_stages = [s for s in stages if s.get("deploy_status") not in ("removed", "destroyed")] deployed = len([s for s in stages if s.get("deploy_status") == "deployed"]) failed = len([s for s in stages if s.get("deploy_status") == "failed"]) rolled = len([s for s in stages if s.get("deploy_status") == "rolled_back"]) + removed = len([s for s in stages if s.get("deploy_status") in ("removed", "destroyed")]) lines.append( - f" Stages: {len(stages)} total, {deployed} deployed" + f" Stages: {len(active_stages)} active, {deployed} deployed" f"{f', {failed} failed' if failed else ''}" f"{f', {rolled} rolled back' if rolled else ''}" + f"{f', {removed} removed' if removed else ''}" ) lines.append("") for stage in stages: - icon = _status_icon(stage.get("deploy_status", "pending")) - line = f" {icon} Stage {stage['stage']}: {stage['name']}" + status = stage.get("deploy_status", "pending") + icon = _status_icon(status) + display_id = _format_display_id(stage) + deploy_mode = stage.get("deploy_mode", "auto") + + if status in ("removed", "destroyed"): + line = f" {icon} Stage {display_id}: ~~{stage['name']}~~ (Removed)" + else: + line = f" {icon} Stage {display_id}: {stage['name']}" + if deploy_mode == "manual": + line += " [Manual]" + ts = stage.get("deploy_timestamp") if ts: line += f" ({ts[:19]})" lines.append(line) services = stage.get("services", []) - if services: + if services and status not in ("removed", "destroyed"): svc_names = [s.get("computed_name") or s.get("name", "?") for s in services] lines.append(f" Resources: {', '.join(svc_names)}") + if deploy_mode == "manual" and stage.get("manual_instructions"): + preview = stage["manual_instructions"][:80] + lines.append(f" Instructions: {preview}...") + error = stage.get("deploy_error", "") if error: - # Truncate long errors short = error[:120] + "..." if len(error) > 120 else error lines.append(f" Error: {short}") @@ -402,14 +798,23 @@ def format_stage_status(self) -> str: status = stage.get("deploy_status", "pending") icon = _status_icon(status) svc_count = len(stage.get("services", [])) - line = f" {icon} Stage {stage['stage']}: {stage['name']} ({stage.get('category', '?')})" - if svc_count: - line += f" - {svc_count} service(s)" + display_id = _format_display_id(stage) + deploy_mode = stage.get("deploy_mode", "auto") + + if status in ("removed", "destroyed"): + line = f" {icon} Stage {display_id}: ~~{stage['name']}~~ ({stage.get('category', '?')}) (Removed)" + else: + line = f" {icon} Stage {display_id}: {stage['name']} ({stage.get('category', '?')})" + if deploy_mode == "manual": + line += " [Manual]" + if svc_count: + line += f" - {svc_count} service(s)" lines.append(line) + active = [s for s in stages if s.get("deploy_status") not in ("removed", "destroyed")] deployed = len([s for s in stages if s.get("deploy_status") == "deployed"]) lines.append("") - lines.append(f" Progress: {deployed}/{len(stages)} stages deployed") + lines.append(f" Progress: {deployed}/{len(active)} stages deployed") metadata = self._state.get("_metadata", {}) if metadata.get("last_updated"): @@ -472,6 +877,13 @@ def format_outputs(self) -> str: # Internals # ------------------------------------------------------------------ # + def _backfill_build_stage_ids(self) -> None: + """Backfill ``build_stage_id`` and deploy fields on legacy state files.""" + for stage in self._state["deployment_stages"]: + if not stage.get("build_stage_id"): + stage["build_stage_id"] = _slugify(stage.get("name", "stage")) + _enrich_deploy_fields(stage) + def _deep_merge(self, base: dict, updates: dict) -> None: """Deep merge updates into base dict.""" for key, value in updates.items(): @@ -481,6 +893,30 @@ def _deep_merge(self, base: dict, updates: dict) -> None: base[key] = value +# ================================================================== # +# Module-level helpers +# ================================================================== # + + +def parse_stage_ref(arg: str) -> tuple[int | None, str | None]: + """Parse a stage reference like ``"5"`` or ``"5a"`` into (stage_num, substage_label). + + Returns ``(None, None)`` if the string cannot be parsed. + """ + m = re.match(r"^(\d+)([a-z]?)$", arg.strip()) + if not m: + return None, None + stage_num = int(m.group(1)) + label = m.group(2) or None + return stage_num, label + + +def _format_display_id(stage: dict) -> str: + """Format a stage's display identifier, e.g. ``"5"`` or ``"5a"``.""" + label = stage.get("substage_label") or "" + return f"{stage['stage']}{label}" + + def _status_icon(status: str) -> str: """Return a compact status icon for display.""" return { @@ -489,4 +925,8 @@ def _status_icon(status: str) -> str: "deployed": " v", "failed": " x", "rolled_back": " ~", + "remediating": "<>", + "removed": "~~", + "destroyed": "xx", + "awaiting_manual": "!!", }.get(status, " ") diff --git a/azext_prototype/stages/design_stage.py b/azext_prototype/stages/design_stage.py index 05ca9f8..76272e7 100644 --- a/azext_prototype/stages/design_stage.py +++ b/azext_prototype/stages/design_stage.py @@ -15,6 +15,7 @@ import json import logging import re +import time from datetime import datetime, timezone from pathlib import Path @@ -32,6 +33,34 @@ logger = logging.getLogger(__name__) +_NEW_SECTION_RE = re.compile( + r"\[NEW_SECTION:\s*(\{.*?\})\]", + re.DOTALL, +) + + +def _format_section_elapsed(seconds: float) -> str: + """Format elapsed seconds as ``12s`` or ``1m04s`` when >= 60.""" + if seconds < 60: + return f"{seconds:.0f}s" + minutes = int(seconds) // 60 + secs = int(seconds) % 60 + return f"{minutes}m{secs:02d}s" + + +def _extract_new_sections(content: str) -> list[dict]: + """Parse ``[NEW_SECTION: {...}]`` markers from AI response content.""" + results = [] + for m in _NEW_SECTION_RE.finditer(content): + try: + obj = json.loads(m.group(1)) + if isinstance(obj, dict) and "name" in obj: + obj.setdefault("context", "") + results.append(obj) + except (json.JSONDecodeError, TypeError): + pass + return results + class DesignStage(BaseStage): """Analyze requirements and generate architecture design. @@ -87,9 +116,13 @@ def execute( reset = kwargs.get("reset", False) interactive = kwargs.get("interactive", False) skip_discovery = kwargs.get("skip_discovery", False) - # Accept injected I/O callables (for tests) + # Accept injected I/O callables (for tests / TUI) input_fn = kwargs.get("input_fn") print_fn = kwargs.get("print_fn") + status_fn = kwargs.get("status_fn") + section_fn = kwargs.get("section_fn") + response_fn = kwargs.get("response_fn") + update_task_fn = kwargs.get("update_task_fn") self.state = StageState.IN_PROGRESS config = ProjectConfig(agent_context.project_dir) @@ -103,7 +136,10 @@ def execute( ui = default_console if use_styled else None _print = print_fn or default_console.print - default_console.print_header("Starting design session") + if use_styled: + default_console.print_header("Starting design session") + else: + _print("\n[bold bright_magenta]Starting design session[/bold bright_magenta]\n") # Load existing discovery state discovery_state = DiscoveryState(agent_context.project_dir) @@ -111,6 +147,8 @@ def execute( discovery_state.load() if ui: ui.print_info("Loaded existing discovery context from previous session.") + else: + _print("[bright_cyan]\u2192[/bright_cyan] Loaded existing discovery context from previous session.") # Determine if this is a context-only invocation # (--context provided but no --artifacts) @@ -129,10 +167,12 @@ def execute( ui.print_file_list(result["read"], success=True) else: for name in result["read"]: - _print(f" \u2713 {name}") + _print(f" [bright_green]\u2713[/bright_green] [bright_cyan]{name}[/bright_cyan]") if artifact_images: - _print(f" Extracted {len(artifact_images)} image(s) for vision analysis") + _print( + f" [bright_cyan]\u2192[/bright_cyan] Extracted {len(artifact_images)} image(s) for vision analysis" + ) if result["failed"]: _print(f" Could not read {len(result['failed'])} file(s):") @@ -140,10 +180,10 @@ def execute( ui.print_file_list([f"{n} ({r})" for n, r in result["failed"]], success=False) else: for name, reason in result["failed"]: - _print(f" \u2717 {name} ({reason})") + _print(f" [bright_red]\u2717[/bright_red] {name} ({reason})") if not result["read"] and not result["failed"]: - _print(" (no files found)") + _print(" [dim](no files found)[/dim]") _print("") design_state["artifacts"].append( @@ -221,6 +261,10 @@ def execute( input_fn=input_fn, print_fn=print_fn, context_only=context_only, + status_fn=status_fn, + section_fn=section_fn, + response_fn=response_fn, + update_task_fn=update_task_fn, ) if discovery_result.cancelled: @@ -251,7 +295,17 @@ def execute( config, additional_context, _print, + status_fn=status_fn, ) + + # Add architecture parent node and section children to the task tree + if section_fn: + section_fn([("Generate Architecture", 2)]) + if update_task_fn: + update_task_fn("design-section-generate-architecture", "in_progress") + if section_fn: + section_fn([(s["name"], 3) for s in sections]) + design_output, _usage = self._generate_architecture_sections( ui, agent_context, @@ -260,8 +314,14 @@ def execute( sections, additional_context, _print, + section_fn=section_fn, + update_task_fn=update_task_fn, + status_fn=status_fn, ) + if update_task_fn: + update_task_fn("design-section-generate-architecture", "completed") + # 5. Run supporting IaC review iac_tool = config.get("project.iac_tool", "terraform") if ui: @@ -274,7 +334,7 @@ def execute( design_output, ) else: - _print(f"\nReviewing {iac_tool} feasibility...") + _print(f"\n[bright_cyan]\u2192[/bright_cyan] Reviewing {iac_tool} feasibility...") self._run_iac_review( agent_context, registry, @@ -329,11 +389,17 @@ def execute( ui.print_dim(" az prototype analyze costs # Cost estimate") ui.print_dim(" az prototype build # Generate code") else: - _print(f"Design iteration {design_state['iteration']} complete.") - _print("Architecture docs: concept/docs/ARCHITECTURE.md") - _print("\nTo refine: az prototype design --context 'your changes'") - _print("To estimate costs: az prototype analyze costs") - _print("To proceed: az prototype build") + _print( + f"[bold bright_green]\u2714[/bold bright_green] Design iteration {design_state['iteration']} complete." + ) + _print( + "[bright_cyan]\u2192[/bright_cyan] Architecture docs:" + " [bright_cyan]concept/docs/ARCHITECTURE.md[/bright_cyan]" + ) + _print("\n[dim]Next steps:[/dim]") + _print("[dim] az prototype design --context 'your changes' # Refine[/dim]") + _print("[dim] az prototype analyze costs # Cost estimate[/dim]") + _print("[dim] az prototype build # Generate code[/dim]") return { "status": "success", @@ -493,6 +559,7 @@ def _plan_architecture( config: ProjectConfig, additional_context: str, _print, + status_fn=None, ) -> list[dict]: """Ask the architect for a section plan, return list of section dicts. @@ -560,6 +627,9 @@ def _generate_architecture_sections( sections: list[dict], additional_context: str, _print, + section_fn=None, + update_task_fn=None, + status_fn=None, ) -> tuple[str, dict]: """Generate each architecture section iteratively. @@ -576,9 +646,20 @@ def _generate_architecture_sections( accumulated: list[str] = [] merged_usage: dict[str, int] = {} - for idx, section in enumerate(sections, 1): + # Start cumulative timer for the entire architecture generation + if status_fn: + status_fn("Generating architecture...", "start") + + idx = 0 + while idx < len(sections): + section = sections[idx] section_name = section["name"] section_context = section.get("context", "") + slug = re.sub(r"[^a-z0-9]+", "-", section_name.lower()).strip("-") + task_id = f"design-section-{slug}" + + if update_task_fn: + update_task_fn(task_id, "in_progress") prompt = ( f"## Task\n" @@ -613,11 +694,16 @@ def _generate_architecture_sections( f"## Instructions\n" f'Generate ONLY the "{section_name}" section. Use markdown with a ## heading.\n' f"Ensure consistency with the sections already generated above.\n" - f"Do not repeat content from prior sections." + f"Do not repeat content from prior sections.\n" + f"If while writing this section you determine an additional section is needed " + f"that is not in the architecture plan, include a line at the very end:\n" + f'[NEW_SECTION: {{"name": "Section Name", "context": "Brief description"}}]' ) + section_start = time.monotonic() + spinner_msg = f"Generating architecture ({section_name})..." - if ui: + if ui and not status_fn: with ui.spinner(spinner_msg): response = architect.execute(agent_context, prompt) else: @@ -645,12 +731,33 @@ def _generate_architecture_sections( finish_reason=cont.finish_reason, ) + section_elapsed = time.monotonic() - section_start + elapsed_str = _format_section_elapsed(section_elapsed) + _print(f" {section_name}...Done. ({elapsed_str})") + accumulated.append(response.content) # Merge usage for k, v in response.usage.items(): merged_usage[k] = merged_usage.get(k, 0) + v + if update_task_fn: + update_task_fn(task_id, "completed") + + # Check for dynamically discovered sections + new_sections = _extract_new_sections(response.content) + for ns in new_sections: + if not any(s["name"].lower() == ns["name"].lower() for s in sections): + sections.append(ns) + plan_summary += f"\n- {ns['name']}: {ns.get('context', '')}" + if section_fn: + section_fn([(ns["name"], 3)]) + + idx += 1 + + if status_fn: + status_fn("Generating architecture...", "end") + return "\n\n".join(accumulated), merged_usage # ------------------------------------------------------------------ diff --git a/azext_prototype/stages/discovery.py b/azext_prototype/stages/discovery.py index cb4884b..7e2e6bc 100644 --- a/azext_prototype/stages/discovery.py +++ b/azext_prototype/stages/discovery.py @@ -22,27 +22,152 @@ import logging import re +from contextlib import contextmanager +from dataclasses import dataclass from datetime import date -from typing import Any, Callable +from typing import Any, Callable, Iterator from azext_prototype.agents.base import AgentCapability, AgentContext from azext_prototype.agents.registry import AgentRegistry from azext_prototype.ai.provider import AIMessage from azext_prototype.ai.token_tracker import TokenTracker from azext_prototype.stages.discovery_state import DiscoveryState +from azext_prototype.stages.intent import ( + IntentKind, + build_discovery_classifier, + read_files_for_session, +) from azext_prototype.stages.qa_router import route_error_to_qa from azext_prototype.ui.console import Console, DiscoveryPrompt from azext_prototype.ui.console import console as default_console logger = logging.getLogger(__name__) +# -------------------------------------------------------------------- # +# Section header extraction +# -------------------------------------------------------------------- # + +_SECTION_HEADING_RE = re.compile(r"^#{2,3}\s+(.+?)\s*$", re.MULTILINE) + +# Matches **Bold Heading** on its own line (common in conversational responses) +_BOLD_HEADING_RE = re.compile(r"^\*\*([^*\n]{3,60})\*\*\s*$", re.MULTILINE) + +_SKIP_HEADINGS = frozenset( + { + "summary", + "policy overrides", + "policy override", + "next steps", + "what i've understood so far", + "what we've covered", + "what i've understood", + "what we've established", + } +) + + +def extract_section_headers(response: str) -> list[tuple[str, int]]: + """Extract ## / ### headings and **bold headings** from an AI response. + + Returns a list of ``(heading_text, level)`` tuples sorted by position. + Level 2 = top-level section (``##`` or ``**bold**``), level 3 = subsection (``###``). + + Filters out structural headings (Summary, Policy Overrides, Next Steps, + "What I've Understood So Far", etc.) and very short matches. + """ + matches: list[tuple[int, str, int]] = [] # (position, text, level) + for m in _SECTION_HEADING_RE.finditer(response): + text = m.group(1).strip() + hashes = len(m.group(0)) - len(m.group(0).lstrip("#")) + level = min(hashes, 3) # ## = 2, ### = 3 + matches.append((m.start(), text, level)) + for m in _BOLD_HEADING_RE.finditer(response): + text = m.group(1).strip() + matches.append((m.start(), text, 2)) + matches.sort(key=lambda x: x[0]) + + seen: set[str] = set() + headers: list[tuple[str, int]] = [] + for _, text, level in matches: + lower = text.lower() + if lower in _SKIP_HEADINGS or len(text) < 3 or lower in seen: + continue + seen.add(lower) + headers.append((text, level)) + return headers + + +# -------------------------------------------------------------------- # +# Section parsing — code-level gating for one-at-a-time display +# -------------------------------------------------------------------- # + + +@dataclass +class Section: + """A parsed section from an AI response.""" + + heading: str + level: int # 2=##, 3=### + content: str # text from heading to next heading (includes heading line) + task_id: str # "design-section-{slug}" + + +def parse_sections(response: str) -> tuple[str, list[Section]]: + """Split *response* into ``(preamble, sections)``. + + Preamble = text before the first heading. Sections are filtered by + ``_SKIP_HEADINGS`` (same filter as :func:`extract_section_headers`). + """ + # Collect heading positions + matches: list[tuple[int, str, int]] = [] # (position, text, level) + for m in _SECTION_HEADING_RE.finditer(response): + text = m.group(1).strip() + hashes = len(m.group(0)) - len(m.group(0).lstrip("#")) + level = min(hashes, 3) + matches.append((m.start(), text, level)) + for m in _BOLD_HEADING_RE.finditer(response): + text = m.group(1).strip() + matches.append((m.start(), text, 2)) + matches.sort(key=lambda x: x[0]) + + if not matches: + return response, [] + + preamble = response[: matches[0][0]].strip() + + seen: set[str] = set() + sections: list[Section] = [] + for idx, (pos, text, level) in enumerate(matches): + lower = text.lower() + if lower in _SKIP_HEADINGS or len(text) < 3 or lower in seen: + continue + seen.add(lower) + + # Content runs from this heading to the next heading (or end) + end = matches[idx + 1][0] if idx + 1 < len(matches) else len(response) + content = response[pos:end].strip() + + slug = re.sub(r"[^a-z0-9]+", "-", lower).strip("-") + task_id = f"design-section-{slug}" + sections.append(Section(heading=text, level=level, content=content, task_id=task_id)) + + return preamble, sections + + +# -------------------------------------------------------------------- # +# Section follow-up detection +# -------------------------------------------------------------------- # + +_SECTION_COMPLETE_MARKER = "[SECTION_COMPLETE]" + + # -------------------------------------------------------------------- # # Sentinels # -------------------------------------------------------------------- # # User inputs that end the session _QUIT_WORDS = frozenset({"q", "quit", "exit"}) -_DONE_WORDS = frozenset({"done", "finish", "accept", "lgtm"}) +_DONE_WORDS = frozenset({"done", "end", "finish", "accept", "lgtm", "continue"}) # Slash commands _SLASH_COMMANDS = frozenset({"/open", "/status", "/confirmed", "/help", "/summary", "/restart"}) @@ -139,6 +264,184 @@ def __init__( qa_agents = registry.find_by_capability(AgentCapability.QA) self._qa_agent = qa_agents[0] if qa_agents else None + # Intent classifier for natural language command detection + self._intent_classifier = build_discovery_classifier( + ai_provider=agent_context.ai_provider, + token_tracker=self._token_tracker, + ) + + # ------------------------------------------------------------------ # + # Spinner helper (mirrors build/deploy pattern) + # ------------------------------------------------------------------ # + + @contextmanager + def _maybe_spinner(self, message: str, use_styled: bool, *, status_fn: Callable | None = None) -> Iterator[None]: + """Show a spinner when using styled output, otherwise no-op.""" + if use_styled: + with self._console.spinner(message): + yield + elif status_fn: + status_fn(message, "start") + try: + yield + finally: + status_fn(message, "end") + else: + yield + + # ------------------------------------------------------------------ # + # Display helpers + # ------------------------------------------------------------------ # + + def _show_content(self, content: str, use_styled: bool, _print: Callable) -> None: + """Display content using the appropriate output channel.""" + if use_styled: + self._console.print_agent_response(content) + self._console.print_token_status(self._token_tracker.format_status()) + elif self._response_fn: + self._response_fn(content) + else: + _print(content) + + def _handle_read_files( + self, + args: str, + _print: Callable, + use_styled: bool, + ) -> None: + """Read files into the session and display the AI's analysis.""" + text, images = read_files_for_session(args, self._context.project_dir, _print) + if not (text or images): + return + content: str | list = text + if images: + parts: list[dict] = [] + if text: + parts.append({"type": "text", "text": f"Here are the files I'd like you to review:\n\n{text}"}) + for img in images: + parts.append( + { + "type": "image_url", + "image_url": {"url": f"data:{img['mime']};base64,{img['data']}", "detail": "high"}, + } + ) + content = parts if parts else text + elif text: + content = f"Here are the files I'd like you to review:\n\n{text}" + self._exchange_count += 1 + with self._maybe_spinner("Analyzing files...", use_styled, status_fn=self._status_fn): + response = self._chat(content) + self._discovery_state.update_from_exchange(f"[Read files from {args}]", response, self._exchange_count) + self._extract_items_from_response(response) + clean = self._clean(response) + self._show_content(clean, use_styled, _print) + self._update_token_status() + self._emit_sections(clean) + + # ------------------------------------------------------------------ # + # Section-at-a-time gating + # ------------------------------------------------------------------ # + + _SKIP_WORDS = frozenset({"skip", "next", "move on"}) + + def _run_section_loop( + self, + sections: list[Section], + preamble: str, + _input: Callable[[str], str], + _print: Callable[[str], None], + use_styled: bool, + ) -> str | None: + """Walk sections one at a time. + + Returns ``"cancelled"``, ``"done"``, or ``None`` (all sections covered, + fall through to free-form loop). + """ + if preamble: + self._show_content(preamble, use_styled, _print) + + all_confirmed = True + + for i, section in enumerate(sections): + if self._update_task_fn: + self._update_task_fn(section.task_id, "in_progress") + + self._show_content(section.content, use_styled, _print) + self._update_token_status() + + # Inner follow-up loop (max 5 per section) + section_confirmed = False + for _ in range(5): + try: + user_input = _input("> ").strip() + except (EOFError, KeyboardInterrupt): + if self._update_task_fn: + self._update_task_fn(section.task_id, "completed") + return "done" + + if not user_input: + continue + + lower = user_input.lower() + if lower in _QUIT_WORDS: + return "cancelled" + if lower in _DONE_WORDS: + # Mark remaining sections as completed + for s in sections[i:]: + if self._update_task_fn: + self._update_task_fn(s.task_id, "completed") + return "done" + if lower in self._SKIP_WORDS: + break # Advance to next section + + # Handle slash commands + if lower in _SLASH_COMMANDS: + self._handle_slash_command(lower) + continue + if lower.startswith("/why"): + self._handle_why_command(user_input) + continue + + # Normal answer — send focused follow-up with explicit gate + self._exchange_count += 1 + topic = section.heading + prompt = ( + f"The user answered about **{topic}**: {user_input}\n" + f"Do you have follow-up questions about **{topic}**? " + f'If fully covered, respond ONLY with the word "Yes" ' + f"(meaning yes, this section is complete). " + f"Otherwise, ask your follow-up questions." + ) + with self._maybe_spinner("Thinking...", use_styled, status_fn=self._status_fn): + response = self._chat(prompt) + + self._discovery_state.update_from_exchange(user_input, response, self._exchange_count) + self._extract_items_from_response(response) + + # Check if the AI confirmed the section is complete + stripped = response.strip().rstrip(".").lower() + if stripped == "yes": + section_confirmed = True + break # Section complete — advance + + clean = self._clean(response) + self._show_content(clean, use_styled, _print) + self._update_token_status() + + if not section_confirmed: + all_confirmed = False + + if self._update_task_fn: + self._update_task_fn(section.task_id, "completed") + + # All sections walked + _print("") + if all_confirmed: + _print("All topics covered! Type anything to keep discussing, or 'continue' to proceed.") + else: + _print("Type anything to keep discussing, or 'continue' to proceed.") + return None + # ------------------------------------------------------------------ # # Public API # ------------------------------------------------------------------ # @@ -151,6 +454,10 @@ def run( input_fn: Callable[[str], str] | None = None, print_fn: Callable[[str], None] | None = None, context_only: bool = False, + status_fn: Callable | None = None, + section_fn: Callable[[list[tuple[str, int]]], None] | None = None, + response_fn: Callable[[str], None] | None = None, + update_task_fn: Callable[[str, str], None] | None = None, ) -> DiscoveryResult: """Run the discovery conversation. @@ -180,6 +487,14 @@ def run( # Use injected I/O for tests, otherwise use styled console use_styled = input_fn is None and print_fn is None _input = input_fn or (lambda p: self._prompt.prompt(p)) + _print = print_fn or self._console.print + # Store for use by slash command handlers + self._use_styled = use_styled + self._print = _print + self._status_fn = status_fn + self._section_fn = section_fn + self._response_fn = response_fn + self._update_task_fn = update_task_fn # Load existing discovery state for context existing_context = "" @@ -193,7 +508,10 @@ def run( # ---- Fallback when no agent is available ---- if not self._biz_agent: - self._console.print_warning("No biz-analyst agent available. Enter your requirements:") + if use_styled: + self._console.print_warning("No biz-analyst agent available. Enter your requirements:") + else: + _print("No biz-analyst agent available. Enter your requirements:") try: text = _input("> ") except (EOFError, KeyboardInterrupt): @@ -208,7 +526,7 @@ def run( # ---- Kick off the conversation ---- opening = self._build_opening(seed_context, artifacts, existing_context, images=artifact_images) - with self._console.spinner("Analyzing your input..."): + with self._maybe_spinner("Analyzing your input...", use_styled, status_fn=status_fn): response = self._chat(opening) # Update discovery state with the initial exchange @@ -216,14 +534,48 @@ def run( self._discovery_state.update_from_exchange(opening, response, self._exchange_count) clean_response = self._clean(response) - self._console.print_agent_response(clean_response) - if use_styled: - self._console.print_token_status(self._token_tracker.format_status()) + preamble, sections = parse_sections(clean_response) + + if sections: + # Populate tree with ALL sections upfront + if self._section_fn: + self._section_fn([(s.heading, s.level) for s in sections]) + + # Section-at-a-time loop + outcome = self._run_section_loop(sections, preamble, _input, _print, use_styled) + if outcome == "cancelled": + return DiscoveryResult( + requirements="", + conversation=list(self._messages), + policy_overrides=[], + exchange_count=self._exchange_count, + cancelled=True, + ) + if outcome == "done": + # Jump to summary production + with self._maybe_spinner("Generating requirements summary...", use_styled, status_fn=status_fn): + summary = self._produce_summary() + overrides = self._extract_overrides(summary) + return DiscoveryResult( + requirements=summary, + conversation=list(self._messages), + policy_overrides=overrides, + exchange_count=self._exchange_count, + ) + else: + # No sections → show full response (backward compat / conversational response) + self._show_content(clean_response, use_styled, _print) + self._update_token_status() + if self._section_fn and not extract_section_headers(clean_response): + self._section_fn([("Discovery conversation", 2)]) # ---- Check if agent needs more information ---- # If context_only mode and agent signals READY, skip interactive loop if context_only and _READY_MARKER in response: - self._console.print_info("Context is sufficient. Proceeding with design.") + if use_styled: + self._console.print_info("Context is sufficient. Proceeding with design.") + else: + _print("Context is sufficient. Proceeding with design.") summary = self._produce_summary() overrides = self._extract_overrides(summary) return DiscoveryResult( @@ -247,6 +599,9 @@ def run( ) first_prompt = False else: + if first_prompt: + _print("[dim]Type 'continue' when finished, or 'quit' to cancel.[/dim]") + first_prompt = False user_input = _input("> ").strip() except (EOFError, KeyboardInterrupt): break @@ -254,16 +609,10 @@ def run( if not user_input: continue - # Handle slash commands + # Check quit/done FIRST — before intent classifier to avoid + # a wasteful AI call and ensure reliable exit behavior lower_input = user_input.lower() - if lower_input in _SLASH_COMMANDS: - self._handle_slash_command(lower_input) - continue - if lower_input.startswith("/why"): - self._handle_why_command(user_input) - continue - - if user_input.lower() in _QUIT_WORDS: + if lower_input in _QUIT_WORDS: return DiscoveryResult( requirements="", conversation=list(self._messages), @@ -272,12 +621,32 @@ def run( cancelled=True, ) - if user_input.lower() in _DONE_WORDS: + if lower_input in _DONE_WORDS: break + # Handle slash commands + if lower_input in _SLASH_COMMANDS: + self._handle_slash_command(lower_input) + continue + if lower_input.startswith("/why"): + self._handle_why_command(user_input) + continue + + # Natural language intent detection + intent = self._intent_classifier.classify(user_input) + if intent.kind == IntentKind.COMMAND: + if intent.command == "/why": + self._handle_why_command(f"/why {intent.args}") + else: + self._handle_slash_command(intent.command) + continue + if intent.kind == IntentKind.READ_FILES: + self._handle_read_files(intent.args, _print, use_styled) + continue + self._exchange_count += 1 - with self._console.spinner("Thinking..."): + with self._maybe_spinner("Thinking...", use_styled, status_fn=status_fn): response = self._chat(user_input) # Update discovery state after each exchange @@ -287,13 +656,16 @@ def run( self._extract_items_from_response(response) clean = self._clean(response) - self._console.print_agent_response(clean) - if use_styled: - self._console.print_token_status(self._token_tracker.format_status()) + self._show_content(clean, use_styled, _print) + self._update_token_status() + self._emit_sections(clean) # Agent signalled convergence if _READY_MARKER in response: - self._console.print_info("Discovery complete. Press Enter to proceed, or keep typing.") + if use_styled: + self._console.print_info("Discovery complete. Press Enter to proceed, or keep typing.") + else: + _print("Discovery complete. Press Enter to proceed, or keep typing.") try: if use_styled: more = self._prompt.simple_prompt("> ") @@ -307,16 +679,17 @@ def run( break # User wants to continue self._exchange_count += 1 - with self._console.spinner("Thinking..."): + with self._maybe_spinner("Thinking...", use_styled, status_fn=status_fn): response = self._chat(more) self._discovery_state.update_from_exchange(more, response, self._exchange_count) self._extract_items_from_response(response) - self._console.print_agent_response(self._clean(response)) - if use_styled: - self._console.print_token_status(self._token_tracker.format_status()) + clean_more = self._clean(response) + self._show_content(clean_more, use_styled, _print) + self._update_token_status() + self._emit_sections(clean_more) # ---- Produce the final summary ---- - with self._console.spinner("Generating requirements summary..."): + with self._maybe_spinner("Generating requirements summary...", use_styled, status_fn=status_fn): summary = self._produce_summary() overrides = self._extract_overrides(summary) @@ -576,9 +949,10 @@ def _build_architect_context(self) -> str: "compliance requirements or security policies you need to " "follow.\n" "\n" - "Don't dump all of these at once. Weave them naturally into " - "the conversation as each topic area comes up. But make sure " - "you cover the relevant technical areas before signalling " + "In your initial response, cover ALL relevant technical areas " + "using separate ## headings for each. Ask 2–4 focused questions " + "per area. The system will present them to the user one at a " + "time. Make sure you cover the relevant areas before signalling " "readiness — the architect cannot design without these details." ) @@ -622,88 +996,114 @@ def _produce_summary(self) -> str: def _handle_slash_command(self, command: str) -> None: """Handle slash commands like /open, /status, /confirmed.""" + _p = self._print + styled = self._use_styled if command == "/open": - self._console.print() - self._console.print(self._discovery_state.format_open_items()) - self._console.print() + _p("") + _p(self._discovery_state.format_open_items()) + _p("") elif command == "/confirmed": - self._console.print() - self._console.print(self._discovery_state.format_confirmed_items()) - self._console.print() + _p("") + _p(self._discovery_state.format_confirmed_items()) + _p("") elif command == "/status": - self._console.print() - self._console.print(f"Discovery Status: {self._discovery_state.format_status_summary()}") - self._console.print() + _p("") + _p(f"Discovery Status: {self._discovery_state.format_status_summary()}") + _p("") if self._discovery_state.open_count > 0: - self._console.print(self._discovery_state.format_open_items()) - self._console.print() + _p(self._discovery_state.format_open_items()) + _p("") elif command == "/summary": if not self._biz_agent or not self._context.ai_provider: - self._console.print_warning("No AI agent available for summary.") + if styled: + self._console.print_warning("No AI agent available for summary.") + else: + _p("No AI agent available for summary.") return - self._console.print() - with self._console.spinner("Generating summary..."): + _p("") + with self._maybe_spinner("Generating summary...", styled, status_fn=self._status_fn): summary = self._chat( "Please provide a concise summary of everything we've " "established so far — confirmed requirements, open questions, " "constraints, and key decisions. This is a mid-session " "checkpoint, not the final summary." ) - self._console.print_agent_response(self._clean(summary)) + if styled: + self._console.print_agent_response(self._clean(summary)) + elif self._response_fn: + self._response_fn(self._clean(summary)) + else: + _p(self._clean(summary)) elif command == "/restart": - self._console.print() - self._console.print_warning("Restarting discovery session...") + _p("") + if styled: + self._console.print_warning("Restarting discovery session...") + else: + _p("Restarting discovery session...") self._discovery_state.reset() self._messages.clear() self._exchange_count = 0 if self._biz_agent and self._context.ai_provider: opening = "I'd like to design a new Azure prototype." - with self._console.spinner("Starting fresh..."): + with self._maybe_spinner("Starting fresh...", styled, status_fn=self._status_fn): response = self._chat(opening) self._exchange_count += 1 self._discovery_state.update_from_exchange(opening, response, self._exchange_count) - self._console.print_agent_response(self._clean(response)) + if styled: + self._console.print_agent_response(self._clean(response)) + elif self._response_fn: + self._response_fn(self._clean(response)) + else: + _p(self._clean(response)) elif command == "/help": - self._console.print() - self._console.print_dim("Available commands:") - self._console.print_dim(" /open - List open items needing resolution") - self._console.print_dim(" /confirmed - List confirmed requirements") - self._console.print_dim(" /status - Show overall discovery status") - self._console.print_dim(" /summary - Show a narrative summary of progress so far") - self._console.print_dim(" /why - Find the exchange where a topic was discussed") - self._console.print_dim(" /restart - Clear state and restart discovery from scratch") - self._console.print_dim(" /help - Show this help message") - self._console.print_dim(" done - Complete discovery and proceed to design") - self._console.print_dim(" quit - Cancel and exit") - self._console.print() + _p("") + _p("Available commands:") + _p(" /open - List open items needing resolution") + _p(" /confirmed - List confirmed requirements") + _p(" /status - Show overall discovery status") + _p(" /summary - Show a narrative summary of progress so far") + _p(" /why - Find the exchange where a topic was discussed") + _p(" /restart - Clear state and restart discovery from scratch") + _p(" /help - Show this help message") + _p(" done - Complete discovery and proceed to design") + _p(" quit - Cancel and exit") + _p("") + _p(" You can also use natural language:") + _p(" 'what are the open items' instead of /open") + _p(" 'where do we stand' instead of /status") + _p(" 'give me a summary' instead of /summary") + _p(" 'why did we choose Cosmos DB' instead of /why Cosmos DB") + _p(" 'read artifacts from ./specs' reads files into the session") + _p("") def _handle_why_command(self, raw_input: str) -> None: """Handle ``/why `` — find the exchange where a topic was discussed.""" + _p = self._print query = raw_input[4:].strip() if not query: - self._console.print() - self._console.print_dim("Usage: /why ") - self._console.print_dim(" Example: /why managed identity") - self._console.print() + _p("") + _p("Usage: /why ") + _p(" Example: /why managed identity") + _p("") return matches = self._discovery_state.search_history(query) - self._console.print() + _p("") if not matches: - self._console.print_dim(f"No exchanges found mentioning '{query}'.") + _p(f"No exchanges found mentioning '{query}'.") else: - self._console.print_dim(f"Found {len(matches)} exchange(s) mentioning '{query}':") - self._console.print() + _p(f"Found {len(matches)} exchange(s) mentioning '{query}':") + _p("") for m in matches: - self._console.print_dim(f" Exchange {m['exchange']}:") + _p(f" Exchange {m['exchange']}:") user_text = m.get("user", "") asst_text = m.get("assistant", "") user_snippet = user_text[:150] + ("..." if len(user_text) > 150 else "") asst_snippet = asst_text[:150] + ("..." if len(asst_text) > 150 else "") - self._console.print_dim(f" You: {user_snippet}") - self._console.print_dim(f" Agent: {asst_snippet}") - self._console.print() - self._console.print() + _p(f" You: {user_snippet}") + _p(f" Agent: {asst_snippet}") + _p("") + _p("") def _extract_items_from_response(self, response: str) -> None: """Extract open questions and confirmed items from agent response. @@ -743,10 +1143,33 @@ def _extract_items_from_response(self, response: str) -> None: # Internal — helpers # ------------------------------------------------------------------ # + def _emit_sections(self, response: str) -> None: + """Notify section_fn callback with any headings found in *response*.""" + if not self._section_fn: + return + headers = extract_section_headers(response) + if headers: + self._section_fn(headers) + + def _update_token_status(self) -> None: + """Push token usage to the TUI status bar via ``status_fn("tokens")``. + + Always pushes an update after an AI call — if the provider didn't + return usage data, shows a turn counter instead of leaving the + elapsed timer stuck. + """ + if self._status_fn: + token_text = self._token_tracker.format_status() + if not token_text: + turns = self._token_tracker.turn_count + token_text = f"Turn {turns}" if turns > 0 else "" + if token_text: + self._status_fn(token_text, "tokens") + @staticmethod def _clean(text: str) -> str: - """Strip the ``[READY]`` marker so the user sees natural text.""" - return text.replace(_READY_MARKER, "").strip() + """Strip invisible markers so the user sees natural text.""" + return text.replace(_READY_MARKER, "").replace(_SECTION_COMPLETE_MARKER, "").strip() @staticmethod def _extract_overrides(summary: str) -> list[dict[str, str]]: diff --git a/azext_prototype/stages/intent.py b/azext_prototype/stages/intent.py new file mode 100644 index 0000000..c152ae3 --- /dev/null +++ b/azext_prototype/stages/intent.py @@ -0,0 +1,780 @@ +"""Natural language intent classification for interactive sessions. + +Provides a two-tier classifier: + +1. **AI-powered** (primary) — when an AI provider is available, sends a + short classification prompt listing the session's available commands. + Uses low temperature (0.0) and low max_tokens (150) for fast, + deterministic responses. +2. **Keyword/regex fallback** — when no AI provider is available (or the + AI call fails), keyword/phrase/regex scoring runs as a zero-latency + fallback. + +Each session registers its own command definitions via factory functions. +The classifier picks AI or fallback automatically. +""" + +from __future__ import annotations + +import json +import logging +import re +from dataclasses import dataclass, field +from enum import Enum +from pathlib import Path +from typing import Any, Callable + +logger = logging.getLogger(__name__) + + +# -------------------------------------------------------------------- # +# Types +# -------------------------------------------------------------------- # + + +class IntentKind(Enum): + """Classification of user input.""" + + COMMAND = "command" + READ_FILES = "read_files" + CONVERSATIONAL = "conversational" + + +@dataclass +class IntentResult: + """Result of classifying user input.""" + + kind: IntentKind + command: str = "" + args: str = "" + original_input: str = "" + confidence: float = 0.0 + + +@dataclass +class CommandDef: + """Definition of a command for AI classification prompt.""" + + command: str + description: str + has_args: bool = False + arg_description: str = "" + + +@dataclass +class IntentPattern: + """Keyword/regex pattern for fallback classification.""" + + command: str + keywords: list[str] = field(default_factory=list) + phrases: list[str] = field(default_factory=list) + regex_patterns: list[str] = field(default_factory=list) + arg_extractor: Callable[[str], str] | None = None + min_confidence: float = 0.5 + + +# -------------------------------------------------------------------- # +# File-read regex — cross-session +# -------------------------------------------------------------------- # + +_READ_FILE_RE = re.compile( + r"(?:read|load|import)\s+(?:artifacts?|files?|documents?)\s+from\s+(.+)", + re.IGNORECASE, +) + + +# -------------------------------------------------------------------- # +# IntentClassifier +# -------------------------------------------------------------------- # + + +class IntentClassifier: + """Two-tier intent classifier: AI-first with keyword fallback. + + Parameters + ---------- + ai_provider: + Optional AI provider for AI-powered classification. + token_tracker: + Optional token tracker for recording classification costs. + """ + + def __init__( + self, + ai_provider: Any = None, + token_tracker: Any = None, + ) -> None: + self._ai_provider = ai_provider + self._token_tracker = token_tracker + self._patterns: list[IntentPattern] = [] + self._command_defs: list[CommandDef] = [] + + def register(self, pattern: IntentPattern) -> None: + """Register a keyword/regex fallback pattern.""" + self._patterns.append(pattern) + + def register_many(self, patterns: list[IntentPattern]) -> None: + """Register multiple keyword/regex fallback patterns.""" + self._patterns.extend(patterns) + + def add_command_def(self, cmd_def: CommandDef) -> None: + """Add a command definition for the AI classification prompt.""" + self._command_defs.append(cmd_def) + + def add_command_defs(self, defs: list[CommandDef]) -> None: + """Add multiple command definitions.""" + self._command_defs.extend(defs) + + # ------------------------------------------------------------------ # + # Public — classify + # ------------------------------------------------------------------ # + + def classify(self, user_input: str) -> IntentResult: + """Classify user input as COMMAND, READ_FILES, or CONVERSATIONAL. + + 1. Explicit slash commands (``/...``) → CONVERSATIONAL (pass-through) + 2. File-read regex → READ_FILES + 3. Keyword/regex scoring (fast, no network) → COMMAND if confident + 4. AI classification (if available, keywords uncertain) → COMMAND or CONVERSATIONAL + """ + if not user_input or not user_input.strip(): + return IntentResult(kind=IntentKind.CONVERSATIONAL, original_input=user_input) + + stripped = user_input.strip() + + # 1. Explicit slash commands — let sessions handle directly + if stripped.startswith("/"): + return IntentResult(kind=IntentKind.CONVERSATIONAL, original_input=user_input) + + # 2. File-read detection + m = _READ_FILE_RE.search(stripped) + if m: + path_str = m.group(1).strip().strip("'\"") + return IntentResult( + kind=IntentKind.READ_FILES, + command="__read_files", + args=path_str, + original_input=user_input, + confidence=0.9, + ) + + # 3. Keyword/regex scoring (fast path — no API call) + keyword_result = self._classify_with_keywords(stripped) + if keyword_result.kind == IntentKind.COMMAND: + return keyword_result + + # 4. AI classification — only when keywords had SOME signal + # (confidence > 0 means some keywords matched but not enough). + # When confidence is 0.0, no keywords matched at all, so the + # input is almost certainly conversational — skip the AI call. + if self._ai_provider and self._command_defs and keyword_result.confidence > 0: + ai_result = self._classify_with_ai(stripped) + if ai_result is not None: + return ai_result + + return keyword_result + + # ------------------------------------------------------------------ # + # Internal — AI classification + # ------------------------------------------------------------------ # + + def _classify_with_ai(self, user_input: str) -> IntentResult | None: + """Use the AI provider to classify the input. + + Returns None on any error, allowing fallback to keyword scoring. + """ + from azext_prototype.ai.provider import AIMessage + + system_prompt = self._build_classification_prompt() + messages = [ + AIMessage(role="system", content=system_prompt), + AIMessage(role="user", content=user_input), + ] + + try: + response = self._ai_provider.chat( + messages, + temperature=0.0, + max_tokens=150, + ) + if self._token_tracker: + self._token_tracker.record(response) + + return self._parse_ai_response(response.content, user_input) + except Exception: + logger.debug("AI classification failed, falling back to keywords", exc_info=True) + return None + + def _build_classification_prompt(self) -> str: + """Build the system prompt listing available commands.""" + lines = [ + "You are a command classifier. Given user input, determine if it " + "maps to one of these commands or is conversational input for the " + "AI assistant.", + "", + "Available commands:", + ] + + for cmd_def in self._command_defs: + if cmd_def.has_args: + lines.append(f"- {cmd_def.command} <{cmd_def.arg_description}> — {cmd_def.description}") + else: + lines.append(f"- {cmd_def.command} — {cmd_def.description}") + + lines.extend( + [ + "(Plus special: __prompt_context — user wants to provide new context/files)", + "(Plus special: __read_files — user wants to read files from a path)", + "", + 'Respond with JSON only: {"command": "/open", "args": "", "is_command": true}', + "If the input is conversational (design feedback, questions, etc.), " + 'respond: {"command": "", "args": "", "is_command": false}', + ] + ) + + return "\n".join(lines) + + def _parse_ai_response(self, content: str, original_input: str) -> IntentResult | None: + """Parse the AI's JSON response into an IntentResult.""" + text = content.strip() + # Strip markdown fences if present + if text.startswith("```"): + lines = text.split("\n") + lines = [ln for ln in lines if not ln.strip().startswith("```")] + text = "\n".join(lines).strip() + + try: + data = json.loads(text) + except json.JSONDecodeError: + logger.debug("Could not parse AI classification response: %s", text[:100]) + return None + + if not isinstance(data, dict): + return None + + is_command = data.get("is_command", False) + if not is_command: + return IntentResult( + kind=IntentKind.CONVERSATIONAL, + original_input=original_input, + confidence=0.8, + ) + + command = data.get("command", "") + args = data.get("args", "") + + if not command: + return IntentResult( + kind=IntentKind.CONVERSATIONAL, + original_input=original_input, + confidence=0.5, + ) + + return IntentResult( + kind=IntentKind.COMMAND, + command=command, + args=str(args), + original_input=original_input, + confidence=0.9, + ) + + # ------------------------------------------------------------------ # + # Internal — keyword/regex fallback + # ------------------------------------------------------------------ # + + def _classify_with_keywords(self, user_input: str) -> IntentResult: + """Score registered patterns against user input.""" + lower = user_input.lower() + best_score = 0.0 + best_pattern: IntentPattern | None = None + + for pattern in self._patterns: + score = 0.0 + + # Keyword scoring: +0.2 each + for kw in pattern.keywords: + if kw.lower() in lower: + score += 0.2 + + # Phrase scoring: +0.4 each + for phrase in pattern.phrases: + if phrase.lower() in lower: + score += 0.4 + + # Regex scoring: +0.6 each + for rx in pattern.regex_patterns: + if re.search(rx, user_input, re.IGNORECASE): + score += 0.6 + + score = min(score, 1.0) + + if score > best_score: + best_score = score + best_pattern = pattern + + if best_pattern and best_score >= best_pattern.min_confidence: + args = "" + if best_pattern.arg_extractor: + args = best_pattern.arg_extractor(user_input) + + return IntentResult( + kind=IntentKind.COMMAND, + command=best_pattern.command, + args=args, + original_input=user_input, + confidence=best_score, + ) + + # Return the actual best_score even when below threshold — this + # allows the caller to detect partial keyword signal and decide + # whether to try AI classification. + return IntentResult( + kind=IntentKind.CONVERSATIONAL, + original_input=user_input, + confidence=best_score, + ) + + +# -------------------------------------------------------------------- # +# Arg extractors +# -------------------------------------------------------------------- # + + +def _extract_stage_numbers(text: str) -> str: + """Extract stage numbers from text like 'stage 3' or 'stages 3 and 4'.""" + numbers = re.findall(r"\d+", text) + return " ".join(numbers) + + +def _extract_why_args(text: str) -> str: + """Extract the topic from 'why did we choose X' style input.""" + # Remove common prefixes + cleaned = re.sub( + r"^(?:why\s+(?:did\s+we\s+)?(?:choose|pick|select|use|go\s+with)?)\s*", + "", + text, + flags=re.IGNORECASE, + ).strip(" ?") + return cleaned + + +def _extract_show_number(text: str) -> str: + """Extract item number from 'show item 3' style input.""" + numbers = re.findall(r"\d+", text) + return numbers[0] if numbers else "" + + +# -------------------------------------------------------------------- # +# Factory functions — per-session classifiers +# -------------------------------------------------------------------- # + + +def build_discovery_classifier( + ai_provider: Any = None, + token_tracker: Any = None, +) -> IntentClassifier: + """Build an intent classifier for the discovery session.""" + c = IntentClassifier(ai_provider=ai_provider, token_tracker=token_tracker) + + # Command definitions (for AI prompt) + c.add_command_defs( + [ + CommandDef("/open", "Show open items needing resolution"), + CommandDef("/confirmed", "Show confirmed requirements"), + CommandDef("/status", "Show discovery progress"), + CommandDef("/summary", "Generate a narrative summary"), + CommandDef("/why", "Search for when a topic was discussed", has_args=True, arg_description="topic"), + CommandDef("/restart", "Clear state and restart discovery"), + ] + ) + + # Keyword/regex fallback patterns + c.register_many( + [ + IntentPattern( + command="/open", + keywords=["open"], + phrases=["open items", "open questions", "what's open", "unresolved"], + regex_patterns=[r"what(?:'s| are| is) (?:the )?open", r"what(?:'s| is) (?:still )?unresolved"], + ), + IntentPattern( + command="/confirmed", + keywords=["confirmed"], + phrases=["confirmed requirements", "what's confirmed", "confirmed items"], + regex_patterns=[r"what(?:'s| are| is) confirmed"], + ), + IntentPattern( + command="/status", + keywords=[], + phrases=["where do we stand", "what's the status", "discovery status", "how far along"], + regex_patterns=[r"(?:where|how)\s+(?:do\s+we|are\s+we)\s+stand"], + ), + IntentPattern( + command="/summary", + keywords=[], + phrases=["give me a summary", "summarize", "show summary"], + regex_patterns=[r"(?:give|show|generate)\s+(?:me\s+)?a?\s*summary"], + ), + IntentPattern( + command="/why", + keywords=[], + phrases=[], + regex_patterns=[r"why\s+did\s+we\s+(?:choose|pick|select|use|go\s+with)"], + arg_extractor=_extract_why_args, + ), + IntentPattern( + command="/restart", + keywords=[], + phrases=["start over", "restart", "start from scratch", "begin again"], + regex_patterns=[r"(?:start|begin)\s+(?:over|from\s+scratch|again)"], + ), + IntentPattern( + command="__prompt_context", + keywords=[], + phrases=["i have new context", "i have some context", "let me provide context"], + regex_patterns=[r"i\s+have\s+(?:new|some|additional)\s+context"], + ), + ] + ) + + return c + + +def build_build_classifier( + ai_provider: Any = None, + token_tracker: Any = None, +) -> IntentClassifier: + """Build an intent classifier for the build session.""" + c = IntentClassifier(ai_provider=ai_provider, token_tracker=token_tracker) + + c.add_command_defs( + [ + CommandDef("/status", "Show stage completion summary"), + CommandDef("/stages", "Show full deployment plan"), + CommandDef("/files", "List all generated files"), + CommandDef("/policy", "Show policy check summary"), + CommandDef("/describe", "Show detailed description of a stage", has_args=True, arg_description="N"), + ] + ) + + c.register_many( + [ + IntentPattern( + command="/status", + keywords=[], + phrases=["build status", "what's the status", "how's the build"], + regex_patterns=[r"what(?:'s| is)\s+the\s+(?:build\s+)?status"], + ), + IntentPattern( + command="/stages", + keywords=[], + phrases=["show stages", "list stages", "deployment plan"], + regex_patterns=[r"(?:show|list|display)\s+(?:the\s+)?stages"], + ), + IntentPattern( + command="/files", + keywords=[], + phrases=["generated files", "show files", "list files", "what files"], + regex_patterns=[ + r"(?:show|list|display)\s+(?:me\s+)?(?:the\s+)?(?:generated\s+)?files", + r"what(?:'s| are)\s+(?:the\s+)?(?:generated\s+)?files", + ], + ), + IntentPattern( + command="/policy", + keywords=[], + phrases=["policy status", "policy check", "policy summary"], + regex_patterns=[r"(?:show|check)\s+(?:the\s+)?polic(?:y|ies)"], + ), + IntentPattern( + command="/describe", + keywords=[], + phrases=[], + regex_patterns=[ + r"describe\s+stage\s+\d+", + r"what(?:'s| is)\s+(?:in|being\s+built\s+in)\s+stage\s+\d+", + r"show\s+(?:me\s+)?stage\s+\d+\s+details?", + ], + arg_extractor=_extract_stage_numbers, + ), + ] + ) + + return c + + +def build_deploy_classifier( + ai_provider: Any = None, + token_tracker: Any = None, +) -> IntentClassifier: + """Build an intent classifier for the deploy session.""" + c = IntentClassifier(ai_provider=ai_provider, token_tracker=token_tracker) + + c.add_command_defs( + [ + CommandDef("/deploy", "Deploy stage N or all pending stages", has_args=True, arg_description="N|all"), + CommandDef("/rollback", "Roll back stage N or all", has_args=True, arg_description="N|all"), + CommandDef("/redeploy", "Rollback + redeploy stage N", has_args=True, arg_description="N"), + CommandDef("/plan", "Show what-if/terraform plan for stage N", has_args=True, arg_description="N"), + CommandDef("/outputs", "Show captured deployment outputs"), + CommandDef("/preflight", "Re-run preflight checks"), + CommandDef("/login", "Run az login interactively"), + CommandDef("/status", "Show deployment progress per stage"), + CommandDef("/describe", "Show detailed description of a stage", has_args=True, arg_description="N"), + ] + ) + + c.register_many( + [ + IntentPattern( + command="/deploy", + keywords=[], + phrases=[], + regex_patterns=[ + r"deploy\s+(?:stage\s+)?\d+", + r"deploy\s+(?:all\s+)?(?:pending\s+)?stages", + r"deploy\s+stages?\s+\d+(?:\s+and\s+\d+)*", + r"deploy\s+all", + ], + arg_extractor=lambda t: ( + _extract_stage_numbers(t) or "all" + if re.search(r"\ball\b", t, re.IGNORECASE) + else _extract_stage_numbers(t) + ), + ), + IntentPattern( + command="/rollback", + keywords=[], + phrases=[], + regex_patterns=[ + r"rollback\s+(?:stage\s+)?\d+", + r"roll\s+back\s+(?:stage\s+)?\d+", + r"rollback\s+all", + r"roll\s+back\s+all", + r"undo\s+(?:stage\s+)?\d+", + r"undo\s+(?:the\s+)?deploy", + ], + arg_extractor=lambda t: "all" if re.search(r"\ball\b", t, re.IGNORECASE) else _extract_stage_numbers(t), + ), + IntentPattern( + command="/redeploy", + keywords=[], + phrases=[], + regex_patterns=[ + r"redeploy\s+(?:stage\s+)?\d+", + r"re-deploy\s+(?:stage\s+)?\d+", + ], + arg_extractor=_extract_stage_numbers, + ), + IntentPattern( + command="/plan", + keywords=[], + phrases=["show plan", "what-if", "terraform plan"], + regex_patterns=[ + r"(?:show\s+)?plan\s+(?:for\s+)?stage\s+\d+", + r"what.?if\s+(?:for\s+)?stage\s+\d+", + ], + arg_extractor=_extract_stage_numbers, + ), + IntentPattern( + command="/outputs", + keywords=[], + phrases=["deployment outputs", "show outputs", "captured outputs"], + regex_patterns=[r"(?:show|display|list)\s+(?:the\s+)?(?:deployment\s+)?outputs"], + ), + IntentPattern( + command="/preflight", + keywords=[], + phrases=["run preflight", "preflight checks", "check prerequisites"], + regex_patterns=[r"(?:run|re-?run)\s+preflight"], + ), + IntentPattern( + command="/login", + keywords=[], + phrases=["az login", "azure login", "log in"], + regex_patterns=[r"(?:az|azure)\s+login"], + ), + IntentPattern( + command="/status", + keywords=[], + phrases=["deployment status", "deploy status", "what's deployed"], + regex_patterns=[ + r"what(?:'s| is)\s+(?:the\s+)?deploy(?:ment)?\s+status", + r"what(?:'s| is)\s+deployed", + ], + ), + IntentPattern( + command="/describe", + keywords=[], + phrases=[], + regex_patterns=[ + r"describe\s+stage\s+\d+", + r"what(?:'s| is)\s+(?:in|being\s+deployed\s+in)\s+stage\s+\d+", + r"show\s+(?:me\s+)?stage\s+\d+\s+details?", + ], + arg_extractor=_extract_stage_numbers, + ), + ] + ) + + return c + + +def build_backlog_classifier( + ai_provider: Any = None, + token_tracker: Any = None, +) -> IntentClassifier: + """Build an intent classifier for the backlog session. + + Intentionally omits ``/add`` — "add a story about X" should fall + through to the AI mutation path. + """ + c = IntentClassifier(ai_provider=ai_provider, token_tracker=token_tracker) + + c.add_command_defs( + [ + CommandDef("/list", "Show all items grouped by epic"), + CommandDef("/show", "Show item N with full details", has_args=True, arg_description="N"), + CommandDef("/remove", "Remove item N", has_args=True, arg_description="N"), + CommandDef("/preview", "Show what will be pushed"), + CommandDef("/save", "Save to concept/docs/BACKLOG.md"), + CommandDef("/push", "Push all pending items or item N", has_args=True, arg_description="N"), + CommandDef("/status", "Show push status per item"), + ] + ) + + c.register_many( + [ + IntentPattern( + command="/list", + keywords=[], + phrases=["show all items", "list items", "list all", "show backlog", "show items"], + regex_patterns=[r"(?:show|list|display)\s+(?:all\s+)?(?:the\s+)?(?:backlog\s+)?items"], + ), + IntentPattern( + command="/show", + keywords=[], + phrases=[], + regex_patterns=[ + r"show\s+(?:me\s+)?item\s+\d+", + r"show\s+(?:me\s+)?(?:story|feature)\s+\d+", + r"details?\s+(?:for|of|on)\s+(?:item|story)\s+\d+", + ], + arg_extractor=_extract_show_number, + ), + IntentPattern( + command="/remove", + keywords=[], + phrases=[], + regex_patterns=[ + r"remove\s+item\s+\d+", + r"delete\s+item\s+\d+", + r"remove\s+(?:story|feature)\s+\d+", + ], + arg_extractor=_extract_show_number, + ), + IntentPattern( + command="/preview", + keywords=[], + phrases=["preview", "show preview", "what will be pushed"], + regex_patterns=[r"(?:show\s+)?(?:the\s+)?preview"], + ), + IntentPattern( + command="/save", + keywords=[], + phrases=["save backlog", "save to file", "save to markdown"], + regex_patterns=[r"save\s+(?:the\s+)?(?:backlog|items)"], + ), + IntentPattern( + command="/push", + keywords=[], + phrases=[], + regex_patterns=[ + r"push\s+(?:item\s+)?\d+", + r"push\s+(?:all|items|everything)", + r"create\s+(?:the\s+)?(?:issues?|work\s+items?)", + ], + arg_extractor=_extract_show_number, + ), + IntentPattern( + command="/status", + keywords=[], + phrases=["push status", "item status", "what's been pushed"], + regex_patterns=[r"what(?:'s| is)\s+(?:the\s+)?(?:push\s+)?status"], + ), + ] + ) + + return c + + +# -------------------------------------------------------------------- # +# File reading helper +# -------------------------------------------------------------------- # + + +def read_files_for_session( + path_str: str, + project_dir: str, + print_fn: Callable[[str], None], +) -> tuple[str, list[dict]]: + """Read files from a path for mid-session injection. + + Returns ``(text_content, images_list)`` where images_list contains + dicts with ``filename``, ``data``, ``mime`` keys for vision API use. + """ + from azext_prototype.parsers.binary_reader import read_file + + # Expand ~ and resolve relative paths + path = Path(path_str).expanduser() + if not path.is_absolute(): + path = Path(project_dir) / path + + if not path.exists(): + print_fn(f" Path not found: {path}") + return "", [] + + text_parts: list[str] = [] + images: list[dict] = [] + + if path.is_file(): + files = [path] + else: + # Read all files in directory (non-recursive, skip hidden) + files = sorted(f for f in path.iterdir() if f.is_file() and not f.name.startswith(".")) + + for file_path in files: + result = read_file(file_path) + if result.error: + print_fn(f" Could not read {file_path.name}: {result.error}") + continue + + if result.text: + text_parts.append(f"## {file_path.name}\n\n{result.text}") + + if result.image_data and result.mime_type: + images.append( + { + "filename": file_path.name, + "data": result.image_data, + "mime": result.mime_type, + } + ) + + for emb in result.embedded_images: + images.append( + { + "filename": f"{file_path.name}:{emb.source}", + "data": emb.data, + "mime": emb.mime_type, + } + ) + + if text_parts: + print_fn(f" Read {len(text_parts)} file(s) from {path_str}") + elif images: + print_fn(f" Read {len(images)} image(s) from {path_str}") + else: + print_fn(f" No readable files found in {path_str}") + + return "\n\n---\n\n".join(text_parts), images diff --git a/azext_prototype/templates/__init__.py b/azext_prototype/templates/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/azext_prototype/templates/docs/__init__.py b/azext_prototype/templates/docs/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/azext_prototype/templates/workloads/__init__.py b/azext_prototype/templates/workloads/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/azext_prototype/ui/__init__.py b/azext_prototype/ui/__init__.py index 93949dc..21683e2 100644 --- a/azext_prototype/ui/__init__.py +++ b/azext_prototype/ui/__init__.py @@ -4,6 +4,8 @@ - Progress indicators for file operations and API calls - Claude Code-inspired styling with borders and colors - Styled prompts with instructions + +And a Textual-based TUI dashboard for interactive sessions. """ from azext_prototype.ui.console import ( diff --git a/azext_prototype/ui/app.py b/azext_prototype/ui/app.py new file mode 100644 index 0000000..a8231aa --- /dev/null +++ b/azext_prototype/ui/app.py @@ -0,0 +1,136 @@ +"""Textual TUI application — main dashboard for interactive sessions. + +Layout:: + + ┌─────────────────────────────┬──────────────────┐ + │ Console (RichLog) │ Tasks (Tree) │ + │ scrollable, new at bottom │ collapsible │ + ├─────────────────────────────┴──────────────────┤ + │ Prompt (TextArea) │ + ├───────────────────────────┬────────────────────┤ + │ Assist (left 50%) │ Status (right 50%) │ + └───────────────────────────┴────────────────────┘ + +Sessions run on ``@work(thread=True)`` workers. ``call_from_thread`` +is used to schedule widget updates on the main event loop. +""" + +from __future__ import annotations + +from textual import on +from textual.app import App, ComposeResult +from textual.containers import Horizontal + +from azext_prototype.ui.task_model import TaskStore +from azext_prototype.ui.theme import APP_CSS +from azext_prototype.ui.tui_adapter import TUIAdapter +from azext_prototype.ui.widgets.console_view import ConsoleView +from azext_prototype.ui.widgets.info_bar import InfoBar +from azext_prototype.ui.widgets.prompt_input import PromptInput +from azext_prototype.ui.widgets.task_tree import TaskTree + + +class PrototypeApp(App): + """Main TUI application for ``az prototype`` interactive sessions.""" + + CSS = APP_CSS + + BINDINGS = [ + ("ctrl+c", "quit", "Quit"), + ] + + def __init__( + self, + store: TaskStore | None = None, + start_stage: str | None = None, + project_dir: str | None = None, + stage_kwargs: dict | None = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self._store = store or TaskStore() + self._start_stage = start_stage + self._project_dir = project_dir + self._stage_kwargs = stage_kwargs or {} + self.adapter = TUIAdapter(self) + + # ------------------------------------------------------------------ # + # Compose + # ------------------------------------------------------------------ # + + def compose(self) -> ComposeResult: + with Horizontal(id="body"): + yield ConsoleView(id="console-view") + yield TaskTree(store=self._store, id="task-tree") + yield PromptInput(id="prompt-input") + yield InfoBar(id="info-bar") + + # ------------------------------------------------------------------ # + # Widget accessors + # ------------------------------------------------------------------ # + + @property + def console_view(self) -> ConsoleView: + return self.query_one("#console-view", ConsoleView) + + @property + def task_tree(self) -> TaskTree: + return self.query_one("#task-tree", TaskTree) + + @property + def prompt_input(self) -> PromptInput: + return self.query_one("#prompt-input", PromptInput) + + @property + def info_bar(self) -> InfoBar: + return self.query_one("#info-bar", InfoBar) + + # ------------------------------------------------------------------ # + # Lifecycle + # ------------------------------------------------------------------ # + + def on_mount(self) -> None: + """Set up the initial state after widgets are mounted.""" + self.title = "az prototype" + self.info_bar.update_assist("Enter = submit | Ctrl+J = newline | Ctrl+C = quit") + self.prompt_input.disable() + + # Write a welcome banner + self.console_view.write_info("Welcome to az prototype") + self.console_view.write_dim("") + + # Auto-start the orchestrator if project_dir is set + if self._project_dir: + self.start_orchestrator() + + def start_orchestrator(self) -> None: + """Launch the stage orchestrator on a worker thread.""" + from azext_prototype.ui.stage_orchestrator import StageOrchestrator + + def _run() -> None: + orchestrator = StageOrchestrator( + self, + self.adapter, + self._project_dir or ".", + stage_kwargs=self._stage_kwargs, + ) + orchestrator.run(start_stage=self._start_stage) + + self.run_worker(_run, thread=True) + + # ------------------------------------------------------------------ # + # Shutdown + # ------------------------------------------------------------------ # + + def on_unmount(self) -> None: + """Signal the adapter so worker threads unblock and exit.""" + self.adapter.shutdown() + + # ------------------------------------------------------------------ # + # Event handlers + # ------------------------------------------------------------------ # + + @on(PromptInput.Submitted) + def _on_prompt_submitted(self, event: PromptInput.Submitted) -> None: + """Route prompt submissions to the adapter.""" + self.adapter.on_prompt_submitted(event.value) diff --git a/azext_prototype/ui/console.py b/azext_prototype/ui/console.py index b04f2b4..0141f06 100644 --- a/azext_prototype/ui/console.py +++ b/azext_prototype/ui/console.py @@ -33,62 +33,10 @@ TextColumn, TimeElapsedColumn, ) -from rich.theme import Theme -# -------------------------------------------------------------------- # -# Color scheme (Claude Code inspired) -# -------------------------------------------------------------------- # +from azext_prototype.ui.theme import COLORS, PT_STYLE_DICT, RICH_THEME -THEME = Theme( - { - # Background/secondary text - "dim": "#888888", - "muted": "#666666", - # Primary content — bright for readability on dark terminals - "content": "bright_white", - # Callouts and highlights - "success": "bright_green", - "error": "bright_red", - "warning": "bright_yellow", - "info": "bright_cyan", - "accent": "bright_magenta", # Purple-ish for emphasis - # Prompt styling — #555555 is visible on dark terminals without being bright - "prompt.border": "#555555", - "prompt.instruction": "bright_cyan", - "prompt.input": "bright_white", - # Progress indicators - "progress.description": "bright_white", - "progress.percentage": "bright_cyan", - "progress.bar.complete": "bright_green", - "progress.bar.finished": "bright_green", - # Agent/stage names - "agent": "bright_magenta bold", - "stage": "bright_cyan bold", - # File paths - "path": "bright_cyan", - # Markdown rendering — colour hierarchy for contrast on dark terminals: - # Headers: coloured + bold (stand out from body text) - # Body: white (readable but not blinding) - # Bold: bright_white bold (pops against regular white body) - # Code: green (distinct from prose) - # Markers: cyan (numbered/bullet prefixes contrast with white text) - "markdown.paragraph": "white", - "markdown.h1": "bright_magenta bold underline", - "markdown.h2": "bright_magenta bold", - "markdown.h3": "bright_cyan bold", - "markdown.h4": "bright_cyan italic", - "markdown.bold": "bright_white bold", - "markdown.italic": "white italic", - "markdown.code": "bright_green", - "markdown.code_block": "bright_green", - "markdown.block_quote": "bright_yellow italic", - "markdown.item.bullet": "bright_cyan", - "markdown.item.number": "bright_cyan", - "markdown.link": "bright_cyan underline", - "markdown.link_url": "bright_cyan", - "markdown.hr": "#555555", - } -) +THEME = RICH_THEME # -------------------------------------------------------------------- # # Markdown preprocessing @@ -113,16 +61,7 @@ def _preprocess_markdown(content: str) -> str: # prompt_toolkit style for the input area -PT_STYLE = PTStyle.from_dict( - { - "prompt": "#888888", - "": "#ffffff", # Default text color - # Toolbar below the prompt — noreverse prevents the default white-on-white - # inversion that prompt_toolkit applies. Keep text dim like Claude Code's - # status bar. - "bottom-toolbar": "noreverse #888888", - } -) +PT_STYLE = PTStyle.from_dict(PT_STYLE_DICT) class Console: @@ -481,9 +420,9 @@ def prompt( def _toolbar(): cols = shutil.get_terminal_size().columns return [ - ("#555555", "─" * cols), + (COLORS["border"], "─" * cols), ("", "\n"), - ("#888888", hint), + (COLORS["dim"], hint), ] toolbar = _toolbar @@ -491,7 +430,7 @@ def _toolbar(): def _toolbar_border_only(): cols = shutil.get_terminal_size().columns - return [("#555555", "─" * cols)] + return [(COLORS["border"], "─" * cols)] toolbar = _toolbar_border_only diff --git a/azext_prototype/ui/stage_orchestrator.py b/azext_prototype/ui/stage_orchestrator.py new file mode 100644 index 0000000..66f56c9 --- /dev/null +++ b/azext_prototype/ui/stage_orchestrator.py @@ -0,0 +1,449 @@ +"""Stage orchestrator — detects project state and manages stage transitions. + +Runs on a Textual worker thread. Reads ``.prototype/state/`` files to +determine the current position in the pipeline, populates the task tree +with sub-tasks from completed stages, and waits for user commands. +""" + +from __future__ import annotations + +import logging +import textwrap +from pathlib import Path +from typing import TYPE_CHECKING, Any + +from azext_prototype.ui.task_model import TaskStatus +from azext_prototype.ui.tui_adapter import ShutdownRequested + +if TYPE_CHECKING: + from azext_prototype.ui.app import PrototypeApp + from azext_prototype.ui.tui_adapter import TUIAdapter + +logger = logging.getLogger(__name__) + + +# -------------------------------------------------------------------- # +# State detection +# -------------------------------------------------------------------- # + + +def detect_stage(project_dir: str) -> str: + """Detect the furthest completed stage from state files. + + Returns one of: ``"init"``, ``"design"``, ``"build"``, ``"deploy"``. + """ + state_dir = Path(project_dir) / ".prototype" / "state" + if (state_dir / "deploy.yaml").exists(): + return "deploy" + if (state_dir / "build.yaml").exists(): + return "build" + if (state_dir / "discovery.yaml").exists() or (state_dir / "design.json").exists(): + return "design" + return "init" + + +# -------------------------------------------------------------------- # +# Orchestrator +# -------------------------------------------------------------------- # + + +class StageOrchestrator: + """Manages stage lifecycle within the dashboard. + + Call :meth:`run` from a Textual worker thread. It will detect the + current project state, populate the task tree, show project metadata, + and wait for user commands. + """ + + def __init__( + self, + app: PrototypeApp, + adapter: TUIAdapter, + project_dir: str, + stage_kwargs: dict | None = None, + ) -> None: + self._app = app + self._adapter = adapter + self._project_dir = project_dir + self._stage_kwargs = stage_kwargs or {} + + # ------------------------------------------------------------------ # + # Public + # ------------------------------------------------------------------ # + + def run(self, start_stage: str | None = None) -> None: + """Main orchestration loop — detect state, populate tree, prompt.""" + try: + current = start_stage or detect_stage(self._project_dir) + + # Always mark init as completed + self._adapter.update_task("init", TaskStatus.COMPLETED) + + # Populate tree and show welcome based on detected state + self._populate_from_state(current) + self._show_welcome(current) + + # Auto-run a stage when launched with stage_kwargs + if self._stage_kwargs and start_stage: + if start_stage == "design": + self._run_design(**self._stage_kwargs) + + # Enter the command loop + self._command_loop(current) + except ShutdownRequested: + logger.debug("Orchestrator received shutdown signal") + + # ------------------------------------------------------------------ # + # Welcome + metadata + # ------------------------------------------------------------------ # + + def _show_welcome(self, current_stage: str) -> None: + """Display project metadata in the console view.""" + pf = self._adapter.print_fn + + # Load config for project metadata + try: + from azext_prototype.config import ProjectConfig + + config = ProjectConfig(self._project_dir) + config.load() + + name = config.get("project.name", "") + location = config.get("project.location", "") + iac = config.get("project.iac_tool", "") + ai_provider = config.get("ai.provider", "") + model = config.get("ai.model", "") + + if name: + pf(f" Project: {name}") + summary = self._get_project_summary() + if summary: + prefix = " Summary: " + indent = " " * len(prefix) + wrapped = textwrap.fill( + summary, + width=80, + initial_indent=prefix, + subsequent_indent=indent, + ) + for line in wrapped.splitlines(): + pf(line) + if location: + pf(f" Location: {location}") + if iac: + pf(f" IaC tool: {iac}") + if ai_provider: + provider_str = ai_provider + if model: + provider_str += f" ({model})" + pf(f" AI: {provider_str}") + + pf(f" Stage: {current_stage}") + pf("") + except Exception: + pf(f" Stage: {current_stage}") + pf("") + + def _get_project_summary(self) -> str: + """Extract a project summary from discovery state or design output. + + Returns empty string if no summary is available. + """ + import re + + def _normalize(text: str) -> str: + """Collapse multiple spaces into one.""" + return re.sub(r" +", " ", text).strip() + + # Try discovery state first + try: + from azext_prototype.stages.discovery_state import DiscoveryState + + ds = DiscoveryState(self._project_dir) + if ds.exists: + ds.load() + summary = ds.state.get("project", {}).get("summary", "") + if summary: + return _normalize(summary) + except Exception: + pass + + # Fall back to design.json architecture text + try: + import json + + design_json = Path(self._project_dir) / ".prototype" / "state" / "design.json" + if design_json.exists(): + data = json.loads(design_json.read_text(encoding="utf-8")) + arch = data.get("architecture", "") + if arch: + # First sentence + first_sentence = arch.split(".")[0].strip() + if first_sentence: + return _normalize(first_sentence + ".") + except Exception: + pass + + return "" + + # ------------------------------------------------------------------ # + # State → task tree population + # ------------------------------------------------------------------ # + + def _populate_from_state(self, current_stage: str) -> None: + """Read state files and populate the task tree with sub-tasks.""" + stage_order = ["init", "design", "build", "deploy"] + current_idx = stage_order.index(current_stage) if current_stage in stage_order else 0 + + # Mark all stages up to current as completed + for i, stage_name in enumerate(stage_order): + if i == 0: + continue # init already marked + if i <= current_idx: + self._adapter.update_task(stage_name, TaskStatus.COMPLETED) + + # Populate sub-tasks from state files + self._populate_design_subtasks() + self._populate_build_subtasks() + self._populate_deploy_subtasks() + + def _populate_design_subtasks(self) -> None: + """Populate design sub-tasks from discovery state.""" + try: + from azext_prototype.stages.discovery_state import DiscoveryState + + ds = DiscoveryState(self._project_dir) + if not ds.exists: + return + ds.load() + + confirmed = ds.confirmed_count + open_count = ds.open_count + + if confirmed > 0: + self._adapter.add_task( + "design", + "design-confirmed", + f"Confirmed requirements ({confirmed})", + ) + self._adapter.update_task("design-confirmed", TaskStatus.COMPLETED) + + if open_count > 0: + self._adapter.add_task( + "design", + "design-open", + f"Open items ({open_count})", + ) + # Open items are pending resolution + self._adapter.update_task("design-open", TaskStatus.PENDING) + + # Check for architecture output + design_json = Path(self._project_dir) / ".prototype" / "state" / "design.json" + if design_json.exists(): + self._adapter.add_task( + "design", + "design-arch", + "Architecture document", + ) + self._adapter.update_task("design-arch", TaskStatus.COMPLETED) + except Exception: + logger.debug("Could not populate design subtasks", exc_info=True) + + def _populate_build_subtasks(self) -> None: + """Populate build sub-tasks from build state.""" + try: + from azext_prototype.stages.build_state import BuildState + + bs = BuildState(self._project_dir) + if not bs.exists: + return + bs.load() + + stages = bs.state.get("deployment_stages", []) + for s in stages: + stage_num = s.get("stage", 0) + name = s.get("name", f"Stage {stage_num}") + status = s.get("status", "pending") + task_id = f"build-stage-{stage_num}" + + self._adapter.add_task("build", task_id, f"Stage {stage_num}: {name}") + + if status in ("generated", "accepted"): + self._adapter.update_task(task_id, TaskStatus.COMPLETED) + elif status == "in_progress": + self._adapter.update_task(task_id, TaskStatus.IN_PROGRESS) + # else: stays PENDING + except Exception: + logger.debug("Could not populate build subtasks", exc_info=True) + + def _populate_deploy_subtasks(self) -> None: + """Populate deploy sub-tasks from deploy state.""" + try: + from azext_prototype.stages.deploy_state import DeployState + + ds = DeployState(self._project_dir) + if not ds.exists: + return + ds.load() + + stages = ds.state.get("deployment_stages", []) + for s in stages: + stage_num = s.get("stage", 0) + name = s.get("name", f"Stage {stage_num}") + deploy_status = s.get("deploy_status", "pending") + task_id = f"deploy-stage-{stage_num}" + + self._adapter.add_task("deploy", task_id, f"Stage {stage_num}: {name}") + + if deploy_status == "deployed": + self._adapter.update_task(task_id, TaskStatus.COMPLETED) + elif deploy_status in ("deploying", "in_progress", "remediating"): + self._adapter.update_task(task_id, TaskStatus.IN_PROGRESS) + elif deploy_status in ("failed", "rolled_back"): + self._adapter.update_task(task_id, TaskStatus.FAILED) + # else: stays PENDING + except Exception: + logger.debug("Could not populate deploy subtasks", exc_info=True) + + # ------------------------------------------------------------------ # + # Command loop + # ------------------------------------------------------------------ # + + def _command_loop(self, current_stage: str) -> None: + """Wait for user commands: design, build, deploy, quit. + + Raises :class:`ShutdownRequested` if the app is shutting down, + which is caught by :meth:`run`. + """ + pf = self._adapter.print_fn + + while True: + user_input = self._adapter.input_fn("> ").strip().lower() + + if not user_input: + continue + + if user_input in ("q", "quit", "exit", "end"): + self._app.call_from_thread(self._app.exit) + break + elif user_input in ("design", "redesign"): + self._run_design() + elif user_input == "build": + self._run_build() + elif user_input in ("deploy", "redeploy"): + self._run_deploy() + elif user_input == "help": + pf("") + pf("Commands:") + pf(" design - Run or re-run the design stage") + pf(" build - Run or re-run the build stage") + pf(" deploy - Run or re-run the deploy stage") + pf(" quit - Exit") + pf("") + else: + pf(f"Unknown command: {user_input}. Type 'help' for options.") + + # ------------------------------------------------------------------ # + # Stage runners + # ------------------------------------------------------------------ # + + def _run_design(self, **kwargs) -> None: + """Launch the design (discovery + architecture) session.""" + self._adapter.clear_tasks("design") + self._adapter.update_task("design", TaskStatus.IN_PROGRESS) + # Show an initial subtask so the tree isn't empty during discovery + self._adapter.add_task("design", "design-discovery", "Discovery") + self._adapter.update_task("design-discovery", TaskStatus.IN_PROGRESS) + + try: + _, config, registry, agent_context = self._prepare() + from azext_prototype.stages.design_stage import DesignStage + + stage = DesignStage() + result = stage.execute( + agent_context, + registry, + input_fn=self._adapter.input_fn, + print_fn=self._adapter.print_fn, + status_fn=self._adapter.status_fn, + section_fn=self._adapter.section_fn, + response_fn=self._adapter.response_fn, + update_task_fn=lambda tid, status: self._adapter.update_task(tid, TaskStatus(status)), + **kwargs, + ) + if result.get("status") == "cancelled": + self._adapter.print_fn("[bright_yellow]![/bright_yellow] Design session cancelled.") + self._app.call_from_thread(self._app.exit) + return + self._adapter.update_task("design", TaskStatus.COMPLETED) + self._populate_design_subtasks() + except ShutdownRequested: + raise + except Exception as exc: + logger.exception("Design stage failed") + self._adapter.update_task("design", TaskStatus.FAILED) + self._adapter.print_fn(f"Design stage failed: {exc}") + + def _run_build(self) -> None: + """Launch the build session.""" + self._adapter.clear_tasks("build") + self._adapter.update_task("build", TaskStatus.IN_PROGRESS) + + try: + _, config, registry, agent_context = self._prepare() + from azext_prototype.stages.build_stage import BuildStage + + stage = BuildStage() + stage.execute( + agent_context, + registry, + input_fn=self._adapter.input_fn, + print_fn=self._adapter.print_fn, + ) + self._adapter.update_task("build", TaskStatus.COMPLETED) + self._populate_build_subtasks() + except ShutdownRequested: + raise + except Exception as exc: + logger.exception("Build stage failed") + self._adapter.update_task("build", TaskStatus.FAILED) + self._adapter.print_fn(f"Build stage failed: {exc}") + + def _run_deploy(self) -> None: + """Launch the deploy session.""" + self._adapter.clear_tasks("deploy") + self._adapter.update_task("deploy", TaskStatus.IN_PROGRESS) + + try: + _, config, registry, agent_context = self._prepare() + from azext_prototype.stages.deploy_stage import DeployStage + + stage = DeployStage() + stage.execute( + agent_context, + registry, + input_fn=self._adapter.input_fn, + print_fn=self._adapter.print_fn, + ) + self._adapter.update_task("deploy", TaskStatus.COMPLETED) + self._populate_deploy_subtasks() + except ShutdownRequested: + raise + except Exception as exc: + logger.exception("Deploy stage failed") + self._adapter.update_task("deploy", TaskStatus.FAILED) + self._adapter.print_fn(f"Deploy stage failed: {exc}") + + # ------------------------------------------------------------------ # + # Helpers + # ------------------------------------------------------------------ # + + def _prepare(self) -> tuple[Any, Any, Any, Any]: + """Load config, registry, and agent context. + + Lazy import to avoid circular dependencies and keep the UI module + lightweight. Returns ``(project_dir, config, registry, agent_context)``. + """ + from azext_prototype.custom import _prepare_command + + return _prepare_command(self._project_dir) diff --git a/azext_prototype/ui/task_model.py b/azext_prototype/ui/task_model.py new file mode 100644 index 0000000..fdf6c27 --- /dev/null +++ b/azext_prototype/ui/task_model.py @@ -0,0 +1,115 @@ +"""Task data model for the TUI task tree. + +Provides ``TaskItem`` and ``TaskStore`` for tracking stage/sub-task +progress displayed in the :class:`~.widgets.task_tree.TaskTree`. +""" + +from __future__ import annotations + +import enum +from dataclasses import dataclass, field + + +class TaskStatus(enum.Enum): + """Lifecycle status for a task tree node.""" + + PENDING = "pending" + IN_PROGRESS = "in_progress" + COMPLETED = "completed" + FAILED = "failed" + + +# Status → display symbol mapping +STATUS_SYMBOLS = { + TaskStatus.PENDING: "\u25cb", # ○ + TaskStatus.IN_PROGRESS: "\u25cf", # ● + TaskStatus.COMPLETED: "\u2713", # ✓ + TaskStatus.FAILED: "\u2717", # ✗ +} + + +@dataclass +class TaskItem: + """A single node in the task tree.""" + + id: str + label: str + status: TaskStatus = TaskStatus.PENDING + children: list[TaskItem] = field(default_factory=list) + + @property + def symbol(self) -> str: + return STATUS_SYMBOLS.get(self.status, "?") + + @property + def display(self) -> str: + return f"{self.symbol} {self.label}" + + +class TaskStore: + """In-memory store for all task items, keyed by id. + + The store maintains four root tasks (one per stage) and allows + dynamic addition/removal of sub-tasks. + """ + + def __init__(self) -> None: + self._items: dict[str, TaskItem] = {} + self._roots: list[str] = [] + self._init_roots() + + def _init_roots(self) -> None: + """Create the four permanent stage root nodes.""" + for stage_id, label in [ + ("init", "Initialize"), + ("design", "Design"), + ("build", "Build"), + ("deploy", "Deploy"), + ]: + item = TaskItem(id=stage_id, label=label, status=TaskStatus.PENDING) + self._items[stage_id] = item + self._roots.append(stage_id) + + @property + def roots(self) -> list[TaskItem]: + return [self._items[rid] for rid in self._roots] + + def get(self, task_id: str) -> TaskItem | None: + return self._items.get(task_id) + + def update_status(self, task_id: str, status: TaskStatus) -> TaskItem | None: + """Update a task's status. Returns the item if found.""" + item = self._items.get(task_id) + if item: + item.status = status + return item + + def add_child(self, parent_id: str, child: TaskItem) -> bool: + """Add a sub-task under a parent. Returns True on success.""" + parent = self._items.get(parent_id) + if not parent: + return False + self._items[child.id] = child + parent.children.append(child) + return True + + def remove(self, task_id: str) -> bool: + """Remove a task (and all its children) from the store.""" + item = self._items.pop(task_id, None) + if not item: + return False + # Remove from any parent's children list + for other in self._items.values(): + other.children = [c for c in other.children if c.id != task_id] + # Recursively remove children + for child in item.children: + self._items.pop(child.id, None) + return True + + def clear_children(self, parent_id: str) -> None: + """Remove all children of a parent task.""" + parent = self._items.get(parent_id) + if parent: + for child in parent.children: + self._items.pop(child.id, None) + parent.children.clear() diff --git a/azext_prototype/ui/theme.py b/azext_prototype/ui/theme.py new file mode 100644 index 0000000..fc152dc --- /dev/null +++ b/azext_prototype/ui/theme.py @@ -0,0 +1,157 @@ +"""Textual CSS variables and Rich theme dict for the TUI dashboard. + +Ports the Claude Code-inspired color scheme from console.py into both +Textual CSS custom properties and a Rich Theme for use inside RichLog +widgets. +""" + +from __future__ import annotations + +from rich.theme import Theme + +# -------------------------------------------------------------------- # +# Shared color constants — single source of truth +# -------------------------------------------------------------------- # + +COLORS = { + "dim": "#888888", + "muted": "#666666", + "content": "bright_white", + "success": "bright_green", + "error": "bright_red", + "warning": "bright_yellow", + "info": "bright_cyan", + "accent": "bright_magenta", + "border": "#555555", + "bg_panel": "#1a1a2e", + "bg_surface": "#16213e", + "bg_input": "#0f3460", +} + +# -------------------------------------------------------------------- # +# Rich theme — used by RichLog widgets to render Rich renderables +# -------------------------------------------------------------------- # + +# -------------------------------------------------------------------- # +# prompt_toolkit style dict — used by console.py DiscoveryPrompt +# -------------------------------------------------------------------- # + +PT_STYLE_DICT = { + "prompt": COLORS["dim"], + "": "#ffffff", + "bottom-toolbar": f"noreverse {COLORS['dim']}", +} + +# -------------------------------------------------------------------- # +# Rich theme — used by RichLog widgets to render Rich renderables +# -------------------------------------------------------------------- # + +RICH_THEME = Theme( + { + "dim": COLORS["dim"], + "muted": COLORS["muted"], + "content": COLORS["content"], + "success": COLORS["success"], + "error": COLORS["error"], + "warning": COLORS["warning"], + "info": COLORS["info"], + "accent": COLORS["accent"], + "prompt.border": COLORS["border"], + "prompt.instruction": COLORS["info"], + "prompt.input": COLORS["content"], + "progress.description": COLORS["content"], + "progress.percentage": COLORS["info"], + "progress.bar.complete": COLORS["success"], + "progress.bar.finished": COLORS["success"], + "agent": "bright_magenta bold", + "stage": "bright_cyan bold", + "path": "bright_cyan", + "markdown.paragraph": "white", + "markdown.h1": "bright_magenta bold underline", + "markdown.h2": "bright_magenta bold", + "markdown.h3": "bright_cyan bold", + "markdown.h4": "bright_cyan italic", + "markdown.bold": "bright_white bold", + "markdown.italic": "white italic", + "markdown.code": "bright_green", + "markdown.code_block": "bright_green", + "markdown.block_quote": "bright_yellow italic", + "markdown.item.bullet": "bright_cyan", + "markdown.item.number": "bright_cyan", + "markdown.link": "bright_cyan underline", + "markdown.link_url": "bright_cyan", + "markdown.hr": COLORS["border"], + } +) + +# -------------------------------------------------------------------- # +# Task-status colors (for the Tree widget) +# -------------------------------------------------------------------- # + +TASK_COLORS = { + "pending": COLORS["dim"], + "in_progress": "bright_white bold", + "completed": COLORS["success"], + "failed": COLORS["error"], +} + +# -------------------------------------------------------------------- # +# Textual CSS stylesheet +# -------------------------------------------------------------------- # + +APP_CSS = """\ +/* ── Layout ────────────────────────────────────────────────── */ + +Screen { + layout: grid; + grid-size: 1; + grid-rows: 1fr auto auto; +} + +#body { + layout: horizontal; + height: 1fr; +} + +#console-view { + width: 3fr; + border-right: solid $accent; +} + +#task-tree { + width: 1fr; + min-width: 24; + max-width: 40; +} + +/* ── Prompt ────────────────────────────────────────────────── */ + +#prompt-input { + height: auto; + min-height: 3; + max-height: 10; + border-top: solid $accent; + border-bottom: solid $accent; + border-left: none; + border-right: none; +} + +/* ── Info bar ──────────────────────────────────────────────── */ + +#info-bar { + layout: horizontal; + height: 1; + dock: bottom; +} + +#assist-label { + width: 1fr; + content-align-vertical: middle; +} + +#status-label { + width: 1fr; + text-align: right; + content-align-vertical: middle; +} +""" diff --git a/azext_prototype/ui/tui_adapter.py b/azext_prototype/ui/tui_adapter.py new file mode 100644 index 0000000..7273aa1 --- /dev/null +++ b/azext_prototype/ui/tui_adapter.py @@ -0,0 +1,398 @@ +"""Bridge between synchronous sessions and the async Textual TUI. + +Sessions (Discovery, Build, Deploy, Backlog) run on Textual worker +threads and call ``input_fn`` / ``print_fn`` synchronously. This +adapter translates those calls into Textual widget operations using +``call_from_thread`` (thread-safe scheduling on the main event loop) +and ``threading.Event`` (blocking the worker until the user submits). + +On shutdown the adapter's :meth:`shutdown` method is called by the app. +This sets the ``_shutdown`` event so that any worker thread blocked in +``input_fn`` unblocks immediately and raises :class:`ShutdownRequested`. + +Usage inside ``PrototypeApp``:: + + adapter = TUIAdapter(app) + # Then pass to a session: + session.run( + input_fn=adapter.input_fn, + print_fn=adapter.print_fn, + ) +""" + +from __future__ import annotations + +import logging +import re +import threading +import time +from typing import TYPE_CHECKING + +from azext_prototype.ui.task_model import TaskItem, TaskStatus +from azext_prototype.ui.theme import COLORS + +if TYPE_CHECKING: + from azext_prototype.ui.app import PrototypeApp + +logger = logging.getLogger(__name__) + +# Rich markup tag pattern (e.g. [success], [/error], [info]...[/info]) +_RICH_TAG_RE = re.compile(r"\[/?[a-zA-Z_.]+(?:\s[^\]]+)?\]") + + +def _format_elapsed(seconds: float) -> str: + """Format elapsed seconds as ``12s`` or ``1m04s`` when >= 60.""" + if seconds < 60: + return f"{seconds:.0f}s" + minutes = int(seconds) // 60 + secs = int(seconds) % 60 + return f"{minutes}m{secs:02d}s" + + +def _strip_rich_markup(text: str) -> str: + """Remove Rich-style markup tags for plain-text rendering.""" + return _RICH_TAG_RE.sub("", text) + + +class ShutdownRequested(Exception): + """Raised on a worker thread when the app is shutting down.""" + + +class TUIAdapter: + """Thread-safe bridge from sync session I/O to Textual widgets. + + Constructed with a reference to the running :class:`PrototypeApp`. + The three main callables — ``input_fn``, ``print_fn``, ``status_fn`` + — are suitable for passing directly to session ``run()`` methods. + """ + + def __init__(self, app: PrototypeApp) -> None: + self._app = app + # Synchronization for blocking input + self._input_event = threading.Event() + self._input_value: str = "" + # Shutdown signal — unblocks any waiting worker thread + self._shutdown = threading.Event() + # Elapsed timer state (managed on the main thread) + self._timer_start: float | None = None + self._timer_handle = None # Textual Timer reference + # Track last level-2 section for nesting level-3 subsections + self._last_l2_section_id: str = "design" + + # ------------------------------------------------------------------ # + # Shutdown + # ------------------------------------------------------------------ # + + def shutdown(self) -> None: + """Signal all waiting worker threads to exit immediately.""" + self._shutdown.set() + self._cancel_timer() + # Also set the input event so any blocked wait() returns + self._input_event.set() + + @property + def is_shutdown(self) -> bool: + """True if the adapter has been told to shut down.""" + return self._shutdown.is_set() + + # ------------------------------------------------------------------ # + # Screen refresh helper + # ------------------------------------------------------------------ # + + def _request_screen_update(self) -> None: + """Force Textual to repaint the screen. + + Must be called on the main thread (inside a ``call_from_thread`` + callback). ``call_from_thread`` runs callbacks as asyncio tasks + *outside* Textual's message loop, so widget ``refresh()`` calls + may not trigger a compositor pass. Calling ``screen.refresh()`` + explicitly schedules one. + """ + try: + self._app.screen.refresh() + except Exception: + pass + + # ------------------------------------------------------------------ # + # print_fn — called from worker thread + # ------------------------------------------------------------------ # + + def print_fn(self, message: str = "", **kwargs) -> None: + """Write *message* to the ConsoleView widget. + + Called from a worker thread; delegates to the main thread via + ``call_from_thread``. If the message contains Rich markup tags + (e.g. ``[success]✓[/success]``), they are preserved so the + console renders colored output. + """ + if self._shutdown.is_set(): + return + + msg = str(message) + + def _write() -> None: + if _RICH_TAG_RE.search(msg): + self._app.console_view.write_markup(msg) + else: + self._app.console_view.write_text(msg) + self._request_screen_update() + + try: + self._app.call_from_thread(_write) + except Exception: + pass # App already torn down + + # ------------------------------------------------------------------ # + # response_fn — render agent responses with color + pagination + # ------------------------------------------------------------------ # + + def response_fn(self, content: str) -> None: + """Render an agent response as colored Markdown — full content, no pagination.""" + if self._shutdown.is_set(): + return + try: + + def _render(): + self._app.console_view.write_agent_response(content) + self._request_screen_update() + + self._app.call_from_thread(_render) + except Exception: + pass + + # ------------------------------------------------------------------ # + # input_fn — called from worker thread, blocks until user submits + # ------------------------------------------------------------------ # + + def input_fn(self, prompt_text: str = "> ") -> str: + """Block the worker thread until the user submits input. + + 1. Schedules prompt activation on the main thread. + 2. Waits on ``_input_event`` (checks for shutdown every 0.25 s). + 3. Returns the submitted text. + + Raises :class:`ShutdownRequested` if the app is shutting down. + """ + if self._shutdown.is_set(): + raise ShutdownRequested() + + self._input_event.clear() + + def _enable_prompt() -> None: + self._app.prompt_input.enable(placeholder=prompt_text) + self._request_screen_update() + + try: + self._app.call_from_thread(_enable_prompt) + except Exception: + raise ShutdownRequested() + + # Block worker thread — poll with a short timeout so we can + # detect shutdown without waiting for user input. + while not self._input_event.wait(timeout=0.25): + if self._shutdown.is_set(): + raise ShutdownRequested() + + if self._shutdown.is_set(): + raise ShutdownRequested() + + return self._input_value + + def on_prompt_submitted(self, value: str) -> None: + """Called on the main thread when PromptInput.Submitted fires. + + Stores the value, disables the prompt, echoes the input to the + console (unless empty — e.g. pagination "Enter to continue"), + and unblocks the worker. + """ + self._input_value = value + self._app.prompt_input.disable() + # Echo user input to console (skip for empty pagination presses) + if value: + self._app.console_view.write_text(f"> {value}", style=COLORS["content"]) + # Unblock the waiting worker thread + self._input_event.set() + + # ------------------------------------------------------------------ # + # status_fn — called from worker thread for spinner replacement + # ------------------------------------------------------------------ # + + def status_fn(self, message: str, event: str = "start") -> None: + """Update the info bar as spinner replacement with elapsed timer. + + Called by ``_maybe_spinner(..., status_fn=adapter.status_fn)``. + + Events: + ``"start"`` — show assist text and start an elapsed timer on + the right side of the info bar. + ``"end"`` — stop the timer, show final elapsed time (will be + replaced by token usage shortly after). + ``"tokens"`` — replace the timer/elapsed text with token usage. + """ + if self._shutdown.is_set(): + return + + if event == "start": + + def _start() -> None: + self._cancel_timer() + self._timer_start = time.monotonic() + self._app.info_bar.update_assist(f"\u23f3 {message}") + self._app.info_bar.update_status("\u23f1 0s") + self._timer_handle = self._app.set_interval( + 1.0, + self._tick_timer, + ) + self._request_screen_update() + + try: + self._app.call_from_thread(_start) + except Exception: + pass + + elif event == "end": + + def _stop() -> None: + self._cancel_timer() + if self._timer_start is not None: + elapsed = time.monotonic() - self._timer_start + self._app.info_bar.update_status(f"\u23f1 {_format_elapsed(elapsed)}") + self._timer_start = None + self._app.info_bar.update_assist("Enter = submit | Ctrl+J = newline | Ctrl+C = quit") + self._request_screen_update() + + try: + self._app.call_from_thread(_stop) + except Exception: + pass + + elif event == "tokens": + + def _tokens() -> None: + if message: + self._app.info_bar.update_status(message) + self._request_screen_update() + + try: + self._app.call_from_thread(_tokens) + except Exception: + pass + + def _tick_timer(self) -> None: + """Update the elapsed timer display (called on the main thread).""" + if self._timer_start is None or self._timer_handle is None: + return + elapsed = time.monotonic() - self._timer_start + self._app.info_bar.update_status(f"\u23f1 {_format_elapsed(elapsed)}") + self._request_screen_update() + + def _cancel_timer(self) -> None: + """Stop the interval timer if running.""" + if self._timer_handle is not None: + self._timer_handle.stop() + self._timer_handle = None + + # ------------------------------------------------------------------ # + # Token status — called from worker thread + # ------------------------------------------------------------------ # + + def print_token_status(self, status_text: str) -> None: + """Update the right side of the info bar with token usage.""" + if self._shutdown.is_set(): + return + + def _update() -> None: + self._app.info_bar.update_status(status_text) + self._request_screen_update() + + try: + self._app.call_from_thread(_update) + except Exception: + pass + + # ------------------------------------------------------------------ # + # Task tree — called from worker thread + # ------------------------------------------------------------------ # + + def _refresh_tree(self) -> None: + """Force a tree re-render and screen repaint (must be on main thread).""" + self._app.task_tree.refresh() + self._request_screen_update() + + def update_task(self, task_id: str, status: TaskStatus) -> None: + """Update a task's status in the tree widget.""" + if self._shutdown.is_set(): + return + + def _update() -> None: + self._app.task_tree.update_task(task_id, status) + self._refresh_tree() + + try: + self._app.call_from_thread(_update) + except Exception: + pass + + def add_task(self, parent_id: str, task_id: str, label: str) -> None: + """Add a sub-task to the tree widget.""" + if self._shutdown.is_set(): + return + + item = TaskItem(id=task_id, label=label) + + def _add() -> None: + self._app.task_tree.add_task(parent_id, item) + self._refresh_tree() + + try: + self._app.call_from_thread(_add) + except Exception: + pass + + def clear_tasks(self, parent_id: str) -> None: + """Remove all sub-tasks under a parent stage.""" + if self._shutdown.is_set(): + return + + def _clear() -> None: + self._app.task_tree.clear_children(parent_id) + self._refresh_tree() + + try: + self._app.call_from_thread(_clear) + except Exception: + pass + + def section_fn(self, headers: list[tuple[str, int]]) -> None: + """Add section headers as sub-tasks under 'design' with hierarchy. + + Level-2 headings become expandable sections directly under 'design'. + Level-3 headings nest under the most recent level-2 section. + """ + if self._shutdown.is_set(): + return + + def _add() -> None: + changed = False + for header_text, level in headers: + slug = re.sub(r"[^a-z0-9]+", "-", header_text.lower()).strip("-") + task_id = f"design-section-{slug}" + if self._app.task_tree.store.get(task_id) is not None: + if level == 2: + self._last_l2_section_id = task_id + continue # dedup + item = TaskItem(id=task_id, label=header_text) + if level == 2: + self._app.task_tree.add_section("design", item) + self._last_l2_section_id = task_id + else: # level 3 — nest under most recent level-2 + parent = self._last_l2_section_id + self._app.task_tree.add_task(parent, item) + changed = True + if changed: + self._refresh_tree() + + try: + self._app.call_from_thread(_add) + except Exception: + logger.debug("section_fn call_from_thread failed", exc_info=True) diff --git a/azext_prototype/ui/widgets/__init__.py b/azext_prototype/ui/widgets/__init__.py new file mode 100644 index 0000000..a54d51b --- /dev/null +++ b/azext_prototype/ui/widgets/__init__.py @@ -0,0 +1,13 @@ +"""TUI dashboard widgets for the prototype extension.""" + +from azext_prototype.ui.widgets.console_view import ConsoleView +from azext_prototype.ui.widgets.info_bar import InfoBar +from azext_prototype.ui.widgets.prompt_input import PromptInput +from azext_prototype.ui.widgets.task_tree import TaskTree + +__all__ = [ + "ConsoleView", + "InfoBar", + "PromptInput", + "TaskTree", +] diff --git a/azext_prototype/ui/widgets/console_view.py b/azext_prototype/ui/widgets/console_view.py new file mode 100644 index 0000000..6846311 --- /dev/null +++ b/azext_prototype/ui/widgets/console_view.py @@ -0,0 +1,115 @@ +"""Scrollable console output widget wrapping Textual's RichLog. + +Renders Rich renderables (Markdown, Panel, Table, Text) directly +and provides semantic convenience methods mirroring ``Console`` from +``console.py``. +""" + +from __future__ import annotations + +import re + +from rich.markdown import Markdown +from rich.text import Text +from textual.widgets import RichLog + +from azext_prototype.ui.theme import RICH_THEME + +# Ordered list fix ported from console.py +_ORDERED_LIST_RE = re.compile(r"^(\s*)(\d+)\.\s", re.MULTILINE) + + +def _preprocess_markdown(content: str) -> str: + return _ORDERED_LIST_RE.sub(r"**\2.** ", content) + + +class ConsoleView(RichLog): + """Scrollable console panel for agent output, status messages, etc.""" + + DEFAULT_CSS = """ + ConsoleView { + background: $surface; + scrollbar-size: 1 1; + } + """ + + def __init__(self, **kwargs) -> None: + super().__init__( + highlight=False, + markup=True, + auto_scroll=True, + wrap=True, + **kwargs, + ) + + # ------------------------------------------------------------------ # + # Semantic write methods (mirror console.py Console) + # ------------------------------------------------------------------ # + + def write_text(self, message: str, style: str = "") -> None: + """Write a plain styled line.""" + self.write(Text(message, style=style)) + + def write_markup(self, message: str) -> None: + """Write a message with Rich markup tags preserved.""" + try: + styled = Text.from_markup(message) + self.write(styled) + except Exception: + self.write(Text(message)) + + def write_dim(self, message: str) -> None: + self.write(Text(message, style=RICH_THEME.styles.get("dim", ""))) + + def write_success(self, message: str) -> None: + text = Text() + text.append("\u2713 ", style=str(RICH_THEME.styles.get("success", ""))) + text.append(message) + self.write(text) + + def write_error(self, message: str) -> None: + text = Text() + text.append("\u2717 ", style=str(RICH_THEME.styles.get("error", ""))) + text.append(message) + self.write(text) + + def write_warning(self, message: str) -> None: + text = Text() + text.append("! ", style=str(RICH_THEME.styles.get("warning", ""))) + text.append(message) + self.write(text) + + def write_info(self, message: str) -> None: + text = Text() + text.append("\u2192 ", style=str(RICH_THEME.styles.get("info", ""))) + text.append(message) + self.write(text) + + def write_header(self, title: str) -> None: + from rich.style import Style + + self.write(Text()) + base = RICH_THEME.styles.get("accent") + style = (base + Style(bold=True)) if base else Style(bold=True) + self.write(Text(title, style=style)) + self.write(Text()) + + def write_agent_response(self, content: str) -> None: + """Render a markdown agent response.""" + self.write(Text()) + self.write(Markdown(_preprocess_markdown(content))) + self.write(Text()) + + def write_token_status(self, status_text: str) -> None: + if status_text: + self.write(Text(status_text, style=str(RICH_THEME.styles.get("muted", "")), justify="right")) + + def write_file_list(self, files: list[str], success: bool = True) -> None: + style_name = "success" if success else "error" + marker = "\u2713" if success else "\u2717" + style = str(RICH_THEME.styles.get(style_name, "")) + for f in files: + text = Text() + text.append(f" {marker} ", style=style) + text.append(f, style=str(RICH_THEME.styles.get("path", ""))) + self.write(text) diff --git a/azext_prototype/ui/widgets/info_bar.py b/azext_prototype/ui/widgets/info_bar.py new file mode 100644 index 0000000..22ce581 --- /dev/null +++ b/azext_prototype/ui/widgets/info_bar.py @@ -0,0 +1,45 @@ +"""Info bar widget — assist label (left) + token status (right). + +Sits at the very bottom of the TUI layout as a single-line status area. +""" + +from __future__ import annotations + +from textual.app import ComposeResult +from textual.containers import Horizontal +from textual.widgets import Static + + +class InfoBar(Horizontal): + """Bottom info bar with assist and token status regions.""" + + DEFAULT_CSS = """ + InfoBar { + height: 1; + dock: bottom; + background: $surface; + } + + InfoBar > #assist-label { + width: 1fr; + color: $text-muted; + } + + InfoBar > #status-label { + width: 1fr; + text-align: right; + color: $text-muted; + } + """ + + def compose(self) -> ComposeResult: + yield Static("", id="assist-label") + yield Static("", id="status-label") + + def update_assist(self, text: str) -> None: + """Update the left-side assist/instruction text.""" + self.query_one("#assist-label", Static).update(text) + + def update_status(self, text: str) -> None: + """Update the right-side token/status text.""" + self.query_one("#status-label", Static).update(text) diff --git a/azext_prototype/ui/widgets/prompt_input.py b/azext_prototype/ui/widgets/prompt_input.py new file mode 100644 index 0000000..297b513 --- /dev/null +++ b/azext_prototype/ui/widgets/prompt_input.py @@ -0,0 +1,151 @@ +"""Prompt input widget — growable TextArea with submit behavior. + +Enter submits the input. Shift+Enter or Ctrl+J inserts a newline. +The widget auto-grows vertically (up to a max) as the user types +multi-line text. A ``"> "`` prefix is pre-filled when the prompt is +enabled. + +Note: some terminals (e.g. Windows PowerShell) cannot distinguish +Shift+Enter from bare Enter. Ctrl+J is provided as a universal +fallback that works everywhere. + +Implementation note: TextArea processes Enter internally in ``_on_key`` +to insert a newline *before* the BINDINGS system runs. We must +intercept at the same level to override that behavior. +""" + +from __future__ import annotations + +from textual.message import Message +from textual.widgets import TextArea + +_PROMPT_PREFIX = "> " + + +class PromptInput(TextArea): + """Multi-line prompt that submits on Enter and grows upward.""" + + DEFAULT_CSS = """ + PromptInput { + height: auto; + min-height: 3; + max-height: 10; + border-top: solid $accent; + border-bottom: solid $accent; + border-left: none; + border-right: none; + } + """ + + class Submitted(Message): + """Posted when the user presses Enter to submit their input.""" + + def __init__(self, value: str) -> None: + super().__init__() + self.value = value + + def __init__(self, **kwargs) -> None: + super().__init__( + language=None, + show_line_numbers=False, + soft_wrap=True, + **kwargs, + ) + self._enabled = False + self._allow_empty = False + self.text = _PROMPT_PREFIX + + # ------------------------------------------------------------------ # + # Enable / disable (blocks input while session is thinking) + # ------------------------------------------------------------------ # + + def enable(self, placeholder: str = "Type your response...", allow_empty: bool = False) -> None: + """Enable the prompt for user input. + + When *allow_empty* is True, pressing Enter with no text submits + an empty string (used for "Enter to continue" pagination). + In that mode the ``"> "`` prefix is hidden and a placeholder is + shown instead, giving a clear visual distinction from input mode. + """ + self._enabled = True + self._allow_empty = allow_empty + self.read_only = False + if allow_empty: + # Pagination mode — show placeholder, no "> " prefix + self.text = "" + self.placeholder = placeholder + else: + # Input mode — show "> " prefix + self.text = _PROMPT_PREFIX + self.placeholder = "" + self.move_cursor_to_end_of_line() + self.focus() + + def disable(self) -> None: + """Disable the prompt (session is processing).""" + self._enabled = False + self.read_only = True + + def move_cursor_to_end_of_line(self) -> None: + """Place the cursor after the '> ' prefix.""" + row = self.document.line_count - 1 + col = len(self.document.get_line(row)) + self.move_cursor((row, col)) + + # ------------------------------------------------------------------ # + # Key handling + # + # TextArea handles Enter in _on_key to insert a newline before the + # BINDINGS system runs, so we must intercept at the same level. + # + # enter → submit + # shift+enter → newline (terminals with kitty keyboard protocol) + # ctrl+j → newline (universal fallback) + # everything else → default TextArea behavior + # ------------------------------------------------------------------ # + + async def _on_key(self, event) -> None: + # Always let Ctrl+C bubble up to the app's quit binding + if event.key == "ctrl+c": + return + + if not self._enabled: + event.prevent_default() + event.stop() + return + + if event.key == "enter": + # Bare Enter → submit the prompt + event.prevent_default() + event.stop() + self._submit() + return + + if event.key == "ctrl+j": + # Ctrl+J → insert newline (universal fallback) + event.prevent_default() + event.stop() + self.insert("\n") + return + + # shift+enter and all other keys → default TextArea behavior + # (TextArea's _on_key inserts a newline for shift+enter) + await super()._on_key(event) + + # ------------------------------------------------------------------ # + # Submit logic + # ------------------------------------------------------------------ # + + def _submit(self) -> None: + """Strip the prefix, post the Submitted message, and reset.""" + raw = self.text + if raw.startswith(_PROMPT_PREFIX): + raw = raw[len(_PROMPT_PREFIX) :] + value = raw.strip() + if value or self._allow_empty: + self.post_message(self.Submitted(value)) + if self._allow_empty: + self.text = "" + else: + self.text = _PROMPT_PREFIX + self.move_cursor_to_end_of_line() diff --git a/azext_prototype/ui/widgets/task_tree.py b/azext_prototype/ui/widgets/task_tree.py new file mode 100644 index 0000000..0e01ba6 --- /dev/null +++ b/azext_prototype/ui/widgets/task_tree.py @@ -0,0 +1,102 @@ +"""Collapsible task tree widget showing stage and sub-task progress. + +Four permanent root nodes — Initialize, Design, Build, Deploy — with +dynamic sub-tasks added/updated by sessions via the TUI adapter. +""" + +from __future__ import annotations + +from textual.widgets import Tree +from textual.widgets._tree import TreeNode + +from azext_prototype.ui.task_model import TaskItem, TaskStatus, TaskStore +from azext_prototype.ui.theme import COLORS, TASK_COLORS + + +class TaskTree(Tree[str]): + """Tree widget displaying stage progress with colored status symbols.""" + + DEFAULT_CSS = """ + TaskTree { + background: $surface; + scrollbar-size: 1 1; + } + """ + + def __init__(self, store: TaskStore | None = None, **kwargs) -> None: + super().__init__("Stages", **kwargs) + self._store = store or TaskStore() + # Map task id → tree node for fast updates + self._node_map: dict[str, TreeNode[str]] = {} + + @property + def store(self) -> TaskStore: + return self._store + + def on_mount(self) -> None: + """Populate root nodes on mount.""" + self.root.expand() + for root_item in self._store.roots: + node = self.root.add(self._render_label(root_item), data=root_item.id) + node.allow_expand = True + node.expand() + self._node_map[root_item.id] = node + # Add any existing children + for child in root_item.children: + self._add_child_node(node, child) + + # ------------------------------------------------------------------ # + # Public API (called from main thread via call_from_thread) + # ------------------------------------------------------------------ # + + def update_task(self, task_id: str, status: TaskStatus) -> None: + """Update a task's status and refresh its tree label.""" + item = self._store.update_status(task_id, status) + node = self._node_map.get(task_id) + if item and node: + node.set_label(self._render_label(item)) + + def add_task(self, parent_id: str, task: TaskItem) -> None: + """Add a sub-task node under *parent_id*.""" + self._store.add_child(parent_id, task) + parent_node = self._node_map.get(parent_id) + if parent_node: + self._add_child_node(parent_node, task) + + def add_section(self, parent_id: str, task: TaskItem) -> None: + """Add an expandable section node that can have children. + + Unlike :meth:`add_task` (which creates a leaf node), this creates + an expandable node so level-3 subsections can be nested under it. + """ + self._store.add_child(parent_id, task) + parent_node = self._node_map.get(parent_id) + if parent_node: + child_node = parent_node.add(self._render_label(task), data=task.id) + child_node.allow_expand = True + child_node.expand() + self._node_map[task.id] = child_node + + def clear_children(self, parent_id: str) -> None: + """Remove all sub-tasks under *parent_id*.""" + parent = self._store.get(parent_id) + if parent: + for child in list(parent.children): + node = self._node_map.pop(child.id, None) + if node: + node.remove() + self._store.clear_children(parent_id) + + # ------------------------------------------------------------------ # + # Internals + # ------------------------------------------------------------------ # + + def _add_child_node(self, parent_node: TreeNode[str], item: TaskItem) -> None: + child_node = parent_node.add_leaf(self._render_label(item), data=item.id) + self._node_map[item.id] = child_node + + @staticmethod + def _render_label(item: TaskItem) -> str: + """Build a Rich-markup label string with color.""" + color = TASK_COLORS.get(item.status.value, COLORS["dim"]) + return f"[{color}]{item.symbol}[/{color}] {item.label}" diff --git a/build.sh b/build.sh old mode 100644 new mode 100755 diff --git a/pyproject.toml b/pyproject.toml index 0195e2d..e145b02 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,6 +8,7 @@ python_files = "test_*.py" python_classes = "Test*" python_functions = "test_*" addopts = "-v --tb=short -q" +asyncio_mode = "strict" [tool.coverage.run] source = ["azext_prototype"] diff --git a/setup.py b/setup.py index 57d50dd..d564fd6 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import find_packages, setup -VERSION = "0.2.1b3" +VERSION = "0.2.1b4" CLASSIFIERS = [ "Development Status :: 4 - Beta", "Intended Audience :: Developers", @@ -25,6 +25,8 @@ "opencensus-ext-azure>=1.1.0", # prompt_toolkit for multi-line input (Shift+Enter, backslash continuation) "prompt_toolkit>=3.0.0", + # Textual TUI dashboard for interactive sessions + "textual>=8.0.0", # Pin psutil — only 7.1.1 ships a pre-built win32 binary wheel. # Later versions (7.1.2+) require a source build which fails on # Azure CLI's bundled 32-bit Python (no setuptools). @@ -46,7 +48,10 @@ author_email="joshuadavis@microsoft.com", url="https://github.com/Azure/az-prototype", classifiers=CLASSIFIERS, - packages=find_packages(exclude=["tests", "tests.*"]), + packages=[ + p for p in find_packages(exclude=["tests", "tests.*"]) + if "__pycache__" not in p + ], install_requires=DEPENDENCIES, include_package_data=True, package_data={ diff --git a/tests/test_build_session.py b/tests/test_build_session.py index 6b57b54..a0b8337 100644 --- a/tests/test_build_session.py +++ b/tests/test_build_session.py @@ -722,13 +722,16 @@ def test_reentrant_skips_generated_stages(self, build_context, build_registry, m session = BuildSession(build_context, build_registry) - # Pre-populate with a generated stage + design = {"architecture": "Test"} + + # Pre-populate with a generated stage and matching design snapshot session._build_state.set_deployment_plan([ {"stage": 1, "name": "Foundation", "category": "infra", "services": [], "status": "generated", "dir": "", "files": ["main.tf"]}, {"stage": 2, "name": "Documentation", "category": "docs", "services": [], "status": "pending", "dir": "concept/docs", "files": []}, ]) + session._build_state.set_design_snapshot(design) inputs = iter(["", "done"]) @@ -741,7 +744,7 @@ def test_reentrant_skips_generated_stages(self, build_context, build_registry, m mock_orch.return_value.delegate.return_value = _make_response("QA ok") result = session.run( - design={"architecture": "Test"}, + design=design, input_fn=lambda p: next(inputs), print_fn=lambda m: None, ) @@ -752,6 +755,407 @@ def test_reentrant_skips_generated_stages(self, build_context, build_registry, m assert mock_doc_agent.execute.call_count == 1 +# ====================================================================== +# Incremental build / design snapshot tests +# ====================================================================== + +class TestDesignSnapshot: + """Tests for design snapshot tracking and change detection in BuildState.""" + + def test_design_snapshot_set_on_first_build(self, tmp_project): + from azext_prototype.stages.build_state import BuildState + + bs = BuildState(str(tmp_project)) + design = { + "architecture": "## Architecture\nKey Vault + SQL Database", + "_metadata": {"iteration": 3}, + } + bs.set_design_snapshot(design) + + snapshot = bs.state["design_snapshot"] + assert snapshot["iteration"] == 3 + assert snapshot["architecture_hash"] is not None + assert len(snapshot["architecture_hash"]) == 16 + assert snapshot["architecture_text"] == design["architecture"] + + def test_design_has_changed_detects_modification(self, tmp_project): + from azext_prototype.stages.build_state import BuildState + + bs = BuildState(str(tmp_project)) + original = {"architecture": "Key Vault + SQL"} + bs.set_design_snapshot(original) + + modified = {"architecture": "Key Vault + SQL + Redis Cache"} + assert bs.design_has_changed(modified) is True + + def test_design_has_changed_no_change(self, tmp_project): + from azext_prototype.stages.build_state import BuildState + + bs = BuildState(str(tmp_project)) + design = {"architecture": "Key Vault + SQL"} + bs.set_design_snapshot(design) + + assert bs.design_has_changed(design) is False + + def test_design_has_changed_legacy_no_snapshot(self, tmp_project): + from azext_prototype.stages.build_state import BuildState + + bs = BuildState(str(tmp_project)) + # No snapshot set — simulates legacy build + assert bs.design_has_changed({"architecture": "anything"}) is True + + def test_get_previous_architecture(self, tmp_project): + from azext_prototype.stages.build_state import BuildState + + bs = BuildState(str(tmp_project)) + assert bs.get_previous_architecture() is None + + design = {"architecture": "The full architecture text here"} + bs.set_design_snapshot(design) + assert bs.get_previous_architecture() == "The full architecture text here" + + def test_design_snapshot_persists_across_load(self, tmp_project): + from azext_prototype.stages.build_state import BuildState + + bs = BuildState(str(tmp_project)) + design = {"architecture": "Persistent arch", "_metadata": {"iteration": 2}} + bs.set_design_snapshot(design) + + bs2 = BuildState(str(tmp_project)) + bs2.load() + assert bs2.design_has_changed(design) is False + assert bs2.get_previous_architecture() == "Persistent arch" + + +class TestStageManipulation: + """Tests for mark_stages_stale, remove_stages, add_stages, renumber_stages.""" + + def _sample_stages(self): + return [ + {"stage": 1, "name": "Foundation", "category": "infra", + "services": [], "status": "generated", "dir": "concept/infra/terraform/stage-1-foundation", + "files": ["main.tf"]}, + {"stage": 2, "name": "Data", "category": "data", + "services": [{"name": "sql", "computed_name": "sql-1", "resource_type": "Microsoft.Sql/servers", "sku": ""}], + "status": "generated", "dir": "concept/infra/terraform/stage-2-data", + "files": ["sql.tf"]}, + {"stage": 3, "name": "App", "category": "app", + "services": [], "status": "generated", "dir": "concept/apps/stage-3-api", + "files": ["app.py"]}, + {"stage": 4, "name": "Documentation", "category": "docs", + "services": [], "status": "generated", "dir": "concept/docs", + "files": ["DEPLOY.md"]}, + ] + + def test_mark_stages_stale(self, tmp_project): + from azext_prototype.stages.build_state import BuildState + + bs = BuildState(str(tmp_project)) + bs.set_deployment_plan(self._sample_stages()) + + bs.mark_stages_stale([2, 3]) + + assert bs.get_stage(1)["status"] == "generated" + assert bs.get_stage(2)["status"] == "pending" + assert bs.get_stage(3)["status"] == "pending" + assert bs.get_stage(4)["status"] == "generated" + + def test_remove_stages(self, tmp_project): + from azext_prototype.stages.build_state import BuildState + + bs = BuildState(str(tmp_project)) + bs.set_deployment_plan(self._sample_stages()) + bs._state["files_generated"] = ["main.tf", "sql.tf", "app.py", "DEPLOY.md"] + + bs.remove_stages([2]) + + stage_nums = [s["stage"] for s in bs.state["deployment_stages"]] + assert 2 not in stage_nums + assert len(bs.state["deployment_stages"]) == 3 + # sql.tf should be removed from files_generated + assert "sql.tf" not in bs.state["files_generated"] + assert "main.tf" in bs.state["files_generated"] + + def test_add_stages(self, tmp_project): + from azext_prototype.stages.build_state import BuildState + + bs = BuildState(str(tmp_project)) + bs.set_deployment_plan(self._sample_stages()) + + new_stages = [ + {"name": "Redis Cache", "category": "data", + "services": [{"name": "redis", "computed_name": "redis-1", + "resource_type": "Microsoft.Cache/redis", "sku": "Basic"}]}, + ] + bs.add_stages(new_stages) + + stages = bs.state["deployment_stages"] + # Should be inserted before docs (stage 4 originally) + # After renumbering: Foundation(1), Data(2), App(3), Redis(4), Docs(5) + assert len(stages) == 5 + assert stages[3]["name"] == "Redis Cache" + assert stages[3]["stage"] == 4 + assert stages[4]["name"] == "Documentation" + assert stages[4]["stage"] == 5 + + def test_renumber_stages(self, tmp_project): + from azext_prototype.stages.build_state import BuildState + + bs = BuildState(str(tmp_project)) + # Set up stages with gaps + bs._state["deployment_stages"] = [ + {"stage": 1, "name": "A", "category": "infra", "services": [], "status": "generated", "dir": "", "files": []}, + {"stage": 5, "name": "B", "category": "data", "services": [], "status": "pending", "dir": "", "files": []}, + {"stage": 10, "name": "C", "category": "docs", "services": [], "status": "pending", "dir": "", "files": []}, + ] + + bs.renumber_stages() + + assert bs.state["deployment_stages"][0]["stage"] == 1 + assert bs.state["deployment_stages"][1]["stage"] == 2 + assert bs.state["deployment_stages"][2]["stage"] == 3 + + +class TestArchitectureDiff: + """Tests for _diff_architectures and _parse_diff_result.""" + + def test_diff_architectures_parses_response(self, build_context, build_registry, mock_architect_agent_for_build): + from azext_prototype.stages.build_session import BuildSession + + session = BuildSession(build_context, build_registry) + + existing = [ + {"stage": 1, "name": "Foundation", "category": "infra", "services": [{"name": "key-vault"}], + "status": "generated", "dir": "", "files": []}, + {"stage": 2, "name": "Data", "category": "data", "services": [{"name": "sql"}], + "status": "generated", "dir": "", "files": []}, + ] + + diff_response = json.dumps({ + "unchanged": [1], + "modified": [2], + "removed": [], + "added": [{"name": "Redis", "category": "data", "services": []}], + "plan_restructured": False, + "summary": "Modified data stage; added Redis.", + }) + mock_architect_agent_for_build.execute.return_value = _make_response( + f"```json\n{diff_response}\n```" + ) + + result = session._diff_architectures("old arch", "new arch", existing) + + assert result["unchanged"] == [1] + assert result["modified"] == [2] + assert result["removed"] == [] + assert len(result["added"]) == 1 + assert result["added"][0]["name"] == "Redis" + assert result["plan_restructured"] is False + + def test_diff_architectures_fallback_no_architect(self, build_context, build_registry): + from azext_prototype.stages.build_session import BuildSession + + # Remove the architect agent + session = BuildSession(build_context, build_registry) + session._architect_agent = None + + existing = [ + {"stage": 1, "name": "A", "category": "infra", "services": [], "status": "generated", "dir": "", "files": []}, + {"stage": 2, "name": "B", "category": "data", "services": [], "status": "generated", "dir": "", "files": []}, + ] + + result = session._diff_architectures("old", "new", existing) + + # Fallback: all stages marked as modified + assert set(result["modified"]) == {1, 2} + assert result["unchanged"] == [] + + def test_parse_diff_result_defaults_to_unchanged(self, build_context, build_registry): + from azext_prototype.stages.build_session import BuildSession + + session = BuildSession(build_context, build_registry) + existing = [ + {"stage": 1, "name": "A", "category": "infra", "services": [], "status": "generated", "dir": "", "files": []}, + {"stage": 2, "name": "B", "category": "data", "services": [], "status": "generated", "dir": "", "files": []}, + {"stage": 3, "name": "C", "category": "app", "services": [], "status": "generated", "dir": "", "files": []}, + ] + + # Only mention stage 2 as modified; 1 and 3 should default to unchanged + content = json.dumps({"modified": [2], "summary": "test"}) + result = session._parse_diff_result(content, existing) + + assert result is not None + assert 1 in result["unchanged"] + assert 3 in result["unchanged"] + assert result["modified"] == [2] + + def test_parse_diff_result_invalid_json(self, build_context, build_registry): + from azext_prototype.stages.build_session import BuildSession + + session = BuildSession(build_context, build_registry) + result = session._parse_diff_result("This is not JSON", []) + assert result is None + + +class TestIncrementalBuildSession: + """End-to-end tests for the incremental build flow.""" + + def test_incremental_run_no_changes(self, build_context, build_registry): + """When design hasn't changed and all stages are generated, report up to date.""" + from azext_prototype.stages.build_session import BuildSession + + session = BuildSession(build_context, build_registry) + + design = {"architecture": "Sample arch"} + + # Set up: pre-populate with generated stages and a matching snapshot + session._build_state.set_deployment_plan([ + {"stage": 1, "name": "Foundation", "category": "infra", + "services": [], "status": "generated", "dir": "", "files": ["main.tf"]}, + {"stage": 2, "name": "Docs", "category": "docs", + "services": [], "status": "generated", "dir": "concept/docs", "files": ["README.md"]}, + ]) + session._build_state.set_design_snapshot(design) + + printed = [] + inputs = iter(["done"]) + + result = session.run( + design=design, + input_fn=lambda p: next(inputs), + print_fn=lambda m: printed.append(m), + ) + + output = "\n".join(printed) + assert "up to date" in output.lower() + assert result.review_accepted is True + + def test_incremental_run_with_changes(self, build_context, build_registry, mock_architect_agent_for_build, mock_tf_agent): + """When design has changed, only affected stages should be regenerated.""" + from azext_prototype.stages.build_session import BuildSession + + session = BuildSession(build_context, build_registry) + + old_design = {"architecture": "Original architecture with Key Vault"} + new_design = {"architecture": "Updated architecture with Key Vault + Redis"} + + # Set up existing build + session._build_state.set_deployment_plan([ + {"stage": 1, "name": "Foundation", "category": "infra", + "services": [{"name": "key-vault"}], "status": "generated", + "dir": "concept/infra/terraform/stage-1-foundation", "files": ["main.tf"]}, + {"stage": 2, "name": "Documentation", "category": "docs", + "services": [], "status": "generated", "dir": "concept/docs", "files": ["README.md"]}, + ]) + session._build_state.set_design_snapshot(old_design) + + # Mock architect: stage 1 unchanged, no removed, add Redis + diff_response = json.dumps({ + "unchanged": [1], + "modified": [], + "removed": [], + "added": [{"name": "Redis Cache", "category": "data", + "services": [{"name": "redis-cache", "computed_name": "redis-1", + "resource_type": "Microsoft.Cache/redis", "sku": "Basic"}]}], + "plan_restructured": False, + "summary": "Added Redis Cache stage.", + }) + mock_architect_agent_for_build.execute.return_value = _make_response( + f"```json\n{diff_response}\n```" + ) + + printed = [] + inputs = iter(["", "done"]) + + with patch("azext_prototype.stages.build_session.GovernanceContext") as mock_gov_cls: + mock_gov_cls.return_value.check_response_for_violations.return_value = [] + session._governance = mock_gov_cls.return_value + session._policy_resolver._governance = mock_gov_cls.return_value + + with patch("azext_prototype.stages.build_session.AgentOrchestrator") as mock_orch: + mock_orch.return_value.delegate.return_value = _make_response("QA ok") + + result = session.run( + design=new_design, + input_fn=lambda p: next(inputs), + print_fn=lambda m: printed.append(m), + ) + + output = "\n".join(printed) + assert "Design changes detected" in output + assert "Added 1 new stage" in output + assert result.cancelled is False + + def test_incremental_run_plan_restructured(self, build_context, build_registry, mock_architect_agent_for_build, mock_tf_agent): + """When plan_restructured is True, a full re-derive should be offered.""" + from azext_prototype.stages.build_session import BuildSession + + session = BuildSession(build_context, build_registry) + + old_design = {"architecture": "Simple architecture"} + new_design = {"architecture": "Completely redesigned architecture"} + + session._build_state.set_deployment_plan([ + {"stage": 1, "name": "Foundation", "category": "infra", + "services": [], "status": "generated", "dir": "", "files": ["main.tf"]}, + ]) + session._build_state.set_design_snapshot(old_design) + + # First call: diff says plan_restructured + diff_response = json.dumps({ + "unchanged": [], + "modified": [1], + "removed": [], + "added": [], + "plan_restructured": True, + "summary": "Major restructuring needed.", + }) + + # Second call: re-derive returns new plan + new_plan = { + "stages": [ + {"stage": 1, "name": "New Foundation", "category": "infra", + "dir": "concept/infra/terraform/stage-1-new", + "services": [], "status": "pending", "files": []}, + {"stage": 2, "name": "Documentation", "category": "docs", + "dir": "concept/docs", + "services": [], "status": "pending", "files": []}, + ] + } + + call_count = [0] + def architect_side_effect(ctx, task): + call_count[0] += 1 + if call_count[0] == 1: + return _make_response(f"```json\n{diff_response}\n```") + else: + return _make_response(f"```json\n{json.dumps(new_plan)}\n```") + + mock_architect_agent_for_build.execute.side_effect = architect_side_effect + + printed = [] + # First prompt: confirm re-derive (Enter), second: confirm plan, third: done + inputs = iter(["", "", "done"]) + + with patch("azext_prototype.stages.build_session.GovernanceContext") as mock_gov_cls: + mock_gov_cls.return_value.check_response_for_violations.return_value = [] + session._governance = mock_gov_cls.return_value + session._policy_resolver._governance = mock_gov_cls.return_value + + with patch("azext_prototype.stages.build_session.AgentOrchestrator") as mock_orch: + mock_orch.return_value.delegate.return_value = _make_response("QA ok") + + result = session.run( + design=new_design, + input_fn=lambda p: next(inputs), + print_fn=lambda m: printed.append(m), + ) + + output = "\n".join(printed) + assert "full plan re-derive" in output.lower() + assert result.cancelled is False + + # ====================================================================== # Telemetry tests # ====================================================================== @@ -1614,3 +2018,137 @@ def test_advisory_qa_header_says_advisory(self, tmp_project): assert "Advisory Notes" in output # Should NOT contain "QA Review:" as a section header assert "QA Review:" not in output + + +# ====================================================================== +# Stable ID tests +# ====================================================================== + +class TestStableIds: + + def test_stable_ids_assigned_on_set_deployment_plan(self, tmp_project): + from azext_prototype.stages.build_state import BuildState + + bs = BuildState(str(tmp_project)) + stages = [ + {"stage": 1, "name": "Foundation", "category": "infra", "services": [], "status": "pending", "files": []}, + {"stage": 2, "name": "Data Layer", "category": "data", "services": [], "status": "pending", "files": []}, + ] + bs.set_deployment_plan(stages) + + for s in bs.state["deployment_stages"]: + assert "id" in s + assert s["id"] # non-empty + assert bs.state["deployment_stages"][0]["id"] == "foundation" + assert bs.state["deployment_stages"][1]["id"] == "data-layer" + + def test_stable_ids_preserved_on_renumber(self, tmp_project): + from azext_prototype.stages.build_state import BuildState + + bs = BuildState(str(tmp_project)) + stages = [ + {"stage": 1, "name": "Foundation", "category": "infra", "services": [], "status": "pending", "files": []}, + {"stage": 2, "name": "Data Layer", "category": "data", "services": [], "status": "pending", "files": []}, + ] + bs.set_deployment_plan(stages) + + original_ids = [s["id"] for s in bs.state["deployment_stages"]] + bs.renumber_stages() + new_ids = [s["id"] for s in bs.state["deployment_stages"]] + assert original_ids == new_ids + + def test_stable_ids_unique_on_name_collision(self, tmp_project): + from azext_prototype.stages.build_state import BuildState + + bs = BuildState(str(tmp_project)) + stages = [ + {"stage": 1, "name": "Foundation", "category": "infra", "services": [], "status": "pending", "files": []}, + {"stage": 2, "name": "Foundation", "category": "infra", "services": [], "status": "pending", "files": []}, + ] + bs.set_deployment_plan(stages) + + ids = [s["id"] for s in bs.state["deployment_stages"]] + assert len(set(ids)) == 2 # all unique + assert ids[0] == "foundation" + assert ids[1] == "foundation-2" + + def test_stable_ids_backfilled_on_load(self, tmp_project): + from azext_prototype.stages.build_state import BuildState + + # Write a legacy state file without ids + state_dir = Path(str(tmp_project)) / ".prototype" / "state" + state_dir.mkdir(parents=True, exist_ok=True) + legacy = { + "deployment_stages": [ + {"stage": 1, "name": "Foundation", "category": "infra", "services": [], "status": "generated", "files": []}, + ], + "templates_used": [], + "iac_tool": "terraform", + "_metadata": {"created": None, "last_updated": None, "iteration": 0}, + } + with open(state_dir / "build.yaml", "w") as f: + yaml.dump(legacy, f) + + bs = BuildState(str(tmp_project)) + bs.load() + assert bs.state["deployment_stages"][0]["id"] == "foundation" + assert bs.state["deployment_stages"][0]["deploy_mode"] == "auto" + + def test_get_stage_by_id(self, tmp_project): + from azext_prototype.stages.build_state import BuildState + + bs = BuildState(str(tmp_project)) + stages = [ + {"stage": 1, "name": "Foundation", "category": "infra", "services": [], "status": "pending", "files": []}, + {"stage": 2, "name": "Data Layer", "category": "data", "services": [], "status": "pending", "files": []}, + ] + bs.set_deployment_plan(stages) + + found = bs.get_stage_by_id("data-layer") + assert found is not None + assert found["name"] == "Data Layer" + assert bs.get_stage_by_id("nonexistent") is None + + def test_deploy_mode_in_stage_schema(self, tmp_project): + from azext_prototype.stages.build_state import BuildState + + bs = BuildState(str(tmp_project)) + stages = [ + { + "stage": 1, + "name": "Manual Upload", + "category": "external", + "services": [], + "status": "pending", + "files": [], + "deploy_mode": "manual", + "manual_instructions": "Upload the notebook to the Fabric workspace.", + }, + { + "stage": 2, + "name": "Foundation", + "category": "infra", + "services": [], + "status": "pending", + "files": [], + }, + ] + bs.set_deployment_plan(stages) + + assert bs.state["deployment_stages"][0]["deploy_mode"] == "manual" + assert "Upload" in bs.state["deployment_stages"][0]["manual_instructions"] + assert bs.state["deployment_stages"][1]["deploy_mode"] == "auto" + assert bs.state["deployment_stages"][1]["manual_instructions"] is None + + def test_add_stages_assigns_ids(self, tmp_project): + from azext_prototype.stages.build_state import BuildState + + bs = BuildState(str(tmp_project)) + bs.set_deployment_plan([ + {"stage": 1, "name": "Foundation", "category": "infra", "services": [], "status": "pending", "files": []}, + ]) + bs.add_stages([ + {"name": "API Layer", "category": "app"}, + ]) + ids = [s["id"] for s in bs.state["deployment_stages"]] + assert "api-layer" in ids diff --git a/tests/test_coverage_gaps.py b/tests/test_coverage_gaps.py index 667fed0..440a0b2 100644 --- a/tests/test_coverage_gaps.py +++ b/tests/test_coverage_gaps.py @@ -316,65 +316,29 @@ def test_deploy_app_stage_no_scripts(self, tmp_path): class TestPrototypeDesign: """Test the design command.""" - @patch(f"{_MOD}._prepare_command") - @patch(f"{_MOD}._check_guards") - def test_design_interactive(self, mock_guards, mock_prep, project_with_config, mock_ai_provider): + @patch(f"{_MOD}._run_tui") + @patch(f"{_MOD}._get_project_dir") + def test_design_interactive(self, mock_dir, mock_tui, project_with_config): from azext_prototype.custom import prototype_design - from azext_prototype.ai.provider import AIResponse - - mock_ai_provider.chat.return_value = AIResponse(content="Architecture plan", model="gpt-4o") - mock_registry = MagicMock() - architect = MagicMock() - architect.name = "cloud-architect" - architect.execute.return_value = AIResponse(content="# Architecture\nDesign", model="gpt-4o") - mock_registry.find_by_capability.return_value = [architect] - - mock_ctx = MagicMock() - mock_ctx.project_dir = str(project_with_config) - mock_prep.return_value = (str(project_with_config), MagicMock(), mock_registry, mock_ctx) + mock_dir.return_value = str(project_with_config) cmd = MagicMock() - # Discovery now happens inside DesignStage.execute — mock it - with patch("azext_prototype.stages.design_stage.DiscoverySession") as MockDS: - from azext_prototype.stages.discovery import DiscoveryResult - MockDS.return_value.run.return_value = DiscoveryResult( - requirements="User wants an API", - conversation=[], - policy_overrides=[], - exchange_count=2, - ) - result = prototype_design(cmd, json_output=True) - # Design stage calls execute which returns a dict + result = prototype_design(cmd, json_output=True) assert isinstance(result, dict) + mock_tui.assert_called_once() - @patch(f"{_MOD}._prepare_command") - @patch(f"{_MOD}._check_guards") - def test_design_with_context(self, mock_guards, mock_prep, project_with_config, mock_ai_provider): + @patch(f"{_MOD}._run_tui") + @patch(f"{_MOD}._get_project_dir") + def test_design_with_context(self, mock_dir, mock_tui, project_with_config): from azext_prototype.custom import prototype_design - from azext_prototype.ai.provider import AIResponse - - mock_registry = MagicMock() - architect = MagicMock() - architect.name = "cloud-architect" - architect.execute.return_value = AIResponse(content="# Architecture", model="gpt-4o") - mock_registry.find_by_capability.return_value = [architect] - mock_ctx = MagicMock() - mock_ctx.project_dir = str(project_with_config) - mock_prep.return_value = (str(project_with_config), MagicMock(), mock_registry, mock_ctx) + mock_dir.return_value = str(project_with_config) cmd = MagicMock() - with patch("azext_prototype.stages.design_stage.DiscoverySession") as MockDS: - from azext_prototype.stages.discovery import DiscoveryResult - MockDS.return_value.run.return_value = DiscoveryResult( - requirements="Build an API with Cosmos DB", - conversation=[], - policy_overrides=[], - exchange_count=2, - ) - result = prototype_design(cmd, context="Build an API with Cosmos DB", json_output=True) + result = prototype_design(cmd, context="Build an API with Cosmos DB", json_output=True) assert isinstance(result, dict) + mock_tui.assert_called_once() class TestPrototypeGenerateDocs: diff --git a/tests/test_deploy_session.py b/tests/test_deploy_session.py index 10daf02..76d4592 100644 --- a/tests/test_deploy_session.py +++ b/tests/test_deploy_session.py @@ -736,6 +736,10 @@ def test_deploy_failure_qa_routing(self, mock_tf, mock_sub, mock_login, mock_sub # Mock QA agent response session._qa_agent = MagicMock() session._qa_agent.execute.return_value = _make_response("Check your service principal credentials.") + # Clear fix agents so remediation is skipped (this test verifies QA routing only) + session._iac_agents = {} + session._dev_agent = None + session._architect_agent = None inputs = iter(["", "done"]) # confirm, then done output = [] @@ -2269,3 +2273,1176 @@ def test_resolve_context_no_oid_when_lookup_fails(self, _mock_lookup, tmp_projec session._resolve_context("sub-123", None) assert "TF_VAR_deployer_object_id" not in session._deploy_env + + +# ====================================================================== +# Natural Language Intent Detection — Deploy Integration +# ====================================================================== + + +class TestNaturalLanguageIntentDeploy: + """Test that natural language triggers correct deploy commands.""" + + def _make_session(self, project_dir, build_stages=None): + """Create a DeploySession with dependencies mocked.""" + from azext_prototype.agents.base import AgentContext + from azext_prototype.agents.registry import AgentRegistry + from azext_prototype.agents.builtin import register_all_builtin + from azext_prototype.stages.deploy_session import DeploySession + + config_path = Path(project_dir) / "prototype.yaml" + if not config_path.exists(): + config_data = { + "project": {"name": "test", "location": "eastus", "iac_tool": "terraform"}, + "ai": {"provider": "github-models"}, + } + with open(config_path, "w") as f: + yaml.dump(config_data, f) + + _write_build_yaml(project_dir, stages=build_stages) + + context = AgentContext( + project_config={"project": {"iac_tool": "terraform"}}, + project_dir=str(project_dir), + ai_provider=MagicMock(), + ) + registry = AgentRegistry() + register_all_builtin(registry) + + return DeploySession(context, registry) + + @patch("azext_prototype.stages.deploy_session.subprocess.run", return_value=MagicMock(returncode=0, stdout="Terraform v1.7.0\n", stderr="")) + @patch("azext_prototype.stages.deploy_session.check_az_login", return_value=True) + @patch("azext_prototype.stages.deploy_session.get_current_subscription", return_value="sub-123") + @patch("azext_prototype.stages.deploy_session.deploy_terraform", return_value={"status": "deployed"}) + def test_nl_deploy_stage_1(self, mock_tf, mock_sub, mock_login, mock_subprocess, tmp_project): + """'deploy stage 1' in natural language triggers deploy.""" + stages = [ + { + "stage": 1, "name": "Infra", "category": "infra", + "services": [], "dir": "concept/infra/terraform", + "status": "generated", "files": [], + }, + ] + (tmp_project / "concept" / "infra" / "terraform").mkdir(parents=True, exist_ok=True) + + session = self._make_session(tmp_project, build_stages=stages) + inputs = iter(["", "deploy stage 1", "done"]) + output = [] + result = session.run( + subscription="sub-123", + input_fn=lambda p: next(inputs), + print_fn=lambda msg: output.append(msg), + ) + joined = "\n".join(output) + # Should show deploy success or at least process the deploy command + assert "deployed" in joined.lower() or "Stage 1" in joined + + def test_nl_describe_stage(self, tmp_project): + """'describe stage 1' shows stage details.""" + session = self._make_session(tmp_project) + inputs = iter(["", "describe stage 1", "done"]) + output = [] + session.run( + subscription="sub-123", + input_fn=lambda p: next(inputs), + print_fn=lambda msg: output.append(msg), + ) + joined = "\n".join(output) + assert "Foundation" in joined or "Stage 1" in joined + + +# ====================================================================== +# Deploy State Remediation tests +# ====================================================================== + +class TestDeployStateRemediation: + """Tests for remediation state tracking in DeployState.""" + + def test_mark_stage_remediating(self, tmp_project): + from azext_prototype.stages.deploy_state import DeployState + + build_path = _write_build_yaml(tmp_project) + ds = DeployState(str(tmp_project)) + ds.load_from_build_state(build_path) + + ds.mark_stage_failed(1, "auth error") + ds.mark_stage_remediating(1) + + stage = ds.get_stage(1) + assert stage["deploy_status"] == "remediating" + assert stage["remediation_attempts"] == 1 + + def test_remediation_attempts_increment(self, tmp_project): + from azext_prototype.stages.deploy_state import DeployState + + build_path = _write_build_yaml(tmp_project) + ds = DeployState(str(tmp_project)) + ds.load_from_build_state(build_path) + + ds.mark_stage_remediating(1) + assert ds.get_stage(1)["remediation_attempts"] == 1 + + ds.mark_stage_remediating(1) + assert ds.get_stage(1)["remediation_attempts"] == 2 + + ds.mark_stage_remediating(1) + assert ds.get_stage(1)["remediation_attempts"] == 3 + + def test_reset_stage_to_pending(self, tmp_project): + from azext_prototype.stages.deploy_state import DeployState + + build_path = _write_build_yaml(tmp_project) + ds = DeployState(str(tmp_project)) + ds.load_from_build_state(build_path) + + ds.mark_stage_failed(1, "timeout") + assert ds.get_stage(1)["deploy_status"] == "failed" + assert ds.get_stage(1)["deploy_error"] == "timeout" + + ds.reset_stage_to_pending(1) + stage = ds.get_stage(1) + assert stage["deploy_status"] == "pending" + assert stage["deploy_error"] == "" + + def test_add_patch_stages(self, tmp_project): + from azext_prototype.stages.deploy_state import DeployState + + build_path = _write_build_yaml(tmp_project) + ds = DeployState(str(tmp_project)) + ds.load_from_build_state(build_path) + + new_stages = [ + {"stage": 0, "name": "Patch Fix", "category": "infra"}, + ] + ds.add_patch_stages(new_stages) + + stages = ds.state["deployment_stages"] + assert len(stages) == 4 + # Should have deploy-specific fields + patch_stage = [s for s in stages if s["name"] == "Patch Fix"][0] + assert patch_stage["deploy_status"] == "pending" + assert patch_stage["remediation_attempts"] == 0 + assert patch_stage["deploy_timestamp"] is None + + def test_add_patch_stages_before_docs(self, tmp_project): + from azext_prototype.stages.deploy_state import DeployState + + stages = [ + {"stage": 1, "name": "Infra", "category": "infra", "services": [], "dir": "s1", "files": []}, + {"stage": 2, "name": "Docs", "category": "docs", "services": [], "dir": "s2", "files": []}, + ] + build_path = _write_build_yaml(tmp_project, stages=stages) + ds = DeployState(str(tmp_project)) + ds.load_from_build_state(build_path) + + ds.add_patch_stages([{"stage": 0, "name": "Patch", "category": "infra"}]) + + stage_names = [s["name"] for s in ds.state["deployment_stages"]] + # Patch should be before Docs + assert stage_names.index("Patch") < stage_names.index("Docs") + + def test_renumber_stages(self, tmp_project): + from azext_prototype.stages.deploy_state import DeployState + + build_path = _write_build_yaml(tmp_project) + ds = DeployState(str(tmp_project)) + ds.load_from_build_state(build_path) + + # Manually set non-sequential numbers + ds.state["deployment_stages"][0]["stage"] = 10 + ds.state["deployment_stages"][1]["stage"] = 20 + ds.state["deployment_stages"][2]["stage"] = 30 + + ds.renumber_stages() + + nums = [s["stage"] for s in ds.state["deployment_stages"]] + assert nums == [1, 2, 3] + + def test_remediation_attempts_in_load_from_build_state(self, tmp_project): + """Verify remediation_attempts field is added during build state import.""" + from azext_prototype.stages.deploy_state import DeployState + + build_path = _write_build_yaml(tmp_project) + ds = DeployState(str(tmp_project)) + ds.load_from_build_state(build_path) + + for stage in ds.state["deployment_stages"]: + assert "remediation_attempts" in stage + assert stage["remediation_attempts"] == 0 + + def test_remediating_status_icon(self, tmp_project): + from azext_prototype.stages.deploy_state import DeployState + + build_path = _write_build_yaml(tmp_project) + ds = DeployState(str(tmp_project)) + ds.load_from_build_state(build_path) + + ds.mark_stage_remediating(1) + status = ds.format_stage_status() + assert "<>" in status + + +# ====================================================================== +# Deploy Remediation Loop tests +# ====================================================================== + +class TestDeployRemediation: + """Tests for the deploy auto-remediation loop in DeploySession.""" + + _SENTINEL = object() + + def _make_session(self, project_dir, iac_tool="terraform", build_stages=None, ai_provider=_SENTINEL): + from azext_prototype.agents.base import AgentContext + from azext_prototype.agents.registry import AgentRegistry + from azext_prototype.agents.builtin import register_all_builtin + from azext_prototype.stages.deploy_session import DeploySession + + config_path = Path(project_dir) / "prototype.yaml" + if not config_path.exists(): + config_data = { + "project": {"name": "test", "location": "eastus", "iac_tool": iac_tool}, + "ai": {"provider": "github-models"}, + } + with open(config_path, "w") as f: + yaml.dump(config_data, f) + + _write_build_yaml(project_dir, stages=build_stages, iac_tool=iac_tool) + + provider = MagicMock() if ai_provider is self._SENTINEL else ai_provider + context = AgentContext( + project_config={"project": {"iac_tool": iac_tool}}, + project_dir=str(project_dir), + ai_provider=provider, + ) + registry = AgentRegistry() + register_all_builtin(registry) + + session = DeploySession(context, registry) + # Pre-load build state into deploy state + build_path = Path(project_dir) / ".prototype" / "state" / "build.yaml" + session._deploy_state.load_from_build_state(build_path) + return session + + def test_remediation_succeeds_first_attempt(self, tmp_project): + """Deploy fails -> QA diagnoses -> fix agent fixes -> redeploy succeeds.""" + stages = [ + {"stage": 1, "name": "Infra", "category": "infra", "services": [], + "dir": "concept/infra/terraform", "status": "generated", "files": []}, + ] + (tmp_project / "concept" / "infra" / "terraform").mkdir(parents=True, exist_ok=True) + (tmp_project / "concept" / "infra" / "terraform" / "main.tf").write_text("# original") + + session = self._make_session(tmp_project, build_stages=stages) + + # Mock QA agent + session._qa_agent = MagicMock() + session._qa_agent.execute.return_value = _make_response("Missing provider configuration. Add required_providers block.") + + # Mock architect agent + session._architect_agent = MagicMock() + session._architect_agent.execute.return_value = _make_response( + "Root cause: missing provider. Add azurerm provider config.\nNo downstream impact." + ) + + # Mock IaC agent (terraform) + mock_iac = MagicMock() + mock_iac.execute.return_value = _make_response( + "```main.tf\n# fixed provider config\nterraform { required_providers { azurerm = { source = \"hashicorp/azurerm\" } } }\n```" + ) + session._iac_agents["terraform"] = mock_iac + + result = {"status": "failed", "error": "Error: No provider configured"} + stage = session._deploy_state.get_stage(1) + output = [] + + with patch("azext_prototype.stages.deploy_session.deploy_terraform", return_value={"status": "deployed"}): + remediated = session._remediate_deploy_failure( + stage, result, False, lambda msg: output.append(msg), lambda p: "", + ) + + assert remediated is not None + assert remediated["status"] == "deployed" + joined = "\n".join(output) + assert "Remediating" in joined + assert "deployed successfully after remediation" in joined + + def test_remediation_succeeds_second_attempt(self, tmp_project): + """First redeploy fails, second attempt succeeds.""" + stages = [ + {"stage": 1, "name": "Infra", "category": "infra", "services": [], + "dir": "concept/infra/terraform", "status": "generated", "files": []}, + ] + (tmp_project / "concept" / "infra" / "terraform").mkdir(parents=True, exist_ok=True) + (tmp_project / "concept" / "infra" / "terraform" / "main.tf").write_text("# original") + + session = self._make_session(tmp_project, build_stages=stages) + + session._qa_agent = MagicMock() + session._qa_agent.execute.return_value = _make_response("Diagnosis: missing config") + + session._architect_agent = MagicMock() + session._architect_agent.execute.return_value = _make_response("Fix the provider.\n[]") + + mock_iac = MagicMock() + mock_iac.execute.return_value = _make_response( + "```main.tf\n# fixed\n```" + ) + session._iac_agents["terraform"] = mock_iac + + result = {"status": "failed", "error": "Error: provider error"} + stage = session._deploy_state.get_stage(1) + output = [] + + deploy_call_count = [0] + + def mock_deploy(*args, **kwargs): + deploy_call_count[0] += 1 + if deploy_call_count[0] <= 1: + return {"status": "failed", "error": "still broken"} + return {"status": "deployed"} + + with patch.object(session, "_deploy_single_stage", side_effect=mock_deploy): + remediated = session._remediate_deploy_failure( + stage, result, False, lambda msg: output.append(msg), lambda p: "", + ) + + assert remediated is not None + assert remediated["status"] == "deployed" + assert deploy_call_count[0] == 2 + + def test_remediation_exhausted(self, tmp_project): + """All remediation attempts fail — falls through.""" + stages = [ + {"stage": 1, "name": "Infra", "category": "infra", "services": [], + "dir": "concept/infra/terraform", "status": "generated", "files": []}, + ] + (tmp_project / "concept" / "infra" / "terraform").mkdir(parents=True, exist_ok=True) + (tmp_project / "concept" / "infra" / "terraform" / "main.tf").write_text("# original") + + session = self._make_session(tmp_project, build_stages=stages) + + session._qa_agent = MagicMock() + session._qa_agent.execute.return_value = _make_response("Diagnosis: broken") + + session._architect_agent = MagicMock() + session._architect_agent.execute.return_value = _make_response("Fix it.\n[]") + + mock_iac = MagicMock() + mock_iac.execute.return_value = _make_response("```main.tf\n# attempt\n```") + session._iac_agents["terraform"] = mock_iac + + result = {"status": "failed", "error": "persistent error"} + stage = session._deploy_state.get_stage(1) + output = [] + + with patch.object(session, "_deploy_single_stage", return_value={"status": "failed", "error": "still broken"}): + remediated = session._remediate_deploy_failure( + stage, result, False, lambda msg: output.append(msg), lambda p: "", + ) + + assert remediated is not None + assert remediated["status"] == "failed" + joined = "\n".join(output) + assert "Re-deploy failed" in joined + + def test_remediation_no_agents(self, tmp_project): + """Gracefully skipped when no fix agents are available.""" + stages = [ + {"stage": 1, "name": "Infra", "category": "infra", "services": [], + "dir": "concept/infra/terraform", "status": "generated", "files": []}, + ] + session = self._make_session(tmp_project, build_stages=stages) + + # Clear all agents + session._qa_agent = None + session._iac_agents = {} + session._dev_agent = None + session._architect_agent = None + + result = {"status": "failed", "error": "auth error"} + stage = session._deploy_state.get_stage(1) + output = [] + + remediated = session._remediate_deploy_failure( + stage, result, False, lambda msg: output.append(msg), lambda p: "", + ) + + assert remediated is None # No remediation attempted + + def test_remediation_qa_cannot_diagnose(self, tmp_project): + """Stops early when QA can't diagnose.""" + stages = [ + {"stage": 1, "name": "Infra", "category": "infra", "services": [], + "dir": "concept/infra/terraform", "status": "generated", "files": []}, + ] + session = self._make_session(tmp_project, build_stages=stages) + + # QA returns no diagnosis + session._qa_agent = MagicMock() + session._qa_agent.execute.return_value = _make_response("") + + mock_iac = MagicMock() + session._iac_agents["terraform"] = mock_iac + + result = {"status": "failed", "error": "auth error"} + stage = session._deploy_state.get_stage(1) + output = [] + + remediated = session._remediate_deploy_failure( + stage, result, False, lambda msg: output.append(msg), lambda p: "", + ) + + # Should not have called the IaC agent since QA couldn't diagnose + mock_iac.execute.assert_not_called() + + def test_remediation_updates_build_state(self, tmp_project): + """Build.yaml files list is updated after remediation writes.""" + stages = [ + {"stage": 1, "name": "Infra", "category": "infra", "services": [], + "dir": "concept/infra/terraform", "status": "generated", + "files": ["concept/infra/terraform/main.tf"]}, + ] + (tmp_project / "concept" / "infra" / "terraform").mkdir(parents=True, exist_ok=True) + (tmp_project / "concept" / "infra" / "terraform" / "main.tf").write_text("# original") + + session = self._make_session(tmp_project, build_stages=stages) + + content = "```main.tf\n# fixed content\n```" + stage = session._deploy_state.get_stage(1) + written = session._write_stage_files(stage, content) + + assert len(written) == 1 + assert "main.tf" in written[0] + + # Verify build state was updated + from azext_prototype.stages.build_state import BuildState + bs = BuildState(str(tmp_project)) + bs.load() + build_stage = bs.state["deployment_stages"][0] + assert build_stage["files"] == written + + @patch("azext_prototype.stages.deploy_session.subprocess.run", return_value=MagicMock(returncode=0, stdout="Terraform v1.7.0\n", stderr="")) + @patch("azext_prototype.stages.deploy_session.check_az_login", return_value=True) + @patch("azext_prototype.stages.deploy_session.get_current_subscription", return_value="sub-123") + @patch("azext_prototype.stages.deploy_session.deploy_terraform") + def test_slash_deploy_routes_through_remediation(self, mock_tf, mock_sub, mock_login, mock_subprocess, tmp_project): + """/deploy N triggers remediation on failure.""" + stages = [ + {"stage": 1, "name": "Infra", "category": "infra", "services": [], + "dir": "concept/infra/terraform", "status": "generated", "files": []}, + ] + (tmp_project / "concept" / "infra" / "terraform").mkdir(parents=True, exist_ok=True) + + session = self._make_session(tmp_project, build_stages=stages) + + mock_tf.return_value = {"status": "failed", "error": "auth error"} + output = [] + + with patch.object(session, "_handle_deploy_failure", return_value={"status": "failed", "error": "auth error"}) as mock_handle: + session._handle_slash_command( + "/deploy 1", False, False, + lambda msg: output.append(msg), lambda p: "", + ) + + # _handle_deploy_failure should have been called + mock_handle.assert_called_once() + + @patch("azext_prototype.stages.deploy_session.deploy_terraform") + def test_slash_redeploy_routes_through_remediation(self, mock_tf, tmp_project): + """/redeploy N triggers remediation on failure.""" + stages = [ + {"stage": 1, "name": "Infra", "category": "infra", "services": [], + "dir": "concept/infra/terraform", "status": "generated", "files": []}, + ] + (tmp_project / "concept" / "infra" / "terraform").mkdir(parents=True, exist_ok=True) + + session = self._make_session(tmp_project, build_stages=stages) + session._deploy_env = {"ARM_SUBSCRIPTION_ID": "sub-123"} + + mock_tf.return_value = {"status": "failed", "error": "deploy error"} + output = [] + + with patch.object(session, "_handle_deploy_failure", return_value={"status": "failed", "error": "deploy error"}) as mock_handle: + session._handle_slash_command( + "/redeploy 1", False, False, + lambda msg: output.append(msg), lambda p: "", + ) + + mock_handle.assert_called_once() + + def test_downstream_impact_detected(self, tmp_project): + """Architect flags downstream stages for regeneration.""" + stages = [ + {"stage": 1, "name": "Foundation", "category": "infra", "services": [], + "dir": "concept/infra/terraform/stage-1", "status": "generated", "files": []}, + {"stage": 2, "name": "Data Layer", "category": "data", "services": [], + "dir": "concept/infra/terraform/stage-2", "status": "generated", "files": []}, + {"stage": 3, "name": "App", "category": "app", "services": [], + "dir": "concept/apps/stage-3", "status": "generated", "files": []}, + ] + session = self._make_session(tmp_project, build_stages=stages) + + # Mark stage 2 and 3 as pending (downstream) + session._deploy_state.get_stage(2)["deploy_status"] = "pending" + session._deploy_state.get_stage(3)["deploy_status"] = "pending" + + # Architect returns stage 2 as affected + session._architect_agent = MagicMock() + session._architect_agent.execute.return_value = _make_response("Affected stages: [2]") + + stage = session._deploy_state.get_stage(1) + result = session._check_downstream_impact(stage, "Changed outputs from foundation") + + assert 2 in result + assert 1 not in result # Not downstream of itself + + def test_downstream_regeneration(self, tmp_project): + """Flagged downstream stages get regenerated code.""" + stages = [ + {"stage": 1, "name": "Foundation", "category": "infra", "services": [], + "dir": "concept/infra/terraform/stage-1", "status": "generated", "files": []}, + {"stage": 2, "name": "Data Layer", "category": "data", "services": [], + "dir": "concept/infra/terraform/stage-2", "status": "generated", "files": []}, + ] + for s in stages: + (tmp_project / s["dir"]).mkdir(parents=True, exist_ok=True) + (tmp_project / s["dir"] / "main.tf").write_text("# original") + + session = self._make_session(tmp_project, build_stages=stages) + + # Mock IaC agent to return regenerated content + mock_iac = MagicMock() + mock_iac.execute.return_value = _make_response( + "```main.tf\n# regenerated with fixed references\n```" + ) + session._iac_agents["terraform"] = mock_iac + + output = [] + session._regenerate_downstream_stages( + [2], False, lambda msg: output.append(msg), + ) + + joined = "\n".join(output) + assert "regenerated" in joined.lower() + # Verify the file was actually written + content = (tmp_project / "concept" / "infra" / "terraform" / "stage-2" / "main.tf").read_text() + assert "regenerated" in content + + def test_handle_deploy_failure_returns_result(self, tmp_project): + """_handle_deploy_failure returns the remediation result.""" + stages = [ + {"stage": 1, "name": "Infra", "category": "infra", "services": [], + "dir": "concept/infra/terraform", "status": "generated", "files": []}, + ] + session = self._make_session(tmp_project, build_stages=stages) + + # No agents available — remediation returns None + session._qa_agent = None + session._iac_agents = {} + session._dev_agent = None + + result = {"status": "failed", "error": "auth error"} + stage = session._deploy_state.get_stage(1) + output = [] + + returned = session._handle_deploy_failure( + stage, result, False, + lambda msg: output.append(msg), lambda p: "", + ) + + # Should return original result when remediation not possible + assert returned["status"] == "failed" + # Should still show interactive options + joined = "\n".join(output) + assert "/deploy" in joined + + def test_no_ai_provider_skips_remediation(self, tmp_project): + """Remediation is skipped when ai_provider is None.""" + stages = [ + {"stage": 1, "name": "Infra", "category": "infra", "services": [], + "dir": "concept/infra/terraform", "status": "generated", "files": []}, + ] + session = self._make_session(tmp_project, build_stages=stages, ai_provider=None) + + result = {"status": "failed", "error": "auth error"} + stage = session._deploy_state.get_stage(1) + + remediated = session._remediate_deploy_failure( + stage, result, False, lambda msg: None, lambda p: "", + ) + + assert remediated is None + + +# ====================================================================== +# Build-Deploy Decoupling: Stable IDs, Sync, Splitting, Manual Steps +# ====================================================================== + +def _build_yaml_with_ids(stages=None, iac_tool="terraform"): + """Build YAML with stable IDs.""" + if stages is None: + stages = [ + { + "stage": 1, "name": "Foundation", "category": "infra", "id": "foundation", + "deploy_mode": "auto", "manual_instructions": None, + "services": [{"name": "key-vault", "computed_name": "kv-1", "resource_type": "Microsoft.KeyVault/vaults", "sku": "standard"}], + "status": "generated", "dir": "concept/infra/terraform/stage-1-foundation", "files": ["main.tf"], + }, + { + "stage": 2, "name": "Data Layer", "category": "data", "id": "data-layer", + "deploy_mode": "auto", "manual_instructions": None, + "services": [{"name": "sql-db", "computed_name": "sql-1", "resource_type": "Microsoft.Sql/servers", "sku": "S0"}], + "status": "generated", "dir": "concept/infra/terraform/stage-2-data", "files": ["main.tf"], + }, + { + "stage": 3, "name": "Application", "category": "app", "id": "application", + "deploy_mode": "auto", "manual_instructions": None, + "services": [{"name": "web-app", "computed_name": "app-1", "resource_type": "Microsoft.Web/sites", "sku": "B1"}], + "status": "generated", "dir": "concept/apps/stage-3-application", "files": ["app.py"], + }, + ] + return { + "iac_tool": iac_tool, + "deployment_stages": stages, + "_metadata": {"created": "2026-01-01T00:00:00", "last_updated": "2026-01-01T00:00:00", "iteration": 1}, + } + + +def _write_build_yaml_with_ids(project_dir, stages=None, iac_tool="terraform"): + """Write build.yaml with stable IDs.""" + state_dir = Path(project_dir) / ".prototype" / "state" + state_dir.mkdir(parents=True, exist_ok=True) + data = _build_yaml_with_ids(stages, iac_tool) + with open(state_dir / "build.yaml", "w", encoding="utf-8") as f: + yaml.dump(data, f, default_flow_style=False) + return state_dir / "build.yaml" + + +class TestSyncFromBuildState: + + def test_sync_from_build_state_fresh(self, tmp_project): + """First sync creates deploy stages from build stages.""" + from azext_prototype.stages.deploy_state import DeployState + + build_path = _write_build_yaml_with_ids(tmp_project) + ds = DeployState(str(tmp_project)) + result = ds.sync_from_build_state(build_path) + + assert result.created == 3 + assert result.matched == 0 + assert result.orphaned == 0 + assert len(ds.state["deployment_stages"]) == 3 + assert ds.state["deployment_stages"][0]["build_stage_id"] == "foundation" + + def test_sync_from_build_state_preserves_deploy_status(self, tmp_project): + """Matched stages keep their deploy state.""" + from azext_prototype.stages.deploy_state import DeployState + + build_path = _write_build_yaml_with_ids(tmp_project) + ds = DeployState(str(tmp_project)) + ds.load_from_build_state(build_path) + + # Deploy stage 1 + ds.mark_stage_deployed(1, output="done") + + # Re-sync + result = ds.sync_from_build_state(build_path) + assert result.matched == 3 + assert result.created == 0 + + stage1 = ds.state["deployment_stages"][0] + assert stage1["deploy_status"] == "deployed" + assert stage1["deploy_output"] == "done" + + def test_sync_from_build_state_detects_code_change(self, tmp_project): + """Changed files trigger _code_updated marking.""" + from azext_prototype.stages.deploy_state import DeployState + + build_path = _write_build_yaml_with_ids(tmp_project) + ds = DeployState(str(tmp_project)) + ds.load_from_build_state(build_path) + ds.mark_stage_deployed(1) + + # Update build state with new files + updated_stages = _build_yaml_with_ids()["deployment_stages"] + updated_stages[0]["files"] = ["main.tf", "variables.tf"] # changed + _write_build_yaml_with_ids(tmp_project, stages=updated_stages) + + result = ds.sync_from_build_state(build_path) + assert result.updated_code == 1 + assert ds.state["deployment_stages"][0].get("_code_updated") is True + + def test_sync_from_build_state_creates_new(self, tmp_project): + """New build stage creates new deploy stage.""" + from azext_prototype.stages.deploy_state import DeployState + + build_path = _write_build_yaml_with_ids(tmp_project) + ds = DeployState(str(tmp_project)) + ds.load_from_build_state(build_path) + + # Add new stage to build + stages = _build_yaml_with_ids()["deployment_stages"] + stages.append({ + "stage": 4, "name": "Monitoring", "category": "infra", "id": "monitoring", + "deploy_mode": "auto", "manual_instructions": None, + "services": [], "status": "generated", "dir": "concept/infra/terraform/stage-4-monitoring", "files": [], + }) + _write_build_yaml_with_ids(tmp_project, stages=stages) + + result = ds.sync_from_build_state(build_path) + assert result.created == 1 + assert len(ds.state["deployment_stages"]) == 4 + assert ds.state["deployment_stages"][3]["build_stage_id"] == "monitoring" + + def test_sync_from_build_state_with_substages(self, tmp_project): + """Split stages preserved across sync.""" + from azext_prototype.stages.deploy_state import DeployState + + build_path = _write_build_yaml_with_ids(tmp_project) + ds = DeployState(str(tmp_project)) + ds.load_from_build_state(build_path) + + # Split stage 2 into substages + ds.split_stage(2, [ + {"name": "Data Layer - Base", "dir": "concept/infra/terraform/stage-2-data"}, + {"name": "Data Layer - Schema", "dir": "concept/db/schema"}, + ]) + + # Re-sync — substages should be preserved + result = ds.sync_from_build_state(build_path) + data_stages = ds.get_stages_for_build_stage("data-layer") + assert len(data_stages) == 2 + assert data_stages[0]["substage_label"] == "a" + assert data_stages[1]["substage_label"] == "b" + + def test_sync_orphan_sets_removed_status(self, tmp_project): + """Removed build stage → deploy stage gets 'removed' status.""" + from azext_prototype.stages.deploy_state import DeployState + + build_path = _write_build_yaml_with_ids(tmp_project) + ds = DeployState(str(tmp_project)) + ds.load_from_build_state(build_path) + + # Remove a stage from build + stages = _build_yaml_with_ids()["deployment_stages"] + stages = [s for s in stages if s["id"] != "data-layer"] + _write_build_yaml_with_ids(tmp_project, stages=stages) + + result = ds.sync_from_build_state(build_path) + assert result.orphaned == 1 + + removed = [s for s in ds.state["deployment_stages"] if s.get("deploy_status") == "removed"] + assert len(removed) == 1 + assert removed[0]["build_stage_id"] == "data-layer" + + +class TestStageSpitting: + + def test_split_stage(self, tmp_project): + """Split creates substages with shared build_stage_id.""" + from azext_prototype.stages.deploy_state import DeployState + + build_path = _write_build_yaml_with_ids(tmp_project) + ds = DeployState(str(tmp_project)) + ds.load_from_build_state(build_path) + + ds.split_stage(2, [ + {"name": "Data - Base", "dir": "concept/infra/terraform/stage-2-data"}, + {"name": "Data - Schema", "dir": "concept/db/schema"}, + ]) + + # All substages share the same build_stage_id + data_stages = ds.get_stages_for_build_stage("data-layer") + assert len(data_stages) == 2 + assert data_stages[0]["substage_label"] == "a" + assert data_stages[1]["substage_label"] == "b" + assert data_stages[0]["_is_substage"] is True + assert data_stages[1]["_is_substage"] is True + + def test_split_stage_renumbering(self, tmp_project): + """After split, stage numbers are correct.""" + from azext_prototype.stages.deploy_state import DeployState + + build_path = _write_build_yaml_with_ids(tmp_project) + ds = DeployState(str(tmp_project)) + ds.load_from_build_state(build_path) + + ds.split_stage(2, [ + {"name": "Data - Base", "dir": "dir1"}, + {"name": "Data - Schema", "dir": "dir2"}, + ]) + + stages = ds.state["deployment_stages"] + # Stage 1 stays as 1, substages get stage 2 with labels, stage 3 stays + assert stages[0]["stage"] == 1 # Foundation + assert stages[1]["stage"] == 2 # Data - Base (2a) + assert stages[1]["substage_label"] == "a" + assert stages[2]["stage"] == 2 # Data - Schema (2b) + assert stages[2]["substage_label"] == "b" + assert stages[3]["stage"] == 3 # Application + + def test_get_stage_groups(self, tmp_project): + """Verify grouping by build_stage_id.""" + from azext_prototype.stages.deploy_state import DeployState + + build_path = _write_build_yaml_with_ids(tmp_project) + ds = DeployState(str(tmp_project)) + ds.load_from_build_state(build_path) + + ds.split_stage(2, [ + {"name": "Data - Base", "dir": "dir1"}, + {"name": "Data - Schema", "dir": "dir2"}, + ]) + + groups = ds.get_stage_groups() + assert "foundation" in groups + assert "data-layer" in groups + assert "application" in groups + assert len(groups["data-layer"]) == 2 + assert len(groups["foundation"]) == 1 + + def test_can_rollback_with_substages(self, tmp_project): + """Rollback checks work with substages.""" + from azext_prototype.stages.deploy_state import DeployState + + build_path = _write_build_yaml_with_ids(tmp_project) + ds = DeployState(str(tmp_project)) + ds.load_from_build_state(build_path) + + ds.split_stage(2, [ + {"name": "Data - Base", "dir": "dir1"}, + {"name": "Data - Schema", "dir": "dir2"}, + ]) + + # Deploy both substages + substages = ds.get_stages_for_build_stage("data-layer") + substages[0]["deploy_status"] = "deployed" + substages[1]["deploy_status"] = "deployed" + ds.save() + + # Can't rollback "a" while "b" is deployed + assert ds.can_rollback(2, "a") is False + # Can rollback "b" + assert ds.can_rollback(2, "b") is True + + def test_get_stage_by_display_id(self, tmp_project): + """Parse and lookup by compound display ID.""" + from azext_prototype.stages.deploy_state import DeployState + + build_path = _write_build_yaml_with_ids(tmp_project) + ds = DeployState(str(tmp_project)) + ds.load_from_build_state(build_path) + + ds.split_stage(2, [ + {"name": "Data - Base", "dir": "dir1"}, + {"name": "Data - Schema", "dir": "dir2"}, + ]) + + found = ds.get_stage_by_display_id("2a") + assert found is not None + assert found["name"] == "Data - Base" + + found_b = ds.get_stage_by_display_id("2b") + assert found_b is not None + assert found_b["name"] == "Data - Schema" + + +class TestDeployStateNewStatuses: + + def test_load_from_build_state_backward_compat(self, tmp_project): + """Legacy build state without IDs still imports correctly.""" + from azext_prototype.stages.deploy_state import DeployState + + # Write legacy build yaml (no id field) + build_path = _write_build_yaml(tmp_project) + ds = DeployState(str(tmp_project)) + result = ds.load_from_build_state(build_path) + + assert result is True + # build_stage_id should be auto-generated from name + for stage in ds.state["deployment_stages"]: + assert stage.get("build_stage_id") + + def test_destroy_stage(self, tmp_project): + """Destroyed status after rollback.""" + from azext_prototype.stages.deploy_state import DeployState + + build_path = _write_build_yaml_with_ids(tmp_project) + ds = DeployState(str(tmp_project)) + ds.load_from_build_state(build_path) + + ds.mark_stage_deployed(1) + ds.mark_stage_rolled_back(1) + ds.mark_stage_destroyed(1) + + assert ds.get_stage(1)["deploy_status"] == "destroyed" + + def test_destruction_declined_not_reprompted(self, tmp_project): + """_destruction_declined flag persists across save/load.""" + from azext_prototype.stages.deploy_state import DeployState + + build_path = _write_build_yaml_with_ids(tmp_project) + ds = DeployState(str(tmp_project)) + ds.load_from_build_state(build_path) + + stage = ds.get_stage(1) + stage["_destruction_declined"] = True + ds.save() + + ds2 = DeployState(str(tmp_project)) + ds2.load() + assert ds2.get_stage(1)["_destruction_declined"] is True + + def test_awaiting_manual_status(self, tmp_project): + """Manual step sets awaiting_manual status.""" + from azext_prototype.stages.deploy_state import DeployState + + build_path = _write_build_yaml_with_ids(tmp_project) + ds = DeployState(str(tmp_project)) + ds.load_from_build_state(build_path) + + ds.mark_stage_awaiting_manual(1) + assert ds.get_stage(1)["deploy_status"] == "awaiting_manual" + + +class TestManualStepDeploy: + + def test_manual_step_deploy(self, tmp_project): + """Manual stage shows instructions, waits for confirmation.""" + from azext_prototype.stages.deploy_state import DeployState + + stages = [ + { + "stage": 1, "name": "Upload Notebook", "category": "external", "id": "upload-notebook", + "deploy_mode": "manual", "manual_instructions": "Upload the notebook to Fabric workspace.", + "services": [], "status": "generated", + "dir": "concept/docs", "files": [], + }, + ] + build_path = _write_build_yaml_with_ids(tmp_project, stages=stages) + ds = DeployState(str(tmp_project)) + ds.load_from_build_state(build_path) + + # Verify the manual stage imported correctly + stage = ds.get_stage(1) + assert stage["deploy_mode"] == "manual" + assert "Upload" in stage["manual_instructions"] + + def test_manual_step_from_build(self, tmp_project): + """deploy_mode: 'manual' inherited from build stage via sync.""" + from azext_prototype.stages.deploy_state import DeployState + + stages = [ + { + "stage": 1, "name": "Foundation", "category": "infra", "id": "foundation", + "deploy_mode": "auto", "manual_instructions": None, + "services": [], "status": "generated", + "dir": "concept/infra/terraform/stage-1-foundation", "files": [], + }, + { + "stage": 2, "name": "Manual Config", "category": "external", "id": "manual-config", + "deploy_mode": "manual", "manual_instructions": "Configure the firewall rules manually.", + "services": [], "status": "generated", + "dir": "", "files": [], + }, + ] + build_path = _write_build_yaml_with_ids(tmp_project, stages=stages) + + ds = DeployState(str(tmp_project)) + ds.load_from_build_state(build_path) + + manual_stage = ds.state["deployment_stages"][1] + assert manual_stage["deploy_mode"] == "manual" + assert "firewall" in manual_stage["manual_instructions"] + + def test_code_split_syncs_back_to_build(self, tmp_project): + """Type A split: _sync_build_state uses build_stage_id for matching.""" + from azext_prototype.stages.build_state import BuildState + from azext_prototype.stages.deploy_state import DeployState + + build_path = _write_build_yaml_with_ids(tmp_project) + + # Load into deploy state + ds = DeployState(str(tmp_project)) + ds.load_from_build_state(build_path) + + # Load build state and verify get_stage_by_id works + bs = BuildState(str(tmp_project)) + bs.load() + + # Verify the build stage has the right id + build_stage = bs.get_stage_by_id("data-layer") + assert build_stage is not None + assert build_stage["name"] == "Data Layer" + + # Deploy stage links back correctly + deploy_stage = ds.state["deployment_stages"][1] + assert deploy_stage["build_stage_id"] == "data-layer" + + +class TestParseStageRef: + + def test_parse_simple_number(self): + from azext_prototype.stages.deploy_state import parse_stage_ref + + num, label = parse_stage_ref("5") + assert num == 5 + assert label is None + + def test_parse_substage(self): + from azext_prototype.stages.deploy_state import parse_stage_ref + + num, label = parse_stage_ref("5a") + assert num == 5 + assert label == "a" + + def test_parse_invalid(self): + from azext_prototype.stages.deploy_state import parse_stage_ref + + num, label = parse_stage_ref("abc") + assert num is None + assert label is None + + def test_parse_empty(self): + from azext_prototype.stages.deploy_state import parse_stage_ref + + num, label = parse_stage_ref("") + assert num is None + + def test_parse_with_whitespace(self): + from azext_prototype.stages.deploy_state import parse_stage_ref + + num, label = parse_stage_ref(" 3b ") + assert num == 3 + assert label == "b" + + +class TestRenumberWithSubstages: + + def test_renumber_preserves_substage_labels(self, tmp_project): + """Substages keep their labels and inherit parent number.""" + from azext_prototype.stages.deploy_state import DeployState + + build_path = _write_build_yaml_with_ids(tmp_project) + ds = DeployState(str(tmp_project)) + ds.load_from_build_state(build_path) + + # Split stage 2 + ds.split_stage(2, [ + {"name": "Data - Base", "dir": "dir1"}, + {"name": "Data - Schema", "dir": "dir2"}, + ]) + + # Remove stage 1 — renumber should shift substages + stages = ds.state["deployment_stages"] + ds._state["deployment_stages"] = [s for s in stages if s.get("build_stage_id") != "foundation"] + ds.renumber_stages() + + stages = ds.state["deployment_stages"] + # Now data substages should be stage 1 + assert stages[0]["stage"] == 1 + assert stages[0]["substage_label"] == "a" + assert stages[1]["stage"] == 1 + assert stages[1]["substage_label"] == "b" + # Application should be stage 2 + assert stages[2]["stage"] == 2 + assert stages[2]["substage_label"] is None + + +class TestFormatDisplayId: + + def test_format_top_level(self): + from azext_prototype.stages.deploy_state import _format_display_id + + assert _format_display_id({"stage": 3}) == "3" + + def test_format_substage(self): + from azext_prototype.stages.deploy_state import _format_display_id + + assert _format_display_id({"stage": 3, "substage_label": "b"}) == "3b" + + def test_format_no_label(self): + from azext_prototype.stages.deploy_state import _format_display_id + + assert _format_display_id({"stage": 1, "substage_label": None}) == "1" + + +class TestNewStatusIcons: + + def test_removed_icon(self): + from azext_prototype.stages.deploy_state import _status_icon + + assert _status_icon("removed") == "~~" + + def test_destroyed_icon(self): + from azext_prototype.stages.deploy_state import _status_icon + + assert _status_icon("destroyed") == "xx" + + def test_awaiting_manual_icon(self): + from azext_prototype.stages.deploy_state import _status_icon + + assert _status_icon("awaiting_manual") == "!!" + + def test_existing_icons_unchanged(self): + from azext_prototype.stages.deploy_state import _status_icon + + assert _status_icon("pending") == " " + assert _status_icon("deployed") == " v" + assert _status_icon("failed") == " x" + assert _status_icon("remediating") == "<>" + + +class TestDeployReportFormatting: + + def test_format_shows_removed_stages(self, tmp_project): + """Removed stages show with strikethrough in report.""" + from azext_prototype.stages.deploy_state import DeployState + + build_path = _write_build_yaml_with_ids(tmp_project) + ds = DeployState(str(tmp_project)) + ds.load_from_build_state(build_path) + ds.mark_stage_removed(2) + + report = ds.format_deploy_report() + assert "(Removed)" in report + assert "~~Data Layer~~" in report + + def test_format_shows_manual_badge(self, tmp_project): + """Manual stages show [Manual] badge.""" + from azext_prototype.stages.deploy_state import DeployState + + stages = [ + { + "stage": 1, "name": "Manual Step", "category": "external", "id": "manual", + "deploy_mode": "manual", "manual_instructions": "Do the thing.", + "services": [], "status": "generated", "dir": "", "files": [], + }, + ] + build_path = _write_build_yaml_with_ids(tmp_project, stages=stages) + ds = DeployState(str(tmp_project)) + ds.load_from_build_state(build_path) + + report = ds.format_deploy_report() + assert "[Manual]" in report + + status = ds.format_stage_status() + assert "[Manual]" in status + + def test_format_shows_substage_ids(self, tmp_project): + """Substages show compound display IDs like 2a, 2b.""" + from azext_prototype.stages.deploy_state import DeployState + + build_path = _write_build_yaml_with_ids(tmp_project) + ds = DeployState(str(tmp_project)) + ds.load_from_build_state(build_path) + + ds.split_stage(2, [ + {"name": "Data - Base", "dir": "dir1"}, + {"name": "Data - Schema", "dir": "dir2"}, + ]) + + status = ds.format_stage_status() + assert "2a" in status + assert "2b" in status diff --git a/tests/test_discovery.py b/tests/test_discovery.py index 33eacc5..e0b14a2 100644 --- a/tests/test_discovery.py +++ b/tests/test_discovery.py @@ -10,9 +10,13 @@ from azext_prototype.stages.discovery import ( DiscoverySession, DiscoveryResult, + Section, + extract_section_headers, + parse_sections, _READY_MARKER, _QUIT_WORDS, _DONE_WORDS, + _SECTION_COMPLETE_MARKER, ) @@ -321,6 +325,24 @@ def test_all_done_words(self, mock_agent_context, mock_registry, mock_biz_agent) ) assert not result.cancelled, f"'{word}' should end gracefully, not cancel" + def test_end_in_done_words(self): + """'end' should be recognized as a done word.""" + assert "end" in _DONE_WORDS + + def test_end_word_finishes_session(self, mock_agent_context, mock_registry, mock_biz_agent): + """Typing 'end' should complete the session (not cancel).""" + mock_agent_context.ai_provider.chat.side_effect = [ + _make_response("Hi! Tell me about your project."), + _make_response("## Summary\nHere's what we discussed."), + ] + session = DiscoverySession(mock_agent_context, mock_registry) + result = session.run( + input_fn=lambda _: "end", + print_fn=lambda x: None, + ) + assert not result.cancelled + assert result.exchange_count >= 1 + def test_eof_exits_gracefully(self, mock_agent_context, mock_registry, mock_biz_agent): mock_agent_context.ai_provider.chat.return_value = _make_response("Hi!") session = DiscoverySession(mock_agent_context, mock_registry) @@ -581,7 +603,7 @@ def test_design_stage_uses_discovery( mock_agent_context.project_dir = str(project_with_config) mock_agent_context.ai_provider.chat.return_value = _make_response( - "## Architecture\nMock output" + "Tell me more about your project." ) inputs = iter(["Build a REST API", "PostgreSQL, 50 users", "done"]) @@ -605,7 +627,7 @@ def test_cancelled_discovery_cancels_design( mock_agent_context.project_dir = str(project_with_config) mock_agent_context.ai_provider.chat.return_value = _make_response( - "## Architecture\nShould not appear" + "Tell me about your project." ) result = stage.execute( @@ -629,7 +651,7 @@ def test_design_stage_persists_policy_overrides( mock_agent_context.project_dir = str(project_with_config) mock_agent_context.ai_provider.chat.return_value = _make_response( - "## Architecture\nDesign with overrides" + "Architecture design with overrides." ) mock_result = DiscoveryResult( @@ -818,7 +840,7 @@ def test_restart_clears_conversation_history( class TestWhyCommand: def test_why_no_argument_shows_usage( - self, mock_agent_context, mock_registry, mock_biz_agent, capsys, + self, mock_agent_context, mock_registry, mock_biz_agent, ): """/why with no argument should show usage hint, not crash.""" mock_agent_context.ai_provider.chat.side_effect = [ @@ -828,17 +850,18 @@ def test_why_no_argument_shows_usage( session = DiscoverySession(mock_agent_context, mock_registry) inputs = iter(["/why", "done"]) + output = [] session.run( input_fn=lambda _: next(inputs), - print_fn=lambda x: None, + print_fn=output.append, ) - captured = capsys.readouterr() - assert "Usage" in captured.out or "/why" in captured.out + combined = "\n".join(str(x) for x in output) + assert "Usage" in combined or "/why" in combined def test_why_with_matching_query( - self, mock_agent_context, mock_registry, mock_biz_agent, capsys, + self, mock_agent_context, mock_registry, mock_biz_agent, ): """/why should find exchanges mentioning the queried topic.""" mock_agent_context.ai_provider.chat.side_effect = [ @@ -849,17 +872,18 @@ def test_why_with_matching_query( session = DiscoverySession(mock_agent_context, mock_registry) inputs = iter(["Use managed identity for auth", "/why managed identity", "done"]) + output = [] session.run( input_fn=lambda _: next(inputs), - print_fn=lambda x: None, + print_fn=output.append, ) - captured = capsys.readouterr() - assert "Exchange" in captured.out + combined = "\n".join(str(x) for x in output) + assert "Exchange" in combined def test_why_no_matches( - self, mock_agent_context, mock_registry, mock_biz_agent, capsys, + self, mock_agent_context, mock_registry, mock_biz_agent, ): """/why with no matching history should show 'no exchanges found'.""" mock_agent_context.ai_provider.chat.side_effect = [ @@ -869,14 +893,15 @@ def test_why_no_matches( session = DiscoverySession(mock_agent_context, mock_registry) inputs = iter(["/why kubernetes", "done"]) + output = [] session.run( input_fn=lambda _: next(inputs), - print_fn=lambda x: None, + print_fn=output.append, ) - captured = capsys.readouterr() - assert "No exchanges found" in captured.out + combined = "\n".join(str(x) for x in output) + assert "No exchanges found" in combined # ====================================================================== @@ -1185,3 +1210,661 @@ def test_summary_prompt_asks_for_no_skipped_sections( user_msgs = [m.content for m in messages if m.role == "user"] summary_prompt = user_msgs[-1] assert "None" in summary_prompt or "skip" in summary_prompt.lower() + + +# ====================================================================== +# Natural Language Intent Detection — Integration +# ====================================================================== + + +class TestNaturalLanguageIntentDiscovery: + """Test that natural language triggers the correct slash commands.""" + + def test_nl_open_items(self, mock_agent_context, mock_registry): + """'what are the open items' should trigger the /open display.""" + # Use return_value — any call returns a valid response (no headings + # to avoid triggering section-at-a-time gating) + mock_agent_context.ai_provider.chat.return_value = _make_response( + "Tell me about your project." + ) + session = DiscoverySession(mock_agent_context, mock_registry) + output = [] + inputs = iter(["what are the open items", "done"]) + result = session.run( + input_fn=lambda _: next(inputs), + print_fn=output.append, + ) + # The /open handler should have run and printed open items info + assert any("open" in o.lower() for o in output if isinstance(o, str)) + + def test_nl_status(self, mock_agent_context, mock_registry): + """'where do we stand' should trigger the /status display.""" + mock_agent_context.ai_provider.chat.return_value = _make_response( + "Tell me about your project." + ) + session = DiscoverySession(mock_agent_context, mock_registry) + output = [] + inputs = iter(["where do we stand", "done"]) + result = session.run( + input_fn=lambda _: next(inputs), + print_fn=output.append, + ) + assert any("status" in o.lower() or "discovery" in o.lower() for o in output if isinstance(o, str)) + + +# ====================================================================== +# extract_section_headers +# ====================================================================== + +class TestExtractSectionHeaders: + """Unit tests for extract_section_headers().""" + + def test_extracts_h2_headings(self): + text = "## Project Context & Scope\nSome text\n## Data & Content\nMore text" + result = extract_section_headers(text) + assert result == [("Project Context & Scope", 2), ("Data & Content", 2)] + + def test_extracts_h3_headings(self): + text = "### Authentication\nDetails\n### Authorization\nMore details" + result = extract_section_headers(text) + assert result == [("Authentication", 3), ("Authorization", 3)] + + def test_mixed_h2_h3(self): + text = "## Overview\nText\n### Sub-section\nText\n## Architecture\nText" + result = extract_section_headers(text) + assert result == [("Overview", 2), ("Sub-section", 3), ("Architecture", 2)] + + def test_skips_structural_headings(self): + text = ( + "## Project Context\nText\n" + "## Summary\nText\n" + "## Policy Overrides\nText\n" + "## Next Steps\nText\n" + ) + result = extract_section_headers(text) + assert result == [("Project Context", 2)] + + def test_skips_policy_override_singular(self): + text = "## Policy Override\nText" + result = extract_section_headers(text) + assert result == [] + + def test_skips_short_headings(self): + text = "## AB\nText\n## OK\nMore" + result = extract_section_headers(text) + assert result == [] + + def test_empty_string(self): + assert extract_section_headers("") == [] + + def test_no_headings(self): + text = "Just plain text without any headings at all." + assert extract_section_headers(text) == [] + + def test_h1_not_extracted(self): + """Only ## and ### are extracted, not #.""" + text = "# Title\n## Section One\nContent" + result = extract_section_headers(text) + assert result == [("Section One", 2)] + + def test_strips_whitespace(self): + text = "## Padded Heading \nText" + result = extract_section_headers(text) + assert result == [("Padded Heading", 2)] + + def test_case_insensitive_skip(self): + text = "## SUMMARY\nText\n## NEXT STEPS\nText\n## Actual Content\nText" + result = extract_section_headers(text) + assert result == [("Actual Content", 2)] + + def test_bold_headings_extracted(self): + """**Bold Heading** on its own line should be extracted as level 2.""" + text = ( + "Let me ask about your project.\n" + "\n" + "**Hosting & Deployment**\n" + "How do you plan to host this?\n" + "\n" + "**Data Layer**\n" + "What database will you use?" + ) + result = extract_section_headers(text) + assert ("Hosting & Deployment", 2) in result + assert ("Data Layer", 2) in result + + def test_bold_inline_not_extracted(self): + """Bold text mid-line should NOT be extracted as a heading.""" + text = "I think **this is important** for the project." + result = extract_section_headers(text) + assert result == [] + + def test_bold_and_markdown_headings_merged(self): + """Both ## headings and **bold headings** should be found with levels.""" + text = ( + "## Architecture Overview\n" + "Details here.\n" + "\n" + "**Security Considerations**\n" + "More details." + ) + result = extract_section_headers(text) + assert ("Architecture Overview", 2) in result + assert ("Security Considerations", 2) in result + + def test_bold_headings_deduped(self): + """Duplicate headings (same text in both formats) should appear once.""" + text = ( + "## Security\n" + "Details.\n" + "\n" + "**Security**\n" + "More details." + ) + result = extract_section_headers(text) + texts = [h[0] for h in result] + assert texts.count("Security") == 1 + + def test_bold_headings_skip_structural(self): + """Bold structural headings (Summary, Next Steps) should be skipped.""" + text = "**Summary**\nText\n**Actual Topic**\nMore text" + result = extract_section_headers(text) + texts = [h[0] for h in result] + assert "Summary" not in texts + assert "Actual Topic" in texts + + def test_bold_heading_too_short(self): + """Bold headings under 3 chars should be skipped.""" + text = "**AB**\nText" + result = extract_section_headers(text) + assert result == [] + + def test_skip_what_ive_understood(self): + """'What I've Understood So Far' and variants should be filtered.""" + text = ( + "## What I've Understood So Far\nStuff\n" + "## What We've Covered\nMore stuff\n" + "## Actual Topic\nReal content" + ) + result = extract_section_headers(text) + texts = [h[0] for h in result] + assert "What I've Understood So Far" not in texts + assert "What We've Covered" not in texts + assert "Actual Topic" in texts + + def test_position_ordering(self): + """Headers should be sorted by their position in the response.""" + text = ( + "**First Bold**\n" + "Text\n" + "## Second Markdown\n" + "Text\n" + "**Third Bold**\n" + "Text" + ) + result = extract_section_headers(text) + assert result == [("First Bold", 2), ("Second Markdown", 2), ("Third Bold", 2)] + + +# ====================================================================== +# section_fn callback integration +# ====================================================================== + +class TestSectionFnCallback: + """Verify that section_fn is called with extracted headers during a session.""" + + def test_section_fn_receives_headers( + self, mock_agent_context, mock_registry, + ): + """section_fn should be called upfront with all headers from the AI response.""" + mock_agent_context.ai_provider.chat.side_effect = [ + _make_response( + "## Project Context & Scope\n" + "Let me ask about your project.\n" + "## Data & Content\n" + "What kind of data will you store?" + ), + # Summary after "done" exits the section loop + _make_response("## Summary\nAll done."), + ] + + session = DiscoverySession(mock_agent_context, mock_registry) + captured_headers = [] + + def _section_fn(headers): + captured_headers.extend(headers) + + # "done" exits from the section loop immediately + result = session.run( + input_fn=lambda _: "done", + print_fn=lambda x: None, + section_fn=_section_fn, + ) + + texts = [h[0] for h in captured_headers] + assert "Project Context & Scope" in texts + assert "Data & Content" in texts + + def test_section_fn_not_called_when_none( + self, mock_agent_context, mock_registry, + ): + """When section_fn is None, no error should occur.""" + mock_agent_context.ai_provider.chat.side_effect = [ + _make_response("## Some Heading\nContent"), + _make_response("## Summary\nDone"), + ] + + session = DiscoverySession(mock_agent_context, mock_registry) + # Should not raise — section_fn defaults to None + result = session.run( + input_fn=lambda _: "done", + print_fn=lambda x: None, + ) + assert not result.cancelled + + +# ====================================================================== +# response_fn callback integration +# ====================================================================== + +class TestResponseFnCallback: + """Verify that response_fn is called with agent responses during a session.""" + + def test_response_fn_receives_agent_responses( + self, mock_agent_context, mock_registry, + ): + """response_fn should be called with cleaned agent responses.""" + # Use a response without ## headings so it takes the non-sectioned path + mock_agent_context.ai_provider.chat.side_effect = [ + _make_response("Let me understand your project. What are you building?"), + _make_response("An API. Got it."), + _make_response("Final summary."), + ] + + session = DiscoverySession(mock_agent_context, mock_registry) + captured = [] + + def _response_fn(content): + captured.append(content) + + inputs = iter(["A REST API", "done"]) + result = session.run( + input_fn=lambda _: next(inputs), + print_fn=lambda x: None, + response_fn=_response_fn, + ) + + # response_fn should have been called for the opening and the reply + assert len(captured) == 2 + assert "understand your project" in captured[0] + assert "API" in captured[1] + + def test_response_fn_not_called_when_none( + self, mock_agent_context, mock_registry, + ): + """When response_fn is None, print_fn should be used instead.""" + mock_agent_context.ai_provider.chat.side_effect = [ + _make_response("What are you building?"), + _make_response("## Summary\nDone"), + ] + + session = DiscoverySession(mock_agent_context, mock_registry) + printed = [] + + result = session.run( + input_fn=lambda _: "done", + print_fn=lambda x: printed.append(x), + ) + + # print_fn should have received the response + assert any("building" in p.lower() for p in printed if isinstance(p, str)) + + def test_response_fn_takes_precedence_over_print_fn( + self, mock_agent_context, mock_registry, + ): + """response_fn should be used instead of print_fn for agent responses.""" + mock_agent_context.ai_provider.chat.side_effect = [ + _make_response("Tell me about your project."), + _make_response("## Summary\nDone."), + ] + + session = DiscoverySession(mock_agent_context, mock_registry) + printed = [] + response_captured = [] + + result = session.run( + input_fn=lambda _: "done", + print_fn=lambda x: printed.append(x), + response_fn=lambda x: response_captured.append(x), + ) + + # response_fn should have the agent response + assert len(response_captured) == 1 + assert "Tell me about your project" in response_captured[0] + # print_fn should NOT have the agent response text + assert not any("Tell me about your project" in p for p in printed if isinstance(p, str)) + + +# ====================================================================== +# parse_sections() +# ====================================================================== + +class TestParseSections: + """Verify section parsing from AI responses.""" + + def test_basic_section_splitting(self): + text = ( + "Here's my analysis.\n\n" + "## Authentication\n" + "How do users sign in?\n\n" + "## Data Layer\n" + "What database do you prefer?" + ) + preamble, sections = parse_sections(text) + assert preamble == "Here's my analysis." + assert len(sections) == 2 + assert sections[0].heading == "Authentication" + assert sections[0].level == 2 + assert "How do users sign in?" in sections[0].content + assert sections[1].heading == "Data Layer" + assert "What database" in sections[1].content + + def test_preamble_only(self): + text = "No headings here, just a plain response." + preamble, sections = parse_sections(text) + assert preamble == text + assert sections == [] + + def test_empty_preamble(self): + text = "## First Topic\nQuestion here." + preamble, sections = parse_sections(text) + assert preamble == "" + assert len(sections) == 1 + + def test_skip_headings_filtered(self): + text = ( + "## Authentication\nHow do users sign in?\n\n" + "## Summary\nThis is a summary.\n\n" + "## Next Steps\nDo this next." + ) + _, sections = parse_sections(text) + assert len(sections) == 1 + assert sections[0].heading == "Authentication" + + def test_task_id_generation(self): + text = "## Data & Content\nWhat kind of data?" + _, sections = parse_sections(text) + assert len(sections) == 1 + assert sections[0].task_id == "design-section-data-content" + + def test_bold_headings(self): + text = ( + "Here's what I need to know.\n\n" + "**Authentication & Security**\n" + "How do users log in?\n\n" + "**Data Storage**\n" + "What database?" + ) + preamble, sections = parse_sections(text) + assert len(sections) == 2 + assert sections[0].heading == "Authentication & Security" + assert sections[0].level == 2 + + def test_level_3_headings(self): + text = "### Sub-topic\nDetailed question." + _, sections = parse_sections(text) + assert len(sections) == 1 + assert sections[0].level == 3 + + def test_mixed_heading_levels(self): + text = ( + "## Main Topic\nOverview.\n\n" + "### Sub-topic\nDetail." + ) + _, sections = parse_sections(text) + assert len(sections) == 2 + assert sections[0].level == 2 + assert sections[1].level == 3 + + def test_empty_string(self): + preamble, sections = parse_sections("") + assert preamble == "" + assert sections == [] + + def test_duplicate_headings_deduped(self): + text = ( + "## Authentication\nFirst mention.\n\n" + "## Authentication\nSecond mention." + ) + _, sections = parse_sections(text) + assert len(sections) == 1 + + +# ====================================================================== +# Section completion via AI "Yes" gate +# ====================================================================== + +class TestSectionDoneDetection: + """Verify section completion detection via AI 'Yes' gate. + + The old heuristic-based ``_is_section_done()`` has been replaced with + an explicit AI confirmation step. When the AI responds with exactly + "Yes" (case-insensitive, optional trailing period) the section is + considered complete. + """ + + def test_continue_in_done_words(self): + """'continue' should be accepted as a done keyword.""" + assert "continue" in _DONE_WORDS + + +# ====================================================================== +# Section-at-a-time flow integration +# ====================================================================== + +class TestSectionAtATimeFlow: + """Verify sections are shown one at a time with follow-ups.""" + + def test_sections_shown_one_at_a_time( + self, mock_agent_context, mock_registry, + ): + """Each section should be shown individually, collecting user input.""" + mock_agent_context.ai_provider.chat.side_effect = [ + # Initial response with 2 sections + _make_response( + "Great, let me explore a few areas.\n\n" + "## Authentication\n" + "How do users sign in?\n\n" + "## Data Layer\n" + "What database do you need?" + ), + # Follow-up for section 1 (auth) — marks section done + _make_response("Yes"), + # Follow-up for section 2 (data) — marks section done + _make_response("Yes"), + # Summary after free-form "done" + _make_response("## Summary\nAll done."), + ] + + session = DiscoverySession(mock_agent_context, mock_registry) + printed = [] + inputs = iter([ + "We use Entra ID", # Answer for section 1 + "SQL Database", # Answer for section 2 + "done", # Exit free-form loop + ]) + + result = session.run( + input_fn=lambda _: next(inputs), + print_fn=lambda x: printed.append(x), + ) + assert not result.cancelled + # Both sections should have been displayed + printed_text = "\n".join(str(p) for p in printed) + assert "Authentication" in printed_text + assert "Data Layer" in printed_text + + def test_skip_advances_to_next_section( + self, mock_agent_context, mock_registry, + ): + """Typing 'skip' should advance to the next section.""" + mock_agent_context.ai_provider.chat.side_effect = [ + _make_response( + "## Auth\nHow do users sign in?\n\n" + "## Data\nWhat database?" + ), + # Follow-up for data section + _make_response("Yes"), + # Summary + _make_response("## Summary\nDone."), + ] + + session = DiscoverySession(mock_agent_context, mock_registry) + inputs = iter([ + "skip", # Skip auth section + "Cosmos DB", # Answer data section + "done", # Exit free-form + ]) + + result = session.run( + input_fn=lambda _: next(inputs), + print_fn=lambda x: None, + ) + assert not result.cancelled + + def test_done_exits_section_loop( + self, mock_agent_context, mock_registry, + ): + """Typing 'done' during section loop should jump to summary.""" + mock_agent_context.ai_provider.chat.side_effect = [ + _make_response( + "## Auth\nHow do users sign in?\n\n" + "## Data\nWhat database?" + ), + # Summary produced after "done" + _make_response("## Summary\nFinal summary."), + ] + + session = DiscoverySession(mock_agent_context, mock_registry) + result = session.run( + input_fn=lambda _: "done", + print_fn=lambda x: None, + ) + assert not result.cancelled + assert result.requirements # Should have summary + + def test_quit_cancels_from_section_loop( + self, mock_agent_context, mock_registry, + ): + """Typing 'quit' during section loop should cancel the session.""" + mock_agent_context.ai_provider.chat.side_effect = [ + _make_response( + "## Auth\nHow do users sign in?\n\n" + "## Data\nWhat database?" + ), + ] + + session = DiscoverySession(mock_agent_context, mock_registry) + result = session.run( + input_fn=lambda _: "quit", + print_fn=lambda x: None, + ) + assert result.cancelled + + def test_follow_ups_iterate_within_section( + self, mock_agent_context, mock_registry, + ): + """Multiple follow-ups within a section should work.""" + mock_agent_context.ai_provider.chat.side_effect = [ + _make_response("## Auth\nHow do users sign in?"), + # First follow-up — needs more info + _make_response("What about service-to-service auth?"), + # Second follow-up — section done + _make_response("Yes"), + # Summary + _make_response("## Summary\nDone."), + ] + + session = DiscoverySession(mock_agent_context, mock_registry) + inputs = iter([ + "Entra ID for users", # First answer + "Managed identity for services", # Second answer + "done", # Exit free-form + ]) + + result = session.run( + input_fn=lambda _: next(inputs), + print_fn=lambda x: None, + ) + assert not result.cancelled + assert result.exchange_count >= 3 # opening + 2 follow-ups + + def test_update_task_fn_called( + self, mock_agent_context, mock_registry, + ): + """update_task_fn should be called with in_progress and completed.""" + mock_agent_context.ai_provider.chat.side_effect = [ + _make_response("## Auth\nHow do users sign in?"), + _make_response("Yes"), + _make_response("## Summary\nDone."), + ] + + session = DiscoverySession(mock_agent_context, mock_registry) + task_updates = [] + + def _update_task_fn(tid, status): + task_updates.append((tid, status)) + + inputs = iter(["Entra ID", "done"]) + result = session.run( + input_fn=lambda _: next(inputs), + print_fn=lambda x: None, + update_task_fn=_update_task_fn, + ) + + # Should have in_progress then completed for the auth section + assert ("design-section-auth", "in_progress") in task_updates + assert ("design-section-auth", "completed") in task_updates + + def test_no_sections_fallback( + self, mock_agent_context, mock_registry, + ): + """When no sections are found, should display full response.""" + mock_agent_context.ai_provider.chat.side_effect = [ + _make_response("Tell me what you want to build."), + _make_response("## Summary\nDone."), + ] + + session = DiscoverySession(mock_agent_context, mock_registry) + printed = [] + + result = session.run( + input_fn=lambda _: "done", + print_fn=lambda x: printed.append(x), + ) + + assert not result.cancelled + printed_text = "\n".join(str(p) for p in printed) + assert "Tell me what you want to build" in printed_text + + def test_yes_gate_not_displayed( + self, mock_agent_context, mock_registry, + ): + """AI 'Yes' confirmation should not be printed to the user.""" + mock_agent_context.ai_provider.chat.side_effect = [ + _make_response("## Auth\nHow do users sign in?"), + _make_response("Yes"), + _make_response("## Summary\nDone."), + ] + + session = DiscoverySession(mock_agent_context, mock_registry) + printed = [] + + inputs = iter(["Entra ID", "continue"]) + result = session.run( + input_fn=lambda _: next(inputs), + print_fn=lambda x: printed.append(x), + ) + + printed_text = "\n".join(str(p) for p in printed) + # The "Yes" response should not appear in output + assert "\nYes\n" not in printed_text diff --git a/tests/test_intent.py b/tests/test_intent.py new file mode 100644 index 0000000..838721b --- /dev/null +++ b/tests/test_intent.py @@ -0,0 +1,546 @@ +"""Tests for azext_prototype.stages.intent — natural language intent classification.""" + +from __future__ import annotations + +import pytest +from unittest.mock import MagicMock + +from azext_prototype.ai.provider import AIResponse +from azext_prototype.stages.intent import ( + CommandDef, + IntentClassifier, + IntentKind, + IntentPattern, + IntentResult, + build_backlog_classifier, + build_build_classifier, + build_deploy_classifier, + build_discovery_classifier, + read_files_for_session, +) + + +# ====================================================================== +# Helpers +# ====================================================================== + + +def _make_response(content: str) -> AIResponse: + return AIResponse(content=content, model="gpt-4o", usage={}) + + +def _make_classifier_with_ai(response_content: str) -> IntentClassifier: + """Build a classifier with a mock AI provider that returns the given content.""" + provider = MagicMock() + provider.chat.return_value = _make_response(response_content) + c = IntentClassifier(ai_provider=provider) + c.add_command_def(CommandDef("/open", "Show open items")) + c.add_command_def(CommandDef("/status", "Show status")) + return c + + +# ====================================================================== +# TestIntentClassifier — core classifier +# ====================================================================== + + +class TestIntentClassifier: + """Core IntentClassifier tests.""" + + def test_empty_input_conversational(self): + c = IntentClassifier() + result = c.classify("") + assert result.kind == IntentKind.CONVERSATIONAL + + def test_whitespace_only_conversational(self): + c = IntentClassifier() + result = c.classify(" ") + assert result.kind == IntentKind.CONVERSATIONAL + + def test_slash_command_passthrough(self): + """Explicit slash commands should return CONVERSATIONAL for pass-through.""" + c = IntentClassifier() + result = c.classify("/open") + assert result.kind == IntentKind.CONVERSATIONAL + + def test_ai_classification_parses_command(self): + """AI classification used when keywords have partial match.""" + c = _make_classifier_with_ai('{"command": "/open", "args": "", "is_command": true}') + # Register a keyword with partial signal (one keyword = 0.2, below 0.5 threshold) + c.register(IntentPattern(command="/open", keywords=["items"], min_confidence=0.5)) + result = c.classify("what are the open items") + assert result.kind == IntentKind.COMMAND + assert result.command == "/open" + + def test_ai_classification_conversational(self): + c = _make_classifier_with_ai('{"command": "", "args": "", "is_command": false}') + result = c.classify("I think we should use PostgreSQL") + assert result.kind == IntentKind.CONVERSATIONAL + + def test_ai_classification_with_args(self): + """AI classification used when keywords have partial match.""" + c = _make_classifier_with_ai('{"command": "/deploy", "args": "3", "is_command": true}') + # Register a keyword with partial signal + c.register(IntentPattern(command="/deploy", keywords=["deploy"], min_confidence=0.5)) + result = c.classify("deploy stage 3") + assert result.kind == IntentKind.COMMAND + assert result.command == "/deploy" + assert result.args == "3" + + def test_ai_classification_falls_back_on_parse_error(self): + """When AI returns unparseable JSON, fall through to keyword scoring.""" + c = _make_classifier_with_ai("This is not JSON at all") + # Register a keyword pattern that will match (keyword + phrase = 0.6) + c.register(IntentPattern( + command="/open", + keywords=["open"], + phrases=["open items"], + )) + result = c.classify("what are the open items") + assert result.kind == IntentKind.COMMAND + assert result.command == "/open" + + def test_ai_classification_falls_back_when_no_provider(self): + """When no AI provider, keyword fallback runs.""" + c = IntentClassifier() # No AI provider + c.add_command_def(CommandDef("/open", "Show open items")) + c.register(IntentPattern( + command="/open", + keywords=["open"], + phrases=["open items"], + )) + result = c.classify("what are the open items") + assert result.kind == IntentKind.COMMAND + assert result.command == "/open" + + def test_keyword_matching_triggers_command(self): + c = IntentClassifier() + c.register(IntentPattern( + command="/status", + keywords=["status"], + phrases=["build status"], + )) + result = c.classify("what's the build status") + assert result.kind == IntentKind.COMMAND + assert result.command == "/status" + + def test_below_threshold_conversational(self): + c = IntentClassifier() + c.register(IntentPattern( + command="/deploy", + keywords=[], + phrases=["deploy stage"], + min_confidence=0.5, + )) + # "the" keyword alone shouldn't match + result = c.classify("I like the architecture") + assert result.kind == IntentKind.CONVERSATIONAL + + def test_phrase_outscores_keyword(self): + c = IntentClassifier() + c.register(IntentPattern( + command="/files", + keywords=["files"], + phrases=["generated files"], + )) + result = c.classify("show me the generated files") + # phrase(0.4) + keyword(0.2) = 0.6 > 0.5 threshold + assert result.kind == IntentKind.COMMAND + assert result.command == "/files" + assert result.confidence >= 0.6 + + def test_regex_match_extracts_args(self): + c = IntentClassifier() + c.register(IntentPattern( + command="/deploy", + regex_patterns=[r"deploy\s+(?:stage\s+)?\d+"], + arg_extractor=lambda t: " ".join(__import__("re").findall(r"\d+", t)), + )) + result = c.classify("deploy stage 3") + assert result.kind == IntentKind.COMMAND + assert result.command == "/deploy" + assert result.args == "3" + + def test_file_read_detection(self): + c = IntentClassifier() + result = c.classify("read artifacts from ~/docs/requirements") + assert result.kind == IntentKind.READ_FILES + assert result.command == "__read_files" + assert "docs/requirements" in result.args + + def test_file_load_detection(self): + c = IntentClassifier() + result = c.classify("load files from /tmp/specs") + assert result.kind == IntentKind.READ_FILES + assert "/tmp/specs" in result.args + + def test_file_import_detection(self): + c = IntentClassifier() + result = c.classify("import documents from ./design") + assert result.kind == IntentKind.READ_FILES + + def test_no_false_file_read(self): + """'I read a book yesterday' should NOT match file read pattern.""" + c = IntentClassifier() + result = c.classify("I read a book yesterday") + assert result.kind == IntentKind.CONVERSATIONAL + + def test_ai_markdown_fenced_json(self): + """AI response with markdown fences should still parse.""" + c = _make_classifier_with_ai('```json\n{"command": "/status", "args": "", "is_command": true}\n```') + # Register a keyword with partial signal + c.register(IntentPattern(command="/status", keywords=["status"], min_confidence=0.5)) + result = c.classify("what's the status") + assert result.kind == IntentKind.COMMAND + assert result.command == "/status" + + def test_ai_network_error_falls_back(self): + """Network errors should fall through to keyword fallback.""" + provider = MagicMock() + provider.chat.side_effect = ConnectionError("timeout") + c = IntentClassifier(ai_provider=provider) + c.add_command_def(CommandDef("/open", "Show open items")) + c.register(IntentPattern( + command="/open", + keywords=["open"], + phrases=["open items"], + )) + result = c.classify("what are the open items") + assert result.kind == IntentKind.COMMAND + assert result.command == "/open" + + +# ====================================================================== +# TestDiscoveryIntents — discovery session factory +# ====================================================================== + + +class TestDiscoveryIntents: + """Tests for the discovery session classifier (keyword fallback path).""" + + def test_open_items(self): + c = build_discovery_classifier() + result = c.classify("What are the open items?") + assert result.kind == IntentKind.COMMAND + assert result.command == "/open" + + def test_status(self): + c = build_discovery_classifier() + result = c.classify("Where do we stand?") + assert result.kind == IntentKind.COMMAND + assert result.command == "/status" + + def test_summary(self): + c = build_discovery_classifier() + result = c.classify("Give me a summary") + assert result.kind == IntentKind.COMMAND + assert result.command == "/summary" + + def test_conversational_feedback(self): + """Design feedback should NOT be classified as a command.""" + c = build_discovery_classifier() + result = c.classify("I don't like the database choice, change it to PostgreSQL") + assert result.kind == IntentKind.CONVERSATIONAL + + def test_why_command(self): + c = build_discovery_classifier() + result = c.classify("Why did we choose Cosmos DB?") + assert result.kind == IntentKind.COMMAND + assert result.command == "/why" + assert "Cosmos DB" in result.args + + def test_restart(self): + c = build_discovery_classifier() + result = c.classify("let's start over") + assert result.kind == IntentKind.COMMAND + assert result.command == "/restart" + + def test_unresolved(self): + c = build_discovery_classifier() + result = c.classify("What's still unresolved?") + assert result.kind == IntentKind.COMMAND + assert result.command == "/open" + + +# ====================================================================== +# TestDeployIntents — deploy session factory +# ====================================================================== + + +class TestDeployIntents: + """Tests for the deploy session classifier (keyword fallback path).""" + + def test_deploy_stage_3(self): + c = build_deploy_classifier() + result = c.classify("deploy stage 3") + assert result.kind == IntentKind.COMMAND + assert result.command == "/deploy" + assert "3" in result.args + + def test_deploy_all(self): + c = build_deploy_classifier() + result = c.classify("deploy all stages") + assert result.kind == IntentKind.COMMAND + assert result.command == "/deploy" + + def test_rollback_stage_2(self): + c = build_deploy_classifier() + result = c.classify("rollback stage 2") + assert result.kind == IntentKind.COMMAND + assert result.command == "/rollback" + assert "2" in result.args + + def test_deploy_stages_3_and_4(self): + c = build_deploy_classifier() + result = c.classify("deploy stages 3 and 4") + assert result.kind == IntentKind.COMMAND + assert result.command == "/deploy" + assert "3" in result.args + assert "4" in result.args + + def test_deployment_status(self): + c = build_deploy_classifier() + result = c.classify("what's the deployment status") + assert result.kind == IntentKind.COMMAND + assert result.command == "/status" + + def test_describe_stage(self): + c = build_deploy_classifier() + result = c.classify("describe stage 3") + assert result.kind == IntentKind.COMMAND + assert result.command == "/describe" + assert "3" in result.args + + def test_whats_being_deployed(self): + c = build_deploy_classifier() + result = c.classify("what's being deployed in stage 2") + assert result.kind == IntentKind.COMMAND + assert result.command == "/describe" + assert "2" in result.args + + def test_rollback_all(self): + c = build_deploy_classifier() + result = c.classify("roll back all") + assert result.kind == IntentKind.COMMAND + assert result.command == "/rollback" + assert "all" in result.args + + def test_undo_stage(self): + c = build_deploy_classifier() + result = c.classify("undo stage 1") + assert result.kind == IntentKind.COMMAND + assert result.command == "/rollback" + assert "1" in result.args + + +# ====================================================================== +# TestBuildIntents — build session factory +# ====================================================================== + + +class TestBuildIntents: + """Tests for the build session classifier (keyword fallback path).""" + + def test_generated_files(self): + c = build_build_classifier() + result = c.classify("show me the generated files") + assert result.kind == IntentKind.COMMAND + assert result.command == "/files" + + def test_build_status(self): + c = build_build_classifier() + result = c.classify("what's the build status") + assert result.kind == IntentKind.COMMAND + assert result.command == "/status" + + def test_describe_stage(self): + c = build_build_classifier() + result = c.classify("describe stage 1") + assert result.kind == IntentKind.COMMAND + assert result.command == "/describe" + assert "1" in result.args + + def test_conversational_feedback(self): + """Build feedback should NOT be classified as a command.""" + c = build_build_classifier() + result = c.classify("I don't like the key vault config") + assert result.kind == IntentKind.CONVERSATIONAL + + def test_show_policy(self): + c = build_build_classifier() + result = c.classify("show policy status") + assert result.kind == IntentKind.COMMAND + assert result.command == "/policy" + + +# ====================================================================== +# TestBacklogIntents — backlog session factory +# ====================================================================== + + +class TestBacklogIntents: + """Tests for the backlog session classifier (keyword fallback path).""" + + def test_show_all_items(self): + c = build_backlog_classifier() + result = c.classify("show all items") + assert result.kind == IntentKind.COMMAND + assert result.command == "/list" + + def test_push_item(self): + c = build_backlog_classifier() + result = c.classify("push item 3") + assert result.kind == IntentKind.COMMAND + assert result.command == "/push" + assert "3" in result.args + + def test_add_story_is_conversational(self): + """'add a story for API rate limiting' should fall through to AI mutation.""" + c = build_backlog_classifier() + result = c.classify("add a story for API rate limiting") + assert result.kind == IntentKind.CONVERSATIONAL + + def test_show_item(self): + c = build_backlog_classifier() + result = c.classify("show me item 2") + assert result.kind == IntentKind.COMMAND + assert result.command == "/show" + assert "2" in result.args + + def test_remove_item(self): + c = build_backlog_classifier() + result = c.classify("remove item 5") + assert result.kind == IntentKind.COMMAND + assert result.command == "/remove" + assert "5" in result.args + + def test_save_backlog(self): + c = build_backlog_classifier() + result = c.classify("save the backlog") + assert result.kind == IntentKind.COMMAND + assert result.command == "/save" + + +# ====================================================================== +# TestFileReadDetection — cross-session file reading +# ====================================================================== + + +class TestFileReadDetection: + """Tests for the file-read regex detection.""" + + def test_read_artifacts_from_path(self): + c = IntentClassifier() + result = c.classify("Read artifacts from ~/docs/requirements") + assert result.kind == IntentKind.READ_FILES + assert "docs/requirements" in result.args + + def test_load_files_from_path(self): + c = IntentClassifier() + result = c.classify("Load files from /tmp/specs") + assert result.kind == IntentKind.READ_FILES + assert "/tmp/specs" in result.args + + def test_no_false_read(self): + """'I read a book yesterday' should NOT match.""" + c = IntentClassifier() + result = c.classify("I read a book yesterday") + assert result.kind == IntentKind.CONVERSATIONAL + + def test_import_documents(self): + c = IntentClassifier() + result = c.classify("import documents from ./specs") + assert result.kind == IntentKind.READ_FILES + assert "specs" in result.args + + +# ====================================================================== +# TestReadFilesForSession — file reading helper +# ====================================================================== + + +class TestReadFilesForSession: + """Tests for the read_files_for_session helper.""" + + def test_nonexistent_path(self, tmp_path): + output = [] + text, images = read_files_for_session( + str(tmp_path / "nonexistent"), + str(tmp_path), + output.append, + ) + assert text == "" + assert images == [] + assert any("not found" in o for o in output) + + def test_read_text_file(self, tmp_path): + (tmp_path / "hello.txt").write_text("Hello world", encoding="utf-8") + output = [] + text, images = read_files_for_session( + str(tmp_path / "hello.txt"), + str(tmp_path), + output.append, + ) + assert "Hello world" in text + assert images == [] + + def test_read_directory(self, tmp_path): + (tmp_path / "a.txt").write_text("File A", encoding="utf-8") + (tmp_path / "b.txt").write_text("File B", encoding="utf-8") + output = [] + text, images = read_files_for_session( + str(tmp_path), + str(tmp_path), + output.append, + ) + assert "File A" in text + assert "File B" in text + + def test_read_skips_hidden_files(self, tmp_path): + (tmp_path / ".hidden").write_text("secret", encoding="utf-8") + (tmp_path / "visible.txt").write_text("visible", encoding="utf-8") + output = [] + text, images = read_files_for_session( + str(tmp_path), + str(tmp_path), + output.append, + ) + assert "visible" in text + assert "secret" not in text + + def test_relative_path_resolution(self, tmp_path): + (tmp_path / "specs").mkdir() + (tmp_path / "specs" / "req.txt").write_text("Requirements", encoding="utf-8") + output = [] + text, images = read_files_for_session( + "specs", + str(tmp_path), + output.append, + ) + assert "Requirements" in text + + +# ====================================================================== +# TestAIClassificationPrompt — prompt construction +# ====================================================================== + + +class TestAIClassificationPrompt: + """Tests that the AI classification prompt is built correctly.""" + + def test_prompt_includes_commands(self): + c = IntentClassifier() + c.add_command_def(CommandDef("/open", "Show open items")) + c.add_command_def(CommandDef("/deploy", "Deploy stage", has_args=True, arg_description="N")) + prompt = c._build_classification_prompt() + assert "/open" in prompt + assert "Show open items" in prompt + assert "/deploy" in prompt + assert "" in prompt + + def test_prompt_includes_special_commands(self): + c = IntentClassifier() + c.add_command_def(CommandDef("/status", "Show status")) + prompt = c._build_classification_prompt() + assert "__prompt_context" in prompt + assert "__read_files" in prompt diff --git a/tests/test_telemetry.py b/tests/test_telemetry.py index c981185..6efc9c5 100644 --- a/tests/test_telemetry.py +++ b/tests/test_telemetry.py @@ -354,7 +354,7 @@ def test_reads_from_metadata(self): from azext_prototype.telemetry import _get_extension_version version = _get_extension_version() - assert version == "0.2.1b3" + assert version == "0.2.1b4" def test_returns_unknown_on_error(self): from azext_prototype.telemetry import _get_extension_version @@ -1199,7 +1199,7 @@ def test_command_count(self): if name.startswith("prototype_") and callable(getattr(custom_mod, name)) ] - assert len(command_functions) == 23 + assert len(command_functions) == 24 # ====================================================================== diff --git a/tests/test_token_tracker.py b/tests/test_token_tracker.py index ffc11e9..bd32029 100644 --- a/tests/test_token_tracker.py +++ b/tests/test_token_tracker.py @@ -230,6 +230,39 @@ def test_gpt4_small_window(self): assert pct is not None assert abs(pct - 50.0) < 0.1 # 4096 / 8192 = 50% + def test_claude_model_exact(self): + """Claude models should have known context windows.""" + t = TokenTracker() + t.record(AIResponse( + content="x", model="claude-sonnet-4", + usage={"prompt_tokens": 100_000, "completion_tokens": 0}, + )) + pct = t.budget_pct + assert pct is not None + assert abs(pct - 50.0) < 0.1 # 100000 / 200000 = 50% + + def test_claude_model_substring(self): + """Claude model names with suffixes should match via substring.""" + t = TokenTracker() + t.record(AIResponse( + content="x", model="claude-sonnet-4-20250514", + usage={"prompt_tokens": 50_000, "completion_tokens": 0}, + )) + pct = t.budget_pct + assert pct is not None + assert abs(pct - 25.0) < 0.1 # 50000 / 200000 = 25% + + def test_gemini_model(self): + """Gemini models should have known context windows.""" + t = TokenTracker() + t.record(AIResponse( + content="x", model="gemini-2.0-flash", + usage={"prompt_tokens": 524_288, "completion_tokens": 0}, + )) + pct = t.budget_pct + assert pct is not None + assert abs(pct - 50.0) < 0.1 # 524288 / 1048576 = 50% + # -------------------------------------------------------------------- # # Console.print_token_status — unit tests diff --git a/tests/test_tui_adapter.py b/tests/test_tui_adapter.py new file mode 100644 index 0000000..76da42e --- /dev/null +++ b/tests/test_tui_adapter.py @@ -0,0 +1,490 @@ +"""Threading and bridge tests for TUIAdapter. + +Verifies that the adapter correctly shuttles data between worker +threads (sessions) and the main Textual event loop (widgets). +""" + +from __future__ import annotations + +import threading + +import pytest + +from azext_prototype.ui.app import PrototypeApp +from azext_prototype.ui.task_model import TaskStatus +from azext_prototype.ui.tui_adapter import _strip_rich_markup, _RICH_TAG_RE + + +# -------------------------------------------------------------------- # +# Unit tests (no Textual) +# -------------------------------------------------------------------- # + + +class TestStripRichMarkup: + def test_strips_simple_tags(self): + assert _strip_rich_markup("[success]OK[/success]") == "OK" + + def test_strips_nested(self): + assert _strip_rich_markup("[bold][info]hello[/info][/bold]") == "hello" + + def test_leaves_plain_text(self): + assert _strip_rich_markup("no markup here") == "no markup here" + + def test_preserves_brackets_in_non_tag_context(self): + # e.g. list notation + assert _strip_rich_markup("list[0]") == "list[0]" + + def test_empty(self): + assert _strip_rich_markup("") == "" + + +# -------------------------------------------------------------------- # +# Integration tests with Textual pilot +# -------------------------------------------------------------------- # + + +@pytest.mark.asyncio +async def test_adapter_print_fn(): + """print_fn should route text to the ConsoleView.""" + app = PrototypeApp() + async with app.run_test() as pilot: + adapter = app.adapter + # Simulate a worker thread calling print_fn + done = threading.Event() + + def _worker(): + adapter.print_fn("Hello from worker") + done.set() + + t = threading.Thread(target=_worker) + t.start() + done.wait(timeout=5) + t.join(timeout=5) + # The message should have been routed through — no exception = success + + +@pytest.mark.asyncio +async def test_adapter_input_fn_and_submit(): + """input_fn should block until on_prompt_submitted is called.""" + app = PrototypeApp() + async with app.run_test() as pilot: + adapter = app.adapter + result = {} + + def _worker(): + result["value"] = adapter.input_fn("> ") + + t = threading.Thread(target=_worker) + t.start() + + # Give the worker thread time to block + await pilot.pause() + await pilot.pause() + + # Simulate user submitting input from the main thread + adapter.on_prompt_submitted("test response") + + t.join(timeout=5) + assert result.get("value") == "test response" + + +@pytest.mark.asyncio +async def test_adapter_status_fn(): + """status_fn should update the info bar assist text.""" + app = PrototypeApp() + async with app.run_test() as pilot: + adapter = app.adapter + + done = threading.Event() + + def _worker(): + adapter.status_fn("Building Stage 1...", "start") + adapter.status_fn("Building Stage 1...", "end") + done.set() + + t = threading.Thread(target=_worker) + t.start() + done.wait(timeout=5) + t.join(timeout=5) + + +@pytest.mark.asyncio +async def test_adapter_token_status(): + """print_token_status should update the info bar status.""" + app = PrototypeApp() + async with app.run_test() as pilot: + adapter = app.adapter + done = threading.Event() + + def _worker(): + adapter.print_token_status("1,200 tokens · 5,000 session") + done.set() + + t = threading.Thread(target=_worker) + t.start() + done.wait(timeout=5) + t.join(timeout=5) + + +@pytest.mark.asyncio +async def test_adapter_status_fn_timer_lifecycle(): + """status_fn start/end/tokens lifecycle should manage elapsed timer.""" + app = PrototypeApp() + async with app.run_test() as pilot: + adapter = app.adapter + + # Start the timer + start_done = threading.Event() + + def _start(): + adapter.status_fn("Analyzing your input...", "start") + start_done.set() + + t1 = threading.Thread(target=_start) + t1.start() + start_done.wait(timeout=5) + t1.join(timeout=5) + await pilot.pause() + await pilot.pause() + + assert adapter._timer_start is not None + assert adapter._timer_handle is not None + + # Stop the timer and replace with tokens + stop_done = threading.Event() + + def _stop(): + adapter.status_fn("Analyzing your input...", "end") + adapter.status_fn("1,200 tokens \u00b7 5,000 session", "tokens") + stop_done.set() + + t2 = threading.Thread(target=_stop) + t2.start() + stop_done.wait(timeout=5) + t2.join(timeout=5) + await pilot.pause() + await pilot.pause() + + assert adapter._timer_handle is None + + +@pytest.mark.asyncio +async def test_adapter_task_updates(): + """Task tree operations via adapter should work from threads.""" + app = PrototypeApp() + async with app.run_test() as pilot: + adapter = app.adapter + done = threading.Event() + + def _worker(): + adapter.update_task("init", TaskStatus.COMPLETED) + adapter.add_task("design", "design-d1", "Discovery") + adapter.update_task("design-d1", TaskStatus.IN_PROGRESS) + done.set() + + t = threading.Thread(target=_worker) + t.start() + done.wait(timeout=5) + t.join(timeout=5) + + # Let the main event loop process the queued callbacks + await pilot.pause() + await pilot.pause() + + # Verify state + assert app.task_tree.store.get("init").status == TaskStatus.COMPLETED + assert app.task_tree.store.get("design-d1") is not None + + +@pytest.mark.asyncio +async def test_adapter_clear_tasks(): + """clear_tasks should remove sub-tasks via worker thread.""" + app = PrototypeApp() + async with app.run_test() as pilot: + adapter = app.adapter + + # Add tasks from a worker thread + setup_done = threading.Event() + + def _setup(): + adapter.add_task("build", "build-s1", "Stage 1") + adapter.add_task("build", "build-s2", "Stage 2") + setup_done.set() + + t1 = threading.Thread(target=_setup) + t1.start() + setup_done.wait(timeout=5) + t1.join(timeout=5) + + await pilot.pause() + await pilot.pause() + assert app.task_tree.store.get("build-s1") is not None + + done = threading.Event() + + def _worker(): + adapter.clear_tasks("build") + done.set() + + t = threading.Thread(target=_worker) + t.start() + done.wait(timeout=5) + t.join(timeout=5) + + +@pytest.mark.asyncio +async def test_adapter_section_fn(): + """section_fn should add design sub-tasks with dedup and hierarchy.""" + app = PrototypeApp() + async with app.run_test() as pilot: + adapter = app.adapter + done = threading.Event() + + def _worker(): + adapter.section_fn([("Project Context & Scope", 2), ("Data & Content", 2)]) + # Call again with overlapping header — should dedup + adapter.section_fn([("Project Context & Scope", 2), ("Security", 2)]) + done.set() + + t = threading.Thread(target=_worker) + t.start() + done.wait(timeout=5) + t.join(timeout=5) + + # Let the main event loop process the queued callbacks + await pilot.pause() + await pilot.pause() + + # Verify: 3 unique sections (not 4) + assert app.task_tree.store.get("design-section-project-context-scope") is not None + assert app.task_tree.store.get("design-section-data-content") is not None + assert app.task_tree.store.get("design-section-security") is not None + + # Check labels + assert app.task_tree.store.get("design-section-project-context-scope").label == "Project Context & Scope" + assert app.task_tree.store.get("design-section-security").label == "Security" + + +@pytest.mark.asyncio +async def test_adapter_section_fn_hierarchy(): + """Level-3 headings should nest under the most recent level-2 section.""" + app = PrototypeApp() + async with app.run_test() as pilot: + adapter = app.adapter + done = threading.Event() + + def _worker(): + # First call: a level-2 parent with level-3 children + adapter.section_fn([ + ("Architecture", 2), + ("Compute", 3), + ("Networking", 3), + ]) + # Second call: new level-2, then level-3 under it + adapter.section_fn([ + ("Security", 2), + ("Authentication", 3), + ]) + done.set() + + t = threading.Thread(target=_worker) + t.start() + done.wait(timeout=5) + t.join(timeout=5) + + await pilot.pause() + await pilot.pause() + + # All nodes should exist in the store + assert app.task_tree.store.get("design-section-architecture") is not None + assert app.task_tree.store.get("design-section-compute") is not None + assert app.task_tree.store.get("design-section-networking") is not None + assert app.task_tree.store.get("design-section-security") is not None + assert app.task_tree.store.get("design-section-authentication") is not None + + # Level-3 nodes should be children of their level-2 parent + arch = app.task_tree.store.get("design-section-architecture") + child_ids = [c.id for c in arch.children] + assert "design-section-compute" in child_ids + assert "design-section-networking" in child_ids + + sec = app.task_tree.store.get("design-section-security") + sec_child_ids = [c.id for c in sec.children] + assert "design-section-authentication" in sec_child_ids + + +# -------------------------------------------------------------------- # +# print_fn markup preservation +# -------------------------------------------------------------------- # + + +@pytest.mark.asyncio +async def test_adapter_print_fn_preserves_markup(): + """print_fn should detect and preserve Rich markup tags.""" + app = PrototypeApp() + async with app.run_test() as pilot: + adapter = app.adapter + done = threading.Event() + + def _worker(): + adapter.print_fn("[success]✓[/success] All good") + adapter.print_fn("Plain text without markup") + done.set() + + t = threading.Thread(target=_worker) + t.start() + done.wait(timeout=5) + t.join(timeout=5) + # No exception = success (markup preserved for styled, plain for unstyled) + + +# -------------------------------------------------------------------- # +# response_fn +# -------------------------------------------------------------------- # + + +@pytest.mark.asyncio +async def test_adapter_response_fn_single_section(): + """response_fn with no headings should render without pagination.""" + app = PrototypeApp() + async with app.run_test() as pilot: + adapter = app.adapter + done = threading.Event() + + def _worker(): + adapter.response_fn("Just a simple response with no headings.") + done.set() + + t = threading.Thread(target=_worker) + t.start() + done.wait(timeout=5) + t.join(timeout=5) + # No exception = success + + +@pytest.mark.asyncio +async def test_adapter_on_prompt_submitted_empty_no_echo(): + """Empty submission (pagination) should not echo to console.""" + app = PrototypeApp() + async with app.run_test() as pilot: + adapter = app.adapter + # Submit empty string (simulating "Enter to continue") + adapter.on_prompt_submitted("") + # The input event should be set + assert adapter._input_event.is_set() + + +@pytest.mark.asyncio +async def test_adapter_status_fn_timer_start_cleared_after_end(): + """After 'end' event, _timer_start should be None.""" + app = PrototypeApp() + async with app.run_test() as pilot: + adapter = app.adapter + + # Start the timer + done1 = threading.Event() + + def _start(): + adapter.status_fn("Analyzing...", "start") + done1.set() + + t1 = threading.Thread(target=_start) + t1.start() + done1.wait(timeout=5) + t1.join(timeout=5) + await pilot.pause() + await pilot.pause() + + assert adapter._timer_start is not None + + # Stop the timer + done2 = threading.Event() + + def _stop(): + adapter.status_fn("Analyzing...", "end") + done2.set() + + t2 = threading.Thread(target=_stop) + t2.start() + done2.wait(timeout=5) + t2.join(timeout=5) + await pilot.pause() + await pilot.pause() + + # _timer_start should be cleared + assert adapter._timer_start is None + + +@pytest.mark.asyncio +async def test_adapter_timer_tick_after_cancel_is_noop(): + """_tick_timer() after _stop() should not overwrite info bar.""" + app = PrototypeApp() + async with app.run_test() as pilot: + adapter = app.adapter + + # Start then stop + done = threading.Event() + + def _lifecycle(): + adapter.status_fn("Thinking...", "start") + adapter.status_fn("Thinking...", "end") + adapter.status_fn("500 tokens · 500 session", "tokens") + done.set() + + t = threading.Thread(target=_lifecycle) + t.start() + done.wait(timeout=5) + t.join(timeout=5) + await pilot.pause() + await pilot.pause() + + # Now call _tick_timer on the main thread — should be a no-op + adapter._tick_timer() + await pilot.pause() + + # The status should still show the token text, not overwritten by timer + # (We can't easily read the status text, but verifying no exception + # and that _timer_start is None confirms the guard works) + assert adapter._timer_start is None + assert adapter._timer_handle is None + + +@pytest.mark.asyncio +async def test_adapter_section_fn_with_bold_headings(): + """section_fn should work when discovery extracts **bold** headings as tuples.""" + from azext_prototype.stages.discovery import extract_section_headers + + app = PrototypeApp() + async with app.run_test() as pilot: + adapter = app.adapter + + # Simulate what the discovery session does with bold headings + response = ( + "Let me explore your requirements.\n" + "\n" + "**Hosting & Deployment**\n" + "How do you plan to host this?\n" + "\n" + "**Data Layer**\n" + "What database will you use?" + ) + headers = extract_section_headers(response) + assert len(headers) >= 2 # sanity check + # Bold headings should be level 2 + assert all(level == 2 for _, level in headers) + + done = threading.Event() + + def _worker(): + adapter.section_fn(headers) + done.set() + + t = threading.Thread(target=_worker) + t.start() + done.wait(timeout=5) + t.join(timeout=5) + await pilot.pause() + await pilot.pause() + + assert app.task_tree.store.get("design-section-hosting-deployment") is not None + assert app.task_tree.store.get("design-section-data-layer") is not None diff --git a/tests/test_tui_widgets.py b/tests/test_tui_widgets.py new file mode 100644 index 0000000..40509cf --- /dev/null +++ b/tests/test_tui_widgets.py @@ -0,0 +1,284 @@ +"""Widget isolation tests for the Textual TUI dashboard. + +Uses Textual's pilot test harness to mount individual widgets and +the full PrototypeApp in a headless terminal. +""" + +from __future__ import annotations + +import pytest + +from azext_prototype.ui.app import PrototypeApp +from azext_prototype.ui.task_model import TaskItem, TaskStatus, TaskStore +from azext_prototype.ui.widgets.console_view import ConsoleView +from azext_prototype.ui.widgets.info_bar import InfoBar +from azext_prototype.ui.widgets.prompt_input import PromptInput +from azext_prototype.ui.widgets.task_tree import TaskTree + + +# -------------------------------------------------------------------- # +# TaskStore unit tests (no Textual needed) +# -------------------------------------------------------------------- # + + +class TestTaskStore: + def test_roots_initialized(self): + store = TaskStore() + roots = store.roots + assert len(roots) == 4 + assert [r.id for r in roots] == ["init", "design", "build", "deploy"] + + def test_update_status(self): + store = TaskStore() + item = store.update_status("init", TaskStatus.COMPLETED) + assert item is not None + assert item.status == TaskStatus.COMPLETED + + def test_update_nonexistent(self): + store = TaskStore() + assert store.update_status("nope", TaskStatus.COMPLETED) is None + + def test_add_child(self): + store = TaskStore() + child = TaskItem(id="design-req1", label="Gather requirements") + assert store.add_child("design", child) is True + assert len(store.get("design").children) == 1 + assert store.get("design-req1") is child + + def test_add_child_invalid_parent(self): + store = TaskStore() + child = TaskItem(id="orphan", label="Orphan") + assert store.add_child("nonexistent", child) is False + + def test_remove(self): + store = TaskStore() + child = TaskItem(id="build-stage1", label="Stage 1") + store.add_child("build", child) + assert store.remove("build-stage1") is True + assert store.get("build-stage1") is None + assert len(store.get("build").children) == 0 + + def test_clear_children(self): + store = TaskStore() + store.add_child("deploy", TaskItem(id="d1", label="Stage 1")) + store.add_child("deploy", TaskItem(id="d2", label="Stage 2")) + assert len(store.get("deploy").children) == 2 + store.clear_children("deploy") + assert len(store.get("deploy").children) == 0 + assert store.get("d1") is None + + def test_display(self): + item = TaskItem(id="t", label="Test", status=TaskStatus.COMPLETED) + assert "\u2713" in item.display # checkmark + assert "Test" in item.display + + +# -------------------------------------------------------------------- # +# TaskItem unit tests +# -------------------------------------------------------------------- # + + +class TestTaskItem: + def test_symbols(self): + assert TaskItem(id="a", label="a", status=TaskStatus.PENDING).symbol == "\u25cb" + assert TaskItem(id="b", label="b", status=TaskStatus.IN_PROGRESS).symbol == "\u25cf" + assert TaskItem(id="c", label="c", status=TaskStatus.COMPLETED).symbol == "\u2713" + assert TaskItem(id="d", label="d", status=TaskStatus.FAILED).symbol == "\u2717" + + +# -------------------------------------------------------------------- # +# Textual pilot tests +# -------------------------------------------------------------------- # + + +@pytest.mark.asyncio +async def test_app_mounts(): + """The app should mount all four panels without errors.""" + app = PrototypeApp() + async with app.run_test() as pilot: + # All four widget types should be queryable + assert app.query_one("#console-view", ConsoleView) + assert app.query_one("#task-tree", TaskTree) + assert app.query_one("#prompt-input", PromptInput) + assert app.query_one("#info-bar", InfoBar) + + +@pytest.mark.asyncio +async def test_console_view_write_text(): + """ConsoleView should accept text writes.""" + app = PrototypeApp() + async with app.run_test() as pilot: + cv = app.console_view + cv.write_text("Hello, TUI!") + cv.write_success("It worked") + cv.write_error("Something failed") + cv.write_warning("Watch out") + cv.write_info("FYI") + cv.write_header("Section") + cv.write_dim("Quiet text") + # No exception raised = success + + +@pytest.mark.asyncio +async def test_console_view_agent_response(): + """ConsoleView should render markdown agent responses.""" + app = PrototypeApp() + async with app.run_test() as pilot: + app.console_view.write_agent_response("# Hello\n\nThis is **bold**.") + + +@pytest.mark.asyncio +async def test_task_tree_roots(): + """TaskTree should show 4 root nodes on mount.""" + app = PrototypeApp() + async with app.run_test() as pilot: + tree = app.task_tree + # Root should have 4 children (Init, Design, Build, Deploy) + assert len(tree.root.children) == 4 + + +@pytest.mark.asyncio +async def test_task_tree_update(): + """TaskTree should update status labels.""" + app = PrototypeApp() + async with app.run_test() as pilot: + tree = app.task_tree + tree.update_task("init", TaskStatus.COMPLETED) + item = tree.store.get("init") + assert item.status == TaskStatus.COMPLETED + + +@pytest.mark.asyncio +async def test_task_tree_add_child(): + """TaskTree should add and display sub-tasks.""" + app = PrototypeApp() + async with app.run_test() as pilot: + tree = app.task_tree + child = TaskItem(id="design-discovery", label="Discovery conversation") + tree.add_task("design", child) + assert tree.store.get("design-discovery") is not None + # Node should be in the map + assert "design-discovery" in tree._node_map + + +@pytest.mark.asyncio +async def test_task_tree_add_section(): + """TaskTree.add_section() should create an expandable node that accepts children.""" + app = PrototypeApp() + async with app.run_test() as pilot: + tree = app.task_tree + section = TaskItem(id="design-section-arch", label="Architecture") + tree.add_section("design", section) + assert tree.store.get("design-section-arch") is not None + assert "design-section-arch" in tree._node_map + # The section node should be expandable (not a leaf) + node = tree._node_map["design-section-arch"] + assert node.allow_expand is True + + # Now add a child under the section + child = TaskItem(id="design-section-compute", label="Compute") + tree.add_task("design-section-arch", child) + assert tree.store.get("design-section-compute") is not None + + +@pytest.mark.asyncio +async def test_info_bar_updates(): + """InfoBar should update assist and status text.""" + app = PrototypeApp() + async with app.run_test() as pilot: + app.info_bar.update_assist("Press Enter to continue") + app.info_bar.update_status("1,200 tokens") + # No exception = success + + +@pytest.mark.asyncio +async def test_prompt_input_disable(): + """PromptInput should be disabled by default.""" + app = PrototypeApp() + async with app.run_test() as pilot: + prompt = app.prompt_input + assert prompt._enabled is False + assert prompt.read_only is True + + +@pytest.mark.asyncio +async def test_prompt_input_enable(): + """PromptInput should allow enabling for input.""" + app = PrototypeApp() + async with app.run_test() as pilot: + prompt = app.prompt_input + prompt.enable() + assert prompt._enabled is True + assert prompt.read_only is False + + +@pytest.mark.asyncio +async def test_file_list(): + """ConsoleView should render file lists.""" + app = PrototypeApp() + async with app.run_test() as pilot: + app.console_view.write_file_list(["main.tf", "variables.tf"], success=True) + app.console_view.write_file_list(["broken.tf"], success=False) + + +# -------------------------------------------------------------------- # +# ConsoleView.write_markup tests +# -------------------------------------------------------------------- # + + +@pytest.mark.asyncio +async def test_console_view_write_markup(): + """write_markup should accept Rich markup without error.""" + app = PrototypeApp() + async with app.run_test() as pilot: + app.console_view.write_markup("[success]✓[/success] All good") + app.console_view.write_markup("[info]→[/info] Starting session") + # No exception = success + + +@pytest.mark.asyncio +async def test_console_view_write_markup_invalid_falls_back(): + """write_markup with invalid markup should fall back to plain text.""" + app = PrototypeApp() + async with app.run_test() as pilot: + # This has an unclosed tag — should not raise + app.console_view.write_markup("[invalid_tag_that_wont_parse") + + +# -------------------------------------------------------------------- # +# PromptInput allow_empty tests +# -------------------------------------------------------------------- # + + +@pytest.mark.asyncio +async def test_prompt_input_allow_empty(): + """PromptInput with allow_empty=True should submit empty string.""" + app = PrototypeApp() + async with app.run_test() as pilot: + prompt = app.prompt_input + prompt.enable(allow_empty=True) + assert prompt._allow_empty is True + assert prompt._enabled is True + + +@pytest.mark.asyncio +async def test_prompt_input_default_no_allow_empty(): + """PromptInput defaults to allow_empty=False.""" + app = PrototypeApp() + async with app.run_test() as pilot: + prompt = app.prompt_input + prompt.enable() + assert prompt._allow_empty is False + + +@pytest.mark.asyncio +async def test_prompt_input_input_mode(): + """In input mode (default), text has '> ' prefix and placeholder is empty.""" + app = PrototypeApp() + async with app.run_test() as pilot: + prompt = app.prompt_input + prompt.enable() + assert prompt._allow_empty is False + assert prompt._enabled is True + assert prompt.text == "> " + assert prompt.placeholder == ""