Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BEAM-5759] Ensuring JmsIO checkpoint state is accessed and modified safely #6702

Merged
merged 3 commits into from
Oct 17, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,17 @@

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import java.util.function.BiFunction;
import java.util.function.Supplier;
import javax.jms.Message;
import org.apache.beam.sdk.coders.AvroCoder;
import org.apache.beam.sdk.coders.DefaultCoder;
import org.apache.beam.sdk.io.UnboundedSource;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.joda.time.Instant;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
* Checkpoint for an unbounded JmsIO.Read. Consists of JMS destination name, and the latest message
Expand All @@ -33,25 +38,27 @@
@DefaultCoder(AvroCoder.class)
public class JmsCheckpointMark implements UnboundedSource.CheckpointMark {

private final List<Message> messages = new ArrayList<>();
private Instant oldestPendingTimestamp = BoundedWindow.TIMESTAMP_MIN_VALUE;
private static final Logger LOG = LoggerFactory.getLogger(JmsCheckpointMark.class);

private final State state = new State();

public JmsCheckpointMark() {}

protected List<Message> getMessages() {
return this.messages;
return state.getMessages();
}

protected void addMessage(Message message) throws Exception {
Instant currentMessageTimestamp = new Instant(message.getJMSTimestamp());
if (currentMessageTimestamp.isBefore(oldestPendingTimestamp)) {
oldestPendingTimestamp = currentMessageTimestamp;
}
messages.add(message);
state.atomicWrite(
() -> {
state.updateOldestPendingTimestampIf(currentMessageTimestamp, Instant::isBefore);
state.addMessage(message);
});
}

protected Instant getOldestPendingTimestamp() {
return oldestPendingTimestamp;
return state.getOldestPendingTimestamp();
}

/**
Expand All @@ -61,17 +68,117 @@ protected Instant getOldestPendingTimestamp() {
*/
@Override
public void finalizeCheckpoint() {
for (Message message : messages) {
State snapshot = state.snapshot();
for (Message message : snapshot.messages) {
try {
message.acknowledge();
Instant currentMessageTimestamp = new Instant(message.getJMSTimestamp());
if (currentMessageTimestamp.isAfter(oldestPendingTimestamp)) {
oldestPendingTimestamp = currentMessageTimestamp;
}
snapshot.updateOldestPendingTimestampIf(currentMessageTimestamp, Instant::isAfter);
} catch (Exception e) {
// nothing to do
LOG.error("Exception while finalizing message: {}", e);
}
}
state.atomicWrite(
() -> {
state.removeMessages(snapshot.messages);
state.updateOldestPendingTimestampIf(snapshot.oldestPendingTimestamp, Instant::isAfter);
});
}

/**
* Encapsulates the state of a checkpoint mark; the list of messages pending finalisation and the
* oldest pending timestamp. Read/write-exclusive access is provided throughout, and constructs
* allowing multiple operations to be performed atomically -- i.e. performed within the context of
* a single lock operation -- are made available.
*/
private class State {
private final ReentrantReadWriteLock lock = new ReentrantReadWriteLock();

private final List<Message> messages;
private Instant oldestPendingTimestamp;

public State() {
this(new ArrayList<>(), BoundedWindow.TIMESTAMP_MIN_VALUE);
}

private State(List<Message> messages, Instant oldestPendingTimestamp) {
this.messages = messages;
this.oldestPendingTimestamp = oldestPendingTimestamp;
}

/**
* Create and return a copy of the current state.
*
* @return A new {@code State} instance which is a deep copy of the target instance at the time
* of execution.
*/
public State snapshot() {
return atomicRead(() -> new State(new ArrayList<>(messages), oldestPendingTimestamp));
}

public Instant getOldestPendingTimestamp() {
return atomicRead(() -> oldestPendingTimestamp);
}

public List<Message> getMessages() {
return atomicRead(() -> messages);
}

public void addMessage(Message message) {
atomicWrite(() -> messages.add(message));
}

public void removeMessages(List<Message> messages) {
atomicWrite(() -> this.messages.removeAll(messages));
}

/**
* Conditionally sets {@code oldestPendingTimestamp} to the value of the supplied {@code
* candidate}, iff the provided {@code check} yields true for the {@code candidate} when called
* with the existing {@code oldestPendingTimestamp} value.
*
* @param candidate The potential new value.
* @param check The comparison method to call on {@code candidate} passing the existing {@code
* oldestPendingTimestamp} value as a parameter.
*/
private void updateOldestPendingTimestampIf(
Instant candidate, BiFunction<Instant, Instant, Boolean> check) {
atomicWrite(
() -> {
if (check.apply(candidate, oldestPendingTimestamp)) {
oldestPendingTimestamp = candidate;
}
});
}

/**
* Call the provided {@link Supplier} under this State's read lock and return its result.
*
* @param operation The code to execute in the context of this State's read lock.
* @param <T> The return type of the provided {@link Supplier}.
* @return The value produced by the provided {@link Supplier}.
*/
public <T> T atomicRead(Supplier<T> operation) {
lock.readLock().lock();
try {
return operation.get();
} finally {
lock.readLock().unlock();
}
}

/**
* Call the provided {@link Runnable} under this State's write lock.
*
* @param operation The code to execute in the context of this State's write lock.
*/
public void atomicWrite(Runnable operation) {
lock.writeLock().lock();
try {
operation.run();
} finally {
lock.writeLock().unlock();
}
}
messages.clear();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,13 @@
import static org.junit.Assert.fail;

import com.google.common.base.Throwables;
import java.io.IOException;
import java.lang.reflect.Proxy;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Enumeration;
import java.util.List;
import java.util.function.Function;
import javax.jms.BytesMessage;
import javax.jms.Connection;
import javax.jms.ConnectionFactory;
Expand All @@ -41,9 +44,11 @@
import org.apache.activemq.ActiveMQConnectionFactory;
import org.apache.activemq.broker.BrokerPlugin;
import org.apache.activemq.broker.BrokerService;
import org.apache.activemq.command.ActiveMQMessage;
import org.apache.activemq.security.AuthenticationUser;
import org.apache.activemq.security.SimpleAuthenticationPlugin;
import org.apache.activemq.store.memory.MemoryPersistenceAdapter;
import org.apache.activemq.util.Callback;
import org.apache.beam.sdk.coders.SerializableCoder;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
Expand Down Expand Up @@ -91,7 +96,6 @@ public void startBroker() throws Exception {
// username and password to use to connect to the broker.
// This user has users privilege (able to browse, consume, produce, list destinations)
users.add(new AuthenticationUser(USERNAME, PASSWORD, "users"));
users.add(new AuthenticationUser(USERNAME, PASSWORD, "users"));
SimpleAuthenticationPlugin plugin = new SimpleAuthenticationPlugin(users);
BrokerPlugin[] plugins = new BrokerPlugin[] {plugin};
broker.setPlugins(plugins);
Expand Down Expand Up @@ -329,6 +333,76 @@ public void testCheckpointMark() throws Exception {
assertEquals(0, count(QUEUE));
}

@Test
public void testCheckpointMarkSafety() throws Exception {

final int messagesToProcess = 100;

// we are using no prefetch here
// prefetch is an ActiveMQ feature: to make efficient use of network resources the broker
// utilizes a 'push' model to dispatch messages to consumers. However, in the case of our
// test, it means that we can have some latency between the receiveNoWait() method used by
// the consumer and the prefetch buffer populated by the broker. Using a prefetch to 0 means
// that the consumer will poll for message, which is exactly what we want for the test.
// We are also sending message acknowledgements synchronously to ensure that they are
// processed before any subsequent assertions.
Connection connection =
connectionFactoryWithSyncAcksAndWithoutPrefetch.createConnection(USERNAME, PASSWORD);
connection.start();
Session session = connection.createSession(false, Session.CLIENT_ACKNOWLEDGE);

// Fill the queue with messages
MessageProducer producer = session.createProducer(session.createQueue(QUEUE));
for (int i = 0; i < messagesToProcess; i++) {
producer.send(session.createTextMessage("test " + i));
}
producer.close();
session.close();
connection.close();

// create a JmsIO.Read with a decorated ConnectionFactory which will introduce a delay in sending
// acknowledgements - this should help uncover threading issues around checkpoint management.
JmsIO.Read spec =
JmsIO.read()
.withConnectionFactory(
withSlowAcks(connectionFactoryWithSyncAcksAndWithoutPrefetch, 10))
.withUsername(USERNAME)
.withPassword(PASSWORD)
.withQueue(QUEUE);
JmsIO.UnboundedJmsSource source = new JmsIO.UnboundedJmsSource(spec);
JmsIO.UnboundedJmsReader reader = source.createReader(null, null);

// start the reader and move to the first record
assertTrue(reader.start());

// consume half the messages (NB: start already consumed the first message)
for (int i = 0; i < (messagesToProcess / 2) - 1; i++) {
assertTrue(reader.advance());
}

// the messages are still pending in the queue (no ACK yet)
assertEquals(messagesToProcess, count(QUEUE));

// we finalize the checkpoint for the already-processed messages while simultaneously consuming the remainder of
// messages from the queue
Thread runner =
new Thread(
() -> {
try {
for (int i = 0; i < messagesToProcess / 2; i++) {
assertTrue(reader.advance());
}
} catch (IOException ex) {
throw new RuntimeException(ex);
}
});
runner.start();
reader.getCheckpointMark().finalizeCheckpoint();

// Concurrency issues would cause an exception to be thrown before this method exits, failing the test
runner.join();
}

private int count(String queue) throws Exception {
Connection connection = connectionFactory.createConnection(USERNAME, PASSWORD);
connection.start();
Expand All @@ -355,4 +429,63 @@ public String mapMessage(Message message) throws Exception {
return new String(bytes, StandardCharsets.UTF_8);
}
}

/*
* A utility method which replaces a ConnectionFactory with one where calling receiveNoWait() -- i.e. pulling a
* message -- will return a message with its acknowledgement callback decorated to include a sleep for a specified
* duration. This gives the effect of ensuring messages take at least {@code delay} milliseconds to be processed.
*/
private ConnectionFactory withSlowAcks(ConnectionFactory factory, long delay) {
return proxyMethod(
factory,
ConnectionFactory.class,
"createConnection",
(Connection connection) ->
proxyMethod(
connection,
Connection.class,
"createSession",
(Session session) ->
proxyMethod(
session,
Session.class,
"createConsumer",
(MessageConsumer consumer) ->
proxyMethod(
consumer,
MessageConsumer.class,
"receiveNoWait",
(ActiveMQMessage message) -> {
final Callback originalCallback =
message.getAcknowledgeCallback();
message.setAcknowledgeCallback(
() -> {
Thread.sleep(delay);
originalCallback.execute();
});
return message;
}))));
}

/*
* A utility method which decorates an existing object with a proxy instance adhering to a given interface, with the
* specified method name having its return value transformed by the provided function.
*/
private <T, MethodArgT, MethodResultT> T proxyMethod(
T target,
Class<? super T> proxyInterface,
String methodName,
Function<MethodArgT, MethodResultT> resultTransformer) {
return (T)
Proxy.newProxyInstance(
this.getClass().getClassLoader(),
new Class[] {proxyInterface},
(proxy, method, args) -> {
Object result = method.invoke(target, args);
if (method.getName().equals(methodName)) {
result = resultTransformer.apply((MethodArgT) result);
}
return result;
});
}
}