Skip to content

Commit

Permalink
feat: adding attribute to ignore extension method for read writer (#841)
Browse files Browse the repository at this point in the history
* feat: adding attibute to ignore extension method for read writer

* tests

* refactoring ReaderWriterProcessor

adding SerailizeExtensionHelper to handle checking of types and methods to find extension methods

making reflection and cecil methods have the same logic

* adding better fail message to assert

* fixing checking type is void

* fixing code smell
  • Loading branch information
James-Frowen committed Jun 20, 2021
1 parent b9f9ef6 commit 9494500
Show file tree
Hide file tree
Showing 5 changed files with 237 additions and 103 deletions.
6 changes: 6 additions & 0 deletions Assets/Mirage/Runtime/CustomAttributes.cs
Original file line number Diff line number Diff line change
Expand Up @@ -124,4 +124,10 @@ public sealed class ShowInInspectorAttribute : Attribute { }
/// </summary>
[AttributeUsage(AttributeTargets.Field)]
public sealed class FoldoutEventAttribute : PropertyAttribute { }

/// <summary>
/// Tells Weaver to ignore an Extension method
/// </summary>
[AttributeUsage(AttributeTargets.Method)]
public sealed class WeaverIgnoreAttribute : PropertyAttribute { }
}
268 changes: 166 additions & 102 deletions Assets/Mirage/Weaver/Processors/ReaderWriterProcessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using System.Runtime.CompilerServices;
using Mirage.Serialization;
using Mono.Cecil;
using Mono.Cecil.Cil;
Expand All @@ -18,6 +19,7 @@ public class ReaderWriterProcessor
private readonly ModuleDefinition module;
private readonly Readers readers;
private readonly Writers writers;
private readonly SerailizeExtensionHelper extensionHelper;

/// <summary>
/// Mirage's main module used to find built in extension methods and messages
Expand All @@ -29,6 +31,7 @@ public ReaderWriterProcessor(ModuleDefinition module, Readers readers, Writers w
this.module = module;
this.readers = readers;
this.writers = writers;
extensionHelper = new SerailizeExtensionHelper(module, readers, writers);
}

public bool Process()
Expand All @@ -46,37 +49,18 @@ public bool Process()
return writers.Count != writeCount || readers.Count != readCount;
}


#region Load Mirage built in readers and writers
private void LoadBuiltinExtensions()
{
// find all extension methods
IEnumerable<Type> types = MirageModule.GetTypes().Where(IsStatic);
IEnumerable<Type> types = MirageModule.GetTypes();

foreach (Type type in types)
{
IEnumerable<MethodInfo> methods = type.GetMethods(BindingFlags.Static | BindingFlags.Public)
.Where(IsExtension)
.Where(NotGeneric);

foreach (MethodInfo method in methods)
{
RegisterReader(method);
RegisterWriter(method);
}
extensionHelper.RegisterExtensionMethodsInType(type);
}
}

/// <summary>
/// static classes are declared abstract and sealed at the IL level.
/// <see href="https://stackoverflow.com/a/1175901/8479976"/>
/// </summary>
private static bool IsStatic(Type t) => t.IsSealed && t.IsAbstract;

private static bool IsExtension(MethodInfo method) => Attribute.IsDefined(method, typeof(System.Runtime.CompilerServices.ExtensionAttribute));
private static bool NotGeneric(MethodInfo method) => !method.IsGenericMethod;


private void LoadBuiltinMessages()
{
IEnumerable<Type> types = MirageModule.GetTypes().Where(t => t.GetCustomAttribute<NetworkMessageAttribute>() != null);
Expand All @@ -88,35 +72,6 @@ private void LoadBuiltinMessages()
messages.Add(typeReference);
}
}


private void RegisterReader(MethodInfo method)
{
if (method.GetParameters().Length != 1)
return;

if (method.GetParameters()[0].ParameterType.FullName != typeof(NetworkReader).FullName)
return;

if (method.ReturnType == typeof(void))
return;
readers.Register(module.ImportReference(method.ReturnType), module.ImportReference(method));
}

private void RegisterWriter(MethodInfo method)
{
if (method.GetParameters().Length != 2)
return;

if (method.GetParameters()[0].ParameterType.FullName != typeof(NetworkWriter).FullName)
return;

if (method.ReturnType != typeof(void))
return;

Type dataType = method.GetParameters()[1].ParameterType;
writers.Register(module.ImportReference(dataType), module.ImportReference(method));
}
#endregion

#region Assembly defined reader/writer
Expand All @@ -128,11 +83,7 @@ void ProcessAssemblyClasses()
{
// extension methods only live in static classes
// static classes are represented as sealed and abstract
if (klass.IsAbstract && klass.IsSealed)
{
LoadDeclaredWriters(klass);
LoadDeclaredReaders(klass);
}
extensionHelper.RegisterExtensionMethodsInType(klass);

if (klass.GetCustomAttribute<NetworkMessageAttribute>() != null)
{
Expand Down Expand Up @@ -255,54 +206,7 @@ private static bool IsReadWriteMethod(MethodReference method)
method.Is<NetworkReader>(nameof(NetworkReader.Read));
}

void LoadDeclaredWriters(TypeDefinition klass)
{
// register all the writers in this class. Skip the ones with wrong signature
foreach (MethodDefinition method in klass.Methods)
{
if (method.Parameters.Count != 2)
continue;

if (!method.Parameters[0].ParameterType.Is<NetworkWriter>())
continue;

if (!method.ReturnType.Is(typeof(void)))
continue;

if (!method.HasCustomAttribute<System.Runtime.CompilerServices.ExtensionAttribute>())
continue;

if (method.HasGenericParameters)
continue;

TypeReference dataType = method.Parameters[1].ParameterType;
writers.Register(dataType, module.ImportReference(method));
}
}

void LoadDeclaredReaders(TypeDefinition klass)
{
// register all the reader in this class. Skip the ones with wrong signature
foreach (MethodDefinition method in klass.Methods)
{
if (method.Parameters.Count != 1)
continue;

if (!method.Parameters[0].ParameterType.Is<NetworkReader>())
continue;

if (method.ReturnType.Is(typeof(void)))
continue;

if (!method.HasCustomAttribute<System.Runtime.CompilerServices.ExtensionAttribute>())
continue;

if (method.HasGenericParameters)
continue;

readers.Register(method.ReturnType, module.ImportReference(method));
}
}

private static bool IsEditorAssembly(ModuleDefinition module)
{
Expand Down Expand Up @@ -364,4 +268,164 @@ private void RegisterMessages(ILProcessor worker)

#endregion
}

/// <summary>
/// Helps get Extension methods using either reflection or cecil
/// </summary>
public class SerailizeExtensionHelper
{
private readonly ModuleDefinition module;
private readonly Readers readers;
private readonly Writers writers;

public SerailizeExtensionHelper(ModuleDefinition module, Readers readers, Writers writers)
{
this.module = module;
this.readers = readers;
this.writers = writers;
}


public void RegisterExtensionMethodsInType(Type type)
{
// only check static types
if (!IsStatic(type))
return;

IEnumerable<MethodInfo> methods = type.GetMethods(BindingFlags.Static | BindingFlags.Public)
.Where(IsExtension)
.Where(NotGeneric)
.Where(NotIgnored);

foreach (MethodInfo method in methods)
{
if (IsWriterMethod(method))
{
RegisterWriter(method);
}

if (IsReaderMethod(method))
{
RegisterReader(method);
}
}
}
public void RegisterExtensionMethodsInType(TypeDefinition type)
{
// only check static types
if (!IsStatic(type))
return;

IEnumerable<MethodDefinition> methods = type.Methods
.Where(IsExtension)
.Where(NotGeneric)
.Where(NotIgnored);

foreach (MethodDefinition method in methods)
{
if (IsWriterMethod(method))
{
RegisterWriter(method);
}

if (IsReaderMethod(method))
{
RegisterReader(method);
}
}
}

/// <summary>
/// static classes are declared abstract and sealed at the IL level.
/// <see href="https://stackoverflow.com/a/1175901/8479976"/>
/// </summary>
private static bool IsStatic(Type t) => t.IsSealed && t.IsAbstract;
private static bool IsStatic(TypeDefinition t) => t.IsSealed && t.IsAbstract;

private static bool IsExtension(MethodInfo method) => Attribute.IsDefined(method, typeof(ExtensionAttribute));
private static bool IsExtension(MethodDefinition method) => method.HasCustomAttribute<ExtensionAttribute>();
private static bool NotGeneric(MethodInfo method) => !method.IsGenericMethod;
private static bool NotGeneric(MethodDefinition method) => !method.IsGenericInstance;

/// <returns>true if method does not have <see cref="WeaverIgnoreAttribute"/></returns>
private static bool NotIgnored(MethodInfo method) => !Attribute.IsDefined(method, typeof(WeaverIgnoreAttribute));
/// <returns>true if method does not have <see cref="WeaverIgnoreAttribute"/></returns>
private static bool NotIgnored(MethodDefinition method) => !method.HasCustomAttribute<WeaverIgnoreAttribute>();


private static bool IsWriterMethod(MethodInfo method)
{
if (method.GetParameters().Length != 2)
return false;

if (method.GetParameters()[0].ParameterType.FullName != typeof(NetworkWriter).FullName)
return false;

if (method.ReturnType != typeof(void))
return false;

return true;
}
private bool IsWriterMethod(MethodDefinition method)
{
if (method.Parameters.Count != 2)
return false;

if (method.Parameters[0].ParameterType.FullName != typeof(NetworkWriter).FullName)
return false;

if (!method.ReturnType.Is(typeof(void)))
return false;

return true;
}

private static bool IsReaderMethod(MethodInfo method)
{
if (method.GetParameters().Length != 1)
return false;

if (method.GetParameters()[0].ParameterType.FullName != typeof(NetworkReader).FullName)
return false;

if (method.ReturnType == typeof(void))
return false;

return true;
}
private bool IsReaderMethod(MethodDefinition method)
{
if (method.Parameters.Count != 1)
return false;

if (method.Parameters[0].ParameterType.FullName != typeof(NetworkReader).FullName)
return false;

if (method.ReturnType.Is(typeof(void)))
return false;

return true;
}

private void RegisterWriter(MethodInfo method)
{
Type dataType = method.GetParameters()[1].ParameterType;
writers.Register(module.ImportReference(dataType), module.ImportReference(method));
}
private void RegisterWriter(MethodDefinition method)
{
TypeReference dataType = method.Parameters[1].ParameterType;
writers.Register(module.ImportReference(dataType), module.ImportReference(method));
}


private void RegisterReader(MethodInfo method)
{
readers.Register(module.ImportReference(method.ReturnType), module.ImportReference(method));
}
private void RegisterReader(MethodDefinition method)
{
readers.Register(module.ImportReference(method.ReturnType), module.ImportReference(method));
}
}
}
53 changes: 53 additions & 0 deletions Assets/Tests/Runtime/Serialization/WeaverIgnoreTest.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
using System;
using Mirage.Serialization;
using NUnit.Framework;

namespace Mirage.Tests.Runtime.Serialization
{
public class WeaverIgnoreTest
{
[Test]
public void DoesNotUseCustomWriter()
{
// check method names
Assert.That(Writer<MyCustomType>.Write.Method.Name, Is.Not.EqualTo(new Action<NetworkWriter, MyCustomType>(MyCustomTypeExtension.WriteOnlyPartOfCustom).Method.Name));
Assert.That(Reader<MyCustomType>.Read.Method.Name, Is.Not.EqualTo(new Func<NetworkReader, MyCustomType>(MyCustomTypeExtension.ReadOnlyPartOfCustom).Method.Name));

// check writing and reading
var data = new MyCustomType
{
first = 10,
second = 20,
};
var writer = new NetworkWriter();
writer.Write(data);
var reader = new NetworkReader(writer.ToArraySegment());
MyCustomType copy = reader.Read<MyCustomType>();

// should have copied both fields,
// if it uses custom extension methods it will only write first
Assert.That(copy.first, Is.EqualTo(data.first));
Assert.That(copy.second, Is.EqualTo(data.second));
}

}
public struct MyCustomType
{
public int first;
public int second;
}
public static class MyCustomTypeExtension
{
[WeaverIgnore]
public static void WriteOnlyPartOfCustom(this NetworkWriter writer, MyCustomType value)
{
writer.WriteInt32(value.first);
}
[WeaverIgnore]

public static MyCustomType ReadOnlyPartOfCustom(this NetworkReader reader)
{
return new MyCustomType { first = reader.ReadInt32() };
}
}
}

0 comments on commit 9494500

Please sign in to comment.