Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/SMOODEV-948-async-auth-provider.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'@smooai/fetch': patch
---

SMOODEV-948: Async auth-token provider across TS, Python, Rust, Go. Adds a first-class hook that's invoked before every request to mint / refresh an auth token (sync or async), with the resulting `Authorization` header injected using a configurable scheme (default `Bearer`). Mirrors the existing .NET `AuthTokenProvider` delegate.
111 changes: 111 additions & 0 deletions go/fetch/auth_provider_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
package fetch

import (
"context"
"errors"
"fmt"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
)

func TestClientBuilder_WithAuthTokenProvider_DefaultScheme(t *testing.T) {
var captured string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
captured = r.Header.Get("Authorization")
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"ok":true}`))
}))
defer server.Close()

client := NewClientBuilder().
WithNoRetry().
WithAuthTokenProvider(func(_ context.Context) (string, error) {
return "fresh-token", nil
}, "").
Build()

_, err := SimpleGet(context.Background(), client, server.URL, nil)
if err != nil {
t.Fatalf("fetch failed: %v", err)
}

if captured != "Bearer fresh-token" {
t.Errorf("expected 'Bearer fresh-token', got %q", captured)
}
}

func TestClientBuilder_WithAuthTokenProvider_CustomScheme(t *testing.T) {
var captured string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
captured = r.Header.Get("Authorization")
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"ok":true}`))
}))
defer server.Close()

client := NewClientBuilder().
WithNoRetry().
WithAuthTokenProvider(func(_ context.Context) (string, error) {
return "abc", nil
}, "Token").
Build()

_, err := SimpleGet(context.Background(), client, server.URL, nil)
if err != nil {
t.Fatalf("fetch failed: %v", err)
}

if captured != "Token abc" {
t.Errorf("expected 'Token abc', got %q", captured)
}
}

func TestClientBuilder_WithAuthTokenProvider_InvokedPerRequest(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"ok":true}`))
}))
defer server.Close()

var calls int32
client := NewClientBuilder().
WithNoRetry().
WithAuthTokenProvider(func(_ context.Context) (string, error) {
n := atomic.AddInt32(&calls, 1)
return fmt.Sprintf("tok-%d", n), nil
}, "Bearer").
Build()

for i := 0; i < 3; i++ {
if _, err := SimpleGet(context.Background(), client, server.URL, nil); err != nil {
t.Fatalf("fetch %d failed: %v", i, err)
}
}

if got := atomic.LoadInt32(&calls); got != 3 {
t.Errorf("expected 3 provider invocations, got %d", got)
}
}

func TestClientBuilder_WithAuthTokenProvider_PropagatesError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Error("HTTP request fired despite provider error")
w.WriteHeader(200)
}))
defer server.Close()

wantErr := errors.New("token-mint-failure")
client := NewClientBuilder().
WithNoRetry().
WithAuthTokenProvider(func(_ context.Context) (string, error) {
return "", wantErr
}, "Bearer").
Build()

_, err := SimpleGet(context.Background(), client, server.URL, nil)
if err == nil || !errors.Is(err, wantErr) {
t.Fatalf("expected wrapped %v, got %v", wantErr, err)
}
}
23 changes: 23 additions & 0 deletions go/fetch/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ type ClientBuilder struct {
circuitBreakerOpts *CircuitBreakerOptions
circuitBreakerName string
hooks *LifecycleHooks
authProvider AuthTokenProvider
authScheme string
}

// NewClientBuilder creates a new ClientBuilder with default retry and timeout options.
Expand Down Expand Up @@ -86,6 +88,22 @@ func (b *ClientBuilder) WithHooks(hooks *LifecycleHooks) *ClientBuilder {
return b
}

// WithAuthTokenProvider registers a sync-or-async auth token provider that is
// invoked before every request and used to populate the `Authorization`
// header. The provider receives the request context, so it can short-circuit
// on cancellation/timeouts. Mirrors the .NET `AuthTokenProvider` delegate and
// the TypeScript `FetchBuilder.withAuthTokenProvider(...)` method.
//
// Pass an empty string for scheme to default to "Bearer".
func (b *ClientBuilder) WithAuthTokenProvider(provider AuthTokenProvider, scheme string) *ClientBuilder {
b.authProvider = provider
if scheme == "" {
scheme = "Bearer"
}
b.authScheme = scheme
return b
}

// WithNoRetry disables retries.
func (b *ClientBuilder) WithNoRetry() *ClientBuilder {
b.retryOpts = nil
Expand Down Expand Up @@ -126,6 +144,11 @@ func (b *ClientBuilder) Build() *Client {
timeout: b.timeoutOpts,
rateLimitRetry: b.rateLimitRetryOpts,
hooks: b.hooks,
authProvider: b.authProvider,
authScheme: b.authScheme,
}
if c.authScheme == "" {
c.authScheme = "Bearer"
}

if c.httpClient == nil {
Expand Down
22 changes: 22 additions & 0 deletions go/fetch/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ type Client struct {
rateLimitRetry *RateLimitRetryOptions
circuitBreaker *CircuitBreaker
hooks *LifecycleHooks
authProvider AuthTokenProvider
authScheme string
}

// NewClient creates a new Client with default settings (retry + timeout).
Expand Down Expand Up @@ -186,6 +188,26 @@ func executeHTTPRequest[T any](
}
}

// Apply auth-token provider (after the pre-request hook so the hook can
// adjust the URL/init first; the resulting Authorization header overrides
// any value the hook may have set).
if client.authProvider != nil {
token, err := client.authProvider(ctx)
if err != nil {
if hooks != nil && hooks.PostResponseError != nil {
if replacementErr := hooks.PostResponseError(requestURL, req, err, nil); replacementErr != nil {
return nil, replacementErr
}
}
return nil, fmt.Errorf("auth token provider failed: %w", err)
}
scheme := client.authScheme
if scheme == "" {
scheme = "Bearer"
}
req.Header.Set("Authorization", fmt.Sprintf("%s %s", scheme, token))
}

// Execute the HTTP request
httpClient := client.httpClient
if httpClient == nil {
Expand Down
7 changes: 7 additions & 0 deletions go/fetch/options.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package fetch

import (
"context"
"net/http"
"time"
)
Expand Down Expand Up @@ -114,6 +115,12 @@ type CircuitBreakerCounts struct {
ConsecutiveFailures uint32
}

// AuthTokenProvider is invoked before every request to mint an auth token that
// is injected into the `Authorization` header. The provider receives the
// request context so callers can hook into cancellation/timeouts when fetching
// or refreshing tokens. Mirrors the .NET `AuthTokenProvider` delegate.
type AuthTokenProvider func(ctx context.Context) (string, error)

// LifecycleHooks provides hooks into the request/response lifecycle.
type LifecycleHooks struct {
// PreRequest is called before sending the request.
Expand Down
2 changes: 2 additions & 0 deletions python/src/smooai_fetch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@

# Types
from smooai_fetch._types import (
AuthTokenProvider,
CircuitBreakerOptions,
FetchContainerOptions,
FetchOptions,
Expand All @@ -66,6 +67,7 @@
"FetchBuilder",
"FetchResponse",
# Types
"AuthTokenProvider",
"CircuitBreakerOptions",
"FetchContainerOptions",
"FetchOptions",
Expand Down
29 changes: 29 additions & 0 deletions python/src/smooai_fetch/_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from smooai_fetch._defaults import DEFAULT_RETRY_OPTIONS
from smooai_fetch._response import FetchResponse
from smooai_fetch._types import (
AuthTokenProvider,
CircuitBreakerOptions,
FetchContainerOptions,
FetchOptions,
Expand Down Expand Up @@ -54,6 +55,8 @@ def __init__(self) -> None:
self._schema: type[BaseModel] | None = None
self._headers: dict[str, str] = {}
self._hooks: LifecycleHooks = LifecycleHooks()
self._auth_token_provider: AuthTokenProvider | None = None
self._auth_scheme: str = "Bearer"

def with_retry(self, options: RetryOptions | None = None) -> FetchBuilder:
"""Configure retry behavior.
Expand Down Expand Up @@ -159,6 +162,30 @@ def with_auth(self, token: str, scheme: str = "Bearer") -> FetchBuilder:
self._headers["Authorization"] = f"{scheme} {token}"
return self

def with_auth_provider(
self,
provider: AuthTokenProvider,
scheme: str = "Bearer",
) -> FetchBuilder:
"""Register a sync or async auth token provider.

The provider is invoked before every request and its result is injected
as the `Authorization` header. If the provider returns an awaitable,
it is awaited inline. Mirrors the .NET `AuthTokenProvider` delegate.

Args:
provider: Callable returning the bare token (no scheme prefix).
Sync `() -> str` or async `() -> Awaitable[str]` are both
accepted.
scheme: Auth scheme prefix. Defaults to "Bearer".

Returns:
The builder instance for method chaining.
"""
self._auth_token_provider = provider
self._auth_scheme = scheme
return self

def with_pre_request_hook(self, hook: PreRequestHook) -> FetchBuilder:
"""Set a pre-request hook.

Expand Down Expand Up @@ -216,6 +243,8 @@ def build(self) -> FetchOptions:
schema=self._schema,
hooks=self._hooks,
container_options=container_options,
auth_token_provider=self._auth_token_provider,
auth_scheme=self._auth_scheme,
)

async def fetch(
Expand Down
13 changes: 13 additions & 0 deletions python/src/smooai_fetch/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import inspect
import json
from typing import Any, TypeVar

Expand Down Expand Up @@ -234,6 +235,18 @@ async def fetch(url: str, options: FetchOptions | None = None) -> FetchResponse[
current_url, request_kwargs = hook_result
request_kwargs["url"] = current_url

# Apply auth-token provider (after pre-request hook so the hook can still
# adjust headers / URL first). Supports both sync providers and async
# providers; the awaitable is awaited inline so we always get the resolved
# token before the request fires.
if opts.auth_token_provider is not None:
raw_token = opts.auth_token_provider()
if inspect.isawaitable(raw_token):
raw_token = await raw_token
token: str = str(raw_token)
headers: dict[str, str] = request_kwargs.setdefault("headers", {})
headers["Authorization"] = f"{opts.auth_scheme} {token}"

# Build rate limiter
rate_limiter: SlidingWindowRateLimiter | None = None
if container_options and container_options.rate_limit:
Expand Down
17 changes: 17 additions & 0 deletions python/src/smooai_fetch/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,13 @@ class RetryContext:
"""Callback invoked before each retry attempt to override default behavior."""


AuthTokenProvider = Callable[[], "Any"]
"""Provider that returns an auth token. May return a `str` directly or an awaitable
that resolves to a `str`. Invoked before every request to populate the
`Authorization` header. Mirrors the .NET delegate of the same name.
"""


@dataclass
class RetryOptions:
"""Configuration options for retry behavior."""
Expand Down Expand Up @@ -242,3 +249,13 @@ class FetchOptions:

container_options: FetchContainerOptions | None = None
"""Container-level options (rate limit, circuit breaker)."""

auth_token_provider: AuthTokenProvider | None = None
"""Optional sync or async provider invoked before each request to mint an auth token.

The returned token is injected into the `Authorization` header using
`auth_scheme` (default `"Bearer"`). Awaitable return values are awaited.
"""

auth_scheme: str = "Bearer"
"""Auth scheme prefix used with `auth_token_provider`. Defaults to "Bearer"."""
Loading
Loading