Skip to content

Commit

Permalink
Ensure that only application wide and service scopes are allowed for …
Browse files Browse the repository at this point in the history
…token requests
  • Loading branch information
sbearcsiro committed May 3, 2022
1 parent aae9112 commit a38001f
Show file tree
Hide file tree
Showing 14 changed files with 374 additions and 12 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package org.apereo.cas.support.oauth.scopes;

import org.apereo.cas.support.oauth.web.response.accesstoken.ext.AccessTokenRequestContext;

import lombok.RequiredArgsConstructor;

import java.util.LinkedHashSet;
import java.util.List;
import java.util.Set;

/**
* This is {@link CompositeScopeResolver}.
*
* @author sbearcsiro
* @since 6.6.0
*/
@RequiredArgsConstructor
public class CompositeScopeResolver implements ScopeResolver {

private final List<ScopeResolver> resolvers;

@Override
public boolean supportsService(final AccessTokenRequestContext requestContext) {
return resolvers.stream().anyMatch(resolver -> resolver.supportsService(requestContext));
}

@Override
public Set<String> resolveRequestScopes(final AccessTokenRequestContext requestContext) {
return resolvers
.stream()
.filter(resolver -> resolver.supportsService(requestContext))
.map(resolver -> resolver.resolveRequestScopes(requestContext))
.findFirst()
.orElseGet(LinkedHashSet::new);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package org.apereo.cas.support.oauth.scopes;

import org.apereo.cas.support.oauth.web.response.accesstoken.ext.AccessTokenRequestContext;
import org.springframework.core.annotation.Order;

import java.util.LinkedHashSet;
import java.util.Set;

/**
* This is {@link DefaultOAuth20ScopeResolver}.
*
* Since OAuth services don't know about scopes all scopes are allowed.
*
* @author sbearcsiro
* @since 6.6.0
*/
@Order
public class DefaultOAuth20ScopeResolver implements ScopeResolver {

@Override
public boolean supportsService(final AccessTokenRequestContext requestContext) {
return true;
}

@Override
public Set<String> resolveRequestScopes(final AccessTokenRequestContext requestContext) {
return new LinkedHashSet<>(requestContext.getScopes());
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package org.apereo.cas.support.oauth.scopes;

import org.apereo.cas.support.oauth.web.response.accesstoken.ext.AccessTokenRequestContext;

import java.util.Set;

/**
* Implementations of this interface resolve the allowed scopes for a given request context.
*
* @author sbearcsiro
* @since 6.6.0
*/
public interface ScopeResolver {

/**
* Whether this {@link ScopeResolver} supports the given request context.
* @param requestContext The request context
* @return true if the resolver can handle the given context
*/
boolean supportsService(AccessTokenRequestContext requestContext);

/**
* Resolves the scopes for the request context.
*
* @param requestContext The request context
* @return The set of allowed scopes for this request
*/
Set<String> resolveRequestScopes(AccessTokenRequestContext requestContext);

}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.apereo.cas.support.oauth.authenticator.OAuth20CasAuthenticationBuilder;
import org.apereo.cas.support.oauth.profile.OAuth20ProfileScopeToAttributesFilter;
import org.apereo.cas.support.oauth.profile.OAuth20UserProfileDataCreator;
import org.apereo.cas.support.oauth.scopes.ScopeResolver;
import org.apereo.cas.support.oauth.validator.OAuth20ClientSecretValidator;
import org.apereo.cas.support.oauth.validator.authorization.OAuth20AuthorizationRequestValidator;
import org.apereo.cas.support.oauth.validator.token.OAuth20TokenRequestValidator;
Expand Down Expand Up @@ -132,6 +133,8 @@ public class OAuth20ConfigurationContext {

private final OAuth20ClientSecretValidator clientSecretValidator;

private final ScopeResolver scopeResolver;

/**
* Gets ticket granting ticket.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import org.apereo.cas.configuration.support.Beans;
import org.apereo.cas.support.oauth.OAuth20Constants;
import org.apereo.cas.support.oauth.OAuth20ResponseTypes;
import org.apereo.cas.support.oauth.scopes.ScopeResolver;
import org.apereo.cas.support.oauth.validator.token.device.InvalidOAuth20DeviceTokenException;
import org.apereo.cas.support.oauth.validator.token.device.ThrottledOAuth20DeviceUserCodeApprovalException;
import org.apereo.cas.support.oauth.validator.token.device.UnapprovedOAuth20DeviceUserCodeException;
Expand Down Expand Up @@ -72,6 +73,11 @@ public class OAuth20DefaultTokenGenerator implements OAuth20TokenGenerator {
*/
protected final CentralAuthenticationService centralAuthenticationService;

/**
* The scope resolver.
*/
protected final ScopeResolver scopeResolver;

/**
* CAS configuration settings.
*/
Expand Down Expand Up @@ -169,7 +175,7 @@ protected Pair<OAuth20AccessToken, OAuth20RefreshToken> generateAccessTokenOAuth
.newInstance(holder.getAuthentication())
.setAuthenticationDate(ZonedDateTime.now(ZoneOffset.UTC))
.addAttribute(OAuth20Constants.GRANT_TYPE, holder.getGrantType().toString())
.addAttribute(OAuth20Constants.SCOPE, holder.getScopes())
.addAttribute(OAuth20Constants.SCOPE, scopeResolver.resolveRequestScopes(holder))
.addAttribute(OAuth20Constants.CLIENT_ID, clientId);

val requestedClaims = holder.getClaims().getOrDefault(OAuth20Constants.CLAIMS_USERINFO, new HashMap<>());
Expand All @@ -179,7 +185,7 @@ protected Pair<OAuth20AccessToken, OAuth20RefreshToken> generateAccessTokenOAuth
LOGGER.debug("Creating access token for [{}]", holder);
val ticketGrantingTicket = holder.getTicketGrantingTicket();
val accessToken = this.accessTokenFactory.create(holder.getService(),
authentication, ticketGrantingTicket, holder.getScopes(),
authentication, ticketGrantingTicket, this.scopeResolver.resolveRequestScopes(holder),
Optional.ofNullable(holder.getToken()).map(Ticket::getId).orElse(null),
clientId, holder.getClaims(),
holder.getResponseType(), holder.getGrantType());
Expand Down Expand Up @@ -265,7 +271,7 @@ protected OAuth20RefreshToken generateRefreshToken(final AccessTokenRequestConte
val refreshToken = this.refreshTokenFactory.create(responseHolder.getService(),
responseHolder.getAuthentication(),
responseHolder.getTicketGrantingTicket(),
responseHolder.getScopes(),
scopeResolver.resolveRequestScopes(responseHolder),
responseHolder.getRegisteredService().getClientId(),
accessToken.getId(),
responseHolder.getClaims(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ protected Set<String> extractRequestedScopesByToken(final Set<String> requestedS
final OAuth20Token token,
final WebContext context) {
val scopes = new TreeSet<>(requestedScopes);
scopes.addAll(token.getScopes());
scopes.retainAll(token.getScopes());
return scopes;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ public ModelAndView build(final AccessTokenRequestContext holder) throws Excepti
val authentication = holder.getAuthentication();
val factory = (OAuth20CodeFactory) configurationContext.getTicketFactory().get(OAuth20Code.class);
val code = factory.create(holder.getService(), authentication,
holder.getTicketGrantingTicket(), holder.getScopes(),
holder.getTicketGrantingTicket(), configurationContext.getScopeResolver().resolveRequestScopes(holder),
holder.getCodeChallenge(), holder.getCodeChallengeMethod(),
holder.getClientId(), holder.getClaims(),
holder.getResponseType(), holder.getGrantType());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@
import org.apereo.cas.support.oauth.profile.DefaultOAuth20UserProfileDataCreator;
import org.apereo.cas.support.oauth.profile.OAuth20ProfileScopeToAttributesFilter;
import org.apereo.cas.support.oauth.profile.OAuth20UserProfileDataCreator;
import org.apereo.cas.support.oauth.scopes.CompositeScopeResolver;
import org.apereo.cas.support.oauth.scopes.DefaultOAuth20ScopeResolver;
import org.apereo.cas.support.oauth.scopes.ScopeResolver;
import org.apereo.cas.support.oauth.services.OAuth20RegisteredServiceCipherExecutor;
import org.apereo.cas.support.oauth.util.OAuth20Utils;
import org.apereo.cas.support.oauth.validator.DefaultOAuth20ClientSecretValidator;
Expand Down Expand Up @@ -225,6 +228,26 @@ public JwtBuilder accessTokenJwtBuilder(
}
}

@Configuration(value = "CasOAuth20ScopesConfiguration", proxyBeanMethods = false)
@EnableConfigurationProperties(CasConfigurationProperties.class)
public static class CasOAuth20ScopesConfiguration {

@Bean
@ConditionalOnMissingBean(name = "oauthScopeResolver")
@RefreshScope(proxyMode = ScopedProxyMode.DEFAULT)
public ScopeResolver oauthScopeResolver() {
return new DefaultOAuth20ScopeResolver();
}

@Bean
@ConditionalOnMissingBean(name = "defaultScopeResolver")
@RefreshScope(proxyMode = ScopedProxyMode.DEFAULT)
public CompositeScopeResolver defaultScopeResolver(final List<ScopeResolver> scopeResolvers) {
scopeResolvers.sort(AnnotationAwareOrderComparator.INSTANCE);
return new CompositeScopeResolver(scopeResolvers);
}
}

@Configuration(value = "CasOAuth20ContextConfiguration", proxyBeanMethods = false)
@EnableConfigurationProperties(CasConfigurationProperties.class)
public static class CasOAuth20ContextConfiguration {
Expand Down Expand Up @@ -285,7 +308,9 @@ public OAuth20ConfigurationContext oauth20ConfigurationContext(
final ObjectProvider<List<OAuth20AuthorizationResponseBuilder>> oauthAuthorizationResponseBuilders,
final ObjectProvider<List<OAuth20AuthorizationRequestValidator>> oauthAuthorizationRequestValidators,
@Qualifier("oauthTokenGenerator")
final OAuth20TokenGenerator oauthTokenGenerator) {
final OAuth20TokenGenerator oauthTokenGenerator,
@Qualifier("defaultScopeResolver")
final ScopeResolver defaultScopeResolver) {
return OAuth20ConfigurationContext.builder()
.requestParameterResolver(oauthRequestParameterResolver)
.applicationContext(applicationContext)
Expand Down Expand Up @@ -317,6 +342,7 @@ public OAuth20ConfigurationContext oauth20ConfigurationContext(
.oauthAuthorizationResponseBuilders(oauthAuthorizationResponseBuilders)
.oauthRequestValidators(oauthAuthorizationRequestValidators)
.clientSecretValidator(oauth20ClientSecretValidator)
.scopeResolver(defaultScopeResolver)
.build();
}
}
Expand Down Expand Up @@ -365,11 +391,13 @@ public OAuth20TokenGenerator oauthTokenGenerator(
final OAuth20AccessTokenFactory defaultAccessTokenFactory,
@Qualifier(CentralAuthenticationService.BEAN_NAME)
final CentralAuthenticationService centralAuthenticationService,
@Qualifier("defaultScopeResolver")
final ScopeResolver defaultScopeResolver,
final CasConfigurationProperties casProperties) {
return new OAuth20DefaultTokenGenerator(
defaultAccessTokenFactory, defaultDeviceTokenFactory,
defaultDeviceUserCodeFactory, defaultRefreshTokenFactory,
centralAuthenticationService, casProperties);
centralAuthenticationService, defaultScopeResolver, casProperties);
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
import org.apereo.cas.support.oauth.OAuth20Constants;
import org.apereo.cas.support.oauth.OAuth20GrantTypes;
import org.apereo.cas.support.oauth.OAuth20ResponseTypes;
import org.apereo.cas.support.oauth.scopes.ScopeResolver;
import org.apereo.cas.support.oauth.services.OAuthRegisteredService;
import org.apereo.cas.support.oauth.validator.OAuth20ClientSecretValidator;
import org.apereo.cas.support.oauth.web.CasOAuth20TestAuthenticationEventExecutionPlanConfiguration;
Expand Down Expand Up @@ -134,6 +135,7 @@
import java.time.ZoneOffset;
import java.time.ZonedDateTime;
import java.util.ArrayList;
import java.util.Collection;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.HashSet;
Expand Down Expand Up @@ -277,6 +279,10 @@ public abstract class AbstractOAuth20Tests {
@Qualifier(CentralAuthenticationService.BEAN_NAME)
protected CentralAuthenticationService centralAuthenticationService;

@Autowired
@Qualifier("defaultScopeResolver")
protected ScopeResolver defaultScopeResolver;

@Autowired
@Qualifier("requiresAuthenticationAccessTokenInterceptor")
protected HandlerInterceptor requiresAuthenticationInterceptor;
Expand Down Expand Up @@ -519,8 +525,18 @@ protected OAuth20Code addCode(final Principal principal, final OAuthRegisteredSe
return addCodeWithChallenge(principal, registeredService, null, null);
}

protected OAuth20Code addCodeWithScopes(final Principal principal, final OAuthRegisteredService registeredService,
final Collection<String> scopes) throws Exception {
return addCodeWithChallengeAndScopes(principal, registeredService, null, null, scopes);
}

protected OAuth20Code addCodeWithChallenge(final Principal principal, final OAuthRegisteredService registeredService,
final String codeChallenge, final String codeChallengeMethod) throws Exception {
return addCodeWithChallengeAndScopes(principal, registeredService, codeChallenge, codeChallengeMethod, new ArrayList<>());
}

protected OAuth20Code addCodeWithChallengeAndScopes(final Principal principal, final OAuthRegisteredService registeredService,
final String codeChallenge, final String codeChallengeMethod, final Collection<String> scopes) throws Exception {
val authentication = getAuthentication(principal);
val factory = new WebApplicationServiceFactory();
val service = factory.createService(registeredService.getClientId());
Expand All @@ -529,7 +545,7 @@ protected OAuth20Code addCodeWithChallenge(final Principal principal, final OAut
this.ticketRegistry.addTicket(tgt);

val code = oAuthCodeFactory.create(service, authentication,
tgt, new ArrayList<>(),
tgt, scopes,
codeChallenge, codeChallengeMethod, CLIENT_ID, new HashMap<>(),
OAuth20ResponseTypes.CODE, OAuth20GrantTypes.AUTHORIZATION_CODE);
this.ticketRegistry.addTicket(code);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.springframework.test.context.TestPropertySource;

import java.util.LinkedHashSet;
import java.util.Set;
import java.util.UUID;

import static org.apereo.cas.util.junit.Assertions.*;
Expand Down Expand Up @@ -86,7 +87,7 @@ public void verifyAccessTokenAsJwt() throws Exception {
@Test
public void verifySlowDown() throws Exception {
val generator = new OAuth20DefaultTokenGenerator(defaultAccessTokenFactory, defaultDeviceTokenFactory,
defaultDeviceUserCodeFactory, oAuthRefreshTokenFactory, centralAuthenticationService, casProperties);
defaultDeviceUserCodeFactory, oAuthRefreshTokenFactory, centralAuthenticationService, defaultScopeResolver, casProperties);
val token = defaultDeviceTokenFactory.createDeviceCode(
RegisteredServiceTestUtils.getService("https://device.oauth.org"));
ticketRegistry.addTicket(token);
Expand All @@ -104,7 +105,7 @@ public void verifySlowDown() throws Exception {
@Test
public void verifyUnapproved() throws Exception {
val generator = new OAuth20DefaultTokenGenerator(defaultAccessTokenFactory, defaultDeviceTokenFactory,
defaultDeviceUserCodeFactory, oAuthRefreshTokenFactory, centralAuthenticationService, casProperties);
defaultDeviceUserCodeFactory, oAuthRefreshTokenFactory, centralAuthenticationService, defaultScopeResolver, casProperties);
val token = defaultDeviceTokenFactory.createDeviceCode(
RegisteredServiceTestUtils.getService("https://device.oauth.org"));
ticketRegistry.addTicket(token);
Expand All @@ -121,10 +122,31 @@ public void verifyUnapproved() throws Exception {
assertThrows(UnapprovedOAuth20DeviceUserCodeException.class, () -> generator.generate(holder));
}

@Test
public void verifyScopes() throws Exception {
val registeredService = getRegisteredService(UUID.randomUUID().toString(), "secret", new LinkedHashSet<>());
servicesManager.save(registeredService);
val generator = new OAuth20DefaultTokenGenerator(defaultAccessTokenFactory, defaultDeviceTokenFactory,
defaultDeviceUserCodeFactory, oAuthRefreshTokenFactory, centralAuthenticationService, defaultScopeResolver, casProperties);
val service = RegisteredServiceTestUtils.getService(registeredService.getServiceId());

Thread.sleep(2000);
val holder = AccessTokenRequestContext.builder()
.service(service)
.responseType(OAuth20ResponseTypes.CODE)
.scopes(Set.of("test1", "test3"))
.authentication(RegisteredServiceTestUtils.getAuthentication())
.registeredService(registeredService)
.build();
val result = generator.generate(holder);
assertTrue(result.getAccessToken().isPresent());
assertEquals(Set.of("test1", "test3"), result.getAccessToken().get().getScopes());
}

@Test
public void verifyExpiredUserCode() throws Exception {
val generator = new OAuth20DefaultTokenGenerator(defaultAccessTokenFactory, defaultDeviceTokenFactory,
defaultDeviceUserCodeFactory, oAuthRefreshTokenFactory, centralAuthenticationService, casProperties);
defaultDeviceUserCodeFactory, oAuthRefreshTokenFactory, centralAuthenticationService, defaultScopeResolver, casProperties);
val token = defaultDeviceTokenFactory.createDeviceCode(
RegisteredServiceTestUtils.getService("https://device.oauth.org"));
ticketRegistry.addTicket(token);
Expand All @@ -146,7 +168,7 @@ public void verifyExpiredUserCode() throws Exception {
@Test
public void verifyDeviceCodeExpired() throws Exception {
val generator = new OAuth20DefaultTokenGenerator(defaultAccessTokenFactory, defaultDeviceTokenFactory,
defaultDeviceUserCodeFactory, oAuthRefreshTokenFactory, centralAuthenticationService, casProperties);
defaultDeviceUserCodeFactory, oAuthRefreshTokenFactory, centralAuthenticationService, defaultScopeResolver, casProperties);
val token = defaultDeviceTokenFactory.createDeviceCode(
RegisteredServiceTestUtils.getService("https://device.oauth.org"));
ticketRegistry.addTicket(token);
Expand Down

0 comments on commit a38001f

Please sign in to comment.