From 7620de2786efb7b5b0294314c38359e8b322f4fd Mon Sep 17 00:00:00 2001 From: michael1011 Date: Mon, 5 Feb 2024 15:51:46 +0100 Subject: [PATCH] feat: API V2 WebSocket --- lib/api/Api.ts | 6 +- lib/api/v2/WebSocketHandler.ts | 202 +++++++++++++++ test/unit/api/v2/WebSocketHandler.spec.ts | 289 ++++++++++++++++++++++ 3 files changed, 496 insertions(+), 1 deletion(-) create mode 100644 lib/api/v2/WebSocketHandler.ts create mode 100644 test/unit/api/v2/WebSocketHandler.spec.ts diff --git a/lib/api/Api.ts b/lib/api/Api.ts index e2659154..5dd25082 100644 --- a/lib/api/Api.ts +++ b/lib/api/Api.ts @@ -7,9 +7,11 @@ import Service from '../service/Service'; import Controller from './Controller'; import { errorResponse } from './Utils'; import ApiV2 from './v2/ApiV2'; +import WebSocketHandler from './v2/WebSocketHandler'; class Api { private app: Application; + private readonly websocket: WebSocketHandler; private readonly controller: Controller; constructor( @@ -46,6 +48,7 @@ class Api { ); this.controller = new Controller(logger, service, countryCodes); + this.websocket = new WebSocketHandler(service, this.controller); new ApiV2( this.logger, @@ -60,10 +63,11 @@ class Api { await this.controller.init(); await new Promise((resolve) => { - this.app.listen(this.config.port, this.config.host, () => { + const server = this.app.listen(this.config.port, this.config.host, () => { this.logger.info( `API server listening on: ${this.config.host}:${this.config.port}`, ); + this.websocket.register(server); resolve(); }); }); diff --git a/lib/api/v2/WebSocketHandler.ts b/lib/api/v2/WebSocketHandler.ts new file mode 100644 index 00000000..48915c75 --- /dev/null +++ b/lib/api/v2/WebSocketHandler.ts @@ -0,0 +1,202 @@ +import http from 'http'; +import ws from 'ws'; +import { formatError } from '../../Utils'; +import Service from '../../service/Service'; +import Controller from '../Controller'; +import { apiPrefix } from './Consts'; + +enum Operation { + // TODO: unsubscribe + Subscribe = 'subscribe', + Update = 'update', +} + +enum SubscriptionChannel { + SwapUpdate = 'swap.update', +} + +type WsRequest = { + op: Operation; +}; + +type WsSubscribeRequest = WsRequest & { + channel: SubscriptionChannel; + args: string[]; +}; + +type WsResponse = { + event: Operation; +}; + +type WsErrorResponse = { + error: string; +}; + +class WebSocketHandler { + private static readonly pingIntervalMs = 15_000; + + private readonly ws: ws.Server; + private pingInterval?: NodeJS.Timer; + + private readonly swapToSockets = new Map(); + private readonly socketToSwaps = new Map(); + + constructor( + private readonly service: Service, + private readonly controller: Controller, + ) { + this.ws = new ws.Server({ + noServer: true, + }); + this.listenConnections(); + this.listenSwapUpdates(); + } + + public register = (server: http.Server) => { + server.on('upgrade', (request, socket, head) => { + if (request.url !== `${apiPrefix}/ws`) { + request.destroy(); + socket.destroy(); + return; + } + + this.ws.handleUpgrade(request, socket, head, (ws) => { + this.ws.emit('connection', ws, request); + }); + }); + + this.pingInterval = setInterval(() => { + this.ws.clients.forEach((ws) => ws.ping()); + }, WebSocketHandler.pingIntervalMs); + }; + + public close = () => { + this.ws.close(); + clearInterval(this.pingInterval); + }; + + private listenConnections = () => { + this.ws.on('connection', (socket) => { + socket.on('message', (msg) => this.handleMessage(socket, msg)); + socket.on('close', () => { + const ids = this.socketToSwaps.get(socket); + if (ids === undefined) { + return; + } + + this.socketToSwaps.delete(socket); + + for (const id of ids) { + const sockets = this.swapToSockets + .get(id) + ?.filter((s) => s !== socket); + if (sockets === undefined || sockets.length === 0) { + this.swapToSockets.delete(id); + continue; + } + + this.swapToSockets.set(id, sockets); + } + }); + }); + }; + + private handleMessage = (socket: ws, message: ws.RawData) => { + try { + const data = JSON.parse(message.toString('utf-8')) as WsRequest; + + switch (data.op) { + case Operation.Subscribe: + this.handleSubscribe(socket, data); + break; + + default: + this.sendToSocket(socket, { error: 'unknown operation' }); + break; + } + } catch (e) { + this.sendToSocket(socket, { + error: `could not parse message: ${formatError(e)}`, + }); + } + }; + + private handleSubscribe = (socket: ws, data: WsRequest) => { + const subscribeData = data as WsSubscribeRequest; + switch (subscribeData.channel) { + case SubscriptionChannel.SwapUpdate: { + const existingIds = this.socketToSwaps.get(socket) || []; + this.socketToSwaps.set( + socket, + existingIds.concat( + subscribeData.args.filter((id) => !existingIds.includes(id)), + ), + ); + + for (const id of subscribeData.args) { + const existingSockets = this.swapToSockets.get(id) || []; + if (existingSockets.includes(socket)) { + continue; + } + + this.swapToSockets.set(id, existingSockets.concat(socket)); + } + + break; + } + + default: + this.sendToSocket(socket, { error: 'unknown channel' }); + return; + } + + this.sendToSocket(socket, { + event: Operation.Subscribe, + channel: subscribeData.channel, + args: subscribeData.args, + }); + + if (subscribeData.channel === SubscriptionChannel.SwapUpdate) { + const args = subscribeData.args + .map((id) => [id, this.controller.pendingSwapInfos.get(id)]) + .filter(([, data]) => data !== undefined); + + this.sendToSocket(socket, { + event: Operation.Update, + channel: SubscriptionChannel.SwapUpdate, + args: args, + }); + } + }; + + private listenSwapUpdates = () => { + this.service.eventHandler.on('swap.update', ({ id, status }) => { + const sockets = this.swapToSockets.get(id); + if (sockets === undefined) { + return; + } + + for (const socket of sockets) { + this.sendToSocket(socket, { + event: Operation.Update, + channel: SubscriptionChannel.SwapUpdate, + args: [[id, status]], + }); + } + }); + }; + + private sendToSocket = ( + socket: ws, + msg: T | WsErrorResponse, + ) => { + if (socket.readyState !== socket.OPEN) { + return; + } + + socket.send(JSON.stringify(msg)); + }; +} + +export default WebSocketHandler; +export { Operation, SubscriptionChannel }; diff --git a/test/unit/api/v2/WebSocketHandler.spec.ts b/test/unit/api/v2/WebSocketHandler.spec.ts new file mode 100644 index 00000000..eee71919 --- /dev/null +++ b/test/unit/api/v2/WebSocketHandler.spec.ts @@ -0,0 +1,289 @@ +import http from 'http'; +import ws from 'ws'; +import WebSocketHandler, { + Operation, + SubscriptionChannel, +} from '../../../../lib/api/v2/WebSocketHandler'; +import { SwapUpdateEvent } from '../../../../lib/consts/Enums'; +import { SwapUpdate } from '../../../../lib/service/EventHandler'; + +type SwapUpdateCallback = (args: { id: string; status: SwapUpdate }) => void; +let emitSwapUpdate: SwapUpdateCallback; + +describe('WebSocket', () => { + const service = { + eventHandler: { + on: jest.fn().mockImplementation((name, cb) => { + if (name === 'swap.update') { + emitSwapUpdate = cb; + } + }), + }, + } as any; + const controller = { + pendingSwapInfos: new Map([ + ['swap', { status: SwapUpdateEvent.InvoiceSet }], + ['reverse', { status: SwapUpdateEvent.SwapCreated }], + ]), + } as any; + + const server = http.createServer(); + const wsHandler = new WebSocketHandler(service, controller); + + const createWs = async (waitForInit: boolean = true) => { + const socket = new ws( + `ws://127.0.0.1:${(server.address() as any).port}/v2/ws`, + ); + + if (waitForInit) { + await new Promise((resolve) => { + socket.on('open', () => { + resolve(); + }); + }); + } + + return socket; + }; + + const waitForMessage = (socket: ws, message: any) => + new Promise((resolve) => { + socket.on('message', (msg) => { + expect(JSON.parse(msg.toString('utf-8'))).toStrictEqual(message); + resolve(); + }); + }); + + beforeAll(async () => { + await new Promise((resolve) => { + server.listen(0, () => { + resolve(); + }); + }); + + wsHandler.register(server); + }); + + afterAll(() => { + wsHandler.close(); + server.close(); + }); + + test('should upgrade connections', async () => { + const socket = await createWs(false); + await new Promise((resolve) => { + socket.on('open', () => { + resolve(); + }); + }); + + socket.close(); + }); + + test('should respond to pings', async () => { + const socket = await createWs(); + + const pongPromise = new Promise((resolve) => { + socket.on('pong', () => resolve()); + }); + + socket.ping(); + await pongPromise; + + socket.close(); + }); + + test('should respond with error when message cannot be parsed', async () => { + const socket = await createWs(); + const resPromise = waitForMessage(socket, { + error: + 'could not parse message: Unexpected token \'o\', "not json" is not valid JSON', + }); + + socket.send('not json'); + await resPromise; + + socket.close(); + }); + + test('should respond with error for unknown operations', async () => { + const socket = await createWs(); + const resPromise = waitForMessage(socket, { + error: 'unknown operation', + }); + + socket.send( + JSON.stringify({ + op: 'unknown', + }), + ); + await resPromise; + + socket.close(); + }); + + test('should respond with error for unknown subscription channels', async () => { + const socket = await createWs(); + const resPromise = waitForMessage(socket, { + error: 'unknown channel', + }); + + socket.send( + JSON.stringify({ + op: Operation.Subscribe, + channel: 'notachannel', + }), + ); + await resPromise; + + socket.close(); + }); + + test('should ignore swap events with no socket', () => { + emitSwapUpdate({ + id: 'noSocket', + status: { + status: SwapUpdateEvent.SwapCreated, + }, + }); + }); + + test('should subscribe to swap events and send latest status', async () => { + const swapIds = ['swap', 'reverse', 'notFound']; + + const socket = await createWs(); + + const resPromise = new Promise((resolve) => { + let msgCount = 0; + + socket.on('message', (msg) => { + const parsedMsg = JSON.parse(msg.toString('utf-8')); + expect(parsedMsg).toStrictEqual( + msgCount === 0 + ? { + event: Operation.Subscribe, + channel: SubscriptionChannel.SwapUpdate, + args: swapIds, + } + : { + event: Operation.Update, + channel: SubscriptionChannel.SwapUpdate, + args: [ + [ + 'swap', + { + status: SwapUpdateEvent.InvoiceSet, + }, + ], + [ + 'reverse', + { + status: SwapUpdateEvent.SwapCreated, + }, + ], + ], + }, + ); + + msgCount++; + if (msgCount == 2) { + socket.removeAllListeners('message'); + resolve(); + } + }); + }); + + socket.send( + JSON.stringify({ + op: Operation.Subscribe, + channel: SubscriptionChannel.SwapUpdate, + args: swapIds, + }), + ); + + await resPromise; + + socket.send( + JSON.stringify({ + op: Operation.Subscribe, + channel: SubscriptionChannel.SwapUpdate, + args: swapIds, + }), + ); + + expect(wsHandler['swapToSockets'].size).toEqual(3); + for (const id of swapIds) { + expect(wsHandler['swapToSockets'].get(id)).not.toBeUndefined(); + expect(wsHandler['swapToSockets'].get(id)!.length).toEqual(1); + } + + expect(wsHandler['socketToSwaps'].size).toEqual(1); + expect(Array.from(wsHandler['socketToSwaps'].values())).toEqual([swapIds]); + + socket.close(); + }); + + test('should subscribe to swap events and send swap updates', async () => { + const swapId = 'updateId'; + const status = { + status: SwapUpdateEvent.TransactionClaimPending, + }; + + const sockets = await Promise.all([createWs(), createWs()]); + + const setupPromises = sockets.map( + (sock) => + new Promise((resolve) => { + sock.on('message', (msg) => { + const parsedMsg = JSON.parse(msg.toString('utf-8')); + if (parsedMsg.event === Operation.Update) { + resolve(); + } + }); + }), + ); + + sockets.forEach((sock) => + sock.send( + JSON.stringify({ + op: Operation.Subscribe, + channel: SubscriptionChannel.SwapUpdate, + args: [swapId], + }), + ), + ); + + await Promise.all(setupPromises); + + const resPromises = sockets.map( + (sock) => + new Promise((resolve) => { + sock.on('message', (msg) => { + const parsedMsg = JSON.parse(msg.toString('utf-8')); + if ( + parsedMsg.event !== Operation.Update && + parsedMsg.args.length > 0 + ) { + return; + } + + expect(parsedMsg).toStrictEqual({ + event: Operation.Update, + channel: SubscriptionChannel.SwapUpdate, + args: [[swapId, status]], + }); + resolve(); + }); + }), + ); + + emitSwapUpdate({ + status, + id: swapId, + }); + + await Promise.all(resPromises); + + sockets.forEach((sock) => sock.close()); + }); +});