diff --git a/__tests__/integration/core-p2p/socket-server/peer.test.ts b/__tests__/integration/core-p2p/socket-server/peer.test.ts index 0172e48f0a..0875b7afe7 100644 --- a/__tests__/integration/core-p2p/socket-server/peer.test.ts +++ b/__tests__/integration/core-p2p/socket-server/peer.test.ts @@ -151,6 +151,10 @@ describe("Peer socket endpoint", () => { await delay(1000); expect(socket.state).toBe("closed"); + + // kill workers to reset ipLastError (or we won't pass handshake for 1 minute) + server.killWorkers({ immediate: true }); + await delay(2000); // give time to workers to respawn }); it("should disconnect the client if it sends too many pongs too quickly", async () => { @@ -172,6 +176,10 @@ describe("Peer socket endpoint", () => { await delay(1000); expect(socket.state).toBe("closed"); + + // kill workers to reset ipLastError (or we won't pass handshake for 1 minute) + server.killWorkers({ immediate: true }); + await delay(2000); // give time to workers to respawn }); it("should disconnect the client if it sends a ping frame", async () => { @@ -183,6 +191,10 @@ describe("Peer socket endpoint", () => { ping(); await delay(500); expect(socket.state).toBe("closed"); + + // kill workers to reset ipLastError (or we won't pass handshake for 1 minute) + server.killWorkers({ immediate: true }); + await delay(2000); // give time to workers to respawn }); it("should disconnect the client if it sends a pong frame", async () => { @@ -194,6 +206,10 @@ describe("Peer socket endpoint", () => { pong(); await delay(500); expect(socket.state).toBe("closed"); + + // kill workers to reset ipLastError (or we won't pass handshake for 1 minute) + server.killWorkers({ immediate: true }); + await delay(2000); // give time to workers to respawn }); }); }); diff --git a/packages/core-p2p/src/socket-server/worker.ts b/packages/core-p2p/src/socket-server/worker.ts index eb8629ba95..7a03c0d3d3 100644 --- a/packages/core-p2p/src/socket-server/worker.ts +++ b/packages/core-p2p/src/socket-server/worker.ts @@ -5,10 +5,14 @@ import { SocketErrors } from "../enums"; import { requestSchemas } from "../schemas"; import { RateLimiter } from "./rate-limiter"; +const MINUTE_IN_MILLISECONDS = 1000 * 60; +const HOUR_IN_MILLISECONDS = MINUTE_IN_MILLISECONDS * 60; + const ajv = new Ajv(); export class Worker extends SCWorker { private config: Record; + private ipLastError: Record = {}; private rateLimiter: RateLimiter; public async run() { @@ -16,6 +20,9 @@ export class Worker extends SCWorker { await this.loadConfiguration(); + // purge ipLastError every hour to free up memory + setInterval(() => (this.ipLastError = {}), HOUR_IN_MILLISECONDS); + // @ts-ignore this.scServer.wsServer.on("connection", (ws, req) => this.handlePayload(ws, req)); this.scServer.on("connection", socket => this.handleConnection(socket)); @@ -64,13 +71,13 @@ export class Worker extends SCWorker { } private handlePayload(ws, req) { - ws.on("ping", () => { - ws.terminate(); + ws.prependListener("ping", () => { + this.setErrorForIpAndTerminate(ws, req); }); - ws.on("pong", () => { - ws.terminate(); + ws.prependListener("pong", () => { + this.setErrorForIpAndTerminate(ws, req); }); - ws.on("message", message => { + ws.prependListener("message", message => { try { const InvalidMessagePayloadError: Error = this.createError( SocketErrors.InvalidMessagePayload, @@ -82,6 +89,10 @@ export class Worker extends SCWorker { throw InvalidMessagePayloadError; } ws._lastPingTime = timeNow; + } else if (message.length < 10) { + // except for #2 message, we should have JSON with some required properties + // (see below) which implies that message length should be longer than 10 chars + this.setErrorForIpAndTerminate(ws, req); } else { const parsed = JSON.parse(message); if ( @@ -90,15 +101,20 @@ export class Worker extends SCWorker { (typeof parsed.cid !== "number" && (parsed.event === "#disconnect" && typeof parsed.cid !== "undefined")) ) { - throw InvalidMessagePayloadError; + this.setErrorForIpAndTerminate(ws, req); } } } catch (error) { - ws.terminate(); + this.setErrorForIpAndTerminate(ws, req); } }); } + private setErrorForIpAndTerminate(ws, req): void { + this.ipLastError[req.socket.remoteAddress] = Date.now(); + ws.terminate(); + } + private async handleConnection(socket): Promise { const { data } = await this.sendToMasterAsync("p2p.utils.getHandlers"); @@ -117,14 +133,20 @@ export class Worker extends SCWorker { } private async handleHandshake(req, next): Promise { - const isBlocked = await this.rateLimiter.isBlocked(req.socket.remoteAddress); - const isBlacklisted = (this.config.blacklist || []).includes(req.socket.remoteAddress); + const ip = req.socket.remoteAddress; + if (this.ipLastError[ip] && this.ipLastError[ip] > Date.now() - MINUTE_IN_MILLISECONDS) { + req.socket.destroy(); + return; + } + + const isBlocked = await this.rateLimiter.isBlocked(ip); + const isBlacklisted = (this.config.blacklist || []).includes(ip); if (isBlocked || isBlacklisted) { next(this.createError(SocketErrors.Forbidden, "Blocked due to rate limit or blacklisted.")); return; } - const cidrRemoteAddress = cidr(`${req.socket.remoteAddress}/24`); + const cidrRemoteAddress = cidr(`${ip}/24`); const sameSubnetSockets = Object.values({ ...this.scServer.clients, ...this.scServer.pendingClients }).filter( client => cidr(`${client.remoteAddress}/24`) === cidrRemoteAddress, );