Skip to content

Commit

Permalink
FLUME-3437 - Improve JMSSource validation
Browse files Browse the repository at this point in the history
  • Loading branch information
rgoers committed Oct 8, 2022
1 parent 34f3722 commit eee179a
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,11 @@
import javax.jms.Topic;
import javax.naming.InitialContext;
import javax.naming.NamingException;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.ArrayList;
import java.util.List;

class JMSMessageConsumer {
private static final Logger logger = LoggerFactory.getLogger(JMSMessageConsumer.class);
private static final String JAVA_SCHEME = "java";

private final int batchSize;
private final long pollTimeout;
Expand Down Expand Up @@ -102,14 +99,7 @@ class JMSMessageConsumer {
throw new IllegalStateException(String.valueOf(destinationType));
}
} else {
try {
URI uri = new URI(destinationName);
String scheme = uri.getScheme();
assertTrue(scheme == null || scheme.equals(JAVA_SCHEME),
"Unsupported JNDI URI: " + destinationName);
} catch (URISyntaxException ex) {
logger.warn("Invalid JNDI URI - {}", destinationName);
}
JMSSource.verifyContext(destinationName);
destination = (Destination) initialContext.lookup(destinationName);
}
} catch (JMSException e) {
Expand Down Expand Up @@ -220,8 +210,4 @@ void close() {
logger.error("Could not destroy connection", e);
}
}

private void assertTrue(boolean arg, String msg) {
Preconditions.checkArgument(arg, msg);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,13 @@
@InterfaceStability.Stable
public interface JMSMessageConverter {

public List<Event> convert(Message message) throws JMSException;
List<Event> convert(Message message) throws JMSException;

/**
* Implementors of JMSMessageConverter must either provide
* a suitable builder or implement the Configurable interface.
*/
public interface Builder {
public JMSMessageConverter build(Context context);
interface Builder {
JMSMessageConverter build(Context context);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import java.io.IOException;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Properties;
Expand Down Expand Up @@ -56,6 +58,7 @@
public class JMSSource extends AbstractPollableSource implements BatchSizeSupported {
private static final Logger logger = LoggerFactory.getLogger(JMSSource.class);
private static final String JAVA_SCHEME = "java";
public static final String JNDI_ALLOWED_PROTOCOLS = "JndiAllowedProtocols";

// setup by constructor
private final InitialContextFactory initialContextFactory;
Expand All @@ -82,6 +85,7 @@ public class JMSSource extends AbstractPollableSource implements BatchSizeSuppor

private int jmsExceptionCounter;
private InitialContext initialContext;
private static List<String> allowedSchemes = getAllowedProtocols();

public JMSSource() {
this(new InitialContextFactory());
Expand All @@ -92,6 +96,34 @@ public JMSSource(InitialContextFactory initialContextFactory) {
this.initialContextFactory = initialContextFactory;
}

private static List<String> getAllowedProtocols() {
String allowed = System.getProperty(JNDI_ALLOWED_PROTOCOLS, null);
if (allowed == null) {
return Collections.singletonList(JAVA_SCHEME);
} else {
String[] items = allowed.split(",");
List<String> schemes = new ArrayList<>();
schemes.add(JAVA_SCHEME);
for (String item : items) {
if (!item.equals(JAVA_SCHEME)) {
schemes.add(item.trim());
}
}
return schemes;
}
}

public static void verifyContext(String location) {
try {
String scheme = new URI(location).getScheme();
if (scheme != null && !allowedSchemes.contains(scheme)) {
throw new IllegalArgumentException("Invalid JNDI URI: " + location);
}
} catch (URISyntaxException ex) {
logger.trace("{}} is not a valid URI", location);
}
}

@Override
protected void doConfigure(Context context) throws FlumeException {
sourceCounter = new SourceCounter(getName());
Expand All @@ -100,14 +132,7 @@ protected void doConfigure(Context context) throws FlumeException {
JMSSourceConfiguration.INITIAL_CONTEXT_FACTORY, "").trim();

providerUrl = context.getString(JMSSourceConfiguration.PROVIDER_URL, "").trim();
try {
URI uri = new URI(providerUrl);
String scheme = uri.getScheme();
assertTrue(scheme == null || scheme.equals(JAVA_SCHEME),
"Unsupported JNDI URI: " + providerUrl);
} catch (URISyntaxException ex) {
logger.warn("Invalid JNDI URI - {}", providerUrl);
}
verifyContext(providerUrl);

destinationName = context.getString(JMSSourceConfiguration.DESTINATION_NAME, "").trim();

Expand Down Expand Up @@ -190,14 +215,7 @@ protected void doConfigure(Context context) throws FlumeException {
String connectionFactoryName = context.getString(
JMSSourceConfiguration.CONNECTION_FACTORY,
JMSSourceConfiguration.CONNECTION_FACTORY_DEFAULT).trim();
try {
URI uri = new URI(connectionFactoryName);
String scheme = uri.getScheme();
assertTrue(scheme == null || scheme.equals(JAVA_SCHEME),
"Unsupported JNDI URI: " + connectionFactoryName);
} catch (URISyntaxException ex) {
logger.warn("Invalid JNDI URI - {}", connectionFactoryName);
}
verifyContext(connectionFactoryName);

assertNotEmpty(initialContextFactoryName, String.format(
"Initial Context Factory is empty. This is specified by %s",
Expand Down Expand Up @@ -291,10 +309,6 @@ private void assertNotEmpty(String arg, String msg) {
Preconditions.checkArgument(!arg.isEmpty(), msg);
}

private void assertTrue(boolean arg, String msg) {
Preconditions.checkArgument(arg, msg);
}

@Override
protected synchronized Status doProcess() throws EventDeliveryException {
boolean error = true;
Expand Down Expand Up @@ -322,14 +336,12 @@ protected synchronized Status doProcess() throws EventDeliveryException {
sourceCounter.incrementChannelWriteFail();
} catch (JMSException jmsException) {
logger.warn("JMSException consuming events", jmsException);
if (++jmsExceptionCounter > errorThreshold) {
if (consumer != null) {
logger.warn("Exceeded JMSException threshold, closing consumer");
sourceCounter.incrementEventReadFail();
consumer.rollback();
consumer.close();
consumer = null;
}
if (++jmsExceptionCounter > errorThreshold && consumer != null) {
logger.warn("Exceeded JMSException threshold, closing consumer");
sourceCounter.incrementEventReadFail();
consumer.rollback();
consumer.close();
consumer = null;
}
} catch (Throwable throwable) {
logger.error("Unexpected error processing events", throwable);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ private enum TestMode {
private final String jmsPassword;

public TestIntegrationActiveMQ(TestMode testMode) {
System.setProperty(JMSSource.JNDI_ALLOWED_PROTOCOLS, "tcp");
LOGGER.info("Testing with test mode {}", testMode);

switch (testMode) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ public class TestJMSSource extends JMSMessageConsumerTestBase {
@SuppressWarnings("unchecked")
@Override
void afterSetup() throws Exception {
System.setProperty(JMSSource.JNDI_ALLOWED_PROTOCOLS, "dummy");
baseDir = Files.createTempDir();
passwordFile = new File(baseDir, "password");
Assert.assertTrue(passwordFile.createNewFile());
Expand Down

0 comments on commit eee179a

Please sign in to comment.