Skip to content

Commit

Permalink
feat: support generic network behaviors (#574)
Browse files Browse the repository at this point in the history
It is now possible to use generic NetworkBehaviors.  The following example now works fine:

```cs
public class Pepe<T> : NetworkBehavior
{
    [SyncVar]
    public int someVariable;

    [ClientRpc]
    public void SomeRpc(string something) {
    }
}

class MyBehavior: Pepe<int> {}
```

Note that as of this PR,  the synccvar or rpc cannot be generic.
  • Loading branch information
Hertzole committed Feb 21, 2021
1 parent 4cbf2b4 commit 715642c
Show file tree
Hide file tree
Showing 20 changed files with 1,511 additions and 44 deletions.
10 changes: 10 additions & 0 deletions Assets/Mirage/Weaver/Extensions.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using System;
using System.Linq;
using Mono.Cecil;
using Mono.Collections.Generic;

namespace Mirage.Weaver
{
Expand Down Expand Up @@ -203,5 +204,14 @@ public static T GetField<T>(this CustomAttribute ca, string field, T defaultValu
return defaultValue;
}

public static FieldReference MakeHostGenericIfNeeded(this FieldReference fd)
{
if (fd.DeclaringType.HasGenericParameters)
{
return new FieldReference(fd.Name, fd.FieldType, fd.DeclaringType.Resolve().ConvertToGenericIfNeeded());
}

return fd;
}
}
}
2 changes: 1 addition & 1 deletion Assets/Mirage/Weaver/MethodExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,4 @@ public static SequencePoint GetSequencePoint(this MethodDefinition method, Instr
return sequencePoint;
}
}
}
}
9 changes: 3 additions & 6 deletions Assets/Mirage/Weaver/Processors/ClientRpcProcessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -51,18 +51,15 @@ MethodDefinition GenerateSkeleton(MethodDefinition md, MethodDefinition userCode
{
MethodDefinition rpc = md.DeclaringType.AddMethod(
SkeletonPrefix + md.Name,
MethodAttributes.Family | MethodAttributes.Static | MethodAttributes.HideBySig);
MethodAttributes.Family | MethodAttributes.HideBySig);

_ = rpc.AddParam<NetworkBehaviour>("obj");
_ = rpc.AddParam<NetworkReader>("reader");
_ = rpc.AddParam<INetworkConnection>("senderConnection");
_ = rpc.AddParam<int>("replyId");

ILProcessor worker = rpc.Body.GetILProcessor();

// setup for reader
worker.Append(worker.Create(OpCodes.Ldarg_0));
worker.Append(worker.Create(OpCodes.Castclass, md.DeclaringType));

// NetworkConnection parameter is only required for Client.Connection
Client target = clientRpcAttr.GetField("target", Client.Observers);
Expand Down Expand Up @@ -168,7 +165,7 @@ MethodDefinition GenerateStub(MethodDefinition md, CustomAttribute clientRpcAttr
else if (target == Client.Owner)
worker.Append(worker.Create(OpCodes.Ldnull));

worker.Append(worker.Create(OpCodes.Ldtoken, md.DeclaringType));
worker.Append(worker.Create(OpCodes.Ldtoken, md.DeclaringType.ConvertToGenericIfNeeded()));
// invokerClass
worker.Append(worker.Create(OpCodes.Call, () => Type.GetTypeFromHandle(default)));
worker.Append(worker.Create(OpCodes.Ldstr, rpcName));
Expand Down Expand Up @@ -237,7 +234,7 @@ public void RegisterClientRpcs(ILProcessor cctorWorker)
*/
void GenerateRegisterRemoteDelegate(ILProcessor worker, MethodDefinition func, string cmdName)
{
TypeDefinition netBehaviourSubclass = func.DeclaringType;
TypeReference netBehaviourSubclass = func.DeclaringType.ConvertToGenericIfNeeded();
worker.Append(worker.Create(OpCodes.Ldtoken, netBehaviourSubclass));
worker.Append(worker.Create(OpCodes.Call, () => Type.GetTypeFromHandle(default)));
worker.Append(worker.Create(OpCodes.Ldstr, cmdName));
Expand Down
8 changes: 4 additions & 4 deletions Assets/Mirage/Weaver/Processors/FieldReferenceComparator.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using System.Collections.Generic;
using System.Collections.Generic;
using Mono.Cecil;

namespace Mirage.Weaver
Expand All @@ -7,9 +7,9 @@ internal class FieldReferenceComparator : IEqualityComparer<FieldReference>
{
public bool Equals(FieldReference x, FieldReference y)
{
return x.FullName == y.FullName;
return x.DeclaringType.FullName == y.DeclaringType.FullName && x.Name == y.Name;
}

public int GetHashCode(FieldReference obj) => obj.FullName.GetHashCode();
public int GetHashCode(FieldReference obj) => (obj.DeclaringType.FullName + "." + obj.Name).GetHashCode();
}
}
}
56 changes: 47 additions & 9 deletions Assets/Mirage/Weaver/Processors/PropertySiteProcessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,19 @@ void ProcessInstructionSetterField(Instruction i, FieldReference opField)
// does it set a field that we replaced?
if (Setters.TryGetValue(opField, out MethodDefinition replacement))
{
//replace with property
i.OpCode = OpCodes.Call;
i.Operand = replacement;
if (opField.DeclaringType.IsGenericInstance || opField.DeclaringType.HasGenericParameters) // We're calling to a generic class
{
FieldReference newField = i.Operand as FieldReference;
GenericInstanceType genericType = (GenericInstanceType)newField.DeclaringType;
i.OpCode = OpCodes.Callvirt;
i.Operand = replacement.MakeHostInstanceGeneric(genericType);
}
else
{
//replace with property
i.OpCode = OpCodes.Call;
i.Operand = replacement;
}
}
}

Expand All @@ -46,31 +56,59 @@ void ProcessInstructionGetterField(Instruction i, FieldReference opField)
// does it set a field that we replaced?
if (Getters.TryGetValue(opField, out MethodDefinition replacement))
{
//replace with property
i.OpCode = OpCodes.Call;
i.Operand = replacement;
if (opField.DeclaringType.IsGenericInstance || opField.DeclaringType.HasGenericParameters) // We're calling to a generic class
{
FieldReference newField = i.Operand as FieldReference;
GenericInstanceType genericType = (GenericInstanceType)newField.DeclaringType;
i.OpCode = OpCodes.Callvirt;
i.Operand = replacement.MakeHostInstanceGeneric(genericType);
}
else
{
//replace with property
i.OpCode = OpCodes.Call;
i.Operand = replacement;
}
}
}

Instruction ProcessInstruction(MethodDefinition md, Instruction instr, SequencePoint sequencePoint)
{
if (instr.OpCode == OpCodes.Stfld && instr.Operand is FieldReference opFieldst)
{
FieldReference resolved = opFieldst.Resolve();
if (resolved == null)
{
resolved = opFieldst.DeclaringType.Resolve().GetField(opFieldst.Name);
}

// this instruction sets the value of a field. cache the field reference.
ProcessInstructionSetterField(instr, opFieldst);
ProcessInstructionSetterField(instr, resolved);
}

if (instr.OpCode == OpCodes.Ldfld && instr.Operand is FieldReference opFieldld)
{
FieldReference resolved = opFieldld.Resolve();
if (resolved == null)
{
resolved = opFieldld.DeclaringType.Resolve().GetField(opFieldld.Name);
}

// this instruction gets the value of a field. cache the field reference.
ProcessInstructionGetterField(instr, opFieldld);
ProcessInstructionGetterField(instr, resolved);
}

if (instr.OpCode == OpCodes.Ldflda && instr.Operand is FieldReference opFieldlda)
{
FieldReference resolved = opFieldlda.Resolve();
if (resolved == null)
{
resolved = opFieldlda.DeclaringType.Resolve().GetField(opFieldlda.Name);
}

// loading a field by reference, watch out for initobj instruction
// see https://github.com/vis2k/Mirror/issues/696
return ProcessInstructionLoadAddress(md, instr, opFieldlda);
return ProcessInstructionLoadAddress(md, instr, resolved);
}

return instr;
Expand Down
8 changes: 4 additions & 4 deletions Assets/Mirage/Weaver/Processors/RpcProcessor.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using System.Linq;
using System.Linq;
using System.Reflection;
using Cysharp.Threading.Tasks;
using Mirage.RemoteCalls;
Expand Down Expand Up @@ -338,15 +338,15 @@ public void FixRemoteCallToBaseMethod(TypeDefinition type, MethodDefinition meth
calledMethod.Name == baseRemoteCallName)
{
TypeDefinition baseType = type.BaseType.Resolve();
MethodDefinition baseMethod = baseType.GetMethodInBaseType(callName);
MethodReference baseMethod = baseType.GetMethodInBaseType(callName);

if (baseMethod == null)
{
logger.Error($"Could not find base method for {callName}", method);
return;
}

if (!baseMethod.IsVirtual)
if (!baseMethod.Resolve().IsVirtual)
{
logger.Error($"Could not find base method that was virtual {callName}", method);
return;
Expand Down Expand Up @@ -375,4 +375,4 @@ static bool IsCallToMethod(Instruction instruction, out MethodDefinition calledM
}

}
}
}
8 changes: 3 additions & 5 deletions Assets/Mirage/Weaver/Processors/ServerRpcProcessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ MethodDefinition GenerateStub(MethodDefinition md, CustomAttribute serverRpcAttr
// invoke internal send and return
// load 'base.' to call the SendServerRpc function with
worker.Append(worker.Create(OpCodes.Ldarg_0));
worker.Append(worker.Create(OpCodes.Ldtoken, md.DeclaringType));
worker.Append(worker.Create(OpCodes.Ldtoken, md.DeclaringType.ConvertToGenericIfNeeded()));
// invokerClass
worker.Append(worker.Create(OpCodes.Call, () => Type.GetTypeFromHandle(default)));
worker.Append(worker.Create(OpCodes.Ldstr, cmdName));
Expand Down Expand Up @@ -146,10 +146,9 @@ private void CallSendServerRpc(MethodDefinition md, ILProcessor worker)
MethodDefinition GenerateSkeleton(MethodDefinition method, MethodDefinition userCodeFunc)
{
MethodDefinition cmd = method.DeclaringType.AddMethod(SkeletonPrefix + method.Name,
MethodAttributes.Family | MethodAttributes.Static | MethodAttributes.HideBySig,
MethodAttributes.Family | MethodAttributes.HideBySig,
userCodeFunc.ReturnType);

_ = cmd.AddParam<NetworkBehaviour>("obj");
_ = cmd.AddParam<NetworkReader>("reader");
_ = cmd.AddParam<INetworkConnection>("senderConnection");
_ = cmd.AddParam<int>("replyId");
Expand All @@ -159,7 +158,6 @@ MethodDefinition GenerateSkeleton(MethodDefinition method, MethodDefinition user

// setup for reader
worker.Append(worker.Create(OpCodes.Ldarg_0));
worker.Append(worker.Create(OpCodes.Castclass, method.DeclaringType));

if (!ReadArguments(method, worker, false))
return cmd;
Expand Down Expand Up @@ -214,7 +212,7 @@ void GenerateRegisterServerRpcDelegate(ILProcessor worker, ServerRpcMethod cmdRe
bool requireAuthority = cmdResult.requireAuthority;

TypeDefinition netBehaviourSubclass = skeleton.DeclaringType;
worker.Append(worker.Create(OpCodes.Ldtoken, netBehaviourSubclass));
worker.Append(worker.Create(OpCodes.Ldtoken, netBehaviourSubclass.ConvertToGenericIfNeeded()));
worker.Append(worker.Create(OpCodes.Call, () => Type.GetTypeFromHandle(default)));
worker.Append(worker.Create(OpCodes.Ldstr, cmdName));
worker.Append(worker.Create(OpCodes.Ldnull));
Expand Down
2 changes: 1 addition & 1 deletion Assets/Mirage/Weaver/Processors/SyncObjectProcessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ public void ProcessSyncObjects(TypeDefinition td)
{
foreach (FieldDefinition fd in td.Fields)
{
if (fd.FieldType.IsGenericParameter) // Just ignore all generic objects.
if (fd.FieldType.IsGenericParameter || fd.ContainsGenericParameter) // Just ignore all generic objects.
{
continue;
}
Expand Down
27 changes: 18 additions & 9 deletions Assets/Mirage/Weaver/Processors/SyncVarProcessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -203,15 +203,15 @@ private void StoreField(FieldDefinition fd, ParameterDefinition valueParam, ILPr
MethodReference setter = module.ImportReference(fd.FieldType.Resolve().GetMethod("set_Value"));

worker.Append(worker.Create(OpCodes.Ldarg_0));
worker.Append(worker.Create(OpCodes.Ldflda, fd));
worker.Append(worker.Create(OpCodes.Ldflda, fd.MakeHostGenericIfNeeded()));
worker.Append(worker.Create(OpCodes.Ldarg, valueParam));
worker.Append(worker.Create(OpCodes.Call, setter));
}
else
{
worker.Append(worker.Create(OpCodes.Ldarg_0));
worker.Append(worker.Create(OpCodes.Ldarg, valueParam));
worker.Append(worker.Create(OpCodes.Stfld, fd));
worker.Append(worker.Create(OpCodes.Stfld, fd.MakeHostGenericIfNeeded()));
}
}

Expand All @@ -221,7 +221,7 @@ private void LoadField(FieldDefinition fd, TypeReference originalType, ILProces

if (IsWrapped(fd.FieldType))
{
worker.Append(worker.Create(OpCodes.Ldflda, fd));
worker.Append(worker.Create(OpCodes.Ldflda, fd.MakeHostGenericIfNeeded()));
MethodReference getter = module.ImportReference(fd.FieldType.Resolve().GetMethod("get_Value"));
worker.Append(worker.Create(OpCodes.Call, getter));

Expand All @@ -236,7 +236,7 @@ private void LoadField(FieldDefinition fd, TypeReference originalType, ILProces
}
else
{
worker.Append(worker.Create(OpCodes.Ldfld, fd));
worker.Append(worker.Create(OpCodes.Ldfld, fd.MakeHostGenericIfNeeded()));
}
}

Expand Down Expand Up @@ -422,7 +422,16 @@ void WriteEndFunctionCall()
{
// only use Callvirt when not static
OpCode opcode = hookMethod.IsStatic ? OpCodes.Call : OpCodes.Callvirt;
worker.Append(worker.Create(opcode, hookMethod));
MethodReference hookMethodReference = hookMethod;

if (hookMethodReference.DeclaringType.HasGenericParameters)
{
// we need to get the Type<T>.HookMethod so convert it to a generic<T>.
var genericType = (GenericInstanceType)hookMethod.DeclaringType.ConvertToGenericIfNeeded();
hookMethodReference = hookMethod.MakeHostInstanceGeneric(genericType);
}

worker.Append(worker.Create(opcode, module.ImportReference(hookMethodReference)));
}
}

Expand Down Expand Up @@ -453,7 +462,7 @@ void GenerateSerialization(TypeDefinition netBehaviourSubclass)
// loc_0, this local variable is to determine if any variable was dirty
VariableDefinition dirtyLocal = serialize.AddLocal<bool>();

MethodDefinition baseSerialize = netBehaviourSubclass.BaseType.Resolve().GetMethodInBaseType(SerializeMethodName);
MethodReference baseSerialize = netBehaviourSubclass.BaseType.GetMethodInBaseType(SerializeMethodName);
if (baseSerialize != null)
{
// base
Expand Down Expand Up @@ -538,7 +547,7 @@ private void WriteVariable(ILProcessor worker, ParameterDefinition writerParamet
worker.Append(worker.Create(OpCodes.Ldarg, writerParameter));
// this
worker.Append(worker.Create(OpCodes.Ldarg_0));
worker.Append(worker.Create(OpCodes.Ldfld, syncVar));
worker.Append(worker.Create(OpCodes.Ldfld, syncVar.MakeHostGenericIfNeeded()));
MethodReference writeFunc = writers.GetWriteFunc(syncVar.FieldType, null);
if (writeFunc != null)
{
Expand Down Expand Up @@ -574,7 +583,7 @@ void GenerateDeSerialization(TypeDefinition netBehaviourSubclass)
serialize.Body.InitLocals = true;
VariableDefinition dirtyBitsLocal = serialize.AddLocal<long>();

MethodDefinition baseDeserialize = netBehaviourSubclass.BaseType.Resolve().GetMethodInBaseType(DeserializeMethodName);
MethodReference baseDeserialize = netBehaviourSubclass.BaseType.GetMethodInBaseType(DeserializeMethodName);
if (baseDeserialize != null)
{
// base
Expand Down Expand Up @@ -681,7 +690,7 @@ void DeserializeField(FieldDefinition syncVar, ILProcessor serWorker, MethodDefi
// reader.Read()
serWorker.Append(serWorker.Create(OpCodes.Call, readFunc));
// syncvar
serWorker.Append(serWorker.Create(OpCodes.Stfld, syncVar));
serWorker.Append(serWorker.Create(OpCodes.Stfld, syncVar.MakeHostGenericIfNeeded()));

if (hookMethod != null)
{
Expand Down

0 comments on commit 715642c

Please sign in to comment.