Skip to content

Commit

Permalink
Fixed #74 (Join on nullable and not nullable type)
Browse files Browse the repository at this point in the history
  • Loading branch information
StefH committed Apr 15, 2017
1 parent 1e5c225 commit 26ed2f4
Show file tree
Hide file tree
Showing 7 changed files with 212 additions and 55 deletions.
35 changes: 35 additions & 0 deletions src/System.Linq.Dynamic.Core/DynamicExpressionArgument.cs
@@ -0,0 +1,35 @@
//using System.Linq.Expressions;

//namespace System.Linq.Dynamic.Core
//{
// /// <summary>
// /// DynamicExpressionArgument
// /// </summary>
// public class DynamicExpressionArgument
// {
// /// <summary>
// /// If set to <c>true</c> then also create a constructor for all the parameters. Note that this doesn't work for Linq-to-Database entities.
// /// </summary>
// public bool CreateParameterCtor { get; set; }

// /// <summary>
// /// Parameters
// /// </summary>
// public ParameterExpression[] Parameters { get; set; }

// /// <summary>
// /// ResultType
// /// </summary>
// public Type ResultType { get; set; }

// /// <summary>
// /// Expression
// /// </summary>
// public string Expression { get; set; }

// /// <summary>
// /// Values
// /// </summary>
// public object[] Values { get; set; }
// }
//}
137 changes: 83 additions & 54 deletions src/System.Linq.Dynamic.Core/DynamicQueryableExtensions.cs
@@ -1,5 +1,7 @@
using System.Collections.Generic;
using System.Collections;
using System.Globalization;
using System.Linq.Dynamic.Core.Exceptions;
#if !(WINDOWS_APP45x || SILVERLIGHT)
using System.Diagnostics;
#endif
Expand Down Expand Up @@ -46,7 +48,7 @@ private static Expression OptimizeExpression(Expression expression)
return expression;
}

#region Any
#region Any
private static readonly MethodInfo _any = GetMethod(nameof(Queryable.Any));

/// <summary>
Expand Down Expand Up @@ -94,9 +96,9 @@ public static bool Any([NotNull] this IQueryable source, [NotNull] string predic

return Execute<bool>(_anyPredicate, source, lambda);
}
#endregion Any
#endregion Any

#region AsEnumerable
#region AsEnumerable
#if NET35
/// <summary>
/// Returns the input typed as <see cref="IEnumerable{T}"/> of <see cref="object"/>./>
Expand All @@ -118,9 +120,9 @@ public static IEnumerable<dynamic> AsEnumerable([NotNull] this IQueryable source
yield return obj;
}
}
#endregion AsEnumerable
#endregion AsEnumerable

#region Count
#region Count
private static readonly MethodInfo _count = GetMethod(nameof(Queryable.Count));

/// <summary>
Expand Down Expand Up @@ -168,9 +170,9 @@ public static int Count([NotNull] this IQueryable source, [NotNull] string predi

return Execute<int>(_countPredicate, source, lambda);
}
#endregion Count
#endregion Count

#region Distinct
#region Distinct
private static readonly MethodInfo _distinct = GetMethod(nameof(Queryable.Distinct));

/// <summary>
Expand All @@ -192,9 +194,9 @@ public static IQueryable Distinct([NotNull] this IQueryable source)
var optimized = OptimizeExpression(Expression.Call(typeof(Queryable), "Distinct", new Type[] { source.ElementType }, source.Expression));
return source.Provider.CreateQuery(optimized);
}
#endregion Distinct
#endregion Distinct

#region First
#region First
private static readonly MethodInfo _first = GetMethod(nameof(Queryable.First));

/// <summary>
Expand Down Expand Up @@ -236,9 +238,9 @@ public static dynamic First([NotNull] this IQueryable source, [NotNull] string p

return Execute(_firstPredicate, source, lambda);
}
#endregion First
#endregion First

#region FirstOrDefault
#region FirstOrDefault
/// <summary>
/// Returns the first element of a sequence, or a default value if the sequence contains no elements.
/// </summary>
Expand Down Expand Up @@ -278,9 +280,9 @@ public static dynamic FirstOrDefault([NotNull] this IQueryable source, [NotNull]
return Execute(_firstOrDefaultPredicate, source, lambda);
}
private static readonly MethodInfo _firstOrDefaultPredicate = GetMethod(nameof(Queryable.FirstOrDefault), 1);
#endregion FirstOrDefault
#endregion FirstOrDefault

#region GroupBy
#region GroupBy
/// <summary>
/// Groups the elements of a sequence according to a specified key string function
/// and creates a result value from each group and its key.
Expand Down Expand Up @@ -366,9 +368,9 @@ public static IQueryable GroupBy([NotNull] this IQueryable source, [NotNull] str

return source.Provider.CreateQuery(optimized);
}
#endregion GroupBy
#endregion GroupBy

#region GroupByMany
#region GroupByMany
/// <summary>
/// Groups the elements of a sequence according to multiple specified key string functions
/// and creates a result value from each group (and subgroups) and its key.
Expand Down Expand Up @@ -427,9 +429,9 @@ static IEnumerable<GroupResult> GroupByManyInternal<TElement>(IEnumerable<TEleme

return result;
}
#endregion GroupByMany
#endregion GroupByMany

#region Join
#region Join
/// <summary>
/// Correlates the elements of two sequences based on matching keys. The default equality comparer is used to compare keys.
/// </summary>
Expand All @@ -450,20 +452,47 @@ public static IQueryable Join([NotNull] this IQueryable outer, [NotNull] IEnumer
Check.NotEmpty(innerKeySelector, nameof(innerKeySelector));
Check.NotEmpty(resultSelector, nameof(resultSelector));

Type outerType = outer.ElementType;
Type innerType = inner.AsQueryable().ElementType;

bool createParameterCtor = outer.IsLinqToObjects();
LambdaExpression outerSelectorLambda = DynamicExpressionParser.ParseLambda(createParameterCtor, outer.ElementType, null, outerKeySelector, args);
LambdaExpression innerSelectorLambda = DynamicExpressionParser.ParseLambda(createParameterCtor, inner.AsQueryable().ElementType, null, innerKeySelector, args);
LambdaExpression outerSelectorLambda = DynamicExpressionParser.ParseLambda(createParameterCtor, outerType, null, outerKeySelector, args);
LambdaExpression innerSelectorLambda = DynamicExpressionParser.ParseLambda(createParameterCtor, innerType, null, innerKeySelector, args);

Type outerSelectorReturnType = outerSelectorLambda.Body.Type;
Type innerSelectorReturnType = innerSelectorLambda.Body.Type;

// If types are not the same, try to convert to Nullable and generate new LambdaExpression
if (outerSelectorReturnType != innerSelectorReturnType)
{
if (ExpressionParser.IsNullableType(outerSelectorReturnType) && !ExpressionParser.IsNullableType(innerSelectorReturnType))
{
innerSelectorReturnType = ExpressionParser.ToNullableType(innerSelectorReturnType);
innerSelectorLambda = DynamicExpressionParser.ParseLambda(createParameterCtor, innerType, innerSelectorReturnType, innerKeySelector, args);
}
else if (!ExpressionParser.IsNullableType(outerSelectorReturnType) && ExpressionParser.IsNullableType(innerSelectorReturnType))
{
outerSelectorReturnType = ExpressionParser.ToNullableType(outerSelectorReturnType);
outerSelectorLambda = DynamicExpressionParser.ParseLambda(createParameterCtor, outerType, outerSelectorReturnType, outerKeySelector, args);
}

// If types are still not the same, throw an Exception
if (outerSelectorReturnType != innerSelectorReturnType)
{
throw new ParseException(string.Format(CultureInfo.CurrentCulture, Res.IncompatibleTypes, outerType, innerType), -1);
}
}

ParameterExpression[] parameters = new[]
ParameterExpression[] parameters =
{
Expression.Parameter(outer.ElementType, "outer"), Expression.Parameter(inner.AsQueryable().ElementType, "inner")
Expression.Parameter(outerType, "outer"), Expression.Parameter(innerType, "inner")
};

LambdaExpression resultSelectorLambda = DynamicExpressionParser.ParseLambda(createParameterCtor, parameters, null, resultSelector, args);

var optimized = OptimizeExpression(Expression.Call(
typeof(Queryable), "Join",
new[] { outer.ElementType, inner.AsQueryable().ElementType, outerSelectorLambda.Body.Type, resultSelectorLambda.Body.Type },
new[] { outerType, innerType, outerSelectorLambda.Body.Type, resultSelectorLambda.Body.Type },
outer.Expression, // outer: The first sequence to join.
inner.AsQueryable().Expression, // inner: The sequence to join to the first sequence.
Expression.Quote(outerSelectorLambda), // outerKeySelector: A function to extract the join key from each element of the first sequence.
Expand All @@ -490,9 +519,9 @@ public static IQueryable<TElement> Join<TElement>([NotNull] this IQueryable<TEle
{
return (IQueryable<TElement>)Join((IQueryable)outer, (IEnumerable)inner, outerKeySelector, innerKeySelector, resultSelector, args);
}
#endregion Join
#endregion Join

#region Last
#region Last
private static readonly MethodInfo _last = GetMethod(nameof(Queryable.Last));
/// <summary>
/// Returns the last element of a sequence.
Expand All @@ -509,9 +538,9 @@ public static dynamic Last([NotNull] this IQueryable source)

return Execute(_last, source);
}
#endregion Last
#endregion Last

#region LastOrDefault
#region LastOrDefault
private static readonly MethodInfo _lastDefault = GetMethod(nameof(Queryable.LastOrDefault));
/// <summary>
/// Returns the last element of a sequence, or a default value if the sequence contains no elements.
Expand All @@ -528,9 +557,9 @@ public static dynamic LastOrDefault([NotNull] this IQueryable source)

return Execute(_lastDefault, source);
}
#endregion LastOrDefault
#endregion LastOrDefault

#region OrderBy
#region OrderBy
/// <summary>
/// Sorts the elements of a sequence in ascending or descending order according to a key.
/// </summary>
Expand Down Expand Up @@ -589,9 +618,9 @@ public static IOrderedQueryable OrderBy([NotNull] this IQueryable source, [NotNu
var optimized = OptimizeExpression(queryExpr);
return (IOrderedQueryable)source.Provider.CreateQuery(optimized);
}
#endregion OrderBy
#endregion OrderBy

#region Page/PageResult
#region Page/PageResult
/// <summary>
/// Returns the elements as paged.
/// </summary>
Expand Down Expand Up @@ -677,9 +706,9 @@ public static PagedResult<TSource> PageResult<TSource>([NotNull] this IQueryable

return result;
}
#endregion Page/PageResult
#endregion Page/PageResult

#region Reverse
#region Reverse
/// <summary>
/// Inverts the order of the elements in a sequence.
/// </summary>
Expand All @@ -691,9 +720,9 @@ public static IQueryable Reverse([NotNull] this IQueryable source)

return Queryable.Reverse((IQueryable<object>)source);
}
#endregion Reverse
#endregion Reverse

#region Select
#region Select
/// <summary>
/// Projects each element of a sequence into a new form.
/// </summary>
Expand Down Expand Up @@ -786,9 +815,9 @@ public static IQueryable Select([NotNull] this IQueryable source, [NotNull] Type

return source.Provider.CreateQuery(optimized);
}
#endregion Select
#endregion Select

#region SelectMany
#region SelectMany
/// <summary>
/// Projects each element of a sequence to an <see cref="IQueryable"/> and combines the resulting sequences into one sequence.
/// </summary>
Expand Down Expand Up @@ -991,9 +1020,9 @@ public static IQueryable SelectMany([NotNull] this IQueryable source, [NotNull]

return source.Provider.CreateQuery(optimized);
}
#endregion SelectMany
#endregion SelectMany

#region Single/SingleOrDefault
#region Single/SingleOrDefault
/// <summary>
/// Returns the only element of a sequence, and throws an exception if there
/// is not exactly one element in the sequence.
Expand Down Expand Up @@ -1030,9 +1059,9 @@ public static dynamic SingleOrDefault([NotNull] this IQueryable source)
var optimized = OptimizeExpression(Expression.Call(typeof(Queryable), "SingleOrDefault", new[] { source.ElementType }, source.Expression));
return source.Provider.Execute(optimized);
}
#endregion Single/SingleOrDefault
#endregion Single/SingleOrDefault

#region Skip
#region Skip
private static readonly MethodInfo _skip = GetMethod(nameof(Queryable.Skip), 1);

/// <summary>
Expand All @@ -1052,9 +1081,9 @@ public static IQueryable Skip([NotNull] this IQueryable source, int count)

return CreateQuery(_skip, source, Expression.Constant(count));
}
#endregion Skip
#endregion Skip

#region SkipWhile
#region SkipWhile
private static readonly MethodInfo _skipWhilePredicate = GetMethod(nameof(Queryable.SkipWhile), 1, _predicateParameterHas2);

/// <summary>
Expand All @@ -1081,9 +1110,9 @@ public static IQueryable SkipWhile([NotNull] this IQueryable source, [NotNull] s

return CreateQuery(_skipWhilePredicate, source, lambda);
}
#endregion SkipWhile
#endregion SkipWhile

#region Sum
#region Sum
/// <summary>
/// Computes the sum of a sequence of numeric values.
/// </summary>
Expand All @@ -1096,9 +1125,9 @@ public static object Sum([NotNull] this IQueryable source)
var optimized = OptimizeExpression(Expression.Call(typeof(Queryable), "Sum", null, source.Expression));
return source.Provider.Execute(optimized);
}
#endregion Sum
#endregion Sum

#region Take
#region Take
private static readonly MethodInfo _take = GetMethod(nameof(Queryable.Take), 1);
/// <summary>
/// Returns a specified number of contiguous elements from the start of a sequence.
Expand All @@ -1113,9 +1142,9 @@ public static IQueryable Take([NotNull] this IQueryable source, int count)

return CreateQuery(_take, source, Expression.Constant(count));
}
#endregion Take
#endregion Take

#region TakeWhile
#region TakeWhile
private static readonly MethodInfo _takeWhilePredicate = GetMethod(nameof(Queryable.TakeWhile), 1, _predicateParameterHas2);

/// <summary>
Expand Down Expand Up @@ -1144,7 +1173,7 @@ public static IQueryable TakeWhile([NotNull] this IQueryable source, [NotNull] s
}
#endregion TakeWhile

#region ThenBy
#region ThenBy
/// <summary>
/// Performs a subsequent ordering of the elements in a sequence in ascending order according to a key.
/// </summary>
Expand Down Expand Up @@ -1205,9 +1234,9 @@ public static IOrderedQueryable ThenBy([NotNull] this IOrderedQueryable source,
var optimized = OptimizeExpression(queryExpr);
return (IOrderedQueryable)source.Provider.CreateQuery(optimized);
}
#endregion OrderBy
#endregion OrderBy

#region Where
#region Where
/// <summary>
/// Filters a sequence of values based on a predicate.
/// </summary>
Expand Down Expand Up @@ -1260,9 +1289,9 @@ public static IQueryable Where([NotNull] this IQueryable source, [NotNull] strin
var optimized = OptimizeExpression(Expression.Call(typeof(Queryable), "Where", new[] { source.ElementType }, source.Expression, Expression.Quote(lambda)));
return source.Provider.CreateQuery(optimized);
}
#endregion
#endregion

#region Private Helpers
#region Private Helpers
// Code below is based on https://github.com/aspnet/EntityFramework/blob/9186d0b78a3176587eeb0f557c331f635760fe92/src/Microsoft.EntityFrameworkCore/EntityFrameworkQueryableExtensions.cs

private static IQueryable CreateQuery(MethodInfo operatorMethodInfo, IQueryable source)
Expand Down Expand Up @@ -1341,6 +1370,6 @@ private static TResult Execute<TResult>(MethodInfo operatorMethodInfo, IQueryabl

private static MethodInfo GetMethod(string name, int parameterCount = 0, Func<MethodInfo, bool> predicate = null) =>
typeof(Queryable).GetTypeInfo().GetDeclaredMethods(name).Single(mi => (mi.GetParameters().Length == parameterCount + 1) && ((predicate == null) || predicate(mi)));
#endregion Private Helpers
#endregion Private Helpers
}
}

0 comments on commit 26ed2f4

Please sign in to comment.