Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
import org.apache.hadoop.fs.azurebfs.oauth2.ClientCredsTokenProvider;
import org.apache.hadoop.fs.azurebfs.oauth2.CustomTokenProviderAdapter;
import org.apache.hadoop.fs.azurebfs.oauth2.MsiTokenProvider;
import org.apache.hadoop.fs.azurebfs.oauth2.ArcMsiTokenProvider;
import org.apache.hadoop.fs.azurebfs.oauth2.RefreshTokenBasedTokenProvider;
import org.apache.hadoop.fs.azurebfs.oauth2.UserPasswordTokenProvider;
import org.apache.hadoop.fs.azurebfs.oauth2.WorkloadIdentityTokenProvider;
Expand Down Expand Up @@ -961,6 +962,9 @@ public AccessTokenProvider getTokenProvider() throws TokenAccessProviderExceptio
String authEndpoint = getTrimmedPasswordString(
FS_AZURE_ACCOUNT_OAUTH_MSI_ENDPOINT,
AuthConfigurations.DEFAULT_FS_AZURE_ACCOUNT_OAUTH_MSI_ENDPOINT);
String apiVersion = getTrimmedPasswordString(
FS_AZURE_ACCOUNT_OAUTH_MSI_ENDPOINT_API_VERSION,
AuthConfigurations.DEFAULT_FS_AZURE_ACCOUNT_OAUTH_MSI_ENDPOINT_API_VERSION);
String tenantGuid =
getPasswordString(FS_AZURE_ACCOUNT_OAUTH_MSI_TENANT);
String clientId =
Expand All @@ -969,9 +973,27 @@ public AccessTokenProvider getTokenProvider() throws TokenAccessProviderExceptio
FS_AZURE_ACCOUNT_OAUTH_MSI_AUTHORITY,
AuthConfigurations.DEFAULT_FS_AZURE_ACCOUNT_OAUTH_MSI_AUTHORITY);
authority = appendSlashIfNeeded(authority);
tokenProvider = new MsiTokenProvider(authEndpoint, tenantGuid,
tokenProvider = new MsiTokenProvider(authEndpoint, apiVersion, tenantGuid,
clientId, authority);
LOG.trace("MsiTokenProvider initialized");
} else if (tokenProviderClass == ArcMsiTokenProvider.class) {
String authEndpoint = getTrimmedPasswordString(
FS_AZURE_ACCOUNT_OAUTH_MSI_ENDPOINT,
AuthConfigurations.DEFAULT_FS_AZURE_ACCOUNT_OAUTH_MSI_ENDPOINT);
String apiVersion = getTrimmedPasswordString(
FS_AZURE_ACCOUNT_OAUTH_ARC_MSI_ENDPOINT_API_VERSION,
AuthConfigurations.DEFAULT_FS_AZURE_ACCOUNT_OAUTH_ARC_MSI_ENDPOINT_API_VERSION);
String tenantGuid =
getPasswordString(FS_AZURE_ACCOUNT_OAUTH_MSI_TENANT);
String clientId =
getPasswordString(FS_AZURE_ACCOUNT_OAUTH_CLIENT_ID);
String authority = getTrimmedPasswordString(
FS_AZURE_ACCOUNT_OAUTH_MSI_AUTHORITY,
AuthConfigurations.DEFAULT_FS_AZURE_ACCOUNT_OAUTH_MSI_AUTHORITY);
authority = appendSlashIfNeeded(authority);
tokenProvider = new ArcMsiTokenProvider(authEndpoint, apiVersion, tenantGuid,
clientId, authority);
LOG.trace("ArcMsiTokenProvider initialized");
} else if (tokenProviderClass == RefreshTokenBasedTokenProvider.class) {
String authEndpoint = getTrimmedPasswordString(
FS_AZURE_ACCOUNT_OAUTH_REFRESH_TOKEN_ENDPOINT,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,14 @@ public final class AuthConfigurations {
public static final String
DEFAULT_FS_AZURE_ACCOUNT_OAUTH_TOKEN_FILE =
"/var/run/secrets/azure/tokens/azure-identity-token";
/** Default OAuth api-version end point for the MSI flow. */
public static final String
DEFAULT_FS_AZURE_ACCOUNT_OAUTH_MSI_ENDPOINT_API_VERSION =
"2018-02-01";
/** Default OAuth api-version end point for the ARC MSI flow. */
public static final String
DEFAULT_FS_AZURE_ACCOUNT_OAUTH_ARC_MSI_ENDPOINT_API_VERSION =
"2021-02-01";

private AuthConfigurations() {
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,10 @@ public final class ConfigurationKeys {
public static final String FS_AZURE_ACCOUNT_OAUTH_MSI_ENDPOINT = "fs.azure.account.oauth2.msi.endpoint";
/** Key for oauth msi Authority: {@value}. */
public static final String FS_AZURE_ACCOUNT_OAUTH_MSI_AUTHORITY = "fs.azure.account.oauth2.msi.authority";
/** Key for oauth msi endpoint api version: {@value}. */
public static final String FS_AZURE_ACCOUNT_OAUTH_MSI_ENDPOINT_API_VERSION = "fs.azure.account.oauth2.msi.endpoint.api.version";
/** Key for oauth arc msi endpoint api version: {@value}. */
public static final String FS_AZURE_ACCOUNT_OAUTH_ARC_MSI_ENDPOINT_API_VERSION = "fs.azure.account.oauth2.arc.msi.endpoint.api.version";
/** Key for oauth user name: {@value}. */
public static final String FS_AZURE_ACCOUNT_OAUTH_USER_NAME = "fs.azure.account.oauth2.user.name";
/** Key for oauth user password: {@value}. */
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.hadoop.fs.azurebfs.oauth2;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;

/**
* Provides tokens based on Azure VM's Managed Service Identity.
*/
public class ArcMsiTokenProvider extends AccessTokenProvider {

private final String authEndpoint;

private final String apiVersion;

private final String authority;

private final String tenantGuid;

private final String clientId;

private long tokenFetchTime = -1;

private static final long ONE_HOUR = 3600 * 1000;

private static final Logger LOG = LoggerFactory.getLogger(AccessTokenProvider.class);

public ArcMsiTokenProvider(final String authEndpoint, final String apiVersion, final String tenantGuid,
final String clientId, final String authority) {
this.authEndpoint = authEndpoint;
this.tenantGuid = tenantGuid;
this.clientId = clientId;
this.authority = authority;
this.apiVersion = apiVersion;
}

@Override
protected AzureADToken refreshToken() throws IOException {
LOG.debug("AADToken: refreshing token from ARC MSI");
AzureADToken token = AzureADAuthenticator
.getTokenFromArcMsi(authEndpoint, apiVersion, tenantGuid, clientId, authority, false);
tokenFetchTime = System.currentTimeMillis();
return token;
}

/**
* Checks if the token is about to expire as per base expiry logic.
* Otherwise try to expire every 1 hour
*
* @return true if the token is expiring in next 1 hour or if a token has
* never been fetched
*/
@Override
protected boolean isTokenAboutToExpire() {
if (tokenFetchTime == -1 || super.isTokenAboutToExpire()) {
return true;
}

boolean expiring = false;
long elapsedTimeSinceLastTokenRefreshInMillis =
System.currentTimeMillis() - tokenFetchTime;
expiring = elapsedTimeSinceLastTokenRefreshInMillis >= ONE_HOUR
|| elapsedTimeSinceLastTokenRefreshInMillis < 0;
// In case of, Token is not refreshed for 1 hr or any clock skew issues,
// refresh token.
if (expiring) {
LOG.debug("ARCMSIToken: token renewing. Time elapsed since last token fetch:"
+ " {} milli seconds", elapsedTimeSinceLastTokenRefreshInMillis);
}

return expiring;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
import java.io.BufferedReader;
import java.io.FileReader;
import java.net.HttpURLConnection;
import java.net.MalformedURLException;
import java.net.URL;
Expand Down Expand Up @@ -169,11 +171,11 @@
* @return {@link AzureADToken} obtained using the creds
* @throws IOException throws IOException if there is a failure in obtaining the token
*/
public static AzureADToken getTokenFromMsi(final String authEndpoint,
public static AzureADToken getTokenFromMsi(final String authEndpoint, final String apiVersion,

Check failure on line 174 in hadoop-tools/hadoop-azure/src/main/java/org/apache/hadoop/fs/azurebfs/oauth2/AzureADAuthenticator.java

View check run for this annotation

ASF Cloudbees Jenkins ci-hadoop / Apache Yetus

hadoop-tools/hadoop-azure/src/main/java/org/apache/hadoop/fs/azurebfs/oauth2/AzureADAuthenticator.java#L174

javadoc: warning: no @param for apiVersion

Check failure on line 174 in hadoop-tools/hadoop-azure/src/main/java/org/apache/hadoop/fs/azurebfs/oauth2/AzureADAuthenticator.java

View check run for this annotation

ASF Cloudbees Jenkins ci-hadoop / Apache Yetus

hadoop-tools/hadoop-azure/src/main/java/org/apache/hadoop/fs/azurebfs/oauth2/AzureADAuthenticator.java#L174

javadoc: warning: no @param for apiVersion

Check failure on line 174 in hadoop-tools/hadoop-azure/src/main/java/org/apache/hadoop/fs/azurebfs/oauth2/AzureADAuthenticator.java

View check run for this annotation

ASF Cloudbees Jenkins ci-hadoop / Apache Yetus

hadoop-tools/hadoop-azure/src/main/java/org/apache/hadoop/fs/azurebfs/oauth2/AzureADAuthenticator.java#L174

javadoc: warning: no @param for authority

Check failure on line 174 in hadoop-tools/hadoop-azure/src/main/java/org/apache/hadoop/fs/azurebfs/oauth2/AzureADAuthenticator.java

View check run for this annotation

ASF Cloudbees Jenkins ci-hadoop / Apache Yetus

hadoop-tools/hadoop-azure/src/main/java/org/apache/hadoop/fs/azurebfs/oauth2/AzureADAuthenticator.java#L174

javadoc: warning: no @param for authority
final String tenantGuid, final String clientId, String authority,
boolean bypassCache) throws IOException {
QueryParams qp = new QueryParams();
qp.add("api-version", "2018-02-01");
qp.add("api-version", apiVersion);
qp.add("resource", RESOURCE_NAME);

if (tenantGuid != null && tenantGuid.length() > 0) {
Expand All @@ -194,7 +196,51 @@
headers.put("Metadata", "true");

LOG.debug("AADToken: starting to fetch token using MSI");
return getTokenCall(authEndpoint, qp.serialize(), headers, "GET", true);
return getTokenCall(authEndpoint, qp.serialize(), headers, "GET", true, false);
}

/**
* Gets AAD token from the local virtual machine's ARC extension. This only works on
* an Azure VM with MSI extension
* enabled.
*
* @param authEndpoint the OAuth 2.0 token endpoint associated
* with the user's directory (obtain from
* Active Directory configuration)
* @param tenantGuid (optional) The guid of the AAD tenant. Can be {@code null}.
* @param clientId (optional) The clientId guid of the MSI service
* principal to use. Can be {@code null}.
* @param bypassCache {@code boolean} specifying whether a cached token is acceptable or a fresh token
* request should me made to AAD
* @return {@link AzureADToken} obtained using the creds
* @throws IOException throws IOException if there is a failure in obtaining the token
*/
public static AzureADToken getTokenFromArcMsi(final String authEndpoint, final String apiVersion,

Check failure on line 218 in hadoop-tools/hadoop-azure/src/main/java/org/apache/hadoop/fs/azurebfs/oauth2/AzureADAuthenticator.java

View check run for this annotation

ASF Cloudbees Jenkins ci-hadoop / Apache Yetus

hadoop-tools/hadoop-azure/src/main/java/org/apache/hadoop/fs/azurebfs/oauth2/AzureADAuthenticator.java#L218

javadoc: warning: no @param for apiVersion

Check failure on line 218 in hadoop-tools/hadoop-azure/src/main/java/org/apache/hadoop/fs/azurebfs/oauth2/AzureADAuthenticator.java

View check run for this annotation

ASF Cloudbees Jenkins ci-hadoop / Apache Yetus

hadoop-tools/hadoop-azure/src/main/java/org/apache/hadoop/fs/azurebfs/oauth2/AzureADAuthenticator.java#L218

javadoc: warning: no @param for apiVersion
final String tenantGuid, final String clientId, String authority,
boolean bypassCache) throws IOException {
QueryParams qp = new QueryParams();
qp.add("api-version", apiVersion);
qp.add("resource", RESOURCE_NAME);

if (tenantGuid != null && tenantGuid.length() > 0) {
authority = authority + tenantGuid;
LOG.debug("MSI authority : {}", authority);
qp.add("authority", authority);
}

if (clientId != null && clientId.length() > 0) {
qp.add("client_id", clientId);
}

if (bypassCache) {
qp.add("bypass_cache", "true");
}

Hashtable<String, String> headers = new Hashtable<>();
headers.put("Metadata", "true");

LOG.debug("AADToken: starting to fetch token using MSI from ARC");
return getTokenCall(authEndpoint, qp.serialize(), headers, "GET", true, true);
}

/**
Expand Down Expand Up @@ -327,11 +373,11 @@

private static AzureADToken getTokenCall(String authEndpoint, String body,
Hashtable<String, String> headers, String httpMethod) throws IOException {
return getTokenCall(authEndpoint, body, headers, httpMethod, false);
return getTokenCall(authEndpoint, body, headers, httpMethod, false, false);
}

private static AzureADToken getTokenCall(String authEndpoint, String body,
Hashtable<String, String> headers, String httpMethod, boolean isMsi)
Hashtable<String, String> headers, String httpMethod, boolean isMsi, boolean isArc)
throws IOException {
AzureADToken token = null;

Expand All @@ -346,7 +392,7 @@
httperror = 0;
ex = null;
try {
token = getTokenSingleCall(authEndpoint, body, headers, httpMethod, isMsi);
token = getTokenSingleCall(authEndpoint, body, headers, httpMethod, isMsi, isArc);
} catch (HttpException e) {
httperror = e.httpErrorCode;
ex = e;
Expand Down Expand Up @@ -385,18 +431,83 @@

private static AzureADToken getTokenSingleCall(String authEndpoint,
String payload, Hashtable<String, String> headers, String httpMethod,
boolean isMsi)
boolean isMsi, boolean isArc)
throws IOException {

AzureADToken token = null;
HttpURLConnection conn = null;
String urlString = authEndpoint;
String challengerToken = null;

httpMethod = (httpMethod == null) ? "POST" : httpMethod;
if (httpMethod.equals("GET")) {
urlString = urlString + "?" + payload;
}

if (isArc) {
// Currently there is a known flow that ARC needs obtain a challenge token first
// before and in order to get access_token from the same MSI endpoint
try{
LOG.debug("Requesting a challenge token by {} to {}",
httpMethod, authEndpoint);
URL url = new URL(urlString);
conn = (HttpURLConnection) url.openConnection();
conn.setRequestMethod(httpMethod);
conn.setReadTimeout(READ_TIMEOUT);
conn.setConnectTimeout(CONNECT_TIMEOUT);

if (headers != null && headers.size() > 0) {
for (Map.Entry<String, String> entry : headers.entrySet()) {
conn.setRequestProperty(entry.getKey(), entry.getValue());
}
}
conn.setRequestProperty("Connection", "close");
AbfsIoUtils.dumpHeadersToDebugLog("Request Headers",
conn.getRequestProperties());
if (httpMethod.equals("POST")) {
conn.setDoOutput(true);
conn.getOutputStream().write(payload.getBytes(StandardCharsets.UTF_8));
}
AbfsIoUtils.dumpHeadersToDebugLog("Response Headers",
conn.getHeaderFields());

int httpResponseCode = conn.getResponseCode();
String requestId = conn.getHeaderField("x-ms-request-id");
String responseContentType = conn.getHeaderField("Content-Type");
String operation = "Challenge Token: HTTP connection to " + authEndpoint
+ " failed for getting challenge token from ARC MSI endpoint.";
InputStream stream = conn.getErrorStream();
if (stream == null) {
// no error stream, try the original input stream
stream = conn.getInputStream();
}
String responseBody = consumeInputStream(stream, 1024);

String authHeader = conn.getHeaderField("Www-Authenticate");
if (authHeader != null) {
// Extract the challenge token path
int index = authHeader.indexOf('=');
if (index != -1) {
String authHeaderPath = authHeader.substring(index + 1).trim();
try (BufferedReader reader = new BufferedReader(new FileReader(authHeaderPath))) {
challengerToken = reader.readLine().trim();
}
}
} else {
throw new HttpException(httpResponseCode,
requestId,
operation,
authEndpoint,
responseContentType,
responseBody);
}
} finally {
if (conn != null) {
conn.disconnect();
}
}
}

try {
LOG.debug("Requesting an OAuth token by {} to {}",
httpMethod, authEndpoint);
Expand All @@ -406,6 +517,10 @@
conn.setReadTimeout(READ_TIMEOUT);
conn.setConnectTimeout(CONNECT_TIMEOUT);

if (isArc) {
conn.setRequestProperty("Authorization", "Basic " + challengerToken);
}

if (headers != null && headers.size() > 0) {
for (Map.Entry<String, String> entry : headers.entrySet()) {
conn.setRequestProperty(entry.getKey(), entry.getValue());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ public class MsiTokenProvider extends AccessTokenProvider {

private final String authEndpoint;

private final String apiVersion;

private final String authority;

private final String tenantGuid;
Expand All @@ -42,19 +44,20 @@ public class MsiTokenProvider extends AccessTokenProvider {

private static final Logger LOG = LoggerFactory.getLogger(AccessTokenProvider.class);

public MsiTokenProvider(final String authEndpoint, final String tenantGuid,
public MsiTokenProvider(final String authEndpoint, final String apiVersion, final String tenantGuid,
final String clientId, final String authority) {
this.authEndpoint = authEndpoint;
this.tenantGuid = tenantGuid;
this.clientId = clientId;
this.authority = authority;
this.apiVersion = apiVersion;
}

@Override
protected AzureADToken refreshToken() throws IOException {
LOG.debug("AADToken: refreshing token from MSI");
AzureADToken token = AzureADAuthenticator
.getTokenFromMsi(authEndpoint, tenantGuid, clientId, authority, false);
.getTokenFromMsi(authEndpoint, apiVersion, tenantGuid, clientId, authority, false);
tokenFetchTime = System.currentTimeMillis();
return token;
}
Expand Down
Loading