Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions Settings.StyleCop
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,11 @@
<BooleanProperty Name="Enabled">False</BooleanProperty>
</RuleSettings>
</Rule>
<Rule Name="FieldNamesMustBeginWithLowerCaseLetter">
<RuleSettings>
<BooleanProperty Name="Enabled">False</BooleanProperty>
</RuleSettings>
</Rule>
</Rules>
<AnalyzerSettings>
<CollectionProperty Name="Hungarian">
Expand Down Expand Up @@ -371,6 +376,16 @@
<BooleanProperty Name="Enabled">False</BooleanProperty>
</RuleSettings>
</Rule>
<Rule Name="ElementDocumentationHeaderMustBePrecededByBlankLine">
<RuleSettings>
<BooleanProperty Name="Enabled">False</BooleanProperty>
</RuleSettings>
</Rule>
<Rule Name="CurlyBracketsMustNotBeOmitted">
<RuleSettings>
<BooleanProperty Name="Enabled">False</BooleanProperty>
</RuleSettings>
</Rule>
</Rules>
<AnalyzerSettings />
</Analyzer>
Expand Down
227 changes: 134 additions & 93 deletions WebAPI.NHibernate-OData/Internal/FixStringMethodsVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,97 +5,138 @@

namespace Pathoschild.WebApi.NhibernateOdata.Internal
{
/// <summary>Intercepts queries before they're parsed by NHibernate to rewrite unsupported lambdas for <see cref="string.Contains"/>, <see cref="string.StartsWith(string)"/> and <see cref="string.EndsWith(string)"/>.</summary>
/// <remarks>
/// The expression tree generated by the <c>ODataQueryOptions.ApplyTo</c> method looks like the following sample.
/// <code>
/// .Lambda #Lambda1&lt;System.Func`2[Pathoschild.WebApi.NhibernateOdata.Tests.Models.Parent,System.Boolean]&gt;(Pathoschild.WebApi.NhibernateOdata.Tests.Models.Parent $$it)
/// {
/// (.If (
/// $$it.Name == null | .Constant&lt;System.Web.Http.OData.Query.Expressions.LinqParameterContainer+TypedLinqParameterContainer`1[System.String]&gt;(System.Web.Http.OData.Query.Expressions.LinqParameterContainer+TypedLinqParameterContainer`1[System.String]).TypedProperty ==
/// null
/// ) {
/// null
/// } .Else {
/// (System.Nullable`1[System.Boolean]).Call ($$it.Name).Contains(.Constant&lt;System.Web.Http.OData.Query.Expressions.LinqParameterContainer+TypedLinqParameterContainer`1[System.String]&gt;(System.Web.Http.OData.Query.Expressions.LinqParameterContainer+TypedLinqParameterContainer`1[System.String]).TypedProperty)
/// } == (System.Nullable`1[System.Boolean]).Constant&lt;System.Web.Http.OData.Query.Expressions.LinqParameterContainer+TypedLinqParameterContainer`1[System.Boolean]&gt;(System.Web.Http.OData.Query.Expressions.LinqParameterContainer+TypedLinqParameterContainer`1[System.Boolean]).TypedProperty)
/// == .Constant&lt;System.Nullable`1[System.Boolean]&gt;(True)
/// }
/// </code>
/// </remarks>
public class FixStringMethodsVisitor : ExpressionVisitor
{
/*********
** Properties
*********/
/// <summary>Whether the visitor is visiting a nested node.</summary>
/// <remarks>This is used to recognize the top-level node for logging.</remarks>
private bool IsRecursing;

/// <summary>A list of <see cref="string"/> methods supported by this visitor.</summary>
private readonly List<MethodInfo> StringMethods = new List<MethodInfo>();


/*********
** Public methods
*********/
/// <summary>Constructs an instance.</summary>
public FixStringMethodsVisitor()
{
this.StringMethods.AddRange(typeof(string).GetMethods(BindingFlags.Public | BindingFlags.Instance).Where(x => x.Name == "Contains" || x.Name == "StartsWith" || x.Name == "EndsWith"));
}

/// <summary>Dispatches the expression to one of the more specialized visit methods in this class.</summary>
/// <param name="node">The expression to visit.</param>
/// <returns>The modified expression, if it or any subexpression was modified; otherwise, returns the original expression.</returns>
public override Expression Visit(Expression node)
{
// top node
if (!this.IsRecursing)
{
this.IsRecursing = true;
return base.Visit(node);
}

var conditionalExpression = node as ConditionalExpression;
if (conditionalExpression != null)
return this.HandleConditionalExpression(node, conditionalExpression);

return base.Visit(node);
}


/*********
** Protected methods
*********/
/// <summary>Handles the conditional expression (equivalent to <c>.If {} .Else {}</c> in the sample expression tree in the <see cref="FixStringMethodsVisitor"/> remarks).</summary>
/// <param name="original">The original expression.</param>
/// <param name="ifElse">The conditional expression.</param>
/// <returns>A reduced if/else statement if it contains any of the matched methods. Otherwise, the original expression.</returns>
private Expression HandleConditionalExpression(Expression original, ConditionalExpression ifElse)
{
var elseExpression = ifElse.IfFalse as UnaryExpression;
if (elseExpression != null)
{
var methodCallExpression = elseExpression.Operand as MethodCallExpression;
if (methodCallExpression != null)
{
if (this.StringMethods.Contains(methodCallExpression.Method))
{
var methodCallReplacement = Expression.Call(
methodCallExpression.Object,
methodCallExpression.Method,
methodCallExpression.Arguments);

// Convert the result to a nullable boolean so the Expression.Equal works.
var result = Expression.Convert(methodCallReplacement, typeof(bool?));

return result;
}
}
}

return original;
}
}
/// <summary>Intercepts queries before they're parsed by NHibernate to rewrite unsupported lambdas for <see cref="string.Contains"/>, <see cref="string.StartsWith(string)"/> and <see cref="string.EndsWith(string)"/>.</summary>
/// <remarks>
/// The expression tree generated by the <c>ODataQueryOptions.ApplyTo</c> method looks like the following sample.
/// <code>
/// .Lambda #Lambda1&lt;System.Func`2[Pathoschild.WebApi.NhibernateOdata.Tests.Models.Parent,System.Boolean]&gt;(Pathoschild.WebApi.NhibernateOdata.Tests.Models.Parent $$it)
/// {
/// (.If (
/// $$it.Name == null | .Constant&lt;System.Web.Http.OData.Query.Expressions.LinqParameterContainer+TypedLinqParameterContainer`1[System.String]&gt;(System.Web.Http.OData.Query.Expressions.LinqParameterContainer+TypedLinqParameterContainer`1[System.String]).TypedProperty ==
/// null
/// ) {
/// null
/// } .Else {
/// (System.Nullable`1[System.Boolean]).Call ($$it.Name).Contains(.Constant&lt;System.Web.Http.OData.Query.Expressions.LinqParameterContainer+TypedLinqParameterContainer`1[System.String]&gt;(System.Web.Http.OData.Query.Expressions.LinqParameterContainer+TypedLinqParameterContainer`1[System.String]).TypedProperty)
/// } == (System.Nullable`1[System.Boolean]).Constant&lt;System.Web.Http.OData.Query.Expressions.LinqParameterContainer+TypedLinqParameterContainer`1[System.Boolean]&gt;(System.Web.Http.OData.Query.Expressions.LinqParameterContainer+TypedLinqParameterContainer`1[System.Boolean]).TypedProperty)
/// == .Constant&lt;System.Nullable`1[System.Boolean]&gt;(True)
/// }
/// </code>
///
/// The actual System.Web.Http.OData parser DOES NOT support the "replace" string method, so we can't make it go through NHibernate.
/// </remarks>
public class FixStringMethodsVisitor : ExpressionVisitor
{
/*********
** Properties
*********/
/// <summary>Whether the visitor is visiting a nested node.</summary>
/// <remarks>This is used to recognize the top-level node for logging.</remarks>
private bool IsRecursing;

/// <summary>A list of boolean return <see cref="string"/> methods supported by this visitor.</summary>
private readonly List<MethodInfo> BooleanReturnStringMethods = new List<MethodInfo>();

/// <summary>A list of integer return <see cref="string"/> methods supported by this visitor.</summary>
private readonly List<MethodInfo> IntegerStringMethods = new List<MethodInfo>();

/// <summary>A list of concatenation <see cref="string"/> methods supported by this visitor.</summary>
private readonly List<MethodInfo> ConcatStringMethods = new List<MethodInfo>();


/*********
** Public methods
*********/
/// <summary>Constructs an instance.</summary>
public FixStringMethodsVisitor()
{
this.BooleanReturnStringMethods.AddRange(typeof(string).GetMethods(BindingFlags.Public | BindingFlags.Instance).Where(x => x.Name == "Contains" || x.Name == "StartsWith" || x.Name == "EndsWith"));
this.IntegerStringMethods.AddRange(typeof(string).GetMethods(BindingFlags.Public | BindingFlags.Instance).Where(x => x.Name == "IndexOf").ToList());
this.ConcatStringMethods.AddRange(typeof(string).GetMethods(BindingFlags.Public | BindingFlags.Static).Where(x => x.Name == "Concat").ToList());
}

/// <summary>Dispatches the expression to one of the more specialized visit methods in this class.</summary>
/// <param name="node">The expression to visit.</param>
/// <returns>The modified expression, if it or any subexpression was modified; otherwise, returns the original expression.</returns>
public override Expression Visit(Expression node)
{
// top node
if (!this.IsRecursing)
{
this.IsRecursing = true;
return base.Visit(node);
}

var conditionalExpression = node as ConditionalExpression;
if (conditionalExpression != null)
return this.HandleConditionalExpression(node, conditionalExpression);

return base.Visit(node);
}


/*********
** Protected methods
*********/
/// <summary>Handles the conditional expression (equivalent to <c>.If {} .Else {}</c> in the sample expression tree in the <see cref="FixStringMethodsVisitor"/> remarks).</summary>
/// <param name="original">The original expression.</param>
/// <param name="ifElse">The conditional expression.</param>
/// <returns>A reduced if/else statement if it contains any of the matched methods. Otherwise, the original expression.</returns>
private Expression HandleConditionalExpression(Expression original, ConditionalExpression ifElse)
{
var elseExpression = ifElse.IfFalse as UnaryExpression;
if (elseExpression != null)
{
var methodCallExpression = elseExpression.Operand as MethodCallExpression;
if (methodCallExpression != null)
{
if (this.BooleanReturnStringMethods.Contains(methodCallExpression.Method))
{
var methodCallReplacement = Expression.Call(
methodCallExpression.Object,
methodCallExpression.Method,
methodCallExpression.Arguments);

// Convert the result to a nullable boolean so the Expression.Equal works.
var result = Expression.Convert(methodCallReplacement, typeof(bool?));
return result;
}

if (this.IntegerStringMethods.Contains(methodCallExpression.Method))
{
var methodCallReplacement = Expression.Call(
methodCallExpression.Object,
methodCallExpression.Method,
methodCallExpression.Arguments);

var result = Expression.Convert(methodCallReplacement, typeof(int?));
return result;
}
}
}

var firstLevelMethodCallExpression = ifElse.IfFalse as MethodCallExpression;
if (firstLevelMethodCallExpression != null)
{
// Using the method name and declaring type as strings because I don't want to add a dependency to the project for a simple check like that.
if (firstLevelMethodCallExpression.Method.DeclaringType != null &&
firstLevelMethodCallExpression.Method.DeclaringType.FullName == "System.Web.Http.OData.Query.Expressions.ClrSafeFunctions" &&
(firstLevelMethodCallExpression.Method.Name == "SubstringStartAndLength" || firstLevelMethodCallExpression.Method.Name == "SubstringStart"))
{
var arguments = firstLevelMethodCallExpression.Arguments.Skip(1).ToArray();
return Expression.Call(
firstLevelMethodCallExpression.Arguments[0],
typeof(string).GetMethod("Substring", arguments.Select(x => typeof(int)).ToArray()),
arguments);
}

if (this.ConcatStringMethods.Contains(firstLevelMethodCallExpression.Method))
{
return Expression.Add(firstLevelMethodCallExpression.Arguments.First(), firstLevelMethodCallExpression.Arguments.Last(), firstLevelMethodCallExpression.Method);
}
}

return original;
}
}
}
Loading