Skip to content

Commit

Permalink
fix(base-driver): Update the web socket upgrade behavior (#20142)
Browse files Browse the repository at this point in the history
  • Loading branch information
mykola-mokhnach committed May 24, 2024
1 parent 82406f1 commit 275790e
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 48 deletions.
86 changes: 70 additions & 16 deletions packages/base-driver/lib/express/middleware.js
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,16 @@ import _ from 'lodash';
import log from './logger';
import {errors} from '../protocol';
import {handleIdempotency} from './idempotency';
import {pathToRegexp} from 'path-to-regexp';

function allowCrossDomain(req, res, next) {
/**
*
* @param {import('express').Request} req
* @param {import('express').Response} res
* @param {import('express').NextFunction} next
* @returns {any}
*/
export function allowCrossDomain(req, res, next) {
try {
res.header('Access-Control-Allow-Origin', '*');
res.header('Access-Control-Allow-Methods', 'GET, POST, PUT, OPTIONS, DELETE');
Expand All @@ -22,7 +30,11 @@ function allowCrossDomain(req, res, next) {
next();
}

function allowCrossDomainAsyncExecute(basePath) {
/**
* @param {string} basePath
* @returns {import('express').RequestHandler}
*/
export function allowCrossDomainAsyncExecute(basePath) {
return (req, res, next) => {
// there are two paths for async responses, so cover both
// https://regex101.com/r/txYiEz/1
Expand All @@ -36,12 +48,17 @@ function allowCrossDomainAsyncExecute(basePath) {
};
}

function fixPythonContentType(basePath) {
/**
*
* @param {string} basePath
* @returns {import('express').RequestHandler}
*/
export function fixPythonContentType(basePath) {
return (req, res, next) => {
// hack because python client library gives us wrong content-type
if (
new RegExp(`^${_.escapeRegExp(basePath)}`).test(req.path) &&
/^Python/.test(req.headers['user-agent'])
/^Python/.test(req.headers['user-agent'] ?? '')
) {
if (req.headers['content-type'] === 'application/x-www-form-urlencoded') {
req.headers['content-type'] = 'application/json; charset=utf-8';
Expand All @@ -51,14 +68,55 @@ function fixPythonContentType(basePath) {
};
}

function defaultToJSONContentType(req, res, next) {
/**
*
* @param {import('express').Request} req
* @param {import('express').Response} res
* @param {import('express').NextFunction} next
* @returns {any}
*/
export function defaultToJSONContentType(req, res, next) {
if (!req.headers['content-type']) {
req.headers['content-type'] = 'application/json; charset=utf-8';
}
next();
}

function catchAllHandler(err, req, res, next) {
/**
*
* @param {import('@appium/types').StringRecord<import('@appium/types').WSServer>} webSocketsMapping
* @returns {import('express').RequestHandler}
*/
export function handleUpgrade(webSocketsMapping) {
return (req, res, next) => {
if (!req.headers?.upgrade || _.toLower(req.headers.upgrade) !== 'websocket') {
return next();
}
let currentPathname;
try {
currentPathname = new URL(req.url ?? '').pathname;
} catch {
currentPathname = req.url ?? '';
}
for (const [pathname, wsServer] of _.toPairs(webSocketsMapping)) {
if (pathToRegexp(pathname).test(currentPathname)) {
return wsServer.handleUpgrade(req, req.socket, Buffer.from(''), (ws) => {
wsServer.emit('connection', ws, req);
});
}
}
log.info(`Did not match the websocket upgrade request at ${currentPathname} to any known route`);
next();
};
}

/**
* @param {Error} err
* @param {import('express').Request} req
* @param {import('express').Response} res
* @param {import('express').NextFunction} next
*/
export function catchAllHandler(err, req, res, next) {
if (res.headersSent) {
return next(err);
}
Expand All @@ -79,7 +137,11 @@ function catchAllHandler(err, req, res, next) {
log.error(err);
}

function catch404Handler(req, res) {
/**
* @param {import('express').Request} req
* @param {import('express').Response} res
*/
export function catch404Handler(req, res) {
log.debug(`No route found for ${req.url}`);
const error = errors.UnknownCommandError;
res.status(error.w3cStatus()).json(
Expand Down Expand Up @@ -107,12 +169,4 @@ function patchWithSessionId(req, body) {
return body;
}

export {
allowCrossDomain,
fixPythonContentType,
defaultToJSONContentType,
catchAllHandler,
allowCrossDomainAsyncExecute,
handleIdempotency,
catch404Handler,
};
export { handleIdempotency };
7 changes: 6 additions & 1 deletion packages/base-driver/lib/express/server.js
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import {
catchAllHandler,
allowCrossDomainAsyncExecute,
handleIdempotency,
handleUpgrade,
catch404Handler,
} from './middleware';
import {guineaPig, guineaPigScrollable, guineaPigAppBanner, welcome, STATIC_DIR} from './static';
Expand Down Expand Up @@ -109,6 +110,7 @@ async function server(opts) {
allowCors,
basePath,
extraMethodMap,
webSocketsMapping: appiumServer.webSocketsMapping,
});
// allow extensions to update the app and http server objects
for (const updater of serverUpdaters) {
Expand Down Expand Up @@ -139,6 +141,7 @@ function configureServer({
allowCors = true,
basePath = DEFAULT_BASE_PATH,
extraMethodMap = {},
webSocketsMapping = {},
}) {
basePath = normalizeBasePath(basePath);

Expand All @@ -152,7 +155,7 @@ function configureServer({
app.use(`${basePath}/produce_error`, produceError);
app.use(`${basePath}/crash`, produceCrash);

// add middlewares
app.use(handleUpgrade(webSocketsMapping));
if (allowCors) {
app.use(allowCrossDomain);
} else {
Expand Down Expand Up @@ -195,6 +198,7 @@ function configureHttp({httpServer, reject, keepAliveTimeout}) {
* @type {AppiumServer}
*/
const appiumServer = /** @type {any} */ (httpServer);
appiumServer.webSocketsMapping = {};
appiumServer.addWebSocketHandler = addWebSocketHandler;
appiumServer.removeWebSocketHandler = removeWebSocketHandler;
appiumServer.removeAllWebSocketHandlers = removeAllWebSocketHandlers;
Expand Down Expand Up @@ -370,4 +374,5 @@ export {server, configureServer, normalizeBasePath};
* @property {boolean} [allowCors]
* @property {string} [basePath]
* @property {MethodMap} [extraMethodMap]
* @property {import('@appium/types').StringRecord} [webSocketsMapping={}]
*/
27 changes: 0 additions & 27 deletions packages/base-driver/lib/express/websocket.js
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
/* eslint-disable require-await */
import _ from 'lodash';
import {URL} from 'url';
import B from 'bluebird';
import { pathToRegexp } from 'path-to-regexp';

const DEFAULT_WS_PATHNAME_PREFIX = '/ws';

Expand All @@ -11,27 +9,6 @@ const DEFAULT_WS_PATHNAME_PREFIX = '/ws';
* @type {AppiumServer['addWebSocketHandler']}
*/
async function addWebSocketHandler(handlerPathname, handlerServer) {
if (_.isUndefined(this.webSocketsMapping)) {
this.webSocketsMapping = {};
// https://github.com/websockets/ws/pull/885
this.on('upgrade', (request, socket, head) => {
let currentPathname;
try {
currentPathname = new URL(request.url ?? '').pathname;
} catch {
currentPathname = request.url ?? '';
}
for (const [pathname, wsServer] of _.toPairs(this.webSocketsMapping)) {
if (pathToRegexp(pathname).test(currentPathname)) {
wsServer.handleUpgrade(request, socket, head, (ws) => {
wsServer.emit('connection', ws, request);
});
return;
}
}
socket.destroy();
});
}
this.webSocketsMapping[handlerPathname] = handlerServer;
}

Expand All @@ -40,10 +17,6 @@ async function addWebSocketHandler(handlerPathname, handlerServer) {
* @type {AppiumServer['getWebSocketHandlers']}
*/
async function getWebSocketHandlers(keysFilter = null) {
if (_.isEmpty(this.webSocketsMapping)) {
return {};
}

return _.toPairs(this.webSocketsMapping).reduce((acc, [pathname, wsServer]) => {
if (!_.isString(keysFilter) || pathname.includes(keysFilter)) {
acc[pathname] = wsServer;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,9 @@ describe('Websockets (e2e)', function () {
ws.send(WS_DATA);
}
});
const previousListenerCount = baseServer.listenerCount('upgrade');
const endpoint = `${DEFAULT_WS_PATHNAME_PREFIX}/hello`;
const timeout = 5000;
await baseServer.addWebSocketHandler(endpoint, wss);
baseServer.listenerCount('upgrade').should.be.above(previousListenerCount);
_.keys(await baseServer.getWebSocketHandlers()).length.should.eql(1);
await new B((resolve, reject) => {
const client = new WebSocket(`ws://${TEST_HOST}:${port}${endpoint}`);
Expand Down Expand Up @@ -78,7 +76,6 @@ describe('Websockets (e2e)', function () {
client.on('error', resolve);
setTimeout(resolve, timeout);
});
baseServer.listenerCount('upgrade').should.be.above(previousListenerCount);
});
});
});
2 changes: 1 addition & 1 deletion packages/base-driver/test/unit/express/server.spec.js
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ describe('server configuration', function () {
const app = fakeApp();
const configureRoutes = () => {};
configureServer({app, addRoutes: configureRoutes});
app.use.callCount.should.equal(14);
app.use.callCount.should.equal(15);
app.all.callCount.should.equal(4);
});

Expand Down

0 comments on commit 275790e

Please sign in to comment.