Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(base-driver): Update the web socket upgrade behavior #20142

Merged
merged 8 commits into from
May 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 70 additions & 16 deletions packages/base-driver/lib/express/middleware.js
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,16 @@ import _ from 'lodash';
import log from './logger';
import {errors} from '../protocol';
import {handleIdempotency} from './idempotency';
import {pathToRegexp} from 'path-to-regexp';

function allowCrossDomain(req, res, next) {
/**
*
* @param {import('express').Request} req
* @param {import('express').Response} res
* @param {import('express').NextFunction} next
* @returns {any}
*/
export function allowCrossDomain(req, res, next) {
try {
res.header('Access-Control-Allow-Origin', '*');
res.header('Access-Control-Allow-Methods', 'GET, POST, PUT, OPTIONS, DELETE');
Expand All @@ -22,7 +30,11 @@ function allowCrossDomain(req, res, next) {
next();
}

function allowCrossDomainAsyncExecute(basePath) {
/**
* @param {string} basePath
* @returns {import('express').RequestHandler}
*/
export function allowCrossDomainAsyncExecute(basePath) {
return (req, res, next) => {
// there are two paths for async responses, so cover both
// https://regex101.com/r/txYiEz/1
Expand All @@ -36,12 +48,17 @@ function allowCrossDomainAsyncExecute(basePath) {
};
}

function fixPythonContentType(basePath) {
/**
*
* @param {string} basePath
* @returns {import('express').RequestHandler}
*/
export function fixPythonContentType(basePath) {
return (req, res, next) => {
// hack because python client library gives us wrong content-type
if (
new RegExp(`^${_.escapeRegExp(basePath)}`).test(req.path) &&
/^Python/.test(req.headers['user-agent'])
/^Python/.test(req.headers['user-agent'] ?? '')
) {
if (req.headers['content-type'] === 'application/x-www-form-urlencoded') {
req.headers['content-type'] = 'application/json; charset=utf-8';
Expand All @@ -51,14 +68,55 @@ function fixPythonContentType(basePath) {
};
}

function defaultToJSONContentType(req, res, next) {
/**
*
* @param {import('express').Request} req
* @param {import('express').Response} res
* @param {import('express').NextFunction} next
* @returns {any}
*/
export function defaultToJSONContentType(req, res, next) {
if (!req.headers['content-type']) {
req.headers['content-type'] = 'application/json; charset=utf-8';
}
next();
}

function catchAllHandler(err, req, res, next) {
/**
*
* @param {import('@appium/types').StringRecord<import('@appium/types').WSServer>} webSocketsMapping
* @returns {import('express').RequestHandler}
*/
export function handleUpgrade(webSocketsMapping) {
return (req, res, next) => {
if (!req.headers?.upgrade || _.toLower(req.headers.upgrade) !== 'websocket') {
return next();
}
let currentPathname;
try {
currentPathname = new URL(req.url ?? '').pathname;
} catch {
currentPathname = req.url ?? '';
}
for (const [pathname, wsServer] of _.toPairs(webSocketsMapping)) {
if (pathToRegexp(pathname).test(currentPathname)) {
return wsServer.handleUpgrade(req, req.socket, Buffer.from(''), (ws) => {
wsServer.emit('connection', ws, req);
});
}
}
log.info(`Did not match the websocket upgrade request at ${currentPathname} to any known route`);
next();
};
}

/**
* @param {Error} err
* @param {import('express').Request} req
* @param {import('express').Response} res
* @param {import('express').NextFunction} next
*/
export function catchAllHandler(err, req, res, next) {
if (res.headersSent) {
return next(err);
}
Expand All @@ -79,7 +137,11 @@ function catchAllHandler(err, req, res, next) {
log.error(err);
}

function catch404Handler(req, res) {
/**
* @param {import('express').Request} req
* @param {import('express').Response} res
*/
export function catch404Handler(req, res) {
log.debug(`No route found for ${req.url}`);
const error = errors.UnknownCommandError;
res.status(error.w3cStatus()).json(
Expand Down Expand Up @@ -107,12 +169,4 @@ function patchWithSessionId(req, body) {
return body;
}

export {
allowCrossDomain,
fixPythonContentType,
defaultToJSONContentType,
catchAllHandler,
allowCrossDomainAsyncExecute,
handleIdempotency,
catch404Handler,
};
export { handleIdempotency };
7 changes: 6 additions & 1 deletion packages/base-driver/lib/express/server.js
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import {
catchAllHandler,
allowCrossDomainAsyncExecute,
handleIdempotency,
handleUpgrade,
catch404Handler,
} from './middleware';
import {guineaPig, guineaPigScrollable, guineaPigAppBanner, welcome, STATIC_DIR} from './static';
Expand Down Expand Up @@ -109,6 +110,7 @@ async function server(opts) {
allowCors,
basePath,
extraMethodMap,
webSocketsMapping: appiumServer.webSocketsMapping,
});
// allow extensions to update the app and http server objects
for (const updater of serverUpdaters) {
Expand Down Expand Up @@ -139,6 +141,7 @@ function configureServer({
allowCors = true,
basePath = DEFAULT_BASE_PATH,
extraMethodMap = {},
webSocketsMapping = {},
}) {
basePath = normalizeBasePath(basePath);

Expand All @@ -152,7 +155,7 @@ function configureServer({
app.use(`${basePath}/produce_error`, produceError);
app.use(`${basePath}/crash`, produceCrash);

// add middlewares
app.use(handleUpgrade(webSocketsMapping));
if (allowCors) {
app.use(allowCrossDomain);
} else {
Expand Down Expand Up @@ -195,6 +198,7 @@ function configureHttp({httpServer, reject, keepAliveTimeout}) {
* @type {AppiumServer}
*/
const appiumServer = /** @type {any} */ (httpServer);
appiumServer.webSocketsMapping = {};
appiumServer.addWebSocketHandler = addWebSocketHandler;
appiumServer.removeWebSocketHandler = removeWebSocketHandler;
appiumServer.removeAllWebSocketHandlers = removeAllWebSocketHandlers;
Expand Down Expand Up @@ -370,4 +374,5 @@ export {server, configureServer, normalizeBasePath};
* @property {boolean} [allowCors]
* @property {string} [basePath]
* @property {MethodMap} [extraMethodMap]
* @property {import('@appium/types').StringRecord} [webSocketsMapping={}]
*/
27 changes: 0 additions & 27 deletions packages/base-driver/lib/express/websocket.js
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
/* eslint-disable require-await */
import _ from 'lodash';
import {URL} from 'url';
import B from 'bluebird';
import { pathToRegexp } from 'path-to-regexp';

const DEFAULT_WS_PATHNAME_PREFIX = '/ws';

Expand All @@ -11,27 +9,6 @@ const DEFAULT_WS_PATHNAME_PREFIX = '/ws';
* @type {AppiumServer['addWebSocketHandler']}
*/
async function addWebSocketHandler(handlerPathname, handlerServer) {
if (_.isUndefined(this.webSocketsMapping)) {
this.webSocketsMapping = {};
// https://github.com/websockets/ws/pull/885
this.on('upgrade', (request, socket, head) => {
let currentPathname;
try {
currentPathname = new URL(request.url ?? '').pathname;
} catch {
currentPathname = request.url ?? '';
}
for (const [pathname, wsServer] of _.toPairs(this.webSocketsMapping)) {
if (pathToRegexp(pathname).test(currentPathname)) {
wsServer.handleUpgrade(request, socket, head, (ws) => {
wsServer.emit('connection', ws, request);
});
return;
}
}
socket.destroy();
});
}
this.webSocketsMapping[handlerPathname] = handlerServer;
}

Expand All @@ -40,10 +17,6 @@ async function addWebSocketHandler(handlerPathname, handlerServer) {
* @type {AppiumServer['getWebSocketHandlers']}
*/
async function getWebSocketHandlers(keysFilter = null) {
if (_.isEmpty(this.webSocketsMapping)) {
return {};
}

return _.toPairs(this.webSocketsMapping).reduce((acc, [pathname, wsServer]) => {
if (!_.isString(keysFilter) || pathname.includes(keysFilter)) {
acc[pathname] = wsServer;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,9 @@ describe('Websockets (e2e)', function () {
ws.send(WS_DATA);
}
});
const previousListenerCount = baseServer.listenerCount('upgrade');
const endpoint = `${DEFAULT_WS_PATHNAME_PREFIX}/hello`;
const timeout = 5000;
await baseServer.addWebSocketHandler(endpoint, wss);
baseServer.listenerCount('upgrade').should.be.above(previousListenerCount);
_.keys(await baseServer.getWebSocketHandlers()).length.should.eql(1);
await new B((resolve, reject) => {
const client = new WebSocket(`ws://${TEST_HOST}:${port}${endpoint}`);
Expand Down Expand Up @@ -78,7 +76,6 @@ describe('Websockets (e2e)', function () {
client.on('error', resolve);
setTimeout(resolve, timeout);
});
baseServer.listenerCount('upgrade').should.be.above(previousListenerCount);
});
});
});
2 changes: 1 addition & 1 deletion packages/base-driver/test/unit/express/server.spec.js
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ describe('server configuration', function () {
const app = fakeApp();
const configureRoutes = () => {};
configureServer({app, addRoutes: configureRoutes});
app.use.callCount.should.equal(14);
app.use.callCount.should.equal(15);
app.all.callCount.should.equal(4);
});

Expand Down