diff --git a/auth0/src/main/java/com/auth0/android/provider/OAuthManager.java b/auth0/src/main/java/com/auth0/android/provider/OAuthManager.java index f089b17a1..20fb4fe68 100644 --- a/auth0/src/main/java/com/auth0/android/provider/OAuthManager.java +++ b/auth0/src/main/java/com/auth0/android/provider/OAuthManager.java @@ -70,6 +70,7 @@ class OAuthManager extends ResumableManager { private CustomTabsOptions ctOptions; private Integer idTokenVerificationLeeway; private String idTokenVerificationIssuer; + private Map headers; OAuthManager(@NonNull Auth0 account, @NonNull AuthCallback callback, @NonNull Map parameters, @NonNull CustomTabsOptions ctOptions) { this.account = account; @@ -77,6 +78,7 @@ class OAuthManager extends ResumableManager { this.parameters = new HashMap<>(parameters); this.apiClient = new AuthenticationAPIClient(account); this.ctOptions = ctOptions; + this.headers = new HashMap<>(); } void useFullScreen(boolean useFullScreen) { @@ -101,7 +103,7 @@ void setIdTokenVerificationIssuer(String issuer) { } void startAuthentication(Activity activity, String redirectUri, int requestCode) { - addPKCEParameters(parameters, redirectUri); + addPKCEParameters(parameters, redirectUri, headers); addClientParameters(parameters, redirectUri); addValidationParameters(parameters); Uri uri = buildAuthorizeUri(); @@ -114,6 +116,10 @@ void startAuthentication(Activity activity, String redirectUri, int requestCode) } } + void setHeaders(@NonNull Map headers) { + this.headers.putAll(headers); + } + @SuppressWarnings("ConstantConditions") @Override boolean resume(AuthorizeResult result) { @@ -305,12 +311,12 @@ private Uri buildAuthorizeUri() { return uri; } - private void addPKCEParameters(Map parameters, String redirectUri) { + private void addPKCEParameters(Map parameters, String redirectUri, Map headers) { if (!shouldUsePKCE()) { return; } try { - createPKCE(redirectUri); + createPKCE(redirectUri, headers); String codeChallenge = pkce.getCodeChallenge(); parameters.put(KEY_CODE_CHALLENGE, codeChallenge); parameters.put(KEY_CODE_CHALLENGE_METHOD, METHOD_SHA_256); @@ -340,9 +346,9 @@ private void addClientParameters(Map parameters, String redirect parameters.put(KEY_REDIRECT_URI, redirectUri); } - private void createPKCE(String redirectUri) { + private void createPKCE(String redirectUri, Map headers) { if (pkce == null) { - pkce = new PKCE(apiClient, redirectUri); + pkce = new PKCE(apiClient, redirectUri, headers); } } diff --git a/auth0/src/main/java/com/auth0/android/provider/PKCE.java b/auth0/src/main/java/com/auth0/android/provider/PKCE.java index 1aa2f04b1..0425de96f 100644 --- a/auth0/src/main/java/com/auth0/android/provider/PKCE.java +++ b/auth0/src/main/java/com/auth0/android/provider/PKCE.java @@ -31,9 +31,13 @@ import com.auth0.android.authentication.AuthenticationAPIClient; import com.auth0.android.authentication.AuthenticationException; +import com.auth0.android.authentication.request.TokenRequest; import com.auth0.android.callback.BaseCallback; import com.auth0.android.result.Credentials; +import java.util.HashMap; +import java.util.Map; + /** * Performs code exchange according to Proof Key for Code Exchange (PKCE) spec. */ @@ -44,6 +48,7 @@ class PKCE { private final String codeVerifier; private final String redirectUri; private final String codeChallenge; + private final Map headers; /** * Creates a new instance of this class with the given AuthenticationAPIClient. @@ -51,19 +56,22 @@ class PKCE { * * @param apiClient to get the OAuth Token. * @param redirectUri going to be used in the OAuth code request. + * @param headers HTTP headers added to the OAuth token request. * @throws IllegalStateException when either 'US-ASCII` encoding or 'SHA-256' algorithm is not available. * @see #isAvailable() */ - public PKCE(@NonNull AuthenticationAPIClient apiClient, String redirectUri) { - this(apiClient, new AlgorithmHelper(), redirectUri); + public PKCE(@NonNull AuthenticationAPIClient apiClient, String redirectUri, @NonNull Map headers) { + this(apiClient, new AlgorithmHelper(), redirectUri, headers); } @VisibleForTesting - PKCE(@NonNull AuthenticationAPIClient apiClient, @NonNull AlgorithmHelper algorithmHelper, @NonNull String redirectUri) { + PKCE(@NonNull AuthenticationAPIClient apiClient, @NonNull AlgorithmHelper algorithmHelper, + @NonNull String redirectUri, @NonNull Map headers) { this.apiClient = apiClient; this.redirectUri = redirectUri; this.codeVerifier = algorithmHelper.generateCodeVerifier(); this.codeChallenge = algorithmHelper.generateCodeChallenge(codeVerifier); + this.headers = headers; } /** @@ -83,8 +91,13 @@ public String getCodeChallenge() { * @param callback to notify the result of this call to. */ public void getToken(String authorizationCode, @NonNull final AuthCallback callback) { - apiClient.token(authorizationCode, redirectUri) - .setCodeVerifier(codeVerifier) + TokenRequest tokenRequest = apiClient.token(authorizationCode, redirectUri); + + for (Map.Entry entry : headers.entrySet()) { + tokenRequest.addHeader(entry.getKey(), entry.getValue()); + } + + tokenRequest.setCodeVerifier(codeVerifier) .start(new BaseCallback() { @Override public void onSuccess(@Nullable Credentials payload) { diff --git a/auth0/src/main/java/com/auth0/android/provider/WebAuthProvider.java b/auth0/src/main/java/com/auth0/android/provider/WebAuthProvider.java index ad913073e..7403c2be3 100644 --- a/auth0/src/main/java/com/auth0/android/provider/WebAuthProvider.java +++ b/auth0/src/main/java/com/auth0/android/provider/WebAuthProvider.java @@ -156,6 +156,7 @@ public static class Builder { private final Auth0 account; private final Map values; + private final Map headers; private boolean useBrowser; private boolean useFullscreen; private PKCE pkce; @@ -174,6 +175,7 @@ public static class Builder { this.useBrowser = true; this.useFullscreen = false; this.ctOptions = CustomTabsOptions.newBuilder().build(); + this.headers = new HashMap<>(); withResponseType(ResponseType.CODE); withScope(SCOPE_TYPE_OPENID); } @@ -323,6 +325,17 @@ public Builder withScope(@NonNull String scope) { return this; } + /** + * Add custom headers for PKCE token request. + * + * @param headers for token request. + * @return the current builder instance + */ + public Builder withHeaders(@NonNull Map headers) { + this.headers.putAll(headers); + return this; + } + /** * Give a connection scope for this request. * @@ -448,6 +461,7 @@ public void start(@NonNull Activity activity, @NonNull AuthCallback callback, in OAuthManager manager = new OAuthManager(account, callback, values, ctOptions); manager.useFullScreen(useFullscreen); manager.useBrowser(useBrowser); + manager.setHeaders(headers); manager.setPKCE(pkce); manager.setIdTokenVerificationLeeway(leeway); manager.setIdTokenVerificationIssuer(issuer); diff --git a/auth0/src/test/java/com/auth0/android/provider/PKCETest.java b/auth0/src/test/java/com/auth0/android/provider/PKCETest.java index 870122e52..11ccb3d15 100644 --- a/auth0/src/test/java/com/auth0/android/provider/PKCETest.java +++ b/auth0/src/test/java/com/auth0/android/provider/PKCETest.java @@ -44,6 +44,8 @@ import java.io.UnsupportedEncodingException; import java.security.NoSuchAlgorithmException; +import java.util.HashMap; +import java.util.Map; import static org.hamcrest.CoreMatchers.containsString; import static org.hamcrest.CoreMatchers.equalTo; @@ -77,18 +79,18 @@ public class PKCETest { @Before public void setUp() { MockitoAnnotations.initMocks(this); - pkce = new PKCE(apiClient, new AlgorithmHelperMock(CODE_VERIFIER), REDIRECT_URI); + pkce = new PKCE(apiClient, new AlgorithmHelperMock(CODE_VERIFIER), REDIRECT_URI, new HashMap()); } @Test public void shouldGenerateChallengeFromRandomVerifier() { - PKCE pkce = new PKCE(apiClient, REDIRECT_URI); + PKCE pkce = new PKCE(apiClient, REDIRECT_URI, new HashMap()); assertThat(pkce.getCodeChallenge(), is(notNullValue())); } @Test public void shouldGenerateValidRandomCodeChallenge() { - PKCE randomPKCE = new PKCE(apiClient, REDIRECT_URI); + PKCE randomPKCE = new PKCE(apiClient, REDIRECT_URI, new HashMap()); String challenge = randomPKCE.getCodeChallenge(); assertThat(challenge, is(notNullValue())); assertThat(challenge, CoreMatchers.not(Matchers.isEmptyString())); @@ -118,6 +120,24 @@ public void shouldGetToken() { verify(callback).onSuccess(credentials); } + @Test + public void shouldAddHeaders() { + String header1Name = "header1"; + String header1Value = "val1"; + String header2Name = "header2"; + String header2Value = "val2"; + Map headers = new HashMap<>(); + headers.put(header1Name, header1Value); + headers.put(header2Name, header2Value); + PKCE pkce = new PKCE(apiClient, new AlgorithmHelperMock(CODE_VERIFIER), REDIRECT_URI, headers); + TokenRequest tokenRequest = mock(TokenRequest.class); + when(apiClient.token(AUTHORIZATION_CODE, REDIRECT_URI)).thenReturn(tokenRequest); + when(tokenRequest.setCodeVerifier(CODE_VERIFIER)).thenReturn(tokenRequest); + pkce.getToken(AUTHORIZATION_CODE, callback); + verify(tokenRequest).addHeader(header1Name, header1Value); + verify(tokenRequest).addHeader(header2Name, header2Value); + } + @Test public void shouldFailToGetToken() { TokenRequest tokenRequest = mock(TokenRequest.class);