Skip to content

Commit

Permalink
Merge pull request #4401 from Particular/hotfix-6.0.5
Browse files Browse the repository at this point in the history
Reuse of incoming message as saga timeout message causes duplicate handler invocations
  • Loading branch information
timbussmann committed Dec 16, 2016
2 parents 56dfc0d + 0895202 commit e573354
Show file tree
Hide file tree
Showing 7 changed files with 167 additions and 24 deletions.
Expand Up @@ -93,6 +93,7 @@
<Compile Include="Sagas\When_a_base_class_mapped_is_handled_by_a_saga.cs" />
<Compile Include="Sagas\When_a_base_class_message_starts_a_saga.cs" />
<Compile Include="Sagas\When_saga_started_concurrently.cs" />
<Compile Include="Sagas\When_using_a_received_message_for_timeout.cs" />
<Compile Include="Serialization\When_configuring_custom_xml_namespace.cs" />
<Compile Include="Serialization\When_registering_additional_deserializers.cs" />
<Compile Include="Basic\When_no_content_type.cs" />
Expand Down Expand Up @@ -309,7 +310,7 @@
<Compile Include="Sagas\When_started_by_event_from_another_saga.cs" />
<Compile Include="Sagas\When_a_existing_saga_instance_exists.cs" />
<Compile Include="Sagas\When_two_sagas_subscribe_to_the_same_event.cs" />
<Compile Include="Sagas\When_using_a_received_message_for_timeout.cs" />
<Compile Include="Sagas\When_handling_message_with_handler_and_timeout_handler.cs" />
<Compile Include="Sagas\When_receiving_that_completes_the_saga.cs" />
<Compile Include="Sagas\When_receiving_that_should_start_a_saga.cs" />
<Compile Include="Sagas\When_saga_message_goes_through_delayed_retries.cs" />
Expand Down
@@ -0,0 +1,78 @@
namespace NServiceBus.AcceptanceTests.Sagas
{
using System;
using System.Threading.Tasks;
using AcceptanceTesting;
using EndpointTemplates;
using Features;
using NUnit.Framework;

public class When_handling_message_with_handler_and_timeout_handler : NServiceBusAcceptanceTest
{
[Test]
public async Task Should_not_invoke_timeout_handler()
{
var context = await Scenario.Define<Context>()
.WithEndpoint<TimeoutSagaEndpoint>(g => g.When(session => session.SendLocal(new StartSagaMessage
{
SomeId = Guid.NewGuid()
})))
.Done(c => c.HandlerInvoked || c.TimeoutHandlerInvoked)
.Run();

Assert.True(context.HandlerInvoked, "Regular handler should be invoked");
Assert.False(context.TimeoutHandlerInvoked, "Timeout handler should not be invoked");
}

public class Context : ScenarioContext
{
public bool TimeoutHandlerInvoked { get; set; }
public bool HandlerInvoked { get; set; }
}

public class TimeoutSagaEndpoint : EndpointConfigurationBuilder
{
public TimeoutSagaEndpoint()
{
EndpointSetup<DefaultServer>(config => config.EnableFeature<TimeoutManager>());
}

public class HandlerAndTimeoutSaga : Saga<HandlerAndTimeoutSagaData>, IAmStartedByMessages<StartSagaMessage>,
IHandleTimeouts<StartSagaMessage>
{
public Context TestContext { get; set; }

public Task Handle(StartSagaMessage message, IMessageHandlerContext context)
{
TestContext.HandlerInvoked = true;
return Task.FromResult(0);
}

public Task Timeout(StartSagaMessage message, IMessageHandlerContext context)
{
TestContext.TimeoutHandlerInvoked = true;
return Task.FromResult(0);
}

protected override void ConfigureHowToFindSaga(SagaPropertyMapper<HandlerAndTimeoutSagaData> mapper)
{
mapper.ConfigureMapping<StartSagaMessage>(m => m.SomeId)
.ToSaga(s => s.SomeId);
}
}

public class HandlerAndTimeoutSagaData : IContainSagaData
{
public virtual Guid SomeId { get; set; }
public virtual Guid Id { get; set; }
public virtual string Originator { get; set; }
public virtual string OriginalMessageId { get; set; }
}
}

public class StartSagaMessage : IMessage
{
public Guid SomeId { get; set; }
}
}
}
Expand Up @@ -10,24 +10,24 @@
public class When_using_a_received_message_for_timeout : NServiceBusAcceptanceTest
{
[Test]
public Task Timeout_should_be_received_after_expiration()
public async Task Timeout_should_be_received_after_expiration()
{
return Scenario.Define<Context>(c => { c.Id = Guid.NewGuid(); })
var context = await Scenario.Define<Context>()
.WithEndpoint<ReceiveMessageForTimeoutEndpoint>(g => g.When(session => session.SendLocal(new StartSagaMessage
{
SomeId = Guid.NewGuid()
})))
.Done(c => c.TimeoutReceived)
.Run();

Assert.True(context.TimeoutReceived);
Assert.AreEqual(1, context.HandlerCalled);
}

public class Context : ScenarioContext
{
public Guid Id { get; set; }

public bool StartSagaMessageReceived { get; set; }

public bool TimeoutReceived { get; set; }
public int HandlerCalled { get; set; }
}

public class ReceiveMessageForTimeoutEndpoint : EndpointConfigurationBuilder
Expand All @@ -43,7 +43,7 @@ public class TestSaga01 : Saga<TestSagaData01>, IAmStartedByMessages<StartSagaMe

public Task Handle(StartSagaMessage message, IMessageHandlerContext context)
{
Data.SomeId = message.SomeId;
TestContext.HandlerCalled++;
return RequestTimeout(context, TimeSpan.FromMilliseconds(100), message);
}

Expand Down
@@ -1,10 +1,12 @@
namespace NServiceBus.Core.Tests.Handlers
{
using System;
using System.Linq;
using System.Threading.Tasks;
using NUnit.Framework;
using Unicast;

[TestFixture]
public class MessageHandlerRegistryTests
{
[TestCase(typeof(HandlerWithIMessageSessionProperty))]
Expand All @@ -20,6 +22,35 @@ public void ShouldThrowIfUserTriesToBypassTheHandlerContext(Type handlerType)
Assert.Throws<Exception>(() => registry.RegisterHandler(handlerType));
}

[Test] public void ShouldIndicateWhetherAHandlerIsATimeoutHandler() { var registry = new MessageHandlerRegistry(new Conventions()); registry.RegisterHandler(typeof(SagaWithTimeoutOfMessage)); var handlers = registry.GetHandlersFor(typeof(MyMessage));

Assert.AreEqual(2, handlers.Count);

var timeoutHandler = handlers.SingleOrDefault(h => h.IsTimeoutHandler);

Assert.NotNull(timeoutHandler, "Timeout handler should be marked as such");

var timeoutInstance = new SagaWithTimeoutOfMessage();

timeoutHandler.Instance = timeoutInstance;
timeoutHandler.Invoke(null, null);

Assert.True(timeoutInstance.TimeoutCalled);
Assert.False(timeoutInstance.HandlerCalled);

var regularHandler = handlers.SingleOrDefault(h => !h.IsTimeoutHandler);

Assert.NotNull(regularHandler, "Regular handler should be marked as timeout handler");

var regularInstance = new SagaWithTimeoutOfMessage();

regularHandler.Instance = regularInstance;
regularHandler.Invoke(null, null);

Assert.False(regularInstance.TimeoutCalled);
Assert.True(regularInstance.HandlerCalled);
}

class HandlerWithIMessageSessionProperty : IHandleMessages<MyMessage>
{
public IMessageSession MessageSession { get; set; }
Expand Down Expand Up @@ -107,8 +138,21 @@ class HandlerBaseWithIMessageSessionDep
public IMessageSession MessageSession { get; set; }
}

class MyMessage
class MyMessage : IMessage
{
}

class SagaWithTimeoutOfMessage : Saga<SagaWithTimeoutOfMessage.MySagaData>, IAmStartedByMessages<MyMessage>, IHandleTimeouts<MyMessage> { public Task Handle(MyMessage message, IMessageHandlerContext context) {
HandlerCalled = true;
return TaskEx.CompletedTask;
} protected override void ConfigureHowToFindSaga(SagaPropertyMapper<MySagaData> mapper) { throw new NotImplementedException(); } public Task Timeout(MyMessage state, IMessageHandlerContext context)
{
TimeoutCalled = true;
return TaskEx.CompletedTask;
}

public bool HandlerCalled { get; set; }
public bool TimeoutCalled { get; set; } public class MySagaData : ContainSagaData { } }

}
}
Expand Down
2 changes: 2 additions & 0 deletions src/NServiceBus.Core/Pipeline/Incoming/MessageHandler.cs
Expand Up @@ -29,6 +29,8 @@ public MessageHandler(Func<object, object, IMessageHandlerContext, Task> invocat
/// </summary>
public Type HandlerType { get; private set; }

internal bool IsTimeoutHandler { get; set; }

/// <summary>
/// Invokes the message handler.
/// </summary>
Expand Down
21 changes: 17 additions & 4 deletions src/NServiceBus.Core/Sagas/SagaPersistenceBehavior.cs
Expand Up @@ -21,6 +21,19 @@ public SagaPersistenceBehavior(ISagaPersister persister, ICancelDeferredMessages

public async Task Invoke(IInvokeHandlerContext context, Func<IInvokeHandlerContext, Task> next)
{
var isTimeoutMessage = IsTimeoutMessage(context.Headers);
var isTimeoutHandler = context.MessageHandler.IsTimeoutHandler;

if (isTimeoutHandler && !isTimeoutMessage)
{
return;
}

if (!isTimeoutHandler && isTimeoutMessage)
{
return;
}

currentContext = context;

RemoveSagaHeadersIfProcessingAEvent(context);
Expand Down Expand Up @@ -63,7 +76,7 @@ public async Task Invoke(IInvokeHandlerContext context, Func<IInvokeHandlerConte
sagaInstanceState.MarkAsNotFound();

//we don't invoke not found handlers for timeouts
if (IsTimeoutMessage(context.Headers))
if (isTimeoutMessage)
{
context.Extensions.Get<SagaInvocationResult>().SagaFound();
logger.InfoFormat("No saga found for timeout message {0}, ignoring since the saga has been marked as complete before the timeout fired", context.MessageId);
Expand Down Expand Up @@ -225,7 +238,7 @@ Task<IContainSagaData> TryLoadSagaEntity(SagaMetadata metadata, IInvokeHandlerCo
//since we have a saga id available we can now shortcut the finders and just load the saga
var loaderType = typeof(LoadSagaByIdWrapper<>).MakeGenericType(sagaEntityType);

var loader = (SagaLoader) Activator.CreateInstance(loaderType);
var loader = (SagaLoader)Activator.CreateInstance(loaderType);

return loader.Load(sagaPersister, sagaId, context.SynchronizedStorageSession, context.Extensions);
}
Expand All @@ -239,7 +252,7 @@ Task<IContainSagaData> TryLoadSagaEntity(SagaMetadata metadata, IInvokeHandlerCo
}

var finderType = finderDefinition.Type;
var finder = (SagaFinder) currentContext.Builder.Build(finderType);
var finder = (SagaFinder)currentContext.Builder.Build(finderType);

return finder.Find(currentContext.Builder, finderDefinition, context.SynchronizedStorageSession, context.Extensions, context.MessageBeingHandled);
}
Expand All @@ -261,7 +274,7 @@ IContainSagaData CreateNewSagaEntity(SagaMetadata metadata, IInvokeHandlerContex
{
var sagaEntityType = metadata.SagaEntityType;

var sagaEntity = (IContainSagaData) Activator.CreateInstance(sagaEntityType);
var sagaEntity = (IContainSagaData)Activator.CreateInstance(sagaEntityType);

sagaEntity.Id = CombGuid.Generate();
sagaEntity.OriginalMessageId = context.MessageId;
Expand Down
27 changes: 16 additions & 11 deletions src/NServiceBus.Core/Unicast/MessageHandlerRegistry.cs
Expand Up @@ -37,11 +37,14 @@ public List<MessageHandler> GetHandlersFor(Type messageType)
{
var handlerType = handlersAndMessages.Key;
// ReSharper disable once LoopCanBeConvertedToQuery
foreach (var messagesBeingHandled in handlersAndMessages.Value)
foreach (var handlerDelegate in handlersAndMessages.Value)
{
if (messagesBeingHandled.MessageType.IsAssignableFrom(messageType))
if (handlerDelegate.MessageType.IsAssignableFrom(messageType))
{
messageHandlers.Add(new MessageHandler(messagesBeingHandled.MethodDelegate, handlerType));
messageHandlers.Add(new MessageHandler(handlerDelegate.MethodDelegate, handlerType)
{
IsTimeoutHandler = handlerDelegate.IsTimeoutHandler
});
}
}
}
Expand Down Expand Up @@ -99,11 +102,11 @@ public void Clear()

static void CacheHandlerMethods(Type handler, Type messageType, ICollection<DelegateHolder> typeList)
{
CacheMethod(handler, messageType, typeof(IHandleMessages<>), typeList);
CacheMethod(handler, messageType, typeof(IHandleTimeouts<>), typeList);
CacheMethod(handler, messageType, typeof(IHandleMessages<>), typeList, isTimeoutHandler: false);
CacheMethod(handler, messageType, typeof(IHandleTimeouts<>), typeList, isTimeoutHandler: true);
}

static void CacheMethod(Type handler, Type messageType, Type interfaceGenericType, ICollection<DelegateHolder> methodList)
static void CacheMethod(Type handler, Type messageType, Type interfaceGenericType, ICollection<DelegateHolder> methodList, bool isTimeoutHandler)
{
var handleMethod = GetMethod(handler, messageType, interfaceGenericType);
if (handleMethod == null)
Expand All @@ -115,7 +118,8 @@ static void CacheMethod(Type handler, Type messageType, Type interfaceGenericTyp
var delegateHolder = new DelegateHolder
{
MessageType = messageType,
MethodDelegate = handleMethod
MethodDelegate = handleMethod,
IsTimeoutHandler = isTimeoutHandler
};
methodList.Add(delegateHolder);
}
Expand Down Expand Up @@ -153,12 +157,12 @@ static void CacheMethod(Type handler, Type messageType, Type interfaceGenericTyp
static Type[] GetMessageTypesBeingHandledBy(Type type)
{
return (from t in type.GetInterfaces()
where t.IsGenericType
let potentialMessageType = t.GetGenericArguments()[0]
where
where t.IsGenericType
let potentialMessageType = t.GetGenericArguments()[0]
where
typeof(IHandleMessages<>).MakeGenericType(potentialMessageType).IsAssignableFrom(t) ||
typeof(IHandleTimeouts<>).MakeGenericType(potentialMessageType).IsAssignableFrom(t)
select potentialMessageType)
select potentialMessageType)
.Distinct()
.ToArray();
}
Expand All @@ -185,6 +189,7 @@ void ValidateHandlerType(Type handlerType)

class DelegateHolder
{
public bool IsTimeoutHandler { get; set; }
public Type MessageType;
public Func<object, object, IMessageHandlerContext, Task> MethodDelegate;
}
Expand Down

0 comments on commit e573354

Please sign in to comment.