Skip to content

Commit

Permalink
forced serialization sender middleware (#1759)
Browse files Browse the repository at this point in the history
  • Loading branch information
marcinbudny authored Sep 23, 2022
1 parent d56ddd1 commit 6b34459
Show file tree
Hide file tree
Showing 3 changed files with 345 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
// -----------------------------------------------------------------------
// <copyright file = "ForceSerializationSenderMiddleware.cs" company = "Asynkron AB">
// Copyright (C) 2015-2022 Asynkron AB All rights reserved
// </copyright>
// -----------------------------------------------------------------------
using System;
using Google.Protobuf;
using Microsoft.Extensions.Logging;

namespace Proto.Remote;

public static class ForcedSerializationSenderMiddleware
{
private static readonly ILogger Logger = Log.CreateLogger(nameof(ForcedSerializationSenderMiddleware));

/// <summary>
/// Returns sender middleware that forces serialization of the message. This middleware serializes and then deserializes the message before
/// sending it further down the pipeline. It simulates the serialization process in <see cref="Endpoint"/>.
/// Useful for testing if serialization is working correctly and the messages are immutable.
/// </summary>
/// <param name="shouldSerialize">
/// A predicate that can prevent serialization by returning false.
/// If null, it defaults to <see cref="SkipInternalProtoMessages"/>
/// </param>
/// <returns>
/// Middleware configuration function, to be used with WithSenderMiddleware on
/// <see cref="Props"/> or on <see cref="RootContext"/> configuration
/// </returns>
public static Func<Sender, Sender> Create(Func<Proto.MessageEnvelope, bool>? shouldSerialize = null)
{
shouldSerialize ??= SkipInternalProtoMessages;

return next =>
(context, target, envelope) => {
object? message = null;
PID? sender;
Proto.MessageHeader headers;
try
{
if (shouldSerialize?.Invoke(envelope) == false)
return next(context, target, envelope);
var serialization = context.System.Serialization();
// serialize
(message, sender, headers) = Proto.MessageEnvelope.Unwrap(envelope);
if (message is IRootSerializable rootSerializable)
message = rootSerializable.Serialize(context.System);
if (message is null)
throw new Exception("Null message passed to the forced serialization middleware");
var (bytes, typeName, serializerId) = serialization.Serialize(message);
// deserialize
var deserializedMessage = serialization.Deserialize(typeName, bytes, serializerId);
if (message is IRootSerialized rootDeserialized)
deserializedMessage = rootDeserialized.Deserialize(context.System);
// forward
var newEnvelope = new Proto.MessageEnvelope(deserializedMessage, sender, headers);
return next(context, target, newEnvelope);
}
catch (CodedOutputStream.OutOfSpaceException oom)
{
Logger.LogError(oom, "Message is too large for serialization {Message}", message?.GetType().Name);
throw;
}
catch (Exception ex)
{
ex.CheckFailFast();
Logger.LogError(ex, "Forced serialization -> deserialization failed for message {Message}", message?.GetType().Name);
throw;
}
};
}

/// <summary>
/// Predicate to skip serialization of internal Proto messages
/// </summary>
/// <param name="envelope"></param>
/// <returns></returns>
public static bool SkipInternalProtoMessages(Proto.MessageEnvelope envelope)
{
var (message, _, _) = Proto.MessageEnvelope.Unwrap(envelope);
return message.GetType().FullName?.StartsWith("Proto.") == false;
}
}
59 changes: 59 additions & 0 deletions tests/Proto.Cluster.Tests/ForcedSerializationTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
// -----------------------------------------------------------------------
// <copyright file = "ForcedSerializationTests.cs" company = "Asynkron AB">
// Copyright (C) 2015-2022 Asynkron AB All rights reserved
// </copyright>
// -----------------------------------------------------------------------
using System.Linq;
using System.Threading.Tasks;
using ClusterTest.Messages;
using FluentAssertions;
using Proto.Cluster.Gossip;
using Proto.Remote;
using Xunit;

namespace Proto.Cluster.Tests;

public class ForcedSerializationTests
{
[Fact]
public async Task Forced_serialization_works_correctly_in_a_cluster()
{
await using var fixture = new ForcedSerializationClusterFixture();
await fixture.InitializeAsync();
var entryMember = fixture.Members.First();

var testData = Enumerable.Range(1, 100).Select(i => i.ToString()).ToList();

var tasks = testData.Select(id => entryMember.Ping(id, id, CancellationTokens.FromSeconds(10))).ToList();
await Task.WhenAll(tasks);

var results = tasks.Select(t => t.Result.Message).ToList();

results.Should().BeEquivalentTo(testData);
}

[Fact]
public void The_test_messages_are_allowed_by_the_default_predicate()
{
var predicate = ForcedSerializationSenderMiddleware.SkipInternalProtoMessages;

predicate(MessageEnvelope.Wrap(new Ping())).Should().BeTrue();
}

[Fact]
public void Sample_internal_proto_messages_are_not_allowed_by_the_default_predicate()
{
var predicate = ForcedSerializationSenderMiddleware.SkipInternalProtoMessages;

predicate(MessageEnvelope.Wrap(new GetGossipStateRequest("test"))).Should().BeFalse();
predicate(MessageEnvelope.Wrap(new GossipState())).Should().BeFalse();
}

private class ForcedSerializationClusterFixture : InMemoryClusterFixture
{
protected override ActorSystemConfig GetActorSystemConfig() =>
base.GetActorSystemConfig().WithConfigureRootContext(
conf => conf.WithSenderMiddleware(ForcedSerializationSenderMiddleware.Create())
);
}
}
194 changes: 194 additions & 0 deletions tests/Proto.Remote.Tests/ForcedSerializationTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
// -----------------------------------------------------------------------
// <copyright file = "ForcedSerializationTests.cs" company = "Asynkron AB">
// Copyright (C) 2015-2022 Asynkron AB All rights reserved
// </copyright>
// -----------------------------------------------------------------------
using System;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using FluentAssertions;
using ForcedSerialization.TestMessages;
using Proto;
using Proto.Remote;
using Xunit;

namespace Proto.Remote.Tests
{
public class ForcedSerializationTests
{
private object _receivedMessage;
private PID _sender;
private Proto.MessageHeader _header;
private readonly ManualResetEvent _wait = new(false);
private readonly Props _receivingActorProps;
private readonly Props _sendingActorProps;

public ForcedSerializationTests()
{
_receivingActorProps = Props.FromFunc(ctx => {
if (ctx.Message is TestMessage or TestRootSerializableMessage)
{
_receivedMessage = ctx.Message;
_sender = ctx.Sender;
_header = ctx.Headers;
ctx.Respond(new TestResponse());
_wait.Set();
}
return Task.CompletedTask;
}
);

_sendingActorProps = Props.FromFunc(ctx => {
switch (ctx.Message)
{
case RunRequestAsync msg:
_ = ctx.RequestWithHeadersAsync<TestResponse>(msg.Target, new TestMessage("From another actor"), msg.Headers);
break;
case RunRequest msg:
ctx.Request(msg.Target, new TestMessage("From another actor"));
break;
}
return Task.CompletedTask;
}
).WithSenderMiddleware(ForcedSerializationSenderMiddleware.Create());
}

[Fact]
public void The_test_messages_are_allowed_by_the_default_predicate()
{
var predicate = ForcedSerializationSenderMiddleware.SkipInternalProtoMessages;

predicate(Proto.MessageEnvelope.Wrap(new TestMessage("test"))).Should().BeTrue();
predicate(Proto.MessageEnvelope.Wrap(new TestRootSerializableMessage("test"))).Should().BeTrue();
}

[Fact]
public void Sample_internal_proto_messages_are_not_allowed_by_the_default_predicate()
{
var predicate = ForcedSerializationSenderMiddleware.SkipInternalProtoMessages;

predicate(Proto.MessageEnvelope.Wrap(Started.Instance)).Should().BeFalse();
predicate(Proto.MessageEnvelope.Wrap(new RemoteDeliver(null!, null!, null!, null))).Should().BeFalse();
}

[Fact]
public void It_serializes_and_deserializes()
{
var system = new ActorSystem(ActorSystemConfig.Setup()
.WithConfigureRootContext(ctx => ctx.WithSenderMiddleware(
ForcedSerializationSenderMiddleware.Create()
)
)
);
system.Extensions.Register(new Serialization());

var pid = system.Root.Spawn(_receivingActorProps);
var sentMessage = new TestMessage("Serialized");
system.Root.Send(pid, sentMessage);

_wait.WaitOne(TimeSpan.FromSeconds(2));

_receivedMessage.Should().BeEquivalentTo(sentMessage, "the received message should be the same as the sent message");
_receivedMessage.Should().NotBeSameAs(sentMessage, "the message should have been serialized");
}

[Fact]
public void It_should_not_serialize_if_predicate_prevents_it()
{
var system = new ActorSystem(ActorSystemConfig.Setup()
.WithConfigureRootContext(ctx => ctx.WithSenderMiddleware(
ForcedSerializationSenderMiddleware.Create(_ => false)
)
)
);
system.Extensions.Register(new Serialization());

var pid = system.Root.Spawn(_receivingActorProps);
var sentMessage = new TestMessage("Not serialized");
system.Root.Send(pid, sentMessage);

_wait.WaitOne(TimeSpan.FromSeconds(2));

_receivedMessage.Should().BeEquivalentTo(sentMessage, "the received message should be the same as the sent message");
_receivedMessage.Should().BeSameAs(sentMessage, "the message should not have been serialized");
}

[Fact]
public async Task It_preserves_headers()
{
await using var system = new ActorSystem(ActorSystemConfig.Setup());
system.Extensions.Register(new Serialization());

var pid = system.Root.Spawn(_receivingActorProps);
var sender = system.Root.Spawn(_sendingActorProps);

var headers = new Proto.MessageHeader(new Dictionary<string, string> {{"key", "value"}});
system.Root.Send(sender, new RunRequestAsync(pid, headers));

_wait.WaitOne(TimeSpan.FromSeconds(2));

_header.Should().BeEquivalentTo(headers);
}

[Fact]
public async Task It_preserves_sender()
{
await using var system = new ActorSystem(ActorSystemConfig.Setup());
system.Extensions.Register(new Serialization());

var pid = system.Root.Spawn(_receivingActorProps);
var sender = system.Root.Spawn(_sendingActorProps);

system.Root.Send(sender, new RunRequest(pid, null));

_wait.WaitOne(TimeSpan.FromSeconds(2));

_sender.Should().BeEquivalentTo(sender);
}

[Fact]
public async Task It_can_handle_root_serializable()
{
await using var system = new ActorSystem(ActorSystemConfig.Setup()
.WithConfigureRootContext(ctx => ctx.WithSenderMiddleware(
ForcedSerializationSenderMiddleware.Create()
)
)
);
system.Extensions.Register(new Serialization());

var pid = system.Root.Spawn(_receivingActorProps);
var sentMessage = new TestRootSerializableMessage("Serialized");
system.Root.Send(pid, sentMessage);

_wait.WaitOne(TimeSpan.FromSeconds(2));

_receivedMessage.Should().BeEquivalentTo(sentMessage, "the received message should be the same as the sent message");
_receivedMessage.Should().NotBeSameAs(sentMessage, "the message should have been serialized");
}
}
}

namespace ForcedSerialization.TestMessages
{
record TestMessage(string Value);

record TestRootSerializableMessage(string Value) : IRootSerializable
{
public IRootSerialized Serialize(ActorSystem system) => new TestRootSerializedMessage(Value);
}

record TestRootSerializedMessage(string Value) : IRootSerialized
{
public IRootSerializable Deserialize(ActorSystem system) => new TestRootSerializableMessage(Value);
}

record TestResponse();

record RunRequest(PID Target, Proto.MessageHeader Headers);

record RunRequestAsync(PID Target, Proto.MessageHeader Headers);
}

0 comments on commit 6b34459

Please sign in to comment.