Skip to content

Commit

Permalink
cleanu for module fatories - configurer now creates configuration and…
Browse files Browse the repository at this point in the history
… filter chain
  • Loading branch information
katkav committed Jul 27, 2023
1 parent 5837547 commit 64eccab
Show file tree
Hide file tree
Showing 41 changed files with 645 additions and 202 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,29 +12,21 @@

import com.evolveum.midpoint.authentication.api.config.ModuleAuthentication;

import com.evolveum.midpoint.authentication.impl.filter.RefuseUnauthenticatedRequestFilter;

import jakarta.servlet.ServletRequest;

import com.evolveum.midpoint.authentication.impl.module.configurer.ModuleWebSecurityConfigurer;
import com.evolveum.midpoint.authentication.impl.util.AuthModuleImpl;
import com.evolveum.midpoint.authentication.api.AuthModule;
import com.evolveum.midpoint.authentication.api.AuthenticationChannel;
import com.evolveum.midpoint.authentication.impl.module.authentication.ModuleAuthenticationImpl;
import com.evolveum.midpoint.authentication.api.ModuleWebSecurityConfiguration;

import org.apache.commons.lang3.StringUtils;
import org.springframework.security.authentication.AuthenticationProvider;
import org.springframework.security.config.annotation.ObjectPostProcessor;
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.web.SecurityFilterChain;

import com.evolveum.midpoint.util.logging.Trace;
import com.evolveum.midpoint.util.logging.TraceManager;
import com.evolveum.midpoint.xml.ns._public.common.common_3.*;

import org.springframework.security.web.authentication.switchuser.SwitchUserFilter;

/**
* @author skublik
*/
Expand All @@ -43,7 +35,7 @@ public abstract class AbstractCredentialModuleFactory<
CA extends ModuleWebSecurityConfigurer<C, MT>,
MT extends AbstractAuthenticationModuleType,
MA extends ModuleAuthentication>
extends AbstractModuleFactory<MT, MA> {
extends AbstractModuleFactory<C, CA, MT, MA> {

private static final Trace LOGGER = TraceManager.getTrace(AbstractCredentialModuleFactory.class);

Expand Down Expand Up @@ -76,20 +68,21 @@ public AuthModule<MA> createModuleFilter(
// getProvider((AbstractCredentialAuthenticationModuleType) moduleType, credentialPolicy));


CA moduleConfigurer = getObjectObjectPostProcessor()
.postProcess(createModuleConfigurer(moduleType, sequenceSuffix, authenticationChannel, getObjectObjectPostProcessor()));

HttpSecurity http = moduleConfigurer.getNewHttpSecurity();
http.addFilterAfter(new RefuseUnauthenticatedRequestFilter(), SwitchUserFilter.class);
setSharedObjects(http, sharedObjects);

SecurityFilterChain filter = http.build();

// CA moduleConfigurer = getObjectObjectPostProcessor()
// .postProcess(createModuleConfigurer(moduleType, sequenceSuffix, authenticationChannel, getObjectObjectPostProcessor()));

MA moduleAuthentication = createEmptyModuleAuthentication(moduleType, moduleConfigurer.getConfiguration(), necessity);
moduleAuthentication.setFocusType(moduleType.getFocusType());
// HttpSecurity http = moduleConfigurer.getNewHttpSecurity();
// http.addFilterAfter(new RefuseUnauthenticatedRequestFilter(), SwitchUserFilter.class);
// setSharedObjects(http, sharedObjects);
//
// SecurityFilterChain filter = http.build();
//
//
// MA moduleAuthentication = createEmptyModuleAuthentication(moduleType, moduleConfigurer.getConfiguration(), necessity);
// moduleAuthentication.setFocusType(moduleType.getFocusType());

return AuthModuleImpl.build(filter, moduleConfigurer.getConfiguration(), moduleAuthentication);
// return AuthModuleImpl.build(filter, moduleConfigurer.getConfiguration(), moduleAuthentication);
return null;
}


Expand Down Expand Up @@ -157,13 +150,13 @@ private String getCredentialAuthModuleIdentifier(AbstractCredentialAuthenticatio
}

protected abstract MA createEmptyModuleAuthentication(
MT moduleType, C configuration, AuthenticationSequenceModuleType sequenceModule);
MT moduleType, C configuration, AuthenticationSequenceModuleType sequenceModule, ServletRequest request);


protected abstract CA createModuleConfigurer(MT moduleType,
String sequenceSuffix,
AuthenticationChannel authenticationChannel,
ObjectPostProcessor<Object> objectPostProcessor);
ObjectPostProcessor<Object> objectPostProcessor, ServletRequest request);


protected abstract AuthenticationProvider createProvider(CredentialPolicyType usedPolicy);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,14 @@

import java.util.Map;

import com.evolveum.midpoint.authentication.api.ModuleWebSecurityConfiguration;
import com.evolveum.midpoint.authentication.api.config.ModuleAuthentication;

import com.evolveum.midpoint.authentication.impl.util.AuthModuleImpl;

import com.evolveum.midpoint.util.logging.Trace;
import com.evolveum.midpoint.util.logging.TraceManager;

import jakarta.annotation.PostConstruct;
import jakarta.servlet.ServletRequest;

Expand All @@ -28,13 +34,20 @@

import com.evolveum.midpoint.schema.constants.SchemaConstants;

import org.springframework.security.web.SecurityFilterChain;
import org.springframework.security.web.authentication.switchuser.SwitchUserFilter;

/**
* @author skublik
*/

public abstract class AbstractModuleFactory<MT extends AbstractAuthenticationModuleType, MA extends ModuleAuthentication> {
public abstract class AbstractModuleFactory<
C extends ModuleWebSecurityConfiguration,
CA extends ModuleWebSecurityConfigurer<C, MT>,
MT extends AbstractAuthenticationModuleType,
MA extends ModuleAuthentication> implements ModuleFactory<MT, MA> {

private static final Trace LOGGER = TraceManager.getTrace(AbstractModuleFactory.class);

@PostConstruct
public void register() {
Expand Down Expand Up @@ -62,7 +75,72 @@ public abstract AuthModule<MA> createModuleFilter(MT moduleType, String sequence
AuthenticationModulesType authenticationsPolicy, CredentialsPolicyType credentialPolicy,
AuthenticationChannel authenticationChannel, AuthenticationSequenceModuleType sequenceModule) throws Exception;

protected Integer getOrder(){
@Override
public AuthModule<MA> createAuthModule(MT moduleType, String sequenceSuffix,
ServletRequest request, Map<Class<?>, Object> sharedObjects,
AuthenticationModulesType authenticationsPolicy, CredentialsPolicyType credentialPolicy,
AuthenticationChannel authenticationChannel, AuthenticationSequenceModuleType sequenceModule) throws Exception {

validateChanelAndModule(authenticationChannel, moduleType);


//TODO PROVIDERS
// configuration.addAuthenticationProvider(
// getProvider((AbstractCredentialAuthenticationModuleType) moduleType, credentialPolicy));


CA moduleConfigurer = getObjectObjectPostProcessor()
.postProcess(createModuleConfigurer(moduleType, sequenceSuffix, authenticationChannel, getObjectObjectPostProcessor(), request));

HttpSecurity http = moduleConfigurer.getNewHttpSecurity();
http.addFilterAfter(new RefuseUnauthenticatedRequestFilter(), SwitchUserFilter.class);
setSharedObjects(http, sharedObjects);

SecurityFilterChain filter = http.build();
postProcessFilter(filter, moduleConfigurer);


MA moduleAuthentication = createEmptyModuleAuthentication(moduleType, moduleConfigurer.getConfiguration(), sequenceModule, request);
moduleAuthentication.setFocusType(moduleType.getFocusType());

return AuthModuleImpl.build(filter, moduleConfigurer.getConfiguration(), moduleAuthentication);
}

protected void postProcessFilter(SecurityFilterChain filter, CA configurer) {
// Nothing to do here. Subclasses may override.
}

protected void validateChanelAndModule(AuthenticationChannel authenticationChannel, MT moduleType) {
if (!(moduleType instanceof AbstractCredentialAuthenticationModuleType)) {
LOGGER.error("This factory supports only AbstractCredentialAuthenticationModuleType, but moduleType is " + moduleType);
throw new IllegalArgumentException("Unsupported factory " + this.getClass().getSimpleName()
+ " for module " + moduleType);
}


if (authenticationChannel == null) {
return;
}

//TODO chanel
if (SchemaConstants.CHANNEL_SELF_REGISTRATION_URI.equals(authenticationChannel.getChannelId())) {
throw new IllegalArgumentException("Unsupported factory " + this.getClass().getSimpleName()
+ " for channel " + authenticationChannel.getChannelId());
}
}

protected abstract CA createModuleConfigurer(MT moduleType,
String sequenceSuffix,
AuthenticationChannel authenticationChannel,
ObjectPostProcessor<Object> objectPostProcessor, ServletRequest request);

protected abstract MA createEmptyModuleAuthentication(
MT moduleType, C configuration,
AuthenticationSequenceModuleType sequenceModule,
ServletRequest request);


public Integer getOrder() {
return 0;
}

Expand All @@ -82,11 +160,11 @@ protected void isSupportedChannel(AuthenticationChannel authenticationChannel) {
}
}

HttpSecurity getNewHttpSecurity(ModuleWebSecurityConfigurer module) throws Exception {
// module.setObjectPostProcessor(getObjectObjectPostProcessor());
HttpSecurity httpSecurity = module.getNewHttpSecurity();
httpSecurity.addFilterAfter(new RefuseUnauthenticatedRequestFilter(), SwitchUserFilter.class);
return httpSecurity;
}
// HttpSecurity getNewHttpSecurity(ModuleWebSecurityConfigurer module) throws Exception {
//// module.setObjectPostProcessor(getObjectObjectPostProcessor());
// HttpSecurity httpSecurity = module.getNewHttpSecurity();
// httpSecurity.addFilterAfter(new RefuseUnauthenticatedRequestFilter(), SwitchUserFilter.class);
// return httpSecurity;
// }

}
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@

import com.evolveum.midpoint.authentication.api.AuthenticationChannel;
import com.evolveum.midpoint.authentication.impl.module.authentication.ArchetypeSelectionModuleAuthentication;
import com.evolveum.midpoint.authentication.impl.module.authentication.ModuleAuthenticationImpl;
import com.evolveum.midpoint.authentication.impl.module.configuration.LoginFormModuleWebSecurityConfiguration;
import com.evolveum.midpoint.authentication.impl.module.configurer.ArchetypeSelectionModuleWebSecurityConfigurer;
import com.evolveum.midpoint.authentication.impl.provider.ArchetypeSelectionAuthenticationProvider;
import com.evolveum.midpoint.util.logging.Trace;
import com.evolveum.midpoint.util.logging.TraceManager;
import com.evolveum.midpoint.xml.ns._public.common.common_3.*;

import jakarta.servlet.ServletRequest;
import org.springframework.security.authentication.AuthenticationProvider;
import org.springframework.security.config.annotation.ObjectPostProcessor;
import org.springframework.stereotype.Component;
Expand All @@ -37,7 +37,7 @@ public boolean match(AbstractAuthenticationModuleType moduleType, Authentication

@Override
protected ArchetypeSelectionModuleAuthentication createEmptyModuleAuthentication(ArchetypeSelectionModuleType moduleType,
LoginFormModuleWebSecurityConfiguration configuration, AuthenticationSequenceModuleType sequenceModule) {
LoginFormModuleWebSecurityConfiguration configuration, AuthenticationSequenceModuleType sequenceModule, ServletRequest request) {
ArchetypeSelectionModuleAuthentication moduleAuthentication = new ArchetypeSelectionModuleAuthentication(sequenceModule);
moduleAuthentication.setPrefix(configuration.getPrefixOfModule());
moduleAuthentication.setCredentialName(moduleType.getCredentialName());
Expand All @@ -51,8 +51,8 @@ protected ArchetypeSelectionModuleWebSecurityConfigurer<LoginFormModuleWebSecuri
ArchetypeSelectionModuleType moduleType,
String sequenceSuffix,
AuthenticationChannel authenticationChannel,
ObjectPostProcessor<Object> objectPostProcessor) {
return new ArchetypeSelectionModuleWebSecurityConfigurer<>(moduleType, sequenceSuffix, authenticationChannel, objectPostProcessor);
ObjectPostProcessor<Object> objectPostProcessor, ServletRequest request) {
return new ArchetypeSelectionModuleWebSecurityConfigurer<>(moduleType, sequenceSuffix, authenticationChannel, objectPostProcessor, request);
// return null;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,12 @@

import com.evolveum.midpoint.authentication.api.AuthenticationChannel;
import com.evolveum.midpoint.authentication.impl.module.authentication.AttributeVerificationModuleAuthentication;
import com.evolveum.midpoint.authentication.impl.module.authentication.ModuleAuthenticationImpl;
import com.evolveum.midpoint.authentication.impl.module.configuration.LoginFormModuleWebSecurityConfiguration;
import com.evolveum.midpoint.authentication.impl.module.configurer.ArchetypeSelectionModuleWebSecurityConfigurer;
import com.evolveum.midpoint.authentication.impl.module.configurer.AttributeVerificationModuleWebSecurityConfigurer;
import com.evolveum.midpoint.authentication.impl.provider.AttributeVerificationProvider;
import com.evolveum.midpoint.xml.ns._public.common.common_3.*;

import jakarta.servlet.ServletRequest;
import org.springframework.security.authentication.AuthenticationProvider;
import org.springframework.security.config.annotation.ObjectPostProcessor;
import org.springframework.stereotype.Component;
Expand All @@ -36,8 +35,8 @@ protected AttributeVerificationModuleWebSecurityConfigurer<LoginFormModuleWebSec
AttributeVerificationAuthenticationModuleType moduleType,
String sequenceSuffix,
AuthenticationChannel authenticationChannel,
ObjectPostProcessor<Object> objectPostProcessor) {
return new AttributeVerificationModuleWebSecurityConfigurer<>(moduleType, sequenceSuffix, authenticationChannel, objectPostProcessor);
ObjectPostProcessor<Object> objectPostProcessor, ServletRequest request) {
return new AttributeVerificationModuleWebSecurityConfigurer<>(moduleType, sequenceSuffix, authenticationChannel, objectPostProcessor, request);
// return null;
}

Expand All @@ -53,7 +52,7 @@ protected Class<? extends CredentialPolicyType> supportedClass() {

@Override
protected AttributeVerificationModuleAuthentication createEmptyModuleAuthentication(AttributeVerificationAuthenticationModuleType moduleType,
LoginFormModuleWebSecurityConfiguration configuration, AuthenticationSequenceModuleType sequenceModule) {
LoginFormModuleWebSecurityConfiguration configuration, AuthenticationSequenceModuleType sequenceModule, ServletRequest request) {
AttributeVerificationModuleAuthentication moduleAuthentication = new AttributeVerificationModuleAuthentication(sequenceModule);
moduleAuthentication.setPrefix(configuration.getPrefixOfModule());
moduleAuthentication.setCredentialName(moduleType.getCredentialName());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

import com.evolveum.midpoint.authentication.api.AuthenticationChannel;

import com.evolveum.midpoint.authentication.api.config.ModuleAuthentication;

import org.springframework.stereotype.Component;

import com.evolveum.midpoint.util.logging.Trace;
Expand All @@ -27,12 +29,12 @@ public class AuthModuleRegistryImpl {

private static final Trace LOGGER = TraceManager.getTrace(AuthModuleRegistryImpl.class);

List<AbstractModuleFactory> moduleFactories = new ArrayList<>();
List<ModuleFactory> moduleFactories = new ArrayList<>();

public void addToRegistry(AbstractModuleFactory factory) {
public void addToRegistry(ModuleFactory factory) {
moduleFactories.add(factory);

Comparator<? super AbstractModuleFactory> comparator =
Comparator<? super ModuleFactory> comparator =
(f1,f2) -> {

Integer f1Order = f1.getOrder();
Expand All @@ -57,26 +59,30 @@ public void addToRegistry(AbstractModuleFactory factory) {

}

public AbstractModuleFactory findModuleFactory(AbstractAuthenticationModuleType configuration, AuthenticationChannel authenticationChannel) {
public <MT extends AbstractAuthenticationModuleType, MA extends ModuleAuthentication> ModuleFactory<MT, MA> findModuleFactory(
AbstractAuthenticationModuleType configuration, AuthenticationChannel authenticationChannel) {

Optional<AbstractModuleFactory> opt = moduleFactories.stream().filter(f -> f.match(configuration, authenticationChannel)).findFirst();
Optional<ModuleFactory> opt = moduleFactories.stream().filter(f -> f.match(configuration, authenticationChannel)).findFirst();
if (opt.isEmpty()) {
LOGGER.trace("No factory found for {}", configuration);
return null;
}
AbstractModuleFactory factory = opt.get();
ModuleFactory factory = opt.get();
LOGGER.trace("Found component factory {} for {}", factory, configuration);
return factory;
}

public <T extends AbstractModuleFactory> T findModelFactoryByClass(Class<T> clazz) {

Optional<T> opt = (Optional<T>) moduleFactories.stream().filter(f -> f.getClass().equals(clazz)).findFirst();
if (opt.isEmpty()) {
LOGGER.trace("No factory found for class {}", clazz);
return null;
}
T factory = opt.get();
public <T extends ModuleFactory> T findModuleFactoryByClass(Class<T> clazz) {

T factory = (T) moduleFactories.stream()
.filter(f -> f.getClass().equals(clazz))
.findFirst()
.orElse(null);
// if (opt.isEmpty()) {
// LOGGER.trace("No factory found for class {}", clazz);
// return null;
// }
// T factory = opt.get();
LOGGER.trace("Found component factory {} for class {}", factory, clazz);
return factory;
}
Expand Down

0 comments on commit 64eccab

Please sign in to comment.