Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP
Browse files

Find AR::Relations in build_where, append binds

Possible fix for #272.
  • Loading branch information...
commit eaa0033e0cc04c1cfa80bdbd2bfda31d3bc74e40 1 parent db1e244
@ernie ernie authored
View
27 lib/squeel/adapters/active_record/4.0/relation_extensions.rb
@@ -55,6 +55,33 @@ def build_arel
arel
end
+ def build_where(opts, other = [])
+ case opts
+ when String, Array
+ super
+ else # Let's prevent PredicateBuilder from doing its thing
+ [opts, *other].map do |arg|
+ case arg
+ when Array # Just in case there's an array in there somewhere
+ @klass.send(:sanitize_sql, arg)
+ when Hash
+ attrs = @klass.send(:expand_hash_conditions_for_aggregates, arg)
+ attrs.values.grep(::ActiveRecord::Relation) do |rel|
+ self.bind_values += rel.bind_values
+ end
+ attrs
+ when Squeel::Nodes::Node
+ arg.grep(::ActiveRecord::Relation) do |rel|
+ self.bind_values += rel.bind_values
+ end
+ arg
+ else
+ arg
+ end
+ end
+ end
+ end
+
def build_from
opts, name = from_visit(from_value)
case opts
View
17 lib/squeel/nodes/node.rb
@@ -1,6 +1,23 @@
module Squeel
module Nodes
class Node
+ def each(&block)
+ return enum_for(:each) unless block_given?
+
+ Visitors::EnumerationVisitor.new(block).accept(self)
+ end
+
+ # We don't want the full Enumerable method list, because it will mess
+ # with stuff like KeyPath
+ def grep(object, &block)
+ if block_given?
+ each { |value| yield value if object === value }
+ else
+ [].tap do |results|
+ each { |value| results << value if object === value }
+ end
+ end
+ end
end
end
end
View
2  lib/squeel/visitors.rb
@@ -8,3 +8,5 @@
require 'squeel/visitors/select_visitor'
require 'squeel/visitors/from_visitor'
require 'squeel/visitors/preload_visitor'
+
+require 'squeel/visitors/enumeration_visitor'
View
101 lib/squeel/visitors/enumeration_visitor.rb
@@ -0,0 +1,101 @@
+require 'active_support/core_ext/module'
+require 'squeel/nodes'
+
+module Squeel
+ module Visitors
+ # The Enumeration visitor class, used to implement Node#each
+ class EnumerationVisitor
+
+ # Create a new EnumerationVisitor.
+ #
+ # @param [Proc] block The block to execute against each node.
+ def initialize(block = Proc.new)
+ @block = block
+ end
+
+ # Accept an object.
+ #
+ # @param object The object to visit
+ # @return The results of the node visitation, which will be the last
+ # call to the @block
+ def accept(object)
+ visit(object)
+ end
+
+ private
+ # A hash that caches the method name to use for a visitor for a given
+ # class
+ DISPATCH = Hash.new do |hash, klass|
+ hash[klass] = "visit_#{(klass.name || '').gsub('::', '_')}"
+ end
+
+ # Visit the object.
+ #
+ # @param object The object to visit
+ def visit(object)
+ send(DISPATCH[object.class], object)
+ @block.call(object)
+ rescue NoMethodError => e
+ raise e if respond_to?(DISPATCH[object.class], true)
+
+ superklass = object.class.ancestors.find { |klass|
+ respond_to?(DISPATCH[klass], true)
+ }
+ raise(TypeError, "Cannot visit #{object.class}") unless superklass
+ DISPATCH[object.class] = DISPATCH[superklass]
+ retry
+ end
+
+ def visit_terminal(o)
+ end
+ alias :visit_Object :visit_terminal
+
+ def visit_Array(o)
+ o.map { |v| visit(v) }
+ end
+
+ def visit_Hash(o)
+ o.each { |k, v| visit(k); visit(v) }
+ end
+
+ def visit_Squeel_Nodes_Nary(o)
+ visit(o.children)
+ end
+
+ def visit_Squeel_Nodes_Binary(o)
+ visit(o.left)
+ visit(o.right)
+ end
+
+ def visit_Squeel_Nodes_Unary(o)
+ visit(o.expr)
+ end
+
+ def visit_Squeel_Nodes_Order(o)
+ visit(o.expr)
+ end
+
+ def visit_Squeel_Nodes_Function(o)
+ visit(o.args)
+ end
+
+ def visit_Squeel_Nodes_Predicate(o)
+ visit(o.expr)
+ visit(o.value)
+ end
+
+ def visit_Squeel_Nodes_KeyPath(o)
+ visit(o.path)
+ end
+
+ def visit_Squeel_Nodes_Join(o)
+ visit(o._join)
+ end
+
+ def visit_Squeel_Nodes_Literal(o)
+ visit(o.expr)
+ end
+
+ end
+ end
+end
View
12 spec/squeel/adapters/active_record/relation_extensions_spec.rb
@@ -504,6 +504,18 @@ module ActiveRecord
old_and_busted.to_a.should eq new_hotness.to_a
end
+ it 'allows a subquery from an association in a hash' do
+ scope = Person.first.articles
+ articles = Article.where(:id => scope)
+ articles.should have(3).articles
+ end
+
+ it 'allows a subquery from an association in a Squeel node' do
+ scope = Person.first.articles
+ articles = Article.where{id.in scope}
+ articles.should have(3).articles
+ end
+
it 'is backwards-compatible with "where.not"' do
if activerecord_version_at_least '4.0.0'
name = Person.first.name
Please sign in to comment.
Something went wrong with that request. Please try again.