Skip to content

Commit

Permalink
fix(RPC): fixing rpc calls to base methods and overloads
Browse files Browse the repository at this point in the history
adding more test for overload and base class rpcs
  • Loading branch information
James-Frowen committed Feb 2, 2022
1 parent 4dcd604 commit 8bc165d
Show file tree
Hide file tree
Showing 14 changed files with 350 additions and 72 deletions.
8 changes: 6 additions & 2 deletions Assets/Mirage/Weaver/Extensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -234,8 +234,12 @@ public static CustomAttribute GetCustomAttribute<TAttribute>(this ICustomAttribu

public static bool HasCustomAttribute<TAttribute>(this ICustomAttributeProvider attributeProvider)
{
// Linq allocations don't matter in weaver
return attributeProvider.CustomAttributes.Any(attr => attr.AttributeType.Is<TAttribute>());
return HasCustomAttribute(attributeProvider, typeof(TAttribute));
}

public static bool HasCustomAttribute(this ICustomAttributeProvider attributeProvider, Type t)
{
return attributeProvider.CustomAttributes.Any(attr => attr.AttributeType.Is(t));
}

public static T GetField<T>(this CustomAttribute ca, string field, T defaultValue)
Expand Down
4 changes: 3 additions & 1 deletion Assets/Mirage/Weaver/Processors/ClientRpcProcessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ public ClientRpcProcessor(ModuleDefinition module, Readers readers, Writers writ
{
}

protected override Type AttributeType => typeof(ClientRpcAttribute);

/// <summary>
/// Generates a skeleton for an RPC
/// </summary>
Expand Down Expand Up @@ -145,7 +147,7 @@ MethodDefinition GenerateStub(MethodDefinition md, CustomAttribute clientRpcAttr
// write all the arguments that the user passed to the Rpc call
WriteArguments(worker, md, writer, paramSerializers, RemoteCallType.ClientRpc);

string rpcName = md.Name;
string rpcName = md.FullName;

RpcTarget target = clientRpcAttr.GetField("target", RpcTarget.Observers);
int channel = clientRpcAttr.GetField("channel", 0);
Expand Down
62 changes: 33 additions & 29 deletions Assets/Mirage/Weaver/Processors/RpcProcessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ protected RpcProcessor(ModuleDefinition module, Readers readers, Writers writers
this.logger = logger;
}

/// <summary>
/// Type of attribute for this rpc, eg [ServerRPC] or [ClientRPC}
/// </summary>
protected abstract Type AttributeType { get; }

// helper functions to check if the method has a NetworkPlayer parameter
protected static bool HasNetworkPlayerParameter(MethodDefinition md)
{
Expand Down Expand Up @@ -302,8 +307,7 @@ void ValidateParameter(MethodReference method, ParameterDefinition param, Remote
// this returns the newly created method with all the user provided code
public MethodDefinition SubstituteMethod(MethodDefinition method)
{
// append fullName hash to end to support overloads, but keep "md.Name" so it is human readable when debugging
string newName = $"UserCode_{method.Name}_{method.FullName.GetStableHashCode()}";
string newName = UserCodeMethodName(method);
MethodDefinition generatedMethod = method.DeclaringType.AddMethod(newName, method.Attributes, method.ReturnType);

// add parameters
Expand Down Expand Up @@ -343,35 +347,35 @@ void FixRemoteCallToBaseMethod(TypeDefinition type, MethodDefinition method, Met

foreach (Instruction instruction in generatedMethod.Body.Instructions)
{
// if call to base.CmdDoSomething within this.CallCmdDoSomething
if (IsCallToMethod(instruction, out MethodDefinition calledMethod) &&
calledMethod.Name == rpcName)
if (!IsCallToMethod(instruction, out MethodDefinition calledMethod))
continue;

// does method have same name? (NOTE: could be overload or non RPC at this point)
if (calledMethod.Name != rpcName)
continue;

// method (base or overload) is not an rpc, dont try to change it
if (!calledMethod.HasCustomAttribute(AttributeType))
continue;

string targetName = UserCodeMethodName(calledMethod);
// check this type and base types for methods
// if the calledMethod is an rpc, then it will have a UserCode_ method generated for it
MethodReference userCodeReplacement = type.GetMethodInBaseType(targetName);

if (userCodeReplacement == null)
{
throw new RpcException($"Could not find base method for {userCodeName}", generatedMethod);
}

if (!userCodeReplacement.Resolve().IsVirtual)
{
TypeDefinition baseType = type.BaseType.Resolve();
MethodReference baseMethod = baseType.GetMethodInBaseType(userCodeName);

// todo isn't calledMethod == baseMethod, improve this
if (calledMethod != baseMethod)
{
throw new RpcException("call was not to base method", method);
}

if (baseMethod == null)
{
logger.Error($"Could not find base method for {userCodeName}", generatedMethod);
return;
}

if (!baseMethod.Resolve().IsVirtual)
{
logger.Error($"Could not find base method that was virtual {userCodeName}", generatedMethod);
return;
}

instruction.Operand = generatedMethod.Module.ImportReference(baseMethod);

Weaver.DebugLog(type, $"Replacing call to '{calledMethod.FullName}' with '{baseMethod.FullName}' inside '{ generatedMethod.FullName}'");
throw new RpcException($"Could not find base method that was virtual {userCodeName}", generatedMethod);
}

instruction.Operand = generatedMethod.Module.ImportReference(userCodeReplacement);

Weaver.DebugLog(type, $"Replacing call to '{calledMethod.FullName}' with '{userCodeReplacement.FullName}' inside '{ generatedMethod.FullName}'");
}
}

Expand Down
4 changes: 3 additions & 1 deletion Assets/Mirage/Weaver/Processors/ServerRpcProcessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ public ServerRpcProcessor(ModuleDefinition module, Readers readers, Writers writ
{
}

protected override Type AttributeType => typeof(ServerRpcAttribute);

/// <summary>
/// Replaces the user code with a stub.
/// Moves the original code to a new method
Expand Down Expand Up @@ -65,7 +67,7 @@ MethodDefinition GenerateStub(MethodDefinition md, CustomAttribute serverRpcAttr
// write all the arguments that the user passed to the Cmd call
WriteArguments(worker, md, writer, paramSerializers, RemoteCallType.ServerRpc);

string cmdName = md.Name;
string cmdName = md.FullName;

int channel = serverRpcAttr.GetField("channel", 0);
bool requireAuthority = serverRpcAttr.GetField("requireAuthority", true);
Expand Down
8 changes: 8 additions & 0 deletions Assets/Tests/Runtime/ClientServer/RpcTests.meta

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

70 changes: 64 additions & 6 deletions Assets/Tests/Runtime/ClientServer/RpcTests/CallToNonRpcBase.cs
Original file line number Diff line number Diff line change
@@ -1,22 +1,80 @@
using System;
using System.Collections;
using NSubstitute;
using UnityEngine.TestTools;

namespace Mirage.Tests.Runtime.ClientServer.RpcTests
{
class CallToNonRpcBase_behaviour : CallToNonRpcBase_base
public class CallToNonRpcBase_behaviour : CallToNonRpcBase_base
{
public event Action<int> clientRpcCalled;
public event Action<int> serverRpcCalled;

[ClientRpc]
public override void RpcThatIsTotallyValid(int a)
public override void MyRpc(int arg1)
{
clientRpcCalled?.Invoke(arg1);

// should call normal base method, no swapping to rpc (that doesn't exist)
base.RpcThatIsTotallyValid(a);
base.MyRpc(arg1);
}

[ServerRpc(requireAuthority = false)]
public void MyRpc(int arg1, INetworkPlayer sender)
{
serverRpcCalled?.Invoke(arg1);

// should call normal base method, no swapping to rpc (that doesn't exist)
base.MyRpc(arg1);
}
}

class CallToNonRpcBase_base : NetworkBehaviour
public class CallToNonRpcBase_base : NetworkBehaviour
{
public event Action<int> baseCalled;

// not an rpc, override is, so it should just be called normally on receiver
public virtual void RpcThatIsTotallyValid(int a) { }
public virtual void MyRpc(int arg1)
{
baseCalled?.Invoke(arg1);
}
}

public class CallToNonRpcBase : ClientServerSetup<MockComponent>
public class CallToNonRpcBase : ClientServerSetup<CallToNonRpcBase_behaviour>
{

[UnityTest]
public IEnumerator CanCallServerRpc()
{
const int num = 32;
Action<int> sub = Substitute.For<Action<int>>();
Action<int> baseSub = Substitute.For<Action<int>>();
serverComponent.serverRpcCalled += sub;
serverComponent.baseCalled += baseSub;
clientComponent.MyRpc(num, default(INetworkPlayer));

yield return null;
yield return null;

sub.Received(1).Invoke(num);
baseSub.Received(1).Invoke(num);
}

[UnityTest]
public IEnumerator CanCallClientRpc()
{
const int num = 32;
Action<int> sub = Substitute.For<Action<int>>();
Action<int> baseSub = Substitute.For<Action<int>>();
clientComponent.clientRpcCalled += sub;
clientComponent.baseCalled += baseSub;
serverComponent.MyRpc(num);

yield return null;
yield return null;

sub.Received(1).Invoke(num);
baseSub.Received(1).Invoke(num);
}
}
}
73 changes: 67 additions & 6 deletions Assets/Tests/Runtime/ClientServer/RpcTests/CallToNonRpcOverLoad.cs
Original file line number Diff line number Diff line change
@@ -1,20 +1,81 @@
using System;
using System.Collections;
using NSubstitute;
using UnityEngine.TestTools;

namespace Mirage.Tests.Runtime.ClientServer.RpcTests
{
class CallToNonRpcOverLoad_behaviour : NetworkBehaviour
// normal and rpc method in same class
public class CallToNonRpcOverLoad_behaviour : NetworkBehaviour
{
// normal and rpc method in same class
public event Action<int> clientRpcCalled;
public event Action<int> serverRpcCalled;
public event Action<int> overloadCalled;

[ClientRpc(target = RpcTarget.Player)]
public void RpcThatIsTotallyValid(INetworkPlayer player, int a)
public void MyRpc(INetworkPlayer player, int arg1)
{
clientRpcCalled?.Invoke(arg1);

// should call overload without any problem
RpcThatIsTotallyValid(a);
MyRpc(arg1);
}

public void RpcThatIsTotallyValid(int a) { }
[ServerRpc(requireAuthority = false)]
public void MyRpc(int arg1, INetworkPlayer sender)
{
serverRpcCalled?.Invoke(arg1);

// should call base user code, not generated rpc
MyRpc(arg1);
}

public void MyRpc(int arg1)
{
overloadCalled?.Invoke(arg1);
}
}

public class CallToNonRpcOverLoad : ClientServerSetup<MockComponent>
public class CallToNonRpcOverLoad : ClientServerSetup<CallToNonRpcOverLoad_behaviour>
{
[UnityTest]
public IEnumerator CanCallServerRpc()
{
const int num = 32;
Action<int> clientSub = Substitute.For<Action<int>>();
Action<int> serverSub = Substitute.For<Action<int>>();
Action<int> overloadSub = Substitute.For<Action<int>>();
serverComponent.clientRpcCalled += clientSub;
serverComponent.serverRpcCalled += serverSub;
serverComponent.overloadCalled += overloadSub;
clientComponent.MyRpc(num, default(INetworkPlayer));

yield return null;
yield return null;

clientSub.DidNotReceive().Invoke(num);
serverSub.Received(1).Invoke(num);
overloadSub.Received(1).Invoke(num);
}

[UnityTest]
public IEnumerator CanCallClientRpc()
{
const int num = 32;
Action<int> clientSub = Substitute.For<Action<int>>();
Action<int> serverSub = Substitute.For<Action<int>>();
Action<int> overloadSub = Substitute.For<Action<int>>();
clientComponent.clientRpcCalled += clientSub;
clientComponent.serverRpcCalled += serverSub;
clientComponent.overloadCalled += overloadSub;
serverComponent.MyRpc(serverPlayer, num);

yield return null;
yield return null;

clientSub.Received(1).Invoke(num);
serverSub.DidNotReceive().Invoke(num);
overloadSub.Received(1).Invoke(num);
}
}
}

0 comments on commit 8bc165d

Please sign in to comment.