diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 6cbd4d68f..ba0cd04f2 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -90,8 +90,8 @@ jobs: - name: Run tests run: cargo test --locked --workspace --lib --bins --test '*' --exclude fig_desktop-fuzz - cargo-clippy-windows-chat-cli: - name: Clippy Windows (chat_cli) + cargo-clippy-windows: + name: Clippy Windows runs-on: windows-latest timeout-minutes: 60 steps: @@ -107,11 +107,11 @@ jobs: ~/.cargo/registry/cache/ ~/.cargo/git/db/ target/ - key: cargo-clippy-windows-chat-cli-${{ hashFiles('**/Cargo.lock') }}-${{ steps.toolchain.outputs.cachekey }} - - run: cargo clippy --locked -p chat_cli --color always -- -D warnings + key: cargo-clippy-windows-${{ hashFiles('**/Cargo.lock') }}-${{ steps.toolchain.outputs.cachekey }} + - run: cargo clippy --locked -p q_cli --color always -- -D warnings - cargo-test-windows-chat-cli: - name: Test Windows (chat_cli) + cargo-test-windows: + name: Test Windows runs-on: windows-latest timeout-minutes: 60 steps: @@ -127,9 +127,9 @@ jobs: ~/.cargo/registry/cache/ ~/.cargo/git/db/ target/ - key: cargo-test-windows-chat-cli-${{ hashFiles('**/Cargo.lock') }}-${{ steps.toolchain.outputs.cachekey }} + key: cargo-test-windows-${{ hashFiles('**/Cargo.lock') }}-${{ steps.toolchain.outputs.cachekey }} - name: Run tests - run: cargo test --locked -p chat_cli + run: cargo test --locked -p q_cli cargo-fmt: name: Fmt diff --git a/Cargo.lock b/Cargo.lock index c7df44821..101afc77b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1144,12 +1144,6 @@ dependencies = [ "vsimd", ] -[[package]] -name = "beef" -version = "0.5.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3a8241f3ebb85c056b509d4327ad0358fbbba6ffb340bf388f26350aeda225b1" - [[package]] name = "bincode" version = "1.3.3" @@ -1556,9 +1550,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.2.43" +version = "1.2.44" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "739eb0f94557554b3ca9a86d2d37bebd49c5e6d0c1d2bda35ba5bdac830befc2" +checksum = "37521ac7aabe3d13122dc382493e20c9416f299d2ccd5b3a5340a2570cdeb0f3" dependencies = [ "find-msvc-tools", "jobserver", @@ -1614,125 +1608,6 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" -[[package]] -name = "chat_cli" -version = "1.19.3" -dependencies = [ - "amzn-codewhisperer-client", - "amzn-codewhisperer-streaming-client", - "amzn-consolas-client", - "amzn-qdeveloper-streaming-client", - "amzn-toolkit-telemetry-client", - "anstream", - "arboard", - "assert_cmd", - "async-trait", - "aws-config", - "aws-credential-types", - "aws-runtime", - "aws-sdk-cognitoidentity", - "aws-sdk-ssooidc", - "aws-smithy-async", - "aws-smithy-runtime-api", - "aws-smithy-types", - "aws-types", - "base64 0.22.1", - "bitflags 2.10.0", - "bstr", - "bytes", - "camino", - "cfg-if", - "clap", - "clap_complete", - "clap_complete_fig", - "color-eyre", - "color-print", - "convert_case 0.8.0", - "cookie", - "criterion", - "crossterm", - "ctrlc", - "dialoguer", - "dirs 5.0.1", - "eyre", - "fd-lock", - "futures", - "glob", - "globset", - "hex", - "http 1.3.1", - "http-body-util", - "hyper 1.7.0", - "hyper-util", - "indicatif", - "indoc", - "insta", - "libc", - "mimalloc", - "mockito", - "nix 0.29.0", - "objc2 0.5.2", - "objc2-app-kit 0.2.2", - "objc2-foundation 0.2.2", - "owo-colors", - "parking_lot", - "paste", - "percent-encoding", - "predicates", - "prettyplease", - "quote", - "r2d2", - "r2d2_sqlite", - "rand 0.9.2", - "regex", - "reqwest", - "ring", - "rusqlite", - "rustls 0.23.34", - "rustls-native-certs 0.8.2", - "rustls-pemfile 2.2.0", - "rustyline", - "security-framework 3.5.1", - "semantic_search_client", - "semver", - "serde", - "serde_json", - "sha2", - "shell-color", - "shell-words", - "shellexpand", - "shlex", - "similar", - "skim", - "spinners", - "strip-ansi-escapes", - "strum 0.27.2", - "syn 2.0.108", - "syntect", - "sysinfo", - "tempfile", - "thiserror 2.0.17", - "time", - "tokio", - "tokio-tungstenite", - "tokio-util", - "toml", - "tracing", - "tracing-appender", - "tracing-subscriber", - "tracing-test", - "typed-path", - "unicode-width 0.2.2", - "url", - "uuid", - "walkdir", - "webpki-roots 0.26.8", - "whoami", - "windows 0.61.3", - "winnow 0.6.2", - "winreg 0.55.0", -] - [[package]] name = "chrono" version = "0.4.42" @@ -1845,10 +1720,8 @@ version = "4.5.49" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2a0b5487afeab2deb2ff4e03a807ad1a03ac532ff5a2cee5d86884440c7f7671" dependencies = [ - "anstyle", "heck 0.5.0", "proc-macro2", - "pulldown-cmark", "quote", "syn 2.0.108", ] @@ -2215,19 +2088,6 @@ dependencies = [ "itertools 0.10.5", ] -[[package]] -name = "crossbeam" -version = "0.8.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1137cd7e7fc0fb5d3c5a8678be38ec56e819125d8d7907411fe24ccb943faca8" -dependencies = [ - "crossbeam-channel", - "crossbeam-deque", - "crossbeam-epoch", - "crossbeam-queue", - "crossbeam-utils", -] - [[package]] name = "crossbeam-channel" version = "0.5.15" @@ -2256,15 +2116,6 @@ dependencies = [ "crossbeam-utils", ] -[[package]] -name = "crossbeam-queue" -version = "0.3.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0f58bbc28f91df819d0aa2a2c00cd19754769c2fad90579b3592b1c9ba7a3115" -dependencies = [ - "crossbeam-utils", -] - [[package]] name = "crossbeam-utils" version = "0.8.21" @@ -2433,16 +2284,6 @@ dependencies = [ "zbus 4.4.0", ] -[[package]] -name = "defer-drop" -version = "1.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f613ec9fa66a6b28cdb1842b27f9adf24f39f9afc4dcdd9fdecee4aca7945c57" -dependencies = [ - "crossbeam-channel", - "once_cell", -] - [[package]] name = "deranged" version = "0.4.0" @@ -2569,16 +2410,6 @@ dependencies = [ "dirs-sys 0.5.0", ] -[[package]] -name = "dirs-next" -version = "2.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b98cf8ebf19c3d1b223e151f99a4f9f0690dca41414773390fc824184ac833e1" -dependencies = [ - "cfg-if", - "dirs-sys-next", -] - [[package]] name = "dirs-sys" version = "0.4.1" @@ -2603,17 +2434,6 @@ dependencies = [ "windows-sys 0.61.2", ] -[[package]] -name = "dirs-sys-next" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ebda144c4fe02d1f7ea1a7d9641b6fc6b580adcfa024ae48797ecdeb6825b4d" -dependencies = [ - "libc", - "redox_users 0.4.6", - "winapi", -] - [[package]] name = "dispatch" version = "0.2.0" @@ -2837,12 +2657,6 @@ dependencies = [ "regex", ] -[[package]] -name = "env_home" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c7f84e12ccf0a7ddc17a6c41c93326024c42920d7ee630d04950e6926645c0fe" - [[package]] name = "env_logger" version = "0.10.2" @@ -3072,7 +2886,7 @@ dependencies = [ "tracing", "tracing-subscriber", "url", - "which 6.0.3", + "which", ] [[package]] @@ -3195,7 +3009,7 @@ dependencies = [ "tray-icon", "url", "uuid", - "which 6.0.3", + "which", "whoami", "windows 0.58.0", "wry", @@ -3248,7 +3062,7 @@ dependencies = [ "tracing", "tracing-subscriber", "uuid", - "which 6.0.3", + "which", "whoami", ] @@ -3671,7 +3485,7 @@ dependencies = [ "tracing", "tracing-subscriber", "uuid", - "which 6.0.3", + "which", "winapi", "windows 0.58.0", "winreg 0.55.0", @@ -6100,17 +5914,6 @@ dependencies = [ "smallvec", ] -[[package]] -name = "nix" -version = "0.24.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa52e972a9a719cecb6864fb88568781eb706bac2cd1d4f04a648542dbf78069" -dependencies = [ - "bitflags 1.3.2", - "cfg-if", - "libc", -] - [[package]] name = "nix" version = "0.25.1" @@ -7648,17 +7451,6 @@ dependencies = [ "psl-types", ] -[[package]] -name = "pulldown-cmark" -version = "0.13.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e8bbe1a966bd2f362681a44f6edce3c2310ac21e4d5067a6e7ec396297a6ea0" -dependencies = [ - "bitflags 2.10.0", - "memchr", - "unicase", -] - [[package]] name = "pulp" version = "0.18.22" @@ -7779,7 +7571,7 @@ dependencies = [ "url", "uuid", "walkdir", - "which 6.0.3", + "which", "whoami", "winapi", "windows 0.58.0", @@ -8572,39 +8364,6 @@ version = "1.0.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d" -[[package]] -name = "rustyline" -version = "15.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2ee1e066dc922e513bda599c6ccb5f3bb2b0ea5870a579448f2622993f0a9a2f" -dependencies = [ - "bitflags 2.10.0", - "cfg-if", - "clipboard-win", - "fd-lock", - "libc", - "log", - "memchr", - "nix 0.29.0", - "radix_trie", - "rustyline-derive", - "unicode-segmentation", - "unicode-width 0.2.2", - "utf8parse", - "windows-sys 0.59.0", -] - -[[package]] -name = "rustyline-derive" -version = "0.11.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d66de233f908aebf9cc30ac75ef9103185b4b715c6f2fb7a626aa5e5ede53ab" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.108", -] - [[package]] name = "ryu" version = "1.0.20" @@ -8988,15 +8747,6 @@ dependencies = [ "nu-color-config", ] -[[package]] -name = "shell-quote" -version = "0.7.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fb502615975ae2365825521fa1529ca7648fd03ce0b0746604e0683856ecd7e4" -dependencies = [ - "bstr", -] - [[package]] name = "shell-words" version = "1.1.0" @@ -9081,37 +8831,6 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "56199f7ddabf13fe5074ce809e7d3f42b42ae711800501b5b16ea82ad029c39d" -[[package]] -name = "skim" -version = "0.16.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e29ac781a242d0cc04f1bbf5cc3522ef2205b778fdc8c2a6f293eef34467968" -dependencies = [ - "beef", - "bitflags 1.3.2", - "chrono", - "clap", - "crossbeam", - "defer-drop", - "derive_builder", - "env_logger 0.11.8", - "fuzzy-matcher", - "indexmap 2.11.4", - "log", - "nix 0.29.0", - "rand 0.9.2", - "rayon", - "regex", - "shell-quote", - "shlex", - "time", - "timer", - "tuikit", - "unicode-width 0.2.2", - "vte 0.15.0", - "which 7.0.3", -] - [[package]] name = "slab" version = "0.4.11" @@ -9405,27 +9124,6 @@ dependencies = [ "syn 2.0.108", ] -[[package]] -name = "syntect" -version = "5.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "656b45c05d95a5704399aeef6bd0ddec7b2b3531b7c9e900abbf7c4d2190c925" -dependencies = [ - "bincode", - "flate2", - "fnv", - "once_cell", - "onig", - "plist", - "regex-syntax", - "serde", - "serde_derive", - "serde_json", - "thiserror 2.0.17", - "walkdir", - "yaml-rust", -] - [[package]] name = "sys-locale" version = "0.3.2" @@ -9608,17 +9306,6 @@ dependencies = [ "utf-8", ] -[[package]] -name = "term" -version = "0.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c59df8ac95d96ff9bede18eb7300b0fda5e5d8d90960e76f8e14ae765eedbf1f" -dependencies = [ - "dirs-next", - "rustversion", - "winapi", -] - [[package]] name = "termcolor" version = "1.4.1" @@ -9783,15 +9470,6 @@ dependencies = [ "time-core", ] -[[package]] -name = "timer" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "31d42176308937165701f50638db1c31586f183f1aab416268216577aec7306b" -dependencies = [ - "chrono", -] - [[package]] name = "tiny-keccak" version = "2.0.2" @@ -10255,20 +9933,6 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" -[[package]] -name = "tuikit" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e19c6ab038babee3d50c8c12ff8b910bdb2196f62278776422f50390d8e53d8" -dependencies = [ - "bitflags 1.3.2", - "lazy_static", - "log", - "nix 0.24.3", - "term", - "unicode-width 0.1.14", -] - [[package]] name = "tungstenite" version = "0.26.2" @@ -10286,12 +9950,6 @@ dependencies = [ "utf-8", ] -[[package]] -name = "typed-path" -version = "0.11.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c462d18470a2857aa657d338af5fa67170bb48bcc80a296710ce3b0802a32566" - [[package]] name = "typeid" version = "1.0.3" @@ -10522,9 +10180,9 @@ checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" [[package]] name = "version-compare" -version = "0.2.0" +version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "852e951cb7832cb45cb1169900d19760cfa39b82bc0ea9c0e5a14ae88411c98b" +checksum = "03c2856837ef78f57382f06b2b8563a2f512f7185d732608fd9176cb3b8edf0e" [[package]] name = "version_check" @@ -10902,18 +10560,6 @@ dependencies = [ "winsafe", ] -[[package]] -name = "which" -version = "7.0.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "24d643ce3fd3e5b54854602a080f34fb10ab75e0b813ee32d00ca2b44fa74762" -dependencies = [ - "either", - "env_home", - "rustix 1.1.2", - "winsafe", -] - [[package]] name = "whichlang" version = "0.1.1" @@ -11007,28 +10653,6 @@ dependencies = [ "windows-targets 0.52.6", ] -[[package]] -name = "windows" -version = "0.61.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9babd3a767a4c1aef6900409f85f5d53ce2544ccdfaa86dad48c91782c6d6893" -dependencies = [ - "windows-collections", - "windows-core 0.61.2", - "windows-future", - "windows-link 0.1.3", - "windows-numerics", -] - -[[package]] -name = "windows-collections" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3beeceb5e5cfd9eb1d76b381630e82c4241ccd0d27f1a39ed41b2760b255c5e8" -dependencies = [ - "windows-core 0.61.2", -] - [[package]] name = "windows-core" version = "0.56.0" @@ -11079,17 +10703,6 @@ dependencies = [ "windows-strings 0.4.2", ] -[[package]] -name = "windows-future" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fc6a41e98427b19fe4b73c550f060b59fa592d7d686537eebf9385621bfbad8e" -dependencies = [ - "windows-core 0.61.2", - "windows-link 0.1.3", - "windows-threading", -] - [[package]] name = "windows-implement" version = "0.56.0" @@ -11190,16 +10803,6 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" -[[package]] -name = "windows-numerics" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9150af68066c4c5c07ddc0ce30421554771e528bde427614c61038bc2c92c2b1" -dependencies = [ - "windows-core 0.61.2", - "windows-link 0.1.3", -] - [[package]] name = "windows-registry" version = "0.5.3" @@ -11374,15 +10977,6 @@ dependencies = [ "windows_x86_64_msvc 0.53.1", ] -[[package]] -name = "windows-threading" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b66463ad2e0ea3bbf808b7f1d371311c80e115c0b71d60efc142cafbcfb057a6" -dependencies = [ - "windows-link 0.1.3", -] - [[package]] name = "windows-version" version = "0.1.7" @@ -11581,15 +11175,6 @@ dependencies = [ "memchr", ] -[[package]] -name = "winnow" -version = "0.6.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a4191c47f15cc3ec71fcb4913cb83d58def65dd3787610213c649283b5ce178" -dependencies = [ - "memchr", -] - [[package]] name = "winnow" version = "0.7.13" @@ -11774,15 +11359,6 @@ version = "0.13.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "66fee0b777b0f5ac1c69bb06d361268faafa61cd4682ae064a171c16c433e9e4" -[[package]] -name = "yaml-rust" -version = "0.4.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "56c1936c4cc7a1c9ab21a1ebb602eb942ba868cbd44a99cb7cdc5892335e1c85" -dependencies = [ - "linked-hash-map", -] - [[package]] name = "yansi" version = "1.0.1" diff --git a/Cargo.toml b/Cargo.toml index 4eb119ed6..3b2a91673 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,7 +9,7 @@ members = [ "tests/fig-api/fig-api-mock", "tests/figterm2", ] -default-members = ["crates/chat-cli"] +default-members = ["crates/q_cli"] [workspace.package] authors = [ diff --git a/README.md b/README.md index 8e0e5634c..58a0cf01a 100644 --- a/README.md +++ b/README.md @@ -155,14 +155,14 @@ pnpm install --ignore-scripts ### 3. Start Local Development To compile and view changes made to `q chat`: ```shell -cargo run --bin chat_cli +cargo run --bin q_cli ``` -> If you are working on other q commands, just append `-- `. For example, to run `q login`, you can run `cargo run --bin chat_cli -- login` +> If you are working on other q commands, just append `-- `. For example, to run `q login`, you can run `cargo run --bin q_cli -- login` To run tests for the Q CLI crate: ```shell -cargo test -p chat_cli +cargo test -p q_cli ``` To format Rust files: diff --git a/build-scripts/qchatbuild.py b/build-scripts/qchatbuild.py deleted file mode 100644 index bae469b49..000000000 --- a/build-scripts/qchatbuild.py +++ /dev/null @@ -1,587 +0,0 @@ -import base64 -from dataclasses import dataclass -import json -import pathlib -from functools import cache -import os -import shutil -import time -from typing import Any, Mapping, Sequence, List, Optional -from build import generate_sha -from const import APPLE_TEAM_ID, CHAT_BINARY_NAME, CHAT_PACKAGE_NAME -from util import debug, info, isDarwin, isLinux, run_cmd, run_cmd_output, warn -from rust import cargo_cmd_name, rust_env, rust_targets -from importlib import import_module - -Args = Sequence[str | os.PathLike] -Env = Mapping[str, str | os.PathLike] -Cwd = str | os.PathLike - -BUILD_DIR_RELATIVE = pathlib.Path(os.environ.get("BUILD_DIR") or "build") -BUILD_DIR = BUILD_DIR_RELATIVE.absolute() - -CD_SIGNER_REGION = "us-west-2" -SIGNING_API_BASE_URL = "https://api.signer.builder-tools.aws.dev" - - -@dataclass -class CdSigningData: - bucket_name: str - """The bucket hosting signing artifacts accessible by CD Signer.""" - apple_notarizing_secret_arn: str - """The ARN of the secret containing the Apple ID and password, used during notarization""" - signing_role_arn: str - """The ARN of the role used by CD Signer""" - - -@dataclass -class MacOSBuildOutput: - chat_path: pathlib.Path - """The path to the chat binary""" - chat_zip_path: pathlib.Path - """The path to the chat binary zipped""" - - -def run_cargo_tests(): - args = [cargo_cmd_name()] - args.extend(["test", "--locked", "--package", CHAT_PACKAGE_NAME]) - run_cmd( - args, - env={ - **os.environ, - **rust_env(release=False), - }, - ) - - -def run_clippy(): - args = [cargo_cmd_name(), "clippy", "--locked", "--package", CHAT_PACKAGE_NAME] - run_cmd( - args, - env={ - **os.environ, - **rust_env(release=False), - }, - ) - - -def build_chat_bin( - release: bool, - output_name: str | None = None, - targets: Sequence[str] = [], -): - package = CHAT_PACKAGE_NAME - - args = [cargo_cmd_name(), "build", "--locked", "--package", package] - - for target in targets: - args.extend(["--target", target]) - - if release: - args.append("--release") - target_subdir = "release" - else: - target_subdir = "debug" - - run_cmd( - args, - env={ - **os.environ, - **rust_env(release=release), - }, - ) - - # create "universal" binary for macos - if isDarwin(): - out_path = BUILD_DIR / f"{output_name or package}-universal-apple-darwin" - args = [ - "lipo", - "-create", - "-output", - out_path, - ] - for target in targets: - args.append(pathlib.Path("target") / target / target_subdir / package) - run_cmd(args) - return out_path - else: - # linux does not cross compile arch - target = targets[0] - target_path = pathlib.Path("target") / target / target_subdir / package - out_path = BUILD_DIR / "bin" / f"{(output_name or package)}-{target}" - out_path.parent.mkdir(parents=True, exist_ok=True) - shutil.copy2(target_path, out_path) - return out_path - - -@cache -def get_creds(): - boto3 = import_module("boto3") - session = boto3.Session() - credentials = session.get_credentials() - creds = credentials.get_frozen_credentials() - return creds - - -def cd_signer_request(method: str, path: str, data: str | None = None): - """ - Sends a request to the CD Signer API. - """ - SigV4Auth = import_module("botocore.auth").SigV4Auth - AWSRequest = import_module("botocore.awsrequest").AWSRequest - requests = import_module("requests") - - url = f"{SIGNING_API_BASE_URL}{path}" - headers = {"Content-Type": "application/json"} - request = AWSRequest(method=method, url=url, data=data, headers=headers) - SigV4Auth(get_creds(), "signer-builder-tools", CD_SIGNER_REGION).add_auth(request) - - for i in range(1, 8): - debug(f"Sending request {method} to {url} with data: {data}") - response = requests.request(method=method, url=url, headers=dict(request.headers), data=data) - info(f"CDSigner Request ({url}): {response.status_code}") - if response.status_code == 429: - warn(f"Too many requests, backing off for {2**i} seconds") - time.sleep(2**i) - continue - return response - - raise Exception(f"Failed to request {url}") - - -def cd_signer_create_request(manifest: Any) -> str: - """ - Sends a POST request to create a new signing request. After creation, we - need to send another request to start it. - """ - response = cd_signer_request( - method="POST", - path="/signing_requests", - data=json.dumps({"manifest": manifest}), - ) - response_json = response.json() - info(f"Signing request create: {response_json}") - request_id = response_json["signingRequestId"] - return request_id - - -def cd_signer_start_request(request_id: str, source_key: str, destination_key: str, signing_data: CdSigningData): - """ - Sends a POST request to start the signing process. - """ - response_text = cd_signer_request( - method="POST", - path=f"/signing_requests/{request_id}/start", - data=json.dumps( - { - "iamRole": f"{signing_data.signing_role_arn}", - "s3Location": { - "bucket": signing_data.bucket_name, - "sourceKey": source_key, - "destinationKey": destination_key, - }, - } - ), - ).text - info(f"Signing request start: {response_text}") - - -def cd_signer_status_request(request_id: str): - response_json = cd_signer_request( - method="GET", - path=f"/signing_requests/{request_id}", - ).json() - info(f"Signing request status: {response_json}") - return response_json["signingRequest"]["status"] - - -def cd_build_signed_package(exe_path: pathlib.Path): - """ - Creates a tarball `package.tar.gz` with the following structure: - ``` - package - ├─ EXECUTABLES_TO_SIGN - | ├─ {chat_binary} - ``` - """ - # Trying a different format without manifest.yaml and placing EXECUTABLES_TO_SIGN - # at the root. - # The docs contain conflicting information, idk what to even do here - working_dir = BUILD_DIR / "package" - shutil.rmtree(working_dir, ignore_errors=True) - (BUILD_DIR / "package" / "EXECUTABLES_TO_SIGN").mkdir(parents=True) - - shutil.copy2(exe_path, working_dir / "EXECUTABLES_TO_SIGN" / exe_path.name) - exe_path.unlink() - - run_cmd(["gtar", "-czf", "artifact.gz", "EXECUTABLES_TO_SIGN"], cwd=working_dir) - run_cmd( - ["gtar", "-czf", BUILD_DIR / "package.tar.gz", "artifact.gz"], - cwd=working_dir, - ) - - return BUILD_DIR / "package.tar.gz" - - -def manifest( - identifier: str, -): - """ - Returns the manifest arguments required when creating a new CD Signer request. - """ - return { - "type": "app", - "os": "osx", - "name": "EXECUTABLES_TO_SIGN", - "outputs": [{"label": "macos", "path": "EXECUTABLES_TO_SIGN"}], - "app": { - "identifier": identifier, - "signing_requirements": { - "certificate_type": "developerIDAppDistribution", - "app_id_prefix": APPLE_TEAM_ID, - }, - }, - } - - -def sign_executable(signing_data: CdSigningData, exe_path: pathlib.Path) -> pathlib.Path: - """ - Signs an executable with CD Signer. - - Returns: - The path to the signed executable - """ - name = exe_path.name - info(f"Signing {name}") - - info("Packaging...") - package_path = cd_build_signed_package(exe_path) - - info("Uploading...") - run_cmd(["aws", "s3", "rm", "--recursive", f"s3://{signing_data.bucket_name}/signed"]) - run_cmd(["aws", "s3", "rm", "--recursive", f"s3://{signing_data.bucket_name}/pre-signed"]) - run_cmd(["aws", "s3", "cp", package_path, f"s3://{signing_data.bucket_name}/pre-signed/package.tar.gz"]) - - info("Sending request...") - request_id = cd_signer_create_request(manifest("com.amazon.codewhisperer")) - cd_signer_start_request( - request_id=request_id, - source_key="pre-signed/package.tar.gz", - destination_key="signed/signed.zip", - signing_data=signing_data, - ) - - max_duration = 180 - end_time = time.time() + max_duration - i = 1 - while True: - info(f"Checking for signed package attempt #{i}") - status = cd_signer_status_request(request_id) - info(f"Package has status: {status}") - - match status: - case "success": - break - case "created" | "processing" | "inProgress": - pass - case "failure": - raise RuntimeError("Signing request failed") - case _: - warn(f"Unexpected status, ignoring: {status}") - - if time.time() >= end_time: - raise RuntimeError("Signed package did not appear, check signer logs") - time.sleep(2) - i += 1 - - info("Signed!") - - # CD Signer should return the signed executable in a zip file containing the structure: - # "Payload/EXECUTABLES_TO_SIGN/{executable name}". - info("Downloading...") - - # Create a new directory for unzipping the signed executable. - zip_dl_path = BUILD_DIR / pathlib.Path("signed.zip") - run_cmd(["aws", "s3", "cp", f"s3://{signing_data.bucket_name}/signed/signed.zip", zip_dl_path]) - payload_path = BUILD_DIR / "signed" - shutil.rmtree(payload_path, ignore_errors=True) - run_cmd(["unzip", zip_dl_path, "-d", payload_path]) - zip_dl_path.unlink() - signed_exe_path = BUILD_DIR / "signed" / "Payload" / "EXECUTABLES_TO_SIGN" / name - # Verify that the exe is signed - run_cmd(["codesign", "--verify", "--verbose=4", signed_exe_path]) - return signed_exe_path - - -def notarize_executable(signing_data: CdSigningData, exe_path: pathlib.Path): - """ - Submits an executable to Apple notary service. - """ - # Load the Apple id and password from secrets manager. - secret_id = signing_data.apple_notarizing_secret_arn - secret_region = parse_region_from_arn(signing_data.apple_notarizing_secret_arn) - info(f"Loading secretmanager value: {secret_id}") - secret_value = run_cmd_output( - ["aws", "--region", secret_region, "secretsmanager", "get-secret-value", "--secret-id", secret_id] - ) - secret_string = json.loads(secret_value)["SecretString"] - secrets = json.loads(secret_string) - - # Submit the exe to Apple notary service. It must be zipped first. - info(f"Submitting {exe_path} to Apple notary service") - zip_path = BUILD_DIR / f"{exe_path.name}.zip" - zip_path.unlink(missing_ok=True) - run_cmd(["zip", "-j", zip_path, exe_path], cwd=BUILD_DIR) - submit_res = run_cmd_output( - [ - "xcrun", - "notarytool", - "submit", - zip_path, - "--team-id", - APPLE_TEAM_ID, - "--apple-id", - secrets["appleId"], - "--password", - secrets["appleIdPassword"], - "--wait", - "-f", - "json", - ] - ) - debug(f"Notary service response: {submit_res}") - - # Confirm notarization succeeded. - assert json.loads(submit_res)["status"] == "Accepted" - - # Cleanup - zip_path.unlink() - - -def sign_and_notarize(signing_data: CdSigningData, chat_path: pathlib.Path) -> pathlib.Path: - """ - Signs an executable with CD Signer, and verifies it with Apple notary service. - - Returns: - The path to the signed executable. - """ - # First, sign the application - chat_path = sign_executable(signing_data, chat_path) - - # Next, notarize the application - notarize_executable(signing_data, chat_path) - - return chat_path - - -def build_macos(chat_path: pathlib.Path, signing_data: CdSigningData | None): - """ - Creates a chat binary zip under the build directory. - """ - chat_dst = BUILD_DIR / CHAT_BINARY_NAME - chat_dst.unlink(missing_ok=True) - shutil.copy2(chat_path, chat_dst) - - if signing_data: - chat_dst = sign_and_notarize(signing_data, chat_dst) - - zip_path = BUILD_DIR / f"{CHAT_BINARY_NAME}.zip" - zip_path.unlink(missing_ok=True) - - info(f"Creating zip output to {zip_path}") - run_cmd(["zip", "-j", zip_path, chat_dst], cwd=BUILD_DIR) - generate_sha(zip_path) - - -class GpgSigner: - def __init__(self, gpg_id: str, gpg_secret_key: str, gpg_passphrase: str): - self.gpg_id = gpg_id - self.gpg_secret_key = gpg_secret_key - self.gpg_passphrase = gpg_passphrase - - self.gpg_home = pathlib.Path.home() / ".gnupg-tmp" - self.gpg_home.mkdir(parents=True, exist_ok=True, mode=0o700) - - # write gpg secret key to file - self.gpg_secret_key_path = self.gpg_home / "gpg_secret" - self.gpg_secret_key_path.write_bytes(base64.b64decode(gpg_secret_key)) - - self.gpg_passphrase_path = self.gpg_home / "gpg_pass" - self.gpg_passphrase_path.write_text(gpg_passphrase) - - run_cmd(["gpg", "--version"]) - - info("Importing GPG key") - run_cmd(["gpg", "--list-keys"], env=self.gpg_env()) - run_cmd( - ["gpg", *self.sign_args(), "--allow-secret-key-import", "--import", self.gpg_secret_key_path], - env=self.gpg_env(), - ) - run_cmd(["gpg", "--list-keys"], env=self.gpg_env()) - - def gpg_env(self) -> Env: - return {**os.environ, "GNUPGHOME": self.gpg_home} - - def sign_args(self) -> Args: - return [ - "--batch", - "--pinentry-mode", - "loopback", - "--no-tty", - "--yes", - "--passphrase-file", - self.gpg_passphrase_path, - ] - - def sign_file(self, path: pathlib.Path) -> List[pathlib.Path]: - info(f"Signing {path.name}") - run_cmd( - ["gpg", "--detach-sign", *self.sign_args(), "--local-user", self.gpg_id, path], - env=self.gpg_env(), - ) - run_cmd( - ["gpg", "--detach-sign", *self.sign_args(), "--armor", "--local-user", self.gpg_id, path], - env=self.gpg_env(), - ) - return [path.with_suffix(f"{path.suffix}.asc"), path.with_suffix(f"{path.suffix}.sig")] - - def clean(self): - info("Cleaning gpg keys") - shutil.rmtree(self.gpg_home, ignore_errors=True) - - -def get_secretmanager_json(secret_id: str, secret_region: str): - info(f"Loading secretmanager value: {secret_id}") - secret_value = run_cmd_output( - ["aws", "--region", secret_region, "secretsmanager", "get-secret-value", "--secret-id", secret_id] - ) - secret_string = json.loads(secret_value)["SecretString"] - return json.loads(secret_string) - - -def load_gpg_signer() -> Optional[GpgSigner]: - if gpg_id := os.getenv("TEST_PGP_ID"): - gpg_secret_key = os.getenv("TEST_PGP_SECRET_KEY") - gpg_passphrase = os.getenv("TEST_PGP_PASSPHRASE") - if gpg_secret_key is not None and gpg_passphrase is not None: - info("Using test pgp key", gpg_id) - return GpgSigner(gpg_id=gpg_id, gpg_secret_key=gpg_secret_key, gpg_passphrase=gpg_passphrase) - - pgp_secret_arn = os.getenv("SIGNING_PGP_KEY_SECRET_ARN") - info(f"SIGNING_PGP_KEY_SECRET_ARN: {pgp_secret_arn}") - if pgp_secret_arn: - pgp_secret_region = parse_region_from_arn(pgp_secret_arn) - gpg_secret_json = get_secretmanager_json(pgp_secret_arn, pgp_secret_region) - gpg_id = gpg_secret_json["gpg_id"] - gpg_secret_key = gpg_secret_json["gpg_secret_key"] - gpg_passphrase = gpg_secret_json["gpg_passphrase"] - return GpgSigner(gpg_id=gpg_id, gpg_secret_key=gpg_secret_key, gpg_passphrase=gpg_passphrase) - else: - return None - - -def parse_region_from_arn(arn: str) -> str: - # ARN format: arn:partition:service:region:account-id:resource-type/resource-id - # Check if we have enough parts and the ARN starts with "arn:" - parts = arn.split(":") - if len(parts) >= 4: - return parts[3] - - return "" - - -def build_linux(chat_path: pathlib.Path, signer: GpgSigner | None): - """ - Creates tar.gz, tar.xz, tar.zst, and zip archives under `BUILD_DIR`. - - Each archive has the following structure: - - archive/{chat_binary} - """ - archive_name = CHAT_BINARY_NAME - - archive_path = pathlib.Path(archive_name) - archive_path.mkdir(parents=True, exist_ok=True) - shutil.copy2(chat_path, archive_path / CHAT_BINARY_NAME) - - info(f"Building {archive_name}.tar.gz") - tar_gz_path = BUILD_DIR / f"{archive_name}.tar.gz" - run_cmd(["tar", "-czf", tar_gz_path, archive_path]) - generate_sha(tar_gz_path) - if signer: - signer.sign_file(tar_gz_path) - - info(f"Building {archive_name}.zip") - zip_path = BUILD_DIR / f"{archive_name}.zip" - run_cmd(["zip", "-r", zip_path, archive_path]) - generate_sha(zip_path) - if signer: - signer.sign_file(zip_path) - - # clean up - shutil.rmtree(archive_path) - if signer: - signer.clean() - - -def build( - release: bool, - stage_name: str | None = None, - run_lints: bool = True, - run_test: bool = True, -): - BUILD_DIR.mkdir(exist_ok=True) - - disable_signing = os.environ.get("DISABLE_SIGNING") - - gpg_signer = load_gpg_signer() if not disable_signing and isLinux() else None - signing_role_arn = os.environ.get("SIGNING_ROLE_ARN") - signing_bucket_name = os.environ.get("SIGNING_BUCKET_NAME") - signing_apple_notarizing_secret_arn = os.environ.get("SIGNING_APPLE_NOTARIZING_SECRET_ARN") - if ( - not disable_signing - and isDarwin() - and signing_role_arn - and signing_bucket_name - and signing_apple_notarizing_secret_arn - ): - signing_data = CdSigningData( - bucket_name=signing_bucket_name, - apple_notarizing_secret_arn=signing_apple_notarizing_secret_arn, - signing_role_arn=signing_role_arn, - ) - else: - signing_data = None - - match stage_name: - case "prod" | None: - info("Building for prod") - case "gamma": - info("Building for gamma") - case _: - raise ValueError(f"Unknown stage name: {stage_name}") - - targets = rust_targets() - - info(f"Release: {release}") - info(f"Targets: {targets}") - info(f"Signing app: {signing_data is not None or gpg_signer is not None}") - - if run_test: - info("Running cargo tests") - run_cargo_tests() - - if run_lints: - info("Running cargo clippy") - run_clippy() - - info("Building", CHAT_PACKAGE_NAME) - chat_path = build_chat_bin( - release=release, - output_name=CHAT_BINARY_NAME, - targets=targets, - ) - - if isDarwin(): - build_macos(chat_path, signing_data) - else: - build_linux(chat_path, gpg_signer) diff --git a/build-scripts/qchatmain.py b/build-scripts/qchatmain.py deleted file mode 100644 index b6142ee52..000000000 --- a/build-scripts/qchatmain.py +++ /dev/null @@ -1,50 +0,0 @@ -import argparse -from qchatbuild import build - - -class StoreIfNotEmptyAction(argparse.Action): - def __call__(self, parser, namespace, values, option_string=None): - if values and len(values) > 0: - setattr(namespace, self.dest, values) - - -parser = argparse.ArgumentParser( - prog="build", - description="Builds the chat binary", -) -subparsers = parser.add_subparsers(help="sub-command help", dest="subparser", required=True) - -build_subparser = subparsers.add_parser(name="build") -build_subparser.add_argument( - "--stage-name", - action=StoreIfNotEmptyAction, - help="The name of the stage", -) -build_subparser.add_argument( - "--not-release", - action="store_true", - help="Build a non-release version", -) -build_subparser.add_argument( - "--skip-tests", - action="store_true", - help="Skip running npm and rust tests", -) -build_subparser.add_argument( - "--skip-lints", - action="store_true", - help="Skip running lints", -) - -args = parser.parse_args() - -match args.subparser: - case "build": - build( - release=not args.not_release, - stage_name=args.stage_name, - run_lints=not args.skip_lints, - run_test=not args.skip_tests, - ) - case _: - raise ValueError(f"Unsupported subparser {args.subparser}") diff --git a/crates/chat-cli/.gitignore b/crates/chat-cli/.gitignore deleted file mode 100644 index b082f8a65..000000000 --- a/crates/chat-cli/.gitignore +++ /dev/null @@ -1,5 +0,0 @@ -build/ -spec.ts - -# This is created by the build script for macOS -src/Info.plist diff --git a/crates/chat-cli/Cargo.toml b/crates/chat-cli/Cargo.toml deleted file mode 100644 index cddbdc41a..000000000 --- a/crates/chat-cli/Cargo.toml +++ /dev/null @@ -1,204 +0,0 @@ -[package] -name = "chat_cli" -authors.workspace = true -edition.workspace = true -homepage.workspace = true -publish.workspace = true -version.workspace = true -license.workspace = true -default-run = "chat_cli" - -[lints] -workspace = true - -[features] -default = [] -wayland = ["arboard/wayland-data-control"] - -[[bin]] -name = "test_mcp_server" -path = "test_mcp_server/test_server.rs" -test = true -doc = false - -[dependencies] -amzn-codewhisperer-client = { path = "../amzn-codewhisperer-client" } -amzn-codewhisperer-streaming-client = { path = "../amzn-codewhisperer-streaming-client" } -amzn-consolas-client = { path = "../amzn-consolas-client" } -amzn-qdeveloper-streaming-client = { path = "../amzn-qdeveloper-streaming-client" } -amzn-toolkit-telemetry-client = { path = "../amzn-toolkit-telemetry-client" } -anstream = "0.6.13" -arboard = { version = "3.5.0", default-features = false } -async-trait = "0.1.87" -aws-config = "1.0.3" -aws-credential-types = "1.0.3" -aws-runtime = "1.4.4" -aws-sdk-cognitoidentity = "1.51.0" -aws-sdk-ssooidc = "1.51.0" -aws-smithy-async = "1.2.2" -aws-smithy-runtime-api = "1.6.1" -aws-smithy-types = "1.2.10" -aws-types = "1.3.0" -base64 = "0.22.1" -bitflags = "2.9.0" -bstr = "1.12.0" -bytes = "1.10.1" -camino = { version = "1.1.3", features = ["serde1"] } -cfg-if = "1.0.0" -clap = { version = "4.5.32", features = [ - "deprecated", - "derive", - "string", - "unicode", - "wrap_help", -] } -clap_complete = "4.5.46" -clap_complete_fig = "4.4.0" -color-eyre = "0.6.5" -color-print = "0.3.5" -convert_case = "0.8.0" -cookie = "0.18.1" -crossterm = { version = "0.28.1", features = ["event-stream", "events"] } -ctrlc = "3.4.6" -dialoguer = { version = "0.11.0", features = ["fuzzy-select"] } -dirs = "5.0.0" -eyre = "0.6.8" -fd-lock = "4.0.4" -futures = "0.3.26" -glob = "0.3.2" -globset = "0.4.16" -hex = "0.4.3" -http = "1.2.0" -http-body-util = "0.1.3" -hyper = { version = "1.6.0", features = ["server"] } -hyper-util = { version = "0.1.11", features = ["tokio"] } -indicatif = "0.17.11" -indoc = "2.0.6" -insta = "1.43.1" -libc = "0.2.172" -mimalloc = "0.1.46" -nix = { version = "0.29.0", features = [ - "feature", - "fs", - "ioctl", - "process", - "signal", - "term", - "user", -] } -owo-colors = "4.2.0" -parking_lot = "0.12.3" -paste = "1.0.11" -percent-encoding = "2.2.0" -r2d2 = "0.8.10" -r2d2_sqlite = "0.25.0" -rand = "0.9.0" -regex = "1.7.0" -reqwest = { version = "0.12.14", default-features = false, features = [ - "http2", - "charset", - "rustls-tls", - "rustls-tls-native-roots", - "gzip", - "json", - "socks", - "cookies", -] } -ring = "0.17.14" -rusqlite = { version = "0.32.1", features = ["bundled", "serde_json"] } -rustls = "0.23.23" -rustls-native-certs = "0.8.1" -rustls-pemfile = "2.1.0" -rustyline = { version = "15.0.0", features = [ - "custom-bindings", - "derive", - "with-file-history", -], default-features = false } -semantic_search_client = { path = "../semantic_search_client" } -semver = { version = "1.0.26", features = ["serde"] } -serde = { version = "1.0.219", features = ["derive", "rc"] } -serde_json = { version = "1.0.140", features = ["preserve_order"] } -sha2 = "0.10.9" -shell-color = "1.0.0" -shell-words = "1.1.0" -shellexpand = "3.0.0" -shlex = "1.3.0" -similar = "2.7.0" -spinners = "4.1.0" -strip-ansi-escapes = "0.2.1" -strum = { version = "0.27.1", features = ["derive"] } -syntect = "5.2.0" -sysinfo = "0.33.1" -tempfile = "3.18.0" -thiserror = "2.0.12" -time = { version = "0.3.39", features = [ - "parsing", - "formatting", - "local-offset", - "macros", - "serde", -] } -tokio = { version = "1.47.1", features = ["full"] } -tokio-tungstenite = "0.26.2" -tokio-util = { version = "0.7.16", features = ["codec", "compat"] } -toml = "0.8.12" -tracing = { version = "0.1.40", features = ["log"] } -tracing-appender = "0.2.2" -tracing-subscriber = { version = "0.3.19", features = [ - "env-filter", - "fmt", - "parking_lot", - "time", -] } -typed-path = "0.11.0" -unicode-width = "0.2.0" -url = "2.5.4" -uuid = { version = "1.15.1", features = ["v4", "serde"] } -walkdir = "2.5.0" -webpki-roots = "=0.26.8" -whoami = "1.6.0" -winnow = "=0.6.2" - -[target.'cfg(unix)'.dependencies] -nix = { version = "0.29.0", features = [ - "feature", - "fs", - "ioctl", - "process", - "signal", - "term", - "user", -] } -skim = { version = "0.16.2" } - -[target.'cfg(target_os = "macos")'.dependencies] -objc2 = "0.5.2" -objc2-app-kit = { version = "0.2.2", features = ["NSWorkspace"] } -objc2-foundation = { version = "0.2.2", features = ["NSString", "NSURL"] } -security-framework = "3.2.0" - -[target.'cfg(windows)'.dependencies] -windows = { version = "0.61.1", features = [ - "Foundation", - "Win32_System_ProcessStatus", - "Win32_System_Kernel", - "Win32_System_Threading", - "Wdk_System_Threading", -] } -winreg = "0.55.0" - -[dev-dependencies] -assert_cmd = "2.0" -criterion = "0.6.0" -mockito = "1.7.0" -paste = "1.0.11" -predicates = "3.0" -tracing-test = "0.2.4" - -[build-dependencies] -convert_case = "0.8.0" -prettyplease = "0.2.32" -quote = "1.0.40" -serde = { version = "1.0.219", features = ["derive", "rc"] } -serde_json = "1.0.140" -syn = "2.0.101" diff --git a/crates/chat-cli/build.rs b/crates/chat-cli/build.rs deleted file mode 100644 index f097298af..000000000 --- a/crates/chat-cli/build.rs +++ /dev/null @@ -1,324 +0,0 @@ -use convert_case::{ - Case, - Casing, -}; -use quote::{ - format_ident, - quote, -}; - -// TODO(brandonskiser): update bundle identifier for signed builds -#[cfg(target_os = "macos")] -const MACOS_BUNDLE_IDENTIFIER: &str = "com.amazon.codewhisperer"; - -const DEF: &str = include_str!("./telemetry_definitions.json"); - -#[derive(Debug, Clone, serde::Deserialize)] -#[serde(rename_all = "camelCase")] -struct TypeDef { - name: String, - r#type: Option, - allowed_values: Option>, - description: String, -} - -#[derive(Debug, Clone, serde::Deserialize)] -struct MetricDef { - name: String, - description: String, - metadata: Option>, - passive: Option, - unit: Option, -} - -#[derive(Debug, Clone, serde::Deserialize)] -struct MetricMetadata { - r#type: String, - required: Option, -} - -#[derive(Debug, Clone, serde::Deserialize)] -struct Def { - types: Vec, - metrics: Vec, -} - -/// Writes a generated Info.plist for the qchat executable under src/. -/// -/// This is required for signing the executable since we must embed the Info.plist directly within -/// the binary. -#[cfg(target_os = "macos")] -fn write_plist() { - let plist = format!( - r#" - - - - CFBundlePackageType - APPL - CFBundleIdentifier - {} - CFBundleName - {} - CFBundleVersion - {} - CFBundleShortVersionString - {} - CFBundleInfoDictionaryVersion - 6.0 - NSHumanReadableCopyright - Copyright © 2022 Amazon Q CLI Team (q-cli@amazon.com):Chay Nabors (nabochay@amazon.com):Brandon Kiser (bskiser@amazon.com) All rights reserved. - - -"#, - MACOS_BUNDLE_IDENTIFIER, - option_env!("AMAZON_Q_BUILD_HASH").unwrap_or("unknown"), - option_env!("AMAZON_Q_BUILD_DATETIME").unwrap_or("unknown"), - env!("CARGO_PKG_VERSION") - ); - - std::fs::write("src/Info.plist", plist).expect("writing the Info.plist should not fail"); -} - -fn main() { - println!("cargo:rerun-if-changed=def.json"); - - #[cfg(target_os = "macos")] - write_plist(); - - let outdir = std::env::var("OUT_DIR").unwrap(); - - let data = serde_json::from_str::(DEF).unwrap(); - - let mut out = " - #[allow(rustdoc::invalid_html_tags)] - #[allow(rustdoc::bare_urls)] - mod inner { - " - .to_string(); - - out.push_str("pub mod types {"); - for t in data.types { - let name = format_ident!("{}", t.name.to_case(Case::Pascal)); - - let rust_type = match t.allowed_values { - // enum - Some(allowed_values) => { - let mut variants = vec![]; - let mut variant_as_str = vec![]; - - for v in allowed_values { - let ident = format_ident!("{}", v.replace('.', "").to_case(Case::Pascal)); - variants.push(quote!( - #[doc = concat!("`", #v, "`")] - #ident - )); - variant_as_str.push(quote!( - #name::#ident => #v - )); - } - - let description = t.description; - - quote::quote!( - #[doc = #description] - #[derive(Debug, Clone, PartialEq)] - #[non_exhaustive] - pub enum #name { - #( - #variants, - )* - } - - impl #name { - pub fn as_str(&self) -> &'static str { - match self { - #( #variant_as_str, )* - } - } - } - - impl ::std::fmt::Display for #name { - fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { - f.write_str(self.as_str()) - } - } - ) - .to_string() - }, - // struct - None => { - let r#type = match t.r#type.as_deref() { - Some("string") | None => quote!(::std::string::String), - Some("int") => quote!(::std::primitive::i64), - Some("double") => quote!(::std::primitive::f64), - Some("boolean") => quote!(::std::primitive::bool), - Some(other) => panic!("{}", other), - }; - let description = t.description; - - quote::quote!( - #[doc = #description] - #[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] - #[serde(transparent)] - pub struct #name(pub #r#type); - - impl #name { - pub fn new(t: #r#type) -> Self { - Self(t) - } - - pub fn value(&self) -> &#r#type { - &self.0 - } - - pub fn into_value(self) -> #r#type { - self.0 - } - } - - impl ::std::fmt::Display for #name { - fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { - write!(f, "{}", self.0) - } - } - - impl From<#r#type> for #name { - fn from(t: #r#type) -> Self { - Self(t) - } - } - ) - .to_string() - }, - }; - - out.push_str(&rust_type); - } - out.push('}'); - - out.push_str("pub mod metrics {"); - for m in data.metrics.clone() { - let raw_name = m.name; - let name = format_ident!("{}", raw_name.to_case(Case::Pascal)); - let description = m.description; - - let passive = m.passive.unwrap_or_default(); - - let unit = match m.unit.map(|u| u.to_lowercase()).as_deref() { - Some("bytes") => quote!(::amzn_toolkit_telemetry_client::types::Unit::Bytes), - Some("count") => quote!(::amzn_toolkit_telemetry_client::types::Unit::Count), - Some("milliseconds") => quote!(::amzn_toolkit_telemetry_client::types::Unit::Milliseconds), - Some("percent") => quote!(::amzn_toolkit_telemetry_client::types::Unit::Percent), - Some("none") | None => quote!(::amzn_toolkit_telemetry_client::types::Unit::None), - Some(unknown) => { - panic!("unknown unit: {:?}", unknown); - }, - }; - - let metadata = m.metadata.unwrap_or_default(); - - let mut fields = Vec::new(); - for field in &metadata { - let field_name = format_ident!("{}", &field.r#type.to_case(Case::Snake)); - let ty_name = format_ident!("{}", field.r#type.to_case(Case::Pascal)); - let ty = if field.required.unwrap_or_default() { - quote!(crate::telemetry::definitions::types::#ty_name) - } else { - quote!(::std::option::Option) - }; - - fields.push(quote!( - #field_name: #ty - )); - } - - let metadata_entries = metadata.iter().map(|m| { - let raw_name = &m.r#type; - let key = format_ident!("{}", m.r#type.to_case(Case::Snake)); - - let value = if m.required.unwrap_or_default() { - quote!(.value(self.#key.to_string())) - } else { - quote!(.value(self.#key.map(|v| v.to_string()).unwrap_or_default())) - }; - - quote!( - ::amzn_toolkit_telemetry_client::types::MetadataEntry::builder() - .key(#raw_name) - #value - .build() - ) - }); - - let rust_type = quote::quote!( - #[doc = #description] - #[derive(Debug, Clone, PartialEq, ::serde::Serialize, ::serde::Deserialize)] - #[serde(rename_all = "camelCase")] - pub struct #name { - /// The time that the event took place, - pub create_time: ::std::option::Option<::std::time::SystemTime>, - /// Value based on unit and call type, - pub value: ::std::option::Option, - #( pub #fields, )* - } - - impl #name { - const NAME: &'static ::std::primitive::str = #raw_name; - const PASSIVE: ::std::primitive::bool = #passive; - const UNIT: ::amzn_toolkit_telemetry_client::types::Unit = #unit; - } - - impl crate::telemetry::definitions::IntoMetricDatum for #name { - fn into_metric_datum(self) -> ::amzn_toolkit_telemetry_client::types::MetricDatum { - let metadata_entries = vec![ - #( - #metadata_entries, - )* - ]; - - let epoch_timestamp = self.create_time - .map_or_else( - || ::std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_millis() as ::std::primitive::i64, - |t| t.duration_since(std::time::UNIX_EPOCH).unwrap().as_millis() as ::std::primitive::i64 - ); - - ::amzn_toolkit_telemetry_client::types::MetricDatum::builder() - .metric_name(#name::NAME) - .passive(#name::PASSIVE) - .unit(#name::UNIT) - .epoch_timestamp(epoch_timestamp) - .value(self.value.unwrap_or(1.0)) - .set_metadata(Some(metadata_entries)) - .build() - .unwrap() - } - } - ); - - out.push_str(&rust_type.to_string()); - } - out.push('}'); - - // enum of all metrics - let mut metrics = Vec::new(); - for m in data.metrics { - let name = format_ident!("{}", m.name.to_case(Case::Pascal)); - metrics.push(quote!( - #name - )); - } - out.push_str("#[derive(Debug, Clone, PartialEq, ::serde::Serialize, ::serde::Deserialize)]\n#[serde(tag = \"type\", content = \"content\")]\npub enum Metric {\n"); - for m in metrics { - out.push_str(&format!("{m}(crate::telemetry::definitions::metrics::{m}),\n")); - } - out.push('}'); - - out.push_str("}\npub use inner::*;"); - - let file: syn::File = syn::parse_str(&out).unwrap(); - let pp = prettyplease::unparse(&file); - - // write an empty file to the output directory - std::fs::write(format!("{}/mod.rs", outdir), pp).unwrap(); -} diff --git a/crates/chat-cli/src/api_client/credentials.rs b/crates/chat-cli/src/api_client/credentials.rs deleted file mode 100644 index e24d8cdb9..000000000 --- a/crates/chat-cli/src/api_client/credentials.rs +++ /dev/null @@ -1,80 +0,0 @@ -use aws_config::default_provider::region::DefaultRegionChain; -use aws_config::ecs::EcsCredentialsProvider; -use aws_config::environment::credentials::EnvironmentVariableCredentialsProvider; -use aws_config::imds::credentials::ImdsCredentialsProvider; -use aws_config::meta::credentials::CredentialsProviderChain; -use aws_config::profile::ProfileFileCredentialsProvider; -use aws_config::provider_config::ProviderConfig; -use aws_config::web_identity_token::WebIdentityTokenCredentialsProvider; -use aws_credential_types::Credentials; -use aws_credential_types::provider::{ - self, - ProvideCredentials, - future, -}; -use tracing::Instrument; - -#[derive(Debug)] -pub struct CredentialsChain { - provider_chain: CredentialsProviderChain, -} - -impl CredentialsChain { - /// Based on code the code for - /// [aws_config::default_provider::credentials::DefaultCredentialsChain] - pub async fn new() -> Self { - let region = DefaultRegionChain::builder().build().region().await; - let config = ProviderConfig::default().with_region(region.clone()); - - let env_provider = EnvironmentVariableCredentialsProvider::new(); - let profile_provider = ProfileFileCredentialsProvider::builder().configure(&config).build(); - let web_identity_token_provider = WebIdentityTokenCredentialsProvider::builder() - .configure(&config) - .build(); - let imds_provider = ImdsCredentialsProvider::builder().configure(&config).build(); - let ecs_provider = EcsCredentialsProvider::builder().configure(&config).build(); - - let mut provider_chain = CredentialsProviderChain::first_try("Environment", env_provider); - - provider_chain = provider_chain - .or_else("Profile", profile_provider) - .or_else("WebIdentityToken", web_identity_token_provider) - .or_else("EcsContainer", ecs_provider) - .or_else("Ec2InstanceMetadata", imds_provider); - - CredentialsChain { provider_chain } - } - - async fn credentials(&self) -> provider::Result { - self.provider_chain - .provide_credentials() - .instrument(tracing::debug_span!("provide_credentials", provider = %"default_chain")) - .await - } -} - -impl ProvideCredentials for CredentialsChain { - fn provide_credentials<'a>(&'a self) -> future::ProvideCredentials<'a> - where - Self: 'a, - { - future::ProvideCredentials::new(self.credentials()) - } - - fn fallback_on_interrupt(&self) -> Option { - self.provider_chain.fallback_on_interrupt() - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[tokio::test] - async fn test_credentials_chain() { - let credentials_chain = CredentialsChain::new().await; - let credentials_res = credentials_chain.provide_credentials().await; - let fallback_on_interrupt_res = credentials_chain.fallback_on_interrupt(); - println!("credentials_res: {credentials_res:?}, fallback_on_interrupt_res: {fallback_on_interrupt_res:?}"); - } -} diff --git a/crates/chat-cli/src/api_client/customization.rs b/crates/chat-cli/src/api_client/customization.rs deleted file mode 100644 index 98d22c8f6..000000000 --- a/crates/chat-cli/src/api_client/customization.rs +++ /dev/null @@ -1,123 +0,0 @@ -use amzn_codewhisperer_client::types::Customization as CodewhispererCustomization; -use amzn_consolas_client::types::CustomizationSummary as ConsolasCustomization; -use serde::{ - Deserialize, - Serialize, -}; - -#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] -#[serde(rename_all = "camelCase")] -pub struct Customization { - pub arn: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub name: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub description: Option, -} - -impl From for CodewhispererCustomization { - fn from(Customization { arn, name, description }: Customization) -> Self { - CodewhispererCustomization::builder() - .arn(arn) - .set_name(name) - .set_description(description) - .build() - .expect("Failed to build CW Customization") - } -} - -impl From for Customization { - fn from(cw_customization: CodewhispererCustomization) -> Self { - Customization { - arn: cw_customization.arn, - name: cw_customization.name, - description: cw_customization.description, - } - } -} - -impl From for Customization { - fn from(consolas_customization: ConsolasCustomization) -> Self { - Customization { - arn: consolas_customization.arn, - name: Some(consolas_customization.customization_name), - description: consolas_customization.description, - } - } -} - -#[cfg(test)] -mod tests { - use amzn_consolas_client::types::CustomizationStatus; - use aws_smithy_types::DateTime; - - use super::*; - - #[test] - fn test_customization_from_impls() { - let cw_customization = CodewhispererCustomization::builder() - .arn("arn") - .name("name") - .description("description") - .build() - .unwrap(); - - let custom_from_cw: Customization = cw_customization.into(); - let cw_from_custom: CodewhispererCustomization = custom_from_cw.into(); - - assert_eq!(cw_from_custom.arn, "arn"); - assert_eq!(cw_from_custom.name, Some("name".into())); - assert_eq!(cw_from_custom.description, Some("description".into())); - - let cw_customization = CodewhispererCustomization::builder().arn("arn").build().unwrap(); - - let custom_from_cw: Customization = cw_customization.into(); - let cw_from_custom: CodewhispererCustomization = custom_from_cw.into(); - - assert_eq!(cw_from_custom.arn, "arn"); - assert_eq!(cw_from_custom.name, None); - assert_eq!(cw_from_custom.description, None); - - let consolas_customization = ConsolasCustomization::builder() - .arn("arn") - .customization_name("name") - .description("description") - .status(CustomizationStatus::Activated) - .updated_at(DateTime::from_secs(0)) - .build() - .unwrap(); - - let custom_from_consolas: Customization = consolas_customization.into(); - - assert_eq!(custom_from_consolas.arn, "arn"); - assert_eq!(custom_from_consolas.name, Some("name".into())); - assert_eq!(custom_from_consolas.description, Some("description".into())); - } - - #[test] - fn test_customization_serde() { - let customization = Customization { - arn: "arn".into(), - name: Some("name".into()), - description: Some("description".into()), - }; - - let serialized = serde_json::to_string(&customization).unwrap(); - assert_eq!(serialized, r#"{"arn":"arn","name":"name","description":"description"}"#); - - let deserialized: Customization = serde_json::from_str(&serialized).unwrap(); - assert_eq!(deserialized, customization); - - let customization = Customization { - arn: "arn".into(), - name: None, - description: None, - }; - - let serialized = serde_json::to_string(&customization).unwrap(); - assert_eq!(serialized, r#"{"arn":"arn"}"#); - - let deserialized: Customization = serde_json::from_str(&serialized).unwrap(); - assert_eq!(deserialized, customization); - } -} diff --git a/crates/chat-cli/src/api_client/endpoints.rs b/crates/chat-cli/src/api_client/endpoints.rs deleted file mode 100644 index f761ce8b2..000000000 --- a/crates/chat-cli/src/api_client/endpoints.rs +++ /dev/null @@ -1,91 +0,0 @@ -use std::borrow::Cow; - -use aws_config::Region; -use serde_json::Value; -use tracing::error; - -use crate::database::Database; -use crate::database::settings::Setting; - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct Endpoint { - pub url: Cow<'static, str>, - pub region: Region, -} - -impl Endpoint { - pub const CODEWHISPERER_ENDPOINTS: [Self; 2] = [Self::DEFAULT_ENDPOINT, Self::FRA_ENDPOINT]; - pub const DEFAULT_ENDPOINT: Self = Self { - url: Cow::Borrowed("https://q.us-east-1.amazonaws.com"), - region: Region::from_static("us-east-1"), - }; - pub const FRA_ENDPOINT: Self = Self { - url: Cow::Borrowed("https://q.eu-central-1.amazonaws.com/"), - region: Region::from_static("eu-central-1"), - }; - - pub fn configured_value(database: &Database) -> Self { - let (endpoint, region) = if let Some(Value::Object(o)) = database.settings.get(Setting::ApiCodeWhispererService) - { - // The following branch is evaluated in case the user has set their own endpoint. - ( - o.get("endpoint").and_then(|v| v.as_str()).map(|v| v.to_owned()), - o.get("region").and_then(|v| v.as_str()).map(|v| v.to_owned()), - ) - } else if let Ok(Some(profile)) = database.get_auth_profile() { - // The following branch is evaluated in the case of user profile being set. - let region = profile.arn.split(':').nth(3).unwrap_or_default().to_owned(); - match Self::CODEWHISPERER_ENDPOINTS - .iter() - .find(|e| e.region().as_ref() == region) - { - Some(endpoint) => (Some(endpoint.url().to_owned()), Some(region)), - None => { - error!("Failed to find endpoint for region: {region}"); - (None, None) - }, - } - } else { - (None, None) - }; - - match (endpoint, region) { - (Some(endpoint), Some(region)) => Self { - url: endpoint.clone().into(), - region: Region::new(region.clone()), - }, - _ => Endpoint::DEFAULT_ENDPOINT, - } - } - - pub(crate) fn url(&self) -> &str { - &self.url - } - - pub(crate) fn region(&self) -> &Region { - &self.region - } -} - -#[cfg(test)] -mod tests { - use url::Url; - - use super::*; - - #[tokio::test] - async fn test_endpoints() { - let database = Database::new().await.unwrap(); - let _ = Endpoint::configured_value(&database); - - let prod = &Endpoint::DEFAULT_ENDPOINT; - Url::parse(prod.url()).unwrap(); - - let custom = Endpoint { - region: Region::new("us-west-2"), - url: "https://example.com".into(), - }; - Url::parse(custom.url()).unwrap(); - assert_eq!(custom.region(), &Region::new("us-west-2")); - } -} diff --git a/crates/chat-cli/src/api_client/error.rs b/crates/chat-cli/src/api_client/error.rs deleted file mode 100644 index 88f4a1f70..000000000 --- a/crates/chat-cli/src/api_client/error.rs +++ /dev/null @@ -1,226 +0,0 @@ -use amzn_codewhisperer_client::operation::create_subscription_token::CreateSubscriptionTokenError; -use amzn_codewhisperer_client::operation::generate_completions::GenerateCompletionsError; -use amzn_codewhisperer_client::operation::list_available_customizations::ListAvailableCustomizationsError; -use amzn_codewhisperer_client::operation::list_available_profiles::ListAvailableProfilesError; -use amzn_codewhisperer_client::operation::send_telemetry_event::SendTelemetryEventError; -pub use amzn_codewhisperer_streaming_client::operation::generate_assistant_response::GenerateAssistantResponseError; -use amzn_codewhisperer_streaming_client::types::error::ChatResponseStreamError as CodewhispererChatResponseStreamError; -use amzn_consolas_client::operation::generate_recommendations::GenerateRecommendationsError; -use amzn_consolas_client::operation::list_customizations::ListCustomizationsError; -use amzn_qdeveloper_streaming_client::operation::send_message::SendMessageError as QDeveloperSendMessageError; -use amzn_qdeveloper_streaming_client::types::error::ChatResponseStreamError as QDeveloperChatResponseStreamError; -use aws_credential_types::provider::error::CredentialsError; -use aws_sdk_ssooidc::error::ProvideErrorMetadata; -use aws_smithy_runtime_api::client::orchestrator::HttpResponse; -pub use aws_smithy_runtime_api::client::result::SdkError; -use aws_smithy_runtime_api::http::Response; -use aws_smithy_types::event_stream::RawMessage; -use thiserror::Error; - -use crate::auth::AuthError; -use crate::aws_common::SdkErrorDisplay; -use crate::telemetry::ReasonCode; - -#[derive(Debug, Error)] -pub enum ApiClientError { - // Generate completions errors - #[error("{}", SdkErrorDisplay(.0))] - GenerateCompletions(#[from] SdkError), - #[error("{}", SdkErrorDisplay(.0))] - GenerateRecommendations(#[from] SdkError), - - // List customizations error - #[error("{}", SdkErrorDisplay(.0))] - ListAvailableCustomizations(#[from] SdkError), - #[error("{}", SdkErrorDisplay(.0))] - ListAvailableServices(#[from] SdkError), - - // Telemetry client error - #[error("{}", SdkErrorDisplay(.0))] - SendTelemetryEvent(#[from] SdkError), - - // Send message errors - #[error("{}", SdkErrorDisplay(.0))] - CodewhispererGenerateAssistantResponse(#[from] SdkError), - #[error("{}", SdkErrorDisplay(.0))] - QDeveloperSendMessage(#[from] SdkError), - - // chat stream errors - #[error("{}", SdkErrorDisplay(.0))] - CodewhispererChatResponseStream(#[from] SdkError), - #[error("{}", SdkErrorDisplay(.0))] - QDeveloperChatResponseStream(#[from] SdkError), - - // quota breach - #[error("quota has reached its limit")] - QuotaBreach { - message: &'static str, - status_code: Option, - }, - - // Separate from quota breach (somehow) - #[error("monthly query limit reached")] - MonthlyLimitReached { status_code: Option }, - - #[error("{}", SdkErrorDisplay(.0))] - CreateSubscriptionToken(#[from] SdkError), - - /// Returned from the backend when the user input is too large to fit within the model context - /// window. - /// - /// Note that we currently do not receive token usage information regarding how large the - /// context window is. - #[error("the context window has overflowed")] - ContextWindowOverflow { status_code: Option }, - - #[error(transparent)] - SmithyBuild(#[from] aws_smithy_types::error::operation::BuildError), - - #[error(transparent)] - ListAvailableProfilesError(#[from] SdkError), - - #[error(transparent)] - AuthError(#[from] AuthError), - - #[error( - "The model you've selected is temporarily unavailable. Please use '/model' to select a different model and try again." - )] - ModelOverloadedError { - request_id: Option, - status_code: Option, - }, - - // Credential errors - #[error("failed to load credentials: {}", .0)] - Credentials(CredentialsError), -} - -impl ApiClientError { - pub fn status_code(&self) -> Option { - match self { - Self::GenerateCompletions(e) => sdk_status_code(e), - Self::GenerateRecommendations(e) => sdk_status_code(e), - Self::ListAvailableCustomizations(e) => sdk_status_code(e), - Self::ListAvailableServices(e) => sdk_status_code(e), - Self::CodewhispererGenerateAssistantResponse(e) => sdk_status_code(e), - Self::QDeveloperSendMessage(e) => sdk_status_code(e), - Self::CodewhispererChatResponseStream(_) => None, - Self::QDeveloperChatResponseStream(_) => None, - Self::ListAvailableProfilesError(e) => sdk_status_code(e), - Self::SendTelemetryEvent(e) => sdk_status_code(e), - Self::CreateSubscriptionToken(e) => sdk_status_code(e), - Self::QuotaBreach { status_code, .. } => *status_code, - Self::ContextWindowOverflow { status_code } => *status_code, - Self::SmithyBuild(_) => None, - Self::AuthError(_) => None, - Self::ModelOverloadedError { status_code, .. } => *status_code, - Self::MonthlyLimitReached { status_code } => *status_code, - Self::Credentials(_e) => None, - } - } -} - -impl ReasonCode for ApiClientError { - fn reason_code(&self) -> String { - match self { - Self::GenerateCompletions(e) => sdk_error_code(e), - Self::GenerateRecommendations(e) => sdk_error_code(e), - Self::ListAvailableCustomizations(e) => sdk_error_code(e), - Self::ListAvailableServices(e) => sdk_error_code(e), - Self::CodewhispererGenerateAssistantResponse(e) => sdk_error_code(e), - Self::QDeveloperSendMessage(e) => sdk_error_code(e), - Self::CodewhispererChatResponseStream(e) => sdk_error_code(e), - Self::QDeveloperChatResponseStream(e) => sdk_error_code(e), - Self::ListAvailableProfilesError(e) => sdk_error_code(e), - Self::SendTelemetryEvent(e) => sdk_error_code(e), - Self::CreateSubscriptionToken(e) => sdk_error_code(e), - Self::QuotaBreach { .. } => "QuotaBreachError".to_string(), - Self::ContextWindowOverflow { .. } => "ContextWindowOverflow".to_string(), - Self::SmithyBuild(_) => "SmithyBuildError".to_string(), - Self::AuthError(_) => "AuthError".to_string(), - Self::ModelOverloadedError { .. } => "ModelOverloadedError".to_string(), - Self::MonthlyLimitReached { .. } => "MonthlyLimitReached".to_string(), - Self::Credentials(_) => "CredentialsError".to_string(), - } - } -} - -fn sdk_error_code(e: &SdkError) -> String { - e.as_service_error() - .and_then(|se| se.meta().code().map(str::to_string)) - .unwrap_or_else(|| e.to_string()) -} - -fn sdk_status_code(e: &SdkError) -> Option { - e.raw_response().map(|res| res.status().as_u16()) -} - -#[cfg(test)] -mod tests { - use std::error::Error as _; - - use aws_smithy_runtime_api::http::Response; - use aws_smithy_types::body::SdkBody; - use aws_smithy_types::event_stream::Message; - - use super::*; - - fn response() -> Response { - Response::new(500.try_into().unwrap(), SdkBody::empty()) - } - - fn raw_message() -> RawMessage { - RawMessage::Decoded(Message::new(b"".to_vec())) - } - - fn all_errors() -> Vec { - vec![ - ApiClientError::Credentials(CredentialsError::unhandled("")), - ApiClientError::GenerateCompletions(SdkError::service_error( - GenerateCompletionsError::unhandled(""), - response(), - )), - ApiClientError::GenerateRecommendations(SdkError::service_error( - GenerateRecommendationsError::unhandled(""), - response(), - )), - ApiClientError::ListAvailableCustomizations(SdkError::service_error( - ListAvailableCustomizationsError::unhandled(""), - response(), - )), - ApiClientError::ListAvailableServices(SdkError::service_error( - ListCustomizationsError::unhandled(""), - response(), - )), - ApiClientError::CodewhispererGenerateAssistantResponse(SdkError::service_error( - GenerateAssistantResponseError::unhandled(""), - response(), - )), - ApiClientError::QDeveloperSendMessage(SdkError::service_error( - QDeveloperSendMessageError::unhandled(""), - response(), - )), - ApiClientError::CreateSubscriptionToken(SdkError::service_error( - CreateSubscriptionTokenError::unhandled(""), - response(), - )), - ApiClientError::CodewhispererChatResponseStream(SdkError::service_error( - CodewhispererChatResponseStreamError::unhandled(""), - raw_message(), - )), - ApiClientError::QDeveloperChatResponseStream(SdkError::service_error( - QDeveloperChatResponseStreamError::unhandled(""), - raw_message(), - )), - ApiClientError::SmithyBuild(aws_smithy_types::error::operation::BuildError::other("")), - ] - } - - #[test] - fn test_errors() { - for error in all_errors() { - let _ = error.source(); - println!("{error} {error:?}"); - } - } -} diff --git a/crates/chat-cli/src/api_client/mod.rs b/crates/chat-cli/src/api_client/mod.rs deleted file mode 100644 index be01fb4ba..000000000 --- a/crates/chat-cli/src/api_client/mod.rs +++ /dev/null @@ -1,591 +0,0 @@ -mod credentials; -pub mod customization; -mod endpoints; -mod error; -pub mod model; -mod opt_out; -pub mod profile; -pub mod send_message_output; - -use std::sync::Arc; -use std::time::Duration; - -use amzn_codewhisperer_client::Client as CodewhispererClient; -use amzn_codewhisperer_client::operation::create_subscription_token::CreateSubscriptionTokenOutput; -use amzn_codewhisperer_client::types::{ - OptOutPreference, - SubscriptionStatus, - TelemetryEvent, - UserContext, -}; -use amzn_codewhisperer_streaming_client::Client as CodewhispererStreamingClient; -use amzn_qdeveloper_streaming_client::Client as QDeveloperStreamingClient; -use amzn_qdeveloper_streaming_client::types::Origin; -use aws_config::retry::RetryConfig; -use aws_config::timeout::TimeoutConfig; -use aws_credential_types::Credentials; -use aws_credential_types::provider::ProvideCredentials; -use aws_types::request_id::RequestId; -use aws_types::sdk_config::StalledStreamProtectionConfig; -pub use endpoints::Endpoint; -pub use error::ApiClientError; -use parking_lot::Mutex; -pub use profile::list_available_profiles; -use serde_json::Map; -use tracing::{ - debug, - error, -}; - -use crate::api_client::credentials::CredentialsChain; -use crate::api_client::model::{ - ChatResponseStream, - ConversationState, -}; -use crate::api_client::opt_out::OptOutInterceptor; -use crate::api_client::send_message_output::SendMessageOutput; -use crate::auth::builder_id::BearerResolver; -use crate::aws_common::{ - UserAgentOverrideInterceptor, - app_name, - behavior_version, -}; -use crate::database::settings::Setting; -use crate::database::{ - AuthProfile, - Database, -}; -use crate::os::{ - Env, - Fs, -}; - -// Opt out constants -pub const X_AMZN_CODEWHISPERER_OPT_OUT_HEADER: &str = "x-amzn-codewhisperer-optout"; - -// TODO(bskiser): confirm timeout is updated to an appropriate value? -const DEFAULT_TIMEOUT_DURATION: Duration = Duration::from_secs(60 * 5); - -#[derive(Clone, Debug)] -pub struct ApiClient { - client: CodewhispererClient, - streaming_client: Option, - sigv4_streaming_client: Option, - mock_client: Option>>>>, - profile: Option, -} - -impl ApiClient { - pub async fn new( - env: &Env, - fs: &Fs, - database: &mut Database, - // endpoint is only passed here for list_profiles where it needs to be called for each region - endpoint: Option, - ) -> Result { - let endpoint = endpoint.unwrap_or(Endpoint::configured_value(database)); - - let credentials = Credentials::new("xxx", "xxx", None, None, "xxx"); - let bearer_sdk_config = aws_config::defaults(behavior_version()) - .region(endpoint.region.clone()) - .credentials_provider(credentials) - .timeout_config(timeout_config(database)) - .retry_config(retry_config()) - .load() - .await; - - let client = CodewhispererClient::from_conf( - amzn_codewhisperer_client::config::Builder::from(&bearer_sdk_config) - .http_client(crate::aws_common::http_client::client()) - .interceptor(OptOutInterceptor::new(database)) - .interceptor(UserAgentOverrideInterceptor::new()) - .bearer_token_resolver(BearerResolver) - .app_name(app_name()) - .endpoint_url(endpoint.url()) - .build(), - ); - - if cfg!(test) { - let mut this = Self { - client, - streaming_client: None, - sigv4_streaming_client: None, - mock_client: None, - profile: None, - }; - - if let Ok(json) = env.get("Q_MOCK_CHAT_RESPONSE") { - this.set_mock_output(serde_json::from_str(fs.read_to_string(json).await.unwrap().as_str()).unwrap()); - } - - return Ok(this); - } - - // If SIGV4_AUTH_ENABLED is true, use Q developer client - let mut streaming_client = None; - let mut sigv4_streaming_client = None; - match env.get("AMAZON_Q_SIGV4").is_ok() { - true => { - let credentials_chain = CredentialsChain::new().await; - if let Err(err) = credentials_chain.provide_credentials().await { - return Err(ApiClientError::Credentials(err)); - }; - - sigv4_streaming_client = Some(QDeveloperStreamingClient::from_conf( - amzn_qdeveloper_streaming_client::config::Builder::from( - &aws_config::defaults(behavior_version()) - .region(endpoint.region.clone()) - .credentials_provider(credentials_chain) - .timeout_config(timeout_config(database)) - .retry_config(retry_config()) - .load() - .await, - ) - .http_client(crate::aws_common::http_client::client()) - .interceptor(OptOutInterceptor::new(database)) - .interceptor(UserAgentOverrideInterceptor::new()) - .app_name(app_name()) - .endpoint_url(endpoint.url()) - .stalled_stream_protection(stalled_stream_protection_config()) - .build(), - )); - }, - false => { - streaming_client = Some(CodewhispererStreamingClient::from_conf( - amzn_codewhisperer_streaming_client::config::Builder::from(&bearer_sdk_config) - .http_client(crate::aws_common::http_client::client()) - .interceptor(OptOutInterceptor::new(database)) - .interceptor(UserAgentOverrideInterceptor::new()) - .bearer_token_resolver(BearerResolver) - .app_name(app_name()) - .endpoint_url(endpoint.url()) - .stalled_stream_protection(stalled_stream_protection_config()) - .build(), - )); - }, - } - - let profile = match database.get_auth_profile() { - Ok(profile) => profile, - Err(err) => { - error!("Failed to get auth profile: {err}"); - None - }, - }; - - Ok(Self { - client, - streaming_client, - sigv4_streaming_client, - mock_client: None, - profile, - }) - } - - pub async fn send_telemetry_event( - &self, - telemetry_event: TelemetryEvent, - user_context: UserContext, - telemetry_enabled: bool, - model: Option, - ) -> Result<(), ApiClientError> { - if cfg!(test) { - return Ok(()); - } - - self.client - .send_telemetry_event() - .telemetry_event(telemetry_event) - .user_context(user_context) - .opt_out_preference(match telemetry_enabled { - true => OptOutPreference::OptIn, - false => OptOutPreference::OptOut, - }) - .set_profile_arn(self.profile.as_ref().map(|p| p.arn.clone())) - .set_model_id(model) - .send() - .await?; - - Ok(()) - } - - pub async fn list_available_profiles(&self) -> Result, ApiClientError> { - if cfg!(test) { - return Ok(vec![ - AuthProfile { - arn: "my:arn:1".to_owned(), - profile_name: "MyProfile".to_owned(), - }, - AuthProfile { - arn: "my:arn:2".to_owned(), - profile_name: "MyOtherProfile".to_owned(), - }, - ]); - } - - let mut profiles = vec![]; - let mut stream = self.client.list_available_profiles().into_paginator().send(); - while let Some(profiles_output) = stream.next().await { - profiles.extend(profiles_output?.profiles().iter().cloned().map(AuthProfile::from)); - } - - Ok(profiles) - } - - pub async fn create_subscription_token(&self) -> Result { - if cfg!(test) { - return Ok(CreateSubscriptionTokenOutput::builder() - .set_encoded_verification_url(Some("test/url".to_string())) - .set_status(Some(SubscriptionStatus::Inactive)) - .set_token(Some("test-token".to_string())) - .build()?); - } - - self.client - .create_subscription_token() - .send() - .await - .map_err(ApiClientError::CreateSubscriptionToken) - } - - pub async fn send_message(&self, conversation: ConversationState) -> Result { - debug!("Sending conversation: {:#?}", conversation); - - let ConversationState { - conversation_id, - user_input_message, - history, - } = conversation; - - let model_id_opt: Option = user_input_message.model_id.clone(); - - if let Some(client) = &self.streaming_client { - let conversation_state = amzn_codewhisperer_streaming_client::types::ConversationState::builder() - .set_conversation_id(conversation_id) - .current_message( - amzn_codewhisperer_streaming_client::types::ChatMessage::UserInputMessage( - user_input_message.into(), - ), - ) - .chat_trigger_type(amzn_codewhisperer_streaming_client::types::ChatTriggerType::Manual) - .set_history( - history - .map(|v| v.into_iter().map(|i| i.try_into()).collect::, _>>()) - .transpose()?, - ) - .build() - .expect("building conversation should not fail"); - - match client - .generate_assistant_response() - .conversation_state(conversation_state) - .set_profile_arn(self.profile.as_ref().map(|p| p.arn.clone())) - .send() - .await - { - Ok(response) => Ok(SendMessageOutput::Codewhisperer(response)), - Err(err) => { - let status_code = err.raw_response().map(|res| res.status().as_u16()); - let is_quota_breach = status_code.is_some_and(|status| status == 429); - let is_context_window_overflow = err.as_service_error().is_some_and(|err| { - matches!(err, err if err.meta().code() == Some("ValidationException") && err.meta().message() == Some("Input is too long.")) - }); - - let is_model_unavailable = model_id_opt.is_some() - && status_code.is_some_and(|status| status == 500) - && err.as_service_error().is_some_and(|err| { - err.meta().message() - == Some( - "Encountered unexpectedly high load when processing the request, please try again.", - ) - }); - - let is_monthly_limit_err = err - .raw_response() - .and_then(|resp| resp.body().bytes()) - .and_then(|bytes| match String::from_utf8(bytes.to_vec()) { - Ok(s) => Some(s.contains("MONTHLY_REQUEST_COUNT")), - Err(_) => None, - }) - .unwrap_or(false); - - if is_quota_breach { - return Err(ApiClientError::QuotaBreach { - message: "quota has reached its limit", - status_code, - }); - } - - if is_context_window_overflow { - return Err(ApiClientError::ContextWindowOverflow { status_code }); - } - - if is_model_unavailable { - return Err(ApiClientError::ModelOverloadedError { - request_id: err - .as_service_error() - .and_then(|err| err.meta().request_id()) - .map(|s| s.to_string()), - status_code, - }); - } - - if is_monthly_limit_err { - return Err(ApiClientError::MonthlyLimitReached { status_code }); - } - - Err(err.into()) - }, - } - } else if let Some(client) = &self.sigv4_streaming_client { - let conversation_state = amzn_qdeveloper_streaming_client::types::ConversationState::builder() - .set_conversation_id(conversation_id) - .current_message(amzn_qdeveloper_streaming_client::types::ChatMessage::UserInputMessage( - user_input_message.into(), - )) - .chat_trigger_type(amzn_qdeveloper_streaming_client::types::ChatTriggerType::Manual) - .set_history( - history - .map(|v| v.into_iter().map(|i| i.try_into()).collect::, _>>()) - .transpose()?, - ) - .build() - .expect("building conversation_state should not fail"); - - match client - .send_message() - .conversation_state(conversation_state) - .set_source(Some(Origin::from("CLI"))) - .send() - .await - { - Ok(response) => Ok(SendMessageOutput::QDeveloper(response)), - Err(err) => { - let status_code = err.raw_response().map(|res| res.status().as_u16()); - let is_quota_breach = status_code.is_some_and(|status| status == 429); - let is_context_window_overflow = err.as_service_error().is_some_and(|err| { - matches!(err, err if err.meta().code() == Some("ValidationException") && err.meta().message() == Some("Input is too long.")) - }); - - let is_model_unavailable = model_id_opt.is_some() - && status_code.is_some_and(|status| status == 500) - && err.as_service_error().is_some_and(|err| { - err.meta().message() - == Some( - "Encountered unexpectedly high load when processing the request, please try again.", - ) - }); - - let is_monthly_limit_err = err - .raw_response() - .and_then(|resp| resp.body().bytes()) - .and_then(|bytes| match String::from_utf8(bytes.to_vec()) { - Ok(s) => Some(s.contains("MONTHLY_REQUEST_COUNT")), - Err(_) => None, - }) - .unwrap_or(false); - - if is_quota_breach { - return Err(ApiClientError::QuotaBreach { - message: "quota has reached its limit", - status_code, - }); - } - - if is_context_window_overflow { - return Err(ApiClientError::ContextWindowOverflow { status_code }); - } - - if is_model_unavailable { - return Err(ApiClientError::ModelOverloadedError { - request_id: err - .as_service_error() - .and_then(|err| err.meta().request_id()) - .map(|s| s.to_string()), - status_code, - }); - } - - if is_monthly_limit_err { - return Err(ApiClientError::MonthlyLimitReached { status_code }); - } - - Err(err.into()) - }, - } - } else if let Some(client) = &self.mock_client { - let mut new_events = client.lock().next().unwrap_or_default().clone(); - new_events.reverse(); - - return Ok(SendMessageOutput::Mock(new_events)); - } else { - unreachable!("One of the clients must be created by this point"); - } - } - - /// Only meant for testing. Do not use outside of testing responses. - pub fn set_mock_output(&mut self, json: serde_json::Value) { - let mut mock = Vec::new(); - for response in json.as_array().unwrap() { - let mut stream = Vec::new(); - for event in response.as_array().unwrap() { - match event { - serde_json::Value::String(assistant_text) => { - stream.push(ChatResponseStream::AssistantResponseEvent { - content: assistant_text.clone(), - }); - }, - serde_json::Value::Object(tool_use) => { - stream.append(&mut split_tool_use_event(tool_use)); - }, - other => panic!("Unexpected value: {:?}", other), - } - } - mock.push(stream); - } - - self.mock_client = Some(Arc::new(Mutex::new(mock.into_iter()))); - } -} - -fn timeout_config(database: &Database) -> TimeoutConfig { - let timeout = database - .settings - .get_int(Setting::ApiTimeout) - .and_then(|i| i.try_into().ok()) - .map_or(DEFAULT_TIMEOUT_DURATION, Duration::from_millis); - - TimeoutConfig::builder() - .read_timeout(timeout) - .operation_timeout(timeout) - .operation_attempt_timeout(timeout) - .connect_timeout(timeout) - .build() -} - -fn retry_config() -> RetryConfig { - RetryConfig::standard().with_max_attempts(1) -} - -pub fn stalled_stream_protection_config() -> StalledStreamProtectionConfig { - StalledStreamProtectionConfig::enabled() - .grace_period(Duration::from_secs(60 * 5)) - .build() -} - -fn split_tool_use_event(value: &Map) -> Vec { - let tool_use_id = value.get("tool_use_id").unwrap().as_str().unwrap().to_string(); - let name = value.get("name").unwrap().as_str().unwrap().to_string(); - let args_str = value.get("args").unwrap().to_string(); - let split_point = args_str.len() / 2; - vec![ - ChatResponseStream::ToolUseEvent { - tool_use_id: tool_use_id.clone(), - name: name.clone(), - input: None, - stop: None, - }, - ChatResponseStream::ToolUseEvent { - tool_use_id: tool_use_id.clone(), - name: name.clone(), - input: Some(args_str.split_at(split_point).0.to_string()), - stop: None, - }, - ChatResponseStream::ToolUseEvent { - tool_use_id: tool_use_id.clone(), - name: name.clone(), - input: Some(args_str.split_at(split_point).1.to_string()), - stop: None, - }, - ChatResponseStream::ToolUseEvent { - tool_use_id: tool_use_id.clone(), - name: name.clone(), - input: None, - stop: Some(true), - }, - ] -} - -#[cfg(test)] -mod tests { - use amzn_codewhisperer_client::types::{ - ChatAddMessageEvent, - IdeCategory, - OperatingSystem, - }; - - use super::*; - use crate::api_client::model::UserInputMessage; - - #[tokio::test] - async fn create_clients() { - let env = Env::new(); - let fs = Fs::new(); - let mut database = crate::database::Database::new().await.unwrap(); - let _ = ApiClient::new(&env, &fs, &mut database, None).await; - } - - #[tokio::test] - async fn test_mock() { - let env = Env::new(); - let fs = Fs::new(); - let mut database = crate::database::Database::new().await.unwrap(); - let mut client = ApiClient::new(&env, &fs, &mut database, None).await.unwrap(); - client - .send_telemetry_event( - TelemetryEvent::ChatAddMessageEvent( - ChatAddMessageEvent::builder() - .conversation_id("") - .message_id("") - .build() - .unwrap(), - ), - UserContext::builder() - .ide_category(IdeCategory::Cli) - .operating_system(OperatingSystem::Linux) - .product("") - .build() - .unwrap(), - false, - Some("model".to_owned()), - ) - .await - .unwrap(); - - client.mock_client = Some(Arc::new(Mutex::new( - vec![vec![ - ChatResponseStream::AssistantResponseEvent { - content: "Hello!".to_owned(), - }, - ChatResponseStream::AssistantResponseEvent { - content: " How can I".to_owned(), - }, - ChatResponseStream::AssistantResponseEvent { - content: " assist you today?".to_owned(), - }, - ]] - .into_iter(), - ))); - - let mut output = client - .send_message(ConversationState { - conversation_id: None, - user_input_message: UserInputMessage { - images: None, - content: "Hello".into(), - user_input_message_context: None, - user_intent: None, - model_id: Some("model".to_owned()), - }, - history: None, - }) - .await - .unwrap(); - - let mut output_content = String::new(); - while let Some(ChatResponseStream::AssistantResponseEvent { content }) = output.recv().await.unwrap() { - output_content.push_str(&content); - } - assert_eq!(output_content, "Hello! How can I assist you today?"); - } -} diff --git a/crates/chat-cli/src/api_client/model.rs b/crates/chat-cli/src/api_client/model.rs deleted file mode 100644 index 808081ec6..000000000 --- a/crates/chat-cli/src/api_client/model.rs +++ /dev/null @@ -1,1227 +0,0 @@ -use std::collections::HashMap; - -use aws_smithy_types::{ - Blob, - Document as AwsDocument, -}; -use serde::de::{ - self, - MapAccess, - SeqAccess, - Visitor, -}; -use serde::{ - Deserialize, - Deserializer, - Serialize, - Serializer, -}; - -#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct FileContext { - pub left_file_content: String, - pub right_file_content: String, - pub filename: String, - pub programming_language: ProgrammingLanguage, -} - -#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct ProgrammingLanguage { - pub language_name: LanguageName, -} - -#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, strum::AsRefStr)] -#[serde(rename_all = "lowercase")] -#[strum(serialize_all = "lowercase")] -pub enum LanguageName { - Python, - Javascript, - Java, - Csharp, - Typescript, - C, - Cpp, - Go, - Kotlin, - Php, - Ruby, - Rust, - Scala, - Shell, - Sql, -} - -#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct ReferenceTrackerConfiguration { - pub recommendations_with_references: RecommendationsWithReferences, -} - -#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] -#[serde(rename_all = "UPPERCASE")] -pub enum RecommendationsWithReferences { - Block, -} - -#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct RecommendationsInput { - pub file_context: FileContext, - pub max_results: i32, - pub next_token: Option, -} - -#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct RecommendationsOutput { - pub recommendations: Vec, - pub next_token: Option, - pub session_id: Option, - pub request_id: Option, -} - -#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct Recommendation { - pub content: String, -} - -// ========= -// Streaming -// ========= - -#[derive(Debug, Clone)] -pub struct ConversationState { - pub conversation_id: Option, - pub user_input_message: UserInputMessage, - pub history: Option>, -} - -#[derive(Debug, Clone)] -pub enum ChatMessage { - AssistantResponseMessage(AssistantResponseMessage), - UserInputMessage(UserInputMessage), -} - -impl TryFrom for amzn_codewhisperer_streaming_client::types::ChatMessage { - type Error = aws_smithy_types::error::operation::BuildError; - - fn try_from(value: ChatMessage) -> Result { - Ok(match value { - ChatMessage::AssistantResponseMessage(message) => { - amzn_codewhisperer_streaming_client::types::ChatMessage::AssistantResponseMessage(message.try_into()?) - }, - ChatMessage::UserInputMessage(message) => { - amzn_codewhisperer_streaming_client::types::ChatMessage::UserInputMessage(message.into()) - }, - }) - } -} - -impl TryFrom for amzn_qdeveloper_streaming_client::types::ChatMessage { - type Error = aws_smithy_types::error::operation::BuildError; - - fn try_from(value: ChatMessage) -> Result { - Ok(match value { - ChatMessage::AssistantResponseMessage(message) => { - amzn_qdeveloper_streaming_client::types::ChatMessage::AssistantResponseMessage(message.try_into()?) - }, - ChatMessage::UserInputMessage(message) => { - amzn_qdeveloper_streaming_client::types::ChatMessage::UserInputMessage(message.into()) - }, - }) - } -} - -/// Wrapper around [aws_smithy_types::Document]. -/// -/// Used primarily so we can implement [Serialize] and [Deserialize] for -/// [aws_smith_types::Document]. -#[derive(Debug, Clone)] -pub struct FigDocument(AwsDocument); - -impl From for FigDocument { - fn from(value: AwsDocument) -> Self { - Self(value) - } -} - -impl From for AwsDocument { - fn from(value: FigDocument) -> Self { - value.0 - } -} - -/// Internal type used only during serialization for `FigDocument` to avoid unnecessary cloning. -#[derive(Debug, Clone)] -struct FigDocumentRef<'a>(&'a AwsDocument); - -impl Serialize for FigDocumentRef<'_> { - fn serialize(&self, serializer: S) -> Result - where - S: Serializer, - { - use aws_smithy_types::Number; - match self.0 { - AwsDocument::Null => serializer.serialize_unit(), - AwsDocument::Bool(b) => serializer.serialize_bool(*b), - AwsDocument::Number(n) => match n { - Number::PosInt(u) => serializer.serialize_u64(*u), - Number::NegInt(i) => serializer.serialize_i64(*i), - Number::Float(f) => serializer.serialize_f64(*f), - }, - AwsDocument::String(s) => serializer.serialize_str(s), - AwsDocument::Array(arr) => { - use serde::ser::SerializeSeq; - let mut seq = serializer.serialize_seq(Some(arr.len()))?; - for value in arr { - seq.serialize_element(&Self(value))?; - } - seq.end() - }, - AwsDocument::Object(m) => { - use serde::ser::SerializeMap; - let mut map = serializer.serialize_map(Some(m.len()))?; - for (k, v) in m { - map.serialize_entry(k, &Self(v))?; - } - map.end() - }, - } - } -} - -impl Serialize for FigDocument { - fn serialize(&self, serializer: S) -> Result - where - S: Serializer, - { - FigDocumentRef(&self.0).serialize(serializer) - } -} - -impl<'de> Deserialize<'de> for FigDocument { - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - use aws_smithy_types::Number; - - struct FigDocumentVisitor; - - impl<'de> Visitor<'de> for FigDocumentVisitor { - type Value = FigDocument; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("any valid JSON value") - } - - fn visit_bool(self, value: bool) -> Result - where - E: de::Error, - { - Ok(FigDocument(AwsDocument::Bool(value))) - } - - fn visit_i64(self, value: i64) -> Result - where - E: de::Error, - { - Ok(FigDocument(AwsDocument::Number(if value < 0 { - Number::NegInt(value) - } else { - Number::PosInt(value as u64) - }))) - } - - fn visit_u64(self, value: u64) -> Result - where - E: de::Error, - { - Ok(FigDocument(AwsDocument::Number(Number::PosInt(value)))) - } - - fn visit_f64(self, value: f64) -> Result - where - E: de::Error, - { - Ok(FigDocument(AwsDocument::Number(Number::Float(value)))) - } - - fn visit_str(self, value: &str) -> Result - where - E: de::Error, - { - Ok(FigDocument(AwsDocument::String(value.to_owned()))) - } - - fn visit_string(self, value: String) -> Result - where - E: de::Error, - { - Ok(FigDocument(AwsDocument::String(value))) - } - - fn visit_none(self) -> Result - where - E: de::Error, - { - Ok(FigDocument(AwsDocument::Null)) - } - - fn visit_some(self, deserializer: D) -> Result - where - D: Deserializer<'de>, - { - Deserialize::deserialize(deserializer) - } - - fn visit_unit(self) -> Result - where - E: de::Error, - { - Ok(FigDocument(AwsDocument::Null)) - } - - fn visit_seq(self, mut seq: A) -> Result - where - A: SeqAccess<'de>, - { - let mut vec = Vec::new(); - - while let Some(elem) = seq.next_element::()? { - vec.push(elem.0); - } - - Ok(FigDocument(AwsDocument::Array(vec))) - } - - fn visit_map(self, mut access: M) -> Result - where - M: MapAccess<'de>, - { - let mut map = HashMap::new(); - - while let Some((key, value)) = access.next_entry::()? { - map.insert(key, value.0); - } - - Ok(FigDocument(AwsDocument::Object(map))) - } - } - - deserializer.deserialize_any(FigDocumentVisitor) - } -} - -/// Information about a tool that can be used. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub enum Tool { - ToolSpecification(ToolSpecification), -} - -impl From for amzn_codewhisperer_streaming_client::types::Tool { - fn from(value: Tool) -> Self { - match value { - Tool::ToolSpecification(v) => amzn_codewhisperer_streaming_client::types::Tool::ToolSpecification(v.into()), - } - } -} - -impl From for amzn_qdeveloper_streaming_client::types::Tool { - fn from(value: Tool) -> Self { - match value { - Tool::ToolSpecification(v) => amzn_qdeveloper_streaming_client::types::Tool::ToolSpecification(v.into()), - } - } -} - -/// The specification for the tool. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ToolSpecification { - /// The name for the tool. - pub name: String, - /// The description for the tool. - pub description: String, - /// The input schema for the tool in JSON format. - pub input_schema: ToolInputSchema, -} - -impl From for amzn_codewhisperer_streaming_client::types::ToolSpecification { - fn from(value: ToolSpecification) -> Self { - Self::builder() - .name(value.name) - .description(value.description) - .input_schema(value.input_schema.into()) - .build() - .expect("building ToolSpecification should not fail") - } -} - -impl From for amzn_qdeveloper_streaming_client::types::ToolSpecification { - fn from(value: ToolSpecification) -> Self { - Self::builder() - .name(value.name) - .description(value.description) - .input_schema(value.input_schema.into()) - .build() - .expect("building ToolSpecification should not fail") - } -} - -/// The input schema for the tool in JSON format. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ToolInputSchema { - pub json: Option, -} - -impl From for amzn_codewhisperer_streaming_client::types::ToolInputSchema { - fn from(value: ToolInputSchema) -> Self { - Self::builder().set_json(value.json.map(Into::into)).build() - } -} - -impl From for amzn_qdeveloper_streaming_client::types::ToolInputSchema { - fn from(value: ToolInputSchema) -> Self { - Self::builder().set_json(value.json.map(Into::into)).build() - } -} - -/// Contains information about a tool that the model is requesting be run. The model uses the result -/// from the tool to generate a response. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ToolUse { - /// The ID for the tool request. - pub tool_use_id: String, - /// The name for the tool. - pub name: String, - /// The input to pass to the tool. - pub input: FigDocument, -} - -impl From for amzn_codewhisperer_streaming_client::types::ToolUse { - fn from(value: ToolUse) -> Self { - Self::builder() - .tool_use_id(value.tool_use_id) - .name(value.name) - .input(value.input.into()) - .build() - .expect("building ToolUse should not fail") - } -} - -impl From for amzn_qdeveloper_streaming_client::types::ToolUse { - fn from(value: ToolUse) -> Self { - Self::builder() - .tool_use_id(value.tool_use_id) - .name(value.name) - .input(value.input.into()) - .build() - .expect("building ToolUse should not fail") - } -} - -/// A tool result that contains the results for a tool request that was previously made. -#[derive(Debug, Clone)] -pub struct ToolResult { - /// The ID for the tool request. - pub tool_use_id: String, - /// Content of the tool result. - pub content: Vec, - /// Status of the tools result. - pub status: ToolResultStatus, -} - -impl From for amzn_codewhisperer_streaming_client::types::ToolResult { - fn from(value: ToolResult) -> Self { - Self::builder() - .tool_use_id(value.tool_use_id) - .set_content(Some(value.content.into_iter().map(Into::into).collect::<_>())) - .status(value.status.into()) - .build() - .expect("building ToolResult should not fail") - } -} - -impl From for amzn_qdeveloper_streaming_client::types::ToolResult { - fn from(value: ToolResult) -> Self { - Self::builder() - .tool_use_id(value.tool_use_id) - .set_content(Some(value.content.into_iter().map(Into::into).collect::<_>())) - .status(value.status.into()) - .build() - .expect("building ToolResult should not fail") - } -} - -#[derive(Debug, Clone)] -pub enum ToolResultContentBlock { - /// A tool result that is JSON format data. - Json(AwsDocument), - /// A tool result that is text. - Text(String), -} - -impl From for amzn_codewhisperer_streaming_client::types::ToolResultContentBlock { - fn from(value: ToolResultContentBlock) -> Self { - match value { - ToolResultContentBlock::Json(document) => Self::Json(document), - ToolResultContentBlock::Text(text) => Self::Text(text), - } - } -} - -impl From for amzn_qdeveloper_streaming_client::types::ToolResultContentBlock { - fn from(value: ToolResultContentBlock) -> Self { - match value { - ToolResultContentBlock::Json(document) => Self::Json(document), - ToolResultContentBlock::Text(text) => Self::Text(text), - } - } -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub enum ToolResultStatus { - Error, - Success, -} - -impl From for amzn_codewhisperer_streaming_client::types::ToolResultStatus { - fn from(value: ToolResultStatus) -> Self { - match value { - ToolResultStatus::Error => Self::Error, - ToolResultStatus::Success => Self::Success, - } - } -} - -impl From for amzn_qdeveloper_streaming_client::types::ToolResultStatus { - fn from(value: ToolResultStatus) -> Self { - match value { - ToolResultStatus::Error => Self::Error, - ToolResultStatus::Success => Self::Success, - } - } -} - -/// Markdown text message. -#[derive(Debug, Clone)] -pub struct AssistantResponseMessage { - /// Unique identifier for the chat message - pub message_id: Option, - /// The content of the text message in markdown format. - pub content: String, - /// ToolUse Request - pub tool_uses: Option>, -} - -impl TryFrom for amzn_codewhisperer_streaming_client::types::AssistantResponseMessage { - type Error = aws_smithy_types::error::operation::BuildError; - - fn try_from(value: AssistantResponseMessage) -> Result { - Self::builder() - .content(value.content) - .set_message_id(value.message_id) - .set_tool_uses(value.tool_uses.map(|uses| uses.into_iter().map(Into::into).collect())) - .build() - } -} - -impl TryFrom for amzn_qdeveloper_streaming_client::types::AssistantResponseMessage { - type Error = aws_smithy_types::error::operation::BuildError; - - fn try_from(value: AssistantResponseMessage) -> Result { - Self::builder() - .content(value.content) - .set_message_id(value.message_id) - .set_tool_uses(value.tool_uses.map(|uses| uses.into_iter().map(Into::into).collect())) - .build() - } -} - -#[non_exhaustive] -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum ChatResponseStream { - AssistantResponseEvent { - content: String, - }, - /// Streaming response event for generated code text. - CodeEvent { - content: String, - }, - // TODO: finish events here - CodeReferenceEvent(()), - FollowupPromptEvent(()), - IntentsEvent(()), - InvalidStateEvent { - reason: String, - message: String, - }, - MessageMetadataEvent { - conversation_id: Option, - utterance_id: Option, - }, - SupplementaryWebLinksEvent(()), - ToolUseEvent { - tool_use_id: String, - name: String, - input: Option, - stop: Option, - }, - - #[non_exhaustive] - Unknown, -} - -impl From for ChatResponseStream { - fn from(value: amzn_codewhisperer_streaming_client::types::ChatResponseStream) -> Self { - match value { - amzn_codewhisperer_streaming_client::types::ChatResponseStream::AssistantResponseEvent( - amzn_codewhisperer_streaming_client::types::AssistantResponseEvent { content, .. }, - ) => ChatResponseStream::AssistantResponseEvent { content }, - amzn_codewhisperer_streaming_client::types::ChatResponseStream::CodeEvent( - amzn_codewhisperer_streaming_client::types::CodeEvent { content, .. }, - ) => ChatResponseStream::CodeEvent { content }, - amzn_codewhisperer_streaming_client::types::ChatResponseStream::CodeReferenceEvent(_) => { - ChatResponseStream::CodeReferenceEvent(()) - }, - amzn_codewhisperer_streaming_client::types::ChatResponseStream::FollowupPromptEvent(_) => { - ChatResponseStream::FollowupPromptEvent(()) - }, - amzn_codewhisperer_streaming_client::types::ChatResponseStream::IntentsEvent(_) => { - ChatResponseStream::IntentsEvent(()) - }, - amzn_codewhisperer_streaming_client::types::ChatResponseStream::InvalidStateEvent( - amzn_codewhisperer_streaming_client::types::InvalidStateEvent { reason, message, .. }, - ) => ChatResponseStream::InvalidStateEvent { - reason: reason.to_string(), - message, - }, - amzn_codewhisperer_streaming_client::types::ChatResponseStream::MessageMetadataEvent( - amzn_codewhisperer_streaming_client::types::MessageMetadataEvent { - conversation_id, - utterance_id, - .. - }, - ) => ChatResponseStream::MessageMetadataEvent { - conversation_id, - utterance_id, - }, - amzn_codewhisperer_streaming_client::types::ChatResponseStream::ToolUseEvent( - amzn_codewhisperer_streaming_client::types::ToolUseEvent { - tool_use_id, - name, - input, - stop, - .. - }, - ) => ChatResponseStream::ToolUseEvent { - tool_use_id, - name, - input, - stop, - }, - amzn_codewhisperer_streaming_client::types::ChatResponseStream::SupplementaryWebLinksEvent(_) => { - ChatResponseStream::SupplementaryWebLinksEvent(()) - }, - _ => ChatResponseStream::Unknown, - } - } -} - -impl From for ChatResponseStream { - fn from(value: amzn_qdeveloper_streaming_client::types::ChatResponseStream) -> Self { - match value { - amzn_qdeveloper_streaming_client::types::ChatResponseStream::AssistantResponseEvent( - amzn_qdeveloper_streaming_client::types::AssistantResponseEvent { content, .. }, - ) => ChatResponseStream::AssistantResponseEvent { content }, - amzn_qdeveloper_streaming_client::types::ChatResponseStream::CodeEvent( - amzn_qdeveloper_streaming_client::types::CodeEvent { content, .. }, - ) => ChatResponseStream::CodeEvent { content }, - amzn_qdeveloper_streaming_client::types::ChatResponseStream::CodeReferenceEvent(_) => { - ChatResponseStream::CodeReferenceEvent(()) - }, - amzn_qdeveloper_streaming_client::types::ChatResponseStream::FollowupPromptEvent(_) => { - ChatResponseStream::FollowupPromptEvent(()) - }, - amzn_qdeveloper_streaming_client::types::ChatResponseStream::IntentsEvent(_) => { - ChatResponseStream::IntentsEvent(()) - }, - amzn_qdeveloper_streaming_client::types::ChatResponseStream::InvalidStateEvent( - amzn_qdeveloper_streaming_client::types::InvalidStateEvent { reason, message, .. }, - ) => ChatResponseStream::InvalidStateEvent { - reason: reason.to_string(), - message, - }, - amzn_qdeveloper_streaming_client::types::ChatResponseStream::MessageMetadataEvent( - amzn_qdeveloper_streaming_client::types::MessageMetadataEvent { - conversation_id, - utterance_id, - .. - }, - ) => ChatResponseStream::MessageMetadataEvent { - conversation_id, - utterance_id, - }, - amzn_qdeveloper_streaming_client::types::ChatResponseStream::ToolUseEvent( - amzn_qdeveloper_streaming_client::types::ToolUseEvent { - tool_use_id, - name, - input, - stop, - .. - }, - ) => ChatResponseStream::ToolUseEvent { - tool_use_id, - name, - input, - stop, - }, - amzn_qdeveloper_streaming_client::types::ChatResponseStream::SupplementaryWebLinksEvent(_) => { - ChatResponseStream::SupplementaryWebLinksEvent(()) - }, - _ => ChatResponseStream::Unknown, - } - } -} - -#[derive(Debug, Clone, Default, Serialize, Deserialize)] -pub struct EnvState { - pub operating_system: Option, - pub current_working_directory: Option, - pub environment_variables: Vec, -} - -impl From for amzn_codewhisperer_streaming_client::types::EnvState { - fn from(value: EnvState) -> Self { - let environment_variables: Vec<_> = value.environment_variables.into_iter().map(Into::into).collect(); - Self::builder() - .set_operating_system(value.operating_system) - .set_current_working_directory(value.current_working_directory) - .set_environment_variables(if environment_variables.is_empty() { - None - } else { - Some(environment_variables) - }) - .build() - } -} - -impl From for amzn_qdeveloper_streaming_client::types::EnvState { - fn from(value: EnvState) -> Self { - let environment_variables: Vec<_> = value.environment_variables.into_iter().map(Into::into).collect(); - Self::builder() - .set_operating_system(value.operating_system) - .set_current_working_directory(value.current_working_directory) - .set_environment_variables(if environment_variables.is_empty() { - None - } else { - Some(environment_variables) - }) - .build() - } -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct EnvironmentVariable { - pub key: String, - pub value: String, -} - -impl From for amzn_codewhisperer_streaming_client::types::EnvironmentVariable { - fn from(value: EnvironmentVariable) -> Self { - Self::builder().key(value.key).value(value.value).build() - } -} - -impl From for amzn_qdeveloper_streaming_client::types::EnvironmentVariable { - fn from(value: EnvironmentVariable) -> Self { - Self::builder().key(value.key).value(value.value).build() - } -} - -#[derive(Debug, Clone)] -pub struct GitState { - pub status: String, -} - -impl From for amzn_codewhisperer_streaming_client::types::GitState { - fn from(value: GitState) -> Self { - Self::builder().status(value.status).build() - } -} - -impl From for amzn_qdeveloper_streaming_client::types::GitState { - fn from(value: GitState) -> Self { - Self::builder().status(value.status).build() - } -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ImageBlock { - pub format: ImageFormat, - pub source: ImageSource, -} - -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] -pub enum ImageFormat { - Gif, - Jpeg, - Png, - Webp, -} - -impl std::str::FromStr for ImageFormat { - type Err = String; - - fn from_str(s: &str) -> Result { - match s.trim().to_lowercase().as_str() { - "gif" => Ok(ImageFormat::Gif), - "jpeg" => Ok(ImageFormat::Jpeg), - "jpg" => Ok(ImageFormat::Jpeg), - "png" => Ok(ImageFormat::Png), - "webp" => Ok(ImageFormat::Webp), - _ => Err(format!("Failed to parse '{}' as ImageFormat", s)), - } - } -} - -impl From for amzn_codewhisperer_streaming_client::types::ImageFormat { - fn from(value: ImageFormat) -> Self { - match value { - ImageFormat::Gif => Self::Gif, - ImageFormat::Jpeg => Self::Jpeg, - ImageFormat::Png => Self::Png, - ImageFormat::Webp => Self::Webp, - } - } -} -impl From for amzn_qdeveloper_streaming_client::types::ImageFormat { - fn from(value: ImageFormat) -> Self { - match value { - ImageFormat::Gif => Self::Gif, - ImageFormat::Jpeg => Self::Jpeg, - ImageFormat::Png => Self::Png, - ImageFormat::Webp => Self::Webp, - } - } -} - -#[non_exhaustive] -#[derive(Debug, Clone, Serialize, Deserialize)] -pub enum ImageSource { - Bytes(Vec), - #[non_exhaustive] - Unknown, -} - -impl From for amzn_codewhisperer_streaming_client::types::ImageSource { - fn from(value: ImageSource) -> Self { - match value { - ImageSource::Bytes(bytes) => Self::Bytes(Blob::new(bytes)), - ImageSource::Unknown => Self::Unknown, - } - } -} -impl From for amzn_qdeveloper_streaming_client::types::ImageSource { - fn from(value: ImageSource) -> Self { - match value { - ImageSource::Bytes(bytes) => Self::Bytes(Blob::new(bytes)), - ImageSource::Unknown => Self::Unknown, - } - } -} - -impl From for amzn_codewhisperer_streaming_client::types::ImageBlock { - fn from(value: ImageBlock) -> Self { - Self::builder() - .format(value.format.into()) - .source(value.source.into()) - .build() - .expect("Failed to build ImageBlock") - } -} -impl From for amzn_qdeveloper_streaming_client::types::ImageBlock { - fn from(value: ImageBlock) -> Self { - Self::builder() - .format(value.format.into()) - .source(value.source.into()) - .build() - .expect("Failed to build ImageBlock") - } -} - -#[derive(Debug, Clone)] -pub struct UserInputMessage { - pub content: String, - pub user_input_message_context: Option, - pub user_intent: Option, - pub images: Option>, - pub model_id: Option, -} - -impl From for amzn_codewhisperer_streaming_client::types::UserInputMessage { - fn from(value: UserInputMessage) -> Self { - Self::builder() - .content(value.content) - .set_images(value.images.map(|images| images.into_iter().map(Into::into).collect())) - .set_user_input_message_context(value.user_input_message_context.map(Into::into)) - .set_user_intent(value.user_intent.map(Into::into)) - .set_model_id(value.model_id) - .origin(amzn_codewhisperer_streaming_client::types::Origin::Cli) - .build() - .expect("Failed to build UserInputMessage") - } -} - -impl From for amzn_qdeveloper_streaming_client::types::UserInputMessage { - fn from(value: UserInputMessage) -> Self { - Self::builder() - .content(value.content) - .set_images(value.images.map(|images| images.into_iter().map(Into::into).collect())) - .set_user_input_message_context(value.user_input_message_context.map(Into::into)) - .set_user_intent(value.user_intent.map(Into::into)) - .set_model_id(value.model_id) - .origin(amzn_qdeveloper_streaming_client::types::Origin::Cli) - .build() - .expect("Failed to build UserInputMessage") - } -} - -#[derive(Debug, Clone, Default)] -pub struct UserInputMessageContext { - pub env_state: Option, - pub git_state: Option, - pub tool_results: Option>, - pub tools: Option>, -} - -impl From for amzn_codewhisperer_streaming_client::types::UserInputMessageContext { - fn from(value: UserInputMessageContext) -> Self { - Self::builder() - .set_env_state(value.env_state.map(Into::into)) - .set_git_state(value.git_state.map(Into::into)) - .set_tool_results(value.tool_results.map(|t| t.into_iter().map(Into::into).collect())) - .set_tools(value.tools.map(|t| t.into_iter().map(Into::into).collect())) - .build() - } -} - -impl From for amzn_qdeveloper_streaming_client::types::UserInputMessageContext { - fn from(value: UserInputMessageContext) -> Self { - Self::builder() - .set_env_state(value.env_state.map(Into::into)) - .set_git_state(value.git_state.map(Into::into)) - .set_tool_results(value.tool_results.map(|t| t.into_iter().map(Into::into).collect())) - .set_tools(value.tools.map(|t| t.into_iter().map(Into::into).collect())) - .build() - } -} - -#[derive(Debug, Clone)] -pub enum UserIntent { - ApplyCommonBestPractices, -} - -impl From for amzn_codewhisperer_streaming_client::types::UserIntent { - fn from(value: UserIntent) -> Self { - match value { - UserIntent::ApplyCommonBestPractices => Self::ApplyCommonBestPractices, - } - } -} - -impl From for amzn_qdeveloper_streaming_client::types::UserIntent { - fn from(value: UserIntent) -> Self { - match value { - UserIntent::ApplyCommonBestPractices => Self::ApplyCommonBestPractices, - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn build_user_input_message() { - let user_input_message = UserInputMessage { - images: Some(vec![ImageBlock { - format: ImageFormat::Png, - source: ImageSource::Bytes(vec![1, 2, 3]), - }]), - content: "test content".to_string(), - user_input_message_context: Some(UserInputMessageContext { - env_state: Some(EnvState { - operating_system: Some("test os".to_string()), - current_working_directory: Some("test cwd".to_string()), - environment_variables: vec![EnvironmentVariable { - key: "test key".to_string(), - value: "test value".to_string(), - }], - }), - git_state: Some(GitState { - status: "test status".to_string(), - }), - tool_results: Some(vec![ToolResult { - tool_use_id: "test id".to_string(), - content: vec![ToolResultContentBlock::Text("test text".to_string())], - status: ToolResultStatus::Success, - }]), - tools: Some(vec![Tool::ToolSpecification(ToolSpecification { - name: "test tool name".to_string(), - description: "test tool description".to_string(), - input_schema: ToolInputSchema { - json: Some(AwsDocument::Null.into()), - }, - })]), - }), - user_intent: Some(UserIntent::ApplyCommonBestPractices), - model_id: Some("model id".to_string()), - }; - - let codewhisper_input = - amzn_codewhisperer_streaming_client::types::UserInputMessage::from(user_input_message.clone()); - let qdeveloper_input = amzn_qdeveloper_streaming_client::types::UserInputMessage::from(user_input_message); - - assert_eq!(format!("{codewhisper_input:?}"), format!("{qdeveloper_input:?}")); - - let minimal_message = UserInputMessage { - images: None, - content: "test content".to_string(), - user_input_message_context: None, - user_intent: None, - model_id: Some("model id".to_string()), - }; - - let codewhisper_minimal = - amzn_codewhisperer_streaming_client::types::UserInputMessage::from(minimal_message.clone()); - let qdeveloper_minimal = amzn_qdeveloper_streaming_client::types::UserInputMessage::from(minimal_message); - assert_eq!(format!("{codewhisper_minimal:?}"), format!("{qdeveloper_minimal:?}")); - } - - #[test] - fn build_assistant_response_message() { - let message = AssistantResponseMessage { - message_id: Some("testid".to_string()), - content: "test content".to_string(), - tool_uses: Some(vec![ToolUse { - tool_use_id: "tooluseid_test".to_string(), - name: "tool_name_test".to_string(), - input: FigDocument(AwsDocument::Object( - [("key1".to_string(), AwsDocument::Null)].into_iter().collect(), - )), - }]), - }; - let codewhisper_input = - amzn_codewhisperer_streaming_client::types::AssistantResponseMessage::try_from(message.clone()).unwrap(); - let qdeveloper_input = - amzn_qdeveloper_streaming_client::types::AssistantResponseMessage::try_from(message).unwrap(); - assert_eq!(format!("{codewhisper_input:?}"), format!("{qdeveloper_input:?}")); - } - - #[test] - fn build_chat_response() { - let assistant_response_event = - amzn_codewhisperer_streaming_client::types::ChatResponseStream::AssistantResponseEvent( - amzn_codewhisperer_streaming_client::types::AssistantResponseEvent::builder() - .content("context") - .build() - .unwrap(), - ); - assert_eq!( - ChatResponseStream::from(assistant_response_event), - ChatResponseStream::AssistantResponseEvent { - content: "context".into(), - } - ); - - let assistant_response_event = - amzn_qdeveloper_streaming_client::types::ChatResponseStream::AssistantResponseEvent( - amzn_qdeveloper_streaming_client::types::AssistantResponseEvent::builder() - .content("context") - .build() - .unwrap(), - ); - assert_eq!( - ChatResponseStream::from(assistant_response_event), - ChatResponseStream::AssistantResponseEvent { - content: "context".into(), - } - ); - - let code_event = amzn_codewhisperer_streaming_client::types::ChatResponseStream::CodeEvent( - amzn_codewhisperer_streaming_client::types::CodeEvent::builder() - .content("context") - .build() - .unwrap(), - ); - assert_eq!(ChatResponseStream::from(code_event), ChatResponseStream::CodeEvent { - content: "context".into() - }); - - let code_event = amzn_qdeveloper_streaming_client::types::ChatResponseStream::CodeEvent( - amzn_qdeveloper_streaming_client::types::CodeEvent::builder() - .content("context") - .build() - .unwrap(), - ); - assert_eq!(ChatResponseStream::from(code_event), ChatResponseStream::CodeEvent { - content: "context".into() - }); - - let code_reference_event = amzn_codewhisperer_streaming_client::types::ChatResponseStream::CodeReferenceEvent( - amzn_codewhisperer_streaming_client::types::CodeReferenceEvent::builder().build(), - ); - assert_eq!( - ChatResponseStream::from(code_reference_event), - ChatResponseStream::CodeReferenceEvent(()) - ); - - let code_reference_event = amzn_qdeveloper_streaming_client::types::ChatResponseStream::CodeReferenceEvent( - amzn_qdeveloper_streaming_client::types::CodeReferenceEvent::builder().build(), - ); - assert_eq!( - ChatResponseStream::from(code_reference_event), - ChatResponseStream::CodeReferenceEvent(()) - ); - - let followup_prompt_event = amzn_codewhisperer_streaming_client::types::ChatResponseStream::FollowupPromptEvent( - amzn_codewhisperer_streaming_client::types::FollowupPromptEvent::builder().build(), - ); - assert_eq!( - ChatResponseStream::from(followup_prompt_event), - ChatResponseStream::FollowupPromptEvent(()) - ); - - let followup_prompt_event = amzn_qdeveloper_streaming_client::types::ChatResponseStream::FollowupPromptEvent( - amzn_qdeveloper_streaming_client::types::FollowupPromptEvent::builder().build(), - ); - assert_eq!( - ChatResponseStream::from(followup_prompt_event), - ChatResponseStream::FollowupPromptEvent(()) - ); - - let intents_event = amzn_codewhisperer_streaming_client::types::ChatResponseStream::IntentsEvent( - amzn_codewhisperer_streaming_client::types::IntentsEvent::builder().build(), - ); - assert_eq!( - ChatResponseStream::from(intents_event), - ChatResponseStream::IntentsEvent(()) - ); - - let intents_event = amzn_qdeveloper_streaming_client::types::ChatResponseStream::IntentsEvent( - amzn_qdeveloper_streaming_client::types::IntentsEvent::builder().build(), - ); - assert_eq!( - ChatResponseStream::from(intents_event), - ChatResponseStream::IntentsEvent(()) - ); - - let user_input_event = amzn_codewhisperer_streaming_client::types::ChatResponseStream::InvalidStateEvent( - amzn_codewhisperer_streaming_client::types::InvalidStateEvent::builder() - .reason(amzn_codewhisperer_streaming_client::types::InvalidStateReason::InvalidTaskAssistPlan) - .message("message") - .build() - .unwrap(), - ); - assert_eq!( - ChatResponseStream::from(user_input_event), - ChatResponseStream::InvalidStateEvent { - reason: amzn_codewhisperer_streaming_client::types::InvalidStateReason::InvalidTaskAssistPlan - .to_string(), - message: "message".into() - } - ); - - let user_input_event = amzn_qdeveloper_streaming_client::types::ChatResponseStream::InvalidStateEvent( - amzn_qdeveloper_streaming_client::types::InvalidStateEvent::builder() - .reason(amzn_qdeveloper_streaming_client::types::InvalidStateReason::InvalidTaskAssistPlan) - .message("message") - .build() - .unwrap(), - ); - assert_eq!( - ChatResponseStream::from(user_input_event), - ChatResponseStream::InvalidStateEvent { - reason: amzn_qdeveloper_streaming_client::types::InvalidStateReason::InvalidTaskAssistPlan.to_string(), - message: "message".into() - } - ); - - let user_input_event = amzn_codewhisperer_streaming_client::types::ChatResponseStream::MessageMetadataEvent( - amzn_codewhisperer_streaming_client::types::MessageMetadataEvent::builder().build(), - ); - assert_eq!( - ChatResponseStream::from(user_input_event), - ChatResponseStream::MessageMetadataEvent { - conversation_id: None, - utterance_id: None - } - ); - - let user_input_event = amzn_qdeveloper_streaming_client::types::ChatResponseStream::MessageMetadataEvent( - amzn_qdeveloper_streaming_client::types::MessageMetadataEvent::builder().build(), - ); - assert_eq!( - ChatResponseStream::from(user_input_event), - ChatResponseStream::MessageMetadataEvent { - conversation_id: None, - utterance_id: None - } - ); - - let user_input_event = - amzn_codewhisperer_streaming_client::types::ChatResponseStream::SupplementaryWebLinksEvent( - amzn_codewhisperer_streaming_client::types::SupplementaryWebLinksEvent::builder().build(), - ); - assert_eq!( - ChatResponseStream::from(user_input_event), - ChatResponseStream::SupplementaryWebLinksEvent(()) - ); - - let user_input_event = amzn_qdeveloper_streaming_client::types::ChatResponseStream::SupplementaryWebLinksEvent( - amzn_qdeveloper_streaming_client::types::SupplementaryWebLinksEvent::builder().build(), - ); - assert_eq!( - ChatResponseStream::from(user_input_event), - ChatResponseStream::SupplementaryWebLinksEvent(()) - ); - - let user_input_event = amzn_codewhisperer_streaming_client::types::ChatResponseStream::ToolUseEvent( - amzn_codewhisperer_streaming_client::types::ToolUseEvent::builder() - .tool_use_id("tool_use_id".to_string()) - .name("tool_name".to_string()) - .build() - .unwrap(), - ); - assert_eq!( - ChatResponseStream::from(user_input_event), - ChatResponseStream::ToolUseEvent { - tool_use_id: "tool_use_id".to_string(), - name: "tool_name".to_string(), - input: None, - stop: None, - } - ); - - let user_input_event = amzn_qdeveloper_streaming_client::types::ChatResponseStream::ToolUseEvent( - amzn_qdeveloper_streaming_client::types::ToolUseEvent::builder() - .tool_use_id("tool_use_id".to_string()) - .name("tool_name".to_string()) - .build() - .unwrap(), - ); - assert_eq!( - ChatResponseStream::from(user_input_event), - ChatResponseStream::ToolUseEvent { - tool_use_id: "tool_use_id".to_string(), - name: "tool_name".to_string(), - input: None, - stop: None, - } - ); - } -} diff --git a/crates/chat-cli/src/api_client/opt_out.rs b/crates/chat-cli/src/api_client/opt_out.rs deleted file mode 100644 index 9ffb6fb2f..000000000 --- a/crates/chat-cli/src/api_client/opt_out.rs +++ /dev/null @@ -1,94 +0,0 @@ -use aws_smithy_runtime_api::box_error::BoxError; -use aws_smithy_runtime_api::client::interceptors::Intercept; -use aws_smithy_runtime_api::client::interceptors::context::BeforeTransmitInterceptorContextMut; -use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents; -use aws_smithy_types::config_bag::ConfigBag; - -use crate::api_client::X_AMZN_CODEWHISPERER_OPT_OUT_HEADER; -use crate::database::Database; -use crate::database::settings::Setting; - -fn is_codewhisperer_content_optout(database: &Database) -> bool { - !database - .settings - .get_bool(Setting::ShareCodeWhispererContent) - .unwrap_or(true) -} - -#[derive(Debug, Clone)] -pub struct OptOutInterceptor { - is_codewhisperer_content_optout: bool, - override_value: Option, - _inner: (), -} - -impl OptOutInterceptor { - pub fn new(database: &Database) -> Self { - Self { - is_codewhisperer_content_optout: is_codewhisperer_content_optout(database), - override_value: None, - _inner: (), - } - } -} - -impl Intercept for OptOutInterceptor { - fn name(&self) -> &'static str { - "OptOutInterceptor" - } - - fn modify_before_signing( - &self, - context: &mut BeforeTransmitInterceptorContextMut<'_>, - _runtime_components: &RuntimeComponents, - _cfg: &mut ConfigBag, - ) -> Result<(), BoxError> { - let opt_out = self.override_value.unwrap_or(self.is_codewhisperer_content_optout); - context - .request_mut() - .headers_mut() - .insert(X_AMZN_CODEWHISPERER_OPT_OUT_HEADER, opt_out.to_string()); - Ok(()) - } -} - -#[cfg(test)] -mod tests { - use amzn_consolas_client::config::RuntimeComponentsBuilder; - use amzn_consolas_client::config::interceptors::InterceptorContext; - use aws_smithy_runtime_api::client::interceptors::context::Input; - - use super::*; - - #[tokio::test] - async fn test_opt_out_interceptor() { - let rc = RuntimeComponentsBuilder::for_tests().build().unwrap(); - let mut cfg = ConfigBag::base(); - - let mut context = InterceptorContext::new(Input::erase(())); - context.set_request(aws_smithy_runtime_api::http::Request::empty()); - let mut context = BeforeTransmitInterceptorContextMut::from(&mut context); - - let database = Database::new().await.unwrap(); - let mut interceptor = OptOutInterceptor::new(&database); - println!("Interceptor: {}", interceptor.name()); - - interceptor - .modify_before_signing(&mut context, &rc, &mut cfg) - .expect("success"); - - interceptor.override_value = Some(false); - interceptor - .modify_before_signing(&mut context, &rc, &mut cfg) - .expect("success"); - let val = context.request().headers().get(X_AMZN_CODEWHISPERER_OPT_OUT_HEADER); - assert_eq!(val, Some("false")); - - interceptor.override_value = Some(true); - interceptor - .modify_before_signing(&mut context, &rc, &mut cfg) - .expect("success"); - let val = context.request().headers().get(X_AMZN_CODEWHISPERER_OPT_OUT_HEADER); - assert_eq!(val, Some("true")); - } -} diff --git a/crates/chat-cli/src/api_client/profile.rs b/crates/chat-cli/src/api_client/profile.rs deleted file mode 100644 index 17e29e25a..000000000 --- a/crates/chat-cli/src/api_client/profile.rs +++ /dev/null @@ -1,30 +0,0 @@ -use crate::api_client::endpoints::Endpoint; -use crate::api_client::{ - ApiClient, - ApiClientError, -}; -use crate::database::{ - AuthProfile, - Database, -}; -use crate::os::{ - Env, - Fs, -}; - -pub async fn list_available_profiles( - env: &Env, - fs: &Fs, - database: &mut Database, -) -> Result, ApiClientError> { - let mut profiles = vec![]; - for endpoint in Endpoint::CODEWHISPERER_ENDPOINTS { - let client = ApiClient::new(env, fs, database, Some(endpoint.clone())).await?; - match client.list_available_profiles().await { - Ok(mut p) => profiles.append(&mut p), - Err(e) => tracing::error!("Failed to list profiles from endpoint {:?}: {:?}", endpoint, e), - } - } - - Ok(profiles) -} diff --git a/crates/chat-cli/src/api_client/send_message_output.rs b/crates/chat-cli/src/api_client/send_message_output.rs deleted file mode 100644 index 43c15ab66..000000000 --- a/crates/chat-cli/src/api_client/send_message_output.rs +++ /dev/null @@ -1,45 +0,0 @@ -use aws_types::request_id::RequestId; - -use crate::api_client::ApiClientError; -use crate::api_client::model::ChatResponseStream; - -#[derive(Debug)] -pub enum SendMessageOutput { - Codewhisperer( - amzn_codewhisperer_streaming_client::operation::generate_assistant_response::GenerateAssistantResponseOutput, - ), - QDeveloper(amzn_qdeveloper_streaming_client::operation::send_message::SendMessageOutput), - Mock(Vec), -} - -impl SendMessageOutput { - pub fn request_id(&self) -> Option<&str> { - match self { - SendMessageOutput::Codewhisperer(output) => output.request_id(), - SendMessageOutput::QDeveloper(output) => output.request_id(), - SendMessageOutput::Mock(_) => None, - } - } - - pub async fn recv(&mut self) -> Result, ApiClientError> { - match self { - SendMessageOutput::Codewhisperer(output) => Ok(output - .generate_assistant_response_response - .recv() - .await? - .map(|s| s.into())), - SendMessageOutput::QDeveloper(output) => Ok(output.send_message_response.recv().await?.map(|s| s.into())), - SendMessageOutput::Mock(vec) => Ok(vec.pop()), - } - } -} - -impl RequestId for SendMessageOutput { - fn request_id(&self) -> Option<&str> { - match self { - SendMessageOutput::Codewhisperer(output) => output.request_id(), - SendMessageOutput::QDeveloper(output) => output.request_id(), - SendMessageOutput::Mock(_) => Some(""), - } - } -} diff --git a/crates/chat-cli/src/api_client/stage.rs b/crates/chat-cli/src/api_client/stage.rs deleted file mode 100644 index 31b301786..000000000 --- a/crates/chat-cli/src/api_client/stage.rs +++ /dev/null @@ -1,40 +0,0 @@ -use std::str::FromStr; - -#[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub enum Stage { - Prod, - Gamma, - Alpha, - Beta, -} - -impl Stage { - pub fn as_str(&self) -> &'static str { - match self { - Stage::Prod => "prod", - Stage::Gamma => "gamma", - Stage::Alpha => "alpha", - Stage::Beta => "beta", - } - } -} - -impl FromStr for Stage { - type Err = (); - - fn from_str(s: &str) -> Result { - match s.to_ascii_lowercase().trim() { - "prod" | "production" => Ok(Stage::Prod), - "gamma" => Ok(Stage::Gamma), - "alpha" => Ok(Stage::Alpha), - "beta" => Ok(Stage::Beta), - _ => Err(()), - } - } -} - -impl std::fmt::Display for Stage { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str(self.as_str()) - } -} diff --git a/crates/chat-cli/src/auth/builder_id.rs b/crates/chat-cli/src/auth/builder_id.rs deleted file mode 100644 index a257fe415..000000000 --- a/crates/chat-cli/src/auth/builder_id.rs +++ /dev/null @@ -1,658 +0,0 @@ -//! # Builder ID -//! -//! SSO flow (RFC: ) -//! 1. Get a client id (SSO-OIDC identifier, formatted per RFC6749). -//! - Code: [DeviceRegistration::register] -//! - Calls [Client::register_client] -//! - RETURNS: [DeviceRegistration] -//! - Client registration is valid for potentially months and creates state server-side, so -//! the client SHOULD cache them to disk. -//! 2. Start device authorization. -//! - Code: [start_device_authorization] -//! - Calls [Client::start_device_authorization] -//! - RETURNS (RFC: ): -//! [StartDeviceAuthorizationResponse] -//! 3. Poll for the access token -//! - Code: [poll_create_token] -//! - Calls [Client::create_token] -//! - RETURNS: [PollCreateToken] -//! 4. (Repeat) Tokens SHOULD be refreshed if expired and a refresh token is available. -//! - Code: [BuilderIdToken::refresh_token] -//! - Calls [Client::create_token] -//! - RETURNS: [BuilderIdToken] - -use aws_sdk_ssooidc::client::Client; -use aws_sdk_ssooidc::config::retry::RetryConfig; -use aws_sdk_ssooidc::config::{ - BehaviorVersion, - ConfigBag, - RuntimeComponents, - SharedAsyncSleep, -}; -use aws_sdk_ssooidc::error::SdkError; -use aws_sdk_ssooidc::operation::create_token::CreateTokenOutput; -use aws_sdk_ssooidc::operation::register_client::RegisterClientOutput; -use aws_smithy_async::rt::sleep::TokioSleep; -use aws_smithy_runtime_api::client::identity::http::Token; -use aws_smithy_runtime_api::client::identity::{ - Identity, - IdentityFuture, - ResolveIdentity, -}; -use aws_smithy_types::error::display::DisplayErrorContext; -use aws_types::region::Region; -use eyre::{ - Result, - eyre, -}; -use time::OffsetDateTime; -use tracing::{ - debug, - error, - info, - trace, - warn, -}; - -use crate::api_client::stalled_stream_protection_config; -use crate::auth::AuthError; -use crate::auth::consts::*; -use crate::auth::scope::is_scopes; -use crate::aws_common::app_name; -use crate::database::{ - Database, - Secret, -}; - -#[derive(Debug, Copy, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] -pub enum OAuthFlow { - DeviceCode, - // This must remain backwards compatible - #[serde(alias = "PKCE")] - Pkce, -} - -/// Indicates if an expiration time has passed, there is a small 1 min window that is removed -/// so the token will not expire in transit -fn is_expired(expiration_time: &OffsetDateTime) -> bool { - let now = time::OffsetDateTime::now_utc(); - &(now + time::Duration::minutes(1)) > expiration_time -} - -pub(crate) fn oidc_url(region: &Region) -> String { - format!("https://oidc.{region}.amazonaws.com") -} - -pub fn client(region: Region) -> Client { - Client::new( - &aws_types::SdkConfig::builder() - .http_client(crate::aws_common::http_client::client()) - .behavior_version(BehaviorVersion::v2025_08_07()) - .endpoint_url(oidc_url(®ion)) - .region(region) - .retry_config(RetryConfig::standard().with_max_attempts(3)) - .sleep_impl(SharedAsyncSleep::new(TokioSleep::new())) - .stalled_stream_protection(stalled_stream_protection_config()) - .app_name(app_name()) - .build(), - ) -} - -/// Represents an OIDC registered client, resulting from the "register client" API call. -#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] -pub struct DeviceRegistration { - pub client_id: String, - pub client_secret: Secret, - #[serde(with = "time::serde::rfc3339::option")] - pub client_secret_expires_at: Option, - pub region: String, - pub oauth_flow: OAuthFlow, - pub scopes: Option>, -} - -impl DeviceRegistration { - const SECRET_KEY: &'static str = "codewhisperer:odic:device-registration"; - - pub fn from_output( - output: RegisterClientOutput, - region: &Region, - oauth_flow: OAuthFlow, - scopes: Vec, - ) -> Self { - Self { - client_id: output.client_id.unwrap_or_default(), - client_secret: output.client_secret.unwrap_or_default().into(), - client_secret_expires_at: time::OffsetDateTime::from_unix_timestamp(output.client_secret_expires_at).ok(), - region: region.to_string(), - oauth_flow, - scopes: Some(scopes), - } - } - - /// Loads the OIDC registered client from the secret store, deleting it if it is expired. - async fn load_from_secret_store(database: &Database, region: &Region) -> Result, AuthError> { - trace!(?region, "loading device registration from secret store"); - let device_registration = database.get_secret(Self::SECRET_KEY).await?; - - if let Some(device_registration) = device_registration { - // check that the data is not expired, assume it is invalid if not present - let device_registration: Self = serde_json::from_str(&device_registration.0)?; - - if let Some(client_secret_expires_at) = device_registration.client_secret_expires_at { - let is_expired = is_expired(&client_secret_expires_at); - let registration_region_is_valid = device_registration.region == region.as_ref(); - trace!( - ?is_expired, - ?registration_region_is_valid, - "checking if device registration is valid" - ); - if !is_expired && registration_region_is_valid { - return Ok(Some(device_registration)); - } - } else { - warn!("no expiration time found for the client secret"); - } - } - - // delete the data if its expired or invalid - if let Err(err) = database.delete_secret(Self::SECRET_KEY).await { - error!(?err, "Failed to delete device registration from keychain"); - } - - Ok(None) - } - - /// Loads the client saved in the secret store if available, otherwise registers a new client - /// and saves it in the secret store. - pub async fn init_device_code_registration( - database: &Database, - client: &Client, - region: &Region, - ) -> Result { - match Self::load_from_secret_store(database, region).await { - Ok(Some(registration)) if registration.oauth_flow == OAuthFlow::DeviceCode => match ®istration.scopes { - Some(scopes) if is_scopes(scopes) => return Ok(registration), - _ => warn!("Invalid scopes in device registration, ignoring"), - }, - // If it doesn't exist or is for another OAuth flow, - // then continue with creating a new one. - Ok(None | Some(_)) => {}, - Err(err) => { - error!(?err, "Failed to read device registration from keychain"); - }, - }; - - let mut register = client - .register_client() - .client_name(CLIENT_NAME) - .client_type(CLIENT_TYPE); - for scope in SCOPES { - register = register.scopes(*scope); - } - let output = register.send().await?; - - let device_registration = Self::from_output( - output, - region, - OAuthFlow::DeviceCode, - SCOPES.iter().map(|s| (*s).to_owned()).collect(), - ); - - if let Err(err) = device_registration.save(database).await { - error!(?err, "Failed to write device registration to keychain"); - } - - Ok(device_registration) - } - - /// Saves to the passed secret store. - pub async fn save(&self, secret_store: &Database) -> Result<(), AuthError> { - secret_store - .set_secret(Self::SECRET_KEY, &serde_json::to_string(&self)?) - .await?; - Ok(()) - } -} - -#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] -pub struct StartDeviceAuthorizationResponse { - /// Device verification code. - pub device_code: String, - /// User verification code. - pub user_code: String, - /// Verification URI on the authorization server. - pub verification_uri: String, - /// User verification URI on the authorization server. - pub verification_uri_complete: String, - /// Lifetime (seconds) of `device_code` and `user_code`. - pub expires_in: i32, - /// Minimum time (seconds) the client SHOULD wait between polling intervals. - pub interval: i32, - pub region: String, - pub start_url: String, -} - -/// Init a builder id request -pub async fn start_device_authorization( - database: &Database, - start_url: Option, - region: Option, -) -> Result { - let region = region.clone().map_or(OIDC_BUILDER_ID_REGION, Region::new); - let client = client(region.clone()); - - let DeviceRegistration { - client_id, - client_secret, - .. - } = DeviceRegistration::init_device_code_registration(database, &client, ®ion).await?; - - let output = client - .start_device_authorization() - .client_id(&client_id) - .client_secret(&client_secret.0) - .start_url(start_url.as_deref().unwrap_or(START_URL)) - .send() - .await?; - - Ok(StartDeviceAuthorizationResponse { - device_code: output.device_code.unwrap_or_default(), - user_code: output.user_code.unwrap_or_default(), - verification_uri: output.verification_uri.unwrap_or_default(), - verification_uri_complete: output.verification_uri_complete.unwrap_or_default(), - expires_in: output.expires_in, - interval: output.interval, - region: region.to_string(), - start_url: start_url.unwrap_or_else(|| START_URL.to_owned()), - }) -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum TokenType { - BuilderId, - IamIdentityCenter, -} - -#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] -pub struct BuilderIdToken { - pub access_token: Secret, - #[serde(with = "time::serde::rfc3339")] - pub expires_at: time::OffsetDateTime, - pub refresh_token: Option, - pub region: Option, - pub start_url: Option, - pub oauth_flow: OAuthFlow, - pub scopes: Option>, -} - -impl BuilderIdToken { - const SECRET_KEY: &'static str = "codewhisperer:odic:token"; - - #[cfg(test)] - fn test() -> Self { - Self { - access_token: Secret("test_access_token".to_string()), - expires_at: time::OffsetDateTime::now_utc() + time::Duration::minutes(60), - refresh_token: Some(Secret("test_refresh_token".to_string())), - region: Some(OIDC_BUILDER_ID_REGION.to_string()), - start_url: Some(START_URL.to_string()), - oauth_flow: OAuthFlow::DeviceCode, - scopes: Some(SCOPES.iter().map(|s| (*s).to_owned()).collect()), - } - } - - /// Load the token from the keychain, refresh the token if it is expired and return it - pub async fn load(database: &Database) -> Result, AuthError> { - trace!("loading builder id token from the secret store"); - match database.get_secret(Self::SECRET_KEY).await { - Ok(Some(secret)) => { - let token: Option = serde_json::from_str(&secret.0)?; - match token { - Some(token) => { - let region = token.region.clone().map_or(OIDC_BUILDER_ID_REGION, Region::new); - let client = client(region.clone()); - - if token.is_expired() { - trace!("token is expired, refreshing"); - token.refresh_token(&client, database, ®ion).await - } else { - trace!(?token, "found a valid token"); - Ok(Some(token)) - } - }, - None => { - debug!("secret stored in the database was empty"); - Ok(None) - }, - } - }, - Ok(None) => { - debug!("no secret found in the database"); - Ok(None) - }, - Err(err) => { - error!(%err, "Error getting builder id token from keychain"); - Err(err)? - }, - } - } - - /// Refresh the access token - pub async fn refresh_token( - &self, - client: &Client, - database: &Database, - region: &Region, - ) -> Result, AuthError> { - let Some(refresh_token) = &self.refresh_token else { - warn!("no refresh token was found"); - // if the token is expired and has no refresh token, delete it - if let Err(err) = self.delete(database).await { - error!(?err, "Failed to delete builder id token"); - } - - return Ok(None); - }; - - trace!("loading device registration from secret store"); - let registration = match DeviceRegistration::load_from_secret_store(database, region).await? { - Some(registration) if registration.oauth_flow == self.oauth_flow => registration, - // If the OIDC client registration is for a different oauth flow or doesn't exist, then - // we can't refresh the token. - Some(registration) => { - warn!( - "Unable to refresh token: Stored client registration has oauth flow: {:?} but current access token has oauth flow: {:?}", - registration.oauth_flow, self.oauth_flow - ); - return Ok(None); - }, - None => { - warn!("Unable to refresh token: No registered client was found"); - return Ok(None); - }, - }; - - debug!("Refreshing access token"); - match client - .create_token() - .client_id(registration.client_id) - .client_secret(registration.client_secret.0) - .refresh_token(&refresh_token.0) - .grant_type(REFRESH_GRANT_TYPE) - .send() - .await - { - Ok(output) => { - let token: BuilderIdToken = Self::from_output( - output, - region.clone(), - self.start_url.clone(), - self.oauth_flow, - self.scopes.clone(), - ); - debug!("Refreshed access token, new token: {:?}", token); - - if let Err(err) = token.save(database).await { - error!(?err, "Failed to store builder id access token"); - }; - - Ok(Some(token)) - }, - Err(err) => { - let display_err = DisplayErrorContext(&err); - error!("Failed to refresh builder id access token: {}", display_err); - - // if the error is the client's fault, clear the token - if let SdkError::ServiceError(service_err) = &err { - if !service_err.err().is_slow_down_exception() { - if let Err(err) = self.delete(database).await { - error!(?err, "Failed to delete builder id token"); - } - } - } - - Err(err.into()) - }, - } - } - - /// If the time has passed the `expires_at` time - /// - /// The token is marked as expired 1 min before it actually does to account for the potential a - /// token expires while in transit - pub fn is_expired(&self) -> bool { - is_expired(&self.expires_at) - } - - /// Save the token to the keychain - pub async fn save(&self, database: &Database) -> Result<(), AuthError> { - database - .set_secret(Self::SECRET_KEY, &serde_json::to_string(self)?) - .await?; - Ok(()) - } - - /// Delete the token from the keychain - pub async fn delete(&self, database: &Database) -> Result<(), AuthError> { - database.delete_secret(Self::SECRET_KEY).await?; - Ok(()) - } - - pub(crate) fn from_output( - output: CreateTokenOutput, - region: Region, - start_url: Option, - oauth_flow: OAuthFlow, - scopes: Option>, - ) -> Self { - Self { - access_token: output.access_token.unwrap_or_default().into(), - expires_at: time::OffsetDateTime::now_utc() + time::Duration::seconds(output.expires_in as i64), - refresh_token: output.refresh_token.map(|t| t.into()), - region: Some(region.to_string()), - start_url, - oauth_flow, - scopes, - } - } - - pub fn token_type(&self) -> TokenType { - match &self.start_url { - Some(url) if url == START_URL => TokenType::BuilderId, - None => TokenType::BuilderId, - Some(_) => TokenType::IamIdentityCenter, - } - } - - /// Check if the token is for the internal amzn start URL (`https://amzn.awsapps.com/start`), - /// this implies the user will use midway for private specs - pub fn is_amzn_user(&self) -> bool { - matches!(&self.start_url, Some(url) if url == AMZN_START_URL) - } -} - -pub enum PollCreateToken { - Pending, - Complete, - Error(AuthError), -} - -/// Poll for the create token response -pub async fn poll_create_token( - database: &Database, - device_code: String, - start_url: Option, - region: Option, -) -> PollCreateToken { - let region = region.clone().map_or(OIDC_BUILDER_ID_REGION, Region::new); - let client = client(region.clone()); - - let DeviceRegistration { - client_id, - client_secret, - scopes, - .. - } = match DeviceRegistration::init_device_code_registration(database, &client, ®ion).await { - Ok(res) => res, - Err(err) => { - return PollCreateToken::Error(err); - }, - }; - - match client - .create_token() - .grant_type(DEVICE_GRANT_TYPE) - .device_code(device_code) - .client_id(client_id) - .client_secret(client_secret.0) - .send() - .await - { - Ok(output) => { - let token: BuilderIdToken = - BuilderIdToken::from_output(output, region, start_url, OAuthFlow::DeviceCode, scopes); - - if let Err(err) = token.save(database).await { - error!(?err, "Failed to store builder id token"); - }; - - PollCreateToken::Complete - }, - Err(SdkError::ServiceError(service_error)) if service_error.err().is_authorization_pending_exception() => { - PollCreateToken::Pending - }, - Err(err) => { - error!(?err, "Failed to poll for builder id token"); - PollCreateToken::Error(err.into()) - }, - } -} - -pub async fn is_logged_in(database: &mut Database) -> bool { - // Check for BuilderId if not using Sigv4 - if std::env::var("AMAZON_Q_SIGV4").is_ok_and(|v| !v.is_empty()) { - debug!("logged in using sigv4 credentials"); - return true; - } - - match BuilderIdToken::load(database).await { - Ok(Some(_)) => true, - Ok(None) => { - info!("not logged in - no valid token found"); - false - }, - Err(err) => { - warn!(?err, "failed to try to load a builder id token"); - false - }, - } -} - -pub async fn logout(database: &mut Database) -> Result<(), AuthError> { - let Ok(secret_store) = Database::new().await else { - return Ok(()); - }; - - let (builder_res, device_res) = tokio::join!( - secret_store.delete_secret(BuilderIdToken::SECRET_KEY), - secret_store.delete_secret(DeviceRegistration::SECRET_KEY), - ); - - let profile_res = database.unset_auth_profile(); - - builder_res?; - device_res?; - profile_res?; - - Ok(()) -} - -pub async fn get_start_url_and_region(database: &Database) -> (Option, Option) { - // NOTE: Database provides direct methods to access the start_url and region, but they are not - // guaranteed to be up to date in the chat session. Example: login is changed mid-chat session. - let token = BuilderIdToken::load(database).await; - match token { - Ok(Some(t)) => (t.start_url, t.region), - _ => (None, None), - } -} - -#[derive(Debug, Clone)] -pub struct BearerResolver; - -impl ResolveIdentity for BearerResolver { - fn resolve_identity<'a>( - &'a self, - _runtime_components: &'a RuntimeComponents, - _config_bag: &'a ConfigBag, - ) -> IdentityFuture<'a> { - IdentityFuture::new_boxed(Box::pin(async { - let database = Database::new().await?; - match BuilderIdToken::load(&database).await? { - Some(token) => Ok(Identity::new( - Token::new(token.access_token.0.clone(), Some(token.expires_at.into())), - Some(token.expires_at.into()), - )), - None => Err(AuthError::NoToken.into()), - } - })) - } -} - -pub async fn is_idc_user(database: &Database) -> Result { - if cfg!(test) { - return Ok(false); - } - if let Ok(Some(token)) = BuilderIdToken::load(database).await { - Ok(token.token_type() == TokenType::IamIdentityCenter) - } else { - Err(eyre!("No auth token found - is the user signed in?")) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - const US_EAST_1: Region = Region::from_static("us-east-1"); - const US_WEST_2: Region = Region::from_static("us-west-2"); - - #[test] - fn test_oauth_flow_deser() { - assert_eq!(OAuthFlow::Pkce, serde_json::from_str("\"PKCE\"").unwrap()); - assert_eq!(OAuthFlow::Pkce, serde_json::from_str("\"Pkce\"").unwrap()); - } - - #[tokio::test] - async fn test_client() { - println!("{:?}", client(US_EAST_1)); - println!("{:?}", client(US_WEST_2)); - } - - #[test] - fn oidc_url_snapshot() { - insta::assert_snapshot!(oidc_url(&US_EAST_1), @"https://oidc.us-east-1.amazonaws.com"); - insta::assert_snapshot!(oidc_url(&US_WEST_2), @"https://oidc.us-west-2.amazonaws.com"); - } - - #[test] - fn test_is_expired() { - let mut token = BuilderIdToken::test(); - assert!(!token.is_expired()); - - token.expires_at = time::OffsetDateTime::now_utc() - time::Duration::seconds(60); - assert!(token.is_expired()); - } - - #[test] - fn test_token_type() { - let mut token = BuilderIdToken::test(); - assert_eq!(token.token_type(), TokenType::BuilderId); - - token.start_url = None; - assert_eq!(token.token_type(), TokenType::BuilderId); - - token.start_url = Some("https://amzn.awsapps.com/start".into()); - assert_eq!(token.token_type(), TokenType::IamIdentityCenter); - } -} diff --git a/crates/chat-cli/src/auth/consts.rs b/crates/chat-cli/src/auth/consts.rs deleted file mode 100644 index a09e42a85..000000000 --- a/crates/chat-cli/src/auth/consts.rs +++ /dev/null @@ -1,28 +0,0 @@ -use aws_types::region::Region; - -pub(crate) const CLIENT_NAME: &str = "Amazon Q Developer for command line"; - -pub(crate) const OIDC_BUILDER_ID_REGION: Region = Region::from_static("us-east-1"); - -/// The scopes requested for OIDC -/// -/// Do not include `sso:account:access`, these permissions are not needed and were -/// previously included -pub(crate) const SCOPES: &[&str] = &[ - "codewhisperer:completions", - "codewhisperer:analysis", - "codewhisperer:conversations", - // "codewhisperer:taskassist", - // "codewhisperer:transformations", -]; - -pub(crate) const CLIENT_TYPE: &str = "public"; - -// The start URL for public builder ID users -pub const START_URL: &str = "https://view.awsapps.com/start"; - -// The start URL for internal amzn users -pub const AMZN_START_URL: &str = "https://amzn.awsapps.com/start"; - -pub(crate) const DEVICE_GRANT_TYPE: &str = "urn:ietf:params:oauth:grant-type:device_code"; -pub(crate) const REFRESH_GRANT_TYPE: &str = "refresh_token"; diff --git a/crates/chat-cli/src/auth/index.html b/crates/chat-cli/src/auth/index.html deleted file mode 100644 index c68c852af..000000000 --- a/crates/chat-cli/src/auth/index.html +++ /dev/null @@ -1,181 +0,0 @@ - - - - - AWS Authentication - - - - - -
-
- - - - - -
-
- -
-
- -
-

Request approved

-

-
-
-

-
- - - -
-
- - - - diff --git a/crates/chat-cli/src/auth/mod.rs b/crates/chat-cli/src/auth/mod.rs deleted file mode 100644 index 4b425f2a6..000000000 --- a/crates/chat-cli/src/auth/mod.rs +++ /dev/null @@ -1,73 +0,0 @@ -pub mod builder_id; -mod consts; -pub mod pkce; -mod scope; - -use aws_sdk_ssooidc::error::SdkError; -use aws_sdk_ssooidc::operation::create_token::CreateTokenError; -use aws_sdk_ssooidc::operation::register_client::RegisterClientError; -use aws_sdk_ssooidc::operation::start_device_authorization::StartDeviceAuthorizationError; -pub use builder_id::{ - is_logged_in, - logout, -}; -pub use consts::START_URL; -use thiserror::Error; - -#[derive(Debug, Error)] -pub enum AuthError { - #[error(transparent)] - Ssooidc(Box), - #[error(transparent)] - SdkRegisterClient(Box>), - #[error(transparent)] - SdkCreateToken(Box>), - #[error(transparent)] - SdkStartDeviceAuthorization(Box>), - #[error(transparent)] - Io(#[from] std::io::Error), - #[error(transparent)] - TimeComponentRange(#[from] time::error::ComponentRange), - #[error(transparent)] - Directories(#[from] crate::util::directories::DirectoryError), - #[error(transparent)] - SerdeJson(#[from] serde_json::Error), - #[error(transparent)] - DbOpenError(#[from] crate::database::DbOpenError), - #[error("No token")] - NoToken, - #[error("OAuth state mismatch. Actual: {} | Expected: {}", .actual, .expected)] - OAuthStateMismatch { actual: String, expected: String }, - #[error("Timeout waiting for authentication to complete")] - OAuthTimeout, - #[error("No code received on redirect")] - OAuthMissingCode, - #[error("OAuth error: {0}")] - OAuthCustomError(String), - #[error(transparent)] - DatabaseError(#[from] crate::database::DatabaseError), -} - -impl From for AuthError { - fn from(value: aws_sdk_ssooidc::Error) -> Self { - Self::Ssooidc(Box::new(value)) - } -} - -impl From> for AuthError { - fn from(value: SdkError) -> Self { - Self::SdkRegisterClient(Box::new(value)) - } -} - -impl From> for AuthError { - fn from(value: SdkError) -> Self { - Self::SdkCreateToken(Box::new(value)) - } -} - -impl From> for AuthError { - fn from(value: SdkError) -> Self { - Self::SdkStartDeviceAuthorization(Box::new(value)) - } -} diff --git a/crates/chat-cli/src/auth/pkce.rs b/crates/chat-cli/src/auth/pkce.rs deleted file mode 100644 index e001abe97..000000000 --- a/crates/chat-cli/src/auth/pkce.rs +++ /dev/null @@ -1,633 +0,0 @@ -//! # OAuth 2.0 Proof Key for Code Exchange -//! -//! This module implements the PKCE integration with AWS OIDC according to their -//! developer guide. -//! -//! The benefit of PKCE over device code is to simplify the user experience by not -//! requiring the user to validate the generated code across the browser and the -//! device. -//! -//! SSO flow (RFC: ) -//! 1. Register an OIDC client -//! - Code: [PkceRegistration::register] -//! 2. Host a local HTTP server to handle the redirect -//! - Code: [PkceRegistration::finish] -//! 3. Open the [PkceRegistration::url] in the browser, and approve the request. -//! 4. Exchange the code for access and refresh tokens. -//! - This completes the future returned by [PkceRegistration::finish]. -//! -//! Once access/refresh tokens are received, there is no difference between PKCE -//! and device code (as already implemented in [crate::builder_id]). - -use std::future::Future; -use std::pin::Pin; -use std::time::Duration; - -pub use aws_sdk_ssooidc::client::Client; -pub use aws_sdk_ssooidc::operation::create_token::CreateTokenOutput; -pub use aws_sdk_ssooidc::operation::register_client::RegisterClientOutput; -pub use aws_types::region::Region; -use base64::Engine; -use base64::engine::general_purpose::URL_SAFE; -use bytes::Bytes; -use http_body_util::Full; -use hyper::body::Incoming; -use hyper::server::conn::http1; -use hyper::service::Service; -use hyper::{ - Request, - Response, -}; -use hyper_util::rt::TokioIo; -use percent_encoding::{ - NON_ALPHANUMERIC, - utf8_percent_encode, -}; -use rand::Rng; -use tokio::net::TcpListener; -use tracing::{ - debug, - error, -}; - -use crate::auth::builder_id::*; -use crate::auth::consts::*; -use crate::auth::{ - AuthError, - START_URL, -}; -use crate::database::Database; - -const DEFAULT_AUTHORIZATION_TIMEOUT: Duration = Duration::from_secs(60 * 3); - -/// Starts the PKCE authorization flow, using [`START_URL`] and [`OIDC_BUILDER_ID_REGION`] as the -/// default issuer URL and region. Returns the [`PkceClient`] to use to finish the flow. -pub async fn start_pkce_authorization( - start_url: Option, - region: Option, -) -> Result<(Client, PkceRegistration), AuthError> { - let issuer_url = start_url.as_deref().unwrap_or(START_URL); - let region = region.clone().map_or(OIDC_BUILDER_ID_REGION, Region::new); - let client = client(region.clone()); - let registration = PkceRegistration::register(&client, region, issuer_url.to_string(), None).await?; - Ok((client, registration)) -} - -/// Represents a client used for registering with AWS IAM OIDC. -#[async_trait::async_trait] -pub trait PkceClient { - /// The scopes that the client will request - fn scopes() -> Vec; - - async fn register_client( - &self, - redirect_uri: String, - issuer_url: String, - ) -> Result; - - async fn create_token(&self, args: CreateTokenArgs) -> Result; -} - -#[derive(Debug, Clone)] -pub struct RegisterClientResponse { - pub output: RegisterClientOutput, -} - -impl RegisterClientResponse { - pub fn client_id(&self) -> &str { - self.output.client_id().unwrap_or_default() - } - - pub fn client_secret(&self) -> &str { - self.output.client_secret().unwrap_or_default() - } -} - -#[derive(Debug)] -pub struct CreateTokenResponse { - pub output: CreateTokenOutput, -} - -#[derive(Debug)] -pub struct CreateTokenArgs { - pub client_id: String, - pub client_secret: String, - pub redirect_uri: String, - pub code_verifier: String, - pub code: String, -} - -#[async_trait::async_trait] -impl PkceClient for Client { - fn scopes() -> Vec { - SCOPES.iter().map(|s| (*s).to_owned()).collect() - } - - async fn register_client( - &self, - redirect_uri: String, - issuer_url: String, - ) -> Result { - let mut register = self - .register_client() - .client_name(CLIENT_NAME) - .client_type(CLIENT_TYPE) - .issuer_url(issuer_url.clone()) - .redirect_uris(redirect_uri.clone()) - .grant_types("authorization_code") - .grant_types("refresh_token"); - for scope in Self::scopes() { - register = register.scopes(scope); - } - let output = register.send().await?; - Ok(RegisterClientResponse { output }) - } - - async fn create_token(&self, args: CreateTokenArgs) -> Result { - let output = self - .create_token() - .client_id(args.client_id.clone()) - .client_secret(args.client_secret.clone()) - .grant_type("authorization_code") - .redirect_uri(args.redirect_uri) - .code_verifier(args.code_verifier) - .code(args.code) - .send() - .await?; - Ok(CreateTokenResponse { output }) - } -} - -/// Represents an active PKCE registration flow. To execute the flow, you should (in order): -/// 1. Call [`PkceRegistration::register`] to register an AWS OIDC client and receive the URL to be -/// opened by the browser. -/// 2. Call [`PkceRegistration::finish`] to host a local server to handle redirects, and trade the -/// authorization code for an access token. -#[derive(Debug)] -pub struct PkceRegistration { - /// URL to be opened by the user's browser. - pub url: String, - registered_client: RegisterClientResponse, - /// Configured URI that the authorization server will redirect the client to. - pub redirect_uri: String, - code_verifier: String, - /// Random value generated for every authentication attempt. - /// - /// - pub state: String, - /// Listener for hosting the local HTTP server. - listener: TcpListener, - region: Region, - /// Interchangeable with the "start URL" concept in the device code flow. - issuer_url: String, - /// Time to wait for [`Self::finish`] to complete. Default is [`DEFAULT_AUTHORIZATION_TIMEOUT`]. - timeout: Duration, -} - -impl PkceRegistration { - pub async fn register( - client: &impl PkceClient, - region: Region, - issuer_url: String, - timeout: Option, - ) -> Result { - let listener = TcpListener::bind("127.0.0.1:0").await?; - let redirect_uri = format!("http://{}/oauth/callback", listener.local_addr()?); - let code_verifier = generate_code_verifier(); - let code_challenge = generate_code_challenge(&code_verifier); - let state = rand::rng() - .sample_iter(rand::distr::Alphanumeric) - .take(10) - .collect::>(); - let state = String::from_utf8(state).unwrap_or("state".to_string()); - - let response = client.register_client(redirect_uri.clone(), issuer_url.clone()).await?; - - let query = PkceQueryParams { - client_id: response.client_id().to_string(), - redirect_uri: redirect_uri.clone(), - // Scopes must be space delimited. - scopes: SCOPES.join(" "), - state: state.clone(), - code_challenge: code_challenge.clone(), - code_challenge_method: "S256".to_string(), - }; - let url = format!("{}/authorize?{}", oidc_url(®ion), query.as_query_params()); - - Ok(Self { - url, - registered_client: response, - code_verifier, - state, - listener, - redirect_uri, - region, - issuer_url, - timeout: timeout.unwrap_or(DEFAULT_AUTHORIZATION_TIMEOUT), - }) - } - - /// Hosts a local HTTP server to listen for browser redirects. If a [`Database`] is passed, - /// then the access and refresh tokens will be saved. - /// - /// Only the first connection will be served. - pub async fn finish(self, client: &C, database: Option<&mut Database>) -> Result<(), AuthError> { - let code = tokio::select! { - code = Self::recv_code(self.listener, self.state) => { - code? - }, - _ = tokio::time::sleep(self.timeout) => { - return Err(AuthError::OAuthTimeout); - } - }; - - let response = client - .create_token(CreateTokenArgs { - client_id: self.registered_client.client_id().to_string(), - client_secret: self.registered_client.client_secret().to_string(), - redirect_uri: self.redirect_uri, - code_verifier: self.code_verifier, - code, - }) - .await?; - - // Tokens are redacted in the log output. - debug!(?response, "Received create_token response"); - - let token = BuilderIdToken::from_output( - response.output, - self.region.clone(), - Some(self.issuer_url), - OAuthFlow::Pkce, - Some(C::scopes()), - ); - - let device_registration = DeviceRegistration::from_output( - self.registered_client.output, - &self.region, - OAuthFlow::Pkce, - C::scopes(), - ); - - if let Some(database) = database { - if let Err(err) = device_registration.save(database).await { - error!(?err, "Failed to store pkce registration to secret store"); - } - - if let Err(err) = token.save(database).await { - error!(?err, "Failed to store builder id token"); - }; - } - - Ok(()) - } - - async fn recv_code(listener: TcpListener, expected_state: String) -> Result { - let (code_tx, mut code_rx) = tokio::sync::mpsc::channel::>(1); - let (stream, _) = listener.accept().await?; - let stream = TokioIo::new(stream); // Wrapper to implement Hyper IO traits for Tokio types. - let host = listener.local_addr()?.to_string(); - tokio::spawn(async move { - if let Err(err) = http1::Builder::new() - .serve_connection(stream, PkceHttpService { - code_tx: std::sync::Arc::new(code_tx), - host, - }) - .await - { - error!(?err, "Error occurred serving the connection"); - } - }); - match code_rx.recv().await { - Some(Ok((code, state))) => { - debug!(code = "", state, "Received code and state"); - if state != expected_state { - return Err(AuthError::OAuthStateMismatch { - actual: state, - expected: expected_state, - }); - } - // Give time for the user to be redirected to index.html. - tokio::time::sleep(Duration::from_millis(200)).await; - Ok(code) - }, - Some(Err(err)) => { - // Give time for the user to be redirected to index.html. - tokio::time::sleep(Duration::from_millis(200)).await; - Err(err) - }, - None => Err(AuthError::OAuthMissingCode), - } - } -} - -type CodeSender = std::sync::Arc>>; -type ServiceError = AuthError; -type ServiceResponse = Response>; -type ServiceFuture = Pin> + Send>>; - -#[derive(Debug, Clone)] -struct PkceHttpService { - /// [`tokio::sync::mpsc::Sender`] for a (code, state) pair. - code_tx: CodeSender, - - /// The host being served - ie, the hostname and port. - /// Used for responding with redirects. - host: String, -} - -impl PkceHttpService { - /// Handles the browser redirect to `"http://{host}/oauth/callback"` which contains either the - /// code and state query params, or an error query param. Redirects to "/index.html". - /// - /// The [`Request`] doesn't actually contain the host, hence the `host` argument. - async fn handle_oauth_callback( - code_tx: CodeSender, - host: String, - req: Request, - ) -> Result { - let query_params = req - .uri() - .query() - .map(|query| { - query - .split('&') - .filter_map(|kv| kv.split_once('=')) - .collect::>() - }) - .ok_or(AuthError::OAuthCustomError("query parameters are missing".into()))?; - - // Error handling: if something goes wrong at the authorization endpoint, the - // client will be redirected to the redirect url with "error" and - // "error_description" query parameters. - if let Some(error) = query_params.get("error") { - let error_description = query_params.get("error_description").unwrap_or(&""); - let _ = code_tx - .send(Err(AuthError::OAuthCustomError(format!( - "error occurred during authorization: {:?}, {:?}", - error, error_description - )))) - .await; - return Self::redirect_to_index(&host, &format!("?error={}", error)); - } else { - let code = query_params.get("code"); - let state = query_params.get("state"); - if let (Some(code), Some(state)) = (code, state) { - let _ = code_tx.send(Ok(((*code).to_string(), (*state).to_string()))).await; - } else { - let _ = code_tx - .send(Err(AuthError::OAuthCustomError( - "missing code and/or state in the query parameters".into(), - ))) - .await; - return Self::redirect_to_index(&host, "?error=missing%20required%20query%20parameters"); - } - } - - Self::redirect_to_index(&host, "") - } - - fn redirect_to_index(host: &str, query_params: &str) -> Result { - Ok(Response::builder() - .status(302) - .header("Location", format!("http://{}/index.html{}", host, query_params)) - .body("".into()) - .expect("is valid builder, should not panic")) - } -} - -impl Service> for PkceHttpService { - type Error = ServiceError; - type Future = ServiceFuture; - type Response = ServiceResponse; - - fn call(&self, req: Request) -> Self::Future { - let code_tx: CodeSender = std::sync::Arc::clone(&self.code_tx); - let host = self.host.clone(); - Box::pin(async move { - debug!(?req, "Handling connection"); - match req.uri().path() { - "/oauth/callback" | "/oauth/callback/" => Self::handle_oauth_callback(code_tx, host, req).await, - "/index.html" => Ok(Response::builder() - .status(200) - .header("Content-Type", "text/html") - .header("Connection", "close") - .body(include_str!("./index.html").into()) - .expect("valid builder will not panic")), - _ => Ok(Response::builder() - .status(404) - .body("".into()) - .expect("valid builder will not panic")), - } - }) - } -} - -/// Query params for the initial GET request that starts the PKCE flow. Use -/// [`PkceQueryParams::as_query_params`] to get a URL-safe string. -#[derive(Debug, Clone, serde::Serialize)] -struct PkceQueryParams { - client_id: String, - redirect_uri: String, - scopes: String, - state: String, - code_challenge: String, - code_challenge_method: String, -} - -macro_rules! encode { - ($expr:expr) => { - utf8_percent_encode(&$expr, NON_ALPHANUMERIC) - }; -} - -impl PkceQueryParams { - fn as_query_params(&self) -> String { - [ - "response_type=code".to_string(), - format!("client_id={}", encode!(self.client_id)), - format!("redirect_uri={}", encode!(self.redirect_uri)), - format!("scopes={}", encode!(self.scopes)), - format!("state={}", encode!(self.state)), - format!("code_challenge={}", encode!(self.code_challenge)), - format!("code_challenge_method={}", encode!(self.code_challenge_method)), - ] - .join("&") - } -} - -/// Generates a random 43-octet URL safe string according to the RFC recommendation. -/// -/// Reference: https://datatracker.ietf.org/doc/html/rfc7636#section-4.1 -fn generate_code_verifier() -> String { - URL_SAFE.encode(rand::random::<[u8; 32]>()).replace('=', "") -} - -/// Base64 URL encoded sha256 hash of the code verifier. -/// -/// Reference: https://datatracker.ietf.org/doc/html/rfc7636#section-4.2 -fn generate_code_challenge(code_verifier: &str) -> String { - use sha2::{ - Digest, - Sha256, - }; - let mut hasher = Sha256::new(); - hasher.update(code_verifier); - URL_SAFE.encode(hasher.finalize()).replace('=', "") -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::auth::scope::is_scopes; - - #[derive(Debug, Clone)] - struct TestPkceClient; - - #[async_trait::async_trait] - impl PkceClient for TestPkceClient { - fn scopes() -> Vec { - vec!["scope:1".to_string(), "scope:2".to_string()] - } - - async fn register_client(&self, _: String, _: String) -> Result { - Ok(RegisterClientResponse { - output: RegisterClientOutput::builder() - .client_id("test_client_id") - .client_secret("test_client_secret") - .build(), - }) - } - - async fn create_token(&self, _: CreateTokenArgs) -> Result { - Ok(CreateTokenResponse { - output: CreateTokenOutput::builder().build(), - }) - } - } - - #[ignore = "not in ci"] - #[tokio::test] - async fn test_pkce_flow_e2e() { - tracing_subscriber::fmt::init(); - - let start_url = "https://amzn.awsapps.com/start".to_string(); - let region = Region::new("us-east-1"); - let client = client(region.clone()); - let registration = PkceRegistration::register(&client, region.clone(), start_url, None) - .await - .unwrap(); - println!("{:?}", registration); - if crate::util::open::open_url_async(®istration.url).await.is_err() { - panic!("unable to open the URL"); - } - println!("Waiting for authorization to complete..."); - - registration.finish(&client, None).await.unwrap(); - println!("Authorization successful"); - } - - #[tokio::test] - async fn test_pkce_flow_completes_successfully() { - // tracing_subscriber::fmt::init(); - let region = Region::new("us-east-1"); - let issuer_url = START_URL.into(); - let client = TestPkceClient {}; - let registration = PkceRegistration::register(&client, region, issuer_url, None) - .await - .unwrap(); - - let redirect_uri = registration.redirect_uri.clone(); - let state = registration.state.clone(); - tokio::spawn(async move { - // Let registration.finish be called to handle the request. - tokio::time::sleep(std::time::Duration::from_secs(1)).await; - reqwest::get(format!("{}/?code={}&state={}", redirect_uri, "code", state)) - .await - .unwrap(); - }); - - registration.finish(&client, None).await.unwrap(); - } - - #[tokio::test] - async fn test_pkce_flow_with_state_mismatch_throws_err() { - let region = Region::new("us-east-1"); - let issuer_url = START_URL.into(); - let client = TestPkceClient {}; - let registration = PkceRegistration::register(&client, region, issuer_url, None) - .await - .unwrap(); - - let redirect_uri = registration.redirect_uri.clone(); - tokio::spawn(async move { - // Let registration.finish be called to handle the request. - tokio::time::sleep(std::time::Duration::from_secs(1)).await; - reqwest::get(format!("{}/?code={}&state={}", redirect_uri, "code", "not_my_state")) - .await - .unwrap(); - }); - - assert!(matches!( - registration.finish(&client, None).await, - Err(AuthError::OAuthStateMismatch { actual: _, expected: _ }) - )); - } - - #[tokio::test] - async fn test_pkce_flow_with_authorization_redirect_error() { - let region = Region::new("us-east-1"); - let issuer_url = START_URL.into(); - let client = TestPkceClient {}; - let registration = PkceRegistration::register(&client, region, issuer_url, None) - .await - .unwrap(); - - let redirect_uri = registration.redirect_uri.clone(); - tokio::spawn(async move { - // Let registration.finish be called to handle the request. - tokio::time::sleep(std::time::Duration::from_secs(1)).await; - reqwest::get(format!( - "{}/?error={}&error_description={}", - redirect_uri, "error code", "something bad happened?" - )) - .await - .unwrap(); - }); - - assert!(matches!( - registration.finish(&client, None).await, - Err(AuthError::OAuthCustomError(_)) - )); - } - - #[tokio::test] - async fn test_pkce_flow_with_timeout() { - let region = Region::new("us-east-1"); - let issuer_url = START_URL.into(); - let client = TestPkceClient {}; - let registration = PkceRegistration::register(&client, region, issuer_url, Some(Duration::from_millis(100))) - .await - .unwrap(); - - assert!(matches!( - registration.finish(&client, None).await, - Err(AuthError::OAuthTimeout) - )); - } - - #[tokio::test] - async fn verify_gen_code_challenge() { - let code_verifier = generate_code_verifier(); - println!("{:?}", code_verifier); - - let code_challenge = generate_code_challenge(&code_verifier); - println!("{:?}", code_challenge); - assert!(code_challenge.len() >= 43); - } - - #[test] - fn verify_client_scopes() { - assert!(is_scopes(&Client::scopes())); - } -} diff --git a/crates/chat-cli/src/auth/scope.rs b/crates/chat-cli/src/auth/scope.rs deleted file mode 100644 index b6f9cddd0..000000000 --- a/crates/chat-cli/src/auth/scope.rs +++ /dev/null @@ -1,33 +0,0 @@ -use crate::auth::consts::SCOPES; - -pub fn scopes_match, B: AsRef>(a: &[A], b: &[B]) -> bool { - if a.len() != b.len() { - return false; - } - - let mut a = a.iter().map(|s| s.as_ref()).collect::>(); - let mut b = b.iter().map(|s| s.as_ref()).collect::>(); - a.sort(); - b.sort(); - a == b -} - -/// Checks if the given scopes match the predefined scopes. -pub(crate) fn is_scopes>(scopes: &[S]) -> bool { - scopes_match(SCOPES, scopes) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_scopes_match() { - assert!(scopes_match(&["a", "b", "c"], &["a", "b", "c"])); - assert!(scopes_match(&["a", "b", "c"], &["a", "c", "b"])); - assert!(!scopes_match(&["a", "b", "c"], &["a", "b"])); - assert!(!scopes_match(&["a", "b"], &["a", "b", "c"])); - - assert!(is_scopes(SCOPES)); - } -} diff --git a/crates/chat-cli/src/aws_common/http_client.rs b/crates/chat-cli/src/aws_common/http_client.rs deleted file mode 100644 index 85c2b482c..000000000 --- a/crates/chat-cli/src/aws_common/http_client.rs +++ /dev/null @@ -1,198 +0,0 @@ -use std::time::Duration; - -use aws_smithy_runtime_api::client::http::{ - HttpClient, - HttpConnector, - HttpConnectorFuture, - HttpConnectorSettings, - SharedHttpConnector, -}; -use aws_smithy_runtime_api::client::result::ConnectorError; -use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents; -use aws_smithy_runtime_api::http::Request; -use aws_smithy_types::body::SdkBody; -use reqwest::Client as ReqwestClient; - -/// Returns a wrapper around the global [fig_request::client] that implements -/// [HttpClient]. -pub fn client() -> Client { - let client = crate::request::new_client().expect("failed to create http client"); - Client::new(client.clone()) -} - -/// A wrapper around [reqwest::Client] that implements [HttpClient]. -/// -/// This is required to support using proxy servers with the AWS SDK. -#[derive(Debug, Clone)] -pub struct Client { - inner: ReqwestClient, -} - -impl Client { - pub fn new(client: ReqwestClient) -> Self { - Self { inner: client } - } -} - -#[derive(Debug)] -struct CallError { - kind: CallErrorKind, - message: &'static str, - source: Option>, -} - -impl CallError { - fn user(message: &'static str) -> Self { - Self { - kind: CallErrorKind::User, - message, - source: None, - } - } - - fn user_with_source(message: &'static str, source: E) -> Self - where - E: std::error::Error + Send + Sync + 'static, - { - Self { - kind: CallErrorKind::User, - message, - source: Some(Box::new(source)), - } - } - - fn timeout(source: E) -> Self - where - E: std::error::Error + Send + Sync + 'static, - { - Self { - kind: CallErrorKind::Timeout, - message: "request timed out", - source: Some(Box::new(source)), - } - } - - fn io(source: E) -> Self - where - E: std::error::Error + Send + Sync + 'static, - { - Self { - kind: CallErrorKind::Io, - message: "an i/o error occurred", - source: Some(Box::new(source)), - } - } - - fn other(message: &'static str, source: E) -> Self - where - E: std::error::Error + Send + Sync + 'static, - { - Self { - kind: CallErrorKind::Other, - message, - source: Some(Box::new(source)), - } - } -} - -impl std::error::Error for CallError {} - -impl std::fmt::Display for CallError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.message)?; - if let Some(err) = self.source.as_ref() { - write!(f, ": {}", err)?; - } - Ok(()) - } -} - -impl From for ConnectorError { - fn from(value: CallError) -> Self { - match &value.kind { - CallErrorKind::User => Self::user(Box::new(value)), - CallErrorKind::Timeout => Self::timeout(Box::new(value)), - CallErrorKind::Io => Self::io(Box::new(value)), - CallErrorKind::Other => Self::other(Box::new(value), None), - } - } -} - -impl From for CallError { - fn from(err: reqwest::Error) -> Self { - if err.is_timeout() { - CallError::timeout(err) - } else if err.is_connect() { - CallError::io(err) - } else { - CallError::other("an unknown error occurred", err) - } - } -} - -#[derive(Debug, Clone)] -enum CallErrorKind { - User, - Timeout, - Io, - Other, -} - -#[derive(Debug)] -struct ReqwestConnector { - client: ReqwestClient, - timeout: Option, -} - -impl HttpConnector for ReqwestConnector { - fn call(&self, request: Request) -> HttpConnectorFuture { - let client = self.client.clone(); - let timeout = self.timeout; - - HttpConnectorFuture::new(async move { - // Convert the aws_smithy_runtime_api request to a reqwest request. - // TODO: There surely has to be a better way to convert an aws_smith_runtime_api - // Request to a reqwest Request. - let mut req_builder = client.request( - reqwest::Method::from_bytes(request.method().as_bytes()) - .map_err(|err| CallError::user_with_source("failed to create method name", err))?, - request.uri().to_owned(), - ); - // Copy the header, body, and timeout. - let parts = request.into_parts(); - for (name, value) in parts.headers.iter() { - let name = name.to_owned(); - let value = value.as_bytes().to_owned(); - req_builder = req_builder.header(name, value); - } - let body_bytes = parts - .body - .bytes() - .ok_or(CallError::user("streaming request body is not supported"))? - .to_owned(); - req_builder = req_builder.body(body_bytes); - if let Some(timeout) = timeout { - req_builder = req_builder.timeout(timeout); - } - - let reqwest_response = req_builder.send().await.map_err(CallError::from)?; - - // Converts from a reqwest Response into an http::Response. - let (parts, body) = http::Response::from(reqwest_response).into_parts(); - let http_response = http::Response::from_parts(parts, SdkBody::from_body_1_x(body)); - - Ok(aws_smithy_runtime_api::http::Response::try_from(http_response) - .map_err(|err| CallError::other("failed to convert to a proper response", err))?) - }) - } -} - -impl HttpClient for Client { - fn http_connector(&self, settings: &HttpConnectorSettings, _components: &RuntimeComponents) -> SharedHttpConnector { - let connector = ReqwestConnector { - client: self.inner.clone(), - timeout: settings.read_timeout(), - }; - SharedHttpConnector::new(connector) - } -} diff --git a/crates/chat-cli/src/aws_common/mod.rs b/crates/chat-cli/src/aws_common/mod.rs deleted file mode 100644 index 4632a3bf0..000000000 --- a/crates/chat-cli/src/aws_common/mod.rs +++ /dev/null @@ -1,36 +0,0 @@ -pub mod http_client; -mod sdk_error_display; -mod user_agent_override_interceptor; - -use std::sync::LazyLock; - -use aws_smithy_runtime_api::client::behavior_version::BehaviorVersion; -use aws_types::app_name::AppName; -pub use sdk_error_display::SdkErrorDisplay; -pub use user_agent_override_interceptor::UserAgentOverrideInterceptor; - -const APP_NAME_STR: &str = "AmazonQ-For-CLI"; - -pub fn app_name() -> AppName { - static APP_NAME: LazyLock = LazyLock::new(|| AppName::new(APP_NAME_STR).expect("invalid app name")); - APP_NAME.clone() -} - -pub fn behavior_version() -> BehaviorVersion { - BehaviorVersion::v2025_08_07() -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_app_name() { - println!("{}", app_name()); - } - - #[test] - fn test_behavior_version() { - assert!(behavior_version() == BehaviorVersion::latest()); - } -} diff --git a/crates/chat-cli/src/aws_common/sdk_error_display.rs b/crates/chat-cli/src/aws_common/sdk_error_display.rs deleted file mode 100644 index 6bd8b544c..000000000 --- a/crates/chat-cli/src/aws_common/sdk_error_display.rs +++ /dev/null @@ -1,96 +0,0 @@ -use std::error::Error; -use std::fmt::{ - self, - Debug, - Display, -}; - -use aws_smithy_runtime_api::client::result::SdkError; - -#[derive(Debug)] -pub struct SdkErrorDisplay<'a, E, R>(pub &'a SdkError); - -impl Display for SdkErrorDisplay<'_, E, R> -where - E: Display, -{ - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match &self.0 { - SdkError::ConstructionFailure(_) => { - write!(f, "failed to construct request") - }, - SdkError::TimeoutError(_) => write!(f, "request has timed out"), - SdkError::DispatchFailure(e) => { - write!(f, "dispatch failure")?; - if let Some(connector_error) = e.as_connector_error() { - if let Some(source) = connector_error.source() { - write!(f, " ({connector_error}): {source}")?; - } else { - write!(f, ": {connector_error}")?; - } - } - Ok(()) - }, - SdkError::ResponseError(_) => write!(f, "response error"), - SdkError::ServiceError(e) => { - write!(f, "{}", e.err()) - }, - other => write!(f, "{other}"), - } - } -} - -impl Error for SdkErrorDisplay<'_, E, R> -where - E: Error + 'static, - R: Debug, -{ - fn source(&self) -> Option<&(dyn Error + 'static)> { - self.0.source() - } -} - -#[cfg(test)] -mod tests { - use aws_smithy_runtime_api::client::result::{ - ConnectorError, - ConstructionFailure, - DispatchFailure, - ResponseError, - SdkError, - ServiceError, - TimeoutError, - }; - - use super::SdkErrorDisplay; - - #[test] - fn test_displays_sdk_error() { - let construction_failure = ConstructionFailure::builder().source("").build(); - let sdk_error: SdkError = SdkError::ConstructionFailure(construction_failure); - let sdk_error_display = SdkErrorDisplay(&sdk_error); - assert_eq!("failed to construct request", sdk_error_display.to_string()); - - let timeout_error = TimeoutError::builder().source("").build(); - let sdk_error: SdkError = SdkError::TimeoutError(timeout_error); - let sdk_error_display = SdkErrorDisplay(&sdk_error); - assert_eq!("request has timed out", sdk_error_display.to_string()); - - let dispatch_failure = DispatchFailure::builder() - .source(ConnectorError::io("".into())) - .build(); - let sdk_error: SdkError = SdkError::DispatchFailure(dispatch_failure); - let sdk_error_display = SdkErrorDisplay(&sdk_error); - assert_eq!("dispatch failure (io error): ", sdk_error_display.to_string()); - - let response_error = ResponseError::builder().source("").raw("".into()).build(); - let sdk_error: SdkError = SdkError::ResponseError(response_error); - let sdk_error_display = SdkErrorDisplay(&sdk_error); - assert_eq!("response error", sdk_error_display.to_string()); - - let service_error = ServiceError::builder().source("").raw("".into()).build(); - let sdk_error: SdkError = SdkError::ServiceError(service_error); - let sdk_error_display = SdkErrorDisplay(&sdk_error); - assert_eq!("", sdk_error_display.to_string()); - } -} diff --git a/crates/chat-cli/src/aws_common/user_agent_override_interceptor.rs b/crates/chat-cli/src/aws_common/user_agent_override_interceptor.rs deleted file mode 100644 index b8c7d2846..000000000 --- a/crates/chat-cli/src/aws_common/user_agent_override_interceptor.rs +++ /dev/null @@ -1,239 +0,0 @@ -use std::borrow::Cow; -use std::error::Error; -use std::fmt; - -use aws_runtime::user_agent::{ - AdditionalMetadata, - ApiMetadata, - AwsUserAgent, -}; -use aws_smithy_runtime_api::box_error::BoxError; -use aws_smithy_runtime_api::client::interceptors::Intercept; -use aws_smithy_runtime_api::client::interceptors::context::BeforeTransmitInterceptorContextMut; -use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents; -use aws_smithy_types::config_bag::ConfigBag; -use aws_types::app_name::AppName; -use aws_types::os_shim_internal::Env; -use http::header::{ - InvalidHeaderValue, - USER_AGENT, -}; -use tracing::{ - trace, - warn, -}; - -/// The environment variable name of additional user agent metadata we include in the user agent -/// string. This is used in AWS CloudShell where they want to track usage by version. -const AWS_TOOLING_USER_AGENT: &str = "AWS_TOOLING_USER_AGENT"; - -const VERSION_HEADER: &str = "appVersion"; -const VERSION_VALUE: &str = env!("CARGO_PKG_VERSION"); - -#[derive(Debug)] -enum UserAgentOverrideInterceptorError { - MissingApiMetadata, - InvalidHeaderValue(InvalidHeaderValue), -} - -impl Error for UserAgentOverrideInterceptorError { - fn source(&self) -> Option<&(dyn Error + 'static)> { - match self { - Self::InvalidHeaderValue(source) => Some(source), - Self::MissingApiMetadata => None, - } - } -} - -impl fmt::Display for UserAgentOverrideInterceptorError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(match self { - Self::InvalidHeaderValue(_) => "AwsUserAgent generated an invalid HTTP header value. This is a bug. Please file an issue.", - Self::MissingApiMetadata => "The UserAgentInterceptor requires ApiMetadata to be set before the request is made. This is a bug. Please file an issue.", - }) - } -} - -impl From for UserAgentOverrideInterceptorError { - fn from(err: InvalidHeaderValue) -> Self { - UserAgentOverrideInterceptorError::InvalidHeaderValue(err) - } -} -/// Generates and attaches the AWS SDK's user agent to a HTTP request -#[non_exhaustive] -#[derive(Debug, Default)] -pub struct UserAgentOverrideInterceptor { - env: Env, -} - -impl UserAgentOverrideInterceptor { - /// Creates a new `UserAgentInterceptor` - pub fn new() -> Self { - Self { env: Env::real() } - } - - #[cfg(test)] - pub fn from_env(env: Env) -> Self { - Self { env } - } -} - -impl Intercept for UserAgentOverrideInterceptor { - fn name(&self) -> &'static str { - "UserAgentOverrideInterceptor" - } - - fn modify_before_signing( - &self, - context: &mut BeforeTransmitInterceptorContextMut<'_>, - _runtime_components: &RuntimeComponents, - cfg: &mut ConfigBag, - ) -> Result<(), BoxError> { - let env = self.env.clone(); - - // Allow for overriding the user agent by an earlier interceptor (so, for example, - // tests can use `AwsUserAgent::for_tests()`) by attempting to grab one out of the - // config bag before creating one. - let ua: Cow<'_, AwsUserAgent> = match cfg.get_mut::() { - Some(ua) => { - apply_additional_metadata(&self.env, ua); - Cow::Borrowed(ua) - }, - None => { - let api_metadata = cfg - .load::() - .ok_or(UserAgentOverrideInterceptorError::MissingApiMetadata)?; - - let mut ua = AwsUserAgent::new_from_environment(self.env.clone(), api_metadata.clone()); - - let maybe_app_name = cfg.load::(); - if let Some(app_name) = maybe_app_name { - ua.set_app_name(app_name.clone()); - } - - apply_additional_metadata(&env, &mut ua); - - Cow::Owned(ua) - }, - }; - - trace!(?ua, "setting user agent"); - - let headers = context.request_mut().headers_mut(); - headers.insert(USER_AGENT.as_str(), ua.aws_ua_header()); - Ok(()) - } -} - -fn apply_additional_metadata(env: &Env, ua: &mut AwsUserAgent) { - let ver = format!("{VERSION_HEADER}/{VERSION_VALUE}"); - match AdditionalMetadata::new(clean_metadata(&ver)) { - Ok(md) => { - ua.add_additional_metadata(md); - }, - Err(err) => panic!("Failed to parse version: {err}"), - }; - - if let Ok(val) = env.get(AWS_TOOLING_USER_AGENT) { - match AdditionalMetadata::new(clean_metadata(&val)) { - Ok(md) => { - ua.add_additional_metadata(md); - }, - Err(err) => warn!(%err, %val, "Failed to parse {AWS_TOOLING_USER_AGENT}"), - }; - } -} - -fn clean_metadata(s: &str) -> String { - let valid_character = |c: char| -> bool { - match c { - _ if c.is_ascii_alphanumeric() => true, - '!' | '#' | '$' | '%' | '&' | '\'' | '*' | '+' | '-' | '.' | '^' | '_' | '`' | '|' | '~' => true, - _ => false, - } - }; - s.chars().map(|c| if valid_character(c) { c } else { '-' }).collect() -} - -#[cfg(test)] -mod tests { - use aws_smithy_runtime_api::client::interceptors::context::{ - Input, - InterceptorContext, - }; - use aws_smithy_runtime_api::client::runtime_components::RuntimeComponentsBuilder; - use aws_smithy_types::config_bag::Layer; - use http::HeaderValue; - - use super::*; - use crate::aws_common::{ - APP_NAME_STR, - app_name, - }; - - #[test] - fn error_test() { - let err = UserAgentOverrideInterceptorError::InvalidHeaderValue(HeaderValue::from_bytes(b"\0").unwrap_err()); - assert!(err.source().is_some()); - println!("{err}"); - - let err = UserAgentOverrideInterceptorError::MissingApiMetadata; - assert!(err.source().is_none()); - println!("{err}"); - } - - fn user_agent_base() -> (RuntimeComponents, ConfigBag, InterceptorContext) { - let rc = RuntimeComponentsBuilder::for_tests().build().unwrap(); - let mut cfg = ConfigBag::base(); - - let mut layer = Layer::new("layer"); - layer.store_put(ApiMetadata::new("q", "123")); - layer.store_put(app_name()); - cfg.push_layer(layer); - - let mut context = InterceptorContext::new(Input::erase(())); - context.set_request(aws_smithy_runtime_api::http::Request::empty()); - - (rc, cfg, context) - } - - #[test] - fn user_agent_override_test() { - let (rc, mut cfg, mut context) = user_agent_base(); - let mut context = BeforeTransmitInterceptorContextMut::from(&mut context); - let interceptor = UserAgentOverrideInterceptor::new(); - println!("Interceptor: {}", interceptor.name()); - interceptor - .modify_before_signing(&mut context, &rc, &mut cfg) - .expect("success"); - - let ua = context.request().headers().get(USER_AGENT).unwrap(); - println!("User-Agent: {ua}"); - assert!(ua.contains(&format!("app/{APP_NAME_STR}"))); - assert!(ua.contains(VERSION_HEADER)); - assert!(ua.contains(VERSION_VALUE)); - } - - #[test] - fn user_agent_override_cloudshell_test() { - let (rc, mut cfg, mut context) = user_agent_base(); - let mut context = BeforeTransmitInterceptorContextMut::from(&mut context); - let env = Env::from_slice(&[ - ("AWS_EXECUTION_ENV", "CloudShell"), - (AWS_TOOLING_USER_AGENT, "AWS-CloudShell/2024.08.29"), - ]); - let interceptor = UserAgentOverrideInterceptor::from_env(env); - println!("Interceptor: {}", interceptor.name()); - interceptor - .modify_before_signing(&mut context, &rc, &mut cfg) - .expect("success"); - - let ua = context.request().headers().get(USER_AGENT).unwrap(); - println!("User-Agent: {ua}"); - assert!(ua.contains(&format!("app/{APP_NAME_STR}"))); - assert!(ua.contains("exec-env/CloudShell")); - assert!(ua.contains("md/AWS-CloudShell-2024.08.29")); - assert!(ua.contains(VERSION_HEADER)); - assert!(ua.contains(VERSION_VALUE)); - } -} diff --git a/crates/chat-cli/src/cli/chat/cli/clear.rs b/crates/chat-cli/src/cli/chat/cli/clear.rs deleted file mode 100644 index d8ac04aa5..000000000 --- a/crates/chat-cli/src/cli/chat/cli/clear.rs +++ /dev/null @@ -1,65 +0,0 @@ -use clap::Args; -use crossterm::style::{ - self, - Color, - Stylize, -}; -use crossterm::{ - cursor, - execute, -}; - -use crate::cli::chat::{ - ChatError, - ChatSession, - ChatState, -}; - -#[deny(missing_docs)] -#[derive(Debug, PartialEq, Args)] -pub struct ClearArgs; - -impl ClearArgs { - pub async fn execute(self, session: &mut ChatSession) -> Result { - execute!( - session.stderr, - style::SetForegroundColor(Color::DarkGrey), - style::Print( - "\nAre you sure? This will erase the conversation history and context from hooks for the current session. " - ), - style::Print("["), - style::SetForegroundColor(Color::Green), - style::Print("y"), - style::SetForegroundColor(Color::DarkGrey), - style::Print("/"), - style::SetForegroundColor(Color::Green), - style::Print("n"), - style::SetForegroundColor(Color::DarkGrey), - style::Print("]:\n\n"), - style::SetForegroundColor(Color::Reset), - cursor::Show, - )?; - - // Setting `exit_on_single_ctrl_c` for better ux: exit the confirmation dialog rather than the CLI - let user_input = match session.read_user_input("> ".yellow().to_string().as_str(), true) { - Some(input) => input, - None => "".to_string(), - }; - - if ["y", "Y"].contains(&user_input.as_str()) { - session.conversation.clear(true); - if let Some(cm) = session.conversation.context_manager.as_mut() { - cm.hook_executor.global_cache.clear(); - cm.hook_executor.profile_cache.clear(); - } - execute!( - session.stderr, - style::SetForegroundColor(Color::Green), - style::Print("\nConversation history cleared.\n\n"), - style::SetForegroundColor(Color::Reset) - )?; - } - - Ok(ChatState::default()) - } -} diff --git a/crates/chat-cli/src/cli/chat/cli/compact.rs b/crates/chat-cli/src/cli/chat/cli/compact.rs deleted file mode 100644 index f132ef27e..000000000 --- a/crates/chat-cli/src/cli/chat/cli/compact.rs +++ /dev/null @@ -1,91 +0,0 @@ -use clap::Args; - -use crate::cli::chat::consts::MAX_USER_MESSAGE_SIZE; -use crate::cli::chat::message::UserMessageContent; -use crate::cli::chat::{ - ChatError, - ChatSession, - ChatState, -}; -use crate::os::Os; - -#[deny(missing_docs)] -#[derive(Debug, PartialEq, Args)] -#[command( - before_long_help = "/compact summarizes the conversation history to free up context space -while preserving essential information. This is useful for long-running conversations -that may eventually reach memory constraints. - -When to use -• When you see the memory constraint warning message -• When a conversation has been running for a long time -• Before starting a new topic within the same session -• After completing complex tool operations - -How it works -• Creates an AI-generated summary of your conversation -• Retains key information, code, and tool executions in the summary -• Clears the conversation history to free up space -• The assistant will reference the summary context in future responses - -Compaction will be automatically performed whenever the context window overflows. -To disable this behavior, run: `q settings chat.disableAutoCompaction true`" -)] -pub struct CompactArgs { - /// The prompt to use when generating the summary - prompt: Vec, - #[arg(long)] - show_summary: bool, - /// The number of user and assistant message pairs to exclude from the summarization. - #[arg(long)] - messages_to_exclude: Option, - /// Whether or not large messages should be truncated. - #[arg(long)] - truncate_large_messages: Option, - /// Maximum allowed size of messages in the conversation history. Requires - /// truncate_large_messages to be set. - #[arg(long, requires = "truncate_large_messages")] - max_message_length: Option, -} - -impl CompactArgs { - pub async fn execute(self, os: &Os, session: &mut ChatSession) -> Result { - let default = CompactStrategy::default(); - let prompt = if self.prompt.is_empty() { - None - } else { - Some(self.prompt.join(" ")) - }; - - session - .compact_history(os, prompt, self.show_summary, CompactStrategy { - messages_to_exclude: self.messages_to_exclude.unwrap_or(default.messages_to_exclude), - truncate_large_messages: self.truncate_large_messages.unwrap_or(default.truncate_large_messages), - max_message_length: self.max_message_length.map_or(default.max_message_length, |v| { - v.clamp(UserMessageContent::TRUNCATED_SUFFIX.len(), MAX_USER_MESSAGE_SIZE) - }), - }) - .await - } -} - -/// Parameters for performing the history compaction request. -#[derive(Debug, Copy, Clone)] -pub struct CompactStrategy { - /// Number of user/assistant pairs to exclude from the history as part of compaction. - pub messages_to_exclude: usize, - /// Whether or not to truncate large messages in the history. - pub truncate_large_messages: bool, - /// Maximum allowed size of messages in the conversation history. - pub max_message_length: usize, -} - -impl Default for CompactStrategy { - fn default() -> Self { - Self { - messages_to_exclude: Default::default(), - truncate_large_messages: Default::default(), - max_message_length: MAX_USER_MESSAGE_SIZE, - } - } -} diff --git a/crates/chat-cli/src/cli/chat/cli/context.rs b/crates/chat-cli/src/cli/chat/cli/context.rs deleted file mode 100644 index 0fedc8c01..000000000 --- a/crates/chat-cli/src/cli/chat/cli/context.rs +++ /dev/null @@ -1,445 +0,0 @@ -use std::collections::HashSet; - -use clap::Subcommand; -use crossterm::style::{ - Attribute, - Color, -}; -use crossterm::{ - execute, - style, -}; - -use crate::cli::chat::cli::hooks::{ - HookTrigger, - map_chat_error, - print_hook_section, -}; -use crate::cli::chat::consts::CONTEXT_FILES_MAX_SIZE; -use crate::cli::chat::token_counter::TokenCounter; -use crate::cli::chat::util::drop_matched_context_files; -use crate::cli::chat::{ - ChatError, - ChatSession, - ChatState, -}; -use crate::os::Os; - -#[deny(missing_docs)] -#[derive(Debug, PartialEq, Subcommand)] -#[command( - before_long_help = "Context rules determine which files are included in your Amazon Q session. -The files matched by these rules provide Amazon Q with additional information -about your project or environment. Adding relevant files helps Q generate -more accurate and helpful responses. - -Notes: -• You can add specific files or use glob patterns (e.g., \"*.py\", \"src/**/*.js\") -• Profile rules apply only to the current profile -• Global rules apply across all profiles -• Context is preserved between chat sessions" -)] -pub enum ContextSubcommand { - /// Display the context rule configuration and matched files - Show { - /// Print out each matched file's content, hook configurations, and last - /// session.conversation summary - #[arg(long)] - expand: bool, - }, - /// Add context rules (filenames or glob patterns) - Add { - /// Add to global rules (available in all profiles) - #[arg(short, long)] - global: bool, - /// Include even if matched files exceed size limits - #[arg(short, long)] - force: bool, - #[arg(required = true)] - paths: Vec, - }, - /// Remove specified rules from current profile - #[command(alias = "rm")] - Remove { - /// Remove specified rules globally - #[arg(short, long)] - global: bool, - #[arg(required = true)] - paths: Vec, - }, - /// Remove all rules from current profile - Clear { - /// Remove global rules - #[arg(short, long)] - global: bool, - }, - #[command(hide = true)] - Hooks, -} - -impl ContextSubcommand { - pub async fn execute(self, os: &Os, session: &mut ChatSession) -> Result { - let Some(context_manager) = &mut session.conversation.context_manager else { - execute!( - session.stderr, - style::SetForegroundColor(Color::Red), - style::Print("\nContext management is not available.\n\n"), - style::SetForegroundColor(Color::Reset) - )?; - - return Ok(ChatState::PromptUser { - skip_printing_tools: true, - }); - }; - - match self { - Self::Show { expand } => { - // Display global context - execute!( - session.stderr, - style::SetAttribute(Attribute::Bold), - style::SetForegroundColor(Color::Magenta), - style::Print("\n🌍 global:\n"), - style::SetAttribute(Attribute::Reset), - )?; - let mut global_context_files = HashSet::new(); - let mut profile_context_files = HashSet::new(); - if context_manager.global_config.paths.is_empty() { - execute!( - session.stderr, - style::SetForegroundColor(Color::DarkGrey), - style::Print(" \n"), - style::SetForegroundColor(Color::Reset) - )?; - } else { - for path in &context_manager.global_config.paths { - execute!(session.stderr, style::Print(format!(" {} ", path)))?; - if let Ok(context_files) = context_manager.get_context_files_by_path(os, path).await { - execute!( - session.stderr, - style::SetForegroundColor(Color::Green), - style::Print(format!( - "({} match{})", - context_files.len(), - if context_files.len() == 1 { "" } else { "es" } - )), - style::SetForegroundColor(Color::Reset) - )?; - global_context_files.extend(context_files); - } - execute!(session.stderr, style::Print("\n"))?; - } - } - - if expand { - execute!( - session.stderr, - style::SetAttribute(Attribute::Bold), - style::SetForegroundColor(Color::DarkYellow), - style::Print("\n 🔧 Hooks:\n") - )?; - print_hook_section( - &mut session.stderr, - &context_manager.global_config.hooks, - HookTrigger::ConversationStart, - ) - .map_err(map_chat_error)?; - - print_hook_section( - &mut session.stderr, - &context_manager.global_config.hooks, - HookTrigger::PerPrompt, - ) - .map_err(map_chat_error)?; - } - - // Display profile context - execute!( - session.stderr, - style::SetAttribute(Attribute::Bold), - style::SetForegroundColor(Color::Magenta), - style::Print(format!("\n👤 profile ({}):\n", context_manager.current_profile)), - style::SetAttribute(Attribute::Reset), - )?; - - if context_manager.profile_config.paths.is_empty() { - execute!( - session.stderr, - style::SetForegroundColor(Color::DarkGrey), - style::Print(" \n\n"), - style::SetForegroundColor(Color::Reset) - )?; - } else { - for path in &context_manager.profile_config.paths { - execute!(session.stderr, style::Print(format!(" {} ", path)))?; - if let Ok(context_files) = context_manager.get_context_files_by_path(os, path).await { - execute!( - session.stderr, - style::SetForegroundColor(Color::Green), - style::Print(format!( - "({} match{})", - context_files.len(), - if context_files.len() == 1 { "" } else { "es" } - )), - style::SetForegroundColor(Color::Reset) - )?; - profile_context_files.extend(context_files); - } - execute!(session.stderr, style::Print("\n"))?; - } - execute!(session.stderr, style::Print("\n"))?; - } - - if expand { - execute!( - session.stderr, - style::SetAttribute(Attribute::Bold), - style::SetForegroundColor(Color::DarkYellow), - style::Print(" 🔧 Hooks:\n") - )?; - print_hook_section( - &mut session.stderr, - &context_manager.profile_config.hooks, - HookTrigger::ConversationStart, - ) - .map_err(map_chat_error)?; - print_hook_section( - &mut session.stderr, - &context_manager.profile_config.hooks, - HookTrigger::PerPrompt, - ) - .map_err(map_chat_error)?; - execute!(session.stderr, style::Print("\n"))?; - } - - if global_context_files.is_empty() && profile_context_files.is_empty() { - execute!( - session.stderr, - style::SetForegroundColor(Color::DarkGrey), - style::Print("No files in the current directory matched the rules above.\n\n"), - style::SetForegroundColor(Color::Reset) - )?; - } else { - let total = global_context_files.len() + profile_context_files.len(); - let total_tokens = global_context_files - .iter() - .map(|(_, content)| TokenCounter::count_tokens(content)) - .sum::() - + profile_context_files - .iter() - .map(|(_, content)| TokenCounter::count_tokens(content)) - .sum::(); - execute!( - session.stderr, - style::SetForegroundColor(Color::Green), - style::SetAttribute(Attribute::Bold), - style::Print(format!( - "{} matched file{} in use:\n", - total, - if total == 1 { "" } else { "s" } - )), - style::SetForegroundColor(Color::Reset), - style::SetAttribute(Attribute::Reset) - )?; - - for (filename, content) in &global_context_files { - let est_tokens = TokenCounter::count_tokens(content); - execute!( - session.stderr, - style::Print(format!("🌍 {} ", filename)), - style::SetForegroundColor(Color::DarkGrey), - style::Print(format!("(~{} tkns)\n", est_tokens)), - style::SetForegroundColor(Color::Reset), - )?; - if expand { - execute!( - session.stderr, - style::SetForegroundColor(Color::DarkGrey), - style::Print(format!("{}\n\n", content)), - style::SetForegroundColor(Color::Reset) - )?; - } - } - - for (filename, content) in &profile_context_files { - let est_tokens = TokenCounter::count_tokens(content); - execute!( - session.stderr, - style::Print(format!("👤 {} ", filename)), - style::SetForegroundColor(Color::DarkGrey), - style::Print(format!("(~{} tkns)\n", est_tokens)), - style::SetForegroundColor(Color::Reset), - )?; - if expand { - execute!( - session.stderr, - style::SetForegroundColor(Color::DarkGrey), - style::Print(format!("{}\n\n", content)), - style::SetForegroundColor(Color::Reset) - )?; - } - } - - if expand { - execute!(session.stderr, style::Print(format!("{}\n\n", "▔".repeat(3))),)?; - } - - let mut combined_files: Vec<(String, String)> = global_context_files - .iter() - .chain(profile_context_files.iter()) - .cloned() - .collect(); - - let dropped_files = drop_matched_context_files(&mut combined_files, CONTEXT_FILES_MAX_SIZE).ok(); - - execute!( - session.stderr, - style::Print(format!("\nTotal: ~{} tokens\n\n", total_tokens)) - )?; - - if let Some(dropped_files) = dropped_files { - if !dropped_files.is_empty() { - execute!( - session.stderr, - style::SetForegroundColor(Color::DarkYellow), - style::Print(format!( - "Total token count exceeds limit: {}. The following files will be automatically dropped when interacting with Q. Consider removing them. \n\n", - CONTEXT_FILES_MAX_SIZE - )), - style::SetForegroundColor(Color::Reset) - )?; - let total_files = dropped_files.len(); - - let truncated_dropped_files = &dropped_files[..10]; - - for (filename, content) in truncated_dropped_files { - let est_tokens = TokenCounter::count_tokens(content); - execute!( - session.stderr, - style::Print(format!("{} ", filename)), - style::SetForegroundColor(Color::DarkGrey), - style::Print(format!("(~{} tkns)\n", est_tokens)), - style::SetForegroundColor(Color::Reset), - )?; - } - - if total_files > 10 { - execute!( - session.stderr, - style::Print(format!("({} more files)\n", total_files - 10)) - )?; - } - } - } - - execute!(session.stderr, style::Print("\n"))?; - } - - // Show last cached session.conversation summary if available, otherwise regenerate it - if expand { - if let Some(summary) = session.conversation.latest_summary() { - let border = "═".repeat(session.terminal_width().min(80)); - execute!( - session.stderr, - style::Print("\n"), - style::SetForegroundColor(Color::Cyan), - style::Print(&border), - style::Print("\n"), - style::SetAttribute(Attribute::Bold), - style::Print(" CONVERSATION SUMMARY"), - style::Print("\n"), - style::Print(&border), - style::SetAttribute(Attribute::Reset), - style::Print("\n\n"), - style::Print(&summary), - style::Print("\n\n\n") - )?; - } - } - }, - Self::Add { global, force, paths } => { - match context_manager.add_paths(os, paths.clone(), global, force).await { - Ok(_) => { - let target = if global { "global" } else { "profile" }; - execute!( - session.stderr, - style::SetForegroundColor(Color::Green), - style::Print(format!("\nAdded {} path(s) to {} context.\n\n", paths.len(), target)), - style::SetForegroundColor(Color::Reset) - )?; - }, - Err(e) => { - execute!( - session.stderr, - style::SetForegroundColor(Color::Red), - style::Print(format!("\nError: {}\n\n", e)), - style::SetForegroundColor(Color::Reset) - )?; - }, - } - }, - Self::Remove { global, paths } => match context_manager.remove_paths(os, paths.clone(), global).await { - Ok(_) => { - let target = if global { "global" } else { "profile" }; - execute!( - session.stderr, - style::SetForegroundColor(Color::Green), - style::Print(format!( - "\nRemoved {} path(s) from {} context.\n\n", - paths.len(), - target - )), - style::SetForegroundColor(Color::Reset) - )?; - }, - Err(e) => { - execute!( - session.stderr, - style::SetForegroundColor(Color::Red), - style::Print(format!("\nError: {}\n\n", e)), - style::SetForegroundColor(Color::Reset) - )?; - }, - }, - Self::Clear { global } => match context_manager.clear(os, global).await { - Ok(_) => { - let target = if global { - "global".to_string() - } else { - format!("profile '{}'", context_manager.current_profile) - }; - execute!( - session.stderr, - style::SetForegroundColor(Color::Green), - style::Print(format!("\nCleared context for {}\n\n", target)), - style::SetForegroundColor(Color::Reset) - )?; - }, - Err(e) => { - execute!( - session.stderr, - style::SetForegroundColor(Color::Red), - style::Print(format!("\nError: {}\n\n", e)), - style::SetForegroundColor(Color::Reset) - )?; - }, - }, - Self::Hooks => { - execute!( - session.stderr, - style::SetForegroundColor(Color::Yellow), - style::Print("The /context hooks command is deprecated. Use "), - style::SetForegroundColor(Color::Green), - style::Print("/hooks"), - style::SetForegroundColor(Color::Yellow), - style::Print(" instead.\n\n"), - style::SetForegroundColor(Color::Reset) - )?; - }, - } - - Ok(ChatState::PromptUser { - skip_printing_tools: true, - }) - } -} diff --git a/crates/chat-cli/src/cli/chat/cli/editor.rs b/crates/chat-cli/src/cli/chat/cli/editor.rs deleted file mode 100644 index c88d6db2c..000000000 --- a/crates/chat-cli/src/cli/chat/cli/editor.rs +++ /dev/null @@ -1,134 +0,0 @@ -use clap::Args; -use crossterm::execute; -use crossterm::style::{ - self, - Attribute, - Color, -}; -use uuid::Uuid; - -use crate::cli::chat::{ - ChatError, - ChatSession, - ChatState, -}; - -#[deny(missing_docs)] -#[derive(Debug, PartialEq, Args)] -pub struct EditorArgs { - pub initial_text: Vec, -} - -impl EditorArgs { - pub async fn execute(self, session: &mut ChatSession) -> Result { - let initial_text = if self.initial_text.is_empty() { - None - } else { - Some(self.initial_text.join(" ")) - }; - - let content = match open_editor(initial_text) { - Ok(content) => content, - Err(err) => { - execute!( - session.stderr, - style::SetForegroundColor(Color::Red), - style::Print(format!("\nError opening editor: {}\n\n", err)), - style::SetForegroundColor(Color::Reset) - )?; - - return Ok(ChatState::PromptUser { - skip_printing_tools: true, - }); - }, - }; - - Ok(match content.trim().is_empty() { - true => { - execute!( - session.stderr, - style::SetForegroundColor(Color::Yellow), - style::Print("\nEmpty content from editor, not submitting.\n\n"), - style::SetForegroundColor(Color::Reset) - )?; - - ChatState::PromptUser { - skip_printing_tools: true, - } - }, - false => { - execute!( - session.stderr, - style::SetForegroundColor(Color::Green), - style::Print("\nContent loaded from editor. Submitting prompt...\n\n"), - style::SetForegroundColor(Color::Reset) - )?; - - // Display the content as if the user typed it - execute!( - session.stderr, - style::SetAttribute(Attribute::Reset), - style::SetForegroundColor(Color::Magenta), - style::Print("> "), - style::SetAttribute(Attribute::Reset), - style::Print(&content), - style::Print("\n") - )?; - - // Process the content as user input - ChatState::HandleInput { input: content } - }, - }) - } -} - -/// Opens the user's preferred editor to compose a prompt -fn open_editor(initial_text: Option) -> Result { - // Create a temporary file with a unique name - let temp_dir = std::env::temp_dir(); - let file_name = format!("q_prompt_{}.md", Uuid::new_v4()); - let temp_file_path = temp_dir.join(file_name); - - // Get the editor from environment variable or use a default - let editor_cmd = std::env::var("EDITOR").unwrap_or_else(|_| "vi".to_string()); - - // Parse the editor command to handle arguments - let mut parts = - shlex::split(&editor_cmd).ok_or_else(|| ChatError::Custom("Failed to parse EDITOR command".into()))?; - - if parts.is_empty() { - return Err(ChatError::Custom("EDITOR environment variable is empty".into())); - } - - let editor_bin = parts.remove(0); - - // Write initial content to the file if provided - let initial_content = initial_text.unwrap_or_default(); - std::fs::write(&temp_file_path, &initial_content) - .map_err(|e| ChatError::Custom(format!("Failed to create temporary file: {}", e).into()))?; - - // Open the editor with the parsed command and arguments - let mut cmd = std::process::Command::new(editor_bin); - // Add any arguments that were part of the EDITOR variable - for arg in parts { - cmd.arg(arg); - } - // Add the file path as the last argument - let status = cmd - .arg(&temp_file_path) - .status() - .map_err(|e| ChatError::Custom(format!("Failed to open editor: {}", e).into()))?; - - if !status.success() { - return Err(ChatError::Custom("Editor exited with non-zero status".into())); - } - - // Read the content back - let content = std::fs::read_to_string(&temp_file_path) - .map_err(|e| ChatError::Custom(format!("Failed to read temporary file: {}", e).into()))?; - - // Clean up the temporary file - let _ = std::fs::remove_file(&temp_file_path); - - Ok(content.trim().to_string()) -} diff --git a/crates/chat-cli/src/cli/chat/cli/hooks.rs b/crates/chat-cli/src/cli/chat/cli/hooks.rs deleted file mode 100644 index b9497dfb6..000000000 --- a/crates/chat-cli/src/cli/chat/cli/hooks.rs +++ /dev/null @@ -1,1136 +0,0 @@ -use std::collections::HashMap; -use std::io::Write; -use std::process::Stdio; -use std::time::{ - Duration, - Instant, -}; - -use bstr::ByteSlice; -use clap::{ - Args, - Subcommand, -}; -use crossterm::style::{ - self, - Attribute, - Color, - Stylize, -}; -use crossterm::{ - cursor, - execute, - queue, - terminal, -}; -use eyre::{ - ErrReport, - Result, - eyre, -}; -use futures::stream::{ - FuturesUnordered, - StreamExt, -}; -use serde::{ - Deserialize, - Serialize, -}; -use spinners::{ - Spinner, - Spinners, -}; - -use crate::cli::chat::util::truncate_safe; -use crate::cli::chat::{ - ChatError, - ChatSession, - ChatState, -}; -use crate::os::Os; - -const DEFAULT_TIMEOUT_MS: u64 = 30_000; -const DEFAULT_MAX_OUTPUT_SIZE: usize = 1024 * 10; -const DEFAULT_CACHE_TTL_SECONDS: u64 = 0; - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Hook { - pub trigger: HookTrigger, - - pub r#type: HookType, - - #[serde(default = "Hook::default_disabled")] - pub disabled: bool, - - /// Max time the hook can run before it throws a timeout error - #[serde(default = "Hook::default_timeout_ms")] - pub timeout_ms: u64, - - /// Max output size of the hook before it is truncated - #[serde(default = "Hook::default_max_output_size")] - pub max_output_size: usize, - - /// How long the hook output is cached before it will be executed again - #[serde(default = "Hook::default_cache_ttl_seconds")] - pub cache_ttl_seconds: u64, - - // Type-specific fields - /// The bash command to execute - pub command: Option, // For inline hooks - - // Internal data - #[serde(skip)] - pub name: String, - #[serde(skip)] - pub is_global: bool, -} - -impl Hook { - pub fn new_inline_hook(trigger: HookTrigger, command: String) -> Self { - Self { - trigger, - r#type: HookType::Inline, - disabled: Self::default_disabled(), - timeout_ms: Self::default_timeout_ms(), - max_output_size: Self::default_max_output_size(), - cache_ttl_seconds: Self::default_cache_ttl_seconds(), - command: Some(command), - is_global: false, - name: "new hook".to_string(), - } - } - - fn default_disabled() -> bool { - false - } - - fn default_timeout_ms() -> u64 { - DEFAULT_TIMEOUT_MS - } - - fn default_max_output_size() -> usize { - DEFAULT_MAX_OUTPUT_SIZE - } - - fn default_cache_ttl_seconds() -> u64 { - DEFAULT_CACHE_TTL_SECONDS - } -} - -#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)] -#[serde(rename_all = "lowercase")] -pub enum HookType { - // Execute an inline shell command - Inline, -} - -#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)] -#[serde(rename_all = "snake_case")] -pub enum HookTrigger { - ConversationStart, - PerPrompt, -} - -#[derive(Debug, Clone)] -pub struct CachedHook { - output: String, - expiry: Option, -} - -/// Maps a hook name to a [`CachedHook`] -#[derive(Debug, Clone, Default)] -pub struct HookExecutor { - pub global_cache: HashMap, - pub profile_cache: HashMap, -} - -impl HookExecutor { - pub fn new() -> Self { - Self { - global_cache: HashMap::new(), - profile_cache: HashMap::new(), - } - } - - /// Run and cache [`Hook`]s. Any hooks that are already cached will be returned without - /// executing. Hooks that fail to execute will not be returned. - /// - /// If `updates` is `Some`, progress on hook execution will be written to it. - /// Errors encountered with write operations to `updates` are ignored. - /// - /// Note: [`HookTrigger::ConversationStart`] hooks never leave the cache. - pub async fn run_hooks( - &mut self, - hooks: Vec<&Hook>, - output: &mut impl Write, - ) -> Result, ChatError> { - let mut results = Vec::with_capacity(hooks.len()); - let mut futures = FuturesUnordered::new(); - - // Start all hook future OR fetch from cache if available - // Why enumerate? We want to return the hook results in the order of hooks that we received, - // however, for output display we want to process hooks as they complete rather than the - // order they were started in. The index will be used later to sort them back to output order. - for (index, hook) in hooks.into_iter().enumerate() { - if hook.disabled { - continue; - } - - if let Some(cached) = self.get_cache(hook) { - results.push((index, (hook.clone(), cached.clone()))); - continue; - } - let future = self.execute_hook(hook); - futures.push(async move { (index, future.await) }); - } - - // Start caching the results added after whats already their (they are from the cache already) - let start_cache_index = results.len(); - - let mut succeeded = 0; - let total = futures.len(); - - let mut spinner = None; - let spinner_text = |complete: usize, total: usize| { - format!( - "{} of {} hooks finished", - complete.to_string().blue(), - total.to_string().blue(), - ) - }; - - if total != 0 { - spinner = Some(Spinner::new(Spinners::Dots12, spinner_text(succeeded, total))); - } - - // Process results as they complete - let start_time = Instant::now(); - while let Some((index, (hook, result, duration))) = futures.next().await { - // If output is enabled, handle that first - if let Some(spinner) = spinner.as_mut() { - spinner.stop(); - - // Erase the spinner - execute!( - output, - cursor::MoveToColumn(0), - terminal::Clear(terminal::ClearType::CurrentLine), - cursor::Hide, - )?; - } - - match &result { - Ok(_) => { - queue!( - output, - style::SetForegroundColor(style::Color::Green), - style::Print("✓ "), - style::SetForegroundColor(style::Color::Blue), - style::Print(&hook.name), - style::ResetColor, - style::Print(" finished in "), - style::SetForegroundColor(style::Color::Yellow), - style::Print(format!("{:.2} s\n", duration.as_secs_f32())), - style::ResetColor, - )?; - }, - Err(e) => { - queue!( - output, - style::SetForegroundColor(style::Color::Red), - style::Print("✗ "), - style::SetForegroundColor(style::Color::Blue), - style::Print(&hook.name), - style::ResetColor, - style::Print(" failed after "), - style::SetForegroundColor(style::Color::Yellow), - style::Print(format!("{:.2} s", duration.as_secs_f32())), - style::ResetColor, - style::Print(format!(": {}\n", e)), - )?; - }, - } - - // Process results regardless of output enabled - if let Ok(output) = result { - succeeded += 1; - results.push((index, (hook.clone(), output))); - } - - // Display ending summary or add a new spinner - // The futures set size decreases each time we process one - if futures.is_empty() { - let symbol = if total == succeeded { - "✓".to_string().green() - } else { - "✗".to_string().red() - }; - - queue!( - output, - style::SetForegroundColor(Color::Blue), - style::Print(format!("{symbol} {} in ", spinner_text(succeeded, total))), - style::SetForegroundColor(style::Color::Yellow), - style::Print(format!("{:.2} s\n", start_time.elapsed().as_secs_f32())), - style::ResetColor, - )?; - } else { - spinner = Some(Spinner::new(Spinners::Dots, spinner_text(succeeded, total))); - } - } - - drop(futures); - - // Fill cache with executed results, skipping what was already from cache - results.iter().skip(start_cache_index).for_each(|(_, (hook, output))| { - let expiry = match hook.trigger { - HookTrigger::ConversationStart => None, - HookTrigger::PerPrompt => Some(Instant::now() + Duration::from_secs(hook.cache_ttl_seconds)), - }; - self.insert_cache(hook, CachedHook { - output: output.clone(), - expiry, - }); - }); - - // Return back to order at request start - results.sort_by_key(|(idx, _)| *idx); - Ok(results.into_iter().map(|(_, r)| r).collect()) - } - - async fn execute_hook<'a>(&self, hook: &'a Hook) -> (&'a Hook, Result, Duration) { - let start_time = Instant::now(); - let result = match hook.r#type { - HookType::Inline => self.execute_inline_hook(hook).await, - }; - - (hook, result, start_time.elapsed()) - } - - async fn execute_inline_hook(&self, hook: &Hook) -> Result { - let command = hook.command.as_ref().ok_or_else(|| eyre!("no command specified"))?; - - #[cfg(unix)] - let command_future = tokio::process::Command::new("bash") - .arg("-c") - .arg(command) - .stdin(Stdio::piped()) - .stdout(Stdio::piped()) - .stderr(Stdio::piped()) - .output(); - - #[cfg(windows)] - let command_future = tokio::process::Command::new("cmd") - .arg("/C") - .arg(command) - .stdin(Stdio::piped()) - .stdout(Stdio::piped()) - .stderr(Stdio::piped()) - .output(); - - let timeout = Duration::from_millis(hook.timeout_ms); - - // Run with timeout - match tokio::time::timeout(timeout, command_future).await { - Ok(result) => { - let result = result?; - if result.status.success() { - let stdout = result.stdout.to_str_lossy(); - let stdout = format!( - "{}{}", - truncate_safe(&stdout, hook.max_output_size), - if stdout.len() > hook.max_output_size { - " ... truncated" - } else { - "" - } - ); - Ok(stdout) - } else { - Err(eyre!("command returned non-zero exit code: {}", result.status)) - } - }, - Err(_) => Err(eyre!("command timed out after {} ms", timeout.as_millis())), - } - } - - /// Will return a cached hook's output if it exists and isn't expired. - fn get_cache(&self, hook: &Hook) -> Option { - let cache = if hook.is_global { - &self.global_cache - } else { - &self.profile_cache - }; - - cache.get(&hook.name).and_then(|o| { - if let Some(expiry) = o.expiry { - if Instant::now() < expiry { - Some(o.output.clone()) - } else { - None - } - } else { - Some(o.output.clone()) - } - }) - } - - fn insert_cache(&mut self, hook: &Hook, hook_output: CachedHook) { - let cache = if hook.is_global { - &mut self.global_cache - } else { - &mut self.profile_cache - }; - - cache.insert(hook.name.clone(), hook_output); - } -} - -#[deny(missing_docs)] -#[derive(Debug, PartialEq, Args)] -#[command( - before_long_help = "Use context hooks to specify shell commands to run. The output from these -commands will be appended to the prompt to Amazon Q. Hooks can be defined -in global or local profiles. - -Notes: -• Hooks are executed in parallel -• 'conversation_start' hooks run on the first user prompt and are attached once to the conversation history sent to Amazon Q -• 'per_prompt' hooks run on each user prompt and are attached to the prompt, but are not stored in conversation history" -)] -pub struct HooksArgs { - #[command(subcommand)] - subcommand: Option, -} - -impl HooksArgs { - pub async fn execute(self, os: &Os, session: &mut ChatSession) -> Result { - if let Some(subcommand) = self.subcommand { - return subcommand.execute(os, session).await; - } - - let Some(context_manager) = &mut session.conversation.context_manager else { - return Ok(ChatState::PromptUser { - skip_printing_tools: true, - }); - }; - - queue!( - session.stderr, - style::SetAttribute(Attribute::Bold), - style::SetForegroundColor(Color::Magenta), - style::Print("\n🌍 global:\n"), - style::SetAttribute(Attribute::Reset), - )?; - - print_hook_section( - &mut session.stderr, - &context_manager.global_config.hooks, - HookTrigger::ConversationStart, - ) - .map_err(map_chat_error)?; - print_hook_section( - &mut session.stderr, - &context_manager.global_config.hooks, - HookTrigger::PerPrompt, - ) - .map_err(map_chat_error)?; - - queue!( - session.stderr, - style::SetAttribute(Attribute::Bold), - style::SetForegroundColor(Color::Magenta), - style::Print(format!("\n👤 profile ({}):\n", &context_manager.current_profile)), - style::SetAttribute(Attribute::Reset), - )?; - - print_hook_section( - &mut session.stderr, - &context_manager.profile_config.hooks, - HookTrigger::ConversationStart, - ) - .map_err(map_chat_error)?; - print_hook_section( - &mut session.stderr, - &context_manager.profile_config.hooks, - HookTrigger::PerPrompt, - ) - .map_err(map_chat_error)?; - - execute!( - session.stderr, - style::Print(format!( - "\nUse {} to manage hooks.\n\n", - "/hooks help".to_string().dark_green() - )), - )?; - - Ok(ChatState::PromptUser { - skip_printing_tools: true, - }) - } -} - -#[deny(missing_docs)] -#[derive(Clone, Debug, PartialEq, Subcommand)] -pub enum HooksSubcommand { - /// Add a new command context hook - Add { - /// The name of the hook - name: String, - /// When to trigger the hook, valid options: `per_prompt` or `conversation_start` - #[arg(long, value_parser = ["per_prompt", "conversation_start"])] - trigger: String, - /// Shell command to execute - #[arg(long, value_parser = clap::value_parser!(String))] - command: String, - /// Add to global hooks - #[arg(long)] - global: bool, - }, - /// Remove an existing context hook - #[command(name = "rm")] - Remove { - /// The name of the hook - name: String, - /// Remove from global hooks - #[arg(long)] - global: bool, - }, - /// Enable an existing context hook - Enable { - /// The name of the hook - name: String, - /// Enable in global hooks - #[arg(long)] - global: bool, - }, - /// Disable an existing context hook - Disable { - /// The name of the hook - name: String, - /// Disable in global hooks - #[arg(long)] - global: bool, - }, - /// Enable all existing context hooks - EnableAll { - /// Enable all in global hooks - #[arg(long)] - global: bool, - }, - /// Disable all existing context hooks - DisableAll { - /// Disable all in global hooks - #[arg(long)] - global: bool, - }, - /// Display the context rule configuration and matched files - Show, -} - -impl HooksSubcommand { - pub async fn execute(self, os: &Os, session: &mut ChatSession) -> Result { - let Some(context_manager) = &mut session.conversation.context_manager else { - return Ok(ChatState::PromptUser { - skip_printing_tools: true, - }); - }; - - let scope = |g: bool| if g { "global" } else { "profile" }; - - match self { - Self::Add { - name, - trigger, - command, - global, - } => { - let trigger = if trigger == "conversation_start" { - HookTrigger::ConversationStart - } else { - HookTrigger::PerPrompt - }; - - let result = context_manager - .add_hook(os, name.clone(), Hook::new_inline_hook(trigger, command), global) - .await; - match result { - Ok(_) => { - execute!( - session.stderr, - style::SetForegroundColor(Color::Green), - style::Print(format!("\nAdded {} hook '{name}'.\n\n", scope(global))), - style::SetForegroundColor(Color::Reset) - )?; - }, - Err(e) => { - execute!( - session.stderr, - style::SetForegroundColor(Color::Red), - style::Print(format!("\nCannot add {} hook '{name}': {}\n\n", scope(global), e)), - style::SetForegroundColor(Color::Reset) - )?; - }, - } - }, - Self::Remove { name, global } => { - let result = context_manager.remove_hook(os, &name, global).await; - match result { - Ok(_) => { - execute!( - session.stderr, - style::SetForegroundColor(Color::Green), - style::Print(format!("\nRemoved {} hook '{name}'.\n\n", scope(global))), - style::SetForegroundColor(Color::Reset) - )?; - }, - Err(e) => { - execute!( - session.stderr, - style::SetForegroundColor(Color::Red), - style::Print(format!("\nCannot remove {} hook '{name}': {}\n\n", scope(global), e)), - style::SetForegroundColor(Color::Reset) - )?; - }, - } - }, - Self::Enable { name, global } => { - let result = context_manager.set_hook_disabled(os, &name, global, false).await; - match result { - Ok(_) => { - execute!( - session.stderr, - style::SetForegroundColor(Color::Green), - style::Print(format!("\nEnabled {} hook '{name}'.\n\n", scope(global))), - style::SetForegroundColor(Color::Reset) - )?; - }, - Err(e) => { - execute!( - session.stderr, - style::SetForegroundColor(Color::Red), - style::Print(format!("\nCannot enable {} hook '{name}': {}\n\n", scope(global), e)), - style::SetForegroundColor(Color::Reset) - )?; - }, - } - }, - Self::Disable { name, global } => { - let result = context_manager.set_hook_disabled(os, &name, global, true).await; - match result { - Ok(_) => { - execute!( - session.stderr, - style::SetForegroundColor(Color::Green), - style::Print(format!("\nDisabled {} hook '{name}'.\n\n", scope(global))), - style::SetForegroundColor(Color::Reset) - )?; - }, - Err(e) => { - execute!( - session.stderr, - style::SetForegroundColor(Color::Red), - style::Print(format!("\nCannot disable {} hook '{name}': {}\n\n", scope(global), e)), - style::SetForegroundColor(Color::Reset) - )?; - }, - } - }, - Self::EnableAll { global } => { - context_manager - .set_all_hooks_disabled(os, global, false) - .await - .map_err(map_chat_error)?; - execute!( - session.stderr, - style::SetForegroundColor(Color::Green), - style::Print(format!("\nEnabled all {} hooks.\n\n", scope(global))), - style::SetForegroundColor(Color::Reset) - )?; - }, - Self::DisableAll { global } => { - context_manager - .set_all_hooks_disabled(os, global, true) - .await - .map_err(map_chat_error)?; - execute!( - session.stderr, - style::SetForegroundColor(Color::Green), - style::Print(format!("\nDisabled all {} hooks.\n\n", scope(global))), - style::SetForegroundColor(Color::Reset) - )?; - }, - Self::Show => { - // Display global context - execute!( - session.stderr, - style::SetAttribute(Attribute::Bold), - style::SetForegroundColor(Color::Magenta), - style::Print("\n🌍 global:\n"), - style::SetAttribute(Attribute::Reset), - )?; - - print_hook_section( - &mut session.stderr, - &context_manager.global_config.hooks, - HookTrigger::ConversationStart, - ) - .map_err(map_chat_error)?; - print_hook_section( - &mut session.stderr, - &context_manager.global_config.hooks, - HookTrigger::PerPrompt, - ) - .map_err(map_chat_error)?; - - // Display profile hooks - execute!( - session.stderr, - style::SetAttribute(Attribute::Bold), - style::SetForegroundColor(Color::Magenta), - style::Print(format!("\n👤 profile ({}):\n", context_manager.current_profile)), - style::SetAttribute(Attribute::Reset), - )?; - - print_hook_section( - &mut session.stderr, - &context_manager.profile_config.hooks, - HookTrigger::ConversationStart, - ) - .map_err(map_chat_error)?; - print_hook_section( - &mut session.stderr, - &context_manager.profile_config.hooks, - HookTrigger::PerPrompt, - ) - .map_err(map_chat_error)?; - execute!(session.stderr, style::Print("\n"))?; - }, - } - - Ok(ChatState::PromptUser { - skip_printing_tools: true, - }) - } -} - -/// Prints hook configuration grouped by trigger: conversation session start or per user message -pub fn print_hook_section(output: &mut impl Write, hooks: &HashMap, trigger: HookTrigger) -> Result<()> { - let section = match trigger { - HookTrigger::ConversationStart => "On Session Start", - HookTrigger::PerPrompt => "Per User Message", - }; - let hooks: Vec<(&String, &Hook)> = hooks.iter().filter(|(_, h)| h.trigger == trigger).collect(); - - queue!( - output, - style::SetForegroundColor(Color::Cyan), - style::Print(format!(" {section}:\n")), - style::SetForegroundColor(Color::Reset), - )?; - - if hooks.is_empty() { - queue!( - output, - style::SetForegroundColor(Color::DarkGrey), - style::Print(" \n"), - style::SetForegroundColor(Color::Reset) - )?; - } else { - for (name, hook) in hooks { - if hook.disabled { - queue!( - output, - style::SetForegroundColor(Color::DarkGrey), - style::Print(format!(" {} (disabled)\n", name)), - style::SetForegroundColor(Color::Reset) - )?; - } else { - queue!(output, style::Print(format!(" {}\n", name)),)?; - } - } - } - Ok(()) -} - -pub fn map_chat_error(e: ErrReport) -> ChatError { - ChatError::Custom(e.to_string().into()) -} - -#[cfg(test)] -mod tests { - use std::time::Duration; - - use tokio::time::sleep; - - use super::*; - use crate::cli::chat::util::test::create_test_context_manager; - - #[tokio::test] - async fn test_add_hook() -> Result<()> { - let os = Os::new().await.unwrap(); - let mut manager = create_test_context_manager(None).await?; - let hook = Hook::new_inline_hook(HookTrigger::ConversationStart, "echo test".to_string()); - - // Test adding hook to profile config - manager - .add_hook(&os, "test_hook".to_string(), hook.clone(), false) - .await?; - assert!(manager.profile_config.hooks.contains_key("test_hook")); - - // Test adding hook to global config - manager - .add_hook(&os, "global_hook".to_string(), hook.clone(), true) - .await?; - assert!(manager.global_config.hooks.contains_key("global_hook")); - - // Test adding duplicate hook name - assert!( - manager - .add_hook(&os, "test_hook".to_string(), hook, false) - .await - .is_err() - ); - - Ok(()) - } - - #[tokio::test] - async fn test_remove_hook() -> Result<()> { - let os = Os::new().await.unwrap(); - let mut manager = create_test_context_manager(None).await?; - let hook = Hook::new_inline_hook(HookTrigger::ConversationStart, "echo test".to_string()); - - manager.add_hook(&os, "test_hook".to_string(), hook, false).await?; - - // Test removing existing hook - manager.remove_hook(&os, "test_hook", false).await?; - assert!(!manager.profile_config.hooks.contains_key("test_hook")); - - // Test removing non-existent hook - assert!(manager.remove_hook(&os, "test_hook", false).await.is_err()); - - Ok(()) - } - - #[tokio::test] - async fn test_set_hook_disabled() -> Result<()> { - let os = Os::new().await.unwrap(); - let mut manager = create_test_context_manager(None).await?; - let hook = Hook::new_inline_hook(HookTrigger::ConversationStart, "echo test".to_string()); - - manager.add_hook(&os, "test_hook".to_string(), hook, false).await?; - - // Test disabling hook - manager.set_hook_disabled(&os, "test_hook", false, true).await?; - assert!(manager.profile_config.hooks.get("test_hook").unwrap().disabled); - - // Test enabling hook - manager.set_hook_disabled(&os, "test_hook", false, false).await?; - assert!(!manager.profile_config.hooks.get("test_hook").unwrap().disabled); - - // Test with non-existent hook - assert!( - manager - .set_hook_disabled(&os, "nonexistent", false, true) - .await - .is_err() - ); - - Ok(()) - } - - #[tokio::test] - async fn test_set_all_hooks_disabled() -> Result<()> { - let os = Os::new().await.unwrap(); - let mut manager = create_test_context_manager(None).await?; - let hook1 = Hook::new_inline_hook(HookTrigger::ConversationStart, "echo test".to_string()); - let hook2 = Hook::new_inline_hook(HookTrigger::ConversationStart, "echo test".to_string()); - - manager.add_hook(&os, "hook1".to_string(), hook1, false).await?; - manager.add_hook(&os, "hook2".to_string(), hook2, false).await?; - - // Test disabling all hooks - manager.set_all_hooks_disabled(&os, false, true).await?; - assert!(manager.profile_config.hooks.values().all(|h| h.disabled)); - - // Test enabling all hooks - manager.set_all_hooks_disabled(&os, false, false).await?; - assert!(manager.profile_config.hooks.values().all(|h| !h.disabled)); - - Ok(()) - } - - #[tokio::test] - async fn test_run_hooks() -> Result<()> { - let os = Os::new().await.unwrap(); - let mut manager = create_test_context_manager(None).await?; - let hook1 = Hook::new_inline_hook(HookTrigger::ConversationStart, "echo test".to_string()); - let hook2 = Hook::new_inline_hook(HookTrigger::ConversationStart, "echo test".to_string()); - - manager.add_hook(&os, "hook1".to_string(), hook1, false).await?; - manager.add_hook(&os, "hook2".to_string(), hook2, false).await?; - - // Run the hooks - let results = manager - .run_hooks(HookTrigger::ConversationStart, &mut vec![]) - .await - .unwrap(); - assert_eq!(results.len(), 2); - - let results = manager.run_hooks(HookTrigger::PerPrompt, &mut vec![]).await.unwrap(); - assert_eq!(results.len(), 0); - - Ok(()) - } - - #[tokio::test] - async fn test_hooks_across_profiles() -> Result<()> { - let os = Os::new().await.unwrap(); - let mut manager = create_test_context_manager(None).await?; - let hook1 = Hook::new_inline_hook(HookTrigger::ConversationStart, "echo test".to_string()); - let hook2 = Hook::new_inline_hook(HookTrigger::ConversationStart, "echo test".to_string()); - - manager.add_hook(&os, "profile_hook".to_string(), hook1, false).await?; - manager.add_hook(&os, "global_hook".to_string(), hook2, true).await?; - - let results = manager - .run_hooks(HookTrigger::ConversationStart, &mut vec![]) - .await - .unwrap(); - assert_eq!(results.len(), 2); // Should include both hooks - - // Create and switch to a new profile - manager.create_profile(&os, "test_profile").await?; - manager.switch_profile(&os, "test_profile").await?; - - let results = manager - .run_hooks(HookTrigger::ConversationStart, &mut vec![]) - .await - .unwrap(); - assert_eq!(results.len(), 1); // Should include global hook - assert_eq!(results[0].0.name, "global_hook"); - - Ok(()) - } - - #[test] - fn test_hook_creation() { - let command = "echo 'hello'"; - let hook = Hook::new_inline_hook(HookTrigger::PerPrompt, command.to_string()); - - assert_eq!(hook.r#type, HookType::Inline); - assert!(!hook.disabled); - assert_eq!(hook.timeout_ms, DEFAULT_TIMEOUT_MS); - assert_eq!(hook.max_output_size, DEFAULT_MAX_OUTPUT_SIZE); - assert_eq!(hook.cache_ttl_seconds, DEFAULT_CACHE_TTL_SECONDS); - assert_eq!(hook.command, Some(command.to_string())); - assert_eq!(hook.trigger, HookTrigger::PerPrompt); - assert!(!hook.is_global); - } - - #[tokio::test] - async fn test_hook_executor_cached_conversation_start() { - let mut executor = HookExecutor::new(); - let mut hook1 = Hook::new_inline_hook(HookTrigger::ConversationStart, "echo 'test1'".to_string()); - hook1.is_global = true; - - let mut hook2 = Hook::new_inline_hook(HookTrigger::ConversationStart, "echo 'test2'".to_string()); - hook2.is_global = false; - - // First execution should run the command - let mut output = vec![]; - let results = executor.run_hooks(vec![&hook1, &hook2], &mut output).await.unwrap(); - - assert_eq!(results.len(), 2); - assert!(results[0].1.contains("test1")); - assert!(results[1].1.contains("test2")); - assert!(!output.is_empty()); - - // Second execution should use cache - let mut output = Vec::new(); - let results = executor.run_hooks(vec![&hook1, &hook2], &mut output).await.unwrap(); - - assert_eq!(results.len(), 2); - assert!(results[0].1.contains("test1")); - assert!(results[1].1.contains("test2")); - assert!(output.is_empty()); // Should not have run the hook, so no output. - } - - #[tokio::test] - async fn test_hook_executor_cached_per_prompt() { - let mut executor = HookExecutor::new(); - let mut hook1 = Hook::new_inline_hook(HookTrigger::PerPrompt, "echo 'test1'".to_string()); - hook1.is_global = true; - hook1.cache_ttl_seconds = 60; - - let mut hook2 = Hook::new_inline_hook(HookTrigger::PerPrompt, "echo 'test2'".to_string()); - hook2.is_global = false; - hook2.cache_ttl_seconds = 60; - - // First execution should run the command - let mut output = vec![]; - let results = executor.run_hooks(vec![&hook1, &hook2], &mut output).await.unwrap(); - - assert_eq!(results.len(), 2); - assert!(results[0].1.contains("test1")); - assert!(results[1].1.contains("test2")); - assert!(!output.is_empty()); - - // Second execution should use cache - let mut output = Vec::new(); - let results = executor.run_hooks(vec![&hook1, &hook2], &mut output).await.unwrap(); - - assert_eq!(results.len(), 2); - assert!(results[0].1.contains("test1")); - assert!(results[1].1.contains("test2")); - assert!(output.is_empty()); // Should not have run the hook, so no output. - } - - #[tokio::test] - async fn test_hook_executor_not_cached_per_prompt() { - let mut executor = HookExecutor::new(); - let mut hook1 = Hook::new_inline_hook(HookTrigger::PerPrompt, "echo 'test1'".to_string()); - hook1.is_global = true; - - let mut hook2 = Hook::new_inline_hook(HookTrigger::PerPrompt, "echo 'test2'".to_string()); - hook2.is_global = false; - - // First execution should run the command - let mut output = Vec::new(); - let results = executor.run_hooks(vec![&hook1, &hook2], &mut output).await.unwrap(); - - assert_eq!(results.len(), 2); - assert!(results[0].1.contains("test1")); - assert!(results[1].1.contains("test2")); - assert!(!output.is_empty()); - - // Second execution should use cache - let mut output = Vec::new(); - let results = executor.run_hooks(vec![&hook1, &hook2], &mut output).await.unwrap(); - - assert_eq!(results.len(), 2); - assert!(results[0].1.contains("test1")); - assert!(results[1].1.contains("test2")); - assert!(!output.is_empty()); - } - - #[tokio::test] - async fn test_hook_timeout() { - let mut executor = HookExecutor::new(); - let mut hook = Hook::new_inline_hook(HookTrigger::PerPrompt, "sleep 2".to_string()); - hook.timeout_ms = 100; // Set very short timeout - - let results = executor.run_hooks(vec![&hook], &mut vec![]).await.unwrap(); - - assert_eq!(results.len(), 0); // Should fail due to timeout - } - - #[tokio::test] - async fn test_disabled_hook() { - let mut executor = HookExecutor::new(); - let mut hook = Hook::new_inline_hook(HookTrigger::PerPrompt, "echo 'test'".to_string()); - hook.disabled = true; - - let results = executor.run_hooks(vec![&hook], &mut vec![]).await.unwrap(); - - assert_eq!(results.len(), 0); // Disabled hook should not run - } - - #[tokio::test] - async fn test_cache_expiration() { - let mut executor = HookExecutor::new(); - let mut hook = Hook::new_inline_hook(HookTrigger::PerPrompt, "echo 'test'".to_string()); - hook.cache_ttl_seconds = 1; - - // First execution - let results1 = executor.run_hooks(vec![&hook], &mut vec![]).await.unwrap(); - assert_eq!(results1.len(), 1); - - // Wait for cache to expire - sleep(Duration::from_millis(1001)).await; - - // Second execution should run command again - let results2 = executor.run_hooks(vec![&hook], &mut vec![]).await.unwrap(); - assert_eq!(results2.len(), 1); - } - - #[test] - fn test_hook_cache_storage() { - let mut executor: HookExecutor = HookExecutor::new(); - let hook = Hook::new_inline_hook(HookTrigger::PerPrompt, "".to_string()); - - let cached_hook = CachedHook { - output: "test output".to_string(), - expiry: None, - }; - - executor.insert_cache(&hook, cached_hook.clone()); - - assert_eq!(executor.get_cache(&hook), Some("test output".to_string())); - } - - #[test] - fn test_hook_cache_storage_expired() { - let mut executor: HookExecutor = HookExecutor::new(); - let hook = Hook::new_inline_hook(HookTrigger::PerPrompt, "".to_string()); - - let cached_hook = CachedHook { - output: "test output".to_string(), - expiry: Some(Instant::now()), - }; - - executor.insert_cache(&hook, cached_hook.clone()); - - // Item should not return since it is expired - assert_eq!(executor.get_cache(&hook), None); - } - - #[tokio::test] - async fn test_max_output_size() { - let mut executor = HookExecutor::new(); - - // Use different commands based on OS - #[cfg(unix)] - let command = "for i in {1..1000}; do echo $i; done"; - - #[cfg(windows)] - let command = "for /L %i in (1,1,1000) do @echo %i"; - - let mut hook = Hook::new_inline_hook(HookTrigger::PerPrompt, command.to_string()); - hook.max_output_size = 100; - - let results = executor.run_hooks(vec![&hook], &mut vec![]).await.unwrap(); - - assert!(results[0].1.len() <= hook.max_output_size + " ... truncated".len()); - } - - #[tokio::test] - async fn test_os_specific_command_execution() { - let mut executor = HookExecutor::new(); - - // Create a simple command that outputs the shell name - #[cfg(unix)] - let command = "echo $SHELL"; - - #[cfg(windows)] - let command = "echo %ComSpec%"; - - let hook = Hook::new_inline_hook(HookTrigger::PerPrompt, command.to_string()); - - let results = executor.run_hooks(vec![&hook], &mut vec![]).await.unwrap(); - - assert_eq!(results.len(), 1, "Command execution should succeed"); - - // Verify output contains expected shell information - #[cfg(unix)] - assert!(results[0].1.contains("/"), "Unix shell path should contain '/'"); - - #[cfg(windows)] - assert!( - results[0].1.to_lowercase().contains("cmd.exe") || results[0].1.to_lowercase().contains("command.com"), - "Windows shell path should contain cmd.exe or command.com" - ); - } -} diff --git a/crates/chat-cli/src/cli/chat/cli/knowledge.rs b/crates/chat-cli/src/cli/chat/cli/knowledge.rs deleted file mode 100644 index d868a14ef..000000000 --- a/crates/chat-cli/src/cli/chat/cli/knowledge.rs +++ /dev/null @@ -1,481 +0,0 @@ -use std::io::Write; - -use clap::Subcommand; -use crossterm::queue; -use crossterm::style::{ - self, - Color, -}; -use eyre::Result; -use semantic_search_client::{ - KnowledgeContext, - OperationStatus, - SystemStatus, -}; - -use crate::cli::chat::tools::sanitize_path_tool_arg; -use crate::cli::chat::{ - ChatError, - ChatSession, - ChatState, -}; -use crate::database::settings::Setting; -use crate::os::Os; -use crate::util::knowledge_store::KnowledgeStore; - -/// Knowledge base management commands -#[derive(Clone, Debug, PartialEq, Eq, Subcommand)] -pub enum KnowledgeSubcommand { - /// Display the knowledge base contents - Show, - /// Add a file or directory to knowledge base - Add { path: String }, - /// Remove specified knowledge context by path - #[command(alias = "rm")] - Remove { path: String }, - /// Update a file or directory in knowledge base - Update { path: String }, - /// Remove all knowledge contexts - Clear, - /// Show background operation status - Status, - /// Cancel a background operation - Cancel { - /// Operation ID to cancel (optional - cancels most recent if not provided) - operation_id: Option, - }, -} - -#[derive(Debug)] -enum OperationResult { - Success(String), - Info(String), - Warning(String), - Error(String), -} - -impl KnowledgeSubcommand { - pub async fn execute(self, os: &Os, session: &mut ChatSession) -> Result { - if !Self::is_feature_enabled(os) { - Self::write_feature_disabled_message(session)?; - return Ok(Self::default_chat_state()); - } - - let result = self.execute_operation(os, session).await; - - Self::write_operation_result(session, result)?; - - Ok(Self::default_chat_state()) - } - - fn is_feature_enabled(os: &Os) -> bool { - os.database - .settings - .get_bool(Setting::EnabledKnowledge) - .unwrap_or(false) - } - - fn write_feature_disabled_message(session: &mut ChatSession) -> Result<(), std::io::Error> { - queue!( - session.stderr, - style::SetForegroundColor(Color::Red), - style::Print("\nKnowledge tool is disabled. Enable it with: q settings chat.enableKnowledge true\n\n"), - style::SetForegroundColor(Color::Reset) - ) - } - - fn default_chat_state() -> ChatState { - ChatState::PromptUser { - skip_printing_tools: true, - } - } - - async fn execute_operation(&self, os: &Os, session: &mut ChatSession) -> OperationResult { - match self { - KnowledgeSubcommand::Show => { - match Self::handle_show(session).await { - Ok(_) => OperationResult::Info("".to_string()), // Empty Info, formatting already done - Err(e) => OperationResult::Error(format!("Failed to show contexts: {}", e)), - } - }, - KnowledgeSubcommand::Add { path } => Self::handle_add(os, path).await, - KnowledgeSubcommand::Remove { path } => Self::handle_remove(os, path).await, - KnowledgeSubcommand::Update { path } => Self::handle_update(os, path).await, - KnowledgeSubcommand::Clear => Self::handle_clear(session).await, - KnowledgeSubcommand::Status => Self::handle_status().await, - KnowledgeSubcommand::Cancel { operation_id } => Self::handle_cancel(operation_id.as_deref()).await, - } - } - - async fn handle_show(session: &mut ChatSession) -> Result<(), std::io::Error> { - let async_knowledge_store = KnowledgeStore::get_async_instance().await; - let store = async_knowledge_store.lock().await; - - // Use the async get_all method which is concurrent with indexing - let contexts = store.get_all().await.unwrap_or_else(|e| { - // Write error to output using queue system - let _ = queue!( - session.stderr, - style::SetForegroundColor(Color::Red), - style::Print(&format!("Error getting contexts: {}\n", e)), - style::ResetColor - ); - Vec::new() - }); - - Self::format_contexts(session, &contexts) - } - - fn format_contexts(session: &mut ChatSession, contexts: &[KnowledgeContext]) -> Result<(), std::io::Error> { - if contexts.is_empty() { - queue!( - session.stderr, - style::Print("\nNo knowledge base entries found.\n"), - style::Print("💡 Tip: If indexing is in progress, contexts may not appear until indexing completes.\n"), - style::Print(" Use 'knowledge status' to check active operations.\n\n") - )?; - } else { - queue!( - session.stderr, - style::Print("\n📚 Knowledge Base Contexts:\n"), - style::Print(format!("{}\n", "━".repeat(80))) - )?; - - for context in contexts { - Self::format_single_context(session, &context)?; - queue!(session.stderr, style::Print(format!("{}\n", "━".repeat(80))))?; - } - // Add final newline to match original formatting exactly - queue!(session.stderr, style::Print("\n"))?; - } - Ok(()) - } - - fn format_single_context(session: &mut ChatSession, context: &&KnowledgeContext) -> Result<(), std::io::Error> { - queue!( - session.stderr, - style::SetAttribute(style::Attribute::Bold), - style::SetForegroundColor(Color::Cyan), - style::Print(format!("📂 {}: ", context.id)), - style::SetForegroundColor(Color::Green), - style::Print(&context.name), - style::SetAttribute(style::Attribute::Reset), - style::Print("\n") - )?; - - queue!( - session.stderr, - style::Print(format!(" Description: {}\n", context.description)), - style::Print(format!( - " Created: {}\n", - context.created_at.format("%Y-%m-%d %H:%M:%S") - )), - style::Print(format!( - " Updated: {}\n", - context.updated_at.format("%Y-%m-%d %H:%M:%S") - )) - )?; - - if let Some(path) = &context.source_path { - queue!(session.stderr, style::Print(format!(" Source: {}\n", path)))?; - } - - queue!( - session.stderr, - style::Print(" Items: "), - style::SetForegroundColor(Color::Yellow), - style::Print(format!("{}", context.item_count)), - style::SetForegroundColor(Color::Reset), - style::Print(" | Persistent: ") - )?; - - if context.persistent { - queue!( - session.stderr, - style::SetForegroundColor(Color::Green), - style::Print("Yes"), - style::SetForegroundColor(Color::Reset), - style::Print("\n") - )?; - } else { - queue!( - session.stderr, - style::SetForegroundColor(Color::Yellow), - style::Print("No"), - style::SetForegroundColor(Color::Reset), - style::Print("\n") - )?; - } - Ok(()) - } - - /// Handle add operation - async fn handle_add(os: &Os, path: &str) -> OperationResult { - match Self::validate_and_sanitize_path(os, path) { - Ok(sanitized_path) => { - let async_knowledge_store = KnowledgeStore::get_async_instance().await; - let mut store = async_knowledge_store.lock().await; - - // Use the async add method which is fire-and-forget - match store.add(path, &sanitized_path).await { - Ok(message) => OperationResult::Info(message), - Err(e) => OperationResult::Error(format!("Failed to add to knowledge base: {}", e)), - } - }, - Err(e) => OperationResult::Error(e), - } - } - - /// Handle remove operation - async fn handle_remove(os: &Os, path: &str) -> OperationResult { - let sanitized_path = sanitize_path_tool_arg(os, path); - let async_knowledge_store = KnowledgeStore::get_async_instance().await; - let mut store = async_knowledge_store.lock().await; - - // Try path first, then name - if store.remove_by_path(&sanitized_path.to_string_lossy()).await.is_ok() { - OperationResult::Success(format!("Removed context with path '{}'", path)) - } else if store.remove_by_name(path).await.is_ok() { - OperationResult::Success(format!("Removed context with name '{}'", path)) - } else { - OperationResult::Warning(format!("Entry not found in knowledge base: {}", path)) - } - } - - /// Handle update operation - async fn handle_update(os: &Os, path: &str) -> OperationResult { - match Self::validate_and_sanitize_path(os, path) { - Ok(sanitized_path) => { - let async_knowledge_store = KnowledgeStore::get_async_instance().await; - let mut store = async_knowledge_store.lock().await; - - match store.update_by_path(&sanitized_path).await { - Ok(message) => OperationResult::Info(message), - Err(e) => OperationResult::Error(format!("Failed to update: {}", e)), - } - }, - Err(e) => OperationResult::Error(e), - } - } - - /// Handle clear operation - async fn handle_clear(session: &mut ChatSession) -> OperationResult { - // Require confirmation - queue!( - session.stderr, - style::Print("⚠️ This will remove ALL knowledge base entries. Are you sure? (y/N): ") - ) - .unwrap(); - session.stderr.flush().unwrap(); - - let mut input = String::new(); - if std::io::stdin().read_line(&mut input).is_err() { - return OperationResult::Error("Failed to read input".to_string()); - } - - let input = input.trim().to_lowercase(); - if input != "y" && input != "yes" { - return OperationResult::Info("Clear operation cancelled".to_string()); - } - - let async_knowledge_store = KnowledgeStore::get_async_instance().await; - let mut store = async_knowledge_store.lock().await; - - // First, cancel any pending operations - queue!( - session.stderr, - style::Print("🛑 Cancelling any pending operations...\n") - ) - .unwrap(); - if let Err(e) = store.cancel_operation(None).await { - queue!( - session.stderr, - style::Print(&format!("⚠️ Warning: Failed to cancel operations: {}\n", e)) - ) - .unwrap(); - } - - // Now perform immediate synchronous clear - queue!( - session.stderr, - style::Print("🗑️ Clearing all knowledge base entries...\n") - ) - .unwrap(); - match store.clear_immediate().await { - Ok(message) => OperationResult::Success(message), - Err(e) => OperationResult::Error(format!("Failed to clear: {}", e)), - } - } - - /// Handle status operation - async fn handle_status() -> OperationResult { - let async_knowledge_store = KnowledgeStore::get_async_instance().await; - let store = async_knowledge_store.lock().await; - - match store.get_status_data().await { - Ok(status_data) => { - let formatted_status = Self::format_status_display(&status_data); - OperationResult::Info(formatted_status) - }, - Err(e) => OperationResult::Error(format!("Failed to get status: {}", e)), - } - } - - /// Format status data for display (UI rendering responsibility) - fn format_status_display(status: &SystemStatus) -> String { - let mut status_lines = Vec::new(); - - // Show context summary - status_lines.push(format!( - "📚 Total contexts: {} ({} persistent, {} volatile)", - status.total_contexts, status.persistent_contexts, status.volatile_contexts - )); - - if status.operations.is_empty() { - status_lines.push("✅ No active operations".to_string()); - return status_lines.join("\n"); - } - - status_lines.push("📊 Active Operations:".to_string()); - status_lines.push(format!( - " 📈 Queue Status: {} active, {} waiting (max {} concurrent)", - status.active_count, status.waiting_count, status.max_concurrent - )); - - for op in &status.operations { - let formatted_operation = Self::format_operation_display(op); - status_lines.push(formatted_operation); - } - - status_lines.join("\n") - } - - /// Format a single operation for display - fn format_operation_display(op: &OperationStatus) -> String { - let elapsed = op.started_at.elapsed().unwrap_or_default(); - - let (status_icon, status_info) = if op.is_cancelled { - ("🛑", "Cancelled".to_string()) - } else if op.is_failed { - ("❌", op.message.clone()) - } else if op.is_waiting { - ("⏳", op.message.clone()) - } else if Self::should_show_progress_bar(op.current, op.total) { - ("🔄", Self::create_progress_bar(op.current, op.total, &op.message)) - } else { - ("🔄", op.message.clone()) - }; - - let operation_desc = op.operation_type.display_name(); - - // Format with conditional elapsed time and ETA - if op.is_cancelled || op.is_failed { - format!( - " {} {} | {}\n {}", - status_icon, op.short_id, operation_desc, status_info - ) - } else { - let mut time_info = format!("Elapsed: {}s", elapsed.as_secs()); - - if let Some(eta) = op.eta { - time_info.push_str(&format!(" | ETA: {}s", eta.as_secs())); - } - - format!( - " {} {} | {}\n {} | {}", - status_icon, op.short_id, operation_desc, status_info, time_info - ) - } - } - - /// Check if progress bar should be shown - fn should_show_progress_bar(current: u64, total: u64) -> bool { - total > 0 && current <= total - } - - /// Create progress bar display - fn create_progress_bar(current: u64, total: u64, message: &str) -> String { - if total == 0 { - return message.to_string(); - } - - let percentage = (current as f64 / total as f64 * 100.0) as u8; - let filled = (current as f64 / total as f64 * 30.0) as usize; - let empty = 30 - filled; - - let mut bar = String::new(); - bar.push_str(&"█".repeat(filled)); - if filled < 30 && current < total { - bar.push('▓'); - bar.push_str(&"░".repeat(empty.saturating_sub(1))); - } else { - bar.push_str(&"░".repeat(empty)); - } - - format!("{} {}% ({}/{}) {}", bar, percentage, current, total, message) - } - - /// Handle cancel operation - async fn handle_cancel(operation_id: Option<&str>) -> OperationResult { - let async_knowledge_store = KnowledgeStore::get_async_instance().await; - let mut store = async_knowledge_store.lock().await; - - match store.cancel_operation(operation_id).await { - Ok(result) => OperationResult::Success(result), - Err(e) => OperationResult::Error(format!("Failed to cancel operation: {}", e)), - } - } - - /// Validate and sanitize path - fn validate_and_sanitize_path(os: &Os, path: &str) -> Result { - if path.contains('\n') { - return Ok(path.to_string()); - } - - let os_path = sanitize_path_tool_arg(os, path); - if !os_path.exists() { - return Err(format!("Path '{}' does not exist", path)); - } - - Ok(os_path.to_string_lossy().to_string()) - } - - fn write_operation_result(session: &mut ChatSession, result: OperationResult) -> Result<(), std::io::Error> { - match result { - OperationResult::Success(msg) => { - queue!( - session.stderr, - style::SetForegroundColor(Color::Green), - style::Print(format!("\n{}\n\n", msg)), - style::SetForegroundColor(Color::Reset) - ) - }, - OperationResult::Info(msg) => { - if !msg.trim().is_empty() { - queue!( - session.stderr, - style::Print(format!("\n{}\n\n", msg)), - style::SetForegroundColor(Color::Reset) - )?; - } - Ok(()) - }, - OperationResult::Warning(msg) => { - queue!( - session.stderr, - style::SetForegroundColor(Color::Yellow), - style::Print(format!("\n{}\n\n", msg)), - style::SetForegroundColor(Color::Reset) - ) - }, - OperationResult::Error(msg) => { - queue!( - session.stderr, - style::SetForegroundColor(Color::Red), - style::Print(format!("\nError: {}\n\n", msg)), - style::SetForegroundColor(Color::Reset) - ) - }, - } - } -} diff --git a/crates/chat-cli/src/cli/chat/cli/mcp.rs b/crates/chat-cli/src/cli/chat/cli/mcp.rs deleted file mode 100644 index e653ddca7..000000000 --- a/crates/chat-cli/src/cli/chat/cli/mcp.rs +++ /dev/null @@ -1,70 +0,0 @@ -use std::io::Write; - -use clap::Args; -use crossterm::{ - queue, - style, -}; - -use crate::cli::chat::tool_manager::LoadingRecord; -use crate::cli::chat::{ - ChatError, - ChatSession, - ChatState, -}; - -#[deny(missing_docs)] -#[derive(Debug, PartialEq, Args)] -pub struct McpArgs; - -impl McpArgs { - pub async fn execute(self, session: &mut ChatSession) -> Result { - let terminal_width = session.terminal_width(); - let still_loading = session - .conversation - .tool_manager - .pending_clients() - .await - .into_iter() - .map(|name| format!(" - {name}\n")) - .collect::>() - .join(""); - - for (server_name, msg) in session.conversation.tool_manager.mcp_load_record.lock().await.iter() { - let msg = msg - .iter() - .map(|record| match record { - LoadingRecord::Err(content) | LoadingRecord::Warn(content) | LoadingRecord::Success(content) => { - content.clone() - }, - }) - .collect::>() - .join("\n--- tools refreshed ---\n"); - - queue!( - session.stderr, - style::Print(server_name), - style::Print("\n"), - style::Print(format!("{}\n", "▔".repeat(terminal_width))), - style::Print(msg), - style::Print("\n") - )?; - } - - if !still_loading.is_empty() { - queue!( - session.stderr, - style::Print("Still loading:\n"), - style::Print(format!("{}\n", "▔".repeat(terminal_width))), - style::Print(still_loading), - style::Print("\n") - )?; - } - - session.stderr.flush()?; - - Ok(ChatState::PromptUser { - skip_printing_tools: true, - }) - } -} diff --git a/crates/chat-cli/src/cli/chat/cli/mod.rs b/crates/chat-cli/src/cli/chat/cli/mod.rs deleted file mode 100644 index 7df327d3f..000000000 --- a/crates/chat-cli/src/cli/chat/cli/mod.rs +++ /dev/null @@ -1,125 +0,0 @@ -pub mod clear; -pub mod compact; -pub mod context; -pub mod editor; -pub mod hooks; -pub mod knowledge; -pub mod mcp; -pub mod model; -pub mod persist; -pub mod profile; -pub mod prompts; -pub mod subscribe; -pub mod tools; -pub mod usage; - -use clap::Parser; -use clear::ClearArgs; -use compact::CompactArgs; -use context::ContextSubcommand; -use editor::EditorArgs; -use hooks::HooksArgs; -use knowledge::KnowledgeSubcommand; -use mcp::McpArgs; -use model::ModelArgs; -use persist::PersistSubcommand; -use profile::ProfileSubcommand; -use prompts::PromptsArgs; -use tools::ToolsArgs; - -use crate::cli::chat::cli::subscribe::SubscribeArgs; -use crate::cli::chat::cli::usage::UsageArgs; -use crate::cli::chat::{ - ChatError, - ChatSession, - ChatState, - EXTRA_HELP, -}; -use crate::cli::issue; -use crate::os::Os; - -/// q (Amazon Q Chat) -#[derive(Debug, PartialEq, Parser)] -#[command(color = clap::ColorChoice::Always, term_width = 0, after_long_help = EXTRA_HELP)] -pub enum SlashCommand { - /// Quit the application - #[command(aliases = ["q", "exit"])] - Quit, - /// Clear the conversation history - Clear(ClearArgs), - /// Manage profiles - #[command(subcommand)] - Profile(ProfileSubcommand), - /// Manage context files for the chat session - #[command(subcommand)] - Context(ContextSubcommand), - /// (Beta) Manage knowledge base for persistent context storage. Requires "q settings - /// chat.enableKnowledge true" - #[command(subcommand, hide = true)] - Knowledge(KnowledgeSubcommand), - /// Open $EDITOR (defaults to vi) to compose a prompt - #[command(name = "editor")] - PromptEditor(EditorArgs), - /// Summarize the conversation to free up context space - Compact(CompactArgs), - /// View and manage tools and permissions - Tools(ToolsArgs), - /// Create a new Github issue or make a feature request - Issue(issue::IssueArgs), - /// View and retrieve prompts - Prompts(PromptsArgs), - /// View and manage context hooks - Hooks(HooksArgs), - /// Show current session's context window usage - Usage(UsageArgs), - /// See mcp server loaded - Mcp(McpArgs), - /// Select a model for the current conversation session - Model(ModelArgs), - /// Upgrade to a Q Developer Pro subscription for increased query limits - Subscribe(SubscribeArgs), - #[command(flatten)] - Persist(PersistSubcommand), - // #[command(flatten)] - // Root(RootSubcommand), -} - -impl SlashCommand { - pub async fn execute(self, os: &mut Os, session: &mut ChatSession) -> Result { - match self { - Self::Quit => Ok(ChatState::Exit), - Self::Clear(args) => args.execute(session).await, - Self::Profile(subcommand) => subcommand.execute(os, session).await, - Self::Context(args) => args.execute(os, session).await, - Self::Knowledge(subcommand) => subcommand.execute(os, session).await, - Self::PromptEditor(args) => args.execute(session).await, - Self::Compact(args) => args.execute(os, session).await, - Self::Tools(args) => args.execute(session).await, - Self::Issue(args) => { - if let Err(err) = args.execute(os).await { - return Err(ChatError::Custom(err.to_string().into())); - } - - Ok(ChatState::PromptUser { - skip_printing_tools: true, - }) - }, - Self::Prompts(args) => args.execute(session).await, - Self::Hooks(args) => args.execute(os, session).await, - Self::Usage(args) => args.execute(os, session).await, - Self::Mcp(args) => args.execute(session).await, - Self::Model(args) => args.execute(session).await, - Self::Subscribe(args) => args.execute(os, session).await, - Self::Persist(subcommand) => subcommand.execute(os, session).await, - // Self::Root(subcommand) => { - // if let Err(err) = subcommand.execute(os, database, telemetry).await { - // return Err(ChatError::Custom(err.to_string().into())); - // } - // - // Ok(ChatState::PromptUser { - // skip_printing_tools: true, - // }) - // }, - } - } -} diff --git a/crates/chat-cli/src/cli/chat/cli/model.rs b/crates/chat-cli/src/cli/chat/cli/model.rs deleted file mode 100644 index 5ce6015d3..000000000 --- a/crates/chat-cli/src/cli/chat/cli/model.rs +++ /dev/null @@ -1,130 +0,0 @@ -use clap::Args; -use crossterm::style::{ - self, - Color, -}; -use crossterm::{ - execute, - queue, -}; -use dialoguer::Select; - -use crate::auth::builder_id::{ - BuilderIdToken, - TokenType, -}; -use crate::cli::chat::{ - ChatError, - ChatSession, - ChatState, -}; -use crate::os::Os; - -pub struct ModelOption { - pub name: &'static str, - pub model_id: &'static str, -} - -pub const MODEL_OPTIONS: [ModelOption; 3] = [ - ModelOption { - name: "claude-4-sonnet", - model_id: "CLAUDE_SONNET_4_20250514_V1_0", - }, - ModelOption { - name: "claude-3.7-sonnet", - model_id: "CLAUDE_3_7_SONNET_20250219_V1_0", - }, - ModelOption { - name: "claude-3.5-sonnet", - model_id: "CLAUDE_3_5_SONNET_20241022_V2_0", - }, -]; - -#[deny(missing_docs)] -#[derive(Debug, PartialEq, Args)] -pub struct ModelArgs; - -impl ModelArgs { - pub async fn execute(self, session: &mut ChatSession) -> Result { - Ok(select_model(session)?.unwrap_or(ChatState::PromptUser { - skip_printing_tools: false, - })) - } -} - -pub fn select_model(session: &mut ChatSession) -> Result, ChatError> { - queue!(session.stderr, style::Print("\n"))?; - let active_model_id = session.conversation.model.as_deref(); - let labels: Vec = MODEL_OPTIONS - .iter() - .map(|opt| { - if (opt.model_id.is_empty() && active_model_id.is_none()) || Some(opt.model_id) == active_model_id { - format!("{} (active)", opt.name) - } else { - opt.name.to_owned() - } - }) - .collect(); - - let selection: Option<_> = match Select::with_theme(&crate::util::dialoguer_theme()) - .with_prompt("Select a model for this chat session") - .items(&labels) - .default(0) - .interact_on_opt(&dialoguer::console::Term::stdout()) - { - Ok(sel) => { - let _ = crossterm::execute!( - std::io::stdout(), - crossterm::style::SetForegroundColor(crossterm::style::Color::Magenta) - ); - sel - }, - // Ctrl‑C -> Err(Interrupted) - Err(dialoguer::Error::IO(ref e)) if e.kind() == std::io::ErrorKind::Interrupted => return Ok(None), - Err(e) => return Err(ChatError::Custom(format!("Failed to choose model: {e}").into())), - }; - - queue!(session.stderr, style::ResetColor)?; - - if let Some(index) = selection { - let selected = &MODEL_OPTIONS[index]; - let model_id_str = selected.model_id.to_string(); - session.conversation.model = Some(model_id_str); - - queue!( - session.stderr, - style::Print("\n"), - style::Print(format!(" Using {}\n\n", selected.name)), - style::ResetColor, - style::SetForegroundColor(Color::Reset), - style::SetBackgroundColor(Color::Reset), - )?; - } - - execute!(session.stderr, style::ResetColor)?; - - Ok(Some(ChatState::PromptUser { - skip_printing_tools: false, - })) -} - -/// Returns Claude 3.7 for: Amazon IDC users, FRA region users -/// Returns Claude 4.0 for: Builder ID users, other regions -pub async fn default_model_id(os: &Os) -> &'static str { - // Check FRA region first - if let Ok(Some(profile)) = os.database.get_auth_profile() { - if profile.arn.split(':').nth(3) == Some("eu-central-1") { - return "CLAUDE_3_7_SONNET_20250219_V1_0"; - } - } - - // Check if Amazon IDC user - if let Ok(Some(token)) = BuilderIdToken::load(&os.database).await { - if matches!(token.token_type(), TokenType::IamIdentityCenter) && token.is_amzn_user() { - return "CLAUDE_3_7_SONNET_20250219_V1_0"; - } - } - - // Default to 4.0 - "CLAUDE_SONNET_4_20250514_V1_0" -} diff --git a/crates/chat-cli/src/cli/chat/cli/persist.rs b/crates/chat-cli/src/cli/chat/cli/persist.rs deleted file mode 100644 index 3b8c8d0a9..000000000 --- a/crates/chat-cli/src/cli/chat/cli/persist.rs +++ /dev/null @@ -1,114 +0,0 @@ -use clap::Subcommand; -use crossterm::execute; -use crossterm::style::{ - self, - Attribute, - Color, -}; - -use crate::cli::ConversationState; -use crate::cli::chat::{ - ChatError, - ChatSession, - ChatState, -}; -use crate::os::Os; - -#[deny(missing_docs)] -#[derive(Debug, PartialEq, Subcommand)] -pub enum PersistSubcommand { - /// Save the current conversation - Save { - path: String, - #[arg(short, long)] - force: bool, - }, - /// Load a previous conversation - Load { path: String }, -} - -impl PersistSubcommand { - pub async fn execute(self, os: &Os, session: &mut ChatSession) -> Result { - macro_rules! tri { - ($v:expr, $name:expr, $path:expr) => { - match $v { - Ok(v) => v, - Err(err) => { - execute!( - session.stderr, - style::SetForegroundColor(Color::Red), - style::Print(format!("\nFailed to {} {}: {}\n\n", $name, $path, &err)), - style::SetAttribute(Attribute::Reset) - )?; - - return Ok(ChatState::PromptUser { - skip_printing_tools: true, - }); - }, - } - }; - } - - match self { - Self::Save { path, force } => { - let contents = tri!(serde_json::to_string_pretty(&session.conversation), "export to", &path); - if os.fs.exists(&path) && !force { - execute!( - session.stderr, - style::SetForegroundColor(Color::Red), - style::Print(format!( - "\nFile at {} already exists. To overwrite, use -f or --force\n\n", - &path - )), - style::SetAttribute(Attribute::Reset) - )?; - return Ok(ChatState::PromptUser { - skip_printing_tools: true, - }); - } - tri!(os.fs.write(&path, contents).await, "export to", &path); - - execute!( - session.stderr, - style::SetForegroundColor(Color::Green), - style::Print(format!("\n✔ Exported conversation state to {}\n\n", &path)), - style::SetAttribute(Attribute::Reset) - )?; - }, - Self::Load { path } => { - // Try the original path first - let original_result = os.fs.read_to_string(&path).await; - - // If the original path fails and doesn't end with .json, try with .json appended - let contents = if original_result.is_err() && !path.ends_with(".json") { - let json_path = format!("{}.json", path); - match os.fs.read_to_string(&json_path).await { - Ok(content) => content, - Err(_) => { - // If both paths fail, return the original error for better user experience - tri!(original_result, "import from", &path) - }, - } - } else { - tri!(original_result, "import from", &path) - }; - - let mut new_state: ConversationState = tri!(serde_json::from_str(&contents), "import from", &path); - new_state.reload_serialized_state(os).await; - std::mem::swap(&mut new_state.tool_manager, &mut session.conversation.tool_manager); - session.conversation = new_state; - - execute!( - session.stderr, - style::SetForegroundColor(Color::Green), - style::Print(format!("\n✔ Imported conversation state from {}\n\n", &path)), - style::SetAttribute(Attribute::Reset) - )?; - }, - } - - Ok(ChatState::PromptUser { - skip_printing_tools: true, - }) - } -} diff --git a/crates/chat-cli/src/cli/chat/cli/profile.rs b/crates/chat-cli/src/cli/chat/cli/profile.rs deleted file mode 100644 index a963e0d6d..000000000 --- a/crates/chat-cli/src/cli/chat/cli/profile.rs +++ /dev/null @@ -1,153 +0,0 @@ -use clap::Subcommand; -use crossterm::execute; -use crossterm::style::{ - self, - Color, -}; -use tracing::warn; - -use crate::cli::chat::{ - ChatError, - ChatSession, - ChatState, -}; -use crate::os::Os; - -#[deny(missing_docs)] -#[derive(Debug, PartialEq, Subcommand)] -#[command( - before_long_help = "Profiles allow you to organize and manage different sets of context files for different projects or tasks. - -Notes -• The \"global\" profile contains context files that are available in all profiles -• The \"default\" profile is used when no profile is specified -• You can switch between profiles to work on different projects -• Each profile maintains its own set of context files" -)] -pub enum ProfileSubcommand { - /// List all available profiles - List, - /// Create a new profile with the specified name - Create { name: String }, - /// Delete the specified profile - Delete { name: String }, - /// Switch to the specified profile - Set { name: String }, - /// Rename a profile - Rename { old_name: String, new_name: String }, -} - -impl ProfileSubcommand { - pub async fn execute(self, os: &Os, session: &mut ChatSession) -> Result { - let Some(context_manager) = &mut session.conversation.context_manager else { - return Ok(ChatState::PromptUser { - skip_printing_tools: true, - }); - }; - - macro_rules! print_err { - ($err:expr) => { - execute!( - session.stderr, - style::SetForegroundColor(Color::Red), - style::Print(format!("\nError: {}\n\n", $err)), - style::SetForegroundColor(Color::Reset) - )? - }; - } - - match self { - Self::List => { - let profiles = match context_manager.list_profiles(os).await { - Ok(profiles) => profiles, - Err(e) => { - execute!( - session.stderr, - style::SetForegroundColor(Color::Red), - style::Print(format!("\nError listing profiles: {}\n\n", e)), - style::SetForegroundColor(Color::Reset) - )?; - vec![] - }, - }; - - execute!(session.stderr, style::Print("\n"))?; - for profile in profiles { - if profile == context_manager.current_profile { - execute!( - session.stderr, - style::SetForegroundColor(Color::Green), - style::Print("* "), - style::Print(&profile), - style::SetForegroundColor(Color::Reset), - style::Print("\n") - )?; - } else { - execute!( - session.stderr, - style::Print(" "), - style::Print(&profile), - style::Print("\n") - )?; - } - } - execute!(session.stderr, style::Print("\n"))?; - }, - Self::Create { name } => match context_manager.create_profile(os, &name).await { - Ok(_) => { - execute!( - session.stderr, - style::SetForegroundColor(Color::Green), - style::Print(format!("\nCreated profile: {}\n\n", name)), - style::SetForegroundColor(Color::Reset) - )?; - context_manager - .switch_profile(os, &name) - .await - .map_err(|e| warn!(?e, "failed to switch to newly created profile")) - .ok(); - }, - Err(e) => print_err!(e), - }, - Self::Delete { name } => match context_manager.delete_profile(os, &name).await { - Ok(_) => { - execute!( - session.stderr, - style::SetForegroundColor(Color::Green), - style::Print(format!("\nDeleted profile: {}\n\n", name)), - style::SetForegroundColor(Color::Reset) - )?; - }, - Err(e) => print_err!(e), - }, - Self::Set { name } => match context_manager.switch_profile(os, &name).await { - Ok(_) => { - execute!( - session.stderr, - style::SetForegroundColor(Color::Green), - style::Print(format!("\nSwitched to profile: {}\n\n", name)), - style::SetForegroundColor(Color::Reset) - )?; - }, - Err(e) => print_err!(e), - }, - Self::Rename { old_name, new_name } => { - match context_manager.rename_profile(os, &old_name, &new_name).await { - Ok(_) => { - execute!( - session.stderr, - style::SetForegroundColor(Color::Green), - style::Print(format!("\nRenamed profile: {} -> {}\n\n", old_name, new_name)), - style::SetForegroundColor(Color::Reset) - )?; - }, - Err(e) => print_err!(e), - } - }, - } - - Ok(ChatState::PromptUser { - skip_printing_tools: true, - }) - } -} diff --git a/crates/chat-cli/src/cli/chat/cli/prompts.rs b/crates/chat-cli/src/cli/chat/cli/prompts.rs deleted file mode 100644 index e61cc7966..000000000 --- a/crates/chat-cli/src/cli/chat/cli/prompts.rs +++ /dev/null @@ -1,308 +0,0 @@ -use std::collections::{ - HashMap, - VecDeque, -}; - -use clap::{ - Args, - Subcommand, -}; -use crossterm::style::{ - self, - Attribute, - Color, -}; -use crossterm::{ - execute, - queue, -}; -use thiserror::Error; -use unicode_width::UnicodeWidthStr; - -use crate::cli::chat::error_formatter::format_mcp_error; -use crate::cli::chat::tool_manager::PromptBundle; -use crate::cli::chat::{ - ChatError, - ChatSession, - ChatState, -}; -use crate::mcp_client::PromptGetResult; - -#[derive(Debug, Error)] -pub enum GetPromptError { - #[error("Prompt with name {0} does not exist")] - PromptNotFound(String), - #[error("Prompt {0} is offered by more than one server. Use one of the following {1}")] - AmbiguousPrompt(String, String), - #[error("Missing client")] - MissingClient, - #[error("Missing prompt name")] - MissingPromptName, - #[error("Synchronization error: {0}")] - Synchronization(String), - #[error("Missing prompt bundle")] - MissingPromptInfo, - #[error(transparent)] - General(#[from] eyre::Report), -} - -#[deny(missing_docs)] -#[derive(Debug, PartialEq, Args)] -#[command(color = clap::ColorChoice::Always, - before_long_help = color_print::cstr!{"Prompts are reusable templates that help you quickly access common workflows and tasks. -These templates are provided by the mcp servers you have installed and configured. - -To actually retrieve a prompt, directly start with the following command (without prepending /prompt get): - @<> [arg] Retrieve prompt specified -Or if you prefer the long way: - /prompts get <> [arg] Retrieve prompt specified" -})] -pub struct PromptsArgs { - #[command(subcommand)] - subcommand: Option, -} - -impl PromptsArgs { - pub async fn execute(self, session: &mut ChatSession) -> Result { - let search_word = match &self.subcommand { - Some(PromptsSubcommand::List { search_word }) => search_word.clone(), - _ => None, - }; - - if let Some(subcommand) = self.subcommand { - if matches!(subcommand, PromptsSubcommand::Get { .. }) { - return subcommand.execute(session).await; - } - } - - let terminal_width = session.terminal_width(); - let mut prompts_wl = session.conversation.tool_manager.prompts.write().map_err(|e| { - ChatError::Custom(format!("Poison error encountered while retrieving prompts: {}", e).into()) - })?; - session.conversation.tool_manager.refresh_prompts(&mut prompts_wl)?; - let mut longest_name = ""; - let arg_pos = { - let optimal_case = UnicodeWidthStr::width(longest_name) + terminal_width / 4; - if optimal_case > terminal_width { - terminal_width / 3 - } else { - optimal_case - } - }; - // Add usage guidance at the top - queue!( - session.stderr, - style::Print("\n"), - style::SetAttribute(Attribute::Bold), - style::Print("Usage: "), - style::SetAttribute(Attribute::Reset), - style::Print("You can use a prompt by typing "), - style::SetAttribute(Attribute::Bold), - style::SetForegroundColor(Color::Green), - style::Print("'@ [...args]'"), - style::SetForegroundColor(Color::Reset), - style::SetAttribute(Attribute::Reset), - style::Print("\n\n"), - )?; - queue!( - session.stderr, - style::Print("\n"), - style::SetAttribute(Attribute::Bold), - style::Print("Prompt"), - style::SetAttribute(Attribute::Reset), - style::Print({ - let name_width = UnicodeWidthStr::width("Prompt"); - let padding = arg_pos.saturating_sub(name_width); - " ".repeat(padding) - }), - style::SetAttribute(Attribute::Bold), - style::Print("Arguments (* = required)"), - style::SetAttribute(Attribute::Reset), - style::Print("\n"), - style::Print(format!("{}\n", "▔".repeat(terminal_width))), - )?; - let mut prompts_by_server: Vec<_> = prompts_wl - .iter() - .fold( - HashMap::<&String, Vec<&PromptBundle>>::new(), - |mut acc, (prompt_name, bundles)| { - if prompt_name.contains(search_word.as_deref().unwrap_or("")) { - if prompt_name.len() > longest_name.len() { - longest_name = prompt_name.as_str(); - } - for bundle in bundles { - acc.entry(&bundle.server_name) - .and_modify(|b| b.push(bundle)) - .or_insert(vec![bundle]); - } - } - acc - }, - ) - .into_iter() - .collect(); - prompts_by_server.sort_by_key(|(server_name, _)| server_name.as_str()); - - for (i, (server_name, bundles)) in prompts_by_server.iter_mut().enumerate() { - bundles.sort_by_key(|bundle| &bundle.prompt_get.name); - - if i > 0 { - queue!(session.stderr, style::Print("\n"))?; - } - queue!( - session.stderr, - style::SetAttribute(Attribute::Bold), - style::Print(server_name), - style::Print(" (MCP):"), - style::SetAttribute(Attribute::Reset), - style::Print("\n"), - )?; - for bundle in bundles { - queue!( - session.stderr, - style::Print("- "), - style::Print(&bundle.prompt_get.name), - style::Print({ - if bundle - .prompt_get - .arguments - .as_ref() - .is_some_and(|args| !args.is_empty()) - { - let name_width = UnicodeWidthStr::width(bundle.prompt_get.name.as_str()); - let padding = arg_pos - .saturating_sub(name_width) - .saturating_sub(UnicodeWidthStr::width("- ")); - " ".repeat(padding.max(1)) - } else { - "\n".to_owned() - } - }) - )?; - if let Some(args) = bundle.prompt_get.arguments.as_ref() { - for (i, arg) in args.iter().enumerate() { - queue!( - session.stderr, - style::SetForegroundColor(Color::DarkGrey), - style::Print(match arg.required { - Some(true) => format!("{}*", arg.name), - _ => arg.name.clone(), - }), - style::SetForegroundColor(Color::Reset), - style::Print(if i < args.len() - 1 { ", " } else { "\n" }), - )?; - } - } - } - } - - Ok(ChatState::PromptUser { - skip_printing_tools: true, - }) - } -} - -#[deny(missing_docs)] -#[derive(Clone, Debug, PartialEq, Subcommand)] -pub enum PromptsSubcommand { - /// List available prompts from a tool or show all available prompt - List { search_word: Option }, - Get { - #[arg(long, hide = true)] - orig_input: Option, - name: String, - arguments: Option>, - }, -} - -impl PromptsSubcommand { - pub async fn execute(self, session: &mut ChatSession) -> Result { - let PromptsSubcommand::Get { - orig_input, - name, - arguments, - } = self - else { - unreachable!("List has already been parsed out at this point"); - }; - - let prompts = match session.conversation.tool_manager.get_prompt(name, arguments).await { - Ok(resp) => resp, - Err(e) => { - match e { - GetPromptError::AmbiguousPrompt(prompt_name, alt_msg) => { - queue!( - session.stderr, - style::Print("\n"), - style::SetForegroundColor(Color::Yellow), - style::Print("Prompt "), - style::SetForegroundColor(Color::Cyan), - style::Print(prompt_name), - style::SetForegroundColor(Color::Yellow), - style::Print(" is ambiguous. Use one of the following "), - style::SetForegroundColor(Color::Cyan), - style::Print(alt_msg), - style::SetForegroundColor(Color::Reset), - )?; - }, - GetPromptError::PromptNotFound(prompt_name) => { - queue!( - session.stderr, - style::Print("\n"), - style::SetForegroundColor(Color::Yellow), - style::Print("Prompt "), - style::SetForegroundColor(Color::Cyan), - style::Print(prompt_name), - style::SetForegroundColor(Color::Yellow), - style::Print(" not found. Use "), - style::SetForegroundColor(Color::Cyan), - style::Print("/prompts list"), - style::SetForegroundColor(Color::Yellow), - style::Print(" to see available prompts.\n"), - style::SetForegroundColor(Color::Reset), - )?; - }, - _ => return Err(ChatError::Custom(e.to_string().into())), - } - execute!(session.stderr, style::Print("\n"))?; - return Ok(ChatState::PromptUser { - skip_printing_tools: true, - }); - }, - }; - if let Some(err) = prompts.error { - // If we are running into error we should just display the error - // and abort. - let to_display = serde_json::json!(err); - queue!( - session.stderr, - style::Print("\n"), - style::SetAttribute(Attribute::Bold), - style::Print("Error encountered while retrieving prompt:"), - style::SetAttribute(Attribute::Reset), - style::Print("\n"), - style::SetForegroundColor(Color::Red), - style::Print(format_mcp_error(&to_display)), - style::SetForegroundColor(Color::Reset), - style::Print("\n"), - )?; - } else { - let prompts = prompts - .result - .ok_or(ChatError::Custom("Result field missing from prompt/get request".into()))?; - let prompts = serde_json::from_value::(prompts) - .map_err(|e| ChatError::Custom(format!("Failed to deserialize prompt/get result: {:?}", e).into()))?; - session.pending_prompts.clear(); - session.pending_prompts.append(&mut VecDeque::from(prompts.messages)); - return Ok(ChatState::HandleInput { - input: orig_input.unwrap_or_default(), - }); - } - - execute!(session.stderr, style::Print("\n"))?; - - Ok(ChatState::PromptUser { - skip_printing_tools: true, - }) - } -} diff --git a/crates/chat-cli/src/cli/chat/cli/subscribe.rs b/crates/chat-cli/src/cli/chat/cli/subscribe.rs deleted file mode 100644 index c92090874..000000000 --- a/crates/chat-cli/src/cli/chat/cli/subscribe.rs +++ /dev/null @@ -1,197 +0,0 @@ -use clap::Args; -use crossterm::style::{ - Color, - Stylize, -}; -use crossterm::{ - cursor, - execute, - queue, - style, -}; - -use crate::auth::builder_id::is_idc_user; -use crate::cli::chat::{ - ActualSubscriptionStatus, - ChatError, - ChatSession, - ChatState, - get_subscription_status_with_spinner, - with_spinner, -}; -use crate::os::Os; -use crate::util::system_info::is_remote; - -const SUBSCRIBE_TITLE_TEXT: &str = color_print::cstr! { "Subscribe to Q Developer Pro" }; - -const SUBSCRIBE_TEXT: &str = color_print::cstr! { "During the upgrade, you'll be asked to link your Builder ID to the AWS account that will be billed the monthly subscription fee. - -Need help? Visit our subscription support page> https://docs.aws.amazon.com/console/amazonq/upgrade-builder-id" }; - -#[deny(missing_docs)] -#[derive(Debug, PartialEq, Args)] -pub struct SubscribeArgs { - #[arg(long)] - manage: bool, -} - -impl SubscribeArgs { - pub async fn execute(self, os: &mut Os, session: &mut ChatSession) -> Result { - if is_idc_user(&os.database) - .await - .map_err(|e| ChatError::Custom(e.to_string().into()))? - { - execute!( - session.stderr, - style::SetForegroundColor(Color::Yellow), - style::Print("\nYour Q Developer Pro subscription is managed through IAM Identity Center.\n\n"), - style::SetForegroundColor(Color::Reset), - )?; - } else if self.manage { - queue!(session.stderr, style::Print("\n"),)?; - match get_subscription_status_with_spinner(os, &mut session.stderr).await { - Ok(status) => { - if status != ActualSubscriptionStatus::Active { - queue!( - session.stderr, - style::SetForegroundColor(Color::Yellow), - style::Print("You don't seem to have a Q Developer Pro subscription. "), - style::SetForegroundColor(Color::DarkGrey), - style::Print("Use "), - style::SetForegroundColor(Color::Green), - style::Print("/subscribe"), - style::SetForegroundColor(Color::DarkGrey), - style::Print(" to upgrade your subscription.\n\n"), - style::SetForegroundColor(Color::Reset), - )?; - } - }, - Err(err) => { - queue!( - session.stderr, - style::SetForegroundColor(Color::Red), - style::Print(format!("Failed to get subscription status: {}\n\n", err)), - style::SetForegroundColor(Color::Reset), - )?; - }, - } - - let url = format!( - "https://{}.console.aws.amazon.com/amazonq/developer/home#/subscriptions", - os.database - .get_idc_region() - .ok() - .flatten() - .unwrap_or("us-east-1".to_string()) - ); - if is_remote() || crate::util::open::open_url_async(&url).await.is_err() { - execute!( - session.stderr, - style::Print(format!("Open this URL to manage your subscription: {}\n\n", url.blue())), - style::ResetColor, - style::SetForegroundColor(Color::Reset), - )?; - } - } else { - upgrade_to_pro(os, session).await?; - } - - Ok(ChatState::PromptUser { - skip_printing_tools: true, - }) - } -} - -async fn upgrade_to_pro(os: &mut Os, session: &mut ChatSession) -> Result<(), ChatError> { - queue!(session.stderr, style::Print("\n"),)?; - - // Get current subscription status - match get_subscription_status_with_spinner(os, &mut session.stderr).await { - Ok(status) => { - if status == ActualSubscriptionStatus::Active { - queue!( - session.stderr, - style::SetForegroundColor(Color::Yellow), - style::Print("Your Builder ID already has a Q Developer Pro subscription.\n\n"), - style::SetForegroundColor(Color::Reset), - )?; - return Ok(()); - } - }, - Err(e) => { - execute!( - session.stderr, - style::SetForegroundColor(Color::Red), - style::Print(format!("{}\n\n", e)), - style::SetForegroundColor(Color::Reset), - )?; - // Don't exit early here, the check isn't required to subscribe. - }, - } - - // Upgrade information - queue!( - session.stderr, - style::Print(SUBSCRIBE_TITLE_TEXT), - style::SetForegroundColor(Color::Grey), - style::Print(format!("\n\n{}\n\n", SUBSCRIBE_TEXT)), - style::SetForegroundColor(Color::Reset), - cursor::Show - )?; - - let prompt = format!( - "{}{}{}{}{}", - "Would you like to open the AWS console to upgrade? [".dark_grey(), - "y".green(), - "/".dark_grey(), - "n".green(), - "]: ".dark_grey(), - ); - - let user_input = session.read_user_input(&prompt, true); - queue!( - session.stderr, - style::SetForegroundColor(Color::Reset), - style::Print("\n"), - )?; - - if !user_input.is_some_and(|i| ["y", "Y"].contains(&i.as_str())) { - execute!( - session.stderr, - style::SetForegroundColor(Color::Red), - style::Print("Upgrade cancelled.\n\n"), - style::SetForegroundColor(Color::Reset), - )?; - return Ok(()); - } - - // Create a subscription token and open the webpage - let r = os.client.create_subscription_token().await?; - - let url = with_spinner(&mut session.stderr, "Preparing to upgrade...", || async move { - r.encoded_verification_url() - .map(|s| s.to_string()) - .ok_or(ChatError::Custom("Missing verification URL".into())) - }) - .await?; - - if is_remote() || crate::util::open::open_url_async(&url).await.is_err() { - queue!( - session.stderr, - style::SetForegroundColor(Color::DarkGrey), - style::Print(format!( - "{} Having issues opening the AWS console? Try copy and pasting the URL > {}\n\n", - "?".magenta(), - url.blue() - )), - style::SetForegroundColor(Color::Reset), - )?; - } - - execute!( - session.stderr, - style::Print("Once upgraded, type a new prompt to continue your work, or type /quit to exit the chat.\n\n") - )?; - - Ok(()) -} diff --git a/crates/chat-cli/src/cli/chat/cli/tools.rs b/crates/chat-cli/src/cli/chat/cli/tools.rs deleted file mode 100644 index 35bf7da6d..000000000 --- a/crates/chat-cli/src/cli/chat/cli/tools.rs +++ /dev/null @@ -1,318 +0,0 @@ -use std::collections::HashSet; -use std::io::Write; - -use clap::{ - Args, - Subcommand, -}; -use crossterm::style::{ - Attribute, - Color, -}; -use crossterm::{ - queue, - style, -}; - -use crate::api_client::model::Tool as FigTool; -use crate::cli::chat::consts::DUMMY_TOOL_NAME; -use crate::cli::chat::tools::ToolOrigin; -use crate::cli::chat::{ - ChatError, - ChatSession, - ChatState, - TRUST_ALL_TEXT, -}; - -#[deny(missing_docs)] -#[derive(Debug, PartialEq, Args)] -pub struct ToolsArgs { - #[command(subcommand)] - subcommand: Option, -} - -impl ToolsArgs { - pub async fn execute(self, session: &mut ChatSession) -> Result { - if let Some(subcommand) = self.subcommand { - return subcommand.execute(session).await; - } - - // No subcommand - print the current tools and their permissions. - // Determine how to format the output nicely. - let terminal_width = session.terminal_width(); - let longest = session - .conversation - .tools - .values() - .flatten() - .map(|FigTool::ToolSpecification(spec)| spec.name.len()) - .max() - .unwrap_or(0); - - queue!( - session.stderr, - style::Print("\n"), - style::SetAttribute(Attribute::Bold), - style::Print({ - // Adding 2 because of "- " preceding every tool name - let width = longest + 2 - "Tool".len() + 4; - format!("Tool{:>width$}Permission", "", width = width) - }), - style::SetAttribute(Attribute::Reset), - style::Print("\n"), - style::Print("▔".repeat(terminal_width)), - )?; - - let mut origin_tools: Vec<_> = session.conversation.tools.iter().collect(); - - // Built in tools always appear first. - origin_tools.sort_by(|(origin_a, _), (origin_b, _)| match (origin_a, origin_b) { - (ToolOrigin::Native, _) => std::cmp::Ordering::Less, - (_, ToolOrigin::Native) => std::cmp::Ordering::Greater, - (ToolOrigin::McpServer(name_a), ToolOrigin::McpServer(name_b)) => name_a.cmp(name_b), - }); - - for (origin, tools) in origin_tools.iter() { - let mut sorted_tools: Vec<_> = tools - .iter() - .filter(|FigTool::ToolSpecification(spec)| spec.name != DUMMY_TOOL_NAME) - .collect(); - - sorted_tools.sort_by_key(|t| match t { - FigTool::ToolSpecification(spec) => &spec.name, - }); - - let to_display = sorted_tools - .iter() - .fold(String::new(), |mut acc, FigTool::ToolSpecification(spec)| { - let width = longest - spec.name.len() + 4; - acc.push_str( - format!( - "- {}{:>width$}{}\n", - spec.name, - "", - session.tool_permissions.display_label(&spec.name), - width = width - ) - .as_str(), - ); - acc - }); - - let _ = queue!( - session.stderr, - style::SetAttribute(Attribute::Bold), - style::Print(format!("{}:\n", origin)), - style::SetAttribute(Attribute::Reset), - style::Print(to_display), - style::Print("\n") - ); - } - - let loading = session.conversation.tool_manager.pending_clients().await; - if !loading.is_empty() { - queue!( - session.stderr, - style::SetAttribute(Attribute::Bold), - style::Print("Servers still loading"), - style::SetAttribute(Attribute::Reset), - style::Print("\n"), - style::Print("▔".repeat(terminal_width)), - )?; - for client in loading { - queue!(session.stderr, style::Print(format!(" - {client}")), style::Print("\n"))?; - } - } - - queue!( - session.stderr, - style::Print("\nTrusted tools will run without confirmation."), - style::SetForegroundColor(Color::DarkGrey), - style::Print(format!("\n{}\n", "* Default settings")), - style::Print("\n💡 Use "), - style::SetForegroundColor(Color::Green), - style::Print("/tools help"), - style::SetForegroundColor(Color::Reset), - style::SetForegroundColor(Color::DarkGrey), - style::Print(" to edit permissions.\n\n"), - style::SetForegroundColor(Color::Reset), - )?; - - Ok(ChatState::default()) - } -} - -#[deny(missing_docs)] -#[derive(Debug, PartialEq, Subcommand)] -#[command( - before_long_help = "By default, Amazon Q will ask for your permission to use certain tools. You can control which tools you -trust so that no confirmation is required. These settings will last only for this session." -)] -pub enum ToolsSubcommand { - /// Show the input schema for all available tools - Schema, - /// Trust a specific tool or tools for the session - Trust { - #[arg(required = true)] - tool_names: Vec, - }, - /// Revert a tool or tools to per-request confirmation - Untrust { - #[arg(required = true)] - tool_names: Vec, - }, - /// Trust all tools (equivalent to deprecated /acceptall) - TrustAll, - /// Reset all tools to default permission levels - Reset, - /// Reset a single tool to default permission level - ResetSingle { tool_name: String }, -} - -impl ToolsSubcommand { - pub async fn execute(self, session: &mut ChatSession) -> Result { - let existing_tools: HashSet<&String> = session - .conversation - .tools - .values() - .flatten() - .map(|FigTool::ToolSpecification(spec)| &spec.name) - .collect(); - - match self { - Self::Schema => { - let schema_json = serde_json::to_string_pretty(&session.conversation.tool_manager.schema) - .map_err(|e| ChatError::Custom(format!("Error converting tool schema to string: {e}").into()))?; - queue!(session.stderr, style::Print(schema_json), style::Print("\n"))?; - }, - Self::Trust { tool_names } => { - let (valid_tools, invalid_tools): (Vec, Vec) = tool_names - .into_iter() - .partition(|tool_name| existing_tools.contains(tool_name)); - - if !invalid_tools.is_empty() { - queue!( - session.stderr, - style::SetForegroundColor(Color::Red), - style::Print(format!("\nCannot trust '{}', ", invalid_tools.join("', '"))), - if invalid_tools.len() > 1 { - style::Print("they do not exist.") - } else { - style::Print("it does not exist.") - }, - style::SetForegroundColor(Color::Reset), - )?; - } - if !valid_tools.is_empty() { - valid_tools.iter().for_each(|t| session.tool_permissions.trust_tool(t)); - queue!( - session.stderr, - style::SetForegroundColor(Color::Green), - if valid_tools.len() > 1 { - style::Print(format!("Tools '{}' are ", valid_tools.join("', '"))) - } else { - style::Print(format!("Tool '{}' is ", valid_tools[0])) - }, - style::Print("now trusted. I will "), - style::SetAttribute(Attribute::Bold), - style::Print("not"), - style::SetAttribute(Attribute::Reset), - style::SetForegroundColor(Color::Green), - style::Print(format!( - " ask for confirmation before running {}.", - if valid_tools.len() > 1 { - "these tools" - } else { - "this tool" - } - )), - style::Print("\n"), - style::SetForegroundColor(Color::Reset), - )?; - } - }, - Self::Untrust { tool_names } => { - let (valid_tools, invalid_tools): (Vec, Vec) = tool_names - .into_iter() - .partition(|tool_name| existing_tools.contains(tool_name)); - - if !invalid_tools.is_empty() { - queue!( - session.stderr, - style::SetForegroundColor(Color::Red), - style::Print(format!("\nCannot untrust '{}', ", invalid_tools.join("', '"))), - if invalid_tools.len() > 1 { - style::Print("they do not exist.") - } else { - style::Print("it does not exist.") - }, - style::SetForegroundColor(Color::Reset), - )?; - } - if !valid_tools.is_empty() { - valid_tools - .iter() - .for_each(|t| session.tool_permissions.untrust_tool(t)); - queue!( - session.stderr, - style::SetForegroundColor(Color::Green), - if valid_tools.len() > 1 { - style::Print(format!("Tools '{}' are ", valid_tools.join("', '"))) - } else { - style::Print(format!("Tool '{}' is ", valid_tools[0])) - }, - style::Print("set to per-request confirmation.\n"), - style::SetForegroundColor(Color::Reset), - )?; - } - }, - Self::TrustAll => { - session - .conversation - .tools - .values() - .flatten() - .for_each(|FigTool::ToolSpecification(spec)| { - session.tool_permissions.trust_tool(spec.name.as_str()); - }); - queue!(session.stderr, style::Print(TRUST_ALL_TEXT), style::Print("\n"))?; - }, - Self::Reset => { - session.tool_permissions.reset(); - queue!( - session.stderr, - style::SetForegroundColor(Color::Green), - style::Print("Reset all tools to the default permission levels.\n"), - style::SetForegroundColor(Color::Reset), - )?; - }, - Self::ResetSingle { tool_name } => { - if session.tool_permissions.has(&tool_name) || session.tool_permissions.trust_all { - session.tool_permissions.reset_tool(&tool_name); - queue!( - session.stderr, - style::SetForegroundColor(Color::Green), - style::Print(format!("Reset tool '{}' to the default permission level.\n", tool_name)), - style::SetForegroundColor(Color::Reset), - )?; - } else { - queue!( - session.stderr, - style::SetForegroundColor(Color::Red), - style::Print(format!( - "Tool '{}' does not exist or is already in default settings.\n", - tool_name - )), - style::SetForegroundColor(Color::Reset), - )?; - } - }, - }; - - session.stderr.flush()?; - - Ok(ChatState::PromptUser { - skip_printing_tools: true, - }) - } -} diff --git a/crates/chat-cli/src/cli/chat/cli/usage.rs b/crates/chat-cli/src/cli/chat/cli/usage.rs deleted file mode 100644 index ef4d8be42..000000000 --- a/crates/chat-cli/src/cli/chat/cli/usage.rs +++ /dev/null @@ -1,214 +0,0 @@ -use clap::Args; -use crossterm::style::{ - Attribute, - Color, -}; -use crossterm::{ - execute, - queue, - style, -}; - -use crate::cli::chat::consts::CONTEXT_WINDOW_SIZE; -use crate::cli::chat::token_counter::{ - CharCount, - TokenCount, -}; -use crate::cli::chat::{ - ChatError, - ChatSession, - ChatState, -}; -use crate::os::Os; -#[deny(missing_docs)] -#[derive(Debug, PartialEq, Args)] -pub struct UsageArgs; - -impl UsageArgs { - pub async fn execute(self, os: &Os, session: &mut ChatSession) -> Result { - let state = session - .conversation - .backend_conversation_state(os, true, &mut session.stderr) - .await?; - - if !state.dropped_context_files.is_empty() { - execute!( - session.stderr, - style::SetForegroundColor(Color::DarkYellow), - style::Print("\nSome context files are dropped due to size limit, please run "), - style::SetForegroundColor(Color::DarkGreen), - style::Print("/context show "), - style::SetForegroundColor(Color::DarkYellow), - style::Print("to learn more.\n"), - style::SetForegroundColor(style::Color::Reset) - )?; - } - - let data = state.calculate_conversation_size(); - let tool_specs_json: String = state - .tools - .values() - .filter_map(|s| serde_json::to_string(s).ok()) - .collect::>() - .join(""); - let context_token_count: TokenCount = data.context_messages.into(); - let assistant_token_count: TokenCount = data.assistant_messages.into(); - let user_token_count: TokenCount = data.user_messages.into(); - let tools_char_count: CharCount = tool_specs_json.len().into(); // usize → CharCount - let tools_token_count: TokenCount = tools_char_count.into(); // CharCount → TokenCount - let total_token_used: TokenCount = - (data.context_messages + data.user_messages + data.assistant_messages + tools_char_count).into(); - let window_width = session.terminal_width(); - // set a max width for the progress bar for better aesthetic - let progress_bar_width = std::cmp::min(window_width, 80); - - let context_width = - ((context_token_count.value() as f64 / CONTEXT_WINDOW_SIZE as f64) * progress_bar_width as f64) as usize; - let assistant_width = - ((assistant_token_count.value() as f64 / CONTEXT_WINDOW_SIZE as f64) * progress_bar_width as f64) as usize; - let tools_width = - ((tools_token_count.value() as f64 / CONTEXT_WINDOW_SIZE as f64) * progress_bar_width as f64) as usize; - let user_width = - ((user_token_count.value() as f64 / CONTEXT_WINDOW_SIZE as f64) * progress_bar_width as f64) as usize; - - let left_over_width = progress_bar_width - - std::cmp::min( - context_width + assistant_width + user_width + tools_width, - progress_bar_width, - ); - - let is_overflow = (context_width + assistant_width + user_width + tools_width) > progress_bar_width; - - if is_overflow { - queue!( - session.stderr, - style::Print(format!( - "\nCurrent context window ({} of {}k tokens used)\n", - total_token_used, - CONTEXT_WINDOW_SIZE / 1000 - )), - style::SetForegroundColor(Color::DarkRed), - style::Print("█".repeat(progress_bar_width)), - style::SetForegroundColor(Color::Reset), - style::Print(" "), - style::Print(format!( - "{:.2}%", - (total_token_used.value() as f32 / CONTEXT_WINDOW_SIZE as f32) * 100.0 - )), - )?; - } else { - queue!( - session.stderr, - style::Print(format!( - "\nCurrent context window ({} of {}k tokens used)\n", - total_token_used, - CONTEXT_WINDOW_SIZE / 1000 - )), - // Context files - style::SetForegroundColor(Color::DarkCyan), - // add a nice visual to mimic "tiny" progress, so the overrall progress bar doesn't look too - // empty - style::Print("|".repeat(if context_width == 0 && *context_token_count > 0 { - 1 - } else { - 0 - })), - style::Print("█".repeat(context_width)), - // Tools - style::SetForegroundColor(Color::DarkRed), - style::Print("|".repeat(if tools_width == 0 && *tools_token_count > 0 { - 1 - } else { - 0 - })), - style::Print("█".repeat(tools_width)), - // Assistant responses - style::SetForegroundColor(Color::Blue), - style::Print("|".repeat(if assistant_width == 0 && *assistant_token_count > 0 { - 1 - } else { - 0 - })), - style::Print("█".repeat(assistant_width)), - // User prompts - style::SetForegroundColor(Color::Magenta), - style::Print("|".repeat(if user_width == 0 && *user_token_count > 0 { 1 } else { 0 })), - style::Print("█".repeat(user_width)), - style::SetForegroundColor(Color::DarkGrey), - style::Print("█".repeat(left_over_width)), - style::Print(" "), - style::SetForegroundColor(Color::Reset), - style::Print(format!( - "{:.2}%", - (total_token_used.value() as f32 / CONTEXT_WINDOW_SIZE as f32) * 100.0 - )), - )?; - } - - execute!(session.stderr, style::Print("\n\n"))?; - - queue!( - session.stderr, - style::SetForegroundColor(Color::DarkCyan), - style::Print("█ Context files: "), - style::SetForegroundColor(Color::Reset), - style::Print(format!( - "~{} tokens ({:.2}%)\n", - context_token_count, - (context_token_count.value() as f32 / CONTEXT_WINDOW_SIZE as f32) * 100.0 - )), - style::SetForegroundColor(Color::DarkRed), - style::Print("█ Tools: "), - style::SetForegroundColor(Color::Reset), - style::Print(format!( - " ~{} tokens ({:.2}%)\n", - tools_token_count, - (tools_token_count.value() as f32 / CONTEXT_WINDOW_SIZE as f32) * 100.0 - )), - style::SetForegroundColor(Color::Blue), - style::Print("█ Q responses: "), - style::SetForegroundColor(Color::Reset), - style::Print(format!( - " ~{} tokens ({:.2}%)\n", - assistant_token_count, - (assistant_token_count.value() as f32 / CONTEXT_WINDOW_SIZE as f32) * 100.0 - )), - style::SetForegroundColor(Color::Magenta), - style::Print("█ Your prompts: "), - style::SetForegroundColor(Color::Reset), - style::Print(format!( - " ~{} tokens ({:.2}%)\n\n", - user_token_count, - (user_token_count.value() as f32 / CONTEXT_WINDOW_SIZE as f32) * 100.0 - )), - )?; - - queue!( - session.stderr, - style::SetAttribute(Attribute::Bold), - style::Print("\n💡 Pro Tips:\n"), - style::SetAttribute(Attribute::Reset), - style::SetForegroundColor(Color::DarkGrey), - style::Print("Run "), - style::SetForegroundColor(Color::DarkGreen), - style::Print("/compact"), - style::SetForegroundColor(Color::DarkGrey), - style::Print(" to replace the conversation history with its summary\n"), - style::Print("Run "), - style::SetForegroundColor(Color::DarkGreen), - style::Print("/clear"), - style::SetForegroundColor(Color::DarkGrey), - style::Print(" to erase the entire chat history\n"), - style::Print("Run "), - style::SetForegroundColor(Color::DarkGreen), - style::Print("/context show"), - style::SetForegroundColor(Color::DarkGrey), - style::Print(" to see tokens per context file\n\n"), - style::SetForegroundColor(Color::Reset), - )?; - - Ok(ChatState::PromptUser { - skip_printing_tools: true, - }) - } -} diff --git a/crates/chat-cli/src/cli/chat/consts.rs b/crates/chat-cli/src/cli/chat/consts.rs deleted file mode 100644 index ac1c1fab3..000000000 --- a/crates/chat-cli/src/cli/chat/consts.rs +++ /dev/null @@ -1,28 +0,0 @@ -use super::token_counter::TokenCounter; - -// These limits are the internal undocumented values from the service for each item - -pub const MAX_CURRENT_WORKING_DIRECTORY_LEN: usize = 256; - -/// Limit to send the number of messages as part of chat. -pub const MAX_CONVERSATION_STATE_HISTORY_LEN: usize = 250; - -/// Actual service limit is 800_000 -pub const MAX_TOOL_RESPONSE_SIZE: usize = 400_000; - -/// Actual service limit is 600_000 -pub const MAX_USER_MESSAGE_SIZE: usize = 400_000; - -/// In tokens -pub const CONTEXT_WINDOW_SIZE: usize = 200_000; - -pub const CONTEXT_FILES_MAX_SIZE: usize = 150_000; - -pub const MAX_CHARS: usize = TokenCounter::token_to_chars(CONTEXT_WINDOW_SIZE); // Character-based warning threshold - -pub const DUMMY_TOOL_NAME: &str = "dummy"; - -pub const MAX_NUMBER_OF_IMAGES_PER_REQUEST: usize = 10; - -/// In bytes - 10 MB -pub const MAX_IMAGE_SIZE: usize = 10 * 1024 * 1024; diff --git a/crates/chat-cli/src/cli/chat/context.rs b/crates/chat-cli/src/cli/chat/context.rs deleted file mode 100644 index 10e2aa17d..000000000 --- a/crates/chat-cli/src/cli/chat/context.rs +++ /dev/null @@ -1,919 +0,0 @@ -use std::collections::HashMap; -use std::io::Write; -use std::path::{ - Path, - PathBuf, -}; - -use eyre::{ - Result, - eyre, -}; -use glob::glob; -use regex::Regex; -use serde::{ - Deserialize, - Serialize, -}; -use tracing::debug; - -use super::cli::hooks::HookTrigger; -use super::consts::CONTEXT_FILES_MAX_SIZE; -use super::util::drop_matched_context_files; -use crate::cli::chat::ChatError; -use crate::cli::chat::cli::hooks::{ - Hook, - HookExecutor, -}; -use crate::os::Os; -use crate::util::directories; - -pub const AMAZONQ_FILENAME: &str = "AmazonQ.md"; - -/// Configuration for context files, containing paths to include in the context. -#[derive(Debug, Clone, Serialize, Deserialize, Default)] -#[serde(default)] -pub struct ContextConfig { - /// List of file paths or glob patterns to include in the context. - pub paths: Vec, - - /// Map of Hook Name to [`Hook`]. The hook name serves as the hook's ID. - pub hooks: HashMap, -} - -/// Manager for context files and profiles. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ContextManager { - max_context_files_size: usize, - - /// Global context configuration that applies to all profiles. - pub global_config: ContextConfig, - - /// Name of the current active profile. - pub current_profile: String, - - /// Context configuration for the current profile. - pub profile_config: ContextConfig, - - #[serde(skip)] - pub hook_executor: HookExecutor, -} - -impl ContextManager { - /// Create a new ContextManager with default settings. - /// - /// This will: - /// 1. Create the necessary directories if they don't exist - /// 2. Load the global configuration - /// 3. Load the default profile configuration - /// - /// # Arguments - /// * `os` - The context to use - /// * `max_context_files_size` - Optional maximum token size for context files. If not provided, - /// defaults to `CONTEXT_FILES_MAX_SIZE`. - /// - /// # Returns - /// A Result containing the new ContextManager or an error - pub async fn new(os: &Os, max_context_files_size: Option) -> Result { - let max_context_files_size = max_context_files_size.unwrap_or(CONTEXT_FILES_MAX_SIZE); - - let profiles_dir = directories::chat_profiles_dir(os)?; - - os.fs.create_dir_all(&profiles_dir).await?; - - let global_config = load_global_config(os).await?; - let current_profile = "default".to_string(); - let profile_config = load_profile_config(os, ¤t_profile).await?; - - Ok(Self { - max_context_files_size, - global_config, - current_profile, - profile_config, - hook_executor: HookExecutor::new(), - }) - } - - /// Save the current configuration to disk. - /// - /// # Arguments - /// * `global` - If true, save the global configuration; otherwise, save the current profile - /// configuration - /// - /// # Returns - /// A Result indicating success or an error - async fn save_config(&self, os: &Os, global: bool) -> Result<()> { - if global { - let global_path = directories::chat_global_context_path(os)?; - let contents = serde_json::to_string_pretty(&self.global_config) - .map_err(|e| eyre!("Failed to serialize global configuration: {}", e))?; - - os.fs.write(&global_path, contents).await?; - } else { - let profile_path = profile_context_path(os, &self.current_profile)?; - if let Some(parent) = profile_path.parent() { - os.fs.create_dir_all(parent).await?; - } - let contents = serde_json::to_string_pretty(&self.profile_config) - .map_err(|e| eyre!("Failed to serialize profile configuration: {}", e))?; - - os.fs.write(&profile_path, contents).await?; - } - - Ok(()) - } - - /// Reloads the global and profile config from disk. - pub async fn reload_config(&mut self, os: &Os) -> Result<()> { - self.global_config = load_global_config(os).await?; - self.profile_config = load_profile_config(os, &self.current_profile).await?; - Ok(()) - } - - /// Add paths to the context configuration. - /// - /// # Arguments - /// * `paths` - List of paths to add - /// * `global` - If true, add to global configuration; otherwise, add to current profile - /// configuration - /// * `force` - If true, skip validation that the path exists - /// - /// # Returns - /// A Result indicating success or an error - pub async fn add_paths(&mut self, os: &Os, paths: Vec, global: bool, force: bool) -> Result<()> { - let mut all_paths = self.global_config.paths.clone(); - all_paths.append(&mut self.profile_config.paths.clone()); - - // Validate paths exist before adding them - if !force { - let mut context_files = Vec::new(); - - // Check each path to make sure it exists or matches at least one file - for path in &paths { - // We're using a temporary context_files vector just for validation - // Pass is_validation=true to ensure we error if glob patterns don't match any files - match process_path(os, path, &mut context_files, true).await { - Ok(_) => {}, // Path is valid - Err(e) => return Err(eyre!("Invalid path '{}': {}. Use --force to add anyway.", path, e)), - } - } - } - - // Add each path, checking for duplicates - for path in paths { - if all_paths.contains(&path) { - return Err(eyre!("Rule '{}' already exists.", path)); - } - if global { - self.global_config.paths.push(path); - } else { - self.profile_config.paths.push(path); - } - } - - // Save the updated configuration - self.save_config(os, global).await?; - - Ok(()) - } - - /// Remove paths from the context configuration. - /// - /// # Arguments - /// * `paths` - List of paths to remove - /// * `global` - If true, remove from global configuration; otherwise, remove from current - /// profile configuration - /// - /// # Returns - /// A Result indicating success or an error - pub async fn remove_paths(&mut self, os: &Os, paths: Vec, global: bool) -> Result<()> { - // Get reference to the appropriate config - let config = self.get_config_mut(global); - - // Track if any paths were removed - let mut removed_any = false; - - // Remove each path if it exists - for path in paths { - let original_len = config.paths.len(); - config.paths.retain(|p| p != &path); - - if config.paths.len() < original_len { - removed_any = true; - } - } - - if !removed_any { - return Err(eyre!("None of the specified paths were found in the context")); - } - - // Save the updated configuration - self.save_config(os, global).await?; - - Ok(()) - } - - /// List all available profiles. - /// - /// # Returns - /// A Result containing a vector of profile names, with "default" always first - pub async fn list_profiles(&self, os: &Os) -> Result> { - let mut profiles = Vec::new(); - - // Always include default profile - profiles.push("default".to_string()); - - // Read profile directory and extract profile names - let profiles_dir = directories::chat_profiles_dir(os)?; - if profiles_dir.exists() { - let mut read_dir = os.fs.read_dir(&profiles_dir).await?; - while let Some(entry) = read_dir.next_entry().await? { - let path = entry.path(); - if let (true, Some(name)) = (path.is_dir(), path.file_name()) { - if name != "default" { - profiles.push(name.to_string_lossy().to_string()); - } - } - } - } - - // Sort non-default profiles alphabetically - if profiles.len() > 1 { - profiles[1..].sort(); - } - - Ok(profiles) - } - - /// List all available profiles using blocking operations. - /// - /// Similar to list_profiles but uses synchronous filesystem operations. - /// - /// # Returns - /// A Result containing a vector of profile names, with "default" always first - #[cfg_attr(target_os = "windows", allow(dead_code))] - pub fn list_profiles_blocking(&self, os: &Os) -> Result> { - let _ = self; - - let mut profiles = Vec::new(); - - // Always include default profile - profiles.push("default".to_string()); - - // Read profile directory and extract profile names - let profiles_dir = directories::chat_profiles_dir(os)?; - if profiles_dir.exists() { - for entry in std::fs::read_dir(profiles_dir)? { - let entry = entry?; - let path = entry.path(); - if let (true, Some(name)) = (path.is_dir(), path.file_name()) { - if name != "default" { - profiles.push(name.to_string_lossy().to_string()); - } - } - } - } - - // Sort non-default profiles alphabetically - if profiles.len() > 1 { - profiles[1..].sort(); - } - - Ok(profiles) - } - - /// Clear all paths from the context configuration. - /// - /// # Arguments - /// * `global` - If true, clear global configuration; otherwise, clear current profile - /// configuration - /// - /// # Returns - /// A Result indicating success or an error - pub async fn clear(&mut self, os: &Os, global: bool) -> Result<()> { - // Clear the appropriate config - if global { - self.global_config.paths.clear(); - } else { - self.profile_config.paths.clear(); - } - - // Save the updated configuration - self.save_config(os, global).await?; - - Ok(()) - } - - /// Create a new profile. - /// - /// # Arguments - /// * `name` - Name of the profile to create - /// - /// # Returns - /// A Result indicating success or an error - pub async fn create_profile(&self, os: &Os, name: &str) -> Result<()> { - validate_profile_name(name)?; - - // Check if profile already exists - let profile_path = profile_context_path(os, name)?; - if profile_path.exists() { - return Err(eyre!("Profile '{}' already exists", name)); - } - - // Create empty profile configuration - let config = ContextConfig::default(); - let contents = serde_json::to_string_pretty(&config) - .map_err(|e| eyre!("Failed to serialize profile configuration: {}", e))?; - - // Create the file - if let Some(parent) = profile_path.parent() { - os.fs.create_dir_all(parent).await?; - } - os.fs.write(&profile_path, contents).await?; - - Ok(()) - } - - /// Delete a profile. - /// - /// # Arguments - /// * `name` - Name of the profile to delete - /// - /// # Returns - /// A Result indicating success or an error - pub async fn delete_profile(&self, os: &Os, name: &str) -> Result<()> { - if name == "default" { - return Err(eyre!("Cannot delete the default profile")); - } else if name == self.current_profile { - return Err(eyre!( - "Cannot delete the active profile. Switch to another profile first" - )); - } - - let profile_path = profile_dir_path(os, name)?; - if !profile_path.exists() { - return Err(eyre!("Profile '{}' does not exist", name)); - } - - os.fs.remove_dir_all(&profile_path).await?; - - Ok(()) - } - - /// Rename a profile. - /// - /// # Arguments - /// * `old_name` - Current name of the profile - /// * `new_name` - New name for the profile - /// - /// # Returns - /// A Result indicating success or an error - pub async fn rename_profile(&mut self, os: &Os, old_name: &str, new_name: &str) -> Result<()> { - // Validate profile names - if old_name == "default" { - return Err(eyre!("Cannot rename the default profile")); - } - if new_name == "default" { - return Err(eyre!("Cannot rename to 'default' as it's a reserved profile name")); - } - - validate_profile_name(new_name)?; - - let old_profile_path = profile_dir_path(os, old_name)?; - if !old_profile_path.exists() { - return Err(eyre!("Profile '{}' not found", old_name)); - } - - let new_profile_path = profile_dir_path(os, new_name)?; - if new_profile_path.exists() { - return Err(eyre!("Profile '{}' already exists", new_name)); - } - - os.fs.rename(&old_profile_path, &new_profile_path).await?; - - // If the current profile is being renamed, update the current_profile field - if self.current_profile == old_name { - self.current_profile = new_name.to_string(); - self.profile_config = load_profile_config(os, new_name).await?; - } - - Ok(()) - } - - /// Switch to a different profile. - /// - /// # Arguments - /// * `name` - Name of the profile to switch to - /// - /// # Returns - /// A Result indicating success or an error - pub async fn switch_profile(&mut self, os: &Os, name: &str) -> Result<()> { - validate_profile_name(name)?; - self.hook_executor.profile_cache.clear(); - - // Special handling for default profile - it always exists - if name == "default" { - // Load the default profile configuration - let profile_config = load_profile_config(os, name).await?; - - // Update the current profile - self.current_profile = name.to_string(); - self.profile_config = profile_config; - - return Ok(()); - } - - // Check if profile exists - let profile_path = profile_context_path(os, name)?; - if !profile_path.exists() { - return Err(eyre!("Profile '{}' does not exist. Use 'create' to create it", name)); - } - - // Update the current profile - self.current_profile = name.to_string(); - self.profile_config = load_profile_config(os, name).await?; - - Ok(()) - } - - /// Get all context files (global + profile-specific). - /// - /// This method: - /// 1. Processes all paths in the global and profile configurations - /// 2. Expands glob patterns to include matching files - /// 3. Reads the content of each file - /// 4. Returns a vector of (filename, content) pairs - /// - /// - /// # Returns - /// A Result containing a vector of (filename, content) pairs or an error - pub async fn get_context_files(&self, os: &Os) -> Result> { - let mut context_files = Vec::new(); - - self.collect_context_files(os, &self.global_config.paths, &mut context_files) - .await?; - self.collect_context_files(os, &self.profile_config.paths, &mut context_files) - .await?; - - context_files.sort_by(|a, b| a.0.cmp(&b.0)); - context_files.dedup_by(|a, b| a.0 == b.0); - - Ok(context_files) - } - - pub async fn get_context_files_by_path(&self, os: &Os, path: &str) -> Result> { - let mut context_files = Vec::new(); - process_path(os, path, &mut context_files, true).await?; - Ok(context_files) - } - - /// Collects context files and optionally drops files if the total size exceeds the limit. - /// Returns (files_to_use, dropped_files) - pub async fn collect_context_files_with_limit( - &self, - os: &Os, - ) -> Result<(Vec<(String, String)>, Vec<(String, String)>)> { - let mut files = self.get_context_files(os).await?; - - let dropped_files = drop_matched_context_files(&mut files, self.max_context_files_size).unwrap_or_default(); - - // remove dropped files from files - files.retain(|file| !dropped_files.iter().any(|dropped| dropped.0 == file.0)); - - Ok((files, dropped_files)) - } - - async fn collect_context_files( - &self, - os: &Os, - paths: &[String], - context_files: &mut Vec<(String, String)>, - ) -> Result<()> { - for path in paths { - // Use is_validation=false to handle non-matching globs gracefully - process_path(os, path, context_files, false).await?; - } - Ok(()) - } - - fn get_config_mut(&mut self, global: bool) -> &mut ContextConfig { - if global { - &mut self.global_config - } else { - &mut self.profile_config - } - } - - /// Add hooks to the context config. If another hook with the same name already exists, throw an - /// error. - /// - /// # Arguments - /// * `hook` - name of the hook to delete - /// * `global` - If true, the add to the global config. If false, add to the current profile - /// config. - /// * `conversation_start` - If true, add the hook to conversation_start. Otherwise, it will be - /// added to per_prompt. - pub async fn add_hook(&mut self, os: &Os, name: String, hook: Hook, global: bool) -> Result<()> { - let config = self.get_config_mut(global); - - if config.hooks.contains_key(&name) { - return Err(eyre!("name already exists.")); - } - - config.hooks.insert(name, hook); - self.save_config(os, global).await - } - - /// Delete hook(s) by name - /// # Arguments - /// * `name` - name of the hook to delete - /// * `global` - If true, the delete from the global config. If false, delete from the current - /// profile config - pub async fn remove_hook(&mut self, os: &Os, name: &str, global: bool) -> Result<()> { - let config = self.get_config_mut(global); - - if !config.hooks.contains_key(name) { - return Err(eyre!("does not exist.")); - } - - config.hooks.remove(name); - - self.save_config(os, global).await - } - - /// Sets the "disabled" field on any [`Hook`] with the given name - /// # Arguments - /// * `disable` - Set "disabled" field to this value - pub async fn set_hook_disabled(&mut self, os: &Os, name: &str, global: bool, disable: bool) -> Result<()> { - let config = self.get_config_mut(global); - - if !config.hooks.contains_key(name) { - return Err(eyre!("does not exist.")); - } - - if let Some(hook) = config.hooks.get_mut(name) { - hook.disabled = disable; - } - - self.save_config(os, global).await - } - - /// Sets the "disabled" field on all [`Hook`]s - /// # Arguments - /// * `disable` - Set all "disabled" fields to this value - pub async fn set_all_hooks_disabled(&mut self, os: &Os, global: bool, disable: bool) -> Result<()> { - let config = self.get_config_mut(global); - - config.hooks.iter_mut().for_each(|(_, h)| h.disabled = disable); - - self.save_config(os, global).await - } - - /// Run all the currently enabled hooks from both the global and profile contexts. - /// Skipped hooks (disabled) will not appear in the output. - /// # Arguments - /// * `updates` - output stream to write hook run status to if Some, else do nothing if None - /// # Returns - /// A vector containing pairs of a [`Hook`] definition and its execution output - pub async fn run_hooks( - &mut self, - trigger: HookTrigger, - output: &mut impl Write, - ) -> Result, ChatError> { - let mut hooks: Vec<&Hook> = Vec::new(); - - // Set internal hook states - let configs = [ - (&mut self.global_config.hooks, true), - (&mut self.profile_config.hooks, false), - ]; - - for (hook_list, is_global) in configs { - hooks.extend(hook_list.iter_mut().filter_map(|(name, h)| { - if h.trigger == trigger { - h.name = name.clone(); - h.is_global = is_global; - Some(&*h) - } else { - None - } - })); - } - - self.hook_executor.run_hooks(hooks, output).await - } -} - -fn profile_dir_path(os: &Os, profile_name: &str) -> Result { - Ok(directories::chat_profiles_dir(os)?.join(profile_name)) -} - -/// Path to the context config file for `profile_name`. -pub fn profile_context_path(os: &Os, profile_name: &str) -> Result { - Ok(directories::chat_profiles_dir(os)? - .join(profile_name) - .join("context.json")) -} - -/// Load the global context configuration. -/// -/// If the global configuration file doesn't exist, returns a default configuration. -async fn load_global_config(os: &Os) -> Result { - let global_path = directories::chat_global_context_path(os)?; - debug!(?global_path, "loading profile config"); - if os.fs.exists(&global_path) { - let contents = os.fs.read_to_string(&global_path).await?; - let config: ContextConfig = - serde_json::from_str(&contents).map_err(|e| eyre!("Failed to parse global configuration: {}", e))?; - Ok(config) - } else { - // Return default global configuration with predefined paths - use crate::util::paths::workspace; - - Ok(ContextConfig { - paths: vec![ - workspace::RULES_PATTERN.to_string(), - "README.md".to_string(), - AMAZONQ_FILENAME.to_string(), - ], - hooks: HashMap::new(), - }) - } -} - -/// Load a profile's context configuration. -/// -/// If the profile configuration file doesn't exist, creates a default configuration. -async fn load_profile_config(os: &Os, profile_name: &str) -> Result { - let profile_path = profile_context_path(os, profile_name)?; - debug!(?profile_path, "loading profile config"); - if os.fs.exists(&profile_path) { - let contents = os.fs.read_to_string(&profile_path).await?; - let config: ContextConfig = - serde_json::from_str(&contents).map_err(|e| eyre!("Failed to parse profile configuration: {}", e))?; - Ok(config) - } else { - // Return empty configuration for new profiles - Ok(ContextConfig::default()) - } -} - -/// Process a path, handling glob patterns and file types. -/// -/// This method: -/// 1. Expands the path (handling ~ for home directory) -/// 2. If the path contains glob patterns, expands them -/// 3. For each resulting path, adds the file to the context collection -/// 4. Handles directories by including all files in the directory (non-recursive) -/// 5. With force=true, includes paths that don't exist yet -/// -/// # Arguments -/// * `path` - The path to process -/// * `context_files` - The collection to add files to -/// * `is_validation` - If true, error when glob patterns don't match; if false, silently skip -/// -/// # Returns -/// A Result indicating success or an error -async fn process_path( - os: &Os, - path: &str, - context_files: &mut Vec<(String, String)>, - is_validation: bool, -) -> Result<()> { - // Expand ~ to home directory - let expanded_path = if path.starts_with('~') { - if let Some(home_dir) = os.env.home() { - home_dir.join(&path[2..]).to_string_lossy().to_string() - } else { - return Err(eyre!("Could not determine home directory")); - } - } else { - path.to_string() - }; - - // Handle absolute, relative paths, and glob patterns - let full_path = if expanded_path.starts_with('/') { - expanded_path - } else { - os.env.current_dir()?.join(&expanded_path).to_string_lossy().to_string() - }; - - // Required in chroot testing scenarios so that we can use `Path::exists`. - let full_path = os.fs.chroot_path_str(full_path); - - // Check if the path contains glob patterns - if full_path.contains('*') || full_path.contains('?') || full_path.contains('[') { - // Expand glob pattern - match glob(&full_path) { - Ok(entries) => { - let mut found_any = false; - - for entry in entries { - match entry { - Ok(path) => { - if path.is_file() { - add_file_to_context(os, &path, context_files).await?; - found_any = true; - } - }, - Err(e) => return Err(eyre!("Glob error: {}", e)), - } - } - - if !found_any && is_validation { - // When validating paths (e.g., for /context add), error if no files match - return Err(eyre!("No files found matching glob pattern '{}'", full_path)); - } - // When just showing expanded files (e.g., for /context show --expand), - // silently skip non-matching patterns (don't add anything to context_files) - }, - Err(e) => return Err(eyre!("Invalid glob pattern '{}': {}", full_path, e)), - } - } else { - // Regular path - let path = Path::new(&full_path); - if path.exists() { - if path.is_file() { - add_file_to_context(os, path, context_files).await?; - } else if path.is_dir() { - // For directories, add all files in the directory (non-recursive) - let mut read_dir = os.fs.read_dir(path).await?; - while let Some(entry) = read_dir.next_entry().await? { - let path = entry.path(); - if path.is_file() { - add_file_to_context(os, &path, context_files).await?; - } - } - } - } else if is_validation { - // When validating paths (e.g., for /context add), error if the path doesn't exist - return Err(eyre!("Path '{}' does not exist", full_path)); - } - } - - Ok(()) -} - -/// Add a file to the context collection. -/// -/// This method: -/// 1. Reads the content of the file -/// 2. Adds the (filename, content) pair to the context collection -/// -/// # Arguments -/// * `path` - The path to the file -/// * `context_files` - The collection to add the file to -/// -/// # Returns -/// A Result indicating success or an error -async fn add_file_to_context(os: &Os, path: &Path, context_files: &mut Vec<(String, String)>) -> Result<()> { - let filename = path.to_string_lossy().to_string(); - let content = os.fs.read_to_string(path).await?; - context_files.push((filename, content)); - Ok(()) -} - -/// Validate a profile name. -/// -/// Profile names can only contain alphanumeric characters, hyphens, and underscores. -/// -/// # Arguments -/// * `name` - Name to validate -/// -/// # Returns -/// A Result indicating if the name is valid -fn validate_profile_name(name: &str) -> Result<()> { - // Check if name is empty - if name.is_empty() { - return Err(eyre!("Profile name cannot be empty")); - } - - // Check if name contains only allowed characters and starts with an alphanumeric character - let re = Regex::new(r"^[a-zA-Z0-9][a-zA-Z0-9_-]*$").unwrap(); - if !re.is_match(name) { - return Err(eyre!( - "Profile name must start with an alphanumeric character and can only contain alphanumeric characters, hyphens, and underscores" - )); - } - - Ok(()) -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::cli::chat::util::test::create_test_context_manager; - - #[tokio::test] - async fn test_validate_profile_name() { - // Test valid names - assert!(validate_profile_name("valid").is_ok()); - assert!(validate_profile_name("valid-name").is_ok()); - assert!(validate_profile_name("valid_name").is_ok()); - assert!(validate_profile_name("valid123").is_ok()); - assert!(validate_profile_name("1valid").is_ok()); - assert!(validate_profile_name("9test").is_ok()); - - // Test invalid names - assert!(validate_profile_name("").is_err()); - assert!(validate_profile_name("invalid/name").is_err()); - assert!(validate_profile_name("invalid.name").is_err()); - assert!(validate_profile_name("invalid name").is_err()); - assert!(validate_profile_name("_invalid").is_err()); - assert!(validate_profile_name("-invalid").is_err()); - } - - #[tokio::test] - async fn test_profile_ops() -> Result<()> { - let os = Os::new().await.unwrap(); - let mut manager = create_test_context_manager(None).await?; - - assert_eq!(manager.current_profile, "default"); - - // Create ops - manager.create_profile(&os, "test_profile").await?; - assert!(profile_context_path(&os, "test_profile")?.exists()); - assert!(manager.create_profile(&os, "test_profile").await.is_err()); - manager.create_profile(&os, "alt").await?; - - // Listing - let profiles = manager.list_profiles(&os).await?; - assert!(profiles.contains(&"default".to_string())); - assert!(profiles.contains(&"test_profile".to_string())); - assert!(profiles.contains(&"alt".to_string())); - - // Switching - manager.switch_profile(&os, "test_profile").await?; - assert!(manager.switch_profile(&os, "notexists").await.is_err()); - - // Renaming - manager.rename_profile(&os, "alt", "renamed").await?; - assert!(!profile_context_path(&os, "alt")?.exists()); - assert!(profile_context_path(&os, "renamed")?.exists()); - - // Delete ops - assert!(manager.delete_profile(&os, "test_profile").await.is_err()); - manager.switch_profile(&os, "default").await?; - manager.delete_profile(&os, "test_profile").await?; - assert!(!profile_context_path(&os, "test_profile")?.exists()); - assert!(manager.delete_profile(&os, "test_profile").await.is_err()); - assert!(manager.delete_profile(&os, "default").await.is_err()); - - Ok(()) - } - - #[tokio::test] - async fn test_collect_exceeds_limit() -> Result<()> { - let os = Os::new().await.unwrap(); - let mut manager = create_test_context_manager(Some(2)).await?; - - os.fs.create_dir_all("test").await?; - os.fs.write("test/to-include.md", "ha").await?; - os.fs.write("test/to-drop.md", "long content that exceed limit").await?; - manager - .add_paths(&os, vec!["test/*.md".to_string()], false, false) - .await?; - - let (used, dropped) = manager.collect_context_files_with_limit(&os).await.unwrap(); - - assert!(used.len() + dropped.len() == 2); - assert!(used.len() == 1); - assert!(dropped.len() == 1); - Ok(()) - } - - #[tokio::test] - async fn test_path_ops() -> Result<()> { - let os = Os::new().await.unwrap(); - let mut manager = create_test_context_manager(None).await?; - - // Create some test files for matching. - os.fs.create_dir_all("test").await?; - os.fs.write("test/p1.md", "p1").await?; - os.fs.write("test/p2.md", "p2").await?; - - assert!( - manager.get_context_files(&os).await?.is_empty(), - "no files should be returned for an empty profile when force is false" - ); - - manager - .add_paths(&os, vec!["test/*.md".to_string()], false, false) - .await?; - let files = manager.get_context_files(&os).await?; - assert!(files[0].0.ends_with("p1.md")); - assert_eq!(files[0].1, "p1"); - assert!(files[1].0.ends_with("p2.md")); - assert_eq!(files[1].1, "p2"); - - assert!( - manager - .add_paths(&os, vec!["test/*.txt".to_string()], false, false) - .await - .is_err(), - "adding a glob with no matching and without force should fail" - ); - - Ok(()) - } -} diff --git a/crates/chat-cli/src/cli/chat/conversation.rs b/crates/chat-cli/src/cli/chat/conversation.rs deleted file mode 100644 index 20cf79253..000000000 --- a/crates/chat-cli/src/cli/chat/conversation.rs +++ /dev/null @@ -1,1225 +0,0 @@ -use std::collections::{ - HashMap, - HashSet, - VecDeque, -}; -use std::io::Write; -use std::sync::atomic::Ordering; - -use crossterm::style::Color; -use crossterm::{ - execute, - style, -}; -use serde::{ - Deserialize, - Serialize, -}; -use tracing::{ - debug, - error, - warn, -}; - -use super::cli::compact::CompactStrategy; -use super::consts::{ - DUMMY_TOOL_NAME, - MAX_CHARS, - MAX_CONVERSATION_STATE_HISTORY_LEN, -}; -use super::context::ContextManager; -use super::message::{ - AssistantMessage, - ToolUseResult, - UserMessage, -}; -use super::token_counter::{ - CharCount, - CharCounter, -}; -use super::tool_manager::ToolManager; -use super::tools::{ - InputSchema, - QueuedTool, - ToolOrigin, - ToolSpec, -}; -use super::util::serde_value_to_document; -use crate::api_client::model::{ - ChatMessage, - ConversationState as FigConversationState, - ImageBlock, - Tool, - ToolInputSchema, - ToolSpecification, - UserInputMessage, -}; -use crate::cli::chat::ChatError; -use crate::cli::chat::cli::hooks::{ - Hook, - HookTrigger, -}; -use crate::mcp_client::Prompt; -use crate::os::Os; - -const CONTEXT_ENTRY_START_HEADER: &str = "--- CONTEXT ENTRY BEGIN ---\n"; -const CONTEXT_ENTRY_END_HEADER: &str = "--- CONTEXT ENTRY END ---\n\n"; - -/// Tracks state related to an ongoing conversation. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ConversationState { - /// Randomly generated on creation. - conversation_id: String, - /// The next user message to be sent as part of the conversation. Required to be [Some] before - /// calling [Self::as_sendable_conversation_state]. - next_message: Option, - history: VecDeque<(UserMessage, AssistantMessage)>, - /// The range in the history sendable to the backend (start inclusive, end exclusive). - valid_history_range: (usize, usize), - /// Similar to history in that stores user and assistant responses, except that it is not used - /// in message requests. Instead, the responses are expected to be in human-readable format, - /// e.g user messages prefixed with '> '. Should also be used to store errors posted in the - /// chat. - pub transcript: VecDeque, - pub tools: HashMap>, - /// Context manager for handling sticky context files - pub context_manager: Option, - /// Tool manager for handling tool and mcp related activities - #[serde(skip)] - pub tool_manager: ToolManager, - /// Cached value representing the length of the user context message. - context_message_length: Option, - /// Stores the latest conversation summary created by /compact - latest_summary: Option, - /// Model explicitly selected by the user in this conversation state via `/model`. - #[serde(default, skip_serializing_if = "Option::is_none")] - pub model: Option, -} - -impl ConversationState { - pub async fn new( - os: &mut Os, - conversation_id: &str, - tool_config: HashMap, - profile: Option, - tool_manager: ToolManager, - current_model_id: Option, - ) -> Self { - // Initialize context manager - let context_manager = match ContextManager::new(os, None).await { - Ok(mut manager) => { - // Switch to specified profile if provided - if let Some(profile_name) = profile { - if let Err(e) = manager.switch_profile(os, &profile_name).await { - warn!("Failed to switch to profile {}: {}", profile_name, e); - } - } - Some(manager) - }, - Err(e) => { - warn!("Failed to initialize context manager: {}", e); - None - }, - }; - - Self { - conversation_id: conversation_id.to_string(), - next_message: None, - history: VecDeque::new(), - valid_history_range: Default::default(), - transcript: VecDeque::with_capacity(MAX_CONVERSATION_STATE_HISTORY_LEN), - tools: tool_config - .into_values() - .fold(HashMap::>::new(), |mut acc, v| { - let tool = Tool::ToolSpecification(ToolSpecification { - name: v.name, - description: v.description, - input_schema: v.input_schema.into(), - }); - acc.entry(v.tool_origin) - .and_modify(|tools| tools.push(tool.clone())) - .or_insert(vec![tool]); - acc - }), - context_manager, - tool_manager, - context_message_length: None, - latest_summary: None, - model: current_model_id, - } - } - - /// Reloads necessary fields after being deserialized. This should be called after - /// deserialization. - pub async fn reload_serialized_state(&mut self, os: &Os) { - // Try to reload ContextManager, but do not return an error if we fail. - // TODO: Currently the failure modes around ContextManager is unclear, and we don't return - // errors in most cases. Thus, we try to preserve the same behavior here and simply have - // self.context_manager equal to None if any errors are encountered. This needs to be - // refactored. - let mut failed = false; - if let Some(context_manager) = self.context_manager.as_mut() { - match context_manager.reload_config(os).await { - Ok(_) => (), - Err(err) => { - error!(?err, "failed to reload context config"); - match ContextManager::new(os, None).await { - Ok(v) => *context_manager = v, - Err(err) => { - failed = true; - error!(?err, "failed to construct context manager"); - }, - } - }, - } - } - - if failed { - self.context_manager.take(); - } - } - - pub fn latest_summary(&self) -> Option<&str> { - self.latest_summary.as_deref() - } - - pub fn history(&self) -> &VecDeque<(UserMessage, AssistantMessage)> { - &self.history - } - - /// Clears the conversation history and optionally the summary. - pub fn clear(&mut self, preserve_summary: bool) { - self.next_message = None; - self.history.clear(); - if !preserve_summary { - self.latest_summary = None; - } - } - - /// Appends a collection prompts into history and returns the last message in the collection. - /// It asserts that the collection ends with a prompt that assumes the role of user. - pub fn append_prompts(&mut self, mut prompts: VecDeque) -> Option { - debug_assert!(self.next_message.is_none(), "next_message should not exist"); - debug_assert!(prompts.back().is_some_and(|p| p.role == crate::mcp_client::Role::User)); - let last_msg = prompts.pop_back()?; - let (mut candidate_user, mut candidate_asst) = (None::, None::); - while let Some(prompt) = prompts.pop_front() { - let Prompt { role, content } = prompt; - match role { - crate::mcp_client::Role::User => { - let user_msg = UserMessage::new_prompt(content.to_string()); - candidate_user.replace(user_msg); - }, - crate::mcp_client::Role::Assistant => { - let assistant_msg = AssistantMessage::new_response(None, content.into()); - candidate_asst.replace(assistant_msg); - }, - } - if candidate_asst.is_some() && candidate_user.is_some() { - let asst = candidate_asst.take().unwrap(); - let user = candidate_user.take().unwrap(); - self.append_assistant_transcript(&asst); - self.history.push_back((user, asst)); - } - } - Some(last_msg.content.to_string()) - } - - pub fn next_user_message(&self) -> Option<&UserMessage> { - self.next_message.as_ref() - } - - pub fn reset_next_user_message(&mut self) { - self.next_message = None; - } - - pub async fn set_next_user_message(&mut self, input: String) { - debug_assert!(self.next_message.is_none(), "next_message should not exist"); - if let Some(next_message) = self.next_message.as_ref() { - warn!(?next_message, "next_message should not exist"); - } - - let input = if input.is_empty() { - warn!("input must not be empty when adding new messages"); - "Empty prompt".to_string() - } else { - input - }; - - let msg = UserMessage::new_prompt(input); - self.next_message = Some(msg); - } - - /// Sets the response message according to the currently set [Self::next_message]. - pub fn push_assistant_message(&mut self, os: &mut Os, message: AssistantMessage) { - debug_assert!(self.next_message.is_some(), "next_message should exist"); - let next_user_message = self.next_message.take().expect("next user message should exist"); - - self.append_assistant_transcript(&message); - self.history.push_back((next_user_message, message)); - - if let Ok(cwd) = std::env::current_dir() { - os.database.set_conversation_by_path(cwd, self).ok(); - } - } - - /// Returns the conversation id. - pub fn conversation_id(&self) -> &str { - self.conversation_id.as_ref() - } - - /// Returns the message id associated with the last assistant message, if present. - /// - /// This is equivalent to `utterance_id` in the Q API. - pub fn message_id(&self) -> Option<&str> { - self.history.back().and_then(|(_, msg)| msg.message_id()) - } - - /// Updates the history so that, when non-empty, the following invariants are in place: - /// 1. The history length is `<= MAX_CONVERSATION_STATE_HISTORY_LEN`. Oldest messages are - /// dropped. - /// 2. The first message is from the user, and does not contain tool results. Oldest messages - /// are dropped. - /// 3. If the last message from the assistant contains tool results, and a next user message is - /// set without tool results, then the user message will have "cancelled" tool results. - pub fn enforce_conversation_invariants(&mut self) { - self.valid_history_range = - enforce_conversation_invariants(&mut self.history, &mut self.next_message, &self.tools); - } - - /// Here we also need to make sure that the tool result corresponds to one of the tools - /// in the list. Otherwise we will see validation error from the backend. There are three - /// such circumstances where intervention would be needed: - /// 1. The model had decided to call a tool with its partial name AND there is only one such - /// tool, in which case we would automatically resolve this tool call to its correct name. - /// This will NOT result in an error in its tool result. The intervention here is to - /// substitute the partial name with its full name. - /// 2. The model had decided to call a tool with its partial name AND there are multiple tools - /// it could be referring to, in which case we WILL return an error in the tool result. The - /// intervention here is to substitute the ambiguous, partial name with a dummy. - /// 3. The model had decided to call a tool that does not exist. The intervention here is to - /// substitute the non-existent tool name with a dummy. - pub fn enforce_tool_use_history_invariants(&mut self) { - enforce_tool_use_history_invariants(&mut self.history, &self.tools); - } - - pub fn add_tool_results(&mut self, tool_results: Vec) { - debug_assert!(self.next_message.is_none()); - self.next_message = Some(UserMessage::new_tool_use_results(tool_results)); - } - - pub fn add_tool_results_with_images(&mut self, tool_results: Vec, images: Vec) { - debug_assert!(self.next_message.is_none()); - self.next_message = Some(UserMessage::new_tool_use_results_with_images(tool_results, images)); - } - - /// Sets the next user message with "cancelled" tool results. - pub fn abandon_tool_use(&mut self, tools_to_be_abandoned: &[QueuedTool], deny_input: String) { - self.next_message = Some(UserMessage::new_cancelled_tool_uses( - Some(deny_input), - tools_to_be_abandoned.iter().map(|t| t.id.as_str()), - )); - } - - /// Returns a [FigConversationState] capable of being sent by [api_client::StreamingClient]. - /// - /// Params: - /// - `run_perprompt_hooks` - whether per-prompt hooks should be executed and included as - /// context - pub async fn as_sendable_conversation_state( - &mut self, - os: &Os, - stderr: &mut impl Write, - run_perprompt_hooks: bool, - ) -> Result { - debug_assert!(self.next_message.is_some()); - self.enforce_conversation_invariants(); - self.history.drain(self.valid_history_range.1..); - self.history.drain(..self.valid_history_range.0); - - let context = self.backend_conversation_state(os, run_perprompt_hooks, stderr).await?; - if !context.dropped_context_files.is_empty() { - execute!( - stderr, - style::SetForegroundColor(Color::DarkYellow), - style::Print("\nSome context files are dropped due to size limit, please run "), - style::SetForegroundColor(Color::DarkGreen), - style::Print("/context show "), - style::SetForegroundColor(Color::DarkYellow), - style::Print("to learn more.\n"), - style::SetForegroundColor(style::Color::Reset) - ) - .ok(); - } - - Ok(context - .into_fig_conversation_state() - .expect("unable to construct conversation state")) - } - - pub async fn update_state(&mut self, force_update: bool) { - let needs_update = self.tool_manager.has_new_stuff.load(Ordering::Acquire) || force_update; - if !needs_update { - return; - } - self.tool_manager.update().await; - // TODO: make this more targeted so we don't have to clone the entire list of tools - self.tools = self - .tool_manager - .schema - .values() - .fold(HashMap::>::new(), |mut acc, v| { - let tool = Tool::ToolSpecification(ToolSpecification { - name: v.name.clone(), - description: v.description.clone(), - input_schema: v.input_schema.clone().into(), - }); - acc.entry(v.tool_origin.clone()) - .and_modify(|tools| tools.push(tool.clone())) - .or_insert(vec![tool]); - acc - }); - self.tool_manager.has_new_stuff.store(false, Ordering::Release); - // We call this in [Self::enforce_conversation_invariants] as well. But we need to call it - // here as well because when it's being called in [Self::enforce_conversation_invariants] - // it is only checking the last entry. - self.enforce_tool_use_history_invariants(); - } - - /// Returns a conversation state representation which reflects the exact conversation to send - /// back to the model. - pub async fn backend_conversation_state( - &mut self, - os: &Os, - run_perprompt_hooks: bool, - output: &mut impl Write, - ) -> Result, ChatError> { - self.update_state(false).await; - self.enforce_conversation_invariants(); - - let mut conversation_start_context = None; - if let Some(cm) = self.context_manager.as_mut() { - let conv_start = cm.run_hooks(HookTrigger::ConversationStart, output).await?; - conversation_start_context = format_hook_context(&conv_start, HookTrigger::ConversationStart); - - if let (true, Some(next_message)) = (run_perprompt_hooks, self.next_message.as_mut()) { - let per_prompt = cm.run_hooks(HookTrigger::PerPrompt, output).await?; - if let Some(ctx) = format_hook_context(&per_prompt, HookTrigger::PerPrompt) { - next_message.additional_context = ctx; - } - } - } - - let (context_messages, dropped_context_files) = self.context_messages(os, conversation_start_context).await; - - Ok(BackendConversationState { - conversation_id: self.conversation_id.as_str(), - next_user_message: self.next_message.as_ref(), - history: self - .history - .range(self.valid_history_range.0..self.valid_history_range.1), - context_messages, - dropped_context_files, - tools: &self.tools, - model_id: self.model.as_deref(), - }) - } - - /// Returns a [FigConversationState] capable of replacing the history of the current - /// conversation with a summary generated by the model. - /// - /// The resulting summary should update the state by immediately following with - /// [ConversationState::replace_history_with_summary]. - pub async fn create_summary_request( - &mut self, - os: &Os, - custom_prompt: Option>, - strategy: CompactStrategy, - ) -> Result { - let mut summary_content = match custom_prompt { - Some(custom_prompt) => { - // Make the custom instructions much more prominent and directive - format!( - "[SYSTEM NOTE: This is an automated summarization request, not from the user]\n\n\ - FORMAT REQUIREMENTS: Create a structured, concise summary in bullet-point format. DO NOT respond conversationally. DO NOT address the user directly.\n\n\ - IMPORTANT CUSTOM INSTRUCTION: {}\n\n\ - Your task is to create a structured summary document containing:\n\ - 1) A bullet-point list of key topics/questions covered\n\ - 2) Bullet points for all significant tools executed and their results\n\ - 3) Bullet points for any code or technical information shared\n\ - 4) A section of key insights gained\n\n\ - FORMAT THE SUMMARY IN THIRD PERSON, NOT AS A DIRECT RESPONSE. Example format:\n\n\ - ## CONVERSATION SUMMARY\n\ - * Topic 1: Key information\n\ - * Topic 2: Key information\n\n\ - ## TOOLS EXECUTED\n\ - * Tool X: Result Y\n\n\ - Remember this is a DOCUMENT not a chat response. The custom instruction above modifies what to prioritize.\n\ - FILTER OUT CHAT CONVENTIONS (greetings, offers to help, etc).", - custom_prompt.as_ref() - ) - }, - None => { - // Default prompt - "[SYSTEM NOTE: This is an automated summarization request, not from the user]\n\n\ - FORMAT REQUIREMENTS: Create a structured, concise summary in bullet-point format. DO NOT respond conversationally. DO NOT address the user directly.\n\n\ - Your task is to create a structured summary document containing:\n\ - 1) A bullet-point list of key topics/questions covered\n\ - 2) Bullet points for all significant tools executed and their results\n\ - 3) Bullet points for any code or technical information shared\n\ - 4) A section of key insights gained\n\n\ - FORMAT THE SUMMARY IN THIRD PERSON, NOT AS A DIRECT RESPONSE. Example format:\n\n\ - ## CONVERSATION SUMMARY\n\ - * Topic 1: Key information\n\ - * Topic 2: Key information\n\n\ - ## TOOLS EXECUTED\n\ - * Tool X: Result Y\n\n\ - Remember this is a DOCUMENT not a chat response.\n\ - FILTER OUT CHAT CONVENTIONS (greetings, offers to help, etc).".to_string() - }, - }; - if let Some(summary) = &self.latest_summary { - summary_content.push_str("\n\n"); - summary_content.push_str(CONTEXT_ENTRY_START_HEADER); - summary_content.push_str("This summary contains ALL relevant information from our previous conversation including tool uses, results, code analysis, and file operations. YOU MUST be sure to include this information when creating your summarization document.\n\n"); - summary_content.push_str("SUMMARY CONTENT:\n"); - summary_content.push_str(summary); - summary_content.push('\n'); - summary_content.push_str(CONTEXT_ENTRY_END_HEADER); - } - - let conv_state = self.backend_conversation_state(os, false, &mut vec![]).await?; - let mut summary_message = Some(UserMessage::new_prompt(summary_content.clone())); - - // Create the history according to the passed compact strategy. - let mut history = conv_state.history.cloned().collect::>(); - history.drain((history.len().saturating_sub(strategy.messages_to_exclude))..); - if strategy.truncate_large_messages { - for (user_message, _) in &mut history { - user_message.truncate_safe(strategy.max_message_length); - } - } - - // Only send the dummy tool spec in order to prevent the model from ever attempting a tool - // use. - let mut tools = self.tools.clone(); - tools.retain(|k, v| match k { - ToolOrigin::Native => { - v.retain(|tool| match tool { - Tool::ToolSpecification(tool_spec) => tool_spec.name == DUMMY_TOOL_NAME, - }); - true - }, - ToolOrigin::McpServer(_) => false, - }); - - enforce_conversation_invariants(&mut history, &mut summary_message, &tools); - - Ok(FigConversationState { - conversation_id: Some(self.conversation_id.clone()), - user_input_message: summary_message - .unwrap_or(UserMessage::new_prompt(summary_content)) // should not happen - .into_user_input_message(self.model.clone(), &tools), - history: Some(flatten_history(history.iter())), - }) - } - - /// `strategy` - The [CompactStrategy] used for the corresponding - /// [ConversationState::create_summary_request]. - pub fn replace_history_with_summary(&mut self, summary: String, strategy: CompactStrategy) { - self.history - .drain(..(self.history.len().saturating_sub(strategy.messages_to_exclude))); - self.latest_summary = Some(summary); - } - - pub fn current_profile(&self) -> Option<&str> { - if let Some(cm) = self.context_manager.as_ref() { - Some(cm.current_profile.as_str()) - } else { - None - } - } - - /// Returns pairs of user and assistant messages to include as context in the message history - /// including both summaries and context files if available, and the dropped context files. - /// - /// TODO: - /// - Either add support for multiple context messages if the context is too large to fit inside - /// a single user message, or handle this case more gracefully. For now, always return 2 - /// messages. - /// - Cache this return for some period of time. - async fn context_messages( - &mut self, - os: &Os, - conversation_start_context: Option, - ) -> (Option>, Vec<(String, String)>) { - let mut context_content = String::new(); - let mut dropped_context_files = Vec::new(); - if let Some(summary) = &self.latest_summary { - context_content.push_str(CONTEXT_ENTRY_START_HEADER); - context_content.push_str("This summary contains ALL relevant information from our previous conversation including tool uses, results, code analysis, and file operations. YOU MUST reference this information when answering questions and explicitly acknowledge specific details from the summary when they're relevant to the current question.\n\n"); - context_content.push_str("SUMMARY CONTENT:\n"); - context_content.push_str(summary); - context_content.push('\n'); - context_content.push_str(CONTEXT_ENTRY_END_HEADER); - } - - // Add context files if available - if let Some(context_manager) = self.context_manager.as_mut() { - match context_manager.collect_context_files_with_limit(os).await { - Ok((files_to_use, files_dropped)) => { - if !files_dropped.is_empty() { - dropped_context_files.extend(files_dropped); - } - - if !files_to_use.is_empty() { - context_content.push_str(CONTEXT_ENTRY_START_HEADER); - for (filename, content) in files_to_use { - context_content.push_str(&format!("[{}]\n{}\n", filename, content)); - } - context_content.push_str(CONTEXT_ENTRY_END_HEADER); - } - }, - Err(e) => { - warn!("Failed to get context files: {}", e); - }, - } - } - - if let Some(context) = conversation_start_context { - context_content.push_str(&context); - } - - if !context_content.is_empty() { - self.context_message_length = Some(context_content.len()); - let user_msg = UserMessage::new_prompt(context_content); - let assistant_msg = AssistantMessage::new_response(None, "I will fully incorporate this information when generating my responses, and explicitly acknowledge relevant parts of the summary when answering questions.".into()); - (Some(vec![(user_msg, assistant_msg)]), dropped_context_files) - } else { - (None, dropped_context_files) - } - } - - /// The length of the user message used as context, if any. - pub fn context_message_length(&self) -> Option { - self.context_message_length - } - - /// Calculate the total character count in the conversation - pub async fn calculate_char_count(&mut self, os: &Os) -> Result { - Ok(self - .backend_conversation_state(os, false, &mut vec![]) - .await? - .char_count()) - } - - /// Get the current token warning level - pub async fn get_token_warning_level(&mut self, os: &Os) -> Result { - let total_chars = self.calculate_char_count(os).await?; - - Ok(if *total_chars >= MAX_CHARS { - TokenWarningLevel::Critical - } else { - TokenWarningLevel::None - }) - } - - pub fn append_user_transcript(&mut self, message: &str) { - self.append_transcript(format!("> {}", message.replace("\n", "> \n"))); - } - - pub fn append_assistant_transcript(&mut self, message: &AssistantMessage) { - let tool_uses = message.tool_uses().map_or("none".to_string(), |tools| { - tools.iter().map(|tool| tool.name.clone()).collect::>().join(",") - }); - self.append_transcript(format!("{}\n[Tool uses: {tool_uses}]", message.content())); - } - - pub fn append_transcript(&mut self, message: String) { - if self.transcript.len() >= MAX_CONVERSATION_STATE_HISTORY_LEN { - self.transcript.pop_front(); - } - self.transcript.push_back(message); - } -} - -/// Represents a conversation state that can be converted into a [FigConversationState] (the type -/// used by the API client). Represents borrowed data, and reflects an exact [FigConversationState] -/// that can be generated from [ConversationState] at any point in time. -/// -/// This is intended to provide us ways to accurately assess the exact state that is sent to the -/// model without having to needlessly clone and mutate [ConversationState] in strange ways. -pub type BackendConversationState<'a> = BackendConversationStateImpl< - 'a, - std::collections::vec_deque::Iter<'a, (UserMessage, AssistantMessage)>, - Option>, ->; - -/// See [BackendConversationState] -#[derive(Debug, Clone)] -pub struct BackendConversationStateImpl<'a, T, U> { - pub conversation_id: &'a str, - pub next_user_message: Option<&'a UserMessage>, - pub history: T, - pub context_messages: U, - pub dropped_context_files: Vec<(String, String)>, - pub tools: &'a HashMap>, - pub model_id: Option<&'a str>, -} - -impl - BackendConversationStateImpl< - '_, - std::collections::vec_deque::Iter<'_, (UserMessage, AssistantMessage)>, - Option>, - > -{ - fn into_fig_conversation_state(self) -> eyre::Result { - let history = flatten_history(self.context_messages.unwrap_or_default().iter().chain(self.history)); - let user_input_message: UserInputMessage = self - .next_user_message - .cloned() - .map(|msg| msg.into_user_input_message(self.model_id.map(str::to_string), self.tools)) - .ok_or(eyre::eyre!("next user message is not set"))?; - - Ok(FigConversationState { - conversation_id: Some(self.conversation_id.to_string()), - user_input_message, - history: Some(history), - }) - } - - pub fn calculate_conversation_size(&self) -> ConversationSize { - let mut user_chars = 0; - let mut assistant_chars = 0; - let mut context_chars = 0; - - // Count the chars used by the messages in the history. - // this clone is cheap - let history = self.history.clone(); - for (user, assistant) in history { - user_chars += *user.char_count(); - assistant_chars += *assistant.char_count(); - } - - // Add any chars from context messages, if available. - context_chars += self - .context_messages - .as_ref() - .map(|v| { - v.iter().fold(0, |acc, (user, assistant)| { - acc + *user.char_count() + *assistant.char_count() - }) - }) - .unwrap_or_default(); - - ConversationSize { - context_messages: context_chars.into(), - user_messages: user_chars.into(), - assistant_messages: assistant_chars.into(), - } - } -} - -/// Reflects a detailed accounting of the context window utilization for a given conversation. -#[derive(Debug, Clone, Copy)] -pub struct ConversationSize { - pub context_messages: CharCount, - pub user_messages: CharCount, - pub assistant_messages: CharCount, -} - -/// Converts a list of user/assistant message pairs into a flattened list of ChatMessage. -fn flatten_history<'a, T>(history: T) -> Vec -where - T: Iterator, -{ - history.fold(Vec::new(), |mut acc, (user, assistant)| { - acc.push(ChatMessage::UserInputMessage(user.clone().into_history_entry())); - acc.push(ChatMessage::AssistantResponseMessage(assistant.clone().into())); - acc - }) -} - -/// Character count warning levels for conversation size -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum TokenWarningLevel { - /// No warning, conversation is within normal limits - None, - /// Critical level - at single warning threshold (600K characters) - Critical, -} - -impl From for ToolInputSchema { - fn from(value: InputSchema) -> Self { - Self { - json: Some(serde_value_to_document(value.0).into()), - } - } -} - -/// Formats hook output to be used within context blocks (e.g., in context messages or in new user -/// prompts). -/// -/// # Returns -/// [Option::Some] if `hook_results` is not empty and at least one hook has content. Otherwise, -/// [Option::None] -fn format_hook_context(hook_results: &[(Hook, String)], trigger: HookTrigger) -> Option { - if hook_results.iter().all(|(_, content)| content.is_empty()) { - return None; - } - - let mut context_content = String::new(); - - context_content.push_str(CONTEXT_ENTRY_START_HEADER); - context_content.push_str("This section (like others) contains important information that I want you to use in your responses. I have gathered this context from valuable programmatic script hooks. You must follow any requests and consider all of the information in this section"); - if trigger == HookTrigger::ConversationStart { - context_content.push_str(" for the entire conversation"); - } - context_content.push_str("\n\n"); - - for (hook, output) in hook_results.iter().filter(|(h, _)| h.trigger == trigger) { - context_content.push_str(&format!("'{}': {output}\n\n", &hook.name)); - } - context_content.push_str(CONTEXT_ENTRY_END_HEADER); - Some(context_content) -} - -fn enforce_conversation_invariants( - history: &mut VecDeque<(UserMessage, AssistantMessage)>, - next_message: &mut Option, - tools: &HashMap>, -) -> (usize, usize) { - // First set the valid range as the entire history - this will be truncated as necessary - // later below. - let mut valid_history_range = (0, history.len()); - - // Trim the conversation history by finding the second oldest message from the user without - // tool results - this will be the new oldest message in the history. - // - // Note that we reserve extra slots for [ConversationState::context_messages]. - if (history.len() * 2) > MAX_CONVERSATION_STATE_HISTORY_LEN - 6 { - match history - .iter() - .enumerate() - .skip(1) - .find(|(_, (m, _))| -> bool { !m.has_tool_use_results() }) - .map(|v| v.0) - { - Some(i) => { - debug!("removing the first {i} user/assistant response pairs in the history"); - valid_history_range.0 = i; - }, - None => { - debug!("no valid starting user message found in the history, clearing"); - valid_history_range = (0, 0); - // Edge case: if the next message contains tool results, then we have to just - // abandon them. - if next_message.as_ref().is_some_and(|m| m.has_tool_use_results()) { - debug!("abandoning tool results"); - *next_message = Some(UserMessage::new_prompt( - "The conversation history has overflowed, clearing state".to_string(), - )); - } - }, - } - } - - // If the first message contains tool results, then we add the results to the content field - // instead. This is required to avoid validation errors. - if let Some((user, _)) = history.front_mut() { - if user.has_tool_use_results() { - user.replace_content_with_tool_use_results(); - } - } - - // If the next message is set with tool results, but the previous assistant message is not a - // tool use, then we add the results to the content field instead. - match ( - next_message.as_mut(), - history.range(valid_history_range.0..valid_history_range.1).last(), - ) { - (Some(next_message), prev_msg) if next_message.has_tool_use_results() => match prev_msg { - None | Some((_, AssistantMessage::Response { .. })) => { - next_message.replace_content_with_tool_use_results(); - }, - _ => (), - }, - (_, _) => (), - } - - // If the last message from the assistant contains tool uses AND next_message is set, we need to - // ensure that next_message contains tool results. - if let (Some((_, AssistantMessage::ToolUse { tool_uses, .. })), Some(user_msg)) = ( - history.range(valid_history_range.0..valid_history_range.1).last(), - next_message, - ) { - if !user_msg.has_tool_use_results() { - debug!( - "last assistant message contains tool uses, but next message is set and does not contain tool results. setting tool results as cancelled" - ); - *user_msg = UserMessage::new_cancelled_tool_uses( - user_msg.prompt().map(|p| p.to_string()), - tool_uses.iter().map(|t| t.id.as_str()), - ); - } - } - - enforce_tool_use_history_invariants(history, tools); - - valid_history_range -} - -fn enforce_tool_use_history_invariants( - history: &mut VecDeque<(UserMessage, AssistantMessage)>, - tools: &HashMap>, -) { - let tool_names: HashSet<_> = tools - .values() - .flat_map(|tools| { - tools.iter().map(|tool| match tool { - Tool::ToolSpecification(tool_specification) => tool_specification.name.as_str(), - }) - }) - .filter(|name| *name != DUMMY_TOOL_NAME) - .collect(); - - for (_, assistant) in history { - if let AssistantMessage::ToolUse { tool_uses, .. } = assistant { - for tool_use in tool_uses { - if tool_names.contains(tool_use.name.as_str()) { - continue; - } - - if tool_names.contains(tool_use.orig_name.as_str()) { - tool_use.name = tool_use.orig_name.clone(); - tool_use.args = tool_use.orig_args.clone(); - continue; - } - - let names: Vec<&str> = tool_names - .iter() - .filter_map(|name| { - if name.ends_with(&tool_use.name) { - Some(*name) - } else { - None - } - }) - .collect(); - - // There's only one tool use matching, so we can just replace it with the - // found name. - if names.len() == 1 { - tool_use.name = (*names.first().unwrap()).to_string(); - continue; - } - - // Otherwise, we have to replace it with a dummy. - tool_use.name = DUMMY_TOOL_NAME.to_string(); - } - } - } -} - -#[cfg(test)] -mod tests { - use super::super::context::{ - AMAZONQ_FILENAME, - profile_context_path, - }; - use super::super::message::AssistantToolUse; - use super::*; - use crate::api_client::model::{ - AssistantResponseMessage, - ToolResultStatus, - }; - use crate::cli::chat::tool_manager::ToolManager; - - fn assert_conversation_state_invariants(state: FigConversationState, assertion_iteration: usize) { - if let Some(Some(msg)) = state.history.as_ref().map(|h| h.first()) { - assert!( - matches!(msg, ChatMessage::UserInputMessage(_)), - "{assertion_iteration}: First message in the history must be from the user, instead found: {:?}", - msg - ); - } - if let Some(Some(msg)) = state.history.as_ref().map(|h| h.last()) { - assert!( - matches!(msg, ChatMessage::AssistantResponseMessage(_)), - "{assertion_iteration}: Last message in the history must be from the assistant, instead found: {:?}", - msg - ); - // If the last message from the assistant contains tool uses, then the next user - // message must contain tool results. - match (state.user_input_message.user_input_message_context.as_ref(), msg) { - ( - Some(os), - ChatMessage::AssistantResponseMessage(AssistantResponseMessage { - tool_uses: Some(tool_uses), - .. - }), - ) if !tool_uses.is_empty() => { - assert!( - os.tool_results.as_ref().is_some_and(|r| !r.is_empty()), - "The user input message must contain tool results when the last assistant message contains tool uses" - ); - }, - _ => {}, - } - } - - if let Some(history) = state.history.as_ref() { - for (i, msg) in history.iter().enumerate() { - // User message checks. - if let ChatMessage::UserInputMessage(user) = msg { - assert!( - user.user_input_message_context - .as_ref() - .is_none_or(|os| os.tools.is_none()), - "the tool specification should be empty for all user messages in the history" - ); - - // Check that messages with tool results are immediately preceded by an - // assistant message with tool uses. - if user - .user_input_message_context - .as_ref() - .is_some_and(|os| os.tool_results.as_ref().is_some_and(|r| !r.is_empty())) - { - match history.get(i.checked_sub(1).unwrap_or_else(|| { - panic!( - "{assertion_iteration}: first message in the history should not contain tool results" - ) - })) { - Some(ChatMessage::AssistantResponseMessage(assistant)) => { - assert!(assistant.tool_uses.is_some()); - }, - _ => panic!( - "expected an assistant response message with tool uses at index: {}", - i - 1 - ), - } - } - } - } - } - - let actual_history_len = state.history.unwrap_or_default().len(); - assert!( - actual_history_len <= MAX_CONVERSATION_STATE_HISTORY_LEN, - "history should not extend past the max limit of {}, instead found length {}", - MAX_CONVERSATION_STATE_HISTORY_LEN, - actual_history_len - ); - - let os = state - .user_input_message - .user_input_message_context - .as_ref() - .expect("user input message context must exist"); - assert!( - os.tools.is_some(), - "Currently, the tool spec must be included in the next user message" - ); - } - - #[tokio::test] - async fn test_conversation_state_history_handling_truncation() { - let mut os = Os::new().await.unwrap(); - let mut tool_manager = ToolManager::default(); - let tools = tool_manager.load_tools(&mut os, &mut vec![]).await.unwrap(); - let mut conversation = ConversationState::new(&mut os, "fake_conv_id", tools, None, tool_manager, None).await; - - // First, build a large conversation history. We need to ensure that the order is always - // User -> Assistant -> User -> Assistant ...and so on. - conversation.set_next_user_message("start".to_string()).await; - for i in 0..=(MAX_CONVERSATION_STATE_HISTORY_LEN + 100) { - let s = conversation - .as_sendable_conversation_state(&os, &mut vec![], true) - .await - .unwrap(); - assert_conversation_state_invariants(s, i); - conversation.push_assistant_message(&mut os, AssistantMessage::new_response(None, i.to_string())); - conversation.set_next_user_message(i.to_string()).await; - } - } - - #[tokio::test] - async fn test_conversation_state_history_handling_with_tool_results() { - let mut os = Os::new().await.unwrap(); - - // Build a long conversation history of tool use results. - let mut tool_manager = ToolManager::default(); - let tool_config = tool_manager.load_tools(&mut os, &mut vec![]).await.unwrap(); - let mut conversation = ConversationState::new( - &mut os, - "fake_conv_id", - tool_config.clone(), - None, - tool_manager.clone(), - None, - ) - .await; - conversation.set_next_user_message("start".to_string()).await; - for i in 0..=(MAX_CONVERSATION_STATE_HISTORY_LEN + 100) { - let s = conversation - .as_sendable_conversation_state(&os, &mut vec![], true) - .await - .unwrap(); - assert_conversation_state_invariants(s, i); - - conversation.push_assistant_message( - &mut os, - AssistantMessage::new_tool_use(None, i.to_string(), vec![AssistantToolUse { - id: "tool_id".to_string(), - name: "tool name".to_string(), - args: serde_json::Value::Null, - ..Default::default() - }]), - ); - conversation.add_tool_results(vec![ToolUseResult { - tool_use_id: "tool_id".to_string(), - content: vec![], - status: ToolResultStatus::Success, - }]); - } - - // Build a long conversation history of user messages mixed in with tool results. - let mut conversation = ConversationState::new( - &mut os, - "fake_conv_id", - tool_config.clone(), - None, - tool_manager.clone(), - None, - ) - .await; - conversation.set_next_user_message("start".to_string()).await; - for i in 0..=(MAX_CONVERSATION_STATE_HISTORY_LEN + 100) { - let s = conversation - .as_sendable_conversation_state(&os, &mut vec![], true) - .await - .unwrap(); - assert_conversation_state_invariants(s, i); - if i % 3 == 0 { - conversation.push_assistant_message( - &mut os, - AssistantMessage::new_tool_use(None, i.to_string(), vec![AssistantToolUse { - id: "tool_id".to_string(), - name: "tool name".to_string(), - args: serde_json::Value::Null, - ..Default::default() - }]), - ); - conversation.add_tool_results(vec![ToolUseResult { - tool_use_id: "tool_id".to_string(), - content: vec![], - status: ToolResultStatus::Success, - }]); - } else { - conversation.push_assistant_message(&mut os, AssistantMessage::new_response(None, i.to_string())); - conversation.set_next_user_message(i.to_string()).await; - } - } - } - - #[tokio::test] - async fn test_conversation_state_with_context_files() { - let mut os = Os::new().await.unwrap(); - os.fs.write(AMAZONQ_FILENAME, "test context").await.unwrap(); - - let mut tool_manager = ToolManager::default(); - let tools = tool_manager.load_tools(&mut os, &mut vec![]).await.unwrap(); - let mut conversation = ConversationState::new(&mut os, "fake_conv_id", tools, None, tool_manager, None).await; - - // First, build a large conversation history. We need to ensure that the order is always - // User -> Assistant -> User -> Assistant ...and so on. - conversation.set_next_user_message("start".to_string()).await; - for i in 0..=(MAX_CONVERSATION_STATE_HISTORY_LEN + 100) { - let s = conversation - .as_sendable_conversation_state(&os, &mut vec![], true) - .await - .unwrap(); - - // Ensure that the first two messages are the fake context messages. - let hist = s.history.as_ref().unwrap(); - let user = &hist[0]; - let assistant = &hist[1]; - match (user, assistant) { - (ChatMessage::UserInputMessage(user), ChatMessage::AssistantResponseMessage(_)) => { - assert!( - user.content.contains("test context"), - "expected context message to contain context file, instead found: {}", - user.content - ); - }, - _ => panic!("Expected the first two messages to be from the user and the assistant"), - } - - assert_conversation_state_invariants(s, i); - - conversation.push_assistant_message(&mut os, AssistantMessage::new_response(None, i.to_string())); - conversation.set_next_user_message(i.to_string()).await; - } - } - - #[tokio::test] - async fn test_conversation_state_additional_context() { - let mut os = Os::new().await.unwrap(); - let mut tool_manager = ToolManager::default(); - let conversation_start_context = "conversation start context"; - let prompt_context = "prompt context"; - let config = serde_json::json!({ - "hooks": { - "test_per_prompt": { - "trigger": "per_prompt", - "type": "inline", - "command": format!("echo {}", prompt_context) - }, - "test_conversation_start": { - "trigger": "conversation_start", - "type": "inline", - "command": format!("echo {}", conversation_start_context) - } - } - }); - let config_path = profile_context_path(&os, "default").unwrap(); - os.fs.create_dir_all(config_path.parent().unwrap()).await.unwrap(); - os.fs - .write(&config_path, serde_json::to_string(&config).unwrap()) - .await - .unwrap(); - let tools = tool_manager.load_tools(&mut os, &mut vec![]).await.unwrap(); - let mut conversation = ConversationState::new(&mut os, "fake_conv_id", tools, None, tool_manager, None).await; - - // Simulate conversation flow - conversation.set_next_user_message("start".to_string()).await; - for i in 0..=5 { - let s = conversation - .as_sendable_conversation_state(&os, &mut vec![], true) - .await - .unwrap(); - let hist = s.history.as_ref().unwrap(); - #[allow(clippy::match_wildcard_for_single_variants)] - match &hist[0] { - ChatMessage::UserInputMessage(user) => { - assert!( - user.content.contains(conversation_start_context), - "expected to contain '{conversation_start_context}', instead found: {}", - user.content - ); - }, - _ => panic!("Expected user message."), - } - assert!( - s.user_input_message.content.contains(prompt_context), - "expected to contain '{prompt_context}', instead found: {}", - s.user_input_message.content - ); - - conversation.push_assistant_message(&mut os, AssistantMessage::new_response(None, i.to_string())); - conversation.set_next_user_message(i.to_string()).await; - } - } -} diff --git a/crates/chat-cli/src/cli/chat/error_formatter.rs b/crates/chat-cli/src/cli/chat/error_formatter.rs deleted file mode 100644 index 96a604bdb..000000000 --- a/crates/chat-cli/src/cli/chat/error_formatter.rs +++ /dev/null @@ -1,148 +0,0 @@ -/// Formats an MCP error message to be more user-friendly. -/// -/// This function extracts nested JSON from the error message and formats it -/// with proper indentation and newlines. -/// -/// # Arguments -/// -/// * `err` - A reference to a serde_json::Value containing the error information -/// -/// # Returns -/// -/// A formatted string representation of the error message -pub fn format_mcp_error(err: &serde_json::Value) -> String { - // Extract the message field from the error JSON - if let Some(message) = err.get("message").and_then(|m| m.as_str()) { - // Check if the message contains a nested JSON array - if let Some(start_idx) = message.find('[') { - if let Some(end_idx) = message.rfind(']') { - let prefix = &message[..start_idx].trim(); - let nested_json = &message[start_idx..=end_idx]; - - // Try to parse the nested JSON - if let Ok(nested_value) = serde_json::from_str::(nested_json) { - // Format the error message with the prefix and pretty-printed nested JSON - return format!( - "{}\n{}", - prefix, - serde_json::to_string_pretty(&nested_value).unwrap_or_else(|_| nested_json.to_string()) - ); - } - } - } - } - - // Fallback if message field is missing or if we couldn't extract and parse nested JSON - serde_json::to_string_pretty(err).unwrap_or_else(|_| format!("{:?}", err)) -} - -#[cfg(test)] -mod tests { - use serde_json::json; - - use super::*; - - #[test] - fn test_format_mcp_error_with_nested_json() { - let error = json!({ - "code": -32602, - "message": "MCP error -32602: Invalid arguments for prompt agent_script_coco_was_sev2_ticket_details_retrieve: [\n {\n \"code\": \"invalid_type\",\n \"expected\": \"object\",\n \"received\": \"undefined\",\n \"path\": [],\n \"message\": \"Required\"\n }\n]" - }); - - let formatted = format_mcp_error(&error); - - // Extract the prefix and JSON part from the formatted string - let parts: Vec<&str> = formatted.split('\n').collect(); - let prefix = parts[0]; - let json_part = &formatted[prefix.len() + 1..]; - - // Check that the prefix is correct - assert_eq!( - prefix, - "MCP error -32602: Invalid arguments for prompt agent_script_coco_was_sev2_ticket_details_retrieve:" - ); - - // Parse the JSON part to compare the actual content rather than the exact string - let parsed_json: serde_json::Value = serde_json::from_str(json_part).expect("Failed to parse JSON part"); - - // Expected JSON structure - let expected_json = json!([ - { - "code": "invalid_type", - "expected": "object", - "received": "undefined", - "path": [], - "message": "Required" - } - ]); - - // Compare the parsed JSON values - assert_eq!(parsed_json, expected_json); - } - - #[test] - fn test_format_mcp_error_without_nested_json() { - let error = json!({ - "code": -32602, - "message": "MCP error -32602: Invalid arguments for prompt" - }); - - let formatted = format_mcp_error(&error); - - assert_eq!( - formatted, - "{\n \"code\": -32602,\n \"message\": \"MCP error -32602: Invalid arguments for prompt\"\n}" - ); - } - - #[test] - fn test_format_mcp_error_non_mcp_error() { - let error = json!({ - "error": "Unknown error occurred" - }); - - let formatted = format_mcp_error(&error); - - // Should pretty-print the entire error - assert_eq!(formatted, "{\n \"error\": \"Unknown error occurred\"\n}"); - } - - #[test] - fn test_format_mcp_error_empty_message() { - let error = json!({ - "code": -32602, - "message": "" - }); - - let formatted = format_mcp_error(&error); - - assert_eq!(formatted, "{\n \"code\": -32602,\n \"message\": \"\"\n}"); - } - - #[test] - fn test_format_mcp_error_missing_message() { - let error = json!({ - "code": -32602 - }); - - let formatted = format_mcp_error(&error); - - assert_eq!(formatted, "{\n \"code\": -32602\n}"); - } - - #[test] - fn test_format_mcp_error_malformed_nested_json() { - let error = json!({ - "code": -32602, - "message": "MCP error -32602: Invalid arguments for prompt: [{\n \"code\": \"invalid_type\",\n \"expected\": \"object\",\n \"received\": \"undefined\",\n \"path\": [],\n \"message\": \"Required\"\n" - }); - - let formatted = format_mcp_error(&error); - - // Should return the pretty-printed JSON since the nested JSON is malformed - assert_eq!( - formatted, - "{\n \"code\": -32602,\n \"message\": \"MCP error -32602: Invalid arguments for prompt: [{\\n \\\"code\\\": \\\"invalid_type\\\",\\n \\\"expected\\\": \\\"object\\\",\\n \\\"received\\\": \\\"undefined\\\",\\n \\\"path\\\": [],\\n \\\"message\\\": \\\"Required\\\"\\n\"\n}" - ); - } -} diff --git a/crates/chat-cli/src/cli/chat/input_source.rs b/crates/chat-cli/src/cli/chat/input_source.rs deleted file mode 100644 index 028b2e288..000000000 --- a/crates/chat-cli/src/cli/chat/input_source.rs +++ /dev/null @@ -1,126 +0,0 @@ -use eyre::Result; -use rustyline::error::ReadlineError; - -use super::prompt::rl; -#[cfg(unix)] -use super::skim_integration::SkimCommandSelector; -use crate::os::Os; - -#[derive(Debug)] -pub struct InputSource(inner::Inner); - -mod inner { - use rustyline::Editor; - use rustyline::history::FileHistory; - - use super::super::prompt::ChatHelper; - - #[allow(clippy::large_enum_variant)] - #[derive(Debug)] - pub enum Inner { - Readline(Editor), - #[allow(dead_code)] - Mock { - index: usize, - lines: Vec, - }, - } -} - -impl InputSource { - pub fn new( - os: &Os, - sender: std::sync::mpsc::Sender>, - receiver: std::sync::mpsc::Receiver>, - ) -> Result { - Ok(Self(inner::Inner::Readline(rl(os, sender, receiver)?))) - } - - #[cfg(unix)] - pub fn put_skim_command_selector( - &mut self, - os: &Os, - context_manager: std::sync::Arc, - tool_names: Vec, - ) { - use rustyline::{ - EventHandler, - KeyEvent, - }; - - use crate::database::settings::Setting; - - if let inner::Inner::Readline(rl) = &mut self.0 { - let key_char = match os.database.settings.get_string(Setting::SkimCommandKey) { - Some(key) if key.len() == 1 => key.chars().next().unwrap_or('s'), - _ => 's', // Default to 's' if setting is missing or invalid - }; - rl.bind_sequence( - KeyEvent::ctrl(key_char), - EventHandler::Conditional(Box::new(SkimCommandSelector::new( - os.clone(), - context_manager, - tool_names, - ))), - ); - } - } - - #[allow(dead_code)] - pub fn new_mock(lines: Vec) -> Self { - Self(inner::Inner::Mock { index: 0, lines }) - } - - pub fn read_line(&mut self, prompt: Option<&str>) -> Result, ReadlineError> { - match &mut self.0 { - inner::Inner::Readline(rl) => { - let prompt = prompt.unwrap_or_default(); - let curr_line = rl.readline(prompt); - match curr_line { - Ok(line) => { - let _ = rl.add_history_entry(line.as_str()); - - if let Some(helper) = rl.helper_mut() { - helper.update_hinter_history(&line); - } - - Ok(Some(line)) - }, - Err(ReadlineError::Interrupted | ReadlineError::Eof) => Ok(None), - Err(err) => Err(err), - } - }, - inner::Inner::Mock { index, lines } => { - *index += 1; - Ok(lines.get(*index - 1).cloned()) - }, - } - } - - // We're keeping this method for potential future use - #[allow(dead_code)] - pub fn set_buffer(&mut self, content: &str) { - if let inner::Inner::Readline(rl) = &mut self.0 { - // Add to history so user can access it with up arrow - let _ = rl.add_history_entry(content); - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_mock_input_source() { - let l1 = "Hello,".to_string(); - let l2 = "Line 2".to_string(); - let l3 = "World!".to_string(); - let mut input = InputSource::new_mock(vec![l1.clone(), l2.clone(), l3.clone()]); - - assert_eq!(input.read_line(None).unwrap().unwrap(), l1); - assert_eq!(input.read_line(None).unwrap().unwrap(), l2); - assert_eq!(input.read_line(None).unwrap().unwrap(), l3); - assert!(input.read_line(None).unwrap().is_none()); - } -} diff --git a/crates/chat-cli/src/cli/chat/message.rs b/crates/chat-cli/src/cli/chat/message.rs deleted file mode 100644 index 756398981..000000000 --- a/crates/chat-cli/src/cli/chat/message.rs +++ /dev/null @@ -1,514 +0,0 @@ -use std::collections::HashMap; -use std::env; - -use serde::{ - Deserialize, - Serialize, -}; -use tracing::{ - error, - warn, -}; - -use super::consts::{ - MAX_CURRENT_WORKING_DIRECTORY_LEN, - MAX_USER_MESSAGE_SIZE, -}; -use super::tools::{ - InvokeOutput, - OutputKind, - ToolOrigin, -}; -use super::util::{ - document_to_serde_value, - serde_value_to_document, - truncate_safe, - truncate_safe_in_place, -}; -use crate::api_client::model::{ - AssistantResponseMessage, - EnvState, - ImageBlock, - Tool, - ToolResult, - ToolResultContentBlock, - ToolResultStatus, - ToolUse, - UserInputMessage, - UserInputMessageContext, -}; - -const USER_ENTRY_START_HEADER: &str = "--- USER MESSAGE BEGIN ---\n"; -const USER_ENTRY_END_HEADER: &str = "--- USER MESSAGE END ---\n\n"; - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct UserMessage { - pub additional_context: String, - pub env_context: UserEnvContext, - pub content: UserMessageContent, - pub images: Option>, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub enum UserMessageContent { - Prompt { - /// The original prompt as input by the user. - prompt: String, - }, - CancelledToolUses { - /// The original prompt as input by the user, if any. - prompt: Option, - tool_use_results: Vec, - }, - ToolUseResults { - tool_use_results: Vec, - }, -} - -impl UserMessageContent { - pub const TRUNCATED_SUFFIX: &str = "...content truncated due to length"; - - fn truncate_safe(&mut self, max_bytes: usize) { - match self { - UserMessageContent::Prompt { prompt } => { - truncate_safe_in_place(prompt, max_bytes, Self::TRUNCATED_SUFFIX); - }, - UserMessageContent::CancelledToolUses { - prompt, - tool_use_results, - } => { - if let Some(prompt) = prompt { - truncate_safe_in_place(prompt, max_bytes / 2, Self::TRUNCATED_SUFFIX); - truncate_safe_tool_use_results( - tool_use_results.as_mut_slice(), - max_bytes / 2, - Self::TRUNCATED_SUFFIX, - ); - } else { - truncate_safe_tool_use_results(tool_use_results.as_mut_slice(), max_bytes, Self::TRUNCATED_SUFFIX); - } - }, - UserMessageContent::ToolUseResults { tool_use_results } => { - truncate_safe_tool_use_results(tool_use_results.as_mut_slice(), max_bytes, Self::TRUNCATED_SUFFIX); - }, - } - } -} - -impl UserMessage { - /// Creates a new [UserMessage::Prompt], automatically detecting and adding the user's - /// environment [UserEnvContext]. - pub fn new_prompt(prompt: String) -> Self { - Self { - images: None, - additional_context: String::new(), - env_context: UserEnvContext::generate_new(), - content: UserMessageContent::Prompt { prompt }, - } - } - - pub fn new_cancelled_tool_uses<'a>(prompt: Option, tool_use_ids: impl Iterator) -> Self { - Self { - images: None, - additional_context: String::new(), - env_context: UserEnvContext::generate_new(), - content: UserMessageContent::CancelledToolUses { - prompt, - tool_use_results: tool_use_ids - .map(|id| ToolUseResult { - tool_use_id: id.to_string(), - content: vec![ToolUseResultBlock::Text( - "Tool use was cancelled by the user".to_string(), - )], - status: ToolResultStatus::Error, - }) - .collect(), - }, - } - } - - pub fn new_tool_use_results(results: Vec) -> Self { - Self { - additional_context: String::new(), - env_context: UserEnvContext::generate_new(), - content: UserMessageContent::ToolUseResults { - tool_use_results: results, - }, - images: None, - } - } - - pub fn new_tool_use_results_with_images(results: Vec, images: Vec) -> Self { - Self { - additional_context: String::new(), - env_context: UserEnvContext::generate_new(), - content: UserMessageContent::ToolUseResults { - tool_use_results: results, - }, - images: Some(images), - } - } - - /// Converts this message into a [UserInputMessage] to be stored in the history of - /// [api_client::model::ConversationState]. - pub fn into_history_entry(self) -> UserInputMessage { - UserInputMessage { - images: None, - content: self.prompt().unwrap_or_default().to_string(), - user_input_message_context: Some(UserInputMessageContext { - env_state: self.env_context.env_state, - tool_results: match self.content { - UserMessageContent::CancelledToolUses { tool_use_results, .. } - | UserMessageContent::ToolUseResults { tool_use_results } => { - Some(tool_use_results.into_iter().map(Into::into).collect()) - }, - UserMessageContent::Prompt { .. } => None, - }, - tools: None, - ..Default::default() - }), - user_intent: None, - model_id: None, - } - } - - /// Converts this message into a [UserInputMessage] to be sent as - /// [FigConversationState::user_input_message]. - pub fn into_user_input_message( - self, - model_id: Option, - tools: &HashMap>, - ) -> UserInputMessage { - let formatted_prompt = match self.prompt() { - Some(prompt) if !prompt.is_empty() => { - format!("{}{}{}", USER_ENTRY_START_HEADER, prompt, USER_ENTRY_END_HEADER) - }, - _ => String::new(), - }; - UserInputMessage { - images: self.images, - content: format!("{} {}", self.additional_context, formatted_prompt) - .trim() - .to_string(), - user_input_message_context: Some(UserInputMessageContext { - env_state: self.env_context.env_state, - tool_results: match self.content { - UserMessageContent::CancelledToolUses { tool_use_results, .. } - | UserMessageContent::ToolUseResults { tool_use_results } => { - Some(tool_use_results.into_iter().map(Into::into).collect()) - }, - UserMessageContent::Prompt { .. } => None, - }, - tools: if tools.is_empty() { - None - } else { - Some(tools.values().flatten().cloned().collect::>()) - }, - ..Default::default() - }), - user_intent: None, - model_id, - } - } - - pub fn has_tool_use_results(&self) -> bool { - match self.content() { - UserMessageContent::CancelledToolUses { .. } | UserMessageContent::ToolUseResults { .. } => true, - UserMessageContent::Prompt { .. } => false, - } - } - - pub fn tool_use_results(&self) -> Option<&[ToolUseResult]> { - match self.content() { - UserMessageContent::Prompt { .. } => None, - UserMessageContent::CancelledToolUses { tool_use_results, .. } => Some(tool_use_results.as_slice()), - UserMessageContent::ToolUseResults { tool_use_results } => Some(tool_use_results.as_slice()), - } - } - - pub fn additional_context(&self) -> &str { - &self.additional_context - } - - pub fn content(&self) -> &UserMessageContent { - &self.content - } - - pub fn prompt(&self) -> Option<&str> { - match self.content() { - UserMessageContent::Prompt { prompt } => Some(prompt.as_str()), - UserMessageContent::CancelledToolUses { prompt, .. } => prompt.as_ref().map(|s| s.as_str()), - UserMessageContent::ToolUseResults { .. } => None, - } - } - - /// Truncates the content contained in this user message to a maximum length of `max_bytes`. - pub fn truncate_safe(&mut self, max_bytes: usize) { - self.content.truncate_safe(max_bytes); - } - - pub fn replace_content_with_tool_use_results(&mut self) { - if let Some(tool_results) = self.tool_use_results() { - let tool_content: Vec = tool_results - .iter() - .flat_map(|tr| { - tr.content.iter().map(|c| match c { - ToolUseResultBlock::Json(document) => serde_json::to_string(&document) - .map_err(|err| error!(?err, "failed to serialize tool result")) - .unwrap_or_default(), - ToolUseResultBlock::Text(s) => s.clone(), - }) - }) - .collect::<_>(); - let mut tool_content = tool_content.join(" "); - if tool_content.is_empty() { - // To avoid validation errors with empty content, we need to make sure - // something is set. - tool_content.push_str(""); - } - let prompt = truncate_safe(&tool_content, MAX_USER_MESSAGE_SIZE).to_string(); - self.content = UserMessageContent::Prompt { prompt }; - } - } -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ToolUseResult { - /// The ID for the tool request. - pub tool_use_id: String, - /// Content of the tool result. - pub content: Vec, - /// Status of the tool result. - pub status: ToolResultStatus, -} - -impl From for ToolUseResult { - fn from(value: ToolResult) -> Self { - Self { - tool_use_id: value.tool_use_id, - content: value.content.into_iter().map(Into::into).collect(), - status: value.status, - } - } -} - -impl From for ToolResult { - fn from(value: ToolUseResult) -> Self { - Self { - tool_use_id: value.tool_use_id, - content: value.content.into_iter().map(Into::into).collect(), - status: value.status, - } - } -} - -fn truncate_safe_tool_use_results(tool_use_results: &mut [ToolUseResult], max_bytes: usize, truncated_suffix: &str) { - let max_bytes = max_bytes / tool_use_results.len(); - for result in tool_use_results { - for content in &mut result.content { - match content { - ToolUseResultBlock::Json(value) => match serde_json::to_string(value) { - Ok(mut value_str) => { - if value_str.len() > max_bytes { - truncate_safe_in_place(&mut value_str, max_bytes, truncated_suffix); - *content = ToolUseResultBlock::Text(value_str); - return; - } - }, - Err(err) => { - warn!(?err, "Unable to truncate JSON"); - }, - }, - ToolUseResultBlock::Text(t) => { - truncate_safe_in_place(t, max_bytes, truncated_suffix); - }, - } - } - } -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub enum ToolUseResultBlock { - Json(serde_json::Value), - Text(String), -} - -impl From for ToolResultContentBlock { - fn from(value: ToolUseResultBlock) -> Self { - match value { - ToolUseResultBlock::Json(v) => Self::Json(serde_value_to_document(v)), - ToolUseResultBlock::Text(s) => Self::Text(s), - } - } -} - -impl From for ToolUseResultBlock { - fn from(value: ToolResultContentBlock) -> Self { - match value { - ToolResultContentBlock::Json(v) => Self::Json(document_to_serde_value(v)), - ToolResultContentBlock::Text(s) => Self::Text(s), - } - } -} - -impl From for ToolUseResultBlock { - fn from(value: InvokeOutput) -> Self { - match value.output { - OutputKind::Text(text) => Self::Text(text), - OutputKind::Json(value) => Self::Json(value), - OutputKind::Images(_) => Self::Text("See images data supplied".to_string()), - } - } -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct UserEnvContext { - env_state: Option, -} - -impl UserEnvContext { - pub fn generate_new() -> Self { - Self { - env_state: Some(build_env_state()), - } - } -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub enum AssistantMessage { - /// Normal response containing no tool uses. - Response { - message_id: Option, - content: String, - }, - /// An assistant message containing tool uses. - ToolUse { - message_id: Option, - content: String, - tool_uses: Vec, - }, -} - -impl AssistantMessage { - pub fn new_response(message_id: Option, content: String) -> Self { - Self::Response { message_id, content } - } - - pub fn new_tool_use(message_id: Option, content: String, tool_uses: Vec) -> Self { - Self::ToolUse { - message_id, - content, - tool_uses, - } - } - - pub fn message_id(&self) -> Option<&str> { - match self { - AssistantMessage::Response { message_id, .. } => message_id.as_ref().map(|s| s.as_str()), - AssistantMessage::ToolUse { message_id, .. } => message_id.as_ref().map(|s| s.as_str()), - } - } - - pub fn content(&self) -> &str { - match self { - AssistantMessage::Response { content, .. } => content.as_str(), - AssistantMessage::ToolUse { content, .. } => content.as_str(), - } - } - - pub fn tool_uses(&self) -> Option<&[AssistantToolUse]> { - match self { - AssistantMessage::ToolUse { tool_uses, .. } => Some(tool_uses.as_slice()), - AssistantMessage::Response { .. } => None, - } - } -} - -impl From for AssistantResponseMessage { - fn from(value: AssistantMessage) -> Self { - let (message_id, content, tool_uses) = match value { - AssistantMessage::Response { message_id, content } => (message_id, content, None), - AssistantMessage::ToolUse { - message_id, - content, - tool_uses, - } => ( - message_id, - content, - Some(tool_uses.into_iter().map(Into::into).collect()), - ), - }; - Self { - message_id, - content, - tool_uses, - } - } -} - -#[derive(Default, Debug, Clone, Serialize, Deserialize)] -pub struct AssistantToolUse { - /// The ID for the tool request. - pub id: String, - /// The name for the tool as exposed to the model - pub name: String, - /// Original name of the tool - pub orig_name: String, - /// The input to pass to the tool as exposed to the model - pub args: serde_json::Value, - /// Original input passed to the tool - pub orig_args: serde_json::Value, -} - -impl From for ToolUse { - fn from(value: AssistantToolUse) -> Self { - Self { - tool_use_id: value.id, - name: value.name, - input: serde_value_to_document(value.args).into(), - } - } -} - -impl From for AssistantToolUse { - fn from(value: ToolUse) -> Self { - Self { - id: value.tool_use_id, - name: value.name, - args: document_to_serde_value(value.input.into()), - ..Default::default() - } - } -} - -pub fn build_env_state() -> EnvState { - let mut env_state = EnvState { - operating_system: Some(env::consts::OS.into()), - ..Default::default() - }; - - match env::current_dir() { - Ok(current_dir) => { - env_state.current_working_directory = - Some(truncate_safe(¤t_dir.to_string_lossy(), MAX_CURRENT_WORKING_DIRECTORY_LEN).into()); - }, - Err(err) => { - error!(?err, "Attempted to fetch the CWD but it did not exist."); - }, - } - - env_state -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_env_state() { - let env_state = build_env_state(); - assert!(env_state.current_working_directory.is_some()); - assert!(env_state.operating_system.as_ref().is_some_and(|os| !os.is_empty())); - println!("{env_state:?}"); - } -} diff --git a/crates/chat-cli/src/cli/chat/mod.rs b/crates/chat-cli/src/cli/chat/mod.rs deleted file mode 100644 index 76609e3c9..000000000 --- a/crates/chat-cli/src/cli/chat/mod.rs +++ /dev/null @@ -1,2884 +0,0 @@ -mod cli; -mod consts; -mod context; -mod conversation; -mod error_formatter; -mod input_source; -mod message; -mod parse; -use std::path::MAIN_SEPARATOR; -mod parser; -mod prompt; -mod prompt_parser; -mod server_messenger; -#[cfg(unix)] -mod skim_integration; -mod token_counter; -pub mod tool_manager; -pub mod tools; -pub mod util; - -use std::borrow::Cow; -use std::collections::{ - HashMap, - HashSet, - VecDeque, -}; -use std::io::{ - IsTerminal, - Read, - Write, -}; -use std::process::ExitCode; -use std::time::Duration; - -use amzn_codewhisperer_client::types::SubscriptionStatus; -use clap::{ - Args, - CommandFactory, - Parser, -}; -use cli::compact::CompactStrategy; -use cli::model::select_model; -use context::ContextManager; -pub use conversation::ConversationState; -use conversation::TokenWarningLevel; -use crossterm::style::{ - Attribute, - Color, - Stylize, -}; -use crossterm::{ - cursor, - execute, - queue, - style, - terminal, -}; -use eyre::{ - Report, - Result, - bail, - eyre, -}; -use input_source::InputSource; -use message::{ - AssistantMessage, - AssistantToolUse, - ToolUseResult, - ToolUseResultBlock, -}; -use parse::{ - ParseState, - interpret_markdown, -}; -use parser::{ - RecvErrorKind, - ResponseParser, -}; -use regex::Regex; -use spinners::{ - Spinner, - Spinners, -}; -use thiserror::Error; -use time::OffsetDateTime; -use token_counter::TokenCounter; -use tokio::signal::ctrl_c; -use tool_manager::{ - McpServerConfig, - ToolManager, - ToolManagerBuilder, -}; -use tools::gh_issue::GhIssueContext; -use tools::{ - OutputKind, - QueuedTool, - Tool, - ToolPermissions, - ToolSpec, -}; -use tracing::{ - debug, - error, - info, - trace, - warn, -}; -use util::images::RichImageBlock; -use util::ui::draw_box; -use util::{ - animate_output, - play_notification_bell, -}; -use winnow::Partial; -use winnow::stream::Offset; - -use crate::api_client::ApiClientError; -use crate::api_client::model::{ - Tool as FigTool, - ToolResultStatus, -}; -use crate::api_client::send_message_output::SendMessageOutput; -use crate::auth::AuthError; -use crate::auth::builder_id::is_idc_user; -use crate::cli::chat::cli::SlashCommand; -use crate::cli::chat::cli::model::{ - MODEL_OPTIONS, - default_model_id, -}; -use crate::cli::chat::cli::prompts::{ - GetPromptError, - PromptsSubcommand, -}; -use crate::database::settings::Setting; -use crate::mcp_client::Prompt; -use crate::os::Os; -use crate::telemetry::core::ToolUseEventBuilder; -use crate::telemetry::{ - ReasonCode, - TelemetryResult, - get_error_reason, -}; - -const LIMIT_REACHED_TEXT: &str = color_print::cstr! { "You've used all your free requests for this month. You have two options: -1. Upgrade to a paid subscription for increased limits. See our Pricing page for what's included> https://aws.amazon.com/q/developer/pricing/ -2. Wait until next month when your limit automatically resets." }; - -pub const EXTRA_HELP: &str = color_print::cstr! {" -MCP: -You can now configure the Amazon Q CLI to use MCP servers. \nLearn how: https://docs.aws.amazon.com/en_us/amazonq/latest/qdeveloper-ug/command-line-mcp.html - -Tips: -!{command} Quickly execute a command in your current session -Ctrl(^) + j Insert new-line to provide multi-line prompt - Alternatively, [Alt(⌥) + Enter(⏎)] -Ctrl(^) + s Fuzzy search commands and context files - Use Tab to select multiple items - Change the keybind using: q settings chat.skimCommandKey x -chat.editMode The prompt editing mode (vim or emacs) - Change using: q settings chat.skimCommandKey x -"}; - -#[derive(Debug, Clone, PartialEq, Eq, Default, Args)] -pub struct ChatArgs { - /// Resumes the previous conversation from this directory. - #[arg(short, long)] - pub resume: bool, - /// Context profile to use - #[arg(long = "profile")] - pub profile: Option, - /// Current model to use - #[arg(long = "model")] - pub model: Option, - /// Allows the model to use any tool to run commands without asking for confirmation. - #[arg(short = 'a', long)] - pub trust_all_tools: bool, - /// Trust only this set of tools. Example: trust some tools: - /// '--trust-tools=fs_read,fs_write', trust no tools: '--trust-tools=' - #[arg(long, value_delimiter = ',', value_name = "TOOL_NAMES")] - pub trust_tools: Option>, - /// Whether the command should run without expecting user input - #[arg(long, alias = "non-interactive")] - pub no_interactive: bool, - /// The first question to ask - pub input: Option, -} - -impl ChatArgs { - pub async fn execute(self, os: &mut Os) -> Result { - let mut input = self.input; - - if self.no_interactive && input.is_none() { - if !std::io::stdin().is_terminal() { - let mut buffer = String::new(); - match std::io::stdin().read_to_string(&mut buffer) { - Ok(_) => { - if !buffer.trim().is_empty() { - input = Some(buffer.trim().to_string()); - } - }, - Err(e) => { - eprintln!("Error reading from stdin: {}", e); - }, - } - } - - if input.is_none() { - bail!("Input must be supplied when running in non-interactive mode"); - } - } - - let stdout = std::io::stdout(); - let mut stderr = std::io::stderr(); - - let mcp_server_configs = match McpServerConfig::load_config(&mut stderr).await { - Ok(config) => { - if !os.database.settings.get_bool(Setting::McpLoadedBefore).unwrap_or(false) { - execute!( - stderr, - style::Print( - "To learn more about MCP safety, see https://docs.aws.amazon.com/amazonq/latest/qdeveloper-ug/command-line-mcp-security.html\n\n" - ) - )?; - } - os.database.settings.set(Setting::McpLoadedBefore, true).await?; - config - }, - Err(e) => { - warn!("No mcp server config loaded: {}", e); - McpServerConfig::default() - }, - }; - - // If profile is specified, verify it exists before starting the chat - if let Some(ref profile_name) = self.profile { - // Create a temporary context manager to check if the profile exists - match ContextManager::new(os, None).await { - Ok(context_manager) => { - let profiles = context_manager.list_profiles(os).await?; - if !profiles.contains(profile_name) { - bail!( - "Profile '{}' does not exist. Available profiles: {}", - profile_name, - profiles.join(", ") - ); - } - }, - Err(e) => { - warn!("Failed to initialize context manager to verify profile: {}", e); - // Continue without verification if context manager can't be initialized - }, - } - } - - // If modelId is specified, verify it exists before starting the chat - let model_id: Option = if let Some(model_name) = self.model { - let model_name_lower = model_name.to_lowercase(); - match MODEL_OPTIONS.iter().find(|opt| opt.name == model_name_lower) { - Some(opt) => Some((opt.model_id).to_string()), - None => { - let available_names: Vec<&str> = MODEL_OPTIONS.iter().map(|opt| opt.name).collect(); - bail!( - "Model '{}' does not exist. Available models: {}", - model_name, - available_names.join(", ") - ); - }, - } - } else { - None - }; - - let conversation_id = uuid::Uuid::new_v4().to_string(); - info!(?conversation_id, "Generated new conversation id"); - let (prompt_request_sender, prompt_request_receiver) = std::sync::mpsc::channel::>(); - let (prompt_response_sender, prompt_response_receiver) = std::sync::mpsc::channel::>(); - let mut tool_manager = ToolManagerBuilder::default() - .mcp_server_config(mcp_server_configs) - .prompt_list_sender(prompt_response_sender) - .prompt_list_receiver(prompt_request_receiver) - .conversation_id(&conversation_id) - .build(os, Box::new(std::io::stderr()), !self.no_interactive) - .await?; - let tool_config = tool_manager.load_tools(os, &mut stderr).await?; - let mut tool_permissions = ToolPermissions::new(tool_config.len()); - - if self.trust_all_tools { - tool_permissions.trust_all = true; - for tool in tool_config.values() { - tool_permissions.trust_tool(&tool.name); - } - } else if let Some(trusted) = self.trust_tools.map(|vec| vec.into_iter().collect::>()) { - // --trust-all-tools takes precedence over --trust-tools=... - for tool_name in &trusted { - if !tool_name.is_empty() { - // Store the original trust settings for later use with MCP tools - tool_permissions.add_pending_trust_tool(tool_name.clone()); - } - } - - // Apply to currently known tools - for tool in tool_config.values() { - if trusted.contains(&tool.name) { - tool_permissions.trust_tool(&tool.name); - } else { - tool_permissions.untrust_tool(&tool.name); - } - } - } - - ChatSession::new( - os, - stdout, - stderr, - &conversation_id, - input, - InputSource::new(os, prompt_request_sender, prompt_response_receiver)?, - self.resume, - || terminal::window_size().map(|s| s.columns.into()).ok(), - tool_manager, - self.profile, - model_id, - tool_config, - tool_permissions, - !self.no_interactive, - ) - .await? - .spawn(os) - .await - .map(|_| ExitCode::SUCCESS) - } -} - -const WELCOME_TEXT: &str = color_print::cstr! {" - ⢠⣶⣶⣦⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⣤⣶⣿⣿⣿⣶⣦⡀⠀ - ⠀⠀⠀⣾⡿⢻⣿⡆⠀⠀⠀⢀⣄⡄⢀⣠⣤⣤⡀⢀⣠⣤⣤⡀⠀⠀⢀⣠⣤⣤⣤⣄⠀⠀⢀⣤⣤⣤⣤⣤⣤⡀⠀⠀⣀⣤⣤⣤⣀⠀⠀⠀⢠⣤⡀⣀⣤⣤⣄⡀⠀⠀⠀⠀⠀⠀⢠⣿⣿⠋⠀⠀⠀⠙⣿⣿⡆ - ⠀⠀⣼⣿⠇⠀⣿⣿⡄⠀⠀⢸⣿⣿⠛⠉⠻⣿⣿⠛⠉⠛⣿⣿⠀⠀⠘⠛⠉⠉⠻⣿⣧⠀⠈⠛⠛⠛⣻⣿⡿⠀⢀⣾⣿⠛⠉⠻⣿⣷⡀⠀⢸⣿⡟⠛⠉⢻⣿⣷⠀⠀⠀⠀⠀⠀⣼⣿⡏⠀⠀⠀⠀⠀⢸⣿⣿ - ⠀⢰⣿⣿⣤⣤⣼⣿⣷⠀⠀⢸⣿⣿⠀⠀⠀⣿⣿⠀⠀⠀⣿⣿⠀⠀⢀⣴⣶⣶⣶⣿⣿⠀⠀⠀⣠⣾⡿⠋⠀⠀⢸⣿⣿⠀⠀⠀⣿⣿⡇⠀⢸⣿⡇⠀⠀⢸⣿⣿⠀⠀⠀⠀⠀⠀⢹⣿⣇⠀⠀⠀⠀⠀⢸⣿⡿ - ⢀⣿⣿⠋⠉⠉⠉⢻⣿⣇⠀⢸⣿⣿⠀⠀⠀⣿⣿⠀⠀⠀⣿⣿⠀⠀⣿⣿⡀⠀⣠⣿⣿⠀⢀⣴⣿⣋⣀⣀⣀⡀⠘⣿⣿⣄⣀⣠⣿⣿⠃⠀⢸⣿⡇⠀⠀⢸⣿⣿⠀⠀⠀⠀⠀⠀⠈⢿⣿⣦⣀⣀⣀⣴⣿⡿⠃ - ⠚⠛⠋⠀⠀⠀⠀⠘⠛⠛⠀⠘⠛⠛⠀⠀⠀⠛⠛⠀⠀⠀⠛⠛⠀⠀⠙⠻⠿⠟⠋⠛⠛⠀⠘⠛⠛⠛⠛⠛⠛⠃⠀⠈⠛⠿⠿⠿⠛⠁⠀⠀⠘⠛⠃⠀⠀⠘⠛⠛⠀⠀⠀⠀⠀⠀⠀⠀⠙⠛⠿⢿⣿⣿⣋⠀⠀ - ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠛⠿⢿⡧"}; - -const SMALL_SCREEN_WELCOME_TEXT: &str = color_print::cstr! {"Welcome to Amazon Q!"}; -const RESUME_TEXT: &str = color_print::cstr! {"Picking up where we left off..."}; - -// Only show the model-related tip for now to make users aware of this feature. -const ROTATING_TIPS: [&str; 16] = [ - color_print::cstr! {"You can resume the last conversation from your current directory by launching with - q chat --resume"}, - color_print::cstr! {"Get notified whenever Q CLI finishes responding. - Just run q settings chat.enableNotifications true"}, - color_print::cstr! {"You can use - /editor to edit your prompt with a vim-like experience"}, - color_print::cstr! {"/usage shows you a visual breakdown of your current context window usage"}, - color_print::cstr! {"Get notified whenever Q CLI finishes responding. Just run q settings - chat.enableNotifications true"}, - color_print::cstr! {"You can execute bash commands by typing - ! followed by the command"}, - color_print::cstr! {"Q can use tools without asking for - confirmation every time. Give /tools trust a try"}, - color_print::cstr! {"You can - programmatically inject context to your prompts by using hooks. Check out /context hooks - help"}, - color_print::cstr! {"You can use /compact to replace the conversation - history with its summary to free up the context space"}, - color_print::cstr! {"If you want to file an issue - to the Q CLI team, just tell me, or run q issue"}, - color_print::cstr! {"You can enable - custom tools with MCP servers. Learn more with /help"}, - color_print::cstr! {"You can - specify wait time (in ms) for mcp server loading with q settings mcp.initTimeout {timeout in - int}. Servers that takes longer than the specified time will continue to load in the background. Use - /tools to see pending servers."}, - color_print::cstr! {"You can see the server load status as well as any - warnings or errors associated with /mcp"}, - color_print::cstr! {"Use /model to select the model to use for this conversation"}, - color_print::cstr! {"Set a default model by running q settings chat.defaultModel MODEL. Run /model to learn more."}, - color_print::cstr! {"Run /prompts to learn how to build & run repeatable workflows"}, -]; - -const GREETING_BREAK_POINT: usize = 80; - -const POPULAR_SHORTCUTS: &str = color_print::cstr! {"/help all commands ctrl + j new lines ctrl + s fuzzy search"}; -const SMALL_SCREEN_POPULAR_SHORTCUTS: &str = color_print::cstr! {"/help all commands -ctrl + j new lines -ctrl + s fuzzy search -"}; - -const RESPONSE_TIMEOUT_CONTENT: &str = "Response timed out - message took too long to generate"; -const TRUST_ALL_TEXT: &str = color_print::cstr! {"All tools are now trusted (!). Amazon Q will execute tools without asking for confirmation.\ -\nAgents can sometimes do unexpected things so understand the risks. -\nLearn more at https://docs.aws.amazon.com/amazonq/latest/qdeveloper-ug/command-line-chat-security.html#command-line-chat-trustall-safety"}; - -const TOOL_BULLET: &str = " ● "; -const CONTINUATION_LINE: &str = " ⋮ "; -const PURPOSE_ARROW: &str = " ↳ "; - -/// Enum used to denote the origin of a tool use event -enum ToolUseStatus { - /// Variant denotes that the tool use event associated with chat context is a direct result of - /// a user request - Idle, - /// Variant denotes that the tool use event associated with the chat context is a result of a - /// retry for one or more previously attempted tool use. The tuple is the utterance id - /// associated with the original user request that necessitated the tool use - RetryInProgress(String), -} - -#[derive(Debug, Error)] -pub enum ChatError { - #[error("{0}")] - Client(Box), - #[error("{0}")] - Auth(#[from] AuthError), - #[error("{0}")] - ResponseStream(Box), - #[error("{0}")] - Std(#[from] std::io::Error), - #[error("{0}")] - Readline(#[from] rustyline::error::ReadlineError), - #[error("{0}")] - Custom(Cow<'static, str>), - #[error("interrupted")] - Interrupted { tool_uses: Option> }, - #[error(transparent)] - GetPromptError(#[from] GetPromptError), - #[error( - "Tool approval required but --no-interactive was specified. Use --trust-all-tools to automatically approve tools." - )] - NonInteractiveToolApproval, - #[error("The conversation history is too large to compact")] - CompactHistoryFailure, -} - -impl ChatError { - fn status_code(&self) -> Option { - match self { - ChatError::Client(e) => e.status_code(), - ChatError::Auth(_) => None, - ChatError::ResponseStream(_) => None, - ChatError::Std(_) => None, - ChatError::Readline(_) => None, - ChatError::Custom(_) => None, - ChatError::Interrupted { .. } => None, - ChatError::GetPromptError(_) => None, - ChatError::NonInteractiveToolApproval => None, - ChatError::CompactHistoryFailure => None, - } - } -} - -impl ReasonCode for ChatError { - fn reason_code(&self) -> String { - match self { - ChatError::Client(e) => e.reason_code(), - ChatError::ResponseStream(e) => e.reason_code(), - ChatError::Std(_) => "StdIoError".to_string(), - ChatError::Readline(_) => "ReadlineError".to_string(), - ChatError::Custom(_) => "GenericError".to_string(), - ChatError::Interrupted { .. } => "Interrupted".to_string(), - ChatError::GetPromptError(_) => "GetPromptError".to_string(), - ChatError::Auth(_) => "AuthError".to_string(), - ChatError::NonInteractiveToolApproval => "NonInteractiveToolApproval".to_string(), - ChatError::CompactHistoryFailure => "CompactHistoryFailure".to_string(), - } - } -} - -impl From for ChatError { - fn from(value: ApiClientError) -> Self { - Self::Client(Box::new(value)) - } -} - -impl From for ChatError { - fn from(value: parser::RecvError) -> Self { - Self::ResponseStream(Box::new(value)) - } -} - -pub struct ChatSession { - /// For output read by humans and machine - pub stdout: std::io::Stdout, - /// For display output, only read by humans - pub stderr: std::io::Stderr, - initial_input: Option, - /// Whether we're starting a new conversation or continuing an old one. - existing_conversation: bool, - input_source: InputSource, - /// Width of the terminal, required for [ParseState]. - terminal_width_provider: fn() -> Option, - spinner: Option, - /// [ConversationState]. - conversation: ConversationState, - tool_uses: Vec, - pending_tool_index: Option, - /// State to track tools that need confirmation. - tool_permissions: ToolPermissions, - /// Telemetry events to be sent as part of the conversation. - tool_use_telemetry_events: HashMap, - /// State used to keep track of tool use relation - tool_use_status: ToolUseStatus, - /// Any failed requests that could be useful for error report/debugging - failed_request_ids: Vec, - /// Pending prompts to be sent - pending_prompts: VecDeque, - interactive: bool, - inner: Option, -} - -impl ChatSession { - #[allow(clippy::too_many_arguments)] - pub async fn new( - os: &mut Os, - stdout: std::io::Stdout, - stderr: std::io::Stderr, - conversation_id: &str, - mut input: Option, - input_source: InputSource, - resume_conversation: bool, - terminal_width_provider: fn() -> Option, - tool_manager: ToolManager, - profile: Option, - model_id: Option, - tool_config: HashMap, - tool_permissions: ToolPermissions, - interactive: bool, - ) -> Result { - let valid_model_id = match model_id { - Some(id) => id, - None => { - let from_settings = os - .database - .settings - .get_string(Setting::ChatDefaultModel) - .and_then(|model_name| { - MODEL_OPTIONS - .iter() - .find(|opt| opt.name == model_name) - .map(|opt| opt.model_id.to_owned()) - }); - - match from_settings { - Some(id) => id, - None => default_model_id(os).await.to_owned(), - } - }, - }; - - // Reload prior conversation - let mut existing_conversation = false; - let previous_conversation = std::env::current_dir() - .ok() - .and_then(|cwd| os.database.get_conversation_by_path(cwd).ok()) - .flatten(); - - // Only restore conversations where there were actual messages. - // Prevents edge case where user clears conversation then exits without chatting. - let conversation = match resume_conversation - && previous_conversation - .as_ref() - .is_some_and(|cs| !cs.history().is_empty()) - { - true => { - let mut cs = previous_conversation.unwrap(); - existing_conversation = true; - cs.reload_serialized_state(os).await; - input = Some(input.unwrap_or("In a few words, summarize our conversation so far.".to_owned())); - cs.tool_manager = tool_manager; - cs.update_state(true).await; - cs.enforce_tool_use_history_invariants(); - cs - }, - false => { - ConversationState::new( - os, - conversation_id, - tool_config, - profile, - tool_manager, - Some(valid_model_id), - ) - .await - }, - }; - - Ok(Self { - stdout, - stderr, - initial_input: input, - existing_conversation, - input_source, - terminal_width_provider, - spinner: None, - tool_permissions, - conversation, - tool_uses: vec![], - pending_tool_index: None, - tool_use_telemetry_events: HashMap::new(), - tool_use_status: ToolUseStatus::Idle, - failed_request_ids: Vec::new(), - pending_prompts: VecDeque::new(), - interactive, - inner: Some(ChatState::default()), - }) - } - - pub async fn next(&mut self, os: &mut Os) -> Result<(), ChatError> { - // Update conversation state with new tool information - self.conversation.update_state(false).await; - - let ctrl_c_stream = ctrl_c(); - let result = match self.inner.take().expect("state must always be Some") { - ChatState::PromptUser { skip_printing_tools } => { - match (self.interactive, self.tool_uses.is_empty()) { - (false, true) => { - self.inner = Some(ChatState::Exit); - return Ok(()); - }, - (false, false) => { - return Err(ChatError::NonInteractiveToolApproval); - }, - _ => (), - }; - - self.prompt_user(os, skip_printing_tools).await - }, - ChatState::HandleInput { input } => { - tokio::select! { - res = self.handle_input(os, input) => res, - Ok(_) = ctrl_c_stream => Err(ChatError::Interrupted { tool_uses: Some(self.tool_uses.clone()) }) - } - }, - ChatState::CompactHistory { - prompt, - show_summary, - strategy, - } => { - tokio::select! { - res = self.compact_history(os, prompt, show_summary, strategy) => res, - Ok(_) = ctrl_c_stream => Err(ChatError::Interrupted { tool_uses: Some(self.tool_uses.clone()) }) - } - }, - ChatState::ExecuteTools => { - let tool_uses_clone = self.tool_uses.clone(); - tokio::select! { - res = self.tool_use_execute(os) => res, - Ok(_) = ctrl_c_stream => Err(ChatError::Interrupted { tool_uses: Some(tool_uses_clone) }) - } - }, - ChatState::ValidateTools(tool_uses) => { - tokio::select! { - res = self.validate_tools(os, tool_uses) => res, - Ok(_) = ctrl_c_stream => Err(ChatError::Interrupted { tool_uses: None }) - } - }, - ChatState::HandleResponseStream(response) => tokio::select! { - res = self.handle_response(os, response) => res, - Ok(_) = ctrl_c_stream => { - self.send_chat_telemetry(os, None, TelemetryResult::Cancelled, None, None, None).await; - Err(ChatError::Interrupted { tool_uses: None }) - } - }, - ChatState::RetryModelOverload => tokio::select! { - res = self.retry_model_overload(os) => res, - Ok(_) = ctrl_c_stream => { - Err(ChatError::Interrupted { tool_uses: None }) - } - }, - ChatState::Exit => return Ok(()), - }; - - let err = match result { - Ok(state) => { - self.inner = Some(state); - return Ok(()); - }, - Err(err) => err, - }; - - // We encountered an error. Handle it. - error!(?err, "An error occurred processing the current state"); - let (reason, reason_desc) = get_error_reason(&err); - self.send_error_telemetry(os, reason, Some(reason_desc), err.status_code()) - .await; - - if self.spinner.is_some() { - drop(self.spinner.take()); - queue!( - self.stderr, - terminal::Clear(terminal::ClearType::CurrentLine), - cursor::MoveToColumn(0), - )?; - } - - let (context, report, display_err_message) = match err { - ChatError::Interrupted { tool_uses: ref inter } => { - execute!(self.stderr, style::Print("\n\n"))?; - - // If there was an interrupt during tool execution, then we add fake - // messages to "reset" the chat state. - match inter { - Some(tool_uses) if !tool_uses.is_empty() => { - self.conversation - .abandon_tool_use(tool_uses, "The user interrupted the tool execution.".to_string()); - let _ = self - .conversation - .as_sendable_conversation_state(os, &mut self.stderr, false) - .await?; - self.conversation.push_assistant_message( - os, - AssistantMessage::new_response( - None, - "Tool uses were interrupted, waiting for the next user prompt".to_string(), - ), - ); - }, - _ => (), - } - - ("Tool use was interrupted", Report::from(err), false) - }, - ChatError::CompactHistoryFailure => { - // This error is not retryable - the user must take manual intervention to manage - // their context. - execute!( - self.stderr, - style::SetForegroundColor(Color::Red), - style::Print("Your conversation is too large to continue.\n"), - style::SetForegroundColor(Color::Reset), - style::Print(format!( - "• Run {} to compact your conversation. See {} for compaction options\n", - "/compact".green(), - "/compact --help".green() - )), - style::Print(format!("• Run {} to analyze your context usage\n", "/usage".green())), - style::Print(format!("• Run {} to reset your conversation state\n", "/clear".green())), - style::SetAttribute(Attribute::Reset), - style::Print("\n\n"), - )?; - ("Unable to compact the conversation history", eyre!(err), true) - }, - ChatError::Client(err) => match *err { - // Errors from attempting to send too large of a conversation history. In - // this case, attempt to automatically compact the history for the user. - ApiClientError::ContextWindowOverflow { .. } => { - if os - .database - .settings - .get_bool(Setting::ChatDisableAutoCompaction) - .unwrap_or(false) - { - execute!( - self.stderr, - style::SetForegroundColor(Color::Red), - style::Print("The conversation history has overflowed.\n"), - style::SetForegroundColor(Color::Reset), - style::Print(format!("• Run {} to compact your conversation\n", "/compact".green())), - style::SetAttribute(Attribute::Reset), - style::Print("\n\n"), - )?; - ("The conversation history has overflowed", eyre!(err), false) - } else { - self.inner = Some(ChatState::CompactHistory { - prompt: None, - show_summary: false, - strategy: CompactStrategy { - truncate_large_messages: self.conversation.history().len() <= 2, - ..Default::default() - }, - }); - - execute!( - self.stdout, - style::SetForegroundColor(Color::Yellow), - style::Print("The context window has overflowed, summarizing the history..."), - style::SetAttribute(Attribute::Reset), - style::Print("\n\n"), - )?; - - return Ok(()); - } - }, - ApiClientError::QuotaBreach { message, .. } => (message, Report::from(err), true), - ApiClientError::ModelOverloadedError { request_id, .. } => { - if self.interactive { - execute!( - self.stderr, - style::SetAttribute(Attribute::Bold), - style::SetForegroundColor(Color::Red), - style::Print( - "\nThe model you've selected is temporarily unavailable. Please select a different model.\n" - ), - style::SetAttribute(Attribute::Reset), - style::SetForegroundColor(Color::Reset), - )?; - - if let Some(id) = request_id { - self.conversation - .append_transcript(format!("Model unavailable (Request ID: {})", id)); - } - - self.inner = Some(ChatState::RetryModelOverload); - - return Ok(()); - } - - // non-interactive throws this error - let model_instruction = "Please relaunch with '--model ' to use a different model."; - let err = format!( - "The model you've selected is temporarily unavailable. {}{}\n\n", - model_instruction, - match request_id { - Some(id) => format!("\n Request ID: {}", id), - None => "".to_owned(), - } - ); - self.conversation.append_transcript(err.clone()); - execute!( - self.stderr, - style::SetAttribute(Attribute::Bold), - style::SetForegroundColor(Color::Red), - style::Print("Amazon Q is having trouble responding right now:\n"), - style::Print(format!(" {}\n", err.clone())), - style::SetAttribute(Attribute::Reset), - style::SetForegroundColor(Color::Reset), - )?; - ("Amazon Q is having trouble responding right now", eyre!(err), false) - }, - ApiClientError::MonthlyLimitReached { .. } => { - let subscription_status = get_subscription_status(os).await; - if subscription_status.is_err() { - execute!( - self.stderr, - style::SetForegroundColor(Color::Red), - style::Print(format!( - "Unable to verify subscription status: {}\n\n", - subscription_status.as_ref().err().unwrap() - )), - style::SetForegroundColor(Color::Reset), - )?; - } - - execute!( - self.stderr, - style::SetForegroundColor(Color::Yellow), - style::Print("Monthly request limit reached"), - style::SetForegroundColor(Color::Reset), - )?; - - let limits_text = format!( - "The limits reset on {:02}/01.", - OffsetDateTime::now_utc().month().next() as u8 - ); - - if subscription_status.is_err() - || subscription_status.is_ok_and(|s| s == ActualSubscriptionStatus::None) - { - execute!( - self.stderr, - style::Print(format!("\n\n{LIMIT_REACHED_TEXT} {limits_text}")), - style::SetForegroundColor(Color::DarkGrey), - style::Print("\n\nUse "), - style::SetForegroundColor(Color::Green), - style::Print("/subscribe"), - style::SetForegroundColor(Color::DarkGrey), - style::Print(" to upgrade your subscription.\n\n"), - style::SetForegroundColor(Color::Reset), - )?; - } else { - execute!( - self.stderr, - style::SetForegroundColor(Color::Yellow), - style::Print(format!(" - {limits_text}\n\n")), - style::SetForegroundColor(Color::Reset), - )?; - } - - self.inner = Some(ChatState::PromptUser { - skip_printing_tools: false, - }); - - return Ok(()); - }, - _ => ( - "Amazon Q is having trouble responding right now", - Report::from(err), - true, - ), - }, - _ => ( - "Amazon Q is having trouble responding right now", - Report::from(err), - true, - ), - }; - - if display_err_message { - // Remove non-ASCII and ANSI characters. - let re = Regex::new(r"((\x9B|\x1B\[)[0-?]*[ -\/]*[@-~])|([^\x00-\x7F]+)").unwrap(); - - queue!( - self.stderr, - style::SetAttribute(Attribute::Bold), - style::SetForegroundColor(Color::Red), - )?; - - let text = re.replace_all(&format!("{}: {:?}\n", context, report), "").into_owned(); - - queue!(self.stderr, style::Print(&text),)?; - self.conversation.append_transcript(text); - - execute!( - self.stderr, - style::SetAttribute(Attribute::Reset), - style::SetForegroundColor(Color::Reset), - )?; - } - - self.conversation.enforce_conversation_invariants(); - self.conversation.reset_next_user_message(); - self.pending_tool_index = None; - - self.inner = Some(ChatState::PromptUser { - skip_printing_tools: false, - }); - - Ok(()) - } -} - -impl Drop for ChatSession { - fn drop(&mut self) { - if let Some(spinner) = &mut self.spinner { - spinner.stop(); - } - - execute!( - self.stderr, - cursor::MoveToColumn(0), - style::SetAttribute(Attribute::Reset), - style::ResetColor, - cursor::Show - ) - .ok(); - } -} - -/// The chat execution state. -/// -/// Intended to provide more robust handling around state transitions while dealing with, e.g., -/// tool validation, execution, response stream handling, etc. -#[allow(clippy::large_enum_variant)] -#[derive(Debug)] -enum ChatState { - /// Prompt the user with `tool_uses`, if available. - PromptUser { - /// Used to avoid displaying the tool info at inappropriate times, e.g. after clear or help - /// commands. - skip_printing_tools: bool, - }, - /// Handle the user input, depending on if any tools require execution. - HandleInput { input: String }, - /// Validate the list of tool uses provided by the model. - ValidateTools(Vec), - /// Execute the list of tools. - ExecuteTools, - /// Consume the response stream and display to the user. - HandleResponseStream(SendMessageOutput), - /// Compact the chat history. - CompactHistory { - /// Custom prompt to include as part of history compaction. - prompt: Option, - /// Whether or not the summary should be shown on compact success. - show_summary: bool, - /// Parameters for how to perform the compaction request. - strategy: CompactStrategy, - }, - /// Retry the current request if we encounter a model overloaded error. - RetryModelOverload, - /// Exit the chat. - Exit, -} - -impl Default for ChatState { - fn default() -> Self { - Self::PromptUser { - skip_printing_tools: false, - } - } -} - -impl ChatSession { - async fn spawn(&mut self, os: &mut Os) -> Result<()> { - let is_small_screen = self.terminal_width() < GREETING_BREAK_POINT; - if os - .database - .settings - .get_bool(Setting::ChatGreetingEnabled) - .unwrap_or(true) - { - let welcome_text = match self.existing_conversation { - true => RESUME_TEXT, - false => match is_small_screen { - true => SMALL_SCREEN_WELCOME_TEXT, - false => WELCOME_TEXT, - }, - }; - - execute!(self.stderr, style::Print(welcome_text), style::Print("\n\n"),)?; - - let tip = ROTATING_TIPS[usize::try_from(rand::random::()).unwrap_or(0) % ROTATING_TIPS.len()]; - if is_small_screen { - // If the screen is small, print the tip in a single line - execute!( - self.stderr, - style::Print("💡 ".to_string()), - style::Print(tip), - style::Print("\n") - )?; - } else { - draw_box( - &mut self.stderr, - "Did you know?", - tip, - GREETING_BREAK_POINT, - Color::DarkGrey, - )?; - } - - execute!( - self.stderr, - style::Print("\n"), - style::Print(match is_small_screen { - true => SMALL_SCREEN_POPULAR_SHORTCUTS, - false => POPULAR_SHORTCUTS, - }), - style::Print("\n"), - style::Print( - "━" - .repeat(if is_small_screen { 0 } else { GREETING_BREAK_POINT }) - .dark_grey() - ) - )?; - execute!(self.stderr, style::Print("\n"), style::SetForegroundColor(Color::Reset))?; - } - - if self.all_tools_trusted() { - queue!( - self.stderr, - style::Print(format!( - "{}{TRUST_ALL_TEXT}\n\n", - if !is_small_screen { "\n" } else { "" } - )) - )?; - } - self.stderr.flush()?; - - if let Some(ref id) = self.conversation.model { - if let Some(model_option) = MODEL_OPTIONS.iter().find(|option| option.model_id == *id) { - execute!( - self.stderr, - style::SetForegroundColor(Color::Cyan), - style::Print(format!("🤖 You are chatting with {}\n", model_option.name)), - style::SetForegroundColor(Color::Reset), - style::Print("\n") - )?; - } - } - - if let Some(user_input) = self.initial_input.take() { - self.inner = Some(ChatState::HandleInput { input: user_input }); - } - - while !matches!(self.inner, Some(ChatState::Exit)) { - self.next(os).await?; - } - - Ok(()) - } - - /// Compacts the conversation history using the strategy specified by [CompactStrategy], - /// replacing the history with a summary generated by the model. - /// - /// If the compact request itself fails, it will be retried depending on [CompactStrategy] - /// - /// If [CompactStrategy::messages_to_exclude] is greater than 0, and - /// [CompactStrategy::truncate_large_messages] is true, then compaction will not be retried and - /// will fail with [ChatError::CompactHistoryFailure]. - async fn compact_history( - &mut self, - os: &Os, - custom_prompt: Option, - show_summary: bool, - strategy: CompactStrategy, - ) -> Result { - let hist = self.conversation.history(); - debug!(?strategy, ?hist, "compacting history"); - - if self.conversation.history().is_empty() { - execute!( - self.stderr, - style::SetForegroundColor(Color::Yellow), - style::Print("\nConversation too short to compact.\n\n"), - style::SetForegroundColor(Color::Reset) - )?; - - return Ok(ChatState::PromptUser { - skip_printing_tools: true, - }); - } - - if strategy.truncate_large_messages { - info!("truncating large messages"); - execute!( - self.stderr, - terminal::Clear(terminal::ClearType::CurrentLine), - cursor::MoveToColumn(0), - style::SetForegroundColor(Color::Yellow), - style::Print("Truncating large messages..."), - style::SetAttribute(Attribute::Reset), - style::Print("\n\n"), - )?; - } - - // Send a request for summarizing the history. - let summary_state = self - .conversation - .create_summary_request(os, custom_prompt.as_ref(), strategy) - .await?; - - execute!(self.stderr, cursor::Hide, style::Print("\n"))?; - - if self.interactive { - self.spinner = Some(Spinner::new(Spinners::Dots, "Creating summary...".to_string())); - } - - let response = os.client.send_message(summary_state).await; - - let response = match response { - Ok(res) => res, - Err(err) => { - if self.interactive { - self.spinner.take(); - execute!( - self.stderr, - terminal::Clear(terminal::ClearType::CurrentLine), - cursor::MoveToColumn(0), - style::SetAttribute(Attribute::Reset) - )?; - } - - let (reason, reason_desc) = get_error_reason(&err); - self.send_chat_telemetry( - os, - None, - TelemetryResult::Failed, - Some(reason), - Some(reason_desc), - err.status_code(), - ) - .await; - let history_len = self.conversation.history().len(); - match err { - ApiClientError::ContextWindowOverflow { .. } => { - error!(?strategy, "failed to send compaction request"); - // If there's only two messages in the history, we have no choice but to - // truncate it. We use two messages since it's almost guaranteed to contain: - // 1. A small user prompt - // 2. A large user tool use result - if history_len <= 2 && !strategy.truncate_large_messages { - return Ok(ChatState::CompactHistory { - prompt: custom_prompt, - show_summary, - strategy: CompactStrategy { - truncate_large_messages: true, - max_message_length: 25_000, - messages_to_exclude: 0, - }, - }); - } - - // Otherwise, we will first exclude the most recent message, and only then - // truncate. If both of these have already been set, then return an error. - if history_len > 2 && strategy.messages_to_exclude < 1 { - return Ok(ChatState::CompactHistory { - prompt: custom_prompt, - show_summary, - strategy: CompactStrategy { - messages_to_exclude: 1, - ..strategy - }, - }); - } else if !strategy.truncate_large_messages { - return Ok(ChatState::CompactHistory { - prompt: custom_prompt, - show_summary, - strategy: CompactStrategy { - truncate_large_messages: true, - max_message_length: 25_000, - ..strategy - }, - }); - } else { - return Err(ChatError::CompactHistoryFailure); - } - }, - err => return Err(err.into()), - } - }, - }; - - let request_id = response.request_id().map(|s| s.to_string()); - let summary = { - let mut parser = ResponseParser::new(response); - loop { - match parser.recv().await { - Ok(parser::ResponseEvent::EndStream { message }) => { - break message.content().to_string(); - }, - Ok(_) => (), - Err(err) => { - if let Some(request_id) = &err.request_id { - self.failed_request_ids.push(request_id.clone()); - }; - let (reason, reason_desc) = get_error_reason(&err); - self.send_chat_telemetry( - os, - err.request_id.clone(), - TelemetryResult::Failed, - Some(reason), - Some(reason_desc), - err.status_code(), - ) - .await; - return Err(err.into()); - }, - } - } - }; - - if self.spinner.is_some() { - drop(self.spinner.take()); - queue!( - self.stderr, - terminal::Clear(terminal::ClearType::CurrentLine), - cursor::MoveToColumn(0), - cursor::Show - )?; - } - - self.send_chat_telemetry(os, request_id, TelemetryResult::Succeeded, None, None, None) - .await; - - self.conversation - .replace_history_with_summary(summary.clone(), strategy); - - // Print output to the user. - { - execute!( - self.stderr, - style::SetForegroundColor(Color::Green), - style::Print("✔ Conversation history has been compacted successfully!\n\n"), - style::SetForegroundColor(Color::DarkGrey) - )?; - - let mut output = Vec::new(); - if let Some(custom_prompt) = &custom_prompt { - execute!( - output, - style::Print(format!("• Custom prompt applied: {}\n", custom_prompt)) - )?; - } - animate_output(&mut self.stderr, &output)?; - - // Display the summary if the show_summary flag is set - if show_summary { - // Add a border around the summary for better visual separation - let terminal_width = self.terminal_width(); - let border = "═".repeat(terminal_width.min(80)); - execute!( - self.stderr, - style::Print("\n"), - style::SetForegroundColor(Color::Cyan), - style::Print(&border), - style::Print("\n"), - style::SetAttribute(Attribute::Bold), - style::Print(" CONVERSATION SUMMARY"), - style::Print("\n"), - style::Print(&border), - style::SetAttribute(Attribute::Reset), - style::Print("\n\n"), - )?; - - execute!( - output, - style::Print(&summary), - style::Print("\n\n"), - style::SetForegroundColor(Color::Cyan), - style::Print("The conversation history has been replaced with this summary.\n"), - style::Print("It contains all important details from previous interactions.\n"), - )?; - animate_output(&mut self.stderr, &output)?; - - execute!( - self.stderr, - style::Print(&border), - style::Print("\n\n"), - style::SetForegroundColor(Color::Reset) - )?; - } - } - - // If a next message is set, then retry the request. - if self.conversation.next_user_message().is_some() { - Ok(ChatState::HandleResponseStream( - os.client - .send_message( - self.conversation - .as_sendable_conversation_state(os, &mut self.stderr, false) - .await?, - ) - .await?, - )) - } else { - // Otherwise, return back to the prompt for any pending tool uses. - Ok(ChatState::PromptUser { - skip_printing_tools: true, - }) - } - } - - /// Read input from the user. - async fn prompt_user(&mut self, os: &Os, skip_printing_tools: bool) -> Result { - execute!(self.stderr, cursor::Show)?; - - // Check token usage and display warnings if needed - if self.pending_tool_index.is_none() { - // Only display warnings when not waiting for tool approval - if let Err(err) = self.display_char_warnings(os).await { - warn!("Failed to display character limit warnings: {}", err); - } - } - - let show_tool_use_confirmation_dialog = !skip_printing_tools && self.pending_tool_index.is_some(); - if show_tool_use_confirmation_dialog { - execute!( - self.stderr, - style::SetForegroundColor(Color::DarkGrey), - style::Print("\nAllow this action? Use '"), - style::SetForegroundColor(Color::Green), - style::Print("t"), - style::SetForegroundColor(Color::DarkGrey), - style::Print("' to trust (always allow) this tool for the session. ["), - style::SetForegroundColor(Color::Green), - style::Print("y"), - style::SetForegroundColor(Color::DarkGrey), - style::Print("/"), - style::SetForegroundColor(Color::Green), - style::Print("n"), - style::SetForegroundColor(Color::DarkGrey), - style::Print("/"), - style::SetForegroundColor(Color::Green), - style::Print("t"), - style::SetForegroundColor(Color::DarkGrey), - style::Print("]:\n\n"), - style::SetForegroundColor(Color::Reset), - )?; - } - - // Do this here so that the skim integration sees an updated view of the context *during the current - // q session*. (e.g., if I add files to context, that won't show up for skim for the current - // q session unless we do this in prompt_user... unless you can find a better way) - #[cfg(unix)] - if let Some(ref context_manager) = self.conversation.context_manager { - use std::sync::Arc; - - use crate::cli::chat::consts::DUMMY_TOOL_NAME; - - let tool_names = self - .conversation - .tool_manager - .tn_map - .keys() - .filter(|name| *name != DUMMY_TOOL_NAME) - .cloned() - .collect::>(); - self.input_source - .put_skim_command_selector(os, Arc::new(context_manager.clone()), tool_names); - } - - execute!( - self.stderr, - style::SetForegroundColor(Color::Reset), - style::SetAttribute(Attribute::Reset) - )?; - let prompt = self.generate_tool_trust_prompt(); - let user_input = match self.read_user_input(&prompt, false) { - Some(input) => input, - None => return Ok(ChatState::Exit), - }; - - self.conversation.append_user_transcript(&user_input); - Ok(ChatState::HandleInput { input: user_input }) - } - - async fn handle_input(&mut self, os: &mut Os, mut user_input: String) -> Result { - queue!(self.stderr, style::Print('\n'))?; - - let input = user_input.trim(); - - // handle image path - if let Some(chat_state) = does_input_reference_file(input) { - return Ok(chat_state); - } - if let Some(mut args) = input.strip_prefix("/").and_then(shlex::split) { - // Required for printing errors correctly. - let orig_args = args.clone(); - - // We set the binary name as a dummy name "slash_command" which we - // replace anytime we error out and print a usage statement. - args.insert(0, "slash_command".to_owned()); - - match SlashCommand::try_parse_from(args) { - Ok(command) => { - match command.execute(os, self).await { - Ok(chat_state) - if matches!(chat_state, ChatState::Exit) - || matches!(chat_state, ChatState::HandleInput { input: _ }) - // TODO(bskiser): this is just a hotfix for handling state changes - // from manually running /compact, without impacting behavior of - // other slash commands. - || matches!(chat_state, ChatState::CompactHistory { .. }) => - { - return Ok(chat_state); - }, - Err(err) => { - queue!( - self.stderr, - style::SetForegroundColor(Color::Red), - style::Print(format!("\nFailed to execute command: {}\n", err)), - style::SetForegroundColor(Color::Reset) - )?; - }, - _ => {}, - } - - writeln!(self.stderr)?; - }, - Err(err) => { - // Replace the dummy name with a slash. Also have to check for an ansi sequence - // for invalid slash commands (e.g. on a "/doesntexist" input). - let ansi_output = err - .render() - .ansi() - .to_string() - .replace("slash_command ", "/") - .replace("slash_command\u{1b}[0m ", "/"); - - writeln!(self.stderr, "{}", ansi_output)?; - - // Print the subcommand help, if available. Required since by default we won't - // show what the actual arguments are, requiring an unnecessary --help call. - if let clap::error::ErrorKind::InvalidValue - | clap::error::ErrorKind::UnknownArgument - | clap::error::ErrorKind::InvalidSubcommand - | clap::error::ErrorKind::MissingRequiredArgument = err.kind() - { - let mut cmd = SlashCommand::command(); - for arg in &orig_args { - match cmd.find_subcommand(arg) { - Some(subcmd) => cmd = subcmd.clone(), - None => break, - } - } - let help = cmd.help_template("{all-args}").render_help(); - writeln!(self.stderr, "{}", help.ansi())?; - } - }, - } - - Ok(ChatState::PromptUser { - skip_printing_tools: false, - }) - } else if let Some(command) = input.strip_prefix("@") { - let input_parts = - shlex::split(command).ok_or(ChatError::Custom("Error splitting prompt command".into()))?; - - let mut iter = input_parts.into_iter(); - let prompt_name = iter - .next() - .ok_or(ChatError::Custom("Prompt name needs to be specified".into()))?; - - let args: Vec = iter.collect(); - let arguments = if args.is_empty() { None } else { Some(args) }; - - let subcommand = PromptsSubcommand::Get { - orig_input: Some(command.to_string()), - name: prompt_name, - arguments, - }; - return subcommand.execute(self).await; - } else if let Some(command) = input.strip_prefix("!") { - // Use platform-appropriate shell - let result = if cfg!(target_os = "windows") { - std::process::Command::new("cmd").args(["/C", command]).status() - } else { - std::process::Command::new("bash").args(["-c", command]).status() - }; - - // Handle the result and provide appropriate feedback - match result { - Ok(status) => { - if !status.success() { - queue!( - self.stderr, - style::SetForegroundColor(Color::Yellow), - style::Print(format!("Self exited with status: {}\n", status)), - style::SetForegroundColor(Color::Reset) - )?; - } - }, - Err(e) => { - queue!( - self.stderr, - style::SetForegroundColor(Color::Red), - style::Print(format!("\nFailed to execute command: {}\n", e)), - style::SetForegroundColor(Color::Reset) - )?; - }, - } - - Ok(ChatState::PromptUser { - skip_printing_tools: false, - }) - } else { - // Check for a pending tool approval - if let Some(index) = self.pending_tool_index { - let is_trust = ["t", "T"].contains(&input); - let tool_use = &mut self.tool_uses[index]; - if ["y", "Y"].contains(&input) || is_trust { - if is_trust { - self.tool_permissions.trust_tool(&tool_use.name); - } - tool_use.accepted = true; - - return Ok(ChatState::ExecuteTools); - } - } else if !self.pending_prompts.is_empty() { - let prompts = self.pending_prompts.drain(0..).collect(); - user_input = self - .conversation - .append_prompts(prompts) - .ok_or(ChatError::Custom("Prompt append failed".into()))?; - } - - // Otherwise continue with normal chat on 'n' or other responses - self.tool_use_status = ToolUseStatus::Idle; - - if self.pending_tool_index.is_some() { - // If the user just enters "n", replace the message we send to the model with - // something more substantial. - // TODO: Update this flow to something that does *not* require two requests just to - // get a meaningful response from the user - this is a short term solution before - // we decide on a better flow. - let user_input = if ["n", "N"].contains(&user_input.trim()) { - "I deny this tool request. Ask a follow up question clarifying the expected action".to_string() - } else { - user_input - }; - self.conversation.abandon_tool_use(&self.tool_uses, user_input); - } else { - self.conversation.set_next_user_message(user_input).await; - } - - let conv_state = self - .conversation - .as_sendable_conversation_state(os, &mut self.stderr, true) - .await?; - self.send_tool_use_telemetry(os).await; - - queue!(self.stderr, style::SetForegroundColor(Color::Magenta))?; - queue!(self.stderr, style::SetForegroundColor(Color::Reset))?; - queue!(self.stderr, cursor::Hide)?; - - if self.interactive { - self.spinner = Some(Spinner::new(Spinners::Dots, "Thinking...".to_owned())); - } - - Ok(ChatState::HandleResponseStream( - os.client.send_message(conv_state).await?, - )) - } - } - - async fn tool_use_execute(&mut self, os: &mut Os) -> Result { - // Verify tools have permissions. - for i in 0..self.tool_uses.len() { - let tool = &mut self.tool_uses[i]; - - // Manually accepted by the user or otherwise verified already. - if tool.accepted { - continue; - } - - // If there is an override, we will use it. Otherwise fall back to Tool's default. - let allowed = self.tool_permissions.trust_all - || self.tool_permissions.is_trusted(&tool.name) - || (!self.tool_permissions.has(&tool.name) && !tool.tool.requires_acceptance(os)); - - if os - .database - .settings - .get_bool(Setting::ChatEnableNotifications) - .unwrap_or(false) - { - play_notification_bell(!allowed); - } - - // TODO: Control flow is hacky here because of borrow rules - let _ = tool; - self.print_tool_description(os, i, allowed).await?; - let tool = &mut self.tool_uses[i]; - - if allowed { - tool.accepted = true; - continue; - } - - self.pending_tool_index = Some(i); - - return Ok(ChatState::PromptUser { - skip_printing_tools: false, - }); - } - - // Execute the requested tools. - let mut tool_results = vec![]; - let mut image_blocks: Vec = Vec::new(); - - for tool in &self.tool_uses { - let mut tool_telemetry = self.tool_use_telemetry_events.entry(tool.id.clone()); - tool_telemetry = tool_telemetry.and_modify(|ev| ev.is_accepted = true); - - // Extract AWS service name and operation name if available - if let Some(additional_info) = tool.tool.get_additional_info() { - if let Some(aws_service_name) = additional_info.get("aws_service_name").and_then(|v| v.as_str()) { - tool_telemetry = - tool_telemetry.and_modify(|ev| ev.aws_service_name = Some(aws_service_name.to_string())); - } - if let Some(aws_operation_name) = additional_info.get("aws_operation_name").and_then(|v| v.as_str()) { - tool_telemetry = - tool_telemetry.and_modify(|ev| ev.aws_operation_name = Some(aws_operation_name.to_string())); - } - } - - let tool_start = std::time::Instant::now(); - let invoke_result = tool.tool.invoke(os, &mut self.stdout).await; - - if self.spinner.is_some() { - queue!( - self.stderr, - terminal::Clear(terminal::ClearType::CurrentLine), - cursor::MoveToColumn(0), - cursor::Show - )?; - } - execute!(self.stdout, style::Print("\n"))?; - - let tool_time = std::time::Instant::now().duration_since(tool_start); - if let Tool::Custom(ct) = &tool.tool { - tool_telemetry = tool_telemetry.and_modify(|ev| { - ev.custom_tool_call_latency = Some(tool_time.as_secs() as usize); - ev.input_token_size = Some(ct.get_input_token_size()); - ev.is_custom_tool = true; - }); - } - let tool_time = format!("{}.{}", tool_time.as_secs(), tool_time.subsec_millis()); - match invoke_result { - Ok(result) => { - match result.output { - OutputKind::Text(ref text) => { - debug!("Output is Text: {}", text); - }, - OutputKind::Json(ref json) => { - debug!("Output is JSON: {}", json); - }, - OutputKind::Images(ref image) => { - image_blocks.extend(image.clone()); - }, - } - - debug!("tool result output: {:#?}", result); - execute!( - self.stdout, - style::Print(CONTINUATION_LINE), - style::Print("\n"), - style::SetForegroundColor(Color::Green), - style::SetAttribute(Attribute::Bold), - style::Print(format!(" ● Completed in {}s", tool_time)), - style::SetForegroundColor(Color::Reset), - style::Print("\n\n"), - )?; - - tool_telemetry = tool_telemetry.and_modify(|ev| ev.is_success = Some(true)); - if let Tool::Custom(_) = &tool.tool { - tool_telemetry - .and_modify(|ev| ev.output_token_size = Some(TokenCounter::count_tokens(result.as_str()))); - } - tool_results.push(ToolUseResult { - tool_use_id: tool.id.clone(), - content: vec![result.into()], - status: ToolResultStatus::Success, - }); - }, - Err(err) => { - error!(?err, "An error occurred processing the tool"); - execute!( - self.stderr, - style::Print(CONTINUATION_LINE), - style::Print("\n"), - style::SetAttribute(Attribute::Bold), - style::SetForegroundColor(Color::Red), - style::Print(format!(" ● Execution failed after {}s:\n", tool_time)), - style::SetAttribute(Attribute::Reset), - style::SetForegroundColor(Color::Red), - style::Print(&err), - style::SetAttribute(Attribute::Reset), - style::Print("\n\n"), - )?; - - tool_telemetry.and_modify(|ev| ev.is_success = Some(false)); - tool_results.push(ToolUseResult { - tool_use_id: tool.id.clone(), - content: vec![ToolUseResultBlock::Text(format!( - "An error occurred processing the tool: \n{}", - &err - ))], - status: ToolResultStatus::Error, - }); - if let ToolUseStatus::Idle = self.tool_use_status { - self.tool_use_status = ToolUseStatus::RetryInProgress( - self.conversation - .message_id() - .map_or("No utterance id found".to_string(), |v| v.to_string()), - ); - } - }, - } - } - - if !image_blocks.is_empty() { - let images = image_blocks.into_iter().map(|(block, _)| block).collect(); - self.conversation.add_tool_results_with_images(tool_results, images); - execute!( - self.stderr, - style::SetAttribute(Attribute::Reset), - style::SetForegroundColor(Color::Reset), - style::Print("\n") - )?; - } else { - self.conversation.add_tool_results(tool_results); - } - - execute!(self.stderr, cursor::Hide)?; - execute!(self.stderr, style::Print("\n"), style::SetAttribute(Attribute::Reset))?; - if self.interactive { - self.spinner = Some(Spinner::new(Spinners::Dots, "Thinking...".to_string())); - } - - self.send_tool_use_telemetry(os).await; - return Ok(ChatState::HandleResponseStream( - os.client - .send_message( - self.conversation - .as_sendable_conversation_state(os, &mut self.stderr, false) - .await?, - ) - .await?, - )); - } - - async fn handle_response(&mut self, os: &mut Os, response: SendMessageOutput) -> Result { - let request_id = response.request_id().map(|s| s.to_string()); - let mut buf = String::new(); - let mut offset = 0; - let mut ended = false; - let mut parser = ResponseParser::new(response); - let mut state = ParseState::new(Some(self.terminal_width())); - let mut response_prefix_printed = false; - - let mut tool_uses = Vec::new(); - let mut tool_name_being_recvd: Option = None; - - if self.spinner.is_some() { - drop(self.spinner.take()); - queue!( - self.stderr, - style::SetForegroundColor(Color::Reset), - cursor::MoveToColumn(0), - cursor::Show, - terminal::Clear(terminal::ClearType::CurrentLine), - )?; - } - - loop { - match parser.recv().await { - Ok(msg_event) => { - trace!("Consumed: {:?}", msg_event); - match msg_event { - parser::ResponseEvent::ToolUseStart { name } => { - // We need to flush the buffer here, otherwise text will not be - // printed while we are receiving tool use events. - buf.push('\n'); - tool_name_being_recvd = Some(name); - }, - parser::ResponseEvent::AssistantText(text) => { - // Add Q response prefix before the first assistant text. - // This must be markdown - using a code tick, which is printed - // as green. - if !response_prefix_printed && !text.trim().is_empty() { - buf.push_str("`>` "); - response_prefix_printed = true; - } - buf.push_str(&text); - }, - parser::ResponseEvent::ToolUse(tool_use) => { - if self.spinner.is_some() { - drop(self.spinner.take()); - queue!( - self.stderr, - terminal::Clear(terminal::ClearType::CurrentLine), - cursor::MoveToColumn(0), - cursor::Show - )?; - } - tool_uses.push(tool_use); - tool_name_being_recvd = None; - }, - parser::ResponseEvent::EndStream { message } => { - // This log is attempting to help debug instances where users encounter - // the response timeout message. - if message.content() == RESPONSE_TIMEOUT_CONTENT { - error!(?request_id, ?message, "Encountered an unexpected model response"); - } - self.conversation.push_assistant_message(os, message); - ended = true; - }, - } - }, - Err(recv_error) => { - if let Some(request_id) = &recv_error.request_id { - self.failed_request_ids.push(request_id.clone()); - }; - - let (reason, reason_desc) = get_error_reason(&recv_error); - self.send_chat_telemetry( - os, - recv_error.request_id.clone(), - TelemetryResult::Failed, - Some(reason), - Some(reason_desc), - recv_error.status_code(), - ) - .await; - - match recv_error.source { - RecvErrorKind::StreamTimeout { source, duration } => { - error!( - recv_error.request_id, - ?source, - "Encountered a stream timeout after waiting for {}s", - duration.as_secs() - ); - - execute!(self.stderr, cursor::Hide)?; - self.spinner = Some(Spinner::new(Spinners::Dots, "Dividing up the work...".to_string())); - - // For stream timeouts, we'll tell the model to try and split its response into - // smaller chunks. - self.conversation.push_assistant_message( - os, - AssistantMessage::new_response(None, RESPONSE_TIMEOUT_CONTENT.to_string()), - ); - self.conversation - .set_next_user_message( - "You took too long to respond - try to split up the work into smaller steps." - .to_string(), - ) - .await; - self.send_tool_use_telemetry(os).await; - return Ok(ChatState::HandleResponseStream( - os.client - .send_message( - self.conversation - .as_sendable_conversation_state(os, &mut self.stderr, false) - .await?, - ) - .await?, - )); - }, - RecvErrorKind::UnexpectedToolUseEos { - tool_use_id, - name, - message, - .. - } => { - error!( - recv_error.request_id, - tool_use_id, name, "The response stream ended before the entire tool use was received" - ); - self.conversation.push_assistant_message(os, *message); - let tool_results = vec![ToolUseResult { - tool_use_id, - content: vec![ToolUseResultBlock::Text( - "The generated tool was too large, try again but this time split up the work between multiple tool uses".to_string(), - )], - status: ToolResultStatus::Error, - }]; - self.conversation.add_tool_results(tool_results); - self.send_tool_use_telemetry(os).await; - return Ok(ChatState::HandleResponseStream( - os.client - .send_message( - self.conversation - .as_sendable_conversation_state(os, &mut self.stderr, false) - .await?, - ) - .await?, - )); - }, - _ => return Err(recv_error.into()), - } - }, - } - - // Fix for the markdown parser copied over from q chat: - // this is a hack since otherwise the parser might report Incomplete with useful data - // still left in the buffer. I'm not sure how this is intended to be handled. - if ended { - buf.push('\n'); - } - - if tool_name_being_recvd.is_none() && !buf.is_empty() && self.spinner.is_some() { - drop(self.spinner.take()); - queue!( - self.stderr, - terminal::Clear(terminal::ClearType::CurrentLine), - cursor::MoveToColumn(0), - cursor::Show - )?; - } - - // Print the response for normal cases - loop { - let input = Partial::new(&buf[offset..]); - match interpret_markdown(input, &mut self.stdout, &mut state) { - Ok(parsed) => { - offset += parsed.offset_from(&input); - self.stdout.flush()?; - state.newline = state.set_newline; - state.set_newline = false; - }, - Err(err) => match err.into_inner() { - Some(err) => return Err(ChatError::Custom(err.to_string().into())), - None => break, // Data was incomplete - }, - } - - // TODO: We should buffer output based on how much we have to parse, not as a constant - // Do not remove unless you are nabochay :) - tokio::time::sleep(Duration::from_millis(8)).await; - } - - // Set spinner after showing all of the assistant text content so far. - if tool_name_being_recvd.is_some() { - queue!(self.stderr, cursor::Hide)?; - if self.interactive { - self.spinner = Some(Spinner::new(Spinners::Dots, "Thinking...".to_string())); - } - } - - if ended { - self.send_chat_telemetry(os, request_id, TelemetryResult::Succeeded, None, None, None) - .await; - - if os - .database - .settings - .get_bool(Setting::ChatEnableNotifications) - .unwrap_or(false) - { - // For final responses (no tools suggested), always play the bell - play_notification_bell(tool_uses.is_empty()); - } - - queue!(self.stderr, style::ResetColor, style::SetAttribute(Attribute::Reset))?; - execute!(self.stdout, style::Print("\n"))?; - - for (i, citation) in &state.citations { - queue!( - self.stdout, - style::Print("\n"), - style::SetForegroundColor(Color::Blue), - style::Print(format!("[^{i}]: ")), - style::SetForegroundColor(Color::DarkGrey), - style::Print(format!("{citation}\n")), - style::SetForegroundColor(Color::Reset) - )?; - } - - break; - } - } - - if !tool_uses.is_empty() { - Ok(ChatState::ValidateTools(tool_uses)) - } else { - self.tool_uses.clear(); - self.pending_tool_index = None; - - Ok(ChatState::PromptUser { - skip_printing_tools: false, - }) - } - } - - async fn validate_tools(&mut self, os: &Os, tool_uses: Vec) -> Result { - let conv_id = self.conversation.conversation_id().to_owned(); - debug!(?tool_uses, "Validating tool uses"); - let mut queued_tools: Vec = Vec::new(); - let mut tool_results: Vec = Vec::new(); - - for tool_use in tool_uses { - let tool_use_id = tool_use.id.clone(); - let tool_use_name = tool_use.name.clone(); - let mut tool_telemetry = - ToolUseEventBuilder::new(conv_id.clone(), tool_use.id.clone(), self.conversation.model.clone()) - .set_tool_use_id(tool_use_id.clone()) - .set_tool_name(tool_use.name.clone()) - .utterance_id(self.conversation.message_id().map(|s| s.to_string())); - match self.conversation.tool_manager.get_tool_from_tool_use(tool_use) { - Ok(mut tool) => { - // Apply non-Q-generated context to tools - self.contextualize_tool(&mut tool); - - match tool.validate(os).await { - Ok(()) => { - tool_telemetry.is_valid = Some(true); - queued_tools.push(QueuedTool { - id: tool_use_id.clone(), - name: tool_use_name, - tool, - accepted: false, - }); - }, - Err(err) => { - tool_telemetry.is_valid = Some(false); - tool_results.push(ToolUseResult { - tool_use_id: tool_use_id.clone(), - content: vec![ToolUseResultBlock::Text(format!( - "Failed to validate tool parameters: {err}" - ))], - status: ToolResultStatus::Error, - }); - }, - }; - }, - Err(err) => { - tool_telemetry.is_valid = Some(false); - tool_results.push(err.into()); - }, - } - self.tool_use_telemetry_events.insert(tool_use_id, tool_telemetry); - } - - // If we have any validation errors, then return them immediately to the model. - if !tool_results.is_empty() { - debug!(?tool_results, "Error found in the model tools"); - queue!( - self.stderr, - style::SetAttribute(Attribute::Bold), - style::Print("Tool validation failed: "), - style::SetAttribute(Attribute::Reset), - )?; - for tool_result in &tool_results { - for block in &tool_result.content { - let content: Option> = match block { - ToolUseResultBlock::Text(t) => Some(t.as_str().into()), - ToolUseResultBlock::Json(d) => serde_json::to_string(d) - .map_err(|err| error!(?err, "failed to serialize tool result content")) - .map(Into::into) - .ok(), - }; - if let Some(content) = content { - queue!( - self.stderr, - style::Print("\n"), - style::SetForegroundColor(Color::Red), - style::Print(format!("{}\n", content)), - style::SetForegroundColor(Color::Reset), - )?; - } - } - } - self.conversation.add_tool_results(tool_results); - self.send_tool_use_telemetry(os).await; - if let ToolUseStatus::Idle = self.tool_use_status { - self.tool_use_status = ToolUseStatus::RetryInProgress( - self.conversation - .message_id() - .map_or("No utterance id found".to_string(), |v| v.to_string()), - ); - } - - let response = os - .client - .send_message( - self.conversation - .as_sendable_conversation_state(os, &mut self.stderr, false) - .await?, - ) - .await?; - return Ok(ChatState::HandleResponseStream(response)); - } - - self.tool_uses = queued_tools; - self.pending_tool_index = Some(0); - Ok(ChatState::ExecuteTools) - } - - async fn retry_model_overload(&mut self, os: &mut Os) -> Result { - match select_model(self) { - Ok(Some(_)) => (), - Ok(None) => { - // User did not select a model, so reset the current request state. - self.conversation.enforce_conversation_invariants(); - self.conversation.reset_next_user_message(); - self.pending_tool_index = None; - return Ok(ChatState::PromptUser { - skip_printing_tools: false, - }); - }, - Err(err) => return Err(err), - } - - let conv_state = self - .conversation - .as_sendable_conversation_state(os, &mut self.stderr, true) - .await?; - - if self.interactive { - self.spinner = Some(Spinner::new(Spinners::Dots, "Thinking...".to_owned())); - } - - Ok(ChatState::HandleResponseStream( - os.client.send_message(conv_state).await?, - )) - } - - /// Apply program context to tools that Q may not have. - // We cannot attach this any other way because Tools are constructed by deserializing - // output from Amazon Q. - // TODO: Is there a better way? - fn contextualize_tool(&self, tool: &mut Tool) { - if let Tool::GhIssue(gh_issue) = tool { - gh_issue.set_context(GhIssueContext { - // Ideally we avoid cloning, but this function is not called very often. - // Using references with lifetimes requires a large refactor, and Arc> - // seems like overkill and may incur some performance cost anyway. - context_manager: self.conversation.context_manager.clone(), - transcript: self.conversation.transcript.clone(), - failed_request_ids: self.failed_request_ids.clone(), - tool_permissions: self.tool_permissions.permissions.clone(), - }); - } - } - - async fn print_tool_description(&mut self, os: &Os, tool_index: usize, trusted: bool) -> Result<(), ChatError> { - let tool_use = &self.tool_uses[tool_index]; - - queue!( - self.stdout, - style::SetForegroundColor(Color::Magenta), - style::Print(format!( - "🛠️ Using tool: {}{}", - tool_use.tool.display_name(), - if trusted { " (trusted)".dark_green() } else { "".reset() } - )), - style::SetForegroundColor(Color::Reset) - )?; - if let Tool::Custom(ref tool) = tool_use.tool { - queue!( - self.stdout, - style::SetForegroundColor(Color::Reset), - style::Print(" from mcp server "), - style::SetForegroundColor(Color::Magenta), - style::Print(tool.client.get_server_name()), - style::SetForegroundColor(Color::Reset), - )?; - } - - execute!( - self.stdout, - style::Print("\n"), - style::Print(CONTINUATION_LINE), - style::Print("\n"), - style::Print(TOOL_BULLET) - )?; - - tool_use - .tool - .queue_description(os, &mut self.stdout) - .await - .map_err(|e| ChatError::Custom(format!("failed to print tool, `{}`: {}", tool_use.name, e).into()))?; - - Ok(()) - } - - /// Helper function to read user input with a prompt and Ctrl+C handling - fn read_user_input(&mut self, prompt: &str, exit_on_single_ctrl_c: bool) -> Option { - let mut ctrl_c = false; - loop { - match (self.input_source.read_line(Some(prompt)), ctrl_c) { - (Ok(Some(line)), _) => { - if line.trim().is_empty() { - continue; // Reprompt if the input is empty - } - return Some(line); - }, - (Ok(None), false) => { - if exit_on_single_ctrl_c { - return None; - } - execute!( - self.stderr, - style::Print(format!( - "\n(To exit the CLI, press Ctrl+C or Ctrl+D again or type {})\n\n", - "/quit".green() - )) - ) - .unwrap_or_default(); - ctrl_c = true; - }, - (Ok(None), true) => return None, // Exit if Ctrl+C was pressed twice - (Err(_), _) => return None, - } - } - } - - /// Helper function to generate a prompt based on the current context - fn generate_tool_trust_prompt(&mut self) -> String { - let profile = self.conversation.current_profile().map(|s| s.to_string()); - let all_trusted = self.all_tools_trusted(); - prompt::generate_prompt(profile.as_deref(), all_trusted) - } - - async fn send_tool_use_telemetry(&mut self, os: &Os) { - for (_, mut event) in self.tool_use_telemetry_events.drain() { - event.user_input_id = match self.tool_use_status { - ToolUseStatus::Idle => self.conversation.message_id(), - ToolUseStatus::RetryInProgress(ref id) => Some(id.as_str()), - } - .map(|v| v.to_string()); - - os.telemetry.send_tool_use_suggested(&os.database, event).await.ok(); - } - } - - fn terminal_width(&self) -> usize { - (self.terminal_width_provider)().unwrap_or(80) - } - - fn all_tools_trusted(&mut self) -> bool { - self.conversation.tools.values().flatten().all(|t| match t { - FigTool::ToolSpecification(t) => self.tool_permissions.is_trusted(&t.name), - }) - } - - /// Display character limit warnings based on current conversation size - async fn display_char_warnings(&mut self, os: &Os) -> Result<(), ChatError> { - let warning_level = self.conversation.get_token_warning_level(os).await?; - - match warning_level { - TokenWarningLevel::Critical => { - // Memory constraint warning with gentler wording - execute!( - self.stderr, - style::SetForegroundColor(Color::Yellow), - style::SetAttribute(Attribute::Bold), - style::Print("\n⚠️ This conversation is getting lengthy.\n"), - style::SetAttribute(Attribute::Reset), - style::Print( - "To ensure continued smooth operation, please use /compact to summarize the conversation.\n\n" - ), - style::SetForegroundColor(Color::Reset) - )?; - }, - TokenWarningLevel::None => { - // No warning needed - }, - } - - Ok(()) - } - - #[allow(clippy::too_many_arguments)] - async fn send_chat_telemetry( - &self, - os: &Os, - request_id: Option, - result: TelemetryResult, - reason: Option, - reason_desc: Option, - status_code: Option, - ) { - os.telemetry - .send_chat_added_message( - &os.database, - self.conversation.conversation_id().to_owned(), - self.conversation.message_id().map(|s| s.to_owned()), - request_id, - self.conversation.context_message_length(), - result, - reason, - reason_desc, - status_code, - self.conversation.model.clone(), - ) - .await - .ok(); - } - - async fn send_error_telemetry( - &self, - os: &Os, - reason: String, - reason_desc: Option, - status_code: Option, - ) { - os.telemetry - .send_response_error( - &os.database, - self.conversation.conversation_id().to_owned(), - self.conversation.context_message_length(), - TelemetryResult::Failed, - Some(reason), - reason_desc, - status_code, - ) - .await - .ok(); - } -} - -/// Replaces amzn_codewhisperer_client::types::SubscriptionStatus with a more descriptive type. -/// See response expectations in [`get_subscription_status`] for reasoning. -#[derive(Debug, Clone, PartialEq, Eq)] -enum ActualSubscriptionStatus { - Active, // User has paid for this month - Expiring, // User has paid for this month but cancelled - None, // User has not paid for this month -} - -// NOTE: The subscription API behaves in a non-intuitive way. We expect the following responses: -// -// 1. SubscriptionStatus::Active: -// - The user *has* a subscription, but it is set to *not auto-renew* (i.e., cancelled). -// - We return ActualSubscriptionStatus::Expiring to indicate they are eligible to re-subscribe -// -// 2. SubscriptionStatus::Inactive: -// - The user has no subscription at all (no Pro access). -// - We return ActualSubscriptionStatus::None to indicate they are eligible to subscribe. -// -// 3. ConflictException (as an error): -// - The user already has an active subscription *with auto-renewal enabled*. -// - We return ActualSubscriptionStatus::Active since they don’t need to subscribe again. -// -// Also, it is currently not possible to subscribe or re-subscribe via console, only IDE/CLI. -async fn get_subscription_status(os: &mut Os) -> Result { - if is_idc_user(&os.database).await? { - return Ok(ActualSubscriptionStatus::Active); - } - - match os.client.create_subscription_token().await { - Ok(response) => match response.status() { - SubscriptionStatus::Active => Ok(ActualSubscriptionStatus::Expiring), - SubscriptionStatus::Inactive => Ok(ActualSubscriptionStatus::None), - _ => Ok(ActualSubscriptionStatus::None), - }, - Err(ApiClientError::CreateSubscriptionToken(e)) => { - let sdk_error_code = e.as_service_error().and_then(|err| err.meta().code()); - - if sdk_error_code.is_some_and(|c| c.contains("ConflictException")) { - Ok(ActualSubscriptionStatus::Active) - } else { - Err(e.into()) - } - }, - Err(e) => Err(e.into()), - } -} - -async fn get_subscription_status_with_spinner( - os: &mut Os, - output: &mut impl Write, -) -> Result { - return with_spinner(output, "Checking subscription status...", || async { - get_subscription_status(os).await - }) - .await; -} - -async fn with_spinner(output: &mut impl std::io::Write, spinner_text: &str, f: F) -> Result -where - F: FnOnce() -> Fut, - Fut: std::future::Future>, -{ - queue!(output, cursor::Hide,).ok(); - let spinner = Some(Spinner::new(Spinners::Dots, spinner_text.to_owned())); - - let result = f().await; - - if let Some(mut s) = spinner { - s.stop(); - let _ = queue!( - output, - terminal::Clear(terminal::ClearType::CurrentLine), - cursor::MoveToColumn(0), - ); - } - - result -} - -/// Checks if an input may be referencing a file and should not be handled as a typical slash -/// command. If true, then return [Option::Some], otherwise [Option::None]. -fn does_input_reference_file(input: &str) -> Option { - let after_slash = input.strip_prefix("/")?; - - if let Some(first) = shlex::split(after_slash).unwrap_or_default().first() { - let looks_like_path = - first.contains(MAIN_SEPARATOR) || first.contains('/') || first.contains('\\') || first.contains('.'); - - if looks_like_path { - return Some(ChatState::HandleInput { - input: after_slash.to_string(), - }); - } - } - - None -} - -#[cfg(test)] -mod tests { - use super::*; - - #[tokio::test] - async fn test_flow() { - let mut os = Os::new().await.unwrap(); - os.client.set_mock_output(serde_json::json!([ - [ - "Sure, I'll create a file for you", - { - "tool_use_id": "1", - "name": "fs_write", - "args": { - "command": "create", - "file_text": "Hello, world!", - "path": "/file.txt", - } - } - ], - [ - "Hope that looks good to you!", - ], - ])); - - let tool_manager = ToolManager::default(); - let tool_config = serde_json::from_str::>(include_str!("tools/tool_index.json")) - .expect("Tools failed to load"); - ChatSession::new( - &mut os, - std::io::stdout(), - std::io::stderr(), - "fake_conv_id", - None, - InputSource::new_mock(vec![ - "create a new file".to_string(), - "y".to_string(), - "exit".to_string(), - ]), - false, - || Some(80), - tool_manager, - None, - None, - tool_config, - ToolPermissions::new(0), - true, - ) - .await - .unwrap() - .spawn(&mut os) - .await - .unwrap(); - - assert_eq!(os.fs.read_to_string("/file.txt").await.unwrap(), "Hello, world!\n"); - } - - #[tokio::test] - async fn test_flow_tool_permissions() { - let mut os = Os::new().await.unwrap(); - os.client.set_mock_output(serde_json::json!([ - [ - "Ok", - { - "tool_use_id": "1", - "name": "fs_write", - "args": { - "command": "create", - "file_text": "Hello, world!", - "path": "/file1.txt", - } - } - ], - [ - "Done", - ], - [ - "Ok", - { - "tool_use_id": "1", - "name": "fs_write", - "args": { - "command": "create", - "file_text": "Hello, world!", - "path": "/file2.txt", - } - } - ], - [ - "Done", - ], - [ - "Ok", - { - "tool_use_id": "1", - "name": "fs_write", - "args": { - "command": "create", - "file_text": "Hello, world!", - "path": "/file3.txt", - } - } - ], - [ - "Done", - ], - [ - "Ok", - { - "tool_use_id": "1", - "name": "fs_write", - "args": { - "command": "create", - "file_text": "Hello, world!", - "path": "/file4.txt", - } - } - ], - [ - "Ok, I won't make it.", - ], - [ - "Ok", - { - "tool_use_id": "1", - "name": "fs_write", - "args": { - "command": "create", - "file_text": "Hello, world!", - "path": "/file5.txt", - } - } - ], - [ - "Done", - ], - [ - "Ok", - { - "tool_use_id": "1", - "name": "fs_write", - "args": { - "command": "create", - "file_text": "Hello, world!", - "path": "/file6.txt", - } - } - ], - [ - "Ok, I won't make it.", - ], - ])); - - let tool_manager = ToolManager::default(); - let tool_config = serde_json::from_str::>(include_str!("tools/tool_index.json")) - .expect("Tools failed to load"); - ChatSession::new( - &mut os, - std::io::stdout(), - std::io::stderr(), - "fake_conv_id", - None, - InputSource::new_mock(vec![ - "/tools".to_string(), - "/tools help".to_string(), - "create a new file".to_string(), - "y".to_string(), - "create a new file".to_string(), - "t".to_string(), - "create a new file".to_string(), // should make without prompting due to 't' - "/tools untrust fs_write".to_string(), - "create a file".to_string(), // prompt again due to untrust - "n".to_string(), // cancel - "/tools trust fs_write".to_string(), - "create a file".to_string(), // again without prompting due to '/tools trust' - "/tools reset".to_string(), - "create a file".to_string(), // prompt again due to reset - "n".to_string(), // cancel - "exit".to_string(), - ]), - false, - || Some(80), - tool_manager, - None, - None, - tool_config, - ToolPermissions::new(0), - true, - ) - .await - .unwrap() - .spawn(&mut os) - .await - .unwrap(); - - assert_eq!(os.fs.read_to_string("/file2.txt").await.unwrap(), "Hello, world!\n"); - assert_eq!(os.fs.read_to_string("/file3.txt").await.unwrap(), "Hello, world!\n"); - assert!(!os.fs.exists("/file4.txt")); - assert_eq!(os.fs.read_to_string("/file5.txt").await.unwrap(), "Hello, world!\n"); - assert!(!os.fs.exists("/file6.txt")); - } - - #[tokio::test] - async fn test_flow_multiple_tools() { - // let _ = tracing_subscriber::fmt::try_init(); - let mut os = Os::new().await.unwrap(); - os.client.set_mock_output(serde_json::json!([ - [ - "Sure, I'll create a file for you", - { - "tool_use_id": "1", - "name": "fs_write", - "args": { - "command": "create", - "file_text": "Hello, world!", - "path": "/file1.txt", - } - }, - { - "tool_use_id": "2", - "name": "fs_write", - "args": { - "command": "create", - "file_text": "Hello, world!", - "path": "/file2.txt", - } - } - ], - [ - "Done", - ], - [ - "Sure, I'll create a file for you", - { - "tool_use_id": "1", - "name": "fs_write", - "args": { - "command": "create", - "file_text": "Hello, world!", - "path": "/file3.txt", - } - }, - { - "tool_use_id": "2", - "name": "fs_write", - "args": { - "command": "create", - "file_text": "Hello, world!", - "path": "/file4.txt", - } - } - ], - [ - "Done", - ], - ])); - - let tool_manager = ToolManager::default(); - let tool_config = serde_json::from_str::>(include_str!("tools/tool_index.json")) - .expect("Tools failed to load"); - ChatSession::new( - &mut os, - std::io::stdout(), - std::io::stderr(), - "fake_conv_id", - None, - InputSource::new_mock(vec![ - "create 2 new files parallel".to_string(), - "t".to_string(), - "/tools reset".to_string(), - "create 2 new files parallel".to_string(), - "y".to_string(), - "y".to_string(), - "exit".to_string(), - ]), - false, - || Some(80), - tool_manager, - None, - None, - tool_config, - ToolPermissions::new(0), - true, - ) - .await - .unwrap() - .spawn(&mut os) - .await - .unwrap(); - - assert_eq!(os.fs.read_to_string("/file1.txt").await.unwrap(), "Hello, world!\n"); - assert_eq!(os.fs.read_to_string("/file2.txt").await.unwrap(), "Hello, world!\n"); - assert_eq!(os.fs.read_to_string("/file3.txt").await.unwrap(), "Hello, world!\n"); - assert_eq!(os.fs.read_to_string("/file4.txt").await.unwrap(), "Hello, world!\n"); - } - - #[tokio::test] - async fn test_flow_tools_trust_all() { - // let _ = tracing_subscriber::fmt::try_init(); - let mut os = Os::new().await.unwrap(); - os.client.set_mock_output(serde_json::json!([ - [ - "Sure, I'll create a file for you", - { - "tool_use_id": "1", - "name": "fs_write", - "args": { - "command": "create", - "file_text": "Hello, world!", - "path": "/file1.txt", - } - } - ], - [ - "Done", - ], - [ - "Sure, I'll create a file for you", - { - "tool_use_id": "1", - "name": "fs_write", - "args": { - "command": "create", - "file_text": "Hello, world!", - "path": "/file3.txt", - } - } - ], - [ - "Ok I won't.", - ], - ])); - - let tool_manager = ToolManager::default(); - let tool_config = serde_json::from_str::>(include_str!("tools/tool_index.json")) - .expect("Tools failed to load"); - ChatSession::new( - &mut os, - std::io::stdout(), - std::io::stderr(), - "fake_conv_id", - None, - InputSource::new_mock(vec![ - "/tools trust-all".to_string(), - "create a new file".to_string(), - "/tools reset".to_string(), - "create a new file".to_string(), - "exit".to_string(), - ]), - false, - || Some(80), - tool_manager, - None, - None, - tool_config, - ToolPermissions::new(0), - true, - ) - .await - .unwrap() - .spawn(&mut os) - .await - .unwrap(); - - assert_eq!(os.fs.read_to_string("/file1.txt").await.unwrap(), "Hello, world!\n"); - assert!(!os.fs.exists("/file2.txt")); - } - - #[test] - fn test_editor_content_processing() { - // Since we no longer have template replacement, this test is simplified - let cases = vec![ - ("My content", "My content"), - ("My content with newline\n", "My content with newline"), - ("", ""), - ]; - - for (input, expected) in cases { - let processed = input.trim().to_string(); - assert_eq!(processed, expected.trim().to_string(), "Failed for input: {}", input); - } - } - - #[tokio::test] - async fn test_subscribe_flow() { - let mut os = Os::new().await.unwrap(); - os.client.set_mock_output(serde_json::Value::Array(vec![])); - let tool_manager = ToolManager::default(); - let tool_config = serde_json::from_str::>(include_str!("tools/tool_index.json")) - .expect("Tools failed to load"); - ChatSession::new( - &mut os, - std::io::stdout(), - std::io::stderr(), - "fake_conv_id", - None, - InputSource::new_mock(vec!["/subscribe".to_string(), "y".to_string(), "/quit".to_string()]), - false, - || Some(80), - tool_manager, - None, - None, - tool_config, - ToolPermissions::new(0), - true, - ) - .await - .unwrap() - .spawn(&mut os) - .await - .unwrap(); - } - - #[test] - fn test_does_input_reference_file() { - let tests = &[ - ( - r"/Users/user/Desktop/Screenshot\ 2025-06-30\ at\ 2.13.34 PM.png read this image for me", - true, - ), - ("/path/to/file.json", true), - ("/save output.json", false), - ("~/does/not/start/with/slash", false), - ]; - for (input, expected) in tests { - let actual = does_input_reference_file(input).is_some(); - assert_eq!(actual, *expected, "expected {} for input {}", expected, input); - } - } -} diff --git a/crates/chat-cli/src/cli/chat/parse.rs b/crates/chat-cli/src/cli/chat/parse.rs deleted file mode 100644 index f59049a33..000000000 --- a/crates/chat-cli/src/cli/chat/parse.rs +++ /dev/null @@ -1,762 +0,0 @@ -use std::io::Write; - -use crossterm::style::{ - Attribute, - Color, - Stylize, -}; -use crossterm::{ - Command, - style, -}; -use unicode_width::{ - UnicodeWidthChar, - UnicodeWidthStr, -}; -use winnow::Partial; -use winnow::ascii::{ - self, - digit1, - space0, - space1, - till_line_ending, -}; -use winnow::combinator::{ - alt, - delimited, - preceded, - repeat, - terminated, -}; -use winnow::error::{ - ErrMode, - ErrorKind, - ParserError, -}; -use winnow::prelude::*; -use winnow::stream::{ - AsChar, - Stream, -}; -use winnow::token::{ - any, - take_till, - take_until, - take_while, -}; - -const CODE_COLOR: Color = Color::Green; -const HEADING_COLOR: Color = Color::Magenta; -const BLOCKQUOTE_COLOR: Color = Color::DarkGrey; -const URL_TEXT_COLOR: Color = Color::Blue; -const URL_LINK_COLOR: Color = Color::DarkGrey; - -const DEFAULT_RULE_WIDTH: usize = 40; - -#[derive(Debug, thiserror::Error)] -pub enum Error<'a> { - #[error(transparent)] - Stdio(#[from] std::io::Error), - #[error("parse error {1}, input {0}")] - Winnow(Partial<&'a str>, ErrorKind), -} - -impl<'a> ParserError> for Error<'a> { - fn from_error_kind(input: &Partial<&'a str>, kind: ErrorKind) -> Self { - Self::Winnow(*input, kind) - } - - fn append( - self, - _input: &Partial<&'a str>, - _checkpoint: &winnow::stream::Checkpoint< - winnow::stream::Checkpoint<&'a str, &'a str>, - winnow::Partial<&'a str>, - >, - _kind: ErrorKind, - ) -> Self { - self - } -} - -#[derive(Debug)] -pub struct ParseState { - pub terminal_width: Option, - pub column: usize, - pub in_codeblock: bool, - pub bold: bool, - pub italic: bool, - pub strikethrough: bool, - pub set_newline: bool, - pub newline: bool, - pub citations: Vec<(String, String)>, -} - -impl ParseState { - pub fn new(terminal_width: Option) -> Self { - Self { - terminal_width, - column: 0, - in_codeblock: false, - bold: false, - italic: false, - strikethrough: false, - set_newline: false, - newline: true, - citations: vec![], - } - } -} - -pub fn interpret_markdown<'a, 'b>( - mut i: Partial<&'a str>, - mut o: impl Write + 'b, - state: &mut ParseState, -) -> PResult, Error<'a>> { - let mut error: Option> = None; - let start = i.checkpoint(); - - macro_rules! stateful_alt { - ($($fns:ident),*) => { - $({ - i.reset(&start); - match $fns(&mut o, state).parse_next(&mut i) { - Err(ErrMode::Backtrack(e)) => { - error = match error { - Some(error) => Some(error.or(e)), - None => Some(e), - }; - }, - res => { - return res.map(|_| i); - } - } - })* - }; - } - - match state.in_codeblock { - false => { - stateful_alt!( - // This pattern acts as a short circuit for alphanumeric plaintext - // More importantly, it's needed to support manual wordwrapping - text, - // multiline patterns - blockquote, - // linted_codeblock, - codeblock_begin, - // single line patterns - horizontal_rule, - heading, - bulleted_item, - numbered_item, - // inline patterns - code, - citation, - url, - bold, - italic, - strikethrough, - // symbols - less_than, - greater_than, - ampersand, - quot, - line_ending, - // fallback - fallback - ); - }, - true => { - stateful_alt!( - codeblock_less_than, - codeblock_greater_than, - codeblock_ampersand, - codeblock_quot, - codeblock_end, - codeblock_line_ending, - codeblock_fallback - ); - }, - } - - match error { - Some(e) => Err(ErrMode::Backtrack(e.append(&i, &start, ErrorKind::Alt))), - None => Err(ErrMode::assert(&i, "no parsers")), - } -} - -fn text<'a, 'b>( - mut o: impl Write + 'b, - state: &'b mut ParseState, -) -> impl FnMut(&mut Partial<&'a str>) -> PResult<(), Error<'a>> + 'b { - move |i| { - let content = take_while(1.., |t| AsChar::is_alphanum(t) || "+,.!?\"".contains(t)).parse_next(i)?; - queue_newline_or_advance(&mut o, state, content.width())?; - queue(&mut o, style::Print(content))?; - Ok(()) - } -} - -fn heading<'a, 'b>( - mut o: impl Write + 'b, - state: &'b mut ParseState, -) -> impl FnMut(&mut Partial<&'a str>) -> PResult<(), Error<'a>> + 'b { - move |i| { - if !state.newline { - return Err(ErrMode::from_error_kind(i, ErrorKind::Fail)); - } - - let level = terminated(take_while(1.., |c| c == '#'), space1).parse_next(i)?; - let print = format!("{level} "); - - queue_newline_or_advance(&mut o, state, print.width())?; - queue(&mut o, style::SetForegroundColor(HEADING_COLOR))?; - queue(&mut o, style::SetAttribute(Attribute::Bold))?; - queue(&mut o, style::Print(print)) - } -} - -fn bulleted_item<'a, 'b>( - mut o: impl Write + 'b, - state: &'b mut ParseState, -) -> impl FnMut(&mut Partial<&'a str>) -> PResult<(), Error<'a>> + 'b { - move |i| { - if !state.newline { - return Err(ErrMode::from_error_kind(i, ErrorKind::Fail)); - } - - let ws = (space0, alt(("-", "*")), space1).parse_next(i)?.0; - let print = format!("{ws}• "); - - queue_newline_or_advance(&mut o, state, print.width())?; - queue(&mut o, style::Print(print)) - } -} - -fn numbered_item<'a, 'b>( - mut o: impl Write + 'b, - state: &'b mut ParseState, -) -> impl FnMut(&mut Partial<&'a str>) -> PResult<(), Error<'a>> + 'b { - move |i| { - if !state.newline { - return Err(ErrMode::from_error_kind(i, ErrorKind::Fail)); - } - - let (ws, digits, _, _) = (space0, digit1, ".", space1).parse_next(i)?; - let print = format!("{ws}{digits}. "); - - queue_newline_or_advance(&mut o, state, print.width())?; - queue(&mut o, style::Print(print)) - } -} - -fn horizontal_rule<'a, 'b>( - mut o: impl Write + 'b, - state: &'b mut ParseState, -) -> impl FnMut(&mut Partial<&'a str>) -> PResult<(), Error<'a>> + 'b { - move |i| { - if !state.newline { - return Err(ErrMode::from_error_kind(i, ErrorKind::Fail)); - } - - ( - space0, - alt((take_while(3.., '-'), take_while(3.., '*'), take_while(3.., '_'))), - ) - .parse_next(i)?; - - state.column = 0; - state.set_newline = true; - - let rule_width = state.terminal_width.unwrap_or(DEFAULT_RULE_WIDTH); - queue(&mut o, style::Print(format!("{}\n", "━".repeat(rule_width)))) - } -} - -fn code<'a, 'b>( - mut o: impl Write + 'b, - state: &'b mut ParseState, -) -> impl FnMut(&mut Partial<&'a str>) -> PResult<(), Error<'a>> + 'b { - move |i| { - "`".parse_next(i)?; - let code = terminated(take_until(0.., "`"), "`").parse_next(i)?; - let out = code.replace("&", "&").replace(">", ">").replace("<", "<"); - - queue_newline_or_advance(&mut o, state, out.width())?; - queue(&mut o, style::SetForegroundColor(Color::Green))?; - queue(&mut o, style::Print(out))?; - queue(&mut o, style::ResetColor) - } -} - -fn blockquote<'a, 'b>( - mut o: impl Write + 'b, - state: &'b mut ParseState, -) -> impl FnMut(&mut Partial<&'a str>) -> PResult<(), Error<'a>> + 'b { - move |i| { - if !state.newline { - return Err(ErrMode::from_error_kind(i, ErrorKind::Fail)); - } - - let level = repeat::<_, _, Vec<&'_ str>, _, _>(1.., terminated(">", space0)) - .parse_next(i)? - .len(); - let print = "│ ".repeat(level); - - queue(&mut o, style::SetForegroundColor(BLOCKQUOTE_COLOR))?; - queue_newline_or_advance(&mut o, state, print.width())?; - queue(&mut o, style::Print(print)) - } -} - -fn bold<'a, 'b>( - mut o: impl Write + 'b, - state: &'b mut ParseState, -) -> impl FnMut(&mut Partial<&'a str>) -> PResult<(), Error<'a>> + 'b { - move |i| { - match state.newline { - true => { - alt(("**", "__")).parse_next(i)?; - queue(&mut o, style::SetAttribute(Attribute::Bold))?; - }, - false => match state.bold { - true => { - alt(("**", "__")).parse_next(i)?; - queue(&mut o, style::SetAttribute(Attribute::NormalIntensity))?; - }, - false => { - preceded(space1, alt(("**", "__"))).parse_next(i)?; - queue(&mut o, style::Print(' '))?; - queue(&mut o, style::SetAttribute(Attribute::Bold))?; - }, - }, - }; - - state.bold = !state.bold; - - Ok(()) - } -} - -fn italic<'a, 'b>( - mut o: impl Write + 'b, - state: &'b mut ParseState, -) -> impl FnMut(&mut Partial<&'a str>) -> PResult<(), Error<'a>> + 'b { - move |i| { - match state.newline { - true => { - alt(("*", "_")).parse_next(i)?; - queue(&mut o, style::SetAttribute(Attribute::Italic))?; - }, - false => match state.italic { - true => { - alt(("*", "_")).parse_next(i)?; - queue(&mut o, style::SetAttribute(Attribute::NoItalic))?; - }, - false => { - preceded(space1, alt(("*", "_"))).parse_next(i)?; - queue(&mut o, style::Print(' '))?; - queue(&mut o, style::SetAttribute(Attribute::Italic))?; - }, - }, - }; - - state.italic = !state.italic; - - Ok(()) - } -} - -fn strikethrough<'a, 'b>( - mut o: impl Write + 'b, - state: &'b mut ParseState, -) -> impl FnMut(&mut Partial<&'a str>) -> PResult<(), Error<'a>> + 'b { - move |i| { - "~~".parse_next(i)?; - state.strikethrough = !state.strikethrough; - match state.strikethrough { - true => queue(&mut o, style::SetAttribute(Attribute::CrossedOut)), - false => queue(&mut o, style::SetAttribute(Attribute::NotCrossedOut)), - } - } -} - -fn citation<'a, 'b>( - mut o: impl Write + 'b, - state: &'b mut ParseState, -) -> impl FnMut(&mut Partial<&'a str>) -> PResult<(), Error<'a>> + 'b { - move |i| { - let num = delimited("[[", digit1, "]]").parse_next(i)?; - let link = delimited("(", take_till(0.., ')'), ")").parse_next(i)?; - - state.citations.push((num.to_owned(), link.to_owned())); - - queue_newline_or_advance(&mut o, state, num.width() + 1)?; - queue(&mut o, style::SetForegroundColor(URL_TEXT_COLOR))?; - queue(&mut o, style::Print(format!("[^{num}]")))?; - queue(&mut o, style::ResetColor) - } -} - -fn url<'a, 'b>( - mut o: impl Write + 'b, - state: &'b mut ParseState, -) -> impl FnMut(&mut Partial<&'a str>) -> PResult<(), Error<'a>> + 'b { - move |i| { - // Save the current input position - let start = i.checkpoint(); - - // Try to match the first part of URL pattern "[text]" - let display = match delimited::<_, _, _, _, Error<'a>, _, _, _>("[", take_until(1.., "]("), "]").parse_next(i) { - Ok(display) => display, - Err(_) => { - // If it doesn't match, reset position and fail - i.reset(&start); - return Err(ErrMode::from_error_kind(i, ErrorKind::Fail)); - }, - }; - - // Try to match the second part of URL pattern "(url)" - let link = match delimited::<_, _, _, _, Error<'a>, _, _, _>("(", take_till(0.., ')'), ")").parse_next(i) { - Ok(link) => link, - Err(_) => { - // If it doesn't match, reset position and fail - i.reset(&start); - return Err(ErrMode::from_error_kind(i, ErrorKind::Fail)); - }, - }; - - // Only generate output if the complete URL pattern matches - queue_newline_or_advance(&mut o, state, display.width() + 1)?; - queue(&mut o, style::SetForegroundColor(URL_TEXT_COLOR))?; - queue(&mut o, style::Print(format!("{display} ")))?; - queue(&mut o, style::SetForegroundColor(URL_LINK_COLOR))?; - state.column += link.width(); - queue(&mut o, style::Print(link))?; - queue(&mut o, style::ResetColor) - } -} - -fn less_than<'a, 'b>( - mut o: impl Write + 'b, - state: &'b mut ParseState, -) -> impl FnMut(&mut Partial<&'a str>) -> PResult<(), Error<'a>> + 'b { - move |i| { - "<".parse_next(i)?; - queue_newline_or_advance(&mut o, state, 1)?; - queue(&mut o, style::Print('<')) - } -} - -fn greater_than<'a, 'b>( - mut o: impl Write + 'b, - state: &'b mut ParseState, -) -> impl FnMut(&mut Partial<&'a str>) -> PResult<(), Error<'a>> + 'b { - move |i| { - ">".parse_next(i)?; - queue_newline_or_advance(&mut o, state, 1)?; - queue(&mut o, style::Print('>')) - } -} - -fn ampersand<'a, 'b>( - mut o: impl Write + 'b, - state: &'b mut ParseState, -) -> impl FnMut(&mut Partial<&'a str>) -> PResult<(), Error<'a>> + 'b { - move |i| { - "&".parse_next(i)?; - queue_newline_or_advance(&mut o, state, 1)?; - queue(&mut o, style::Print('&')) - } -} - -fn quot<'a, 'b>( - mut o: impl Write + 'b, - state: &'b mut ParseState, -) -> impl FnMut(&mut Partial<&'a str>) -> PResult<(), Error<'a>> + 'b { - move |i| { - """.parse_next(i)?; - queue_newline_or_advance(&mut o, state, 1)?; - queue(&mut o, style::Print('"')) - } -} - -fn line_ending<'a, 'b>( - mut o: impl Write + 'b, - state: &'b mut ParseState, -) -> impl FnMut(&mut Partial<&'a str>) -> PResult<(), Error<'a>> + 'b { - move |i| { - ascii::line_ending.parse_next(i)?; - - state.column = 0; - state.set_newline = true; - - queue(&mut o, style::ResetColor)?; - queue(&mut o, style::SetAttribute(style::Attribute::Reset))?; - queue(&mut o, style::Print("\n")) - } -} - -fn fallback<'a, 'b>( - mut o: impl Write + 'b, - state: &'b mut ParseState, -) -> impl FnMut(&mut Partial<&'a str>) -> PResult<(), Error<'a>> + 'b { - move |i| { - let fallback = any.parse_next(i)?; - if let Some(width) = fallback.width() { - queue_newline_or_advance(&mut o, state, width)?; - if fallback != ' ' || state.column != 1 { - queue(&mut o, style::Print(fallback))?; - } - } - - Ok(()) - } -} - -fn queue_newline_or_advance<'a, 'b>( - mut o: impl Write + 'b, - state: &'b mut ParseState, - width: usize, -) -> Result<(), ErrMode>> { - if let Some(terminal_width) = state.terminal_width { - if state.column > 0 && state.column + width > terminal_width { - state.column = width; - queue(&mut o, style::Print('\n'))?; - return Ok(()); - } - } - - // else - state.column += width; - - Ok(()) -} - -fn queue<'a>(mut o: impl Write, command: impl Command) -> Result<(), ErrMode>> { - use crossterm::QueueableCommand; - o.queue(command).map_err(|err| ErrMode::Cut(Error::Stdio(err)))?; - Ok(()) -} - -fn codeblock_begin<'a, 'b>( - mut o: impl Write + 'b, - state: &'b mut ParseState, -) -> impl FnMut(&mut Partial<&'a str>) -> PResult<(), Error<'a>> + 'b { - move |i| { - if !state.newline { - return Err(ErrMode::from_error_kind(i, ErrorKind::Fail)); - } - - // We don't want to do anything special to text inside codeblocks so we wait for all of it - // The alternative is to switch between parse rules at the top level but that's slightly involved - let language = preceded("```", till_line_ending).parse_next(i)?; - ascii::line_ending.parse_next(i)?; - - state.in_codeblock = true; - - if !language.is_empty() { - queue(&mut o, style::Print(format!("{}\n", language).bold()))?; - } - - queue(&mut o, style::SetForegroundColor(CODE_COLOR))?; - - Ok(()) - } -} - -fn codeblock_end<'a, 'b>( - mut o: impl Write + 'b, - state: &'b mut ParseState, -) -> impl FnMut(&mut Partial<&'a str>) -> PResult<(), Error<'a>> + 'b { - move |i| { - "```".parse_next(i)?; - state.in_codeblock = false; - queue(&mut o, style::ResetColor) - } -} - -fn codeblock_less_than<'a, 'b>( - mut o: impl Write + 'b, - _state: &'b mut ParseState, -) -> impl FnMut(&mut Partial<&'a str>) -> PResult<(), Error<'a>> + 'b { - move |i| { - "<".parse_next(i)?; - queue(&mut o, style::Print('<')) - } -} - -fn codeblock_greater_than<'a, 'b>( - mut o: impl Write + 'b, - _state: &'b mut ParseState, -) -> impl FnMut(&mut Partial<&'a str>) -> PResult<(), Error<'a>> + 'b { - move |i| { - ">".parse_next(i)?; - queue(&mut o, style::Print('>')) - } -} - -fn codeblock_ampersand<'a, 'b>( - mut o: impl Write + 'b, - _state: &'b mut ParseState, -) -> impl FnMut(&mut Partial<&'a str>) -> PResult<(), Error<'a>> + 'b { - move |i| { - "&".parse_next(i)?; - queue(&mut o, style::Print('&')) - } -} - -fn codeblock_quot<'a, 'b>( - mut o: impl Write + 'b, - _state: &'b mut ParseState, -) -> impl FnMut(&mut Partial<&'a str>) -> PResult<(), Error<'a>> + 'b { - move |i| { - """.parse_next(i)?; - queue(&mut o, style::Print('"')) - } -} - -fn codeblock_line_ending<'a, 'b>( - mut o: impl Write + 'b, - _state: &'b mut ParseState, -) -> impl FnMut(&mut Partial<&'a str>) -> PResult<(), Error<'a>> + 'b { - move |i| { - ascii::line_ending.parse_next(i)?; - queue(&mut o, style::Print("\n")) - } -} - -fn codeblock_fallback<'a, 'b>( - mut o: impl Write + 'b, - _state: &'b mut ParseState, -) -> impl FnMut(&mut Partial<&'a str>) -> PResult<(), Error<'a>> + 'b { - move |i| { - let fallback = any.parse_next(i)?; - queue(&mut o, style::Print(fallback)) - } -} - -#[cfg(test)] -mod tests { - use std::io::Write; - - use winnow::stream::Offset; - - use super::*; - - macro_rules! validate { - ($test:ident, $input:literal, [$($commands:expr),+ $(,)?]) => { - #[test] - fn $test() -> eyre::Result<()> { - use crossterm::ExecutableCommand; - - let mut input = $input.trim().to_owned(); - input.push(' '); - input.push(' '); - - let mut state = ParseState::new(Some(80)); - let mut presult = vec![]; - let mut offset = 0; - - loop { - let input = Partial::new(&input[offset..]); - match interpret_markdown(input, &mut presult, &mut state) { - Ok(parsed) => { - offset += parsed.offset_from(&input); - state.newline = state.set_newline; - state.set_newline = false; - }, - Err(err) => match err.into_inner() { - Some(err) => panic!("{err}"), - None => break, // Data was incomplete - }, - } - } - - presult.flush()?; - let presult = String::from_utf8(presult)?; - - let mut wresult: Vec = vec![]; - $(wresult.execute($commands)?;)+ - let wresult = String::from_utf8(wresult)?; - - assert_eq!(presult.trim(), wresult); - - Ok(()) - } - }; - } - - validate!(text_1, "hello world!", [style::Print("hello world!")]); - validate!(linted_codeblock_1, "```java\nhello world!```", [ - style::SetAttribute(Attribute::Bold), - style::Print("java\n"), - style::SetAttribute(Attribute::Reset), - style::SetForegroundColor(CODE_COLOR), - style::Print("hello world!"), - style::ResetColor, - ]); - validate!(code_1, "`print`", [ - style::SetForegroundColor(CODE_COLOR), - style::Print("print"), - style::ResetColor, - ]); - validate!(url_1, "[google](google.com)", [ - style::SetForegroundColor(URL_TEXT_COLOR), - style::Print("google "), - style::SetForegroundColor(URL_LINK_COLOR), - style::Print("google.com"), - style::ResetColor, - ]); - validate!(citation_1, "[[1]](google.com)", [ - style::SetForegroundColor(URL_TEXT_COLOR), - style::Print("[^1]"), - style::ResetColor, - ]); - validate!(bold_1, "**hello**", [ - style::SetAttribute(Attribute::Bold), - style::Print("hello"), - style::SetAttribute(Attribute::NormalIntensity) - ]); - validate!(italic_1, "*hello*", [ - style::SetAttribute(Attribute::Italic), - style::Print("hello"), - style::SetAttribute(Attribute::NoItalic) - ]); - validate!(strikethrough_1, "~~hello~~", [ - style::SetAttribute(Attribute::CrossedOut), - style::Print("hello"), - style::SetAttribute(Attribute::NotCrossedOut) - ]); - validate!(less_than_1, "<", [style::Print('<')]); - validate!(greater_than_1, ".>.", [style::Print(".>.")]); - validate!(ampersand_1, "&", [style::Print('&')]); - validate!(quote_1, """, [style::Print('"')]); - validate!(fallback_1, "+ % @ . ? ", [style::Print("+ % @ . ?")]); - validate!(horizontal_rule_1, "---", [style::Print("━".repeat(80))]); - validate!(heading_1, "# Hello World", [ - style::SetForegroundColor(HEADING_COLOR), - style::SetAttribute(Attribute::Bold), - style::Print("# Hello World"), - ]); - validate!(bulleted_item_1, "- bullet", [style::Print("• bullet")]); - validate!(bulleted_item_2, "* bullet", [style::Print("• bullet")]); - validate!(numbered_item_1, "1. number", [style::Print("1. number")]); - validate!(blockquote_1, "> hello", [ - style::SetForegroundColor(BLOCKQUOTE_COLOR), - style::Print("│ hello"), - ]); - validate!(square_bracket_1, "[test]", [style::Print("[test]")]); - validate!(square_bracket_2, "Text with [brackets]", [style::Print( - "Text with [brackets]" - )]); - validate!(square_bracket_empty, "[]", [style::Print("[]")]); - validate!(square_bracket_array, "a[i]", [style::Print("a[i]")]); - validate!(square_bracket_url_like_1, "[text] without url part", [style::Print( - "[text] without url part" - )]); - validate!(square_bracket_url_like_2, "[text](without url part", [style::Print( - "[text](without url part" - )]); -} diff --git a/crates/chat-cli/src/cli/chat/parser.rs b/crates/chat-cli/src/cli/chat/parser.rs deleted file mode 100644 index 821ba1087..000000000 --- a/crates/chat-cli/src/cli/chat/parser.rs +++ /dev/null @@ -1,403 +0,0 @@ -use std::time::{ - Duration, - Instant, -}; - -use eyre::Result; -use thiserror::Error; -use tracing::{ - error, - info, - trace, -}; - -use super::message::{ - AssistantMessage, - AssistantToolUse, -}; -use crate::api_client::model::ChatResponseStream; -use crate::api_client::send_message_output::SendMessageOutput; -use crate::telemetry::ReasonCode; - -#[derive(Debug, Error)] -pub struct RecvError { - /// The request id associated with the [SendMessageOutput] stream. - pub request_id: Option, - #[source] - pub source: RecvErrorKind, -} - -impl RecvError { - pub fn status_code(&self) -> Option { - match &self.source { - RecvErrorKind::Client(e) => e.status_code(), - RecvErrorKind::Json(_) => None, - RecvErrorKind::StreamTimeout { .. } => None, - RecvErrorKind::UnexpectedToolUseEos { .. } => None, - } - } -} - -impl ReasonCode for RecvError { - fn reason_code(&self) -> String { - match &self.source { - RecvErrorKind::Client(_) => "RecvErrorApiClient".to_string(), - RecvErrorKind::Json(_) => "RecvErrorJson".to_string(), - RecvErrorKind::StreamTimeout { .. } => "RecvErrorStreamTimeout".to_string(), - RecvErrorKind::UnexpectedToolUseEos { .. } => "RecvErrorUnexpectedToolUseEos".to_string(), - } - } -} - -impl std::fmt::Display for RecvError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "Failed to receive the next message: ")?; - if let Some(request_id) = self.request_id.as_ref() { - write!(f, "request_id: {}, error: ", request_id)?; - } - write!(f, "{}", self.source)?; - Ok(()) - } -} - -#[derive(Debug, Error)] -pub enum RecvErrorKind { - #[error("{0}")] - Client(#[from] crate::api_client::ApiClientError), - #[error("{0}")] - Json(#[from] serde_json::Error), - /// An error was encountered while waiting for the next event in the stream after a noticeably - /// long wait time. - /// - /// *Context*: the client can throw an error after ~100s of waiting with no response, likely due - /// to an exceptionally complex tool use taking too long to generate. - #[error("The stream ended after {}s: {source}", .duration.as_secs())] - StreamTimeout { - source: crate::api_client::ApiClientError, - duration: std::time::Duration, - }, - /// Unexpected end of stream while receiving a tool use. - /// - /// *Context*: the stream can unexpectedly end with `Ok(None)` while waiting for an - /// exceptionally complex tool use. This is due to some proxy server dropping idle - /// connections after some timeout is reached. - /// - /// TODO: should this be removed? - #[error("Unexpected end of stream for tool: {} with id: {}", .name, .tool_use_id)] - UnexpectedToolUseEos { - tool_use_id: String, - name: String, - message: Box, - time_elapsed: Duration, - }, -} - -/// State associated with parsing a [ChatResponseStream] into a [Message]. -/// -/// # Usage -/// -/// You should repeatedly call [Self::recv] to receive [ResponseEvent]'s until a -/// [ResponseEvent::EndStream] value is returned. -#[derive(Debug)] -pub struct ResponseParser { - /// The response to consume and parse into a sequence of [Ev]. - response: SendMessageOutput, - /// Buffer to hold the next event in [SendMessageOutput]. - peek: Option, - /// Message identifier for the assistant's response. Randomly generated on creation. - message_id: String, - /// Buffer for holding the accumulated assistant response. - assistant_text: String, - /// Tool uses requested by the model. - tool_uses: Vec, - /// Whether or not we are currently receiving tool use delta events. Tuple of - /// `Some((tool_use_id, name))` if true, [None] otherwise. - parsing_tool_use: Option<(String, String)>, -} - -impl ResponseParser { - pub fn new(response: SendMessageOutput) -> Self { - let message_id = uuid::Uuid::new_v4().to_string(); - info!(?message_id, "Generated new message id"); - Self { - response, - peek: None, - message_id, - assistant_text: String::new(), - tool_uses: Vec::new(), - parsing_tool_use: None, - } - } - - /// Consumes the associated [ConverseStreamResponse] until a valid [ResponseEvent] is parsed. - pub async fn recv(&mut self) -> Result { - if let Some((id, name)) = self.parsing_tool_use.take() { - let tool_use = self.parse_tool_use(id, name).await?; - self.tool_uses.push(tool_use.clone()); - return Ok(ResponseEvent::ToolUse(tool_use)); - } - - // First, handle discarding AssistantResponseEvent's that immediately precede a - // CodeReferenceEvent. - let peek = self.peek().await?; - if let Some(ChatResponseStream::AssistantResponseEvent { content }) = peek { - // Cloning to bypass borrowchecker stuff. - let content = content.clone(); - self.next().await?; - match self.peek().await? { - Some(ChatResponseStream::CodeReferenceEvent(_)) => (), - _ => { - self.assistant_text.push_str(&content); - return Ok(ResponseEvent::AssistantText(content)); - }, - } - } - - loop { - match self.next().await { - Ok(Some(output)) => match output { - ChatResponseStream::AssistantResponseEvent { content } => { - self.assistant_text.push_str(&content); - return Ok(ResponseEvent::AssistantText(content)); - }, - ChatResponseStream::InvalidStateEvent { reason, message } => { - error!(%reason, %message, "invalid state event"); - }, - ChatResponseStream::ToolUseEvent { - tool_use_id, - name, - input, - stop, - } => { - debug_assert!(input.is_none(), "Unexpected initial content in first tool use event"); - debug_assert!( - stop.is_none_or(|v| !v), - "Unexpected immediate stop in first tool use event" - ); - self.parsing_tool_use = Some((tool_use_id.clone(), name.clone())); - return Ok(ResponseEvent::ToolUseStart { name }); - }, - _ => {}, - }, - Ok(None) => { - let message_id = Some(self.message_id.clone()); - let content = std::mem::take(&mut self.assistant_text); - let message = if self.tool_uses.is_empty() { - AssistantMessage::new_response(message_id, content) - } else { - AssistantMessage::new_tool_use( - message_id, - content, - self.tool_uses.clone().into_iter().collect(), - ) - }; - return Ok(ResponseEvent::EndStream { message }); - }, - Err(err) => return Err(err), - } - } - } - - /// Consumes the response stream until a valid [ToolUse] is parsed. - /// - /// The arguments are the fields from the first [ChatResponseStream::ToolUseEvent] consumed. - async fn parse_tool_use(&mut self, id: String, name: String) -> Result { - let mut tool_string = String::new(); - let start = Instant::now(); - while let Some(ChatResponseStream::ToolUseEvent { .. }) = self.peek().await? { - if let Some(ChatResponseStream::ToolUseEvent { input, stop, .. }) = self.next().await? { - if let Some(i) = input { - tool_string.push_str(&i); - } - if let Some(true) = stop { - break; - } - } - } - - let args = match serde_json::from_str(&tool_string) { - Ok(args) => args, - Err(err) if !tool_string.is_empty() => { - // If we failed deserializing after waiting for a long time, then this is most - // likely bedrock responding with a stop event for some reason without actually - // including the tool contents. Essentially, the tool was too large. - let time_elapsed = start.elapsed(); - let args = serde_json::Value::Object( - [( - "key".to_string(), - serde_json::Value::String( - "WARNING: the actual tool use arguments were too complicated to be generated".to_string(), - ), - )] - .into_iter() - .collect(), - ); - if self.peek().await?.is_none() { - error!( - "Received an unexpected end of stream after spending ~{}s receiving tool events", - time_elapsed.as_secs_f64() - ); - self.tool_uses.push(AssistantToolUse { - id: id.clone(), - name: name.clone(), - orig_name: name.clone(), - args: args.clone(), - orig_args: args.clone(), - }); - let message = Box::new(AssistantMessage::new_tool_use( - Some(self.message_id.clone()), - std::mem::take(&mut self.assistant_text), - self.tool_uses.clone().into_iter().collect(), - )); - return Err(self.error(RecvErrorKind::UnexpectedToolUseEos { - tool_use_id: id, - name, - message, - time_elapsed, - })); - } else { - return Err(self.error(err)); - } - }, - // if the tool just does not need any input - _ => serde_json::json!({}), - }; - let orig_name = name.clone(); - let orig_args = args.clone(); - Ok(AssistantToolUse { - id, - name, - orig_name, - args, - orig_args, - }) - } - - /// Returns the next event in the [SendMessageOutput] without consuming it. - async fn peek(&mut self) -> Result, RecvError> { - if self.peek.is_some() { - return Ok(self.peek.as_ref()); - } - match self.next().await? { - Some(v) => { - self.peek = Some(v); - Ok(self.peek.as_ref()) - }, - None => Ok(None), - } - } - - /// Consumes the next [SendMessageOutput] event. - async fn next(&mut self) -> Result, RecvError> { - if let Some(ev) = self.peek.take() { - return Ok(Some(ev)); - } - trace!("Attempting to recv next event"); - let start = std::time::Instant::now(); - let result = self.response.recv().await; - let duration = std::time::Instant::now().duration_since(start); - match result { - Ok(r) => { - trace!(?r, "Received new event"); - Ok(r) - }, - Err(err) => { - if duration.as_secs() >= 59 { - Err(self.error(RecvErrorKind::StreamTimeout { source: err, duration })) - } else { - Err(self.error(err)) - } - }, - } - } - - fn request_id(&self) -> Option<&str> { - self.response.request_id() - } - - /// Helper to create a new [RecvError] populated with the associated request id for the stream. - fn error(&self, source: impl Into) -> RecvError { - RecvError { - request_id: self.request_id().map(str::to_string), - source: source.into(), - } - } -} - -#[derive(Debug)] -pub enum ResponseEvent { - /// Text returned by the assistant. This should be displayed to the user as it is received. - AssistantText(String), - /// Notification that a tool use is being received. - ToolUseStart { name: String }, - /// A tool use requested by the assistant. This should be displayed to the user as it is - /// received. - ToolUse(AssistantToolUse), - /// Represents the end of the response. No more events will be returned. - EndStream { - /// The completed message containing all of the assistant text and tool use events - /// previously emitted. This should be stored in the conversation history and sent in - /// subsequent requests. - message: AssistantMessage, - }, -} - -#[cfg(test)] -mod tests { - use super::*; - - #[tokio::test] - async fn test_parse() { - // let _ = tracing_subscriber::fmt::try_init(); - let tool_use_id = "TEST_ID".to_string(); - let tool_name = "execute_bash".to_string(); - let tool_args = serde_json::json!({ - "command": "echo hello" - }) - .to_string(); - let tool_use_split_at = 5; - let mut events = vec![ - ChatResponseStream::AssistantResponseEvent { - content: "hi".to_string(), - }, - ChatResponseStream::AssistantResponseEvent { - content: " there".to_string(), - }, - ChatResponseStream::AssistantResponseEvent { - content: "IGNORE ME PLEASE".to_string(), - }, - ChatResponseStream::CodeReferenceEvent(()), - ChatResponseStream::ToolUseEvent { - tool_use_id: tool_use_id.clone(), - name: tool_name.clone(), - input: None, - stop: None, - }, - ChatResponseStream::ToolUseEvent { - tool_use_id: tool_use_id.clone(), - name: tool_name.clone(), - input: Some(tool_args.as_str().split_at(tool_use_split_at).0.to_string()), - stop: None, - }, - ChatResponseStream::ToolUseEvent { - tool_use_id: tool_use_id.clone(), - name: tool_name.clone(), - input: Some(tool_args.as_str().split_at(tool_use_split_at).1.to_string()), - stop: None, - }, - ChatResponseStream::ToolUseEvent { - tool_use_id: tool_use_id.clone(), - name: tool_name.clone(), - input: None, - stop: Some(true), - }, - ]; - events.reverse(); - let mock = SendMessageOutput::Mock(events); - let mut parser = ResponseParser::new(mock); - - for _ in 0..5 { - println!("{:?}", parser.recv().await.unwrap()); - } - } -} diff --git a/crates/chat-cli/src/cli/chat/prompt.rs b/crates/chat-cli/src/cli/chat/prompt.rs deleted file mode 100644 index 707c2558d..000000000 --- a/crates/chat-cli/src/cli/chat/prompt.rs +++ /dev/null @@ -1,610 +0,0 @@ -use std::borrow::Cow; - -use eyre::Result; -use rustyline::completion::{ - Completer, - FilenameCompleter, - extract_word, -}; -use rustyline::error::ReadlineError; -use rustyline::highlight::{ - CmdKind, - Highlighter, -}; -use rustyline::hint::Hinter as RustylineHinter; -use rustyline::history::DefaultHistory; -use rustyline::validate::{ - ValidationContext, - ValidationResult, - Validator, -}; -use rustyline::{ - Cmd, - Completer, - CompletionType, - Config, - Context, - EditMode, - Editor, - EventHandler, - Helper, - Hinter, - KeyCode, - KeyEvent, - Modifiers, -}; -use winnow::stream::AsChar; - -pub use super::prompt_parser::generate_prompt; -use super::prompt_parser::parse_prompt_components; -use crate::database::settings::Setting; -use crate::os::Os; - -pub const COMMANDS: &[&str] = &[ - "/clear", - "/help", - "/editor", - "/issue", - "/quit", - "/tools", - "/tools trust", - "/tools untrust", - "/tools trust-all", - "/tools reset", - "/mcp", - "/model", - "/profile", - "/profile help", - "/profile list", - "/profile create", - "/profile delete", - "/profile rename", - "/profile set", - "/prompts", - "/context", - "/context help", - "/context show", - "/context show --expand", - "/context add", - "/context add --global", - "/context rm", - "/context rm --global", - "/context clear", - "/context clear --global", - "/hooks", - "/hooks help", - "/hooks add", - "/hooks rm", - "/hooks enable", - "/hooks disable", - "/hooks enable-all", - "/hooks disable-all", - "/compact", - "/compact help", - "/usage", - "/save", - "/load", - "/subscribe", -]; - -/// Complete commands that start with a slash -fn complete_command(word: &str, start: usize) -> (usize, Vec) { - ( - start, - COMMANDS - .iter() - .filter(|p| p.starts_with(word)) - .map(|s| (*s).to_owned()) - .collect(), - ) -} - -/// A wrapper around FilenameCompleter that provides enhanced path detection -/// and completion capabilities for the chat interface. -pub struct PathCompleter { - /// The underlying filename completer from rustyline - filename_completer: FilenameCompleter, -} - -impl PathCompleter { - /// Creates a new PathCompleter instance - pub fn new() -> Self { - Self { - filename_completer: FilenameCompleter::new(), - } - } - - /// Attempts to complete a file path at the given position in the line - pub fn complete_path( - &self, - line: &str, - pos: usize, - os: &Context<'_>, - ) -> Result<(usize, Vec), ReadlineError> { - // Use the filename completer to get path completions - match self.filename_completer.complete(line, pos, os) { - Ok((pos, completions)) => { - // Convert the filename completer's pairs to strings - let file_completions: Vec = completions.iter().map(|pair| pair.replacement.clone()).collect(); - - // Return the completions if we have any - Ok((pos, file_completions)) - }, - Err(err) => Err(err), - } - } -} - -pub struct PromptCompleter { - sender: std::sync::mpsc::Sender>, - receiver: std::sync::mpsc::Receiver>, -} - -impl PromptCompleter { - fn new(sender: std::sync::mpsc::Sender>, receiver: std::sync::mpsc::Receiver>) -> Self { - PromptCompleter { sender, receiver } - } - - fn complete_prompt(&self, word: &str) -> Result, ReadlineError> { - let sender = &self.sender; - let receiver = &self.receiver; - sender - .send(if !word.is_empty() { Some(word.to_string()) } else { None }) - .map_err(|e| ReadlineError::Io(std::io::Error::other(e.to_string())))?; - let prompt_info = receiver - .recv() - .map_err(|e| ReadlineError::Io(std::io::Error::other(e.to_string())))? - .iter() - .map(|n| format!("@{n}")) - .collect::>(); - - Ok(prompt_info) - } -} - -pub struct ChatCompleter { - path_completer: PathCompleter, - prompt_completer: PromptCompleter, -} - -impl ChatCompleter { - fn new(sender: std::sync::mpsc::Sender>, receiver: std::sync::mpsc::Receiver>) -> Self { - Self { - path_completer: PathCompleter::new(), - prompt_completer: PromptCompleter::new(sender, receiver), - } - } -} - -impl Completer for ChatCompleter { - type Candidate = String; - - fn complete( - &self, - line: &str, - pos: usize, - _os: &Context<'_>, - ) -> Result<(usize, Vec), ReadlineError> { - let (start, word) = extract_word(line, pos, None, |c| c.is_space()); - - // Handle command completion - if word.starts_with('/') { - return Ok(complete_command(word, start)); - } - - if line.starts_with('@') { - let search_word = line.strip_prefix('@').unwrap_or(""); - if let Ok(completions) = self.prompt_completer.complete_prompt(search_word) { - if !completions.is_empty() { - return Ok((0, completions)); - } - } - } - - // Handle file path completion as fallback - if let Ok((pos, completions)) = self.path_completer.complete_path(line, pos, _os) { - if !completions.is_empty() { - return Ok((pos, completions)); - } - } - - // Default: no completions - Ok((start, Vec::new())) - } -} - -/// Custom hinter that provides shadowtext suggestions -pub struct ChatHinter { - /// Command history for providing suggestions based on past commands - history: Vec, - /// Whether history-based hints are enabled - history_hints_enabled: bool, -} - -impl ChatHinter { - /// Creates a new ChatHinter instance - pub fn new(history_hints_enabled: bool) -> Self { - Self { - history: Vec::new(), - history_hints_enabled, - } - } - - /// Updates the history with a new command - pub fn update_history(&mut self, command: &str) { - if !command.trim().is_empty() { - self.history.push(command.to_string()); - } - } - - /// Finds the best hint for the current input - fn find_hint(&self, line: &str) -> Option { - // If line is empty, no hint - if line.is_empty() { - return None; - } - - // If line starts with a slash, try to find a command hint - if line.starts_with('/') { - return COMMANDS - .iter() - .find(|cmd| cmd.starts_with(line)) - .map(|cmd| cmd[line.len()..].to_string()); - } - - // Try to find a hint from history, but only if history hints are enabled - if self.history_hints_enabled { - self.history - .iter() - .rev() // Start from most recent - .find(|cmd| cmd.starts_with(line) && cmd.len() > line.len()) - .map(|cmd| cmd[line.len()..].to_string()) - } else { - None - } - } -} - -impl RustylineHinter for ChatHinter { - type Hint = String; - - fn hint(&self, line: &str, pos: usize, _ctx: &Context<'_>) -> Option { - // Only provide hints when cursor is at the end of the line - if pos < line.len() { - return None; - } - - self.find_hint(line) - } -} - -/// Custom validator for multi-line input -pub struct MultiLineValidator; - -impl Validator for MultiLineValidator { - fn validate(&self, os: &mut ValidationContext<'_>) -> rustyline::Result { - let input = os.input(); - - // Check for explicit multi-line markers - if input.starts_with("```") && !input.ends_with("```") { - return Ok(ValidationResult::Incomplete); - } - - // Check for backslash continuation - if input.ends_with('\\') { - return Ok(ValidationResult::Incomplete); - } - - Ok(ValidationResult::Valid(None)) - } -} - -#[derive(Helper, Completer, Hinter)] -pub struct ChatHelper { - #[rustyline(Completer)] - completer: ChatCompleter, - #[rustyline(Hinter)] - hinter: ChatHinter, - validator: MultiLineValidator, -} - -impl ChatHelper { - /// Updates the history of the ChatHinter with a new command - pub fn update_hinter_history(&mut self, command: &str) { - if command.contains("\n") || command.contains("\r") { - return; - } - - self.hinter.update_history(command); - } -} - -impl Validator for ChatHelper { - fn validate(&self, os: &mut ValidationContext<'_>) -> rustyline::Result { - self.validator.validate(os) - } -} - -impl Highlighter for ChatHelper { - fn highlight_hint<'h>(&self, hint: &'h str) -> Cow<'h, str> { - Cow::Owned(format!("\x1b[38;5;240m{hint}\x1b[m")) - } - - fn highlight<'l>(&self, line: &'l str, _pos: usize) -> Cow<'l, str> { - Cow::Borrowed(line) - } - - fn highlight_char(&self, _line: &str, _pos: usize, _kind: CmdKind) -> bool { - false - } - - fn highlight_prompt<'b, 's: 'b, 'p: 'b>(&'s self, prompt: &'p str, _default: bool) -> Cow<'b, str> { - use crossterm::style::Stylize; - - // Parse the plain text prompt to extract profile and warning information - // and apply colors using crossterm's ANSI escape codes - if let Some(components) = parse_prompt_components(prompt) { - let mut result = String::new(); - - // Add profile part if present - if let Some(profile) = components.profile { - result.push_str(&format!("[{}] ", profile).cyan().to_string()); - } - - // Add warning symbol if present - if components.warning { - result.push_str(&"!".red().to_string()); - } - - // Add the prompt symbol - result.push_str(&"> ".magenta().to_string()); - - Cow::Owned(result) - } else { - // If we can't parse the prompt, return it as-is - Cow::Borrowed(prompt) - } - } -} - -pub fn rl( - os: &Os, - sender: std::sync::mpsc::Sender>, - receiver: std::sync::mpsc::Receiver>, -) -> Result> { - let edit_mode = match os.database.settings.get_string(Setting::ChatEditMode).as_deref() { - Some("vi" | "vim") => EditMode::Vi, - _ => EditMode::Emacs, - }; - let config = Config::builder() - .history_ignore_space(true) - .completion_type(CompletionType::List) - .edit_mode(edit_mode) - .build(); - - let history_hints_enabled = os - .database - .settings - .get_bool(Setting::ChatEnableHistoryHints) - .unwrap_or(false); - let h = ChatHelper { - completer: ChatCompleter::new(sender, receiver), - hinter: ChatHinter::new(history_hints_enabled), - validator: MultiLineValidator, - }; - - let mut rl = Editor::with_config(config)?; - rl.set_helper(Some(h)); - - // Add custom keybinding for Alt+Enter to insert a newline - rl.bind_sequence( - KeyEvent(KeyCode::Enter, Modifiers::ALT), - EventHandler::Simple(Cmd::Insert(1, "\n".to_string())), - ); - - // Add custom keybinding for Ctrl+J to insert a newline - rl.bind_sequence( - KeyEvent(KeyCode::Char('j'), Modifiers::CTRL), - EventHandler::Simple(Cmd::Insert(1, "\n".to_string())), - ); - - // Add custom keybinding for Ctrl+F to accept hint (like fish shell) - rl.bind_sequence( - KeyEvent(KeyCode::Char('f'), Modifiers::CTRL), - EventHandler::Simple(Cmd::CompleteHint), - ); - - Ok(rl) -} - -#[cfg(test)] -mod tests { - use crossterm::style::Stylize; - use rustyline::highlight::Highlighter; - - use super::*; - - #[test] - fn test_chat_completer_command_completion() { - let (prompt_request_sender, _) = std::sync::mpsc::channel::>(); - let (_, prompt_response_receiver) = std::sync::mpsc::channel::>(); - let completer = ChatCompleter::new(prompt_request_sender, prompt_response_receiver); - let line = "/h"; - let pos = 2; // Position at the end of "/h" - - // Create a mock context with empty history - let empty_history = DefaultHistory::new(); - let os = Context::new(&empty_history); - - // Get completions - let (start, completions) = completer.complete(line, pos, &os).unwrap(); - - // Verify start position - assert_eq!(start, 0); - - // Verify completions contain expected commands - assert!(completions.contains(&"/help".to_string())); - } - - #[test] - fn test_chat_completer_no_completion() { - let (prompt_request_sender, _) = std::sync::mpsc::channel::>(); - let (_, prompt_response_receiver) = std::sync::mpsc::channel::>(); - let completer = ChatCompleter::new(prompt_request_sender, prompt_response_receiver); - let line = "Hello, how are you?"; - let pos = line.len(); - - // Create a mock context with empty history - let empty_history = DefaultHistory::new(); - let os = Context::new(&empty_history); - - // Get completions - let (_, completions) = completer.complete(line, pos, &os).unwrap(); - - // Verify no completions are returned for regular text - assert!(completions.is_empty()); - } - - #[test] - fn test_highlight_prompt_basic() { - let (prompt_request_sender, _) = std::sync::mpsc::channel::>(); - let (_, prompt_response_receiver) = std::sync::mpsc::channel::>(); - let helper = ChatHelper { - completer: ChatCompleter::new(prompt_request_sender, prompt_response_receiver), - hinter: ChatHinter::new(true), - validator: MultiLineValidator, - }; - - // Test basic prompt highlighting - let highlighted = helper.highlight_prompt("> ", true); - - assert_eq!(highlighted, "> ".magenta().to_string()); - } - - #[test] - fn test_highlight_prompt_with_warning() { - let (prompt_request_sender, _) = std::sync::mpsc::channel::>(); - let (_, prompt_response_receiver) = std::sync::mpsc::channel::>(); - let helper = ChatHelper { - completer: ChatCompleter::new(prompt_request_sender, prompt_response_receiver), - hinter: ChatHinter::new(true), - validator: MultiLineValidator, - }; - - // Test warning prompt highlighting - let highlighted = helper.highlight_prompt("!> ", true); - - assert_eq!(highlighted, format!("{}{}", "!".red(), "> ".magenta())); - } - - #[test] - fn test_highlight_prompt_with_profile() { - let (prompt_request_sender, _) = std::sync::mpsc::channel::>(); - let (_, prompt_response_receiver) = std::sync::mpsc::channel::>(); - let helper = ChatHelper { - completer: ChatCompleter::new(prompt_request_sender, prompt_response_receiver), - hinter: ChatHinter::new(true), - validator: MultiLineValidator, - }; - - // Test profile prompt highlighting - let highlighted = helper.highlight_prompt("[test-profile] > ", true); - - assert_eq!(highlighted, format!("{}{}", "[test-profile] ".cyan(), "> ".magenta())); - } - - #[test] - fn test_highlight_prompt_with_profile_and_warning() { - let (prompt_request_sender, _) = std::sync::mpsc::channel::>(); - let (_, prompt_response_receiver) = std::sync::mpsc::channel::>(); - let helper = ChatHelper { - completer: ChatCompleter::new(prompt_request_sender, prompt_response_receiver), - hinter: ChatHinter::new(true), - validator: MultiLineValidator, - }; - - // Test profile + warning prompt highlighting - let highlighted = helper.highlight_prompt("[dev] !> ", true); - // Should have cyan profile + red warning + cyan bold prompt - assert_eq!( - highlighted, - format!("{}{}{}", "[dev] ".cyan(), "!".red(), "> ".magenta()) - ); - } - - #[test] - fn test_highlight_prompt_invalid_format() { - let (prompt_request_sender, _) = std::sync::mpsc::channel::>(); - let (_, prompt_response_receiver) = std::sync::mpsc::channel::>(); - let helper = ChatHelper { - completer: ChatCompleter::new(prompt_request_sender, prompt_response_receiver), - hinter: ChatHinter::new(true), - validator: MultiLineValidator, - }; - - // Test invalid prompt format (should return as-is) - let invalid_prompt = "invalid prompt format"; - let highlighted = helper.highlight_prompt(invalid_prompt, true); - assert_eq!(highlighted, invalid_prompt); - } - - #[test] - fn test_chat_hinter_command_hint() { - let hinter = ChatHinter::new(true); - - // Test hint for a command - let line = "/he"; - let pos = line.len(); - let empty_history = DefaultHistory::new(); - let ctx = Context::new(&empty_history); - - let hint = hinter.hint(line, pos, &ctx); - assert_eq!(hint, Some("lp".to_string())); - - // Test hint when cursor is not at the end - let hint = hinter.hint(line, 1, &ctx); - assert_eq!(hint, None); - - // Test hint for a non-existent command - let line = "/xyz"; - let pos = line.len(); - let hint = hinter.hint(line, pos, &ctx); - assert_eq!(hint, None); - } - - #[test] - fn test_chat_hinter_history_hint_enabled() { - let mut hinter = ChatHinter::new(true); - - // Add some history - hinter.update_history("Hello, world!"); - hinter.update_history("How are you?"); - - // Test hint from history - let line = "How"; - let pos = line.len(); - let empty_history = DefaultHistory::new(); - let ctx = Context::new(&empty_history); - - let hint = hinter.hint(line, pos, &ctx); - assert_eq!(hint, Some(" are you?".to_string())); - } - - #[test] - fn test_chat_hinter_history_hint_disabled() { - let mut hinter = ChatHinter::new(false); - - // Add some history - hinter.update_history("Hello, world!"); - hinter.update_history("How are you?"); - - // Test hint from history when disabled - let line = "How"; - let pos = line.len(); - let empty_history = DefaultHistory::new(); - let ctx = Context::new(&empty_history); - - let hint = hinter.hint(line, pos, &ctx); - assert_eq!(hint, None); - } -} diff --git a/crates/chat-cli/src/cli/chat/prompt_parser.rs b/crates/chat-cli/src/cli/chat/prompt_parser.rs deleted file mode 100644 index 59b0caf99..000000000 --- a/crates/chat-cli/src/cli/chat/prompt_parser.rs +++ /dev/null @@ -1,93 +0,0 @@ -/// Components extracted from a prompt string -#[derive(Debug, PartialEq)] -pub struct PromptComponents { - pub profile: Option, - pub warning: bool, -} - -/// Parse prompt components from a plain text prompt -pub fn parse_prompt_components(prompt: &str) -> Option { - // Expected format: "[profile] !> " or "> " or "!> " etc. - let mut profile = None; - let mut warning = false; - let mut remaining = prompt.trim(); - - // Check for profile pattern [profile] - if let Some(start) = remaining.find('[') { - if let Some(end) = remaining.find(']') { - if start < end { - profile = Some(remaining[start + 1..end].to_string()); - remaining = remaining[end + 1..].trim_start(); - } - } - } - - // Check for warning symbol ! - if remaining.starts_with('!') { - warning = true; - remaining = remaining[1..].trim_start(); - } - - // Should end with "> " - if remaining.trim_end() == ">" { - Some(PromptComponents { profile, warning }) - } else { - None - } -} - -pub fn generate_prompt(current_profile: Option<&str>, warning: bool) -> String { - // Generate plain text prompt that will be colored by highlight_prompt - let warning_symbol = if warning { "!" } else { "" }; - let profile_part = current_profile - .filter(|&p| p != "default") - .map(|p| format!("[{p}] ")) - .unwrap_or_default(); - - format!("{profile_part}{warning_symbol}> ") -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_generate_prompt() { - // Test default prompt (no profile) - assert_eq!(generate_prompt(None, false), "> "); - // Test default prompt with warning - assert_eq!(generate_prompt(None, true), "!> "); - // Test default profile (should be same as no profile) - assert_eq!(generate_prompt(Some("default"), false), "> "); - // Test custom profile - assert_eq!(generate_prompt(Some("test-profile"), false), "[test-profile] > "); - // Test another custom profile with warning - assert_eq!(generate_prompt(Some("dev"), true), "[dev] !> "); - } - - #[test] - fn test_parse_prompt_components() { - // Test basic prompt - let components = parse_prompt_components("> ").unwrap(); - assert!(components.profile.is_none()); - assert!(!components.warning); - - // Test warning prompt - let components = parse_prompt_components("!> ").unwrap(); - assert!(components.profile.is_none()); - assert!(components.warning); - - // Test profile prompt - let components = parse_prompt_components("[test] > ").unwrap(); - assert_eq!(components.profile.as_deref(), Some("test")); - assert!(!components.warning); - - // Test profile with warning - let components = parse_prompt_components("[dev] !> ").unwrap(); - assert_eq!(components.profile.as_deref(), Some("dev")); - assert!(components.warning); - - // Test invalid prompt - assert!(parse_prompt_components("invalid").is_none()); - } -} diff --git a/crates/chat-cli/src/cli/chat/server_messenger.rs b/crates/chat-cli/src/cli/chat/server_messenger.rs deleted file mode 100644 index 966600fc4..000000000 --- a/crates/chat-cli/src/cli/chat/server_messenger.rs +++ /dev/null @@ -1,133 +0,0 @@ -use tokio::sync::mpsc::{ - Receiver, - Sender, - channel, -}; - -use crate::mcp_client::{ - Messenger, - MessengerError, - PromptsListResult, - ResourceTemplatesListResult, - ResourcesListResult, - ToolsListResult, -}; - -#[allow(dead_code)] -#[derive(Debug)] -pub enum UpdateEventMessage { - ToolsListResult { - server_name: String, - result: eyre::Result, - }, - PromptsListResult { - server_name: String, - result: eyre::Result, - }, - ResourcesListResult { - server_name: String, - result: eyre::Result, - }, - ResourceTemplatesListResult { - server_name: String, - result: eyre::Result, - }, - InitStart { - server_name: String, - }, -} - -#[derive(Clone, Debug)] -pub struct ServerMessengerBuilder { - pub update_event_sender: Sender, -} - -impl ServerMessengerBuilder { - pub fn new(capacity: usize) -> (Receiver, Self) { - let (tx, rx) = channel::(capacity); - let this = Self { - update_event_sender: tx, - }; - (rx, this) - } - - pub fn build_with_name(&self, server_name: String) -> ServerMessenger { - ServerMessenger { - server_name, - update_event_sender: self.update_event_sender.clone(), - } - } -} - -#[derive(Clone, Debug)] -pub struct ServerMessenger { - pub server_name: String, - pub update_event_sender: Sender, -} - -#[async_trait::async_trait] -impl Messenger for ServerMessenger { - async fn send_tools_list_result(&self, result: eyre::Result) -> Result<(), MessengerError> { - Ok(self - .update_event_sender - .send(UpdateEventMessage::ToolsListResult { - server_name: self.server_name.clone(), - result, - }) - .await - .map_err(|e| MessengerError::Custom(e.to_string()))?) - } - - async fn send_prompts_list_result(&self, result: eyre::Result) -> Result<(), MessengerError> { - Ok(self - .update_event_sender - .send(UpdateEventMessage::PromptsListResult { - server_name: self.server_name.clone(), - result, - }) - .await - .map_err(|e| MessengerError::Custom(e.to_string()))?) - } - - async fn send_resources_list_result( - &self, - result: eyre::Result, - ) -> Result<(), MessengerError> { - Ok(self - .update_event_sender - .send(UpdateEventMessage::ResourcesListResult { - server_name: self.server_name.clone(), - result, - }) - .await - .map_err(|e| MessengerError::Custom(e.to_string()))?) - } - - async fn send_resource_templates_list_result( - &self, - result: eyre::Result, - ) -> Result<(), MessengerError> { - Ok(self - .update_event_sender - .send(UpdateEventMessage::ResourceTemplatesListResult { - server_name: self.server_name.clone(), - result, - }) - .await - .map_err(|e| MessengerError::Custom(e.to_string()))?) - } - - async fn send_init_msg(&self) -> Result<(), MessengerError> { - Ok(self - .update_event_sender - .send(UpdateEventMessage::InitStart { - server_name: self.server_name.clone(), - }) - .await - .map_err(|e| MessengerError::Custom(e.to_string()))?) - } - - fn duplicate(&self) -> Box { - Box::new(self.clone()) - } -} diff --git a/crates/chat-cli/src/cli/chat/skim_integration.rs b/crates/chat-cli/src/cli/chat/skim_integration.rs deleted file mode 100644 index 625b91f1a..000000000 --- a/crates/chat-cli/src/cli/chat/skim_integration.rs +++ /dev/null @@ -1,385 +0,0 @@ -use std::io::{ - BufReader, - Cursor, - Write, - stdout, -}; - -use crossterm::execute; -use crossterm::terminal::{ - EnterAlternateScreen, - LeaveAlternateScreen, -}; -use eyre::{ - Result, - eyre, -}; -use rustyline::{ - Cmd, - ConditionalEventHandler, - EventContext, - RepeatCount, -}; -use skim::prelude::*; -use tempfile::NamedTempFile; - -use super::context::ContextManager; -use crate::os::Os; - -pub fn select_profile_with_skim(os: &Os, context_manager: &ContextManager) -> Result> { - let profiles = context_manager.list_profiles_blocking(os)?; - - launch_skim_selector(&profiles, "Select profile: ", false) - .map(|selected| selected.and_then(|s| s.into_iter().next())) -} - -pub struct SkimCommandSelector { - os: Os, - context_manager: Arc, - tool_names: Vec, -} - -impl SkimCommandSelector { - /// This allows the ConditionalEventHandler handle function to be bound to a KeyEvent. - pub fn new(os: Os, context_manager: Arc, tool_names: Vec) -> Self { - Self { - os, - context_manager, - tool_names, - } - } -} - -impl ConditionalEventHandler for SkimCommandSelector { - fn handle(&self, _evt: &rustyline::Event, _n: RepeatCount, _positive: bool, _os: &EventContext<'_>) -> Option { - // Launch skim command selector with the context manager if available - match select_command(&self.os, self.context_manager.as_ref(), &self.tool_names) { - Ok(Some(command)) => Some(Cmd::Insert(1, command)), - _ => { - // If cancelled or error, do nothing - Some(Cmd::Noop) - }, - } - } -} - -pub fn get_available_commands() -> Vec { - // Import the COMMANDS array directly from prompt.rs - // This is the single source of truth for available commands - let commands_array = super::prompt::COMMANDS; - - let mut commands = Vec::new(); - for &cmd in commands_array { - commands.push(cmd.to_string()); - } - - commands -} - -/// Format commands for skim display -/// Create a standard set of skim options with consistent styling -fn create_skim_options(prompt: &str, multi: bool) -> Result { - SkimOptionsBuilder::default() - .height("100%".to_string()) - .prompt(prompt.to_string()) - .reverse(true) - .multi(multi) - .build() - .map_err(|e| eyre!("Failed to build skim options: {}", e)) -} - -/// Run skim with the given options and items in an alternate screen -/// This helper function handles entering/exiting the alternate screen and running skim -fn run_skim_with_options(options: &SkimOptions, items: SkimItemReceiver) -> Result>>> { - // Enter alternate screen to prevent skim output from persisting in terminal history - execute!(stdout(), EnterAlternateScreen).map_err(|e| eyre!("Failed to enter alternate screen: {}", e))?; - - let selected_items = - Skim::run_with(options, Some(items)).and_then(|out| if out.is_abort { None } else { Some(out.selected_items) }); - - execute!(stdout(), LeaveAlternateScreen).map_err(|e| eyre!("Failed to leave alternate screen: {}", e))?; - - Ok(selected_items) -} - -/// Extract string selections from skim items -fn extract_selections(items: Vec>) -> Vec { - items.iter().map(|item| item.output().to_string()).collect() -} - -/// Launch skim with the given items and return the selected item -pub fn launch_skim_selector(items: &[String], prompt: &str, multi: bool) -> Result>> { - let mut temp_file_for_skim_input = NamedTempFile::new()?; - temp_file_for_skim_input.write_all(items.join("\n").as_bytes())?; - - let options = create_skim_options(prompt, multi)?; - let item_reader = SkimItemReader::default(); - let items = item_reader.of_bufread(BufReader::new(std::fs::File::open(temp_file_for_skim_input.path())?)); - - // Run skim and get selected items - match run_skim_with_options(&options, items)? { - Some(items) if !items.is_empty() => { - let selections = extract_selections(items); - Ok(Some(selections)) - }, - _ => Ok(None), // User cancelled or no selection - } -} - -/// Select files using skim -pub fn select_files_with_skim() -> Result>> { - // Create skim options with appropriate settings - let options = create_skim_options("Select files: ", true)?; - - // Create a command that will be executed by skim - // This command checks if git is installed and if we're in a git repo - // Otherwise falls back to find command - let find_cmd = r#" - # Check if git is available and we're in a git repo - if command -v git >/dev/null 2>&1 && git rev-parse --is-inside-work-tree &>/dev/null; then - # Git repository - respect .gitignore - { git ls-files; git ls-files --others --exclude-standard; } | sort | uniq - else - # Not a git repository or git not installed - use find command - find . -type f -not -path '*/\.*' - fi - "#; - - // Create a command collector that will execute the find command - let item_reader = SkimItemReader::default(); - let items = item_reader.of_bufread(BufReader::new( - std::process::Command::new("sh") - .args(["-c", find_cmd]) - .stdout(std::process::Stdio::piped()) - .spawn()? - .stdout - .ok_or_else(|| eyre!("Failed to get stdout from command"))?, - )); - - // Run skim with the command output as a stream - match run_skim_with_options(&options, items)? { - Some(items) if !items.is_empty() => { - let selections = extract_selections(items); - Ok(Some(selections)) - }, - _ => Ok(None), // User cancelled or no selection - } -} - -/// Select context paths using skim -pub fn select_context_paths_with_skim(context_manager: &ContextManager) -> Result, bool)>> { - let mut global_paths = Vec::new(); - let mut profile_paths = Vec::new(); - - // Get global paths - for path in &context_manager.global_config.paths { - global_paths.push(format!("(global) {}", path)); - } - - // Get profile-specific paths - for path in &context_manager.profile_config.paths { - profile_paths.push(format!("(profile: {}) {}", context_manager.current_profile, path)); - } - - // Combine paths, but keep track of which are global - let mut all_paths = Vec::new(); - all_paths.extend(global_paths); - all_paths.extend(profile_paths); - - if all_paths.is_empty() { - return Ok(None); // No paths to select - } - - // Create skim options - let options = create_skim_options("Select paths to remove: ", true)?; - - // Create item reader - let item_reader = SkimItemReader::default(); - let items = item_reader.of_bufread(Cursor::new(all_paths.join("\n"))); - - // Run skim and get selected paths - match run_skim_with_options(&options, items)? { - Some(items) if !items.is_empty() => { - let selected_paths = extract_selections(items); - - // Check if any global paths were selected - let has_global = selected_paths.iter().any(|p| p.starts_with("(global)")); - - // Extract the actual paths from the formatted strings - let paths: Vec = selected_paths - .iter() - .map(|p| { - // Extract the path part after the prefix - let parts: Vec<&str> = p.splitn(2, ") ").collect(); - if parts.len() > 1 { - parts[1].to_string() - } else { - p.clone() - } - }) - .collect(); - - Ok(Some((paths, has_global))) - }, - _ => Ok(None), // User cancelled selection - } -} - -/// Launch the command selector and handle the selected command -pub fn select_command(os: &Os, context_manager: &ContextManager, tools: &[String]) -> Result> { - let commands = get_available_commands(); - - match launch_skim_selector(&commands, "Select command: ", false)? { - Some(selections) if !selections.is_empty() => { - let selected_command = &selections[0]; - - match CommandType::from_str(selected_command) { - Some(CommandType::ContextAdd(cmd)) => { - // For context add commands, we need to select files - match select_files_with_skim()? { - Some(files) if !files.is_empty() => { - // Construct the full command with selected files - let mut cmd = cmd.clone(); - for file in files { - cmd.push_str(&format!(" {}", file)); - } - Ok(Some(cmd)) - }, - _ => Ok(Some(selected_command.clone())), /* User cancelled file selection, return just the - * command */ - } - }, - Some(CommandType::ContextRemove(cmd)) => { - // For context rm commands, we need to select from existing context paths - match select_context_paths_with_skim(context_manager)? { - Some((paths, has_global)) if !paths.is_empty() => { - // Construct the full command with selected paths - let mut full_cmd = cmd.clone(); - if has_global { - full_cmd.push_str(" --global"); - } - for path in paths { - full_cmd.push_str(&format!(" {}", path)); - } - Ok(Some(full_cmd)) - }, - Some((_, _)) => Ok(Some(format!("{} (No paths selected)", cmd))), - None => Ok(Some(selected_command.clone())), // User cancelled path selection - } - }, - Some(CommandType::Tools(_)) => { - let options = create_skim_options("Select tool: ", false)?; - let item_reader = SkimItemReader::default(); - let items = item_reader.of_bufread(Cursor::new(tools.join("\n"))); - let selected_tool = match run_skim_with_options(&options, items)? { - Some(items) if !items.is_empty() => Some(items[0].output().to_string()), - _ => None, - }; - - match selected_tool { - Some(tool) => Ok(Some(format!("{} {}", selected_command, tool))), - None => Ok(Some(selected_command.clone())), /* User cancelled tool selection, return just the - * command */ - } - }, - Some(cmd @ CommandType::Profile(_)) if cmd.needs_profile_selection() => { - // For profile operations that need a profile name, show profile selector - match select_profile_with_skim(os, context_manager)? { - Some(profile) => { - let full_cmd = format!("{} {}", selected_command, profile); - Ok(Some(full_cmd)) - }, - None => Ok(Some(selected_command.clone())), // User cancelled profile selection - } - }, - Some(CommandType::Profile(_)) => { - // For other profile operations (like create), just return the command - Ok(Some(selected_command.clone())) - }, - None => { - // Command doesn't need additional parameters - Ok(Some(selected_command.clone())) - }, - } - }, - _ => Ok(None), // User cancelled command selection - } -} - -#[derive(PartialEq)] -enum CommandType { - ContextAdd(String), - ContextRemove(String), - Tools(&'static str), - Profile(&'static str), -} - -impl CommandType { - fn needs_profile_selection(&self) -> bool { - matches!(self, CommandType::Profile("set" | "delete" | "rename")) - } - - fn from_str(cmd: &str) -> Option { - if cmd.starts_with("/context add") { - Some(CommandType::ContextAdd(cmd.to_string())) - } else if cmd.starts_with("/context rm") { - Some(CommandType::ContextRemove(cmd.to_string())) - } else { - match cmd { - "/tools trust" => Some(CommandType::Tools("trust")), - "/tools untrust" => Some(CommandType::Tools("untrust")), - "/profile set" => Some(CommandType::Profile("set")), - "/profile delete" => Some(CommandType::Profile("delete")), - "/profile rename" => Some(CommandType::Profile("rename")), - "/profile create" => Some(CommandType::Profile("create")), - _ => None, - } - } - } -} - -#[cfg(test)] -mod tests { - use std::collections::HashSet; - - use super::*; - - /// Test to verify that all hardcoded command strings in select_command - /// are present in the COMMANDS array from prompt.rs - #[test] - fn test_hardcoded_commands_in_commands_array() { - // Get the set of available commands from prompt.rs - let available_commands: HashSet = get_available_commands().iter().cloned().collect(); - - // List of hardcoded commands used in select_command - let hardcoded_commands = vec![ - "/context add", - "/context add --global", - "/context rm", - "/context rm --global", - "/tools trust", - "/tools untrust", - "/profile set", - "/profile delete", - "/profile rename", - "/profile create", - ]; - - // Check that each hardcoded command is in the COMMANDS array - for cmd in hardcoded_commands { - assert!( - available_commands.contains(cmd), - "Command '{}' is used in select_command but not defined in COMMANDS array", - cmd - ); - - // This should assert that all the commands we assert are present in the match statement of - // select_command() - assert!( - CommandType::from_str(cmd).is_some(), - "Command '{}' cannot be parsed into a CommandType", - cmd - ); - } - } -} diff --git a/crates/chat-cli/src/cli/chat/token_counter.rs b/crates/chat-cli/src/cli/chat/token_counter.rs deleted file mode 100644 index 2c7b61f3b..000000000 --- a/crates/chat-cli/src/cli/chat/token_counter.rs +++ /dev/null @@ -1,251 +0,0 @@ -use std::ops::Deref; - -use super::message::{ - AssistantMessage, - ToolUseResult, - ToolUseResultBlock, - UserMessage, - UserMessageContent, -}; -use crate::cli::chat::conversation::{ - BackendConversationState, - ConversationSize, -}; - -#[derive(Debug, Clone, Copy)] -pub struct CharCount(usize); - -impl CharCount { - pub fn value(&self) -> usize { - self.0 - } -} - -impl Deref for CharCount { - type Target = usize; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -impl From for CharCount { - fn from(value: usize) -> Self { - Self(value) - } -} - -impl std::ops::Add for CharCount { - type Output = CharCount; - - fn add(self, rhs: Self) -> Self::Output { - Self(self.value() + rhs.value()) - } -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] -pub struct TokenCount(usize); - -impl TokenCount { - pub fn value(&self) -> usize { - self.0 - } -} - -impl Deref for TokenCount { - type Target = usize; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -impl From for TokenCount { - fn from(value: CharCount) -> Self { - Self(TokenCounter::count_tokens_char_count(value.value())) - } -} - -impl std::fmt::Display for TokenCount { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.0) - } -} - -pub struct TokenCounter; - -impl TokenCounter { - pub const TOKEN_TO_CHAR_RATIO: usize = 4; - - /// Estimates the number of tokens in the input content. - /// Currently uses a simple heuristic: content length / TOKEN_TO_CHAR_RATIO - /// - /// Rounds up to the nearest multiple of 10 to avoid giving users a false sense of precision. - pub fn count_tokens(content: &str) -> usize { - Self::count_tokens_char_count(content.len()) - } - - fn count_tokens_char_count(count: usize) -> usize { - (count / Self::TOKEN_TO_CHAR_RATIO + 5) / 10 * 10 - } - - pub const fn token_to_chars(token: usize) -> usize { - token * Self::TOKEN_TO_CHAR_RATIO - } -} - -/// A trait for types that represent some number of characters (aka bytes). For use in calculating -/// context window size utilization. -pub trait CharCounter { - /// Returns the number of characters contained within this type. - /// - /// One "character" is essentially the same as one "byte" - fn char_count(&self) -> CharCount; -} - -impl CharCounter for BackendConversationState<'_> { - fn char_count(&self) -> CharCount { - self.calculate_conversation_size().char_count() - } -} - -impl CharCounter for ConversationSize { - fn char_count(&self) -> CharCount { - self.user_messages + self.assistant_messages + self.context_messages - } -} - -impl CharCounter for UserMessage { - fn char_count(&self) -> CharCount { - let mut total_chars = 0; - total_chars += self.additional_context().len(); - match self.content() { - UserMessageContent::Prompt { prompt } => { - total_chars += prompt.len(); - }, - UserMessageContent::CancelledToolUses { - prompt, - tool_use_results, - } => { - total_chars += prompt.as_ref().map_or(0, String::len); - total_chars += tool_use_results.as_slice().char_count().0; - }, - UserMessageContent::ToolUseResults { tool_use_results } => { - total_chars += tool_use_results.as_slice().char_count().0; - }, - } - total_chars.into() - } -} - -impl CharCounter for AssistantMessage { - fn char_count(&self) -> CharCount { - let mut total_chars = 0; - total_chars += self.content().len(); - if let Some(tool_uses) = self.tool_uses() { - total_chars += tool_uses - .iter() - .map(|v| calculate_value_char_count(&v.args)) - .reduce(|acc, e| acc + e) - .unwrap_or_default(); - } - total_chars.into() - } -} - -impl CharCounter for &[ToolUseResult] { - fn char_count(&self) -> CharCount { - self.iter() - .flat_map(|v| &v.content) - .fold(0, |acc, v| { - acc + match v { - ToolUseResultBlock::Json(v) => calculate_value_char_count(v), - ToolUseResultBlock::Text(s) => s.len(), - } - }) - .into() - } -} - -fn calculate_value_char_count(document: &serde_json::Value) -> usize { - match document { - serde_json::Value::Null => 1, - serde_json::Value::Bool(_) => 1, - serde_json::Value::Number(_) => 1, - serde_json::Value::String(s) => s.len(), - serde_json::Value::Array(vec) => vec.iter().fold(0, |acc, v| acc + calculate_value_char_count(v)), - serde_json::Value::Object(map) => map.values().fold(0, |acc, v| acc + calculate_value_char_count(v)), - } -} - -#[cfg(test)] -mod tests { - - use super::*; - - #[test] - fn test_token_count() { - let text = "This is a test sentence."; - let count = TokenCounter::count_tokens(text); - assert_eq!(count, (text.len() / 3 + 5) / 10 * 10); - } - - #[test] - fn test_calculate_value_char_count() { - // Test simple types - assert_eq!( - calculate_value_char_count(&serde_json::Value::String("hello".to_string())), - 5 - ); - assert_eq!( - calculate_value_char_count(&serde_json::Value::Number(serde_json::Number::from(123))), - 1 - ); - assert_eq!(calculate_value_char_count(&serde_json::Value::Bool(true)), 1); - assert_eq!(calculate_value_char_count(&serde_json::Value::Null), 1); - - // Test array - let array = serde_json::Value::Array(vec![ - serde_json::Value::String("test".to_string()), - serde_json::Value::Number(serde_json::Number::from(42)), - serde_json::Value::Bool(false), - ]); - assert_eq!(calculate_value_char_count(&array), 6); // "test" (4) + Number (1) + Bool (1) - - // Test object - let mut obj = serde_json::Map::new(); - obj.insert("key1".to_string(), serde_json::Value::String("value1".to_string())); - obj.insert( - "key2".to_string(), - serde_json::Value::Number(serde_json::Number::from(99)), - ); - let object = serde_json::Value::Object(obj); - assert_eq!(calculate_value_char_count(&object), 7); // "value1" (6) + Number (1) - - // Test nested structure - let mut nested_obj = serde_json::Map::new(); - let mut inner_obj = serde_json::Map::new(); - inner_obj.insert( - "inner_key".to_string(), - serde_json::Value::String("inner_value".to_string()), - ); - nested_obj.insert("outer_key".to_string(), serde_json::Value::Object(inner_obj)); - nested_obj.insert( - "array_key".to_string(), - serde_json::Value::Array(vec![ - serde_json::Value::String("item1".to_string()), - serde_json::Value::String("item2".to_string()), - ]), - ); - - let complex = serde_json::Value::Object(nested_obj); - assert_eq!(calculate_value_char_count(&complex), 21); // "inner_value" (11) + "item1" (5) + "item2" (5) - - // Test empty structures - assert_eq!(calculate_value_char_count(&serde_json::Value::Array(vec![])), 0); - assert_eq!( - calculate_value_char_count(&serde_json::Value::Object(serde_json::Map::new())), - 0 - ); - } -} diff --git a/crates/chat-cli/src/cli/chat/tool_manager.rs b/crates/chat-cli/src/cli/chat/tool_manager.rs deleted file mode 100644 index fb59f3b44..000000000 --- a/crates/chat-cli/src/cli/chat/tool_manager.rs +++ /dev/null @@ -1,1575 +0,0 @@ -use std::collections::{ - HashMap, - HashSet, -}; -use std::future::Future; -use std::hash::{ - DefaultHasher, - Hasher, -}; -use std::io::{ - BufWriter, - Write, -}; -use std::path::{ - Path, - PathBuf, -}; -use std::pin::Pin; -use std::sync::atomic::{ - AtomicBool, - Ordering, -}; -use std::sync::{ - Arc, - RwLock as SyncRwLock, -}; -use std::time::{ - Duration, - Instant, -}; - -use convert_case::Casing; -use crossterm::{ - cursor, - execute, - queue, - style, - terminal, -}; -use futures::{ - StreamExt, - future, - stream, -}; -use regex::Regex; -use serde::{ - Deserialize, - Serialize, -}; -use tokio::signal::ctrl_c; -use tokio::sync::{ - Mutex, - Notify, - RwLock, -}; -use tracing::{ - error, - warn, -}; - -use crate::api_client::model::{ - ToolResult, - ToolResultContentBlock, - ToolResultStatus, -}; -use crate::cli::chat::cli::prompts::GetPromptError; -use crate::cli::chat::message::AssistantToolUse; -use crate::cli::chat::server_messenger::{ - ServerMessengerBuilder, - UpdateEventMessage, -}; -use crate::cli::chat::tools::custom_tool::{ - CustomTool, - CustomToolClient, - CustomToolConfig, -}; -use crate::cli::chat::tools::execute::ExecuteCommand; -use crate::cli::chat::tools::fs_read::FsRead; -use crate::cli::chat::tools::fs_write::FsWrite; -use crate::cli::chat::tools::gh_issue::GhIssue; -use crate::cli::chat::tools::knowledge::Knowledge; -use crate::cli::chat::tools::thinking::Thinking; -use crate::cli::chat::tools::use_aws::UseAws; -use crate::cli::chat::tools::{ - Tool, - ToolOrigin, - ToolSpec, -}; -use crate::database::Database; -use crate::database::settings::Setting; -use crate::mcp_client::{ - JsonRpcResponse, - Messenger, - PromptGet, -}; -use crate::os::Os; -use crate::telemetry::TelemetryThread; -use crate::util::paths::PathResolver; - -const NAMESPACE_DELIMITER: &str = "___"; -// This applies for both mcp server and tool name since in the end the tool name as seen by the -// model is just {server_name}{NAMESPACE_DELIMITER}{tool_name} -const VALID_TOOL_NAME: &str = "^[a-zA-Z][a-zA-Z0-9_]*$"; -const SPINNER_CHARS: [char; 10] = ['⠋', '⠙', '⠹', '⠸', '⠼', '⠴', '⠦', '⠧', '⠇', '⠏']; - -pub fn workspace_mcp_config_path(os: &Os) -> eyre::Result { - Ok(PathResolver::new(os).workspace().mcp_config()?) -} - -pub fn global_mcp_config_path(os: &Os) -> eyre::Result { - Ok(PathResolver::new(os).global().mcp_config()?) -} - -/// Messages used for communication between the tool initialization thread and the loading -/// display thread. These messages control the visual loading indicators shown to -/// the user during tool initialization. -enum LoadingMsg { - /// Indicates a tool has finished initializing successfully and should be removed from - /// the loading display. The String parameter is the name of the tool that - /// completed initialization. - Done { name: String, time: String }, - /// Represents an error that occurred during tool initialization. - /// Contains the name of the server that failed to initialize and the error message. - Error { - name: String, - msg: eyre::Report, - time: String, - }, - /// Represents a warning that occurred during tool initialization. - /// Contains the name of the server that generated the warning and the warning message. - Warn { - name: String, - msg: eyre::Report, - time: String, - }, - /// Signals that the loading display thread should terminate. - /// This is sent when all tool initialization is complete or when the application is shutting - /// down. - Terminate { still_loading: Vec }, -} - -/// Used to denote the loading outcome associated with a server. -/// This is mainly used in the non-interactive mode to determine if there is any fatal errors to -/// surface (since we would only want to surface fatal errors in non-interactive mode). -#[derive(Clone, Debug)] -pub enum LoadingRecord { - Success(String), - Warn(String), - Err(String), -} - -// This is to mirror claude's config set up -#[derive(Clone, Serialize, Deserialize, Debug, Default)] -#[serde(rename_all = "camelCase")] -pub struct McpServerConfig { - pub mcp_servers: HashMap, -} - -impl McpServerConfig { - pub async fn load_config(stderr: &mut impl Write) -> eyre::Result { - let os = Os::new().await?; - let resolver = PathResolver::new(&os); - let workspace_path = resolver.workspace().mcp_config()?; - let global_path = resolver.global().mcp_config()?; - - let global_buf = tokio::fs::read(global_path).await.ok(); - let local_buf = tokio::fs::read(workspace_path).await.ok(); - let conf = match (global_buf, local_buf) { - (Some(global_buf), Some(local_buf)) => { - let mut global_conf = Self::from_slice(&global_buf, stderr, "global")?; - let local_conf = Self::from_slice(&local_buf, stderr, "local")?; - for (server_name, config) in local_conf.mcp_servers { - if global_conf.mcp_servers.insert(server_name.clone(), config).is_some() { - queue!( - stderr, - style::SetForegroundColor(style::Color::Yellow), - style::Print("WARNING: "), - style::ResetColor, - style::Print("MCP config conflict for "), - style::SetForegroundColor(style::Color::Green), - style::Print(server_name), - style::ResetColor, - style::Print(". Using workspace version.\n") - )?; - } - } - global_conf - }, - (None, Some(local_buf)) => Self::from_slice(&local_buf, stderr, "local")?, - (Some(global_buf), None) => Self::from_slice(&global_buf, stderr, "global")?, - _ => Default::default(), - }; - - stderr.flush()?; - Ok(conf) - } - - pub async fn load_from_file(os: &Os, path: impl AsRef) -> eyre::Result { - let contents = os.fs.read_to_string(path.as_ref()).await?; - Ok(serde_json::from_str(&contents)?) - } - - pub async fn save_to_file(&self, os: &Os, path: impl AsRef) -> eyre::Result<()> { - let json = serde_json::to_string_pretty(self)?; - os.fs.write(path.as_ref(), json).await?; - Ok(()) - } - - fn from_slice(slice: &[u8], stderr: &mut impl Write, location: &str) -> eyre::Result { - match serde_json::from_slice::(slice) { - Ok(config) => Ok(config), - Err(e) => { - queue!( - stderr, - style::SetForegroundColor(style::Color::Yellow), - style::Print("WARNING: "), - style::ResetColor, - style::Print(format!("Error reading {location} mcp config: {e}\n")), - style::Print("Please check to make sure config is correct. Discarding.\n"), - )?; - Ok(McpServerConfig::default()) - }, - } - } -} - -#[derive(Default)] -pub struct ToolManagerBuilder { - mcp_server_config: Option, - prompt_list_sender: Option>>, - prompt_list_receiver: Option>>, - conversation_id: Option, -} - -impl ToolManagerBuilder { - pub fn mcp_server_config(mut self, config: McpServerConfig) -> Self { - self.mcp_server_config.replace(config); - self - } - - pub fn prompt_list_sender(mut self, sender: std::sync::mpsc::Sender>) -> Self { - self.prompt_list_sender.replace(sender); - self - } - - pub fn prompt_list_receiver(mut self, receiver: std::sync::mpsc::Receiver>) -> Self { - self.prompt_list_receiver.replace(receiver); - self - } - - pub fn conversation_id(mut self, conversation_id: &str) -> Self { - self.conversation_id.replace(conversation_id.to_string()); - self - } - - pub async fn build( - mut self, - os: &mut Os, - mut output: Box, - interactive: bool, - ) -> eyre::Result { - let McpServerConfig { mcp_servers } = self.mcp_server_config.ok_or(eyre::eyre!("Missing mcp server config"))?; - debug_assert!(self.conversation_id.is_some()); - let conversation_id = self.conversation_id.ok_or(eyre::eyre!("Missing conversation id"))?; - let regex = regex::Regex::new(VALID_TOOL_NAME)?; - let mut hasher = DefaultHasher::new(); - - // Separate enabled and disabled servers - let (enabled_servers, disabled_servers): (Vec<_>, Vec<_>) = mcp_servers - .into_iter() - .partition(|(_, server_config)| !server_config.disabled); - - // Prepare disabled servers for display - let disabled_servers_display: Vec = disabled_servers - .iter() - .map(|(server_name, _)| { - let snaked_cased_name = server_name.to_case(convert_case::Case::Snake); - sanitize_name(snaked_cased_name, ®ex, &mut hasher) - }) - .collect(); - - let pre_initialized = enabled_servers - .into_iter() - .map(|(server_name, server_config)| { - let snaked_cased_name = server_name.to_case(convert_case::Case::Snake); - let sanitized_server_name = sanitize_name(snaked_cased_name, ®ex, &mut hasher); - let custom_tool_client = CustomToolClient::from_config(sanitized_server_name.clone(), server_config); - (sanitized_server_name, custom_tool_client) - }) - .collect::>(); - - let mut loading_servers = HashMap::::new(); - for (server_name, _) in &pre_initialized { - let init_time = std::time::Instant::now(); - loading_servers.insert(server_name.clone(), init_time); - } - let total = loading_servers.len(); - - // Spawn a task for displaying the mcp loading statuses. - // This is only necessary when we are in interactive mode AND there are servers to load. - // Otherwise we do not need to be spawning this. - let (_loading_display_task, loading_status_sender) = if interactive - && (total > 0 || !disabled_servers_display.is_empty()) - { - let (tx, mut rx) = tokio::sync::mpsc::channel::(50); - let disabled_servers_display_clone = disabled_servers_display.clone(); - ( - Some(tokio::task::spawn(async move { - let mut spinner_logo_idx: usize = 0; - let mut complete: usize = 0; - let mut failed: usize = 0; - - // Show disabled servers immediately - for server_name in &disabled_servers_display_clone { - queue_disabled_message(server_name, &mut output)?; - } - - if total > 0 { - queue_init_message(spinner_logo_idx, complete, failed, total, &mut output)?; - } - - loop { - match tokio::time::timeout(Duration::from_millis(50), rx.recv()).await { - Ok(Some(recv_result)) => match recv_result { - LoadingMsg::Done { name, time } => { - complete += 1; - execute!( - output, - cursor::MoveToColumn(0), - cursor::MoveUp(1), - terminal::Clear(terminal::ClearType::CurrentLine), - )?; - queue_success_message(&name, &time, &mut output)?; - queue_init_message(spinner_logo_idx, complete, failed, total, &mut output)?; - }, - LoadingMsg::Error { name, msg, time } => { - failed += 1; - execute!( - output, - cursor::MoveToColumn(0), - cursor::MoveUp(1), - terminal::Clear(terminal::ClearType::CurrentLine), - )?; - queue_failure_message(&name, &msg, time.as_str(), &mut output)?; - queue_init_message(spinner_logo_idx, complete, failed, total, &mut output)?; - }, - LoadingMsg::Warn { name, msg, time } => { - complete += 1; - execute!( - output, - cursor::MoveToColumn(0), - cursor::MoveUp(1), - terminal::Clear(terminal::ClearType::CurrentLine), - )?; - let msg = eyre::eyre!(msg.to_string()); - queue_warn_message(&name, &msg, time.as_str(), &mut output)?; - queue_init_message(spinner_logo_idx, complete, failed, total, &mut output)?; - }, - LoadingMsg::Terminate { still_loading } => { - if !still_loading.is_empty() && total > 0 { - execute!( - output, - cursor::MoveToColumn(0), - cursor::MoveUp(1), - terminal::Clear(terminal::ClearType::CurrentLine), - )?; - let msg = still_loading.iter().fold(String::new(), |mut acc, server_name| { - acc.push_str(format!("\n - {server_name}").as_str()); - acc - }); - let msg = eyre::eyre!(msg); - queue_incomplete_load_message(complete, total, &msg, &mut output)?; - } else if total > 0 { - // Clear the loading line if we have enabled servers - execute!( - output, - cursor::MoveToColumn(0), - cursor::MoveUp(1), - terminal::Clear(terminal::ClearType::CurrentLine), - )?; - } - execute!(output, style::Print("\n"),)?; - break; - }, - }, - Err(_e) => { - spinner_logo_idx = (spinner_logo_idx + 1) % SPINNER_CHARS.len(); - execute!( - output, - cursor::SavePosition, - cursor::MoveToColumn(0), - cursor::MoveUp(1), - style::Print(SPINNER_CHARS[spinner_logo_idx]), - cursor::RestorePosition - )?; - }, - _ => break, - } - output.flush()?; - } - Ok::<_, eyre::Report>(()) - })), - Some(tx), - ) - } else { - (None, None) - }; - let mut clients = HashMap::>::new(); - let mut loading_status_sender_clone = loading_status_sender.clone(); - let conv_id_clone = conversation_id.clone(); - let regex = Regex::new(VALID_TOOL_NAME)?; - let new_tool_specs = Arc::new(Mutex::new(HashMap::new())); - let new_tool_specs_clone = new_tool_specs.clone(); - let has_new_stuff = Arc::new(AtomicBool::new(false)); - let has_new_stuff_clone = has_new_stuff.clone(); - let pending = Arc::new(RwLock::new(HashSet::::new())); - let pending_clone = pending.clone(); - let (mut msg_rx, messenger_builder) = ServerMessengerBuilder::new(20); - let telemetry_clone = os.telemetry.clone(); - let database_clone = os.database.clone(); - let notify = Arc::new(Notify::new()); - let notify_weak = Arc::downgrade(¬ify); - let load_record = Arc::new(Mutex::new(HashMap::>::new())); - let load_record_clone = load_record.clone(); - tokio::spawn(async move { - let mut record_temp_buf = Vec::::new(); - let mut initialized = HashSet::::new(); - while let Some(msg) = msg_rx.recv().await { - record_temp_buf.clear(); - // For now we will treat every list result as if they contain the - // complete set of tools. This is not necessarily true in the future when - // request method on the mcp client no longer buffers all the pages from - // list calls. - match msg { - UpdateEventMessage::ToolsListResult { server_name, result } => { - let time_taken = loading_servers - .remove(&server_name) - .map_or("0.0".to_owned(), |init_time| { - let time_taken = (std::time::Instant::now() - init_time).as_secs_f64().abs(); - format!("{:.2}", time_taken) - }); - pending_clone.write().await.remove(&server_name); - match result { - Ok(result) => { - let mut specs = result - .tools - .into_iter() - .filter_map(|v| serde_json::from_value::(v).ok()) - .collect::>(); - let mut sanitized_mapping = HashMap::::new(); - let process_result = process_tool_specs( - conv_id_clone.as_str(), - &server_name, - &mut specs, - &mut sanitized_mapping, - ®ex, - &telemetry_clone, - &database_clone, - ) - .await; - if let Some(sender) = &loading_status_sender_clone { - // Anomalies here are not considered fatal, thus we shall give - // warnings. - let msg = match process_result { - Ok(_) => LoadingMsg::Done { - name: server_name.clone(), - time: time_taken.clone(), - }, - Err(ref e) => LoadingMsg::Warn { - name: server_name.clone(), - msg: eyre::eyre!(e.to_string()), - time: time_taken.clone(), - }, - }; - if let Err(e) = sender.send(msg).await { - warn!( - "Error sending update message to display task: {:?}\nAssume display task has completed", - e - ); - loading_status_sender_clone.take(); - } - } - new_tool_specs_clone - .lock() - .await - .insert(server_name.clone(), (sanitized_mapping, specs)); - has_new_stuff_clone.store(true, Ordering::Release); - // Maintain a record of the server load: - let mut buf_writer = BufWriter::new(&mut record_temp_buf); - if let Err(e) = &process_result { - let _ = queue_warn_message( - server_name.as_str(), - e, - time_taken.as_str(), - &mut buf_writer, - ); - } else { - let _ = queue_success_message( - server_name.as_str(), - time_taken.as_str(), - &mut buf_writer, - ); - } - let _ = buf_writer.flush(); - drop(buf_writer); - let record = String::from_utf8_lossy(&record_temp_buf).to_string(); - let record = if process_result.is_err() { - LoadingRecord::Warn(record) - } else { - LoadingRecord::Success(record) - }; - load_record_clone - .lock() - .await - .entry(server_name.clone()) - .and_modify(|load_record| { - load_record.push(record.clone()); - }) - .or_insert(vec![record]); - }, - Err(e) => { - // Log error to chat Log - error!("Error loading server {server_name}: {:?}", e); - // Maintain a record of the server load: - let mut buf_writer = BufWriter::new(&mut record_temp_buf); - let _ = queue_failure_message(server_name.as_str(), &e, &time_taken, &mut buf_writer); - let _ = buf_writer.flush(); - drop(buf_writer); - let record = String::from_utf8_lossy(&record_temp_buf).to_string(); - let record = LoadingRecord::Err(record); - load_record_clone - .lock() - .await - .entry(server_name.clone()) - .and_modify(|load_record| { - load_record.push(record.clone()); - }) - .or_insert(vec![record]); - // Errors surfaced at this point (i.e. before [process_tool_specs] - // is called) are fatals and should be considered errors - if let Some(sender) = &loading_status_sender_clone { - let msg = LoadingMsg::Error { - name: server_name.clone(), - msg: e, - time: time_taken, - }; - if let Err(e) = sender.send(msg).await { - warn!( - "Error sending update message to display task: {:?}\nAssume display task has completed", - e - ); - loading_status_sender_clone.take(); - } - } - }, - } - if let Some(notify) = notify_weak.upgrade() { - initialized.insert(server_name); - if initialized.len() >= total { - notify.notify_one(); - } - } - }, - UpdateEventMessage::PromptsListResult { - server_name: _, - result: _, - } => {}, - UpdateEventMessage::ResourcesListResult { - server_name: _, - result: _, - } => {}, - UpdateEventMessage::ResourceTemplatesListResult { - server_name: _, - result: _, - } => {}, - UpdateEventMessage::InitStart { server_name } => { - pending_clone.write().await.insert(server_name.clone()); - loading_servers.insert(server_name, std::time::Instant::now()); - }, - } - } - }); - for (mut name, init_res) in pre_initialized { - let messenger = messenger_builder.build_with_name(name.clone()); - match init_res { - Ok(mut client) => { - client.assign_messenger(Box::new(messenger)); - let mut client = Arc::new(client); - while let Some(collided_client) = clients.insert(name.clone(), client) { - // to avoid server name collision we are going to circumvent this by - // appending the name with 1 - name.push('1'); - client = collided_client; - } - }, - Err(e) => { - error!("Error initializing mcp client for server {}: {:?}", name, &e); - os.telemetry - .send_mcp_server_init(&os.database, conversation_id.clone(), Some(e.to_string()), 0) - .await - .ok(); - let _ = messenger.send_tools_list_result(Err(e)).await; - }, - } - } - - // Set up task to handle prompt requests - let sender = self.prompt_list_sender.take(); - let receiver = self.prompt_list_receiver.take(); - let prompts = Arc::new(SyncRwLock::new(HashMap::default())); - // TODO: accommodate hot reload of mcp servers - if let (Some(sender), Some(receiver)) = (sender, receiver) { - let clients = clients.iter().fold(HashMap::new(), |mut acc, (n, c)| { - acc.insert(n.clone(), Arc::downgrade(c)); - acc - }); - let prompts_clone = prompts.clone(); - tokio::task::spawn_blocking(move || { - let receiver = Arc::new(std::sync::Mutex::new(receiver)); - loop { - let search_word = receiver.lock().map_err(|e| eyre::eyre!("{:?}", e))?.recv()?; - if clients - .values() - .any(|client| client.upgrade().is_some_and(|c| c.is_prompts_out_of_date())) - { - let mut prompts_wl = prompts_clone.write().map_err(|e| { - eyre::eyre!( - "Error retrieving write lock on prompts for tab complete {}", - e.to_string() - ) - })?; - *prompts_wl = clients.iter().fold( - HashMap::>::new(), - |mut acc, (server_name, client)| { - let Some(client) = client.upgrade() else { - return acc; - }; - let prompt_gets = client.list_prompt_gets(); - let Ok(prompt_gets) = prompt_gets.read() else { - tracing::error!("Error retrieving read lock for prompt gets for tab complete"); - return acc; - }; - for (prompt_name, prompt_get) in prompt_gets.iter() { - acc.entry(prompt_name.clone()) - .and_modify(|bundles| { - bundles.push(PromptBundle { - server_name: server_name.to_owned(), - prompt_get: prompt_get.clone(), - }); - }) - .or_insert(vec![PromptBundle { - server_name: server_name.to_owned(), - prompt_get: prompt_get.clone(), - }]); - } - client.prompts_updated(); - acc - }, - ); - } - let prompts_rl = prompts_clone.read().map_err(|e| { - eyre::eyre!( - "Error retrieving read lock on prompts for tab complete {}", - e.to_string() - ) - })?; - let filtered_prompts = prompts_rl - .iter() - .flat_map(|(prompt_name, bundles)| { - if bundles.len() > 1 { - bundles - .iter() - .map(|b| format!("{}/{}", b.server_name, prompt_name)) - .collect() - } else { - vec![prompt_name.to_owned()] - } - }) - .filter(|n| { - if let Some(p) = &search_word { - n.contains(p) - } else { - true - } - }) - .collect::>(); - if let Err(e) = sender.send(filtered_prompts) { - error!("Error sending prompts to chat helper: {:?}", e); - } - } - #[allow(unreachable_code)] - Ok::<(), eyre::Report>(()) - }); - } - - Ok(ToolManager { - conversation_id, - clients, - prompts, - pending_clients: pending, - notify: Some(notify), - loading_status_sender, - new_tool_specs, - has_new_stuff, - is_interactive: interactive, - mcp_load_record: load_record, - disabled_servers: disabled_servers_display, - ..Default::default() - }) - } -} - -#[derive(Clone, Debug)] -/// A collection of information that is used for the following purposes: -/// - Checking if prompt info cached is out of date -/// - Retrieve new prompt info -pub struct PromptBundle { - /// The server name from which the prompt is offered / exposed - pub server_name: String, - /// The prompt get (info with which a prompt is retrieved) cached - pub prompt_get: PromptGet, -} - -/// Categorizes different types of tool name validation failures: -/// - `TooLong`: The tool name exceeds the maximum allowed length -/// - `IllegalChar`: The tool name contains characters that are not allowed -/// - `EmptyDescription`: The tool description is empty or missing -#[allow(dead_code)] -enum OutOfSpecName { - TooLong(String), - IllegalChar(String), - EmptyDescription(String), -} - -type NewToolSpecs = Arc, Vec)>>>; - -#[derive(Default, Debug)] -/// Manages the lifecycle and interactions with tools from various sources, including MCP servers. -/// This struct is responsible for initializing tools, handling tool requests, and maintaining -/// a cache of available prompts from connected servers. -pub struct ToolManager { - /// Unique identifier for the current conversation. - /// This ID is used to track and associate tools with a specific chat session. - pub conversation_id: String, - - /// Map of server names to their corresponding client instances. - /// These clients are used to communicate with MCP servers. - pub clients: HashMap>, - - /// A list of client names that are still in the process of being initialized - pub pending_clients: Arc>>, - - /// Flag indicating whether new tool specifications have been added since the last update. - /// When set to true, it signals that the tool manager needs to refresh its internal state - /// to incorporate newly available tools from MCP servers. - pub has_new_stuff: Arc, - - /// Storage for newly discovered tool specifications from MCP servers that haven't yet been - /// integrated into the main tool registry. This field holds a thread-safe reference to a map - /// of server names to their tool specifications and name mappings, allowing concurrent updates - /// from server initialization processes. - new_tool_specs: NewToolSpecs, - - /// Cache for prompts collected from different servers. - /// Key: prompt name - /// Value: a list of PromptBundle that has a prompt of this name. - /// This cache helps resolve prompt requests efficiently and handles - /// cases where multiple servers offer prompts with the same name. - pub prompts: Arc>>>, - - /// A notifier to understand if the initial loading has completed. - /// This is only used for initial loading and is discarded after. - notify: Option>, - - /// Channel sender for communicating with the loading display thread. - /// Used to send status updates about tool initialization progress. - loading_status_sender: Option>, - - /// Mapping from sanitized tool names to original tool names. - /// This is used to handle tool name transformations that may occur during initialization - /// to ensure tool names comply with naming requirements. - pub tn_map: HashMap, - - /// A cache of tool's input schema for all of the available tools. - /// This is mainly used to show the user what the tools look like from the perspective of the - /// model. - pub schema: HashMap, - - is_interactive: bool, - - /// This serves as a record of the loading of mcp servers. - /// The key of which is the server name as they are recognized by the current instance of chat - /// (which may be different than how it is written in the config, depending of the presence of - /// invalid characters). - /// The value is the load message (i.e. load time, warnings, and errors) - pub mcp_load_record: Arc>>>, - - /// List of disabled MCP server names for display purposes - disabled_servers: Vec, -} - -impl Clone for ToolManager { - fn clone(&self) -> Self { - Self { - conversation_id: self.conversation_id.clone(), - clients: self.clients.clone(), - has_new_stuff: self.has_new_stuff.clone(), - new_tool_specs: self.new_tool_specs.clone(), - prompts: self.prompts.clone(), - tn_map: self.tn_map.clone(), - schema: self.schema.clone(), - is_interactive: self.is_interactive, - mcp_load_record: self.mcp_load_record.clone(), - disabled_servers: self.disabled_servers.clone(), - ..Default::default() - } - } -} - -impl ToolManager { - pub async fn load_tools( - &mut self, - os: &mut Os, - stderr: &mut impl Write, - ) -> eyre::Result> { - let tx = self.loading_status_sender.take(); - let notify = self.notify.take(); - self.schema = { - let mut tool_specs = - serde_json::from_str::>(include_str!("tools/tool_index.json"))?; - if !crate::cli::chat::tools::thinking::Thinking::is_enabled(os) { - tool_specs.remove("thinking"); - } - if !crate::cli::chat::tools::knowledge::Knowledge::is_enabled(os) { - tool_specs.remove("knowledge"); - } - - #[cfg(windows)] - { - use serde_json::json; - - use crate::cli::chat::tools::InputSchema; - - tool_specs.remove("execute_bash"); - - tool_specs.insert("execute_cmd".to_string(), ToolSpec { - name: "execute_cmd".to_string(), - description: "Execute the specified Windows command.".to_string(), - input_schema: InputSchema(json!({ - "type": "object", - "properties": { - "command": { - "type": "string", - "description": "Windows command to execute" - }, - "summary": { - "type": "string", - "description": "A brief explanation of what the command does" - } - }, - "required": ["command"]})), - tool_origin: ToolOrigin::Native, - }); - } - - tool_specs - }; - let load_tools = self - .clients - .values() - .map(|c| { - let clone = Arc::clone(c); - async move { clone.init().await } - }) - .collect::>(); - let initial_poll = stream::iter(load_tools) - .map(|async_closure| tokio::spawn(async_closure)) - .buffer_unordered(20); - tokio::spawn(async move { - initial_poll.collect::>().await; - }); - // We need to cast it to erase the type otherwise the compiler will default to static - // dispatch, which would result in an error of inconsistent match arm return type. - let timeout_fut: Pin>> = if self.clients.is_empty() { - // If there is no server loaded, we want to resolve immediately - Box::pin(future::ready(())) - } else if self.is_interactive { - let init_timeout = os - .database - .settings - .get_int(Setting::McpInitTimeout) - .map_or(5000_u64, |s| s as u64); - Box::pin(tokio::time::sleep(std::time::Duration::from_millis(init_timeout))) - } else { - // if it is non-interactive we will want to use the "mcp.noInteractiveTimeout" - let init_timeout = os - .database - .settings - .get_int(Setting::McpNoInteractiveTimeout) - .map_or(30_000_u64, |s| s as u64); - Box::pin(tokio::time::sleep(std::time::Duration::from_millis(init_timeout))) - }; - let server_loading_fut: Pin>> = if let Some(notify) = notify { - Box::pin(async move { notify.notified().await }) - } else { - Box::pin(future::ready(())) - }; - tokio::select! { - _ = timeout_fut => { - if let Some(tx) = tx { - let still_loading = self.pending_clients.read().await.iter().cloned().collect::>(); - let _ = tx.send(LoadingMsg::Terminate { still_loading }).await; - } - if !self.clients.is_empty() && !self.is_interactive { - let _ = queue!( - stderr, - style::Print( - "Not all mcp servers loaded. Configure non-interactive timeout with q settings mcp.noInteractiveTimeout" - ), - style::Print("\n------\n") - ); - } - }, - _ = server_loading_fut => { - if let Some(tx) = tx { - let still_loading = self.pending_clients.read().await.iter().cloned().collect::>(); - let _ = tx.send(LoadingMsg::Terminate { still_loading }).await; - } - } - _ = ctrl_c() => { - if self.is_interactive { - if let Some(tx) = tx { - let still_loading = self.pending_clients.read().await.iter().cloned().collect::>(); - let _ = tx.send(LoadingMsg::Terminate { still_loading }).await; - } - } else { - return Err(eyre::eyre!("User interrupted mcp server loading in non-interactive mode. Ending.")); - } - } - } - if !self.is_interactive - && self - .mcp_load_record - .lock() - .await - .iter() - .any(|(_, records)| records.iter().any(|record| matches!(record, LoadingRecord::Err(_)))) - { - queue!( - stderr, - style::Print( - "One or more mcp server did not load correctly. See $TMPDIR/qlog/chat.log for more details." - ), - style::Print("\n------\n") - )?; - } - self.update().await; - Ok(self.schema.clone()) - } - - pub fn get_tool_from_tool_use(&self, value: AssistantToolUse) -> Result { - let map_err = |parse_error| ToolResult { - tool_use_id: value.id.clone(), - content: vec![ToolResultContentBlock::Text(format!( - "Failed to validate tool parameters: {parse_error}. The model has either suggested tool parameters which are incompatible with the existing tools, or has suggested one or more tool that does not exist in the list of known tools." - ))], - status: ToolResultStatus::Error, - }; - - Ok(match value.name.as_str() { - "fs_read" => Tool::FsRead(serde_json::from_value::(value.args).map_err(map_err)?), - "fs_write" => Tool::FsWrite(serde_json::from_value::(value.args).map_err(map_err)?), - #[cfg(windows)] - "execute_cmd" => { - Tool::ExecuteCommand(serde_json::from_value::(value.args).map_err(map_err)?) - }, - #[cfg(not(windows))] - "execute_bash" => { - Tool::ExecuteCommand(serde_json::from_value::(value.args).map_err(map_err)?) - }, - "use_aws" => Tool::UseAws(serde_json::from_value::(value.args).map_err(map_err)?), - "report_issue" => Tool::GhIssue(serde_json::from_value::(value.args).map_err(map_err)?), - "thinking" => Tool::Thinking(serde_json::from_value::(value.args).map_err(map_err)?), - "knowledge" => Tool::Knowledge(serde_json::from_value::(value.args).map_err(map_err)?), - // Note that this name is namespaced with server_name{DELIMITER}tool_name - name => { - // Note: tn_map also has tools that underwent no transformation. In otherwords, if - // it is a valid tool name, we should get a hit. - let name = match self.tn_map.get(name) { - Some(name) => Ok::<&str, ToolResult>(name.as_str()), - None => { - // There are three possibilities: - // - The tool name supplied is valid, it's just missing the server name - // prefix. - // - The tool name supplied is valid, it's missing the server name prefix - // and there are more than one possible tools that fit this description. - // - No server has a tool with this name. - let candidates = self.tn_map.keys().filter(|n| n.ends_with(name)).collect::>(); - #[allow(clippy::comparison_chain)] - if candidates.len() == 1 { - Ok(candidates.first().map(|s| s.as_str()).unwrap()) - } else if candidates.len() > 1 { - let mut content = candidates.iter().fold( - "There are multiple tools with given tool name: ".to_string(), - |mut acc, name| { - acc.push_str(name); - acc.push_str(", "); - acc - }, - ); - content.push_str("specify a tool with its full name."); - Err(ToolResult { - tool_use_id: value.id.clone(), - content: vec![ToolResultContentBlock::Text(content)], - status: ToolResultStatus::Error, - }) - } else { - Err(ToolResult { - tool_use_id: value.id.clone(), - content: vec![ToolResultContentBlock::Text(format!( - "The tool, \"{name}\" is supplied with incorrect name" - ))], - status: ToolResultStatus::Error, - }) - } - }, - }?; - let name = self.tn_map.get(name).map_or(name, String::as_str); - let (server_name, tool_name) = name.split_once(NAMESPACE_DELIMITER).ok_or(ToolResult { - tool_use_id: value.id.clone(), - content: vec![ToolResultContentBlock::Text(format!( - "The tool, \"{name}\" is supplied with incorrect name" - ))], - status: ToolResultStatus::Error, - })?; - let Some(client) = self.clients.get(server_name) else { - return Err(ToolResult { - tool_use_id: value.id, - content: vec![ToolResultContentBlock::Text(format!( - "The tool, \"{server_name}\" is not supported by the client" - ))], - status: ToolResultStatus::Error, - }); - }; - // The tool input schema has the shape of { type, properties }. - // The field "params" expected by MCP is { name, arguments }, where name is the - // name of the tool being invoked, - // https://spec.modelcontextprotocol.io/specification/2024-11-05/server/tools/#calling-tools. - // The field "arguments" is where ToolUse::args belong. - let mut params = serde_json::Map::::new(); - params.insert("name".to_owned(), serde_json::Value::String(tool_name.to_owned())); - params.insert("arguments".to_owned(), value.args); - let params = serde_json::Value::Object(params); - let custom_tool = CustomTool { - name: tool_name.to_owned(), - client: client.clone(), - method: "tools/call".to_owned(), - params: Some(params), - }; - Tool::Custom(custom_tool) - }, - }) - } - - /// Updates tool managers various states with new information - pub async fn update(&mut self) { - // A hashmap of - let mut tool_specs = HashMap::::new(); - let new_tools = { - let mut new_tool_specs = self.new_tool_specs.lock().await; - new_tool_specs.drain().fold(HashMap::new(), |mut acc, (k, v)| { - acc.insert(k, v); - acc - }) - }; - let mut updated_servers = HashSet::::new(); - for (server_name, (tool_name_map, specs)) in new_tools { - let target = format!("{server_name}{NAMESPACE_DELIMITER}"); - self.tn_map.retain(|k, _| !k.starts_with(&target)); - for (k, v) in tool_name_map { - self.tn_map.insert(k, v); - } - if let Some(spec) = specs.first() { - updated_servers.insert(spec.tool_origin.clone()); - } - for spec in specs { - tool_specs.insert(spec.name.clone(), spec); - } - } - // Caching the tool names for skim operations - for tool_name in tool_specs.keys() { - if !self.tn_map.contains_key(tool_name) { - self.tn_map.insert(tool_name.clone(), tool_name.clone()); - } - } - // Update schema - // As we are writing over the ensemble of tools in a given server, we will need to first - // remove everything that it has. - self.schema - .retain(|_tool_name, spec| !updated_servers.contains(&spec.tool_origin)); - self.schema.extend(tool_specs); - } - - #[allow(clippy::await_holding_lock)] - pub async fn get_prompt( - &self, - name: String, - arguments: Option>, - ) -> Result { - let (server_name, prompt_name) = match name.split_once('/') { - None => (None::, Some(name.clone())), - Some((server_name, prompt_name)) => (Some(server_name.to_string()), Some(prompt_name.to_string())), - }; - let prompt_name = prompt_name.ok_or(GetPromptError::MissingPromptName)?; - // We need to use a sync lock here because this lock is also used in a blocking thread, - // necessitated by the fact that said thread is also responsible for using a sync channel, - // which is itself necessitated by the fact that consumer of said channel is calling from a - // sync function - let mut prompts_wl = self - .prompts - .write() - .map_err(|e| GetPromptError::Synchronization(e.to_string()))?; - let mut maybe_bundles = prompts_wl.get(&prompt_name); - let mut has_retried = false; - 'blk: loop { - match (maybe_bundles, server_name.as_ref(), has_retried) { - // If we have more than one eligible clients but no server name specified - (Some(bundles), None, _) if bundles.len() > 1 => { - break 'blk Err(GetPromptError::AmbiguousPrompt(prompt_name.clone(), { - bundles.iter().fold("\n".to_string(), |mut acc, b| { - acc.push_str(&format!("- @{}/{}\n", b.server_name, prompt_name)); - acc - }) - })); - }, - // Normal case where we have enough info to proceed - // Note that if bundle exists, it should never be empty - (Some(bundles), sn, _) => { - let bundle = if bundles.len() > 1 { - let Some(server_name) = sn else { - maybe_bundles = None; - continue 'blk; - }; - let bundle = bundles.iter().find(|b| b.server_name == *server_name); - match bundle { - Some(bundle) => bundle, - None => { - maybe_bundles = None; - continue 'blk; - }, - } - } else { - bundles.first().ok_or(GetPromptError::MissingPromptInfo)? - }; - let server_name = bundle.server_name.clone(); - let client = self.clients.get(&server_name).ok_or(GetPromptError::MissingClient)?; - // Here we lazily update the out of date cache - if client.is_prompts_out_of_date() { - let prompt_gets = client.list_prompt_gets(); - let prompt_gets = prompt_gets - .read() - .map_err(|e| GetPromptError::Synchronization(e.to_string()))?; - for (prompt_name, prompt_get) in prompt_gets.iter() { - prompts_wl - .entry(prompt_name.clone()) - .and_modify(|bundles| { - let mut is_modified = false; - for bundle in &mut *bundles { - let mut updated_bundle = PromptBundle { - server_name: server_name.clone(), - prompt_get: prompt_get.clone(), - }; - if bundle.server_name == *server_name { - std::mem::swap(bundle, &mut updated_bundle); - is_modified = true; - break; - } - } - if !is_modified { - bundles.push(PromptBundle { - server_name: server_name.clone(), - prompt_get: prompt_get.clone(), - }); - } - }) - .or_insert(vec![PromptBundle { - server_name: server_name.clone(), - prompt_get: prompt_get.clone(), - }]); - } - client.prompts_updated(); - } - - let PromptBundle { prompt_get, .. } = prompts_wl - .get(&prompt_name) - .and_then(|bundles| bundles.iter().find(|b| b.server_name == server_name)) - .ok_or(GetPromptError::MissingPromptInfo)?; - - // Here we need to convert the positional arguments into key value pair - // The assignment order is assumed to be the order of args as they are - // presented in PromptGet::arguments - let args = if let (Some(schema), Some(value)) = (&prompt_get.arguments, &arguments) { - let params = schema.iter().zip(value.iter()).fold( - HashMap::::new(), - |mut acc, (prompt_get_arg, value)| { - acc.insert(prompt_get_arg.name.clone(), value.clone()); - acc - }, - ); - Some(serde_json::json!(params)) - } else { - None - }; - let params = { - let mut params = serde_json::Map::new(); - params.insert("name".to_string(), serde_json::Value::String(prompt_name)); - if let Some(args) = args { - params.insert("arguments".to_string(), args); - } - Some(serde_json::Value::Object(params)) - }; - let resp = client.request("prompts/get", params).await?; - break 'blk Ok(resp); - }, - // If we have no eligible clients this would mean one of the following: - // - The prompt does not exist, OR - // - This is the first time we have a query / our cache is out of date - // Both of which means we would have to requery - (None, _, false) => { - has_retried = true; - self.refresh_prompts(&mut prompts_wl)?; - maybe_bundles = prompts_wl.get(&prompt_name); - }, - (_, _, true) => { - break 'blk Err(GetPromptError::PromptNotFound(prompt_name)); - }, - } - } - } - - pub fn refresh_prompts(&self, prompts_wl: &mut HashMap>) -> Result<(), GetPromptError> { - *prompts_wl = self.clients.iter().fold( - HashMap::>::new(), - |mut acc, (server_name, client)| { - let prompt_gets = client.list_prompt_gets(); - let Ok(prompt_gets) = prompt_gets.read() else { - tracing::error!("Error encountered while retrieving read lock"); - return acc; - }; - for (prompt_name, prompt_get) in prompt_gets.iter() { - acc.entry(prompt_name.clone()) - .and_modify(|bundles| { - bundles.push(PromptBundle { - server_name: server_name.to_owned(), - prompt_get: prompt_get.clone(), - }); - }) - .or_insert(vec![PromptBundle { - server_name: server_name.to_owned(), - prompt_get: prompt_get.clone(), - }]); - } - acc - }, - ); - Ok(()) - } - - pub async fn pending_clients(&self) -> Vec { - self.pending_clients.read().await.iter().cloned().collect::>() - } -} - -#[inline] -async fn process_tool_specs( - conversation_id: &str, - server_name: &str, - specs: &mut Vec, - tn_map: &mut HashMap, - regex: &Regex, - telemetry: &TelemetryThread, - database: &Database, -) -> eyre::Result<()> { - // Each mcp server might have multiple tools. - // To avoid naming conflicts we are going to namespace it. - // This would also help us locate which mcp server to call the tool from. - let mut out_of_spec_tool_names = Vec::::new(); - let mut hasher = DefaultHasher::new(); - let number_of_tools = specs.len(); - // Sanitize tool names to ensure they comply with the naming requirements: - // 1. If the name already matches the regex pattern and doesn't contain the namespace delimiter, use - // it as is - // 2. Otherwise, remove invalid characters and handle special cases: - // - Remove namespace delimiters - // - Ensure the name starts with an alphabetic character - // - Generate a hash-based name if the sanitized result is empty - // This ensures all tool names are valid identifiers that can be safely used in the system - // If after all of the aforementioned modification the combined tool - // name we have exceeds a length of 64, we surface it as an error - for spec in specs.iter_mut() { - let sn = if !regex.is_match(&spec.name) { - let mut sn = sanitize_name(spec.name.clone(), regex, &mut hasher); - while tn_map.contains_key(&sn) { - sn.push('1'); - } - sn - } else { - spec.name.clone() - }; - let full_name = format!("{}{}{}", server_name, NAMESPACE_DELIMITER, sn); - if full_name.len() > 64 { - out_of_spec_tool_names.push(OutOfSpecName::TooLong(spec.name.clone())); - continue; - } else if spec.description.is_empty() { - out_of_spec_tool_names.push(OutOfSpecName::EmptyDescription(spec.name.clone())); - continue; - } - if sn != spec.name { - tn_map.insert( - full_name.clone(), - format!("{}{}{}", server_name, NAMESPACE_DELIMITER, spec.name), - ); - } - spec.name = full_name; - spec.tool_origin = ToolOrigin::McpServer(server_name.to_string()); - } - // Native origin is the default, and since this function never reads native tools, if we still - // have it, that would indicate a tool that should not be included. - specs.retain(|spec| !matches!(spec.tool_origin, ToolOrigin::Native)); - // Send server load success metric datum - let conversation_id = conversation_id.to_string(); - telemetry - .send_mcp_server_init(database, conversation_id, None, number_of_tools) - .await - .ok(); - // Tool name translation. This is beyond of the scope of what is - // considered a "server load". Reasoning being: - // - Failures here are not related to server load - // - There is not a whole lot we can do with this data - if !out_of_spec_tool_names.is_empty() { - Err(eyre::eyre!(out_of_spec_tool_names.iter().fold( - String::from( - "The following tools are out of spec. They will be excluded from the list of available tools:\n", - ), - |mut acc, name| { - let (tool_name, msg) = match name { - OutOfSpecName::TooLong(tool_name) => ( - tool_name.as_str(), - "tool name exceeds max length of 64 when combined with server name", - ), - OutOfSpecName::IllegalChar(tool_name) => ( - tool_name.as_str(), - "tool name must be compliant with ^[a-zA-Z][a-zA-Z0-9_]*$", - ), - OutOfSpecName::EmptyDescription(tool_name) => { - (tool_name.as_str(), "tool schema contains empty description") - }, - }; - acc.push_str(format!(" - {} ({})\n", tool_name, msg).as_str()); - acc - }, - ))) - // TODO: if no tools are valid, we need to offload the server - // from the fleet (i.e. kill the server) - } else if !tn_map.is_empty() { - Err(eyre::eyre!(tn_map.iter().fold( - String::from("The following tool names are changed:\n"), - |mut acc, (k, v)| { - acc.push_str(format!(" - {} -> {}\n", v, k).as_str()); - acc - }, - ))) - } else { - Ok(()) - } -} - -fn sanitize_name(orig: String, regex: ®ex::Regex, hasher: &mut impl Hasher) -> String { - if regex.is_match(&orig) && !orig.contains(NAMESPACE_DELIMITER) { - return orig; - } - let sanitized: String = orig - .chars() - .filter(|c| c.is_ascii_alphabetic() || c.is_ascii_digit() || *c == '_') - .collect::() - .replace(NAMESPACE_DELIMITER, ""); - if sanitized.is_empty() { - hasher.write(orig.as_bytes()); - let hash = format!("{:03}", hasher.finish() % 1000); - return format!("a{}", hash); - } - match sanitized.chars().next() { - Some(c) if c.is_ascii_alphabetic() => sanitized, - Some(_) => { - format!("a{}", sanitized) - }, - None => { - hasher.write(orig.as_bytes()); - format!("a{}", hasher.finish()) - }, - } -} - -fn queue_success_message(name: &str, time_taken: &str, output: &mut impl Write) -> eyre::Result<()> { - Ok(queue!( - output, - style::SetForegroundColor(style::Color::Green), - style::Print("✓ "), - style::SetForegroundColor(style::Color::Blue), - style::Print(name), - style::ResetColor, - style::Print(" loaded in "), - style::SetForegroundColor(style::Color::Yellow), - style::Print(format!("{time_taken} s\n")), - style::ResetColor, - )?) -} - -fn queue_init_message( - spinner_logo_idx: usize, - complete: usize, - failed: usize, - total: usize, - output: &mut impl Write, -) -> eyre::Result<()> { - if total == complete { - queue!( - output, - style::SetForegroundColor(style::Color::Green), - style::Print("✓"), - style::ResetColor, - )?; - } else if total == complete + failed { - queue!( - output, - style::SetForegroundColor(style::Color::Red), - style::Print("✗"), - style::ResetColor, - )?; - } else { - queue!(output, style::Print(SPINNER_CHARS[spinner_logo_idx]))?; - } - queue!( - output, - style::SetForegroundColor(style::Color::Blue), - style::Print(format!(" {}", complete)), - style::ResetColor, - style::Print(" of "), - style::SetForegroundColor(style::Color::Blue), - style::Print(format!("{} ", total)), - style::ResetColor, - style::Print("mcp servers initialized."), - )?; - if total > complete + failed { - queue!( - output, - style::SetForegroundColor(style::Color::Blue), - style::Print(" ctrl-c "), - style::ResetColor, - style::Print("to start chatting now") - )?; - } - Ok(queue!(output, style::Print("\n"))?) -} - -fn queue_failure_message( - name: &str, - fail_load_msg: &eyre::Report, - time: &str, - output: &mut impl Write, -) -> eyre::Result<()> { - use crate::util::CHAT_BINARY_NAME; - Ok(queue!( - output, - style::SetForegroundColor(style::Color::Red), - style::Print("✗ "), - style::SetForegroundColor(style::Color::Blue), - style::Print(name), - style::ResetColor, - style::Print(" has failed to load after"), - style::SetForegroundColor(style::Color::Yellow), - style::Print(format!(" {time} s")), - style::ResetColor, - style::Print("\n - "), - style::Print(fail_load_msg), - style::Print("\n"), - style::Print(format!( - " - run with Q_LOG_LEVEL=trace and see $TMPDIR/{CHAT_BINARY_NAME} for detail\n" - )), - style::ResetColor, - )?) -} - -fn queue_warn_message(name: &str, msg: &eyre::Report, time: &str, output: &mut impl Write) -> eyre::Result<()> { - Ok(queue!( - output, - style::SetForegroundColor(style::Color::Yellow), - style::Print("⚠ "), - style::SetForegroundColor(style::Color::Blue), - style::Print(name), - style::ResetColor, - style::Print(" has loaded in"), - style::SetForegroundColor(style::Color::Yellow), - style::Print(format!(" {time} s")), - style::ResetColor, - style::Print(" with the following warning:\n"), - style::Print(msg), - style::ResetColor, - )?) -} - -fn queue_disabled_message(name: &str, output: &mut impl Write) -> eyre::Result<()> { - Ok(queue!( - output, - style::SetForegroundColor(style::Color::DarkGrey), - style::Print("○ "), - style::SetForegroundColor(style::Color::Blue), - style::Print(name), - style::ResetColor, - style::Print(" is disabled\n"), - style::ResetColor, - )?) -} - -fn queue_incomplete_load_message( - complete: usize, - total: usize, - msg: &eyre::Report, - output: &mut impl Write, -) -> eyre::Result<()> { - Ok(queue!( - output, - style::SetForegroundColor(style::Color::Yellow), - style::Print("⚠"), - style::SetForegroundColor(style::Color::Blue), - style::Print(format!(" {}", complete)), - style::ResetColor, - style::Print(" of "), - style::SetForegroundColor(style::Color::Blue), - style::Print(format!("{} ", total)), - style::ResetColor, - style::Print("mcp servers initialized."), - style::ResetColor, - // We expect the message start with a newline - style::Print(" Servers still loading:"), - style::Print(msg), - style::ResetColor, - )?) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_sanitize_server_name() { - let regex = regex::Regex::new(VALID_TOOL_NAME).unwrap(); - let mut hasher = DefaultHasher::new(); - let orig_name = "@awslabs.cdk-mcp-server"; - let sanitized_server_name = sanitize_name(orig_name.to_string(), ®ex, &mut hasher); - assert_eq!(sanitized_server_name, "awslabscdkmcpserver"); - - let orig_name = "good_name"; - let sanitized_good_name = sanitize_name(orig_name.to_string(), ®ex, &mut hasher); - assert_eq!(sanitized_good_name, orig_name); - - let all_bad_name = "@@@@@"; - let sanitized_all_bad_name = sanitize_name(all_bad_name.to_string(), ®ex, &mut hasher); - assert!(regex.is_match(&sanitized_all_bad_name)); - - let with_delim = format!("a{}b{}c", NAMESPACE_DELIMITER, NAMESPACE_DELIMITER); - let sanitized = sanitize_name(with_delim, ®ex, &mut hasher); - assert_eq!(sanitized, "abc"); - } -} diff --git a/crates/chat-cli/src/cli/chat/tools/custom_tool.rs b/crates/chat-cli/src/cli/chat/tools/custom_tool.rs deleted file mode 100644 index 0f7338ee4..000000000 --- a/crates/chat-cli/src/cli/chat/tools/custom_tool.rs +++ /dev/null @@ -1,246 +0,0 @@ -use std::collections::HashMap; -use std::io::Write; -use std::sync::Arc; -use std::sync::atomic::Ordering; - -use crossterm::{ - queue, - style, -}; -use eyre::Result; -use serde::{ - Deserialize, - Serialize, -}; -use tokio::sync::RwLock; -use tracing::warn; - -use super::InvokeOutput; -use crate::cli::chat::CONTINUATION_LINE; -use crate::cli::chat::token_counter::TokenCounter; -use crate::mcp_client::{ - Client as McpClient, - ClientConfig as McpClientConfig, - JsonRpcResponse, - JsonRpcStdioTransport, - MessageContent, - Messenger, - PromptGet, - ServerCapabilities, - StdioTransport, - ToolCallResult, -}; -use crate::os::Os; - -// TODO: support http transport type -#[derive(Clone, Serialize, Deserialize, Debug)] -pub struct CustomToolConfig { - pub command: String, - #[serde(default)] - pub args: Vec, - #[serde(skip_serializing_if = "Option::is_none")] - pub env: Option>, - #[serde(default = "default_timeout")] - pub timeout: u64, - #[serde(default)] - pub disabled: bool, -} - -pub fn default_timeout() -> u64 { - 120 * 1000 -} - -#[derive(Debug)] -pub enum CustomToolClient { - Stdio { - server_name: String, - client: McpClient, - server_capabilities: RwLock>, - }, -} - -impl CustomToolClient { - // TODO: add support for http transport - pub fn from_config(server_name: String, config: CustomToolConfig) -> Result { - let CustomToolConfig { - command, - args, - env, - timeout, - disabled: _, - } = config; - let mcp_client_config = McpClientConfig { - server_name: server_name.clone(), - bin_path: command.clone(), - args, - timeout, - client_info: serde_json::json!({ - "name": "Q CLI Chat", - "version": "1.0.0" - }), - env, - }; - let client = McpClient::::from_config(mcp_client_config)?; - Ok(CustomToolClient::Stdio { - server_name, - client, - server_capabilities: RwLock::new(None), - }) - } - - pub async fn init(&self) -> Result<()> { - match self { - CustomToolClient::Stdio { - client, - server_capabilities, - .. - } => { - if let Some(messenger) = &client.messenger { - let _ = messenger.send_init_msg().await; - } - // We'll need to first initialize. This is the handshake every client and server - // needs to do before proceeding to anything else - let cap = client.init().await?; - // We'll be scrapping this for background server load: https://github.com/aws/amazon-q-developer-cli/issues/1466 - // So don't worry about the tidiness for now - server_capabilities.write().await.replace(cap); - Ok(()) - }, - } - } - - pub fn assign_messenger(&mut self, messenger: Box) { - match self { - CustomToolClient::Stdio { client, .. } => { - client.messenger = Some(messenger); - }, - } - } - - pub fn get_server_name(&self) -> &str { - match self { - CustomToolClient::Stdio { server_name, .. } => server_name.as_str(), - } - } - - pub async fn request(&self, method: &str, params: Option) -> Result { - match self { - CustomToolClient::Stdio { client, .. } => Ok(client.request(method, params).await?), - } - } - - pub fn list_prompt_gets(&self) -> Arc>> { - match self { - CustomToolClient::Stdio { client, .. } => client.prompt_gets.clone(), - } - } - - #[allow(dead_code)] - pub async fn notify(&self, method: &str, params: Option) -> Result<()> { - match self { - CustomToolClient::Stdio { client, .. } => Ok(client.notify(method, params).await?), - } - } - - pub fn is_prompts_out_of_date(&self) -> bool { - match self { - CustomToolClient::Stdio { client, .. } => client.is_prompts_out_of_date.load(Ordering::Relaxed), - } - } - - pub fn prompts_updated(&self) { - match self { - CustomToolClient::Stdio { client, .. } => client.is_prompts_out_of_date.store(false, Ordering::Relaxed), - } - } -} - -/// Represents a custom tool that can be invoked through the Model Context Protocol (MCP). -#[derive(Clone, Debug)] -pub struct CustomTool { - /// Actual tool name as recognized by its MCP server. This differs from the tool names as they - /// are seen by the model since they are not prefixed by its MCP server name. - pub name: String, - /// Reference to the client that manages communication with the tool's server process. - pub client: Arc, - /// The method name to call on the tool's server, following the JSON-RPC convention. - /// This corresponds to a specific functionality provided by the tool. - pub method: String, - /// Optional parameters to pass to the tool when invoking the method. - /// Structured as a JSON value to accommodate various parameter types and structures. - pub params: Option, -} - -impl CustomTool { - pub async fn invoke(&self, _os: &Os, _updates: impl Write) -> Result { - // Assuming a response shape as per https://spec.modelcontextprotocol.io/specification/2024-11-05/server/tools/#calling-tools - let resp = self.client.request(self.method.as_str(), self.params.clone()).await?; - let result = match resp.result { - Some(result) => result, - None => { - let failure = resp.error.map_or("Unknown error encountered".to_string(), |err| { - serde_json::to_string(&err).unwrap_or_default() - }); - return Err(eyre::eyre!(failure)); - }, - }; - - match serde_json::from_value::(result.clone()) { - Ok(mut de_result) => { - for content in &mut de_result.content { - if let MessageContent::Image { data, .. } = content { - *data = format!("Redacted base64 encoded string of an image of size {}", data.len()); - } - } - Ok(InvokeOutput { - output: super::OutputKind::Json(serde_json::json!(de_result)), - }) - }, - Err(e) => { - warn!("Tool call result deserialization failed: {:?}", e); - Ok(InvokeOutput { - output: super::OutputKind::Json(result.clone()), - }) - }, - } - } - - pub fn queue_description(&self, output: &mut impl Write) -> Result<()> { - queue!( - output, - style::Print("Running "), - style::SetForegroundColor(style::Color::Green), - style::Print(&self.name), - style::ResetColor, - )?; - if let Some(params) = &self.params { - let params = match serde_json::to_string_pretty(params) { - Ok(params) => params - .split("\n") - .map(|p| format!("{CONTINUATION_LINE} {p}")) - .collect::>() - .join("\n"), - _ => format!("{:?}", params), - }; - queue!( - output, - style::Print(" with the param:\n"), - style::Print(params), - style::Print("\n"), - style::ResetColor, - )?; - } else { - queue!(output, style::Print("\n"))?; - } - Ok(()) - } - - pub async fn validate(&mut self, _os: &Os) -> Result<()> { - Ok(()) - } - - pub fn get_input_token_size(&self) -> usize { - TokenCounter::count_tokens(self.method.as_str()) - + TokenCounter::count_tokens(self.params.as_ref().map_or("", |p| p.as_str().unwrap_or_default())) - } -} diff --git a/crates/chat-cli/src/cli/chat/tools/execute/mod.rs b/crates/chat-cli/src/cli/chat/tools/execute/mod.rs deleted file mode 100644 index 96ae6891e..000000000 --- a/crates/chat-cli/src/cli/chat/tools/execute/mod.rs +++ /dev/null @@ -1,270 +0,0 @@ -use std::io::Write; - -use crossterm::queue; -use crossterm::style::{ - self, - Color, -}; -use eyre::Result; -use serde::Deserialize; - -use crate::cli::chat::tools::{ - InvokeOutput, - MAX_TOOL_RESPONSE_SIZE, - OutputKind, -}; -use crate::cli::chat::util::truncate_safe; -use crate::os::Os; - -// Platform-specific modules -#[cfg(windows)] -mod windows; -#[cfg(windows)] -pub use windows::*; - -#[cfg(not(windows))] -mod unix; -#[cfg(not(windows))] -pub use unix::*; - -// Common readonly commands that are safe to execute without user confirmation -pub const READONLY_COMMANDS: &[&str] = &[ - "ls", "cat", "echo", "pwd", "which", "head", "tail", "find", "grep", "dir", "type", -]; - -#[derive(Debug, Clone, Deserialize)] -pub struct ExecuteCommand { - pub command: String, - pub summary: Option, -} - -impl ExecuteCommand { - pub fn requires_acceptance(&self) -> bool { - let Some(args) = shlex::split(&self.command) else { - return true; - }; - - const DANGEROUS_PATTERNS: &[&str] = &["<(", "$(", "`", ">", "&&", "||", "&", ";"]; - if args - .iter() - .any(|arg| DANGEROUS_PATTERNS.iter().any(|p| arg.contains(p))) - { - return true; - } - - // Split commands by pipe and check each one - let mut current_cmd = Vec::new(); - let mut all_commands = Vec::new(); - - for arg in args { - if arg == "|" { - if !current_cmd.is_empty() { - all_commands.push(current_cmd); - } - current_cmd = Vec::new(); - } else if arg.contains("|") { - // if pipe appears without spacing e.g. `echo myimportantfile|args rm` it won't get - // parsed out, in this case - we want to verify before running - return true; - } else { - current_cmd.push(arg); - } - } - if !current_cmd.is_empty() { - all_commands.push(current_cmd); - } - - // Check if each command in the pipe chain starts with a safe command - for cmd_args in all_commands { - match cmd_args.first() { - // Special casing for `find` so that we support most cases while safeguarding - // against unwanted mutations - Some(cmd) - if cmd == "find" - && cmd_args.iter().any(|arg| { - arg.contains("-exec") // includes -execdir - || arg.contains("-delete") - || arg.contains("-ok") // includes -okdir - }) => - { - return true; - }, - // Special casing for `grep`. -P flag for perl regexp has RCE issues, apparently - // should not be supported within grep but is flagged as a possibility since this is perl - // regexp. - Some(cmd) if cmd == "grep" && cmd_args.iter().any(|arg| arg.contains("-P")) => { - return true; - }, - Some(cmd) if !READONLY_COMMANDS.contains(&cmd.as_str()) => return true, - None => return true, - _ => (), - } - } - - false - } - - pub async fn invoke(&self, output: &mut impl Write) -> Result { - let output = run_command(&self.command, MAX_TOOL_RESPONSE_SIZE / 3, Some(output)).await?; - let result = serde_json::json!({ - "exit_status": output.exit_status.unwrap_or(0).to_string(), - "stdout": output.stdout, - "stderr": output.stderr, - }); - - Ok(InvokeOutput { - output: OutputKind::Json(result), - }) - } - - pub fn queue_description(&self, output: &mut impl Write) -> Result<()> { - queue!(output, style::Print("I will run the following shell command: "),)?; - - // TODO: Could use graphemes for a better heuristic - if self.command.len() > 20 { - queue!(output, style::Print("\n"),)?; - } - - queue!( - output, - style::SetForegroundColor(Color::Green), - style::Print(&self.command), - style::Print("\n"), - style::ResetColor - )?; - - // Add the summary if available - if let Some(ref summary) = self.summary { - super::display_purpose(Some(summary), output)?; - } - - queue!(output, style::Print("\n"))?; - - Ok(()) - } - - pub async fn validate(&mut self, _os: &Os) -> Result<()> { - // TODO: probably some small amount of PATH checking - Ok(()) - } -} - -pub struct CommandResult { - pub exit_status: Option, - /// Truncated stdout - pub stdout: String, - /// Truncated stderr - pub stderr: String, -} - -// Helper function to format command output with truncation -pub fn format_output(output: &str, max_size: usize) -> String { - format!( - "{}{}", - truncate_safe(output, max_size), - if output.len() > max_size { " ... truncated" } else { "" } - ) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_requires_acceptance_for_readonly_commands() { - let cmds = &[ - // Safe commands - ("ls ~", false), - ("ls -al ~", false), - ("pwd", false), - ("echo 'Hello, world!'", false), - ("which aws", false), - // Potentially dangerous readonly commands - ("echo hi > myimportantfile", true), - ("ls -al >myimportantfile", true), - ("echo hi 2> myimportantfile", true), - ("echo hi >> myimportantfile", true), - ("echo $(rm myimportantfile)", true), - ("echo `rm myimportantfile`", true), - ("echo hello && rm myimportantfile", true), - ("echo hello&&rm myimportantfile", true), - ("ls nonexistantpath || rm myimportantfile", true), - ("echo myimportantfile | xargs rm", true), - ("echo myimportantfile|args rm", true), - ("echo <(rm myimportantfile)", true), - ("cat <<< 'some string here' > myimportantfile", true), - ("echo '\n#!/usr/bin/env bash\necho hello\n' > myscript.sh", true), - ("cat < myimportantfile\nhello world\nEOF", true), - // Safe piped commands - ("find . -name '*.rs' | grep main", false), - ("ls -la | grep .git", false), - ("cat file.txt | grep pattern | head -n 5", false), - // Unsafe piped commands - ("find . -name '*.rs' | rm", true), - ("ls -la | grep .git | rm -rf", true), - ("echo hello | sudo rm -rf /", true), - // `find` command arguments - ("find important-dir/ -exec rm {} \\;", true), - ("find . -name '*.c' -execdir gcc -o '{}.out' '{}' \\;", true), - ("find important-dir/ -delete", true), - ( - "echo y | find . -type f -maxdepth 1 -okdir open -a Calculator {} +", - true, - ), - ("find important-dir/ -name '*.txt'", false), - // `grep` command arguments - ("echo 'test data' | grep -P '(?{system(\"date\")})'", true), - ]; - for (cmd, expected) in cmds { - let tool = serde_json::from_value::(serde_json::json!({ - "command": cmd, - })) - .unwrap(); - assert_eq!( - tool.requires_acceptance(), - *expected, - "expected command: `{}` to have requires_acceptance: `{}`", - cmd, - expected - ); - } - } - - #[test] - fn test_requires_acceptance_for_windows_commands() { - let cmds = &[ - // Safe Windows commands - ("dir", false), - ("type file.txt", false), - ("echo Hello, world!", false), - // Potentially dangerous Windows commands - ("del file.txt", true), - ("rmdir /s /q folder", true), - ("rd /s /q folder", true), - ("format c:", true), - ("erase file.txt", true), - ("copy file.txt > important.txt", true), - ("move file.txt destination", true), - // Command with pipes - ("dir | findstr txt", true), - ("type file.txt | findstr pattern", true), - // Dangerous piped commands - ("dir | del", true), - ("type file.txt | del", true), - ]; - - for (cmd, expected) in cmds { - let tool = serde_json::from_value::(serde_json::json!({ - "command": cmd, - })) - .unwrap(); - assert_eq!( - tool.requires_acceptance(), - *expected, - "expected command: `{}` to have requires_acceptance: `{}`", - cmd, - expected - ); - } - } -} diff --git a/crates/chat-cli/src/cli/chat/tools/execute/unix.rs b/crates/chat-cli/src/cli/chat/tools/execute/unix.rs deleted file mode 100644 index 586dbd29c..000000000 --- a/crates/chat-cli/src/cli/chat/tools/execute/unix.rs +++ /dev/null @@ -1,184 +0,0 @@ -use std::collections::VecDeque; -use std::io::Write; -use std::process::Stdio; - -use eyre::{ - Context as EyreContext, - Result, -}; -use tokio::io::AsyncBufReadExt; -use tokio::select; -use tracing::error; - -use super::{ - CommandResult, - format_output, -}; - -/// Run a bash command on Unix systems. -/// # Arguments -/// * `command` - The command to run -/// * `max_result_size` - max size of output streams, truncating if required -/// * `updates` - output stream to push informational messages about the progress -/// # Returns -/// A [`CommandResult`] -pub async fn run_command( - command: &str, - max_result_size: usize, - mut updates: Option, -) -> Result { - let shell = std::env::var("AMAZON_Q_CHAT_SHELL").unwrap_or("bash".to_string()); - - // We need to maintain a handle on stderr and stdout, but pipe it to the terminal as well - let mut child = tokio::process::Command::new(shell) - .arg("-c") - .arg(command) - .stdin(Stdio::inherit()) - .stdout(Stdio::piped()) - .stderr(Stdio::piped()) - .spawn() - .wrap_err_with(|| format!("Unable to spawn command '{}'", command))?; - - let stdout_final: String; - let stderr_final: String; - let exit_status; - - // Buffered output vs all-at-once - if let Some(u) = updates.as_mut() { - let stdout = child.stdout.take().unwrap(); - let stdout = tokio::io::BufReader::new(stdout); - let mut stdout = stdout.lines(); - - let stderr = child.stderr.take().unwrap(); - let stderr = tokio::io::BufReader::new(stderr); - let mut stderr = stderr.lines(); - - const LINE_COUNT: usize = 1024; - let mut stdout_buf = VecDeque::with_capacity(LINE_COUNT); - let mut stderr_buf = VecDeque::with_capacity(LINE_COUNT); - - let mut stdout_done = false; - let mut stderr_done = false; - exit_status = loop { - select! { - biased; - line = stdout.next_line(), if !stdout_done => match line { - Ok(Some(line)) => { - writeln!(u, "{line}")?; - if stdout_buf.len() >= LINE_COUNT { - stdout_buf.pop_front(); - } - stdout_buf.push_back(line); - }, - Ok(None) => stdout_done = true, - Err(err) => error!(%err, "Failed to read stdout of child process"), - }, - line = stderr.next_line(), if !stderr_done => match line { - Ok(Some(line)) => { - writeln!(u, "{line}")?; - if stderr_buf.len() >= LINE_COUNT { - stderr_buf.pop_front(); - } - stderr_buf.push_back(line); - }, - Ok(None) => stderr_done = true, - Err(err) => error!(%err, "Failed to read stderr of child process"), - }, - exit_status = child.wait() => { - break exit_status; - }, - }; - } - .wrap_err_with(|| format!("No exit status for '{}'", command))?; - - u.flush()?; - - stdout_final = stdout_buf.into_iter().collect::>().join("\n"); - stderr_final = stderr_buf.into_iter().collect::>().join("\n"); - } else { - // Take output all at once since we are not reporting anything in real time - // - // NOTE: If we don't split this logic, then any writes to stdout while calling - // this function concurrently may cause the piped child output to be ignored - - let output = child - .wait_with_output() - .await - .wrap_err_with(|| format!("No exit status for '{}'", command))?; - - exit_status = output.status; - stdout_final = String::from_utf8_lossy(&output.stdout).to_string(); - stderr_final = String::from_utf8_lossy(&output.stderr).to_string(); - } - - Ok(CommandResult { - exit_status: exit_status.code(), - stdout: format_output(&stdout_final, max_result_size), - stderr: format_output(&stderr_final, max_result_size), - }) -} - -#[cfg(test)] -mod tests { - use crate::cli::chat::tools::OutputKind; - use crate::cli::chat::tools::execute::ExecuteCommand; - - #[ignore = "todo: fix failing on musl for some reason"] - #[tokio::test] - async fn test_execute_bash_tool() { - let mut stdout = std::io::stdout(); - - // Verifying stdout - let v = serde_json::json!({ - "command": "echo Hello, world!", - }); - let out = serde_json::from_value::(v) - .unwrap() - .invoke(&mut stdout) - .await - .unwrap(); - - if let OutputKind::Json(json) = out.output { - assert_eq!(json.get("exit_status").unwrap(), &0.to_string()); - assert_eq!(json.get("stdout").unwrap(), "Hello, world!"); - assert_eq!(json.get("stderr").unwrap(), ""); - } else { - panic!("Expected JSON output"); - } - - // Verifying stderr - let v = serde_json::json!({ - "command": "echo Hello, world! 1>&2", - }); - let out = serde_json::from_value::(v) - .unwrap() - .invoke(&mut stdout) - .await - .unwrap(); - - if let OutputKind::Json(json) = out.output { - assert_eq!(json.get("exit_status").unwrap(), &0.to_string()); - assert_eq!(json.get("stdout").unwrap(), ""); - assert_eq!(json.get("stderr").unwrap(), "Hello, world!"); - } else { - panic!("Expected JSON output"); - } - - // Verifying exit code - let v = serde_json::json!({ - "command": "exit 1", - }); - let out = serde_json::from_value::(v) - .unwrap() - .invoke(&mut stdout) - .await - .unwrap(); - if let OutputKind::Json(json) = out.output { - assert_eq!(json.get("exit_status").unwrap(), &1.to_string()); - assert_eq!(json.get("stdout").unwrap(), ""); - assert_eq!(json.get("stderr").unwrap(), ""); - } else { - panic!("Expected JSON output"); - } - } -} diff --git a/crates/chat-cli/src/cli/chat/tools/execute/windows.rs b/crates/chat-cli/src/cli/chat/tools/execute/windows.rs deleted file mode 100644 index 950417b4c..000000000 --- a/crates/chat-cli/src/cli/chat/tools/execute/windows.rs +++ /dev/null @@ -1,177 +0,0 @@ -use std::collections::VecDeque; -use std::io::Write; -use std::process::Stdio; - -use eyre::{ - Context as EyreContext, - Result, -}; -use tokio::io::AsyncBufReadExt; -use tokio::select; -use tracing::error; - -use super::{ - CommandResult, - format_output, -}; - -/// Run a command on Windows using cmd.exe. -/// # Arguments -/// * `command` - The command to run -/// * `max_result_size` - max size of output streams, truncating if required -/// * `updates` - output stream to push informational messages about the progress -/// # Returns -/// A [`CommandResult`] -pub async fn run_command( - command: &str, - max_result_size: usize, - mut updates: Option, -) -> Result { - // We need to maintain a handle on stderr and stdout, but pipe it to the terminal as well - let mut child = tokio::process::Command::new("cmd") - .arg("/C") - .arg(command) - .stdin(Stdio::inherit()) - .stdout(Stdio::piped()) - .stderr(Stdio::piped()) - .spawn() - .wrap_err_with(|| format!("Unable to spawn command '{}'", command))?; - - let stdout_final: String; - let stderr_final: String; - let exit_status; - - // Buffered output vs all-at-once - if let Some(u) = updates.as_mut() { - let stdout = child.stdout.take().unwrap(); - let stdout = tokio::io::BufReader::new(stdout); - let mut stdout = stdout.lines(); - - let stderr = child.stderr.take().unwrap(); - let stderr = tokio::io::BufReader::new(stderr); - let mut stderr = stderr.lines(); - - const LINE_COUNT: usize = 1024; - let mut stdout_buf = VecDeque::with_capacity(LINE_COUNT); - let mut stderr_buf = VecDeque::with_capacity(LINE_COUNT); - - let mut stdout_done = false; - let mut stderr_done = false; - exit_status = loop { - select! { - biased; - line = stdout.next_line(), if !stdout_done => match line { - Ok(Some(line)) => { - writeln!(u, "{line}")?; - if stdout_buf.len() >= LINE_COUNT { - stdout_buf.pop_front(); - } - stdout_buf.push_back(line); - }, - Ok(None) => stdout_done = true, - Err(err) => error!(%err, "Failed to read stdout of child process"), - }, - line = stderr.next_line(), if !stderr_done => match line { - Ok(Some(line)) => { - writeln!(u, "{line}")?; - if stderr_buf.len() >= LINE_COUNT { - stderr_buf.pop_front(); - } - stderr_buf.push_back(line); - }, - Ok(None) => stderr_done = true, - Err(err) => error!(%err, "Failed to read stderr of child process"), - }, - exit_status = child.wait() => { - break exit_status; - }, - }; - } - .wrap_err_with(|| format!("No exit status for '{}'", command))?; - - u.flush()?; - - stdout_final = stdout_buf.into_iter().collect::>().join("\n"); - stderr_final = stderr_buf.into_iter().collect::>().join("\n"); - } else { - // Take output all at once since we are not reporting anything in real time - let output = child - .wait_with_output() - .await - .wrap_err_with(|| format!("No exit status for '{}'", command))?; - - exit_status = output.status; - stdout_final = String::from_utf8_lossy(&output.stdout).to_string(); - stderr_final = String::from_utf8_lossy(&output.stderr).to_string(); - } - - Ok(CommandResult { - exit_status: exit_status.code(), - stdout: format_output(&stdout_final, max_result_size), - stderr: format_output(&stderr_final, max_result_size), - }) -} - -#[cfg(test)] -mod tests { - use crate::cli::chat::tools::OutputKind; - use crate::cli::chat::tools::execute::ExecuteCommand; - - #[tokio::test] - async fn test_execute_cmd_tool() { - let mut stdout = std::io::stdout(); - - // Verifying stdout - let v = serde_json::json!({ - "command": "echo Hello, world!", - }); - let out = serde_json::from_value::(v) - .unwrap() - .invoke(&mut stdout) - .await - .unwrap(); - - if let OutputKind::Json(json) = out.output { - assert_eq!(json.get("exit_status").unwrap(), &0.to_string()); - assert!(json.get("stdout").unwrap().to_string().contains("Hello, world!")); - assert_eq!(json.get("stderr").unwrap(), ""); - } else { - panic!("Expected JSON output"); - } - - // Verifying stderr (using 2>&1 redirection for Windows) - let v = serde_json::json!({ - "command": "echo Hello, world! 1>&2", - }); - let out = serde_json::from_value::(v) - .unwrap() - .invoke(&mut stdout) - .await - .unwrap(); - - if let OutputKind::Json(json) = out.output { - assert_eq!(json.get("exit_status").unwrap(), &0.to_string()); - assert_eq!(json.get("stdout").unwrap(), ""); - assert!(json.get("stderr").unwrap().to_string().contains("Hello, world!")); - } else { - panic!("Expected JSON output"); - } - - // Verifying exit code - let v = serde_json::json!({ - "command": "exit /b 1", - }); - let out = serde_json::from_value::(v) - .unwrap() - .invoke(&mut stdout) - .await - .unwrap(); - if let OutputKind::Json(json) = out.output { - assert_eq!(json.get("exit_status").unwrap(), &1.to_string()); - assert_eq!(json.get("stdout").unwrap(), ""); - assert_eq!(json.get("stderr").unwrap(), ""); - } else { - panic!("Expected JSON output"); - } - } -} diff --git a/crates/chat-cli/src/cli/chat/tools/fs_read.rs b/crates/chat-cli/src/cli/chat/tools/fs_read.rs deleted file mode 100644 index 00ad936b8..000000000 --- a/crates/chat-cli/src/cli/chat/tools/fs_read.rs +++ /dev/null @@ -1,988 +0,0 @@ -use std::collections::VecDeque; -use std::fs::Metadata; -use std::io::Write; - -use crossterm::queue; -use crossterm::style::{ - self, - Color, - Stylize, -}; -use eyre::{ - Result, - bail, -}; -use serde::{ - Deserialize, - Serialize, -}; -use syntect::util::LinesWithEndings; -use tracing::{ - debug, - warn, -}; - -use super::{ - InvokeOutput, - MAX_TOOL_RESPONSE_SIZE, - OutputKind, - format_path, - sanitize_path_tool_arg, -}; -use crate::cli::chat::CONTINUATION_LINE; -use crate::cli::chat::util::images::{ - handle_images_from_paths, - is_supported_image_type, - pre_process, -}; -use crate::os::Os; - -const CHECKMARK: &str = "✔"; -const CROSS: &str = "✘"; - -#[derive(Debug, Clone, Deserialize)] -#[serde(tag = "mode")] -pub enum FsRead { - Line(FsLine), - Directory(FsDirectory), - Search(FsSearch), - Image(FsImage), -} - -impl FsRead { - pub async fn validate(&mut self, os: &Os) -> Result<()> { - match self { - FsRead::Line(fs_line) => fs_line.validate(os).await, - FsRead::Directory(fs_directory) => fs_directory.validate(os).await, - FsRead::Search(fs_search) => fs_search.validate(os).await, - FsRead::Image(fs_image) => fs_image.validate(os).await, - } - } - - pub async fn queue_description(&self, os: &Os, updates: &mut impl Write) -> Result<()> { - match self { - FsRead::Line(fs_line) => fs_line.queue_description(os, updates).await, - FsRead::Directory(fs_directory) => fs_directory.queue_description(updates), - FsRead::Search(fs_search) => fs_search.queue_description(updates), - FsRead::Image(fs_image) => fs_image.queue_description(updates), - } - } - - pub async fn invoke(&self, os: &Os, updates: &mut impl Write) -> Result { - match self { - FsRead::Line(fs_line) => fs_line.invoke(os, updates).await, - FsRead::Directory(fs_directory) => fs_directory.invoke(os, updates).await, - FsRead::Search(fs_search) => fs_search.invoke(os, updates).await, - FsRead::Image(fs_image) => fs_image.invoke(updates).await, - } - } -} - -/// Read images from given paths. -#[derive(Debug, Clone, Deserialize)] -pub struct FsImage { - pub image_paths: Vec, -} - -impl FsImage { - pub async fn validate(&mut self, os: &Os) -> Result<()> { - for path in &self.image_paths { - let path = sanitize_path_tool_arg(os, path); - if let Some(path) = path.to_str() { - let processed_path = pre_process(path); - if !is_supported_image_type(&processed_path) { - bail!("'{}' is not a supported image type", &processed_path); - } - let is_file = os.fs.symlink_metadata(&processed_path).await?.is_file(); - if !is_file { - bail!("'{}' is not a file", &processed_path); - } - } else { - bail!("Unable to parse path"); - } - } - Ok(()) - } - - pub async fn invoke(&self, updates: &mut impl Write) -> Result { - let pre_processed_paths: Vec = self.image_paths.iter().map(|path| pre_process(path)).collect(); - let valid_images = handle_images_from_paths(updates, &pre_processed_paths); - Ok(InvokeOutput { - output: OutputKind::Images(valid_images), - }) - } - - pub fn queue_description(&self, updates: &mut impl Write) -> Result<()> { - queue!( - updates, - style::Print("Reading images: \n"), - style::SetForegroundColor(Color::Green), - style::Print(&self.image_paths.join("\n")), - style::ResetColor, - )?; - Ok(()) - } -} - -/// Read lines from a file. -#[derive(Debug, Clone, Deserialize)] -pub struct FsLine { - pub path: String, - pub start_line: Option, - pub end_line: Option, -} - -impl FsLine { - const DEFAULT_END_LINE: i32 = -1; - const DEFAULT_START_LINE: i32 = 1; - - pub async fn validate(&mut self, os: &Os) -> Result<()> { - let path = sanitize_path_tool_arg(os, &self.path); - if !path.exists() { - bail!("'{}' does not exist", self.path); - } - let is_file = os.fs.symlink_metadata(&path).await?.is_file(); - if !is_file { - bail!("'{}' is not a file", self.path); - } - Ok(()) - } - - pub async fn queue_description(&self, os: &Os, updates: &mut impl Write) -> Result<()> { - let path = sanitize_path_tool_arg(os, &self.path); - let file_bytes = os.fs.read(&path).await?; - let file_content = String::from_utf8_lossy(&file_bytes); - let line_count = file_content.lines().count(); - queue!( - updates, - style::Print("Reading file: "), - style::SetForegroundColor(Color::Green), - style::Print(&self.path), - style::ResetColor, - style::Print(", "), - )?; - - let start = convert_negative_index(line_count, self.start_line()) + 1; - let end = convert_negative_index(line_count, self.end_line()) + 1; - match (start, end) { - _ if start == 1 && end == line_count => Ok(queue!(updates, style::Print("all lines".to_string()))?), - _ if end == line_count => Ok(queue!( - updates, - style::Print("from line "), - style::SetForegroundColor(Color::Green), - style::Print(start), - style::ResetColor, - style::Print(" to end of file"), - )?), - _ => Ok(queue!( - updates, - style::Print("from line "), - style::SetForegroundColor(Color::Green), - style::Print(start), - style::ResetColor, - style::Print(" to "), - style::SetForegroundColor(Color::Green), - style::Print(end), - style::ResetColor, - )?), - } - } - - pub async fn invoke(&self, os: &Os, _updates: &mut impl Write) -> Result { - let path = sanitize_path_tool_arg(os, &self.path); - debug!(?path, "Reading"); - let file_bytes = os.fs.read(&path).await?; - let file_content = String::from_utf8_lossy(&file_bytes); - let line_count = file_content.lines().count(); - let (start, end) = ( - convert_negative_index(line_count, self.start_line()), - convert_negative_index(line_count, self.end_line()), - ); - - // safety check to ensure end is always greater than start - let end = end.max(start); - - if start >= line_count { - bail!( - "starting index: {} is outside of the allowed range: ({}, {})", - self.start_line(), - -(line_count as i64), - line_count - ); - } - - // The range should be inclusive on both ends. - let file_contents = file_content - .lines() - .skip(start) - .take(end - start + 1) - .collect::>() - .join("\n"); - - let byte_count = file_contents.len(); - if byte_count > MAX_TOOL_RESPONSE_SIZE { - bail!( - "This tool only supports reading {MAX_TOOL_RESPONSE_SIZE} bytes at a -time. You tried to read {byte_count} bytes. Try executing with fewer lines specified." - ); - } - - Ok(InvokeOutput { - output: OutputKind::Text(file_contents), - }) - } - - fn start_line(&self) -> i32 { - self.start_line.unwrap_or(Self::DEFAULT_START_LINE) - } - - fn end_line(&self) -> i32 { - self.end_line.unwrap_or(Self::DEFAULT_END_LINE) - } -} - -/// Search in a file. -#[derive(Debug, Clone, Deserialize)] -pub struct FsSearch { - pub path: String, - pub pattern: String, - pub context_lines: Option, -} - -impl FsSearch { - const CONTEXT_LINE_PREFIX: &str = " "; - const DEFAULT_CONTEXT_LINES: usize = 2; - const MATCHING_LINE_PREFIX: &str = "→ "; - - pub async fn validate(&mut self, os: &Os) -> Result<()> { - let path = sanitize_path_tool_arg(os, &self.path); - let relative_path = format_path(os.env.current_dir()?, &path); - if !path.exists() { - bail!("File not found: {}", relative_path); - } - if !os.fs.symlink_metadata(path).await?.is_file() { - bail!("Path is not a file: {}", relative_path); - } - if self.pattern.is_empty() { - bail!("Search pattern cannot be empty"); - } - Ok(()) - } - - pub fn queue_description(&self, updates: &mut impl Write) -> Result<()> { - queue!( - updates, - style::Print("Searching: "), - style::SetForegroundColor(Color::Green), - style::Print(&self.path), - style::ResetColor, - style::Print(" for pattern: "), - style::SetForegroundColor(Color::Green), - style::Print(&self.pattern.to_lowercase()), - style::ResetColor, - style::Print("\n"), - )?; - Ok(()) - } - - pub async fn invoke(&self, os: &Os, updates: &mut impl Write) -> Result { - let file_path = sanitize_path_tool_arg(os, &self.path); - let pattern = &self.pattern; - - let file_bytes = os.fs.read(&file_path).await?; - let file_content = String::from_utf8_lossy(&file_bytes); - let lines: Vec<&str> = LinesWithEndings::from(&file_content).collect(); - - let mut results = Vec::new(); - let mut total_matches = 0; - - // Case insensitive search - let pattern_lower = pattern.to_lowercase(); - for (line_num, line) in lines.iter().enumerate() { - if line.to_lowercase().contains(&pattern_lower) { - total_matches += 1; - let start = line_num.saturating_sub(self.context_lines()); - let end = lines.len().min(line_num + self.context_lines() + 1); - let mut context_text = Vec::new(); - (start..end).for_each(|i| { - let prefix = if i == line_num { - Self::MATCHING_LINE_PREFIX - } else { - Self::CONTEXT_LINE_PREFIX - }; - let line_text = lines[i].to_string(); - context_text.push(format!("{}{}: {}", prefix, i + 1, line_text)); - }); - let match_text = context_text.join(""); - results.push(SearchMatch { - line_number: line_num + 1, - context: match_text, - }); - } - } - let match_text = if total_matches == 1 { - "1 match".to_string() - } else { - format!("{} matches", total_matches) - }; - - let color = if total_matches == 0 { - Color::Yellow - } else { - Color::Green - }; - - let result = if total_matches == 0 { - CROSS.yellow() - } else { - CHECKMARK.green() - }; - - queue!( - updates, - style::SetForegroundColor(Color::Yellow), - style::ResetColor, - style::Print(CONTINUATION_LINE), - style::Print("\n"), - style::Print(" "), - style::Print(result), - style::Print(" Found: "), - style::SetForegroundColor(color), - style::Print(match_text), - style::ResetColor, - )?; - - Ok(InvokeOutput { - output: OutputKind::Text(serde_json::to_string(&results)?), - }) - } - - fn context_lines(&self) -> usize { - self.context_lines.unwrap_or(Self::DEFAULT_CONTEXT_LINES) - } -} - -/// List directory contents. -#[derive(Debug, Clone, Deserialize)] -pub struct FsDirectory { - pub path: String, - pub depth: Option, -} - -impl FsDirectory { - const DEFAULT_DEPTH: usize = 0; - - pub async fn validate(&mut self, os: &Os) -> Result<()> { - let path = sanitize_path_tool_arg(os, &self.path); - let relative_path = format_path(os.env.current_dir()?, &path); - if !path.exists() { - bail!("Directory not found: {}", relative_path); - } - if !os.fs.symlink_metadata(path).await?.is_dir() { - bail!("Path is not a directory: {}", relative_path); - } - Ok(()) - } - - pub fn queue_description(&self, updates: &mut impl Write) -> Result<()> { - queue!( - updates, - style::Print("Reading directory: "), - style::SetForegroundColor(Color::Green), - style::Print(&self.path), - style::ResetColor, - style::Print(" "), - )?; - let depth = self.depth.unwrap_or_default(); - Ok(queue!( - updates, - style::Print(format!("with maximum depth of {}", depth)) - )?) - } - - pub async fn invoke(&self, os: &Os, _updates: &mut impl Write) -> Result { - let path = sanitize_path_tool_arg(os, &self.path); - let max_depth = self.depth(); - debug!(?path, max_depth, "Reading directory at path with depth"); - let mut result = Vec::new(); - let mut dir_queue = VecDeque::new(); - dir_queue.push_back((path, 0)); - while let Some((path, depth)) = dir_queue.pop_front() { - if depth > max_depth { - break; - } - let mut read_dir = os.fs.read_dir(path).await?; - - #[cfg(windows)] - while let Some(ent) = read_dir.next_entry().await? { - let md = ent.metadata().await?; - - let modified_timestamp = md.modified()?.duration_since(std::time::UNIX_EPOCH)?.as_secs(); - let datetime = time::OffsetDateTime::from_unix_timestamp(modified_timestamp as i64).unwrap(); - let formatted_date = datetime - .format(time::macros::format_description!( - "[month repr:short] [day] [hour]:[minute]" - )) - .unwrap(); - - result.push(format!( - "{} {} {} {}", - format_ftype(&md), - String::from_utf8_lossy(ent.file_name().as_encoded_bytes()), - formatted_date, - ent.path().to_string_lossy() - )); - - if md.is_dir() && md.is_dir() { - dir_queue.push_back((ent.path(), depth + 1)); - } - } - - #[cfg(unix)] - while let Some(ent) = read_dir.next_entry().await? { - use std::os::unix::fs::{ - MetadataExt, - PermissionsExt, - }; - - let md = ent.metadata().await?; - let formatted_mode = format_mode(md.permissions().mode()).into_iter().collect::(); - - let modified_timestamp = md.modified()?.duration_since(std::time::UNIX_EPOCH)?.as_secs(); - let datetime = time::OffsetDateTime::from_unix_timestamp(modified_timestamp as i64).unwrap(); - let formatted_date = datetime - .format(time::macros::format_description!( - "[month repr:short] [day] [hour]:[minute]" - )) - .unwrap(); - - // Mostly copying "The Long Format" from `man ls`. - // TODO: query user/group database to convert uid/gid to names? - result.push(format!( - "{}{} {} {} {} {} {} {}", - format_ftype(&md), - formatted_mode, - md.nlink(), - md.uid(), - md.gid(), - md.size(), - formatted_date, - ent.path().to_string_lossy() - )); - if md.is_dir() { - dir_queue.push_back((ent.path(), depth + 1)); - } - } - } - - let file_count = result.len(); - let result = result.join("\n"); - let byte_count = result.len(); - if byte_count > MAX_TOOL_RESPONSE_SIZE { - bail!( - "This tool only supports reading up to {MAX_TOOL_RESPONSE_SIZE} bytes at a time. You tried to read {byte_count} bytes ({file_count} files). Try executing with fewer lines specified." - ); - } - - Ok(InvokeOutput { - output: OutputKind::Text(result), - }) - } - - fn depth(&self) -> usize { - self.depth.unwrap_or(Self::DEFAULT_DEPTH) - } -} - -/// Converts negative 1-based indices to positive 0-based indices. -fn convert_negative_index(line_count: usize, i: i32) -> usize { - if i <= 0 { - (line_count as i32 + i).max(0) as usize - } else { - i as usize - 1 - } -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -struct SearchMatch { - line_number: usize, - context: String, -} - -fn format_ftype(md: &Metadata) -> char { - if md.is_symlink() { - 'l' - } else if md.is_file() { - '-' - } else if md.is_dir() { - 'd' - } else { - warn!("unknown file metadata: {:?}", md); - '-' - } -} - -/// Formats a permissions mode into the form used by `ls`, e.g. `0o644` to `rw-r--r--` -#[cfg(unix)] -fn format_mode(mode: u32) -> [char; 9] { - let mut mode = mode & 0o777; - let mut res = ['-'; 9]; - fn octal_to_chars(val: u32) -> [char; 3] { - match val { - 1 => ['-', '-', 'x'], - 2 => ['-', 'w', '-'], - 3 => ['-', 'w', 'x'], - 4 => ['r', '-', '-'], - 5 => ['r', '-', 'x'], - 6 => ['r', 'w', '-'], - 7 => ['r', 'w', 'x'], - _ => ['-', '-', '-'], - } - } - for c in res.rchunks_exact_mut(3) { - c.copy_from_slice(&octal_to_chars(mode & 0o7)); - mode /= 0o10; - } - res -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::cli::chat::util::test::{ - TEST_FILE_CONTENTS, - TEST_FILE_PATH, - setup_test_directory, - }; - - #[test] - fn test_negative_index_conversion() { - assert_eq!(convert_negative_index(5, -100), 0); - assert_eq!(convert_negative_index(5, -1), 4); - } - - #[test] - fn test_fs_read_deser() { - serde_json::from_value::(serde_json::json!({ "path": "/test_file.txt", "mode": "Line" })).unwrap(); - serde_json::from_value::( - serde_json::json!({ "path": "/test_file.txt", "mode": "Line", "end_line": 5 }), - ) - .unwrap(); - serde_json::from_value::( - serde_json::json!({ "path": "/test_file.txt", "mode": "Line", "start_line": -1 }), - ) - .unwrap(); - serde_json::from_value::( - serde_json::json!({ "path": "/test_file.txt", "mode": "Line", "start_line": None:: }), - ) - .unwrap(); - serde_json::from_value::(serde_json::json!({ "path": "/", "mode": "Directory" })).unwrap(); - serde_json::from_value::( - serde_json::json!({ "path": "/test_file.txt", "mode": "Directory", "depth": 2 }), - ) - .unwrap(); - serde_json::from_value::( - serde_json::json!({ "path": "/test_file.txt", "mode": "Search", "pattern": "hello" }), - ) - .unwrap(); - } - - #[tokio::test] - async fn test_fs_read_line_invoke() { - let os = setup_test_directory().await; - let lines = TEST_FILE_CONTENTS.lines().collect::>(); - let mut stdout = std::io::stdout(); - - macro_rules! assert_lines { - ($start_line:expr, $end_line:expr, $expected:expr) => { - let v = serde_json::json!({ - "path": TEST_FILE_PATH, - "mode": "Line", - "start_line": $start_line, - "end_line": $end_line, - }); - let output = serde_json::from_value::(v) - .unwrap() - .invoke(&os, &mut stdout) - .await - .unwrap(); - - if let OutputKind::Text(text) = output.output { - assert_eq!(text, $expected.join("\n"), "actual(left) does not equal - expected(right) for (start_line, end_line): ({:?}, {:?})", $start_line, $end_line); - } else { - panic!("expected text output"); - } - } - } - assert_lines!(None::, None::, lines[..]); - assert_lines!(1, 2, lines[..=1]); - assert_lines!(1, -1, lines[..]); - assert_lines!(2, 1, lines[1..=1]); - assert_lines!(-2, -1, lines[2..]); - assert_lines!(-2, None::, lines[2..]); - assert_lines!(2, None::, lines[1..]); - } - - #[tokio::test] - async fn test_fs_read_line_past_eof() { - let os = setup_test_directory().await; - let mut stdout = std::io::stdout(); - let v = serde_json::json!({ - "path": TEST_FILE_PATH, - "mode": "Line", - "start_line": 100, - "end_line": None::, - }); - assert!( - serde_json::from_value::(v) - .unwrap() - .invoke(&os, &mut stdout) - .await - .is_err() - ); - } - - #[test] - #[cfg(unix)] - fn test_format_mode() { - macro_rules! assert_mode { - ($actual:expr, $expected:expr) => { - assert_eq!(format_mode($actual).iter().collect::(), $expected); - }; - } - assert_mode!(0o000, "---------"); - assert_mode!(0o700, "rwx------"); - assert_mode!(0o744, "rwxr--r--"); - assert_mode!(0o641, "rw-r----x"); - } - - #[tokio::test] - async fn test_fs_read_directory_invoke() { - let os = setup_test_directory().await; - let mut stdout = std::io::stdout(); - - // Testing without depth - let v = serde_json::json!({ - "mode": "Directory", - "path": "/", - }); - let output = serde_json::from_value::(v) - .unwrap() - .invoke(&os, &mut stdout) - .await - .unwrap(); - - if let OutputKind::Text(text) = output.output { - assert_eq!(text.lines().collect::>().len(), 4); - } else { - panic!("expected text output"); - } - - // Testing with depth level 1 - let v = serde_json::json!({ - "mode": "Directory", - "path": "/", - "depth": 1, - }); - let output = serde_json::from_value::(v) - .unwrap() - .invoke(&os, &mut stdout) - .await - .unwrap(); - - if let OutputKind::Text(text) = output.output { - let lines = text.lines().collect::>(); - assert_eq!(lines.len(), 7); - assert!( - !lines.iter().any(|l| l.contains("cccc1")), - "directory at depth level 2 should not be included in output" - ); - } else { - panic!("expected text output"); - } - } - - #[tokio::test] - async fn test_fs_read_search_invoke() { - let os = setup_test_directory().await; - let mut stdout = std::io::stdout(); - - macro_rules! invoke_search { - ($value:tt) => {{ - let v = serde_json::json!($value); - let output = serde_json::from_value::(v) - .unwrap() - .invoke(&os, &mut stdout) - .await - .unwrap(); - - if let OutputKind::Text(value) = output.output { - serde_json::from_str::>(&value).unwrap() - } else { - panic!("expected Text output") - } - }}; - } - - let matches = invoke_search!({ - "mode": "Search", - "path": TEST_FILE_PATH, - "pattern": "hello", - }); - assert_eq!(matches.len(), 2); - assert_eq!(matches[0].line_number, 1); - assert_eq!( - matches[0].context, - format!( - "{}1: 1: Hello world!\n{}2: 2: This is line 2\n{}3: 3: asdf\n", - FsSearch::MATCHING_LINE_PREFIX, - FsSearch::CONTEXT_LINE_PREFIX, - FsSearch::CONTEXT_LINE_PREFIX - ) - ); - } - - #[tokio::test] - async fn test_fs_read_non_utf8_binary_file() { - let os = Os::new().await.unwrap(); - let mut stdout = std::io::stdout(); - - let binary_data = vec![0xff, 0xfe, 0xfd, 0xfc, 0xfb, 0xfa, 0xf9, 0xf8]; - let binary_file_path = "/binary_test.dat"; - os.fs.write(binary_file_path, &binary_data).await.unwrap(); - - let v = serde_json::json!({ - "path": binary_file_path, - "mode": "Line" - }); - let output = serde_json::from_value::(v) - .unwrap() - .invoke(&os, &mut stdout) - .await - .unwrap(); - - if let OutputKind::Text(text) = output.output { - assert!(text.contains('�'), "Binary data should contain replacement characters"); - assert_eq!(text.chars().count(), 8, "Should have 8 replacement characters"); - assert!( - text.chars().all(|c| c == '�'), - "All characters should be replacement characters" - ); - } else { - panic!("expected text output"); - } - } - - #[tokio::test] - async fn test_fs_read_latin1_encoded_file() { - let os = Os::new().await.unwrap(); - let mut stdout = std::io::stdout(); - - let latin1_data = vec![99, 97, 102, 233]; // "café" in Latin-1 - let latin1_file_path = "/latin1_test.txt"; - os.fs.write(latin1_file_path, &latin1_data).await.unwrap(); - - let v = serde_json::json!({ - "path": latin1_file_path, - "mode": "Line" - }); - let output = serde_json::from_value::(v) - .unwrap() - .invoke(&os, &mut stdout) - .await - .unwrap(); - - if let OutputKind::Text(text) = output.output { - // Latin-1 byte 233 (é) is invalid UTF-8, so it becomes a replacement character - assert!(text.starts_with("caf"), "Should start with 'caf'"); - assert!( - text.contains('�'), - "Should contain replacement character for invalid UTF-8" - ); - } else { - panic!("expected text output"); - } - } - - #[tokio::test] - async fn test_fs_search_non_utf8_file() { - let os = Os::new().await.unwrap(); - let mut stdout = std::io::stdout(); - - let mut mixed_data = Vec::new(); - mixed_data.extend_from_slice(b"Hello world\n"); - mixed_data.extend_from_slice(&[0xff, 0xfe]); // Invalid UTF-8 bytes - mixed_data.extend_from_slice(b"\nGoodbye world\n"); - - let mixed_file_path = "/mixed_encoding_test.txt"; - os.fs.write(mixed_file_path, &mixed_data).await.unwrap(); - - let v = serde_json::json!({ - "mode": "Search", - "path": mixed_file_path, - "pattern": "hello" - }); - let output = serde_json::from_value::(v) - .unwrap() - .invoke(&os, &mut stdout) - .await - .unwrap(); - - if let OutputKind::Text(value) = output.output { - let matches: Vec = serde_json::from_str(&value).unwrap(); - assert_eq!(matches.len(), 1, "Should find one match for 'hello'"); - assert_eq!(matches[0].line_number, 1, "Match should be on line 1"); - assert!( - matches[0].context.contains("Hello world"), - "Should contain the matched line" - ); - } else { - panic!("expected Text output"); - } - - let v = serde_json::json!({ - "mode": "Search", - "path": mixed_file_path, - "pattern": "goodbye" - }); - let output = serde_json::from_value::(v) - .unwrap() - .invoke(&os, &mut stdout) - .await - .unwrap(); - - if let OutputKind::Text(value) = output.output { - let matches: Vec = serde_json::from_str(&value).unwrap(); - assert_eq!(matches.len(), 1, "Should find one match for 'goodbye'"); - assert!( - matches[0].context.contains("Goodbye world"), - "Should contain the matched line" - ); - } else { - panic!("expected Text output"); - } - } - - #[tokio::test] - async fn test_fs_read_windows1252_encoded_file() { - let os = Os::new().await.unwrap(); - let mut stdout = std::io::stdout(); - - let mut windows1252_data = Vec::new(); - windows1252_data.extend_from_slice(b"Text with "); - windows1252_data.push(0x93); // Left double quotation mark in Windows-1252 - windows1252_data.extend_from_slice(b"smart quotes"); - windows1252_data.push(0x94); // Right double quotation mark in Windows-1252 - - let windows1252_file_path = "/windows1252_test.txt"; - os.fs.write(windows1252_file_path, &windows1252_data).await.unwrap(); - - let v = serde_json::json!({ - "path": windows1252_file_path, - "mode": "Line" - }); - let output = serde_json::from_value::(v) - .unwrap() - .invoke(&os, &mut stdout) - .await - .unwrap(); - - if let OutputKind::Text(text) = output.output { - assert!(text.contains("Text with"), "Should contain readable text"); - assert!(text.contains("smart quotes"), "Should contain readable text"); - assert!( - text.contains('�'), - "Should contain replacement characters for invalid UTF-8" - ); - } else { - panic!("expected text output"); - } - } - - #[tokio::test] - async fn test_fs_search_pattern_with_replacement_chars() { - let os = Os::new().await.unwrap(); - let mut stdout = std::io::stdout(); - - let mut data_with_invalid_utf8 = Vec::new(); - data_with_invalid_utf8.extend_from_slice(b"Line 1: caf"); - data_with_invalid_utf8.push(0xe9); // Invalid UTF-8 byte (Latin-1 é) - data_with_invalid_utf8.extend_from_slice(b"\nLine 2: hello world\n"); - - let invalid_utf8_file_path = "/invalid_utf8_search_test.txt"; - os.fs - .write(invalid_utf8_file_path, &data_with_invalid_utf8) - .await - .unwrap(); - - let v = serde_json::json!({ - "mode": "Search", - "path": invalid_utf8_file_path, - "pattern": "caf" - }); - let output = serde_json::from_value::(v) - .unwrap() - .invoke(&os, &mut stdout) - .await - .unwrap(); - - if let OutputKind::Text(value) = output.output { - let matches: Vec = serde_json::from_str(&value).unwrap(); - assert_eq!(matches.len(), 1, "Should find one match for 'caf'"); - assert_eq!(matches[0].line_number, 1, "Match should be on line 1"); - assert!(matches[0].context.contains("caf"), "Should contain 'caf'"); - } else { - panic!("expected Text output"); - } - } - - #[tokio::test] - async fn test_fs_read_empty_file_with_invalid_utf8() { - let os = Os::new().await.unwrap(); - let mut stdout = std::io::stdout(); - - let invalid_only_data = vec![0xff, 0xfe, 0xfd]; - let invalid_only_file_path = "/invalid_only_test.txt"; - os.fs.write(invalid_only_file_path, &invalid_only_data).await.unwrap(); - - let v = serde_json::json!({ - "path": invalid_only_file_path, - "mode": "Line" - }); - let output = serde_json::from_value::(v) - .unwrap() - .invoke(&os, &mut stdout) - .await - .unwrap(); - - if let OutputKind::Text(text) = output.output { - assert_eq!(text.chars().count(), 3, "Should have 3 replacement characters"); - assert!(text.chars().all(|c| c == '�'), "Should be all replacement characters"); - } else { - panic!("expected text output"); - } - - let v = serde_json::json!({ - "mode": "Search", - "path": invalid_only_file_path, - "pattern": "test" - }); - let output = serde_json::from_value::(v) - .unwrap() - .invoke(&os, &mut stdout) - .await - .unwrap(); - - if let OutputKind::Text(value) = output.output { - let matches: Vec = serde_json::from_str(&value).unwrap(); - assert_eq!( - matches.len(), - 0, - "Should find no matches in file with only invalid UTF-8" - ); - } else { - panic!("expected Text output"); - } - } -} diff --git a/crates/chat-cli/src/cli/chat/tools/fs_write.rs b/crates/chat-cli/src/cli/chat/tools/fs_write.rs deleted file mode 100644 index dbd4fe71f..000000000 --- a/crates/chat-cli/src/cli/chat/tools/fs_write.rs +++ /dev/null @@ -1,1083 +0,0 @@ -use std::io::Write; -use std::path::Path; -use std::sync::LazyLock; - -use crossterm::queue; -use crossterm::style::{ - self, - Color, -}; -use eyre::{ - ContextCompat as _, - Result, - bail, - eyre, -}; -use serde::Deserialize; -use similar::DiffableStr; -use syntect::easy::HighlightLines; -use syntect::highlighting::ThemeSet; -use syntect::parsing::SyntaxSet; -use syntect::util::{ - LinesWithEndings, - as_24_bit_terminal_escaped, -}; -use tracing::{ - error, - warn, -}; - -use super::{ - InvokeOutput, - format_path, - sanitize_path_tool_arg, - supports_truecolor, -}; -use crate::os::Os; - -static SYNTAX_SET: LazyLock = LazyLock::new(SyntaxSet::load_defaults_newlines); -static THEME_SET: LazyLock = LazyLock::new(ThemeSet::load_defaults); - -#[derive(Debug, Clone, Deserialize)] -#[serde(tag = "command")] -pub enum FsWrite { - /// The tool spec should only require `file_text`, but the model sometimes doesn't want to - /// provide it. Thus, including `new_str` as a fallback check, if it's available. - #[serde(rename = "create")] - Create { - path: String, - file_text: Option, - new_str: Option, - summary: Option, - }, - #[serde(rename = "str_replace")] - StrReplace { - path: String, - old_str: String, - new_str: String, - summary: Option, - }, - #[serde(rename = "insert")] - Insert { - path: String, - insert_line: usize, - new_str: String, - summary: Option, - }, - #[serde(rename = "append")] - Append { - path: String, - new_str: String, - summary: Option, - }, -} - -impl FsWrite { - pub async fn invoke(&self, os: &Os, output: &mut impl Write) -> Result { - let cwd = os.env.current_dir()?; - match self { - FsWrite::Create { path, .. } => { - let file_text = self.canonical_create_command_text(); - let path = sanitize_path_tool_arg(os, path); - if let Some(parent) = path.parent() { - os.fs.create_dir_all(parent).await?; - } - - let invoke_description = if os.fs.exists(&path) { - "Replacing: " - } else { - "Creating: " - }; - queue!( - output, - style::Print(invoke_description), - style::SetForegroundColor(Color::Green), - style::Print(format_path(cwd, &path)), - style::ResetColor, - style::Print("\n"), - )?; - - write_to_file(os, path, file_text).await?; - Ok(Default::default()) - }, - FsWrite::StrReplace { - path, old_str, new_str, .. - } => { - let path = sanitize_path_tool_arg(os, path); - let file = os.fs.read_to_string(&path).await?; - let matches = file.match_indices(old_str).collect::>(); - queue!( - output, - style::Print("Updating: "), - style::SetForegroundColor(Color::Green), - style::Print(format_path(cwd, &path)), - style::ResetColor, - style::Print("\n"), - )?; - match matches.len() { - 0 => Err(eyre!("no occurrences of \"{old_str}\" were found")), - 1 => { - let file = file.replacen(old_str, new_str, 1); - os.fs.write(path, file).await?; - Ok(Default::default()) - }, - x => Err(eyre!("{x} occurrences of old_str were found when only 1 is expected")), - } - }, - FsWrite::Insert { - path, - insert_line, - new_str, - .. - } => { - let path = sanitize_path_tool_arg(os, path); - let mut file = os.fs.read_to_string(&path).await?; - queue!( - output, - style::Print("Updating: "), - style::SetForegroundColor(Color::Green), - style::Print(format_path(cwd, &path)), - style::ResetColor, - style::Print("\n"), - )?; - - // Get the index of the start of the line to insert at. - let num_lines = file.lines().enumerate().map(|(i, _)| i + 1).last().unwrap_or(1); - let insert_line = insert_line.clamp(&0, &num_lines); - let mut i = 0; - for _ in 0..*insert_line { - let line_len = &file[i..].find("\n").map_or(file[i..].len(), |i| i + 1); - i += line_len; - } - file.insert_str(i, new_str); - write_to_file(os, &path, file).await?; - Ok(Default::default()) - }, - FsWrite::Append { path, new_str, .. } => { - let path = sanitize_path_tool_arg(os, path); - - queue!( - output, - style::Print("Appending to: "), - style::SetForegroundColor(Color::Green), - style::Print(format_path(cwd, &path)), - style::ResetColor, - style::Print("\n"), - )?; - - let mut file = os.fs.read_to_string(&path).await?; - if !file.ends_with_newline() { - file.push('\n'); - } - file.push_str(new_str); - write_to_file(os, path, file).await?; - Ok(Default::default()) - }, - } - } - - pub fn queue_description(&self, os: &Os, output: &mut impl Write) -> Result<()> { - let cwd = os.env.current_dir()?; - self.print_relative_path(os, output)?; - match self { - FsWrite::Create { path, .. } => { - let file_text = self.canonical_create_command_text(); - let path = sanitize_path_tool_arg(os, path); - let relative_path = format_path(cwd, &path); - let prev = if os.fs.exists(&path) { - let file = os.fs.read_to_string_sync(&path)?; - stylize_output_if_able(os, &path, &file) - } else { - Default::default() - }; - let new = stylize_output_if_able(os, &relative_path, &file_text); - print_diff(output, &prev, &new, 1)?; - - // Display summary as purpose if available after the diff - super::display_purpose(self.get_summary(), output)?; - - Ok(()) - }, - FsWrite::Insert { - path, - insert_line, - new_str, - .. - } => { - let path = sanitize_path_tool_arg(os, path); - let relative_path = format_path(cwd, &path); - let file = os.fs.read_to_string_sync(&path)?; - - // Diff the old with the new by adding extra context around the line being inserted - // at. - let (prefix, start_line, suffix, _) = get_lines_with_context(&file, *insert_line, *insert_line, 3); - let insert_line_content = LinesWithEndings::from(&file) - // don't include any content if insert_line is 0 - .nth(insert_line.checked_sub(1).unwrap_or(usize::MAX)) - .unwrap_or_default(); - let old = [prefix, insert_line_content, suffix].join(""); - let new = [prefix, insert_line_content, new_str, suffix].join(""); - - let old = stylize_output_if_able(os, &relative_path, &old); - let new = stylize_output_if_able(os, &relative_path, &new); - print_diff(output, &old, &new, start_line)?; - - // Display summary as purpose if available after the diff - super::display_purpose(self.get_summary(), output)?; - - Ok(()) - }, - FsWrite::StrReplace { - path, old_str, new_str, .. - } => { - let path = sanitize_path_tool_arg(os, path); - let relative_path = format_path(cwd, &path); - let file = os.fs.read_to_string_sync(&path)?; - let (start_line, _) = match line_number_at(&file, old_str) { - Some((start_line, end_line)) => (start_line, end_line), - _ => (0, 0), - }; - let old_str = stylize_output_if_able(os, &relative_path, old_str); - let new_str = stylize_output_if_able(os, &relative_path, new_str); - print_diff(output, &old_str, &new_str, start_line)?; - - // Display summary as purpose if available after the diff - super::display_purpose(self.get_summary(), output)?; - - Ok(()) - }, - FsWrite::Append { path, new_str, .. } => { - let path = sanitize_path_tool_arg(os, path); - let relative_path = format_path(cwd, &path); - let start_line = os.fs.read_to_string_sync(&path)?.lines().count() + 1; - let file = stylize_output_if_able(os, &relative_path, new_str); - print_diff(output, &Default::default(), &file, start_line)?; - - // Display summary as purpose if available after the diff - super::display_purpose(self.get_summary(), output)?; - - Ok(()) - }, - } - } - - pub async fn validate(&mut self, os: &Os) -> Result<()> { - match self { - FsWrite::Create { path, .. } => { - if path.is_empty() { - bail!("Path must not be empty") - }; - }, - FsWrite::StrReplace { path, .. } | FsWrite::Insert { path, .. } => { - let path = sanitize_path_tool_arg(os, path); - if !path.exists() { - bail!("The provided path must exist in order to replace or insert contents into it") - } - }, - FsWrite::Append { path, new_str, .. } => { - if path.is_empty() { - bail!("Path must not be empty") - }; - if new_str.is_empty() { - bail!("Content to append must not be empty") - }; - }, - } - - Ok(()) - } - - fn print_relative_path(&self, os: &Os, output: &mut impl Write) -> Result<()> { - let cwd = os.env.current_dir()?; - let path = match self { - FsWrite::Create { path, .. } => path, - FsWrite::StrReplace { path, .. } => path, - FsWrite::Insert { path, .. } => path, - FsWrite::Append { path, .. } => path, - }; - // Sanitize the path to handle tilde expansion - let path = sanitize_path_tool_arg(os, path); - let relative_path = format_path(cwd, &path); - queue!( - output, - style::Print("Path: "), - style::SetForegroundColor(Color::Green), - style::Print(&relative_path), - style::ResetColor, - style::Print("\n\n"), - )?; - Ok(()) - } - - /// Returns the text to use for the [FsWrite::Create] command. This is required since we can't - /// rely on the model always providing `file_text`. - fn canonical_create_command_text(&self) -> String { - match self { - FsWrite::Create { file_text, new_str, .. } => match (file_text, new_str) { - (Some(file_text), _) => file_text.clone(), - (None, Some(new_str)) => { - warn!("required field `file_text` is missing, using the provided `new_str` instead"); - new_str.clone() - }, - _ => { - warn!("no content provided for the create command"); - String::new() - }, - }, - _ => String::new(), - } - } - - /// Returns the summary from any variant of the FsWrite enum - fn get_summary(&self) -> Option<&String> { - match self { - FsWrite::Create { summary, .. } => summary.as_ref(), - FsWrite::StrReplace { summary, .. } => summary.as_ref(), - FsWrite::Insert { summary, .. } => summary.as_ref(), - FsWrite::Append { summary, .. } => summary.as_ref(), - } - } -} - -/// Writes `content` to `path`, adding a newline if necessary. -async fn write_to_file(os: &Os, path: impl AsRef, mut content: String) -> Result<()> { - let path_ref = path.as_ref(); - - // Log the path being written to - tracing::debug!("Writing to file: {:?}", path_ref); - - if !content.ends_with_newline() { - content.push('\n'); - } - os.fs.write(path.as_ref(), content).await?; - Ok(()) -} - -/// Returns a prefix/suffix pair before and after the content dictated by `[start_line, end_line]` -/// within `content`. The updated start and end lines containing the original context along with -/// the suffix and prefix are returned. -/// -/// Params: -/// - `start_line` - 1-indexed starting line of the content. -/// - `end_line` - 1-indexed ending line of the content. -/// - `context_lines` - number of lines to include before the start and end. -/// -/// Returns `(prefix, new_start_line, suffix, new_end_line)` -fn get_lines_with_context( - content: &str, - start_line: usize, - end_line: usize, - context_lines: usize, -) -> (&str, usize, &str, usize) { - let line_count = content.lines().count(); - // We want to support end_line being 0, in which case we should be able to set the first line - // as the suffix. - let zero_check_inc = if end_line == 0 { 0 } else { 1 }; - - // Convert to 0-indexing. - let (start_line, end_line) = ( - start_line.saturating_sub(1).clamp(0, line_count - 1), - end_line.saturating_sub(1).clamp(0, line_count - 1), - ); - let new_start_line = 0.max(start_line.saturating_sub(context_lines)); - let new_end_line = (line_count - 1).min(end_line + context_lines); - - // Build prefix - let mut prefix_start = 0; - for line in LinesWithEndings::from(content).take(new_start_line) { - prefix_start += line.len(); - } - let mut prefix_end = prefix_start; - for line in LinesWithEndings::from(&content[prefix_start..]).take(start_line - new_start_line) { - prefix_end += line.len(); - } - - // Build suffix - let mut suffix_start = 0; - for line in LinesWithEndings::from(content).take(end_line + zero_check_inc) { - suffix_start += line.len(); - } - let mut suffix_end = suffix_start; - for line in LinesWithEndings::from(&content[suffix_start..]).take(new_end_line - end_line) { - suffix_end += line.len(); - } - - ( - &content[prefix_start..prefix_end], - new_start_line + 1, - &content[suffix_start..suffix_end], - new_end_line + zero_check_inc, - ) -} - -/// Prints a git-diff style comparison between `old_str` and `new_str`. -/// - `start_line` - 1-indexed line number that `old_str` and `new_str` start at. -fn print_diff( - output: &mut impl Write, - old_str: &StylizedFile, - new_str: &StylizedFile, - start_line: usize, -) -> Result<()> { - let diff = similar::TextDiff::from_lines(&old_str.content, &new_str.content); - - // First, get the gutter width required for both the old and new lines. - let (mut max_old_i, mut max_new_i) = (1, 1); - for change in diff.iter_all_changes() { - if let Some(i) = change.old_index() { - max_old_i = i + start_line; - } - if let Some(i) = change.new_index() { - max_new_i = i + start_line; - } - } - let old_line_num_width = terminal_width_required_for_line_count(max_old_i); - let new_line_num_width = terminal_width_required_for_line_count(max_new_i); - - // Now, print - fn fmt_index(i: Option, start_line: usize) -> String { - match i { - Some(i) => (i + start_line).to_string(), - _ => " ".to_string(), - } - } - for change in diff.iter_all_changes() { - // Define the colors per line. - let (text_color, gutter_bg_color, line_bg_color) = match (change.tag(), new_str.truecolor) { - (similar::ChangeTag::Equal, true) => (style::Color::Reset, new_str.gutter_bg, new_str.line_bg), - (similar::ChangeTag::Delete, true) => ( - style::Color::Reset, - style::Color::Rgb { r: 79, g: 40, b: 40 }, - style::Color::Rgb { r: 36, g: 25, b: 28 }, - ), - (similar::ChangeTag::Insert, true) => ( - style::Color::Reset, - style::Color::Rgb { r: 40, g: 67, b: 43 }, - style::Color::Rgb { r: 24, g: 38, b: 30 }, - ), - (similar::ChangeTag::Equal, false) => (style::Color::Reset, new_str.gutter_bg, new_str.line_bg), - (similar::ChangeTag::Delete, false) => (style::Color::Red, new_str.gutter_bg, new_str.line_bg), - (similar::ChangeTag::Insert, false) => (style::Color::Green, new_str.gutter_bg, new_str.line_bg), - }; - // Define the change tag character to print, if any. - let sign = match change.tag() { - similar::ChangeTag::Equal => " ", - similar::ChangeTag::Delete => "-", - similar::ChangeTag::Insert => "+", - }; - - let old_i_str = fmt_index(change.old_index(), start_line); - let new_i_str = fmt_index(change.new_index(), start_line); - - // Print the gutter and line numbers. - queue!(output, style::SetBackgroundColor(gutter_bg_color))?; - queue!( - output, - style::SetForegroundColor(text_color), - style::Print(sign), - style::Print(" ") - )?; - queue!( - output, - style::Print(format!( - "{:>old_line_num_width$}", - old_i_str, - old_line_num_width = old_line_num_width - )) - )?; - if sign == " " { - queue!(output, style::Print(", "))?; - } else { - queue!(output, style::Print(" "))?; - } - queue!( - output, - style::Print(format!( - "{:>new_line_num_width$}", - new_i_str, - new_line_num_width = new_line_num_width - )) - )?; - // Print the line. - queue!( - output, - style::SetForegroundColor(style::Color::Reset), - style::Print(":"), - style::SetForegroundColor(text_color), - style::SetBackgroundColor(line_bg_color), - style::Print(" "), - style::Print(change), - style::ResetColor, - )?; - } - queue!( - output, - crossterm::terminal::Clear(crossterm::terminal::ClearType::UntilNewLine), - style::Print("\n"), - )?; - - Ok(()) -} - -/// Returns a 1-indexed line number range of the start and end of `needle` inside `file`. -fn line_number_at(file: impl AsRef, needle: impl AsRef) -> Option<(usize, usize)> { - let file = file.as_ref(); - let needle = needle.as_ref(); - if let Some((i, _)) = file.match_indices(needle).next() { - let start = file[..i].matches("\n").count(); - let end = needle.matches("\n").count(); - Some((start + 1, start + end + 1)) - } else { - None - } -} - -/// Returns the number of terminal cells required for displaying line numbers. This is used to -/// determine how many characters the gutter should allocate when displaying line numbers for a -/// text file. -/// -/// For example, `10` and `99` both take 2 cells, whereas `100` and `999` take 3. -fn terminal_width_required_for_line_count(line_count: usize) -> usize { - line_count.to_string().chars().count() -} - -fn stylize_output_if_able(os: &Os, path: impl AsRef, file_text: &str) -> StylizedFile { - if supports_truecolor(os) { - match stylized_file(path, file_text) { - Ok(s) => return s, - Err(err) => { - error!(?err, "unable to syntax highlight the output"); - }, - } - } - StylizedFile { - truecolor: false, - content: file_text.to_string(), - gutter_bg: style::Color::Reset, - line_bg: style::Color::Reset, - } -} - -/// Represents a [String] that is potentially stylized with truecolor escape codes. -#[derive(Debug)] -struct StylizedFile { - /// Whether or not the file is stylized with 24bit color. - truecolor: bool, - /// File content. If [Self::truecolor] is true, then it has escape codes for styling with 24bit - /// color. - content: String, - /// Background color for the gutter. - gutter_bg: style::Color, - /// Background color for the line content. - line_bg: style::Color, -} - -impl Default for StylizedFile { - fn default() -> Self { - Self { - truecolor: false, - content: Default::default(), - gutter_bg: style::Color::Reset, - line_bg: style::Color::Reset, - } - } -} - -/// Returns a 24bit terminal escaped syntax-highlighted [String] of the file pointed to by `path`, -/// if able. -fn stylized_file(path: impl AsRef, file_text: impl AsRef) -> Result { - let ps = &*SYNTAX_SET; - let ts = &*THEME_SET; - - let extension = path - .as_ref() - .extension() - .wrap_err("missing extension")? - .to_str() - .wrap_err("not utf8")?; - - let syntax = ps - .find_syntax_by_extension(extension) - .wrap_err_with(|| format!("missing extension: {}", extension))?; - - let theme = &ts.themes["base16-ocean.dark"]; - let mut highlighter = HighlightLines::new(syntax, theme); - let file_text = file_text.as_ref().lines(); - let mut file = String::new(); - for line in file_text { - let mut ranges = Vec::new(); - ranges.append(&mut highlighter.highlight_line(line, ps)?); - let mut escaped_line = as_24_bit_terminal_escaped(&ranges[..], false); - escaped_line.push_str(&format!( - "{}\n", - crossterm::terminal::Clear(crossterm::terminal::ClearType::UntilNewLine), - )); - file.push_str(&escaped_line); - } - - let (line_bg, gutter_bg) = match (theme.settings.background, theme.settings.gutter) { - (Some(line_bg), Some(gutter_bg)) => (line_bg, gutter_bg), - (Some(line_bg), None) => (line_bg, line_bg), - _ => bail!("missing theme"), - }; - Ok(StylizedFile { - truecolor: true, - content: file, - gutter_bg: syntect_to_crossterm_color(gutter_bg), - line_bg: syntect_to_crossterm_color(line_bg), - }) -} - -fn syntect_to_crossterm_color(syntect: syntect::highlighting::Color) -> style::Color { - style::Color::Rgb { - r: syntect.r, - g: syntect.g, - b: syntect.b, - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::cli::chat::util::test::{ - TEST_FILE_CONTENTS, - TEST_FILE_PATH, - setup_test_directory, - }; - - #[test] - fn test_fs_write_deserialize() { - let path = "/my-file"; - let file_text = "hello world"; - - // create - let v = serde_json::json!({ - "path": path, - "command": "create", - "file_text": file_text - }); - let fw = serde_json::from_value::(v).unwrap(); - assert!(matches!(fw, FsWrite::Create { .. })); - - // str_replace - let v = serde_json::json!({ - "path": path, - "command": "str_replace", - "old_str": "prev string", - "new_str": "new string", - }); - let fw = serde_json::from_value::(v).unwrap(); - assert!(matches!(fw, FsWrite::StrReplace { .. })); - - // insert - let v = serde_json::json!({ - "path": path, - "command": "insert", - "insert_line": 3, - "new_str": "new string", - }); - let fw = serde_json::from_value::(v).unwrap(); - assert!(matches!(fw, FsWrite::Insert { .. })); - - // append - let v = serde_json::json!({ - "path": path, - "command": "append", - "new_str": "appended content", - }); - let fw = serde_json::from_value::(v).unwrap(); - assert!(matches!(fw, FsWrite::Append { .. })); - } - - #[test] - fn test_fs_write_deserialize_with_summary() { - let path = "/my-file"; - let file_text = "hello world"; - let summary = "Added hello world content"; - - // create with summary - let v = serde_json::json!({ - "path": path, - "command": "create", - "file_text": file_text, - "summary": summary - }); - let fw = serde_json::from_value::(v).unwrap(); - assert!(matches!(fw, FsWrite::Create { .. })); - if let FsWrite::Create { summary: s, .. } = &fw { - assert_eq!(s.as_ref().unwrap(), summary); - } - - // str_replace with summary - let v = serde_json::json!({ - "path": path, - "command": "str_replace", - "old_str": "prev string", - "new_str": "new string", - "summary": summary - }); - let fw = serde_json::from_value::(v).unwrap(); - assert!(matches!(fw, FsWrite::StrReplace { .. })); - if let FsWrite::StrReplace { summary: s, .. } = &fw { - assert_eq!(s.as_ref().unwrap(), summary); - } - } - - #[tokio::test] - async fn test_fs_write_tool_create() { - let os = setup_test_directory().await; - let mut stdout = std::io::stdout(); - - let file_text = "Hello, world!"; - let v = serde_json::json!({ - "path": "/my-file", - "command": "create", - "file_text": file_text - }); - serde_json::from_value::(v) - .unwrap() - .invoke(&os, &mut stdout) - .await - .unwrap(); - - assert_eq!( - os.fs.read_to_string("/my-file").await.unwrap(), - format!("{}\n", file_text) - ); - - let file_text = "Goodbye, world!\nSee you later"; - let v = serde_json::json!({ - "path": "/my-file", - "command": "create", - "file_text": file_text - }); - serde_json::from_value::(v) - .unwrap() - .invoke(&os, &mut stdout) - .await - .unwrap(); - - // File should end with a newline - assert_eq!( - os.fs.read_to_string("/my-file").await.unwrap(), - format!("{}\n", file_text) - ); - - let file_text = "This is a new string"; - let v = serde_json::json!({ - "path": "/my-file", - "command": "create", - "new_str": file_text - }); - serde_json::from_value::(v) - .unwrap() - .invoke(&os, &mut stdout) - .await - .unwrap(); - - assert_eq!( - os.fs.read_to_string("/my-file").await.unwrap(), - format!("{}\n", file_text) - ); - } - - #[tokio::test] - async fn test_fs_write_tool_str_replace() { - let os = setup_test_directory().await; - let mut stdout = std::io::stdout(); - - // No instances found - let v = serde_json::json!({ - "path": TEST_FILE_PATH, - "command": "str_replace", - "old_str": "asjidfopjaieopr", - "new_str": "1623749", - }); - assert!( - serde_json::from_value::(v) - .unwrap() - .invoke(&os, &mut stdout) - .await - .is_err() - ); - - // Multiple instances found - let v = serde_json::json!({ - "path": TEST_FILE_PATH, - "command": "str_replace", - "old_str": "Hello world!", - "new_str": "Goodbye world!", - }); - assert!( - serde_json::from_value::(v) - .unwrap() - .invoke(&os, &mut stdout) - .await - .is_err() - ); - - // Single instance found and replaced - let v = serde_json::json!({ - "path": TEST_FILE_PATH, - "command": "str_replace", - "old_str": "1: Hello world!", - "new_str": "1: Goodbye world!", - }); - serde_json::from_value::(v) - .unwrap() - .invoke(&os, &mut stdout) - .await - .unwrap(); - assert_eq!( - os.fs - .read_to_string(TEST_FILE_PATH) - .await - .unwrap() - .lines() - .next() - .unwrap(), - "1: Goodbye world!", - "expected the only occurrence to be replaced" - ); - } - - #[tokio::test] - async fn test_fs_write_tool_insert_at_beginning() { - let os = setup_test_directory().await; - let mut stdout = std::io::stdout(); - - let new_str = "1: New first line!\n"; - let v = serde_json::json!({ - "path": TEST_FILE_PATH, - "command": "insert", - "insert_line": 0, - "new_str": new_str, - }); - serde_json::from_value::(v) - .unwrap() - .invoke(&os, &mut stdout) - .await - .unwrap(); - let actual = os.fs.read_to_string(TEST_FILE_PATH).await.unwrap(); - assert_eq!( - format!("{}\n", actual.lines().next().unwrap()), - new_str, - "expected the first line to be updated to '{}'", - new_str - ); - assert_eq!( - actual.lines().skip(1).collect::>(), - TEST_FILE_CONTENTS.lines().collect::>(), - "the rest of the file should not have been updated" - ); - } - - #[tokio::test] - async fn test_fs_write_tool_insert_after_first_line() { - let os = setup_test_directory().await; - let mut stdout = std::io::stdout(); - - let new_str = "2: New second line!\n"; - let v = serde_json::json!({ - "path": TEST_FILE_PATH, - "command": "insert", - "insert_line": 1, - "new_str": new_str, - }); - - serde_json::from_value::(v) - .unwrap() - .invoke(&os, &mut stdout) - .await - .unwrap(); - let actual = os.fs.read_to_string(TEST_FILE_PATH).await.unwrap(); - assert_eq!( - format!("{}\n", actual.lines().nth(1).unwrap()), - new_str, - "expected the second line to be updated to '{}'", - new_str - ); - assert_eq!( - actual.lines().skip(2).collect::>(), - TEST_FILE_CONTENTS.lines().skip(1).collect::>(), - "the rest of the file should not have been updated" - ); - } - - #[tokio::test] - async fn test_fs_write_tool_insert_when_no_newlines_in_file() { - let os = Os::new().await.unwrap(); - let mut stdout = std::io::stdout(); - - let test_file_path = "/file.txt"; - let test_file_contents = "hello there"; - os.fs.write(test_file_path, test_file_contents).await.unwrap(); - - let new_str = "test"; - - // First, test appending - let v = serde_json::json!({ - "path": test_file_path, - "command": "insert", - "insert_line": 1, - "new_str": new_str, - }); - serde_json::from_value::(v) - .unwrap() - .invoke(&os, &mut stdout) - .await - .unwrap(); - let actual = os.fs.read_to_string(test_file_path).await.unwrap(); - assert_eq!(actual, format!("{}{}\n", test_file_contents, new_str)); - - // Then, test prepending - let v = serde_json::json!({ - "path": test_file_path, - "command": "insert", - "insert_line": 0, - "new_str": new_str, - }); - serde_json::from_value::(v) - .unwrap() - .invoke(&os, &mut stdout) - .await - .unwrap(); - let actual = os.fs.read_to_string(test_file_path).await.unwrap(); - assert_eq!(actual, format!("{}{}{}\n", new_str, test_file_contents, new_str)); - } - - #[tokio::test] - async fn test_fs_write_tool_append() { - let os = setup_test_directory().await; - let mut stdout = std::io::stdout(); - - // Test appending to existing file - let content_to_append = "5: Appended line"; - let v = serde_json::json!({ - "path": TEST_FILE_PATH, - "command": "append", - "new_str": content_to_append, - }); - - serde_json::from_value::(v) - .unwrap() - .invoke(&os, &mut stdout) - .await - .unwrap(); - - let actual = os.fs.read_to_string(TEST_FILE_PATH).await.unwrap(); - assert_eq!( - actual, - format!("{}{}\n", TEST_FILE_CONTENTS, content_to_append), - "Content should be appended to the end of the file with a newline added" - ); - - // Test appending to non-existent file (should fail) - let new_file_path = "/new_append_file.txt"; - let content = "This is a new file created by append"; - let v = serde_json::json!({ - "path": new_file_path, - "command": "append", - "new_str": content, - }); - - let result = serde_json::from_value::(v) - .unwrap() - .invoke(&os, &mut stdout) - .await; - - assert!(result.is_err(), "Appending to non-existent file should fail"); - } - - #[test] - fn test_lines_with_context() { - let content = "Hello\nWorld!\nhow\nare\nyou\ntoday?"; - assert_eq!(get_lines_with_context(content, 1, 1, 1), ("", 1, "World!\n", 2)); - assert_eq!(get_lines_with_context(content, 0, 0, 2), ("", 1, "Hello\nWorld!\n", 2)); - assert_eq!( - get_lines_with_context(content, 2, 4, 50), - ("Hello\n", 1, "you\ntoday?", 6) - ); - assert_eq!(get_lines_with_context(content, 4, 100, 2), ("World!\nhow\n", 2, "", 6)); - } - - #[test] - fn test_gutter_width() { - assert_eq!(terminal_width_required_for_line_count(1), 1); - assert_eq!(terminal_width_required_for_line_count(9), 1); - assert_eq!(terminal_width_required_for_line_count(10), 2); - assert_eq!(terminal_width_required_for_line_count(99), 2); - assert_eq!(terminal_width_required_for_line_count(100), 3); - assert_eq!(terminal_width_required_for_line_count(999), 3); - } - - #[tokio::test] - async fn test_fs_write_with_tilde_paths() { - // Create a test context - let os = Os::new().await.unwrap(); - let mut stdout = std::io::stdout(); - - // Get the home directory from the context - let home_dir = os.env.home().unwrap_or_default(); - println!("Test home directory: {:?}", home_dir); - - // Create a file directly in the home directory first to ensure it exists - let home_path = os.fs.chroot_path(&home_dir); - println!("Chrooted home path: {:?}", home_path); - - // Ensure the home directory exists - os.fs.create_dir_all(&home_path).await.unwrap(); - - let v = serde_json::json!({ - "path": "~/file.txt", - "command": "create", - "file_text": "content in home file" - }); - - let result = serde_json::from_value::(v) - .unwrap() - .invoke(&os, &mut stdout) - .await; - - match &result { - Ok(_) => println!("Writing to ~/file.txt succeeded"), - Err(e) => println!("Writing to ~/file.txt failed: {:?}", e), - } - - assert!(result.is_ok(), "Writing to ~/file.txt should succeed"); - - // Verify content was written correctly - let file_path = home_path.join("file.txt"); - println!("Checking file at: {:?}", file_path); - - let content_result = os.fs.read_to_string(&file_path).await; - match &content_result { - Ok(content) => println!("Read content: {:?}", content), - Err(e) => println!("Failed to read content: {:?}", e), - } - - assert!(content_result.is_ok(), "Should be able to read from expanded path"); - assert_eq!(content_result.unwrap(), "content in home file\n"); - - // Test with "~/nested/path/file.txt" to ensure deep paths work - let nested_dir = home_path.join("nested").join("path"); - os.fs.create_dir_all(&nested_dir).await.unwrap(); - - let v = serde_json::json!({ - "path": "~/nested/path/file.txt", - "command": "create", - "file_text": "content in nested path" - }); - - let result = serde_json::from_value::(v) - .unwrap() - .invoke(&os, &mut stdout) - .await; - - assert!(result.is_ok(), "Writing to ~/nested/path/file.txt should succeed"); - - // Verify nested path content - let nested_file_path = nested_dir.join("file.txt"); - let nested_content = os.fs.read_to_string(&nested_file_path).await.unwrap(); - assert_eq!(nested_content, "content in nested path\n"); - } -} diff --git a/crates/chat-cli/src/cli/chat/tools/gh_issue.rs b/crates/chat-cli/src/cli/chat/tools/gh_issue.rs deleted file mode 100644 index 787bf9d28..000000000 --- a/crates/chat-cli/src/cli/chat/tools/gh_issue.rs +++ /dev/null @@ -1,219 +0,0 @@ -use std::collections::{ - HashMap, - VecDeque, -}; -use std::io::Write; - -use crossterm::style::Color; -use crossterm::{ - queue, - style, -}; -use eyre::{ - Result, - WrapErr, - eyre, -}; -use serde::Deserialize; - -use super::super::context::ContextManager; -use super::super::util::issue::IssueCreator; -use super::{ - InvokeOutput, - ToolPermission, -}; -use crate::cli::chat::token_counter::TokenCounter; -use crate::os::Os; - -#[derive(Debug, Clone, Deserialize)] -pub struct GhIssue { - pub title: String, - pub expected_behavior: Option, - pub actual_behavior: Option, - pub steps_to_reproduce: Option, - - #[serde(skip_deserializing)] - pub context: Option, -} - -#[derive(Debug, Clone)] -pub struct GhIssueContext { - pub context_manager: Option, - pub transcript: VecDeque, - pub failed_request_ids: Vec, - pub tool_permissions: HashMap, -} - -/// Max amount of characters to include in the transcript. -const MAX_TRANSCRIPT_CHAR_LEN: usize = 3_000; - -impl GhIssue { - pub async fn invoke(&self, os: &Os, _updates: impl Write) -> Result { - let Some(context) = self.context.as_ref() else { - return Err(eyre!( - "report_issue: Required tool context (GhIssueContext) not set by the program." - )); - }; - - // Prepare additional details from the chat session - let additional_environment = [ - Self::get_chat_settings(context), - Self::get_request_ids(context), - Self::get_context(os, context).await, - ] - .join("\n\n"); - - // Add chat history to the actual behavior text. - let actual_behavior = self.actual_behavior.as_ref().map_or_else( - || Self::get_transcript(context), - |behavior| format!("{behavior}\n\n{}\n", Self::get_transcript(context)), - ); - - let _ = IssueCreator { - title: Some(self.title.clone()), - expected_behavior: self.expected_behavior.clone(), - actual_behavior: Some(actual_behavior), - steps_to_reproduce: self.steps_to_reproduce.clone(), - additional_environment: Some(additional_environment), - } - .create_url(os) - .await - .wrap_err("failed to invoke gh issue tool"); - - Ok(Default::default()) - } - - pub fn set_context(&mut self, context: GhIssueContext) { - self.context = Some(context); - } - - fn get_transcript(context: &GhIssueContext) -> String { - let mut transcript_str = String::from("```\n[chat-transcript]\n"); - let mut is_truncated = false; - let transcript: Vec = context.transcript - .iter() - .rev() // To take last N items - .scan(0, |user_msg_char_count, line| { - if *user_msg_char_count >= MAX_TRANSCRIPT_CHAR_LEN { - is_truncated = true; - return None; - } - let remaining_chars = MAX_TRANSCRIPT_CHAR_LEN - *user_msg_char_count; - let trimmed_line = if line.len() > remaining_chars { - &line[..remaining_chars] - } else { - line - }; - *user_msg_char_count += trimmed_line.len(); - - // backticks will mess up the markdown - let text = trimmed_line.replace("```", r"\```"); - Some(text) - }) - .collect::>() - .into_iter() - .rev() // Now return items to the proper order - .collect(); - - if !transcript.is_empty() { - transcript_str.push_str(&transcript.join("\n\n")); - } else { - transcript_str.push_str("No chat history found."); - } - - if is_truncated { - transcript_str.push_str("\n\n(...truncated)"); - } - transcript_str.push_str("\n```"); - transcript_str - } - - fn get_request_ids(context: &GhIssueContext) -> String { - format!( - "[chat-failed_request_ids]\n{}", - if context.failed_request_ids.is_empty() { - "none".to_string() - } else { - context.failed_request_ids.join("\n") - } - ) - } - - async fn get_context(os: &Os, context: &GhIssueContext) -> String { - let mut os_str = "[chat-context]\n".to_string(); - let Some(os_manager) = &context.context_manager else { - os_str.push_str("No context available."); - return os_str; - }; - - os_str.push_str(&format!("current_profile={}\n", os_manager.current_profile)); - match os_manager.list_profiles(os).await { - Ok(profiles) if !profiles.is_empty() => { - os_str.push_str(&format!("profiles=\n{}\n\n", profiles.join("\n"))); - }, - _ => os_str.push_str("profiles=none\n\n"), - } - - // Context file categories - if os_manager.global_config.paths.is_empty() { - os_str.push_str("global_context=none\n\n"); - } else { - os_str.push_str(&format!( - "global_context=\n{}\n\n", - &os_manager.global_config.paths.join("\n") - )); - } - - if os_manager.profile_config.paths.is_empty() { - os_str.push_str("profile_context=none\n\n"); - } else { - os_str.push_str(&format!( - "profile_context=\n{}\n\n", - &os_manager.profile_config.paths.join("\n") - )); - } - - // Handle context files - match os_manager.get_context_files(os).await { - Ok(context_files) if !context_files.is_empty() => { - os_str.push_str("files=\n"); - let total_size: usize = context_files - .iter() - .map(|(file, content)| { - let size = TokenCounter::count_tokens(content); - os_str.push_str(&format!("{}, {} tkns\n", file, size)); - size - }) - .sum(); - os_str.push_str(&format!("total context size={total_size} tkns")); - }, - _ => os_str.push_str("files=none"), - } - - os_str - } - - fn get_chat_settings(context: &GhIssueContext) -> String { - let mut result_str = "[chat-settings]\n".to_string(); - result_str.push_str("\n\n[chat-trusted_tools]"); - for (tool, permission) in context.tool_permissions.iter() { - result_str.push_str(&format!("\n{tool}={}", permission.trusted)); - } - - result_str - } - - pub fn queue_description(&self, output: &mut impl Write) -> Result<()> { - Ok(queue!( - output, - style::Print("I will prepare a github issue with our conversation history.\n\n"), - style::SetForegroundColor(Color::Green), - style::Print(format!("Title: {}\n", &self.title)), - style::ResetColor - )?) - } - - pub async fn validate(&mut self, _os: &Os) -> Result<()> { - Ok(()) - } -} diff --git a/crates/chat-cli/src/cli/chat/tools/knowledge.rs b/crates/chat-cli/src/cli/chat/tools/knowledge.rs deleted file mode 100644 index dca258fb8..000000000 --- a/crates/chat-cli/src/cli/chat/tools/knowledge.rs +++ /dev/null @@ -1,545 +0,0 @@ -use std::io::Write; - -use crossterm::queue; -use crossterm::style::{ - self, - Color, -}; -use eyre::Result; -use serde::Deserialize; -use tracing::warn; - -use super::{ - InvokeOutput, - OutputKind, -}; -use crate::database::settings::Setting; -use crate::os::Os; -use crate::util::knowledge_store::KnowledgeStore; - -/// The Knowledge tool allows storing and retrieving information across chat sessions. -/// It provides semantic search capabilities for files, directories, and text content. -/// -/// This feature can be enabled/disabled via settings: -/// `q settings chat.enableKnowledge true` -#[derive(Debug, Clone, Deserialize)] -#[serde(tag = "command", rename_all = "lowercase")] -pub enum Knowledge { - Add(KnowledgeAdd), - Remove(KnowledgeRemove), - Clear(KnowledgeClear), - Search(KnowledgeSearch), - Update(KnowledgeUpdate), - Show, - /// Show background operation status - Status, - /// Cancel a background operation - Cancel(KnowledgeCancel), -} - -#[derive(Debug, Clone, Deserialize)] -pub struct KnowledgeAdd { - pub name: String, - pub value: String, -} - -#[derive(Debug, Clone, Deserialize)] -pub struct KnowledgeRemove { - #[serde(default)] - pub name: String, - #[serde(default)] - pub context_id: String, - #[serde(default)] - pub path: String, -} - -#[derive(Debug, Clone, Deserialize)] -pub struct KnowledgeClear { - pub confirm: bool, -} - -#[derive(Debug, Clone, Deserialize)] -pub struct KnowledgeSearch { - pub query: String, - pub context_id: Option, -} - -#[derive(Debug, Clone, Deserialize)] -pub struct KnowledgeUpdate { - #[serde(default)] - pub path: String, - #[serde(default)] - pub context_id: String, - #[serde(default)] - pub name: String, -} - -#[derive(Debug, Clone, Deserialize)] -pub struct KnowledgeCancel { - /// Operation ID to cancel, or "all" to cancel all operations - pub operation_id: String, -} - -impl Knowledge { - /// Checks if the knowledge feature is enabled in settings - pub fn is_enabled(os: &Os) -> bool { - os.database - .settings - .get_bool(Setting::EnabledKnowledge) - .unwrap_or(false) - } - - pub async fn validate(&mut self, os: &Os) -> Result<()> { - match self { - Knowledge::Add(add) => { - // Check if value is intended to be a path (doesn't contain newlines) - if !add.value.contains('\n') { - let path = crate::cli::chat::tools::sanitize_path_tool_arg(os, &add.value); - if !path.exists() { - eyre::bail!("Path '{}' does not exist", add.value); - } - } - Ok(()) - }, - Knowledge::Remove(remove) => { - if remove.name.is_empty() && remove.context_id.is_empty() && remove.path.is_empty() { - eyre::bail!("Please provide at least one of: name, context_id, or path"); - } - // If path is provided, validate it exists - if !remove.path.is_empty() { - let path = crate::cli::chat::tools::sanitize_path_tool_arg(os, &remove.path); - if !path.exists() { - warn!( - "Path '{}' does not exist, will try to remove by path string match", - remove.path - ); - } - } - Ok(()) - }, - Knowledge::Update(update) => { - // Require at least one identifier (context_id or name) - if update.context_id.is_empty() && update.name.is_empty() && update.path.is_empty() { - eyre::bail!("Please provide either context_id or name or path to identify the context to update"); - } - - // Validate the path exists - if !update.path.is_empty() { - let path = crate::cli::chat::tools::sanitize_path_tool_arg(os, &update.path); - if !path.exists() { - eyre::bail!("Path '{}' does not exist", update.path); - } - } - - Ok(()) - }, - Knowledge::Clear(clear) => { - if !clear.confirm { - eyre::bail!("Please confirm clearing knowledge base by setting confirm=true"); - } - Ok(()) - }, - Knowledge::Search(_) => Ok(()), - Knowledge::Show => Ok(()), - Knowledge::Status => Ok(()), - Knowledge::Cancel(_) => Ok(()), - } - } - - pub async fn queue_description(&self, os: &Os, updates: &mut impl Write) -> Result<()> { - match self { - Knowledge::Add(add) => { - queue!( - updates, - style::Print("Adding to knowledge base: "), - style::SetForegroundColor(Color::Green), - style::Print(&add.name), - style::ResetColor, - )?; - - // Check if value is a path or text content - let path = crate::cli::chat::tools::sanitize_path_tool_arg(os, &add.value); - if path.exists() { - let path_type = if path.is_dir() { "directory" } else { "file" }; - queue!( - updates, - style::Print(format!(" ({}: ", path_type)), - style::SetForegroundColor(Color::Green), - style::Print(&add.value), - style::ResetColor, - style::Print(")\n") - )?; - } else { - let preview: String = add.value.chars().take(20).collect(); - if add.value.len() > 20 { - queue!( - updates, - style::Print(" (text: "), - style::SetForegroundColor(Color::Blue), - style::Print(format!("{}...", preview)), - style::ResetColor, - style::Print(")\n") - )?; - } else { - queue!( - updates, - style::Print(" (text: "), - style::SetForegroundColor(Color::Blue), - style::Print(&add.value), - style::ResetColor, - style::Print(")\n") - )?; - } - } - }, - Knowledge::Remove(remove) => { - if !remove.name.is_empty() { - queue!( - updates, - style::Print("Removing from knowledge base by name: "), - style::SetForegroundColor(Color::Green), - style::Print(&remove.name), - style::ResetColor, - )?; - } else if !remove.context_id.is_empty() { - queue!( - updates, - style::Print("Removing from knowledge base by ID: "), - style::SetForegroundColor(Color::Green), - style::Print(&remove.context_id), - style::ResetColor, - )?; - } else if !remove.path.is_empty() { - queue!( - updates, - style::Print("Removing from knowledge base by path: "), - style::SetForegroundColor(Color::Green), - style::Print(&remove.path), - style::ResetColor, - )?; - } else { - queue!( - updates, - style::Print("Removing from knowledge base: "), - style::SetForegroundColor(Color::Yellow), - style::Print("No identifier provided"), - style::ResetColor, - )?; - } - }, - Knowledge::Update(update) => { - queue!(updates, style::Print("Updating knowledge base context"),)?; - - if !update.context_id.is_empty() { - queue!( - updates, - style::Print(" with ID: "), - style::SetForegroundColor(Color::Green), - style::Print(&update.context_id), - style::ResetColor, - )?; - } else if !update.name.is_empty() { - queue!( - updates, - style::Print(" with name: "), - style::SetForegroundColor(Color::Green), - style::Print(&update.name), - style::ResetColor, - )?; - } - - let path = crate::cli::chat::tools::sanitize_path_tool_arg(os, &update.path); - let path_type = if path.is_dir() { "directory" } else { "file" }; - queue!( - updates, - style::Print(format!(" using new {}: ", path_type)), - style::SetForegroundColor(Color::Green), - style::Print(&update.path), - style::ResetColor, - )?; - }, - Knowledge::Clear(_) => { - queue!( - updates, - style::Print("Clearing "), - style::SetForegroundColor(Color::Yellow), - style::Print("all"), - style::ResetColor, - style::Print(" knowledge base entries"), - )?; - }, - Knowledge::Search(search) => { - queue!( - updates, - style::Print("Searching knowledge base for: "), - style::SetForegroundColor(Color::Green), - style::Print(&search.query), - style::ResetColor, - )?; - - if let Some(context_id) = &search.context_id { - queue!( - updates, - style::Print(" in context: "), - style::SetForegroundColor(Color::Green), - style::Print(context_id), - style::ResetColor, - )?; - } else { - queue!(updates, style::Print(" across all contexts"),)?; - } - }, - Knowledge::Show => { - queue!(updates, style::Print("Showing all knowledge base entries"),)?; - }, - Knowledge::Status => { - queue!(updates, style::Print("Checking background operation status"),)?; - }, - Knowledge::Cancel(cancel) => { - queue!( - updates, - style::Print(&format!("Cancelling operation: {}", cancel.operation_id)), - )?; - }, - }; - Ok(()) - } - - pub async fn invoke(&self, os: &Os, _updates: &mut impl Write) -> Result { - // Get the async knowledge store singleton - let async_knowledge_store = KnowledgeStore::get_async_instance().await; - let mut store = async_knowledge_store.lock().await; - - let result = match self { - Knowledge::Add(add) => { - // For path indexing, we'll show a progress message first - let path = crate::cli::chat::tools::sanitize_path_tool_arg(os, &add.value); - let value_to_use = if path.exists() { - path.to_string_lossy().to_string() - } else { - // If it's not a valid path, use the original value (might be text content) - add.value.clone() - }; - - match store.add(&add.name, &value_to_use).await { - Ok(context_id) => format!( - "Added '{}' to knowledge base with ID: {}. Track active jobs in '/knowledge status' with provided id.", - add.name, context_id - ), - Err(e) => format!("Failed to add to knowledge base: {}", e), - } - }, - Knowledge::Remove(remove) => { - if !remove.context_id.is_empty() { - // Remove by ID - match store.remove_by_id(&remove.context_id).await { - Ok(_) => format!("Removed context with ID '{}' from knowledge base", remove.context_id), - Err(e) => format!("Failed to remove context by ID: {}", e), - } - } else if !remove.name.is_empty() { - // Remove by name - match store.remove_by_name(&remove.name).await { - Ok(_) => format!("Removed context with name '{}' from knowledge base", remove.name), - Err(e) => format!("Failed to remove context by name: {}", e), - } - } else if !remove.path.is_empty() { - // Remove by path - let sanitized_path = crate::cli::chat::tools::sanitize_path_tool_arg(os, &remove.path); - match store.remove_by_path(sanitized_path.to_string_lossy().as_ref()).await { - Ok(_) => format!("Removed context with path '{}' from knowledge base", remove.path), - Err(e) => format!("Failed to remove context by path: {}", e), - } - } else { - "Error: No identifier provided for removal. Please specify name, context_id, or path.".to_string() - } - }, - Knowledge::Update(update) => { - // Validate that we have a path and at least one identifier - if update.path.is_empty() { - return Ok(InvokeOutput { - output: OutputKind::Text( - "Error: No path provided for update. Please specify a path to update with.".to_string(), - ), - }); - } - - // Sanitize the path - let path = crate::cli::chat::tools::sanitize_path_tool_arg(os, &update.path); - if !path.exists() { - return Ok(InvokeOutput { - output: OutputKind::Text(format!("Error: Path '{}' does not exist", update.path)), - }); - } - - let sanitized_path = path.to_string_lossy().to_string(); - - // Choose the appropriate update method based on provided identifiers - if !update.context_id.is_empty() { - // Update by ID - match store.update_context_by_id(&update.context_id, &sanitized_path).await { - Ok(_) => format!( - "Updated context with ID '{}' using path '{}'. Track active jobs in '/knowledge status' with provided id.", - update.context_id, update.path - ), - Err(e) => format!("Failed to update context by ID: {}", e), - } - } else if !update.name.is_empty() { - // Update by name - match store.update_context_by_name(&update.name, &sanitized_path).await { - Ok(_) => format!( - "Updated context with name '{}' using path '{}'. Track active jobs in '/knowledge status' with provided id.", - update.name, update.path - ), - Err(e) => format!("Failed to update context by name: {}", e), - } - } else { - // Update by path (if no ID or name provided) - match store.update_by_path(&sanitized_path).await { - Ok(_) => format!( - "Updated context with path '{}'. Track active jobs in '/knowledge status' with provided id.", - update.path - ), - Err(e) => format!("Failed to update context by path: {}", e), - } - } - }, - Knowledge::Clear(_) => store - .clear() - .await - .unwrap_or_else(|e| format!("Failed to clear knowledge base: {}", e)), - Knowledge::Search(search) => { - // Only use a spinner for search, not a full progress bar - let results = store.search(&search.query, search.context_id.as_deref()).await; - match results { - Ok(results) => { - if results.is_empty() { - "No matching entries found in knowledge base".to_string() - } else { - let mut output = String::from("Search results:\n"); - for result in results { - if let Some(text) = result.text() { - output.push_str(&format!("- {}\n", text)); - } - } - output - } - }, - Err(e) => format!("Search failed: {}", e), - } - }, - Knowledge::Show => { - let contexts = store.get_all().await; - match contexts { - Ok(contexts) => { - if contexts.is_empty() { - "No knowledge base entries found".to_string() - } else { - let mut output = String::from("Knowledge base entries:\n"); - for context in contexts { - output.push_str(&format!("- ID: {}\n Name: {}\n Description: {}\n Persistent: {}\n Created: {}\n Last Updated: {}\n Items: {}\n\n", - context.id, - context.name, - context.description, - context.persistent, - context.created_at.format("%Y-%m-%d %H:%M:%S"), - context.updated_at.format("%Y-%m-%d %H:%M:%S"), - context.item_count - )); - } - output - } - }, - Err(e) => format!("Failed to get knowledge base entries: {}", e), - } - }, - Knowledge::Status => { - match store.get_status_data().await { - Ok(status_data) => { - // Format the status data for display (same logic as knowledge command) - Self::format_status_display(&status_data) - }, - Err(e) => format!("Failed to get status: {}", e), - } - }, - Knowledge::Cancel(cancel) => store - .cancel_operation(Some(&cancel.operation_id)) - .await - .unwrap_or_else(|e| format!("Failed to cancel operation: {}", e)), - }; - - Ok(InvokeOutput { - output: OutputKind::Text(result), - }) - } - - /// Format status data for display (UI rendering responsibility) - fn format_status_display(status: &semantic_search_client::SystemStatus) -> String { - let mut status_lines = Vec::new(); - - // Show context summary - status_lines.push(format!( - "Total contexts: {} ({} persistent, {} volatile)", - status.total_contexts, status.persistent_contexts, status.volatile_contexts - )); - - if status.operations.is_empty() { - status_lines.push("No active operations".to_string()); - return status_lines.join("\n"); - } - - status_lines.push("Active Operations:".to_string()); - status_lines.push(format!( - "Queue Status: {} active, {} waiting (max {} concurrent)", - status.active_count, status.waiting_count, status.max_concurrent - )); - - for op in &status.operations { - let formatted_operation = Self::format_operation_display(op); - status_lines.push(formatted_operation); - } - - status_lines.join("\n") - } - - /// Format a single operation for display (LLM-friendly data format) - fn format_operation_display(op: &semantic_search_client::OperationStatus) -> String { - let elapsed = op.started_at.elapsed().unwrap_or_default(); - - let status_info = if op.is_cancelled { - "Status: Cancelled".to_string() - } else if op.is_failed { - format!("Status: Failed - {}", op.message) - } else if op.is_waiting { - format!("Status: Waiting - {}", op.message) - } else if op.total > 0 { - let percentage = (op.current as f64 / op.total as f64 * 100.0) as u8; - format!( - "Status: In Progress - {}% ({}/{}) - {}", - percentage, op.current, op.total, op.message - ) - } else { - format!("Status: In Progress - {}", op.message) - }; - - let operation_desc = op.operation_type.display_name(); - - // Format with conditional elapsed time and ETA - if op.is_cancelled || op.is_failed { - format!( - "Operation ID: {} | Type: {} | {}", - op.short_id, operation_desc, status_info - ) - } else { - let mut time_info = format!("Elapsed: {}s", elapsed.as_secs()); - - if let Some(eta) = op.eta { - time_info.push_str(&format!(" | ETA: {}s", eta.as_secs())); - } - - format!( - "Operation ID: {} | Type: {} | {} | {}", - op.short_id, operation_desc, status_info, time_info - ) - } - } -} diff --git a/crates/chat-cli/src/cli/chat/tools/mod.rs b/crates/chat-cli/src/cli/chat/tools/mod.rs deleted file mode 100644 index 1d03ea629..000000000 --- a/crates/chat-cli/src/cli/chat/tools/mod.rs +++ /dev/null @@ -1,523 +0,0 @@ -pub mod custom_tool; -pub mod execute; -pub mod fs_read; -pub mod fs_write; -pub mod gh_issue; -pub mod knowledge; -pub mod thinking; -pub mod use_aws; - -use std::collections::{ - HashMap, - HashSet, -}; -use std::io::Write; -use std::path::{ - Path, - PathBuf, -}; - -use crossterm::queue; -use crossterm::style::{ - self, - Color, - Stylize, -}; -use custom_tool::CustomTool; -use execute::ExecuteCommand; -use eyre::Result; -use fs_read::FsRead; -use fs_write::FsWrite; -use gh_issue::GhIssue; -use knowledge::Knowledge; -use serde::{ - Deserialize, - Serialize, -}; -use thinking::Thinking; -use use_aws::UseAws; - -use super::consts::MAX_TOOL_RESPONSE_SIZE; -use super::util::images::RichImageBlocks; -use crate::os::Os; - -/// Represents an executable tool use. -#[allow(clippy::large_enum_variant)] -#[derive(Debug, Clone)] -pub enum Tool { - FsRead(FsRead), - FsWrite(FsWrite), - ExecuteCommand(ExecuteCommand), - UseAws(UseAws), - Custom(CustomTool), - GhIssue(GhIssue), - Knowledge(Knowledge), - Thinking(Thinking), -} - -impl Tool { - /// The display name of a tool - pub fn display_name(&self) -> String { - match self { - Tool::FsRead(_) => "fs_read", - Tool::FsWrite(_) => "fs_write", - #[cfg(windows)] - Tool::ExecuteCommand(_) => "execute_cmd", - #[cfg(not(windows))] - Tool::ExecuteCommand(_) => "execute_bash", - Tool::UseAws(_) => "use_aws", - Tool::Custom(custom_tool) => &custom_tool.name, - Tool::GhIssue(_) => "gh_issue", - Tool::Knowledge(_) => "knowledge", - Tool::Thinking(_) => "thinking (prerelease)", - } - .to_owned() - } - - /// Whether or not the tool should prompt the user to accept before [Self::invoke] is called. - pub fn requires_acceptance(&self, _os: &Os) -> bool { - match self { - Tool::FsRead(_) => false, - Tool::FsWrite(_) => true, - Tool::ExecuteCommand(execute_command) => execute_command.requires_acceptance(), - Tool::UseAws(use_aws) => use_aws.requires_acceptance(), - Tool::Custom(_) => true, - Tool::GhIssue(_) => false, - Tool::Knowledge(_) => false, - Tool::Thinking(_) => false, - } - } - - /// Invokes the tool asynchronously - pub async fn invoke(&self, os: &Os, stdout: &mut impl Write) -> Result { - match self { - Tool::FsRead(fs_read) => fs_read.invoke(os, stdout).await, - Tool::FsWrite(fs_write) => fs_write.invoke(os, stdout).await, - Tool::ExecuteCommand(execute_command) => execute_command.invoke(stdout).await, - Tool::UseAws(use_aws) => use_aws.invoke(os, stdout).await, - Tool::Custom(custom_tool) => custom_tool.invoke(os, stdout).await, - Tool::GhIssue(gh_issue) => gh_issue.invoke(os, stdout).await, - Tool::Knowledge(knowledge) => knowledge.invoke(os, stdout).await, - Tool::Thinking(think) => think.invoke(stdout).await, - } - } - - /// Queues up a tool's intention in a human readable format - pub async fn queue_description(&self, os: &Os, output: &mut impl Write) -> Result<()> { - match self { - Tool::FsRead(fs_read) => fs_read.queue_description(os, output).await, - Tool::FsWrite(fs_write) => fs_write.queue_description(os, output), - Tool::ExecuteCommand(execute_command) => execute_command.queue_description(output), - Tool::UseAws(use_aws) => use_aws.queue_description(output), - Tool::Custom(custom_tool) => custom_tool.queue_description(output), - Tool::GhIssue(gh_issue) => gh_issue.queue_description(output), - Tool::Knowledge(knowledge) => knowledge.queue_description(os, output).await, - Tool::Thinking(thinking) => thinking.queue_description(output), - } - } - - /// Validates the tool with the arguments supplied - pub async fn validate(&mut self, os: &Os) -> Result<()> { - match self { - Tool::FsRead(fs_read) => fs_read.validate(os).await, - Tool::FsWrite(fs_write) => fs_write.validate(os).await, - Tool::ExecuteCommand(execute_command) => execute_command.validate(os).await, - Tool::UseAws(use_aws) => use_aws.validate(os).await, - Tool::Custom(custom_tool) => custom_tool.validate(os).await, - Tool::GhIssue(gh_issue) => gh_issue.validate(os).await, - Tool::Knowledge(knowledge) => knowledge.validate(os).await, - Tool::Thinking(think) => think.validate(os).await, - } - } - - /// Returns additional information about the tool if available - pub fn get_additional_info(&self) -> Option { - match self { - Tool::UseAws(use_aws) => Some(use_aws.get_additional_info()), - // Add other tool types here as they implement get_additional_info() - _ => None, - } - } -} - -#[derive(Debug, Clone)] -pub struct ToolPermission { - pub trusted: bool, -} - -#[derive(Debug, Clone)] -/// Holds overrides for tool permissions. -/// Tools that do not have an associated ToolPermission should use -/// their default logic to determine to permission. -pub struct ToolPermissions { - // We need this field for any stragglers - pub trust_all: bool, - pub permissions: HashMap, - // Store pending trust-tool patterns for MCP tools that may be loaded later - pub pending_trusted_tools: HashSet, -} - -impl ToolPermissions { - pub fn new(capacity: usize) -> Self { - Self { - trust_all: false, - permissions: HashMap::with_capacity(capacity), - pending_trusted_tools: HashSet::new(), - } - } - - pub fn is_trusted(&mut self, tool_name: &str) -> bool { - // Check if we should trust from pending patterns first - if self.should_trust_from_pending(tool_name) { - self.trust_tool(tool_name); - self.pending_trusted_tools.remove(tool_name); - } - - self.trust_all || self.permissions.get(tool_name).is_some_and(|perm| perm.trusted) - } - - /// Returns a label to describe the permission status for a given tool. - pub fn display_label(&mut self, tool_name: &str) -> String { - let is_trusted = self.is_trusted(tool_name); - let has_setting = self.has(tool_name) || self.trust_all; - - match (has_setting, is_trusted) { - (true, true) => format!(" {}", "trusted".dark_green().bold()), - (true, false) => format!(" {}", "not trusted".dark_grey()), - _ => self.default_permission_label(tool_name), - } - } - - pub fn trust_tool(&mut self, tool_name: &str) { - self.permissions - .insert(tool_name.to_string(), ToolPermission { trusted: true }); - } - - pub fn untrust_tool(&mut self, tool_name: &str) { - self.trust_all = false; - self.pending_trusted_tools.remove(tool_name); - self.permissions - .insert(tool_name.to_string(), ToolPermission { trusted: false }); - } - - pub fn reset(&mut self) { - self.trust_all = false; - self.permissions.clear(); - self.pending_trusted_tools.clear(); - } - - pub fn reset_tool(&mut self, tool_name: &str) { - self.trust_all = false; - self.permissions.remove(tool_name); - self.pending_trusted_tools.remove(tool_name); - } - - /// Add a pending trust pattern for tools that may be loaded later - pub fn add_pending_trust_tool(&mut self, pattern: String) { - self.pending_trusted_tools.insert(pattern); - } - - /// Check if a tool should be trusted based on preceding trust declarations - pub fn should_trust_from_pending(&self, tool_name: &str) -> bool { - // Check for exact match - self.pending_trusted_tools.contains(tool_name) - } - - pub fn has(&mut self, tool_name: &str) -> bool { - // Check if we should trust from pending tools first - if self.should_trust_from_pending(tool_name) { - self.trust_tool(tool_name); - self.pending_trusted_tools.remove(tool_name); - } - - self.permissions.contains_key(tool_name) - } - - /// Provide default permission labels for the built-in set of tools. - // This "static" way avoids needing to construct a tool instance. - fn default_permission_label(&self, tool_name: &str) -> String { - let label = match tool_name { - "fs_read" => "trusted".dark_green().bold(), - "fs_write" => "not trusted".dark_grey(), - #[cfg(not(windows))] - "execute_bash" => "trust read-only commands".dark_grey(), - #[cfg(windows)] - "execute_cmd" => "trust read-only commands".dark_grey(), - "use_aws" => "trust read-only commands".dark_grey(), - "report_issue" => "trusted".dark_green().bold(), - "knowledge" => "trusted".dark_green().bold(), - "thinking" => "trusted (prerelease)".dark_green().bold(), - _ if self.trust_all => "trusted".dark_grey().bold(), - _ => "not trusted".dark_grey(), - }; - - format!("{} {label}", "*".reset()) - } -} - -/// A tool specification to be sent to the model as part of a conversation. Maps to -/// [BedrockToolSpecification]. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ToolSpec { - pub name: String, - pub description: String, - #[serde(alias = "inputSchema")] - pub input_schema: InputSchema, - #[serde(skip_serializing, default = "tool_origin")] - pub tool_origin: ToolOrigin, -} - -#[derive(Debug, Clone, Eq, PartialEq, Hash)] -pub enum ToolOrigin { - Native, - McpServer(String), -} - -impl<'de> Deserialize<'de> for ToolOrigin { - fn deserialize(deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - let s = String::deserialize(deserializer)?; - if s == "native___" { - Ok(ToolOrigin::Native) - } else { - Ok(ToolOrigin::McpServer(s)) - } - } -} - -impl Serialize for ToolOrigin { - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { - match self { - ToolOrigin::Native => serializer.serialize_str("native___"), - ToolOrigin::McpServer(server) => serializer.serialize_str(server), - } - } -} - -impl std::fmt::Display for ToolOrigin { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - ToolOrigin::Native => write!(f, "Built-in"), - ToolOrigin::McpServer(server) => write!(f, "{} (MCP)", server), - } - } -} - -fn tool_origin() -> ToolOrigin { - ToolOrigin::Native -} - -#[derive(Debug, Clone)] -pub struct QueuedTool { - pub id: String, - pub name: String, - pub accepted: bool, - pub tool: Tool, -} - -/// The schema specification describing a tool's fields. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct InputSchema(pub serde_json::Value); - -/// The output received from invoking a [Tool]. -#[derive(Debug, Default)] -pub struct InvokeOutput { - pub output: OutputKind, -} - -impl InvokeOutput { - pub fn as_str(&self) -> &str { - match &self.output { - OutputKind::Text(s) => s.as_str(), - OutputKind::Json(j) => j.as_str().unwrap_or_default(), - OutputKind::Images(_) => "", - } - } -} - -#[non_exhaustive] -#[derive(Debug)] -pub enum OutputKind { - Text(String), - Json(serde_json::Value), - Images(RichImageBlocks), -} - -impl Default for OutputKind { - fn default() -> Self { - Self::Text(String::new()) - } -} - -/// Performs tilde expansion and other required sanitization modifications for handling tool use -/// path arguments. -/// -/// Required since path arguments are defined by the model. -#[allow(dead_code)] -pub fn sanitize_path_tool_arg(os: &Os, path: impl AsRef) -> PathBuf { - let mut res = PathBuf::new(); - // Expand `~` only if it is the first part. - let mut path = path.as_ref().components(); - match path.next() { - Some(p) if p.as_os_str() == "~" => { - res.push(os.env.home().unwrap_or_default()); - }, - Some(p) => res.push(p), - None => return res, - } - for p in path { - res.push(p); - } - // For testing scenarios, we need to make sure paths are appropriately handled in chroot test - // file systems since they are passed directly from the model. - os.fs.chroot_path(res) -} - -/// Converts `path` to a relative path according to the current working directory `cwd`. -fn absolute_to_relative(cwd: impl AsRef, path: impl AsRef) -> Result { - let cwd = cwd.as_ref().canonicalize()?; - let path = path.as_ref().canonicalize()?; - let mut cwd_parts = cwd.components().peekable(); - let mut path_parts = path.components().peekable(); - - // Skip common prefix - while let (Some(a), Some(b)) = (cwd_parts.peek(), path_parts.peek()) { - if a == b { - cwd_parts.next(); - path_parts.next(); - } else { - break; - } - } - - // ".." for any uncommon parts, then just append the rest of the path. - let mut relative = PathBuf::new(); - for _ in cwd_parts { - relative.push(".."); - } - for part in path_parts { - relative.push(part); - } - - Ok(relative) -} - -/// Small helper for formatting the path as a relative path, if able. -fn format_path(cwd: impl AsRef, path: impl AsRef) -> String { - absolute_to_relative(cwd, path.as_ref()) - .map(|p| p.to_string_lossy().to_string()) - // If we have three consecutive ".." then it should probably just stay as an absolute path. - .map(|p| { - let three_up = format!("..{}..{}..", std::path::MAIN_SEPARATOR, std::path::MAIN_SEPARATOR); - if p.starts_with(&three_up) { - path.as_ref().to_string_lossy().to_string() - } else { - p - } - }) - .unwrap_or(path.as_ref().to_string_lossy().to_string()) -} - -fn supports_truecolor(os: &Os) -> bool { - // Simple override to disable truecolor since shell_color doesn't use Context. - !os.env.get("Q_DISABLE_TRUECOLOR").is_ok_and(|s| !s.is_empty()) - && shell_color::get_color_support().contains(shell_color::ColorSupport::TERM24BIT) -} - -/// Helper function to display a purpose if available (for execute commands) -pub fn display_purpose(purpose: Option<&String>, updates: &mut impl Write) -> Result<()> { - if let Some(purpose) = purpose { - queue!( - updates, - style::Print(super::CONTINUATION_LINE), - style::Print("\n"), - style::Print(super::PURPOSE_ARROW), - style::SetForegroundColor(Color::Blue), - style::Print("Purpose: "), - style::ResetColor, - style::Print(purpose), - style::Print("\n"), - )?; - } - Ok(()) -} - -#[cfg(test)] -mod tests { - use std::path::MAIN_SEPARATOR; - - use super::*; - use crate::os::ACTIVE_USER_HOME; - - #[tokio::test] - async fn test_tilde_path_expansion() { - let os = Os::new().await.unwrap(); - - let actual = sanitize_path_tool_arg(&os, "~"); - let expected_home = os.env.home().unwrap_or_default(); - assert_eq!(actual, os.fs.chroot_path(&expected_home), "tilde should expand"); - let actual = sanitize_path_tool_arg(&os, "~/hello"); - assert_eq!( - actual, - os.fs.chroot_path(expected_home.join("hello")), - "tilde should expand" - ); - let actual = sanitize_path_tool_arg(&os, "/~"); - assert_eq!( - actual, - os.fs.chroot_path("/~"), - "tilde should not expand when not the first component" - ); - } - - #[tokio::test] - async fn test_format_path() { - async fn assert_paths(cwd: &str, path: &str, expected: &str) { - let os = Os::new().await.unwrap(); - let cwd = sanitize_path_tool_arg(&os, cwd); - let path = sanitize_path_tool_arg(&os, path); - let fs = os.fs; - fs.create_dir_all(&cwd).await.unwrap(); - fs.create_dir_all(&path).await.unwrap(); - - let formatted = format_path(&cwd, &path); - - if Path::new(expected).is_absolute() { - // If the expected path is relative, we need to ensure it is relative to the cwd. - let expected = fs.chroot_path_str(expected); - - assert!(formatted == expected, "Expected '{}' to be '{}'", formatted, expected); - - return; - } - - assert!( - formatted.contains(expected), - "Expected '{}' to be '{}'", - formatted, - expected - ); - } - - // Test relative path from src to Downloads (sibling directories) - assert_paths( - format!("{ACTIVE_USER_HOME}{MAIN_SEPARATOR}src").as_str(), - format!("{ACTIVE_USER_HOME}{MAIN_SEPARATOR}Downloads").as_str(), - format!("..{MAIN_SEPARATOR}Downloads").as_str(), - ) - .await; - - // Test absolute path that should stay absolute (going up too many levels) - assert_paths( - format!("{ACTIVE_USER_HOME}{MAIN_SEPARATOR}projects{MAIN_SEPARATOR}some{MAIN_SEPARATOR}project").as_str(), - format!("{ACTIVE_USER_HOME}{MAIN_SEPARATOR}other").as_str(), - format!("{ACTIVE_USER_HOME}{MAIN_SEPARATOR}other").as_str(), - ) - .await; - } -} diff --git a/crates/chat-cli/src/cli/chat/tools/thinking.rs b/crates/chat-cli/src/cli/chat/tools/thinking.rs deleted file mode 100644 index 8c101909a..000000000 --- a/crates/chat-cli/src/cli/chat/tools/thinking.rs +++ /dev/null @@ -1,70 +0,0 @@ -use std::io::Write; - -use crossterm::queue; -use crossterm::style::{ - self, - Color, -}; -use eyre::Result; -use serde::Deserialize; - -use super::{ - InvokeOutput, - OutputKind, -}; -use crate::database::settings::Setting; -use crate::os::Os; - -/// The Think tool allows the model to reason through complex problems during response generation. -/// It provides a dedicated space for the model to process information from tool call results, -/// navigate complex decision trees, and improve the quality of responses in multi-step scenarios. -/// -/// This is a beta feature that can be enabled/disabled via settings: -/// `q settings chat.enableThinking true` -#[derive(Debug, Clone, Deserialize)] -pub struct Thinking { - /// The thought content that the model wants to process - pub thought: String, -} - -impl Thinking { - /// Checks if the thinking feature is enabled in settings - pub fn is_enabled(os: &Os) -> bool { - os.database.settings.get_bool(Setting::EnabledThinking).unwrap_or(false) - } - - /// Queues up a description of the think tool for the user - pub fn queue_description(&self, output: &mut impl Write) -> Result<()> { - // Only show a description if there's actual thought content - if !self.thought.trim().is_empty() { - // Show a preview of the thought that will be displayed - queue!( - output, - style::SetForegroundColor(Color::Blue), - style::Print("I'll share my reasoning process: "), - style::SetForegroundColor(Color::Reset), - style::Print(&self.thought), - style::Print("\n") - )?; - } - Ok(()) - } - - /// Invokes the think tool. This doesn't actually perform any system operations, - /// it's purely for the model's internal reasoning process. - pub async fn invoke(&self, _updates: impl Write) -> Result { - // The think tool always returns an empty output because: - // 1. When enabled with content: We've already shown the thought in queue_description - // 2. When disabled or empty: Nothing should be shown - Ok(InvokeOutput { - output: OutputKind::Text(String::new()), - }) - } - - /// Validates the thought - accepts empty thoughts - pub async fn validate(&mut self, _os: &crate::os::Os) -> Result<()> { - // We accept empty thoughts - they'll just be ignored - // This makes the tool more robust and prevents errors from blocking the model - Ok(()) - } -} diff --git a/crates/chat-cli/src/cli/chat/tools/tool_index.json b/crates/chat-cli/src/cli/chat/tools/tool_index.json deleted file mode 100644 index 448e4c892..000000000 --- a/crates/chat-cli/src/cli/chat/tools/tool_index.json +++ /dev/null @@ -1,236 +0,0 @@ -{ - "dummy": { - "name": "dummy", - "description": "This is a dummy tool. If you are seeing this that means the tool associated with this tool call is not in the list of available tools. This could be because a wrong tool name was supplied or the list of tools has changed since the conversation has started. Do not show this when user asks you to list tools.", - "input_schema": { - "type": "object", - "properties": {}, - "required": [] - } - }, - "execute_bash": { - "name": "execute_bash", - "description": "Execute the specified bash command.", - "input_schema": { - "type": "object", - "properties": { - "command": { - "type": "string", - "description": "Bash command to execute" - }, - "summary": { - "type": "string", - "description": "A brief explanation of what the command does" - } - }, - "required": ["command"] - } - }, - "fs_read": { - "name": "fs_read", - "description": "Tool for reading files (for example, `cat -n`), directories (for example, `ls -la`) and images. If user has supplied paths that appear to be leading to images, you should use this tool right away using Image mode. The behavior of this tool is determined by the `mode` parameter. The available modes are:\n- line: Show lines in a file, given by an optional `start_line` and optional `end_line`.\n- directory: List directory contents. Content is returned in the \"long format\" of ls (that is, `ls -la`).\n- search: Search for a pattern in a file. The pattern is a string. The matching is case insensitive.\n\nExample Usage:\n1. Read all lines from a file: command=\"line\", path=\"/path/to/file.txt\"\n2. Read the last 5 lines from a file: command=\"line\", path=\"/path/to/file.txt\", start_line=-5\n3. List the files in the home directory: command=\"line\", path=\"~\"\n4. Recursively list files in a directory to a max depth of 2: command=\"line\", path=\"/path/to/directory\", depth=2\n5. Search for all instances of \"test\" in a file: command=\"search\", path=\"/path/to/file.txt\", pattern=\"test\"\n", - "input_schema": { - "type": "object", - "properties": { - "path": { - "description": "Path to the file or directory. The path should be absolute, or otherwise start with ~ for the user's home.", - "type": "string" - }, - "image_paths": { - "description": "List of paths to the images. This is currently supported by the Image mode.", - "type": "array", - "items": { - "type": "string" - } - }, - "mode": { - "type": "string", - "enum": [ - "Line", - "Directory", - "Search", - "Image" - ], - "description": "The mode to run in: `Line`, `Directory`, `Search`. `Line` and `Search` are only for text files, and `Directory` is only for directories. `Image` is for image files, in this mode `image_paths` is required." - }, - "start_line": { - "type": "integer", - "description": "Starting line number (optional, for Line mode). A negative index represents a line number starting from the end of the file.", - "default": 1 - }, - "end_line": { - "type": "integer", - "description": "Ending line number (optional, for Line mode). A negative index represents a line number starting from the end of the file.", - "default": -1 - }, - "pattern": { - "type": "string", - "description": "Pattern to search for (required, for Search mode). Case insensitive. The pattern matching is performed per line." - }, - "context_lines": { - "type": "integer", - "description": "Number of context lines around search results (optional, for Search mode)", - "default": 2 - }, - "depth": { - "type": "integer", - "description": "Depth of a recursive directory listing (optional, for Directory mode)", - "default": 0 - } - }, - "required": ["path", "mode"] - } - }, - "fs_write": { - "name": "fs_write", - "description": "A tool for creating and editing files\n * The `create` command will override the file at `path` if it already exists as a file, and otherwise create a new file\n * The `append` command will add content to the end of an existing file, automatically adding a newline if the file doesn't end with one. The file must exist.\n Notes for using the `str_replace` command:\n * The `old_str` parameter should match EXACTLY one or more consecutive lines from the original file. Be mindful of whitespaces!\n * If the `old_str` parameter is not unique in the file, the replacement will not be performed. Make sure to include enough context in `old_str` to make it unique\n * The `new_str` parameter should contain the edited lines that should replace the `old_str`.", - "input_schema": { - "type": "object", - "properties": { - "command": { - "type": "string", - "enum": ["create", "str_replace", "insert", "append"], - "description": "The commands to run. Allowed options are: `create`, `str_replace`, `insert`, `append`." - }, - "file_text": { - "description": "Required parameter of `create` command, with the content of the file to be created.", - "type": "string" - }, - "insert_line": { - "description": "Required parameter of `insert` command. The `new_str` will be inserted AFTER the line `insert_line` of `path`.", - "type": "integer" - }, - "new_str": { - "description": "Required parameter of `str_replace` command containing the new string. Required parameter of `insert` command containing the string to insert. Required parameter of `append` command containing the content to append to the file.", - "type": "string" - }, - "old_str": { - "description": "Required parameter of `str_replace` command containing the string in `path` to replace.", - "type": "string" - }, - "path": { - "description": "Absolute path to file or directory, e.g. `/repo/file.py` or `/repo`.", - "type": "string" - }, - "summary": { - "description": "A brief explanation of what the file change does or why it's being made.", - "type": "string" - } - }, - "required": ["command", "path"] - } - }, - "use_aws": { - "name": "use_aws", - "description": "Make an AWS CLI api call with the specified service, operation, and parameters. All arguments MUST conform to the AWS CLI specification. Should the output of the invocation indicate a malformed command, invoke help to obtain the the correct command.", - "input_schema": { - "type": "object", - "properties": { - "service_name": { - "type": "string", - "description": "The name of the AWS service. If you want to query s3, you should use s3api if possible." - }, - "operation_name": { - "type": "string", - "description": "The name of the operation to perform." - }, - "parameters": { - "type": "object", - "description": "The parameters for the operation. The parameter keys MUST conform to the AWS CLI specification. You should prefer to use JSON Syntax over shorthand syntax wherever possible. For parameters that are booleans, prioritize using flags with no value. Denote these flags with flag names as key and an empty string as their value. You should also prefer kebab case." - }, - "region": { - "type": "string", - "description": "Region name for calling the operation on AWS." - }, - "profile_name": { - "type": "string", - "description": "Optional: AWS profile name to use from ~/.aws/credentials. Defaults to default profile if not specified." - }, - "label": { - "type": "string", - "description": "Human readable description of the api that is being called." - } - }, - "required": ["region", "service_name", "operation_name", "label"] - } - }, - "gh_issue": { - "name": "report_issue", - "description": "Opens the browser to a pre-filled gh (GitHub) issue template to report chat issues, bugs, or feature requests. Pre-filled information includes the conversation transcript, chat context, and chat request IDs from the service.", - "input_schema": { - "type": "object", - "properties": { - "title": { - "type": "string", - "description": "The title of the GitHub issue." - }, - "expected_behavior": { - "type": "string", - "description": "Optional: The expected chat behavior or action that did not happen." - }, - "actual_behavior": { - "type": "string", - "description": "Optional: The actual chat behavior that happened and demonstrates the issue or lack of a feature." - }, - "steps_to_reproduce": { - "type": "string", - "description": "Optional: Previous user chat requests or steps that were taken that may have resulted in the issue or error response." - } - }, - "required": ["title"] - } - }, - "thinking": { - "name": "thinking", - "description": "Thinking is an internal reasoning mechanism improving the quality of complex tasks by breaking their atomic actions down; use it specifically for multi-step problems requiring step-by-step dependencies, reasoning through multiple constraints, synthesizing results from previous tool calls, planning intricate sequences of actions, troubleshooting complex errors, or making decisions involving multiple trade-offs. Avoid using it for straightforward tasks, basic information retrieval, summaries, always clearly define the reasoning challenge, structure thoughts explicitly, consider multiple perspectives, and summarize key insights before important decisions or complex tool interactions.", - "input_schema": { - "type": "object", - "properties": { - "thought": { - "type": "string", - "description": "A reflective note or intermediate reasoning step such as \"The user needs to prepare their application for production. I need to complete three major asks including 1: building their code from source, 2: bundling their release artifacts together, and 3: signing the application bundle." - } - }, - "required": ["thought"] - } - }, - "knowledge": { - "name": "knowledge", - "description": "Store and retrieve information in knowledge base across chat sessions. Provides semantic search capabilities for files, directories, and text content.", - "input_schema": { - "type": "object", - "properties": { - "command": { - "type": "string", - "enum": ["show", "add", "remove", "clear", "search", "update", "status", "cancel"], - "description": "The knowledge operation to perform:\n- 'show': List all knowledge contexts (no additional parameters required)\n- 'add': Add content to knowledge base (requires 'name' and 'value')\n- 'remove': Remove content from knowledge base (requires one of: 'name', 'context_id', or 'path')\n- 'clear': Remove all knowledge contexts.\n- 'search': Search across knowledge contexts (requires 'query', optional 'context_id')\n- 'update': Update existing context with new content (requires 'path' and one of: 'name', 'context_id')\n- 'status': Show background operation status and progress\n- 'cancel': Cancel background operations (optional 'operation_id' to cancel specific operation, or cancel all if not provided)" - }, - "name": { - "type": "string", - "description": "A descriptive name for the knowledge context. Required for 'add' operations. Can be used for 'remove' and 'update' operations to identify the context." - }, - "value": { - "type": "string", - "description": "The content to store in knowledge base. Required for 'add' operations. Can be either text content or a file/directory path. If it's a valid file or directory path, the content will be indexed; otherwise it's treated as text." - }, - "context_id": { - "type": "string", - "description": "The unique context identifier for targeted operations. Can be obtained from 'show' command. Used for 'remove', 'update', and 'search' operations to specify which context to operate on." - }, - "path": { - "type": "string", - "description": "File or directory path. Used in 'remove' operations to remove contexts by their source path, and required for 'update' operations to specify the new content location." - }, - "query": { - "type": "string", - "description": "The search query string. Required for 'search' operations. Performs semantic search across knowledge contexts to find relevant content." - }, - "operation_id": { - "type": "string", - "description": "Optional operation ID to cancel a specific operation. Used with 'cancel' command. If not provided, all active operations will be cancelled. Can be either the full operation ID or the short 8-character ID." - } - }, - "required": ["command"] - } - } -} diff --git a/crates/chat-cli/src/cli/chat/tools/use_aws.rs b/crates/chat-cli/src/cli/chat/tools/use_aws.rs deleted file mode 100644 index 467083565..000000000 --- a/crates/chat-cli/src/cli/chat/tools/use_aws.rs +++ /dev/null @@ -1,323 +0,0 @@ -use std::collections::HashMap; -use std::io::Write; -use std::process::Stdio; - -use bstr::ByteSlice; -use convert_case::{ - Case, - Casing, -}; -use crossterm::{ - queue, - style, -}; -use eyre::{ - Result, - WrapErr, -}; -use serde::Deserialize; - -use super::{ - InvokeOutput, - MAX_TOOL_RESPONSE_SIZE, - OutputKind, -}; -use crate::os::Os; - -const READONLY_OPS: [&str; 6] = ["get", "describe", "list", "ls", "search", "batch_get"]; - -/// The environment variable name where we set additional metadata for the AWS CLI user agent. -const USER_AGENT_ENV_VAR: &str = "AWS_EXECUTION_ENV"; -const USER_AGENT_APP_NAME: &str = "AmazonQ-For-CLI"; -const USER_AGENT_VERSION_KEY: &str = "Version"; -const USER_AGENT_VERSION_VALUE: &str = env!("CARGO_PKG_VERSION"); - -// TODO: we should perhaps composite this struct with an interface that we can use to mock the -// actual cli with. That will allow us to more thoroughly test it. -#[derive(Debug, Clone, Deserialize)] -pub struct UseAws { - pub service_name: String, - pub operation_name: String, - pub parameters: Option>, - pub region: String, - pub profile_name: Option, - pub label: Option, -} - -impl UseAws { - pub fn requires_acceptance(&self) -> bool { - !READONLY_OPS.iter().any(|op| self.operation_name.starts_with(op)) - } - - pub async fn invoke(&self, _os: &Os, _updates: impl Write) -> Result { - let mut command = tokio::process::Command::new("aws"); - command.envs(std::env::vars()); - - // Set up environment variables - let mut env_vars: std::collections::HashMap = std::env::vars().collect(); - - // Set up additional metadata for the AWS CLI user agent - let user_agent_metadata_value = format!( - "{} {}/{}", - USER_AGENT_APP_NAME, USER_AGENT_VERSION_KEY, USER_AGENT_VERSION_VALUE - ); - - // If the user agent metadata env var already exists, append to it, otherwise set it - if let Some(existing_value) = env_vars.get(USER_AGENT_ENV_VAR) { - if !existing_value.is_empty() { - env_vars.insert( - USER_AGENT_ENV_VAR.to_string(), - format!("{} {}", existing_value, user_agent_metadata_value), - ); - } else { - env_vars.insert(USER_AGENT_ENV_VAR.to_string(), user_agent_metadata_value); - } - } else { - env_vars.insert(USER_AGENT_ENV_VAR.to_string(), user_agent_metadata_value); - } - - command.envs(env_vars).arg("--region").arg(&self.region); - if let Some(profile_name) = self.profile_name.as_deref() { - command.arg("--profile").arg(profile_name); - } - command.arg(&self.service_name).arg(&self.operation_name); - if let Some(parameters) = self.cli_parameters() { - for (name, val) in parameters { - command.arg(name); - if !val.is_empty() { - command.arg(val); - } - } - } - let output = command - .stdout(Stdio::piped()) - .stderr(Stdio::piped()) - .spawn() - .wrap_err_with(|| format!("Unable to spawn command '{:?}'", self))? - .wait_with_output() - .await - .wrap_err_with(|| format!("Unable to spawn command '{:?}'", self))?; - let status = output.status.code().unwrap_or(0).to_string(); - let stdout = output.stdout.to_str_lossy(); - let stderr = output.stderr.to_str_lossy(); - - let stdout = format!( - "{}{}", - &stdout[0..stdout.len().min(MAX_TOOL_RESPONSE_SIZE / 3)], - if stdout.len() > MAX_TOOL_RESPONSE_SIZE / 3 { - " ... truncated" - } else { - "" - } - ); - - let stderr = format!( - "{}{}", - &stderr[0..stderr.len().min(MAX_TOOL_RESPONSE_SIZE / 3)], - if stderr.len() > MAX_TOOL_RESPONSE_SIZE / 3 { - " ... truncated" - } else { - "" - } - ); - - if status.eq("0") { - Ok(InvokeOutput { - output: OutputKind::Json(serde_json::json!({ - "exit_status": status, - "stdout": stdout, - "stderr": stderr.clone() - })), - }) - } else { - Err(eyre::eyre!(stderr)) - } - } - - pub fn queue_description(&self, output: &mut impl Write) -> Result<()> { - queue!( - output, - style::Print("Running aws cli command:\n\n"), - style::Print(format!("Service name: {}\n", self.service_name)), - style::Print(format!("Operation name: {}\n", self.operation_name)), - )?; - if let Some(parameters) = &self.parameters { - queue!(output, style::Print("Parameters: \n".to_string()))?; - for (name, value) in parameters { - match value { - serde_json::Value::String(s) if s.is_empty() => { - queue!(output, style::Print(format!("- {}\n", name)))?; - }, - _ => { - queue!(output, style::Print(format!("- {}: {}\n", name, value)))?; - }, - } - } - } - - if let Some(ref profile_name) = self.profile_name { - queue!(output, style::Print(format!("Profile name: {}\n", profile_name)))?; - } else { - queue!(output, style::Print("Profile name: default\n".to_string()))?; - } - - queue!(output, style::Print(format!("Region: {}", self.region)))?; - - if let Some(ref label) = self.label { - queue!(output, style::Print(format!("\nLabel: {}", label)))?; - } - Ok(()) - } - - pub async fn validate(&mut self, _os: &Os) -> Result<()> { - Ok(()) - } - - pub fn get_additional_info(&self) -> serde_json::Value { - serde_json::json!({ - "aws_service_name": self.service_name.clone(), - "aws_operation_name": self.operation_name.clone() - }) - } - - /// Returns the CLI arguments properly formatted as kebab case if parameters is - /// [Option::Some], otherwise None - fn cli_parameters(&self) -> Option> { - if let Some(parameters) = &self.parameters { - let mut params = vec![]; - for (param_name, val) in parameters { - let param_name = format!("--{}", param_name.trim_start_matches("--").to_case(Case::Kebab)); - let param_val = val.as_str().map(|s| s.to_string()).unwrap_or(val.to_string()); - params.push((param_name, param_val)); - } - Some(params) - } else { - None - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - macro_rules! use_aws { - ($value:tt) => { - serde_json::from_value::(serde_json::json!($value)).unwrap() - }; - } - - #[test] - fn test_requires_acceptance() { - let cmd = use_aws! {{ - "service_name": "ecs", - "operation_name": "list-task-definitions", - "region": "us-west-2", - "profile_name": "default", - "label": "" - }}; - assert!(!cmd.requires_acceptance()); - let cmd = use_aws! {{ - "service_name": "lambda", - "operation_name": "list-functions", - "region": "us-west-2", - "profile_name": "default", - "label": "" - }}; - assert!(!cmd.requires_acceptance()); - let cmd = use_aws! {{ - "service_name": "s3", - "operation_name": "put-object", - "region": "us-west-2", - "profile_name": "default", - "label": "" - }}; - assert!(cmd.requires_acceptance()); - } - - #[test] - fn test_use_aws_deser() { - let cmd = use_aws! {{ - "service_name": "s3", - "operation_name": "put-object", - "parameters": { - "TableName": "table-name", - "KeyConditionExpression": "PartitionKey = :pkValue" - }, - "region": "us-west-2", - "profile_name": "default", - "label": "" - }}; - let params = cmd.cli_parameters().unwrap(); - assert!( - params.iter().any(|p| p.0 == "--table-name" && p.1 == "table-name"), - "not found in {:?}", - params - ); - assert!( - params - .iter() - .any(|p| p.0 == "--key-condition-expression" && p.1 == "PartitionKey = :pkValue"), - "not found in {:?}", - params - ); - } - - #[tokio::test] - #[ignore = "not in ci"] - async fn test_aws_read_only() { - let os = Os::new().await.unwrap(); - - let v = serde_json::json!({ - "service_name": "s3", - "operation_name": "put-object", - // technically this wouldn't be a valid request with an empty parameter set but it's - // okay for this test - "parameters": {}, - "region": "us-west-2", - "profile_name": "default", - "label": "" - }); - - assert!( - serde_json::from_value::(v) - .unwrap() - .invoke(&os, &mut std::io::stdout()) - .await - .is_err() - ); - } - - #[tokio::test] - #[ignore = "not in ci"] - async fn test_aws_output() { - let os = Os::new().await.unwrap(); - - let v = serde_json::json!({ - "service_name": "s3", - "operation_name": "ls", - "parameters": {}, - "region": "us-west-2", - "profile_name": "default", - "label": "" - }); - let out = serde_json::from_value::(v) - .unwrap() - .invoke(&os, &mut std::io::stdout()) - .await - .unwrap(); - - if let OutputKind::Json(json) = out.output { - // depending on where the test is ran we might get different outcome here but it does - // not mean the tool is not working - let exit_status = json.get("exit_status").unwrap(); - if exit_status == 0 { - assert_eq!(json.get("stderr").unwrap(), ""); - } else { - assert_ne!(json.get("stderr").unwrap(), ""); - } - } else { - panic!("Expected JSON output"); - } - } -} diff --git a/crates/chat-cli/src/cli/chat/util/images.rs b/crates/chat-cli/src/cli/chat/util/images.rs deleted file mode 100644 index 27f362392..000000000 --- a/crates/chat-cli/src/cli/chat/util/images.rs +++ /dev/null @@ -1,283 +0,0 @@ -use std::fs; -use std::io::Write; -use std::path::Path; -use std::str::FromStr; - -use crossterm::execute; -use crossterm::style::{ - self, - Color, -}; -use serde::{ - Deserialize, - Serialize, -}; - -use crate::api_client::model::{ - ImageBlock, - ImageFormat, - ImageSource, -}; -use crate::cli::chat::consts::{ - MAX_IMAGE_SIZE, - MAX_NUMBER_OF_IMAGES_PER_REQUEST, -}; - -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct ImageMetadata { - pub filepath: String, - /// The size of the image in bytes - pub size: u64, - pub filename: String, -} - -pub type RichImageBlocks = Vec; -pub type RichImageBlock = (ImageBlock, ImageMetadata); - -/// Macos screenshots insert a NNBSP character rather than a space between the timestamp and AM/PM -/// part. An example of a screenshot name is: /path-to/Screenshot 2025-03-13 at 1.46.32 PM.png -/// -/// However, the model will just treat it as a normal space and return the wrong path string to the -/// `fs_read` tool. This will lead to file-not-found errors. -pub fn pre_process(path: &str) -> String { - if cfg!(target_os = "macos") && path.contains("Screenshot") { - let mac_screenshot_regex = - regex::Regex::new(r"Screenshot \d{4}-\d{2}-\d{2} at \d{1,2}\.\d{2}\.\d{2} [AP]M").unwrap(); - if mac_screenshot_regex.is_match(path) { - if let Some(pos) = path.find(" at ") { - let mut new_path = String::new(); - new_path.push_str(&path[..pos + 4]); - new_path.push_str(&path[pos + 4..].replace(" ", "\u{202F}")); - return new_path; - } - } - } - - path.to_string() -} - -pub fn handle_images_from_paths(output: &mut impl Write, paths: &[String]) -> RichImageBlocks { - let mut extracted_images = Vec::new(); - let mut seen_args = std::collections::HashSet::new(); - - for path in paths.iter() { - if seen_args.contains(path) { - continue; - } - seen_args.insert(path); - if is_supported_image_type(path) { - if let Some(image_block) = get_image_block_from_file_path(path) { - let filename = Path::new(path) - .file_name() - .unwrap_or_default() - .to_string_lossy() - .to_string(); - - let image_size = fs::metadata(path).map(|m| m.len()).unwrap_or_default(); - - extracted_images.push((image_block, ImageMetadata { - filename, - filepath: path.clone(), - size: image_size, - })); - } - } - } - - let (mut valid_images, images_exceeding_size_limit): (RichImageBlocks, RichImageBlocks) = extracted_images - .into_iter() - .partition(|(_, metadata)| metadata.size as usize <= MAX_IMAGE_SIZE); - - if valid_images.len() > MAX_NUMBER_OF_IMAGES_PER_REQUEST { - execute!( - &mut *output, - style::SetForegroundColor(Color::DarkYellow), - style::Print(format!( - "\nMore than {} images detected. Extra ones will be dropped.\n", - MAX_NUMBER_OF_IMAGES_PER_REQUEST - )), - style::SetForegroundColor(Color::Reset) - ) - .ok(); - valid_images.truncate(MAX_NUMBER_OF_IMAGES_PER_REQUEST); - } - - if !images_exceeding_size_limit.is_empty() { - execute!( - &mut *output, - style::SetForegroundColor(Color::DarkYellow), - style::Print(format!( - "\nThe following images are dropped due to exceeding size limit ({}MB):\n", - MAX_IMAGE_SIZE / (1024 * 1024) - )), - style::SetForegroundColor(Color::Reset) - ) - .ok(); - for (_, metadata) in &images_exceeding_size_limit { - let image_size_str = if metadata.size > 1024 * 1024 { - format!("{:.2} MB", metadata.size as f64 / (1024.0 * 1024.0)) - } else if metadata.size > 1024 { - format!("{:.2} KB", metadata.size as f64 / 1024.0) - } else { - format!("{} bytes", metadata.size) - }; - execute!( - &mut *output, - style::SetForegroundColor(Color::DarkYellow), - style::Print(format!(" - {} ({})\n", metadata.filename, image_size_str)), - style::SetForegroundColor(Color::Reset) - ) - .ok(); - } - } - valid_images -} - -/// This function checks if the file path has a supported image type -/// and returns true if it does, otherwise false. -/// Supported image types are: jpg, jpeg, png, gif, webp -/// -/// # Arguments -/// -/// * `maybe_file_path` - A string slice that may or may not be a valid file path -/// -/// # Returns -/// -/// * `true` if the file path has a supported image type -/// * `false` otherwise -pub fn is_supported_image_type(maybe_file_path: &str) -> bool { - let supported_image_types = ["jpg", "jpeg", "png", "gif", "webp"]; - if let Some(extension) = maybe_file_path.split('.').next_back() { - return supported_image_types.contains(&extension.trim().to_lowercase().as_str()); - } - false -} - -pub fn get_image_block_from_file_path(maybe_file_path: &str) -> Option { - if !is_supported_image_type(maybe_file_path) { - return None; - } - - let file_path = Path::new(maybe_file_path); - if !file_path.exists() { - return None; - } - - let image_bytes = fs::read(file_path); - if image_bytes.is_err() { - return None; - } - - let image_format = ImageFormat::from_str(file_path.extension()?.to_str()?.to_lowercase().as_str()); - - if image_format.is_err() { - return None; - } - - let image_bytes = image_bytes.unwrap(); - let image_block = ImageBlock { - format: image_format.unwrap(), - source: ImageSource::Bytes(image_bytes), - }; - Some(image_block) -} - -#[cfg(test)] -mod tests { - use std::str::FromStr; - - use bstr::ByteSlice; - - use super::*; - - #[test] - fn test_is_supported_image_type() { - let test_cases = vec![ - ("image.jpg", true), - ("image.jpeg", true), - ("image.png", true), - ("image.gif", true), - ("image.webp", true), - ("image.txt", false), - ("image", false), - ]; - - for (path, expected) in test_cases { - assert_eq!(is_supported_image_type(path), expected, "Failed for path: {}", path); - } - } - - #[test] - fn test_get_image_format_from_ext() { - assert_eq!(ImageFormat::from_str("jpg"), Ok(ImageFormat::Jpeg)); - assert_eq!(ImageFormat::from_str("JPEG"), Ok(ImageFormat::Jpeg)); - assert_eq!(ImageFormat::from_str("png"), Ok(ImageFormat::Png)); - assert_eq!(ImageFormat::from_str("gif"), Ok(ImageFormat::Gif)); - assert_eq!(ImageFormat::from_str("webp"), Ok(ImageFormat::Webp)); - assert_eq!( - ImageFormat::from_str("txt"), - Err("Failed to parse 'txt' as ImageFormat".to_string()) - ); - } - - #[test] - fn test_handle_images_from_paths() { - let temp_dir = tempfile::tempdir().unwrap(); - let image_path = temp_dir.path().join("test_image.jpg"); - std::fs::write(&image_path, b"fake_image_data").unwrap(); - - let images = handle_images_from_paths(&mut vec![], &[image_path.to_string_lossy().to_string()]); - - assert_eq!(images.len(), 1); - assert_eq!(images[0].1.filename, "test_image.jpg"); - assert_eq!(images[0].1.filepath, image_path.to_string_lossy()); - } - - #[test] - fn test_get_image_block_from_file_path() { - let temp_dir = tempfile::tempdir().unwrap(); - let image_path = temp_dir.path().join("test_image.png"); - std::fs::write(&image_path, b"fake_image_data").unwrap(); - - let image_block = get_image_block_from_file_path(&image_path.to_string_lossy()); - assert!(image_block.is_some()); - let image_block = image_block.unwrap(); - assert_eq!(image_block.format, ImageFormat::Png); - if let ImageSource::Bytes(bytes) = image_block.source { - assert_eq!(bytes, b"fake_image_data"); - } else { - panic!("Expected ImageSource::Bytes"); - } - } - - #[test] - fn test_handle_images_size_limit_exceeded() { - let temp_dir = tempfile::tempdir().unwrap(); - let large_image_path = temp_dir.path().join("large_image.jpg"); - let large_image_size = MAX_IMAGE_SIZE + 1; - std::fs::write(&large_image_path, vec![0; large_image_size]).unwrap(); - let mut output = vec![]; - let images = handle_images_from_paths(&mut output, &[large_image_path.to_string_lossy().to_string()]); - let output_str = output.to_str_lossy(); - print!("{}", output_str); - assert!(output_str.contains("The following images are dropped due to exceeding size limit (10MB):")); - assert!(output_str.contains("- large_image.jpg (10.00 MB)")); - assert!(images.is_empty()); - } - - #[test] - fn test_handle_images_number_exceeded() { - let temp_dir = tempfile::tempdir().unwrap(); - - let mut paths = vec![]; - for i in 0..(MAX_NUMBER_OF_IMAGES_PER_REQUEST + 2) { - let image_path = temp_dir.path().join(format!("image_{}.jpg", i)); - paths.push(image_path.to_string_lossy().to_string()); - std::fs::write(&image_path, b"fake_image_data").unwrap(); - } - - let images = handle_images_from_paths(&mut vec![], &paths); - - assert_eq!(images.len(), MAX_NUMBER_OF_IMAGES_PER_REQUEST); - } -} diff --git a/crates/chat-cli/src/cli/chat/util/issue.rs b/crates/chat-cli/src/cli/chat/util/issue.rs deleted file mode 100644 index 5b2434904..000000000 --- a/crates/chat-cli/src/cli/chat/util/issue.rs +++ /dev/null @@ -1,84 +0,0 @@ -use anstream::{ - eprintln, - println, -}; -use crossterm::style::Stylize; -use eyre::Result; - -use crate::os::Os; -use crate::os::diagnostics::Diagnostics; -use crate::util::GITHUB_REPO_NAME; -use crate::util::system_info::is_remote; - -const TEMPLATE_NAME: &str = "1_bug_report_template.yml"; - -pub struct IssueCreator { - /// Issue title - pub title: Option, - /// Issue description - pub expected_behavior: Option, - /// Issue description - pub actual_behavior: Option, - /// Issue description - pub steps_to_reproduce: Option, - /// Issue description - pub additional_environment: Option, -} - -impl IssueCreator { - pub async fn create_url(&self, os: &Os) -> Result { - println!("Heading over to GitHub..."); - - let warning = |text: &String| { - format!("\n\n{text}") - }; - let diagnostics = Diagnostics::new(&os.env).await; - - let os = match &diagnostics.system_info.os { - Some(os) => os.to_string(), - None => "None".to_owned(), - }; - - let diagnostic_info = match diagnostics.user_readable() { - Ok(diagnostics) => diagnostics, - Err(err) => { - eprintln!("Error getting diagnostics: {err}"); - "Error occurred while generating diagnostics".to_owned() - }, - }; - - let environment = match &self.additional_environment { - Some(os) => format!("{diagnostic_info}\n{os}"), - None => diagnostic_info, - }; - - let mut params = Vec::new(); - params.push(("template", TEMPLATE_NAME.to_string())); - params.push(("os", os)); - params.push(("environment", warning(&environment))); - - if let Some(t) = self.title.clone() { - params.push(("title", t)); - } - if let Some(t) = self.expected_behavior.as_ref() { - params.push(("expected", warning(t))); - } - if let Some(t) = self.actual_behavior.as_ref() { - params.push(("actual", warning(t))); - } - if let Some(t) = self.steps_to_reproduce.as_ref() { - params.push(("reproduce", warning(t))); - } - - let url = url::Url::parse_with_params( - &format!("https://github.com/{GITHUB_REPO_NAME}/issues/new"), - params.iter(), - )?; - - if is_remote() || crate::util::open::open_url_async(url.as_str()).await.is_err() { - println!("Issue Url: {}", url.as_str().underlined()); - } - - Ok(url) - } -} diff --git a/crates/chat-cli/src/cli/chat/util/mod.rs b/crates/chat-cli/src/cli/chat/util/mod.rs deleted file mode 100644 index 0afb6d24d..000000000 --- a/crates/chat-cli/src/cli/chat/util/mod.rs +++ /dev/null @@ -1,252 +0,0 @@ -pub mod images; -pub mod issue; -#[cfg(test)] -pub mod test; -pub mod ui; - -use std::io::Write; -use std::time::Duration; - -use aws_smithy_types::{ - Document, - Number as SmithyNumber, -}; -use eyre::Result; - -use super::ChatError; -use super::token_counter::TokenCounter; - -pub fn truncate_safe(s: &str, max_bytes: usize) -> &str { - if s.len() <= max_bytes { - return s; - } - - let mut byte_count = 0; - let mut char_indices = s.char_indices(); - - for (byte_idx, _) in &mut char_indices { - if byte_count + (byte_idx - byte_count) > max_bytes { - break; - } - byte_count = byte_idx; - } - - &s[..byte_count] -} - -/// Truncates `s` to a maximum length of `max_bytes`, appending `suffix` if `s` was truncated. The -/// result is always guaranteed to be at least less than `max_bytes`. -/// -/// If `suffix` is larger than `max_bytes`, or `s` is within `max_bytes`, then this function does -/// nothing. -pub fn truncate_safe_in_place(s: &mut String, max_bytes: usize, suffix: &str) { - // Do nothing if the suffix is too large to be truncated within max_bytes, or s is already small - // enough to not be truncated. - if suffix.len() > max_bytes || s.len() <= max_bytes { - return; - } - - let end = truncate_safe(s, max_bytes - suffix.len()).len(); - s.replace_range(end..s.len(), suffix); - s.truncate(max_bytes); -} - -pub fn animate_output(output: &mut impl Write, bytes: &[u8]) -> Result<(), ChatError> { - for b in bytes.chunks(12) { - output.write_all(b)?; - std::thread::sleep(Duration::from_millis(16)); - } - Ok(()) -} - -/// Play the terminal bell notification sound -pub fn play_notification_bell(requires_confirmation: bool) { - // Don't play bell for tools that don't require confirmation - if !requires_confirmation { - return; - } - - // Check if we should play the bell based on terminal type - if should_play_bell() { - print!("\x07"); // ASCII bell character - std::io::stdout().flush().unwrap(); - } -} - -/// Determine if we should play the bell based on terminal type -fn should_play_bell() -> bool { - // Get the TERM environment variable - if let Ok(term) = std::env::var("TERM") { - // List of terminals known to handle bell character well - let bell_compatible_terms = [ - "xterm", - "xterm-256color", - "screen", - "screen-256color", - "tmux", - "tmux-256color", - "rxvt", - "rxvt-unicode", - "linux", - "konsole", - "gnome", - "gnome-256color", - "alacritty", - "iterm2", - "eat-truecolor", - "eat-256color", - "eat-color", - ]; - - // Check if the current terminal is in the compatible list - for compatible_term in bell_compatible_terms.iter() { - if term.starts_with(compatible_term) { - return true; - } - } - - // For other terminals, don't play the bell - return false; - } - - // If TERM is not set, default to not playing the bell - false -} - -/// This is a simple greedy algorithm that drops the largest files first -/// until the total size is below the limit -/// -/// # Arguments -/// * `files` - A mutable reference to a vector of tuples: (filename, content). This file will be -/// sorted but the content will not be changed. -/// -/// Returns the dropped files -pub fn drop_matched_context_files(files: &mut [(String, String)], limit: usize) -> Result> { - files.sort_by(|a, b| TokenCounter::count_tokens(&b.1).cmp(&TokenCounter::count_tokens(&a.1))); - let mut total_size = 0; - let mut dropped_files = Vec::new(); - - for (filename, content) in files.iter() { - let size = TokenCounter::count_tokens(content); - if total_size + size > limit { - dropped_files.push((filename.clone(), content.clone())); - } else { - total_size += size; - } - } - Ok(dropped_files) -} - -pub fn serde_value_to_document(value: serde_json::Value) -> Document { - match value { - serde_json::Value::Null => Document::Null, - serde_json::Value::Bool(bool) => Document::Bool(bool), - serde_json::Value::Number(number) => { - if let Some(num) = number.as_u64() { - Document::Number(SmithyNumber::PosInt(num)) - } else if number.as_i64().is_some_and(|n| n < 0) { - Document::Number(SmithyNumber::NegInt(number.as_i64().unwrap())) - } else { - Document::Number(SmithyNumber::Float(number.as_f64().unwrap_or_default())) - } - }, - serde_json::Value::String(string) => Document::String(string), - serde_json::Value::Array(vec) => { - Document::Array(vec.clone().into_iter().map(serde_value_to_document).collect::<_>()) - }, - serde_json::Value::Object(map) => Document::Object( - map.into_iter() - .map(|(k, v)| (k, serde_value_to_document(v))) - .collect::<_>(), - ), - } -} - -pub fn document_to_serde_value(value: Document) -> serde_json::Value { - use serde_json::Value; - match value { - Document::Object(map) => Value::Object( - map.into_iter() - .map(|(k, v)| (k, document_to_serde_value(v))) - .collect::<_>(), - ), - Document::Array(vec) => Value::Array(vec.clone().into_iter().map(document_to_serde_value).collect::<_>()), - Document::Number(number) => { - if let Ok(v) = TryInto::::try_into(number) { - Value::Number(v.into()) - } else if let Ok(v) = TryInto::::try_into(number) { - Value::Number(v.into()) - } else { - Value::Number( - serde_json::Number::from_f64(number.to_f64_lossy()) - .unwrap_or(serde_json::Number::from_f64(0.0).expect("converting from 0.0 will not fail")), - ) - } - }, - Document::String(s) => serde_json::Value::String(s), - Document::Bool(b) => serde_json::Value::Bool(b), - Document::Null => serde_json::Value::Null, - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_truncate_safe() { - assert_eq!(truncate_safe("Hello World", 5), "Hello"); - assert_eq!(truncate_safe("Hello ", 5), "Hello"); - assert_eq!(truncate_safe("Hello World", 11), "Hello World"); - assert_eq!(truncate_safe("Hello World", 15), "Hello World"); - } - - #[test] - fn test_tsip() { - let suffix = "suffix"; - let tests = &[ - ("Hello World", 5, "Hello World"), - ("Hello World", 7, "Hsuffix"), - ("Hello World", usize::MAX, "Hello World"), - // α -> 2 byte length - ("αααααα", 7, "suffix"), - ("αααααα", 8, "αsuffix"), - ("αααααα", 9, "αsuffix"), - ]; - assert!("α".len() == 2); - - for (input, max_bytes, expected) in tests { - let mut input = (*input).to_string(); - truncate_safe_in_place(&mut input, *max_bytes, suffix); - assert_eq!( - input.as_str(), - *expected, - "input: {} with max bytes: {} failed", - input, - max_bytes - ); - } - } - - #[test] - fn test_drop_matched_context_files() { - let mut files = vec![ - ("file1".to_string(), "This is a test file".to_string()), - ( - "file3".to_string(), - "Yet another test file that's has the largest context file".to_string(), - ), - ]; - let limit = 9; - - let dropped_files = drop_matched_context_files(&mut files, limit).unwrap(); - assert_eq!(dropped_files.len(), 1); - assert_eq!(dropped_files[0].0, "file3"); - assert_eq!(files.len(), 2); - - for (filename, _) in dropped_files.iter() { - files.retain(|(f, _)| f != filename); - } - assert_eq!(files.len(), 1); - } -} diff --git a/crates/chat-cli/src/cli/chat/util/test.rs b/crates/chat-cli/src/cli/chat/util/test.rs deleted file mode 100644 index 1106e02a7..000000000 --- a/crates/chat-cli/src/cli/chat/util/test.rs +++ /dev/null @@ -1,45 +0,0 @@ -use eyre::Result; - -use crate::cli::chat::consts::CONTEXT_FILES_MAX_SIZE; -use crate::cli::chat::context::ContextManager; -use crate::os::Os; - -pub const TEST_FILE_CONTENTS: &str = "\ -1: Hello world! -2: This is line 2 -3: asdf -4: Hello world! -"; - -pub const TEST_FILE_PATH: &str = "/test_file.txt"; -pub const TEST_HIDDEN_FILE_PATH: &str = "/aaaa2/.hidden"; - -// Helper function to create a test ContextManager with Context -pub async fn create_test_context_manager(context_file_size: Option) -> Result { - let context_file_size = context_file_size.unwrap_or(CONTEXT_FILES_MAX_SIZE); - let os = Os::new().await.unwrap(); - let manager = ContextManager::new(&os, Some(context_file_size)).await?; - Ok(manager) -} - -/// Sets up the following filesystem structure: -/// ```text -/// test_file.txt -/// /home/testuser/ -/// /aaaa1/ -/// /bbbb1/ -/// /cccc1/ -/// /aaaa2/ -/// .hidden -/// ``` -pub async fn setup_test_directory() -> Os { - let os = Os::new().await.unwrap(); - os.fs.write(TEST_FILE_PATH, TEST_FILE_CONTENTS).await.unwrap(); - os.fs.create_dir_all("/aaaa1/bbbb1/cccc1").await.unwrap(); - os.fs.create_dir_all("/aaaa2").await.unwrap(); - os.fs - .write(TEST_HIDDEN_FILE_PATH, "this is a hidden file") - .await - .unwrap(); - os -} diff --git a/crates/chat-cli/src/cli/chat/util/ui.rs b/crates/chat-cli/src/cli/chat/util/ui.rs deleted file mode 100644 index 43efc8a68..000000000 --- a/crates/chat-cli/src/cli/chat/util/ui.rs +++ /dev/null @@ -1,202 +0,0 @@ -use std::io::Write; - -use crossterm::style::{ - Color, - Stylize, -}; -use crossterm::terminal::{ - self, - ClearType, -}; -use crossterm::{ - cursor, - execute, - style, -}; -use eyre::Result; -use strip_ansi_escapes::strip_str; - -pub fn draw_box( - output: &mut impl Write, - title: &str, - content: &str, - box_width: usize, - border_color: Color, -) -> Result<()> { - let inner_width = box_width - 4; // account for │ and padding - - // wrap the single line into multiple lines respecting inner width - // Manually wrap the text by splitting at word boundaries - let mut wrapped_lines = Vec::new(); - let mut line = String::new(); - - for word in content.split_whitespace() { - if line.len() + word.len() < inner_width { - if !line.is_empty() { - line.push(' '); - } - line.push_str(word); - } else { - // Here we need to account for words that are too long as well - if word.len() >= inner_width { - let mut start = 0_usize; - for (i, _) in word.chars().enumerate() { - if i - start >= inner_width { - wrapped_lines.push(word[start..i].to_string()); - start = i; - } - } - wrapped_lines.push(word[start..].to_string()); - line = String::new(); - } else { - wrapped_lines.push(line); - line = word.to_string(); - } - } - } - - if !line.is_empty() { - wrapped_lines.push(line); - } - - let side_len = (box_width.saturating_sub(title.len())) / 2; - let top_border = format!( - "{} {} {}", - style::style(format!("╭{}", "─".repeat(side_len - 2))).with(border_color), - title, - style::style(format!("{}╮", "─".repeat(box_width - side_len - title.len() - 2))).with(border_color) - ); - - execute!( - output, - terminal::Clear(ClearType::CurrentLine), - cursor::MoveToColumn(0), - style::Print(format!("{top_border}\n")), - )?; - - // Top vertical padding - let top_vertical_border = format!( - "{}", - style::style(format!("│{: = long_tip.split_whitespace().collect(); - for part in long_tip_parts.iter().take(3) { - assert!(output_str.contains(part), "Output should contain parts of the long tip"); - } - } -} diff --git a/crates/chat-cli/src/cli/debug.rs b/crates/chat-cli/src/cli/debug.rs deleted file mode 100644 index 27e0477b3..000000000 --- a/crates/chat-cli/src/cli/debug.rs +++ /dev/null @@ -1,82 +0,0 @@ -use clap::ValueEnum; - -#[derive(Debug, ValueEnum, Clone, PartialEq, Eq)] -pub enum Build { - Production, - #[value(alias = "staging")] - Beta, - #[value(hide = true, alias = "dev")] - Develop, -} - -impl std::fmt::Display for Build { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Build::Production => f.write_str("production"), - Build::Beta => f.write_str("beta"), - Build::Develop => f.write_str("develop"), - } - } -} - -#[derive(Debug, ValueEnum, Clone, PartialEq, Eq)] -pub enum App { - Dashboard, - Autocomplete, -} - -impl std::fmt::Display for App { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - App::Dashboard => f.write_str("dashboard"), - App::Autocomplete => f.write_str("autocomplete"), - } - } -} - -#[derive(Debug, ValueEnum, Clone, PartialEq, Eq)] -pub enum AutocompleteWindowDebug { - On, - Off, -} - -#[derive(Debug, ValueEnum, Clone, PartialEq, Eq)] -pub enum AccessibilityAction { - Refresh, - Reset, - Prompt, - Open, - Status, -} - -#[cfg(target_os = "macos")] -#[derive(Debug, Clone, PartialEq, Eq, ValueEnum)] -pub enum TISAction { - Enable, - Disable, - Select, - Deselect, -} - -#[cfg(target_os = "macos")] -use std::path::PathBuf; - -#[cfg(target_os = "macos")] -#[derive(Debug, clap::Subcommand, Clone, PartialEq, Eq)] -pub enum InputMethodDebugAction { - Install { - bundle_path: Option, - }, - Uninstall { - bundle_path: Option, - }, - List, - Status { - bundle_path: Option, - }, - Source { - bundle_identifier: String, - #[arg(value_enum)] - action: TISAction, - }, -} diff --git a/crates/chat-cli/src/cli/diagnostics.rs b/crates/chat-cli/src/cli/diagnostics.rs deleted file mode 100644 index 0a5b977d8..000000000 --- a/crates/chat-cli/src/cli/diagnostics.rs +++ /dev/null @@ -1,69 +0,0 @@ -use std::io::{ - IsTerminal, - stdout, -}; -use std::process::ExitCode; - -use anstream::println; -use clap::Args; -use color_eyre::Result; -use crossterm::terminal::{ - Clear, - ClearType, -}; -use crossterm::{ - cursor, - execute, -}; -use spinners::{ - Spinner, - Spinners, -}; - -use super::OutputFormat; -use crate::os::Os; -use crate::os::diagnostics::Diagnostics; - -#[derive(Clone, Debug, Args, PartialEq, Eq)] -pub struct DiagnosticArgs { - /// The format of the output - #[arg(long, short, value_enum, default_value_t)] - format: OutputFormat, - /// Force limited diagnostic output - #[arg(long)] - force: bool, -} - -impl DiagnosticArgs { - pub async fn execute(&self, os: &Os) -> Result { - let spinner = if stdout().is_terminal() { - Some(Spinner::new(Spinners::Dots, "Generating...".into())) - } else { - None - }; - - if spinner.is_some() { - execute!(std::io::stdout(), cursor::Hide)?; - - ctrlc::set_handler(move || { - execute!(std::io::stdout(), cursor::Show).ok(); - std::process::exit(1); - })?; - } - - let diagnostics = Diagnostics::new(&os.env).await; - - if let Some(mut sp) = spinner { - sp.stop(); - execute!(std::io::stdout(), Clear(ClearType::CurrentLine), cursor::Show)?; - println!(); - } - - self.format.print( - || diagnostics.user_readable().expect("Failed to run user_readable()"), - || &diagnostics, - ); - - Ok(ExitCode::SUCCESS) - } -} diff --git a/crates/chat-cli/src/cli/feed.rs b/crates/chat-cli/src/cli/feed.rs deleted file mode 100644 index 7df058c94..000000000 --- a/crates/chat-cli/src/cli/feed.rs +++ /dev/null @@ -1,49 +0,0 @@ -use serde::{ - Deserialize, - Serialize, -}; - -#[derive(Debug, Serialize, Deserialize)] -pub struct Feed { - pub entries: Vec, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Entry { - #[serde(rename = "type")] - pub entry_type: String, - pub date: String, - pub version: String, - #[serde(default)] - pub hidden: bool, - #[serde(default)] - pub changes: Vec, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Change { - #[serde(rename = "type")] - pub change_type: String, - pub description: String, -} - -impl Feed { - pub fn load() -> Self { - serde_json::from_str(include_str!("../../../../feed.json")).expect("feed.json is valid json") - } - - pub fn get_version_changelog(&self, version: &str) -> Option { - self.entries - .iter() - .find(|entry| entry.entry_type == "release" && entry.version == version && !entry.hidden) - .cloned() - } - - pub fn get_all_changelogs(&self) -> Vec { - self.entries - .iter() - .filter(|entry| entry.entry_type == "release" && !entry.hidden) - .cloned() - .collect() - } -} diff --git a/crates/chat-cli/src/cli/issue.rs b/crates/chat-cli/src/cli/issue.rs deleted file mode 100644 index 4aa78bd6d..000000000 --- a/crates/chat-cli/src/cli/issue.rs +++ /dev/null @@ -1,40 +0,0 @@ -use std::process::ExitCode; - -use clap::Args; -use eyre::Result; - -use crate::os::Os; - -#[derive(Clone, Debug, Args, PartialEq, Eq)] -pub struct IssueArgs { - /// Force issue creation - #[arg(long, short = 'f')] - force: bool, - /// Issue description - description: Vec, -} - -impl IssueArgs { - pub async fn execute(&self, os: &Os) -> Result { - let joined_description = self.description.join(" ").trim().to_owned(); - - let issue_title = match joined_description.len() { - 0 => dialoguer::Input::with_theme(&crate::util::dialoguer_theme()) - .with_prompt("Issue Title") - .interact_text()?, - _ => joined_description, - }; - - let _ = crate::cli::chat::util::issue::IssueCreator { - title: Some(issue_title), - expected_behavior: None, - actual_behavior: None, - steps_to_reproduce: None, - additional_environment: None, - } - .create_url(os) - .await; - - Ok(ExitCode::SUCCESS) - } -} diff --git a/crates/chat-cli/src/cli/mcp.rs b/crates/chat-cli/src/cli/mcp.rs deleted file mode 100644 index 451d44b87..000000000 --- a/crates/chat-cli/src/cli/mcp.rs +++ /dev/null @@ -1,569 +0,0 @@ -use std::collections::HashMap; -use std::io::Write; -use std::path::PathBuf; -use std::process::ExitCode; - -use clap::{ - ArgAction, - Args, - ValueEnum, -}; -use crossterm::{ - execute, - style, -}; -use eyre::{ - Result, - bail, -}; -use tracing::warn; - -use crate::cli::chat::tool_manager::{ - McpServerConfig, - global_mcp_config_path, - workspace_mcp_config_path, -}; -use crate::cli::chat::tools::custom_tool::{ - CustomToolConfig, - default_timeout, -}; -use crate::os::Os; - -#[derive(Debug, Copy, Clone, PartialEq, Eq, ValueEnum)] -pub enum Scope { - Workspace, - Global, -} - -impl std::fmt::Display for Scope { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Scope::Workspace => write!(f, "workspace"), - Scope::Global => write!(f, "global"), - } - } -} - -#[derive(Debug, Clone, PartialEq, Eq, clap::Subcommand)] -pub enum McpSubcommand { - /// Add or replace a configured server - Add(AddArgs), - /// Remove a server from the MCP configuration - #[command(alias = "rm")] - Remove(RemoveArgs), - /// List configured servers - List(ListArgs), - /// Import a server configuration from another file - Import(ImportArgs), - /// Get the status of a configured server - Status(StatusArgs), -} - -impl McpSubcommand { - pub async fn execute(self, os: &mut Os, output: &mut impl Write) -> Result { - match self { - Self::Add(args) => args.execute(os, output).await?, - Self::Remove(args) => args.execute(os, output).await?, - Self::List(args) => args.execute(os, output).await?, - Self::Import(args) => args.execute(os, output).await?, - Self::Status(args) => args.execute(os, output).await?, - } - - output.flush()?; - Ok(ExitCode::SUCCESS) - } -} - -#[derive(Debug, Clone, PartialEq, Eq, Args)] -pub struct AddArgs { - /// Name for the server - #[arg(long)] - pub name: String, - /// The command used to launch the server - #[arg(long)] - pub command: String, - /// Arguments to pass to the command - #[arg(long, action = ArgAction::Append, allow_hyphen_values = true, value_delimiter = ',')] - pub args: Vec, - /// Where to add the server to. - #[arg(long, value_enum)] - pub scope: Option, - /// Environment variables to use when launching the server - #[arg(long, value_parser = parse_env_vars)] - pub env: Vec>, - /// Server launch timeout, in milliseconds - #[arg(long)] - pub timeout: Option, - /// Whether the server should be disabled (not loaded) - #[arg(long, default_value_t = false)] - pub disabled: bool, - /// Overwrite an existing server with the same name - #[arg(long, default_value_t = false)] - pub force: bool, -} - -impl AddArgs { - pub async fn execute(self, os: &Os, output: &mut impl Write) -> Result<()> { - let scope = self.scope.unwrap_or(Scope::Workspace); - let config_path = resolve_scope_profile(os, self.scope)?; - - let mut config: McpServerConfig = ensure_config_file(os, &config_path, output).await?; - - if config.mcp_servers.contains_key(&self.name) && !self.force { - bail!( - "\nMCP server '{}' already exists in {} (scope {}). Use --force to overwrite.", - self.name, - config_path.display(), - scope - ); - } - - let merged_env = self.env.into_iter().flatten().collect::>(); - let tool: CustomToolConfig = serde_json::from_value(serde_json::json!({ - "command": self.command, - "args": self.args, - "env": merged_env, - "timeout": self.timeout.unwrap_or(default_timeout()), - "disabled": self.disabled, - }))?; - - writeln!( - output, - "\nTo learn more about MCP safety, see https://docs.aws.amazon.com/amazonq/latest/qdeveloper-ug/command-line-mcp-security.html\n\n" - )?; - - config.mcp_servers.insert(self.name.clone(), tool); - config.save_to_file(os, &config_path).await?; - writeln!( - output, - "✓ Added MCP server '{}' to {}\n", - self.name, - scope_display(&scope) - )?; - Ok(()) - } -} - -#[derive(Debug, Clone, PartialEq, Eq, Args)] -pub struct RemoveArgs { - #[arg(long)] - pub name: String, - #[arg(long, value_enum)] - pub scope: Option, -} - -impl RemoveArgs { - pub async fn execute(self, os: &Os, output: &mut impl Write) -> Result<()> { - let scope = self.scope.unwrap_or(Scope::Workspace); - let config_path = resolve_scope_profile(os, self.scope)?; - - if !os.fs.exists(&config_path) { - writeln!(output, "\nNo MCP server configurations found.\n")?; - return Ok(()); - } - - let mut config = McpServerConfig::load_from_file(os, &config_path).await?; - match config.mcp_servers.remove(&self.name) { - Some(_) => { - config.save_to_file(os, &config_path).await?; - writeln!( - output, - "\n✓ Removed MCP server '{}' from {}\n", - self.name, - scope_display(&scope) - )?; - }, - None => { - writeln!( - output, - "\nNo MCP server named '{}' found in {}\n", - self.name, - scope_display(&scope) - )?; - }, - } - Ok(()) - } -} - -#[derive(Debug, Clone, PartialEq, Eq, Args)] -pub struct ListArgs { - #[arg(value_enum)] - pub scope: Option, - #[arg(long, hide = true)] - pub profile: Option, -} - -impl ListArgs { - pub async fn execute(self, os: &Os, output: &mut impl Write) -> Result<()> { - let configs = get_mcp_server_configs(os, self.scope).await?; - if configs.is_empty() { - writeln!(output, "No MCP server configurations found.\n")?; - return Ok(()); - } - - for (scope, path, cfg_opt) in configs { - writeln!(output)?; - writeln!(output, "{}:\n {}", scope_display(&scope), path.display())?; - match cfg_opt { - Some(cfg) if !cfg.mcp_servers.is_empty() => { - for (name, tool_cfg) in &cfg.mcp_servers { - let status = if tool_cfg.disabled { " (disabled)" } else { "" }; - writeln!(output, " • {name:<12} {}{}", tool_cfg.command, status)?; - } - }, - _ => { - writeln!(output, " (empty)")?; - }, - } - } - writeln!(output, "\n")?; - - Ok(()) - } -} - -#[derive(Debug, Clone, PartialEq, Eq, Args)] -pub struct ImportArgs { - #[arg(long)] - pub file: String, - #[arg(value_enum)] - pub scope: Option, - /// Overwrite an existing server with the same name - #[arg(long, default_value_t = false)] - pub force: bool, -} - -impl ImportArgs { - pub async fn execute(self, os: &Os, output: &mut impl Write) -> Result<()> { - let scope: Scope = self.scope.unwrap_or(Scope::Workspace); - let config_path = resolve_scope_profile(os, self.scope)?; - let mut dst_cfg = ensure_config_file(os, &config_path, output).await?; - - let src_path = expand_path(os, &self.file)?; - let src_cfg: McpServerConfig = McpServerConfig::load_from_file(os, &src_path).await?; - - let mut added = 0; - for (name, cfg) in src_cfg.mcp_servers { - if dst_cfg.mcp_servers.contains_key(&name) && !self.force { - bail!( - "\nMCP server '{}' already exists in {} (scope {}). Use --force to overwrite.\n", - name, - config_path.display(), - scope - ); - } - dst_cfg.mcp_servers.insert(name.clone(), cfg); - added += 1; - } - - writeln!( - output, - "\nTo learn more about MCP safety, see https://docs.aws.amazon.com/amazonq/latest/qdeveloper-ug/command-line-mcp-security.html\n\n" - )?; - - dst_cfg.save_to_file(os, &config_path).await?; - writeln!( - output, - "✓ Imported {added} MCP server(s) into {}\n", - scope_display(&scope) - )?; - Ok(()) - } -} - -#[derive(Debug, Clone, PartialEq, Eq, Args)] -pub struct StatusArgs { - #[arg(long)] - pub name: String, -} - -impl StatusArgs { - pub async fn execute(self, os: &Os, output: &mut impl Write) -> Result<()> { - let configs = get_mcp_server_configs(os, None).await?; - let mut found = false; - - for (sc, path, cfg_opt) in configs { - if let Some(cfg) = cfg_opt.and_then(|c| c.mcp_servers.get(&self.name).cloned()) { - found = true; - execute!( - output, - style::Print("\n─────────────\n"), - style::Print(format!("Scope : {}\n", scope_display(&sc))), - style::Print(format!("File : {}\n", path.display())), - style::Print(format!("Command : {}\n", cfg.command)), - style::Print(format!("Timeout : {} ms\n", cfg.timeout)), - style::Print(format!("Disabled: {}\n", cfg.disabled)), - style::Print(format!( - "Env Vars: {}\n", - cfg.env - .as_ref() - .map_or_else(|| "(none)".into(), |e| e.keys().cloned().collect::>().join(", ")) - )), - )?; - } - } - writeln!(output, "\n")?; - - if !found { - bail!("No MCP server named '{}' found in any scope/profile\n", self.name); - } - - Ok(()) - } -} - -async fn get_mcp_server_configs( - os: &Os, - scope: Option, -) -> Result)>> { - let mut targets = Vec::new(); - match scope { - Some(scope) => targets.push(scope), - None => targets.extend([Scope::Workspace, Scope::Global]), - } - - let mut results = Vec::new(); - for sc in targets { - let path = resolve_scope_profile(os, Some(sc))?; - let cfg_opt = if os.fs.exists(&path) { - match McpServerConfig::load_from_file(os, &path).await { - Ok(cfg) => Some(cfg), - Err(e) => { - warn!(?path, error = %e, "Invalid MCP config file—ignored, treated as null"); - None - }, - } - } else { - None - }; - results.push((sc, path, cfg_opt)); - } - Ok(results) -} - -fn scope_display(scope: &Scope) -> String { - match scope { - Scope::Workspace => "📄 workspace".into(), - Scope::Global => "🌍 global".into(), - } -} - -fn resolve_scope_profile(os: &Os, scope: Option) -> Result { - Ok(match scope { - Some(Scope::Global) => global_mcp_config_path(os)?, - _ => workspace_mcp_config_path(os)?, - }) -} - -fn expand_path(os: &Os, p: &str) -> Result { - let p = shellexpand::tilde(p); - let mut path = PathBuf::from(p.as_ref() as &str); - if path.is_relative() { - path = os.env.current_dir()?.join(path); - } - Ok(path) -} - -async fn ensure_config_file(os: &Os, path: &PathBuf, output: &mut impl Write) -> Result { - if !os.fs.exists(path) { - if let Some(parent) = path.parent() { - os.fs.create_dir_all(parent).await?; - } - McpServerConfig::default().save_to_file(os, path).await?; - writeln!(output, "\n📁 Created MCP config in '{}'", path.display())?; - } - - load_cfg(os, path).await -} - -fn parse_env_vars(arg: &str) -> Result> { - let mut vars = HashMap::new(); - - for pair in arg.split(",") { - match pair.split_once('=') { - Some((key, value)) => { - vars.insert(key.trim().to_string(), value.trim().to_string()); - }, - None => { - bail!( - "Failed to parse environment variables, invalid environment variable '{}'. Expected 'name=value'", - pair - ) - }, - } - } - - Ok(vars) -} - -async fn load_cfg(os: &Os, p: &PathBuf) -> Result { - Ok(if os.fs.exists(p) { - McpServerConfig::load_from_file(os, p).await? - } else { - McpServerConfig::default() - }) -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::cli::RootSubcommand; - use crate::util::test::assert_parse; - - #[tokio::test] - async fn test_scope_and_profile_defaults_to_workspace() { - let os = Os::new().await.unwrap(); - let path = resolve_scope_profile(&os, None).unwrap(); - assert_eq!( - path.to_str(), - workspace_mcp_config_path(&os).unwrap().to_str(), - "No scope or profile should default to the workspace path" - ); - } - - #[tokio::test] - async fn test_resolve_paths() { - let os = Os::new().await.unwrap(); - // workspace - let p = resolve_scope_profile(&os, Some(Scope::Workspace)).unwrap(); - assert_eq!(p, workspace_mcp_config_path(&os).unwrap()); - - // global - let p = resolve_scope_profile(&os, Some(Scope::Global)).unwrap(); - assert_eq!(p, global_mcp_config_path(&os).unwrap()); - } - - #[ignore = "TODO: fix in CI"] - #[tokio::test] - async fn ensure_file_created_and_loaded() { - let os = Os::new().await.unwrap(); - let path = workspace_mcp_config_path(&os).unwrap(); - - let cfg = super::ensure_config_file(&os, &path, &mut vec![]).await.unwrap(); - assert!(path.exists(), "config file should be created"); - assert!(cfg.mcp_servers.is_empty()); - } - - #[tokio::test] - async fn add_then_remove_cycle() { - let os = Os::new().await.unwrap(); - - // 1. add - AddArgs { - name: "local".into(), - command: "echo hi".into(), - args: vec![ - "awslabs.eks-mcp-server".to_string(), - "--allow-write".to_string(), - "--allow-sensitive-data-access".to_string(), - ], - env: vec![], - timeout: None, - scope: None, - disabled: false, - force: false, - } - .execute(&os, &mut vec![]) - .await - .unwrap(); - - let cfg_path = workspace_mcp_config_path(&os).unwrap(); - let cfg: McpServerConfig = - serde_json::from_str(&os.fs.read_to_string(cfg_path.clone()).await.unwrap()).unwrap(); - assert!(cfg.mcp_servers.len() == 1); - - // 2. remove - RemoveArgs { - name: "local".into(), - scope: None, - } - .execute(&os, &mut vec![]) - .await - .unwrap(); - - let cfg: McpServerConfig = serde_json::from_str(&os.fs.read_to_string(cfg_path).await.unwrap()).unwrap(); - assert!(cfg.mcp_servers.is_empty()); - } - - #[test] - fn test_mcp_subcomman_add() { - assert_parse!( - [ - "mcp", - "add", - "--name", - "test_server", - "--command", - "test_command", - "--args", - "awslabs.eks-mcp-server,--allow-write,--allow-sensitive-data-access", - "--env", - "key1=value1,key2=value2" - ], - RootSubcommand::Mcp(McpSubcommand::Add(AddArgs { - name: "test_server".to_string(), - command: "test_command".to_string(), - args: vec![ - "awslabs.eks-mcp-server".to_string(), - "--allow-write".to_string(), - "--allow-sensitive-data-access".to_string(), - ], - scope: None, - env: vec![ - [ - ("key1".to_string(), "value1".to_string()), - ("key2".to_string(), "value2".to_string()) - ] - .into_iter() - .collect() - ], - timeout: None, - disabled: false, - force: false, - })) - ); - } - - #[test] - fn test_mcp_subcomman_remove_workspace() { - assert_parse!( - ["mcp", "remove", "--name", "old"], - RootSubcommand::Mcp(McpSubcommand::Remove(RemoveArgs { - name: "old".into(), - scope: None, - })) - ); - } - - #[test] - fn test_mcp_subcomman_import_profile_force() { - assert_parse!( - ["mcp", "import", "--file", "servers.json", "--force"], - RootSubcommand::Mcp(McpSubcommand::Import(ImportArgs { - file: "servers.json".into(), - scope: None, - force: true, - })) - ); - } - - #[test] - fn test_mcp_subcommand_status_simple() { - assert_parse!( - ["mcp", "status", "--name", "aws"], - RootSubcommand::Mcp(McpSubcommand::Status(StatusArgs { name: "aws".into() })) - ); - } - - #[test] - fn test_mcp_subcommand_list() { - assert_parse!( - ["mcp", "list", "global"], - RootSubcommand::Mcp(McpSubcommand::List(ListArgs { - scope: Some(Scope::Global), - profile: None - })) - ); - } -} diff --git a/crates/chat-cli/src/cli/mod.rs b/crates/chat-cli/src/cli/mod.rs deleted file mode 100644 index 637b86968..000000000 --- a/crates/chat-cli/src/cli/mod.rs +++ /dev/null @@ -1,517 +0,0 @@ -mod chat; -mod debug; -mod diagnostics; -mod feed; -mod issue; -mod mcp; -mod settings; -mod user; - -use std::fmt::Display; -use std::io::{ - Write as _, - stdout, -}; -use std::process::ExitCode; - -use anstream::println; -pub use chat::ConversationState; -use clap::{ - ArgAction, - CommandFactory, - Parser, - Subcommand, - ValueEnum, -}; -use crossterm::style::Stylize; -use eyre::{ - Result, - bail, -}; -use feed::Feed; -use serde::Serialize; -use tracing::{ - Level, - debug, -}; - -use crate::cli::chat::ChatArgs; -use crate::cli::mcp::McpSubcommand; -use crate::cli::user::{ - LoginArgs, - WhoamiArgs, -}; -use crate::logging::{ - LogArgs, - initialize_logging, -}; -use crate::os::Os; -use crate::util::directories::logs_dir; -use crate::util::{ - CHAT_BINARY_NAME, - CLI_BINARY_NAME, - GOV_REGIONS, -}; - -#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, ValueEnum)] -pub enum OutputFormat { - /// Outputs the results as markdown - #[default] - Plain, - /// Outputs the results as JSON - Json, - /// Outputs the results as pretty print JSON - JsonPretty, -} - -impl OutputFormat { - pub fn print(&self, text_fn: TFn, json_fn: JFn) - where - T: std::fmt::Display, - TFn: FnOnce() -> T, - J: Serialize, - JFn: FnOnce() -> J, - { - match self { - OutputFormat::Plain => println!("{}", text_fn()), - OutputFormat::Json => println!("{}", serde_json::to_string(&json_fn()).unwrap()), - OutputFormat::JsonPretty => println!("{}", serde_json::to_string_pretty(&json_fn()).unwrap()), - } - } -} - -/// The Amazon Q CLI -#[deny(missing_docs)] -#[derive(Debug, PartialEq, Subcommand)] -pub enum RootSubcommand { - /// AI assistant in your terminal - Chat(ChatArgs), - /// Log in to Amazon Q - Login(LoginArgs), - /// Log out of Amazon Q - Logout, - /// Print info about the current login session - Whoami(WhoamiArgs), - /// Show the profile associated with this idc user - Profile, - /// Customize appearance & behavior - #[command(alias("setting"))] - Settings(settings::SettingsArgs), - /// Run diagnostic tests - #[command(alias("diagnostics"))] - Diagnostic(diagnostics::DiagnosticArgs), - /// Create a new Github issue - Issue(issue::IssueArgs), - /// Version - #[command(hide = true)] - Version { - /// Show the changelog (use --changelog=all for all versions, or --changelog=x.x.x for a - /// specific version) - #[arg(long, num_args = 0..=1, default_missing_value = "")] - changelog: Option, - }, - /// Model Context Protocol (MCP) - #[command(subcommand)] - Mcp(McpSubcommand), -} - -impl RootSubcommand { - /// Whether the command should have an associated telemetry event. - /// - /// Emitting telemetry takes a long time so the answer is usually no. - pub fn valid_for_telemetry(&self) -> bool { - matches!(self, Self::Chat(_) | Self::Login(_) | Self::Profile | Self::Issue(_)) - } - - pub fn requires_auth(&self) -> bool { - matches!(self, Self::Chat(_) | Self::Profile) - } - - pub async fn execute(self, os: &mut Os) -> Result { - // Check for auth on subcommands that require it. - if self.requires_auth() && !crate::auth::is_logged_in(&mut os.database).await { - bail!( - "You are not logged in, please log in with {}", - format!("{CLI_BINARY_NAME} login").bold() - ); - } - - // Send executed telemetry. - if self.valid_for_telemetry() { - os.telemetry - .send_cli_subcommand_executed(&os.database, &self) - .await - .ok(); - } - - match self { - Self::Diagnostic(args) => args.execute(os).await, - Self::Login(args) => args.execute(os).await, - Self::Logout => user::logout(os).await, - Self::Whoami(args) => args.execute(os).await, - Self::Profile => user::profile(os).await, - Self::Settings(settings_args) => settings_args.execute(os).await, - Self::Issue(args) => args.execute(os).await, - Self::Version { changelog } => Cli::print_version(changelog), - Self::Chat(args) => args.execute(os).await, - Self::Mcp(args) => args.execute(os, &mut std::io::stderr()).await, - } - } -} - -impl Default for RootSubcommand { - fn default() -> Self { - Self::Chat(ChatArgs::default()) - } -} - -impl Display for RootSubcommand { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let name = match self { - Self::Chat(_) => "chat", - Self::Login(_) => "login", - Self::Logout => "logout", - Self::Whoami(_) => "whoami", - Self::Profile => "profile", - Self::Settings(_) => "settings", - Self::Diagnostic(_) => "diagnostic", - Self::Issue(_) => "issue", - Self::Version { .. } => "version", - Self::Mcp(_) => "mcp", - }; - - write!(f, "{name}") - } -} - -#[derive(Debug, Parser, PartialEq, Default)] -#[command(version, about, name = crate::util::CHAT_BINARY_NAME)] -pub struct Cli { - #[command(subcommand)] - pub subcommand: Option, - /// Increase logging verbosity - #[arg(long, short = 'v', action = ArgAction::Count, global = true)] - pub verbose: u8, - /// Print help for all subcommands - #[arg(long)] - help_all: bool, -} - -impl Cli { - pub async fn execute(self) -> Result { - let subcommand = self.subcommand.unwrap_or_default(); - - // Initialize our logger and keep around the guard so logging can perform as expected. - let _log_guard = initialize_logging(LogArgs { - log_level: match self.verbose > 0 { - true => Some( - match self.verbose { - 1 => Level::WARN, - 2 => Level::INFO, - 3 => Level::DEBUG, - _ => Level::TRACE, - } - .to_string(), - ), - false => None, - }, - log_to_stdout: std::env::var_os("Q_LOG_STDOUT").is_some() || self.verbose > 0, - log_file_path: match subcommand { - RootSubcommand::Chat { .. } => Some( - logs_dir() - .expect("home dir must be set") - .join(format!("{CHAT_BINARY_NAME}.log")), - ), - _ => None, - }, - delete_old_log_file: false, - }); - - // Check for region support. - if let Ok(region) = std::env::var("AWS_REGION") { - if GOV_REGIONS.contains(®ion.as_str()) { - bail!("AWS GovCloud ({region}) is not supported.") - } - } - - debug!(command =? std::env::args().collect::>(), "Command being ran"); - - let mut os = Os::new().await?; - let result = subcommand.execute(&mut os).await; - - let telemetry_result = os.telemetry.finish().await; - let exit_code = result?; - telemetry_result?; - - Ok(exit_code) - } - - fn print_changelog_entry(entry: &feed::Entry) -> Result<()> { - println!("Version {} ({})", entry.version, entry.date); - - if entry.changes.is_empty() { - println!(" No changes recorded for this version."); - } else { - for change in &entry.changes { - let type_label = match change.change_type.as_str() { - "added" => "Added", - "fixed" => "Fixed", - "changed" => "Changed", - other => other, - }; - - println!(" - {}: {}", type_label, change.description); - } - } - - println!(); - Ok(()) - } - - fn print_version(changelog: Option) -> Result { - // If no changelog is requested, display normal version information - if changelog.is_none() { - let _ = writeln!(stdout(), "{}", Self::command().render_version()); - return Ok(ExitCode::SUCCESS); - } - - let changelog_value = changelog.unwrap_or_default(); - let feed = Feed::load(); - - // Display changelog for all versions - if changelog_value == "all" { - let entries = feed.get_all_changelogs(); - if entries.is_empty() { - println!("No changelog information available."); - } else { - println!("Changelog for all versions:"); - for entry in entries { - Self::print_changelog_entry(&entry)?; - } - } - return Ok(ExitCode::SUCCESS); - } - - // Display changelog for a specific version (--changelog=x.x.x) - if !changelog_value.is_empty() { - match feed.get_version_changelog(&changelog_value) { - Some(entry) => { - println!("Changelog for version {}:", changelog_value); - Self::print_changelog_entry(&entry)?; - return Ok(ExitCode::SUCCESS); - }, - None => { - println!("No changelog information available for version {}.", changelog_value); - return Ok(ExitCode::SUCCESS); - }, - } - } - - // Display changelog for the current version (--changelog only) - let current_version = env!("CARGO_PKG_VERSION"); - match feed.get_version_changelog(current_version) { - Some(entry) => { - println!("Changelog for version {}:", current_version); - Self::print_changelog_entry(&entry)?; - }, - None => { - println!("No changelog information available for version {}.", current_version); - }, - } - - Ok(ExitCode::SUCCESS) - } -} - -#[cfg(test)] -mod test { - use super::*; - use crate::util::CHAT_BINARY_NAME; - use crate::util::test::assert_parse; - - #[test] - fn debug_assert() { - Cli::command().debug_assert(); - } - - /// Test flag parsing for the top level [Cli] - #[test] - fn test_flags() { - assert_eq!(Cli::parse_from([CHAT_BINARY_NAME, "-v"]), Cli { - subcommand: None, - verbose: 1, - help_all: false, - }); - - assert_eq!(Cli::parse_from([CHAT_BINARY_NAME, "-vvv"]), Cli { - subcommand: None, - verbose: 3, - help_all: false, - }); - - assert_eq!(Cli::parse_from([CHAT_BINARY_NAME, "--help-all"]), Cli { - subcommand: None, - verbose: 0, - help_all: true, - }); - - assert_eq!(Cli::parse_from([CHAT_BINARY_NAME, "chat", "-vv"]), Cli { - subcommand: Some(RootSubcommand::Chat(ChatArgs { - resume: false, - input: None, - profile: None, - model: None, - trust_all_tools: false, - trust_tools: None, - no_interactive: false - })), - verbose: 2, - help_all: false, - }); - } - - #[test] - fn test_version_changelog() { - assert_parse!(["version", "--changelog"], RootSubcommand::Version { - changelog: Some("".to_string()), - }); - } - - #[test] - fn test_version_changelog_all() { - assert_parse!(["version", "--changelog=all"], RootSubcommand::Version { - changelog: Some("all".to_string()), - }); - } - - #[test] - fn test_version_changelog_specific() { - assert_parse!(["version", "--changelog=1.8.0"], RootSubcommand::Version { - changelog: Some("1.8.0".to_string()), - }); - } - - #[test] - fn test_chat_with_context_profile() { - assert_parse!( - ["chat", "--profile", "my-profile"], - RootSubcommand::Chat(ChatArgs { - resume: false, - input: None, - profile: Some("my-profile".to_string()), - model: None, - trust_all_tools: false, - trust_tools: None, - no_interactive: false - }) - ); - } - - #[test] - fn test_chat_with_context_profile_and_input() { - assert_parse!( - ["chat", "--profile", "my-profile", "Hello"], - RootSubcommand::Chat(ChatArgs { - resume: false, - input: Some("Hello".to_string()), - profile: Some("my-profile".to_string()), - model: None, - trust_all_tools: false, - trust_tools: None, - no_interactive: false - }) - ); - } - - #[test] - fn test_chat_with_context_profile_and_accept_all() { - assert_parse!( - ["chat", "--profile", "my-profile", "--trust-all-tools"], - RootSubcommand::Chat(ChatArgs { - resume: false, - input: None, - profile: Some("my-profile".to_string()), - model: None, - trust_all_tools: true, - trust_tools: None, - no_interactive: false - }) - ); - } - - #[test] - fn test_chat_with_no_interactive_and_resume() { - assert_parse!( - ["chat", "--no-interactive", "--resume"], - RootSubcommand::Chat(ChatArgs { - resume: true, - input: None, - profile: None, - model: None, - trust_all_tools: false, - trust_tools: None, - no_interactive: true - }) - ); - assert_parse!( - ["chat", "--non-interactive", "-r"], - RootSubcommand::Chat(ChatArgs { - resume: true, - input: None, - profile: None, - model: None, - trust_all_tools: false, - trust_tools: None, - no_interactive: true - }) - ); - } - - #[test] - fn test_chat_with_tool_trust_all() { - assert_parse!( - ["chat", "--trust-all-tools"], - RootSubcommand::Chat(ChatArgs { - resume: false, - input: None, - profile: None, - model: None, - trust_all_tools: true, - trust_tools: None, - no_interactive: false - }) - ); - } - - #[test] - fn test_chat_with_tool_trust_none() { - assert_parse!( - ["chat", "--trust-tools="], - RootSubcommand::Chat(ChatArgs { - resume: false, - input: None, - profile: None, - model: None, - trust_all_tools: false, - trust_tools: Some(vec!["".to_string()]), - no_interactive: false - }) - ); - } - - #[test] - fn test_chat_with_tool_trust_some() { - assert_parse!( - ["chat", "--trust-tools=fs_read,fs_write"], - RootSubcommand::Chat(ChatArgs { - resume: false, - input: None, - profile: None, - model: None, - trust_all_tools: false, - trust_tools: Some(vec!["fs_read".to_string(), "fs_write".to_string()]), - no_interactive: false - }) - ); - } -} diff --git a/crates/chat-cli/src/cli/settings.rs b/crates/chat-cli/src/cli/settings.rs deleted file mode 100644 index 0993fb5dc..000000000 --- a/crates/chat-cli/src/cli/settings.rs +++ /dev/null @@ -1,155 +0,0 @@ -use std::process::ExitCode; - -use anstream::println; -use clap::{ - ArgGroup, - Args, - Subcommand, -}; -use eyre::{ - Result, - WrapErr, - bail, -}; -use globset::Glob; -use serde_json::json; - -use super::OutputFormat; -use crate::database::settings::Setting; -use crate::os::Os; -use crate::util::directories; - -#[derive(Clone, Debug, Subcommand, PartialEq, Eq)] -pub enum SettingsSubcommands { - /// Open the settings file - Open, - /// List all the settings - All { - /// Format of the output - #[arg(long, short, value_enum, default_value_t)] - format: OutputFormat, - /// Whether or not we want to modify state instead - #[arg(long, short, hide = true)] - state: bool, - }, -} - -#[derive(Clone, Debug, Args, PartialEq, Eq)] -#[command(subcommand_negates_reqs = true)] -#[command(args_conflicts_with_subcommands = true)] -#[command(group(ArgGroup::new("vals").requires("key").args(&["value", "delete", "format"])))] -pub struct SettingsArgs { - #[command(subcommand)] - cmd: Option, - /// key - key: Option, - /// value - value: Option, - /// Delete a value - #[arg(long, short)] - delete: bool, - /// Format of the output - #[arg(long, short, value_enum, default_value_t)] - format: OutputFormat, -} - -impl SettingsArgs { - pub async fn execute(&self, os: &mut Os) -> Result { - match self.cmd { - Some(SettingsSubcommands::Open) => { - let file = directories::settings_path().context("Could not get settings path")?; - if let Ok(editor) = os.env.get("EDITOR") { - tokio::process::Command::new(editor).arg(file).spawn()?.wait().await?; - Ok(ExitCode::SUCCESS) - } else { - bail!("The EDITOR environment variable is not set") - } - }, - Some(SettingsSubcommands::All { format, state }) => { - let settings = match state { - true => os.database.get_all_entries()?, - false => os.database.settings.map().clone(), - }; - - match format { - OutputFormat::Plain => { - for (key, value) in settings { - println!("{key} = {value}"); - } - }, - OutputFormat::Json => println!("{}", serde_json::to_string(&settings)?), - OutputFormat::JsonPretty => { - println!("{}", serde_json::to_string_pretty(&settings)?); - }, - } - - Ok(ExitCode::SUCCESS) - }, - None => { - let Some(key) = &self.key else { - return Ok(ExitCode::SUCCESS); - }; - - let key = Setting::try_from(key.as_str())?; - match (&self.value, self.delete) { - (None, false) => match os.database.settings.get(key) { - Some(value) => { - match self.format { - OutputFormat::Plain => match value.as_str() { - Some(value) => println!("{value}"), - None => println!("{value:#}"), - }, - OutputFormat::Json => println!("{value}"), - OutputFormat::JsonPretty => println!("{value:#}"), - } - Ok(ExitCode::SUCCESS) - }, - None => match self.format { - OutputFormat::Plain => Err(eyre::eyre!("No value associated with {key}")), - OutputFormat::Json | OutputFormat::JsonPretty => { - println!("null"); - Ok(ExitCode::SUCCESS) - }, - }, - }, - (Some(value_str), false) => { - let value = serde_json::from_str(value_str).unwrap_or_else(|_| json!(value_str)); - os.database.settings.set(key, value).await?; - Ok(ExitCode::SUCCESS) - }, - (None, true) => { - let glob = Glob::new(key.as_ref()) - .context("Could not create glob")? - .compile_matcher(); - let map = os.database.settings.map(); - let keys_to_remove = map.keys().filter(|key| glob.is_match(key)).cloned().collect::>(); - - match keys_to_remove.len() { - 0 => { - return Err(eyre::eyre!("No settings found matching {key}")); - }, - 1 => { - println!("Removing {:?}", keys_to_remove[0]); - os.database - .settings - .remove(Setting::try_from(keys_to_remove[0].as_str())?) - .await?; - }, - _ => { - for key in &keys_to_remove { - if let Ok(key) = Setting::try_from(key.as_str()) { - println!("Removing `{key}`"); - os.database.settings.remove(key).await?; - } - } - }, - } - - Ok(ExitCode::SUCCESS) - }, - _ => Ok(ExitCode::SUCCESS), - } - }, - } - } -} diff --git a/crates/chat-cli/src/cli/user.rs b/crates/chat-cli/src/cli/user.rs deleted file mode 100644 index 50e761b9b..000000000 --- a/crates/chat-cli/src/cli/user.rs +++ /dev/null @@ -1,426 +0,0 @@ -use std::fmt; -use std::fmt::Display; -use std::process::{ - ExitCode, - exit, -}; -use std::time::Duration; - -use anstream::{ - eprintln, - println, -}; -use clap::{ - Args, - Subcommand, -}; -use crossterm::style::Stylize; -use dialoguer::Select; -use eyre::{ - Result, - bail, -}; -use serde_json::json; -use tokio::signal::ctrl_c; -use tracing::{ - error, - info, -}; - -use super::OutputFormat; -use crate::api_client::list_available_profiles; -use crate::auth::builder_id::{ - BuilderIdToken, - PollCreateToken, - TokenType, - poll_create_token, - start_device_authorization, -}; -use crate::auth::pkce::start_pkce_authorization; -use crate::os::Os; -use crate::telemetry::{ - QProfileSwitchIntent, - TelemetryResult, -}; -use crate::util::spinner::{ - Spinner, - SpinnerComponent, -}; -use crate::util::system_info::is_remote; -use crate::util::{ - CLI_BINARY_NAME, - PRODUCT_NAME, - choose, - input, -}; - -#[derive(Args, Debug, PartialEq, Eq, Clone, Default)] -pub struct LoginArgs { - /// License type (pro for Identity Center, free for Builder ID) - #[arg(long, value_enum)] - pub license: Option, - - /// Identity provider URL (for Identity Center) - #[arg(long)] - pub identity_provider: Option, - - /// Region (for Identity Center) - #[arg(long)] - pub region: Option, - - /// Always use the OAuth device flow for authentication. Useful for instances where browser - /// redirects cannot be handled. - #[arg(long)] - pub use_device_flow: bool, -} - -impl LoginArgs { - pub async fn execute(self, os: &mut Os) -> Result { - if crate::auth::is_logged_in(&mut os.database).await { - eyre::bail!( - "Already logged in, please logout with {} first", - format!("{CLI_BINARY_NAME} logout").magenta() - ); - } - - let login_method = match self.license { - Some(LicenseType::Free) => AuthMethod::BuilderId, - Some(LicenseType::Pro) => AuthMethod::IdentityCenter, - None => { - if self.identity_provider.is_some() && self.region.is_some() { - // If license is specified and --identity-provider and --region are specified, - // the license is determined to be pro - AuthMethod::IdentityCenter - } else { - // --license is not specified, prompt the user to choose - let options = [AuthMethod::BuilderId, AuthMethod::IdentityCenter]; - let i = match choose("Select login method", &options)? { - Some(i) => i, - None => bail!("No login method selected"), - }; - options[i] - } - }, - }; - - match login_method { - AuthMethod::BuilderId | AuthMethod::IdentityCenter => { - let (start_url, region) = match login_method { - AuthMethod::BuilderId => (None, None), - AuthMethod::IdentityCenter => { - let default_start_url = match self.identity_provider { - Some(start_url) => Some(start_url), - None => os.database.get_start_url()?, - }; - let default_region = match self.region { - Some(region) => Some(region), - None => os.database.get_idc_region()?, - }; - - let start_url = input("Enter Start URL", default_start_url.as_deref())?; - let region = input("Enter Region", default_region.as_deref())?; - - let _ = os.database.set_start_url(start_url.clone()); - let _ = os.database.set_idc_region(region.clone()); - - (Some(start_url), Some(region)) - }, - }; - - // Remote machine won't be able to handle browser opening and redirects, - // hence always use device code flow. - if is_remote() || self.use_device_flow { - try_device_authorization(os, start_url.clone(), region.clone()).await?; - } else { - let (client, registration) = start_pkce_authorization(start_url.clone(), region.clone()).await?; - - match crate::util::open::open_url_async(®istration.url).await { - // If it succeeded, finish PKCE. - Ok(()) => { - let mut spinner = Spinner::new(vec![ - SpinnerComponent::Spinner, - SpinnerComponent::Text(" Logging in...".into()), - ]); - let ctrl_c_stream = ctrl_c(); - tokio::select! { - res = registration.finish(&client, Some(&mut os.database)) => res?, - Ok(_) = ctrl_c_stream => { - #[allow(clippy::exit)] - exit(1); - }, - } - os.telemetry.send_user_logged_in().ok(); - spinner.stop_with_message("Logged in".into()); - }, - // If we are unable to open the link with the browser, then fallback to - // the device code flow. - Err(err) => { - error!(%err, "Failed to open URL with browser, falling back to device code flow"); - - // Try device code flow. - try_device_authorization(os, start_url.clone(), region.clone()).await?; - }, - } - } - }, - }; - - if login_method == AuthMethod::IdentityCenter { - select_profile_interactive(os, true).await?; - } - - Ok(ExitCode::SUCCESS) - } -} - -pub async fn logout(os: &mut Os) -> Result { - let _ = crate::auth::logout(&mut os.database).await; - - eprintln!("You are now logged out"); - eprintln!( - "Run {} to log back in to {PRODUCT_NAME}", - format!("{CLI_BINARY_NAME} login").magenta() - ); - - Ok(ExitCode::SUCCESS) -} - -#[derive(Args, Debug, PartialEq, Eq, Clone, Default)] -pub struct WhoamiArgs { - /// Output format to use - #[arg(long, short, value_enum, default_value_t)] - format: OutputFormat, -} - -impl WhoamiArgs { - pub async fn execute(self, os: &mut Os) -> Result { - let builder_id = BuilderIdToken::load(&os.database).await; - - match builder_id { - Ok(Some(token)) => { - self.format.print( - || match token.token_type() { - TokenType::BuilderId => "Logged in with Builder ID".into(), - TokenType::IamIdentityCenter => { - format!( - "Logged in with IAM Identity Center ({})", - token.start_url.as_ref().unwrap() - ) - }, - }, - || { - json!({ - "accountType": match token.token_type() { - TokenType::BuilderId => "BuilderId", - TokenType::IamIdentityCenter => "IamIdentityCenter", - }, - "startUrl": token.start_url, - "region": token.region, - }) - }, - ); - - if matches!(token.token_type(), TokenType::IamIdentityCenter) { - if let Ok(Some(profile)) = os.database.get_auth_profile() { - color_print::cprintln!("\nProfile:\n{}\n{}\n", profile.profile_name, profile.arn); - } - } - - Ok(ExitCode::SUCCESS) - }, - _ => { - self.format.print(|| "Not logged in", || json!({ "account": null })); - Ok(ExitCode::FAILURE) - }, - } - } -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum)] -pub enum LicenseType { - /// Free license with Builder ID - Free, - /// Pro license with Identity Center - Pro, -} - -pub async fn profile(os: &mut Os) -> Result { - if let Ok(Some(token)) = BuilderIdToken::load(&os.database).await { - if matches!(token.token_type(), TokenType::BuilderId) { - bail!("This command is only available for Pro users"); - } - } - - select_profile_interactive(os, false).await?; - - Ok(ExitCode::SUCCESS) -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -enum AuthMethod { - /// Builder ID (free) - BuilderId, - /// IdC (enterprise) - IdentityCenter, -} - -impl Display for AuthMethod { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - AuthMethod::BuilderId => write!(f, "Use for Free with Builder ID"), - AuthMethod::IdentityCenter => write!(f, "Use with Pro license"), - } - } -} - -#[derive(Subcommand, Debug, PartialEq, Eq)] -pub enum UserSubcommand { - Profile, -} - -async fn try_device_authorization(os: &mut Os, start_url: Option, region: Option) -> Result<()> { - let device_auth = start_device_authorization(&os.database, start_url.clone(), region.clone()).await?; - - println!(); - println!("Confirm the following code in the browser"); - println!("Code: {}", device_auth.user_code.bold()); - println!(); - - let print_open_url = || println!("Open this URL: {}", device_auth.verification_uri_complete); - - if is_remote() { - print_open_url(); - } else if let Err(err) = crate::util::open::open_url_async(&device_auth.verification_uri_complete).await { - error!(%err, "Failed to open URL with browser"); - print_open_url(); - } - - let mut spinner = Spinner::new(vec![ - SpinnerComponent::Spinner, - SpinnerComponent::Text(" Logging in...".into()), - ]); - - loop { - let ctrl_c_stream = ctrl_c(); - tokio::select! { - _ = tokio::time::sleep(Duration::from_secs(device_auth.interval.try_into().unwrap_or(1))) => (), - Ok(_) = ctrl_c_stream => { - #[allow(clippy::exit)] - exit(1); - } - } - match poll_create_token( - &os.database, - device_auth.device_code.clone(), - start_url.clone(), - region.clone(), - ) - .await - { - PollCreateToken::Pending => {}, - PollCreateToken::Complete => { - os.telemetry.send_user_logged_in().ok(); - spinner.stop_with_message("Logged in".into()); - break; - }, - PollCreateToken::Error(err) => { - spinner.stop(); - return Err(err.into()); - }, - }; - } - Ok(()) -} - -async fn select_profile_interactive(os: &mut Os, whoami: bool) -> Result<()> { - let mut spinner = Spinner::new(vec![ - SpinnerComponent::Spinner, - SpinnerComponent::Text(" Fetching profiles...".into()), - ]); - let profiles = list_available_profiles(&os.env, &os.fs, &mut os.database).await?; - if profiles.is_empty() { - info!("Available profiles was empty"); - return Ok(()); - } - - let sso_region = os.database.get_idc_region()?; - let total_profiles = profiles.len() as i64; - - if whoami && profiles.len() == 1 { - if let Some(profile_region) = profiles[0].arn.split(':').nth(3) { - os.telemetry - .send_profile_state( - QProfileSwitchIntent::Update, - profile_region.to_string(), - TelemetryResult::Succeeded, - sso_region, - ) - .ok(); - } - - spinner.stop_with_message(String::new()); - os.database.set_auth_profile(&profiles[0])?; - return Ok(()); - } - - let mut items: Vec = profiles - .iter() - .map(|p| format!("{} (arn: {})", p.profile_name, p.arn)) - .collect(); - let active_profile = os.database.get_auth_profile()?; - - if let Some(default_idx) = active_profile - .as_ref() - .and_then(|active| profiles.iter().position(|p| p.arn == active.arn)) - { - items[default_idx] = format!("{} (active)", items[default_idx].as_str()); - } - - spinner.stop_with_message(String::new()); - let selected = Select::with_theme(&crate::util::dialoguer_theme()) - .with_prompt("Select an IAM Identity Center profile") - .items(&items) - .default(0) - .interact_opt()?; - - match selected { - Some(i) => { - let chosen = &profiles[i]; - eprintln!("Profile set"); - os.database.set_auth_profile(chosen)?; - - if let Some(profile_region) = chosen.arn.split(':').nth(3) { - let intent = if whoami { - QProfileSwitchIntent::Auth - } else { - QProfileSwitchIntent::User - }; - - os.telemetry - .send_did_select_profile( - intent, - profile_region.to_string(), - TelemetryResult::Succeeded, - sso_region, - Some(total_profiles), - ) - .ok(); - } - }, - None => { - os.telemetry - .send_did_select_profile( - QProfileSwitchIntent::User, - "not-set".to_string(), - TelemetryResult::Cancelled, - sso_region, - Some(total_profiles), - ) - .ok(); - - bail!("No profile selected.\n"); - }, - } - - Ok(()) -} diff --git a/crates/chat-cli/src/database/mod.rs b/crates/chat-cli/src/database/mod.rs deleted file mode 100644 index bccee45fd..000000000 --- a/crates/chat-cli/src/database/mod.rs +++ /dev/null @@ -1,617 +0,0 @@ -pub mod settings; - -use std::ops::Deref; -use std::path::Path; -use std::str::FromStr; -use std::sync::PoisonError; - -use aws_sdk_cognitoidentity::primitives::DateTimeFormat; -use aws_sdk_cognitoidentity::types::Credentials; -use r2d2::Pool; -use r2d2_sqlite::SqliteConnectionManager; -use rusqlite::types::FromSql; -use rusqlite::{ - Connection, - Error, - ToSql, - params, -}; -use serde::de::DeserializeOwned; -use serde::{ - Deserialize, - Serialize, -}; -use serde_json::{ - Map, - Value, -}; -use settings::Settings; -use thiserror::Error; -use tracing::{ - error, - info, - trace, -}; -use uuid::Uuid; - -use crate::cli::ConversationState; -use crate::util::directories::{ - DirectoryError, - database_path, -}; - -macro_rules! migrations { - ($($name:expr),*) => {{ - &[ - $( - Migration { - name: $name, - sql: include_str!(concat!("sqlite_migrations/", $name, ".sql")), - } - ),* - ] - }}; -} - -const CREDENTIALS_KEY: &str = "telemetry-cognito-credentials"; -const CLIENT_ID_KEY: &str = "telemetryClientId"; -const CODEWHISPERER_PROFILE_KEY: &str = "api.codewhisperer.profile"; -const START_URL_KEY: &str = "auth.idc.start-url"; -const IDC_REGION_KEY: &str = "auth.idc.region"; -// We include this key to remove for backwards compatibility -const CUSTOMIZATION_STATE_KEY: &str = "api.selectedCustomization"; - -const MIGRATIONS: &[Migration] = migrations![ - "000_migration_table", - "001_history_table", - "002_drop_history_in_ssh_docker", - "003_improved_history_timing", - "004_state_table", - "005_auth_table", - "006_make_state_blob", - "007_conversations_table" -]; - -#[derive(Debug, serde::Deserialize, serde::Serialize)] -pub struct CredentialsJson { - pub access_key_id: Option, - pub secret_key: Option, - pub session_token: Option, - pub expiration: Option, -} - -#[derive(Clone, Debug, Deserialize, Serialize)] -pub struct AuthProfile { - pub arn: String, - pub profile_name: String, -} - -impl From for AuthProfile { - fn from(profile: amzn_codewhisperer_client::types::Profile) -> Self { - Self { - arn: profile.arn, - profile_name: profile.profile_name, - } - } -} - -#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, serde::Serialize, serde::Deserialize)] -#[serde(transparent)] -pub struct Secret(pub String); - -impl std::fmt::Debug for Secret { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("Secret").finish() - } -} - -impl From for Secret -where - T: Into, -{ - fn from(value: T) -> Self { - Self(value.into()) - } -} - -// A cloneable error -#[derive(Debug, Clone, thiserror::Error)] -#[error("Failed to open database: {}", .0)] -pub struct DbOpenError(pub(crate) String); - -#[derive(Debug, Error)] -pub enum DatabaseError { - #[error(transparent)] - IoError(#[from] std::io::Error), - #[error(transparent)] - JsonError(#[from] serde_json::Error), - #[error(transparent)] - FigUtilError(#[from] crate::util::UtilError), - #[error(transparent)] - DirectoryError(#[from] DirectoryError), - #[error(transparent)] - Rusqlite(#[from] rusqlite::Error), - #[error(transparent)] - R2d2(#[from] r2d2::Error), - #[error(transparent)] - DbOpenError(#[from] DbOpenError), - #[error("{}", .0)] - PoisonError(String), - #[error(transparent)] - StringFromUtf8(#[from] std::string::FromUtf8Error), - #[error(transparent)] - StrFromUtf8(#[from] std::str::Utf8Error), - #[error("`{}` is not a valid setting", .0)] - InvalidSetting(String), -} - -impl From> for DatabaseError { - fn from(value: PoisonError) -> Self { - Self::PoisonError(value.to_string()) - } -} - -#[derive(Debug)] -pub enum Table { - /// The state table contains persistent application state. - State, - /// The conversations tables contains user chat conversations. - Conversations, - /// The auth table contains SSO and Builder ID credentials. - Auth, -} - -impl std::fmt::Display for Table { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Table::State => write!(f, "state"), - Table::Conversations => write!(f, "conversations"), - Table::Auth => write!(f, "auth_kv"), - } - } -} - -#[derive(Debug)] -struct Migration { - name: &'static str, - sql: &'static str, -} - -#[derive(Clone, Debug)] -pub struct Database { - pool: Pool, - pub settings: Settings, -} - -impl Database { - pub async fn new() -> Result { - let path = match cfg!(test) { - true => { - return Self { - pool: Pool::builder().build(SqliteConnectionManager::memory()).unwrap(), - settings: Settings::new().await?, - } - .migrate(); - }, - false => database_path()?, - }; - - // make the parent dir if it doesnt exist - if let Some(parent) = path.parent() { - if !parent.exists() { - std::fs::create_dir_all(parent)?; - } - } - - let conn = SqliteConnectionManager::file(&path); - let pool = Pool::builder().build(conn)?; - - // Check the unix permissions of the database file, set them to 0600 if they are not - #[cfg(unix)] - { - use std::os::unix::fs::PermissionsExt; - let metadata = std::fs::metadata(&path)?; - let mut permissions = metadata.permissions(); - if permissions.mode() & 0o777 != 0o600 { - tracing::debug!(?path, "Setting database file permissions to 0600"); - permissions.set_mode(0o600); - std::fs::set_permissions(path, permissions)?; - } - } - - Ok(Self { - pool, - settings: Settings::new().await?, - } - .migrate() - .map_err(|e| DbOpenError(e.to_string()))?) - } - - /// Get all entries for dumping the persistent application state. - pub fn get_all_entries(&self) -> Result, DatabaseError> { - self.all_entries(Table::State) - } - - /// Get cognito credentials used by toolkit telemetry. - pub fn get_credentials_entry(&mut self) -> Result, DatabaseError> { - self.get_json_entry::(Table::State, CREDENTIALS_KEY) - } - - /// Set cognito credentials used by toolkit telemetry. - pub fn set_credentials_entry(&mut self, credentials: &Credentials) -> Result { - self.set_json_entry(Table::State, CREDENTIALS_KEY, CredentialsJson { - access_key_id: credentials.access_key_id.clone(), - secret_key: credentials.secret_key.clone(), - session_token: credentials.session_token.clone(), - expiration: credentials - .expiration - .and_then(|t| t.fmt(DateTimeFormat::DateTime).ok()), - }) - } - - /// Get the current user profile used to determine API endpoints. - pub fn get_auth_profile(&self) -> Result, DatabaseError> { - self.get_json_entry(Table::State, CODEWHISPERER_PROFILE_KEY) - } - - /// Set the current user profile used to determine API endpoints. - pub fn set_auth_profile(&mut self, profile: &AuthProfile) -> Result<(), DatabaseError> { - self.set_json_entry(Table::State, CODEWHISPERER_PROFILE_KEY, profile)?; - self.delete_entry(Table::State, CUSTOMIZATION_STATE_KEY) - } - - /// Unset the current user profile used to determine API endpoints. - pub fn unset_auth_profile(&mut self) -> Result<(), DatabaseError> { - self.delete_entry(Table::State, CODEWHISPERER_PROFILE_KEY)?; - self.delete_entry(Table::State, CUSTOMIZATION_STATE_KEY) - } - - /// Get the client ID used for telemetry requests. - pub fn get_client_id(&mut self) -> Result, DatabaseError> { - Ok(self - .get_json_entry::(Table::State, CLIENT_ID_KEY)? - .and_then(|s| Uuid::from_str(&s).ok())) - } - - /// Set the client ID used for telemetry requests. - pub fn set_client_id(&mut self, client_id: Uuid) -> Result { - self.set_json_entry(Table::State, CLIENT_ID_KEY, client_id.to_string()) - } - - /// Get the start URL used for IdC login. - pub fn get_start_url(&self) -> Result, DatabaseError> { - self.get_json_entry::(Table::State, START_URL_KEY) - } - - /// Set the start URL used for IdC login. - pub fn set_start_url(&mut self, start_url: String) -> Result { - self.set_json_entry(Table::State, START_URL_KEY, start_url) - } - - /// Get the region used for IdC login. - pub fn get_idc_region(&self) -> Result, DatabaseError> { - // Annoyingly, this is encoded as a JSON string on older clients - self.get_json_entry::(Table::State, IDC_REGION_KEY) - } - - /// Set the region used for IdC login. - pub fn set_idc_region(&mut self, region: String) -> Result { - // Annoyingly, this is encoded as a JSON string on older clients - self.set_json_entry(Table::State, IDC_REGION_KEY, region) - } - - // /// Get the model id used for last conversation state. - // pub fn get_last_used_model_id(&self) -> Result, DatabaseError> { - // self.get_json_entry::(Table::State, LAST_USED_MODEL_ID) - // } - - // /// Set the model id used for last conversation state. - // pub fn set_last_used_model_id(&mut self, last_used_model_id: String) -> Result { self.set_json_entry(Table::State, LAST_USED_MODEL_ID, - // last_used_model_id) } - - // /// UnsSet the model id used for last conversation state. - // pub fn unset_last_used_model_id(&mut self) -> Result<(), DatabaseError> { - // self.delete_entry(Table::State, LAST_USED_MODEL_ID) - // } - - /// Get a chat conversation given a path to the conversation. - pub fn get_conversation_by_path( - &mut self, - path: impl AsRef, - ) -> Result, DatabaseError> { - // We would need to encode this to support non utf8 paths. - let path = match path.as_ref().to_str() { - Some(path) => path, - None => return Ok(None), - }; - - self.get_json_entry(Table::Conversations, path) - } - - /// Set a chat conversation given a path to the conversation. - pub fn set_conversation_by_path( - &mut self, - path: impl AsRef, - state: &ConversationState, - ) -> Result { - // We would need to encode this to support non utf8 paths. - let path = match path.as_ref().to_str() { - Some(path) => path, - None => return Ok(0), - }; - - self.set_json_entry(Table::Conversations, path, state) - } - - pub async fn get_secret(&self, key: &str) -> Result, DatabaseError> { - trace!(key, "getting secret"); - Ok(self.get_entry::(Table::Auth, key)?.map(Into::into)) - } - - pub async fn set_secret(&self, key: &str, value: &str) -> Result<(), DatabaseError> { - trace!(key, "setting secret"); - self.set_entry(Table::Auth, key, value)?; - Ok(()) - } - - pub async fn delete_secret(&self, key: &str) -> Result<(), DatabaseError> { - trace!(key, "deleting secret"); - self.delete_entry(Table::Auth, key) - } - - // Private functions. Do not expose. - - fn migrate(self) -> Result { - let mut conn = self.pool.get()?; - let transaction = conn.transaction()?; - - let max_version = max_migration_version(&transaction); - - for (version, migration) in MIGRATIONS.iter().enumerate() { - if has_migration(&transaction, version, max_version)? { - continue; - } - - // execute the migration - transaction.execute_batch(migration.sql)?; - - info!(%version, name =% migration.name, "Applying migration"); - - // insert the migration entry - transaction.execute( - "INSERT INTO migrations (version, migration_time) VALUES (?1, strftime('%s', 'now'));", - params![version], - )?; - } - - // commit the transaction - transaction.commit()?; - - Ok(self) - } - - fn get_entry(&self, table: Table, key: impl AsRef) -> Result, DatabaseError> { - let conn = self.pool.get()?; - let mut stmt = conn.prepare(&format!("SELECT value FROM {table} WHERE key = ?1"))?; - match stmt.query_row([key.as_ref()], |row| row.get(0)) { - Ok(data) => Ok(Some(data)), - Err(Error::QueryReturnedNoRows) => Ok(None), - Err(err) => Err(err.into()), - } - } - - fn set_entry(&self, table: Table, key: impl AsRef, value: impl ToSql) -> Result { - Ok(self.pool.get()?.execute( - &format!("INSERT OR REPLACE INTO {table} (key, value) VALUES (?1, ?2)"), - params![key.as_ref(), value], - )?) - } - - fn get_json_entry( - &self, - table: Table, - key: impl AsRef, - ) -> Result, DatabaseError> { - Ok(match self.get_entry::(table, key.as_ref())? { - Some(value) => serde_json::from_str(&value)?, - None => None, - }) - } - - fn set_json_entry( - &self, - table: Table, - key: impl AsRef, - value: impl Serialize, - ) -> Result { - self.set_entry(table, key, serde_json::to_string(&value)?) - } - - fn delete_entry(&self, table: Table, key: impl AsRef) -> Result<(), DatabaseError> { - self.pool - .get()? - .execute(&format!("DELETE FROM {table} WHERE key = ?1"), [key.as_ref()])?; - Ok(()) - } - - fn all_entries(&self, table: Table) -> Result, DatabaseError> { - let conn = self.pool.get()?; - let mut stmt = conn.prepare(&format!("SELECT key, value FROM {table}"))?; - let rows = stmt.query_map([], |row| { - let key = row.get(0)?; - let value = Value::String(row.get(1)?); - Ok((key, value)) - })?; - - let mut map = Map::new(); - for row in rows { - let (key, value) = row?; - map.insert(key, value); - } - - Ok(map) - } -} - -fn max_migration_version>(conn: &C) -> Option { - let mut stmt = conn.prepare("SELECT MAX(version) FROM migrations").ok()?; - stmt.query_row([], |row| row.get(0)).ok() -} - -fn has_migration>( - conn: &C, - version: usize, - max_version: Option, -) -> Result { - // IMPORTANT: Due to a bug with the first 7 migrations, we have to check manually - // - // Background: the migrations table stores two identifying keys: the sqlite auto-generated - // auto-incrementing key `id`, and the `version` which is the index of the `MIGRATIONS` - // constant. - // - // Checking whether a migration exists would compare id with version, but since id is 1-indexed - // and version is 0-indexed, we would actually skip the last migration! Therefore, it's - // possible users are missing a critical migration (namely, auth_kv table creation) when - // upgrading to the qchat build (which includes two new migrations). Hence, we have to check - // all migrations until version 7 to make sure that nothing is missed. - if version <= 7 { - let mut stmt = match conn.prepare("SELECT COUNT(*) FROM migrations WHERE version = ?1") { - Ok(stmt) => stmt, - // If the migrations table does not exist, then we can reasonably say no migrations - // will exist. - Err(Error::SqliteFailure(_, Some(msg))) if msg.contains("no such table") => { - return Ok(false); - }, - Err(err) => return Err(err.into()), - }; - let count: i32 = stmt.query_row([version], |row| row.get(0))?; - return Ok(count >= 1); - } - - // Continuing from the previously implemented logic - any migrations after the 7th can have a simple - // maximum version check, since we can reasonably assume if any version >=7 will have all - // migrations prior to it. - #[allow(clippy::match_like_matches_macro)] - Ok(match max_version { - Some(max_version) if max_version >= version as i64 => true, - _ => false, - }) -} - -#[cfg(test)] -mod tests { - use super::*; - - fn all_errors() -> Vec { - vec![ - std::io::Error::new(std::io::ErrorKind::InvalidData, "oops").into(), - serde_json::from_str::<()>("oops").unwrap_err().into(), - crate::util::directories::DirectoryError::NoHomeDirectory.into(), - rusqlite::Error::SqliteSingleThreadedMode.into(), - // r2d2::Error - DbOpenError("oops".into()).into(), - PoisonError::<()>::new(()).into(), - ] - } - - #[test] - fn test_error_display_debug() { - for error in all_errors() { - eprintln!("{}", error); - eprintln!("{:?}", error); - } - } - - #[tokio::test] - async fn test_migrate() { - let db = Database::new().await.unwrap(); - - // assert migration count is correct - let max_migration = max_migration_version(&&*db.pool.get().unwrap()); - assert_eq!(max_migration, Some(MIGRATIONS.len() as i64 - 1)); - } - - #[test] - fn list_migrations() { - // Assert the migrations are in order - assert!(MIGRATIONS.windows(2).all(|w| w[0].name <= w[1].name)); - - // Assert the migrations start with their index - assert!( - MIGRATIONS - .iter() - .enumerate() - .all(|(i, m)| m.name.starts_with(&format!("{:03}_", i))) - ); - - // Assert all the files in migrations/ are in the list - let migration_folder = std::path::Path::new(env!("CARGO_MANIFEST_DIR")).join("src/database/sqlite_migrations"); - let migration_count = std::fs::read_dir(migration_folder).unwrap().count(); - assert_eq!(MIGRATIONS.len(), migration_count); - } - - #[tokio::test] - async fn state_table_tests() { - let db = Database::new().await.unwrap(); - - // set - db.set_entry(Table::State, "test", "test").unwrap(); - db.set_entry(Table::State, "int", 1).unwrap(); - db.set_entry(Table::State, "float", 1.0).unwrap(); - db.set_entry(Table::State, "bool", true).unwrap(); - db.set_entry(Table::State, "array", vec![1, 2, 3]).unwrap(); - db.set_entry(Table::State, "object", serde_json::json!({ "test": "test" })) - .unwrap(); - db.set_entry(Table::State, "binary", b"test".to_vec()).unwrap(); - - // unset - db.delete_entry(Table::State, "test").unwrap(); - db.delete_entry(Table::State, "int").unwrap(); - - // is some - assert!(db.get_entry::(Table::State, "test").unwrap().is_none()); - assert!(db.get_entry::(Table::State, "int").unwrap().is_none()); - assert!(db.get_entry::(Table::State, "float").unwrap().is_some()); - assert!(db.get_entry::(Table::State, "bool").unwrap().is_some()); - } - - #[tokio::test] - #[ignore = "not on ci"] - async fn test_set_password() { - let key = "test_set_password"; - let store = Database::new().await.unwrap(); - store.set_secret(key, "test").await.unwrap(); - assert_eq!(store.get_secret(key).await.unwrap().unwrap().0, "test"); - store.delete_secret(key).await.unwrap(); - } - - #[tokio::test] - #[ignore = "not on ci"] - async fn secret_get_time() { - let key = "test_secret_get_time"; - let store = Database::new().await.unwrap(); - store.set_secret(key, "1234").await.unwrap(); - - let now = std::time::Instant::now(); - for _ in 0..100 { - store.get_secret(key).await.unwrap(); - } - - println!("duration: {:?}", now.elapsed() / 100); - - store.delete_secret(key).await.unwrap(); - } - - #[tokio::test] - #[ignore = "not on ci"] - async fn secret_delete() { - let key = "test_secret_delete"; - - let store = Database::new().await.unwrap(); - store.set_secret(key, "1234").await.unwrap(); - assert_eq!(store.get_secret(key).await.unwrap().unwrap().0, "1234"); - store.delete_secret(key).await.unwrap(); - assert_eq!(store.get_secret(key).await.unwrap(), None); - } -} diff --git a/crates/chat-cli/src/database/settings.rs b/crates/chat-cli/src/database/settings.rs deleted file mode 100644 index 763742cc0..000000000 --- a/crates/chat-cli/src/database/settings.rs +++ /dev/null @@ -1,245 +0,0 @@ -use std::fmt::Display; -use std::io::SeekFrom; - -use fd_lock::RwLock; -use serde_json::{ - Map, - Value, -}; -use tokio::fs::File; -use tokio::io::{ - AsyncReadExt, - AsyncSeekExt, - AsyncWriteExt, -}; - -use super::DatabaseError; - -#[derive(Clone, Copy, Debug)] -pub enum Setting { - TelemetryEnabled, - OldClientId, - ShareCodeWhispererContent, - EnabledThinking, - EnabledKnowledge, - SkimCommandKey, - ChatGreetingEnabled, - ApiTimeout, - ChatEditMode, - ChatEnableNotifications, - ApiCodeWhispererService, - ApiQService, - McpInitTimeout, - McpNoInteractiveTimeout, - McpLoadedBefore, - ChatDefaultModel, - ChatDisableAutoCompaction, - ChatEnableHistoryHints, -} - -impl AsRef for Setting { - fn as_ref(&self) -> &'static str { - match self { - Self::TelemetryEnabled => "telemetry.enabled", - Self::OldClientId => "telemetryClientId", - Self::ShareCodeWhispererContent => "codeWhisperer.shareCodeWhispererContentWithAWS", - Self::EnabledThinking => "chat.enableThinking", - Self::EnabledKnowledge => "chat.enableKnowledge", - Self::SkimCommandKey => "chat.skimCommandKey", - Self::ChatGreetingEnabled => "chat.greeting.enabled", - Self::ApiTimeout => "api.timeout", - Self::ChatEditMode => "chat.editMode", - Self::ChatEnableNotifications => "chat.enableNotifications", - Self::ApiCodeWhispererService => "api.codewhisperer.service", - Self::ApiQService => "api.q.service", - Self::McpInitTimeout => "mcp.initTimeout", - Self::McpNoInteractiveTimeout => "mcp.noInteractiveTimeout", - Self::McpLoadedBefore => "mcp.loadedBefore", - Self::ChatDefaultModel => "chat.defaultModel", - Self::ChatDisableAutoCompaction => "chat.disableAutoCompaction", - Self::ChatEnableHistoryHints => "chat.enableHistoryHints", - } - } -} - -impl Display for Setting { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str(self.as_ref()) - } -} - -impl TryFrom<&str> for Setting { - type Error = DatabaseError; - - fn try_from(value: &str) -> Result { - match value { - "telemetry.enabled" => Ok(Self::TelemetryEnabled), - "telemetryClientId" => Ok(Self::OldClientId), - "codeWhisperer.shareCodeWhispererContentWithAWS" => Ok(Self::ShareCodeWhispererContent), - "chat.enableThinking" => Ok(Self::EnabledThinking), - "chat.enableKnowledge" => Ok(Self::EnabledKnowledge), - "chat.skimCommandKey" => Ok(Self::SkimCommandKey), - "chat.greeting.enabled" => Ok(Self::ChatGreetingEnabled), - "api.timeout" => Ok(Self::ApiTimeout), - "chat.editMode" => Ok(Self::ChatEditMode), - "chat.enableNotifications" => Ok(Self::ChatEnableNotifications), - "api.codewhisperer.service" => Ok(Self::ApiCodeWhispererService), - "api.q.service" => Ok(Self::ApiQService), - "mcp.initTimeout" => Ok(Self::McpInitTimeout), - "mcp.noInteractiveTimeout" => Ok(Self::McpNoInteractiveTimeout), - "mcp.loadedBefore" => Ok(Self::McpLoadedBefore), - "chat.defaultModel" => Ok(Self::ChatDefaultModel), - "chat.disableAutoCompaction" => Ok(Self::ChatDisableAutoCompaction), - "chat.enableHistoryHints" => Ok(Self::ChatEnableHistoryHints), - _ => Err(DatabaseError::InvalidSetting(value.to_string())), - } - } -} - -#[derive(Debug, Clone, Default)] -pub struct Settings(Map); - -impl Settings { - pub async fn new() -> Result { - if cfg!(test) { - return Ok(Self::default()); - } - - let path = crate::util::directories::settings_path()?; - - // If the folder doesn't exist, create it. - if let Some(parent) = path.parent() { - if !parent.exists() { - std::fs::create_dir_all(parent)?; - } - } - - Ok(Self(match path.exists() { - true => { - let mut file = RwLock::new(File::open(&path).await?); - let mut buf = Vec::new(); - file.write()?.read_to_end(&mut buf).await?; - serde_json::from_slice(&buf)? - }, - false => { - let mut file = RwLock::new(File::create(path).await?); - file.write()?.write_all(b"{}").await?; - serde_json::Map::new() - }, - })) - } - - pub fn map(&self) -> &'_ Map { - &self.0 - } - - pub fn get(&self, key: Setting) -> Option<&Value> { - self.0.get(key.as_ref()) - } - - pub async fn set(&mut self, key: Setting, value: impl Into) -> Result<(), DatabaseError> { - self.0.insert(key.to_string(), value.into()); - self.save_to_file().await - } - - pub async fn remove(&mut self, key: Setting) -> Result, DatabaseError> { - let key = self.0.remove(key.as_ref()); - self.save_to_file().await?; - Ok(key) - } - - pub fn get_bool(&self, key: Setting) -> Option { - self.get(key).and_then(|value| value.as_bool()) - } - - pub fn get_string(&self, key: Setting) -> Option { - self.get(key).and_then(|value| value.as_str().map(|s| s.into())) - } - - pub fn get_int(&self, key: Setting) -> Option { - self.get(key).and_then(|value| value.as_i64()) - } - - pub async fn save_to_file(&self) -> Result<(), DatabaseError> { - if cfg!(test) { - return Ok(()); - } - - let path = crate::util::directories::settings_path()?; - - // If the folder doesn't exist, create it. - if let Some(parent) = path.parent() { - if !parent.exists() { - tokio::fs::create_dir_all(parent).await?; - } - } - - let mut file_opts = File::options(); - file_opts.create(true).write(true).truncate(true); - - #[cfg(unix)] - file_opts.mode(0o600); - let mut file = RwLock::new(file_opts.open(&path).await?); - let mut lock = file.write()?; - - match serde_json::to_string_pretty(&self.0) { - Ok(json) => lock.write_all(json.as_bytes()).await?, - Err(_err) => { - lock.seek(SeekFrom::Start(0)).await?; - lock.set_len(0).await?; - lock.write_all(b"{}").await?; - }, - } - lock.flush().await?; - - Ok(()) - } -} - -#[cfg(test)] -mod test { - use super::*; - - /// General read/write settings test - #[tokio::test] - async fn test_settings() { - let mut settings = Settings::new().await.unwrap(); - - assert_eq!(settings.get(Setting::TelemetryEnabled), None); - assert_eq!(settings.get(Setting::OldClientId), None); - assert_eq!(settings.get(Setting::ShareCodeWhispererContent), None); - assert_eq!(settings.get(Setting::McpLoadedBefore), None); - assert_eq!(settings.get(Setting::ChatDefaultModel), None); - - settings.set(Setting::TelemetryEnabled, true).await.unwrap(); - settings.set(Setting::OldClientId, "test").await.unwrap(); - settings.set(Setting::ShareCodeWhispererContent, false).await.unwrap(); - settings.set(Setting::McpLoadedBefore, true).await.unwrap(); - settings.set(Setting::ChatDefaultModel, "model 1").await.unwrap(); - - assert_eq!(settings.get(Setting::TelemetryEnabled), Some(&Value::Bool(true))); - assert_eq!( - settings.get(Setting::OldClientId), - Some(&Value::String("test".to_string())) - ); - assert_eq!( - settings.get(Setting::ShareCodeWhispererContent), - Some(&Value::Bool(false)) - ); - assert_eq!(settings.get(Setting::McpLoadedBefore), Some(&Value::Bool(true))); - assert_eq!( - settings.get(Setting::ChatDefaultModel), - Some(&Value::String("model 1".to_string())) - ); - - settings.remove(Setting::TelemetryEnabled).await.unwrap(); - settings.remove(Setting::OldClientId).await.unwrap(); - settings.remove(Setting::ShareCodeWhispererContent).await.unwrap(); - settings.remove(Setting::McpLoadedBefore).await.unwrap(); - - assert_eq!(settings.get(Setting::TelemetryEnabled), None); - assert_eq!(settings.get(Setting::OldClientId), None); - assert_eq!(settings.get(Setting::ShareCodeWhispererContent), None); - assert_eq!(settings.get(Setting::McpLoadedBefore), None); - } -} diff --git a/crates/chat-cli/src/database/sqlite_migrations/000_migration_table.sql b/crates/chat-cli/src/database/sqlite_migrations/000_migration_table.sql deleted file mode 100644 index 1437deb0d..000000000 --- a/crates/chat-cli/src/database/sqlite_migrations/000_migration_table.sql +++ /dev/null @@ -1,5 +0,0 @@ -CREATE TABLE IF NOT EXISTS migrations ( - id INTEGER PRIMARY KEY, - version INTEGER NOT NULL, - migration_time INTEGER NOT NULL -); \ No newline at end of file diff --git a/crates/chat-cli/src/database/sqlite_migrations/001_history_table.sql b/crates/chat-cli/src/database/sqlite_migrations/001_history_table.sql deleted file mode 100644 index 7d2591338..000000000 --- a/crates/chat-cli/src/database/sqlite_migrations/001_history_table.sql +++ /dev/null @@ -1,13 +0,0 @@ -CREATE TABLE IF NOT EXISTS history ( - id INTEGER PRIMARY KEY, - command TEXT, - shell TEXT, - pid INTEGER, - session_id TEXT, - cwd TEXT, - time INTEGER, - in_ssh INTEGER, - in_docker INTEGER, - hostname TEXT, - exit_code INTEGER -); diff --git a/crates/chat-cli/src/database/sqlite_migrations/002_drop_history_in_ssh_docker.sql b/crates/chat-cli/src/database/sqlite_migrations/002_drop_history_in_ssh_docker.sql deleted file mode 100644 index 45e518e02..000000000 --- a/crates/chat-cli/src/database/sqlite_migrations/002_drop_history_in_ssh_docker.sql +++ /dev/null @@ -1,3 +0,0 @@ -ALTER TABLE history DROP COLUMN in_ssh; -ALTER TABLE history DROP COLUMN in_docker; - \ No newline at end of file diff --git a/crates/chat-cli/src/database/sqlite_migrations/003_improved_history_timing.sql b/crates/chat-cli/src/database/sqlite_migrations/003_improved_history_timing.sql deleted file mode 100644 index 58e3bb1c3..000000000 --- a/crates/chat-cli/src/database/sqlite_migrations/003_improved_history_timing.sql +++ /dev/null @@ -1,3 +0,0 @@ -ALTER TABLE history RENAME COLUMN time TO start_time; -ALTER TABLE history ADD COLUMN end_time INTEGER; -ALTER TABLE history ADD COLUMN duration INTEGER; diff --git a/crates/chat-cli/src/database/sqlite_migrations/004_state_table.sql b/crates/chat-cli/src/database/sqlite_migrations/004_state_table.sql deleted file mode 100644 index 3a7b43c00..000000000 --- a/crates/chat-cli/src/database/sqlite_migrations/004_state_table.sql +++ /dev/null @@ -1,4 +0,0 @@ -CREATE TABLE state ( - key TEXT PRIMARY KEY, - value TEXT -); diff --git a/crates/chat-cli/src/database/sqlite_migrations/005_auth_table.sql b/crates/chat-cli/src/database/sqlite_migrations/005_auth_table.sql deleted file mode 100644 index 17b28fb8e..000000000 --- a/crates/chat-cli/src/database/sqlite_migrations/005_auth_table.sql +++ /dev/null @@ -1,6 +0,0 @@ --- We create a separate auth_kv to ensure the data is not available in all the same --- places that the state is available in -CREATE TABLE auth_kv ( - key TEXT PRIMARY KEY, - value TEXT -); diff --git a/crates/chat-cli/src/database/sqlite_migrations/006_make_state_blob.sql b/crates/chat-cli/src/database/sqlite_migrations/006_make_state_blob.sql deleted file mode 100644 index fc3153823..000000000 --- a/crates/chat-cli/src/database/sqlite_migrations/006_make_state_blob.sql +++ /dev/null @@ -1,7 +0,0 @@ -ALTER TABLE state RENAME TO state_old; -CREATE TABLE state ( - key TEXT PRIMARY KEY, - value BLOB -); -INSERT INTO state SELECT key, value FROM state_old; -DROP TABLE state_old; \ No newline at end of file diff --git a/crates/chat-cli/src/database/sqlite_migrations/007_conversations_table.sql b/crates/chat-cli/src/database/sqlite_migrations/007_conversations_table.sql deleted file mode 100644 index 99f4244f5..000000000 --- a/crates/chat-cli/src/database/sqlite_migrations/007_conversations_table.sql +++ /dev/null @@ -1,4 +0,0 @@ -CREATE TABLE conversations ( - key TEXT PRIMARY KEY, - value TEXT -); diff --git a/crates/chat-cli/src/lib.rs b/crates/chat-cli/src/lib.rs deleted file mode 100644 index 3df8f93cf..000000000 --- a/crates/chat-cli/src/lib.rs +++ /dev/null @@ -1,17 +0,0 @@ -#![cfg(not(test))] -//! This lib.rs is only here for testing purposes. -//! `test_mcp_server/test_server.rs` is declared as a separate binary and would need a way to -//! reference types defined inside of this crate, hence the export. -pub mod api_client; -pub mod auth; -pub mod aws_common; -pub mod cli; -pub mod database; -pub mod logging; -pub mod mcp_client; -pub mod os; -pub mod request; -pub mod telemetry; -pub mod util; - -pub use mcp_client::*; diff --git a/crates/chat-cli/src/logging.rs b/crates/chat-cli/src/logging.rs deleted file mode 100644 index 3f8dbd811..000000000 --- a/crates/chat-cli/src/logging.rs +++ /dev/null @@ -1,314 +0,0 @@ -use std::fs::File; -use std::path::Path; -use std::sync::Mutex; - -use thiserror::Error; -use tracing::info; -use tracing::level_filters::LevelFilter; -use tracing_appender::non_blocking::WorkerGuard; -use tracing_subscriber::filter::Directive; -use tracing_subscriber::prelude::*; -use tracing_subscriber::{ - EnvFilter, - Registry, - fmt, -}; - -use crate::util::consts::CHAT_BINARY_NAME; -use crate::util::env_var::Q_LOG_LEVEL; - -const MAX_FILE_SIZE: u64 = 10 * 1024 * 1024; -const DEFAULT_FILTER: LevelFilter = LevelFilter::ERROR; - -static Q_LOG_LEVEL_GLOBAL: Mutex> = Mutex::new(None); -static MAX_LEVEL: Mutex> = Mutex::new(None); -static ENV_FILTER_RELOADABLE_HANDLE: Mutex>> = - Mutex::new(None); - -// A logging error -#[derive(Debug, Error)] -pub enum Error { - #[error(transparent)] - Io(#[from] std::io::Error), - #[error(transparent)] - TracingReload(#[from] tracing_subscriber::reload::Error), -} - -/// Arguments to the initialize_logging function -#[derive(Debug)] -pub struct LogArgs> { - /// The log level to use. When not set, the default log level is used. - pub log_level: Option, - /// Whether or not we log to stdout. - pub log_to_stdout: bool, - /// The log file path which we write logs to. When not set, we do not write to a file. - pub log_file_path: Option, - /// Whether we should delete the log file at each launch. - pub delete_old_log_file: bool, -} - -/// The log guard maintains tracing guards which send log information to other threads. -/// -/// This must be kept alive for logging to function as expected. -#[must_use] -#[derive(Debug)] -pub struct LogGuard { - _file_guard: Option, - _stdout_guard: Option, - _mcp_file_guard: Option, -} - -/// Initialize our application level logging using the given LogArgs. -/// -/// # Returns -/// -/// On success, this returns a guard which must be kept alive. -#[inline] -pub fn initialize_logging>(args: LogArgs) -> Result { - let filter_layer = create_filter_layer(); - let (reloadable_filter_layer, reloadable_handle) = tracing_subscriber::reload::Layer::new(filter_layer); - ENV_FILTER_RELOADABLE_HANDLE.lock().unwrap().replace(reloadable_handle); - let mut mcp_path = None; - - // First we construct the file logging layer if a file name was provided. - let (file_layer, _file_guard) = match args.log_file_path { - Some(log_file_path) => { - let log_path = log_file_path.as_ref(); - - // Make the log path parent directory if it doesn't exist. - if let Some(parent) = log_path.parent() { - if log_path.ends_with(format!("{CHAT_BINARY_NAME}.log")) { - mcp_path = Some(parent.to_path_buf()); - } - std::fs::create_dir_all(parent)?; - } - - // We delete the old log file when requested each time the logger is initialized, otherwise we only - // delete the file when it has grown too large. - if args.delete_old_log_file { - std::fs::remove_file(log_path).ok(); - } else if log_path.exists() && std::fs::metadata(log_path)?.len() > MAX_FILE_SIZE { - std::fs::remove_file(log_path)?; - } - - // Create the new log file or append to the existing one. - let file = if args.delete_old_log_file { - File::create(log_path)? - } else { - File::options().append(true).create(true).open(log_path)? - }; - - // On posix-like systems, we modify permissions so that only the owner has access. - #[cfg(unix)] - { - use std::os::unix::fs::PermissionsExt; - if let Ok(metadata) = file.metadata() { - let mut permissions = metadata.permissions(); - permissions.set_mode(0o600); - file.set_permissions(permissions).ok(); - } - } - - let (non_blocking, guard) = tracing_appender::non_blocking(file); - let file_layer = fmt::layer().with_line_number(true).with_writer(non_blocking); - - (Some(file_layer), Some(guard)) - }, - None => (None, None), - }; - - // If we log to stdout, we need to add this layer to our logger. - let (stdout_layer, _stdout_guard) = if args.log_to_stdout { - let (non_blocking, guard) = tracing_appender::non_blocking(std::io::stdout()); - let stdout_layer = fmt::layer().with_line_number(true).with_writer(non_blocking); - (Some(stdout_layer), Some(guard)) - } else { - (None, None) - }; - - // Set up for mcp servers layer if we are in chat - let (mcp_server_layer, _mcp_file_guard) = if let Some(parent) = mcp_path { - let mcp_path = parent.join("mcp.log"); - if args.delete_old_log_file { - std::fs::remove_file(&mcp_path).ok(); - } else if mcp_path.exists() && std::fs::metadata(&mcp_path)?.len() > MAX_FILE_SIZE { - std::fs::remove_file(&mcp_path)?; - } - let file = if args.delete_old_log_file { - File::create(&mcp_path)? - } else { - File::options().append(true).create(true).open(&mcp_path)? - }; - #[cfg(unix)] - { - use std::os::unix::fs::PermissionsExt; - if let Ok(metadata) = file.metadata() { - let mut permissions = metadata.permissions(); - permissions.set_mode(0o600); - file.set_permissions(permissions).ok(); - } - } - let (non_blocking, guard) = tracing_appender::non_blocking(file); - let file_layer = fmt::layer() - .with_line_number(true) - .with_writer(non_blocking) - .with_filter(EnvFilter::new("mcp=trace")); - (Some(file_layer), Some(guard)) - } else { - (None, None) - }; - - if let Some(level) = args.log_level { - set_log_level(level)?; - } - - // Finally, initialize our logging - let subscriber = tracing_subscriber::registry() - .with(reloadable_filter_layer) - .with(file_layer) - .with(stdout_layer); - - if let Some(mcp_server_layer) = mcp_server_layer { - subscriber.with(mcp_server_layer).init(); - return Ok(LogGuard { - _file_guard, - _stdout_guard, - _mcp_file_guard, - }); - } - - subscriber.init(); - - Ok(LogGuard { - _file_guard, - _stdout_guard, - _mcp_file_guard, - }) -} - -/// Get the current log level by first seeing if it is set in application, then environment, then -/// otherwise using the default -/// -/// # Returns -/// -/// Returns a string identifying the current log level. -pub fn get_log_level() -> String { - Q_LOG_LEVEL_GLOBAL - .lock() - .unwrap() - .clone() - .unwrap_or_else(|| std::env::var(Q_LOG_LEVEL).unwrap_or_else(|_| DEFAULT_FILTER.to_string())) -} - -/// Set the log level to the given level. -/// -/// # Returns -/// -/// On success, returns the old log level. -pub fn set_log_level(level: String) -> Result { - info!("Setting log level to {level:?}"); - - let old_level = get_log_level(); - *Q_LOG_LEVEL_GLOBAL.lock().unwrap() = Some(level); - - let filter_layer = create_filter_layer(); - *MAX_LEVEL.lock().unwrap() = filter_layer.max_level_hint(); - - ENV_FILTER_RELOADABLE_HANDLE - .lock() - .unwrap() - .as_ref() - .expect("set_log_level must not be called before logging is initialized") - .reload(filter_layer)?; - - Ok(old_level) -} - -/// Get the current max log level -/// -/// # Returns -/// -/// The max log level which is set every time the log level is set. -pub fn get_log_level_max() -> LevelFilter { - let max_level = *MAX_LEVEL.lock().unwrap(); - match max_level { - Some(level) => level, - None => { - let filter_layer = create_filter_layer(); - *MAX_LEVEL.lock().unwrap() = filter_layer.max_level_hint(); - filter_layer.max_level_hint().unwrap_or(DEFAULT_FILTER) - }, - } -} - -fn create_filter_layer() -> EnvFilter { - let directive = Directive::from(DEFAULT_FILTER); - - let log_level = Q_LOG_LEVEL_GLOBAL - .lock() - .unwrap() - .clone() - .or_else(|| std::env::var(Q_LOG_LEVEL).ok()); - - match log_level { - Some(level) => EnvFilter::builder() - .with_default_directive(directive) - .parse_lossy(level), - None => EnvFilter::default().add_directive(directive), - } -} - -#[cfg(test)] -mod tests { - use std::fs::read_to_string; - use std::time::Duration; - - use tracing::{ - debug, - error, - trace, - warn, - }; - - use super::*; - - #[test] - fn test_logging() { - // Create a temp path for where we write logs to. - let tempdir = tempfile::TempDir::new().unwrap(); - let log_path = tempdir.path().join("test.log"); - - // Assert that initialize logging simply doesn't panic. - let _guard = initialize_logging(LogArgs { - log_level: Some("trace".to_owned()), - log_to_stdout: true, - log_file_path: Some(&log_path), - delete_old_log_file: true, - }) - .unwrap(); - - // Test that get log level functions as expected. - assert_eq!(get_log_level(), "trace"); - - // Write some log messages out to file. (and stderr) - trace!("abc"); - debug!("def"); - info!("ghi"); - warn!("jkl"); - error!("mno"); - - // Test that set log level functions as expected. - // This also restores the default log level. - set_log_level(DEFAULT_FILTER.to_string()).unwrap(); - assert_eq!(get_log_level(), DEFAULT_FILTER.to_string()); - - // Sleep in order to ensure logs get written to file, then assert on the contents - std::thread::sleep(Duration::from_millis(100)); - let logs = read_to_string(&log_path).unwrap(); - for i in [ - "TRACE", "DEBUG", "INFO", "WARN", "ERROR", "abc", "def", "ghi", "jkl", "mno", - ] { - assert!(logs.contains(i)); - } - } -} diff --git a/crates/chat-cli/src/main.rs b/crates/chat-cli/src/main.rs deleted file mode 100644 index 1f643e7d6..000000000 --- a/crates/chat-cli/src/main.rs +++ /dev/null @@ -1,52 +0,0 @@ -mod api_client; -mod auth; -mod aws_common; -mod cli; -mod database; -mod logging; -mod mcp_client; -mod os; -mod request; -mod telemetry; -mod util; - -use std::process::ExitCode; - -use anstream::eprintln; -use clap::Parser; -use crossterm::style::Stylize; -use eyre::Result; -use logging::get_log_level_max; -use tracing::metadata::LevelFilter; - -#[global_allocator] -static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc; - -fn main() -> Result { - color_eyre::install()?; - - let parsed = match cli::Cli::try_parse() { - Ok(cli) => cli, - Err(err) => { - err.print().ok(); - return Ok(ExitCode::from(err.exit_code().try_into().unwrap_or(2))); - }, - }; - - let verbose = parsed.verbose > 0; - let runtime = tokio::runtime::Builder::new_multi_thread().enable_all().build()?; - let result = runtime.block_on(parsed.execute()); - - match result { - Ok(exit_code) => Ok(exit_code), - Err(err) => { - if verbose || get_log_level_max() > LevelFilter::INFO { - eprintln!("{} {err:?}", "error:".bold().red()); - } else { - eprintln!("{} {err}", "error:".bold().red()); - } - - Ok(ExitCode::FAILURE) - }, - } -} diff --git a/crates/chat-cli/src/mcp_client/client.rs b/crates/chat-cli/src/mcp_client/client.rs deleted file mode 100644 index 004c0623a..000000000 --- a/crates/chat-cli/src/mcp_client/client.rs +++ /dev/null @@ -1,1147 +0,0 @@ -use std::collections::HashMap; -use std::process::Stdio; -use std::sync::atomic::{ - AtomicBool, - AtomicU64, - Ordering, -}; -use std::sync::{ - Arc, - RwLock as SyncRwLock, -}; -use std::time::Duration; - -use serde::{ - Deserialize, - Serialize, -}; -use thiserror::Error; -use tokio::time; -use tokio::time::error::Elapsed; - -use super::transport::base_protocol::{ - JsonRpcMessage, - JsonRpcNotification, - JsonRpcRequest, - JsonRpcVersion, -}; -use super::transport::stdio::JsonRpcStdioTransport; -use super::transport::{ - self, - Transport, - TransportError, -}; -use super::{ - JsonRpcResponse, - Listener as _, - LogListener, - Messenger, - PaginationSupportedOps, - PromptGet, - PromptsListResult, - ResourceTemplatesListResult, - ResourcesListResult, - ServerCapabilities, - ToolsListResult, -}; -use crate::util::process::{ - Pid, - terminate_process, -}; - -pub type ClientInfo = serde_json::Value; -pub type StdioTransport = JsonRpcStdioTransport; - -/// Represents the capabilities of a client in the Model Context Protocol. -/// This structure is sent to the server during initialization to communicate -/// what features the client supports and provide information about the client. -/// When features are added to the client, these should be declared in the [From] trait implemented -/// for the struct. -#[derive(Default, Debug, Serialize)] -#[serde(rename_all = "camelCase")] -struct ClientCapabilities { - protocol_version: JsonRpcVersion, - capabilities: HashMap, - client_info: serde_json::Value, -} - -impl From for ClientCapabilities { - fn from(client_info: ClientInfo) -> Self { - ClientCapabilities { - client_info, - ..Default::default() - } - } -} - -#[derive(Debug, Deserialize)] -pub struct ClientConfig { - pub server_name: String, - pub bin_path: String, - pub args: Vec, - pub timeout: u64, - pub client_info: serde_json::Value, - pub env: Option>, -} - -#[allow(dead_code)] -#[derive(Debug, Error)] -pub enum ClientError { - #[error(transparent)] - TransportError(#[from] TransportError), - #[error(transparent)] - Io(#[from] std::io::Error), - #[error(transparent)] - Serialization(#[from] serde_json::Error), - #[error("Operation timed out: {context}")] - RuntimeError { - #[source] - source: tokio::time::error::Elapsed, - context: String, - }, - #[error("Unexpected msg type encountered")] - UnexpectedMsgType, - #[error("{0}")] - NegotiationError(String), - #[error("Failed to obtain process id")] - MissingProcessId, - #[error("Invalid path received")] - InvalidPath, - #[error("{0}")] - ProcessKillError(String), - #[error("{0}")] - PoisonError(String), -} - -impl From<(tokio::time::error::Elapsed, String)> for ClientError { - fn from((error, context): (tokio::time::error::Elapsed, String)) -> Self { - ClientError::RuntimeError { source: error, context } - } -} - -#[derive(Debug)] -pub struct Client { - server_name: String, - transport: Arc, - timeout: u64, - server_process_id: Option, - client_info: serde_json::Value, - current_id: Arc, - pub messenger: Option>, - // TODO: move this to tool manager that way all the assets are treated equally - pub prompt_gets: Arc>>, - pub is_prompts_out_of_date: Arc, -} - -impl Clone for Client { - fn clone(&self) -> Self { - Self { - server_name: self.server_name.clone(), - transport: self.transport.clone(), - timeout: self.timeout, - // Note that we cannot have an id for the clone because we would kill the original - // process when we drop the clone - server_process_id: None, - client_info: self.client_info.clone(), - current_id: self.current_id.clone(), - messenger: None, - prompt_gets: self.prompt_gets.clone(), - is_prompts_out_of_date: self.is_prompts_out_of_date.clone(), - } - } -} - -impl Client { - pub fn from_config(config: ClientConfig) -> Result { - let ClientConfig { - server_name, - bin_path, - args, - timeout, - client_info, - env, - } = config; - let child = { - let expanded_bin_path = shellexpand::tilde(&bin_path); - - // On Windows, we need to use cmd.exe to run the binary with arguments because Tokio - // always assumes that the program has an .exe extension, which is not the case for - // helpers like `uvx` or `npx`. - let mut command = if cfg!(windows) { - let mut cmd = tokio::process::Command::new("cmd.exe"); - cmd.args(["/C", &Self::build_windows_command(&expanded_bin_path, args)]); - cmd - } else { - let mut cmd = tokio::process::Command::new(expanded_bin_path.to_string()); - cmd.args(args); - cmd - }; - - command - .stdin(Stdio::piped()) - .stdout(Stdio::piped()) - .stderr(Stdio::piped()) - .envs(std::env::vars()); - - #[cfg(not(windows))] - command.process_group(0); - - if let Some(env) = env { - for (env_name, env_value) in env { - command.env(env_name, env_value); - } - } - - command.spawn()? - }; - - let server_process_id = child.id().ok_or(ClientError::MissingProcessId)?; - let server_process_id = Some(Pid::from_u32(server_process_id)); - - let transport = Arc::new(transport::stdio::JsonRpcStdioTransport::client(child)?); - Ok(Self { - server_name, - transport, - timeout, - server_process_id, - client_info, - current_id: Arc::new(AtomicU64::new(0)), - messenger: None, - prompt_gets: Arc::new(SyncRwLock::new(HashMap::new())), - is_prompts_out_of_date: Arc::new(AtomicBool::new(false)), - }) - } - - fn build_windows_command(bin_path: &str, args: Vec) -> String { - let mut parts = Vec::new(); - - // Add the binary path, quoted if necessary - parts.push(Self::quote_windows_arg(bin_path)); - - // Add all arguments, quoted if necessary - for arg in args { - parts.push(Self::quote_windows_arg(&arg)); - } - - parts.join(" ") - } - - fn quote_windows_arg(arg: &str) -> String { - // If the argument doesn't need quoting, return as-is - if !arg.chars().any(|c| " \t\n\r\"".contains(c)) { - return arg.to_string(); - } - - let mut result = String::from("\""); - let mut backslashes = 0; - - for c in arg.chars() { - match c { - '\\' => { - backslashes += 1; - result.push('\\'); - }, - '"' => { - // Escape all preceding backslashes and the quote - for _ in 0..backslashes { - result.push('\\'); - } - result.push_str("\\\""); - backslashes = 0; - }, - _ => { - backslashes = 0; - result.push(c); - }, - } - } - - // Escape trailing backslashes before the closing quote - for _ in 0..backslashes { - result.push('\\'); - } - - result.push('"'); - result - } -} - -impl Drop for Client -where - T: Transport, -{ - // IF the servers are implemented well, they will shutdown once the pipe closes. - // This drop trait is here as a fail safe to ensure we don't leave behind any orphans. - fn drop(&mut self) { - if let Some(process_id) = self.server_process_id { - let _ = terminate_process(process_id); - } - } -} - -impl Client -where - T: Transport, -{ - /// Exchange of information specified as per https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/lifecycle/#initialization - /// - /// Also done are the following: - /// - Spawns task for listening to server driven workflows - /// - Spawns tasks to ask for relevant info such as tools and prompts in accordance to server - /// capabilities received - pub async fn init(&self) -> Result { - let transport_ref = self.transport.clone(); - let server_name = self.server_name.clone(); - - // Spawning a task to listen and log stderr output - tokio::spawn(async move { - let mut log_listener = transport_ref.get_log_listener(); - loop { - match log_listener.recv().await { - Ok(msg) => { - tracing::trace!(target: "mcp", "{server_name} logged {}", msg); - }, - Err(e) => { - tracing::error!( - "Error encountered while reading from stderr for {server_name}: {:?}\nEnding stderr listening task.", - e - ); - break; - }, - } - } - }); - - let init_params = Some({ - let client_cap = ClientCapabilities::from(self.client_info.clone()); - serde_json::json!(client_cap) - }); - let init_resp = self.request("initialize", init_params).await?; - if let Err(e) = examine_server_capabilities(&init_resp) { - return Err(ClientError::NegotiationError(format!( - "Client {} has failed to negotiate server capabilities with server: {:?}", - self.server_name, e - ))); - } - let cap = { - let result = init_resp.result.ok_or(ClientError::NegotiationError(format!( - "Server {} init resp is missing result", - self.server_name - )))?; - let cap = result - .get("capabilities") - .ok_or(ClientError::NegotiationError(format!( - "Server {} init resp result is missing capabilities", - self.server_name - )))? - .clone(); - serde_json::from_value::(cap)? - }; - self.notify("initialized", None).await?; - - // TODO: group this into examine_server_capabilities - // Prefetch prompts in the background. We should only do this after the server has been - // initialized - if cap.prompts.is_some() { - self.is_prompts_out_of_date.store(true, Ordering::Relaxed); - let client_ref = (*self).clone(); - let messenger_ref = self.messenger.as_ref().map(|m| m.duplicate()); - tokio::spawn(async move { - fetch_prompts_and_notify_with_messenger(&client_ref, messenger_ref.as_ref()).await; - }); - } - if cap.tools.is_some() { - let client_ref = (*self).clone(); - let messenger_ref = self.messenger.as_ref().map(|m| m.duplicate()); - tokio::spawn(async move { - fetch_tools_and_notify_with_messenger(&client_ref, messenger_ref.as_ref()).await; - }); - } - - let transport_ref = self.transport.clone(); - let server_name = self.server_name.clone(); - let messenger_ref = self.messenger.as_ref().map(|m| m.duplicate()); - let client_ref = (*self).clone(); - - let prompts_list_changed_supported = cap.prompts.as_ref().is_some_and(|p| p.get("listChanged").is_some()); - let tools_list_changed_supported = cap.tools.as_ref().is_some_and(|t| t.get("listChanged").is_some()); - tokio::spawn(async move { - let mut listener = transport_ref.get_listener(); - loop { - match listener.recv().await { - Ok(msg) => { - match msg { - JsonRpcMessage::Request(_req) => {}, - JsonRpcMessage::Notification(notif) => { - let JsonRpcNotification { method, params, .. } = notif; - match method.as_str() { - "notifications/message" | "message" => { - let level = params - .as_ref() - .and_then(|p| p.get("level")) - .and_then(|v| serde_json::to_string(v).ok()); - let data = params - .as_ref() - .and_then(|p| p.get("data")) - .and_then(|v| serde_json::to_string(v).ok()); - if let (Some(level), Some(data)) = (level, data) { - match level.to_lowercase().as_str() { - "error" => { - tracing::error!(target: "mcp", "{}: {}", server_name, data); - }, - "warn" => { - tracing::warn!(target: "mcp", "{}: {}", server_name, data); - }, - "info" => { - tracing::info!(target: "mcp", "{}: {}", server_name, data); - }, - "debug" => { - tracing::debug!(target: "mcp", "{}: {}", server_name, data); - }, - "trace" => { - tracing::trace!(target: "mcp", "{}: {}", server_name, data); - }, - _ => {}, - } - } - }, - "notifications/prompts/list_changed" | "prompts/list_changed" - if prompts_list_changed_supported => - { - // TODO: after we have moved the prompts to the tool - // manager we follow the same workflow as the list changed - // for tools - fetch_prompts_and_notify_with_messenger(&client_ref, messenger_ref.as_ref()) - .await; - client_ref.is_prompts_out_of_date.store(true, Ordering::Release); - }, - "notifications/tools/list_changed" | "tools/list_changed" - if tools_list_changed_supported => - { - fetch_tools_and_notify_with_messenger(&client_ref, messenger_ref.as_ref()) - .await; - }, - _ => {}, - } - }, - JsonRpcMessage::Response(_resp) => { /* noop since direct response is handled inside the request api */ - }, - } - }, - Err(e) => { - tracing::error!("Background listening thread for client {}: {:?}", server_name, e); - // If we don't have anything on the other end, we should just end the task - // now - if let TransportError::RecvError(tokio::sync::broadcast::error::RecvError::Closed) = e { - tracing::error!( - "All senders dropped for transport layer for server {}: {:?}. This likely means the mcp server process is no longer running.", - server_name, - e - ); - break; - } - }, - } - } - }); - - Ok(cap) - } - - /// Sends a request to the server associated. - /// This call will yield until a response is received. - pub async fn request( - &self, - method: &str, - params: Option, - ) -> Result { - let send_map_err = |e: Elapsed| (e, method.to_string()); - let recv_map_err = |e: Elapsed| (e, format!("recv for {method}")); - let mut id = self.get_id(); - let request = JsonRpcRequest { - jsonrpc: JsonRpcVersion::default(), - id, - method: method.to_owned(), - params, - }; - tracing::trace!(target: "mcp", "To {}:\n{:#?}", self.server_name, request); - let msg = JsonRpcMessage::Request(request); - time::timeout(Duration::from_millis(self.timeout), self.transport.send(&msg)) - .await - .map_err(send_map_err)??; - let mut listener = self.transport.get_listener(); - let mut resp = time::timeout(Duration::from_millis(self.timeout), async { - // we want to ignore all other messages sent by the server at this point and let the - // background loop handle them - // We also want to ignore all messages emitted by the server to its stdout that does - // not deserialize into a valid JsonRpcMessage (they are not supposed to do this but - // too many people complained about this so we are adding this safeguard in) - loop { - if let Ok(JsonRpcMessage::Response(resp)) = listener.recv().await { - if resp.id == id { - break Ok::(resp); - } - } - } - }) - .await - .map_err(recv_map_err)??; - // Pagination support: https://spec.modelcontextprotocol.io/specification/2024-11-05/server/utilities/pagination/#pagination-model - let mut next_cursor = resp.result.as_ref().and_then(|v| v.get("nextCursor")); - if next_cursor.is_some() { - let mut current_resp = resp.clone(); - let mut results = Vec::::new(); - let pagination_supported_ops = { - let maybe_pagination_supported_op: Result = method.try_into(); - maybe_pagination_supported_op.ok() - }; - if let Some(ops) = pagination_supported_ops { - loop { - let result = current_resp.result.as_ref().cloned().unwrap(); - let mut list: Vec = match ops { - PaginationSupportedOps::ResourcesList => { - let ResourcesListResult { resources: list, .. } = - serde_json::from_value::(result) - .map_err(ClientError::Serialization)?; - list - }, - PaginationSupportedOps::ResourceTemplatesList => { - let ResourceTemplatesListResult { - resource_templates: list, - .. - } = serde_json::from_value::(result) - .map_err(ClientError::Serialization)?; - list - }, - PaginationSupportedOps::PromptsList => { - let PromptsListResult { prompts: list, .. } = - serde_json::from_value::(result) - .map_err(ClientError::Serialization)?; - list - }, - PaginationSupportedOps::ToolsList => { - let ToolsListResult { tools: list, .. } = serde_json::from_value::(result) - .map_err(ClientError::Serialization)?; - list - }, - }; - results.append(&mut list); - if next_cursor.is_none() { - break; - } - id = self.get_id(); - let next_request = JsonRpcRequest { - jsonrpc: JsonRpcVersion::default(), - id, - method: method.to_owned(), - params: Some(serde_json::json!({ - "cursor": next_cursor, - })), - }; - let msg = JsonRpcMessage::Request(next_request); - time::timeout(Duration::from_millis(self.timeout), self.transport.send(&msg)) - .await - .map_err(send_map_err)??; - let resp = time::timeout(Duration::from_millis(self.timeout), async { - loop { - if let Ok(JsonRpcMessage::Response(resp)) = listener.recv().await { - if resp.id == id { - break Ok::(resp); - } - } - } - }) - .await - .map_err(recv_map_err)??; - current_resp = resp; - next_cursor = current_resp.result.as_ref().and_then(|v| v.get("nextCursor")); - } - resp.result = Some({ - let mut map = serde_json::Map::new(); - map.insert(ops.as_key().to_owned(), serde_json::to_value(results)?); - serde_json::to_value(map)? - }); - } - } - tracing::trace!(target: "mcp", "From {}:\n{:#?}", self.server_name, resp); - Ok(resp) - } - - /// Sends a notification to the server associated. - /// Notifications are requests that expect no responses. - pub async fn notify(&self, method: &str, params: Option) -> Result<(), ClientError> { - let send_map_err = |e: Elapsed| (e, method.to_string()); - let notification = JsonRpcNotification { - jsonrpc: JsonRpcVersion::default(), - method: format!("notifications/{}", method), - params, - }; - let msg = JsonRpcMessage::Notification(notification); - Ok( - time::timeout(Duration::from_millis(self.timeout), self.transport.send(&msg)) - .await - .map_err(send_map_err)??, - ) - } - - fn get_id(&self) -> u64 { - self.current_id.fetch_add(1, Ordering::SeqCst) - } -} - -fn examine_server_capabilities(ser_cap: &JsonRpcResponse) -> Result<(), ClientError> { - // Check the jrpc version. - // Currently we are only proceeding if the versions are EXACTLY the same. - let jrpc_version = ser_cap.jsonrpc.as_u32_vec(); - let client_jrpc_version = JsonRpcVersion::default().as_u32_vec(); - for (sv, cv) in jrpc_version.iter().zip(client_jrpc_version.iter()) { - if sv != cv { - return Err(ClientError::NegotiationError( - "Incompatible jrpc version between server and client".to_owned(), - )); - } - } - Ok(()) -} - -// TODO: after we move prompts to tool manager, use the messenger to notify the listener spawned by -// tool manager to update its own field. Currently this function does not make use of the -// messesnger. -#[allow(clippy::borrowed_box)] -async fn fetch_prompts_and_notify_with_messenger(client: &Client, _messenger: Option<&Box>) -where - T: Transport, -{ - let Ok(resp) = client.request("prompts/list", None).await else { - tracing::error!("Prompt list query failed for {0}", client.server_name); - return; - }; - let Some(result) = resp.result else { - tracing::warn!("Prompt list query returned no result for {0}", client.server_name); - return; - }; - let Some(prompts) = result.get("prompts") else { - tracing::warn!( - "Prompt list query result contained no field named prompts for {0}", - client.server_name - ); - return; - }; - let Ok(prompts) = serde_json::from_value::>(prompts.clone()) else { - tracing::error!("Prompt list query deserialization failed for {0}", client.server_name); - return; - }; - let Ok(mut lock) = client.prompt_gets.write() else { - tracing::error!( - "Failed to obtain write lock for prompt list query for {0}", - client.server_name - ); - return; - }; - lock.clear(); - for prompt in prompts { - let name = prompt.name.clone(); - lock.insert(name, prompt); - } -} - -#[allow(clippy::borrowed_box)] -async fn fetch_tools_and_notify_with_messenger(client: &Client, messenger: Option<&Box>) -where - T: Transport, -{ - // TODO: decouple pagination logic from request and have page fetching logic here - // instead - let tool_list_result = 'tool_list_result: { - let resp = match client.request("tools/list", None).await { - Ok(resp) => resp, - Err(e) => break 'tool_list_result Err(e.into()), - }; - if let Some(error) = resp.error { - let msg = format!("Failed to retrieve tool list for {}: {:?}", client.server_name, error); - break 'tool_list_result Err(eyre::eyre!(msg)); - } - let Some(result) = resp.result else { - let msg = format!("Tool list response from {} is missing result", client.server_name); - break 'tool_list_result Err(eyre::eyre!(msg)); - }; - let tool_list_result = match serde_json::from_value::(result) { - Ok(result) => result, - Err(e) => { - let msg = format!("Failed to deserialize tool result from {}: {:?}", client.server_name, e); - break 'tool_list_result Err(eyre::eyre!(msg)); - }, - }; - Ok::(tool_list_result) - }; - if let Some(messenger) = messenger { - let _ = messenger - .send_tools_list_result(tool_list_result) - .await - .map_err(|e| tracing::error!("Failed to send tool result through messenger {:?}", e)); - } -} - -#[cfg(test)] -mod tests { - use std::path::PathBuf; - - use serde_json::Value; - - use super::*; - const TEST_BIN_OUT_DIR: &str = "target/debug"; - const TEST_SERVER_NAME: &str = "test_mcp_server"; - - fn get_workspace_root() -> PathBuf { - let output = std::process::Command::new("cargo") - .args(["metadata", "--format-version=1", "--no-deps"]) - .output() - .expect("Failed to execute cargo metadata"); - - let metadata: serde_json::Value = - serde_json::from_slice(&output.stdout).expect("Failed to parse cargo metadata"); - - let workspace_root = metadata["workspace_root"] - .as_str() - .expect("Failed to find workspace_root in metadata"); - - PathBuf::from(workspace_root) - } - - #[tokio::test(flavor = "multi_thread")] - // For some reason this test is quite flakey when ran in the CI but not on developer's - // machines. As a result it is hard to debug, hence we are ignoring it for now. - #[ignore] - async fn test_client_stdio() { - std::process::Command::new("cargo") - .args(["build", "--bin", TEST_SERVER_NAME]) - .status() - .expect("Failed to build binary"); - let workspace_root = get_workspace_root(); - let bin_path = workspace_root.join(TEST_BIN_OUT_DIR).join(TEST_SERVER_NAME); - println!("bin path: {}", bin_path.to_str().unwrap_or("no path found")); - - // Testing 2 concurrent sessions to make sure transport layer does not overlap. - let client_info_one = serde_json::json!({ - "name": "TestClientOne", - "version": "1.0.0" - }); - let client_config_one = ClientConfig { - server_name: "test_tool".to_owned(), - bin_path: bin_path.to_str().unwrap().to_string(), - args: ["1".to_owned()].to_vec(), - timeout: 120 * 1000, - client_info: client_info_one.clone(), - env: { - let mut map = HashMap::::new(); - map.insert("ENV_ONE".to_owned(), "1".to_owned()); - map.insert("ENV_TWO".to_owned(), "2".to_owned()); - Some(map) - }, - }; - let client_info_two = serde_json::json!({ - "name": "TestClientTwo", - "version": "1.0.0" - }); - let client_config_two = ClientConfig { - server_name: "test_tool".to_owned(), - bin_path: bin_path.to_str().unwrap().to_string(), - args: ["2".to_owned()].to_vec(), - timeout: 120 * 1000, - client_info: client_info_two.clone(), - env: { - let mut map = HashMap::::new(); - map.insert("ENV_ONE".to_owned(), "1".to_owned()); - map.insert("ENV_TWO".to_owned(), "2".to_owned()); - Some(map) - }, - }; - let mut client_one = Client::::from_config(client_config_one).expect("Failed to create client"); - let mut client_two = Client::::from_config(client_config_two).expect("Failed to create client"); - let client_one_cap = ClientCapabilities::from(client_info_one); - let client_two_cap = ClientCapabilities::from(client_info_two); - - let (res_one, res_two) = tokio::join!( - time::timeout( - time::Duration::from_secs(10), - test_client_routine(&mut client_one, serde_json::json!(client_one_cap)) - ), - time::timeout( - time::Duration::from_secs(10), - test_client_routine(&mut client_two, serde_json::json!(client_two_cap)) - ) - ); - let res_one = res_one.expect("Client one timed out"); - let res_two = res_two.expect("Client two timed out"); - assert!(res_one.is_ok()); - assert!(res_two.is_ok()); - } - - #[allow(clippy::await_holding_lock)] - async fn test_client_routine( - client: &mut Client, - cap_sent: serde_json::Value, - ) -> Result<(), Box> { - // Test init - let _ = client.init().await.expect("Client init failed"); - tokio::time::sleep(time::Duration::from_millis(1500)).await; - let client_capabilities_sent = client - .request("verify_init_ack_sent", None) - .await - .expect("Verify init ack mock request failed"); - let has_server_recvd_init_ack = client_capabilities_sent - .result - .expect("Failed to retrieve client capabilities sent."); - assert_eq!(has_server_recvd_init_ack.to_string(), "true"); - let cap_recvd = client - .request("verify_init_params_sent", None) - .await - .expect("Verify init params mock request failed"); - let cap_recvd = cap_recvd - .result - .expect("Verify init params mock request does not contain required field (result)"); - assert!(are_json_values_equal(&cap_sent, &cap_recvd)); - - // test list tools - let fake_tool_names = ["get_weather_one", "get_weather_two", "get_weather_three"]; - let mock_result_spec = fake_tool_names.map(create_fake_tool_spec); - let mock_tool_specs_for_verify = serde_json::json!(mock_result_spec.clone()); - let mock_tool_specs_prep_param = mock_result_spec - .iter() - .zip(fake_tool_names.iter()) - .map(|(v, n)| { - serde_json::json!({ - "key": (*n).to_string(), - "value": v - }) - }) - .collect::>(); - let mock_tool_specs_prep_param = - serde_json::to_value(mock_tool_specs_prep_param).expect("Failed to create mock tool specs prep param"); - let _ = client - .request("store_mock_tool_spec", Some(mock_tool_specs_prep_param)) - .await - .expect("Mock tool spec prep failed"); - let tool_spec_recvd = client.request("tools/list", None).await.expect("List tools failed"); - assert!(are_json_values_equal( - tool_spec_recvd - .result - .as_ref() - .and_then(|v| v.get("tools")) - .expect("Failed to retrieve tool specs from result received"), - &mock_tool_specs_for_verify - )); - - // Test list prompts directly - let fake_prompt_names = ["code_review_one", "code_review_two", "code_review_three"]; - let mock_result_prompts = fake_prompt_names.map(create_fake_prompts); - let mock_prompts_for_verify = serde_json::json!(mock_result_prompts.clone()); - let mock_prompts_prep_param = mock_result_prompts - .iter() - .zip(fake_prompt_names.iter()) - .map(|(v, n)| { - serde_json::json!({ - "key": (*n).to_string(), - "value": v - }) - }) - .collect::>(); - let mock_prompts_prep_param = - serde_json::to_value(mock_prompts_prep_param).expect("Failed to create mock prompts prep param"); - let _ = client - .request("store_mock_prompts", Some(mock_prompts_prep_param)) - .await - .expect("Mock prompt prep failed"); - let prompts_recvd = client.request("prompts/list", None).await.expect("List prompts failed"); - client.is_prompts_out_of_date.store(false, Ordering::Release); - assert!(are_json_values_equal( - prompts_recvd - .result - .as_ref() - .and_then(|v| v.get("prompts")) - .expect("Failed to retrieve prompts from results received"), - &mock_prompts_for_verify - )); - - // Test prompts list changed - let fake_prompt_names = ["code_review_four", "code_review_five", "code_review_six"]; - let mock_result_prompts = fake_prompt_names.map(create_fake_prompts); - let mock_prompts_prep_param = mock_result_prompts - .iter() - .zip(fake_prompt_names.iter()) - .map(|(v, n)| { - serde_json::json!({ - "key": (*n).to_string(), - "value": v - }) - }) - .collect::>(); - let mock_prompts_prep_param = - serde_json::to_value(mock_prompts_prep_param).expect("Failed to create mock prompts prep param"); - let _ = client - .request("store_mock_prompts", Some(mock_prompts_prep_param)) - .await - .expect("Mock new prompt request failed"); - // After we send the signal for the server to clear prompts, we should be receiving signal - // to fetch for new prompts, after which we should be getting no prompts. - let is_prompts_out_of_date = client.is_prompts_out_of_date.clone(); - let wait_for_new_prompts = async move { - while !is_prompts_out_of_date.load(Ordering::Acquire) { - tokio::time::sleep(time::Duration::from_millis(100)).await; - } - }; - time::timeout(time::Duration::from_secs(5), wait_for_new_prompts) - .await - .expect("Timed out while waiting for new prompts"); - let new_prompts = client.prompt_gets.read().expect("Failed to read new prompts"); - for k in new_prompts.keys() { - assert!(fake_prompt_names.contains(&k.as_str())); - } - - // Test env var inclusion - let env_vars = client.request("get_env_vars", None).await.expect("Get env vars failed"); - let env_one = env_vars - .result - .as_ref() - .expect("Failed to retrieve results from env var request") - .get("ENV_ONE") - .expect("Failed to retrieve env one from env var request"); - let env_two = env_vars - .result - .as_ref() - .expect("Failed to retrieve results from env var request") - .get("ENV_TWO") - .expect("Failed to retrieve env two from env var request"); - let env_one_as_str = serde_json::to_string(env_one).expect("Failed to convert env one to string"); - let env_two_as_str = serde_json::to_string(env_two).expect("Failed to convert env two to string"); - assert_eq!(env_one_as_str, "\"1\"".to_string()); - assert_eq!(env_two_as_str, "\"2\"".to_string()); - - Ok(()) - } - - fn are_json_values_equal(a: &Value, b: &Value) -> bool { - match (a, b) { - (Value::Null, Value::Null) => true, - (Value::Bool(a_val), Value::Bool(b_val)) => a_val == b_val, - (Value::Number(a_val), Value::Number(b_val)) => a_val == b_val, - (Value::String(a_val), Value::String(b_val)) => a_val == b_val, - (Value::Array(a_arr), Value::Array(b_arr)) => { - if a_arr.len() != b_arr.len() { - return false; - } - a_arr - .iter() - .zip(b_arr.iter()) - .all(|(a_item, b_item)| are_json_values_equal(a_item, b_item)) - }, - (Value::Object(a_obj), Value::Object(b_obj)) => { - if a_obj.len() != b_obj.len() { - return false; - } - a_obj.iter().all(|(key, a_value)| match b_obj.get(key) { - Some(b_value) => are_json_values_equal(a_value, b_value), - None => false, - }) - }, - _ => false, - } - } - - fn create_fake_tool_spec(name: &str) -> serde_json::Value { - serde_json::json!({ - "name": name, - "description": "Get current weather information for a location", - "inputSchema": { - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "City name or zip code" - } - }, - "required": ["location"] - } - }) - } - - fn create_fake_prompts(name: &str) -> serde_json::Value { - serde_json::json!({ - "name": name, - "description": "Asks the LLM to analyze code quality and suggest improvements", - "arguments": [ - { - "name": "code", - "description": "The code to review", - "required": true - } - ] - }) - } - - #[cfg(windows)] - mod windows_command_tests { - use super::*; - use crate::mcp_client::transport::stdio::JsonRpcStdioTransport as StdioTransport; - - #[test] - fn test_quote_windows_arg_no_special_chars() { - let result = Client::::quote_windows_arg("simple"); - assert_eq!(result, "simple"); - } - - #[test] - fn test_quote_windows_arg_with_spaces() { - let result = Client::::quote_windows_arg("with spaces"); - assert_eq!(result, "\"with spaces\""); - } - - #[test] - fn test_quote_windows_arg_with_quotes() { - let result = Client::::quote_windows_arg("with \"quotes\""); - assert_eq!(result, "\"with \\\"quotes\\\"\""); - } - - #[test] - fn test_quote_windows_arg_with_backslashes() { - let result = Client::::quote_windows_arg("path\\to\\file"); - assert_eq!(result, "path\\to\\file"); - } - - #[test] - fn test_quote_windows_arg_with_trailing_backslashes() { - let result = Client::::quote_windows_arg("path\\to\\dir\\"); - assert_eq!(result, "path\\to\\dir\\"); - } - - #[test] - fn test_quote_windows_arg_with_backslashes_before_quote() { - let result = Client::::quote_windows_arg("path\\\\\"quoted\""); - assert_eq!(result, "\"path\\\\\\\\\\\"quoted\\\"\""); - } - - #[test] - fn test_quote_windows_arg_complex_case() { - let result = Client::::quote_windows_arg("C:\\Program Files\\My App\\bin\\app.exe"); - assert_eq!(result, "\"C:\\Program Files\\My App\\bin\\app.exe\""); - } - - #[test] - fn test_quote_windows_arg_with_tabs_and_newlines() { - let result = Client::::quote_windows_arg("with\ttabs\nand\rnewlines"); - assert_eq!(result, "\"with\ttabs\nand\rnewlines\""); - } - - #[test] - fn test_quote_windows_arg_edge_case_only_backslashes() { - let result = Client::::quote_windows_arg("\\\\\\"); - assert_eq!(result, "\\\\\\"); - } - - #[test] - fn test_quote_windows_arg_edge_case_only_quotes() { - let result = Client::::quote_windows_arg("\"\"\""); - assert_eq!(result, "\"\\\"\\\"\\\"\""); - } - - // Tests for build_windows_command function - #[test] - fn test_build_windows_command_empty_args() { - let bin_path = "myapp"; - let args = vec![]; - let result = Client::::build_windows_command(bin_path, args); - assert_eq!(result, "myapp"); - } - - #[test] - fn test_build_windows_command_uvx_example() { - let bin_path = "uvx"; - let args = vec!["mcp-server-fetch".to_string()]; - let result = Client::::build_windows_command(bin_path, args); - assert_eq!(result, "uvx mcp-server-fetch"); - } - - #[test] - fn test_build_windows_command_npx_example() { - let bin_path = "npx"; - let args = vec!["-y".to_string(), "@modelcontextprotocol/server-memory".to_string()]; - let result = Client::::build_windows_command(bin_path, args); - assert_eq!(result, "npx -y @modelcontextprotocol/server-memory"); - } - - #[test] - fn test_build_windows_command_docker_example() { - let bin_path = "docker"; - let args = vec![ - "run".to_string(), - "-i".to_string(), - "--rm".to_string(), - "-e".to_string(), - "GITHUB_PERSONAL_ACCESS_TOKEN".to_string(), - "ghcr.io/github/github-mcp-server".to_string(), - ]; - let result = Client::::build_windows_command(bin_path, args); - assert_eq!( - result, - "docker run -i --rm -e GITHUB_PERSONAL_ACCESS_TOKEN ghcr.io/github/github-mcp-server" - ); - } - - #[test] - fn test_build_windows_command_with_quotes_in_args() { - let bin_path = "myapp"; - let args = vec!["--config".to_string(), "{\"key\": \"value\"}".to_string()]; - let result = Client::::build_windows_command(bin_path, args); - assert_eq!(result, "myapp --config \"{\\\"key\\\": \\\"value\\\"}\""); - } - - #[test] - fn test_build_windows_command_with_spaces_in_path() { - let bin_path = "C:\\Program Files\\My App\\bin\\app.exe"; - let args = vec!["--input".to_string(), "file with spaces.txt".to_string()]; - let result = Client::::build_windows_command(bin_path, args); - assert_eq!( - result, - "\"C:\\Program Files\\My App\\bin\\app.exe\" --input \"file with spaces.txt\"" - ); - } - - #[test] - fn test_build_windows_command_complex_args() { - let bin_path = "myapp"; - let args = vec![ - "--config".to_string(), - "C:\\Users\\test\\config.json".to_string(), - "--output".to_string(), - "C:\\Output\\result file.txt".to_string(), - "--verbose".to_string(), - ]; - let result = Client::::build_windows_command(bin_path, args); - assert_eq!( - result, - "myapp --config C:\\Users\\test\\config.json --output \"C:\\Output\\result file.txt\" --verbose" - ); - } - - #[test] - fn test_build_windows_command_with_environment_variables() { - let bin_path = "cmd"; - let args = vec!["/c".to_string(), "echo %PATH%".to_string()]; - let result = Client::::build_windows_command(bin_path, args); - assert_eq!(result, "cmd /c \"echo %PATH%\""); - } - - #[test] - fn test_build_windows_command_real_world_python() { - let bin_path = "python"; - let args = vec![ - "-m".to_string(), - "mcp_server".to_string(), - "--config".to_string(), - "C:\\configs\\server.json".to_string(), - ]; - let result = Client::::build_windows_command(bin_path, args); - assert_eq!(result, "python -m mcp_server --config C:\\configs\\server.json"); - } - } -} diff --git a/crates/chat-cli/src/mcp_client/error.rs b/crates/chat-cli/src/mcp_client/error.rs deleted file mode 100644 index 01f77cfa8..000000000 --- a/crates/chat-cli/src/mcp_client/error.rs +++ /dev/null @@ -1,66 +0,0 @@ -/// Error codes as defined in the MCP protocol. -/// -/// These error codes are based on the JSON-RPC 2.0 specification with additional -/// MCP-specific error codes in the -32000 to -32099 range. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -#[repr(i32)] -pub enum ErrorCode { - /// Invalid JSON was received by the server. - /// An error occurred on the server while parsing the JSON text. - ParseError = -32700, - - /// The JSON sent is not a valid Request object. - InvalidRequest = -32600, - - /// The method does not exist / is not available. - MethodNotFound = -32601, - - /// Invalid method parameter(s). - InvalidParams = -32602, - - /// Internal JSON-RPC error. - InternalError = -32603, - - /// Server has not been initialized. - /// This error is returned when a request is made before the server - /// has been properly initialized. - ServerNotInitialized = -32002, - - /// Unknown error code. - /// This error is returned when an error code is received that is not - /// recognized by the implementation. - Unknown = -32001, - - /// Request failed. - /// This error is returned when a request fails for a reason not covered - /// by other error codes. - RequestFailed = -32000, -} - -impl From for ErrorCode { - fn from(code: i32) -> Self { - match code { - -32700 => ErrorCode::ParseError, - -32600 => ErrorCode::InvalidRequest, - -32601 => ErrorCode::MethodNotFound, - -32602 => ErrorCode::InvalidParams, - -32603 => ErrorCode::InternalError, - -32002 => ErrorCode::ServerNotInitialized, - -32001 => ErrorCode::Unknown, - -32000 => ErrorCode::RequestFailed, - _ => ErrorCode::Unknown, - } - } -} - -impl From for i32 { - fn from(code: ErrorCode) -> Self { - code as i32 - } -} - -impl std::fmt::Display for ErrorCode { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{:?}", self) - } -} diff --git a/crates/chat-cli/src/mcp_client/facilitator_types.rs b/crates/chat-cli/src/mcp_client/facilitator_types.rs deleted file mode 100644 index 87fbd79b2..000000000 --- a/crates/chat-cli/src/mcp_client/facilitator_types.rs +++ /dev/null @@ -1,248 +0,0 @@ -use serde::{ - Deserialize, - Serialize, -}; -use thiserror::Error; - -/// https://spec.modelcontextprotocol.io/specification/2024-11-05/server/utilities/pagination/#operations-supporting-pagination -#[allow(clippy::enum_variant_names)] -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum PaginationSupportedOps { - ResourcesList, - ResourceTemplatesList, - PromptsList, - ToolsList, -} - -impl PaginationSupportedOps { - pub fn as_key(&self) -> &str { - match self { - PaginationSupportedOps::ResourcesList => "resources", - PaginationSupportedOps::ResourceTemplatesList => "resourceTemplates", - PaginationSupportedOps::PromptsList => "prompts", - PaginationSupportedOps::ToolsList => "tools", - } - } -} - -impl TryFrom<&str> for PaginationSupportedOps { - type Error = OpsConversionError; - - fn try_from(value: &str) -> Result { - match value { - "resources/list" => Ok(PaginationSupportedOps::ResourcesList), - "resources/templates/list" => Ok(PaginationSupportedOps::ResourceTemplatesList), - "prompts/list" => Ok(PaginationSupportedOps::PromptsList), - "tools/list" => Ok(PaginationSupportedOps::ToolsList), - _ => Err(OpsConversionError::InvalidMethod), - } - } -} - -#[derive(Error, Debug)] -pub enum OpsConversionError { - #[error("Invalid method encountered")] - InvalidMethod, -} - -#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)] -#[serde(rename_all = "camelCase")] -/// Role assumed for a particular message -pub enum Role { - User, - Assistant, -} - -impl std::fmt::Display for Role { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Role::User => write!(f, "user"), - Role::Assistant => write!(f, "assistant"), - } - } -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -/// Result of listing resources operation -pub struct ResourcesListResult { - /// List of resources - pub resources: Vec, - /// Optional cursor for pagination - #[serde(skip_serializing_if = "Option::is_none")] - pub next_cursor: Option, -} - -/// Result of listing resource templates operation -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct ResourceTemplatesListResult { - /// List of resource templates - pub resource_templates: Vec, - /// Optional cursor for pagination - #[serde(skip_serializing_if = "Option::is_none")] - pub next_cursor: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -/// Result of prompt listing query -pub struct PromptsListResult { - /// List of prompts - pub prompts: Vec, - /// Optional cursor for pagination - #[serde(skip_serializing_if = "Option::is_none")] - pub next_cursor: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -/// Represents an argument to be supplied to a [PromptGet] -pub struct PromptGetArg { - /// The name identifier of the prompt - pub name: String, - /// Optional description providing context about the prompt - #[serde(skip_serializing_if = "Option::is_none")] - pub description: Option, - /// Indicates whether a response to this prompt is required - /// If not specified, defaults to false - #[serde(skip_serializing_if = "Option::is_none")] - pub required: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -/// Represents a request to get a prompt from a mcp server -pub struct PromptGet { - /// Unique identifier for the prompt - pub name: String, - /// Optional description providing context about the prompt's purpose - #[serde(skip_serializing_if = "Option::is_none")] - pub description: Option, - /// Optional list of arguments that define the structure of information to be collected - #[serde(skip_serializing_if = "Option::is_none")] - pub arguments: Option>, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -/// `result` field in [JsonRpcResponse] from a `prompts/get` request -pub struct PromptGetResult { - #[serde(skip_serializing_if = "Option::is_none")] - pub description: Option, - pub messages: Vec, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -/// Completed prompt from `prompts/get` to be returned by a mcp server -pub struct Prompt { - pub role: Role, - pub content: MessageContent, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -/// Result of listing tools operation -pub struct ToolsListResult { - /// List of tools - pub tools: Vec, - /// Optional cursor for pagination - #[serde(skip_serializing_if = "Option::is_none")] - pub next_cursor: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct ToolCallResult { - pub content: Vec, - #[serde(skip_serializing_if = "Option::is_none")] - pub is_error: Option, -} - -/// Content of a message -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(tag = "type", rename_all = "camelCase")] -pub enum MessageContent { - /// Text content - Text { - /// The text content - text: String, - }, - /// Image content - #[serde(rename_all = "camelCase")] - Image { - /// base64-encoded-data - data: String, - mime_type: String, - }, - /// Resource content - Resource { - /// The resource - resource: Resource, - }, -} - -impl From for String { - fn from(val: MessageContent) -> Self { - match val { - MessageContent::Text { text } => text, - MessageContent::Image { data, mime_type } => serde_json::json!({ - "data": data, - "mime_type": mime_type - }) - .to_string(), - MessageContent::Resource { resource } => serde_json::json!(resource).to_string(), - } - } -} - -impl std::fmt::Display for MessageContent { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - MessageContent::Text { text } => write!(f, "{}", text), - MessageContent::Image { data: _, mime_type } => write!(f, "Image [base64-encoded-string] ({})", mime_type), - MessageContent::Resource { resource } => write!(f, "Resource: {} ({})", resource.title, resource.uri), - } - } -} - -/// Resource contents -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(tag = "type", rename_all = "camelCase")] -pub enum ResourceContents { - Text { text: String }, - Blob { data: Vec }, -} - -/// A resource in the system -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Resource { - /// Unique identifier for the resource - pub uri: String, - /// Human-readable title - pub title: String, - /// Optional description - #[serde(skip_serializing_if = "Option::is_none")] - pub description: Option, - /// Resource contents - pub contents: ResourceContents, -} - -/// Represents the capabilities supported by a Model Context Protocol server -/// This is the "capabilities" field in the result of a response for init -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ServerCapabilities { - /// Configuration for server logging capabilities - #[serde(skip_serializing_if = "Option::is_none")] - pub logging: Option, - /// Configuration for prompt-related capabilities - #[serde(skip_serializing_if = "Option::is_none")] - pub prompts: Option, - /// Configuration for resource management capabilities - #[serde(skip_serializing_if = "Option::is_none")] - pub resources: Option, - /// Configuration for tool integration capabilities - #[serde(skip_serializing_if = "Option::is_none")] - pub tools: Option, -} diff --git a/crates/chat-cli/src/mcp_client/messenger.rs b/crates/chat-cli/src/mcp_client/messenger.rs deleted file mode 100644 index 14f79e518..000000000 --- a/crates/chat-cli/src/mcp_client/messenger.rs +++ /dev/null @@ -1,85 +0,0 @@ -use thiserror::Error; - -use super::{ - PromptsListResult, - ResourceTemplatesListResult, - ResourcesListResult, - ToolsListResult, -}; - -/// An interface that abstracts the implementation for information delivery from client and its -/// consumer. It is through this interface secondary information (i.e. information that are needed -/// to make requests to mcp servers) are obtained passively. Consumers of client can of course -/// choose to "actively" retrieve these information via explicitly making these requests. -#[allow(dead_code)] -#[async_trait::async_trait] -pub trait Messenger: std::fmt::Debug + Send + Sync + 'static { - /// Sends the result of a tools list operation to the consumer - /// This function is used to deliver information about available tools - async fn send_tools_list_result(&self, result: eyre::Result) -> Result<(), MessengerError>; - - /// Sends the result of a prompts list operation to the consumer - /// This function is used to deliver information about available prompts - async fn send_prompts_list_result(&self, result: eyre::Result) -> Result<(), MessengerError>; - - /// Sends the result of a resources list operation to the consumer - /// This function is used to deliver information about available resources - async fn send_resources_list_result(&self, result: eyre::Result) - -> Result<(), MessengerError>; - - /// Sends the result of a resource templates list operation to the consumer - /// This function is used to deliver information about available resource templates - async fn send_resource_templates_list_result( - &self, - result: eyre::Result, - ) -> Result<(), MessengerError>; - - /// Signals to the orchestrator that a server has started initializing - async fn send_init_msg(&self) -> Result<(), MessengerError>; - - /// Creates a duplicate of the messenger object - /// This function is used to create a new instance of the messenger with the same configuration - fn duplicate(&self) -> Box; -} - -#[derive(Clone, Debug, Error)] -pub enum MessengerError { - #[error("{0}")] - Custom(String), -} - -#[derive(Clone, Debug)] -pub struct NullMessenger; - -#[async_trait::async_trait] -impl Messenger for NullMessenger { - async fn send_tools_list_result(&self, _result: eyre::Result) -> Result<(), MessengerError> { - Ok(()) - } - - async fn send_prompts_list_result(&self, _result: eyre::Result) -> Result<(), MessengerError> { - Ok(()) - } - - async fn send_resources_list_result( - &self, - _result: eyre::Result, - ) -> Result<(), MessengerError> { - Ok(()) - } - - async fn send_resource_templates_list_result( - &self, - _result: eyre::Result, - ) -> Result<(), MessengerError> { - Ok(()) - } - - async fn send_init_msg(&self) -> Result<(), MessengerError> { - Ok(()) - } - - fn duplicate(&self) -> Box { - Box::new(NullMessenger) - } -} diff --git a/crates/chat-cli/src/mcp_client/mod.rs b/crates/chat-cli/src/mcp_client/mod.rs deleted file mode 100644 index 51f8b178f..000000000 --- a/crates/chat-cli/src/mcp_client/mod.rs +++ /dev/null @@ -1,13 +0,0 @@ -pub mod client; -pub mod error; -pub mod facilitator_types; -pub mod messenger; -pub mod server; -pub mod transport; - -pub use client::*; -pub use facilitator_types::*; -pub use messenger::*; -#[allow(unused_imports)] -pub use server::*; -pub use transport::*; diff --git a/crates/chat-cli/src/mcp_client/server.rs b/crates/chat-cli/src/mcp_client/server.rs deleted file mode 100644 index 7b320a2c6..000000000 --- a/crates/chat-cli/src/mcp_client/server.rs +++ /dev/null @@ -1,311 +0,0 @@ -#![allow(dead_code)] -use std::collections::HashMap; -use std::sync::atomic::{ - AtomicBool, - AtomicU64, - Ordering, -}; -use std::sync::{ - Arc, - Mutex, -}; - -use tokio::io::{ - Stdin, - Stdout, -}; -use tokio::task::JoinHandle; - -use super::Listener as _; -use super::client::StdioTransport; -use super::error::ErrorCode; -use super::transport::base_protocol::{ - JsonRpcError, - JsonRpcMessage, - JsonRpcNotification, - JsonRpcRequest, - JsonRpcResponse, -}; -use super::transport::stdio::JsonRpcStdioTransport; -use super::transport::{ - JsonRpcVersion, - Transport, - TransportError, -}; - -pub type Request = serde_json::Value; -pub type Response = Option; -pub type InitializedServer = JoinHandle>; - -pub trait PreServerRequestHandler { - fn register_pending_request_callback(&mut self, cb: impl Fn(u64) -> Option + Send + Sync + 'static); - fn register_send_request_callback( - &mut self, - cb: impl Fn(&str, Option) -> Result<(), ServerError> + Send + Sync + 'static, - ); -} - -#[async_trait::async_trait] -pub trait ServerRequestHandler: PreServerRequestHandler + Send + Sync + 'static { - async fn handle_initialize(&self, params: Option) -> Result; - async fn handle_incoming(&self, method: &str, params: Option) -> Result; - async fn handle_response(&self, resp: JsonRpcResponse) -> Result<(), ServerError>; - async fn handle_shutdown(&self) -> Result<(), ServerError>; -} - -pub struct Server { - transport: Option>, - handler: Option, - #[allow(dead_code)] - pending_requests: Arc>>, - #[allow(dead_code)] - current_id: Arc, -} - -#[derive(Debug, thiserror::Error)] -pub enum ServerError { - #[error(transparent)] - TransportError(#[from] TransportError), - #[error(transparent)] - Io(#[from] std::io::Error), - #[error(transparent)] - Serialization(#[from] serde_json::Error), - #[error("Unexpected msg type encountered")] - UnexpectedMsgType, - #[error("{0}")] - NegotiationError(String), - #[error(transparent)] - TokioJoinError(#[from] tokio::task::JoinError), - #[error("Failed to obtain mutex lock")] - MutexError, - #[error("Failed to obtain request method")] - MissingMethod, - #[error("Failed to obtain request id")] - MissingId, - #[error("Failed to initialize server. Missing transport")] - MissingTransport, - #[error("Failed to initialize server. Missing handler")] - MissingHandler, -} - -impl Server -where - H: ServerRequestHandler, -{ - pub fn new(mut handler: H, stdin: Stdin, stdout: Stdout) -> Result { - let transport = Arc::new(JsonRpcStdioTransport::server(stdin, stdout)?); - let pending_requests = Arc::new(Mutex::new(HashMap::::new())); - let pending_requests_clone_one = pending_requests.clone(); - let current_id = Arc::new(AtomicU64::new(0)); - let pending_request_getter = move |id: u64| -> Option { - match pending_requests_clone_one.lock() { - Ok(mut p) => p.remove(&id), - Err(_) => None, - } - }; - handler.register_pending_request_callback(pending_request_getter); - let transport_clone = transport.clone(); - let pending_request_clone_two = pending_requests.clone(); - let current_id_clone = current_id.clone(); - let request_sender = move |method: &str, params: Option| -> Result<(), ServerError> { - let id = current_id_clone.fetch_add(1, Ordering::SeqCst); - let msg = match method.split_once("/") { - Some(("request", _)) => { - let request = JsonRpcRequest { - jsonrpc: JsonRpcVersion::default(), - id, - method: method.to_owned(), - params, - }; - let msg = JsonRpcMessage::Request(request.clone()); - #[allow(clippy::map_err_ignore)] - let mut pending_request = pending_request_clone_two.lock().map_err(|_| ServerError::MutexError)?; - pending_request.insert(id, request); - Some(msg) - }, - Some(("notifications", _)) => { - let notif = JsonRpcNotification { - jsonrpc: JsonRpcVersion::default(), - method: method.to_owned(), - params, - }; - let msg = JsonRpcMessage::Notification(notif); - Some(msg) - }, - _ => None, - }; - if let Some(msg) = msg { - let transport = transport_clone.clone(); - tokio::task::spawn(async move { - let _ = transport.send(&msg).await; - }); - } - Ok(()) - }; - handler.register_send_request_callback(request_sender); - let server = Self { - transport: Some(transport), - handler: Some(handler), - pending_requests, - current_id, - }; - Ok(server) - } -} - -impl Server -where - T: Transport, - H: ServerRequestHandler, -{ - pub fn init(mut self) -> Result { - let transport = self.transport.take().ok_or(ServerError::MissingTransport)?; - let handler = Arc::new(self.handler.take().ok_or(ServerError::MissingHandler)?); - let has_initialized = Arc::new(AtomicBool::new(false)); - let listener = tokio::spawn(async move { - let mut listener = transport.get_listener(); - loop { - let request = listener.recv().await; - let transport_clone = transport.clone(); - let has_init_clone = has_initialized.clone(); - let handler_clone = handler.clone(); - tokio::task::spawn(async move { - process_request(has_init_clone, transport_clone, handler_clone, request).await; - }); - } - }); - Ok(listener) - } -} - -async fn process_request( - has_initialized: Arc, - transport: Arc, - handler: Arc, - request: Result, -) where - T: Transport, - H: ServerRequestHandler, -{ - match request { - Ok(msg) if msg.is_initialize() => { - let id = msg.id().unwrap_or_default(); - if has_initialized.load(Ordering::SeqCst) { - let resp = JsonRpcMessage::Response(JsonRpcResponse { - jsonrpc: JsonRpcVersion::default(), - id, - error: Some(JsonRpcError { - code: ErrorCode::InvalidRequest.into(), - message: "Server has already been initialized".to_owned(), - data: None, - }), - ..Default::default() - }); - let _ = transport.send(&resp).await; - return; - } - let JsonRpcMessage::Request(req) = msg else { - let resp = JsonRpcMessage::Response(JsonRpcResponse { - jsonrpc: JsonRpcVersion::default(), - id, - error: Some(JsonRpcError { - code: ErrorCode::InvalidRequest.into(), - message: "Invalid method for initialization (use request)".to_owned(), - data: None, - }), - ..Default::default() - }); - let _ = transport.send(&resp).await; - return; - }; - let JsonRpcRequest { params, .. } = req; - match handler.handle_initialize(params).await { - Ok(result) => { - let resp = JsonRpcMessage::Response(JsonRpcResponse { - id, - result, - ..Default::default() - }); - let _ = transport.send(&resp).await; - has_initialized.store(true, Ordering::SeqCst); - }, - Err(_e) => { - let resp = JsonRpcMessage::Response(JsonRpcResponse { - jsonrpc: JsonRpcVersion::default(), - id, - error: Some(JsonRpcError { - code: ErrorCode::InternalError.into(), - message: "Error producing initialization response".to_owned(), - data: None, - }), - ..Default::default() - }); - let _ = transport.send(&resp).await; - }, - } - }, - Ok(msg) if msg.is_shutdown() => { - // TODO: add shutdown routine - }, - Ok(msg) if has_initialized.load(Ordering::SeqCst) => match msg { - JsonRpcMessage::Request(req) => { - let JsonRpcRequest { - id, - jsonrpc, - params, - ref method, - } = req; - let resp = handler.handle_incoming(method, params).await.map_or_else( - |error| { - let err = JsonRpcError { - code: ErrorCode::InternalError.into(), - message: error.to_string(), - data: None, - }; - let resp = JsonRpcResponse { - jsonrpc: jsonrpc.clone(), - id, - result: None, - error: Some(err), - }; - JsonRpcMessage::Response(resp) - }, - |result| { - let resp = JsonRpcResponse { - jsonrpc: jsonrpc.clone(), - id, - result, - error: None, - }; - JsonRpcMessage::Response(resp) - }, - ); - let _ = transport.send(&resp).await; - }, - JsonRpcMessage::Notification(notif) => { - let JsonRpcNotification { ref method, params, .. } = notif; - let _ = handler.handle_incoming(method, params).await; - }, - JsonRpcMessage::Response(resp) => { - let _ = handler.handle_response(resp).await; - }, - }, - Ok(msg) => { - let id = msg.id().unwrap_or_default(); - let resp = JsonRpcMessage::Response(JsonRpcResponse { - jsonrpc: JsonRpcVersion::default(), - id, - error: Some(JsonRpcError { - code: ErrorCode::ServerNotInitialized.into(), - message: "Server has not been initialized".to_owned(), - data: None, - }), - ..Default::default() - }); - let _ = transport.send(&resp).await; - }, - Err(_e) => { - // TODO: error handling - }, - } -} diff --git a/crates/chat-cli/src/mcp_client/transport/base_protocol.rs b/crates/chat-cli/src/mcp_client/transport/base_protocol.rs deleted file mode 100644 index b0394e6e0..000000000 --- a/crates/chat-cli/src/mcp_client/transport/base_protocol.rs +++ /dev/null @@ -1,108 +0,0 @@ -//! Referencing https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/messages/ -//! Protocol Revision 2024-11-05 -use serde::{ - Deserialize, - Serialize, -}; - -pub type RequestId = u64; - -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] -pub struct JsonRpcVersion(String); - -impl Default for JsonRpcVersion { - fn default() -> Self { - JsonRpcVersion("2.0".to_owned()) - } -} - -impl JsonRpcVersion { - pub fn as_u32_vec(&self) -> Vec { - self.0 - .split(".") - .map(|n| n.parse::().unwrap()) - .collect::>() - } -} - -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] -#[serde(untagged)] -#[serde(deny_unknown_fields)] -// DO NOT change the order of these variants. This body of json is [untagged](https://serde.rs/enum-representations.html#untagged) -// The categorization of the deserialization depends on the order in which the variants are -// declared. -pub enum JsonRpcMessage { - Response(JsonRpcResponse), - Notification(JsonRpcNotification), - Request(JsonRpcRequest), -} - -impl JsonRpcMessage { - pub fn is_initialize(&self) -> bool { - match self { - JsonRpcMessage::Request(req) => req.method == "initialize", - _ => false, - } - } - - pub fn is_shutdown(&self) -> bool { - match self { - JsonRpcMessage::Notification(notif) => notif.method == "notification/shutdown", - _ => false, - } - } - - pub fn id(&self) -> Option { - match self { - JsonRpcMessage::Request(req) => Some(req.id), - JsonRpcMessage::Response(resp) => Some(resp.id), - JsonRpcMessage::Notification(_) => None, - } - } -} - -#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq, Eq)] -#[serde(default, deny_unknown_fields)] -pub struct JsonRpcRequest { - pub jsonrpc: JsonRpcVersion, - pub id: RequestId, - pub method: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub params: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq, Eq)] -#[serde(default, deny_unknown_fields)] -pub struct JsonRpcResponse { - pub jsonrpc: JsonRpcVersion, - pub id: RequestId, - #[serde(skip_serializing_if = "Option::is_none")] - pub result: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub error: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq, Eq)] -#[serde(default, deny_unknown_fields)] -pub struct JsonRpcNotification { - pub jsonrpc: JsonRpcVersion, - pub method: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub params: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq, Eq)] -#[serde(default, deny_unknown_fields)] -pub struct JsonRpcError { - pub code: i32, - pub message: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub data: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq, Eq)] -pub enum TransportType { - #[default] - Stdio, - Websocket, -} diff --git a/crates/chat-cli/src/mcp_client/transport/mod.rs b/crates/chat-cli/src/mcp_client/transport/mod.rs deleted file mode 100644 index f752b1675..000000000 --- a/crates/chat-cli/src/mcp_client/transport/mod.rs +++ /dev/null @@ -1,57 +0,0 @@ -pub mod base_protocol; -pub mod stdio; - -use std::fmt::Debug; - -pub use base_protocol::*; -pub use stdio::*; -use thiserror::Error; - -#[derive(Clone, Debug, Error)] -pub enum TransportError { - #[error("Serialization error: {0}")] - Serialization(String), - #[error("IO error: {0}")] - Stdio(String), - #[error("{0}")] - Custom(String), - #[error(transparent)] - RecvError(#[from] tokio::sync::broadcast::error::RecvError), -} - -impl From for TransportError { - fn from(err: serde_json::Error) -> Self { - TransportError::Serialization(err.to_string()) - } -} - -impl From for TransportError { - fn from(err: std::io::Error) -> Self { - TransportError::Stdio(err.to_string()) - } -} - -#[allow(dead_code)] -#[async_trait::async_trait] -pub trait Transport: Send + Sync + Debug + 'static { - /// Sends a message over the transport layer. - async fn send(&self, msg: &JsonRpcMessage) -> Result<(), TransportError>; - /// Listens to awaits for a response. This is a call that should be used after `send` is called - /// to listen for a response from the message recipient. - fn get_listener(&self) -> impl Listener; - /// Gracefully terminates the transport connection, cleaning up any resources. - /// This should be called when the transport is no longer needed to ensure proper cleanup. - async fn shutdown(&self) -> Result<(), TransportError>; - /// Listener that listens for logging messages. - fn get_log_listener(&self) -> impl LogListener; -} - -#[async_trait::async_trait] -pub trait Listener: Send + Sync + 'static { - async fn recv(&mut self) -> Result; -} - -#[async_trait::async_trait] -pub trait LogListener: Send + Sync + 'static { - async fn recv(&mut self) -> Result; -} diff --git a/crates/chat-cli/src/mcp_client/transport/stdio.rs b/crates/chat-cli/src/mcp_client/transport/stdio.rs deleted file mode 100644 index 89266a183..000000000 --- a/crates/chat-cli/src/mcp_client/transport/stdio.rs +++ /dev/null @@ -1,285 +0,0 @@ -use std::sync::Arc; - -use tokio::io::{ - AsyncBufReadExt, - AsyncRead, - AsyncWriteExt as _, - BufReader, - Stdin, - Stdout, -}; -use tokio::process::{ - Child, - ChildStdin, -}; -use tokio::sync::{ - Mutex, - broadcast, -}; - -use super::base_protocol::JsonRpcMessage; -use super::{ - Listener, - LogListener, - Transport, - TransportError, -}; - -#[derive(Debug)] -pub enum JsonRpcStdioTransport { - Client { - stdin: Arc>, - receiver: broadcast::Receiver>, - log_receiver: broadcast::Receiver, - }, - Server { - stdout: Arc>, - receiver: broadcast::Receiver>, - }, -} - -impl JsonRpcStdioTransport { - fn spawn_reader( - reader: R, - tx: broadcast::Sender>, - ) { - tokio::spawn(async move { - let mut buffer = Vec::::new(); - let mut buf_reader = BufReader::new(reader); - loop { - buffer.clear(); - // Messages are delimited by newlines and assumed to contain no embedded newlines - // See https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#stdio - match buf_reader.read_until(b'\n', &mut buffer).await { - Ok(0) => break, - Ok(_) => match serde_json::from_slice::(buffer.as_slice()) { - Ok(msg) => { - let _ = tx.send(Ok(msg)); - }, - Err(e) => { - let _ = tx.send(Err(e.into())); - }, - }, - Err(e) => { - let _ = tx.send(Err(e.into())); - }, - } - } - }); - } - - pub fn client(child_process: Child) -> Result { - let (tx, receiver) = broadcast::channel::>(100); - let Some(stdout) = child_process.stdout else { - return Err(TransportError::Custom("No stdout found on child process".to_owned())); - }; - let Some(stdin) = child_process.stdin else { - return Err(TransportError::Custom("No stdin found on child process".to_owned())); - }; - let Some(stderr) = child_process.stderr else { - return Err(TransportError::Custom("No stderr found on child process".to_owned())); - }; - let (log_tx, log_receiver) = broadcast::channel::(100); - tokio::task::spawn(async move { - let stderr = tokio::io::BufReader::new(stderr); - let mut lines = stderr.lines(); - while let Ok(Some(line)) = lines.next_line().await { - let _ = log_tx.send(line); - } - }); - let stdin = Arc::new(Mutex::new(stdin)); - Self::spawn_reader(stdout, tx); - Ok(JsonRpcStdioTransport::Client { - stdin, - receiver, - log_receiver, - }) - } - - pub fn server(stdin: Stdin, stdout: Stdout) -> Result { - let (tx, receiver) = broadcast::channel::>(100); - Self::spawn_reader(stdin, tx); - let stdout = Arc::new(Mutex::new(stdout)); - Ok(JsonRpcStdioTransport::Server { stdout, receiver }) - } -} - -#[async_trait::async_trait] -impl Transport for JsonRpcStdioTransport { - async fn send(&self, msg: &JsonRpcMessage) -> Result<(), TransportError> { - match self { - JsonRpcStdioTransport::Client { stdin, .. } => { - let mut serialized = serde_json::to_vec(msg)?; - serialized.push(b'\n'); - let mut stdin = stdin.lock().await; - stdin - .write_all(&serialized) - .await - .map_err(|e| TransportError::Custom(format!("Error writing to server: {:?}", e)))?; - stdin - .flush() - .await - .map_err(|e| TransportError::Custom(format!("Error writing to server: {:?}", e)))?; - Ok(()) - }, - JsonRpcStdioTransport::Server { stdout, .. } => { - let mut serialized = serde_json::to_vec(msg)?; - serialized.push(b'\n'); - let mut stdout = stdout.lock().await; - stdout - .write_all(&serialized) - .await - .map_err(|e| TransportError::Custom(format!("Error writing to client: {:?}", e)))?; - stdout - .flush() - .await - .map_err(|e| TransportError::Custom(format!("Error writing to client: {:?}", e)))?; - Ok(()) - }, - } - } - - fn get_listener(&self) -> impl Listener { - match self { - JsonRpcStdioTransport::Client { receiver, .. } | JsonRpcStdioTransport::Server { receiver, .. } => { - StdioListener { - receiver: receiver.resubscribe(), - } - }, - } - } - - async fn shutdown(&self) -> Result<(), TransportError> { - match self { - JsonRpcStdioTransport::Client { stdin, .. } => { - let mut stdin = stdin.lock().await; - Ok(stdin.shutdown().await?) - }, - JsonRpcStdioTransport::Server { stdout, .. } => { - let mut stdout = stdout.lock().await; - Ok(stdout.shutdown().await?) - }, - } - } - - fn get_log_listener(&self) -> impl LogListener { - match self { - JsonRpcStdioTransport::Client { log_receiver, .. } => StdioLogListener { - receiver: log_receiver.resubscribe(), - }, - JsonRpcStdioTransport::Server { .. } => unreachable!("server does not need a log listener"), - } - } -} - -pub struct StdioListener { - pub receiver: broadcast::Receiver>, -} - -#[async_trait::async_trait] -impl Listener for StdioListener { - async fn recv(&mut self) -> Result { - self.receiver.recv().await? - } -} - -pub struct StdioLogListener { - pub receiver: broadcast::Receiver, -} - -#[async_trait::async_trait] -impl LogListener for StdioLogListener { - async fn recv(&mut self) -> Result { - Ok(self.receiver.recv().await?) - } -} - -#[cfg(test)] -mod tests { - use std::process::Stdio; - - use serde_json::{ - Value, - json, - }; - use tokio::process::Command; - - use super::{ - JsonRpcMessage, - JsonRpcStdioTransport, - Listener, - Transport, - }; - - // Helpers for testing - fn create_test_message() -> JsonRpcMessage { - serde_json::from_value(json!({ - "jsonrpc": "2.0", - "id": 1, - "method": "test_method", - "params": { - "test_param": "test_value" - } - })) - .unwrap() - } - - #[tokio::test] - async fn test_client_transport() { - #[cfg(windows)] - let mut cmd = { - let mut cmd = Command::new("powershell"); - cmd.args(&["cat"]); - cmd - }; - #[cfg(not(windows))] - let mut cmd = Command::new("cat"); - - cmd.stdin(Stdio::piped()).stdout(Stdio::piped()).stderr(Stdio::piped()); - - // Inject our mock transport instead - let child = cmd.spawn().expect("Failed to spawn command"); - let transport = JsonRpcStdioTransport::client(child).expect("Failed to create client transport"); - - let message = create_test_message(); - let result = transport.send(&message).await; - assert!(result.is_ok(), "Failed to send message: {:?}", result); - - let echo = transport - .get_listener() - .recv() - .await - .expect("Failed to receive message"); - let echo_value = serde_json::to_value(&echo).expect("Failed to convert echo to value"); - let message_value = serde_json::to_value(&message).expect("Failed to convert message to value"); - assert!(are_json_values_equal(&echo_value, &message_value)); - } - - fn are_json_values_equal(a: &Value, b: &Value) -> bool { - match (a, b) { - (Value::Null, Value::Null) => true, - (Value::Bool(a_val), Value::Bool(b_val)) => a_val == b_val, - (Value::Number(a_val), Value::Number(b_val)) => a_val == b_val, - (Value::String(a_val), Value::String(b_val)) => a_val == b_val, - (Value::Array(a_arr), Value::Array(b_arr)) => { - if a_arr.len() != b_arr.len() { - return false; - } - a_arr - .iter() - .zip(b_arr.iter()) - .all(|(a_item, b_item)| are_json_values_equal(a_item, b_item)) - }, - (Value::Object(a_obj), Value::Object(b_obj)) => { - if a_obj.len() != b_obj.len() { - return false; - } - a_obj.iter().all(|(key, a_value)| match b_obj.get(key) { - Some(b_value) => are_json_values_equal(a_value, b_value), - None => false, - }) - }, - _ => false, - } - } -} diff --git a/crates/chat-cli/src/mcp_client/transport/websocket.rs b/crates/chat-cli/src/mcp_client/transport/websocket.rs deleted file mode 100644 index e69de29bb..000000000 diff --git a/crates/chat-cli/src/os/diagnostics.rs b/crates/chat-cli/src/os/diagnostics.rs deleted file mode 100644 index 0152b56c3..000000000 --- a/crates/chat-cli/src/os/diagnostics.rs +++ /dev/null @@ -1,236 +0,0 @@ -#![allow(clippy::ref_option_ref)] -use std::collections::BTreeMap; - -use serde::Serialize; -use sysinfo::{ - CpuRefreshKind, - MemoryRefreshKind, - RefreshKind, -}; -use time::OffsetDateTime; -use time::format_description::well_known::Rfc3339; - -use crate::os::Env; -use crate::telemetry::InstallMethod; -use crate::util::consts::build::HASH; -use crate::util::system_info::{ - OSVersion, - os_version, -}; - -fn serialize_display(display: D, serializer: S) -> Result -where - D: std::fmt::Display, - S: serde::Serializer, -{ - serializer.serialize_str(&display.to_string()) -} - -fn is_false(value: &bool) -> bool { - !value -} - -#[derive(Debug, Clone, Serialize, Default)] -#[serde(rename_all = "kebab-case")] -pub struct BuildDetails { - pub version: String, - pub hash: Option<&'static str>, - pub date: Option, -} - -impl BuildDetails { - pub fn new() -> BuildDetails { - let date = crate::util::consts::build::DATETIME - .and_then(|input| OffsetDateTime::parse(input, &Rfc3339).ok()) - .and_then(|time| { - let rfc3339 = time.format(&Rfc3339).ok()?; - let duration = OffsetDateTime::now_utc() - time; - Some(format!("{rfc3339} ({duration:.0} ago)")) - }); - - BuildDetails { - version: env!("CARGO_PKG_VERSION").to_owned(), - hash: HASH, - date, - } - } -} - -fn serialize_os_version(version: &Option<&OSVersion>, serializer: S) -> Result -where - S: serde::Serializer, -{ - match version { - Some(version) => match version { - OSVersion::Linux { .. } => version.serialize(serializer), - other => serializer.serialize_str(&other.to_string()), - }, - None => serializer.serialize_none(), - } -} - -#[derive(Debug, Clone, Serialize)] -#[serde(rename_all = "kebab-case")] -pub struct SystemInfo { - #[serde(serialize_with = "serialize_os_version")] - pub os: Option<&'static OSVersion>, - pub chip: Option, - pub total_cores: Option, - pub memory: Option, -} - -impl SystemInfo { - fn new() -> SystemInfo { - let system = sysinfo::System::new_with_specifics( - RefreshKind::nothing() - .with_cpu(CpuRefreshKind::everything()) - .with_memory(MemoryRefreshKind::everything()), - ); - - let mut hardware_info = SystemInfo { - os: os_version(), - chip: None, - total_cores: system.physical_core_count(), - memory: Some(format!("{:0.2} GB", system.total_memory() as f32 / 2.0_f32.powi(30))), - }; - - if let Some(processor) = system.cpus().first() { - hardware_info.chip = Some(processor.brand().into()); - } - - hardware_info - } -} - -#[derive(Debug, Clone, Serialize)] -#[serde(rename_all = "kebab-case")] -pub struct EnvVarDiagnostic { - pub env_vars: BTreeMap, -} - -impl EnvVarDiagnostic { - fn new() -> EnvVarDiagnostic { - let env_vars = std::env::vars() - .filter(|(key, _)| { - let fig_var = crate::util::env_var::ALL.contains(&key.as_str()); - let other_var = [ - // General env vars - "SHELL", - "DISPLAY", - "PATH", - "TERM", - "ZDOTDIR", - // Linux vars - "XDG_CURRENT_DESKTOP", - "XDG_SESSION_DESKTOP", - "XDG_SESSION_TYPE", - "GLFW_IM_MODULE", - "GTK_IM_MODULE", - "QT_IM_MODULE", - "XMODIFIERS", - // Macos vars - "__CFBundleIdentifier", - ] - .contains(&key.as_str()); - - fig_var || other_var - }) - .map(|(key, value)| { - // sanitize username from values - let username = format!("/{}", whoami::username()); - (key, value.replace(&username, "/USER")) - }) - .collect(); - - EnvVarDiagnostic { env_vars } - } -} - -#[derive(Debug, Clone, Serialize)] -#[serde(rename_all = "kebab-case")] -pub struct CurrentEnvironment { - pub cwd: Option, - pub cli_path: Option, - #[serde(serialize_with = "serialize_display")] - pub install_method: InstallMethod, - #[serde(skip_serializing_if = "is_false")] - pub in_ssh: bool, - #[serde(skip_serializing_if = "is_false")] - pub in_ci: bool, - #[serde(skip_serializing_if = "is_false")] - pub in_wsl: bool, - #[serde(skip_serializing_if = "is_false")] - pub in_codespaces: bool, -} - -impl CurrentEnvironment { - async fn new(env: &Env) -> CurrentEnvironment { - let username = format!("/{}", whoami::username()); - - let cwd = env - .current_dir() - .ok() - .map(|path| path.to_string_lossy().replace(&username, "/USER")); - - let cli_path = env - .current_dir() - .ok() - .map(|path| path.to_string_lossy().replace(&username, "/USER")); - - let install_method = crate::telemetry::get_install_method(); - - let in_ssh = crate::util::system_info::in_ssh(); - let in_ci = crate::util::system_info::in_ci(); - let in_wsl = crate::util::system_info::in_wsl(); - let in_codespaces = crate::util::system_info::in_codespaces(); - - CurrentEnvironment { - cwd, - cli_path, - install_method, - in_ssh, - in_ci, - in_wsl, - in_codespaces, - } - } -} - -#[derive(Clone, Debug, Serialize)] -#[serde(rename_all = "kebab-case")] -pub struct Diagnostics { - #[serde(rename = "q-details")] - pub build_details: BuildDetails, - pub system_info: SystemInfo, - pub environment: CurrentEnvironment, - #[serde(flatten)] - pub environment_variables: EnvVarDiagnostic, -} - -impl Diagnostics { - pub async fn new(env: &Env) -> Diagnostics { - Diagnostics { - build_details: BuildDetails::new(), - system_info: SystemInfo::new(), - environment: CurrentEnvironment::new(env).await, - environment_variables: EnvVarDiagnostic::new(), - } - } - - pub fn user_readable(&self) -> Result { - toml::to_string(&self) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[tokio::test] - async fn test_diagnostics_user_readable() { - let env = Env::new(); - let diagnostics = Diagnostics::new(&env).await; - let toml = diagnostics.user_readable().unwrap(); - assert!(!toml.is_empty()); - } -} diff --git a/crates/chat-cli/src/os/env.rs b/crates/chat-cli/src/os/env.rs deleted file mode 100644 index 756e1746e..000000000 --- a/crates/chat-cli/src/os/env.rs +++ /dev/null @@ -1,207 +0,0 @@ -use std::collections::HashMap; -use std::env::{ - self, - VarError, -}; -use std::ffi::{ - OsStr, - OsString, -}; -use std::io; -use std::path::PathBuf; -use std::sync::{ - Arc, - Mutex, -}; - -use crate::os::ACTIVE_USER_HOME; - -#[derive(Debug, Clone)] -pub struct Env(inner::Inner); - -mod inner { - use std::collections::HashMap; - use std::path::PathBuf; - use std::sync::{ - Arc, - Mutex, - }; - - #[derive(Debug, Clone)] - pub(super) enum Inner { - Real, - Fake(Arc>), - } - - #[derive(Debug, Clone)] - pub(super) struct Fake { - pub vars: HashMap, - pub cwd: PathBuf, - pub current_exe: PathBuf, - } -} - -impl Env { - pub fn new() -> Self { - if cfg!(test) { - match cfg!(windows) { - true => Env::from_slice(&[ - ("USERPROFILE", ACTIVE_USER_HOME), - ("USERNAME", "testuser"), - ("PATH", ""), - ]), - false => Env::from_slice(&[("HOME", ACTIVE_USER_HOME), ("USER", "testuser"), ("PATH", "")]), - } - } else { - Env(inner::Inner::Real) - } - } - - /// Create a fake process environment from a slice of tuples. - pub fn from_slice(vars: &[(&str, &str)]) -> Self { - use inner::Inner; - let map: HashMap<_, _> = vars.iter().map(|(k, v)| ((*k).to_owned(), (*v).to_owned())).collect(); - Self(Inner::Fake(Arc::new(Mutex::new(inner::Fake { - vars: map, - cwd: PathBuf::from("/"), - current_exe: PathBuf::from("/current_exe"), - })))) - } - - pub fn get>(&self, key: K) -> Result { - use inner::Inner; - match &self.0 { - Inner::Real => env::var(key.as_ref()), - Inner::Fake(fake) => fake - .lock() - .unwrap() - .vars - .get(key.as_ref()) - .cloned() - .ok_or(VarError::NotPresent), - } - } - - pub fn get_os>(&self, key: K) -> Option { - use inner::Inner; - match &self.0 { - Inner::Real => env::var_os(key.as_ref()), - Inner::Fake(fake) => fake - .lock() - .unwrap() - .vars - .get(key.as_ref().to_str()?) - .cloned() - .map(OsString::from), - } - } - - /// Sets the environment variable `key` to the value `value` for the currently running - /// process. - /// - /// # Safety - /// - /// See [std::env::set_var] for the safety requirements. - pub unsafe fn set_var(&self, key: impl AsRef, value: impl AsRef) { - unsafe { - use inner::Inner; - match &self.0 { - Inner::Real => std::env::set_var(key, value), - Inner::Fake(fake) => { - fake.lock().unwrap().vars.insert( - key.as_ref().to_str().expect("key must be valid str").to_string(), - value.as_ref().to_str().expect("key must be valid str").to_string(), - ); - }, - } - } - } - - pub fn home(&self) -> Option { - match &self.0 { - inner::Inner::Real => dirs::home_dir(), - inner::Inner::Fake(fake) => fake.lock().unwrap().vars.get("HOME").map(PathBuf::from), - } - } - - pub fn current_dir(&self) -> Result { - use inner::Inner; - match &self.0 { - Inner::Real => std::env::current_dir(), - Inner::Fake(fake) => Ok(fake.lock().unwrap().cwd.clone()), - } - } - - pub fn current_exe(&self) -> Result { - use inner::Inner; - match &self.0 { - Inner::Real => std::env::current_exe(), - Inner::Fake(fake) => Ok(fake.lock().unwrap().current_exe.clone()), - } - } - - pub fn in_ssh(&self) -> bool { - self.get("SSH_CLIENT").is_ok() || self.get("SSH_CONNECTION").is_ok() || self.get("SSH_TTY").is_ok() - } - - pub fn in_codespaces(&self) -> bool { - self.get_os("CODESPACES").is_some() || self.get_os("Q_CODESPACES").is_some() - } - - pub fn in_ci(&self) -> bool { - self.get_os("CI").is_some() || self.get_os("Q_CI").is_some() - } - - /// Whether or not the current executable is run from an AppImage. - /// - /// See: https://docs.appimage.org/packaging-guide/environment-variables.html - pub fn in_appimage(&self) -> bool { - self.get_os("APPIMAGE").is_some() - } -} - -impl Default for Env { - fn default() -> Self { - Env::new() - } -} - -#[cfg(test)] -mod tests { - use std::path::Path; - - use super::*; - - #[test] - fn test_get() { - let env = Env::new(); - assert!(env.home().is_some()); - assert!(env.get("PATH").is_ok()); - assert!(env.get_os("PATH").is_some()); - assert!(env.get("NON_EXISTENT").is_err()); - - let env = Env::from_slice(&[("HOME", "/home/user"), ("PATH", "/bin:/usr/bin")]); - assert_eq!(env.home().unwrap(), Path::new("/home/user")); - assert_eq!(env.get("PATH").unwrap(), "/bin:/usr/bin"); - assert!(env.get_os("PATH").is_some()); - assert!(env.get("NON_EXISTENT").is_err()); - } - - #[test] - fn test_in_envs() { - let env = Env::from_slice(&[]); - assert!(!env.in_ssh()); - - let env = Env::from_slice(&[("SSH_CLIENT", "1")]); - assert!(env.in_ssh()); - - let env = Env::from_slice(&[("APPIMAGE", "/tmp/.mount-asdf/usr")]); - assert!(env.in_appimage()); - } - - #[test] - fn test_default_current_dir() { - let env = Env::from_slice(&[]); - assert_eq!(env.current_dir().unwrap(), PathBuf::from("/")); - } -} diff --git a/crates/chat-cli/src/os/fs/mod.rs b/crates/chat-cli/src/os/fs/mod.rs deleted file mode 100644 index 755966aca..000000000 --- a/crates/chat-cli/src/os/fs/mod.rs +++ /dev/null @@ -1,627 +0,0 @@ -use std::collections::HashMap; -use std::fs::Permissions; -use std::io; -use std::path::{ - Path, - PathBuf, -}; -use std::sync::{ - Arc, - Mutex, -}; - -use tempfile::TempDir; -use tokio::fs; - -pub const WINDOWS_USER_HOME: &str = "C:\\Users\\testuser"; -pub const UNIX_USER_HOME: &str = "/home/testuser"; - -pub const ACTIVE_USER_HOME: &str = if cfg!(windows) { - WINDOWS_USER_HOME -} else { - UNIX_USER_HOME -}; - -// Import platform-specific modules -#[cfg(unix)] -mod unix; -#[cfg(windows)] -mod windows; - -// Use platform-specific functions -#[cfg(unix)] -use unix::{ - append as platform_append, - symlink_sync, -}; -#[cfg(windows)] -use windows::{ - append as platform_append, - symlink_sync, -}; - -/// Rust path handling is hard coded to work specific ways depending on the -/// OS that is being executed on. Because of this, if Unix paths are provided, -/// they aren't recognized. For example a leading prefix of '/' isn't considered -/// an absolute path. To fix this, all test paths would need to have windows -/// equivalents which is tedious and can lead to errors and missed test cases. -/// To make writing tests easier, path normalization happens on Windows systems -/// implicitly during test runtime. -#[cfg(test)] -fn normalize_test_path(path: impl AsRef) -> PathBuf { - #[cfg(windows)] - { - use typed_path::Utf8TypedPath; - let path_ref = path.as_ref(); - - // Only process string paths with forward slashes - let typed_path = Utf8TypedPath::derive(path_ref.to_str().unwrap()); - if typed_path.is_unix() { - let windows_path = typed_path.with_windows_encoding().to_string(); - - // If path is absolute (starts with /) and doesn't already have a drive letter - if PathBuf::from(&windows_path).has_root() { - // Prepend C: drive letter to make it truly absolute on Windows - return PathBuf::from(format!("C:{}", windows_path)); - } - - return PathBuf::from(windows_path); - } - } - path.as_ref().to_path_buf() -} - -/// Cross-platform path append that handles test paths consistently -fn append(base: impl AsRef, path: impl AsRef) -> PathBuf { - #[cfg(test)] - { - // Normalize the path for tests, then use the platform-specific append - platform_append(normalize_test_path(base), normalize_test_path(path)) - } - - #[cfg(not(test))] - { - // In non-test code, just use the platform-specific append directly - platform_append(base, path) - } -} - -#[derive(Debug, Clone)] -pub enum Fs { - Real, - /// Uses the real filesystem except acts as if the process has - /// a different root directory by using [TempDir] - Chroot(Arc), - Fake(Arc>>>), -} - -impl Fs { - pub fn new() -> Self { - match cfg!(test) { - true => { - let tempdir = tempfile::tempdir().expect("failed creating temporary directory"); - let fs = Self::Chroot(tempdir.into()); - futures::executor::block_on(fs.create_dir_all(ACTIVE_USER_HOME)) - .expect("failed to create test user home"); - - fs - }, - false => Self::Real, - } - } - - pub fn is_chroot(&self) -> bool { - matches!(self, Self::Chroot(_)) - } - - pub fn from_slice(vars: &[(&str, &str)]) -> Self { - let map: HashMap<_, _> = vars - .iter() - .map(|(k, v)| (PathBuf::from(k), v.as_bytes().to_vec())) - .collect(); - - Self::Fake(Arc::new(Mutex::new(map))) - } - - pub async fn create_new(&self, path: impl AsRef) -> io::Result { - match self { - Self::Real => fs::File::create_new(path).await, - Self::Chroot(root) => fs::File::create_new(append(root.path(), path)).await, - Self::Fake(_) => Err(io::Error::other("unimplemented")), - } - } - - pub async fn create_dir(&self, path: impl AsRef) -> io::Result<()> { - match self { - Self::Real => fs::create_dir(path).await, - Self::Chroot(root) => fs::create_dir(append(root.path(), path)).await, - Self::Fake(_) => Err(io::Error::other("unimplemented")), - } - } - - pub async fn create_dir_all(&self, path: impl AsRef) -> io::Result<()> { - match self { - Self::Real => fs::create_dir_all(path).await, - Self::Chroot(root) => fs::create_dir_all(append(root.path(), path)).await, - Self::Fake(_) => Err(io::Error::other("unimplemented")), - } - } - - /// Attempts to open a file in read-only mode. - /// - /// This is a proxy to [`tokio::fs::File::open`]. - pub async fn open(&self, path: impl AsRef) -> io::Result { - match self { - Self::Real => fs::File::open(path).await, - Self::Chroot(root) => fs::File::open(append(root.path(), path)).await, - Self::Fake(_) => Err(io::Error::other("unimplemented")), - } - } - - pub async fn read(&self, path: impl AsRef) -> io::Result> { - match self { - Self::Real => fs::read(path).await, - Self::Chroot(root) => fs::read(append(root.path(), path)).await, - Self::Fake(map) => { - let Ok(lock) = map.lock() else { - return Err(io::Error::other("poisoned lock")); - }; - let Some(data) = lock.get(path.as_ref()) else { - return Err(io::Error::new(io::ErrorKind::NotFound, "not found")); - }; - Ok(data.clone()) - }, - } - } - - pub async fn read_to_string(&self, path: impl AsRef) -> io::Result { - match self { - Self::Real => fs::read_to_string(path).await, - Self::Chroot(root) => fs::read_to_string(append(root.path(), path)).await, - Self::Fake(map) => { - let Ok(lock) = map.lock() else { - return Err(io::Error::other("poisoned lock")); - }; - let Some(data) = lock.get(path.as_ref()) else { - return Err(io::Error::new(io::ErrorKind::NotFound, "not found")); - }; - match String::from_utf8(data.clone()) { - Ok(string) => Ok(string), - Err(err) => Err(io::Error::new(io::ErrorKind::InvalidData, err)), - } - }, - } - } - - pub fn read_to_string_sync(&self, path: impl AsRef) -> io::Result { - match self { - Self::Real => std::fs::read_to_string(path), - Self::Chroot(root) => std::fs::read_to_string(append(root.path(), path)), - Self::Fake(map) => { - let Ok(lock) = map.lock() else { - return Err(io::Error::other("poisoned lock")); - }; - let Some(data) = lock.get(path.as_ref()) else { - return Err(io::Error::new(io::ErrorKind::NotFound, "not found")); - }; - match String::from_utf8(data.clone()) { - Ok(string) => Ok(string), - Err(err) => Err(io::Error::new(io::ErrorKind::InvalidData, err)), - } - }, - } - } - - /// Creates a future that will open a file for writing and write the entire - /// contents of `contents` to it. - /// - /// This is a proxy to [`tokio::fs::write`]. - pub async fn write(&self, path: impl AsRef, contents: impl AsRef<[u8]>) -> io::Result<()> { - match self { - Self::Real => fs::write(path, contents).await, - Self::Chroot(root) => fs::write(append(root.path(), path), contents).await, - Self::Fake(map) => { - let Ok(mut lock) = map.lock() else { - return Err(io::Error::other("poisoned lock")); - }; - lock.insert(path.as_ref().to_owned(), contents.as_ref().to_owned()); - Ok(()) - }, - } - } - - /// Removes a file from the filesystem. - /// - /// Note that there is no guarantee that the file is immediately deleted (e.g. - /// depending on platform, other open file descriptors may prevent immediate - /// removal). - /// - /// This is a proxy to [`tokio::fs::remove_file`]. - pub async fn remove_file(&self, path: impl AsRef) -> io::Result<()> { - match self { - Self::Real => fs::remove_file(path).await, - Self::Chroot(root) => fs::remove_file(append(root.path(), path)).await, - Self::Fake(_) => panic!("unimplemented"), - } - } - - /// Removes a directory at this path, after removing all its contents. Use carefully! - /// - /// This is a proxy to [`tokio::fs::remove_dir_all`]. - pub async fn remove_dir_all(&self, path: impl AsRef) -> io::Result<()> { - match self { - Self::Real => fs::remove_dir_all(path).await, - Self::Chroot(root) => fs::remove_dir_all(append(root.path(), path)).await, - Self::Fake(_) => panic!("unimplemented"), - } - } - - /// Renames a file or directory to a new name, replacing the original file if - /// `to` already exists. - /// - /// This will not work if the new name is on a different mount point. - /// - /// This is a proxy to [`tokio::fs::rename`]. - pub async fn rename(&self, from: impl AsRef, to: impl AsRef) -> io::Result<()> { - match self { - Self::Real => fs::rename(from, to).await, - Self::Chroot(root) => fs::rename(append(root.path(), from), append(root.path(), to)).await, - Self::Fake(_) => panic!("unimplemented"), - } - } - - /// Copies the contents of one file to another. This function will also copy the permission bits - /// of the original file to the destination file. - /// This function will overwrite the contents of to. - /// - /// This is a proxy to [`tokio::fs::copy`]. - pub async fn copy(&self, from: impl AsRef, to: impl AsRef) -> io::Result { - match self { - Self::Real => fs::copy(from, to).await, - Self::Chroot(root) => fs::copy(append(root.path(), from), append(root.path(), to)).await, - Self::Fake(_) => panic!("unimplemented"), - } - } - - /// Returns `Ok(true)` if the path points at an existing entity. - /// - /// This function will traverse symbolic links to query information about the - /// destination file. In case of broken symbolic links this will return `Ok(false)`. - /// - /// This is a proxy to [`tokio::fs::try_exists`]. - pub async fn try_exists(&self, path: impl AsRef) -> Result { - match self { - Self::Real => fs::try_exists(path).await, - Self::Chroot(root) => fs::try_exists(append(root.path(), path)).await, - Self::Fake(_) => panic!("unimplemented"), - } - } - - /// Returns `true` if the path points at an existing entity. - /// - /// This is a proxy to [std::path::Path::exists]. See the related doc comment in std - /// on the pitfalls of using this versus [std::path::Path::try_exists]. - pub fn exists(&self, path: impl AsRef) -> bool { - match self { - Self::Real => path.as_ref().exists(), - Self::Chroot(root) => append(root.path(), path).exists(), - Self::Fake(_) => panic!("unimplemented"), - } - } - - /// Returns `true` if the path points at an existing entity without following symlinks. - /// - /// This does *not* guarantee that the path doesn't point to a symlink. For example, `false` - /// will be returned if the user doesn't have permission to perform a metadata operation on - /// `path`. - pub async fn symlink_exists(&self, path: impl AsRef) -> bool { - match self.symlink_metadata(path).await { - Ok(_) => true, - Err(err) if err.kind() != std::io::ErrorKind::NotFound => true, - Err(_) => false, - } - } - - pub async fn create_tempdir(&self) -> io::Result { - match self { - Self::Real => TempDir::new(), - Self::Chroot(root) => TempDir::new_in(root.path()), - Self::Fake(_) => panic!("unimplemented"), - } - } - - /// Creates a new symbolic link on the filesystem. - /// - /// The `link` path will be a symbolic link pointing to the `original` path. - pub async fn symlink(&self, original: impl AsRef, link: impl AsRef) -> io::Result<()> { - #[cfg(unix)] - async fn do_symlink(original: impl AsRef, link: impl AsRef) -> io::Result<()> { - fs::symlink(original, link).await - } - - #[cfg(windows)] - async fn do_symlink(original: impl AsRef, link: impl AsRef) -> io::Result<()> { - windows::symlink_async(original, link).await - } - - match self { - Self::Real => do_symlink(original, link).await, - Self::Chroot(root) => do_symlink(append(root.path(), original), append(root.path(), link)).await, - Self::Fake(_) => panic!("unimplemented"), - } - } - - /// Creates a new symbolic link on the filesystem. - /// - /// The `link` path will be a symbolic link pointing to the `original` path. - pub fn symlink_sync(&self, original: impl AsRef, link: impl AsRef) -> io::Result<()> { - match self { - Self::Real => symlink_sync(original, link), - Self::Chroot(root) => symlink_sync(append(root.path(), original), append(root.path(), link)), - Self::Fake(_) => panic!("unimplemented"), - } - } - - /// Query the metadata about a file without following symlinks. - /// - /// This is a proxy to [`tokio::fs::symlink_metadata`] - /// - /// # Errors - /// - /// This function will return an error in the following situations, but is not - /// limited to just these cases: - /// - /// * The user lacks permissions to perform `metadata` call on `path`. - /// * `path` does not exist. - pub async fn symlink_metadata(&self, path: impl AsRef) -> io::Result { - match self { - Self::Real => fs::symlink_metadata(path).await, - Self::Chroot(root) => fs::symlink_metadata(append(root.path(), path)).await, - Self::Fake(_) => panic!("unimplemented"), - } - } - - /// Reads a symbolic link, returning the file that the link points to. - /// - /// This is a proxy to [`tokio::fs::read_link`]. - pub async fn read_link(&self, path: impl AsRef) -> io::Result { - match self { - Self::Real => fs::read_link(path).await, - Self::Chroot(root) => Ok(append(root.path(), fs::read_link(append(root.path(), path)).await?)), - Self::Fake(_) => panic!("unimplemented"), - } - } - - /// Returns a stream over the entries within a directory. - /// - /// This is a proxy to [`tokio::fs::read_dir`]. - pub async fn read_dir(&self, path: impl AsRef) -> Result { - match self { - Self::Real => fs::read_dir(path).await, - Self::Chroot(root) => fs::read_dir(append(root.path(), path)).await, - Self::Fake(_) => panic!("unimplemented"), - } - } - - /// Returns the canonical, absolute form of a path with all intermediate - /// components normalized and symbolic links resolved. - /// - /// This is a proxy to [`tokio::fs::canonicalize`]. - pub async fn canonicalize(&self, path: impl AsRef) -> Result { - match self { - Self::Real => fs::canonicalize(path).await, - Self::Chroot(root) => fs::canonicalize(append(root.path(), path)).await, - Self::Fake(_) => panic!("unimplemented"), - } - } - - /// Changes the permissions found on a file or a directory. - /// - /// This is a proxy to [`tokio::fs::set_permissions`] - pub async fn set_permissions(&self, path: impl AsRef, perm: Permissions) -> Result<(), io::Error> { - match self { - Self::Real => fs::set_permissions(path, perm).await, - Self::Chroot(root) => fs::set_permissions(append(root.path(), path), perm).await, - Self::Fake(_) => panic!("unimplemented"), - } - } - - /// For test [Fs]'s that use a different root, returns an absolute path. - /// - /// This must be used for any paths indirectly used by code using a chroot - /// [Fs]. - pub fn chroot_path(&self, path: impl AsRef) -> PathBuf { - match self { - Self::Chroot(root) => append(root.path(), path), - _ => path.as_ref().to_path_buf(), - } - } - - /// See [Fs::chroot_path]. - pub fn chroot_path_str(&self, path: impl AsRef) -> String { - match self { - Self::Chroot(root) => append(root.path(), path).to_string_lossy().to_string(), - _ => path.as_ref().to_path_buf().to_string_lossy().to_string(), - } - } -} - -impl Default for Fs { - fn default() -> Self { - Self::new() - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[tokio::test] - async fn test_fake() { - let dir = PathBuf::from("/dir"); - let fs = Fs::from_slice(&[("/test", "test")]); - - fs.create_dir(dir.join("create_dir")).await.unwrap_err(); - fs.create_dir_all(dir.join("create/dir/all")).await.unwrap_err(); - fs.write(dir.join("write"), b"write").await.unwrap(); - assert_eq!(fs.read(dir.join("write")).await.unwrap(), b"write"); - assert_eq!(fs.read_to_string(dir.join("write")).await.unwrap(), "write"); - } - - #[tokio::test] - async fn test_real() { - let dir = tempfile::tempdir().unwrap(); - let fs = Fs::Real; - - fs.create_dir(dir.path().join("create_dir")).await.unwrap(); - fs.create_dir_all(dir.path().join("create/dir/all")).await.unwrap(); - fs.write(dir.path().join("write"), b"write").await.unwrap(); - assert_eq!(fs.read(dir.path().join("write")).await.unwrap(), b"write"); - assert_eq!(fs.read_to_string(dir.path().join("write")).await.unwrap(), "write"); - } - - macro_rules! test_append_cases { - ($( - $name:ident: ($a:expr, $b:expr) => $expected:expr - ),* $(,)?) => { - $( - #[test] - fn $name() { - assert_eq!(append($a, $b), normalize_test_path($expected)); - } - )* - }; -} - - test_append_cases!( - append_test_path_to_dir: ("/abc/test", "/test") => "/abc/test/test", - append_absolute_to_tmp_dir: ("/tmp/.dir", "/tmp/.dir/home/myuser") => "/tmp/.dir/home/myuser", - append_different_tmp_path: ("/tmp/.dir", "/tmp/hello") => "/tmp/.dir/tmp/hello", - append_nested_path_to_tmpdir: ("/tmp/.dir", "/tmp/.dir/tmp/.dir/home/user") => "/tmp/.dir/home/user", - ); - - #[tokio::test] - async fn test_read_to_string() { - let fs = Fs::new(); - fs.write("fake", "contents").await.unwrap(); - fs.write("invalid_utf8", &[255]).await.unwrap(); - - // async tests - assert_eq!( - fs.read_to_string("fake").await.unwrap(), - "contents", - "should read fake file" - ); - assert!( - fs.read_to_string("unknown") - .await - .is_err_and(|err| err.kind() == io::ErrorKind::NotFound), - "unknown path should return NotFound" - ); - assert!( - fs.read_to_string("invalid_utf8") - .await - .is_err_and(|err| err.kind() == io::ErrorKind::InvalidData), - "invalid utf8 should return InvalidData" - ); - - // sync tests - assert_eq!( - fs.read_to_string_sync("fake").unwrap(), - "contents", - "should read fake file" - ); - assert!( - fs.read_to_string_sync("unknown") - .is_err_and(|err| err.kind() == io::ErrorKind::NotFound), - "unknown path should return NotFound" - ); - assert!( - fs.read_to_string_sync("invalid_utf8") - .is_err_and(|err| err.kind() == io::ErrorKind::InvalidData), - "invalid utf8 should return InvalidData" - ); - } - - #[tokio::test] - #[cfg(unix)] - async fn test_chroot_file_operations_for_unix() { - if nix::unistd::Uid::effective().is_root() { - println!("currently running as root, skipping."); - return; - } - - let fs = Fs::new(); - assert!(fs.is_chroot()); - - fs.write("/fake", "contents").await.unwrap(); - assert_eq!(fs.read_to_string("/fake").await.unwrap(), "contents"); - assert_eq!(fs.read_to_string_sync("/fake").unwrap(), "contents"); - - assert!(!fs.try_exists("/etc").await.unwrap()); - - fs.create_dir_all("/etc/b/c").await.unwrap(); - assert!(fs.try_exists("/etc").await.unwrap()); - let mut read_dir = fs.read_dir("/etc").await.unwrap(); - let e = read_dir.next_entry().await.unwrap(); - assert!(e.unwrap().metadata().await.unwrap().is_dir()); - assert!(read_dir.next_entry().await.unwrap().is_none()); - - fs.remove_dir_all("/etc").await.unwrap(); - assert!(!fs.try_exists("/etc").await.unwrap()); - - fs.copy("/fake", "/fake_copy").await.unwrap(); - assert_eq!(fs.read_to_string("/fake_copy").await.unwrap(), "contents"); - assert_eq!(fs.read_to_string_sync("/fake_copy").unwrap(), "contents"); - - fs.remove_file("/fake_copy").await.unwrap(); - assert!(!fs.try_exists("/fake_copy").await.unwrap()); - - fs.symlink("/fake", "/fake_symlink").await.unwrap(); - fs.symlink_sync("/fake", "/fake_symlink_sync").unwrap(); - assert_eq!(fs.read_to_string("/fake_symlink").await.unwrap(), "contents"); - assert_eq!( - fs.read_to_string(fs.read_link("/fake_symlink").await.unwrap()) - .await - .unwrap(), - "contents" - ); - assert_eq!(fs.read_to_string("/fake_symlink_sync").await.unwrap(), "contents"); - assert_eq!(fs.read_to_string_sync("/fake_symlink").unwrap(), "contents"); - - // Checking symlink exist - assert!(fs.symlink_exists("/fake_symlink").await); - assert!(fs.exists("/fake_symlink")); - fs.remove_file("/fake").await.unwrap(); - assert!(fs.symlink_exists("/fake_symlink").await); - assert!(!fs.exists("/fake_symlink")); - - // Checking rename - fs.write("/rename_1", "abc").await.unwrap(); - fs.write("/rename_2", "123").await.unwrap(); - fs.rename("/rename_2", "/rename_1").await.unwrap(); - assert_eq!(fs.read_to_string("/rename_1").await.unwrap(), "123"); - - // Checking open - assert!(fs.open("/does_not_exist").await.is_err()); - assert!(fs.open("/rename_1").await.is_ok()); - } - - #[tokio::test] - async fn test_chroot_tempdir() { - let fs = Fs::new(); - let tempdir = fs.create_tempdir().await.unwrap(); - if let Fs::Chroot(root) = fs { - assert_eq!(tempdir.path().parent().unwrap(), root.path()); - } else { - panic!("tempdir should be created under root"); - } - } - - #[tokio::test] - async fn test_create_new() { - let fs = Fs::new(); - fs.create_new("my_file.txt").await.unwrap(); - assert!(fs.create_new("my_file.txt").await.is_err()); - } -} diff --git a/crates/chat-cli/src/os/fs/unix.rs b/crates/chat-cli/src/os/fs/unix.rs deleted file mode 100644 index 313b04213..000000000 --- a/crates/chat-cli/src/os/fs/unix.rs +++ /dev/null @@ -1,51 +0,0 @@ -use std::ffi::OsString; -use std::os::unix::ffi::{ - OsStrExt, - OsStringExt, -}; -use std::path::{ - Path, - PathBuf, -}; - -/// Performs `a.join(b)`, except: -/// - if `b` is an absolute path, then the resulting path will equal `/a/b` -/// - if the prefix of `b` contains some `n` copies of a, then the resulting path will equal `/a/b` -pub(super) fn append(a: impl AsRef, b: impl AsRef) -> PathBuf { - // Have to use byte slices since rust seems to always append - // a forward slash at the end of a path... - let a = a.as_ref().as_os_str().as_bytes(); - let mut b = b.as_ref().as_os_str().as_bytes(); - while b.starts_with(a) { - b = b.strip_prefix(a).unwrap(); - } - while b.starts_with(b"/") { - b = b.strip_prefix(b"/").unwrap(); - } - PathBuf::from(OsString::from_vec(a.to_vec())).join(PathBuf::from(OsString::from_vec(b.to_vec()))) -} - -/// Creates a new symbolic link on the filesystem. -/// -/// The `link` path will be a symbolic link pointing to the `original` path. -pub(super) fn symlink_sync(original: impl AsRef, link: impl AsRef) -> std::io::Result<()> { - std::os::unix::fs::symlink(original, link) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_append() { - macro_rules! assert_append { - ($a:expr, $b:expr, $expected:expr) => { - assert_eq!(append($a, $b), PathBuf::from($expected)); - }; - } - assert_append!("/abc/test", "/test", "/abc/test/test"); - assert_append!("/tmp/.dir", "/tmp/.dir/home/myuser", "/tmp/.dir/home/myuser"); - assert_append!("/tmp/.dir", "/tmp/hello", "/tmp/.dir/tmp/hello"); - assert_append!("/tmp/.dir", "/tmp/.dir/tmp/.dir/home/user", "/tmp/.dir/home/user"); - } -} diff --git a/crates/chat-cli/src/os/fs/windows.rs b/crates/chat-cli/src/os/fs/windows.rs deleted file mode 100644 index 7ba871f1a..000000000 --- a/crates/chat-cli/src/os/fs/windows.rs +++ /dev/null @@ -1,148 +0,0 @@ -use std::fs::metadata; -use std::io; -use std::path::{ - Component, - Path, - PathBuf, -}; - -/// Performs `a.join(b)`, except: -/// - if `b` is an absolute path, then the resulting path will equal `/a/b` -/// - if the prefix of `b` contains some `n` copies of a, then the resulting path will equal `/a/b` -pub(super) fn append(a: impl AsRef, b: impl AsRef) -> PathBuf { - let a_path = a.as_ref(); - let b_path = b.as_ref(); - - // Extract the non-prefix, non-root components of paths a and b for comparison - let a_normal_components: Vec<_> = a_path - .components() - .filter(|c| !matches!(c, Component::Prefix(_) | Component::RootDir)) - .collect(); - - // Create a version of b_path without prefix/root components - let mut b_normal_path = PathBuf::new(); - for comp in b_path.components() { - match comp { - Component::Prefix(_) | Component::RootDir => (), - _ => b_normal_path.push(comp.as_os_str()), - } - } - - // Iteratively strip a from the beginning of b - let mut cleaned_b = b_normal_path.clone(); - let mut done = false; - - while !done { - let b_normal_components: Vec<_> = cleaned_b.components().collect(); - - if b_normal_components.len() >= a_normal_components.len() { - // Check if the beginning of b matches a (case-insensitive on Windows) - let matches = a_normal_components - .iter() - .zip(b_normal_components.iter()) - .all(|(a_comp, b_comp)| { - // Case-insensitive comparison for Windows - a_comp.as_os_str().to_string_lossy().to_lowercase() - == b_comp.as_os_str().to_string_lossy().to_lowercase() - }); - - if matches { - // Create a new path with a's components removed from the beginning of b - let mut new_b = PathBuf::new(); - for comp in b_normal_components.iter().skip(a_normal_components.len()) { - new_b.push(comp.as_os_str()); - } - cleaned_b = new_b; - } else { - done = true; - } - } else { - done = true; - } - } - - // Join the paths - a_path.join(cleaned_b) -} - -/// Creates a new symbolic link on the filesystem. -/// -/// The `link` path will be a symbolic link pointing to the `original` path. -/// On Windows, we need to determine if the target is a file or directory. -pub(super) fn symlink_sync(original: impl AsRef, link: impl AsRef) -> io::Result<()> { - // Determine if the original is a file or directory - let meta = metadata(original.as_ref())?; - if meta.is_dir() { - std::os::windows::fs::symlink_dir(original, link) - } else { - std::os::windows::fs::symlink_file(original, link) - } -} - -/// Creates a new symbolic link asynchronously. -/// -/// This is a helper function for the Windows implementation. -pub(super) async fn symlink_async(original: impl AsRef, link: impl AsRef) -> io::Result<()> { - // Determine if the original is a file or directory - let meta = metadata(original.as_ref())?; - if meta.is_dir() { - tokio::fs::symlink_dir(original, link).await - } else { - tokio::fs::symlink_file(original, link).await - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_append() { - macro_rules! assert_append { - ($a:expr, $b:expr, $expected:expr) => { - assert_eq!(append($a, $b), PathBuf::from($expected)); - }; - } - - // Test different drive letters (should strip prefix) - assert_append!("C:\\temp", "D:\\test", "C:\\temp\\test"); - - // Test same path prefixes (should use strip_prefix) - assert_append!("C:\\temp", "C:\\temp\\subdir", "C:\\temp\\subdir"); - assert_append!("C:\\temp", "C:\\temp\\subdir\\file.txt", "C:\\temp\\subdir\\file.txt"); - - // Test relative path (standard join) - assert_append!("C:\\temp", "subdir\\file.txt", "C:\\temp\\subdir\\file.txt"); - - // Test different absolute paths with same drive (strip drive and root) - assert_append!("C:\\temprootdir", "C:\\test_file.txt", "C:\\temprootdir\\test_file.txt"); - - // Test different absolute paths with different drives - assert_append!("C:\\temprootdir", "D:\\test_file.txt", "C:\\temprootdir\\test_file.txt"); - - // Test paths with mixed case (should be case-insensitive on Windows) - assert_append!("C:\\Temp", "c:\\temp\\file.txt", "C:\\Temp\\file.txt"); - } -} - -#[cfg(test)] -mod integration_tests { - use tempfile::TempDir; - - use super::*; - - #[test] - fn test_append_with_real_paths() { - // Create a temporary directory for testing - let temp_dir = TempDir::new().unwrap(); - let temp_path = temp_dir.path(); - - // Test appending an absolute path - let drive_letter = temp_path.to_string_lossy().chars().next().unwrap_or('C'); - let absolute_path = format!("{}:\\test.txt", drive_letter); - - let result = append(temp_path, absolute_path); - assert!(result.to_string_lossy().contains("test.txt")); - assert!(!result.to_string_lossy().contains(":\\test.txt")); - } -} diff --git a/crates/chat-cli/src/os/mod.rs b/crates/chat-cli/src/os/mod.rs deleted file mode 100644 index 69fb4b74c..000000000 --- a/crates/chat-cli/src/os/mod.rs +++ /dev/null @@ -1,86 +0,0 @@ -#![allow(dead_code)] - -pub mod diagnostics; -mod env; -mod fs; -mod sysinfo; - -pub use env::Env; -use eyre::Result; -pub use fs::Fs; -pub use sysinfo::SysInfo; - -use crate::api_client::ApiClient; -use crate::database::Database; -use crate::telemetry::TelemetryThread; - -const WINDOWS_USER_HOME: &str = "C:\\Users\\testuser"; -const UNIX_USER_HOME: &str = "/home/testuser"; - -pub const ACTIVE_USER_HOME: &str = if cfg!(windows) { - WINDOWS_USER_HOME -} else { - UNIX_USER_HOME -}; - -// TODO OS SHOULD NOT BE CLONE - -/// Struct that contains the interface to every system related IO operation. -/// -/// Every operation that accesses the file system, environment, or other related platform -/// primitives should be done through a [Context] as this enables testing otherwise untestable -/// code paths in unit tests. -#[derive(Clone, Debug)] -pub struct Os { - pub env: Env, - pub fs: Fs, - pub sysinfo: SysInfo, - pub database: Database, - pub client: ApiClient, - pub telemetry: TelemetryThread, -} - -impl Os { - pub async fn new() -> Result { - let env = Env::new(); - let fs = Fs::new(); - let mut database = Database::new().await?; - let client = ApiClient::new(&env, &fs, &mut database, None).await?; - let telemetry = TelemetryThread::new(&env, &fs, &mut database).await?; - - Ok(Self { - env, - fs, - sysinfo: SysInfo::new(), - database, - client, - telemetry, - }) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[tokio::test] - async fn test_context_builder_with_test_home() { - let os = Os::new().await.unwrap(); - unsafe { - os.env.set_var("hello", "world"); - } - - #[cfg(windows)] - { - assert!(os.fs.try_exists(ACTIVE_USER_HOME).await.unwrap()); - assert_eq!(os.env.get("USERPROFILE").unwrap(), ACTIVE_USER_HOME); - } - #[cfg(not(windows))] - { - assert!(os.fs.try_exists(ACTIVE_USER_HOME).await.unwrap()); - assert_eq!(os.env.get("HOME").unwrap(), ACTIVE_USER_HOME); - } - - assert_eq!(os.env.get("hello").unwrap(), "world"); - } -} diff --git a/crates/chat-cli/src/os/sysinfo.rs b/crates/chat-cli/src/os/sysinfo.rs deleted file mode 100644 index 40cbbb158..000000000 --- a/crates/chat-cli/src/os/sysinfo.rs +++ /dev/null @@ -1,63 +0,0 @@ -use std::ffi::OsString; -use std::sync::{ - Arc, - Mutex, -}; - -#[derive(Debug, Clone, Default)] -pub struct SysInfo(inner::Inner); - -mod inner { - use std::collections::HashSet; - use std::sync::{ - Arc, - Mutex, - }; - - #[derive(Debug, Clone, Default)] - pub enum Inner { - #[default] - Real, - Fake(Arc>), - } - - #[derive(Debug, Clone, Default)] - pub struct Fake { - pub process_names: HashSet, - } -} - -impl SysInfo { - pub fn new() -> Self { - match cfg!(test) { - true => Self(inner::Inner::Fake(Arc::new(Mutex::new(inner::Fake::default())))), - false => Self(inner::Inner::Real), - } - } - - /// Returns whether the process containing `name` is running. - pub fn is_process_running(&self, name: &str) -> bool { - use inner::Inner; - match &self.0 { - Inner::Real => { - let system = sysinfo::System::new_all(); - - system.processes_by_name(&OsString::from(name)).next().is_some() - }, - Inner::Fake(fake) => fake.lock().unwrap().process_names.contains(name), - } - } - - pub fn add_running_processes(&self, process_names: &[&str]) { - use inner::Inner; - match &self.0 { - Inner::Real => panic!("unimplemented"), - Inner::Fake(fake) => { - let curr_names = &mut fake.lock().unwrap().process_names; - for name in process_names { - curr_names.insert((*name).to_string()); - } - }, - } - } -} diff --git a/crates/chat-cli/src/request.rs b/crates/chat-cli/src/request.rs deleted file mode 100644 index e9b3abacc..000000000 --- a/crates/chat-cli/src/request.rs +++ /dev/null @@ -1,105 +0,0 @@ -use std::env::current_exe; -use std::sync::{ - Arc, - LazyLock, -}; - -use reqwest::Client; -use rustls::{ - ClientConfig, - RootCertStore, -}; -use thiserror::Error; -use url::ParseError; - -#[derive(Debug, Error)] -pub enum RequestError { - #[error(transparent)] - Reqwest(#[from] reqwest::Error), - #[error(transparent)] - Serde(#[from] serde_json::Error), - #[error(transparent)] - Io(#[from] std::io::Error), - #[error(transparent)] - Dir(#[from] crate::util::directories::DirectoryError), - #[error(transparent)] - Settings(#[from] crate::database::DatabaseError), - #[error(transparent)] - UrlParseError(#[from] ParseError), -} - -pub fn new_client() -> Result { - Ok(Client::builder() - .use_preconfigured_tls(client_config()) - .user_agent(USER_AGENT.chars().filter(|c| c.is_ascii_graphic()).collect::()) - .cookie_store(true) - .build()?) -} - -pub fn create_default_root_cert_store() -> RootCertStore { - let mut root_cert_store: RootCertStore = webpki_roots::TLS_SERVER_ROOTS.iter().cloned().collect(); - - // The errors are ignored because root certificates often include - // ancient or syntactically invalid certificates - let rustls_native_certs::CertificateResult { certs, errors: _, .. } = rustls_native_certs::load_native_certs(); - for cert in certs { - let _ = root_cert_store.add(cert); - } - - root_cert_store -} - -fn client_config() -> ClientConfig { - let provider = rustls::crypto::CryptoProvider::get_default() - .cloned() - .unwrap_or_else(|| Arc::new(rustls::crypto::ring::default_provider())); - - ClientConfig::builder_with_provider(provider) - .with_protocol_versions(rustls::DEFAULT_VERSIONS) - .expect("Failed to set supported TLS versions") - .with_root_certificates(create_default_root_cert_store()) - .with_no_client_auth() -} - -static USER_AGENT: LazyLock = LazyLock::new(|| { - let name = current_exe() - .ok() - .and_then(|exe| exe.file_stem().and_then(|name| name.to_str().map(String::from))) - .unwrap_or_else(|| "unknown-rust-client".into()); - - let os = std::env::consts::OS; - let arch = std::env::consts::ARCH; - let version = env!("CARGO_PKG_VERSION"); - - format!("{name}-{os}-{arch}-{version}") -}); - -#[cfg(test)] -mod tests { - use super::*; - - #[tokio::test] - async fn get_client() { - new_client().unwrap(); - } - - #[tokio::test] - async fn request_test() { - let mut server = mockito::Server::new_async().await; - let mock = server - .mock("GET", "/hello") - .with_status(200) - .with_header("content-type", "text/plain") - .with_body("world") - .create(); - let url = server.url(); - - let client = new_client().unwrap(); - let res = client.get(format!("{url}/hello")).send().await.unwrap(); - assert_eq!(res.status(), 200); - assert_eq!(res.headers()["content-type"], "text/plain"); - assert_eq!(res.text().await.unwrap(), "world"); - - mock.expect(1).assert(); - } -} diff --git a/crates/chat-cli/src/telemetry/cognito.rs b/crates/chat-cli/src/telemetry/cognito.rs deleted file mode 100644 index 4fa921155..000000000 --- a/crates/chat-cli/src/telemetry/cognito.rs +++ /dev/null @@ -1,178 +0,0 @@ -use std::time::SystemTime; - -use amzn_toolkit_telemetry_client::config::BehaviorVersion; -use aws_credential_types::provider::error::CredentialsError; -use aws_credential_types::{ - Credentials, - provider, -}; -use aws_sdk_cognitoidentity::primitives::{ - DateTime, - DateTimeFormat, -}; -use tracing::{ - trace, - warn, -}; - -use crate::aws_common::app_name; -use crate::database::{ - CredentialsJson, - Database, -}; -use crate::telemetry::TelemetryStage; - -pub async fn get_cognito_credentials_send( - database: &mut Database, - telemetry_stage: &TelemetryStage, -) -> Result { - trace!("Creating new cognito credentials"); - - let conf = aws_sdk_cognitoidentity::Config::builder() - .behavior_version(BehaviorVersion::v2025_08_07()) - .region(telemetry_stage.region.clone()) - .app_name(app_name()) - .build(); - let client = aws_sdk_cognitoidentity::Client::from_conf(conf); - - let identity_id = client - .get_id() - .identity_pool_id(telemetry_stage.cognito_pool_id) - .send() - .await - .map_err(CredentialsError::provider_error)? - .identity_id - .ok_or(CredentialsError::provider_error("no identity_id from get_id"))?; - - let credentials = client - .get_credentials_for_identity() - .identity_id(identity_id) - .send() - .await - .map_err(CredentialsError::provider_error)? - .credentials - .ok_or(CredentialsError::provider_error( - "no credentials from get_credentials_for_identity", - ))?; - - database.set_credentials_entry(&credentials).ok(); - - let Some(access_key_id) = credentials.access_key_id else { - return Err(CredentialsError::provider_error("access key id not found")); - }; - - let Some(secret_key) = credentials.secret_key else { - return Err(CredentialsError::provider_error("secret access key not found")); - }; - - Ok(Credentials::new( - access_key_id, - secret_key, - credentials.session_token, - credentials.expiration.and_then(|dt| dt.try_into().ok()), - "", - )) -} - -pub async fn get_cognito_credentials( - database: &mut Database, - telemetry_stage: &TelemetryStage, -) -> Result { - match database - .get_credentials_entry() - .map_err(CredentialsError::provider_error)? - { - Some(CredentialsJson { - access_key_id, - secret_key, - session_token, - expiration, - }) => { - if is_expired(expiration.as_ref()) { - return get_cognito_credentials_send(database, telemetry_stage).await; - } - - let Some(access_key_id) = access_key_id else { - return get_cognito_credentials_send(database, telemetry_stage).await; - }; - - let Some(secret_key) = secret_key else { - return get_cognito_credentials_send(database, telemetry_stage).await; - }; - - Ok(Credentials::new( - access_key_id, - secret_key, - session_token, - expiration - .and_then(|s| DateTime::from_str(&s, DateTimeFormat::DateTime).ok()) - .and_then(|dt| dt.try_into().ok()), - "", - )) - }, - None => get_cognito_credentials_send(database, telemetry_stage).await, - } -} - -#[derive(Debug)] -pub struct CognitoProvider { - telemetry_stage: TelemetryStage, -} - -impl CognitoProvider { - pub fn new(telemetry_stage: TelemetryStage) -> CognitoProvider { - CognitoProvider { telemetry_stage } - } -} - -impl provider::ProvideCredentials for CognitoProvider { - fn provide_credentials<'a>(&'a self) -> provider::future::ProvideCredentials<'a> - where - Self: 'a, - { - provider::future::ProvideCredentials::new(async { - match Database::new().await { - Ok(mut db) => get_cognito_credentials(&mut db, &self.telemetry_stage).await, - Err(err) => Err(CredentialsError::provider_error(format!( - "failed to get database: {:?}", - err - ))), - } - }) - } -} - -fn is_expired(expiration: Option<&String>) -> bool { - let expiration = if let Some(v) = expiration { - v - } else { - warn!("no cognito expiration was saved"); - return true; - }; - - match DateTime::from_str(expiration, DateTimeFormat::DateTime) { - Ok(expiration) => { - // Check if the expiration is at least after five minutes after the current time. - let curr: DateTime = (SystemTime::now() + std::time::Duration::from_secs(60 * 5)).into(); - expiration < curr - }, - Err(err) => { - warn!(?err, "invalid cognito expiration was saved"); - true - }, - } -} - -#[cfg(test)] -mod test { - use super::*; - - #[tokio::test] - async fn pools() { - for telemetry_stage in [TelemetryStage::BETA, TelemetryStage::EXTERNAL_PROD] { - get_cognito_credentials_send(&mut Database::new().await.unwrap(), &telemetry_stage) - .await - .unwrap(); - } - } -} diff --git a/crates/chat-cli/src/telemetry/core.rs b/crates/chat-cli/src/telemetry/core.rs deleted file mode 100644 index a8600b2d5..000000000 --- a/crates/chat-cli/src/telemetry/core.rs +++ /dev/null @@ -1,467 +0,0 @@ -use std::fmt::Debug; -use std::time::SystemTime; - -pub use amzn_toolkit_telemetry_client::types::MetricDatum; -use strum::{ - Display, - EnumString, -}; - -use crate::telemetry::definitions::IntoMetricDatum; -use crate::telemetry::definitions::metrics::{ - AmazonqDidSelectProfile, - AmazonqEndChat, - AmazonqMessageResponseError, - AmazonqProfileState, - AmazonqStartChat, - CodewhispererterminalAddChatMessage, - CodewhispererterminalCliSubcommandExecuted, - CodewhispererterminalMcpServerInit, - CodewhispererterminalRefreshCredentials, - CodewhispererterminalToolUseSuggested, - CodewhispererterminalUserLoggedIn, -}; -use crate::telemetry::definitions::types::{ - CodewhispererterminalCustomToolInputTokenSize, - CodewhispererterminalCustomToolLatency, - CodewhispererterminalCustomToolOutputTokenSize, - CodewhispererterminalIsToolValid, - CodewhispererterminalMcpServerInitFailureReason, - CodewhispererterminalToolName, - CodewhispererterminalToolUseId, - CodewhispererterminalToolUseIsSuccess, - CodewhispererterminalToolsPerMcpServer, - CodewhispererterminalUserInputId, - CodewhispererterminalUtteranceId, -}; - -/// A serializable telemetry event that can be sent or queued. -#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct Event { - pub created_time: Option, - pub credential_start_url: Option, - pub sso_region: Option, - pub client_application: Option, - #[serde(flatten)] - pub ty: EventType, -} - -impl Event { - pub fn new(ty: EventType) -> Self { - Self { - ty, - created_time: Some(SystemTime::now()), - credential_start_url: None, - sso_region: None, - client_application: None, - } - } - - pub fn set_start_url(&mut self, start_url: String) { - self.credential_start_url = Some(start_url); - } - - pub fn set_sso_region(&mut self, sso_region: String) { - self.sso_region = Some(sso_region); - } - - pub fn set_client_application(&mut self, client_application: String) { - self.client_application = Some(client_application); - } - - pub fn into_metric_datum(self) -> Option { - match self.ty { - EventType::UserLoggedIn {} => Some( - CodewhispererterminalUserLoggedIn { - create_time: self.created_time, - value: None, - credential_start_url: self.credential_start_url.map(Into::into), - codewhispererterminal_in_cloudshell: None, - } - .into_metric_datum(), - ), - EventType::RefreshCredentials { - request_id, - result, - reason, - oauth_flow, - } => Some( - CodewhispererterminalRefreshCredentials { - create_time: self.created_time, - value: None, - credential_start_url: self.credential_start_url.map(Into::into), - request_id: Some(request_id.into()), - result: Some(result.to_string().into()), - reason: reason.map(Into::into), - oauth_flow: Some(oauth_flow.into()), - codewhispererterminal_in_cloudshell: None, - } - .into_metric_datum(), - ), - EventType::CliSubcommandExecuted { subcommand } => Some( - CodewhispererterminalCliSubcommandExecuted { - create_time: self.created_time, - value: None, - credential_start_url: self.credential_start_url.map(Into::into), - codewhispererterminal_subcommand: Some(subcommand.into()), - codewhispererterminal_in_cloudshell: None, - codewhispererterminal_client_application: self.client_application.map(Into::into), - } - .into_metric_datum(), - ), - EventType::ChatStart { conversation_id, model } => Some( - AmazonqStartChat { - create_time: self.created_time, - value: None, - credential_start_url: self.credential_start_url.map(Into::into), - amazonq_conversation_id: Some(conversation_id.into()), - codewhispererterminal_in_cloudshell: None, - codewhispererterminal_model: model.map(Into::into), - } - .into_metric_datum(), - ), - EventType::ChatEnd { conversation_id, model } => Some( - AmazonqEndChat { - create_time: self.created_time, - value: None, - credential_start_url: self.credential_start_url.map(Into::into), - amazonq_conversation_id: Some(conversation_id.into()), - codewhispererterminal_in_cloudshell: None, - codewhispererterminal_model: model.map(Into::into), - } - .into_metric_datum(), - ), - EventType::ChatAddedMessage { - conversation_id, - context_file_length, - message_id, - request_id, - result, - reason, - reason_desc, - status_code, - model, - .. - } => Some( - CodewhispererterminalAddChatMessage { - create_time: self.created_time, - value: None, - amazonq_conversation_id: Some(conversation_id.into()), - request_id: request_id.map(Into::into), - codewhispererterminal_utterance_id: message_id.map(Into::into), - credential_start_url: self.credential_start_url.map(Into::into), - sso_region: self.sso_region.map(Into::into), - codewhispererterminal_in_cloudshell: None, - codewhispererterminal_context_file_length: context_file_length.map(|l| l as i64).map(Into::into), - result: result.to_string().into(), - reason: reason.map(Into::into), - reason_desc: reason_desc.map(Into::into), - status_code: status_code.map(|v| v as i64).map(Into::into), - codewhispererterminal_model: model.map(Into::into), - codewhispererterminal_client_application: self.client_application.map(Into::into), - } - .into_metric_datum(), - ), - EventType::ToolUseSuggested { - conversation_id, - utterance_id, - user_input_id, - tool_use_id, - tool_name, - is_accepted, - is_valid, - is_success, - is_custom_tool, - input_token_size, - output_token_size, - custom_tool_call_latency, - model, - aws_service_name, - aws_operation_name, - } => Some( - CodewhispererterminalToolUseSuggested { - create_time: self.created_time, - credential_start_url: self.credential_start_url.map(Into::into), - value: None, - amazonq_conversation_id: Some(conversation_id.into()), - codewhispererterminal_utterance_id: utterance_id.map(CodewhispererterminalUtteranceId), - codewhispererterminal_user_input_id: user_input_id.map(CodewhispererterminalUserInputId), - codewhispererterminal_tool_use_id: tool_use_id.map(CodewhispererterminalToolUseId), - codewhispererterminal_tool_name: tool_name.map(CodewhispererterminalToolName), - codewhispererterminal_is_tool_use_accepted: Some(is_accepted.into()), - codewhispererterminal_is_tool_valid: is_valid.map(CodewhispererterminalIsToolValid), - codewhispererterminal_tool_use_is_success: is_success.map(CodewhispererterminalToolUseIsSuccess), - codewhispererterminal_is_custom_tool: Some(is_custom_tool.into()), - codewhispererterminal_custom_tool_input_token_size: input_token_size - .map(|s| CodewhispererterminalCustomToolInputTokenSize(s as i64)), - codewhispererterminal_custom_tool_output_token_size: output_token_size - .map(|s| CodewhispererterminalCustomToolOutputTokenSize(s as i64)), - codewhispererterminal_custom_tool_latency: custom_tool_call_latency - .map(|l| CodewhispererterminalCustomToolLatency(l as i64)), - codewhispererterminal_model: model.map(Into::into), - codewhispererterminal_client_application: self.client_application.map(Into::into), - codewhispererterminal_aws_service_name: aws_service_name.map(Into::into), - codewhispererterminal_aws_operation_name: aws_operation_name.map(Into::into), - } - .into_metric_datum(), - ), - EventType::McpServerInit { - conversation_id, - init_failure_reason, - number_of_tools, - } => Some( - CodewhispererterminalMcpServerInit { - create_time: self.created_time, - credential_start_url: self.credential_start_url.map(Into::into), - value: None, - amazonq_conversation_id: Some(conversation_id.into()), - codewhispererterminal_mcp_server_init_failure_reason: init_failure_reason - .map(CodewhispererterminalMcpServerInitFailureReason), - codewhispererterminal_tools_per_mcp_server: Some(CodewhispererterminalToolsPerMcpServer( - number_of_tools as i64, - )), - codewhispererterminal_client_application: self.client_application.map(Into::into), - } - .into_metric_datum(), - ), - EventType::DidSelectProfile { - source, - amazonq_profile_region, - result, - sso_region, - profile_count, - } => Some( - AmazonqDidSelectProfile { - create_time: self.created_time, - value: None, - source: Some(source.to_string().into()), - amazon_q_profile_region: Some(amazonq_profile_region.into()), - result: Some(result.to_string().into()), - sso_region: sso_region.map(Into::into), - credential_start_url: self.credential_start_url.map(Into::into), - profile_count: profile_count.map(Into::into), - } - .into_metric_datum(), - ), - EventType::ProfileState { - source, - amazonq_profile_region, - result, - sso_region, - } => Some( - AmazonqProfileState { - create_time: self.created_time, - value: None, - source: Some(source.to_string().into()), - amazon_q_profile_region: Some(amazonq_profile_region.into()), - result: Some(result.to_string().into()), - sso_region: sso_region.map(Into::into), - credential_start_url: self.credential_start_url.map(Into::into), - } - .into_metric_datum(), - ), - EventType::MessageResponseError { - conversation_id, - context_file_length, - result, - reason, - reason_desc, - status_code, - } => Some( - AmazonqMessageResponseError { - create_time: self.created_time, - value: None, - amazonq_conversation_id: Some(conversation_id.into()), - codewhispererterminal_context_file_length: context_file_length.map(|l| l as i64).map(Into::into), - credential_start_url: self.credential_start_url.map(Into::into), - sso_region: self.sso_region.map(Into::into), - result: Some(result.to_string().into()), - reason: reason.map(Into::into), - reason_desc: reason_desc.map(Into::into), - status_code: status_code.map(|v| v as i64).map(Into::into), - codewhispererterminal_client_application: self.client_application.map(Into::into), - } - .into_metric_datum(), - ), - } - } -} - -#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] -#[serde(rename_all = "camelCase")] -#[serde(tag = "type")] -pub enum EventType { - UserLoggedIn {}, - RefreshCredentials { - request_id: String, - result: TelemetryResult, - reason: Option, - oauth_flow: String, - }, - CliSubcommandExecuted { - subcommand: String, - }, - ChatStart { - conversation_id: String, - model: Option, - }, - ChatEnd { - conversation_id: String, - model: Option, - }, - ChatAddedMessage { - conversation_id: String, - message_id: Option, - request_id: Option, - context_file_length: Option, - result: TelemetryResult, - reason: Option, - reason_desc: Option, - status_code: Option, - model: Option, - }, - ToolUseSuggested { - conversation_id: String, - utterance_id: Option, - user_input_id: Option, - tool_use_id: Option, - tool_name: Option, - is_accepted: bool, - is_success: Option, - is_valid: Option, - is_custom_tool: bool, - input_token_size: Option, - output_token_size: Option, - custom_tool_call_latency: Option, - model: Option, - aws_service_name: Option, - aws_operation_name: Option, - }, - McpServerInit { - conversation_id: String, - init_failure_reason: Option, - number_of_tools: usize, - }, - DidSelectProfile { - source: QProfileSwitchIntent, - amazonq_profile_region: String, - result: TelemetryResult, - sso_region: Option, - profile_count: Option, - }, - ProfileState { - source: QProfileSwitchIntent, - amazonq_profile_region: String, - result: TelemetryResult, - sso_region: Option, - }, - MessageResponseError { - result: TelemetryResult, - reason: Option, - reason_desc: Option, - status_code: Option, - conversation_id: String, - context_file_length: Option, - }, -} - -#[derive(Debug)] -pub struct ToolUseEventBuilder { - pub conversation_id: String, - pub utterance_id: Option, - pub user_input_id: Option, - pub tool_use_id: Option, - pub tool_name: Option, - pub is_accepted: bool, - pub is_success: Option, - pub is_valid: Option, - pub is_custom_tool: bool, - pub input_token_size: Option, - pub output_token_size: Option, - pub custom_tool_call_latency: Option, - pub model: Option, - pub aws_service_name: Option, - pub aws_operation_name: Option, -} - -impl ToolUseEventBuilder { - pub fn new(conv_id: String, tool_use_id: String, model: Option) -> Self { - Self { - conversation_id: conv_id, - utterance_id: None, - user_input_id: None, - tool_use_id: Some(tool_use_id), - tool_name: None, - is_accepted: false, - is_success: None, - is_valid: None, - is_custom_tool: false, - input_token_size: None, - output_token_size: None, - custom_tool_call_latency: None, - model, - aws_service_name: None, - aws_operation_name: None, - } - } - - pub fn utterance_id(mut self, id: Option) -> Self { - self.utterance_id = id; - self - } - - pub fn set_tool_use_id(mut self, id: String) -> Self { - self.tool_use_id.replace(id); - self - } - - pub fn set_tool_name(mut self, name: String) -> Self { - self.tool_name.replace(name); - self - } -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)] -pub enum SuggestionState { - Accept, - Discard, - Empty, - Reject, -} - -impl SuggestionState { - pub fn is_accepted(&self) -> bool { - matches!(self, SuggestionState::Accept) - } -} - -impl From for amzn_codewhisperer_client::types::SuggestionState { - fn from(value: SuggestionState) -> Self { - match value { - SuggestionState::Accept => amzn_codewhisperer_client::types::SuggestionState::Accept, - SuggestionState::Discard => amzn_codewhisperer_client::types::SuggestionState::Discard, - SuggestionState::Empty => amzn_codewhisperer_client::types::SuggestionState::Empty, - SuggestionState::Reject => amzn_codewhisperer_client::types::SuggestionState::Reject, - } - } -} - -#[derive(Debug, Copy, Clone, PartialEq, Eq, EnumString, Display, serde::Serialize, serde::Deserialize)] -pub enum TelemetryResult { - Succeeded, - Failed, - Cancelled, -} - -/// 'user' -> users change the profile through Q CLI user profile command -/// 'auth' -> users change the profile through dashboard -/// 'update' -> CLI auto select the profile on users' behalf as there is only 1 profile -/// 'reload' -> CLI will try to reload previous selected profile upon CLI is running -#[derive(Debug, Copy, Clone, PartialEq, Eq, EnumString, Display, serde::Serialize, serde::Deserialize)] -pub enum QProfileSwitchIntent { - User, - Auth, - Update, - Reload, -} diff --git a/crates/chat-cli/src/telemetry/definitions.rs b/crates/chat-cli/src/telemetry/definitions.rs deleted file mode 100644 index ba4ba069d..000000000 --- a/crates/chat-cli/src/telemetry/definitions.rs +++ /dev/null @@ -1,46 +0,0 @@ -#![allow(dead_code)] - -// https://github.com/aws/aws-toolkit-common/blob/main/telemetry/telemetryformat.md - -pub trait IntoMetricDatum: Send { - fn into_metric_datum(self) -> amzn_toolkit_telemetry_client::types::MetricDatum; -} - -include!(concat!(env!("OUT_DIR"), "/mod.rs")); - -#[cfg(test)] -mod tests { - use std::time::SystemTime; - - use super::*; - use crate::telemetry::definitions::metrics::CodewhispererterminalAddChatMessage; - - #[test] - fn test_serde() { - let metric_datum_init = Metric::CodewhispererterminalAddChatMessage(CodewhispererterminalAddChatMessage { - amazonq_conversation_id: None, - request_id: None, - codewhispererterminal_context_file_length: None, - create_time: Some(SystemTime::now()), - value: None, - credential_start_url: Some("https://example.com".to_owned().into()), - sso_region: Some("us-east-1".to_owned().into()), - codewhispererterminal_in_cloudshell: None, - codewhispererterminal_utterance_id: Some("message_id".to_owned().into()), - result: crate::telemetry::definitions::types::Result::new("Succeeded".to_string()), - reason: None, - reason_desc: None, - status_code: None, - codewhispererterminal_model: None, - codewhispererterminal_client_application: None, - }); - - let s = serde_json::to_string_pretty(&metric_datum_init).unwrap(); - println!("{s}"); - - let metric_datum_out: Metric = serde_json::from_str(&s).unwrap(); - println!("{metric_datum_out:#?}"); - - assert_eq!(metric_datum_init, metric_datum_out); - } -} diff --git a/crates/chat-cli/src/telemetry/endpoint.rs b/crates/chat-cli/src/telemetry/endpoint.rs deleted file mode 100644 index 681d19af7..000000000 --- a/crates/chat-cli/src/telemetry/endpoint.rs +++ /dev/null @@ -1,32 +0,0 @@ -use amzn_toolkit_telemetry_client::config::endpoint::{ - Endpoint, - EndpointFuture, - Params, - ResolveEndpoint, -}; - -#[derive(Debug, Clone, Copy)] -pub(crate) struct StaticEndpoint(pub &'static str); - -impl ResolveEndpoint for StaticEndpoint { - fn resolve_endpoint<'a>(&'a self, _params: &'a Params) -> EndpointFuture<'a> { - let endpoint = Endpoint::builder().url(self.0).build(); - tracing::info!(?endpoint, "Resolving endpoint"); - EndpointFuture::ready(Ok(endpoint)) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[tokio::test] - async fn test_static_endpoint() { - let endpoint = StaticEndpoint("https://example.com"); - let params = Params::builder().build().unwrap(); - let endpoint = endpoint.resolve_endpoint(¶ms).await.unwrap(); - assert_eq!(endpoint.url(), "https://example.com"); - assert!(endpoint.properties().is_empty()); - assert!(endpoint.headers().count() == 0); - } -} diff --git a/crates/chat-cli/src/telemetry/install_method.rs b/crates/chat-cli/src/telemetry/install_method.rs deleted file mode 100644 index 2a541252a..000000000 --- a/crates/chat-cli/src/telemetry/install_method.rs +++ /dev/null @@ -1,45 +0,0 @@ -use std::process::Command; -use std::sync::LazyLock; - -use serde::{ - Deserialize, - Serialize, -}; - -static INSTALL_METHOD: LazyLock = LazyLock::new(|| { - if let Ok(output) = Command::new("brew").args(["list", "amazon-q", "-1"]).output() { - if output.status.success() { - return InstallMethod::Brew; - } - } - - if let Ok(current_exe) = std::env::current_exe() { - if current_exe.components().any(|c| c.as_os_str() == ".toolbox") { - return InstallMethod::Toolbox; - } - } - - InstallMethod::Unknown -}); - -/// The method of installation that Fig was installed with -#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] -pub enum InstallMethod { - Brew, - Toolbox, - Unknown, -} - -impl std::fmt::Display for InstallMethod { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str(match self { - InstallMethod::Brew => "brew", - InstallMethod::Toolbox => "toolbox", - InstallMethod::Unknown => "unknown", - }) - } -} - -pub fn get_install_method() -> InstallMethod { - *INSTALL_METHOD -} diff --git a/crates/chat-cli/src/telemetry/mod.rs b/crates/chat-cli/src/telemetry/mod.rs deleted file mode 100644 index 277f5186f..000000000 --- a/crates/chat-cli/src/telemetry/mod.rs +++ /dev/null @@ -1,710 +0,0 @@ -pub mod cognito; -pub mod core; -pub mod definitions; -pub mod endpoint; -mod install_method; - -use core::ToolUseEventBuilder; -use std::str::FromStr; - -use amzn_codewhisperer_client::types::{ - ChatAddMessageEvent, - IdeCategory, - OperatingSystem, - TelemetryEvent, - UserContext, -}; -use amzn_toolkit_telemetry_client::config::{ - BehaviorVersion, - Region, -}; -use amzn_toolkit_telemetry_client::error::DisplayErrorContext; -use amzn_toolkit_telemetry_client::types::AwsProduct; -use amzn_toolkit_telemetry_client::{ - Client as ToolkitTelemetryClient, - Config, -}; -use aws_credential_types::provider::SharedCredentialsProvider; -use cognito::CognitoProvider; -use endpoint::StaticEndpoint; -pub use install_method::{ - InstallMethod, - get_install_method, -}; -use tokio::sync::mpsc; -use tokio::task::JoinHandle; -use tokio::time::error::Elapsed; -use tracing::{ - debug, - error, - trace, -}; -use uuid::{ - Uuid, - uuid, -}; - -use crate::api_client::{ - ApiClient, - ApiClientError, -}; -use crate::auth::builder_id::get_start_url_and_region; -use crate::aws_common::app_name; -use crate::cli::RootSubcommand; -use crate::database::settings::Setting; -use crate::database::{ - Database, - DatabaseError, -}; -use crate::os::{ - Env, - Fs, -}; -use crate::telemetry::core::Event; -pub use crate::telemetry::core::{ - EventType, - QProfileSwitchIntent, - TelemetryResult, -}; -use crate::util::env_var::Q_CLI_CLIENT_APPLICATION; -use crate::util::system_info::os_version; - -#[derive(thiserror::Error, Debug)] -pub enum TelemetryError { - #[error(transparent)] - Client(Box), - #[error(transparent)] - Send(Box>), - #[error(transparent)] - ApiClient(Box), - #[error(transparent)] - Join(#[from] tokio::task::JoinError), - #[error(transparent)] - Database(#[from] DatabaseError), - #[error(transparent)] - Timeout(#[from] Elapsed), -} - -impl From for TelemetryError { - fn from(value: amzn_toolkit_telemetry_client::operation::post_metrics::PostMetricsError) -> Self { - Self::Client(Box::new(value)) - } -} - -impl From>> for TelemetryError { - fn from(value: Box>) -> Self { - Self::Send(value) - } -} - -impl From for TelemetryError { - fn from(value: ApiClientError) -> Self { - Self::ApiClient(Box::new(value)) - } -} - -const PRODUCT: &str = "CodeWhisperer"; -const PRODUCT_VERSION: &str = env!("CARGO_PKG_VERSION"); -const CLIENT_ID_ENV_VAR: &str = "Q_TELEMETRY_CLIENT_ID"; - -/// A IDE toolkit telemetry stage -#[derive(Debug, Clone)] -#[non_exhaustive] -pub struct TelemetryStage { - pub endpoint: &'static str, - pub cognito_pool_id: &'static str, - pub region: Region, -} - -impl TelemetryStage { - #[cfg(test)] - const BETA: Self = Self::new( - "https://7zftft3lj2.execute-api.us-east-1.amazonaws.com/Beta", - "us-east-1:db7bfc9f-8ecd-4fbb-bea7-280c16069a99", - "us-east-1", - ); - const EXTERNAL_PROD: Self = Self::new( - "https://client-telemetry.us-east-1.amazonaws.com", - "us-east-1:820fd6d1-95c0-4ca4-bffb-3f01d32da842", - "us-east-1", - ); - - const fn new(endpoint: &'static str, cognito_pool_id: &'static str, region: &'static str) -> Self { - Self { - endpoint, - cognito_pool_id, - region: Region::from_static(region), - } - } -} - -#[derive(Debug)] -enum TelemetrySender { - Strong(mpsc::UnboundedSender), - Weak(mpsc::WeakUnboundedSender), -} - -impl TelemetrySender { - fn send(&self, ev: Event) -> Result<(), Box>> { - match self { - Self::Strong(sender) => sender.send(ev).map_err(Box::new), - Self::Weak(sender) => { - if let Some(sender) = sender.upgrade() { - sender.send(ev).map_err(Box::new) - } else { - tracing::error!( - "Attempted to send telemetry after telemetry thread has been dropped. Event attempted {:?}", - ev - ); - Ok(()) - } - }, - } - } -} - -impl Clone for TelemetrySender { - fn clone(&self) -> Self { - match self { - Self::Strong(sender) => Self::Weak(sender.downgrade()), - Self::Weak(sender) => Self::Weak(sender.clone()), - } - } -} - -#[derive(Debug)] -pub struct TelemetryThread { - handle: Option>, - tx: TelemetrySender, -} - -impl Clone for TelemetryThread { - fn clone(&self) -> Self { - Self { - handle: None, - tx: self.tx.clone(), - } - } -} - -impl TelemetryThread { - pub async fn new(env: &Env, fs: &Fs, database: &mut Database) -> Result { - let telemetry_client = TelemetryClient::new(env, fs, database).await?; - let (tx, mut rx) = mpsc::unbounded_channel(); - let tx = TelemetrySender::Strong(tx); - let handle = tokio::spawn(async move { - while let Some(event) = rx.recv().await { - trace!("TelemetryThread received new telemetry event: {:?}", event); - telemetry_client.send_event(event).await; - } - }); - - Ok(Self { - handle: Some(handle), - tx, - }) - } - - pub async fn finish(self) -> Result<(), TelemetryError> { - drop(self.tx); - if let Some(handle) = self.handle { - match tokio::time::timeout(std::time::Duration::from_millis(1000), handle).await { - Ok(result) => { - if let Err(e) = result { - return Err(TelemetryError::Join(e)); - } - }, - Err(_) => { - // Ignore timeout errors - }, - } - } - - Ok(()) - } - - pub fn send_user_logged_in(&self) -> Result<(), TelemetryError> { - Ok(self.tx.send(Event::new(EventType::UserLoggedIn {}))?) - } - - pub async fn send_cli_subcommand_executed( - &self, - database: &Database, - subcommand: &RootSubcommand, - ) -> Result<(), TelemetryError> { - let mut telemetry_event = Event::new(EventType::CliSubcommandExecuted { - subcommand: subcommand.to_string(), - }); - set_event_metadata(database, &mut telemetry_event).await; - - Ok(self.tx.send(telemetry_event)?) - } - - #[allow(clippy::too_many_arguments)] // TODO: Should make a parameters struct. - pub async fn send_chat_added_message( - &self, - database: &Database, - conversation_id: String, - message_id: Option, - request_id: Option, - context_file_length: Option, - result: TelemetryResult, - reason: Option, - reason_desc: Option, - status_code: Option, - model: Option, - ) -> Result<(), TelemetryError> { - let mut telemetry_event = Event::new(EventType::ChatAddedMessage { - conversation_id, - message_id, - request_id, - context_file_length, - result, - reason, - reason_desc, - status_code, - model, - }); - set_event_metadata(database, &mut telemetry_event).await; - - Ok(self.tx.send(telemetry_event)?) - } - - pub async fn send_tool_use_suggested( - &self, - database: &Database, - event: ToolUseEventBuilder, - ) -> Result<(), TelemetryError> { - let mut telemetry_event = Event::new(EventType::ToolUseSuggested { - conversation_id: event.conversation_id, - utterance_id: event.utterance_id, - user_input_id: event.user_input_id, - tool_use_id: event.tool_use_id, - tool_name: event.tool_name, - is_accepted: event.is_accepted, - is_success: event.is_success, - is_valid: event.is_valid, - is_custom_tool: event.is_custom_tool, - input_token_size: event.input_token_size, - output_token_size: event.output_token_size, - custom_tool_call_latency: event.custom_tool_call_latency, - model: event.model, - aws_service_name: event.aws_service_name, - aws_operation_name: event.aws_operation_name, - }); - set_event_metadata(database, &mut telemetry_event).await; - - Ok(self.tx.send(telemetry_event)?) - } - - pub async fn send_mcp_server_init( - &self, - database: &Database, - conversation_id: String, - init_failure_reason: Option, - number_of_tools: usize, - ) -> Result<(), TelemetryError> { - let mut telemetry_event = Event::new(crate::telemetry::EventType::McpServerInit { - conversation_id, - init_failure_reason, - number_of_tools, - }); - set_event_metadata(database, &mut telemetry_event).await; - - Ok(self.tx.send(telemetry_event)?) - } - - pub fn send_did_select_profile( - &self, - source: QProfileSwitchIntent, - amazonq_profile_region: String, - result: TelemetryResult, - sso_region: Option, - profile_count: Option, - ) -> Result<(), TelemetryError> { - Ok(self.tx.send(Event::new(EventType::DidSelectProfile { - source, - amazonq_profile_region, - result, - sso_region, - profile_count, - }))?) - } - - pub fn send_profile_state( - &self, - source: QProfileSwitchIntent, - amazonq_profile_region: String, - result: TelemetryResult, - sso_region: Option, - ) -> Result<(), TelemetryError> { - Ok(self.tx.send(Event::new(EventType::ProfileState { - source, - amazonq_profile_region, - result, - sso_region, - }))?) - } - - #[allow(clippy::too_many_arguments)] - pub async fn send_response_error( - &self, - database: &Database, - conversation_id: String, - context_file_length: Option, - result: TelemetryResult, - reason: Option, - reason_desc: Option, - status_code: Option, - ) -> Result<(), TelemetryError> { - let mut telemetry_event = Event::new(EventType::MessageResponseError { - result, - reason, - reason_desc, - status_code, - conversation_id, - context_file_length, - }); - set_event_metadata(database, &mut telemetry_event).await; - - Ok(self.tx.send(telemetry_event)?) - } -} - -async fn set_event_metadata(database: &Database, event: &mut Event) { - let (start_url, region) = get_start_url_and_region(database).await; - if let Some(start_url) = start_url { - event.set_start_url(start_url); - } - if let Some(region) = region { - event.set_sso_region(region); - } - - // Set the client application from environment variable - if let Ok(client_app) = std::env::var(Q_CLI_CLIENT_APPLICATION) { - event.set_client_application(client_app); - } -} - -#[derive(Debug)] -struct TelemetryClient { - client_id: Uuid, - telemetry_enabled: bool, - codewhisperer_client: Option, - toolkit_telemetry_client: Option, -} - -impl TelemetryClient { - async fn new(env: &Env, fs: &Fs, database: &mut Database) -> Result { - let telemetry_enabled = !cfg!(test) - && env.get_os("Q_DISABLE_TELEMETRY").is_none() - && database.settings.get_bool(Setting::TelemetryEnabled).unwrap_or(true); - - // If telemetry is disabled we do not emit using toolkit_telemetry - let toolkit_telemetry_client = if telemetry_enabled { - Some(ToolkitTelemetryClient::from_conf( - Config::builder() - .http_client(crate::aws_common::http_client::client()) - .behavior_version(BehaviorVersion::v2025_08_07()) - .endpoint_resolver(StaticEndpoint(TelemetryStage::EXTERNAL_PROD.endpoint)) - .app_name(app_name()) - .region(TelemetryStage::EXTERNAL_PROD.region.clone()) - .credentials_provider(SharedCredentialsProvider::new(CognitoProvider::new( - TelemetryStage::EXTERNAL_PROD, - ))) - .build(), - )) - } else { - None - }; - - fn client_id(env: &Env, database: &mut Database, telemetry_enabled: bool) -> Result { - if !telemetry_enabled { - return Ok(uuid!("ffffffff-ffff-ffff-ffff-ffffffffffff")); - } - - if let Ok(client_id) = env.get(CLIENT_ID_ENV_VAR) { - if let Ok(uuid) = Uuid::from_str(&client_id) { - return Ok(uuid); - } - } - - Ok(match database.get_client_id()? { - Some(uuid) => uuid, - None => { - let uuid = database - .settings - .get_string(Setting::OldClientId) - .and_then(|id| Uuid::try_parse(&id).ok()) - .unwrap_or_else(Uuid::new_v4); - - if let Err(err) = database.set_client_id(uuid) { - error!(%err, "Failed to set client id in state"); - } - - uuid - }, - }) - } - - // cw telemetry is only available with bearer token auth. - let codewhisperer_client = if env.get("AMAZON_Q_SIGV4").is_ok() { - None - } else { - Some(ApiClient::new(env, fs, database, None).await?) - }; - - Ok(Self { - client_id: client_id(env, database, telemetry_enabled)?, - telemetry_enabled, - toolkit_telemetry_client, - codewhisperer_client, - }) - } - - /// Sends a telemetry event to both the CW and toolkit API's. If the clients do not exist, then - /// telemetry is not sent. - /// - /// See [TelemetryClient::new] for which conditions the clients are created for. - async fn send_event(&self, event: Event) { - self.send_cw_telemetry_event(&event).await; - self.send_telemetry_toolkit_metric(event).await; - } - - async fn send_cw_telemetry_event(&self, event: &Event) { - let Some(codewhisperer_client) = self.codewhisperer_client.clone() else { - trace!("not sending cw metric - client does not exist"); - return; - }; - - if let EventType::ChatAddedMessage { - conversation_id, - message_id, - model, - .. - } = &event.ty - { - let user_context = self.user_context().unwrap(); - - let chat_add_message_event = match ChatAddMessageEvent::builder() - .conversation_id(conversation_id) - .message_id(message_id.clone().unwrap_or("not_set".to_string())) - .build() - { - Ok(event) => event, - Err(err) => { - error!(err =% DisplayErrorContext(err), "Failed to send cw telemetry event"); - return; - }, - }; - - let event = TelemetryEvent::ChatAddMessageEvent(chat_add_message_event); - debug!( - ?event, - ?user_context, - telemetry_enabled = self.telemetry_enabled, - "Sending cw telemetry event" - ); - if let Err(err) = codewhisperer_client - .send_telemetry_event(event, user_context, self.telemetry_enabled, model.to_owned()) - .await - { - error!(err =% DisplayErrorContext(err), "Failed to send cw telemetry event"); - } - } - } - - async fn send_telemetry_toolkit_metric(&self, event: Event) { - let Some(toolkit_telemetry_client) = self.toolkit_telemetry_client.clone() else { - trace!("not sending toolkit metric - client does not exist"); - return; - }; - let client_id = self.client_id; - let Some(metric_datum) = event.into_metric_datum() else { - trace!("not sending toolkit metric - metric datum does not exist"); - return; - }; - - let product = AwsProduct::CodewhispererTerminal; - let metric_name = metric_datum.metric_name().to_owned(); - - debug!(?client_id, ?product, ?metric_datum, "Sending toolkit telemetry event"); - if let Err(err) = toolkit_telemetry_client - .post_metrics() - .aws_product(product) - .aws_product_version(env!("CARGO_PKG_VERSION")) - .client_id(client_id) - .os(std::env::consts::OS) - .os_architecture(std::env::consts::ARCH) - .os_version(os_version().map(|v| v.to_string()).unwrap_or_default()) - .metric_data(metric_datum) - .send() - .await - .map_err(DisplayErrorContext) - { - error!(%err, ?metric_name, "Failed to post toolkit metric"); - } - } - - fn user_context(&self) -> Option { - let operating_system = match std::env::consts::OS { - "linux" => OperatingSystem::Linux, - "macos" => OperatingSystem::Mac, - "windows" => OperatingSystem::Windows, - os => { - error!(%os, "Unsupported operating system"); - return None; - }, - }; - - match UserContext::builder() - .client_id(self.client_id.hyphenated().to_string()) - .operating_system(operating_system) - .product(PRODUCT) - .ide_category(IdeCategory::Cli) - .ide_version(PRODUCT_VERSION) - .build() - { - Ok(user_context) => Some(user_context), - Err(err) => { - error!(%err, "Failed to build user context"); - None - }, - } - } -} - -pub trait ReasonCode: std::error::Error { - fn reason_code(&self) -> String; -} - -/// Returns a generic error reason + reason description pair. -pub fn get_error_reason(error: &E) -> (String, String) -where - E: ReasonCode + 'static, -{ - let mut err_chain = eyre::Chain::new(error); - let reason_desc = if err_chain.len() > 1 { - format!( - "'{}' caused by: {}", - error, - err_chain.next_back().map_or("UNKNOWN".to_string(), |e| e.to_string()) - ) - } else { - error.to_string() - }; - - (error.reason_code(), reason_desc) -} - -#[cfg(test)] -mod test { - use uuid::uuid; - - use super::*; - - #[tokio::test] - async fn client_context() { - let mut database = Database::new().await.unwrap(); - let client = TelemetryClient::new(&Env::new(), &Fs::new(), &mut database) - .await - .unwrap(); - let context = client.user_context().unwrap(); - - assert_eq!(context.ide_category, IdeCategory::Cli); - assert!(matches!( - context.operating_system, - OperatingSystem::Linux | OperatingSystem::Mac | OperatingSystem::Windows - )); - assert_eq!(context.product, PRODUCT); - assert_eq!( - context.client_id, - Some(uuid!("ffffffff-ffff-ffff-ffff-ffffffffffff").hyphenated().to_string()) - ); - assert_eq!(context.ide_version.as_deref(), Some(PRODUCT_VERSION)); - } - - #[tracing_test::traced_test] - #[tokio::test] - #[ignore = "needs auth which is not in CI"] - async fn test_send() { - let mut database = Database::new().await.unwrap(); - let thread = TelemetryThread::new(&Env::new(), &Fs::new(), &mut database) - .await - .unwrap(); - thread.send_user_logged_in().ok(); - drop(thread); - - assert!(!logs_contain("ERROR")); - assert!(!logs_contain("error")); - assert!(!logs_contain("WARN")); - assert!(!logs_contain("warn")); - assert!(!logs_contain("Failed to post metric")); - } - - #[tracing_test::traced_test] - #[tokio::test] - #[ignore = "needs auth which is not in CI"] - async fn test_all_telemetry() { - let mut database = Database::new().await.unwrap(); - let thread = TelemetryThread::new(&Env::new(), &Fs::new(), &mut database) - .await - .unwrap(); - - thread.send_user_logged_in().ok(); - thread - .send_cli_subcommand_executed(&database, &RootSubcommand::Version { changelog: None }) - .await - .ok(); - thread - .send_chat_added_message( - &database, - "conv_id".to_owned(), - Some("message_id".to_owned()), - Some("req_id".to_owned()), - Some(123), - TelemetryResult::Succeeded, - None, - None, - None, - None, - ) - .await - .ok(); - - drop(thread); - - assert!(!logs_contain("ERROR")); - assert!(!logs_contain("error")); - assert!(!logs_contain("WARN")); - assert!(!logs_contain("warn")); - assert!(!logs_contain("Failed to post metric")); - } - - #[tokio::test] - #[ignore = "needs auth which is not in CI"] - async fn test_without_optout() { - let mut database = Database::new().await.unwrap(); - let client = TelemetryClient::new(&Env::new(), &Fs::new(), &mut database) - .await - .unwrap(); - client - .codewhisperer_client - .as_ref() - .expect("cw telemetry client should exist") - .send_telemetry_event( - TelemetryEvent::ChatAddMessageEvent( - ChatAddMessageEvent::builder() - .conversation_id("debug".to_owned()) - .message_id("debug".to_owned()) - .build() - .unwrap(), - ), - client.user_context().unwrap(), - false, - Some("model".to_owned()), - ) - .await - .unwrap(); - } -} diff --git a/crates/chat-cli/src/util/consts.rs b/crates/chat-cli/src/util/consts.rs deleted file mode 100644 index 4c5542a00..000000000 --- a/crates/chat-cli/src/util/consts.rs +++ /dev/null @@ -1,93 +0,0 @@ -/// TODO(brandonskiser): revert back to "qchat" for prompting login after standalone releases. -pub const CLI_BINARY_NAME: &str = "q"; -pub const CHAT_BINARY_NAME: &str = "qchat"; - -pub const PRODUCT_NAME: &str = "Amazon Q"; - -pub const GITHUB_REPO_NAME: &str = "aws/amazon-q-developer-cli"; - -pub const GOV_REGIONS: &[&str] = &["us-gov-east-1", "us-gov-west-1"]; - -/// Build time env vars -pub mod build { - /// A git full sha hash of the current build - pub const HASH: Option<&str> = option_env!("AMAZON_Q_BUILD_HASH"); - - /// The datetime in rfc3339 format of the current build - pub const DATETIME: Option<&str> = option_env!("AMAZON_Q_BUILD_DATETIME"); -} - -pub mod env_var { - macro_rules! define_env_vars { - ($($(#[$meta:meta])* $ident:ident = $name:expr),*) => { - $( - $(#[$meta])* - pub const $ident: &str = $name; - )* - - pub const ALL: &[&str] = &[$($ident),*]; - } - } - - define_env_vars! { - /// The UUID of the current parent qterm instance - QTERM_SESSION_ID = "QTERM_SESSION_ID", - - /// The current parent socket to connect to - Q_PARENT = "Q_PARENT", - - /// Set the [`Q_PARENT`] parent socket to connect to - Q_SET_PARENT = "Q_SET_PARENT", - - /// Guard for the [`Q_SET_PARENT`] check - Q_SET_PARENT_CHECK = "Q_SET_PARENT_CHECK", - - /// Set if qterm is running, contains the version - Q_TERM = "Q_TERM", - - /// Sets the current log level - Q_LOG_LEVEL = "Q_LOG_LEVEL", - - /// Overrides the ZDOTDIR environment variable - Q_ZDOTDIR = "Q_ZDOTDIR", - - /// Indicates a process was launched by Amazon Q - PROCESS_LAUNCHED_BY_Q = "PROCESS_LAUNCHED_BY_Q", - - /// The shell to use in qterm - Q_SHELL = "Q_SHELL", - - /// Indicates the user is debugging the shell - Q_DEBUG_SHELL = "Q_DEBUG_SHELL", - - /// Indicates the user is using zsh autosuggestions which disables Inline - Q_USING_ZSH_AUTOSUGGESTIONS = "Q_USING_ZSH_AUTOSUGGESTIONS", - - /// Overrides the path to the bundle metadata released with certain desktop builds. - Q_BUNDLE_METADATA_PATH = "Q_BUNDLE_METADATA_PATH", - - /// Identifier for the client application or service using the chat-cli - Q_CLI_CLIENT_APPLICATION = "Q_CLI_CLIENT_APPLICATION" - } -} - -#[cfg(test)] -mod tests { - use time::OffsetDateTime; - use time::format_description::well_known::Rfc3339; - - use super::*; - - #[test] - fn test_build_envs() { - if let Some(build_hash) = build::HASH { - println!("build_hash: {build_hash}"); - assert!(!build_hash.is_empty()); - } - - if let Some(build_datetime) = build::DATETIME { - println!("build_datetime: {build_datetime}"); - println!("{}", OffsetDateTime::parse(build_datetime, &Rfc3339).unwrap()); - } - } -} diff --git a/crates/chat-cli/src/util/directories.rs b/crates/chat-cli/src/util/directories.rs deleted file mode 100644 index 525df4784..000000000 --- a/crates/chat-cli/src/util/directories.rs +++ /dev/null @@ -1,262 +0,0 @@ -use std::path::PathBuf; - -use thiserror::Error; - -use crate::os::Os; -use crate::util::paths::PathResolver; - -#[allow(dead_code)] // Allow unused variants during migration -#[derive(Debug, Error)] -pub enum DirectoryError { - #[error("home directory not found")] - NoHomeDirectory, - #[cfg(unix)] - #[error("runtime directory not found: neither XDG_RUNTIME_DIR nor TMPDIR were found")] - NoRuntimeDirectory, - #[error("IO Error: {0}")] - Io(#[from] std::io::Error), - #[error(transparent)] - TimeFormat(#[from] time::error::Format), - #[error(transparent)] - Utf8FromPath(#[from] camino::FromPathError), - #[error(transparent)] - Utf8FromPathBuf(#[from] camino::FromPathBufError), - #[error(transparent)] - FromVecWithNul(#[from] std::ffi::FromVecWithNulError), - #[error(transparent)] - IntoString(#[from] std::ffi::IntoStringError), -} - -type Result = std::result::Result; - -/// The directory of the users home -/// -/// - Linux: /home/Alice -/// - MacOS: /Users/Alice -/// - Windows: C:\Users\Alice -#[allow(dead_code)] // Allow unused function during migration -pub fn home_dir(#[cfg_attr(windows, allow(unused_variables))] os: &Os) -> Result { - #[cfg(unix)] - match cfg!(test) { - true => os - .env - .get("HOME") - .map_err(|_err| DirectoryError::NoHomeDirectory) - .and_then(|h| { - if h.is_empty() { - Err(DirectoryError::NoHomeDirectory) - } else { - Ok(h) - } - }) - .map(PathBuf::from) - .map(|p| os.fs.chroot_path(p)), - false => dirs::home_dir().ok_or(DirectoryError::NoHomeDirectory), - } - - #[cfg(windows)] - match cfg!(test) { - true => os - .env - .get("USERPROFILE") - .map_err(|_err| DirectoryError::NoHomeDirectory) - .and_then(|h| { - if h.is_empty() { - Err(DirectoryError::NoHomeDirectory) - } else { - Ok(h) - } - }) - .map(PathBuf::from) - .map(|p| os.fs.chroot_path(p)), - false => dirs::home_dir().ok_or(DirectoryError::NoHomeDirectory), - } -} - -/// Get the macos tempdir from the `confstr` function -/// -/// See: -#[cfg(target_os = "macos")] -fn macos_tempdir() -> Result { - let len = unsafe { libc::confstr(libc::_CS_DARWIN_USER_TEMP_DIR, std::ptr::null::().cast_mut(), 0) }; - let mut buf: Vec = vec![0; len]; - unsafe { libc::confstr(libc::_CS_DARWIN_USER_TEMP_DIR, buf.as_mut_ptr().cast(), buf.len()) }; - let c_string = std::ffi::CString::from_vec_with_nul(buf)?; - let str = c_string.into_string()?; - Ok(PathBuf::from(str)) -} - -/// Runtime dir is used for runtime data that should not be persisted for a long time, e.g. socket -/// files and logs -/// -/// The XDG_RUNTIME_DIR is set by systemd , -/// if this is not set such as on macOS it will fallback to TMPDIR which is secure on macOS -#[cfg(unix)] -pub fn runtime_dir() -> Result { - let mut dir = dirs::runtime_dir(); - dir = dir.or_else(|| std::env::var_os("TMPDIR").map(PathBuf::from)); - - cfg_if::cfg_if! { - if #[cfg(target_os = "macos")] { - let macos_tempdir = macos_tempdir()?; - dir = dir.or(Some(macos_tempdir)); - } else { - dir = dir.or_else(|| Some(std::env::temp_dir())); - } - } - - dir.ok_or(DirectoryError::NoRuntimeDirectory) -} - -/// The directory to all the fig logs -/// - Linux: `/tmp/fig/$USER/logs` -/// - MacOS: `$TMPDIR/logs` -/// - Windows: `%TEMP%\fig\logs` -pub fn logs_dir() -> Result { - cfg_if::cfg_if! { - if #[cfg(unix)] { - Ok(runtime_dir()?.join("qlog")) - } else if #[cfg(windows)] { - use crate::util::paths::application::DATA_DIR_NAME; - Ok(std::env::temp_dir().join(DATA_DIR_NAME).join("logs")) - } - } -} - -/// The directory to the directory containing config for the `/context` feature in `q chat`. -pub fn chat_global_context_path(os: &Os) -> Result { - PathResolver::new(os) - .global() - .global_context() - .map_err(|e| DirectoryError::Io(std::io::Error::other(e))) -} - -/// The directory to the directory containing config for the `/context` feature in `q chat`. -pub fn chat_profiles_dir(os: &Os) -> Result { - PathResolver::new(os) - .global() - .profiles_dir() - .map_err(|e| DirectoryError::Io(std::io::Error::other(e))) -} - -/// The path to the fig settings file -pub fn settings_path() -> Result { - crate::util::paths::ApplicationPaths::settings_path_static() - .map_err(|e| DirectoryError::Io(std::io::Error::other(e))) -} - -/// The path to the local sqlite database -pub fn database_path() -> Result { - crate::util::paths::ApplicationPaths::database_path_static() - .map_err(|e| DirectoryError::Io(std::io::Error::other(e))) -} - -#[cfg(test)] -mod linux_tests { - use super::*; - - #[test] - fn all_paths() { - assert!(logs_dir().is_ok()); - assert!(settings_path().is_ok()); - } -} - -// TODO(grant): Add back path tests on linux -#[cfg(all(test, not(target_os = "linux")))] -mod tests { - use insta; - - use super::*; - - macro_rules! assert_directory { - ($value:expr, @$snapshot:literal) => { - insta::assert_snapshot!( - sanitized_directory_path($value), - @$snapshot, - ) - }; - } - - macro_rules! macos { - ($value:expr, @$snapshot:literal) => { - #[cfg(target_os = "macos")] - assert_directory!($value, @$snapshot) - }; - } - - macro_rules! linux { - ($value:expr, @$snapshot:literal) => { - #[cfg(target_os = "linux")] - assert_directory!($value, @$snapshot) - }; - } - - macro_rules! windows { - ($value:expr, @$snapshot:literal) => { - #[cfg(target_os = "windows")] - assert_directory!($value, @$snapshot) - }; - } - - fn sanitized_directory_path(path: Result) -> String { - let mut path = path.unwrap().into_os_string().into_string().unwrap(); - - if let Ok(home) = std::env::var("HOME") { - let home = home.strip_suffix('/').unwrap_or(&home); - path = path.replace(home, "$HOME"); - } - - let user = whoami::username(); - path = path.replace(&user, "$USER"); - - if let Ok(tmpdir) = std::env::var("TMPDIR") { - let tmpdir = tmpdir.strip_suffix('/').unwrap_or(&tmpdir); - path = path.replace(tmpdir, "$TMPDIR"); - } - - #[cfg(target_os = "macos")] - { - if let Ok(tmpdir) = macos_tempdir() { - let tmpdir = tmpdir.to_str().unwrap(); - let tmpdir = tmpdir.strip_suffix('/').unwrap_or(tmpdir); - path = path.replace(tmpdir, "$TMPDIR"); - }; - } - - if let Ok(xdg_runtime_dir) = std::env::var("XDG_RUNTIME_DIR") { - let xdg_runtime_dir = xdg_runtime_dir.strip_suffix('/').unwrap_or(&xdg_runtime_dir); - path = path.replace(xdg_runtime_dir, "$XDG_RUNTIME_DIR"); - } - - #[cfg(target_os = "linux")] - { - path = path.replace("/tmp", "$TMPDIR"); - } - - path - } - - #[test] - fn snapshot_fig_data_dir() { - let app_data_dir = - || crate::util::paths::app_data_dir().map_err(|e| DirectoryError::Io(std::io::Error::other(e))); - linux!(app_data_dir(), @"$HOME/.local/share/amazon-q"); - macos!(app_data_dir(), @"$HOME/Library/Application Support/amazon-q"); - windows!(app_data_dir(), @r"C:\Users\$USER\AppData\Local\AmazonQ"); - } - - #[test] - fn snapshot_settings_path() { - linux!(settings_path(), @"$HOME/.local/share/amazon-q/settings.json"); - macos!(settings_path(), @"$HOME/Library/Application Support/amazon-q/settings.json"); - windows!(settings_path(), @r"C:\Users\$USER\AppData\Local\amazon-q\settings.json"); - } - - #[test] - #[cfg(target_os = "macos")] - fn macos_tempdir_test() { - let tmpdir = macos_tempdir().unwrap(); - println!("{:?}", tmpdir); - } -} diff --git a/crates/chat-cli/src/util/knowledge_store.rs b/crates/chat-cli/src/util/knowledge_store.rs deleted file mode 100644 index 23ef1345e..000000000 --- a/crates/chat-cli/src/util/knowledge_store.rs +++ /dev/null @@ -1,262 +0,0 @@ -use std::sync::{ - Arc, - LazyLock as Lazy, -}; - -use eyre::Result; -use semantic_search_client::KnowledgeContext; -use semantic_search_client::client::AsyncSemanticSearchClient; -use semantic_search_client::types::SearchResult; -use tokio::sync::Mutex; -use uuid::Uuid; - -#[derive(Debug)] -pub enum KnowledgeError { - ClientError(String), -} - -impl std::fmt::Display for KnowledgeError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - KnowledgeError::ClientError(msg) => write!(f, "Client error: {}", msg), - } - } -} - -impl std::error::Error for KnowledgeError {} - -/// Async knowledge store - just a thin wrapper! -pub struct KnowledgeStore { - client: AsyncSemanticSearchClient, -} - -impl KnowledgeStore { - /// Get singleton instance - pub async fn get_async_instance() -> Arc> { - static ASYNC_INSTANCE: Lazy>>> = - Lazy::new(tokio::sync::OnceCell::new); - - if cfg!(test) { - Arc::new(Mutex::new( - KnowledgeStore::new() - .await - .expect("Failed to create test async knowledge store"), - )) - } else { - ASYNC_INSTANCE - .get_or_init(|| async { - Arc::new(Mutex::new( - KnowledgeStore::new() - .await - .expect("Failed to create async knowledge store"), - )) - }) - .await - .clone() - } - } - - pub async fn new() -> Result { - let client = AsyncSemanticSearchClient::new_with_default_dir() - .await - .map_err(|e| eyre::eyre!("Failed to create client: {}", e))?; - - Ok(Self { client }) - } - - /// Add context - delegates to async client - pub async fn add(&mut self, name: &str, path_str: &str) -> Result { - let path_buf = std::path::PathBuf::from(path_str); - let canonical_path = path_buf - .canonicalize() - .map_err(|_io_error| format!("❌ Path does not exist: {}", path_str))?; - - match self - .client - .add_context_from_path(&canonical_path, name, &format!("Knowledge context for {}", name), true) - .await - { - Ok((operation_id, _)) => Ok(format!( - "🚀 Started indexing '{}'\n📁 Path: {}\n🆔 Operation ID: {}.", - name, - canonical_path.display(), - &operation_id.to_string()[..8] - )), - Err(e) => Err(format!("Failed to start indexing: {}", e)), - } - } - - /// Get all contexts - delegates to async client - pub async fn get_all(&self) -> Result, KnowledgeError> { - Ok(self.client.get_contexts().await) - } - - /// Search - delegates to async client - pub async fn search(&self, query: &str, _context_id: Option<&str>) -> Result, KnowledgeError> { - let results = self - .client - .search_all(query, None) - .await - .map_err(|e| KnowledgeError::ClientError(e.to_string()))?; - - let mut flattened = Vec::new(); - for (_, context_results) in results { - flattened.extend(context_results); - } - - flattened.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap_or(std::cmp::Ordering::Equal)); - - Ok(flattened) - } - - /// Get status data - delegates to async client - pub async fn get_status_data(&self) -> Result { - self.client - .get_status_data() - .await - .map_err(|e| format!("Failed to get status data: {}", e)) - } - - /// Cancel operation - delegates to async client - pub async fn cancel_operation(&mut self, operation_id: Option<&str>) -> Result { - if let Some(short_id) = operation_id { - // Debug: List all available operations - let available_ops = self.client.list_operation_ids().await; - if available_ops.is_empty() { - return Err("No active operations found".to_string()); - } - - // Try to parse as full UUID first - if let Ok(uuid) = Uuid::parse_str(short_id) { - self.client.cancel_operation(uuid).await.map_err(|e| e.to_string()) - } else { - // Try to find by short ID (first 8 characters) - if let Some(full_uuid) = self.client.find_operation_by_short_id(short_id).await { - self.client.cancel_operation(full_uuid).await.map_err(|e| e.to_string()) - } else { - Err(format!( - "No operation found matching ID: {}\nAvailable operations:\n{}", - short_id, - available_ops.join("\n") - )) - } - } - } else { - // Cancel all operations - self.client.cancel_all_operations().await.map_err(|e| e.to_string()) - } - } - - /// Clear all contexts (background operation) - pub async fn clear(&mut self) -> Result { - match self.client.clear_all().await { - Ok((operation_id, _cancel_token)) => Ok(format!( - "🚀 Started clearing all contexts in background.\n📊 Use 'knowledge status' to check progress.\n🆔 Operation ID: {}", - &operation_id.to_string()[..8] - )), - Err(e) => Err(format!("Failed to start clear operation: {}", e)), - } - } - - /// Clear all contexts immediately (synchronous operation) - pub async fn clear_immediate(&mut self) -> Result { - match self.client.clear_all_immediate().await { - Ok(count) => Ok(format!("✅ Successfully cleared {} knowledge base entries", count)), - Err(e) => Err(format!("Failed to clear knowledge base: {}", e)), - } - } - - /// Remove context by path - pub async fn remove_by_path(&mut self, path: &str) -> Result<(), String> { - if let Some(context) = self.client.get_context_by_path(path).await { - self.client - .remove_context_by_id(&context.id) - .await - .map_err(|e| e.to_string()) - } else { - Err(format!("No context found with path '{}'", path)) - } - } - - /// Remove context by name - pub async fn remove_by_name(&mut self, name: &str) -> Result<(), String> { - if let Some(context) = self.client.get_context_by_name(name).await { - self.client - .remove_context_by_id(&context.id) - .await - .map_err(|e| e.to_string()) - } else { - Err(format!("No context found with name '{}'", name)) - } - } - - /// Remove context by ID - pub async fn remove_by_id(&mut self, context_id: &str) -> Result<(), String> { - self.client - .remove_context_by_id(context_id) - .await - .map_err(|e| e.to_string()) - } - - /// Update context by path - pub async fn update_by_path(&mut self, path_str: &str) -> Result { - if let Some(context) = self.client.get_context_by_path(path_str).await { - // Remove the existing context first - self.client - .remove_context_by_id(&context.id) - .await - .map_err(|e| e.to_string())?; - - // Then add it back with the same name - self.add(&context.name, path_str).await - } else { - // Debug: List all available contexts - let available_paths = self.client.list_context_paths().await; - if available_paths.is_empty() { - Err("No contexts found. Add a context first with 'knowledge add '".to_string()) - } else { - Err(format!( - "No context found with path '{}'\nAvailable contexts:\n{}", - path_str, - available_paths.join("\n") - )) - } - } - } - - /// Update context by ID - pub async fn update_context_by_id(&mut self, context_id: &str, path_str: &str) -> Result { - let contexts = self.get_all().await.map_err(|e| e.to_string())?; - let context = contexts - .iter() - .find(|c| c.id == context_id) - .ok_or_else(|| format!("Context '{}' not found", context_id))?; - - let context_name = context.name.clone(); - - // Remove the existing context first - self.client - .remove_context_by_id(context_id) - .await - .map_err(|e| e.to_string())?; - - // Then add it back with the same name - self.add(&context_name, path_str).await - } - - /// Update context by name - pub async fn update_context_by_name(&mut self, name: &str, path_str: &str) -> Result { - if let Some(context) = self.client.get_context_by_name(name).await { - // Remove the existing context first - self.client - .remove_context_by_id(&context.id) - .await - .map_err(|e| e.to_string())?; - - // Then add it back with the same name - self.add(name, path_str).await - } else { - Err(format!("Context with name '{}' not found", name)) - } - } -} diff --git a/crates/chat-cli/src/util/mod.rs b/crates/chat-cli/src/util/mod.rs deleted file mode 100644 index 041373df0..000000000 --- a/crates/chat-cli/src/util/mod.rs +++ /dev/null @@ -1,100 +0,0 @@ -pub mod consts; -pub mod directories; -pub mod knowledge_store; -pub mod open; -pub mod paths; -pub mod process; -pub mod spinner; -pub mod system_info; -#[cfg(test)] -pub mod test; - -use std::fmt::Display; -use std::io::{ - ErrorKind, - stdout, -}; - -use anstream::stream::IsTerminal; -pub use consts::*; -use dialoguer::Select; -use dialoguer::theme::ColorfulTheme; -use eyre::{ - Context, - Result, - bail, -}; -use thiserror::Error; -use tracing::warn; - -#[derive(Debug, Error)] -pub enum UtilError { - #[error("io operation error")] - IoError(#[from] std::io::Error), - #[error(transparent)] - Directory(#[from] directories::DirectoryError), - #[error(transparent)] - StrUtf8Error(#[from] std::str::Utf8Error), - #[error(transparent)] - Json(#[from] serde_json::Error), -} - -#[derive(Debug, Clone)] -pub struct UnknownDesktopErrContext { - xdg_current_desktop: String, - xdg_session_desktop: String, - gdm_session: String, -} - -impl std::fmt::Display for UnknownDesktopErrContext { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "XDG_CURRENT_DESKTOP: `{}`, ", self.xdg_current_desktop)?; - write!(f, "XDG_SESSION_DESKTOP: `{}`, ", self.xdg_session_desktop)?; - write!(f, "GDMSESSION: `{}`", self.gdm_session) - } -} - -pub fn choose(prompt: impl Display, options: &[impl ToString]) -> Result> { - if options.is_empty() { - bail!("no options passed to choose") - } - - if !stdout().is_terminal() { - warn!("called choose while stdout is not a terminal"); - return Ok(Some(0)); - } - - match Select::with_theme(&dialoguer_theme()) - .items(options) - .default(0) - .with_prompt(prompt.to_string()) - .interact_opt() - { - Ok(ok) => Ok(ok), - Err(dialoguer::Error::IO(io)) if io.kind() == ErrorKind::Interrupted => Ok(None), - Err(e) => Err(e).wrap_err("Failed to choose"), - } -} - -pub fn input(prompt: &str, initial_text: Option<&str>) -> Result { - if !stdout().is_terminal() { - warn!("called input while stdout is not a terminal"); - return Ok(String::new()); - } - - let theme = dialoguer_theme(); - let mut input = dialoguer::Input::with_theme(&theme).with_prompt(prompt); - - if let Some(initial_text) = initial_text { - input = input.with_initial_text(initial_text); - } - - Ok(input.interact_text()?) -} - -pub fn dialoguer_theme() -> ColorfulTheme { - ColorfulTheme { - prompt_prefix: dialoguer::console::style("?".into()).for_stderr().magenta(), - ..ColorfulTheme::default() - } -} diff --git a/crates/chat-cli/src/util/open.rs b/crates/chat-cli/src/util/open.rs deleted file mode 100644 index 6309c7fcc..000000000 --- a/crates/chat-cli/src/util/open.rs +++ /dev/null @@ -1,102 +0,0 @@ -use cfg_if::cfg_if; - -#[derive(Debug, thiserror::Error)] -pub enum Error { - #[error(transparent)] - Io(#[from] std::io::Error), - #[error("Failed to open URL")] - Failed, -} - -#[cfg(target_os = "macos")] -#[allow(unexpected_cfgs)] -fn open_macos(url_str: impl AsRef) -> Result<(), Error> { - use objc2::ClassType; - use objc2_foundation::{ - NSString, - NSURL, - }; - - let url_nsstring = NSString::from_str(url_str.as_ref()); - let nsurl = unsafe { NSURL::initWithString(NSURL::alloc(), &url_nsstring) }.ok_or(Error::Failed)?; - let res = unsafe { objc2_app_kit::NSWorkspace::sharedWorkspace().openURL(&nsurl) }; - res.then_some(()).ok_or(Error::Failed) -} - -#[cfg(target_os = "windows")] -fn open_command(url: impl AsRef) -> std::process::Command { - use std::os::windows::process::CommandExt; - - let detached = 0x8; - let mut command = std::process::Command::new("cmd"); - command.creation_flags(detached); - command.args(["/c", "start", url.as_ref()]); - command -} - -#[cfg(any(target_os = "linux", target_os = "freebsd"))] -fn open_command(url: impl AsRef) -> std::process::Command { - let executable = if super::system_info::in_wsl() { - "wslview" - } else { - "xdg-open" - }; - - let mut command = std::process::Command::new(executable); - command.arg(url.as_ref()); - command -} - -/// Returns bool indicating whether the URL was opened successfully -#[allow(dead_code)] -pub fn open_url(url: impl AsRef) -> Result<(), Error> { - cfg_if! { - if #[cfg(target_os = "macos")] { - open_macos(url) - } else { - match open_command(url).output() { - Ok(output) => { - tracing::trace!(?output, "open_url output"); - if output.status.success() { - Ok(()) - } else { - Err(Error::Failed) - } - }, - Err(err) => Err(err.into()), - } - } - } -} - -/// Returns bool indicating whether the URL was opened successfully -pub async fn open_url_async(url: impl AsRef) -> Result<(), Error> { - cfg_if! { - if #[cfg(target_os = "macos")] { - open_macos(url) - } else { - match tokio::process::Command::from(open_command(url)).output().await { - Ok(output) => { - tracing::trace!(?output, "open_url_async output"); - if output.status.success() { - Ok(()) - } else { - Err(Error::Failed) - } - }, - Err(err) => Err(err.into()), - } - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[ignore] - #[test] - fn test_open_url() { - open_url("https://fig.io").unwrap(); - } -} diff --git a/crates/chat-cli/src/util/paths.rs b/crates/chat-cli/src/util/paths.rs deleted file mode 100644 index ea141448b..000000000 --- a/crates/chat-cli/src/util/paths.rs +++ /dev/null @@ -1,156 +0,0 @@ -//! Hierarchical path management for the application - -use std::path::PathBuf; - -use crate::os::Os; - -#[derive(Debug, thiserror::Error)] -pub enum DirectoryError { - #[error("home directory not found")] - NoHomeDirectory, - #[error("IO Error: {0}")] - Io(#[from] std::io::Error), -} - -pub mod workspace { - //! Project-level paths (relative to current working directory) - pub const MCP_CONFIG: &str = ".amazonq/mcp.json"; - pub const RULES_PATTERN: &str = ".amazonq/rules/**/*.md"; -} - -pub mod global { - //! User-level paths (relative to home directory) - pub const MCP_CONFIG: &str = ".aws/amazonq/mcp.json"; - pub const GLOBAL_CONTEXT: &str = ".aws/amazonq/global_context.json"; - pub const PROFILES_DIR: &str = ".aws/amazonq/profiles"; -} - -pub mod application { - //! Application data paths (system-specific) - #[cfg(unix)] - pub const DATA_DIR_NAME: &str = "amazon-q"; - #[cfg(windows)] - pub const DATA_DIR_NAME: &str = "AmazonQ"; - pub const SETTINGS_FILE: &str = "settings.json"; - pub const DATABASE_FILE: &str = "data.sqlite3"; -} - -type Result = std::result::Result; - -/// The directory of the users home -/// - Linux: /home/Alice -/// - MacOS: /Users/Alice -/// - Windows: C:\Users\Alice -pub fn home_dir(#[cfg_attr(windows, allow(unused_variables))] os: &Os) -> Result { - #[cfg(unix)] - match cfg!(test) { - true => os - .env - .get("HOME") - .map_err(|_err| DirectoryError::NoHomeDirectory) - .and_then(|h| { - if h.is_empty() { - Err(DirectoryError::NoHomeDirectory) - } else { - Ok(h) - } - }) - .map(PathBuf::from) - .map(|p| os.fs.chroot_path(p)), - false => dirs::home_dir().ok_or(DirectoryError::NoHomeDirectory), - } - - #[cfg(windows)] - match cfg!(test) { - true => os - .env - .get("USERPROFILE") - .map_err(|_err| DirectoryError::NoHomeDirectory) - .and_then(|h| { - if h.is_empty() { - Err(DirectoryError::NoHomeDirectory) - } else { - Ok(h) - } - }) - .map(PathBuf::from) - .map(|p| os.fs.chroot_path(p)), - false => dirs::home_dir().ok_or(DirectoryError::NoHomeDirectory), - } -} - -/// The application data directory -/// - Linux: `$XDG_DATA_HOME/{data_dir}` or `$HOME/.local/share/{data_dir}` -/// - MacOS: `$HOME/Library/Application Support/{data_dir}` -/// - Windows: `%LOCALAPPDATA%\{data_dir}` -pub fn app_data_dir() -> Result { - Ok(dirs::data_local_dir() - .ok_or(DirectoryError::NoHomeDirectory)? - .join(application::DATA_DIR_NAME)) -} - -/// Path resolver with hierarchy-aware methods -pub struct PathResolver<'a> { - os: &'a Os, -} - -impl<'a> PathResolver<'a> { - pub fn new(os: &'a Os) -> Self { - Self { os } - } - - /// Get workspace-scoped path resolver - pub fn workspace(&self) -> WorkspacePaths<'_> { - WorkspacePaths { os: self.os } - } - - /// Get global-scoped path resolver - pub fn global(&self) -> GlobalPaths<'_> { - GlobalPaths { os: self.os } - } -} - -/// Workspace-scoped path methods -pub struct WorkspacePaths<'a> { - os: &'a Os, -} - -impl<'a> WorkspacePaths<'a> { - pub fn mcp_config(&self) -> Result { - Ok(self.os.env.current_dir()?.join(workspace::MCP_CONFIG)) - } -} - -/// Global-scoped path methods -pub struct GlobalPaths<'a> { - os: &'a Os, -} - -impl<'a> GlobalPaths<'a> { - pub fn mcp_config(&self) -> Result { - Ok(home_dir(self.os)?.join(global::MCP_CONFIG)) - } - - pub fn global_context(&self) -> Result { - Ok(home_dir(self.os)?.join(global::GLOBAL_CONTEXT)) - } - - pub fn profiles_dir(&self) -> Result { - Ok(home_dir(self.os)?.join(global::PROFILES_DIR)) - } -} - -/// Application path static methods -pub struct ApplicationPaths; - -impl ApplicationPaths { - /// Static method for settings path (to avoid circular dependency) - pub fn settings_path_static() -> Result { - Ok(app_data_dir()?.join(application::SETTINGS_FILE)) - } - - /// Static method for database path (to avoid circular dependency) - pub fn database_path_static() -> Result { - Ok(app_data_dir()?.join(application::DATABASE_FILE)) - } -} diff --git a/crates/chat-cli/src/util/process/mod.rs b/crates/chat-cli/src/util/process/mod.rs deleted file mode 100644 index e0a841459..000000000 --- a/crates/chat-cli/src/util/process/mod.rs +++ /dev/null @@ -1,11 +0,0 @@ -pub use sysinfo::Pid; - -#[cfg(target_os = "windows")] -mod windows; -#[cfg(target_os = "windows")] -pub use windows::*; - -#[cfg(not(windows))] -mod unix; -#[cfg(not(windows))] -pub use unix::*; diff --git a/crates/chat-cli/src/util/process/unix.rs b/crates/chat-cli/src/util/process/unix.rs deleted file mode 100644 index b0ffc6093..000000000 --- a/crates/chat-cli/src/util/process/unix.rs +++ /dev/null @@ -1,64 +0,0 @@ -use nix::sys::signal::Signal; -use sysinfo::Pid; - -pub fn terminate_process(pid: Pid) -> Result<(), String> { - let nix_pid = nix::unistd::Pid::from_raw(pid.as_u32() as i32); - nix::sys::signal::kill(nix_pid, Signal::SIGTERM).map_err(|e| format!("Failed to terminate process: {}", e)) -} - -#[cfg(test)] -#[cfg(not(windows))] -mod tests { - use std::process::Command; - use std::time::Duration; - - use super::*; - - // Helper to create a long-running process for testing - fn spawn_test_process() -> std::process::Child { - let mut command = Command::new("sleep"); - command.arg("30"); - command.spawn().expect("Failed to spawn test process") - } - - #[test] - fn test_terminate_process() { - // Spawn a test process - let mut child = spawn_test_process(); - let pid = Pid::from_u32(child.id()); - - // Terminate the process - let result = terminate_process(pid); - - // Verify termination was successful - assert!(result.is_ok(), "Process termination failed: {:?}", result.err()); - - // Give it a moment to terminate - std::thread::sleep(Duration::from_millis(100)); - - // Verify the process is actually terminated - match child.try_wait() { - Ok(Some(_)) => { - // Process exited, which is what we expect - }, - Ok(None) => { - panic!("Process is still running after termination"); - }, - Err(e) => { - panic!("Error checking process status: {}", e); - }, - } - } - - #[test] - fn test_terminate_nonexistent_process() { - // Use a likely invalid PID - let invalid_pid = Pid::from_u32(u32::MAX - 1); - - // Attempt to terminate a non-existent process - let result = terminate_process(invalid_pid); - - // Should return an error - assert!(result.is_err(), "Terminating non-existent process should fail"); - } -} diff --git a/crates/chat-cli/src/util/process/windows.rs b/crates/chat-cli/src/util/process/windows.rs deleted file mode 100644 index 12e0389bd..000000000 --- a/crates/chat-cli/src/util/process/windows.rs +++ /dev/null @@ -1,120 +0,0 @@ -use std::ops::Deref; - -use sysinfo::Pid; -use windows::Win32::Foundation::{ - CloseHandle, - HANDLE, -}; -use windows::Win32::System::Threading::{ - OpenProcess, - PROCESS_TERMINATE, - TerminateProcess, -}; - -/// Terminate a process on Windows using the Windows API -pub fn terminate_process(pid: Pid) -> Result<(), String> { - unsafe { - // Open the process with termination rights - let handle = OpenProcess(PROCESS_TERMINATE, false, pid.as_u32()) - .map_err(|e| format!("Failed to open process: {}", e))?; - - // Create a safe handle that will be closed automatically when dropped - let safe_handle = SafeHandle::new(handle).ok_or_else(|| "Invalid process handle".to_string())?; - - // Terminate the process with exit code 1 - TerminateProcess(*safe_handle, 1).map_err(|e| format!("Failed to terminate process: {}", e))?; - - Ok(()) - } -} - -struct SafeHandle(HANDLE); - -impl SafeHandle { - fn new(handle: HANDLE) -> Option { - if !handle.is_invalid() { Some(Self(handle)) } else { None } - } -} - -impl Drop for SafeHandle { - fn drop(&mut self) { - unsafe { - let _ = CloseHandle(self.0); - } - } -} - -impl Deref for SafeHandle { - type Target = HANDLE; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -#[cfg(test)] -mod tests { - use std::process::Command; - use std::time::Duration; - - use super::*; - - // Helper to create a long-running process for testing - fn spawn_test_process() -> std::process::Child { - let mut command = Command::new("cmd"); - command.args(["/C", "timeout 30 > nul"]); - command.spawn().expect("Failed to spawn test process") - } - - #[test] - fn test_terminate_process() { - // Spawn a test process - let mut child = spawn_test_process(); - let pid = Pid::from_u32(child.id()); - - // Terminate the process - let result = terminate_process(pid); - - // Verify termination was successful - assert!(result.is_ok(), "Process termination failed: {:?}", result.err()); - - // Give it a moment to terminate - std::thread::sleep(Duration::from_millis(100)); - - // Verify the process is actually terminated - match child.try_wait() { - Ok(Some(_)) => { - // Process exited, which is what we expect - }, - Ok(None) => { - panic!("Process is still running after termination"); - }, - Err(e) => { - panic!("Error checking process status: {}", e); - }, - } - } - - #[test] - fn test_terminate_nonexistent_process() { - // Use a likely invalid PID - let invalid_pid = Pid::from_u32(u32::MAX - 1); - - // Attempt to terminate a non-existent process - let result = terminate_process(invalid_pid); - - // Should return an error - assert!(result.is_err(), "Terminating non-existent process should fail"); - } - - #[test] - fn test_safe_handle() { - // Test creating a SafeHandle with an invalid handle - let invalid_handle = HANDLE(std::ptr::null_mut()); - let safe_handle = SafeHandle::new(invalid_handle); - assert!(safe_handle.is_none(), "SafeHandle should be None for invalid handle"); - - // We can't easily test a valid handle without actually opening a process, - // which would require additional setup and teardown - } -} diff --git a/crates/chat-cli/src/util/spinner.rs b/crates/chat-cli/src/util/spinner.rs deleted file mode 100644 index 1ffcf1cc1..000000000 --- a/crates/chat-cli/src/util/spinner.rs +++ /dev/null @@ -1,126 +0,0 @@ -use std::io::{ - Write, - stdout, -}; -use std::sync::mpsc::{ - Sender, - TryRecvError, - channel, -}; -use std::thread; -use std::thread::JoinHandle; -use std::time::Duration; - -use anstream::{ - print, - println, -}; -use crossterm::ExecutableCommand; - -const FRAMES: &[&str] = &[ - "▰▱▱▱▱▱▱", - "▰▰▱▱▱▱▱", - "▰▰▰▱▱▱▱", - "▰▰▰▰▱▱▱", - "▰▰▰▰▰▱▱", - "▰▰▰▰▰▰▱", - "▰▰▰▰▰▰▰", - "▰▱▱▱▱▱▱", -]; -const INTERVAL: Duration = Duration::from_millis(100); - -pub struct Spinner { - sender: Sender>, - join: Option>, -} - -impl Drop for Spinner { - fn drop(&mut self) { - if self.join.is_some() { - self.sender.send(Some("\x1b[2K\r".into())).unwrap(); - self.join.take().unwrap().join().unwrap(); - } - } -} - -#[derive(Debug, Clone)] -pub enum SpinnerComponent { - Text(String), - Spinner, -} - -impl Spinner { - pub fn new(components: Vec) -> Self { - let (sender, recv) = channel::>(); - - stdout().execute(crossterm::cursor::Hide).ok(); - - let join = thread::spawn(move || { - 'outer: loop { - let mut stdout = stdout(); - for frame in FRAMES.iter() { - let (do_stop, stop_symbol) = match recv.try_recv() { - Ok(stop_symbol) => (true, stop_symbol), - Err(TryRecvError::Disconnected) => (true, None), - Err(TryRecvError::Empty) => (false, None), - }; - - let frame = stop_symbol.unwrap_or_else(|| (*frame).to_string()); - - let line = components.iter().fold(String::new(), |mut acc, elem| { - acc.push_str(match elem { - SpinnerComponent::Text(t) => t, - SpinnerComponent::Spinner => &frame, - }); - acc - }); - - print!("\r{line}"); - - stdout.flush().unwrap(); - - if do_stop { - stdout.execute(crossterm::cursor::Show).ok(); - break 'outer; - } - - thread::sleep(INTERVAL); - } - } - }); - - Self { - sender, - join: Some(join), - } - } - - fn stop_inner(&mut self, stop_symbol: Option) { - self.sender.send(stop_symbol).expect("Could not stop spinner thread."); - self.join.take().unwrap().join().unwrap(); - } - - pub fn stop(&mut self) { - self.stop_inner(None); - } - - pub fn stop_with_message(&mut self, msg: String) { - self.stop(); - println!("\x1b[2K\r{msg}"); - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_spinner() { - let mut spinner = Spinner::new(vec![ - SpinnerComponent::Spinner, - SpinnerComponent::Text("Loading".into()), - ]); - thread::sleep(Duration::from_secs(1)); - spinner.stop_with_message("Done".into()); - } -} diff --git a/crates/chat-cli/src/util/system_info/linux.rs b/crates/chat-cli/src/util/system_info/linux.rs deleted file mode 100644 index 20d257d55..000000000 --- a/crates/chat-cli/src/util/system_info/linux.rs +++ /dev/null @@ -1,134 +0,0 @@ -use std::io; -use std::path::Path; -use std::sync::OnceLock; - -use nix::sys::utsname::uname; -use serde::{ - Deserialize, - Serialize, -}; - -use super::{ - OSVersion, - OsRelease, -}; - -#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] -pub enum DisplayServer { - X11, - Wayland, -} - -#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] -pub enum DesktopEnvironment { - Gnome, - Plasma, - I3, - Sway, -} - -pub fn get_os_release() -> Option<&'static OsRelease> { - static OS_RELEASE: OnceLock> = OnceLock::new(); - OS_RELEASE.get_or_init(|| OsRelease::load().ok()).as_ref() -} - -pub fn get_os_version() -> Option { - let kernel_version = uname().ok()?.release().to_string_lossy().into(); - let os_release = get_os_release().cloned(); - - Some(OSVersion::Linux { - kernel_version, - os_release, - }) -} - -impl OsRelease { - fn path() -> &'static Path { - Path::new("/etc/os-release") - } - - pub(crate) fn load() -> io::Result { - let os_release_str = std::fs::read_to_string(Self::path())?; - Ok(OsRelease::from_str(&os_release_str)) - } - - pub(crate) fn from_str(s: &str) -> OsRelease { - // Remove the starting and ending quotes from a string if they match - let strip_quotes = |s: &str| -> Option { - if s.starts_with('"') && s.ends_with('"') { - Some(s[1..s.len() - 1].into()) - } else { - Some(s.into()) - } - }; - - let mut os_release = OsRelease::default(); - for line in s.lines() { - if let Some((key, value)) = line.split_once('=') { - match key { - "ID" => os_release.id = strip_quotes(value), - "NAME" => os_release.name = strip_quotes(value), - "PRETTY_NAME" => os_release.pretty_name = strip_quotes(value), - "VERSION" => os_release.version = strip_quotes(value), - "VERSION_ID" => os_release.version_id = strip_quotes(value), - "BUILD_ID" => os_release.build_id = strip_quotes(value), - "VARIANT" => os_release.variant = strip_quotes(value), - "VARIANT_ID" => os_release.variant_id = strip_quotes(value), - _ => {}, - } - } - } - os_release - } -} - -#[cfg(test)] -mod test { - use super::*; - - #[test] - fn os_release() { - if OsRelease::path().exists() { - OsRelease::load().unwrap(); - } else { - println!("Skipping os-release test as /etc/os-release does not exist"); - } - } - - #[test] - fn os_release_parse() { - let os_release_str = indoc::indoc! {r#" - NAME="Amazon Linux" - VERSION="2023" - ID="amzn" - ID_LIKE="fedora" - VERSION_ID="2023" - PLATFORM_ID="platform:al2023" - PRETTY_NAME="Amazon Linux 2023.4.20240416" - ANSI_COLOR="0;33" - CPE_NAME="cpe:2.3:o:amazon:amazon_linux:2023" - HOME_URL="https://aws.amazon.com/linux/amazon-linux-2023/" - DOCUMENTATION_URL="https://docs.aws.amazon.com/linux/" - SUPPORT_URL="https://aws.amazon.com/premiumsupport/" - BUG_REPORT_URL="https://github.com/amazonlinux/amazon-linux-2023" - VENDOR_NAME="AWS" - VENDOR_URL="https://aws.amazon.com/" - SUPPORT_END="2028-03-15" - "#}; - - let os_release = OsRelease::from_str(os_release_str); - - assert_eq!(os_release.id, Some("amzn".into())); - - assert_eq!(os_release.name, Some("Amazon Linux".into())); - assert_eq!(os_release.pretty_name, Some("Amazon Linux 2023.4.20240416".into())); - - assert_eq!(os_release.version_id, Some("2023".into())); - assert_eq!(os_release.version, Some("2023".into())); - - assert_eq!(os_release.build_id, None); - - assert_eq!(os_release.variant_id, None); - assert_eq!(os_release.variant, None); - } -} diff --git a/crates/chat-cli/src/util/system_info/mod.rs b/crates/chat-cli/src/util/system_info/mod.rs deleted file mode 100644 index 6f4aa75f3..000000000 --- a/crates/chat-cli/src/util/system_info/mod.rs +++ /dev/null @@ -1,190 +0,0 @@ -#[cfg(target_os = "linux")] -pub mod linux; -#[cfg(target_os = "windows")] -pub mod windows; - -use std::sync::OnceLock; - -use cfg_if::cfg_if; -use serde::{ - Deserialize, - Serialize, -}; - -use crate::os::Env; - -/// Fields for OS release information -/// Fields from -#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)] -pub struct OsRelease { - pub id: Option, - - pub name: Option, - pub pretty_name: Option, - - pub version_id: Option, - pub version: Option, - - pub build_id: Option, - - pub variant_id: Option, - pub variant: Option, -} - -#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] -#[serde(rename_all = "lowercase")] -pub enum OSVersion { - MacOS { - major: i32, - minor: i32, - patch: Option, - build: String, - }, - Linux { - kernel_version: String, - #[serde(flatten)] - os_release: Option, - }, - Windows { - name: String, - build: u32, - }, - FreeBsd { - version: String, - }, -} - -impl std::fmt::Display for OSVersion { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - OSVersion::MacOS { - major, - minor, - patch, - build, - } => { - let patch = patch.unwrap_or(0); - write!(f, "macOS {major}.{minor}.{patch} ({build})") - }, - OSVersion::Linux { - kernel_version, - os_release, - } => match os_release - .as_ref() - .and_then(|r| r.pretty_name.as_ref().or(r.name.as_ref())) - { - Some(distro_name) => write!(f, "Linux {kernel_version} - {distro_name}"), - None => write!(f, "Linux {kernel_version}"), - }, - OSVersion::Windows { name, build } => write!(f, "{name} (or newer) - build {build}"), - OSVersion::FreeBsd { version } => write!(f, "FreeBSD {version}"), - } - } -} - -pub fn os_version() -> Option<&'static OSVersion> { - static OS_VERSION: OnceLock> = OnceLock::new(); - OS_VERSION - .get_or_init(|| { - cfg_if! { - if #[cfg(target_os = "macos")] { - use std::process::Command; - use regex::Regex; - - let version_info = Command::new("sw_vers") - .output() - .ok()?; - - let version_info: String = String::from_utf8_lossy(&version_info.stdout).trim().into(); - - let version_regex = Regex::new(r"ProductVersion:\s*(\S+)").unwrap(); - let build_regex = Regex::new(r"BuildVersion:\s*(\S+)").unwrap(); - - let version: String = version_regex - .captures(&version_info) - .and_then(|c| c.get(1)) - .map(|v| v.as_str().into())?; - - let major = version - .split('.') - .next()? - .parse().ok()?; - - let minor = version - .split('.') - .nth(1)? - .parse().ok()?; - - let patch = version.split('.').nth(2).and_then(|p| p.parse().ok()); - - let build = build_regex - .captures(&version_info) - .and_then(|c| c.get(1))? - .as_str() - .into(); - - Some(OSVersion::MacOS { - major, - minor, - patch, - build, - }) - } else if #[cfg(target_os = "linux")] { - linux::get_os_version() - } else if #[cfg(target_os = "windows")] { - windows::get_os_version() - } else if #[cfg(target_os = "freebsd")] { - use nix::sys::utsname::uname; - - let version = uname().ok()?.release().to_string_lossy().into(); - - Some(OSVersion::FreeBsd { - version, - }) - } - } - }) - .as_ref() -} - -pub fn in_ssh() -> bool { - static IN_SSH: OnceLock = OnceLock::new(); - *IN_SSH.get_or_init(|| Env::new().in_ssh()) -} - -/// Test if the program is running under WSL -pub fn in_wsl() -> bool { - cfg_if! { - if #[cfg(target_os = "linux")] { - static IN_WSL: OnceLock = OnceLock::new(); - *IN_WSL.get_or_init(|| { - if let Ok(b) = std::fs::read("/proc/sys/kernel/osrelease") { - if let Ok(s) = std::str::from_utf8(&b) { - let a = s.to_ascii_lowercase(); - return a.contains("microsoft") || a.contains("wsl"); - } - } - false - }) - } else { - false - } - } -} - -/// Is the calling binary running on a remote instance -pub fn is_remote() -> bool { - // TODO(chay): Add detection for inside docker container - in_ssh() || in_wsl() || std::env::var_os("Q_FAKE_IS_REMOTE").is_some() -} - -pub fn in_codespaces() -> bool { - static IN_CODESPACES: OnceLock = OnceLock::new(); - *IN_CODESPACES - .get_or_init(|| std::env::var_os("CODESPACES").is_some() || std::env::var_os("Q_CODESPACES").is_some()) -} - -pub fn in_ci() -> bool { - static IN_CI: OnceLock = OnceLock::new(); - *IN_CI.get_or_init(|| std::env::var_os("CI").is_some() || std::env::var_os("Q_CI").is_some()) -} diff --git a/crates/chat-cli/src/util/system_info/windows.rs b/crates/chat-cli/src/util/system_info/windows.rs deleted file mode 100644 index e0182de7f..000000000 --- a/crates/chat-cli/src/util/system_info/windows.rs +++ /dev/null @@ -1,33 +0,0 @@ -use serde::{ - Deserialize, - Serialize, -}; -use winreg::RegKey; -use winreg::enums::HKEY_LOCAL_MACHINE; - -use super::OSVersion; - -#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] -pub enum DisplayServer { - Win32, -} - -#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] -pub enum DesktopEnvironment { - Windows, - WindowsTerminal, -} - -pub fn get_os_version() -> Option { - let rkey = RegKey::predef(HKEY_LOCAL_MACHINE) - .open_subkey(r"SOFTWARE\Microsoft\Windows NT\CurrentVersion") - .ok()?; - - let build: String = rkey.get_value("CurrentBuild").ok()?; - let name: String = rkey.get_value("ProductName").ok()?; - - Some(OSVersion::Windows { - name, - build: build.parse::().ok()?, - }) -} diff --git a/crates/chat-cli/src/util/test.rs b/crates/chat-cli/src/util/test.rs deleted file mode 100644 index c18c8eead..000000000 --- a/crates/chat-cli/src/util/test.rs +++ /dev/null @@ -1,16 +0,0 @@ -macro_rules! assert_parse { - ( - [ $($args:expr),+ ], - $subcommand:expr - ) => { - assert_eq!( - ::parse_from([crate::util::CHAT_BINARY_NAME, $($args),*]), - crate::cli::Cli { - subcommand: Some($subcommand), - ..Default::default() - } - ); - }; -} - -pub(crate) use assert_parse; diff --git a/crates/chat-cli/telemetry_definitions.json b/crates/chat-cli/telemetry_definitions.json deleted file mode 100644 index bc28fc134..000000000 --- a/crates/chat-cli/telemetry_definitions.json +++ /dev/null @@ -1,328 +0,0 @@ -{ - "types": [ - { - "name": "amazonQProfileRegion", - "type": "string", - "description": "Region of the Q Profile associated with a metric\n- \"n/a\" if metric is not associated with a profile or region.\n- \"not-set\" if metric is associated with a profile, but profile is unknown." - }, - { - "name": "ssoRegion", - "type": "string", - "description": "Region of the current SSO connection. Typically associated with credentialStartUrl\n- \"n/a\" if metric is not associated with a region.\n- \"not-set\" if metric is associated with a region, but region is unknown." - }, - { - "name": "profileCount", - "type": "int", - "description": "The number of profiles that were available to choose from" - }, - { - "name": "source", - "type": "string", - "description": "Identifies the source component where the telemetry event originated." - }, - { - "name": "amazonqConversationId", - "type": "string", - "description": "Uniquely identifies a message with which the user interacts." - }, - { - "name": "codewhispererterminal_command", - "type": "string", - "description": "The CLI tool a completion was for" - }, - { - "name": "codewhispererterminal_subcommand", - "type": "string", - "description": "A codewhisperer CLI subcommand" - }, - { - "name": "codewhispererterminal_inCloudshell", - "type": "boolean", - "description": "Whether the CLI is running in the AWS CloudShell environment" - }, - { - "name": "credentialStartUrl", - "type": "string", - "description": "The start URL of current SSO connection" - }, - { - "name": "requestId", - "type": "string", - "description": "The id assigned to an AWS request" - }, - { - "name": "oauthFlow", - "type": "string", - "description": "The oauth authentication flow executed by the user, e.g. device code or PKCE" - }, - { - "name": "result", - "type": "string", - "description": "Whether or not the operation succeeded" - }, - { - "name": "reason", - "type": "string", - "description": "Description of what caused an error, if any. Should be a stable/predictable name." - }, - { - "name": "reasonDesc", - "type": "string", - "description": "Error message detail. May contain arbitrary message details (unlike the `reason` field)." - }, - { - "name": "statusCode", - "type": "int", - "description": "The HTTP status code of the request, e.g. 200, 400, etc." - }, - { - "name": "codewhispererterminal_toolUseId", - "type": "string", - "description": "The id assigned to the client by the model representing a tool use event" - }, - { - "name": "codewhispererterminal_toolName", - "type": "string", - "description": "The name associated with a tool" - }, - { - "name": "codewhispererterminal_AwsServiceName", - "type": "string", - "description": "AWS service called by the tool" - }, - { - "name": "codewhispererterminal_AwsOperationName", - "type": "string", - "description": "Specific operation of the AWS service invoked by the tool" - }, - { - "name": "codewhispererterminal_isToolUseAccepted", - "type": "boolean", - "description": "Denotes if a tool use event has been fulfilled" - }, - { - "name": "codewhispererterminal_toolUseIsSuccess", - "type": "boolean", - "description": "The outcome of a tool use" - }, - { - "name": "codewhispererterminal_utteranceId", - "type": "string", - "description": "Id associated with a given response from the model" - }, - { - "name": "codewhispererterminal_userInputId", - "type": "string", - "description": "Id associated with a given user input. This is used to differentiate responses to user input and that of retries from tool uses. This id is the utterance id of the first response following an user input" - }, - { - "name": "codewhispererterminal_isToolValid", - "type": "boolean", - "description": "If the use of tool as instructed by the model is valid" - }, - { - "name": "codewhispererterminal_contextFileLength", - "type": "int", - "description": "The length of the files included as part of context management" - }, - { - "name": "codewhispererterminal_mcpServerInitFailureReason", - "type": "string", - "description": "Reason for which a mcp server has failed to be initialized" - }, - { - "name": "codewhispererterminal_toolsPerMcpServer", - "type": "int", - "description": "The number of tools provided by a mcp server" - }, - { - "name": "codewhispererterminal_isCustomTool", - "type": "boolean", - "description": "Denoting whether or not the tool is a custom tool" - }, - { - "name": "codewhispererterminal_customToolInputTokenSize", - "type": "int", - "description": "Number of tokens used on invoking the custom tool" - }, - { - "name": "codewhispererterminal_customToolOutputTokenSize", - "type": "int", - "description": "Number of tokens received from invoking the custom tool" - }, - { - "name": "codewhispererterminal_customToolLatency", - "type": "int", - "description": "Custom tool call latency in seconds" - }, - { - "name": "codewhispererterminal_model", - "type": "string", - "description": "The underlying LLM used by the service, set by the client" - }, - { - "name": "codewhispererterminal_clientApplication", - "type": "string", - "description": "Identifier for the client application or service using the chat-cli" - } - ], - "metrics": [ - { - "name": "amazonq_startChat", - "description": "Captures start of the conversation with amazonq /dev", - "metadata": [ - { "type": "amazonqConversationId" }, - { "type": "credentialStartUrl", "required": false }, - { "type": "codewhispererterminal_inCloudshell" }, - { "type": "codewhispererterminal_model" } - ] - }, - { - "name": "codewhispererterminal_addChatMessage", - "description": "Captures active usage with Q Chat in shell", - "metadata": [ - { "type": "amazonqConversationId" }, - { "type": "codewhispererterminal_utteranceId" }, - { "type": "credentialStartUrl", "required": false }, - { "type": "ssoRegion", "required": false }, - { "type": "codewhispererterminal_inCloudshell" }, - { "type": "codewhispererterminal_contextFileLength", "required": false }, - { "type": "requestId" }, - { "type": "result", "required": true }, - { "type": "reason", "required": false }, - { "type": "reasonDesc", "required": false }, - { "type": "statusCode", "required": false }, - { "type": "codewhispererterminal_model" }, - { "type": "codewhispererterminal_clientApplication" } - ] - }, - { - "name": "amazonq_endChat", - "description": "Captures end of the conversation with amazonq /dev", - "metadata": [ - { "type": "amazonqConversationId" }, - { "type": "credentialStartUrl", "required": false }, - { "type": "codewhispererterminal_inCloudshell" }, - { "type": "codewhispererterminal_model" } - ] - }, - { - "name": "codewhispererterminal_userLoggedIn", - "description": "Emitted when users log in", - "passive": false, - "metadata": [ - { "type": "credentialStartUrl" }, - { "type": "codewhispererterminal_inCloudshell" } - ] - }, - { - "name": "codewhispererterminal_refreshCredentials", - "description": "Emitted when users refresh their credentials", - "passive": false, - "metadata": [ - { "type": "credentialStartUrl" }, - { "type": "requestId" }, - { "type": "oauthFlow" }, - { "type": "result" }, - { "type": "reason", "required": false }, - { "type": "codewhispererterminal_inCloudshell" } - ] - }, - { - "name": "codewhispererterminal_cliSubcommandExecuted", - "description": "Emitted on CW CLI subcommand executed", - "passive": false, - "metadata": [ - { "type": "credentialStartUrl" }, - { "type": "codewhispererterminal_subcommand" }, - { "type": "codewhispererterminal_inCloudshell" }, - { "type": "codewhispererterminal_clientApplication" } - ] - }, - { - "name": "codewhispererterminal_toolUseSuggested", - "description": "Emitted once per tool use to report outcome of tool use suggested", - "passive": false, - "metadata": [ - { "type": "credentialStartUrl" }, - { "type": "amazonqConversationId" }, - { "type": "codewhispererterminal_utteranceId" }, - { "type": "codewhispererterminal_userInputId" }, - { "type": "codewhispererterminal_toolUseId" }, - { "type": "codewhispererterminal_toolName" }, - { "type": "codewhispererterminal_isToolUseAccepted" }, - { "type": "codewhispererterminal_isToolValid" }, - { "type": "codewhispererterminal_toolUseIsSuccess", "required": false }, - { "type": "codewhispererterminal_isCustomTool" }, - { - "type": "codewhispererterminal_customToolInputTokenSize", - "required": false - }, - { - "type": "codewhispererterminal_customToolOutputTokenSize", - "required": false - }, - { "type": "codewhispererterminal_customToolLatency", "required": false }, - { "type": "codewhispererterminal_model" }, - { "type": "codewhispererterminal_clientApplication" }, - { "type": "codewhispererterminal_AwsServiceName", "required": false }, - { "type": "codewhispererterminal_AwsOperationName", "required": false } - ] - }, - { - "name": "codewhispererterminal_mcpServerInit", - "description": "Emitted once per mcp server on start up", - "passive": false, - "metadata": [ - { "type": "credentialStartUrl" }, - { "type": "amazonqConversationId" }, - { - "type": "codewhispererterminal_mcpServerInitFailureReason", - "required": false - }, - { "type": "codewhispererterminal_toolsPerMcpServer" }, - { "type": "codewhispererterminal_clientApplication" } - ] - }, - { - "name": "amazonq_didSelectProfile", - "description": "Emitted after the user's Q Profile has been set, whether the user was prompted with a dialog, or a profile was automatically assigned after signing in.", - "metadata": [ - { "type": "source" }, - { "type": "amazonQProfileRegion" }, - { "type": "result" }, - { "type": "ssoRegion", "required": false }, - { "type": "credentialStartUrl", "required": false }, - { "type": "profileCount", "required": false } - ], - "passive": true - }, - { - "name": "amazonq_profileState", - "description": "Indicates a change in the user's Q Profile state", - "metadata": [ - { "type": "source" }, - { "type": "amazonQProfileRegion" }, - { "type": "result" }, - { "type": "ssoRegion", "required": false }, - { "type": "credentialStartUrl", "required": false } - ], - "passive": true - }, - { - "name": "amazonq_messageResponseError", - "description": "When an error has occurred in response to a prompt", - "metadata": [ - { "type": "credentialStartUrl", "required": false }, - { "type": "ssoRegion", "required": false }, - { "type": "amazonqConversationId" }, - { "type": "codewhispererterminal_contextFileLength", "required": false }, - { "type": "result" }, - { "type": "reason", "required": false }, - { "type": "reasonDesc", "required": false }, - { "type": "statusCode", "required": false }, - { "type": "codewhispererterminal_clientApplication" } - ] - } - ] -} diff --git a/crates/chat-cli/test_mcp_server/test_server.rs b/crates/chat-cli/test_mcp_server/test_server.rs deleted file mode 100644 index 970157f96..000000000 --- a/crates/chat-cli/test_mcp_server/test_server.rs +++ /dev/null @@ -1,340 +0,0 @@ -//! This is a bin used solely for testing the client -use std::collections::HashMap; -use std::str::FromStr; -use std::sync::atomic::{ - AtomicU8, - Ordering, -}; - -use chat_cli::{ - self, - JsonRpcRequest, - JsonRpcResponse, - JsonRpcStdioTransport, - PreServerRequestHandler, - Response, - Server, - ServerError, - ServerRequestHandler, -}; -use tokio::sync::Mutex; - -#[derive(Default)] -struct Handler { - pending_request: Option Option + Send + Sync>>, - #[allow(clippy::type_complexity)] - send_request: Option) -> Result<(), ServerError> + Send + Sync>>, - storage: Mutex>, - tool_spec: Mutex>, - tool_spec_key_list: Mutex>, - prompts: Mutex>, - prompt_key_list: Mutex>, - prompt_list_call_no: AtomicU8, -} - -impl PreServerRequestHandler for Handler { - fn register_pending_request_callback( - &mut self, - cb: impl Fn(u64) -> Option + Send + Sync + 'static, - ) { - self.pending_request = Some(Box::new(cb)); - } - - fn register_send_request_callback( - &mut self, - cb: impl Fn(&str, Option) -> Result<(), ServerError> + Send + Sync + 'static, - ) { - self.send_request = Some(Box::new(cb)); - } -} - -#[async_trait::async_trait] -impl ServerRequestHandler for Handler { - async fn handle_initialize(&self, params: Option) -> Result { - let mut storage = self.storage.lock().await; - if let Some(params) = params { - storage.insert("client_cap".to_owned(), params); - } - let capabilities = serde_json::json!({ - "protocolVersion": "2024-11-05", - "capabilities": { - "logging": {}, - "prompts": { - "listChanged": true - }, - "resources": { - "subscribe": true, - "listChanged": true - }, - "tools": { - "listChanged": true - } - }, - "serverInfo": { - "name": "TestServer", - "version": "1.0.0" - } - }); - Ok(Some(capabilities)) - } - - async fn handle_incoming(&self, method: &str, params: Option) -> Result { - match method { - "notifications/initialized" => { - { - let mut storage = self.storage.lock().await; - storage.insert( - "init_ack_sent".to_owned(), - serde_json::Value::from_str("true").expect("Failed to convert string to value"), - ); - } - Ok(None) - }, - "verify_init_params_sent" => { - let client_capabilities = { - let storage = self.storage.lock().await; - storage.get("client_cap").cloned() - }; - Ok(client_capabilities) - }, - "verify_init_ack_sent" => { - let result = { - let storage = self.storage.lock().await; - storage.get("init_ack_sent").cloned() - }; - Ok(result) - }, - "store_mock_tool_spec" => { - let Some(params) = params else { - eprintln!("Params missing from store mock tool spec"); - return Ok(None); - }; - // expecting a mock_specs: { key: String, value: serde_json::Value }[]; - let Ok(mock_specs) = serde_json::from_value::>(params) else { - eprintln!("Failed to convert to mock specs from value"); - return Ok(None); - }; - let self_tool_specs = self.tool_spec.lock().await; - let mut self_tool_spec_key_list = self.tool_spec_key_list.lock().await; - let _ = mock_specs.iter().fold(self_tool_specs, |mut acc, spec| { - let Some(key) = spec.get("key").cloned() else { - return acc; - }; - let Ok(key) = serde_json::from_value::(key) else { - eprintln!("Failed to convert serde value to string for key"); - return acc; - }; - self_tool_spec_key_list.push(key.clone()); - acc.insert(key, spec.get("value").cloned()); - acc - }); - Ok(None) - }, - "tools/list" => { - if let Some(params) = params { - if let Some(cursor) = params.get("cursor").cloned() { - let Ok(cursor) = serde_json::from_value::(cursor) else { - eprintln!("Failed to convert cursor to string: {:#?}", params); - return Ok(None); - }; - let self_tool_spec_key_list = self.tool_spec_key_list.lock().await; - let self_tool_spec = self.tool_spec.lock().await; - let (next_cursor, spec) = { - 'blk: { - for (i, item) in self_tool_spec_key_list.iter().enumerate() { - if item == &cursor { - break 'blk ( - self_tool_spec_key_list.get(i + 1).cloned(), - self_tool_spec.get(&cursor).cloned().unwrap(), - ); - } - } - (None, None) - } - }; - if let Some(next_cursor) = next_cursor { - return Ok(Some(serde_json::json!({ - "tools": [spec.unwrap()], - "nextCursor": next_cursor, - }))); - } else { - return Ok(Some(serde_json::json!({ - "tools": [spec.unwrap()], - }))); - } - } else { - eprintln!("Params exist but cursor is missing"); - return Ok(None); - } - } else { - let tool_spec_key_list = self.tool_spec_key_list.lock().await; - let tool_spec = self.tool_spec.lock().await; - let first_key = tool_spec_key_list - .first() - .expect("First key missing from tool specs") - .clone(); - let first_value = tool_spec - .get(&first_key) - .expect("First value missing from tool specs") - .clone(); - let second_key = tool_spec_key_list - .get(1) - .expect("Second key missing from tool specs") - .clone(); - return Ok(Some(serde_json::json!({ - "tools": [first_value], - "nextCursor": second_key - }))); - }; - }, - "get_env_vars" => { - let kv = std::env::vars().fold(HashMap::::new(), |mut acc, (k, v)| { - acc.insert(k, v); - acc - }); - Ok(Some(serde_json::json!(kv))) - }, - // This is a test path relevant only to sampling - "trigger_server_request" => { - let Some(ref send_request) = self.send_request else { - return Err(ServerError::MissingMethod); - }; - let params = Some(serde_json::json!({ - "messages": [ - { - "role": "user", - "content": { - "type": "text", - "text": "What is the capital of France?" - } - } - ], - "modelPreferences": { - "hints": [ - { - "name": "claude-3-sonnet" - } - ], - "intelligencePriority": 0.8, - "speedPriority": 0.5 - }, - "systemPrompt": "You are a helpful assistant.", - "maxTokens": 100 - })); - send_request("sampling/createMessage", params)?; - Ok(None) - }, - "store_mock_prompts" => { - let Some(params) = params else { - eprintln!("Params missing from store mock prompts"); - return Ok(None); - }; - // expecting a mock_prompts: { key: String, value: serde_json::Value }[]; - let Ok(mock_prompts) = serde_json::from_value::>(params) else { - eprintln!("Failed to convert to mock specs from value"); - return Ok(None); - }; - let mut self_prompts = self.prompts.lock().await; - let mut self_prompt_key_list = self.prompt_key_list.lock().await; - let is_first_mock = self_prompts.is_empty(); - self_prompts.clear(); - self_prompt_key_list.clear(); - let _ = mock_prompts.iter().fold(self_prompts, |mut acc, spec| { - let Some(key) = spec.get("key").cloned() else { - return acc; - }; - let Ok(key) = serde_json::from_value::(key) else { - eprintln!("Failed to convert serde value to string for key"); - return acc; - }; - self_prompt_key_list.push(key.clone()); - acc.insert(key, spec.get("value").cloned()); - acc - }); - if !is_first_mock { - if let Some(sender) = &self.send_request { - let _ = sender("notifications/prompts/list_changed", None); - } - } - Ok(None) - }, - "prompts/list" => { - // We expect this method to be called after the mock prompts have already been - // stored. - self.prompt_list_call_no.fetch_add(1, Ordering::Relaxed); - if let Some(params) = params { - if let Some(cursor) = params.get("cursor").cloned() { - let Ok(cursor) = serde_json::from_value::(cursor) else { - eprintln!("Failed to convert cursor to string: {:#?}", params); - return Ok(None); - }; - let self_prompt_key_list = self.prompt_key_list.lock().await; - let self_prompts = self.prompts.lock().await; - let (next_cursor, spec) = { - 'blk: { - for (i, item) in self_prompt_key_list.iter().enumerate() { - if item == &cursor { - break 'blk ( - self_prompt_key_list.get(i + 1).cloned(), - self_prompts.get(&cursor).cloned().unwrap(), - ); - } - } - (None, None) - } - }; - if let Some(next_cursor) = next_cursor { - return Ok(Some(serde_json::json!({ - "prompts": [spec.unwrap()], - "nextCursor": next_cursor, - }))); - } else { - return Ok(Some(serde_json::json!({ - "prompts": [spec.unwrap()], - }))); - } - } else { - eprintln!("Params exist but cursor is missing"); - return Ok(None); - } - } else { - // If there is no parameter, this is the request to retrieve the first page - let prompt_key_list = self.prompt_key_list.lock().await; - let prompts = self.prompts.lock().await; - let first_key = prompt_key_list.first().expect("first key missing"); - let first_value = prompts.get(first_key).cloned().unwrap().unwrap(); - let second_key = prompt_key_list.get(1).expect("second key missing"); - return Ok(Some(serde_json::json!({ - "prompts": [first_value], - "nextCursor": second_key - }))); - }; - }, - "get_prompt_list_call_no" => Ok(Some( - serde_json::to_value::(self.prompt_list_call_no.load(Ordering::Relaxed)) - .expect("Failed to convert list call no to u8"), - )), - _ => Err(ServerError::MissingMethod), - } - } - - // This is a test path relevant only to sampling - async fn handle_response(&self, resp: JsonRpcResponse) -> Result<(), ServerError> { - let JsonRpcResponse { id, .. } = resp; - let _pending = self.pending_request.as_ref().and_then(|f| f(id)); - Ok(()) - } - - async fn handle_shutdown(&self) -> Result<(), ServerError> { - Ok(()) - } -} - -#[tokio::main] -async fn main() { - let handler = Handler::default(); - let stdin = tokio::io::stdin(); - let stdout = tokio::io::stdout(); - let test_server = Server::::new(handler, stdin, stdout).expect("Failed to create server"); - let _ = test_server.init().expect("Test server failed to init").await; -} diff --git a/crates/fig_proto/build.rs b/crates/fig_proto/build.rs index 6d22deb1e..504274197 100644 --- a/crates/fig_proto/build.rs +++ b/crates/fig_proto/build.rs @@ -133,7 +133,7 @@ fn download_protoc_windows(protoc_version: &str, tmp_folder: &tempfile::TempDir) // Verify checksum using PowerShell let mut checksum_command = Command::new("powershell"); checksum_command.arg("-Command").arg(format!( - "(Get-FileHash -Path '{}' -Algorithm SHA256).Hash.ToLower()", + "(Get-FileHash -Path '{}' -Algorithm SHA256).Hash", tmp_folder.path().join("protoc.zip").display() )); let checksum_output = checksum_command.output().unwrap(); @@ -141,11 +141,9 @@ fn download_protoc_windows(protoc_version: &str, tmp_folder: &tempfile::TempDir) eprintln!("checksum: {checksum_output:?}"); assert_eq!( - checksum_output, - checksum.to_lowercase(), + checksum_output, checksum, "Checksum verification failed. Expected: {}, Got: {}", - checksum.to_lowercase(), - checksum_output + checksum, checksum_output ); // Extract using PowerShell @@ -169,7 +167,9 @@ fn download_protoc_windows(protoc_version: &str, tmp_folder: &tempfile::TempDir) )); assert!(copy_command.spawn().unwrap().wait().unwrap().success()); - std::env::set_var("PROTOC", out_bin); + unsafe { + std::env::set_var("PROTOC", out_bin); + } } fn main() -> Result<()> { diff --git a/crates/fig_settings/src/sqlite/mod.rs b/crates/fig_settings/src/sqlite/mod.rs index c7b56c55b..58a453369 100644 --- a/crates/fig_settings/src/sqlite/mod.rs +++ b/crates/fig_settings/src/sqlite/mod.rs @@ -16,10 +16,7 @@ use rusqlite::{ params, }; use serde_json::Map; -use tracing::{ - debug, - info, -}; +use tracing::info; use crate::Result; use crate::error::DbOpenError; @@ -100,7 +97,6 @@ impl Db { let metadata = std::fs::metadata(path)?; let mut permissions = metadata.permissions(); if permissions.mode() & 0o777 != 0o600 { - debug!(?path, "Setting database file permissions to 0600"); permissions.set_mode(0o600); std::fs::set_permissions(path, permissions)?; }