Skip to content

Commit

Permalink
Improve p2p message deserialization (neo-project#1262)
Browse files Browse the repository at this point in the history
  • Loading branch information
erikzhang authored and Tommo-L committed Jun 22, 2020
1 parent 5eb9432 commit 16c1223
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 144 deletions.
51 changes: 24 additions & 27 deletions neo.UnitTests/IO/Caching/UT_ReflectionCache.cs
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
using System.IO;
using FluentAssertions;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Neo.IO;
using Neo.IO.Caching;
using System;

namespace Neo.UnitTests.IO.Caching
{
public class TestItem { }
public class TestItem : ISerializable
{
public int Size => 0;
public void Deserialize(BinaryReader reader) { }
public void Serialize(BinaryWriter writer) { }
}

public class TestItem1 : TestItem { }

Expand All @@ -25,55 +31,46 @@ public enum MyEmptyEnum : byte { }
[TestClass]
public class UT_ReflectionCache
{
ReflectionCache<byte> reflectionCache;

[TestInitialize]
public void SetUp()
{
reflectionCache = ReflectionCache<byte>.CreateFromEnum<MyTestEnum>();
}

[TestMethod]
public void TestCreateFromEnum()
public void TestCreateFromEmptyEnum()
{
reflectionCache.Should().NotBeNull();
ReflectionCache<MyEmptyEnum>.Count.Should().Be(0);
}

[TestMethod]
public void TestCreateFromObjectNotEnum()
public void TestCreateInstance()
{
Action action = () => ReflectionCache<byte>.CreateFromEnum<int>();
action.Should().Throw<ArgumentException>();
}
object item1 = ReflectionCache<MyTestEnum>.CreateInstance(MyTestEnum.Item1, null);
(item1 is TestItem1).Should().BeTrue();

[TestMethod]
public void TestCreateFromEmptyEnum()
{
reflectionCache = ReflectionCache<byte>.CreateFromEnum<MyEmptyEnum>();
reflectionCache.Count.Should().Be(0);
object item2 = ReflectionCache<MyTestEnum>.CreateInstance(MyTestEnum.Item2, null);
(item2 is TestItem2).Should().BeTrue();

object item3 = ReflectionCache<MyTestEnum>.CreateInstance((MyTestEnum)0x02, null);
item3.Should().BeNull();
}

[TestMethod]
public void TestCreateInstance()
public void TestCreateSerializable()
{
object item1 = reflectionCache.CreateInstance((byte)MyTestEnum.Item1, null);
object item1 = ReflectionCache<MyTestEnum>.CreateSerializable(MyTestEnum.Item1, new byte[0]);
(item1 is TestItem1).Should().BeTrue();

object item2 = reflectionCache.CreateInstance((byte)MyTestEnum.Item2, null);
object item2 = ReflectionCache<MyTestEnum>.CreateSerializable(MyTestEnum.Item2, new byte[0]);
(item2 is TestItem2).Should().BeTrue();

object item3 = reflectionCache.CreateInstance(0x02, null);
object item3 = ReflectionCache<MyTestEnum>.CreateSerializable((MyTestEnum)0x02, new byte[0]);
item3.Should().BeNull();
}

[TestMethod]
public void TestCreateInstance2()
{
TestItem defaultItem = new TestItem1();
object item2 = reflectionCache.CreateInstance((byte)MyTestEnum.Item2, defaultItem);
object item2 = ReflectionCache<MyTestEnum>.CreateInstance(MyTestEnum.Item2, defaultItem);
(item2 is TestItem2).Should().BeTrue();

object item1 = reflectionCache.CreateInstance(0x02, new TestItem1());
object item1 = ReflectionCache<MyTestEnum>.CreateInstance((MyTestEnum)0x02, new TestItem1());
(item1 is TestItem1).Should().BeTrue();
}
}
Expand Down
18 changes: 4 additions & 14 deletions neo/Consensus/ConsensusMessage.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,6 @@ namespace Neo.Consensus
{
public abstract class ConsensusMessage : ISerializable
{
/// <summary>
/// Reflection cache for ConsensusMessageType
/// </summary>
private static ReflectionCache<byte> ReflectionCache = ReflectionCache<byte>.CreateFromEnum<ConsensusMessageType>();

public readonly ConsensusMessageType Type;
public byte ViewNumber;

Expand All @@ -31,15 +26,10 @@ public virtual void Deserialize(BinaryReader reader)

public static ConsensusMessage DeserializeFrom(byte[] data)
{
ConsensusMessage message = ReflectionCache.CreateInstance<ConsensusMessage>(data[0]);
if (message == null) throw new FormatException();

using (MemoryStream ms = new MemoryStream(data, false))
using (BinaryReader r = new BinaryReader(ms))
{
message.Deserialize(r);
}
return message;
ConsensusMessageType type = (ConsensusMessageType)data[0];
ISerializable message = ReflectionCache<ConsensusMessageType>.CreateSerializable(type, data);
if (message is null) throw new FormatException();
return (ConsensusMessage)message;
}

public virtual void Serialize(BinaryWriter writer)
Expand Down
74 changes: 19 additions & 55 deletions neo/IO/Caching/ReflectionCache.cs
Original file line number Diff line number Diff line change
@@ -1,80 +1,44 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;

namespace Neo.IO.Caching
{
public class ReflectionCache<T> : Dictionary<T, Type>
internal static class ReflectionCache<T> where T : Enum
{
/// <summary>
/// Constructor
/// </summary>
public ReflectionCache() { }
/// <summary>
/// Constructor
/// </summary>
/// <typeparam name="EnumType">Enum type</typeparam>
public static ReflectionCache<T> CreateFromEnum<EnumType>() where EnumType : struct, IConvertible
{
Type enumType = typeof(EnumType);

if (!enumType.GetTypeInfo().IsEnum)
throw new ArgumentException("K must be an enumerated type");
private static readonly Dictionary<T, Type> dictionary = new Dictionary<T, Type>();

// Cache all types
ReflectionCache<T> r = new ReflectionCache<T>();
public static int Count => dictionary.Count;

foreach (object t in Enum.GetValues(enumType))
static ReflectionCache()
{
Type enumType = typeof(T);
foreach (FieldInfo field in enumType.GetFields(BindingFlags.Public | BindingFlags.Static))
{
// Get enumn member
MemberInfo[] memInfo = enumType.GetMember(t.ToString());
if (memInfo == null || memInfo.Length != 1)
throw (new FormatException());

// Get attribute
ReflectionCacheAttribute attribute = memInfo[0].GetCustomAttributes(typeof(ReflectionCacheAttribute), false)
.Cast<ReflectionCacheAttribute>()
.FirstOrDefault();

if (attribute == null)
throw (new FormatException());
ReflectionCacheAttribute attribute = field.GetCustomAttribute<ReflectionCacheAttribute>();
if (attribute == null) continue;

// Append to cache
r.Add((T)t, attribute.Type);
dictionary.Add((T)field.GetValue(null), attribute.Type);
}
return r;
}
/// <summary>
/// Create object from key
/// </summary>
/// <param name="key">Key</param>
/// <param name="def">Default value</param>
public object CreateInstance(T key, object def = null)
{
Type tp;

public static object CreateInstance(T key, object def = null)
{
// Get Type from cache
if (TryGetValue(key, out tp)) return Activator.CreateInstance(tp);
if (dictionary.TryGetValue(key, out Type t))
return Activator.CreateInstance(t);

// return null
return def;
}
/// <summary>
/// Create object from key
/// </summary>
/// <typeparam name="K">Type</typeparam>
/// <param name="key">Key</param>
/// <param name="def">Default value</param>
public K CreateInstance<K>(T key, K def = default(K))
{
Type tp;

// Get Type from cache
if (TryGetValue(key, out tp)) return (K)Activator.CreateInstance(tp);

// return null
return def;
public static ISerializable CreateSerializable(T key, byte[] data)
{
if (dictionary.TryGetValue(key, out Type t))
return data.AsSerializable(t);
return null;
}
}
}
5 changes: 3 additions & 2 deletions neo/IO/Caching/ReflectionCacheAttribute.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@

namespace Neo.IO.Caching
{
public class ReflectionCacheAttribute : Attribute
[AttributeUsage(AttributeTargets.Field, AllowMultiple = false)]
internal class ReflectionCacheAttribute : Attribute
{
/// <summary>
/// Type
/// </summary>
public Type Type { get; private set; }
public Type Type { get; }

/// <summary>
/// Constructor
Expand Down
48 changes: 2 additions & 46 deletions neo/Network/P2P/Message.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using Akka.IO;
using Neo.Cryptography;
using Neo.IO;
using Neo.Network.P2P.Payloads;
using Neo.IO.Caching;
using System;
using System.IO;

Expand Down Expand Up @@ -51,51 +51,7 @@ private void DecompressPayload()
byte[] decompressed = Flags.HasFlag(MessageFlags.Compressed)
? _payload_compressed.DecompressLz4(PayloadMaxSize)
: _payload_compressed;
switch (Command)
{
case MessageCommand.Version:
Payload = decompressed.AsSerializable<VersionPayload>();
break;
case MessageCommand.Addr:
Payload = decompressed.AsSerializable<AddrPayload>();
break;
case MessageCommand.Ping:
case MessageCommand.Pong:
Payload = decompressed.AsSerializable<PingPayload>();
break;
case MessageCommand.GetHeaders:
case MessageCommand.GetBlocks:
Payload = decompressed.AsSerializable<GetBlocksPayload>();
break;
case MessageCommand.Headers:
Payload = decompressed.AsSerializable<HeadersPayload>();
break;
case MessageCommand.Inv:
case MessageCommand.GetData:
Payload = decompressed.AsSerializable<InvPayload>();
break;
case MessageCommand.GetBlockData:
Payload = decompressed.AsSerializable<GetBlockDataPayload>();
break;
case MessageCommand.Transaction:
Payload = decompressed.AsSerializable<Transaction>();
break;
case MessageCommand.Block:
Payload = decompressed.AsSerializable<Block>();
break;
case MessageCommand.Consensus:
Payload = decompressed.AsSerializable<ConsensusPayload>();
break;
case MessageCommand.FilterLoad:
Payload = decompressed.AsSerializable<FilterLoadPayload>();
break;
case MessageCommand.FilterAdd:
Payload = decompressed.AsSerializable<FilterAddPayload>();
break;
case MessageCommand.MerkleBlock:
Payload = decompressed.AsSerializable<MerkleBlockPayload>();
break;
}
Payload = ReflectionCache<MessageCommand>.CreateSerializable(Command, decompressed);
}

void ISerializable.Deserialize(BinaryReader reader)
Expand Down
19 changes: 19 additions & 0 deletions neo/Network/P2P/MessageCommand.cs
Original file line number Diff line number Diff line change
@@ -1,35 +1,54 @@
using Neo.IO.Caching;
using Neo.Network.P2P.Payloads;

namespace Neo.Network.P2P
{
public enum MessageCommand : byte
{
//handshaking
[ReflectionCache(typeof(VersionPayload))]
Version = 0x00,
Verack = 0x01,

//connectivity
GetAddr = 0x10,
[ReflectionCache(typeof(AddrPayload))]
Addr = 0x11,
[ReflectionCache(typeof(PingPayload))]
Ping = 0x18,
[ReflectionCache(typeof(PingPayload))]
Pong = 0x19,

//synchronization
[ReflectionCache(typeof(GetBlocksPayload))]
GetHeaders = 0x20,
[ReflectionCache(typeof(HeadersPayload))]
Headers = 0x21,
[ReflectionCache(typeof(GetBlocksPayload))]
GetBlocks = 0x24,
Mempool = 0x25,
[ReflectionCache(typeof(InvPayload))]
Inv = 0x27,
[ReflectionCache(typeof(InvPayload))]
GetData = 0x28,
[ReflectionCache(typeof(GetBlockDataPayload))]
GetBlockData = 0x29,
NotFound = 0x2a,
[ReflectionCache(typeof(Transaction))]
Transaction = 0x2b,
[ReflectionCache(typeof(Block))]
Block = 0x2c,
[ReflectionCache(typeof(ConsensusPayload))]
Consensus = 0x2d,
Reject = 0x2f,

//SPV protocol
[ReflectionCache(typeof(FilterLoadPayload))]
FilterLoad = 0x30,
[ReflectionCache(typeof(FilterAddPayload))]
FilterAdd = 0x31,
FilterClear = 0x32,
[ReflectionCache(typeof(MerkleBlockPayload))]
MerkleBlock = 0x38,

//others
Expand Down

0 comments on commit 16c1223

Please sign in to comment.