From 94f17e2b7d57e29917182f8052c1a33717408df2 Mon Sep 17 00:00:00 2001 From: Phil Zampino Date: Fri, 6 Mar 2020 19:55:39 -0500 Subject: [PATCH] KNOX-2266 - Tokens Should Include a Unique Identifier --- .../jwt/filter/AbstractJWTFilter.java | 28 +-- .../filter/AccessTokenFederationFilter.java | 11 +- .../federation/CommonJWTFilterTest.java | 9 +- .../impl/AliasBasedTokenStateService.java | 91 ++++----- .../token/impl/DefaultTokenStateService.java | 184 ++++++++---------- .../token/impl/TokenStateServiceMessages.java | 36 ++-- .../impl/DefaultTokenStateServiceTest.java | 32 +-- .../service/knoxtoken/TokenResource.java | 44 ++++- .../knoxtoken/TokenServiceMessages.java | 26 ++- .../knoxtoken/TokenServiceResourceTest.java | 101 ++++++++-- .../security/token/TokenStateService.java | 40 ++-- .../services/security/token/TokenUtils.java | 58 +++++- .../security/token/UnknownTokenException.java | 16 +- .../security/token/impl/JWTToken.java | 6 + .../security/token/impl/JWTTokenTest.java | 20 ++ 15 files changed, 420 insertions(+), 282 deletions(-) diff --git a/gateway-provider-security-jwt/src/main/java/org/apache/knox/gateway/provider/federation/jwt/filter/AbstractJWTFilter.java b/gateway-provider-security-jwt/src/main/java/org/apache/knox/gateway/provider/federation/jwt/filter/AbstractJWTFilter.java index af33275b26..33af86ff89 100644 --- a/gateway-provider-security-jwt/src/main/java/org/apache/knox/gateway/provider/federation/jwt/filter/AbstractJWTFilter.java +++ b/gateway-provider-security-jwt/src/main/java/org/apache/knox/gateway/provider/federation/jwt/filter/AbstractJWTFilter.java @@ -49,7 +49,6 @@ import org.apache.knox.gateway.audit.api.Auditor; import org.apache.knox.gateway.audit.api.ResourceType; import org.apache.knox.gateway.audit.log4j.audit.AuditConstants; -import org.apache.knox.gateway.config.GatewayConfig; import org.apache.knox.gateway.filter.AbstractGatewayFilter; import org.apache.knox.gateway.i18n.messages.MessagesFactory; import org.apache.knox.gateway.provider.federation.jwt.JWTMessages; @@ -59,6 +58,7 @@ import org.apache.knox.gateway.services.security.token.JWTokenAuthority; import org.apache.knox.gateway.services.security.token.TokenServiceException; import org.apache.knox.gateway.services.security.token.TokenStateService; +import org.apache.knox.gateway.services.security.token.TokenUtils; import org.apache.knox.gateway.services.security.token.UnknownTokenException; import org.apache.knox.gateway.services.security.token.impl.JWT; @@ -113,35 +113,13 @@ public void init( FilterConfig filterConfig ) throws ServletException { GatewayServices services = (GatewayServices) context.getAttribute(GatewayServices.GATEWAY_SERVICES_ATTRIBUTE); if (services != null) { authority = services.getService(ServiceType.TOKEN_SERVICE); - if (isServerManagedTokenStateEnabled(filterConfig)) { + if (TokenUtils.isServerManagedTokenStateEnabled(filterConfig)) { tokenStateService = services.getService(ServiceType.TOKEN_STATE_SERVICE); } } } } - protected boolean isServerManagedTokenStateEnabled(FilterConfig filterConfig) { - boolean isServerManaged = false; - - // First, check for explicit provider-level configuration - String providerParamValue = filterConfig.getInitParameter(TokenStateService.CONFIG_SERVER_MANAGED); - - // If there is no provider-level configuration - if (providerParamValue == null || providerParamValue.isEmpty()) { - // Fall back to the gateway-level default - ServletContext context = filterConfig.getServletContext(); - if (context != null) { - GatewayConfig config = (GatewayConfig) context.getAttribute(GatewayConfig.GATEWAY_CONFIG_ATTRIBUTE); - isServerManaged = (config != null) && config.isServerManagedTokenStateEnabled(); - } - } else { - // Otherwise, apply the provider-level configuration - isServerManaged = Boolean.valueOf(providerParamValue); - } - - return isServerManaged; - } - protected void configureExpectedParameters(FilterConfig filterConfig) { expectedIssuer = filterConfig.getInitParameter(JWT_EXPECTED_ISSUER); if (expectedIssuer == null) { @@ -171,7 +149,7 @@ protected List parseExpectedAudiences(String expectedAudiences) { protected boolean tokenIsStillValid(JWT jwtToken) throws UnknownTokenException { Date expires; if (tokenStateService != null) { - expires = new Date(tokenStateService.getTokenExpiration(jwtToken.toString())); + expires = new Date(tokenStateService.getTokenExpiration(jwtToken)); } else { // if there is no expiration date then the lifecycle is tied entirely to // the cookie validity - otherwise ensure that the current time is before diff --git a/gateway-provider-security-jwt/src/main/java/org/apache/knox/gateway/provider/federation/jwt/filter/AccessTokenFederationFilter.java b/gateway-provider-security-jwt/src/main/java/org/apache/knox/gateway/provider/federation/jwt/filter/AccessTokenFederationFilter.java index 57696d642f..1bc8b4dddc 100644 --- a/gateway-provider-security-jwt/src/main/java/org/apache/knox/gateway/provider/federation/jwt/filter/AccessTokenFederationFilter.java +++ b/gateway-provider-security-jwt/src/main/java/org/apache/knox/gateway/provider/federation/jwt/filter/AccessTokenFederationFilter.java @@ -25,6 +25,7 @@ import org.apache.knox.gateway.services.security.token.JWTokenAuthority; import org.apache.knox.gateway.services.security.token.TokenServiceException; import org.apache.knox.gateway.services.security.token.TokenStateService; +import org.apache.knox.gateway.services.security.token.TokenUtils; import org.apache.knox.gateway.services.security.token.UnknownTokenException; import org.apache.knox.gateway.services.security.token.impl.JWTToken; @@ -59,7 +60,7 @@ public void init( FilterConfig filterConfig ) throws ServletException { GatewayServices services = (GatewayServices) filterConfig.getServletContext().getAttribute(GatewayServices.GATEWAY_SERVICES_ATTRIBUTE); authority = services.getService(ServiceType.TOKEN_SERVICE); - if (Boolean.valueOf(filterConfig.getInitParameter(TokenStateService.CONFIG_SERVER_MANAGED))) { + if (TokenUtils.isServerManagedTokenStateEnabled(filterConfig)) { tokenStateService = services.getService(ServiceType.TOKEN_STATE_SERVICE); } } @@ -117,14 +118,18 @@ public void doFilter(ServletRequest request, ServletResponse response, FilterCha } private boolean isExpired(JWTToken token) throws UnknownTokenException { - return (tokenStateService != null) ? tokenStateService.isExpired(token.toString()) : (Long.parseLong(token.getExpires()) <= System.currentTimeMillis()); + return (tokenStateService != null) ? tokenStateService.isExpired(token) + : (Long.parseLong(token.getExpires()) <= System.currentTimeMillis()); } private void sendUnauthorized(ServletResponse response) throws IOException { ((HttpServletResponse) response).sendError(HttpServletResponse.SC_UNAUTHORIZED); } - private void continueWithEstablishedSecurityContext(Subject subject, final HttpServletRequest request, final HttpServletResponse response, final FilterChain chain) throws IOException, ServletException { + private void continueWithEstablishedSecurityContext(Subject subject, + final HttpServletRequest request, + final HttpServletResponse response, + final FilterChain chain) throws IOException, ServletException { try { Subject.doAs( subject, diff --git a/gateway-provider-security-jwt/src/test/java/org/apache/knox/gateway/provider/federation/CommonJWTFilterTest.java b/gateway-provider-security-jwt/src/test/java/org/apache/knox/gateway/provider/federation/CommonJWTFilterTest.java index 8c46900cf8..35bb310487 100644 --- a/gateway-provider-security-jwt/src/test/java/org/apache/knox/gateway/provider/federation/CommonJWTFilterTest.java +++ b/gateway-provider-security-jwt/src/test/java/org/apache/knox/gateway/provider/federation/CommonJWTFilterTest.java @@ -19,6 +19,7 @@ import org.apache.knox.gateway.config.GatewayConfig; import org.apache.knox.gateway.provider.federation.jwt.filter.AbstractJWTFilter; import org.apache.knox.gateway.services.security.token.TokenStateService; +import org.apache.knox.gateway.services.security.token.TokenUtils; import org.apache.knox.gateway.services.security.token.UnknownTokenException; import org.apache.knox.gateway.services.security.token.impl.JWT; import org.easymock.EasyMock; @@ -109,9 +110,7 @@ private boolean doTestServerManagedTokenState(final Boolean isEnabledAtGateway, EasyMock.expect(fc.getServletContext()).andReturn(sc).anyTimes(); EasyMock.replay(fc); - Method m = AbstractJWTFilter.class.getDeclaredMethod("isServerManagedTokenStateEnabled", FilterConfig.class); - m.setAccessible(true); - return (Boolean) m.invoke(handler, fc); + return TokenUtils.isServerManagedTokenStateEnabled(fc); } @Test @@ -129,7 +128,7 @@ public void testIsStillValidExpired() throws Exception { @Test(expected = UnknownTokenException.class) public void testIsStillValidUnknownToken() throws Exception { TokenStateService tss = EasyMock.createNiceMock(TokenStateService.class); - EasyMock.expect(tss.getTokenExpiration(anyObject())) + EasyMock.expect(tss.getTokenExpiration(anyObject(JWT.class))) .andThrow(new UnknownTokenException("eyjhbgcioi1234567890neg")) .anyTimes(); EasyMock.replay(tss); @@ -139,7 +138,7 @@ public void testIsStillValidUnknownToken() throws Exception { private boolean doTestIsStillValid(final Long expiration) throws Exception { TokenStateService tss = EasyMock.createNiceMock(TokenStateService.class); - EasyMock.expect(tss.getTokenExpiration(anyObject())) + EasyMock.expect(tss.getTokenExpiration(anyObject(JWT.class))) .andReturn(expiration) .anyTimes(); EasyMock.replay(tss); diff --git a/gateway-server/src/main/java/org/apache/knox/gateway/services/token/impl/AliasBasedTokenStateService.java b/gateway-server/src/main/java/org/apache/knox/gateway/services/token/impl/AliasBasedTokenStateService.java index ab900e5f9f..1178e87fb4 100644 --- a/gateway-server/src/main/java/org/apache/knox/gateway/services/token/impl/AliasBasedTokenStateService.java +++ b/gateway-server/src/main/java/org/apache/knox/gateway/services/token/impl/AliasBasedTokenStateService.java @@ -20,7 +20,6 @@ import org.apache.knox.gateway.services.ServiceLifecycleException; import org.apache.knox.gateway.services.security.AliasService; import org.apache.knox.gateway.services.security.AliasServiceException; -import org.apache.knox.gateway.services.security.token.TokenUtils; import org.apache.knox.gateway.services.security.token.UnknownTokenException; import java.util.ArrayList; @@ -49,120 +48,102 @@ public void init(final GatewayConfig config, final Map options) } @Override - public void addToken(final String token, - long issueTime, - long expiration, - long maxLifetimeDuration) { - isValidIdentifier(token); + public void addToken(final String tokenId, + long issueTime, + long expiration, + long maxLifetimeDuration) { + isValidIdentifier(tokenId); try { - aliasService.addAliasForCluster(AliasService.NO_CLUSTER_NAME, token, String.valueOf(expiration)); - setMaxLifetime(token, issueTime, maxLifetimeDuration); - log.addedToken(TokenUtils.getTokenDisplayText(token), getTimestampDisplay(expiration)); + aliasService.addAliasForCluster(AliasService.NO_CLUSTER_NAME, tokenId, String.valueOf(expiration)); + setMaxLifetime(tokenId, issueTime, maxLifetimeDuration); + log.addedToken(tokenId, getTimestampDisplay(expiration)); } catch (AliasServiceException e) { - log.failedToSaveTokenState(TokenUtils.getTokenDisplayText(token), e); + log.failedToSaveTokenState(tokenId, e); } } @Override - protected void setMaxLifetime(final String token, long issueTime, long maxLifetimeDuration) { + protected void setMaxLifetime(final String tokenId, long issueTime, long maxLifetimeDuration) { try { aliasService.addAliasForCluster(AliasService.NO_CLUSTER_NAME, - token + TOKEN_MAX_LIFETIME_POSTFIX, + tokenId + TOKEN_MAX_LIFETIME_POSTFIX, String.valueOf(issueTime + maxLifetimeDuration)); } catch (AliasServiceException e) { - log.failedToSaveTokenState(TokenUtils.getTokenDisplayText(token), e); + log.failedToSaveTokenState(tokenId, e); } } @Override - protected long getMaxLifetime(final String token) { + protected long getMaxLifetime(final String tokenId) { long result = 0; try { char[] maxLifetimeStr = - aliasService.getPasswordFromAliasForCluster(AliasService.NO_CLUSTER_NAME, token + TOKEN_MAX_LIFETIME_POSTFIX); + aliasService.getPasswordFromAliasForCluster(AliasService.NO_CLUSTER_NAME, + tokenId + TOKEN_MAX_LIFETIME_POSTFIX); if (maxLifetimeStr != null) { result = Long.parseLong(new String(maxLifetimeStr)); } } catch (AliasServiceException e) { - log.errorAccessingTokenState(TokenUtils.getTokenDisplayText(token), e); + log.errorAccessingTokenState(tokenId, e); } return result; } @Override - public long getTokenExpiration(final String token) throws UnknownTokenException { + public long getTokenExpiration(final String tokenId) throws UnknownTokenException { long expiration = 0; + + validateToken(tokenId); + try { - validateToken(token); - } catch (final UnknownTokenException e) { - /* if token permissiveness is enabled we check JWT token expiration when the token state is unknown */ - if (permissiveValidationEnabled && getJWTTokenExpiration(token).isPresent()) { - return getJWTTokenExpiration(token).getAsLong(); - } else { - throw e; - } - } - try { - char[] expStr = aliasService.getPasswordFromAliasForCluster(AliasService.NO_CLUSTER_NAME, token); + char[] expStr = aliasService.getPasswordFromAliasForCluster(AliasService.NO_CLUSTER_NAME, tokenId); if (expStr != null) { expiration = Long.parseLong(new String(expStr)); } } catch (Exception e) { - log.errorAccessingTokenState(TokenUtils.getTokenDisplayText(token), e); + log.errorAccessingTokenState(tokenId, e); } return expiration; } @Override - public void revokeToken(final String token) throws UnknownTokenException { - /* no reason to keep revoked tokens around */ - removeToken(token); - log.revokedToken(TokenUtils.getTokenDisplayText(token)); - } - - @Override - protected boolean isUnknown(final String token) { + protected boolean isUnknown(final String tokenId) { boolean isUnknown = false; try { - isUnknown = (aliasService.getPasswordFromAliasForCluster(AliasService.NO_CLUSTER_NAME, token) == null); + isUnknown = (aliasService.getPasswordFromAliasForCluster(AliasService.NO_CLUSTER_NAME, tokenId) == null); } catch (AliasServiceException e) { - log.errorAccessingTokenState(TokenUtils.getTokenDisplayText(token), e); + log.errorAccessingTokenState(tokenId, e); } return isUnknown; } @Override - protected void removeToken(final String token) throws UnknownTokenException { - validateToken(token); + protected void removeToken(final String tokenId) throws UnknownTokenException { + validateToken(tokenId); try { - aliasService.removeAliasForCluster(AliasService.NO_CLUSTER_NAME, token); - aliasService.removeAliasForCluster(AliasService.NO_CLUSTER_NAME,token + TOKEN_MAX_LIFETIME_POSTFIX); - log.removedTokenState(TokenUtils.getTokenDisplayText(token)); + aliasService.removeAliasForCluster(AliasService.NO_CLUSTER_NAME, tokenId); + aliasService.removeAliasForCluster(AliasService.NO_CLUSTER_NAME, tokenId + TOKEN_MAX_LIFETIME_POSTFIX); + log.removedTokenState(tokenId); } catch (AliasServiceException e) { - log.failedToRemoveTokenState(TokenUtils.getTokenDisplayText(token), e); + log.failedToRemoveTokenState(tokenId, e); } } @Override - protected void updateExpiration(final String token, long expiration) { - if (isUnknown(token)) { - log.unknownToken(TokenUtils.getTokenDisplayText(token)); - throw new IllegalArgumentException("Unknown token."); - } - + protected void updateExpiration(final String tokenId, long expiration) { try { - aliasService.removeAliasForCluster(AliasService.NO_CLUSTER_NAME, token); - aliasService.addAliasForCluster(AliasService.NO_CLUSTER_NAME, token, String.valueOf(expiration)); + aliasService.removeAliasForCluster(AliasService.NO_CLUSTER_NAME, tokenId); + aliasService.addAliasForCluster(AliasService.NO_CLUSTER_NAME, tokenId, String.valueOf(expiration)); } catch (AliasServiceException e) { - log.failedToUpdateTokenExpiration(TokenUtils.getTokenDisplayText(token), e); + log.failedToUpdateTokenExpiration(tokenId, e); } } @Override protected List getTokens() { - List allAliases = new ArrayList(); + List allAliases = new ArrayList<>(); try { allAliases = aliasService.getAliasesForCluster(AliasService.NO_CLUSTER_NAME); /* only get the aliases that represent tokens and extract the current list of tokens */ diff --git a/gateway-server/src/main/java/org/apache/knox/gateway/services/token/impl/DefaultTokenStateService.java b/gateway-server/src/main/java/org/apache/knox/gateway/services/token/impl/DefaultTokenStateService.java index f0f5ca7782..cf5d1907e0 100644 --- a/gateway-server/src/main/java/org/apache/knox/gateway/services/token/impl/DefaultTokenStateService.java +++ b/gateway-server/src/main/java/org/apache/knox/gateway/services/token/impl/DefaultTokenStateService.java @@ -25,13 +25,10 @@ import org.apache.knox.gateway.services.security.token.impl.JWT; import org.apache.knox.gateway.services.security.token.impl.JWTToken; -import java.text.ParseException; import java.time.Instant; -import java.util.Date; import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.OptionalLong; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; @@ -97,48 +94,62 @@ public long getDefaultMaxLifetimeDuration() { @Override public void addToken(final JWTToken token, long issueTime) { if (token == null) { - throw new IllegalArgumentException("Token data cannot be null."); + throw new IllegalArgumentException("Token cannot be null."); } - addToken(token.getPayload(), issueTime, token.getExpiresDate().getTime()); + addToken(TokenUtils.getTokenId(token), issueTime, token.getExpiresDate().getTime()); } @Override - public void addToken(final String token, long issueTime, long expiration) { - addToken(token, issueTime, expiration, getDefaultMaxLifetimeDuration()); + public void addToken(final String tokenId, long issueTime, long expiration) { + addToken(tokenId, issueTime, expiration, getDefaultMaxLifetimeDuration()); } @Override - public void addToken(final String token, - long issueTime, - long expiration, - long maxLifetimeDuration) { - if (!isValidIdentifier(token)) { - throw new IllegalArgumentException("Token data cannot be null."); + public void addToken(final String tokenId, + long issueTime, + long expiration, + long maxLifetimeDuration) { + if (!isValidIdentifier(tokenId)) { + throw new IllegalArgumentException("Token identifier cannot be null."); } synchronized (tokenExpirations) { - tokenExpirations.put(token, expiration); + tokenExpirations.put(tokenId, expiration); } - setMaxLifetime(token, issueTime, maxLifetimeDuration); - log.addedToken(TokenUtils.getTokenDisplayText(token), getTimestampDisplay(expiration)); + setMaxLifetime(tokenId, issueTime, maxLifetimeDuration); + log.addedToken(tokenId, getTimestampDisplay(expiration)); } @Override - public long getTokenExpiration(final String token) throws UnknownTokenException { - long expiration; - + public long getTokenExpiration(final JWT token) throws UnknownTokenException { + long expiration = -1; try { - validateToken(token); - } catch (final UnknownTokenException e) { - /* if token permissiveness is enabled we check JWT token expiration when the token state is unknown */ - if (permissiveValidationEnabled && getJWTTokenExpiration(token).isPresent()) { - return getJWTTokenExpiration(token).getAsLong(); - } else { + expiration = getTokenExpiration(TokenUtils.getTokenId(token)); + } catch (UnknownTokenException e) { + if (permissiveValidationEnabled) { + String exp = token.getExpires(); + if (exp != null) { + log.permissiveTokenHandling(TokenUtils.getTokenId(token), e.getMessage()); + expiration = Long.parseLong(exp); + } + } + + if (expiration == -1) { throw e; } } + + return expiration; + } + + @Override + public long getTokenExpiration(final String tokenId) throws UnknownTokenException { + long expiration; + + validateToken(tokenId); + synchronized (tokenExpirations) { - expiration = tokenExpirations.get(token); + expiration = tokenExpirations.get(tokenId); } return expiration; @@ -152,29 +163,29 @@ public long renewToken(final JWTToken token) throws UnknownTokenException { @Override public long renewToken(final JWTToken token, long renewInterval) throws UnknownTokenException { if (token == null) { - throw new IllegalArgumentException("Token data cannot be null."); + throw new IllegalArgumentException("Token cannot be null."); } - return renewToken(token.getPayload(), renewInterval); + return renewToken(TokenUtils.getTokenId(token), renewInterval); } @Override - public long renewToken(final String token) throws UnknownTokenException { - return renewToken(token, DEFAULT_RENEWAL_INTERVAL); + public long renewToken(final String tokenId) throws UnknownTokenException { + return renewToken(tokenId, DEFAULT_RENEWAL_INTERVAL); } @Override - public long renewToken(final String token, long renewInterval) throws UnknownTokenException { + public long renewToken(final String tokenId, long renewInterval) throws UnknownTokenException { long expiration; - validateToken(token); + validateToken(tokenId); // Make sure the maximum lifetime has not been (and will not be) exceeded - if (hasRemainingRenewals(token, renewInterval)) { + if (hasRemainingRenewals(tokenId, renewInterval)) { expiration = System.currentTimeMillis() + renewInterval; - updateExpiration(token, expiration); - log.renewedToken(TokenUtils.getTokenDisplayText(token), getTimestampDisplay(expiration)); + updateExpiration(tokenId, expiration); + log.renewedToken(tokenId, getTimestampDisplay(expiration)); } else { - log.renewalLimitExceeded(token); + log.renewalLimitExceeded(tokenId); throw new IllegalArgumentException("The renewal limit for the token has been exceeded"); } @@ -184,33 +195,22 @@ public long renewToken(final String token, long renewInterval) throws UnknownTok @Override public void revokeToken(final JWTToken token) throws UnknownTokenException { if (token == null) { - throw new IllegalArgumentException("Token data cannot be null."); + throw new IllegalArgumentException("Token cannot be null."); } - revokeToken(token.getPayload()); + revokeToken(TokenUtils.getTokenId(token)); } @Override - public void revokeToken(final String token) throws UnknownTokenException { + public void revokeToken(final String tokenId) throws UnknownTokenException { /* no reason to keep revoked tokens around */ - removeToken(token); - log.revokedToken(TokenUtils.getTokenDisplayText(token)); + removeToken(tokenId); + log.revokedToken(tokenId); } @Override public boolean isExpired(final JWTToken token) throws UnknownTokenException { - return isExpired(token.getPayload()); - } - - @Override - public boolean isExpired(final String token) throws UnknownTokenException { - boolean isExpired; - isExpired = isUnknown(token); // Check if the token exist - if (!isExpired) { - // If it not unknown, check its expiration - isExpired = (getTokenExpiration(token) <= System.currentTimeMillis()); - } - return isExpired; + return getTokenExpiration(token) <= System.currentTimeMillis(); } protected void setMaxLifetime(final String token, long issueTime, long maxLifetimeDuration) { @@ -233,57 +233,57 @@ protected boolean isUnknown(final String token) { return isUnknown; } - protected void updateExpiration(final String token, long expiration) { + protected void updateExpiration(final String tokenId, long expiration) { synchronized (tokenExpirations) { - tokenExpirations.replace(token, expiration); + tokenExpirations.replace(tokenId, expiration); } } - protected void removeToken(final String token) throws UnknownTokenException { - validateToken(token); + protected void removeToken(final String tokenId) throws UnknownTokenException { + validateToken(tokenId); synchronized (tokenExpirations) { - tokenExpirations.remove(token); + tokenExpirations.remove(tokenId); } synchronized (maxTokenLifetimes) { - maxTokenLifetimes.remove(token); + maxTokenLifetimes.remove(tokenId); } - log.removedTokenState(TokenUtils.getTokenDisplayText(token)); + log.removedTokenState(tokenId); } - protected boolean hasRemainingRenewals(final String token, long renewInterval) { + protected boolean hasRemainingRenewals(final String tokenId, long renewInterval) { // Is the current time + 30-second buffer + the renewal interval is less than the max lifetime for the token? - return ((System.currentTimeMillis() + 30000 + renewInterval) < getMaxLifetime(token)); + return ((System.currentTimeMillis() + 30000 + renewInterval) < getMaxLifetime(tokenId)); } - protected long getMaxLifetime(final String token) { + protected long getMaxLifetime(final String tokenId) { long result; synchronized (maxTokenLifetimes) { - result = maxTokenLifetimes.getOrDefault(token, 0L); + result = maxTokenLifetimes.getOrDefault(tokenId, 0L); } return result; } - protected boolean isValidIdentifier(final String token) { - return token != null && !token.isEmpty(); + protected boolean isValidIdentifier(final String tokenId) { + return tokenId != null && !tokenId.isEmpty(); } /** * Validate the specified token identifier. * - * @param token The token identifier to validate. + * @param tokenId The token identifier to validate. * * @throws IllegalArgumentException if the specified token in invalid. * @throws UnknownTokenException if the specified token in valid, but not known to the service. */ - protected void validateToken(final String token) throws IllegalArgumentException, UnknownTokenException { - if (!isValidIdentifier(token)) { - throw new IllegalArgumentException("Token data cannot be null."); + protected void validateToken(final String tokenId) throws IllegalArgumentException, UnknownTokenException { + if (!isValidIdentifier(tokenId)) { + throw new IllegalArgumentException("Token identifier cannot be null."); } // First, make sure the token is one we know about - if (isUnknown(token)) { - log.unknownToken(TokenUtils.getTokenDisplayText(token)); - throw new UnknownTokenException(token); + if (isUnknown(tokenId)) { + log.unknownToken(tokenId); + throw new UnknownTokenException(tokenId); } } @@ -295,14 +295,14 @@ protected String getTimestampDisplay(long timestamp) { * Method that deletes expired tokens based on the token timestamp. */ protected void evictExpiredTokens() { - for (final String token : getTokens()) { + for (final String tokenId : getTokens()) { try { - if (needsEviction(token)) { - log.evictToken(TokenUtils.getTokenDisplayText(token)); - removeToken(token); + if (needsEviction(tokenId)) { + log.evictToken(tokenId); + removeToken(tokenId); } } catch (final Exception e) { - log.failedExpiredTokenEviction(TokenUtils.getTokenDisplayText(token), e); + log.failedExpiredTokenEviction(tokenId, e); } } } @@ -310,11 +310,11 @@ protected void evictExpiredTokens() { /** * Method that checks if an expired token is ready to be evicted * by adding configured grace period to the expiry time. - * @param token + * @param tokenId * @return */ - protected boolean needsEviction(final String token) throws UnknownTokenException { - return ((getTokenExpiration(token) + TimeUnit.SECONDS.toMillis(tokenEvictionGracePeriod)) <= System.currentTimeMillis()); + protected boolean needsEviction(final String tokenId) throws UnknownTokenException { + return ((getTokenExpiration(tokenId) + TimeUnit.SECONDS.toMillis(tokenEvictionGracePeriod)) <= System.currentTimeMillis()); } /** @@ -326,26 +326,4 @@ protected List getTokens() { return tokenExpirations.keySet().stream().collect(Collectors.toList()); } - /** - * A function that returns the JWT token expiration. This is only called when - * gateway.knox.token.permissive.validation property is set to true. - * @param token token to be verified and saved - */ - protected OptionalLong getJWTTokenExpiration(final String token) { - JWT jwt; - try { - jwt = new JWTToken(token); - } catch (final ParseException e) { - log.errorParsingToken(e.toString()); - return OptionalLong.empty(); - } - final Date expires = jwt.getExpiresDate(); - if (expires == null) { - log.jwtTokenExpiry(TokenUtils.getTokenDisplayText(token), "-1"); - return OptionalLong.of(-1); - } - log.jwtTokenExpiry(TokenUtils.getTokenDisplayText(token), expires.toString()); - return OptionalLong.of(expires.getTime()); - } - } diff --git a/gateway-server/src/main/java/org/apache/knox/gateway/services/token/impl/TokenStateServiceMessages.java b/gateway-server/src/main/java/org/apache/knox/gateway/services/token/impl/TokenStateServiceMessages.java index 318ccedb94..15eeb18805 100644 --- a/gateway-server/src/main/java/org/apache/knox/gateway/services/token/impl/TokenStateServiceMessages.java +++ b/gateway-server/src/main/java/org/apache/knox/gateway/services/token/impl/TokenStateServiceMessages.java @@ -25,48 +25,46 @@ public interface TokenStateServiceMessages { @Message(level = MessageLevel.DEBUG, text = "Added token {0}, expiration {1}") - void addedToken(String tokenDisplayText, String expiration); + void addedToken(String tokenId, String expiration); @Message(level = MessageLevel.DEBUG, text = "Renewed token {0}, expiration {1}") - void renewedToken(String tokenDisplayText, String expiration); + void renewedToken(String tokenId, String expiration); @Message(level = MessageLevel.DEBUG, text = "Revoked token {0}") - void revokedToken(String tokenDisplayText); + void revokedToken(String tokenId); @Message(level = MessageLevel.DEBUG, text = "Removed state for token {0}") - void removedTokenState(String tokenDisplayText); + void removedTokenState(String tokenId); - @Message(level = MessageLevel.DEBUG, text = "Unknown token {0}") - void unknownToken(String tokenDisplayText); + @Message(level = MessageLevel.ERROR, text = "Unknown token {0}") + void unknownToken(String tokenId); @Message(level = MessageLevel.ERROR, text = "The renewal limit for the token ({0}) has been exceeded.") - void renewalLimitExceeded(String tokenDisplayText); + void renewalLimitExceeded(String tokenId); @Message(level = MessageLevel.ERROR, text = "Failed to save state for token {0} : {1}") - void failedToSaveTokenState(String tokenDisplayText, @StackTrace(level = MessageLevel.DEBUG) Exception e); + void failedToSaveTokenState(String tokenId, @StackTrace(level = MessageLevel.DEBUG) Exception e); @Message(level = MessageLevel.ERROR, text = "Error accessing state for token {0} : {1}") - void errorAccessingTokenState(String tokenDisplayText, @StackTrace(level = MessageLevel.DEBUG) Exception e); + void errorAccessingTokenState(String tokenId, @StackTrace(level = MessageLevel.DEBUG) Exception e); + + @Message(level = MessageLevel.INFO, + text = "Referencing the expiration in the token ({0}) because no state could not be found: {1}") + void permissiveTokenHandling(String tokenId, String errorMessage); @Message(level = MessageLevel.ERROR, text = "Failed to update expiration for token {1} : {1}") - void failedToUpdateTokenExpiration(String tokenDisplayText, @StackTrace(level = MessageLevel.DEBUG) Exception e); + void failedToUpdateTokenExpiration(String tokenId, @StackTrace(level = MessageLevel.DEBUG) Exception e); @Message(level = MessageLevel.ERROR, text = "Failed to remove state for token {0} : {1}") - void failedToRemoveTokenState(String tokenDisplayText, @StackTrace(level = MessageLevel.DEBUG) Exception e); + void failedToRemoveTokenState(String tokenId, @StackTrace(level = MessageLevel.DEBUG) Exception e); @Message(level = MessageLevel.ERROR, text = "Failed to evict expired token {0} : {1}") - void failedExpiredTokenEviction(String tokenDisplayText, @StackTrace(level = MessageLevel.DEBUG) Exception e); + void failedExpiredTokenEviction(String tokenId, @StackTrace(level = MessageLevel.DEBUG) Exception e); @Message(level = MessageLevel.DEBUG, text = "Evicting expired token {0}") - void evictToken(String tokenDisplayText); + void evictToken(String tokenId); @Message(level = MessageLevel.ERROR, text = "Error occurred evicting token {0}") void errorEvictingTokens(@StackTrace(level = MessageLevel.DEBUG) Exception e); - @Message(level = MessageLevel.ERROR, text = "Error occurred while parsing JWT token, cause: {0}") - void errorParsingToken(String cause); - - @Message(level = MessageLevel.DEBUG, text = "Permissive validation for token is enabled, expiration for token {0} is {1}") - void jwtTokenExpiry(String tokenDisplayText, String expiration); - } diff --git a/gateway-server/src/test/java/org/apache/knox/gateway/services/token/impl/DefaultTokenStateServiceTest.java b/gateway-server/src/test/java/org/apache/knox/gateway/services/token/impl/DefaultTokenStateServiceTest.java index 28ac7b8622..0e60b7e73f 100644 --- a/gateway-server/src/test/java/org/apache/knox/gateway/services/token/impl/DefaultTokenStateServiceTest.java +++ b/gateway-server/src/test/java/org/apache/knox/gateway/services/token/impl/DefaultTokenStateServiceTest.java @@ -21,8 +21,8 @@ import org.apache.knox.gateway.config.GatewayConfig; import org.apache.knox.gateway.services.ServiceLifecycleException; import org.apache.knox.gateway.services.security.token.TokenStateService; -import org.apache.knox.gateway.services.security.token.impl.JWT; import org.apache.knox.gateway.services.security.token.TokenUtils; +import org.apache.knox.gateway.services.security.token.impl.JWT; import org.apache.knox.gateway.services.security.token.UnknownTokenException; import org.apache.knox.gateway.services.security.token.impl.JWTToken; import org.easymock.EasyMock; @@ -34,6 +34,7 @@ import java.security.interfaces.RSAPrivateKey; import java.util.Collections; import java.util.Date; +import java.util.UUID; import java.util.concurrent.TimeUnit; import static org.junit.Assert.assertEquals; @@ -62,14 +63,14 @@ public void testGetExpiration() throws Exception { final TokenStateService tss = createTokenStateService(); tss.addToken(token, System.currentTimeMillis()); - long expiration = tss.getTokenExpiration(token.getPayload()); + long expiration = tss.getTokenExpiration(TokenUtils.getTokenId(token)); assertEquals(token.getExpiresDate().getTime(), expiration); } @Test(expected = IllegalArgumentException.class) public void testGetExpiration_NullToken() throws Exception { // Expecting an IllegalArgumentException because the token is null - createTokenStateService().getTokenExpiration(null); + createTokenStateService().getTokenExpiration((String) null); } @Test(expected = IllegalArgumentException.class) @@ -83,7 +84,7 @@ public void testGetExpiration_InvalidToken() throws Exception { final JWTToken token = createMockToken(System.currentTimeMillis() + 60000); // Expecting an UnknownTokenException because the token is not known to the TokenStateService - createTokenStateService().getTokenExpiration(token.getPayload()); + createTokenStateService().getTokenExpiration(TokenUtils.getTokenId(token)); } @Test @@ -92,12 +93,12 @@ public void testGetExpiration_AfterRenewal() throws Exception { final TokenStateService tss = createTokenStateService(); tss.addToken(token, System.currentTimeMillis()); - long expiration = tss.getTokenExpiration(token.getPayload()); + long expiration = tss.getTokenExpiration(TokenUtils.getTokenId(token)); assertEquals(token.getExpiresDate().getTime(), expiration); long newExpiration = tss.renewToken(token); assertTrue(newExpiration > token.getExpiresDate().getTime()); - assertTrue(tss.getTokenExpiration(token.getPayload()) > token.getExpiresDate().getTime()); + assertTrue(tss.getTokenExpiration(TokenUtils.getTokenId(token)) > token.getExpiresDate().getTime()); } @Test @@ -119,7 +120,7 @@ public void testIsExpired_Positive() throws Exception { } - @Test + @Test(expected = UnknownTokenException.class) public void testIsExpired_Revoked() throws Exception { final JWTToken token = createMockToken(System.currentTimeMillis() + 60000); final TokenStateService tss = createTokenStateService(); @@ -128,7 +129,7 @@ public void testIsExpired_Revoked() throws Exception { assertFalse("Expected the token to be valid.", tss.isExpired(token)); tss.revokeToken(token); - assertTrue("Expected the token to have been marked as revoked.", tss.isExpired(token)); + tss.isExpired(token); } @@ -154,7 +155,7 @@ public void testRenewalBeyondMaxLifetime() throws Exception { final TokenStateService tss = createTokenStateService(); // Add the token with a short maximum lifetime - tss.addToken(token.getPayload(), issueTime, expiration, 5000L); + tss.addToken(TokenUtils.getTokenId(token), issueTime, expiration, 5000L); try { // Attempt to renew the token for the default interval, which should exceed the specified short maximum lifetime @@ -196,7 +197,7 @@ public void testTokenEviction() /* expect the renew call to fail since the token is evicted */ final UnknownTokenException e = assertThrows(UnknownTokenException.class, () -> tss.renewToken(token)); - assertEquals("Unknown token: " + TokenUtils.getTokenDisplayText(token.getPayload()), e.getMessage()); + assertEquals("Unknown token: " + TokenUtils.getTokenId(token), e.getMessage()); } finally { tss.stop(); } @@ -212,12 +213,12 @@ public void testTokenPermissiveness() throws UnknownTokenException { } catch (ServiceLifecycleException e) { fail("Error creating TokenStateService: " + e.getMessage()); } - assertEquals(expiry/1000, tss.getTokenExpiration(token.toString())/1000); + assertEquals(expiry/1000, tss.getTokenExpiration(token)/1000); } - @Test + @Test(expected = UnknownTokenException.class) public void testTokenPermissivenessNoExpiry() throws UnknownTokenException { - final JWT token = getJWTToken(-1); + final JWT token = getJWTToken(-1L); TokenStateService tss = new DefaultTokenStateService(); try { tss.init(createMockGatewayConfig(true), Collections.emptyMap()); @@ -225,7 +226,7 @@ public void testTokenPermissivenessNoExpiry() throws UnknownTokenException { fail("Error creating TokenStateService: " + e.getMessage()); } - assertEquals(-1L, tss.getTokenExpiration(token.toString())); + tss.getTokenExpiration(token); } protected static JWTToken createMockToken(final long expiration) { @@ -233,8 +234,10 @@ protected static JWTToken createMockToken(final long expiration) { } protected static JWTToken createMockToken(final String payload, final long expiration) { + UUID tokenUID = UUID.randomUUID(); JWTToken token = EasyMock.createNiceMock(JWTToken.class); EasyMock.expect(token.getPayload()).andReturn(payload).anyTimes(); + EasyMock.expect(token.getClaim(JWTToken.KNOX_ID_CLAIM)).andReturn(String.valueOf(tokenUID)).anyTimes(); EasyMock.expect(token.getExpiresDate()).andReturn(new Date(expiration)).anyTimes(); EasyMock.replay(token); return token; @@ -273,6 +276,7 @@ protected JWT getJWTToken(final long expiry) { if(expiry > 0) { claims[3] = Long.toString(expiry); } + JWT token = new JWTToken("RS256", claims); // Sign the token JWSSigner signer = new RSASSASigner(privateKey); diff --git a/gateway-service-knoxtoken/src/main/java/org/apache/knox/gateway/service/knoxtoken/TokenResource.java b/gateway-service-knoxtoken/src/main/java/org/apache/knox/gateway/service/knoxtoken/TokenResource.java index 01c30733ee..26474205f8 100644 --- a/gateway-service-knoxtoken/src/main/java/org/apache/knox/gateway/service/knoxtoken/TokenResource.java +++ b/gateway-service-knoxtoken/src/main/java/org/apache/knox/gateway/service/knoxtoken/TokenResource.java @@ -22,6 +22,7 @@ import java.security.cert.Certificate; import java.security.cert.CertificateEncodingException; import java.security.cert.X509Certificate; +import java.text.ParseException; import java.util.ArrayList; import java.util.Map; import java.util.HashMap; @@ -50,8 +51,10 @@ import org.apache.knox.gateway.services.security.token.JWTokenAuthority; import org.apache.knox.gateway.services.security.token.TokenServiceException; import org.apache.knox.gateway.services.security.token.TokenStateService; +import org.apache.knox.gateway.services.security.token.TokenUtils; import org.apache.knox.gateway.services.security.token.UnknownTokenException; import org.apache.knox.gateway.services.security.token.impl.JWT; +import org.apache.knox.gateway.services.security.token.impl.JWTToken; import org.apache.knox.gateway.util.JsonUtils; import static javax.ws.rs.core.MediaType.APPLICATION_JSON; @@ -246,11 +249,16 @@ public Response renew(String token) { String renewer = SubjectUtils.getCurrentEffectivePrincipalName(); if (allowedRenewers.contains(renewer)) { try { + JWTToken jwt = new JWTToken(token); // If renewal fails, it should be an exception - expiration = tokenStateService.renewToken(token, + expiration = tokenStateService.renewToken(jwt, renewInterval.orElse(tokenStateService.getDefaultRenewInterval())); + log.renewedToken(getTopologyName(), TokenUtils.getTokenDisplayText(token), TokenUtils.getTokenId(jwt)); + } catch (ParseException e) { + log.invalidToken(getTopologyName(), TokenUtils.getTokenDisplayText(token), e); + error = safeGetMessage(e); } catch (Exception e) { - error = e.getMessage(); + error = safeGetMessage(e); } } else { errorStatus = Response.Status.FORBIDDEN; @@ -263,7 +271,7 @@ public Response renew(String token) { .entity("{\n \"renewed\": \"true\",\n \"expires\": \"" + expiration + "\"\n}\n") .build(); } else { - log.badRenewalRequest(getTopologyName(), error); + log.badRenewalRequest(getTopologyName(), TokenUtils.getTokenDisplayText(token), error); resp = Response.status(errorStatus) .entity("{\n \"renewed\": \"false\",\n \"error\": \"" + error + "\"\n}\n") .build(); @@ -287,9 +295,14 @@ public Response revoke(String token) { String renewer = SubjectUtils.getCurrentEffectivePrincipalName(); if (allowedRenewers.contains(renewer)) { try { - tokenStateService.revokeToken(token); + JWTToken jwt = new JWTToken(token); + tokenStateService.revokeToken(jwt); + log.revokedToken(getTopologyName(), TokenUtils.getTokenDisplayText(token), TokenUtils.getTokenId(jwt)); + } catch (ParseException e) { + log.invalidToken(getTopologyName(), TokenUtils.getTokenDisplayText(token), e); + error = safeGetMessage(e); } catch (UnknownTokenException e) { - error = e.getMessage(); + error = safeGetMessage(e); } } else { errorStatus = Response.Status.FORBIDDEN; @@ -297,12 +310,12 @@ public Response revoke(String token) { } } - if(error.isEmpty()) { + if (error.isEmpty()) { resp = Response.status(Response.Status.OK) .entity("{\n \"revoked\": \"true\"\n}\n") .build(); } else { - log.badRevocationRequest(getTopologyName(), error); + log.badRevocationRequest(getTopologyName(), TokenUtils.getTokenDisplayText(token), error); resp = Response.status(errorStatus) .entity("{\n \"revoked\": \"false\",\n \"error\": \"" + error + "\"\n}\n") .build(); @@ -367,6 +380,9 @@ private Response getAuthenticationToken() { if (token != null) { String accessToken = token.toString(); + String tokenId = TokenUtils.getTokenId(token); + log.issuedToken(getTopologyName(), TokenUtils.getTokenDisplayText(accessToken), tokenId); + HashMap map = new HashMap<>(); map.put(ACCESS_TOKEN, accessToken); map.put(TOKEN_TYPE, BEARER); @@ -385,10 +401,11 @@ private Response getAuthenticationToken() { // Optional token store service persistence if (tokenStateService != null) { - tokenStateService.addToken(accessToken, + tokenStateService.addToken(tokenId, System.currentTimeMillis(), expires, maxTokenLifetime.orElse(tokenStateService.getDefaultMaxLifetimeDuration())); + log.storedToken(getTopologyName(), TokenUtils.getTokenDisplayText(accessToken), tokenId); } return Response.ok().entity(jsonResponse).build(); @@ -426,4 +443,15 @@ private String getTopologyName() { return (String) context.getAttribute("org.apache.knox.gateway.gateway.cluster"); } + /** + * Safely get the message from the specified Throwable. + * + * @param t A Throwable + * @return The result of t.getMessage(), or "null" if that result is null. + */ + private String safeGetMessage(Throwable t) { + String message = t.getMessage(); + return message != null ? message : "null"; + } + } diff --git a/gateway-service-knoxtoken/src/main/java/org/apache/knox/gateway/service/knoxtoken/TokenServiceMessages.java b/gateway-service-knoxtoken/src/main/java/org/apache/knox/gateway/service/knoxtoken/TokenServiceMessages.java index df374cebe3..60ee4bae21 100644 --- a/gateway-service-knoxtoken/src/main/java/org/apache/knox/gateway/service/knoxtoken/TokenServiceMessages.java +++ b/gateway-service-knoxtoken/src/main/java/org/apache/knox/gateway/service/knoxtoken/TokenServiceMessages.java @@ -22,9 +22,20 @@ import org.apache.knox.gateway.i18n.messages.Messages; import org.apache.knox.gateway.i18n.messages.StackTrace; +import java.text.ParseException; + @Messages(logger="org.apache.knox.gateway.service.knoxtoken") public interface TokenServiceMessages { + @Message( level = MessageLevel.INFO, text = "Knox Token service ({0}) issued token {1} ({2})") + void issuedToken(String topologyName, String tokenDisplayText, String tokenId); + + @Message( level = MessageLevel.INFO, text = "Knox Token service ({0}) renewed the expiration for token {1} ({2})") + void renewedToken(String topologyName, String tokenDisplayText, String tokenId); + + @Message( level = MessageLevel.INFO, text = "Knox Token service ({0}) revoked token {1} ({2})") + void revokedToken(String topologyName, String tokenDisplayText, String tokenId); + @Message( level = MessageLevel.ERROR, text = "Unable to issue token.") void unableToIssueToken(@StackTrace( level = MessageLevel.DEBUG) Exception e); @@ -45,15 +56,22 @@ void invalidConfigValue(String topologyName, @Message( level = MessageLevel.INFO, text = "Server management of token state is enabled for the \"{0}\" topology.") void serverManagedTokenStateEnabled(String topologyName); + @Message( level = MessageLevel.ERROR, text = "Knox Token service ({0}) could not parse token {1}: {2}") + void invalidToken(String topologyName, + String tokenDisplayText, + @StackTrace( level = MessageLevel.DEBUG ) ParseException e); + @Message( level = MessageLevel.WARN, text = "There are no token renewers white-listed in the \"{0}\" topology.") void noRenewersConfigured(String topologyName); - @Message( level = MessageLevel.ERROR, text = "Knox Token service ({0}) rejected a bad token renewal request: {1}") - void badRenewalRequest(String topologyName, String error); + @Message( level = MessageLevel.ERROR, text = "Knox Token service ({0}) rejected a bad renewal request for token {1}: {2}") + void badRenewalRequest(String topologyName, String tokenDisplayText, String error); - @Message( level = MessageLevel.ERROR, text = "Knox Token service ({0}) rejected a bad token revocation request: {1}") - void badRevocationRequest(String topologyName, String error); + @Message( level = MessageLevel.ERROR, text = "Knox Token service ({0}) rejected a bad revocation request for token {1}: {2}") + void badRevocationRequest(String topologyName, String tokenDisplayText, String error); + @Message( level = MessageLevel.DEBUG, text = "Knox Token service ({0}) stored state for token {1} ({2})") + void storedToken(String topologyName, String tokenDisplayText, String tokenId); } diff --git a/gateway-service-knoxtoken/src/test/java/org/apache/knox/gateway/service/knoxtoken/TokenServiceResourceTest.java b/gateway-service-knoxtoken/src/test/java/org/apache/knox/gateway/service/knoxtoken/TokenServiceResourceTest.java index 074dfeb848..7b9a82e328 100644 --- a/gateway-service-knoxtoken/src/test/java/org/apache/knox/gateway/service/knoxtoken/TokenServiceResourceTest.java +++ b/gateway-service-knoxtoken/src/test/java/org/apache/knox/gateway/service/knoxtoken/TokenServiceResourceTest.java @@ -30,6 +30,8 @@ import org.apache.knox.gateway.services.GatewayServices; import org.apache.knox.gateway.services.security.token.JWTokenAuthority; import org.apache.knox.gateway.services.security.token.TokenStateService; +import org.apache.knox.gateway.services.security.token.TokenUtils; +import org.apache.knox.gateway.services.security.token.UnknownTokenException; import org.apache.knox.gateway.services.security.token.impl.JWT; import org.apache.knox.gateway.services.security.token.impl.JWTToken; import org.easymock.EasyMock; @@ -58,9 +60,14 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; @@ -113,6 +120,7 @@ public void testClientData() { public void testGetToken() throws Exception { ServletContext context = EasyMock.createNiceMock(ServletContext.class); + EasyMock.expect(context.getAttribute("org.apache.knox.gateway.gateway.cluster")).andReturn("test").anyTimes(); HttpServletRequest request = EasyMock.createNiceMock(HttpServletRequest.class); EasyMock.expect(request.getServletContext()).andReturn(context).anyTimes(); @@ -129,6 +137,7 @@ public void testGetToken() throws Exception { EasyMock.replay(principal, services, context, request); TokenResource tr = new TokenResource(); + tr.context = context; tr.request = request; // Issue a token @@ -149,6 +158,58 @@ public void testGetToken() throws Exception { assertTrue(authority.verifyToken(parsedToken)); } + /** + * KNOX-2266 + */ + @Test + public void testConcurrentGetToken() throws Exception { + + ServletContext context = EasyMock.createNiceMock(ServletContext.class); + EasyMock.expect(context.getAttribute("org.apache.knox.gateway.gateway.cluster")).andReturn("test").anyTimes(); + + HttpServletRequest request = EasyMock.createNiceMock(HttpServletRequest.class); + EasyMock.expect(request.getServletContext()).andReturn(context).anyTimes(); + Principal principal = EasyMock.createNiceMock(Principal.class); + EasyMock.expect(principal.getName()).andReturn("alice").anyTimes(); + EasyMock.expect(request.getUserPrincipal()).andReturn(principal).anyTimes(); + + GatewayServices services = EasyMock.createNiceMock(GatewayServices.class); + EasyMock.expect(context.getAttribute(GatewayServices.GATEWAY_SERVICES_ATTRIBUTE)).andReturn(services).anyTimes(); + + JWTokenAuthority authority = new TestJWTokenAuthority(publicKey, privateKey); + EasyMock.expect(services.getService(ServiceType.TOKEN_SERVICE)).andReturn(authority).anyTimes(); + + EasyMock.replay(principal, services, context, request); + + final TokenResource tr = new TokenResource(); + tr.context = context; + tr.request = request; + + // Request two tokens concurrently + Callable task = tr::doGet; + List> tasks = Collections.nCopies(2, task); + ExecutorService executorService = Executors.newFixedThreadPool(2); + List> futures = executorService.invokeAll(tasks); + List responses = new ArrayList<>(futures.size()); + for (Future f : futures) { + responses.add(f.get()); + } + + // Parse the responses + String accessToken1 = getTagValue(responses.get(0).getEntity().toString(), "access_token"); + assertNotNull(accessToken1); + JWT jwt1 = new JWTToken(accessToken1); + + String accessToken2 = getTagValue(responses.get(1).getEntity().toString(), "access_token"); + assertNotNull(accessToken1); + JWT jwt2 = new JWTToken(accessToken2); + + // Verify the tokens + assertNotEquals("Access tokens should be different.", accessToken1, accessToken2); + assertEquals("The token expirations should be the same.", jwt1.getExpires(), jwt2.getExpires()); + assertNotEquals("Tokens should have unique IDs.", TokenUtils.getTokenId(jwt1), TokenUtils.getTokenId(jwt2)); + } + @Test public void testAudiences() throws Exception { @@ -1125,7 +1186,7 @@ long getExpiration(final String token) { @Override public void addToken(JWTToken token, long issueTime) { - addToken(token.getPayload(), issueTime, token.getExpiresDate().getTime()); + addToken(TokenUtils.getTokenId(token), issueTime, token.getExpiresDate().getTime()); } @Override @@ -1139,58 +1200,58 @@ public long getDefaultMaxLifetimeDuration() { } @Override - public void addToken(String token, long issueTime, long expiration) { - addToken(token, issueTime, expiration, getDefaultMaxLifetimeDuration()); + public void addToken(String tokenId, long issueTime, long expiration) { + addToken(tokenId, issueTime, expiration, getDefaultMaxLifetimeDuration()); } @Override - public void addToken(String token, long issueTime, long expiration, long maxLifetimeDuration) { - issueTimes.put(token, issueTime); - expirationData.put(token, expiration); - maxLifetimes.put(token, issueTime + maxLifetimeDuration); + public void addToken(String tokenId, long issueTime, long expiration, long maxLifetimeDuration) { + issueTimes.put(tokenId, issueTime); + expirationData.put(tokenId, expiration); + maxLifetimes.put(tokenId, issueTime + maxLifetimeDuration); } @Override public boolean isExpired(JWTToken token) { - return isExpired(token.getPayload()); - } - - @Override - public boolean isExpired(String token) { return false; } @Override public void revokeToken(JWTToken token) { - revokeToken(token.getPayload()); + revokeToken(TokenUtils.getTokenId(token)); } @Override - public void revokeToken(String token) { + public void revokeToken(String tokenId) { } @Override public long renewToken(JWTToken token) { - return renewToken(token.getPayload()); + return renewToken(TokenUtils.getTokenId(token)); } @Override - public long renewToken(String token) { - return renewToken(token, 0L); + public long renewToken(String tokenId) { + return renewToken(tokenId, 0L); } @Override public long renewToken(JWTToken token, long renewInterval) { - return renewToken(token.getPayload()); + return renewToken(TokenUtils.getTokenId(token), renewInterval); + } + + @Override + public long renewToken(String tokenId, long renewInterval) { + return 0; } @Override - public long renewToken(String token, long renewInterval) { + public long getTokenExpiration(JWT token) throws UnknownTokenException { return 0; } @Override - public long getTokenExpiration(String token) { + public long getTokenExpiration(String tokenId) { return 0; } diff --git a/gateway-spi/src/main/java/org/apache/knox/gateway/services/security/token/TokenStateService.java b/gateway-spi/src/main/java/org/apache/knox/gateway/services/security/token/TokenStateService.java index 4de929555a..533cf9d21d 100644 --- a/gateway-spi/src/main/java/org/apache/knox/gateway/services/security/token/TokenStateService.java +++ b/gateway-spi/src/main/java/org/apache/knox/gateway/services/security/token/TokenStateService.java @@ -17,6 +17,7 @@ package org.apache.knox.gateway.services.security.token; import org.apache.knox.gateway.services.Service; +import org.apache.knox.gateway.services.security.token.impl.JWT; import org.apache.knox.gateway.services.security.token.impl.JWTToken; @@ -44,24 +45,25 @@ public interface TokenStateService extends Service { * @param issueTime The time the token was issued. */ void addToken(JWTToken token, long issueTime); + /** * Add state for the specified token. * - * @param token The token. + * @param tokenId The token unique identifier. * @param issueTime The time the token was issued. * @param expiration The token expiration time. */ - void addToken(String token, long issueTime, long expiration); + void addToken(String tokenId, long issueTime, long expiration); /** * Add state for the specified token. * - * @param token The token. + * @param tokenId The token unique identifier. * @param issueTime The time the token was issued. * @param expiration The token expiration time. * @param maxLifetimeDuration The maximum allowed lifetime for the token. */ - void addToken(String token, long issueTime, long expiration, long maxLifetimeDuration); + void addToken(String tokenId, long issueTime, long expiration, long maxLifetimeDuration); /** * @@ -71,14 +73,6 @@ public interface TokenStateService extends Service { */ boolean isExpired(JWTToken token) throws UnknownTokenException; - /** - * - * @param token The token. - * - * @return true, if the token has expired; Otherwise, false. - */ - boolean isExpired(String token) throws UnknownTokenException; - /** * Disable any subsequent use of the specified token. * @@ -89,9 +83,9 @@ public interface TokenStateService extends Service { /** * Disable any subsequent use of the specified token. * - * @param token The token. + * @param tokenId The token unique identifier. */ - void revokeToken(String token) throws UnknownTokenException; + void revokeToken(String tokenId) throws UnknownTokenException; /** * Extend the lifetime of the specified token by the default amount of time. @@ -115,21 +109,21 @@ public interface TokenStateService extends Service { /** * Extend the lifetime of the specified token by the default amount of time. * - * @param token The token. + * @param tokenId The token unique identifier. * * @return The token's updated expiration time in milliseconds. */ - long renewToken(String token) throws UnknownTokenException; + long renewToken(String tokenId) throws UnknownTokenException; /** * Extend the lifetime of the specified token by the specified amount of time. * - * @param token The token. + * @param tokenId The token unique identifier. * @param renewInterval The amount of time that should be added to the token's lifetime. * * @return The token's updated expiration time in milliseconds. */ - long renewToken(String token, long renewInterval) throws UnknownTokenException; + long renewToken(String tokenId, long renewInterval) throws UnknownTokenException; /** * @@ -137,6 +131,14 @@ public interface TokenStateService extends Service { * * @return The token's expiration time in milliseconds. */ - long getTokenExpiration(String token) throws UnknownTokenException; + long getTokenExpiration(JWT token) throws UnknownTokenException; + + /** + * + * @param tokenId The token unique identifier. + * + * @return The token's expiration time in milliseconds. + */ + long getTokenExpiration(String tokenId) throws UnknownTokenException; } diff --git a/gateway-spi/src/main/java/org/apache/knox/gateway/services/security/token/TokenUtils.java b/gateway-spi/src/main/java/org/apache/knox/gateway/services/security/token/TokenUtils.java index ab630c0b05..dd054d924d 100644 --- a/gateway-spi/src/main/java/org/apache/knox/gateway/services/security/token/TokenUtils.java +++ b/gateway-spi/src/main/java/org/apache/knox/gateway/services/security/token/TokenUtils.java @@ -16,12 +16,68 @@ */ package org.apache.knox.gateway.services.security.token; +import org.apache.knox.gateway.config.GatewayConfig; +import org.apache.knox.gateway.services.security.token.impl.JWT; +import org.apache.knox.gateway.services.security.token.impl.JWTToken; + +import javax.servlet.FilterConfig; +import javax.servlet.ServletContext; import java.util.Locale; + public class TokenUtils { + /** + * Get a String derived from a JWT String, which is suitable for presentation (e.g., logging) without compromising + * security. + * + * @param token A BASE64-encoded JWT String. + * + * @return An abbreviated form of the specified JWT String. + */ public static String getTokenDisplayText(final String token) { - return String.format(Locale.ROOT, "%s...%s", token.substring(0, 10), token.substring(token.length() - 3)); + return String.format(Locale.ROOT, "%s...%s", token.substring(0, 6), token.substring(token.length() - 6)); + } + + /** + * Extract the unique Knox token identifier from the specified JWT's claim set. + * + * @param token A JWT + * + * @return The unique identifier, or null. + */ + public static String getTokenId(final JWT token) { + return token.getClaim(JWTToken.KNOX_ID_CLAIM); + } + + /** + * Determine if server-managed token state is enabled for a provider, based on configuration. + * The analysis includes checking the provider params and the gateway configuration. + * + * @param filterConfig A FilterConfig object. + * + * @return true, if server-managed state is enabled; Otherwise, false. + */ + public static boolean isServerManagedTokenStateEnabled(FilterConfig filterConfig) { + boolean isServerManaged = false; + + // First, check for explicit provider-level configuration + String providerParamValue = filterConfig.getInitParameter(TokenStateService.CONFIG_SERVER_MANAGED); + + // If there is no provider-level configuration + if (providerParamValue == null || providerParamValue.isEmpty()) { + // Fall back to the gateway-level default + ServletContext context = filterConfig.getServletContext(); + if (context != null) { + GatewayConfig config = (GatewayConfig) context.getAttribute(GatewayConfig.GATEWAY_CONFIG_ATTRIBUTE); + isServerManaged = (config != null) && config.isServerManagedTokenStateEnabled(); + } + } else { + // Otherwise, apply the provider-level configuration + isServerManaged = Boolean.valueOf(providerParamValue); + } + + return isServerManaged; } } diff --git a/gateway-spi/src/main/java/org/apache/knox/gateway/services/security/token/UnknownTokenException.java b/gateway-spi/src/main/java/org/apache/knox/gateway/services/security/token/UnknownTokenException.java index 265e878fa1..1a8c4e0004 100644 --- a/gateway-spi/src/main/java/org/apache/knox/gateway/services/security/token/UnknownTokenException.java +++ b/gateway-spi/src/main/java/org/apache/knox/gateway/services/security/token/UnknownTokenException.java @@ -18,19 +18,23 @@ public class UnknownTokenException extends Exception { - private String token; + private String tokenId; - public UnknownTokenException(final String token) { - this.token = token; + /** + * + * @param tokenId The token unique identifier + */ + public UnknownTokenException(final String tokenId) { + this.tokenId = tokenId; } - public String getToken() { - return token; + public String getTokenId() { + return tokenId; } @Override public String getMessage() { - return "Unknown token: " + TokenUtils.getTokenDisplayText(token); + return "Unknown token: " + tokenId; } } diff --git a/gateway-spi/src/main/java/org/apache/knox/gateway/services/security/token/impl/JWTToken.java b/gateway-spi/src/main/java/org/apache/knox/gateway/services/security/token/impl/JWTToken.java index 35d5e082e5..023010f1ce 100644 --- a/gateway-spi/src/main/java/org/apache/knox/gateway/services/security/token/impl/JWTToken.java +++ b/gateway-spi/src/main/java/org/apache/knox/gateway/services/security/token/impl/JWTToken.java @@ -20,6 +20,7 @@ import java.util.Date; import java.util.ArrayList; import java.util.List; +import java.util.UUID; import org.apache.knox.gateway.i18n.messages.MessagesFactory; @@ -36,6 +37,8 @@ public class JWTToken implements JWT { private static JWTProviderMessages log = MessagesFactory.get( JWTProviderMessages.class ); + public static final String KNOX_ID_CLAIM = "knox.id"; + SignedJWT jwt; private JWTToken(String header, String claims, String signature) throws ParseException { @@ -73,6 +76,9 @@ public JWTToken(String alg, String[] claimsArray, List audiences) { builder = builder.expirationTime(new Date(Long.parseLong(claimsArray[3]))); } + // Add a private UUID claim for uniqueness + builder.claim(KNOX_ID_CLAIM, String.valueOf(UUID.randomUUID())); + claims = builder.build(); jwt = new SignedJWT(header, claims); diff --git a/gateway-spi/src/test/java/org/apache/knox/gateway/services/security/token/impl/JWTTokenTest.java b/gateway-spi/src/test/java/org/apache/knox/gateway/services/security/token/impl/JWTTokenTest.java index f53bc165ce..44e3d28537 100644 --- a/gateway-spi/src/test/java/org/apache/knox/gateway/services/security/token/impl/JWTTokenTest.java +++ b/gateway-spi/src/test/java/org/apache/knox/gateway/services/security/token/impl/JWTTokenTest.java @@ -25,6 +25,7 @@ import java.util.ArrayList; import java.util.Date; import java.util.List; +import java.util.UUID; import org.junit.BeforeClass; import org.junit.Test; @@ -80,6 +81,25 @@ public void testTokenCreation() throws Exception { assertEquals("https://login.example.com", token.getAudience()); } + @Test + public void testPrivateUUIDClaim() throws Exception { + String[] claims = new String[4]; + claims[0] = "KNOXSSO"; + claims[1] = "john.doe@example.com"; + claims[2] = "https://login.example.com"; + claims[3] = Long.toString( ( System.currentTimeMillis()/1000 ) + 300); + JWT token = new JWTToken("RS256", claims); + + assertEquals("KNOXSSO", token.getIssuer()); + assertEquals("john.doe@example.com", token.getSubject()); + assertEquals("https://login.example.com", token.getAudience()); + + String uuidString = token.getClaim(JWTToken.KNOX_ID_CLAIM); + assertNotNull(uuidString); + UUID uuid = UUID.fromString(uuidString); + assertNotNull(uuid); + } + @Test public void testTokenCreationWithAudienceListSingle() throws Exception { String[] claims = new String[4];