Permalink
Browse files

Merge pull request #369 from garyrussell/INT-2449

* INT-2449:
  INT-2449 Fix PubSub Subscriber Accounting
  • Loading branch information...
2 parents c5e3bcf + b6b4e25 commit 2f3330c78f718cb3df4fd44153dd0bba67794c0d Oleg Zhurakousky committed Mar 26, 2012
@@ -24,8 +24,8 @@
import org.springframework.integration.MessageDispatchingException;
import org.springframework.integration.core.MessageHandler;
import org.springframework.integration.core.SubscribableChannel;
+import org.springframework.integration.dispatcher.AbstractDispatcher;
import org.springframework.integration.dispatcher.MessageDispatcher;
-import org.springframework.integration.dispatcher.UnicastingDispatcher;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
@@ -44,20 +44,31 @@
public boolean subscribe(MessageHandler handler) {
MessageDispatcher dispatcher = this.getRequiredDispatcher();
boolean added = dispatcher.addHandler(handler);
- if (added) {
- int counter = handlerCounter.incrementAndGet();
- if (logger.isInfoEnabled()) {
- logger.info("Channel '" + this.getComponentName() + "' has " + counter + " subscriber(s).");
- }
- }
+ this.adjustCounterIfNecessary(dispatcher, added ? 1 : 0);
return added;
}
public boolean unsubscribe(MessageHandler handle) {
- if (this.getRequiredDispatcher() instanceof UnicastingDispatcher){
- handlerCounter.getAndDecrement();
+ MessageDispatcher dispatcher = this.getRequiredDispatcher();
+ boolean removed = dispatcher.removeHandler(handle);
+ this.adjustCounterIfNecessary(dispatcher, removed ? -1 : 0);
+ return removed;
+ }
+
+ private void adjustCounterIfNecessary(MessageDispatcher dispatcher, int delta) {
+ if (delta != 0) {
+ int counter = 0;
+ if (dispatcher instanceof AbstractDispatcher) {
+ counter = ((AbstractDispatcher) dispatcher).getHandlerCount();
+ }
+ else {
+ // some other dispatcher - hand-roll the counter
+ counter = handlerCounter.addAndGet(delta);
+ }
+ if (logger.isInfoEnabled()) {
+ logger.info("Channel '" + this.getComponentName() + "' has " + counter + " subscriber(s).");
+ }
}
- return this.getRequiredDispatcher().removeHandler(handle);
}
@Override
@@ -1,5 +1,5 @@
/*
- * Copyright 2002-2011 the original author or authors.
+ * Copyright 2002-2012 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -39,6 +39,7 @@
* @author Mark Fisher
* @author Iwein Fuld
* @author Oleg Zhurakousky
+ * @author Gary Russell
*/
public abstract class AbstractDispatcher implements MessageDispatcher {
@@ -79,4 +80,11 @@ public boolean removeHandler(MessageHandler handler) {
public String toString() {
return this.getClass().getSimpleName() + " with handlers: " + this.handlers.toString();
}
+
+ /**
+ * @return The current number of handlers
+ */
+ public int getHandlerCount() {
+ return this.handlers.size();
+ }
}
@@ -1,5 +1,5 @@
/*
- * Copyright 2002-2011 the original author or authors.
+ * Copyright 2002-2012 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -15,24 +15,31 @@
*/
package org.springframework.integration.channel;
+import static org.junit.Assert.assertEquals;
+import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import java.lang.reflect.Field;
+import java.util.ArrayList;
+import java.util.List;
import java.util.concurrent.Executor;
import org.apache.commons.logging.Log;
import org.junit.Test;
import org.mockito.Mockito;
-
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
import org.springframework.integration.core.MessageHandler;
+import org.springframework.integration.dispatcher.MessageDispatcher;
import org.springframework.util.ReflectionUtils;
import org.springframework.util.ReflectionUtils.FieldCallback;
/**
* @author Oleg Zhurakousky
+ * @author Gary Russell
*
*/
public class P2pChannelTests {
@@ -41,10 +48,40 @@
public void testDirectChannelLoggingWithMoreThenOneSubscriber() {
final DirectChannel channel = new DirectChannel();
channel.setBeanName("directChannel");
-
+
+ verifySubscriptions(channel);
+ }
+
+ @Test
+ public void testCustomChannelLoggingWithMoreThenOneSubscriberNotAbstractDispatcher() {
+ final MessageDispatcher mockDispatcher = mock(MessageDispatcher.class);
+ when(mockDispatcher.addHandler(Mockito.any(MessageHandler.class))).thenReturn(true);
+ when(mockDispatcher.removeHandler(Mockito.any(MessageHandler.class))).thenReturn(true).thenReturn(false).thenReturn(true);
+
+ final AbstractSubscribableChannel channel = new AbstractSubscribableChannel() {
+ @Override
+ protected MessageDispatcher getDispatcher() {
+ return mockDispatcher;
+ }
+ };
+ channel.setBeanName("customChannel");
+
+ verifySubscriptions(channel);
+ }
+
+ /**
+ * @param channel
+ */
+ private void verifySubscriptions(final AbstractSubscribableChannel channel) {
final Log logger = mock(Log.class);
when(logger.isInfoEnabled()).thenReturn(true);
- ReflectionUtils.doWithFields(AbstractMessageChannel.class, new FieldCallback() {
+ final List<String> logs = new ArrayList<String>();
+ doAnswer(new Answer<Object>() {
+ public Object answer(InvocationOnMock invocation) throws Throwable {
+ logs.add((String) invocation.getArguments()[0]);
+ return null;
+ }}).when(logger).info(Mockito.anyString());
+ ReflectionUtils.doWithFields(AbstractMessageChannel.class, new FieldCallback() {
public void doWith(Field field) throws IllegalArgumentException,
IllegalAccessException {
if ("logger".equals(field.getName())){
@@ -53,21 +90,35 @@ public void doWith(Field field) throws IllegalArgumentException,
}
}
});
-
- channel.subscribe(mock(MessageHandler.class));
- channel.subscribe(mock(MessageHandler.class));
- verify(logger, times(2)).info(Mockito.anyString());
+ String log = "Channel '"
+ + channel.getComponentName()
+ + "' has "
+ + "%d subscriber(s).";
+
+ MessageHandler handler1 = mock(MessageHandler.class);
+ channel.subscribe(handler1);
+ assertEquals(String.format(log, 1), logs.remove(0));
+ MessageHandler handler2 = mock(MessageHandler.class);
+ channel.subscribe(handler2);
+ assertEquals(String.format(log, 2), logs.remove(0));
+ channel.unsubscribe(handler1);
+ assertEquals(String.format(log, 1), logs.remove(0));
+ channel.unsubscribe(handler1);
+ assertEquals(0, logs.size());
+ channel.unsubscribe(handler2);
+ assertEquals(String.format(log, 0), logs.remove(0));
+ verify(logger, times(4)).info(Mockito.anyString());
}
@Test
public void testExecutorChannelLoggingWithMoreThenOneSubscriber() {
final ExecutorChannel channel = new ExecutorChannel(mock(Executor.class));
channel.setBeanName("executorChannel");
-
+
final Log logger = mock(Log.class);
when(logger.isInfoEnabled()).thenReturn(true);
ReflectionUtils.doWithFields(AbstractMessageChannel.class, new FieldCallback() {
-
+
public void doWith(Field field) throws IllegalArgumentException,
IllegalAccessException {
if ("logger".equals(field.getName())){
@@ -85,7 +136,7 @@ public void doWith(Field field) throws IllegalArgumentException,
public void testPubSubChannelLoggingWithMoreThenOneSubscriber() {
final PublishSubscribeChannel channel = new PublishSubscribeChannel();
channel.setBeanName("pubSubChannel");
-
+
final Log logger = mock(Log.class);
when(logger.isInfoEnabled()).thenReturn(true);
ReflectionUtils.doWithFields(AbstractMessageChannel.class, new FieldCallback() {

0 comments on commit 2f3330c

Please sign in to comment.