diff --git a/packages/middleware-sdk-transcribe-streaming/src/websocket-handler.spec.ts b/packages/middleware-sdk-transcribe-streaming/src/websocket-handler.spec.ts index ae54530112dc..df6cd3f41436 100644 --- a/packages/middleware-sdk-transcribe-streaming/src/websocket-handler.spec.ts +++ b/packages/middleware-sdk-transcribe-streaming/src/websocket-handler.spec.ts @@ -6,35 +6,98 @@ import { PassThrough } from "stream"; import { WebSocketHandler } from "./websocket-handler"; describe("WebSocketHandler", () => { + const mockHostname = "localhost:6789"; + const mockUrl = `ws://${mockHostname}/`; + beforeEach(() => { (global as any).WebSocket = WebSocket; }); + afterEach(() => { WS.clean(); jest.clearAllMocks(); }); + it("should contain protocol metadata", () => { const handler = new WebSocketHandler(); expect(handler.metadata.handlerProtocol).toEqual("websocket"); }); + it("populates socket in socket pool based on handle() requests", async () => { + const handler = new WebSocketHandler(); + const server = new WS(mockUrl); + + // @ts-expect-error Property 'sockets' is private and only accessible within class 'WebSocketHandler'. + expect(handler.sockets[mockUrl]).not.toBeDefined(); + + await handler.handle( + new HttpRequest({ + body: new PassThrough(), + hostname: mockHostname, + protocol: "ws:", + }) + ); + + // @ts-expect-error Property 'sockets' is private and only accessible within class 'WebSocketHandler'. + expect(handler.sockets[mockUrl]).toBeDefined(); + // @ts-expect-error Property 'sockets' is private and only accessible within class 'WebSocketHandler'. + expect(handler.sockets[mockUrl].length).toBe(1); + + await handler.handle( + new HttpRequest({ + body: new PassThrough(), + hostname: mockHostname, + protocol: "ws:", + }) + ); + + // @ts-expect-error Property 'sockets' is private and only accessible within class 'WebSocketHandler'. + expect(handler.sockets[mockUrl].length).toBe(2); + }); + + it("closes socket in socket pool on handler.destroy()", async () => { + const handler = new WebSocketHandler(); + const server = new WS(mockUrl); + + await handler.handle( + new HttpRequest({ + body: new PassThrough(), + hostname: mockHostname, + protocol: "ws:", + }) + ); + + // @ts-expect-error Property 'sockets' is private and only accessible within class 'WebSocketHandler'. + const socket = handler.sockets[mockUrl][0]; + + expect(socket.readyState).toBe(WebSocket.OPEN); + handler.destroy(); + + // Verify that socket.close() is called + expect(socket.readyState).toBe(WebSocket.CLOSING); + }); + it("should throw in output stream if input stream throws", async () => { - expect.assertions(2); + expect.assertions(3); const handler = new WebSocketHandler(); //Using Node stream is fine because they are also async iterables. const payload = new PassThrough(); - const server = new WS("ws://localhost:6789"); + + const server = new WS(mockUrl); + const { response: { body: responsePayload }, } = await handler.handle( new HttpRequest({ body: payload, - hostname: "localhost:6789", + hostname: mockHostname, protocol: "ws:", }) ); + await server.connected; payload.emit("error", new Error("FakeError")); + try { // eslint-disable-next-line @typescript-eslint/no-unused-vars for await (const chunk of responsePayload) { @@ -43,22 +106,30 @@ describe("WebSocketHandler", () => { } catch (err) { expect(err).toBeDefined(); expect(err.message).toEqual("FakeError"); + // @ts-expect-error Property 'sockets' is private and only accessible within class 'WebSocketHandler'. + expect(handler.sockets[mockUrl].length).toBe(0); } }); it("should return retryable error if cannot setup ws connection", async () => { - expect.assertions(4); + expect.assertions(5); + const originalFn = setTimeout; (global as any).setTimeout = jest.fn().mockImplementation(setTimeout); + const connectionTimeout = 1000; const handler = new WebSocketHandler({ connectionTimeout }); + //Using Node stream is fine because they are also async iterables. const payload = new PassThrough(); + const mockInvalidHostname = "localhost:9876"; + const mockInvalidUrl = `ws://${mockInvalidHostname}/`; + try { await handler.handle( new HttpRequest({ body: payload, - hostname: "localhost:9876", //invalid websocket endpoint + hostname: mockInvalidHostname, //invalid websocket endpoint protocol: "ws:", }) ); @@ -72,6 +143,8 @@ describe("WebSocketHandler", () => { return args[0].toString().indexOf("$metadata") >= 0; })[0][1] ).toBe(connectionTimeout); + // @ts-expect-error Property 'sockets' is private and only accessible within class 'WebSocketHandler'. + expect(handler.sockets[mockInvalidUrl].length).toBe(0); } (global as any).setTimeout = originalFn; }); diff --git a/packages/middleware-sdk-transcribe-streaming/src/websocket-handler.ts b/packages/middleware-sdk-transcribe-streaming/src/websocket-handler.ts index 99d48bef22fc..5258400edf84 100644 --- a/packages/middleware-sdk-transcribe-streaming/src/websocket-handler.ts +++ b/packages/middleware-sdk-transcribe-streaming/src/websocket-handler.ts @@ -22,21 +22,52 @@ export class WebSocketHandler implements HttpHandler { handlerProtocol: "websocket", }; private readonly connectionTimeout: number; + private readonly sockets: Record = {}; + constructor({ connectionTimeout }: WebSocketHandlerOptions = {}) { this.connectionTimeout = connectionTimeout || 2000; } - destroy(): void {} + /** + * Destroys the WebSocketHandler. + * Closes all sockets from the socket pool. + */ + destroy(): void { + for (const [key, sockets] of Object.entries(this.sockets)) { + for (const socket of sockets) { + socket.close(1000, `Socket closed through destroy() call`); + } + delete this.sockets[key]; + } + } + + /** + * Removes all closing/closed sockets from the socket pool for URL. + */ + private removeNotUsableSockets(url: string): void { + this.sockets[url] = this.sockets[url].filter( + (socket) => ![WebSocket.CLOSING, WebSocket.CLOSED].includes(socket.readyState) + ); + } async handle(request: HttpRequest): Promise<{ response: HttpResponse }> { const url = formatUrl(request); const socket: WebSocket = new WebSocket(url); + + // Add socket to sockets pool + if (!this.sockets[url]) { + this.sockets[url] = []; + } + this.sockets[url].push(socket); + socket.binaryType = "arraybuffer"; - await waitForReady(socket, this.connectionTimeout); + await this.waitForReady(socket, this.connectionTimeout); + const { body } = request; const bodyStream = getIterator(body); - const asyncIterable = connect(socket, bodyStream); + const asyncIterable = this.connect(socket, bodyStream); const outputPayload = toReadableStream(asyncIterable); + return { response: new HttpResponse({ statusCode: 200, // indicates connection success @@ -44,76 +75,90 @@ export class WebSocketHandler implements HttpHandler { }), }; } -} -const waitForReady = (socket: WebSocket, connectionTimeout: number): Promise => - new Promise((resolve, reject) => { - const timeout = setTimeout(() => { - reject({ - $metadata: { - httpStatusCode: 500, - }, - }); - }, connectionTimeout); - socket.onopen = () => { - clearTimeout(timeout); - resolve(); - }; - }); - -const connect = (socket: WebSocket, data: AsyncIterable): AsyncIterable => { - // To notify output stream any error thrown after response - // is returned while data keeps streaming. - let streamError: Error | undefined = undefined; - const outputStream: AsyncIterable = { - [Symbol.asyncIterator]: () => ({ - next: () => { - return new Promise((resolve, reject) => { - socket.onerror = (error) => { - socket.onclose = null; - socket.close(); - reject(error); - }; - socket.onclose = () => { - if (streamError) { - reject(streamError); - } else { + private waitForReady(socket: WebSocket, connectionTimeout: number): Promise { + return new Promise((resolve, reject) => { + const timeout = setTimeout(() => { + this.removeNotUsableSockets(socket.url); + reject({ + $metadata: { + httpStatusCode: 500, + }, + }); + }, connectionTimeout); + + socket.onopen = () => { + clearTimeout(timeout); + resolve(); + }; + }); + } + + private connect(socket: WebSocket, data: AsyncIterable): AsyncIterable { + // To notify output stream any error thrown after response + // is returned while data keeps streaming. + let streamError: Error | undefined = undefined; + + const outputStream: AsyncIterable = { + [Symbol.asyncIterator]: () => ({ + next: () => { + return new Promise((resolve, reject) => { + // To notify onclose event that error has occurred + let socketErrorOccurred = false; + + socket.onerror = (error) => { + socketErrorOccurred = true; + socket.close(); + reject(error); + }; + + socket.onclose = () => { + this.removeNotUsableSockets(socket.url); + if (socketErrorOccurred) return; + + if (streamError) { + reject(streamError); + } else { + resolve({ + done: true, + value: undefined, + }); + } + }; + + socket.onmessage = (event) => { resolve({ - done: true, - value: undefined, + done: false, + value: new Uint8Array(event.data), }); - } - }; - socket.onmessage = (event) => { - resolve({ - done: false, - value: new Uint8Array(event.data), - }); - }; - }); - }, - }), - }; + }; + }); + }, + }), + }; - const send = async (): Promise => { - try { - for await (const inputChunk of data) { - socket.send(inputChunk); + const send = async (): Promise => { + try { + for await (const inputChunk of data) { + socket.send(inputChunk); + } + } catch (err) { + // We don't throw the error here because the send()'s returned + // would already be settled by the time sending chunk throws error. + // Instead, the notify the output stream to throw if there's + // exceptions + streamError = err; + } finally { + // WS status code: https://tools.ietf.org/html/rfc6455#section-7.4 + socket.close(1000); } - } catch (err) { - // We don't throw the error here because the send()'s returned - // would already be settled by the time sending chunk throws error. - // Instead, the notify the output stream to throw if there's - // exceptions - streamError = err; - } finally { - // WS status code: https://tools.ietf.org/html/rfc6455#section-7.4 - socket.close(1000); - } - }; - send(); - return outputStream; -}; + }; + + send(); + + return outputStream; + } +} /** * Transfer payload data to an AsyncIterable. @@ -123,18 +168,21 @@ const connect = (socket: WebSocket, data: AsyncIterable): AsyncItera */ const getIterator = (stream: any): AsyncIterable => { // Noop if stream is already an async iterable - if (stream[Symbol.asyncIterator]) return stream; - else if (isReadableStream(stream)) { + if (stream[Symbol.asyncIterator]) { + return stream; + } + + if (isReadableStream(stream)) { //If stream is a ReadableStream, transfer the ReadableStream to async iterable. return readableStreamtoIterable(stream); - } else { - //For other types, just wrap them with an async iterable. - return { - [Symbol.asyncIterator]: async function* () { - yield stream; - }, - }; } + + // For other types, just wrap them with an async iterable. + return { + [Symbol.asyncIterator]: async function* () { + yield stream; + }, + }; }; /**