diff --git a/src/GraphQL.DI/DIObjectGraphType.cs b/src/GraphQL.DI/DIObjectGraphType.cs index 24fe23b..b2d3d83 100644 --- a/src/GraphQL.DI/DIObjectGraphType.cs +++ b/src/GraphQL.DI/DIObjectGraphType.cs @@ -266,7 +266,7 @@ protected virtual bool GetNullability(MethodInfo method) nullable = (Nullability)(byte)attribute.ConstructorArguments[0].Value; } - var nullabilityBytes = attribute?.ConstructorArguments[0].Value as byte[]; + var nullabilityBytes = attribute?.ConstructorArguments[0].Value as IList; var index = 0; nullable = Consider(method.ReturnType); return nullable != Nullability.NonNullable; @@ -278,7 +278,7 @@ Nullability Consider(Type t) return Nullability.Nullable; if (t.IsValueType) return Nullability.NonNullable; - if ((nullabilityBytes != null && nullabilityBytes[index] == (byte)Nullability.Nullable) || (nullabilityBytes == null && nullable == Nullability.Nullable)) + if ((nullabilityBytes != null && (byte)nullabilityBytes[index].Value == (byte)Nullability.Nullable) || (nullabilityBytes == null && nullable == Nullability.Nullable)) return Nullability.Nullable; if (g == typeof(IDataLoaderResult<>) || g == typeof(Task<>)) { index++; @@ -287,7 +287,7 @@ Nullability Consider(Type t) if (t == typeof(IDataLoaderResult)) return Nullability.Nullable; if (nullabilityBytes != null) - return (Nullability)nullabilityBytes[index]; + return (Nullability)(byte)nullabilityBytes[index].Value; return nullable; } } diff --git a/src/Tests/DIObjectGraphTypeTests/Nullable.cs b/src/Tests/DIObjectGraphTypeTests/Nullable.cs index 54ae8fb..f75897f 100644 --- a/src/Tests/DIObjectGraphTypeTests/Nullable.cs +++ b/src/Tests/DIObjectGraphTypeTests/Nullable.cs @@ -32,6 +32,7 @@ public class Nullable [InlineData(typeof(NullableClass15), null, null)] [InlineData(typeof(NullableClass16), 1, 0)] [InlineData(typeof(NullableClass16.NestedClass1), null, 0)] + [InlineData(typeof(NullableClass17), 1, 0)] public void VerifyTestClass(Type type, int? nullableContext, int? nullable) { var actualHasNullableContext = type.CustomAttributes.FirstOrDefault( @@ -100,8 +101,13 @@ public void VerifyTestClass(Type type, int? nullableContext, int? nullable) [InlineData(typeof(NullableClass15), "Field4", false, true)] [InlineData(typeof(NullableClass16), "Field1", false, false)] [InlineData(typeof(NullableClass16), "Field2", false, false)] + [InlineData(typeof(NullableClass16), "Field3", false, true)] [InlineData(typeof(NullableClass16.NestedClass1), "Field1", false, false)] [InlineData(typeof(NullableClass16.NestedClass1), "Field2", false, false)] + [InlineData(typeof(NullableClass16.NestedClass1), "Field3", false, true)] + [InlineData(typeof(NullableClass17), "Field1", false, false)] + [InlineData(typeof(NullableClass17), "Field2", false, false)] + [InlineData(typeof(NullableClass17), "Field3", true, false)] public void VerifyTestMethod(Type type, string methodName, bool hasNullable, bool hasNullableContext) { var method = type.GetMethod(methodName); @@ -189,8 +195,13 @@ public void VerifyTestArgument(Type type, string methodName, string argumentName [InlineData(typeof(NullableClass15), "Field4", true)] [InlineData(typeof(NullableClass16), "Field1", false)] [InlineData(typeof(NullableClass16), "Field2", false)] + [InlineData(typeof(NullableClass16), "Field3", true)] [InlineData(typeof(NullableClass16.NestedClass1), "Field1", false)] [InlineData(typeof(NullableClass16.NestedClass1), "Field2", false)] + [InlineData(typeof(NullableClass16.NestedClass1), "Field3", true)] + [InlineData(typeof(NullableClass17), "Field1", false)] + [InlineData(typeof(NullableClass17), "Field2", false)] + [InlineData(typeof(NullableClass17), "Field3", true)] public void Method(Type type, string methodName, bool expected) { var method = type.GetMethod(methodName); diff --git a/src/Tests/NullabilityTestClasses.cs b/src/Tests/NullabilityTestClasses.cs index 6f75e8e..6b23448 100644 --- a/src/Tests/NullabilityTestClasses.cs +++ b/src/Tests/NullabilityTestClasses.cs @@ -114,11 +114,20 @@ public class NullableClass16 { public static string Field1() => "test"; public static string Field2() => "test"; + public static string? Field3() => "test"; public class NestedClass1 { public static string Field1() => "test"; public static string Field2() => "test"; + public static string? Field3() => "test"; } } + + public class NullableClass17 + { + public static Task Field1() => null!; + public static Task Field2() => null!; + public static Task Field3() => null!; + } }