Skip to content

Commit

Permalink
some attempt for authentication modules, the code isn't clean
Browse files Browse the repository at this point in the history
  • Loading branch information
KaterynaHonchar committed Dec 19, 2022
1 parent 8a5c9da commit 5200112
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 62 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import javax.servlet.http.HttpServletRequest;

import com.evolveum.midpoint.authentication.api.AuthModule;
Expand All @@ -18,6 +19,8 @@
import com.evolveum.midpoint.security.api.MidPointPrincipal;
import com.evolveum.midpoint.authentication.api.AuthenticationModuleState;

import com.evolveum.midpoint.xml.ns._public.common.common_3.AuthenticationSequenceModuleType;

import org.apache.commons.lang3.Validate;
import org.springframework.security.authentication.AbstractAuthenticationToken;
import org.springframework.security.authentication.AnonymousAuthenticationToken;
Expand All @@ -38,7 +41,9 @@ public class MidpointAuthentication extends AbstractAuthenticationToken implemen
/**
* Configuration of sequence from xml
*/
private final AuthenticationSequenceType sequence;
private AuthenticationSequenceType sequence;
private Map<Class<?>, Object> sharedObjects;


/**
* Authentications for modules of sequence
Expand All @@ -64,8 +69,8 @@ public class MidpointAuthentication extends AbstractAuthenticationToken implemen
private Object credential;
private String sessionId;
private Collection<? extends GrantedAuthority> authorities = AuthorityUtils.NO_AUTHORITIES;
public static int NO_PROCESSING_MODULE_INDEX = -2;
public static int NO_MODULE_FOUND_INDEX = -1;
public static final int NO_PROCESSING_MODULE_INDEX = -2;
public static final int NO_MODULE_FOUND_INDEX = -1;
private boolean merged = false;


Expand All @@ -87,6 +92,18 @@ public AuthenticationSequenceType getSequence() {
return sequence;
}

public void setSequence(AuthenticationSequenceType sequence) {
this.sequence = sequence;
}

public Map<Class<?>, Object> getSharedObjects() {
return sharedObjects;
}

public void setSharedObjects(Map<Class<?>, Object> sharedObjects) {
this.sharedObjects = sharedObjects;
}

public AuthenticationChannel getAuthenticationChannel() {
return authenticationChannel;
}
Expand Down Expand Up @@ -150,24 +167,15 @@ public String getSessionId() {

@Override
public boolean isAuthenticated() {
List<AuthModule> modules = getAuthModules();
List<AuthenticationSequenceModuleType> modules = sequence.getModule();
if (modules.isEmpty()) {
return false;
}
for (AuthModule module : modules) {
ModuleAuthentication authentication = getAuthenticationByIdentifier(module.getModuleIdentifier());
if (authentication == null) {
continue;
}
//TODO we will complete after supporting of full "necessity"
// if (AuthenticationSequenceModuleNecessityType.SUFFICIENT.equals(authentication.getNecessity())) {
if (!AuthenticationModuleState.SUCCESSFULLY.equals(authentication.getState())) {
return false;
}
// }

}
return true;
//todo
boolean isAuth = modules.stream().filter(m -> getAuthenticationByIdentifier(m.getIdentifier()) == null).findAny().isEmpty() &&
modules.stream().filter(m ->
!AuthenticationModuleState.SUCCESSFULLY.equals(getAuthenticationByIdentifier(m.getIdentifier()).getState())).findAny().isEmpty();
return isAuth;
}

public ModuleAuthentication getAuthenticationByIdentifier(String moduleIdentifier) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,15 +152,15 @@ private void doFilterInternal(ServletRequest request, ServletResponse response,
authWrapper.authenticationChannel = AuthSequenceUtil.buildAuthChannel(authChannelRegistry, authWrapper.sequence);
try {
initAuthenticationModule(mpAuthentication, authWrapper, httpRequest);
if (mpAuthentication != null && authWrapper.switchSecurityPolicy && !mpAuthentication.isMerged()) {
mpAuthentication.getAuthModules().clear();
mpAuthentication.setAuthModules(authWrapper.authModules);
AuthModule module = getUnauthenticatedModule(authWrapper.authModules, mpAuthentication);
if (module != null) {
mpAuthentication.addAuthentications(module.getBaseModuleAuthentication());
}
mpAuthentication.setMerged(true);
}
// if (mpAuthentication != null && authWrapper.switchSecurityPolicy && !mpAuthentication.isMerged()) {
// mpAuthentication.getAuthModules().clear();
// mpAuthentication.setAuthModules(authWrapper.authModules);
// AuthModule module = getUnauthenticatedModule(authWrapper.authModules, mpAuthentication);
// if (module != null) {
// mpAuthentication.addAuthentications(module.getBaseModuleAuthentication());
// }
// mpAuthentication.setMerged(true);
// }

if (isRequestAuthenticated(mpAuthentication, authWrapper)) {
processingOfAuthenticatedRequest(mpAuthentication, httpRequest, response, chain);
Expand All @@ -184,31 +184,13 @@ private void doFilterInternal(ServletRequest request, ServletResponse response,
indexOfProcessingModule = getIndexOfActualProcessingModule(mpAuthentication, httpRequest);
}
setAuthenticationChanel(mpAuthentication, authWrapper);
authWrapper.switchSecurityPolicy = false;
// authWrapper.switchSecurityPolicy = false;
runFilters(authWrapper, indexOfProcessingModule, chain, httpRequest, response);
} finally {
removingFiltersAfterProcessing(mpAuthentication, httpRequest);
}
}

private AuthModule getUnauthenticatedModule(List<AuthModule> modules, MidpointAuthentication mpAuthentication) {
if (CollectionUtils.isEmpty(modules)) {
return null;
}
return modules.stream()
.filter(module -> !authModuleAlreadyProcessed(mpAuthentication, module.getModuleIdentifier()))
.findFirst()
.orElse(null);
}

private boolean authModuleAlreadyProcessed(MidpointAuthentication mpAuthentication, String moduleIdentifier) {
if (CollectionUtils.isEmpty(mpAuthentication.getAuthentications())) {
return false;
}
return mpAuthentication.getAuthentications().stream().anyMatch(auth -> auth.getModuleIdentifier().equals(moduleIdentifier)
&& AuthenticationModuleState.SUCCESSFULLY.equals(auth.getState()));
}

private void removingFiltersAfterProcessing(MidpointAuthentication mpAuthentication, HttpServletRequest httpRequest) {
if (!AuthSequenceUtil.isSpecificSequence(httpRequest) && httpRequest.getSession(false) == null && mpAuthentication != null) {
removeUnusedSecurityFilterPublisher.publishCustomEvent(mpAuthentication);
Expand Down Expand Up @@ -247,7 +229,7 @@ private int initNewAuthenticationToken(AuthenticationWrapper authWrapper, HttpSe

private boolean needCreateNewAuthenticationToken(MidpointAuthentication mpAuthentication, AuthenticationWrapper authWrapper, HttpServletRequest httpRequest) {
return AuthSequenceUtil.isSpecificSequence(httpRequest)
|| needRestartAuthFlow(getIndexOfActualProcessingModule(mpAuthentication, httpRequest), authWrapper);
|| needRestartAuthFlow(getIndexOfActualProcessingModule(mpAuthentication, httpRequest), authWrapper, mpAuthentication);
}

private void setLogoutPath(ServletRequest request, ServletResponse response) {
Expand All @@ -260,8 +242,7 @@ private boolean wasNotFoundAuthModule(AuthenticationWrapper authWrapper) {

private boolean isRequestAuthenticated(MidpointAuthentication mpAuthentication, AuthenticationWrapper authWrapper) {
return mpAuthentication != null && mpAuthentication.isAuthenticated()
&& sequenceIdentifiersMatch(authWrapper.sequence, mpAuthentication.getSequence())
&& CollectionUtils.size(mpAuthentication.getAuthModules()) == CollectionUtils.size(authWrapper.authModules);
&& sequenceIdentifiersMatch(authWrapper.sequence, mpAuthentication.getSequence());
}

private boolean sequenceIdentifiersMatch(AuthenticationSequenceType seq1, AuthenticationSequenceType seq2) {
Expand Down Expand Up @@ -331,9 +312,9 @@ private boolean isPermitAllPage(HttpServletRequest request) {
return AuthSequenceUtil.isPermitAll(request) && !AuthSequenceUtil.isLoginPage(request);
}

private boolean needRestartAuthFlow(int indexOfProcessingModule, AuthenticationWrapper authWrapper) {
private boolean needRestartAuthFlow(int indexOfProcessingModule, AuthenticationWrapper authWrapper, MidpointAuthentication mpAuthentication) {
// if index == -1 indicate restart authentication flow
return !authWrapper.switchSecurityPolicy && indexOfProcessingModule == MidpointAuthentication.NO_MODULE_FOUND_INDEX;
return (mpAuthentication == null || !mpAuthentication.isMerged()) && indexOfProcessingModule == MidpointAuthentication.NO_MODULE_FOUND_INDEX;
}

private int restartAuthFlow(HttpServletRequest httpRequest, AuthenticationWrapper authWrapper) {
Expand All @@ -344,6 +325,7 @@ private int restartAuthFlow(HttpServletRequest httpRequest, AuthenticationWrappe

private void createMpAuthentication(HttpServletRequest httpRequest, AuthenticationWrapper authWrapper) {
MidpointAuthentication mpAuthentication = new MidpointAuthentication(authWrapper.sequence);
mpAuthentication.setSharedObjects(sharedObjects);
mpAuthentication.setAuthModules(authWrapper.authModules);
mpAuthentication.setSessionId(httpRequest.getSession(false) != null ?
httpRequest.getSession(false).getId() : RandomStringUtils.random(30, true, true).toUpperCase());
Expand Down Expand Up @@ -391,16 +373,16 @@ private List<AuthModule> createAuthenticationModuleBySequence(MidpointAuthentica
authModuleRegistry, authWrapper.sequence, httpRequest, authWrapper.authenticationsPolicy.getModules(),
authWrapper.credentialsPolicy, sharedObjects, authWrapper.authenticationChannel);
} else {
AuthenticationSequenceType sequence =
AuthSequenceUtil.getSequenceByIdentifier(mpAuthentication.getSequence().getIdentifier(), authWrapper.authenticationsPolicy);
authWrapper.switchSecurityPolicy = sequence != null && CollectionUtils.size(mpAuthentication.getSequence().getModule()) != CollectionUtils.size(sequence.getModule());
if (authWrapper.switchSecurityPolicy) {
authModules = AuthSequenceUtil.buildModuleFilters(
authModuleRegistry, authWrapper.sequence, httpRequest, authWrapper.authenticationsPolicy.getModules(),
authWrapper.credentialsPolicy, sharedObjects, authWrapper.authenticationChannel);
} else {
// AuthenticationSequenceType sequence =
// AuthSequenceUtil.getSequenceByIdentifier(mpAuthentication.getSequence().getIdentifier(), authWrapper.authenticationsPolicy);
// authWrapper.switchSecurityPolicy = sequence != null && CollectionUtils.size(mpAuthentication.getSequence().getModule()) != CollectionUtils.size(sequence.getModule());
// if (authWrapper.switchSecurityPolicy) {
// authModules = AuthSequenceUtil.buildModuleFilters(
// authModuleRegistry, authWrapper.sequence, httpRequest, authWrapper.authenticationsPolicy.getModules(),
// authWrapper.credentialsPolicy, sharedObjects, authWrapper.authenticationChannel);
// } else {
authModules = mpAuthentication.getAuthModules();
}
// }
}
return authModules;
}
Expand Down Expand Up @@ -544,7 +526,7 @@ private class AuthenticationWrapper {
AuthenticationsPolicyType authenticationsPolicy;
CredentialsPolicyType credentialsPolicy = null;
PrismObject<SecurityPolicyType> securityPolicy = null;
List<AuthModule> authModules;
List<AuthModule> authModules; //vsetky module
AuthenticationSequenceType sequence = null;
AuthenticationChannel authenticationChannel;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import com.evolveum.midpoint.authentication.api.AuthModule;
import com.evolveum.midpoint.authentication.impl.factory.module.AuthModuleRegistryImpl;
import com.evolveum.midpoint.authentication.impl.module.authentication.ModuleAuthenticationImpl;
import com.evolveum.midpoint.authentication.impl.util.AuthSequenceUtil;
import com.evolveum.midpoint.authentication.api.util.AuthConstants;
Expand All @@ -24,7 +26,9 @@
import com.evolveum.midpoint.xml.ns._public.common.common_3.AuthenticationSequenceType;
import com.evolveum.midpoint.xml.ns._public.common.common_3.SecurityPolicyType;

import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.security.core.Authentication;
import org.springframework.security.web.authentication.SavedRequestAwareAuthenticationSuccessHandler;
import org.springframework.security.web.savedrequest.HttpSessionRequestCache;
Expand All @@ -39,6 +43,8 @@
*/
public class MidPointAuthenticationSuccessHandler extends SavedRequestAwareAuthenticationSuccessHandler {

@Autowired
private AuthModuleRegistryImpl authModuleRegistry;
private String defaultTargetUrl;

public MidPointAuthenticationSuccessHandler() {
Expand Down Expand Up @@ -70,14 +76,23 @@ public void onAuthenticationSuccess(HttpServletRequest request, HttpServletRespo
if (mpAuthentication.getAuthenticationChannel() != null) {
authenticatedChannel = mpAuthentication.getAuthenticationChannel().getChannelId();
boolean continueSequence = false;
if (mpAuthentication.getPrincipal() instanceof MidPointPrincipal) {
if (mpAuthentication.getPrincipal() instanceof MidPointPrincipal) { //todo refactor, move to methods
MidPointPrincipal principal = (MidPointPrincipal) mpAuthentication.getPrincipal();
SecurityPolicyType securityPolicy = principal.getApplicableSecurityPolicy();
if (securityPolicy != null) {
AuthenticationSequenceType processingSequence = mpAuthentication.getSequence();
AuthenticationSequenceType sequence = SecurityPolicyUtil.findSequenceByIdentifier(securityPolicy, processingSequence.getIdentifier());
if (processingSequence.getModule().size() != sequence.getModule().size()) {
continueSequence = true;
mpAuthentication.setSequence(sequence);
mpAuthentication.setAuthModules(AuthSequenceUtil.buildModuleFilters(
authModuleRegistry, sequence, request, securityPolicy.getAuthentication().getModules(),
securityPolicy.getCredentials(), mpAuthentication.getSharedObjects(), mpAuthentication.getAuthenticationChannel()));
mpAuthentication.setMerged(true);
AuthModule module = getUnauthenticatedModule(mpAuthentication);
// if (module != null) {
// mpAuthentication.addAuthentications(module.getBaseModuleAuthentication());
// }
}
}
}
Expand Down Expand Up @@ -120,6 +135,25 @@ public void onAuthenticationSuccess(HttpServletRequest request, HttpServletRespo
super.onAuthenticationSuccess(request, response, authentication);
}

private AuthModule getUnauthenticatedModule(MidpointAuthentication mpAuthentication) {
if (CollectionUtils.isEmpty(mpAuthentication.getAuthModules())) {
return null;
}
return mpAuthentication.getAuthModules().stream()
.filter(module -> !authModuleAlreadyProcessed(mpAuthentication, module.getModuleIdentifier()))
.findFirst()
.orElse(null);
}


private boolean authModuleAlreadyProcessed(MidpointAuthentication mpAuthentication, String moduleIdentifier) {
if (CollectionUtils.isEmpty(mpAuthentication.getAuthentications())) {
return false;
}
return mpAuthentication.getAuthentications().stream().anyMatch(auth -> auth.getModuleIdentifier().equals(moduleIdentifier)
&& AuthenticationModuleState.SUCCESSFULLY.equals(auth.getState()));
}

@Override
protected String getTargetUrlParameter() {
return defaultTargetUrl;
Expand Down

0 comments on commit 5200112

Please sign in to comment.