Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
// a generalized "nullable" option here to allow us to do that.
#nullable disable

using System.ComponentModel;
using System.Diagnostics.CodeAnalysis;
using System.Linq.Expressions;
using System.Reflection;

namespace CommunityToolkit.Datasync.Client.Query.Linq;

Expand All @@ -17,6 +19,30 @@ namespace CommunityToolkit.Datasync.Client.Query.Linq;
/// </summary>
internal static class ExpressionExtensions
{
private static readonly MethodInfo Contains;
private static readonly MethodInfo SequenceEqual;

static ExpressionExtensions()
{
Dictionary<string, List<MethodInfo>> queryableMethodGroups = typeof(Enumerable)
.GetMethods(BindingFlags.Public | BindingFlags.Static | BindingFlags.DeclaredOnly)
.GroupBy(mi => mi.Name)
.ToDictionary(e => e.Key, l => l.ToList());

MethodInfo GetMethod(string name, int genericParameterCount, Func<Type[], Type[]> parameterGenerator)
=> queryableMethodGroups[name].Single(mi => ((genericParameterCount == 0 && !mi.IsGenericMethod)
|| (mi.IsGenericMethod && mi.GetGenericArguments().Length == genericParameterCount))
&& mi.GetParameters().Select(e => e.ParameterType).SequenceEqual(
parameterGenerator(mi.IsGenericMethod ? mi.GetGenericArguments() : [])));

Contains = GetMethod(
nameof(Enumerable.Contains), 1,
types => [typeof(IEnumerable<>).MakeGenericType(types[0]), types[0]]);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: The MakeGenericType() call can be made into a static readonly so we don't re-allocate. Wouldn't worry about this one - I'll fix it next time round.

SequenceEqual = GetMethod(
nameof(Enumerable.SequenceEqual), 1,
types => [typeof(IEnumerable<>).MakeGenericType(types[0]), typeof(IEnumerable<>).MakeGenericType(types[0])]);
}

/// <summary>
/// Walk the expression and compute all the subtrees that are not dependent on any
/// of the expressions parameters.
Expand Down Expand Up @@ -127,6 +153,7 @@ internal static bool IsValidLambdaExpression(this MethodCallExpression expressio
/// <returns>The partially evaluated expression</returns>
internal static Expression PartiallyEvaluate(this Expression expression)
{
expression = expression.RemoveSpanImplicitCast();
List<Expression> subtrees = expression.FindIndependentSubtrees();
return VisitorHelper.VisitAll(expression, (Expression expr, Func<Expression, Expression> recurse) =>
{
Expand All @@ -143,6 +170,63 @@ internal static Expression PartiallyEvaluate(this Expression expression)
});
}

internal static Expression RemoveSpanImplicitCast(this Expression expression)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method needs a summary/parameters/returns XML comments.

{
return VisitorHelper.VisitAll(expression, (Expression expr, Func<Expression, Expression> recurse) =>
{
if (expr is MethodCallExpression methodCall)
{
MethodInfo method = methodCall.Method;

if (method.DeclaringType == typeof(MemoryExtensions))
{
switch (method.Name)
{
case nameof(MemoryExtensions.Contains)
when methodCall.Arguments is [Expression arg0, Expression arg1] && TryUnwrapSpanImplicitCast(arg0, out Expression unwrappedArg0):
{
Expression unwrappedExpr = Expression.Call(
Contains.MakeGenericMethod(methodCall.Method.GetGenericArguments()[0]),
unwrappedArg0, arg1);
return recurse(unwrappedExpr);
}

case nameof(MemoryExtensions.SequenceEqual)
when methodCall.Arguments is [Expression arg0, Expression arg1]
&& TryUnwrapSpanImplicitCast(arg0, out Expression unwrappedArg0)
&& TryUnwrapSpanImplicitCast(arg1, out Expression unwrappedArg1):
{
Expression unwrappedExpr = Expression.Call(
SequenceEqual.MakeGenericMethod(methodCall.Method.GetGenericArguments()[0]),
unwrappedArg0, unwrappedArg1);
return recurse(unwrappedExpr);
}
}

static bool TryUnwrapSpanImplicitCast(Expression expression, out Expression result)
{
if (expression is MethodCallExpression
{
Method: { Name: "op_Implicit", DeclaringType: { IsGenericType: true } implicitCastDeclaringType },
Arguments: [Expression unwrapped]
}
&& implicitCastDeclaringType.GetGenericTypeDefinition() is Type genericTypeDefinition
&& (genericTypeDefinition == typeof(Span<>) || genericTypeDefinition == typeof(ReadOnlySpan<>)))
{
result = unwrapped;
return true;
}

result = null;
return false;
}
}
}

return recurse(expr);
});
}

/// <summary>
/// Remove the quote from quoted expressions.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1177,7 +1177,7 @@ public void Linq_Where_StartsWith_InvariantIgnoreCase()
);
}

[Fact(Skip = "OData v8.4 does not allow string.contains")]
[Fact]
public void Linq_Where_String_Contains()
{
string[] ratings = ["A", "B"];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1416,7 +1416,7 @@ public void Linq_Where_StartsWith_InvariantIgnoreCase()
);
}

[Fact(Skip = "OData v8.4 does not allow string.contains")]
[Fact]
public void Linq_Where_String_Contains()
{
string[] ratings = ["A", "B"];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3547,7 +3547,7 @@ public void Linq_Where_StartsWith_InvariantIgnoreCase()
);
}

[Fact(Skip = "OData v8.4 does not allow string.contains")]
[Fact]
public void Linq_Where_String_Contains()
{
string[] ratings = ["A", "B"];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,7 @@ await KitchenSinkQueryTest(
// );
//}

[Fact(Skip = "OData v8.4 does not allow string.contains")]
[Fact]
public async Task KitchenSinkQueryTest_020()
{
SeedKitchenSinkWithCountryData();
Expand Down