Skip to content

Commit

Permalink
allow logout from SAML IDP for unathorization user and fix redirect f…
Browse files Browse the repository at this point in the history
…or internal login (MID-7076)
  • Loading branch information
skublik committed Aug 16, 2021
1 parent 6dbcce4 commit f887364
Show file tree
Hide file tree
Showing 21 changed files with 177 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,10 @@ protected void createBreadcrumb() {
@Override
protected void onBeforeRender() {
super.onBeforeRender();
confirmUserPrincipal();
}

protected void confirmUserPrincipal() {
if (SecurityUtils.getPrincipalUser() != null) {
MidPointApplication app = getMidpointApplication();
throw new RestartResponseException(app.getHomePage());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,6 @@ private String getUrlProcessingLogin() {
}
}

return "/midpoint/spring_security_login";
return "./spring_security_login";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,17 @@
<body>
<wicket:extend>
<h4><wicket:message key="PageSamlSelect.select.identity.provider"/></h4>
<div wicket:id="providers">
<div wicket:id="providers" class="col-md-12">
<a wicket:id="provider"/>
</div>
<div class="pull-right">
<form method="post" wicket:id="logoutForm">

<div wicket:id="csrfField"/>

<input type="submit" wicket:message="value:UserMenuPanel.logout"/>
</form>
</div>
</wicket:extend>

</body>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,17 @@
import java.util.ArrayList;
import java.util.List;

import com.evolveum.midpoint.gui.api.util.WebComponentUtil;
import com.evolveum.midpoint.web.component.form.MidpointForm;
import com.evolveum.midpoint.web.component.util.VisibleBehaviour;
import com.evolveum.midpoint.web.security.util.SecurityUtils;

import org.apache.wicket.AttributeModifier;
import org.apache.wicket.markup.html.WebMarkupContainer;
import org.apache.wicket.markup.html.link.ExternalLink;
import org.apache.wicket.markup.html.list.ListItem;
import org.apache.wicket.markup.html.list.ListView;
import org.apache.wicket.model.IModel;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder;

Expand All @@ -23,6 +31,8 @@
import com.evolveum.midpoint.web.security.module.authentication.Saml2ModuleAuthentication;
import com.evolveum.midpoint.web.security.util.IdentityProvider;

import org.springframework.security.saml.SamlAuthentication;

/**
* @author skublik
*/
Expand All @@ -32,18 +42,38 @@
public class PageSamlSelect extends AbstractPageLogin implements Serializable {
private static final long serialVersionUID = 1L;

private static final String ID_PROVIDERS = "providers";
private static final String ID_PROVIDER = "provider";
private static final String ID_LOGOUT_FORM = "logoutForm";
private static final String ID_CSRF_FIELD = "csrfField";

public PageSamlSelect() {
}

@Override
protected void initCustomLayer() {
List<IdentityProvider> providers = getProviders();
add(new ListView<IdentityProvider>("providers", providers) {
add(new ListView<IdentityProvider>(ID_PROVIDERS, providers) {
@Override
protected void populateItem(ListItem<IdentityProvider> item) {
item.add(new ExternalLink("provider", item.getModelObject().getRedirectLink(), item.getModelObject().getLinkText()));
item.add(new ExternalLink(ID_PROVIDER, item.getModelObject().getRedirectLink(), item.getModelObject().getLinkText()));
}
});
MidpointForm<?> form = new MidpointForm<>(ID_LOGOUT_FORM);
ModuleAuthentication actualModule = SecurityUtils.getProcessingModule(false);
form.add(new VisibleBehaviour(() -> existSamlAuthentication(actualModule)));
form.add(AttributeModifier.replace("action",
(IModel<String>) () -> existSamlAuthentication(actualModule) ?
SecurityUtils.getPathForLogoutWithContextPath(getRequest().getContextPath(), actualModule) : ""));
add(form);

WebMarkupContainer csrfField = SecurityUtils.createHiddenInputForCsrf(ID_CSRF_FIELD);
form.add(csrfField);
}

private boolean existSamlAuthentication(ModuleAuthentication actualModule) {
return actualModule instanceof Saml2ModuleAuthentication
&& actualModule.getAuthentication() instanceof SamlAuthentication;
}

private List<IdentityProvider> getProviders() {
Expand All @@ -68,4 +98,8 @@ private List<IdentityProvider> getProviders() {
error(getString(key));
return providers;
}

@Override
protected void confirmUserPrincipal() {
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import com.evolveum.midpoint.model.api.authentication.StateOfModule;
import com.evolveum.midpoint.web.security.module.SamlModuleWebSecurityConfig;

import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.core.context.SecurityContextHolder;
Expand Down Expand Up @@ -40,9 +43,13 @@ public void commence(HttpServletRequest request, HttpServletResponse response, A
ModuleAuthentication moduleAuthentication = mpAuthentication.getProcessingModuleAuthentication();
if (moduleAuthentication instanceof Saml2ModuleAuthentication) {
providers = ((Saml2ModuleAuthentication) moduleAuthentication).getProviders();
if (providers.size() == 1
&& request.getSession().getAttribute("SPRING_SECURITY_LAST_EXCEPTION") == null) {
response.sendRedirect(providers.get(0).getRedirectLink());
if (request.getSession().getAttribute("SPRING_SECURITY_LAST_EXCEPTION") == null) {
if (providers.size() == 1) {
response.sendRedirect(providers.get(0).getRedirectLink());
return;
}
} else if (SamlModuleWebSecurityConfig.SAML_LOGIN_PATH.equals(request.getServletPath())
&& StateOfModule.LOGIN_PROCESSING.equals(moduleAuthentication.getState())) {
return;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,12 +243,12 @@ private int restartAuthFlow(HttpServletRequest httpRequest, AuthenticationSequen
}

private void createMpAuthentication(HttpServletRequest httpRequest, AuthenticationSequenceType sequence, List<AuthModule> authModules) {
SecurityContextHolder.getContext().setAuthentication(null);
SecurityContextHolder.getContext().setAuthentication(new MidpointAuthentication(sequence));
MidpointAuthentication mpAuthentication = (MidpointAuthentication) SecurityContextHolder.getContext().getAuthentication();
MidpointAuthentication mpAuthentication = new MidpointAuthentication(sequence);
mpAuthentication.setAuthModules(authModules);
mpAuthentication.setSessionId(httpRequest.getSession(false) != null ? httpRequest.getSession(false).getId() : RandomStringUtils.random(30, true, true).toUpperCase());
mpAuthentication.addAuthentications(authModules.get(0).getBaseModuleAuthentication());
SecurityContextHolder.getContext().setAuthentication(null);
SecurityContextHolder.getContext().setAuthentication(mpAuthentication);
}

private void resolveErrorWithMoreModules(MidpointAuthentication mpAuthentication, HttpServletRequest httpRequest) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
public class SamlModuleWebSecurityConfig<C extends SamlModuleWebSecurityConfiguration> extends ModuleWebSecurityConfig<C> {

private static final Trace LOGGER = TraceManager.getTrace(SamlModuleWebSecurityConfig.class);
public static final String SAML_LOGIN_PATH = "/saml2/select";

@Autowired
private ModelAuditRecorder auditProvider;
Expand All @@ -76,7 +77,7 @@ protected void configure(HttpSecurity http) throws Exception {
http.csrf().disable();

getOrApply(http, new MidpointExceptionHandlingConfigurer())
.authenticationEntryPoint(new SamlAuthenticationEntryPoint("/saml2/select"));
.authenticationEntryPoint(new SamlAuthenticationEntryPoint(SAML_LOGIN_PATH));

http.addFilterAfter(
getBeanConfiguration().samlConfigurationFilter(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ public void setCredentialType(Class<? extends CredentialPolicyType> credentialTy

public ModuleAuthentication clone() {
CredentialModuleAuthentication module = new CredentialModuleAuthentication(this.getNameOfModuleType());
module.setAuthentication(this.getAuthentication());
clone(module);
return module;
}
Expand All @@ -54,6 +55,7 @@ protected void clone(ModuleAuthentication module) {
((CredentialModuleAuthentication)module).setCredentialName(getCredentialName());
((CredentialModuleAuthentication)module).setCredentialType(getCredentialType());
}
module.setAuthentication(this.getAuthentication());
super.clone(module);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ public HttpHeaderModuleAuthentication() {

public ModuleAuthentication clone() {
HttpHeaderModuleAuthentication module = new HttpHeaderModuleAuthentication();
module.setAuthentication(this.getAuthentication());
clone(module);
return module;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ protected void clone(ModuleAuthentication module) {
((HttpModuleAuthentication)module).setProxyUserOid(this.getProxyUserOid());
((HttpModuleAuthentication)module).setRealm(this.getRealm());
}
module.setAuthentication(this.getAuthentication());
super.clone(module);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ public LdapModuleAuthentication() {
public ModuleAuthentication clone() {
LdapModuleAuthentication module = new LdapModuleAuthentication();
module.setNamingAttribute(this.namingAttribute);
module.setAuthentication(this.getAuthentication());
clone(module);
return module;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ public LoginFormModuleAuthentication() {

public ModuleAuthentication clone() {
LoginFormModuleAuthentication module = new LoginFormModuleAuthentication();
module.setAuthentication(this.getAuthentication());
super.clone(module);
return module;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ public MailNonceModuleAuthentication() {

public ModuleAuthentication clone() {
MailNonceModuleAuthentication module = new MailNonceModuleAuthentication();
module.setAuthentication(this.getAuthentication());
super.clone(module);
return module;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ public OtherModuleAuthentication() {

public ModuleAuthentication clone() {
OtherModuleAuthentication module = new OtherModuleAuthentication();
module.setAuthentication(this.getAuthentication());
clone(module);
return module;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@
*/
package com.evolveum.midpoint.web.security.module.authentication;

import com.evolveum.midpoint.model.api.authentication.ModuleAuthentication;
import com.evolveum.midpoint.model.api.authentication.AuthenticationModuleNameConstants;
import com.evolveum.midpoint.model.api.authentication.ModuleType;
import com.evolveum.midpoint.model.api.authentication.StateOfModule;
import com.evolveum.midpoint.model.api.authentication.*;
import com.evolveum.midpoint.web.security.util.IdentityProvider;
import com.evolveum.midpoint.web.security.util.RequestState;
import com.evolveum.midpoint.web.security.util.SecurityUtils;

import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.saml.SamlAuthentication;

import java.util.ArrayList;
import java.util.List;
Expand Down Expand Up @@ -62,6 +64,18 @@ public ModuleAuthentication clone() {
Saml2ModuleAuthentication module = new Saml2ModuleAuthentication();
module.setNamesOfUsernameAttributes(this.getNamesOfUsernameAttributes());
module.setProviders(this.getProviders());
Authentication actualAuth = SecurityContextHolder.getContext().getAuthentication();
Authentication newAuthentication = this.getAuthentication();
if (actualAuth instanceof MidpointAuthentication
&& ((MidpointAuthentication) actualAuth).getAuthentications() != null
&& !((MidpointAuthentication) actualAuth).getAuthentications().isEmpty()) {
ModuleAuthentication actualModule = ((MidpointAuthentication) actualAuth).getAuthentications().get(0);
if (actualModule instanceof Saml2ModuleAuthentication
&& actualModule.getAuthentication() instanceof SamlAuthentication) {
newAuthentication = actualModule.getAuthentication();
}
}
module.setAuthentication(newAuthentication);
super.clone(module);
return module;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ public SecurityQuestionFormModuleAuthentication() {

public ModuleAuthentication clone() {
SecurityQuestionFormModuleAuthentication module = new SecurityQuestionFormModuleAuthentication();
module.setAuthentication(this.getAuthentication());
super.clone(module);
return module;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import java.util.List;

import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.security.authentication.AuthenticationServiceException;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
Expand Down Expand Up @@ -40,6 +41,7 @@ public class Saml2Provider extends MidPointAbstractAuthenticationProvider {
private static final Trace LOGGER = TraceManager.getTrace(Saml2Provider.class);

@Autowired
@Qualifier("samlAuthenticationEvaluator")
private AuthenticationEvaluator<PasswordAuthenticationContext> authenticationEvaluator;

@Override
Expand All @@ -63,11 +65,11 @@ protected Authentication internalAuthentication(Authentication authentication, L
AuthenticationChannel channel, Class focusType) throws AuthenticationException {
ConnectionEnvironment connEnv = createEnvironment(channel);

try {
Authentication token;
if (authentication instanceof DefaultSamlAuthentication) {
DefaultSamlAuthentication samlAuthentication = (DefaultSamlAuthentication) authentication;
Saml2ModuleAuthentication samlModule = (Saml2ModuleAuthentication) SecurityUtils.getProcessingModule(true);
Authentication token;
if (authentication instanceof DefaultSamlAuthentication) {
DefaultSamlAuthentication samlAuthentication = (DefaultSamlAuthentication) authentication;
Saml2ModuleAuthentication samlModule = (Saml2ModuleAuthentication) SecurityUtils.getProcessingModule(true);
try {
List<Attribute> attributes = ((DefaultSamlAuthentication) authentication).getAssertion().getAttributes();
String enteredUsername = "";
for (Attribute attribute : attributes) {
Expand All @@ -91,21 +93,22 @@ protected Authentication internalAuthentication(Authentication authentication, L
authContext.setSupportActivationByChannel(channel.isSupportActivationByChannel());
}
token = authenticationEvaluator.authenticateUserPreAuthenticated(connEnv, authContext);
} else {
LOGGER.error("Unsupported authentication {}", authentication);
throw new AuthenticationServiceException("web.security.provider.unavailable");
} catch (AuthenticationException e) {
samlModule.setAuthentication(samlAuthentication);
LOGGER.info("Authentication with saml module failed: {}", e.getMessage());
throw e;
}
} else {
LOGGER.error("Unsupported authentication {}", authentication);
throw new AuthenticationServiceException("web.security.provider.unavailable");
}

MidPointPrincipal principal = (MidPointPrincipal) token.getPrincipal();
MidPointPrincipal principal = (MidPointPrincipal) token.getPrincipal();

LOGGER.debug("User '{}' authenticated ({}), authorities: {}", authentication.getPrincipal(),
authentication.getClass().getSimpleName(), principal.getAuthorities());
return token;
LOGGER.debug("User '{}' authenticated ({}), authorities: {}", authentication.getPrincipal(),
authentication.getClass().getSimpleName(), principal.getAuthorities());
return token;

} catch (AuthenticationException e) {
LOGGER.info("Authentication with saml module failed: {}", e.getMessage());
throw e;
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@ public ModuleAuthentication clone() {

protected void clone (ModuleAuthentication module) {
module.setState(this.getState());
module.setAuthentication(this.getAuthentication());
module.setNameOfModule(this.nameOfModule);
module.setType(this.getType());
module.setPrefix(this.getPrefix());
Expand Down
9 changes: 9 additions & 0 deletions model/model-impl/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,15 @@
</exclusions>
</dependency>

<dependency>
<groupId>jakarta.servlet</groupId>
<artifactId>jakarta.servlet-api</artifactId>
</dependency>
<dependency>
<groupId>org.springframework</groupId>
<artifactId>spring-web</artifactId>
</dependency>

<!-- Test -->
<dependency>
<groupId>com.evolveum.midpoint</groupId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ public PreAuthenticatedAuthenticationToken authenticateUserPreAuthenticated(Conn
}

@NotNull
private MidPointPrincipal getAndCheckPrincipal(ConnectionEnvironment connEnv, String enteredUsername, Class<? extends FocusType> clazz,
protected MidPointPrincipal getAndCheckPrincipal(ConnectionEnvironment connEnv, String enteredUsername, Class<? extends FocusType> clazz,
boolean supportsActivationCheck) {

if (StringUtils.isBlank(enteredUsername)) {
Expand Down Expand Up @@ -299,7 +299,7 @@ private MidPointPrincipal getAndCheckPrincipal(ConnectionEnvironment connEnv, St
return principal;
}

private boolean hasAnyAuthorization(MidPointPrincipal principal) {
protected boolean hasAnyAuthorization(MidPointPrincipal principal) {
Collection<Authorization> authorizations = principal.getAuthorities();
if (authorizations == null || authorizations.isEmpty()){
return false;
Expand Down

0 comments on commit f887364

Please sign in to comment.