diff --git a/source/MongoDB.Tests/IntegrationTests/Linq/LinqDomain.cs b/source/MongoDB.Tests/IntegrationTests/Linq/LinqDomain.cs index 9777aa1b..639a760f 100644 --- a/source/MongoDB.Tests/IntegrationTests/Linq/LinqDomain.cs +++ b/source/MongoDB.Tests/IntegrationTests/Linq/LinqDomain.cs @@ -28,6 +28,8 @@ public class Address //[MongoAlias("city")] public string City { get; set; } + public bool IsInternational { get; set; } + public AddressType AddressType { get; set; } } diff --git a/source/MongoDB.Tests/IntegrationTests/Linq/MongoQueryProviderTests.cs b/source/MongoDB.Tests/IntegrationTests/Linq/MongoQueryProviderTests.cs index b4382a7e..e0afcd66 100644 --- a/source/MongoDB.Tests/IntegrationTests/Linq/MongoQueryProviderTests.cs +++ b/source/MongoDB.Tests/IntegrationTests/Linq/MongoQueryProviderTests.cs @@ -9,6 +9,24 @@ namespace MongoDB.IntegrationTests.Linq [TestFixture] public class MongoQueryProviderTests : LinqTestsBase { + [Test] + public void Boolean_Test1() + { + var people = Collection.Linq().Where(x => x.PrimaryAddress.IsInternational); + + var queryObject = ((IMongoQueryable)people).GetQueryObject(); + Assert.AreEqual(new Document("PrimaryAddress.IsInternational", true), queryObject.Query); + } + + [Test] + public void Boolean_Test2() + { + var people = Collection.Linq().Where(x => !x.PrimaryAddress.IsInternational); + + var queryObject = ((IMongoQueryable)people).GetQueryObject(); + Assert.AreEqual(new Document("$not", new Document("PrimaryAddress.IsInternational", true)), queryObject.Query); + } + [Test] public void Chained() { @@ -87,6 +105,18 @@ public void DocumentQuery() Assert.AreEqual(new Document("Age", Op.GreaterThan(21)), queryObject.Query); } + [Test] + public void Enum() + { + var people = Collection.Linq().Where(x => x.PrimaryAddress.AddressType == AddressType.Company); + + var queryObject = ((IMongoQueryable)people).GetQueryObject(); + Assert.AreEqual(0, queryObject.Fields.Count); + Assert.AreEqual(0, queryObject.NumberToLimit); + Assert.AreEqual(0, queryObject.NumberToSkip); + Assert.AreEqual(new Document("PrimaryAddress.AddressType", AddressType.Company), queryObject.Query); + } + [Test] public void LocalEnumerable_Contains() { diff --git a/source/MongoDB.Tests/IntegrationTests/Linq/MongoQueryTests.cs b/source/MongoDB.Tests/IntegrationTests/Linq/MongoQueryTests.cs index d8787987..76f2b3c5 100644 --- a/source/MongoDB.Tests/IntegrationTests/Linq/MongoQueryTests.cs +++ b/source/MongoDB.Tests/IntegrationTests/Linq/MongoQueryTests.cs @@ -20,16 +20,15 @@ public override void TestSetup() FirstName = "Bob", LastName = "McBob", Age = 42, - PrimaryAddress = new Address {City = "London",AddressType = AddressType.Company}, + PrimaryAddress = new Address {City = "London", IsInternational = true, AddressType = AddressType.Company}, Addresses = new List
{ - new Address {City = "London"}, - new Address {City = "Tokyo"}, - new Address {City = "Seattle"} + new Address { City = "London", IsInternational = true, AddressType = AddressType.Company }, + new Address { City = "Tokyo", IsInternational = true, AddressType = AddressType.Private }, + new Address { City = "Seattle", IsInternational = false, AddressType = AddressType.Private } }, - EmployerIds = new[] {1, 2} - }, - true); + EmployerIds = new[] { 1, 2 } + }, true); Collection.Insert( new Person @@ -37,10 +36,10 @@ public override void TestSetup() FirstName = "Jane", LastName = "McJane", Age = 35, - PrimaryAddress = new Address {City = "Paris", AddressType = AddressType.Private}, - Addresses = new List
+ PrimaryAddress = new Address { City = "Paris", IsInternational = true, AddressType = AddressType.Private }, + Addresses = new List
{ - new Address {City = "Paris"} + new Address { City = "Paris", AddressType = AddressType.Private } }, EmployerIds = new[] {1} }, @@ -52,22 +51,38 @@ public override void TestSetup() FirstName = "Joe", LastName = "McJoe", Age = 21, - PrimaryAddress = new Address {City = "Chicago", AddressType = AddressType.Private}, - Addresses = new List
+ PrimaryAddress = new Address { City = "Chicago", IsInternational = true, AddressType = AddressType.Private }, + Addresses = new List
{ - new Address {City = "Chicago"}, - new Address {City = "London"} + new Address { City = "Chicago", AddressType = AddressType.Private }, + new Address { City = "London", AddressType = AddressType.Company } }, EmployerIds = new[] {3} }, true); } + [Test] + public void Boolean_Test1() + { + var people = Enumerable.ToList(Collection.Linq().Where(x => x.PrimaryAddress.IsInternational)); + + Assert.AreEqual(3, people.Count); + } + + [Test] + public void Boolean_Test2() + { + var people = Enumerable.ToList(Collection.Linq().Where(x => !x.PrimaryAddress.IsInternational)); + + Assert.AreEqual(0, people.Count); + } + [Test] public void Chained() { var people = Collection.Linq() - .Select(x => new {Name = x.FirstName + x.LastName, x.Age}) + .Select(x => new { Name = x.FirstName + x.LastName, x.Age }) .Where(x => x.Age > 21) .Select(x => x.Name).ToList(); @@ -101,7 +116,7 @@ public void ConjuctionConstraint() [Test] public void ConstraintsAgainstLocalReferenceMember() { - var local = new {Test = new {Age = 21}}; + var local = new { Test = new { Age = 21 } }; var people = Collection.Linq().Where(p => p.Age > local.Test.Age).ToList(); Assert.AreEqual(2, people.Count); @@ -150,6 +165,16 @@ public void DocumentQuery() Assert.AreEqual(2, people.Count); } + [Test] + public void Enum() + { + var people = Collection.Linq() + .Where(x => x.PrimaryAddress.AddressType == AddressType.Company) + .ToList(); + + Assert.AreEqual(1, people.Count); + } + [Test] public void First() { @@ -161,7 +186,7 @@ public void First() [Test] public void LocalEnumerable_Contains() { - var names = new[] {"Joe", "Bob"}; + var names = new[] { "Joe", "Bob" }; var people = Collection.Linq().Where(x => names.Contains(x.FirstName)).ToList(); Assert.AreEqual(2, people.Count); @@ -170,7 +195,7 @@ public void LocalEnumerable_Contains() [Test] public void LocalList_Contains() { - var names = new List {"Joe", "Bob"}; + var names = new List { "Joe", "Bob" }; var people = Collection.Linq().Where(x => names.Contains(x.FirstName)).ToList(); Assert.AreEqual(2, people.Count); @@ -266,7 +291,7 @@ public void OrderBy() public void Projection() { var people = (from p in Collection.Linq() - select new {Name = p.FirstName + p.LastName}).ToList(); + select new { Name = p.FirstName + p.LastName }).ToList(); Assert.AreEqual(3, people.Count); } @@ -276,7 +301,7 @@ public void ProjectionWithConstraints() { var people = (from p in Collection.Linq() where p.Age > 21 && p.Age < 42 - select new {Name = p.FirstName + p.LastName}).ToList(); + select new { Name = p.FirstName + p.LastName }).ToList(); Assert.AreEqual(1, people.Count); } @@ -352,15 +377,5 @@ public void WithoutConstraints() Assert.AreEqual(3, people.Count); } - - [Test] - public void Enum() - { - var people = Collection.Linq() - .Where(x => x.PrimaryAddress.AddressType == AddressType.Company) - .ToArray(); - - Assert.AreEqual(1, people.Length); - } } } \ No newline at end of file diff --git a/source/MongoDB/Linq/Translators/DocumentFormatter.cs b/source/MongoDB/Linq/Translators/DocumentFormatter.cs index 8d532815..a4e1074e 100644 --- a/source/MongoDB/Linq/Translators/DocumentFormatter.cs +++ b/source/MongoDB/Linq/Translators/DocumentFormatter.cs @@ -16,6 +16,7 @@ internal class DocumentFormatter : MongoExpressionVisitor { private Document _query; private Stack _scopes; + private bool _hasPredicate; internal Document FormatDocument(Expression expression) { @@ -28,7 +29,7 @@ internal Document FormatDocument(Expression expression) protected override Expression VisitBinary(BinaryExpression b) { int scopeDepth = _scopes.Count; - Visit(b.Left); + VisitPredicate(b.Left, true); switch (b.NodeType) { @@ -58,7 +59,7 @@ protected override Expression VisitBinary(BinaryExpression b) throw new NotSupportedException(string.Format("The operation {0} is not supported.", b.NodeType)); } - Visit(b.Right); + VisitPredicate(b.Right, true); while (_scopes.Count > scopeDepth) PopConditionScope(); @@ -74,7 +75,14 @@ protected override Expression VisitConstant(ConstantExpression c) protected override Expression VisitField(FieldExpression f) { - PushConditionScope(f.Name); + if (!_hasPredicate) + { + PushConditionScope(f.Name); + AddCondition(true); + PopConditionScope(); + } + else + PushConditionScope(f.Name); return f; } @@ -84,7 +92,7 @@ protected override Expression VisitMemberAccess(MemberExpression m) { if (m.Member.Name == "Length") { - Visit(m.Expression); + VisitPredicate(m.Expression, true); PushConditionScope("$size"); return m; } @@ -93,7 +101,7 @@ protected override Expression VisitMemberAccess(MemberExpression m) { if (m.Member.Name == "Count") { - Visit(m.Expression); + VisitPredicate(m.Expression, true); PushConditionScope("$size"); return m; } @@ -102,7 +110,7 @@ protected override Expression VisitMemberAccess(MemberExpression m) { if (m.Member.Name == "Count") { - Visit(m.Expression); + VisitPredicate(m.Expression, true); PushConditionScope("$size"); return m; } @@ -125,9 +133,9 @@ protected override Expression VisitMethodCall(MethodCallExpression m) field = m.Arguments[0] as FieldExpression; if (field == null) throw new InvalidQueryException("A mongo field must be a part of the Contains method."); - Visit(field); + VisitPredicate(field, true); PushConditionScope("$elemMatch"); - Visit(m.Arguments[1]); + VisitPredicate(m.Arguments[1], true); PopConditionScope(); //elemMatch PopConditionScope(); //field return m; @@ -139,7 +147,7 @@ protected override Expression VisitMethodCall(MethodCallExpression m) field = m.Arguments[0] as FieldExpression; if (field != null) { - Visit(field); + VisitPredicate(field, true); AddCondition(EvaluateConstant(m.Arguments[1])); PopConditionScope(); return m; @@ -148,7 +156,7 @@ protected override Expression VisitMethodCall(MethodCallExpression m) field = m.Arguments[1] as FieldExpression; if (field == null) throw new InvalidQueryException("A mongo field must be a part of the Contains method."); - Visit(field); + VisitPredicate(field, true); AddCondition("$in", EvaluateConstant(m.Arguments[0])); PopConditionScope(); return m; @@ -170,7 +178,7 @@ protected override Expression VisitMethodCall(MethodCallExpression m) field = m.Arguments[0] as FieldExpression; if (field == null) throw new InvalidQueryException(string.Format("The mongo field must be the argument in method {0}.", m.Method.Name)); - Visit(field); + VisitPredicate(field, true); AddCondition("$in", EvaluateConstant(m.Object).OfType().ToArray()); PopConditionScope(); return m; @@ -181,7 +189,7 @@ protected override Expression VisitMethodCall(MethodCallExpression m) field = m.Object as FieldExpression; if (field == null) throw new InvalidQueryException(string.Format("The mongo field must be the operator for a string operation of type {0}.", m.Method.Name)); - Visit(field); + VisitPredicate(field, true); var value = EvaluateConstant(m.Arguments[0]); if (m.Method.Name == "StartsWith") @@ -204,7 +212,7 @@ protected override Expression VisitMethodCall(MethodCallExpression m) if (field == null) throw new InvalidQueryException(string.Format("The mongo field must be the operator for a string operation of type {0}.", m.Method.Name)); - Visit(field); + VisitPredicate(field, true); string value = null; if (m.Object == null) value = EvaluateConstant(m.Arguments[1]); @@ -226,13 +234,17 @@ protected override Expression VisitUnary(UnaryExpression u) { case ExpressionType.Not: PushConditionScope("$not"); - Visit(u.Operand); + VisitPredicate(u.Operand, false); PopConditionScope(); break; case ExpressionType.ArrayLength: Visit(u.Operand); PushConditionScope("$size"); break; + case ExpressionType.Convert: + case ExpressionType.ConvertChecked: + Visit(u.Operand); + break; default: throw new NotSupportedException(string.Format("The unary operator {0} is not supported.", u.NodeType)); } @@ -281,6 +293,14 @@ private void PopConditionScope() doc[scope.Key] = scope.Value; } + private void VisitPredicate(Expression expression, bool hasPredicate) + { + var oldHasPredicate = _hasPredicate; + _hasPredicate = hasPredicate; + Visit(expression); + _hasPredicate = oldHasPredicate; + } + private static T EvaluateConstant(Expression e) { if (e.NodeType != ExpressionType.Constant) @@ -289,6 +309,11 @@ private static T EvaluateConstant(Expression e) return (T)((ConstantExpression)e).Value; } + private static bool IsBoolean(Expression expression) + { + return expression.Type == typeof(bool) || expression.Type == typeof(bool?); + } + private class Scope { public string Key { get; private set; } diff --git a/source/MongoDB/Linq/Translators/JavascriptFormatter.cs b/source/MongoDB/Linq/Translators/JavascriptFormatter.cs index 815aaa8c..e64c2615 100644 --- a/source/MongoDB/Linq/Translators/JavascriptFormatter.cs +++ b/source/MongoDB/Linq/Translators/JavascriptFormatter.cs @@ -244,11 +244,24 @@ protected override Expression VisitUnary(UnaryExpression u) { switch (u.NodeType) { + case ExpressionType.Negate: + case ExpressionType.NegateChecked: + _js.Append("-"); + Visit(u.Operand); + break; + case ExpressionType.UnaryPlus: + _js.Append("+"); + Visit(u.Operand); + break; case ExpressionType.Not: _js.Append("!("); Visit(u.Operand); _js.Append(")"); break; + case ExpressionType.Convert: + case ExpressionType.ConvertChecked: + Visit(u.Operand); + break; default: throw new NotSupportedException(string.Format("The unary operator {0} is not supported.", u.NodeType)); }