Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,35 +6,98 @@ import { PassThrough } from "stream";
import { WebSocketHandler } from "./websocket-handler";

describe("WebSocketHandler", () => {
const mockHostname = "localhost:6789";
const mockUrl = `ws://${mockHostname}/`;

beforeEach(() => {
(global as any).WebSocket = WebSocket;
});

afterEach(() => {
WS.clean();
jest.clearAllMocks();
});

it("should contain protocol metadata", () => {
const handler = new WebSocketHandler();
expect(handler.metadata.handlerProtocol).toEqual("websocket");
});

it("populates socket in socket pool based on handle() requests", async () => {
const handler = new WebSocketHandler();
const server = new WS(mockUrl);

// @ts-expect-error Property 'sockets' is private and only accessible within class 'WebSocketHandler'.
expect(handler.sockets[mockUrl]).not.toBeDefined();

await handler.handle(
new HttpRequest({
body: new PassThrough(),
hostname: mockHostname,
protocol: "ws:",
})
);

// @ts-expect-error Property 'sockets' is private and only accessible within class 'WebSocketHandler'.
expect(handler.sockets[mockUrl]).toBeDefined();
// @ts-expect-error Property 'sockets' is private and only accessible within class 'WebSocketHandler'.
expect(handler.sockets[mockUrl].length).toBe(1);

await handler.handle(
new HttpRequest({
body: new PassThrough(),
hostname: mockHostname,
protocol: "ws:",
})
);

// @ts-expect-error Property 'sockets' is private and only accessible within class 'WebSocketHandler'.
expect(handler.sockets[mockUrl].length).toBe(2);
});

it("closes socket in socket pool on handler.destroy()", async () => {
const handler = new WebSocketHandler();
const server = new WS(mockUrl);

await handler.handle(
new HttpRequest({
body: new PassThrough(),
hostname: mockHostname,
protocol: "ws:",
})
);

// @ts-expect-error Property 'sockets' is private and only accessible within class 'WebSocketHandler'.
const socket = handler.sockets[mockUrl][0];

expect(socket.readyState).toBe(WebSocket.OPEN);
handler.destroy();

// Verify that socket.close() is called
expect(socket.readyState).toBe(WebSocket.CLOSING);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

on what timing does this transition to WebSocket.CLOSED?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried timeouts of 1s and 5s, but the state was still closing.
That's why I updated source code to remove sockets with closing state, and verified state to be closing.

});

it("should throw in output stream if input stream throws", async () => {
expect.assertions(2);
expect.assertions(3);
const handler = new WebSocketHandler();
//Using Node stream is fine because they are also async iterables.
const payload = new PassThrough();
const server = new WS("ws://localhost:6789");

const server = new WS(mockUrl);

const {
response: { body: responsePayload },
} = await handler.handle(
new HttpRequest({
body: payload,
hostname: "localhost:6789",
hostname: mockHostname,
protocol: "ws:",
})
);

await server.connected;
payload.emit("error", new Error("FakeError"));

try {
// eslint-disable-next-line @typescript-eslint/no-unused-vars
for await (const chunk of responsePayload) {
Expand All @@ -43,22 +106,30 @@ describe("WebSocketHandler", () => {
} catch (err) {
expect(err).toBeDefined();
expect(err.message).toEqual("FakeError");
// @ts-expect-error Property 'sockets' is private and only accessible within class 'WebSocketHandler'.
expect(handler.sockets[mockUrl].length).toBe(0);
}
});

it("should return retryable error if cannot setup ws connection", async () => {
expect.assertions(4);
expect.assertions(5);

const originalFn = setTimeout;
(global as any).setTimeout = jest.fn().mockImplementation(setTimeout);

const connectionTimeout = 1000;
const handler = new WebSocketHandler({ connectionTimeout });

//Using Node stream is fine because they are also async iterables.
const payload = new PassThrough();
const mockInvalidHostname = "localhost:9876";
const mockInvalidUrl = `ws://${mockInvalidHostname}/`;

try {
await handler.handle(
new HttpRequest({
body: payload,
hostname: "localhost:9876", //invalid websocket endpoint
hostname: mockInvalidHostname, //invalid websocket endpoint
protocol: "ws:",
})
);
Expand All @@ -72,6 +143,8 @@ describe("WebSocketHandler", () => {
return args[0].toString().indexOf("$metadata") >= 0;
})[0][1]
).toBe(connectionTimeout);
// @ts-expect-error Property 'sockets' is private and only accessible within class 'WebSocketHandler'.
expect(handler.sockets[mockInvalidUrl].length).toBe(0);
}
(global as any).setTimeout = originalFn;
});
Expand Down
202 changes: 125 additions & 77 deletions packages/middleware-sdk-transcribe-streaming/src/websocket-handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,98 +22,143 @@ export class WebSocketHandler implements HttpHandler {
handlerProtocol: "websocket",
};
private readonly connectionTimeout: number;
private readonly sockets: Record<string, WebSocket[]> = {};

constructor({ connectionTimeout }: WebSocketHandlerOptions = {}) {
this.connectionTimeout = connectionTimeout || 2000;
}

destroy(): void {}
/**
* Destroys the WebSocketHandler.
* Closes all sockets from the socket pool.
*/
destroy(): void {
for (const [key, sockets] of Object.entries(this.sockets)) {
for (const socket of sockets) {
socket.close(1000, `Socket closed through destroy() call`);
}
delete this.sockets[key];
}
}

/**
* Removes all closing/closed sockets from the socket pool for URL.
*/
private removeNotUsableSockets(url: string): void {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rename to removeClosedSockets?

Copy link
Member Author

@trivikr trivikr Feb 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was originally named removeClosedSockets
It removes Closing sockets too, that's why I'd renamed it from Closed to NotUsable.

this.sockets[url] = this.sockets[url].filter(
(socket) => ![WebSocket.CLOSING, WebSocket.CLOSED].includes(socket.readyState)
);
}

async handle(request: HttpRequest): Promise<{ response: HttpResponse }> {
const url = formatUrl(request);
const socket: WebSocket = new WebSocket(url);

// Add socket to sockets pool
if (!this.sockets[url]) {
this.sockets[url] = [];
}
this.sockets[url].push(socket);

socket.binaryType = "arraybuffer";
await waitForReady(socket, this.connectionTimeout);
await this.waitForReady(socket, this.connectionTimeout);

const { body } = request;
const bodyStream = getIterator(body);
const asyncIterable = connect(socket, bodyStream);
const asyncIterable = this.connect(socket, bodyStream);
const outputPayload = toReadableStream(asyncIterable);

return {
response: new HttpResponse({
statusCode: 200, // indicates connection success
body: outputPayload,
}),
};
}
}

const waitForReady = (socket: WebSocket, connectionTimeout: number): Promise<void> =>
new Promise((resolve, reject) => {
const timeout = setTimeout(() => {
reject({
$metadata: {
httpStatusCode: 500,
},
});
}, connectionTimeout);
socket.onopen = () => {
clearTimeout(timeout);
resolve();
};
});

const connect = (socket: WebSocket, data: AsyncIterable<Uint8Array>): AsyncIterable<Uint8Array> => {
// To notify output stream any error thrown after response
// is returned while data keeps streaming.
let streamError: Error | undefined = undefined;
const outputStream: AsyncIterable<Uint8Array> = {
[Symbol.asyncIterator]: () => ({
next: () => {
return new Promise((resolve, reject) => {
socket.onerror = (error) => {
socket.onclose = null;
socket.close();
reject(error);
};
socket.onclose = () => {
if (streamError) {
reject(streamError);
} else {
private waitForReady(socket: WebSocket, connectionTimeout: number): Promise<void> {
return new Promise((resolve, reject) => {
const timeout = setTimeout(() => {
this.removeNotUsableSockets(socket.url);
reject({
$metadata: {
httpStatusCode: 500,
},
});
}, connectionTimeout);

socket.onopen = () => {
clearTimeout(timeout);
resolve();
};
});
}

private connect(socket: WebSocket, data: AsyncIterable<Uint8Array>): AsyncIterable<Uint8Array> {
// To notify output stream any error thrown after response
// is returned while data keeps streaming.
let streamError: Error | undefined = undefined;

const outputStream: AsyncIterable<Uint8Array> = {
[Symbol.asyncIterator]: () => ({
next: () => {
return new Promise((resolve, reject) => {
// To notify onclose event that error has occurred
let socketErrorOccurred = false;

socket.onerror = (error) => {
socketErrorOccurred = true;
socket.close();
reject(error);
};

socket.onclose = () => {
this.removeNotUsableSockets(socket.url);
if (socketErrorOccurred) return;

if (streamError) {
reject(streamError);
} else {
resolve({
done: true,
value: undefined,
});
}
};

socket.onmessage = (event) => {
resolve({
done: true,
value: undefined,
done: false,
value: new Uint8Array(event.data),
});
}
};
socket.onmessage = (event) => {
resolve({
done: false,
value: new Uint8Array(event.data),
});
};
});
},
}),
};
};
});
},
}),
};

const send = async (): Promise<void> => {
try {
for await (const inputChunk of data) {
socket.send(inputChunk);
const send = async (): Promise<void> => {
try {
for await (const inputChunk of data) {
socket.send(inputChunk);
}
} catch (err) {
// We don't throw the error here because the send()'s returned
// would already be settled by the time sending chunk throws error.
// Instead, the notify the output stream to throw if there's
// exceptions
streamError = err;
} finally {
// WS status code: https://tools.ietf.org/html/rfc6455#section-7.4
socket.close(1000);
}
} catch (err) {
// We don't throw the error here because the send()'s returned
// would already be settled by the time sending chunk throws error.
// Instead, the notify the output stream to throw if there's
// exceptions
streamError = err;
} finally {
// WS status code: https://tools.ietf.org/html/rfc6455#section-7.4
socket.close(1000);
}
};
send();
return outputStream;
};
};

send();

return outputStream;
}
}

/**
* Transfer payload data to an AsyncIterable.
Expand All @@ -123,18 +168,21 @@ const connect = (socket: WebSocket, data: AsyncIterable<Uint8Array>): AsyncItera
*/
const getIterator = (stream: any): AsyncIterable<any> => {
// Noop if stream is already an async iterable
if (stream[Symbol.asyncIterator]) return stream;
else if (isReadableStream(stream)) {
if (stream[Symbol.asyncIterator]) {
return stream;
}

if (isReadableStream(stream)) {
//If stream is a ReadableStream, transfer the ReadableStream to async iterable.
return readableStreamtoIterable(stream);
} else {
//For other types, just wrap them with an async iterable.
return {
[Symbol.asyncIterator]: async function* () {
yield stream;
},
};
}

// For other types, just wrap them with an async iterable.
return {
[Symbol.asyncIterator]: async function* () {
yield stream;
},
};
};

/**
Expand Down