diff --git a/change/@azure-msal-browser-2020-08-11-12-05-23-msal-throttling.json b/change/@azure-msal-browser-2020-08-11-12-05-23-msal-throttling.json new file mode 100644 index 0000000000..894a096d93 --- /dev/null +++ b/change/@azure-msal-browser-2020-08-11-12-05-23-msal-throttling.json @@ -0,0 +1,8 @@ +{ + "type": "minor", + "comment": "Added client-side throttling to enhance server stability (#1907)", + "packageName": "@azure/msal-browser", + "email": "jamckenn@microsoft.com", + "dependentChangeType": "patch", + "date": "2020-08-11T19:04:58.604Z" +} diff --git a/change/@azure-msal-common-2020-08-11-12-05-23-msal-throttling.json b/change/@azure-msal-common-2020-08-11-12-05-23-msal-throttling.json new file mode 100644 index 0000000000..4ee85349ec --- /dev/null +++ b/change/@azure-msal-common-2020-08-11-12-05-23-msal-throttling.json @@ -0,0 +1,8 @@ +{ + "type": "minor", + "comment": "Added client-side throttling to enhance server stability (#1907)", + "packageName": "@azure/msal-common", + "email": "jamckenn@microsoft.com", + "dependentChangeType": "patch", + "date": "2020-08-11T19:05:23.196Z" +} diff --git a/lib/msal-browser/src/app/ClientApplication.ts b/lib/msal-browser/src/app/ClientApplication.ts index 7593325516..64832e7f54 100644 --- a/lib/msal-browser/src/app/ClientApplication.ts +++ b/lib/msal-browser/src/app/ClientApplication.ts @@ -5,7 +5,7 @@ import { CryptoOps } from "../crypto/CryptoOps"; import { BrowserStorage } from "../cache/BrowserStorage"; -import { Authority, TrustedAuthority, StringUtils, CacheSchemaType, UrlString, ServerAuthorizationCodeResponse, AuthorizationCodeRequest, AuthorizationUrlRequest, AuthorizationCodeClient, PromptValue, SilentFlowRequest, ServerError, InteractionRequiredAuthError, EndSessionRequest, AccountInfo, AuthorityFactory, ServerTelemetryManager, SilentFlowClient, ClientConfiguration, BaseAuthRequest, ServerTelemetryRequest, PersistentCacheKeys, IdToken, ProtocolUtils, ResponseMode, Constants, INetworkModule, AuthenticationResult, Logger } from "@azure/msal-common"; +import { Authority, TrustedAuthority, StringUtils, CacheSchemaType, UrlString, ServerAuthorizationCodeResponse, AuthorizationCodeRequest, AuthorizationUrlRequest, AuthorizationCodeClient, PromptValue, SilentFlowRequest, ServerError, InteractionRequiredAuthError, EndSessionRequest, AccountInfo, AuthorityFactory, ServerTelemetryManager, SilentFlowClient, ClientConfiguration, BaseAuthRequest, ServerTelemetryRequest, PersistentCacheKeys, IdToken, ProtocolUtils, ResponseMode, Constants, INetworkModule, AuthenticationResult, Logger, ThrottlingUtils } from "@azure/msal-common"; import { buildConfiguration, Configuration } from "../config/Configuration"; import { TemporaryCacheKeys, InteractionType, ApiId, BrowserConstants, DEFAULT_REQUEST } from "../utils/BrowserConstants"; import { BrowserUtils } from "../utils/BrowserUtils"; @@ -194,7 +194,7 @@ export abstract class ClientApplication { const currentAuthority = this.browserStorage.getCachedAuthority(); const authClient = await this.createAuthCodeClient(serverTelemetryManager, currentAuthority); const interactionHandler = new RedirectHandler(authClient, this.browserStorage); - return await interactionHandler.handleCodeResponse(responseHash, this.browserCrypto); + return await interactionHandler.handleCodeResponse(responseHash, this.browserCrypto, this.config.auth.clientId); } catch (e) { serverTelemetryManager.cacheFailedRequest(e); this.browserStorage.cleanRequest(); @@ -290,6 +290,9 @@ export abstract class ClientApplication { // Monitor the window for the hash. Return the string value and close the popup when the hash is received. Default timeout is 60 seconds. const hash = await interactionHandler.monitorPopupForHash(popupWindow, this.config.system.windowHashTimeout); + // Remove throttle if it exists + ThrottlingUtils.removeThrottle(this.browserStorage, this.config.auth.clientId, authCodeRequest.authority, authCodeRequest.scopes); + // Handle response from hash string. return await interactionHandler.handleCodeResponse(hash); } catch (e) { diff --git a/lib/msal-browser/src/cache/BrowserStorage.ts b/lib/msal-browser/src/cache/BrowserStorage.ts index e0116fe492..5a4db8ebca 100644 --- a/lib/msal-browser/src/cache/BrowserStorage.ts +++ b/lib/msal-browser/src/cache/BrowserStorage.ts @@ -2,7 +2,7 @@ * Copyright (c) Microsoft Corporation. All rights reserved. * Licensed under the MIT License. */ -import { Constants, PersistentCacheKeys, StringUtils, AuthorizationCodeRequest, ICrypto, CacheSchemaType, AccountEntity, IdTokenEntity, CredentialType, AccessTokenEntity, RefreshTokenEntity, AppMetadataEntity, CacheManager, CredentialEntity, ServerTelemetryCacheValue } from "@azure/msal-common"; +import { Constants, PersistentCacheKeys, StringUtils, AuthorizationCodeRequest, ICrypto, CacheSchemaType, AccountEntity, IdTokenEntity, CredentialType, AccessTokenEntity, RefreshTokenEntity, AppMetadataEntity, CacheManager, CredentialEntity, ServerTelemetryCacheValue, ThrottlingEntity } from "@azure/msal-common"; import { CacheOptions } from "../config/Configuration"; import { BrowserAuthError } from "../error/BrowserAuthError"; import { BrowserConfigurationAuthError } from "../error/BrowserConfigurationAuthError"; @@ -114,6 +114,7 @@ export class BrowserStorage extends CacheManager { case CacheSchemaType.ACCOUNT: case CacheSchemaType.CREDENTIAL: case CacheSchemaType.APP_METADATA: + case CacheSchemaType.THROTTLING: this.windowStorage.setItem(key, JSON.stringify(value)); break; case CacheSchemaType.TEMPORARY: { @@ -169,6 +170,9 @@ export class BrowserStorage extends CacheManager { case CacheSchemaType.APP_METADATA: { return (JSON.parse(value) as AppMetadataEntity); } + case CacheSchemaType.THROTTLING: { + return (JSON.parse(value) as ThrottlingEntity); + } case CacheSchemaType.TEMPORARY: { const itemCookie = this.getItemCookie(key); if (this.cacheConfig.storeAuthStateInCookie) { diff --git a/lib/msal-browser/src/interaction_handler/RedirectHandler.ts b/lib/msal-browser/src/interaction_handler/RedirectHandler.ts index 5cfa23e2d4..8fcaffdf41 100644 --- a/lib/msal-browser/src/interaction_handler/RedirectHandler.ts +++ b/lib/msal-browser/src/interaction_handler/RedirectHandler.ts @@ -2,7 +2,7 @@ * Copyright (c) Microsoft Corporation. All rights reserved. * Licensed under the MIT License. */ -import { StringUtils, AuthorizationCodeRequest, ICrypto, CacheSchemaType, AuthenticationResult } from "@azure/msal-common"; +import { StringUtils, AuthorizationCodeRequest, ICrypto, CacheSchemaType, AuthenticationResult, ThrottlingUtils } from "@azure/msal-common"; import { InteractionHandler } from "./InteractionHandler"; import { BrowserAuthError } from "../error/BrowserAuthError"; import { BrowserConstants, TemporaryCacheKeys } from "../utils/BrowserConstants"; @@ -46,7 +46,7 @@ export class RedirectHandler extends InteractionHandler { * Handle authorization code response in the window. * @param hash */ - async handleCodeResponse(locationHash: string, browserCrypto?: ICrypto): Promise { + async handleCodeResponse(locationHash: string, browserCrypto?: ICrypto, clientId?: string): Promise { // Check that location hash isn't empty. if (StringUtils.isEmpty(locationHash)) { throw BrowserAuthError.createEmptyHashError(locationHash); @@ -65,6 +65,11 @@ export class RedirectHandler extends InteractionHandler { this.authCodeRequest = this.browserStorage.getCachedRequest(requestState, browserCrypto); this.authCodeRequest.code = authCode; + // Remove throttle if it exists + if (clientId) { + ThrottlingUtils.removeThrottle(this.browserStorage, clientId, this.authCodeRequest.authority, this.authCodeRequest.scopes); + } + // Acquire token with retrieved code. const tokenResponse = await this.authModule.acquireToken(this.authCodeRequest, cachedNonce, requestState); diff --git a/lib/msal-browser/test/cache/BrowserStorage.spec.ts b/lib/msal-browser/test/cache/BrowserStorage.spec.ts index 5c42bea57f..aecabdfd57 100644 --- a/lib/msal-browser/test/cache/BrowserStorage.spec.ts +++ b/lib/msal-browser/test/cache/BrowserStorage.spec.ts @@ -495,6 +495,6 @@ describe("BrowserStorage() tests", () => { // Perform test const tokenRequest = browserStorage.getCachedRequest(RANDOM_TEST_GUID, browserCrypto); expect(tokenRequest.authority).to.be.eq(alternateAuthority); - }); + }); }); }); diff --git a/lib/msal-common/src/cache/entities/ThrottlingEntity.ts b/lib/msal-common/src/cache/entities/ThrottlingEntity.ts new file mode 100644 index 0000000000..f5f9d741f7 --- /dev/null +++ b/lib/msal-common/src/cache/entities/ThrottlingEntity.ts @@ -0,0 +1,35 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. + */ +import { ThrottlingConstants } from "../../utils/Constants"; + +export class ThrottlingEntity { + // Unix-time value representing the expiration of the throttle + throttleTime: number; + // Information provided by the server + error?: string; + errorCodes?: Array; + errorMessage?: string; + subError?: string; + + /** + * validates if a given cache entry is "Throttling", parses + * @param key + * @param entity + */ + static isThrottlingEntity(key: string, entity?: object): boolean { + + let validateKey: boolean = false; + if (key) { + validateKey = key.indexOf(ThrottlingConstants.THROTTLING_PREFIX) === 0; + } + + let validateEntity: boolean = true; + if (entity) { + validateEntity = entity.hasOwnProperty("throttleTime"); + } + + return validateKey && validateEntity; + } +} diff --git a/lib/msal-common/src/cache/utils/CacheTypes.ts b/lib/msal-common/src/cache/utils/CacheTypes.ts index e2bb6de336..bba10b9ef7 100644 --- a/lib/msal-common/src/cache/utils/CacheTypes.ts +++ b/lib/msal-common/src/cache/utils/CacheTypes.ts @@ -9,6 +9,7 @@ import { AccessTokenEntity } from "../entities/AccessTokenEntity"; import { RefreshTokenEntity } from "../entities/RefreshTokenEntity"; import { AppMetadataEntity } from "../entities/AppMetadataEntity"; import { ServerTelemetryEntity } from "../entities/ServerTelemetryEntity"; +import { ThrottlingEntity } from "../entities/ThrottlingEntity"; export type AccountCache = Record; export type IdTokenCache = Record; @@ -21,7 +22,7 @@ export type CredentialCache = { refreshTokens: RefreshTokenCache; }; -export type ValidCacheType = AccountEntity | IdTokenEntity | AccessTokenEntity | RefreshTokenEntity | AppMetadataEntity | ServerTelemetryEntity | string; +export type ValidCacheType = AccountEntity | IdTokenEntity | AccessTokenEntity | RefreshTokenEntity | AppMetadataEntity | ServerTelemetryEntity | ThrottlingEntity | string; /** * Account: -- diff --git a/lib/msal-common/src/client/AuthorizationCodeClient.ts b/lib/msal-common/src/client/AuthorizationCodeClient.ts index 1e05e2490d..ab20fd31da 100644 --- a/lib/msal-common/src/client/AuthorizationCodeClient.ts +++ b/lib/msal-common/src/client/AuthorizationCodeClient.ts @@ -22,6 +22,7 @@ import { ServerAuthorizationCodeResponse } from "../response/ServerAuthorization import { AccountEntity } from "../cache/entities/AccountEntity"; import { EndSessionRequest } from "../request/EndSessionRequest"; import { ClientConfigurationError } from "../error/ClientConfigurationError"; +import { RequestThumbprint } from "../network/RequestThumbprint"; /** * Oauth2.0 Authorization Code client @@ -130,10 +131,16 @@ export class AuthorizationCodeClient extends BaseClient { * @param request */ private async executeTokenRequest(authority: Authority, request: AuthorizationCodeRequest): Promise> { + const thumbprint: RequestThumbprint = { + clientId: this.config.authOptions.clientId, + authority: authority.canonicalAuthority, + scopes: request.scopes + }; + const requestBody = this.createTokenRequestBody(request); const headers: Record = this.createDefaultTokenRequestHeaders(); - return this.executePostToTokenEndpoint(authority.tokenEndpoint, requestBody, headers); + return this.executePostToTokenEndpoint(authority.tokenEndpoint, requestBody, headers, thumbprint); } /** diff --git a/lib/msal-common/src/client/BaseClient.ts b/lib/msal-common/src/client/BaseClient.ts index 89c779fcc7..3a79d85d57 100644 --- a/lib/msal-common/src/client/BaseClient.ts +++ b/lib/msal-common/src/client/BaseClient.ts @@ -5,15 +5,16 @@ import { ClientConfiguration, buildClientConfiguration } from "../config/ClientConfiguration"; import { INetworkModule } from "../network/INetworkModule"; +import { NetworkManager, NetworkResponse } from "../network/NetworkManager"; import { ICrypto } from "../crypto/ICrypto"; import { Authority } from "../authority/Authority"; import { Logger } from "../logger/Logger"; import { AADServerParamKeys, Constants, HeaderNames } from "../utils/Constants"; -import { NetworkResponse } from "../network/NetworkManager"; import { ServerAuthorizationTokenResponse } from "../response/ServerAuthorizationTokenResponse"; import { TrustedAuthority } from "../authority/TrustedAuthority"; import { CacheManager } from "../cache/CacheManager"; import { ServerTelemetryManager } from "../telemetry/server/ServerTelemetryManager"; +import { RequestThumbprint } from "../network/RequestThumbprint"; /** * Base application class which will construct requests to send to and handle responses from the Microsoft STS using the authorization code flow. @@ -37,6 +38,9 @@ export abstract class BaseClient { // Server Telemetry Manager protected serverTelemetryManager: ServerTelemetryManager; + // Network Manager + protected networkManager: NetworkManager; + // Default authority object protected authority: Authority; @@ -56,6 +60,9 @@ export abstract class BaseClient { // Set the network interface this.networkClient = this.config.networkInterface; + // Set the NetworkManager + this.networkManager = new NetworkManager(this.networkClient, this.cacheManager); + // Set TelemetryManager this.serverTelemetryManager = this.config.serverTelemetryManager; @@ -70,6 +77,7 @@ export abstract class BaseClient { protected createDefaultTokenRequestHeaders(): Record { const headers = this.createDefaultLibraryHeaders(); headers[HeaderNames.CONTENT_TYPE] = Constants.URL_FORM_CONTENT_TYPE; + headers[HeaderNames.X_MS_LIB_CAPABILITY] = HeaderNames.X_MS_LIB_CAPABILITY_VALUE; if (this.serverTelemetryManager) { headers[HeaderNames.X_CLIENT_CURR_TELEM] = this.serverTelemetryManager.generateCurrentRequestHeaderValue(); @@ -99,14 +107,14 @@ export abstract class BaseClient { * @param tokenEndpoint * @param queryString * @param headers + * @param thumbprint */ - protected async executePostToTokenEndpoint(tokenEndpoint: string, queryString: string, headers: Record): Promise> { - const response = await this.networkClient.sendPostRequestAsync< - ServerAuthorizationTokenResponse - >(tokenEndpoint, { - body: queryString, - headers: headers, - }); + protected async executePostToTokenEndpoint(tokenEndpoint: string, queryString: string, headers: Record, thumbprint: RequestThumbprint): Promise> { + const response = await this.networkManager.sendPostRequest( + thumbprint, + tokenEndpoint, + { body: queryString, headers: headers } + ); if (this.config.serverTelemetryManager && response.status < 500 && response.status !== 429) { // Telemetry data successfully logged by server, clear Telemetry cache diff --git a/lib/msal-common/src/client/ClientCredentialClient.ts b/lib/msal-common/src/client/ClientCredentialClient.ts index 60b2607de3..fed669f126 100644 --- a/lib/msal-common/src/client/ClientCredentialClient.ts +++ b/lib/msal-common/src/client/ClientCredentialClient.ts @@ -17,6 +17,7 @@ import { CredentialType } from "../utils/Constants"; import { AccessTokenEntity } from "../cache/entities/AccessTokenEntity"; import { TimeUtils } from "../utils/TimeUtils"; import { StringUtils } from "../utils/StringUtils"; +import { RequestThumbprint } from "../network/RequestThumbprint"; /** * OAuth2.0 client credential grant @@ -81,8 +82,13 @@ export class ClientCredentialClient extends BaseClient { const requestBody = this.createTokenRequestBody(request); const headers: Record = this.createDefaultTokenRequestHeaders(); + const thumbprint: RequestThumbprint = { + clientId: this.config.authOptions.clientId, + authority: request.authority, + scopes: request.scopes + }; - const response = await this.executePostToTokenEndpoint(authority.tokenEndpoint, requestBody, headers); + const response = await this.executePostToTokenEndpoint(authority.tokenEndpoint, requestBody, headers, thumbprint); const responseHandler = new ResponseHandler( this.config.authOptions.clientId, diff --git a/lib/msal-common/src/client/DeviceCodeClient.ts b/lib/msal-common/src/client/DeviceCodeClient.ts index e209d7480f..35e49b0c3d 100644 --- a/lib/msal-common/src/client/DeviceCodeClient.ts +++ b/lib/msal-common/src/client/DeviceCodeClient.ts @@ -16,6 +16,7 @@ import { ScopeSet } from "../request/ScopeSet"; import { ResponseHandler } from "../response/ResponseHandler"; import { AuthenticationResult } from "../response/AuthenticationResult"; import { StringUtils } from "../utils/StringUtils"; +import { RequestThumbprint } from "../network/RequestThumbprint"; /** * OAuth2.0 Device code client @@ -61,11 +62,15 @@ export class DeviceCodeClient extends BaseClient { * @param request */ private async getDeviceCode(request: DeviceCodeRequest): Promise { - const queryString = this.createQueryString(request); const headers = this.createDefaultLibraryHeaders(); + const thumbprint: RequestThumbprint = { + clientId: this.config.authOptions.clientId, + authority: request.authority, + scopes: request.scopes + }; - return this.executePostRequestToDeviceCodeEndpoint(this.authority.deviceCodeEndpoint, queryString, headers); + return this.executePostRequestToDeviceCodeEndpoint(this.authority.deviceCodeEndpoint, queryString, headers, thumbprint); } /** @@ -77,7 +82,8 @@ export class DeviceCodeClient extends BaseClient { private async executePostRequestToDeviceCodeEndpoint( deviceCodeEndpoint: string, queryString: string, - headers: Record): Promise { + headers: Record, + thumbprint: RequestThumbprint): Promise { const { body: { @@ -88,7 +94,8 @@ export class DeviceCodeClient extends BaseClient { interval, message } - } = await this.networkClient.sendPostRequestAsync( + } = await this.networkManager.sendPostRequest( + thumbprint, deviceCodeEndpoint, { body: queryString, @@ -157,10 +164,16 @@ export class DeviceCodeClient extends BaseClient { reject(ClientAuthError.createDeviceCodeExpiredError()); } else { + const thumbprint: RequestThumbprint = { + clientId: this.config.authOptions.clientId, + authority: request.authority, + scopes: request.scopes + }; const response = await this.executePostToTokenEndpoint( this.authority.tokenEndpoint, requestBody, - headers); + headers, + thumbprint); if (response.body && response.body.error == Constants.AUTHORIZATION_PENDING) { // user authorization is pending. Sleep for polling interval and try again diff --git a/lib/msal-common/src/client/RefreshTokenClient.ts b/lib/msal-common/src/client/RefreshTokenClient.ts index 1ebc98fc41..018ae73c55 100644 --- a/lib/msal-common/src/client/RefreshTokenClient.ts +++ b/lib/msal-common/src/client/RefreshTokenClient.ts @@ -6,7 +6,7 @@ import { ClientConfiguration } from "../config/ClientConfiguration"; import { BaseClient } from "./BaseClient"; import { RefreshTokenRequest } from "../request/RefreshTokenRequest"; -import { Authority, NetworkResponse } from ".."; +import { Authority } from ".."; import { ServerAuthorizationTokenResponse } from "../response/ServerAuthorizationTokenResponse"; import { RequestParameterBuilder } from "../request/RequestParameterBuilder"; import { ScopeSet } from "../request/ScopeSet"; @@ -14,6 +14,8 @@ import { GrantType } from "../utils/Constants"; import { ResponseHandler } from "../response/ResponseHandler"; import { AuthenticationResult } from "../response/AuthenticationResult"; import { StringUtils } from "../utils/StringUtils"; +import { RequestThumbprint } from "../network/RequestThumbprint"; +import { NetworkResponse } from "../network/NetworkManager"; /** * OAuth2.0 refresh token client @@ -45,11 +47,15 @@ export class RefreshTokenClient extends BaseClient { private async executeTokenRequest(request: RefreshTokenRequest, authority: Authority) : Promise> { - const requestBody = this.createTokenRequestBody(request); const headers: Record = this.createDefaultTokenRequestHeaders(); + const thumbprint: RequestThumbprint = { + clientId: this.config.authOptions.clientId, + authority: authority.canonicalAuthority, + scopes: request.scopes + }; - return this.executePostToTokenEndpoint(authority.tokenEndpoint, requestBody, headers); + return this.executePostToTokenEndpoint(authority.tokenEndpoint, requestBody, headers, thumbprint); } private createTokenRequestBody(request: RefreshTokenRequest): string { diff --git a/lib/msal-common/src/index.ts b/lib/msal-common/src/index.ts index e5d1426ade..9422fb1c89 100644 --- a/lib/msal-common/src/index.ts +++ b/lib/msal-common/src/index.ts @@ -25,9 +25,12 @@ export { AccountEntity } from "./cache/entities/AccountEntity"; export { IdTokenEntity } from "./cache/entities/IdTokenEntity"; export { AccessTokenEntity } from "./cache/entities/AccessTokenEntity"; export { RefreshTokenEntity } from "./cache/entities/RefreshTokenEntity"; +export { ThrottlingEntity } from "./cache/entities/ThrottlingEntity"; // Network Interface export { INetworkModule, NetworkRequestOptions } from "./network/INetworkModule"; -export { NetworkResponse } from "./network/NetworkManager"; +export { NetworkManager, NetworkResponse } from "./network/NetworkManager"; +export { ThrottlingUtils } from "./network/ThrottlingUtils"; +export { RequestThumbprint } from "./network/RequestThumbprint"; export { IUri } from "./url/IUri"; export { UrlString } from "./url/UrlString"; // Crypto Interface diff --git a/lib/msal-common/src/network/INetworkModule.ts b/lib/msal-common/src/network/INetworkModule.ts index 567c7ad8d2..788498bb65 100644 --- a/lib/msal-common/src/network/INetworkModule.ts +++ b/lib/msal-common/src/network/INetworkModule.ts @@ -3,7 +3,7 @@ * Licensed under the MIT License. */ -import {NetworkResponse} from "./NetworkManager"; +import { NetworkResponse } from "./NetworkManager"; /** * Options allowed by network request APIs. diff --git a/lib/msal-common/src/network/NetworkManager.ts b/lib/msal-common/src/network/NetworkManager.ts index d560a64b45..68c44e01d4 100644 --- a/lib/msal-common/src/network/NetworkManager.ts +++ b/lib/msal-common/src/network/NetworkManager.ts @@ -3,11 +3,39 @@ * Licensed under the MIT License. */ +import { INetworkModule, NetworkRequestOptions } from "./INetworkModule"; +import { RequestThumbprint } from "./RequestThumbprint"; +import { ThrottlingUtils } from "./ThrottlingUtils"; +import { CacheManager } from "../cache/CacheManager"; + export type NetworkResponse = { headers: Record; body: T; status: number; }; -// TODO placeholder: this will be filled in by the throttling PR -export class NetworkManager {} +export class NetworkManager { + private networkClient: INetworkModule; + private cacheManager: CacheManager; + + constructor(networkClient: INetworkModule, cacheManager: CacheManager) { + this.networkClient = networkClient; + this.cacheManager = cacheManager; + } + + /** + * Wraps sendPostRequestAsync with necessary preflight and postflight logic + * @param thumbprint + * @param tokenEndpoint + * @param options + */ + async sendPostRequest(thumbprint: RequestThumbprint, tokenEndpoint: string, options: NetworkRequestOptions): Promise> { + ThrottlingUtils.preProcess(this.cacheManager, thumbprint); + const response = await this.networkClient.sendPostRequestAsync(tokenEndpoint, options); + ThrottlingUtils.postProcess(this.cacheManager, thumbprint, response); + + // Placeholder for Telemetry hook + + return response; + } +} diff --git a/lib/msal-common/src/network/RequestThumbprint.ts b/lib/msal-common/src/network/RequestThumbprint.ts new file mode 100644 index 0000000000..d792acca14 --- /dev/null +++ b/lib/msal-common/src/network/RequestThumbprint.ts @@ -0,0 +1,14 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. + */ + +/** + * Type representing a unique request thumbprint. + */ +export type RequestThumbprint = { + clientId: string; + authority: string; + scopes: Array; + homeAccountIdentifier?: string; +}; diff --git a/lib/msal-common/src/network/ThrottlingUtils.ts b/lib/msal-common/src/network/ThrottlingUtils.ts new file mode 100644 index 0000000000..b3a36d6da8 --- /dev/null +++ b/lib/msal-common/src/network/ThrottlingUtils.ts @@ -0,0 +1,110 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. + */ + +import { NetworkResponse } from "./NetworkManager"; +import { ServerAuthorizationTokenResponse } from "../response/ServerAuthorizationTokenResponse"; +import { HeaderNames, CacheSchemaType, ThrottlingConstants } from "../utils/Constants"; +import { CacheManager } from "../cache/CacheManager"; +import { ServerError } from "../error/ServerError"; +import { RequestThumbprint } from "./RequestThumbprint"; +import { ThrottlingEntity } from "../cache/entities/ThrottlingEntity"; + +export class ThrottlingUtils { + + /** + * Prepares a RequestThumbprint to be stored as a key. + * @param thumbprint + */ + static generateThrottlingStorageKey(thumbprint: RequestThumbprint): string { + return `${ThrottlingConstants.THROTTLING_PREFIX}.${JSON.stringify(thumbprint)}`; + } + + /** + * Performs necessary throttling checks before a network request. + * @param cacheManager + * @param thumbprint + */ + static preProcess(cacheManager: CacheManager, thumbprint: RequestThumbprint): void { + const key = ThrottlingUtils.generateThrottlingStorageKey(thumbprint); + const value = cacheManager.getItem(key, CacheSchemaType.THROTTLING) as ThrottlingEntity; + + if (value) { + if (value.throttleTime < Date.now()) { + cacheManager.removeItem(key, CacheSchemaType.THROTTLING); + return; + } + throw new ServerError(value.errorCodes.join(" "), value.errorMessage, value.subError); + } + } + + /** + * Performs necessary throttling checks after a network request. + * @param cacheManager + * @param thumbprint + * @param response + */ + static postProcess(cacheManager: CacheManager, thumbprint: RequestThumbprint, response: NetworkResponse): void { + if (ThrottlingUtils.checkResponseStatus(response) || ThrottlingUtils.checkResponseForRetryAfter(response)) { + const thumbprintValue: ThrottlingEntity = { + throttleTime: ThrottlingUtils.calculateThrottleTime(parseInt(response.headers[HeaderNames.RETRY_AFTER])), + error: response.body.error, + errorCodes: response.body.error_codes, + errorMessage: response.body.error_description, + subError: response.body.suberror + }; + cacheManager.setItem( + ThrottlingUtils.generateThrottlingStorageKey(thumbprint), + thumbprintValue, + CacheSchemaType.THROTTLING + ); + } + } + + /** + * Checks a NetworkResponse object's status codes against 429 or 5xx + * @param response + */ + static checkResponseStatus(response: NetworkResponse): boolean { + return response.status == 429 || response.status >= 500 && response.status < 600; + } + + /** + * Checks a NetworkResponse object's RetryAfter header + * @param response + */ + static checkResponseForRetryAfter(response: NetworkResponse): boolean { + if (response.headers) { + return response.headers.hasOwnProperty(HeaderNames.RETRY_AFTER) && (response.status < 200 || response.status >= 300); + } + return false; + } + + /** + * Calculates the Unix-time value for a throttle to expire given throttleTime in seconds. + * @param throttleTime + */ + static calculateThrottleTime(throttleTime: number): number { + if(throttleTime <= 0) { + throttleTime = null; + } + const currentSeconds = Date.now() / 1000; + return Math.floor(Math.min( + currentSeconds + (throttleTime || ThrottlingConstants.DEFAULT_THROTTLE_TIME_SECONDS), + currentSeconds + ThrottlingConstants.DEFAULT_MAX_THROTTLE_TIME_SECONDS + ) * 1000); + } + + static removeThrottle(cacheManager: CacheManager, clientId: string, authority: string, scopes: Array, homeAccountIdentifier?: string): boolean { + const thumbprint: RequestThumbprint = { + clientId, + authority, + scopes, + homeAccountIdentifier + }; + + const key = this.generateThrottlingStorageKey(thumbprint); + return cacheManager.removeItem(key, CacheSchemaType.THROTTLING); + } +} diff --git a/lib/msal-common/src/utils/Constants.ts b/lib/msal-common/src/utils/Constants.ts index d370a21e6f..40dd3695fc 100644 --- a/lib/msal-common/src/utils/Constants.ts +++ b/lib/msal-common/src/utils/Constants.ts @@ -43,7 +43,10 @@ export const Constants = { export enum HeaderNames { CONTENT_TYPE = "Content-Type", X_CLIENT_CURR_TELEM = "x-client-current-telemetry", - X_CLIENT_LAST_TELEM = "x-client-last-telemetry" + X_CLIENT_LAST_TELEM = "x-client-last-telemetry", + RETRY_AFTER = "Retry-After", + X_MS_LIB_CAPABILITY = "x-ms-lib-capability", + X_MS_LIB_CAPABILITY_VALUE = "retry-after, h429" } /** @@ -220,7 +223,7 @@ export enum Separators { } /** - * Credentail Type stored in the cache + * Credential Type stored in the cache */ export enum CredentialType { ID_TOKEN = "IdToken", @@ -229,7 +232,7 @@ export enum CredentialType { } /** - * Credentail Type stored in the cache + * Credential Type stored in the cache */ export enum CacheSchemaType { ACCOUNT = "Account", @@ -240,7 +243,8 @@ export enum CacheSchemaType { APP_METADATA = "AppMetadata", TEMPORARY = "TempCache", TELEMETRY = "Telemetry", - UNDEFINED = "Undefined" + UNDEFINED = "Undefined", + THROTTLING = "Throttling" } /** @@ -271,3 +275,15 @@ export const SERVER_TELEM_CONSTANTS = { CATEGORY_SEPARATOR: "|", VALUE_SEPARATOR: "," }; + +/** + * Constants related to throttling + */ +export const ThrottlingConstants = { + // Default time to throttle RequestThumbprint in seconds + DEFAULT_THROTTLE_TIME_SECONDS: 60, + // Default maximum time to throttle in seconds, overrides what the server sends back + DEFAULT_MAX_THROTTLE_TIME_SECONDS: 3600, + // Prefix for storing throttling entries + THROTTLING_PREFIX: "throttling" +}; diff --git a/lib/msal-common/src/utils/StringUtils.ts b/lib/msal-common/src/utils/StringUtils.ts index b424b11534..51059d9e02 100644 --- a/lib/msal-common/src/utils/StringUtils.ts +++ b/lib/msal-common/src/utils/StringUtils.ts @@ -86,4 +86,16 @@ export class StringUtils { return !StringUtils.isEmpty(entry); }); } + + /** + * Attempts to parse a string into JSON + * @param str + */ + static jsonParseHelper(str: string): T { + try { + return JSON.parse(str) as T; + } catch (e) { + return null; + } + } } diff --git a/lib/msal-common/test/cache/entities/ThrottlingEntity.spec.ts b/lib/msal-common/test/cache/entities/ThrottlingEntity.spec.ts new file mode 100644 index 0000000000..1432342f1a --- /dev/null +++ b/lib/msal-common/test/cache/entities/ThrottlingEntity.spec.ts @@ -0,0 +1,45 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. + */ + +import "mocha" +import { expect } from "chai"; +import { ThrottlingEntity } from "../../../src/cache/entities/ThrottlingEntity"; +import { ThrottlingConstants, Separators } from "../../../src/utils/Constants"; +import { TEST_CONFIG } from "../../utils/StringConstants"; + +describe("ThrottlingEntity", () => { + describe("isThrottlingEntity", () => { + + const key = ThrottlingConstants.THROTTLING_PREFIX + Separators.CACHE_KEY_SEPARATOR + TEST_CONFIG.MSAL_CLIENT_ID; + it("Verifies if an object is a ThrottlingEntity", () => { + const throttlingObject = { + throttleTime: 0 + } + expect(ThrottlingEntity.isThrottlingEntity(key, throttlingObject)).to.be.true; + + }); + + it("Verifies if an object is a ThrottlingEntity when no object is given", () => { + expect(ThrottlingEntity.isThrottlingEntity(key)).to.be.true; + expect(ThrottlingEntity.isThrottlingEntity(key, null)).to.be.true; + }); + + it("Verifies if an object is not a ThrottlingEntity based on field", () => { + const throttlingObject = { + test: 0 + } + expect(ThrottlingEntity.isThrottlingEntity(key, throttlingObject)).to.be.false; + }); + + it("Verifies if an object is not a ThrottlingEntity based on key", () => { + const throttlingObject = { + throttleTime: 0 + } + expect(ThrottlingEntity.isThrottlingEntity("asd", throttlingObject)).to.be.false; + expect(ThrottlingEntity.isThrottlingEntity("", throttlingObject)).to.be.false; + expect(ThrottlingEntity.isThrottlingEntity(null, throttlingObject)).to.be.false; + }) + }); +}); diff --git a/lib/msal-common/test/network/NetworkManager.spec.ts b/lib/msal-common/test/network/NetworkManager.spec.ts new file mode 100644 index 0000000000..f471bc55b9 --- /dev/null +++ b/lib/msal-common/test/network/NetworkManager.spec.ts @@ -0,0 +1,126 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. + */ + +import { expect } from "chai"; +import sinon from "sinon"; +import { ThrottlingUtils } from "../../src/network/ThrottlingUtils"; +import { RequestThumbprint } from "../../src/network/RequestThumbprint"; +import { NetworkManager, NetworkResponse } from "../../src/network/NetworkManager"; +import { ServerAuthorizationTokenResponse } from "../../src/response/ServerAuthorizationTokenResponse"; +import { MockStorageClass } from "../client/ClientTestUtils"; +import { NetworkRequestOptions } from "../../src/network/INetworkModule"; +import { ServerError } from "../../src/error/ServerError"; +import { AUTHENTICATION_RESULT, NETWORK_REQUEST_OPTIONS, THUMBPRINT, THROTTLING_ENTITY, DEFAULT_NETWORK_IMPLEMENTATION } from "../utils/StringConstants"; + +describe("NetworkManager", () => { + describe("sendPostRequest", () => { + afterEach(() => { + sinon.restore(); + }); + + it("returns a response", async () => { + const networkInterface = DEFAULT_NETWORK_IMPLEMENTATION; + const cache = new MockStorageClass(); + const networkManager = new NetworkManager(networkInterface, cache); + const thumbprint: RequestThumbprint = THUMBPRINT; + const options: NetworkRequestOptions = NETWORK_REQUEST_OPTIONS; + const mockRes: NetworkResponse = { + headers: { }, + body: AUTHENTICATION_RESULT.body, + status: 200 + } + const networkStub = sinon.stub(networkInterface, "sendPostRequestAsync").returns(Promise.resolve(mockRes)); + const getItemStub = sinon.stub(cache, "getItem"); + const setItemStub = sinon.stub(cache, "setItem"); + const removeItemStub = sinon.stub(cache, "removeItem"); + sinon.stub(Date, "now").callsFake(() => 1) + + const res = await networkManager.sendPostRequest>(thumbprint, "tokenEndpoint", options); + + sinon.assert.callCount(networkStub, 1); + sinon.assert.callCount(getItemStub, 1); + sinon.assert.callCount(setItemStub, 0); + sinon.assert.callCount(removeItemStub, 0); + expect(res).to.deep.eq(mockRes); + }); + + it("blocks the request if item is found in the cache", async () => { + const networkInterface = DEFAULT_NETWORK_IMPLEMENTATION; + const cache = new MockStorageClass(); + const networkManager = new NetworkManager(networkInterface, cache); + const thumbprint: RequestThumbprint = THUMBPRINT; + const options: NetworkRequestOptions = NETWORK_REQUEST_OPTIONS; + const mockThrottlingEntity = THROTTLING_ENTITY; + const networkStub = sinon.stub(networkInterface, "sendPostRequestAsync"); + const getItemStub = sinon.stub(cache, "getItem").returns(mockThrottlingEntity); + const setItemStub = sinon.stub(cache, "setItem"); + const removeItemStub = sinon.stub(cache, "removeItem"); + sinon.stub(Date, "now").callsFake(() => 1) + + try { + await networkManager.sendPostRequest>(thumbprint, "tokenEndpoint", options); + } catch { } + + sinon.assert.callCount(networkStub, 0); + sinon.assert.callCount(getItemStub, 1); + sinon.assert.callCount(setItemStub, 0); + sinon.assert.callCount(removeItemStub, 0); + expect(() => ThrottlingUtils.preProcess(cache, thumbprint)).to.throw(ServerError); + }); + + it("passes request through if expired item in cache", async () => { + const networkInterface = DEFAULT_NETWORK_IMPLEMENTATION; + const cache = new MockStorageClass(); + const networkManager = new NetworkManager(networkInterface, cache); + const thumbprint: RequestThumbprint = THUMBPRINT; + const options: NetworkRequestOptions = NETWORK_REQUEST_OPTIONS; + const mockRes: NetworkResponse = { + headers: { }, + body: AUTHENTICATION_RESULT.body, + status: 200 + } + const mockThrottlingEntity = THROTTLING_ENTITY; + const networkStub = sinon.stub(networkInterface, "sendPostRequestAsync").returns(Promise.resolve(mockRes)); + const getItemStub = sinon.stub(cache, "getItem").returns(mockThrottlingEntity); + const setItemStub = sinon.stub(cache, "setItem"); + const removeItemStub = sinon.stub(cache, "removeItem"); + sinon.stub(Date, "now").callsFake(() => 10) + + const res = await networkManager.sendPostRequest>(thumbprint, "tokenEndpoint", options); + + sinon.assert.callCount(networkStub, 1); + sinon.assert.callCount(getItemStub, 1); + sinon.assert.callCount(setItemStub, 0); + sinon.assert.callCount(removeItemStub, 1); + expect(res).to.deep.eq(mockRes); + }); + + it("creates cache entry on error", async () => { + const networkInterface = DEFAULT_NETWORK_IMPLEMENTATION; + const cache = new MockStorageClass(); + const networkManager = new NetworkManager(networkInterface, cache); + const thumbprint: RequestThumbprint = THUMBPRINT; + const options: NetworkRequestOptions = NETWORK_REQUEST_OPTIONS; + const mockRes: NetworkResponse = { + headers: { }, + body: AUTHENTICATION_RESULT.body, + status: 500 + } + const networkStub = sinon.stub(networkInterface, "sendPostRequestAsync").returns(Promise.resolve(mockRes)); + const getItemStub = sinon.stub(cache, "getItem"); + const setItemStub = sinon.stub(cache, "setItem"); + const removeItemStub = sinon.stub(cache, "removeItem"); + sinon.stub(Date, "now").callsFake(() => 1) + + const res = await networkManager.sendPostRequest>(thumbprint, "tokenEndpoint", options); + + sinon.assert.callCount(networkStub, 1); + sinon.assert.callCount(getItemStub, 1); + sinon.assert.callCount(setItemStub, 1); + sinon.assert.callCount(removeItemStub, 0); + expect(res).to.deep.eq(mockRes); + }); + }); +}); diff --git a/lib/msal-common/test/network/ThrottlingUtils.spec.ts b/lib/msal-common/test/network/ThrottlingUtils.spec.ts new file mode 100644 index 0000000000..09c51692c7 --- /dev/null +++ b/lib/msal-common/test/network/ThrottlingUtils.spec.ts @@ -0,0 +1,261 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. + */ + +import { expect } from "chai"; +import sinon from "sinon"; +import { ThrottlingUtils } from "../../src/network/ThrottlingUtils"; +import { RequestThumbprint } from "../../src/network/RequestThumbprint"; +import { ThrottlingEntity } from "../../src/cache/entities/ThrottlingEntity"; +import { NetworkResponse } from "../../src/network/NetworkManager"; +import { ServerAuthorizationTokenResponse } from "../../src/response/ServerAuthorizationTokenResponse"; +import { MockStorageClass } from "../client/ClientTestUtils"; +import { ServerError } from "../../src"; +import { THUMBPRINT, THROTTLING_ENTITY, TEST_CONFIG } from "../utils/StringConstants"; + +describe("ThrottlingUtils", () => { + describe("generateThrottlingStorageKey", () => { + it("returns a throttling key", () => { + const thumbprint: RequestThumbprint = THUMBPRINT; + const jsonString = JSON.stringify(thumbprint); + const key = ThrottlingUtils.generateThrottlingStorageKey(thumbprint); + + expect(key).to.deep.eq(`throttling.${jsonString}`); + }); + }); + + describe("preProcess", () => { + afterEach(() => { + sinon.restore(); + }) + + it("checks the cache and throws an error", () => { + const thumbprint: RequestThumbprint = THUMBPRINT; + const thumbprintValue: ThrottlingEntity = THROTTLING_ENTITY; + const cache = new MockStorageClass(); + const removeItemStub = sinon.stub(cache, "removeItem"); + sinon.stub(cache, "getItem").callsFake(() => thumbprintValue); + sinon.stub(Date, "now").callsFake(() => 1) + + try { + ThrottlingUtils.preProcess(cache, thumbprint); + } catch { } + sinon.assert.callCount(removeItemStub, 0) + + expect(() => ThrottlingUtils.preProcess(cache, thumbprint)).to.throw(ServerError); + }); + + it("checks the cache and removes an item", () => { + const thumbprint: RequestThumbprint = THUMBPRINT; + const thumbprintValue: ThrottlingEntity = THROTTLING_ENTITY; + const cache = new MockStorageClass(); + const removeItemStub = sinon.stub(cache, "removeItem"); + sinon.stub(cache, "getItem").callsFake(() => thumbprintValue); + sinon.stub(Date, "now").callsFake(() => 10) + + ThrottlingUtils.preProcess(cache, thumbprint); + sinon.assert.callCount(removeItemStub, 1) + + expect(() => ThrottlingUtils.preProcess(cache, thumbprint)).to.not.throw; + }); + + it("checks the cache and does nothing with no match", () => { + const thumbprint: RequestThumbprint = THUMBPRINT; + const cache = new MockStorageClass(); + const removeItemStub = sinon.stub(cache, "removeItem"); + sinon.stub(cache, "getItem").callsFake(() => null); + + ThrottlingUtils.preProcess(cache, thumbprint); + sinon.assert.callCount(removeItemStub, 0) + + expect(() => ThrottlingUtils.preProcess(cache, thumbprint)).to.not.throw; + }); + }); + + describe("postProcess", () => { + afterEach(() => { + sinon.restore(); + }); + + it("sets an item in the cache", () => { + const thumbprint: RequestThumbprint = THUMBPRINT; + const res: NetworkResponse = { + headers: { }, + body: { }, + status: 429 + }; + const cache = new MockStorageClass(); + const setItemStub = sinon.stub(cache, "setItem"); + + ThrottlingUtils.postProcess(cache, thumbprint, res); + sinon.assert.callCount(setItemStub, 1); + }); + + it("does not set an item in the cache", () => { + const thumbprint: RequestThumbprint = THUMBPRINT; + const res: NetworkResponse = { + headers: { }, + body: { }, + status: 200 + }; + const cache = new MockStorageClass(); + const setItemStub = sinon.stub(cache, "setItem"); + + ThrottlingUtils.postProcess(cache, thumbprint, res); + sinon.assert.callCount(setItemStub, 0); + }); + }) + + describe("checkResponseStatus", () => { + it("returns true if status == 429", () => { + const res: NetworkResponse = { + headers: { }, + body: { }, + status: 429 + }; + + const bool = ThrottlingUtils.checkResponseStatus(res); + expect(bool).to.be.true; + }); + + it("returns true if 500 <= status < 600", () => { + const res: NetworkResponse = { + headers: { }, + body: { }, + status: 500 + }; + + const bool = ThrottlingUtils.checkResponseStatus(res); + expect(bool).to.be.true; + }); + + it("returns false if status is not 429 or between 500 and 600", () => { + const res: NetworkResponse = { + headers: { }, + body: { }, + status: 430 + }; + + const bool = ThrottlingUtils.checkResponseStatus(res); + expect(bool).to.be.false; + }); + }); + + describe("checkResponseForRetryAfter", () => { + it("returns true when Retry-After header exists and when status <= 200", () => { + const headers: Record = { }; + headers["Retry-After"] = "test"; + const res: NetworkResponse = { + headers, + body: { }, + status: 199 + }; + + const bool = ThrottlingUtils.checkResponseForRetryAfter(res); + expect(bool).to.be.true; + }); + + it("returns true when Retry-After header exists and when status > 300", () => { + const headers: Record = { }; + headers["Retry-After"] = "test"; + const res: NetworkResponse = { + headers, + body: { }, + status: 300 + }; + + const bool = ThrottlingUtils.checkResponseForRetryAfter(res); + expect(bool).to.be.true; + }); + + it("returns false when there is no RetryAfter header", () => { + const headers: Record = { }; + const res: NetworkResponse = { + headers, + body: { }, + status: 301 + }; + + const bool = ThrottlingUtils.checkResponseForRetryAfter(res); + expect(bool).to.be.false; + }); + + it("returns false when 200 <= status < 300", () => { + const headers: Record = { }; + const res: NetworkResponse = { + headers, + body: { }, + status: 200 + }; + + const bool = ThrottlingUtils.checkResponseForRetryAfter(res); + expect(bool).to.be.false; + }); + }); + + describe("calculateThrottleTime", () => { + before(() => { + sinon.stub(Date, "now").callsFake(() => 5000) + }); + + after(() => { + sinon.restore(); + }); + + it("returns calculated time to throttle", () => { + const time = ThrottlingUtils.calculateThrottleTime(10); + expect(time).to.be.deep.eq(15000); + }); + + it("calculates with the default time given a bad number", () => { + const time1 = ThrottlingUtils.calculateThrottleTime(-1); + const time2 = ThrottlingUtils.calculateThrottleTime(0); + const time3 = ThrottlingUtils.calculateThrottleTime(null); + + // Based on Constants.DEFAULT_THROTTLE_TIME_SECONDS + expect(time1).to.be.deep.eq(65000); + expect(time2).to.be.deep.eq(65000); + expect(time3).to.be.deep.eq(65000); + }); + + it("calculates with the default MAX if given too large of a number", () => { + const time = ThrottlingUtils.calculateThrottleTime(1000000000); + + // Based on Constants.DEFAULT_MAX_THROTTLE_TIME_SECONDS + expect(time).to.be.deep.eq(3605000); + }); + }); + + describe("removeThrottle", () => { + after(() => { + sinon.restore(); + }); + + it("removes the entry from storage and returns true", () => { + const cache = new MockStorageClass(); + const removeItemStub = sinon.stub(cache, "removeItem").returns(true); + const clientId = TEST_CONFIG.MSAL_CLIENT_ID; + const authority = TEST_CONFIG.validAuthority; + const scopes = TEST_CONFIG.DEFAULT_SCOPES; + + const res = ThrottlingUtils.removeThrottle(cache, clientId, authority, scopes); + + sinon.assert.callCount(removeItemStub, 1); + expect(res).to.be.true; + }) + + it("doesn't find an entry and returns false", () => { + const cache = new MockStorageClass(); + const removeItemStub = sinon.stub(cache, "removeItem").returns(false); + const clientId = TEST_CONFIG.MSAL_CLIENT_ID; + const authority = TEST_CONFIG.validAuthority; + const scopes = TEST_CONFIG.DEFAULT_SCOPES; + + const res = ThrottlingUtils.removeThrottle(cache, clientId, authority, scopes); + + sinon.assert.callCount(removeItemStub, 1); + expect(res).to.be.false; + }); + }); +}); diff --git a/lib/msal-common/test/utils/StringConstants.ts b/lib/msal-common/test/utils/StringConstants.ts index 6d16b07a5a..676c91c0c4 100644 --- a/lib/msal-common/test/utils/StringConstants.ts +++ b/lib/msal-common/test/utils/StringConstants.ts @@ -3,6 +3,8 @@ */ import { Constants } from "../../src/utils/Constants"; +import { RequestThumbprint, ThrottlingEntity } from "../../src"; +import { NetworkRequestOptions } from "../../src/network/INetworkModule"; // Test Tokens export const TEST_TOKENS = { @@ -276,3 +278,31 @@ export const AUTHORIZATION_PENDING_RESPONSE = { error_uri: 'https://login.microsoftonline.com/error?code=70016' } }; + +export const DEFAULT_NETWORK_IMPLEMENTATION = { + sendGetRequestAsync: async (url: string, options?: NetworkRequestOptions): Promise => { + return { test: "test" }; + }, + sendPostRequestAsync: async (url: string, options?: NetworkRequestOptions): Promise => { + return { test: "test" }; + } +} + +export const THUMBPRINT: RequestThumbprint = { + clientId: TEST_CONFIG.MSAL_CLIENT_ID, + authority: TEST_CONFIG.validAuthority, + scopes: TEST_CONFIG.DEFAULT_SCOPES +}; + +export const THROTTLING_ENTITY: ThrottlingEntity = { + throttleTime: 5, + error: "This is a error", + errorCodes: ["ErrorCode"], + errorMessage:"This is an errorMessage", + subError: "This is a subError" +}; + +export const NETWORK_REQUEST_OPTIONS: NetworkRequestOptions = { + headers: { }, + body: "" +}; diff --git a/lib/msal-common/test/utils/StringUtils.spec.ts b/lib/msal-common/test/utils/StringUtils.spec.ts index ea6232ceab..895587c682 100644 --- a/lib/msal-common/test/utils/StringUtils.spec.ts +++ b/lib/msal-common/test/utils/StringUtils.spec.ts @@ -131,4 +131,20 @@ describe("StringUtils.ts Class Unit Tests", () => { it("removeEmptyStringsFromArray() removes empty strings from an array", () => { }); + + describe("jsonParseHelper", () => { + it("parses json", () => { + const test = { test: "json" }; + const jsonString = JSON.stringify(test); + const parsedVal = StringUtils.jsonParseHelper(jsonString); + expect(parsedVal).to.be.deep.eq(test); + }); + + it("returns null on error", () => { + const parsedValNull = StringUtils.jsonParseHelper(null); + const parsedValEmptyString = StringUtils.jsonParseHelper(""); + expect(parsedValNull).to.be.null; + expect(parsedValEmptyString).to.be.null; + }) + }); });