Navigation Menu

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor access token jwt fixes #4333

Merged
merged 19 commits into from Oct 14, 2019
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -11,9 +11,6 @@
import lombok.Getter;
import lombok.val;

import java.time.ZoneOffset;
import java.time.ZonedDateTime;

/**
* This is {@link OAuth20JwtAccessTokenEncoder}.
*
Expand All @@ -32,7 +29,7 @@ public String encode() {
val oAuthRegisteredService = OAuthRegisteredService.class.cast(this.registeredService);
val authentication = accessToken.getAuthentication();
if (oAuthRegisteredService != null && oAuthRegisteredService.isJwtAccessToken()) {
val dt = ZonedDateTime.now(ZoneOffset.UTC).plusSeconds(accessToken.getExpirationPolicy().getTimeToLive());
val dt = authentication.getAuthenticationDate().plusSeconds(accessToken.getExpirationPolicy().getTimeToLive());
val builder = JwtBuilder.JwtRequest.builder();

val request = builder
Expand Down
Expand Up @@ -136,7 +136,7 @@ protected ModelAndView buildCallbackUrlResponseType(final AccessTokenRequestData

LOGGER.debug("Redirecting to URL [{}]", url);
val parameters = new LinkedHashMap<String, String>();
parameters.put(OAuth20Constants.ACCESS_TOKEN, accessToken.getId());
parameters.put(OAuth20Constants.ACCESS_TOKEN, encodedAccessToken);
if (refreshToken != null) {
parameters.put(OAuth20Constants.REFRESH_TOKEN, refreshToken.getId());
}
Expand Down
Expand Up @@ -20,15 +20,15 @@ public interface IdTokenGeneratorService {
*
* @param request the request
* @param response the response
* @param accessTokenId the access token id
* @param accessToken the access token
* @param timeoutInSeconds the timeoutInSeconds
* @param responseType the response type
* @param registeredService the registered service
* @return the string
*/
String generate(HttpServletRequest request,
HttpServletResponse response,
AccessToken accessTokenId,
AccessToken accessToken,
long timeoutInSeconds,
OAuth20ResponseTypes responseType,
OAuthRegisteredService registeredService);
Expand Down
Expand Up @@ -6,6 +6,7 @@
import org.apereo.cas.ticket.IdTokenGeneratorService;
import org.apereo.cas.ticket.OAuthTokenSigningAndEncryptionService;
import org.apereo.cas.ticket.registry.TicketRegistry;
import org.apereo.cas.token.JwtBuilder;
import org.apereo.cas.uma.claim.UmaResourceSetClaimPermissionExaminer;
import org.apereo.cas.uma.ticket.permission.UmaPermissionTicketFactory;
import org.apereo.cas.uma.ticket.resource.repository.ResourceSetRepository;
Expand All @@ -32,6 +33,7 @@ public class UmaConfigurationContext {
private final ServicesManager servicesManager;
private final TicketRegistry ticketRegistry;
private final OAuth20TokenGenerator accessTokenGenerator;
private final JwtBuilder accessTokenJwtBuilder;
private final UmaPermissionTicketFactory umaPermissionTicketFactory;
private final ResourceSetRepository umaResourceSetRepository;
private final CasConfigurationProperties casProperties;
Expand Down
Expand Up @@ -53,7 +53,7 @@ public String generate(final HttpServletRequest request,
* Build jwt claims jwt claims.
*
* @param request the request
* @param accessTokenId the access token id
* @param accessToken the access token
* @param timeoutInSeconds the timeout in seconds
* @param service the service
* @param profile the profile
Expand All @@ -62,7 +62,7 @@ public String generate(final HttpServletRequest request,
* @return the jwt claims
*/
protected JwtClaims buildJwtClaims(final HttpServletRequest request,
final AccessToken accessTokenId,
final AccessToken accessToken,
final long timeoutInSeconds,
final OAuthRegisteredService service,
final UserProfile profile,
Expand Down
Expand Up @@ -6,6 +6,7 @@
import org.apereo.cas.support.oauth.OAuth20ResponseTypes;
import org.apereo.cas.support.oauth.util.OAuth20Utils;
import org.apereo.cas.support.oauth.web.response.accesstoken.ext.AccessTokenRequestDataHolder;
import org.apereo.cas.support.oauth.web.response.accesstoken.response.OAuth20JwtAccessTokenEncoder;
import org.apereo.cas.ticket.accesstoken.AccessToken;
import org.apereo.cas.uma.UmaConfigurationContext;
import org.apereo.cas.uma.claim.UmaResourceSetClaimPermissionResult;
Expand Down Expand Up @@ -42,6 +43,7 @@
@Slf4j
@Controller("umaAuthorizationRequestEndpointController")
public class UmaAuthorizationRequestEndpointController extends BaseUmaEndpointController {

public UmaAuthorizationRequestEndpointController(final UmaConfigurationContext umaConfigurationContext) {
super(umaConfigurationContext);
}
Expand Down Expand Up @@ -185,19 +187,29 @@ protected ResponseEntity generateRequestingPartyToken(final HttpServletRequest r
}

val accessToken = result.getAccessToken().get();

val encodedAccessToken = OAuth20JwtAccessTokenEncoder.builder()
.accessToken(accessToken)
.registeredService(holder.getRegisteredService())
.service(holder.getService())
.accessTokenJwtBuilder(getUmaConfigurationContext().getAccessTokenJwtBuilder())
.build()
.encode();

val timeout = Beans.newDuration(getUmaConfigurationContext().getCasProperties()
.getAuthn().getUma().getRequestingPartyToken().getMaxTimeToLiveInSeconds()).getSeconds();
request.setAttribute(UmaPermissionTicket.class.getName(), permissionTicket);
request.setAttribute(ResourceSet.class.getName(), resourceSet);
val idToken = getUmaConfigurationContext().getRequestingPartyTokenGenerator().generate(request, response, accessToken,
timeout, OAuth20ResponseTypes.CODE, registeredService);
val idToken = getUmaConfigurationContext().getRequestingPartyTokenGenerator().generate(request, response,
accessToken, timeout, OAuth20ResponseTypes.CODE, registeredService);
accessToken.setIdToken(idToken);
getUmaConfigurationContext().getTicketRegistry().updateTicket(accessToken);

if (StringUtils.isNotBlank(umaRequest.getRpt())) {
getUmaConfigurationContext().getTicketRegistry().deleteTicket(umaRequest.getRpt());
}
val model = CollectionUtils.wrap("rpt", accessToken.getId(), "code", HttpStatus.CREATED);

val model = CollectionUtils.wrap("rpt", encodedAccessToken, "code", HttpStatus.CREATED);
return new ResponseEntity<>(model, HttpStatus.OK);
}
}
Expand Up @@ -342,12 +342,6 @@ public Authenticator<TokenCredentials> oAuthAccessTokenAuthenticator() {
return new OAuth20AccessTokenAuthenticator(ticketRegistry.getObject());
}

@ConditionalOnMissingBean(name = "oauthAccessTokenResponseGenerator")
@Bean
public OAuth20AccessTokenResponseGenerator oauthAccessTokenResponseGenerator() {
return new OAuth20DefaultAccessTokenResponseGenerator(accessTokenJwtBuilder());
}
mmoayyed marked this conversation as resolved.
Show resolved Hide resolved

@Bean
@RefreshScope
@ConditionalOnMissingBean(name = "defaultAccessTokenFactory")
Expand Down
Expand Up @@ -69,7 +69,7 @@ public String generate(final HttpServletRequest request,
* Produce claims as jwt.
*
* @param request the request
* @param accessTokenId the access token id
* @param accessToken the access token
* @param timeoutInSeconds the timeoutInSeconds
* @param service the service
* @param profile the user profile
Expand All @@ -78,26 +78,26 @@ public String generate(final HttpServletRequest request,
* @return the jwt claims
*/
protected JwtClaims buildJwtClaims(final HttpServletRequest request,
final AccessToken accessTokenId,
final AccessToken accessToken,
final long timeoutInSeconds,
final OidcRegisteredService service,
final UserProfile profile,
final JEEContext context,
final OAuth20ResponseTypes responseType) {
val authentication = accessTokenId.getAuthentication();
val authentication = accessToken.getAuthentication();

val principal = this.getConfigurationContext().getProfileScopeToAttributesFilter()
.filter(accessTokenId.getService(), authentication.getPrincipal(), service, context, accessTokenId);
.filter(accessToken.getService(), authentication.getPrincipal(), service, context, accessToken);

val oidc = getConfigurationContext().getCasProperties().getAuthn().getOidc();

val claims = new JwtClaims();

val jwtId = getJwtId(accessTokenId.getTicketGrantingTicket());
val jwtId = getJwtId(accessToken.getTicketGrantingTicket());
claims.setJwtId(jwtId);

claims.setIssuer(oidc.getIssuer());
claims.setAudience(accessTokenId.getClientId());
claims.setAudience(accessToken.getClientId());

val expirationDate = NumericDate.now();
expirationDate.addSeconds(timeoutInSeconds);
Expand Down Expand Up @@ -125,7 +125,7 @@ protected JwtClaims buildJwtClaims(final HttpServletRequest request,
if (attributes.containsKey(OAuth20Constants.NONCE)) {
claims.setClaim(OAuth20Constants.NONCE, attributes.get(OAuth20Constants.NONCE).get(0));
}
generateAccessTokenHash(accessTokenId, service, claims);
generateAccessTokenHash(accessToken, service, claims);

LOGGER.trace("Comparing principal attributes [{}] with supported claims [{}]", principal.getAttributes(), oidc.getClaims());

Expand Down Expand Up @@ -185,7 +185,7 @@ protected String getJwtId(final TicketGrantingTicket tgt) {
/**
* Generate access token hash string.
*
* @param accessToken the access token id
* @param accessToken the access token
* @param registeredService the service
* @param claims the claims
*/
Expand All @@ -200,7 +200,6 @@ protected void generateAccessTokenHash(final AccessToken accessToken,
.build()
.encode();

claims.setClaim(OAuth20Constants.ACCESS_TOKEN, encodedAccessToken);
val alg = getConfigurationContext().getIdTokenSigningAndEncryptionService().getJsonWebKeySigningAlgorithm(registeredService);
val hash = OAuth20AccessTokenAtHashGenerator.builder()
.accessTokenId(encodedAccessToken)
Expand Down
Expand Up @@ -13,6 +13,7 @@
import org.apereo.cas.support.oauth.util.OAuth20Utils;
import org.apereo.cas.support.oauth.web.endpoints.BaseOAuth20Controller;
import org.apereo.cas.support.oauth.web.endpoints.OAuth20ConfigurationContext;
import org.apereo.cas.support.oauth.web.response.accesstoken.response.OAuth20JwtAccessTokenEncoder;
import org.apereo.cas.ticket.accesstoken.AccessToken;
import org.apereo.cas.util.HttpUtils;
import org.apereo.cas.util.RandomUtils;
Expand Down Expand Up @@ -149,7 +150,16 @@ public ResponseEntity handleRequestInternal(@RequestBody final String jsonInput,
val clientResponse = OidcClientRegistrationUtils.getClientRegistrationResponse(registeredService, prefix);

val accessToken = generateRegistrationAccessToken(request, response, registeredService, registrationRequest);
clientResponse.setRegistrationAccessToken(accessToken.getId());

val encodedAccessToken = OAuth20JwtAccessTokenEncoder.builder()
.accessToken(accessToken)
.registeredService(registeredService)
.service(accessToken.getService())
.accessTokenJwtBuilder(getOAuthConfigurationContext().getAccessTokenJwtBuilder())
.build()
.encode();

clientResponse.setRegistrationAccessToken(encodedAccessToken);

registeredService.setScopes(supportedScopes);
val processedScopes = new LinkedHashSet<String>(supportedScopes);
Expand Down