Skip to content

Commit

Permalink
recent code optimizing
Browse files Browse the repository at this point in the history
  • Loading branch information
KaterynaHonchar committed Dec 22, 2022
1 parent 4a05596 commit d8a8307
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 67 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ public class MidpointAuthentication extends AbstractAuthenticationToken implemen
* Configuration of sequence from xml
*/
private AuthenticationSequenceType sequence;
private Map<Class<?>, Object> sharedObjects;
private Map<Class<?>, Object> sharedObjects; //todo may be wrong place


/**
Expand Down Expand Up @@ -320,9 +320,7 @@ public int getIndexOfModule(ModuleAuthentication authentication) {

for (int i = 0; i < getAuthModules().size(); i++) {
if (getAuthModules().get(i).getModuleIdentifier().equals(authentication.getModuleIdentifier())) {
int indexOfModule = i;
//TODO presumption that necessity is sufficient
return indexOfModule;
return i;
}
}
return NO_MODULE_FOUND_INDEX;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,16 +152,6 @@ private void doFilterInternal(ServletRequest request, ServletResponse response,
authWrapper.authenticationChannel = AuthSequenceUtil.buildAuthChannel(authChannelRegistry, authWrapper.sequence);
try {
initAuthenticationModule(mpAuthentication, authWrapper, httpRequest);
// if (mpAuthentication != null && authWrapper.switchSecurityPolicy && !mpAuthentication.isMerged()) {
// mpAuthentication.getAuthModules().clear();
// mpAuthentication.setAuthModules(authWrapper.authModules);
// AuthModule module = getUnauthenticatedModule(authWrapper.authModules, mpAuthentication);
// if (module != null) {
// mpAuthentication.addAuthentications(module.getBaseModuleAuthentication());
// }
// mpAuthentication.setMerged(true);
// }

if (isRequestAuthenticated(mpAuthentication, authWrapper)) {
processingOfAuthenticatedRequest(mpAuthentication, httpRequest, response, chain);
return;
Expand All @@ -177,14 +167,13 @@ private void doFilterInternal(ServletRequest request, ServletResponse response,
resolveErrorWithMoreModules(mpAuthentication, httpRequest);

int indexOfProcessingModule;
if (needCreateNewAuthenticationToken(mpAuthentication, authWrapper, httpRequest)) {
if (needCreateNewAuthenticationToken(mpAuthentication, httpRequest)) {
indexOfProcessingModule = initNewAuthenticationToken(authWrapper, httpRequest);
mpAuthentication = (MidpointAuthentication) SecurityContextHolder.getContext().getAuthentication();
} else {
indexOfProcessingModule = getIndexOfActualProcessingModule(mpAuthentication, httpRequest);
}
setAuthenticationChanel(mpAuthentication, authWrapper);
// authWrapper.switchSecurityPolicy = false;
runFilters(authWrapper, indexOfProcessingModule, chain, httpRequest, response);
} finally {
removingFiltersAfterProcessing(mpAuthentication, httpRequest);
Expand Down Expand Up @@ -227,9 +216,9 @@ private int initNewAuthenticationToken(AuthenticationWrapper authWrapper, HttpSe
}
}

private boolean needCreateNewAuthenticationToken(MidpointAuthentication mpAuthentication, AuthenticationWrapper authWrapper, HttpServletRequest httpRequest) {
private boolean needCreateNewAuthenticationToken(MidpointAuthentication mpAuthentication, HttpServletRequest httpRequest) {
return AuthSequenceUtil.isSpecificSequence(httpRequest)
|| needRestartAuthFlow(getIndexOfActualProcessingModule(mpAuthentication, httpRequest), authWrapper, mpAuthentication);
|| needRestartAuthFlow(getIndexOfActualProcessingModule(mpAuthentication, httpRequest), mpAuthentication);
}

private void setLogoutPath(ServletRequest request, ServletResponse response) {
Expand Down Expand Up @@ -312,9 +301,9 @@ private boolean isPermitAllPage(HttpServletRequest request) {
return AuthSequenceUtil.isPermitAll(request) && !AuthSequenceUtil.isLoginPage(request);
}

private boolean needRestartAuthFlow(int indexOfProcessingModule, AuthenticationWrapper authWrapper, MidpointAuthentication mpAuthentication) {
private boolean needRestartAuthFlow(int indexOfProcessingModule, MidpointAuthentication mpAuthentication) {
// if index == -1 indicate restart authentication flow
return (mpAuthentication == null || !mpAuthentication.isMerged()) && indexOfProcessingModule == MidpointAuthentication.NO_MODULE_FOUND_INDEX;
return mpAuthentication == null || !mpAuthentication.isMerged() || indexOfProcessingModule == MidpointAuthentication.NO_MODULE_FOUND_INDEX;
}

private int restartAuthFlow(HttpServletRequest httpRequest, AuthenticationWrapper authWrapper) {
Expand Down Expand Up @@ -373,16 +362,7 @@ private List<AuthModule> createAuthenticationModuleBySequence(MidpointAuthentica
authModuleRegistry, authWrapper.sequence, httpRequest, authWrapper.authenticationsPolicy.getModules(),
authWrapper.credentialsPolicy, sharedObjects, authWrapper.authenticationChannel);
} else {
// AuthenticationSequenceType sequence =
// AuthSequenceUtil.getSequenceByIdentifier(mpAuthentication.getSequence().getIdentifier(), authWrapper.authenticationsPolicy);
// authWrapper.switchSecurityPolicy = sequence != null && CollectionUtils.size(mpAuthentication.getSequence().getModule()) != CollectionUtils.size(sequence.getModule());
// if (authWrapper.switchSecurityPolicy) {
// authModules = AuthSequenceUtil.buildModuleFilters(
// authModuleRegistry, authWrapper.sequence, httpRequest, authWrapper.authenticationsPolicy.getModules(),
// authWrapper.credentialsPolicy, sharedObjects, authWrapper.authenticationChannel);
// } else {
authModules = mpAuthentication.getAuthModules();
// }
authModules = mpAuthentication.getAuthModules();
}
return authModules;
}
Expand Down Expand Up @@ -529,8 +509,6 @@ private class AuthenticationWrapper {
List<AuthModule> authModules; //vsetky module
AuthenticationSequenceType sequence = null;
AuthenticationChannel authenticationChannel;

boolean switchSecurityPolicy;
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -78,27 +78,10 @@ public void onAuthenticationSuccess(HttpServletRequest request, HttpServletRespo
if (mpAuthentication.getAuthenticationChannel() != null) {
authenticatedChannel = mpAuthentication.getAuthenticationChannel().getChannelId();
boolean continueSequence = false;
if (mpAuthentication.getPrincipal() instanceof MidPointPrincipal) { //todo refactor, move to methods
MidPointPrincipal principal = (MidPointPrincipal) mpAuthentication.getPrincipal();
SecurityPolicyType securityPolicy = principal.getApplicableSecurityPolicy();
if (securityPolicy != null) {
AuthenticationSequenceType processingSequence = mpAuthentication.getSequence();
AuthenticationSequenceType sequence = SecurityPolicyUtil.findSequenceByIdentifier(securityPolicy, processingSequence.getIdentifier());
if (processingSequence.getModule().size() != sequence.getModule().size()) {
continueSequence = true;
mpAuthentication.setSequence(sequence);
List<AuthModule> modules = AuthSequenceUtil.buildModuleFilters(
authModuleRegistry, sequence, request, securityPolicy.getAuthentication().getModules(),
securityPolicy.getCredentials(), mpAuthentication.getSharedObjects(), mpAuthentication.getAuthenticationChannel());
modules.removeIf(Objects::isNull);
mpAuthentication.setAuthModules(modules);
mpAuthentication.setMerged(true);
AuthModule module = getUnauthenticatedModule(mpAuthentication);
// if (module != null) {
// mpAuthentication.addAuthentications(module.getBaseModuleAuthentication());
// }
}
}
if (isNewSecurityPolicyFound(mpAuthentication)) {
continueSequence = true;
SecurityPolicyType securityPolicy = ((MidPointPrincipal) mpAuthentication.getPrincipal()).getApplicableSecurityPolicy();
updateMidpointAuthentication(request, mpAuthentication, securityPolicy);
}
if (mpAuthentication.isAuthenticated() && !continueSequence) {
urlSuffix = mpAuthentication.getAuthenticationChannel().getPathAfterSuccessfulAuthentication();
Expand Down Expand Up @@ -139,23 +122,33 @@ public void onAuthenticationSuccess(HttpServletRequest request, HttpServletRespo
super.onAuthenticationSuccess(request, response, authentication);
}

private AuthModule getUnauthenticatedModule(MidpointAuthentication mpAuthentication) {
if (CollectionUtils.isEmpty(mpAuthentication.getAuthModules())) {
return null;
private boolean isNewSecurityPolicyFound(MidpointAuthentication mpAuthentication) {
if (mpAuthentication.getPrincipal() == null || !(mpAuthentication.getPrincipal() instanceof MidPointPrincipal)) {
return false;
}
return mpAuthentication.getAuthModules().stream()
.filter(module -> !authModuleAlreadyProcessed(mpAuthentication, module.getModuleIdentifier()))
.findFirst()
.orElse(null);
}


private boolean authModuleAlreadyProcessed(MidpointAuthentication mpAuthentication, String moduleIdentifier) {
if (CollectionUtils.isEmpty(mpAuthentication.getAuthentications())) {
if (mpAuthentication.isMerged()) {
return false;
}
return mpAuthentication.getAuthentications().stream().anyMatch(auth -> auth.getModuleIdentifier().equals(moduleIdentifier)
&& AuthenticationModuleState.SUCCESSFULLY.equals(auth.getState()));
MidPointPrincipal principal = (MidPointPrincipal) mpAuthentication.getPrincipal();
SecurityPolicyType securityPolicy = principal.getApplicableSecurityPolicy();
if (securityPolicy == null) {
return false;
}
AuthenticationSequenceType processingSequence = mpAuthentication.getSequence();
AuthenticationSequenceType sequence = SecurityPolicyUtil.findSequenceByIdentifier(securityPolicy, processingSequence.getIdentifier());
return processingSequence.getModule().size() != sequence.getModule().size();
}

private void updateMidpointAuthentication(HttpServletRequest request, MidpointAuthentication mpAuthentication, SecurityPolicyType newSecurityPolicy) {
AuthenticationSequenceType processingSequence = mpAuthentication.getSequence();
AuthenticationSequenceType sequence = SecurityPolicyUtil.findSequenceByIdentifier(newSecurityPolicy, processingSequence.getIdentifier());
mpAuthentication.setSequence(sequence);
List<AuthModule> modules = AuthSequenceUtil.buildModuleFilters(
authModuleRegistry, sequence, request, newSecurityPolicy.getAuthentication().getModules(),
newSecurityPolicy.getCredentials(), mpAuthentication.getSharedObjects(), mpAuthentication.getAuthenticationChannel());
modules.removeIf(Objects::isNull);
mpAuthentication.setAuthModules(modules);
mpAuthentication.setMerged(true);
}

@Override
Expand Down

0 comments on commit d8a8307

Please sign in to comment.