diff --git a/client/src/main/java/org/apache/rocketmq/client/consumer/DefaultMQPushConsumer.java b/client/src/main/java/org/apache/rocketmq/client/consumer/DefaultMQPushConsumer.java index 2cce03d34738..1776a54a0676 100644 --- a/client/src/main/java/org/apache/rocketmq/client/consumer/DefaultMQPushConsumer.java +++ b/client/src/main/java/org/apache/rocketmq/client/consumer/DefaultMQPushConsumer.java @@ -214,6 +214,11 @@ public class DefaultMQPushConsumer extends ClientConfig implements MQPushConsume */ private long consumeTimeout = 15; + /** + * Maximum time to await message consuming when shutdown consumer, 0 indicates no await. + */ + private long awaitTerminationMillisWhenShutdown = 0; + /** * Default constructor. */ @@ -461,7 +466,7 @@ public void start() throws MQClientException { */ @Override public void shutdown() { - this.defaultMQPushConsumerImpl.shutdown(); + this.defaultMQPushConsumerImpl.shutdown(awaitTerminationMillisWhenShutdown); } @Override @@ -616,4 +621,12 @@ public long getConsumeTimeout() { public void setConsumeTimeout(final long consumeTimeout) { this.consumeTimeout = consumeTimeout; } + + public long getAwaitTerminationMillisWhenShutdown() { + return awaitTerminationMillisWhenShutdown; + } + + public void setAwaitTerminationMillisWhenShutdown(long awaitTerminationMillisWhenShutdown) { + this.awaitTerminationMillisWhenShutdown = awaitTerminationMillisWhenShutdown; + } } diff --git a/client/src/main/java/org/apache/rocketmq/client/impl/consumer/ConsumeMessageConcurrentlyService.java b/client/src/main/java/org/apache/rocketmq/client/impl/consumer/ConsumeMessageConcurrentlyService.java index f566ed0fcca4..ae4118fe30ad 100644 --- a/client/src/main/java/org/apache/rocketmq/client/impl/consumer/ConsumeMessageConcurrentlyService.java +++ b/client/src/main/java/org/apache/rocketmq/client/impl/consumer/ConsumeMessageConcurrentlyService.java @@ -92,9 +92,23 @@ public void run() { }, this.defaultMQPushConsumer.getConsumeTimeout(), this.defaultMQPushConsumer.getConsumeTimeout(), TimeUnit.MINUTES); } - public void shutdown() { + @Override + public void shutdown(long awaitTerminateMillis) { this.scheduledExecutorService.shutdown(); this.consumeExecutor.shutdown(); + //await to consume + if (awaitTerminateMillis > 0) { + try { + this.consumeExecutor.awaitTermination(awaitTerminateMillis,TimeUnit.MILLISECONDS); + if (!this.consumeExecutor.isTerminated()) { + log.info("There are messages still being consumed in thread pool, but not going to await them anymore after waiting for {} ms",awaitTerminateMillis); + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); // set interrupt flag + log.warn("got InterruptedException when await termination"); + } + } + this.cleanExpireMsgExecutors.shutdown(); } diff --git a/client/src/main/java/org/apache/rocketmq/client/impl/consumer/ConsumeMessageOrderlyService.java b/client/src/main/java/org/apache/rocketmq/client/impl/consumer/ConsumeMessageOrderlyService.java index 1fa474caa1d2..5b74091effb2 100644 --- a/client/src/main/java/org/apache/rocketmq/client/impl/consumer/ConsumeMessageOrderlyService.java +++ b/client/src/main/java/org/apache/rocketmq/client/impl/consumer/ConsumeMessageOrderlyService.java @@ -92,10 +92,22 @@ public void run() { } } - public void shutdown() { + @Override + public void shutdown(long awaitTerminateMillis) { this.stopped = true; this.scheduledExecutorService.shutdown(); this.consumeExecutor.shutdown(); + //await to consume + if (awaitTerminateMillis > 0) { + try { + this.consumeExecutor.awaitTermination(awaitTerminateMillis,TimeUnit.MILLISECONDS); + if (!this.consumeExecutor.isTerminated()) log.info("There are messages still being consumed in thread pool, but not going to await them anymore. Have awaited for {} ms",awaitTerminateMillis); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); // set interrupt flag + log.warn("got InterruptedException when awaitTermination"); + } + } + if (MessageModel.CLUSTERING.equals(this.defaultMQPushConsumerImpl.messageModel())) { this.unlockAllMQ(); } diff --git a/client/src/main/java/org/apache/rocketmq/client/impl/consumer/ConsumeMessageService.java b/client/src/main/java/org/apache/rocketmq/client/impl/consumer/ConsumeMessageService.java index 8742191b59d5..ab4448bc4025 100644 --- a/client/src/main/java/org/apache/rocketmq/client/impl/consumer/ConsumeMessageService.java +++ b/client/src/main/java/org/apache/rocketmq/client/impl/consumer/ConsumeMessageService.java @@ -24,7 +24,7 @@ public interface ConsumeMessageService { void start(); - void shutdown(); + void shutdown(long awaitTerminateMillis); void updateCorePoolSize(int corePoolSize); diff --git a/client/src/main/java/org/apache/rocketmq/client/impl/consumer/DefaultMQPushConsumerImpl.java b/client/src/main/java/org/apache/rocketmq/client/impl/consumer/DefaultMQPushConsumerImpl.java index 4f33732dddb7..d91adf1c132f 100644 --- a/client/src/main/java/org/apache/rocketmq/client/impl/consumer/DefaultMQPushConsumerImpl.java +++ b/client/src/main/java/org/apache/rocketmq/client/impl/consumer/DefaultMQPushConsumerImpl.java @@ -515,12 +515,12 @@ private int getMaxReconsumeTimes() { } } - public void shutdown() { + public void shutdown(long awaitTerminateMillis) { switch (this.serviceState) { case CREATE_JUST: break; case RUNNING: - this.consumeMessageService.shutdown(); + this.consumeMessageService.shutdown(awaitTerminateMillis); this.persistConsumerOffset(); this.mQClientFactory.unregisterConsumer(this.defaultMQPushConsumer.getConsumerGroup()); this.mQClientFactory.shutdown(); @@ -593,7 +593,7 @@ public void start() throws MQClientException { boolean registerOK = mQClientFactory.registerConsumer(this.defaultMQPushConsumer.getConsumerGroup(), this); if (!registerOK) { this.serviceState = ServiceState.CREATE_JUST; - this.consumeMessageService.shutdown(); + this.consumeMessageService.shutdown(defaultMQPushConsumer.getAwaitTerminationMillisWhenShutdown()); throw new MQClientException("The consumer group[" + this.defaultMQPushConsumer.getConsumerGroup() + "] has been created before, specify another name please." + FAQUrl.suggestTodo(FAQUrl.GROUP_NAME_DUPLICATE_URL), null); diff --git a/client/src/test/java/org/apache/rocketmq/client/consumer/DefaultMQPushConsumerTest.java b/client/src/test/java/org/apache/rocketmq/client/consumer/DefaultMQPushConsumerTest.java index 2e0af5affdd0..53ac37a6369d 100644 --- a/client/src/test/java/org/apache/rocketmq/client/consumer/DefaultMQPushConsumerTest.java +++ b/client/src/test/java/org/apache/rocketmq/client/consumer/DefaultMQPushConsumerTest.java @@ -19,10 +19,13 @@ import java.io.ByteArrayOutputStream; import java.lang.reflect.Field; import java.net.InetSocketAddress; +import java.util.ArrayList; import java.util.Collections; import java.util.HashSet; +import java.util.LinkedList; import java.util.List; import java.util.Set; +import java.util.TreeSet; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import org.apache.rocketmq.client.consumer.listener.ConsumeConcurrentlyContext; @@ -31,6 +34,7 @@ import org.apache.rocketmq.client.consumer.listener.ConsumeOrderlyStatus; import org.apache.rocketmq.client.consumer.listener.MessageListenerConcurrently; import org.apache.rocketmq.client.consumer.listener.MessageListenerOrderly; +import org.apache.rocketmq.client.consumer.store.ReadOffsetType; import org.apache.rocketmq.client.exception.MQBrokerException; import org.apache.rocketmq.client.impl.CommunicationMode; import org.apache.rocketmq.client.impl.FindBrokerResult; @@ -52,6 +56,7 @@ import org.apache.rocketmq.common.protocol.header.PullMessageRequestHeader; import org.apache.rocketmq.remoting.exception.RemotingException; import org.junit.After; +import org.junit.Assert; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -135,6 +140,7 @@ public void init() throws Exception { messageClientExt.setOffsetMsgId("234"); messageClientExt.setBornHost(new InetSocketAddress(8080)); messageClientExt.setStoreHost(new InetSocketAddress(8080)); + messageClientExt.setQueueOffset(((PullMessageRequestHeader)mock.getArgument(1)).getQueueOffset()); PullResult pullResult = createPullResult(requestHeader, PullStatus.FOUND, Collections.singletonList(messageClientExt)); ((PullCallback)mock.getArgument(4)).onSuccess(pullResult); return pullResult; @@ -174,6 +180,37 @@ public void testPullMessage_Success() throws InterruptedException, RemotingExcep assertThat(messageExts[0].getBody()).isEqualTo(new byte[] {'a'}); } + @Test + public void testShutdownAwait() throws Exception { + final LinkedList consumedOffset = new LinkedList<>(); + pushConsumer.setPullInterval(0); + pushConsumer.setPullThresholdForQueue(100); + pushConsumer.setAwaitTerminationMillisWhenShutdown(60* 1000);//await consume for at most 60 seconds. If we do not set await millis, this test case will not pass + pushConsumer.getDefaultMQPushConsumerImpl().setConsumeMessageService(new ConsumeMessageConcurrentlyService(pushConsumer.getDefaultMQPushConsumerImpl(), new MessageListenerConcurrently() { + @Override public ConsumeConcurrentlyStatus consumeMessage(List msgs, + ConsumeConcurrentlyContext context) { + for (MessageExt msg : msgs) { + try { + Thread.sleep(100); + } catch (InterruptedException e) {e.printStackTrace();} + synchronized (consumedOffset) { + consumedOffset.add(msg.getQueueOffset()); + } + } + return ConsumeConcurrentlyStatus.CONSUME_SUCCESS; + } + })); + pushConsumer.getDefaultMQPushConsumerImpl().doRebalance(); + PullMessageService pullMessageService = mQClientFactory.getPullMessageService(); + pullMessageService.executePullRequestImmediately(createPullRequest()); + Thread.sleep(1000); + pushConsumer.shutdown(); + long persitOffset =pushConsumer.getDefaultMQPushConsumerImpl().getOffsetStore().readOffset(new MessageQueue(topic, brokerName, 0), ReadOffsetType.READ_FROM_MEMORY); + Thread.sleep(1000);//wait for thread pool to continue consume for sometime if not terminated well + Collections.sort(consumedOffset); + Assert.assertEquals("actual consumed offset is not equals to persist offset when shutdown await", consumedOffset.getLast() + 1, persitOffset);//when shutdown with await, the persisted offset should be the latest message offset + } + @Test public void testPullMessage_SuccessWithOrderlyService() throws Exception { final CountDownLatch countDownLatch = new CountDownLatch(1);