Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -42,29 +42,6 @@
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.patrodyne.jvnet</groupId>
<artifactId>hisrc-higherjaxb40-maven-plugin</artifactId>
<executions>
<execution>
<id>current</id>
<goals>
<goal>generate</goal>
</goals>
<configuration>
<generatePackage>org.apache.nifi.authentication.generated</generatePackage>
<schemaDirectory>src/main/xsd</schemaDirectory>
</configuration>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-checkstyle-plugin</artifactId>
<configuration>
<excludes>**/authentication/generated/*.java,</excludes>
</configuration>
</plugin>
<plugin>
<groupId>org.apache.rat</groupId>
<artifactId>apache-rat-plugin</artifactId>
Expand Down Expand Up @@ -209,10 +186,6 @@
<groupId>org.springframework</groupId>
<artifactId>spring-context</artifactId>
</dependency>
<dependency>
<groupId>jakarta.xml.bind</groupId>
<artifactId>jakarta.xml.bind-api</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.security</groupId>
<artifactId>spring-security-saml2-service-provider</artifactId>
Expand Down Expand Up @@ -308,10 +281,6 @@
<artifactId>nifi-web-client</artifactId>
<version>2.7.0-SNAPSHOT</version>
</dependency>
<dependency>
<groupId>org.glassfish.jaxb</groupId>
<artifactId>jaxb-runtime</artifactId>
</dependency>
<dependency>
<groupId>com.squareup.okhttp3</groupId>
<artifactId>mockwebserver3</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
package org.apache.nifi.web.security.spring;

import java.io.File;
import java.io.FileInputStream;
import java.io.InputStream;
import java.lang.reflect.Constructor;
import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
Expand All @@ -25,12 +27,8 @@
import java.util.List;
import java.util.Map;
import javax.xml.XMLConstants;
import jakarta.xml.bind.JAXBContext;
import jakarta.xml.bind.JAXBElement;
import jakarta.xml.bind.JAXBException;
import jakarta.xml.bind.Unmarshaller;
import javax.xml.stream.XMLStreamReader;
import javax.xml.transform.stream.StreamSource;
import javax.xml.transform.Source;
import javax.xml.transform.dom.DOMSource;
import javax.xml.validation.Schema;
import javax.xml.validation.SchemaFactory;
import org.apache.commons.lang3.StringUtils;
Expand All @@ -43,41 +41,31 @@
import org.apache.nifi.authentication.annotation.LoginIdentityProviderContext;
import org.apache.nifi.authentication.exception.ProviderCreationException;
import org.apache.nifi.authentication.exception.ProviderDestructionException;
import org.apache.nifi.authentication.generated.LoginIdentityProviders;
import org.apache.nifi.authentication.generated.Property;
import org.apache.nifi.authentication.generated.Provider;
import org.apache.nifi.bundle.Bundle;
import org.apache.nifi.nar.ExtensionManager;
import org.apache.nifi.nar.NarCloseable;
import org.apache.nifi.util.NiFiProperties;
import org.apache.nifi.xml.processing.stream.StandardXMLStreamReaderProvider;
import org.apache.nifi.xml.processing.stream.XMLStreamReaderProvider;
import org.apache.nifi.xml.processing.ProcessingException;
import org.apache.nifi.xml.processing.parsers.DocumentProvider;
import org.apache.nifi.xml.processing.parsers.StandardDocumentProvider;
import org.apache.nifi.xml.processing.validation.SchemaValidator;
import org.apache.nifi.xml.processing.validation.StandardSchemaValidator;
import org.springframework.beans.factory.DisposableBean;
import org.springframework.beans.factory.FactoryBean;
import org.xml.sax.SAXException;
import org.w3c.dom.Document;
import org.w3c.dom.Element;
import org.w3c.dom.Node;
import org.w3c.dom.NodeList;

/**
* Spring Factory Bean implementation requires a generic Object return type to handle a null Provider configuration
*/
public class LoginIdentityProviderFactoryBean implements FactoryBean<Object>, DisposableBean, LoginIdentityProviderLookup {

private static final String LOGIN_IDENTITY_PROVIDERS_XSD = "/login-identity-providers.xsd";
private static final String JAXB_GENERATED_PATH = "org.apache.nifi.authentication.generated";
private static final JAXBContext JAXB_CONTEXT = initializeJaxbContext();

private NiFiProperties properties;

/**
* Load the JAXBContext.
*/
private static JAXBContext initializeJaxbContext() {
try {
return JAXBContext.newInstance(JAXB_GENERATED_PATH, LoginIdentityProviderFactoryBean.class.getClassLoader());
} catch (JAXBException e) {
throw new RuntimeException("Unable to create JAXBContext.");
}
}

private ExtensionManager extensionManager;
private LoginIdentityProvider loginIdentityProvider;
private final Map<String, LoginIdentityProvider> loginIdentityProviders = new HashMap<>();
Expand All @@ -100,71 +88,86 @@ public LoginIdentityProvider getLoginIdentityProvider(String identifier) {
@Override
public Object getObject() throws Exception {
if (loginIdentityProvider == null) {
// look up the login identity provider to use
final String loginIdentityProviderIdentifier = properties.getProperty(NiFiProperties.SECURITY_USER_LOGIN_IDENTITY_PROVIDER);

// ensure the login identity provider class name was specified
if (StringUtils.isNotBlank(loginIdentityProviderIdentifier)) {
final LoginIdentityProviders loginIdentityProviderConfiguration = loadLoginIdentityProvidersConfiguration();

// create each login identity provider
for (final Provider provider : loginIdentityProviderConfiguration.getProvider()) {
loginIdentityProviders.put(provider.getIdentifier(), createLoginIdentityProvider(provider.getIdentifier(), provider.getClazz()));
}

loadProviderProperties(loginIdentityProviderConfiguration);
final Map<String, LoginIdentityProvider> loadedLoginIdentityProviders = loadLoginIdentityProviders();
loginIdentityProviders.putAll(loadedLoginIdentityProviders);

// get the login identity provider instance
loginIdentityProvider = getLoginIdentityProvider(loginIdentityProviderIdentifier);

// ensure it was found
loginIdentityProvider = loginIdentityProviders.get(loginIdentityProviderIdentifier);
if (loginIdentityProvider == null) {
throw new Exception(String.format("The specified login identity provider '%s' could not be found.", loginIdentityProviderIdentifier));
throw new IllegalStateException("Login Identity Provider [%s] not found".formatted(loginIdentityProviderIdentifier));
}
}
}

return loginIdentityProvider;
}

private LoginIdentityProviders loadLoginIdentityProvidersConfiguration() throws Exception {
private Map<String, LoginIdentityProvider> loadLoginIdentityProviders() throws Exception {
final File loginIdentityProvidersConfigurationFile = properties.getLoginIdentityProviderConfigurationFile();

// load the users from the specified file
if (loginIdentityProvidersConfigurationFile.exists()) {
try {
// find the schema
try (InputStream inputStream = new FileInputStream(loginIdentityProvidersConfigurationFile)) {
final SchemaFactory schemaFactory = SchemaFactory.newInstance(XMLConstants.W3C_XML_SCHEMA_NS_URI);
final Schema schema = schemaFactory.newSchema(LoginIdentityProviders.class.getResource(LOGIN_IDENTITY_PROVIDERS_XSD));

// attempt to unmarshal
final XMLStreamReaderProvider provider = new StandardXMLStreamReaderProvider();
XMLStreamReader xsr = provider.getStreamReader(new StreamSource(loginIdentityProvidersConfigurationFile));
final Unmarshaller unmarshaller = JAXB_CONTEXT.createUnmarshaller();
unmarshaller.setSchema(schema);
final JAXBElement<LoginIdentityProviders> element = unmarshaller.unmarshal(xsr, LoginIdentityProviders.class);
return element.getValue();
} catch (SAXException | JAXBException e) {
final Schema schema = schemaFactory.newSchema(getClass().getResource(LOGIN_IDENTITY_PROVIDERS_XSD));
final SchemaValidator schemaValidator = new StandardSchemaValidator();

final DocumentProvider documentProvider = new StandardDocumentProvider();
final Document document = documentProvider.parse(inputStream);
final Source source = new DOMSource(document);

// Validate Document using Schema before parsing
schemaValidator.validate(schema, source);

return loadLoginIdentityProviders(document);
} catch (final ProcessingException e) {
throw new Exception("Unable to load the login identity provider configuration file at: " + loginIdentityProvidersConfigurationFile.getAbsolutePath());
}
} else {
throw new Exception("Unable to find the login identity provider configuration file at " + loginIdentityProvidersConfigurationFile.getAbsolutePath());
}
}

private Map<String, LoginIdentityProvider> loadLoginIdentityProviders(final Document document) throws Exception {
final Element loginIdentityProviders = (Element) document.getElementsByTagName("loginIdentityProviders").item(0);
final NodeList providers = loginIdentityProviders.getElementsByTagName("provider");

final Map<String, LoginIdentityProvider> loadedProviders = new HashMap<>();
for (int i = 0; i < providers.getLength(); i++) {
final Element provider = (Element) providers.item(i);
final NodeList identifiers = provider.getElementsByTagName("identifier");
final Node firstIdentifier = identifiers.item(0);

final String providerIdentifier = firstIdentifier.getFirstChild().getTextContent();

final Node providerClass = provider.getElementsByTagName("class").item(0);
final String providerClassName = providerClass.getFirstChild().getTextContent();
final LoginIdentityProvider identityProvider = createLoginIdentityProvider(providerIdentifier, providerClassName);

final LoginIdentityProviderConfigurationContext configurationContext = getConfigurationContext(providerIdentifier, provider);
identityProvider.onConfigured(configurationContext);

loadedProviders.put(providerIdentifier, identityProvider);
}

return loadedProviders;
}

private LoginIdentityProvider createLoginIdentityProvider(final String identifier, final String loginIdentityProviderClassName) throws Exception {
// get the classloader for the specified login identity provider
final List<Bundle> loginIdentityProviderBundles = extensionManager.getBundles(loginIdentityProviderClassName);

if (loginIdentityProviderBundles.isEmpty()) {
throw new Exception(String.format("The specified login identity provider class '%s' is not known to this nifi.", loginIdentityProviderClassName));
throw new Exception("Login Identity Provider class [%s] not registered in loaded Extension Bundles".formatted(loginIdentityProviderClassName));
}

if (loginIdentityProviderBundles.size() > 1) {
throw new Exception(String.format("Multiple bundles found for the specified login identity provider class '%s', only one is allowed.", loginIdentityProviderClassName));
}

final Bundle loginIdentityProviderBundle = loginIdentityProviderBundles.get(0);
final Bundle loginIdentityProviderBundle = loginIdentityProviderBundles.getFirst();
final ClassLoader loginIdentityProviderClassLoader = loginIdentityProviderBundle.getClassLoader();

// get the current context classloader
Expand Down Expand Up @@ -200,23 +203,23 @@ private LoginIdentityProvider createLoginIdentityProvider(final String identifie
return withNarLoader(instance);
}

private void loadProviderProperties(final LoginIdentityProviders loginIdentityProviderConfiguration) {
for (final Provider provider : loginIdentityProviderConfiguration.getProvider()) {
final LoginIdentityProvider instance = loginIdentityProviders.get(provider.getIdentifier());
final LoginIdentityProviderConfigurationContext configurationContext = getConfigurationContext(provider);
instance.onConfigured(configurationContext);
}
}

private LoginIdentityProviderConfigurationContext getConfigurationContext(final Provider provider) {
final String providerIdentifier = provider.getIdentifier();
private LoginIdentityProviderConfigurationContext getConfigurationContext(final String identifier, final Element provider) {
final Map<String, String> providerProperties = new HashMap<>();

for (final Property property : provider.getProperty()) {
providerProperties.put(property.getName(), property.getValue());
final NodeList properties = provider.getElementsByTagName("property");
for (int i = 0; i < properties.getLength(); i++) {
final Element property = (Element) properties.item(i);
final String propertyName = property.getAttribute("name");

if (property.hasChildNodes()) {
final String propertyValue = property.getFirstChild().getNodeValue();
if (StringUtils.isNotBlank(propertyValue)) {
providerProperties.put(propertyName, propertyValue);
}
}
}

return new StandardLoginIdentityProviderConfigurationContext(providerIdentifier, providerProperties);
return new StandardLoginIdentityProviderConfigurationContext(identifier, providerProperties);
}

private void performMethodInjection(final LoginIdentityProvider instance, final Class<?> loginIdentityProviderClass)
Expand Down Expand Up @@ -316,7 +319,7 @@ public boolean isSingleton() {
}

@Override
public void destroy() throws Exception {
public void destroy() {
if (loginIdentityProvider != null) {
loginIdentityProvider.preDestruction();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,6 @@
<identifier>login-identity-provider</identifier>
<class>org.apache.nifi.web.security.spring.mock.MockLoginIdentityProvider</class>
<property name="strategy">MOCK</property>
<property name="unused"/>
</provider>
</loginIdentityProviders>