diff --git a/src/app/api/auth/[...nextauth]/route.test.ts b/src/app/api/auth/[...nextauth]/route.test.ts index bec4d8c13..e30b330b0 100644 --- a/src/app/api/auth/[...nextauth]/route.test.ts +++ b/src/app/api/auth/[...nextauth]/route.test.ts @@ -1,6 +1,6 @@ import { handlers } from "@project/auth"; import { extractRequestContextFromHeadersAndCookies } from "@project/src/utils/requestScopedStorageWrapper"; -import { GET } from "@src/app/api/auth/[...nextauth]/route"; +import { GET, POST } from "@src/app/api/auth/[...nextauth]/route"; import { logger } from "@src/utils/logger"; import { NextRequest } from "next/server"; @@ -11,10 +11,6 @@ jest.mock("@project/auth", () => ({ }, })); -jest.mock("@src/app/api/auth/[...nextauth]/provider", () => ({ - NHS_LOGIN_PROVIDER_ID: "nhs-login", -})); - jest.mock("@src/utils/logger", () => ({ logger: { child: jest.fn().mockReturnValue({ @@ -32,112 +28,160 @@ jest.mock("@project/src/utils/requestContext", () => ({ jest.mock("@project/src/utils/requestScopedStorageWrapper"); -jest.mock("next/server", () => ({ - NextRequest: jest.fn().mockImplementation((url: string, init?: RequestInit) => ({ - url, - headers: init?.headers ?? new Headers(), - })), +jest.mock("@src/utils/getHeadersForLogging", () => ({ + getHeadersForLogging: jest.fn().mockReturnValue({ "mock-header": "mock-value" }), })); -const buildMockRequest = (pathname: string, params?: Record): NextRequest => +const buildMockGetRequest = ( + pathname: string, + params?: Record, + headers: Headers = new Headers(), +): NextRequest => ({ method: "GET", nextUrl: { pathname, searchParams: new URLSearchParams(params), }, + headers, }) as unknown as NextRequest; -describe("GET", () => { +describe("NextAuth API Route", () => { const mockLogError = (logger.child as jest.Mock).mock.results[0].value.error as jest.Mock; const mockLogInfo = (logger.child as jest.Mock).mock.results[0].value.info as jest.Mock; const mockResponse = { status: 200 } as Response; - const mockContext = { sessionId: "test-session-id", traceId: "test-trace-id", nextUrl: "" }; beforeEach(() => { (handlers.GET as jest.Mock).mockResolvedValue(mockResponse); + (handlers.POST as jest.Mock).mockResolvedValue(mockResponse); (extractRequestContextFromHeadersAndCookies as jest.Mock).mockReturnValueOnce(mockContext); }); - it("should log the pathname on GET request with correct context", async () => { - const req = buildMockRequest("/api/auth/signin") as unknown as NextRequest; - - await GET(req); - - expect(mockLogInfo).toHaveBeenCalledWith( - { - context: { pathname: "/api/auth/signin" }, - sessionId: "test-session-id", - traceId: "test-trace-id", - nextUrl: "/api/auth/signin", - }, - "GET NextAuth route", - ); - }); - - it("should delegate to nextAuth handlers.GET with the original request", async () => { - const req = buildMockRequest("/api/auth/callback/nhs-login") as unknown as NextRequest; - - await GET(req); - - expect(handlers.GET).toHaveBeenCalledWith(req); - }); - - describe("when the callback URL contains an OAuth error", () => { - it("should log the error and error_description", async () => { - const req = buildMockRequest("/api/auth/callback/nhs-login", { - error: "access_denied", - error_description: "User cancelled login", - }) as unknown as NextRequest; + describe("GET", () => { + it("should log the pathname on GET request with correct context", async () => { + const req = buildMockGetRequest("/api/auth/signin") as unknown as NextRequest; await GET(req); - expect(mockLogError).toHaveBeenCalledWith( - { error: "access_denied", error_description: "User cancelled login" }, - "OAuth provider returned error in callback", + expect(mockLogInfo).toHaveBeenCalledWith( { + context: { + method: "GET", + pathname: "/api/auth/signin", + headers: { "mock-header": "mock-value" }, + }, sessionId: "test-session-id", traceId: "test-trace-id", - nextUrl: "/api/auth/callback/nhs-login", + nextUrl: "/api/auth/signin", }, + "NextAuth route", ); }); - it("should log null when error_description is absent", async () => { - const req = buildMockRequest("/api/auth/callback/nhs-login", { error: "server_error" }) as unknown as NextRequest; + it("should delegate the request to nextAuth handlers.GET", async () => { + const req = buildMockGetRequest("/api/auth/callback/nhs-login") as unknown as NextRequest; await GET(req); - expect(mockLogError).toHaveBeenCalledWith( - { error: "server_error", error_description: null }, - "OAuth provider returned error in callback", + expect(handlers.GET as jest.Mock).toHaveBeenCalledWith(req); + }); + + describe("when the callback URL is called", () => { + it("should log the error and error_description when present", async () => { + const req = buildMockGetRequest("/api/auth/callback/nhs-login", { + error: "access_denied", + error_description: "User cancelled login", + }) as unknown as NextRequest; + + await GET(req); + + expect(mockLogError).toHaveBeenCalledWith( + { error: "access_denied", error_description: "User cancelled login" }, + "OAuth provider returned error in callback", + { + sessionId: "test-session-id", + traceId: "test-trace-id", + nextUrl: "/api/auth/callback/nhs-login", + }, + ); + }); + + it("should log null when error_description is absent", async () => { + const req = buildMockGetRequest("/api/auth/callback/nhs-login", { + error: "access_denied", + }) as unknown as NextRequest; + + await GET(req); + + expect(mockLogError).toHaveBeenCalledWith( + { error: "access_denied", error_description: null }, + "OAuth provider returned error in callback", + { + sessionId: "test-session-id", + traceId: "test-trace-id", + nextUrl: "/api/auth/callback/nhs-login", + }, + ); + }); + + it("should carry on and delegate to handlers.GET when error not present", async () => { + const req = buildMockGetRequest("/api/auth/callback/nhs-login") as unknown as NextRequest; + + await GET(req); + + expect(handlers.GET as jest.Mock).toHaveBeenCalledWith(req); + }); + }); + + describe("when a non-callback url is called", () => { + it("should not log an error even if an error param is present", async () => { + const req = buildMockGetRequest("/api/auth/signin", { error: "OAuthCallbackError" }) as unknown as NextRequest; + + await GET(req); + + expect(mockLogError).not.toHaveBeenCalled(); + }); + }); + }); + + describe("POST", () => { + it("should log the pathname and method on POST request with correct context", async () => { + const req = { + method: "POST", + nextUrl: { pathname: "/api/auth/callback/nhs-login" }, + headers: new Headers(), + } as unknown as NextRequest; + + await POST(req); + + expect(mockLogInfo).toHaveBeenCalledWith( { + context: { + method: "POST", + pathname: "/api/auth/callback/nhs-login", + headers: { "mock-header": "mock-value" }, + }, sessionId: "test-session-id", traceId: "test-trace-id", nextUrl: "/api/auth/callback/nhs-login", }, + "NextAuth route", ); }); - }); - describe("when the callback URL has no error", () => { - it("should not log an error", async () => { - const req = buildMockRequest("/api/auth/callback/nhs-login") as unknown as NextRequest; + it("should delegate the request to nextAuth handlers.POST", async () => { + const req = { + method: "POST", + nextUrl: { pathname: "/api/auth/callback/nhs-login" }, + headers: new Headers(), + } as unknown as NextRequest; - await GET(req); + (extractRequestContextFromHeadersAndCookies as jest.Mock).mockReturnValueOnce(mockContext); - expect(mockLogError).not.toHaveBeenCalled(); - }); - }); - - describe("when the path is not the OAuth callback", () => { - it("should not log an error even if an error param is present", async () => { - const req = buildMockRequest("/api/auth/signin", { error: "OAuthCallbackError" }) as unknown as NextRequest; - - await GET(req); + await POST(req); - expect(mockLogError).not.toHaveBeenCalled(); + expect(handlers.POST as jest.Mock).toHaveBeenCalledWith(req); }); }); }); diff --git a/src/app/api/auth/[...nextauth]/route.ts b/src/app/api/auth/[...nextauth]/route.ts index 1bea4c952..ba500ed40 100644 --- a/src/app/api/auth/[...nextauth]/route.ts +++ b/src/app/api/auth/[...nextauth]/route.ts @@ -2,6 +2,7 @@ import { handlers } from "@project/auth"; import { RequestContext, asyncLocalStorage } from "@project/src/utils/requestContext"; import { extractRequestContextFromHeadersAndCookies } from "@project/src/utils/requestScopedStorageWrapper"; import { NHS_LOGIN_PROVIDER_ID } from "@src/app/api/auth/[...nextauth]/provider"; +import { getHeadersForLogging } from "@src/utils/getHeadersForLogging"; import { logger } from "@src/utils/logger"; import { NextRequest } from "next/server"; @@ -16,7 +17,10 @@ export const GET = async (req: NextRequest) => { requestContext.nextUrl = pathname; return await asyncLocalStorage.run(requestContext, async () => { - log.info({ context: { pathname }, ...requestContext }, "GET NextAuth route"); + log.info( + { context: { method: req.method, pathname, headers: getHeadersForLogging(req) }, ...requestContext }, + "NextAuth route", + ); const error = searchParams.get("error"); if (pathname.includes(NHS_LOGIN_CALLBACK_PATH) && error) { @@ -33,4 +37,17 @@ export const GET = async (req: NextRequest) => { }); }; -export const { POST } = handlers; +export const POST = async (req: NextRequest) => { + const { pathname } = req.nextUrl; + + const requestContext: RequestContext = extractRequestContextFromHeadersAndCookies(req.headers, req?.cookies); + requestContext.nextUrl = pathname; + + return await asyncLocalStorage.run(requestContext, async () => { + log.info( + { context: { method: req.method, pathname, headers: getHeadersForLogging(req) }, ...requestContext }, + "NextAuth route", + ); + return await handlers.POST(req); + }); +}; diff --git a/src/app/api/sso/route.test.ts b/src/app/api/sso/route.test.ts index d8574173b..9231e7409 100644 --- a/src/app/api/sso/route.test.ts +++ b/src/app/api/sso/route.test.ts @@ -2,6 +2,7 @@ import { signIn } from "@project/auth"; import { GET } from "@src/app/api/sso/route"; import config from "@src/utils/config"; import { SESSION_ID_COOKIE_NAME } from "@src/utils/constants"; +import { logger } from "@src/utils/logger"; import { ConfigMock, configBuilder } from "@test-data/config/builders"; import { ResponseCookie, ResponseCookies } from "next/dist/compiled/@edge-runtime/cookies"; import { ReadonlyHeaders } from "next/dist/server/web/spec-extension/adapters/headers"; @@ -15,7 +16,20 @@ jest.mock("@project/auth", () => ({ jest.mock("next/navigation", () => ({ redirect: jest.fn(), })); -jest.mock("sanitize-data", () => ({ sanitize: jest.fn() })); +jest.mock("@src/utils/getHeadersForLogging", () => ({ + getHeadersForLogging: jest.fn().mockReturnValue({ "mock-header": "mock-value" }), +})); +jest.mock("@src/utils/logger", () => ({ + logger: { + child: jest.fn().mockReturnValue({ + info: jest.fn(), + warn: jest.fn(), + error: jest.fn(), + debug: jest.fn(), + }), + }, + extractRootTraceIdFromAmznTraceId: jest.fn().mockReturnValue("mock-trace-id"), +})); jest.mock("next/headers", () => ({ headers: jest.fn(), @@ -34,13 +48,16 @@ const getMockRequest = (testUrl: string, params?: Record) => { ["X-Clacks-Overhead", "GNU Terry Pratchett"], ["referer", "testing"], ]); + const url = new URL(testUrl); return { nextUrl: { searchParams: new URLSearchParams(params), - origin: new URL(testUrl).origin, + origin: url.origin, href: testUrl, + pathname: url.pathname, }, + method: "GET", headers: headers, } as NextRequest; }; @@ -48,6 +65,8 @@ const getMockRequest = (testUrl: string, params?: Record) => { let responseCookies: ResponseCookies; describe("GET handler", () => { + const mockLogInfo = (logger.child as jest.Mock).mock.results[0].value.info as jest.Mock; + beforeEach(() => { jest.clearAllMocks(); jest.useFakeTimers().setSystemTime(mockNowTimeInSeconds * 1000); @@ -68,6 +87,26 @@ describe("GET handler", () => { jest.useRealTimers(); }); + it("logs method, pathname and headers when the route is invoked", async () => { + const testUrl = "https://testurl/api/sso"; + const mockRequest = getMockRequest(testUrl, { assertedLoginIdentity: "test-identity" }); + + (signIn as jest.Mock).mockResolvedValue("/some-url"); + + await GET(mockRequest); + + expect(mockLogInfo).toHaveBeenCalledWith( + expect.objectContaining({ + context: expect.objectContaining({ + method: "GET", + pathname: "/api/sso", + headers: { "mock-header": "mock-value" }, + }), + }), + "SSO route invoked", + ); + }); + it("redirects to sso-failure if assertedLoginIdentity parameter is missing", async () => { const testUrl = "https://testurl"; const mockRequest = getMockRequest(testUrl); diff --git a/src/app/api/sso/route.ts b/src/app/api/sso/route.ts index 1a33510ed..e9c83ac71 100644 --- a/src/app/api/sso/route.ts +++ b/src/app/api/sso/route.ts @@ -2,6 +2,7 @@ import { signIn } from "@project/auth"; import { NHS_LOGIN_PROVIDER_ID } from "@src/app/api/auth/[...nextauth]/provider"; import { SSO_FAILURE_ROUTE } from "@src/app/sso-failure/constants"; import config from "@src/utils/config"; +import { getHeadersForLogging } from "@src/utils/getHeadersForLogging"; import { logger } from "@src/utils/logger"; import { profilePerformanceEnd, profilePerformanceStart } from "@src/utils/performance"; import { RequestContext, asyncLocalStorage } from "@src/utils/requestContext"; @@ -22,7 +23,12 @@ export const GET = async (request: NextRequest) => { requestContext.nextUrl = request.nextUrl.pathname; await asyncLocalStorage.run(requestContext, async () => { - log.info("SSO route invoked"); + log.info( + { + context: { method: request.method, pathname: request.nextUrl.pathname, headers: getHeadersForLogging(request) }, + }, + "SSO route invoked", + ); const assertedLoginIdentity: string | null = request.nextUrl.searchParams.get(ASSERTED_LOGIN_IDENTITY_PARAM); const MAX_SESSION_AGE_MILLISECONDS: number = (await config.MAX_SESSION_AGE_MINUTES) * 60 * 1000; diff --git a/src/proxy.test.ts b/src/proxy.test.ts index 14a0a1493..b2b1a5592 100644 --- a/src/proxy.test.ts +++ b/src/proxy.test.ts @@ -3,7 +3,7 @@ */ import { auth } from "@project/auth"; import { unprotectedUrlPaths } from "@src/app/_components/inactivity/constants"; -import { _getHeadersForLogging, config, proxy } from "@src/proxy"; +import { config, proxy } from "@src/proxy"; import appConfig from "@src/utils/config"; import { SESSION_ID_COOKIE_NAME } from "@src/utils/constants"; import { ConfigMock, configBuilder } from "@test-data/config/builders"; @@ -77,7 +77,7 @@ describe("proxy", () => { expect(result.status).toBe(200); }); - it("pass the nextUrl to the request headers for users with active session", async () => { + it("adds the nextUrl to the request headers for users with active session", async () => { const testUrl = "https://testurl/abc"; const mockRequest = getMockRequest(testUrl); @@ -87,15 +87,6 @@ describe("proxy", () => { expect(result.headers.get("x-middleware-request-nexturl")).toEqual(testUrl); }); - it("_getHeadersForLogging() contains map of special headers", async () => { - const testUrl = "https://nhs-app-redirect-login-url/"; - const mockRequest = getMockRequest(testUrl); - - expect(_getHeadersForLogging(mockRequest as NextRequest)).toEqual( - expect.objectContaining({ referer: "testing", "user-agent": "-" }), - ); - }); - it.each(unprotectedUrlPaths)("is skipped for unprotected path %s", async (path: string) => { // verify the regex does not match unprotected paths expect(proxyRegex.test(path)).toBe(false); diff --git a/src/proxy.ts b/src/proxy.ts index 7acbcd5d0..3bba2816d 100644 --- a/src/proxy.ts +++ b/src/proxy.ts @@ -1,5 +1,6 @@ import { auth } from "@project/auth"; import appConfig from "@src/utils/config"; +import { getHeadersForLogging } from "@src/utils/getHeadersForLogging"; import { logger } from "@src/utils/logger"; import { profilePerformanceEnd, profilePerformanceStart } from "@src/utils/performance"; import { RequestContext, asyncLocalStorage } from "@src/utils/requestContext"; @@ -24,7 +25,7 @@ export async function proxy(request: NextRequest) { const middlewareWrapper = async (request: NextRequest) => { profilePerformanceStart(MiddlewarePerformanceMarker); log.info( - { context: { nextUrl: request.nextUrl.href, headers: _getHeadersForLogging(request) } }, + { context: { method: request.method, nextUrl: request.nextUrl.href, headers: getHeadersForLogging(request) } }, "Inspecting request", ); @@ -50,23 +51,6 @@ const middlewareWrapper = async (request: NextRequest) => { return response; }; -export const _getHeadersForLogging = (request: NextRequest) => { - const SAFE_HEADERS = [ - "cache-control", - "cloudfront-is-desktop-viewer", - "cloudfront-is-mobile-viewer", - "cloudfront-is-tablet-viewer", - "referer", - "user-agent", - ]; - const headersObj: Record = {}; - SAFE_HEADERS.forEach((header) => { - headersObj[header] = request.headers.get(header) ?? "-"; - }); - - return headersObj; -}; - export const config = { matcher: [ /* diff --git a/src/utils/getHeadersForLogging.test.ts b/src/utils/getHeadersForLogging.test.ts new file mode 100644 index 000000000..23501a2c0 --- /dev/null +++ b/src/utils/getHeadersForLogging.test.ts @@ -0,0 +1,61 @@ +import { getHeadersForLogging } from "@src/utils/getHeadersForLogging"; +import { NextRequest } from "next/server"; + +function buildMockRequest(headers: Record = {}): NextRequest { + return { headers: new Headers(Object.entries(headers)) } as unknown as NextRequest; +} + +describe("getHeadersForLogging", () => { + it("returns '-' for all safe headers when none are present", () => { + const req = buildMockRequest(); + + expect(getHeadersForLogging(req)).toEqual({ + "cache-control": "-", + "cloudfront-is-desktop-viewer": "-", + "cloudfront-is-mobile-viewer": "-", + "cloudfront-is-tablet-viewer": "-", + "content-length": "-", + "content-type": "-", + referer: "-", + "user-agent": "-", + }); + }); + + it("returns the value of all safe headers when all are present", () => { + const req = buildMockRequest({ + "cache-control": "no-cache", + "cloudfront-is-desktop-viewer": "true", + "cloudfront-is-mobile-viewer": "false", + "cloudfront-is-tablet-viewer": "false", + "content-length": "42", + "content-type": "application/x-www-form-urlencoded", + referer: "https://example.com/previous-page", + "user-agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7)", + }); + + expect(getHeadersForLogging(req)).toEqual({ + "cache-control": "no-cache", + "cloudfront-is-desktop-viewer": "true", + "cloudfront-is-mobile-viewer": "false", + "cloudfront-is-tablet-viewer": "false", + "content-length": "42", + "content-type": "application/x-www-form-urlencoded", + referer: "https://example.com/previous-page", + "user-agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7)", + }); + }); + + it("does not include headers outside the safe list", () => { + const req = buildMockRequest({ + authorization: "Bearer token", + "x-amzn-trace-id": "Root=1-abc", + cookie: "session=secret", + }); + + const result = getHeadersForLogging(req); + + expect(result).not.toHaveProperty("authorization"); + expect(result).not.toHaveProperty("x-amzn-trace-id"); + expect(result).not.toHaveProperty("cookie"); + }); +}); diff --git a/src/utils/getHeadersForLogging.ts b/src/utils/getHeadersForLogging.ts new file mode 100644 index 000000000..8206e8bde --- /dev/null +++ b/src/utils/getHeadersForLogging.ts @@ -0,0 +1,20 @@ +import { NextRequest } from "next/server"; + +const SAFE_HEADERS = [ + "cache-control", + "cloudfront-is-desktop-viewer", + "cloudfront-is-mobile-viewer", + "cloudfront-is-tablet-viewer", + "content-length", + "content-type", + "referer", + "user-agent", +]; + +export const getHeadersForLogging = (request: NextRequest): Record => { + const headersObj: Record = {}; + SAFE_HEADERS.forEach((header) => { + headersObj[header] = request.headers.get(header) ?? "-"; + }); + return headersObj; +};