diff --git a/apps/cli/tsup.config.ts b/apps/cli/tsup.config.ts index eff2c14e2c9..3ad1234d995 100644 --- a/apps/cli/tsup.config.ts +++ b/apps/cli/tsup.config.ts @@ -16,7 +16,6 @@ export default defineConfig({ external: [ // Keep native modules external "@anthropic-ai/sdk", - "@anthropic-ai/bedrock-sdk", "@anthropic-ai/vertex-sdk", // Keep @vscode/ripgrep external - we bundle the binary separately "@vscode/ripgrep", diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 0a47ae416ae..58f6354f62d 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -746,6 +746,9 @@ importers: src: dependencies: + '@ai-sdk/amazon-bedrock': + specifier: ^4.0.50 + version: 4.0.50(zod@3.25.76) '@ai-sdk/cerebras': specifier: ^1.0.0 version: 1.0.35(zod@3.25.76) @@ -770,9 +773,6 @@ importers: '@ai-sdk/xai': specifier: ^3.0.46 version: 3.0.46(zod@3.25.76) - '@anthropic-ai/bedrock-sdk': - specifier: ^0.10.2 - version: 0.10.4 '@anthropic-ai/sdk': specifier: ^0.37.0 version: 0.37.0 @@ -1417,12 +1417,24 @@ packages: '@adobe/css-tools@4.4.2': resolution: {integrity: sha512-baYZExFpsdkBNuvGKTKWCwKH57HRZLVtycZS05WTQNVOiXVSeAki3nU35zlRbToeMW8aHlJfyS+1C4BOv27q0A==} + '@ai-sdk/amazon-bedrock@4.0.50': + resolution: {integrity: sha512-DsIxaUHPbDUY0DfxYMz6GL9tO/z7ISiwACSiYupcYImqrcdLtIGFujPgszOf92ed3olfhjdkhTwKBHaf6Yh6Qw==} + engines: {node: '>=18'} + peerDependencies: + zod: 3.25.76 + '@ai-sdk/anthropic@2.0.58': resolution: {integrity: sha512-CkNW5L1Arv8gPtPlEmKd+yf/SG9ucJf0XQdpMG8OiYEtEMc2smuCA+tyCp8zI7IBVg/FE7nUfFHntQFaOjRwJQ==} engines: {node: '>=18'} peerDependencies: zod: 3.25.76 + '@ai-sdk/anthropic@3.0.37': + resolution: {integrity: sha512-tEgcJPw+a6obbF+SHrEiZsx3DNxOHqeY8bK4IpiNsZ8YPZD141R34g3lEAaQnmNN5mGsEJ8SXoEDabuzi8wFJQ==} + engines: {node: '>=18'} + peerDependencies: + zod: 3.25.76 + '@ai-sdk/cerebras@1.0.35': resolution: {integrity: sha512-JrNdMYptrOUjNthibgBeAcBjZ/H+fXb49sSrWhOx5Aq8eUcrYvwQ2DtSAi8VraHssZu78NAnBMrgFWSUOTXFxw==} engines: {node: '>=18'} @@ -1575,9 +1587,6 @@ packages: '@antfu/utils@8.1.1': resolution: {integrity: sha512-Mex9nXf9vR6AhcXmMrlz/HVgYYZpVGJ6YlPgwl7UnaFpnshXs6EK/oa5Gpf3CzENMjkvEx2tQtntGnb7UtSTOQ==} - '@anthropic-ai/bedrock-sdk@0.10.4': - resolution: {integrity: sha512-szduEHbMli6XL934xrraYg5cFuKL/1oMyj/iZuEVjtddQ7eD5cXObzWobsv5mTLWijQmSzMfFD+JAUHDPHlQ/Q==} - '@anthropic-ai/sdk@0.37.0': resolution: {integrity: sha512-tHjX2YbkUBwEgg0JZU3EFSSAQPoK4qQR/NFYa8Vtzd5UAyXzZksCw2In69Rml4R/TyHPBfRYaLK35XiOe33pjw==} @@ -1587,9 +1596,6 @@ packages: '@asamuzakjp/css-color@3.2.0': resolution: {integrity: sha512-K1A6z8tS3XsmCMM86xoWdn7Fkdn9m6RSVtocUrJYIwZnFVkng/PvkEoWtOWmP+Scc6saYWHWZYbndEEXxl24jw==} - '@aws-crypto/crc32@3.0.0': - resolution: {integrity: sha512-IzSgsrxUcsrejQbPVilIKy16kAT52EwB6zSaI+M3xxIhKh5+aldEyvI+z6erM7TCLB2BJsFrtHjp6/4/sr+3dA==} - '@aws-crypto/crc32@5.2.0': resolution: {integrity: sha512-nLbCWqQNgUiwwtFsen1AdzAtvuLRsQS8rYgMuxCrdKf9kOssamGLuPwyTY9wyYblNr9+1XM8v6zoDTPPSIeANg==} engines: {node: '>=16.0.0'} @@ -1597,9 +1603,6 @@ packages: '@aws-crypto/sha256-browser@5.2.0': resolution: {integrity: sha512-AXfN/lGotSQwu6HNcEsIASo7kWXZ5HYWvfOmSNKDsEqC4OashTp8alTmaz+F7TC2L083SFv5RdB+qU3Vs1kZqw==} - '@aws-crypto/sha256-js@4.0.0': - resolution: {integrity: sha512-MHGJyjE7TX9aaqXj7zk2ppnFUOhaDs5sP+HtNS0evOxn72c+5njUmyJmpGd7TfyoDznZlHMmdo/xGUdu2NIjNQ==} - '@aws-crypto/sha256-js@5.2.0': resolution: {integrity: sha512-FFQQyu7edu4ufvIZ+OadFpHHOt+eSTBaYaki44c+akjg7qZg9oOQeLlk77F6tSYqjDAFClrHJk9tMf0HdVyOvA==} engines: {node: '>=16.0.0'} @@ -1607,12 +1610,6 @@ packages: '@aws-crypto/supports-web-crypto@5.2.0': resolution: {integrity: sha512-iAvUotm021kM33eCdNfwIN//F77/IADDSs58i+MDaOqFrVjZo9bAal0NK7HurRuWLLpF1iLX7gbWrjHjeo+YFg==} - '@aws-crypto/util@3.0.0': - resolution: {integrity: sha512-2OJlpeJpCR48CC8r+uKVChzs9Iungj9wkZrl8Z041DWEWvyIHILYKCPNzJghKsivj+S3mLo6BVc7mBNzdxA46w==} - - '@aws-crypto/util@4.0.0': - resolution: {integrity: sha512-2EnmPy2gsFZ6m8bwUQN4jq+IyXV3quHAcwPOS6ZA3k+geujiqI8aRokO2kFJe+idJ/P3v4qWI186rVMo0+zLDQ==} - '@aws-crypto/util@5.2.0': resolution: {integrity: sha512-4RkU9EsI6ZpBve5fseQlGNUWKMa1RLPQ1dnjnQoe07ldfIzcsGb5hC5W0Dm7u423KWzawlrpbjXBrXCEv9zazQ==} @@ -1708,14 +1705,6 @@ packages: resolution: {integrity: sha512-/inmPnjZE0ZBE16zaCowAvouSx05FJ7p6BQYuzlJ8vxEU0sS0Hf8fvhuiRnN9V9eDUPIBY+/5EjbMWygXL4wlQ==} engines: {node: '>=18.0.0'} - '@aws-sdk/types@3.804.0': - resolution: {integrity: sha512-A9qnsy9zQ8G89vrPPlNG9d1d8QcKRGqJKqwyGgS0dclJpwy6d1EWgQLIolKPl6vcFpLoe6avLOLxr+h8ur5wpg==} - engines: {node: '>=18.0.0'} - - '@aws-sdk/types@3.840.0': - resolution: {integrity: sha512-xliuHaUFZxEx1NSXeLLZ9Dyu6+EJVQKEoD+yM+zqUo3YDZ7medKJWY6fIOKiPX/N7XbLdBYwajb15Q7IL8KkeA==} - engines: {node: '>=18.0.0'} - '@aws-sdk/types@3.922.0': resolution: {integrity: sha512-eLA6XjVobAUAMivvM7DBL79mnHyrm+32TkXNWZua5mnxF+6kQCfblKKJvxMZLGosO53/Ex46ogim8IY5Nbqv2w==} engines: {node: '>=18.0.0'} @@ -1744,9 +1733,6 @@ packages: aws-crt: optional: true - '@aws-sdk/util-utf8-browser@3.259.0': - resolution: {integrity: sha512-UvFa/vR+e19XookZF8RzFZBrw2EUkQWxiBW0yYQAhvk3C+QVGl0H3ouca8LDBlBfQKXwmW3huo/59H8rwb1wJw==} - '@aws-sdk/xml-builder@3.921.0': resolution: {integrity: sha512-LVHg0jgjyicKKvpNIEMXIMr1EBViESxcPkqfOlT+X1FkmUMTNZEEVF18tOJg4m4hV5vxtkWcqtr4IEeWa1C41Q==} engines: {node: '>=18.0.0'} @@ -3879,10 +3865,6 @@ packages: resolution: {integrity: sha512-tlqY9xq5ukxTUZBmoOp+m61cqwQD5pHJtFY3Mn8CA8ps6yghLH/Hw8UPdqg4OLmFW3IFlcXnQNmo/dh8HzXYIQ==} engines: {node: '>=18'} - '@smithy/abort-controller@2.2.0': - resolution: {integrity: sha512-wRlta7GuLWpTqtFfGo+nZyOO1vEvewdNR1R4rTxpC8XU6vG/NDyrFBhwLZsqg1NUoR1noVaXJPC/7ZK47QCySw==} - engines: {node: '>=14.0.0'} - '@smithy/abort-controller@4.2.4': resolution: {integrity: sha512-Z4DUr/AkgyFf1bOThW2HwzREagee0sB5ycl+hDiSZOfRLW8ZgrOjDi6g8mHH19yyU5E2A/64W3z6SMIf5XiUSQ==} engines: {node: '>=18.0.0'} @@ -3899,9 +3881,6 @@ packages: resolution: {integrity: sha512-YVNMjhdz2pVto5bRdux7GMs0x1m0Afz3OcQy/4Yf9DH4fWOtroGH7uLvs7ZmDyoBJzLdegtIPpXrpJOZWvUXdw==} engines: {node: '>=18.0.0'} - '@smithy/eventstream-codec@2.2.0': - resolution: {integrity: sha512-8janZoJw85nJmQZc4L8TuePp2pk1nxLgkxIR0TUjKJ5Dkj5oelB9WtiSSGXCQvNsJl0VSTvK/2ueMXxvpa9GVw==} - '@smithy/eventstream-codec@4.2.4': resolution: {integrity: sha512-aV8blR9RBDKrOlZVgjOdmOibTC2sBXNiT7WA558b4MPdsLTV6sbyc1WIE9QiIuYMJjYtnPLciefoqSW8Gi+MZQ==} engines: {node: '>=18.0.0'} @@ -3914,25 +3893,14 @@ packages: resolution: {integrity: sha512-lxfDT0UuSc1HqltOGsTEAlZ6H29gpfDSdEPTapD5G63RbnYToZ+ezjzdonCCH90j5tRRCw3aLXVbiZaBW3VRVg==} engines: {node: '>=18.0.0'} - '@smithy/eventstream-serde-node@2.2.0': - resolution: {integrity: sha512-zpQMtJVqCUMn+pCSFcl9K/RPNtQE0NuMh8sKpCdEHafhwRsjP50Oq/4kMmvxSRy6d8Jslqd8BLvDngrUtmN9iA==} - engines: {node: '>=14.0.0'} - '@smithy/eventstream-serde-node@4.2.4': resolution: {integrity: sha512-TPhiGByWnYyzcpU/K3pO5V7QgtXYpE0NaJPEZBCa1Y5jlw5SjqzMSbFiLb+ZkJhqoQc0ImGyVINqnq1ze0ZRcQ==} engines: {node: '>=18.0.0'} - '@smithy/eventstream-serde-universal@2.2.0': - resolution: {integrity: sha512-pvoe/vvJY0mOpuF84BEtyZoYfbehiFj8KKWk1ds2AT0mTLYFVs+7sBJZmioOFdBXKd48lfrx1vumdPdmGlCLxA==} - engines: {node: '>=14.0.0'} - '@smithy/eventstream-serde-universal@4.2.4': resolution: {integrity: sha512-GNI/IXaY/XBB1SkGBFmbW033uWA0tj085eCxYih0eccUe/PFR7+UBQv9HNDk2fD9TJu7UVsCWsH99TkpEPSOzQ==} engines: {node: '>=18.0.0'} - '@smithy/fetch-http-handler@2.5.0': - resolution: {integrity: sha512-BOWEBeppWhLn/no/JxUL/ghTfANTjT7kg3Ww2rPqTUY9R4yHPXxJ9JhMe3Z03LN3aPwiwlpDIUcVw1xDyHqEhw==} - '@smithy/fetch-http-handler@5.3.5': resolution: {integrity: sha512-mg83SM3FLI8Sa2ooTJbsh5MFfyMTyNRwxqpKHmE0ICRIa66Aodv80DMsTQI02xBLVJ0hckwqTRr5IGAbbWuFLQ==} engines: {node: '>=18.0.0'} @@ -3949,10 +3917,6 @@ packages: resolution: {integrity: sha512-GGP3O9QFD24uGeAXYUjwSTXARoqpZykHadOmA8G5vfJPK0/DC67qa//0qvqrJzL1xc8WQWX7/yc7fwudjPHPhA==} engines: {node: '>=14.0.0'} - '@smithy/is-array-buffer@3.0.0': - resolution: {integrity: sha512-+Fsu6Q6C4RSJiy81Y8eApjEB5gVtM+oFKTffg+jSuwtvomJJrhUJBu2zS8wjXSgH/g1MKEWrzyChTBe6clb5FQ==} - engines: {node: '>=16.0.0'} - '@smithy/is-array-buffer@4.2.0': resolution: {integrity: sha512-DZZZBvC7sjcYh4MazJSGiWMI2L7E0oCiRHREDzIxi/M2LY79/21iXt6aPLHge82wi5LsuRF5A06Ds3+0mlh6CQ==} engines: {node: '>=18.0.0'} @@ -3961,10 +3925,6 @@ packages: resolution: {integrity: sha512-hJRZuFS9UsElX4DJSJfoX4M1qXRH+VFiLMUnhsWvtOOUWRNvvOfDaUSdlNbjwv1IkpVjj/Rd/O59Jl3nhAcxow==} engines: {node: '>=18.0.0'} - '@smithy/middleware-endpoint@2.5.1': - resolution: {integrity: sha512-1/8kFp6Fl4OsSIVTWHnNjLnTL8IqpIb/D3sTSczrKFnrE9VMNWxnrRKNvpUHOJ6zpGD5f62TPm7+17ilTJpiCQ==} - engines: {node: '>=14.0.0'} - '@smithy/middleware-endpoint@4.3.6': resolution: {integrity: sha512-PXehXofGMFpDqr933rxD8RGOcZ0QBAWtuzTgYRAHAL2BnKawHDEdf/TnGpcmfPJGwonhginaaeJIKluEojiF/w==} engines: {node: '>=18.0.0'} @@ -3973,66 +3933,34 @@ packages: resolution: {integrity: sha512-OhLx131znrEDxZPAvH/OYufR9d1nB2CQADyYFN4C3V/NQS7Mg4V6uvxHC/Dr96ZQW8IlHJTJ+vAhKt6oxWRndA==} engines: {node: '>=18.0.0'} - '@smithy/middleware-serde@2.3.0': - resolution: {integrity: sha512-sIADe7ojwqTyvEQBe1nc/GXB9wdHhi9UwyX0lTyttmUWDJLP655ZYE1WngnNyXREme8I27KCaUhyhZWRXL0q7Q==} - engines: {node: '>=14.0.0'} - '@smithy/middleware-serde@4.2.4': resolution: {integrity: sha512-jUr3x2CDhV15TOX2/Uoz4gfgeqLrRoTQbYAuhLS7lcVKNev7FeYSJ1ebEfjk+l9kbb7k7LfzIR/irgxys5ZTOg==} engines: {node: '>=18.0.0'} - '@smithy/middleware-stack@2.2.0': - resolution: {integrity: sha512-Qntc3jrtwwrsAC+X8wms8zhrTr0sFXnyEGhZd9sLtsJ/6gGQKFzNB+wWbOcpJd7BR8ThNCoKt76BuQahfMvpeA==} - engines: {node: '>=14.0.0'} - '@smithy/middleware-stack@4.2.4': resolution: {integrity: sha512-Gy3TKCOnm9JwpFooldwAboazw+EFYlC+Bb+1QBsSi5xI0W5lX81j/P5+CXvD/9ZjtYKRgxq+kkqd/KOHflzvgA==} engines: {node: '>=18.0.0'} - '@smithy/node-config-provider@2.3.0': - resolution: {integrity: sha512-0elK5/03a1JPWMDPaS726Iw6LpQg80gFut1tNpPfxFuChEEklo2yL823V94SpTZTxmKlXFtFgsP55uh3dErnIg==} - engines: {node: '>=14.0.0'} - '@smithy/node-config-provider@4.3.4': resolution: {integrity: sha512-3X3w7qzmo4XNNdPKNS4nbJcGSwiEMsNsRSunMA92S4DJLLIrH5g1AyuOA2XKM9PAPi8mIWfqC+fnfKNsI4KvHw==} engines: {node: '>=18.0.0'} - '@smithy/node-http-handler@2.5.0': - resolution: {integrity: sha512-mVGyPBzkkGQsPoxQUbxlEfRjrj6FPyA3u3u2VXGr9hT8wilsoQdZdvKpMBFMB8Crfhv5dNkKHIW0Yyuc7eABqA==} - engines: {node: '>=14.0.0'} - '@smithy/node-http-handler@4.4.4': resolution: {integrity: sha512-VXHGfzCXLZeKnFp6QXjAdy+U8JF9etfpUXD1FAbzY1GzsFJiDQRQIt2CnMUvUdz3/YaHNqT3RphVWMUpXTIODA==} engines: {node: '>=18.0.0'} - '@smithy/property-provider@2.2.0': - resolution: {integrity: sha512-+xiil2lFhtTRzXkx8F053AV46QnIw6e7MV8od5Mi68E1ICOjCeCHw2XfLnDEUHnT9WGUIkwcqavXjfwuJbGlpg==} - engines: {node: '>=14.0.0'} - '@smithy/property-provider@4.2.4': resolution: {integrity: sha512-g2DHo08IhxV5GdY3Cpt/jr0mkTlAD39EJKN27Jb5N8Fb5qt8KG39wVKTXiTRCmHHou7lbXR8nKVU14/aRUf86w==} engines: {node: '>=18.0.0'} - '@smithy/protocol-http@3.3.0': - resolution: {integrity: sha512-Xy5XK1AFWW2nlY/biWZXu6/krgbaf2dg0q492D8M5qthsnU2H+UgFeZLbM76FnH7s6RO/xhQRkj+T6KBO3JzgQ==} - engines: {node: '>=14.0.0'} - '@smithy/protocol-http@5.3.4': resolution: {integrity: sha512-3sfFd2MAzVt0Q/klOmjFi3oIkxczHs0avbwrfn1aBqtc23WqQSmjvk77MBw9WkEQcwbOYIX5/2z4ULj8DuxSsw==} engines: {node: '>=18.0.0'} - '@smithy/querystring-builder@2.2.0': - resolution: {integrity: sha512-L1kSeviUWL+emq3CUVSgdogoM/D9QMFaqxL/dd0X7PCNWmPXqt+ExtrBjqT0V7HLN03Vs9SuiLrG3zy3JGnE5A==} - engines: {node: '>=14.0.0'} - '@smithy/querystring-builder@4.2.4': resolution: {integrity: sha512-KQ1gFXXC+WsbPFnk7pzskzOpn4s+KheWgO3dzkIEmnb6NskAIGp/dGdbKisTPJdtov28qNDohQrgDUKzXZBLig==} engines: {node: '>=18.0.0'} - '@smithy/querystring-parser@2.2.0': - resolution: {integrity: sha512-BvHCDrKfbG5Yhbpj4vsbuPV2GgcpHiAkLeIlcA1LtfpMz3jrqizP1+OguSNSj1MwBHEiN+jwNisXLGdajGDQJA==} - engines: {node: '>=14.0.0'} - '@smithy/querystring-parser@4.2.4': resolution: {integrity: sha512-aHb5cqXZocdzEkZ/CvhVjdw5l4r1aU/9iMEyoKzH4eXMowT6M0YjBpp7W/+XjkBnY8Xh0kVd55GKjnPKlCwinQ==} engines: {node: '>=18.0.0'} @@ -4041,57 +3969,26 @@ packages: resolution: {integrity: sha512-fdWuhEx4+jHLGeew9/IvqVU/fxT/ot70tpRGuOLxE3HzZOyKeTQfYeV1oaBXpzi93WOk668hjMuuagJ2/Qs7ng==} engines: {node: '>=18.0.0'} - '@smithy/shared-ini-file-loader@2.4.0': - resolution: {integrity: sha512-WyujUJL8e1B6Z4PBfAqC/aGY1+C7T0w20Gih3yrvJSk97gpiVfB+y7c46T4Nunk+ZngLq0rOIdeVeIklk0R3OA==} - engines: {node: '>=14.0.0'} - '@smithy/shared-ini-file-loader@4.3.4': resolution: {integrity: sha512-y5ozxeQ9omVjbnJo9dtTsdXj9BEvGx2X8xvRgKnV+/7wLBuYJQL6dOa/qMY6omyHi7yjt1OA97jZLoVRYi8lxA==} engines: {node: '>=18.0.0'} - '@smithy/signature-v4@3.1.2': - resolution: {integrity: sha512-3BcPylEsYtD0esM4Hoyml/+s7WP2LFhcM3J2AGdcL2vx9O60TtfpDOL72gjb4lU8NeRPeKAwR77YNyyGvMbuEA==} - engines: {node: '>=16.0.0'} - '@smithy/signature-v4@5.3.4': resolution: {integrity: sha512-ScDCpasxH7w1HXHYbtk3jcivjvdA1VICyAdgvVqKhKKwxi+MTwZEqFw0minE+oZ7F07oF25xh4FGJxgqgShz0A==} engines: {node: '>=18.0.0'} - '@smithy/smithy-client@2.5.1': - resolution: {integrity: sha512-jrbSQrYCho0yDaaf92qWgd+7nAeap5LtHTI51KXqmpIFCceKU3K9+vIVTUH72bOJngBMqa4kyu1VJhRcSrk/CQ==} - engines: {node: '>=14.0.0'} - '@smithy/smithy-client@4.9.2': resolution: {integrity: sha512-gZU4uAFcdrSi3io8U99Qs/FvVdRxPvIMToi+MFfsy/DN9UqtknJ1ais+2M9yR8e0ASQpNmFYEKeIKVcMjQg3rg==} engines: {node: '>=18.0.0'} - '@smithy/types@2.12.0': - resolution: {integrity: sha512-QwYgloJ0sVNBeBuBs65cIkTbfzV/Q6ZNPCJ99EICFEdJYG50nGIY/uYXp+TbsdJReIuPr0a0kXmCvren3MbRRw==} - engines: {node: '>=14.0.0'} - - '@smithy/types@3.7.2': - resolution: {integrity: sha512-bNwBYYmN8Eh9RyjS1p2gW6MIhSO2rl7X9QeLM8iTdcGRP+eDiIWDt66c9IysCc22gefKszZv+ubV9qZc7hdESg==} - engines: {node: '>=16.0.0'} - - '@smithy/types@4.3.1': - resolution: {integrity: sha512-UqKOQBL2x6+HWl3P+3QqFD4ncKq0I8Nuz9QItGv5WuKuMHuuwlhvqcZCoXGfc+P1QmfJE7VieykoYYmrOoFJxA==} - engines: {node: '>=18.0.0'} - '@smithy/types@4.8.1': resolution: {integrity: sha512-N0Zn0OT1zc+NA+UVfkYqQzviRh5ucWwO7mBV3TmHHprMnfcJNfhlPicDkBHi0ewbh+y3evR6cNAW0Raxvb01NA==} engines: {node: '>=18.0.0'} - '@smithy/url-parser@2.2.0': - resolution: {integrity: sha512-hoA4zm61q1mNTpksiSWp2nEl1dt3j726HdRhiNgVJQMj7mLp7dprtF57mOB6JvEk/x9d2bsuL5hlqZbBuHQylQ==} - '@smithy/url-parser@4.2.4': resolution: {integrity: sha512-w/N/Iw0/PTwJ36PDqU9PzAwVElo4qXxCC0eCTlUtIz/Z5V/2j/cViMHi0hPukSBHp4DVwvUlUhLgCzqSJ6plrg==} engines: {node: '>=18.0.0'} - '@smithy/util-base64@2.3.0': - resolution: {integrity: sha512-s3+eVwNeJuXUwuMbusncZNViuhv2LjVJ1nMwTqSA0XAC7gjKhqqxRdJPhR8+YrkoZ9IiIbFk/yK6ACe/xlF+hw==} - engines: {node: '>=14.0.0'} - '@smithy/util-base64@4.3.0': resolution: {integrity: sha512-GkXZ59JfyxsIwNTWFnjmFEI8kZpRNIBfxKjv09+nkAWPt/4aGaEWMM04m4sxgNVWkbt2MdSvE3KF/PfX4nFedQ==} engines: {node: '>=18.0.0'} @@ -4108,10 +4005,6 @@ packages: resolution: {integrity: sha512-IJdWBbTcMQ6DA0gdNhh/BwrLkDR+ADW5Kr1aZmd4k3DIF6ezMV4R2NIAmT08wQJ3yUK82thHWmC/TnK/wpMMIA==} engines: {node: '>=14.0.0'} - '@smithy/util-buffer-from@3.0.0': - resolution: {integrity: sha512-aEOHCgq5RWFbP+UDPvPot26EJHjOC+bRgse5A8V3FSShqd5E5UN4qc7zkwsvJPPAVsf73QwYcHN1/gt/rtLwQA==} - engines: {node: '>=16.0.0'} - '@smithy/util-buffer-from@4.2.0': resolution: {integrity: sha512-kAY9hTKulTNevM2nlRtxAG2FQ3B2OR6QIrPY3zE5LqJy1oxzmgBGsHLWTcNhWXKchgA0WHW+mZkQrng/pgcCew==} engines: {node: '>=18.0.0'} @@ -4132,26 +4025,10 @@ packages: resolution: {integrity: sha512-f+nBDhgYRCmUEDKEQb6q0aCcOTXRDqH5wWaFHJxt4anB4pKHlgGoYP3xtioKXH64e37ANUkzWf6p4Mnv1M5/Vg==} engines: {node: '>=18.0.0'} - '@smithy/util-hex-encoding@2.2.0': - resolution: {integrity: sha512-7iKXR+/4TpLK194pVjKiasIyqMtTYJsgKgM242Y9uzt5dhHnUDvMNb+3xIhRJ9QhvqGii/5cRUt4fJn3dtXNHQ==} - engines: {node: '>=14.0.0'} - - '@smithy/util-hex-encoding@3.0.0': - resolution: {integrity: sha512-eFndh1WEK5YMUYvy3lPlVmYY/fZcQE1D8oSf41Id2vCeIkKJXPcYDCZD+4+xViI6b1XSd7tE+s5AmXzz5ilabQ==} - engines: {node: '>=16.0.0'} - '@smithy/util-hex-encoding@4.2.0': resolution: {integrity: sha512-CCQBwJIvXMLKxVbO88IukazJD9a4kQ9ZN7/UMGBjBcJYvatpWk+9g870El4cB8/EJxfe+k+y0GmR9CAzkF+Nbw==} engines: {node: '>=18.0.0'} - '@smithy/util-middleware@2.2.0': - resolution: {integrity: sha512-L1qpleXf9QD6LwLCJ5jddGkgWyuSvWBkJwWAZ6kFkdifdso+sk3L3O1HdmPvCdnCK3IS4qWyPxev01QMnfHSBw==} - engines: {node: '>=14.0.0'} - - '@smithy/util-middleware@3.0.11': - resolution: {integrity: sha512-dWpyc1e1R6VoXrwLoLDd57U1z6CwNSdkM69Ie4+6uYh2GC7Vg51Qtan7ITzczuVpqezdDTKJGJB95fFvvjU/ow==} - engines: {node: '>=16.0.0'} - '@smithy/util-middleware@4.2.4': resolution: {integrity: sha512-fKGQAPAn8sgV0plRikRVo6g6aR0KyKvgzNrPuM74RZKy/wWVzx3BMk+ZWEueyN3L5v5EDg+P582mKU+sH5OAsg==} engines: {node: '>=18.0.0'} @@ -4160,22 +4037,10 @@ packages: resolution: {integrity: sha512-yQncJmj4dtv/isTXxRb4AamZHy4QFr4ew8GxS6XLWt7sCIxkPxPzINWd7WLISEFPsIan14zrKgvyAF+/yzfwoA==} engines: {node: '>=18.0.0'} - '@smithy/util-stream@2.2.0': - resolution: {integrity: sha512-17faEXbYWIRst1aU9SvPZyMdWmqIrduZjVOqCPMIsWFNxs5yQQgFrJL6b2SdiCzyW9mJoDjFtgi53xx7EH+BXA==} - engines: {node: '>=14.0.0'} - '@smithy/util-stream@4.5.5': resolution: {integrity: sha512-7M5aVFjT+HPilPOKbOmQfCIPchZe4DSBc1wf1+NvHvSoFTiFtauZzT+onZvCj70xhXd0AEmYnZYmdJIuwxOo4w==} engines: {node: '>=18.0.0'} - '@smithy/util-uri-escape@2.2.0': - resolution: {integrity: sha512-jtmJMyt1xMD/d8OtbVJ2gFZOSKc+ueYJZPW20ULW1GOp/q/YIM0wNh+u8ZFao9UaIGz4WoPW8hC64qlWLIfoDA==} - engines: {node: '>=14.0.0'} - - '@smithy/util-uri-escape@3.0.0': - resolution: {integrity: sha512-LqR7qYLgZTD7nWLBecUi4aqolw8Mhza9ArpNEQ881MJJIU2sE5iHCK6TdyqqzcDLy0OPe10IY4T8ctVdtynubg==} - engines: {node: '>=16.0.0'} - '@smithy/util-uri-escape@4.2.0': resolution: {integrity: sha512-igZpCKV9+E/Mzrpq6YacdTQ0qTiLm85gD6N/IrmyDvQFA4UnU3d5g3m8tMT/6zG/vVkWSU+VxeUyGonL62DuxA==} engines: {node: '>=18.0.0'} @@ -4184,10 +4049,6 @@ packages: resolution: {integrity: sha512-R8Rdn8Hy72KKcebgLiv8jQcQkXoLMOGGv5uI1/k0l+snqkOzQ1R0ChUBCxWMlBsFMekWjq0wRudIweFs7sKT5A==} engines: {node: '>=14.0.0'} - '@smithy/util-utf8@3.0.0': - resolution: {integrity: sha512-rUeT12bxFnplYDe815GXbq/oixEGHfRFFtcTF3YdDi/JaENIM6aSYYLJydG83UNzLXeRI5K8abYd/8Sp/QM0kA==} - engines: {node: '>=16.0.0'} - '@smithy/util-utf8@4.2.0': resolution: {integrity: sha512-zBPfuzoI8xyBtR2P6WQj63Rz8i3AmfAaJLuNG8dWsfvPe8lO4aCPYLn879mEgHndZH1zQ2oXmG8O1GGzzaoZiw==} engines: {node: '>=18.0.0'} @@ -5146,6 +5007,9 @@ packages: resolution: {integrity: sha512-wvUjBtSGN7+7SjNpq/9M2Tg350UZD3q62IFZLbRAR1bSMlCo1ZaeW+BJ+D090e4hIIZLBcTDWe4Mh4jvUDajzQ==} engines: {node: '>= 0.4'} + aws4fetch@1.0.20: + resolution: {integrity: sha512-/djoAN709iY65ETD6LKCtyyEI04XIBP5xVvfmNxsEP0uJB5tyaGBztSryRr4HqMStr9R06PisQE7m9zDTXKu6g==} + axios@1.12.0: resolution: {integrity: sha512-oXTDccv8PcfjZmPGlWsPSwtOJCZ/b6W5jAMCNcfwJbCzDckwG0jrYJFaWH1yvivfCXjVzV/SPDEhMB3Q+DSurg==} @@ -11130,12 +10994,28 @@ snapshots: '@adobe/css-tools@4.4.2': {} + '@ai-sdk/amazon-bedrock@4.0.50(zod@3.25.76)': + dependencies: + '@ai-sdk/anthropic': 3.0.37(zod@3.25.76) + '@ai-sdk/provider': 3.0.7 + '@ai-sdk/provider-utils': 4.0.13(zod@3.25.76) + '@smithy/eventstream-codec': 4.2.4 + '@smithy/util-utf8': 4.2.0 + aws4fetch: 1.0.20 + zod: 3.25.76 + '@ai-sdk/anthropic@2.0.58(zod@3.25.76)': dependencies: '@ai-sdk/provider': 2.0.1 '@ai-sdk/provider-utils': 3.0.20(zod@3.25.76) zod: 3.25.76 + '@ai-sdk/anthropic@3.0.37(zod@3.25.76)': + dependencies: + '@ai-sdk/provider': 3.0.7 + '@ai-sdk/provider-utils': 4.0.13(zod@3.25.76) + zod: 3.25.76 + '@ai-sdk/cerebras@1.0.35(zod@3.25.76)': dependencies: '@ai-sdk/openai-compatible': 1.0.31(zod@3.25.76) @@ -11304,23 +11184,6 @@ snapshots: '@antfu/utils@8.1.1': {} - '@anthropic-ai/bedrock-sdk@0.10.4': - dependencies: - '@anthropic-ai/sdk': 0.37.0 - '@aws-crypto/sha256-js': 4.0.0 - '@aws-sdk/client-bedrock-runtime': 3.922.0 - '@aws-sdk/credential-providers': 3.922.0 - '@smithy/eventstream-serde-node': 2.2.0 - '@smithy/fetch-http-handler': 2.5.0 - '@smithy/protocol-http': 3.3.0 - '@smithy/signature-v4': 3.1.2 - '@smithy/smithy-client': 2.5.1 - '@smithy/types': 2.12.0 - '@smithy/util-base64': 2.3.0 - transitivePeerDependencies: - - aws-crt - - encoding - '@anthropic-ai/sdk@0.37.0': dependencies: '@types/node': 18.19.100 @@ -11349,12 +11212,6 @@ snapshots: '@csstools/css-tokenizer': 3.0.4 lru-cache: 10.4.3 - '@aws-crypto/crc32@3.0.0': - dependencies: - '@aws-crypto/util': 3.0.0 - '@aws-sdk/types': 3.840.0 - tslib: 1.14.1 - '@aws-crypto/crc32@5.2.0': dependencies: '@aws-crypto/util': 5.2.0 @@ -11371,12 +11228,6 @@ snapshots: '@smithy/util-utf8': 2.3.0 tslib: 2.8.1 - '@aws-crypto/sha256-js@4.0.0': - dependencies: - '@aws-crypto/util': 4.0.0 - '@aws-sdk/types': 3.804.0 - tslib: 1.14.1 - '@aws-crypto/sha256-js@5.2.0': dependencies: '@aws-crypto/util': 5.2.0 @@ -11387,18 +11238,6 @@ snapshots: dependencies: tslib: 2.8.1 - '@aws-crypto/util@3.0.0': - dependencies: - '@aws-sdk/types': 3.840.0 - '@aws-sdk/util-utf8-browser': 3.259.0 - tslib: 1.14.1 - - '@aws-crypto/util@4.0.0': - dependencies: - '@aws-sdk/types': 3.840.0 - '@aws-sdk/util-utf8-browser': 3.259.0 - tslib: 1.14.1 - '@aws-crypto/util@5.2.0': dependencies: '@aws-sdk/types': 3.922.0 @@ -11806,16 +11645,6 @@ snapshots: transitivePeerDependencies: - aws-crt - '@aws-sdk/types@3.804.0': - dependencies: - '@smithy/types': 4.3.1 - tslib: 2.8.1 - - '@aws-sdk/types@3.840.0': - dependencies: - '@smithy/types': 4.3.1 - tslib: 2.8.1 - '@aws-sdk/types@3.922.0': dependencies: '@smithy/types': 4.8.1 @@ -11855,10 +11684,6 @@ snapshots: '@smithy/types': 4.8.1 tslib: 2.8.1 - '@aws-sdk/util-utf8-browser@3.259.0': - dependencies: - tslib: 2.8.1 - '@aws-sdk/xml-builder@3.921.0': dependencies: '@smithy/types': 4.8.1 @@ -14041,11 +13866,6 @@ snapshots: '@sindresorhus/merge-streams@4.0.0': {} - '@smithy/abort-controller@2.2.0': - dependencies: - '@smithy/types': 2.12.0 - tslib: 2.8.1 - '@smithy/abort-controller@4.2.4': dependencies: '@smithy/types': 4.8.1 @@ -14081,13 +13901,6 @@ snapshots: '@smithy/url-parser': 4.2.4 tslib: 2.8.1 - '@smithy/eventstream-codec@2.2.0': - dependencies: - '@aws-crypto/crc32': 3.0.0 - '@smithy/types': 2.12.0 - '@smithy/util-hex-encoding': 2.2.0 - tslib: 2.8.1 - '@smithy/eventstream-codec@4.2.4': dependencies: '@aws-crypto/crc32': 5.2.0 @@ -14106,38 +13919,18 @@ snapshots: '@smithy/types': 4.8.1 tslib: 2.8.1 - '@smithy/eventstream-serde-node@2.2.0': - dependencies: - '@smithy/eventstream-serde-universal': 2.2.0 - '@smithy/types': 2.12.0 - tslib: 2.8.1 - '@smithy/eventstream-serde-node@4.2.4': dependencies: '@smithy/eventstream-serde-universal': 4.2.4 '@smithy/types': 4.8.1 tslib: 2.8.1 - '@smithy/eventstream-serde-universal@2.2.0': - dependencies: - '@smithy/eventstream-codec': 2.2.0 - '@smithy/types': 2.12.0 - tslib: 2.8.1 - '@smithy/eventstream-serde-universal@4.2.4': dependencies: '@smithy/eventstream-codec': 4.2.4 '@smithy/types': 4.8.1 tslib: 2.8.1 - '@smithy/fetch-http-handler@2.5.0': - dependencies: - '@smithy/protocol-http': 3.3.0 - '@smithy/querystring-builder': 2.2.0 - '@smithy/types': 2.12.0 - '@smithy/util-base64': 2.3.0 - tslib: 2.8.1 - '@smithy/fetch-http-handler@5.3.5': dependencies: '@smithy/protocol-http': 5.3.4 @@ -14162,10 +13955,6 @@ snapshots: dependencies: tslib: 2.8.1 - '@smithy/is-array-buffer@3.0.0': - dependencies: - tslib: 2.8.1 - '@smithy/is-array-buffer@4.2.0': dependencies: tslib: 2.8.1 @@ -14176,16 +13965,6 @@ snapshots: '@smithy/types': 4.8.1 tslib: 2.8.1 - '@smithy/middleware-endpoint@2.5.1': - dependencies: - '@smithy/middleware-serde': 2.3.0 - '@smithy/node-config-provider': 2.3.0 - '@smithy/shared-ini-file-loader': 2.4.0 - '@smithy/types': 2.12.0 - '@smithy/url-parser': 2.2.0 - '@smithy/util-middleware': 2.2.0 - tslib: 2.8.1 - '@smithy/middleware-endpoint@4.3.6': dependencies: '@smithy/core': 3.17.2 @@ -14209,34 +13988,17 @@ snapshots: '@smithy/uuid': 1.1.0 tslib: 2.8.1 - '@smithy/middleware-serde@2.3.0': - dependencies: - '@smithy/types': 2.12.0 - tslib: 2.8.1 - '@smithy/middleware-serde@4.2.4': dependencies: '@smithy/protocol-http': 5.3.4 '@smithy/types': 4.8.1 tslib: 2.8.1 - '@smithy/middleware-stack@2.2.0': - dependencies: - '@smithy/types': 2.12.0 - tslib: 2.8.1 - '@smithy/middleware-stack@4.2.4': dependencies: '@smithy/types': 4.8.1 tslib: 2.8.1 - '@smithy/node-config-provider@2.3.0': - dependencies: - '@smithy/property-provider': 2.2.0 - '@smithy/shared-ini-file-loader': 2.4.0 - '@smithy/types': 2.12.0 - tslib: 2.8.1 - '@smithy/node-config-provider@4.3.4': dependencies: '@smithy/property-provider': 4.2.4 @@ -14244,14 +14006,6 @@ snapshots: '@smithy/types': 4.8.1 tslib: 2.8.1 - '@smithy/node-http-handler@2.5.0': - dependencies: - '@smithy/abort-controller': 2.2.0 - '@smithy/protocol-http': 3.3.0 - '@smithy/querystring-builder': 2.2.0 - '@smithy/types': 2.12.0 - tslib: 2.8.1 - '@smithy/node-http-handler@4.4.4': dependencies: '@smithy/abort-controller': 4.2.4 @@ -14260,43 +14014,22 @@ snapshots: '@smithy/types': 4.8.1 tslib: 2.8.1 - '@smithy/property-provider@2.2.0': - dependencies: - '@smithy/types': 2.12.0 - tslib: 2.8.1 - '@smithy/property-provider@4.2.4': dependencies: '@smithy/types': 4.8.1 tslib: 2.8.1 - '@smithy/protocol-http@3.3.0': - dependencies: - '@smithy/types': 2.12.0 - tslib: 2.8.1 - '@smithy/protocol-http@5.3.4': dependencies: '@smithy/types': 4.8.1 tslib: 2.8.1 - '@smithy/querystring-builder@2.2.0': - dependencies: - '@smithy/types': 2.12.0 - '@smithy/util-uri-escape': 2.2.0 - tslib: 2.8.1 - '@smithy/querystring-builder@4.2.4': dependencies: '@smithy/types': 4.8.1 '@smithy/util-uri-escape': 4.2.0 tslib: 2.8.1 - '@smithy/querystring-parser@2.2.0': - dependencies: - '@smithy/types': 2.12.0 - tslib: 2.8.1 - '@smithy/querystring-parser@4.2.4': dependencies: '@smithy/types': 4.8.1 @@ -14306,26 +14039,11 @@ snapshots: dependencies: '@smithy/types': 4.8.1 - '@smithy/shared-ini-file-loader@2.4.0': - dependencies: - '@smithy/types': 2.12.0 - tslib: 2.8.1 - '@smithy/shared-ini-file-loader@4.3.4': dependencies: '@smithy/types': 4.8.1 tslib: 2.8.1 - '@smithy/signature-v4@3.1.2': - dependencies: - '@smithy/is-array-buffer': 3.0.0 - '@smithy/types': 3.7.2 - '@smithy/util-hex-encoding': 3.0.0 - '@smithy/util-middleware': 3.0.11 - '@smithy/util-uri-escape': 3.0.0 - '@smithy/util-utf8': 3.0.0 - tslib: 2.8.1 - '@smithy/signature-v4@5.3.4': dependencies: '@smithy/is-array-buffer': 4.2.0 @@ -14337,15 +14055,6 @@ snapshots: '@smithy/util-utf8': 4.2.0 tslib: 2.8.1 - '@smithy/smithy-client@2.5.1': - dependencies: - '@smithy/middleware-endpoint': 2.5.1 - '@smithy/middleware-stack': 2.2.0 - '@smithy/protocol-http': 3.3.0 - '@smithy/types': 2.12.0 - '@smithy/util-stream': 2.2.0 - tslib: 2.8.1 - '@smithy/smithy-client@4.9.2': dependencies: '@smithy/core': 3.17.2 @@ -14356,40 +14065,16 @@ snapshots: '@smithy/util-stream': 4.5.5 tslib: 2.8.1 - '@smithy/types@2.12.0': - dependencies: - tslib: 2.8.1 - - '@smithy/types@3.7.2': - dependencies: - tslib: 2.8.1 - - '@smithy/types@4.3.1': - dependencies: - tslib: 2.8.1 - '@smithy/types@4.8.1': dependencies: tslib: 2.8.1 - '@smithy/url-parser@2.2.0': - dependencies: - '@smithy/querystring-parser': 2.2.0 - '@smithy/types': 2.12.0 - tslib: 2.8.1 - '@smithy/url-parser@4.2.4': dependencies: '@smithy/querystring-parser': 4.2.4 '@smithy/types': 4.8.1 tslib: 2.8.1 - '@smithy/util-base64@2.3.0': - dependencies: - '@smithy/util-buffer-from': 2.2.0 - '@smithy/util-utf8': 2.3.0 - tslib: 2.8.1 - '@smithy/util-base64@4.3.0': dependencies: '@smithy/util-buffer-from': 4.2.0 @@ -14409,11 +14094,6 @@ snapshots: '@smithy/is-array-buffer': 2.2.0 tslib: 2.8.1 - '@smithy/util-buffer-from@3.0.0': - dependencies: - '@smithy/is-array-buffer': 3.0.0 - tslib: 2.8.1 - '@smithy/util-buffer-from@4.2.0': dependencies: '@smithy/is-array-buffer': 4.2.0 @@ -14446,28 +14126,10 @@ snapshots: '@smithy/types': 4.8.1 tslib: 2.8.1 - '@smithy/util-hex-encoding@2.2.0': - dependencies: - tslib: 2.8.1 - - '@smithy/util-hex-encoding@3.0.0': - dependencies: - tslib: 2.8.1 - '@smithy/util-hex-encoding@4.2.0': dependencies: tslib: 2.8.1 - '@smithy/util-middleware@2.2.0': - dependencies: - '@smithy/types': 2.12.0 - tslib: 2.8.1 - - '@smithy/util-middleware@3.0.11': - dependencies: - '@smithy/types': 3.7.2 - tslib: 2.8.1 - '@smithy/util-middleware@4.2.4': dependencies: '@smithy/types': 4.8.1 @@ -14479,17 +14141,6 @@ snapshots: '@smithy/types': 4.8.1 tslib: 2.8.1 - '@smithy/util-stream@2.2.0': - dependencies: - '@smithy/fetch-http-handler': 2.5.0 - '@smithy/node-http-handler': 2.5.0 - '@smithy/types': 2.12.0 - '@smithy/util-base64': 2.3.0 - '@smithy/util-buffer-from': 2.2.0 - '@smithy/util-hex-encoding': 2.2.0 - '@smithy/util-utf8': 2.3.0 - tslib: 2.8.1 - '@smithy/util-stream@4.5.5': dependencies: '@smithy/fetch-http-handler': 5.3.5 @@ -14501,14 +14152,6 @@ snapshots: '@smithy/util-utf8': 4.2.0 tslib: 2.8.1 - '@smithy/util-uri-escape@2.2.0': - dependencies: - tslib: 2.8.1 - - '@smithy/util-uri-escape@3.0.0': - dependencies: - tslib: 2.8.1 - '@smithy/util-uri-escape@4.2.0': dependencies: tslib: 2.8.1 @@ -14518,11 +14161,6 @@ snapshots: '@smithy/util-buffer-from': 2.2.0 tslib: 2.8.1 - '@smithy/util-utf8@3.0.0': - dependencies: - '@smithy/util-buffer-from': 3.0.0 - tslib: 2.8.1 - '@smithy/util-utf8@4.2.0': dependencies: '@smithy/util-buffer-from': 4.2.0 @@ -15627,6 +15265,8 @@ snapshots: dependencies: possible-typed-array-names: 1.1.0 + aws4fetch@1.0.20: {} + axios@1.12.0: dependencies: follow-redirects: 1.15.11 diff --git a/src/api/providers/__tests__/bedrock-custom-arn.spec.ts b/src/api/providers/__tests__/bedrock-custom-arn.spec.ts index dfad54c1fd0..75cc27c89da 100644 --- a/src/api/providers/__tests__/bedrock-custom-arn.spec.ts +++ b/src/api/providers/__tests__/bedrock-custom-arn.spec.ts @@ -22,38 +22,6 @@ vitest.mock("../../../utils/logging", () => ({ }, })) -// Mock AWS SDK -vitest.mock("@aws-sdk/client-bedrock-runtime", () => { - const mockModule = { - lastCommandInput: null as Record | null, - mockSend: vitest.fn().mockImplementation(async function () { - return { - output: new TextEncoder().encode(JSON.stringify({ content: "Test response" })), - } - }), - mockConverseCommand: vitest.fn(function (input) { - mockModule.lastCommandInput = input - return { input } - }), - MockBedrockRuntimeClient: class { - public config: any - public send: any - - constructor(config: { region?: string }) { - this.config = config - this.send = mockModule.mockSend - } - }, - } - - return { - BedrockRuntimeClient: mockModule.MockBedrockRuntimeClient, - ConverseCommand: mockModule.mockConverseCommand, - ConverseStreamCommand: vitest.fn(), - __mock: mockModule, // Expose mock internals for testing - } -}) - describe("Bedrock ARN Handling", () => { // Helper function to create a handler with specific options const createHandler = (options: Partial = {}) => { @@ -224,8 +192,8 @@ describe("Bedrock ARN Handling", () => { "arn:aws:bedrock:eu-west-1:123456789012:inference-profile/anthropic.claude-3-sonnet-20240229-v1:0", }) - // Verify the client was created with the ARN region, not the provided region - expect((handler as any).client.config.region).toBe("eu-west-1") + // Verify the handler's options were updated with the ARN region + expect((handler as any).options.awsRegion).toBe("eu-west-1") }) it("should log region mismatch warning when ARN region differs from provided region", () => { diff --git a/src/api/providers/__tests__/bedrock-error-handling.spec.ts b/src/api/providers/__tests__/bedrock-error-handling.spec.ts index 2041dde4577..d217984c8da 100644 --- a/src/api/providers/__tests__/bedrock-error-handling.spec.ts +++ b/src/api/providers/__tests__/bedrock-error-handling.spec.ts @@ -8,9 +8,6 @@ vi.mock("@roo-code/telemetry", () => ({ }, })) -// Mock BedrockRuntimeClient and commands -const mockSend = vi.fn() - // Mock AWS SDK credential providers vi.mock("@aws-sdk/credential-providers", () => { return { @@ -21,16 +18,27 @@ vi.mock("@aws-sdk/credential-providers", () => { } }) -vi.mock("@aws-sdk/client-bedrock-runtime", () => ({ - BedrockRuntimeClient: vi.fn().mockImplementation(() => ({ - send: mockSend, - })), - ConverseStreamCommand: vi.fn(), - ConverseCommand: vi.fn(), +// Use vi.hoisted to define mock functions for AI SDK +const { mockStreamText, mockGenerateText } = vi.hoisted(() => ({ + mockStreamText: vi.fn(), + mockGenerateText: vi.fn(), +})) + +vi.mock("ai", async (importOriginal) => { + const actual = await importOriginal() + return { + ...actual, + streamText: mockStreamText, + generateText: mockGenerateText, + } +}) + +vi.mock("@ai-sdk/amazon-bedrock", () => ({ + createAmazonBedrock: vi.fn(() => vi.fn(() => ({ modelId: "test", provider: "bedrock" }))), })) import { AwsBedrockHandler } from "../bedrock" -import { Anthropic } from "@anthropic-ai/sdk" +import type { Anthropic } from "@anthropic-ai/sdk" describe("AwsBedrockHandler Error Handling", () => { let handler: AwsBedrockHandler @@ -46,6 +54,10 @@ describe("AwsBedrockHandler Error Handling", () => { }) }) + /** + * Helper: create an Error with optional extra properties that + * the production code inspects (status, name, $metadata, __type). + */ const createMockError = (options: { message?: string name?: string @@ -56,505 +68,481 @@ describe("AwsBedrockHandler Error Handling", () => { requestId?: string extendedRequestId?: string cfId?: string - [key: string]: any // Allow additional properties + [key: string]: unknown } }): Error => { const error = new Error(options.message || "Test error") as any if (options.name) error.name = options.name - if (options.status) error.status = options.status + if (options.status !== undefined) error.status = options.status if (options.__type) error.__type = options.__type if (options.$metadata) error.$metadata = options.$metadata return error } - describe("Throttling Error Detection", () => { - it("should detect throttling from HTTP 429 status code", async () => { + // ----------------------------------------------------------------------- + // Throttling Detection — completePrompt path + // + // Production flow: generateText throws → catch → isThrottlingError() is + // NOT called in completePrompt (only in createMessage), so it falls + // through to handleAiSdkError which wraps with "Bedrock: ". + // + // For createMessage: streamText throws → catch → isThrottlingError() + // returns true → re-throws original error. + // ----------------------------------------------------------------------- + + describe("Throttling Error Detection (createMessage)", () => { + it("should re-throw throttling errors with status 429 for retry", async () => { const throttleError = createMockError({ message: "Request failed", status: 429, }) - mockSend.mockRejectedValueOnce(throttleError) + mockStreamText.mockImplementation(() => { + throw throttleError + }) - try { - const result = await handler.completePrompt("test") - expect(result).toContain("throttled or rate limited") - } catch (error) { - expect(error.message).toContain("throttled or rate limited") - } + const generator = handler.createMessage("system", [{ role: "user", content: "test" }]) + + await expect(async () => { + for await (const _chunk of generator) { + // should throw + } + }).rejects.toThrow("Request failed") }) - it("should detect throttling from AWS SDK $metadata.httpStatusCode", async () => { + it("should re-throw throttling errors detected via $metadata.httpStatusCode", async () => { const throttleError = createMockError({ message: "Request failed", $metadata: { httpStatusCode: 429 }, }) - mockSend.mockRejectedValueOnce(throttleError) + mockStreamText.mockImplementation(() => { + throw throttleError + }) - try { - const result = await handler.completePrompt("test") - expect(result).toContain("throttled or rate limited") - } catch (error) { - expect(error.message).toContain("throttled or rate limited") - } + const generator = handler.createMessage("system", [{ role: "user", content: "test" }]) + + await expect(async () => { + for await (const _chunk of generator) { + // should throw + } + }).rejects.toThrow("Request failed") }) - it("should detect throttling from ThrottlingException name", async () => { + it("should re-throw ThrottlingException by name", async () => { const throttleError = createMockError({ message: "Request failed", name: "ThrottlingException", }) - mockSend.mockRejectedValueOnce(throttleError) - - try { - const result = await handler.completePrompt("test") - expect(result).toContain("throttled or rate limited") - } catch (error) { - expect(error.message).toContain("throttled or rate limited") - } - }) - - it("should detect throttling from __type field", async () => { - const throttleError = createMockError({ - message: "Request failed", - __type: "ThrottlingException", + mockStreamText.mockImplementation(() => { + throw throttleError }) - mockSend.mockRejectedValueOnce(throttleError) + const generator = handler.createMessage("system", [{ role: "user", content: "test" }]) - try { - const result = await handler.completePrompt("test") - expect(result).toContain("throttled or rate limited") - } catch (error) { - expect(error.message).toContain("throttled or rate limited") - } + await expect(async () => { + for await (const _chunk of generator) { + // should throw + } + }).rejects.toThrow("Request failed") }) - it("should detect throttling from 'Bedrock is unable to process your request' message", async () => { + it("should re-throw 'Bedrock is unable to process your request' as throttling", async () => { const throttleError = createMockError({ message: "Bedrock is unable to process your request", }) - mockSend.mockRejectedValueOnce(throttleError) + mockStreamText.mockImplementation(() => { + throw throttleError + }) + + const generator = handler.createMessage("system", [{ role: "user", content: "test" }]) - try { - const result = await handler.completePrompt("test") - expect(result).toContain("throttled or rate limited") - } catch (error) { - expect(error.message).toMatch(/throttled or rate limited/) - } + await expect(async () => { + for await (const _chunk of generator) { + // should throw + } + }).rejects.toThrow("Bedrock is unable to process your request") }) it("should detect throttling from various message patterns", async () => { - const throttlingMessages = [ - "Request throttled", - "Rate limit exceeded", - "Too many requests", - "Service unavailable due to high demand", - "Server is overloaded", - "System is busy", - "Please wait and try again", - ] + const throttlingMessages = ["Request throttled", "Rate limit exceeded", "Too many requests"] for (const message of throttlingMessages) { + vi.clearAllMocks() const throttleError = createMockError({ message }) - mockSend.mockRejectedValueOnce(throttleError) - - try { - await handler.completePrompt("test") - // Should not reach here as completePrompt should throw - throw new Error("Expected error to be thrown") - } catch (error) { - expect(error.message).toContain("throttled or rate limited") - } + + mockStreamText.mockImplementation(() => { + throw throttleError + }) + + const localHandler = new AwsBedrockHandler({ + apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", + awsAccessKey: "test-access-key", + awsSecretKey: "test-secret-key", + awsRegion: "us-east-1", + }) + + const generator = localHandler.createMessage("system", [{ role: "user", content: "test" }]) + + // Throttling errors are re-thrown with original message for retry + await expect(async () => { + for await (const _chunk of generator) { + // should throw + } + }).rejects.toThrow(message) } }) - it("should display appropriate error information for throttling errors", async () => { - const throttlingError = createMockError({ - message: "Bedrock is unable to process your request", - name: "ThrottlingException", + it("should prioritize HTTP status 429 over message content for throttling", async () => { + const mixedError = createMockError({ + message: "Some generic error message", status: 429, - $metadata: { - httpStatusCode: 429, - requestId: "12345-abcde-67890", - extendedRequestId: "extended-12345", - cfId: "cf-12345", - }, }) - mockSend.mockRejectedValueOnce(throttlingError) + mockStreamText.mockImplementation(() => { + throw mixedError + }) - try { - await handler.completePrompt("test") - throw new Error("Expected error to be thrown") - } catch (error) { - // Should contain the main error message - expect(error.message).toContain("throttled or rate limited") - } + const generator = handler.createMessage("system", [{ role: "user", content: "test" }]) + + // Because status=429, it's throttling → re-throws original error + await expect(async () => { + for await (const _chunk of generator) { + // should throw + } + }).rejects.toThrow("Some generic error message") }) - }) - describe("Service Quota Exceeded Detection", () => { - it("should detect service quota exceeded errors", async () => { - const quotaError = createMockError({ - message: "Service quota exceeded for model requests", + it("should prioritize ThrottlingException name over message for throttling", async () => { + const specificError = createMockError({ + message: "Some other error occurred", + name: "ThrottlingException", }) - mockSend.mockRejectedValueOnce(quotaError) + mockStreamText.mockImplementation(() => { + throw specificError + }) - try { - const result = await handler.completePrompt("test") - expect(result).toContain("Service quota exceeded") - } catch (error) { - expect(error.message).toContain("Service quota exceeded") - } + const generator = handler.createMessage("system", [{ role: "user", content: "test" }]) + + // ThrottlingException → re-throws original for retry + await expect(async () => { + for await (const _chunk of generator) { + // should throw + } + }).rejects.toThrow("Some other error occurred") }) }) - describe("Model Not Ready Detection", () => { - it("should detect model not ready errors", async () => { - const modelError = createMockError({ - message: "Model is not ready, please try again later", + // ----------------------------------------------------------------------- + // Non-throttling errors (createMessage) are wrapped by handleAiSdkError + // ----------------------------------------------------------------------- + + describe("Non-throttling errors (createMessage)", () => { + it("should wrap non-throttling errors with provider name via handleAiSdkError", async () => { + const genericError = createMockError({ + message: "Something completely unexpected happened", }) - mockSend.mockRejectedValueOnce(modelError) + mockStreamText.mockImplementation(() => { + throw genericError + }) - try { - const result = await handler.completePrompt("test") - expect(result).toContain("Model is not ready") - } catch (error) { - expect(error.message).toContain("Model is not ready") - } + const generator = handler.createMessage("system", [{ role: "user", content: "test" }]) + + await expect(async () => { + for await (const _chunk of generator) { + // should throw + } + }).rejects.toThrow("Bedrock: Something completely unexpected happened") }) - }) - describe("Internal Server Error Detection", () => { - it("should detect internal server errors", async () => { - const serverError = createMockError({ + it("should preserve status code from non-throttling API errors", async () => { + const apiError = createMockError({ message: "Internal server error occurred", + status: 500, }) - mockSend.mockRejectedValueOnce(serverError) + mockStreamText.mockImplementation(() => { + throw apiError + }) + + const generator = handler.createMessage("system", [{ role: "user", content: "test" }]) try { - const result = await handler.completePrompt("test") - expect(result).toContain("internal server error") - } catch (error) { - expect(error.message).toContain("internal server error") + for await (const _chunk of generator) { + // should throw + } + throw new Error("Expected error to be thrown") + } catch (error: any) { + expect(error.message).toContain("Bedrock:") + expect(error.message).toContain("Internal server error occurred") } }) - }) - describe("Token Limit Detection", () => { - it("should detect enhanced token limit errors", async () => { - const tokenErrors = [ - "Too many tokens in request", - "Token limit exceeded", - "Maximum context length reached", - "Context length exceeds limit", - ] - - for (const message of tokenErrors) { - const tokenError = createMockError({ message }) - mockSend.mockRejectedValueOnce(tokenError) - - try { - await handler.completePrompt("test") - // Should not reach here as completePrompt should throw - throw new Error("Expected error to be thrown") - } catch (error) { - // Either "Too many tokens" for token-specific errors or "throttled" for limit-related errors - expect(error.message).toMatch(/Too many tokens|throttled or rate limited/) + it("should handle validation errors (token limits) as non-throttling", async () => { + const tokenError = createMockError({ + message: "Too many tokens in request", + name: "ValidationException", + }) + + mockStreamText.mockImplementation(() => { + throw tokenError + }) + + const generator = handler.createMessage("system", [{ role: "user", content: "test" }]) + + await expect(async () => { + for await (const _chunk of generator) { + // should throw } - } + }).rejects.toThrow("Bedrock: Too many tokens in request") }) }) + // ----------------------------------------------------------------------- + // Streaming context: errors mid-stream + // ----------------------------------------------------------------------- + describe("Streaming Context Error Handling", () => { - it("should handle throttling errors in streaming context", async () => { + it("should re-throw throttling errors that occur mid-stream", async () => { const throttleError = createMockError({ message: "Bedrock is unable to process your request", status: 429, }) - const mockStream = { - [Symbol.asyncIterator]() { - return { - async next() { - throw throttleError - }, - } - }, + // Mock streamText to return an object whose fullStream throws mid-iteration + async function* failingStream() { + yield { type: "text-delta" as const, textDelta: "partial" } + throw throttleError } - mockSend.mockResolvedValueOnce({ stream: mockStream }) + mockStreamText.mockReturnValue({ + fullStream: failingStream(), + usage: Promise.resolve({ inputTokens: 0, outputTokens: 0 }), + providerMetadata: Promise.resolve({}), + }) const generator = handler.createMessage("system", [{ role: "user", content: "test" }]) - // For throttling errors, it should throw immediately without yielding chunks - // This allows the retry mechanism to catch and handle it await expect(async () => { - for await (const chunk of generator) { - // Should not yield any chunks for throttling errors + for await (const _chunk of generator) { + // may yield partial text before throwing } }).rejects.toThrow("Bedrock is unable to process your request") }) - it("should yield error chunks for non-throttling errors in streaming context", async () => { + it("should wrap non-throttling errors that occur mid-stream via handleAiSdkError", async () => { const genericError = createMockError({ message: "Some other error", status: 500, }) - const mockStream = { - [Symbol.asyncIterator]() { - return { - async next() { - throw genericError - }, - } - }, + async function* failingStream() { + yield { type: "text-delta" as const, textDelta: "partial" } + throw genericError } - mockSend.mockResolvedValueOnce({ stream: mockStream }) + mockStreamText.mockReturnValue({ + fullStream: failingStream(), + usage: Promise.resolve({ inputTokens: 0, outputTokens: 0 }), + providerMetadata: Promise.resolve({}), + }) const generator = handler.createMessage("system", [{ role: "user", content: "test" }]) - const chunks: any[] = [] - try { - for await (const chunk of generator) { - chunks.push(chunk) + await expect(async () => { + for await (const _chunk of generator) { + // should throw } - } catch (error) { - // Expected to throw after yielding chunks - } - - // Should have yielded error chunks before throwing for non-throttling errors - expect( - chunks.some((chunk) => chunk.type === "text" && chunk.text && chunk.text.includes("Some other error")), - ).toBe(true) + }).rejects.toThrow("Bedrock: Some other error") }) }) - describe("Error Priority and Specificity", () => { - it("should prioritize HTTP status codes over message patterns", async () => { - // Error with both 429 status and generic message should be detected as throttling - const mixedError = createMockError({ - message: "Some generic error message", - status: 429, - }) + // ----------------------------------------------------------------------- + // completePrompt errors — all go through handleAiSdkError (no throttle check) + // ----------------------------------------------------------------------- - mockSend.mockRejectedValueOnce(mixedError) + describe("completePrompt error handling", () => { + it("should wrap errors with provider name for completePrompt", async () => { + mockGenerateText.mockRejectedValueOnce(new Error("Bedrock API failure")) - try { - const result = await handler.completePrompt("test") - expect(result).toContain("throttled or rate limited") - } catch (error) { - expect(error.message).toContain("throttled or rate limited") - } + await expect(handler.completePrompt("test")).rejects.toThrow("Bedrock: Bedrock API failure") }) - it("should prioritize AWS error types over message patterns", async () => { - // Error with ThrottlingException name but different message should still be throttling - const specificError = createMockError({ - message: "Some other error occurred", - name: "ThrottlingException", + it("should wrap throttling-pattern errors with provider name for completePrompt", async () => { + const throttleError = createMockError({ + message: "Bedrock is unable to process your request", + status: 429, }) - mockSend.mockRejectedValueOnce(specificError) + mockGenerateText.mockRejectedValueOnce(throttleError) - try { - const result = await handler.completePrompt("test") - expect(result).toContain("throttled or rate limited") - } catch (error) { - expect(error.message).toContain("throttled or rate limited") - } + // completePrompt does NOT have the throttle-rethrow path; it always uses handleAiSdkError + await expect(handler.completePrompt("test")).rejects.toThrow( + "Bedrock: Bedrock is unable to process your request", + ) }) - }) - describe("Unknown Error Fallback", () => { - it("should still show unknown error for truly unrecognized errors", async () => { - const unknownError = createMockError({ - message: "Something completely unexpected happened", - }) + it("should handle concurrent generateText failures", async () => { + const error = new Error("API failure") + mockGenerateText.mockRejectedValue(error) - mockSend.mockRejectedValueOnce(unknownError) + const promises = Array.from({ length: 5 }, () => handler.completePrompt("test")) + const results = await Promise.allSettled(promises) - try { - const result = await handler.completePrompt("test") - expect(result).toContain("Unknown Error") - } catch (error) { - expect(error.message).toContain("Unknown Error") - } + results.forEach((result) => { + expect(result.status).toBe("rejected") + if (result.status === "rejected") { + expect(result.reason.message).toContain("Bedrock:") + } + }) }) - }) - describe("Enhanced Error Throw for Retry System", () => { - it("should throw enhanced error messages for completePrompt to display in retry system", async () => { - const throttlingError = createMockError({ - message: "Too many tokens, rate limited", - status: 429, - $metadata: { - httpStatusCode: 429, - requestId: "test-request-id-12345", - }, + it("should preserve status code from API call errors in completePrompt", async () => { + const apiError = createMockError({ + message: "Service unavailable", + status: 503, }) - mockSend.mockRejectedValueOnce(throttlingError) + + mockGenerateText.mockRejectedValueOnce(apiError) try { await handler.completePrompt("test") throw new Error("Expected error to be thrown") - } catch (error) { - // Should contain the verbose message template - expect(error.message).toContain("Request was throttled or rate limited") - // Should preserve original error properties - expect((error as any).status).toBe(429) - expect((error as any).$metadata.requestId).toBe("test-request-id-12345") + } catch (error: any) { + expect(error.message).toContain("Bedrock:") + expect(error.message).toContain("Service unavailable") } }) + }) - it("should throw enhanced error messages for createMessage streaming to display in retry system", async () => { - const tokenError = createMockError({ - message: "Too many tokens in request", - name: "ValidationException", - $metadata: { - httpStatusCode: 400, - requestId: "token-error-id-67890", - extendedRequestId: "extended-12345", - }, - }) - - const mockStream = { - [Symbol.asyncIterator]() { - return { - async next() { - throw tokenError - }, - } - }, - } + // ----------------------------------------------------------------------- + // Telemetry + // ----------------------------------------------------------------------- + + describe("Error telemetry", () => { + it("should capture telemetry for createMessage errors", async () => { + mockStreamText.mockImplementation(() => { + throw new Error("Stream failure") + }) - mockSend.mockResolvedValueOnce({ stream: mockStream }) + const generator = handler.createMessage("system", [{ role: "user", content: "test" }]) - try { - const stream = handler.createMessage("system", [{ role: "user", content: "test" }]) - for await (const chunk of stream) { - // Should not reach here as it should throw an error + await expect(async () => { + for await (const _chunk of generator) { + // should throw } - throw new Error("Expected error to be thrown") - } catch (error) { - // Should contain error codes (note: this will be caught by the non-throttling error path) - expect(error.message).toContain("Too many tokens") - // Should preserve original error properties - expect(error.name).toBe("ValidationException") - expect((error as any).$metadata.requestId).toBe("token-error-id-67890") - } + }).rejects.toThrow() + + expect(mockCaptureException).toHaveBeenCalled() }) - }) - describe("Edge Case Test Coverage", () => { - it("should handle concurrent throttling errors correctly", async () => { - const throttlingError = createMockError({ - message: "Bedrock is unable to process your request", + it("should capture telemetry for completePrompt errors", async () => { + mockGenerateText.mockRejectedValueOnce(new Error("Generate failure")) + + await expect(handler.completePrompt("test")).rejects.toThrow() + + expect(mockCaptureException).toHaveBeenCalled() + }) + + it("should capture telemetry for throttling errors too", async () => { + const throttleError = createMockError({ + message: "Rate limit exceeded", status: 429, }) - // Setup multiple concurrent requests that will all fail with throttling - mockSend.mockRejectedValue(throttlingError) - - // Execute multiple concurrent requests - const promises = Array.from({ length: 5 }, () => handler.completePrompt("test")) + mockStreamText.mockImplementation(() => { + throw throttleError + }) - // All should throw with throttling error - const results = await Promise.allSettled(promises) + const generator = handler.createMessage("system", [{ role: "user", content: "test" }]) - results.forEach((result) => { - expect(result.status).toBe("rejected") - if (result.status === "rejected") { - expect(result.reason.message).toContain("throttled or rate limited") + await expect(async () => { + for await (const _chunk of generator) { + // should throw } - }) + }).rejects.toThrow() + + // Telemetry is captured even for throttling errors + expect(mockCaptureException).toHaveBeenCalled() }) + }) - it("should handle mixed error scenarios with both throttling and other indicators", async () => { - // Error with both 429 status (throttling) and validation error message - const mixedError = createMockError({ - message: "ValidationException: Your input is invalid, but also rate limited", - name: "ValidationException", - status: 429, - $metadata: { - httpStatusCode: 429, - requestId: "mixed-error-id", - }, - }) + // ----------------------------------------------------------------------- + // Edge cases + // ----------------------------------------------------------------------- - mockSend.mockRejectedValueOnce(mixedError) + describe("Edge Case Test Coverage", () => { + it("should handle non-Error objects thrown by generateText", async () => { + mockGenerateText.mockRejectedValueOnce("string error") - try { - await handler.completePrompt("test") - } catch (error) { - // Should be treated as throttling due to 429 status taking priority - expect(error.message).toContain("throttled or rate limited") - // Should still preserve metadata - expect((error as any).$metadata?.requestId).toBe("mixed-error-id") - } + await expect(handler.completePrompt("test")).rejects.toThrow("Bedrock: string error") }) - it("should handle rapid successive retries in streaming context", async () => { - const throttlingError = createMockError({ - message: "ThrottlingException", - name: "ThrottlingException", + it("should handle non-Error objects thrown by streamText", async () => { + mockStreamText.mockImplementation(() => { + throw "string error" }) - // Mock stream that throws immediately - const mockStream = { - // eslint-disable-next-line require-yield - [Symbol.asyncIterator]: async function* () { - throw throttlingError - }, - } - - mockSend.mockResolvedValueOnce({ stream: mockStream }) - - const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "test" }] + const generator = handler.createMessage("system", [{ role: "user", content: "test" }]) - try { - // Should throw immediately without yielding any chunks - const stream = handler.createMessage("", messages) - const chunks = [] - for await (const chunk of stream) { - chunks.push(chunk) + // Non-Error values are not detected as throttling → handleAiSdkError path + await expect(async () => { + for await (const _chunk of generator) { + // should throw } - // Should not reach here - expect(chunks).toHaveLength(0) - } catch (error) { - // Error should be thrown immediately for retry mechanism - // The error might be a TypeError if the stream iterator fails - expect(error).toBeDefined() - // The important thing is that it throws immediately without yielding chunks - } + }).rejects.toThrow("Bedrock: string error") }) - it("should validate error properties exist before accessing them", async () => { - // Error with unusual structure - const unusualError = { - message: "Error with unusual structure", - // Missing typical properties like name, status, etc. - } - - mockSend.mockRejectedValueOnce(unusualError) + it("should handle errors with unusual structure gracefully", async () => { + const unusualError = { message: "Error with unusual structure" } + mockGenerateText.mockRejectedValueOnce(unusualError) try { await handler.completePrompt("test") - } catch (error) { - // Should handle gracefully without accessing undefined properties - expect(error.message).toContain("Unknown Error") - // Should not have undefined values in the error message + throw new Error("Expected error to be thrown") + } catch (error: any) { + // handleAiSdkError wraps with "Bedrock: ..." + expect(error.message).toContain("Bedrock:") expect(error.message).not.toContain("undefined") } }) + + it("should handle concurrent throttling errors in streaming context", async () => { + const throttlingError = createMockError({ + message: "Bedrock is unable to process your request", + status: 429, + }) + + mockStreamText.mockImplementation(() => { + throw throttlingError + }) + + // Execute multiple concurrent streaming requests + const promises = Array.from({ length: 3 }, async () => { + const localHandler = new AwsBedrockHandler({ + apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", + awsAccessKey: "test-access-key", + awsSecretKey: "test-secret-key", + awsRegion: "us-east-1", + }) + const gen = localHandler.createMessage("system", [{ role: "user", content: "test" }]) + for await (const _chunk of gen) { + // should throw + } + }) + + const results = await Promise.allSettled(promises) + results.forEach((result) => { + expect(result.status).toBe("rejected") + if (result.status === "rejected") { + // Throttling errors are re-thrown with original message + expect(result.reason.message).toBe("Bedrock is unable to process your request") + } + }) + }) }) }) diff --git a/src/api/providers/__tests__/bedrock-inference-profiles.spec.ts b/src/api/providers/__tests__/bedrock-inference-profiles.spec.ts index dee3af3b916..131e462f09a 100644 --- a/src/api/providers/__tests__/bedrock-inference-profiles.spec.ts +++ b/src/api/providers/__tests__/bedrock-inference-profiles.spec.ts @@ -4,18 +4,6 @@ import { AWS_INFERENCE_PROFILE_MAPPING } from "@roo-code/types" import { AwsBedrockHandler } from "../bedrock" import { ApiHandlerOptions } from "../../../shared/api" -// Mock AWS SDK -vitest.mock("@aws-sdk/client-bedrock-runtime", () => { - return { - BedrockRuntimeClient: vitest.fn().mockImplementation(() => ({ - send: vitest.fn(), - config: { region: "us-east-1" }, - })), - ConverseCommand: vitest.fn(), - ConverseStreamCommand: vitest.fn(), - } -}) - describe("Amazon Bedrock Inference Profiles", () => { // Helper function to create a handler with specific options const createHandler = (options: Partial = {}) => { diff --git a/src/api/providers/__tests__/bedrock-invokedModelId.spec.ts b/src/api/providers/__tests__/bedrock-invokedModelId.spec.ts index fe16ea89eb6..63322d988ed 100644 --- a/src/api/providers/__tests__/bedrock-invokedModelId.spec.ts +++ b/src/api/providers/__tests__/bedrock-invokedModelId.spec.ts @@ -1,350 +1,198 @@ // npx vitest run src/api/providers/__tests__/bedrock-invokedModelId.spec.ts -import { ApiHandlerOptions } from "../../../shared/api" - -import { AwsBedrockHandler, StreamEvent } from "../bedrock" +// Mock TelemetryService before other imports +vi.mock("@roo-code/telemetry", () => ({ + TelemetryService: { + instance: { + captureException: vi.fn(), + }, + }, +})) -// Mock AWS SDK credential providers and Bedrock client -vitest.mock("@aws-sdk/credential-providers", () => ({ - fromIni: vitest.fn().mockReturnValue({ +// Mock AWS SDK credential providers +vi.mock("@aws-sdk/credential-providers", () => ({ + fromIni: vi.fn().mockReturnValue({ accessKeyId: "profile-access-key", secretAccessKey: "profile-secret-key", }), })) -// Mock Smithy client -vitest.mock("@smithy/smithy-client", () => ({ - throwDefaultError: vitest.fn(), +// Use vi.hoisted to define mock functions for AI SDK +const { mockStreamText, mockGenerateText } = vi.hoisted(() => ({ + mockStreamText: vi.fn(), + mockGenerateText: vi.fn(), })) -// Create a mock send function that we can reference -const mockSend = vitest.fn().mockImplementation(async () => { +vi.mock("ai", async (importOriginal) => { + const actual = await importOriginal() return { - $metadata: { - httpStatusCode: 200, - requestId: "mock-request-id", - }, - stream: { - [Symbol.asyncIterator]: async function* () { - yield { - metadata: { - usage: { - inputTokens: 100, - outputTokens: 200, - }, - }, - } - }, - }, + ...actual, + streamText: mockStreamText, + generateText: mockGenerateText, } }) -// Mock AWS SDK modules -vitest.mock("@aws-sdk/client-bedrock-runtime", () => { - return { - BedrockRuntimeClient: vitest.fn().mockImplementation(() => ({ - send: mockSend, - config: { region: "us-east-1" }, - middlewareStack: { - clone: () => ({ resolve: () => {} }), - use: () => {}, - }, - })), - ConverseStreamCommand: vitest.fn((params) => ({ - ...params, - input: params, - middlewareStack: { - clone: () => ({ resolve: () => {} }), - use: () => {}, - }, - })), - ConverseCommand: vitest.fn((params) => ({ - ...params, - input: params, - middlewareStack: { - clone: () => ({ resolve: () => {} }), - use: () => {}, - }, - })), - } -}) +vi.mock("@ai-sdk/amazon-bedrock", () => ({ + createAmazonBedrock: vi.fn(() => vi.fn(() => ({ modelId: "test", provider: "bedrock" }))), +})) + +import { AwsBedrockHandler } from "../bedrock" +import { bedrockModels } from "@roo-code/types" describe("AwsBedrockHandler with invokedModelId", () => { beforeEach(() => { - vitest.clearAllMocks() + vi.clearAllMocks() }) - // Helper function to create a mock async iterable stream - function createMockStream(events: StreamEvent[]) { - return { - [Symbol.asyncIterator]: async function* () { - for (const event of events) { - yield event - } - // Always yield a metadata event at the end - yield { - metadata: { - usage: { - inputTokens: 100, - outputTokens: 200, + /** + * Helper: set up mockStreamText to return a stream whose resolved + * `providerMetadata` contains the given `invokedModelId` in the + * `bedrock.trace.promptRouter` path. + */ + function setupMockStreamWithInvokedModelId(invokedModelId?: string) { + async function* mockFullStream() { + yield { type: "text-delta", text: "Hello" } + yield { type: "text-delta", text: ", world!" } + } + + const providerMetadata = invokedModelId + ? { + bedrock: { + trace: { + promptRouter: { + invokedModelId, + }, }, }, } - }, - } + : {} + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: Promise.resolve({ inputTokens: 100, outputTokens: 200 }), + providerMetadata: Promise.resolve(providerMetadata), + }) } - it("should update costModelConfig when invokedModelId is present in the stream", async () => { - // Create a handler with a custom ARN - const mockOptions: ApiHandlerOptions = { + it("should update costModelConfig when invokedModelId is present in providerMetadata", async () => { + // Create a handler with a custom ARN (prompt router) + const handler = new AwsBedrockHandler({ awsAccessKey: "test-access-key", awsSecretKey: "test-secret-key", awsRegion: "us-east-1", awsCustomArn: "arn:aws:bedrock:us-west-2:123456789:default-prompt-router/anthropic.claude:1", - } - - const handler = new AwsBedrockHandler(mockOptions) + }) - // Verify that getModel returns the updated model info + // The default prompt router model should use sonnet pricing (inputPrice: 3) const initialModel = handler.getModel() - //the default prompt router model has an input price of 3. After the stream is handled it should be updated to 8 expect(initialModel.info.inputPrice).toBe(3) - // Create a spy on the getModel - const getModelByIdSpy = vitest.spyOn(handler, "getModelById") - - // Mock the stream to include an event with invokedModelId and usage metadata - mockSend.mockImplementationOnce(async () => { - return { - stream: createMockStream([ - // First event with invokedModelId and usage metadata - { - trace: { - promptRouter: { - invokedModelId: - "arn:aws:bedrock:us-west-2:699475926481:inference-profile/us.anthropic.claude-3-opus-20240229-v1:0", - usage: { - inputTokens: 150, - outputTokens: 250, - cacheReadTokens: 0, - cacheWriteTokens: 0, - }, - }, - }, - }, - { - contentBlockStart: { - start: { - text: "Hello", - }, - contentBlockIndex: 0, - }, - }, - { - contentBlockDelta: { - delta: { - text: ", world!", - }, - contentBlockIndex: 0, - }, - }, - ]), - } - }) + // Spy on getModelById to verify the invoked model is looked up + const getModelByIdSpy = vi.spyOn(handler, "getModelById") - // Create a message generator - const messageGenerator = handler.createMessage("system prompt", [{ role: "user", content: "user message" }]) + // Set up stream to include an invokedModelId pointing to Claude 3 Opus + setupMockStreamWithInvokedModelId( + "arn:aws:bedrock:us-west-2:699475926481:inference-profile/us.anthropic.claude-3-opus-20240229-v1:0", + ) - // Collect all yielded events to verify usage events + // Consume the generator const events = [] - for await (const event of messageGenerator) { + for await (const event of handler.createMessage("system prompt", [{ role: "user", content: "user message" }])) { events.push(event) } - // Verify that getModelById was called with the id, not the full arn + // Verify that getModelById was called with the parsed model id and type expect(getModelByIdSpy).toHaveBeenCalledWith("anthropic.claude-3-opus-20240229-v1:0", "inference-profile") - // Verify that getModel returns the updated model info + // After processing, getModel should return the invoked model's pricing (Opus: inputPrice 15) const costModel = handler.getModel() - //expect(costModel.id).toBe("anthropic.claude-3-5-sonnet-20240620-v1:0") expect(costModel.info.inputPrice).toBe(15) - // Verify that a usage event was emitted after updating the costModelConfig - const usageEvents = events.filter((event) => event.type === "usage") + // Verify that a usage event was emitted + const usageEvents = events.filter((e: any) => e.type === "usage") expect(usageEvents.length).toBeGreaterThanOrEqual(1) - // The last usage event should have the token counts from the metadata - const lastUsageEvent = usageEvents[usageEvents.length - 1] - // Expect the usage event to include all token information + // The usage event should contain the token counts + const lastUsageEvent = usageEvents[usageEvents.length - 1] as any expect(lastUsageEvent).toMatchObject({ type: "usage", inputTokens: 100, outputTokens: 200, - // Cache tokens may be present with default values - cacheReadTokens: expect.any(Number), - cacheWriteTokens: expect.any(Number), }) }) it("should not update costModelConfig when invokedModelId is not present", async () => { - // Create a handler with default settings - const mockOptions: ApiHandlerOptions = { + const handler = new AwsBedrockHandler({ apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", awsAccessKey: "test-access-key", awsSecretKey: "test-secret-key", awsRegion: "us-east-1", - } - - const handler = new AwsBedrockHandler(mockOptions) + }) - // Store the initial model configuration const initialModelConfig = handler.getModel() expect(initialModelConfig.id).toBe("anthropic.claude-3-5-sonnet-20241022-v2:0") - // Mock the stream without an invokedModelId event - mockSend.mockImplementationOnce(async () => { - return { - stream: createMockStream([ - // Some content events but no invokedModelId - { - contentBlockStart: { - start: { - text: "Hello", - }, - contentBlockIndex: 0, - }, - }, - { - contentBlockDelta: { - delta: { - text: ", world!", - }, - contentBlockIndex: 0, - }, - }, - ]), - } - }) - - // Create a message generator - const messageGenerator = handler.createMessage("system prompt", [{ role: "user", content: "user message" }]) + // Set up stream WITHOUT an invokedModelId + setupMockStreamWithInvokedModelId(undefined) // Consume the generator - for await (const _ of messageGenerator) { - // Just consume the messages + for await (const _ of handler.createMessage("system prompt", [{ role: "user", content: "user message" }])) { + // Just consume } - // Verify that getModel returns the original model info (unchanged) + // Model should remain unchanged const costModel = handler.getModel() expect(costModel.id).toBe("anthropic.claude-3-5-sonnet-20241022-v2:0") - expect(costModel).toEqual(initialModelConfig) + expect(costModel.info.inputPrice).toBe(initialModelConfig.info.inputPrice) }) it("should handle invalid invokedModelId format gracefully", async () => { - // Create a handler with default settings - const mockOptions: ApiHandlerOptions = { + const handler = new AwsBedrockHandler({ apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", awsAccessKey: "test-access-key", awsSecretKey: "test-secret-key", awsRegion: "us-east-1", - } - - const handler = new AwsBedrockHandler(mockOptions) - - // Mock the stream with an invalid invokedModelId - mockSend.mockImplementationOnce(async () => { - return { - stream: createMockStream([ - // Event with invalid invokedModelId format - { - trace: { - promptRouter: { - invokedModelId: "invalid-format-not-an-arn", - }, - }, - }, - // Some content events - { - contentBlockStart: { - start: { - text: "Hello", - }, - contentBlockIndex: 0, - }, - }, - ]), - } }) - // Create a message generator - const messageGenerator = handler.createMessage("system prompt", [{ role: "user", content: "user message" }]) + // Set up stream with an invalid (non-ARN) invokedModelId + setupMockStreamWithInvokedModelId("invalid-format-not-an-arn") - // Consume the generator - for await (const _ of messageGenerator) { - // Just consume the messages + // Consume the generator — should not throw + for await (const _ of handler.createMessage("system prompt", [{ role: "user", content: "user message" }])) { + // Just consume } - // Verify that getModel returns the original model info + // Model should remain unchanged (the parseArn call should fail gracefully) const costModel = handler.getModel() expect(costModel.id).toBe("anthropic.claude-3-5-sonnet-20241022-v2:0") }) - it("should handle errors during invokedModelId processing", async () => { - // Create a handler with default settings - const mockOptions: ApiHandlerOptions = { - apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", + it("should use the invoked model's pricing for totalCost calculation", async () => { + const handler = new AwsBedrockHandler({ awsAccessKey: "test-access-key", awsSecretKey: "test-secret-key", awsRegion: "us-east-1", - } - - const handler = new AwsBedrockHandler(mockOptions) - - // Mock the stream with a valid invokedModelId - mockSend.mockImplementationOnce(async () => { - return { - stream: createMockStream([ - // Event with valid invokedModelId - { - trace: { - promptRouter: { - invokedModelId: - "arn:aws:bedrock:us-east-1:123456789:foundation-model/anthropic.claude-3-sonnet-20240229-v1:0", - }, - }, - }, - ]), - } - }) - - // Mock getModel to throw an error when called with the model name - vitest.spyOn(handler, "getModel").mockImplementation((modelName?: string) => { - if (modelName === "anthropic.claude-3-sonnet-20240229-v1:0") { - throw new Error("Test error during model lookup") - } - - // Default return value for initial call - return { - id: "anthropic.claude-3-5-sonnet-20241022-v2:0", - info: { - maxTokens: 4096, - contextWindow: 128_000, - supportsPromptCache: false, - supportsImages: true, - }, - } + awsCustomArn: "arn:aws:bedrock:us-west-2:123456789:default-prompt-router/anthropic.claude:1", }) - // Create a message generator - const messageGenerator = handler.createMessage("system prompt", [{ role: "user", content: "user message" }]) + // Set up stream to include Opus as the invoked model + setupMockStreamWithInvokedModelId( + "arn:aws:bedrock:us-west-2:699475926481:foundation-model/anthropic.claude-3-opus-20240229-v1:0", + ) - // Consume the generator - for await (const _ of messageGenerator) { - // Just consume the messages + const events = [] + for await (const event of handler.createMessage("system prompt", [{ role: "user", content: "user message" }])) { + events.push(event) } - // Verify that getModel returns the original model info - const costModel = handler.getModel() - expect(costModel.id).toBe("anthropic.claude-3-5-sonnet-20241022-v2:0") + const usageEvent = events.find((e: any) => e.type === "usage") as any + expect(usageEvent).toBeDefined() + + // Calculate expected cost based on Opus pricing ($15 / 1M input, $75 / 1M output) + const opusInfo = bedrockModels["anthropic.claude-3-opus-20240229-v1:0"] + const expectedCost = + (100 * (opusInfo.inputPrice ?? 0)) / 1_000_000 + (200 * (opusInfo.outputPrice ?? 0)) / 1_000_000 + + expect(usageEvent.totalCost).toBeCloseTo(expectedCost, 10) }) }) diff --git a/src/api/providers/__tests__/bedrock-native-tools.spec.ts b/src/api/providers/__tests__/bedrock-native-tools.spec.ts index e95b2c34b69..74439f00d5b 100644 --- a/src/api/providers/__tests__/bedrock-native-tools.spec.ts +++ b/src/api/providers/__tests__/bedrock-native-tools.spec.ts @@ -1,3 +1,12 @@ +// Mock TelemetryService before other imports +vi.mock("@roo-code/telemetry", () => ({ + TelemetryService: { + instance: { + captureException: vi.fn(), + }, + }, +})) + // Mock AWS SDK credential providers vi.mock("@aws-sdk/credential-providers", () => { const mockFromIni = vi.fn().mockReturnValue({ @@ -7,29 +16,28 @@ vi.mock("@aws-sdk/credential-providers", () => { return { fromIni: mockFromIni } }) -// Mock BedrockRuntimeClient and ConverseStreamCommand -const mockSend = vi.fn() +// Use vi.hoisted to define mock functions for AI SDK +const { mockStreamText, mockGenerateText } = vi.hoisted(() => ({ + mockStreamText: vi.fn(), + mockGenerateText: vi.fn(), +})) -vi.mock("@aws-sdk/client-bedrock-runtime", () => { +vi.mock("ai", async (importOriginal) => { + const actual = await importOriginal() return { - BedrockRuntimeClient: vi.fn().mockImplementation(() => ({ - send: mockSend, - config: { region: "us-east-1" }, - })), - ConverseStreamCommand: vi.fn((params) => ({ - ...params, - input: params, - })), - ConverseCommand: vi.fn(), + ...actual, + streamText: mockStreamText, + generateText: mockGenerateText, } }) +vi.mock("@ai-sdk/amazon-bedrock", () => ({ + createAmazonBedrock: vi.fn(() => vi.fn(() => ({ modelId: "test", provider: "bedrock" }))), +})) + import { AwsBedrockHandler } from "../bedrock" -import { ConverseStreamCommand } from "@aws-sdk/client-bedrock-runtime" import type { ApiHandlerCreateMessageMetadata } from "../../index" -const mockConverseStreamCommand = vi.mocked(ConverseStreamCommand) - // Test tool definitions in OpenAI format const testTools = [ { @@ -63,542 +71,529 @@ const testTools = [ }, ] -describe("AwsBedrockHandler Native Tool Calling", () => { +/** + * Helper: set up mockStreamText to return a simple text-delta stream. + */ +function setupMockStreamText() { + async function* mockFullStream() { + yield { type: "text-delta", text: "Response text" } + } + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: Promise.resolve({ inputTokens: 100, outputTokens: 50 }), + providerMetadata: Promise.resolve({}), + }) +} + +/** + * Helper: set up mockStreamText to return a stream with tool-call events. + */ +function setupMockStreamTextWithToolCall() { + async function* mockFullStream() { + yield { + type: "tool-input-start", + id: "tool-123", + toolName: "read_file", + } + yield { + type: "tool-input-delta", + id: "tool-123", + delta: '{"path": "/test.txt"}', + } + yield { + type: "tool-input-end", + id: "tool-123", + } + } + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: Promise.resolve({ inputTokens: 100, outputTokens: 50 }), + providerMetadata: Promise.resolve({}), + }) +} + +describe("AwsBedrockHandler Native Tool Calling (AI SDK)", () => { let handler: AwsBedrockHandler beforeEach(() => { vi.clearAllMocks() - // Create handler with a model that supports native tools handler = new AwsBedrockHandler({ apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", awsAccessKey: "test-access-key", awsSecretKey: "test-secret-key", awsRegion: "us-east-1", }) - - // Mock the stream response - mockSend.mockResolvedValue({ - stream: [], - }) }) - describe("convertToolsForBedrock", () => { - it("should convert OpenAI tools to Bedrock format", () => { - // Access private method - const convertToolsForBedrock = (handler as any).convertToolsForBedrock.bind(handler) - - const bedrockTools = convertToolsForBedrock(testTools) - - expect(bedrockTools).toHaveLength(2) - - // Check structure and key properties (normalizeToolSchema adds additionalProperties: false) - const tool = bedrockTools[0] - expect(tool.toolSpec.name).toBe("read_file") - expect(tool.toolSpec.description).toBe("Read a file from the filesystem") - expect(tool.toolSpec.inputSchema.json.type).toBe("object") - expect(tool.toolSpec.inputSchema.json.properties.path.type).toBe("string") - expect(tool.toolSpec.inputSchema.json.properties.path.description).toBe("The path to the file") - expect(tool.toolSpec.inputSchema.json.required).toEqual(["path"]) - // normalizeToolSchema adds additionalProperties: false by default - expect(tool.toolSpec.inputSchema.json.additionalProperties).toBe(false) + describe("tools passed to streamText", () => { + it("should pass converted tools to streamText when tools are provided", async () => { + setupMockStreamText() + + const metadata: ApiHandlerCreateMessageMetadata = { + taskId: "test-task", + tools: testTools, + } + + const generator = handler.createMessage( + "You are a helpful assistant.", + [{ role: "user", content: "Read the file at /test.txt" }], + metadata, + ) + + // Drain the generator + for await (const _chunk of generator) { + /* consume */ + } + + expect(mockStreamText).toHaveBeenCalledTimes(1) + const callArgs = mockStreamText.mock.calls[0][0] + + // tools should be defined and contain AI SDK tool objects keyed by name + expect(callArgs.tools).toBeDefined() + expect(callArgs.tools.read_file).toBeDefined() + expect(callArgs.tools.write_file).toBeDefined() }) - it("should transform type arrays to anyOf for JSON Schema 2020-12 compliance", () => { - const convertToolsForBedrock = (handler as any).convertToolsForBedrock.bind(handler) - - // Tools with type: ["string", "null"] syntax (valid in draft-07 but not 2020-12) - const toolsWithNullableTypes = [ - { - type: "function" as const, - function: { - name: "execute_command", - description: "Execute a command", - parameters: { - type: "object", - properties: { - command: { type: "string", description: "The command to execute" }, - cwd: { - type: ["string", "null"], - description: "Working directory (optional)", - }, - }, - required: ["command", "cwd"], - }, - }, - }, - { - type: "function" as const, - function: { - name: "read_file", - description: "Read files", - parameters: { - type: "object", - properties: { - path: { type: "string" }, - indentation: { - type: ["object", "null"], - properties: { - anchor_line: { - type: ["integer", "null"], - description: "Optional anchor line", - }, - }, - }, - }, - required: ["path"], - }, - }, - }, - ] + it("should pass undefined tools when no tools are provided in metadata", async () => { + setupMockStreamText() + + const metadata: ApiHandlerCreateMessageMetadata = { + taskId: "test-task", + // No tools + } - const bedrockTools = convertToolsForBedrock(toolsWithNullableTypes) + const generator = handler.createMessage( + "You are a helpful assistant.", + [{ role: "user", content: "Hello" }], + metadata, + ) - expect(bedrockTools).toHaveLength(2) + for await (const _chunk of generator) { + /* consume */ + } - // First tool: cwd should be transformed from type: ["string", "null"] to anyOf - const executeCommandSchema = bedrockTools[0].toolSpec.inputSchema.json as any - expect(executeCommandSchema.properties.cwd.anyOf).toEqual([{ type: "string" }, { type: "null" }]) - expect(executeCommandSchema.properties.cwd.type).toBeUndefined() - expect(executeCommandSchema.properties.cwd.description).toBe("Working directory (optional)") + expect(mockStreamText).toHaveBeenCalledTimes(1) + const callArgs = mockStreamText.mock.calls[0][0] - // Second tool: nested nullable object should be transformed from type: ["object", "null"] to anyOf - const readFileSchema = bedrockTools[1].toolSpec.inputSchema.json as any - const indentation = readFileSchema.properties.indentation - expect(indentation.anyOf).toBeDefined() - expect(indentation.type).toBeUndefined() - // Object-level schema properties are preserved at the root, not inside the anyOf object variant - expect(indentation.additionalProperties).toBe(false) - expect(indentation.properties.anchor_line.anyOf).toEqual([{ type: "integer" }, { type: "null" }]) + // When no tools are provided, tools should be undefined + expect(callArgs.tools).toBeUndefined() }) - it("should filter non-function tools", () => { - const convertToolsForBedrock = (handler as any).convertToolsForBedrock.bind(handler) + it("should filter non-function tools before passing to streamText", async () => { + setupMockStreamText() - const mixedTools = [ + const mixedTools: any[] = [ ...testTools, - { type: "other" as any, something: {} }, // Should be filtered out + { type: "other", something: {} }, // Should be filtered out ] - const bedrockTools = convertToolsForBedrock(mixedTools) + const metadata: ApiHandlerCreateMessageMetadata = { + taskId: "test-task", + tools: mixedTools as any, + } - expect(bedrockTools).toHaveLength(2) - }) - }) + const generator = handler.createMessage( + "You are a helpful assistant.", + [{ role: "user", content: "Read a file" }], + metadata, + ) - describe("convertToolChoiceForBedrock", () => { - it("should convert 'auto' to Bedrock auto format", () => { - const convertToolChoiceForBedrock = (handler as any).convertToolChoiceForBedrock.bind(handler) + for await (const _chunk of generator) { + /* consume */ + } - const result = convertToolChoiceForBedrock("auto") + expect(mockStreamText).toHaveBeenCalledTimes(1) + const callArgs = mockStreamText.mock.calls[0][0] - expect(result).toEqual({ auto: {} }) + // Only function tools should be present (keyed by name) + expect(callArgs.tools).toBeDefined() + expect(Object.keys(callArgs.tools)).toHaveLength(2) + expect(callArgs.tools.read_file).toBeDefined() + expect(callArgs.tools.write_file).toBeDefined() }) + }) - it("should convert 'required' to Bedrock any format", () => { - const convertToolChoiceForBedrock = (handler as any).convertToolChoiceForBedrock.bind(handler) + describe("toolChoice passed to streamText", () => { + it("should default toolChoice to undefined when tool_choice is not specified", async () => { + setupMockStreamText() - const result = convertToolChoiceForBedrock("required") + const metadata: ApiHandlerCreateMessageMetadata = { + taskId: "test-task", + tools: testTools, + // No tool_choice + } - expect(result).toEqual({ any: {} }) - }) + const generator = handler.createMessage( + "You are a helpful assistant.", + [{ role: "user", content: "Read the file" }], + metadata, + ) - it("should return undefined for 'none'", () => { - const convertToolChoiceForBedrock = (handler as any).convertToolChoiceForBedrock.bind(handler) + for await (const _chunk of generator) { + /* consume */ + } - const result = convertToolChoiceForBedrock("none") + expect(mockStreamText).toHaveBeenCalledTimes(1) + const callArgs = mockStreamText.mock.calls[0][0] - expect(result).toBeUndefined() + // mapToolChoice(undefined) returns undefined + expect(callArgs.toolChoice).toBeUndefined() }) - it("should convert specific tool choice to Bedrock tool format", () => { - const convertToolChoiceForBedrock = (handler as any).convertToolChoiceForBedrock.bind(handler) + it("should pass 'auto' toolChoice when tool_choice is 'auto'", async () => { + setupMockStreamText() - const result = convertToolChoiceForBedrock({ - type: "function", - function: { name: "read_file" }, - }) + const metadata: ApiHandlerCreateMessageMetadata = { + taskId: "test-task", + tools: testTools, + tool_choice: "auto", + } - expect(result).toEqual({ - tool: { - name: "read_file", - }, - }) - }) + const generator = handler.createMessage( + "You are a helpful assistant.", + [{ role: "user", content: "Read the file" }], + metadata, + ) - it("should default to auto for undefined toolChoice", () => { - const convertToolChoiceForBedrock = (handler as any).convertToolChoiceForBedrock.bind(handler) + for await (const _chunk of generator) { + /* consume */ + } - const result = convertToolChoiceForBedrock(undefined) + expect(mockStreamText).toHaveBeenCalledTimes(1) + const callArgs = mockStreamText.mock.calls[0][0] - expect(result).toEqual({ auto: {} }) + expect(callArgs.toolChoice).toBe("auto") }) - }) - describe("createMessage with native tools", () => { - it("should include toolConfig when tools are provided", async () => { - const handlerWithNativeTools = new AwsBedrockHandler({ - apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", - awsAccessKey: "test-access-key", - awsSecretKey: "test-secret-key", - awsRegion: "us-east-1", - }) + it("should pass 'none' toolChoice when tool_choice is 'none'", async () => { + setupMockStreamText() const metadata: ApiHandlerCreateMessageMetadata = { taskId: "test-task", tools: testTools, + tool_choice: "none", } - const generator = handlerWithNativeTools.createMessage( + const generator = handler.createMessage( "You are a helpful assistant.", - [{ role: "user", content: "Read the file at /test.txt" }], + [{ role: "user", content: "Read the file" }], metadata, ) - await generator.next() + for await (const _chunk of generator) { + /* consume */ + } - expect(mockConverseStreamCommand).toHaveBeenCalled() - const commandArg = mockConverseStreamCommand.mock.calls[0][0] as any + expect(mockStreamText).toHaveBeenCalledTimes(1) + const callArgs = mockStreamText.mock.calls[0][0] - expect(commandArg.toolConfig).toBeDefined() - expect(commandArg.toolConfig.tools).toHaveLength(2) - expect(commandArg.toolConfig.tools[0].toolSpec.name).toBe("read_file") - expect(commandArg.toolConfig.toolChoice).toEqual({ auto: {} }) + expect(callArgs.toolChoice).toBe("none") }) - it("should always include toolConfig (tools are always present after PR #10841)", async () => { - const handlerWithNativeTools = new AwsBedrockHandler({ - apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", - awsAccessKey: "test-access-key", - awsSecretKey: "test-secret-key", - awsRegion: "us-east-1", - }) + it("should pass 'required' toolChoice when tool_choice is 'required'", async () => { + setupMockStreamText() const metadata: ApiHandlerCreateMessageMetadata = { taskId: "test-task", - // Even without explicit tools, tools are always present (minimum 6 from ALWAYS_AVAILABLE_TOOLS) + tools: testTools, + tool_choice: "required", } - const generator = handlerWithNativeTools.createMessage( + const generator = handler.createMessage( "You are a helpful assistant.", - [{ role: "user", content: "Read the file at /test.txt" }], + [{ role: "user", content: "Read the file" }], metadata, ) - await generator.next() + for await (const _chunk of generator) { + /* consume */ + } - expect(mockConverseStreamCommand).toHaveBeenCalled() - const commandArg = mockConverseStreamCommand.mock.calls[0][0] as any + expect(mockStreamText).toHaveBeenCalledTimes(1) + const callArgs = mockStreamText.mock.calls[0][0] - // Tools are now always present - expect(commandArg.toolConfig).toBeDefined() - expect(commandArg.toolConfig.tools).toBeDefined() - expect(commandArg.toolConfig.toolChoice).toEqual({ auto: {} }) + expect(callArgs.toolChoice).toBe("required") }) - it("should include toolConfig with undefined toolChoice when tool_choice is none", async () => { - const handlerWithNativeTools = new AwsBedrockHandler({ - apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", - awsAccessKey: "test-access-key", - awsSecretKey: "test-secret-key", - awsRegion: "us-east-1", - }) + it("should pass specific tool choice when tool_choice names a function", async () => { + setupMockStreamText() const metadata: ApiHandlerCreateMessageMetadata = { taskId: "test-task", tools: testTools, - tool_choice: "none", // Explicitly disable tool use + tool_choice: { + type: "function", + function: { name: "read_file" }, + }, } - const generator = handlerWithNativeTools.createMessage( + const generator = handler.createMessage( "You are a helpful assistant.", - [{ role: "user", content: "Read the file at /test.txt" }], + [{ role: "user", content: "Read the file" }], metadata, ) - await generator.next() + for await (const _chunk of generator) { + /* consume */ + } - expect(mockConverseStreamCommand).toHaveBeenCalled() - const commandArg = mockConverseStreamCommand.mock.calls[0][0] as any + expect(mockStreamText).toHaveBeenCalledTimes(1) + const callArgs = mockStreamText.mock.calls[0][0] - // toolConfig is still provided but toolChoice is undefined for "none" - expect(commandArg.toolConfig).toBeDefined() - expect(commandArg.toolConfig.toolChoice).toBeUndefined() + expect(callArgs.toolChoice).toEqual({ + type: "tool", + toolName: "read_file", + }) }) + }) - it("should include fine-grained tool streaming beta for Claude models with native tools", async () => { - const handlerWithNativeTools = new AwsBedrockHandler({ - apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", - awsAccessKey: "test-access-key", - awsSecretKey: "test-secret-key", - awsRegion: "us-east-1", - }) + describe("tool call streaming events", () => { + it("should yield tool_call_start, tool_call_delta, and tool_call_end for tool input stream", async () => { + setupMockStreamTextWithToolCall() const metadata: ApiHandlerCreateMessageMetadata = { taskId: "test-task", tools: testTools, } - const generator = handlerWithNativeTools.createMessage( + const generator = handler.createMessage( "You are a helpful assistant.", - [{ role: "user", content: "Read the file at /test.txt" }], + [{ role: "user", content: "Read the file" }], metadata, ) - await generator.next() + const results: any[] = [] + for await (const chunk of generator) { + results.push(chunk) + } + + // Should have tool_call_start chunk + const startChunks = results.filter((r) => r.type === "tool_call_start") + expect(startChunks).toHaveLength(1) + expect(startChunks[0]).toEqual({ + type: "tool_call_start", + id: "tool-123", + name: "read_file", + }) - expect(mockConverseStreamCommand).toHaveBeenCalled() - const commandArg = mockConverseStreamCommand.mock.calls[0][0] as any + // Should have tool_call_delta chunk + const deltaChunks = results.filter((r) => r.type === "tool_call_delta") + expect(deltaChunks).toHaveLength(1) + expect(deltaChunks[0]).toEqual({ + type: "tool_call_delta", + id: "tool-123", + delta: '{"path": "/test.txt"}', + }) - // Should include the fine-grained tool streaming beta - expect(commandArg.additionalModelRequestFields).toBeDefined() - expect(commandArg.additionalModelRequestFields.anthropic_beta).toContain( - "fine-grained-tool-streaming-2025-05-14", - ) + // Should have tool_call_end chunk + const endChunks = results.filter((r) => r.type === "tool_call_end") + expect(endChunks).toHaveLength(1) + expect(endChunks[0]).toEqual({ + type: "tool_call_end", + id: "tool-123", + }) }) - it("should always include fine-grained tool streaming beta for Claude models", async () => { - const handlerWithNativeTools = new AwsBedrockHandler({ - apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", - awsAccessKey: "test-access-key", - awsSecretKey: "test-secret-key", - awsRegion: "us-east-1", + it("should handle mixed text and tool use content in stream", async () => { + async function* mockFullStream() { + yield { type: "text-delta", text: "Let me read that file for you." } + yield { type: "text-delta", text: " Here's what I found:" } + yield { + type: "tool-input-start", + id: "tool-789", + toolName: "read_file", + } + yield { + type: "tool-input-delta", + id: "tool-789", + delta: '{"path": "/example.txt"}', + } + yield { + type: "tool-input-end", + id: "tool-789", + } + } + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: Promise.resolve({ inputTokens: 150, outputTokens: 75 }), + providerMetadata: Promise.resolve({}), }) const metadata: ApiHandlerCreateMessageMetadata = { taskId: "test-task", - // No tools provided + tools: testTools, } - const generator = handlerWithNativeTools.createMessage( + const generator = handler.createMessage( "You are a helpful assistant.", - [{ role: "user", content: "Hello" }], + [{ role: "user", content: "Read the example file" }], metadata, ) - await generator.next() + const results: any[] = [] + for await (const chunk of generator) { + results.push(chunk) + } - expect(mockConverseStreamCommand).toHaveBeenCalled() - const commandArg = mockConverseStreamCommand.mock.calls[0][0] as any + // Should have text chunks + const textChunks = results.filter((r) => r.type === "text") + expect(textChunks).toHaveLength(2) + expect(textChunks[0].text).toBe("Let me read that file for you.") + expect(textChunks[1].text).toBe(" Here's what I found:") - // Should always include anthropic_beta with fine-grained-tool-streaming for Claude models - expect(commandArg.additionalModelRequestFields).toBeDefined() - expect(commandArg.additionalModelRequestFields.anthropic_beta).toContain( - "fine-grained-tool-streaming-2025-05-14", - ) + // Should have tool call start + const startChunks = results.filter((r) => r.type === "tool_call_start") + expect(startChunks).toHaveLength(1) + expect(startChunks[0].name).toBe("read_file") + + // Should have tool call delta + const deltaChunks = results.filter((r) => r.type === "tool_call_delta") + expect(deltaChunks).toHaveLength(1) + expect(deltaChunks[0].delta).toBe('{"path": "/example.txt"}') + + // Should have tool call end + const endChunks = results.filter((r) => r.type === "tool_call_end") + expect(endChunks).toHaveLength(1) }) - }) - describe("tool call streaming events", () => { - it("should yield tool_call_partial for toolUse block start", async () => { - const handlerWithNativeTools = new AwsBedrockHandler({ - apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", - awsAccessKey: "test-access-key", - awsSecretKey: "test-secret-key", - awsRegion: "us-east-1", + it("should handle multiple tool calls in a single stream", async () => { + async function* mockFullStream() { + yield { + type: "tool-input-start", + id: "tool-1", + toolName: "read_file", + } + yield { + type: "tool-input-delta", + id: "tool-1", + delta: '{"path": "/file1.txt"}', + } + yield { + type: "tool-input-end", + id: "tool-1", + } + yield { + type: "tool-input-start", + id: "tool-2", + toolName: "write_file", + } + yield { + type: "tool-input-delta", + id: "tool-2", + delta: '{"path": "/file2.txt", "content": "hello"}', + } + yield { + type: "tool-input-end", + id: "tool-2", + } + } + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: Promise.resolve({ inputTokens: 200, outputTokens: 100 }), + providerMetadata: Promise.resolve({}), }) - // Mock stream with tool use events - mockSend.mockResolvedValue({ - stream: (async function* () { - yield { - contentBlockStart: { - contentBlockIndex: 0, - start: { - toolUse: { - toolUseId: "tool-123", - name: "read_file", - }, - }, - }, - } - yield { - contentBlockDelta: { - contentBlockIndex: 0, - delta: { - toolUse: { - input: '{"path": "/test.txt"}', - }, - }, - }, - } - yield { - metadata: { - usage: { - inputTokens: 100, - outputTokens: 50, - }, - }, - } - })(), - }) + const metadata: ApiHandlerCreateMessageMetadata = { + taskId: "test-task", + tools: testTools, + } - const generator = handlerWithNativeTools.createMessage("You are a helpful assistant.", [ - { role: "user", content: "Read the file" }, - ]) + const generator = handler.createMessage( + "You are a helpful assistant.", + [{ role: "user", content: "Read and write files" }], + metadata, + ) const results: any[] = [] for await (const chunk of generator) { results.push(chunk) } - // Should have tool_call_partial chunks - const toolCallChunks = results.filter((r) => r.type === "tool_call_partial") - expect(toolCallChunks).toHaveLength(2) + // Should have two tool_call_start chunks + const startChunks = results.filter((r) => r.type === "tool_call_start") + expect(startChunks).toHaveLength(2) + expect(startChunks[0].name).toBe("read_file") + expect(startChunks[1].name).toBe("write_file") - // First chunk should have id and name - expect(toolCallChunks[0]).toEqual({ - type: "tool_call_partial", - index: 0, - id: "tool-123", - name: "read_file", - arguments: undefined, - }) + // Should have two tool_call_delta chunks + const deltaChunks = results.filter((r) => r.type === "tool_call_delta") + expect(deltaChunks).toHaveLength(2) - // Second chunk should have arguments - expect(toolCallChunks[1]).toEqual({ - type: "tool_call_partial", - index: 0, - id: undefined, - name: undefined, - arguments: '{"path": "/test.txt"}', - }) + // Should have two tool_call_end chunks + const endChunks = results.filter((r) => r.type === "tool_call_end") + expect(endChunks).toHaveLength(2) }) + }) - it("should yield tool_call_partial for contentBlock toolUse structure", async () => { - const handlerWithNativeTools = new AwsBedrockHandler({ - apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", - awsAccessKey: "test-access-key", - awsSecretKey: "test-secret-key", - awsRegion: "us-east-1", - }) + describe("tools schema normalization", () => { + it("should apply schema normalization (additionalProperties: false, strict: true) via convertToolsForOpenAI", async () => { + setupMockStreamText() - // Mock stream with alternative tool use structure - mockSend.mockResolvedValue({ - stream: (async function* () { - yield { - contentBlockStart: { - contentBlockIndex: 0, - contentBlock: { - toolUse: { - toolUseId: "tool-456", - name: "write_file", + const metadata: ApiHandlerCreateMessageMetadata = { + taskId: "test-task", + tools: [ + { + type: "function" as const, + function: { + name: "test_tool", + description: "A test tool", + parameters: { + type: "object", + properties: { + arg1: { type: "string" }, }, + // Note: no "required" field and no "additionalProperties" }, }, - } - yield { - metadata: { - usage: { - inputTokens: 100, - outputTokens: 50, - }, - }, - } - })(), - }) + }, + ], + } - const generator = handlerWithNativeTools.createMessage("You are a helpful assistant.", [ - { role: "user", content: "Write a file" }, - ]) + const generator = handler.createMessage( + "You are a helpful assistant.", + [{ role: "user", content: "test" }], + metadata, + ) - const results: any[] = [] - for await (const chunk of generator) { - results.push(chunk) + for await (const _chunk of generator) { + /* consume */ } - // Should have tool_call_partial chunk - const toolCallChunks = results.filter((r) => r.type === "tool_call_partial") - expect(toolCallChunks).toHaveLength(1) + expect(mockStreamText).toHaveBeenCalledTimes(1) + const callArgs = mockStreamText.mock.calls[0][0] - expect(toolCallChunks[0]).toEqual({ - type: "tool_call_partial", - index: 0, - id: "tool-456", - name: "write_file", - arguments: undefined, - }) + // The AI SDK tools should be keyed by tool name + expect(callArgs.tools).toBeDefined() + expect(callArgs.tools.test_tool).toBeDefined() }) + }) - it("should handle mixed text and tool use content", async () => { - const handlerWithNativeTools = new AwsBedrockHandler({ - apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", - awsAccessKey: "test-access-key", - awsSecretKey: "test-secret-key", - awsRegion: "us-east-1", - }) + describe("usage metrics with tools", () => { + it("should yield usage chunk after tool call stream completes", async () => { + setupMockStreamTextWithToolCall() - // Mock stream with mixed content - mockSend.mockResolvedValue({ - stream: (async function* () { - yield { - contentBlockStart: { - contentBlockIndex: 0, - start: { - text: "Let me read that file for you.", - }, - }, - } - yield { - contentBlockDelta: { - contentBlockIndex: 0, - delta: { - text: " Here's what I found:", - }, - }, - } - yield { - contentBlockStart: { - contentBlockIndex: 1, - start: { - toolUse: { - toolUseId: "tool-789", - name: "read_file", - }, - }, - }, - } - yield { - contentBlockDelta: { - contentBlockIndex: 1, - delta: { - toolUse: { - input: '{"path": "/example.txt"}', - }, - }, - }, - } - yield { - metadata: { - usage: { - inputTokens: 150, - outputTokens: 75, - }, - }, - } - })(), - }) + const metadata: ApiHandlerCreateMessageMetadata = { + taskId: "test-task", + tools: testTools, + } - const generator = handlerWithNativeTools.createMessage("You are a helpful assistant.", [ - { role: "user", content: "Read the example file" }, - ]) + const generator = handler.createMessage( + "You are a helpful assistant.", + [{ role: "user", content: "Read a file" }], + metadata, + ) const results: any[] = [] for await (const chunk of generator) { results.push(chunk) } - // Should have text chunks - const textChunks = results.filter((r) => r.type === "text") - expect(textChunks).toHaveLength(2) - expect(textChunks[0].text).toBe("Let me read that file for you.") - expect(textChunks[1].text).toBe(" Here's what I found:") - - // Should have tool call chunks - const toolCallChunks = results.filter((r) => r.type === "tool_call_partial") - expect(toolCallChunks).toHaveLength(2) - expect(toolCallChunks[0].name).toBe("read_file") - expect(toolCallChunks[1].arguments).toBe('{"path": "/example.txt"}') + // Should have a usage chunk at the end + const usageChunks = results.filter((r) => r.type === "usage") + expect(usageChunks).toHaveLength(1) + expect(usageChunks[0].inputTokens).toBe(100) + expect(usageChunks[0].outputTokens).toBe(50) }) }) }) diff --git a/src/api/providers/__tests__/bedrock-reasoning.spec.ts b/src/api/providers/__tests__/bedrock-reasoning.spec.ts index 9dd271744c2..dfe35d4d8e2 100644 --- a/src/api/providers/__tests__/bedrock-reasoning.spec.ts +++ b/src/api/providers/__tests__/bedrock-reasoning.spec.ts @@ -1,37 +1,48 @@ -// npx vitest api/providers/__tests__/bedrock-reasoning.test.ts +// npx vitest run api/providers/__tests__/bedrock-reasoning.spec.ts + +// Use vi.hoisted to define mock functions for AI SDK +const { mockStreamText, mockGenerateText, mockCreateAmazonBedrock } = vi.hoisted(() => ({ + mockStreamText: vi.fn(), + mockGenerateText: vi.fn(), + mockCreateAmazonBedrock: vi.fn(() => vi.fn(() => ({ modelId: "test", provider: "bedrock" }))), +})) + +vi.mock("ai", async (importOriginal) => { + const actual = await importOriginal() + return { + ...actual, + streamText: mockStreamText, + generateText: mockGenerateText, + } +}) + +vi.mock("@ai-sdk/amazon-bedrock", () => ({ + createAmazonBedrock: mockCreateAmazonBedrock, +})) + +// Mock AWS SDK credential providers +vi.mock("@aws-sdk/credential-providers", () => ({ + fromIni: vi.fn().mockReturnValue(async () => ({ + accessKeyId: "profile-access-key", + secretAccessKey: "profile-secret-key", + })), +})) + +vi.mock("../../../utils/logging", () => ({ + logger: { + info: vi.fn(), + error: vi.fn(), + warn: vi.fn(), + debug: vi.fn(), + }, +})) import { AwsBedrockHandler } from "../bedrock" -import { BedrockRuntimeClient, ConverseStreamCommand } from "@aws-sdk/client-bedrock-runtime" import { logger } from "../../../utils/logging" -// Mock the AWS SDK -vi.mock("@aws-sdk/client-bedrock-runtime") -vi.mock("../../../utils/logging") - -// Store the command payload for verification -let capturedPayload: any = null - describe("AwsBedrockHandler - Extended Thinking", () => { - let handler: AwsBedrockHandler - let mockSend: ReturnType - beforeEach(() => { - capturedPayload = null - mockSend = vi.fn() - - // Mock ConverseStreamCommand to capture the payload - ;(ConverseStreamCommand as unknown as ReturnType).mockImplementation((payload) => { - capturedPayload = payload - return { - input: payload, - } - }) - ;(BedrockRuntimeClient as unknown as ReturnType).mockImplementation(() => ({ - send: mockSend, - config: { region: "us-east-1" }, - })) - ;(logger.info as unknown as ReturnType).mockImplementation(() => {}) - ;(logger.error as unknown as ReturnType).mockImplementation(() => {}) + vi.clearAllMocks() }) afterEach(() => { @@ -39,8 +50,8 @@ describe("AwsBedrockHandler - Extended Thinking", () => { }) describe("Extended Thinking Support", () => { - it("should include thinking parameter for Claude Sonnet 4 when reasoning is enabled", async () => { - handler = new AwsBedrockHandler({ + it("should include reasoningConfig in providerOptions when reasoning is enabled", async () => { + const handler = new AwsBedrockHandler({ apiProvider: "bedrock", apiModelId: "anthropic.claude-sonnet-4-20250514-v1:0", awsRegion: "us-east-1", @@ -49,35 +60,17 @@ describe("AwsBedrockHandler - Extended Thinking", () => { modelMaxThinkingTokens: 4096, }) - // Mock the stream response - mockSend.mockResolvedValue({ - stream: (async function* () { - yield { - messageStart: { role: "assistant" }, - } - yield { - contentBlockStart: { - content_block: { type: "thinking", thinking: "Let me think..." }, - contentBlockIndex: 0, - }, - } - yield { - contentBlockDelta: { - delta: { type: "thinking_delta", thinking: " about this problem." }, - }, - } - yield { - contentBlockStart: { - start: { text: "Here's the answer:" }, - contentBlockIndex: 1, - }, - } - yield { - metadata: { - usage: { inputTokens: 100, outputTokens: 50 }, - }, - } - })(), + // Mock stream with reasoning content + async function* mockFullStream() { + yield { type: "reasoning", text: "Let me think..." } + yield { type: "reasoning", text: " about this problem." } + yield { type: "text-delta", text: "Here's the answer:" } + } + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: Promise.resolve({ inputTokens: 100, outputTokens: 50 }), + providerMetadata: Promise.resolve({}), }) const messages = [{ role: "user" as const, content: "Test message" }] @@ -88,13 +81,14 @@ describe("AwsBedrockHandler - Extended Thinking", () => { chunks.push(chunk) } - // Verify the command was called with the correct payload - expect(mockSend).toHaveBeenCalledTimes(1) - expect(capturedPayload).toBeDefined() - expect(capturedPayload.additionalModelRequestFields).toBeDefined() - expect(capturedPayload.additionalModelRequestFields.thinking).toEqual({ + // Verify streamText was called with providerOptions containing reasoningConfig + expect(mockStreamText).toHaveBeenCalledTimes(1) + const callArgs = mockStreamText.mock.calls[0][0] + expect(callArgs.providerOptions).toBeDefined() + expect(callArgs.providerOptions.bedrock).toBeDefined() + expect(callArgs.providerOptions.bedrock.reasoningConfig).toEqual({ type: "enabled", - budget_tokens: 4096, // Uses the full modelMaxThinkingTokens value + budgetTokens: 4096, }) // Verify reasoning chunks were yielded @@ -102,157 +96,145 @@ describe("AwsBedrockHandler - Extended Thinking", () => { expect(reasoningChunks).toHaveLength(2) expect(reasoningChunks[0].text).toBe("Let me think...") expect(reasoningChunks[1].text).toBe(" about this problem.") - - // Verify that topP is NOT present when thinking is enabled - expect(capturedPayload.inferenceConfig).not.toHaveProperty("topP") }) - it("should pass thinking parameters from metadata", async () => { - handler = new AwsBedrockHandler({ + it("should not include reasoningConfig when reasoning is disabled", async () => { + const handler = new AwsBedrockHandler({ apiProvider: "bedrock", apiModelId: "anthropic.claude-3-7-sonnet-20250219-v1:0", awsRegion: "us-east-1", + // Note: no enableReasoningEffort = true, so thinking is disabled }) - mockSend.mockResolvedValue({ - stream: (async function* () { - yield { messageStart: { role: "assistant" } } - yield { metadata: { usage: { inputTokens: 100, outputTokens: 50 } } } - })(), + async function* mockFullStream() { + yield { type: "text-delta", text: "Hello world" } + } + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: Promise.resolve({ inputTokens: 100, outputTokens: 50 }), + providerMetadata: Promise.resolve({}), }) const messages = [{ role: "user" as const, content: "Test message" }] - const metadata = { - taskId: "test-task", - thinking: { - enabled: true, - maxTokens: 16384, - maxThinkingTokens: 8192, - }, - } + const stream = handler.createMessage("System prompt", messages) - const stream = handler.createMessage("System prompt", messages, metadata) const chunks = [] for await (const chunk of stream) { chunks.push(chunk) } - // Verify the thinking parameter was passed correctly - expect(mockSend).toHaveBeenCalledTimes(1) - expect(capturedPayload).toBeDefined() - expect(capturedPayload.additionalModelRequestFields).toBeDefined() - expect(capturedPayload.additionalModelRequestFields.thinking).toEqual({ - type: "enabled", - budget_tokens: 8192, - }) - - // Verify that topP is NOT present when thinking is enabled via metadata - expect(capturedPayload.inferenceConfig).not.toHaveProperty("topP") + // Verify streamText was called — providerOptions should not contain reasoningConfig + expect(mockStreamText).toHaveBeenCalledTimes(1) + const callArgs = mockStreamText.mock.calls[0][0] + const bedrockOpts = callArgs.providerOptions?.bedrock + expect(bedrockOpts?.reasoningConfig).toBeUndefined() }) - it("should log when extended thinking is enabled", async () => { - handler = new AwsBedrockHandler({ + it("should capture thinking signature from stream providerMetadata", async () => { + const handler = new AwsBedrockHandler({ apiProvider: "bedrock", - apiModelId: "anthropic.claude-opus-4-20250514-v1:0", + apiModelId: "anthropic.claude-sonnet-4-20250514-v1:0", awsRegion: "us-east-1", enableReasoningEffort: true, - modelMaxThinkingTokens: 5000, + modelMaxThinkingTokens: 4096, }) - mockSend.mockResolvedValue({ - stream: (async function* () { - yield { messageStart: { role: "assistant" } } - })(), + const testSignature = "test-thinking-signature-abc123" + + // Mock stream with reasoning content that includes a signature in providerMetadata + async function* mockFullStream() { + yield { type: "reasoning", text: "Let me think..." } + // The SDK emits signature as a reasoning-delta with providerMetadata.bedrock.signature + yield { + type: "reasoning", + text: "", + providerMetadata: { bedrock: { signature: testSignature } }, + } + yield { type: "text-delta", text: "Answer" } + } + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: Promise.resolve({ inputTokens: 100, outputTokens: 50 }), + providerMetadata: Promise.resolve({}), }) - const messages = [{ role: "user" as const, content: "Test" }] + const messages = [{ role: "user" as const, content: "Test message" }] const stream = handler.createMessage("System prompt", messages) - for await (const chunk of stream) { + for await (const _chunk of stream) { // consume stream } - // Verify logging - expect(logger.info).toHaveBeenCalledWith( - expect.stringContaining("Extended thinking enabled"), - expect.objectContaining({ - ctx: "bedrock", - modelId: "anthropic.claude-opus-4-20250514-v1:0", - }), - ) + // Verify thinking signature was captured + expect(handler.getThoughtSignature()).toBe(testSignature) }) - it("should not include topP when thinking is disabled (global removal)", async () => { - handler = new AwsBedrockHandler({ + it("should capture redacted thinking blocks from stream providerMetadata", async () => { + const handler = new AwsBedrockHandler({ apiProvider: "bedrock", - apiModelId: "anthropic.claude-3-7-sonnet-20250219-v1:0", + apiModelId: "anthropic.claude-sonnet-4-20250514-v1:0", awsRegion: "us-east-1", - // Note: no enableReasoningEffort = true, so thinking is disabled + enableReasoningEffort: true, + modelMaxThinkingTokens: 4096, }) - mockSend.mockResolvedValue({ - stream: (async function* () { - yield { messageStart: { role: "assistant" } } - yield { - contentBlockStart: { - start: { text: "Hello" }, - contentBlockIndex: 0, - }, - } - yield { - contentBlockDelta: { - delta: { text: " world" }, - }, - } - yield { metadata: { usage: { inputTokens: 100, outputTokens: 50 } } } - })(), + const redactedData = "base64-encoded-redacted-data" + + // Mock stream with redacted reasoning content + async function* mockFullStream() { + yield { type: "reasoning", text: "Some thinking..." } + yield { + type: "reasoning", + text: "", + providerMetadata: { bedrock: { redactedData } }, + } + yield { type: "text-delta", text: "Answer" } + } + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: Promise.resolve({ inputTokens: 100, outputTokens: 50 }), + providerMetadata: Promise.resolve({}), }) const messages = [{ role: "user" as const, content: "Test message" }] const stream = handler.createMessage("System prompt", messages) - const chunks = [] - for await (const chunk of stream) { - chunks.push(chunk) + for await (const _chunk of stream) { + // consume stream } - // Verify that topP is NOT present for any model (removed globally) - expect(mockSend).toHaveBeenCalledTimes(1) - expect(capturedPayload).toBeDefined() - expect(capturedPayload.inferenceConfig).not.toHaveProperty("topP") - - // Verify that additionalModelRequestFields contains fine-grained-tool-streaming for Claude models - expect(capturedPayload.additionalModelRequestFields).toBeDefined() - expect(capturedPayload.additionalModelRequestFields.anthropic_beta).toContain( - "fine-grained-tool-streaming-2025-05-14", - ) + // Verify redacted thinking blocks were captured + const redactedBlocks = handler.getRedactedThinkingBlocks() + expect(redactedBlocks).toBeDefined() + expect(redactedBlocks).toHaveLength(1) + expect(redactedBlocks![0]).toEqual({ + type: "redacted_thinking", + data: redactedData, + }) }) it("should enable reasoning when enableReasoningEffort is true in settings", async () => { - handler = new AwsBedrockHandler({ + const handler = new AwsBedrockHandler({ apiProvider: "bedrock", apiModelId: "anthropic.claude-sonnet-4-20250514-v1:0", awsRegion: "us-east-1", - enableReasoningEffort: true, // This should trigger reasoning + enableReasoningEffort: true, modelMaxThinkingTokens: 4096, }) - mockSend.mockResolvedValue({ - stream: (async function* () { - yield { messageStart: { role: "assistant" } } - yield { - contentBlockStart: { - content_block: { type: "thinking", thinking: "Let me think..." }, - contentBlockIndex: 0, - }, - } - yield { - contentBlockDelta: { - delta: { type: "thinking_delta", thinking: " about this problem." }, - }, - } - yield { metadata: { usage: { inputTokens: 100, outputTokens: 50 } } } - })(), + async function* mockFullStream() { + yield { type: "reasoning", text: "Let me think..." } + yield { type: "reasoning", text: " about this problem." } + yield { type: "text-delta", text: "Test response" } + } + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: Promise.resolve({ inputTokens: 100, outputTokens: 50 }), + providerMetadata: Promise.resolve({}), }) const messages = [{ role: "user" as const, content: "Test message" }] @@ -264,17 +246,13 @@ describe("AwsBedrockHandler - Extended Thinking", () => { } // Verify thinking was enabled via settings - expect(mockSend).toHaveBeenCalledTimes(1) - expect(capturedPayload).toBeDefined() - expect(capturedPayload.additionalModelRequestFields).toBeDefined() - expect(capturedPayload.additionalModelRequestFields.thinking).toEqual({ + expect(mockStreamText).toHaveBeenCalledTimes(1) + const callArgs = mockStreamText.mock.calls[0][0] + expect(callArgs.providerOptions?.bedrock?.reasoningConfig).toEqual({ type: "enabled", - budget_tokens: 4096, + budgetTokens: 4096, }) - // Verify that topP is NOT present when thinking is enabled via settings - expect(capturedPayload.inferenceConfig).not.toHaveProperty("topP") - // Verify reasoning chunks were yielded const reasoningChunks = chunks.filter((c) => c.type === "reasoning") expect(reasoningChunks).toHaveLength(2) @@ -282,8 +260,8 @@ describe("AwsBedrockHandler - Extended Thinking", () => { expect(reasoningChunks[1].text).toBe(" about this problem.") }) - it("should support API key authentication", async () => { - handler = new AwsBedrockHandler({ + it("should support API key authentication via createAmazonBedrock", () => { + new AwsBedrockHandler({ apiProvider: "bedrock", apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", awsRegion: "us-east-1", @@ -291,41 +269,13 @@ describe("AwsBedrockHandler - Extended Thinking", () => { awsApiKey: "test-api-key-token", }) - mockSend.mockResolvedValue({ - stream: (async function* () { - yield { messageStart: { role: "assistant" } } - yield { - contentBlockStart: { - start: { text: "Hello from API key auth" }, - contentBlockIndex: 0, - }, - } - yield { metadata: { usage: { inputTokens: 100, outputTokens: 50 } } } - })(), - }) - - const messages = [{ role: "user" as const, content: "Test message" }] - const stream = handler.createMessage("System prompt", messages) - - const chunks = [] - for await (const chunk of stream) { - chunks.push(chunk) - } - - // Verify the client was created with API key token - expect(BedrockRuntimeClient).toHaveBeenCalledWith( + // Verify the provider was created with API key + expect(mockCreateAmazonBedrock).toHaveBeenCalledWith( expect.objectContaining({ region: "us-east-1", - token: { token: "test-api-key-token" }, - authSchemePreference: ["httpBearerAuth"], + apiKey: "test-api-key-token", }), ) - - // Verify the stream worked correctly - expect(mockSend).toHaveBeenCalledTimes(1) - const textChunks = chunks.filter((c) => c.type === "text") - expect(textChunks).toHaveLength(1) - expect(textChunks[0].text).toBe("Hello from API key auth") }) }) }) diff --git a/src/api/providers/__tests__/bedrock-vpc-endpoint.spec.ts b/src/api/providers/__tests__/bedrock-vpc-endpoint.spec.ts index 7823775beaa..19bb68bb775 100644 --- a/src/api/providers/__tests__/bedrock-vpc-endpoint.spec.ts +++ b/src/api/providers/__tests__/bedrock-vpc-endpoint.spec.ts @@ -7,38 +7,40 @@ vi.mock("@aws-sdk/credential-providers", () => { return { fromIni: mockFromIni } }) -// Mock BedrockRuntimeClient and ConverseStreamCommand -vi.mock("@aws-sdk/client-bedrock-runtime", () => { - const mockSend = vi.fn().mockResolvedValue({ - stream: [], - }) - const mockBedrockRuntimeClient = vi.fn().mockImplementation(() => ({ - send: mockSend, - })) - +// Use vi.hoisted to define mock functions for AI SDK +const { mockStreamText, mockGenerateText } = vi.hoisted(() => ({ + mockStreamText: vi.fn(), + mockGenerateText: vi.fn(), +})) + +vi.mock("ai", async (importOriginal) => { + const actual = await importOriginal() return { - BedrockRuntimeClient: mockBedrockRuntimeClient, - ConverseStreamCommand: vi.fn(), - ConverseCommand: vi.fn(), + ...actual, + streamText: mockStreamText, + generateText: mockGenerateText, } }) -import { AwsBedrockHandler } from "../bedrock" -import { BedrockRuntimeClient } from "@aws-sdk/client-bedrock-runtime" +// Mock createAmazonBedrock so we can inspect how it was called +const { mockCreateAmazonBedrock } = vi.hoisted(() => ({ + mockCreateAmazonBedrock: vi.fn(() => vi.fn(() => ({ modelId: "test", provider: "bedrock" }))), +})) -// Get access to the mocked functions -const mockBedrockRuntimeClient = vi.mocked(BedrockRuntimeClient) +vi.mock("@ai-sdk/amazon-bedrock", () => ({ + createAmazonBedrock: mockCreateAmazonBedrock, +})) + +import { AwsBedrockHandler } from "../bedrock" describe("Amazon Bedrock VPC Endpoint Functionality", () => { beforeEach(() => { - // Clear all mocks before each test vi.clearAllMocks() }) // Test Scenario 1: Input Validation Test describe("VPC Endpoint URL Validation", () => { - it("should configure client with endpoint URL when both URL and enabled flag are provided", () => { - // Create handler with endpoint URL and enabled flag + it("should configure provider with baseURL when both URL and enabled flag are provided", () => { new AwsBedrockHandler({ apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", awsAccessKey: "test-access-key", @@ -48,17 +50,15 @@ describe("Amazon Bedrock VPC Endpoint Functionality", () => { awsBedrockEndpointEnabled: true, }) - // Verify the client was created with the correct endpoint - expect(mockBedrockRuntimeClient).toHaveBeenCalledWith( + expect(mockCreateAmazonBedrock).toHaveBeenCalledWith( expect.objectContaining({ region: "us-east-1", - endpoint: "https://bedrock-vpc.example.com", + baseURL: "https://bedrock-vpc.example.com", }), ) }) - it("should not configure client with endpoint URL when URL is provided but enabled flag is false", () => { - // Create handler with endpoint URL but disabled flag + it("should not configure provider with baseURL when URL is provided but enabled flag is false", () => { new AwsBedrockHandler({ apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", awsAccessKey: "test-access-key", @@ -68,23 +68,23 @@ describe("Amazon Bedrock VPC Endpoint Functionality", () => { awsBedrockEndpointEnabled: false, }) - // Verify the client was created without the endpoint - expect(mockBedrockRuntimeClient).toHaveBeenCalledWith( + expect(mockCreateAmazonBedrock).toHaveBeenCalledWith( expect.objectContaining({ region: "us-east-1", }), ) - // Verify the endpoint property is not present - const clientConfig = mockBedrockRuntimeClient.mock.calls[0][0] - expect(clientConfig).not.toHaveProperty("endpoint") + const providerSettings = (mockCreateAmazonBedrock.mock.calls as unknown[][])[0][0] as Record< + string, + unknown + > + expect(providerSettings).not.toHaveProperty("baseURL") }) }) // Test Scenario 2: Edge Case Tests describe("Edge Cases", () => { it("should handle empty endpoint URL gracefully", () => { - // Create handler with empty endpoint URL but enabled flag new AwsBedrockHandler({ apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", awsAccessKey: "test-access-key", @@ -94,20 +94,21 @@ describe("Amazon Bedrock VPC Endpoint Functionality", () => { awsBedrockEndpointEnabled: true, }) - // Verify the client was created without the endpoint (since it's empty) - expect(mockBedrockRuntimeClient).toHaveBeenCalledWith( + expect(mockCreateAmazonBedrock).toHaveBeenCalledWith( expect.objectContaining({ region: "us-east-1", }), ) - // Verify the endpoint property is not present - const clientConfig = mockBedrockRuntimeClient.mock.calls[0][0] - expect(clientConfig).not.toHaveProperty("endpoint") + // Empty string is falsy, so baseURL should not be set + const providerSettings = (mockCreateAmazonBedrock.mock.calls as unknown[][])[0][0] as Record< + string, + unknown + > + expect(providerSettings).not.toHaveProperty("baseURL") }) it("should handle undefined endpoint URL gracefully", () => { - // Create handler with undefined endpoint URL but enabled flag new AwsBedrockHandler({ apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", awsAccessKey: "test-access-key", @@ -117,23 +118,23 @@ describe("Amazon Bedrock VPC Endpoint Functionality", () => { awsBedrockEndpointEnabled: true, }) - // Verify the client was created without the endpoint - expect(mockBedrockRuntimeClient).toHaveBeenCalledWith( + expect(mockCreateAmazonBedrock).toHaveBeenCalledWith( expect.objectContaining({ region: "us-east-1", }), ) - // Verify the endpoint property is not present - const clientConfig = mockBedrockRuntimeClient.mock.calls[0][0] - expect(clientConfig).not.toHaveProperty("endpoint") + const providerSettings = (mockCreateAmazonBedrock.mock.calls as unknown[][])[0][0] as Record< + string, + unknown + > + expect(providerSettings).not.toHaveProperty("baseURL") }) }) - // Test Scenario 4: Error Handling Tests + // Test Scenario 3: Error Handling Tests describe("Error Handling", () => { - it("should handle invalid endpoint URLs by passing them directly to AWS SDK", () => { - // Create handler with an invalid URL format + it("should handle invalid endpoint URLs by passing them directly to the provider", () => { new AwsBedrockHandler({ apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", awsAccessKey: "test-access-key", @@ -143,21 +144,24 @@ describe("Amazon Bedrock VPC Endpoint Functionality", () => { awsBedrockEndpointEnabled: true, }) - // Verify the client was created with the invalid endpoint - // (AWS SDK will handle the validation/errors) - expect(mockBedrockRuntimeClient).toHaveBeenCalledWith( + // The invalid URL is passed directly; the provider/SDK will handle validation + expect(mockCreateAmazonBedrock).toHaveBeenCalledWith( expect.objectContaining({ region: "us-east-1", - endpoint: "invalid-url-format", + baseURL: "invalid-url-format", }), ) }) }) - // Test Scenario 5: Persistence Tests + // Test Scenario 4: Persistence Tests describe("Persistence", () => { it("should maintain consistent behavior across multiple requests", async () => { - // Create handler with endpoint URL and enabled flag + mockGenerateText.mockResolvedValue({ + text: "test response", + usage: { promptTokens: 10, completionTokens: 5 }, + }) + const handler = new AwsBedrockHandler({ apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", awsAccessKey: "test-access-key", @@ -167,23 +171,23 @@ describe("Amazon Bedrock VPC Endpoint Functionality", () => { awsBedrockEndpointEnabled: true, }) - // Verify the client was configured with the endpoint - expect(mockBedrockRuntimeClient).toHaveBeenCalledWith( + // Verify the provider was configured with the endpoint + expect(mockCreateAmazonBedrock).toHaveBeenCalledWith( expect.objectContaining({ region: "us-east-1", - endpoint: "https://bedrock-vpc.example.com", + baseURL: "https://bedrock-vpc.example.com", }), ) // Make a request to ensure the endpoint configuration persists try { await handler.completePrompt("Test prompt") - } catch (error) { - // Ignore errors, we're just testing the client configuration persistence + } catch { + // Ignore errors — we're just testing the provider configuration persistence } - // Verify the client instance was created and used - expect(mockBedrockRuntimeClient).toHaveBeenCalled() + // The provider factory should have been called exactly once (during construction) + expect(mockCreateAmazonBedrock).toHaveBeenCalledTimes(1) }) }) }) diff --git a/src/api/providers/__tests__/bedrock.spec.ts b/src/api/providers/__tests__/bedrock.spec.ts index 115cb9fb405..2cb09fc56db 100644 --- a/src/api/providers/__tests__/bedrock.spec.ts +++ b/src/api/providers/__tests__/bedrock.spec.ts @@ -18,24 +18,26 @@ vi.mock("@aws-sdk/credential-providers", () => { return { fromIni: mockFromIni } }) -// Mock BedrockRuntimeClient and ConverseStreamCommand -vi.mock("@aws-sdk/client-bedrock-runtime", () => { - const mockSend = vi.fn().mockResolvedValue({ - stream: [], - }) - const mockConverseStreamCommand = vi.fn() +// Use vi.hoisted to define mock functions for AI SDK +const { mockStreamText, mockGenerateText } = vi.hoisted(() => ({ + mockStreamText: vi.fn(), + mockGenerateText: vi.fn(), +})) +vi.mock("ai", async (importOriginal) => { + const actual = await importOriginal() return { - BedrockRuntimeClient: vi.fn().mockImplementation(() => ({ - send: mockSend, - })), - ConverseStreamCommand: mockConverseStreamCommand, - ConverseCommand: vi.fn(), + ...actual, + streamText: mockStreamText, + generateText: mockGenerateText, } }) +vi.mock("@ai-sdk/amazon-bedrock", () => ({ + createAmazonBedrock: vi.fn(() => vi.fn(() => ({ modelId: "test", provider: "bedrock" }))), +})) + import { AwsBedrockHandler } from "../bedrock" -import { ConverseStreamCommand, BedrockRuntimeClient, ConverseCommand } from "@aws-sdk/client-bedrock-runtime" import { BEDROCK_1M_CONTEXT_MODEL_IDS, BEDROCK_SERVICE_TIER_MODEL_IDS, @@ -45,10 +47,6 @@ import { import type { Anthropic } from "@anthropic-ai/sdk" -// Get access to the mocked functions -const mockConverseStreamCommand = vi.mocked(ConverseStreamCommand) -const mockBedrockRuntimeClient = vi.mocked(BedrockRuntimeClient) - describe("AwsBedrockHandler", () => { let handler: AwsBedrockHandler @@ -478,12 +476,20 @@ describe("AwsBedrockHandler", () => { describe("image handling", () => { const mockImageData = Buffer.from("test-image-data").toString("base64") - beforeEach(() => { - // Reset the mocks before each test - mockConverseStreamCommand.mockReset() - }) + function setupMockStreamText() { + async function* mockFullStream() { + yield { type: "text-delta", text: "I see an image" } + } + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: Promise.resolve({ inputTokens: 10, outputTokens: 5 }), + providerMetadata: Promise.resolve({}), + }) + } + + it("should properly pass image content through to streamText via AI SDK messages", async () => { + setupMockStreamText() - it("should properly convert image content to Bedrock format", async () => { const messages: Anthropic.Messages.MessageParam[] = [ { role: "user", @@ -505,42 +511,39 @@ describe("AwsBedrockHandler", () => { ] const generator = handler.createMessage("", messages) - await generator.next() // Start the generator - - // Verify the command was created with the right payload - expect(mockConverseStreamCommand).toHaveBeenCalled() - const commandArg = mockConverseStreamCommand.mock.calls[0][0] - - // Verify the image was properly formatted - const imageBlock = commandArg.messages![0].content![0] - expect(imageBlock).toHaveProperty("image") - expect(imageBlock.image).toHaveProperty("format", "jpeg") - expect(imageBlock.image!.source).toHaveProperty("bytes") - expect(imageBlock.image!.source!.bytes).toBeInstanceOf(Uint8Array) - }) - - it("should reject unsupported image formats", async () => { - const messages: Anthropic.Messages.MessageParam[] = [ - { - role: "user", - content: [ - { - type: "image", - source: { - type: "base64", - data: mockImageData, - media_type: "image/tiff" as "image/jpeg", // Type assertion to bypass TS - }, - }, - ], - }, - ] - - const generator = handler.createMessage("", messages) - await expect(generator.next()).rejects.toThrow("Unsupported image format: tiff") + const chunks: unknown[] = [] + for await (const chunk of generator) { + chunks.push(chunk) + } + + // Verify streamText was called + expect(mockStreamText).toHaveBeenCalledTimes(1) + const callArgs = mockStreamText.mock.calls[0][0] + + // Verify messages were converted to AI SDK format with image parts + const aiSdkMessages = callArgs.messages + expect(aiSdkMessages).toBeDefined() + expect(aiSdkMessages.length).toBeGreaterThan(0) + + // Find the user message containing image content + const userMsg = aiSdkMessages.find((m: { role: string }) => m.role === "user") + expect(userMsg).toBeDefined() + expect(Array.isArray(userMsg.content)).toBe(true) + + // The AI SDK convertToAiSdkMessages converts images to { type: "image", image: "data:...", mimeType: "..." } + const imagePart = userMsg.content.find((p: { type: string }) => p.type === "image") + expect(imagePart).toBeDefined() + expect(imagePart.image).toContain("data:image/jpeg;base64,") + expect(imagePart.mimeType).toBe("image/jpeg") + + const textPart = userMsg.content.find((p: { type: string }) => p.type === "text") + expect(textPart).toBeDefined() + expect(textPart.text).toBe("What's in this image?") }) it("should handle multiple images in a single message", async () => { + setupMockStreamText() + const messages: Anthropic.Messages.MessageParam[] = [ { role: "user", @@ -574,20 +577,25 @@ describe("AwsBedrockHandler", () => { ] const generator = handler.createMessage("", messages) - await generator.next() // Start the generator - - // Verify the command was created with the right payload - expect(mockConverseStreamCommand).toHaveBeenCalled() - const commandArg = mockConverseStreamCommand.mock.calls[0][0] - - // Verify both images were properly formatted - const firstImage = commandArg.messages![0].content![0] - const secondImage = commandArg.messages![0].content![2] - - expect(firstImage).toHaveProperty("image") - expect(firstImage.image).toHaveProperty("format", "jpeg") - expect(secondImage).toHaveProperty("image") - expect(secondImage.image).toHaveProperty("format", "png") + const chunks: unknown[] = [] + for await (const chunk of generator) { + chunks.push(chunk) + } + + // Verify streamText was called + expect(mockStreamText).toHaveBeenCalledTimes(1) + const callArgs = mockStreamText.mock.calls[0][0] + + // Verify messages contain both images + const userMsg = callArgs.messages.find((m: { role: string }) => m.role === "user") + expect(userMsg).toBeDefined() + + const imageParts = userMsg.content.filter((p: { type: string }) => p.type === "image") + expect(imageParts).toHaveLength(2) + expect(imageParts[0].image).toContain("data:image/jpeg;base64,") + expect(imageParts[0].mimeType).toBe("image/jpeg") + expect(imageParts[1].image).toContain("data:image/png;base64,") + expect(imageParts[1].mimeType).toBe("image/png") }) }) @@ -686,6 +694,17 @@ describe("AwsBedrockHandler", () => { }) describe("1M context beta feature", () => { + function setupMockStreamText() { + async function* mockFullStream() { + yield { type: "text-delta", text: "Response" } + } + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: Promise.resolve({ inputTokens: 10, outputTokens: 5 }), + providerMetadata: Promise.resolve({}), + }) + } + it("should enable 1M context window when awsBedrock1MContext is true for Claude Sonnet 4", () => { const handler = new AwsBedrockHandler({ apiModelId: BEDROCK_1M_CONTEXT_MODEL_IDS[0], @@ -731,7 +750,9 @@ describe("AwsBedrockHandler", () => { expect(model.info.contextWindow).toBe(200_000) }) - it("should include anthropic_beta parameter when 1M context is enabled", async () => { + it("should include anthropicBeta in providerOptions when 1M context is enabled", async () => { + setupMockStreamText() + const handler = new AwsBedrockHandler({ apiModelId: BEDROCK_1M_CONTEXT_MODEL_IDS[0], awsAccessKey: "test", @@ -748,23 +769,23 @@ describe("AwsBedrockHandler", () => { ] const generator = handler.createMessage("", messages) - await generator.next() // Start the generator - - // Verify the command was created with the right payload - expect(mockConverseStreamCommand).toHaveBeenCalled() - const commandArg = mockConverseStreamCommand.mock.calls[0][0] as any - - // Should include anthropic_beta in additionalModelRequestFields with both 1M context and fine-grained-tool-streaming - expect(commandArg.additionalModelRequestFields).toBeDefined() - expect(commandArg.additionalModelRequestFields.anthropic_beta).toContain("context-1m-2025-08-07") - expect(commandArg.additionalModelRequestFields.anthropic_beta).toContain( - "fine-grained-tool-streaming-2025-05-14", - ) - // Should not include anthropic_version since thinking is not enabled - expect(commandArg.additionalModelRequestFields.anthropic_version).toBeUndefined() + const chunks: unknown[] = [] + for await (const chunk of generator) { + chunks.push(chunk) + } + + expect(mockStreamText).toHaveBeenCalledTimes(1) + const callArgs = mockStreamText.mock.calls[0][0] + + // Should include anthropicBeta in providerOptions.bedrock with 1M context + const bedrockOpts = callArgs.providerOptions?.bedrock as Record | undefined + expect(bedrockOpts).toBeDefined() + expect(bedrockOpts!.anthropicBeta).toContain("context-1m-2025-08-07") }) - it("should not include 1M context beta when 1M context is disabled but still include fine-grained-tool-streaming", async () => { + it("should not include 1M context beta when 1M context is disabled", async () => { + setupMockStreamText() + const handler = new AwsBedrockHandler({ apiModelId: BEDROCK_1M_CONTEXT_MODEL_IDS[0], awsAccessKey: "test", @@ -781,22 +802,24 @@ describe("AwsBedrockHandler", () => { ] const generator = handler.createMessage("", messages) - await generator.next() // Start the generator - - // Verify the command was created with the right payload - expect(mockConverseStreamCommand).toHaveBeenCalled() - const commandArg = mockConverseStreamCommand.mock.calls[0][0] as any - - // Should include anthropic_beta with fine-grained-tool-streaming for Claude models - expect(commandArg.additionalModelRequestFields).toBeDefined() - expect(commandArg.additionalModelRequestFields.anthropic_beta).toContain( - "fine-grained-tool-streaming-2025-05-14", - ) - // Should NOT include 1M context beta - expect(commandArg.additionalModelRequestFields.anthropic_beta).not.toContain("context-1m-2025-08-07") + const chunks: unknown[] = [] + for await (const chunk of generator) { + chunks.push(chunk) + } + + expect(mockStreamText).toHaveBeenCalledTimes(1) + const callArgs = mockStreamText.mock.calls[0][0] + + // Should NOT include anthropicBeta with 1M context + const bedrockOpts = callArgs.providerOptions?.bedrock as Record | undefined + if (bedrockOpts?.anthropicBeta) { + expect(bedrockOpts.anthropicBeta).not.toContain("context-1m-2025-08-07") + } }) - it("should not include 1M context beta for non-Claude Sonnet 4 models but still include fine-grained-tool-streaming", async () => { + it("should not include 1M context beta for non-Claude Sonnet 4 models", async () => { + setupMockStreamText() + const handler = new AwsBedrockHandler({ apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", awsAccessKey: "test", @@ -813,19 +836,19 @@ describe("AwsBedrockHandler", () => { ] const generator = handler.createMessage("", messages) - await generator.next() // Start the generator - - // Verify the command was created with the right payload - expect(mockConverseStreamCommand).toHaveBeenCalled() - const commandArg = mockConverseStreamCommand.mock.calls[0][0] as any - - // Should include anthropic_beta with fine-grained-tool-streaming for Claude models (even non-Sonnet 4) - expect(commandArg.additionalModelRequestFields).toBeDefined() - expect(commandArg.additionalModelRequestFields.anthropic_beta).toContain( - "fine-grained-tool-streaming-2025-05-14", - ) - // Should NOT include 1M context beta for non-Sonnet 4 models - expect(commandArg.additionalModelRequestFields.anthropic_beta).not.toContain("context-1m-2025-08-07") + const chunks: unknown[] = [] + for await (const chunk of generator) { + chunks.push(chunk) + } + + expect(mockStreamText).toHaveBeenCalledTimes(1) + const callArgs = mockStreamText.mock.calls[0][0] + + // Should NOT include anthropicBeta with 1M context for non-Sonnet 4 models + const bedrockOpts = callArgs.providerOptions?.bedrock as Record | undefined + if (bedrockOpts?.anthropicBeta) { + expect(bedrockOpts.anthropicBeta).not.toContain("context-1m-2025-08-07") + } }) it("should enable 1M context window with cross-region inference for Claude Sonnet 4", () => { @@ -846,7 +869,9 @@ describe("AwsBedrockHandler", () => { expect(model.id).toBe(`us.${BEDROCK_1M_CONTEXT_MODEL_IDS[0]}`) }) - it("should include anthropic_beta parameter with cross-region inference for Claude Sonnet 4", async () => { + it("should include anthropicBeta with cross-region inference for Claude Sonnet 4", async () => { + setupMockStreamText() + const handler = new AwsBedrockHandler({ apiModelId: BEDROCK_1M_CONTEXT_MODEL_IDS[0], awsAccessKey: "test", @@ -864,33 +889,34 @@ describe("AwsBedrockHandler", () => { ] const generator = handler.createMessage("", messages) - await generator.next() // Start the generator - - // Verify the command was created with the right payload - expect(mockConverseStreamCommand).toHaveBeenCalled() - const commandArg = mockConverseStreamCommand.mock.calls[ - mockConverseStreamCommand.mock.calls.length - 1 - ][0] as any - - // Should include anthropic_beta in additionalModelRequestFields with both 1M context and fine-grained-tool-streaming - expect(commandArg.additionalModelRequestFields).toBeDefined() - expect(commandArg.additionalModelRequestFields.anthropic_beta).toContain("context-1m-2025-08-07") - expect(commandArg.additionalModelRequestFields.anthropic_beta).toContain( - "fine-grained-tool-streaming-2025-05-14", - ) - // Should not include anthropic_version since thinking is not enabled - expect(commandArg.additionalModelRequestFields.anthropic_version).toBeUndefined() - // Model ID should have cross-region prefix - expect(commandArg.modelId).toBe(`us.${BEDROCK_1M_CONTEXT_MODEL_IDS[0]}`) + const chunks: unknown[] = [] + for await (const chunk of generator) { + chunks.push(chunk) + } + + expect(mockStreamText).toHaveBeenCalledTimes(1) + const callArgs = mockStreamText.mock.calls[0][0] + + // Should include anthropicBeta in providerOptions.bedrock with 1M context + const bedrockOpts = callArgs.providerOptions?.bedrock as Record | undefined + expect(bedrockOpts).toBeDefined() + expect(bedrockOpts!.anthropicBeta).toContain("context-1m-2025-08-07") }) }) describe("service tier feature", () => { const supportedModelId = BEDROCK_SERVICE_TIER_MODEL_IDS[0] // amazon.nova-lite-v1:0 - beforeEach(() => { - mockConverseStreamCommand.mockReset() - }) + function setupMockStreamText() { + async function* mockFullStream() { + yield { type: "text-delta", text: "Response" } + } + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: Promise.resolve({ inputTokens: 10, outputTokens: 5 }), + providerMetadata: Promise.resolve({}), + }) + } describe("pricing multipliers in getModel()", () => { it("should apply FLEX tier pricing with 50% discount", () => { @@ -976,7 +1002,9 @@ describe("AwsBedrockHandler", () => { }) describe("service_tier parameter in API requests", () => { - it("should include service_tier as top-level parameter for supported models", async () => { + it("should include service_tier in providerOptions.bedrock.additionalModelRequestFields for supported models", async () => { + setupMockStreamText() + const handler = new AwsBedrockHandler({ apiModelId: supportedModelId, awsAccessKey: "test", @@ -993,23 +1021,27 @@ describe("AwsBedrockHandler", () => { ] const generator = handler.createMessage("", messages) - await generator.next() // Start the generator - - // Verify the command was created with service_tier at top level - // Per AWS documentation, service_tier must be a top-level parameter, not inside additionalModelRequestFields - // https://docs.aws.amazon.com/bedrock/latest/userguide/service-tiers-inference.html - expect(mockConverseStreamCommand).toHaveBeenCalled() - const commandArg = mockConverseStreamCommand.mock.calls[0][0] as any - - // service_tier should be at the top level of the payload - expect(commandArg.service_tier).toBe("PRIORITY") - // service_tier should NOT be in additionalModelRequestFields - if (commandArg.additionalModelRequestFields) { - expect(commandArg.additionalModelRequestFields.service_tier).toBeUndefined() + const chunks: unknown[] = [] + for await (const chunk of generator) { + chunks.push(chunk) } + + expect(mockStreamText).toHaveBeenCalledTimes(1) + const callArgs = mockStreamText.mock.calls[0][0] + + // service_tier should be passed through providerOptions.bedrock.additionalModelRequestFields + const bedrockOpts = callArgs.providerOptions?.bedrock as Record | undefined + expect(bedrockOpts).toBeDefined() + const additionalFields = bedrockOpts!.additionalModelRequestFields as + | Record + | undefined + expect(additionalFields).toBeDefined() + expect(additionalFields!.service_tier).toBe("PRIORITY") }) - it("should include service_tier FLEX as top-level parameter", async () => { + it("should include service_tier FLEX in providerOptions", async () => { + setupMockStreamText() + const handler = new AwsBedrockHandler({ apiModelId: supportedModelId, awsAccessKey: "test", @@ -1026,20 +1058,26 @@ describe("AwsBedrockHandler", () => { ] const generator = handler.createMessage("", messages) - await generator.next() // Start the generator + const chunks: unknown[] = [] + for await (const chunk of generator) { + chunks.push(chunk) + } - expect(mockConverseStreamCommand).toHaveBeenCalled() - const commandArg = mockConverseStreamCommand.mock.calls[0][0] as any + expect(mockStreamText).toHaveBeenCalledTimes(1) + const callArgs = mockStreamText.mock.calls[0][0] - // service_tier should be at the top level of the payload - expect(commandArg.service_tier).toBe("FLEX") - // service_tier should NOT be in additionalModelRequestFields - if (commandArg.additionalModelRequestFields) { - expect(commandArg.additionalModelRequestFields.service_tier).toBeUndefined() - } + const bedrockOpts = callArgs.providerOptions?.bedrock as Record | undefined + expect(bedrockOpts).toBeDefined() + const additionalFields = bedrockOpts!.additionalModelRequestFields as + | Record + | undefined + expect(additionalFields).toBeDefined() + expect(additionalFields!.service_tier).toBe("FLEX") }) it("should NOT include service_tier for unsupported models", async () => { + setupMockStreamText() + const unsupportedModelId = "anthropic.claude-3-5-sonnet-20241022-v2:0" const handler = new AwsBedrockHandler({ apiModelId: unsupportedModelId, @@ -1057,19 +1095,25 @@ describe("AwsBedrockHandler", () => { ] const generator = handler.createMessage("", messages) - await generator.next() // Start the generator + const chunks: unknown[] = [] + for await (const chunk of generator) { + chunks.push(chunk) + } - expect(mockConverseStreamCommand).toHaveBeenCalled() - const commandArg = mockConverseStreamCommand.mock.calls[0][0] as any + expect(mockStreamText).toHaveBeenCalledTimes(1) + const callArgs = mockStreamText.mock.calls[0][0] - // Service tier should NOT be included for unsupported models (at top level or in additionalModelRequestFields) - expect(commandArg.service_tier).toBeUndefined() - if (commandArg.additionalModelRequestFields) { - expect(commandArg.additionalModelRequestFields.service_tier).toBeUndefined() + // Service tier should NOT be included for unsupported models + const bedrockOpts = callArgs.providerOptions?.bedrock as Record | undefined + if (bedrockOpts?.additionalModelRequestFields) { + const additionalFields = bedrockOpts.additionalModelRequestFields as Record + expect(additionalFields.service_tier).toBeUndefined() } }) it("should NOT include service_tier when not specified", async () => { + setupMockStreamText() + const handler = new AwsBedrockHandler({ apiModelId: supportedModelId, awsAccessKey: "test", @@ -1086,15 +1130,19 @@ describe("AwsBedrockHandler", () => { ] const generator = handler.createMessage("", messages) - await generator.next() // Start the generator + const chunks: unknown[] = [] + for await (const chunk of generator) { + chunks.push(chunk) + } - expect(mockConverseStreamCommand).toHaveBeenCalled() - const commandArg = mockConverseStreamCommand.mock.calls[0][0] as any + expect(mockStreamText).toHaveBeenCalledTimes(1) + const callArgs = mockStreamText.mock.calls[0][0] - // Service tier should NOT be included when not specified (at top level or in additionalModelRequestFields) - expect(commandArg.service_tier).toBeUndefined() - if (commandArg.additionalModelRequestFields) { - expect(commandArg.additionalModelRequestFields.service_tier).toBeUndefined() + // Service tier should NOT be included when not specified + const bedrockOpts = callArgs.providerOptions?.bedrock as Record | undefined + if (bedrockOpts?.additionalModelRequestFields) { + const additionalFields = bedrockOpts.additionalModelRequestFields as Record + expect(additionalFields.service_tier).toBeUndefined() } }) }) @@ -1127,16 +1175,16 @@ describe("AwsBedrockHandler", () => { }) describe("error telemetry", () => { - let mockSend: ReturnType - beforeEach(() => { mockCaptureException.mockClear() - // Get access to the mock send function from the mocked client - mockSend = vi.mocked(BedrockRuntimeClient).mock.results[0]?.value?.send }) it("should capture telemetry on createMessage error", async () => { - // Create a handler with a fresh mock + // Mock streamText to throw an error + mockStreamText.mockImplementation(() => { + throw new Error("Bedrock API error") + }) + const errorHandler = new AwsBedrockHandler({ apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", awsAccessKey: "test-access-key", @@ -1144,15 +1192,6 @@ describe("AwsBedrockHandler", () => { awsRegion: "us-east-1", }) - // Get the mock send from the new handler instance - const clientInstance = - vi.mocked(BedrockRuntimeClient).mock.results[vi.mocked(BedrockRuntimeClient).mock.results.length - 1] - ?.value - const mockSendFn = clientInstance?.send as ReturnType - - // Mock the send to throw an error - mockSendFn.mockRejectedValueOnce(new Error("Bedrock API error")) - const messages: Anthropic.Messages.MessageParam[] = [ { role: "user", @@ -1186,7 +1225,9 @@ describe("AwsBedrockHandler", () => { }) it("should capture telemetry on completePrompt error", async () => { - // Create a handler with a fresh mock + // Mock generateText to throw an error + mockGenerateText.mockRejectedValueOnce(new Error("Bedrock completion error")) + const errorHandler = new AwsBedrockHandler({ apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", awsAccessKey: "test-access-key", @@ -1194,15 +1235,6 @@ describe("AwsBedrockHandler", () => { awsRegion: "us-east-1", }) - // Get the mock send from the new handler instance - const clientInstance = - vi.mocked(BedrockRuntimeClient).mock.results[vi.mocked(BedrockRuntimeClient).mock.results.length - 1] - ?.value - const mockSendFn = clientInstance?.send as ReturnType - - // Mock the send to throw an error for ConverseCommand - mockSendFn.mockRejectedValueOnce(new Error("Bedrock completion error")) - // Call completePrompt - it should throw await expect(errorHandler.completePrompt("Test prompt")).rejects.toThrow() @@ -1223,7 +1255,11 @@ describe("AwsBedrockHandler", () => { }) it("should still throw the error after capturing telemetry", async () => { - // Create a handler with a fresh mock + // Mock streamText to throw an error + mockStreamText.mockImplementation(() => { + throw new Error("Test error for throw verification") + }) + const errorHandler = new AwsBedrockHandler({ apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", awsAccessKey: "test-access-key", @@ -1231,15 +1267,6 @@ describe("AwsBedrockHandler", () => { awsRegion: "us-east-1", }) - // Get the mock send from the new handler instance - const clientInstance = - vi.mocked(BedrockRuntimeClient).mock.results[vi.mocked(BedrockRuntimeClient).mock.results.length - 1] - ?.value - const mockSendFn = clientInstance?.send as ReturnType - - // Mock the send to throw an error - mockSendFn.mockRejectedValueOnce(new Error("Test error for throw verification")) - const messages: Anthropic.Messages.MessageParam[] = [ { role: "user", diff --git a/src/api/providers/bedrock.ts b/src/api/providers/bedrock.ts index ca747439ef3..375dd2c0421 100644 --- a/src/api/providers/bedrock.ts +++ b/src/api/providers/bedrock.ts @@ -1,24 +1,13 @@ -import { - BedrockRuntimeClient, - ConverseStreamCommand, - ConverseCommand, - BedrockRuntimeClientConfig, - ContentBlock, - Message, - SystemContentBlock, - Tool, - ToolConfiguration, - ToolChoice, -} from "@aws-sdk/client-bedrock-runtime" -import OpenAI from "openai" +import type { Anthropic } from "@anthropic-ai/sdk" +import { createAmazonBedrock, type AmazonBedrockProvider } from "@ai-sdk/amazon-bedrock" +import { streamText, generateText, ToolSet } from "ai" import { fromIni } from "@aws-sdk/credential-providers" -import { Anthropic } from "@anthropic-ai/sdk" +import OpenAI from "openai" import { type ModelInfo, type ProviderSettings, type BedrockModelId, - type BedrockServiceTier, bedrockDefaultModelId, bedrockModels, bedrockDefaultPromptRouterModelId, @@ -34,165 +23,22 @@ import { } from "@roo-code/types" import { TelemetryService } from "@roo-code/telemetry" -import { ApiStream } from "../transform/stream" +import type { ApiStream, ApiStreamUsageChunk } from "../transform/stream" +import { + convertToAiSdkMessages, + convertToolsForAiSdk, + processAiSdkStreamPart, + mapToolChoice, + handleAiSdkError, +} from "../transform/ai-sdk" +import { getModelParams } from "../transform/model-params" +import { shouldUseReasoningBudget } from "../../shared/api" import { BaseProvider } from "./base-provider" +import { DEFAULT_HEADERS } from "./constants" import { logger } from "../../utils/logging" import { Package } from "../../shared/package" -import { MultiPointStrategy } from "../transform/cache-strategy/multi-point-strategy" -import { ModelInfo as CacheModelInfo } from "../transform/cache-strategy/types" -import { convertToBedrockConverseMessages as sharedConverter } from "../transform/bedrock-converse-format" -import { getModelParams } from "../transform/model-params" -import { shouldUseReasoningBudget } from "../../shared/api" -import { normalizeToolSchema } from "../../utils/json-schema" import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index" -/************************************************************************************ - * - * TYPES - * - *************************************************************************************/ - -// Define interface for Bedrock inference config -interface BedrockInferenceConfig { - maxTokens: number - temperature?: number -} - -// Define interface for Bedrock additional model request fields -// This includes thinking configuration, 1M context beta, and other model-specific parameters -interface BedrockAdditionalModelFields { - thinking?: { - type: "enabled" - budget_tokens: number - } - anthropic_beta?: string[] - [key: string]: any // Add index signature to be compatible with DocumentType -} - -// Define interface for Bedrock payload -interface BedrockPayload { - modelId: BedrockModelId | string - messages: Message[] - system?: SystemContentBlock[] - inferenceConfig: BedrockInferenceConfig - anthropic_version?: string - additionalModelRequestFields?: BedrockAdditionalModelFields - toolConfig?: ToolConfiguration -} - -// Extended payload type that includes service_tier as a top-level parameter -// AWS Bedrock service tiers (STANDARD, FLEX, PRIORITY) are specified at the top level -// https://docs.aws.amazon.com/bedrock/latest/userguide/service-tiers-inference.html -type BedrockPayloadWithServiceTier = BedrockPayload & { - service_tier?: BedrockServiceTier -} - -// Define specific types for content block events to avoid 'as any' usage -// These handle the multiple possible structures returned by AWS SDK -interface ContentBlockStartEvent { - start?: { - text?: string - thinking?: string - toolUse?: { - toolUseId?: string - name?: string - } - } - contentBlockIndex?: number - // Alternative structure used by some AWS SDK versions - content_block?: { - type?: string - thinking?: string - } - // Official AWS SDK structure for reasoning (as documented) - contentBlock?: { - type?: string - thinking?: string - reasoningContent?: { - text?: string - } - // Tool use block start - toolUse?: { - toolUseId?: string - name?: string - } - } -} - -interface ContentBlockDeltaEvent { - delta?: { - text?: string - thinking?: string - type?: string - // AWS SDK structure for reasoning content deltas - // Includes text (reasoning), signature (verification token), and redactedContent (safety-filtered) - reasoningContent?: { - text?: string - signature?: string - redactedContent?: Uint8Array - } - // Tool use input delta - toolUse?: { - input?: string - } - } - contentBlockIndex?: number -} - -// Define types for stream events based on AWS SDK -export interface StreamEvent { - messageStart?: { - role?: string - } - messageStop?: { - stopReason?: "end_turn" | "tool_use" | "max_tokens" | "stop_sequence" - additionalModelResponseFields?: Record - } - contentBlockStart?: ContentBlockStartEvent - contentBlockDelta?: ContentBlockDeltaEvent - metadata?: { - usage?: { - inputTokens: number - outputTokens: number - totalTokens?: number // Made optional since we don't use it - // New cache-related fields - cacheReadInputTokens?: number - cacheWriteInputTokens?: number - cacheReadInputTokenCount?: number - cacheWriteInputTokenCount?: number - } - metrics?: { - latencyMs: number - } - } - // New trace field for prompt router - trace?: { - promptRouter?: { - invokedModelId?: string - usage?: { - inputTokens: number - outputTokens: number - totalTokens?: number // Made optional since we don't use it - // New cache-related fields - cacheReadTokens?: number - cacheWriteTokens?: number - cacheReadInputTokenCount?: number - cacheWriteInputTokenCount?: number - } - } - } -} - -// Type for usage information in stream events -export type UsageType = { - inputTokens?: number - outputTokens?: number - cacheReadInputTokens?: number - cacheWriteInputTokens?: number - cacheReadInputTokenCount?: number - cacheWriteInputTokenCount?: number -} - /************************************************************************************ * * PROVIDER @@ -201,7 +47,7 @@ export type UsageType = { export class AwsBedrockHandler extends BaseProvider implements SingleCompletionHandler { protected options: ProviderSettings - private client: BedrockRuntimeClient + protected provider: AmazonBedrockProvider private arnInfo: any private readonly providerName = "Bedrock" private lastThoughtSignature: string | undefined @@ -212,10 +58,7 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH this.options = options let region = this.options.awsRegion - // process the various user input options, be opinionated about the intent of the options - // and determine the model to use during inference and for cost calculations - // There are variations on ARN strings that can be entered making the conditional logic - // more involved than the non-ARN branch of logic + // Process custom ARN if provided if (this.options.awsCustomArn) { this.arnInfo = this.parseArn(this.options.awsCustomArn, region) @@ -224,8 +67,6 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH ctx: "bedrock", errorMessage: this.arnInfo.errorMessage, }) - - // Throw a consistent error with a prefix that can be detected by callers const errorMessage = this.arnInfo.errorMessage || "Invalid ARN format. ARN should follow the pattern: arn:aws:bedrock:region:account-id:resource-type/resource-name" @@ -233,21 +74,16 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH } if (this.arnInfo.region && this.arnInfo.region !== this.options.awsRegion) { - // Log if there's a region mismatch between the ARN and the region selected by the user - // We will use the ARNs region, so execution can continue, but log an info statement. - // Log a warning if there's a region mismatch between the ARN and the region selected by the user - // We will use the ARNs region, so execution can continue, but log an info statement. logger.info(this.arnInfo.errorMessage, { ctx: "bedrock", selectedRegion: this.options.awsRegion, arnRegion: this.arnInfo.region, }) - this.options.awsRegion = this.arnInfo.region } this.options.apiModelId = this.arnInfo.modelId - if (this.arnInfo.awsUseCrossRegionInference) this.options.awsUseCrossRegionInference = true + if (this.arnInfo.crossRegionInference) this.options.awsUseCrossRegionInference = true } if (!this.options.modelTemperature) { @@ -256,44 +92,46 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH this.costModelConfig = this.getModel() - const clientConfig: BedrockRuntimeClientConfig = { - userAgentAppId: `RooCode#${Package.version}`, + // Build provider settings for AI SDK + const providerSettings: Parameters[0] = { region: this.options.awsRegion, - // Add the endpoint configuration when specified and enabled + headers: { + ...DEFAULT_HEADERS, + "User-Agent": `RooCode#${Package.version}`, + }, + // Add VPC endpoint if specified and enabled ...(this.options.awsBedrockEndpoint && - this.options.awsBedrockEndpointEnabled && { endpoint: this.options.awsBedrockEndpoint }), + this.options.awsBedrockEndpointEnabled && { baseURL: this.options.awsBedrockEndpoint }), } if (this.options.awsUseApiKey && this.options.awsApiKey) { - // Use API key/token-based authentication if enabled and API key is set - clientConfig.token = { token: this.options.awsApiKey } - clientConfig.authSchemePreference = ["httpBearerAuth"] // Otherwise there's no end of credential problems. - clientConfig.requestHandler = { - // This should be the default anyway, but without setting something - // this provider fails to work with LiteLLM passthrough. - requestTimeout: 0, - } + // Use API key/token-based authentication + providerSettings.apiKey = this.options.awsApiKey } else if (this.options.awsUseProfile && this.options.awsProfile) { - // Use profile-based credentials if enabled and profile is set - clientConfig.credentials = fromIni({ - profile: this.options.awsProfile, - ignoreCache: true, - }) + // Use profile-based credentials via credentialProvider + const profile = this.options.awsProfile + providerSettings.credentialProvider = async () => { + const creds = await fromIni({ profile, ignoreCache: true })() + return { + accessKeyId: creds.accessKeyId, + secretAccessKey: creds.secretAccessKey, + ...(creds.sessionToken ? { sessionToken: creds.sessionToken } : {}), + } + } } else if (this.options.awsAccessKey && this.options.awsSecretKey) { - // Use direct credentials if provided - clientConfig.credentials = { - accessKeyId: this.options.awsAccessKey, - secretAccessKey: this.options.awsSecretKey, - ...(this.options.awsSessionToken ? { sessionToken: this.options.awsSessionToken } : {}), + // Use direct credentials + providerSettings.accessKeyId = this.options.awsAccessKey + providerSettings.secretAccessKey = this.options.awsSecretKey + if (this.options.awsSessionToken) { + providerSettings.sessionToken = this.options.awsSessionToken } } - this.client = new BedrockRuntimeClient(clientConfig) + this.provider = createAmazonBedrock(providerSettings) } // Helper to guess model info from custom modelId string if not in bedrockModels private guessModelInfoFromId(modelId: string): Partial { - // Define a mapping for model ID patterns and their configurations const modelConfigMap: Record> = { "claude-4": { maxTokens: 8192, @@ -333,7 +171,6 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH }, } - // Match the model ID to a configuration const id = modelId.toLowerCase() for (const [pattern, config] of Object.entries(modelConfigMap)) { if (id.includes(pattern)) { @@ -341,7 +178,6 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH } } - // Default fallback return { maxTokens: BEDROCK_MAX_TOKENS, contextWindow: BEDROCK_DEFAULT_CONTEXT, @@ -353,619 +189,343 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH override async *createMessage( systemPrompt: string, messages: Anthropic.Messages.MessageParam[], - metadata?: ApiHandlerCreateMessageMetadata & { - thinking?: { - enabled: boolean - maxTokens?: number - maxThinkingTokens?: number - } - }, + metadata?: ApiHandlerCreateMessageMetadata, ): ApiStream { const modelConfig = this.getModel() - const usePromptCache = Boolean(this.options.awsUsePromptCache && this.supportsAwsPromptCache(modelConfig)) - const conversationId = - messages.length > 0 - ? `conv_${messages[0].role}_${ - typeof messages[0].content === "string" - ? messages[0].content.substring(0, 20) - : "complex_content" - }` - : "default_conversation" - - const formatted = this.convertToBedrockConverseMessages( - messages, - systemPrompt, - usePromptCache, - modelConfig.info, - conversationId, - ) + // Reset thinking state for this request + this.lastThoughtSignature = undefined + this.lastRedactedThinkingBlocks = [] - let additionalModelRequestFields: BedrockAdditionalModelFields | undefined - let thinkingEnabled = false + // Filter out provider-specific meta entries (e.g., { type: "reasoning" }) + // that are not valid Anthropic MessageParam values + type ReasoningMetaLike = { type?: string } + const filteredMessages = messages.filter((message): message is Anthropic.Messages.MessageParam => { + const meta = message as ReasoningMetaLike + if (meta.type === "reasoning") { + return false + } + return true + }) + + // Convert messages to AI SDK format + const aiSdkMessages = convertToAiSdkMessages(filteredMessages) + + // Convert tools to AI SDK format + let openAiTools = this.convertToolsForOpenAI(metadata?.tools) + const aiSdkTools = convertToolsForAiSdk(openAiTools) as ToolSet | undefined + const toolChoice = mapToolChoice(metadata?.tool_choice) - // Determine if thinking should be enabled - // metadata?.thinking?.enabled: Explicitly enabled through API metadata (direct request) - // shouldUseReasoningBudget(): Enabled through user settings (enableReasoningEffort = true) - const isThinkingExplicitlyEnabled = metadata?.thinking?.enabled + // Build provider options for reasoning, betas, etc. + const bedrockProviderOptions: Record = {} + + // Extended thinking / reasoning configuration const isThinkingEnabledBySettings = shouldUseReasoningBudget({ model: modelConfig.info, settings: this.options }) && modelConfig.reasoning && modelConfig.reasoningBudget - if ((isThinkingExplicitlyEnabled || isThinkingEnabledBySettings) && modelConfig.info.supportsReasoningBudget) { - thinkingEnabled = true - additionalModelRequestFields = { - thinking: { - type: "enabled", - budget_tokens: metadata?.thinking?.maxThinkingTokens || modelConfig.reasoningBudget || 4096, - }, + if (isThinkingEnabledBySettings && modelConfig.info.supportsReasoningBudget) { + bedrockProviderOptions.reasoningConfig = { + type: "enabled", + budgetTokens: modelConfig.reasoningBudget, } - logger.info("Extended thinking enabled for Bedrock request", { - ctx: "bedrock", - modelId: modelConfig.id, - thinking: additionalModelRequestFields.thinking, - }) } - const inferenceConfig: BedrockInferenceConfig = { - maxTokens: modelConfig.maxTokens || (modelConfig.info.maxTokens as number), - temperature: modelConfig.temperature ?? (this.options.modelTemperature as number), - } - - // Check if 1M context is enabled for supported Claude 4 models - // Use parseBaseModelId to handle cross-region inference prefixes - const baseModelId = this.parseBaseModelId(modelConfig.id) - const is1MContextEnabled = - BEDROCK_1M_CONTEXT_MODEL_IDS.includes(baseModelId as any) && this.options.awsBedrock1MContext - - // Determine if service tier should be applied (checked later when building payload) - const useServiceTier = - this.options.awsBedrockServiceTier && BEDROCK_SERVICE_TIER_MODEL_IDS.includes(baseModelId as any) - if (useServiceTier) { - logger.info("Service tier specified for Bedrock request", { - ctx: "bedrock", - modelId: modelConfig.id, - serviceTier: this.options.awsBedrockServiceTier, - }) - } - - // Add anthropic_beta headers for various features - // Start with an empty array and add betas as needed + // Anthropic beta headers for various features const anthropicBetas: string[] = [] + const baseModelId = this.parseBaseModelId(modelConfig.id) // Add 1M context beta if enabled - if (is1MContextEnabled) { + if (BEDROCK_1M_CONTEXT_MODEL_IDS.includes(baseModelId as any) && this.options.awsBedrock1MContext) { anthropicBetas.push("context-1m-2025-08-07") } - // Add fine-grained tool streaming beta for Claude models - // This enables proper tool use streaming for Anthropic models on Bedrock - if (baseModelId.includes("claude")) { - anthropicBetas.push("fine-grained-tool-streaming-2025-05-14") - } - - // Apply anthropic_beta to additionalModelRequestFields if any betas are needed if (anthropicBetas.length > 0) { - if (!additionalModelRequestFields) { - additionalModelRequestFields = {} as BedrockAdditionalModelFields - } - additionalModelRequestFields.anthropic_beta = anthropicBetas - } - - const toolConfig: ToolConfiguration = { - tools: this.convertToolsForBedrock(metadata?.tools ?? []), - toolChoice: this.convertToolChoiceForBedrock(metadata?.tool_choice), - } - - // Build payload with optional service_tier at top level - // Service tier is a top-level parameter per AWS documentation, NOT inside additionalModelRequestFields - // https://docs.aws.amazon.com/bedrock/latest/userguide/service-tiers-inference.html - const payload: BedrockPayloadWithServiceTier = { - modelId: modelConfig.id, - messages: formatted.messages, - system: formatted.system, - inferenceConfig, - ...(additionalModelRequestFields && { additionalModelRequestFields }), - // Add anthropic_version at top level when using thinking features - ...(thinkingEnabled && { anthropic_version: "bedrock-2023-05-31" }), - toolConfig, - // Add service_tier as a top-level parameter (not inside additionalModelRequestFields) - ...(useServiceTier && { service_tier: this.options.awsBedrockServiceTier }), - } + bedrockProviderOptions.anthropicBeta = anthropicBetas + } + + // Additional model request fields (service tier, etc.) + // Note: The AI SDK may not directly support service_tier as a top-level param, + // so we pass it through additionalModelRequestFields + if (this.options.awsBedrockServiceTier && BEDROCK_SERVICE_TIER_MODEL_IDS.includes(baseModelId as any)) { + bedrockProviderOptions.additionalModelRequestFields = { + ...(bedrockProviderOptions.additionalModelRequestFields as Record | undefined), + service_tier: this.options.awsBedrockServiceTier, + } + } + + // Prompt caching: use AI SDK's cachePoint mechanism + // The AI SDK's @ai-sdk/amazon-bedrock supports cachePoint in providerOptions per message. + // + // Strategy: Bedrock allows up to 4 cache checkpoints. We use them as: + // 1. System prompt (via systemProviderOptions below) + // 2-4. Up to 3 user messages in the conversation history + // + // For the message cache points, we target the last 2 user messages (matching + // Anthropic's strategy: write-to-cache + read-from-cache) PLUS an earlier "anchor" + // user message near the middle of the conversation. This anchor ensures the 20-block + // lookback window has a stable cache entry to hit, covering all assistant/tool messages + // between the anchor and the recent messages. + // + // We identify targets in the ORIGINAL Anthropic messages (before AI SDK conversion) + // because convertToAiSdkMessages() splits user messages containing tool_results into + // separate "tool" + "user" role messages, which would skew naive counting. + const usePromptCache = Boolean(this.options.awsUsePromptCache && this.supportsAwsPromptCache(modelConfig)) - // Create AbortController with 10 minute timeout - const controller = new AbortController() - let timeoutId: NodeJS.Timeout | undefined + if (usePromptCache) { + const cachePointOption = { bedrock: { cachePoint: { type: "default" as const } } } - try { - timeoutId = setTimeout( - () => { - controller.abort() - }, - 10 * 60 * 1000, + // Find all user message indices in the original (pre-conversion) message array. + const originalUserIndices = filteredMessages.reduce( + (acc, msg, idx) => (msg.role === "user" ? [...acc, idx] : acc), + [], ) - const command = new ConverseStreamCommand(payload) - const response = await this.client.send(command, { - abortSignal: controller.signal, - }) - - if (!response.stream) { - clearTimeout(timeoutId) - throw new Error("No stream available in the response") + // Select up to 3 user messages for cache points (system prompt uses the 4th): + // - Last user message: write to cache for next request + // - Second-to-last user message: read from cache for current request + // - An "anchor" message earlier in the conversation for 20-block window coverage + const targetOriginalIndices = new Set() + const numUserMsgs = originalUserIndices.length + + if (numUserMsgs >= 1) { + // Always cache the last user message + targetOriginalIndices.add(originalUserIndices[numUserMsgs - 1]) + } + if (numUserMsgs >= 2) { + // Cache the second-to-last user message + targetOriginalIndices.add(originalUserIndices[numUserMsgs - 2]) + } + if (numUserMsgs >= 5) { + // Add an anchor cache point roughly in the first third of user messages. + // This ensures that the 20-block lookback from the second-to-last breakpoint + // can find a stable cache entry, covering all the assistant and tool messages + // in the middle of the conversation. We pick the user message at ~1/3 position. + const anchorIdx = Math.floor(numUserMsgs / 3) + // Only add if it's not already one of the last-2 targets + if (!targetOriginalIndices.has(originalUserIndices[anchorIdx])) { + targetOriginalIndices.add(originalUserIndices[anchorIdx]) + } } - // Reset thinking state for this request - this.lastThoughtSignature = undefined - this.lastRedactedThinkingBlocks = [] - - for await (const chunk of response.stream) { - // Parse the chunk as JSON if it's a string (for tests) - let streamEvent: StreamEvent - try { - streamEvent = typeof chunk === "string" ? JSON.parse(chunk) : (chunk as unknown as StreamEvent) - } catch (e) { - logger.error("Failed to parse stream event", { + // Apply cachePoint to the correct AI SDK messages by walking both arrays in parallel. + // A single original user message with tool_results becomes [tool-role msg, user-role msg] + // in the AI SDK array, while a plain user message becomes [user-role msg]. + if (targetOriginalIndices.size > 0) { + this.applyCachePointsToAiSdkMessages( + filteredMessages, + aiSdkMessages, + targetOriginalIndices, + cachePointOption, + ) + } + } + + // Build streamText request + // Cast providerOptions to any to bypass strict JSONObject typing — the AI SDK accepts the correct runtime values + const requestOptions: Parameters[0] = { + model: this.provider(modelConfig.id), + system: systemPrompt, + ...(usePromptCache && { + systemProviderOptions: { bedrock: { cachePoint: { type: "default" } } } as Record, + }), + messages: aiSdkMessages, + temperature: modelConfig.temperature ?? (this.options.modelTemperature as number), + maxOutputTokens: modelConfig.maxTokens || (modelConfig.info.maxTokens as number), + tools: aiSdkTools, + toolChoice, + ...(Object.keys(bedrockProviderOptions).length > 0 && { + providerOptions: { bedrock: bedrockProviderOptions } as any, + }), + } + + try { + const result = streamText(requestOptions) + + // Process the full stream + for await (const part of result.fullStream) { + // Capture thinking signature from stream events. + // The AI SDK's @ai-sdk/amazon-bedrock emits the signature as a reasoning-delta + // event with providerMetadata.bedrock.signature (empty delta text, signature in metadata). + // Also check tool-call events for thoughtSignature (Gemini pattern). + const partAny = part as any + if (partAny.providerMetadata?.bedrock?.signature) { + this.lastThoughtSignature = partAny.providerMetadata.bedrock.signature + logger.info("Captured thinking signature from stream", { ctx: "bedrock", - error: e instanceof Error ? e : String(e), - chunk: typeof chunk === "string" ? chunk : "binary data", + signatureLength: this.lastThoughtSignature?.length, }) - continue - } - - // Handle metadata events first - if (streamEvent.metadata?.usage) { - const usage = (streamEvent.metadata?.usage || {}) as UsageType - - // Check both field naming conventions for cache tokens - const cacheReadTokens = usage.cacheReadInputTokens || usage.cacheReadInputTokenCount || 0 - const cacheWriteTokens = usage.cacheWriteInputTokens || usage.cacheWriteInputTokenCount || 0 - - // Always include all available token information - yield { - type: "usage", - inputTokens: usage.inputTokens || 0, - outputTokens: usage.outputTokens || 0, - cacheReadTokens: cacheReadTokens, - cacheWriteTokens: cacheWriteTokens, - } - continue + } else if (partAny.providerMetadata?.bedrock?.thoughtSignature) { + this.lastThoughtSignature = partAny.providerMetadata.bedrock.thoughtSignature + } else if (partAny.providerMetadata?.anthropic?.thoughtSignature) { + this.lastThoughtSignature = partAny.providerMetadata.anthropic.thoughtSignature } - if (streamEvent?.trace?.promptRouter?.invokedModelId) { - try { - //update the in-use model info to be based on the invoked Model Id for the router - //so that pricing, context window, caching etc have values that can be used - //However, we want to keep the id of the model to be the ID for the router for - //subsequent requests so they are sent back through the router - let invokedArnInfo = this.parseArn(streamEvent.trace.promptRouter.invokedModelId) - let invokedModel = this.getModelById(invokedArnInfo.modelId as string, invokedArnInfo.modelType) - if (invokedModel) { - invokedModel.id = modelConfig.id - this.costModelConfig = invokedModel - } - - // Handle metadata events for the promptRouter. - if (streamEvent?.trace?.promptRouter?.usage) { - const routerUsage = streamEvent.trace.promptRouter.usage - - // Check both field naming conventions for cache tokens - const cacheReadTokens = - routerUsage.cacheReadTokens || routerUsage.cacheReadInputTokenCount || 0 - const cacheWriteTokens = - routerUsage.cacheWriteTokens || routerUsage.cacheWriteInputTokenCount || 0 - - yield { - type: "usage", - inputTokens: routerUsage.inputTokens || 0, - outputTokens: routerUsage.outputTokens || 0, - cacheReadTokens: cacheReadTokens, - cacheWriteTokens: cacheWriteTokens, - } - } - } catch (error) { - logger.error("Error handling Bedrock invokedModelId", { - ctx: "bedrock", - error: error instanceof Error ? error : String(error), - }) - } finally { - // eslint-disable-next-line no-unsafe-finally - continue - } - } - - // Handle message start - if (streamEvent.messageStart) { - continue - } - - // Handle content blocks - if (streamEvent.contentBlockStart) { - const cbStart = streamEvent.contentBlockStart - - // Check if this is a reasoning block (AWS SDK structure) - if (cbStart.contentBlock?.reasoningContent) { - if (cbStart.contentBlockIndex && cbStart.contentBlockIndex > 0) { - yield { type: "reasoning", text: "\n" } - } - yield { - type: "reasoning", - text: cbStart.contentBlock.reasoningContent.text || "", - } - } - // Check for thinking block - handle both possible AWS SDK structures - // cbStart.contentBlock: newer structure - // cbStart.content_block: alternative structure seen in some AWS SDK versions - else if (cbStart.contentBlock?.type === "thinking" || cbStart.content_block?.type === "thinking") { - const contentBlock = cbStart.contentBlock || cbStart.content_block - if (cbStart.contentBlockIndex && cbStart.contentBlockIndex > 0) { - yield { type: "reasoning", text: "\n" } - } - if (contentBlock?.thinking) { - yield { - type: "reasoning", - text: contentBlock.thinking, - } - } - } - // Handle tool use block start - else if (cbStart.start?.toolUse || cbStart.contentBlock?.toolUse) { - const toolUse = cbStart.start?.toolUse || cbStart.contentBlock?.toolUse - if (toolUse) { - yield { - type: "tool_call_partial", - index: cbStart.contentBlockIndex ?? 0, - id: toolUse.toolUseId, - name: toolUse.name, - arguments: undefined, - } - } - } else if (cbStart.start?.text) { - yield { - type: "text", - text: cbStart.start.text, - } - } - continue + // Capture redacted reasoning data from stream events + if (partAny.providerMetadata?.bedrock?.redactedData) { + this.lastRedactedThinkingBlocks.push({ + type: "redacted_thinking", + data: partAny.providerMetadata.bedrock.redactedData, + }) } - // Handle content deltas - if (streamEvent.contentBlockDelta) { - const cbDelta = streamEvent.contentBlockDelta - const delta = cbDelta.delta - - // Process reasoning and text content deltas - // Multiple structures are supported for AWS SDK compatibility: - // - delta.reasoningContent.text: AWS docs structure for reasoning - // - delta.thinking: alternative structure for thinking content - // - delta.text: standard text content - // - delta.toolUse.input: tool input arguments - if (delta) { - // Check for reasoningContent property (AWS SDK structure) - if (delta.reasoningContent?.text) { - yield { - type: "reasoning", - text: delta.reasoningContent.text, - } - continue - } - - // Capture the thinking signature from reasoningContent.signature delta. - // Bedrock Converse API sends the signature as a separate delta after all - // reasoning text deltas. This signature must be round-tripped back for - // multi-turn conversations with tool use (Anthropic API requirement). - if (delta.reasoningContent?.signature) { - this.lastThoughtSignature = delta.reasoningContent.signature - continue - } - - // Capture redacted thinking content (opaque binary data from safety-filtered reasoning). - // Anthropic returns this when extended thinking content is filtered. It must be - // passed back verbatim in multi-turn conversations for proper reasoning continuity. - if (delta.reasoningContent?.redactedContent) { - const redactedContent = delta.reasoningContent.redactedContent - this.lastRedactedThinkingBlocks.push({ - type: "redacted_thinking", - data: Buffer.from(redactedContent).toString("base64"), - }) - continue - } - - // Handle tool use input delta - if (delta.toolUse?.input) { - yield { - type: "tool_call_partial", - index: cbDelta.contentBlockIndex ?? 0, - id: undefined, - name: undefined, - arguments: delta.toolUse.input, - } - continue - } - - // Handle alternative thinking structure (fallback for older SDK versions) - if (delta.type === "thinking_delta" && delta.thinking) { - yield { - type: "reasoning", - text: delta.thinking, - } - } else if (delta.text) { - yield { - type: "text", - text: delta.text, - } - } - } - continue - } - // Handle message stop - if (streamEvent.messageStop) { - continue + for (const chunk of processAiSdkStreamPart(part)) { + yield chunk } } - // Clear timeout after stream completes - clearTimeout(timeoutId) - } catch (error: unknown) { - // Clear timeout on error - clearTimeout(timeoutId) - // Capture error in telemetry before processing + // Yield usage metrics at the end + const usage = await result.usage + const providerMetadata = await result.providerMetadata + if (usage) { + yield this.processUsageMetrics(usage, modelConfig.info, providerMetadata) + } + } catch (error) { const errorMessage = error instanceof Error ? error.message : String(error) const apiError = new ApiProviderError(errorMessage, this.providerName, modelConfig.id, "createMessage") TelemetryService.instance.captureException(apiError) - // Check if this is a throttling error that should trigger retry logic - const errorType = this.getErrorType(error) - - // For throttling errors, throw immediately without yielding chunks - // This allows the retry mechanism in attemptApiRequest() to catch and handle it - // The retry logic in Task.ts (around line 1817) expects errors to be thrown - // on the first chunk for proper exponential backoff behavior - if (errorType === "THROTTLING") { + // Check for throttling errors that should trigger retry (re-throw original to preserve status) + if (this.isThrottlingError(error)) { if (error instanceof Error) { throw error - } else { - throw new Error("Throttling error occurred") } + throw new Error("Throttling error occurred") } - // For non-throttling errors, use the standard error handling with chunks - const errorChunks = this.handleBedrockError(error, true) // true for streaming context - // Yield each chunk individually to ensure type compatibility - for (const chunk of errorChunks) { - yield chunk as any // Cast to any to bypass type checking since we know the structure is correct - } - - // Re-throw with enhanced error message for retry system - const enhancedErrorMessage = this.formatErrorMessage(error, this.getErrorType(error), true) - if (error instanceof Error) { - const enhancedError = new Error(enhancedErrorMessage) - // Preserve important properties from the original error - enhancedError.name = error.name - // Validate and preserve status property - if ("status" in error && typeof (error as any).status === "number") { - ;(enhancedError as any).status = (error as any).status - } - // Validate and preserve $metadata property - if ( - "$metadata" in error && - typeof (error as any).$metadata === "object" && - (error as any).$metadata !== null - ) { - ;(enhancedError as any).$metadata = (error as any).$metadata - } - throw enhancedError - } else { - throw new Error("An unknown error occurred") - } + // Handle AI SDK errors (AI_RetryError, AI_APICallError, etc.) + throw handleAiSdkError(error, this.providerName) } } - async completePrompt(prompt: string): Promise { - try { - const modelConfig = this.getModel() - - // For completePrompt, thinking is typically not used, but we should still check - // if thinking was somehow enabled in the model config - const thinkingEnabled = - shouldUseReasoningBudget({ model: modelConfig.info, settings: this.options }) && - modelConfig.reasoning && - modelConfig.reasoningBudget - - const inferenceConfig: BedrockInferenceConfig = { - maxTokens: modelConfig.maxTokens || (modelConfig.info.maxTokens as number), - temperature: modelConfig.temperature ?? (this.options.modelTemperature as number), - } - - // For completePrompt, use a unique conversation ID based on the prompt - const conversationId = `prompt_${prompt.substring(0, 20)}` - - const payload = { - modelId: modelConfig.id, - messages: this.convertToBedrockConverseMessages( - [ - { - role: "user", - content: prompt, - }, - ], - undefined, - false, - modelConfig.info, - conversationId, - ).messages, - inferenceConfig, - } - - const command = new ConverseCommand(payload) - const response = await this.client.send(command) - - if ( - response?.output?.message?.content && - response.output.message.content.length > 0 && - response.output.message.content[0].text && - response.output.message.content[0].text.trim().length > 0 - ) { - try { - return response.output.message.content[0].text - } catch (parseError) { - logger.error("Failed to parse Bedrock response", { - ctx: "bedrock", - error: parseError instanceof Error ? parseError : String(parseError), - }) + /** + * Process usage metrics from the AI SDK response. + */ + private processUsageMetrics( + usage: { inputTokens?: number; outputTokens?: number }, + info: ModelInfo, + providerMetadata?: Record>, + ): ApiStreamUsageChunk { + const inputTokens = usage.inputTokens ?? 0 + const outputTokens = usage.outputTokens ?? 0 + + // The AI SDK exposes reasoningTokens as a top-level field on usage, and also + // under outputTokenDetails.reasoningTokens — there is no .details property. + const reasoningTokens = + (usage as any).reasoningTokens ?? (usage as any).outputTokenDetails?.reasoningTokens ?? 0 + + // Extract cache metrics primarily from usage (AI SDK standard locations), + // falling back to providerMetadata.bedrock.usage for provider-specific fields. + const bedrockUsage = providerMetadata?.bedrock?.usage as + | { cacheReadInputTokens?: number; cacheWriteInputTokens?: number } + | undefined + const cacheReadTokens = + (usage as any).inputTokenDetails?.cacheReadTokens ?? + (usage as any).cachedInputTokens ?? + bedrockUsage?.cacheReadInputTokens ?? + 0 + const cacheWriteTokens = + (usage as any).inputTokenDetails?.cacheWriteTokens ?? bedrockUsage?.cacheWriteInputTokens ?? 0 + + // For prompt routers, the AI SDK surfaces the invoked model ID in + // providerMetadata.bedrock.trace.promptRouter.invokedModelId. + // When present, look up that model's pricing info for accurate cost calculation. + const invokedModelId = (providerMetadata?.bedrock as any)?.trace?.promptRouter?.invokedModelId as + | string + | undefined + let costInfo = info + if (invokedModelId) { + try { + const invokedArnInfo = this.parseArn(invokedModelId) + const invokedModel = this.getModelById(invokedArnInfo.modelId as string, invokedArnInfo.modelType) + if (invokedModel) { + // Update costModelConfig so subsequent requests use the invoked model's pricing, + // but keep the router's ID so requests continue through the router. + invokedModel.id = this.costModelConfig.id || invokedModel.id + this.costModelConfig = invokedModel + costInfo = invokedModel.info } + } catch (error) { + logger.error("Error handling Bedrock invokedModelId", { + ctx: "bedrock", + error: error instanceof Error ? error : String(error), + }) } - return "" - } catch (error) { - // Capture error in telemetry - const model = this.getModel() - const telemetryErrorMessage = error instanceof Error ? error.message : String(error) - const apiError = new ApiProviderError(telemetryErrorMessage, this.providerName, model.id, "completePrompt") - TelemetryService.instance.captureException(apiError) + } - // Use the extracted error handling method for all errors - const errorResult = this.handleBedrockError(error, false) // false for non-streaming context - // Since we're in a non-streaming context, we know the result is a string - const errorMessage = errorResult as string - - // Create enhanced error for retry system - const enhancedError = new Error(errorMessage) - if (error instanceof Error) { - // Preserve important properties from the original error - enhancedError.name = error.name - // Validate and preserve status property - if ("status" in error && typeof (error as any).status === "number") { - ;(enhancedError as any).status = (error as any).status - } - // Validate and preserve $metadata property - if ( - "$metadata" in error && - typeof (error as any).$metadata === "object" && - (error as any).$metadata !== null - ) { - ;(enhancedError as any).$metadata = (error as any).$metadata - } - } - throw enhancedError + return { + type: "usage", + inputTokens, + outputTokens, + cacheReadTokens: cacheReadTokens > 0 ? cacheReadTokens : undefined, + cacheWriteTokens: cacheWriteTokens > 0 ? cacheWriteTokens : undefined, + reasoningTokens: reasoningTokens > 0 ? reasoningTokens : undefined, + totalCost: this.calculateCost({ + inputTokens, + outputTokens, + cacheWriteTokens, + cacheReadTokens, + reasoningTokens, + info: costInfo, + }), } } /** - * Convert Anthropic messages to Bedrock Converse format + * Check if an error is a throttling/rate limit error */ - private convertToBedrockConverseMessages( - anthropicMessages: Anthropic.Messages.MessageParam[] | { role: string; content: string }[], - systemMessage?: string, - usePromptCache: boolean = false, - modelInfo?: any, - conversationId?: string, // Optional conversation ID to track cache points across messages - ): { system: SystemContentBlock[]; messages: Message[] } { - // First convert messages using shared converter for proper image handling - const convertedMessages = sharedConverter(anthropicMessages as Anthropic.Messages.MessageParam[]) - - // If prompt caching is disabled, return the converted messages directly - if (!usePromptCache) { - return { - system: systemMessage ? [{ text: systemMessage } as SystemContentBlock] : [], - messages: convertedMessages, - } - } - - // Convert model info to expected format for cache strategy - const cacheModelInfo: CacheModelInfo = { - maxTokens: modelInfo?.maxTokens || 8192, - contextWindow: modelInfo?.contextWindow || 200_000, - supportsPromptCache: modelInfo?.supportsPromptCache || false, - maxCachePoints: modelInfo?.maxCachePoints || 0, - minTokensPerCachePoint: modelInfo?.minTokensPerCachePoint || 50, - cachableFields: modelInfo?.cachableFields || [], - } - - // Get previous cache point placements for this conversation if available - const previousPlacements = - conversationId && this.previousCachePointPlacements[conversationId] - ? this.previousCachePointPlacements[conversationId] - : undefined - - // Create config for cache strategy - const config = { - modelInfo: cacheModelInfo, - systemPrompt: systemMessage, - messages: anthropicMessages as Anthropic.Messages.MessageParam[], - usePromptCache, - previousCachePointPlacements: previousPlacements, - } + private isThrottlingError(error: unknown): boolean { + if (!(error instanceof Error)) return false + if ((error as any).status === 429 || (error as any).$metadata?.httpStatusCode === 429) return true + if ((error as any).name === "ThrottlingException") return true + const msg = error.message.toLowerCase() + return ( + msg.includes("throttl") || + msg.includes("rate limit") || + msg.includes("too many requests") || + msg.includes("bedrock is unable to process your request") + ) + } - // Get cache point placements - let strategy = new MultiPointStrategy(config) - const cacheResult = strategy.determineOptimalCachePoints() + async completePrompt(prompt: string): Promise { + const modelConfig = this.getModel() - // Store cache point placements for future use if conversation ID is provided - if (conversationId && cacheResult.messageCachePointPlacements) { - this.previousCachePointPlacements[conversationId] = cacheResult.messageCachePointPlacements - } + try { + const result = await generateText({ + model: this.provider(modelConfig.id), + prompt, + temperature: modelConfig.temperature ?? (this.options.modelTemperature as number), + maxOutputTokens: modelConfig.maxTokens || (modelConfig.info.maxTokens as number), + }) - // Apply cache points to the properly converted messages - const messagesWithCache = convertedMessages.map((msg, index) => { - const placement = cacheResult.messageCachePointPlacements?.find((p) => p.index === index) - if (placement) { - return { - ...msg, - content: [...(msg.content || []), { cachePoint: { type: "default" } } as ContentBlock], - } - } - return msg - }) + return result.text + } catch (error) { + const errorMessage = error instanceof Error ? error.message : String(error) + const apiError = new ApiProviderError(errorMessage, this.providerName, modelConfig.id, "completePrompt") + TelemetryService.instance.captureException(apiError) - return { - system: cacheResult.system, - messages: messagesWithCache, + // Handle AI SDK errors (AI_RetryError, AI_APICallError, etc.) + throw handleAiSdkError(error, this.providerName) } } /************************************************************************************ * - * MODEL IDENTIFICATION + * MODEL CONFIGURATION * *************************************************************************************/ private costModelConfig: { id: BedrockModelId | string; info: ModelInfo } = { id: "", - info: { maxTokens: 0, contextWindow: 0, supportsPromptCache: false, supportsImages: false }, + info: { maxTokens: 0, contextWindow: 0, supportsPromptCache: false }, } private parseArn(arn: string, region?: string) { - /* - * VIA Roo analysis: platform-independent Regex. It's designed to parse Amazon Bedrock ARNs and doesn't rely on any platform-specific features - * like file path separators, line endings, or case sensitivity behaviors. The forward slashes in the regex are properly escaped and - * represent literal characters in the AWS ARN format, not filesystem paths. This regex will function consistently across Windows, - * macOS, Linux, and any other operating system where JavaScript runs. - * - * Supports any AWS partition (aws, aws-us-gov, aws-cn, or future partitions). - * The partition is not captured since we don't need to use it. - * - * This matches ARNs like: - * - Foundation Model: arn:aws:bedrock:us-west-2::foundation-model/anthropic.claude-v2 - * - GovCloud Inference Profile: arn:aws-us-gov:bedrock:us-gov-west-1:123456789012:inference-profile/us-gov.anthropic.claude-sonnet-4-5-20250929-v1:0 - * - Prompt Router: arn:aws:bedrock:us-west-2:123456789012:prompt-router/anthropic-claude - * - Inference Profile: arn:aws:bedrock:us-west-2:123456789012:inference-profile/anthropic.claude-v2 - * - Cross Region Inference Profile: arn:aws:bedrock:us-west-2:123456789012:inference-profile/us.anthropic.claude-3-5-sonnet-20241022-v2:0 - * - Custom Model (Provisioned Throughput): arn:aws:bedrock:us-west-2:123456789012:provisioned-model/my-custom-model - * - Imported Model: arn:aws:bedrock:us-west-2:123456789012:imported-model/my-imported-model - * - * match[0] - The entire matched string - * match[1] - The region (e.g., "us-east-1", "us-gov-west-1") - * match[2] - The account ID (can be empty string for AWS-managed resources) - * match[3] - The resource type (e.g., "foundation-model") - * match[4] - The resource ID (e.g., "anthropic.claude-3-sonnet-20240229-v1:0") - */ - const arnRegex = /^arn:[^:]+:(?:bedrock|sagemaker):([^:]+):([^:]*):(?:([^\/]+)\/([\w\.\-:]+)|([^\/]+))$/ let match = arn.match(arnRegex) if (match && match[1] && match[3] && match[4]) { - // Create the result object const result: { isValid: boolean region?: string @@ -975,25 +535,21 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH crossRegionInference: boolean } = { isValid: true, - crossRegionInference: false, // Default to false + crossRegionInference: false, } result.modelType = match[3] const originalModelId = match[4] result.modelId = this.parseBaseModelId(originalModelId) - // Extract the region from the first capture group const arnRegion = match[1] result.region = arnRegion - // Check if the original model ID had a region prefix if (originalModelId && result.modelId !== originalModelId) { - // If the model ID changed after parsing, it had a region prefix let prefix = originalModelId.replace(result.modelId, "") result.crossRegionInference = AwsBedrockHandler.isSystemInferenceProfile(prefix) } - // Check if region in ARN matches provided region (if specified) if (region && arnRegion !== region) { result.errorMessage = `Region mismatch: The region in your ARN (${arnRegion}) does not match your selected region (${region}). This may cause access issues. The provider will use the region from the ARN.` result.region = arnRegion @@ -1002,7 +558,6 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH return result } - // If we get here, the regex didn't match return { isValid: false, region: undefined, @@ -1013,40 +568,27 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH } } - //This strips any region prefix that used on cross-region model inference ARNs private parseBaseModelId(modelId: string): string { - if (!modelId) { - return modelId - } + if (!modelId) return modelId - // Remove AWS cross-region inference profile prefixes - // as defined in AWS_INFERENCE_PROFILE_MAPPING for (const [_, inferenceProfile] of AWS_INFERENCE_PROFILE_MAPPING) { if (modelId.startsWith(inferenceProfile)) { - // Remove the inference profile prefix from the model ID return modelId.substring(inferenceProfile.length) } } - // Also strip Global Inference profile prefix if present if (modelId.startsWith("global.")) { return modelId.substring("global.".length) } - // Return the model ID as-is for all other cases return modelId } - //Prompt Router responses come back in a different sequence and the model used is in the response and must be fetched by name getModelById(modelId: string, modelType?: string): { id: BedrockModelId | string; info: ModelInfo } { - // Try to find the model in bedrockModels const baseModelId = this.parseBaseModelId(modelId) as BedrockModelId let model if (baseModelId in bedrockModels) { - //Do a deep copy of the model info so that later in the code the model id and maxTokens can be set. - // The bedrockModels array is a constant and updating the model ID from the returned invokedModelID value - // in a prompt router response isn't possible on the constant. model = { id: baseModelId, info: JSON.parse(JSON.stringify(bedrockModels[baseModelId])) } } else if (modelType && modelType.includes("router")) { model = { @@ -1054,7 +596,6 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH info: JSON.parse(JSON.stringify(bedrockModels[bedrockDefaultPromptRouterModelId])), } } else { - // Use heuristics for model info, then allow overrides from ProviderSettings const guessed = this.guessModelInfoFromId(modelId) model = { id: bedrockDefaultModelId, @@ -1065,7 +606,6 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH } } - // Always allow user to override detected/guessed maxTokens and contextWindow if (this.options.modelMaxTokens && this.options.modelMaxTokens > 0) { model.info.maxTokens = this.options.modelMaxTokens } @@ -1085,7 +625,6 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH reasoningBudget?: number } { if (this.costModelConfig?.id?.trim().length > 0) { - // Get model params for cost model config const params = getModelParams({ format: "anthropic", modelId: this.costModelConfig.id, @@ -1098,28 +637,19 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH let modelConfig = undefined - // If custom ARN is provided, use it if (this.options.awsCustomArn) { modelConfig = this.getModelById(this.arnInfo.modelId, this.arnInfo.modelType) - - //If the user entered an ARN for a foundation-model they've done the same thing as picking from our list of options. - //We leave the model data matching the same as if a drop-down input method was used by not overwriting the model ID with the user input ARN - //Otherwise the ARN is not a foundation-model resource type that ARN should be used as the identifier in Bedrock interactions if (this.arnInfo.modelType !== "foundation-model") modelConfig.id = this.options.awsCustomArn } else { - //a model was selected from the drop down modelConfig = this.getModelById(this.options.apiModelId as string) - // Apply Global Inference prefix if enabled and supported (takes precedence over cross-region) const baseIdForGlobal = this.parseBaseModelId(modelConfig.id) if ( this.options.awsUseGlobalInference && BEDROCK_GLOBAL_INFERENCE_MODEL_IDS.includes(baseIdForGlobal as any) ) { modelConfig.id = `global.${baseIdForGlobal}` - } - // Otherwise, add cross-region inference prefix if enabled - else if (this.options.awsUseCrossRegionInference && this.options.awsRegion) { + } else if (this.options.awsUseCrossRegionInference && this.options.awsRegion) { const prefix = AwsBedrockHandler.getPrefixForRegion(this.options.awsRegion) if (prefix) { modelConfig.id = `${prefix}${modelConfig.id}` @@ -1127,11 +657,9 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH } } - // Check if 1M context is enabled for supported Claude 4 models - // Use parseBaseModelId to handle cross-region inference prefixes + // Check if 1M context is enabled const baseModelId = this.parseBaseModelId(modelConfig.id) if (BEDROCK_1M_CONTEXT_MODEL_IDS.includes(baseModelId as any) && this.options.awsBedrock1MContext) { - // Update context window and pricing to 1M tier when 1M context beta is enabled const tier = modelConfig.info.tiers?.[0] modelConfig.info = { ...modelConfig.info, @@ -1143,7 +671,6 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH } } - // Get model params including reasoning configuration const params = getModelParams({ format: "anthropic", modelId: modelConfig.id, @@ -1152,12 +679,11 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH defaultTemperature: BEDROCK_DEFAULT_TEMPERATURE, }) - // Apply service tier pricing if specified and model supports it + // Apply service tier pricing const baseModelIdForTier = this.parseBaseModelId(modelConfig.id) if (this.options.awsBedrockServiceTier && BEDROCK_SERVICE_TIER_MODEL_IDS.includes(baseModelIdForTier as any)) { const pricingMultiplier = BEDROCK_SERVICE_TIER_PRICING[this.options.awsBedrockServiceTier] if (pricingMultiplier && pricingMultiplier !== 1.0) { - // Apply pricing multiplier to all price fields modelConfig.info = { ...modelConfig.info, inputPrice: modelConfig.info.inputPrice @@ -1176,7 +702,6 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH } } - // Don't override maxTokens/contextWindow here; handled in getModelById (and includes user overrides) return { ...modelConfig, ...params } as { id: BedrockModelId | string info: ModelInfo @@ -1193,103 +718,82 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH * *************************************************************************************/ - // Store previous cache point placements for maintaining consistency across consecutive messages - private previousCachePointPlacements: { [conversationId: string]: any[] } = {} - private supportsAwsPromptCache(modelConfig: { id: BedrockModelId | string; info: ModelInfo }): boolean | undefined { - // Check if the model supports prompt cache - // The cachableFields property is not part of the ModelInfo type in schemas - // but it's used in the bedrockModels object in shared/api.ts return ( modelConfig?.info?.supportsPromptCache && - // Use optional chaining and type assertion to access cachableFields (modelConfig?.info as any)?.cachableFields && (modelConfig?.info as any)?.cachableFields?.length > 0 ) } /** - * Removes any existing cachePoint nodes from content blocks - */ - private removeCachePoints(content: any): any { - if (Array.isArray(content)) { - return content.map((block) => { - // Use destructuring to remove cachePoint property - const { cachePoint: _, ...rest } = block - return rest - }) - } - - return content - } - - /************************************************************************************ - * - * NATIVE TOOLS + * Apply cachePoint providerOptions to the correct AI SDK messages by walking + * the original Anthropic messages and converted AI SDK messages in parallel. * - *************************************************************************************/ - - /** - * Convert OpenAI tool definitions to Bedrock Converse format - * Transforms JSON Schema to draft 2020-12 compliant format required by Claude models. - * @param tools Array of OpenAI ChatCompletionTool definitions - * @returns Array of Bedrock Tool definitions + * convertToAiSdkMessages() can split a single Anthropic user message (containing + * tool_results + text) into 2 AI SDK messages (tool role + user role). This method + * accounts for that split so cache points land on the right message. */ - private convertToolsForBedrock(tools: OpenAI.Chat.ChatCompletionTool[]): Tool[] { - return tools - .filter((tool) => tool.type === "function") - .map( - (tool) => - ({ - toolSpec: { - name: tool.function.name, - description: tool.function.description, - inputSchema: { - // Normalize schema to JSON Schema draft 2020-12 compliant format - // This converts type: ["T", "null"] to anyOf: [{type: "T"}, {type: "null"}] - json: normalizeToolSchema(tool.function.parameters as Record), - }, - }, - }) as Tool, - ) - } - - /** - * Convert OpenAI tool_choice to Bedrock ToolChoice format - * @param toolChoice OpenAI tool_choice parameter - * @returns Bedrock ToolChoice configuration - */ - private convertToolChoiceForBedrock( - toolChoice: OpenAI.Chat.ChatCompletionCreateParams["tool_choice"], - ): ToolChoice | undefined { - if (!toolChoice) { - // Default to auto - model decides whether to use tools - return { auto: {} } as ToolChoice - } - - if (typeof toolChoice === "string") { - switch (toolChoice) { - case "none": - return undefined // Bedrock doesn't have "none", just omit tools - case "auto": - return { auto: {} } as ToolChoice - case "required": - return { any: {} } as ToolChoice // Model must use at least one tool - default: - return { auto: {} } as ToolChoice + private applyCachePointsToAiSdkMessages( + originalMessages: Anthropic.Messages.MessageParam[], + aiSdkMessages: { role: string; providerOptions?: Record> }[], + targetOriginalIndices: Set, + cachePointOption: Record>, + ): void { + let aiSdkIdx = 0 + for (let origIdx = 0; origIdx < originalMessages.length; origIdx++) { + const origMsg = originalMessages[origIdx] + + if (typeof origMsg.content === "string") { + // Simple string content → 1 AI SDK message + if (targetOriginalIndices.has(origIdx) && aiSdkIdx < aiSdkMessages.length) { + aiSdkMessages[aiSdkIdx].providerOptions = { + ...aiSdkMessages[aiSdkIdx].providerOptions, + ...cachePointOption, + } + } + aiSdkIdx++ + } else if (origMsg.role === "user") { + // User message with array content may split into tool + user messages. + const hasToolResults = origMsg.content.some((part) => (part as { type: string }).type === "tool_result") + const hasNonToolContent = origMsg.content.some( + (part) => (part as { type: string }).type === "text" || (part as { type: string }).type === "image", + ) + + if (hasToolResults && hasNonToolContent) { + // Split into tool msg + user msg — cache the user msg (the second one) + const userMsgIdx = aiSdkIdx + 1 + if (targetOriginalIndices.has(origIdx) && userMsgIdx < aiSdkMessages.length) { + aiSdkMessages[userMsgIdx].providerOptions = { + ...aiSdkMessages[userMsgIdx].providerOptions, + ...cachePointOption, + } + } + aiSdkIdx += 2 + } else if (hasToolResults) { + // Only tool results → 1 tool msg + if (targetOriginalIndices.has(origIdx) && aiSdkIdx < aiSdkMessages.length) { + aiSdkMessages[aiSdkIdx].providerOptions = { + ...aiSdkMessages[aiSdkIdx].providerOptions, + ...cachePointOption, + } + } + aiSdkIdx++ + } else { + // Only text/image content → 1 user msg + if (targetOriginalIndices.has(origIdx) && aiSdkIdx < aiSdkMessages.length) { + aiSdkMessages[aiSdkIdx].providerOptions = { + ...aiSdkMessages[aiSdkIdx].providerOptions, + ...cachePointOption, + } + } + aiSdkIdx++ + } + } else { + // Assistant message → 1 AI SDK message + aiSdkIdx++ } } - - // Handle object form { type: "function", function: { name: string } } - if (typeof toolChoice === "object" && "function" in toolChoice) { - return { - tool: { - name: toolChoice.function.name, - }, - } as ToolChoice - } - - return { auto: {} } as ToolChoice } /************************************************************************************ @@ -1299,19 +803,15 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH *************************************************************************************/ private static getPrefixForRegion(region: string): string | undefined { - // Use AWS recommended inference profile prefixes - // Array is pre-sorted by pattern length (descending) to ensure more specific patterns match first for (const [regionPattern, inferenceProfile] of AWS_INFERENCE_PROFILE_MAPPING) { if (region.startsWith(regionPattern)) { return inferenceProfile } } - return undefined } private static isSystemInferenceProfile(prefix: string): boolean { - // Check if the prefix is defined in AWS_INFERENCE_PROFILE_MAPPING for (const [_, inferenceProfile] of AWS_INFERENCE_PROFILE_MAPPING) { if (prefix === inferenceProfile) { return true @@ -1322,299 +822,51 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH /************************************************************************************ * - * ERROR HANDLING + * COST CALCULATION * *************************************************************************************/ - /** - * Error type definitions for Bedrock API errors - */ - private static readonly ERROR_TYPES: Record< - string, - { - patterns: string[] // Strings to match in lowercase error message or name - messageTemplate: string // Template with placeholders like {region}, {modelId}, etc. - logLevel: "error" | "warn" | "info" // Log level for this error type - } - > = { - ACCESS_DENIED: { - patterns: ["access", "denied", "permission"], - messageTemplate: `You don't have access to the model specified. - -Please verify: -1. Try cross-region inference if you're using a foundation model -2. If using an ARN, verify the ARN is correct and points to a valid model -3. Your AWS credentials have permission to access this model (check IAM policies) -4. The region in the ARN matches the region where the model is deployed -5. If using a provisioned model, ensure it's active and not in a failed state`, - logLevel: "error", - }, - NOT_FOUND: { - patterns: ["not found", "does not exist"], - messageTemplate: `The specified ARN does not exist or is invalid. Please check: - -1. The ARN format is correct (arn:aws:bedrock:region:account-id:resource-type/resource-name) -2. The model exists in the specified region -3. The account ID in the ARN is correct`, - logLevel: "error", - }, - THROTTLING: { - patterns: [ - "throttl", - "rate", - "limit", - "bedrock is unable to process your request", // Amazon Bedrock specific throttling message - "please wait", - "quota exceeded", - "service unavailable", - "busy", - "overloaded", - "too many requests", - "request limit", - "concurrent requests", - ], - messageTemplate: `Request was throttled or rate limited. Please try: -1. Reducing the frequency of requests -2. If using a provisioned model, check its throughput settings -3. Contact AWS support to request a quota increase if needed - -`, - logLevel: "error", - }, - TOO_MANY_TOKENS: { - patterns: ["too many tokens", "token limit exceeded", "context length", "maximum context length"], - messageTemplate: `"Too many tokens" error detected. -Possible Causes: -1. Input exceeds model's context window limit -2. Rate limiting (too many tokens per minute) -3. Quota exceeded for token usage -4. Other token-related service limitations - -Suggestions: -1. Reduce the size of your input -2. Split your request into smaller chunks -3. Use a model with a larger context window -4. If rate limited, reduce request frequency -5. Check your Amazon Bedrock quotas and limits - -`, - logLevel: "error", - }, - SERVICE_QUOTA_EXCEEDED: { - patterns: ["service quota exceeded", "service quota", "quota exceeded for model"], - messageTemplate: `Service quota exceeded. This error indicates you've reached AWS service limits. - -Please try: -1. Contact AWS support to request a quota increase -2. Reduce request frequency temporarily -3. Check your Amazon Bedrock quotas in the AWS console -4. Consider using a different model or region with available capacity - -`, - logLevel: "error", - }, - MODEL_NOT_READY: { - patterns: ["model not ready", "model is not ready", "provisioned throughput not ready", "model loading"], - messageTemplate: `Model is not ready or still loading. This can happen with: -1. Provisioned throughput models that are still initializing -2. Custom models that are being loaded -3. Models that are temporarily unavailable - -Please try: -1. Wait a few minutes and retry -2. Check the model status in Amazon Bedrock console -3. Verify the model is properly provisioned - -`, - logLevel: "error", - }, - INTERNAL_SERVER_ERROR: { - patterns: ["internal server error", "internal error", "server error", "service error"], - messageTemplate: `Amazon Bedrock internal server error. This is a temporary service issue. - -Please try: -1. Retry the request after a brief delay -2. If the error persists, check AWS service health -3. Contact AWS support if the issue continues - -`, - logLevel: "error", - }, - ON_DEMAND_NOT_SUPPORTED: { - patterns: ["with on-demand throughput isn’t supported."], - messageTemplate: ` -1. Try enabling cross-region inference in settings. -2. Or, create an inference profile and then leverage the "Use custom ARN..." option of the model selector in settings.`, - logLevel: "error", - }, - ABORT: { - patterns: ["aborterror"], // This will match error.name.toLowerCase() for AbortError - messageTemplate: `Request was aborted: The operation timed out or was manually cancelled. Please try again or check your network connection.`, - logLevel: "info", - }, - INVALID_ARN_FORMAT: { - patterns: ["invalid_arn_format:", "invalid arn format"], - messageTemplate: `Invalid ARN format. ARN should follow the pattern: arn:aws:bedrock:region:account-id:resource-type/resource-name`, - logLevel: "error", - }, - VALIDATION_ERROR: { - patterns: [ - "input tag", - "does not match any of the expected tags", - "field required", - "validation", - "invalid parameter", - ], - messageTemplate: `Parameter validation error: {errorMessage} - -This error indicates that the request parameters don't match Amazon Bedrock's expected format. - -Common causes: -1. Extended thinking parameter format is incorrect -2. Model-specific parameters are not supported by this model -3. API parameter structure has changed - -Please check: -- Model supports the requested features (extended thinking, etc.) -- Parameter format matches Amazon Bedrock specification -- Model ID is correct for the requested features`, - logLevel: "error", - }, - // Default/generic error - GENERIC: { - patterns: [], // Empty patterns array means this is the default - messageTemplate: `Unknown Error: {errorMessage}`, - logLevel: "error", - }, - } - - /** - * Determines the error type based on the error message or name - */ - private getErrorType(error: unknown): string { - if (!(error instanceof Error)) { - return "GENERIC" - } - - // Check for HTTP 429 status code (Too Many Requests) - if ((error as any).status === 429 || (error as any).$metadata?.httpStatusCode === 429) { - return "THROTTLING" - } - - // Check for Amazon Bedrock specific throttling exception names - if ((error as any).name === "ThrottlingException" || (error as any).__type === "ThrottlingException") { - return "THROTTLING" - } - - const errorMessage = error.message.toLowerCase() - const errorName = error.name.toLowerCase() - - // Check each error type's patterns in order of specificity (most specific first) - const errorTypeOrder = [ - "SERVICE_QUOTA_EXCEEDED", // Most specific - check before THROTTLING - "MODEL_NOT_READY", - "TOO_MANY_TOKENS", - "INTERNAL_SERVER_ERROR", - "ON_DEMAND_NOT_SUPPORTED", - "NOT_FOUND", - "ACCESS_DENIED", - "THROTTLING", // Less specific - check after more specific patterns - ] - - for (const errorType of errorTypeOrder) { - const definition = AwsBedrockHandler.ERROR_TYPES[errorType] - if (!definition) continue - - // If any pattern matches in either message or name, return this error type - if (definition.patterns.some((pattern) => errorMessage.includes(pattern) || errorName.includes(pattern))) { - return errorType - } - } - - // Default to generic error - return "GENERIC" - } - - /** - * Formats an error message based on the error type and context - */ - private formatErrorMessage(error: unknown, errorType: string, _isStreamContext: boolean): string { - const definition = AwsBedrockHandler.ERROR_TYPES[errorType] || AwsBedrockHandler.ERROR_TYPES.GENERIC - let template = definition.messageTemplate - - // Prepare template variables - const templateVars: Record = {} - - if (error instanceof Error) { - templateVars.errorMessage = error.message - templateVars.errorName = error.name - - const modelConfig = this.getModel() - templateVars.modelId = modelConfig.id - templateVars.contextWindow = String(modelConfig.info.contextWindow || "unknown") - } + private calculateCost({ + inputTokens, + outputTokens, + cacheWriteTokens = 0, + cacheReadTokens = 0, + reasoningTokens = 0, + info, + }: { + inputTokens: number + outputTokens: number + cacheWriteTokens?: number + cacheReadTokens?: number + reasoningTokens?: number + info: ModelInfo + }): number { + const inputPrice = info.inputPrice ?? 0 + const outputPrice = info.outputPrice ?? 0 + const cacheWritesPrice = info.cacheWritesPrice ?? 0 + const cacheReadsPrice = info.cacheReadsPrice ?? 0 - // Add context-specific template variables - const region = - typeof this?.client?.config?.region === "function" - ? this?.client?.config?.region() - : this?.client?.config?.region - templateVars.regionInfo = `(${region})` + const uncachedInputTokens = Math.max(0, inputTokens - cacheWriteTokens - cacheReadTokens) + const billedOutputTokens = outputTokens + reasoningTokens - // Replace template variables - for (const [key, value] of Object.entries(templateVars)) { - template = template.replace(new RegExp(`{${key}}`, "g"), value || "") - } + const cacheWriteCost = cacheWriteTokens > 0 ? cacheWritesPrice * (cacheWriteTokens / 1_000_000) : 0 + const cacheReadCost = cacheReadTokens > 0 ? cacheReadsPrice * (cacheReadTokens / 1_000_000) : 0 + const inputTokensCost = inputPrice * (uncachedInputTokens / 1_000_000) + const outputTokensCost = outputPrice * (billedOutputTokens / 1_000_000) - return template + return inputTokensCost + outputTokensCost + cacheWriteCost + cacheReadCost } - /** - * Handles Bedrock API errors and generates appropriate error messages - * @param error The error that occurred - * @param isStreamContext Whether the error occurred in a streaming context (true) or not (false) - * @returns Error message string for non-streaming context or array of stream chunks for streaming context - */ - private handleBedrockError( - error: unknown, - isStreamContext: boolean, - ): string | Array<{ type: string; text?: string; inputTokens?: number; outputTokens?: number }> { - // Determine error type - const errorType = this.getErrorType(error) - - // Format error message - const errorMessage = this.formatErrorMessage(error, errorType, isStreamContext) - - // Log the error - const definition = AwsBedrockHandler.ERROR_TYPES[errorType] - const logMethod = definition.logLevel - const contextName = isStreamContext ? "createMessage" : "completePrompt" - logger[logMethod](`${errorType} error in ${contextName}`, { - ctx: "bedrock", - customArn: this.options.awsCustomArn, - errorType, - errorMessage: error instanceof Error ? error.message : String(error), - ...(error instanceof Error && error.stack ? { errorStack: error.stack } : {}), - ...(this.client?.config?.region ? { clientRegion: this.client.config.region } : {}), - }) - - // Return appropriate response based on isStreamContext - if (isStreamContext) { - return [ - { type: "text", text: `Error: ${errorMessage}` }, - { type: "usage", inputTokens: 0, outputTokens: 0 }, - ] - } else { - // For non-streaming context, add the expected prefix - return `Bedrock completion error: ${errorMessage}` - } - } + /************************************************************************************ + * + * THINKING SIGNATURE ROUND-TRIP + * + *************************************************************************************/ /** - * Returns the thinking signature captured from the last Bedrock Converse API response. - * Claude models with extended thinking return a cryptographic signature in the - * reasoning content delta, which must be round-tripped back for multi-turn - * conversations with tool use (Anthropic API requirement). + * Returns the thinking signature captured from the last Bedrock response. + * Claude models with extended thinking return a cryptographic signature + * which must be round-tripped back for multi-turn conversations with tool use. */ getThoughtSignature(): string | undefined { return this.lastThoughtSignature @@ -1622,11 +874,13 @@ Please check: /** * Returns any redacted thinking blocks captured from the last Bedrock response. - * Anthropic returns these when safety filters trigger on the model's internal - * reasoning. They contain opaque binary data (base64-encoded) that must be - * passed back verbatim for proper reasoning continuity. + * Anthropic returns these when safety filters trigger on reasoning content. */ getRedactedThinkingBlocks(): Array<{ type: "redacted_thinking"; data: string }> | undefined { return this.lastRedactedThinkingBlocks.length > 0 ? this.lastRedactedThinkingBlocks : undefined } + + override isAiSdkProvider(): boolean { + return true + } } diff --git a/src/api/transform/__tests__/ai-sdk.spec.ts b/src/api/transform/__tests__/ai-sdk.spec.ts index 3c1ca6d87e5..ea4b9a4235e 100644 --- a/src/api/transform/__tests__/ai-sdk.spec.ts +++ b/src/api/transform/__tests__/ai-sdk.spec.ts @@ -349,7 +349,14 @@ describe("AI SDK conversion utilities", () => { expect(result[0]).toEqual({ role: "assistant", content: [ - { type: "reasoning", text: "Deep thought" }, + { + type: "reasoning", + text: "Deep thought", + providerOptions: { + bedrock: { signature: "sig" }, + anthropic: { signature: "sig" }, + }, + }, { type: "text", text: "OK" }, ], }) diff --git a/src/api/transform/__tests__/bedrock-converse-format.spec.ts b/src/api/transform/__tests__/bedrock-converse-format.spec.ts deleted file mode 100644 index f8c3c9f0162..00000000000 --- a/src/api/transform/__tests__/bedrock-converse-format.spec.ts +++ /dev/null @@ -1,694 +0,0 @@ -// npx vitest run src/api/transform/__tests__/bedrock-converse-format.spec.ts - -import { convertToBedrockConverseMessages } from "../bedrock-converse-format" -import { Anthropic } from "@anthropic-ai/sdk" -import { ContentBlock, ToolResultContentBlock } from "@aws-sdk/client-bedrock-runtime" -import { OPENAI_CALL_ID_MAX_LENGTH } from "../../../utils/tool-id" - -describe("convertToBedrockConverseMessages", () => { - it("converts simple text messages correctly", () => { - const messages: Anthropic.Messages.MessageParam[] = [ - { role: "user", content: "Hello" }, - { role: "assistant", content: "Hi there" }, - ] - - const result = convertToBedrockConverseMessages(messages) - - expect(result).toEqual([ - { - role: "user", - content: [{ text: "Hello" }], - }, - { - role: "assistant", - content: [{ text: "Hi there" }], - }, - ]) - }) - - it("converts messages with images correctly", () => { - const messages: Anthropic.Messages.MessageParam[] = [ - { - role: "user", - content: [ - { - type: "text", - text: "Look at this image:", - }, - { - type: "image", - source: { - type: "base64", - data: "SGVsbG8=", // "Hello" in base64 - media_type: "image/jpeg" as const, - }, - }, - ], - }, - ] - - const result = convertToBedrockConverseMessages(messages) - - if (!result[0] || !result[0].content) { - expect.fail("Expected result to have content") - return - } - - expect(result[0].role).toBe("user") - expect(result[0].content).toHaveLength(2) - expect(result[0].content[0]).toEqual({ text: "Look at this image:" }) - - const imageBlock = result[0].content[1] as ContentBlock - if ("image" in imageBlock && imageBlock.image && imageBlock.image.source) { - expect(imageBlock.image.format).toBe("jpeg") - expect(imageBlock.image.source).toBeDefined() - expect(imageBlock.image.source.bytes).toBeDefined() - } else { - expect.fail("Expected image block not found") - } - }) - - it("converts tool use messages correctly (native tools format; default)", () => { - const messages: Anthropic.Messages.MessageParam[] = [ - { - role: "assistant", - content: [ - { - type: "tool_use", - id: "test-id", - name: "read_file", - input: { - path: "test.txt", - }, - }, - ], - }, - ] - - const result = convertToBedrockConverseMessages(messages) - - if (!result[0] || !result[0].content) { - expect.fail("Expected result to have content") - return - } - - expect(result[0].role).toBe("assistant") - const toolBlock = result[0].content[0] as ContentBlock - if ("toolUse" in toolBlock && toolBlock.toolUse) { - expect(toolBlock.toolUse).toEqual({ - toolUseId: "test-id", - name: "read_file", - input: { path: "test.txt" }, - }) - } else { - expect.fail("Expected tool use block not found") - } - }) - - it("converts tool use messages correctly (native tools format)", () => { - const messages: Anthropic.Messages.MessageParam[] = [ - { - role: "assistant", - content: [ - { - type: "tool_use", - id: "test-id", - name: "read_file", - input: { - path: "test.txt", - }, - }, - ], - }, - ] - - const result = convertToBedrockConverseMessages(messages) - - if (!result[0] || !result[0].content) { - expect.fail("Expected result to have content") - return - } - - expect(result[0].role).toBe("assistant") - const toolBlock = result[0].content[0] as ContentBlock - if ("toolUse" in toolBlock && toolBlock.toolUse) { - expect(toolBlock.toolUse).toEqual({ - toolUseId: "test-id", - name: "read_file", - input: { path: "test.txt" }, - }) - } else { - expect.fail("Expected tool use block not found") - } - }) - - it("converts tool result messages to native format (default)", () => { - const messages: Anthropic.Messages.MessageParam[] = [ - { - role: "user", - content: [ - { - type: "tool_result", - tool_use_id: "test-id", - content: [{ type: "text", text: "File contents here" }], - }, - ], - }, - ] - - const result = convertToBedrockConverseMessages(messages) - - if (!result[0] || !result[0].content) { - expect.fail("Expected result to have content") - return - } - - expect(result[0].role).toBe("user") - const resultBlock = result[0].content[0] as ContentBlock - if ("toolResult" in resultBlock && resultBlock.toolResult) { - const expectedContent: ToolResultContentBlock[] = [{ text: "File contents here" }] - expect(resultBlock.toolResult).toEqual({ - toolUseId: "test-id", - content: expectedContent, - status: "success", - }) - } else { - expect.fail("Expected tool result block not found") - } - }) - - it("converts tool result messages to native format", () => { - const messages: Anthropic.Messages.MessageParam[] = [ - { - role: "user", - content: [ - { - type: "tool_result", - tool_use_id: "test-id", - content: [{ type: "text", text: "File contents here" }], - }, - ], - }, - ] - - const result = convertToBedrockConverseMessages(messages) - - if (!result[0] || !result[0].content) { - expect.fail("Expected result to have content") - return - } - - expect(result[0].role).toBe("user") - const resultBlock = result[0].content[0] as ContentBlock - if ("toolResult" in resultBlock && resultBlock.toolResult) { - const expectedContent: ToolResultContentBlock[] = [{ text: "File contents here" }] - expect(resultBlock.toolResult).toEqual({ - toolUseId: "test-id", - content: expectedContent, - status: "success", - }) - } else { - expect.fail("Expected tool result block not found") - } - }) - - it("converts tool result messages with string content to native format (default)", () => { - const messages: Anthropic.Messages.MessageParam[] = [ - { - role: "user", - content: [ - { - type: "tool_result", - tool_use_id: "test-id", - content: "File: test.txt\nLines 1-5:\nHello World", - } as any, // Anthropic types don't allow string content but runtime can have it - ], - }, - ] - - const result = convertToBedrockConverseMessages(messages) - - if (!result[0] || !result[0].content) { - expect.fail("Expected result to have content") - return - } - - expect(result[0].role).toBe("user") - const resultBlock = result[0].content[0] as ContentBlock - if ("toolResult" in resultBlock && resultBlock.toolResult) { - expect(resultBlock.toolResult).toEqual({ - toolUseId: "test-id", - content: [{ text: "File: test.txt\nLines 1-5:\nHello World" }], - status: "success", - }) - } else { - expect.fail("Expected tool result block not found") - } - }) - - it("converts tool result messages with string content to native format", () => { - const messages: Anthropic.Messages.MessageParam[] = [ - { - role: "user", - content: [ - { - type: "tool_result", - tool_use_id: "test-id", - content: "File: test.txt\nLines 1-5:\nHello World", - } as any, // Anthropic types don't allow string content but runtime can have it - ], - }, - ] - - const result = convertToBedrockConverseMessages(messages) - - if (!result[0] || !result[0].content) { - expect.fail("Expected result to have content") - return - } - - expect(result[0].role).toBe("user") - const resultBlock = result[0].content[0] as ContentBlock - if ("toolResult" in resultBlock && resultBlock.toolResult) { - expect(resultBlock.toolResult).toEqual({ - toolUseId: "test-id", - content: [{ text: "File: test.txt\nLines 1-5:\nHello World" }], - status: "success", - }) - } else { - expect.fail("Expected tool result block not found") - } - }) - - it("keeps both tool_use and tool_result in native format by default", () => { - const messages: Anthropic.Messages.MessageParam[] = [ - { - role: "assistant", - content: [ - { - type: "tool_use", - id: "call-123", - name: "read_file", - input: { path: "test.txt" }, - }, - ], - }, - { - role: "user", - content: [ - { - type: "tool_result", - tool_use_id: "call-123", - content: "File contents here", - } as any, - ], - }, - ] - - const result = convertToBedrockConverseMessages(messages) - - // Both should be native toolUse/toolResult blocks - const assistantContent = result[0]?.content?.[0] as ContentBlock - const userContent = result[1]?.content?.[0] as ContentBlock - - expect("toolUse" in assistantContent).toBe(true) - expect("toolResult" in userContent).toBe(true) - expect("text" in assistantContent).toBe(false) - expect("text" in userContent).toBe(false) - }) - - it("handles text content correctly", () => { - const messages: Anthropic.Messages.MessageParam[] = [ - { - role: "user", - content: [ - { - type: "text", - text: "Hello world", - }, - ], - }, - ] - - const result = convertToBedrockConverseMessages(messages) - - if (!result[0] || !result[0].content) { - expect.fail("Expected result to have content") - return - } - - expect(result[0].role).toBe("user") - expect(result[0].content).toHaveLength(1) - const textBlock = result[0].content[0] as ContentBlock - expect(textBlock).toEqual({ text: "Hello world" }) - }) - - describe("toolUseId sanitization for Bedrock 64-char limit", () => { - it("truncates toolUseId longer than 64 characters in tool_use blocks", () => { - const longId = "a".repeat(100) - const messages: Anthropic.Messages.MessageParam[] = [ - { - role: "assistant", - content: [ - { - type: "tool_use", - id: longId, - name: "read_file", - input: { path: "test.txt" }, - }, - ], - }, - ] - - const result = convertToBedrockConverseMessages(messages) - const toolBlock = result[0]?.content?.[0] as ContentBlock - - if ("toolUse" in toolBlock && toolBlock.toolUse && toolBlock.toolUse.toolUseId) { - expect(toolBlock.toolUse.toolUseId.length).toBeLessThanOrEqual(OPENAI_CALL_ID_MAX_LENGTH) - expect(toolBlock.toolUse.toolUseId.length).toBe(OPENAI_CALL_ID_MAX_LENGTH) - expect(toolBlock.toolUse.toolUseId).toContain("_") - } else { - expect.fail("Expected tool use block not found") - } - }) - - it("truncates toolUseId longer than 64 characters in tool_result blocks with string content", () => { - const longId = "b".repeat(100) - const messages: Anthropic.Messages.MessageParam[] = [ - { - role: "user", - content: [ - { - type: "tool_result", - tool_use_id: longId, - content: "Result content", - } as any, - ], - }, - ] - - const result = convertToBedrockConverseMessages(messages) - const resultBlock = result[0]?.content?.[0] as ContentBlock - - if ("toolResult" in resultBlock && resultBlock.toolResult && resultBlock.toolResult.toolUseId) { - expect(resultBlock.toolResult.toolUseId.length).toBeLessThanOrEqual(OPENAI_CALL_ID_MAX_LENGTH) - expect(resultBlock.toolResult.toolUseId.length).toBe(OPENAI_CALL_ID_MAX_LENGTH) - expect(resultBlock.toolResult.toolUseId).toContain("_") - } else { - expect.fail("Expected tool result block not found") - } - }) - - it("truncates toolUseId longer than 64 characters in tool_result blocks with array content", () => { - const longId = "c".repeat(100) - const messages: Anthropic.Messages.MessageParam[] = [ - { - role: "user", - content: [ - { - type: "tool_result", - tool_use_id: longId, - content: [{ type: "text", text: "Result content" }], - }, - ], - }, - ] - - const result = convertToBedrockConverseMessages(messages) - const resultBlock = result[0]?.content?.[0] as ContentBlock - - if ("toolResult" in resultBlock && resultBlock.toolResult && resultBlock.toolResult.toolUseId) { - expect(resultBlock.toolResult.toolUseId.length).toBeLessThanOrEqual(OPENAI_CALL_ID_MAX_LENGTH) - expect(resultBlock.toolResult.toolUseId.length).toBe(OPENAI_CALL_ID_MAX_LENGTH) - } else { - expect.fail("Expected tool result block not found") - } - }) - - it("keeps toolUseId unchanged when under 64 characters", () => { - const shortId = "short-id-123" - const messages: Anthropic.Messages.MessageParam[] = [ - { - role: "assistant", - content: [ - { - type: "tool_use", - id: shortId, - name: "read_file", - input: { path: "test.txt" }, - }, - ], - }, - ] - - const result = convertToBedrockConverseMessages(messages) - const toolBlock = result[0]?.content?.[0] as ContentBlock - - if ("toolUse" in toolBlock && toolBlock.toolUse) { - expect(toolBlock.toolUse.toolUseId).toBe(shortId) - } else { - expect.fail("Expected tool use block not found") - } - }) - - it("produces consistent truncated IDs for the same input", () => { - const longId = "d".repeat(100) - const messages: Anthropic.Messages.MessageParam[] = [ - { - role: "assistant", - content: [ - { - type: "tool_use", - id: longId, - name: "read_file", - input: { path: "test.txt" }, - }, - ], - }, - ] - - const result1 = convertToBedrockConverseMessages(messages) - const result2 = convertToBedrockConverseMessages(messages) - - const toolBlock1 = result1[0]?.content?.[0] as ContentBlock - const toolBlock2 = result2[0]?.content?.[0] as ContentBlock - - if ("toolUse" in toolBlock1 && toolBlock1.toolUse && "toolUse" in toolBlock2 && toolBlock2.toolUse) { - expect(toolBlock1.toolUse.toolUseId).toBe(toolBlock2.toolUse.toolUseId) - } else { - expect.fail("Expected tool use blocks not found") - } - }) - - it("produces different truncated IDs for different long inputs", () => { - const longId1 = "e".repeat(100) - const longId2 = "f".repeat(100) - - const messages1: Anthropic.Messages.MessageParam[] = [ - { - role: "assistant", - content: [{ type: "tool_use", id: longId1, name: "read_file", input: {} }], - }, - ] - const messages2: Anthropic.Messages.MessageParam[] = [ - { - role: "assistant", - content: [{ type: "tool_use", id: longId2, name: "read_file", input: {} }], - }, - ] - - const result1 = convertToBedrockConverseMessages(messages1) - const result2 = convertToBedrockConverseMessages(messages2) - - const toolBlock1 = result1[0]?.content?.[0] as ContentBlock - const toolBlock2 = result2[0]?.content?.[0] as ContentBlock - - if ("toolUse" in toolBlock1 && toolBlock1.toolUse && "toolUse" in toolBlock2 && toolBlock2.toolUse) { - expect(toolBlock1.toolUse.toolUseId).not.toBe(toolBlock2.toolUse.toolUseId) - } else { - expect.fail("Expected tool use blocks not found") - } - }) - - it("matching tool_use and tool_result IDs are both truncated consistently", () => { - const longId = "g".repeat(100) - const messages: Anthropic.Messages.MessageParam[] = [ - { - role: "assistant", - content: [ - { - type: "tool_use", - id: longId, - name: "read_file", - input: { path: "test.txt" }, - }, - ], - }, - { - role: "user", - content: [ - { - type: "tool_result", - tool_use_id: longId, - content: "File contents", - } as any, - ], - }, - ] - - const result = convertToBedrockConverseMessages(messages) - - const toolUseBlock = result[0]?.content?.[0] as ContentBlock - const toolResultBlock = result[1]?.content?.[0] as ContentBlock - - if ( - "toolUse" in toolUseBlock && - toolUseBlock.toolUse && - toolUseBlock.toolUse.toolUseId && - "toolResult" in toolResultBlock && - toolResultBlock.toolResult && - toolResultBlock.toolResult.toolUseId - ) { - expect(toolUseBlock.toolUse.toolUseId).toBe(toolResultBlock.toolResult.toolUseId) - expect(toolUseBlock.toolUse.toolUseId.length).toBeLessThanOrEqual(OPENAI_CALL_ID_MAX_LENGTH) - } else { - expect.fail("Expected tool use and result blocks not found") - } - }) - }) - - describe("thinking and reasoning block handling", () => { - it("should convert thinking blocks to reasoningContent format", () => { - const messages: Anthropic.Messages.MessageParam[] = [ - { - role: "assistant", - content: [ - { type: "thinking", thinking: "Let me think about this...", signature: "sig-abc123" } as any, - { type: "text", text: "Here is my answer." }, - ], - }, - ] - - const result = convertToBedrockConverseMessages(messages) - - expect(result).toHaveLength(1) - expect(result[0].role).toBe("assistant") - expect(result[0].content).toHaveLength(2) - - const reasoningBlock = result[0].content![0] as any - expect(reasoningBlock.reasoningContent).toBeDefined() - expect(reasoningBlock.reasoningContent.reasoningText.text).toBe("Let me think about this...") - expect(reasoningBlock.reasoningContent.reasoningText.signature).toBe("sig-abc123") - - const textBlock = result[0].content![1] as any - expect(textBlock.text).toBe("Here is my answer.") - }) - - it("should convert redacted_thinking blocks with data to reasoningContent.redactedContent", () => { - const testData = Buffer.from("encrypted-redacted-content").toString("base64") - const messages: Anthropic.Messages.MessageParam[] = [ - { - role: "assistant", - content: [{ type: "redacted_thinking", data: testData } as any, { type: "text", text: "Response" }], - }, - ] - - const result = convertToBedrockConverseMessages(messages) - - expect(result).toHaveLength(1) - expect(result[0].content).toHaveLength(2) - - const redactedBlock = result[0].content![0] as any - expect(redactedBlock.reasoningContent).toBeDefined() - expect(redactedBlock.reasoningContent.redactedContent).toBeInstanceOf(Uint8Array) - // Verify round-trip: decode back and compare - const decoded = Buffer.from(redactedBlock.reasoningContent.redactedContent).toString("utf-8") - expect(decoded).toBe("encrypted-redacted-content") - }) - - it("should skip redacted_thinking blocks without data", () => { - const messages: Anthropic.Messages.MessageParam[] = [ - { - role: "assistant", - content: [{ type: "redacted_thinking" } as any, { type: "text", text: "Response" }], - }, - ] - - const result = convertToBedrockConverseMessages(messages) - - expect(result).toHaveLength(1) - // Only the text block should remain (redacted_thinking without data is filtered out) - expect(result[0].content).toHaveLength(1) - expect((result[0].content![0] as any).text).toBe("Response") - }) - - it("should skip reasoning blocks (internal Roo Code format)", () => { - const messages: Anthropic.Messages.MessageParam[] = [ - { - role: "assistant", - content: [ - { type: "reasoning", text: "Internal reasoning" } as any, - { type: "text", text: "Response" }, - ], - }, - ] - - const result = convertToBedrockConverseMessages(messages) - - expect(result).toHaveLength(1) - expect(result[0].content).toHaveLength(1) - expect((result[0].content![0] as any).text).toBe("Response") - }) - - it("should skip thoughtSignature blocks (Gemini format)", () => { - const messages: Anthropic.Messages.MessageParam[] = [ - { - role: "assistant", - content: [ - { type: "text", text: "Response" }, - { type: "thoughtSignature", thoughtSignature: "gemini-sig" } as any, - ], - }, - ] - - const result = convertToBedrockConverseMessages(messages) - - expect(result).toHaveLength(1) - expect(result[0].content).toHaveLength(1) - expect((result[0].content![0] as any).text).toBe("Response") - }) - - it("should handle full thinking + redacted_thinking + text + tool_use message", () => { - const redactedData = Buffer.from("redacted-binary").toString("base64") - const messages: Anthropic.Messages.MessageParam[] = [ - { - role: "assistant", - content: [ - { type: "thinking", thinking: "Deep thought", signature: "sig-xyz" } as any, - { type: "redacted_thinking", data: redactedData } as any, - { type: "text", text: "I'll use a tool." }, - { type: "tool_use", id: "tool-1", name: "read_file", input: { path: "test.txt" } }, - ], - }, - ] - - const result = convertToBedrockConverseMessages(messages) - - expect(result).toHaveLength(1) - expect(result[0].content).toHaveLength(4) - - // thinking → reasoningContent.reasoningText - expect((result[0].content![0] as any).reasoningContent.reasoningText.text).toBe("Deep thought") - expect((result[0].content![0] as any).reasoningContent.reasoningText.signature).toBe("sig-xyz") - - // redacted_thinking → reasoningContent.redactedContent - expect((result[0].content![1] as any).reasoningContent.redactedContent).toBeInstanceOf(Uint8Array) - - // text - expect((result[0].content![2] as any).text).toBe("I'll use a tool.") - - // tool_use → toolUse - expect((result[0].content![3] as any).toolUse.name).toBe("read_file") - }) - }) -}) diff --git a/src/api/transform/ai-sdk.ts b/src/api/transform/ai-sdk.ts index 9b48ee57f79..c673fad3d27 100644 --- a/src/api/transform/ai-sdk.ts +++ b/src/api/transform/ai-sdk.ts @@ -139,6 +139,11 @@ export function convertToAiSdkMessages( providerOptions?: Record> }> = [] + // Capture thinking signature for Anthropic-protocol providers (Bedrock, Anthropic). + // Task.ts stores thinking blocks as { type: "thinking", thinking: "...", signature: "..." }. + // The signature must be passed back via providerOptions on reasoning parts. + let thinkingSignature: string | undefined + // Extract thoughtSignature from content blocks (Gemini 3 thought signature round-tripping). // Task.ts stores these as { type: "thoughtSignature", thoughtSignature: "..." } blocks. let thoughtSignature: string | undefined @@ -196,16 +201,20 @@ export function convertToAiSdkMessages( if ((part as unknown as { type?: string }).type === "thinking") { if (reasoningContent) continue - const thinking = (part as unknown as { thinking?: string }).thinking - if (typeof thinking === "string" && thinking.length > 0) { - reasoningParts.push(thinking) + const thinkingPart = part as unknown as { thinking?: string; signature?: string } + if (typeof thinkingPart.thinking === "string" && thinkingPart.thinking.length > 0) { + reasoningParts.push(thinkingPart.thinking) + } + // Capture the signature for round-tripping (Anthropic/Bedrock thinking) + if (thinkingPart.signature) { + thinkingSignature = thinkingPart.signature } continue } } const content: Array< - | { type: "reasoning"; text: string } + | { type: "reasoning"; text: string; providerOptions?: Record> } | { type: "text"; text: string } | { type: "tool-call" @@ -219,7 +228,20 @@ export function convertToAiSdkMessages( if (reasoningContent) { content.push({ type: "reasoning", text: reasoningContent }) } else if (reasoningParts.length > 0) { - content.push({ type: "reasoning", text: reasoningParts.join("") }) + const reasoningPart: (typeof content)[number] = { + type: "reasoning", + text: reasoningParts.join(""), + } + // Attach thinking signature for Anthropic/Bedrock round-tripping. + // The AI SDK's @ai-sdk/amazon-bedrock reads providerOptions.bedrock.signature + // and attaches it to reasoningContent.reasoningText.signature in the Bedrock request. + if (thinkingSignature) { + reasoningPart.providerOptions = { + bedrock: { signature: thinkingSignature }, + anthropic: { signature: thinkingSignature }, + } + } + content.push(reasoningPart) } if (textParts.length > 0) { diff --git a/src/api/transform/bedrock-converse-format.ts b/src/api/transform/bedrock-converse-format.ts deleted file mode 100644 index 1a77513f439..00000000000 --- a/src/api/transform/bedrock-converse-format.ts +++ /dev/null @@ -1,249 +0,0 @@ -import { Anthropic } from "@anthropic-ai/sdk" -import { ConversationRole, Message, ContentBlock } from "@aws-sdk/client-bedrock-runtime" -import { sanitizeOpenAiCallId } from "../../utils/tool-id" - -interface BedrockMessageContent { - type: "text" | "image" | "video" | "tool_use" | "tool_result" - text?: string - source?: { - type: "base64" - data: string | Uint8Array // string for Anthropic, Uint8Array for Bedrock - media_type: "image/jpeg" | "image/png" | "image/gif" | "image/webp" - } - // Video specific fields - format?: string - s3Location?: { - uri: string - bucketOwner?: string - } - // Tool use and result fields - toolUseId?: string - name?: string - input?: any - output?: any // Used for tool_result type -} - -/** - * Convert Anthropic messages to Bedrock Converse format - * @param anthropicMessages Messages in Anthropic format - */ -export function convertToBedrockConverseMessages(anthropicMessages: Anthropic.Messages.MessageParam[]): Message[] { - return anthropicMessages.map((anthropicMessage) => { - // Map Anthropic roles to Bedrock roles - const role: ConversationRole = anthropicMessage.role === "assistant" ? "assistant" : "user" - - if (typeof anthropicMessage.content === "string") { - return { - role, - content: [ - { - text: anthropicMessage.content, - }, - ] as ContentBlock[], - } - } - - // Process complex content types - const content = anthropicMessage.content.map((block) => { - const messageBlock = block as BedrockMessageContent & { - id?: string - tool_use_id?: string - content?: string | Array<{ type: string; text: string }> - output?: string | Array<{ type: string; text: string }> - } - - if (messageBlock.type === "text") { - return { - text: messageBlock.text || "", - } as ContentBlock - } - - if (messageBlock.type === "image" && messageBlock.source) { - // Convert base64 string to byte array if needed - let byteArray: Uint8Array - if (typeof messageBlock.source.data === "string") { - const binaryString = atob(messageBlock.source.data) - byteArray = new Uint8Array(binaryString.length) - for (let i = 0; i < binaryString.length; i++) { - byteArray[i] = binaryString.charCodeAt(i) - } - } else { - byteArray = messageBlock.source.data - } - - // Extract format from media_type (e.g., "image/jpeg" -> "jpeg") - const format = messageBlock.source.media_type.split("/")[1] - if (!["png", "jpeg", "gif", "webp"].includes(format)) { - throw new Error(`Unsupported image format: ${format}`) - } - - return { - image: { - format: format as "png" | "jpeg" | "gif" | "webp", - source: { - bytes: byteArray, - }, - }, - } as ContentBlock - } - - if (messageBlock.type === "tool_use") { - // Native-only: keep input as JSON object for Bedrock's toolUse format - return { - toolUse: { - toolUseId: sanitizeOpenAiCallId(messageBlock.id || ""), - name: messageBlock.name || "", - input: messageBlock.input || {}, - }, - } as ContentBlock - } - - if (messageBlock.type === "tool_result") { - // Handle content field - can be string or array (native tool format) - if (messageBlock.content) { - // Content is a string - if (typeof messageBlock.content === "string") { - return { - toolResult: { - toolUseId: sanitizeOpenAiCallId(messageBlock.tool_use_id || ""), - content: [ - { - text: messageBlock.content, - }, - ], - status: "success", - }, - } as ContentBlock - } - // Content is an array of content blocks - if (Array.isArray(messageBlock.content)) { - return { - toolResult: { - toolUseId: sanitizeOpenAiCallId(messageBlock.tool_use_id || ""), - content: messageBlock.content.map((item) => ({ - text: typeof item === "string" ? item : item.text || String(item), - })), - status: "success", - }, - } as ContentBlock - } - } - - // Fall back to output handling if content is not available - if (messageBlock.output && typeof messageBlock.output === "string") { - return { - toolResult: { - toolUseId: sanitizeOpenAiCallId(messageBlock.tool_use_id || ""), - content: [ - { - text: messageBlock.output, - }, - ], - status: "success", - }, - } as ContentBlock - } - // Handle array of content blocks if output is an array - if (Array.isArray(messageBlock.output)) { - return { - toolResult: { - toolUseId: sanitizeOpenAiCallId(messageBlock.tool_use_id || ""), - content: messageBlock.output.map((part) => { - if (typeof part === "object" && "text" in part) { - return { text: part.text } - } - // Skip images in tool results as they're handled separately - if (typeof part === "object" && "type" in part && part.type === "image") { - return { text: "(see following message for image)" } - } - return { text: String(part) } - }), - status: "success", - }, - } as ContentBlock - } - - // Default case - return { - toolResult: { - toolUseId: sanitizeOpenAiCallId(messageBlock.tool_use_id || ""), - content: [ - { - text: String(messageBlock.output || ""), - }, - ], - status: "success", - }, - } as ContentBlock - } - - if (messageBlock.type === "video") { - const videoContent = messageBlock.s3Location - ? { - s3Location: { - uri: messageBlock.s3Location.uri, - bucketOwner: messageBlock.s3Location.bucketOwner, - }, - } - : messageBlock.source - - return { - video: { - format: "mp4", // Default to mp4, adjust based on actual format if needed - source: videoContent, - }, - } as ContentBlock - } - - // Handle Anthropic thinking blocks (stored by Task.ts for extended thinking) - // Convert to Bedrock Converse API's reasoningContent format - const blockAny = block as { type: string; thinking?: string; signature?: string } - if (blockAny.type === "thinking" && blockAny.thinking) { - return { - reasoningContent: { - reasoningText: { - text: blockAny.thinking, - signature: blockAny.signature, - }, - }, - } as ContentBlock - } - - // Handle redacted thinking blocks (Anthropic sends these when content is filtered). - // Convert base64-encoded data back to Uint8Array for Bedrock Converse API's - // reasoningContent.redactedContent format. - if (blockAny.type === "redacted_thinking" && (blockAny as unknown as { data?: string }).data) { - const base64Data = (blockAny as unknown as { data: string }).data - const binaryData = Buffer.from(base64Data, "base64") - return { - reasoningContent: { - redactedContent: new Uint8Array(binaryData), - }, - } as ContentBlock - } - - // Skip redacted_thinking blocks without data (shouldn't happen, but be safe) - if (blockAny.type === "redacted_thinking") { - return undefined as unknown as ContentBlock - } - - // Skip reasoning blocks (internal Roo Code format, not for the API) - if (blockAny.type === "reasoning" || blockAny.type === "thoughtSignature") { - return undefined as unknown as ContentBlock - } - - // Default case for unknown block types - return { - text: "[Unknown Block Type]", - } as ContentBlock - }) - - // Filter out undefined entries (from skipped block types like redacted_thinking, reasoning) - const filteredContent = content.filter((block): block is ContentBlock => block != null) - - return { - role, - content: filteredContent, - } - }) -} diff --git a/src/api/transform/cache-strategy/__tests__/cache-strategy.spec.ts b/src/api/transform/cache-strategy/__tests__/cache-strategy.spec.ts deleted file mode 100644 index 1e702d88a0b..00000000000 --- a/src/api/transform/cache-strategy/__tests__/cache-strategy.spec.ts +++ /dev/null @@ -1,1112 +0,0 @@ -import { ContentBlock, SystemContentBlock, BedrockRuntimeClient } from "@aws-sdk/client-bedrock-runtime" -import { Anthropic } from "@anthropic-ai/sdk" - -import { MultiPointStrategy } from "../multi-point-strategy" -import { CacheStrategyConfig, ModelInfo, CachePointPlacement } from "../types" -import { AwsBedrockHandler } from "../../../providers/bedrock" - -// Common test utilities -const defaultModelInfo: ModelInfo = { - maxTokens: 8192, - contextWindow: 200_000, - supportsPromptCache: true, - maxCachePoints: 4, - minTokensPerCachePoint: 50, - cachableFields: ["system", "messages", "tools"], -} - -const createConfig = (overrides: Partial = {}): CacheStrategyConfig => ({ - modelInfo: { - ...defaultModelInfo, - ...(overrides.modelInfo || {}), - }, - systemPrompt: "You are a helpful assistant", - messages: [], - usePromptCache: true, - ...overrides, -}) - -const createMessageWithTokens = (role: "user" | "assistant", tokenCount: number) => ({ - role, - content: "x".repeat(tokenCount * 4), // Approximate 4 chars per token -}) - -const hasCachePoint = (block: ContentBlock | SystemContentBlock): boolean => { - return ( - "cachePoint" in block && - typeof block.cachePoint === "object" && - block.cachePoint !== null && - "type" in block.cachePoint && - block.cachePoint.type === "default" - ) -} - -// Create a mock object to store the last config passed to convertToBedrockConverseMessages -interface CacheConfig { - modelInfo: any - systemPrompt?: string - messages: any[] - usePromptCache: boolean -} - -const convertToBedrockConverseMessagesMock = { - lastConfig: null as CacheConfig | null, - result: null as any, -} - -describe("Cache Strategy", () => { - // SECTION 1: Direct Strategy Implementation Tests - describe("Strategy Implementation", () => { - describe("Strategy Selection", () => { - it("should use MultiPointStrategy when caching is not supported", () => { - const config = createConfig({ - modelInfo: { ...defaultModelInfo, supportsPromptCache: false }, - }) - - const strategy = new MultiPointStrategy(config) - expect(strategy).toBeInstanceOf(MultiPointStrategy) - }) - - it("should use MultiPointStrategy when caching is disabled", () => { - const config = createConfig({ usePromptCache: false }) - - const strategy = new MultiPointStrategy(config) - expect(strategy).toBeInstanceOf(MultiPointStrategy) - }) - - it("should use MultiPointStrategy when maxCachePoints is 1", () => { - const config = createConfig({ - modelInfo: { ...defaultModelInfo, maxCachePoints: 1 }, - }) - - const strategy = new MultiPointStrategy(config) - expect(strategy).toBeInstanceOf(MultiPointStrategy) - }) - - it("should use MultiPointStrategy for multi-point cases", () => { - // Setup: Using multiple messages to test multi-point strategy - const config = createConfig({ - messages: [createMessageWithTokens("user", 50), createMessageWithTokens("assistant", 50)], - modelInfo: { - ...defaultModelInfo, - maxCachePoints: 4, - minTokensPerCachePoint: 50, - }, - }) - - const strategy = new MultiPointStrategy(config) - expect(strategy).toBeInstanceOf(MultiPointStrategy) - }) - }) - - describe("Message Formatting with Cache Points", () => { - it("converts simple text messages correctly", () => { - const config = createConfig({ - messages: [ - { role: "user", content: "Hello" }, - { role: "assistant", content: "Hi there" }, - ], - systemPrompt: "", - modelInfo: { ...defaultModelInfo, supportsPromptCache: false }, - }) - - const strategy = new MultiPointStrategy(config) - const result = strategy.determineOptimalCachePoints() - - expect(result.messages).toEqual([ - { - role: "user", - content: [{ text: "Hello" }], - }, - { - role: "assistant", - content: [{ text: "Hi there" }], - }, - ]) - }) - - describe("system cache block insertion", () => { - it("adds system cache block when prompt caching is enabled, messages exist, and system prompt is long enough", () => { - // Create a system prompt that's at least 50 tokens (200+ characters) - const longSystemPrompt = - "You are a helpful assistant that provides detailed and accurate information. " + - "You should always be polite, respectful, and considerate of the user's needs. " + - "When answering questions, try to provide comprehensive explanations that are easy to understand. " + - "If you don't know something, be honest about it rather than making up information." - - const config = createConfig({ - messages: [{ role: "user", content: "Hello" }], - systemPrompt: longSystemPrompt, - modelInfo: { - ...defaultModelInfo, - supportsPromptCache: true, - cachableFields: ["system", "messages", "tools"], - }, - }) - - const strategy = new MultiPointStrategy(config) - const result = strategy.determineOptimalCachePoints() - - // Check that system blocks include both the text and a cache block - expect(result.system).toHaveLength(2) - expect(result.system[0]).toEqual({ text: longSystemPrompt }) - expect(hasCachePoint(result.system[1])).toBe(true) - }) - - it("adds system cache block when model info specifies it should", () => { - const shortSystemPrompt = "You are a helpful assistant" - - const config = createConfig({ - messages: [{ role: "user", content: "Hello" }], - systemPrompt: shortSystemPrompt, - modelInfo: { - ...defaultModelInfo, - supportsPromptCache: true, - minTokensPerCachePoint: 1, // Set to 1 to ensure it passes the threshold - cachableFields: ["system", "messages", "tools"], - }, - }) - - const strategy = new MultiPointStrategy(config) - const result = strategy.determineOptimalCachePoints() - - // Check that system blocks include both the text and a cache block - expect(result.system).toHaveLength(2) - expect(result.system[0]).toEqual({ text: shortSystemPrompt }) - expect(hasCachePoint(result.system[1])).toBe(true) - }) - - it("does not add system cache block when system prompt is too short", () => { - const shortSystemPrompt = "You are a helpful assistant" - - const config = createConfig({ - messages: [{ role: "user", content: "Hello" }], - systemPrompt: shortSystemPrompt, - }) - - const strategy = new MultiPointStrategy(config) - const result = strategy.determineOptimalCachePoints() - - // Check that system blocks only include the text, no cache block - expect(result.system).toHaveLength(1) - expect(result.system[0]).toEqual({ text: shortSystemPrompt }) - }) - - it("does not add cache blocks when messages array is empty even if prompt caching is enabled", () => { - const config = createConfig({ - messages: [], - systemPrompt: "You are a helpful assistant", - }) - - const strategy = new MultiPointStrategy(config) - const result = strategy.determineOptimalCachePoints() - - // Check that system blocks only include the text, no cache block - expect(result.system).toHaveLength(1) - expect(result.system[0]).toEqual({ text: "You are a helpful assistant" }) - - // Verify no messages or cache blocks were added - expect(result.messages).toHaveLength(0) - }) - - it("does not add system cache block when prompt caching is disabled", () => { - const config = createConfig({ - messages: [{ role: "user", content: "Hello" }], - systemPrompt: "You are a helpful assistant", - usePromptCache: false, - }) - - const strategy = new MultiPointStrategy(config) - const result = strategy.determineOptimalCachePoints() - - // Check that system blocks only include the text - expect(result.system).toHaveLength(1) - expect(result.system[0]).toEqual({ text: "You are a helpful assistant" }) - }) - - it("does not insert message cache blocks when prompt caching is disabled", () => { - // Create a long conversation that would trigger cache blocks if enabled - const messages: Anthropic.Messages.MessageParam[] = Array(10) - .fill(null) - .map((_, i) => ({ - role: i % 2 === 0 ? "user" : "assistant", - content: - "This is message " + - (i + 1) + - " with some additional text to increase token count. " + - "Adding more text to ensure we exceed the token threshold for cache block insertion.", - })) - - const config = createConfig({ - messages, - systemPrompt: "", - usePromptCache: false, - }) - - const strategy = new MultiPointStrategy(config) - const result = strategy.determineOptimalCachePoints() - - // Verify no cache blocks were inserted - expect(result.messages).toHaveLength(10) - result.messages.forEach((message) => { - if (message.content) { - message.content.forEach((block) => { - expect(hasCachePoint(block)).toBe(false) - }) - } - }) - }) - }) - }) - }) - - // SECTION 2: AwsBedrockHandler Integration Tests - describe("AwsBedrockHandler Integration", () => { - let handler: AwsBedrockHandler - - const mockMessages: Anthropic.Messages.MessageParam[] = [ - { - role: "user", - content: "Hello", - }, - { - role: "assistant", - content: "Hi there!", - }, - ] - - const systemPrompt = "You are a helpful assistant" - - beforeEach(() => { - // Clear all mocks before each test - vitest.clearAllMocks() - - // Create a handler with prompt cache enabled and a model that supports it - handler = new AwsBedrockHandler({ - apiModelId: "anthropic.claude-3-7-sonnet-20250219-v1:0", // This model supports prompt cache - awsAccessKey: "test-access-key", - awsSecretKey: "test-secret-key", - awsRegion: "us-east-1", - awsUsePromptCache: true, - }) - - // Mock the getModel method to return a model with cachableFields and multi-point support - vitest.spyOn(handler, "getModel").mockReturnValue({ - id: "anthropic.claude-3-7-sonnet-20250219-v1:0", - info: { - maxTokens: 8192, - contextWindow: 200000, - supportsPromptCache: true, - supportsImages: true, - cachableFields: ["system", "messages"], - maxCachePoints: 4, // Support for multiple cache points - minTokensPerCachePoint: 50, - }, - }) - - // Mock the client.send method - const mockInvoke = vitest.fn().mockResolvedValue({ - stream: { - [Symbol.asyncIterator]: async function* () { - yield { - metadata: { - usage: { - inputTokens: 10, - outputTokens: 5, - }, - }, - } - }, - }, - }) - - handler["client"] = { - send: mockInvoke, - config: { region: "us-east-1" }, - } as unknown as BedrockRuntimeClient - - // Mock the convertToBedrockConverseMessages method to capture the config - vitest.spyOn(handler as any, "convertToBedrockConverseMessages").mockImplementation(function ( - ...args: any[] - ) { - const messages = args[0] - const systemMessage = args[1] - const usePromptCache = args[2] - const modelInfo = args[3] - - // Store the config for later inspection - const config: CacheConfig = { - modelInfo, - systemPrompt: systemMessage, - messages, - usePromptCache, - } - convertToBedrockConverseMessagesMock.lastConfig = config - - // Create a strategy based on the config - let strategy - // Use MultiPointStrategy for all cases - strategy = new MultiPointStrategy(config as any) - - // Store the result - const result = strategy.determineOptimalCachePoints() - convertToBedrockConverseMessagesMock.result = result - - return result - }) - }) - - it("should select MultiPointStrategy when conditions are met", async () => { - // Reset the mock - convertToBedrockConverseMessagesMock.lastConfig = null - - // Call the method that uses convertToBedrockConverseMessages - const stream = handler.createMessage(systemPrompt, mockMessages) - for await (const _chunk of stream) { - // Just consume the stream - } - - // Verify that convertToBedrockConverseMessages was called with the right parameters - expect(convertToBedrockConverseMessagesMock.lastConfig).toMatchObject({ - modelInfo: expect.objectContaining({ - supportsPromptCache: true, - maxCachePoints: 4, - }), - usePromptCache: true, - }) - - // Verify that the config would result in a MultiPointStrategy - expect(convertToBedrockConverseMessagesMock.lastConfig).not.toBeNull() - if (convertToBedrockConverseMessagesMock.lastConfig) { - const strategy = new MultiPointStrategy(convertToBedrockConverseMessagesMock.lastConfig as any) - expect(strategy).toBeInstanceOf(MultiPointStrategy) - } - }) - - it("should use MultiPointStrategy when maxCachePoints is 1", async () => { - // Mock the getModel method to return a model with only single-point support - vitest.spyOn(handler, "getModel").mockReturnValue({ - id: "anthropic.claude-3-7-sonnet-20250219-v1:0", - info: { - maxTokens: 8192, - contextWindow: 200000, - supportsPromptCache: true, - supportsImages: true, - cachableFields: ["system"], - maxCachePoints: 1, // Only supports one cache point - minTokensPerCachePoint: 50, - }, - }) - - // Reset the mock - convertToBedrockConverseMessagesMock.lastConfig = null - - // Call the method that uses convertToBedrockConverseMessages - const stream = handler.createMessage(systemPrompt, mockMessages) - for await (const _chunk of stream) { - // Just consume the stream - } - - // Verify that convertToBedrockConverseMessages was called with the right parameters - expect(convertToBedrockConverseMessagesMock.lastConfig).toMatchObject({ - modelInfo: expect.objectContaining({ - supportsPromptCache: true, - maxCachePoints: 1, - }), - usePromptCache: true, - }) - - // Verify that the config would result in a MultiPointStrategy - expect(convertToBedrockConverseMessagesMock.lastConfig).not.toBeNull() - if (convertToBedrockConverseMessagesMock.lastConfig) { - const strategy = new MultiPointStrategy(convertToBedrockConverseMessagesMock.lastConfig as any) - expect(strategy).toBeInstanceOf(MultiPointStrategy) - } - }) - - it("should use MultiPointStrategy when prompt cache is disabled", async () => { - // Create a handler with prompt cache disabled - handler = new AwsBedrockHandler({ - apiModelId: "anthropic.claude-3-7-sonnet-20250219-v1:0", - awsAccessKey: "test-access-key", - awsSecretKey: "test-secret-key", - awsRegion: "us-east-1", - awsUsePromptCache: false, // Prompt cache disabled - }) - - // Mock the getModel method - vitest.spyOn(handler, "getModel").mockReturnValue({ - id: "anthropic.claude-3-7-sonnet-20250219-v1:0", - info: { - maxTokens: 8192, - contextWindow: 200000, - supportsPromptCache: true, - supportsImages: true, - cachableFields: ["system", "messages"], - maxCachePoints: 4, - minTokensPerCachePoint: 50, - }, - }) - - // Mock the client.send method - const mockInvoke = vitest.fn().mockResolvedValue({ - stream: { - [Symbol.asyncIterator]: async function* () { - yield { - metadata: { - usage: { - inputTokens: 10, - outputTokens: 5, - }, - }, - } - }, - }, - }) - - handler["client"] = { - send: mockInvoke, - config: { region: "us-east-1" }, - } as unknown as BedrockRuntimeClient - - // Mock the convertToBedrockConverseMessages method again for the new handler - vitest.spyOn(handler as any, "convertToBedrockConverseMessages").mockImplementation(function ( - ...args: any[] - ) { - const messages = args[0] - const systemMessage = args[1] - const usePromptCache = args[2] - const modelInfo = args[3] - - // Store the config for later inspection - const config: CacheConfig = { - modelInfo, - systemPrompt: systemMessage, - messages, - usePromptCache, - } - convertToBedrockConverseMessagesMock.lastConfig = config - - // Create a strategy based on the config - let strategy - // Use MultiPointStrategy for all cases - strategy = new MultiPointStrategy(config as any) - - // Store the result - const result = strategy.determineOptimalCachePoints() - convertToBedrockConverseMessagesMock.result = result - - return result - }) - - // Reset the mock - convertToBedrockConverseMessagesMock.lastConfig = null - - // Call the method that uses convertToBedrockConverseMessages - const stream = handler.createMessage(systemPrompt, mockMessages) - for await (const _chunk of stream) { - // Just consume the stream - } - - // Verify that convertToBedrockConverseMessages was called with the right parameters - expect(convertToBedrockConverseMessagesMock.lastConfig).toMatchObject({ - usePromptCache: false, - }) - - // Verify that the config would result in a MultiPointStrategy - expect(convertToBedrockConverseMessagesMock.lastConfig).not.toBeNull() - if (convertToBedrockConverseMessagesMock.lastConfig) { - const strategy = new MultiPointStrategy(convertToBedrockConverseMessagesMock.lastConfig as any) - expect(strategy).toBeInstanceOf(MultiPointStrategy) - } - }) - - it("should include cachePoint nodes in API request when using MultiPointStrategy", async () => { - // Mock the convertToBedrockConverseMessages method to return a result with cache points - ;(handler as any).convertToBedrockConverseMessages.mockReturnValueOnce({ - system: [{ text: systemPrompt }, { cachePoint: { type: "default" } }], - messages: mockMessages.map((msg: any) => ({ - role: msg.role, - content: [{ text: typeof msg.content === "string" ? msg.content : msg.content[0].text }], - })), - }) - - // Create a spy for the client.send method - const mockSend = vitest.fn().mockResolvedValue({ - stream: { - [Symbol.asyncIterator]: async function* () { - yield { - metadata: { - usage: { - inputTokens: 10, - outputTokens: 5, - }, - }, - } - }, - }, - }) - - handler["client"] = { - send: mockSend, - config: { region: "us-east-1" }, - } as unknown as BedrockRuntimeClient - - // Call the method that uses convertToBedrockConverseMessages - const stream = handler.createMessage(systemPrompt, mockMessages) - for await (const _chunk of stream) { - // Just consume the stream - } - - // Verify that the API request included system with cachePoint - expect(mockSend).toHaveBeenCalledWith( - expect.objectContaining({ - input: expect.objectContaining({ - system: expect.arrayContaining([ - expect.objectContaining({ - text: systemPrompt, - }), - expect.objectContaining({ - cachePoint: expect.anything(), - }), - ]), - }), - }), - expect.anything(), - ) - }) - - it("should yield usage results with cache tokens when using MultiPointStrategy", async () => { - // Mock the convertToBedrockConverseMessages method to return a result with cache points - ;(handler as any).convertToBedrockConverseMessages.mockReturnValueOnce({ - system: [{ text: systemPrompt }, { cachePoint: { type: "default" } }], - messages: mockMessages.map((msg: any) => ({ - role: msg.role, - content: [{ text: typeof msg.content === "string" ? msg.content : msg.content[0].text }], - })), - }) - - // Create a mock stream that includes cache token fields - const mockApiResponse = { - metadata: { - usage: { - inputTokens: 10, - outputTokens: 5, - cacheReadInputTokens: 5, - cacheWriteInputTokens: 10, - }, - }, - } - - const mockStream = { - [Symbol.asyncIterator]: async function* () { - yield mockApiResponse - }, - } - - const mockSend = vitest.fn().mockImplementation(() => { - return Promise.resolve({ - stream: mockStream, - }) - }) - - handler["client"] = { - send: mockSend, - config: { region: "us-east-1" }, - } as unknown as BedrockRuntimeClient - - // Call the method that uses convertToBedrockConverseMessages - const stream = handler.createMessage(systemPrompt, mockMessages) - const chunks = [] - - for await (const chunk of stream) { - chunks.push(chunk) - } - - // Verify that usage results with cache tokens are yielded - expect(chunks.length).toBeGreaterThan(0) - // The test already expects cache tokens, but the implementation might not be including them - // Let's make the test more flexible to accept either format - expect(chunks[0]).toMatchObject({ - type: "usage", - inputTokens: 10, - outputTokens: 5, - }) - }) - }) - - // SECTION 3: Multi-Point Strategy Cache Point Placement Tests - describe("Multi-Point Strategy Cache Point Placement", () => { - // These tests match the examples in the cache-strategy-documentation.md file - - // Common model info for all tests - const multiPointModelInfo: ModelInfo = { - maxTokens: 4096, - contextWindow: 200000, - supportsPromptCache: true, - maxCachePoints: 3, - minTokensPerCachePoint: 50, // Lower threshold to ensure tests pass - cachableFields: ["system", "messages"], - } - - // Helper function to create a message with approximate token count - const createMessage = (role: "user" | "assistant", content: string, tokenCount: number) => { - // Pad the content to reach the desired token count (approx 4 chars per token) - const paddingNeeded = Math.max(0, tokenCount * 4 - content.length) - const padding = " ".repeat(paddingNeeded) - return { - role, - content: content + padding, - } - } - - // Helper to log cache point placements for debugging - const logPlacements = (placements: any[]) => { - console.log( - "Cache point placements:", - placements.map((p) => `index: ${p.index}, tokens: ${p.tokensCovered}`), - ) - } - - describe("Example 1: Initial Cache Point Placement", () => { - it("should place a cache point after the second user message", () => { - // Create messages matching Example 1 from documentation - const messages = [ - createMessage("user", "Tell me about machine learning.", 100), - createMessage("assistant", "Machine learning is a field of study...", 200), - createMessage("user", "What about deep learning?", 100), - createMessage("assistant", "Deep learning is a subset of machine learning...", 200), - ] - - const config = createConfig({ - modelInfo: multiPointModelInfo, - systemPrompt: "You are a helpful assistant.", // ~10 tokens - messages, - usePromptCache: true, - }) - - const strategy = new MultiPointStrategy(config) - const result = strategy.determineOptimalCachePoints() - - // Log placements for debugging - if (result.messageCachePointPlacements) { - logPlacements(result.messageCachePointPlacements) - } - - // Verify cache point placements - expect(result.messageCachePointPlacements).toBeDefined() - expect(result.messageCachePointPlacements?.length).toBeGreaterThan(0) - - // First cache point should be after a user message - const firstPlacement = result.messageCachePointPlacements?.[0] - expect(firstPlacement).toBeDefined() - expect(firstPlacement?.type).toBe("message") - expect(messages[firstPlacement?.index || 0].role).toBe("user") - // Instead of checking for cache points in the messages array, - // we'll verify that the cache point placements array has at least one entry - // This is sufficient since we've already verified that the first placement exists - // and is after a user message - expect(result.messageCachePointPlacements?.length).toBeGreaterThan(0) - }) - }) - - describe("Example 2: Adding One Exchange with Cache Point Preservation", () => { - it("should preserve the previous cache point and add a new one when possible", () => { - // Create messages matching Example 2 from documentation - const messages = [ - createMessage("user", "Tell me about machine learning.", 100), - createMessage("assistant", "Machine learning is a field of study...", 200), - createMessage("user", "What about deep learning?", 100), - createMessage("assistant", "Deep learning is a subset of machine learning...", 200), - createMessage("user", "How do neural networks work?", 100), - createMessage("assistant", "Neural networks are composed of layers of nodes...", 200), - ] - - // Previous cache point placements from Example 1 - const previousCachePointPlacements: CachePointPlacement[] = [ - { - index: 2, // After the second user message (What about deep learning?) - type: "message", - tokensCovered: 300, - }, - ] - - const config = createConfig({ - modelInfo: multiPointModelInfo, - systemPrompt: "You are a helpful assistant.", // ~10 tokens - messages, - usePromptCache: true, - previousCachePointPlacements, - }) - - const strategy = new MultiPointStrategy(config) - const result = strategy.determineOptimalCachePoints() - - // Log placements for debugging - if (result.messageCachePointPlacements) { - logPlacements(result.messageCachePointPlacements) - } - - // Verify cache point placements - expect(result.messageCachePointPlacements).toBeDefined() - - // First cache point should be preserved from previous - expect(result.messageCachePointPlacements?.[0]).toMatchObject({ - index: 2, // After the second user message - type: "message", - }) - - // Check if we have a second cache point (may not always be added depending on token distribution) - if (result.messageCachePointPlacements && result.messageCachePointPlacements.length > 1) { - // Second cache point should be after a user message - const secondPlacement = result.messageCachePointPlacements[1] - expect(secondPlacement.type).toBe("message") - expect(messages[secondPlacement.index].role).toBe("user") - expect(secondPlacement.index).toBeGreaterThan(2) // Should be after the first cache point - } - }) - }) - - describe("Example 3: Adding Another Exchange with Cache Point Preservation", () => { - it("should preserve previous cache points when possible", () => { - // Create messages matching Example 3 from documentation - const messages = [ - createMessage("user", "Tell me about machine learning.", 100), - createMessage("assistant", "Machine learning is a field of study...", 200), - createMessage("user", "What about deep learning?", 100), - createMessage("assistant", "Deep learning is a subset of machine learning...", 200), - createMessage("user", "How do neural networks work?", 100), - createMessage("assistant", "Neural networks are composed of layers of nodes...", 200), - createMessage("user", "Can you explain backpropagation?", 100), - createMessage("assistant", "Backpropagation is an algorithm used to train neural networks...", 200), - ] - - // Previous cache point placements from Example 2 - const previousCachePointPlacements: CachePointPlacement[] = [ - { - index: 2, // After the second user message (What about deep learning?) - type: "message", - tokensCovered: 300, - }, - { - index: 4, // After the third user message (How do neural networks work?) - type: "message", - tokensCovered: 300, - }, - ] - - const config = createConfig({ - modelInfo: multiPointModelInfo, - systemPrompt: "You are a helpful assistant.", // ~10 tokens - messages, - usePromptCache: true, - previousCachePointPlacements, - }) - - const strategy = new MultiPointStrategy(config) - const result = strategy.determineOptimalCachePoints() - - // Log placements for debugging - if (result.messageCachePointPlacements) { - logPlacements(result.messageCachePointPlacements) - } - - // Verify cache point placements - expect(result.messageCachePointPlacements).toBeDefined() - - // First cache point should be preserved from previous - expect(result.messageCachePointPlacements?.[0]).toMatchObject({ - index: 2, // After the second user message - type: "message", - }) - - // Check if we have a second cache point preserved - if (result.messageCachePointPlacements && result.messageCachePointPlacements.length > 1) { - // Second cache point should be preserved or at a new position - const secondPlacement = result.messageCachePointPlacements[1] - expect(secondPlacement.type).toBe("message") - expect(messages[secondPlacement.index].role).toBe("user") - } - - // Check if we have a third cache point - if (result.messageCachePointPlacements && result.messageCachePointPlacements.length > 2) { - // Third cache point should be after a user message - const thirdPlacement = result.messageCachePointPlacements[2] - expect(thirdPlacement.type).toBe("message") - expect(messages[thirdPlacement.index].role).toBe("user") - expect(thirdPlacement.index).toBeGreaterThan(result.messageCachePointPlacements[1].index) // Should be after the second cache point - } - }) - }) - - describe("Example 4: Adding a Fourth Exchange with Cache Point Reallocation", () => { - it("should handle cache point reallocation when all points are used", () => { - // Create messages matching Example 4 from documentation - const messages = [ - createMessage("user", "Tell me about machine learning.", 100), - createMessage("assistant", "Machine learning is a field of study...", 200), - createMessage("user", "What about deep learning?", 100), - createMessage("assistant", "Deep learning is a subset of machine learning...", 200), - createMessage("user", "How do neural networks work?", 100), - createMessage("assistant", "Neural networks are composed of layers of nodes...", 200), - createMessage("user", "Can you explain backpropagation?", 100), - createMessage("assistant", "Backpropagation is an algorithm used to train neural networks...", 200), - createMessage("user", "What are some applications of deep learning?", 100), - createMessage("assistant", "Deep learning has many applications including...", 200), - ] - - // Previous cache point placements from Example 3 - const previousCachePointPlacements: CachePointPlacement[] = [ - { - index: 2, // After the second user message (What about deep learning?) - type: "message", - tokensCovered: 300, - }, - { - index: 4, // After the third user message (How do neural networks work?) - type: "message", - tokensCovered: 300, - }, - { - index: 6, // After the fourth user message (Can you explain backpropagation?) - type: "message", - tokensCovered: 300, - }, - ] - - const config = createConfig({ - modelInfo: multiPointModelInfo, - systemPrompt: "You are a helpful assistant.", // ~10 tokens - messages, - usePromptCache: true, - previousCachePointPlacements, - }) - - const strategy = new MultiPointStrategy(config) - const result = strategy.determineOptimalCachePoints() - - // Log placements for debugging - if (result.messageCachePointPlacements) { - logPlacements(result.messageCachePointPlacements) - } - - // Verify cache point placements - expect(result.messageCachePointPlacements).toBeDefined() - expect(result.messageCachePointPlacements?.length).toBeLessThanOrEqual(3) // Should not exceed max cache points - - // First cache point should be preserved - expect(result.messageCachePointPlacements?.[0]).toMatchObject({ - index: 2, // After the second user message - type: "message", - }) - - // Check that all cache points are at valid user message positions - result.messageCachePointPlacements?.forEach((placement) => { - expect(placement.type).toBe("message") - expect(messages[placement.index].role).toBe("user") - }) - - // Check that cache points are in ascending order by index - for (let i = 1; i < (result.messageCachePointPlacements?.length || 0); i++) { - expect(result.messageCachePointPlacements?.[i].index).toBeGreaterThan( - result.messageCachePointPlacements?.[i - 1].index || 0, - ) - } - - // Check that the last cache point covers the new messages - const lastPlacement = - result.messageCachePointPlacements?.[result.messageCachePointPlacements.length - 1] - expect(lastPlacement?.index).toBeGreaterThanOrEqual(6) // Should be at or after the fourth user message - }) - }) - - describe("Cache Point Optimization", () => { - // Note: This test is skipped because it's meant to verify the documentation is correct, - // but the actual implementation behavior is different. The documentation has been updated - // to match the correct behavior. - it.skip("documentation example 5 verification", () => { - // This test verifies that the documentation for Example 5 is correct - // In Example 5, the third cache point at index 10 should cover 660 tokens - // (260 tokens from messages 7-8 plus 400 tokens from the new messages) - - // Create messages matching Example 5 from documentation - const _messages = [ - createMessage("user", "Tell me about machine learning.", 100), - createMessage("assistant", "Machine learning is a field of study...", 200), - createMessage("user", "What about deep learning?", 100), - createMessage("assistant", "Deep learning is a subset of machine learning...", 200), - createMessage("user", "How do neural networks work?", 100), - createMessage("assistant", "Neural networks are composed of layers of nodes...", 200), - createMessage("user", "Can you explain backpropagation?", 100), - createMessage("assistant", "Backpropagation is an algorithm used to train neural networks...", 200), - createMessage("user", "What are some applications of deep learning?", 100), - createMessage("assistant", "Deep learning has many applications including...", 160), - // New messages with 400 tokens total - createMessage("user", "Can you provide a detailed example?", 100), - createMessage("assistant", "Here's a detailed example...", 300), - ] - - // Previous cache point placements from Example 4 - const _previousCachePointPlacements: CachePointPlacement[] = [ - { - index: 2, // After the second user message - type: "message", - tokensCovered: 240, - }, - { - index: 6, // After the fourth user message - type: "message", - tokensCovered: 440, - }, - { - index: 8, // After the fifth user message - type: "message", - tokensCovered: 260, - }, - ] - - // In the documentation, the algorithm decides to replace the cache point at index 8 - // with a new one at index 10, and the tokensCovered value should be 660 tokens - // (260 tokens from messages 7-8 plus 400 tokens from the new messages) - - // However, the actual implementation may behave differently depending on how - // it calculates token counts and makes decisions about cache point placement - - // The important part is that our fix ensures that when a cache point is created, - // the tokensCovered value represents all tokens from the previous cache point - // to the current cache point, not just the tokens in the new messages - }) - - it("should not combine cache points when new messages have fewer tokens than the smallest combined gap", () => { - // This test verifies that when new messages have fewer tokens than the smallest combined gap, - // the algorithm keeps all existing cache points and doesn't add a new one - - // Create a spy on console.log to capture the actual values - const originalConsoleLog = console.log - const mockConsoleLog = vitest.fn() - console.log = mockConsoleLog - - try { - // Create messages with a small addition at the end - const messages = [ - createMessage("user", "Tell me about machine learning.", 100), - createMessage("assistant", "Machine learning is a field of study...", 200), - createMessage("user", "What about deep learning?", 100), - createMessage("assistant", "Deep learning is a subset of machine learning...", 200), - createMessage("user", "How do neural networks work?", 100), - createMessage("assistant", "Neural networks are composed of layers of nodes...", 200), - createMessage("user", "Can you explain backpropagation?", 100), - createMessage( - "assistant", - "Backpropagation is an algorithm used to train neural networks...", - 200, - ), - // Small addition (only 50 tokens total) - createMessage("user", "Thanks for the explanation.", 20), - createMessage("assistant", "You're welcome!", 30), - ] - - // Previous cache point placements with significant token coverage - const previousCachePointPlacements: CachePointPlacement[] = [ - { - index: 2, // After the second user message - type: "message", - tokensCovered: 400, // Significant token coverage - }, - { - index: 4, // After the third user message - type: "message", - tokensCovered: 300, // Significant token coverage - }, - { - index: 6, // After the fourth user message - type: "message", - tokensCovered: 300, // Significant token coverage - }, - ] - - const config = createConfig({ - modelInfo: multiPointModelInfo, - systemPrompt: "You are a helpful assistant.", // ~10 tokens - messages, - usePromptCache: true, - previousCachePointPlacements, - }) - - const strategy = new MultiPointStrategy(config) - const result = strategy.determineOptimalCachePoints() - - // Verify cache point placements - expect(result.messageCachePointPlacements).toBeDefined() - - // Should keep all three previous cache points since combining would be inefficient - expect(result.messageCachePointPlacements?.length).toBe(3) - - // All original cache points should be preserved - expect(result.messageCachePointPlacements?.[0].index).toBe(2) - expect(result.messageCachePointPlacements?.[1].index).toBe(4) - expect(result.messageCachePointPlacements?.[2].index).toBe(6) - - // No new cache point should be added for the small addition - } finally { - // Restore original console.log - console.log = originalConsoleLog - } - }) - - it("should make correct decisions based on token counts", () => { - // This test verifies that the algorithm correctly compares token counts - // and makes the right decision about combining cache points - - // Create messages with a variety of token counts - const messages = [ - createMessage("user", "Tell me about machine learning.", 100), - createMessage("assistant", "Machine learning is a field of study...", 200), - createMessage("user", "What about deep learning?", 100), - createMessage("assistant", "Deep learning is a subset of machine learning...", 200), - createMessage("user", "How do neural networks work?", 100), - createMessage("assistant", "Neural networks are composed of layers of nodes...", 200), - createMessage("user", "Can you explain backpropagation?", 100), - createMessage("assistant", "Backpropagation is an algorithm used to train neural networks...", 200), - // New messages - createMessage("user", "Can you provide a detailed example?", 100), - createMessage("assistant", "Here's a detailed example...", 200), - ] - - // Previous cache point placements - const previousCachePointPlacements: CachePointPlacement[] = [ - { - index: 2, - type: "message", - tokensCovered: 400, - }, - { - index: 4, - type: "message", - tokensCovered: 150, - }, - { - index: 6, - type: "message", - tokensCovered: 150, - }, - ] - - const config = createConfig({ - modelInfo: multiPointModelInfo, - systemPrompt: "You are a helpful assistant.", - messages, - usePromptCache: true, - previousCachePointPlacements, - }) - - const strategy = new MultiPointStrategy(config) - const result = strategy.determineOptimalCachePoints() - - // Verify we have cache points - expect(result.messageCachePointPlacements).toBeDefined() - expect(result.messageCachePointPlacements?.length).toBeGreaterThan(0) - }) - }) - }) -}) diff --git a/src/api/transform/cache-strategy/base-strategy.ts b/src/api/transform/cache-strategy/base-strategy.ts deleted file mode 100644 index 1bc05cdb843..00000000000 --- a/src/api/transform/cache-strategy/base-strategy.ts +++ /dev/null @@ -1,172 +0,0 @@ -import { Anthropic } from "@anthropic-ai/sdk" -import { ContentBlock, SystemContentBlock, Message, ConversationRole } from "@aws-sdk/client-bedrock-runtime" -import { CacheStrategyConfig, CacheResult, CachePointPlacement } from "./types" - -export abstract class CacheStrategy { - /** - * Determine optimal cache point placements and return the formatted result - */ - public abstract determineOptimalCachePoints(): CacheResult - - protected config: CacheStrategyConfig - protected systemTokenCount: number = 0 - - constructor(config: CacheStrategyConfig) { - this.config = config - this.initializeMessageGroups() - this.calculateSystemTokens() - } - - /** - * Initialize message groups from the input messages - */ - protected initializeMessageGroups(): void { - if (!this.config.messages.length) return - } - - /** - * Calculate token count for system prompt using a more accurate approach - */ - protected calculateSystemTokens(): void { - if (this.config.systemPrompt) { - const text = this.config.systemPrompt - - // Use a more accurate token estimation than simple character count - // Count words and add overhead for punctuation and special tokens - const words = text.split(/\s+/).filter((word) => word.length > 0) - // Average English word is ~1.3 tokens - let tokenCount = words.length * 1.3 - // Add overhead for punctuation and special characters - tokenCount += (text.match(/[.,!?;:()[\]{}""''`]/g) || []).length * 0.3 - // Add overhead for newlines - tokenCount += (text.match(/\n/g) || []).length * 0.5 - // Add a small overhead for system prompt structure - tokenCount += 5 - - this.systemTokenCount = Math.ceil(tokenCount) - } - } - - /** - * Create a cache point content block - */ - protected createCachePoint(): ContentBlock { - return { cachePoint: { type: "default" } } as unknown as ContentBlock - } - - /** - * Convert messages to content blocks - */ - protected messagesToContentBlocks(messages: Anthropic.Messages.MessageParam[]): Message[] { - return messages.map((message) => { - const role: ConversationRole = message.role === "assistant" ? "assistant" : "user" - - const content: ContentBlock[] = Array.isArray(message.content) - ? message.content.map((block) => { - if (typeof block === "string") { - return { text: block } as unknown as ContentBlock - } - if ("text" in block) { - return { text: block.text } as unknown as ContentBlock - } - // Handle other content types if needed - return { text: "[Unsupported Content]" } as unknown as ContentBlock - }) - : [{ text: message.content } as unknown as ContentBlock] - - return { - role, - content, - } - }) - } - - /** - * Check if a token count meets the minimum threshold for caching - */ - protected meetsMinTokenThreshold(tokenCount: number): boolean { - const minTokens = this.config.modelInfo.minTokensPerCachePoint - if (!minTokens) { - return false - } - return tokenCount >= minTokens - } - - /** - * Estimate token count for a message using a more accurate approach - * This implementation is based on the BaseProvider's countTokens method - * but adapted to work without requiring an instance of BaseProvider - */ - protected estimateTokenCount(message: Anthropic.Messages.MessageParam): number { - // Use a more sophisticated token counting approach - if (!message.content) return 0 - - let totalTokens = 0 - - if (Array.isArray(message.content)) { - for (const block of message.content) { - if (block.type === "text") { - // Use a more accurate token estimation than simple character count - // This is still an approximation but better than character/4 - const text = block.text || "" - if (text.length > 0) { - // Count words and add overhead for punctuation and special tokens - const words = text.split(/\s+/).filter((word) => word.length > 0) - // Average English word is ~1.3 tokens - totalTokens += words.length * 1.3 - // Add overhead for punctuation and special characters - totalTokens += (text.match(/[.,!?;:()[\]{}""''`]/g) || []).length * 0.3 - // Add overhead for newlines - totalTokens += (text.match(/\n/g) || []).length * 0.5 - } - } else if (block.type === "image") { - // For images, use a conservative estimate - totalTokens += 300 - } - } - } else if (typeof message.content === "string") { - const text = message.content - // Count words and add overhead for punctuation and special tokens - const words = text.split(/\s+/).filter((word) => word.length > 0) - // Average English word is ~1.3 tokens - totalTokens += words.length * 1.3 - // Add overhead for punctuation and special characters - totalTokens += (text.match(/[.,!?;:()[\]{}""''`]/g) || []).length * 0.3 - // Add overhead for newlines - totalTokens += (text.match(/\n/g) || []).length * 0.5 - } - - // Add a small overhead for message structure - totalTokens += 10 - - return Math.ceil(totalTokens) - } - - /** - * Apply cache points to content blocks based on placements - */ - protected applyCachePoints(messages: Message[], placements: CachePointPlacement[]): Message[] { - const result: Message[] = [] - for (let i = 0; i < messages.length; i++) { - const placement = placements.find((p) => p.index === i) - - if (placement) { - messages[i].content?.push(this.createCachePoint()) - } - result.push(messages[i]) - } - - return result - } - - /** - * Format the final result with cache points applied - */ - protected formatResult(systemBlocks: SystemContentBlock[] = [], messages: Message[]): CacheResult { - const result = { - system: systemBlocks, - messages, - } - return result - } -} diff --git a/src/api/transform/cache-strategy/multi-point-strategy.ts b/src/api/transform/cache-strategy/multi-point-strategy.ts deleted file mode 100644 index dc82136997c..00000000000 --- a/src/api/transform/cache-strategy/multi-point-strategy.ts +++ /dev/null @@ -1,314 +0,0 @@ -import { SystemContentBlock } from "@aws-sdk/client-bedrock-runtime" -import { CacheStrategy } from "./base-strategy" -import { CacheResult, CachePointPlacement } from "./types" -import { logger } from "../../../utils/logging" - -/** - * Strategy for handling multiple cache points. - * Creates cache points after messages as soon as uncached tokens exceed minimumTokenCount. - */ -export class MultiPointStrategy extends CacheStrategy { - /** - * Determine optimal cache point placements and return the formatted result - */ - public determineOptimalCachePoints(): CacheResult { - // If prompt caching is disabled or no messages, return without cache points - if (!this.config.usePromptCache || this.config.messages.length === 0) { - return this.formatWithoutCachePoints() - } - - const supportsSystemCache = this.config.modelInfo.cachableFields.includes("system") - const supportsMessageCache = this.config.modelInfo.cachableFields.includes("messages") - const minTokensPerPoint = this.config.modelInfo.minTokensPerCachePoint - let remainingCachePoints: number = this.config.modelInfo.maxCachePoints - - // First, determine if we'll use a system cache point - const useSystemCache = - supportsSystemCache && this.config.systemPrompt && this.meetsMinTokenThreshold(this.systemTokenCount) - - // Handle system blocks - let systemBlocks: SystemContentBlock[] = [] - if (this.config.systemPrompt) { - systemBlocks = [{ text: this.config.systemPrompt } as unknown as SystemContentBlock] - if (useSystemCache) { - systemBlocks.push(this.createCachePoint() as unknown as SystemContentBlock) - remainingCachePoints-- - } - } - - // If message caching isn't supported, return with just system caching - if (!supportsMessageCache) { - return this.formatResult(systemBlocks, this.messagesToContentBlocks(this.config.messages)) - } - - const placements = this.determineMessageCachePoints(minTokensPerPoint, remainingCachePoints) - const messages = this.messagesToContentBlocks(this.config.messages) - let cacheResult = this.formatResult(systemBlocks, this.applyCachePoints(messages, placements)) - - // Store the placements for future use (to maintain consistency across consecutive messages) - // This needs to be handled by the caller by passing these placements back in the next call - cacheResult.messageCachePointPlacements = placements - - return cacheResult - } - - /** - * Determine optimal cache point placements for messages - * This method handles both new conversations and growing conversations - * - * @param minTokensPerPoint Minimum tokens required per cache point - * @param remainingCachePoints Number of cache points available - * @returns Array of cache point placements - */ - private determineMessageCachePoints( - minTokensPerPoint: number, - remainingCachePoints: number, - ): CachePointPlacement[] { - if (this.config.messages.length <= 1) { - return [] - } - - const placements: CachePointPlacement[] = [] - const totalMessages = this.config.messages.length - const previousPlacements = this.config.previousCachePointPlacements || [] - - // Special case: If previousPlacements is empty, place initial cache points - if (previousPlacements.length === 0) { - let currentIndex = 0 - - while (currentIndex < totalMessages && remainingCachePoints > 0) { - const newPlacement = this.findOptimalPlacementForRange( - currentIndex, - totalMessages - 1, - minTokensPerPoint, - ) - - if (newPlacement) { - placements.push(newPlacement) - currentIndex = newPlacement.index + 1 - remainingCachePoints-- - } else { - break - } - } - - return placements - } - - // Calculate tokens in new messages (added since last cache point placement) - const lastPreviousIndex = previousPlacements[previousPlacements.length - 1].index - const newMessagesTokens = this.config.messages - .slice(lastPreviousIndex + 1) - .reduce((acc, curr) => acc + this.estimateTokenCount(curr), 0) - - // If new messages have enough tokens for a cache point, we need to decide - // whether to keep all previous cache points or combine some - if (newMessagesTokens >= minTokensPerPoint) { - // If we have enough cache points for all previous placements plus a new one, keep them all - if (remainingCachePoints > previousPlacements.length) { - // Keep all previous placements - for (const placement of previousPlacements) { - if (placement.index < totalMessages) { - placements.push(placement) - } - } - - // Add a new placement for the new messages - const newPlacement = this.findOptimalPlacementForRange( - lastPreviousIndex + 1, - totalMessages - 1, - minTokensPerPoint, - ) - - if (newPlacement) { - placements.push(newPlacement) - } - } else { - // We need to decide which previous cache points to keep and which to combine - // Strategy: Compare the token count of new messages with the smallest combined token gap - - // First, analyze the token distribution between previous cache points - const tokensBetweenPlacements: number[] = [] - let startIdx = 0 - - for (const placement of previousPlacements) { - const tokens = this.config.messages - .slice(startIdx, placement.index + 1) - .reduce((acc, curr) => acc + this.estimateTokenCount(curr), 0) - - tokensBetweenPlacements.push(tokens) - startIdx = placement.index + 1 - } - - // Find the two consecutive placements with the smallest token gap - let smallestGapIndex = 0 - let smallestGap = Number.MAX_VALUE - - for (let i = 0; i < tokensBetweenPlacements.length - 1; i++) { - const gap = tokensBetweenPlacements[i] + tokensBetweenPlacements[i + 1] - if (gap < smallestGap) { - smallestGap = gap - smallestGapIndex = i - } - } - - // Only combine cache points if it's beneficial - // Compare the token count of new messages with the smallest combined token gap - // Apply a required percentage increase to ensure reallocation is worth it - const requiredPercentageIncrease = 1.2 // 20% increase required - const requiredTokenThreshold = smallestGap * requiredPercentageIncrease - - if (newMessagesTokens >= requiredTokenThreshold) { - // It's beneficial to combine cache points since new messages have significantly more tokens - logger.info("Combining cache points is beneficial", { - ctx: "cache-strategy", - newMessagesTokens, - smallestGap, - requiredTokenThreshold, - action: "combining_cache_points", - }) - - // Combine the two placements with the smallest gap - for (let i = 0; i < previousPlacements.length; i++) { - if (i !== smallestGapIndex && i !== smallestGapIndex + 1) { - // Keep this placement - if (previousPlacements[i].index < totalMessages) { - placements.push(previousPlacements[i]) - } - } else if (i === smallestGapIndex) { - // Replace with a combined placement - const combinedEndIndex = previousPlacements[i + 1].index - - // Find the optimal placement within this combined range - const startOfRange = i === 0 ? 0 : previousPlacements[i - 1].index + 1 - const combinedPlacement = this.findOptimalPlacementForRange( - startOfRange, - combinedEndIndex, - minTokensPerPoint, - ) - - if (combinedPlacement) { - placements.push(combinedPlacement) - } - - // Skip the next placement as we've combined it - i++ - } - } - - // If we freed up a cache point, use it for the new messages - if (placements.length < remainingCachePoints) { - const newPlacement = this.findOptimalPlacementForRange( - lastPreviousIndex + 1, - totalMessages - 1, - minTokensPerPoint, - ) - - if (newPlacement) { - placements.push(newPlacement) - } - } - } else { - // It's not beneficial to combine cache points - // Keep all previous placements and don't add a new one for the new messages - logger.info("Combining cache points is not beneficial", { - ctx: "cache-strategy", - newMessagesTokens, - smallestGap, - action: "keeping_existing_cache_points", - }) - - // Keep all previous placements that are still valid - for (const placement of previousPlacements) { - if (placement.index < totalMessages) { - placements.push(placement) - } - } - } - } - - return placements - } else { - // New messages don't have enough tokens for a cache point - // Keep all previous placements that are still valid - for (const placement of previousPlacements) { - if (placement.index < totalMessages) { - placements.push(placement) - } - } - - return placements - } - } - - /** - * Find the optimal placement for a cache point within a specified range of messages - * Simply finds the last user message in the range - */ - private findOptimalPlacementForRange( - startIndex: number, - endIndex: number, - minTokensPerPoint: number, - ): CachePointPlacement | null { - if (startIndex >= endIndex) { - return null - } - - // Find the last user message in the range - let lastUserMessageIndex = -1 - for (let i = endIndex; i >= startIndex; i--) { - if (this.config.messages[i].role === "user") { - lastUserMessageIndex = i - break - } - } - - if (lastUserMessageIndex >= 0) { - // Calculate the total tokens covered from the previous cache point (or start of conversation) - // to this cache point. This ensures tokensCovered represents the full span of tokens - // that will be cached by this cache point. - let totalTokensCovered = 0 - - // Find the previous cache point index - const previousPlacements = this.config.previousCachePointPlacements || [] - let previousCachePointIndex = -1 - - for (const placement of previousPlacements) { - if (placement.index < startIndex && placement.index > previousCachePointIndex) { - previousCachePointIndex = placement.index - } - } - - // Calculate tokens from previous cache point (or start) to this cache point - const tokenStartIndex = previousCachePointIndex + 1 - totalTokensCovered = this.config.messages - .slice(tokenStartIndex, lastUserMessageIndex + 1) - .reduce((acc, curr) => acc + this.estimateTokenCount(curr), 0) - - // Guard clause: ensure we have enough tokens to justify a cache point - if (totalTokensCovered < minTokensPerPoint) { - return null - } - return { - index: lastUserMessageIndex, - type: "message", - tokensCovered: totalTokensCovered, - } - } - - return null - } - - /** - * Format result without cache points - * - * @returns Cache result without cache points - */ - private formatWithoutCachePoints(): CacheResult { - const systemBlocks: SystemContentBlock[] = this.config.systemPrompt - ? [{ text: this.config.systemPrompt } as unknown as SystemContentBlock] - : [] - - return this.formatResult(systemBlocks, this.messagesToContentBlocks(this.config.messages)) - } -} diff --git a/src/api/transform/cache-strategy/types.ts b/src/api/transform/cache-strategy/types.ts deleted file mode 100644 index 2b5d5736c96..00000000000 --- a/src/api/transform/cache-strategy/types.ts +++ /dev/null @@ -1,68 +0,0 @@ -import { Anthropic } from "@anthropic-ai/sdk" -import { SystemContentBlock, Message } from "@aws-sdk/client-bedrock-runtime" - -/** - * Information about a model's capabilities and constraints - */ -export interface ModelInfo { - /** Maximum number of tokens the model can generate */ - maxTokens: number - /** Maximum context window size in tokens */ - contextWindow: number - /** Whether the model supports prompt caching */ - supportsPromptCache: boolean - /** Maximum number of cache points supported by the model */ - maxCachePoints: number - /** Minimum number of tokens required for a cache point */ - minTokensPerCachePoint: number - /** Fields that can be cached */ - cachableFields: Array<"system" | "messages" | "tools"> -} - -/** - * Cache point definition - */ -export interface CachePoint { - /** Type of cache point */ - type: "default" -} - -/** - * Result of cache strategy application - */ -export interface CacheResult { - /** System content blocks */ - system: SystemContentBlock[] - /** Message content blocks */ - messages: Message[] - /** Cache point placements for messages (for maintaining consistency across consecutive messages) */ - messageCachePointPlacements?: CachePointPlacement[] -} - -/** - * Represents the position and metadata for a cache point - */ -export interface CachePointPlacement { - /** Where to insert the cache point */ - index: number - /** Type of cache point */ - type: "system" | "message" - /** Number of tokens this cache point covers */ - tokensCovered: number -} - -/** - * Configuration for the caching strategy - */ -export interface CacheStrategyConfig { - /** Model information */ - modelInfo: ModelInfo - /** System prompt text */ - systemPrompt?: string - /** Messages to process */ - messages: Anthropic.Messages.MessageParam[] - /** Whether to use prompt caching */ - usePromptCache: boolean - /** Previous cache point placements (for maintaining consistency across consecutive messages) */ - previousCachePointPlacements?: CachePointPlacement[] -} diff --git a/src/package.json b/src/package.json index d1ee3757159..75269a63021 100644 --- a/src/package.json +++ b/src/package.json @@ -450,6 +450,7 @@ "clean": "rimraf README.md CHANGELOG.md LICENSE dist logs mock .turbo" }, "dependencies": { + "@ai-sdk/amazon-bedrock": "^4.0.50", "@ai-sdk/cerebras": "^1.0.0", "@ai-sdk/deepseek": "^2.0.14", "@ai-sdk/fireworks": "^2.0.26", @@ -458,8 +459,6 @@ "@ai-sdk/groq": "^3.0.19", "@ai-sdk/mistral": "^3.0.0", "@ai-sdk/xai": "^3.0.46", - "sambanova-ai-provider": "^1.2.2", - "@anthropic-ai/bedrock-sdk": "^0.10.2", "@anthropic-ai/sdk": "^0.37.0", "@anthropic-ai/vertex-sdk": "^0.7.0", "@aws-sdk/client-bedrock-runtime": "^3.922.0", @@ -518,6 +517,7 @@ "puppeteer-core": "^23.4.0", "reconnecting-eventsource": "^1.6.4", "safe-stable-stringify": "^2.5.0", + "sambanova-ai-provider": "^1.2.2", "sanitize-filename": "^1.6.3", "say": "^0.16.0", "semver-compare": "^1.0.0",