Skip to content

Commit

Permalink
Added more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Shane32 committed Mar 23, 2021
1 parent 7ea1d6b commit c3cc139
Show file tree
Hide file tree
Showing 4 changed files with 181 additions and 53 deletions.
82 changes: 30 additions & 52 deletions src/GraphQL.DI/DIObjectGraphType.cs
Original file line number Diff line number Diff line change
Expand Up @@ -248,29 +248,10 @@ protected virtual bool GetNullability(MethodInfo method)
if (method.ReturnType.IsValueType)
return method.ReturnType.IsConstructedGenericType && method.ReturnType.GetGenericTypeDefinition() == typeof(Nullable<>);

Nullability nullable = Nullability.Unknown;

// check the parent type first to see if there's a nullable context attribute set for it
var parentType = method.DeclaringType;
var attribute = parentType.CustomAttributes.FirstOrDefault(x =>
x.AttributeType.FullName == "System.Runtime.CompilerServices.NullableContextAttribute" &&
x.ConstructorArguments.Count == 1 &&
x.ConstructorArguments[0].ArgumentType == typeof(byte));
if (attribute != null) {
nullable = (Nullability)(byte)attribute.ConstructorArguments[0].Value;
}

// now check the method to see if there's a nullable context attribute set for it
attribute = method.CustomAttributes.FirstOrDefault(x =>
x.AttributeType.FullName == "System.Runtime.CompilerServices.NullableContextAttribute" &&
x.ConstructorArguments.Count == 1 &&
x.ConstructorArguments[0].ArgumentType == typeof(byte));
if (attribute != null) {
nullable = (Nullability)(byte)attribute.ConstructorArguments[0].Value;
}
Nullability nullable = GetMethodDefaultNullability(method);

// now check the return type to see if there's a nullable attribute for it
attribute = method.ReturnParameter.CustomAttributes.FirstOrDefault(x =>
var attribute = method.ReturnParameter.CustomAttributes.FirstOrDefault(x =>
x.AttributeType.FullName == "System.Runtime.CompilerServices.NullableAttribute" &&
x.ConstructorArguments.Count == 1 &&
(x.ConstructorArguments[0].ArgumentType == typeof(byte) ||
Expand Down Expand Up @@ -312,21 +293,9 @@ private enum Nullability : byte
Nullable = 2,
}

/// <summary>
/// Returns a boolean indicating if the parameter value is nullable
/// </summary>
protected virtual bool GetNullability(MethodInfo method, ParameterInfo parameter)
private Nullability GetMethodDefaultNullability(MethodInfo method)
{
if (parameter.GetCustomAttribute<OptionalAttribute>() != null)
return true;
if (parameter.GetCustomAttribute<RequiredAttribute>() != null)
return false;
if (parameter.GetCustomAttribute<System.ComponentModel.DataAnnotations.RequiredAttribute>() != null)
return false;
if (parameter.ParameterType.IsValueType)
return parameter.ParameterType.IsConstructedGenericType && parameter.ParameterType.GetGenericTypeDefinition() == typeof(Nullable<>);

Nullability nullable = Nullability.Unknown;
var nullable = Nullability.Unknown;

// check the parent type first to see if there's a nullable context attribute set for it
var parentType = method.DeclaringType;
Expand All @@ -347,8 +316,29 @@ protected virtual bool GetNullability(MethodInfo method, ParameterInfo parameter
nullable = (Nullability)(byte)attribute.ConstructorArguments[0].Value;
}

return nullable;
}

/// <summary>
/// Returns a boolean indicating if the parameter value is nullable
/// </summary>
protected virtual bool GetNullability(MethodInfo method, ParameterInfo parameter)
{
if (parameter.GetCustomAttribute<OptionalAttribute>() != null)
return true;
if (parameter.GetCustomAttribute<RequiredAttribute>() != null)
return false;
if (parameter.GetCustomAttribute<System.ComponentModel.DataAnnotations.RequiredAttribute>() != null)
return false;
if (parameter.IsOptional)
return true;
if (parameter.ParameterType.IsValueType)
return parameter.ParameterType.IsConstructedGenericType && parameter.ParameterType.GetGenericTypeDefinition() == typeof(Nullable<>);

Nullability nullable = GetMethodDefaultNullability(method);

// now check the parameter to see if there's a nullable attribute for it
attribute = parameter.CustomAttributes.FirstOrDefault(x =>
var attribute = parameter.CustomAttributes.FirstOrDefault(x =>
x.AttributeType.FullName == "System.Runtime.CompilerServices.NullableAttribute" &&
x.ConstructorArguments.Count == 1 &&
(x.ConstructorArguments[0].ArgumentType == typeof(byte) ||
Expand All @@ -358,23 +348,11 @@ protected virtual bool GetNullability(MethodInfo method, ParameterInfo parameter
}

var nullabilityBytes = attribute?.ConstructorArguments[0].Value as byte[];
var index = 0;
nullable = Consider(parameter.ParameterType);
if ((nullabilityBytes != null && nullabilityBytes[0] == (byte)Nullability.Nullable) || (nullabilityBytes == null && nullable == Nullability.Nullable))
return true;
if (nullabilityBytes != null)
return (Nullability)nullabilityBytes[0] != Nullability.NonNullable;
return nullable != Nullability.NonNullable;

Nullability Consider(Type t)
{
var g = t.IsGenericType ? t.GetGenericTypeDefinition() : null;
if (g == typeof(Nullable<>))
return Nullability.Nullable;
if (t.IsValueType)
return Nullability.NonNullable;
if ((nullabilityBytes != null && nullabilityBytes[index] == (byte)Nullability.Nullable) || (nullabilityBytes == null && nullable == Nullability.Nullable))
return Nullability.Nullable;
if (nullabilityBytes != null)
return (Nullability)nullabilityBytes[index];
return nullable;
}
}

/// <summary>
Expand Down
62 changes: 62 additions & 0 deletions src/Tests/DIObjectGraphTypeTests/Argument.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System.ComponentModel;
using System.Threading;
using GraphQL.DI;
using GraphQL.Types;
using Shouldly;
Expand Down Expand Up @@ -78,6 +79,20 @@ public class CNonNullableObject : DIObjectGraphBase
public static string Field1([Required] string arg) => arg;
}

[Fact]
public void NonNullableObject2()
{
Configure<CNonNullableObject2, object>();
VerifyFieldArgument("Field1", "arg", false, "hello");
VerifyField("Field1", true, false, "hello");
Verify(false);
}

public class CNonNullableObject2 : DIObjectGraphBase
{
public static string Field1([System.ComponentModel.DataAnnotations.Required] string arg) => arg;
}

[Fact]
public void Name()
{
Expand Down Expand Up @@ -121,5 +136,52 @@ public class CGraphType : DIObjectGraphBase
public static string Field1([GraphType(typeof(StringGraphType))] string arg) => arg;
}

[Fact]
public void CancellationToken()
{
using var cts = new CancellationTokenSource();
var token = cts.Token;
cts.Cancel();
Configure<CCancellationToken, object>();
_contextMock.SetupGet(x => x.CancellationToken).Returns(token).Verifiable();
VerifyField("Field1", true, false, "hello + True");
Verify(false);
}

public class CCancellationToken : DIObjectGraphBase
{
public static string Field1(CancellationToken cancellationToken) => "hello + " + cancellationToken.IsCancellationRequested;
}

[Fact]
public void NullNameArgument()
{
Configure<CNullNameArgument, object>();
VerifyField("Field1", true, false, "hello + 0");
Verify(false);
}

public class CNullNameArgument : DIObjectGraphBase
{
public static string Field1([Name(null)] int arg) => "hello + " + arg;
}

[Fact]
public void DefaultValue()
{
Configure<CDefaultValue, object>();
VerifyFieldArgument("Field1", "arg1", true, 2);
VerifyField("Field1", false, false, 2);
VerifyFieldArgument<int>("Field2", "arg2", true);
VerifyField("Field2", false, false, 5);
Verify(false);
}

public class CDefaultValue : DIObjectGraphBase
{
public static int Field1(int arg1 = 4) => arg1;
public static int Field2(int arg2 = 5) => arg2;
}

}
}
15 changes: 14 additions & 1 deletion src/Tests/DIObjectGraphTypeTests/DIObjectGraphTypeTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ protected FieldType VerifyField<T>(string fieldName, Type fieldGraphType, bool c
field.Resolver.ShouldNotBeNull();
_contextMock.Setup(x => x.FieldDefinition).Returns(field);
field.Resolver.Resolve(context).ShouldBe(returnValue);
_arguments.Clear();
return field;
}

Expand All @@ -88,6 +89,7 @@ protected async Task<FieldType> VerifyFieldAsync<T>(string fieldName, bool nulla
var taskRet = ret.ShouldBeOfType<Task<T>>();
var final = await taskRet;
final.ShouldBe(returnValue);
_arguments.Clear();
return field;
}

Expand All @@ -96,7 +98,12 @@ protected QueryArgument VerifyFieldArgument<T>(string fieldName, string argument
return VerifyFieldArgument(fieldName, argumentName, typeof(T).GetGraphTypeFromType(nullable, TypeMappingMode.InputType), returnValue);
}

protected QueryArgument VerifyFieldArgument<T>(string fieldName, string argumentName, Type graphType, T returnValue)
protected QueryArgument VerifyFieldArgument<T>(string fieldName, string argumentName, bool nullable)
{
return VerifyFieldArgument<T>(fieldName, argumentName, typeof(T).GetGraphTypeFromType(nullable, TypeMappingMode.InputType));
}

protected QueryArgument VerifyFieldArgument<T>(string fieldName, string argumentName, Type graphType)
{
_graphType.ShouldNotBeNull();
_graphType.Fields.ShouldNotBeNull();
Expand All @@ -106,6 +113,12 @@ protected QueryArgument VerifyFieldArgument<T>(string fieldName, string argument
var argument = field.Arguments.Find(argumentName);
argument.ShouldNotBeNull();
argument.Type.ShouldBe(graphType);
return argument;
}

protected QueryArgument VerifyFieldArgument<T>(string fieldName, string argumentName, Type graphType, T returnValue)
{
var argument = VerifyFieldArgument<T>(fieldName, argumentName, graphType);
_arguments.Add(argumentName, new ArgumentValue(returnValue, ArgumentSource.FieldDefault));
return argument;
}
Expand Down
75 changes: 75 additions & 0 deletions src/Tests/DIObjectGraphTypeTests/Field.cs
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,81 @@ public class CInheritedConcurrentOverridable : DIObjectGraphBase<object>
public static Task<string> Field1(IServiceProvider services) => Task.FromResult<string>("hello");
}

[Fact]
public async Task RemoveAsyncFromName()
{
Configure<CRemoveAsyncFromName, object>();
await VerifyFieldAsync("Field1", true, true, "hello");
Verify(false);
}

public class CRemoveAsyncFromName : DIObjectGraphBase<object>
{
public static Task<string> Field1Async() => Task.FromResult("hello");
}

[Fact]
public void SkipVoidMembers()
{
Configure<CSkipVoidMembers, object>();
_graphType.Fields.Find("Field1").ShouldBeNull();
_graphType.Fields.Find("Field2").ShouldBeNull();
}

public class CSkipVoidMembers : DIObjectGraphBase<object>
{
public static Task Field1() => Task.CompletedTask;
public static void Field2() { }
}

[Fact]
public void SkipNullName()
{
Configure<CSkipNullName, object>();
_graphType.Fields.Find("Field1").ShouldBeNull();
}

public class CSkipNullName : DIObjectGraphBase<object>
{
[Name(null)]
public static string Field1() => "hello";
}

[Fact]
public void AddsMetadata()
{
Configure<CAddsMetadata, object>();
var field = VerifyField("Field1", true, false, "hello");
field.GetMetadata<string>("test").ShouldBe("value");
Verify(false);
}

public class CAddsMetadata : DIObjectGraphBase<object>
{
[Metadata("test", "value")]
public static string Field1() => "hello";
}

[Fact]
public void AddsInheritedMetadata()
{
Configure<CAddsInheritedMetadata, object>();
var field = VerifyField("Field1", true, false, "hello");
field.GetMetadata<string>("test").ShouldBe("value2");
Verify(false);
}

public class CAddsInheritedMetadata : DIObjectGraphBase<object>
{
[InheritedMetadata("value2")]
public static string Field1() => "hello";
}

private class InheritedMetadata : MetadataAttribute
{
public InheritedMetadata(string value) : base("test", value) { }
}

[Theory]
[InlineData("Field1", typeof(GraphQLClrOutputTypeReference<string>))]
[InlineData("Field2", typeof(NonNullGraphType<GraphQLClrOutputTypeReference<string>>))]
Expand Down

0 comments on commit c3cc139

Please sign in to comment.