Skip to content

Commit

Permalink
feat: support for generic network messages (#1040)
Browse files Browse the repository at this point in the history
Allowing writers to be creates for generic instance, eg MyType<int>

This does not work with the generic itself MyType<T> because of the way writers are registered
  • Loading branch information
James-Frowen committed Feb 17, 2022
1 parent 7cda9fb commit 2d8990d
Show file tree
Hide file tree
Showing 19 changed files with 406 additions and 74 deletions.
65 changes: 65 additions & 0 deletions Assets/Mirage/Weaver/Extensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,71 @@ public static T GetField<T>(this CustomAttribute ca, string field, T defaultValu
return defaultValue;
}

/// <summary>
/// Imports a field and makes it a member of its orignal type.
/// <para>This is needed if orignal type is a generic instance, this will ensure that it stays a member of that instance, eg MyMessage{int}.Value</para>
/// </summary>
/// <param name="module"></param>
/// <param name="field"></param>
/// <param name="orignalType">Type that the field orignal belonged to, NOT the resolved type</param>
/// <returns></returns>
public static FieldReference ImportField(this ModuleDefinition module, FieldDefinition field, TypeReference orignalType)
{
if (orignalType.Module != module)
orignalType = module.ImportReference(orignalType);

TypeReference fieldType = module.ImportReference(field.FieldType);
return new FieldReference(field.Name, fieldType, orignalType);
}

/// <summary>
///
/// </summary>
/// <param name="field"></param>
/// <param name="orignalType">make sure orignalType is already imported</param>
/// <returns></returns>
public static TypeReference GetFieldTypeIncludingGeneric(this FieldDefinition field, TypeReference orignalType)
{
// if generic, then check if it has a type from orignalType
if (field.FieldType.IsGenericParameter && orignalType.IsGenericInstance)
{
if (FindGenericArgmentWithMatchingName(field.FieldType, orignalType, out TypeReference found))
return found;
}

// if not generic, or no matching found just return its type
return field.FieldType;
}

private static bool FindGenericArgmentWithMatchingName(TypeReference genericParameter, TypeReference orignalType, out TypeReference found)
{
// resolve to get GenericParameters
TypeDefinition resolved = orignalType.Resolve();

string typeName = genericParameter.Name;
for (int i = 0; i < resolved.GenericParameters.Count; i++)
{
GenericParameter param = resolved.GenericParameters[i];
if (param.Name == typeName)
{
var generic = (GenericInstanceType)orignalType;
found = generic.GenericArguments[i];
return true;
}
}

found = null;
return false;
}

/// <summary>
/// Makes a field part of generic defintion
/// <para>
/// NOTE: this only works when you need the type to be part of a generic defintion, NOT a generic instance, eg member of List{T} works, but List{int} doesn't
/// </para>
/// </summary>
/// <param name="fd"></param>
/// <returns></returns>
public static FieldReference MakeHostGenericIfNeeded(this FieldReference fd)
{
if (fd.DeclaringType.HasGenericParameters)
Expand Down
3 changes: 2 additions & 1 deletion Assets/Mirage/Weaver/Processors/SyncVarProcessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,8 @@ void WriteFromField(ILProcessor worker, ParameterDefinition writerParameter, Fou
{
if (!syncVar.HasProcessed) return;

syncVar.ValueSerializer.AppendWriteField(module, worker, writerParameter, null, syncVar.FieldDefinition);
FieldReference fieldRef = syncVar.FieldDefinition.MakeHostGenericIfNeeded();
syncVar.ValueSerializer.AppendWriteField(module, worker, writerParameter, null, fieldRef);
}


Expand Down
6 changes: 3 additions & 3 deletions Assets/Mirage/Weaver/Serialization/BitCountSerializer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,17 @@ internal BitCountSerializer CopyWithZigZag()
return new BitCountSerializer(bitCount, typeConverter, true);
}

public override void AppendWriteField(ModuleDefinition module, ILProcessor worker, ParameterDefinition writerParameter, ParameterDefinition typeParameter, FieldDefinition fieldDefinition)
public override void AppendWriteField(ModuleDefinition module, ILProcessor worker, ParameterDefinition writerParameter, ParameterDefinition typeParameter, FieldReference fieldReference)
{
MethodReference writeWithBitCount = module.ImportReference(typeof(NetworkWriter).GetMethod(nameof(NetworkWriter.Write)));

worker.Append(LoadParamOrArg0(worker, writerParameter));
worker.Append(LoadParamOrArg0(worker, typeParameter));
worker.Append(worker.Create(OpCodes.Ldfld, ImportField(module, fieldDefinition)));
worker.Append(worker.Create(OpCodes.Ldfld, fieldReference));

if (useZigZag)
{
WriteZigZag(module, worker, fieldDefinition.FieldType);
WriteZigZag(module, worker, fieldReference.FieldType);
}
if (minValue.HasValue)
{
Expand Down
4 changes: 2 additions & 2 deletions Assets/Mirage/Weaver/Serialization/BlockSizeSerializer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@ public BlockSizeSerializer(int blockSize, OpCode? typeConverter)
this.typeConverter = typeConverter;
}

public override void AppendWriteField(ModuleDefinition module, ILProcessor worker, ParameterDefinition writerParameter, ParameterDefinition typeParameter, FieldDefinition fieldDefinition)
public override void AppendWriteField(ModuleDefinition module, ILProcessor worker, ParameterDefinition writerParameter, ParameterDefinition typeParameter, FieldReference fieldReference)
{
MethodReference writeWithBlockSize = module.ImportReference(() => VarIntBlocksPacker.Pack(default, default, default));

worker.Append(LoadParamOrArg0(worker, writerParameter));
worker.Append(LoadParamOrArg0(worker, typeParameter));
worker.Append(worker.Create(OpCodes.Ldfld, ImportField(module, fieldDefinition)));
worker.Append(worker.Create(OpCodes.Ldfld, fieldReference));
worker.Append(worker.Create(OpCodes.Conv_U8));
worker.Append(worker.Create(OpCodes.Ldc_I4, blockSize));
worker.Append(worker.Create(OpCodes.Call, writeWithBlockSize));
Expand Down
4 changes: 2 additions & 2 deletions Assets/Mirage/Weaver/Serialization/FunctionSerializer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@ public FunctionSerializer(MethodReference writeFunction, MethodReference readFun
this.readFunction = readFunction;
}

public override void AppendWriteField(ModuleDefinition module, ILProcessor worker, ParameterDefinition writerParameter, ParameterDefinition typeParameter, FieldDefinition fieldDefinition)
public override void AppendWriteField(ModuleDefinition module, ILProcessor worker, ParameterDefinition writerParameter, ParameterDefinition typeParameter, FieldReference fieldReference)
{
// make generic and import field

// if param is null then load arg0 instead
worker.Append(LoadParamOrArg0(worker, writerParameter));
worker.Append(LoadParamOrArg0(worker, typeParameter));
worker.Append(worker.Create(OpCodes.Ldfld, ImportField(module, fieldDefinition)));
worker.Append(worker.Create(OpCodes.Ldfld, fieldReference));
worker.Append(worker.Create(OpCodes.Call, writeFunction));

}
Expand Down
4 changes: 2 additions & 2 deletions Assets/Mirage/Weaver/Serialization/PackerSerializer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ public PackerSerializer(FieldDefinition packerField, LambdaExpression packMethod
IsIntType = isIntType;
}

public override void AppendWriteField(ModuleDefinition module, ILProcessor worker, ParameterDefinition writerParameter, ParameterDefinition typeParameter, FieldDefinition fieldDefinition)
public override void AppendWriteField(ModuleDefinition module, ILProcessor worker, ParameterDefinition writerParameter, ParameterDefinition typeParameter, FieldReference fieldReference)
{
// if PackerField is null it means there was an error earlier, so we dont need to do anything here
if (packerField == null) { return; }
Expand All @@ -32,7 +32,7 @@ public override void AppendWriteField(ModuleDefinition module, ILProcessor worke
worker.Append(worker.Create(OpCodes.Ldsfld, packerField));
worker.Append(LoadParamOrArg0(worker, writerParameter));
worker.Append(LoadParamOrArg0(worker, typeParameter));
worker.Append(worker.Create(OpCodes.Ldfld, ImportField(module, fieldDefinition)));
worker.Append(worker.Create(OpCodes.Ldfld, fieldReference));
worker.Append(worker.Create(OpCodes.Call, packMethod));
}

Expand Down
15 changes: 11 additions & 4 deletions Assets/Mirage/Weaver/Serialization/Readers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -207,18 +207,25 @@ void ReadAllFields(TypeReference type, ReadMethod readMethod)
ILProcessor worker = readMethod.worker;
// create copy here because we might add static packer field
System.Collections.Generic.IEnumerable<FieldDefinition> fields = type.FindAllPublicFields();
foreach (FieldDefinition field in fields)
foreach (FieldDefinition fieldDef in fields)
{
// note:
// - fieldDef to get attributes
// - fieldType (made non-generic if possible) used to get type (eg if MyMessage<int> and field `T Value` then get writer for int)
// - fieldRef (imported) to emit IL codes
TypeReference fieldType = fieldDef.GetFieldTypeIncludingGeneric(type);
FieldReference fieldRef = module.ImportField(fieldDef, type);

ValueSerializer valueSerialize = ValueSerializerFinder.GetSerializer(module, fieldDef, fieldType, null, this);

// load this, write value, store value

// mismatched ldloca/ldloc for struct/class combinations is invalid IL, which causes crash at runtime
OpCode opcode = type.IsValueType ? OpCodes.Ldloca : OpCodes.Ldloc;
worker.Append(worker.Create(opcode, 0));

ValueSerializer valueSerialize = ValueSerializerFinder.GetSerializer(module, field, null, this);
valueSerialize.AppendRead(module, worker, readMethod.readParameter, field.FieldType);
valueSerialize.AppendRead(module, worker, readMethod.readParameter, fieldType);

FieldReference fieldRef = module.ImportReference(field);
worker.Append(worker.Create(OpCodes.Stfld, fieldRef));
}
}
Expand Down
7 changes: 1 addition & 6 deletions Assets/Mirage/Weaver/Serialization/ValueSerializer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,11 @@ public abstract class ValueSerializer
/// </summary>
public abstract bool IsIntType { get; }

public abstract void AppendWriteField(ModuleDefinition module, ILProcessor worker, ParameterDefinition writerParameter, ParameterDefinition typeParameter, FieldDefinition fieldDefinition);
public abstract void AppendWriteField(ModuleDefinition module, ILProcessor worker, ParameterDefinition writerParameter, ParameterDefinition typeParameter, FieldReference fieldReference);
public abstract void AppendWriteParameter(ModuleDefinition module, ILProcessor worker, VariableDefinition writer, ParameterDefinition valueParameter);

public abstract void AppendRead(ModuleDefinition module, ILProcessor worker, ParameterDefinition readerParameter, TypeReference fieldType);

protected static FieldReference ImportField(ModuleDefinition module, FieldDefinition fieldDefinition)
{
return module.ImportReference(fieldDefinition.MakeHostGenericIfNeeded());
}

protected static Instruction LoadParamOrArg0(ILProcessor worker, ParameterDefinition parameter)
{
if (parameter == null)
Expand Down
51 changes: 32 additions & 19 deletions Assets/Mirage/Weaver/Serialization/ValueSerializerFinder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,21 @@ public static ValueSerializer GetSerializer(FoundSyncVar syncVar, Writers writer

/// <exception cref="ValueSerializerException">Throws when attribute is used incorrectly</exception>
/// <exception cref="SerializeFunctionException">Throws when can not generate read or write function</exception>
public static ValueSerializer GetSerializer(ModuleDefinition module, FieldDefinition field, Writers writers, Readers readers)
public static ValueSerializer GetSerializer(ModuleDefinition module, FieldReference field, TypeReference fieldType, Writers writers, Readers readers)
{
// note: we have to `Resolve()` DeclaringType first, because imported referencev `Module` will be equal.
TypeDefinition holder = field.DeclaringType.Resolve();
string name = field.Name;

// if field is in this module use its type for Packer field,
// else use the generated class
TypeDefinition holder = field.DeclaringType.Module == module
? field.DeclaringType
: module.GeneratedClass();

string name = field.DeclaringType.Module == module
? field.Name
: $"{field.DeclaringType.FullName}_{field.Name}";
if (holder.Module != module)
{
holder = module.GeneratedClass();
name = $"{field.DeclaringType.FullName}_{field.Name}";
}

return GetSerializer(module, holder, field, field.FieldType, name, writers, readers);
return GetSerializer(module, holder, field.Resolve(), fieldType, name, writers, readers);
}

/// <exception cref="ValueSerializerException">Throws when attribute is used incorrectly</exception>
Expand All @@ -39,6 +41,7 @@ public static ValueSerializer GetSerializer(MethodDefinition method, ParameterDe
return GetSerializer(method.DeclaringType.Module, method.DeclaringType, param, param.ParameterType, name, writers, readers);
}


/// <summary>
///
/// </summary>
Expand All @@ -61,31 +64,44 @@ public static ValueSerializer GetSerializer(ModuleDefinition module, TypeDefinit
// We need to check if other attributes are also used
// if user adds 2 attributes that dont work together weaver should then throw error
ValueSerializer valueSerializer = null;
bool HasIntAttribute() => valueSerializer != null && valueSerializer.IsIntType;

// attributeProvider is null for generic fields,
// but that is find because they wont have any of these attributes anyway
if (attributeProvider != null)
valueSerializer = GetUsingAttribute(module, holder, attributeProvider, fieldType, fieldName, valueSerializer);

if (valueSerializer == null)
{
valueSerializer = FindSerializeFunctions(writers, readers, fieldType);
}

return valueSerializer;
}

private static ValueSerializer GetUsingAttribute(ModuleDefinition module, TypeDefinition holder, ICustomAttributeProvider attributeProvider, TypeReference fieldType, string fieldName, ValueSerializer valueSerializer)
{
if (attributeProvider.HasCustomAttribute<BitCountAttribute>())
valueSerializer = BitCountFinder.GetSerializer(attributeProvider, fieldType);

if (attributeProvider.HasCustomAttribute<VarIntAttribute>())
{
if (HasIntAttribute())
if (HasIntAttribute(valueSerializer))
throw new VarIntException($"[VarInt] can't be used with [BitCount], [VarIntBlocks] or [BitCountFromRange]");

valueSerializer = new VarIntFinder().GetSerializer(module, holder, attributeProvider, fieldType, fieldName);
}

if (attributeProvider.HasCustomAttribute<VarIntBlocksAttribute>())
{
if (HasIntAttribute())
if (HasIntAttribute(valueSerializer))
throw new VarIntBlocksException($"[VarIntBlocks] can't be used with [BitCount], [VarInt] or [BitCountFromRange]");

valueSerializer = VarIntBlocksFinder.GetSerializer(attributeProvider, fieldType);
}

if (attributeProvider.HasCustomAttribute<BitCountFromRangeAttribute>())
{
if (HasIntAttribute())
if (HasIntAttribute(valueSerializer))
throw new BitCountFromRangeException($"[BitCountFromRange] can't be used with [BitCount], [VarInt] or [VarIntBlocks]");

valueSerializer = BitCountFromRangeFinder.GetSerializer(attributeProvider, fieldType);
Expand All @@ -104,15 +120,12 @@ public static ValueSerializer GetSerializer(ModuleDefinition module, TypeDefinit

if (attributeProvider.HasCustomAttribute<QuaternionPackAttribute>())
valueSerializer = new QuaternionFinder().GetSerializer(module, holder, attributeProvider, fieldType, fieldName);

if (valueSerializer == null)
{
valueSerializer = FindSerializeFunctions(writers, readers, fieldType);
}

return valueSerializer;
}

static bool HasIntAttribute(ValueSerializer valueSerializer) => valueSerializer != null && valueSerializer.IsIntType;


/// <exception cref="SerializeFunctionException">Throws when can not generate read or write function</exception>
static ValueSerializer FindSerializeFunctions(Writers writers, Readers readers, TypeReference fieldType)
{
Expand Down
14 changes: 11 additions & 3 deletions Assets/Mirage/Weaver/Serialization/Writers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,18 @@ void WriteAllFields(TypeReference type, WriteMethod writerFunc)
{
// create copy here because we might add static packer field
System.Collections.Generic.IEnumerable<FieldDefinition> fields = type.FindAllPublicFields();
foreach (FieldDefinition field in fields)
foreach (FieldDefinition fieldDef in fields)
{
ValueSerializer valueSerialize = ValueSerializerFinder.GetSerializer(module, field, this, null);
valueSerialize.AppendWriteField(module, writerFunc.worker, writerFunc.writerParameter, writerFunc.typeParameter, field);
// note:
// - fieldDef to get attributes
// - fieldType (made non-generic if possible) used to get type (eg if MyMessage<int> and field `T Value` then get writer for int)
// - fieldRef (imported) to emit IL codes

TypeReference fieldType = fieldDef.GetFieldTypeIncludingGeneric(type);
FieldReference fieldRef = module.ImportField(fieldDef, type);

ValueSerializer valueSerialize = ValueSerializerFinder.GetSerializer(module, fieldDef, fieldType, this, null);
valueSerialize.AppendWriteField(module, writerFunc.worker, writerFunc.writerParameter, writerFunc.typeParameter, fieldRef);
}
}

Expand Down
4 changes: 2 additions & 2 deletions Assets/Mirage/Weaver/SerializeFunctionBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,8 @@ private MethodReference GenerateFunction(TypeReference typeReference)
throw ThrowCantGenerate(typeReference);
}


if (typeDefinition.HasGenericParameters)
// if it is genericInstance, then we can generate writer for it
if (!typeReference.IsGenericInstance && typeDefinition.HasGenericParameters)
{
throw ThrowCantGenerate(typeReference, "generic type");
}
Expand Down

0 comments on commit 2d8990d

Please sign in to comment.