diff --git a/src/commands/scan/create-scan-from-github.mts b/src/commands/scan/create-scan-from-github.mts index 2a5b55c55..58dd9baee 100644 --- a/src/commands/scan/create-scan-from-github.mts +++ b/src/commands/scan/create-scan-from-github.mts @@ -16,6 +16,7 @@ import { confirm, select } from '@socketsecurity/registry/lib/prompts' import { fetchSupportedScanFileNames } from './fetch-supported-scan-file-names.mts' import { handleCreateNewScan } from './handle-create-new-scan.mts' import constants from '../../constants.mts' +import { apiFetch } from '../../utils/api.mts' import { debugApiRequest, debugApiResponse } from '../../utils/debug.mts' import { formatErrorWithDetail } from '../../utils/errors.mts' import { isReportSupportedFile } from '../../utils/glob.mts' @@ -402,7 +403,7 @@ async function downloadManifestFile({ debugApiRequest('GET', fileUrl) let downloadUrlResponse: Response try { - downloadUrlResponse = await fetch(fileUrl, { + downloadUrlResponse = await apiFetch(fileUrl, { method: 'GET', headers: { Authorization: `Bearer ${githubToken}`, @@ -466,7 +467,7 @@ async function streamDownloadWithFetch( try { debugApiRequest('GET', downloadUrl) - response = await fetch(downloadUrl) + response = await apiFetch(downloadUrl) debugApiResponse('GET', downloadUrl, response.status) if (!response.ok) { @@ -567,7 +568,7 @@ async function getLastCommitDetails({ debugApiRequest('GET', commitApiUrl) let commitResponse: Response try { - commitResponse = await fetch(commitApiUrl, { + commitResponse = await apiFetch(commitApiUrl, { headers: { Authorization: `Bearer ${githubToken}`, }, @@ -679,7 +680,7 @@ async function getRepoDetails({ let repoDetailsResponse: Response try { debugApiRequest('GET', repoApiUrl) - repoDetailsResponse = await fetch(repoApiUrl, { + repoDetailsResponse = await apiFetch(repoApiUrl, { method: 'GET', headers: { Authorization: `Bearer ${githubToken}`, @@ -743,7 +744,7 @@ async function getRepoBranchTree({ let treeResponse: Response try { debugApiRequest('GET', treeApiUrl) - treeResponse = await fetch(treeApiUrl, { + treeResponse = await apiFetch(treeApiUrl, { method: 'GET', headers: { Authorization: `Bearer ${githubToken}`, diff --git a/src/utils/api.mts b/src/utils/api.mts index 77eaf8357..2a31f0f49 100644 --- a/src/utils/api.mts +++ b/src/utils/api.mts @@ -19,6 +19,8 @@ * - Falls back to configured apiBaseUrl or default API_V0_URL */ +import { Agent as HttpsAgent, request as httpsRequest } from 'node:https' + import { messageWithCauses } from 'pony-cause' import { debugDir, debugFn } from '@socketsecurity/registry/lib/debug' @@ -37,7 +39,7 @@ import constants, { HTTP_STATUS_UNAUTHORIZED, } from '../constants.mts' import { getRequirements, getRequirementsKey } from './requirements.mts' -import { getDefaultApiToken } from './sdk.mts' +import { getDefaultApiToken, getExtraCaCerts } from './sdk.mts' import type { CResult } from '../types.mts' import type { Spinner } from '@socketsecurity/registry/lib/spinner' @@ -48,8 +50,149 @@ import type { SocketSdkSuccessResult, } from '@socketsecurity/sdk' +const MAX_REDIRECTS = 20 const NO_ERROR_MESSAGE = 'No error message returned' +// Cached HTTPS agent for extra CA certificate support in direct API calls. +let _httpsAgent: HttpsAgent | undefined +let _httpsAgentResolved = false + +// Returns an HTTPS agent configured with extra CA certificates when +// SSL_CERT_FILE is set but NODE_EXTRA_CA_CERTS is not. +function getHttpsAgent(): HttpsAgent | undefined { + if (_httpsAgentResolved) { + return _httpsAgent + } + _httpsAgentResolved = true + const ca = getExtraCaCerts() + if (!ca) { + return undefined + } + _httpsAgent = new HttpsAgent({ ca }) + return _httpsAgent +} + +// Wrapper around fetch that supports extra CA certificates via SSL_CERT_FILE. +// Uses node:https.request with a custom agent when extra CA certs are needed, +// falling back to regular fetch() otherwise. Follows redirects like fetch(). +export type ApiFetchInit = { + body?: string | undefined + headers?: Record | undefined + method?: string | undefined +} + +// Internal httpsRequest-based fetch with redirect support. +function _httpsRequestFetch( + url: string, + init: ApiFetchInit, + agent: HttpsAgent, + redirectCount: number, +): Promise { + return new Promise((resolve, reject) => { + const headers: Record = { ...init.headers } + // Set Content-Length for request bodies to avoid chunked transfer encoding. + if (init.body) { + headers['content-length'] = String(Buffer.byteLength(init.body)) + } + const req = httpsRequest( + url, + { + method: init.method || 'GET', + headers, + agent, + }, + res => { + const { statusCode } = res + // Follow redirects to match fetch() behavior. + if ( + statusCode && + statusCode >= 300 && + statusCode < 400 && + res.headers['location'] + ) { + // Consume the response body to free up memory. + res.resume() + if (redirectCount >= MAX_REDIRECTS) { + reject(new Error('Maximum redirect limit reached')) + return + } + const redirectUrl = new URL(res.headers['location'], url).href + // Strip sensitive headers on cross-origin redirects to match + // fetch() behavior per the Fetch spec. + const originalOrigin = new URL(url).origin + const redirectOrigin = new URL(redirectUrl).origin + let redirectHeaders = init.headers + if (originalOrigin !== redirectOrigin && redirectHeaders) { + redirectHeaders = { ...redirectHeaders } + for (const key of Object.keys(redirectHeaders)) { + const lower = key.toLowerCase() + if ( + lower === 'authorization' || + lower === 'cookie' || + lower === 'proxy-authorization' + ) { + delete redirectHeaders[key] + } + } + } + // 307 and 308 preserve the original method and body. + const preserveMethod = statusCode === 307 || statusCode === 308 + resolve( + _httpsRequestFetch( + redirectUrl, + preserveMethod + ? { ...init, headers: redirectHeaders } + : { headers: redirectHeaders, method: 'GET' }, + agent, + redirectCount + 1, + ), + ) + return + } + const chunks: Buffer[] = [] + res.on('data', (chunk: Buffer) => chunks.push(chunk)) + res.on('end', () => { + const body = Buffer.concat(chunks) + const responseHeaders = new Headers() + for (const [key, value] of Object.entries(res.headers)) { + if (typeof value === 'string') { + responseHeaders.set(key, value) + } else if (Array.isArray(value)) { + for (const v of value) { + responseHeaders.append(key, v) + } + } + } + resolve( + new Response(body, { + status: statusCode ?? 0, + statusText: res.statusMessage ?? '', + headers: responseHeaders, + }), + ) + }) + res.on('error', reject) + }, + ) + if (init.body) { + req.write(init.body) + } + req.on('error', reject) + req.end() + }) +} + +export async function apiFetch( + url: string, + init: ApiFetchInit = {}, +): Promise { + const agent = getHttpsAgent() + if (!agent) { + return await fetch(url, init as globalThis.RequestInit) + } + return await _httpsRequestFetch(url, init, agent, 0) +} + export type CommandRequirements = { permissions?: string[] | undefined quota?: number | undefined @@ -287,7 +430,7 @@ async function queryApi(path: string, apiToken: string) { } const url = `${baseUrl}${baseUrl.endsWith('/') ? '' : '/'}${path}` - const result = await fetch(url, { + const result = await apiFetch(url, { method: 'GET', headers: { Authorization: `Basic ${btoa(`${apiToken}:`)}`, @@ -480,7 +623,7 @@ export async function sendApiRequest( ...(body ? { body: JSON.stringify(body) } : {}), } - result = await fetch( + result = await apiFetch( `${baseUrl}${baseUrl.endsWith('/') ? '' : '/'}${path}`, fetchOptions, ) diff --git a/src/utils/api.test.mts b/src/utils/api.test.mts new file mode 100644 index 000000000..22146a2df --- /dev/null +++ b/src/utils/api.test.mts @@ -0,0 +1,585 @@ +/** + * Unit tests for API utilities with extra CA certificate support. + * + * Purpose: + * Tests the apiFetch wrapper that enables SSL_CERT_FILE support for + * direct API calls when NODE_EXTRA_CA_CERTS is not set at process startup. + * + * Test Coverage: + * - apiFetch falls back to regular fetch when no extra CA certs are needed. + * - apiFetch uses node:https.request with custom agent when CA certs are set. + * - Response object construction from https.request output. + * - POST requests with JSON body through https.request path. + * - Error propagation from https.request failures. + * + * Testing Approach: + * Mocks node:https, node:fs, node:tls, and the SDK module to test the + * apiFetch behavior in isolation without network calls. + * + * Related Files: + * - utils/api.mts (implementation) + * - utils/sdk.mts (getExtraCaCerts) + */ + +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' + +import type { EventEmitter } from 'node:events' + +// Mock getExtraCaCerts from sdk.mts. +const mockGetExtraCaCerts = vi.hoisted(() => vi.fn(() => undefined)) +const mockGetDefaultApiToken = vi.hoisted(() => vi.fn(() => 'test-api-token')) +vi.mock('./sdk.mts', () => ({ + getDefaultApiToken: mockGetDefaultApiToken, + getDefaultApiBaseUrl: vi.fn(() => undefined), + getDefaultProxyUrl: vi.fn(() => undefined), + getExtraCaCerts: mockGetExtraCaCerts, +})) + +// Mock node:https request function. +type RequestCallback = ( + res: EventEmitter & { + statusCode?: number + statusMessage?: string + headers: Record + }, +) => void +const mockHttpsRequest = vi.hoisted(() => vi.fn()) +const MockHttpsAgent = vi.hoisted(() => + vi.fn().mockImplementation(opts => ({ ...opts, _isHttpsAgent: true })), +) +vi.mock('node:https', () => ({ + Agent: MockHttpsAgent, + request: mockHttpsRequest, +})) + +// Mock constants. +vi.mock('../constants.mts', () => ({ + default: { + API_V0_URL: 'https://api.socket.dev/v0/', + ENV: { + NODE_EXTRA_CA_CERTS: '', + SOCKET_CLI_API_BASE_URL: 'https://api.socket.dev/v0/', + SOCKET_CLI_API_TIMEOUT: 30_000, + }, + spinner: { + failAndStop: vi.fn(), + start: vi.fn(), + stop: vi.fn(), + successAndStop: vi.fn(), + }, + }, + CONFIG_KEY_API_BASE_URL: 'apiBaseUrl', + EMPTY_VALUE: '', + HTTP_STATUS_BAD_REQUEST: 400, + HTTP_STATUS_FORBIDDEN: 403, + HTTP_STATUS_INTERNAL_SERVER_ERROR: 500, + HTTP_STATUS_NOT_FOUND: 404, + HTTP_STATUS_UNAUTHORIZED: 401, +})) + +// Mock config. +vi.mock('./config.mts', () => ({ + getConfigValueOrUndef: vi.fn(() => undefined), +})) + +// Mock debug functions. +vi.mock('./debug.mts', () => ({ + debugApiRequest: vi.fn(), + debugApiResponse: vi.fn(), +})) + +// Mock requirements. +vi.mock('./requirements.mts', () => ({ + getRequirements: vi.fn(() => ({ api: {} })), + getRequirementsKey: vi.fn(() => ''), +})) + +// Mock telemetry. +vi.mock('./telemetry/integration.mts', () => ({ + trackCliEvent: vi.fn(), +})) + +// Store original fetch for restoration. +const originalFetch = globalThis.fetch + +describe('apiFetch with extra CA certificates', () => { + beforeEach(() => { + vi.clearAllMocks() + vi.resetModules() + mockGetExtraCaCerts.mockReturnValue(undefined) + }) + + afterEach(() => { + globalThis.fetch = originalFetch + }) + + it('should use regular fetch when no extra CA certs are needed', async () => { + const mockResponse = new Response(JSON.stringify({ ok: true }), { + status: 200, + statusText: 'OK', + }) + globalThis.fetch = vi.fn().mockResolvedValue(mockResponse) + + const { queryApiSafeText } = await import('./api.mts') + const result = await queryApiSafeText('test/path', 'test request') + + expect(globalThis.fetch).toHaveBeenCalled() + expect(mockHttpsRequest).not.toHaveBeenCalled() + }) + + it('should use https.request when extra CA certs are available', async () => { + const caCerts = ['ROOT_CERT', 'EXTRA_CERT'] + mockGetExtraCaCerts.mockReturnValue(caCerts) + + // Create a mock request object that simulates node:https.request. + const mockReq = { + end: vi.fn(), + on: vi.fn(), + write: vi.fn(), + } + + mockHttpsRequest.mockImplementation( + (_url: string, _opts: unknown, callback: RequestCallback) => { + // Simulate an async response. + setTimeout(() => { + const mockRes = { + headers: { 'content-type': 'text/plain' }, + on: vi.fn(), + statusCode: 200, + statusMessage: 'OK', + } + // Capture data and end handlers. + const handlers: Record = {} + mockRes.on.mockImplementation((event: string, handler: Function) => { + handlers[event] = handler + return mockRes + }) + callback(mockRes) + // Emit data and end events. + handlers['data']?.(Buffer.from('response body')) + handlers['end']?.() + }, 0) + return mockReq + }, + ) + + const { queryApiSafeText } = await import('./api.mts') + const result = await queryApiSafeText('test/path', 'test request') + + expect(mockHttpsRequest).toHaveBeenCalled() + // Verify the agent was created with CA certs. + expect(MockHttpsAgent).toHaveBeenCalledWith({ ca: caCerts }) + // Verify the request was made with the agent. + const callArgs = mockHttpsRequest.mock.calls[0] + expect(callArgs[1]).toEqual( + expect.objectContaining({ + agent: expect.objectContaining({ ca: caCerts }), + method: 'GET', + }), + ) + }) + + it('should construct valid Response from https.request output', async () => { + const caCerts = ['ROOT_CERT', 'EXTRA_CERT'] + mockGetExtraCaCerts.mockReturnValue(caCerts) + + const responseBody = JSON.stringify({ data: 'test-value' }) + const mockReq = { + end: vi.fn(), + on: vi.fn(), + write: vi.fn(), + } + + mockHttpsRequest.mockImplementation( + (_url: string, _opts: unknown, callback: RequestCallback) => { + setTimeout(() => { + const mockRes = { + headers: { 'content-type': 'application/json' }, + on: vi.fn(), + statusCode: 200, + statusMessage: 'OK', + } + const handlers: Record = {} + mockRes.on.mockImplementation((event: string, handler: Function) => { + handlers[event] = handler + return mockRes + }) + callback(mockRes) + handlers['data']?.(Buffer.from(responseBody)) + handlers['end']?.() + }, 0) + return mockReq + }, + ) + + const { queryApiSafeText } = await import('./api.mts') + const result = await queryApiSafeText('test/path', 'test request') + + expect(result.ok).toBe(true) + if (result.ok) { + expect(result.data).toBe(responseBody) + } + }) + + it('should handle https.request errors gracefully', async () => { + const caCerts = ['ROOT_CERT', 'EXTRA_CERT'] + mockGetExtraCaCerts.mockReturnValue(caCerts) + + const mockReq = { + end: vi.fn(), + on: vi.fn(), + write: vi.fn(), + } + + // Simulate a connection error. + mockHttpsRequest.mockImplementation(() => { + // Capture the error handler and trigger it. + const handlers: Record = {} + mockReq.on.mockImplementation((event: string, handler: Function) => { + handlers[event] = handler + return mockReq + }) + setTimeout(() => { + handlers['error']?.(new Error('unable to get local issuer certificate')) + }, 0) + return mockReq + }) + + const { queryApiSafeText } = await import('./api.mts') + const result = await queryApiSafeText('test/path', 'test request') + + expect(result.ok).toBe(false) + if (!result.ok) { + expect(result.cause).toContain('unable to get local issuer certificate') + } + }) + + it('should pass request body for POST requests through https.request', async () => { + const caCerts = ['ROOT_CERT', 'EXTRA_CERT'] + mockGetExtraCaCerts.mockReturnValue(caCerts) + + const mockReq = { + end: vi.fn(), + on: vi.fn(), + write: vi.fn(), + } + + mockHttpsRequest.mockImplementation( + (_url: string, _opts: unknown, callback: RequestCallback) => { + setTimeout(() => { + const mockRes = { + headers: { 'content-type': 'application/json' }, + on: vi.fn(), + statusCode: 200, + statusMessage: 'OK', + } + const handlers: Record = {} + mockRes.on.mockImplementation((event: string, handler: Function) => { + handlers[event] = handler + return mockRes + }) + callback(mockRes) + handlers['data']?.(Buffer.from('{"result":"ok"}')) + handlers['end']?.() + }, 0) + return mockReq + }, + ) + + const { sendApiRequest } = await import('./api.mts') + const result = await sendApiRequest('test/path', { + body: { key: 'value' }, + method: 'POST', + }) + + // Verify body was written to the request. + expect(mockReq.write).toHaveBeenCalledWith('{"key":"value"}') + expect(result.ok).toBe(true) + }) + + it('should handle multi-value response headers from https.request', async () => { + const caCerts = ['ROOT_CERT', 'EXTRA_CERT'] + mockGetExtraCaCerts.mockReturnValue(caCerts) + + const mockReq = { + end: vi.fn(), + on: vi.fn(), + write: vi.fn(), + } + + mockHttpsRequest.mockImplementation( + (_url: string, _opts: unknown, callback: RequestCallback) => { + setTimeout(() => { + const mockRes = { + headers: { + 'content-type': 'text/plain', + 'set-cookie': ['a=1', 'b=2'], + }, + on: vi.fn(), + statusCode: 200, + statusMessage: 'OK', + } + const handlers: Record = {} + mockRes.on.mockImplementation((event: string, handler: Function) => { + handlers[event] = handler + return mockRes + }) + callback(mockRes as any) + handlers['data']?.(Buffer.from('ok')) + handlers['end']?.() + }, 0) + return mockReq + }, + ) + + const { queryApiSafeText } = await import('./api.mts') + const result = await queryApiSafeText('test/path') + + expect(result.ok).toBe(true) + }) + + it('should set Content-Length header for POST requests through https.request', async () => { + const caCerts = ['ROOT_CERT', 'EXTRA_CERT'] + mockGetExtraCaCerts.mockReturnValue(caCerts) + + const mockReq = { + end: vi.fn(), + on: vi.fn(), + write: vi.fn(), + } + + mockHttpsRequest.mockImplementation( + (_url: string, _opts: unknown, callback: RequestCallback) => { + setTimeout(() => { + const mockRes = { + headers: { 'content-type': 'application/json' }, + on: vi.fn(), + statusCode: 200, + statusMessage: 'OK', + } + const handlers: Record = {} + mockRes.on.mockImplementation((event: string, handler: Function) => { + handlers[event] = handler + return mockRes + }) + callback(mockRes) + handlers['data']?.(Buffer.from('{"result":"ok"}')) + handlers['end']?.() + }, 0) + return mockReq + }, + ) + + const { sendApiRequest } = await import('./api.mts') + await sendApiRequest('test/path', { + body: { key: 'value' }, + method: 'POST', + }) + + // Verify Content-Length header was set in the request options. + const callArgs = mockHttpsRequest.mock.calls[0] + const requestHeaders = callArgs[1].headers + expect(requestHeaders['content-length']).toBe( + String(Buffer.byteLength('{"key":"value"}')), + ) + }) + + it('should follow redirects in https.request path', async () => { + const caCerts = ['ROOT_CERT', 'EXTRA_CERT'] + mockGetExtraCaCerts.mockReturnValue(caCerts) + + const mockReq = { + end: vi.fn(), + on: vi.fn(), + write: vi.fn(), + } + + // First call: return a 302 redirect. + mockHttpsRequest.mockImplementationOnce( + (_url: string, _opts: unknown, callback: RequestCallback) => { + setTimeout(() => { + const mockRes = { + headers: { location: 'https://api.socket.dev/v0/redirected' }, + on: vi.fn(), + resume: vi.fn(), + statusCode: 302, + statusMessage: 'Found', + } + callback(mockRes as any) + }, 0) + return mockReq + }, + ) + + // Second call: return the actual response. + mockHttpsRequest.mockImplementationOnce( + (_url: string, _opts: unknown, callback: RequestCallback) => { + setTimeout(() => { + const mockRes = { + headers: { 'content-type': 'text/plain' }, + on: vi.fn(), + statusCode: 200, + statusMessage: 'OK', + } + const handlers: Record = {} + mockRes.on.mockImplementation((event: string, handler: Function) => { + handlers[event] = handler + return mockRes + }) + callback(mockRes) + handlers['data']?.(Buffer.from('redirected response')) + handlers['end']?.() + }, 0) + return mockReq + }, + ) + + const { queryApiSafeText } = await import('./api.mts') + const result = await queryApiSafeText('test/path', 'test request') + + // Should have made two https requests: original and redirect. + expect(mockHttpsRequest).toHaveBeenCalledTimes(2) + // Second call should be to the redirected URL. + expect(mockHttpsRequest.mock.calls[1][0]).toBe( + 'https://api.socket.dev/v0/redirected', + ) + expect(result.ok).toBe(true) + if (result.ok) { + expect(result.data).toBe('redirected response') + } + }) + + it('should strip Authorization header on cross-origin redirects', async () => { + const caCerts = ['ROOT_CERT', 'EXTRA_CERT'] + mockGetExtraCaCerts.mockReturnValue(caCerts) + + const mockReq = { + end: vi.fn(), + on: vi.fn(), + write: vi.fn(), + } + + // First call: return a 302 redirect to a different origin. + mockHttpsRequest.mockImplementationOnce( + (_url: string, _opts: unknown, callback: RequestCallback) => { + setTimeout(() => { + const mockRes = { + headers: { location: 'https://cdn.example.com/file' }, + on: vi.fn(), + resume: vi.fn(), + statusCode: 302, + statusMessage: 'Found', + } + callback(mockRes as any) + }, 0) + return mockReq + }, + ) + + // Second call: return the actual response from the CDN. + mockHttpsRequest.mockImplementationOnce( + (_url: string, _opts: unknown, callback: RequestCallback) => { + setTimeout(() => { + const mockRes = { + headers: { 'content-type': 'text/plain' }, + on: vi.fn(), + statusCode: 200, + statusMessage: 'OK', + } + const handlers: Record = {} + mockRes.on.mockImplementation((event: string, handler: Function) => { + handlers[event] = handler + return mockRes + }) + callback(mockRes) + handlers['data']?.(Buffer.from('cdn response')) + handlers['end']?.() + }, 0) + return mockReq + }, + ) + + const { apiFetch } = await import('./api.mts') + await apiFetch('https://api.github.com/repos/test/contents', { + headers: { + Authorization: 'Bearer ghp_secret', + 'Content-Type': 'application/json', + }, + }) + + // First request should have Authorization header. + const firstCallHeaders = (mockHttpsRequest.mock.calls[0][1] as any).headers + expect(firstCallHeaders['Authorization']).toBe('Bearer ghp_secret') + + // Second request (cross-origin redirect) should NOT have Authorization. + const secondCallHeaders = (mockHttpsRequest.mock.calls[1][1] as any).headers + expect(secondCallHeaders['Authorization']).toBeUndefined() + // Non-sensitive headers should still be present. + expect(secondCallHeaders['Content-Type']).toBe('application/json') + }) + + it('should preserve Authorization header on same-origin redirects', async () => { + const caCerts = ['ROOT_CERT', 'EXTRA_CERT'] + mockGetExtraCaCerts.mockReturnValue(caCerts) + + const mockReq = { + end: vi.fn(), + on: vi.fn(), + write: vi.fn(), + } + + // First call: return a 302 redirect to the same origin. + mockHttpsRequest.mockImplementationOnce( + (_url: string, _opts: unknown, callback: RequestCallback) => { + setTimeout(() => { + const mockRes = { + headers: { location: '/v0/other-path' }, + on: vi.fn(), + resume: vi.fn(), + statusCode: 302, + statusMessage: 'Found', + } + callback(mockRes as any) + }, 0) + return mockReq + }, + ) + + // Second call: return the actual response. + mockHttpsRequest.mockImplementationOnce( + (_url: string, _opts: unknown, callback: RequestCallback) => { + setTimeout(() => { + const mockRes = { + headers: { 'content-type': 'text/plain' }, + on: vi.fn(), + statusCode: 200, + statusMessage: 'OK', + } + const handlers: Record = {} + mockRes.on.mockImplementation((event: string, handler: Function) => { + handlers[event] = handler + return mockRes + }) + callback(mockRes) + handlers['data']?.(Buffer.from('same-origin response')) + handlers['end']?.() + }, 0) + return mockReq + }, + ) + + const { apiFetch } = await import('./api.mts') + await apiFetch('https://api.github.com/repos/test/contents', { + headers: { + Authorization: 'Bearer ghp_secret', + }, + }) + + // Both requests should have Authorization since same origin. + const firstCallHeaders = (mockHttpsRequest.mock.calls[0][1] as any).headers + expect(firstCallHeaders['Authorization']).toBe('Bearer ghp_secret') + + const secondCallHeaders = (mockHttpsRequest.mock.calls[1][1] as any).headers + expect(secondCallHeaders['Authorization']).toBe('Bearer ghp_secret') + }) +}) diff --git a/src/utils/dlx-binary.mts b/src/utils/dlx-binary.mts index eaf309838..9b55fc9fa 100644 --- a/src/utils/dlx-binary.mts +++ b/src/utils/dlx-binary.mts @@ -31,6 +31,7 @@ import { readJson } from '@socketsecurity/registry/lib/fs' import { spawn } from '@socketsecurity/registry/lib/spawn' import constants from '../constants.mts' +import { apiFetch } from './api.mts' import { InputError } from './errors.mts' import type { @@ -117,7 +118,7 @@ async function downloadBinary( destPath: string, checksum?: string, ): Promise { - const response = await fetch(url) + const response = await apiFetch(url) if (!response.ok) { throw new InputError( `Failed to download binary: ${response.status} ${response.statusText}`, diff --git a/src/utils/sdk.mts b/src/utils/sdk.mts index e0c749f5d..ac72b0dc2 100644 --- a/src/utils/sdk.mts +++ b/src/utils/sdk.mts @@ -24,9 +24,14 @@ * - Includes CLI version and platform information */ +import { readFileSync } from 'node:fs' +import { Agent as HttpsAgent } from 'node:https' +import { rootCertificates } from 'node:tls' + import { HttpProxyAgent, HttpsProxyAgent } from 'hpagent' import isInteractive from '@socketregistry/is-interactive/index.cjs' +import { debugFn } from '@socketsecurity/registry/lib/debug' import { logger } from '@socketsecurity/registry/lib/logger' import { password } from '@socketsecurity/registry/lib/prompts' import { isNonEmptyString } from '@socketsecurity/registry/lib/strings' @@ -65,6 +70,41 @@ export function getDefaultProxyUrl(): string | undefined { return isUrl(apiProxy) ? apiProxy : undefined } +// Cached extra CA certificates for SSL_CERT_FILE support. +let _extraCaCerts: string[] | undefined +let _extraCaCertsResolved = false + +// Returns combined root and extra CA certificates when SSL_CERT_FILE is set +// but NODE_EXTRA_CA_CERTS is not. Node.js loads NODE_EXTRA_CA_CERTS at process +// startup, so setting SSL_CERT_FILE alone does not affect the current process. +// This function reads the certificate file manually and combines it with the +// default root certificates for use in HTTPS agents. +export function getExtraCaCerts(): string[] | undefined { + if (_extraCaCertsResolved) { + return _extraCaCerts + } + _extraCaCertsResolved = true + // Node.js already loaded extra CA certs at startup. + if (process.env['NODE_EXTRA_CA_CERTS']) { + return undefined + } + // Check for SSL_CERT_FILE fallback via constants. + const certPath = constants.ENV.NODE_EXTRA_CA_CERTS + if (!certPath) { + return undefined + } + try { + const extraCerts = readFileSync(certPath, 'utf-8') + // Combine default root certificates with extra certificates. Specifying ca + // in an agent replaces the default trust store, so both must be included. + _extraCaCerts = [...rootCertificates, extraCerts] + return _extraCaCerts + } catch (e) { + debugFn('warn', `Failed to read certificate file: ${certPath}`, e) + return undefined + } +} + // This Socket API token should be stored globally for the duration of the CLI execution. let _defaultToken: string | undefined @@ -146,8 +186,21 @@ export async function setupSdk( ? HttpProxyAgent : HttpsProxyAgent + // Load extra CA certificates for SSL_CERT_FILE support when + // NODE_EXTRA_CA_CERTS was not set at process startup. + const ca = getExtraCaCerts() + const sdkOptions = { - ...(apiProxy ? { agent: new ProxyAgent({ proxy: apiProxy }) } : {}), + ...(apiProxy + ? { + agent: new ProxyAgent({ + proxy: apiProxy, + ...(ca ? { ca, proxyConnectOptions: { ca } } : {}), + }), + } + : ca + ? { agent: new HttpsAgent({ ca }) } + : {}), ...(apiBaseUrl ? { baseUrl: apiBaseUrl } : {}), timeout: constants.ENV.SOCKET_CLI_API_TIMEOUT, userAgent: createUserAgentFromPkgJson({ diff --git a/src/utils/sdk.test.mts b/src/utils/sdk.test.mts index f8278c214..ea8cd7e08 100644 --- a/src/utils/sdk.test.mts +++ b/src/utils/sdk.test.mts @@ -22,13 +22,47 @@ * - utils/telemetry/integration.mts (telemetry tracking) */ -import { beforeEach, describe, expect, it, vi } from 'vitest' +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' import constants from '../constants.mts' -import { setupSdk } from './sdk.mts' +import { getExtraCaCerts, setupSdk } from './sdk.mts' import type { RequestInfo, ResponseInfo } from '@socketsecurity/sdk' +// Mock node:fs for certificate file reading. +const mockReadFileSync = vi.hoisted(() => vi.fn()) +vi.mock('node:fs', () => ({ + readFileSync: mockReadFileSync, +})) + +// Mock node:tls for root certificates. +const MOCK_ROOT_CERTS = vi.hoisted(() => [ + '-----BEGIN CERTIFICATE-----\nROOT1\n-----END CERTIFICATE-----', +]) +vi.mock('node:tls', () => ({ + rootCertificates: MOCK_ROOT_CERTS, +})) + +// Mock node:https for HttpsAgent. +const MockHttpsAgent = vi.hoisted(() => + vi.fn().mockImplementation(opts => ({ ...opts, _isHttpsAgent: true })), +) +vi.mock('node:https', () => ({ + Agent: MockHttpsAgent, +})) + +// Mock hpagent proxy agents. +const MockHttpProxyAgent = vi.hoisted(() => + vi.fn().mockImplementation(opts => ({ ...opts, _isHttpProxyAgent: true })), +) +const MockHttpsProxyAgent = vi.hoisted(() => + vi.fn().mockImplementation(opts => ({ ...opts, _isHttpsProxyAgent: true })), +) +vi.mock('hpagent', () => ({ + HttpProxyAgent: MockHttpProxyAgent, + HttpsProxyAgent: MockHttpsProxyAgent, +})) + // Mock telemetry integration. const mockTrackCliEvent = vi.hoisted(() => vi.fn()) vi.mock('./telemetry/integration.mts', () => ({ @@ -43,6 +77,12 @@ vi.mock('./debug.mts', () => ({ debugApiResponse: mockDebugApiResponse, })) +// Mock registry debug functions used by getExtraCaCerts. +vi.mock('@socketsecurity/registry/lib/debug', () => ({ + debugDir: vi.fn(), + debugFn: vi.fn(), +})) + // Mock config. const mockGetConfigValueOrUndef = vi.hoisted(() => vi.fn(() => undefined)) vi.mock('./config.mts', () => ({ @@ -73,6 +113,7 @@ vi.mock('../constants.mts', () => ({ INLINED_SOCKET_CLI_HOMEPAGE: 'https://github.com/SocketDev/socket-cli', INLINED_SOCKET_CLI_NAME: 'socket-cli', INLINED_SOCKET_CLI_VERSION: '1.1.34', + NODE_EXTRA_CA_CERTS: '', SOCKET_CLI_API_TIMEOUT: 30_000, SOCKET_CLI_DEBUG: false, }, @@ -515,3 +556,161 @@ describe('SDK setup with telemetry hooks', () => { }) }) }) + +describe('getExtraCaCerts', () => { + const savedEnv = { ...process.env } + + beforeEach(() => { + vi.clearAllMocks() + // Reset the cached state by re-importing. Since vitest caches modules, + // we reset the internal state via the resolved flag workaround: calling + // the function after resetting module-level state is not possible without + // re-import, so we use resetModules. + vi.resetModules() + // Restore environment variables. + process.env = { ...savedEnv } + delete process.env['NODE_EXTRA_CA_CERTS'] + delete process.env['SSL_CERT_FILE'] + constants.ENV.NODE_EXTRA_CA_CERTS = '' + }) + + afterEach(() => { + process.env = savedEnv + }) + + it('should return undefined when no cert env vars are set', async () => { + const { getExtraCaCerts: fn } = await import('./sdk.mts') + const result = fn() + expect(result).toBeUndefined() + }) + + it('should return undefined when NODE_EXTRA_CA_CERTS is set in process.env', async () => { + process.env['NODE_EXTRA_CA_CERTS'] = '/some/cert.pem' + const { getExtraCaCerts: fn } = await import('./sdk.mts') + const result = fn() + expect(result).toBeUndefined() + // Should not attempt to read the file. + expect(mockReadFileSync).not.toHaveBeenCalled() + }) + + it('should read cert file and combine with root certs when SSL_CERT_FILE is set', async () => { + const fakePEM = + '-----BEGIN CERTIFICATE-----\nEXTRA\n-----END CERTIFICATE-----' + constants.ENV.NODE_EXTRA_CA_CERTS = '/path/to/ssl-cert.pem' + mockReadFileSync.mockReturnValue(fakePEM) + + const { getExtraCaCerts: fn } = await import('./sdk.mts') + const result = fn() + + expect(mockReadFileSync).toHaveBeenCalledWith( + '/path/to/ssl-cert.pem', + 'utf-8', + ) + expect(result).toEqual([...MOCK_ROOT_CERTS, fakePEM]) + }) + + it('should return undefined when cert file does not exist', async () => { + constants.ENV.NODE_EXTRA_CA_CERTS = '/nonexistent/cert.pem' + mockReadFileSync.mockImplementation(() => { + throw new Error('ENOENT: no such file or directory') + }) + + const { getExtraCaCerts: fn } = await import('./sdk.mts') + const result = fn() + + expect(mockReadFileSync).toHaveBeenCalled() + expect(result).toBeUndefined() + }) + + it('should cache the result after first call', async () => { + const fakePEM = + '-----BEGIN CERTIFICATE-----\nCACHED\n-----END CERTIFICATE-----' + constants.ENV.NODE_EXTRA_CA_CERTS = '/path/to/cert.pem' + mockReadFileSync.mockReturnValue(fakePEM) + + const { getExtraCaCerts: fn } = await import('./sdk.mts') + const result1 = fn() + const result2 = fn() + + // File should only be read once. + expect(mockReadFileSync).toHaveBeenCalledTimes(1) + expect(result1).toBe(result2) + }) +}) + +describe('setupSdk with extra CA certificates', () => { + const savedEnv = { ...process.env } + + beforeEach(() => { + vi.clearAllMocks() + vi.resetModules() + mockGetConfigValueOrUndef.mockReturnValue(undefined) + constants.ENV.SOCKET_CLI_DEBUG = false + constants.ENV.NODE_EXTRA_CA_CERTS = '' + process.env = { ...savedEnv } + delete process.env['NODE_EXTRA_CA_CERTS'] + }) + + afterEach(() => { + process.env = savedEnv + }) + + it('should pass CA certs to HttpsAgent when SSL_CERT_FILE is configured', async () => { + const fakePEM = + '-----BEGIN CERTIFICATE-----\nAGENT\n-----END CERTIFICATE-----' + constants.ENV.NODE_EXTRA_CA_CERTS = '/path/to/cert.pem' + mockReadFileSync.mockReturnValue(fakePEM) + + const { setupSdk: fn } = await import('./sdk.mts') + const result = await fn({ apiToken: 'test-token' }) + + expect(result.ok).toBe(true) + if (result.ok) { + // Should create an HttpsAgent with combined CA certs. + expect(result.data.options.agent).toBeDefined() + expect(MockHttpsAgent).toHaveBeenCalledWith({ + ca: [...MOCK_ROOT_CERTS, fakePEM], + }) + } + }) + + it('should pass CA certs to proxy agent when proxy and SSL_CERT_FILE are configured', async () => { + const fakePEM = + '-----BEGIN CERTIFICATE-----\nPROXY\n-----END CERTIFICATE-----' + constants.ENV.NODE_EXTRA_CA_CERTS = '/path/to/cert.pem' + mockReadFileSync.mockReturnValue(fakePEM) + + const expectedCa = [...MOCK_ROOT_CERTS, fakePEM] + const { setupSdk: fn } = await import('./sdk.mts') + const result = await fn({ + apiProxy: 'http://proxy.example.com:8080', + apiToken: 'test-token', + }) + + expect(result.ok).toBe(true) + if (result.ok) { + expect(result.data.options.agent).toBeDefined() + // Verify the proxy agent was constructed with CA and proxy connect options. + expect(MockSocketSdk).toHaveBeenCalledWith( + 'test-token', + expect.objectContaining({ + agent: expect.objectContaining({ + ca: expectedCa, + proxyConnectOptions: { ca: expectedCa }, + }), + }), + ) + } + }) + + it('should not create agent when no extra CA certs are needed', async () => { + const { setupSdk: fn } = await import('./sdk.mts') + const result = await fn({ apiToken: 'test-token' }) + + expect(result.ok).toBe(true) + if (result.ok) { + expect(result.data.options.agent).toBeUndefined() + expect(MockHttpsAgent).not.toHaveBeenCalled() + } + }) +})