diff --git a/src/GraphQL.EntityFramework/Where/ArgumentProcessor_List.cs b/src/GraphQL.EntityFramework/Where/ArgumentProcessor_List.cs index 2f470b414..bff813b88 100644 --- a/src/GraphQL.EntityFramework/Where/ArgumentProcessor_List.cs +++ b/src/GraphQL.EntityFramework/Where/ArgumentProcessor_List.cs @@ -4,31 +4,29 @@ public static partial class ArgumentProcessor { public static IEnumerable ApplyGraphQlArguments(this IEnumerable items, bool hasId, IResolveFieldContext context) { - object? GetArguments(Type type, string name) => context.GetArgument(type, name); - if (hasId) { - if (ArgumentReader.TryReadIds(GetArguments, out var values)) + if (ArgumentReader.TryReadIds(context, out var values)) { var predicate = ExpressionBuilder.BuildPredicate("Id", Comparison.In, values); items = items.Where(predicate.Compile()); } } - if (ArgumentReader.TryReadWhere(GetArguments, out var wheres)) + if (ArgumentReader.TryReadWhere(context, out var wheres)) { var predicate = ExpressionBuilder.BuildPredicate(wheres); items = items.Where(predicate.Compile()); } - items = Order(items, GetArguments); + items = Order(items, context); - if (ArgumentReader.TryReadSkip(GetArguments, out var skip)) + if (ArgumentReader.TryReadSkip(context, out var skip)) { items = items.Skip(skip); } - if (ArgumentReader.TryReadTake(GetArguments, out var take)) + if (ArgumentReader.TryReadTake(context, out var take)) { items = items.Take(take); } @@ -36,10 +34,10 @@ public static IEnumerable ApplyGraphQlArguments(this IEnumerable Order(IEnumerable queryable, Func getArguments) + static IEnumerable Order(IEnumerable queryable, IResolveFieldContext context) { var items = queryable.ToList(); - var orderBys = ArgumentReader.ReadOrderBy(getArguments).ToList(); + var orderBys = ArgumentReader.ReadOrderBy(context).ToList(); IOrderedEnumerable ordered; if (orderBys.Count > 0) { diff --git a/src/GraphQL.EntityFramework/Where/ArgumentProcessor_Queryable.cs b/src/GraphQL.EntityFramework/Where/ArgumentProcessor_Queryable.cs index fd5d692c1..c729c684e 100644 --- a/src/GraphQL.EntityFramework/Where/ArgumentProcessor_Queryable.cs +++ b/src/GraphQL.EntityFramework/Where/ArgumentProcessor_Queryable.cs @@ -9,11 +9,9 @@ public static IQueryable ApplyGraphQlArguments( bool applyOrder) where TItem : class { - object? GetArguments(Type type, string name) => context.GetArgument(type, name); - if (keyNames is not null) { - if (ArgumentReader.TryReadIds(GetArguments, out var values)) + if (ArgumentReader.TryReadIds(context, out var values)) { var keyName = GetKeyName(keyNames); var predicate = ExpressionBuilder.BuildPredicate(keyName, Comparison.In, values); @@ -21,7 +19,7 @@ public static IQueryable ApplyGraphQlArguments( } } - if (ArgumentReader.TryReadWhere(GetArguments, out var wheres)) + if (ArgumentReader.TryReadWhere(context, out var wheres)) { var predicate = ExpressionBuilder.BuildPredicate(wheres); queryable = queryable.Where(predicate); @@ -29,14 +27,14 @@ public static IQueryable ApplyGraphQlArguments( if (applyOrder) { - queryable = Order(queryable, GetArguments); + queryable = Order(queryable, context); - if (ArgumentReader.TryReadSkip(GetArguments, out var skip)) + if (ArgumentReader.TryReadSkip(context, out var skip)) { queryable = queryable.Skip(skip); } - if (ArgumentReader.TryReadTake(GetArguments, out var take)) + if (ArgumentReader.TryReadTake(context, out var take)) { queryable = queryable.Take(take); } @@ -55,9 +53,9 @@ static string GetKeyName(List keyNames) return keyNames[0]; } - static IQueryable Order(IQueryable queryable, Func getArguments) + static IQueryable Order(IQueryable queryable, IResolveFieldContext context) { - var orderBys = ArgumentReader.ReadOrderBy(getArguments).ToList(); + var orderBys = ArgumentReader.ReadOrderBy(context).ToList(); IOrderedQueryable ordered; if (orderBys.Count > 0) { diff --git a/src/GraphQL.EntityFramework/Where/ArgumentReader.cs b/src/GraphQL.EntityFramework/Where/ArgumentReader.cs index c9cf23b82..44a28c94f 100644 --- a/src/GraphQL.EntityFramework/Where/ArgumentReader.cs +++ b/src/GraphQL.EntityFramework/Where/ArgumentReader.cs @@ -1,15 +1,16 @@ static class ArgumentReader { - public static bool TryReadWhere(Func getArgument, out IEnumerable expression) + public static bool TryReadWhere(IResolveFieldContext context, out IEnumerable expression) { - expression = getArgument.ReadList("where"); + expression = ReadList(context, "where"); return expression.Any(); } - public static IEnumerable ReadOrderBy(Func getArgument) => getArgument.ReadList("orderBy"); + public static IEnumerable ReadOrderBy(IResolveFieldContext context) => + ReadList(context, "orderBy"); - public static bool TryReadIds(Func getArgument, [NotNullWhen(true)] out string[]? result) + public static bool TryReadIds(IResolveFieldContext context, [NotNullWhen(true)] out string[]? result) { string ArgumentToExpression(object argument) { @@ -22,8 +23,8 @@ string ArgumentToExpression(object argument) }; } - var idsArgument = getArgument(typeof(object), "ids"); - var idArgument = getArgument(typeof(object), "id"); + var idsArgument = context.GetArgument(typeof(object), "ids"); + var idArgument = context.GetArgument(typeof(object), "id"); if (idsArgument is null && idArgument is null) { result = null; @@ -51,9 +52,9 @@ string ArgumentToExpression(object argument) return true; } - public static bool TryReadSkip(Func getArgument, out int skip) + public static bool TryReadSkip(IResolveFieldContext context, out int skip) { - var result = getArgument.TryReadInt("skip", out skip); + var result = TryReadInt("skip", context, out skip); if (result) { if (skip < 0) @@ -64,9 +65,9 @@ public static bool TryReadSkip(Func getArgument, out int return result; } - public static bool TryReadTake(Func getArgument, out int take) + public static bool TryReadTake(IResolveFieldContext context, out int take) { - var result = getArgument.TryReadInt("take", out take); + var result = TryReadInt("take", context, out take); if (result) { if (take < 0) @@ -77,9 +78,9 @@ public static bool TryReadTake(Func getArgument, out int return result; } - static IEnumerable ReadList(this Func getArgument, string name) + static IEnumerable ReadList(IResolveFieldContext context, string name) { - var argument = getArgument(typeof(T[]), name); + var argument = context.GetArgument(typeof(T[]), name); if (argument is null) { return Enumerable.Empty(); @@ -88,9 +89,9 @@ static IEnumerable ReadList(this Func getArgument, return (T[]) argument; } - static bool TryReadInt(this Func getArgument, string name, out int value) + static bool TryReadInt(string name, IResolveFieldContext context, out int value) { - var argument = getArgument(typeof(int), name); + var argument = context.GetArgument(typeof(int), name); if (argument is null) { value = 0;