Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .changeset/strong-trains-act.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
---
'@tanstack/start-client-core': minor
'@tanstack/start-plugin-core': patch
'@tanstack/start-server-core': patch
'@tanstack/start-fn-stubs': patch
---

add createCsrfMiddleware based on Sec-Fetch-Site header, auto-apply to unconfigured servers, warn for others
45 changes: 45 additions & 0 deletions docs/start/framework/react/guide/middleware.md
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,51 @@ export const startInstance = createStart(() => {
> [!NOTE]
> Global **request** middleware runs before **every request, including server routes, SSR and server functions**.

### CSRF Middleware

Server functions are same-origin RPC endpoints and should be protected from cross-site requests. If your app does not define `src/start.ts`, TanStack Start installs its CSRF middleware automatically for server functions.

If you define a custom `src/start.ts`, add `createCsrfMiddleware()` explicitly:

```tsx
// src/start.ts
import { createStart, createCsrfMiddleware } from '@tanstack/react-start'

const csrfMiddleware = createCsrfMiddleware({
filter: (ctx) => ctx.handlerType === 'serverFn',
})

export const startInstance = createStart(() => ({
requestMiddleware: [csrfMiddleware],
}))
```

By default, `Origin` and `Referer` checks compare against the incoming request URL origin. If your deployment needs to allow a different public origin, configure it on the CSRF middleware with `createCsrfMiddleware({ origin: 'https://app.example.com' })`.

By default, `createCsrfMiddleware()` validates every request handled by the middleware. Use `filter: (ctx) => ctx.handlerType === 'serverFn'` when installing it globally for server function protection. It verifies same-origin browser request metadata with `Sec-Fetch-Site`, `Origin`, or `Referer` headers and rejects requests that cannot be proven same-origin.

You can also use the same middleware to protect any other route.

```tsx
export const Route = createFileRoute('/api/foo')({
server: {
middleware: [createCsrfMiddleware()],
handlers: { GET: () => {...} }
}
})
```
Comment on lines +466 to +475
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Clarify whether to reuse or create a new middleware instance.

Line 466 says "use the same middleware" but the code example creates a new instance with createCsrfMiddleware(). Either update the text to say "You can also use CSRF middleware to protect any other route" or update the code to reuse the csrfMiddleware variable from the earlier example.

Option 1: Update text to match code (creates new instance)
-You can also use the same middleware to protect any other route.
+You can also use CSRF middleware to protect any other route.
Option 2: Update code to reuse existing instance
 export const Route = createFileRoute('/api/foo')({
   server: {
-    middleware: [createCsrfMiddleware()],
+    middleware: [csrfMiddleware],
     handlers: { GET: () => {...} }
   }
 })
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
You can also use the same middleware to protect any other route.
```tsx
export const Route = createFileRoute('/api/foo')({
server: {
middleware: [createCsrfMiddleware()],
handlers: { GET: () => {...} }
}
})
```
You can also use CSRF middleware to protect any other route.
Suggested change
You can also use the same middleware to protect any other route.
```tsx
export const Route = createFileRoute('/api/foo')({
server: {
middleware: [createCsrfMiddleware()],
handlers: { GET: () => {...} }
}
})
```
You can also use the same middleware to protect any other route.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@docs/start/framework/react/guide/middleware.md` around lines 466 - 475, The
wording "use the same middleware" is inconsistent with the example which calls
createCsrfMiddleware() anew; either change the sentence to "You can also use
CSRF middleware to protect any other route" or update the example to reuse the
earlier csrfMiddleware variable (e.g. replace createCsrfMiddleware() with
csrfMiddleware in the createFileRoute call). Ensure references to
createCsrfMiddleware, csrfMiddleware and the Route/createFileRoute example are
updated so text and code match.


If you define `src/start.ts` without the CSRF middleware, Start shows a development warning for server function requests. If you intentionally handle CSRF another way, disable the warning:

```tsx
// vite.config.ts or rsbuild.config.ts
tanstackStart({
serverFns: {
disableCsrfMiddlewareWarning: true,
},
})
```

### Global Server Function Middleware

To have a middleware run for **every server function in your application**, add it to the `functionMiddleware` array in your `src/start.ts` file:
Expand Down
24 changes: 24 additions & 0 deletions docs/start/framework/react/guide/server-functions.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,30 @@ const time = await getServerTime()

Server functions provide server capabilities (database access, environment variables, file system) while maintaining type safety across the network boundary.

## Same-Origin Requests

Server functions are same-origin RPC endpoints for your application. Browser requests to server functions should come from the same origin, verified with Fetch Metadata (`Sec-Fetch-Site`), `Origin`, or `Referer` headers. Use server routes for public APIs or endpoints that intentionally support cross-origin requests.

TanStack Start provides `createCsrfMiddleware()` to protect server functions from cross-site requests. If your app does not define `src/start.ts`, Start installs this middleware automatically for server functions. If you define `src/start.ts`, add the middleware explicitly:

```tsx
// src/start.ts
import { createStart, createCsrfMiddleware } from '@tanstack/react-start'

const csrfMiddleware = createCsrfMiddleware({
filter: (ctx) => ctx.handlerType === 'serverFn',
})

export const startInstance = createStart(() => ({
requestMiddleware: [csrfMiddleware],
}))
```

By default, `Origin` and `Referer` checks compare against the incoming request URL origin. If your deployment needs to allow a different public origin, configure it on the CSRF middleware with `createCsrfMiddleware({ origin: 'https://app.example.com' })`.

> [!TIP]
> Requests without any of these headers (`Sec-Fetch-Site`, `Origin`, or `Referer`) are rejected by default. If your deployment strips these headers and you have another layer that guarantees same-origin server function requests, you can opt in with `createCsrfMiddleware({ filter: (ctx) => ctx.handlerType === 'serverFn', allowRequestsWithoutOriginCheck: true })`.
## Basic Usage

Server functions are created with `createServerFn()` and can specify HTTP method:
Expand Down
7 changes: 6 additions & 1 deletion e2e/react-start/server-functions/src/start.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { createStart } from '@tanstack/react-start'
import { createCsrfMiddleware, createStart } from '@tanstack/react-start'
import type { CustomFetch } from '@tanstack/react-start'

/**
Expand All @@ -16,6 +16,11 @@ const globalServerFnFetch: CustomFetch = (input, init) => {
}

export const startInstance = createStart(() => ({
requestMiddleware: [
createCsrfMiddleware({
filter: (ctx) => ctx.handlerType === 'serverFn',
}),
],
serverFns: {
fetch: globalServerFnFetch,
},
Expand Down
26 changes: 26 additions & 0 deletions e2e/react-start/server-functions/tests/server-functions.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,32 @@ test('Direct POST submitting FormData to a Server function returns the correct m
expect(result).toBe(expected)
})

test('CSRF middleware rejects cross-site Server function requests', async ({
page,
request,
}) => {
await page.goto('/submit-post-formdata')
await page.waitForLoadState('networkidle')

const actionUrl = await page
.getByTestId('submit-post-formdata-form')
.getAttribute('action')

expect(actionUrl).toBeTruthy()

const response = await request.post(actionUrl!, {
headers: {
'Sec-Fetch-Site': 'cross-site',
},
multipart: {
name: 'Sean',
},
})

expect(response.status()).toBe(403)
await expect(response.text()).resolves.toBe('Forbidden')
})

test("server function's dead code is preserved if already there", async ({
page,
}) => {
Expand Down
16 changes: 12 additions & 4 deletions packages/react-start-client/src/tests/createServerFn.test-d.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ test('createServerFn returns async array', () => {
return result
})

expectTypeOf(serverFn()).toEqualTypeOf<Promise<Array<{ a: number }>>>()
expectTypeOf<ReturnType<typeof serverFn>>().toEqualTypeOf<
Promise<Array<{ a: number }>>
>()
})

test('createServerFn returns sync array', () => {
Expand All @@ -33,7 +35,9 @@ test('createServerFn returns sync array', () => {
return result
})

expectTypeOf(serverFn()).toEqualTypeOf<Promise<Array<{ a: number }>>>()
expectTypeOf<ReturnType<typeof serverFn>>().toEqualTypeOf<
Promise<Array<{ a: number }>>
>()
})

test('createServerFn returns async union', () => {
Expand All @@ -42,7 +46,9 @@ test('createServerFn returns async union', () => {
return result
})

expectTypeOf(serverFn()).toEqualTypeOf<Promise<string | number>>()
expectTypeOf<ReturnType<typeof serverFn>>().toEqualTypeOf<
Promise<string | number>
>()
})

test('createServerFn returns sync union', () => {
Expand All @@ -51,5 +57,7 @@ test('createServerFn returns sync union', () => {
return result
})

expectTypeOf(serverFn()).toEqualTypeOf<Promise<string | number>>()
expectTypeOf<ReturnType<typeof serverFn>>().toEqualTypeOf<
Promise<string | number>
>()
})
197 changes: 197 additions & 0 deletions packages/start-client-core/src/createCsrfMiddleware.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
import { createIsomorphicFn } from '@tanstack/start-fn-stubs'
import { createMiddleware } from './createMiddleware'
import type {
RequestMiddlewareAfterServer,
RequestServerOptions,
} from './createMiddleware'
import type { Register } from '@tanstack/router-core'

export const csrfSymbol = Symbol.for('tanstack-start:csrf-middleware')

export type CsrfSecFetchSite =
| 'same-origin'
| 'same-site'
| 'cross-site'
| 'none'

export type CsrfMatcher<TValue, TRegister = Register, TMiddlewares = unknown> =
| TValue
| Array<TValue>
| ((
value: TValue | (string & {}),
ctx: RequestServerOptions<TRegister, TMiddlewares>,
) => boolean | Promise<boolean>)

export interface CsrfMiddlewareOptions<
TRegister = Register,
TMiddlewares = unknown,
> {
/**
* Return `true` to validate this request, or `false` to skip validation.
*
* @default undefined, which validates every request handled by this middleware.
*/
filter?: (
ctx: RequestServerOptions<TRegister, TMiddlewares>,
) => boolean | Promise<boolean>
/**
* Allowed Origin values. Defaults to the trusted request origin.
*/
origin?: CsrfMatcher<string, TRegister, TMiddlewares>
/**
* Allowed Sec-Fetch-Site values.
*
* @default 'same-origin'
*/
secFetchSite?: CsrfMatcher<CsrfSecFetchSite, TRegister, TMiddlewares>
/**
* Whether to use Referer as a fallback when Sec-Fetch-Site and Origin are absent.
*
* @default true
*/
referer?:
| boolean
| ((
referer: string,
ctx: RequestServerOptions<TRegister, TMiddlewares>,
) => boolean | Promise<boolean>)
/**
* Allow requests when Sec-Fetch-Site, Origin, and Referer are all missing.
*
* @default false
*/
allowRequestsWithoutOriginCheck?: boolean
/**
* Optional response returned when CSRF validation fails.
*
* @default new Response('Forbidden', { status: 403 })
*/
failureResponse?:
| Response
| ((
ctx: RequestServerOptions<TRegister, TMiddlewares>,
) => Response | Promise<Response>)
}

type CreateCsrfMiddleware = <TRegister, TMiddlewares>(
opts?: CsrfMiddlewareOptions<TRegister, TMiddlewares>,
) => RequestMiddlewareAfterServer<{}, undefined, undefined>

const innerCreateCsrfMiddleware: CreateCsrfMiddleware = (opts = {}) => {
const middleware = createMiddleware().server(async (ctx) => {
const csrfCtx = ctx as RequestServerOptions<any, any> & typeof ctx

if (opts.filter && !(await opts.filter(csrfCtx))) {
return ctx.next()
}

if (await isCsrfRequestAllowed(opts, csrfCtx)) {
return ctx.next()
}

return getFailureResponse(opts, csrfCtx)
})

if (process.env.NODE_ENV !== 'production') {
Object.defineProperty(middleware, csrfSymbol, { value: true })
}

return middleware
}

export const createCsrfMiddleware: CreateCsrfMiddleware =
createIsomorphicFn().server(innerCreateCsrfMiddleware) as CreateCsrfMiddleware

export async function isCsrfRequestAllowed<TRegister, TMiddlewares>(
opts: CsrfMiddlewareOptions<TRegister, TMiddlewares>,
ctx: RequestServerOptions<TRegister, TMiddlewares>,
): Promise<boolean> {
const result = await getCsrfRequestValidationResult(opts, ctx)
return (
result === true ||
(result === undefined && opts.allowRequestsWithoutOriginCheck === true)
)
}

export async function getCsrfRequestValidationResult<TRegister, TMiddlewares>(
opts: CsrfMiddlewareOptions<TRegister, TMiddlewares>,
ctx: RequestServerOptions<TRegister, TMiddlewares>,
): Promise<boolean | undefined> {
const fetchSite = ctx.request.headers.get('Sec-Fetch-Site')
if (fetchSite !== null) {
return matchValue(opts.secFetchSite ?? 'same-origin', fetchSite, ctx)
}

const origin = ctx.request.headers.get('Origin')
if (origin !== null) {
if (opts.origin) {
return matchValue(opts.origin, origin, ctx)
}

return origin === new URL(ctx.request.url).origin
}

const referer = ctx.request.headers.get('Referer')
if (referer === null || opts.referer === false) {
return undefined
}

if (typeof opts.referer === 'function') {
return opts.referer(referer, ctx)
}

if (opts.origin) {
const refererOrigin = getOriginFromUrl(referer)
return (
refererOrigin !== undefined && matchValue(opts.origin, refererOrigin, ctx)
)
}

return isRefererSameOrigin(referer, new URL(ctx.request.url).origin)
}

async function matchValue<TValue extends string, TRegister, TMiddlewares>(
matcher: CsrfMatcher<TValue, TRegister, TMiddlewares>,
value: string,
ctx: RequestServerOptions<TRegister, TMiddlewares>,
): Promise<boolean> {
if (typeof matcher === 'function') {
return matcher(value, ctx)
}

if (Array.isArray(matcher)) {
// typescript is dumb for array.includes()
return matcher.includes(value as TValue)
}

return value === matcher
}

function getOriginFromUrl(url: string): string | undefined {
try {
return new URL(url).origin
} catch {
return undefined
}
}

function isRefererSameOrigin(referer: string, requestOrigin: string): boolean {
if (referer === requestOrigin) return true
if (!referer.startsWith(requestOrigin)) return false
if (referer.length === requestOrigin.length) return true
const code = referer.charCodeAt(requestOrigin.length)
return code === 47 /* '/' */ || code === 63 /* '?' */ || code === 35 /* '#' */
}

async function getFailureResponse<TRegister, TMiddlewares>(
opts: CsrfMiddlewareOptions<TRegister, TMiddlewares>,
ctx: RequestServerOptions<TRegister, TMiddlewares>,
): Promise<Response> {
if (typeof opts.failureResponse === 'function') {
return opts.failureResponse(ctx)
}

return (
opts.failureResponse?.clone() ?? new Response('Forbidden', { status: 403 })
)
}
4 changes: 4 additions & 0 deletions packages/start-client-core/src/createMiddleware.ts
Original file line number Diff line number Diff line change
Expand Up @@ -770,6 +770,10 @@ export interface RequestServerOptions<TRegister, TMiddlewares> {
pathname: string
context: Expand<AssignAllServerRequestContext<TRegister, TMiddlewares>>
next: RequestServerNextFn<TRegister, TMiddlewares>
/**
* Type of Start handler currently processing this request.
*/
handlerType: 'serverFn' | 'router'
/**
* Metadata about the server function being invoked.
* This is only present when the request is handling a server function call.
Expand Down
Loading
Loading