diff --git a/Extensions/Xtensive.Orm.BulkOperations.Tests/Issues/JoinedTableAsSourceForDelete.cs b/Extensions/Xtensive.Orm.BulkOperations.Tests/Issues/JoinedTableAsSourceForDelete.cs new file mode 100644 index 0000000000..9a08893b45 --- /dev/null +++ b/Extensions/Xtensive.Orm.BulkOperations.Tests/Issues/JoinedTableAsSourceForDelete.cs @@ -0,0 +1,92 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using NUnit.Framework; +using TestCommon.Model; +using Xtensive.Orm.BulkOperations.Tests.Issues.WrongAliassesIssue; +using Xtensive.Sql; + +namespace Xtensive.Orm.BulkOperations.Tests.Issues +{ + public class JoinedTableAsSourceForOperationsCauseWrongAliases : AutoBuildTest + { + [Test] + public void CustomerCase() + { + using (var session = Domain.OpenSession()) + using (var tx = session.OpenTransaction()) { + var query = session.Query.All() + .Where(r => r.DowntimeInfo.Record.Equipment.Id == 333); + + var queryResult = query.Delete(); + } + } + + [Test] + public void MultipleKeyTest() + { + using (var session = Domain.OpenSession()) + using (var tx = session.OpenTransaction()) { + var query = session.Query.All() + .Where(r => r.DowntimeInfo.Record.Equipment.Id == 333); + + var queryResult = query.Delete(); + } + } + } +} + +namespace Xtensive.Orm.BulkOperations.Tests.Issues.WrongAliassesIssue +{ + [HierarchyRoot] + public class DowntimeReason : Entity + { + [Field, Key] + public int Id { get; private set; } + + [Field] + public DowntimeInfo DowntimeInfo { get; set; } + } + + [HierarchyRoot] + [KeyGenerator(KeyGeneratorKind.None)] + public class DowntimeReason2 : Entity + { + [Field, Key(0)] + public int Id { get; private set; } + + [Field, Key(1)] + public int Id2 { get; private set; } + + [Field] + public DowntimeInfo DowntimeInfo { get; set; } + } + + [HierarchyRoot] + public class DowntimeInfo : Entity + { + [Field, Key] + public int Id { get; private set; } + + [Field] + public Record Record { get; set; } + } + + [HierarchyRoot] + public class Record : Entity + { + [Field, Key] + public int Id { get; private set; } + + [Field] + public Equipment Equipment { get; set; } + } + + [HierarchyRoot] + public class Equipment : Entity + { + [Field, Key] + public int Id { get; private set; } + } +} \ No newline at end of file diff --git a/Extensions/Xtensive.Orm.BulkOperations/Internals/Operation.cs b/Extensions/Xtensive.Orm.BulkOperations/Internals/Operation.cs index 5176778834..7adcab32da 100644 --- a/Extensions/Xtensive.Orm.BulkOperations/Internals/Operation.cs +++ b/Extensions/Xtensive.Orm.BulkOperations/Internals/Operation.cs @@ -4,15 +4,12 @@ using System; using System.Collections.Generic; -using System.Data.Common; using System.Linq; using System.Reflection; using Xtensive.Core; using Xtensive.Orm.Linq; -using Xtensive.Orm.Model; using Xtensive.Orm.Providers; using Xtensive.Orm.Services; -using Xtensive.Sql.Model; using Xtensive.Sql; using Xtensive.Sql.Dml; using QueryParameterBinding = Xtensive.Orm.Services.QueryParameterBinding; @@ -24,21 +21,24 @@ internal abstract class Operation where T : class, IEntity { private static readonly MethodInfo TranslateQueryMethod = typeof(QueryBuilder).GetMethod("TranslateQuery"); + public readonly QueryProvider QueryProvider; + public readonly QueryBuilder QueryBuilder; public List Bindings; - protected DomainHandler DomainHandler; - protected PrimaryIndexMapping[] PrimaryIndexes; - public QueryBuilder QueryBuilder; - public Session Session; - protected TypeInfo TypeInfo; public SqlTableRef JoinedTableRef; + protected readonly DomainHandler DomainHandler; + protected readonly PrimaryIndexMapping[] PrimaryIndexes; + protected readonly TypeInfo TypeInfo; + + public Session Session { get { return QueryBuilder.Session; } } + public int Execute() { EnsureTransactionIsStarted(); - QueryProvider.Session.SaveChanges(); + Session.SaveChanges(); int value = ExecuteInternal(); - SessionStateAccessor accessor = DirectStateAccessor.Get(QueryProvider.Session); + var accessor = DirectStateAccessor.Get(Session); accessor.Invalidate(); return value; } @@ -49,45 +49,36 @@ protected void EnsureTransactionIsStarted() { var accessor = QueryProvider.Session.Services.Demand(); #pragma warning disable 168 - DbTransaction notUsed = accessor.Transaction; + var notUsed = accessor.Transaction; #pragma warning restore 168 } protected abstract int ExecuteInternal(); - public QueryTranslationResult GetRequest(IQueryable query) - { - return QueryBuilder.TranslateQuery(query); - } + public QueryTranslationResult GetRequest(IQueryable query) => QueryBuilder.TranslateQuery(query); - public QueryTranslationResult GetRequest(Type type, IQueryable query) - { - return - (QueryTranslationResult) TranslateQueryMethod.MakeGenericMethod(type).Invoke(QueryBuilder, new object[] {query}); - } + public QueryTranslationResult GetRequest(Type type, IQueryable query) => + (QueryTranslationResult) TranslateQueryMethod.MakeGenericMethod(type).Invoke(QueryBuilder, new object[] { query }); - public TypeInfo GetTypeInfo(Type entityType) - { - return Session.Domain.Model.Hierarchies.SelectMany(a => a.Types).Single(a => a.UnderlyingType==entityType); - } + public TypeInfo GetTypeInfo(Type entityType) => + Session.Domain.Model.Hierarchies.SelectMany(a => a.Types).Single(a => a.UnderlyingType == entityType); #endregion protected Operation(QueryProvider queryProvider) { QueryProvider = queryProvider; - Type entityType = typeof (T); - Session = queryProvider.Session; - DomainHandler = Session.Domain.Services.Get(); - TypeInfo = - queryProvider.Session.Domain.Model.Hierarchies.SelectMany(a => a.Types).Single( - a => a.UnderlyingType==entityType); - var mapping = Session.StorageNode.Mapping; + var entityType = typeof(T); + var session = queryProvider.Session; + DomainHandler = session.Domain.Services.Get(); + TypeInfo = DomainHandler.Domain.Model.Hierarchies.SelectMany(a => a.Types) + .Single(a => a.UnderlyingType == entityType); + var mapping = session.StorageNode.Mapping; PrimaryIndexes = TypeInfo.AffectedIndexes .Where(i => i.IsPrimary) .Select(i => new PrimaryIndexMapping(i, mapping[i.ReflectedType])) .ToArray(); - QueryBuilder = Session.Services.Get(); + QueryBuilder = session.Services.Get(); } protected QueryCommand ToCommand(SqlStatement statement) diff --git a/Extensions/Xtensive.Orm.BulkOperations/Internals/QueryOperation.cs b/Extensions/Xtensive.Orm.BulkOperations/Internals/QueryOperation.cs index c99c0dda84..36f496ffad 100644 --- a/Extensions/Xtensive.Orm.BulkOperations/Internals/QueryOperation.cs +++ b/Extensions/Xtensive.Orm.BulkOperations/Internals/QueryOperation.cs @@ -3,6 +3,7 @@ using System.Linq; using System.Linq.Expressions; using System.Reflection; +using Xtensive.Core; using Xtensive.Orm.Linq; using Xtensive.Orm.Model; using Xtensive.Orm.Rse; @@ -14,7 +15,7 @@ namespace Xtensive.Orm.BulkOperations internal abstract class QueryOperation : Operation where T : class, IEntity { - private readonly static MethodInfo inMethod = GetInMethod(); + private readonly static MethodInfo InMethod = GetInMethod(); protected IQueryable query; protected QueryOperation(QueryProvider queryProvider) @@ -24,55 +25,55 @@ protected QueryOperation(QueryProvider queryProvider) private static MethodInfo GetInMethod() { - foreach (var method in typeof (QueryableExtensions).GetMethods().Where(a => string.Equals(a.Name, "In", StringComparison.Ordinal))) { + foreach (var method in typeof(QueryableExtensions).GetMethods().Where(a => string.Equals(a.Name, "In", StringComparison.Ordinal))) { var parameters = method.GetParameters(); - if (parameters.Length == 3 && string.Equals(parameters[2].ParameterType.Name, "IEnumerable`1", StringComparison.Ordinal)) + if (parameters.Length == 3 && string.Equals(parameters[2].ParameterType.Name, "IEnumerable`1", StringComparison.Ordinal)) { return method; + } } return null; } protected override int ExecuteInternal() { - Expression e = query.Expression.Visit((MethodCallExpression ex) => - { - var methodInfo = ex.Method; - //rewrite localCollection.Contains(entity.SomeField) -> entity.SomeField.In(localCollection) - if (methodInfo.DeclaringType == typeof(Enumerable) && - string.Equals(methodInfo.Name, "Contains", StringComparison.Ordinal) && - ex.Arguments.Count == 2) { - var localCollection = ex.Arguments[0];//IEnumerable - var valueToCheck = ex.Arguments[1]; - var genericInMethod = inMethod.MakeGenericMethod(new[] { valueToCheck.Type }); - ex = Expression.Call(genericInMethod, valueToCheck, Expression.Constant(IncludeAlgorithm.ComplexCondition), localCollection); - methodInfo = ex.Method; - } + var e = query.Expression.Visit((MethodCallExpression ex) => { + var methodInfo = ex.Method; + //rewrite localCollection.Contains(entity.SomeField) -> entity.SomeField.In(localCollection) + if (methodInfo.DeclaringType == typeof(Enumerable) && + string.Equals(methodInfo.Name, "Contains", StringComparison.Ordinal) && + ex.Arguments.Count == 2) { + var localCollection = ex.Arguments[0];//IEnumerable + var valueToCheck = ex.Arguments[1]; + var genericInMethod = InMethod.MakeGenericMethod(new[] { valueToCheck.Type }); + ex = Expression.Call(genericInMethod, valueToCheck, Expression.Constant(IncludeAlgorithm.ComplexCondition), localCollection); + methodInfo = ex.Method; + } - if (methodInfo.DeclaringType == typeof(QueryableExtensions) && - string.Equals(methodInfo.Name, "In", StringComparison.Ordinal) && - ex.Arguments.Count > 1) { - if (ex.Arguments[1].Type == typeof(IncludeAlgorithm)) { - var algorithm = (IncludeAlgorithm) ex.Arguments[1].Invoke(); - if (algorithm == IncludeAlgorithm.TemporaryTable) { - throw new NotSupportedException("IncludeAlgorithm.TemporaryTable is not supported"); - } - if (algorithm == IncludeAlgorithm.Auto) { - List arguments = ex.Arguments.ToList(); - arguments[1] = Expression.Constant(IncludeAlgorithm.ComplexCondition); - ex = Expression.Call(methodInfo, arguments); - } + if (methodInfo.DeclaringType == typeof(QueryableExtensions) && + string.Equals(methodInfo.Name, "In", StringComparison.Ordinal) && + ex.Arguments.Count > 1) { + if (ex.Arguments[1].Type == typeof(IncludeAlgorithm)) { + var algorithm = (IncludeAlgorithm) ex.Arguments[1].Invoke(); + if (algorithm == IncludeAlgorithm.TemporaryTable) { + throw new NotSupportedException("IncludeAlgorithm.TemporaryTable is not supported"); } - else { - List arguments = ex.Arguments.ToList(); - arguments.Insert(1, Expression.Constant(IncludeAlgorithm.ComplexCondition)); - List types = methodInfo.GetParameters().Select(a => a.ParameterType).ToList(); - types.Insert(1, typeof(IncludeAlgorithm)); - ex = Expression.Call(inMethod.MakeGenericMethod(methodInfo.GetGenericArguments()), - arguments.ToArray()); + if (algorithm == IncludeAlgorithm.Auto) { + var arguments = ex.Arguments.ToList(); + arguments[1] = Expression.Constant(IncludeAlgorithm.ComplexCondition); + ex = Expression.Call(methodInfo, arguments); } } - return ex; - }); + else { + var arguments = ex.Arguments.ToList(); + arguments.Insert(1, Expression.Constant(IncludeAlgorithm.ComplexCondition)); + var types = methodInfo.GetParameters().Select(a => a.ParameterType).ToList(); + types.Insert(1, typeof(IncludeAlgorithm)); + ex = Expression.Call(InMethod.MakeGenericMethod(methodInfo.GetGenericArguments()), + arguments.ToArray()); + } + } + return ex; + }); query = QueryProvider.CreateQuery(e); return 0; } @@ -84,10 +85,12 @@ protected override int ExecuteInternal() protected void Join(SqlQueryStatement statement, SqlSelect select) { - if (select.HasLimit) + if (select.HasLimit) { JoinWhenQueryHasLimitation(statement, select); - else + } + else { JoinWhenQueryHasNoLimitation(statement, select); + } } protected abstract void SetStatementFrom(SqlStatement statement, SqlTable from); @@ -99,59 +102,60 @@ protected void Join(SqlQueryStatement statement, SqlSelect select) private void JoinWhenQueryHasLimitation(SqlStatement statement, SqlSelect select) { - if (!SupportsLimitation() && !SupportsJoin()) + if (!SupportsLimitation() && !SupportsJoin()) { throw new NotSupportedException("This provider does not supported limitation of affected rows."); + } - var sqlTableRef = @select.From as SqlTableRef; - if (sqlTableRef!=null) { + if (@select.From is SqlTableRef sqlTableRef) { SetStatementTable(statement, sqlTableRef); JoinedTableRef = sqlTableRef; } - else + else { sqlTableRef = GetStatementTable(statement); + } + if (SupportsLimitation()) { SetStatementLimit(statement, select.Limit); SetStatementWhere(statement, select.Where); } - else + else { JoinViaFrom(statement, select); - - + } } private void JoinWhenQueryHasNoLimitation(SqlStatement statement, SqlSelect select) { - var sqlTableRef = @select.From as SqlTableRef; - if (sqlTableRef!=null) { + if (@select.From is SqlTableRef sqlTableRef) { SetStatementTable(statement, sqlTableRef); SetStatementWhere(statement, select.Where); JoinedTableRef = sqlTableRef; return; } - - if (SupportsJoin()) + if (SupportsJoin()) { JoinViaFrom(statement, select); - else + } + else { JoinViaIn(statement, select); + } } private void JoinViaIn(SqlStatement statement, SqlSelect @select) { - SqlTableRef table = GetStatementTable(statement); - SqlExpression where = GetStatementWhere(statement); + var table = GetStatementTable(statement); + var where = GetStatementWhere(statement); JoinedTableRef = table; - PrimaryIndexMapping indexMapping = PrimaryIndexes[0]; + var indexMapping = PrimaryIndexes[0]; var columns = new List(); - foreach (ColumnInfo columnInfo in indexMapping.PrimaryIndex.KeyColumns.Keys) - { - SqlSelect s = select.ShallowClone(); - foreach (ColumnInfo column in columns) - { - SqlBinary ex = SqlDml.Equals(SqlDml.TableColumn(s.From, column.Name), SqlDml.TableColumn(table, column.Name)); + foreach (var columnInfo in indexMapping.PrimaryIndex.KeyColumns.Keys) { + var s = (SqlSelect) select.Clone(); + foreach (var column in columns) { + var ex = SqlDml.Equals(s.From.Columns[column.Name], table.Columns[column.Name]); s.Where = s.Where.IsNullReference() ? ex : SqlDml.And(s.Where, ex); } + var existingColumns = s.Columns.ToChainedBuffer(); s.Columns.Clear(); - s.Columns.Add(SqlDml.TableColumn(s.From, columnInfo.Name)); - SqlBinary @in = SqlDml.In(SqlDml.TableColumn(table, columnInfo.Name), s); + var columnToAdd = existingColumns.First(c => c.Name.Equals(columnInfo.Name, StringComparison.Ordinal)); + s.Columns.Add(columnToAdd); + var @in = SqlDml.In(SqlDml.TableColumn(table, columnInfo.Name), s); @where = @where.IsNullReference() ? @in : SqlDml.And(@where, @in); columns.Add(columnInfo); } @@ -165,37 +169,40 @@ private void JoinViaFrom(SqlStatement statement, SqlSelect select) SetStatementFrom(statement, queryRef); var sqlTableRef = GetStatementTable(statement); SqlExpression whereExpression = null; - PrimaryIndexMapping indexMapping = PrimaryIndexes[0]; - foreach (ColumnInfo columnInfo in indexMapping.PrimaryIndex.KeyColumns.Keys) { + var indexMapping = PrimaryIndexes[0]; + foreach (var columnInfo in indexMapping.PrimaryIndex.KeyColumns.Keys) { var leftColumn = queryRef.Columns[columnInfo.Name]; - var rightColumn = sqlTableRef==null ? GetStatementTable(statement).Columns[columnInfo.Name] : sqlTableRef.Columns[columnInfo.Name]; - if (leftColumn==null || rightColumn==null) + var rightColumn = sqlTableRef == null + ? GetStatementTable(statement).Columns[columnInfo.Name] + : sqlTableRef.Columns[columnInfo.Name]; + if (leftColumn == null || rightColumn == null) { throw new InvalidOperationException("Source query doesn't contain one of key columns of updated table."); + } var columnEqualityExperssion = SqlDml.Equals(queryRef.Columns[columnInfo.Name], sqlTableRef.Columns[columnInfo.Name]); - if (whereExpression==null) - whereExpression = columnEqualityExperssion; - else - whereExpression = SqlDml.And(whereExpression, columnEqualityExperssion); + whereExpression = whereExpression == null + ? columnEqualityExperssion + : SqlDml.And(whereExpression, columnEqualityExperssion); } SetStatementWhere(statement, whereExpression); } private void JoinViaJoin(SqlStatement statement, SqlSelect @select) { - PrimaryIndexMapping indexMapping = PrimaryIndexes[0]; - SqlTableRef left = SqlDml.TableRef(indexMapping.Table); - SqlQueryRef right = SqlDml.QueryRef(@select); + var indexMapping = PrimaryIndexes[0]; + var left = SqlDml.TableRef(indexMapping.Table); + var right = SqlDml.QueryRef(@select); SqlExpression joinExpression = null; - for (int i = 0; i < indexMapping.PrimaryIndex.KeyColumns.Count; i++) - { - SqlBinary binary = (left.Columns[i] == right.Columns[i]); - if (joinExpression.IsNullReference()) + for (var i = 0; i < indexMapping.PrimaryIndex.KeyColumns.Count; i++) { + var binary = (left.Columns[i] == right.Columns[i]); + if (joinExpression.IsNullReference()) { joinExpression = binary; - else + } + else { joinExpression &= binary; + } } JoinedTableRef = left; - SqlJoinedTable joinedTable = left.InnerJoin(right, joinExpression); + var joinedTable = left.InnerJoin(right, joinExpression); SetStatementFrom(statement, joinedTable); }