Skip to content

Commit

Permalink
fix authentication tests for rest
Browse files Browse the repository at this point in the history
  • Loading branch information
skublik committed Mar 7, 2023
1 parent eb9cf57 commit 209001f
Show file tree
Hide file tree
Showing 16 changed files with 129 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
*/
package com.evolveum.midpoint.authentication.api;

import com.evolveum.midpoint.authentication.api.config.MidpointAuthentication;

import org.springframework.context.ApplicationEvent;

import java.util.List;

/**
* @author skublik
*/
Expand All @@ -19,5 +19,5 @@ protected RemoveUnusedSecurityFilterEvent(Object source) {
super(source);
}

public abstract MidpointAuthentication getMpAuthentication();
public abstract List<AuthModule> getAuthModules();
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import com.evolveum.midpoint.util.logging.Trace;
import com.evolveum.midpoint.util.logging.TraceManager;

import java.util.List;

/**
* @author skublik
*/
Expand All @@ -29,9 +31,9 @@ public class RemoveUnusedSecurityFilterPublisher {

private static RemoveUnusedSecurityFilterPublisher instance;

public void publishCustomEvent(final MidpointAuthentication mpAuthentication) {
LOGGER.trace("Publishing RemoveUnusedSecurityFilterEvent event. With authentication: " + mpAuthentication);
RemoveUnusedSecurityFilterEventImpl customSpringEvent = new RemoveUnusedSecurityFilterEventImpl(this, mpAuthentication);
public void publishCustomEvent(final List<AuthModule> modules) {
LOGGER.trace("Publishing RemoveUnusedSecurityFilterEvent event. With authentication modules: " + modules);
RemoveUnusedSecurityFilterEventImpl customSpringEvent = new RemoveUnusedSecurityFilterEventImpl(this, modules);
applicationEventPublisher.publishEvent(customSpringEvent);
}

Expand All @@ -46,16 +48,16 @@ public static RemoveUnusedSecurityFilterPublisher get() {

private static class RemoveUnusedSecurityFilterEventImpl extends RemoveUnusedSecurityFilterEvent {

private final MidpointAuthentication mpAuthentication;
private final List<AuthModule> modules;

RemoveUnusedSecurityFilterEventImpl(Object source, MidpointAuthentication mpAuthentication) {
RemoveUnusedSecurityFilterEventImpl(Object source, List<AuthModule> modules) {
super(source);
this.mpAuthentication = mpAuthentication;
this.modules = modules;
}

@Override
public MidpointAuthentication getMpAuthentication() {
return mpAuthentication;
public List<AuthModule> getAuthModules() {
return modules;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import java.util.stream.Collectors;
import javax.servlet.http.HttpServletRequest;

import com.evolveum.midpoint.authentication.api.AutheticationFailedData;
import com.evolveum.midpoint.authentication.api.*;

import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.Validate;
Expand All @@ -21,9 +21,6 @@
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.authority.AuthorityUtils;

import com.evolveum.midpoint.authentication.api.AuthModule;
import com.evolveum.midpoint.authentication.api.AuthenticationChannel;
import com.evolveum.midpoint.authentication.api.AuthenticationModuleState;
import com.evolveum.midpoint.authentication.api.util.AuthUtil;
import com.evolveum.midpoint.model.api.authentication.GuiProfiledPrincipal;
import com.evolveum.midpoint.security.api.AuthenticationAnonymousChecker;
Expand Down Expand Up @@ -94,6 +91,11 @@ public List<AuthModule> getAuthModules() {
}

public void setAuthModules(List<AuthModule> authModules) {
if (!this.authModules.isEmpty()) {
List<AuthModule> modules = new ArrayList<>();
modules.addAll(this.authModules);
RemoveUnusedSecurityFilterPublisher.get().publishCustomEvent(modules);
}
this.authModules = authModules;
}

Expand Down Expand Up @@ -451,7 +453,9 @@ public String getUsername() {
ModuleAuthentication moduleAuthentication = getFirstFailedAuthenticationModule();
if (moduleAuthentication != null) {
AutheticationFailedData failureData = moduleAuthentication.getFailureData();
return failureData.getUsername();
if (failureData != null) {
return failureData.getUsername();
}
}
return "";
}
Expand All @@ -460,19 +464,31 @@ public String getFailedReason() {
ModuleAuthentication moduleAuthentication = getFirstFailedAuthenticationModule();
if (moduleAuthentication != null) {
AutheticationFailedData failureData = moduleAuthentication.getFailureData();
return failureData.getFailureMessage();
if (failureData != null) {
return failureData.getFailureMessage();
}
}
return "";
}

public ModuleAuthentication getFirstFailedAuthenticationModule() {
List<ModuleAuthentication> moduleAuthentications = getAuthentications();
ModuleAuthentication found = null;
for (ModuleAuthentication moduleAuthentication : moduleAuthentications) {
if (AuthenticationModuleState.FAILURE == moduleAuthentication.getState()) {
if (AuthenticationModuleState.FAILURE == moduleAuthentication.getState() && found == null) {
found = moduleAuthentication;
if (found.getFailureData() != null) {
return found;
}
continue;
}
if (AuthenticationModuleState.FAILURE == moduleAuthentication.getState()
&& found.getOrder() == moduleAuthentication.getOrder()
&& moduleAuthentication.getFailureData() != null) {
return moduleAuthentication;
}
}
return null;
return found;
}

public AuthenticationException getAuthenticationExceptionIfExsits() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,8 @@ public static void clearMidpointAuthentication() {
if (oldAuthentication instanceof MidpointAuthentication
&& ((MidpointAuthentication) oldAuthentication).getAuthenticationChannel() != null
&& SecurityPolicyUtil.DEFAULT_CHANNEL.equals(((MidpointAuthentication) oldAuthentication).getAuthenticationChannel().getChannelId())) {
RemoveUnusedSecurityFilterPublisher.get().publishCustomEvent((MidpointAuthentication) oldAuthentication);
RemoveUnusedSecurityFilterPublisher.get().publishCustomEvent(
((MidpointAuthentication) oldAuthentication).getAuthModules());
}
SecurityContextHolder.getContext().setAuthentication(null);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,10 @@ public void recordSequenceAuthenticationSuccess(MidPointPrincipal principal, Con

Integer failedLogins = behavior.getFailedLogins();

boolean successLoginAfterFail = false;
if (failedLogins != null && failedLogins > 0) {
behavior.setFailedLogins(0);
successLoginAfterFail = true;
}
LoginEventType event = new LoginEventType();
event.setTimestamp(clock.currentTimeXMLGregorianCalendar());
Expand All @@ -158,7 +160,9 @@ public void recordSequenceAuthenticationSuccess(MidPointPrincipal principal, Con
behavior.setPreviousSuccessfulLogin(behavior.getLastSuccessfulLogin());
behavior.setLastSuccessfulLogin(event);

focusProfileService.updateFocus(principal, computeModifications(focusBefore, principal.getFocus()));
if (AuthSequenceUtil.isAllowUpdatingAuthBehavior(successLoginAfterFail)) {
focusProfileService.updateFocus(principal, computeModifications(focusBefore, principal.getFocus()));
}
securityHelper.auditLoginSuccess(principal.getFocus(), connEnv);
}

Expand All @@ -173,7 +177,9 @@ public void recordSequenceAuthenticationFailure(String username, MidPointPrincip
}
if (principal != null) {
focusType = principal.getFocus();
processFocusChange(principal, credentialsPolicy, connEnv);
if (AuthSequenceUtil.isAllowUpdatingAuthBehavior(true)) {
processFocusChange(principal, credentialsPolicy, connEnv);
}
}
securityHelper.auditLoginFailure(username, focusType, connEnv, reason);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ public Authentication getAuthentication() {
public void setAuthentication(Authentication authentication) {
if (getAuthentication() instanceof MidpointAuthentication
&& !getAuthentication().equals(authentication)) {
RemoveUnusedSecurityFilterPublisher.get().publishCustomEvent((MidpointAuthentication) getAuthentication());
RemoveUnusedSecurityFilterPublisher.get().publishCustomEvent(
((MidpointAuthentication) getAuthentication()).getAuthModules());
}
securityContext.setAuthentication(authentication);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import com.evolveum.midpoint.authentication.api.AuthenticationModuleState;
import com.evolveum.midpoint.authentication.api.config.MidpointAuthentication;

import com.evolveum.midpoint.authentication.api.config.ModuleAuthentication;
import com.evolveum.midpoint.authentication.impl.handler.BasicMidPointAuthenticationSuccessHandler;
import com.evolveum.midpoint.authentication.impl.session.MidpointHttpServletRequest;
import com.evolveum.midpoint.authentication.impl.util.AuthSequenceUtil;
Expand Down Expand Up @@ -195,5 +197,11 @@ protected void onUnsuccessfulAuthentication(HttpServletRequest request, HttpServ
} catch (ServletException e) {
LOGGER.error("Couldn't execute post unsuccessful authentication method", e);
}
Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
if (authentication instanceof MidpointAuthentication) {
MidpointAuthentication mpAuthentication = (MidpointAuthentication) authentication;
ModuleAuthentication moduleAuthentication = mpAuthentication.getProcessingModuleAuthentication();
moduleAuthentication.recordFailure(failed);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -192,14 +192,15 @@ private void validateAuthenticationCanContinue(MidpointAuthentication mpAuthenti

private void removingFiltersAfterProcessing(MidpointAuthentication mpAuthentication, HttpServletRequest httpRequest) {
if (!AuthSequenceUtil.isClusterSequence(httpRequest) && httpRequest.getSession(false) == null && mpAuthentication != null) {
removeUnusedSecurityFilterPublisher.publishCustomEvent(mpAuthentication);
removeUnusedSecurityFilterPublisher.publishCustomEvent(mpAuthentication.getAuthModules());
}
}

private void clearAuthentication(HttpServletRequest httpRequest) {
Authentication oldAuthentication = SecurityContextHolder.getContext().getAuthentication();
if (!AuthSequenceUtil.isClusterSequence(httpRequest) && oldAuthentication instanceof MidpointAuthentication) {
removeUnusedSecurityFilterPublisher.publishCustomEvent((MidpointAuthentication) oldAuthentication);
removeUnusedSecurityFilterPublisher.publishCustomEvent(
((MidpointAuthentication) oldAuthentication).getAuthModules());
}
SecurityContextHolder.getContext().setAuthentication(null);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

package com.evolveum.midpoint.authentication.impl.filter;

import org.apache.commons.lang3.StringUtils;

import java.io.IOException;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
Expand Down Expand Up @@ -50,6 +52,8 @@ public class SequenceAuditFilter extends OncePerRequestFilter {

@Autowired private FocusAuthenticationResultRecorder authenticationRecorder;

private boolean recordOnEndOfChain = true;

public SequenceAuditFilter() {
}

Expand All @@ -58,49 +62,70 @@ public SequenceAuditFilter(FocusAuthenticationResultRecorder authenticationRecor
this.authenticationRecorder = authenticationRecorder;
}

public void setRecordOnEndOfChain(boolean recordOnEndOfChain) {
this.recordOnEndOfChain = recordOnEndOfChain;
}

@Override
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException {
LOGGER.info("Running SequenceAuditFilter");

if (recordOnEndOfChain) {
filterChain.doFilter(request, response);
}

Authentication authentication = SecurityUtil.getAuthentication();
if (!(authentication instanceof MidpointAuthentication)) {
LOGGER.trace("No MidpointAuthentication present, continue with filter chain");
filterChain.doFilter(request, response);
if (!recordOnEndOfChain) {
filterChain.doFilter(request, response);
}
return;
}

MidpointAuthentication mpAuthentication = (MidpointAuthentication) authentication;
if (mpAuthentication.isAlreadyAudited()) {
LOGGER.trace("Skipping auditing of authentication record, already audited.");
filterChain.doFilter(request, response);
if (!recordOnEndOfChain) {
filterChain.doFilter(request, response);
}
return;
}

writeRecord(mpAuthentication);
writeRecord(request, mpAuthentication);

if (!recordOnEndOfChain) {
filterChain.doFilter(request, response);
}

filterChain.doFilter(request, response);
}

@VisibleForTesting
public void writeRecord(MidpointAuthentication mpAuthentication) {
public void writeRecord(HttpServletRequest request, MidpointAuthentication mpAuthentication) {
MidPointPrincipal mpPrincipal = mpAuthentication.getPrincipal() instanceof MidPointPrincipal ? (MidPointPrincipal) mpAuthentication.getPrincipal() : null;
boolean isAuthenticated = mpAuthentication.isAuthenticated();
if (isAuthenticated) {
authenticationRecorder.recordSequenceAuthenticationSuccess(mpPrincipal, createConnectionEnvironment(mpAuthentication));
authenticationRecorder.recordSequenceAuthenticationSuccess(mpPrincipal, createConnectionEnvironment(request, mpAuthentication));
mpAuthentication.setAlreadyAudited(true);
LOGGER.trace("Authentication sequence {} evaluated as successful.", mpAuthentication.getSequenceIdentifier());
} else if (mpAuthentication.isFinished()) {
} else if (mpAuthentication.isFinished() && StringUtils.isNotEmpty(mpAuthentication.getUsername())) {
authenticationRecorder.recordSequenceAuthenticationFailure(mpAuthentication.getUsername(), mpPrincipal, null,
mpAuthentication.getFailedReason(), createConnectionEnvironment(mpAuthentication));
mpAuthentication.getFailedReason(), createConnectionEnvironment(request, mpAuthentication));
mpAuthentication.setAlreadyAudited(true);
LOGGER.trace("Authentication sequence {} evaluated as failed.", mpAuthentication.getSequenceIdentifier());
}
}

private ConnectionEnvironment createConnectionEnvironment(MidpointAuthentication mpAuthentication) {
private ConnectionEnvironment createConnectionEnvironment(HttpServletRequest request, MidpointAuthentication mpAuthentication) {
String sessionId = request != null ? request.getRequestedSessionId() : null;
if (mpAuthentication.getSessionId() != null) {
sessionId = mpAuthentication.getSessionId();
}

ConnectionEnvironment connectionEnvironment = ConnectionEnvironment.create(mpAuthentication.getAuthenticationChannel().getChannelId());
connectionEnvironment.setSequenceIdentifier(mpAuthentication.getSequenceIdentifier());
connectionEnvironment.setSessionIdOverride(sessionId);

return connectionEnvironment;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import com.evolveum.midpoint.authentication.impl.entry.point.HttpAuthenticationEntryPoint;
import com.evolveum.midpoint.authentication.impl.MidpointAuthenticationTrustResolverImpl;
import com.evolveum.midpoint.authentication.impl.filter.HttpBasicAuthenticationFilter;
import com.evolveum.midpoint.authentication.impl.filter.SequenceAuditFilter;
import com.evolveum.midpoint.authentication.impl.filter.configurers.MidpointExceptionHandlingConfigurer;
import com.evolveum.midpoint.authentication.api.util.AuthUtil;
import com.evolveum.midpoint.authentication.api.ModuleWebSecurityConfiguration;
Expand Down Expand Up @@ -60,6 +61,11 @@ protected void configure(HttpSecurity http) throws Exception {
}
http.authorizeRequests().accessDecisionManager(new MidpointHttpAuthorizationEvaluator(securityEnforcer, securityContextManager, taskManager, model));
http.addFilterAt(filter, BasicAuthenticationFilter.class);

SequenceAuditFilter sequenceAuditFilter = getObjectPostProcessor().postProcess(new SequenceAuditFilter());
sequenceAuditFilter.setRecordOnEndOfChain(false);
http.addFilterAfter(sequenceAuditFilter, BasicAuthenticationFilter.class);

http.formLogin().disable()
.csrf().disable();
getOrApply(http, new MidpointExceptionHandlingConfigurer<>())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import com.evolveum.midpoint.authentication.impl.entry.point.HttpSecurityQuestionsAuthenticationEntryPoint;
import com.evolveum.midpoint.authentication.impl.MidpointAuthenticationTrustResolverImpl;
import com.evolveum.midpoint.authentication.impl.filter.HttpSecurityQuestionsAuthenticationFilter;
import com.evolveum.midpoint.authentication.impl.filter.SequenceAuditFilter;
import com.evolveum.midpoint.authentication.impl.filter.configurers.MidpointExceptionHandlingConfigurer;
import com.evolveum.midpoint.authentication.api.util.AuthUtil;
import com.evolveum.midpoint.authentication.api.ModuleWebSecurityConfiguration;
Expand Down Expand Up @@ -62,6 +63,11 @@ protected void configure(HttpSecurity http) throws Exception {
filter.setRememberMeServices(rememberMeServices);
}
http.addFilterAt(filter, BasicAuthenticationFilter.class);

SequenceAuditFilter sequenceAuditFilter = getObjectPostProcessor().postProcess(new SequenceAuditFilter());
sequenceAuditFilter.setRecordOnEndOfChain(false);
http.addFilterAfter(sequenceAuditFilter, BasicAuthenticationFilter.class);

http.formLogin().disable()
.csrf().disable();
getOrApply(http, new MidpointExceptionHandlingConfigurer<>())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import com.evolveum.midpoint.authentication.impl.MidpointAuthenticationTrustResolverImpl;
import com.evolveum.midpoint.authentication.impl.authorization.evaluator.MidpointHttpAuthorizationEvaluator;
import com.evolveum.midpoint.authentication.impl.entry.point.HttpAuthenticationEntryPoint;
import com.evolveum.midpoint.authentication.impl.filter.SequenceAuditFilter;
import com.evolveum.midpoint.authentication.impl.filter.configurers.MidpointExceptionHandlingConfigurer;
import com.evolveum.midpoint.authentication.impl.module.configuration.OidcResourceServerModuleWebSecurityConfiguration;
import com.evolveum.midpoint.authentication.impl.oidc.OidcBearerTokenAuthenticationFilter;
Expand Down Expand Up @@ -61,6 +62,12 @@ protected void configure(HttpSecurity http) throws Exception {
}
http.authorizeRequests().accessDecisionManager(new MidpointHttpAuthorizationEvaluator(securityEnforcer, securityContextManager, taskManager, model));
http.addFilterAt(filter, BasicAuthenticationFilter.class);

SequenceAuditFilter sequenceAuditFilter = getObjectPostProcessor().postProcess(new SequenceAuditFilter());
sequenceAuditFilter.setRecordOnEndOfChain(false);
http.addFilterAfter(sequenceAuditFilter, BasicAuthenticationFilter.class);


http.formLogin().disable()
.csrf().disable();
getOrApply(http, new MidpointExceptionHandlingConfigurer<>())
Expand Down

0 comments on commit 209001f

Please sign in to comment.