Skip to content

Commit

Permalink
Merge pull request #379 from andymc12/registerContracts
Browse files Browse the repository at this point in the history
[CXF-7638] Only register provider if it implements specified contracts
  • Loading branch information
andymc12 committed Feb 12, 2018
2 parents 403af86 + 8cd6a3e commit 2fd0b2e
Show file tree
Hide file tree
Showing 9 changed files with 402 additions and 62 deletions.
Expand Up @@ -83,7 +83,7 @@ private static class CdiServerFeatureContextConfigurable extends ConfigurableImp
private final Instantiator instantiator;

CdiServerFeatureContextConfigurable(FeatureContext mc, BeanManager beanManager) {
super(mc, RuntimeType.SERVER, SERVER_FILTER_INTERCEPTOR_CLASSES);
super(mc, RuntimeType.SERVER);
this.instantiator = new CdiInstantiator(beanManager);
}

Expand Down
Expand Up @@ -19,11 +19,20 @@

package org.apache.cxf.jaxrs.impl;

import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.logging.Logger;
import java.util.stream.Collectors;

import javax.ws.rs.ConstrainedTo;
import javax.ws.rs.Priorities;
import javax.ws.rs.RuntimeType;
import javax.ws.rs.client.ClientRequestFilter;
import javax.ws.rs.client.ClientResponseFilter;
import javax.ws.rs.container.ContainerRequestFilter;
import javax.ws.rs.container.ContainerResponseFilter;
import javax.ws.rs.core.Configurable;
import javax.ws.rs.core.Configuration;
import javax.ws.rs.core.Feature;
Expand All @@ -33,29 +42,46 @@

public class ConfigurableImpl<C extends Configurable<C>> implements Configurable<C> {
private static final Logger LOG = LogUtils.getL7dLogger(ConfigurableImpl.class);

private static final Class<?>[] RESTRICTED_CLASSES_IN_SERVER = {ClientRequestFilter.class,
ClientResponseFilter.class};
private static final Class<?>[] RESTRICTED_CLASSES_IN_CLIENT = {ContainerRequestFilter.class,
ContainerResponseFilter.class};

private ConfigurationImpl config;
private final C configurable;
private final Class<?>[] supportedProviderClasses;

private final Class<?>[] restrictedContractTypes;

public interface Instantiator {
<T> Object create(Class<T> cls);
}

public ConfigurableImpl(C configurable, RuntimeType rt, Class<?>[] supportedProviderClasses) {
this(configurable, supportedProviderClasses, new ConfigurationImpl(rt));
public ConfigurableImpl(C configurable, RuntimeType rt) {
this(configurable, new ConfigurationImpl(rt));
}

public ConfigurableImpl(C configurable, Class<?>[] supportedProviderClasses, Configuration config) {
this(configurable, supportedProviderClasses);
public ConfigurableImpl(C configurable, Configuration config) {
this.configurable = configurable;
this.config = config instanceof ConfigurationImpl
? (ConfigurationImpl)config : new ConfigurationImpl(config, supportedProviderClasses);
? (ConfigurationImpl)config : new ConfigurationImpl(config);
restrictedContractTypes = RuntimeType.CLIENT.equals(config.getRuntimeType()) ? RESTRICTED_CLASSES_IN_CLIENT
: RESTRICTED_CLASSES_IN_SERVER;
}

private ConfigurableImpl(C configurable, Class<?>[] supportedProviderClasses) {
this.configurable = configurable;
this.supportedProviderClasses = supportedProviderClasses;
static Class<?>[] getImplementedContracts(Object provider, Class<?>[] restrictedClasses) {
Class<?> providerClass = provider instanceof Class<?> ? ((Class<?>)provider) : provider.getClass();
Set<Class<?>> interfaces = Arrays.stream(providerClass.getInterfaces()).collect(Collectors.toSet());
providerClass = providerClass.getSuperclass();
for (; providerClass != null && providerClass != Object.class; providerClass = providerClass.getSuperclass()) {
interfaces.addAll(Arrays.stream(providerClass.getInterfaces()).collect(Collectors.toSet()));
}
List<Class<?>> implementedContracts = interfaces.stream()
.filter(el -> Arrays.stream(restrictedClasses).noneMatch(el::equals))
.collect(Collectors.toList());
return implementedContracts.toArray(new Class<?>[]{});
}

protected C getConfigurable() {
return configurable;
}
Expand All @@ -78,7 +104,7 @@ public C register(Object provider) {

@Override
public C register(Object provider, int bindingPriority) {
return doRegister(provider, bindingPriority, supportedProviderClasses);
return doRegister(provider, bindingPriority, getImplementedContracts(provider, restrictedContractTypes));
}

@Override
Expand All @@ -98,7 +124,8 @@ public C register(Class<?> providerClass) {

@Override
public C register(Class<?> providerClass, int bindingPriority) {
return doRegister(getInstantiator().create(providerClass), bindingPriority, supportedProviderClasses);
return doRegister(getInstantiator().create(providerClass), bindingPriority,
getImplementedContracts(providerClass, restrictedContractTypes));
}

@Override
Expand All @@ -110,20 +137,23 @@ public C register(Class<?> providerClass, Class<?>... contracts) {
public C register(Class<?> providerClass, Map<Class<?>, Integer> contracts) {
return register(getInstantiator().create(providerClass), contracts);
}

protected Instantiator getInstantiator() {
return ConfigurationImpl::createProvider;
}

private C doRegister(Object provider, int bindingPriority, Class<?>... contracts) {
if (contracts == null || contracts.length == 0) {
LOG.warning("Null or empty contracts specified for " + provider + "; ignoring.");
LOG.warning("Null, empty or invalid contracts specified for " + provider + "; ignoring.");
return configurable;
}
return doRegister(provider, ConfigurationImpl.initContractsMap(bindingPriority, contracts));
}

private C doRegister(Object provider, Map<Class<?>, Integer> contracts) {
if (!checkConstraints(provider)) {
return configurable;
}
if (provider instanceof Feature) {
Feature feature = (Feature)provider;
boolean enabled = feature.configure(new FeatureContextImpl(this));
Expand All @@ -134,4 +164,33 @@ private C doRegister(Object provider, Map<Class<?>, Integer> contracts) {
config.register(provider, contracts);
return configurable;
}

private boolean checkConstraints(Object provider) {
Class<?> providerClass = provider.getClass();
ConstrainedTo providerConstraint = providerClass.getAnnotation(ConstrainedTo.class);
if (providerConstraint != null) {
RuntimeType currentRuntime = config.getRuntimeType();
RuntimeType providerRuntime = providerConstraint.value();
// need to check (1) whether the registration is occurring in the specified runtime type
// and (2) does the provider implement an invalid interface based on the constrained runtime type
if (!providerRuntime.equals(currentRuntime)) {
LOG.warning("Provider " + provider + " cannot be registered in this " + currentRuntime
+ " runtime because it is constrained to " + providerRuntime + " runtimes.");
return false;
}

Class<?>[] restrictedInterfaces = RuntimeType.CLIENT.equals(providerRuntime) ? RESTRICTED_CLASSES_IN_CLIENT
: RESTRICTED_CLASSES_IN_SERVER;
for (Class<?> restrictedContract : restrictedInterfaces) {
if (restrictedContract.isAssignableFrom(providerClass)) {
RuntimeType opposite = RuntimeType.CLIENT.equals(providerRuntime) ? RuntimeType.SERVER
: RuntimeType.CLIENT;
LOG.warning("Provider " + providerClass.getName() + " is invalid - it is constrained to "
+ providerRuntime + " runtimes but implements a " + opposite + " interface ");
return false;
}
}
}
return true;
}
}
Expand Up @@ -47,34 +47,35 @@ public ConfigurationImpl(RuntimeType rt) {
this.runtimeType = rt;
}

public ConfigurationImpl(Configuration parent, Class<?>[] defaultContracts) {
public ConfigurationImpl(Configuration parent) {
if (parent != null) {
this.props.putAll(parent.getProperties());
this.runtimeType = parent.getRuntimeType();

Set<Class<?>> providerClasses = new HashSet<Class<?>>(parent.getClasses());
for (Object o : parent.getInstances()) {
if (!(o instanceof Feature)) {
registerParentProvider(o, parent, defaultContracts);
registerParentProvider(o, parent);
} else {
Feature f = (Feature)o;
features.put(f, parent.isEnabled(f));
}
providerClasses.remove(o.getClass());
}
for (Class<?> cls : providerClasses) {
registerParentProvider(createProvider(cls), parent, defaultContracts);
registerParentProvider(createProvider(cls), parent);
}

}
}

private void registerParentProvider(Object o, Configuration parent, Class<?>[] defaultContracts) {
private void registerParentProvider(Object o, Configuration parent) {
Map<Class<?>, Integer> contracts = parent.getContracts(o.getClass());
if (contracts != null) {
providers.put(o, contracts);
} else {
register(o, AnnotationUtils.getBindingPriority(o.getClass()), defaultContracts);
register(o, AnnotationUtils.getBindingPriority(o.getClass()),
ConfigurableImpl.getImplementedContracts(o, new Class<?>[]{}));
}
}

Expand Down Expand Up @@ -131,13 +132,15 @@ public RuntimeType getRuntimeType() {

@Override
public boolean isEnabled(Feature f) {
return features.containsKey(f);
return features.containsKey(f) && features.get(f);
}

@Override
public boolean isEnabled(Class<? extends Feature> f) {
for (Feature feature : features.keySet()) {
if (feature.getClass().isAssignableFrom(f)) {
for (Entry<Feature, Boolean> entry : features.entrySet()) {
Feature feature = entry.getKey();
Boolean enabled = entry.getValue();
if (f.isAssignableFrom(feature.getClass()) && enabled.booleanValue()) {
return true;
}
}
Expand Down Expand Up @@ -194,6 +197,10 @@ public boolean register(Object provider, Map<Class<?>, Integer> contracts) {
return false;
}

if (!contractsValid(provider, contracts)) {
return false;
}

Map<Class<?>, Integer> metadata = providers.get(provider);
if (metadata == null) {
metadata = new HashMap<>();
Expand All @@ -207,6 +214,17 @@ public boolean register(Object provider, Map<Class<?>, Integer> contracts) {
return true;
}

private boolean contractsValid(Object provider, Map<Class<?>, Integer> contracts) {
final Class<?> providerClass = provider.getClass();
for (Class<?> contractInterface : contracts.keySet()) {
if (!contractInterface.isAssignableFrom(providerClass)) {
LOG.warning("Provider " + providerClass.getName() + " does not implement specified contract: "
+ contractInterface.getName());
return false;
}
}
return true;
}
public static Map<Class<?>, Integer> initContractsMap(int bindingPriority, Class<?>... contracts) {
Map<Class<?>, Integer> metadata = new HashMap<>();
for (Class<?> contract : contracts) {
Expand Down
Expand Up @@ -19,12 +19,8 @@

package org.apache.cxf.jaxrs.provider;

import javax.ws.rs.container.ContainerRequestFilter;
import javax.ws.rs.container.ContainerResponseFilter;
import javax.ws.rs.core.Configurable;
import javax.ws.rs.core.FeatureContext;
import javax.ws.rs.ext.ReaderInterceptor;
import javax.ws.rs.ext.WriterInterceptor;

/**
* Manages the creation of server-side {@code Configurable<FeatureContext>} depending on
Expand All @@ -34,12 +30,6 @@
* notice, please be aware of that.
*/
public interface ServerConfigurableFactory {
Class<?>[] SERVER_FILTER_INTERCEPTOR_CLASSES = new Class<?>[] {
ContainerRequestFilter.class,
ContainerResponseFilter.class,
ReaderInterceptor.class,
WriterInterceptor.class
};


Configurable<FeatureContext> create(FeatureContext context);
}
Expand Up @@ -453,7 +453,7 @@ protected static boolean isPrematching(Class<?> filterCls) {

private static class ServerFeatureContextConfigurable extends ConfigurableImpl<FeatureContext> {
protected ServerFeatureContextConfigurable(FeatureContext mc) {
super(mc, RuntimeType.SERVER, ServerConfigurableFactory.SERVER_FILTER_INTERCEPTOR_CLASSES);
super(mc, RuntimeType.SERVER);
}
}

Expand Down

0 comments on commit 2fd0b2e

Please sign in to comment.