Skip to content

Commit

Permalink
Do same to InterfacedPersistentActor. #21
Browse files Browse the repository at this point in the history
  • Loading branch information
veblush committed May 22, 2016
1 parent c442e07 commit 20b6b9c
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 22 deletions.
4 changes: 2 additions & 2 deletions core/Akka.Interfaced/InterfacedActor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public InterfacedActor()
_handler = InterfacedActorHandlerTable.Get(GetType());
}

// Atomic async OnStart event (it will be called after OnStart)
// Atomic async OnStart event (it will be called after PreStart, PostRestart)
protected virtual Task OnStart(bool restarted)
{
return Task.FromResult(true);
Expand Down Expand Up @@ -78,7 +78,7 @@ private void InitializeActorState()

private void InvokeOnStart(bool restarted)
{
var context = new MessageHandleContext { Self = Self, Sender = Sender, CancellationToken = CancellationToken };
var context = new MessageHandleContext { Self = Self, Sender = base.Sender, CancellationToken = CancellationToken };
BecomeStacked(OnReceiveInAtomicTask);
_currentAtomicContext = context;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public TestNotepadActor(string id, List<string> eventLog)
_eventLog = eventLog;
}

protected override Task OnStart()
protected override Task OnStart(bool restarted)
{
_state = new NotepadState { Document = new List<string>() };
return Task.FromResult(0);
Expand Down
113 changes: 94 additions & 19 deletions plugins/Akka.Interfaced.Persistence/InterfacedPersistentActor.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using Akka.Actor;
using Akka.Persistence;
Expand All @@ -9,18 +10,15 @@ namespace Akka.Interfaced.Persistence
public abstract class InterfacedPersistentActor : UntypedPersistentActor, IRequestWaiter, IFilterPerInstanceProvider
{
private readonly InterfacedActorHandler _handler;
private CancellationTokenSource _cancellationTokenSource;
private int _activeReentrantCount;
private HashSet<MessageHandleContext> _activeReentrantAsyncRequestSet;
private MessageHandleContext _currentAtomicContext;
private InterfacedActorRequestWaiter _requestWaiter;
private InterfacedActorObserverMap _observerMap;
private InterfacedActorPerInstanceFilterList _perInstanceFilterList;
private Dictionary<long, TaskCompletionSource<SnapshotMetadata>> _saveSnapshotTcsMap;

public InterfacedPersistentActor()
{
_handler = InterfacedActorHandlerTable.Get(GetType());
}

protected new IActorRef Sender
{
get
Expand All @@ -30,8 +28,16 @@ public InterfacedPersistentActor()
}
}

// Atomic async OnStart event (it will be called after PreStart)
protected virtual Task OnStart()
// Return a token which will be cancelled when an actor stops or restarts.
protected CancellationToken CancellationToken => _cancellationTokenSource.Token;

public InterfacedPersistentActor()
{
_handler = InterfacedActorHandlerTable.Get(GetType());
}

// Atomic async OnStart event (it will be called after PreStart, PostRestart)
protected virtual Task OnStart(bool restarted)
{
return Task.FromResult(true);
}
Expand All @@ -45,31 +51,48 @@ protected virtual Task OnGracefulStop()

public override void AroundPreStart()
{
if (_handler.PerInstanceFilterCreators.Count > 0)
_perInstanceFilterList = new InterfacedActorPerInstanceFilterList(this, _handler.PerInstanceFilterCreators);

InitializeActorState();
base.AroundPreStart();
InvokeOnStart(false);
}

public override void AroundPostRestart(Exception cause, object message)
{
InitializeActorState();
base.AroundPostRestart(cause, message);
InvokeOnStart(true);
}

private void InitializeActorState()
{
_cancellationTokenSource = new CancellationTokenSource();
_activeReentrantCount = 0;
_activeReentrantAsyncRequestSet = null;
_currentAtomicContext = null;
_requestWaiter = null;
_observerMap = null;

InvokeOnStart();
if (_handler.PerInstanceFilterCreators.Count > 0)
_perInstanceFilterList = new InterfacedActorPerInstanceFilterList(this, _handler.PerInstanceFilterCreators);
}

private void InvokeOnStart()
private void InvokeOnStart(bool restarted)
{
var context = new MessageHandleContext { Self = Self, Sender = base.Sender };
var context = new MessageHandleContext { Self = Self, Sender = base.Sender, CancellationToken = CancellationToken };
BecomeStacked(OnReceiveInAtomicTask);
_currentAtomicContext = context;

using (new SynchronizationContextSwitcher(new ActorSynchronizationContext(context)))
{
OnStart().ContinueWith(
OnStart(restarted).ContinueWith(
t => OnTaskCompleted(t.Exception, false),
TaskContinuationOptions.ExecuteSynchronously);
}
}

private void InvokeOnGracefulStop()
{
var context = new MessageHandleContext { Self = Self };
var context = new MessageHandleContext { Self = Self, CancellationToken = CancellationToken };
BecomeStacked(OnReceiveInAtomicTask);
_currentAtomicContext = context;

Expand All @@ -81,6 +104,46 @@ private void InvokeOnGracefulStop()
}
}

public override void AroundPreRestart(Exception cause, object message)
{
CancelAllTasks();
base.AroundPreRestart(cause, message);
}

public override void AroundPostStop()
{
CancelAllTasks();
base.AroundPostStop();
}

private void CancelAllTasks()
{
_cancellationTokenSource.Cancel();

// Send responses to requesters that waits for a reentrant async job

if (_activeReentrantAsyncRequestSet != null)
{
foreach (var i in _activeReentrantAsyncRequestSet)
{
i.Sender.Tell(new ResponseMessage
{
RequestId = i.RequestId,
Exception = new RequestHaltException()
});
}
}

if (_currentAtomicContext != null && _currentAtomicContext.RequestId != 0)
{
_currentAtomicContext.Sender.Tell(new ResponseMessage
{
RequestId = _currentAtomicContext.RequestId,
Exception = new RequestHaltException()
});
}
}

protected override void OnRecover(object message)
{
var messageHandler = _handler.MessageDispatcher.GetHandler(message.GetType());
Expand Down Expand Up @@ -192,10 +255,17 @@ private void OnRequestMessage(RequestMessage request)
{
// async handle

var context = new MessageHandleContext { Self = Self, Sender = base.Sender };
var context = new MessageHandleContext { Self = Self, Sender = base.Sender, CancellationToken = CancellationToken, RequestId = request.RequestId };
if (handlerItem.IsReentrant)
{
_activeReentrantCount += 1;
if (request.RequestId != 0)
{
if (_activeReentrantAsyncRequestSet == null)
_activeReentrantAsyncRequestSet = new HashSet<MessageHandleContext>();

_activeReentrantAsyncRequestSet.Add(context);
}
}
else
{
Expand All @@ -210,7 +280,12 @@ private void OnRequestMessage(RequestMessage request)
handlerItem.AsyncHandler(this, request, (response, exception) =>
{
if (requestId != 0)
{
if (isReentrant)
_activeReentrantAsyncRequestSet.Remove(context);
sender.Tell(response);
}
OnTaskCompleted(exception, isReentrant);
});
Expand Down Expand Up @@ -249,7 +324,7 @@ private void OnNotificationMessage(NotificationMessage notification)
{
// async handle

var context = new MessageHandleContext { Self = Self, Sender = base.Sender };
var context = new MessageHandleContext { Self = Self, Sender = base.Sender, CancellationToken = CancellationToken };
if (handlerItem.IsReentrant)
{
_activeReentrantCount += 1;
Expand All @@ -276,7 +351,7 @@ private void OnNotificationMessage(NotificationMessage notification)

private void OnTaskRunMessage(TaskRunMessage taskRunMessage)
{
var context = new MessageHandleContext { Self = Self, Sender = base.Sender };
var context = new MessageHandleContext { Self = Self, Sender = base.Sender, CancellationToken = CancellationToken };
if (taskRunMessage.IsReentrant)
{
_activeReentrantCount += 1;
Expand Down Expand Up @@ -311,7 +386,7 @@ private void HandleMessageByHandler(object message, MessageHandlerItem handlerIt
{
if (handlerItem.AsyncHandler != null)
{
var context = new MessageHandleContext { Self = Self, Sender = base.Sender };
var context = new MessageHandleContext { Self = Self, Sender = base.Sender, CancellationToken = CancellationToken };
if (handlerItem.IsReentrant)
{
_activeReentrantCount += 1;
Expand Down

0 comments on commit 20b6b9c

Please sign in to comment.