Permalink
Browse files

Find AR::Relations in build_where, append binds

Possible fix for #272.
  • Loading branch information...
ernie committed Oct 2, 2013
1 parent db1e244 commit eaa0033e0cc04c1cfa80bdbd2bfda31d3bc74e40
@@ -55,6 +55,33 @@ def build_arel
arel
end
def build_where(opts, other = [])
case opts
when String, Array
super
else # Let's prevent PredicateBuilder from doing its thing
[opts, *other].map do |arg|
case arg
when Array # Just in case there's an array in there somewhere
@klass.send(:sanitize_sql, arg)
when Hash
attrs = @klass.send(:expand_hash_conditions_for_aggregates, arg)
attrs.values.grep(::ActiveRecord::Relation) do |rel|
self.bind_values += rel.bind_values
end
attrs
when Squeel::Nodes::Node
arg.grep(::ActiveRecord::Relation) do |rel|
self.bind_values += rel.bind_values
end
arg
else
arg
end
end
end
end
def build_from
opts, name = from_visit(from_value)
case opts
View
@@ -1,6 +1,23 @@
module Squeel
module Nodes
class Node
def each(&block)
return enum_for(:each) unless block_given?
Visitors::EnumerationVisitor.new(block).accept(self)
end
# We don't want the full Enumerable method list, because it will mess
# with stuff like KeyPath
def grep(object, &block)
if block_given?
each { |value| yield value if object === value }
else
[].tap do |results|
each { |value| results << value if object === value }
end
end
end
end
end
end
View
@@ -8,3 +8,5 @@
require 'squeel/visitors/select_visitor'
require 'squeel/visitors/from_visitor'
require 'squeel/visitors/preload_visitor'
require 'squeel/visitors/enumeration_visitor'
@@ -0,0 +1,101 @@
require 'active_support/core_ext/module'
require 'squeel/nodes'
module Squeel
module Visitors
# The Enumeration visitor class, used to implement Node#each
class EnumerationVisitor
# Create a new EnumerationVisitor.
#
# @param [Proc] block The block to execute against each node.
def initialize(block = Proc.new)
@block = block
end
# Accept an object.
#
# @param object The object to visit
# @return The results of the node visitation, which will be the last
# call to the @block
def accept(object)
visit(object)
end
private
# A hash that caches the method name to use for a visitor for a given
# class
DISPATCH = Hash.new do |hash, klass|
hash[klass] = "visit_#{(klass.name || '').gsub('::', '_')}"
end
# Visit the object.
#
# @param object The object to visit
def visit(object)
send(DISPATCH[object.class], object)
@block.call(object)
rescue NoMethodError => e
raise e if respond_to?(DISPATCH[object.class], true)
superklass = object.class.ancestors.find { |klass|
respond_to?(DISPATCH[klass], true)
}
raise(TypeError, "Cannot visit #{object.class}") unless superklass
DISPATCH[object.class] = DISPATCH[superklass]
retry
end
def visit_terminal(o)
end
alias :visit_Object :visit_terminal
def visit_Array(o)
o.map { |v| visit(v) }
end
def visit_Hash(o)
o.each { |k, v| visit(k); visit(v) }
end
def visit_Squeel_Nodes_Nary(o)
visit(o.children)
end
def visit_Squeel_Nodes_Binary(o)
visit(o.left)
visit(o.right)
end
def visit_Squeel_Nodes_Unary(o)
visit(o.expr)
end
def visit_Squeel_Nodes_Order(o)
visit(o.expr)
end
def visit_Squeel_Nodes_Function(o)
visit(o.args)
end
def visit_Squeel_Nodes_Predicate(o)
visit(o.expr)
visit(o.value)
end
def visit_Squeel_Nodes_KeyPath(o)
visit(o.path)
end
def visit_Squeel_Nodes_Join(o)
visit(o._join)
end
def visit_Squeel_Nodes_Literal(o)
visit(o.expr)
end
end
end
end
@@ -504,6 +504,18 @@ module ActiveRecord
old_and_busted.to_a.should eq new_hotness.to_a
end
it 'allows a subquery from an association in a hash' do
scope = Person.first.articles
articles = Article.where(:id => scope)
articles.should have(3).articles
end
it 'allows a subquery from an association in a Squeel node' do
scope = Person.first.articles
articles = Article.where{id.in scope}
articles.should have(3).articles
end
it 'is backwards-compatible with "where.not"' do
if activerecord_version_at_least '4.0.0'
name = Person.first.name

0 comments on commit eaa0033

Please sign in to comment.