From aa7a9c1b564b2667db7fbd41e09ab72f5d58dcdb Mon Sep 17 00:00:00 2001 From: funky81 Date: Tue, 18 Aug 2009 13:27:25 +0700 Subject: [PATCH] Enhancement for projecting into new type class Signed-off-by: funky81 --- SubSonic.Core/Extensions/Database.cs | 36 ++++++++----- .../Linq/Structure/DbQueryProvider.cs | 45 +++++++++------- .../Linq/Structure/ExecutionBuilder.cs | 52 ++++++++++--------- SubSonic.Core/Linq/Structure/QueryCommand.cs | 6 ++- .../Linq/Translation/Parameterizer.cs | 36 ++++++++++--- 5 files changed, 110 insertions(+), 65 deletions(-) diff --git a/SubSonic.Core/Extensions/Database.cs b/SubSonic.Core/Extensions/Database.cs index a32c8d8..4b6414d 100644 --- a/SubSonic.Core/Extensions/Database.cs +++ b/SubSonic.Core/Extensions/Database.cs @@ -138,10 +138,7 @@ public static List ToConstraintList(this object value) return query.Constraints; } - /// - /// Coerces an IDataReader to try and load an object using name/property matching - /// - public static void Load(this IDataReader rdr, T item) + public static void Load(this IDataReader rdr, T item, List ColumnNames) { Type iType = typeof(T); @@ -155,10 +152,11 @@ public static void Load(this IDataReader rdr, T item) { string pName = rdr.GetName(i); currentProp = cachedProps.SingleOrDefault(x => (x.Name.EndsWith("X") ? x.Name.Chop(1) : x.Name).Equals(pName, StringComparison.InvariantCultureIgnoreCase)); - - if (currentProp == null) - /** maybe this is projection **/ - currentProp = cachedProps[i]; + + if (currentProp == null && ColumnNames.Count != 0) + { + currentProp = cachedProps.First(x => x.Name == ColumnNames[i]); + } //if the property is null, likely it's a Field if (currentProp == null) @@ -199,6 +197,13 @@ public static void Load(this IDataReader rdr, T item) } } + /// + /// Coerces an IDataReader to try and load an object using name/property matching + /// + public static void Load(this IDataReader rdr, T item) + { + Load(rdr, item, null); + } /// /// Loads a single primitive value type @@ -266,12 +271,19 @@ private static bool IsCoreSystemType(Type type) type == typeof(bool?); } + public static IEnumerable ToEnumerable(this IDataReader rdr) + { + return ToEnumerable(rdr); + } + /// - /// Coerces an IDataReader to load an enumerable of T + /// Make into Enumerable /// /// - /// - public static IEnumerable ToEnumerable(this IDataReader rdr) + /// The RDR. + /// The column names. + /// + public static IEnumerable ToEnumerable(this IDataReader rdr, List ColumnNames) { List result = new List(); while (rdr.Read()) @@ -302,7 +314,7 @@ public static IEnumerable ToEnumerable(this IDataReader rdr) instance = Activator.CreateInstance(); //do we have a parameterless constructor? - Load(rdr, instance); + Load(rdr, instance, ColumnNames); result.Add(instance); } return result.AsEnumerable(); diff --git a/SubSonic.Core/Linq/Structure/DbQueryProvider.cs b/SubSonic.Core/Linq/Structure/DbQueryProvider.cs index 8efd1fd..2ad4689 100644 --- a/SubSonic.Core/Linq/Structure/DbQueryProvider.cs +++ b/SubSonic.Core/Linq/Structure/DbQueryProvider.cs @@ -41,7 +41,7 @@ public DbQueryProvider(IDataProvider provider, QueryPolicy paramPolicy, TextWrit language = mapping.Language; //log = log; } - + public DbQueryProvider(IDataProvider provider) { _provider = provider; @@ -52,7 +52,7 @@ public DbQueryProvider(IDataProvider provider) case DataClient.MySqlClient: lang = new MySqlLanguage(_provider); break; - case DataClient.SQLite: + case DataClient.SQLite: lang = new SqliteLanguage(_provider); break; default: @@ -126,7 +126,7 @@ public override object Execute(Expression expression) else { // compile the execution plan and invoke it - Expression> efn = ConvertThis(plan,typeof(object)); + Expression> efn = ConvertThis(plan, typeof(object)); Func fn = efn.Compile(); return fn(); } @@ -247,34 +247,35 @@ public virtual IEnumerable Execute(QueryCommand query, object[] paramVa //DbDataReader reader = cmd.ExecuteReader(); //return Project(reader, query.Projector); - QueryCommand cmd = new QueryCommand(query.CommandText, _provider); for (int i = 0; i < paramValues.Length; i++) { - + //need to assign a DbType var valueType = paramValues[i].GetType(); var dbType = Database.GetDbType(valueType); - - - cmd.AddParameter(query.ParameterNames[i], paramValues[i],dbType); + + + cmd.AddParameter(query.ParameterNames[i], paramValues[i], dbType); } -/* - var reader = _provider.ExecuteReader(cmd); - var result = Project(reader, query.Projector); - return result; -*/ + /* + var reader = _provider.ExecuteReader(cmd); + var result = Project(reader, query.Projector); + return result; + */ IEnumerable result; - Type type = typeof (T); + Type type = typeof(T); //this is so hacky - the issue is that the Projector below uses Expression.Convert, which is a bottleneck //it's about 10x slower than our ToEnumerable. Our ToEnumerable, however, stumbles on Anon types and groupings //since it doesn't know how to instantiate them (I tried - not smart enough). So we do some trickery here. - if (type.Name.Contains("AnonymousType") || type.Name.StartsWith("Grouping`") || type.FullName.StartsWith("System.")) { + if (type.Name.Contains("AnonymousType") || type.Name.StartsWith("Grouping`") || type.FullName.StartsWith("System.")) + { var reader = _provider.ExecuteReader(cmd); result = Project(reader, query.Projector); - } else + } + else { using (var reader = _provider.ExecuteReader(cmd)) @@ -282,16 +283,20 @@ public virtual IEnumerable Execute(QueryCommand query, object[] paramVa //use our reader stuff //thanks to Pascal LaCroix for the help here... - var resultType = typeof (T); + var resultType = typeof(T); + + var test = mapping.GetMappedMembers(resultType); + if (resultType.IsValueType) { result = reader.ToEnumerableValueType(); - } else { - result = reader.ToEnumerable(); - + if (query.ColumnNames.Count != 0) + result = reader.ToEnumerable(query.ColumnNames); + else + result = reader.ToEnumerable(); } } } diff --git a/SubSonic.Core/Linq/Structure/ExecutionBuilder.cs b/SubSonic.Core/Linq/Structure/ExecutionBuilder.cs index 4670319..a7d2fbe 100644 --- a/SubSonic.Core/Linq/Structure/ExecutionBuilder.cs +++ b/SubSonic.Core/Linq/Structure/ExecutionBuilder.cs @@ -74,8 +74,8 @@ private static Expression MakeSequence(IList expressions) Expression last = expressions[expressions.Count - 1]; return Expression.Convert( - Expression.Call(typeof (ExecutionBuilder), "Sequence", null, - Expression.NewArrayInit(typeof (object), expressions)), last.Type); + Expression.Call(typeof(ExecutionBuilder), "Sequence", null, + Expression.NewArrayInit(typeof(object), expressions)), last.Type); } public static object Sequence(params object[] values) @@ -85,7 +85,7 @@ public static object Sequence(params object[] values) private static Expression MakeAssign(Expression variable, Expression value) { - return Expression.Call(typeof (ExecutionBuilder), "Assign", new[] {variable.Type}, variable, value); + return Expression.Call(typeof(ExecutionBuilder), "Assign", new[] { variable.Type }, variable, value); } public static T Assign(ref T variable, T value) @@ -121,11 +121,11 @@ private static Expression MakeJoinKey(IList key) return key[0]; return - Expression.New(typeof (CompoundKey).GetConstructors()[0], - Expression.NewArrayInit(typeof (object), + Expression.New(typeof(CompoundKey).GetConstructors()[0], + Expression.NewArrayInit(typeof(object), key.Select(k => (Expression) - Expression.Convert(k, typeof (object + Expression.Convert(k, typeof(object ))))); } @@ -138,8 +138,8 @@ protected override Expression VisitClientJoin(ClientJoinExpression join) Expression outerKey = MakeJoinKey(join.OuterKey); ConstructorInfo kvpConstructor = - typeof (KeyValuePair<,>).MakeGenericType(innerKey.Type, join.Projection.Projector.Type).GetConstructor( - new[] {innerKey.Type, join.Projection.Projector.Type}); + typeof(KeyValuePair<,>).MakeGenericType(innerKey.Type, join.Projection.Projector.Type).GetConstructor( + new[] { innerKey.Type, join.Projection.Projector.Type }); Expression constructKVPair = Expression.New(kvpConstructor, innerKey, join.Projection.Projector); ProjectionExpression newProjection = new ProjectionExpression(join.Projection.Source, constructKVPair); @@ -149,7 +149,7 @@ protected override Expression VisitClientJoin(ClientJoinExpression join) ParameterExpression kvp = Expression.Parameter(constructKVPair.Type, "kvp"); // filter out nulls - if (join.Projection.Projector.NodeType == (ExpressionType) DbExpressionType.OuterJoined) + if (join.Projection.Projector.NodeType == (ExpressionType)DbExpressionType.OuterJoined) { LambdaExpression pred = Expression.Lambda( Expression.NotEqual( @@ -158,14 +158,14 @@ protected override Expression VisitClientJoin(ClientJoinExpression join) ), kvp ); - execution = Expression.Call(typeof (Enumerable), "Where", new[] {kvp.Type}, execution, pred); + execution = Expression.Call(typeof(Enumerable), "Where", new[] { kvp.Type }, execution, pred); } // make lookup LambdaExpression keySelector = Expression.Lambda(Expression.PropertyOrField(kvp, "Key"), kvp); LambdaExpression elementSelector = Expression.Lambda(Expression.PropertyOrField(kvp, "Value"), kvp); - Expression toLookup = Expression.Call(typeof (Enumerable), "ToLookup", - new[] {kvp.Type, outerKey.Type, join.Projection.Projector.Type}, + Expression toLookup = Expression.Call(typeof(Enumerable), "ToLookup", + new[] { kvp.Type, outerKey.Type, join.Projection.Projector.Type }, execution, keySelector, elementSelector); // 2) agg(lookup[outer]) @@ -200,24 +200,26 @@ private Expression ExecuteProjection(ProjectionExpression projection, bool okayT okayToDefer &= (receivingMember != null && policy.IsDeferLoaded(receivingMember)); // parameterize query - projection = (ProjectionExpression) Parameterizer.Parameterize(projection); + projection = (ProjectionExpression)Parameterizer.Parameterize(projection); if (scope != null) { // also convert references to outer alias to named values! these become SQL parameters too - projection = (ProjectionExpression) OuterParameterizer.Parameterize(scope.Alias, projection); + projection = (ProjectionExpression)OuterParameterizer.Parameterize(scope.Alias, projection); } var saveScope = scope; - ParameterExpression reader = Expression.Parameter(typeof (DbDataReader), "r" + nReaders++); + ParameterExpression reader = Expression.Parameter(typeof(DbDataReader), "r" + nReaders++); scope = new Scope(scope, reader, projection.Source.Alias, projection.Source.Columns); LambdaExpression projector = Expression.Lambda(Visit(projection.Projector), reader); scope = saveScope; + List columnNames = ColumnNamedGatherer.Gather(projector.Body); + string commandText = policy.Mapping.Language.Format(projection.Source); ReadOnlyCollection namedValues = NamedValueGatherer.Gather(projection.Source); string[] names = namedValues.Select(v => v.Name).ToArray(); - Expression[] values = namedValues.Select(v => Expression.Convert(Visit(v.Value), typeof (object))).ToArray(); + Expression[] values = namedValues.Select(v => Expression.Convert(Visit(v.Value), typeof(object))).ToArray(); string methExecute = okayToDefer ? "ExecuteDeferred" @@ -228,15 +230,15 @@ private Expression ExecuteProjection(ProjectionExpression projection, bool okayT } // call low-level execute directly on supplied DbQueryProvider - Expression result = Expression.Call(provider, methExecute, new[] {projector.Body.Type}, + Expression result = Expression.Call(provider, methExecute, new[] { projector.Body.Type }, Expression.New( - typeof (QueryCommand<>).MakeGenericType(projector.Body.Type). + typeof(QueryCommand<>).MakeGenericType(projector.Body.Type). GetConstructors()[0], Expression.Constant(commandText), Expression.Constant(names), - projector + projector, Expression.Constant(columnNames) ), - Expression.NewArrayInit(typeof (object), values) + Expression.NewArrayInit(typeof(object), values) ); if (projection.Aggregator != null) @@ -251,7 +253,7 @@ private Expression ExecuteProjection(ProjectionExpression projection, bool okayT protected override Expression VisitOuterJoined(OuterJoinedExpression outer) { Expression expr = Visit(outer.Expression); - ColumnExpression column = (ColumnExpression) outer.Test; + ColumnExpression column = (ColumnExpression)outer.Test; ParameterExpression reader; int iOrdinal; if (scope.TryGetValue(column, out reader, out iOrdinal)) @@ -285,7 +287,7 @@ protected override Expression VisitColumn(ColumnExpression column) // this sucks, but since we don't track true SQL types through the query, and ADO throws exception if you // call the wrong accessor, the best we can do is call GetValue and Convert.ChangeType Expression value = Expression.Convert( - Expression.Call(typeof (Convert), "ChangeType", null, + Expression.Call(typeof(Convert), "ChangeType", null, Expression.Call(reader, "GetValue", null, Expression.Constant(iOrdinal)), Expression.Constant(TypeHelper.GetNonNullableType(column.Type)) ), @@ -317,13 +319,13 @@ private class OuterParameterizer : DbExpressionVisitor internal static Expression Parameterize(TableAlias outerAlias, Expression expr) { - OuterParameterizer op = new OuterParameterizer {outerAlias = outerAlias}; + OuterParameterizer op = new OuterParameterizer { outerAlias = outerAlias }; return op.Visit(expr); } protected override Expression VisitProjection(ProjectionExpression proj) { - SelectExpression select = (SelectExpression) Visit(proj.Source); + SelectExpression select = (SelectExpression)Visit(proj.Source); if (select != proj.Source) { return new ProjectionExpression(select, proj.Projector, proj.Aggregator); @@ -363,7 +365,7 @@ private class Scope this.outer = outer; dbDataReader = dbDataReaderParam; Alias = alias; - nameMap = columns.Select((c, i) => new {c, i}).ToDictionary(x => x.c.Name, x => x.i); + nameMap = columns.Select((c, i) => new { c, i }).ToDictionary(x => x.c.Name, x => x.i); } internal TableAlias Alias { get; private set; } diff --git a/SubSonic.Core/Linq/Structure/QueryCommand.cs b/SubSonic.Core/Linq/Structure/QueryCommand.cs index a87b7f4..a13b67f 100644 --- a/SubSonic.Core/Linq/Structure/QueryCommand.cs +++ b/SubSonic.Core/Linq/Structure/QueryCommand.cs @@ -8,16 +8,20 @@ using System.Collections.Generic; using System.Collections.ObjectModel; using System.Data.Common; +using System.Data; namespace SubSonic.Linq.Structure { public class QueryCommand { - public QueryCommand(string commandText, IEnumerable paramNames, Func projector) + public List ColumnNames = new List(); + public QueryCommand(string commandText, IEnumerable paramNames, Func projector,List ColumnNames) { CommandText = commandText; ParameterNames = new List(paramNames).AsReadOnly(); Projector = projector; + this.ColumnNames = ColumnNames; + } public string CommandText { get; private set; } diff --git a/SubSonic.Core/Linq/Translation/Parameterizer.cs b/SubSonic.Core/Linq/Translation/Parameterizer.cs index 1f48c39..4d69d23 100644 --- a/SubSonic.Core/Linq/Translation/Parameterizer.cs +++ b/SubSonic.Core/Linq/Translation/Parameterizer.cs @@ -21,7 +21,7 @@ namespace SubSonic.Linq.Translation public class Parameterizer : DbExpressionVisitor { Dictionary map = new Dictionary(); - Dictionary pmap = new Dictionary(); + Dictionary pmap = new Dictionary(); private Parameterizer() { @@ -36,7 +36,8 @@ protected override Expression VisitProjection(ProjectionExpression proj) { // don't parameterize the projector or aggregator! SelectExpression select = (SelectExpression)this.Visit(proj.Source); - if (select != proj.Source) { + if (select != proj.Source) + { return new ProjectionExpression(select, proj.Projector, proj.Aggregator); } return proj; @@ -45,9 +46,11 @@ protected override Expression VisitProjection(ProjectionExpression proj) int iParam = 0; protected override Expression VisitConstant(ConstantExpression c) { - if (c.Value != null && !IsNumeric(c.Value.GetType())) { + if (c.Value != null && !IsNumeric(c.Value.GetType())) + { NamedValueExpression nv; - if (!this.map.TryGetValue(c.Value, out nv)) { // re-use same name-value if same value + if (!this.map.TryGetValue(c.Value, out nv)) + { // re-use same name-value if same value string name = "p" + (iParam++); nv = new NamedValueExpression(name, c); this.map.Add(c.Value, nv); @@ -57,7 +60,7 @@ protected override Expression VisitConstant(ConstantExpression c) return c; } - protected override Expression VisitParameter(ParameterExpression p) + protected override Expression VisitParameter(ParameterExpression p) { return this.GetNamedValue(p); } @@ -76,7 +79,8 @@ private Expression GetNamedValue(Expression e) private bool IsNumeric(Type type) { - switch (Type.GetTypeCode(type)) { + switch (Type.GetTypeCode(type)) + { case TypeCode.Boolean: case TypeCode.Byte: case TypeCode.Decimal: @@ -95,7 +99,25 @@ private bool IsNumeric(Type type) } } } - + internal class ColumnNamedGatherer : DbExpressionVisitor + { + List columnNames = new List(); + internal static List Gather(Expression ex) + { + ColumnNamedGatherer gatherer = new ColumnNamedGatherer(); + gatherer.Visit(ex); + return gatherer.columnNames; + } + protected override Expression VisitMemberInit(MemberInitExpression init) + { + //var ex = this.VisitBinding(init.Bindings[0]); + foreach (var binding in init.Bindings) + { + this.columnNames.Add(binding.Member.Name); + } + return init; + } + } internal class NamedValueGatherer : DbExpressionVisitor { HashSet namedValues = new HashSet();