Skip to content

Commit

Permalink
Adds struct support to filtering (#5760)
Browse files Browse the repository at this point in the history
  • Loading branch information
PascalSenn committed Feb 5, 2023
1 parent fd000ca commit b027ca1
Show file tree
Hide file tree
Showing 16 changed files with 828 additions and 52 deletions.
Expand Up @@ -138,6 +138,7 @@ public virtual string GetTypeName(Type runtimeType)
runtimeType.GetGenericTypeDefinition() == typeof(EnumOperationFilterInputType<>))
{
var genericName = _namingConventions.GetTypeName(runtimeType.GenericTypeArguments[0]);

return genericName + "OperationFilterInput";
}

Expand Down Expand Up @@ -253,8 +254,8 @@ public string GetOperationName(int operation)
IFilterInputTypeDescriptor descriptor)
{
if (_configs.TryGetValue(
typeReference,
out var configurations))
typeReference,
out var configurations))
{
foreach (var configure in configurations)
{
Expand Down Expand Up @@ -290,11 +291,13 @@ public bool IsOrAllowed()
if (filterFieldHandler.CanHandle(context, typeDefinition, fieldDefinition))
{
handler = filterFieldHandler;

return true;
}
}

handler = null;

return false;
}

Expand All @@ -319,24 +322,34 @@ public bool IsOrAllowed()
TryCreateFilterType(runtimeType.ElementType, out var elementType))
{
type = typeof(ListFilterInputType<>).MakeGenericType(elementType);

return true;
}
}

if (runtimeType.Type.IsEnum)
{
type = typeof(EnumOperationFilterInputType<>).MakeGenericType(runtimeType.Source);

return true;
}

if (runtimeType.Type is { IsValueType: true, IsPrimitive: false })
{
type = typeof(FilterInputType<>).MakeGenericType(runtimeType.Type);

return true;
}

if (runtimeType.Type.IsClass ||
runtimeType.Type.IsInterface)
if (runtimeType.Type.IsClass || runtimeType.Type.IsInterface)
{
type = typeof(FilterInputType<>).MakeGenericType(runtimeType.Source);

return true;
}

type = null;

return false;
}

Expand All @@ -349,8 +362,8 @@ public bool IsOrAllowed()
foreach (var extensionType in definition.ProviderExtensionsTypes)
{
if (serviceProvider.TryGetOrCreateService<IFilterProviderExtension>(
extensionType,
out var createdExtension))
extensionType,
out var createdExtension))
{
extensions.Add(createdExtension);
}
Expand Down
Expand Up @@ -63,10 +63,7 @@ public static Expression Not(Expression expression)
return Expression.Call(
typeof(Enumerable),
nameof(Enumerable.Contains),
new Type[]
{
genericType
},
new Type[] { genericType },
Expression.Constant(parsedValue),
property);
}
Expand Down Expand Up @@ -124,16 +121,26 @@ public static Expression Not(Expression expression)
public static Expression NotNull(Expression expression)
=> Expression.NotEqual(expression, _null);

public static Expression HasValue(Expression expression)
=> Expression.IsTrue(
Expression.Property(
expression,
expression.Type.GetProperty(nameof(Nullable<int>.HasValue))!));

public static Expression NotNullAndAlso(Expression property, Expression condition)
=> Expression.AndAlso(NotNull(property), condition);

public static Expression HasValueAndAlso(Expression property, Expression condition)
=> Expression.AndAlso(HasValue(property), condition);

public static Expression Any(
Type type,
Expression property,
Expression body,
params ParameterExpression[] parameterExpression)
{
var lambda = Expression.Lambda(body, parameterExpression);

return Any(type, property, lambda);
}

Expand All @@ -143,22 +150,15 @@ public static Expression NotNullAndAlso(Expression property, Expression conditio
LambdaExpression lambda)
=> Expression.Call(
_anyWithParameter.MakeGenericMethod(type),
new Expression[]
{
property,
lambda
});
new Expression[] { property, lambda });

public static Expression Any(
Type type,
Expression property)
{
return Expression.Call(
_anyMethod.MakeGenericMethod(type),
new Expression[]
{
property
});
new Expression[] { property });
}

public static Expression All(
Expand All @@ -167,11 +167,7 @@ public static Expression NotNullAndAlso(Expression property, Expression conditio
LambdaExpression lambda)
=> Expression.Call(
_allMethod.MakeGenericMethod(type),
new Expression[]
{
property,
lambda
});
new Expression[] { property, lambda });

public static Expression NotContains(
Expression property,
Expand All @@ -188,15 +184,13 @@ public static Expression NotNullAndAlso(Expression property, Expression conditio
private static Expression CreateAndConvertParameter<T>(object value)
{
Expression<Func<T>> lambda = () => (T)value;

return lambda.Body;
}

private static Expression CreateParameter(object? value, Type type)
=> (Expression)_createAndConvert
.MakeGenericMethod(type)
.Invoke(null,
new[]
{
value
})!;
new[] { value })!;
}
Expand Up @@ -2,7 +2,6 @@
using System.Diagnostics.CodeAnalysis;
using System.Linq.Expressions;
using HotChocolate.Configuration;
using HotChocolate.Internal;
using HotChocolate.Language;
using HotChocolate.Language.Visitors;

Expand Down Expand Up @@ -47,6 +46,7 @@ public abstract class QueryableListOperationHandlerBase
ErrorHelper.CreateNonNullError(field, node.Value, context));

action = SyntaxVisitor.Skip;

return true;
}

Expand All @@ -61,10 +61,12 @@ public abstract class QueryableListOperationHandlerBase
context.AddScope();

action = SyntaxVisitor.Continue;

return true;
}

action = null;

return false;
}

Expand All @@ -90,15 +92,14 @@ public abstract class QueryableListOperationHandlerBase

if (context.InMemory)
{
expression = FilterExpressionBuilder.NotNullAndAlso(
instance,
expression);
expression = FilterExpressionBuilder.NotNullAndAlso(instance, expression);
}

context.GetLevel().Enqueue(expression);
}

action = SyntaxVisitor.Continue;

return true;
}

Expand Down
@@ -1,9 +1,9 @@
using System;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using HotChocolate.Configuration;
using HotChocolate.Internal;
using HotChocolate.Language;
using HotChocolate.Language.Visitors;

Expand Down Expand Up @@ -43,12 +43,14 @@ public class QueryableDefaultFieldHandler
ErrorHelper.CreateNonNullError(field, node.Value, context));

action = SyntaxVisitor.Skip;

return true;
}

if (field.RuntimeType is null)
{
action = null;

return false;
}

Expand All @@ -69,25 +71,37 @@ public class QueryableDefaultFieldHandler
}
else
{
var instance = context.GetInstance();

// we need to check if the previous value was a nullable value type. if it is a nullable
// value type we cannot just chain the next expression to it. We have to first select
// ".Value".
//
// without this check we would chain "previous" directly to "current": previous.current
// with this check we chain "previous" via ".Value" to "current": previous.Value.current
if (context.TryGetPreviousRuntimeType(out var previousRuntimeType) &&
previousRuntimeType.IsNullableValueType())
{
var valueGetter = instance.Type.GetProperty(nameof(Nullable<int>.Value));
instance = Expression.Property(instance, valueGetter!);
}

nestedProperty = field.Member switch
{
PropertyInfo propertyInfo =>
Expression.Property(context.GetInstance(), propertyInfo),
PropertyInfo propertyInfo => Expression.Property(instance, propertyInfo),

MethodInfo methodInfo =>
Expression.Call(context.GetInstance(), methodInfo),
MethodInfo methodInfo => Expression.Call(instance, methodInfo),

null =>
throw ThrowHelper.QueryableFiltering_NoMemberDeclared(field),
null => throw ThrowHelper.QueryableFiltering_NoMemberDeclared(field),

_ =>
throw ThrowHelper.QueryableFiltering_MemberInvalid(field.Member, field)
_ => throw ThrowHelper.QueryableFiltering_MemberInvalid(field.Member, field)
};
}

context.PushInstance(nestedProperty);
context.RuntimeTypes.Push(field.RuntimeType);
action = SyntaxVisitor.Continue;

return true;
}

Expand All @@ -100,6 +114,7 @@ public class QueryableDefaultFieldHandler
if (field.RuntimeType is null)
{
action = null;

return false;
}

Expand All @@ -109,15 +124,27 @@ public class QueryableDefaultFieldHandler
context.PopInstance();
context.RuntimeTypes.Pop();

if (context.InMemory)
// when we are in a in-memory context, it is possible that we have null reference exceptions
// To avoid these exceptions, we need to add null checks to the chain. We always wrap the
// field before in a null check.
//
// reference types:
// previous.current > 10 ==> previous is not null && previous.current > 10
//
// structs:
// previous.Value.current > 10 ==> previous is not null && previous.Value.current > 10
//
if (context.InMemory &&
context.TryGetPreviousRuntimeType(out var previousRuntimeType) &&
(previousRuntimeType.IsNullableValueType() || !previousRuntimeType.IsValueType()))
{
condition = FilterExpressionBuilder.NotNullAndAlso(
context.GetInstance(),
condition);
var peekedInstance = context.GetInstance();
condition = FilterExpressionBuilder.NotNullAndAlso(peekedInstance, condition);
}

context.GetLevel().Enqueue(condition);
action = SyntaxVisitor.Continue;

return true;
}

Expand All @@ -140,6 +167,7 @@ protected override Expression VisitParameter(ParameterExpression node)
{
return _replacement;
}

return base.VisitParameter(node);
}

Expand All @@ -151,3 +179,33 @@ protected override Expression VisitParameter(ParameterExpression node)
new ReplaceVariableExpressionVisitor(replacement, parameter).Visit(lambda);
}
}

static file class LocalExtensions
{
public static bool TryGetPreviousRuntimeType(
this QueryableFilterContext context,
[NotNullWhen(true)] out IExtendedType? runtimeType)
{
return context.RuntimeTypes.TryPeek(out runtimeType);
}

public static bool IsNullableValueType(this IExtendedType type)
{
return type.GetTypeOrElementType() is { Type.IsValueType: true, IsNullable: true };
}

public static bool IsValueType(this IExtendedType type)
{
return type.GetTypeOrElementType() is { Type.IsValueType: true };
}

private static IExtendedType GetTypeOrElementType(this IExtendedType type)
{
while (type is { IsArrayOrList: true, ElementType: { } nextType })
{
type = nextType;
}

return type;
}
}

0 comments on commit b027ca1

Please sign in to comment.