From 828f42811801e0e75a17ed86a3ed6a46103e716f Mon Sep 17 00:00:00 2001 From: Leila Wang Date: Mon, 24 Feb 2020 16:59:29 +0000 Subject: [PATCH] fix(sdk): avoid race conditions when establishing connection --- .../src/background/tasks/acceptConnection.js | 2 + .../extension/src/client/Aztec/ApiManager.js | 30 ++++++++-- .../services/ConnectionService/index.js | 44 +++++++++++++- packages/extension/src/utils/Web3Service.js | 58 +++++++++++-------- 4 files changed, 102 insertions(+), 32 deletions(-) diff --git a/packages/extension/src/background/tasks/acceptConnection.js b/packages/extension/src/background/tasks/acceptConnection.js index a265fe32b..9a79a5f39 100644 --- a/packages/extension/src/background/tasks/acceptConnection.js +++ b/packages/extension/src/background/tasks/acceptConnection.js @@ -27,6 +27,7 @@ export default function acceptConnection() { window.addEventListener('message', async (event) => { if (event.data.type === connectionRequestEvent) { const { + requestId, clientProfile, } = event.data; @@ -61,6 +62,7 @@ export default function acceptConnection() { type: connectionApprovedEvent, code: '200', data: networkConfig, + requestId, }, '*', [channel.port2]); } }); diff --git a/packages/extension/src/client/Aztec/ApiManager.js b/packages/extension/src/client/Aztec/ApiManager.js index c3bdcec06..80e08d4f9 100644 --- a/packages/extension/src/client/Aztec/ApiManager.js +++ b/packages/extension/src/client/Aztec/ApiManager.js @@ -47,7 +47,7 @@ export default class ApiManager { } if (!this.autoRefreshOnProfileChange) { - this.disable(); + this.disable(this.currentOptions, true); } else if (!this.aztecAccount && !this.sessionPromise) { this.generateDefaultApis(); } @@ -59,6 +59,7 @@ export default class ApiManager { async generateDefaultApis() { const apis = await ApiPermissionService.generateApis(); this.setApis(apis); + return apis; } bindOneTimeProfileChangeListener(cb) { @@ -154,12 +155,14 @@ export default class ApiManager { return this.sessionPromise; }; - async disable(options = this.currentOptions) { - this.flushEnableListeners(options); + async disable(options = this.currentOptions, internalCall = false) { this.currentOptions = null; this.aztecAccount = null; this.error = null; - this.unbindOneTimeProfileChangeListener(); + this.flushEnableListeners(options); + if (!internalCall) { + this.unbindOneTimeProfileChangeListener(); + } await this.generateDefaultApis(); await ConnectionService.disconnect(); } @@ -167,10 +170,12 @@ export default class ApiManager { async refreshSession(options) { this.aztecAccount = null; this.error = null; - await this.generateDefaultApis(); + const defaultApis = await this.generateDefaultApis(); await ConnectionService.disconnect(); + const hasWalletPermission = !!(defaultApis.web3.account && defaultApis.web3.network); let networkSwitchedDuringStart = false; + let abort = false; this.bindOneTimeProfileChangeListener(() => { networkSwitchedDuringStart = true; @@ -190,6 +195,18 @@ export default class ApiManager { } = options; const tasks = [ + async () => { + await Web3Service.init({ + providerUrl, + }); + const { + account, + networkId, + } = Web3Service; + if (!hasWalletPermission || !account.address || !networkId) { + abort = true; + } + }, async () => ConnectionService.openConnection({ apiKey, providerUrl, @@ -205,7 +222,7 @@ export default class ApiManager { ApiPermissionService.validateContractConfigs(networkConfig); return networkConfig; }, - async networkConfig => Web3Service.init(networkConfig), + async ({ contractsConfig }) => Web3Service.registerContractsConfig(contractsConfig), async () => { const { account: aztecAccount, @@ -232,6 +249,7 @@ export default class ApiManager { let prevResult; await asyncForEach(tasks, async (task) => { if (networkSwitchedDuringStart + || abort || !isEqual(options, this.currentOptions) ) { return; diff --git a/packages/extension/src/client/services/ConnectionService/index.js b/packages/extension/src/client/services/ConnectionService/index.js index 7c7f42c5f..23598b2fe 100644 --- a/packages/extension/src/client/services/ConnectionService/index.js +++ b/packages/extension/src/client/services/ConnectionService/index.js @@ -36,6 +36,11 @@ import getApiKeyApproval from '~/client/utils/getApiKeyApproval'; import backgroundFrame from './backgroundFrame'; class ConnectionService { + constructor() { + this.callbackMapping = {}; + this.disconnectRequestId = null; + } + async init() { this.clientId = randomId(); this.setInitialVars(); @@ -65,20 +70,50 @@ class ConnectionService { async disconnect() { if (!this.port) return; + let requestId = this.disconnectRequestId; + if (requestId) { + await new Promise((resolve) => { + if (!this.callbackMapping[requestId]) { + this.callbackMapping[requestId] = { + [uiCloseEvent]: [], + }; + } + this.callbackMapping[requestId][uiCloseEvent].push(resolve); + }); + return; + } + + requestId = randomId(); + this.disconnectRequestId = requestId; + + const uiDisconnected = new Promise((resolve) => { + this.callbackMapping[requestId] = { + [uiCloseEvent]: [resolve], + }; + }); + await this.postToBackground({ type: clientDisconnectEvent, + requestId, }); + await uiDisconnected; + + this.disconnectRequestId = null; this.setInitialVars(); } async openConnection(clientProfile) { + if (this.port) { + errorLog('Connection to background has been established'); + } const { apiKey, } = clientProfile; this.apiKey = apiKey; const frame = await backgroundFrame.ensureCreated(); + const requestId = randomId(); const backgroundResponse = fromEvent(window, 'message') .pipe( @@ -89,7 +124,7 @@ class ConnectionService { frame.contentWindow.postMessage({ type: connectionRequestEvent, - requestId: randomId(), + requestId, clientId: this.clientId, sender: 'WEB_CLIENT', clientProfile, @@ -112,6 +147,7 @@ class ConnectionService { this.port.onmessage = ({ data }) => { const { type, + requestId, } = data; switch (type) { case uiOpenEvent: @@ -131,6 +167,9 @@ class ConnectionService { break; default: } + if (this.callbackMapping[requestId] && this.callbackMapping[requestId][type]) { + this.callbackMapping[requestId][type].forEach(cb => cb()); + } }; } @@ -174,6 +213,7 @@ class ConnectionService { async postToBackground({ type, data, + requestId: customRequestId, }) { if (!this.port) { return { @@ -181,7 +221,7 @@ class ConnectionService { }; } - const requestId = randomId(); + const requestId = customRequestId || randomId(); this.port.postMessage({ type, clientId: this.clientId, diff --git a/packages/extension/src/utils/Web3Service.js b/packages/extension/src/utils/Web3Service.js index 5f70e7256..c34d27351 100644 --- a/packages/extension/src/utils/Web3Service.js +++ b/packages/extension/src/utils/Web3Service.js @@ -7,6 +7,8 @@ import { class Web3Service { constructor() { + this.providerUrl = ''; + this.provider = null; this.web3 = null; this.eth = null; this.contracts = {}; @@ -29,44 +31,52 @@ class Web3Service { contractsConfig, account, } = {}) { - this.reset(); - - let provider; - if (providerUrl) { - if (providerUrl.match(/^wss?:\/\//)) { - provider = new Web3.providers.WebsocketProvider(providerUrl); + let { + web3, + provider, + } = this; + if (!web3 || providerUrl !== this.providerUrl) { + if (providerUrl) { + if (providerUrl.match(/^wss?:\/\//)) { + provider = new Web3.providers.WebsocketProvider(providerUrl); + } else { + provider = new Web3.providers.HttpProvider(providerUrl); + } } else { - provider = new Web3.providers.HttpProvider(providerUrl); + provider = window.ethereum; } - } else { - provider = window.ethereum; // TODO - to be removed // https://metamask.github.io/metamask-docs/API_Reference/Ethereum_Provider#ethereum.autorefreshonnetworkchange-(to-be-removed) provider.autoRefreshOnNetworkChange = false; - } - if (!provider) { - errorLog('Provider cannot be empty.'); - return; - } + if (!provider) { + errorLog('Provider cannot be empty.'); + return; + } - if (provider - && typeof provider.enable === 'function' - ) { - await provider.enable(); + if (typeof provider.enable === 'function') { + await provider.enable(); + } + + web3 = new Web3(provider); + + this.providerUrl = providerUrl; + this.provider = provider; + this.web3 = web3; + this.eth = web3.eth; } - const web3 = new Web3(provider); const networkId = await web3.eth.net.getId(); - - this.web3 = web3; - this.eth = web3.eth; this.networkId = networkId || 0; if (account) { this.account = account; } else { - const [address] = await web3.eth.getAccounts(); + let [address] = await web3.eth.getAccounts(); + if (!address && typeof provider.enable === 'function') { + await provider.enable(); + [address] = await web3.eth.getAccounts(); + } this.account = { address, }; @@ -370,7 +380,7 @@ class Web3Service { try { gasPrice = await web3.eth.getGasPrice(); } catch (e) { - console.log(e); + errorLog(e); } return this.triggerMethod('send', { method: gsnContract.methods[methodName],