Skip to content

Commit

Permalink
re-org access strategy rules for ticket validation
Browse files Browse the repository at this point in the history
# Conflicts:
#	core/cas-server-core-services-api/src/main/java/org/apereo/cas/services/RegisteredServiceAccessStrategyAuditableEnforcer.java
#	core/cas-server-core-services-api/src/main/java/org/apereo/cas/services/RegisteredServiceAccessStrategyUtils.java
#	core/cas-server-core-services/src/test/java/org/apereo/cas/services/RegisteredServiceTestUtils.java
#	core/cas-server-core/src/main/java/org/apereo/cas/DefaultCentralAuthenticationService.java
  • Loading branch information
mmoayyed committed Oct 31, 2021
1 parent 10886bc commit ba0929c
Show file tree
Hide file tree
Showing 9 changed files with 129 additions and 155 deletions.
Expand Up @@ -2,6 +2,7 @@

import org.apereo.cas.authentication.Authentication;
import org.apereo.cas.authentication.AuthenticationResult;
import org.apereo.cas.authentication.principal.Principal;
import org.apereo.cas.authentication.principal.Service;
import org.apereo.cas.services.RegisteredService;
import org.apereo.cas.ticket.ServiceTicket;
Expand All @@ -27,6 +28,8 @@ public class AuditableContext {

private final RegisteredService registeredService;

private final Principal principal;

private final Authentication authentication;

private final ServiceTicket serviceTicket;
Expand All @@ -39,44 +42,25 @@ public class AuditableContext {

private final Object httpResponse;

/**
* Properties.
*/
@Builder.Default
private Map<String, Object> properties = new LinkedHashMap<>(0);

/**
* Get service.
*
* @return optional service
*/
public Optional<Service> getService() {
return Optional.ofNullable(service);
}

/**
* Get registered service.
*
* @return optional registered service
*/
public Optional<RegisteredService> getRegisteredService() {
return Optional.ofNullable(registeredService);
}

/**
* Get.
*
* @return optional authentication
*/
public Optional<Authentication> getAuthentication() {
return Optional.ofNullable(authentication);
}

/**
* Get.
*
* @return optional service ticket
*/
public Optional<Principal> getPrincipal() {
return Optional.ofNullable(this.principal);
}

public Optional<ServiceTicket> getServiceTicket() {
return Optional.ofNullable(serviceTicket);
}
Expand All @@ -89,29 +73,14 @@ public Optional<Object> getResponse() {
return Optional.ofNullable(httpResponse);
}

/**
* Get.
*
* @return optional authentication result
*/
public Optional<AuthenticationResult> getAuthenticationResult() {
return Optional.ofNullable(authenticationResult);
}

/**
* Get.
*
* @return optional tgt
*/
public Optional<TicketGrantingTicket> getTicketGrantingTicket() {
return Optional.ofNullable(ticketGrantingTicket);
}

/**
* Get.
*
* @return optional properties
*/

public Map<String, Object> getProperties() {
return properties;
}
Expand Down
Expand Up @@ -43,6 +43,7 @@ public List<IPersonAttributeDao> attributeRepositories() {
public IPersonAttributeDao attributeRepository() {
val attrs = CollectionUtils.wrap(
"uid", CollectionUtils.wrap("uid"),
"mail", CollectionUtils.wrap("cas@apereo.org"),
"eduPersonAffiliation", CollectionUtils.wrap("developer"),
"groupMembership", CollectionUtils.wrap("adopters"));
return new StubPersonAttributeDao((Map) attrs);
Expand Down
Expand Up @@ -8,12 +8,14 @@
import org.apereo.cas.audit.BaseAuditableExecution;
import org.apereo.cas.authentication.PrincipalException;
import org.apereo.cas.configuration.CasConfigurationProperties;
import org.apereo.cas.util.CollectionUtils;
import org.apereo.cas.util.scripting.WatchableGroovyScriptResource;

import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.apereo.inspektr.audit.annotation.Audit;

import java.util.Map;
import java.util.Optional;

/**
Expand Down Expand Up @@ -41,9 +43,10 @@ private static Optional<AuditableExecutionResult> byServiceTicketAndAuthnResultA
val result = AuditableExecutionResult.of(context);
try {
val serviceTicket = context.getServiceTicket().orElseThrow();
val authResult = context.getAuthenticationResult().orElseThrow();
RegisteredServiceAccessStrategyUtils.ensurePrincipalAccessIsAllowedForService(serviceTicket,
authResult, providedRegisteredService.get());
val authResult = context.getAuthenticationResult().orElseThrow().getAuthentication();
RegisteredServiceAccessStrategyUtils.ensurePrincipalAccessIsAllowedForService(serviceTicket.getService(),
providedRegisteredService.get(), authResult.getPrincipal().getId(),
(Map) CollectionUtils.merge(authResult.getAttributes(), authResult.getPrincipal().getAttributes()));
} catch (final PrincipalException | UnauthorizedServiceException e) {
result.setException(e);
}
Expand All @@ -66,8 +69,10 @@ private static Optional<AuditableExecutionResult> byServiceAndRegisteredServiceA
.ticketGrantingTicket(ticketGrantingTicket.get())
.build();
try {
val authResult = ticketGrantingTicket.get().getRoot().getAuthentication();
RegisteredServiceAccessStrategyUtils.ensurePrincipalAccessIsAllowedForService(service,
registeredService, ticketGrantingTicket.get());
registeredService, authResult.getPrincipal().getId(),
(Map) CollectionUtils.merge(authResult.getAttributes(), authResult.getPrincipal().getAttributes()));
} catch (final PrincipalException | UnauthorizedServiceException e) {
result.setException(e);
}
Expand Down Expand Up @@ -114,6 +119,37 @@ private static Optional<AuditableExecutionResult> byServiceAndRegisteredService(
return Optional.empty();
}

/**
* By service and registered service and principal optional.
*
* @param context the context
* @return the optional
*/
public static Optional<AuditableExecutionResult> byServiceAndRegisteredServiceAndPrincipal(final AuditableContext context) {
val providedService = context.getService();
val providedRegisteredService = context.getRegisteredService();
val providedPrincipal = context.getPrincipal();
if (providedService.isPresent() && providedRegisteredService.isPresent() && providedPrincipal.isPresent()) {
val registeredService = providedRegisteredService.get();
val service = providedService.get();
val principal = providedPrincipal.get();

val result = AuditableExecutionResult.builder()
.registeredService(registeredService)
.service(service)
.build();

try {
RegisteredServiceAccessStrategyUtils.ensurePrincipalAccessIsAllowedForService(service,
registeredService, principal.getId(), principal.getAttributes());
} catch (final PrincipalException | UnauthorizedServiceException e) {
result.setException(e);
}
return Optional.of(result);
}
return Optional.empty();
}

private static Optional<AuditableExecutionResult> byServiceAndRegisteredServiceAndAuthentication(final AuditableContext context) {
val providedService = context.getService();
val providedRegisteredService = context.getRegisteredService();
Expand All @@ -128,10 +164,11 @@ private static Optional<AuditableExecutionResult> byServiceAndRegisteredServiceA
.service(service)
.authentication(authentication)
.build();

try {
RegisteredServiceAccessStrategyUtils.ensurePrincipalAccessIsAllowedForService(service,
registeredService, authentication);
registeredService, authentication.getPrincipal().getId(),
(Map) CollectionUtils.merge(authentication.getAttributes(),
authentication.getPrincipal().getAttributes()));
} catch (final PrincipalException | UnauthorizedServiceException e) {
result.setException(e);
}
Expand All @@ -148,6 +185,7 @@ public AuditableExecutionResult execute(final AuditableContext context) {
return byExternalGroovyScript(context)
.or(() -> byServiceTicketAndAuthnResultAndRegisteredService(context))
.or(() -> byServiceAndRegisteredServiceAndTicketGrantingTicket(context))
.or(() -> byServiceAndRegisteredServiceAndPrincipal(context))
.or(() -> byServiceAndRegisteredServiceAndAuthentication(context))
.or(() -> byServiceAndRegisteredService(context))
.or(() -> byRegisteredService(context))
Expand Down
@@ -1,12 +1,7 @@
package org.apereo.cas.services;

import org.apereo.cas.authentication.Authentication;
import org.apereo.cas.authentication.AuthenticationResult;
import org.apereo.cas.authentication.CoreAuthenticationUtils;
import org.apereo.cas.authentication.PrincipalException;
import org.apereo.cas.authentication.principal.Service;
import org.apereo.cas.configuration.model.core.authentication.PrincipalAttributesCoreProperties;
import org.apereo.cas.ticket.ServiceTicket;
import org.apereo.cas.ticket.TicketGrantingTicket;

import lombok.experimental.UtilityClass;
Expand Down Expand Up @@ -105,7 +100,8 @@ public static void ensureServiceSsoAccessIsAllowed(final RegisteredService regis
* @param ticketGrantingTicket the ticket granting ticket
* @param credentialsProvided the credentials provided
*/
public static void ensureServiceSsoAccessIsAllowed(final RegisteredService registeredService, final Service service,
public static void ensureServiceSsoAccessIsAllowed(final RegisteredService registeredService,
final Service service,
final TicketGrantingTicket ticketGrantingTicket,
final boolean credentialsProvided) {

Expand Down Expand Up @@ -136,112 +132,32 @@ public static void ensureServiceSsoAccessIsAllowed(final RegisteredService regis
* @param attributes the attributes
* @return the boolean
*/
static boolean ensurePrincipalAccessIsAllowedForService(final Service service,
final RegisteredService registeredService,
final String principalId,
final Map<String, List<Object>> attributes) {
public static boolean ensurePrincipalAccessIsAllowedForService(final Service service,
final RegisteredService registeredService,
final String principalId,
final Map<String, List<Object>> attributes) {
ensureServiceAccessIsAllowed(service, registeredService);
LOGGER.trace("Checking access strategy for service [{}], requested by [{}] with attributes [{}].",
service != null ? service.getId() : "unknown", principalId, attributes);

if (!registeredService.getAccessStrategy().doPrincipalAttributesAllowServiceAccess(principalId, (Map) attributes)) {
LOGGER.warn("Cannot grant access to service [{}]; it is not authorized for use by [{}].",
service != null ? service.getId() : "unknown", principalId);
val handlerErrors = new HashMap<String, Throwable>();
val message = String.format("Cannot grant service access to %s", principalId);
val exception = new UnauthorizedServiceForPrincipalException(message, registeredService, principalId, attributes);
handlerErrors.put(UnauthorizedServiceForPrincipalException.class.getSimpleName(), exception);
throw new PrincipalException(UnauthorizedServiceForPrincipalException.CODE_UNAUTHZ_SERVICE, handlerErrors, new HashMap<>(0));
throw new PrincipalException(UnauthorizedServiceForPrincipalException.CODE_UNAUTHZ_SERVICE,
handlerErrors, new HashMap<>(0));
}
return true;
}

/**
* Ensure service access is allowed.
*
* @param service the service
* @param registeredService the registered service
* @param authentication the authentication
* @return the true if access is granted. false otherwise
* @throws UnauthorizedServiceException the unauthorized service exception
* @throws PrincipalException the principal exception
*/
public static boolean ensurePrincipalAccessIsAllowedForService(final Service service,
final RegisteredService registeredService,
final Authentication authentication)
throws UnauthorizedServiceException, PrincipalException {

ensureServiceAccessIsAllowed(service, registeredService);

val principal = authentication.getPrincipal();
val principalAttributes = new HashMap<>(principal.getAttributes());
val merger = CoreAuthenticationUtils.getAttributeMerger(PrincipalAttributesCoreProperties.MergingStrategyTypes.MULTIVALUED);
val context = RegisteredServiceAttributeReleasePolicyContext.builder()
.registeredService(registeredService)
.service(service)
.principal(principal)
.build();
val policyAttributes = registeredService.getAttributeReleasePolicy().getAttributes(context);
val result = CoreAuthenticationUtils.mergeAttributes(principalAttributes, policyAttributes, merger);
LOGGER.trace("Merged principal attributes [{}] with attributes from release policy [{}]. Result: [{}]",
principalAttributes, policyAttributes, result);
result.putAll(authentication.getAttributes());
return ensurePrincipalAccessIsAllowedForService(service, registeredService, principal.getId(), result);
}

/**
* Ensure service access is allowed.
*
* @param serviceTicket the service ticket
* @param registeredService the registered service
* @param ticketGrantingTicket the ticket granting ticket
* @throws UnauthorizedServiceException the unauthorized service exception
* @throws PrincipalException the principal exception
*/
static void ensurePrincipalAccessIsAllowedForService(final ServiceTicket serviceTicket,
final RegisteredService registeredService,
final TicketGrantingTicket ticketGrantingTicket)
throws UnauthorizedServiceException, PrincipalException {
ensurePrincipalAccessIsAllowedForService(serviceTicket.getService(),
registeredService, ticketGrantingTicket.getAuthentication());
}

/**
* Ensure service access is allowed. Determines the final authentication object
* by looking into the chained authentications of the ticket granting ticket.
*
* @param service the service
* @param registeredService the registered service
* @param ticketGrantingTicket the ticket granting ticket
* @throws UnauthorizedServiceException the unauthorized service exception
* @throws PrincipalException the principal exception
*/
static void ensurePrincipalAccessIsAllowedForService(final Service service,
final RegisteredService registeredService,
final TicketGrantingTicket ticketGrantingTicket)
throws UnauthorizedServiceException, PrincipalException {
ensurePrincipalAccessIsAllowedForService(service, registeredService,
ticketGrantingTicket.getRoot().getAuthentication());

}

/**
* Ensure service access is allowed.
*
* @param serviceTicket the service ticket
* @param context the context
* @param registeredService the registered service
* @throws UnauthorizedServiceException the unauthorized service exception
* @throws PrincipalException the principal exception
*/
static void ensurePrincipalAccessIsAllowedForService(final ServiceTicket serviceTicket,
final AuthenticationResult context,
final RegisteredService registeredService)
throws UnauthorizedServiceException, PrincipalException {
ensurePrincipalAccessIsAllowedForService(serviceTicket.getService(), registeredService, context.getAuthentication());
}

/**
* Returns a predicate that determined whether a service has expired.
* Gets registered service expiration policy predicate.
*
* @return true if the service is still valid. false if service has expired.
* @return the registered service expiration policy predicate
*/
public static Predicate<RegisteredService> getRegisteredServiceExpirationPolicyPredicate() {
return service -> {
Expand Down
@@ -1,15 +1,16 @@
package org.apereo.cas.services;

import org.apereo.cas.authentication.PrincipalException;
import org.apereo.cas.ticket.ServiceTicket;
import org.apereo.cas.ticket.TicketGrantingTicket;
import org.apereo.cas.util.CollectionUtils;

import lombok.val;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;

import java.time.LocalDate;
import java.time.ZoneOffset;
import java.util.Map;

import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.Mockito.*;
Expand Down Expand Up @@ -46,12 +47,11 @@ public void verifySsoAccess() {
@Test
public void verifyPrincipalAccess() {
val service = RegisteredServiceTestUtils.getRegisteredService();
val st = mock(ServiceTicket.class);
when(st.getService()).thenReturn(RegisteredServiceTestUtils.getService());
val tgt = mock(TicketGrantingTicket.class);
when(tgt.getAuthentication()).thenReturn(RegisteredServiceTestUtils.getAuthentication());
val authentication = RegisteredServiceTestUtils.getAuthentication();
assertThrows(PrincipalException.class, () ->
RegisteredServiceAccessStrategyUtils.ensurePrincipalAccessIsAllowedForService(st, service, tgt));
RegisteredServiceAccessStrategyUtils.ensurePrincipalAccessIsAllowedForService(
RegisteredServiceTestUtils.getService(), service, authentication.getPrincipal().getId(),
(Map) CollectionUtils.merge(authentication.getAttributes(), authentication.getPrincipal().getAttributes())));
}

}

0 comments on commit ba0929c

Please sign in to comment.