Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: rpcs and syncvars now support generic components #574

Merged
merged 19 commits into from
Feb 21, 2021
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions Assets/Mirage/Weaver/Extensions.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using System;
using System.Linq;
using Mono.Cecil;
using Mono.Collections.Generic;

namespace Mirage.Weaver
{
Expand Down Expand Up @@ -203,5 +204,21 @@ public static T GetField<T>(this CustomAttribute ca, string field, T defaultValu
return defaultValue;
}

public static FieldReference MakeHostGenericIfNeeded(this FieldReference fd)
{
if (fd.DeclaringType.HasGenericParameters)
{
return new FieldReference(fd.Name, fd.FieldType, fd.DeclaringType.Resolve().ConvertToGenericIfNeeded());
}

return fd;
}

public static FieldReference Duplicate(this FieldReference fr, TypeReference declaringType = null)
{
FieldReference newFr = new FieldReference(fr.Name, fr.FieldType, declaringType ?? fr.DeclaringType);

return newFr;
}
}
}
30 changes: 29 additions & 1 deletion Assets/Mirage/Weaver/MethodExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -52,5 +52,33 @@ public static SequencePoint GetSequencePoint(this MethodDefinition method, Instr
SequencePoint sequencePoint = method.DebugInformation.GetSequencePoint(instruction);
return sequencePoint;
}

/// <summary>
/// Duplicates a method reference.
/// </summary>
/// <param name="method"></param>
/// <param name="declaringType">A new declaring type. Set to null to be the same as the base method.</param>
/// <returns></returns>
public static MethodReference Duplicate(this MethodReference method, TypeReference declaringType = null)
{
MethodReference newMethod = new MethodReference(method.Name, method.ReturnType, declaringType ?? method.DeclaringType)
{
HasThis = method.HasThis,
ExplicitThis = method.ExplicitThis
};

if (method.HasParameters)
{
// Add back all the parameters.
for (int i = 0; i < method.Parameters.Count; i++)
{
newMethod.Parameters.Add(new ParameterDefinition(method.Parameters[i].Name,
method.Parameters[i].Attributes,
method.Parameters[i].ParameterType));
}
}

return newMethod;
}
}
}
}
6 changes: 3 additions & 3 deletions Assets/Mirage/Weaver/Processors/ClientRpcProcessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ MethodDefinition GenerateSkeleton(MethodDefinition md, MethodDefinition userCode

// setup for reader
worker.Append(worker.Create(OpCodes.Ldarg_0));
worker.Append(worker.Create(OpCodes.Castclass, md.DeclaringType));
worker.Append(worker.Create(OpCodes.Castclass, md.DeclaringType.ConvertToGenericIfNeeded()));

// NetworkConnection parameter is only required for Client.Connection
Client target = clientRpcAttr.GetField("target", Client.Observers);
Expand Down Expand Up @@ -168,7 +168,7 @@ MethodDefinition GenerateStub(MethodDefinition md, CustomAttribute clientRpcAttr
else if (target == Client.Owner)
worker.Append(worker.Create(OpCodes.Ldnull));

worker.Append(worker.Create(OpCodes.Ldtoken, md.DeclaringType));
worker.Append(worker.Create(OpCodes.Ldtoken, md.DeclaringType.ConvertToGenericIfNeeded()));
// invokerClass
worker.Append(worker.Create(OpCodes.Call, () => Type.GetTypeFromHandle(default)));
worker.Append(worker.Create(OpCodes.Ldstr, rpcName));
Expand Down Expand Up @@ -237,7 +237,7 @@ public void RegisterClientRpcs(ILProcessor cctorWorker)
*/
void GenerateRegisterRemoteDelegate(ILProcessor worker, MethodDefinition func, string cmdName)
{
TypeDefinition netBehaviourSubclass = func.DeclaringType;
TypeReference netBehaviourSubclass = func.DeclaringType.ConvertToGenericIfNeeded();
worker.Append(worker.Create(OpCodes.Ldtoken, netBehaviourSubclass));
worker.Append(worker.Create(OpCodes.Call, () => Type.GetTypeFromHandle(default)));
worker.Append(worker.Create(OpCodes.Ldstr, cmdName));
Expand Down
8 changes: 4 additions & 4 deletions Assets/Mirage/Weaver/Processors/FieldReferenceComparator.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using System.Collections.Generic;
using System.Collections.Generic;
using Mono.Cecil;

namespace Mirage.Weaver
Expand All @@ -7,9 +7,9 @@ internal class FieldReferenceComparator : IEqualityComparer<FieldReference>
{
public bool Equals(FieldReference x, FieldReference y)
{
return x.FullName == y.FullName;
return x.DeclaringType.FullName == y.DeclaringType.FullName && x.Name == y.Name;
}

public int GetHashCode(FieldReference obj) => obj.FullName.GetHashCode();
public int GetHashCode(FieldReference obj) => (obj.DeclaringType.FullName + "." + obj.Name).GetHashCode();
}
}
}
56 changes: 47 additions & 9 deletions Assets/Mirage/Weaver/Processors/PropertySiteProcessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,19 @@ void ProcessInstructionSetterField(Instruction i, FieldReference opField)
// does it set a field that we replaced?
if (Setters.TryGetValue(opField, out MethodDefinition replacement))
{
//replace with property
i.OpCode = OpCodes.Call;
i.Operand = replacement;
if (opField.DeclaringType.IsGenericInstance || opField.DeclaringType.HasGenericParameters) // We're calling to a generic class
{
FieldReference newField = i.Operand as FieldReference;
GenericInstanceType genericType = (GenericInstanceType)newField.DeclaringType;
i.OpCode = OpCodes.Callvirt;
i.Operand = replacement.MakeHostInstanceGeneric(genericType);
}
else
{
//replace with property
i.OpCode = OpCodes.Call;
i.Operand = replacement;
}
}
}

Expand All @@ -46,31 +56,59 @@ void ProcessInstructionGetterField(Instruction i, FieldReference opField)
// does it set a field that we replaced?
if (Getters.TryGetValue(opField, out MethodDefinition replacement))
{
//replace with property
i.OpCode = OpCodes.Call;
i.Operand = replacement;
if (opField.DeclaringType.IsGenericInstance || opField.DeclaringType.HasGenericParameters) // We're calling to a generic class
{
FieldReference newField = i.Operand as FieldReference;
GenericInstanceType genericType = (GenericInstanceType)newField.DeclaringType;
i.OpCode = OpCodes.Callvirt;
i.Operand = replacement.MakeHostInstanceGeneric(genericType);
}
else
{
//replace with property
i.OpCode = OpCodes.Call;
i.Operand = replacement;
}
}
}

Instruction ProcessInstruction(MethodDefinition md, Instruction instr, SequencePoint sequencePoint)
{
if (instr.OpCode == OpCodes.Stfld && instr.Operand is FieldReference opFieldst)
{
FieldReference resolved = opFieldst.Resolve();
if (resolved == null)
{
resolved = opFieldst.Duplicate(opFieldst.DeclaringType.Resolve());
}

// this instruction sets the value of a field. cache the field reference.
ProcessInstructionSetterField(instr, opFieldst);
ProcessInstructionSetterField(instr, resolved);
}

if (instr.OpCode == OpCodes.Ldfld && instr.Operand is FieldReference opFieldld)
{
FieldReference resolved = opFieldld.Resolve();
if (resolved == null)
{
resolved = opFieldld.Duplicate(opFieldld.DeclaringType.Resolve());
}

// this instruction gets the value of a field. cache the field reference.
ProcessInstructionGetterField(instr, opFieldld);
ProcessInstructionGetterField(instr, resolved);
}

if (instr.OpCode == OpCodes.Ldflda && instr.Operand is FieldReference opFieldlda)
{
FieldReference resolved = opFieldlda.Resolve();
if (resolved == null)
{
resolved = opFieldlda.Duplicate(opFieldlda.DeclaringType.Resolve());
Hertzole marked this conversation as resolved.
Show resolved Hide resolved
}

// loading a field by reference, watch out for initobj instruction
// see https://github.com/vis2k/Mirror/issues/696
return ProcessInstructionLoadAddress(md, instr, opFieldlda);
return ProcessInstructionLoadAddress(md, instr, resolved);
}

return instr;
Expand Down
8 changes: 4 additions & 4 deletions Assets/Mirage/Weaver/Processors/RpcProcessor.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using System.Linq;
using System.Linq;
using System.Reflection;
using Cysharp.Threading.Tasks;
using Mirage.RemoteCalls;
Expand Down Expand Up @@ -338,15 +338,15 @@ public void FixRemoteCallToBaseMethod(TypeDefinition type, MethodDefinition meth
calledMethod.Name == baseRemoteCallName)
{
TypeDefinition baseType = type.BaseType.Resolve();
MethodDefinition baseMethod = baseType.GetMethodInBaseType(callName);
MethodReference baseMethod = baseType.GetMethodInBaseType(callName);

if (baseMethod == null)
{
logger.Error($"Could not find base method for {callName}", method);
return;
}

if (!baseMethod.IsVirtual)
if (!baseMethod.Resolve().IsVirtual)
{
logger.Error($"Could not find base method that was virtual {callName}", method);
return;
Expand Down Expand Up @@ -375,4 +375,4 @@ static bool IsCallToMethod(Instruction instruction, out MethodDefinition calledM
}

}
}
}
6 changes: 3 additions & 3 deletions Assets/Mirage/Weaver/Processors/ServerRpcProcessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ MethodDefinition GenerateStub(MethodDefinition md, CustomAttribute serverRpcAttr
// invoke internal send and return
// load 'base.' to call the SendServerRpc function with
worker.Append(worker.Create(OpCodes.Ldarg_0));
worker.Append(worker.Create(OpCodes.Ldtoken, md.DeclaringType));
worker.Append(worker.Create(OpCodes.Ldtoken, md.DeclaringType.ConvertToGenericIfNeeded()));
// invokerClass
worker.Append(worker.Create(OpCodes.Call, () => Type.GetTypeFromHandle(default)));
worker.Append(worker.Create(OpCodes.Ldstr, cmdName));
Expand Down Expand Up @@ -159,7 +159,7 @@ MethodDefinition GenerateSkeleton(MethodDefinition method, MethodDefinition user

// setup for reader
worker.Append(worker.Create(OpCodes.Ldarg_0));
worker.Append(worker.Create(OpCodes.Castclass, method.DeclaringType));
worker.Append(worker.Create(OpCodes.Castclass, method.DeclaringType.ConvertToGenericIfNeeded()));

if (!ReadArguments(method, worker, false))
return cmd;
Expand Down Expand Up @@ -214,7 +214,7 @@ void GenerateRegisterServerRpcDelegate(ILProcessor worker, ServerRpcMethod cmdRe
bool requireAuthority = cmdResult.requireAuthority;

TypeDefinition netBehaviourSubclass = skeleton.DeclaringType;
worker.Append(worker.Create(OpCodes.Ldtoken, netBehaviourSubclass));
worker.Append(worker.Create(OpCodes.Ldtoken, netBehaviourSubclass.ConvertToGenericIfNeeded()));
worker.Append(worker.Create(OpCodes.Call, () => Type.GetTypeFromHandle(default)));
worker.Append(worker.Create(OpCodes.Ldstr, cmdName));
worker.Append(worker.Create(OpCodes.Ldnull));
Expand Down
2 changes: 1 addition & 1 deletion Assets/Mirage/Weaver/Processors/SyncObjectProcessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ public void ProcessSyncObjects(TypeDefinition td)
{
foreach (FieldDefinition fd in td.Fields)
{
if (fd.FieldType.IsGenericParameter) // Just ignore all generic objects.
if (fd.FieldType.IsGenericParameter || fd.ContainsGenericParameter) // Just ignore all generic objects.
{
continue;
}
Expand Down
24 changes: 16 additions & 8 deletions Assets/Mirage/Weaver/Processors/SyncVarProcessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -203,15 +203,15 @@ private void StoreField(FieldDefinition fd, ParameterDefinition valueParam, ILPr
MethodReference setter = module.ImportReference(fd.FieldType.Resolve().GetMethod("set_Value"));

worker.Append(worker.Create(OpCodes.Ldarg_0));
worker.Append(worker.Create(OpCodes.Ldflda, fd));
worker.Append(worker.Create(OpCodes.Ldflda, fd.MakeHostGenericIfNeeded()));
worker.Append(worker.Create(OpCodes.Ldarg, valueParam));
worker.Append(worker.Create(OpCodes.Call, setter));
}
else
{
worker.Append(worker.Create(OpCodes.Ldarg_0));
worker.Append(worker.Create(OpCodes.Ldarg, valueParam));
worker.Append(worker.Create(OpCodes.Stfld, fd));
worker.Append(worker.Create(OpCodes.Stfld, fd.MakeHostGenericIfNeeded()));
}
}

Expand All @@ -227,7 +227,7 @@ private void LoadField(FieldDefinition fd, ILProcessor worker)
}
else
{
worker.Append(worker.Create(OpCodes.Ldfld, fd));
worker.Append(worker.Create(OpCodes.Ldfld, fd.MakeHostGenericIfNeeded()));
}
}

Expand Down Expand Up @@ -413,7 +413,15 @@ void WriteEndFunctionCall()
{
// only use Callvirt when not static
OpCode opcode = hookMethod.IsStatic ? OpCodes.Call : OpCodes.Callvirt;
worker.Append(worker.Create(opcode, hookMethod));
MethodReference hookMethodReference = hookMethod;
// If the delcaring type is generic we need to duplicate the method reference.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as a rule of thumb, don't explain what this is doing. Explain why instead.

I can see that you are duplicating the method ref just by reading the code, the comment is redundant. What I can't figure out is why we are doing it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed duplicate and changed it with some more straightforward code. I hope it becomes more clear.
Removed the duplicate in 86a8205 and added a comment in cb40e99.

// This will make sure it calls the correct instance of the class.
if (hookMethod.DeclaringType.HasGenericParameters)
{
hookMethodReference = hookMethod.Duplicate(hookMethod.DeclaringType.ConvertToGenericIfNeeded());
}

worker.Append(worker.Create(opcode, module.ImportReference(hookMethodReference)));
}
}

Expand Down Expand Up @@ -444,7 +452,7 @@ void GenerateSerialization(TypeDefinition netBehaviourSubclass)
// loc_0, this local variable is to determine if any variable was dirty
VariableDefinition dirtyLocal = serialize.AddLocal<bool>();

MethodDefinition baseSerialize = netBehaviourSubclass.BaseType.Resolve().GetMethodInBaseType(SerializeMethodName);
MethodReference baseSerialize = netBehaviourSubclass.BaseType.GetMethodInBaseType(SerializeMethodName);
if (baseSerialize != null)
{
// base
Expand Down Expand Up @@ -529,7 +537,7 @@ private void WriteVariable(ILProcessor worker, ParameterDefinition writerParamet
worker.Append(worker.Create(OpCodes.Ldarg, writerParameter));
// this
worker.Append(worker.Create(OpCodes.Ldarg_0));
worker.Append(worker.Create(OpCodes.Ldfld, syncVar));
worker.Append(worker.Create(OpCodes.Ldfld, syncVar.MakeHostGenericIfNeeded()));
MethodReference writeFunc = writers.GetWriteFunc(syncVar.FieldType, null);
if (writeFunc != null)
{
Expand Down Expand Up @@ -565,7 +573,7 @@ void GenerateDeSerialization(TypeDefinition netBehaviourSubclass)
serialize.Body.InitLocals = true;
VariableDefinition dirtyBitsLocal = serialize.AddLocal<long>();

MethodDefinition baseDeserialize = netBehaviourSubclass.BaseType.Resolve().GetMethodInBaseType(DeserializeMethodName);
MethodReference baseDeserialize = netBehaviourSubclass.BaseType.GetMethodInBaseType(DeserializeMethodName);
if (baseDeserialize != null)
{
// base
Expand Down Expand Up @@ -672,7 +680,7 @@ void DeserializeField(FieldDefinition syncVar, ILProcessor serWorker, MethodDefi
// reader.Read()
serWorker.Append(serWorker.Create(OpCodes.Call, readFunc));
// syncvar
serWorker.Append(serWorker.Create(OpCodes.Stfld, syncVar));
serWorker.Append(serWorker.Create(OpCodes.Stfld, syncVar.MakeHostGenericIfNeeded()));

if (hookMethod != null)
{
Expand Down
Loading