Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[msal-browser][msal-common] Make library state an object and add timestamp (Updated) #1790

Merged
merged 9 commits into from
Jun 23, 2020
2 changes: 1 addition & 1 deletion lib/msal-browser/src/app/PublicClientApplication.ts
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,7 @@ export class PublicClientApplication {

validatedRequest.state = ProtocolUtils.setRequestState(
(request && request.state) || "",
this.browserCrypto.createNewGuid()
this.browserCrypto
);

validatedRequest.correlationId = (request && request.correlationId) || this.browserCrypto.createNewGuid();
Expand Down
6 changes: 1 addition & 5 deletions lib/msal-browser/src/cache/BrowserStorage.ts
Original file line number Diff line number Diff line change
Expand Up @@ -367,11 +367,7 @@ export class BrowserStorage extends CacheManager {
// check state and remove associated cache items
this.getKeys().forEach(key => {
if (!StringUtils.isEmpty(state) && key.indexOf(state) !== -1) {
const splitKey = key.split(Constants.RESOURCE_DELIM);
const keyState = splitKey.length > 1 ? splitKey[splitKey.length-1]: null;
if (keyState === state) {
this.removeItem(key);
}
this.removeItem(key);
}
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License.
*/
import { SPAClient, StringUtils, AuthorizationCodeRequest, ProtocolUtils, CacheSchemaType, AuthenticationResult } from "@azure/msal-common";
import { SPAClient, StringUtils, AuthorizationCodeRequest, CacheSchemaType, AuthenticationResult } from "@azure/msal-common";
import { BrowserStorage } from "../cache/BrowserStorage";
import { BrowserAuthError } from "../error/BrowserAuthError";
import { TemporaryCacheKeys } from "../utils/BrowserConstants";
Expand Down Expand Up @@ -48,11 +48,8 @@ export abstract class InteractionHandler {
// Assign code to request
this.authCodeRequest.code = authCode;

// Extract user state.
const userState = ProtocolUtils.getUserRequestState(requestState);

// Acquire token with retrieved code.
const tokenResponse = await this.authModule.acquireToken(this.authCodeRequest, userState, cachedNonce);
const tokenResponse = await this.authModule.acquireToken(this.authCodeRequest, requestState, cachedNonce);
this.browserStorage.cleanRequest();
return tokenResponse;
}
Expand Down
7 changes: 2 additions & 5 deletions lib/msal-browser/src/interaction_handler/RedirectHandler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License.
*/
import { StringUtils, AuthorizationCodeRequest, ICrypto, ProtocolUtils, CacheSchemaType, AuthenticationResult } from "@azure/msal-common";
import { StringUtils, AuthorizationCodeRequest, ICrypto, CacheSchemaType, AuthenticationResult } from "@azure/msal-common";
import { InteractionHandler } from "./InteractionHandler";
import { BrowserAuthError } from "../error/BrowserAuthError";
import { BrowserConstants, TemporaryCacheKeys } from "../utils/BrowserConstants";
Expand Down Expand Up @@ -68,11 +68,8 @@ export class RedirectHandler extends InteractionHandler {
// Hash was processed successfully - remove from cache
this.browserStorage.removeItem(this.browserStorage.generateCacheKey(TemporaryCacheKeys.URL_HASH));

// Extract user state.
const userState = ProtocolUtils.getUserRequestState(requestState);

// Acquire token with retrieved code.
const tokenResponse = await this.authModule.acquireToken(this.authCodeRequest, userState, cachedNonce);
const tokenResponse = await this.authModule.acquireToken(this.authCodeRequest, requestState, cachedNonce);
this.browserStorage.cleanRequest();
return tokenResponse;
}
Expand Down
11 changes: 6 additions & 5 deletions lib/msal-common/src/cache/CacheManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ export abstract class CacheManager implements ICacheManager {
return null;
} else {
const allAccounts = accountValues.map<AccountInfo>((value) => {
const accountObj: AccountEntity = JSON.parse(JSON.stringify(value));
let accountObj: AccountEntity = new AccountEntity();
accountObj = CacheManager.toObject(accountObj, JSON.parse(JSON.stringify(value)));
return accountObj.getAccountInfo();
});
return allAccounts;
Expand Down Expand Up @@ -462,19 +463,19 @@ export abstract class CacheManager implements ICacheManager {
}

export class DefaultStorageClass extends CacheManager {
setItem(key: string, value: string | object, type?: string): void {
setItem(): void {
const notImplErr = "Storage interface - setItem() has not been implemented for the cacheStorage interface.";
throw AuthError.createUnexpectedError(notImplErr);
}
getItem(key: string, type?: string): string | object {
getItem(): string | object {
const notImplErr = "Storage interface - getItem() has not been implemented for the cacheStorage interface.";
throw AuthError.createUnexpectedError(notImplErr);
}
removeItem(key: string, type?: string): boolean {
removeItem(): boolean {
const notImplErr = "Storage interface - removeItem() has not been implemented for the cacheStorage interface.";
throw AuthError.createUnexpectedError(notImplErr);
}
containsKey(key: string, type?: string): boolean {
containsKey(): boolean {
const notImplErr = "Storage interface - containsKey() has not been implemented for the cacheStorage interface.";
throw AuthError.createUnexpectedError(notImplErr);
}
Expand Down
5 changes: 1 addition & 4 deletions lib/msal-common/src/client/AuthorizationCodeClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,7 @@ export class AuthorizationCodeClient extends BaseClient {
);

responseHandler.validateTokenResponse(response.body);
const tokenResponse = responseHandler.generateAuthenticationResult(
response.body,
this.defaultAuthority
);
const tokenResponse = responseHandler.generateAuthenticationResult(response.body, this.defaultAuthority);

return tokenResponse;
}
Expand Down
10 changes: 5 additions & 5 deletions lib/msal-common/src/client/SPAClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ export class SPAClient extends BaseClient {
* also use the handleFragmentResponse() API to pass the codeResponse to this function afterwards.
* @param codeResponse
*/
async acquireToken(codeRequest: AuthorizationCodeRequest, userState: string, cachedNonce: string): Promise<AuthenticationResult> {
async acquireToken(codeRequest: AuthorizationCodeRequest, cachedState: string, cachedNonce: string): Promise<AuthenticationResult> {
// If no code response is given, we cannot acquire a token.
if (!codeRequest || StringUtils.isEmpty(codeRequest.code)) {
throw ClientAuthError.createTokenRequestCannotBeMadeError();
Expand Down Expand Up @@ -191,7 +191,7 @@ export class SPAClient extends BaseClient {
// User helper to retrieve token response.
// Need to await function call before return to catch any thrown errors.
// if errors are thrown asynchronously in return statement, they are caught by caller of this function instead.
return await this.getTokenResponse(tokenEndpoint, parameterBuilder, acquireTokenAuthority, cachedNonce, userState);
return await this.getTokenResponse(tokenEndpoint, parameterBuilder, acquireTokenAuthority, cachedNonce, cachedState);
}

/**
Expand Down Expand Up @@ -418,7 +418,7 @@ export class SPAClient extends BaseClient {
* @param tokenRequest
* @param codeResponse
*/
private async getTokenResponse(tokenEndpoint: string, parameterBuilder: RequestParameterBuilder, authority: Authority, cachedNonce?: string, userState?: string): Promise<AuthenticationResult> {
private async getTokenResponse(tokenEndpoint: string, parameterBuilder: RequestParameterBuilder, authority: Authority, cachedNonce?: string, cachedState?: string): Promise<AuthenticationResult> {
// Perform token request.
const acquiredTokenResponse = await this.networkClient.sendPostRequestAsync<
ServerAuthorizationTokenResponse
Expand All @@ -432,8 +432,8 @@ export class SPAClient extends BaseClient {
// Validate response. This function throws a server error if an error is returned by the server.
responseHandler.validateTokenResponse(acquiredTokenResponse.body);
// Return token response with given parameters
const tokenResponse = responseHandler.generateAuthenticationResult(acquiredTokenResponse.body, authority, cachedNonce);
tokenResponse.state = userState;
const tokenResponse = responseHandler.generateAuthenticationResult(acquiredTokenResponse.body, authority, cachedNonce, cachedState);
tokenResponse.state = cachedState;

return tokenResponse;
}
Expand Down
13 changes: 13 additions & 0 deletions lib/msal-common/src/error/ClientAuthError.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ export const ClientAuthErrorMessage = {
code: "blank_guid_generated",
desc: "The guid generated was blank. Please review the trace to determine the root cause."
},
invalidStateError: {
code: "invalid_state",
desc: "State was not valid, please check the network trace."
pkanher617 marked this conversation as resolved.
Show resolved Hide resolved
},
stateMismatchError: {
code: "state_mismatch",
desc: "State mismatch error. Please check your network. Continued requests may cause cache overflow."
Expand Down Expand Up @@ -195,6 +199,15 @@ export class ClientAuthError extends AuthError {
`${ClientAuthErrorMessage.hashNotDeserialized.desc} Given Object: ${hashParamObj}`);
}

/**
* Creates an error thrown when the state cannot be parsed.
* @param invalidState
*/
static createInvalidStateError(invalidState: string, errorString?: string): ClientAuthError {
return new ClientAuthError(ClientAuthErrorMessage.invalidStateError.code,
`${ClientAuthErrorMessage.invalidStateError.desc} Invalid State: ${invalidState}, Root Err: ${errorString}`);
}

/**
* Creates an error thrown when two states do not match.
*/
Expand Down
34 changes: 19 additions & 15 deletions lib/msal-common/src/response/ResponseHandler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import { InteractionRequiredAuthError } from "../error/InteractionRequiredAuthEr
import { CacheRecord } from "../cache/entities/CacheRecord";
import { EnvironmentAliases, PreferredCacheEnvironment } from "../utils/Constants";
import { CacheManager } from "../cache/CacheManager";
import { ProtocolUtils, LibraryStateObject } from "../utils/ProtocolUtils";

/**
* Class that handles response parsing.
Expand All @@ -49,11 +50,7 @@ export class ResponseHandler {
* @param cachedState
* @param cryptoObj
*/
validateServerAuthorizationCodeResponse(
serverResponseHash: ServerAuthorizationCodeResponse,
cachedState: string,
cryptoObj: ICrypto
): void {
validateServerAuthorizationCodeResponse(serverResponseHash: ServerAuthorizationCodeResponse, cachedState: string, cryptoObj: ICrypto): void {
if (serverResponseHash.state !== cachedState) {
throw ClientAuthError.createStateMismatchError();
}
Expand All @@ -76,9 +73,7 @@ export class ResponseHandler {
* Function which validates server authorization token response.
* @param serverResponse
*/
validateTokenResponse(
serverResponse: ServerAuthorizationTokenResponse
): void {
validateTokenResponse(serverResponse: ServerAuthorizationTokenResponse): void {
// Check for error
if (serverResponse.error || serverResponse.error_description || serverResponse.suberror) {
if (InteractionRequiredAuthError.isInteractionRequiredError(serverResponse.error, serverResponse.error_description, serverResponse.suberror)) {
Expand All @@ -103,7 +98,7 @@ export class ResponseHandler {
* @param serverTokenResponse
* @param authority
*/
generateAuthenticationResult(serverTokenResponse: ServerAuthorizationTokenResponse, authority: Authority, cachedNonce?: string): AuthenticationResult {
generateAuthenticationResult(serverTokenResponse: ServerAuthorizationTokenResponse, authority: Authority, cachedNonce?: string, cachedState?: string): AuthenticationResult {
// create an idToken object (not entity)
const idTokenObj = new IdToken(serverTokenResponse.id_token, this.cryptoObj);

Expand All @@ -115,7 +110,8 @@ export class ResponseHandler {
}

// save the response tokens
const cacheRecord = this.generateCacheRecord(serverTokenResponse, idTokenObj, authority);
const requestStateObj = ProtocolUtils.parseRequestState(cachedState, this.cryptoObj);
const cacheRecord = this.generateCacheRecord(serverTokenResponse, idTokenObj, authority, requestStateObj.libraryState);
const responseScopes = ScopeSet.fromString(serverTokenResponse.scope);
this.cacheStorage.saveCacheRecord(cacheRecord, responseScopes);

Expand All @@ -131,6 +127,7 @@ export class ResponseHandler {
expiresOn: new Date(cacheRecord.accessToken.expiresOn),
extExpiresOn: new Date(cacheRecord.accessToken.extendedExpiresOn),
familyId: serverTokenResponse.foci || null,
state: requestStateObj.userRequestState
};

return authenticationResult;
Expand Down Expand Up @@ -165,7 +162,7 @@ export class ResponseHandler {
* @param idTokenObj
* @param authority
*/
generateCacheRecord(serverTokenResponse: ServerAuthorizationTokenResponse, idTokenObj: IdToken, authority: Authority): CacheRecord {
generateCacheRecord(serverTokenResponse: ServerAuthorizationTokenResponse, idTokenObj: IdToken, authority: Authority, libraryState?: LibraryStateObject): CacheRecord {
// Account
const cachedAccount = this.generateAccountEntity(
serverTokenResponse,
Expand All @@ -188,8 +185,15 @@ export class ResponseHandler {
// AccessToken
const responseScopes = ScopeSet.fromString(serverTokenResponse.scope);
// Expiration calculation
const expiresInSeconds = TimeUtils.nowSeconds() + serverTokenResponse.expires_in;
const extendedExpiresInSeconds = expiresInSeconds + serverTokenResponse.ext_expires_in;
const currentTime = TimeUtils.nowSeconds();
let tokenExpirationSeconds = serverTokenResponse.expires_in;
if (!libraryState) {
tokenExpirationSeconds += currentTime;
} else {
console.log(libraryState.ts);
pkanher617 marked this conversation as resolved.
Show resolved Hide resolved
tokenExpirationSeconds += libraryState.ts;
}
const extendedTokenExpirationSeconds = tokenExpirationSeconds + serverTokenResponse.ext_expires_in;

const cachedAccessToken = AccessTokenEntity.createAccessTokenEntity(
this.homeAccountIdentifier,
Expand All @@ -198,8 +202,8 @@ export class ResponseHandler {
this.clientId,
idTokenObj.claims.tid,
responseScopes.asArray().join(" "),
expiresInSeconds,
extendedExpiresInSeconds
tokenExpirationSeconds,
extendedTokenExpirationSeconds
);

// refreshToken
Expand Down
60 changes: 47 additions & 13 deletions lib/msal-common/src/utils/ProtocolUtils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,19 @@
*/
import { StringUtils } from "./StringUtils";
import { Constants } from "./Constants";
import { ICrypto } from "../crypto/ICrypto";
import { TimeUtils } from "./TimeUtils";
import { ClientAuthError } from "../error/ClientAuthError";

export type LibraryStateObject = {
id: string,
ts: number
};

export type RequestStateObject = {
userRequestState: string,
libraryState: LibraryStateObject
};

/**
* Class which provides helpers for OAuth 2.0 protocol specific values
Expand All @@ -15,23 +28,44 @@ export class ProtocolUtils {
* @param userState
* @param randomGuid
*/
static setRequestState(userState: string, randomGuid: string): string {
return !StringUtils.isEmpty(userState) ? `${randomGuid}${Constants.RESOURCE_DELIM}${userState}` : randomGuid;
static setRequestState(userState: string, cryptoObj: ICrypto): string {
const libraryState = ProtocolUtils.generateLibraryState(cryptoObj);
return !StringUtils.isEmpty(userState) ? `${libraryState}${Constants.RESOURCE_DELIM}${userState}` : libraryState;
}

/**
*
* Extracts user state value from the state sent with the authentication request.
* @returns {string} scope.
* @ignore
* Generates the state value used by the library.
* @param randomGuid
* @param cryptoObj
*/
static getUserRequestState(serverResponseState: string): string {
if (!StringUtils.isEmpty(serverResponseState)) {
const splitIndex = serverResponseState.indexOf(Constants.RESOURCE_DELIM);
if (splitIndex > -1 && splitIndex + 1 < serverResponseState.length) {
return serverResponseState.substring(splitIndex + 1);
}
static generateLibraryState(cryptoObj: ICrypto): string {
const stateObj: LibraryStateObject = {
id: cryptoObj.createNewGuid(),
ts: TimeUtils.nowSeconds()
};

const stateString = JSON.stringify(stateObj);

return cryptoObj.base64Encode(stateString);
}

static parseRequestState(state: string, cryptoObj: ICrypto): RequestStateObject {
if (StringUtils.isEmpty(state)) {
throw ClientAuthError.createInvalidStateError(state, "Null, undefined or empty state");
}

try {
const splitState = decodeURIComponent(state).split(Constants.RESOURCE_DELIM);
const libraryState = splitState[0];
const userState = splitState.length > 1 ? splitState.slice(1).join(Constants.RESOURCE_DELIM) : "";
pkanher617 marked this conversation as resolved.
Show resolved Hide resolved
const libraryStateString = cryptoObj.base64Decode(libraryState);
const libraryStateObj = JSON.parse(libraryStateString) as LibraryStateObject;
return {
userRequestState: !StringUtils.isEmpty(userState) ? userState : "",
libraryState: libraryStateObj
};
} catch(e) {
throw ClientAuthError.createInvalidStateError(state, e);
}
return "";
}
}