Skip to content

Commit

Permalink
clenup code of MidpointAuthFilter.doFilterInternal
Browse files Browse the repository at this point in the history
  • Loading branch information
skublik committed Jun 12, 2020
1 parent d45f992 commit 82e1116
Showing 1 changed file with 105 additions and 77 deletions.
Expand Up @@ -10,34 +10,29 @@
import com.evolveum.midpoint.model.common.SystemObjectCache;
import com.evolveum.midpoint.prism.PrismContext;
import com.evolveum.midpoint.prism.PrismObject;
import com.evolveum.midpoint.prism.schema.SchemaRegistry;
import com.evolveum.midpoint.schema.result.OperationResult;
import com.evolveum.midpoint.schema.util.SecurityPolicyUtil;
import com.evolveum.midpoint.util.exception.SchemaException;
import com.evolveum.midpoint.util.logging.Trace;
import com.evolveum.midpoint.util.logging.TraceManager;
import com.evolveum.midpoint.web.security.MidpointAuthenticationManager;
import com.evolveum.midpoint.web.security.MidpointProviderManager;
import com.evolveum.midpoint.web.security.factory.channel.AuthChannelRegistryImpl;
import com.evolveum.midpoint.web.security.module.ModuleWebSecurityConfig;
import com.evolveum.midpoint.web.security.factory.module.AuthModuleRegistryImpl;
import com.evolveum.midpoint.web.security.util.SecurityUtils;
import com.evolveum.midpoint.xml.ns._public.common.common_3.*;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.security.authentication.AuthenticationManager;
import org.springframework.security.authentication.AuthenticationServiceException;
import org.springframework.security.config.annotation.ObjectPostProcessor;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.web.DefaultRedirectStrategy;
import org.springframework.security.web.WebAttributes;
import org.springframework.security.web.util.UrlUtils;
import org.springframework.web.filter.GenericFilterBean;

import javax.servlet.*;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.*;

Expand Down Expand Up @@ -68,7 +63,6 @@ public class MidpointAuthFilter extends GenericFilterBean {
@Autowired
private PrismContext prismContext;

// private SecurityFilterChain authenticatedFilter;
private AuthenticationsPolicyType authenticationPolicy;
private PreLogoutFilter preLogoutFilter = new PreLogoutFilter();

Expand All @@ -83,12 +77,6 @@ public PreLogoutFilter getPreLogoutFilter() {
public void createFilterForAuthenticatedRequest() {
ModuleWebSecurityConfig module = objectObjectPostProcessor.postProcess(new ModuleWebSecurityConfig(null));
module.setObjectPostProcessor(objectObjectPostProcessor);
// try {
// HttpSecurity http = module.getNewHttpSecurity();
// authenticatedFilter = http.build();
// } catch (Exception e) {
// LOGGER.error("Couldn't create filter for authenticated requests", e);
// }
}

public AuthenticationsPolicyType getDefaultAuthenticationPolicy() throws SchemaException {
Expand All @@ -108,6 +96,7 @@ private void doFilterInternal(ServletRequest request, ServletResponse response,
FilterChain chain) throws IOException, ServletException {

HttpServletRequest httpRequest = (HttpServletRequest) request;
//request for permit all page (for example errors and login pages)
if (SecurityUtils.isPermitAll(httpRequest) && !SecurityUtils.isLoginPage(httpRequest)) {
chain.doFilter(request, response);
return;
Expand All @@ -119,21 +108,11 @@ private void doFilterInternal(ServletRequest request, ServletResponse response,
CredentialsPolicyType credentialsPolicy = null;
PrismObject<SecurityPolicyType> authPolicy = null;
try {
authPolicy = systemObjectCache.getSecurityPolicy(new OperationResult("load authentication policy"));

//security policy without authentication
if (authPolicy == null || authPolicy.asObjectable().getAuthentication() == null
|| authPolicy.asObjectable().getAuthentication().getSequence() == null
|| authPolicy.asObjectable().getAuthentication().getSequence().isEmpty()) {
authenticationsPolicy = getDefaultAuthenticationPolicy();
} else {
authenticationsPolicy = authPolicy.asObjectable().getAuthentication();
}

authPolicy = getSecurityPolicy();
authenticationsPolicy = getAuthenticationPolicy(authPolicy);
if (authPolicy != null) {
credentialsPolicy = authPolicy.asObjectable().getCredentials();
}

} catch (SchemaException e) {
LOGGER.error("Couldn't load Authentication policy", e);
try {
Expand All @@ -144,54 +123,29 @@ private void doFilterInternal(ServletRequest request, ServletResponse response,
}
}

//is path for which is ignored authentication
if (SecurityUtils.isIgnoredLocalPath(authenticationsPolicy, httpRequest)) {
chain.doFilter(request, response);
return;
}

AuthenticationSequenceType sequence;
// permitAll pages (login, select ID for saml ...) during processing of modules
if (mpAuthentication != null && SecurityUtils.isLoginPage(httpRequest)) {
sequence = mpAuthentication.getSequence();
} else {
sequence = SecurityUtils.getSequenceByPath(httpRequest, authenticationsPolicy);
}


if (mpAuthentication != null && !mpAuthentication.getSequence().equals(sequence) && mpAuthentication.isAuthenticated()
&& (((sequence != null && sequence.getChannel() != null && mpAuthentication.getAuthenticationChannel().matchChannel(sequence)))
|| mpAuthentication.getAuthenticationChannel().getChannelId().equals(SecurityUtils.findChannelByRequest(httpRequest)))) {
if (SecurityUtils.isBasePathForSequence(httpRequest, sequence)) {
mpAuthentication.getAuthenticationChannel().setPathAfterLogout(((HttpServletRequest) request).getServletPath());
ModuleAuthentication authenticatedModule = SecurityUtils.getAuthenticatedModule();
authenticatedModule.setInternalLogout(true);
}
sequence = mpAuthentication.getSequence();

}

AuthenticationSequenceType sequence = getAuthenticationSequence(mpAuthentication, httpRequest, authenticationsPolicy);
if (sequence == null) {
throw new IllegalArgumentException("Couldn't find sequence for URI '" + httpRequest.getRequestURI() + "' in authentication of Security Policy with oid " + authPolicy.getOid());
}

//change generic logout path to logout path for actual module
getPreLogoutFilter().doFilter(request, response);

AuthenticationChannel authenticationChannel = SecurityUtils.buildAuthChannel(authChannelRegistry, sequence);

List<AuthModule> authModules;
//change sequence of authentication during another sequence
if (mpAuthentication == null || !sequence.equals(mpAuthentication.getSequence())) {
SecurityContextHolder.getContext().setAuthentication(null);
authenticationManager.getProviders().clear();
authModules = SecurityUtils.buildModuleFilters(authModuleRegistry, sequence, httpRequest, authenticationsPolicy.getModules(),
credentialsPolicy, sharedObjects, authenticationChannel);
} else {
//authenticated request
if (mpAuthentication != null && mpAuthentication.isAuthenticated()) {
processingOfAuthenticatedRequest(mpAuthentication, httpRequest, response, chain);
return;
}
authModules = mpAuthentication.getAuthModules();
List<AuthModule> authModules = createAuthenticationModuleBySequence(mpAuthentication, sequence, httpRequest, authenticationsPolicy.getModules()
,authenticationChannel, credentialsPolicy);

//authenticated request
if (mpAuthentication != null && mpAuthentication.isAuthenticated() && sequence.equals(mpAuthentication.getSequence())) {
processingOfAuthenticatedRequest(mpAuthentication, httpRequest, response, chain);
return;
}

//couldn't find authentication modules
Expand All @@ -203,13 +157,39 @@ private void doFilterInternal(ServletRequest request, ServletResponse response,
throw new AuthenticationServiceException("Couldn't find filters for sequence " + sequence.getName());
}

int indexOfProcessingModule = -1;
// if exist authentication (authentication flow is processed) find actual processing module
if (SecurityContextHolder.getContext().getAuthentication() != null) {
indexOfProcessingModule = mpAuthentication.getIndexOfProcessingModule(true);
indexOfProcessingModule = mpAuthentication.resolveParallelModules((HttpServletRequest) request, indexOfProcessingModule);
int indexOfProcessingModule = getIndexOfActualProcessingModule(mpAuthentication, httpRequest);

resolveErrorWithMoreModules(mpAuthentication, httpRequest);

if (needRestartAuthFlow(indexOfProcessingModule)) {
indexOfProcessingModule = restartAuthFlow(mpAuthentication, httpRequest, sequence, authModules);
mpAuthentication = (MidpointAuthentication) SecurityContextHolder.getContext().getAuthentication();
}

if (mpAuthentication.getAuthenticationChannel() == null) {
mpAuthentication.setAuthenticationChannel(authenticationChannel);
}

MidpointAuthFilter.VirtualFilterChain vfc = new MidpointAuthFilter.VirtualFilterChain(httpRequest, chain, authModules.get(indexOfProcessingModule).getSecurityFilterChain().getFilters());
vfc.doFilter(httpRequest, response);
}

private boolean needRestartAuthFlow(int indexOfProcessingModule) {
// if index == -1 indicate restart authentication flow
return indexOfProcessingModule == -1;
}

private int restartAuthFlow(MidpointAuthentication mpAuthentication, HttpServletRequest httpRequest, AuthenticationSequenceType sequence, List<AuthModule> authModules) {
SecurityContextHolder.getContext().setAuthentication(null);
SecurityContextHolder.getContext().setAuthentication(new MidpointAuthentication(sequence));
mpAuthentication = (MidpointAuthentication) SecurityContextHolder.getContext().getAuthentication();
mpAuthentication.setAuthModules(authModules);
mpAuthentication.setSessionId(httpRequest.getSession().getId());
mpAuthentication.addAuthentications(authModules.get(0).getBaseModuleAuthentication());
return mpAuthentication.resolveParallelModules(httpRequest, 0);
}

private void resolveErrorWithMoreModules(MidpointAuthentication mpAuthentication, HttpServletRequest httpRequest) {
//authentication flow fail and exist more as one authentication module write error
if (mpAuthentication != null && mpAuthentication.isAuthenticationFailed() && mpAuthentication.getAuthModules().size() > 1) {

Expand All @@ -224,25 +204,73 @@ private void doFilterInternal(ServletRequest request, ServletResponse response,
AuthenticationException exception = new AuthenticationServiceException(actualMessage);
SecurityUtils.saveException(httpRequest, exception);
}
}

// if index == -1 indicate restart authentication flow
if (indexOfProcessingModule == -1) {
private int getIndexOfActualProcessingModule(MidpointAuthentication mpAuthentication, HttpServletRequest request) {
int indexOfProcessingModule = -1;
// if exist authentication (authentication flow is processed) find actual processing module
if (SecurityContextHolder.getContext().getAuthentication() != null) {
indexOfProcessingModule = mpAuthentication.getIndexOfProcessingModule(true);
indexOfProcessingModule = mpAuthentication.resolveParallelModules(request, indexOfProcessingModule);
}
return indexOfProcessingModule;
}

private List<AuthModule> createAuthenticationModuleBySequence(MidpointAuthentication mpAuthentication, AuthenticationSequenceType sequence,
HttpServletRequest httpRequest, AuthenticationModulesType modules, AuthenticationChannel authenticationChannel, CredentialsPolicyType credentialsPolicy) {
List<AuthModule> authModules;
//change sequence of authentication during another sequence
if (mpAuthentication == null || !sequence.equals(mpAuthentication.getSequence())) {
SecurityContextHolder.getContext().setAuthentication(null);
SecurityContextHolder.getContext().setAuthentication(new MidpointAuthentication(sequence));
mpAuthentication = (MidpointAuthentication) SecurityContextHolder.getContext().getAuthentication();
mpAuthentication.setAuthModules(authModules);
mpAuthentication.setSessionId(httpRequest.getSession().getId());
indexOfProcessingModule = 0;
mpAuthentication.addAuthentications(authModules.get(indexOfProcessingModule).getBaseModuleAuthentication());
indexOfProcessingModule = mpAuthentication.resolveParallelModules((HttpServletRequest) request, indexOfProcessingModule);
authenticationManager.getProviders().clear();
authModules = SecurityUtils.buildModuleFilters(authModuleRegistry, sequence, httpRequest, modules,
credentialsPolicy, sharedObjects, authenticationChannel);
} else {
authModules = mpAuthentication.getAuthModules();
}
return authModules;
}

private AuthenticationSequenceType getAuthenticationSequence(MidpointAuthentication mpAuthentication, HttpServletRequest httpRequest, AuthenticationsPolicyType authenticationsPolicy) {
AuthenticationSequenceType sequence;
// permitAll pages (login, select ID for saml ...) during processing of modules
if (mpAuthentication != null && SecurityUtils.isLoginPage(httpRequest)) {
sequence = mpAuthentication.getSequence();
} else {
sequence = SecurityUtils.getSequenceByPath(httpRequest, authenticationsPolicy);
}

if (mpAuthentication.getAuthenticationChannel() == null) {
mpAuthentication.setAuthenticationChannel(authenticationChannel);
// use same sequence if focus is authenticated and channel id of new sequence is same
if (mpAuthentication != null && !mpAuthentication.getSequence().equals(sequence) && mpAuthentication.isAuthenticated()
&& (((sequence != null && sequence.getChannel() != null && mpAuthentication.getAuthenticationChannel().matchChannel(sequence)))
|| mpAuthentication.getAuthenticationChannel().getChannelId().equals(SecurityUtils.findChannelByRequest(httpRequest)))) {
//change logout path to new sequence
if (SecurityUtils.isBasePathForSequence(httpRequest, sequence)) {
mpAuthentication.getAuthenticationChannel().setPathAfterLogout(httpRequest.getServletPath());
ModuleAuthentication authenticatedModule = SecurityUtils.getAuthenticatedModule();
authenticatedModule.setInternalLogout(true);
}
sequence = mpAuthentication.getSequence();

}
return sequence;
}

MidpointAuthFilter.VirtualFilterChain vfc = new MidpointAuthFilter.VirtualFilterChain(httpRequest, chain, authModules.get(indexOfProcessingModule).getSecurityFilterChain().getFilters());
vfc.doFilter(httpRequest, response);
private AuthenticationsPolicyType getAuthenticationPolicy(PrismObject<SecurityPolicyType> authPolicy) throws SchemaException {
//security policy without authentication
AuthenticationsPolicyType authenticationsPolicy;
if (authPolicy == null || authPolicy.asObjectable().getAuthentication() == null
|| authPolicy.asObjectable().getAuthentication().getSequence() == null
|| authPolicy.asObjectable().getAuthentication().getSequence().isEmpty()) {
authenticationsPolicy = getDefaultAuthenticationPolicy();
} else {
authenticationsPolicy = authPolicy.asObjectable().getAuthentication();
}
return authenticationsPolicy;
}

private PrismObject<SecurityPolicyType> getSecurityPolicy() throws SchemaException {
return systemObjectCache.getSecurityPolicy(new OperationResult("load security policy"));
}

private void processingOfAuthenticatedRequest(MidpointAuthentication mpAuthentication, ServletRequest httpRequest, ServletResponse response, FilterChain chain) throws IOException, ServletException {
Expand Down

0 comments on commit 82e1116

Please sign in to comment.