From 498c982156c194fdeaa02929074bb4fec315b2fd Mon Sep 17 00:00:00 2001 From: Brent Date: Tue, 12 May 2026 13:18:33 -0400 Subject: [PATCH] SMOODEV-948: Async auth-token provider across TS, Python, Rust, Go Only .NET shipped a first-class `AuthTokenProvider` delegate. The other four ports forced callers to wire auth either statically (`with_auth(string)`) or manually via a pre-request hook. Add a real async provider hook to each. TypeScript ---------- - `FetchBuilder.withAuthTokenProvider(provider, scheme = 'Bearer')`. - `PreRequestHook` return type widened to allow `Promise<...>` so the pipeline can `await` async hooks. The new builder method composes with any existing pre-request hook (hook runs first, Authorization injected after). - Export new `AuthTokenProvider` type. Python ------ - `FetchBuilder.with_auth_provider(provider, scheme='Bearer')`. - New `auth_token_provider` + `auth_scheme` fields on `FetchOptions`. - Provider may return `str` or an awaitable; the client awaits inline. Rust ---- - New `AuthTokenProvider` type alias (`Arc Pin>>>`). - `FetchBuilder::with_auth_provider(provider, scheme)`. - `FetchClient` calls the provider before each `fetch()` / `fetch_with_options()` and injects the Authorization header into the merged init. Go -- - `AuthTokenProvider` type: `func(ctx context.Context) (string, error)`. - `ClientBuilder.WithAuthTokenProvider(provider, scheme)`. - Provider is invoked inside `executeHTTPRequest` after pre-request hook so cancellation/timeouts flow through; provider errors propagate via the PostResponseError hook chain. Tests ----- - TS: 4 new cases in `fetch.spec.ts` (sync, async, custom scheme, hook composition). - Python: new `test_auth_provider.py` with 5 cases. - Rust: new `auth_provider_tests.rs` with 3 cases. - Go: new `auth_provider_test.go` with 4 cases. --- .changeset/SMOODEV-948-async-auth-provider.md | 5 + go/fetch/auth_provider_test.go | 111 +++++++++++++++ go/fetch/builder.go | 23 +++ go/fetch/client.go | 22 +++ go/fetch/options.go | 7 + python/src/smooai_fetch/__init__.py | 2 + python/src/smooai_fetch/_builder.py | 29 ++++ python/src/smooai_fetch/_client.py | 13 ++ python/src/smooai_fetch/_types.py | 17 +++ python/tests/test_auth_provider.py | 94 +++++++++++++ rust/fetch/src/builder.rs | 66 ++++++++- rust/fetch/src/lib.rs | 4 +- rust/fetch/src/types.rs | 12 ++ rust/fetch/tests/auth_provider_tests.rs | 132 ++++++++++++++++++ src/fetch.spec.ts | 72 ++++++++++ src/fetch.ts | 69 ++++++++- 16 files changed, 668 insertions(+), 10 deletions(-) create mode 100644 .changeset/SMOODEV-948-async-auth-provider.md create mode 100644 go/fetch/auth_provider_test.go create mode 100644 python/tests/test_auth_provider.py create mode 100644 rust/fetch/tests/auth_provider_tests.rs diff --git a/.changeset/SMOODEV-948-async-auth-provider.md b/.changeset/SMOODEV-948-async-auth-provider.md new file mode 100644 index 0000000..0fc3559 --- /dev/null +++ b/.changeset/SMOODEV-948-async-auth-provider.md @@ -0,0 +1,5 @@ +--- +'@smooai/fetch': patch +--- + +SMOODEV-948: Async auth-token provider across TS, Python, Rust, Go. Adds a first-class hook that's invoked before every request to mint / refresh an auth token (sync or async), with the resulting `Authorization` header injected using a configurable scheme (default `Bearer`). Mirrors the existing .NET `AuthTokenProvider` delegate. diff --git a/go/fetch/auth_provider_test.go b/go/fetch/auth_provider_test.go new file mode 100644 index 0000000..c6531c1 --- /dev/null +++ b/go/fetch/auth_provider_test.go @@ -0,0 +1,111 @@ +package fetch + +import ( + "context" + "errors" + "fmt" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" +) + +func TestClientBuilder_WithAuthTokenProvider_DefaultScheme(t *testing.T) { + var captured string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + captured = r.Header.Get("Authorization") + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"ok":true}`)) + })) + defer server.Close() + + client := NewClientBuilder(). + WithNoRetry(). + WithAuthTokenProvider(func(_ context.Context) (string, error) { + return "fresh-token", nil + }, ""). + Build() + + _, err := SimpleGet(context.Background(), client, server.URL, nil) + if err != nil { + t.Fatalf("fetch failed: %v", err) + } + + if captured != "Bearer fresh-token" { + t.Errorf("expected 'Bearer fresh-token', got %q", captured) + } +} + +func TestClientBuilder_WithAuthTokenProvider_CustomScheme(t *testing.T) { + var captured string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + captured = r.Header.Get("Authorization") + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"ok":true}`)) + })) + defer server.Close() + + client := NewClientBuilder(). + WithNoRetry(). + WithAuthTokenProvider(func(_ context.Context) (string, error) { + return "abc", nil + }, "Token"). + Build() + + _, err := SimpleGet(context.Background(), client, server.URL, nil) + if err != nil { + t.Fatalf("fetch failed: %v", err) + } + + if captured != "Token abc" { + t.Errorf("expected 'Token abc', got %q", captured) + } +} + +func TestClientBuilder_WithAuthTokenProvider_InvokedPerRequest(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"ok":true}`)) + })) + defer server.Close() + + var calls int32 + client := NewClientBuilder(). + WithNoRetry(). + WithAuthTokenProvider(func(_ context.Context) (string, error) { + n := atomic.AddInt32(&calls, 1) + return fmt.Sprintf("tok-%d", n), nil + }, "Bearer"). + Build() + + for i := 0; i < 3; i++ { + if _, err := SimpleGet(context.Background(), client, server.URL, nil); err != nil { + t.Fatalf("fetch %d failed: %v", i, err) + } + } + + if got := atomic.LoadInt32(&calls); got != 3 { + t.Errorf("expected 3 provider invocations, got %d", got) + } +} + +func TestClientBuilder_WithAuthTokenProvider_PropagatesError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Error("HTTP request fired despite provider error") + w.WriteHeader(200) + })) + defer server.Close() + + wantErr := errors.New("token-mint-failure") + client := NewClientBuilder(). + WithNoRetry(). + WithAuthTokenProvider(func(_ context.Context) (string, error) { + return "", wantErr + }, "Bearer"). + Build() + + _, err := SimpleGet(context.Background(), client, server.URL, nil) + if err == nil || !errors.Is(err, wantErr) { + t.Fatalf("expected wrapped %v, got %v", wantErr, err) + } +} diff --git a/go/fetch/builder.go b/go/fetch/builder.go index 9c64ef0..b65aeed 100644 --- a/go/fetch/builder.go +++ b/go/fetch/builder.go @@ -16,6 +16,8 @@ type ClientBuilder struct { circuitBreakerOpts *CircuitBreakerOptions circuitBreakerName string hooks *LifecycleHooks + authProvider AuthTokenProvider + authScheme string } // NewClientBuilder creates a new ClientBuilder with default retry and timeout options. @@ -86,6 +88,22 @@ func (b *ClientBuilder) WithHooks(hooks *LifecycleHooks) *ClientBuilder { return b } +// WithAuthTokenProvider registers a sync-or-async auth token provider that is +// invoked before every request and used to populate the `Authorization` +// header. The provider receives the request context, so it can short-circuit +// on cancellation/timeouts. Mirrors the .NET `AuthTokenProvider` delegate and +// the TypeScript `FetchBuilder.withAuthTokenProvider(...)` method. +// +// Pass an empty string for scheme to default to "Bearer". +func (b *ClientBuilder) WithAuthTokenProvider(provider AuthTokenProvider, scheme string) *ClientBuilder { + b.authProvider = provider + if scheme == "" { + scheme = "Bearer" + } + b.authScheme = scheme + return b +} + // WithNoRetry disables retries. func (b *ClientBuilder) WithNoRetry() *ClientBuilder { b.retryOpts = nil @@ -126,6 +144,11 @@ func (b *ClientBuilder) Build() *Client { timeout: b.timeoutOpts, rateLimitRetry: b.rateLimitRetryOpts, hooks: b.hooks, + authProvider: b.authProvider, + authScheme: b.authScheme, + } + if c.authScheme == "" { + c.authScheme = "Bearer" } if c.httpClient == nil { diff --git a/go/fetch/client.go b/go/fetch/client.go index 0e3258e..ad5084e 100644 --- a/go/fetch/client.go +++ b/go/fetch/client.go @@ -22,6 +22,8 @@ type Client struct { rateLimitRetry *RateLimitRetryOptions circuitBreaker *CircuitBreaker hooks *LifecycleHooks + authProvider AuthTokenProvider + authScheme string } // NewClient creates a new Client with default settings (retry + timeout). @@ -186,6 +188,26 @@ func executeHTTPRequest[T any]( } } + // Apply auth-token provider (after the pre-request hook so the hook can + // adjust the URL/init first; the resulting Authorization header overrides + // any value the hook may have set). + if client.authProvider != nil { + token, err := client.authProvider(ctx) + if err != nil { + if hooks != nil && hooks.PostResponseError != nil { + if replacementErr := hooks.PostResponseError(requestURL, req, err, nil); replacementErr != nil { + return nil, replacementErr + } + } + return nil, fmt.Errorf("auth token provider failed: %w", err) + } + scheme := client.authScheme + if scheme == "" { + scheme = "Bearer" + } + req.Header.Set("Authorization", fmt.Sprintf("%s %s", scheme, token)) + } + // Execute the HTTP request httpClient := client.httpClient if httpClient == nil { diff --git a/go/fetch/options.go b/go/fetch/options.go index 1d38c87..44ef01d 100644 --- a/go/fetch/options.go +++ b/go/fetch/options.go @@ -1,6 +1,7 @@ package fetch import ( + "context" "net/http" "time" ) @@ -114,6 +115,12 @@ type CircuitBreakerCounts struct { ConsecutiveFailures uint32 } +// AuthTokenProvider is invoked before every request to mint an auth token that +// is injected into the `Authorization` header. The provider receives the +// request context so callers can hook into cancellation/timeouts when fetching +// or refreshing tokens. Mirrors the .NET `AuthTokenProvider` delegate. +type AuthTokenProvider func(ctx context.Context) (string, error) + // LifecycleHooks provides hooks into the request/response lifecycle. type LifecycleHooks struct { // PreRequest is called before sending the request. diff --git a/python/src/smooai_fetch/__init__.py b/python/src/smooai_fetch/__init__.py index ed2be8f..cf2f86d 100644 --- a/python/src/smooai_fetch/__init__.py +++ b/python/src/smooai_fetch/__init__.py @@ -43,6 +43,7 @@ # Types from smooai_fetch._types import ( + AuthTokenProvider, CircuitBreakerOptions, FetchContainerOptions, FetchOptions, @@ -66,6 +67,7 @@ "FetchBuilder", "FetchResponse", # Types + "AuthTokenProvider", "CircuitBreakerOptions", "FetchContainerOptions", "FetchOptions", diff --git a/python/src/smooai_fetch/_builder.py b/python/src/smooai_fetch/_builder.py index 31964d7..688393c 100644 --- a/python/src/smooai_fetch/_builder.py +++ b/python/src/smooai_fetch/_builder.py @@ -10,6 +10,7 @@ from smooai_fetch._defaults import DEFAULT_RETRY_OPTIONS from smooai_fetch._response import FetchResponse from smooai_fetch._types import ( + AuthTokenProvider, CircuitBreakerOptions, FetchContainerOptions, FetchOptions, @@ -54,6 +55,8 @@ def __init__(self) -> None: self._schema: type[BaseModel] | None = None self._headers: dict[str, str] = {} self._hooks: LifecycleHooks = LifecycleHooks() + self._auth_token_provider: AuthTokenProvider | None = None + self._auth_scheme: str = "Bearer" def with_retry(self, options: RetryOptions | None = None) -> FetchBuilder: """Configure retry behavior. @@ -159,6 +162,30 @@ def with_auth(self, token: str, scheme: str = "Bearer") -> FetchBuilder: self._headers["Authorization"] = f"{scheme} {token}" return self + def with_auth_provider( + self, + provider: AuthTokenProvider, + scheme: str = "Bearer", + ) -> FetchBuilder: + """Register a sync or async auth token provider. + + The provider is invoked before every request and its result is injected + as the `Authorization` header. If the provider returns an awaitable, + it is awaited inline. Mirrors the .NET `AuthTokenProvider` delegate. + + Args: + provider: Callable returning the bare token (no scheme prefix). + Sync `() -> str` or async `() -> Awaitable[str]` are both + accepted. + scheme: Auth scheme prefix. Defaults to "Bearer". + + Returns: + The builder instance for method chaining. + """ + self._auth_token_provider = provider + self._auth_scheme = scheme + return self + def with_pre_request_hook(self, hook: PreRequestHook) -> FetchBuilder: """Set a pre-request hook. @@ -216,6 +243,8 @@ def build(self) -> FetchOptions: schema=self._schema, hooks=self._hooks, container_options=container_options, + auth_token_provider=self._auth_token_provider, + auth_scheme=self._auth_scheme, ) async def fetch( diff --git a/python/src/smooai_fetch/_client.py b/python/src/smooai_fetch/_client.py index 9e62475..a887ff4 100644 --- a/python/src/smooai_fetch/_client.py +++ b/python/src/smooai_fetch/_client.py @@ -2,6 +2,7 @@ from __future__ import annotations +import inspect import json from typing import Any, TypeVar @@ -234,6 +235,18 @@ async def fetch(url: str, options: FetchOptions | None = None) -> FetchResponse[ current_url, request_kwargs = hook_result request_kwargs["url"] = current_url + # Apply auth-token provider (after pre-request hook so the hook can still + # adjust headers / URL first). Supports both sync providers and async + # providers; the awaitable is awaited inline so we always get the resolved + # token before the request fires. + if opts.auth_token_provider is not None: + raw_token = opts.auth_token_provider() + if inspect.isawaitable(raw_token): + raw_token = await raw_token + token: str = str(raw_token) + headers: dict[str, str] = request_kwargs.setdefault("headers", {}) + headers["Authorization"] = f"{opts.auth_scheme} {token}" + # Build rate limiter rate_limiter: SlidingWindowRateLimiter | None = None if container_options and container_options.rate_limit: diff --git a/python/src/smooai_fetch/_types.py b/python/src/smooai_fetch/_types.py index ec6efa4..0312bad 100644 --- a/python/src/smooai_fetch/_types.py +++ b/python/src/smooai_fetch/_types.py @@ -88,6 +88,13 @@ class RetryContext: """Callback invoked before each retry attempt to override default behavior.""" +AuthTokenProvider = Callable[[], "Any"] +"""Provider that returns an auth token. May return a `str` directly or an awaitable +that resolves to a `str`. Invoked before every request to populate the +`Authorization` header. Mirrors the .NET delegate of the same name. +""" + + @dataclass class RetryOptions: """Configuration options for retry behavior.""" @@ -242,3 +249,13 @@ class FetchOptions: container_options: FetchContainerOptions | None = None """Container-level options (rate limit, circuit breaker).""" + + auth_token_provider: AuthTokenProvider | None = None + """Optional sync or async provider invoked before each request to mint an auth token. + + The returned token is injected into the `Authorization` header using + `auth_scheme` (default `"Bearer"`). Awaitable return values are awaited. + """ + + auth_scheme: str = "Bearer" + """Auth scheme prefix used with `auth_token_provider`. Defaults to "Bearer".""" diff --git a/python/tests/test_auth_provider.py b/python/tests/test_auth_provider.py new file mode 100644 index 0000000..0f1d203 --- /dev/null +++ b/python/tests/test_auth_provider.py @@ -0,0 +1,94 @@ +"""Tests for `FetchBuilder.with_auth_provider` / `FetchOptions.auth_token_provider`.""" + +import httpx +import respx + +from smooai_fetch import FetchBuilder, FetchOptions, fetch + +URL = "https://api.example.com/data" + + +@respx.mock +async def test_sync_auth_provider_injects_bearer(): + """A sync provider injects an Authorization header with the default Bearer scheme.""" + route = respx.get(URL).mock( + return_value=httpx.Response(200, json={"ok": True}, headers={"Content-Type": "application/json"}), + ) + + fetcher = FetchBuilder().with_auth_provider(lambda: "tok-sync") + r = await fetcher.fetch(URL) + assert r.ok + + headers = route.calls.last.request.headers + assert headers["authorization"] == "Bearer tok-sync" + + +@respx.mock +async def test_async_auth_provider_is_awaited(): + """An async provider is awaited inline before the request fires.""" + route = respx.get(URL).mock( + return_value=httpx.Response(200, json={"ok": True}, headers={"Content-Type": "application/json"}), + ) + + call_count = 0 + + async def provider() -> str: + nonlocal call_count + call_count += 1 + return f"tok-async-{call_count}" + + fetcher = FetchBuilder().with_auth_provider(provider) + + await fetcher.fetch(URL) + await fetcher.fetch(URL) + + assert call_count == 2 + # Last call should carry the second token. + headers = route.calls.last.request.headers + assert headers["authorization"] == "Bearer tok-async-2" + + +@respx.mock +async def test_custom_auth_scheme(): + """The configured scheme prefixes the token.""" + route = respx.get(URL).mock( + return_value=httpx.Response(200, json={"ok": True}, headers={"Content-Type": "application/json"}), + ) + + fetcher = FetchBuilder().with_auth_provider(lambda: "abc", scheme="Token") + await fetcher.fetch(URL) + + assert route.calls.last.request.headers["authorization"] == "Token abc" + + +@respx.mock +async def test_auth_provider_via_fetch_options_directly(): + """`FetchOptions.auth_token_provider` works without the builder.""" + route = respx.get(URL).mock( + return_value=httpx.Response(200, json={"ok": True}, headers={"Content-Type": "application/json"}), + ) + + async def provider() -> str: + return "direct-tok" + + opts = FetchOptions(auth_token_provider=provider, auth_scheme="Bearer") + r = await fetch(URL, opts) + assert r.ok + assert route.calls.last.request.headers["authorization"] == "Bearer direct-tok" + + +@respx.mock +async def test_auth_provider_overrides_static_auth_header(): + """The provider runs after the pre-request hook and overrides any prior Authorization header.""" + route = respx.get(URL).mock( + return_value=httpx.Response(200, json={"ok": True}, headers={"Content-Type": "application/json"}), + ) + + fetcher = ( + FetchBuilder() + .with_auth("stale-static-token") # Sets Authorization header on the builder + .with_auth_provider(lambda: "fresh-token-from-provider") + ) + + await fetcher.fetch(URL) + assert route.calls.last.request.headers["authorization"] == "Bearer fresh-token-from-provider" diff --git a/rust/fetch/src/builder.rs b/rust/fetch/src/builder.rs index fc9ccd6..7f64a80 100644 --- a/rust/fetch/src/builder.rs +++ b/rust/fetch/src/builder.rs @@ -13,8 +13,9 @@ use crate::hooks::{ use crate::rate_limit::SlidingWindowRateLimiter; use crate::response::FetchResponse; use crate::types::{ - CircuitBreakerOptions, FetchContainerOptions, FetchOptions, RateLimitOptions, - RateLimitRetryOptions, RequestInit, RetryCallback, RetryOptions, TimeoutOptions, + AuthTokenProvider, CircuitBreakerOptions, FetchContainerOptions, FetchOptions, + RateLimitOptions, RateLimitRetryOptions, RequestInit, RetryCallback, RetryOptions, + TimeoutOptions, }; /// Builder for creating configured fetch functions with retry, timeout, rate limiting, @@ -52,6 +53,8 @@ pub struct FetchBuilder { container_options: FetchContainerOptions, default_init: Option, hooks: LifecycleHooks, + auth_token_provider: Option, + auth_scheme: String, } impl FetchBuilder { @@ -65,6 +68,8 @@ impl FetchBuilder { container_options: FetchContainerOptions::default(), default_init: None, hooks: LifecycleHooks::default(), + auth_token_provider: None, + auth_scheme: "Bearer".to_string(), } } @@ -191,6 +196,42 @@ impl FetchBuilder { self } + /// Register an async auth-token provider that is invoked before every + /// request and used to populate the `Authorization` header. + /// + /// Mirrors the .NET `AuthTokenProvider` delegate and the TypeScript + /// `FetchBuilder.withAuthTokenProvider(...)` method. + /// + /// # Example + /// + /// ```rust,no_run + /// use std::sync::Arc; + /// use smooai_fetch::builder::FetchBuilder; + /// use smooai_fetch::types::AuthTokenProvider; + /// use serde::Deserialize; + /// + /// #[derive(Deserialize, Clone, Debug)] + /// struct Reply { ok: bool } + /// + /// # async fn example() { + /// let provider: AuthTokenProvider = Arc::new(|| { + /// Box::pin(async move { + /// // Imagine fetching/refreshing the token here. + /// "fresh-token".to_string() + /// }) + /// }); + /// + /// let _client = FetchBuilder::::new() + /// .with_auth_provider(provider, "Bearer".to_string()) + /// .build(); + /// # } + /// ``` + pub fn with_auth_provider(mut self, provider: AuthTokenProvider, scheme: String) -> Self { + self.auth_token_provider = Some(provider); + self.auth_scheme = scheme; + self + } + /// Build the configured fetch client. pub fn build(self) -> FetchClient { let rate_limiter = self @@ -213,6 +254,8 @@ impl FetchBuilder { rate_limit_retry: self.container_options.rate_limit_retry, circuit_breaker, hooks: Arc::new(self.hooks), + auth_token_provider: self.auth_token_provider, + auth_scheme: self.auth_scheme, } } } @@ -232,6 +275,8 @@ pub struct FetchClient { rate_limit_retry: Option, circuit_breaker: Option, hooks: Arc>, + auth_token_provider: Option, + auth_scheme: String, } impl FetchClient { @@ -242,7 +287,7 @@ impl FetchClient { init: RequestInit, ) -> Result, FetchError> { // Merge default init with per-request init - let merged_init = self.merge_init(init); + let merged_init = self.apply_auth(self.merge_init(init)).await; crate::client::fetch::( url, @@ -263,7 +308,7 @@ impl FetchClient { init: RequestInit, options: FetchOptions, ) -> Result, FetchError> { - let merged_init = self.merge_init(init); + let merged_init = self.apply_auth(self.merge_init(init)).await; crate::client::fetch::( url, @@ -277,6 +322,19 @@ impl FetchClient { .await } + /// Invoke the configured auth-token provider (if any) and inject the + /// resulting `Authorization` header into the request init. + async fn apply_auth(&self, mut init: RequestInit) -> RequestInit { + if let Some(ref provider) = self.auth_token_provider { + let token = provider().await; + init.headers.insert( + "Authorization".to_string(), + format!("{} {}", self.auth_scheme, token), + ); + } + init + } + /// Get a reference to the circuit breaker, if configured. pub fn circuit_breaker(&self) -> Option<&CircuitBreaker> { self.circuit_breaker.as_ref() diff --git a/rust/fetch/src/lib.rs b/rust/fetch/src/lib.rs index 4922a0b..82d03f3 100644 --- a/rust/fetch/src/lib.rs +++ b/rust/fetch/src/lib.rs @@ -63,8 +63,8 @@ pub use error::FetchError; pub use rate_limit::SlidingWindowRateLimiter; pub use response::FetchResponse; pub use types::{ - FetchContainerOptions, FetchOptions, Method, RateLimitRetryOptions, RequestInit, RetryCallback, - RetryContext, RetryDecision, RetryOptions, + AuthTokenFuture, AuthTokenProvider, FetchContainerOptions, FetchOptions, Method, + RateLimitRetryOptions, RequestInit, RetryCallback, RetryContext, RetryDecision, RetryOptions, }; /// Convenience function: perform a single fetch with default options. diff --git a/rust/fetch/src/types.rs b/rust/fetch/src/types.rs index bd7e20f..fc52400 100644 --- a/rust/fetch/src/types.rs +++ b/rust/fetch/src/types.rs @@ -54,6 +54,18 @@ pub enum RetryDecision { /// task boundaries and stored inside [`RetryOptions`] (which is `Clone`). pub type RetryCallback = Arc RetryDecision + Send + Sync>; +/// Future returned by an [`AuthTokenProvider`]. +pub type AuthTokenFuture = + std::pin::Pin + Send + 'static>>; + +/// Async provider that mints an auth token. Invoked before every request to +/// populate the `Authorization` header. Mirrors the .NET `AuthTokenProvider` +/// delegate. +/// +/// The provider is wrapped in an [`Arc`] so the configured client can be +/// cheaply cloned across task boundaries. +pub type AuthTokenProvider = Arc AuthTokenFuture + Send + Sync>; + /// Configuration options for retry behavior. #[derive(Clone)] pub struct RetryOptions { diff --git a/rust/fetch/tests/auth_provider_tests.rs b/rust/fetch/tests/auth_provider_tests.rs new file mode 100644 index 0000000..975a46a --- /dev/null +++ b/rust/fetch/tests/auth_provider_tests.rs @@ -0,0 +1,132 @@ +//! Tests for `FetchBuilder::with_auth_provider`. + +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; + +use serde::Deserialize; +use wiremock::matchers::{header, method, path}; +use wiremock::{Mock, MockServer, ResponseTemplate}; + +use smooai_fetch::builder::FetchBuilder; +use smooai_fetch::types::{AuthTokenProvider, Method, RequestInit}; + +#[derive(Deserialize, Clone, Debug)] +struct TestReply { + ok: bool, +} + +#[tokio::test] +async fn test_async_auth_provider_injects_bearer_header() { + let mock_server = MockServer::start().await; + Mock::given(method("GET")) + .and(path("/data")) + .and(header("Authorization", "Bearer fresh-token")) + .respond_with( + ResponseTemplate::new(200) + .set_body_json(serde_json::json!({"ok": true})) + .insert_header("content-type", "application/json"), + ) + .mount(&mock_server) + .await; + + let provider: AuthTokenProvider = + Arc::new(|| Box::pin(async move { "fresh-token".to_string() })); + + let client = FetchBuilder::::new() + .without_retry() + .with_auth_provider(provider, "Bearer".to_string()) + .build(); + + let url = format!("{}/data", mock_server.uri()); + let response = client + .fetch( + &url, + RequestInit { + method: Method::GET, + ..Default::default() + }, + ) + .await + .unwrap(); + assert!(response.ok); +} + +#[tokio::test] +async fn test_provider_is_invoked_before_every_request() { + let mock_server = MockServer::start().await; + Mock::given(method("GET")) + .and(path("/data")) + .respond_with( + ResponseTemplate::new(200) + .set_body_json(serde_json::json!({"ok": true})) + .insert_header("content-type", "application/json"), + ) + .mount(&mock_server) + .await; + + let counter = Arc::new(AtomicUsize::new(0)); + let counter_for_provider = counter.clone(); + let provider: AuthTokenProvider = Arc::new(move || { + let c = counter_for_provider.clone(); + Box::pin(async move { + let n = c.fetch_add(1, Ordering::SeqCst); + format!("tok-{}", n) + }) + }); + + let client = FetchBuilder::::new() + .without_retry() + .with_auth_provider(provider, "Bearer".to_string()) + .build(); + + let url = format!("{}/data", mock_server.uri()); + for _ in 0..3 { + let _ = client + .fetch( + &url, + RequestInit { + method: Method::GET, + ..Default::default() + }, + ) + .await + .unwrap(); + } + + assert_eq!(counter.load(Ordering::SeqCst), 3); +} + +#[tokio::test] +async fn test_custom_auth_scheme() { + let mock_server = MockServer::start().await; + Mock::given(method("GET")) + .and(path("/data")) + .and(header("Authorization", "Token abc")) + .respond_with( + ResponseTemplate::new(200) + .set_body_json(serde_json::json!({"ok": true})) + .insert_header("content-type", "application/json"), + ) + .mount(&mock_server) + .await; + + let provider: AuthTokenProvider = Arc::new(|| Box::pin(async move { "abc".to_string() })); + + let client = FetchBuilder::::new() + .without_retry() + .with_auth_provider(provider, "Token".to_string()) + .build(); + + let url = format!("{}/data", mock_server.uri()); + let response = client + .fetch( + &url, + RequestInit { + method: Method::GET, + ..Default::default() + }, + ) + .await + .unwrap(); + assert!(response.ok); +} diff --git a/src/fetch.spec.ts b/src/fetch.spec.ts index 4cf8cfc..a4ae3f3 100644 --- a/src/fetch.spec.ts +++ b/src/fetch.spec.ts @@ -1218,4 +1218,76 @@ describe('Test fetch', () => { expect(mockFetch.mock.calls[0][0].toString()).toContain('timestamp=1234567890'); }); }); + + describe('withAuthTokenProvider', () => { + test('injects Authorization header using async provider with default Bearer scheme', async () => { + const mockFetch = global.fetch as MockedFunction<(url: RequestInfo, init?: RequestInit) => Promise>; + mockFetch.mockResolvedValue(fakeResponse(true, 200, { ok: true })); + + let mintCount = 0; + const provider = async () => { + mintCount++; + return 'tok-abc'; + }; + + const fetch = new FetchBuilder().withAuthTokenProvider(provider).build(); + + const r = await fetch(URL_TO_USE); + expect(r.ok).toBe(true); + + const headers = new Headers(mockFetch.mock.calls[0][1]?.headers); + expect(headers.get('Authorization')).toBe('Bearer tok-abc'); + expect(mintCount).toBe(1); + }); + + test('honors a custom scheme', async () => { + const mockFetch = global.fetch as MockedFunction<(url: RequestInfo, init?: RequestInit) => Promise>; + mockFetch.mockResolvedValue(fakeResponse(true, 200, { ok: true })); + + const fetch = new FetchBuilder().withAuthTokenProvider(() => 'static-tok', 'Token').build(); + + await fetch(URL_TO_USE); + const headers = new Headers(mockFetch.mock.calls[0][1]?.headers); + expect(headers.get('Authorization')).toBe('Token static-tok'); + }); + + test('invokes the provider before every request (fresh token per call)', async () => { + const mockFetch = global.fetch as MockedFunction<(url: RequestInfo, init?: RequestInit) => Promise>; + mockFetch.mockResolvedValue(fakeResponse(true, 200, { ok: true })); + + const tokens = ['t1', 't2', 't3']; + let idx = 0; + const provider = () => tokens[idx++]; + + const fetch = new FetchBuilder().withAuthTokenProvider(provider).build(); + + await fetch(URL_TO_USE); + await fetch(URL_TO_USE); + await fetch(URL_TO_USE); + + expect(new Headers(mockFetch.mock.calls[0][1]?.headers).get('Authorization')).toBe('Bearer t1'); + expect(new Headers(mockFetch.mock.calls[1][1]?.headers).get('Authorization')).toBe('Bearer t2'); + expect(new Headers(mockFetch.mock.calls[2][1]?.headers).get('Authorization')).toBe('Bearer t3'); + }); + + test('composes with an existing preRequest hook', async () => { + const mockFetch = global.fetch as MockedFunction<(url: RequestInfo, init?: RequestInit) => Promise>; + mockFetch.mockResolvedValue(fakeResponse(true, 200, { ok: true })); + + const fetch = new FetchBuilder() + .withHooks({ + preRequest: (url, init) => { + init.headers = { ...(init.headers as Record), 'X-Hook': 'fired' }; + return [url, init]; + }, + }) + .withAuthTokenProvider(async () => 'composed-tok') + .build(); + + await fetch(URL_TO_USE); + const headers = new Headers(mockFetch.mock.calls[0][1]?.headers); + expect(headers.get('X-Hook')).toBe('fired'); + expect(headers.get('Authorization')).toBe('Bearer composed-tok'); + }); + }); }); diff --git a/src/fetch.ts b/src/fetch.ts index e0ff8a3..a87c0b9 100644 --- a/src/fetch.ts +++ b/src/fetch.ts @@ -244,9 +244,18 @@ export const DEFAULT_RATE_LIMIT_RETRY_OPTIONS: RetryOptions = { }; /** - * Hook that runs before the request is made, allowing modification of the request + * Hook that runs before the request is made, allowing modification of the request. + * + * May return synchronously or async — the fetch pipeline awaits the result, so + * callers can perform asynchronous work (e.g., minting a fresh auth token). + */ +export type PreRequestHook = (url: string, init: RequestInit) => [string, RequestInit] | void | Promise<[string, RequestInit] | void>; + +/** + * Provider that returns an auth token to inject as the `Authorization` header. + * May be sync or async. Invoked before every request. */ -export type PreRequestHook = (url: string, init: RequestInit) => [string, RequestInit] | void; +export type AuthTokenProvider = () => string | Promise; /** * Hook that runs after a successful response, allowing modification of the response @@ -557,10 +566,10 @@ async function doFetch( }); const logger = options.logger || contextLoggerToUse; - // Apply pre-request hook if present + // Apply pre-request hook if present (supports async hooks) let modifiedInit = init; if (options.hooks?.preRequest) { - const hookResult = options.hooks.preRequest(url.toString(), init); + const hookResult = await options.hooks.preRequest(url.toString(), init); if (hookResult) { modifiedInit = hookResult[1]; url = hookResult[0]; @@ -917,6 +926,58 @@ export class FetchBuilder { return this; } + /** + * Registers an async (or sync) auth token provider that is invoked before + * every request and injects an `Authorization` header. + * + * Composes cleanly with `withHooks(...)` — if a `preRequest` hook is + * already configured, the existing hook runs first and the + * `Authorization` header is appended afterwards. The provider is awaited, + * so callers can fetch / refresh / mint tokens lazily. + * + * @example + * ```typescript + * const fetch = new FetchBuilder() + * .withAuthTokenProvider(async () => await tokenStore.getFreshToken(), 'Bearer') + * .build(); + * ``` + * + * @param provider - Sync or async function returning the bare token (no scheme prefix). + * @param scheme - The auth scheme to prefix the token with. Defaults to "Bearer". + * @returns The builder instance for method chaining. + */ + withAuthTokenProvider(provider: AuthTokenProvider, scheme: string = 'Bearer'): FetchBuilder { + const existingHooks = this._requestOptions?.hooks; + const existingPreRequest = existingHooks?.preRequest; + + const composedPreRequest: PreRequestHook = async (url, init) => { + // Run any existing pre-request hook first so it can adjust the URL/init. + let nextUrl = url; + let nextInit = init; + if (existingPreRequest) { + const prior = await existingPreRequest(url, init); + if (prior) { + nextUrl = prior[0]; + nextInit = prior[1]; + } + } + + const token = await provider(); + const headers = new Headers(nextInit.headers ?? {}); + headers.set('Authorization', `${scheme} ${token}`); + return [nextUrl, { ...nextInit, headers }]; + }; + + this._requestOptions = { + ...this._requestOptions, + hooks: { + ...existingHooks, + preRequest: composedPreRequest, + }, + }; + return this; + } + /** * Builds and returns a configured fetch function. * Applies default options for any unset configurations.