From 82e1116a4bcc82e569cfecb931f3088c5a33a245 Mon Sep 17 00:00:00 2001 From: lskublik Date: Fri, 12 Jun 2020 12:53:45 +0200 Subject: [PATCH] clenup code of MidpointAuthFilter.doFilterInternal --- .../security/filter/MidpointAuthFilter.java | 182 ++++++++++-------- 1 file changed, 105 insertions(+), 77 deletions(-) diff --git a/gui/admin-gui/src/main/java/com/evolveum/midpoint/web/security/filter/MidpointAuthFilter.java b/gui/admin-gui/src/main/java/com/evolveum/midpoint/web/security/filter/MidpointAuthFilter.java index a713f7179c8..1a5dbcc5876 100644 --- a/gui/admin-gui/src/main/java/com/evolveum/midpoint/web/security/filter/MidpointAuthFilter.java +++ b/gui/admin-gui/src/main/java/com/evolveum/midpoint/web/security/filter/MidpointAuthFilter.java @@ -10,14 +10,12 @@ import com.evolveum.midpoint.model.common.SystemObjectCache; import com.evolveum.midpoint.prism.PrismContext; import com.evolveum.midpoint.prism.PrismObject; -import com.evolveum.midpoint.prism.schema.SchemaRegistry; import com.evolveum.midpoint.schema.result.OperationResult; import com.evolveum.midpoint.schema.util.SecurityPolicyUtil; import com.evolveum.midpoint.util.exception.SchemaException; import com.evolveum.midpoint.util.logging.Trace; import com.evolveum.midpoint.util.logging.TraceManager; import com.evolveum.midpoint.web.security.MidpointAuthenticationManager; -import com.evolveum.midpoint.web.security.MidpointProviderManager; import com.evolveum.midpoint.web.security.factory.channel.AuthChannelRegistryImpl; import com.evolveum.midpoint.web.security.module.ModuleWebSecurityConfig; import com.evolveum.midpoint.web.security.factory.module.AuthModuleRegistryImpl; @@ -25,19 +23,16 @@ import com.evolveum.midpoint.xml.ns._public.common.common_3.*; import org.apache.commons.lang3.StringUtils; import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.authentication.AuthenticationServiceException; import org.springframework.security.config.annotation.ObjectPostProcessor; import org.springframework.security.core.AuthenticationException; import org.springframework.security.core.context.SecurityContextHolder; -import org.springframework.security.web.DefaultRedirectStrategy; import org.springframework.security.web.WebAttributes; import org.springframework.security.web.util.UrlUtils; import org.springframework.web.filter.GenericFilterBean; import javax.servlet.*; import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; import java.io.IOException; import java.util.*; @@ -68,7 +63,6 @@ public class MidpointAuthFilter extends GenericFilterBean { @Autowired private PrismContext prismContext; -// private SecurityFilterChain authenticatedFilter; private AuthenticationsPolicyType authenticationPolicy; private PreLogoutFilter preLogoutFilter = new PreLogoutFilter(); @@ -83,12 +77,6 @@ public PreLogoutFilter getPreLogoutFilter() { public void createFilterForAuthenticatedRequest() { ModuleWebSecurityConfig module = objectObjectPostProcessor.postProcess(new ModuleWebSecurityConfig(null)); module.setObjectPostProcessor(objectObjectPostProcessor); -// try { -// HttpSecurity http = module.getNewHttpSecurity(); -// authenticatedFilter = http.build(); -// } catch (Exception e) { -// LOGGER.error("Couldn't create filter for authenticated requests", e); -// } } public AuthenticationsPolicyType getDefaultAuthenticationPolicy() throws SchemaException { @@ -108,6 +96,7 @@ private void doFilterInternal(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException { HttpServletRequest httpRequest = (HttpServletRequest) request; + //request for permit all page (for example errors and login pages) if (SecurityUtils.isPermitAll(httpRequest) && !SecurityUtils.isLoginPage(httpRequest)) { chain.doFilter(request, response); return; @@ -119,21 +108,11 @@ private void doFilterInternal(ServletRequest request, ServletResponse response, CredentialsPolicyType credentialsPolicy = null; PrismObject authPolicy = null; try { - authPolicy = systemObjectCache.getSecurityPolicy(new OperationResult("load authentication policy")); - - //security policy without authentication - if (authPolicy == null || authPolicy.asObjectable().getAuthentication() == null - || authPolicy.asObjectable().getAuthentication().getSequence() == null - || authPolicy.asObjectable().getAuthentication().getSequence().isEmpty()) { - authenticationsPolicy = getDefaultAuthenticationPolicy(); - } else { - authenticationsPolicy = authPolicy.asObjectable().getAuthentication(); - } - + authPolicy = getSecurityPolicy(); + authenticationsPolicy = getAuthenticationPolicy(authPolicy); if (authPolicy != null) { credentialsPolicy = authPolicy.asObjectable().getCredentials(); } - } catch (SchemaException e) { LOGGER.error("Couldn't load Authentication policy", e); try { @@ -144,54 +123,29 @@ private void doFilterInternal(ServletRequest request, ServletResponse response, } } + //is path for which is ignored authentication if (SecurityUtils.isIgnoredLocalPath(authenticationsPolicy, httpRequest)) { chain.doFilter(request, response); return; } - AuthenticationSequenceType sequence; - // permitAll pages (login, select ID for saml ...) during processing of modules - if (mpAuthentication != null && SecurityUtils.isLoginPage(httpRequest)) { - sequence = mpAuthentication.getSequence(); - } else { - sequence = SecurityUtils.getSequenceByPath(httpRequest, authenticationsPolicy); - } - - - if (mpAuthentication != null && !mpAuthentication.getSequence().equals(sequence) && mpAuthentication.isAuthenticated() - && (((sequence != null && sequence.getChannel() != null && mpAuthentication.getAuthenticationChannel().matchChannel(sequence))) - || mpAuthentication.getAuthenticationChannel().getChannelId().equals(SecurityUtils.findChannelByRequest(httpRequest)))) { - if (SecurityUtils.isBasePathForSequence(httpRequest, sequence)) { - mpAuthentication.getAuthenticationChannel().setPathAfterLogout(((HttpServletRequest) request).getServletPath()); - ModuleAuthentication authenticatedModule = SecurityUtils.getAuthenticatedModule(); - authenticatedModule.setInternalLogout(true); - } - sequence = mpAuthentication.getSequence(); - - } - + AuthenticationSequenceType sequence = getAuthenticationSequence(mpAuthentication, httpRequest, authenticationsPolicy); if (sequence == null) { throw new IllegalArgumentException("Couldn't find sequence for URI '" + httpRequest.getRequestURI() + "' in authentication of Security Policy with oid " + authPolicy.getOid()); } + //change generic logout path to logout path for actual module getPreLogoutFilter().doFilter(request, response); AuthenticationChannel authenticationChannel = SecurityUtils.buildAuthChannel(authChannelRegistry, sequence); - List authModules; - //change sequence of authentication during another sequence - if (mpAuthentication == null || !sequence.equals(mpAuthentication.getSequence())) { - SecurityContextHolder.getContext().setAuthentication(null); - authenticationManager.getProviders().clear(); - authModules = SecurityUtils.buildModuleFilters(authModuleRegistry, sequence, httpRequest, authenticationsPolicy.getModules(), - credentialsPolicy, sharedObjects, authenticationChannel); - } else { - //authenticated request - if (mpAuthentication != null && mpAuthentication.isAuthenticated()) { - processingOfAuthenticatedRequest(mpAuthentication, httpRequest, response, chain); - return; - } - authModules = mpAuthentication.getAuthModules(); + List authModules = createAuthenticationModuleBySequence(mpAuthentication, sequence, httpRequest, authenticationsPolicy.getModules() + ,authenticationChannel, credentialsPolicy); + + //authenticated request + if (mpAuthentication != null && mpAuthentication.isAuthenticated() && sequence.equals(mpAuthentication.getSequence())) { + processingOfAuthenticatedRequest(mpAuthentication, httpRequest, response, chain); + return; } //couldn't find authentication modules @@ -203,13 +157,39 @@ private void doFilterInternal(ServletRequest request, ServletResponse response, throw new AuthenticationServiceException("Couldn't find filters for sequence " + sequence.getName()); } - int indexOfProcessingModule = -1; - // if exist authentication (authentication flow is processed) find actual processing module - if (SecurityContextHolder.getContext().getAuthentication() != null) { - indexOfProcessingModule = mpAuthentication.getIndexOfProcessingModule(true); - indexOfProcessingModule = mpAuthentication.resolveParallelModules((HttpServletRequest) request, indexOfProcessingModule); + int indexOfProcessingModule = getIndexOfActualProcessingModule(mpAuthentication, httpRequest); + + resolveErrorWithMoreModules(mpAuthentication, httpRequest); + + if (needRestartAuthFlow(indexOfProcessingModule)) { + indexOfProcessingModule = restartAuthFlow(mpAuthentication, httpRequest, sequence, authModules); + mpAuthentication = (MidpointAuthentication) SecurityContextHolder.getContext().getAuthentication(); } + if (mpAuthentication.getAuthenticationChannel() == null) { + mpAuthentication.setAuthenticationChannel(authenticationChannel); + } + + MidpointAuthFilter.VirtualFilterChain vfc = new MidpointAuthFilter.VirtualFilterChain(httpRequest, chain, authModules.get(indexOfProcessingModule).getSecurityFilterChain().getFilters()); + vfc.doFilter(httpRequest, response); + } + + private boolean needRestartAuthFlow(int indexOfProcessingModule) { + // if index == -1 indicate restart authentication flow + return indexOfProcessingModule == -1; + } + + private int restartAuthFlow(MidpointAuthentication mpAuthentication, HttpServletRequest httpRequest, AuthenticationSequenceType sequence, List authModules) { + SecurityContextHolder.getContext().setAuthentication(null); + SecurityContextHolder.getContext().setAuthentication(new MidpointAuthentication(sequence)); + mpAuthentication = (MidpointAuthentication) SecurityContextHolder.getContext().getAuthentication(); + mpAuthentication.setAuthModules(authModules); + mpAuthentication.setSessionId(httpRequest.getSession().getId()); + mpAuthentication.addAuthentications(authModules.get(0).getBaseModuleAuthentication()); + return mpAuthentication.resolveParallelModules(httpRequest, 0); + } + + private void resolveErrorWithMoreModules(MidpointAuthentication mpAuthentication, HttpServletRequest httpRequest) { //authentication flow fail and exist more as one authentication module write error if (mpAuthentication != null && mpAuthentication.isAuthenticationFailed() && mpAuthentication.getAuthModules().size() > 1) { @@ -224,25 +204,73 @@ private void doFilterInternal(ServletRequest request, ServletResponse response, AuthenticationException exception = new AuthenticationServiceException(actualMessage); SecurityUtils.saveException(httpRequest, exception); } + } - // if index == -1 indicate restart authentication flow - if (indexOfProcessingModule == -1) { + private int getIndexOfActualProcessingModule(MidpointAuthentication mpAuthentication, HttpServletRequest request) { + int indexOfProcessingModule = -1; + // if exist authentication (authentication flow is processed) find actual processing module + if (SecurityContextHolder.getContext().getAuthentication() != null) { + indexOfProcessingModule = mpAuthentication.getIndexOfProcessingModule(true); + indexOfProcessingModule = mpAuthentication.resolveParallelModules(request, indexOfProcessingModule); + } + return indexOfProcessingModule; + } + + private List createAuthenticationModuleBySequence(MidpointAuthentication mpAuthentication, AuthenticationSequenceType sequence, + HttpServletRequest httpRequest, AuthenticationModulesType modules, AuthenticationChannel authenticationChannel, CredentialsPolicyType credentialsPolicy) { + List authModules; + //change sequence of authentication during another sequence + if (mpAuthentication == null || !sequence.equals(mpAuthentication.getSequence())) { SecurityContextHolder.getContext().setAuthentication(null); - SecurityContextHolder.getContext().setAuthentication(new MidpointAuthentication(sequence)); - mpAuthentication = (MidpointAuthentication) SecurityContextHolder.getContext().getAuthentication(); - mpAuthentication.setAuthModules(authModules); - mpAuthentication.setSessionId(httpRequest.getSession().getId()); - indexOfProcessingModule = 0; - mpAuthentication.addAuthentications(authModules.get(indexOfProcessingModule).getBaseModuleAuthentication()); - indexOfProcessingModule = mpAuthentication.resolveParallelModules((HttpServletRequest) request, indexOfProcessingModule); + authenticationManager.getProviders().clear(); + authModules = SecurityUtils.buildModuleFilters(authModuleRegistry, sequence, httpRequest, modules, + credentialsPolicy, sharedObjects, authenticationChannel); + } else { + authModules = mpAuthentication.getAuthModules(); + } + return authModules; + } + + private AuthenticationSequenceType getAuthenticationSequence(MidpointAuthentication mpAuthentication, HttpServletRequest httpRequest, AuthenticationsPolicyType authenticationsPolicy) { + AuthenticationSequenceType sequence; + // permitAll pages (login, select ID for saml ...) during processing of modules + if (mpAuthentication != null && SecurityUtils.isLoginPage(httpRequest)) { + sequence = mpAuthentication.getSequence(); + } else { + sequence = SecurityUtils.getSequenceByPath(httpRequest, authenticationsPolicy); } - if (mpAuthentication.getAuthenticationChannel() == null) { - mpAuthentication.setAuthenticationChannel(authenticationChannel); + // use same sequence if focus is authenticated and channel id of new sequence is same + if (mpAuthentication != null && !mpAuthentication.getSequence().equals(sequence) && mpAuthentication.isAuthenticated() + && (((sequence != null && sequence.getChannel() != null && mpAuthentication.getAuthenticationChannel().matchChannel(sequence))) + || mpAuthentication.getAuthenticationChannel().getChannelId().equals(SecurityUtils.findChannelByRequest(httpRequest)))) { + //change logout path to new sequence + if (SecurityUtils.isBasePathForSequence(httpRequest, sequence)) { + mpAuthentication.getAuthenticationChannel().setPathAfterLogout(httpRequest.getServletPath()); + ModuleAuthentication authenticatedModule = SecurityUtils.getAuthenticatedModule(); + authenticatedModule.setInternalLogout(true); + } + sequence = mpAuthentication.getSequence(); + } + return sequence; + } - MidpointAuthFilter.VirtualFilterChain vfc = new MidpointAuthFilter.VirtualFilterChain(httpRequest, chain, authModules.get(indexOfProcessingModule).getSecurityFilterChain().getFilters()); - vfc.doFilter(httpRequest, response); + private AuthenticationsPolicyType getAuthenticationPolicy(PrismObject authPolicy) throws SchemaException { + //security policy without authentication + AuthenticationsPolicyType authenticationsPolicy; + if (authPolicy == null || authPolicy.asObjectable().getAuthentication() == null + || authPolicy.asObjectable().getAuthentication().getSequence() == null + || authPolicy.asObjectable().getAuthentication().getSequence().isEmpty()) { + authenticationsPolicy = getDefaultAuthenticationPolicy(); + } else { + authenticationsPolicy = authPolicy.asObjectable().getAuthentication(); + } + return authenticationsPolicy; + } + + private PrismObject getSecurityPolicy() throws SchemaException { + return systemObjectCache.getSecurityPolicy(new OperationResult("load security policy")); } private void processingOfAuthenticatedRequest(MidpointAuthentication mpAuthentication, ServletRequest httpRequest, ServletResponse response, FilterChain chain) throws IOException, ServletException {