diff --git a/.vscode/launch.json b/.vscode/launch.json index ade2de0..e23a68c 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -28,7 +28,7 @@ "--inspect-wait", "--allow-all", "--filter", - "handles 400 response with non-JSON text" + "can use per-domain rate limiting with auto-update from headers" ], "attachSimplePort": 9229 } diff --git a/.vscode/tasks.json b/.vscode/tasks.json index 61e834a..e81dfc9 100644 --- a/.vscode/tasks.json +++ b/.vscode/tasks.json @@ -19,6 +19,10 @@ "problemMatcher": [ "$deno" ], + "group": { + "kind": "build", + "isDefault": true + }, "label": "deno: check", "detail": "$ deno check scripts/*.ts *.ts src/*.ts" }, @@ -43,7 +47,10 @@ "problemMatcher": [ "$deno-test" ], - "group": "test", + "group": { + "kind": "test", + "isDefault": true + }, "label": "deno: test" } ] diff --git a/readme.md b/readme.md index f5301a4..3b88b1c 100644 --- a/readme.md +++ b/readme.md @@ -3,12 +3,18 @@ FetchClient is a library that makes it easier to use the fetch API for JSON APIs. It provides the following features: -* [Makes fetch easier to use for JSON APIs](#typed-response) -* [Automatic model validation](#model-validator) -* [Caching](#caching) -* [Middleware](#middleware) -* [Problem Details](https://www.rfc-editor.org/rfc/rfc9457.html) support -* Option to parse dates in responses +- [FetchClient ](#fetchclient---) + - [Install](#install) + - [Docs](#docs) + - [Usage](#usage) + - [Typed Response](#typed-response) + - [Typed Response Using a Function](#typed-response-using-a-function) + - [Model Validator](#model-validator) + - [Caching](#caching) + - [Middleware](#middleware) + - [Rate Limiting](#rate-limiting) + - [Contributing](#contributing) + - [License](#license) ## Install @@ -130,6 +136,23 @@ const response = await client.getJSON( ); ``` +### Rate Limiting + +```ts +import { FetchClient, useRateLimit } from '@exceptionless/fetchclient'; + +// Enable rate limiting globally with 100 requests per minute +useRateLimit({ + maxRequests: 100, + windowSeconds: 60, +}); + +const client = new FetchClient(); +const response = await client.getJSON( + `https://api.example.com/data`, +); +``` + Also, take a look at the tests: [FetchClient Tests](src/FetchClient.test.ts) diff --git a/src/DefaultHelpers.ts b/src/DefaultHelpers.ts index 9697e35..3e8208e 100644 --- a/src/DefaultHelpers.ts +++ b/src/DefaultHelpers.ts @@ -7,6 +7,7 @@ import { } from "./FetchClientProvider.ts"; import type { FetchClientResponse } from "./FetchClientResponse.ts"; import type { ProblemDetails } from "./ProblemDetails.ts"; +import type { RateLimitMiddlewareOptions } from "./RateLimitMiddleware.ts"; import type { GetRequestOptions, RequestOptions } from "./RequestOptions.ts"; let getCurrentProviderFunc: () => FetchClientProvider | null = () => null; @@ -164,3 +165,23 @@ export function useMiddleware(middleware: FetchClientMiddleware) { export function setRequestOptions(options: RequestOptions) { getCurrentProvider().applyOptions({ defaultRequestOptions: options }); } + +/** + * Enables rate limiting for any FetchClient instances created by the current provider. + * @param options - The rate limiting configuration options. + */ +export function useRateLimit( + options: RateLimitMiddlewareOptions, +) { + getCurrentProvider().useRateLimit(options); +} + +/** + * Enables per-domain rate limiting for any FetchClient instances created by the current provider. + * @param options - The rate limiting configuration options. + */ +export function usePerDomainRateLimit( + options: Omit, +) { + getCurrentProvider().usePerDomainRateLimit(options); +} diff --git a/src/FetchClient.test.ts b/src/FetchClient.test.ts index 5c45879..6bd6632 100644 --- a/src/FetchClient.test.ts +++ b/src/FetchClient.test.ts @@ -16,6 +16,10 @@ import { } from "../mod.ts"; import { FetchClientProvider } from "./FetchClientProvider.ts"; import { z, type ZodTypeAny } from "zod"; +import { + buildRateLimitHeader, + buildRateLimitPolicyHeader, +} from "./RateLimiter.ts"; export const TodoSchema = z.object({ userId: z.number(), @@ -970,6 +974,122 @@ Deno.test("handles 400 response with non-JSON text", async () => { ); }); +Deno.test("can use per-domain rate limiting with auto-update from headers", async () => { + const provider = new FetchClientProvider(); + + const groupTracker = new Map(); + + const startTime = Date.now(); + + groupTracker.set("api.example.com", 100); + groupTracker.set("slow-api.example.com", 5); + + provider.usePerDomainRateLimit({ + maxRequests: 50, // Default limit + windowSeconds: 60, // 1 minute default window + autoUpdateFromHeaders: true, + groups: { + "api.example.com": { + maxRequests: 75, // API will override this with headers + windowSeconds: 60, + }, + "slow-api.example.com": { + maxRequests: 30, // API will override this with headers + windowSeconds: 30, + }, + }, + }); + + provider.fetch = ( + input: RequestInfo | URL, + _init?: RequestInit, + ): Promise => { + let url: URL; + if (input instanceof Request) { + url = new URL(input.url); + } else { + url = new URL(input.toString()); + } + + const headers = new Headers({ + "Content-Type": "application/json", + }); + + // Simulate different rate limits for different domains + if (url.hostname === "api.example.com") { + headers.set("X-RateLimit-Limit", "100"); + let remaining = groupTracker.get("api.example.com") ?? 0; + remaining = remaining > 0 ? remaining - 2 : 0; + groupTracker.set("api.example.com", remaining); + headers.set("X-RateLimit-Remaining", String(remaining)); + } else if (url.hostname === "slow-api.example.com") { + let remaining = groupTracker.get("slow-api.example.com") ?? 0; + remaining = remaining > 0 ? remaining - 2 : 0; + groupTracker.set("slow-api.example.com", remaining); + + headers.set( + "RateLimit-Policy", + buildRateLimitPolicyHeader({ + policy: "slow-api.example.com", + limit: 5, + windowSeconds: 30, + }), + ); + headers.set( + "RateLimit", + buildRateLimitHeader({ + policy: "slow-api.example.com", + remaining: remaining, + resetSeconds: 30 - ((Date.now() - startTime) / 1000), + }), + ); + } + // other-api.example.com gets no rate limit headers + + return Promise.resolve( + new Response(JSON.stringify({ success: true }), { + status: 200, + statusText: "OK", + headers, + }), + ); + }; + + assert(provider.rateLimiter); + + const client = provider.getFetchClient(); + + // check API rate limit + let apiOptions = provider.rateLimiter.getGroupOptions("api.example.com"); + assertEquals(apiOptions.maxRequests, 75); + assertEquals(apiOptions.windowSeconds, 60); + + const response1 = await client.getJSON( + "https://api.example.com/data", + ); + assertEquals(response1.status, 200); + + apiOptions = provider.rateLimiter.getGroupOptions("api.example.com"); + assertEquals(apiOptions.maxRequests, 100); // Updated from headers + + // check slow API rate limit + let slowApiOptions = provider.rateLimiter.getGroupOptions( + "slow-api.example.com", + ); + assertEquals(slowApiOptions.maxRequests, 30); + assertEquals(slowApiOptions.windowSeconds, 30); + + const response2 = await client.getJSON( + "https://slow-api.example.com/data", + ); + assertEquals(response2.status, 200); + + slowApiOptions = provider.rateLimiter.getGroupOptions( + "slow-api.example.com", + ); + assertEquals(slowApiOptions.maxRequests, 5); // Updated from headers +}); + function delay(time: number): Promise { return new Promise((resolve) => setTimeout(resolve, time)); } diff --git a/src/FetchClientProvider.ts b/src/FetchClientProvider.ts index 5c35523..130f9be 100644 --- a/src/FetchClientProvider.ts +++ b/src/FetchClientProvider.ts @@ -5,6 +5,11 @@ import type { ProblemDetails } from "./ProblemDetails.ts"; import { FetchClientCache } from "./FetchClientCache.ts"; import type { FetchClientOptions } from "./FetchClientOptions.ts"; import { type IObjectEvent, ObjectEvent } from "./ObjectEvent.ts"; +import { + RateLimitMiddleware, + type RateLimitMiddlewareOptions, +} from "./RateLimitMiddleware.ts"; +import { groupByDomain, type RateLimiter } from "./RateLimiter.ts"; type Fetch = typeof globalThis.fetch; @@ -15,6 +20,7 @@ export class FetchClientProvider { #options: FetchClientOptions = {}; #fetch?: Fetch; #cache: FetchClientCache; + #rateLimitMiddleware?: RateLimitMiddleware; #counter = new Counter(); #onLoading = new ObjectEvent(); @@ -187,6 +193,47 @@ export class FetchClientProvider { ], }; } + + /** + * Enables rate limiting for all FetchClient instances created by this provider. + * @param options - The rate limiting configuration options. + */ + public useRateLimit(options: RateLimitMiddlewareOptions) { + this.#rateLimitMiddleware = new RateLimitMiddleware(options); + this.useMiddleware(this.#rateLimitMiddleware.middleware()); + } + + /** + * Enables rate limiting for all FetchClient instances created by this provider. + * @param options - The rate limiting configuration options. + */ + public usePerDomainRateLimit( + options: Omit, + ) { + this.#rateLimitMiddleware = new RateLimitMiddleware({ + ...options, + getGroupFunc: groupByDomain, + }); + this.useMiddleware(this.#rateLimitMiddleware.middleware()); + } + + /** + * Gets the rate limiter instance used for rate limiting. + * @returns The rate limiter instance, or undefined if rate limiting is not enabled. + */ + public get rateLimiter(): RateLimiter | undefined { + return this.#rateLimitMiddleware?.rateLimiter; + } + + /** + * Removes the rate limiting middleware from all FetchClient instances created by this provider. + */ + public removeRateLimit() { + this.#rateLimitMiddleware = undefined; + this.#options.middleware = this.#options.middleware?.filter( + (m) => !(m instanceof RateLimitMiddleware), + ); + } } const provider = new FetchClientProvider(); diff --git a/src/RateLimit.test.ts b/src/RateLimit.test.ts new file mode 100644 index 0000000..9bdfb52 --- /dev/null +++ b/src/RateLimit.test.ts @@ -0,0 +1,521 @@ +import { assertEquals, assertRejects } from "@std/assert"; +import { FetchClientProvider } from "./FetchClientProvider.ts"; +import { + RateLimitError, + type RateLimitMiddlewareOptions, +} from "./RateLimitMiddleware.ts"; +import type { FetchClientResponse } from "./FetchClientResponse.ts"; +import { + buildRateLimitHeader, + buildRateLimitPolicyHeader, + parseRateLimitHeader, + parseRateLimitPolicyHeader, + RateLimiter, +} from "./RateLimiter.ts"; + +// Mock fetch function for testing +const createMockFetch = (response: { + status?: number; + statusText?: string; + body?: string; + headers?: Record; +} = {}) => { + return ( + _input: RequestInfo | URL, + _init?: RequestInit, + ): Promise => { + const headers = new Headers(response.headers || {}); + headers.set("Content-Type", "application/json"); + + return Promise.resolve( + new Response(response.body || JSON.stringify({ success: true }), { + status: response.status || 200, + statusText: response.statusText || "OK", + headers, + }), + ); + }; +}; + +Deno.test("RateLimiter - basic functionality", () => { + const rateLimiter = new RateLimiter({ + maxRequests: 2, + windowSeconds: 1, + }); + + // First request should be allowed + assertEquals(rateLimiter.isAllowed("http://example.com"), true); + assertEquals(rateLimiter.getRequestCount("http://example.com"), 1); + assertEquals( + rateLimiter.getRemainingRequests("http://example.com"), + 1, + ); + + // Second request should be allowed + assertEquals(rateLimiter.isAllowed("http://example.com"), true); + assertEquals(rateLimiter.getRequestCount("http://example.com"), 2); + assertEquals( + rateLimiter.getRemainingRequests("http://example.com"), + 0, + ); + + // Third request should be denied + assertEquals(rateLimiter.isAllowed("http://example.com"), false); + assertEquals(rateLimiter.getRequestCount("http://example.com"), 2); + assertEquals( + rateLimiter.getRemainingRequests("http://example.com"), + 0, + ); +}); + +Deno.test("RateLimiter - group generator", () => { + const rateLimiter = new RateLimiter({ + maxRequests: 1, + windowSeconds: 1, + getGroupFunc: (url: string) => `${url}`, + }); + + // Different URLs should have separate buckets + assertEquals(rateLimiter.isAllowed("http://example.com"), true); + assertEquals(rateLimiter.isAllowed("http://other.com"), true); + assertEquals(rateLimiter.isAllowed("http://example.com"), false); + assertEquals(rateLimiter.isAllowed("http://other.com"), false); +}); + +Deno.test("RateLimiter - group initialization", () => { + const rateLimiter = new RateLimiter({ + maxRequests: 5, + windowSeconds: 1, + getGroupFunc: (url: string) => new URL(url).hostname, + groups: { + "example.com": { + maxRequests: 2, + windowSeconds: 1, + }, + "api.example.com": { + maxRequests: 10, + windowSeconds: 2, + }, + }, + }); + + // Check that group options were applied correctly + const exampleOptions = rateLimiter.getGroupOptions("example.com"); + assertEquals(exampleOptions.maxRequests, 2); + assertEquals(exampleOptions.windowSeconds, 1); + + const apiOptions = rateLimiter.getGroupOptions("api.example.com"); + assertEquals(apiOptions.maxRequests, 10); + assertEquals(apiOptions.windowSeconds, 2); + + // Check that non-configured groups get empty options (will use defaults) + const otherOptions = rateLimiter.getGroupOptions("other.com"); + assertEquals(otherOptions.maxRequests, 5); + assertEquals(otherOptions.windowSeconds, 1); + + // Test that the group-specific limits are actually used + assertEquals(rateLimiter.isAllowed("https://example.com/test"), true); + assertEquals(rateLimiter.isAllowed("https://example.com/test"), true); + assertEquals(rateLimiter.isAllowed("https://example.com/test"), false); // Should be denied (limit=2) + + // API subdomain should have different limits + assertEquals( + rateLimiter.getRemainingRequests("https://api.example.com/test"), + 10, + ); +}); + +Deno.test("RateLimiter - time window expiry", async () => { + const rateLimiter = new RateLimiter({ + maxRequests: 1, + windowSeconds: 0.1, + }); + + // First request should be allowed + assertEquals(rateLimiter.isAllowed("http://example.com"), true); + + // Second request should be denied + assertEquals(rateLimiter.isAllowed("http://example.com"), false); + + // Wait for window to expire + await new Promise((resolve) => setTimeout(resolve, 150)); + + // Request should be allowed again + assertEquals(rateLimiter.isAllowed("http://example.com"), true); +}); + +Deno.test("RateLimitMiddleware - throws error when rate limit exceeded", async () => { + const mockFetch = createMockFetch(); + const provider = new FetchClientProvider(mockFetch); + + const options: RateLimitMiddlewareOptions = { + maxRequests: 1, + windowSeconds: 1, + throwOnRateLimit: true, + }; + + provider.useRateLimit(options); + + const client = provider.getFetchClient(); + + // First request should succeed + const response1 = await client.get("http://example.com"); + assertEquals(response1.status, 200); + + // Second request should throw RateLimitError + await assertRejects( + () => client.get("http://example.com"), + RateLimitError, + "Rate limit exceeded", + ); +}); + +Deno.test("RateLimitMiddleware - returns 429 response when configured", async () => { + const mockFetch = createMockFetch(); + const provider = new FetchClientProvider(mockFetch); + + const options: RateLimitMiddlewareOptions = { + maxRequests: 1, + windowSeconds: 1, + throwOnRateLimit: false, + errorMessage: "Custom rate limit message", + }; + + provider.useRateLimit(options); + + const client = provider.getFetchClient(); + + // First request should succeed + const response1 = await client.get("http://example.com"); + assertEquals(response1.status, 200); + + // Second request should throw 429 response + try { + await client.get("http://example.com"); + throw new Error("Expected rate limit response to be thrown"); + } catch (error) { + // The response object is thrown by FetchClient for 4xx/5xx status codes + const response = error as FetchClientResponse; + assertEquals(response.status, 429); + assertEquals(response.problem?.title, "Unexpected status code: 429"); + if (response.problem?.detail) { + assertEquals( + response.problem.detail.includes("Custom rate limit message"), + true, + ); + } + } +}); + +Deno.test("RateLimitMiddleware - provides rate limit info in error response", async () => { + const mockFetch = createMockFetch(); + const provider = new FetchClientProvider(mockFetch); + + const options: RateLimitMiddlewareOptions = { + maxRequests: 1, + windowSeconds: 1, + throwOnRateLimit: false, + }; + + provider.useRateLimit(options); + + const client = provider.getFetchClient(); + + // First request should succeed + const response1 = await client.get("http://example.com"); + assertEquals(response1.status, 200); + + // Second request should throw 429 with rate limit headers + try { + await client.get("http://example.com"); + throw new Error("Expected rate limit response to be thrown"); + } catch (error) { + const response = error as FetchClientResponse; + assertEquals(response.status, 429); + assertEquals(response.headers.get("RateLimit-Limit"), "1"); + assertEquals(response.headers.get("RateLimit-Remaining"), "0"); + assertEquals(response.headers.get("RateLimit-Reset") !== null, true); + assertEquals(response.headers.get("Retry-After") !== null, true); + } +}); + +Deno.test("createRateLimitMiddleware - custom group generator", async () => { + const mockFetch = createMockFetch(); + const provider = new FetchClientProvider(mockFetch); + + let callCount = 0; + const options: RateLimitMiddlewareOptions = { + maxRequests: 1, + windowSeconds: 1, + getGroupFunc: (url: string) => { + callCount++; + return `custom-${url}`; + }, + throwOnRateLimit: true, + autoUpdateFromHeaders: false, // Disable auto-update to prevent extra getGroupFunc calls + }; + + provider.useRateLimit(options); + + const client = provider.getFetchClient(); + + // First request should succeed and call key generator + await client.get("http://example.com"); + assertEquals(callCount, 1); + + // Second request should call key generator and throw + await assertRejects( + () => client.get("http://example.com"), + RateLimitError, + ); + // The key generator might be called multiple times due to the rate limiting logic + assertEquals(callCount >= 2, true); +}); + +Deno.test("RateLimitError - contains correct information", async () => { + const mockFetch = createMockFetch(); + const provider = new FetchClientProvider(mockFetch); + + provider.useRateLimit({ + maxRequests: 1, + windowSeconds: 1, + throwOnRateLimit: true, + }); + + const client = provider.getFetchClient(); + + // First request should succeed + await client.get("http://example.com"); + + // Second request should throw with proper error info + try { + await client.get("http://example.com"); + throw new Error("Expected request to fail"); + } catch (error) { + if (error instanceof RateLimitError) { + assertEquals(error.name, "RateLimitError"); + assertEquals(error.remainingRequests, 0); + assertEquals(typeof error.resetTime, "number"); + assertEquals(error.resetTime > Date.now(), true); + } else { + throw new Error("Expected RateLimitError"); + } + } +}); + +Deno.test("RateLimiter - updateFromHeaders with standard headers", () => { + const rateLimiter = new RateLimiter({ + maxRequests: 10, + windowSeconds: 5, + }); + + // Test with IETF standard headers + const headers = new Headers({ + "ratelimit-policy": '"default";q=100;w=60', + "ratelimit": '"default";r=75;t=30', + }); + + rateLimiter.updateFromHeaders("test-group", headers); + + const groupOptions = rateLimiter.getGroupOptions("test-group"); + assertEquals(groupOptions.maxRequests, 100); + assertEquals(groupOptions.windowSeconds, 60); +}); + +Deno.test("RateLimiter - updateFromHeaders with x-ratelimit fallback headers", () => { + const rateLimiter = new RateLimiter({ + maxRequests: 10, + windowSeconds: 5, + }); + + // Test with fallback x-ratelimit headers + const headers = new Headers({ + "x-ratelimit-limit": "50", + "x-ratelimit-remaining": "25", + "x-ratelimit-reset": "1234567890", + "x-ratelimit-window": "120", + }); + + rateLimiter.updateFromHeaders("test-group", headers); + + const groupOptions = rateLimiter.getGroupOptions("test-group"); + assertEquals(groupOptions.maxRequests, 50); + assertEquals(groupOptions.windowSeconds, 120); +}); + +Deno.test("RateLimiter - updateFromHeaders with x-rate-limit fallback headers", () => { + const rateLimiter = new RateLimiter({ + maxRequests: 10, + windowSeconds: 5, + }); + + // Test with alternate x-rate-limit headers + const headers = new Headers({ + "x-rate-limit-limit": "200", + "x-rate-limit-remaining": "150", + "x-rate-limit-reset": "1234567890", + "x-rate-limit-window": "30", + }); + + rateLimiter.updateFromHeaders("test-group", headers); + + const groupOptions = rateLimiter.getGroupOptions("test-group"); + assertEquals(groupOptions.maxRequests, 200); + assertEquals(groupOptions.windowSeconds, 30); +}); + +Deno.test("RateLimiter - updateFromHeaders prioritizes standard over x-ratelimit", () => { + const rateLimiter = new RateLimiter({ + maxRequests: 10, + windowSeconds: 5, + }); + + // Test with both IETF and x-ratelimit headers - IETF should take precedence + const headers = new Headers({ + "ratelimit-policy": '"default";q=100;w=60', + "ratelimit": '"default";r=75;t=30', + "x-ratelimit-limit": "50", + "x-ratelimit-remaining": "25", + "x-ratelimit-reset": "1234567890", + "x-ratelimit-window": "120", + }); + + rateLimiter.updateFromHeaders("test-group", headers); + + const groupOptions = rateLimiter.getGroupOptions("test-group"); + // Should use IETF standard values (100 limit, 60 window), not x-ratelimit values + assertEquals(groupOptions.maxRequests, 100); + assertEquals(groupOptions.windowSeconds, 60); +}); + +Deno.test("RateLimiter - updateFromHeaders with reset time calculation", () => { + const rateLimiter = new RateLimiter({ + maxRequests: 10, + windowSeconds: 5, + }); + + // Test with only reset time (no window) + const futureTime = Math.floor(Date.now() / 1000) + 90; // 90 seconds in the future + const headers = new Headers({ + "x-ratelimit-limit": "50", + "x-ratelimit-reset": futureTime.toString(), + }); + + rateLimiter.updateFromHeaders("test-group", headers); + + const groupOptions = rateLimiter.getGroupOptions("test-group"); + assertEquals(groupOptions.maxRequests, 50); + // Window should be approximately 90 seconds + assertEquals(groupOptions.windowSeconds! >= 85, true); + assertEquals(groupOptions.windowSeconds! <= 95, true); +}); + +Deno.test("RateLimiter - updateFromHeaders with malformed IETF headers", () => { + const rateLimiter = new RateLimiter({ + maxRequests: 10, + windowSeconds: 5, + }); + + // Test with malformed IETF headers should fall back to x-ratelimit + const headers = new Headers({ + "ratelimit-policy": '"default";invalid=format', + "ratelimit": '"default";bad=format', + "x-ratelimit-limit": "50", + "x-ratelimit-window": "120", + }); + + rateLimiter.updateFromHeaders("test-group", headers); + + const groupOptions = rateLimiter.getGroupOptions("test-group"); + assertEquals(groupOptions.maxRequests, 50); + assertEquals(groupOptions.windowSeconds, 120); +}); + +Deno.test("createRateLimitHeader - creates correct header format", () => { + const result = buildRateLimitHeader({ + policy: "default", + remaining: 75, + resetSeconds: 30, + }); + + assertEquals(result, '"default";r=75;t=30'); +}); + +Deno.test("createRateLimitHeader - handles missing reset time", () => { + const result = buildRateLimitHeader({ + policy: "default", + remaining: 75, + resetSeconds: 0, + }); + + assertEquals(result, '"default";r=75'); +}); + +Deno.test("createRateLimitPolicyHeader - creates correct header format", () => { + const result = buildRateLimitPolicyHeader({ + policy: "default", + limit: 100, + windowSeconds: 60, + }); + + assertEquals(result, '"default";q=100;w=60'); +}); + +Deno.test("createRateLimitPolicyHeader - handles missing window", () => { + const result = buildRateLimitPolicyHeader({ + policy: "default", + limit: 100, + }); + + assertEquals(result, '"default";q=100'); +}); + +Deno.test("parseRateLimitHeader - parses correct header format", () => { + const result = parseRateLimitHeader('"default";r=75;t=30'); + + assertEquals(result, { + policy: "default", + remaining: 75, + resetSeconds: 30, + }); +}); + +Deno.test("parseRateLimitHeader - handles missing parameters", () => { + const result = parseRateLimitHeader('"default";r=75'); + + assertEquals(result, { + policy: "default", + remaining: 75, + }); +}); + +Deno.test("parseRateLimitHeader - handles invalid format", () => { + const result = parseRateLimitHeader("invalid-format"); + + assertEquals(result, {}); +}); + +Deno.test("parseRateLimitPolicyHeader - parses correct header format", () => { + const result = parseRateLimitPolicyHeader('"default";q=100;w=60'); + + assertEquals(result, { + policy: "default", + limit: 100, + windowSeconds: 60, + }); +}); + +Deno.test("parseRateLimitPolicyHeader - handles missing parameters", () => { + const result = parseRateLimitPolicyHeader('"default";q=100'); + + assertEquals(result, { + policy: "default", + limit: 100, + }); +}); + +Deno.test("parseRateLimitPolicyHeader - handles invalid format", () => { + const result = parseRateLimitPolicyHeader("invalid-format"); + + assertEquals(result, {}); +}); diff --git a/src/RateLimitMiddleware.ts b/src/RateLimitMiddleware.ts new file mode 100644 index 0000000..aec1219 --- /dev/null +++ b/src/RateLimitMiddleware.ts @@ -0,0 +1,177 @@ +import type { FetchClientContext } from "./FetchClientContext.ts"; +import type { FetchClientMiddleware } from "./FetchClientMiddleware.ts"; +import type { FetchClientResponse } from "./FetchClientResponse.ts"; +import { ProblemDetails } from "./ProblemDetails.ts"; +import { + buildRateLimitHeader, + buildRateLimitPolicyHeader, + RateLimiter, + type RateLimiterOptions, +} from "./RateLimiter.ts"; + +/** + * Rate limiting error thrown when requests exceed the rate limit. + */ +export class RateLimitError extends Error { + public readonly resetTime: number; + public readonly remainingRequests: number; + + constructor(resetTime: number, remainingRequests: number, message?: string) { + super( + message || + `Rate limit exceeded. Try again after ${ + new Date(resetTime).toISOString() + }`, + ); + this.name = "RateLimitError"; + this.resetTime = resetTime; + this.remainingRequests = remainingRequests; + } +} + +/** + * Configuration options for the rate limiting middleware. + */ +export interface RateLimitMiddlewareOptions extends RateLimiterOptions { + /** + * Whether to throw an error when rate limit is exceeded. + * If false, the middleware will set a 429 status response. + * @default true + */ + throwOnRateLimit?: boolean; + + /** + * Custom error message when rate limit is exceeded. + */ + errorMessage?: string; + + /** + * Whether to automatically update rate limits based on response headers. + * @default true + */ + autoUpdateFromHeaders?: boolean; +} + +/** + * Rate limiting middleware instance that can be shared across requests. + */ +export class RateLimitMiddleware { + #rateLimiter: RateLimiter; + + private readonly throwOnRateLimit: boolean; + private readonly errorMessage?: string; + private readonly autoUpdateFromHeaders: boolean; + + constructor(options: RateLimitMiddlewareOptions) { + this.#rateLimiter = new RateLimiter(options); + this.throwOnRateLimit = options.throwOnRateLimit ?? true; + this.errorMessage = options.errorMessage; + this.autoUpdateFromHeaders = options.autoUpdateFromHeaders ?? true; + } + + /** + * Gets the underlying rate limiter instance. + */ + public get rateLimiter(): RateLimiter { + return this.#rateLimiter; + } + + /** + * Creates the middleware function. + * @returns The middleware function + */ + public middleware(): FetchClientMiddleware { + return async (context: FetchClientContext, next: () => Promise) => { + const url = context.request.url; + + // Check if request is allowed + if (!this.rateLimiter.isAllowed(url)) { + const group = this.rateLimiter.getGroup(url); + const resetTime = this.rateLimiter.getResetTime(url) ?? Date.now(); + const remainingRequests = this.rateLimiter.getRemainingRequests(url); + + if (this.throwOnRateLimit) { + throw new RateLimitError( + resetTime, + remainingRequests, + this.errorMessage, + ); + } + + // Create a 429 Too Many Requests response + const groupOptions = this.rateLimiter.getGroupOptions(group); + const maxRequests = groupOptions.maxRequests!; + const windowSeconds = groupOptions.windowSeconds!; + + // Create IETF standard rate limit headers + const resetSeconds = Math.ceil((resetTime - Date.now()) / 1000); + const rateLimitHeader = buildRateLimitHeader({ + policy: group, + remaining: remainingRequests, + resetSeconds: resetSeconds, + }); + + const rateLimitPolicyHeader = buildRateLimitPolicyHeader({ + policy: group, + limit: maxRequests, + windowSeconds: Math.floor(windowSeconds), + }); + + const headers = new Headers({ + "Content-Type": "application/problem+json", + "RateLimit": rateLimitHeader, + "RateLimit-Policy": rateLimitPolicyHeader, + // Legacy headers for backward compatibility + "RateLimit-Limit": maxRequests.toString(), + "RateLimit-Remaining": remainingRequests.toString(), + "RateLimit-Reset": Math.ceil(resetTime / 1000).toString(), + "Retry-After": resetSeconds.toString(), + }); + + const problem = new ProblemDetails(); + problem.status = 429; + problem.title = "Too Many Requests"; + problem.detail = this.errorMessage || + `Rate limit exceeded. Try again after ${ + new Date(resetTime).toISOString() + }`; + + context.response = { + url: context.request.url, + status: 429, + statusText: "Too Many Requests", + body: null, + bodyUsed: true, + ok: false, + headers: headers, + redirected: false, + type: "basic", + problem: problem, + data: null, + meta: { links: {} }, + json: () => Promise.resolve(problem), + text: () => Promise.resolve(JSON.stringify(problem)), + arrayBuffer: () => Promise.resolve(new ArrayBuffer(0)), + // @ts-ignore: New in Deno 1.44 + bytes: () => Promise.resolve(new Uint8Array()), + blob: () => Promise.resolve(new Blob()), + formData: () => Promise.resolve(new FormData()), + clone: () => { + throw new Error("Not implemented"); + }, + } as FetchClientResponse; + + return; + } + + await next(); + + if (this.autoUpdateFromHeaders && context.response) { + this.rateLimiter.updateFromHeadersForRequest( + url, + context.response.headers, + ); + } + }; + } +} diff --git a/src/RateLimiter.ts b/src/RateLimiter.ts new file mode 100644 index 0000000..a14803f --- /dev/null +++ b/src/RateLimiter.ts @@ -0,0 +1,494 @@ +/** + * Per-group rate limiter options that can override the global options. + */ +export interface GroupRateLimiterOptions { + /** + * Maximum number of requests allowed per time window for this group. + */ + maxRequests?: number; + + /** + * Time window in milliseconds for this group. + */ + windowSeconds?: number; + + /** + * Callback function called when rate limit is exceeded for this group. + * @param resetTime - Time when the rate limit will reset (in milliseconds since epoch) + */ + onRateLimitExceeded?: (resetTime: number) => void; +} + +/** + * Configuration options for the rate limiter. + */ +export interface RateLimiterOptions { + /** + * Maximum number of requests allowed per time window. + */ + maxRequests: number; + + /** + * Time window in seconds. + */ + windowSeconds: number; + + /** + * Optional group generator function to create unique rate limit buckets. + * If not provided, a global rate limit is applied. + * @param url - The request URL + * @returns A string group to identify the rate limit bucket + */ + getGroupFunc?: (url: string) => string; + + /** + * Callback function called when rate limit is exceeded. + * @param resetTime - Time when the rate limit will reset (in milliseconds since epoch) + */ + onRateLimitExceeded?: (resetTime: number) => void; + + /** + * Optional group-specific rate limit options. + * Map of group keys to their specific rate limit options. + */ + groups?: Record; +} + +/** + * Represents a rate limit bucket with request tracking. + */ +interface RateLimitBucket { + requests: number[]; + resetTime: number; +} + +/** + * A rate limiter that tracks requests per time window. + */ +export class RateLimiter { + private readonly options: Required; + private readonly buckets = new Map(); + private readonly groupOptions = new Map(); + + constructor(options: RateLimiterOptions) { + this.options = { + getGroupFunc: () => "global", + onRateLimitExceeded: () => {}, + groups: {}, + ...options, + }; + + // Initialize group options if provided + if (options.groups) { + for (const [groupKey, groupOptions] of Object.entries(options.groups)) { + this.groupOptions.set(groupKey, groupOptions); + } + } + } + + /** + * Checks if a request is allowed and updates the rate limit state. + * @param url - The request URL + * @returns True if the request is allowed, false if rate limit is exceeded + */ + public isAllowed(url: string): boolean { + const key = this.options.getGroupFunc(url); + const groupOptions = this.getGroupOptions(key); + const now = Date.now(); + + // Use group-specific options if available, otherwise fall back to global options + const maxRequests = groupOptions.maxRequests ?? 0; + const windowSeconds = groupOptions.windowSeconds ?? 0; + const onRateLimitExceeded = groupOptions.onRateLimitExceeded ?? + this.options.onRateLimitExceeded; + + let bucket = this.buckets.get(key); + if (!bucket) { + bucket = { + requests: [], + resetTime: now + (windowSeconds * 1000), + }; + this.buckets.set(key, bucket); + } + + // Clean up old requests outside the time window + const windowStart = now - (windowSeconds * 1000); + bucket.requests = bucket.requests.filter((time) => time > windowStart); + + // Update reset time if all requests have expired + if (bucket.requests.length === 0) { + bucket.resetTime = now + (windowSeconds * 1000); + } + + // Check if we're within the rate limit + if (bucket.requests.length >= maxRequests) { + onRateLimitExceeded(bucket.resetTime); + return false; + } + + // Add the current request + bucket.requests.push(now); + return true; + } + + /** + * Gets the current request count for a specific key. + * @param url - The request URL + * @returns The current number of requests in the time window + */ + public getRequestCount(url: string): number { + const key = this.options.getGroupFunc(url); + const groupOptions = this.getGroupOptions(key); + const bucket = this.buckets.get(key); + + if (!bucket) { + return 0; + } + + const now = Date.now(); + const windowSeconds = groupOptions.windowSeconds ?? 0; + const windowStart = now - (windowSeconds * 1000); + return bucket.requests.filter((time) => time > windowStart).length; + } + + /** + * Gets the remaining requests allowed for a specific key. + * @param url - The request URL + * @returns The number of remaining requests allowed + */ + public getRemainingRequests(url: string): number { + const key = this.options.getGroupFunc(url); + const groupOptions = this.getGroupOptions(key); + const maxRequests = groupOptions.maxRequests ?? 0; + + return Math.max( + 0, + maxRequests - this.getRequestCount(url), + ); + } + + /** + * Gets the time when the rate limit will reset for a specific key. + * @param url - The request URL + * @returns The reset time in milliseconds since epoch, or null if no bucket exists + */ + public getResetTime(url: string): number | null { + const key = this.options.getGroupFunc(url); + const bucket = this.buckets.get(key); + return bucket?.resetTime ?? null; + } + + /** + * Clears the rate limit state for a specific key. + * @param url - The request URL + */ + public clearBucket(url: string): void { + const key = this.options.getGroupFunc(url); + this.buckets.delete(key); + } + + /** + * Gets the group key for a URL. + * @param url - The request URL + * @returns The group key + */ + public getGroup(url: string): string { + return this.options.getGroupFunc(url); + } + + /** + * Gets the options for a specific group. Falls back to global options if not set. + * @param group - The group key + * @returns The options for the group + */ + public getGroupOptions(group: string): GroupRateLimiterOptions { + const options = this.groupOptions.get(group); + if (!options) { + return { + maxRequests: this.options.maxRequests, + windowSeconds: this.options.windowSeconds, + }; + } + return options; + } + + /** + * Checks if a group has specific options set. + * @param group - The group key + * @returns True if the group has options, false otherwise + */ + public hasGroupOptions(group: string): boolean { + return this.groupOptions.has(group); + } + + /** + * Sets options for a specific group. + * @param group - The group key + * @param options - The options to set + */ + public setGroupOptions( + group: string, + options: GroupRateLimiterOptions, + ): void { + this.groupOptions.set(group, options); + } + + /** + * Sets rate limit options for a request. + * @param url - The request URL + * @param options - The options to set for this group + */ + public setOptionsForRequest( + url: string, + options: GroupRateLimiterOptions, + ): void { + const group = this.getGroup(url); + this.setGroupOptions(group, options); + } + + /** + * Updates rate limit options for a request based on standard rate limit headers. + * @param url - The request URL + * @param method - The HTTP method + * @param headers - The response headers containing rate limit information + */ + public updateFromHeadersForRequest( + url: string, + headers: Headers, + ): void { + const group = this.getGroup(url); + this.updateFromHeaders(group, headers); + } + + /** + * Updates rate limit options based on standard rate limit headers. + * @param group - The group key + * @param headers - The response headers containing rate limit information + */ + public updateFromHeaders(group: string, headers: Headers): void { + // Get existing group-specific options (not global fallback) + const currentOptions = this.hasGroupOptions(group) + ? this.groupOptions.get(group)! + : {}; + const newOptions: GroupRateLimiterOptions = { ...currentOptions }; + + // Parse IETF standard rate limit headers first, then fall back to x-ratelimit headers + let limit: string | null = null; + let window: string | null = null; + let reset: string | null = null; + + // Try IETF standard headers first + const rateLimitPolicyHeader = headers.get("ratelimit-policy"); + if (rateLimitPolicyHeader) { + const parsed = parseRateLimitPolicyHeader(rateLimitPolicyHeader); + if (parsed?.limit) { + limit = parsed.limit.toString(); + } + if (parsed?.windowSeconds) { + window = parsed.windowSeconds.toString(); + } + } + + const rateLimitHeader = headers.get("ratelimit"); + if (rateLimitHeader) { + const parsed = parseRateLimitHeader(rateLimitHeader); + if (parsed?.resetSeconds) { + reset = parsed.resetSeconds.toString(); + } + } + + // Fall back to x-ratelimit headers if IETF headers not found + if (!limit) { + limit = headers.get("x-ratelimit-limit") || + headers.get("x-rate-limit-limit"); + } + + if (!window) { + window = headers.get("x-ratelimit-window") || + headers.get("x-rate-limit-window"); + } + + if (!reset) { + reset = headers.get("x-ratelimit-reset") || + headers.get("x-rate-limit-reset"); + } + + let hasChanges = false; + + // Apply the parsed values + if (limit) { + const maxRequests = parseInt(limit, 10); + if (!isNaN(maxRequests)) { + newOptions.maxRequests = maxRequests; + hasChanges = true; + } + } + + if (window) { + const windowSeconds = parseInt(window, 10); + if (!isNaN(windowSeconds)) { + newOptions.windowSeconds = windowSeconds; + hasChanges = true; + } + } else if (reset) { + // If no window header, try to calculate from reset time + const resetTime = parseInt(reset, 10); + if (!isNaN(resetTime)) { + const now = Math.floor(Date.now() / 1000); + const windowSeconds = Math.max(1, resetTime - now); + newOptions.windowSeconds = windowSeconds; + hasChanges = true; + } + } + + // Update the group options if we found valid headers + if (hasChanges) { + this.setGroupOptions(group, newOptions); + } + } + + /** + * Clears all rate limit state. + */ + public clearAll(): void { + this.buckets.clear(); + } +} + +/** + * Creates a group generator function that groups requests by domain only (no protocol). + * @param url - The request URL + * @returns A string representing the domain without protocol + */ +export function groupByDomain(url: string): string { + try { + const urlObj = new URL(url); + return urlObj.hostname; + } catch { + return url; + } +} + +/** + * IETF rate limit header information structure. + */ +export interface RateLimitInfo { + /** The policy name/identifier */ + policy: string; + /** Maximum requests allowed (quota) */ + limit: number; + /** Remaining requests */ + remaining: number; + /** Reset time in seconds from now */ + resetSeconds: number; + /** Window duration in seconds */ + windowSeconds?: number; +} + +/** + * Creates an IETF standard RateLimit header value. + * @param info - The rate limit information + * @returns The formatted RateLimit header value + */ +export function buildRateLimitHeader( + info: Omit, +): string { + let headerValue = `"${info.policy}";r=${info.remaining}`; + + if (info.resetSeconds > 0) { + headerValue += `;t=${info.resetSeconds}`; + } + + return headerValue; +} + +/** + * Creates an IETF standard RateLimit-Policy header value. + * @param info - The rate limit information + * @returns The formatted RateLimit-Policy header value + */ +export function buildRateLimitPolicyHeader( + info: Omit, +): string { + let headerValue = `"${info.policy}";q=${info.limit}`; + + if (info.windowSeconds && info.windowSeconds > 0) { + headerValue += `;w=${info.windowSeconds}`; + } + + return headerValue; +} + +/** + * Parses an IETF standard RateLimit header value. + * @param headerValue - The RateLimit header value to parse + * @returns The parsed rate limit information or null if invalid + */ +export function parseRateLimitHeader( + headerValue: string, +): Partial | null { + if (!headerValue) return null; + + try { + const result: Partial = {}; + + // Extract policy name (quoted string at the beginning) + const policyMatch = headerValue.match(/^"([^"]+)"/); + if (policyMatch) { + result.policy = policyMatch[1]; + } + + // Extract remaining (r parameter) + const remainingMatch = headerValue.match(/r=(\d+)/); + if (remainingMatch) { + result.remaining = parseInt(remainingMatch[1], 10); + } + + // Extract reset time (t parameter) + const resetMatch = headerValue.match(/t=(\d+)/); + if (resetMatch) { + result.resetSeconds = parseInt(resetMatch[1], 10); + } + + return result; + } catch { + return null; + } +} + +/** + * Parses an IETF standard RateLimit-Policy header value. + * @param headerValue - The RateLimit-Policy header value to parse + * @returns The parsed rate limit policy information or null if invalid + */ +export function parseRateLimitPolicyHeader( + headerValue: string, +): Partial | null { + if (!headerValue) return null; + + try { + const result: Partial = {}; + + // Extract policy name (quoted string at the beginning) + const policyMatch = headerValue.match(/^"([^"]+)"/); + if (policyMatch) { + result.policy = policyMatch[1]; + } + + // Extract quota/limit (q parameter) + const quotaMatch = headerValue.match(/q=(\d+)/); + if (quotaMatch) { + result.limit = parseInt(quotaMatch[1], 10); + } + + // Extract window (w parameter) + const windowMatch = headerValue.match(/w=(\d+)/); + if (windowMatch) { + result.windowSeconds = parseInt(windowMatch[1], 10); + } + + return result; + } catch { + return null; + } +}