Skip to content

Commit

Permalink
feat: adding WeaverSerializeCollection that can be added to generic w…
Browse files Browse the repository at this point in the history
…riters

allows generic writers for types like List<T> or Nullable<T> without those methods having to be listed in weaver code. This allows for users generic types too
  • Loading branch information
James-Frowen committed Nov 5, 2023
1 parent 3f4b4a8 commit 00d476b
Show file tree
Hide file tree
Showing 7 changed files with 103 additions and 55 deletions.
4 changes: 4 additions & 0 deletions Assets/Mirage/Runtime/Serialization/CollectionExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ public static void WriteBytesAndSizeSegment(this NetworkWriter writer, ArraySegm
writer.WriteBytesAndSize(buffer.Array, buffer.Offset, buffer.Count);
}

[WeaverSerializeCollection]
public static void WriteList<T>(this NetworkWriter writer, List<T> list)
{
WriteCountPlusOne(writer, list?.Count);
Expand All @@ -65,6 +66,7 @@ public static void WriteArray<T>(this NetworkWriter writer, T[] array)
writer.Write(array[i]);
}

[WeaverSerializeCollection]
public static void WriteArraySegment<T>(this NetworkWriter writer, ArraySegment<T> segment)
{
var array = segment.Array;
Expand Down Expand Up @@ -117,6 +119,7 @@ public static byte[] ReadBytes(this NetworkReader reader, int count)
return bytes;
}

[WeaverSerializeCollection]
public static List<T> ReadList<T>(this NetworkReader reader)
{
var hasValue = ReadCountPlusOne(reader, out var length);
Expand Down Expand Up @@ -149,6 +152,7 @@ public static T[] ReadArray<T>(this NetworkReader reader)
return result;
}

[WeaverSerializeCollection]
public static ArraySegment<T> ReadArraySegment<T>(this NetworkReader reader)
{
var array = reader.ReadArray<T>();
Expand Down
3 changes: 3 additions & 0 deletions Assets/Mirage/Runtime/Serialization/SystemTypesExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ public static void WriteGuid(this NetworkWriter writer, Guid value)
writer.WriteBytes(data, 0, data.Length);
}

[WeaverSerializeCollection]
public static void WriteNullable<T>(this NetworkWriter writer, T? nullable) where T : struct
{
var hasValue = nullable.HasValue;
Expand Down Expand Up @@ -143,6 +144,8 @@ public static decimal ReadDecimalConverter(this NetworkReader reader)
return converter.decimalValue;
}
public static Guid ReadGuid(this NetworkReader reader) => new Guid(reader.ReadBytes(16));

[WeaverSerializeCollection]
public static T? ReadNullable<T>(this NetworkReader reader) where T : struct
{
var hasValue = reader.ReadBoolean();
Expand Down
10 changes: 10 additions & 0 deletions Assets/Mirage/Runtime/Serialization/WeaverAttributes.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,16 @@ public sealed class WeaverIgnoreAttribute : Attribute { }
[AttributeUsage(AttributeTargets.Class | AttributeTargets.Struct | AttributeTargets.Enum | AttributeTargets.Interface)]
public sealed class WeaverWriteAsGenericAttribute : Attribute { }

/// <summary>
/// Tells weaver to use this method to write a generic type or collection
/// <para>Can also be used for other generic types like Nullable</para>
/// </summary>
[AttributeUsage(AttributeTargets.Method)]
public sealed class WeaverSerializeCollectionAttribute : Attribute
{
public WeaverSerializeCollectionAttribute() { }
}

/// <summary>
/// Tells weaver how many bits to sue for field
/// <para>Only works with integer fields (byte, int, ulong, enums etc)</para>
Expand Down
76 changes: 56 additions & 20 deletions Assets/Mirage/Weaver/Processors/ReaderWriterProcessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -395,21 +395,40 @@ public void RegisterExtensionMethodsInType(Type type)
if (!IsStatic(type))
return;

var methods = type.GetMethods(BindingFlags.Static | BindingFlags.Public)
var extensionMethods = type.GetMethods(BindingFlags.Static | BindingFlags.Public)
.Where(IsExtension)
.Where(NotGeneric)
.Where(NotIgnored);

var methods = extensionMethods.Where(NotGeneric);
var collectionMethods = extensionMethods.Where(IsCollectionMethod);

foreach (var method in methods)
{
if (IsWriterMethod(method))
{
RegisterWriter(method);
var dataType = GetWriterDataType(method);
writers.Register(module.ImportReference(dataType), module.ImportReference(method));
}

if (IsReaderMethod(method))
{
var dataType = GetReaderDataType(method);
readers.Register(module.ImportReference(dataType), module.ImportReference(method));
}
}

foreach (var method in collectionMethods)
{
if (IsWriterMethod(method))
{
var dataType = GetWriterDataType(method);
writers.RegisterCollectionMethod(dataType.Resolve(), module.ImportReference(method));
}

if (IsReaderMethod(method))
{
RegisterReader(method);
var dataType = GetReaderDataType(method);
readers.RegisterCollectionMethod(dataType.Resolve(), module.ImportReference(method));
}
}
}
Expand All @@ -419,21 +438,40 @@ public void RegisterExtensionMethodsInType(TypeDefinition type)
if (!IsStatic(type))
return;

var methods = type.Methods
var extensionMethods = type.Methods
.Where(IsExtension)
.Where(NotGeneric)
.Where(NotIgnored);

var methods = extensionMethods.Where(NotGeneric);
var collectionMethods = extensionMethods.Where(IsCollectionMethod);

foreach (var method in methods)
{
if (IsWriterMethod(method))
{
RegisterWriter(method);
var dataType = GetWriterDataType(method);
writers.Register(module.ImportReference(dataType), module.ImportReference(method));
}

if (IsReaderMethod(method))
{
RegisterReader(method);
var dataType = GetReaderDataType(method);
readers.Register(module.ImportReference(dataType), module.ImportReference(method));
}
}

foreach (var method in collectionMethods)
{
if (IsWriterMethod(method))
{
var dataType = GetWriterDataType(method);
writers.RegisterCollectionMethod(dataType.Resolve(), module.ImportReference(method));
}

if (IsReaderMethod(method))
{
var dataType = GetReaderDataType(method);
readers.RegisterCollectionMethod(dataType.Resolve(), module.ImportReference(method));
}
}
}
Expand All @@ -449,11 +487,10 @@ public void RegisterExtensionMethodsInType(TypeDefinition type)
private static bool IsExtension(MethodDefinition method) => method.HasCustomAttribute<ExtensionAttribute>();
private static bool NotGeneric(MethodInfo method) => !method.IsGenericMethod;
private static bool NotGeneric(MethodDefinition method) => !method.IsGenericInstance && !method.HasGenericParameters;

/// <returns>true if method does not have <see cref="WeaverIgnoreAttribute"/></returns>
private static bool NotIgnored(MethodInfo method) => !Attribute.IsDefined(method, typeof(WeaverIgnoreAttribute));
/// <returns>true if method does not have <see cref="WeaverIgnoreAttribute"/></returns>
private static bool NotIgnored(MethodDefinition method) => !method.HasCustomAttribute<WeaverIgnoreAttribute>();
private static bool IsCollectionMethod(MethodInfo method) => Attribute.IsDefined(method, typeof(WeaverSerializeCollectionAttribute));
private static bool IsCollectionMethod(MethodDefinition method) => method.HasCustomAttribute<WeaverSerializeCollectionAttribute>();


private static bool IsWriterMethod(MethodInfo method)
Expand Down Expand Up @@ -510,31 +547,30 @@ private bool IsReaderMethod(MethodDefinition method)
return true;
}

private void RegisterWriter(MethodInfo method)
private TypeReference GetWriterDataType(MethodInfo method)
{
ReaderWriterProcessor.Log($"Found writer extension methods: {method.Name}");

var dataType = method.GetParameters()[1].ParameterType;
writers.Register(module.ImportReference(dataType), module.ImportReference(method));
return module.ImportReference(dataType);
}
private void RegisterWriter(MethodDefinition method)
private TypeReference GetWriterDataType(MethodDefinition method)
{
ReaderWriterProcessor.Log($"Found writer extension methods: {method.Name}");

var dataType = method.Parameters[1].ParameterType;
writers.Register(module.ImportReference(dataType), module.ImportReference(method));
return method.Parameters[1].ParameterType;
}


private void RegisterReader(MethodInfo method)
private TypeReference GetReaderDataType(MethodInfo method)
{
ReaderWriterProcessor.Log($"Found reader extension methods: {method.Name}");
readers.Register(module.ImportReference(method.ReturnType), module.ImportReference(method));
return module.ImportReference(method.ReturnType);
}
private void RegisterReader(MethodDefinition method)
private TypeReference GetReaderDataType(MethodDefinition method)
{
ReaderWriterProcessor.Log($"Found reader extension methods: {method.Name}");
readers.Register(module.ImportReference(method.ReturnType), module.ImportReference(method));
return method.ReturnType;
}
}
}
9 changes: 3 additions & 6 deletions Assets/Mirage/Weaver/Serialization/Readers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@ public class Readers : SerializeFunctionBase

protected override string FunctionTypeLog => "read function";
protected override Expression<Action> ArrayExpression => () => CollectionExtensions.ReadArray<byte>(default);
protected override Expression<Action> ListExpression => () => CollectionExtensions.ReadList<byte>(default);
protected override Expression<Action> SegmentExpression => () => CollectionExtensions.ReadArraySegment<byte>(default);
protected override Expression<Action> NullableExpression => () => SystemTypesExtensions.ReadNullable<byte>(default);

protected override MethodReference GetGenericFunction()
{
Expand Down Expand Up @@ -87,16 +84,16 @@ private ReadMethod GenerateReaderFunction(TypeReference variable)
return new ReadMethod(definition, readParameter, worker);
}

protected override MethodReference GenerateCollectionFunction(TypeReference typeReference, TypeReference elementType, Expression<Action> genericExpression)
protected override MethodReference GenerateCollectionFunction(TypeReference typeReference, TypeReference elementType, MethodReference collectionMethod)
{
// generate readers for the element
_ = GetFunction_Throws(elementType);

var readMethod = GenerateReaderFunction(typeReference);

var listReader = module.ImportReference(genericExpression);
var collectionReader = collectionMethod.GetElementMethod();

var methodRef = new GenericInstanceMethod(listReader.GetElementMethod());
var methodRef = new GenericInstanceMethod(collectionReader);
methodRef.GenericArguments.Add(elementType);

// generates
Expand Down
8 changes: 2 additions & 6 deletions Assets/Mirage/Weaver/Serialization/Writers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@ public class Writers : SerializeFunctionBase

protected override string FunctionTypeLog => "write function";
protected override Expression<Action> ArrayExpression => () => CollectionExtensions.WriteArray<byte>(default, default);
protected override Expression<Action> ListExpression => () => CollectionExtensions.WriteList<byte>(default, default);
protected override Expression<Action> SegmentExpression => () => CollectionExtensions.WriteArraySegment<byte>(default, default);
protected override Expression<Action> NullableExpression => () => SystemTypesExtensions.WriteNullable<byte>(default, default);

protected override MethodReference GetGenericFunction()
{
Expand Down Expand Up @@ -154,15 +151,14 @@ private void WriteAllFields(TypeReference type, WriteMethod writerFunc)
}
}

protected override MethodReference GenerateCollectionFunction(TypeReference typeReference, TypeReference elementType, Expression<Action> genericExpression)
protected override MethodReference GenerateCollectionFunction(TypeReference typeReference, TypeReference elementType, MethodReference collectionMethod)
{
// make sure element has a writer
// collection writers use the generic writer, so this will make sure one exists
_ = GetFunction_Throws(elementType);

var writerMethod = GenerateWriterFunc(typeReference);

var collectionWriter = module.ImportReference(genericExpression).GetElementMethod();
var collectionWriter = collectionMethod.GetElementMethod();

var methodRef = new GenericInstanceMethod(collectionWriter);
methodRef.GenericArguments.Add(elementType);
Expand Down
48 changes: 25 additions & 23 deletions Assets/Mirage/Weaver/SerializeFunctionBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ private static void Log(string msg)
}

protected readonly Dictionary<TypeReference, MethodReference> funcs = new Dictionary<TypeReference, MethodReference>(new TypeReferenceComparer());
protected readonly Dictionary<TypeDefinition, MethodReference> collectionMethod = new Dictionary<TypeDefinition, MethodReference>();
private readonly IWeaverLogger logger;
protected readonly ModuleDefinition module;

Expand Down Expand Up @@ -56,6 +57,21 @@ public void Register(TypeReference dataType, MethodReference methodReference)
//MarkAsGenerated(dataType); <--- broken in unity2021
}

public void RegisterCollectionMethod(TypeDefinition dataType, MethodReference methodReference)
{
if (collectionMethod.ContainsKey(dataType))
{
logger.Warning(
$"Registering a {FunctionTypeLog} for {dataType.FullName} when one already exists\n" +
$" old:{collectionMethod[dataType].FullName}\n" +
$" new:{methodReference.FullName}",
methodReference.Resolve());
}

Log($"Register Collection Method {FunctionTypeLog} for {dataType.FullName}, method:{methodReference.FullName}");
collectionMethod[dataType] = methodReference;
}

/// <summary>
/// Trys to get writer for type, returns null if not found
/// </summary>
Expand Down Expand Up @@ -147,35 +163,24 @@ private MethodReference GenerateFunction(TypeReference typeReference)
throw new SerializeFunctionException($"{typeReference.Name} is an unsupported type. Multidimensional arrays are not supported", typeReference);
}
var elementType = typeReference.GetElementType();
return GenerateCollectionFunction(typeReference, elementType, ArrayExpression);
var arrayMethod = module.ImportReference(ArrayExpression);
return GenerateCollectionFunction(typeReference, elementType, arrayMethod);
}

// check for collections
if (typeReference.Is(typeof(Nullable<>)))
{
var genericInstance = (GenericInstanceType)typeReference;
var elementType = genericInstance.GenericArguments[0];

return GenerateCollectionFunction(typeReference, elementType, NullableExpression);
}
if (typeReference.Is(typeof(ArraySegment<>)))
{
var genericInstance = (GenericInstanceType)typeReference;
var elementType = genericInstance.GenericArguments[0];
var typeDefinition = typeReference.Resolve();

return GenerateCollectionFunction(typeReference, elementType, SegmentExpression);
}
if (typeReference.Is(typeof(List<>)))
// check for collections
var isCollection = collectionMethod.TryGetValue(typeDefinition, out var collectionMethohd);
Console.WriteLine($"[CollectionMethod] {typeReference} isCollection={isCollection}");
if (isCollection)
{
var genericInstance = (GenericInstanceType)typeReference;
var elementType = genericInstance.GenericArguments[0];

return GenerateCollectionFunction(typeReference, elementType, ListExpression);
return GenerateCollectionFunction(typeReference, elementType, collectionMethohd);
}


// check for invalid types
var typeDefinition = typeReference.Resolve();
if (typeDefinition == null)
{
throw ThrowCantGenerate(typeReference);
Expand Down Expand Up @@ -258,12 +263,9 @@ private GenericInstanceMethod CreateGenericFunction(TypeReference argument)
protected abstract MethodReference GetNetworkBehaviourFunction(TypeReference typeReference);

protected abstract MethodReference GenerateEnumFunction(TypeReference typeReference);
protected abstract MethodReference GenerateCollectionFunction(TypeReference typeReference, TypeReference elementType, Expression<Action> genericExpression);
protected abstract MethodReference GenerateCollectionFunction(TypeReference typeReference, TypeReference elementType, MethodReference collectionMethod);

protected abstract Expression<Action> ArrayExpression { get; }
protected abstract Expression<Action> ListExpression { get; }
protected abstract Expression<Action> SegmentExpression { get; }
protected abstract Expression<Action> NullableExpression { get; }

protected abstract MethodReference GenerateClassOrStructFunction(TypeReference typeReference);
}
Expand Down

0 comments on commit 00d476b

Please sign in to comment.