Skip to content

Commit

Permalink
Avoid re-prompting a user login when refresh token is processed (#3749)
Browse files Browse the repository at this point in the history
* Refresh token refactor

* Change the if/else into a boolean flag

* Use validateToken to exchange refresh token in OAuth getFetcher (#3750)

* Add missing "$" in template

* Use validateToken to ensure refresh token is exchanged

---------

Co-authored-by: Garrett Stevens <stevens.garrett.j@gmail.com>
  • Loading branch information
cmdcolin and garrettjstevens committed Jun 8, 2023
1 parent d3219d6 commit 27484b9
Show file tree
Hide file tree
Showing 10 changed files with 107 additions and 124 deletions.
2 changes: 1 addition & 1 deletion package.json
Expand Up @@ -103,7 +103,7 @@
"cross-spawn": "^7.0.1",
"crypto-js": "^3.0.0",
"dependency-graph": "^0.11.0",
"electron": "25.0.1",
"electron": "25.1.0",
"electron-builder": "^23.0.3",
"electron-mock-ipc": "^0.3.8",
"eslint": "^8.0.0",
Expand Down
Expand Up @@ -210,12 +210,15 @@ export const InternetAccount = types
* #action
*/
addAuthHeaderToInit(init: RequestInit = {}, token: string) {
const newHeaders = new Headers(init.headers || {})
newHeaders.append(
self.authHeader,
self.tokenType ? `${self.tokenType} ${token}` : token,
)
return { ...init, headers: newHeaders }
return {
...init,
headers: new Headers({
...init.headers,
[self.authHeader]: self.tokenType
? `${self.tokenType} ${token}`
: token,
}),
}
},
/**
* #action
Expand Down
8 changes: 0 additions & 8 deletions plugins/authentication/src/DropboxOAuthModel/configSchema.ts
Expand Up @@ -51,14 +51,6 @@ const DropboxOAuthConfigSchema = ConfigurationSchema(
'getdropbox.com',
],
},
/**
* #slot
*/
hasRefreshToken: {
description: 'true if the endpoint can supply a refresh token',
type: 'boolean',
defaultValue: true,
},
},
{
/**
Expand Down
3 changes: 1 addition & 2 deletions plugins/authentication/src/DropboxOAuthModel/model.tsx
Expand Up @@ -97,8 +97,7 @@ const stateModelFactory = (
},
)
if (!response.ok) {
const refreshToken =
self.hasRefreshToken && self.retrieveRefreshToken()
const refreshToken = self.retrieveRefreshToken()
if (refreshToken) {
self.removeRefreshToken()
const newToken = await self.exchangeRefreshForAccessToken(
Expand Down
Expand Up @@ -17,7 +17,7 @@ export const HTTPBasicLoginForm = ({
open
maxWidth="xl"
data-testid="login-httpbasic"
title={`Log In for {internetAccountId}`}
title={`Log In for ${internetAccountId}`}
>
<form
onSubmit={event => {
Expand Down
11 changes: 2 additions & 9 deletions plugins/authentication/src/OAuthModel/configSchema.ts
Expand Up @@ -70,18 +70,11 @@ const OAuthConfigSchema = ConfigurationSchema(
* #slot
*/
responseType: {
description: 'the type of response from the authorization endpoint',
description:
"the type of response from the authorization endpoint. can be 'token' or 'code'",
type: 'string',
defaultValue: 'code',
},
/**
* #slot
*/
hasRefreshToken: {
description: 'true if the endpoint can supply a refresh token',
type: 'boolean',
defaultValue: false,
},
},
{
/**
Expand Down
148 changes: 60 additions & 88 deletions plugins/authentication/src/OAuthModel/model.tsx
Expand Up @@ -5,7 +5,12 @@ import { Instance, types } from 'mobx-state-tree'

// locals
import { OAuthInternetAccountConfigModel } from './configSchema'
import { fixup, generateChallenge } from './util'
import {
fixup,
generateChallenge,
processError,
processTokenResponse,
} from './util'
import { getResponseError } from '../util'

interface OAuthData {
Expand All @@ -19,14 +24,6 @@ interface OAuthData {
state?: string
}

interface OAuthExchangeData {
code: string
grant_type: string
client_id: string
redirect_uri: string
code_verifier?: string
}

/**
* #stateModel OAuthInternetAccount
*/
Expand Down Expand Up @@ -98,27 +95,22 @@ const stateModelFactory = (configSchema: OAuthInternetAccountConfigModel) => {
* Can override or extend if dynamic state is needed.
*/
state(): string | undefined {
return getConf(self, 'state') || undefined
return getConf(self, 'state')
},
/**
* #getter
*/
get responseType(): 'token' | 'code' {
return getConf(self, 'responseType')
},
/**
* #getter
*/
get hasRefreshToken(): boolean {
return getConf(self, 'hasRefreshToken')
},
/**
* #getter
*/
get refreshTokenKey() {
return `${self.internetAccountId}-refreshToken`
},
}))

.actions(self => ({
/**
* #action
Expand All @@ -145,17 +137,15 @@ const stateModelFactory = (configSchema: OAuthInternetAccountConfigModel) => {
token: string,
redirectUri: string,
): Promise<string> {
const data: OAuthExchangeData = {
code: token,
grant_type: 'authorization_code',
client_id: self.clientId,
redirect_uri: redirectUri,
}
if (self.needsPKCE) {
data.code_verifier = self.codeVerifierPKCE
}

const params = new URLSearchParams(Object.entries(data))
const params = new URLSearchParams(
Object.entries({
code: token,
grant_type: 'authorization_code',
client_id: self.clientId,
redirect_uri: redirectUri,
...(self.needsPKCE ? { code_verifier: self.codeVerifierPKCE } : {}),
}),
)

const response = await fetch(self.tokenEndpoint, {
method: 'POST',
Expand All @@ -172,55 +162,43 @@ const stateModelFactory = (configSchema: OAuthInternetAccountConfigModel) => {
)
}

const accessToken = await response.json()
if (accessToken.refresh_token) {
this.storeRefreshToken(accessToken.refresh_token)
}
return accessToken.access_token
const data = await response.json()
return processTokenResponse(data, token =>
this.storeRefreshToken(token),
)
},
/**
* #action
*/
async exchangeRefreshForAccessToken(
refreshToken: string,
): Promise<string> {
const data = {
grant_type: 'refresh_token',
refresh_token: refreshToken,
client_id: self.clientId,
}

const params = new URLSearchParams(Object.entries(data))

const response = await fetch(self.tokenEndpoint, {
method: 'POST',
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
body: params.toString(),
body: new URLSearchParams(
Object.entries({
grant_type: 'refresh_token',
refresh_token: refreshToken,
client_id: self.clientId,
}),
).toString(),
})

if (!response.ok) {
self.removeToken()
let text = await response.text()
try {
const obj = JSON.parse(text)
if (obj.error === 'invalid_grant') {
this.removeRefreshToken()
}
text = obj?.error_description ?? text
} catch (e) {
/* just use original text as error */
}

const text = await response.text()
throw new Error(
await getResponseError({ response, statusText: text }),
await getResponseError({
response,
statusText: processError(text, () => this.removeRefreshToken()),
}),
)
}

const accessToken = await response.json()
if (accessToken.refresh_token) {
this.storeRefreshToken(accessToken.refresh_token)
}
return accessToken.access_token
const data = await response.json()
return processTokenResponse(data, token =>
this.storeRefreshToken(token),
)
},
}))
.actions(self => {
Expand Down Expand Up @@ -286,10 +264,10 @@ const stateModelFactory = (configSchema: OAuthInternetAccountConfigModel) => {
)
self.storeToken(token)
return resolve(token)
} catch (error) {
return error instanceof Error
? reject(error)
: reject(new Error(String(error)))
} catch (e) {
return e instanceof Error
? reject(e)
: reject(new Error(String(e)))
}
}
if (redirectUriWithInfo.includes('access_denied')) {
Expand All @@ -316,6 +294,7 @@ const stateModelFactory = (configSchema: OAuthInternetAccountConfigModel) => {
client_id: self.clientId,
redirect_uri: redirectUri,
response_type: self.responseType || 'code',
token_access_type: 'offline',
}

if (self.state()) {
Expand All @@ -331,10 +310,6 @@ const stateModelFactory = (configSchema: OAuthInternetAccountConfigModel) => {
data.code_challenge_method = 'S256'
}

if (self.hasRefreshToken) {
data.token_access_type = 'offline'
}

const params = new URLSearchParams(Object.entries(data))

const url = new URL(self.authEndpoint)
Expand All @@ -356,8 +331,7 @@ const stateModelFactory = (configSchema: OAuthInternetAccountConfigModel) => {
// eslint-disable-next-line @typescript-eslint/no-floating-promises
this.finishOAuthWindow(eventFromDesktop, resolve, reject)
} else {
const options = `width=500,height=600,left=0,top=0`
window.open(url, eventName, options)
window.open(url, eventName, `width=500,height=600,left=0,top=0`)
}
},
/**
Expand All @@ -367,22 +341,29 @@ const stateModelFactory = (configSchema: OAuthInternetAccountConfigModel) => {
resolve: (token: string) => void,
reject: (error: Error) => void,
) {
const refreshToken =
self.hasRefreshToken && self.retrieveRefreshToken()
const refreshToken = self.retrieveRefreshToken()
let doUserFlow = true

// if there is a refresh token, then try it out, and only if that
// refresh token succeeds, set doUserFlow to false
if (refreshToken) {
try {
const token = await self.exchangeRefreshForAccessToken(
refreshToken,
)
resolve(token)
} catch (err) {
doUserFlow = false
} catch (e) {
console.error(e)
self.removeRefreshToken()
}
}
this.addMessageChannel(resolve, reject)
// may want to improve handling
// eslint-disable-next-line @typescript-eslint/no-floating-promises
this.useEndpointForAuthorization(resolve, reject)
if (doUserFlow) {
this.addMessageChannel(resolve, reject)
// may want to improve handling
// eslint-disable-next-line @typescript-eslint/no-floating-promises
this.useEndpointForAuthorization(resolve, reject)
}
},
/**
* #action
Expand All @@ -392,8 +373,7 @@ const stateModelFactory = (configSchema: OAuthInternetAccountConfigModel) => {
const response = await fetch(location.uri, newInit)
if (!response.ok) {
self.removeToken()
const refreshToken =
self.hasRefreshToken && self.retrieveRefreshToken()
const refreshToken = self.retrieveRefreshToken()
if (refreshToken) {
try {
if (!exchangedTokenPromise) {
Expand Down Expand Up @@ -436,17 +416,9 @@ const stateModelFactory = (configSchema: OAuthInternetAccountConfigModel) => {
const fetcher = superGetFetcher(loc)
return async (input: RequestInfo, init?: RequestInit) => {
if (loc) {
try {
await self.getPreAuthorizationInformation(loc)
} catch (e) {
/* ignore error */
}
}
const response = await fetcher(input, init)
if (!response.ok) {
throw new Error(await getResponseError({ response }))
await self.validateToken(await self.getToken(loc), loc)
}
return response
return fetcher(input, init)
}
},
}
Expand Down
24 changes: 24 additions & 0 deletions plugins/authentication/src/OAuthModel/util.ts
Expand Up @@ -7,3 +7,27 @@ export async function generateChallenge(val: string) {
const Base64 = await import('crypto-js/enc-base64')
return fixup(Base64.stringify(sha256(val)))
}

// if response is JSON, checks if it needs to remove tokens in error, or just plain throw
export function processError(text: string, invalidErrorCb: () => void) {
try {
const obj = JSON.parse(text)
if (obj.error === 'invalid_grant') {
invalidErrorCb()
}
return obj?.error_description ?? text
} catch (e) {
/* response text is not json, just use original text as error */
}
return text
}

export function processTokenResponse(
data: { refresh_token?: string; access_token: string },
storeRefreshTokenCb: (str: string) => void,
) {
if (data.refresh_token) {
storeRefreshTokenCb(data.refresh_token)
}
return data.access_token
}

0 comments on commit 27484b9

Please sign in to comment.