Skip to content

Commit

Permalink
Merge pull request spring-projects#369 from garyrussell/INT-2449
Browse files Browse the repository at this point in the history
* INT-2449:
  INT-2449 Fix PubSub Subscriber Accounting
  • Loading branch information
olegz committed Mar 26, 2012
2 parents c5e3bcf + b6b4e25 commit 2f3330c
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 22 deletions.
Expand Up @@ -24,8 +24,8 @@
import org.springframework.integration.MessageDispatchingException;
import org.springframework.integration.core.MessageHandler;
import org.springframework.integration.core.SubscribableChannel;
import org.springframework.integration.dispatcher.AbstractDispatcher;
import org.springframework.integration.dispatcher.MessageDispatcher;
import org.springframework.integration.dispatcher.UnicastingDispatcher;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;

Expand All @@ -44,20 +44,31 @@ public abstract class AbstractSubscribableChannel extends AbstractMessageChannel
public boolean subscribe(MessageHandler handler) {
MessageDispatcher dispatcher = this.getRequiredDispatcher();
boolean added = dispatcher.addHandler(handler);
if (added) {
int counter = handlerCounter.incrementAndGet();
if (logger.isInfoEnabled()) {
logger.info("Channel '" + this.getComponentName() + "' has " + counter + " subscriber(s).");
}
}
this.adjustCounterIfNecessary(dispatcher, added ? 1 : 0);
return added;
}

public boolean unsubscribe(MessageHandler handle) {
if (this.getRequiredDispatcher() instanceof UnicastingDispatcher){
handlerCounter.getAndDecrement();
MessageDispatcher dispatcher = this.getRequiredDispatcher();
boolean removed = dispatcher.removeHandler(handle);
this.adjustCounterIfNecessary(dispatcher, removed ? -1 : 0);
return removed;
}

private void adjustCounterIfNecessary(MessageDispatcher dispatcher, int delta) {
if (delta != 0) {
int counter = 0;
if (dispatcher instanceof AbstractDispatcher) {
counter = ((AbstractDispatcher) dispatcher).getHandlerCount();
}
else {
// some other dispatcher - hand-roll the counter
counter = handlerCounter.addAndGet(delta);
}
if (logger.isInfoEnabled()) {
logger.info("Channel '" + this.getComponentName() + "' has " + counter + " subscriber(s).");
}
}
return this.getRequiredDispatcher().removeHandler(handle);
}

@Override
Expand Down
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2011 the original author or authors.
* Copyright 2002-2012 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -39,6 +39,7 @@
* @author Mark Fisher
* @author Iwein Fuld
* @author Oleg Zhurakousky
* @author Gary Russell
*/
public abstract class AbstractDispatcher implements MessageDispatcher {

Expand Down Expand Up @@ -79,4 +80,11 @@ public boolean removeHandler(MessageHandler handler) {
public String toString() {
return this.getClass().getSimpleName() + " with handlers: " + this.handlers.toString();
}

/**
* @return The current number of handlers
*/
public int getHandlerCount() {
return this.handlers.size();
}
}
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2011 the original author or authors.
* Copyright 2002-2012 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -15,24 +15,31 @@
*/
package org.springframework.integration.channel;

import static org.junit.Assert.assertEquals;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Executor;

import org.apache.commons.logging.Log;
import org.junit.Test;
import org.mockito.Mockito;

import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import org.springframework.integration.core.MessageHandler;
import org.springframework.integration.dispatcher.MessageDispatcher;
import org.springframework.util.ReflectionUtils;
import org.springframework.util.ReflectionUtils.FieldCallback;

/**
* @author Oleg Zhurakousky
* @author Gary Russell
*
*/
public class P2pChannelTests {
Expand All @@ -41,10 +48,40 @@ public class P2pChannelTests {
public void testDirectChannelLoggingWithMoreThenOneSubscriber() {
final DirectChannel channel = new DirectChannel();
channel.setBeanName("directChannel");


verifySubscriptions(channel);
}

@Test
public void testCustomChannelLoggingWithMoreThenOneSubscriberNotAbstractDispatcher() {
final MessageDispatcher mockDispatcher = mock(MessageDispatcher.class);
when(mockDispatcher.addHandler(Mockito.any(MessageHandler.class))).thenReturn(true);
when(mockDispatcher.removeHandler(Mockito.any(MessageHandler.class))).thenReturn(true).thenReturn(false).thenReturn(true);

final AbstractSubscribableChannel channel = new AbstractSubscribableChannel() {
@Override
protected MessageDispatcher getDispatcher() {
return mockDispatcher;
}
};
channel.setBeanName("customChannel");

verifySubscriptions(channel);
}

/**
* @param channel
*/
private void verifySubscriptions(final AbstractSubscribableChannel channel) {
final Log logger = mock(Log.class);
when(logger.isInfoEnabled()).thenReturn(true);
ReflectionUtils.doWithFields(AbstractMessageChannel.class, new FieldCallback() {
final List<String> logs = new ArrayList<String>();
doAnswer(new Answer<Object>() {
public Object answer(InvocationOnMock invocation) throws Throwable {
logs.add((String) invocation.getArguments()[0]);
return null;
}}).when(logger).info(Mockito.anyString());
ReflectionUtils.doWithFields(AbstractMessageChannel.class, new FieldCallback() {
public void doWith(Field field) throws IllegalArgumentException,
IllegalAccessException {
if ("logger".equals(field.getName())){
Expand All @@ -53,21 +90,35 @@ public void doWith(Field field) throws IllegalArgumentException,
}
}
});

channel.subscribe(mock(MessageHandler.class));
channel.subscribe(mock(MessageHandler.class));
verify(logger, times(2)).info(Mockito.anyString());
String log = "Channel '"
+ channel.getComponentName()
+ "' has "
+ "%d subscriber(s).";

MessageHandler handler1 = mock(MessageHandler.class);
channel.subscribe(handler1);
assertEquals(String.format(log, 1), logs.remove(0));
MessageHandler handler2 = mock(MessageHandler.class);
channel.subscribe(handler2);
assertEquals(String.format(log, 2), logs.remove(0));
channel.unsubscribe(handler1);
assertEquals(String.format(log, 1), logs.remove(0));
channel.unsubscribe(handler1);
assertEquals(0, logs.size());
channel.unsubscribe(handler2);
assertEquals(String.format(log, 0), logs.remove(0));
verify(logger, times(4)).info(Mockito.anyString());
}

@Test
public void testExecutorChannelLoggingWithMoreThenOneSubscriber() {
final ExecutorChannel channel = new ExecutorChannel(mock(Executor.class));
channel.setBeanName("executorChannel");

final Log logger = mock(Log.class);
when(logger.isInfoEnabled()).thenReturn(true);
ReflectionUtils.doWithFields(AbstractMessageChannel.class, new FieldCallback() {

public void doWith(Field field) throws IllegalArgumentException,
IllegalAccessException {
if ("logger".equals(field.getName())){
Expand All @@ -85,7 +136,7 @@ public void doWith(Field field) throws IllegalArgumentException,
public void testPubSubChannelLoggingWithMoreThenOneSubscriber() {
final PublishSubscribeChannel channel = new PublishSubscribeChannel();
channel.setBeanName("pubSubChannel");

final Log logger = mock(Log.class);
when(logger.isInfoEnabled()).thenReturn(true);
ReflectionUtils.doWithFields(AbstractMessageChannel.class, new FieldCallback() {
Expand Down

0 comments on commit 2f3330c

Please sign in to comment.