Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

This file was deleted.

142 changes: 61 additions & 81 deletions Orm/Xtensive.Orm/Orm/Linq/MemberCompilation/MemberCompilerProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,42 @@ namespace Xtensive.Orm.Linq.MemberCompilation
{
internal partial class MemberCompilerProvider<T> : LockableBase, IMemberCompilerProvider<T>
{
private MemberCompilerCollection compilers = new MemberCompilerCollection();
private readonly struct CompilerKey : IEquatable<CompilerKey>
{
private readonly Module module;
private readonly int metadataToken;

public bool Equals(CompilerKey other) => metadataToken == other.metadataToken
&& (ReferenceEquals(module, other.module) || module == other.module);

public override bool Equals(object obj) => obj is CompilerKey other && Equals(other);

public override int GetHashCode()
{
unchecked {
return module == null ? metadataToken : (module.GetHashCode() * 397) ^ metadataToken;
}
}

public Type ExpressionType { get { return typeof(T); } }
public CompilerKey(MemberInfo memberInfo)
{
module = memberInfo.Module;
metadataToken = memberInfo.MetadataToken;
}
}

private readonly Dictionary<CompilerKey, Delegate> compilers
= new Dictionary<CompilerKey, Delegate>();

public Type ExpressionType => typeof(T);

public Delegate GetUntypedCompiler(MemberInfo target)
{
ArgumentValidator.EnsureArgumentNotNull(target, "target");

var actualTarget = GetCanonicalMember(target);
if (actualTarget == null)
return null;
var registration = compilers.Get(actualTarget);
if (registration == null)
return null;
return registration.CompilerInvoker;
ArgumentValidator.EnsureArgumentNotNull(target, nameof(target));

return compilers.TryGetValue(GetCompilerKey(target), out var compiler)
? compiler
: null;
}

public Func<T, T[], T> GetCompiler(MemberInfo target)
Expand All @@ -54,16 +75,11 @@ public void RegisterCompilers(Type compilerContainer, ConflictHandlingMethod con
throw new InvalidOperationException(string.Format(
Strings.ExTypeXShouldNotBeGeneric, compilerContainer.GetFullName(true)));

var compilersToRegister = new MemberCompilerCollection();

var compilerMethods = compilerContainer
.GetMethods(BindingFlags.Public | BindingFlags.Static | BindingFlags.DeclaredOnly)
.Where(method => method.IsDefined(typeof (CompilerAttribute), false) && !method.IsGenericMethod);

foreach (var compiler in compilerMethods)
compilersToRegister.Add(ProcessCompiler(compiler));

UpdateRegistry(compilersToRegister, conflictHandlingMethod);
UpdateRegistry(compilerMethods.Select(ProcessCompiler), conflictHandlingMethod);
}

public void RegisterCompilers(IEnumerable<KeyValuePair<MemberInfo, Func<MemberInfo, T, T[], T>>> compilerDefinitions)
Expand All @@ -76,71 +92,38 @@ public void RegisterCompilers(IEnumerable<KeyValuePair<MemberInfo, Func<MemberIn
ArgumentValidator.EnsureArgumentNotNull(compilerDefinitions, "compilerDefinitions");
this.EnsureNotLocked();

var newItems = new MemberCompilerCollection();
foreach (var item in compilerDefinitions)
newItems.Add(new MemberCompilerRegistration(GetCanonicalMember(item.Key), item.Value));
var newItems = compilerDefinitions.Select(item => (item.Key, (Delegate) item.Value));
UpdateRegistry(newItems, conflictHandlingMethod);
}

#region Private methods

private void UpdateRegistry(MemberCompilerCollection newItems, ConflictHandlingMethod conflictHandlingMethod)
private void UpdateRegistry(
IEnumerable<(MemberInfo targetMember, Delegate compiler)> newRegistrations, ConflictHandlingMethod conflictHandlingMethod)
{
if (newItems.Count==0)
return;
switch (conflictHandlingMethod) {
case ConflictHandlingMethod.KeepOld:
newItems.MergeWith(compilers, false);
compilers = newItems;
break;
case ConflictHandlingMethod.Overwrite:
compilers.MergeWith(newItems, false);
break;
case ConflictHandlingMethod.ReportError:
compilers.MergeWith(newItems, true);
break;
foreach (var (targetMember, compiler) in newRegistrations) {
var key = GetCompilerKey(targetMember);
if (conflictHandlingMethod != ConflictHandlingMethod.Overwrite && compilers.ContainsKey(key)) {
if (conflictHandlingMethod == ConflictHandlingMethod.ReportError) {
throw new InvalidOperationException(string.Format(
Strings.ExCompilerForXIsAlreadyRegistered, targetMember.GetFullName(true)));
}
continue;
}
compilers[key] = compiler;
}
}

private static bool ParameterTypeMatches(Type inputParameterType, Type candidateParameterType)
{
return inputParameterType.IsGenericParameter
? candidateParameterType==inputParameterType
: (candidateParameterType.IsGenericParameter || inputParameterType==candidateParameterType);
}

private static bool AllParameterTypesMatch(
IEnumerable<Type> inputParameterTypes, IEnumerable<Type> candidateParameterTypes)
{
return inputParameterTypes
.Zip(candidateParameterTypes)
.All(pair => ParameterTypeMatches(pair.First, pair.Second));
}

private static MethodBase GetCanonicalMethod(MethodBase inputMethod, MethodBase[] possibleCanonicalMethods)
{
var inputParameterTypes = inputMethod.GetParameterTypes();

var candidates = possibleCanonicalMethods
.Where(candidate => candidate.Name==inputMethod.Name
&& candidate.GetParameters().Length==inputParameterTypes.Length
&& candidate.IsStatic==inputMethod.IsStatic)
.ToArray();

if (candidates.Length==0)
return null;
if (candidates.Length==1)
return candidates[0];

candidates = candidates
.Where(candidate =>
AllParameterTypesMatch(inputParameterTypes, candidate.GetParameterTypes()))
.ToArray();

if (candidates.Length!=1)
return null;
foreach (var candidate in possibleCanonicalMethods) {
if (inputMethod.MetadataToken == candidate.MetadataToken
&& (ReferenceEquals(inputMethod.Module, candidate.Module) || inputMethod.Module == candidate.Module)) {
return candidate;
}
}

return candidates[0];
return null;
}

private static Type[] ValidateCompilerParametersAndExtractTargetSignature(MethodInfo compiler, bool requireMemberInfo)
Expand Down Expand Up @@ -176,7 +159,7 @@ private static Type[] ValidateCompilerParametersAndExtractTargetSignature(Method
return result;
}

private static MemberCompilerRegistration ProcessCompiler(MethodInfo compiler)
private static (MemberInfo targetMember, Delegate compilerInvoker) ProcessCompiler(MethodInfo compiler)
{
var attribute = compiler.GetAttribute<CompilerAttribute>(AttributeSearchOptions.InheritNone);

Expand Down Expand Up @@ -271,7 +254,7 @@ private static MemberCompilerRegistration ProcessCompiler(MethodInfo compiler)
compiler.GetFullName(true)));

var invoker = WrapInvoker(CreateInvoker(compiler, isStatic || isCtor, isGeneric));
return new MemberCompilerRegistration(targetMember, invoker);
return (targetMember, invoker);
}

private static Func<MemberInfo, T, T[], T> WrapInvoker(Func<MemberInfo, T, T[], T> invoker)
Expand Down Expand Up @@ -318,21 +301,18 @@ private static void ValidateCompilerParameter(ParameterInfo parameter, Type requ
compiler.GetFullName(true), parameter.Name, requiredType.GetFullName(true)));
}

private static MemberInfo GetCanonicalMember(MemberInfo member)
private static CompilerKey GetCompilerKey(MemberInfo member)
{
var canonicalMember = member;
var sourceProperty = canonicalMember as PropertyInfo;
if (sourceProperty!=null) {
canonicalMember = sourceProperty.GetGetMethod();
// GetGetMethod returns null in case of non public getter.
if (canonicalMember==null)
return null;
if (canonicalMember==null) {
return default;
}
}

var sourceMethod = canonicalMember as MethodInfo;
if (sourceMethod!=null && sourceMethod.IsGenericMethod)
canonicalMember = sourceMethod.GetGenericMethodDefinition();

var targetType = canonicalMember.ReflectedType;
if (targetType.IsGenericType) {
targetType = targetType.GetGenericTypeDefinition();
Expand All @@ -348,7 +328,7 @@ private static MemberInfo GetCanonicalMember(MemberInfo member)
}

if (canonicalMember == null) {
return null;
return default;
}

if (targetType.IsEnum) {
Expand All @@ -359,7 +339,7 @@ private static MemberInfo GetCanonicalMember(MemberInfo member)
canonicalMember = GetCanonicalMethod((MethodInfo) canonicalMember, targetType.GetMethods());
}

return canonicalMember;
return new CompilerKey(canonicalMember);
}

#endregion
Expand Down

This file was deleted.

11 changes: 6 additions & 5 deletions Orm/Xtensive.Orm/Orm/Linq/Model/QueryParser.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,36 +7,37 @@
using System;
using System.Linq.Expressions;
using Xtensive.Core;
using Xtensive.Reflection;

namespace Xtensive.Orm.Linq.Model
{
internal static class QueryParser
{
public static GroupByQuery ParseGroupBy(MethodCallExpression mc)
{
var method = mc.Method.GetGenericMethodDefinition();
var method = mc.Method;

if (method==QueryableMethodInfo.GroupBy)
if (method.IsGenericMethodSpecificationOf(QueryableMethodInfo.GroupBy))
return new GroupByQuery {
Source = mc.Arguments[0],
KeySelector = mc.Arguments[1].StripQuotes(),
};

if (method==QueryableMethodInfo.GroupByWithElementSelector)
if (method.IsGenericMethodSpecificationOf(QueryableMethodInfo.GroupByWithElementSelector))
return new GroupByQuery {
Source = mc.Arguments[0],
KeySelector = mc.Arguments[1].StripQuotes(),
ElementSelector = mc.Arguments[2].StripQuotes(),
};

if (method==QueryableMethodInfo.GroupByWithResultSelector)
if (method.IsGenericMethodSpecificationOf(QueryableMethodInfo.GroupByWithResultSelector))
return new GroupByQuery {
Source = mc.Arguments[0],
KeySelector = mc.Arguments[1].StripQuotes(),
ResultSelector = mc.Arguments[2].StripQuotes(),
};

if (method==QueryableMethodInfo.GroupByWithElementAndResultSelectors)
if (method.IsGenericMethodSpecificationOf(QueryableMethodInfo.GroupByWithElementAndResultSelectors))
return new GroupByQuery {
Source = mc.Arguments[0],
KeySelector = mc.Arguments[1].StripQuotes(),
Expand Down
Loading